#include <climits>
#include <cstring>
#include <fcntl.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include "protocol.h"
#define DebugMSG HideDebugMSG
Protocol::Protocol():
thread(),
epollFd(-1),
eventFd(-1),
state(STATE_NONE),
stateWanted(STATE_NONE),
closeWaiters(0),
finishWaiters(0),
sockFirst(), sockLast()
{ }
Protocol::~Protocol() {
assert(state == STATE_NONE || state == STATE_FINISHED);
closeWait();
while(finishWaiters > 0) std::this_thread::yield();
DebugMSG();
if (thread) thread->join();
DebugMSG("deleted");
}
bool Protocol::open() {
DebugMSG();
State s = state;
do {
if (s != STATE_NONE && s != STATE_FINISHED) return false;
} while(!state.compare_exchange_weak(s, STATE_OPENING));
if (s == STATE_FINISHED) while(finishWaiters > 0) std::this_thread::yield();
if (thread) thread->join();
thread = new std::thread(&Protocol::threadRun, this);
state = STATE_INITIALIZING;
DebugMSG();
return true;
}
void Protocol::wantState(State state) {
ReadLock lock(stateProtector);
State s = this->state;
DebugMSG("%d, %d", s, state);
if (s <= STATE_OPENING || s >= STATE_CLOSING || s >= state) return;
DebugMSG("apply");
State sw = stateWanted;
do {
if (sw >= state) return;
} while(!stateWanted.compare_exchange_weak(sw, state));
wakeup();
}
void Protocol::closeReq()
{ DebugMSG(); wantState(STATE_CLOSE_REQ); }
void Protocol::close()
{ DebugMSG(); wantState(STATE_CLOSING); }
void Protocol::closeWait(unsigned long long timeoutUs, bool withReq) {
DebugMSG();
stateProtector.lock();
State s = state;
if (s <= STATE_OPENING || s >= STATE_FINISHED)
{ stateProtector.unlock(); return; }
if (s == STATE_CLOSED) {
finishWaiters++;
stateProtector.unlock();
} else {
if (withReq && s < STATE_CLOSE_REQ) closeReq();
closeWaiters++;
stateProtector.unlock();
std::mutex fakeMutex;
{
std::unique_lock<std::mutex> fakeLock(fakeMutex);
if (timeoutUs && s < STATE_CLOSING) {
std::chrono::steady_clock::time_point t
= std::chrono::steady_clock::now()
+ std::chrono::microseconds(timeoutUs);
while(s < STATE_CLOSED) {
if (t <= std::chrono::steady_clock::now()) { close(); break; }
DebugMSG("wait_until, %d", s);
closeCondition.wait_until(fakeLock, t);
s = state;
}
}
if (s < STATE_CLOSING) close();
while(s < STATE_CLOSED) {
DebugMSG("wait, %d", s);
closeCondition.wait(fakeLock);
s = state;
}
DebugMSG("awaiten");
}
finishWaiters++;
closeWaiters--;
}
while(s != STATE_FINISHED) { std::this_thread::yield(); s = state; }
finishWaiters--;
DebugMSG();
}
void Protocol::wakeup() {
unsigned long long num = 0;
::write(eventFd, &num, sizeof(num));
}
void Protocol::threadRun() {
DebugMSG();
epollFd = ::epoll_create(32);
eventFd = ::eventfd(0, EFD_NONBLOCK);
assert(epollFd >= 0);
assert(eventFd >= 0);
struct epoll_event event = {};
event.events = EPOLLIN;
epoll_ctl(epollFd, EPOLL_CTL_ADD, eventFd, &event);
const int maxCount = 16;
struct epoll_event events[maxCount] = {};
int count = 0;
State stateLocal = STATE_OPEN;
state = stateLocal;
while(true) {
for(int i = 0; i < count; ++i) {
if (Socket *socket = (Socket*)events[i].data.ptr) {
socket->handleEvents(events[i].events);
} else {
unsigned long long num = 0;
while(::read(eventFd, &num, sizeof(num)) > 0 && !num);
}
}
State sw = stateWanted;
if (stateLocal == STATE_OPEN && sw == STATE_CLOSE_REQ) {
DebugMSG("to closse req");
state = stateLocal = STATE_CLOSE_REQ;
stateProtector.wait();
for(Socket *socket = sockFirst; socket; socket = socket->next)
socket->closeReq();
} else
if (stateLocal >= STATE_OPEN && stateLocal <= STATE_CLOSE_REQ && sw == STATE_CLOSING) {
DebugMSG("to closing");
state = stateLocal = STATE_CLOSING;
stateProtector.wait();
for(Socket *socket = sockFirst; socket; socket = socket->next)
socket->close(true);
}
while(sockQueue.peek()) {
Socket &socket = sockQueue.pop();
socket.enqueued = false;
socket.handleState();
}
if (stateLocal == STATE_CLOSING)
break;
count = epoll_wait(epollFd, events, maxCount, 1000);
DebugMSG("epoll count: %d", count);
}
assert(!sockFirst);
DebugMSG("closeWaiters");
state = STATE_CLOSED;
while(closeWaiters > 0)
{ closeCondition.notify_all(); std::this_thread::yield(); }
epoll_ctl(epollFd, EPOLL_CTL_DEL, eventFd, &event);
::close(epollFd);
::close(eventFd);
epollFd = -1;
eventFd = -1;
assert(!sockFirst);
assert(!sockQueue.peek());
stateWanted = STATE_NONE;
stateLocal = STATE_NONE;
state = STATE_FINISHED;
DebugMSG("finished");
}