Blame protocol.cpp

71057f
71057f
#include <climits></climits>
71057f
#include <cstring></cstring>
71057f
71057f
#include <fcntl.h></fcntl.h>
71057f
#include <unistd.h></unistd.h>
71057f
#include <sys types.h=""></sys>
71057f
#include <sys socket.h=""></sys>
71057f
#include <sys epoll.h=""></sys>
71057f
#include <sys eventfd.h=""></sys>
71057f
#include <netinet in.h=""></netinet>
71057f
#include <netinet tcp.h=""></netinet>
71057f
#include <netdb.h></netdb.h>
71057f
71057f
#include "protocol.h"
71057f
71057f
71057f
71057f
Protocol::~Protocol() {
71057f
	assert(state == STATE_NONE || state == STATE_CLOSED);
71057f
	closeWait();
71057f
	while(finishWaiters > 0) std::this_thread::yield();
71057f
	if (thread) thread->join();
71057f
}
71057f
71057f
71057f
bool Protocol::open() {
71057f
	State s = state;
71057f
	do {
71057f
		if (s != STATE_NONE && s != STATE_FINISHED) return false;
71057f
	} while(!state.compare_exchange_weak(s, STATE_OPENING));
71057f
	if (s == STATE_FINISHED) while(finishWaiters > 0) std::this_thread::yield();
71057f
	if (thread) thread->join();
71057f
	
71057f
	thread = new std::thread(&Protocol::threadRun, this);
71057f
	return true;
71057f
}
71057f
71057f
71057f
void Protocol::wantState(State state) {
71057f
	ReadLock lock(stateProtector);
71057f
	State s = this->state;
71057f
	if (s <= STATE_OPENING || s >= STATE_CLOSING || s >= state) return;
71057f
	
71057f
	State sw = stateWanted;
71057f
	do {
71057f
		if (sw <= STATE_OPENING || sw >= state) return;
71057f
	} while(!stateWanted.compare_exchange_weak(sw, state));
71057f
	wakeup();
71057f
}
71057f
71057f
71057f
void Protocol::closeReq()
71057f
	{ wantState(STATE_CLOSE_REQ); }
71057f
void Protocol::close()
71057f
	{ wantState(STATE_CLOSING); }
71057f
71057f
71057f
void Protocol::closeWait(unsigned long long timeoutUs, bool withReq) {
71057f
	stateProtector.lock();
71057f
	State s = state;
71057f
	if (s <= STATE_OPENING || s >= STATE_FINISHED)
71057f
		{ stateProtector.unlock(); return; }
71057f
	
71057f
	
71057f
	if (s == STATE_CLOSED) {
71057f
		finishWaiters++;
71057f
		stateProtector.unlock();
71057f
	} else {
71057f
		if (withReq && s < STATE_CLOSE_REQ) closeReq();
71057f
		closeWaiters++;
71057f
		stateProtector.unlock();
71057f
		
71057f
		std::mutex fakeMutex;
71057f
		{
71057f
			std::unique_lock<std::mutex> fakeLock(fakeMutex);</std::mutex>
71057f
			if (timeoutUs && s < STATE_CLOSING) {
71057f
				std::chrono::steady_clock::time_point t
71057f
					= std::chrono::steady_clock::now()
71057f
						+ std::chrono::microseconds(timeoutUs);
71057f
				while(s < STATE_CLOSED) {
71057f
					if (t <= std::chrono::steady_clock::now()) { close(); break; }
71057f
					closeCondition.wait_until(fakeLock, t);
71057f
					s = state;
71057f
				}
71057f
			}
71057f
			if (s < STATE_CLOSING) close();
71057f
			while(s < STATE_CLOSED) {
71057f
				closeCondition.wait(fakeLock);
71057f
				s = state;
71057f
			}
71057f
		}
71057f
		
71057f
		finishWaiters++;
71057f
		closeWaiters--;
71057f
	}
71057f
	
71057f
	
71057f
	while(s != STATE_FINISHED) { std::this_thread::yield(); s = state; }
71057f
	
71057f
	finishWaiters--;
71057f
}
71057f
71057f
71057f
void Protocol::wakeup() {
71057f
	unsigned long long num = 0;
71057f
	::write(eventFd, &num, sizeof(num));
71057f
}
71057f
71057f
71057f
void Protocol::updateEvents(Connection &connection) {
71057f
	unsigned int events = 0;
71057f
	if (connection.readQueue.count()) events |= EPOLLIN;
71057f
	if (connection.writeQueue.count()) events |= EPOLLOUT;
71057f
	if (connection.events.exchange(events) == events) return;
71057f
	
71057f
	struct epoll_event event = {};
71057f
	event.data.ptr = &connection;
71057f
	event.events = events;
71057f
	epoll_ctl(epollFd, EPOLL_CTL_MOD, connection.sockId, &event);
71057f
}
71057f
71057f
71057f
void Protocol::initSocket(int sockId) {
71057f
	int tcp_nodelay = 1;
71057f
	setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int));
