Blame tcp.cpp

91ef56
91ef56
#include <cerrno></cerrno>
91ef56
#include <cassert></cassert>
91ef56
#include <cstring></cstring>
91ef56
#include <climits></climits>
30ecb0
#include <cctype></cctype>
91ef56
91ef56
#include <fcntl.h></fcntl.h>
91ef56
#include <unistd.h></unistd.h>
91ef56
#include <sys types.h=""></sys>
91ef56
#include <sys socket.h=""></sys>
91ef56
#include <sys poll.h=""></sys>
91ef56
#include <netinet in.h=""></netinet>
9b327d
#include <netinet tcp.h=""></netinet>
30ecb0
#include <netdb.h></netdb.h>
9b327d
91ef56
#include "tcp.h"
91ef56
91ef56
91ef56
static const int pollTimeoutMs = 100;
6173c6
static const int connTimeoutMs = 300000;
91ef56
91ef56
91ef56
using namespace Tcp;
91ef56
91ef56
30ecb0
ErrorCode Address::resolve(const char *addrString) {
30ecb0
	*this = Address();
30ecb0
	
30ecb0
	if (!addrString)
30ecb0
		return ERR_TCP_BAD_ADDRESS;
30ecb0
	
30ecb0
	ErrorCode errorCode = ERR_NONE;
30ecb0
	
30ecb0
	const char *colon = addrString;
30ecb0
	while(*colon && *colon != ':') ++colon;
30ecb0
	
30ecb0
	char *hostBuf = nullptr;
30ecb0
	const char *host = nullptr;
30ecb0
	const char *port = nullptr;
30ecb0
	if (*colon) {
30ecb0
		hostBuf = new char[colon - addrString + 1];
30ecb0
		memcpy(hostBuf, addrString, colon - addrString);
30ecb0
		hostBuf[colon - addrString] = 0;
30ecb0
		host = hostBuf;
30ecb0
		port = colon + 1;
30ecb0
	} else {
30ecb0
		bool allDigits = true;
30ecb0
		for(const char *c = addrString; *c && allDigits; ++c)
30ecb0
			if (*c < '0' || *c > '9') allDigits = false;
30ecb0
		(allDigits ? port : host) = addrString;
30ecb0
	}
30ecb0
	
30ecb0
	if (host) {
30ecb0
		if (4 != sscanf(host, "%hhu.%hhu.%hhu.%hhu", &ip[0], &ip[1], &ip[2], &ip[3])) {
30ecb0
			memset(ip, 0, sizeof(ip));
30ecb0
			hostent *he;
30ecb0
			in_addr **addr_list;
30ecb0
			if ( (he = gethostbyname(host))
30ecb0
			  && (addr_list = (struct in_addr**)he->h_addr_list)
30ecb0
			  && (*addr_list) )
30ecb0
			{
30ecb0
				memcpy(ip, *addr_list, sizeof(ip));
30ecb0
			} else {
30ecb0
				errorCode = ERR_TCP_BAD_ADDRESS;
30ecb0
			}
30ecb0
		}
30ecb0
	}
30ecb0
	
30ecb0
	if (port) {
30ecb0
		if (1 != sscanf(port, "%hu", &this->port))
30ecb0
			errorCode = ERR_TCP_BAD_ADDRESS;
30ecb0
	}
30ecb0
	
30ecb0
	if (!host && !port)
30ecb0
		errorCode = ERR_TCP_BAD_ADDRESS;
30ecb0
	
30ecb0
	if (hostBuf) delete[] hostBuf;
30ecb0
	return errorCode;
30ecb0
}
91ef56
91ef56
91ef56
Connection::Connection():
91ef56
	status(STATUS_NONE),
91ef56
	lastError(ERR_NONE),
91ef56
	sockId(),
91ef56
	stopping(false)
91ef56
	{ }
91ef56
91ef56
Connection::Connection(SockId sockId, const Address &remoteAddr):
91ef56
	status(STATUS_OPEN),
91ef56
	lastError(ERR_NONE),
91ef56
	sockId(sockId),
91ef56
	remoteAddr(remoteAddr),
91ef56
	stopping(false)
91ef56
	{ }
91ef56
91ef56
Connection::~Connection()
91ef56
	{ close(); }
91ef56
91ef56
bool Connection::open(const Address &remoteAddr) {
91ef56
	close();
91ef56
	
91ef56
	this->remoteAddr = remoteAddr;
91ef56
	if (!this->remoteAddr.isValid())
91ef56
		{ close(ERR_TCP_BAD_ADDRESS); return false; }
91ef56
91ef56
	sockId = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
91ef56
	status = STATUS_OPEN;
91ef56
	
91ef56
	struct sockaddr_in addr = {};
91ef56
	addr.sin_family = AF_INET;
91ef56
	memcpy(&addr.sin_addr, remoteAddr.ip, sizeof(remoteAddr.ip));
91ef56
	addr.sin_port = htons(remoteAddr.port);
91ef56
	
91ef56
	if (0 != connect(sockId, (sockaddr*)&addr, sizeof(addr)))
91ef56
		{ close(ERR_TCP_CONNECTION_FAILED); return false; }
91ef56
	
9b327d
	int tcp_nodelay = 1;
91ef56
	fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
9b327d
	setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int));
9b327d
	
91ef56
	return true;
91ef56
}
91ef56
91ef56
void Connection::close(ErrorCode errorCode) {
91ef56
	if (status == STATUS_OPEN)
91ef56
		::close(sockId);
91ef56
	sockId = 0;
91ef56
	status = errorCode ? STATUS_LOST : STATUS_NONE;
91ef56
	lastError = errorCode;
91ef56
	stopping = false;
91ef56
}
91ef56
91ef56
bool Connection::read(void *data, size_t size) {
91ef56
	if (!data)
91ef56
		return check() && size == 0;
91ef56
	
91ef56
	struct pollfd pfd = {};
91ef56
	pfd.fd = sockId;
91ef56
	pfd.events = POLLIN;
91ef56
	
6173c6
	int remainMs = connTimeoutMs;
6173c6
	while(check() && size && remainMs >= 0) {
91ef56
		int pr = poll(&pfd, 1, pollTimeoutMs);
91ef56
		if (pr < 0)
91ef56
			{ close(ERR_TCP_CONNECTION_LOST); return false; }
91ef56
		if (pr > 0) {
91ef56
			int s = size > INT_MAX ? INT_MAX : (int)size;
0ea18d
			int r = ::recv(sockId, data, s, MSG_DONTWAIT | MSG_NOSIGNAL);
91ef56
			if (r < 0) {
91ef56
				if (errno != EAGAIN)
91ef56
					{ close(ERR_TCP_CONNECTION_LOST); return false; }
91ef56
			} else {
91ef56
				size -= r;
91ef56
				data = (char*)data + r;
91ef56
			}
91ef56
		}
6173c6
		remainMs -= pollTimeoutMs;
91ef56
	}
91ef56
91ef56
	return size == 0;
91ef56
}
91ef56
91ef56
bool Connection::write(const void *data, size_t size) {
91ef56
	if (!data)
91ef56
		return check() && size == 0;
91ef56
	
91ef56
	struct pollfd pfd = {};
91ef56
	pfd.fd = sockId;
91ef56
	pfd.events = POLLOUT;
91ef56
	
6173c6
	int remainMs = connTimeoutMs;
6173c6
	while(check() && size && remainMs >= 0) {
91ef56
		int pr = poll(&pfd, 1, pollTimeoutMs);
91ef56
		if (pr < 0)
91ef56
			{ close(ERR_TCP_CONNECTION_LOST); return false; }
91ef56
		if (pr > 0) {
91ef56
			int s = size > INT_MAX ? INT_MAX : (int)size;
0ea18d
			int r = ::send(sockId, data, s, MSG_DONTWAIT | MSG_NOSIGNAL);
91ef56
			if (r < 0) {
91ef56
				if (errno != EAGAIN)
91ef56
					{ close(ERR_TCP_CONNECTION_LOST); return false; }
91ef56
			} else {
91ef56
				size -= r;
6173c6
				data = (const char*)data + r;
91ef56
			}
91ef56
		}
6173c6
		remainMs -= pollTimeoutMs;
91ef56
	}
91ef56
91ef56
	return size == 0;
91ef56
}
91ef56
91ef56
91ef56
Handler::~Handler() { }
91ef56
91ef56
91ef56
Server::Server():
91ef56
	status(STATUS_NONE),
91ef56
	lastError(ERR_NONE),
91ef56
	sockId(),
91ef56
	localAddr(),
91ef56
	handler(),
91ef56
	thread(),
91ef56
	stopping(false)
91ef56
	{ }
91ef56
91ef56
Server::~Server()
91ef56
	{ stop(); }
