#include <cstring>
#include <fcntl.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include "socket.h"
#include "protocol.h"
#define DebugMSG ShowDebugMSG
Socket::Socket():
state(STATE_NONE),
stateWanted(STATE_NONE),
enqueued(false),
closeWaiters(0),
finishWaiters(0),
address(),
addressSize(),
protocol(),
prev(), next(),
queueNext(),
sockId(-1),
stateLocal(),
events(0xffffffff)
{ }
Socket::~Socket()
{ assert(state == STATE_DESTROYED); }
void Socket::destroy() {
assert(state == STATE_NONE || state == STATE_FINISHED);
++finishWaiters;
closeWait();
state = STATE_DESTROYED;
--finishWaiters;
while(finishWaiters > 0)
std::this_thread::yield();
}
bool Socket::openBegin(Protocol &protocol, void *address, size_t addressSize) {
if (!address || !addressSize || addressSize > MAX_ADDR_SIZE) return false;
State s = state;
do {
if (s != STATE_NONE && s != STATE_FINISHED)
return false;
} while(!state.compare_exchange_weak(s, STATE_OPENING));
if (s == STATE_FINISHED)
while(finishWaiters > 0)
std::this_thread::yield();
if (state == STATE_DESTROYED)
return false;
protocol.stateProtector.lock();
State ps = protocol.state;
if (ps < STATE_INITIALIZING || ps >= STATE_CLOSE_REQ) {
state = STATE_NONE;
protocol.stateProtector.unlock();
return false;
}
this->protocol = &protocol;
memcpy(this->address, address, addressSize);
this->addressSize = addressSize;
return true;
}
void Socket::openEnd() {
state = STATE_INITIALIZING;
enqueue();
protocol->stateProtector.unlock();
}
void Socket::initSocket(int sockId) {
int tcp_nodelay = 1;
setsockopt(sockId, IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay, sizeof(int));
fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK);
}
void Socket::enqueue() {
if (!enqueued.exchange(true)) {
protocol->sockQueue.push(*this);
protocol->wakeup();
}
}
void Socket::updateEvents(bool wantRead, bool wantWrite) {
unsigned int e = EPOLLERR | EPOLLRDHUP;
if (wantRead) e |= EPOLLIN;
if (wantWrite) e |= EPOLLOUT;
DebugMSG("events: %08x", e);
unsigned int pe = events.exchange(e);
if (pe == e)
return;
DebugMSG("apply");
struct epoll_event event = {};
event.data.ptr = this;
event.events = e;
epoll_ctl(protocol->epollFd, pe == 0xffffffff ? EPOLL_CTL_ADD : EPOLL_CTL_MOD, sockId, &event);
}
void Socket::wantState(State state, bool error) {
ReadLock lock(stateProtector);
State s = this->state;
DebugMSG("%d, %d", s, state);
if (s <= STATE_OPENING || s >= STATE_CLOSING || s >= state) return;
DebugMSG("apply");
State sw = stateWanted;
do {
if (sw >= state) return;
} while(!stateWanted.compare_exchange_weak(sw, state));
if (error) { DebugMSG("error"); error = false; this->error.compare_exchange_strong(error, true); }
enqueue();
}
void Socket::closeReq()
{ wantState(STATE_CLOSE_REQ, false); }
void Socket::close(bool error)
{ wantState(STATE_CLOSING, error); }
void Socket::closeWait(unsigned long long timeoutUs, bool withReq, bool error) {
stateProtector.lock();
State s = state;
if (s <= STATE_OPENING || s >= STATE_FINISHED)
{ stateProtector.unlock(); return; }
if (s == STATE_CLOSED) {
finishWaiters++;
stateProtector.unlock();
} else {
if (withReq && s < STATE_CLOSE_REQ) closeReq();
closeWaiters++;
stateProtector.unlock();
std::mutex fakeMutex;
{
std::unique_lock<std::mutex> fakeLock(fakeMutex);
if (timeoutUs && s < STATE_CLOSING) {
std::chrono::steady_clock::time_point t
= std::chrono::steady_clock::now()
+ std::chrono::microseconds(timeoutUs);
while(s < STATE_CLOSED) {
if (t <= std::chrono::steady_clock::now()) { close(error); break; }
closeCondition.wait_until(fakeLock, t);
s = state;
}
}
if (s < STATE_CLOSING) close(error);
while(s < STATE_CLOSED) {
closeCondition.wait(fakeLock);
s = state;
}
}
finishWaiters++;
closeWaiters--;
}
while(s < STATE_FINISHED) { std::this_thread::yield(); s = state; }
finishWaiters--;
}
void Socket::onOpeningError() { }
void Socket::onOpen(const void*, size_t) { }
void Socket::onCloseReqested() { }
void Socket::onClose(bool) { }