Blob Blame Raw

#include <climits>
#include <cstring>

#include <fcntl.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>

#include "protocol.h"


#define DebugMSG     HideDebugMSG


Protocol::Protocol():
	thread(),
	epollFd(-1),
	eventFd(-1),
	state(STATE_NONE),
	stateWanted(STATE_NONE),
	closeWaiters(0),
	finishWaiters(0),
	sockFirst(), sockLast()
	{ }


Protocol::~Protocol() {
	assert(state == STATE_NONE || state == STATE_FINISHED);
	closeWait();
	while(finishWaiters > 0) std::this_thread::yield();
	DebugMSG();
	if (thread) thread->join();
	DebugMSG("deleted");
}


bool Protocol::open() {
	DebugMSG();
	State s = state;
	do {
		if (s != STATE_NONE && s != STATE_FINISHED) return false;
	} while(!state.compare_exchange_weak(s, STATE_OPENING));
	if (s == STATE_FINISHED) while(finishWaiters > 0) std::this_thread::yield();
	if (thread) thread->join();
	
	thread = new std::thread(&Protocol::threadRun, this);
	state = STATE_INITIALIZING;
	DebugMSG();
	return true;
}


void Protocol::wantState(State state) {
	ReadLock lock(stateProtector);
	State s = this->state;
	DebugMSG("%d, %d", s, state);
	if (s <= STATE_OPENING || s >= STATE_CLOSING || s >= state) return;
	DebugMSG("apply");
	
	State sw = stateWanted;
	do {
		if (sw >= state) return;
	} while(!stateWanted.compare_exchange_weak(sw, state));
	wakeup();
}


void Protocol::closeReq()
	{ DebugMSG(); wantState(STATE_CLOSE_REQ); }
void Protocol::close()
	{ DebugMSG(); wantState(STATE_CLOSING); }


void Protocol::closeWait(unsigned long long timeoutUs, bool withReq) {
	DebugMSG();
	stateProtector.lock();
	State s = state;
	if (s <= STATE_OPENING || s >= STATE_FINISHED)
		{ stateProtector.unlock(); return; }
	
	
	if (s == STATE_CLOSED) {
		finishWaiters++;
		stateProtector.unlock();
	} else {
		if (withReq && s < STATE_CLOSE_REQ) closeReq();
		closeWaiters++;
		stateProtector.unlock();
		
		std::mutex fakeMutex;
		{
			std::unique_lock<std::mutex> fakeLock(fakeMutex);
			if (timeoutUs && s < STATE_CLOSING) {
				std::chrono::steady_clock::time_point t
					= std::chrono::steady_clock::now()
						+ std::chrono::microseconds(timeoutUs);
				while(s < STATE_CLOSED) {
					if (t <= std::chrono::steady_clock::now()) { close(); break; }
					DebugMSG("wait_until, %d", s);
					closeCondition.wait_until(fakeLock, t);
					s = state;
				}
			}
			if (s < STATE_CLOSING) close();
			while(s < STATE_CLOSED) {
				DebugMSG("wait, %d", s);
				closeCondition.wait(fakeLock);
				s = state;
			}
			DebugMSG("awaiten");
		}
		
		finishWaiters++;
		closeWaiters--;
	}
	
	
	while(s != STATE_FINISHED) { std::this_thread::yield(); s = state; }
	
	finishWaiters--;
	DebugMSG();
}


void Protocol::wakeup() {
	unsigned long long num = 0;
	::write(eventFd, &num, sizeof(num));
}


void Protocol::threadRun() {
	DebugMSG();
	epollFd = ::epoll_create(32);
	eventFd = ::eventfd(0, EFD_NONBLOCK);
	assert(epollFd >= 0);
	assert(eventFd >= 0);
	
	struct epoll_event event = {};
	event.events = EPOLLIN;
	epoll_ctl(epollFd, EPOLL_CTL_ADD, eventFd, &event);
	
	const int maxCount = 16;
	struct epoll_event events[maxCount] = {};
	int count = 0;
	
	State stateLocal = STATE_OPEN;
	state = stateLocal;
	
	while(true) {
		for(int i = 0; i < count; ++i) {
			if (Socket *socket = (Socket*)events[i].data.ptr) {
				socket->handleEvents(events[i].events);
			} else {
				unsigned long long num = 0;
				while(::read(eventFd, &num, sizeof(num)) > 0 && !num);
			}
		}
		
		State sw = stateWanted;
		if (stateLocal == STATE_OPEN && sw == STATE_CLOSE_REQ) {
			DebugMSG("to closse req");
			state = stateLocal = STATE_CLOSE_REQ;
			stateProtector.wait();
			for(Socket *socket = sockFirst; socket; socket = socket->next)
				socket->closeReq();
		} else
		if (stateLocal >= STATE_OPEN && stateLocal <= STATE_CLOSE_REQ && sw == STATE_CLOSING) {
			DebugMSG("to closing");
			state = stateLocal = STATE_CLOSING;
			stateProtector.wait();
			for(Socket *socket = sockFirst; socket; socket = socket->next)
				socket->close(true);
		}
		
		while(sockQueue.peek()) {
			Socket &socket = sockQueue.pop();
			socket.enqueued = false;
			socket.handleState();
		}
		
		if (stateLocal == STATE_CLOSING)
			break;
		
		count = epoll_wait(epollFd, events, maxCount, 1000);
		DebugMSG("epoll count: %d", count);
	}
	
	assert(!sockFirst);
	
	DebugMSG("closeWaiters");
	state = STATE_CLOSED;
	while(closeWaiters > 0)
		{ closeCondition.notify_all(); std::this_thread::yield(); }
	
	epoll_ctl(epollFd, EPOLL_CTL_DEL, eventFd, &event);

	::close(epollFd);
	::close(eventFd);
	epollFd = -1;
	eventFd = -1;
	assert(!sockFirst);
	assert(!sockQueue.peek());
	
	stateWanted = STATE_NONE;
	stateLocal = STATE_NONE;
	state = STATE_FINISHED;
	DebugMSG("finished");
}