diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..405a8d3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.kdev4 +ssbc diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..12e1c50 --- /dev/null +++ b/build.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +set -e + +c++ -Wall -g -pthread *.cpp -o ssbc + diff --git a/error.cpp b/error.cpp new file mode 100644 index 0000000..bfce1b4 --- /dev/null +++ b/error.cpp @@ -0,0 +1,3 @@ + +#include "error.h" + diff --git a/error.h b/error.h new file mode 100644 index 0000000..64423ec --- /dev/null +++ b/error.h @@ -0,0 +1,16 @@ +#ifndef ERROR_H +#define ERROR_H + + +typedef unsigned int ErrorCode; + + +enum : ErrorCode { + ERR_NONE = 0, + ERR_COMMON = 1000, + ERR_TCP = 2000, +}; + + + +#endif diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000..cba5585 --- /dev/null +++ b/main.cpp @@ -0,0 +1,68 @@ + +#include + +#include "tcp.h" +#include "main.h" + + +int main(int argc, char** argv) { + if (argc != 2) + { printf("one arg expected\n"); return 1; } + + Tcp::Address address(127, 0, 0, 1, 1562); + + if (0 == strcmp(argv[1], "tcpconnect")) { + printf("tcpconnect\n"); + + Tcp::Connection connection; + if (connection.open(address)) + printf( "connected to server: %hhd.%hhd.%hhd.%hhd:%hd\n", + address.ip[0], address.ip[1], address.ip[2], address.ip[3], address.port ); + char buf[65] = {}; + while(connection.check()) { + sprintf(buf, "client say %d\n", rand()); + if (connection.write(buf, sizeof(buf) - 1)) + printf("sent: %s\n", buf); + if (connection.read(buf, sizeof(buf) - 1)) + printf("received: %s\n", buf); + } + + printf("last error: %u\n", connection.getLastError()); + return connection.getLastError(); + } + + if (0 == strcmp(argv[1], "tcpserver")) { + printf("tcpserver\n"); + + class Handler: public Tcp::Handler { + public: + void handleTcpConnection(Tcp::Connection &connection) override { + const Tcp::Address &address = connection.getRemoteAddr(); + printf( "received connection from: %hhd.%hhd.%hhd.%hhd:%hd\n", + address.ip[0], address.ip[1], address.ip[2], address.ip[3], address.port ); + while(connection.check()) { + char buf[65] = {}; + if (connection.read(buf, sizeof(buf) - 1)) + printf("received: %s\n", buf); + sprintf(buf, "server say %d\n", rand()); + if (connection.write(buf, sizeof(buf) - 1)) + printf("sent: %s\n", buf); + } + } + } handler; + + Tcp::Server server; + if (server.start(address, &handler)) + printf( "listening at: %hhd.%hhd.%hhd.%hhd:%hd\n", + address.ip[0], address.ip[1], address.ip[2], address.ip[3], address.port ); + server.join(); + + printf("last error: %u\n", server.getLastError()); + return server.getLastError(); + } + + printf("wrong args\n"); + return 1; +} + + diff --git a/main.h b/main.h new file mode 100644 index 0000000..92bf9a0 --- /dev/null +++ b/main.h @@ -0,0 +1,5 @@ +#ifndef MAIN_H +#define MAIN_H + + +#endif diff --git a/socket.h b/socket.h new file mode 100644 index 0000000..a508e8c --- /dev/null +++ b/socket.h @@ -0,0 +1,8 @@ +#ifndef SOCKET_H +#define SOCKET_H + + + + + +#endif diff --git a/tcp.cpp b/tcp.cpp new file mode 100644 index 0000000..711b78c --- /dev/null +++ b/tcp.cpp @@ -0,0 +1,243 @@ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "tcp.h" + + +static const int pollTimeoutMs = 100; + + +using namespace Tcp; + + +//ErrorCode Address::resolve(const char *addrString); + + +Connection::Connection(): + status(STATUS_NONE), + lastError(ERR_NONE), + sockId(), + stopping(false) + { } + +Connection::Connection(SockId sockId, const Address &remoteAddr): + status(STATUS_OPEN), + lastError(ERR_NONE), + sockId(sockId), + remoteAddr(remoteAddr), + stopping(false) + { } + +Connection::~Connection() + { close(); } + +bool Connection::open(const Address &remoteAddr) { + close(); + + this->remoteAddr = remoteAddr; + if (!this->remoteAddr.isValid()) + { close(ERR_TCP_BAD_ADDRESS); return false; } + + sockId = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + status = STATUS_OPEN; + + struct sockaddr_in addr = {}; + addr.sin_family = AF_INET; + memcpy(&addr.sin_addr, remoteAddr.ip, sizeof(remoteAddr.ip)); + addr.sin_port = htons(remoteAddr.port); + + if (0 != connect(sockId, (sockaddr*)&addr, sizeof(addr))) + { close(ERR_TCP_CONNECTION_FAILED); return false; } + + fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK); + return true; +} + +void Connection::close(ErrorCode errorCode) { + if (status == STATUS_OPEN) + ::close(sockId); + sockId = 0; + status = errorCode ? STATUS_LOST : STATUS_NONE; + lastError = errorCode; + stopping = false; +} + +bool Connection::read(void *data, size_t size) { + if (!data) + return check() && size == 0; + + struct pollfd pfd = {}; + pfd.fd = sockId; + pfd.events = POLLIN; + + while(check() && size) { + int pr = poll(&pfd, 1, pollTimeoutMs); + if (pr < 0) + { close(ERR_TCP_CONNECTION_LOST); return false; } + if (pr > 0) { + int s = size > INT_MAX ? INT_MAX : (int)size; + int r = ::recv(sockId, data, s, MSG_DONTWAIT); + if (r < 0) { + if (errno != EAGAIN) + { close(ERR_TCP_CONNECTION_LOST); return false; } + } else { + size -= r; + data = (char*)data + r; + } + } + } + + return size == 0; +} + +bool Connection::write(const void *data, size_t size) { + if (!data) + return check() && size == 0; + + struct pollfd pfd = {}; + pfd.fd = sockId; + pfd.events = POLLOUT; + + while(check() && size) { + int pr = poll(&pfd, 1, pollTimeoutMs); + if (pr < 0) + { close(ERR_TCP_CONNECTION_LOST); return false; } + if (pr > 0) { + int s = size > INT_MAX ? INT_MAX : (int)size; + int r = ::send(sockId, data, s, MSG_DONTWAIT); + if (r < 0) { + if (errno != EAGAIN) + { close(ERR_TCP_CONNECTION_LOST); return false; } + } else { + size -= r; + data = (char*)data + r; + } + } + } + + return size == 0; +} + + +Handler::~Handler() { } + + +Server::Server(): + status(STATUS_NONE), + lastError(ERR_NONE), + sockId(), + localAddr(), + handler(), + thread(), + stopping(false) + { } + +Server::~Server() + { stop(); } + + +void Server::listen() { + fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK); + + struct sockaddr_in addr = {}; + Address remoteAddr; + + struct pollfd pfd = {}; + pfd.fd = sockId; + pfd.events = POLLIN; + + status = STATUS_OPEN; + while(!stopping) { + int pr = poll(&pfd, 1, pollTimeoutMs); + if (pr < 0) + { lastError = ERR_TCP_LISTEN_LOST; break; } + if (pr > 0 && (pr & POLLIN)) { + socklen_t addrlen = sizeof(addr); + int sid = ::accept(sockId, (sockaddr*)&addr, &addrlen); + if (sid < 0) { + if (errno != EAGAIN) + { lastError = ERR_TCP_LISTEN_LOST; break; } + continue; + } + + memcpy(remoteAddr.ip, &addr.sin_addr, sizeof(remoteAddr.ip)); + remoteAddr.port = ntohs(addr.sin_port); + + ConnDesc desc = {}; + desc.connection = new Connection(sid, remoteAddr); + desc.thread = new std::thread( + &Handler::handleTcpConnection, handler, std::ref(*desc.connection)); + connections.push_back(desc); + } + } + status = lastError ? STATUS_LOST : STATUS_CLOSED; + + for(ConnList::iterator i = connections.begin(); i != connections.end(); ++i) + i->connection->stop(); + for(ConnList::iterator i = connections.begin(); i != connections.end(); ++i) { + i->thread->join(); + delete i->thread; + delete i->connection; + } + + stopping = false; +} + +bool Server::start(const Address &localAddr, Handler *handler) { + stop(); + + this->localAddr = localAddr; + this->handler = handler; + + if (!this->localAddr.isValid()) + { status = STATUS_LOST; lastError = ERR_TCP_BAD_ADDRESS; return false; } + if (!this->handler) + { status = STATUS_LOST; lastError = ERR_TCP_BAD_HANDLER; return false; } + + sockId = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + + struct sockaddr_in addr = {}; + addr.sin_family = AF_INET; + memcpy(&addr.sin_addr, localAddr.ip, sizeof(localAddr.ip)); + addr.sin_port = htons(localAddr.port); + + if ( 0 != ::bind(sockId, (sockaddr*)&addr, sizeof(addr)) + || 0 != ::listen(sockId, 32) ) + { + status = STATUS_LOST; + lastError = ERR_TCP_LISTEN_FAILED; + this->handler = nullptr; + return false; + } + + fcntl(sockId, F_SETFL, fcntl(sockId, F_GETFL, 0) | O_NONBLOCK); + thread = new std::thread(&Server::listen, this); + return true; +} + +void Server::join() { + if (!thread) + return; + thread->join(); + delete thread; + thread = nullptr; + handler = nullptr; +} + +void Server::stop() { + if (!thread) + return; + stopping = true; + join(); +} + diff --git a/tcp.h b/tcp.h new file mode 100644 index 0000000..6bf2ed1 --- /dev/null +++ b/tcp.h @@ -0,0 +1,155 @@ +#ifndef TCP_H +#define TCP_H + + +#include +#include +#include + +#include "error.h" + + +enum : ErrorCode { + ERR_TCP_COMMON = ERR_TCP, + ERR_TCP_RESOLVE_FAILED, + ERR_TCP_BAD_ADDRESS, + ERR_TCP_BAD_HANDLER, + ERR_TCP_CONNECTION_FAILED, + ERR_TCP_CONNECTION_LOST, + ERR_TCP_LISTEN_FAILED, + ERR_TCP_LISTEN_LOST, +}; + + +namespace Tcp { + +enum Status { + STATUS_NONE, + STATUS_OPEN, + STATUS_CLOSED, + STATUS_LOST, +}; + + +typedef int SockId; + + +class Address { +public: + unsigned char ip[4]; + unsigned short port; + + inline Address(): + ip(), port() { } + inline Address(unsigned char ip0, unsigned char ip1, unsigned char ip2, unsigned char ip3, unsigned short port): + port(port) { ip[0] = ip0, ip[1] = ip1, ip[2] = ip2, ip[3] = ip3; } + inline Address(unsigned char *ip, unsigned short port): + Address(ip[0], ip[1], ip[2], ip[3], port) { } + + inline bool isValidIp() const + { return ip[0] || ip[1] || ip[2] || ip[3]; } + inline bool isValidPort() const + { return port; } + inline bool isValid() const + { return isValidIp() && isValidPort(); } + + ErrorCode resolve(const char *addrString); +}; + + +class Connection { +private: + Status status; + ErrorCode lastError; + SockId sockId; + Address remoteAddr; + std::atomic stopping; + + Connection(const Connection&) = delete; + Connection& operator=(const Connection&) = delete; + +public: + Connection(); + Connection(SockId sockId, const Address &remoteAddr); + ~Connection(); + + inline Status getStatus() const + { return status; } + ErrorCode getLastError() const + { return lastError; } + inline SockId getSockId() const + { return sockId; } + inline const Address& getRemoteAddr() const + { return remoteAddr; } + + inline void stop() + { stopping = true; } + inline bool check() { + if (stopping) close(ERR_TCP_CONNECTION_LOST); + return status == STATUS_OPEN; + } + + bool open(const Address &remoteAddr); + void close(ErrorCode errorCode = ERR_NONE); + + bool read(void *data, size_t size); + bool write(const void *data, size_t size); +}; + + +class Handler { +public: + virtual ~Handler(); + virtual void handleTcpConnection(Connection &connection) = 0; +}; + + +class Server { +public: + struct ConnDesc { + Connection *connection; + std::thread *thread; + }; + typedef std::vector ConnList; + +private: + Status status; + ErrorCode lastError; + SockId sockId; + Address localAddr; + Handler *handler; + std::thread *thread; + ConnList connections; + std::atomic stopping; + + Server(const Server&) = delete; + Server& operator=(const Server&) = delete; + + void listen(); + +public: + Server(); + ~Server(); + + inline Status getStatus() const + { return status; } + ErrorCode getLastError() const + { return lastError; } + inline SockId getSockId() const + { return sockId; } + inline const Address& getLocalAddr() const + { return localAddr; } + inline Handler* getHandler() const + { return handler; } + + inline bool check() + { return !stopping && status == STATUS_OPEN; } + + bool start(const Address &localAddr, Handler *handler); + void join(); + void stop(); +}; + +} // Tcp + +#endif