diff --git a/connection.cpp b/connection.cpp index f7f5735..47b098d 100644 --- a/connection.cpp +++ b/connection.cpp @@ -9,6 +9,10 @@ #include "connection.h" +#define DebugMSG HideDebugMSG + + + Connection::Connection(): state(STATE_NONE), stateWanted(STATE_NONE), @@ -23,7 +27,8 @@ Connection::Connection(): srvPrev(), srvNext(), queueNext(), sockId(-1), - stateLocal() + stateLocal(), + errorLocal() { } @@ -64,13 +69,15 @@ bool Connection::open(Protocol &protocol, void *address, size_t addressSize, int void Connection::wantState(State state, bool error) { ReadLock lock(stateProtector); State s = this->state; + DebugMSG("%d, %d", s, state); if (s <= STATE_OPENING || s >= STATE_CLOSING || s >= state) return; + DebugMSG("apply"); State sw = stateWanted; do { - if (sw <= STATE_OPENING || sw >= state) return; + if (sw >= state) return; } while(!stateWanted.compare_exchange_weak(sw, state)); - if (error) { error = false; this->error.compare_exchange_strong(error, true); } + if (error) { DebugMSG("error"); error = false; this->error.compare_exchange_strong(error, true); } if (!enqueued.exchange(true)) protocol->connQueue.push(*this); protocol->wakeup(); } @@ -128,19 +135,23 @@ void Connection::closeWait(unsigned long long timeoutUs, bool withReq, bool erro bool Connection::read(Task &task) { + DebugMSG(); ReadLock lock(stateProtector); State s = state; if (s < STATE_OPEN || s >= STATE_CLOSING) return false; - writeQueue.push(task); + DebugMSG(); + readQueue.push(task); protocol->updateEvents(*this); return true; } bool Connection::write(Task &task) { + DebugMSG(); ReadLock lock(stateProtector); State s = state; if (s < STATE_OPEN || s >= STATE_CLOSING) return false; + DebugMSG(); writeQueue.push(task); protocol->updateEvents(*this); return true; diff --git a/connection.h b/connection.h index 0ceee1c..5d8b58b 100644 --- a/connection.h +++ b/connection.h @@ -81,6 +81,7 @@ private: Connection *queueNext; int sockId; State stateLocal; + bool errorLocal; std::atomic events; TaskQueue readQueue; @@ -97,7 +98,7 @@ public: Connection(); virtual ~Connection(); - bool open(Protocol &protocol, void *address, size_t addressSize, int sockId = 0, Server *server = nullptr); + bool open(Protocol &protocol, void *address, size_t addressSize, int sockId = -1, Server *server = nullptr); void closeReq(); void close(bool error); void closeWait(unsigned long long timeoutUs = 0, bool withReq = true, bool error = true); diff --git a/main.cpp b/main.cpp index 958f1aa..e19d939 100644 --- a/main.cpp +++ b/main.cpp @@ -18,16 +18,28 @@ public: char data[2][100]; MyConnection(const char *name): name(name), readTask(data[0], 16), writeTask(data[1], 16), data() - { strcpy(data[1], "01234567890123456789"); } + { strcpy(data[1], "01234\n\n\n567890123456789"); } protected: void onOpeningError() override { myprintf("%s: onOpeningError\n", name); } void onOpen(const void*, size_t) override { myprintf("%s: onOpen\n", name); read(readTask); write(writeTask); } - bool onReadReady(Task &task, bool error) override - { myprintf("%s: onReadReady: %s, %d\n", name, (char*)task.data, error); return true; } - bool onWriteReady(Task &task, bool error) override - { myprintf("%s: onWriteReady: %d\n", name, error); write(task); return true; } + + bool onReadReady(Task &task, bool error) override { + data[0][task.completion] = 0; + myprintf("%s: onReadReady: %lu, %lu, %s, %d\n", name, task.completion, task.size, (char*)task.data, error); + task.completion = 0; + read(task); + return true; + } + + bool onWriteReady(Task &task, bool error) override { + myprintf("%s: onWriteReady: %d\n", name, error); + task.completion = 0; + write(task); + return true; + } + void onCloseReqested() override { myprintf("%s: onCloseReqested\n", name); } void onClose(bool error) override @@ -65,6 +77,7 @@ int main(int argc, char **) { addr.sin_family = AF_INET; addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); addr.sin_port = htons(1532); + addr.sin_port = htons(80); if (server) { myprintf("init server\n"); @@ -87,7 +100,7 @@ int main(int argc, char **) { } else { myprintf("connection open\n"); myprintf("wait 10s\n"); - std::this_thread::sleep_for( std::chrono::seconds(10) ); + std::this_thread::sleep_for( std::chrono::seconds(2) ); //fgetc(stdin); myprintf("close connection\n"); connection.closeWait(); diff --git a/protocol.cpp b/protocol.cpp index 5e15bb7..ed20e23 100644 --- a/protocol.cpp +++ b/protocol.cpp @@ -15,8 +15,9 @@ #include "protocol.h" -#define DebugMSG HideDebugMSG -#define SrvDebugMSG ShowDebugMSG +#define DebugMSG HideDebugMSG +#define SrvDebugMSG HideDebugMSG +#define ConnDebugMSG HideDebugMSG @@ -138,10 +139,13 @@ void Protocol::wakeup() { void Protocol::updateEvents(Connection &connection) { - unsigned int events = 0; + unsigned int events = EPOLLERR; if (connection.readQueue.count()) events |= EPOLLIN; if (connection.writeQueue.count()) events |= EPOLLOUT; + + ConnDebugMSG("events: %08x", events); if (connection.events.exchange(events) == events) return; + ConnDebugMSG("apply"); struct epoll_event event = {}; event.data.ptr = &connection; @@ -159,33 +163,47 @@ void Protocol::initSocket(int sockId) { void Protocol::threadConnQueue() { while(connQueue.peek()) { + ConnDebugMSG(); Connection &connection = connQueue.pop(); connection.enqueued = false; if (connection.stateLocal == Connection::STATE_NONE) { + ConnDebugMSG("none"); int sockId = connection.sockId; bool connected = false; if (sockId >= 0) { + ConnDebugMSG("accepting"); initSocket(sockId); } else if (connection.addressSize >= sizeof(sa_family_t)) { + ConnDebugMSG("connecting"); struct sockaddr *addr = (struct sockaddr*)connection.address; sockId = ::socket(addr->sa_family, SOCK_STREAM, 0); if (sockId >= 0) { initSocket(sockId); - int res = ::connect(sockId, addr, (int)connection.addressSize); - if (res == 0) { + if (0 == ::connect(sockId, addr, (int)connection.addressSize)) connected = true; - } else - if (res != EINPROGRESS) { - ::close(sockId); - sockId = -1; - } + } else { + ConnDebugMSG("none -> cannot create socket"); + } + } + + struct epoll_event event = {}; + if (sockId >= 0) { + event.data.ptr = &connection; + event.events |= EPOLLERR; + if (!connected) event.events |= EPOLLOUT; + if (0 != epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event)) { + ConnDebugMSG("epoll_ctl(%d, ADD, %d, %08x) errno: %d", epollFd, sockId, event.events, errno); + ::close(sockId); + connection.sockId = sockId = -1; + assert(false); } } if (sockId >= 0) { + ConnDebugMSG("opening"); connection.protocol = this; connection.prev = connLast; assert(!connection.next); @@ -197,43 +215,57 @@ void Protocol::threadConnQueue() { } connection.sockId = sockId; - struct epoll_event event = {}; - event.data.ptr = &connection; if (connected) { - epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event); - connection.onOpen(connection.address, connection.addressSize); + ConnDebugMSG("to open"); + connection.events = event.events; connection.state = connection.stateLocal = Connection::STATE_OPEN; connection.stateProtector.wait(); + connection.onOpen(connection.address, connection.addressSize); } else { - event.events = EPOLLOUT; - epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event); + assert(!connection.server); + ConnDebugMSG("to connecting"); connection.state = connection.stateLocal = Connection::STATE_CONNECTING; connection.stateProtector.wait(); } } else { - connection.sockId = -1; + ConnDebugMSG("none -> to closed"); + Server *server = connection.server; + assert(connected || !server); connection.server = nullptr; + connection.sockId = -1; connection.onOpeningError(); connection.state = Connection::STATE_CLOSED; while(connection.closeWaiters > 0) { connection.closeCondition.notify_all(); std::this_thread::yield(); } connection.state = Connection::STATE_FINISHED; + ConnDebugMSG("none -> to finished"); + if (server) threadSrvDisconnect(*server, connection, true); } } else { Connection::State stateWanted = connection.stateWanted; - if ( connection.stateLocal >= Connection::STATE_CONNECTING - && connection.stateLocal <= Connection::STATE_OPEN + if ( connection.stateLocal == Connection::STATE_CONNECTING + && stateWanted == Connection::STATE_CLOSING ) + { + ConnDebugMSG("connecting -> to closing"); + connection.onOpeningError(); + connection.state = connection.stateLocal = Connection::STATE_CLOSING; + connection.stateProtector.wait(); + } + + if ( connection.stateLocal == Connection::STATE_OPEN && stateWanted == Connection::STATE_CLOSE_REQ ) { + ConnDebugMSG("to close req"); connection.state = connection.stateLocal = Connection::STATE_CLOSE_REQ; connection.stateProtector.wait(); connection.onCloseReqested(); } else - if ( connection.stateLocal >= Connection::STATE_CONNECTING + if ( connection.stateLocal >= Connection::STATE_OPEN && connection.stateLocal <= Connection::STATE_CLOSE_REQ && stateWanted == Connection::STATE_CLOSING ) { + ConnDebugMSG("to closing"); connection.state = connection.stateLocal = Connection::STATE_CLOSING; connection.stateProtector.wait(); while(Connection::Task *task = connection.readQueue.peek()) { @@ -246,18 +278,28 @@ void Protocol::threadConnQueue() { if (!connection.onWriteReady(*task, true)) assert(false); } - connection.onClose(false); + connection.errorLocal = connection.error; + connection.onClose(connection.errorLocal); } if (connection.stateLocal == Connection::STATE_CLOSING && !connection.enqueued) { - if (connection.server) { + ConnDebugMSG("to closed"); + Server *server = connection.server; + if (server) { (connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = connection.srvNext; (connection.srvNext ? connection.srvNext->srvPrev : connection.server->connLast) = connection.srvPrev; connection.server = nullptr; } + bool error = connection.errorLocal; + connection.errorLocal = false; (connection.prev ? connection.prev->next : connFirst) = connection.next; (connection.next ? connection.next->prev : connLast) = connection.prev; - connection.sockId = -1; + if (connection.sockId >= 0) { + struct epoll_event event = {}; + epoll_ctl(epollFd, EPOLL_CTL_DEL, connection.sockId, &event); + ::close(connection.sockId); + connection.sockId = -1; + } connection.protocol = nullptr; memset(connection.address, 0, sizeof(connection.address)); connection.addressSize = 0; @@ -267,6 +309,8 @@ void Protocol::threadConnQueue() { while(connection.closeWaiters > 0) { connection.closeCondition.notify_all(); std::this_thread::yield(); } connection.state = Connection::STATE_FINISHED; + ConnDebugMSG("to finished"); + if (server) threadSrvDisconnect(*server, connection, error); } } } @@ -277,31 +321,17 @@ void Protocol::threadConnEvents(Connection &connection, unsigned int events) { bool needUpdate = false; if (connection.stateLocal == Connection::STATE_CONNECTING) { - if (events & EPOLLHUP) { - connection.onOpeningError(); - if (connection.server) { - (connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = connection.srvNext; - (connection.srvNext ? connection.srvNext->srvPrev : connection.server->connLast) = connection.srvPrev; - connection.server = nullptr; - } - (connection.prev ? connection.prev->next : connFirst) = connection.next; - (connection.next ? connection.next->prev : connLast) = connection.prev; - connection.sockId = -1; - connection.protocol = nullptr; - memset(connection.address, 0, sizeof(connection.address)); - connection.addressSize = 0; - connection.stateWanted = Connection::STATE_NONE; - connection.stateLocal = Connection::STATE_NONE; - connection.state = Connection::STATE_CLOSED; - while(connection.closeWaiters > 0) - { connection.closeCondition.notify_all(); std::this_thread::yield(); } - connection.state = Connection::STATE_FINISHED; + ConnDebugMSG("connecting"); + if (events & (EPOLLHUP | EPOLLERR)) { + connection.close(true); + return; } if (events & EPOLLOUT) { - connection.onOpen(connection.address, connection.addressSize); + ConnDebugMSG("connecting -> to open"); connection.state = connection.stateLocal = Connection::STATE_OPEN; connection.stateProtector.wait(); + connection.onOpen(connection.address, connection.addressSize); needUpdate = true; } } @@ -311,6 +341,7 @@ void Protocol::threadConnEvents(Connection &connection, unsigned int events) { return; if (events & EPOLLIN) { + ConnDebugMSG("input"); bool ready = true; int pops = 0; while(ready) { @@ -327,16 +358,16 @@ void Protocol::threadConnEvents(Connection &connection, unsigned int events) { connection.sockId, (unsigned char*)task->data + task->completion, (int)size, MSG_DONTWAIT ); - if (res == EAGAIN) { + if (res == EAGAIN || res == 0) { ready = false; break; } else - if (res <= 0) { + if (res < 0) { connection.close(true); return; } else { - completion += size; - task->completion += size; + completion += res; + task->completion += res; } } @@ -348,14 +379,17 @@ void Protocol::threadConnEvents(Connection &connection, unsigned int events) { } } - if (events & EPOLLHUP) { + if (events & (EPOLLHUP | EPOLLERR)) { + ConnDebugMSG("hup"); connection.close((bool)(events & EPOLLERR)); return; } if (events & EPOLLOUT) { + ConnDebugMSG("output"); bool ready = true; int pops = 0; + int count = connection.writeQueue.count() + 2; while(ready) { Connection::Task *task = connection.writeQueue.peek(); if (!task) { @@ -370,16 +404,16 @@ void Protocol::threadConnEvents(Connection &connection, unsigned int events) { connection.sockId, (const unsigned char*)task->data + task->completion, (int)size, MSG_DONTWAIT ); - if (res == EAGAIN) { + if (res == EAGAIN || res == 0) { ready = false; break; } else - if (res <= 0) { + if (res < 0) { connection.close(true); return; } else { - completion += size; - task->completion += size; + completion += res; + task->completion += res; } } @@ -387,6 +421,7 @@ void Protocol::threadConnEvents(Connection &connection, unsigned int events) { connection.writeQueue.pop(); if (!connection.onWriteReady(*task, false)) connection.writeQueue.unpop(*task); else ++pops; + if (--count <= 0) ready = false; } } } @@ -395,6 +430,18 @@ void Protocol::threadConnEvents(Connection &connection, unsigned int events) { } +void Protocol::threadSrvDisconnect(Server& server, Connection& connection, bool error) { + server.onDisconnect(&connection, error); + if (!server.connFirst && server.stateLocal == Server::STATE_CLOSING_CONNECTIONS) { + SrvDebugMSG("to closing"); + server.state = server.stateLocal = Server::STATE_CLOSING; + server.stateProtector.wait(); + server.onClose(server.error); + if (!server.enqueued.exchange(true)) srvQueue.push(server); + } +} + + void Protocol::threadSrvQueue() { while(srvQueue.peek()) { SrvDebugMSG(); @@ -431,12 +478,12 @@ void Protocol::threadSrvQueue() { struct epoll_event event = {}; event.data.ptr = &server; - event.events = EPOLLIN; + event.events = EPOLLIN | EPOLLERR; epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event); - server.onOpen(); server.state = server.stateLocal = Server::STATE_OPEN; server.stateProtector.wait(); + server.onOpen(); } else { SrvDebugMSG("opening error -> to closed"); server.onOpeningError(); @@ -477,13 +524,18 @@ void Protocol::threadSrvQueue() { SrvDebugMSG("to closing"); server.state = server.stateLocal = Server::STATE_CLOSING; server.stateProtector.wait(); - server.onClose(false); + server.onClose(server.error); } } if (server.stateLocal == Server::STATE_CLOSING && !server.enqueued) { SrvDebugMSG("to closed"); - server.state = Server::STATE_CLOSED; + if (server.sockId >= 0) { + struct epoll_event event = {}; + epoll_ctl(epollFd, EPOLL_CTL_DEL, server.sockId, &event); + ::close(server.sockId); + server.sockId = -1; + } (server.prev ? server.prev->next : srvFirst) = server.next; (server.next ? server.next->prev : srvLast) = server.prev; server.protocol = nullptr; @@ -492,6 +544,7 @@ void Protocol::threadSrvQueue() { server.stateWanted = Server::STATE_NONE; server.stateLocal = Server::STATE_NONE; SrvDebugMSG("closeWaiters"); + server.state = Server::STATE_CLOSED; while(server.closeWaiters > 0) { server.closeCondition.notify_all(); std::this_thread::yield(); } server.state = Server::STATE_FINISHED; @@ -503,7 +556,7 @@ void Protocol::threadSrvQueue() { void Protocol::threadSrvEvents(Server &server, unsigned int events) { - if (events & EPOLLHUP) { + if (events & (EPOLLHUP | EPOLLERR)) { if ( server.stateLocal >= Server::STATE_OPEN && server.stateLocal < Server::STATE_CLOSING ) server.close((bool)(events & EPOLLERR)); @@ -536,6 +589,8 @@ void Protocol::threadRun() { DebugMSG(); epollFd = ::epoll_create(32); eventFd = ::eventfd(0, EFD_NONBLOCK); + assert(epollFd >= 0); + assert(eventFd >= 0); struct epoll_event event = {}; event.events = EPOLLIN; @@ -595,12 +650,12 @@ void Protocol::threadRun() { break; } - DebugMSG("epoll_wait: %d, %d", stateLocal, sw); count = epoll_wait(epollFd, events, maxCount, 1000); + DebugMSG("epoll count: %d", count); } DebugMSG("closeWaiters"); - state = STATE_FINISHED; + state = STATE_CLOSED; while(closeWaiters > 0) { closeCondition.notify_all(); std::this_thread::yield(); } diff --git a/protocol.h b/protocol.h index 7154e45..7a9ed74 100644 --- a/protocol.h +++ b/protocol.h @@ -53,6 +53,7 @@ private: void threadConnQueue(); void threadConnEvents(Connection &connection, unsigned int events); + void threadSrvDisconnect(Server &server, Connection &connection, bool error); void threadSrvQueue(); void threadSrvEvents(Server &server, unsigned int events); void threadRun(); diff --git a/utils.h b/utils.h index 338144c..58d867d 100644 --- a/utils.h +++ b/utils.h @@ -142,7 +142,9 @@ public: inline void push(Type &item) { Queue::push(item); counter++; } inline Type& pop() - { Type &item = Queue::pop(); counter++; return item; } + { Type &item = Queue::pop(); counter--; return item; } + inline void unpop(Type &item) + { return Queue::unpop(item); counter++; } inline int count() const { return counter; } };