Blob Blame Raw

#include <climits>
#include <cstring>

#include <fcntl.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>

#include "protocol.h"


#define DebugMSG    HideDebugMSG
#define SrvDebugMSG ShowDebugMSG



Protocol::Protocol():
	thread(),
	epollFd(-1),
	eventFd(-1),
	state(STATE_NONE),
	stateWanted(STATE_NONE),
	closeWaiters(0),
	finishWaiters(0),
	connFirst(), connLast(),
	srvFirst(), srvLast()
	{ }


Protocol::~Protocol() {
	assert(state == STATE_NONE || state == STATE_FINISHED);
	closeWait();
	while(finishWaiters > 0) std::this_thread::yield();
	DebugMSG();
	if (thread) thread->join();
	DebugMSG("deleted");
}


bool Protocol::open() {
	DebugMSG();
	State s = state;
	do {
		if (s != STATE_NONE && s != STATE_FINISHED) return false;
	} while(!state.compare_exchange_weak(s, STATE_OPENING));
	if (s == STATE_FINISHED) while(finishWaiters > 0) std::this_thread::yield();
	if (thread) thread->join();
	
	thread = new std::thread(&Protocol::threadRun, this);
	state = STATE_INITIALIZING;
	DebugMSG();
	return true;
}


void Protocol::wantState(State state) {
	ReadLock lock(stateProtector);
	State s = this->state;
	DebugMSG("%d, %d", s, state);
	if (s <= STATE_OPENING || s >= STATE_CLOSING || s >= state) return;
	DebugMSG("apply");
	
	State sw = stateWanted;
	do {
		if (sw >= state) return;
	} while(!stateWanted.compare_exchange_weak(sw, state));
	wakeup();
}


void Protocol::closeReq()
	{ DebugMSG(); wantState(STATE_CLOSE_REQ); }
void Protocol::close()
	{ DebugMSG(); wantState(STATE_CLOSING); }


void Protocol::closeWait(unsigned long long timeoutUs, bool withReq) {
	DebugMSG();
	stateProtector.lock();
	State s = state;
	if (s <= STATE_OPENING || s >= STATE_FINISHED)
		{ stateProtector.unlock(); return; }
	
	
	if (s == STATE_CLOSED) {
		finishWaiters++;
		stateProtector.unlock();
	} else {
		if (withReq && s < STATE_CLOSE_REQ) closeReq();
		closeWaiters++;
		stateProtector.unlock();
		
		std::mutex fakeMutex;
		{
			std::unique_lock<std::mutex> fakeLock(fakeMutex);
			if (timeoutUs && s < STATE_CLOSING) {
				std::chrono::steady_clock::time_point t
					= std::chrono::steady_clock::now()
						+ std::chrono::microseconds(timeoutUs);
				while(s < STATE_CLOSED) {
					if (t <= std::chrono::steady_clock::now()) { close(); break; }
					DebugMSG("wait_until, %d", s);
					closeCondition.wait_until(fakeLock, t);
					s = state;
				}
			}
			if (s < STATE_CLOSING) close();
			while(s < STATE_CLOSED) {
				DebugMSG("wait, %d", s);
				closeCondition.wait(fakeLock);
				s = state;
			}
			DebugMSG("awaiten");
		}
		
		finishWaiters++;
		closeWaiters--;
	}
	
	
	while(s != STATE_FINISHED) { std::this_thread::yield(); s = state; }
	
	finishWaiters--;
	DebugMSG();
}


void Protocol::wakeup() {
	unsigned long long num = 0;
	::write(eventFd, &num, sizeof(num));
}


void Protocol::updateEvents(Connection &connection) {
	unsigned int events = 0;
	if (connection.readQueue.count()) events |= EPOLLIN;
	if (connection.writeQueue.count()) events |= EPOLLOUT;
	if (connection.events.exchange(events) == events) return;
	
	struct epoll_event event = {};
	event.data.ptr = &connection;
	event.events = events;
	epoll_ctl(epollFd, EPOLL_CTL_MOD, connection.sockId, &event);
}


