Blame projects/neural/train.inc.cpp

Ivan Mahonin 036a8f
#ifndef TRAIN_INC_CPP
Ivan Mahonin 036a8f
#define TRAIN_INC_CPP
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin e865c9
#include "layer.inc.cpp"
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin e4740d
class Trainer: public ThreadControl {
Ivan Mahonin e865c9
private:
Ivan Mahonin 8e5348
  std::vector<Quality> qualities;
Ivan Mahonin e865c9
Ivan Mahonin e865c9
public:
Ivan Mahonin e865c9
  Layer *layer;
Ivan Mahonin e865c9
  AccumReal ratio;
Ivan Mahonin e865c9
  int threadsCount;
Ivan Mahonin e865c9
  int itersPerBlock;
Ivan Mahonin e865c9
  int blocksPerSaving;
Ivan Mahonin e865c9
  int blocksCount;
Ivan Mahonin e865c9
  AccumReal qmin;
Ivan Mahonin e865c9
Ivan Mahonin e865c9
protected:
Ivan Mahonin 8e5348
  std::atomic<unsigned int> skipBackpass;
Ivan Mahonin e865c9
  Layer *fl;
Ivan Mahonin e865c9
  Layer *bl;
Ivan Mahonin b579b3
  Layer *ffl;
Ivan Mahonin e4740d
  int currentBlock;
Ivan Mahonin e865c9
  
Ivan Mahonin e865c9
  virtual bool prepare() { return true; }
Ivan Mahonin e865c9
  virtual bool prepareBlock() { return true; }
Ivan Mahonin e865c9
  virtual void finishBlock() { }
Ivan Mahonin e865c9
  virtual void finish() { }
Ivan Mahonin e865c9
Ivan Mahonin e865c9
  virtual void loadData(Barrier &barrier, int block, int iter) { }
Ivan Mahonin 8e5348
  virtual Quality verifyData(Barrier &barrier, int block, int iter) { return Quality{}; }
Ivan Mahonin e865c9
Ivan Mahonin e865c9
private:
Ivan Mahonin e4740d
  void threadFunc(Barrier &barrier) override {
Ivan Mahonin e4740d
    int block = currentBlock;
Ivan Mahonin b579b3
    Quality sumQ;
Ivan Mahonin e865c9
    for(int i = 0; i < itersPerBlock; ++i) {
Ivan Mahonin e865c9
      barrier.wait();
Ivan Mahonin e865c9
      loadData(barrier, block, i);
Ivan Mahonin e865c9
Ivan Mahonin e865c9
      for(Layer *l = fl->next; l; l = l->next) {
Ivan Mahonin e865c9
        barrier.wait();
Ivan Mahonin e865c9
        l->pass(barrier);
Ivan Mahonin e865c9
      }
Ivan Mahonin e865c9
Ivan Mahonin e865c9
      barrier.wait();
Ivan Mahonin 8e5348
      skipBackpass = 0;
Ivan Mahonin e865c9
      sumQ += verifyData(barrier, block, i);
Ivan Mahonin 8e5348
      
Ivan Mahonin e865c9
      barrier.wait();
Ivan Mahonin 8e5348
      bool skipBp = skipBackpass;
Ivan Mahonin e865c9
      
Ivan Mahonin e865c9
      barrier.wait();
Ivan Mahonin b579b3
      if (ffl && ratio > 0 && !skipBp) {
Ivan Mahonin b579b3
        for(Layer *l = bl; l != ffl; l = l->prev) {
Ivan Mahonin e865c9
          barrier.wait();
Ivan Mahonin e865c9
          l->backpassDeltas(barrier);
Ivan Mahonin e865c9
        }
Ivan Mahonin e865c9
        for(Layer *l = bl; l->prev; l = l->prev) {
Ivan Mahonin b579b3
          if (!l->skipTrain) {
Ivan Mahonin b579b3
            barrier.wait();
Ivan Mahonin b579b3
            l->backpassWeights(barrier);
Ivan Mahonin b579b3
          }
Ivan Mahonin e865c9
        }
Ivan Mahonin e865c9
      }
Ivan Mahonin e865c9
    }
Ivan Mahonin e4740d
    qualities[barrier.tid] = sumQ;
Ivan Mahonin e865c9
  }
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin 8e5348
  Quality runThreads(int block) {
Ivan Mahonin e4740d
    currentBlock = block;
Ivan Mahonin e4740d
    ThreadControl::runThreads(threadsCount);
Ivan Mahonin e4740d
    Quality q;
Ivan Mahonin e4740d
    for(int i = 0; i < threadsCount; ++i) q += qualities[i];
Ivan Mahonin e4740d
    return q *= 1/(AccumReal)itersPerBlock;
Ivan Mahonin e865c9
  }
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin e865c9
public:
Ivan Mahonin e865c9
  Trainer():
Ivan Mahonin e865c9
    layer(),
Ivan Mahonin e865c9
    ratio(),
Ivan Mahonin e865c9
    threadsCount(1),
Ivan Mahonin e865c9
    itersPerBlock(100),
Ivan Mahonin e865c9
    blocksPerSaving(),
Ivan Mahonin e865c9
    blocksCount(1000),
Ivan Mahonin e865c9
    qmin(),
Ivan Mahonin 8e5348
    skipBackpass(0),
Ivan Mahonin e865c9
    fl(),
Ivan Mahonin e865c9
    bl() { }
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin e865c9
  Trainer& configure(
Ivan Mahonin e865c9
    Layer &layer,
Ivan Mahonin e865c9
    AccumReal ratio,
Ivan Mahonin e865c9
    int threadsCount,
Ivan Mahonin e865c9
    int itersPerBlock,
Ivan Mahonin e865c9
    int blocksPerSaving,
Ivan Mahonin e865c9
    int blocksCount,
Ivan Mahonin e865c9
    AccumReal qmin )
Ivan Mahonin e865c9
  {
Ivan Mahonin e865c9
    this->layer           = &layer;
Ivan Mahonin e865c9
    this->ratio           = ratio;
Ivan Mahonin e865c9
    this->threadsCount    = threadsCount;
Ivan Mahonin e865c9
    this->itersPerBlock   = itersPerBlock;
Ivan Mahonin e865c9
    this->blocksPerSaving = blocksPerSaving;
Ivan Mahonin e865c9
    this->blocksCount     = blocksCount;
Ivan Mahonin e865c9
    this->qmin            = qmin;
Ivan Mahonin e865c9
    return *this;
Ivan Mahonin e865c9
  }
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin 8e5348
  Quality run() {
Ivan Mahonin e865c9
    assert(layer && !layer->prev && layer->next);
Ivan Mahonin e865c9
    assert(threadsCount > 0);
Ivan Mahonin e865c9
    assert(itersPerBlock > 0);
Ivan Mahonin e865c9
Ivan Mahonin 8e5348
    Quality bad = {INFINITY, INFINITY};
Ivan Mahonin 8e5348
    
Ivan Mahonin e865c9
    printf("training: threads %d, itersPerBlock %d, ratio: %lf\n", threadsCount, itersPerBlock, ratio);
Ivan Mahonin 8e5348
    fflush(stdout);
Ivan Mahonin e865c9
Ivan Mahonin e865c9
    fl = layer;
Ivan Mahonin e865c9
    bl = &layer->back();
Ivan Mahonin b579b3
    ffl = fl->next;
Ivan Mahonin b579b3
    while(ffl && ffl->skipTrain) ffl = ffl->next;
Ivan Mahonin e865c9
    
Ivan Mahonin e865c9
    qualities.clear();
Ivan Mahonin 8e5348
    qualities.resize(threadsCount, Quality{});
Ivan Mahonin e865c9
    for(Layer *l = layer; l; l = l->next)
Ivan Mahonin e865c9
      l->split(threadsCount);
Ivan Mahonin e865c9
    
Ivan Mahonin e865c9
    if (!prepare())
Ivan Mahonin 8e5348
      return printf("cannot prepare\n"), bad;
Ivan Mahonin e865c9
Ivan Mahonin 8e5348
    
Ivan Mahonin 8e5348
    
Ivan Mahonin 8e5348
    AccumReal ratioCopy = ratio;
Ivan Mahonin 8e5348
    Quality result = bad, best = result, saved = result;
Ivan Mahonin e865c9
    long long fullTimeStartUs = timeUs();
Ivan Mahonin 8e5348
    ratio = 0;
Ivan Mahonin e865c9
    int i = 0;
Ivan Mahonin 8e5348
    int bps = blocksPerSaving > 0 ? blocksPerSaving : 1;
Ivan Mahonin 8e5348
    int nextSave = i + bps;
Ivan Mahonin e865c9
    while(true) {
Ivan Mahonin e865c9
      if (!prepareBlock()) {
Ivan Mahonin e865c9
        printf("cannot prepare block\n");
Ivan Mahonin 8e5348
        result = bad;
Ivan Mahonin e865c9
        break;
Ivan Mahonin e865c9
      };
Ivan Mahonin e865c9
Ivan Mahonin e865c9
      long long runTimeUs = timeUs();
Ivan Mahonin e865c9
      result = runThreads(i);
Ivan Mahonin e865c9
      runTimeUs = timeUs() - runTimeUs;
Ivan Mahonin e865c9
Ivan Mahonin e865c9
      finishBlock();
Ivan Mahonin e865c9
Ivan Mahonin e865c9
      long long t = timeUs();
Ivan Mahonin e865c9
      long long fullTimeUs = t - fullTimeStartUs;
Ivan Mahonin e865c9
      fullTimeStartUs = t;
Ivan Mahonin e865c9
      ++i;
Ivan Mahonin e865c9
Ivan Mahonin 8e5348
      if (i == 1) saved = result;
Ivan Mahonin 8e5348
      bool good = result < best;
Ivan Mahonin 8e5348
      bool done = (blocksCount > 0 && i >= blocksCount) || result.human <= qmin;
Ivan Mahonin 8e5348
      bool saving = ratio > 0 && (i >= nextSave || done) && result < saved;
Ivan Mahonin 8e5348
      if (good) best = result;
Ivan Mahonin 8e5348
      
Ivan Mahonin 8e5348
      printf("%4d, total %7d, avg.result %f (%f), best %f (%f), time: %f / %f%s\n",
Ivan Mahonin 8e5348
        i, i*itersPerBlock,
Ivan Mahonin 8e5348
        result.human, result.train, best.human, best.train,
Ivan Mahonin 8e5348
        runTimeUs*0.000001, fullTimeUs*0.000001,
Ivan Mahonin 8e5348
        (saving ? ", saving" : "" ) );
Ivan Mahonin 8e5348
      fflush(stdout);
Ivan Mahonin 8e5348
Ivan Mahonin 8e5348
      if (saving) {
Ivan Mahonin 8e5348
        if (!layer->save()) {
Ivan Mahonin 8e5348
          printf("saving failed\n");
Ivan Mahonin 8e5348
          result = bad;
Ivan Mahonin 8e5348
          break;
Ivan Mahonin 8e5348
        }
Ivan Mahonin 8e5348
        saved = result;
Ivan Mahonin 8e5348
        nextSave += bps;
Ivan Mahonin e865c9
      }
Ivan Mahonin e865c9
Ivan Mahonin e865c9
      if (done) break;
Ivan Mahonin 8e5348
      ratio = ratioCopy;
Ivan Mahonin e865c9
    }
Ivan Mahonin 8e5348
    ratio = ratioCopy;
Ivan Mahonin e865c9
Ivan Mahonin e865c9
    finish();
Ivan Mahonin e865c9
Ivan Mahonin e865c9
    return result;
Ivan Mahonin e865c9
  }
Ivan Mahonin e865c9
};
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin e865c9
#endif