Blame socket.cpp

b42f0a
b42f0a
#include <cstring></cstring>
b42f0a
b42f0a
#include <fcntl.h></fcntl.h>
b42f0a
#include <unistd.h></unistd.h>
b42f0a
#include <sys types.h=""></sys>
b42f0a
#include <sys socket.h=""></sys>
b42f0a
#include <sys epoll.h=""></sys>
b42f0a
#include <netinet in.h=""></netinet>
b42f0a
#include <netinet tcp.h=""></netinet>
b42f0a
b42f0a
#include "socket.h"
b42f0a
#include "protocol.h"
b42f0a
b42f0a
b42f0a
#define DebugMSG ShowDebugMSG
b42f0a
b42f0a
b42f0a
b42f0a
Socket::Socket():
b42f0a
	state(STATE_NONE),
b42f0a
	stateWanted(STATE_NONE),
b42f0a
	enqueued(false),
b42f0a
	closeWaiters(0),
b42f0a
	finishWaiters(0),
b42f0a
	address(),
b42f0a
	addressSize(),
b42f0a
	protocol(),
b42f0a
	prev(), next(),
b42f0a
	queueNext(),
b42f0a
	sockId(-1),
b42f0a
	stateLocal(),
b42f0a
	events(0xffffffff)
b42f0a
	{ }
b42f0a
b42f0a
b42f0a
Socket::~Socket()
b42f0a
	{ assert(state == STATE_DESTROYED); }
b42f0a
b42f0a
b42f0a
void Socket::destroy() {
b42f0a
	assert(state == STATE_NONE || state == STATE_FINISHED);
b42f0a
	++finishWaiters;
b42f0a
	closeWait();
b42f0a
	state = STATE_DESTROYED;
b42f0a
	--finishWaiters;
b42f0a
	while(finishWaiters > 0)
b42f0a
		std::this_thread::yield();
b42f0a
}
b42f0a
b42f0a
b42f0a
bool Socket::openBegin(Protocol &protocol, void *address, size_t addressSize) {
b42f0a
	if (!address || !addressSize || addressSize > MAX_ADDR_SIZE) return false;
b42f0a
	
b42f0a
	State s = state;
b42f0a
	do {
b42f0a
		if (s != STATE_NONE && s != STATE_FINISHED)
b42f0a
			return false;
b42f0a
	} while(!state.compare_exchange_weak(s, STATE_OPENING));
b42f0a
	if (s == STATE_FINISHED)
b42f0a
		while(finishWaiters > 0)
b42f0a
			std::this_thread::yield();
b42f0a
	if (state == STATE_DESTROYED)
b42f0a
		return false;
b42f0a
	
b42f0a
	protocol.stateProtector.lock();
b42f0a
	State ps = protocol.state;
b42f0a
	if (ps < STATE_INITIALIZING || ps >= STATE_CLOSE_REQ) {
b42f0a
		state = STATE_NONE;
b42f0a
		protocol.stateProtector.unlock();
b42f0a
		return false;
b42f0a
	}
b42f0a
	
b42f0a
	this->protocol = &protocol;
b42f0a
	memcpy(this->address, address, addressSize);
b42f0a
	this->addressSize = addressSize;
b42f0a
	
b42f0a
	return true;
b42f0a
}
b42f0a
b42f0a
b42f0a
void Socket::openEnd() {
b42f0a
	state = STATE_INITIALIZING;
b42f0a
	enqueue();
b42f0a
	protocol->stateProtector.unlock();
b42f0a
}
b42f0a
b42f0a
b42f0a
void Socket::initSocket(int sockId) {
b42f0a
	int tcp_nodelay = 1;
b42f0a
	setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int));
b42f0a
	fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
