#ifndef TRAIN_DIGIT_INC_CPP
#define TRAIN_DIGIT_INC_CPP
#include "train.inc.cpp"
#include "layer.simple.inc.cpp"
class TrainerDigit: public Trainer {
protected:
std::vector<unsigned char> data;
std::vector<unsigned int> shuffle;
Layout ofl, obl;
Layout::List oflist, oblist;
int stride, count;
public:
TrainerDigit(): stride(), count() { }
bool loadSymbolMap(const char *filename) {
data.clear();
FILE *f = fopen(filename, "rb");
if (!f)
return printf("cannot open file for read: %s\n", filename), false;
fseek(f, 0, SEEK_END);
size_t fs = ftello(f);
fseek(f, 0, SEEK_SET);
data.resize(fs, 0);
if (!fread(data.data(), fs, 1, f))
return printf("cannot read from file: %s\n", filename), fclose(f), data.clear(), false;
fclose(f);
return true;
}
protected:
bool prepare() override {
ofl = optimizeLayoutSimple(fl->layout);
obl = optimizeLayoutSimple(bl->layout);
ofl.split(oflist, threadsCount);
obl.split(oblist, threadsCount);
stride = ofl.getActiveCount() + 1;
count = data.size()/stride;
if (count <= 0) return false;
shuffle.resize(count);
for(int i = 0; i < count; ++i)
shuffle[i] = i;
return true;
}
bool prepareBlock() override {
int cnt = itersPerBlock > count ? count : itersPerBlock;
for(int i = 0; i < cnt; ++i) {
int j = rand()%count;
if (i != j) std::swap(shuffle[i], shuffle[j]);
}
return true;
}
void loadData(Barrier &barrier, int, int iter) override {
struct I: public Iter {
typedef const unsigned char* DataType;
static inline void iter4(Neuron &n, DataType d, DataAccumType&) { n.v = *d/(NeuronReal)255; }
};
const unsigned char *id = data.data() + shuffle[iter%count]*stride;
iterateNeurons2<I>(oflist[barrier.tid], ofl, fl->neurons, id);
}
AccumReal verifyDataMain(int, int iter) override {
struct I: public Iter {
typedef int DataType;
struct DataAccumType { int ri, mi; NeuronReal m; };
static inline void iter4(Neuron &n, DataType d, DataAccumType &a) {
NeuronReal v1 = d == a.ri;
NeuronReal v0 = n.v;
n.d *= v1 - v0;
if (a.m < v0) { a.m = v0; a.mi = d; }
}
};
I::DataAccumType a = { data[ (shuffle[iter%count] + 1)*stride - 1 ], 0, 0 };
iterateNeurons2<I>(obl, obl, bl->neurons, 0, 1, &a);
return a.mi != a.ri;
}
};
#endif