Blob Blame Raw
#ifndef TRAIN_DIGIT_INC_CPP
#define TRAIN_DIGIT_INC_CPP


#include "train.inc.cpp"
#include "layer.simple.inc.cpp"


class TrainerDigit: public Trainer {
protected:
  std::vector<unsigned char> data;
  std::vector<unsigned int> shuffle;
  Layout ofl, obl;
  Layout::List oflist, oblist;
  int stride, count;

public:
  TrainerDigit(): stride(), count() { }

  bool loadSymbolMap(const char *filename) {
    data.clear();

    FILE *f = fopen(filename, "rb");
    if (!f)
      return printf("cannot open file for read: %s\n", filename), false;
    fseek(f, 0, SEEK_END);
    size_t fs = ftello(f);
    fseek(f, 0, SEEK_SET);

    data.resize(fs, 0);
    if (!fread(data.data(), fs, 1, f))
      return printf("cannot read from file: %s\n", filename), fclose(f), data.clear(), false;

    fclose(f);
    return true;
  }
  
  static void printSymbol(const unsigned char *data, int w, int h, int index = -1) {
    if (index >= 0) printf("\nsymbol %d (%d):\n", (int)data[w*h], index);
               else printf("\nsymbol %d:\n", (int)data[w*h]);
    for(int i = 0; i < h; ++i) {
      for(int j = 0; j < w; ++j) printf("%s", data[i*w+j] > 128u ? "#" : " ");
      printf("\n");
    }
    printf("\n");
  }

  void printSymbol(int index) {
    const Layout &l = layer->layout;
    printSymbol(&data[(l.getActiveCount()+1)*index], l.getW(), l.getH(), index);
  }

protected:
  bool prepare() override {
    ofl = optimizeLayoutSimple(fl->layout);
    obl = optimizeLayoutSimple(bl->layout);
    ofl.split(oflist, threadsCount);
    obl.split(oblist, threadsCount);
    stride = ofl.getActiveCount() + 1;
    count = data.size()/stride;
    if (count <= 0) return false;
    shuffle.resize(count);
    for(int i = 0; i < count; ++i)
      shuffle[i] = i;
    return true;
  }


  bool prepareBlock() override {
    int cnt = itersPerBlock > count ? count : itersPerBlock;
    for(int i = 0; i < cnt; ++i) {
      int j = rand()%count;
      if (i != j) std::swap(shuffle[i], shuffle[j]);
    }
    return true;
  }


  void loadData(Barrier &barrier, int, int iter) override {
    struct I: public Iter {
      typedef const unsigned char* DataType;
      static inline void iter4(Neuron &n, DataType d, DataAccumType&) { n.v = *d/(NeuronReal)255; }
    };
    const unsigned char *id = data.data() + shuffle[iter%count]*stride;
    iterateNeurons2<I>(oflist[barrier.tid], ofl, fl->neurons, id);
  }


  Quality verifyData(Barrier &barrier, int, int iter) override {
    Quality q;
    if (barrier.tid) return q;
    
    struct I: public Iter {
      typedef int DataType;
      struct DataAccumType { int ri, mi; NeuronReal m, ratio, q; };
      static inline void iter4(Neuron &n, DataType d, DataAccumType &a) {
        NeuronReal v1 = d == a.ri;
        NeuronReal v0 = n.v;
        NeuronReal diff = v1 - v0;
        n.d *= diff*a.ratio;
        a.q += diff*diff;
        if (a.m < v0) { a.m = v0; a.mi = d; }
      }
    };
    
    int index = shuffle[iter%count];
    if (index == 59915) {
      ++skipBackpass;
      return q;
    }
    
    I::DataAccumType a = { data[ (index + 1)*stride - 1 ], 0, 0, ratio };
    iterateNeurons2<I>(obl, obl, bl->neurons, 0, 1, &a);
    
    q.train = sqrt(a.q/obl.getActiveCount());
    q.human = a.mi != a.ri;
    //if (!q.human && q.train < 0.01) ++skipBackpass;
    //if (!q.human) ++skipBackpass;
    //if (q.human) printSymbol(index);
    return q;
  }
};


#endif