Blob Blame Raw

#include <cstdio>

#include "socket.h"
#include "tunnel.h"



Tunnel::Tunnel():
	time(monotonicTime()),
	udpSendDuration(1000),
	udpPartialSendDuration(500000),
	udpResendDuration(udpSendDuration * PACKETS_COUNT),
	udpConfirmDuration(udpResendDuration/8),
	pollDuration(5000000),
	udpSockId(-1),
	tcpSockId(-1),
	nextTermSendTime(time + udpConfirmDuration)
	{ }

Tunnel::~Tunnel()
	{ close(); }


bool Tunnel::initUdpServer(const Address &address, const Address &remoteTcpAddress) {
	closeUdpServer();
	printf("Init UDP server at address: "); address.print();
	if (remoteTcpAddress.port)
		{ printf(", with remote TCP address: "); remoteTcpAddress.print(); }
	printf("\n");
	
	udpSockId = Socket::udpBind(address);
	if (udpSockId < 0)
		{ printf("cannot bind, cancelled"); return false; }
	
	printf("success");
	this->remoteTcpAddress = remoteTcpAddress;
	return true;
}

bool Tunnel::initTcpServer(const Address &address) {
	closeTcpServer();
	printf("Init TCP server at address: "); address.print(); printf("\n");
	
	if (udpSockId < 0)
		{ printf("UDP server was not initialized, cancelled\n"); return false; }
	
	int sockId = ::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
	struct sockaddr_in addr = {};
	addr.sin_family = PF_INET;
	addr.sin_addr.s_addr = address.ipUInt;
	addr.sin_port = address.port;
	if (0 != ::bind(sockId, (struct sockaddr*)&addr, sizeof(addr)))
		{ ::close(sockId); printf("cannot bind, cancelled\n"); return false; }
	if (0 != ::listen(sockId, 32))
		{ ::close(sockId); printf("cannot listen, cancelled\n"); return false; }

	tcpSockId = sockId;
	printf("success");
	return true;
}


void Tunnel::closeUdpServer() {
	if (udpSockId < 0) return;
	closeTcpServer();
	
	printf("Close UDP server");
	for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i)
		Socket::close(i->second.tcpSockId);
	connections.clear();
	termQueue.clear();
	Socket::close(udpSockId);
	udpSockId = -1;
	remoteTcpAddress = Address();
}


void Tunnel::closeTcpServer() {
	if (tcpSockId < 0) return;
	printf("Close TCP server");
	Socket::close(tcpSockId);
	tcpSockId = -1;
	remoteUdpAddress = Address();
}


void Tunnel::close()
	{ closeUdpServer(); }


void Tunnel::readUdp() {
	// read and parse UDP input
	Packet packet;
	const int minPacketSize = offsetof(Packet, payload) - offsetof(Packet, connId);
	while(true) {
		struct sockaddr_in addr = {};
		socklen_t addrlen = sizeof(addr);
		int res = ::recvfrom(
			udpSockId,
			&packet.connId,
			MSG_DONTWAIT,
			sizeof(packet) - offsetof(Packet, connId),
			(struct sockaddr*)&addr,
			&addrlen );
		if (res <= 0)
			break;
		if (res < minPacketSize)
			continue;
		
		int cmd = packet.getCommand();
		int size = packet.getSize();
		if (res != minPacketSize + packet.getSize())
			continue;
		if (cmd != CMD_TERM && cmd != CMD_CONFIRM && cmd != CMD_DATA)
			continue;
		
		ConnId id;
		id.address.ipUInt = addr.sin_addr.s_addr;
		id.address.port = addr.sin_port;
		id.id = packet.connId;
		int cmd = packet.getCommand();
		ConnMap::iterator i = connections.find(id);
		Connection *conn = i == connections.end() ? nullptr : &i->second;
		
		if (!conn && remoteTcpAddress.port && cmd != CMD_TERM) {
			printf("Open incoming connection from UDP: ");
			id.address.print();
			printf(", id: %u\n", id.id);
			
			int sockId = ::socket(PF_INET, SOCK_DGRAM, IPPROTO_UDP);
			struct sockaddr_in addr = {};
			addr.sin_family = PF_INET;
			addr.sin_addr.s_addr = id.address.ipUInt;
			addr.sin_port = id.address.port;
			if (0 != ::connect(sockId, (struct sockaddr*)&addr, sizeof(addr))) {
				::close(sockId);
				printf("cannot init TCP connection, cancelled");
				continue;
			}
			
			conn = &connections[id];
			conn->id = id;
			conn->tcpSockId = sockId;
			conn->udpActive = true;
		}
		
		if (!conn) {
			if (cmd != CMD_TERM)
				rejectedConnections.insert(id);
			continue;
		}
		
		if (cmd == CMD_TERM) {
			conn->udpActive = false;
			continue;
		}
		
		if (cmd == CMD_CONFIRM) {
			unsigned int index = packet.index;
			for(int i = 0; i < size; ++i)
				for(int j = 0; j < 8; ++j, ++index)
					if (packet.payload[i] & (1 << j))
						if (PacketInfo *pi = conn->toUdpQueue[index])
							if (pi->loaded)
								pi->confirmed = true;
			while(PacketInfo *pi = conn->toUdpQueue.first())
				if (pi->loaded && pi->confirmed)
					conn->toUdpQueue.pop(); else break;
			continue;
		}
	}
}


