From 2499ad58cd8b8b89c9ba96c1e46f8a2b64d27e59 Mon Sep 17 00:00:00 2001 From: Ivan Mahonin Date: Oct 10 2021 15:03:01 +0000 Subject: fix bugs --- diff --git a/common.h b/common.h index 75175a7..3ff5688 100644 --- a/common.h +++ b/common.h @@ -2,8 +2,63 @@ #define COMMON_H +#include + #include "error.h" #include "handle.h" + +class RecursiveMutex { +public: + class CondWrapper { + private: + int cnt; + bool locked; + inline void lockCnt(int cnt) { for(int i = 0; i < cnt; ++i) mutex.lock(); } + inline void unlockCnt(int cnt) { for(int i = 0; i < cnt; ++i) mutex.unlock(); } + + public: + RecursiveMutex &mutex; + + inline CondWrapper(RecursiveMutex &mutex): + cnt(mutex.count()), locked(true), mutex(mutex) { } + inline ~CondWrapper() + { if (!locked) lock(); } + + inline void lock() { + if (!locked) lockCnt(cnt); + locked = true; + } + + inline void unlock() { + if (locked) unlockCnt(cnt); + locked = false; + } + + inline bool try_lock() { + if (locked || cnt <= 0) return true; + if (!mutex.try_lock()) return false; + lockCnt(cnt-1); + locked = true; + return true; + } + }; + +private: + int cnt; + +public: + std::recursive_mutex mutex; + + inline RecursiveMutex(): cnt() { } + inline ~RecursiveMutex() { assert(!cnt); } + + inline int count() const { return cnt; } + inline void lock() { mutex.lock(); ++cnt; } + inline bool try_lock() { if (mutex.try_lock()) { ++cnt; return true; } else return false; } + inline void unlock() { mutex.unlock(); --cnt; } +}; + + #endif diff --git a/connection.cpp b/connection.cpp index 0ca6544..1d8194b 100644 --- a/connection.cpp +++ b/connection.cpp @@ -36,6 +36,22 @@ void Connection::clean() { } +void Connection::readDone() { + ReadReq req = readQueue.front(); + readQueue.pop_front(); + onReadReady(req); + if (isAllReadDone()) socket->setWantRead(false); +} + + +void Connection::writeDone() { + WriteReq req = writeQueue.front(); + writeQueue.pop_front(); + onWriteReady(req); + if (isAllWriteDone()) socket->setWantWrite(false); +} + + ErrorCode Connection::open(Socket *socket) { Lock lock(mutex); close(); @@ -74,7 +90,7 @@ void Connection::close(ErrorCode errorCode) { void Connection::closeReq() { Lock lock(mutex); - if (!started || closeRequested) return; + if (!started || switching || closeRequested) return; closeRequested = true; onCloseRequested(); } @@ -83,7 +99,7 @@ void Connection::closeReq() { void Connection::closeWait(unsigned long long timeoutUs, bool withRequest) { std::unique_lock uniqlock(mutex); - if (!started) return; + if (!started || switching) return; if (withRequest) closeReq(); unsigned long long timeUs = monotonicTimeUs() + timeoutUs; if (closeTimeUs > timeUs) @@ -101,8 +117,9 @@ void Connection::closeWait(unsigned long long timeoutUs, bool withRequest) { Connection::ReqId Connection::writeReq(const void *data, size_t size, void *userData) { if (!data || !size) return 0; Lock lock(mutex); - if (!started) return 0; + if (started == switching) return 0; // already closed or closing now writeQueue.emplace_back(++lastId, userData, data, size); + socket->setWantWrite(true); return lastId; } @@ -110,8 +127,9 @@ Connection::ReqId Connection::writeReq(const void *data, size_t size, void *user Connection::ReqId Connection::readReq(void *data, size_t size, void *userData) { if (!data || !size) return 0; Lock lock(mutex); - if (!started) return 0; + if (started == switching) return 0; // already closed or closing now readQueue.emplace_back(++lastId, userData, data, size); + socket->setWantRead(true); return lastId; } diff --git a/connection.h b/connection.h index a6e9737..9bee5fa 100644 --- a/connection.h +++ b/connection.h @@ -31,7 +31,7 @@ public: typedef THandle Handle; typedef unsigned long long ReqId; - typedef std::recursive_mutex Mutex; + typedef RecursiveMutex Mutex; typedef std::lock_guard Lock; template @@ -80,7 +80,7 @@ public: inline TRWLock(Connection &connection, typename ReqType::Queue &queue): lock(connection.mutex), connection(connection), - request(connection.started && !queue.empty() ? &queue.front() : nullptr) + request(connection.started && !connection.switching && !queue.empty() ? &queue.front() : nullptr) { } public: @@ -99,7 +99,7 @@ public: inline ReadLock(Connection &connection): LockType(connection, connection.readQueue) { } inline ~ReadLock() - { if (success()) connection.onReadReady(*request); } + { if (success()) connection.readDone(); } }; class WriteLock: public TRWLock { @@ -107,7 +107,7 @@ public: inline WriteLock(Connection &connection): LockType(connection, connection.writeQueue) { } inline ~WriteLock() - { if (success()) connection.onWriteReady(*request); } + { if (success()) connection.writeDone(); } }; private: @@ -130,7 +130,10 @@ private: friend class Protocol; ErrorCode open(Socket *socket); - + + void readDone(); + void writeDone(); + public: Connection(); virtual ~Connection(); diff --git a/main.cpp b/main.cpp index 812ed15..7191237 100644 --- a/main.cpp +++ b/main.cpp @@ -1,6 +1,8 @@ #include +#include + #include "tcpprotocol.h" #include "main.h" @@ -8,26 +10,29 @@ namespace { class MyConnection: public Connection { private: - enum { BUFSIZE = 128 }; + enum { BUFSIZE = 16 }; const bool serverSide; + const bool serverHello; char buffer[2][BUFSIZE + 1]; public: std::string name; - explicit MyConnection(const char *name, bool serverSide): - serverSide(serverSide), buffer(), name(name) { } + explicit MyConnection(const char *name, bool serverSide, bool serverHello = false): + serverSide(serverSide), serverHello(serverHello), buffer(), name(name) { } protected: void sendData() { for(int i = 0; i < BUFSIZE; ++i) buffer[1][i] = '0' + rand()%10; + buffer[1][BUFSIZE - 2] = '\n'; + buffer[1][BUFSIZE - 1] = '\n'; printf("%s: send: %s\n", name.c_str(), buffer[1]); writeReq(buffer[1], BUFSIZE); } ErrorCode onOpen() override { printf("%s: MyConnection::onOpen()\n", name.c_str()); - if (serverSide) sendData(); + if (serverSide == serverHello) sendData(); readReq(buffer[0], BUFSIZE); return ERR_NONE; } @@ -129,7 +134,8 @@ int main(int argc, char** argv) { printf("cannot open connection, errorCode: %u\n", errorCode); } else { printf("connected\n"); - fgetc(stdin); + usleep(5000000ll); + //fgetc(stdin); printf("disconnection\n"); connection->closeWait(10); printf("disconnected\n"); diff --git a/protocol.cpp b/protocol.cpp index f62e50d..f1e75c0 100644 --- a/protocol.cpp +++ b/protocol.cpp @@ -77,7 +77,7 @@ void Socket::finalize() { Protocol::Protocol(): first(), last(), chFirst(), chLast(), - started(), switching(), stopping(), stopRequested(), stopTimeUs() { } + started(), switching(), stopRequested(), stopTimeUs() { } Protocol::~Protocol() { @@ -142,7 +142,6 @@ void Protocol::clean() { assert(!socketFirst()); while(socketFirst()) delete socketFirst(); started = false; - stopping = false; stopRequested = false; stopTimeUs = 0; } @@ -170,7 +169,6 @@ void Protocol::stop(ErrorCode errorCode) { if (!started || switching) return; switching = true; - stopping = true; int unfinishedConnections = 0; for(Socket *s = socketFirst(); s; s = socketNext(s)) { if (!socketToRemove(s)) { @@ -195,7 +193,7 @@ void Protocol::stop(ErrorCode errorCode) { void Protocol::stopReq() { Lock lock(mutex); - if (!started || stopRequested) return; + if (!started || switching || stopRequested) return; stopRequested = true; for(Socket *s = socketFirst(); s; s = socketNext(s)) { if (!(socketFlags(s) & Socket::REMOVED)) { @@ -213,7 +211,7 @@ void Protocol::stopReq() { void Protocol::stopWait(unsigned long long timeoutUs, bool withRequest) { std::unique_lock uniqlock(mutex); - if (!started) return; + if (!started || switching) return; if (withRequest) stopReq(); unsigned long long timeUs = monotonicTimeUs() + timeoutUs; if (stopTimeUs > timeUs) diff --git a/protocol.h b/protocol.h index e84e517..464fa35 100644 --- a/protocol.h +++ b/protocol.h @@ -71,7 +71,7 @@ class Protocol: public Shared { public: typedef THandle Handle; - typedef std::recursive_mutex Mutex; + typedef RecursiveMutex Mutex; typedef std::lock_guard Lock; private: @@ -80,7 +80,6 @@ private: Socket *chFirst, *chLast; bool started; bool switching; - bool stopping; bool stopRequested; unsigned long long stopTimeUs; std::condition_variable_any stopWaitCondition; @@ -101,8 +100,10 @@ protected: // mutex must be locked before call inline bool isStarted() const { return started; } + inline bool isStarting() const + { return !started && switching; } inline bool isStopping() const - { return stopping; } + { return started && switching; } inline bool isStopRequested() const { return stopRequested; } inline bool isActive() const diff --git a/server.cpp b/server.cpp index 204c126..d9fd7ea 100644 --- a/server.cpp +++ b/server.cpp @@ -67,7 +67,7 @@ void Server::stop(ErrorCode errorCode) { void Server::stopReq() { Lock lock(mutex); - if (!started || stopRequested) return; + if (!started || switching || stopRequested) return; stopRequested = true; getProtocol().stopServerReq(*this); onStopRequested(); @@ -77,13 +77,13 @@ void Server::stopReq() { void Server::stopWait(unsigned long long timeoutUs, bool withRequest) { std::unique_lock uniqlock(mutex); - if (!socket) return; + if (!started || switching) return; if (withRequest) stopReq(); unsigned long long timeUs = monotonicTimeUs() + timeoutUs; if (stopTimeUs > timeUs) { stopTimeUs = timeUs; stopWaitCondition.notify_all(); } - while(!socket) { + while(started) { unsigned long long timeUs = monotonicTimeUs(); if (timeUs >= stopTimeUs) { stop(ERR_CONNECTION_LOST); break; } diff --git a/server.h b/server.h index 1b5ac34..6bfb503 100644 --- a/server.h +++ b/server.h @@ -26,7 +26,7 @@ class Server: public Shared { public: typedef THandle Handle; - typedef std::recursive_mutex Mutex; + typedef RecursiveMutex Mutex; typedef std::lock_guard Lock; private: diff --git a/tcpprotocol.cpp b/tcpprotocol.cpp index 58d1224..6b90517 100644 --- a/tcpprotocol.cpp +++ b/tcpprotocol.cpp @@ -37,11 +37,35 @@ TcpSocket::TcpSocket( TcpProtocol::TcpProtocol(): - thread(), epollId(), eventsId() { } + wantRun(false), thread(), epollId(), eventsId() +{ + epollId = ::epoll_create(32); + eventsId = ::eventfd(0, EFD_NONBLOCK); + + struct epoll_event event = {}; + event.events = EPOLLIN; + epoll_ctl(epollId, EPOLL_CTL_ADD, eventsId, &event); + + thread = new std::thread(&TcpProtocol::threadRun, this); +} -TcpProtocol::~TcpProtocol() - { stop(); } +TcpProtocol::~TcpProtocol() { + stop(); + + controlCommand(1); + thread->join(); + delete thread; + + struct epoll_event event = {}; + epoll_ctl(epollId, EPOLL_CTL_DEL, eventsId, &event); + + ::close(eventsId); + eventsId = 0; + + ::close(epollId); + epollId = 0; +} ErrorCode TcpProtocol::resolve(Address &address) { @@ -93,7 +117,8 @@ ErrorCode TcpProtocol::resolve(Address &address) { return ERR_ADDRESS_INCORRECT; address.set(Address::SOCKET); - sockaddr_in &addr = address.as(); + address.size = sizeof(struct sockaddr_in); + struct sockaddr_in &addr = address.as(); addr.sin_family = AF_INET; memcpy(&addr.sin_addr, ip, sizeof(ip)); addr.sin_port = htons(portNum); @@ -102,150 +127,180 @@ ErrorCode TcpProtocol::resolve(Address &address) { } -void TcpProtocol::threadRun() { - const int maxCount = 1024; - struct epoll_event events[maxCount] = {}; - int count = 0; +void TcpProtocol::controlCommand(unsigned long long cmd) + { ::write(eventsId, &cmd, sizeof(cmd)); } + + +void TcpProtocol::threadConnectionRead(TcpSocket *socket, const Connection::Handle &connection) { + int sockId = socket->sockId; + ErrorCode err = ERR_NONE; + bool ready = true; + while(ready) { + if (socketToRemove(socket)) return; + Connection::ReadLock readLock(*connection); + if (!readLock) break; + while(true) { + size_t remains = readLock->remains(); + if (!remains) break; + if (remains > INT_MAX) remains = INT_MAX; + int res = ::recv(sockId, readLock->current(), (int)remains, MSG_DONTWAIT); + if (res <= 0) { + if (res != 0 && res != EAGAIN) err = ERR_CONNECTION_LOST; + ready = false; + break; + } + readLock->processed(res); + } + } + if (err) connection->close(err); +} + + +void TcpProtocol::threadConnectionWrite(TcpSocket *socket, const Connection::Handle &connection) { + int sockId = socket->sockId; + ErrorCode err = ERR_NONE; + bool ready = true; + while(ready) { + if (socketToRemove(socket)) return; + Connection::WriteLock writeLock(*connection); + if (!writeLock) break; + while(true) { + size_t remains = writeLock->remains(); + if (!remains) break; + if (remains > INT_MAX) remains = INT_MAX; + int res = ::send(sockId, writeLock->current(), (int)remains, MSG_DONTWAIT); + if (res <= 0) { + if (res != 0 && res != EAGAIN) err = ERR_CONNECTION_LOST; + ready = false; + break; + } + writeLock->processed(res); + } + } + if (err) connection->close(err); +} + + +void TcpProtocol::threadProcessEvents(void *eventsPtr, int count) { + struct epoll_event *events = (struct epoll_event*)eventsPtr; - while(true) { - { // locked scope - Lock lock(mutex); + // process socket events + for(int i = 0; i < count; ++i) { + if (TcpSocket *socket = (TcpSocket*)events[i].data.ptr) { + if (socketToRemove(socket)) continue; - { // read wakeup signals - void *ptr = nullptr; - while(::read(eventsId, &ptr, sizeof(ptr)) > 0); - } + unsigned int e = events[i].events; + int sockId = socket->sockId; + const Connection::Handle &connection = socket->connection; + const Server::Handle &server = socket->server; - // process socket events - for(int i = 0; i < count; ++i) { - if (TcpSocket *socket = (TcpSocket*)events[i].data.ptr) { - if (socketToRemove(socket)) continue; - - unsigned int e = events[i].events; - int sockId = socket->sockId; - const Connection::Handle &connection = socket->connection; - const Server::Handle &server = socket->server; - - if (e & EPOLLIN) { - if (connection) { - ErrorCode err = ERR_NONE; - bool ready = true; - while(ready) { - if (socketToRemove(socket)) break; - Connection::ReadLock readLock(*connection); - if (!readLock) break; - while(true) { - size_t remains = readLock->remains(); - if (!remains) break; - if (remains > INT_MAX) remains = INT_MAX; - int res = ::recv(sockId, readLock->current(), (int)remains, MSG_DONTWAIT); - if (res <= 0) { - if (res != 0 && res != EAGAIN) err = ERR_CONNECTION_LOST; - ready = false; - break; - } - readLock->processed(res); - } - } - if (socketToRemove(socket)) continue; - if (err) connection->close(err); - } else - if (server) { - Address address; - address.type = Address::SOCKET; - socklen_t len = sizeof(sockaddr_in); - int res = ::accept(sockId, (struct sockaddr*)address.data, &len); - assert(len < sizeof(address.data)); - if (res >= 0) { - address.size = len; - Connection::Handle conn = serverConnect(server, address); - if (conn && !socketToRemove(socket)) { - TcpSocket *s = new TcpSocket(*this, conn, server, address, res); - if (ERR_NONE != connectionOpen(conn, s)) - { s->finalize(); serverDisconnect(server, conn); } - } - } else - if (res != EAGAIN) { - server->stop(ERR_SERVER_LISTENING_LOST); - } - } - } - - if (socketToRemove(socket)) continue; - - if (e & EPOLLOUT) { - if (connection) { - ErrorCode err = ERR_NONE; - bool ready = true; - while(ready) { - if (socketToRemove(socket)) break; - Connection::WriteLock writeLock(*connection); - if (!writeLock) break; - while(true) { - size_t remains = writeLock->remains(); - if (!remains) break; - if (remains > INT_MAX) remains = INT_MAX; - int res = ::send(sockId, writeLock->current(), (int)remains, MSG_DONTWAIT); - if (res <= 0) { - if (res != 0 && res != EAGAIN) err = ERR_CONNECTION_LOST; - ready = false; - break; - } - writeLock->processed(res); - } - } - if (socketToRemove(socket)) continue; - if (err) connection->close(err); - } - } - - if (socketToRemove(socket)) continue; - - if (e & EPOLLHUP) { - if (connection) { - connection->close(e & EPOLLERR ? ERR_CONNECTION_LOST : 0); - } else - if (server) { - server->stop(e & EPOLLERR ? ERR_CONNECTION_LOST : 0); + if (e & EPOLLIN) { + if (connection) { + threadConnectionRead(socket, connection); + } else + if (server) { + Address address; + address.type = Address::SOCKET; + socklen_t len = sizeof(sockaddr_in); + int res = ::accept(sockId, (struct sockaddr*)address.data, &len); + assert(len < sizeof(address.data)); + if (res >= 0) { + address.size = len; + Connection::Handle conn = serverConnect(server, address); + if (conn && !socketToRemove(socket)) { + TcpSocket *s = new TcpSocket(*this, conn, server, address, res); + if (ERR_NONE != connectionOpen(conn, s)) + { s->finalize(); serverDisconnect(server, conn); } } + } else + if (res != EAGAIN) { + server->stop(ERR_SERVER_LISTENING_LOST); } } } - // update epoll - Socket *socket = socketChFirst(); - while(socket) { - int sockId = ((TcpSocket*)socket)->sockId; - struct epoll_event event = {}; - if (socketToRemove(socket)) { - epoll_ctl(epollId, EPOLL_CTL_DEL, sockId, &event); - socket = socketRemoveChanged(socket); - } else { - unsigned int flags = socketFlags(socket); - event.data.ptr = socket; - if (flags & Socket::WANTREAD) event.events |= EPOLLIN; - if (flags & Socket::WANTWRITE) event.events |= EPOLLOUT; - epoll_ctl(epollId, EPOLL_CTL_MOD, sockId, &event); - socket = socketSetUnchanged(socket); - } + if (socketToRemove(socket)) continue; + + if (e & EPOLLOUT) { + if (connection) + threadConnectionWrite(socket, connection); } - if (!isActive()) - break; - } // end of locked ecope + if (socketToRemove(socket)) continue; + + if (e & EPOLLHUP) { + if (connection) { + connection->close(e & EPOLLERR ? ERR_CONNECTION_LOST : 0); + } else + if (server) { + server->stop(e & EPOLLERR ? ERR_SERVER_LISTENING_LOST : 0); + } + } + } + } + + // update epoll + Socket *socket = socketChFirst(); + while(socket) { + int sockId = ((TcpSocket*)socket)->sockId; + struct epoll_event event = {}; + unsigned int flags = socketFlags(socket); + + //if (!(flags & Socket::REMOVED) && flags & Socket::WANTREAD) { + // threadConnectionRead((TcpSocket*)socket, socket->connection); + // flags = socketFlags(socket); + //} - count = epoll_wait(epollId, events, count, 1000); - if (count < 0) break; + //if (!(flags & Socket::REMOVED) && flags & Socket::WANTWRITE) { + // threadConnectionWrite((TcpSocket*)socket, socket->connection); + // flags = socketFlags(socket); + //} + + if (flags & Socket::REMOVED) { + epoll_ctl(epollId, EPOLL_CTL_DEL, sockId, &event); + socket = socketRemoveChanged(socket); + } else { + event.data.ptr = socket; + if (flags & Socket::WANTREAD) event.events |= EPOLLIN; + if (flags & Socket::WANTWRITE) event.events |= EPOLLOUT; + epoll_ctl(epollId, (flags & Socket::CREATED) ? EPOLL_CTL_ADD : EPOLL_CTL_MOD, sockId, &event); + socket = socketSetUnchanged(socket); + } } +} + + +void TcpProtocol::threadRun() { + const int maxCount = 16; + struct epoll_event events[maxCount] = {}; + int count = 0; - std::unique_lock uniqlock(mutex); - if (isActive()) { - stop(ERR_TCP_WORKER_FAILED); - while(isActive()) threadWaitCondition.wait(uniqlock); + bool idle = true; + mutexIdle.lock(); + while(true) { + unsigned long long num = 0; + while(::read(eventsId, &num, sizeof(num)) > 0 && !num); + + if (num == 3 && idle) + { idle = false; mutexRun.lock(); mutexIdle.unlock(); } + + if (idle) { + if (num == 1) break; + } else { + Lock lock(mutex); + threadProcessEvents(events, count); + } + + if (num == 2 && !idle) + { idle = true; mutexIdle.lock(); mutexRun.unlock(); } + + count = epoll_wait(epollId, events, maxCount, 100); } + mutexIdle.unlock(); } + ErrorCode TcpProtocol::onConnect(const Connection::Handle &connection, const Address &address) { if (address.type != Address::SOCKET) return ERR_ADDRESS_INCORRECT; @@ -279,49 +334,22 @@ ErrorCode TcpProtocol::onListen(const Server::Handle &server, const Address &add ErrorCode TcpProtocol::onStart() { - epollId = ::epoll_create(32); - if (epollId < 0) { - epollId = 0; - return ERR_PROTOCOL_INITIALIZATION_FAILED; - } - - eventsId = ::eventfd(0, EFD_NONBLOCK); - if (eventsId < 0) { - ::close(epollId); - epollId = 0; - eventsId = 0; - return ERR_PROTOCOL_INITIALIZATION_FAILED; - } - - struct epoll_event event = {}; - event.events = EPOLLIN; - epoll_ctl(epollId, EPOLL_CTL_ADD, eventsId, &event); - - thread = new std::thread(&TcpProtocol::threadRun, this); + controlCommand(3); + std::lock_guard lock(mutexIdle); return ERR_NONE; } void TcpProtocol::onStop(ErrorCode) { - onSocketChanged(nullptr); - threadWaitCondition.notify_all(); - thread->join(); - delete thread; - - struct epoll_event event = {}; - epoll_ctl(epollId, EPOLL_CTL_DEL, eventsId, &event); - - ::close(eventsId); - eventsId = 0; - - ::close(epollId); - epollId = 0; + int cnt = mutex.count(); + for(int i = 0; i < cnt; ++i) mutex.unlock(); + controlCommand(2); + { std::lock_guard lock(mutexRun); } + for(int i = 0; i < cnt; ++i) mutex.lock(); } -void TcpProtocol::onSocketChanged(Socket*) { - void *ptr = nullptr; - ::write(eventsId, &ptr, sizeof(ptr)); -} +void TcpProtocol::onSocketChanged(Socket*) + { controlCommand(0); } diff --git a/tcpprotocol.h b/tcpprotocol.h index df66304..207fe7a 100644 --- a/tcpprotocol.h +++ b/tcpprotocol.h @@ -39,11 +39,18 @@ class TcpProtocol: public Protocol { private: friend class TcpSocket; + std::mutex mutexIdle; + std::mutex mutexRun; + std::atomic wantRun; + std::thread *thread; - std::condition_variable_any threadWaitCondition; int epollId; int eventsId; + void controlCommand(unsigned long long cmd); + void threadConnectionWrite(TcpSocket *socket, const Connection::Handle &connection); + void threadConnectionRead(TcpSocket *socket, const Connection::Handle &connection); + void threadProcessEvents(void *eventsPtr, int count); void threadRun(); public: