Blob Blame Raw

#include <climits>

#include <fcntl.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>

#include "protocol.h"
#include "server.h"
#include "connection.h"
#include "tcpprotocol.h"



TcpSocket::TcpSocket(
	Protocol &protocol,
	const Connection::Handle &connection,
	const Server::Handle &server,
	const Address &address,
	int sockId
):
	Socket(protocol, connection, server, address),
	sockId(sockId)
{
	int tcp_nodelay = 1;
	setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int));
	fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
	
	onConstructionComplete();
}


TcpProtocol::TcpProtocol():
	thread(), epollId(), eventsId() { }


TcpProtocol::~TcpProtocol()
	{ stop(); }


ErrorCode TcpProtocol::resolve(Address &address) {
	if (address.type == Address::SOCKET)
		return ERR_NONE;
	
	if ( address.type != Address::COMMON_STRING
	  || address.text.empty() )
		return ERR_ADDRESS_INCORRECT;
	
	// split to host and port
	const char *addrString = address.text.c_str();
	const char *colon = addrString;
	while(*colon && *colon != ':') ++colon;
	
	std::string hostBuf;
	const char *host = nullptr;
	const char *port = nullptr;
	if (*colon) {
		hostBuf.assign(addrString, colon);
		host = hostBuf.c_str();
		port = colon + 1;
	} else {
		bool allDigits = true;
		for(const char *c = addrString; *c && allDigits; ++c)
			if (*c < '0' || *c > '9') allDigits = false;
		(allDigits ? port : host) = addrString;
	}
	
	// get ipv4
	unsigned char ip[4] = {};
	if (host) {
		if (4 != sscanf(host, "%hhu.%hhu.%hhu.%hhu", &ip[0], &ip[1], &ip[2], &ip[3])) {
			hostent *he;
			in_addr **addr_list;
			if ( (he = gethostbyname(host))
			  && (addr_list = (struct in_addr**)he->h_addr_list)
			  && (*addr_list) )
			{
				memcpy(ip, *addr_list, sizeof(ip));
			} else {
				return ERR_ADDRESS_NOT_FOUND;
			}
		}
	}
	
	unsigned short portNum = 0;
	if (!port || 1 != sscanf(port, "%hu", &portNum) || !portNum)
		return ERR_ADDRESS_INCORRECT;
	
	address.set(Address::SOCKET);
	sockaddr_in &addr = address.as<sockaddr_in>();
	addr.sin_family = AF_INET;
	memcpy(&addr.sin_addr, ip, sizeof(ip));
	addr.sin_port = htons(portNum);
	
	return ERR_NONE;
}