bool Tunnel::iteration() {
	if (udpSockId < 0)
		return false;
	
	// prepare poll
	
	poll.list.clear();
	
	poll.list.emplace_back();
	Poll::Entry &e = poll.list.back();
	e.sockId = udpSockId;
	e.wantRead = true;
	
	if (tcpSockId >= 0) {
		Poll::Entry &e = poll.list.back();
		e.sockId = tcpSockId;
		e.wantRead = true;
	}
	
	Time pollTime = time + pollDuration;
	for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i) {
		Connection &conn = i->second;
		if (conn.tcpSockId >= 0) {
			poll.list.emplace_back();
			Poll::Entry &e = poll.list.back();
			e.connection = &conn;
			e.sockId = conn.tcpSockId;
			e.wantRead = conn.udpActive && conn.wantReadFromTcp();
			e.wantWrite = conn.wantWriteToTcp();
		}
		if (conn.udpActive) {
			Time t = conn.whenWriteToUdp();
			if (timeLess(pollTime, t)) pollTime = t;
		}
	}
	
	// wait for events
	poll.wait(pollTime);
	time = monotonicTime();
	
	if (poll.list[0].closed || (tcpSockId >= 0 && poll.list[1].closed)) {
		poll.list.clear();
		close();
		return false;
	}
	
	// read and parse udp
	if (poll.list[0].canRead)
		while(recvUdpPacket());
	
	if (tcpSockId >= 0 && poll.list[1].canRead) {
		// accept new tcp connection
		int sockId = Socket::tcpAccept(tcpSockId);
		if (sockId >= 0) {
			ConnId id;
			id.address = remoteUdpAddress;
			id.id = ConnId::generateId();
			Connection &conn = connections.emplace(id, *this, id, tcpSockId).first->second;
			conn.tcpRecvQueue.readFromTcp();
		}
	}
	
	// tcp read and write
	int first = tcpSockId < 0 ? 1 : 2;
	for(Poll::List::const_iterator i = poll.list.begin() + first; i != poll.list.end(); ++i) {
		Connection &conn = *i->connection;
		conn.udpRecvQueue.sendTcp();
		if (i->canWrite)
			conn.tcpSendQueue.writeToTcp();
		if (i->canRead && conn.udpActive)
			conn.tcpRecvQueue.readFromTcp();
		if (i->closed || (!conn.udpActive && !conn.tcpSendQueue.busySize())) {
			Socket::close(conn.tcpSockId);
			conn.tcpSockId = -1;
		}
	}
	
	// write to udp
	for(ConnMap::iterator i = connections.begin(); i != connections.end();) {
		ConnMap::iterator j = i++;
		Connection &conn = j->second;
		if (conn.udpActive) {
			conn.udpSendQueue.sendUdp();
			conn.udpRecvQueue.sendUdpConfirm();
		} else
		if (conn.tcpSockId < 0) {
			connections.erase(j);
		}
	}
	
	// send udp term packets
	if (timeLequal(nextTermSendTime, time)) {
		for(ConnSet::const_iterator i = termQueue.begin(); i != termQueue.end(); ++i) {
			Packet packet;
			packet.init(i->id, 0, CMD_TERM, 0);
			sendUdpPacket(i->address, packet);
		}
		termQueue.clear();
		nextTermSendTime = time + nextTermSendTime;
	}
	
	poll.list.clear();
	return true;
}


void Tunnel::run() {
	time = monotonicTime();
	nextTermSendTime = time + udpConfirmDuration;
	while(iteration());
}