Blob Blame Raw

#include <cstdio>

#include "socket.h"
#include "tunnel.h"



Tunnel::Tunnel():
	time(monotonicTime()),
	udpSendDuration(1000),
	udpPartialSendDuration(500000),
	udpResendDuration(udpSendDuration * PACKETS_COUNT),
	udpConfirmDuration(udpResendDuration/8),
	udpSilenceDuration(300000000),
	udpKeepAliveDuration(udpSilenceDuration/32),
	pollDuration(5000000),
	udpSockId(-1),
	tcpSockId(-1),
	nextTermSendTime(time + udpConfirmDuration)
	{ }


Tunnel::~Tunnel()
	{ close(); }


bool Tunnel::initUdpServer(const Address &address, const Address &remoteTcpAddress, const char *key) {
	closeUdpServer();
	printf("Init UDP server at address: "); address.print();
	if (remoteTcpAddress.port)
		{ printf(", with remote TCP address: "); remoteTcpAddress.print(); }
	printf("\n");
	
	if (!crypt.setKey(key))
		{ printf("bad key, cancelled"); return false; }
	
	udpSockId = Socket::udpBind(address);
	if (udpSockId < 0)
		{ crypt.resetKey(); printf("cannot bind, cancelled"); return false; }
	
	printf("success");
	this->remoteTcpAddress = remoteTcpAddress;
	return true;
}


bool Tunnel::initTcpServer(const Address &address, const Address &remoteUdpAddress) {
	closeTcpServer();
	printf("Init TCP server at address: "); address.print(); printf("\n");
	
	if (udpSockId < 0)
		{ printf("UDP server was not initialized, cancelled\n"); return false; }
	
	tcpSockId = Socket::tcpListen(address);
	if (tcpSockId < 0)
		{ printf("cannot listen, cancelled"); return false; }
	
	printf("success");
	this->remoteUdpAddress = remoteUdpAddress;
	return true;
}


void Tunnel::closeUdpServer() {
	if (udpSockId < 0) return;
	closeTcpServer();
	
	printf("Close UDP server");
	for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i)
		Socket::close(i->second.tcpSockId);
	connections.clear();
	termQueue.clear();
	Socket::close(udpSockId);
	udpSockId = -1;
	remoteTcpAddress = Address();
	crypt.resetKey();
}


void Tunnel::closeTcpServer() {
	if (tcpSockId < 0) return;
	printf("Close TCP server");
	Socket::close(tcpSockId);
	tcpSockId = -1;
	remoteUdpAddress = Address();
}


void Tunnel::close()
	{ closeUdpServer(); }


bool Tunnel::recvUdpPacket() {
	ConnId id;
	unsigned char raw[CRYPT_PACKET_SIZE] = {};
	int rawSize = Socket::udpRecvAsync(udpSockId, id.address, raw, sizeof(raw));
	if (rawSize <= 0)
		return false;
	
	unsigned char data[CRYPT_PACKET_SIZE] = {};
	int dataSize = CRYPT_PACKET_SIZE;
	if (!crypt.decrypt(raw, rawSize, data, dataSize))
		return true;
	
	void *ptr = data;
	Packet &packet = *(Packet*)ptr;
	
	int cmd = packet.getCommand();
	if (cmd < 0 || cmd > CMDMAX)
		return true;

	int size = packet.getSize();
	if (size + HEADER_SIZE > dataSize)
		return true;
	
	id.id = packet.connId;
	ConnMap::iterator i = connections.find(id);
	if (i == connections.end() && cmd == CMD_DATA && remoteTcpAddress.port && !closedConnections.count(id)) {
		printf("Open incoming connection from UDP: ");
		id.address.print();
		printf(", id: %u\n", id.id);
		int sockId = Socket::tcpConnect(remoteTcpAddress);
		if (sockId < 0) {
			printf("cannot init TCP connection, cancelled\n");
		} else {
			i = connections.emplace(id, Connection::Args(*this, id, sockId)).first;
		}
	}
	
	if (i == connections.end()) {
		termQueue.insert(id);
		return true;
	}
	
	Connection &conn = i->second;
	if (!conn.udpActive) {
		if (cmd != CMD_TERM) termQueue.insert(id);
		return true;
	}
	
	if (cmd == CMD_DATA) {
		conn.lastUdpInputTime = time;
		conn.udpRecvQueue.recvUdp(packet.index, packet.payload, size);
	} else
	if (cmd == CMD_CONFIRM) {
		conn.lastUdpInputTime = time;
		conn.udpSendQueue.confirm(packet.index, packet.payload, size*8);
	} else
	if (cmd == CMD_TERM) {
		conn.udpActive = false;
	}
	
	return true;
}


void Tunnel::sendUdpPacket(const Address &address, const Packet &packet) {
	unsigned char raw[CRYPT_PACKET_SIZE] = {};
	int rawSize = CRYPT_PACKET_SIZE;
	if (crypt.encrypt(&packet, HEADER_SIZE + packet.getSize(), raw, rawSize))
		Socket::udpSend(udpSockId, address, raw, rawSize);
}


