diff --git a/common.cpp b/common.cpp new file mode 100644 index 0000000..b4cd4e2 --- /dev/null +++ b/common.cpp @@ -0,0 +1,8 @@ + +#include "common.h" + + + +void Address::print() const + { printf("%hhu.%hhu.%hhu.%hhu:%hu", ip[0], ip[1], ip[2], ip[3], port); } + diff --git a/common.h b/common.h index 54e2280..0e1763e 100644 --- a/common.h +++ b/common.h @@ -21,6 +21,8 @@ enum : unsigned short { typedef unsigned long long Time; +Time monotonicTime(); +Time globalTime(); inline bool cycleLequal(unsigned int a, unsigned int b) @@ -62,6 +64,8 @@ public: : other.address < address ? false : id < other.id; } + + static unsigned int generateId(); }; @@ -81,7 +85,7 @@ struct __attribute__((__packed__)) Packet { inline int getCommand() const { return cmdSize >> CMD_SHIFT; } - inline init( + inline void init( unsigned int connId, unsigned int index, unsigned short cmd, @@ -89,8 +93,7 @@ struct __attribute__((__packed__)) Packet { { this->connId = connId; this->index = index; - this->cmd = cmd; - this->size = size; + this->cmdSize = (cmd << CMD_SHIFT) | size; } }; diff --git a/connection.cpp b/connection.cpp new file mode 100644 index 0000000..30d1f80 --- /dev/null +++ b/connection.cpp @@ -0,0 +1,18 @@ + +#include "tunnel.h" +#include "connection.h" + + +Connection::Connection(Tunnel &tunnel, const ConnId &id, int tcpSockId): + tunnel(tunnel), + id(id), + tcpSockId(tcpSockId), + udpActive(true), + tcpSendQueue(*this), + tcpRecvQueue(*this), + udpSendQueue(*this), + udpRecvQueue(*this) + { } + +Connection::~Connection() { } + diff --git a/connection.h b/connection.h new file mode 100644 index 0000000..4707952 --- /dev/null +++ b/connection.h @@ -0,0 +1,34 @@ +#ifndef CONNECTION_H +#define CONNECTION_H + + +#include "tcpqueue.h" +#include "udpsendqueue.h" +#include "udprecvqueue.h" + + +class Tunnel; + + +class Connection { +public: + Tunnel &tunnel; + const ConnId id; + int tcpSockId; + bool udpActive; + + TcpQueue tcpSendQueue; + TcpQueue tcpRecvQueue; + UdpSendQueue udpSendQueue; + UdpRecvQueue udpRecvQueue; + + Connection(Tunnel &tunnel, const ConnId &id, int tcpSockId); + ~Connection(); + + bool wantReadFromTcp() const; + bool wantWriteToTcp() const; + Time whenWriteToUdp() const; +}; + + +#endif diff --git a/socket.cpp b/socket.cpp new file mode 100644 index 0000000..72c09eb --- /dev/null +++ b/socket.cpp @@ -0,0 +1,2 @@ + +#include "socket.h" diff --git a/socket.h b/socket.h index c563eb4..920afde 100644 --- a/socket.h +++ b/socket.h @@ -1,22 +1,40 @@ #ifndef SOCKET_H -#define SOCEKT_H +#define SOCKET_H +#include + #include "common.h" +class Connection; + + class Poll { -private: - int count; - int allocatedCount; - void *data; +public: + struct Entry { + Connection *connection; + int sockId; + bool wantRead; + bool wantWrite; + bool canRead; + bool canWrite; + bool closed; + + inline Entry(): + connection(), sockId(-1), wantRead(), wantWrite(), canRead(), canWrite(), closed() { } + }; + typedef std::vector List; + +private: + std::vector data; + public: + List list; + Poll(); ~Poll(); - - void add(int sockId, bool wantRead, bool wantWrite); - void clear(); bool wait(Time time); }; diff --git a/tcpqueue.cpp b/tcpqueue.cpp index a7c47c6..852776d 100644 --- a/tcpqueue.cpp +++ b/tcpqueue.cpp @@ -20,14 +20,14 @@ bool TcpQueue::push(const void *data, int size) { if (!data || size <= 0 || freeSize() < size) return false; - int part1 = ALLOCED_SIZE - end; + int part1 = ALLOCATED_SIZE - end; if (size > part1) { memcpy(buffer + end, data, part1); memcpy(buffer, (const unsigned char*)data + part1, size - part1); } else { memcpy(buffer + end, data, size); } - end = (end + size)%ALLOCED_SIZE; + end = (end + size)%ALLOCATED_SIZE; return true; } @@ -37,14 +37,14 @@ bool TcpQueue::pop(void *data, int size) { if (!data || size <= 0 || busySize() < size) return false; - int part1 = ALLOCED_SIZE - begin; + int part1 = ALLOCATED_SIZE - begin; if (size > part1) { memcpy(data, buffer + begin, part1); memcpy((unsigned char*)data + part1, buffer, size - part1); } else { memcpy(data, buffer + begin, size); } - begin = (begin + size)%ALLOCED_SIZE; + begin = (begin + size)%ALLOCATED_SIZE; return true; } @@ -59,7 +59,7 @@ int TcpQueue::readFromTcp() { int doneSize = 0; int parts[2]; - parts[0] = ALLOCED_SIZE - end; + parts[0] = ALLOCATED_SIZE - end; parts[1] = parts[0] - size; if (parts[1] < 0) { parts[0] = size; parts[1] = 0; } @@ -73,7 +73,7 @@ int TcpQueue::readFromTcp() { end += res; parts[i] -= res; } - if (end >= ALLOCED_SIZE) end = 0; + if (end >= ALLOCATED_SIZE) end = 0; } return doneSize; } @@ -88,7 +88,7 @@ int TcpQueue::writeToTcp() { int doneSize = 0; int parts[2]; - parts[0] = ALLOCED_SIZE - begin; + parts[0] = ALLOCATED_SIZE - begin; parts[1] = parts[0] - size; if (parts[1] < 0) { parts[0] = size; parts[1] = 0; } @@ -102,7 +102,7 @@ int TcpQueue::writeToTcp() { begin += res; parts[i] -= res; } - if (begin >= ALLOCED_SIZE) begin = 0; + if (begin >= ALLOCATED_SIZE) begin = 0; } return doneSize; } diff --git a/tcpqueue.h b/tcpqueue.h index d1d7f7d..4ac3af6 100644 --- a/tcpqueue.h +++ b/tcpqueue.h @@ -10,12 +10,12 @@ class Connection; class TcpQueue { public: - enum { ALLOCED_SIZE = TCP_BUFFER_SIZE + 1 }; + enum { ALLOCATED_SIZE = TCP_BUFFER_SIZE + 1 }; public: Connection &connection; private: - unsigned char buffer[ALLOCED_SIZE]; + unsigned char buffer[ALLOCATED_SIZE]; int begin; int end; @@ -24,7 +24,7 @@ public: ~TcpQueue(); inline int busySize() const - { return end < begin ? end + ALLOCED_SIZE - begin : end - begin; } + { return end < begin ? end + ALLOCATED_SIZE - begin : end - begin; } inline int freeSize() const { return TCP_BUFFER_SIZE - busySize(); } diff --git a/tunnel.cpp b/tunnel.cpp new file mode 100644 index 0000000..949938e --- /dev/null +++ b/tunnel.cpp @@ -0,0 +1,287 @@ + +#include + +#include "socket.h" +#include "tunnel.h" + + + +Tunnel::Tunnel(): + time(monotonicTime()), + udpSendDuration(1000), + udpPartialSendDuration(500000), + udpResendDuration(udpSendDuration * PACKETS_COUNT), + udpConfirmDuration(udpResendDuration/8), + pollDuration(5000000), + udpSockId(-1), + tcpSockId(-1), + nextTermSendTime(time + udpConfirmDuration) + { } + +Tunnel::~Tunnel() + { close(); } + + +bool Tunnel::initUdpServer(const Address &address, const Address &remoteTcpAddress) { + closeUdpServer(); + printf("Init UDP server at address: "); address.print(); + if (remoteTcpAddress.port) + { printf(", with remote TCP address: "); remoteTcpAddress.print(); } + printf("\n"); + + udpSockId = Socket::udpBind(address); + if (udpSockId < 0) + { printf("cannot bind, cancelled"); return false; } + + printf("success"); + this->remoteTcpAddress = remoteTcpAddress; + return true; +} + +bool Tunnel::initTcpServer(const Address &address) { + closeTcpServer(); + printf("Init TCP server at address: "); address.print(); printf("\n"); + + if (udpSockId < 0) + { printf("UDP server was not initialized, cancelled\n"); return false; } + + int sockId = ::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP); + struct sockaddr_in addr = {}; + addr.sin_family = PF_INET; + addr.sin_addr.s_addr = address.ipUInt; + addr.sin_port = address.port; + if (0 != ::bind(sockId, (struct sockaddr*)&addr, sizeof(addr))) + { ::close(sockId); printf("cannot bind, cancelled\n"); return false; } + if (0 != ::listen(sockId, 32)) + { ::close(sockId); printf("cannot listen, cancelled\n"); return false; } + + tcpSockId = sockId; + printf("success"); + return true; +} + + +void Tunnel::closeUdpServer() { + if (udpSockId < 0) return; + closeTcpServer(); + + printf("Close UDP server"); + for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i) + Socket::close(i->second.tcpSockId); + connections.clear(); + termQueue.clear(); + Socket::close(udpSockId); + udpSockId = -1; + remoteTcpAddress = Address(); +} + + +void Tunnel::closeTcpServer() { + if (tcpSockId < 0) return; + printf("Close TCP server"); + Socket::close(tcpSockId); + tcpSockId = -1; + remoteUdpAddress = Address(); +} + + +void Tunnel::close() + { closeUdpServer(); } + + +void Tunnel::readUdp() { + // read and parse UDP input + Packet packet; + const int minPacketSize = offsetof(Packet, payload) - offsetof(Packet, connId); + while(true) { + struct sockaddr_in addr = {}; + socklen_t addrlen = sizeof(addr); + int res = ::recvfrom( + udpSockId, + &packet.connId, + MSG_DONTWAIT, + sizeof(packet) - offsetof(Packet, connId), + (struct sockaddr*)&addr, + &addrlen ); + if (res <= 0) + break; + if (res < minPacketSize) + continue; + + int cmd = packet.getCommand(); + int size = packet.getSize(); + if (res != minPacketSize + packet.getSize()) + continue; + if (cmd != CMD_TERM && cmd != CMD_CONFIRM && cmd != CMD_DATA) + continue; + + ConnId id; + id.address.ipUInt = addr.sin_addr.s_addr; + id.address.port = addr.sin_port; + id.id = packet.connId; + int cmd = packet.getCommand(); + ConnMap::iterator i = connections.find(id); + Connection *conn = i == connections.end() ? nullptr : &i->second; + + if (!conn && remoteTcpAddress.port && cmd != CMD_TERM) { + printf("Open incoming connection from UDP: "); + id.address.print(); + printf(", id: %u\n", id.id); + + int sockId = ::socket(PF_INET, SOCK_DGRAM, IPPROTO_UDP); + struct sockaddr_in addr = {}; + addr.sin_family = PF_INET; + addr.sin_addr.s_addr = id.address.ipUInt; + addr.sin_port = id.address.port; + if (0 != ::connect(sockId, (struct sockaddr*)&addr, sizeof(addr))) { + ::close(sockId); + printf("cannot init TCP connection, cancelled"); + continue; + } + + conn = &connections[id]; + conn->id = id; + conn->tcpSockId = sockId; + conn->udpActive = true; + } + + if (!conn) { + if (cmd != CMD_TERM) + rejectedConnections.insert(id); + continue; + } + + if (cmd == CMD_TERM) { + conn->udpActive = false; + continue; + } + + if (cmd == CMD_CONFIRM) { + unsigned int index = packet.index; + for(int i = 0; i < size; ++i) + for(int j = 0; j < 8; ++j, ++index) + if (packet.payload[i] & (1 << j)) + if (PacketInfo *pi = conn->toUdpQueue[index]) + if (pi->loaded) + pi->confirmed = true; + while(PacketInfo *pi = conn->toUdpQueue.first()) + if (pi->loaded && pi->confirmed) + conn->toUdpQueue.pop(); else break; + continue; + } + } +} + + +bool Tunnel::iteration() { + if (udpSockId < 0) + return false; + + // prepare poll + + poll.list.clear(); + + poll.list.emplace_back(); + Poll::Entry &e = poll.list.back(); + e.sockId = udpSockId; + e.wantRead = true; + + if (tcpSockId >= 0) { + Poll::Entry &e = poll.list.back(); + e.sockId = tcpSockId; + e.wantRead = true; + } + + Time pollTime = time + pollDuration; + for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i) { + Connection &conn = i->second; + if (conn.tcpSockId >= 0) { + poll.list.emplace_back(); + Poll::Entry &e = poll.list.back(); + e.connection = &conn; + e.sockId = conn.tcpSockId; + e.wantRead = conn.udpActive && conn.wantReadFromTcp(); + e.wantWrite = conn.wantWriteToTcp(); + } + if (conn.udpActive) { + Time t = conn.whenWriteToUdp(); + if (timeLess(pollTime, t)) pollTime = t; + } + } + + // wait for events + poll.wait(pollTime); + time = monotonicTime(); + + if (poll.list[0].closed || (tcpSockId >= 0 && poll.list[1].closed)) { + poll.list.clear(); + close(); + return false; + } + + // read and parse udp + if (poll.list[0].canRead) + while(recvUdpPacket()); + + if (tcpSockId >= 0 && poll.list[1].canRead) { + // accept new tcp connection + int sockId = Socket::tcpAccept(tcpSockId); + if (sockId >= 0) { + ConnId id; + id.address = remoteUdpAddress; + id.id = ConnId::generateId(); + Connection &conn = connections.emplace(id, *this, id, tcpSockId).first->second; + conn.tcpRecvQueue.readFromTcp(); + } + } + + // tcp read and write + int first = tcpSockId < 0 ? 1 : 2; + for(Poll::List::const_iterator i = poll.list.begin() + first; i != poll.list.end(); ++i) { + Connection &conn = *i->connection; + conn.udpRecvQueue.sendTcp(); + if (i->canWrite) + conn.tcpSendQueue.writeToTcp(); + if (i->canRead && conn.udpActive) + conn.tcpRecvQueue.readFromTcp(); + if (i->closed || (!conn.udpActive && !conn.tcpSendQueue.busySize())) { + Socket::close(conn.tcpSockId); + conn.tcpSockId = -1; + } + } + + // write to udp + for(ConnMap::iterator i = connections.begin(); i != connections.end();) { + ConnMap::iterator j = i++; + Connection &conn = j->second; + if (conn.udpActive) { + conn.udpSendQueue.sendUdp(); + conn.udpRecvQueue.sendUdpConfirm(); + } else + if (conn.tcpSockId < 0) { + connections.erase(j); + } + } + + // send udp term packets + if (timeLequal(nextTermSendTime, time)) { + for(ConnSet::const_iterator i = termQueue.begin(); i != termQueue.end(); ++i) { + Packet packet; + packet.init(i->id, 0, CMD_TERM, 0); + sendUdpPacket(i->address, packet); + } + termQueue.clear(); + nextTermSendTime = time + nextTermSendTime; + } + + poll.list.clear(); + return true; +} + + +void Tunnel::run() { + time = monotonicTime(); + nextTermSendTime = time + udpConfirmDuration; + while(iteration()); +} + diff --git a/tunnel.h b/tunnel.h new file mode 100644 index 0000000..15ae065 --- /dev/null +++ b/tunnel.h @@ -0,0 +1,56 @@ +#ifndef TUNNEL_H +#define TUNNEL_H + + +#include "socket.h" +#include "connection.h" + +#include +#include + + + +class Tunnel { +public: + typedef std::map ConnMap; + typedef std::set ConnSet; + +public: + Time time; + Time udpSendDuration; + Time udpPartialSendDuration; + Time udpResendDuration; + Time udpConfirmDuration; + Time pollDuration; + + ConnMap connections; + ConnSet termQueue; + +private: + Poll poll; + int udpSockId; + int tcpSockId; + Address remoteUdpAddress; + Address remoteTcpAddress; + Time nextTermSendTime; + +public: + Tunnel(); + ~Tunnel(); + + bool initUdpServer(const Address &address, const Address &remoteTcpAddress); + bool initTcpServer(const Address &address, const Address &remoteUdpAddress); + + void closeTcpServer(); + void closeUdpServer(); + void close(); + + bool recvUdpPacket(); + void sendUdpPacket(const Address &address, const Packet &packet); + bool iteration(); + + void run(); +}; + + +#endif diff --git a/udpsendqueue.cpp b/udpsendqueue.cpp index 6d3c455..dd1de09 100644 --- a/udpsendqueue.cpp +++ b/udpsendqueue.cpp @@ -70,7 +70,7 @@ void UdpSendQueue::confirm(unsigned int index, const unsigned char *bits, unsign } -bool UdpSendQueue::sendUdp() { +bool UdpSendQueue::sendUdpSingle() { Time time = connection.tunnel.time; if (timeLess(time, nextSendTime)) return false; @@ -108,9 +108,9 @@ bool UdpSendQueue::sendUdp() { } Packet packet; - packet.init(connection.id, e->index, CMD_DATA, e->size); + packet.init(connection.id.id, e->index, CMD_DATA, e->size); if (e->size) memcpy(packet.payload, e->data, e->size); - connection.tunnel.udpSend(packet); + connection.tunnel.sendUdpPacket(packet); e->resendTime = time + connection.tunnel.udpResendDuration; nextSendTime += connection.tunnel.udpPartialSendDuration; @@ -118,3 +118,17 @@ bool UdpSendQueue::sendUdp() { return true; } + +int UdpSendQueue::sendUdp() { + int sent = 0; + while(sendUdpSingle()) { + ++sent; + if (sent == 100) { + if (timeLess(nextSendTime, connection.tunnel.time)) + nextSendTime = connection.tunnel.time; + break; + } + } + return sent; +} + diff --git a/udpsendqueue.h b/udpsendqueue.h index 98458d4..25ef443 100644 --- a/udpsendqueue.h +++ b/udpsendqueue.h @@ -5,6 +5,7 @@ #include "common.h" + class Connection; @@ -38,7 +39,8 @@ public: ~UdpSendQueue(); void confirm(unsigned int index, const unsigned char *bits, unsigned int bitsCount); - bool sendUdp(); + bool sendUdpSingle(); + int sendUdp(); };