Blame tcpprotocol.cpp

541903
9bbba5
#include <climits></climits>
9bbba5
541903
#include <fcntl.h></fcntl.h>
541903
#include <unistd.h></unistd.h>
541903
#include <sys types.h=""></sys>
541903
#include <sys socket.h=""></sys>
541903
#include <sys epoll.h=""></sys>
9bbba5
#include <sys eventfd.h=""></sys>
541903
#include <netinet in.h=""></netinet>
541903
#include <netinet tcp.h=""></netinet>
541903
#include <netdb.h></netdb.h>
541903
541903
#include "protocol.h"
541903
#include "server.h"
541903
#include "connection.h"
541903
#include "tcpprotocol.h"
541903
541903
541903
541903
TcpSocket::TcpSocket(
541903
	Protocol &protocol,
541903
	const Connection::Handle &connection,
541903
	const Server::Handle &server,
541903
	const Address &address,
541903
	int sockId
541903
):
541903
	Socket(protocol, connection, server, address),
541903
	sockId(sockId)
541903
{
541903
	int tcp_nodelay = 1;
541903
	setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int));
541903
	fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
541903
	
9bbba5
	onConstructionComplete();
541903
}
541903
541903
541903
TcpProtocol::TcpProtocol():
2499ad
	wantRun(false), thread(), epollId(), eventsId()
2499ad
{
2499ad
	epollId = ::epoll_create(32);
2499ad
	eventsId = ::eventfd(0, EFD_NONBLOCK);
2499ad
	
2499ad
	struct epoll_event event = {};
2499ad
	event.events = EPOLLIN;
2499ad
	epoll_ctl(epollId, EPOLL_CTL_ADD, eventsId, &event);
2499ad
	
2499ad
	thread = new std::thread(&TcpProtocol::threadRun, this);
2499ad
}
5a39be
5a39be
2499ad
TcpProtocol::~TcpProtocol() {
2499ad
	stop();
2499ad
	
2499ad
	controlCommand(1);
2499ad
	thread->join();
2499ad
	delete thread;
2499ad
	
2499ad
	struct epoll_event event = {};
2499ad
	epoll_ctl(epollId, EPOLL_CTL_DEL, eventsId, &event);
2499ad
	
2499ad
	::close(eventsId);
2499ad
	eventsId = 0;
2499ad
	
2499ad
	::close(epollId);
2499ad
	epollId = 0;
2499ad
}
541903
541903
541903
ErrorCode TcpProtocol::resolve(Address &address) {
541903
	if (address.type == Address::SOCKET)
541903
		return ERR_NONE;
541903
	
541903
	if ( address.type != Address::COMMON_STRING
541903
	  || address.text.empty() )
541903
		return ERR_ADDRESS_INCORRECT;
541903
	
541903
	// split to host and port
541903
	const char *addrString = address.text.c_str();
541903
	const char *colon = addrString;
541903
	while(*colon && *colon != ':') ++colon;
541903
	
541903
	std::string hostBuf;
541903
	const char *host = nullptr;
541903
	const char *port = nullptr;
541903
	if (*colon) {
541903
		hostBuf.assign(addrString, colon);
541903
		host = hostBuf.c_str();
541903
		port = colon + 1;
541903
	} else {
541903
		bool allDigits = true;
541903
		for(const char *c = addrString; *c && allDigits; ++c)
541903
			if (*c < '0' || *c > '9') allDigits = false;
541903
		(allDigits ? port : host) = addrString;
541903
	}
541903
	
541903
	// get ipv4
541903
	unsigned char ip[4] = {};
541903
	if (host) {
541903
		if (4 != sscanf(host, "%hhu.%hhu.%hhu.%hhu", &ip[0], &ip[1], &ip[2], &ip[3])) {
541903
			hostent *he;
541903
			in_addr **addr_list;
541903
			if ( (he = gethostbyname(host))
541903
			  && (addr_list = (struct in_addr**)he->h_addr_list)
541903
			  && (*addr_list) )
541903
			{
541903
				memcpy(ip, *addr_list, sizeof(ip));
541903
			} else {
541903
				return ERR_ADDRESS_NOT_FOUND;
541903
			}
541903
		}
541903
	}
541903
	
541903
	unsigned short portNum = 0;
541903
	if (!port || 1 != sscanf(port, "%hu", &portNum) || !portNum)
541903
		return ERR_ADDRESS_INCORRECT;
541903
	
541903
	address.set(Address::SOCKET);
2499ad
	address.size = sizeof(struct sockaddr_in);
2499ad
	struct sockaddr_in &addr = address.as<struct sockaddr_in="">();</struct>
541903
	addr.sin_family = AF_INET;
541903
	memcpy(&addr.sin_addr, ip, sizeof(ip));
541903
	addr.sin_port = htons(portNum);
541903
	
541903
	return ERR_NONE;
541903
}
541903
541903
2499ad
void TcpProtocol::controlCommand(unsigned long long cmd)
2499ad
	{ ::write(eventsId, &cmd, sizeof(cmd)); }
