Blob Blame Raw

#include <cerrno>
#include <cassert>
#include <cstring>
#include <climits>

#include <fcntl.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/poll.h>
#include <netinet/in.h>
       
#include "tcp.h"


static const int pollTimeoutMs = 100;


using namespace Tcp;


//ErrorCode Address::resolve(const char *addrString);


Connection::Connection():
	status(STATUS_NONE),
	lastError(ERR_NONE),
	sockId(),
	stopping(false)
	{ }

Connection::Connection(SockId sockId, const Address &remoteAddr):
	status(STATUS_OPEN),
	lastError(ERR_NONE),
	sockId(sockId),
	remoteAddr(remoteAddr),
	stopping(false)
	{ }

Connection::~Connection()
	{ close(); }

bool Connection::open(const Address &remoteAddr) {
	close();
	
	this->remoteAddr = remoteAddr;
	if (!this->remoteAddr.isValid())
		{ close(ERR_TCP_BAD_ADDRESS); return false; }

	sockId = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
	status = STATUS_OPEN;
	
	struct sockaddr_in addr = {};
	addr.sin_family = AF_INET;
	memcpy(&addr.sin_addr, remoteAddr.ip, sizeof(remoteAddr.ip));
	addr.sin_port = htons(remoteAddr.port);
	
	if (0 != connect(sockId, (sockaddr*)&addr, sizeof(addr)))
		{ close(ERR_TCP_CONNECTION_FAILED); return false; }
	
	fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
	return true;
}

void Connection::close(ErrorCode errorCode) {
	if (status == STATUS_OPEN)
		::close(sockId);
	sockId = 0;
	status = errorCode ? STATUS_LOST : STATUS_NONE;
	lastError = errorCode;
	stopping = false;
}

bool Connection::read(void *data, size_t size) {
	if (!data)
		return check() && size == 0;
	
	struct pollfd pfd = {};
	pfd.fd = sockId;
	pfd.events = POLLIN;
	
	while(check() && size) {
		int pr = poll(&pfd, 1, pollTimeoutMs);
		if (pr < 0)
			{ close(ERR_TCP_CONNECTION_LOST); return false; }
		if (pr > 0) {
			int s = size > INT_MAX ? INT_MAX : (int)size;
			int r = ::recv(sockId, data, s, MSG_DONTWAIT);
			if (r < 0) {
				if (errno != EAGAIN)
					{ close(ERR_TCP_CONNECTION_LOST); return false; }
			} else {
				size -= r;
				data = (char*)data + r;
			}
		}
	}

	return size == 0;
}

bool Connection::write(const void *data, size_t size) {
	if (!data)
		return check() && size == 0;
	
	struct pollfd pfd = {};
	pfd.fd = sockId;
	pfd.events = POLLOUT;
	
	while(check() && size) {
		int pr = poll(&pfd, 1, pollTimeoutMs);
		if (pr < 0)
			{ close(ERR_TCP_CONNECTION_LOST); return false; }
		if (pr > 0) {
			int s = size > INT_MAX ? INT_MAX : (int)size;
			int r = ::send(sockId, data, s, MSG_DONTWAIT);
			if (r < 0) {
				if (errno != EAGAIN)
					{ close(ERR_TCP_CONNECTION_LOST); return false; }
			} else {
				size -= r;
				data = (char*)data + r;
			}
		}
	}

	return size == 0;
}


Handler::~Handler() { }


Server::Server():
	status(STATUS_NONE),
	lastError(ERR_NONE),
	sockId(),
	localAddr(),
	handler(),
	thread(),
	stopping(false)
	{ }

Server::~Server()
	{ stop(); }


void Server::listen() {
	fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);

	struct sockaddr_in addr = {};
	Address remoteAddr;
	
	struct pollfd pfd = {};
	pfd.fd = sockId;
	pfd.events = POLLIN;
	
	status = STATUS_OPEN;
	while(!stopping) {
		int pr = poll(&pfd, 1, pollTimeoutMs);
		if (pr < 0)
			{ lastError = ERR_TCP_LISTEN_LOST; break; }
		if (pr > 0 && (pr & POLLIN)) {
			socklen_t addrlen = sizeof(addr);
			int sid = ::accept(sockId, (sockaddr*)&addr, &addrlen);
			if (sid < 0) {
				if (errno != EAGAIN)
					{ lastError = ERR_TCP_LISTEN_LOST; break; }
				continue;
			}
			
			memcpy(remoteAddr.ip, &addr.sin_addr, sizeof(remoteAddr.ip));
			remoteAddr.port = ntohs(addr.sin_port);
			
			ConnDesc desc = {};
			desc.connection = new Connection(sid, remoteAddr);
			desc.thread = new std::thread(
				&Handler::handleTcpConnection, handler, std::ref(*desc.connection));
			connections.push_back(desc);
		}
	}
	status = lastError ? STATUS_LOST : STATUS_CLOSED;
	
	for(ConnList::iterator i = connections.begin(); i != connections.end(); ++i)
		i->connection->stop();
	for(ConnList::iterator i = connections.begin(); i != connections.end(); ++i) {
		i->thread->join();
		delete i->thread;
		delete i->connection;
	}
	
	stopping = false;
}

bool Server::start(const Address &localAddr, Handler *handler) {
	stop();
	
	this->localAddr = localAddr;
	this->handler = handler;
	
	if (!this->localAddr.isValid())
		{ status = STATUS_LOST; lastError = ERR_TCP_BAD_ADDRESS; return false; }
	if (!this->handler)
		{ status = STATUS_LOST; lastError = ERR_TCP_BAD_HANDLER; return false; }
	
	sockId = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
	
	struct sockaddr_in addr = {};
	addr.sin_family = AF_INET;
	memcpy(&addr.sin_addr, localAddr.ip, sizeof(localAddr.ip));
	addr.sin_port = htons(localAddr.port);
	
	if ( 0 != ::bind(sockId, (sockaddr*)&addr, sizeof(addr))
	  || 0 != ::listen(sockId, 32) )
	{
		status = STATUS_LOST;
		lastError = ERR_TCP_LISTEN_FAILED;
		this->handler = nullptr;
		return false;
	}
	
	fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
	thread = new std::thread(&Server::listen, this);
	return true;
}

void Server::join() {
	if (!thread)
		return;
	thread->join();
	delete thread;
	thread = nullptr;
	handler = nullptr;
}

void Server::stop() {
	if (!thread)
		return;
	stopping = true;
	join();
}