|
|
b579b3 |
#ifndef TRAIN_SEGMENT_INC_CPP
|
|
|
b579b3 |
#define TRAIN_SEGMENT_INC_CPP
|
|
|
b579b3 |
|
|
|
b579b3 |
|
|
|
b579b3 |
#include "segment.inc.cpp"
|
|
|
b579b3 |
#include "layer.inc.cpp"
|
|
|
b579b3 |
|
|
|
b579b3 |
|
|
|
b579b3 |
class TrainerSegment {
|
|
|
b579b3 |
private:
|
|
|
b579b3 |
std::atomic<unsigned int=""> barrierCounter;</unsigned>
|
|
|
b579b3 |
std::vector<qualitypair> qualities;</qualitypair>
|
|
|
b579b3 |
|
|
|
b579b3 |
public:
|
|
|
b579b3 |
Segment *segment;
|
|
|
b579b3 |
AccumReal ratio;
|
|
|
b579b3 |
int threadsCount;
|
|
|
b579b3 |
|
|
|
b579b3 |
int measuresPerBlock;
|
|
|
b579b3 |
int trainsPerBlock;
|
|
|
b579b3 |
int blocksPerSaving;
|
|
|
b579b3 |
|
|
|
b579b3 |
int blocksCount;
|
|
|
b579b3 |
AccumReal qmin;
|
|
|
b579b3 |
|
|
|
b579b3 |
protected:
|
|
|
b579b3 |
volatile int x, y, z;
|
|
|
b579b3 |
|
|
|
b579b3 |
virtual bool prepare() { return true; }
|
|
|
b579b3 |
virtual bool prepareBlock(int block, bool measureOnly) { return true; }
|
|
|
b579b3 |
virtual void finishBlock(int block) { }
|
|
|
b579b3 |
virtual void finish() { }
|
|
|
b579b3 |
|
|
|
b579b3 |
virtual void loadData(Barrier &barrier, int block, int iter, bool measure) { }
|
|
|
b579b3 |
|
|
|
b579b3 |
private:
|
|
|
b579b3 |
void threadFunc(int tid, unsigned int seed, int block, bool measureOnly) {
|
|
|
b579b3 |
Barrier barrier(barrierCounter, tid, threadsCount, seed);
|
|
|
b579b3 |
|
|
|
b579b3 |
QualityPair q;
|
|
|
b579b3 |
if (!measureOnly) {
|
|
|
b579b3 |
for(int i = 0; i < trainsPerBlock; ++i) {
|
|
|
b579b3 |
barrier.wait();
|
|
|
b579b3 |
loadData(barrier, block, i, false);
|
|
|
b579b3 |
barrier.wait();
|
|
|
b579b3 |
q.train += segment->pass(barrier, x, y, z, ratio);
|
|
|
b579b3 |
}
|
|
|
b579b3 |
}
|
|
|
b579b3 |
|
|
|
b579b3 |
for(int i = 0; i < measuresPerBlock; ++i) {
|
|
|
b579b3 |
barrier.wait();
|
|
|
b579b3 |
loadData(barrier, block, i, true);
|
|
|
b579b3 |
barrier.wait();
|
|
|
b579b3 |
q.measure += segment->pass(barrier, x, y, z, 0);
|
|
|
b579b3 |
}
|
|
|
b579b3 |
|
|
|
b579b3 |
qualities[tid] = q;
|
|
|
b579b3 |
}
|
|
|
b579b3 |
|
|
|
b579b3 |
|
|
|
b579b3 |
QualityPair runThreads(int block, bool measureOnly) {
|
|
|
b579b3 |
barrierCounter = 0;
|
|
|
b579b3 |
std::vector<std::thread*> t(threadsCount, nullptr);</std::thread*>
|
|
|
b579b3 |
for(int i = 1; i < threadsCount; ++i)
|
|
|
b579b3 |
t[i] = new std::thread(&TrainerSegment::threadFunc, this, i, rand(), block, measureOnly);
|
|
|
b579b3 |
threadFunc(0, rand(), block, measureOnly);
|
|
|
b579b3 |
|
|
|
b579b3 |
QualityPair q = qualities[0];
|
|
|
b579b3 |
for(int i = 1; i < threadsCount; ++i)
|
|
|
b579b3 |
{ t[i]->join(); delete t[i]; q += qualities[i]; }
|
|
|
b579b3 |
|
|
|
b579b3 |
q.measure *= 1/(AccumReal)measuresPerBlock;
|
|
|
b579b3 |
q.train *= 1/(AccumReal)trainsPerBlock;
|
|
|
b579b3 |
return q;
|
|
|
b579b3 |
}
|
|
|
b579b3 |
|
|
|
b579b3 |
|
|
|
b579b3 |
public:
|
|
|
b579b3 |
TrainerSegment():
|
|
|
b579b3 |
barrierCounter(0),
|
|
|
b579b3 |
segment(),
|
|
|
b579b3 |
ratio(),
|
|
|
b579b3 |
threadsCount(1),
|
|
|
b579b3 |
measuresPerBlock(),
|
|
|
b579b3 |
trainsPerBlock(),
|
|
|
b579b3 |
blocksPerSaving(),
|
|
|
b579b3 |
blocksCount(),
|
|
|
b579b3 |
qmin(),
|
|
|
b579b3 |
x(), y(), z()
|
|
|
b579b3 |
{ }
|
|
|
b579b3 |
|
|
|
b579b3 |
|
|
|
b579b3 |
QualityPair run() {
|
|
|
b579b3 |
int trainsPerBlock = ratio > 0 ? this->trainsPerBlock : 0;
|
|
|
b579b3 |
int blocksCount = trainsPerBlock > 0 || this->blocksCount > 0 ? this->blocksCount : 1;
|
|
|
b579b3 |
|
|
|
b579b3 |
assert(segment);
|
|
|
b579b3 |
assert(threadsCount > 0);
|
|
|
b579b3 |
assert(measuresPerBlock >= 0);
|
|
|
b579b3 |
assert(trainsPerBlock >= 0);
|
|
|
b579b3 |
assert(measuresPerBlock + trainsPerBlock > 0);
|
|
|
b579b3 |
|
|
|
b579b3 |
QualityPair bad(Quality::bad(), Quality::bad());
|
|
|
b579b3 |
|
|
|
b579b3 |
printf("training segment: threads %d, trainsPerBlock %d, measuresPerBlock %d, ratio: %lf\n", threadsCount, trainsPerBlock, measuresPerBlock, ratio);
|
|
|
b579b3 |
fflush(stdout);
|
|
|
b579b3 |
|
|
|
b579b3 |
qualities.clear();
|
|
|
b579b3 |
qualities.resize(threadsCount);
|
|
|
b579b3 |
segment->split(threadsCount);
|
|
|
b579b3 |
|
|
|
b579b3 |
if (!prepare())
|
|
|
b579b3 |
return printf("cannot prepare\n"), bad;
|
|
|
b579b3 |
|
|
|
b579b3 |
|
|
|
b579b3 |
QualityPair result = bad, best = result, saved = result;
|
|
|
b579b3 |
long long fullTimeStartUs = timeUs();
|
|
|
b579b3 |
int i = 0;
|
|
|
b579b3 |
int bps = blocksPerSaving > 0 ? blocksPerSaving : 1;
|
|
|
b579b3 |
int nextSave = i + bps;
|
|
|
b579b3 |
while(true) {
|
|
|
b579b3 |
bool measureOnly = measuresPerBlock > 0 && (!i || trainsPerBlock <= 0);
|
|
|
b579b3 |
|
|
|
b579b3 |
if (!prepareBlock(i, measureOnly)) {
|
|
|
b579b3 |
printf("cannot prepare block\n");
|
|
|
b579b3 |
result = bad;
|
|
|
b579b3 |
break;
|
|
|
b579b3 |
};
|
|
|
b579b3 |
|
|
|
b579b3 |
long long runTimeUs = timeUs();
|
|
|
b579b3 |
result = runThreads(i, measureOnly);
|
|
|
b579b3 |
runTimeUs = timeUs() - runTimeUs;
|
|
|
b579b3 |
|
|
|
b579b3 |
finishBlock(i);
|
|
|
b579b3 |
|
|
|
b579b3 |
long long t = timeUs();
|
|
|
b579b3 |
long long fullTimeUs = t - fullTimeStartUs;
|
|
|
b579b3 |
fullTimeStartUs = t;
|
|
|
b579b3 |
++i;
|
|
|
b579b3 |
|
|
|
b579b3 |
Quality q = measuresPerBlock > 0 ? result.measure : result.train;
|
|
|
b579b3 |
|
|
|
b579b3 |
if (i == 1) saved = result;
|
|
|
b579b3 |
bool good = result < best;
|
|
|
b579b3 |
bool done = (blocksCount > 0 && i >= blocksCount) || q.human <= qmin;
|
|
|
b579b3 |
bool saving = !measureOnly && ratio > 0 && (i >= nextSave || done) && result < saved;
|
|
|
b579b3 |
if (good) best = result;
|
|
|
b579b3 |
|
|
|
b579b3 |
Quality bq = measuresPerBlock > 0 ? best.measure : best.train;
|
|
|
b579b3 |
|
|
|
b579b3 |
printf("%4d, total %7d, avg.result %12g (%12g), best %12g (%12g), time: %f / %f%s\n",
|
|
|
b579b3 |
i, i*trainsPerBlock,
|
|
|
b579b3 |
q.human, q.train, bq.human, bq.train,
|
|
|
b579b3 |
runTimeUs*0.000001, fullTimeUs*0.000001,
|
|
|
b579b3 |
(saving ? ", saving" : "" ) );
|
|
|
b579b3 |
fflush(stdout);
|
|
|
b579b3 |
|
|
|
b579b3 |
if (saving) {
|
|
|
b579b3 |
if (!segment->save()) {
|
|
|
b579b3 |
printf("saving failed\n");
|
|
|
b579b3 |
result = bad;
|
|
|
b579b3 |
break;
|
|
|
b579b3 |
}
|
|
|
b579b3 |
saved = result;
|
|
|
b579b3 |
nextSave += bps;
|
|
|
b579b3 |
}
|
|
|
b579b3 |
|
|
|
b579b3 |
if (done) break;
|
|
|
b579b3 |
}
|
|
|
b579b3 |
|
|
|
b579b3 |
finish();
|
|
|
b579b3 |
|
|
|
b579b3 |
return result;
|
|
|
b579b3 |
}
|
|
|
b579b3 |
};
|
|
|
b579b3 |
|
|
|
b579b3 |
|
|
|
b579b3 |
#endif
|