Blob Blame Raw

#include <climits>
#include <cassert>

#include <atomic>
#include <mutex>
#include <thread>
#include <condition_variable>

#include <fcntl.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>



enum WorkMode {
	WM_STOPPED,
	WM_STARTING,
	WM_STARTED,
	WM_STOPPING };


template<typename T, T* T::*NextField>
class Chain {
private:
	Chain() { }
public:
	typedef T Type;
	static constexpr Type* Type::*Next = NextField;
	
	static inline void push(Type* &first, Type *item)
		{ assert(!item.*Next); item.*Next = first; first = &item; }
	static inline void push(std::atomic<Type*> &first, Type &item)
		{ assert(!item.*Next); item.*Next = first; while(!first.compare_exchange_weak(item.*Next, &item)); }
	
	static inline Type* pop(Type* &first) {
		Type *item = first;
		if (!item) return nullptr;
		first = item->next;
		item->next = nullptr;
		return item;
	}
	static inline Type* pop(std::atomic<Type*> &first) {
		Type *item = first;
		do {
			if (!item) return nullptr;
		} while(!first.compare_exchange_weak(item, item->*Next));
		item->next = nullptr;
		return item;
	}
	
	static inline void invert(Type* &first) {
		if (Item *curr = first) {
			Item *prev = nullptr;
			while(Item *next = curr->*Next)
				{ curr->*Next = prev; prev = curr; curr = next; }
			first = curr;
		}
	}
	static inline Type* inverted(Type *first)
		{ invert(first); return first; }
};


// multhreaded input, single threaded output
// simultaneus read and write are allowed
template<typename T, T* T::*NextField>
class QueueMISO {
public:
	typedef T Type;
	typedef Chain<Type, NextField> Chain;
	static constexpr Type* Type::*Next = NextField;
	
private:
	std::atomic<Type*> pushed;
	Type *first;
	std::atomic<int> counter;
	
public:
	inline QueueMISO():
		pushed(nullptr), first() { }
	inline ~QueueMISO()
		{ assert(!pushed && !first && !counter); }
	inline void push(Type &item)
		{ Chain::push(pushed, item); counter++; }
	inline Type& pop()
		{ assert(first); Type &item = *Chain::pop(first); counter++; return item; }
	inline Type *peek()
		{ return first ? first : (first = Chain::inverted(pushed.exchange(nullptr))); }
	inline int count() const
		{ return counter; }
};



class AMutex {
private:
	std::atomic<bool> locked;
public:
	inline AMutex(): locked(false) { }
	inline ~AMutex() { assert(!locked); }
	
	inline bool try_lock()
		{ bool l = false; return locked.compare_exchange_strong(l, true); }
	inline void lock()
		{ bool l = false; while(!locked.compare_exchange_weak(l, true)); }
	inline void unlock()
		{ locked = false; }
};


class ARWMutex {
private:
	std::atomic<int> counter;
public:
	inline ARWMutex(): counter(0) { }
	inline ~ARWMutex() { assert(!counter); }
	
	inline bool try_lock()
		{ int cnt = 0; return counter.compare_exchange_strong(cnt, -1); }
	inline void lock()
		{ int cnt = 0; while(!counter.compare_exchange_weak(cnt, -1)) std::this_thread::yield(); }
	inline void unlock()
		{ assert(counter < 0); counter = 0; }
	
	inline bool try_lock_read() {
		int cnt = counter;
		do { if (cnt < 0) return false; } while (counter.compare_exchange_weak(cnt, cnt + 1));
		return true;
	}
	inline void lock_read()
		{ int cnt = counter; while(cnt < 0 || !counter.compare_exchange_strong(cnt, cnt + 1)); }
	inline void unlock_read()
		{ assert(counter >= 0); counter--; }
};


typedef std::lock_guard<AMutex> ALock;
typedef std::lock_guard<ARWMutex> AWLock;

class ARLock {
private:
	ARWMutex &mutex;
public:
	inline ARLock(ARWMutex &mutex): mutex(mutex) { lock(); }
	inline ~ARLock() { unlock(); }
	inline void lock() { mutex.lock_read(); }
	inline void unlock() { mutex.unlock_read(); }
};



typedef std::lock_guard<std::mutex> Lock;
typedef std::unique_lock<std::mutex> UniqLock;


class TcpProtocol;
class Server;
class Connection;




class TcpSocket {
public:
	TcpProtocol &protocol;
	int sockId;
	Server *server;
	Connection *connection;
	
	explicit TcpSocket(TcpProtocol &protocol, int sockId);
	~TcpSocket();
	void closeReq();
	void update(bool read, bool write);
};


class Connection {
public:
	enum State {
		STATE_NONE,
		STATE_CONNECTING,
		STATE_CONNECTED,
		STATE_CLOSE_HINT,
		STATE_CLOSE_REQ,
		STATE_CLOSED,
	};
	
	struct Task {
		unsigned char *data;
		int size;
		int position;
		Task *next;
		
		explicit Task(void *data = nullptr, int size = 0):
			data((unsigned char*)data), size(size), position() { }
	};
	typedef QueueMISO<Task, &Task::next> Queue;
	
	
	std::atomic<int> stateReaders;  // count of threads that currently reads the state
	std::atomic<State> state;       // write from protocol thread, read from others
	std::atomic<State> stateWanted; // read from protocol thread, write from others
	State stateQuickCopy;           // for use in protocol thread only
	std::atomic<unsigned int> lastEvents;
	
	TcpProtocol *protocol;
	int sockId;
	Queue readQueue;
	Queue writeQueue;
	
	// private
	void updateEvents() {
		assert(stateReaders > 0); // << must be locked
		
		unsigned int events = 0;
		if (readQueue.count()) events |= EPOLLIN;
		if (writeQueue.count()) events |= EPOLLOUT;
		if (lastEvents.exchange(events) == events) return;
		
		struct epoll_event event = {};
		event.data.ptr = this;
		event.events = events;
		epoll_ctl(protocol->pollId, EPOLL_CTL_MOD, sockId, &event);
	}
	
	// public functions
	void connect(TcpProtocol &protocol, const void *address, size_t size);
	void accept(TcpProtocol &protocol, int sockId);
	void closeHint() { wantState(STATE_CLOSE_HINT); }
	void closeReq() { wantState(STATE_CLOSED); }
	
	void closeWait(unsigned long long timeoutUs);
	
	void wantState(State state) {
		State sw = stateWanted;
		while(sw < state && !stateWanted.compare_exchange_weak(sw, state));
	}
	
	bool send(Task &task) {
		stateReaders++;
		State s = state;
		if (s < STATE_CONNECTED || s >= STATE_CLOSED) return false;
		writeQueue.push(task);
		updateEvents();
		stateReaders--;
		return true;
	}
	
	bool receive(Task &task) {
		stateReaders++;
		State s = state;
		if (s < STATE_CONNECTED || s >= STATE_CLOSED) return false;
		readQueue.push(task);
		updateEvents();
		stateReaders--;
		return true;
	}
	
	
	// handlers, must be called from protocol thread
	virtual void onOpen() { }
	virtual void onClose(bool) { }
	virtual void onReadReady(const Task&) { }
	virtual void onWriteReady(const Task&) { }
	
	
	std::mutex mutex;
	WorkMode mode;
	TcpSocket *socket;
	
	
	// must be called from Protocol thread only
	bool open(TcpSocket *socket) {
		{
			Lock lock(mutex);
			if (mode != WM_STOPPED) return false;
			mode = WM_STARTING;
			this->socket = socket;
		}
		
		onOpen();
		
		{ 
			Lock lock(mutex);
			mode = WM_STARTED;
		}
		
		return true;
	}
	
	void close(bool err) {
		{
			Lock lock(mutex);
			assert(mode == WM_STARTED);
			mode = WM_STOPPING;
			if (!err && (!readQueue.empty() || !writeQueue.empty()))
				err = true;
		}
		
		onClose(err);
		
		{
			socket = nullptr;
		}
	}
	
	// must be called internally
	
	void closeReq() {
		Lock lock(mutex);
		if (mode != WM_STARTING && mode != WM_STARTED) return;
		socket->closeReq();
	}
	
	void readReq(void *data, int size) {
		assert(data && size > 0);
		Lock lock(mutex);
		if (mode != WM_STARTING && mode != WM_STARTED) return;
		readQueue.emplace_back(data, size);
		socket->update(!readQueue.empty(), !writeQueue.empty());
	}
	
