Blob Blame Raw
#ifndef TRAIN_SEGMENT_INC_CPP
#define TRAIN_SEGMENT_INC_CPP


#include "segment.inc.cpp"
#include "layer.inc.cpp"


class TrainerSegment: public ThreadControl {
private:
  std::vector<QualityPair> qualities;
  
public:
  Segment *segment;
  AccumReal ratio;
  int threadsCount;
  
  int measuresPerBlock;
  int trainsPerBlock;
  int blocksPerSaving;
  
  int blocksCount;
  AccumReal qmin;

protected:
  volatile int x, y, z;
  int currentBlock;
  bool currentBlockMeasureOnly;
  
  virtual bool prepare() { return true; }
  virtual bool prepareBlock(int block, bool measureOnly) { return true; }
  virtual void finishBlock(int block) { }
  virtual void finish() { }

  virtual void loadData(Barrier &barrier, int block, int iter, bool measure) { }

private:
  void threadFunc(Barrier &barrier) override {
    int block = currentBlock;
    bool measureOnly = currentBlockMeasureOnly;

    QualityPair q;
    if (!measureOnly) {
      for(int i = 0; i < trainsPerBlock; ++i) {
        barrier.wait();
        loadData(barrier, block, i, false);
        barrier.wait();
        q.train += segment->pass(barrier, x, y, z, ratio);
      }
    }
    
    for(int i = 0; i < measuresPerBlock; ++i) {
      barrier.wait();
      loadData(barrier, block, i, true);
      barrier.wait();
      q.measure += segment->pass(barrier, x, y, z, 0);
    }
    
    qualities[barrier.tid] = q;
  }


  QualityPair runThreads(int block, bool measureOnly) {
    currentBlock = block;
    currentBlockMeasureOnly = measureOnly;
    ThreadControl::runThreads(threadsCount);
    QualityPair q;
    for(int i = 0; i < threadsCount; ++i) q += qualities[i];
    q.measure *= 1/(AccumReal)measuresPerBlock;
    q.train *= 1/(AccumReal)trainsPerBlock;
    return q;
  }


public:
  TrainerSegment():
    segment(),
    ratio(),
    threadsCount(1),
    measuresPerBlock(),
    trainsPerBlock(),
    blocksPerSaving(),
    blocksCount(),
    qmin(),
    x(), y(), z()
    { }


  QualityPair run() {
    int trainsPerBlock = ratio > 0 ? this->trainsPerBlock : 0;
    int blocksCount = trainsPerBlock > 0 || this->blocksCount > 0 ? this->blocksCount : 1;

    assert(segment);
    assert(threadsCount > 0);
    assert(measuresPerBlock >= 0);
    assert(trainsPerBlock >= 0);
    assert(measuresPerBlock + trainsPerBlock > 0);

    QualityPair bad(Quality::bad(), Quality::bad());
    
    printf("training segment: threads %d, trainsPerBlock %d, measuresPerBlock %d, ratio: %lf\n", threadsCount, trainsPerBlock, measuresPerBlock, ratio);
    fflush(stdout);

    qualities.clear();
    qualities.resize(threadsCount);
    segment->split(threadsCount);
    
    if (!prepare())
      return printf("cannot prepare\n"), bad;

    
    QualityPair result = bad, best = result, saved = result;
    long long fullTimeStartUs = timeUs();
    int i = 0;
    int bps = blocksPerSaving > 0 ? blocksPerSaving : 1;
    int nextSave = i + bps;
    while(true) {
      bool measureOnly = measuresPerBlock > 0 && (!i || trainsPerBlock <= 0);
      
      if (!prepareBlock(i, measureOnly)) {
        printf("cannot prepare block\n");
        result = bad;
        break;
      };

      long long runTimeUs = timeUs();
      result = runThreads(i, measureOnly);
      runTimeUs = timeUs() - runTimeUs;

      finishBlock(i);

      long long t = timeUs();
      long long fullTimeUs = t - fullTimeStartUs;
      fullTimeStartUs = t;
      ++i;

      Quality q = measuresPerBlock > 0 ? result.measure : result.train;
      
      if (i == 1) saved = result;
      bool good = result < best;
      bool done = (blocksCount > 0 && i >= blocksCount) || q.human <= qmin;
      bool saving = !measureOnly && ratio > 0 && (i >= nextSave || done) && result < saved;
      if (good) best = result;
      
      Quality bq = measuresPerBlock > 0 ? best.measure : best.train;
      
      printf("%4d, total %7d, avg.result %12g (%12g), best %12g (%12g), time: %f / %f%s\n",
        i, i*trainsPerBlock,
        q.human, q.train, bq.human, bq.train,
        runTimeUs*0.000001, fullTimeUs*0.000001,
        (saving ? ", saving" : "" ) );
      fflush(stdout);

      if (saving) {
        if (!segment->save()) {
          printf("saving failed\n");
          result = bad;
          break;
        }
        saved = result;
        nextSave += bps;
      }

      if (done) break;
    }

    finish();

    return result;
  }
};


#endif