Blame tcplab.cpp1

71057f
71057f
#include <climits></climits>
71057f
#include <cassert></cassert>
71057f
71057f
#include <atomic></atomic>
71057f
#include <mutex></mutex>
71057f
#include <thread></thread>
71057f
#include <condition_variable></condition_variable>
71057f
71057f
#include <fcntl.h></fcntl.h>
71057f
#include <unistd.h></unistd.h>
71057f
#include <sys types.h=""></sys>
71057f
#include <sys socket.h=""></sys>
71057f
#include <sys epoll.h=""></sys>
71057f
#include <sys eventfd.h=""></sys>
71057f
#include <netinet in.h=""></netinet>
71057f
#include <netinet tcp.h=""></netinet>
71057f
#include <netdb.h></netdb.h>
71057f
71057f
71057f
71057f
enum WorkMode {
71057f
	WM_STOPPED,
71057f
	WM_STARTING,
71057f
	WM_STARTED,
71057f
	WM_STOPPING };
71057f
71057f
71057f
template<typename t*="" t,="" t::*nextfield=""></typename>
71057f
class Chain {
71057f
private:
71057f
	Chain() { }
71057f
public:
71057f
	typedef T Type;
71057f
	static constexpr Type* Type::*Next = NextField;
71057f
	
71057f
	static inline void push(Type* &first, Type *item)
71057f
		{ assert(!item.*Next); item.*Next = first; first = &item; }
71057f
	static inline void push(std::atomic<type*> &first, Type &item)</type*>
71057f
		{ assert(!item.*Next); item.*Next = first; while(!first.compare_exchange_weak(item.*Next, &item)); }
71057f
	
71057f
	static inline Type* pop(Type* &first) {
71057f
		Type *item = first;
71057f
		if (!item) return nullptr;
71057f
		first = item->next;
71057f
		item->next = nullptr;
71057f
		return item;
71057f
	}
71057f
	static inline Type* pop(std::atomic<type*> &first) {</type*>
71057f
		Type *item = first;
71057f
		do {
71057f
			if (!item) return nullptr;
71057f
		} while(!first.compare_exchange_weak(item, item->*Next));
71057f
		item->next = nullptr;
71057f
		return item;
71057f
	}
71057f
	
71057f
	static inline void invert(Type* &first) {
71057f
		if (Item *curr = first) {
71057f
			Item *prev = nullptr;
71057f
			while(Item *next = curr->*Next)
71057f
				{ curr->*Next = prev; prev = curr; curr = next; }
71057f
			first = curr;
71057f
		}
71057f
	}
71057f
	static inline Type* inverted(Type *first)
71057f
		{ invert(first); return first; }
71057f
};
71057f
71057f
71057f
// multhreaded input, single threaded output
71057f
// simultaneus read and write are allowed
71057f
template<typename t*="" t,="" t::*nextfield=""></typename>
71057f
class QueueMISO {
71057f
public:
71057f
	typedef T Type;
71057f
	typedef Chain<type, nextfield=""> Chain;</type,>
71057f
	static constexpr Type* Type::*Next = NextField;
71057f
	
71057f
private:
71057f
	std::atomic<type*> pushed;</type*>
71057f
	Type *first;
71057f
	std::atomic<int> counter;</int>
71057f
	
71057f
public:
71057f
	inline QueueMISO():
71057f
		pushed(nullptr), first() { }
71057f
	inline ~QueueMISO()
71057f
		{ assert(!pushed && !first && !counter); }
71057f
	inline void push(Type &item)
71057f
		{ Chain::push(pushed, item); counter++; }
71057f
	inline Type& pop()
71057f
		{ assert(first); Type &item = *Chain::pop(first); counter++; return item; }
71057f
	inline Type *peek()
71057f
		{ return first ? first : (first = Chain::inverted(pushed.exchange(nullptr))); }
71057f
	inline int count() const
71057f
		{ return counter; }
71057f
};
71057f
71057f
71057f
71057f
class AMutex {
71057f
private:
71057f
	std::atomic<bool> locked;</bool>
71057f
public:
71057f
	inline AMutex(): locked(false) { }
71057f
	inline ~AMutex() { assert(!locked); }
71057f
	
71057f
	inline bool try_lock()
71057f
		{ bool l = false; return locked.compare_exchange_strong(l, true); }
71057f
	inline void lock()
71057f
		{ bool l = false; while(!locked.compare_exchange_weak(l, true)); }
71057f
	inline void unlock()
71057f
		{ locked = false; }
71057f
};
71057f
71057f
71057f
class ARWMutex {
71057f
private:
71057f
	std::atomic<int> counter;</int>
71057f
public:
71057f
	inline ARWMutex(): counter(0) { }
71057f
	inline ~ARWMutex() { assert(!counter); }
71057f
	
71057f
	inline bool try_lock()
71057f
		{ int cnt = 0; return counter.compare_exchange_strong(cnt, -1); }
71057f
	inline void lock()
71057f
		{ int cnt = 0; while(!counter.compare_exchange_weak(cnt, -1)) std::this_thread::yield(); }
71057f
	inline void unlock()
71057f
		{ assert(counter < 0); counter = 0; }
71057f
	
71057f
	inline bool try_lock_read() {
71057f
		int cnt = counter;
71057f
		do { if (cnt < 0) return false; } while (counter.compare_exchange_weak(cnt, cnt + 1));
71057f
		return true;
71057f
	}
71057f
	inline void lock_read()
71057f
		{ int cnt = counter; while(cnt < 0 || !counter.compare_exchange_strong(cnt, cnt + 1)); }
71057f
	inline void unlock_read()
71057f
		{ assert(counter >= 0); counter--; }
71057f
};
71057f
71057f
71057f
typedef std::lock_guard<amutex> ALock;</amutex>
71057f
typedef std::lock_guard<arwmutex> AWLock;</arwmutex>
71057f
71057f
class ARLock {
71057f
private:
71057f
	ARWMutex &mutex;
71057f
public:
71057f
	inline ARLock(ARWMutex &mutex): mutex(mutex) { lock(); }
71057f
	inline ~ARLock() { unlock(); }
71057f
	inline void lock() { mutex.lock_read(); }
71057f
	inline void unlock() { mutex.unlock_read(); }
71057f
};
71057f
71057f
71057f
71057f
typedef std::lock_guard<std::mutex> Lock;</std::mutex>
71057f
typedef std::unique_lock<std::mutex> UniqLock;</std::mutex>
71057f
71057f
71057f
class TcpProtocol;
71057f
class Server;
71057f
class Connection;
71057f
71057f
71057f
71057f
71057f
class TcpSocket {
71057f
public:
71057f
	TcpProtocol &protocol;
71057f
	int sockId;
71057f
	Server *server;
71057f
	Connection *connection;
71057f
	
71057f
	explicit TcpSocket(TcpProtocol &protocol, int sockId);
71057f
	~TcpSocket();
71057f
	void closeReq();
71057f
	void update(bool read, bool write);
71057f
};
71057f
71057f
71057f
class Connection {
71057f
public:
71057f
	enum State {
71057f
		STATE_NONE,
71057f
		STATE_CONNECTING,
71057f
		STATE_CONNECTED,
71057f
		STATE_CLOSE_HINT,
71057f
		STATE_CLOSE_REQ,
71057f
		STATE_CLOSED,
71057f
	};
71057f
	
71057f
	struct Task {
71057f
		unsigned char *data;
71057f
		int size;
71057f
		int position;
71057f
		Task *next;
71057f
		
71057f
		explicit Task(void *data = nullptr, int size = 0):
71057f
			data((unsigned char*)data), size(size), position() { }
71057f
	};
71057f
	typedef QueueMISO<task, &task::next=""> Queue;</task,>
71057f
	
71057f
	
71057f
	std::atomic<int> stateReaders;  // count of threads that currently reads the state</int>
71057f
	std::atomic<state> state;       // write from protocol thread, read from others</state>
71057f
	std::atomic<state> stateWanted; // read from protocol thread, write from others</state>
71057f
	State stateQuickCopy;           // for use in protocol thread only
71057f
	std::atomic<unsigned int=""> lastEvents;</unsigned>
71057f
	
71057f
	TcpProtocol *protocol;
71057f
	int sockId;
71057f
	Queue readQueue;
71057f
	Queue writeQueue;
71057f
	
71057f
	// private
71057f
	void updateEvents() {
71057f
		assert(stateReaders > 0); // << must be locked
71057f
		
71057f
		unsigned int events = 0;
71057f
		if (readQueue.count()) events |= EPOLLIN;
71057f
		if (writeQueue.count()) events |= EPOLLOUT;
71057f
		if (lastEvents.exchange(events) == events) return;
71057f
		
71057f
		struct epoll_event event = {};
71057f
		event.data.ptr = this;
71057f
		event.events = events;
71057f
		epoll_ctl(protocol->pollId, EPOLL_CTL_MOD, sockId, &event);
71057f
	}
71057f
	
71057f
	// public functions
71057f
	void connect(TcpProtocol &protocol, const void *address, size_t size);
71057f
	void accept(TcpProtocol &protocol, int sockId);
71057f
	void closeHint() { wantState(STATE_CLOSE_HINT); }
71057f
	void closeReq() { wantState(STATE_CLOSED); }
71057f
	
71057f
	void closeWait(unsigned long long timeoutUs);
71057f
	
71057f
	void wantState(State state) {
71057f
		State sw = stateWanted;
71057f
		while(sw < state && !stateWanted.compare_exchange_weak(sw, state));
71057f
	}
71057f
	
71057f
	bool send(Task &task) {
71057f
		stateReaders++;
71057f
		State s = state;
71057f
		if (s < STATE_CONNECTED || s >= STATE_CLOSED) return false;
71057f
		writeQueue.push(task);
71057f
		updateEvents();
71057f
		stateReaders--;
71057f
		return true;
71057f
	}
71057f
	
71057f
	bool receive(Task &task) {
71057f
		stateReaders++;
71057f
		State s = state;
71057f
		if (s < STATE_CONNECTED || s >= STATE_CLOSED) return false;
71057f
		readQueue.push(task);
71057f
		updateEvents();
71057f
		stateReaders--;
71057f
		return true;
71057f
	}
71057f
	
71057f
	
71057f
	// handlers, must be called from protocol thread
71057f
	virtual void onOpen() { }
71057f
	virtual void onClose(bool) { }
71057f
	virtual void onReadReady(const Task&) { }
71057f
	virtual void onWriteReady(const Task&) { }
71057f
	
71057f
	
71057f
	std::mutex mutex;
71057f
	WorkMode mode;
71057f
	TcpSocket *socket;
71057f
	
71057f
	
71057f
	// must be called from Protocol thread only
71057f
	bool open(TcpSocket *socket) {
71057f
		{
71057f
			Lock lock(mutex);
71057f
			if (mode != WM_STOPPED) return false;
71057f
			mode = WM_STARTING;
71057f
			this->socket = socket;
71057f
		}
71057f
		
71057f
		onOpen();
71057f
		
71057f
		{ 
71057f
			Lock lock(mutex);
71057f
			mode = WM_STARTED;
71057f
		}
71057f
		
71057f
		return true;
71057f
	}
71057f
	
71057f
	void close(bool err) {
71057f
		{
71057f
			Lock lock(mutex);
71057f
			assert(mode == WM_STARTED);
71057f
			mode = WM_STOPPING;
71057f
			if (!err && (!readQueue.empty() || !writeQueue.empty()))
71057f
				err = true;
71057f
		}
71057f
		
71057f
		onClose(err);
71057f
		
71057f
		{
71057f
			socket = nullptr;
71057f
		}
71057f
	}
71057f
	
71057f
	// must be called internally
71057f
	
71057f
	void closeReq() {
71057f
		Lock lock(mutex);
71057f
		if (mode != WM_STARTING && mode != WM_STARTED) return;
71057f
		socket->closeReq();
71057f
	}
71057f
	
71057f
	void readReq(void *data, int size) {
71057f
		assert(data && size > 0);
71057f
		Lock lock(mutex);
71057f
		if (mode != WM_STARTING && mode != WM_STARTED) return;
71057f
		readQueue.emplace_back(data, size);
71057f
		socket->update(!readQueue.empty(), !writeQueue.empty());
71057f
	}
71057f
	
71057f
	void writeReq(void *data, int size) {
71057f
		assert(data && size > 0);
71057f
		Lock lock(mutex);
71057f
		if (mode != WM_STARTING && mode != WM_STARTED) return;
71057f
		writeQueue.emplace_back(data, size);
71057f
		socket->update(!readQueue.empty(), !writeQueue.empty());
71057f
	}
71057f
};
71057f
71057f
71057f
71057f
class TcpProtocol {
71057f
private:
71057f
	friend class TcpSocket;
71057f
	
71057f
	std::mutex mutex;
71057f
	std::thread *thread;
71057f
	std::condition_variable condModeChanged;
71057f
	WorkMode mode;
71057f
	int epollId;
71057f
	int eventsId;
71057f
	
71057f
	std::vector<tcpsocket*> socketsToClose;</tcpsocket*>
71057f
	
71057f
	// lock mutex befor call
71057f
	void wakeup() {
71057f
		unsigned long long num = 0;
71057f
		::write(eventsId, &num, sizeof(num));
71057f
	}
71057f
	
71057f
	void threadRun() {
71057f
		const int maxCount = 16;
71057f
		struct epoll_event events[maxCount] = {};
71057f
		int count = 0;
71057f
		std::vector<tcpsocket*> socketsToClose;</tcpsocket*>
71057f
		
71057f
		{
71057f
			Lock lock(mutex);
71057f
			assert(mode == WM_STARTING);
71057f
			mode = WM_STARTED;
71057f
			condModeChanged.notify_all();
71057f
		}
71057f
		
71057f
		while(true) {
71057f
			for(int i = 0; i < count; ++i) {
71057f
				unsigned int e = events[i].events;
71057f
				TcpSocket *socket = (TcpSocket*)events[i].data.ptr;
71057f
				if (!socket) {
71057f
					unsigned long long num = 0;
71057f
					while(::read(eventsId, &num, sizeof(num)) > 0);
71057f
				} else
71057f
				if (socket->connection) {
71057f
					// lock connection
71057f
					//if (connection->
71057f
				} else
71057f
				if (socket->server) {
71057f
					// lock server
71057f
				} else {
71057f
					delete socket;
71057f
				}
71057f
			}
71057f
			
71057f
			
71057f
			{
71057f
				Lock lock(mutex);
71057f
				if (mode == WM_STOPPING) {
71057f
					socketsToClose
71057f
				}
71057f
			}
71057f
			
71057f
			count = epoll_wait(epollId, events, maxCount, 10000);
71057f
		}
71057f
	}
71057f
	
71057f
public:
71057f
	TcpProtocol():
71057f
		thread(),
71057f
		mode(WM_STOPPED),
71057f
		epollId(-1),
71057f
		eventsId(-1)
71057f
		{ }
71057f
	
71057f
	~TcpProtocol()
71057f
		{ stop(); }
71057f
	
71057f
	bool start() {
71057f
		{
71057f
			Lock lock(mutex);
71057f
			if (mode == WM_STARTED || mode == WM_STARTING) return true;
71057f
			if (mode != WM_STOPPED) return false;
71057f
			mode = WM_STARTING;
71057f
		}
71057f
		
71057f
		condModeChanged.notify_all();
71057f
		
71057f
		epollId = ::epoll_create(32);
71057f
		eventsId = ::eventfd(0, EFD_NONBLOCK);
71057f
		assert(epollId >= 0 && eventsId >= 0);
71057f
		
71057f
		struct epoll_event event = {};
71057f
		event.events = EPOLLIN;
71057f
		epoll_ctl(epollId, EPOLL_CTL_ADD, eventsId, &event);
71057f
		
71057f
		thread = new std::thread(&TcpProtocol::threadRun, this);
71057f
	}	
71057f
	
71057f
	void stop() {
71057f
		{
71057f
			UniqLock lock(mutex);
71057f
			while(true) {
71057f
				if (mode == WM_STOPPED) return;
71057f
				if (mode == WM_STARTED) break;
71057f
				condModeChanged.wait(lock);
71057f
			}
71057f
			mode = WM_STOPPING;
71057f
		}
71057f
		
71057f
		condModeChanged.notify_all();
71057f
		thread->join();
71057f
		delete thread;
71057f
		thread = nullptr;
71057f
		::close(epollId);
71057f
		::close(eventsId);
71057f
		
71057f
		{
71057f
			Lock lock(mutex);
71057f
			mode = WM_STOPPED;
71057f
		}
71057f
		condModeChanged.notify_all();
71057f
	}
71057f
};
71057f
71057f
71057f
71057f
TcpSocket::TcpSocket(TcpProtocol &protocol, int sockId):
71057f
	protocol(protocol),
71057f
	sockId(sockId)
71057f
{
71057f
	int tcp_nodelay = 1;
71057f
	setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int));
71057f
	fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
71057f
	struct epoll_event event = {};
71057f
	event.data.ptr = this;
71057f
	epoll_ctl(protocol.epollId, EPOLL_CTL_ADD, sockId, &event);
71057f
}
71057f
71057f
TcpSocket::~TcpSocket() {
71057f
	struct epoll_event event = {};
71057f
	epoll_ctl(protocol.epollId, EPOLL_CTL_DEL, sockId, &event);
71057f
	::close(sockId);
71057f
}
71057f
71057f
void TcpSocket::closeReq() {
71057f
	Lock lock(protocol.mutex);
71057f
	socketsToClose.push_back(this);
71057f
	protocol.wakeup();
71057f
}
71057f
71057f
void TcpSocket::update(bool read, bool write) {
71057f
	struct epoll_event event = {};
71057f
	event.data.ptr = this;
71057f
	if (read) event.events |= EPOLLIN;
71057f
	if (write) event.events |= EPOLLOUT;
71057f
	epoll_ctl(protocol.epollId, EPOLL_CTL_MOD, sockId, &event);
71057f
}
71057f
71057f