#include <cassert>
#include <cryptopp/osrng.h>
#include <cryptopp/sha3.h>
#include <cryptopp/modes.h>
#include <cryptopp/threefish.h>
#include "crypt.h"
//#define DebugMsg HideDebugMSG
#define DebugMsg ShowDebugMSG
class Crypt::Private {
public:
Crypt &crypt;
unsigned char key[CRYPT_BLOCK];
unsigned char iv[2][CRYPT_BLOCK];
Time tq[2];
CryptoPP::DefaultAutoSeededRNG random;
CryptoPP::SHA3_256 hasher;
CryptoPP::CBC_Mode<CryptoPP::Threefish1024>::Encryption enc;
CryptoPP::CBC_Mode<CryptoPP::Threefish1024>::Decryption dec;
inline Private(Crypt &crypt):
crypt(crypt), key(), iv(), tq() { }
inline ~Private() {
memset(key, 0, sizeof(key));
memset(iv, 0, sizeof(iv));
memset(tq, 0, sizeof(tq));
}
inline void setTimeQwant(Time timeQwant, int index) {
if (tq[index] == timeQwant) return;
tq[index] = timeQwant;
for(unsigned char *e = iv[index], *c = e + CRYPT_BLOCK - 1; c >= e; --c)
{ *c = timeQwant & 0xff; timeQwant >>= 8; }
DebugMsg("%d, %llu", index, tq[index]);
}
inline bool hash256(const unsigned char *src, int srcSize, unsigned char *dst) {
try {
hasher.Restart();
hasher.CalculateDigest(dst, src, srcSize);
return true;
} catch(...) { }
return false;
}
inline bool encrypt(const unsigned char *src, unsigned char *dst, int size) {
try {
enc.SetKeyWithIV(key, CRYPT_BLOCK, iv[0]);
enc.ProcessData(dst, src, size);
return true;
} catch(...) { }
return false;
}
inline bool decrypt(const unsigned char *src, unsigned char *dst, int size, int index) {
try {
dec.SetKeyWithIV(key, CRYPT_BLOCK, iv[index]);
dec.ProcessData(dst, src, size);
return true;
} catch(...) { }
return false;
}
};
Crypt::Crypt():
priv(new Private(*this)) { }
Crypt::~Crypt()
{ resetKey(); delete priv; priv = nullptr; }
unsigned int Crypt::random()
{ return priv->random.GenerateWord32(); }
bool Crypt::setKey(const char *key) {
static_assert(CRYPT_BLOCK % HASH_SIZE == 0);
static_assert(CRYPT_BLOCK > HASH_SIZE);
static const char salt[] = "dnmnfhne78hj0fklf2891ve";
const int count = CRYPT_BLOCK/HASH_SIZE;
try {
priv->hasher.Restart();
priv->hasher.Update((const unsigned char*)salt, sizeof(salt));
priv->hasher.Update((const unsigned char*)key, strlen(key));
priv->hasher.Final(priv->key);
for(int i = 1; i < count; ++i) {
priv->hasher.Restart();
priv->hasher.CalculateDigest(priv->key + i*HASH_SIZE, priv->key + (i - 1)*HASH_SIZE, HASH_SIZE);
}
return true;
} catch(...) { }
return false;
}
void Crypt::resetKey()
{ memset(priv->key, 0, sizeof(priv->key)); }
void Crypt::setTime(Time time, Time qwant) {
Time hq = qwant/2;
time -= hq;
Time tq = time/qwant;
if (time % qwant < hq) {
priv->setTimeQwant(tq, 0);
priv->setTimeQwant(tq + 1, 1);
} else {
priv->setTimeQwant(tq + 1, 0);
priv->setTimeQwant(tq, 1);
}
}
bool Crypt::encrypt(const void *src, int srcSize, void *dst, int &dstSize) {
int ds = dstSize;
dstSize = 0;
if (!srcSize) return true;
if (!src) return false;
int count = (srcSize + HASH_SIZE - 1)/CRYPT_BLOCK + 1;
int size = count*CRYPT_BLOCK;
if (ds < size) return false;
if (size > CRYPT_PACKET_SIZE) return false;
unsigned char data[CRYPT_PACKET_SIZE] = {};
memcpy(data + HASH_SIZE, src, srcSize);
static_assert(HASH_SIZE == 256/8);
if ( !priv->hash256(data + HASH_SIZE, size - HASH_SIZE, data)
|| !priv->encrypt(data, (unsigned char*)dst, size) )
return false;
dstSize = size;
return true;
}
bool Crypt::decrypt(const void *src, int srcSize, void *dst, int &dstSize) {
int ds = dstSize;
dstSize = 0;
if (!srcSize) return true;
if (!src) return false;
if (srcSize % CRYPT_BLOCK) return false;
int count = srcSize/CRYPT_BLOCK;
int size = count*CRYPT_BLOCK;
if (ds < size - HASH_SIZE) return false;
if (size > CRYPT_PACKET_SIZE) return false;
ds = size - HASH_SIZE;
unsigned char hash[2][HASH_SIZE] = {};
unsigned char data[2][CRYPT_PACKET_SIZE] = {};
bool valid[2] = {};
for(int i = 0; i < 2; ++i) {
static_assert(HASH_SIZE == 256/8);
if ( !priv->decrypt((unsigned char*)src, data[i], size, i)
|| !priv->hash256(data[i] + HASH_SIZE, size - HASH_SIZE, hash[i]) ) return false;
valid[i] = 0 == memcmp(hash, data[i], HASH_SIZE);
}
if (valid[0]) memcpy(dst, data[0] + HASH_SIZE, ds); else
if (valid[1]) memcpy(dst, data[1] + HASH_SIZE, ds); else
return false;
dstSize = size - HASH_SIZE;
return true;
}