diff --git a/address.h b/address.h index 2481349..af105f5 100644 --- a/address.h +++ b/address.h @@ -34,7 +34,7 @@ public: inline void set(Type type, const char *text = nullptr) { this->type = type; - if (text) this->text = text; + if (text) this->text = text; else this->text.clear(); size = 0; memset(data, 0, sizeof(data)); } diff --git a/connection.cpp b/connection.cpp index a8fe9ce..1481ef2 100644 --- a/connection.cpp +++ b/connection.cpp @@ -20,6 +20,8 @@ Protocol& Connection::getProtocol() const { assert(socket); return socket->protocol; } const Address& Connection::getRemoteAddress() const { assert(socket); return socket->address; } +const Connection::Handle& Connection::getConnection() const + { assert(socket); return socket->connection; } const Server::Handle& Connection::getServer() const { assert(socket); return socket->server; } @@ -60,6 +62,7 @@ void Connection::close(ErrorCode errorCode) { switching = true; onClose(errorCode); + if (getServer()) getServer()->disconnect(getConnection()); Socket *socketCopy = socket; clean(); switching = false; diff --git a/connection.h b/connection.h index 9c823ac..2b37c64 100644 --- a/connection.h +++ b/connection.h @@ -78,7 +78,10 @@ public: protected: inline TRWLock(Connection &connection, typename ReqType::Queue &queue): - lock(connection.mutex), connection(connection), request(queue.empty() ? nullptr : &queue.front()) { } + lock(connection.mutex), + connection(connection), + request(connection.started && !queue.empty() ? &queue.front() : nullptr) + { } public: inline operator bool() const @@ -143,6 +146,7 @@ protected: // mutex must be locked before calling of all following methods Protocol& getProtocol() const; const Address& getRemoteAddress() const; + const THandle& getConnection() const; const THandle& getServer() const; virtual ErrorCode onOpen(); diff --git a/protocol.cpp b/protocol.cpp index 11b7b9a..0ca8666 100644 --- a/protocol.cpp +++ b/protocol.cpp @@ -13,11 +13,10 @@ Socket::Socket( const Address &address ): protocol(protocol), connection(connection), server(server), address(address), - flags(CREATED), prev(), next(), chPrev(), chNext() + flags(), prev(), next(), chPrev(), chNext() { prev = protocol.last; (prev ? prev->next : protocol.first) = this; - setChanged(true); } @@ -29,6 +28,12 @@ Socket::~Socket() { } +void Socket::onConstructionComplete() { + flags |= CREATED; + setChanged(true); +} + + void Socket::setChanged(bool changed) { if ((bool)(flags & CHANGED) == changed) return; if (changed) { @@ -37,7 +42,7 @@ void Socket::setChanged(bool changed) { (chPrev ? chPrev->chNext : protocol.chFirst) = this; protocol.onSocketChanged(this); } else { - flags &= ~CHANGED; + flags &= ~(CHANGED | CREATED); (chPrev ? chPrev->chNext : protocol.chFirst) = chNext; (chNext ? chNext->chPrev : protocol.chLast) = chPrev; chPrev = chNext = nullptr; @@ -83,10 +88,25 @@ Protocol::~Protocol() { } -ErrorCode Protocol::connectionOpen(const Connection::Handle &connection, Socket *socket) - { return connection && socket ? connection->open(socket) : ERR_INVALID_ARGS; } -ErrorCode Protocol::serverStart(const Server::Handle &server, Socket *socket) - { return server && socket ? server->start(socket) : ERR_INVALID_ARGS; } +ErrorCode Protocol::connectionOpen(const Connection::Handle &connection, Socket *socket) { + if (!connection || !socket) return ERR_INVALID_ARGS; + if (!isFullActive()) return ERR_PROTOCOL_NOT_INITIALIZED; + return connection->open(socket); +} + + +ErrorCode Protocol::serverStart(const Server::Handle &server, Socket *socket) { + if (!server || !socket) return ERR_INVALID_ARGS; + if (!isFullActive()) return ERR_PROTOCOL_NOT_INITIALIZED; + return server->start(socket); +} + + +Connection::Handle Protocol::serverConnect(const Server::Handle &server, const Address &address) { + if (!server) return nullptr; + if (!isFullActive()) return nullptr; + return server->connect(address); +} int Protocol::stopServer(Server &server) { @@ -94,7 +114,7 @@ int Protocol::stopServer(Server &server) { int count = 0; if (started) for(Socket *s = socketFirst(); s; s = socketNext(s)) - if (!(socketFlags(s) & Socket::REMOVED) && s->server == &server && s->connection) + if (socketToRemove(s) && s->server == &server && s->connection) { s->connection->close(ERR_CONNECTION_LOST); ++count; } return count; } @@ -105,7 +125,7 @@ int Protocol::stopServerReq(Server &server) { int count = 0; if (started) for(Socket *s = socketFirst(); s; s = socketNext(s)) - if (!(socketFlags(s) & Socket::REMOVED) && s->server == &server && s->connection) + if (socketToRemove(s) && s->server == &server && s->connection) { s->connection->closeReq(); ++count; } return count; } @@ -146,7 +166,7 @@ void Protocol::stop(ErrorCode errorCode) { stopping = true; int unfinishedConnections = 0; for(Socket *s = socketFirst(); s; s = socketNext(s)) { - if (!(socketFlags(s) & Socket::REMOVED)) { + if (!socketToRemove(s)) { if (!s->server && s->connection) { s->connection->close(ERR_CONNECTION_LOST); ++unfinishedConnections; @@ -179,6 +199,7 @@ void Protocol::stopReq() { s->server->stopReq(); } } + onStopRequested(); } @@ -202,17 +223,25 @@ void Protocol::stopWait(unsigned long long timeoutUs, bool withRequest) { ErrorCode Protocol::connect(const Connection::Handle &connection, const Address &address) { if (!connection) return ERR_INVALID_ARGS; + Address resolvedAddress = address; + ErrorCode errorCode = resolve(resolvedAddress); + if (errorCode) return errorCode; + Lock lock(mutex); if (!isFullActive()) return ERR_PROTOCOL_NOT_INITIALIZED; - return onConnect(connection, address); + return onConnect(connection, resolvedAddress); } ErrorCode Protocol::listen(const Server::Handle &server, const Address &address) { if (!server) return ERR_INVALID_ARGS; + Address resolvedAddress = address; + ErrorCode errorCode = resolve(resolvedAddress); + if (errorCode) return errorCode; + Lock lock(mutex); if (!isFullActive()) return ERR_PROTOCOL_NOT_INITIALIZED; - return onListen(server, address); + return onListen(server, resolvedAddress); } @@ -222,10 +251,12 @@ ErrorCode Protocol::resolve(Address&) ErrorCode Protocol::onStart() { return ERR_NONE; } -void Protocol::onStop(ErrorCode) +void Protocol::onStopRequested() { } -void Protocol::onSocketChanged(Socket*) +void Protocol::onStop(ErrorCode) { } +void Protocol::onSocketChanged(Socket *socket) + { if (socketToRemove(socket)) socketRemove(socket); } ErrorCode Protocol::onConnect(const Connection::Handle&, const Address&) { return ERR_CONNECTION_FAILED; } ErrorCode Protocol::onListen(const Server::Handle&, const Address&) diff --git a/protocol.h b/protocol.h index 1482139..7da1daa 100644 --- a/protocol.h +++ b/protocol.h @@ -13,6 +13,7 @@ enum : ErrorCode { ERR_PROTOCOL_COMMON = ERR_SERVER, ERR_PROTOCOL_NOT_INITIALIZED, + ERR_PROTOCOL_INITIALIZATION_FAILED, ERR_PROTOCOL_UNFINISHED_CONNECTIONS, ERR_PROTOCOL_IS_SWITCHING, }; @@ -56,6 +57,8 @@ protected: const Address &address ); ~Socket(); + void onConstructionComplete(); + public: // thread-saef methods void setWantRead(bool want); @@ -118,6 +121,8 @@ protected: inline unsigned int socketFlags(Socket *socket) { socketCheck(socket); return socket->flags; } + inline unsigned int socketToRemove(Socket *socket) + { socketCheck(socket); return socket->flags & Socket::REMOVED; } inline Socket* socketRemove(Socket *socket) { socketCheck(socket); Socket *next = socket->next; delete socket; return next; } @@ -128,6 +133,7 @@ protected: ErrorCode connectionOpen(const Connection::Handle &connection, Socket *socket); ErrorCode serverStart(const Server::Handle &server, Socket *socket); + Connection::Handle serverConnect(const Server::Handle &server, const Address &address); private: // mutex must be locked before call @@ -158,6 +164,7 @@ public: protected: // mutex must be locked before call virtual ErrorCode onStart(); + virtual void onStopRequested(); virtual void onStop(ErrorCode errorCode); virtual void onSocketChanged(Socket *socket); virtual ErrorCode onConnect(const Connection::Handle &connection, const Address &address); diff --git a/server.cpp b/server.cpp index 7ea9ead..6dbbba2 100644 --- a/server.cpp +++ b/server.cpp @@ -70,6 +70,7 @@ void Server::stopReq() { if (!started || stopRequested) return; stopRequested = true; getProtocol().stopServerReq(*this); + onStopRequested(); } @@ -91,8 +92,16 @@ void Server::stopWait(unsigned long long timeoutUs, bool withRequest) { } +Connection::Handle Server::connect(const Address &remoteAddress) + { return onConnect(remoteAddress); } +void Server::disconnect(const Connection::Handle &connection) + { onDisconnect(connection); } + + ErrorCode Server::onStart() { return ERR_NONE; } +void Server::onStopReq() + { } void Server::onStop(ErrorCode) { } Connection::Handle Server::onConnect(const Address&) diff --git a/server.h b/server.h index f8cdb38..859d5b8 100644 --- a/server.h +++ b/server.h @@ -44,7 +44,10 @@ private: void clean(); friend class Protocol; + friend class Connection; ErrorCode start(Socket *socket); + THandle connect(const Address &remoteAddress); + void disconnect(const THandle &connection); public: Server(); @@ -55,13 +58,15 @@ public: void stopWait(unsigned long long timeoutUs, bool withRequest = true); protected: + // mutex must be locked before call of all following methods Protocol& getProtocol() const; const Address& getLocalAddress() const; - ErrorCode onStart(); - void onStop(ErrorCode errorCode); - THandle onConnect(const Address &remoteAddres); - void onDisconnect(const THandle &connection); + virtual ErrorCode onStart(); + virtual void onStopRequested(); + virtual void onStop(ErrorCode errorCode); + virtual THandle onConnect(const Address &remoteAddress); + virtual void onDisconnect(const THandle &connection); }; diff --git a/tcpprotocol.cpp b/tcpprotocol.cpp index b99639e..376739c 100644 --- a/tcpprotocol.cpp +++ b/tcpprotocol.cpp @@ -1,14 +1,16 @@ +#include + #include #include #include #include #include +#include #include #include #include - #include "protocol.h" #include "server.h" #include "connection.h" @@ -30,9 +32,7 @@ TcpSocket::TcpSocket( setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int)); fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK); - struct epoll_event event = {}; - event.data.ptr = this; - epoll_ctl(getProtocol().epollId, EPOLL_CTL_ADD, sockId, &event); + onConstructionComplete(); } @@ -99,51 +99,225 @@ ErrorCode TcpProtocol::resolve(Address &address) { void TcpProtocol::threadRun() { - const int count = 1024; - struct epoll_event events[count] = {}; + const int maxCount = 1024; + struct epoll_event events[maxCount] = {}; + int count = 0; + while(true) { - epoll_wait(epollId, events, count, 1000); - break; + { // locked scope + Lock lock(mutex); + + { // read wakeup signals + void *ptr = nullptr; + while(::read(eventsId, &ptr, sizeof(ptr)) > 0); + } + + // process socket events + for(int i = 0; i < count; ++i) { + if (TcpSocket *socket = (TcpSocket*)events[i].data.ptr) { + if (socketToRemove(socket)) continue; + + unsigned int e = events[i].events; + int sockId = socket->sockId; + const Connection::Handle &connection = socket->connection; + const Server::Handle &server = socket->server; + + if (e & EPOLLIN) { + if (connection) { + ErrorCode err = ERR_NONE; + bool ready = true; + while(ready) { + if (socketToRemove(socket)) break; + Connection::ReadLock readLock(*connection); + if (!readLock) break; + while(true) { + size_t remains = readLock->remains(); + if (!remains) break; + if (remains > INT_MAX) remains = INT_MAX; + int res = ::recv(sockId, readLock->current(), (int)remains, MSG_DONTWAIT); + if (res <= 0) { + if (res != 0 && res != EAGAIN) err = ERR_CONNECTION_LOST; + ready = false; + break; + } + readLock->processed(res); + } + } + if (socketToRemove(socket)) continue; + if (err) connection->close(err); + } else + if (server) { + Address address; + address.type = Address::SOCKET; + socklen_t len = sizeof(sockaddr_in); + int res = ::accept(sockId, (struct sockaddr*)address.data, &len); + assert(len < sizeof(address.data)); + if (res >= 0) { + address.size = len; + Connection::Handle conn = serverConnect(server, address); + if (conn && !socketToRemove(socket)) { + TcpSocket *s = new TcpSocket(*this, conn, server, address, res); + if (ERR_NONE != connectionOpen(conn, s)) + s->finalize(); + } + } else + if (res != EAGAIN) { + server->stop(ERR_SERVER_LISTENING_LOST); + } + } + } + + if (socketToRemove(socket)) continue; + + if (e & EPOLLOUT) { + if (connection) { + ErrorCode err = ERR_NONE; + bool ready = true; + while(ready) { + if (socketToRemove(socket)) break; + Connection::WriteLock writeLock(*connection); + if (!writeLock) break; + while(true) { + size_t remains = writeLock->remains(); + if (!remains) break; + if (remains > INT_MAX) remains = INT_MAX; + int res = ::send(sockId, writeLock->current(), (int)remains, MSG_DONTWAIT); + if (res <= 0) { + if (res != 0 && res != EAGAIN) err = ERR_CONNECTION_LOST; + ready = false; + break; + } + writeLock->processed(res); + } + } + if (socketToRemove(socket)) continue; + if (err) connection->close(err); + } + } + + if (socketToRemove(socket)) continue; + + if (e & EPOLLHUP) { + if (connection) { + connection->close(e & EPOLLERR ? ERR_CONNECTION_LOST : 0); + } else + if (server) { + server->stop(e & EPOLLERR ? ERR_CONNECTION_LOST : 0); + } + } + } + } + + // update epoll + Socket *socket = socketChFirst(); + while(socket) { + int sockId = ((TcpSocket*)socket)->sockId; + struct epoll_event event = {}; + if (socketToRemove(socket)) { + epoll_ctl(epollId, EPOLL_CTL_DEL, sockId, &event); + socket = socketRemoveChanged(socket); + } else { + unsigned int flags = socketFlags(socket); + event.data.ptr = socket; + if (flags & Socket::WANTREAD) event.events |= EPOLLIN; + if (flags & Socket::WANTWRITE) event.events |= EPOLLOUT; + epoll_ctl(epollId, EPOLL_CTL_MOD, sockId, &event); + socket = socketSetUnchanged(socket); + } + } + + if (!isActive()) + break; + } // end of locked ecope + + count = epoll_wait(epollId, events, count, 1000); + if (count < 0) break; + } + + std::unique_lock uniqlock(mutex); + if (isActive()) { + stop(ERR_TCP_WORKER_FAILED); + while(isActive()) threadWaitCondition.wait(uniqlock); } } ErrorCode TcpProtocol::onConnect(const Connection::Handle &connection, const Address &address) { - Address remoteAddres = address; - ErrorCode errorCode = resolve(remoteAddres); - if (errorCode) return errorCode; - - Lock lock(mutex); + if (address.type != Address::SOCKET) return ERR_ADDRESS_INCORRECT; int sockId = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - if (0 != ::connect(sockId, &remoteAddres.as(), remoteAddres.size)) { + if (0 != ::connect(sockId, &address.as(), address.size)) { ::close(sockId); return ERR_CONNECTION_FAILED; } - TcpSocket *socket = new TcpSocket(*this, connection, nullptr, remoteAddres, sockId); - connectionOpen(connection, socket); - return ERR_NONE; + TcpSocket *socket = new TcpSocket(*this, connection, nullptr, address, sockId); + ErrorCode errorCode = connectionOpen(connection, socket); + if (errorCode) socket->finalize(); + return errorCode; } ErrorCode TcpProtocol::onListen(const Server::Handle &server, const Address &address) { - return ERR_NOT_IMPLEMENTED; + if (address.type != Address::SOCKET) return ERR_ADDRESS_INCORRECT; + + int sockId = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (0 != ::bind(sockId, &address.as(), address.size)) { + ::close(sockId); + return ERR_SERVER_LISTENING_FAILED; + } + + TcpSocket *socket = new TcpSocket(*this, nullptr, server, address, sockId); + ErrorCode errorCode = serverStart(server, socket); + if (errorCode) socket->finalize(); + return errorCode; } ErrorCode TcpProtocol::onStart() { epollId = ::epoll_create(32); + if (epollId < 0) { + epollId = 0; + return ERR_PROTOCOL_INITIALIZATION_FAILED; + } + + eventsId = ::eventfd(0, EFD_NONBLOCK); + if (eventsId < 0) { + ::close(epollId); + epollId = 0; + eventsId = 0; + return ERR_PROTOCOL_INITIALIZATION_FAILED; + } + + struct epoll_event event = {}; + event.events = EPOLLIN; + epoll_ctl(epollId, EPOLL_CTL_ADD, eventsId, &event); + thread = new std::thread(&TcpProtocol::threadRun, this); return ERR_NONE; } -void TcpProtocol::onStop(ErrorCode errorCode) { - if (!thread) return; +void TcpProtocol::onStop(ErrorCode) { + onSocketChanged(nullptr); + threadWaitCondition.notify_all(); thread->join(); delete thread; + + struct epoll_event event = {}; + epoll_ctl(epollId, EPOLL_CTL_DEL, eventsId, &event); + + ::close(eventsId); + eventsId = 0; + ::close(epollId); epollId = 0; } + +void TcpProtocol::onSocketChanged(Socket*) { + void *ptr = nullptr; + ::write(eventsId, &ptr, sizeof(ptr)); +} + + diff --git a/tcpprotocol.h b/tcpprotocol.h index 710069f..651c8e1 100644 --- a/tcpprotocol.h +++ b/tcpprotocol.h @@ -8,10 +8,16 @@ #include "protocol.h" +enum : ErrorCode { + ERR_TCP_COMMON = ERR_TCP, + ERR_TCP_WORKER_FAILED, +}; + + class TcpProtocol; -class TcpSocket: public Socket { +class TcpSocket final: public Socket { public: const int sockId; protected: @@ -34,7 +40,9 @@ private: friend class TcpSocket; std::thread *thread; + std::condition_variable_any threadWaitCondition; int epollId; + int eventsId; void threadRun(); @@ -45,8 +53,9 @@ public: ErrorCode resolve(Address &address) override; protected: - ErrorCode onStart(); - void onStop(ErrorCode); + ErrorCode onStart() override; + void onStop(ErrorCode) override; + void onSocketChanged(Socket *socket) override; ErrorCode onConnect(const Connection::Handle &connection, const Address &address) override; ErrorCode onListen(const Server::Handle &server, const Address &address) override; };