void Protocol::initSocket(int sockId) {
	int tcp_nodelay = 1;
	setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int));
	fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
}


void Protocol::threadConnQueue() {
	while(connQueue.peek()) {
		Connection &connection = connQueue.pop();
		connection.enqueued = false;
		
		if (connection.stateLocal == Connection::STATE_NONE) {
			int sockId = connection.sockId;
			bool connected = false;
			
			if (sockId >= 0) {
				initSocket(sockId);
			} else
			if (connection.addressSize >= sizeof(sa_family_t)) {
				struct sockaddr *addr = (struct sockaddr*)connection.address;
				sockId = ::socket(addr->sa_family, SOCK_STREAM, 0);
				if (sockId >= 0) {
					initSocket(sockId);
					int res = ::connect(sockId, addr, (int)connection.addressSize);
					if (res == 0) {
						connected = true;
					} else
					if (res != EINPROGRESS) {
						::close(sockId);
						sockId = -1;
					}
				}
			}
			
			if (sockId >= 0) {
				connection.protocol = this;
				connection.prev = connLast;
				assert(!connection.next);
				(connection.prev ? connection.prev->next : connFirst) = &connection;
				if (connection.server) {
					connection.srvPrev = connection.server->connLast;
					assert(!connection.srvNext);
					(connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = &connection;
				}
				connection.sockId = sockId;
				
				struct epoll_event event = {};
				event.data.ptr = &connection;
				if (connected) {
					epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event);
					connection.onOpen(connection.address, connection.addressSize);
					connection.state = connection.stateLocal = Connection::STATE_OPEN;
					connection.stateProtector.wait();
				} else {
					event.events = EPOLLOUT;
					epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event);
					connection.state = connection.stateLocal = Connection::STATE_CONNECTING;
					connection.stateProtector.wait();
				}
			} else {
				connection.sockId = -1;
				connection.server = nullptr;
				connection.onOpeningError();
				connection.state = Connection::STATE_CLOSED;
				while(connection.closeWaiters > 0)
					{ connection.closeCondition.notify_all(); std::this_thread::yield(); }
				connection.state = Connection::STATE_FINISHED;
			}
		} else {
			Connection::State stateWanted = connection.stateWanted;
			
			if ( connection.stateLocal >= Connection::STATE_CONNECTING
			  && connection.stateLocal <= Connection::STATE_OPEN
			  && stateWanted == Connection::STATE_CLOSE_REQ )
			{
				connection.state = connection.stateLocal = Connection::STATE_CLOSE_REQ;
				connection.stateProtector.wait();
				connection.onCloseReqested();
			} else
			if ( connection.stateLocal >= Connection::STATE_CONNECTING
			  && connection.stateLocal <= Connection::STATE_CLOSE_REQ
			  && stateWanted == Connection::STATE_CLOSING )
			{
				connection.state = connection.stateLocal = Connection::STATE_CLOSING;
				connection.stateProtector.wait();
				while(Connection::Task *task = connection.readQueue.peek()) {
					connection.readQueue.pop();
					if (!connection.onReadReady(*task, true))
						assert(false);
				}
				while(Connection::Task *task = connection.writeQueue.peek()) {
					connection.writeQueue.pop();
					if (!connection.onWriteReady(*task, true))
						assert(false);
				}
				connection.onClose(false);
			}
			
			if (connection.stateLocal == Connection::STATE_CLOSING && !connection.enqueued) {
				if (connection.server) {
					(connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = connection.srvNext;
					(connection.srvNext ? connection.srvNext->srvPrev : connection.server->connLast) = connection.srvPrev;
					connection.server = nullptr;
				}
				(connection.prev ? connection.prev->next : connFirst) = connection.next;
				(connection.next ? connection.next->prev : connLast) = connection.prev;
				connection.sockId = -1;
				connection.protocol = nullptr;
				memset(connection.address, 0, sizeof(connection.address));
				connection.addressSize = 0;
				connection.stateWanted = Connection::STATE_NONE;
				connection.stateLocal = Connection::STATE_NONE;
				connection.state = Connection::STATE_CLOSED;
				while(connection.closeWaiters > 0)
					{ connection.closeCondition.notify_all(); std::this_thread::yield(); }
				connection.state = Connection::STATE_FINISHED;
			}
		}
	}
}


