Blob Blame Raw

#include <cstring>
#include <climits>

#include <chrono>
#include <thread>
#include <mutex>

#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <netinet/in.h>

#include "protocol.h"
#include "connection.h"


#define DebugMSG ShowDebugMSG



Connection::Connection():
	server(), srvPrev(), srvNext() { }


Connection::~Connection()
	{ destroy(); }


bool Connection::open(Protocol &protocol, void *address, size_t addressSize, int sockId, Server *server) {
	OpenLock lock(*this, protocol, address, addressSize);
	if (!lock.success)
		return false;
	this->sockId = sockId;
	this->server = server;
	return true;
}


bool Connection::open(Protocol &protocol, void *address, size_t addressSize)
	{ return open(protocol, address, addressSize, -1, nullptr); }


bool Connection::read(Task &task) {
	DebugMSG();
	ReadLock lock(stateProtector);
	State s = state;
	if (s < STATE_OPEN || s >= STATE_CLOSING)
		return false;
	DebugMSG();
	readQueue.push(task);
	updateEvents(readQueue.count(), writeQueue.count());
	return true;
}


bool Connection::write(Task &task) {
	DebugMSG();
	ReadLock lock(stateProtector);
	State s = state;
	if (s < STATE_OPEN || s >= STATE_CLOSING)
		return false;
	DebugMSG();
	writeQueue.push(task);
	updateEvents(readQueue.count(), writeQueue.count());
	return true;
}


void Connection::handleState() {
	if (stateLocal == STATE_NONE) {
		DebugMSG("none");
		bool connected = false;
		
		if (sockId >= 0) {
			DebugMSG("accepting");
			initSocket(sockId);
			connected = true;
		} else
		if (addressSize >= sizeof(sa_family_t)) {
			DebugMSG("connecting");
			struct sockaddr *addr = (struct sockaddr*)address;
			sockId = ::socket(addr->sa_family, SOCK_STREAM, 0);
			if (sockId >= 0) {
				initSocket(sockId);
				if (0 == ::connect(sockId, addr, (int)addressSize))
					connected = true;
			} else {
				DebugMSG("none -> cannot create socket");
			}
		}

		if (sockId >= 0) {
			DebugMSG("opening");
			prev = protocol->sockLast;
			(prev ? prev->next : protocol->sockFirst) = this;
			if (server) {
				srvPrev = server->connLast;
				(srvPrev ? srvPrev->srvNext : server->connFirst) = this;
			}
			updateEvents(false, !connected);
			
			if (connected) {
				DebugMSG("to open");
				state = stateLocal = STATE_OPEN;
				stateProtector.wait();
				onOpen(address, addressSize);
			} else {
				assert(!server);
				DebugMSG("to connecting");
				state = stateLocal = STATE_CONNECTING;
				stateProtector.wait();
			}
		} else {
			DebugMSG("none -> to closed");
			Server *srv = server;
			assert(connected || !server);
			server = nullptr;
			sockId = -1;
			onOpeningError();
			state = STATE_CLOSED;
			stateProtector.wait();
			while(closeWaiters > 0)
				{ closeCondition.notify_all(); std::this_thread::yield(); }
			state = STATE_FINISHED;
			DebugMSG("none -> to finished");
			if (srv) server->handleDisconnect(*this, true);
		}
	} else {
		State sw = stateWanted;
		
		if ( stateLocal == STATE_CONNECTING
		  && sw == STATE_CLOSING )
		{
			DebugMSG("connecting -> to closing");
			onOpeningError();
			state = stateLocal = STATE_CLOSING;
			stateProtector.wait();
		}
		
		if ( stateLocal == STATE_OPEN
		  && sw == STATE_CLOSE_REQ )
		{
			DebugMSG("to close req");
			state = stateLocal = STATE_CLOSE_REQ;
			stateProtector.wait();
			onCloseReqested();
		}
		
		if ( stateLocal >= STATE_OPEN
		  && stateLocal <= STATE_CLOSE_REQ
		  && sw == STATE_CLOSING )
		{
			DebugMSG("to closing");
			state = stateLocal = STATE_CLOSING;
			stateProtector.wait();
			while(Task *task = readQueue.peek())
				{ readQueue.pop(); onReadReady(*task, true); }
			while(Task *task = writeQueue.peek())
				{ writeQueue.pop(); onWriteReady(*task, true); }
			onClose(error);
		}
		
		if ( stateLocal == STATE_CLOSING && !enqueued ) {
			DebugMSG("to closed");
			Server *srv = server;
			bool err = error;
			
			if (server) {
				(srvPrev ? srvPrev->srvNext : server->connFirst) = srvNext;
				(srvNext ? srvNext->srvPrev : server->connLast ) = srvPrev;
				server = nullptr;
			}
			error = false;
			
			(prev ? prev->next : protocol->sockFirst) = next;
			(next ? next->prev : protocol->sockLast ) = prev;
			if (sockId >= 0) {
				struct epoll_event event = {};
				epoll_ctl(protocol->epollFd, EPOLL_CTL_DEL, sockId, &event);
				::close(sockId);
				sockId = -1;
			}
			events = 0xffffffff;
			protocol = nullptr;
			
			memset(address, 0, sizeof(address));
			addressSize = 0;
			stateLocal = STATE_NONE;
			state = STATE_CLOSED;
			stateWanted = STATE_NONE;
			while(closeWaiters > 0)
				{ closeCondition.notify_all(); std::this_thread::yield(); }
			state = STATE_FINISHED;
			DebugMSG("to finished");
			if (srv) srv->handleDisconnect(*this, err);
		}
	}
}