bool Tunnel::iteration() {
	if (udpSockId < 0)
		return false;
	
	// prepare poll
	
	poll.list.clear();
	
	poll.list.emplace_back();
	Poll::Entry &e = poll.list.back();
	e.sockId = udpSockId;
	e.wantRead = true;
	
	if (tcpSockId >= 0) {
		Poll::Entry &e = poll.list.back();
		e.sockId = tcpSockId;
		e.wantRead = true;
	}
	
	Time pollTime = time + pollDuration;
	for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i) {
		Connection &conn = i->second;
		if (conn.tcpSockId >= 0) {
			poll.list.emplace_back();
			Poll::Entry &e = poll.list.back();
			e.connection = &conn;
			e.sockId = conn.tcpSockId;
			e.wantRead = conn.udpActive && conn.wantReadFromTcp();
			e.wantWrite = conn.wantWriteToTcp();
		}
		if (conn.udpActive) {
			Time t = conn.whenWriteToUdp();
			if (timeLess(pollTime, t)) pollTime = t;
		}
	}
	
	// wait for events
	poll.wait(pollTime);
	time = monotonicTime();
	crypt.setTime(time);
	
	if (poll.list[0].closed || (tcpSockId >= 0 && poll.list[1].closed)) {
		poll.list.clear();
		close();
		return false;
	}
	
	// read and parse udp
	if (poll.list[0].canRead)
		while(recvUdpPacket());
	
	if (tcpSockId >= 0 && poll.list[1].canRead) {
		// accept new tcp connection
		int sockId = Socket::tcpAccept(tcpSockId);
		if (sockId >= 0) {
			ConnId id;
			id.address = remoteUdpAddress;
			id.id = ConnId::generateId();
			Connection &conn = connections.emplace(id, Connection::Args(*this, id, tcpSockId)).first->second;
			conn.tcpRecvQueue.readFromTcp();
		}
	}
	
	// tcp read and write
	int first = tcpSockId < 0 ? 1 : 2;
	for(Poll::List::const_iterator i = poll.list.begin() + first; i != poll.list.end(); ++i) {
		Connection &conn = *i->connection;
		conn.udpRecvQueue.sendTcp();
		if (i->canWrite) while(true) {
			int size = conn.tcpSendQueue.writeToTcp();
			if (size <= 0) break;
			conn.udpRecvQueue.sendTcp();
		}
		if (i->canRead && conn.udpActive)
			conn.tcpRecvQueue.readFromTcp();
		if (i->closed || (!conn.udpActive && !conn.tcpSendQueue.busySize())) {
			Socket::close(conn.tcpSockId);
			conn.tcpSockId = -1;
		}
	}
	
	// write to udp
	for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i) {
		Connection &conn = i->second;
		if (conn.udpActive) {
			bool sent = false;
			if (conn.udpSendQueue.sendUdp()) sent = true;
			if (conn.udpRecvQueue.sendUdpConfirm()) sent = true;
			if (sent) conn.lastUdpOutputTime = time;
			if (timeLequal(conn.lastUdpOutputTime + udpKeepAliveDuration, time)) {
				conn.udpRecvQueue.sendUdpConfirm(true);
				conn.lastUdpOutputTime = time;
			}
		}
	}
	
	// close connections	
	for(ConnMap::iterator i = connections.begin(); i != connections.end();) {
		ConnMap::iterator j = i++;
		Connection &conn = j->second;
		
		if (conn.udpActive && timeLess(conn.lastUdpInputTime + udpSilenceDuration, time))
			conn.udpActive = false;
		
		if (!conn.udpActive && !conn.wantWriteToTcp()) {
			Socket::close(conn.tcpSockId);
			conn.tcpSockId = -1;
			closedConnections[conn.id] = time + 2*udpSilenceDuration;
			connections.erase(i);
			continue;
		}
		
		if (conn.tcpSockId < 0 && !conn.udpSendQueue.canSentToUdp()) {
			conn.udpActive = false;
			closedConnections[conn.id] = time + 2*udpSilenceDuration;
			connections.erase(i);
			continue;
		}
		
		if (!conn.udpActive && conn.tcpSockId < 0) {
			closedConnections[conn.id] = time + 2*udpSilenceDuration;
			connections.erase(i);
			continue;
		}
	}
	
	// send udp term packets
	if (timeLequal(nextTermSendTime, time)) {
		for(ConnSet::const_iterator i = termQueue.begin(); i != termQueue.end(); ++i) {
			Packet packet;
			packet.init(i->id, 0, CMD_TERM, 0);
			sendUdpPacket(i->address, packet);
		}
		termQueue.clear();
		nextTermSendTime = time + nextTermSendTime;
	}
	
	poll.list.clear();
	return true;
}


void Tunnel::run() {
	time = monotonicTime();
	nextTermSendTime = time + udpConfirmDuration;
	while(iteration());
}