71057f
	fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
71057f
}
71057f
71057f
71057f
void Protocol::threadConnQueue() {
71057f
	while(connQueue.peek()) {
71057f
		Connection &connection = connQueue.pop();
71057f
		connection.enqueued = false;
71057f
		
71057f
		if (connection.stateLocal == Connection::STATE_NONE) {
71057f
			int sockId = connection.sockId;
71057f
			bool connected = false;
71057f
			
71057f
			if (sockId >= 0) {
71057f
				initSocket(sockId);
71057f
			} else
71057f
			if (connection.addressSize >= sizeof(sa_family_t)) {
71057f
				struct sockaddr *addr = (struct sockaddr*)connection.address;
71057f
				sockId = ::socket(addr->sa_family, SOCK_STREAM, 0);
71057f
				if (sockId >= 0) {
71057f
					initSocket(sockId);
71057f
					int res = ::connect(sockId, addr, (int)connection.addressSize);
71057f
					if (res == 0) {
71057f
						connected = true;
71057f
					} else
71057f
					if (res != EINPROGRESS) {
71057f
						::close(sockId);
71057f
						sockId = -1;
71057f
					}
71057f
				}
71057f
			}
71057f
			
71057f
			if (sockId >= 0) {
71057f
				connection.protocol = this;
71057f
				connection.prev = connLast;
71057f
				assert(!connection.next);
71057f
				(connection.prev ? connection.prev->next : connFirst) = &connection;
71057f
				if (connection.server) {
71057f
					connection.srvPrev = connection.server->connLast;
71057f
					assert(!connection.srvNext);
71057f
					(connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = &connection;
71057f
				}
71057f
				connection.sockId = sockId;
71057f
				
71057f
				struct epoll_event event = {};
71057f
				event.data.ptr = &connection;
71057f
				if (connected) {
71057f
					epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event);
71057f
					connection.onOpen(connection.address, connection.addressSize);
71057f
					connection.state = connection.stateLocal = Connection::STATE_OPEN;
71057f
					connection.stateProtector.wait();
71057f
				} else {
71057f
					event.events = EPOLLOUT;
71057f
					epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event);
71057f
					connection.state = connection.stateLocal = Connection::STATE_CONNECTING;
71057f
					connection.stateProtector.wait();
71057f
				}
71057f
			} else {
71057f
				connection.sockId = -1;
71057f
				connection.server = nullptr;
71057f
				connection.onOpeningError();
71057f
				connection.state = Connection::STATE_CLOSED;
71057f
				while(connection.closeWaiters > 0)
71057f
					{ connection.closeCondition.notify_all(); std::this_thread::yield(); }
71057f
				connection.state = Connection::STATE_FINISHED;
71057f
			}
71057f
		} else {
71057f
			Connection::State stateWanted = connection.stateWanted;
71057f
			
71057f
			if ( connection.stateLocal >= Connection::STATE_CONNECTING
71057f
			  && connection.stateLocal <= Connection::STATE_OPEN
71057f
			  && stateWanted == Connection::STATE_CLOSE_REQ )
71057f
			{
71057f
				connection.state = connection.stateLocal = Connection::STATE_CLOSE_REQ;
71057f
				connection.stateProtector.wait();
71057f
				connection.onCloseReqested();
71057f
			} else
71057f
			if ( connection.stateLocal >= Connection::STATE_CONNECTING
71057f
			  && connection.stateLocal <= Connection::STATE_CLOSE_REQ
71057f
			  && stateWanted == Connection::STATE_CLOSING )
71057f
			{
71057f
				connection.state = connection.stateLocal = Connection::STATE_CLOSING;
71057f
				connection.stateProtector.wait();
71057f
				while(Connection::Task *task = connection.readQueue.peek()) {
71057f
					connection.readQueue.pop();
71057f
					if (!connection.onReadReady(*task, true))
71057f
						assert(false);
71057f
				}
71057f
				while(Connection::Task *task = connection.writeQueue.peek()) {
71057f
					connection.writeQueue.pop();
71057f
					if (!connection.onWriteReady(*task, true))
71057f
						assert(false);
71057f
				}
71057f
				connection.onClose(false);
71057f
			}
71057f
			
71057f
			if (connection.stateLocal == Connection::STATE_CLOSING && !connection.enqueued) {
71057f
				if (connection.server) {
71057f
					(connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = connection.srvNext;
71057f
					(connection.srvNext ? connection.srvNext->srvPrev : connection.server->connLast) = connection.srvPrev;
71057f
					connection.server = nullptr;
71057f
				}
71057f
				(connection.prev ? connection.prev->next : connFirst) = connection.next;
71057f
				(connection.next ? connection.next->prev : connLast) = connection.prev;
71057f
				connection.sockId = -1;
71057f
				connection.protocol = nullptr;
71057f
				memset(connection.address, 0, sizeof(connection.address));
71057f
				connection.addressSize = 0;
71057f
				connection.stateWanted = Connection::STATE_NONE;
71057f
				connection.stateLocal = Connection::STATE_NONE;
71057f
				connection.state = Connection::STATE_CLOSED;
71057f
				while(connection.closeWaiters > 0)
71057f
					{ connection.closeCondition.notify_all(); std::this_thread::yield(); }
71057f
				connection.state = Connection::STATE_FINISHED;
71057f
			}
71057f
		}
71057f
	}
71057f
}
71057f
71057f
71057f
void Protocol::threadConnEvents(Connection &connection, unsigned int events) {
71057f
	bool needUpdate = false;
71057f
	
71057f
	if (connection.stateLocal == Connection::STATE_CONNECTING) {
71057f
		if (events & EPOLLHUP) {
71057f
			connection.onOpeningError();
71057f
			if (connection.server) {
71057f
				(connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = connection.srvNext;
71057f
				(connection.srvNext ? connection.srvNext->srvPrev : connection.server->connLast) = connection.srvPrev;
71057f
				connection.server = nullptr;
71057f
			}
71057f
			(connection.prev ? connection.prev->next : connFirst) = connection.next;
71057f
			(connection.next ? connection.next->prev : connLast) = connection.prev;
71057f
			connection.sockId = -1;
71057f
			connection.protocol = nullptr;
71057f
			memset(connection.address, 0, sizeof(connection.address));
71057f
			connection.addressSize = 0;
71057f
			connection.stateWanted = Connection::STATE_NONE;
71057f
			connection.stateLocal = Connection::STATE_NONE;
71057f
			connection.state = Connection::STATE_CLOSED;
71057f
			while(connection.closeWaiters > 0)
71057f
				{ connection.closeCondition.notify_all(); std::this_thread::yield(); }
71057f
			connection.state = Connection::STATE_FINISHED;
71057f
		}
71057f
		
71057f
		if (events & EPOLLOUT) {
71057f
			connection.onOpen(connection.address, connection.addressSize);
71057f
			connection.state = connection.stateLocal = Connection::STATE_OPEN;
71057f
			connection.stateProtector.wait();
71057f
			needUpdate = true;
71057f
		}
71057f
	}
71057f
71057f
	if ( connection.stateLocal < Connection::STATE_OPEN
71057f
	  || connection.stateLocal >= Connection::STATE_CLOSING )
71057f
		return;
71057f
	
71057f
	if (events & EPOLLIN) {
71057f
		bool ready = true;
71057f
		int pops = 0;
71057f
		while(ready) {
71057f
			Connection::Task *task = connection.readQueue.peek();
71057f
			if (!task) {
71057f
				if (pops) needUpdate = true;
71057f
				break;
71057f
			}
71057f
			size_t completion = 0;
71057f
			while(task->size > task->completion) {
71057f
				size_t size = task->size - task->completion;
71057f
				if (size > INT_MAX) size = INT_MAX;
71057f
				int res = ::recv(
71057f
					connection.sockId,
71057f
					(unsigned char*)task->data + task->completion,
71057f
					(int)size, MSG_DONTWAIT );
71057f
				if (res == EAGAIN) {
71057f
					ready = false;
71057f
					break;
71057f
				} else
71057f
				if (res <= 0) {
71057f
					connection.close(true);
71057f
					return;
71057f
				} else {
71057f
					completion += size;
71057f
					task->completion += size;
71057f
				}
71057f
			}
71057f
			
71057f
			if (task->size <= task->completion || (completion && task->watch)) {
71057f
				connection.readQueue.pop();
71057f
				if (!connection.onReadReady(*task, false))
71057f
					connection.readQueue.unpop(*task); else ++pops;
71057f
			}
71057f
		}
71057f
	}
71057f
	
71057f
	if (events & EPOLLHUP) {
71057f
		connection.close((bool)(events & EPOLLERR));
71057f
		return;
71057f
	}
71057f
	
71057f
	if (events & EPOLLOUT) {
71057f
		bool ready = true;
71057f
		int pops = 0;
71057f
		while(ready) {
71057f
			Connection::Task *task = connection.writeQueue.peek();
71057f
			if (!task) {
71057f
				if (pops) needUpdate = true;
71057f
				break;
71057f
			}
71057f
			size_t completion = 0;
71057f
			while(task->size > task->completion) {
71057f
				size_t size = task->size - task->completion;
71057f
				if (size > INT_MAX) size = INT_MAX;
71057f
				int res = ::send(
71057f
					connection.sockId,
71057f
					(const unsigned char*)task->data + task->completion,
71057f
					(int)size, MSG_DONTWAIT );
71057f
				if (res == EAGAIN) {
71057f
					ready = false;
71057f
					break;
71057f
				} else
71057f
				if (res <= 0) {
71057f
					connection.close(true);
71057f
					return;
71057f
				} else {
71057f
					completion += size;
71057f
					task->completion += size;
71057f
				}
71057f
			}
71057f
			
71057f
			if (task->size <= task->completion || (completion && task->watch)) {
71057f
				connection.writeQueue.pop();
71057f
				if (!connection.onWriteReady(*task, false))
71057f
					connection.writeQueue.unpop(*task); else ++pops;
71057f
			}
71057f
		}
71057f
	}
71057f
	
71057f
	if (needUpdate) updateEvents(connection);
71057f
}
71057f
71057f
71057f
void Protocol::threadSrvQueue() {
71057f
	while(srvQueue.peek()) {
71057f
		Server &server = srvQueue.pop();
71057f
		server.enqueued = false;
71057f
		
71057f
		if (server.stateLocal == Server::STATE_NONE) {
71057f
			int sockId = -1;
71057f
			if (server.addressSize >= sizeof(sa_family_t)) {
71057f
				struct sockaddr *addr = (struct sockaddr*)server.address;
71057f
				sockId = ::socket(addr->sa_family, SOCK_STREAM, 0);
71057f
				if (sockId >= 0) {
71057f
					initSocket(sockId);
71057f
					if ( 0 != ::bind(sockId, addr, (int)server.addressSize)
71057f
					  || 0 != ::listen(sockId, 128) )
71057f
					{
71057f
						::close(sockId);
71057f
						sockId = -1;
71057f
					}
71057f
				}
71057f
			}
71057f
			
71057f
			if (sockId >= 0) {
71057f
				server.protocol = this;
71057f
				server.prev = srvLast;
71057f
				assert(!server.next);
71057f
				(server.prev ? server.prev->next : srvFirst) = &server;
71057f
				server.sockId = sockId;
71057f
				
71057f
				struct epoll_event event = {};
71057f
				event.data.ptr = &server;
71057f
				event.events = EPOLLIN;
71057f
				epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event);
71057f
				
71057f
				server.onOpen();
71057f
				server.state = server.stateLocal = Server::STATE_OPEN;
71057f
				server.stateProtector.wait();
71057f
			} else {
71057f
				server.onOpeningError();
71057f
				server.state = Server::STATE_CLOSED;
71057f
				while(server.closeWaiters > 0)
71057f
					{ server.closeCondition.notify_all(); std::this_thread::yield(); }
71057f
				server.state = Server::STATE_FINISHED;
71057f
			}
71057f
		} else {
71057f
			Server::State stateWanted = server.stateWanted;
71057f
			if ( server.stateLocal == Server::STATE_OPEN
71057f
			  && stateWanted == Server::STATE_CLOSE_REQ )
71057f
			{
71057f
				for(Connection *conn = server.connFirst; conn; conn = conn->srvNext)
71057f
					conn->closeReq();
71057f
				struct epoll_event event = {};
71057f
				epoll_ctl(epollFd, EPOLL_CTL_DEL, server.sockId, &event);
71057f
				server.state = server.stateLocal = Server::STATE_CLOSE_REQ;
71057f
				server.stateProtector.wait();
71057f
				server.onCloseReqested();
71057f
			} else
71057f
			if ( ( server.stateLocal == Server::STATE_OPEN
71057f
				|| server.stateLocal == Server::STATE_CLOSE_REQ )
71057f
			  && stateWanted == Server::STATE_CLOSING_CONNECTIONS )
71057f
			{
71057f
				if (server.stateLocal != Server::STATE_CLOSE_REQ) {
71057f
					struct epoll_event event = {};
71057f
					epoll_ctl(epollFd, EPOLL_CTL_DEL, server.sockId, &event);
71057f
				}
71057f
				if (server.connFirst) {
71057f
					for(Connection *conn = server.connFirst; conn; conn = conn->srvNext)
71057f
						conn->close(true);
71057f
					server.state = server.stateLocal = Server::STATE_CLOSING_CONNECTIONS;
71057f
				} else {
71057f
					server.state = server.stateLocal = Server::STATE_CLOSING;
71057f
					server.stateProtector.wait();
71057f
					server.onClose(false);
71057f
				}
71057f
			}
71057f
			
71057f
			if (server.stateLocal == Server::STATE_CLOSING && !server.enqueued) {
71057f
				server.state = Server::STATE_CLOSED;
71057f
				(server.prev ? server.prev->next : srvFirst) = server.next;
71057f
				(server.next ? server.next->prev : srvLast) = server.prev;
71057f
				server.protocol = nullptr;
71057f
				memset(server.address, 0, sizeof(server.address));
71057f
				server.addressSize = 0;
71057f
				server.stateWanted = Server::STATE_NONE;
71057f
				server.stateLocal = Server::STATE_NONE;
71057f
				while(server.closeWaiters > 0)
71057f
					{ server.closeCondition.notify_all(); std::this_thread::yield(); }
71057f
				server.state = Server::STATE_FINISHED;
71057f
			}
71057f
		}
71057f
	}
71057f
}
71057f
71057f
71057f
void Protocol::threadSrvEvents(Server &server, unsigned int events) {
71057f
	if (events & EPOLLHUP) {
71057f
		if ( server.stateLocal >= Server::STATE_OPEN
71057f
		  && server.stateLocal < Server::STATE_CLOSING )
71057f
			server.close((bool)(events & EPOLLERR));
71057f
		return;
71057f
	}
71057f
	
71057f
	if (events & EPOLLIN) {
71057f
		if ( server.stateLocal >= Server::STATE_OPEN
71057f
		  && server.stateLocal < Server::STATE_CLOSE_REQ)
71057f
		{
71057f
			unsigned char address[MAX_ADDR_SIZE] = {};
71057f
			socklen_t len = sizeof(address);
71057f
			int res = ::accept(server.sockId, (struct sockaddr*)address, &len);
71057f
			assert(len < sizeof(address));
71057f
			if (res >= 0) {
71057f
				if (Connection *conn = server.onConnect(address, len))
71057f
					if (!conn->open(*this, address, len, res))
71057f
						server.onDisconnect(conn, true);
71057f
			} else
71057f
			if (res != EAGAIN) {
71057f
				server.close(true);
71057f
				return;
71057f
			}
71057f
		}
71057f
	}
71057f
}
71057f
71057f
71057f
void Protocol::threadRun() {
71057f
	epollFd = ::epoll_create(32);
71057f
	eventFd = ::eventfd(0, EFD_NONBLOCK);
71057f
	
71057f
	struct epoll_event event = {};
71057f
	event.events = EPOLLIN;
71057f
	epoll_ctl(epollFd, EPOLL_CTL_ADD, eventFd, &event);
71057f
	
71057f
	const int maxCount = 16;
71057f
	struct epoll_event events[maxCount] = {};
71057f
	int count = 0;
71057f
	
71057f
	State stateLocal = STATE_OPEN;
71057f
	state = stateLocal;
71057f
	
71057f
	while(true) {
71057f
		for(int i = 0; i < count; ++i) {
71057f
			Root *pointer = (Root*)events[i].data.ptr;
71057f
			if (!pointer) {
71057f
				unsigned long long num = 0;
71057f
				while(::read(eventFd, &num, sizeof(num)) > 0 && !num);
71057f
			} else
71057f
			if (Connection *connection = dynamic_cast<connection*>(pointer)) {</connection*>
71057f
				threadConnEvents(*connection, events[i].events);
71057f
			} else
71057f
			if (Server *server = dynamic_cast<server*>(pointer)) {</server*>
71057f
				threadSrvEvents(*server, events[i].events);
71057f
			}
71057f
		}
71057f
		
71057f
		State sw = stateWanted;
71057f
		if (stateLocal == STATE_OPEN && sw == STATE_CLOSE_REQ) {
71057f
			state = stateLocal = STATE_CLOSE_REQ;
71057f
			stateProtector.wait();
71057f
			for(Server *server = srvFirst; server; server = server->next)
71057f
				server->closeReq();
71057f
			for(Connection *conn = connFirst; conn; conn = conn->next)
71057f
				if (!conn->server)
71057f
					conn->closeReq();
71057f
		} else
71057f
		if (stateLocal >= STATE_OPEN && stateLocal <= STATE_CLOSE_REQ && sw == STATE_CLOSING) {
71057f
			state = STATE_CLOSING;
71057f
			stateProtector.wait();
71057f
			for(Server *server = srvFirst; server; server = server->next)
71057f
				server->close(true);
71057f
			for(Connection *conn = connFirst; conn; conn = conn->next)
71057f
				if (!conn->server)
71057f
					conn->close(true);
71057f
		}
71057f
		
71057f
		while(connQueue.peek() || srvQueue.peek()) {
71057f
			threadConnQueue();
71057f
			threadSrvQueue();
71057f
		}
71057f
		
71057f
		if (stateLocal == STATE_CLOSING) {
71057f
			assert(!srvFirst && !connFirst);
71057f
			break;
71057f
		}
71057f
		
71057f
		count = epoll_wait(epollFd, events, maxCount, 10000);
71057f
	}
71057f
	
71057f
	while(closeWaiters > 0)
71057f
		{ closeCondition.notify_all(); std::this_thread::yield(); }
71057f
	
71057f
	::close(epollFd);
71057f
	::close(eventFd);
71057f
	
71057f
	state = STATE_FINISHED;
71057f
}
71057f