Blame projects/neural/train.segment.inc.cpp

b579b3
#ifndef TRAIN_SEGMENT_INC_CPP
b579b3
#define TRAIN_SEGMENT_INC_CPP
b579b3
b579b3
b579b3
#include "segment.inc.cpp"
b579b3
#include "layer.inc.cpp"
b579b3
b579b3
b579b3
class TrainerSegment {
b579b3
private:
b579b3
  std::atomic<unsigned int=""> barrierCounter;</unsigned>
b579b3
  std::vector<qualitypair> qualities;</qualitypair>
b579b3
b579b3
public:
b579b3
  Segment *segment;
b579b3
  AccumReal ratio;
b579b3
  int threadsCount;
b579b3
  
b579b3
  int measuresPerBlock;
b579b3
  int trainsPerBlock;
b579b3
  int blocksPerSaving;
b579b3
  
b579b3
  int blocksCount;
b579b3
  AccumReal qmin;
b579b3
b579b3
protected:
b579b3
  volatile int x, y, z;
b579b3
  
b579b3
  virtual bool prepare() { return true; }
b579b3
  virtual bool prepareBlock(int block, bool measureOnly) { return true; }
b579b3
  virtual void finishBlock(int block) { }
b579b3
  virtual void finish() { }
b579b3
b579b3
  virtual void loadData(Barrier &barrier, int block, int iter, bool measure) { }
b579b3
b579b3
private:
b579b3
  void threadFunc(int tid, unsigned int seed, int block, bool measureOnly) {
b579b3
    Barrier barrier(barrierCounter, tid, threadsCount, seed);
b579b3
b579b3
    QualityPair q;
b579b3
    if (!measureOnly) {
b579b3
      for(int i = 0; i < trainsPerBlock; ++i) {
b579b3
        barrier.wait();
b579b3
        loadData(barrier, block, i, false);
b579b3
        barrier.wait();
b579b3
        q.train += segment->pass(barrier, x, y, z, ratio);
b579b3
      }
b579b3
    }
b579b3
    
b579b3
    for(int i = 0; i < measuresPerBlock; ++i) {
b579b3
      barrier.wait();
b579b3
      loadData(barrier, block, i, true);
b579b3
      barrier.wait();
b579b3
      q.measure += segment->pass(barrier, x, y, z, 0);
b579b3
    }
b579b3
    
b579b3
    qualities[tid] = q;
b579b3
  }
b579b3
b579b3
b579b3
  QualityPair runThreads(int block, bool measureOnly) {
b579b3
    barrierCounter = 0;
b579b3
    std::vector<std::thread*> t(threadsCount, nullptr);</std::thread*>
b579b3
    for(int i = 1; i < threadsCount; ++i)
b579b3
      t[i] = new std::thread(&TrainerSegment::threadFunc, this, i, rand(), block, measureOnly);
b579b3
    threadFunc(0, rand(), block, measureOnly);
b579b3
b579b3
    QualityPair q = qualities[0];
b579b3
    for(int i = 1; i < threadsCount; ++i)
b579b3
      { t[i]->join(); delete t[i]; q += qualities[i]; }
b579b3
    
b579b3
    q.measure *= 1/(AccumReal)measuresPerBlock;
b579b3
    q.train *= 1/(AccumReal)trainsPerBlock;
b579b3
    return q;
b579b3
  }
b579b3
b579b3
b579b3
public:
b579b3
  TrainerSegment():
b579b3
    barrierCounter(0),
b579b3
    segment(),
b579b3
    ratio(),
b579b3
    threadsCount(1),
b579b3
    measuresPerBlock(),
b579b3
    trainsPerBlock(),
b579b3
    blocksPerSaving(),
b579b3
    blocksCount(),
b579b3
    qmin(),
b579b3
    x(), y(), z()
b579b3
    { }
b579b3
b579b3
b579b3
  QualityPair run() {
b579b3
    int trainsPerBlock = ratio > 0 ? this->trainsPerBlock : 0;
b579b3
    int blocksCount = trainsPerBlock > 0 || this->blocksCount > 0 ? this->blocksCount : 1;
b579b3
b579b3
    assert(segment);
b579b3
    assert(threadsCount > 0);
b579b3
    assert(measuresPerBlock >= 0);
b579b3
    assert(trainsPerBlock >= 0);
b579b3
    assert(measuresPerBlock + trainsPerBlock > 0);
b579b3
b579b3
    QualityPair bad(Quality::bad(), Quality::bad());
b579b3
    
b579b3
    printf("training segment: threads %d, trainsPerBlock %d, measuresPerBlock %d, ratio: %lf\n", threadsCount, trainsPerBlock, measuresPerBlock, ratio);
b579b3
    fflush(stdout);
b579b3
b579b3
    qualities.clear();
b579b3
    qualities.resize(threadsCount);
b579b3
    segment->split(threadsCount);
b579b3
    
b579b3
    if (!prepare())
b579b3
      return printf("cannot prepare\n"), bad;
b579b3
b579b3
    
b579b3
    QualityPair result = bad, best = result, saved = result;
b579b3
    long long fullTimeStartUs = timeUs();
b579b3
    int i = 0;
b579b3
    int bps = blocksPerSaving > 0 ? blocksPerSaving : 1;
b579b3
    int nextSave = i + bps;
b579b3
    while(true) {
b579b3
      bool measureOnly = measuresPerBlock > 0 && (!i || trainsPerBlock <= 0);
b579b3
      
b579b3
      if (!prepareBlock(i, measureOnly)) {
b579b3
        printf("cannot prepare block\n");
b579b3
        result = bad;
b579b3
        break;
b579b3
      };
b579b3
b579b3
      long long runTimeUs = timeUs();
b579b3
      result = runThreads(i, measureOnly);
b579b3
      runTimeUs = timeUs() - runTimeUs;
b579b3
b579b3
      finishBlock(i);
b579b3
b579b3
      long long t = timeUs();
b579b3
      long long fullTimeUs = t - fullTimeStartUs;
b579b3
      fullTimeStartUs = t;
b579b3
      ++i;
b579b3
b579b3
      Quality q = measuresPerBlock > 0 ? result.measure : result.train;
b579b3
      
b579b3
      if (i == 1) saved = result;
b579b3
      bool good = result < best;
b579b3
      bool done = (blocksCount > 0 && i >= blocksCount) || q.human <= qmin;
b579b3
      bool saving = !measureOnly && ratio > 0 && (i >= nextSave || done) && result < saved;
b579b3
      if (good) best = result;
b579b3
      
b579b3
      Quality bq = measuresPerBlock > 0 ? best.measure : best.train;
b579b3
      
b579b3
      printf("%4d, total %7d, avg.result %12g (%12g), best %12g (%12g), time: %f / %f%s\n",
b579b3
        i, i*trainsPerBlock,
b579b3
        q.human, q.train, bq.human, bq.train,
b579b3
        runTimeUs*0.000001, fullTimeUs*0.000001,
b579b3
        (saving ? ", saving" : "" ) );
b579b3
      fflush(stdout);
b579b3
b579b3
      if (saving) {
b579b3
        if (!segment->save()) {
b579b3
          printf("saving failed\n");
b579b3
          result = bad;
b579b3
          break;
b579b3
        }
b579b3
        saved = result;
b579b3
        nextSave += bps;
b579b3
      }
b579b3
b579b3
      if (done) break;
b579b3
    }
b579b3
b579b3
    finish();
b579b3
b579b3
    return result;
b579b3
  }
b579b3
};
b579b3
b579b3
b579b3
#endif