From 51b3f082e9ff07c0cfe983ae3aeef4a73a252212 Mon Sep 17 00:00:00 2001 From: Ivan Mahonin Date: Oct 15 2021 08:33:25 +0000 Subject: fix protocol --- diff --git a/main.cpp b/main.cpp index 4c4853b..1d6ee37 100644 --- a/main.cpp +++ b/main.cpp @@ -1,8 +1,90 @@ +#include +#include + +#include +#include + #include "protocol.h" -int main() { +class MyConnection: public Connection { +public: + const char *name; + Task readTask, writeTask; + char data[2][100]; + MyConnection(const char *name): + name(name), readTask(data[0], 16), writeTask(data[1], 16), data() + { strcpy(data[1], "01234567890123456789"); } +protected: + void onOpeningError() override + { printf("%s: onOpeningError\n", name); } + void onOpen(const void*, size_t) override + { printf("%s: onOpen\n", name); read(readTask); write(writeTask); } + bool onReadReady(Task &task, bool error) override + { printf("%s: onReadReady: %s, %d\n", name, (char*)task.data, error); return true; } + bool onWriteReady(Task &task, bool error) override + { printf("%s: onWriteReady: %d\n", name, error); write(task); return true; } + void onCloseReqested() override + { printf("%s: onCloseReqested\n", name); } + void onClose(bool error) override + { printf("%s: onClose: %d\n", name, (int)error); } +}; + + +class MyServer: public Server { +protected: + void onOpeningError() override + { printf("Server: onOpeningError\n"); } + void onOpen() override + { printf("Server: onOpen\n"); } + Connection* onConnect(const void*, size_t) override + { printf("Server: onConnect\n"); return new MyConnection("server connection"); } + void onDisconnect(Connection *connection, bool error) override + { printf("Server: onDisconnect: %d\n", error); delete connection; } + void onCloseReqested() override + { printf("Server: onCloseReqested\n"); } + void onClose(bool error) override + { printf("Server: onClose: %d\n", error); } +}; + + + +int main(int argc, char **) { + bool server = argc == 1; + + printf("init protocol\n"); + Protocol protocol; + protocol.open(); + printf("protocol open\n"); + + struct sockaddr_in addr = {}; + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = 0x8002; + addr.sin_port = htons(1532); + + if (server) { + printf("init server\n"); + MyServer server; + server.open(protocol, &addr, sizeof(addr)); + printf("server open\n"); + fgetc(stdin); + printf("close server\n"); + server.close(false); + } else { + printf("init connection\n"); + MyConnection connection("client connection"); + connection.open(protocol, &addr, sizeof(addr)); + printf("connection open\n"); + fgetc(stdin); + printf("close connection\n"); + connection.close(false); + } + + printf("close protocol\n"); + protocol.closeWait(); + printf("protocol closed\n"); + return 0; } diff --git a/protocol.cpp b/protocol.cpp index ea121ad..e05cb31 100644 --- a/protocol.cpp +++ b/protocol.cpp @@ -15,16 +15,35 @@ #include "protocol.h" +#define DebugMSG ShowDebugMSG + + + +Protocol::Protocol(): + thread(), + epollFd(-1), + eventFd(-1), + state(STATE_NONE), + stateWanted(STATE_NONE), + closeWaiters(0), + finishWaiters(0), + connFirst(), connLast(), + srvFirst(), srvLast() + { } + Protocol::~Protocol() { - assert(state == STATE_NONE || state == STATE_CLOSED); + assert(state == STATE_NONE || state == STATE_FINISHED); closeWait(); while(finishWaiters > 0) std::this_thread::yield(); + DebugMSG(); if (thread) thread->join(); + DebugMSG("deleted"); } bool Protocol::open() { + DebugMSG(); State s = state; do { if (s != STATE_NONE && s != STATE_FINISHED) return false; @@ -33,6 +52,8 @@ bool Protocol::open() { if (thread) thread->join(); thread = new std::thread(&Protocol::threadRun, this); + state = STATE_INITIALIZING; + DebugMSG(); return true; } @@ -40,23 +61,26 @@ bool Protocol::open() { void Protocol::wantState(State state) { ReadLock lock(stateProtector); State s = this->state; + DebugMSG("%d, %d", s, state); if (s <= STATE_OPENING || s >= STATE_CLOSING || s >= state) return; + DebugMSG("apply"); State sw = stateWanted; do { - if (sw <= STATE_OPENING || sw >= state) return; + if (sw >= state) return; } while(!stateWanted.compare_exchange_weak(sw, state)); wakeup(); } void Protocol::closeReq() - { wantState(STATE_CLOSE_REQ); } + { DebugMSG(); wantState(STATE_CLOSE_REQ); } void Protocol::close() - { wantState(STATE_CLOSING); } + { DebugMSG(); wantState(STATE_CLOSING); } void Protocol::closeWait(unsigned long long timeoutUs, bool withReq) { + DebugMSG(); stateProtector.lock(); State s = state; if (s <= STATE_OPENING || s >= STATE_FINISHED) @@ -80,15 +104,18 @@ void Protocol::closeWait(unsigned long long timeoutUs, bool withReq) { + std::chrono::microseconds(timeoutUs); while(s < STATE_CLOSED) { if (t <= std::chrono::steady_clock::now()) { close(); break; } + DebugMSG("wait_until, %d", s); closeCondition.wait_until(fakeLock, t); s = state; } } if (s < STATE_CLOSING) close(); while(s < STATE_CLOSED) { + DebugMSG("wait, %d", s); closeCondition.wait(fakeLock); s = state; } + DebugMSG("awaiten"); } finishWaiters++; @@ -99,6 +126,7 @@ void Protocol::closeWait(unsigned long long timeoutUs, bool withReq) { while(s != STATE_FINISHED) { std::this_thread::yield(); s = state; } finishWaiters--; + DebugMSG(); } @@ -368,6 +396,7 @@ void Protocol::threadConnEvents(Connection &connection, unsigned int events) { void Protocol::threadSrvQueue() { while(srvQueue.peek()) { + DebugMSG(); Server &server = srvQueue.pop(); server.enqueued = false; @@ -490,6 +519,7 @@ void Protocol::threadSrvEvents(Server &server, unsigned int events) { void Protocol::threadRun() { + DebugMSG(); epollFd = ::epoll_create(32); eventFd = ::eventfd(0, EFD_NONBLOCK); @@ -521,6 +551,7 @@ void Protocol::threadRun() { State sw = stateWanted; if (stateLocal == STATE_OPEN && sw == STATE_CLOSE_REQ) { + DebugMSG("to closse req"); state = stateLocal = STATE_CLOSE_REQ; stateProtector.wait(); for(Server *server = srvFirst; server; server = server->next) @@ -530,7 +561,8 @@ void Protocol::threadRun() { conn->closeReq(); } else if (stateLocal >= STATE_OPEN && stateLocal <= STATE_CLOSE_REQ && sw == STATE_CLOSING) { - state = STATE_CLOSING; + DebugMSG("to closing"); + state = stateLocal = STATE_CLOSING; stateProtector.wait(); for(Server *server = srvFirst; server; server = server->next) server->close(true); @@ -545,19 +577,29 @@ void Protocol::threadRun() { } if (stateLocal == STATE_CLOSING) { - assert(!srvFirst && !connFirst); + assert(!connFirst && !srvFirst); break; } - count = epoll_wait(epollFd, events, maxCount, 10000); + DebugMSG("epoll_wait: %d, %d", stateLocal, sw); + count = epoll_wait(epollFd, events, maxCount, 1000); } + DebugMSG("closeWaiters"); + state = STATE_FINISHED; while(closeWaiters > 0) { closeCondition.notify_all(); std::this_thread::yield(); } ::close(epollFd); ::close(eventFd); + epollFd = -1; + eventFd = -1; + assert(!connFirst && !srvFirst); + assert(!connQueue.peek() && !srvQueue.peek()); + stateWanted = STATE_NONE; + stateLocal = STATE_NONE; state = STATE_FINISHED; + DebugMSG("finished"); } diff --git a/protocol.h b/protocol.h index 3f115b1..7154e45 100644 --- a/protocol.h +++ b/protocol.h @@ -15,6 +15,7 @@ public: enum State { STATE_NONE, STATE_OPENING, + STATE_INITIALIZING, STATE_OPEN, STATE_CLOSE_REQ, STATE_CLOSING, diff --git a/server.cpp b/server.cpp index af80b4b..1132d22 100644 --- a/server.cpp +++ b/server.cpp @@ -13,6 +13,9 @@ Server::Server(): state(STATE_NONE), stateWanted(STATE_NONE), enqueued(false), + error(false), + closeWaiters(0), + finishWaiters(0), address(), addressSize(), protocol(), @@ -124,7 +127,7 @@ void Server::closeWait(unsigned long long timeoutUs, bool withReq) { void Server::onOpeningError() { } void Server::onOpen() { } -Connection* Server::onConnect(void*, size_t) { return nullptr; } +Connection* Server::onConnect(const void*, size_t) { return nullptr; } void Server::onDisconnect(Connection*, bool) { } void Server::onCloseReqested() { } void Server::onClose(bool) { } diff --git a/server.h b/server.h index 4742c27..cd176a6 100644 --- a/server.h +++ b/server.h @@ -61,7 +61,7 @@ public: protected: virtual void onOpeningError(); virtual void onOpen(); - virtual Connection* onConnect(void *address, size_t addressSize); + virtual Connection* onConnect(const void *address, size_t addressSize); virtual void onDisconnect(Connection *connection, bool error); virtual void onCloseReqested(); virtual void onClose(bool error); diff --git a/utils.h b/utils.h index e759c1e..338144c 100644 --- a/utils.h +++ b/utils.h @@ -3,9 +3,25 @@ #include +#include #include #include +#include + + +#define ShowDebugMSG(...) do { \ + printf("DEBUG %10ld %s:%s:%d:", \ + std::chrono::duration_cast( \ + std::chrono::steady_clock::now().time_since_epoch() ).count(), \ + __FILE__, __func__, __LINE__); \ + printf(" " __VA_ARGS__); \ + printf("\n"); \ + fflush(stdout); \ +} while(0) + +#define HideDebugMSG(...) do {} while(0) + class ReadProtector {