void TcpProtocol::threadRun() {
	const int maxCount = 1024;
	struct epoll_event events[maxCount] = {};
	int count = 0;
	
	while(true) {
		{ // locked scope
			Lock lock(mutex);
			
			{ // read wakeup signals
				void *ptr = nullptr;
				while(::read(eventsId, &ptr, sizeof(ptr)) > 0);
			}
			
			// process socket events
			for(int i = 0; i < count; ++i) {
				if (TcpSocket *socket = (TcpSocket*)events[i].data.ptr) {
					if (socketToRemove(socket)) continue;
					
					unsigned int e = events[i].events;
					int sockId = socket->sockId;
					const Connection::Handle &connection = socket->connection;
					const Server::Handle &server = socket->server;
					
					if (e & EPOLLIN) {
						if (connection) {
							ErrorCode err = ERR_NONE;
							bool ready = true;
							while(ready) {
								if (socketToRemove(socket)) break;
								Connection::ReadLock readLock(*connection);
								if (!readLock) break;
								while(true) {
									size_t remains = readLock->remains();
									if (!remains) break;
									if (remains > INT_MAX) remains = INT_MAX;
									int res = ::recv(sockId, readLock->current(), (int)remains, MSG_DONTWAIT);
									if (res <= 0) {
										if (res != 0 && res != EAGAIN) err = ERR_CONNECTION_LOST;
										ready = false;
										break;
									}
									readLock->processed(res);
								}
							}
							if (socketToRemove(socket)) continue;
							if (err) connection->close(err);
						} else
						if (server) {
							Address address;
							address.type = Address::SOCKET;
							socklen_t len = sizeof(sockaddr_in);
							int res = ::accept(sockId, (struct sockaddr*)address.data, &len);
							assert(len < sizeof(address.data));
							if (res >= 0) {
								address.size = len;
								Connection::Handle conn = serverConnect(server, address);
								if (conn && !socketToRemove(socket)) {
									TcpSocket *s = new TcpSocket(*this, conn, server, address, res);
									if (ERR_NONE != connectionOpen(conn, s))
										{ s->finalize(); serverDisconnect(server, conn); }
								}
							} else
							if (res != EAGAIN) {
								server->stop(ERR_SERVER_LISTENING_LOST);
							}
						}
					}
					
					if (socketToRemove(socket)) continue;
					
					if (e & EPOLLOUT) {
						if (connection) {
							ErrorCode err = ERR_NONE;
							bool ready = true;
							while(ready) {
								if (socketToRemove(socket)) break;
								Connection::WriteLock writeLock(*connection);
								if (!writeLock) break;
								while(true) {
									size_t remains = writeLock->remains();
									if (!remains) break;
									if (remains > INT_MAX) remains = INT_MAX;
									int res = ::send(sockId, writeLock->current(), (int)remains, MSG_DONTWAIT);
									if (res <= 0) {
										if (res != 0 && res != EAGAIN) err = ERR_CONNECTION_LOST;
										ready = false;
										break;
									}
									writeLock->processed(res);
								}
							}
							if (socketToRemove(socket)) continue;
							if (err) connection->close(err);
						}
					}
					
					if (socketToRemove(socket)) continue;
					
					if (e & EPOLLHUP) {
						if (connection) {
							connection->close(e & EPOLLERR ? ERR_CONNECTION_LOST : 0);
						} else
						if (server) {
							server->stop(e & EPOLLERR ? ERR_CONNECTION_LOST : 0);
						}
					}
				}
			}
			
			// update epoll
			Socket *socket = socketChFirst();
			while(socket) {
				int sockId = ((TcpSocket*)socket)->sockId;
				struct epoll_event event = {};
				if (socketToRemove(socket)) {
					epoll_ctl(epollId, EPOLL_CTL_DEL, sockId, &event);
					socket = socketRemoveChanged(socket);
				} else {
					unsigned int flags = socketFlags(socket);
					event.data.ptr = socket;
					if (flags & Socket::WANTREAD) event.events |= EPOLLIN;
					if (flags & Socket::WANTWRITE) event.events |= EPOLLOUT;
					epoll_ctl(epollId, EPOLL_CTL_MOD, sockId, &event);
					socket = socketSetUnchanged(socket);
				}
			}
			
			if (!isActive())
				break;
		} // end of locked ecope
		
		count = epoll_wait(epollId, events, count, 1000);
		if (count < 0) break;
	}
	
	std::unique_lock<Mutex> uniqlock(mutex);
	if (isActive()) {
		stop(ERR_TCP_WORKER_FAILED);
		while(isActive()) threadWaitCondition.wait(uniqlock);
	}
}


ErrorCode TcpProtocol::onConnect(const Connection::Handle &connection, const Address &address) {
	if (address.type != Address::SOCKET) return ERR_ADDRESS_INCORRECT;
	
	int sockId = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
	if (0 != ::connect(sockId, &address.as<struct sockaddr>(), address.size)) {
		::close(sockId);
		return ERR_CONNECTION_FAILED;
	}
	
	TcpSocket *socket = new TcpSocket(*this, connection, nullptr, address, sockId);
	ErrorCode errorCode = connectionOpen(connection, socket);
	if (errorCode) socket->finalize();
	return errorCode;
}


ErrorCode TcpProtocol::onListen(const Server::Handle &server, const Address &address) {
	if (address.type != Address::SOCKET) return ERR_ADDRESS_INCORRECT;
	
	int sockId = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
	if (0 != ::bind(sockId, &address.as<struct sockaddr>(), address.size)) {
		::close(sockId);
		return ERR_SERVER_LISTENING_FAILED;
	}
	
	TcpSocket *socket = new TcpSocket(*this, nullptr, server, address, sockId);
	ErrorCode errorCode = serverStart(server, socket);
	if (errorCode) socket->finalize();
	return errorCode;
}


ErrorCode TcpProtocol::onStart() {
	epollId = ::epoll_create(32);
	if (epollId < 0) {
		epollId = 0;
		return ERR_PROTOCOL_INITIALIZATION_FAILED;
	}
	
	eventsId = ::eventfd(0, EFD_NONBLOCK);
	if (eventsId < 0) {
		::close(epollId);
		epollId = 0;
		eventsId = 0;
		return ERR_PROTOCOL_INITIALIZATION_FAILED;
	}
	
	struct epoll_event event = {};
	event.events = EPOLLIN;
	epoll_ctl(epollId, EPOLL_CTL_ADD, eventsId, &event);
	
	thread = new std::thread(&TcpProtocol::threadRun, this);
	return ERR_NONE;
}


void TcpProtocol::onStop(ErrorCode) {
	onSocketChanged(nullptr);
	threadWaitCondition.notify_all();
	thread->join();
	delete thread;
	
	struct epoll_event event = {};
	epoll_ctl(epollId, EPOLL_CTL_DEL, eventsId, &event);
	
	::close(eventsId);
	eventsId = 0;
	
	::close(epollId);
	epollId = 0;
}


void TcpProtocol::onSocketChanged(Socket*) {
	void *ptr = nullptr;
	::write(eventsId, &ptr, sizeof(ptr));
}