Blame projects/neural/train.segment.inc.cpp

Ivan Mahonin b579b3
#ifndef TRAIN_SEGMENT_INC_CPP
Ivan Mahonin b579b3
#define TRAIN_SEGMENT_INC_CPP
Ivan Mahonin b579b3
Ivan Mahonin b579b3
Ivan Mahonin b579b3
#include "segment.inc.cpp"
Ivan Mahonin b579b3
#include "layer.inc.cpp"
Ivan Mahonin b579b3
Ivan Mahonin b579b3
Ivan Mahonin e4740d
class TrainerSegment: public ThreadControl {
Ivan Mahonin b579b3
private:
Ivan Mahonin b579b3
  std::vector<QualityPair> qualities;
Ivan Mahonin e4740d
  
Ivan Mahonin b579b3
public:
Ivan Mahonin b579b3
  Segment *segment;
Ivan Mahonin b579b3
  AccumReal ratio;
Ivan Mahonin b579b3
  int threadsCount;
Ivan Mahonin b579b3
  
Ivan Mahonin b579b3
  int measuresPerBlock;
Ivan Mahonin b579b3
  int trainsPerBlock;
Ivan Mahonin b579b3
  int blocksPerSaving;
Ivan Mahonin b579b3
  
Ivan Mahonin b579b3
  int blocksCount;
Ivan Mahonin b579b3
  AccumReal qmin;
Ivan Mahonin b579b3
Ivan Mahonin b579b3
protected:
Ivan Mahonin b579b3
  volatile int x, y, z;
Ivan Mahonin e4740d
  int currentBlock;
Ivan Mahonin e4740d
  bool currentBlockMeasureOnly;
Ivan Mahonin b579b3
  
Ivan Mahonin b579b3
  virtual bool prepare() { return true; }
Ivan Mahonin b579b3
  virtual bool prepareBlock(int block, bool measureOnly) { return true; }
Ivan Mahonin b579b3
  virtual void finishBlock(int block) { }
Ivan Mahonin b579b3
  virtual void finish() { }
Ivan Mahonin b579b3
Ivan Mahonin b579b3
  virtual void loadData(Barrier &barrier, int block, int iter, bool measure) { }
Ivan Mahonin b579b3
Ivan Mahonin b579b3
private:
Ivan Mahonin e4740d
  void threadFunc(Barrier &barrier) override {
Ivan Mahonin e4740d
    int block = currentBlock;
Ivan Mahonin e4740d
    bool measureOnly = currentBlockMeasureOnly;
Ivan Mahonin b579b3
Ivan Mahonin b579b3
    QualityPair q;
Ivan Mahonin b579b3
    if (!measureOnly) {
Ivan Mahonin b579b3
      for(int i = 0; i < trainsPerBlock; ++i) {
Ivan Mahonin b579b3
        barrier.wait();
Ivan Mahonin b579b3
        loadData(barrier, block, i, false);
Ivan Mahonin b579b3
        barrier.wait();
Ivan Mahonin b579b3
        q.train += segment->pass(barrier, x, y, z, ratio);
Ivan Mahonin b579b3
      }
Ivan Mahonin b579b3
    }
Ivan Mahonin b579b3
    
Ivan Mahonin b579b3
    for(int i = 0; i < measuresPerBlock; ++i) {
Ivan Mahonin b579b3
      barrier.wait();
Ivan Mahonin b579b3
      loadData(barrier, block, i, true);
Ivan Mahonin b579b3
      barrier.wait();
Ivan Mahonin b579b3
      q.measure += segment->pass(barrier, x, y, z, 0);
Ivan Mahonin b579b3
    }
Ivan Mahonin b579b3
    
Ivan Mahonin e4740d
    qualities[barrier.tid] = q;
Ivan Mahonin b579b3
  }
Ivan Mahonin b579b3
Ivan Mahonin b579b3
Ivan Mahonin b579b3
  QualityPair runThreads(int block, bool measureOnly) {
Ivan Mahonin e4740d
    currentBlock = block;
Ivan Mahonin e4740d
    currentBlockMeasureOnly = measureOnly;
Ivan Mahonin e4740d
    ThreadControl::runThreads(threadsCount);
Ivan Mahonin e4740d
    QualityPair q;
Ivan Mahonin e4740d
    for(int i = 0; i < threadsCount; ++i) q += qualities[i];
Ivan Mahonin b579b3
    q.measure *= 1/(AccumReal)measuresPerBlock;
Ivan Mahonin b579b3
    q.train *= 1/(AccumReal)trainsPerBlock;
Ivan Mahonin b579b3
    return q;
Ivan Mahonin b579b3
  }
Ivan Mahonin b579b3
Ivan Mahonin b579b3
Ivan Mahonin b579b3
public:
Ivan Mahonin b579b3
  TrainerSegment():
Ivan Mahonin b579b3
    segment(),
Ivan Mahonin b579b3
    ratio(),
Ivan Mahonin b579b3
    threadsCount(1),
Ivan Mahonin b579b3
    measuresPerBlock(),
Ivan Mahonin b579b3
    trainsPerBlock(),
Ivan Mahonin b579b3
    blocksPerSaving(),
Ivan Mahonin b579b3
    blocksCount(),
Ivan Mahonin b579b3
    qmin(),
Ivan Mahonin b579b3
    x(), y(), z()
Ivan Mahonin b579b3
    { }
Ivan Mahonin b579b3
Ivan Mahonin b579b3
Ivan Mahonin b579b3
  QualityPair run() {
Ivan Mahonin b579b3
    int trainsPerBlock = ratio > 0 ? this->trainsPerBlock : 0;
Ivan Mahonin b579b3
    int blocksCount = trainsPerBlock > 0 || this->blocksCount > 0 ? this->blocksCount : 1;
Ivan Mahonin b579b3
Ivan Mahonin b579b3
    assert(segment);
Ivan Mahonin b579b3
    assert(threadsCount > 0);
Ivan Mahonin b579b3
    assert(measuresPerBlock >= 0);
Ivan Mahonin b579b3
    assert(trainsPerBlock >= 0);
Ivan Mahonin b579b3
    assert(measuresPerBlock + trainsPerBlock > 0);
Ivan Mahonin b579b3
Ivan Mahonin b579b3
    QualityPair bad(Quality::bad(), Quality::bad());
Ivan Mahonin b579b3
    
Ivan Mahonin b579b3
    printf("training segment: threads %d, trainsPerBlock %d, measuresPerBlock %d, ratio: %lf\n", threadsCount, trainsPerBlock, measuresPerBlock, ratio);
Ivan Mahonin b579b3
    fflush(stdout);
Ivan Mahonin b579b3
Ivan Mahonin b579b3
    qualities.clear();
Ivan Mahonin b579b3
    qualities.resize(threadsCount);
Ivan Mahonin b579b3
    segment->split(threadsCount);
Ivan Mahonin b579b3
    
Ivan Mahonin b579b3
    if (!prepare())
Ivan Mahonin b579b3
      return printf("cannot prepare\n"), bad;
Ivan Mahonin b579b3
Ivan Mahonin b579b3
    
Ivan Mahonin b579b3
    QualityPair result = bad, best = result, saved = result;
Ivan Mahonin b579b3
    long long fullTimeStartUs = timeUs();
Ivan Mahonin b579b3
    int i = 0;
Ivan Mahonin b579b3
    int bps = blocksPerSaving > 0 ? blocksPerSaving : 1;
Ivan Mahonin b579b3
    int nextSave = i + bps;
Ivan Mahonin b579b3
    while(true) {
Ivan Mahonin b579b3
      bool measureOnly = measuresPerBlock > 0 && (!i || trainsPerBlock <= 0);
Ivan Mahonin b579b3
      
Ivan Mahonin b579b3
      if (!prepareBlock(i, measureOnly)) {
Ivan Mahonin b579b3
        printf("cannot prepare block\n");
Ivan Mahonin b579b3
        result = bad;
Ivan Mahonin b579b3
        break;
Ivan Mahonin b579b3
      };
Ivan Mahonin b579b3
Ivan Mahonin b579b3
      long long runTimeUs = timeUs();
Ivan Mahonin b579b3
      result = runThreads(i, measureOnly);
Ivan Mahonin b579b3
      runTimeUs = timeUs() - runTimeUs;
Ivan Mahonin b579b3
Ivan Mahonin b579b3
      finishBlock(i);
Ivan Mahonin b579b3
Ivan Mahonin b579b3
      long long t = timeUs();
Ivan Mahonin b579b3
      long long fullTimeUs = t - fullTimeStartUs;
Ivan Mahonin b579b3
      fullTimeStartUs = t;
Ivan Mahonin b579b3
      ++i;
Ivan Mahonin b579b3
Ivan Mahonin b579b3
      Quality q = measuresPerBlock > 0 ? result.measure : result.train;
Ivan Mahonin b579b3
      
Ivan Mahonin b579b3
      if (i == 1) saved = result;
Ivan Mahonin b579b3
      bool good = result < best;
Ivan Mahonin b579b3
      bool done = (blocksCount > 0 && i >= blocksCount) || q.human <= qmin;
Ivan Mahonin b579b3
      bool saving = !measureOnly && ratio > 0 && (i >= nextSave || done) && result < saved;
Ivan Mahonin b579b3
      if (good) best = result;
Ivan Mahonin b579b3
      
Ivan Mahonin b579b3
      Quality bq = measuresPerBlock > 0 ? best.measure : best.train;
Ivan Mahonin b579b3
      
Ivan Mahonin b579b3
      printf("%4d, total %7d, avg.result %12g (%12g), best %12g (%12g), time: %f / %f%s\n",
Ivan Mahonin b579b3
        i, i*trainsPerBlock,
Ivan Mahonin b579b3
        q.human, q.train, bq.human, bq.train,
Ivan Mahonin b579b3
        runTimeUs*0.000001, fullTimeUs*0.000001,
Ivan Mahonin b579b3
        (saving ? ", saving" : "" ) );
Ivan Mahonin b579b3
      fflush(stdout);
Ivan Mahonin b579b3
Ivan Mahonin b579b3
      if (saving) {
Ivan Mahonin b579b3
        if (!segment->save()) {
Ivan Mahonin b579b3
          printf("saving failed\n");
Ivan Mahonin b579b3
          result = bad;
Ivan Mahonin b579b3
          break;
Ivan Mahonin b579b3
        }
Ivan Mahonin b579b3
        saved = result;
Ivan Mahonin b579b3
        nextSave += bps;
Ivan Mahonin b579b3
      }
Ivan Mahonin b579b3
Ivan Mahonin b579b3
      if (done) break;
Ivan Mahonin b579b3
    }
Ivan Mahonin b579b3
Ivan Mahonin b579b3
    finish();
Ivan Mahonin b579b3
Ivan Mahonin b579b3
    return result;
Ivan Mahonin b579b3
  }
Ivan Mahonin b579b3
};
Ivan Mahonin b579b3
Ivan Mahonin b579b3
Ivan Mahonin b579b3
#endif