2499ad
2499ad
2499ad
void TcpProtocol::threadConnectionRead(TcpSocket *socket, const Connection::Handle &connection) {
2499ad
	int sockId = socket->sockId;
2499ad
	ErrorCode err = ERR_NONE;
2499ad
	bool ready = true;
2499ad
	while(ready) {
2499ad
		if (socketToRemove(socket)) return;
2499ad
		Connection::ReadLock readLock(*connection);
2499ad
		if (!readLock) break;
2499ad
		while(true) {
2499ad
			size_t remains = readLock->remains();
2499ad
			if (!remains) break;
2499ad
			if (remains > INT_MAX) remains = INT_MAX;
2499ad
			int res = ::recv(sockId, readLock->current(), (int)remains, MSG_DONTWAIT);
2499ad
			if (res <= 0) {
2499ad
				if (res != 0 && res != EAGAIN) err = ERR_CONNECTION_LOST;
2499ad
				ready = false;
2499ad
				break;
2499ad
			}
2499ad
			readLock->processed(res);
2499ad
		}
2499ad
	}
2499ad
	if (err) connection->close(err);
2499ad
}
2499ad
2499ad
2499ad
void TcpProtocol::threadConnectionWrite(TcpSocket *socket, const Connection::Handle &connection) {
2499ad
	int sockId = socket->sockId;
2499ad
	ErrorCode err = ERR_NONE;
2499ad
	bool ready = true;
2499ad
	while(ready) {
2499ad
		if (socketToRemove(socket)) return;
2499ad
		Connection::WriteLock writeLock(*connection);
2499ad
		if (!writeLock) break;
2499ad
		while(true) {
2499ad
			size_t remains = writeLock->remains();
2499ad
			if (!remains) break;
2499ad
			if (remains > INT_MAX) remains = INT_MAX;
2499ad
			int res = ::send(sockId, writeLock->current(), (int)remains, MSG_DONTWAIT);
2499ad
			if (res <= 0) {
2499ad
				if (res != 0 && res != EAGAIN) err = ERR_CONNECTION_LOST;
2499ad
				ready = false;
2499ad
				break;
2499ad
			}
2499ad
			writeLock->processed(res);
2499ad
		}
2499ad
	}
2499ad
	if (err) connection->close(err);
2499ad
}
2499ad
2499ad
2499ad
void TcpProtocol::threadProcessEvents(void *eventsPtr, int count) {
2499ad
	struct epoll_event *events = (struct epoll_event*)eventsPtr;
9bbba5
	
2499ad
	// process socket events
2499ad
	for(int i = 0; i < count; ++i) {
2499ad
		if (TcpSocket *socket = (TcpSocket*)events[i].data.ptr) {
2499ad
			if (socketToRemove(socket)) continue;
9bbba5
			
2499ad
			unsigned int e = events[i].events;
2499ad
			int sockId = socket->sockId;
2499ad
			const Connection::Handle &connection = socket->connection;
2499ad
			const Server::Handle &server = socket->server;
9bbba5
			
2499ad
			if (e & EPOLLIN) {
2499ad
				if (connection) {
2499ad
					threadConnectionRead(socket, connection);
2499ad
				} else
2499ad
				if (server) {
2499ad
					Address address;
2499ad
					address.type = Address::SOCKET;
2499ad
					socklen_t len = sizeof(sockaddr_in);
2499ad
					int res = ::accept(sockId, (struct sockaddr*)address.data, &len);
2499ad
					assert(len < sizeof(address.data));
2499ad
					if (res >= 0) {
2499ad
						address.size = len;
2499ad
						Connection::Handle conn = serverConnect(server, address);
2499ad
						if (conn && !socketToRemove(socket)) {
2499ad
							TcpSocket *s = new TcpSocket(*this, conn, server, address, res);
2499ad
							if (ERR_NONE != connectionOpen(conn, s))
2499ad
								{ s->finalize(); serverDisconnect(server, conn); }
9bbba5
						}
2499ad
					} else
2499ad
					if (res != EAGAIN) {
2499ad
						server->stop(ERR_SERVER_LISTENING_LOST);
9bbba5
					}
9bbba5
				}
9bbba5
			}
9bbba5
			
2499ad
			if (socketToRemove(socket)) continue;
2499ad
			
2499ad
			if (e & EPOLLOUT) {
2499ad
				if (connection)
2499ad
					threadConnectionWrite(socket, connection);
9bbba5
			}
9bbba5
			
2499ad
			if (socketToRemove(socket)) continue;
2499ad
			
2499ad
			if (e & EPOLLHUP) {
2499ad
				if (connection) {
2499ad
					connection->close(e & EPOLLERR ? ERR_CONNECTION_LOST : 0);
2499ad
				} else
2499ad
				if (server) {
2499ad
					server->stop(e & EPOLLERR ? ERR_SERVER_LISTENING_LOST : 0);
2499ad
				}
2499ad
			}
2499ad
		}
2499ad
	}
2499ad
	
2499ad
	// update epoll
2499ad
	Socket *socket = socketChFirst();
2499ad
	while(socket) {
2499ad
		int sockId = ((TcpSocket*)socket)->sockId;
2499ad
		struct epoll_event event = {};
2499ad
		unsigned int flags = socketFlags(socket);
2499ad
		
2499ad
		//if (!(flags & Socket::REMOVED) && flags & Socket::WANTREAD) {
2499ad
		//	threadConnectionRead((TcpSocket*)socket, socket->connection);
2499ad
		//	flags = socketFlags(socket);
2499ad
		//}
9bbba5
		
2499ad
		//if (!(flags & Socket::REMOVED) && flags & Socket::WANTWRITE) {
2499ad
		//	threadConnectionWrite((TcpSocket*)socket, socket->connection);
2499ad
		//	flags = socketFlags(socket);
2499ad
		//}
2499ad
		
2499ad
		if (flags & Socket::REMOVED) {
2499ad
			epoll_ctl(epollId, EPOLL_CTL_DEL, sockId, &event);
2499ad
			socket = socketRemoveChanged(socket);
2499ad
		} else {
2499ad
			event.data.ptr = socket;
2499ad
			if (flags & Socket::WANTREAD) event.events |= EPOLLIN;
2499ad
			if (flags & Socket::WANTWRITE) event.events |= EPOLLOUT;
2499ad
			epoll_ctl(epollId, (flags & Socket::CREATED) ? EPOLL_CTL_ADD : EPOLL_CTL_MOD, sockId, &event);
2499ad
			socket = socketSetUnchanged(socket);
2499ad
		}
9bbba5
	}
2499ad
}
2499ad
2499ad
2499ad
void TcpProtocol::threadRun() {
2499ad
	const int maxCount = 16;
2499ad
	struct epoll_event events[maxCount] = {};
2499ad
	int count = 0;
9bbba5
	
2499ad
	bool idle = true;
2499ad
	mutexIdle.lock();
2499ad
	while(true) {
2499ad
		unsigned long long num = 0;
2499ad
		while(::read(eventsId, &num, sizeof(num)) > 0 && !num);
2499ad
		
2499ad
		if (num == 3 && idle)
2499ad
			{ idle = false; mutexRun.lock(); mutexIdle.unlock(); }
2499ad
		
2499ad
		if (idle) {
2499ad
			if (num == 1) break;
2499ad
		} else {
2499ad
			Lock lock(mutex);
2499ad
			threadProcessEvents(events, count);
2499ad
		}
2499ad
		
2499ad
		if (num == 2 && !idle)
2499ad
			{ idle = true; mutexIdle.lock(); mutexRun.unlock(); }
2499ad
		
2499ad
		count = epoll_wait(epollId, events, maxCount, 100);
541903
	}
2499ad
	mutexIdle.unlock();
541903
}
541903
541903
2499ad
646228
ErrorCode TcpProtocol::onConnect(const Connection::Handle &connection, const Address &address) {
9bbba5
	if (address.type != Address::SOCKET) return ERR_ADDRESS_INCORRECT;
541903
	
541903
	int sockId = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
9bbba5
	if (0 != ::connect(sockId, &address.as<struct sockaddr="">(), address.size)) {</struct>
541903
		::close(sockId);
541903
		return ERR_CONNECTION_FAILED;
541903
	}
541903
	
9bbba5
	TcpSocket *socket = new TcpSocket(*this, connection, nullptr, address, sockId);
9bbba5
	ErrorCode errorCode = connectionOpen(connection, socket);
9bbba5
	if (errorCode) socket->finalize();
9bbba5
	return errorCode;
541903
}
541903
541903
646228
ErrorCode TcpProtocol::onListen(const Server::Handle &server, const Address &address) {
9bbba5
	if (address.type != Address::SOCKET) return ERR_ADDRESS_INCORRECT;
9bbba5
	
9bbba5
	int sockId = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
9bbba5
	if (0 != ::bind(sockId, &address.as<struct sockaddr="">(), address.size)) {</struct>
9bbba5
		::close(sockId);
9bbba5
		return ERR_SERVER_LISTENING_FAILED;
9bbba5
	}
9bbba5
	
9bbba5
	TcpSocket *socket = new TcpSocket(*this, nullptr, server, address, sockId);
9bbba5
	ErrorCode errorCode = serverStart(server, socket);
9bbba5
	if (errorCode) socket->finalize();
9bbba5
	return errorCode;
541903
}
541903
541903
646228
ErrorCode TcpProtocol::onStart() {
2499ad
	controlCommand(3);
2499ad
	std::lock_guard<std::mutex> lock(mutexIdle);</std::mutex>
646228
	return ERR_NONE;
541903
}
541903
541903
9bbba5
void TcpProtocol::onStop(ErrorCode) {
2499ad
	int cnt = mutex.count();
2499ad
	for(int i = 0; i < cnt; ++i) mutex.unlock();
2499ad
	controlCommand(2);
2499ad
	{ std::lock_guard<std::mutex> lock(mutexRun); }</std::mutex>
2499ad
	for(int i = 0; i < cnt; ++i) mutex.lock();
541903
}
541903
9bbba5
2499ad
void TcpProtocol::onSocketChanged(Socket*)
2499ad
	{ controlCommand(0); }
9bbba5
9bbba5