Blob Blame Raw

#include <cassert>

#include <cryptopp/osrng.h>
#include <cryptopp/sha3.h>
#include <cryptopp/modes.h>
#include <cryptopp/threefish.h>


#include "crypt.h"



//#define DebugMsg HideDebugMSG
#define DebugMsg ShowDebugMSG



class Crypt::Private {
public:
	Crypt &crypt;
	unsigned char key[CRYPT_BLOCK];
	unsigned char iv[2][CRYPT_BLOCK];
	Time tq[2];
	
	CryptoPP::DefaultAutoSeededRNG random;
	CryptoPP::SHA3_256 hasher;
	CryptoPP::CBC_Mode<CryptoPP::Threefish1024>::Encryption enc;
	CryptoPP::CBC_Mode<CryptoPP::Threefish1024>::Decryption dec;
	
	inline Private(Crypt &crypt):
		crypt(crypt), key(), iv(), tq() { }
	
	inline ~Private() {
		memset(key, 0, sizeof(key));
		memset(iv, 0, sizeof(iv));
		memset(tq, 0, sizeof(tq));
	}
	
	inline void setTimeQwant(Time timeQwant, int index) {
		if (tq[index] == timeQwant) return;
		tq[index] = timeQwant;
		for(unsigned char *e = iv[index], *c = e + CRYPT_BLOCK - 1; c >= e; --c)
			{ *c = timeQwant & 0xff; timeQwant >>= 8; }
		DebugMsg("%d, %llu", index, tq[index]);
	}
	
	inline bool hash256(const unsigned char *src, int srcSize, unsigned char *dst) {
		try {
			hasher.Restart();
			hasher.CalculateDigest(dst, src, srcSize);
			return true;
		} catch(...) { }
		return false;
	}
	
	inline bool encrypt(const unsigned char *src, unsigned char *dst, int size) {
		try {
			enc.SetKeyWithIV(key, CRYPT_BLOCK, iv[0]);
			enc.ProcessData(dst, src, size);
			return true;
		} catch(...) { }
		return false;
	}
	
	inline bool decrypt(const unsigned char *src, unsigned char *dst, int size, int index) {
		try {
			dec.SetKeyWithIV(key, CRYPT_BLOCK, iv[index]);
			dec.ProcessData(dst, src, size);
			return true;
		} catch(...) { }
		return false;
	}
};




Crypt::Crypt():
	priv(new Private(*this)) { }


Crypt::~Crypt()
	{ resetKey(); delete priv; priv = nullptr; }


unsigned int Crypt::random()
	{ return priv->random.GenerateWord32(); }


bool Crypt::setKey(const char *key) {
	static_assert(CRYPT_BLOCK % HASH_SIZE == 0);
	static_assert(CRYPT_BLOCK > HASH_SIZE);
	
	static const char salt[] = "dnmnfhne78hj0fklf2891ve";
	const int count = CRYPT_BLOCK/HASH_SIZE;
	try {
		priv->hasher.Restart();
		priv->hasher.Update((const unsigned char*)salt, sizeof(salt));
		priv->hasher.Update((const unsigned char*)key, strlen(key));
		priv->hasher.Final(priv->key);
		for(int i = 1; i < count; ++i) {
			priv->hasher.Restart();
			priv->hasher.CalculateDigest(priv->key + i*HASH_SIZE, priv->key + (i - 1)*HASH_SIZE, HASH_SIZE);
		}
		return true;
	} catch(...) { }
	return false;
}


void Crypt::resetKey()
	{ memset(priv->key, 0, sizeof(priv->key)); }


void Crypt::setTime(Time time, Time qwant) {
	Time hq = qwant/2;
	time -= hq;
	Time tq = time/qwant;
	if (time % qwant < hq) {
		priv->setTimeQwant(tq,     0);
		priv->setTimeQwant(tq + 1, 1);
	} else {
		priv->setTimeQwant(tq + 1, 0);
		priv->setTimeQwant(tq,     1);
	}
}


bool Crypt::encrypt(const void *src, int srcSize, void *dst, int &dstSize) {
	int ds = dstSize;
	dstSize = 0;
	if (!srcSize) return true;
	if (!src) return false;
	
	int count = (srcSize + HASH_SIZE - 1)/CRYPT_BLOCK + 1;
	int size = count*CRYPT_BLOCK;
	if (ds < size) return false;
	if (size > CRYPT_PACKET_SIZE) return false;

	unsigned char data[CRYPT_PACKET_SIZE] = {};
	memcpy(data + HASH_SIZE, src, srcSize);
	static_assert(HASH_SIZE == 256/8);
	if ( !priv->hash256(data + HASH_SIZE, size - HASH_SIZE, data)
	  || !priv->encrypt(data, (unsigned char*)dst, size) )
		return false;
	
	dstSize = size;
	return true;
}


bool Crypt::decrypt(const void *src, int srcSize, void *dst, int &dstSize) {
	int ds = dstSize;
	dstSize = 0;
	if (!srcSize) return true;
	if (!src) return false;
	if (srcSize % CRYPT_BLOCK) return false;
	
	int count = srcSize/CRYPT_BLOCK;
	int size = count*CRYPT_BLOCK;
	if (ds < size - HASH_SIZE) return false;
	if (size > CRYPT_PACKET_SIZE) return false;
	ds = size - HASH_SIZE;
	
	unsigned char hash[2][HASH_SIZE] = {};
	unsigned char data[2][CRYPT_PACKET_SIZE] = {};
	bool valid[2] = {};
	for(int i = 0; i < 2; ++i) {
		static_assert(HASH_SIZE == 256/8);
		if ( !priv->decrypt((unsigned char*)src, data[i], size, i)
		  || !priv->hash256(data[i] + HASH_SIZE, size - HASH_SIZE, hash[i]) ) return false;
		valid[i] = 0 == memcmp(hash, data[i], HASH_SIZE);
	}
	
	if (valid[0]) memcpy(dst, data[0] + HASH_SIZE, ds); else
	if (valid[1]) memcpy(dst, data[1] + HASH_SIZE, ds); else
				  return false;
	
	dstSize = size - HASH_SIZE;
	return true;
}