From 6462285f42d1f2ac3338dc86a3c787589aa0f4ea Mon Sep 17 00:00:00 2001 From: Ivan Mahonin Date: Oct 10 2021 06:48:16 +0000 Subject: improve protocol abstraction --- diff --git a/connection.cpp b/connection.cpp index 7321559..a8fe9ce 100644 --- a/connection.cpp +++ b/connection.cpp @@ -6,10 +6,14 @@ Connection::Connection(): - socket(), lastId(), closeRequested(), closeTimeUs() { } + started(), switching(), socket(), lastId(), closeRequested(), closeTimeUs() { } -Connection::~Connection() - { close(); } + +Connection::~Connection() { + Lock lock(mutex); + assert(!started && !socket); + clean(); +} Protocol& Connection::getProtocol() const @@ -20,34 +24,47 @@ const Server::Handle& Connection::getServer() const { assert(socket); return socket->server; } -void Connection::open(Socket *socket) { +void Connection::clean() { + readQueue.clear(); + writeQueue.clear(); + started = false; + socket = nullptr; + closeRequested = false; + closeTimeUs = 0; +} + + +ErrorCode Connection::open(Socket *socket) { Lock lock(mutex); close(); - if (!socket) return; + + if (!socket) return ERR_INVALID_ARGS; + if (switching) return ERR_CONNECTION_IS_SWITCHING; this->socket = socket; - onOpen(); + + switching = true; + ErrorCode errorCode = onOpen(); + if (errorCode) clean(); else started = true; + switching = false; + + return errorCode; } void Connection::close(ErrorCode errorCode) { - Socket *socketCopy; - { - Lock lock(mutex); - if (!socket) - return; - if (!errorCode && (!readQueue.empty() || !writeQueue.empty())) - errorCode = ERR_CONNECTION_UNFINISHED_TASKS; - - onClose(errorCode); - - readQueue.clear(); - writeQueue.clear(); - socketCopy = socket; - socket = nullptr; - closeRequested = false; - closeTimeUs = 0; - closeAwaitCondition.notify_all(); - } + Lock lock(mutex); + + if (!started || switching) return; + if (!errorCode && (!readQueue.empty() || !writeQueue.empty())) + errorCode = ERR_CONNECTION_UNFINISHED_TASKS; + + switching = true; + onClose(errorCode); + Socket *socketCopy = socket; + clean(); + switching = false; + + closeWaitCondition.notify_all(); socketCopy->finalize(); } @@ -55,7 +72,7 @@ void Connection::close(ErrorCode errorCode) { Connection::ReqId Connection::writeReq(const void *data, size_t size, void *userData) { if (!data || !size) return 0; Lock lock(mutex); - if (!socket) return 0; + if (!started) return 0; writeQueue.emplace_back(++lastId, userData, data, size); return lastId; } @@ -64,7 +81,7 @@ Connection::ReqId Connection::writeReq(const void *data, size_t size, void *user Connection::ReqId Connection::readReq(void *data, size_t size, void *userData) { if (!data || !size) return 0; Lock lock(mutex); - if (!socket) return 0; + if (!started) return 0; readQueue.emplace_back(++lastId, userData, data, size); return lastId; } @@ -72,31 +89,31 @@ Connection::ReqId Connection::readReq(void *data, size_t size, void *userData) { void Connection::closeReq() { Lock lock(mutex); - if (!socket || closeRequested) return; + if (!started || closeRequested) return; closeRequested = true; onCloseRequested(); } -void Connection::closeAwait(unsigned long long timeoutUs, bool withRequest) { +void Connection::closeWait(unsigned long long timeoutUs, bool withRequest) { std::unique_lock uniqlock(mutex); - if (!socket) return; + if (!started) return; if (withRequest) closeReq(); unsigned long long timeUs = monotonicTimeUs() + timeoutUs; if (closeTimeUs > timeUs) - { closeTimeUs = timeUs; closeAwaitCondition.notify_all(); } + { closeTimeUs = timeUs; closeWaitCondition.notify_all(); } - while(!socket) { + while(started) { unsigned long long timeUs = monotonicTimeUs(); if (timeUs >= closeTimeUs) - { close(); break; } - closeAwaitCondition.wait_for(uniqlock, std::chrono::microseconds(closeTimeUs - timeUs)); + { close(ERR_CONNECTION_LOST); break; } + closeWaitCondition.wait_for(uniqlock, std::chrono::microseconds(closeTimeUs - timeUs)); } } -void Connection::onOpen() { } +ErrorCode Connection::onOpen() { return ERR_NONE; } void Connection::onCloseRequested() { } void Connection::onClose(ErrorCode) { } void Connection::onReadReady(const ReadReq&) { } diff --git a/connection.h b/connection.h index dbe6916..9c823ac 100644 --- a/connection.h +++ b/connection.h @@ -17,6 +17,7 @@ enum : ErrorCode { ERR_CONNECTION_FAILED, ERR_CONNECTION_UNFINISHED_TASKS, ERR_CONNECTION_LOST, + ERR_CONNECTION_IS_SWITCHING, }; @@ -108,6 +109,8 @@ public: private: Mutex mutex; + bool started; + bool switching; Socket *socket; ReqId lastId; @@ -116,22 +119,25 @@ private: bool closeRequested; unsigned long long closeTimeUs; - std::condition_variable_any closeAwaitCondition; - + std::condition_variable_any closeWaitCondition; + + // mutex must be locked before call + // used in open and close + void clean(); + friend class Protocol; - void open(Socket *socket); + ErrorCode open(Socket *socket); public: Connection(); virtual ~Connection(); void close(ErrorCode errorCode = ERR_NONE); + void closeReq(); + void closeWait(unsigned long long timeoutUs, bool withReqest = true); ReqId writeReq(const void *data, size_t size, void *userData = nullptr); ReqId readReq(void *data, size_t size, void *userData = nullptr); - void closeReq(); - - void closeAwait(unsigned long long timeoutUs, bool withReqest = true); protected: // mutex must be locked before calling of all following methods @@ -139,7 +145,7 @@ protected: const Address& getRemoteAddress() const; const THandle& getServer() const; - virtual void onOpen(); + virtual ErrorCode onOpen(); virtual void onCloseRequested(); virtual void onClose(ErrorCode errorCode); virtual void onReadReady(const ReadReq &req); diff --git a/error.h b/error.h index 3814c2b..60fc130 100644 --- a/error.h +++ b/error.h @@ -19,7 +19,7 @@ enum : ErrorCode { enum : ErrorCode { ERR_COMMON = ERR_NONE + 1, ERR_NOT_IMPLEMENTED, - ERR_INVALIDD_ARGS, + ERR_INVALID_ARGS, }; diff --git a/protocol.cpp b/protocol.cpp index 57beaeb..11b7b9a 100644 --- a/protocol.cpp +++ b/protocol.cpp @@ -1,4 +1,5 @@ +#include "utils.h" #include "protocol.h" #include "server.h" #include "connection.h" @@ -34,7 +35,7 @@ void Socket::setChanged(bool changed) { flags |= CHANGED; chPrev = protocol.chLast; (chPrev ? chPrev->chNext : protocol.chFirst) = this; - protocol.wakeup(); + protocol.onSocketChanged(this); } else { flags &= ~CHANGED; (chPrev ? chPrev->chNext : protocol.chFirst) = chNext; @@ -70,31 +71,164 @@ void Socket::finalize() { Protocol::Protocol(): - first(), last(), chFirst(), chLast() { } -Protocol::~Protocol() - { assert(!first); } + first(), last(), chFirst(), chLast(), + started(), switching(), stopping(), stopRequested(), stopTimeUs() { } -void Protocol::connectionOpen(const Connection::Handle &connection, Socket *socket) { - assert(connection && socket); - connection->open(socket); +Protocol::~Protocol() { + Lock lockStartStop(mutexStartStop); + Lock lock(mutex); + assert(!started); + clean(); } -void Protocol::closeAll() { +ErrorCode Protocol::connectionOpen(const Connection::Handle &connection, Socket *socket) + { return connection && socket ? connection->open(socket) : ERR_INVALID_ARGS; } +ErrorCode Protocol::serverStart(const Server::Handle &server, Socket *socket) + { return server && socket ? server->start(socket) : ERR_INVALID_ARGS; } + + +int Protocol::stopServer(Server &server) { Lock lock(mutex); - while(first) - if (first->connection) first->connection->close(); else - if (first->server) first->server->close(); else - first->finalize(); + int count = 0; + if (started) + for(Socket *s = socketFirst(); s; s = socketNext(s)) + if (!(socketFlags(s) & Socket::REMOVED) && s->server == &server && s->connection) + { s->connection->close(ERR_CONNECTION_LOST); ++count; } + return count; } -void Protocol::wakeup() - { } -ErrorCode Protocol::resolve(Address&) +int Protocol::stopServerReq(Server &server) { + Lock lock(mutex); + int count = 0; + if (started) + for(Socket *s = socketFirst(); s; s = socketNext(s)) + if (!(socketFlags(s) & Socket::REMOVED) && s->server == &server && s->connection) + { s->connection->closeReq(); ++count; } + return count; +} + + +void Protocol::clean() { + assert(!socketFirst()); + while(socketFirst()) delete socketFirst(); + started = false; + stopping = false; + stopRequested = false; + stopTimeUs = 0; +} + + +ErrorCode Protocol::start() { + Lock lockStartStop(mutexStartStop); + Lock lock(mutex); + stop(); + + if (switching) return ERR_PROTOCOL_IS_SWITCHING; + + switching = true; + ErrorCode errorCode = onStart(); + if (errorCode) clean(); else started = true; + switching = false; + + return errorCode; +} + + +void Protocol::stop(ErrorCode errorCode) { + Lock lockStartStop(mutexStartStop); + Lock lock(mutex); + if (!started || switching) return; + + switching = true; + stopping = true; + int unfinishedConnections = 0; + for(Socket *s = socketFirst(); s; s = socketNext(s)) { + if (!(socketFlags(s) & Socket::REMOVED)) { + if (!s->server && s->connection) { + s->connection->close(ERR_CONNECTION_LOST); + ++unfinishedConnections; + } else + if (s->server && !s->connection) { + s->server->stop(ERR_SERVER_LISTENING_LOST); + ++unfinishedConnections; + } + } + } + if (!errorCode && unfinishedConnections) errorCode = ERR_PROTOCOL_UNFINISHED_CONNECTIONS; + onStop(errorCode); + clean(); + switching = false; + + stopWaitCondition.notify_all(); +} + + +void Protocol::stopReq() { + Lock lock(mutex); + if (!started || stopRequested) return; + stopRequested = true; + for(Socket *s = socketFirst(); s; s = socketNext(s)) { + if (!(socketFlags(s) & Socket::REMOVED)) { + if (!s->server && s->connection) + s->connection->closeReq(); + else + if (s->server && !s->connection) + s->server->stopReq(); + } + } +} + + +void Protocol::stopWait(unsigned long long timeoutUs, bool withRequest) { + std::unique_lock uniqlock(mutex); + + if (!started) return; + if (withRequest) stopReq(); + unsigned long long timeUs = monotonicTimeUs() + timeoutUs; + if (stopTimeUs > timeUs) + { stopTimeUs = timeUs; stopWaitCondition.notify_all(); } + + while(started) { + unsigned long long timeUs = monotonicTimeUs(); + if (timeUs >= stopTimeUs) + { stop(); break; } + stopWaitCondition.wait_for(uniqlock, std::chrono::microseconds(stopTimeUs - timeUs)); + } +} + + +ErrorCode Protocol::connect(const Connection::Handle &connection, const Address &address) { + if (!connection) return ERR_INVALID_ARGS; + Lock lock(mutex); + if (!isFullActive()) return ERR_PROTOCOL_NOT_INITIALIZED; + return onConnect(connection, address); +} + + +ErrorCode Protocol::listen(const Server::Handle &server, const Address &address) { + if (!server) return ERR_INVALID_ARGS; + Lock lock(mutex); + if (!isFullActive()) return ERR_PROTOCOL_NOT_INITIALIZED; + return onListen(server, address); +} + + +ErrorCode Protocol::resolve(Address&) { return ERR_ADDRESS_INCORRECT; } -ErrorCode Protocol::connect(const Connection::Handle&, const Address&) + + +ErrorCode Protocol::onStart() + { return ERR_NONE; } +void Protocol::onStop(ErrorCode) + { } +void Protocol::onSocketChanged(Socket*) + { } +ErrorCode Protocol::onConnect(const Connection::Handle&, const Address&) { return ERR_CONNECTION_FAILED; } -ErrorCode Protocol::listen(const Server::Handle&, const Address&) +ErrorCode Protocol::onListen(const Server::Handle&, const Address&) { return ERR_SERVER_LISTENING_FAILED; } + + diff --git a/protocol.h b/protocol.h index 1fec707..1482139 100644 --- a/protocol.h +++ b/protocol.h @@ -13,6 +13,8 @@ enum : ErrorCode { ERR_PROTOCOL_COMMON = ERR_SERVER, ERR_PROTOCOL_NOT_INITIALIZED, + ERR_PROTOCOL_UNFINISHED_CONNECTIONS, + ERR_PROTOCOL_IS_SWITCHING, }; @@ -73,7 +75,18 @@ private: friend class Socket; Socket *first, *last; Socket *chFirst, *chLast; + bool started; + bool switching; + bool stopping; + bool stopRequested; + unsigned long long stopTimeUs; + std::condition_variable_any stopWaitCondition; +protected: + Mutex mutexStartStop; // using only inside start and stop functions + Mutex mutex; + +private: inline void socketCheck(Socket *socket, bool mustBeChanged = false) { assert(socket); assert(&socket->protocol == this); @@ -81,11 +94,19 @@ private: } protected: - Mutex mutex; - bool allowNewConnections; - // any class derived from Protocol may access to Socket private members via these methods // mutex must be locked before call + inline bool isStarted() const + { return started; } + inline bool isStopping() const + { return stopping; } + inline bool isStopRequested() const + { return stopRequested; } + inline bool isActive() const + { return isStarted() && !isStopping(); } + inline bool isFullActive() const + { return isActive() && !isStopRequested(); } + inline Socket* socketFirst() { return first; } inline Socket* socketNext(Socket *socket) @@ -97,27 +118,50 @@ protected: inline unsigned int socketFlags(Socket *socket) { socketCheck(socket); return socket->flags; } - + inline Socket* socketRemove(Socket *socket) { socketCheck(socket); Socket *next = socket->next; delete socket; return next; } inline Socket* socketRemoveChanged(Socket *socket) { socketCheck(socket, true); Socket *chNext = socket->chNext; delete socket; return chNext; } inline Socket* socketSetUnchanged(Socket *socket) { socketCheck(socket, true); Socket *chNext = socket->chNext; socket->setChanged(false); return chNext; } - - // thread-safe methods - void connectionOpen(const Connection::Handle &connection, Socket *socket); - virtual void wakeup(); + ErrorCode connectionOpen(const Connection::Handle &connection, Socket *socket); + ErrorCode serverStart(const Server::Handle &server, Socket *socket); + +private: + // mutex must be locked before call + // used in start and stop + void clean(); + + // functions used by the Server class + // thread-safe + friend class Server; + int stopServer(Server &server); + int stopServerReq(Server &server); public: // thread-safe methods Protocol(); virtual ~Protocol(); + + ErrorCode start(); + void stop(ErrorCode errorCode = ERR_NONE); + void stopReq(); + void stopWait(unsigned long long timeoutUs, bool withRequest = true); + + ErrorCode connect(const Connection::Handle &connection, const Address &address); + ErrorCode listen(const Server::Handle &server, const Address &address); + virtual ErrorCode resolve(Address &address); - virtual ErrorCode connect(const Connection::Handle &connection, const Address &address); - virtual ErrorCode listen(const Server::Handle &server, const Address &address); - void closeAll(); + +protected: + // mutex must be locked before call + virtual ErrorCode onStart(); + virtual void onStop(ErrorCode errorCode); + virtual void onSocketChanged(Socket *socket); + virtual ErrorCode onConnect(const Connection::Handle &connection, const Address &address); + virtual ErrorCode onListen(const Server::Handle &server, const Address &address); }; diff --git a/server.cpp b/server.cpp index 3e2983f..7ea9ead 100644 --- a/server.cpp +++ b/server.cpp @@ -1,4 +1,5 @@ +#include "utils.h" #include "protocol.h" #include "server.h" #include "connection.h" @@ -6,43 +7,93 @@ Server::Server(): - socket() { } + started(), switching(), socket(), stopRequested(), stopTimeUs() { } -Server::~Server() - { } +Server::~Server() { + Lock lock(mutex); + assert(!started && !socket); + clean(); +} + + +Protocol& Server::getProtocol() const + { assert(socket); return socket->protocol; } +const Address& Server::getLocalAddress() const + { assert(socket); return socket->address; } + + +void Server::clean() { + started = false; + socket = nullptr; + stopRequested = false; + stopTimeUs = 0; +} -void Server::open(Socket *socket) { +ErrorCode Server::start(Socket *socket) { Lock lock(mutex); - close(); - if (!socket) return; + stop(); + + if (!socket) return ERR_INVALID_ARGS; + if (switching) return ERR_SERVER_IS_SWITCHING; this->socket = socket; - onOpen(); + + switching = true; + ErrorCode errorCode = onStart(); + if (errorCode) clean(); else started = true; + switching = false; + + return errorCode; } -void Server::close(ErrorCode errorCode) { - Socket *socketCopy; - { - Lock lock(mutex); - if (!socket) return; - onClose(errorCode); - socketCopy = socket; - socket = nullptr; - } +void Server::stop(ErrorCode errorCode) { + Lock lock(mutex); + + if (!started || switching) return; + + switching = true; + int unfinichedConnections = getProtocol().stopServer(*this); + if (!errorCode && unfinichedConnections) errorCode = ERR_SERVER_UNFINISHED_CONNECTIONS; + onStop(errorCode); + Socket *socketCopy = socket; + clean(); + switching = false; + socketCopy->finalize(); } -Protocol& Server::getProtocol() const - { assert(socket); return socket->protocol; } -const Address& Server::getLocalAddress() const - { assert(socket); return socket->address; } + +void Server::stopReq() { + Lock lock(mutex); + if (!started || stopRequested) return; + stopRequested = true; + getProtocol().stopServerReq(*this); +} -void Server::onOpen() - { } -void Server::onClose(ErrorCode) +void Server::stopWait(unsigned long long timeoutUs, bool withRequest) { + std::unique_lock uniqlock(mutex); + + if (!socket) return; + if (withRequest) stopReq(); + unsigned long long timeUs = monotonicTimeUs() + timeoutUs; + if (stopTimeUs > timeUs) + { stopTimeUs = timeUs; stopWaitCondition.notify_all(); } + + while(!socket) { + unsigned long long timeUs = monotonicTimeUs(); + if (timeUs >= stopTimeUs) + { stop(ERR_CONNECTION_LOST); break; } + stopWaitCondition.wait_for(uniqlock, std::chrono::microseconds(stopTimeUs - timeUs)); + } +} + + +ErrorCode Server::onStart() + { return ERR_NONE; } +void Server::onStop(ErrorCode) { } Connection::Handle Server::onConnect(const Address&) { return nullptr; } diff --git a/server.h b/server.h index 1253ea5..f8cdb38 100644 --- a/server.h +++ b/server.h @@ -18,6 +18,7 @@ enum : ErrorCode { ERR_SERVER_LISTENING_FAILED, ERR_SERVER_UNFINISHED_CONNECTIONS, ERR_SERVER_LISTENING_LOST, + ERR_SERVER_IS_SWITCHING, }; @@ -30,23 +31,35 @@ public: private: Mutex mutex; + bool started; + bool switching; Socket *socket; + bool stopRequested; + unsigned long long stopTimeUs; + std::condition_variable_any stopWaitCondition; + + // mutex must be locked before call + // used in open and close + void clean(); + friend class Protocol; - void open(Socket *socket); + ErrorCode start(Socket *socket); public: Server(); ~Server(); - void close(ErrorCode errorCode = ERR_NONE); + void stop(ErrorCode errorCode = ERR_NONE); + void stopReq(); + void stopWait(unsigned long long timeoutUs, bool withRequest = true); protected: Protocol& getProtocol() const; const Address& getLocalAddress() const; - void onOpen(); - void onClose(ErrorCode errorCode); + ErrorCode onStart(); + void onStop(ErrorCode errorCode); THandle onConnect(const Address &remoteAddres); void onDisconnect(const THandle &connection); }; diff --git a/tcpprotocol.cpp b/tcpprotocol.cpp index 44c2e6a..b99639e 100644 --- a/tcpprotocol.cpp +++ b/tcpprotocol.cpp @@ -37,7 +37,7 @@ TcpSocket::TcpSocket( TcpProtocol::TcpProtocol(): - thread(), stopping() { } + thread() { } ErrorCode TcpProtocol::resolve(Address &address) { @@ -105,14 +105,10 @@ void TcpProtocol::threadRun() { epoll_wait(epollId, events, count, 1000); break; } - closeAll(); } -ErrorCode TcpProtocol::connect(const Connection::Handle &connection, const Address &address) { - if (!thread) return ERR_PROTOCOL_NOT_INITIALIZED; - if (!connection) return ERR_INVALIDD_ARGS; - +ErrorCode TcpProtocol::onConnect(const Connection::Handle &connection, const Address &address) { Address remoteAddres = address; ErrorCode errorCode = resolve(remoteAddres); if (errorCode) return errorCode; @@ -131,21 +127,20 @@ ErrorCode TcpProtocol::connect(const Connection::Handle &connection, const Addre } -ErrorCode TcpProtocol::listen(const Server::Handle &server, const Address &address) { +ErrorCode TcpProtocol::onListen(const Server::Handle &server, const Address &address) { return ERR_NOT_IMPLEMENTED; } -void TcpProtocol::start() { - stop(); +ErrorCode TcpProtocol::onStart() { epollId = ::epoll_create(32); thread = new std::thread(&TcpProtocol::threadRun, this); + return ERR_NONE; } -void TcpProtocol::stop() { +void TcpProtocol::onStop(ErrorCode errorCode) { if (!thread) return; - { Lock lock2(mutex); stopping = true; } thread->join(); delete thread; ::close(epollId); diff --git a/tcpprotocol.h b/tcpprotocol.h index 7da4d78..710069f 100644 --- a/tcpprotocol.h +++ b/tcpprotocol.h @@ -34,19 +34,21 @@ private: friend class TcpSocket; std::thread *thread; - bool stopping; int epollId; void threadRun(); public: TcpProtocol(); + ~TcpProtocol() { stop(); } ErrorCode resolve(Address &address) override; - ErrorCode connect(const Connection::Handle &connection, const Address &address) override; - ErrorCode listen(const Server::Handle &server, const Address &address) override; - void start(); - void stop(); + +protected: + ErrorCode onStart(); + void onStop(ErrorCode); + ErrorCode onConnect(const Connection::Handle &connection, const Address &address) override; + ErrorCode onListen(const Server::Handle &server, const Address &address) override; };