void Protocol::threadConnEvents(Connection &connection, unsigned int events) {
	bool needUpdate = false;
	
	if (connection.stateLocal == Connection::STATE_CONNECTING) {
		if (events & EPOLLHUP) {
			connection.onOpeningError();
			if (connection.server) {
				(connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = connection.srvNext;
				(connection.srvNext ? connection.srvNext->srvPrev : connection.server->connLast) = connection.srvPrev;
				connection.server = nullptr;
			}
			(connection.prev ? connection.prev->next : connFirst) = connection.next;
			(connection.next ? connection.next->prev : connLast) = connection.prev;
			connection.sockId = -1;
			connection.protocol = nullptr;
			memset(connection.address, 0, sizeof(connection.address));
			connection.addressSize = 0;
			connection.stateWanted = Connection::STATE_NONE;
			connection.stateLocal = Connection::STATE_NONE;
			connection.state = Connection::STATE_CLOSED;
			while(connection.closeWaiters > 0)
				{ connection.closeCondition.notify_all(); std::this_thread::yield(); }
			connection.state = Connection::STATE_FINISHED;
		}
		
		if (events & EPOLLOUT) {
			connection.onOpen(connection.address, connection.addressSize);
			connection.state = connection.stateLocal = Connection::STATE_OPEN;
			connection.stateProtector.wait();
			needUpdate = true;
		}
	}

	if ( connection.stateLocal < Connection::STATE_OPEN
	  || connection.stateLocal >= Connection::STATE_CLOSING )
		return;
	
	if (events & EPOLLIN) {
		bool ready = true;
		int pops = 0;
		while(ready) {
			Connection::Task *task = connection.readQueue.peek();
			if (!task) {
				if (pops) needUpdate = true;
				break;
			}
			size_t completion = 0;
			while(task->size > task->completion) {
				size_t size = task->size - task->completion;
				if (size > INT_MAX) size = INT_MAX;
				int res = ::recv(
					connection.sockId,
					(unsigned char*)task->data + task->completion,
					(int)size, MSG_DONTWAIT );
				if (res == EAGAIN) {
					ready = false;
					break;
				} else
				if (res <= 0) {
					connection.close(true);
					return;
				} else {
					completion += size;
					task->completion += size;
				}
			}
			
			if (task->size <= task->completion || (completion && task->watch)) {
				connection.readQueue.pop();
				if (!connection.onReadReady(*task, false))
					connection.readQueue.unpop(*task); else ++pops;
			}
		}
	}
	
	if (events & EPOLLHUP) {
		connection.close((bool)(events & EPOLLERR));
		return;
	}
	
	if (events & EPOLLOUT) {
		bool ready = true;
		int pops = 0;
		while(ready) {
			Connection::Task *task = connection.writeQueue.peek();
			if (!task) {
				if (pops) needUpdate = true;
				break;
			}
			size_t completion = 0;
			while(task->size > task->completion) {
				size_t size = task->size - task->completion;
				if (size > INT_MAX) size = INT_MAX;
				int res = ::send(
					connection.sockId,
					(const unsigned char*)task->data + task->completion,
					(int)size, MSG_DONTWAIT );
				if (res == EAGAIN) {
					ready = false;
					break;
				} else
				if (res <= 0) {
					connection.close(true);
					return;
				} else {
					completion += size;
					task->completion += size;
				}
			}
			
			if (task->size <= task->completion || (completion && task->watch)) {
				connection.writeQueue.pop();
				if (!connection.onWriteReady(*task, false))
					connection.writeQueue.unpop(*task); else ++pops;
			}
		}
	}
	
	if (needUpdate) updateEvents(connection);
}


void Protocol::threadSrvQueue() {
	while(srvQueue.peek()) {
		SrvDebugMSG();
		Server &server = srvQueue.pop();
		server.enqueued = false;
		
		if (server.stateLocal == Server::STATE_NONE) {
			SrvDebugMSG("none");
			int sockId = -1;
			if (server.addressSize >= sizeof(sa_family_t)) {
				struct sockaddr *addr = (struct sockaddr*)server.address;
				sockId = ::socket(addr->sa_family, SOCK_STREAM, 0);
				if (sockId >= 0) {
					initSocket(sockId);
					if ( 0 != ::bind(sockId, addr, (int)server.addressSize)
					  || 0 != ::listen(sockId, 128) )
					{
						SrvDebugMSG("cannot bind/listen");
						::close(sockId);
						sockId = -1;
					}
				} else {
					SrvDebugMSG("cannot create socket");
				}
			}
			
			if (sockId >= 0) {
				SrvDebugMSG("to open");
				server.protocol = this;
				server.prev = srvLast;
				assert(!server.next);
				(server.prev ? server.prev->next : srvFirst) = &server;
				server.sockId = sockId;
				
				struct epoll_event event = {};
				event.data.ptr = &server;
				event.events = EPOLLIN;
				epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event);
				
				server.onOpen();
				server.state = server.stateLocal = Server::STATE_OPEN;
				server.stateProtector.wait();
			} else {
				SrvDebugMSG("opening error -> to closed");
				server.onOpeningError();
				server.state = Server::STATE_CLOSED;
				while(server.closeWaiters > 0)
					{ server.closeCondition.notify_all(); std::this_thread::yield(); }
				server.state = Server::STATE_FINISHED;
				SrvDebugMSG("opening error -> finished");
			}
		} else {
			Server::State stateWanted = server.stateWanted;
			if ( server.stateLocal == Server::STATE_OPEN
			  && stateWanted == Server::STATE_CLOSE_REQ )
			{
				SrvDebugMSG("to close req");
				for(Connection *conn = server.connFirst; conn; conn = conn->srvNext)
					conn->closeReq();
				struct epoll_event event = {};
				epoll_ctl(epollFd, EPOLL_CTL_DEL, server.sockId, &event);
				server.state = server.stateLocal = Server::STATE_CLOSE_REQ;
				server.stateProtector.wait();
				server.onCloseReqested();
			} else
			if ( ( server.stateLocal == Server::STATE_OPEN
				|| server.stateLocal == Server::STATE_CLOSE_REQ )
			  && stateWanted == Server::STATE_CLOSING_CONNECTIONS )
			{
				SrvDebugMSG("to closing connections");
				if (server.stateLocal != Server::STATE_CLOSE_REQ) {
					struct epoll_event event = {};
					epoll_ctl(epollFd, EPOLL_CTL_DEL, server.sockId, &event);
				}
				if (server.connFirst) {
					for(Connection *conn = server.connFirst; conn; conn = conn->srvNext)
						conn->close(true);
					server.state = server.stateLocal = Server::STATE_CLOSING_CONNECTIONS;
				} else {
					SrvDebugMSG("to closing");
					server.state = server.stateLocal = Server::STATE_CLOSING;
					server.stateProtector.wait();
					server.onClose(false);
				}
			}
			
			if (server.stateLocal == Server::STATE_CLOSING && !server.enqueued) {
				SrvDebugMSG("to closed");
				server.state = Server::STATE_CLOSED;
				(server.prev ? server.prev->next : srvFirst) = server.next;
				(server.next ? server.next->prev : srvLast) = server.prev;
				server.protocol = nullptr;
				memset(server.address, 0, sizeof(server.address));
				server.addressSize = 0;
				server.stateWanted = Server::STATE_NONE;
				server.stateLocal = Server::STATE_NONE;
				SrvDebugMSG("closeWaiters");
				while(server.closeWaiters > 0)
					{ server.closeCondition.notify_all(); std::this_thread::yield(); }
				server.state = Server::STATE_FINISHED;
				SrvDebugMSG("finished");
			}
		}
	}
}


