#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include enum WorkMode { WM_STOPPED, WM_STARTING, WM_STARTED, WM_STOPPING }; template class Chain { private: Chain() { } public: typedef T Type; static constexpr Type* Type::*Next = NextField; static inline void push(Type* &first, Type *item) { assert(!item.*Next); item.*Next = first; first = &item; } static inline void push(std::atomic &first, Type &item) { assert(!item.*Next); item.*Next = first; while(!first.compare_exchange_weak(item.*Next, &item)); } static inline Type* pop(Type* &first) { Type *item = first; if (!item) return nullptr; first = item->next; item->next = nullptr; return item; } static inline Type* pop(std::atomic &first) { Type *item = first; do { if (!item) return nullptr; } while(!first.compare_exchange_weak(item, item->*Next)); item->next = nullptr; return item; } static inline void invert(Type* &first) { if (Item *curr = first) { Item *prev = nullptr; while(Item *next = curr->*Next) { curr->*Next = prev; prev = curr; curr = next; } first = curr; } } static inline Type* inverted(Type *first) { invert(first); return first; } }; // multhreaded input, single threaded output // simultaneus read and write are allowed template class QueueMISO { public: typedef T Type; typedef Chain Chain; static constexpr Type* Type::*Next = NextField; private: std::atomic pushed; Type *first; std::atomic counter; public: inline QueueMISO(): pushed(nullptr), first() { } inline ~QueueMISO() { assert(!pushed && !first && !counter); } inline void push(Type &item) { Chain::push(pushed, item); counter++; } inline Type& pop() { assert(first); Type &item = *Chain::pop(first); counter++; return item; } inline Type *peek() { return first ? first : (first = Chain::inverted(pushed.exchange(nullptr))); } inline int count() const { return counter; } }; class AMutex { private: std::atomic locked; public: inline AMutex(): locked(false) { } inline ~AMutex() { assert(!locked); } inline bool try_lock() { bool l = false; return locked.compare_exchange_strong(l, true); } inline void lock() { bool l = false; while(!locked.compare_exchange_weak(l, true)); } inline void unlock() { locked = false; } }; class ARWMutex { private: std::atomic counter; public: inline ARWMutex(): counter(0) { } inline ~ARWMutex() { assert(!counter); } inline bool try_lock() { int cnt = 0; return counter.compare_exchange_strong(cnt, -1); } inline void lock() { int cnt = 0; while(!counter.compare_exchange_weak(cnt, -1)) std::this_thread::yield(); } inline void unlock() { assert(counter < 0); counter = 0; } inline bool try_lock_read() { int cnt = counter; do { if (cnt < 0) return false; } while (counter.compare_exchange_weak(cnt, cnt + 1)); return true; } inline void lock_read() { int cnt = counter; while(cnt < 0 || !counter.compare_exchange_strong(cnt, cnt + 1)); } inline void unlock_read() { assert(counter >= 0); counter--; } }; typedef std::lock_guard ALock; typedef std::lock_guard AWLock; class ARLock { private: ARWMutex &mutex; public: inline ARLock(ARWMutex &mutex): mutex(mutex) { lock(); } inline ~ARLock() { unlock(); } inline void lock() { mutex.lock_read(); } inline void unlock() { mutex.unlock_read(); } }; typedef std::lock_guard Lock; typedef std::unique_lock UniqLock; class TcpProtocol; class Server; class Connection; class TcpSocket { public: TcpProtocol &protocol; int sockId; Server *server; Connection *connection; explicit TcpSocket(TcpProtocol &protocol, int sockId); ~TcpSocket(); void closeReq(); void update(bool read, bool write); }; class Connection { public: enum State { STATE_NONE, STATE_CONNECTING, STATE_CONNECTED, STATE_CLOSE_HINT, STATE_CLOSE_REQ, STATE_CLOSED, }; struct Task { unsigned char *data; int size; int position; Task *next; explicit Task(void *data = nullptr, int size = 0): data((unsigned char*)data), size(size), position() { } }; typedef QueueMISO Queue; std::atomic stateReaders; // count of threads that currently reads the state std::atomic state; // write from protocol thread, read from others std::atomic stateWanted; // read from protocol thread, write from others State stateQuickCopy; // for use in protocol thread only std::atomic lastEvents; TcpProtocol *protocol; int sockId; Queue readQueue; Queue writeQueue; // private void updateEvents() { assert(stateReaders > 0); // << must be locked unsigned int events = 0; if (readQueue.count()) events |= EPOLLIN; if (writeQueue.count()) events |= EPOLLOUT; if (lastEvents.exchange(events) == events) return; struct epoll_event event = {}; event.data.ptr = this; event.events = events; epoll_ctl(protocol->pollId, EPOLL_CTL_MOD, sockId, &event); } // public functions void connect(TcpProtocol &protocol, const void *address, size_t size); void accept(TcpProtocol &protocol, int sockId); void closeHint() { wantState(STATE_CLOSE_HINT); } void closeReq() { wantState(STATE_CLOSED); } void closeWait(unsigned long long timeoutUs); void wantState(State state) { State sw = stateWanted; while(sw < state && !stateWanted.compare_exchange_weak(sw, state)); } bool send(Task &task) { stateReaders++; State s = state; if (s < STATE_CONNECTED || s >= STATE_CLOSED) return false; writeQueue.push(task); updateEvents(); stateReaders--; return true; } bool receive(Task &task) { stateReaders++; State s = state; if (s < STATE_CONNECTED || s >= STATE_CLOSED) return false; readQueue.push(task); updateEvents(); stateReaders--; return true; } // handlers, must be called from protocol thread virtual void onOpen() { } virtual void onClose(bool) { } virtual void onReadReady(const Task&) { } virtual void onWriteReady(const Task&) { } std::mutex mutex; WorkMode mode; TcpSocket *socket; // must be called from Protocol thread only bool open(TcpSocket *socket) { { Lock lock(mutex); if (mode != WM_STOPPED) return false; mode = WM_STARTING; this->socket = socket; } onOpen(); { Lock lock(mutex); mode = WM_STARTED; } return true; } void close(bool err) { { Lock lock(mutex); assert(mode == WM_STARTED); mode = WM_STOPPING; if (!err && (!readQueue.empty() || !writeQueue.empty())) err = true; } onClose(err); { socket = nullptr; } } // must be called internally void closeReq() { Lock lock(mutex); if (mode != WM_STARTING && mode != WM_STARTED) return; socket->closeReq(); } void readReq(void *data, int size) { assert(data && size > 0); Lock lock(mutex); if (mode != WM_STARTING && mode != WM_STARTED) return; readQueue.emplace_back(data, size); socket->update(!readQueue.empty(), !writeQueue.empty()); } void writeReq(void *data, int size) { assert(data && size > 0); Lock lock(mutex); if (mode != WM_STARTING && mode != WM_STARTED) return; writeQueue.emplace_back(data, size); socket->update(!readQueue.empty(), !writeQueue.empty()); } }; class TcpProtocol { private: friend class TcpSocket; std::mutex mutex; std::thread *thread; std::condition_variable condModeChanged; WorkMode mode; int epollId; int eventsId; std::vector socketsToClose; // lock mutex befor call void wakeup() { unsigned long long num = 0; ::write(eventsId, &num, sizeof(num)); } void threadRun() { const int maxCount = 16; struct epoll_event events[maxCount] = {}; int count = 0; std::vector socketsToClose; { Lock lock(mutex); assert(mode == WM_STARTING); mode = WM_STARTED; condModeChanged.notify_all(); } while(true) { for(int i = 0; i < count; ++i) { unsigned int e = events[i].events; TcpSocket *socket = (TcpSocket*)events[i].data.ptr; if (!socket) { unsigned long long num = 0; while(::read(eventsId, &num, sizeof(num)) > 0); } else if (socket->connection) { // lock connection //if (connection-> } else if (socket->server) { // lock server } else { delete socket; } } { Lock lock(mutex); if (mode == WM_STOPPING) { socketsToClose } } count = epoll_wait(epollId, events, maxCount, 10000); } } public: TcpProtocol(): thread(), mode(WM_STOPPED), epollId(-1), eventsId(-1) { } ~TcpProtocol() { stop(); } bool start() { { Lock lock(mutex); if (mode == WM_STARTED || mode == WM_STARTING) return true; if (mode != WM_STOPPED) return false; mode = WM_STARTING; } condModeChanged.notify_all(); epollId = ::epoll_create(32); eventsId = ::eventfd(0, EFD_NONBLOCK); assert(epollId >= 0 && eventsId >= 0); struct epoll_event event = {}; event.events = EPOLLIN; epoll_ctl(epollId, EPOLL_CTL_ADD, eventsId, &event); thread = new std::thread(&TcpProtocol::threadRun, this); } void stop() { { UniqLock lock(mutex); while(true) { if (mode == WM_STOPPED) return; if (mode == WM_STARTED) break; condModeChanged.wait(lock); } mode = WM_STOPPING; } condModeChanged.notify_all(); thread->join(); delete thread; thread = nullptr; ::close(epollId); ::close(eventsId); { Lock lock(mutex); mode = WM_STOPPED; } condModeChanged.notify_all(); } }; TcpSocket::TcpSocket(TcpProtocol &protocol, int sockId): protocol(protocol), sockId(sockId) { int tcp_nodelay = 1; setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int)); fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK); struct epoll_event event = {}; event.data.ptr = this; epoll_ctl(protocol.epollId, EPOLL_CTL_ADD, sockId, &event); } TcpSocket::~TcpSocket() { struct epoll_event event = {}; epoll_ctl(protocol.epollId, EPOLL_CTL_DEL, sockId, &event); ::close(sockId); } void TcpSocket::closeReq() { Lock lock(protocol.mutex); socketsToClose.push_back(this); protocol.wakeup(); } void TcpSocket::update(bool read, bool write) { struct epoll_event event = {}; event.data.ptr = this; if (read) event.events |= EPOLLIN; if (write) event.events |= EPOLLOUT; epoll_ctl(protocol.epollId, EPOLL_CTL_MOD, sockId, &event); }