|
|
e865c9 |
#ifndef TRAIN_IMAGE_INC_CPP
|
|
|
e865c9 |
#define TRAIN_IMAGE_INC_CPP
|
|
|
e865c9 |
|
|
|
e865c9 |
|
|
|
b579b3 |
#include <set></set>
|
|
|
b579b3 |
|
|
|
e865c9 |
#include "train.inc.cpp"
|
|
|
e865c9 |
#include "layer.simple.inc.cpp"
|
|
|
e865c9 |
|
|
|
e865c9 |
|
|
|
e865c9 |
class TrainerImage: public Trainer {
|
|
|
e865c9 |
protected:
|
|
|
e865c9 |
std::vector<unsigned char=""> data;</unsigned>
|
|
|
b579b3 |
std::vector<unsigned char=""> tmpdata;</unsigned>
|
|
|
b579b3 |
std::vector<int> shuffle;</int>
|
|
|
b579b3 |
std::vector<int> shuffle2;</int>
|
|
|
b579b3 |
Layout pbl;
|
|
|
b579b3 |
Layout::List flist, blist;
|
|
|
b579b3 |
FILE *f;
|
|
|
b579b3 |
size_t imgsize;
|
|
|
b579b3 |
int count;
|
|
|
b579b3 |
int workCount;
|
|
|
e865c9 |
|
|
|
e865c9 |
public:
|
|
|
b579b3 |
int pad;
|
|
|
b579b3 |
const char *datafile;
|
|
|
b579b3 |
const char *outfile;
|
|
|
b579b3 |
Layer *dataLayer;
|
|
|
b579b3 |
|
|
|
b579b3 |
TrainerImage(): f(), imgsize(), count(), workCount(), pad(), datafile(), outfile(), dataLayer() { }
|
|
|
e865c9 |
|
|
|
b579b3 |
protected:
|
|
|
b579b3 |
bool prepare() override {
|
|
|
b579b3 |
assert(datafile);
|
|
|
b579b3 |
assert(fl->layout.getD() == 3);
|
|
|
b579b3 |
|
|
|
e4740d |
#ifndef NDEBUG
|
|
|
b579b3 |
Layer *dl = dataLayer ? dataLayer : fl;
|
|
|
b579b3 |
assert(dl->layout.getW() == bl->layout.getW());
|
|
|
b579b3 |
assert(dl->layout.getH() == bl->layout.getH());
|
|
|
b579b3 |
assert(dl->layout.getD() == bl->layout.getD());
|
|
|
e4740d |
#endif
|
|
|
b579b3 |
|
|
|
b579b3 |
imgsize = fl->layout.getActiveCount();
|
|
|
b579b3 |
fl->layout.split(flist, threadsCount);
|
|
|
b579b3 |
bl->layout.split(blist, threadsCount);
|
|
|
b579b3 |
pbl = bl->layout;
|
|
|
b579b3 |
pbl.padXY(pad);
|
|
|
b579b3 |
|
|
|
b579b3 |
f = fopen(datafile, "rb");
|
|
|
b579b3 |
if (!f) return false;
|
|
|
e865c9 |
|
|
|
b579b3 |
fseeko64(f, 0, SEEK_END);
|
|
|
b579b3 |
long long size = ftello64(f);
|
|
|
b579b3 |
count = size/imgsize;
|
|
|
b579b3 |
if (count < 1) return fclose(f), f = nullptr, false;
|
|
|
e865c9 |
|
|
|
b579b3 |
workCount = itersPerBlock > count ? count : itersPerBlock;
|
|
|
b579b3 |
printf("allocated size: %lld\n", (long long)(imgsize*workCount));
|
|
|
b579b3 |
data.resize(workCount*imgsize);
|
|
|
e865c9 |
|
|
|
b579b3 |
shuffle.resize(count);
|
|
|
b579b3 |
for(int i = 0; i < count; ++i)
|
|
|
b579b3 |
shuffle[i] = i;
|
|
|
e865c9 |
|
|
|
b579b3 |
shuffle2.resize(workCount);
|
|
|
b579b3 |
for(int i = 0; i < workCount; ++i)
|
|
|
b579b3 |
shuffle2[i] = i;
|
|
|
b579b3 |
|
|
|
b579b3 |
return loadBlocks();
|
|
|
b579b3 |
//return true;
|
|
|
e865c9 |
}
|
|
|
e865c9 |
|
|
|
b579b3 |
|
|
|
b579b3 |
void finish() override
|
|
|
b579b3 |
{ if (f) fclose(f), f = nullptr; }
|
|
|
e865c9 |
|
|
|
b579b3 |
|
|
|
b579b3 |
bool loadBlocks() {
|
|
|
b579b3 |
for(int i = 0; i < workCount; ++i) {
|
|
|
b579b3 |
int j = rand()%count;
|
|
|
b579b3 |
if (i != j) std::swap(shuffle[i], shuffle[j]);
|
|
|
e865c9 |
}
|
|
|
b579b3 |
|
|
|
b579b3 |
typedef std::pair<int, int=""> Pair;</int,>
|
|
|
b579b3 |
typedef std::set<pair> Set;</pair>
|
|
|
b579b3 |
Set set;
|
|
|
b579b3 |
for(int i = 0; i < workCount; ++i)
|
|
|
b579b3 |
set.insert(Pair(shuffle[i], i));
|
|
|
b579b3 |
|
|
|
b579b3 |
for(Set::iterator i = set.begin(); i != set.end(); ++i) {
|
|
|
b579b3 |
fseeko64(f, i->first*imgsize, SEEK_SET);
|
|
|
b579b3 |
if (!fread(data.data() + i->second*imgsize, imgsize, 1, f))
|
|
|
b579b3 |
return fclose(f), f = nullptr, false;
|
|
|
e865c9 |
}
|
|
|
e865c9 |
|
|
|
e865c9 |
return true;
|
|
|
e865c9 |
}
|
|
|
e865c9 |
|
|
|
e865c9 |
bool prepareBlock() override {
|
|
|
b579b3 |
for(int i = 0; i < workCount; ++i) {
|
|
|
b579b3 |
int j = rand()%workCount;
|
|
|
b579b3 |
if (i != j) std::swap(shuffle2[i], shuffle2[j]);
|
|
|
e865c9 |
}
|
|
|
b579b3 |
//return loadBlocks();
|
|
|
e865c9 |
return true;
|
|
|
e865c9 |
}
|
|
|
b579b3 |
|
|
|
b579b3 |
|
|
|
b579b3 |
void finishBlock() override {
|
|
|
b579b3 |
if (outfile && !dataLayer) {
|
|
|
b579b3 |
std::string outfile0(outfile);
|
|
|
b579b3 |
std::string outfile1 = outfile0 + ".1.tga";
|
|
|
b579b3 |
outfile0 += ".0.tga";
|
|
|
b579b3 |
|
|
|
b579b3 |
unsigned char *id0 = data.data() + shuffle2[(itersPerBlock-1)%workCount]*imgsize;
|
|
|
b579b3 |
tgaSave(outfile0.c_str(), id0, fl->layout.getW(), fl->layout.getH(), fl->layout.getD());
|
|
|
b579b3 |
|
|
|
b579b3 |
struct I: public Iter {
|
|
|
b579b3 |
typedef unsigned char* DataType;
|
|
|
b579b3 |
static inline void iter4(Neuron &n, DataType d, DataAccumType&) { *d = n.v < 0 ? 0 : n.v > 1 ? 255 : (unsigned char)(n.v*255.999); }
|
|
|
b579b3 |
};
|
|
|
b579b3 |
|
|
|
b579b3 |
tmpdata.resize(imgsize);
|
|
|
b579b3 |
unsigned char *id1 = tmpdata.data();
|
|
|
b579b3 |
iterateNeurons2(bl->layout, bl->layout, bl->neurons, id1);
|
|
|
b579b3 |
tgaSave(outfile1.c_str(), id1, bl->layout.getW(), bl->layout.getH(), bl->layout.getD());
|
|
|
b579b3 |
}
|
|
|
b579b3 |
}
|
|
|
e865c9 |
|
|
|
e865c9 |
|
|
|
e865c9 |
void loadData(Barrier &barrier, int, int iter) override {
|
|
|
e865c9 |
struct I: public Iter {
|
|
|
b579b3 |
typedef const unsigned char* DataType;
|
|
|
b579b3 |
static inline void iter4(Neuron &n, DataType d, DataAccumType&) { n.v = *d/(NeuronReal)255; }
|
|
|
e865c9 |
};
|
|
|
b579b3 |
const unsigned char *id = data.data() + shuffle2[iter%workCount]*imgsize;
|
|
|
b579b3 |
iterateNeurons2(flist[barrier.tid], fl->layout, fl->neurons, id);
|
|
|
e865c9 |
}
|
|
|
e865c9 |
|
|
|
e865c9 |
|
|
|
b579b3 |
Quality verifyData(Barrier &barrier, int, int iter) override {
|
|
|
b579b3 |
Layout l = blist[barrier.tid];
|
|
|
b579b3 |
Layout dl = bl->layout;
|
|
|
b579b3 |
Layout pl = pbl;
|
|
|
b579b3 |
|
|
|
b579b3 |
int d = l.getD();
|
|
|
b579b3 |
int w = l.getW();
|
|
|
b579b3 |
int dx = l.sz - d;
|
|
|
b579b3 |
int dy = (l.sx - w)*l.sz;
|
|
|
b579b3 |
int ddx = dl.getD();
|
|
|
b579b3 |
int ddy = (dl.getW() - w)*ddx;
|
|
|
b579b3 |
|
|
|
b579b3 |
AccumReal aq = 0;
|
|
|
b579b3 |
NeuronReal ratio = this->ratio;
|
|
|
b579b3 |
Neuron *in = bl->neurons + (l.y0*l.sx + l.x0)*l.sz + l.z0;
|
|
|
b579b3 |
const unsigned char *id = data.data() + shuffle2[iter%workCount]*imgsize + ((l.y0-dl.y0)*l.sx + l.x0-dl.x0)*l.sz + l.z0-dl.z0;
|
|
|
e865c9 |
|
|
|
b579b3 |
for(int y = l.y0; y < l.y1; ++y, in += dy, id += ddy) {
|
|
|
b579b3 |
bool outside = y < pl.y0 || y >= pl.y1;
|
|
|
b579b3 |
for(int x = l.x0; x < l.x1; ++x, in += dx, id += ddx) {
|
|
|
b579b3 |
if (outside || x < pl.x0 || x >= pl.x1) {
|
|
|
b579b3 |
for(Neuron *e = in + d; in < e; ++in) in->d = 0;
|
|
|
b579b3 |
} else {
|
|
|
b579b3 |
const unsigned char *iid = id;
|
|
|
b579b3 |
for(Neuron *e = in + d; in < e; ++in, ++iid) {
|
|
|
b579b3 |
NeuronReal v1 = *iid/(NeuronReal)255;
|
|
|
b579b3 |
NeuronReal v0 = in->v;
|
|
|
b579b3 |
NeuronReal diff = v1 - v0;
|
|
|
b579b3 |
in->d *= diff*ratio;
|
|
|
b579b3 |
aq += diff*diff;
|
|
|
b579b3 |
}
|
|
|
b579b3 |
}
|
|
|
b579b3 |
}
|
|
|
b579b3 |
}
|
|
|
e865c9 |
|
|
|
b579b3 |
return Quality( sqrt(aq/pbl.getActiveCount()) );
|
|
|
e865c9 |
}
|
|
|
e865c9 |
};
|
|
|
e865c9 |
|
|
|
e865c9 |
|
|
|
e865c9 |
#endif
|