b42f0a
}
b42f0a
b42f0a
b42f0a
void Socket::enqueue() {
b42f0a
	if (!enqueued.exchange(true)) {
b42f0a
		protocol->sockQueue.push(*this);
b42f0a
		protocol->wakeup();
b42f0a
	}
b42f0a
}
b42f0a
b42f0a
b42f0a
void Socket::updateEvents(bool wantRead, bool wantWrite) {
b42f0a
	unsigned int e = EPOLLERR | EPOLLRDHUP;
b42f0a
	if (wantRead) e |= EPOLLIN;
b42f0a
	if (wantWrite) e |= EPOLLOUT;
b42f0a
b42f0a
	DebugMSG("events: %08x", e);
b42f0a
	unsigned int pe = events.exchange(e);
b42f0a
	if (pe == e)
b42f0a
		return;
b42f0a
	DebugMSG("apply");
b42f0a
	
b42f0a
	struct epoll_event event = {};
b42f0a
	event.data.ptr = this;
b42f0a
	event.events = e;
b42f0a
	epoll_ctl(protocol->epollFd, pe == 0xffffffff ? EPOLL_CTL_ADD : EPOLL_CTL_MOD, sockId, &event);
b42f0a
}
b42f0a
b42f0a
b42f0a
void Socket::wantState(State state, bool error) {
b42f0a
	ReadLock lock(stateProtector);
b42f0a
	State s = this->state;
b42f0a
	DebugMSG("%d, %d", s, state);
b42f0a
	if (s <= STATE_OPENING || s >= STATE_CLOSING || s >= state) return;
b42f0a
	DebugMSG("apply");
b42f0a
	
b42f0a
	State sw = stateWanted;
b42f0a
	do {
b42f0a
		if (sw >= state) return;
b42f0a
	} while(!stateWanted.compare_exchange_weak(sw, state));
b42f0a
	if (error) { DebugMSG("error"); error = false; this->error.compare_exchange_strong(error, true); }
b42f0a
	enqueue();
b42f0a
}
b42f0a
b42f0a
b42f0a
void Socket::closeReq()
b42f0a
	{ wantState(STATE_CLOSE_REQ, false); }
b42f0a
void Socket::close(bool error)
b42f0a
	{ wantState(STATE_CLOSING, error); }
b42f0a
b42f0a
b42f0a
void Socket::closeWait(unsigned long long timeoutUs, bool withReq, bool error) {
b42f0a
	stateProtector.lock();
b42f0a
	State s = state;
b42f0a
	if (s <= STATE_OPENING || s >= STATE_FINISHED)
b42f0a
		{ stateProtector.unlock(); return; }
b42f0a
	
b42f0a
	if (s == STATE_CLOSED) {
b42f0a
		finishWaiters++;
b42f0a
		stateProtector.unlock();
b42f0a
	} else {
b42f0a
		if (withReq && s < STATE_CLOSE_REQ) closeReq();
b42f0a
		closeWaiters++;
b42f0a
		stateProtector.unlock();
b42f0a
		
b42f0a
		std::mutex fakeMutex;
b42f0a
		{
b42f0a
			std::unique_lock<std::mutex> fakeLock(fakeMutex);</std::mutex>
b42f0a
			if (timeoutUs && s < STATE_CLOSING) {
b42f0a
				std::chrono::steady_clock::time_point t
b42f0a
					= std::chrono::steady_clock::now()
b42f0a
						+ std::chrono::microseconds(timeoutUs);
b42f0a
				while(s < STATE_CLOSED) {
b42f0a
					if (t <= std::chrono::steady_clock::now()) { close(error); break; }
b42f0a
					closeCondition.wait_until(fakeLock, t);
b42f0a
					s = state;
b42f0a
				}
b42f0a
			}
b42f0a
			if (s < STATE_CLOSING) close(error);
b42f0a
			while(s < STATE_CLOSED) {
b42f0a
				closeCondition.wait(fakeLock);
b42f0a
				s = state;
b42f0a
			}
b42f0a
		}
b42f0a
		
b42f0a
		finishWaiters++;
b42f0a
		closeWaiters--;
b42f0a
	}
b42f0a
	
b42f0a
	
b42f0a
	while(s < STATE_FINISHED) { std::this_thread::yield(); s = state; }
b42f0a
	
b42f0a
	finishWaiters--;
b42f0a
}
b42f0a
b42f0a
b42f0a
void Socket::onOpeningError() { }
b42f0a
void Socket::onOpen(const void*, size_t) { }
b42f0a
void Socket::onCloseReqested() { }
b42f0a
void Socket::onClose(bool) { }
b42f0a
b42f0a