#include <cstdio>
#include "socket.h"
#include "tunnel.h"
Tunnel::Tunnel():
time(monotonicTime()),
udpSendDuration(1000),
udpPartialSendDuration(500000),
udpResendDuration(udpSendDuration * PACKETS_COUNT),
udpConfirmDuration(udpResendDuration/8),
pollDuration(5000000),
udpSockId(-1),
tcpSockId(-1),
nextTermSendTime(time + udpConfirmDuration)
{ }
Tunnel::~Tunnel()
{ close(); }
bool Tunnel::initUdpServer(const Address &address, const Address &remoteTcpAddress) {
closeUdpServer();
printf("Init UDP server at address: "); address.print();
if (remoteTcpAddress.port)
{ printf(", with remote TCP address: "); remoteTcpAddress.print(); }
printf("\n");
udpSockId = Socket::udpBind(address);
if (udpSockId < 0)
{ printf("cannot bind, cancelled"); return false; }
printf("success");
this->remoteTcpAddress = remoteTcpAddress;
return true;
}
bool Tunnel::initTcpServer(const Address &address) {
closeTcpServer();
printf("Init TCP server at address: "); address.print(); printf("\n");
if (udpSockId < 0)
{ printf("UDP server was not initialized, cancelled\n"); return false; }
int sockId = ::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
struct sockaddr_in addr = {};
addr.sin_family = PF_INET;
addr.sin_addr.s_addr = address.ipUInt;
addr.sin_port = address.port;
if (0 != ::bind(sockId, (struct sockaddr*)&addr, sizeof(addr)))
{ ::close(sockId); printf("cannot bind, cancelled\n"); return false; }
if (0 != ::listen(sockId, 32))
{ ::close(sockId); printf("cannot listen, cancelled\n"); return false; }
tcpSockId = sockId;
printf("success");
return true;
}
void Tunnel::closeUdpServer() {
if (udpSockId < 0) return;
closeTcpServer();
printf("Close UDP server");
for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i)
Socket::close(i->second.tcpSockId);
connections.clear();
termQueue.clear();
Socket::close(udpSockId);
udpSockId = -1;
remoteTcpAddress = Address();
}
void Tunnel::closeTcpServer() {
if (tcpSockId < 0) return;
printf("Close TCP server");
Socket::close(tcpSockId);
tcpSockId = -1;
remoteUdpAddress = Address();
}
void Tunnel::close()
{ closeUdpServer(); }
void Tunnel::readUdp() {
// read and parse UDP input
Packet packet;
const int minPacketSize = offsetof(Packet, payload) - offsetof(Packet, connId);
while(true) {
struct sockaddr_in addr = {};
socklen_t addrlen = sizeof(addr);
int res = ::recvfrom(
udpSockId,
&packet.connId,
MSG_DONTWAIT,
sizeof(packet) - offsetof(Packet, connId),
(struct sockaddr*)&addr,
&addrlen );
if (res <= 0)
break;
if (res < minPacketSize)
continue;
int cmd = packet.getCommand();
int size = packet.getSize();
if (res != minPacketSize + packet.getSize())
continue;
if (cmd != CMD_TERM && cmd != CMD_CONFIRM && cmd != CMD_DATA)
continue;
ConnId id;
id.address.ipUInt = addr.sin_addr.s_addr;
id.address.port = addr.sin_port;
id.id = packet.connId;
int cmd = packet.getCommand();
ConnMap::iterator i = connections.find(id);
Connection *conn = i == connections.end() ? nullptr : &i->second;
if (!conn && remoteTcpAddress.port && cmd != CMD_TERM) {
printf("Open incoming connection from UDP: ");
id.address.print();
printf(", id: %u\n", id.id);
int sockId = ::socket(PF_INET, SOCK_DGRAM, IPPROTO_UDP);
struct sockaddr_in addr = {};
addr.sin_family = PF_INET;
addr.sin_addr.s_addr = id.address.ipUInt;
addr.sin_port = id.address.port;
if (0 != ::connect(sockId, (struct sockaddr*)&addr, sizeof(addr))) {
::close(sockId);
printf("cannot init TCP connection, cancelled");
continue;
}
conn = &connections[id];
conn->id = id;
conn->tcpSockId = sockId;
conn->udpActive = true;
}
if (!conn) {
if (cmd != CMD_TERM)
rejectedConnections.insert(id);
continue;
}
if (cmd == CMD_TERM) {
conn->udpActive = false;
continue;
}
if (cmd == CMD_CONFIRM) {
unsigned int index = packet.index;
for(int i = 0; i < size; ++i)
for(int j = 0; j < 8; ++j, ++index)
if (packet.payload[i] & (1 << j))
if (PacketInfo *pi = conn->toUdpQueue[index])
if (pi->loaded)
pi->confirmed = true;
while(PacketInfo *pi = conn->toUdpQueue.first())
if (pi->loaded && pi->confirmed)
conn->toUdpQueue.pop(); else break;
continue;
}
}
}
bool Tunnel::iteration() {
if (udpSockId < 0)
return false;
// prepare poll
poll.list.clear();
poll.list.emplace_back();
Poll::Entry &e = poll.list.back();
e.sockId = udpSockId;
e.wantRead = true;
if (tcpSockId >= 0) {
Poll::Entry &e = poll.list.back();
e.sockId = tcpSockId;
e.wantRead = true;
}
Time pollTime = time + pollDuration;
for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i) {
Connection &conn = i->second;
if (conn.tcpSockId >= 0) {
poll.list.emplace_back();
Poll::Entry &e = poll.list.back();
e.connection = &conn;
e.sockId = conn.tcpSockId;
e.wantRead = conn.udpActive && conn.wantReadFromTcp();
e.wantWrite = conn.wantWriteToTcp();
}
if (conn.udpActive) {
Time t = conn.whenWriteToUdp();
if (timeLess(pollTime, t)) pollTime = t;
}
}
// wait for events
poll.wait(pollTime);
time = monotonicTime();
if (poll.list[0].closed || (tcpSockId >= 0 && poll.list[1].closed)) {
poll.list.clear();
close();
return false;
}
// read and parse udp
if (poll.list[0].canRead)
while(recvUdpPacket());
if (tcpSockId >= 0 && poll.list[1].canRead) {
// accept new tcp connection
int sockId = Socket::tcpAccept(tcpSockId);
if (sockId >= 0) {
ConnId id;
id.address = remoteUdpAddress;
id.id = ConnId::generateId();
Connection &conn = connections.emplace(id, *this, id, tcpSockId).first->second;
conn.tcpRecvQueue.readFromTcp();
}
}
// tcp read and write
int first = tcpSockId < 0 ? 1 : 2;
for(Poll::List::const_iterator i = poll.list.begin() + first; i != poll.list.end(); ++i) {
Connection &conn = *i->connection;
conn.udpRecvQueue.sendTcp();
if (i->canWrite)
conn.tcpSendQueue.writeToTcp();
if (i->canRead && conn.udpActive)
conn.tcpRecvQueue.readFromTcp();
if (i->closed || (!conn.udpActive && !conn.tcpSendQueue.busySize())) {
Socket::close(conn.tcpSockId);
conn.tcpSockId = -1;
}
}
// write to udp
for(ConnMap::iterator i = connections.begin(); i != connections.end();) {
ConnMap::iterator j = i++;
Connection &conn = j->second;
if (conn.udpActive) {
conn.udpSendQueue.sendUdp();
conn.udpRecvQueue.sendUdpConfirm();
} else
if (conn.tcpSockId < 0) {
connections.erase(j);
}
}
// send udp term packets
if (timeLequal(nextTermSendTime, time)) {
for(ConnSet::const_iterator i = termQueue.begin(); i != termQueue.end(); ++i) {
Packet packet;
packet.init(i->id, 0, CMD_TERM, 0);
sendUdpPacket(i->address, packet);
}
termQueue.clear();
nextTermSendTime = time + nextTermSendTime;
}
poll.list.clear();
return true;
}
void Tunnel::run() {
time = monotonicTime();
nextTermSendTime = time + udpConfirmDuration;
while(iteration());
}