From 71057fc82d5cd126e9b3231eb225bd396c08baa0 Mon Sep 17 00:00:00 2001 From: Ivan Mahonin Date: Oct 14 2021 16:09:50 +0000 Subject: first implemantation --- diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3c6778a --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.kdev4 +tcplab diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..e29ca23 --- /dev/null +++ b/build.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +set -e + +c++ -Wall -g -pthread *.cpp -o tcplab + diff --git a/connection.cpp b/connection.cpp new file mode 100644 index 0000000..ce499dd --- /dev/null +++ b/connection.cpp @@ -0,0 +1,156 @@ + +#include + +#include +#include +#include + +#include "protocol.h" +#include "connection.h" + + +Connection::Connection(): + state(STATE_NONE), + stateWanted(STATE_NONE), + enqueued(false), + closeWaiters(0), + finishWaiters(0), + address(), + addressSize(), + protocol(), + server(), + prev(), next(), + srvPrev(), srvNext(), + queueNext(), + sockId(-1), + stateLocal() + { } + + +Connection::~Connection() { + assert(state == STATE_NONE || state == STATE_CLOSED); + closeWait(); + while(finishWaiters > 0) std::this_thread::yield(); +} + + +bool Connection::open(Protocol &protocol, void *address, size_t addressSize, int sockId, Server *server) { + if (!address || !addressSize || addressSize > MAX_ADDR_SIZE) return false; + + State s = state; + do { + if (s != STATE_NONE && s != STATE_FINISHED) return false; + } while(!state.compare_exchange_weak(s, STATE_OPENING)); + if (s == STATE_FINISHED) while(finishWaiters > 0) std::this_thread::yield(); + + ReadLock lock(protocol.stateProtector); + Protocol::State ps = protocol.state; + if (ps < Protocol::STATE_OPEN || ps >= Protocol::STATE_CLOSE_REQ) + { state = STATE_NONE; return false; } + + memcpy(this->address, address, addressSize); + this->addressSize = addressSize; + this->sockId = sockId; + this->server = server; + + this->protocol = &protocol; + state = STATE_INITIALIZING; + if (!enqueued.exchange(true)) protocol.connQueue.push(*this); + protocol.wakeup(); + return true; +} + + +void Connection::wantState(State state, bool error) { + ReadLock lock(stateProtector); + State s = this->state; + if (s <= STATE_OPENING || s >= STATE_CLOSING || s >= state) return; + + State sw = stateWanted; + do { + if (sw <= STATE_OPENING || sw >= state) return; + } while(!stateWanted.compare_exchange_weak(sw, state)); + if (error) { error = false; this->error.compare_exchange_strong(error, true); } + if (!enqueued.exchange(true)) protocol->connQueue.push(*this); + protocol->wakeup(); +} + + +void Connection::closeReq() + { wantState(STATE_CLOSE_REQ, false); } +void Connection::close(bool error) + { wantState(STATE_CLOSING, error); } + + +void Connection::closeWait(unsigned long long timeoutUs, bool withReq, bool error) { + stateProtector.lock(); + State s = state; + if (s <= STATE_OPENING || s >= STATE_FINISHED) + { stateProtector.unlock(); return; } + + if (s == STATE_CLOSED) { + finishWaiters++; + stateProtector.unlock(); + } else { + if (withReq && s < STATE_CLOSE_REQ) closeReq(); + closeWaiters++; + stateProtector.unlock(); + + std::mutex fakeMutex; + { + std::unique_lock fakeLock(fakeMutex); + if (timeoutUs && s < STATE_CLOSING) { + std::chrono::steady_clock::time_point t + = std::chrono::steady_clock::now() + + std::chrono::microseconds(timeoutUs); + while(s < STATE_CLOSED) { + if (t <= std::chrono::steady_clock::now()) { close(error); break; } + closeCondition.wait_until(fakeLock, t); + s = state; + } + } + if (s < STATE_CLOSING) close(error); + while(s < STATE_CLOSED) { + closeCondition.wait(fakeLock); + s = state; + } + } + + finishWaiters++; + closeWaiters--; + } + + + while(s != STATE_FINISHED) { std::this_thread::yield(); s = state; } + + finishWaiters--; +} + + +bool Connection::read(Task &task) { + ReadLock lock(stateProtector); + State s = state; + if (s < STATE_OPEN || s >= STATE_CLOSING) return false; + writeQueue.push(task); + protocol->updateEvents(*this); + return true; +} + + +bool Connection::write(Task &task) { + ReadLock lock(stateProtector); + State s = state; + if (s < STATE_OPEN || s >= STATE_CLOSING) return false; + writeQueue.push(task); + protocol->updateEvents(*this); + return true; +} + + +void Connection::onOpeningError() { } +void Connection::onOpen(const void*, size_t) { } +bool Connection::onReadReady(Task&, bool) { return true; } +bool Connection::onWriteReady(Task&, bool) { return true; } +void Connection::onCloseReqested() { } +void Connection::onClose(bool) { } + diff --git a/connection.h b/connection.h new file mode 100644 index 0000000..0ceee1c --- /dev/null +++ b/connection.h @@ -0,0 +1,118 @@ +#ifndef CONNECTION_H +#define CONNECTION_H + + +#include + +#include +#include + +#include "utils.h" + + +#define MAX_ADDR_SIZE 256 + + +class Protocol; +class Server; + + +class Root { +public: + virtual ~Root() { } +}; + + +class Connection: public Root { +public: + enum State: int { + STATE_NONE, + STATE_OPENING, + STATE_INITIALIZING, + STATE_RESOLVING, + STATE_CONNECTING, + STATE_OPEN, + STATE_CLOSE_REQ, + STATE_CLOSING, + STATE_CLOSED, + STATE_FINISHED + }; + + struct Task { + void *data; + size_t size; + size_t completion; + bool watch; + Task *next; + + inline explicit Task( + void *data = nullptr, + size_t size = 0, + size_t completion = 0, + bool watch = false, + Task *next = nullptr + ): + data(data), size(size), completion(completion), watch(watch), next(next) + { assert(!!data == !!size); } + }; + + typedef CounteredQueueMISO TaskQueue; + +private: + friend class Protocol; + + std::atomic state; + std::atomic stateWanted; + ReadProtector stateProtector; + std::atomic enqueued; + std::atomic error; + + std::atomic closeWaiters; + std::atomic finishWaiters; + std::condition_variable closeCondition; + + unsigned char address[MAX_ADDR_SIZE]; + size_t addressSize; + + Protocol *protocol; + Server *server; + Connection *prev, *next; + Connection *srvPrev, *srvNext; + Connection *queueNext; + int sockId; + State stateLocal; + + std::atomic events; + TaskQueue readQueue; + TaskQueue writeQueue; + + void wantState(State state, bool error); + +public: + typedef QueueMISO Queue; + friend class QueueMISO; + friend class ChainFuncs; + +public: + Connection(); + virtual ~Connection(); + + bool open(Protocol &protocol, void *address, size_t addressSize, int sockId = 0, Server *server = nullptr); + void closeReq(); + void close(bool error); + void closeWait(unsigned long long timeoutUs = 0, bool withReq = true, bool error = true); + + bool read(Task &task); + bool write(Task &task); + +protected: + virtual void onOpeningError(); + virtual void onOpen(const void *address, size_t addressSize); + virtual bool onReadReady(Task &task, bool error); + virtual bool onWriteReady(Task &task, bool error); + virtual void onCloseReqested(); + virtual void onClose(bool error); +}; + + +#endif diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000..4c4853b --- /dev/null +++ b/main.cpp @@ -0,0 +1,9 @@ + +#include "protocol.h" + + +int main() { + return 0; +} + + diff --git a/protocol.cpp b/protocol.cpp new file mode 100644 index 0000000..ea121ad --- /dev/null +++ b/protocol.cpp @@ -0,0 +1,563 @@ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "protocol.h" + + + +Protocol::~Protocol() { + assert(state == STATE_NONE || state == STATE_CLOSED); + closeWait(); + while(finishWaiters > 0) std::this_thread::yield(); + if (thread) thread->join(); +} + + +bool Protocol::open() { + State s = state; + do { + if (s != STATE_NONE && s != STATE_FINISHED) return false; + } while(!state.compare_exchange_weak(s, STATE_OPENING)); + if (s == STATE_FINISHED) while(finishWaiters > 0) std::this_thread::yield(); + if (thread) thread->join(); + + thread = new std::thread(&Protocol::threadRun, this); + return true; +} + + +void Protocol::wantState(State state) { + ReadLock lock(stateProtector); + State s = this->state; + if (s <= STATE_OPENING || s >= STATE_CLOSING || s >= state) return; + + State sw = stateWanted; + do { + if (sw <= STATE_OPENING || sw >= state) return; + } while(!stateWanted.compare_exchange_weak(sw, state)); + wakeup(); +} + + +void Protocol::closeReq() + { wantState(STATE_CLOSE_REQ); } +void Protocol::close() + { wantState(STATE_CLOSING); } + + +void Protocol::closeWait(unsigned long long timeoutUs, bool withReq) { + stateProtector.lock(); + State s = state; + if (s <= STATE_OPENING || s >= STATE_FINISHED) + { stateProtector.unlock(); return; } + + + if (s == STATE_CLOSED) { + finishWaiters++; + stateProtector.unlock(); + } else { + if (withReq && s < STATE_CLOSE_REQ) closeReq(); + closeWaiters++; + stateProtector.unlock(); + + std::mutex fakeMutex; + { + std::unique_lock fakeLock(fakeMutex); + if (timeoutUs && s < STATE_CLOSING) { + std::chrono::steady_clock::time_point t + = std::chrono::steady_clock::now() + + std::chrono::microseconds(timeoutUs); + while(s < STATE_CLOSED) { + if (t <= std::chrono::steady_clock::now()) { close(); break; } + closeCondition.wait_until(fakeLock, t); + s = state; + } + } + if (s < STATE_CLOSING) close(); + while(s < STATE_CLOSED) { + closeCondition.wait(fakeLock); + s = state; + } + } + + finishWaiters++; + closeWaiters--; + } + + + while(s != STATE_FINISHED) { std::this_thread::yield(); s = state; } + + finishWaiters--; +} + + +void Protocol::wakeup() { + unsigned long long num = 0; + ::write(eventFd, &num, sizeof(num)); +} + + +void Protocol::updateEvents(Connection &connection) { + unsigned int events = 0; + if (connection.readQueue.count()) events |= EPOLLIN; + if (connection.writeQueue.count()) events |= EPOLLOUT; + if (connection.events.exchange(events) == events) return; + + struct epoll_event event = {}; + event.data.ptr = &connection; + event.events = events; + epoll_ctl(epollFd, EPOLL_CTL_MOD, connection.sockId, &event); +} + + +void Protocol::initSocket(int sockId) { + int tcp_nodelay = 1; + setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int)); + fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK); +} + + +void Protocol::threadConnQueue() { + while(connQueue.peek()) { + Connection &connection = connQueue.pop(); + connection.enqueued = false; + + if (connection.stateLocal == Connection::STATE_NONE) { + int sockId = connection.sockId; + bool connected = false; + + if (sockId >= 0) { + initSocket(sockId); + } else + if (connection.addressSize >= sizeof(sa_family_t)) { + struct sockaddr *addr = (struct sockaddr*)connection.address; + sockId = ::socket(addr->sa_family, SOCK_STREAM, 0); + if (sockId >= 0) { + initSocket(sockId); + int res = ::connect(sockId, addr, (int)connection.addressSize); + if (res == 0) { + connected = true; + } else + if (res != EINPROGRESS) { + ::close(sockId); + sockId = -1; + } + } + } + + if (sockId >= 0) { + connection.protocol = this; + connection.prev = connLast; + assert(!connection.next); + (connection.prev ? connection.prev->next : connFirst) = &connection; + if (connection.server) { + connection.srvPrev = connection.server->connLast; + assert(!connection.srvNext); + (connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = &connection; + } + connection.sockId = sockId; + + struct epoll_event event = {}; + event.data.ptr = &connection; + if (connected) { + epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event); + connection.onOpen(connection.address, connection.addressSize); + connection.state = connection.stateLocal = Connection::STATE_OPEN; + connection.stateProtector.wait(); + } else { + event.events = EPOLLOUT; + epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event); + connection.state = connection.stateLocal = Connection::STATE_CONNECTING; + connection.stateProtector.wait(); + } + } else { + connection.sockId = -1; + connection.server = nullptr; + connection.onOpeningError(); + connection.state = Connection::STATE_CLOSED; + while(connection.closeWaiters > 0) + { connection.closeCondition.notify_all(); std::this_thread::yield(); } + connection.state = Connection::STATE_FINISHED; + } + } else { + Connection::State stateWanted = connection.stateWanted; + + if ( connection.stateLocal >= Connection::STATE_CONNECTING + && connection.stateLocal <= Connection::STATE_OPEN + && stateWanted == Connection::STATE_CLOSE_REQ ) + { + connection.state = connection.stateLocal = Connection::STATE_CLOSE_REQ; + connection.stateProtector.wait(); + connection.onCloseReqested(); + } else + if ( connection.stateLocal >= Connection::STATE_CONNECTING + && connection.stateLocal <= Connection::STATE_CLOSE_REQ + && stateWanted == Connection::STATE_CLOSING ) + { + connection.state = connection.stateLocal = Connection::STATE_CLOSING; + connection.stateProtector.wait(); + while(Connection::Task *task = connection.readQueue.peek()) { + connection.readQueue.pop(); + if (!connection.onReadReady(*task, true)) + assert(false); + } + while(Connection::Task *task = connection.writeQueue.peek()) { + connection.writeQueue.pop(); + if (!connection.onWriteReady(*task, true)) + assert(false); + } + connection.onClose(false); + } + + if (connection.stateLocal == Connection::STATE_CLOSING && !connection.enqueued) { + if (connection.server) { + (connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = connection.srvNext; + (connection.srvNext ? connection.srvNext->srvPrev : connection.server->connLast) = connection.srvPrev; + connection.server = nullptr; + } + (connection.prev ? connection.prev->next : connFirst) = connection.next; + (connection.next ? connection.next->prev : connLast) = connection.prev; + connection.sockId = -1; + connection.protocol = nullptr; + memset(connection.address, 0, sizeof(connection.address)); + connection.addressSize = 0; + connection.stateWanted = Connection::STATE_NONE; + connection.stateLocal = Connection::STATE_NONE; + connection.state = Connection::STATE_CLOSED; + while(connection.closeWaiters > 0) + { connection.closeCondition.notify_all(); std::this_thread::yield(); } + connection.state = Connection::STATE_FINISHED; + } + } + } +} + + +void Protocol::threadConnEvents(Connection &connection, unsigned int events) { + bool needUpdate = false; + + if (connection.stateLocal == Connection::STATE_CONNECTING) { + if (events & EPOLLHUP) { + connection.onOpeningError(); + if (connection.server) { + (connection.srvPrev ? connection.srvPrev->srvNext : connection.server->connFirst) = connection.srvNext; + (connection.srvNext ? connection.srvNext->srvPrev : connection.server->connLast) = connection.srvPrev; + connection.server = nullptr; + } + (connection.prev ? connection.prev->next : connFirst) = connection.next; + (connection.next ? connection.next->prev : connLast) = connection.prev; + connection.sockId = -1; + connection.protocol = nullptr; + memset(connection.address, 0, sizeof(connection.address)); + connection.addressSize = 0; + connection.stateWanted = Connection::STATE_NONE; + connection.stateLocal = Connection::STATE_NONE; + connection.state = Connection::STATE_CLOSED; + while(connection.closeWaiters > 0) + { connection.closeCondition.notify_all(); std::this_thread::yield(); } + connection.state = Connection::STATE_FINISHED; + } + + if (events & EPOLLOUT) { + connection.onOpen(connection.address, connection.addressSize); + connection.state = connection.stateLocal = Connection::STATE_OPEN; + connection.stateProtector.wait(); + needUpdate = true; + } + } + + if ( connection.stateLocal < Connection::STATE_OPEN + || connection.stateLocal >= Connection::STATE_CLOSING ) + return; + + if (events & EPOLLIN) { + bool ready = true; + int pops = 0; + while(ready) { + Connection::Task *task = connection.readQueue.peek(); + if (!task) { + if (pops) needUpdate = true; + break; + } + size_t completion = 0; + while(task->size > task->completion) { + size_t size = task->size - task->completion; + if (size > INT_MAX) size = INT_MAX; + int res = ::recv( + connection.sockId, + (unsigned char*)task->data + task->completion, + (int)size, MSG_DONTWAIT ); + if (res == EAGAIN) { + ready = false; + break; + } else + if (res <= 0) { + connection.close(true); + return; + } else { + completion += size; + task->completion += size; + } + } + + if (task->size <= task->completion || (completion && task->watch)) { + connection.readQueue.pop(); + if (!connection.onReadReady(*task, false)) + connection.readQueue.unpop(*task); else ++pops; + } + } + } + + if (events & EPOLLHUP) { + connection.close((bool)(events & EPOLLERR)); + return; + } + + if (events & EPOLLOUT) { + bool ready = true; + int pops = 0; + while(ready) { + Connection::Task *task = connection.writeQueue.peek(); + if (!task) { + if (pops) needUpdate = true; + break; + } + size_t completion = 0; + while(task->size > task->completion) { + size_t size = task->size - task->completion; + if (size > INT_MAX) size = INT_MAX; + int res = ::send( + connection.sockId, + (const unsigned char*)task->data + task->completion, + (int)size, MSG_DONTWAIT ); + if (res == EAGAIN) { + ready = false; + break; + } else + if (res <= 0) { + connection.close(true); + return; + } else { + completion += size; + task->completion += size; + } + } + + if (task->size <= task->completion || (completion && task->watch)) { + connection.writeQueue.pop(); + if (!connection.onWriteReady(*task, false)) + connection.writeQueue.unpop(*task); else ++pops; + } + } + } + + if (needUpdate) updateEvents(connection); +} + + +void Protocol::threadSrvQueue() { + while(srvQueue.peek()) { + Server &server = srvQueue.pop(); + server.enqueued = false; + + if (server.stateLocal == Server::STATE_NONE) { + int sockId = -1; + if (server.addressSize >= sizeof(sa_family_t)) { + struct sockaddr *addr = (struct sockaddr*)server.address; + sockId = ::socket(addr->sa_family, SOCK_STREAM, 0); + if (sockId >= 0) { + initSocket(sockId); + if ( 0 != ::bind(sockId, addr, (int)server.addressSize) + || 0 != ::listen(sockId, 128) ) + { + ::close(sockId); + sockId = -1; + } + } + } + + if (sockId >= 0) { + server.protocol = this; + server.prev = srvLast; + assert(!server.next); + (server.prev ? server.prev->next : srvFirst) = &server; + server.sockId = sockId; + + struct epoll_event event = {}; + event.data.ptr = &server; + event.events = EPOLLIN; + epoll_ctl(epollFd, EPOLL_CTL_ADD, sockId, &event); + + server.onOpen(); + server.state = server.stateLocal = Server::STATE_OPEN; + server.stateProtector.wait(); + } else { + server.onOpeningError(); + server.state = Server::STATE_CLOSED; + while(server.closeWaiters > 0) + { server.closeCondition.notify_all(); std::this_thread::yield(); } + server.state = Server::STATE_FINISHED; + } + } else { + Server::State stateWanted = server.stateWanted; + if ( server.stateLocal == Server::STATE_OPEN + && stateWanted == Server::STATE_CLOSE_REQ ) + { + for(Connection *conn = server.connFirst; conn; conn = conn->srvNext) + conn->closeReq(); + struct epoll_event event = {}; + epoll_ctl(epollFd, EPOLL_CTL_DEL, server.sockId, &event); + server.state = server.stateLocal = Server::STATE_CLOSE_REQ; + server.stateProtector.wait(); + server.onCloseReqested(); + } else + if ( ( server.stateLocal == Server::STATE_OPEN + || server.stateLocal == Server::STATE_CLOSE_REQ ) + && stateWanted == Server::STATE_CLOSING_CONNECTIONS ) + { + if (server.stateLocal != Server::STATE_CLOSE_REQ) { + struct epoll_event event = {}; + epoll_ctl(epollFd, EPOLL_CTL_DEL, server.sockId, &event); + } + if (server.connFirst) { + for(Connection *conn = server.connFirst; conn; conn = conn->srvNext) + conn->close(true); + server.state = server.stateLocal = Server::STATE_CLOSING_CONNECTIONS; + } else { + server.state = server.stateLocal = Server::STATE_CLOSING; + server.stateProtector.wait(); + server.onClose(false); + } + } + + if (server.stateLocal == Server::STATE_CLOSING && !server.enqueued) { + server.state = Server::STATE_CLOSED; + (server.prev ? server.prev->next : srvFirst) = server.next; + (server.next ? server.next->prev : srvLast) = server.prev; + server.protocol = nullptr; + memset(server.address, 0, sizeof(server.address)); + server.addressSize = 0; + server.stateWanted = Server::STATE_NONE; + server.stateLocal = Server::STATE_NONE; + while(server.closeWaiters > 0) + { server.closeCondition.notify_all(); std::this_thread::yield(); } + server.state = Server::STATE_FINISHED; + } + } + } +} + + +void Protocol::threadSrvEvents(Server &server, unsigned int events) { + if (events & EPOLLHUP) { + if ( server.stateLocal >= Server::STATE_OPEN + && server.stateLocal < Server::STATE_CLOSING ) + server.close((bool)(events & EPOLLERR)); + return; + } + + if (events & EPOLLIN) { + if ( server.stateLocal >= Server::STATE_OPEN + && server.stateLocal < Server::STATE_CLOSE_REQ) + { + unsigned char address[MAX_ADDR_SIZE] = {}; + socklen_t len = sizeof(address); + int res = ::accept(server.sockId, (struct sockaddr*)address, &len); + assert(len < sizeof(address)); + if (res >= 0) { + if (Connection *conn = server.onConnect(address, len)) + if (!conn->open(*this, address, len, res)) + server.onDisconnect(conn, true); + } else + if (res != EAGAIN) { + server.close(true); + return; + } + } + } +} + + +void Protocol::threadRun() { + epollFd = ::epoll_create(32); + eventFd = ::eventfd(0, EFD_NONBLOCK); + + struct epoll_event event = {}; + event.events = EPOLLIN; + epoll_ctl(epollFd, EPOLL_CTL_ADD, eventFd, &event); + + const int maxCount = 16; + struct epoll_event events[maxCount] = {}; + int count = 0; + + State stateLocal = STATE_OPEN; + state = stateLocal; + + while(true) { + for(int i = 0; i < count; ++i) { + Root *pointer = (Root*)events[i].data.ptr; + if (!pointer) { + unsigned long long num = 0; + while(::read(eventFd, &num, sizeof(num)) > 0 && !num); + } else + if (Connection *connection = dynamic_cast(pointer)) { + threadConnEvents(*connection, events[i].events); + } else + if (Server *server = dynamic_cast(pointer)) { + threadSrvEvents(*server, events[i].events); + } + } + + State sw = stateWanted; + if (stateLocal == STATE_OPEN && sw == STATE_CLOSE_REQ) { + state = stateLocal = STATE_CLOSE_REQ; + stateProtector.wait(); + for(Server *server = srvFirst; server; server = server->next) + server->closeReq(); + for(Connection *conn = connFirst; conn; conn = conn->next) + if (!conn->server) + conn->closeReq(); + } else + if (stateLocal >= STATE_OPEN && stateLocal <= STATE_CLOSE_REQ && sw == STATE_CLOSING) { + state = STATE_CLOSING; + stateProtector.wait(); + for(Server *server = srvFirst; server; server = server->next) + server->close(true); + for(Connection *conn = connFirst; conn; conn = conn->next) + if (!conn->server) + conn->close(true); + } + + while(connQueue.peek() || srvQueue.peek()) { + threadConnQueue(); + threadSrvQueue(); + } + + if (stateLocal == STATE_CLOSING) { + assert(!srvFirst && !connFirst); + break; + } + + count = epoll_wait(epollFd, events, maxCount, 10000); + } + + while(closeWaiters > 0) + { closeCondition.notify_all(); std::this_thread::yield(); } + + ::close(epollFd); + ::close(eventFd); + + state = STATE_FINISHED; +} + diff --git a/protocol.h b/protocol.h new file mode 100644 index 0000000..3f115b1 --- /dev/null +++ b/protocol.h @@ -0,0 +1,70 @@ +#ifndef PROTOCOL_H +#define PROTOCOL_H + + +#include + +#include "utils.h" +#include "connection.h" +#include "server.h" + + + +class Protocol { +public: + enum State { + STATE_NONE, + STATE_OPENING, + STATE_OPEN, + STATE_CLOSE_REQ, + STATE_CLOSING, + STATE_CLOSED, + STATE_FINISHED + }; + +private: + friend class Server; + friend class Connection; + + std::thread *thread; + int epollFd; + int eventFd; + + std::atomic state; + std::atomic stateWanted; + ReadProtector stateProtector; + + std::atomic closeWaiters; + std::atomic finishWaiters; + std::condition_variable closeCondition; + + Connection *connFirst, *connLast; + Connection::Queue connQueue; + + Server *srvFirst, *srvLast; + Server::Queue srvQueue; + + void wakeup(); + void updateEvents(Connection &connection); + void initSocket(int sockId); + + void wantState(State state); + + void threadConnQueue(); + void threadConnEvents(Connection &connection, unsigned int events); + void threadSrvQueue(); + void threadSrvEvents(Server &server, unsigned int events); + void threadRun(); + +public: + Protocol(); + virtual ~Protocol(); + + bool open(); + void closeReq(); + void close(); + void closeWait(unsigned long long timeoutUs = 0, bool withReq = true); +}; + + +#endif diff --git a/server.cpp b/server.cpp new file mode 100644 index 0000000..af80b4b --- /dev/null +++ b/server.cpp @@ -0,0 +1,131 @@ + +#include + +#include +#include +#include + +#include "protocol.h" +#include "server.h" + + +Server::Server(): + state(STATE_NONE), + stateWanted(STATE_NONE), + enqueued(false), + address(), + addressSize(), + protocol(), + prev(), next(), + queueNext(), + connFirst(), connLast(), + sockId(-1), + stateLocal() + { } + + +Server::~Server() { + assert(state == STATE_NONE || state == STATE_CLOSED); + closeWait(); + while(finishWaiters > 0) std::this_thread::yield(); +} + + +bool Server::open(Protocol &protocol, void *address, size_t addressSize) { + if (!address || !addressSize || addressSize > MAX_ADDR_SIZE) return false; + + State s = state; + do { + if (s != STATE_NONE && s != STATE_FINISHED) return false; + } while(!state.compare_exchange_weak(s, STATE_OPENING)); + if (s == STATE_FINISHED) while(finishWaiters > 0) std::this_thread::yield(); + + ReadLock lock(protocol.stateProtector); + Protocol::State ps = protocol.state; + if (ps < Protocol::STATE_OPEN || ps >= Protocol::STATE_CLOSE_REQ) + { state = STATE_NONE; return false; } + + memcpy(this->address, address, addressSize); + this->addressSize = addressSize; + + this->protocol = &protocol; + state = STATE_INITIALIZING; + if (!enqueued.exchange(true)) protocol.srvQueue.push(*this); + protocol.wakeup(); + return true; +} + + +void Server::wantState(State state, bool error) { + ReadLock lock(stateProtector); + State s = this->state; + if (s <= STATE_OPENING || s >= STATE_CLOSING_CONNECTIONS || s >= state) return; + + State sw = stateWanted; + do { + if (sw <= STATE_OPENING || sw >= state) return; + } while(!stateWanted.compare_exchange_weak(sw, state)); + if (error) { error = false; this->error.compare_exchange_strong(error, true); } + if (!enqueued.exchange(true)) protocol->srvQueue.push(*this); + protocol->wakeup(); +} + + +void Server::closeReq() + { wantState(STATE_CLOSE_REQ, false); } +void Server::close(bool error) + { wantState(STATE_CLOSING_CONNECTIONS, error); } + + +void Server::closeWait(unsigned long long timeoutUs, bool withReq) { + stateProtector.lock(); + State s = state; + if (s <= STATE_OPENING || s >= STATE_FINISHED) + { stateProtector.unlock(); return; } + + if (s == STATE_CLOSED) { + finishWaiters++; + stateProtector.unlock(); + } else { + if (withReq && s < STATE_CLOSE_REQ) closeReq(); + closeWaiters++; + stateProtector.unlock(); + + std::mutex fakeMutex; + { + std::unique_lock fakeLock(fakeMutex); + if (timeoutUs && s < STATE_CLOSING) { + std::chrono::steady_clock::time_point t + = std::chrono::steady_clock::now() + + std::chrono::microseconds(timeoutUs); + while(s < STATE_CLOSED) { + if (t <= std::chrono::steady_clock::now()) { close(true); break; } + closeCondition.wait_until(fakeLock, t); + s = state; + } + } + if (s < STATE_CLOSING) close(true); + while(s < STATE_CLOSED) { + closeCondition.wait(fakeLock); + s = state; + } + } + + finishWaiters++; + closeWaiters--; + } + + + while(s != STATE_FINISHED) { std::this_thread::yield(); s = state; } + + finishWaiters--; +} + + +void Server::onOpeningError() { } +void Server::onOpen() { } +Connection* Server::onConnect(void*, size_t) { return nullptr; } +void Server::onDisconnect(Connection*, bool) { } +void Server::onCloseReqested() { } +void Server::onClose(bool) { } + diff --git a/server.h b/server.h new file mode 100644 index 0000000..4742c27 --- /dev/null +++ b/server.h @@ -0,0 +1,71 @@ +#ifndef SERVER_H +#define SERVER_H + + +#include "connection.h" + + +class Server: public Root { + enum State: int { + STATE_NONE, + STATE_OPENING, + STATE_INITIALIZING, + STATE_RESOLVING, + STATE_OPEN, + STATE_CLOSE_REQ, + STATE_CLOSING_CONNECTIONS, + STATE_CLOSING, + STATE_CLOSED, + STATE_FINISHED + }; + +private: + friend class Protocol; + + std::atomic state; + std::atomic stateWanted; + ReadProtector stateProtector; + std::atomic enqueued; + std::atomic error; + + std::atomic closeWaiters; + std::atomic finishWaiters; + std::condition_variable closeCondition; + + unsigned char address[MAX_ADDR_SIZE]; + size_t addressSize; + + Protocol *protocol; + Server *prev, *next; + Server *queueNext; + Connection *connFirst, *connLast; + int sockId; + State stateLocal; + + void wantState(State state, bool error); + +public: + typedef QueueMISO Queue; + friend class QueueMISO; + friend class ChainFuncs; + +public: + Server(); + ~Server(); + + bool open(Protocol &protocol, void *address, size_t addressSize); + void closeReq(); + void close(bool error); + void closeWait(unsigned long long timeoutUs = 0, bool withReq = true); + +protected: + virtual void onOpeningError(); + virtual void onOpen(); + virtual Connection* onConnect(void *address, size_t addressSize); + virtual void onDisconnect(Connection *connection, bool error); + virtual void onCloseReqested(); + virtual void onClose(bool error); +}; + + +#endif diff --git a/tcplab.cpp1 b/tcplab.cpp1 new file mode 100644 index 0000000..b14b76c --- /dev/null +++ b/tcplab.cpp1 @@ -0,0 +1,496 @@ + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +enum WorkMode { + WM_STOPPED, + WM_STARTING, + WM_STARTED, + WM_STOPPING }; + + +template +class Chain { +private: + Chain() { } +public: + typedef T Type; + static constexpr Type* Type::*Next = NextField; + + static inline void push(Type* &first, Type *item) + { assert(!item.*Next); item.*Next = first; first = &item; } + static inline void push(std::atomic &first, Type &item) + { assert(!item.*Next); item.*Next = first; while(!first.compare_exchange_weak(item.*Next, &item)); } + + static inline Type* pop(Type* &first) { + Type *item = first; + if (!item) return nullptr; + first = item->next; + item->next = nullptr; + return item; + } + static inline Type* pop(std::atomic &first) { + Type *item = first; + do { + if (!item) return nullptr; + } while(!first.compare_exchange_weak(item, item->*Next)); + item->next = nullptr; + return item; + } + + static inline void invert(Type* &first) { + if (Item *curr = first) { + Item *prev = nullptr; + while(Item *next = curr->*Next) + { curr->*Next = prev; prev = curr; curr = next; } + first = curr; + } + } + static inline Type* inverted(Type *first) + { invert(first); return first; } +}; + + +// multhreaded input, single threaded output +// simultaneus read and write are allowed +template +class QueueMISO { +public: + typedef T Type; + typedef Chain Chain; + static constexpr Type* Type::*Next = NextField; + +private: + std::atomic pushed; + Type *first; + std::atomic counter; + +public: + inline QueueMISO(): + pushed(nullptr), first() { } + inline ~QueueMISO() + { assert(!pushed && !first && !counter); } + inline void push(Type &item) + { Chain::push(pushed, item); counter++; } + inline Type& pop() + { assert(first); Type &item = *Chain::pop(first); counter++; return item; } + inline Type *peek() + { return first ? first : (first = Chain::inverted(pushed.exchange(nullptr))); } + inline int count() const + { return counter; } +}; + + + +class AMutex { +private: + std::atomic locked; +public: + inline AMutex(): locked(false) { } + inline ~AMutex() { assert(!locked); } + + inline bool try_lock() + { bool l = false; return locked.compare_exchange_strong(l, true); } + inline void lock() + { bool l = false; while(!locked.compare_exchange_weak(l, true)); } + inline void unlock() + { locked = false; } +}; + + +class ARWMutex { +private: + std::atomic counter; +public: + inline ARWMutex(): counter(0) { } + inline ~ARWMutex() { assert(!counter); } + + inline bool try_lock() + { int cnt = 0; return counter.compare_exchange_strong(cnt, -1); } + inline void lock() + { int cnt = 0; while(!counter.compare_exchange_weak(cnt, -1)) std::this_thread::yield(); } + inline void unlock() + { assert(counter < 0); counter = 0; } + + inline bool try_lock_read() { + int cnt = counter; + do { if (cnt < 0) return false; } while (counter.compare_exchange_weak(cnt, cnt + 1)); + return true; + } + inline void lock_read() + { int cnt = counter; while(cnt < 0 || !counter.compare_exchange_strong(cnt, cnt + 1)); } + inline void unlock_read() + { assert(counter >= 0); counter--; } +}; + + +typedef std::lock_guard ALock; +typedef std::lock_guard AWLock; + +class ARLock { +private: + ARWMutex &mutex; +public: + inline ARLock(ARWMutex &mutex): mutex(mutex) { lock(); } + inline ~ARLock() { unlock(); } + inline void lock() { mutex.lock_read(); } + inline void unlock() { mutex.unlock_read(); } +}; + + + +typedef std::lock_guard Lock; +typedef std::unique_lock UniqLock; + + +class TcpProtocol; +class Server; +class Connection; + + + + +class TcpSocket { +public: + TcpProtocol &protocol; + int sockId; + Server *server; + Connection *connection; + + explicit TcpSocket(TcpProtocol &protocol, int sockId); + ~TcpSocket(); + void closeReq(); + void update(bool read, bool write); +}; + + +class Connection { +public: + enum State { + STATE_NONE, + STATE_CONNECTING, + STATE_CONNECTED, + STATE_CLOSE_HINT, + STATE_CLOSE_REQ, + STATE_CLOSED, + }; + + struct Task { + unsigned char *data; + int size; + int position; + Task *next; + + explicit Task(void *data = nullptr, int size = 0): + data((unsigned char*)data), size(size), position() { } + }; + typedef QueueMISO Queue; + + + std::atomic stateReaders; // count of threads that currently reads the state + std::atomic state; // write from protocol thread, read from others + std::atomic stateWanted; // read from protocol thread, write from others + State stateQuickCopy; // for use in protocol thread only + std::atomic lastEvents; + + TcpProtocol *protocol; + int sockId; + Queue readQueue; + Queue writeQueue; + + // private + void updateEvents() { + assert(stateReaders > 0); // << must be locked + + unsigned int events = 0; + if (readQueue.count()) events |= EPOLLIN; + if (writeQueue.count()) events |= EPOLLOUT; + if (lastEvents.exchange(events) == events) return; + + struct epoll_event event = {}; + event.data.ptr = this; + event.events = events; + epoll_ctl(protocol->pollId, EPOLL_CTL_MOD, sockId, &event); + } + + // public functions + void connect(TcpProtocol &protocol, const void *address, size_t size); + void accept(TcpProtocol &protocol, int sockId); + void closeHint() { wantState(STATE_CLOSE_HINT); } + void closeReq() { wantState(STATE_CLOSED); } + + void closeWait(unsigned long long timeoutUs); + + void wantState(State state) { + State sw = stateWanted; + while(sw < state && !stateWanted.compare_exchange_weak(sw, state)); + } + + bool send(Task &task) { + stateReaders++; + State s = state; + if (s < STATE_CONNECTED || s >= STATE_CLOSED) return false; + writeQueue.push(task); + updateEvents(); + stateReaders--; + return true; + } + + bool receive(Task &task) { + stateReaders++; + State s = state; + if (s < STATE_CONNECTED || s >= STATE_CLOSED) return false; + readQueue.push(task); + updateEvents(); + stateReaders--; + return true; + } + + + // handlers, must be called from protocol thread + virtual void onOpen() { } + virtual void onClose(bool) { } + virtual void onReadReady(const Task&) { } + virtual void onWriteReady(const Task&) { } + + + std::mutex mutex; + WorkMode mode; + TcpSocket *socket; + + + // must be called from Protocol thread only + bool open(TcpSocket *socket) { + { + Lock lock(mutex); + if (mode != WM_STOPPED) return false; + mode = WM_STARTING; + this->socket = socket; + } + + onOpen(); + + { + Lock lock(mutex); + mode = WM_STARTED; + } + + return true; + } + + void close(bool err) { + { + Lock lock(mutex); + assert(mode == WM_STARTED); + mode = WM_STOPPING; + if (!err && (!readQueue.empty() || !writeQueue.empty())) + err = true; + } + + onClose(err); + + { + socket = nullptr; + } + } + + // must be called internally + + void closeReq() { + Lock lock(mutex); + if (mode != WM_STARTING && mode != WM_STARTED) return; + socket->closeReq(); + } + + void readReq(void *data, int size) { + assert(data && size > 0); + Lock lock(mutex); + if (mode != WM_STARTING && mode != WM_STARTED) return; + readQueue.emplace_back(data, size); + socket->update(!readQueue.empty(), !writeQueue.empty()); + } + + void writeReq(void *data, int size) { + assert(data && size > 0); + Lock lock(mutex); + if (mode != WM_STARTING && mode != WM_STARTED) return; + writeQueue.emplace_back(data, size); + socket->update(!readQueue.empty(), !writeQueue.empty()); + } +}; + + + +class TcpProtocol { +private: + friend class TcpSocket; + + std::mutex mutex; + std::thread *thread; + std::condition_variable condModeChanged; + WorkMode mode; + int epollId; + int eventsId; + + std::vector socketsToClose; + + // lock mutex befor call + void wakeup() { + unsigned long long num = 0; + ::write(eventsId, &num, sizeof(num)); + } + + void threadRun() { + const int maxCount = 16; + struct epoll_event events[maxCount] = {}; + int count = 0; + std::vector socketsToClose; + + { + Lock lock(mutex); + assert(mode == WM_STARTING); + mode = WM_STARTED; + condModeChanged.notify_all(); + } + + while(true) { + for(int i = 0; i < count; ++i) { + unsigned int e = events[i].events; + TcpSocket *socket = (TcpSocket*)events[i].data.ptr; + if (!socket) { + unsigned long long num = 0; + while(::read(eventsId, &num, sizeof(num)) > 0); + } else + if (socket->connection) { + // lock connection + //if (connection-> + } else + if (socket->server) { + // lock server + } else { + delete socket; + } + } + + + { + Lock lock(mutex); + if (mode == WM_STOPPING) { + socketsToClose + } + } + + count = epoll_wait(epollId, events, maxCount, 10000); + } + } + +public: + TcpProtocol(): + thread(), + mode(WM_STOPPED), + epollId(-1), + eventsId(-1) + { } + + ~TcpProtocol() + { stop(); } + + bool start() { + { + Lock lock(mutex); + if (mode == WM_STARTED || mode == WM_STARTING) return true; + if (mode != WM_STOPPED) return false; + mode = WM_STARTING; + } + + condModeChanged.notify_all(); + + epollId = ::epoll_create(32); + eventsId = ::eventfd(0, EFD_NONBLOCK); + assert(epollId >= 0 && eventsId >= 0); + + struct epoll_event event = {}; + event.events = EPOLLIN; + epoll_ctl(epollId, EPOLL_CTL_ADD, eventsId, &event); + + thread = new std::thread(&TcpProtocol::threadRun, this); + } + + void stop() { + { + UniqLock lock(mutex); + while(true) { + if (mode == WM_STOPPED) return; + if (mode == WM_STARTED) break; + condModeChanged.wait(lock); + } + mode = WM_STOPPING; + } + + condModeChanged.notify_all(); + thread->join(); + delete thread; + thread = nullptr; + ::close(epollId); + ::close(eventsId); + + { + Lock lock(mutex); + mode = WM_STOPPED; + } + condModeChanged.notify_all(); + } +}; + + + +TcpSocket::TcpSocket(TcpProtocol &protocol, int sockId): + protocol(protocol), + sockId(sockId) +{ + int tcp_nodelay = 1; + setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int)); + fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK); + struct epoll_event event = {}; + event.data.ptr = this; + epoll_ctl(protocol.epollId, EPOLL_CTL_ADD, sockId, &event); +} + +TcpSocket::~TcpSocket() { + struct epoll_event event = {}; + epoll_ctl(protocol.epollId, EPOLL_CTL_DEL, sockId, &event); + ::close(sockId); +} + +void TcpSocket::closeReq() { + Lock lock(protocol.mutex); + socketsToClose.push_back(this); + protocol.wakeup(); +} + +void TcpSocket::update(bool read, bool write) { + struct epoll_event event = {}; + event.data.ptr = this; + if (read) event.events |= EPOLLIN; + if (write) event.events |= EPOLLOUT; + epoll_ctl(protocol.epollId, EPOLL_CTL_MOD, sockId, &event); +} + + diff --git a/utils.h b/utils.h new file mode 100644 index 0000000..e759c1e --- /dev/null +++ b/utils.h @@ -0,0 +1,135 @@ +#ifndef UTILS_H +#define UTILS_H + + +#include + +#include +#include + + +class ReadProtector { +private: + std::atomic enters; + std::atomic exits; + +public: + inline ReadProtector(): enters(0), exits(0) { } + inline ~ReadProtector() { assert(!(enters - exits)); } + inline void lock() { enters++; } + inline void unlock() { exits++; } + inline void wait() const { + int en = enters; + while(en - exits > 0) std::this_thread::yield(); + } +}; + + +class ReadLock { +private: + ReadProtector &protector; + ReadLock(const ReadLock&) = delete; +public: + inline explicit ReadLock(ReadProtector &protector): protector(protector) { protector.lock(); } + inline ~ReadLock() { protector.unlock(); } +}; + + + + +template +class ChainFuncs { +private: + ChainFuncs() { } + +public: + typedef T Type; + static constexpr Type* Type::*Next = NextField; + + static inline void push(Type* &first, Type &item) + { assert(!(item.*Next)); item.*Next = first; first = &item; } + static inline void push(std::atomic &first, Type &item) + { assert(!(item.*Next)); item.*Next = first; while(!first.compare_exchange_weak(item.*Next, &item)); } + + static inline Type* pop(Type* &first) { + Type *item = first; + if (!item) return nullptr; + first = item->next; + item->next = nullptr; + return item; + } + static inline Type* pop(std::atomic &first) { + Type *item = first; + do { + if (!item) return nullptr; + } while(!first.compare_exchange_weak(item, item->*Next)); + item->next = nullptr; + return item; + } + + static inline void invert(Type* &first) { + if (Type *curr = first) { + Type *prev = nullptr; + while(Type *next = curr->*Next) + { curr->*Next = prev; prev = curr; curr = next; } + first = curr; + } + } + static inline Type* inverted(Type *first) + { invert(first); return first; } +}; + + +// multhreaded input, single threaded output +// simultaneus read and write are allowed +template +class QueueMISO { +public: + typedef T Type; + typedef ChainFuncs Funcs; + static constexpr Type* Type::*Next = NextField; + +private: + std::atomic pushed; + Type *first; + +public: + inline QueueMISO(): + pushed(nullptr), first() { } + inline ~QueueMISO() + { assert(!pushed && !first); } + inline void push(Type &item) + { Funcs::push(pushed, item); } + inline Type& pop() + { assert(first); return *Funcs::pop(first); } + inline void unpop(Type &item) + { return Funcs::push(first, item); } + inline Type *peek() + { return first ? first : (first = Funcs::inverted(pushed.exchange(nullptr))); } +}; + + +template +class CounteredQueueMISO: public QueueMISO { +public: + typedef T Type; + typedef QueueMISO Queue; + +private: + std::atomic counter; + +public: + inline CounteredQueueMISO(): + counter(0) { } + inline ~CounteredQueueMISO() + { assert(!counter); } + inline void push(Type &item) + { Queue::push(item); counter++; } + inline Type& pop() + { Type &item = Queue::pop(); counter++; return item; } + inline int count() const + { return counter; } +}; + + +#endif