From f07ad6690178af7f019f3ebf657c934fe7f77dd2 Mon Sep 17 00:00:00 2001 From: Ivan Mahonin Date: Oct 08 2021 15:05:13 +0000 Subject: protocol abstraction --- diff --git a/address.h b/address.h new file mode 100644 index 0000000..4b339ec --- /dev/null +++ b/address.h @@ -0,0 +1,32 @@ +#ifndef ADDRESS_H +#define ADDRESS_H + + +#include + + +class Address { +public: + enum Type { + NONE, + COMMON_STRING, + SOCKET, + }; + + Type type; + std::string text; + char data[512]; + + inline Address(): type(NONE), data() { } + + template + inline const T& as() const + { assert(sizeof(T) < sizeof(data)); return *reinterpret_cast(data); } + + template + inline T& as() const + { assert(sizeof(T) < sizeof(data)); return *reinterpret_cast(data); } +}; + + +#endif diff --git a/connection.cpp b/connection.cpp new file mode 100644 index 0000000..5d81c1d --- /dev/null +++ b/connection.cpp @@ -0,0 +1,76 @@ + +#include "protocol.h" +#include "connection.h" + + +Connection::Connection(): + socket(), lastId() { } + +Connection::~Connection() + { close(); } + + +Protocol& Connection::getProtocol() const + { assert(socket); return socket->protocol; } +const Address& Connection::getRemoteAddress() const + { assert(socket); return socket->address; } +Server* Connection::getServer() const + { assert(socket); return socket->server; } + + +void Connection::open(Socket *socket) { + Lock lock(mutex); + close(); + if (!socket) return; + this->socket = socket; + onOpen(); +} + + +ErrorCode Connection::open(Protocol &protocol, const Address &remoteAddress) + { return protocol.connect(*this, remoteAddress); } + + +void Connection::close(ErrorCode errorCode) { + Socket *socketCopy; + { + Lock lock(mutex); + if (!socket) + return; + if (!errorCode && (!readQueue.empty() || !writeQueue.empty())) + errorCode = ERR_CONNECTION_UNFINISHED_TASKS; + + onClose(errorCode); + + readQueue.clear(); + writeQueue.clear(); + socketCopy = socket; + socket = nullptr; + } + socketCopy->close(); +} + + +Connection::ReqId Connection::writeReq(const void *data, size_t size, void *userData) { + if (!data || !size) return 0; + Lock lock(mutex); + if (!socket) return 0; + writeQueue.emplace_back(++lastId, userData, data, size); + return lastId; +} + + +Connection::ReqId Connection::readReq(void *data, size_t size, void *userData) { + if (!data || !size) return 0; + Lock lock(mutex); + if (!socket) return 0; + readQueue.emplace_back(++lastId, userData, data, size); + return lastId; +} + + +void Connection::onOpen() { } +void Connection::onClose(ErrorCode) { } +void Connection::onReadReady(const ReadReq&) { } +void Connection::onWriteReady(const WriteReq&) { } + diff --git a/connection.h b/connection.h new file mode 100644 index 0000000..70c6156 --- /dev/null +++ b/connection.h @@ -0,0 +1,137 @@ +#ifndef CONNECTION_H +#define CONNECTION_H + + +#include + +#include +#include + +#include "error.h" +#include "address.h" + + +enum : ErrorCode { + ERR_CONNECTION_COMMON = ERR_CONNECTION, + ERR_CONNECTION_FAILED, + ERR_CONNECTION_UNFINISHED_TASKS, + ERR_CONNECTION_LOST, +}; + + +class Protocol; +class Socket; +class Server; + + +class Connection { +public: + typedef unsigned long long ReqId; + typedef std::recursive_mutex Mutex; + typedef std::lock_guard Lock; + + template + struct TRWReq { + typedef T Type; + typedef std::deque< TRWReq > Queue; + + ReqId id; + void *userData; + Type *data; + size_t size; + size_t doneSize; + + inline TRWReq(ReqId id = 0, void *userData = nullptr, Type *data = nullptr, size_t size = 0): + id(id), userData(userData), data(data), size(size), doneSize() { } + + inline bool success() const + { return doneSize >= size; } + inline size_t remains() const + { assert(size >= doneSize); return size - doneSize; } + inline void processed(size_t size) { + assert(size <= remains()); + if (doneSize > size || size > remains()) + doneSize = size; else doneSize += size; + } + inline Type* current() const + { assert(size >= doneSize); return const_cast((const char*)data + doneSize); } + }; + + typedef TRWReq ReadReq; + typedef TRWReq WriteReq; + typedef ReadReq::Queue ReadQueue; + typedef WriteReq::Queue WriteQueue; + + template + class TRWLock { + public: + typedef TRWReq ReqType; + typedef TRWLock LockType; + + Lock lock; + Connection &connection; + ReqType * const request; + + protected: + inline TRWLock(Connection &connection, typename ReqType::Queue &queue): + lock(connection.mutex), connection(connection), request(queue.empty() ? nullptr : &queue.front()) { } + + public: + inline operator bool() const + { return request; } + inline bool success() const + { return request && request->success(); } + inline ReqType& operator*() const + { assert(request); return request; } + inline ReqType* operator->() const + { assert(request); return request; } + }; + + class ReadLock: public TRWLock { + public: + inline ReadLock(Connection &connection): + LockType(connection, connection.readQueue) { } + inline ~ReadLock() + { if (success()) connection.onReadReady(*request); } + }; + + class WriteLock: public TRWLock { + public: + inline WriteLock(Connection &connection): + LockType(connection, connection.writeQueue) { } + inline ~WriteLock() + { if (success()) connection.onWriteReady(*request); } + }; + +private: + Mutex mutex; + Socket *socket; + + ReqId lastId; + ReadQueue readQueue; + WriteQueue writeQueue; + +public: + Connection(); + virtual ~Connection(); + + void open(Socket *socket); + ErrorCode open(Protocol &protocol, const Address &remoteAddress); + void close(ErrorCode errorCode = ERR_NONE); + + ReqId writeReq(const void *data, size_t size, void *userData = nullptr); + ReqId readReq(void *data, size_t size, void *userData = nullptr); + +protected: + Protocol& getProtocol() const; + const Address& getRemoteAddress() const; + Server* getServer() const; + + virtual void onOpen(); + virtual void onClose(ErrorCode errorCode); + virtual void onReadReady(const ReadReq &req); + virtual void onWriteReady(const WriteReq &req); +}; + + +#endif diff --git a/error.h b/error.h index 64423ec..3741b2d 100644 --- a/error.h +++ b/error.h @@ -2,15 +2,22 @@ #define ERROR_H +#include + + typedef unsigned int ErrorCode; enum : ErrorCode { - ERR_NONE = 0, - ERR_COMMON = 1000, - ERR_TCP = 2000, + ERR_NONE = 0, + ERR_CONNECTION = 1000, + ERR_SERVER = 2000, + ERR_TCP = 3000, }; - +enum : ErrorCode { + ERR_COMMON = ERR_NONE + 1, + ERR_NOT_IMPLEMENTED, +}; #endif diff --git a/protocol.cpp b/protocol.cpp new file mode 100644 index 0000000..b282c11 --- /dev/null +++ b/protocol.cpp @@ -0,0 +1,47 @@ + +#include "protocol.h" +#include "server.h" +#include "connection.h" + + +Socket::Socket(Protocol &protocol, Connection *connection, Server *server, const Address &address): + closed(), prev(), next(), protocol(protocol), connection(connection), server(server), address(address) +{ + Protocol::Lock lock(protocol.mutex); + prev = protocol.last; + (prev ? prev->next : protocol.first) = this; +} + +Socket::~Socket() + { assert(closed); } + +void Socket::onClose() + { } + +void Socket::close() { + Protocol::Lock lock(protocol.mutex); + onClose(); + (prev ? prev->next : protocol.first) = next; + (next ? next->prev : protocol.last) = prev; + closed = true; + delete this; +} + + + +Protocol::Protocol(): + first(), last() { } +Protocol::~Protocol() + { assert(!first); } + + +void Protocol::closeAll() { + Lock lock(mutex); + while(first) first->close(); +} + + +ErrorCode Protocol::connect(Connection&, const Address&) + { return ERR_CONNECTION_FAILED; } +ErrorCode Protocol::listen(Server&, const Address&) + { return ERR_SERVER_LISTENING_FAILED; } diff --git a/protocol.h b/protocol.h new file mode 100644 index 0000000..4a2485b --- /dev/null +++ b/protocol.h @@ -0,0 +1,59 @@ +#ifndef PROTOCOL_H +#define PROTOCOL_H + + +#include + +#include "error.h" +#include "address.h" + + +class Protocol; +class Connection; +class Server; + + +class Socket { +private: + bool closed; + Socket *prev, *next; + +public: + Protocol &protocol; + Connection * const connection; + Server * const server; + const Address &address; + +protected: + Socket(const Socket&) = delete; + Socket(Protocol &protocol, Connection *connection, Server *server, const Address &address); + +public: + void close(); + virtual ~Socket(); + +protected: + virtual void onClose(); +}; + + +class Protocol { +public: + typedef std::recursive_mutex Mutex; + typedef std::lock_guard Lock; + +private: + friend class Socket; + Mutex mutex; + Socket *first, *last; + +public: + Protocol(); + virtual ~Protocol(); + virtual ErrorCode connect(Connection &connection, const Address &address); + virtual ErrorCode listen(Server &server, const Address &address); + void closeAll(); +}; + + +#endif diff --git a/server.cpp b/server.cpp new file mode 100644 index 0000000..9752b6a --- /dev/null +++ b/server.cpp @@ -0,0 +1,54 @@ + +#include "protocol.h" +#include "server.h" +#include "connection.h" + + + +Server::Server(): + socket() { } + + +Server::~Server() + { } + + +void Server::open(Socket *socket) { + Lock lock(mutex); + close(); + if (!socket) return; + this->socket = socket; + onOpen(); +} + + +ErrorCode Server::open(Protocol &protocol, const Address &localAddres) + { return protocol.listen(*this, localAddres); } + + +void Server::close(ErrorCode errorCode) { + Socket *socketCopy; + { + Lock lock(mutex); + if (!socket) return; + onClose(errorCode); + socketCopy = socket; + socket = nullptr; + } + socketCopy->close(); +} + +Protocol& Server::getProtocol() const + { assert(socket); return socket->protocol; } +const Address& Server::getLocalAddress() const + { assert(socket); return socket->address; } + + +void Server::onOpen() + { } +void Server::onClose(ErrorCode) + { } +Connection* Server::onConnect(const Address&) + { return nullptr; } +void Server::onDisconnect(Connection *connection) + { delete connection; } diff --git a/server.h b/server.h new file mode 100644 index 0000000..90d50bd --- /dev/null +++ b/server.h @@ -0,0 +1,52 @@ +#ifndef SERVER_H +#define SERVER_H + + +#include + +#include "error.h" +#include "address.h" + + +class Socket; +class Protocol; +class Connection; + + +enum : ErrorCode { + ERR_SERVER_COMMON = ERR_SERVER, + ERR_SERVER_LISTENING_FAILED, + ERR_SERVER_UNFINISHED_CONNECTIONS, + ERR_SERVER_LISTENING_LOST, +}; + + +class Server { +public: + typedef std::recursive_mutex Mutex; + typedef std::lock_guard Lock; + +private: + Mutex mutex; + Socket *socket; + +public: + Server(); + ~Server(); + + void open(Socket *socket); + ErrorCode open(Protocol &protocol, const Address &localAddres); + void close(ErrorCode errorCode = ERR_NONE); + +protected: + Protocol& getProtocol() const; + const Address& getLocalAddress() const; + + void onOpen(); + void onClose(ErrorCode errorCode); + Connection* onConnect(const Address &remoteAddres); + void onDisconnect(Connection *connection); +}; + + +#endif