diff --git a/connection.cpp b/connection.cpp index ce499dd..f7f5735 100644 --- a/connection.cpp +++ b/connection.cpp @@ -28,7 +28,7 @@ Connection::Connection(): Connection::~Connection() { - assert(state == STATE_NONE || state == STATE_CLOSED); + assert(state == STATE_NONE || state == STATE_FINISHED); closeWait(); while(finishWaiters > 0) std::this_thread::yield(); } @@ -45,7 +45,7 @@ bool Connection::open(Protocol &protocol, void *address, size_t addressSize, int ReadLock lock(protocol.stateProtector); Protocol::State ps = protocol.state; - if (ps < Protocol::STATE_OPEN || ps >= Protocol::STATE_CLOSE_REQ) + if (ps < Protocol::STATE_INITIALIZING || ps >= Protocol::STATE_CLOSE_REQ) { state = STATE_NONE; return false; } memcpy(this->address, address, addressSize); diff --git a/main.cpp b/main.cpp index 1d6ee37..958f1aa 100644 --- a/main.cpp +++ b/main.cpp @@ -8,6 +8,9 @@ #include "protocol.h" +#define myprintf(...) do { printf(__VA_ARGS__); fflush(stdout); } while(0); + + class MyConnection: public Connection { public: const char *name; @@ -18,34 +21,34 @@ public: { strcpy(data[1], "01234567890123456789"); } protected: void onOpeningError() override - { printf("%s: onOpeningError\n", name); } + { myprintf("%s: onOpeningError\n", name); } void onOpen(const void*, size_t) override - { printf("%s: onOpen\n", name); read(readTask); write(writeTask); } + { myprintf("%s: onOpen\n", name); read(readTask); write(writeTask); } bool onReadReady(Task &task, bool error) override - { printf("%s: onReadReady: %s, %d\n", name, (char*)task.data, error); return true; } + { myprintf("%s: onReadReady: %s, %d\n", name, (char*)task.data, error); return true; } bool onWriteReady(Task &task, bool error) override - { printf("%s: onWriteReady: %d\n", name, error); write(task); return true; } + { myprintf("%s: onWriteReady: %d\n", name, error); write(task); return true; } void onCloseReqested() override - { printf("%s: onCloseReqested\n", name); } + { myprintf("%s: onCloseReqested\n", name); } void onClose(bool error) override - { printf("%s: onClose: %d\n", name, (int)error); } + { myprintf("%s: onClose: %d\n", name, (int)error); } }; class MyServer: public Server { protected: void onOpeningError() override - { printf("Server: onOpeningError\n"); } + { myprintf("Server: onOpeningError\n"); } void onOpen() override - { printf("Server: onOpen\n"); } + { myprintf("Server: onOpen\n"); } Connection* onConnect(const void*, size_t) override - { printf("Server: onConnect\n"); return new MyConnection("server connection"); } + { myprintf("Server: onConnect\n"); return new MyConnection("server connection"); } void onDisconnect(Connection *connection, bool error) override - { printf("Server: onDisconnect: %d\n", error); delete connection; } + { myprintf("Server: onDisconnect: %d\n", error); delete connection; } void onCloseReqested() override - { printf("Server: onCloseReqested\n"); } + { myprintf("Server: onCloseReqested\n"); } void onClose(bool error) override - { printf("Server: onClose: %d\n", error); } + { myprintf("Server: onClose: %d\n", error); } }; @@ -53,37 +56,47 @@ protected: int main(int argc, char **) { bool server = argc == 1; - printf("init protocol\n"); + myprintf("init protocol\n"); Protocol protocol; protocol.open(); - printf("protocol open\n"); + myprintf("protocol open\n"); struct sockaddr_in addr = {}; addr.sin_family = AF_INET; - addr.sin_addr.s_addr = 0x8002; + addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); addr.sin_port = htons(1532); if (server) { - printf("init server\n"); + myprintf("init server\n"); MyServer server; - server.open(protocol, &addr, sizeof(addr)); - printf("server open\n"); - fgetc(stdin); - printf("close server\n"); - server.close(false); + if (!server.open(protocol, &addr, sizeof(addr))) { + myprintf("cannot open server\n"); + } else { + myprintf("server open\n"); + myprintf("wait 10s\n"); + std::this_thread::sleep_for( std::chrono::seconds(10) ); + //fgetc(stdin); + myprintf("close server\n"); + server.closeWait(); + } } else { - printf("init connection\n"); + myprintf("init connection\n"); MyConnection connection("client connection"); - connection.open(protocol, &addr, sizeof(addr)); - printf("connection open\n"); - fgetc(stdin); - printf("close connection\n"); - connection.close(false); + if (!connection.open(protocol, &addr, sizeof(addr))) { + myprintf("cannot open connection\n"); + } else { + myprintf("connection open\n"); + myprintf("wait 10s\n"); + std::this_thread::sleep_for( std::chrono::seconds(10) ); + //fgetc(stdin); + myprintf("close connection\n"); + connection.closeWait(); + } } - printf("close protocol\n"); + myprintf("close protocol\n"); protocol.closeWait(); - printf("protocol closed\n"); + myprintf("protocol closed\n"); return 0; } diff --git a/protocol.cpp b/protocol.cpp index e05cb31..5e15bb7 100644 --- a/protocol.cpp +++ b/protocol.cpp @@ -15,7 +15,8 @@ #include "protocol.h" -#define DebugMSG ShowDebugMSG +#define DebugMSG HideDebugMSG +#define SrvDebugMSG ShowDebugMSG @@ -396,11 +397,12 @@ void Protocol::threadConnEvents(Connection &connection, unsigned int events) { void Protocol::threadSrvQueue() { while(srvQueue.peek()) { - DebugMSG(); + SrvDebugMSG(); Server &server = srvQueue.pop(); server.enqueued = false; if (server.stateLocal == Server::STATE_NONE) { + SrvDebugMSG("none"); int sockId = -1; if (server.addressSize >= sizeof(sa_family_t)) { struct sockaddr *addr = (struct sockaddr*)server.address; @@ -410,13 +412,17 @@ void Protocol::threadSrvQueue() { if ( 0 != ::bind(sockId, addr, (int)server.addressSize) || 0 != ::listen(sockId, 128) ) { + SrvDebugMSG("cannot bind/listen"); ::close(sockId); sockId = -1; } + } else { + SrvDebugMSG("cannot create socket"); } } if (sockId >= 0) { + SrvDebugMSG("to open"); server.protocol = this; server.prev = srvLast; assert(!server.next); @@ -432,17 +438,20 @@ void Protocol::threadSrvQueue() { server.state = server.stateLocal = Server::STATE_OPEN; server.stateProtector.wait(); } else { + SrvDebugMSG("opening error -> to closed"); server.onOpeningError(); server.state = Server::STATE_CLOSED; while(server.closeWaiters > 0) { server.closeCondition.notify_all(); std::this_thread::yield(); } server.state = Server::STATE_FINISHED; + SrvDebugMSG("opening error -> finished"); } } else { Server::State stateWanted = server.stateWanted; if ( server.stateLocal == Server::STATE_OPEN && stateWanted == Server::STATE_CLOSE_REQ ) { + SrvDebugMSG("to close req"); for(Connection *conn = server.connFirst; conn; conn = conn->srvNext) conn->closeReq(); struct epoll_event event = {}; @@ -455,6 +464,7 @@ void Protocol::threadSrvQueue() { || server.stateLocal == Server::STATE_CLOSE_REQ ) && stateWanted == Server::STATE_CLOSING_CONNECTIONS ) { + SrvDebugMSG("to closing connections"); if (server.stateLocal != Server::STATE_CLOSE_REQ) { struct epoll_event event = {}; epoll_ctl(epollFd, EPOLL_CTL_DEL, server.sockId, &event); @@ -464,6 +474,7 @@ void Protocol::threadSrvQueue() { conn->close(true); server.state = server.stateLocal = Server::STATE_CLOSING_CONNECTIONS; } else { + SrvDebugMSG("to closing"); server.state = server.stateLocal = Server::STATE_CLOSING; server.stateProtector.wait(); server.onClose(false); @@ -471,6 +482,7 @@ void Protocol::threadSrvQueue() { } if (server.stateLocal == Server::STATE_CLOSING && !server.enqueued) { + SrvDebugMSG("to closed"); server.state = Server::STATE_CLOSED; (server.prev ? server.prev->next : srvFirst) = server.next; (server.next ? server.next->prev : srvLast) = server.prev; @@ -479,9 +491,11 @@ void Protocol::threadSrvQueue() { server.addressSize = 0; server.stateWanted = Server::STATE_NONE; server.stateLocal = Server::STATE_NONE; + SrvDebugMSG("closeWaiters"); while(server.closeWaiters > 0) { server.closeCondition.notify_all(); std::this_thread::yield(); } server.state = Server::STATE_FINISHED; + SrvDebugMSG("finished"); } } } diff --git a/server.cpp b/server.cpp index 1132d22..e36344c 100644 --- a/server.cpp +++ b/server.cpp @@ -9,6 +9,9 @@ #include "server.h" +#define DebugMSG ShowDebugMSG + + Server::Server(): state(STATE_NONE), stateWanted(STATE_NONE), @@ -28,25 +31,27 @@ Server::Server(): Server::~Server() { - assert(state == STATE_NONE || state == STATE_CLOSED); + assert(state == STATE_NONE || state == STATE_FINISHED); closeWait(); while(finishWaiters > 0) std::this_thread::yield(); } bool Server::open(Protocol &protocol, void *address, size_t addressSize) { - if (!address || !addressSize || addressSize > MAX_ADDR_SIZE) return false; + if (!address || !addressSize || addressSize > MAX_ADDR_SIZE) + { DebugMSG("bad args"); return false; } State s = state; do { - if (s != STATE_NONE && s != STATE_FINISHED) return false; + if (s != STATE_NONE && s != STATE_FINISHED) + { DebugMSG("already open or in closing state"); return false; } } while(!state.compare_exchange_weak(s, STATE_OPENING)); if (s == STATE_FINISHED) while(finishWaiters > 0) std::this_thread::yield(); ReadLock lock(protocol.stateProtector); Protocol::State ps = protocol.state; - if (ps < Protocol::STATE_OPEN || ps >= Protocol::STATE_CLOSE_REQ) - { state = STATE_NONE; return false; } + if (ps < Protocol::STATE_INITIALIZING || ps >= Protocol::STATE_CLOSE_REQ) + { DebugMSG("protocol is closed"); state = STATE_NONE; return false; } memcpy(this->address, address, addressSize); this->addressSize = addressSize; @@ -66,7 +71,7 @@ void Server::wantState(State state, bool error) { State sw = stateWanted; do { - if (sw <= STATE_OPENING || sw >= state) return; + if (sw >= state) return; } while(!stateWanted.compare_exchange_weak(sw, state)); if (error) { error = false; this->error.compare_exchange_strong(error, true); } if (!enqueued.exchange(true)) protocol->srvQueue.push(*this);