#ifndef TRAIN_SEGMENT_INC_CPP
#define TRAIN_SEGMENT_INC_CPP
#include "segment.inc.cpp"
#include "layer.inc.cpp"
class TrainerSegment: public ThreadControl {
private:
std::vector<QualityPair> qualities;
public:
Segment *segment;
AccumReal ratio;
int threadsCount;
int measuresPerBlock;
int trainsPerBlock;
int blocksPerSaving;
int blocksCount;
AccumReal qmin;
protected:
volatile int x, y, z;
int currentBlock;
bool currentBlockMeasureOnly;
virtual bool prepare() { return true; }
virtual bool prepareBlock(int block, bool measureOnly) { return true; }
virtual void finishBlock(int block) { }
virtual void finish() { }
virtual void loadData(Barrier &barrier, int block, int iter, bool measure) { }
private:
void threadFunc(Barrier &barrier) override {
int block = currentBlock;
bool measureOnly = currentBlockMeasureOnly;
QualityPair q;
if (!measureOnly) {
for(int i = 0; i < trainsPerBlock; ++i) {
barrier.wait();
loadData(barrier, block, i, false);
barrier.wait();
q.train += segment->pass(barrier, x, y, z, ratio);
}
}
for(int i = 0; i < measuresPerBlock; ++i) {
barrier.wait();
loadData(barrier, block, i, true);
barrier.wait();
q.measure += segment->pass(barrier, x, y, z, 0);
}
qualities[barrier.tid] = q;
}
QualityPair runThreads(int block, bool measureOnly) {
currentBlock = block;
currentBlockMeasureOnly = measureOnly;
ThreadControl::runThreads(threadsCount);
QualityPair q;
for(int i = 0; i < threadsCount; ++i) q += qualities[i];
q.measure *= 1/(AccumReal)measuresPerBlock;
q.train *= 1/(AccumReal)trainsPerBlock;
return q;
}
public:
TrainerSegment():
segment(),
ratio(),
threadsCount(1),
measuresPerBlock(),
trainsPerBlock(),
blocksPerSaving(),
blocksCount(),
qmin(),
x(), y(), z()
{ }
QualityPair run() {
int trainsPerBlock = ratio > 0 ? this->trainsPerBlock : 0;
int blocksCount = trainsPerBlock > 0 || this->blocksCount > 0 ? this->blocksCount : 1;
assert(segment);
assert(threadsCount > 0);
assert(measuresPerBlock >= 0);
assert(trainsPerBlock >= 0);
assert(measuresPerBlock + trainsPerBlock > 0);
QualityPair bad(Quality::bad(), Quality::bad());
printf("training segment: threads %d, trainsPerBlock %d, measuresPerBlock %d, ratio: %lf\n", threadsCount, trainsPerBlock, measuresPerBlock, ratio);
fflush(stdout);
qualities.clear();
qualities.resize(threadsCount);
segment->split(threadsCount);
if (!prepare())
return printf("cannot prepare\n"), bad;
QualityPair result = bad, best = result, saved = result;
long long fullTimeStartUs = timeUs();
int i = 0;
int bps = blocksPerSaving > 0 ? blocksPerSaving : 1;
int nextSave = i + bps;
while(true) {
bool measureOnly = measuresPerBlock > 0 && (!i || trainsPerBlock <= 0);
if (!prepareBlock(i, measureOnly)) {
printf("cannot prepare block\n");
result = bad;
break;
};
long long runTimeUs = timeUs();
result = runThreads(i, measureOnly);
runTimeUs = timeUs() - runTimeUs;
finishBlock(i);
long long t = timeUs();
long long fullTimeUs = t - fullTimeStartUs;
fullTimeStartUs = t;
++i;
Quality q = measuresPerBlock > 0 ? result.measure : result.train;
if (i == 1) saved = result;
bool good = result < best;
bool done = (blocksCount > 0 && i >= blocksCount) || q.human <= qmin;
bool saving = !measureOnly && ratio > 0 && (i >= nextSave || done) && result < saved;
if (good) best = result;
Quality bq = measuresPerBlock > 0 ? best.measure : best.train;
printf("%4d, total %7d, avg.result %12g (%12g), best %12g (%12g), time: %f / %f%s\n",
i, i*trainsPerBlock,
q.human, q.train, bq.human, bq.train,
runTimeUs*0.000001, fullTimeUs*0.000001,
(saving ? ", saving" : "" ) );
fflush(stdout);
if (saving) {
if (!segment->save()) {
printf("saving failed\n");
result = bad;
break;
}
saved = result;
nextSave += bps;
}
if (done) break;
}
finish();
return result;
}
};
#endif