void Protocol::threadSrvEvents(Server &server, unsigned int events) {
	if (events & EPOLLHUP) {
		if ( server.stateLocal >= Server::STATE_OPEN
		  && server.stateLocal < Server::STATE_CLOSING )
			server.close((bool)(events & EPOLLERR));
		return;
	}
	
	if (events & EPOLLIN) {
		if ( server.stateLocal >= Server::STATE_OPEN
		  && server.stateLocal < Server::STATE_CLOSE_REQ)
		{
			unsigned char address[MAX_ADDR_SIZE] = {};
			socklen_t len = sizeof(address);
			int res = ::accept(server.sockId, (struct sockaddr*)address, &len);
			assert(len < sizeof(address));
			if (res >= 0) {
				if (Connection *conn = server.onConnect(address, len))
					if (!conn->open(*this, address, len, res))
						server.onDisconnect(conn, true);
			} else
			if (res != EAGAIN) {
				server.close(true);
				return;
			}
		}
	}
}


void Protocol::threadRun() {
	DebugMSG();
	epollFd = ::epoll_create(32);
	eventFd = ::eventfd(0, EFD_NONBLOCK);
	
	struct epoll_event event = {};
	event.events = EPOLLIN;
	epoll_ctl(epollFd, EPOLL_CTL_ADD, eventFd, &event);
	
	const int maxCount = 16;
	struct epoll_event events[maxCount] = {};
	int count = 0;
	
	State stateLocal = STATE_OPEN;
	state = stateLocal;
	
	while(true) {
		for(int i = 0; i < count; ++i) {
			Root *pointer = (Root*)events[i].data.ptr;
			if (!pointer) {
				unsigned long long num = 0;
				while(::read(eventFd, &num, sizeof(num)) > 0 && !num);
			} else
			if (Connection *connection = dynamic_cast<Connection*>(pointer)) {
				threadConnEvents(*connection, events[i].events);
			} else
			if (Server *server = dynamic_cast<Server*>(pointer)) {
				threadSrvEvents(*server, events[i].events);
			}
		}
		
		State sw = stateWanted;
		if (stateLocal == STATE_OPEN && sw == STATE_CLOSE_REQ) {
			DebugMSG("to closse req");
			state = stateLocal = STATE_CLOSE_REQ;
			stateProtector.wait();
			for(Server *server = srvFirst; server; server = server->next)
				server->closeReq();
			for(Connection *conn = connFirst; conn; conn = conn->next)
				if (!conn->server)
					conn->closeReq();
		} else
		if (stateLocal >= STATE_OPEN && stateLocal <= STATE_CLOSE_REQ && sw == STATE_CLOSING) {
			DebugMSG("to closing");
			state = stateLocal = STATE_CLOSING;
			stateProtector.wait();
			for(Server *server = srvFirst; server; server = server->next)
				server->close(true);
			for(Connection *conn = connFirst; conn; conn = conn->next)
				if (!conn->server)
					conn->close(true);
		}
		
		while(connQueue.peek() || srvQueue.peek()) {
			threadConnQueue();
			threadSrvQueue();
		}
		
		if (stateLocal == STATE_CLOSING) {
			assert(!connFirst && !srvFirst);
			break;
		}
		
		DebugMSG("epoll_wait: %d, %d", stateLocal, sw);
		count = epoll_wait(epollFd, events, maxCount, 1000);
	}
	
	DebugMSG("closeWaiters");
	state = STATE_FINISHED;
	while(closeWaiters > 0)
		{ closeCondition.notify_all(); std::this_thread::yield(); }
	
	::close(epollFd);
	::close(eventFd);
	epollFd = -1;
	eventFd = -1;
	assert(!connFirst && !srvFirst);
	assert(!connQueue.peek() && !srvQueue.peek());
	
	stateWanted = STATE_NONE;
	stateLocal = STATE_NONE;
	state = STATE_FINISHED;
	DebugMSG("finished");
}