#include <climits>
#include <cassert>
#include <atomic>
#include <mutex>
#include <thread>
#include <condition_variable>
#include <fcntl.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
enum WorkMode {
WM_STOPPED,
WM_STARTING,
WM_STARTED,
WM_STOPPING };
template<typename T, T* T::*NextField>
class Chain {
private:
Chain() { }
public:
typedef T Type;
static constexpr Type* Type::*Next = NextField;
static inline void push(Type* &first, Type *item)
{ assert(!item.*Next); item.*Next = first; first = &item; }
static inline void push(std::atomic<Type*> &first, Type &item)
{ assert(!item.*Next); item.*Next = first; while(!first.compare_exchange_weak(item.*Next, &item)); }
static inline Type* pop(Type* &first) {
Type *item = first;
if (!item) return nullptr;
first = item->next;
item->next = nullptr;
return item;
}
static inline Type* pop(std::atomic<Type*> &first) {
Type *item = first;
do {
if (!item) return nullptr;
} while(!first.compare_exchange_weak(item, item->*Next));
item->next = nullptr;
return item;
}
static inline void invert(Type* &first) {
if (Item *curr = first) {
Item *prev = nullptr;
while(Item *next = curr->*Next)
{ curr->*Next = prev; prev = curr; curr = next; }
first = curr;
}
}
static inline Type* inverted(Type *first)
{ invert(first); return first; }
};
// multhreaded input, single threaded output
// simultaneus read and write are allowed
template<typename T, T* T::*NextField>
class QueueMISO {
public:
typedef T Type;
typedef Chain<Type, NextField> Chain;
static constexpr Type* Type::*Next = NextField;
private:
std::atomic<Type*> pushed;
Type *first;
std::atomic<int> counter;
public:
inline QueueMISO():
pushed(nullptr), first() { }
inline ~QueueMISO()
{ assert(!pushed && !first && !counter); }
inline void push(Type &item)
{ Chain::push(pushed, item); counter++; }
inline Type& pop()
{ assert(first); Type &item = *Chain::pop(first); counter++; return item; }
inline Type *peek()
{ return first ? first : (first = Chain::inverted(pushed.exchange(nullptr))); }
inline int count() const
{ return counter; }
};
class AMutex {
private:
std::atomic<bool> locked;
public:
inline AMutex(): locked(false) { }
inline ~AMutex() { assert(!locked); }
inline bool try_lock()
{ bool l = false; return locked.compare_exchange_strong(l, true); }
inline void lock()
{ bool l = false; while(!locked.compare_exchange_weak(l, true)); }
inline void unlock()
{ locked = false; }
};
class ARWMutex {
private:
std::atomic<int> counter;
public:
inline ARWMutex(): counter(0) { }
inline ~ARWMutex() { assert(!counter); }
inline bool try_lock()
{ int cnt = 0; return counter.compare_exchange_strong(cnt, -1); }
inline void lock()
{ int cnt = 0; while(!counter.compare_exchange_weak(cnt, -1)) std::this_thread::yield(); }
inline void unlock()
{ assert(counter < 0); counter = 0; }
inline bool try_lock_read() {
int cnt = counter;
do { if (cnt < 0) return false; } while (counter.compare_exchange_weak(cnt, cnt + 1));
return true;
}
inline void lock_read()
{ int cnt = counter; while(cnt < 0 || !counter.compare_exchange_strong(cnt, cnt + 1)); }
inline void unlock_read()
{ assert(counter >= 0); counter--; }
};
typedef std::lock_guard<AMutex> ALock;
typedef std::lock_guard<ARWMutex> AWLock;
class ARLock {
private:
ARWMutex &mutex;
public:
inline ARLock(ARWMutex &mutex): mutex(mutex) { lock(); }
inline ~ARLock() { unlock(); }
inline void lock() { mutex.lock_read(); }
inline void unlock() { mutex.unlock_read(); }
};
typedef std::lock_guard<std::mutex> Lock;
typedef std::unique_lock<std::mutex> UniqLock;
class TcpProtocol;
class Server;
class Connection;
class TcpSocket {
public:
TcpProtocol &protocol;
int sockId;
Server *server;
Connection *connection;
explicit TcpSocket(TcpProtocol &protocol, int sockId);
~TcpSocket();
void closeReq();
void update(bool read, bool write);
};
class Connection {
public:
enum State {
STATE_NONE,
STATE_CONNECTING,
STATE_CONNECTED,
STATE_CLOSE_HINT,
STATE_CLOSE_REQ,
STATE_CLOSED,
};
struct Task {
unsigned char *data;
int size;
int position;
Task *next;
explicit Task(void *data = nullptr, int size = 0):
data((unsigned char*)data), size(size), position() { }
};
typedef QueueMISO<Task, &Task::next> Queue;
std::atomic<int> stateReaders; // count of threads that currently reads the state
std::atomic<State> state; // write from protocol thread, read from others
std::atomic<State> stateWanted; // read from protocol thread, write from others
State stateQuickCopy; // for use in protocol thread only
std::atomic<unsigned int> lastEvents;
TcpProtocol *protocol;
int sockId;
Queue readQueue;
Queue writeQueue;
// private
void updateEvents() {
assert(stateReaders > 0); // << must be locked
unsigned int events = 0;
if (readQueue.count()) events |= EPOLLIN;
if (writeQueue.count()) events |= EPOLLOUT;
if (lastEvents.exchange(events) == events) return;
struct epoll_event event = {};
event.data.ptr = this;
event.events = events;
epoll_ctl(protocol->pollId, EPOLL_CTL_MOD, sockId, &event);
}
// public functions
void connect(TcpProtocol &protocol, const void *address, size_t size);
void accept(TcpProtocol &protocol, int sockId);
void closeHint() { wantState(STATE_CLOSE_HINT); }
void closeReq() { wantState(STATE_CLOSED); }
void closeWait(unsigned long long timeoutUs);
void wantState(State state) {
State sw = stateWanted;
while(sw < state && !stateWanted.compare_exchange_weak(sw, state));
}
bool send(Task &task) {
stateReaders++;
State s = state;
if (s < STATE_CONNECTED || s >= STATE_CLOSED) return false;
writeQueue.push(task);
updateEvents();
stateReaders--;
return true;
}
bool receive(Task &task) {
stateReaders++;
State s = state;
if (s < STATE_CONNECTED || s >= STATE_CLOSED) return false;
readQueue.push(task);
updateEvents();
stateReaders--;
return true;
}
// handlers, must be called from protocol thread
virtual void onOpen() { }
virtual void onClose(bool) { }
virtual void onReadReady(const Task&) { }
virtual void onWriteReady(const Task&) { }
std::mutex mutex;
WorkMode mode;
TcpSocket *socket;
// must be called from Protocol thread only
bool open(TcpSocket *socket) {
{
Lock lock(mutex);
if (mode != WM_STOPPED) return false;
mode = WM_STARTING;
this->socket = socket;
}
onOpen();
{
Lock lock(mutex);
mode = WM_STARTED;
}
return true;
}
void close(bool err) {
{
Lock lock(mutex);
assert(mode == WM_STARTED);
mode = WM_STOPPING;
if (!err && (!readQueue.empty() || !writeQueue.empty()))
err = true;
}
onClose(err);
{
socket = nullptr;
}
}
// must be called internally
void closeReq() {
Lock lock(mutex);
if (mode != WM_STARTING && mode != WM_STARTED) return;
socket->closeReq();
}
void readReq(void *data, int size) {
assert(data && size > 0);
Lock lock(mutex);
if (mode != WM_STARTING && mode != WM_STARTED) return;
readQueue.emplace_back(data, size);
socket->update(!readQueue.empty(), !writeQueue.empty());
}
void writeReq(void *data, int size) {
assert(data && size > 0);
Lock lock(mutex);
if (mode != WM_STARTING && mode != WM_STARTED) return;
writeQueue.emplace_back(data, size);
socket->update(!readQueue.empty(), !writeQueue.empty());
}
};
class TcpProtocol {
private:
friend class TcpSocket;
std::mutex mutex;
std::thread *thread;
std::condition_variable condModeChanged;
WorkMode mode;
int epollId;
int eventsId;
std::vector<TcpSocket*> socketsToClose;
// lock mutex befor call
void wakeup() {
unsigned long long num = 0;
::write(eventsId, &num, sizeof(num));
}
void threadRun() {
const int maxCount = 16;
struct epoll_event events[maxCount] = {};
int count = 0;
std::vector<TcpSocket*> socketsToClose;
{
Lock lock(mutex);
assert(mode == WM_STARTING);
mode = WM_STARTED;
condModeChanged.notify_all();
}
while(true) {
for(int i = 0; i < count; ++i) {
unsigned int e = events[i].events;
TcpSocket *socket = (TcpSocket*)events[i].data.ptr;
if (!socket) {
unsigned long long num = 0;
while(::read(eventsId, &num, sizeof(num)) > 0);
} else
if (socket->connection) {
// lock connection
//if (connection->
} else
if (socket->server) {
// lock server
} else {
delete socket;
}
}
{
Lock lock(mutex);
if (mode == WM_STOPPING) {
socketsToClose
}
}
count = epoll_wait(epollId, events, maxCount, 10000);
}
}
public:
TcpProtocol():
thread(),
mode(WM_STOPPED),
epollId(-1),
eventsId(-1)
{ }
~TcpProtocol()
{ stop(); }
bool start() {
{
Lock lock(mutex);
if (mode == WM_STARTED || mode == WM_STARTING) return true;
if (mode != WM_STOPPED) return false;
mode = WM_STARTING;
}
condModeChanged.notify_all();
epollId = ::epoll_create(32);
eventsId = ::eventfd(0, EFD_NONBLOCK);
assert(epollId >= 0 && eventsId >= 0);
struct epoll_event event = {};
event.events = EPOLLIN;
epoll_ctl(epollId, EPOLL_CTL_ADD, eventsId, &event);
thread = new std::thread(&TcpProtocol::threadRun, this);
}
void stop() {
{
UniqLock lock(mutex);
while(true) {
if (mode == WM_STOPPED) return;
if (mode == WM_STARTED) break;
condModeChanged.wait(lock);
}
mode = WM_STOPPING;
}
condModeChanged.notify_all();
thread->join();
delete thread;
thread = nullptr;
::close(epollId);
::close(eventsId);
{
Lock lock(mutex);
mode = WM_STOPPED;
}
condModeChanged.notify_all();
}
};
TcpSocket::TcpSocket(TcpProtocol &protocol, int sockId):
protocol(protocol),
sockId(sockId)
{
int tcp_nodelay = 1;
setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int));
fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
struct epoll_event event = {};
event.data.ptr = this;
epoll_ctl(protocol.epollId, EPOLL_CTL_ADD, sockId, &event);
}
TcpSocket::~TcpSocket() {
struct epoll_event event = {};
epoll_ctl(protocol.epollId, EPOLL_CTL_DEL, sockId, &event);
::close(sockId);
}
void TcpSocket::closeReq() {
Lock lock(protocol.mutex);
socketsToClose.push_back(this);
protocol.wakeup();
}
void TcpSocket::update(bool read, bool write) {
struct epoll_event event = {};
event.data.ptr = this;
if (read) event.events |= EPOLLIN;
if (write) event.events |= EPOLLOUT;
epoll_ctl(protocol.epollId, EPOLL_CTL_MOD, sockId, &event);
}