diff --git a/connection.cpp b/connection.cpp index 1481ef2..0ca6544 100644 --- a/connection.cpp +++ b/connection.cpp @@ -62,7 +62,7 @@ void Connection::close(ErrorCode errorCode) { switching = true; onClose(errorCode); - if (getServer()) getServer()->disconnect(getConnection()); + getProtocol().serverDisconnect(getServer(), getConnection()); Socket *socketCopy = socket; clean(); switching = false; @@ -72,24 +72,6 @@ void Connection::close(ErrorCode errorCode) { } -Connection::ReqId Connection::writeReq(const void *data, size_t size, void *userData) { - if (!data || !size) return 0; - Lock lock(mutex); - if (!started) return 0; - writeQueue.emplace_back(++lastId, userData, data, size); - return lastId; -} - - -Connection::ReqId Connection::readReq(void *data, size_t size, void *userData) { - if (!data || !size) return 0; - Lock lock(mutex); - if (!started) return 0; - readQueue.emplace_back(++lastId, userData, data, size); - return lastId; -} - - void Connection::closeReq() { Lock lock(mutex); if (!started || closeRequested) return; @@ -116,6 +98,24 @@ void Connection::closeWait(unsigned long long timeoutUs, bool withRequest) { } +Connection::ReqId Connection::writeReq(const void *data, size_t size, void *userData) { + if (!data || !size) return 0; + Lock lock(mutex); + if (!started) return 0; + writeQueue.emplace_back(++lastId, userData, data, size); + return lastId; +} + + +Connection::ReqId Connection::readReq(void *data, size_t size, void *userData) { + if (!data || !size) return 0; + Lock lock(mutex); + if (!started) return 0; + readQueue.emplace_back(++lastId, userData, data, size); + return lastId; +} + + ErrorCode Connection::onOpen() { return ERR_NONE; } void Connection::onCloseRequested() { } void Connection::onClose(ErrorCode) { } diff --git a/connection.h b/connection.h index 2b37c64..a6e9737 100644 --- a/connection.h +++ b/connection.h @@ -144,6 +144,10 @@ public: protected: // mutex must be locked before calling of all following methods + inline bool isCloseRequested() const { return started && closeRequested; } + inline bool isAllReadDone() const { return readQueue.empty() || readQueue.front().success(); } + inline bool isAllWriteDone() const { return writeQueue.empty() || writeQueue.front().success(); } + inline bool isAllDone() const { return isAllReadDone() && isAllWriteDone(); } Protocol& getProtocol() const; const Address& getRemoteAddress() const; const THandle& getConnection() const; diff --git a/main.cpp b/main.cpp index b45311d..812ed15 100644 --- a/main.cpp +++ b/main.cpp @@ -1,69 +1,160 @@ #include -#include "tcp.h" +#include "tcpprotocol.h" #include "main.h" -int main(int argc, char** argv) { - if (argc != 3) - { printf("two args expected\n"); return 1; } +namespace { + class MyConnection: public Connection { + private: + enum { BUFSIZE = 128 }; + const bool serverSide; + char buffer[2][BUFSIZE + 1]; + + public: + std::string name; + + explicit MyConnection(const char *name, bool serverSide): + serverSide(serverSide), buffer(), name(name) { } + + protected: + void sendData() { + for(int i = 0; i < BUFSIZE; ++i) buffer[1][i] = '0' + rand()%10; + printf("%s: send: %s\n", name.c_str(), buffer[1]); + writeReq(buffer[1], BUFSIZE); + } + + ErrorCode onOpen() override { + printf("%s: MyConnection::onOpen()\n", name.c_str()); + if (serverSide) sendData(); + readReq(buffer[0], BUFSIZE); + return ERR_NONE; + } + + void onCloseRequested() override { + printf("%s: MyConnection::onCloseRequested()\n", name.c_str()); + if (isAllDone()) close(); + } + + void onClose(ErrorCode errorCode) override { + printf("%s: MyConnection::onClose(%u)\n", name.c_str(), errorCode); + } + + void onReadReady(const ReadReq&) override { + printf("%s: MyConnection::onReadReady()\n", name.c_str()); + printf("%s: received: %s\n", name.c_str(), buffer[0]); + + if (isCloseRequested()) { + if (serverSide) sendData(); else + if (isAllDone()) close(); + } else { + sendData(); + readReq(buffer[0], BUFSIZE); + } + } + + void onWriteReady(const WriteReq&) override { + printf("%s: MyConnection::onWriteReady()\n", name.c_str()); + if (isCloseRequested() && isAllDone() && serverSide) close(); + } + }; - Tcp::Address address; - address.resolve(argv[2]); - if (0 == strcmp(argv[1], "tcpclient")) { - printf("tcpconnect\n"); - - Tcp::Connection connection; - if (connection.open(address)) - printf( "connected to server: %hhu.%hhu.%hhu.%hhu:%hu\n", - address.ip[0], address.ip[1], address.ip[2], address.ip[3], address.port ); - char buf[65] = {}; - while(connection.check()) { - sprintf(buf, "client say %d\n", rand()); - if (connection.write(buf, sizeof(buf) - 1)) - printf("sent: %s\n", buf); - if (connection.read(buf, sizeof(buf) - 1)) - printf("received: %s\n", buf); + class MyServer: public Server { + private: + unsigned int lastId; + + public: + MyServer(): lastId() { } + + protected: + ErrorCode onStart() override { + printf("MyServer::onStart()\n"); + return ERR_NONE; } - printf("last error: %u\n", connection.getLastError()); - return connection.getLastError(); - } + void onStopRequested() override { + printf("MyServer::onStopRequested()\n"); + } + + void onStop(ErrorCode errorCode) override { + printf("MyServer::onStop(%u)\n", errorCode); + } + + Connection::Handle onConnect(const Address &remoteAddress) override { + printf("MyServer::onConnect( "); + for(int i = 0; i < (int)remoteAddress.size; ++i) printf("%d.", remoteAddress.data[i]); + printf(" )\n"); + + char buf[1024] = {}; + sprintf(buf, "remote%d", ++lastId); + printf("new connection: %s\n", buf); + + return Connection::Handle(new MyConnection(buf, true)); + } + + void onDisconnect(const Connection::Handle &connection) override { + printf("MyServer::onDisconnect( %s )\n", ((MyConnection*)connection.pointer())->name.c_str()); + } + }; +} + + +int main(int argc, char** argv) { + if (argc != 3) + { printf("two args expected\n"); return 1; } + + bool client = (0 == strcmp(argv[1], "tcpclient")); + bool server = (0 == strcmp(argv[1], "tcpserver")); + if (!client && !server) + { printf("wrong args\n"); return 1; } + + printf("start tcp protocol\n"); + TcpProtocol::Handle protocol(new TcpProtocol()); + protocol->start(); - if (0 == strcmp(argv[1], "tcpserver")) { - printf("tcpserver\n"); - - class Handler: public Tcp::Handler { - public: - void handleTcpConnection(Tcp::Connection &connection) override { - const Tcp::Address &address = connection.getRemoteAddr(); - printf( "received connection from: %hhu.%hhu.%hhu.%hhu:%hu\n", - address.ip[0], address.ip[1], address.ip[2], address.ip[3], address.port ); - while(connection.check()) { - char buf[65] = {}; - if (connection.read(buf, sizeof(buf) - 1)) - printf("received: %s\n", buf); - sprintf(buf, "server say %d\n", rand()); - if (connection.write(buf, sizeof(buf) - 1)) - printf("sent: %s\n", buf); - } - } - } handler; + Address address; + address.set(Address::COMMON_STRING, argv[2]); + ErrorCode errorCode = protocol->resolve(address); + if (errorCode) { + printf("cannot resolve address: %s, errorCode: %u\n", argv[2], errorCode); + } else + if (client) { + printf("tcpclient, %s\n", argv[2]); - Tcp::Server server; - if (server.start(address, &handler)) - printf( "listening at: %hhu.%hhu.%hhu.%hhu:%hu\n", - address.ip[0], address.ip[1], address.ip[2], address.ip[3], address.port ); - server.join(); + Connection::Handle connection(new MyConnection("client", false)); + errorCode = protocol->connect(connection, address); + if (errorCode) { + printf("cannot open connection, errorCode: %u\n", errorCode); + } else { + printf("connected\n"); + fgetc(stdin); + printf("disconnection\n"); + connection->closeWait(10); + printf("disconnected\n"); + } + } else { + printf("tcpserver, %s\n", argv[2]); - printf("last error: %u\n", server.getLastError()); - return server.getLastError(); + Server::Handle server(new MyServer()); + errorCode = protocol->listen(server, address); + if (errorCode) { + printf("cannot start server, errorCode: %u\n", errorCode); + } else { + printf("server started\n"); + fgetc(stdin); + printf("stop server\n"); + server->stopWait(10); + printf("server stopped\n"); + } } - - printf("wrong args\n"); - return 1; + + printf("stop tcp protocol\n"); + protocol->stopWait(10*1000000ll); + printf("finished, last error: %u\n", errorCode); + + return errorCode; } diff --git a/protocol.cpp b/protocol.cpp index 0ca8666..f62e50d 100644 --- a/protocol.cpp +++ b/protocol.cpp @@ -109,6 +109,13 @@ Connection::Handle Protocol::serverConnect(const Server::Handle &server, const A } +void Protocol::serverDisconnect(const Server::Handle &server, const Connection::Handle &connection) { + if (!server || !connection) return; + server->disconnect(connection); +} + + + int Protocol::stopServer(Server &server) { Lock lock(mutex); int count = 0; diff --git a/protocol.h b/protocol.h index 7da1daa..e84e517 100644 --- a/protocol.h +++ b/protocol.h @@ -131,9 +131,12 @@ protected: inline Socket* socketSetUnchanged(Socket *socket) { socketCheck(socket, true); Socket *chNext = socket->chNext; socket->setChanged(false); return chNext; } + friend class Server; + friend class Connection; ErrorCode connectionOpen(const Connection::Handle &connection, Socket *socket); ErrorCode serverStart(const Server::Handle &server, Socket *socket); Connection::Handle serverConnect(const Server::Handle &server, const Address &address); + void serverDisconnect(const Server::Handle &server, const Connection::Handle &connection); private: // mutex must be locked before call diff --git a/server.cpp b/server.cpp index 6dbbba2..204c126 100644 --- a/server.cpp +++ b/server.cpp @@ -100,7 +100,7 @@ void Server::disconnect(const Connection::Handle &connection) ErrorCode Server::onStart() { return ERR_NONE; } -void Server::onStopReq() +void Server::onStopRequested() { } void Server::onStop(ErrorCode) { } diff --git a/server.h b/server.h index 859d5b8..1b5ac34 100644 --- a/server.h +++ b/server.h @@ -44,7 +44,6 @@ private: void clean(); friend class Protocol; - friend class Connection; ErrorCode start(Socket *socket); THandle connect(const Address &remoteAddress); void disconnect(const THandle &connection); diff --git a/tcpprotocol.cpp b/tcpprotocol.cpp index 376739c..58d1224 100644 --- a/tcpprotocol.cpp +++ b/tcpprotocol.cpp @@ -37,7 +37,11 @@ TcpSocket::TcpSocket( TcpProtocol::TcpProtocol(): - thread() { } + thread(), epollId(), eventsId() { } + + +TcpProtocol::~TcpProtocol() + { stop(); } ErrorCode TcpProtocol::resolve(Address &address) { @@ -158,7 +162,7 @@ void TcpProtocol::threadRun() { if (conn && !socketToRemove(socket)) { TcpSocket *s = new TcpSocket(*this, conn, server, address, res); if (ERR_NONE != connectionOpen(conn, s)) - s->finalize(); + { s->finalize(); serverDisconnect(server, conn); } } } else if (res != EAGAIN) { diff --git a/tcpprotocol.h b/tcpprotocol.h index 651c8e1..df66304 100644 --- a/tcpprotocol.h +++ b/tcpprotocol.h @@ -48,7 +48,7 @@ private: public: TcpProtocol(); - ~TcpProtocol() { stop(); } + ~TcpProtocol(); ErrorCode resolve(Address &address) override;