#include <climits>
#include <cstring>
#include <fcntl.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include "protocol.h"
#define DebugMSG HideDebugMSG
#define SrvDebugMSG ShowDebugMSG
Protocol::Protocol():
thread(),
epollFd(-1),
eventFd(-1),
state(STATE_NONE),
stateWanted(STATE_NONE),
closeWaiters(0),
finishWaiters(0),
connFirst(), connLast(),
srvFirst(), srvLast()
{ }
Protocol::~Protocol() {
assert(state == STATE_NONE || state == STATE_FINISHED);
closeWait();
while(finishWaiters > 0) std::this_thread::yield();
DebugMSG();
if (thread) thread->join();
DebugMSG("deleted");
}
bool Protocol::open() {
DebugMSG();
State s = state;
do {
if (s != STATE_NONE && s != STATE_FINISHED) return false;
} while(!state.compare_exchange_weak(s, STATE_OPENING));
if (s == STATE_FINISHED) while(finishWaiters > 0) std::this_thread::yield();
if (thread) thread->join();
thread = new std::thread(&Protocol::threadRun, this);
state = STATE_INITIALIZING;
DebugMSG();
return true;
}
void Protocol::wantState(State state) {
ReadLock lock(stateProtector);
State s = this->state;
DebugMSG("%d, %d", s, state);
if (s <= STATE_OPENING || s >= STATE_CLOSING || s >= state) return;
DebugMSG("apply");
State sw = stateWanted;
do {
if (sw >= state) return;
} while(!stateWanted.compare_exchange_weak(sw, state));
wakeup();
}
void Protocol::closeReq()
{ DebugMSG(); wantState(STATE_CLOSE_REQ); }
void Protocol::close()
{ DebugMSG(); wantState(STATE_CLOSING); }
void Protocol::closeWait(unsigned long long timeoutUs, bool withReq) {
DebugMSG();
stateProtector.lock();
State s = state;
if (s <= STATE_OPENING || s >= STATE_FINISHED)
{ stateProtector.unlock(); return; }
if (s == STATE_CLOSED) {
finishWaiters++;
stateProtector.unlock();
} else {
if (withReq && s < STATE_CLOSE_REQ) closeReq();
closeWaiters++;
stateProtector.unlock();
std::mutex fakeMutex;
{
std::unique_lock<std::mutex> fakeLock(fakeMutex);
if (timeoutUs && s < STATE_CLOSING) {
std::chrono::steady_clock::time_point t
= std::chrono::steady_clock::now()
+ std::chrono::microseconds(timeoutUs);
while(s < STATE_CLOSED) {
if (t <= std::chrono::steady_clock::now()) { close(); break; }
DebugMSG("wait_until, %d", s);
closeCondition.wait_until(fakeLock, t);
s = state;
}
}
if (s < STATE_CLOSING) close();
while(s < STATE_CLOSED) {
DebugMSG("wait, %d", s);
closeCondition.wait(fakeLock);
s = state;
}
DebugMSG("awaiten");
}
finishWaiters++;
closeWaiters--;
}
while(s != STATE_FINISHED) { std::this_thread::yield(); s = state; }
finishWaiters--;
DebugMSG();
}
void Protocol::wakeup() {
unsigned long long num = 0;
::write(eventFd, &num, sizeof(num));
}
void Protocol::updateEvents(Connection &connection) {
unsigned int events = 0;
if (connection.readQueue.count()) events |= EPOLLIN;
if (connection.writeQueue.count()) events |= EPOLLOUT;
if (connection.events.exchange(events) == events) return;
struct epoll_event event = {};
event.data.ptr = &connection;
event.events = events;
epoll_ctl(epollFd, EPOLL_CTL_MOD, connection.sockId, &event);
}
void Protocol::initSocket(int sockId) {
int tcp_nodelay = 1;
setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int));
fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
}
void Protocol::threadConnQueue() {
while(connQueue.peek()) {
Connection &connection = connQueue.pop();
connection.enqueued = false;
if (connection.stateLocal == Connection::STATE_NONE) {
int sockId = connection.sockId;
bool connected = false;
if (sockId >= 0) {
initSocket(sockId);
} else
if (connection.addressSize >= sizeof(sa_family_t)) {
struct sockaddr *addr = (struct sockaddr*)connection.address;
sockId = ::socket(addr->sa_family, SOCK_STREAM, 0);
if (sockId >= 0) {
initSocket(sockId);
int res = ::connect(sockId, addr, (int)connection.addressSize);
if (res == 0) {
connected = true;
} else
if (res != EINPROGRESS) {
::close(sockId);
sockId = -1;
}
}
}
if (sockId >= 0) {
connection.protocol = this;
connection.prev = connLast;
assert(!connection.next);
(connection.prev ? connection.prev->next : connFirst) = &connection;
if (connection.server) {
connection.srvPrev = connection.server->connLast;
assert(!connection.srvNext);
(connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = &connection;
}
connection.sockId = sockId;
struct epoll_event event = {};
event.data.ptr = &connection;
if (connected) {
epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event);
connection.onOpen(connection.address, connection.addressSize);
connection.state = connection.stateLocal = Connection::STATE_OPEN;
connection.stateProtector.wait();
} else {
event.events = EPOLLOUT;
epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event);
connection.state = connection.stateLocal = Connection::STATE_CONNECTING;
connection.stateProtector.wait();
}
} else {
connection.sockId = -1;
connection.server = nullptr;
connection.onOpeningError();
connection.state = Connection::STATE_CLOSED;
while(connection.closeWaiters > 0)
{ connection.closeCondition.notify_all(); std::this_thread::yield(); }
connection.state = Connection::STATE_FINISHED;
}
} else {
Connection::State stateWanted = connection.stateWanted;
if ( connection.stateLocal >= Connection::STATE_CONNECTING
&& connection.stateLocal <= Connection::STATE_OPEN
&& stateWanted == Connection::STATE_CLOSE_REQ )
{
connection.state = connection.stateLocal = Connection::STATE_CLOSE_REQ;
connection.stateProtector.wait();
connection.onCloseReqested();
} else
if ( connection.stateLocal >= Connection::STATE_CONNECTING
&& connection.stateLocal <= Connection::STATE_CLOSE_REQ
&& stateWanted == Connection::STATE_CLOSING )
{
connection.state = connection.stateLocal = Connection::STATE_CLOSING;
connection.stateProtector.wait();
while(Connection::Task *task = connection.readQueue.peek()) {
connection.readQueue.pop();
if (!connection.onReadReady(*task, true))
assert(false);
}
while(Connection::Task *task = connection.writeQueue.peek()) {
connection.writeQueue.pop();
if (!connection.onWriteReady(*task, true))
assert(false);
}
connection.onClose(false);
}
if (connection.stateLocal == Connection::STATE_CLOSING && !connection.enqueued) {
if (connection.server) {
(connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = connection.srvNext;
(connection.srvNext ? connection.srvNext->srvPrev : connection.server->connLast) = connection.srvPrev;
connection.server = nullptr;
}
(connection.prev ? connection.prev->next : connFirst) = connection.next;
(connection.next ? connection.next->prev : connLast) = connection.prev;
connection.sockId = -1;
connection.protocol = nullptr;
memset(connection.address, 0, sizeof(connection.address));
connection.addressSize = 0;
connection.stateWanted = Connection::STATE_NONE;
connection.stateLocal = Connection::STATE_NONE;
connection.state = Connection::STATE_CLOSED;
while(connection.closeWaiters > 0)
{ connection.closeCondition.notify_all(); std::this_thread::yield(); }
connection.state = Connection::STATE_FINISHED;
}
}
}
}
void Protocol::threadConnEvents(Connection &connection, unsigned int events) {
bool needUpdate = false;
if (connection.stateLocal == Connection::STATE_CONNECTING) {
if (events & EPOLLHUP) {
connection.onOpeningError();
if (connection.server) {
(connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = connection.srvNext;
(connection.srvNext ? connection.srvNext->srvPrev : connection.server->connLast) = connection.srvPrev;
connection.server = nullptr;
}
(connection.prev ? connection.prev->next : connFirst) = connection.next;
(connection.next ? connection.next->prev : connLast) = connection.prev;
connection.sockId = -1;
connection.protocol = nullptr;
memset(connection.address, 0, sizeof(connection.address));
connection.addressSize = 0;
connection.stateWanted = Connection::STATE_NONE;
connection.stateLocal = Connection::STATE_NONE;
connection.state = Connection::STATE_CLOSED;
while(connection.closeWaiters > 0)
{ connection.closeCondition.notify_all(); std::this_thread::yield(); }
connection.state = Connection::STATE_FINISHED;
}
if (events & EPOLLOUT) {
connection.onOpen(connection.address, connection.addressSize);
connection.state = connection.stateLocal = Connection::STATE_OPEN;
connection.stateProtector.wait();
needUpdate = true;
}
}
if ( connection.stateLocal < Connection::STATE_OPEN
|| connection.stateLocal >= Connection::STATE_CLOSING )
return;
if (events & EPOLLIN) {
bool ready = true;
int pops = 0;
while(ready) {
Connection::Task *task = connection.readQueue.peek();
if (!task) {
if (pops) needUpdate = true;
break;
}
size_t completion = 0;
while(task->size > task->completion) {
size_t size = task->size - task->completion;
if (size > INT_MAX) size = INT_MAX;
int res = ::recv(
connection.sockId,
(unsigned char*)task->data + task->completion,
(int)size, MSG_DONTWAIT );
if (res == EAGAIN) {
ready = false;
break;
} else
if (res <= 0) {
connection.close(true);
return;
} else {
completion += size;
task->completion += size;
}
}
if (task->size <= task->completion || (completion && task->watch)) {
connection.readQueue.pop();
if (!connection.onReadReady(*task, false))
connection.readQueue.unpop(*task); else ++pops;
}
}
}
if (events & EPOLLHUP) {
connection.close((bool)(events & EPOLLERR));
return;
}
if (events & EPOLLOUT) {
bool ready = true;
int pops = 0;
while(ready) {
Connection::Task *task = connection.writeQueue.peek();
if (!task) {
if (pops) needUpdate = true;
break;
}
size_t completion = 0;
while(task->size > task->completion) {
size_t size = task->size - task->completion;
if (size > INT_MAX) size = INT_MAX;
int res = ::send(
connection.sockId,
(const unsigned char*)task->data + task->completion,
(int)size, MSG_DONTWAIT );
if (res == EAGAIN) {
ready = false;
break;
} else
if (res <= 0) {
connection.close(true);
return;
} else {
completion += size;
task->completion += size;
}
}
if (task->size <= task->completion || (completion && task->watch)) {
connection.writeQueue.pop();
if (!connection.onWriteReady(*task, false))
connection.writeQueue.unpop(*task); else ++pops;
}
}
}
if (needUpdate) updateEvents(connection);
}
void Protocol::threadSrvQueue() {
while(srvQueue.peek()) {
SrvDebugMSG();
Server &server = srvQueue.pop();
server.enqueued = false;
if (server.stateLocal == Server::STATE_NONE) {
SrvDebugMSG("none");
int sockId = -1;
if (server.addressSize >= sizeof(sa_family_t)) {
struct sockaddr *addr = (struct sockaddr*)server.address;
sockId = ::socket(addr->sa_family, SOCK_STREAM, 0);
if (sockId >= 0) {
initSocket(sockId);
if ( 0 != ::bind(sockId, addr, (int)server.addressSize)
|| 0 != ::listen(sockId, 128) )
{
SrvDebugMSG("cannot bind/listen");
::close(sockId);
sockId = -1;
}
} else {
SrvDebugMSG("cannot create socket");
}
}
if (sockId >= 0) {
SrvDebugMSG("to open");
server.protocol = this;
server.prev = srvLast;
assert(!server.next);
(server.prev ? server.prev->next : srvFirst) = &server;
server.sockId = sockId;
struct epoll_event event = {};
event.data.ptr = &server;
event.events = EPOLLIN;
epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event);
server.onOpen();
server.state = server.stateLocal = Server::STATE_OPEN;
server.stateProtector.wait();
} else {
SrvDebugMSG("opening error -> to closed");
server.onOpeningError();
server.state = Server::STATE_CLOSED;
while(server.closeWaiters > 0)
{ server.closeCondition.notify_all(); std::this_thread::yield(); }
server.state = Server::STATE_FINISHED;
SrvDebugMSG("opening error -> finished");
}
} else {
Server::State stateWanted = server.stateWanted;
if ( server.stateLocal == Server::STATE_OPEN
&& stateWanted == Server::STATE_CLOSE_REQ )
{
SrvDebugMSG("to close req");
for(Connection *conn = server.connFirst; conn; conn = conn->srvNext)
conn->closeReq();
struct epoll_event event = {};
epoll_ctl(epollFd, EPOLL_CTL_DEL, server.sockId, &event);
server.state = server.stateLocal = Server::STATE_CLOSE_REQ;
server.stateProtector.wait();
server.onCloseReqested();
} else
if ( ( server.stateLocal == Server::STATE_OPEN
|| server.stateLocal == Server::STATE_CLOSE_REQ )
&& stateWanted == Server::STATE_CLOSING_CONNECTIONS )
{
SrvDebugMSG("to closing connections");
if (server.stateLocal != Server::STATE_CLOSE_REQ) {
struct epoll_event event = {};
epoll_ctl(epollFd, EPOLL_CTL_DEL, server.sockId, &event);
}
if (server.connFirst) {
for(Connection *conn = server.connFirst; conn; conn = conn->srvNext)
conn->close(true);
server.state = server.stateLocal = Server::STATE_CLOSING_CONNECTIONS;
} else {
SrvDebugMSG("to closing");
server.state = server.stateLocal = Server::STATE_CLOSING;
server.stateProtector.wait();
server.onClose(false);
}
}
if (server.stateLocal == Server::STATE_CLOSING && !server.enqueued) {
SrvDebugMSG("to closed");
server.state = Server::STATE_CLOSED;
(server.prev ? server.prev->next : srvFirst) = server.next;
(server.next ? server.next->prev : srvLast) = server.prev;
server.protocol = nullptr;
memset(server.address, 0, sizeof(server.address));
server.addressSize = 0;
server.stateWanted = Server::STATE_NONE;
server.stateLocal = Server::STATE_NONE;
SrvDebugMSG("closeWaiters");
while(server.closeWaiters > 0)
{ server.closeCondition.notify_all(); std::this_thread::yield(); }
server.state = Server::STATE_FINISHED;
SrvDebugMSG("finished");
}
}
}
}
void Protocol::threadSrvEvents(Server &server, unsigned int events) {
if (events & EPOLLHUP) {
if ( server.stateLocal >= Server::STATE_OPEN
&& server.stateLocal < Server::STATE_CLOSING )
server.close((bool)(events & EPOLLERR));
return;
}
if (events & EPOLLIN) {
if ( server.stateLocal >= Server::STATE_OPEN
&& server.stateLocal < Server::STATE_CLOSE_REQ)
{
unsigned char address[MAX_ADDR_SIZE] = {};
socklen_t len = sizeof(address);
int res = ::accept(server.sockId, (struct sockaddr*)address, &len);
assert(len < sizeof(address));
if (res >= 0) {
if (Connection *conn = server.onConnect(address, len))
if (!conn->open(*this, address, len, res))
server.onDisconnect(conn, true);
} else
if (res != EAGAIN) {
server.close(true);
return;
}
}
}
}
void Protocol::threadRun() {
DebugMSG();
epollFd = ::epoll_create(32);
eventFd = ::eventfd(0, EFD_NONBLOCK);
struct epoll_event event = {};
event.events = EPOLLIN;
epoll_ctl(epollFd, EPOLL_CTL_ADD, eventFd, &event);
const int maxCount = 16;
struct epoll_event events[maxCount] = {};
int count = 0;
State stateLocal = STATE_OPEN;
state = stateLocal;
while(true) {
for(int i = 0; i < count; ++i) {
Root *pointer = (Root*)events[i].data.ptr;
if (!pointer) {
unsigned long long num = 0;
while(::read(eventFd, &num, sizeof(num)) > 0 && !num);
} else
if (Connection *connection = dynamic_cast<Connection*>(pointer)) {
threadConnEvents(*connection, events[i].events);
} else
if (Server *server = dynamic_cast<Server*>(pointer)) {
threadSrvEvents(*server, events[i].events);
}
}
State sw = stateWanted;
if (stateLocal == STATE_OPEN && sw == STATE_CLOSE_REQ) {
DebugMSG("to closse req");
state = stateLocal = STATE_CLOSE_REQ;
stateProtector.wait();
for(Server *server = srvFirst; server; server = server->next)
server->closeReq();
for(Connection *conn = connFirst; conn; conn = conn->next)
if (!conn->server)
conn->closeReq();
} else
if (stateLocal >= STATE_OPEN && stateLocal <= STATE_CLOSE_REQ && sw == STATE_CLOSING) {
DebugMSG("to closing");
state = stateLocal = STATE_CLOSING;
stateProtector.wait();
for(Server *server = srvFirst; server; server = server->next)
server->close(true);
for(Connection *conn = connFirst; conn; conn = conn->next)
if (!conn->server)
conn->close(true);
}
while(connQueue.peek() || srvQueue.peek()) {
threadConnQueue();
threadSrvQueue();
}
if (stateLocal == STATE_CLOSING) {
assert(!connFirst && !srvFirst);
break;
}
DebugMSG("epoll_wait: %d, %d", stateLocal, sw);
count = epoll_wait(epollFd, events, maxCount, 1000);
}
DebugMSG("closeWaiters");
state = STATE_FINISHED;
while(closeWaiters > 0)
{ closeCondition.notify_all(); std::this_thread::yield(); }
::close(epollFd);
::close(eventFd);
epollFd = -1;
eventFd = -1;
assert(!connFirst && !srvFirst);
assert(!connQueue.peek() && !srvQueue.peek());
stateWanted = STATE_NONE;
stateLocal = STATE_NONE;
state = STATE_FINISHED;
DebugMSG("finished");
}