91ef56
0ea18d
void Server::handleConnection(ConnList::iterator iter) {
0ea18d
	handler->handleTcpConnection(*iter->connection);
0ea18d
	{
0ea18d
		std::lock_guard<std::mutex> lock(connectionsMutex);</std::mutex>
0ea18d
		connectionsFinished.push_back(*iter);
0ea18d
		connections.erase(iter);
0ea18d
	}
0ea18d
}
0ea18d
0ea18d
void Server::cleanConnections() {
0ea18d
	while(!connectionsFinished.empty()) {
0ea18d
		ConnDesc &desc = connectionsFinished.back();
0ea18d
		desc.thread->join();
0ea18d
		delete desc.thread;
0ea18d
		delete desc.connection;
0ea18d
		connectionsFinished.pop_back();
0ea18d
	}
0ea18d
}
91ef56
91ef56
void Server::listen() {
91ef56
	fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
91ef56
91ef56
	struct sockaddr_in addr = {};
91ef56
	Address remoteAddr;
91ef56
	
91ef56
	struct pollfd pfd = {};
91ef56
	pfd.fd = sockId;
91ef56
	pfd.events = POLLIN;
91ef56
	
91ef56
	status = STATUS_OPEN;
91ef56
	while(!stopping) {
91ef56
		int pr = poll(&pfd, 1, pollTimeoutMs);
91ef56
		if (pr < 0)
91ef56
			{ lastError = ERR_TCP_LISTEN_LOST; break; }
0ea18d
			
91ef56
		if (pr > 0 && (pr & POLLIN)) {
91ef56
			socklen_t addrlen = sizeof(addr);
91ef56
			int sid = ::accept(sockId, (sockaddr*)&addr, &addrlen);
91ef56
			if (sid < 0) {
91ef56
				if (errno != EAGAIN)
91ef56
					{ lastError = ERR_TCP_LISTEN_LOST; break; }
91ef56
				continue;
91ef56
			}
91ef56
			
91ef56
			memcpy(remoteAddr.ip, &addr.sin_addr, sizeof(remoteAddr.ip));
91ef56
			remoteAddr.port = ntohs(addr.sin_port);
91ef56
			
0ea18d
			ConnList::iterator iter;
0ea18d
			{
0ea18d
				std::lock_guard<std::mutex> lock(connectionsMutex);</std::mutex>
0ea18d
				cleanConnections();
0ea18d
				iter = connections.emplace(connections.end());
0ea18d
			}
0ea18d
			iter->connection = new Connection(sid, remoteAddr);
0ea18d
			iter->thread = new std::thread(&Server::handleConnection, this, iter);
0ea18d
		} else {
0ea18d
			std::lock_guard<std::mutex> lock(connectionsMutex);</std::mutex>
0ea18d
			cleanConnections();
91ef56
		}
91ef56
	}
91ef56
	status = lastError ? STATUS_LOST : STATUS_CLOSED;
91ef56
	
0ea18d
	connectionsMutex.lock();
91ef56
	for(ConnList::iterator i = connections.begin(); i != connections.end(); ++i)
91ef56
		i->connection->stop();
0ea18d
	while(!connections.empty()) {
0ea18d
		std::thread *thread = connections.front().thread;
0ea18d
		connectionsMutex.unlock();
0ea18d
		thread->join();
0ea18d
		connectionsMutex.lock();
91ef56
	}
0ea18d
	cleanConnections();
0ea18d
	connectionsMutex.unlock();
91ef56
	
91ef56
	stopping = false;
91ef56
}
91ef56
91ef56
bool Server::start(const Address &localAddr, Handler *handler) {
91ef56
	stop();
91ef56
	
91ef56
	this->localAddr = localAddr;
91ef56
	this->handler = handler;
91ef56
	
30ecb0
	if (!this->localAddr.isValidPort())
91ef56
		{ status = STATUS_LOST; lastError = ERR_TCP_BAD_ADDRESS; return false; }
91ef56
	if (!this->handler)
91ef56
		{ status = STATUS_LOST; lastError = ERR_TCP_BAD_HANDLER; return false; }
91ef56
	
91ef56
	sockId = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
91ef56
	
91ef56
	struct sockaddr_in addr = {};
91ef56
	addr.sin_family = AF_INET;
91ef56
	memcpy(&addr.sin_addr, localAddr.ip, sizeof(localAddr.ip));
91ef56
	addr.sin_port = htons(localAddr.port);
91ef56
	
91ef56
	if ( 0 != ::bind(sockId, (sockaddr*)&addr, sizeof(addr))
91ef56
	  || 0 != ::listen(sockId, 32) )
91ef56
	{
91ef56
		status = STATUS_LOST;
91ef56
		lastError = ERR_TCP_LISTEN_FAILED;
91ef56
		this->handler = nullptr;
91ef56
		return false;
91ef56
	}
91ef56
	
9b327d
	int tcp_nodelay = 1;
91ef56
	fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
9b327d
	setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int));
9b327d
	
91ef56
	thread = new std::thread(&Server::listen, this);
91ef56
	return true;
91ef56
}
91ef56
91ef56
void Server::join() {
91ef56
	if (!thread)
91ef56
		return;
91ef56
	thread->join();
91ef56
	delete thread;
91ef56
	thread = nullptr;
91ef56
	handler = nullptr;
91ef56
}
91ef56
91ef56
void Server::stop() {
91ef56
	if (!thread)
91ef56
		return;
91ef56
	stopping = true;
91ef56
	join();
91ef56
}
91ef56