Blob Blame Raw

#include <cstdio>
#include <cerrno>
#include <cstring>

#include "tunnel.h"


static bool parseAddress(const char *str, Address &address) {
	address = Address();
	sscanf( str, "%hhu.%hhu.%hhu.%hhu:%hu",
			&address.ip[0], &address.ip[1], &address.ip[2], &address.ip[3],
			&address.port );
	if (!address.port) {
		logf("Wrong ip-address: [%s]\n", str);
		return false;
	}
	return true;
}


static bool openLogFile(const char *logfile) {
	if (0 != strcmp(logfile, "-")) {
		bool eq = stderr == stdout;
		stderr = stderr ? fopen(logfile, "a+") : freopen(logfile, "a+", stderr);
		if (!stderr)
			{ logf("Cannot open log file: %s\n", logfile); return false; }
		if (!eq && stdout && stdout != stderr) fclose(stdout);
		stdout = stderr;
	}
	return true;
}



int main(int argc, char **argv) {
	
	// parse args
	
	unsigned long long bytesPerSecond = 0;
	Address localUdp, remoteUdp, localTcp, remoteTcp;
	if (argc == 7) {
		bool l = openLogFile(argv[6]);
		sscanf(argv[2], "%llu", &bytesPerSecond);
		if (!bytesPerSecond)
			{ logf("Wrong bytes per second: %llu\n", bytesPerSecond); return 1; }
		parseAddress(argv[3], localUdp);
		parseAddress(argv[4], remoteUdp);
		parseAddress(argv[5], localTcp);
		if (!localUdp.port || !remoteUdp.port || !localTcp.port || !l) return 1;
	} else
	if (argc == 6) {
		bool l = openLogFile(argv[5]);
		sscanf(argv[2], "%llu", &bytesPerSecond);
		if (!bytesPerSecond)
			{ logf("Wrong bytes per second: %llu\n", bytesPerSecond); return 1; }
		parseAddress(argv[3], localUdp);
		parseAddress(argv[4], remoteTcp);
		if (!localUdp.port || !remoteTcp.port || !l) return 1;
	} else {
		logf("Usage: %s keyfile bytesPerSecond localUdp:port remoteUdp:port localTcp:port logfile|-\n", argv[0]);
		logf("   or: %s keyfile bytesPerSecond localUdp:port remoteTcp:port logfile|-\n", argv[0]);
		return 1;
	}
	
	// load key
	
	FILE *f = fopen(argv[1], "r");
	if (!f) {
		logf("Cannot open keyfile: [%s]\n", argv[1]);
		return 1;
	}
	int maxSize = 1024*1024;
	char *key = (char*)calloc(1, maxSize);
	char *c = key, *e = c + maxSize - 1;
	int i = 0;
	while((i = fgetc(f)) > 0 && c < e) *c++ = i;
	if (i != EOF || errno) {
		logf("Cannot read keyfile: [%s]\n", argv[1]);
		memset(key, 0, maxSize);
		free(key);
		fclose(f);
	}
	fclose(f);
	
	// init tunnel
	
	Tunnel tunnel;
	
	tunnel.initSpeed(bytesPerSecond);
	
	if ( !tunnel.initUdpServer(localUdp, remoteTcp, key) ) {
		memset(key, 0, maxSize);
		free(key);
		return 1;
	}
	memset(key, 0, maxSize);
	free(key);

	if ( localTcp.port && remoteUdp.port
	  && !tunnel.initTcpServer(localTcp, remoteUdp) )
		return 1;
	
	tunnel.run();
	
	return 0;
}