diff --git a/.gitignore b/.gitignore index fc44ca9..a45b846 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ /icetunnel2-debug /icetunnel2 /deps/ +/test/ diff --git a/build.sh b/build.sh index 58759c0..21f9429 100755 --- a/build.sh +++ b/build.sh @@ -4,7 +4,8 @@ set -e OPTS="-Ideps -Ldeps/cryptopp -lcryptopp" -if [ "$1" = "debug" ]; then +if [ "$1" == "debug" ]; then + echo debug c++ -Wall -g -O0 *.cpp $OPTS -o icetunnel2-debug else c++ -Wall -DNDEBUG -O3 *.cpp $OPTS -o icetunnel2 diff --git a/common.cpp b/common.cpp index 91cd471..d45d69f 100644 --- a/common.cpp +++ b/common.cpp @@ -20,3 +20,16 @@ Time monotonicTime() { Time globalTime() { return time(NULL)*1000000ull; } + + +const char* bitsToString(const void *data, int size) { + static char buffer[8*PACKET_SIZE + 1] = {}; + if (size > PACKET_SIZE) size = PACKET_SIZE; + char *c = buffer; + for(const unsigned char *p = (const unsigned char*)data, *e = p + size; p < e; ++p) + for(int i = 0; i < 8; ++i, ++c) + *c = (*p & (1 << i)) ? '#': '.'; + *c = 0; + return buffer; +} + diff --git a/common.h b/common.h index 206c19a..ae39f67 100644 --- a/common.h +++ b/common.h @@ -2,6 +2,24 @@ #define COMMON_H +#include + + +#define HideDebugMSG(...) do {} while(0) + +#ifdef NDEBUG +# define ShowDebugMSG(...) do {} while(0) +#else +# define ShowDebugMSG(...) do { \ + printf("DEBUG %16lld %s:%s:%d:", monotonicTime(), __FILE__, __func__, __LINE__); \ + printf(" " __VA_ARGS__); \ + printf("\n"); \ + fflush(stdout); \ + } while(0) +#endif + + + enum { PACKET_SIZE = 470, PACKETS_COUNT = 8*PACKET_SIZE, @@ -27,6 +45,7 @@ enum : unsigned short { typedef unsigned long long Time; Time monotonicTime(); Time globalTime(); +const char* bitsToString(const void *data, int size); inline bool cycleLequal(unsigned int a, unsigned int b) diff --git a/crypt.cpp b/crypt.cpp index 461377f..8d2cc6a 100644 --- a/crypt.cpp +++ b/crypt.cpp @@ -11,6 +11,11 @@ +//#define DebugMsg HideDebugMSG +#define DebugMsg ShowDebugMSG + + + class Crypt::Private { public: Crypt &crypt; @@ -37,6 +42,7 @@ public: tq[index] = timeQwant; for(unsigned char *e = iv[index], *c = e + CRYPT_BLOCK - 1; c >= e; --c) { *c = timeQwant & 0xff; timeQwant >>= 8; } + DebugMsg("%d, %llu", index, tq[index]); } inline bool hash256(const unsigned char *src, int srcSize, unsigned char *dst) { @@ -87,7 +93,7 @@ bool Crypt::setKey(const char *key) { static_assert(CRYPT_BLOCK > HASH_SIZE); static const char salt[] = "dnmnfhne78hj0fklf2891ve"; - const int count = CRYPT_BLOCK; + const int count = CRYPT_BLOCK/HASH_SIZE; try { priv->hasher.Restart(); priv->hasher.Update((const unsigned char*)salt, sizeof(salt)); @@ -115,8 +121,8 @@ void Crypt::setTime(Time time, Time qwant) { priv->setTimeQwant(tq, 0); priv->setTimeQwant(tq + 1, 1); } else { - priv->setTimeQwant(tq + 1, 1); - priv->setTimeQwant(tq, 0); + priv->setTimeQwant(tq + 1, 0); + priv->setTimeQwant(tq, 1); } } @@ -135,7 +141,7 @@ bool Crypt::encrypt(const void *src, int srcSize, void *dst, int &dstSize) { unsigned char data[CRYPT_PACKET_SIZE] = {}; memcpy(data + HASH_SIZE, src, srcSize); static_assert(HASH_SIZE == 256/8); - if ( !priv->hash256(data + HASH_SIZE, size - 8, data) + if ( !priv->hash256(data + HASH_SIZE, size - HASH_SIZE, data) || !priv->encrypt(data, (unsigned char*)dst, size) ) return false; @@ -163,7 +169,7 @@ bool Crypt::decrypt(const void *src, int srcSize, void *dst, int &dstSize) { for(int i = 0; i < 2; ++i) { static_assert(HASH_SIZE == 256/8); if ( !priv->decrypt((unsigned char*)src, data[i], size, i) - || !priv->hash256(data[i] + HASH_SIZE, size - 8, hash[i]) ) return false; + || !priv->hash256(data[i] + HASH_SIZE, size - HASH_SIZE, hash[i]) ) return false; valid[i] = 0 == memcmp(hash, data[i], HASH_SIZE); } diff --git a/socket.cpp b/socket.cpp index 6470f15..fe36c9d 100644 --- a/socket.cpp +++ b/socket.cpp @@ -1,6 +1,8 @@ #include #include +#include +#include #include #include @@ -16,7 +18,7 @@ -static void addressToInaddr(const Address &address, struct sockaddr_in addr) { +static void addressToInaddr(const Address &address, struct sockaddr_in &addr) { memset(&addr, 0, sizeof(addr)); addr.sin_family = AF_INET; addr.sin_addr.s_addr = address.ipUInt; @@ -59,7 +61,7 @@ bool Poll::wait(Time duration) { if (res > 0) { struct pollfd *pfd = pfdFirst; for(List::iterator i = list.begin(); i != list.end(); ++i, ++pfd) { - int e = pfd->revents; + unsigned int e = pfd->revents; i->canRead = (bool)(e & POLLIN); i->canWrite = (bool)(e & POLLOUT); i->closed = (bool)(e & (POLLHUP | POLLRDHUP | POLLERR | POLLNVAL)); @@ -99,7 +101,8 @@ int Socket::tcpListen(const Address &address) { } -int Socket::tcpAccept(int sockId) { +int Socket::tcpAccept(int sockId, Address &address) { + address = Address(); struct sockaddr_in addr = {}; addr.sin_family = AF_INET; socklen_t len = sizeof(addr); @@ -107,6 +110,8 @@ int Socket::tcpAccept(int sockId) { if (newSockId < 0) return -1; if (!configureTcpSocket(newSockId)) { ::close(newSockId); return -1; } + address.ipUInt = addr.sin_addr.s_addr; + address.port = ntohs(addr.sin_port); return newSockId; } diff --git a/socket.h b/socket.h index 538e4d6..2ef0920 100644 --- a/socket.h +++ b/socket.h @@ -46,7 +46,7 @@ private: public: static int tcpConnect(const Address &address); static int tcpListen(const Address &address); - static int tcpAccept(int sockId); + static int tcpAccept(int sockId, Address &address); static int tcpSendAsync(int sockId, const void *data, int size); static int tcpRecvAsync(int sockId, void *data, int size); diff --git a/tcpqueue.cpp b/tcpqueue.cpp index 852776d..3a6a85b 100644 --- a/tcpqueue.cpp +++ b/tcpqueue.cpp @@ -8,6 +8,11 @@ +//#define DebugMsg HideDebugMSG +#define DebugMsg ShowDebugMSG + + + TcpQueue::TcpQueue(Connection &connection): connection(connection), buffer(), begin(), end() { } @@ -58,23 +63,26 @@ int TcpQueue::readFromTcp() { return 0; int doneSize = 0; - int parts[2]; - parts[0] = ALLOCATED_SIZE - end; - parts[1] = parts[0] - size; - if (parts[1] < 0) - { parts[0] = size; parts[1] = 0; } + int parts[2] = { size, 0 }; + int p0 = ALLOCATED_SIZE - end; + if (size > p0) + { parts[0] = p0; parts[1] = size - p0; } for(int i = 0; i < 2; ++i) { while(parts[i]) { int res = Socket::tcpRecvAsync(connection.tcpSockId, buffer + end, parts[i]); - if (res <= 0) + if (res <= 0) { + DebugMsg("%u, %d", connection.id.id, doneSize); return doneSize; + } doneSize += res; end += res; parts[i] -= res; } if (end >= ALLOCATED_SIZE) end = 0; } + + DebugMsg("%u, %d", connection.id.id, doneSize); return doneSize; } @@ -87,22 +95,26 @@ int TcpQueue::writeToTcp() { return 0; int doneSize = 0; - int parts[2]; - parts[0] = ALLOCATED_SIZE - begin; - parts[1] = parts[0] - size; - if (parts[1] < 0) - { parts[0] = size; parts[1] = 0; } + int parts[2] = { size, 0 }; + int p0 = ALLOCATED_SIZE - begin; + if (size > p0) + { parts[0] = p0; parts[1] = size - p0; } for(int i = 0; i < 2; ++i) { while(parts[i]) { int res = Socket::tcpSendAsync(connection.tcpSockId, buffer + begin, parts[i]); - if (res <= 0) + if (res <= 0) { + DebugMsg("%u, %d", connection.id.id, doneSize); return doneSize; + } doneSize += res; begin += res; parts[i] -= res; } if (begin >= ALLOCATED_SIZE) begin = 0; } + + DebugMsg("%u, %d", connection.id.id, doneSize); return doneSize; } + diff --git a/tunnel.cpp b/tunnel.cpp index 7647eda..38b6093 100644 --- a/tunnel.cpp +++ b/tunnel.cpp @@ -6,10 +6,15 @@ +//#define DebugMsg HideDebugMSG +#define DebugMsg ShowDebugMSG + + + Tunnel::Tunnel(): time(monotonicTime()), udpSendDuration(1000), - udpPartialSendDuration(500000), + udpPartialSendDuration(10*udpSendDuration), udpResendDuration(udpSendDuration * PACKETS_COUNT), udpConfirmDuration(udpResendDuration/8), udpSilenceDuration(300000000), @@ -34,13 +39,13 @@ bool Tunnel::initUdpServer(const Address &address, const Address &remoteTcpAddre printf("\n"); if (!crypt.setKey(key)) - { printf("bad key, cancelled"); return false; } + { printf("bad key, cancelled\n"); return false; } udpSockId = Socket::udpBind(address); if (udpSockId < 0) - { crypt.resetKey(); printf("cannot bind, cancelled"); return false; } + { crypt.resetKey(); printf("cannot bind, cancelled\n"); return false; } - printf("success"); + printf("success\n"); this->remoteTcpAddress = remoteTcpAddress; return true; } @@ -55,9 +60,9 @@ bool Tunnel::initTcpServer(const Address &address, const Address &remoteUdpAddre tcpSockId = Socket::tcpListen(address); if (tcpSockId < 0) - { printf("cannot listen, cancelled"); return false; } + { printf("cannot listen, cancelled\n"); return false; } - printf("success"); + printf("success\n"); this->remoteUdpAddress = remoteUdpAddress; return true; } @@ -67,7 +72,7 @@ void Tunnel::closeUdpServer() { if (udpSockId < 0) return; closeTcpServer(); - printf("Close UDP server"); + printf("Close UDP server\n"); for(ConnMap::iterator i = connections.begin(); i != connections.end(); ++i) Socket::close(i->second.tcpSockId); connections.clear(); @@ -81,7 +86,7 @@ void Tunnel::closeUdpServer() { void Tunnel::closeTcpServer() { if (tcpSockId < 0) return; - printf("Close TCP server"); + printf("Close TCP server\n"); Socket::close(tcpSockId); tcpSockId = -1; remoteUdpAddress = Address(); @@ -101,19 +106,25 @@ bool Tunnel::recvUdpPacket() { unsigned char data[CRYPT_PACKET_SIZE] = {}; int dataSize = CRYPT_PACKET_SIZE; - if (!crypt.decrypt(raw, rawSize, data, dataSize)) + if (!crypt.decrypt(raw, rawSize, data, dataSize)) { + DebugMsg("bad packet received%s", (id.address.print(), "")); return true; + } void *ptr = data; Packet &packet = *(Packet*)ptr; int cmd = packet.getCommand(); - if (cmd < 0 || cmd > CMDMAX) + if (cmd < 0 || cmd > CMDMAX) { + DebugMsg("bad command %d", cmd); return true; - + } + int size = packet.getSize(); - if (size + HEADER_SIZE > dataSize) + if (size + HEADER_SIZE > dataSize) { + DebugMsg("bad size %d", size); return true; + } id.id = packet.connId; ConnMap::iterator i = connections.find(id); @@ -130,12 +141,14 @@ bool Tunnel::recvUdpPacket() { } if (i == connections.end()) { - termQueue.insert(id); + DebugMsg("unknown connection: %u", id.id); + if (cmd != CMD_TERM) termQueue.insert(id); return true; } Connection &conn = i->second; if (!conn.udpActive) { + DebugMsg("inactive connection: %u", id.id); if (cmd != CMD_TERM) termQueue.insert(id); return true; } @@ -149,6 +162,7 @@ bool Tunnel::recvUdpPacket() { conn.udpSendQueue.confirm(packet.index, packet.payload, size*8); } else if (cmd == CMD_TERM) { + DebugMsg("%u, inactivate udp by incoming term packet", conn.id.id); conn.udpActive = false; } @@ -178,6 +192,7 @@ bool Tunnel::iteration() { e.wantRead = true; if (tcpSockId >= 0) { + poll.list.emplace_back(); Poll::Entry &e = poll.list.back(); e.sockId = tcpSockId; e.wantRead = true; @@ -201,11 +216,15 @@ bool Tunnel::iteration() { } // wait for events - poll.wait(pollTime); + poll.wait(pollTime - time); time = monotonicTime(); - crypt.setTime(time, timeQwant); + crypt.setTime(globalTime(), timeQwant); if (poll.list[0].closed || (tcpSockId >= 0 && poll.list[1].closed)) { + if (poll.list[0].closed) + printf("UDP socket closed\n"); + if (tcpSockId >= 0 && poll.list[1].closed) + printf("TCP server socket closed\n"); poll.list.clear(); close(); return false; @@ -217,12 +236,16 @@ bool Tunnel::iteration() { if (tcpSockId >= 0 && poll.list[1].canRead) { // accept new tcp connection - int sockId = Socket::tcpAccept(tcpSockId); + Address address; + int sockId = Socket::tcpAccept(tcpSockId, address); if (sockId >= 0) { ConnId id; id.address = remoteUdpAddress; id.id = crypt.random(); - Connection &conn = connections.emplace(id, Connection::Args(*this, id, tcpSockId)).first->second; + printf("Open incoming connection from TCP: "); + address.print(); + printf(", id: %u\n", id.id); + Connection &conn = connections.emplace(id, Connection::Args(*this, id, sockId)).first->second; conn.tcpRecvQueue.readFromTcp(); } } @@ -240,6 +263,7 @@ bool Tunnel::iteration() { if (i->canRead && conn.udpActive) conn.tcpRecvQueue.readFromTcp(); if (i->closed || (!conn.udpActive && !conn.tcpSendQueue.busySize())) { + DebugMsg("%u, inactivate tcp", conn.id.id); Socket::close(conn.tcpSockId); conn.tcpSockId = -1; } @@ -254,51 +278,67 @@ bool Tunnel::iteration() { if (conn.udpRecvQueue.sendUdpConfirm()) sent = true; if (sent) conn.lastUdpOutputTime = time; if (timeLequal(conn.lastUdpOutputTime + udpKeepAliveDuration, time)) { + DebugMsg("%u, udp keep alive", conn.id.id); conn.udpRecvQueue.sendUdpConfirm(true); conn.lastUdpOutputTime = time; } } } - // close connections + // remove closed connections + for(ClosedConnMap::iterator i = closedConnections.begin(); i != closedConnections.end();) + if (timeLess(i->second, time)) + closedConnections.erase(i++); else ++i; + + // close connections for(ConnMap::iterator i = connections.begin(); i != connections.end();) { ConnMap::iterator j = i++; Connection &conn = j->second; - if (conn.udpActive && timeLess(conn.lastUdpInputTime + udpSilenceDuration, time)) + if (conn.udpActive && timeLess(conn.lastUdpInputTime + udpSilenceDuration, time)) { + DebugMsg("%u, inactivate udp by silence timeout", conn.id.id); conn.udpActive = false; + termQueue.insert(conn.id); + } if (!conn.udpActive && !conn.wantWriteToTcp()) { + DebugMsg("%u, inactivate tcp", conn.id.id); Socket::close(conn.tcpSockId); conn.tcpSockId = -1; + printf("Close connection, id: %u\n", conn.id.id); closedConnections[conn.id] = time + 2*udpSilenceDuration; - connections.erase(i); + connections.erase(j); continue; } if (conn.tcpSockId < 0 && !conn.udpSendQueue.canSentToUdp()) { + DebugMsg("%u, inactivate udp, all data sent", conn.id.id); conn.udpActive = false; + termQueue.insert(conn.id); + printf("Close connection, id: %u\n", conn.id.id); closedConnections[conn.id] = time + 2*udpSilenceDuration; - connections.erase(i); + connections.erase(j); continue; } if (!conn.udpActive && conn.tcpSockId < 0) { + printf("Close connection, id: %u\n", conn.id.id); closedConnections[conn.id] = time + 2*udpSilenceDuration; - connections.erase(i); + connections.erase(j); continue; } } // send udp term packets - if (timeLequal(nextTermSendTime, time)) { + if (timeLequal(nextTermSendTime, time) && !termQueue.empty()) { for(ConnSet::const_iterator i = termQueue.begin(); i != termQueue.end(); ++i) { + DebugMsg("%u, send term packet", i->id); Packet packet; packet.init(i->id, 0, CMD_TERM, 0); sendUdpPacket(i->address, packet); } termQueue.clear(); - nextTermSendTime = time + nextTermSendTime; + nextTermSendTime = time + udpConfirmDuration; } poll.list.clear(); diff --git a/udprecvqueue.cpp b/udprecvqueue.cpp index d4050f7..b312c32 100644 --- a/udprecvqueue.cpp +++ b/udprecvqueue.cpp @@ -7,6 +7,11 @@ +//#define DebugMsg HideDebugMSG +#define DebugMsg ShowDebugMSG + + + UdpRecvQueue::UdpRecvQueue(Connection &connection): connection(connection), nextConfirmSendTime(connection.tunnel.time + connection.tunnel.udpConfirmDuration), @@ -35,6 +40,7 @@ bool UdpRecvQueue::recvUdp(unsigned int index, const void *data, int size) { if (e.received) return false; + DebugMsg("%u, %u, %d", connection.id.id, index, size); e.received = true; e.size = size; if (e.size) { @@ -62,6 +68,7 @@ bool UdpRecvQueue::sendUdpConfirm(bool force) { for(int i = 0; i < (int)count; ++i) if (entries[(current + i)%PACKETS_COUNT].received) packet.payload[i/8] |= (1 << (i%8)); + DebugMsg("%u, %u, %s", connection.id.id, currentIndex, bitsToString(packet.payload, bytesCount)); connection.tunnel.sendUdpPacket(connection.id.address, packet); confirmRequired = false; @@ -82,6 +89,7 @@ int UdpRecvQueue::sendTcp() { break; if (!e.size) { + DebugMsg("%d, inactivate udp, all packets received", connection.id.id); connection.udpActive = false; connection.tunnel.termQueue.insert(connection.id); } else @@ -92,6 +100,7 @@ int UdpRecvQueue::sendTcp() { ++sent; memset(&e, 0, sizeof(e)); current = (current + 1)%PACKETS_COUNT; + ++currentIndex; } if (sent) confirmRequired = true; return sent; diff --git a/udpsendqueue.cpp b/udpsendqueue.cpp index 84c88b0..b272c01 100644 --- a/udpsendqueue.cpp +++ b/udpsendqueue.cpp @@ -6,6 +6,10 @@ #include "udpsendqueue.h" +//#define DebugMsg HideDebugMSG +#define DebugMsg ShowDebugMSG + + UdpSendQueue::UdpSendQueue(Connection &connection): connection(connection), @@ -35,7 +39,7 @@ UdpSendQueue::Entry* UdpSendQueue::allocEntry() { e->prev = nullptr; e->next = busyFirst; (e->next ? e->next->prev : busyLast) = e; - busyLast = e; + busyFirst = e; return e; } @@ -60,22 +64,21 @@ void UdpSendQueue::freeEntry(Entry *e) { void UdpSendQueue::confirm(unsigned int index, const unsigned char *bits, unsigned int bitsCount) { - while(busyFirst && cycleLess(busyFirst->index, index)) - freeEntry(busyFirst); + DebugMsg("%u, %u, %s", connection.id.id, index, bitsToString(bits, bitsCount ? (bitsCount-1)/8 + 1 : 0)); Entry *e = busyFirst; while(e) { - unsigned int bit = e->index - index; - if (bit > bitsCount) - break; - unsigned int byte = bit / 8; - bit %= 8; - if (bits[byte] & (1 << bit)) { - Entry *ee = e; - e = e->next; - freeEntry(ee); - } else { - e = e->next; + if (e->index >= index) { + unsigned int bit = e->index - index; + unsigned int byte = bit / 8; + bit %= 8; + if ( !(bits[byte] & (1 << bit)) ) + { e = e->next; continue; } } + + DebugMsg("%u, %u - confirmed", connection.id.id, e->index); + Entry *ee = e; + e = e->next; + freeEntry(ee); } } @@ -86,7 +89,7 @@ bool UdpSendQueue::sendUdpSingle() { return false; Entry *e = busyFirst; - if (!timeLequal(e->resendTime, time)) + if (e && !timeLequal(e->resendTime, time)) e = nullptr; if (!e && !freeFirst) @@ -107,7 +110,7 @@ bool UdpSendQueue::sendUdpSingle() { return false; } - if (( e = allocEntry() )) + if ( !(e = allocEntry()) ) return false; e->index = nextIndex++; e->size = size; @@ -117,10 +120,11 @@ bool UdpSendQueue::sendUdpSingle() { Packet packet; packet.init(connection.id.id, e->index, CMD_DATA, e->size); if (e->size) memcpy(packet.payload, e->data, e->size); + DebugMsg("%u, %u, %hu", connection.id.id, e->index, e->size); connection.tunnel.sendUdpPacket(connection.id.address, packet); e->resendTime = time + connection.tunnel.udpResendDuration; - nextSendTime += connection.tunnel.udpPartialSendDuration; + nextSendTime += connection.tunnel.udpSendDuration; nextPartialSendTime = time + connection.tunnel.udpPartialSendDuration; moveToEnd(e); return true; @@ -154,7 +158,7 @@ void UdpSendQueue::whenWriteToUdp(Time &t) const { Time tt = t; if (size > 0 && timeLess(nextPartialSendTime, tt)) tt = nextPartialSendTime; - if (timeLess(busyFirst->resendTime, tt)) tt = busyFirst->resendTime; + if (busyFirst && timeLess(busyFirst->resendTime, tt)) tt = busyFirst->resendTime; if (timeLess(tt, nextSendTime)) tt = nextSendTime; - if (timeLess(t, tt)) t = tt; + if (timeLess(tt, t)) t = tt; }