|
|
e865c9 |
#ifndef TRAIN_DIGIT_INC_CPP
|
|
|
e865c9 |
#define TRAIN_DIGIT_INC_CPP
|
|
|
e865c9 |
|
|
|
e865c9 |
|
|
|
e865c9 |
#include "train.inc.cpp"
|
|
|
e865c9 |
#include "layer.simple.inc.cpp"
|
|
|
e865c9 |
|
|
|
e865c9 |
|
|
|
e865c9 |
class TrainerDigit: public Trainer {
|
|
|
e865c9 |
protected:
|
|
|
e865c9 |
std::vector<unsigned char=""> data;</unsigned>
|
|
|
e865c9 |
std::vector<unsigned int=""> shuffle;</unsigned>
|
|
|
e865c9 |
Layout ofl, obl;
|
|
|
e865c9 |
Layout::List oflist, oblist;
|
|
|
e865c9 |
int stride, count;
|
|
|
e865c9 |
|
|
|
e865c9 |
public:
|
|
|
e865c9 |
TrainerDigit(): stride(), count() { }
|
|
|
e865c9 |
|
|
|
e865c9 |
bool loadSymbolMap(const char *filename) {
|
|
|
e865c9 |
data.clear();
|
|
|
e865c9 |
|
|
|
e865c9 |
FILE *f = fopen(filename, "rb");
|
|
|
e865c9 |
if (!f)
|
|
|
e865c9 |
return printf("cannot open file for read: %s\n", filename), false;
|
|
|
e865c9 |
fseek(f, 0, SEEK_END);
|
|
|
e865c9 |
size_t fs = ftello(f);
|
|
|
e865c9 |
fseek(f, 0, SEEK_SET);
|
|
|
e865c9 |
|
|
|
e865c9 |
data.resize(fs, 0);
|
|
|
e865c9 |
if (!fread(data.data(), fs, 1, f))
|
|
|
e865c9 |
return printf("cannot read from file: %s\n", filename), fclose(f), data.clear(), false;
|
|
|
e865c9 |
|
|
|
e865c9 |
fclose(f);
|
|
|
e865c9 |
return true;
|
|
|
e865c9 |
}
|
|
|
8e5348 |
|
|
|
8e5348 |
static void printSymbol(const unsigned char *data, int w, int h, int index = -1) {
|
|
|
8e5348 |
if (index >= 0) printf("\nsymbol %d (%d):\n", (int)data[w*h], index);
|
|
|
8e5348 |
else printf("\nsymbol %d:\n", (int)data[w*h]);
|
|
|
8e5348 |
for(int i = 0; i < h; ++i) {
|
|
|
8e5348 |
for(int j = 0; j < w; ++j) printf("%s", data[i*w+j] > 128u ? "#" : " ");
|
|
|
8e5348 |
printf("\n");
|
|
|
8e5348 |
}
|
|
|
8e5348 |
printf("\n");
|
|
|
8e5348 |
}
|
|
|
8e5348 |
|
|
|
8e5348 |
void printSymbol(int index) {
|
|
|
8e5348 |
const Layout &l = layer->layout;
|
|
|
8e5348 |
printSymbol(&data[(l.getActiveCount()+1)*index], l.getW(), l.getH(), index);
|
|
|
8e5348 |
}
|
|
|
e865c9 |
|
|
|
e865c9 |
protected:
|
|
|
e865c9 |
bool prepare() override {
|
|
|
e865c9 |
ofl = optimizeLayoutSimple(fl->layout);
|
|
|
e865c9 |
obl = optimizeLayoutSimple(bl->layout);
|
|
|
e865c9 |
ofl.split(oflist, threadsCount);
|
|
|
e865c9 |
obl.split(oblist, threadsCount);
|
|
|
e865c9 |
stride = ofl.getActiveCount() + 1;
|
|
|
e865c9 |
count = data.size()/stride;
|
|
|
e865c9 |
if (count <= 0) return false;
|
|
|
e865c9 |
shuffle.resize(count);
|
|
|
e865c9 |
for(int i = 0; i < count; ++i)
|
|
|
e865c9 |
shuffle[i] = i;
|
|
|
e865c9 |
return true;
|
|
|
e865c9 |
}
|
|
|
e865c9 |
|
|
|
e865c9 |
|
|
|
e865c9 |
bool prepareBlock() override {
|
|
|
e865c9 |
int cnt = itersPerBlock > count ? count : itersPerBlock;
|
|
|
e865c9 |
for(int i = 0; i < cnt; ++i) {
|
|
|
e865c9 |
int j = rand()%count;
|
|
|
e865c9 |
if (i != j) std::swap(shuffle[i], shuffle[j]);
|
|
|
e865c9 |
}
|
|
|
e865c9 |
return true;
|
|
|
e865c9 |
}
|
|
|
e865c9 |
|
|
|
e865c9 |
|
|
|
e865c9 |
void loadData(Barrier &barrier, int, int iter) override {
|
|
|
e865c9 |
struct I: public Iter {
|
|
|
e865c9 |
typedef const unsigned char* DataType;
|
|
|
e865c9 |
static inline void iter4(Neuron &n, DataType d, DataAccumType&) { n.v = *d/(NeuronReal)255; }
|
|
|
e865c9 |
};
|
|
|
e865c9 |
const unsigned char *id = data.data() + shuffle[iter%count]*stride;
|
|
|
e865c9 |
iterateNeurons2(oflist[barrier.tid], ofl, fl->neurons, id);
|
|
|
e865c9 |
}
|
|
|
e865c9 |
|
|
|
e865c9 |
|
|
|
8e5348 |
Quality verifyData(Barrier &barrier, int, int iter) override {
|
|
|
b579b3 |
Quality q;
|
|
|
8e5348 |
if (barrier.tid) return q;
|
|
|
8e5348 |
|
|
|
e865c9 |
struct I: public Iter {
|
|
|
e865c9 |
typedef int DataType;
|
|
|
8e5348 |
struct DataAccumType { int ri, mi; NeuronReal m, ratio, q; };
|
|
|
e865c9 |
static inline void iter4(Neuron &n, DataType d, DataAccumType &a) {
|
|
|
e865c9 |
NeuronReal v1 = d == a.ri;
|
|
|
e865c9 |
NeuronReal v0 = n.v;
|
|
|
8e5348 |
NeuronReal diff = v1 - v0;
|
|
|
8e5348 |
n.d *= diff*a.ratio;
|
|
|
8e5348 |
a.q += diff*diff;
|
|
|
e865c9 |
if (a.m < v0) { a.m = v0; a.mi = d; }
|
|
|
e865c9 |
}
|
|
|
e865c9 |
};
|
|
|
e865c9 |
|
|
|
8e5348 |
int index = shuffle[iter%count];
|
|
|
8e5348 |
if (index == 59915) {
|
|
|
8e5348 |
++skipBackpass;
|
|
|
8e5348 |
return q;
|
|
|
8e5348 |
}
|
|
|
8e5348 |
|
|
|
8e5348 |
I::DataAccumType a = { data[ (index + 1)*stride - 1 ], 0, 0, ratio };
|
|
|
e865c9 |
iterateNeurons2(obl, obl, bl->neurons, 0, 1, &a);
|
|
|
e865c9 |
|
|
|
8e5348 |
q.train = sqrt(a.q/obl.getActiveCount());
|
|
|
8e5348 |
q.human = a.mi != a.ri;
|
|
|
8e5348 |
//if (!q.human && q.train < 0.01) ++skipBackpass;
|
|
|
8e5348 |
//if (!q.human) ++skipBackpass;
|
|
|
8e5348 |
//if (q.human) printSymbol(index);
|
|
|
8e5348 |
return q;
|
|
|
e865c9 |
}
|
|
|
e865c9 |
};
|
|
|
e865c9 |
|
|
|
e865c9 |
|
|
|
e865c9 |
#endif
|