void Connection::handleEvents(unsigned int events) {
	bool needUpdate = false;
	DebugMSG("%08x", events);
	
	if (stateLocal == STATE_CONNECTING) {
		DebugMSG("connecting");
		if (events & (EPOLLHUP | EPOLLERR)) {
			close(true);
			return;
		}
		
		if (events & EPOLLOUT) {
			DebugMSG("connecting -> to open");
			state = stateLocal = STATE_OPEN;
			stateProtector.wait();
			onOpen(address, addressSize);
			needUpdate = true;
		}
	}

	if ( stateLocal < STATE_OPEN
	  || stateLocal >= STATE_CLOSING )
		return;
	
	bool closing = (bool)(events & (EPOLLHUP | EPOLLRDHUP | EPOLLERR));
	
	if (events & EPOLLIN) {
		DebugMSG("input");
		bool ready = true;
		int pops = 0;
		int count = (readQueue.count() + 2)*2;
		while(ready) {
			Task *task = readQueue.peek();
			if (!task) {
				if (pops) needUpdate = true;
				break;
			}
			
			while(task->size > task->completion) {
				size_t size = task->size - task->completion;
				if (size > INT_MAX) size = INT_MAX;
				int res = ::recv(
					sockId,
					(unsigned char*)task->data + task->completion,
					(int)size, MSG_DONTWAIT );
				if (res <= 0) {
					ready = false;
					break;
				} else {
					task->completion += res;
				}
			}
			
			if (task->size <= task->completion) {
				readQueue.pop();
				onReadReady(*task, false);
				++pops;
				if (!closing && --count < 0) ready = false;
			}
		}
	}
	
	if (closing) {
		DebugMSG("hup");
		close((bool)(events & EPOLLERR));
		return;
	}
	
	if (events & EPOLLOUT) {
		DebugMSG("output");
		bool ready = true;
		int pops = 0;
		int count = writeQueue.count() + 2;
		while(ready) {
			Task *task = writeQueue.peek();
			if (!task) {
				if (pops) needUpdate = true;
				break;
			}
			while(task->size > task->completion) {
				size_t size = task->size - task->completion;
				if (size > INT_MAX) size = INT_MAX;
				int res = ::send(
					sockId,
					(const unsigned char*)task->data + task->completion,
					(int)size, MSG_DONTWAIT );
				if (res <= 0) {
					ready = false;
					break;
				} else {
					task->completion += res;
				}
			}
			
			if (task->size <= task->completion) {
				writeQueue.pop();
				onWriteReady(*task, false);
				++pops;
				if (--count <= 0) ready = false;
			}
		}
	}
	
	if (needUpdate)
		updateEvents(readQueue.count(), writeQueue.count());
}


void Connection::onReadReady(Task&, bool) { }
void Connection::onWriteReady(Task&, bool) { }