diff --git a/connection.cpp b/connection.cpp index 47b098d..86a880d 100644 --- a/connection.cpp +++ b/connection.cpp @@ -1,147 +1,56 @@ #include +#include #include #include #include +#include +#include +#include +#include +#include + #include "protocol.h" #include "connection.h" -#define DebugMSG HideDebugMSG +#define DebugMSG ShowDebugMSG Connection::Connection(): - state(STATE_NONE), - stateWanted(STATE_NONE), - enqueued(false), - closeWaiters(0), - finishWaiters(0), - address(), - addressSize(), - protocol(), - server(), - prev(), next(), - srvPrev(), srvNext(), - queueNext(), - sockId(-1), - stateLocal(), - errorLocal() - { } - - -Connection::~Connection() { - assert(state == STATE_NONE || state == STATE_FINISHED); - closeWait(); - while(finishWaiters > 0) std::this_thread::yield(); -} + server(), srvPrev(), srvNext() { } + + +Connection::~Connection() + { destroy(); } bool Connection::open(Protocol &protocol, void *address, size_t addressSize, int sockId, Server *server) { - if (!address || !addressSize || addressSize > MAX_ADDR_SIZE) return false; - - State s = state; - do { - if (s != STATE_NONE && s != STATE_FINISHED) return false; - } while(!state.compare_exchange_weak(s, STATE_OPENING)); - if (s == STATE_FINISHED) while(finishWaiters > 0) std::this_thread::yield(); - - ReadLock lock(protocol.stateProtector); - Protocol::State ps = protocol.state; - if (ps < Protocol::STATE_INITIALIZING || ps >= Protocol::STATE_CLOSE_REQ) - { state = STATE_NONE; return false; } - - memcpy(this->address, address, addressSize); - this->addressSize = addressSize; + OpenLock lock(*this, protocol, address, addressSize); + if (!lock.success) + return false; this->sockId = sockId; this->server = server; - - this->protocol = &protocol; - state = STATE_INITIALIZING; - if (!enqueued.exchange(true)) protocol.connQueue.push(*this); - protocol.wakeup(); return true; } -void Connection::wantState(State state, bool error) { - ReadLock lock(stateProtector); - State s = this->state; - DebugMSG("%d, %d", s, state); - if (s <= STATE_OPENING || s >= STATE_CLOSING || s >= state) return; - DebugMSG("apply"); - - State sw = stateWanted; - do { - if (sw >= state) return; - } while(!stateWanted.compare_exchange_weak(sw, state)); - if (error) { DebugMSG("error"); error = false; this->error.compare_exchange_strong(error, true); } - if (!enqueued.exchange(true)) protocol->connQueue.push(*this); - protocol->wakeup(); -} - - -void Connection::closeReq() - { wantState(STATE_CLOSE_REQ, false); } -void Connection::close(bool error) - { wantState(STATE_CLOSING, error); } - - -void Connection::closeWait(unsigned long long timeoutUs, bool withReq, bool error) { - stateProtector.lock(); - State s = state; - if (s <= STATE_OPENING || s >= STATE_FINISHED) - { stateProtector.unlock(); return; } - - if (s == STATE_CLOSED) { - finishWaiters++; - stateProtector.unlock(); - } else { - if (withReq && s < STATE_CLOSE_REQ) closeReq(); - closeWaiters++; - stateProtector.unlock(); - - std::mutex fakeMutex; - { - std::unique_lock fakeLock(fakeMutex); - if (timeoutUs && s < STATE_CLOSING) { - std::chrono::steady_clock::time_point t - = std::chrono::steady_clock::now() - + std::chrono::microseconds(timeoutUs); - while(s < STATE_CLOSED) { - if (t <= std::chrono::steady_clock::now()) { close(error); break; } - closeCondition.wait_until(fakeLock, t); - s = state; - } - } - if (s < STATE_CLOSING) close(error); - while(s < STATE_CLOSED) { - closeCondition.wait(fakeLock); - s = state; - } - } - - finishWaiters++; - closeWaiters--; - } - - - while(s != STATE_FINISHED) { std::this_thread::yield(); s = state; } - - finishWaiters--; -} +bool Connection::open(Protocol &protocol, void *address, size_t addressSize) + { return open(protocol, address, addressSize, -1, nullptr); } bool Connection::read(Task &task) { DebugMSG(); ReadLock lock(stateProtector); State s = state; - if (s < STATE_OPEN || s >= STATE_CLOSING) return false; + if (s < STATE_OPEN || s >= STATE_CLOSING) + return false; DebugMSG(); readQueue.push(task); - protocol->updateEvents(*this); + updateEvents(readQueue.count(), writeQueue.count()); return true; } @@ -150,18 +59,255 @@ bool Connection::write(Task &task) { DebugMSG(); ReadLock lock(stateProtector); State s = state; - if (s < STATE_OPEN || s >= STATE_CLOSING) return false; + if (s < STATE_OPEN || s >= STATE_CLOSING) + return false; DebugMSG(); writeQueue.push(task); - protocol->updateEvents(*this); + updateEvents(readQueue.count(), writeQueue.count()); return true; } -void Connection::onOpeningError() { } -void Connection::onOpen(const void*, size_t) { } -bool Connection::onReadReady(Task&, bool) { return true; } -bool Connection::onWriteReady(Task&, bool) { return true; } -void Connection::onCloseReqested() { } -void Connection::onClose(bool) { } +void Connection::handleState() { + if (stateLocal == STATE_NONE) { + DebugMSG("none"); + bool connected = false; + + if (sockId >= 0) { + DebugMSG("accepting"); + initSocket(sockId); + connected = true; + } else + if (addressSize >= sizeof(sa_family_t)) { + DebugMSG("connecting"); + struct sockaddr *addr = (struct sockaddr*)address; + sockId = ::socket(addr->sa_family, SOCK_STREAM, 0); + if (sockId >= 0) { + initSocket(sockId); + if (0 == ::connect(sockId, addr, (int)addressSize)) + connected = true; + } else { + DebugMSG("none -> cannot create socket"); + } + } + + if (sockId >= 0) { + DebugMSG("opening"); + prev = protocol->sockLast; + (prev ? prev->next : protocol->sockFirst) = this; + if (server) { + srvPrev = server->connLast; + (srvPrev ? srvPrev->srvNext : server->connFirst) = this; + } + updateEvents(false, !connected); + + if (connected) { + DebugMSG("to open"); + state = stateLocal = STATE_OPEN; + stateProtector.wait(); + onOpen(address, addressSize); + } else { + assert(!server); + DebugMSG("to connecting"); + state = stateLocal = STATE_CONNECTING; + stateProtector.wait(); + } + } else { + DebugMSG("none -> to closed"); + Server *srv = server; + assert(connected || !server); + server = nullptr; + sockId = -1; + onOpeningError(); + state = STATE_CLOSED; + stateProtector.wait(); + while(closeWaiters > 0) + { closeCondition.notify_all(); std::this_thread::yield(); } + state = STATE_FINISHED; + DebugMSG("none -> to finished"); + if (srv) server->handleDisconnect(*this, true); + } + } else { + State sw = stateWanted; + + if ( stateLocal == STATE_CONNECTING + && sw == STATE_CLOSING ) + { + DebugMSG("connecting -> to closing"); + onOpeningError(); + state = stateLocal = STATE_CLOSING; + stateProtector.wait(); + } + + if ( stateLocal == STATE_OPEN + && sw == STATE_CLOSE_REQ ) + { + DebugMSG("to close req"); + state = stateLocal = STATE_CLOSE_REQ; + stateProtector.wait(); + onCloseReqested(); + } + + if ( stateLocal >= STATE_OPEN + && stateLocal <= STATE_CLOSE_REQ + && sw == STATE_CLOSING ) + { + DebugMSG("to closing"); + state = stateLocal = STATE_CLOSING; + stateProtector.wait(); + while(Task *task = readQueue.peek()) + { readQueue.pop(); onReadReady(*task, true); } + while(Task *task = writeQueue.peek()) + { writeQueue.pop(); onWriteReady(*task, true); } + onClose(error); + } + + if ( stateLocal == STATE_CLOSING && !enqueued ) { + DebugMSG("to closed"); + Server *srv = server; + bool err = error; + + if (server) { + (srvPrev ? srvPrev->srvNext : server->connFirst) = srvNext; + (srvNext ? srvNext->srvPrev : server->connLast ) = srvPrev; + server = nullptr; + } + error = false; + + (prev ? prev->next : protocol->sockFirst) = next; + (next ? next->prev : protocol->sockLast ) = prev; + if (sockId >= 0) { + struct epoll_event event = {}; + epoll_ctl(protocol->epollFd, EPOLL_CTL_DEL, sockId, &event); + ::close(sockId); + sockId = -1; + } + events = 0xffffffff; + protocol = nullptr; + + memset(address, 0, sizeof(address)); + addressSize = 0; + stateLocal = STATE_NONE; + state = STATE_CLOSED; + stateWanted = STATE_NONE; + while(closeWaiters > 0) + { closeCondition.notify_all(); std::this_thread::yield(); } + state = STATE_FINISHED; + DebugMSG("to finished"); + if (srv) srv->handleDisconnect(*this, err); + } + } +} + + +void Connection::handleEvents(unsigned int events) { + bool needUpdate = false; + DebugMSG("%08x", events); + + if (stateLocal == STATE_CONNECTING) { + DebugMSG("connecting"); + if (events & (EPOLLHUP | EPOLLERR)) { + close(true); + return; + } + + if (events & EPOLLOUT) { + DebugMSG("connecting -> to open"); + state = stateLocal = STATE_OPEN; + stateProtector.wait(); + onOpen(address, addressSize); + needUpdate = true; + } + } + + if ( stateLocal < STATE_OPEN + || stateLocal >= STATE_CLOSING ) + return; + + bool closing = (bool)(events & (EPOLLHUP | EPOLLRDHUP | EPOLLERR)); + + if (events & EPOLLIN) { + DebugMSG("input"); + bool ready = true; + int pops = 0; + int count = (readQueue.count() + 2)*2; + while(ready) { + Task *task = readQueue.peek(); + if (!task) { + if (pops) needUpdate = true; + break; + } + + while(task->size > task->completion) { + size_t size = task->size - task->completion; + if (size > INT_MAX) size = INT_MAX; + int res = ::recv( + sockId, + (unsigned char*)task->data + task->completion, + (int)size, MSG_DONTWAIT ); + if (res <= 0) { + ready = false; + break; + } else { + task->completion += res; + } + } + + if (task->size <= task->completion) { + readQueue.pop(); + onReadReady(*task, false); + ++pops; + if (!closing && --count < 0) ready = false; + } + } + } + + if (closing) { + DebugMSG("hup"); + close((bool)(events & EPOLLERR)); + return; + } + + if (events & EPOLLOUT) { + DebugMSG("output"); + bool ready = true; + int pops = 0; + int count = writeQueue.count() + 2; + while(ready) { + Task *task = writeQueue.peek(); + if (!task) { + if (pops) needUpdate = true; + break; + } + while(task->size > task->completion) { + size_t size = task->size - task->completion; + if (size > INT_MAX) size = INT_MAX; + int res = ::send( + sockId, + (const unsigned char*)task->data + task->completion, + (int)size, MSG_DONTWAIT ); + if (res <= 0) { + ready = false; + break; + } else { + task->completion += res; + } + } + + if (task->size <= task->completion) { + writeQueue.pop(); + onWriteReady(*task, false); + ++pops; + if (--count <= 0) ready = false; + } + } + } + + if (needUpdate) + updateEvents(readQueue.count(), writeQueue.count()); +} + + +void Connection::onReadReady(Task&, bool) { } +void Connection::onWriteReady(Task&, bool) { } diff --git a/connection.h b/connection.h index 5d8b58b..0bdd658 100644 --- a/connection.h +++ b/connection.h @@ -2,117 +2,60 @@ #define CONNECTION_H -#include +#include "socket.h" -#include -#include -#include "utils.h" - - -#define MAX_ADDR_SIZE 256 - - -class Protocol; class Server; -class Root { +class Connection: public Socket { public: - virtual ~Root() { } -}; - - -class Connection: public Root { -public: - enum State: int { - STATE_NONE, - STATE_OPENING, - STATE_INITIALIZING, - STATE_RESOLVING, - STATE_CONNECTING, - STATE_OPEN, - STATE_CLOSE_REQ, - STATE_CLOSING, - STATE_CLOSED, - STATE_FINISHED - }; - struct Task { void *data; size_t size; size_t completion; - bool watch; Task *next; inline explicit Task( void *data = nullptr, size_t size = 0, size_t completion = 0, - bool watch = false, Task *next = nullptr ): - data(data), size(size), completion(completion), watch(watch), next(next) + data(data), size(size), completion(completion), next(next) { assert(!!data == !!size); } }; - typedef CounteredQueueMISO TaskQueue; private: friend class Protocol; + friend class Server; - std::atomic state; - std::atomic stateWanted; - ReadProtector stateProtector; - std::atomic enqueued; - std::atomic error; - - std::atomic closeWaiters; - std::atomic finishWaiters; - std::condition_variable closeCondition; - - unsigned char address[MAX_ADDR_SIZE]; - size_t addressSize; - - Protocol *protocol; Server *server; - Connection *prev, *next; Connection *srvPrev, *srvNext; - Connection *queueNext; - int sockId; - State stateLocal; bool errorLocal; - std::atomic events; TaskQueue readQueue; TaskQueue writeQueue; - void wantState(State state, bool error); - -public: - typedef QueueMISO Queue; - friend class QueueMISO; - friend class ChainFuncs; + bool open(Protocol &protocol, void *address, size_t addressSize, int sockId, Server *server); public: Connection(); - virtual ~Connection(); + ~Connection(); - bool open(Protocol &protocol, void *address, size_t addressSize, int sockId = -1, Server *server = nullptr); - void closeReq(); - void close(bool error); - void closeWait(unsigned long long timeoutUs = 0, bool withReq = true, bool error = true); + bool open(Protocol &protocol, void *address, size_t addressSize); bool read(Task &task); bool write(Task &task); +private: + void handleState() override; + void handleEvents(unsigned int events) override; + protected: - virtual void onOpeningError(); - virtual void onOpen(const void *address, size_t addressSize); - virtual bool onReadReady(Task &task, bool error); - virtual bool onWriteReady(Task &task, bool error); - virtual void onCloseReqested(); - virtual void onClose(bool error); + virtual void onReadReady(Task &task, bool error); + virtual void onWriteReady(Task &task, bool error); }; diff --git a/main.cpp b/main.cpp index 1cd4ae2..a56b328 100644 --- a/main.cpp +++ b/main.cpp @@ -18,30 +18,28 @@ public: char data[2][100]; MyConnection(const char *name): name(name), readTask(data[0], 16), writeTask(data[1], 16), data() - { strcpy(data[1], "01234\n\n\n567890123456789"); } + { strcpy(data[1], "01234567890123456789"); } protected: void onOpeningError() override { myprintf("%s: onOpeningError\n", name); } void onOpen(const void*, size_t) override { myprintf("%s: onOpen\n", name); read(readTask); write(writeTask); } - bool onReadReady(Task &task, bool error) override { + void onReadReady(Task &task, bool error) override { data[0][task.completion] = 0; myprintf("%s: onReadReady: %lu, %lu, %s, %d\n", name, task.completion, task.size, (char*)task.data, error); task.completion = 0; read(task); - return true; } - bool onWriteReady(Task &task, bool error) override { + void onWriteReady(Task &task, bool error) override { myprintf("%s: onWriteReady: %d\n", name, error); task.completion = 0; write(task); - return true; } void onCloseReqested() override - { myprintf("%s: onCloseReqested\n", name); } + { myprintf("%s: onCloseReqested\n", name); close(false); } void onClose(bool error) override { myprintf("%s: onClose: %d\n", name, (int)error); } }; @@ -51,7 +49,7 @@ class MyServer: public Server { protected: void onOpeningError() override { myprintf("Server: onOpeningError\n"); } - void onOpen() override + void onOpen(const void*, size_t) override { myprintf("Server: onOpen\n"); } Connection* onConnect(const void*, size_t) override { myprintf("Server: onConnect\n"); return new MyConnection("server connection"); } @@ -77,7 +75,7 @@ int main(int argc, char **) { addr.sin_family = AF_INET; addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); addr.sin_port = htons(1532); - addr.sin_port = htons(80); + //addr.sin_port = htons(80); if (server) { myprintf("init server\n"); @@ -87,8 +85,8 @@ int main(int argc, char **) { } else { myprintf("server open\n"); myprintf("wait 10s\n"); - std::this_thread::sleep_for( std::chrono::seconds(10) ); - //fgetc(stdin); + //std::this_thread::sleep_for( std::chrono::seconds(10) ); + fgetc(stdin); myprintf("close server\n"); server.closeWait(); } @@ -100,8 +98,8 @@ int main(int argc, char **) { } else { myprintf("connection open\n"); myprintf("wait 10s\n"); - std::this_thread::sleep_for( std::chrono::seconds(10) ); - //fgetc(stdin); + //std::this_thread::sleep_for( std::chrono::seconds(10) ); + fgetc(stdin); myprintf("close connection\n"); connection.closeWait(); } diff --git a/protocol.cpp b/protocol.cpp index f3cac23..417e561 100644 --- a/protocol.cpp +++ b/protocol.cpp @@ -16,9 +16,6 @@ #define DebugMSG HideDebugMSG -#define SrvDebugMSG HideDebugMSG -#define ConnDebugMSG ShowDebugMSG - Protocol::Protocol(): @@ -29,8 +26,7 @@ Protocol::Protocol(): stateWanted(STATE_NONE), closeWaiters(0), finishWaiters(0), - connFirst(), connLast(), - srvFirst(), srvLast() + sockFirst(), sockLast() { } @@ -138,457 +134,6 @@ void Protocol::wakeup() { } -void Protocol::updateEvents(Connection &connection) { - unsigned int events = EPOLLERR | EPOLLRDHUP; - if (connection.readQueue.count()) events |= EPOLLIN; - if (connection.writeQueue.count()) events |= EPOLLOUT; - - ConnDebugMSG("events: %08x", events); - if (connection.events.exchange(events) == events) return; - ConnDebugMSG("apply"); - - struct epoll_event event = {}; - event.data.ptr = &connection; - event.events = events; - epoll_ctl(epollFd, EPOLL_CTL_MOD, connection.sockId, &event); -} - - -void Protocol::initSocket(int sockId) { - int tcp_nodelay = 1; - setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int)); - fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK); -} - - -void Protocol::threadConnQueue() { - while(connQueue.peek()) { - ConnDebugMSG(); - Connection &connection = connQueue.pop(); - connection.enqueued = false; - - if (connection.stateLocal == Connection::STATE_NONE) { - ConnDebugMSG("none"); - int sockId = connection.sockId; - bool connected = false; - - if (sockId >= 0) { - ConnDebugMSG("accepting"); - initSocket(sockId); - } else - if (connection.addressSize >= sizeof(sa_family_t)) { - ConnDebugMSG("connecting"); - struct sockaddr *addr = (struct sockaddr*)connection.address; - sockId = ::socket(addr->sa_family, SOCK_STREAM, 0); - if (sockId >= 0) { - initSocket(sockId); - if (0 == ::connect(sockId, addr, (int)connection.addressSize)) - connected = true; - } else { - ConnDebugMSG("none -> cannot create socket"); - } - } - - struct epoll_event event = {}; - if (sockId >= 0) { - event.data.ptr = &connection; - event.events = EPOLLERR | EPOLLRDHUP; - if (!connected) event.events |= EPOLLOUT; - if (0 != epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event)) { - ConnDebugMSG("epoll_ctl(%d, ADD, %d, %08x) errno: %d", epollFd, sockId, event.events, errno); - ::close(sockId); - connection.sockId = sockId = -1; - assert(false); - } - } - - if (sockId >= 0) { - ConnDebugMSG("opening"); - connection.protocol = this; - connection.prev = connLast; - assert(!connection.next); - (connection.prev ? connection.prev->next : connFirst) = &connection; - if (connection.server) { - connection.srvPrev = connection.server->connLast; - assert(!connection.srvNext); - (connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = &connection; - } - connection.sockId = sockId; - - if (connected) { - ConnDebugMSG("to open"); - connection.events = event.events; - connection.state = connection.stateLocal = Connection::STATE_OPEN; - connection.stateProtector.wait(); - connection.onOpen(connection.address, connection.addressSize); - } else { - assert(!connection.server); - ConnDebugMSG("to connecting"); - connection.state = connection.stateLocal = Connection::STATE_CONNECTING; - connection.stateProtector.wait(); - } - } else { - ConnDebugMSG("none -> to closed"); - Server *server = connection.server; - assert(connected || !server); - connection.server = nullptr; - connection.sockId = -1; - connection.onOpeningError(); - connection.state = Connection::STATE_CLOSED; - while(connection.closeWaiters > 0) - { connection.closeCondition.notify_all(); std::this_thread::yield(); } - connection.state = Connection::STATE_FINISHED; - ConnDebugMSG("none -> to finished"); - if (server) threadSrvDisconnect(*server, connection, true); - } - } else { - Connection::State stateWanted = connection.stateWanted; - - if ( connection.stateLocal == Connection::STATE_CONNECTING - && stateWanted == Connection::STATE_CLOSING ) - { - ConnDebugMSG("connecting -> to closing"); - connection.onOpeningError(); - connection.state = connection.stateLocal = Connection::STATE_CLOSING; - connection.stateProtector.wait(); - } - - if ( connection.stateLocal == Connection::STATE_OPEN - && stateWanted == Connection::STATE_CLOSE_REQ ) - { - ConnDebugMSG("to close req"); - connection.state = connection.stateLocal = Connection::STATE_CLOSE_REQ; - connection.stateProtector.wait(); - connection.onCloseReqested(); - } else - if ( connection.stateLocal >= Connection::STATE_OPEN - && connection.stateLocal <= Connection::STATE_CLOSE_REQ - && stateWanted == Connection::STATE_CLOSING ) - { - ConnDebugMSG("to closing"); - connection.state = connection.stateLocal = Connection::STATE_CLOSING; - connection.stateProtector.wait(); - while(Connection::Task *task = connection.readQueue.peek()) { - connection.readQueue.pop(); - if (!connection.onReadReady(*task, true)) - assert(false); - } - while(Connection::Task *task = connection.writeQueue.peek()) { - connection.writeQueue.pop(); - if (!connection.onWriteReady(*task, true)) - assert(false); - } - connection.errorLocal = connection.error; - connection.onClose(connection.errorLocal); - } - - if (connection.stateLocal == Connection::STATE_CLOSING && !connection.enqueued) { - ConnDebugMSG("to closed"); - Server *server = connection.server; - if (server) { - (connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = connection.srvNext; - (connection.srvNext ? connection.srvNext->srvPrev : connection.server->connLast) = connection.srvPrev; - connection.server = nullptr; - } - bool error = connection.errorLocal; - connection.errorLocal = false; - (connection.prev ? connection.prev->next : connFirst) = connection.next; - (connection.next ? connection.next->prev : connLast) = connection.prev; - if (connection.sockId >= 0) { - struct epoll_event event = {}; - epoll_ctl(epollFd, EPOLL_CTL_DEL, connection.sockId, &event); - ::close(connection.sockId); - connection.sockId = -1; - } - connection.protocol = nullptr; - memset(connection.address, 0, sizeof(connection.address)); - connection.addressSize = 0; - connection.stateWanted = Connection::STATE_NONE; - connection.stateLocal = Connection::STATE_NONE; - connection.state = Connection::STATE_CLOSED; - while(connection.closeWaiters > 0) - { connection.closeCondition.notify_all(); std::this_thread::yield(); } - connection.state = Connection::STATE_FINISHED; - ConnDebugMSG("to finished"); - if (server) threadSrvDisconnect(*server, connection, error); - } - } - } -} - - -void Protocol::threadConnEvents(Connection &connection, unsigned int events) { - bool needUpdate = false; - - if (connection.stateLocal == Connection::STATE_CONNECTING) { - ConnDebugMSG("connecting"); - if (events & (EPOLLHUP | EPOLLERR)) { - connection.close(true); - return; - } - - if (events & EPOLLOUT) { - ConnDebugMSG("connecting -> to open"); - connection.state = connection.stateLocal = Connection::STATE_OPEN; - connection.stateProtector.wait(); - connection.onOpen(connection.address, connection.addressSize); - needUpdate = true; - } - } - - if ( connection.stateLocal < Connection::STATE_OPEN - || connection.stateLocal >= Connection::STATE_CLOSING ) - return; - - bool closing = (bool)(events & (EPOLLHUP | EPOLLRDHUP | EPOLLERR)); - - if (events & EPOLLIN) { - ConnDebugMSG("input"); - bool ready = true; - int pops = 0; - int count = (connection.readQueue.count() + 2)*2; - while(ready) { - Connection::Task *task = connection.readQueue.peek(); - if (!task) { - if (pops) needUpdate = true; - break; - } - size_t completion = 0; - while(task->size > task->completion) { - size_t size = task->size - task->completion; - if (size > INT_MAX) size = INT_MAX; - int res = ::recv( - connection.sockId, - (unsigned char*)task->data + task->completion, - (int)size, MSG_DONTWAIT ); - if (res == EAGAIN || res == 0) { - ready = false; - break; - } else - if (res < 0) { - connection.close(true); - return; - } else { - completion += res; - task->completion += res; - } - } - - if (task->size <= task->completion || (completion && task->watch)) { - connection.readQueue.pop(); - if (!connection.onReadReady(*task, false)) - connection.readQueue.unpop(*task); else ++pops; - if (!closing && --count < 0) ready = false; - } - } - } - - if (closing) { - ConnDebugMSG("hup"); - connection.close((bool)(events & EPOLLERR)); - return; - } - - if (events & EPOLLOUT) { - ConnDebugMSG("output"); - bool ready = true; - int pops = 0; - int count = connection.writeQueue.count() + 2; - while(ready) { - Connection::Task *task = connection.writeQueue.peek(); - if (!task) { - if (pops) needUpdate = true; - break; - } - size_t completion = 0; - while(task->size > task->completion) { - size_t size = task->size - task->completion; - if (size > INT_MAX) size = INT_MAX; - int res = ::send( - connection.sockId, - (const unsigned char*)task->data + task->completion, - (int)size, MSG_DONTWAIT ); - if (res == EAGAIN || res == 0) { - ready = false; - break; - } else - if (res < 0) { - connection.close(true); - return; - } else { - completion += res; - task->completion += res; - } - } - - if (task->size <= task->completion || (completion && task->watch)) { - connection.writeQueue.pop(); - if (!connection.onWriteReady(*task, false)) - connection.writeQueue.unpop(*task); else ++pops; - if (--count <= 0) ready = false; - } - } - } - - if (needUpdate) updateEvents(connection); -} - - -void Protocol::threadSrvDisconnect(Server& server, Connection& connection, bool error) { - server.onDisconnect(&connection, error); - if (!server.connFirst && server.stateLocal == Server::STATE_CLOSING_CONNECTIONS) { - SrvDebugMSG("to closing"); - server.state = server.stateLocal = Server::STATE_CLOSING; - server.stateProtector.wait(); - server.onClose(server.error); - if (!server.enqueued.exchange(true)) srvQueue.push(server); - } -} - - -void Protocol::threadSrvQueue() { - while(srvQueue.peek()) { - SrvDebugMSG(); - Server &server = srvQueue.pop(); - server.enqueued = false; - - if (server.stateLocal == Server::STATE_NONE) { - SrvDebugMSG("none"); - int sockId = -1; - if (server.addressSize >= sizeof(sa_family_t)) { - struct sockaddr *addr = (struct sockaddr*)server.address; - sockId = ::socket(addr->sa_family, SOCK_STREAM, 0); - if (sockId >= 0) { - initSocket(sockId); - if ( 0 != ::bind(sockId, addr, (int)server.addressSize) - || 0 != ::listen(sockId, 128) ) - { - SrvDebugMSG("cannot bind/listen"); - ::close(sockId); - sockId = -1; - } - } else { - SrvDebugMSG("cannot create socket"); - } - } - - if (sockId >= 0) { - SrvDebugMSG("to open"); - server.protocol = this; - server.prev = srvLast; - assert(!server.next); - (server.prev ? server.prev->next : srvFirst) = &server; - server.sockId = sockId; - - struct epoll_event event = {}; - event.data.ptr = &server; - event.events = EPOLLIN | EPOLLERR; - epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event); - - server.state = server.stateLocal = Server::STATE_OPEN; - server.stateProtector.wait(); - server.onOpen(); - } else { - SrvDebugMSG("opening error -> to closed"); - server.onOpeningError(); - server.state = Server::STATE_CLOSED; - while(server.closeWaiters > 0) - { server.closeCondition.notify_all(); std::this_thread::yield(); } - server.state = Server::STATE_FINISHED; - SrvDebugMSG("opening error -> finished"); - } - } else { - Server::State stateWanted = server.stateWanted; - if ( server.stateLocal == Server::STATE_OPEN - && stateWanted == Server::STATE_CLOSE_REQ ) - { - SrvDebugMSG("to close req"); - for(Connection *conn = server.connFirst; conn; conn = conn->srvNext) - conn->closeReq(); - struct epoll_event event = {}; - epoll_ctl(epollFd, EPOLL_CTL_DEL, server.sockId, &event); - server.state = server.stateLocal = Server::STATE_CLOSE_REQ; - server.stateProtector.wait(); - server.onCloseReqested(); - } else - if ( ( server.stateLocal == Server::STATE_OPEN - || server.stateLocal == Server::STATE_CLOSE_REQ ) - && stateWanted == Server::STATE_CLOSING_CONNECTIONS ) - { - SrvDebugMSG("to closing connections"); - if (server.stateLocal != Server::STATE_CLOSE_REQ) { - struct epoll_event event = {}; - epoll_ctl(epollFd, EPOLL_CTL_DEL, server.sockId, &event); - } - if (server.connFirst) { - for(Connection *conn = server.connFirst; conn; conn = conn->srvNext) - conn->close(true); - server.state = server.stateLocal = Server::STATE_CLOSING_CONNECTIONS; - } else { - SrvDebugMSG("to closing"); - server.state = server.stateLocal = Server::STATE_CLOSING; - server.stateProtector.wait(); - server.onClose(server.error); - } - } - - if (server.stateLocal == Server::STATE_CLOSING && !server.enqueued) { - SrvDebugMSG("to closed"); - if (server.sockId >= 0) { - struct epoll_event event = {}; - epoll_ctl(epollFd, EPOLL_CTL_DEL, server.sockId, &event); - ::close(server.sockId); - server.sockId = -1; - } - (server.prev ? server.prev->next : srvFirst) = server.next; - (server.next ? server.next->prev : srvLast) = server.prev; - server.protocol = nullptr; - memset(server.address, 0, sizeof(server.address)); - server.addressSize = 0; - server.stateWanted = Server::STATE_NONE; - server.stateLocal = Server::STATE_NONE; - SrvDebugMSG("closeWaiters"); - server.state = Server::STATE_CLOSED; - while(server.closeWaiters > 0) - { server.closeCondition.notify_all(); std::this_thread::yield(); } - server.state = Server::STATE_FINISHED; - SrvDebugMSG("finished"); - } - } - } -} - - -void Protocol::threadSrvEvents(Server &server, unsigned int events) { - if (events & (EPOLLHUP | EPOLLERR)) { - if ( server.stateLocal >= Server::STATE_OPEN - && server.stateLocal < Server::STATE_CLOSING ) - server.close((bool)(events & EPOLLERR)); - return; - } - - if (events & EPOLLIN) { - if ( server.stateLocal >= Server::STATE_OPEN - && server.stateLocal < Server::STATE_CLOSE_REQ) - { - unsigned char address[MAX_ADDR_SIZE] = {}; - socklen_t len = sizeof(address); - int res = ::accept(server.sockId, (struct sockaddr*)address, &len); - assert(len < sizeof(address)); - if (res >= 0) { - if (Connection *conn = server.onConnect(address, len)) - if (!conn->open(*this, address, len, res)) - server.onDisconnect(conn, true); - } else - if (res != EAGAIN) { - server.close(true); - return; - } - } - } -} - - void Protocol::threadRun() { DebugMSG(); epollFd = ::epoll_create(32); @@ -609,16 +154,11 @@ void Protocol::threadRun() { while(true) { for(int i = 0; i < count; ++i) { - Root *pointer = (Root*)events[i].data.ptr; - if (!pointer) { + if (Socket *socket = (Socket*)events[i].data.ptr) { + socket->handleEvents(events[i].events); + } else { unsigned long long num = 0; while(::read(eventFd, &num, sizeof(num)) > 0 && !num); - } else - if (Connection *connection = dynamic_cast(pointer)) { - threadConnEvents(*connection, events[i].events); - } else - if (Server *server = dynamic_cast(pointer)) { - threadSrvEvents(*server, events[i].events); } } @@ -627,48 +167,45 @@ void Protocol::threadRun() { DebugMSG("to closse req"); state = stateLocal = STATE_CLOSE_REQ; stateProtector.wait(); - for(Server *server = srvFirst; server; server = server->next) - server->closeReq(); - for(Connection *conn = connFirst; conn; conn = conn->next) - if (!conn->server) - conn->closeReq(); + for(Socket *socket = sockFirst; socket; socket = socket->next) + socket->closeReq(); } else if (stateLocal >= STATE_OPEN && stateLocal <= STATE_CLOSE_REQ && sw == STATE_CLOSING) { DebugMSG("to closing"); state = stateLocal = STATE_CLOSING; stateProtector.wait(); - for(Server *server = srvFirst; server; server = server->next) - server->close(true); - for(Connection *conn = connFirst; conn; conn = conn->next) - if (!conn->server) - conn->close(true); + for(Socket *socket = sockFirst; socket; socket = socket->next) + socket->close(true); } - while(connQueue.peek() || srvQueue.peek()) { - threadConnQueue(); - threadSrvQueue(); + while(sockQueue.peek()) { + Socket &socket = sockQueue.pop(); + socket.enqueued = false; + socket.handleState(); } - if (stateLocal == STATE_CLOSING) { - assert(!connFirst && !srvFirst); + if (stateLocal == STATE_CLOSING) break; - } count = epoll_wait(epollFd, events, maxCount, 1000); DebugMSG("epoll count: %d", count); } + assert(!sockFirst); + DebugMSG("closeWaiters"); state = STATE_CLOSED; while(closeWaiters > 0) { closeCondition.notify_all(); std::this_thread::yield(); } + epoll_ctl(epollFd, EPOLL_CTL_DEL, eventFd, &event); + ::close(epollFd); ::close(eventFd); epollFd = -1; eventFd = -1; - assert(!connFirst && !srvFirst); - assert(!connQueue.peek() && !srvQueue.peek()); + assert(!sockFirst); + assert(!sockQueue.peek()); stateWanted = STATE_NONE; stateLocal = STATE_NONE; diff --git a/protocol.h b/protocol.h index 7a9ed74..caf587f 100644 --- a/protocol.h +++ b/protocol.h @@ -11,19 +11,8 @@ class Protocol { -public: - enum State { - STATE_NONE, - STATE_OPENING, - STATE_INITIALIZING, - STATE_OPEN, - STATE_CLOSE_REQ, - STATE_CLOSING, - STATE_CLOSED, - STATE_FINISHED - }; - private: + friend class Socket; friend class Server; friend class Connection; @@ -39,23 +28,12 @@ private: std::atomic finishWaiters; std::condition_variable closeCondition; - Connection *connFirst, *connLast; - Connection::Queue connQueue; + Socket *sockFirst, *sockLast; + Connection::Queue sockQueue; - Server *srvFirst, *srvLast; - Server::Queue srvQueue; - void wakeup(); - void updateEvents(Connection &connection); - void initSocket(int sockId); - void wantState(State state); - void threadConnQueue(); - void threadConnEvents(Connection &connection, unsigned int events); - void threadSrvDisconnect(Server &server, Connection &connection, bool error); - void threadSrvQueue(); - void threadSrvEvents(Server &server, unsigned int events); void threadRun(); public: diff --git a/server.cpp b/server.cpp index e36344c..504f4ab 100644 --- a/server.cpp +++ b/server.cpp @@ -5,6 +5,12 @@ #include #include +#include +#include +#include +#include +#include + #include "protocol.h" #include "server.h" @@ -13,127 +19,157 @@ Server::Server(): - state(STATE_NONE), - stateWanted(STATE_NONE), - enqueued(false), - error(false), - closeWaiters(0), - finishWaiters(0), - address(), - addressSize(), - protocol(), - prev(), next(), - queueNext(), - connFirst(), connLast(), - sockId(-1), - stateLocal() + connFirst(), connLast() { } -Server::~Server() { - assert(state == STATE_NONE || state == STATE_FINISHED); - closeWait(); - while(finishWaiters > 0) std::this_thread::yield(); -} +Server::~Server() + { destroy();} -bool Server::open(Protocol &protocol, void *address, size_t addressSize) { - if (!address || !addressSize || addressSize > MAX_ADDR_SIZE) - { DebugMSG("bad args"); return false; } - - State s = state; - do { - if (s != STATE_NONE && s != STATE_FINISHED) - { DebugMSG("already open or in closing state"); return false; } - } while(!state.compare_exchange_weak(s, STATE_OPENING)); - if (s == STATE_FINISHED) while(finishWaiters > 0) std::this_thread::yield(); - - ReadLock lock(protocol.stateProtector); - Protocol::State ps = protocol.state; - if (ps < Protocol::STATE_INITIALIZING || ps >= Protocol::STATE_CLOSE_REQ) - { DebugMSG("protocol is closed"); state = STATE_NONE; return false; } - - memcpy(this->address, address, addressSize); - this->addressSize = addressSize; - - this->protocol = &protocol; - state = STATE_INITIALIZING; - if (!enqueued.exchange(true)) protocol.srvQueue.push(*this); - protocol.wakeup(); - return true; -} +bool Server::open(Protocol &protocol, void *address, size_t addressSize) + { OpenLock lock(*this, protocol, address, addressSize); return lock.success; } -void Server::wantState(State state, bool error) { - ReadLock lock(stateProtector); - State s = this->state; - if (s <= STATE_OPENING || s >= STATE_CLOSING_CONNECTIONS || s >= state) return; - - State sw = stateWanted; - do { - if (sw >= state) return; - } while(!stateWanted.compare_exchange_weak(sw, state)); - if (error) { error = false; this->error.compare_exchange_strong(error, true); } - if (!enqueued.exchange(true)) protocol->srvQueue.push(*this); - protocol->wakeup(); +void Server::handleDisconnect(Connection &connection, bool error) { + onDisconnect(&connection, error); + if (!connFirst && stateLocal == STATE_CLOSING_CONNECTIONS) { + DebugMSG("to closing"); + state = stateLocal = STATE_CLOSING; + stateProtector.wait(); + onClose(error); + enqueue(); + } } -void Server::closeReq() - { wantState(STATE_CLOSE_REQ, false); } -void Server::close(bool error) - { wantState(STATE_CLOSING_CONNECTIONS, error); } - - -void Server::closeWait(unsigned long long timeoutUs, bool withReq) { - stateProtector.lock(); - State s = state; - if (s <= STATE_OPENING || s >= STATE_FINISHED) - { stateProtector.unlock(); return; } - - if (s == STATE_CLOSED) { - finishWaiters++; - stateProtector.unlock(); +void Server::handleState() { + if (stateLocal == STATE_NONE) { + DebugMSG("none"); + int sockId = -1; + if (addressSize >= sizeof(sa_family_t)) { + struct sockaddr *addr = (struct sockaddr*)address; + sockId = ::socket(addr->sa_family, SOCK_STREAM, 0); + if (sockId >= 0) { + initSocket(sockId); + if ( 0 != ::bind(sockId, addr, (int)addressSize) + || 0 != ::listen(sockId, 128) ) + { + DebugMSG("cannot bind/listen"); + ::close(sockId); + sockId = -1; + } + } else { + DebugMSG("cannot create socket"); + } + } + + if (sockId >= 0) { + DebugMSG("to open"); + prev = protocol->sockLast; + (prev ? prev->next : protocol->sockFirst) = this; + this->sockId = sockId; + + updateEvents(true, false); + + state = stateLocal = STATE_OPEN; + stateProtector.wait(); + onOpen(address, addressSize); + } else { + DebugMSG("opening error -> to closed"); + onOpeningError(); + state = STATE_CLOSED; + while(closeWaiters > 0) + { closeCondition.notify_all(); std::this_thread::yield(); } + state = STATE_FINISHED; + DebugMSG("opening error -> finished"); + } } else { - if (withReq && s < STATE_CLOSE_REQ) closeReq(); - closeWaiters++; - stateProtector.unlock(); + State sw = stateWanted; + if ( stateLocal == STATE_OPEN + && sw == STATE_CLOSE_REQ ) + { + DebugMSG("to close req"); + for(Connection *conn = connFirst; conn; conn = conn->srvNext) + conn->closeReq(); + updateEvents(false, false); + state = stateLocal = STATE_CLOSE_REQ; + stateProtector.wait(); + onCloseReqested(); + } - std::mutex fakeMutex; + if ( stateLocal >= STATE_OPEN + && stateLocal <= STATE_CLOSE_REQ + && sw == STATE_CLOSING ) { - std::unique_lock fakeLock(fakeMutex); - if (timeoutUs && s < STATE_CLOSING) { - std::chrono::steady_clock::time_point t - = std::chrono::steady_clock::now() - + std::chrono::microseconds(timeoutUs); - while(s < STATE_CLOSED) { - if (t <= std::chrono::steady_clock::now()) { close(true); break; } - closeCondition.wait_until(fakeLock, t); - s = state; - } - } - if (s < STATE_CLOSING) close(true); - while(s < STATE_CLOSED) { - closeCondition.wait(fakeLock); - s = state; + DebugMSG("to closing connections"); + updateEvents(false, false); + if (connFirst) { + for(Connection *conn = connFirst; conn; conn = conn->srvNext) + conn->close(true); + state = stateLocal = STATE_CLOSING_CONNECTIONS; + stateProtector.wait(); + } else { + DebugMSG("to closing"); + state = stateLocal = STATE_CLOSING; + stateProtector.wait(); + onClose(error); } } - finishWaiters++; - closeWaiters--; + if ( stateLocal == STATE_CLOSING && !enqueued ) { + DebugMSG("to closed"); + if (sockId >= 0) { + struct epoll_event event = {}; + epoll_ctl(protocol->epollFd, EPOLL_CTL_DEL, sockId, &event); + ::close(sockId); + sockId = -1; + } + events = 0xffffffff; + (prev ? prev->next : protocol->sockFirst) = next; + (next ? next->prev : protocol->sockLast ) = prev; + protocol = nullptr; + memset(this->address, 0, sizeof(address)); + addressSize = 0; + stateWanted = STATE_NONE; + stateLocal = STATE_NONE; + DebugMSG("closeWaiters"); + state = STATE_CLOSED; + while(closeWaiters > 0) + { closeCondition.notify_all(); std::this_thread::yield(); } + state = STATE_FINISHED; + DebugMSG("finished"); + } + } +} + + +void Server::handleEvents(unsigned int events) { + if (events & (EPOLLHUP | EPOLLERR)) { + if ( stateLocal >= STATE_OPEN + && stateLocal < STATE_CLOSING ) + close((bool)(events & EPOLLERR)); + return; } - - while(s != STATE_FINISHED) { std::this_thread::yield(); s = state; } - - finishWaiters--; + if (events & EPOLLIN) { + if ( stateLocal >= STATE_OPEN + && stateLocal < STATE_CLOSE_REQ) + { + unsigned char address[MAX_ADDR_SIZE] = {}; + socklen_t len = sizeof(address); + int res = ::accept(sockId, (struct sockaddr*)address, &len); + assert(len < sizeof(address)); + if (res >= 0) + if (Connection *conn = onConnect(address, len)) + if (!conn->open(*protocol, address, len, res, this)) + onDisconnect(conn, true); + } + } } -void Server::onOpeningError() { } -void Server::onOpen() { } + Connection* Server::onConnect(const void*, size_t) { return nullptr; } void Server::onDisconnect(Connection*, bool) { } -void Server::onCloseReqested() { } -void Server::onClose(bool) { } diff --git a/server.h b/server.h index cd176a6..c61d70d 100644 --- a/server.h +++ b/server.h @@ -5,67 +5,30 @@ #include "connection.h" -class Server: public Root { - enum State: int { - STATE_NONE, - STATE_OPENING, - STATE_INITIALIZING, - STATE_RESOLVING, - STATE_OPEN, - STATE_CLOSE_REQ, - STATE_CLOSING_CONNECTIONS, - STATE_CLOSING, - STATE_CLOSED, - STATE_FINISHED - }; - +class Server: public Socket { private: friend class Protocol; + friend class Connection; - std::atomic state; - std::atomic stateWanted; - ReadProtector stateProtector; - std::atomic enqueued; - std::atomic error; - - std::atomic closeWaiters; - std::atomic finishWaiters; - std::condition_variable closeCondition; - - unsigned char address[MAX_ADDR_SIZE]; - size_t addressSize; - - Protocol *protocol; - Server *prev, *next; - Server *queueNext; Connection *connFirst, *connLast; - int sockId; - State stateLocal; - - void wantState(State state, bool error); - -public: - typedef QueueMISO Queue; - friend class QueueMISO; - friend class ChainFuncs; public: Server(); ~Server(); bool open(Protocol &protocol, void *address, size_t addressSize); - void closeReq(); - void close(bool error); - void closeWait(unsigned long long timeoutUs = 0, bool withReq = true); + +private: + void handleDisconnect(Connection &connection, bool error); + + void handleState() override; + void handleEvents(unsigned int events) override; protected: - virtual void onOpeningError(); - virtual void onOpen(); virtual Connection* onConnect(const void *address, size_t addressSize); virtual void onDisconnect(Connection *connection, bool error); - virtual void onCloseReqested(); - virtual void onClose(bool error); }; #endif + diff --git a/socket.cpp b/socket.cpp new file mode 100644 index 0000000..01a4186 --- /dev/null +++ b/socket.cpp @@ -0,0 +1,194 @@ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "socket.h" +#include "protocol.h" + + +#define DebugMSG ShowDebugMSG + + + +Socket::Socket(): + state(STATE_NONE), + stateWanted(STATE_NONE), + enqueued(false), + closeWaiters(0), + finishWaiters(0), + address(), + addressSize(), + protocol(), + prev(), next(), + queueNext(), + sockId(-1), + stateLocal(), + events(0xffffffff) + { } + + +Socket::~Socket() + { assert(state == STATE_DESTROYED); } + + +void Socket::destroy() { + assert(state == STATE_NONE || state == STATE_FINISHED); + ++finishWaiters; + closeWait(); + state = STATE_DESTROYED; + --finishWaiters; + while(finishWaiters > 0) + std::this_thread::yield(); +} + + +bool Socket::openBegin(Protocol &protocol, void *address, size_t addressSize) { + if (!address || !addressSize || addressSize > MAX_ADDR_SIZE) return false; + + State s = state; + do { + if (s != STATE_NONE && s != STATE_FINISHED) + return false; + } while(!state.compare_exchange_weak(s, STATE_OPENING)); + if (s == STATE_FINISHED) + while(finishWaiters > 0) + std::this_thread::yield(); + if (state == STATE_DESTROYED) + return false; + + protocol.stateProtector.lock(); + State ps = protocol.state; + if (ps < STATE_INITIALIZING || ps >= STATE_CLOSE_REQ) { + state = STATE_NONE; + protocol.stateProtector.unlock(); + return false; + } + + this->protocol = &protocol; + memcpy(this->address, address, addressSize); + this->addressSize = addressSize; + + return true; +} + + +void Socket::openEnd() { + state = STATE_INITIALIZING; + enqueue(); + protocol->stateProtector.unlock(); +} + + +void Socket::initSocket(int sockId) { + int tcp_nodelay = 1; + setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int)); + fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK); +} + + +void Socket::enqueue() { + if (!enqueued.exchange(true)) { + protocol->sockQueue.push(*this); + protocol->wakeup(); + } +} + + +void Socket::updateEvents(bool wantRead, bool wantWrite) { + unsigned int e = EPOLLERR | EPOLLRDHUP; + if (wantRead) e |= EPOLLIN; + if (wantWrite) e |= EPOLLOUT; + + DebugMSG("events: %08x", e); + unsigned int pe = events.exchange(e); + if (pe == e) + return; + DebugMSG("apply"); + + struct epoll_event event = {}; + event.data.ptr = this; + event.events = e; + epoll_ctl(protocol->epollFd, pe == 0xffffffff ? EPOLL_CTL_ADD : EPOLL_CTL_MOD, sockId, &event); +} + + +void Socket::wantState(State state, bool error) { + ReadLock lock(stateProtector); + State s = this->state; + DebugMSG("%d, %d", s, state); + if (s <= STATE_OPENING || s >= STATE_CLOSING || s >= state) return; + DebugMSG("apply"); + + State sw = stateWanted; + do { + if (sw >= state) return; + } while(!stateWanted.compare_exchange_weak(sw, state)); + if (error) { DebugMSG("error"); error = false; this->error.compare_exchange_strong(error, true); } + enqueue(); +} + + +void Socket::closeReq() + { wantState(STATE_CLOSE_REQ, false); } +void Socket::close(bool error) + { wantState(STATE_CLOSING, error); } + + +void Socket::closeWait(unsigned long long timeoutUs, bool withReq, bool error) { + stateProtector.lock(); + State s = state; + if (s <= STATE_OPENING || s >= STATE_FINISHED) + { stateProtector.unlock(); return; } + + if (s == STATE_CLOSED) { + finishWaiters++; + stateProtector.unlock(); + } else { + if (withReq && s < STATE_CLOSE_REQ) closeReq(); + closeWaiters++; + stateProtector.unlock(); + + std::mutex fakeMutex; + { + std::unique_lock fakeLock(fakeMutex); + if (timeoutUs && s < STATE_CLOSING) { + std::chrono::steady_clock::time_point t + = std::chrono::steady_clock::now() + + std::chrono::microseconds(timeoutUs); + while(s < STATE_CLOSED) { + if (t <= std::chrono::steady_clock::now()) { close(error); break; } + closeCondition.wait_until(fakeLock, t); + s = state; + } + } + if (s < STATE_CLOSING) close(error); + while(s < STATE_CLOSED) { + closeCondition.wait(fakeLock); + s = state; + } + } + + finishWaiters++; + closeWaiters--; + } + + + while(s < STATE_FINISHED) { std::this_thread::yield(); s = state; } + + finishWaiters--; +} + + +void Socket::onOpeningError() { } +void Socket::onOpen(const void*, size_t) { } +void Socket::onCloseReqested() { } +void Socket::onClose(bool) { } + + diff --git a/socket.h b/socket.h new file mode 100644 index 0000000..d29f1ef --- /dev/null +++ b/socket.h @@ -0,0 +1,109 @@ +#ifndef SOCKET_H +#define SOCKET_H + + +#include + +#include +#include + +#include "utils.h" + + +#define MAX_ADDR_SIZE 256 + + + +enum State: int { + STATE_NONE, + STATE_OPENING, + STATE_INITIALIZING, + STATE_RESOLVING, + STATE_CONNECTING, + STATE_OPEN, + STATE_CLOSE_REQ, + STATE_CLOSING_CONNECTIONS, + STATE_CLOSING, + STATE_CLOSED, + STATE_FINISHED, + STATE_DESTROYED +}; + + +class Protocol; +class Connection; +class Server; + + +class Socket { +private: + friend class Protocol; + friend class Connection; + friend class Server; + + std::atomic state; + std::atomic stateWanted; + ReadProtector stateProtector; + std::atomic enqueued; + std::atomic error; + + std::atomic closeWaiters; + std::atomic finishWaiters; + std::condition_variable closeCondition; + + unsigned char address[MAX_ADDR_SIZE]; + size_t addressSize; + + Protocol *protocol; + Socket *prev, *next; + Socket *queueNext; + int sockId; + State stateLocal; + + std::atomic events; + + struct OpenLock { + Socket &socket; + const bool success; + OpenLock(Socket &socket, Protocol &protocol, void *address, size_t addressSize): + socket(socket), success(socket.openBegin(protocol, address, addressSize)) { } + ~OpenLock() + { if (success) socket.openEnd(); } + }; + + bool openBegin(Protocol &protocol, void *address, size_t addressSize); + void openEnd(); + + void initSocket(int sockId); + void enqueue(); + void updateEvents(bool wantRead, bool wantWrite); + void wantState(State state, bool error); + + void destroy(); + +public: + typedef QueueMISO Queue; + friend class QueueMISO; + friend class ChainFuncs; + +public: + Socket(); + virtual ~Socket(); + + void closeReq(); + void close(bool error); + void closeWait(unsigned long long timeoutUs = 0, bool withReq = true, bool error = true); + +private: + virtual void handleState() = 0; + virtual void handleEvents(unsigned int events) = 0; + +protected: + virtual void onOpeningError(); + virtual void onOpen(const void *address, size_t addressSize); + virtual void onCloseReqested(); + virtual void onClose(bool error); +}; + + +#endif