diff --git a/main.cpp b/main.cpp index bfe0b63..24fa62a 100644 --- a/main.cpp +++ b/main.cpp @@ -6,6 +6,61 @@ #include "tunnel.h" +static bool parseFloatConnections(bool &floatConnections) { + const char *envvar = "ICETUNNEL2_FLOAT"; + const char *str = getenv(envvar); + if (str) { + if (str[0] == '1' && !str[1]) { + logf("Allow floating connetcions (envvar: %s)\n", envvar); + floatConnections = true; + } else + if (str[0] == '0' && !str[1]) { + logf("Disallow floating connetcions (envvar: %s)\n", envvar); + floatConnections = false; + } else { + logf("Wrong floating flag in envvar %s: %s\n", envvar, str); + return false; + } + } + return true; +} + + +static bool parseAddressMap(Tunnel::IpMap &ipMap) { + const char *envvar = "ICETUNNEL2_IPMAP"; + const char *str = getenv(envvar); + + bool success = true; + const char *pos = str; + while(pos && *pos) { + union { + unsigned int ipA, ipB; + unsigned char ips[8] = {}; + }; + int cnt = sscanf( pos, "%hhu.%hhu.%hhu.%hhu:%hhu.%hhu.%hhu.%hhu", + &ips[0], &ips[1], &ips[2], &ips[3], + &ips[4], &ips[5], &ips[6], &ips[7] ); + if (cnt != 8) { success = false; break; } + + logf("Add ip map %hhu.%hhu.%hhu.%hhu -> %hhu.%hhu.%hhu.%hhu (envvar: %s)\n", + ips[0], ips[1], ips[2], ips[3], + ips[4], ips[5], ips[6], ips[7], envvar ); + ipMap[ipA] = ipB; + + while(*pos && *pos != '|') ++pos; + if (*pos != '|') break; + ++pos; + } + + if (!success) { + logf("Wrong address map in envvar %s: %s\n", envvar, str); + return false; + } + + return true; +} + + static bool parseAddress(const char *str, Address &address) { address = Address(); sscanf( str, "%hhu.%hhu.%hhu.%hhu:%hu", @@ -87,6 +142,9 @@ int main(int argc, char **argv) { Tunnel tunnel; + if (!parseFloatConnections(tunnel.floatConnections)) return 1; + if (!parseAddressMap(tunnel.ipMap)) return 1; + tunnel.initSpeed(bytesPerSecond); if ( !tunnel.initUdpServer(localUdp, remoteTcp, key) ) { diff --git a/tunnel.cpp b/tunnel.cpp index a003202..28f7d5b 100644 --- a/tunnel.cpp +++ b/tunnel.cpp @@ -116,6 +116,9 @@ bool Tunnel::recvUdpPacket() { if (rawSize <= 0) return false; + IpMap::const_iterator ipi = ipMap.find(id.address.ipUInt); + if (ipi != ipMap.end()) id.address.ipUInt = ipi->second; + unsigned char data[CRYPT_PACKET_SIZE] = {}; int dataSize = CRYPT_PACKET_SIZE; if (!crypt.decrypt(raw, rawSize, data, dataSize)) { @@ -262,8 +265,11 @@ bool Tunnel::iteration() { int sockId = Socket::tcpAccept(tcpSockId, address); if (sockId >= 0) { ConnId id, floatId; - id.address = remoteUdpAddress; floatId.id = id.id = crypt.random(); + id.address = remoteUdpAddress; + IpMap::const_iterator ipi = ipMap.find(id.address.ipUInt); + if (ipi != ipMap.end()) id.address.ipUInt = ipi->second; + logf("Open incoming connection from TCP: "); address.print(); logf(", id: %u\n", id.id); diff --git a/tunnel.h b/tunnel.h index 8ef8c3a..ef5edd1 100644 --- a/tunnel.h +++ b/tunnel.h @@ -16,6 +16,7 @@ public: typedef std::map ConnMap; typedef std::map ClosedConnMap; typedef std::set ConnSet; + typedef std::map IpMap; public: Time time; @@ -29,6 +30,8 @@ public: Time pollDuration; bool floatConnections; + IpMap ipMap; + ConnMap connections; ClosedConnMap closedConnections; ConnSet termQueue;