#include <cstdio>
#include "socket.h"
#include "tunnel.h"
Tunnel::Tunnel():
time(monotonicTime()),
udpSendDuration(1000),
udpPartialSendDuration(500000),
udpResendDuration(udpSendDuration * PACKETS_COUNT),
udpConfirmDuration(udpResendDuration/8),
udpSilenceDuration(300000000),
udpKeepAliveDuration(udpSilenceDuration/32),
pollDuration(5000000),
udpSockId(-1),
tcpSockId(-1),
nextTermSendTime(time + udpConfirmDuration)
{ }
Tunnel::~Tunnel()
{ close(); }
bool Tunnel::initUdpServer(const Address &address, const Address &remoteTcpAddress, const char *key) {
closeUdpServer();
printf("Init UDP server at address: "); address.print();
if (remoteTcpAddress.port)
{ printf(", with remote TCP address: "); remoteTcpAddress.print(); }
printf("\n");
if (!crypt.setKey(key))
{ printf("bad key, cancelled"); return false; }
udpSockId = Socket::udpBind(address);
if (udpSockId < 0)
{ crypt.resetKey(); printf("cannot bind, cancelled"); return false; }
printf("success");
this->remoteTcpAddress = remoteTcpAddress;
return true;
}
bool Tunnel::initTcpServer(const Address &address, const Address &remoteUdpAddress) {
closeTcpServer();
printf("Init TCP server at address: "); address.print(); printf("\n");
if (udpSockId < 0)
{ printf("UDP server was not initialized, cancelled\n"); return false; }
tcpSockId = Socket::tcpListen(address);
if (tcpSockId < 0)
{ printf("cannot listen, cancelled"); return false; }
printf("success");
this->remoteUdpAddress = remoteUdpAddress;
return true;
}
void Tunnel::closeUdpServer() {
if (udpSockId < 0) return;
closeTcpServer();
printf("Close UDP server");
for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i)
Socket::close(i->second.tcpSockId);
connections.clear();
termQueue.clear();
Socket::close(udpSockId);
udpSockId = -1;
remoteTcpAddress = Address();
crypt.resetKey();
}
void Tunnel::closeTcpServer() {
if (tcpSockId < 0) return;
printf("Close TCP server");
Socket::close(tcpSockId);
tcpSockId = -1;
remoteUdpAddress = Address();
}
void Tunnel::close()
{ closeUdpServer(); }
bool Tunnel::recvUdpPacket() {
ConnId id;
unsigned char raw[CRYPT_PACKET_SIZE] = {};
int rawSize = Socket::udpRecvAsync(udpSockId, id.address, raw, sizeof(raw));
if (rawSize <= 0)
return false;
unsigned char data[CRYPT_PACKET_SIZE] = {};
int dataSize = CRYPT_PACKET_SIZE;
if (!crypt.decrypt(raw, rawSize, data, dataSize))
return true;
void *ptr = data;
Packet &packet = *(Packet*)ptr;
int cmd = packet.getCommand();
if (cmd < 0 || cmd > CMDMAX)
return true;
int size = packet.getSize();
if (size + HEADER_SIZE > dataSize)
return true;
id.id = packet.connId;
ConnMap::iterator i = connections.find(id);
if (i == connections.end() && cmd == CMD_DATA && remoteTcpAddress.port && !closedConnections.count(id)) {
printf("Open incoming connection from UDP: ");
id.address.print();
printf(", id: %u\n", id.id);
int sockId = Socket::tcpConnect(remoteTcpAddress);
if (sockId < 0) {
printf("cannot init TCP connection, cancelled\n");
} else {
i = connections.emplace(id, Connection::Args(*this, id, sockId)).first;
}
}
if (i == connections.end()) {
termQueue.insert(id);
return true;
}
Connection &conn = i->second;
if (!conn.udpActive) {
if (cmd != CMD_TERM) termQueue.insert(id);
return true;
}
if (cmd == CMD_DATA) {
conn.lastUdpInputTime = time;
conn.udpRecvQueue.recvUdp(packet.index, packet.payload, size);
} else
if (cmd == CMD_CONFIRM) {
conn.lastUdpInputTime = time;
conn.udpSendQueue.confirm(packet.index, packet.payload, size*8);
} else
if (cmd == CMD_TERM) {
conn.udpActive = false;
}
return true;
}
void Tunnel::sendUdpPacket(const Address &address, const Packet &packet) {
unsigned char raw[CRYPT_PACKET_SIZE] = {};
int rawSize = CRYPT_PACKET_SIZE;
if (crypt.encrypt(&packet, HEADER_SIZE + packet.getSize(), raw, rawSize))
Socket::udpSend(udpSockId, address, raw, rawSize);
}
bool Tunnel::iteration() {
if (udpSockId < 0)
return false;
// prepare poll
poll.list.clear();
poll.list.emplace_back();
Poll::Entry &e = poll.list.back();
e.sockId = udpSockId;
e.wantRead = true;
if (tcpSockId >= 0) {
Poll::Entry &e = poll.list.back();
e.sockId = tcpSockId;
e.wantRead = true;
}
Time pollTime = time + pollDuration;
for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i) {
Connection &conn = i->second;
if (conn.tcpSockId >= 0) {
poll.list.emplace_back();
Poll::Entry &e = poll.list.back();
e.connection = &conn;
e.sockId = conn.tcpSockId;
e.wantRead = conn.udpActive && conn.wantReadFromTcp();
e.wantWrite = conn.wantWriteToTcp();
}
if (conn.udpActive) {
Time t = conn.whenWriteToUdp();
if (timeLess(pollTime, t)) pollTime = t;
}
}
// wait for events
poll.wait(pollTime);
time = monotonicTime();
crypt.setTime(time);
if (poll.list[0].closed || (tcpSockId >= 0 && poll.list[1].closed)) {
poll.list.clear();
close();
return false;
}
// read and parse udp
if (poll.list[0].canRead)
while(recvUdpPacket());
if (tcpSockId >= 0 && poll.list[1].canRead) {
// accept new tcp connection
int sockId = Socket::tcpAccept(tcpSockId);
if (sockId >= 0) {
ConnId id;
id.address = remoteUdpAddress;
id.id = ConnId::generateId();
Connection &conn = connections.emplace(id, Connection::Args(*this, id, tcpSockId)).first->second;
conn.tcpRecvQueue.readFromTcp();
}
}
// tcp read and write
int first = tcpSockId < 0 ? 1 : 2;
for(Poll::List::const_iterator i = poll.list.begin() + first; i != poll.list.end(); ++i) {
Connection &conn = *i->connection;
conn.udpRecvQueue.sendTcp();
if (i->canWrite) while(true) {
int size = conn.tcpSendQueue.writeToTcp();
if (size <= 0) break;
conn.udpRecvQueue.sendTcp();
}
if (i->canRead && conn.udpActive)
conn.tcpRecvQueue.readFromTcp();
if (i->closed || (!conn.udpActive && !conn.tcpSendQueue.busySize())) {
Socket::close(conn.tcpSockId);
conn.tcpSockId = -1;
}
}
// write to udp
for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i) {
Connection &conn = i->second;
if (conn.udpActive) {
bool sent = false;
if (conn.udpSendQueue.sendUdp()) sent = true;
if (conn.udpRecvQueue.sendUdpConfirm()) sent = true;
if (sent) conn.lastUdpOutputTime = time;
if (timeLequal(conn.lastUdpOutputTime + udpKeepAliveDuration, time)) {
conn.udpRecvQueue.sendUdpConfirm(true);
conn.lastUdpOutputTime = time;
}
}
}
// close connections
for(ConnMap::iterator i = connections.begin(); i != connections.end();) {
ConnMap::iterator j = i++;
Connection &conn = j->second;
if (conn.udpActive && timeLess(conn.lastUdpInputTime + udpSilenceDuration, time))
conn.udpActive = false;
if (!conn.udpActive && !conn.wantWriteToTcp()) {
Socket::close(conn.tcpSockId);
conn.tcpSockId = -1;
closedConnections[conn.id] = time + 2*udpSilenceDuration;
connections.erase(i);
continue;
}
if (conn.tcpSockId < 0 && !conn.udpSendQueue.canSentToUdp()) {
conn.udpActive = false;
closedConnections[conn.id] = time + 2*udpSilenceDuration;
connections.erase(i);
continue;
}
if (!conn.udpActive && conn.tcpSockId < 0) {
closedConnections[conn.id] = time + 2*udpSilenceDuration;
connections.erase(i);
continue;
}
}
// send udp term packets
if (timeLequal(nextTermSendTime, time)) {
for(ConnSet::const_iterator i = termQueue.begin(); i != termQueue.end(); ++i) {
Packet packet;
packet.init(i->id, 0, CMD_TERM, 0);
sendUdpPacket(i->address, packet);
}
termQueue.clear();
nextTermSendTime = time + nextTermSendTime;
}
poll.list.clear();
return true;
}
void Tunnel::run() {
time = monotonicTime();
nextTermSendTime = time + udpConfirmDuration;
while(iteration());
}