	void writeReq(void *data, int size) {
		assert(data && size > 0);
		Lock lock(mutex);
		if (mode != WM_STARTING && mode != WM_STARTED) return;
		writeQueue.emplace_back(data, size);
		socket->update(!readQueue.empty(), !writeQueue.empty());
	}
};



class TcpProtocol {
private:
	friend class TcpSocket;
	
	std::mutex mutex;
	std::thread *thread;
	std::condition_variable condModeChanged;
	WorkMode mode;
	int epollId;
	int eventsId;
	
	std::vector<TcpSocket*> socketsToClose;
	
	// lock mutex befor call
	void wakeup() {
		unsigned long long num = 0;
		::write(eventsId, &num, sizeof(num));
	}
	
	void threadRun() {
		const int maxCount = 16;
		struct epoll_event events[maxCount] = {};
		int count = 0;
		std::vector<TcpSocket*> socketsToClose;
		
		{
			Lock lock(mutex);
			assert(mode == WM_STARTING);
			mode = WM_STARTED;
			condModeChanged.notify_all();
		}
		
		while(true) {
			for(int i = 0; i < count; ++i) {
				unsigned int e = events[i].events;
				TcpSocket *socket = (TcpSocket*)events[i].data.ptr;
				if (!socket) {
					unsigned long long num = 0;
					while(::read(eventsId, &num, sizeof(num)) > 0);
				} else
				if (socket->connection) {
					// lock connection
					//if (connection->
				} else
				if (socket->server) {
					// lock server
				} else {
					delete socket;
				}
			}
			
			
			{
				Lock lock(mutex);
				if (mode == WM_STOPPING) {
					socketsToClose
				}
			}
			
			count = epoll_wait(epollId, events, maxCount, 10000);
		}
	}
	
public:
	TcpProtocol():
		thread(),
		mode(WM_STOPPED),
		epollId(-1),
		eventsId(-1)
		{ }
	
	~TcpProtocol()
		{ stop(); }
	
	bool start() {
		{
			Lock lock(mutex);
			if (mode == WM_STARTED || mode == WM_STARTING) return true;
			if (mode != WM_STOPPED) return false;
			mode = WM_STARTING;
		}
		
		condModeChanged.notify_all();
		
		epollId = ::epoll_create(32);
		eventsId = ::eventfd(0, EFD_NONBLOCK);
		assert(epollId >= 0 && eventsId >= 0);
		
		struct epoll_event event = {};
		event.events = EPOLLIN;
		epoll_ctl(epollId, EPOLL_CTL_ADD, eventsId, &event);
		
		thread = new std::thread(&TcpProtocol::threadRun, this);
	}	
	
	void stop() {
		{
			UniqLock lock(mutex);
			while(true) {
				if (mode == WM_STOPPED) return;
				if (mode == WM_STARTED) break;
				condModeChanged.wait(lock);
			}
			mode = WM_STOPPING;
		}
		
		condModeChanged.notify_all();
		thread->join();
		delete thread;
		thread = nullptr;
		::close(epollId);
		::close(eventsId);
		
		{
			Lock lock(mutex);
			mode = WM_STOPPED;
		}
		condModeChanged.notify_all();
	}
};



TcpSocket::TcpSocket(TcpProtocol &protocol, int sockId):
	protocol(protocol),
	sockId(sockId)
{
	int tcp_nodelay = 1;
	setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int));
	fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
	struct epoll_event event = {};
	event.data.ptr = this;
	epoll_ctl(protocol.epollId, EPOLL_CTL_ADD, sockId, &event);
}

TcpSocket::~TcpSocket() {
	struct epoll_event event = {};
	epoll_ctl(protocol.epollId, EPOLL_CTL_DEL, sockId, &event);
	::close(sockId);
}

void TcpSocket::closeReq() {
	Lock lock(protocol.mutex);
	socketsToClose.push_back(this);
	protocol.wakeup();
}

void TcpSocket::update(bool read, bool write) {
	struct epoll_event event = {};
	event.data.ptr = this;
	if (read) event.events |= EPOLLIN;
	if (write) event.events |= EPOLLOUT;
	epoll_ctl(protocol.epollId, EPOLL_CTL_MOD, sockId, &event);
}