Blob Blame Raw
#ifndef LAYER_INC_CPP
#define LAYER_INC_CPP


#include "common.inc.cpp"



class WeightHolder {
public:
  const int weightsCount;
  Weight *weights;
  
  const char *filename;
  
  explicit WeightHolder(int weightsCount = 0, Weight *weights = nullptr):
    weightsCount(weightsCount), weights(weights), filename()
    { assert(weightsCount >= 0); }
  
  
  virtual ~WeightHolder() { }
  
  
  bool save(bool demoOnly = false) {
    if (filename && weightsCount && !demoOnly) {
      FILE *f = fopen(filename, "wb");
      if (!f)
        return printf("cannot open file for write: %s\n", filename), false;
      if (!fwrite(weights, sizeof(*weights)*weightsCount, 1, f))
        return fclose(f), printf("cannot write to file: %s\n", filename), false;
      fclose(f);
    }
    return saveDemo();
  }


  bool load() {
    if (filename && weightsCount) {
      FILE *f = fopen(filename, "rb");
      if (!f)
        return printf("cannot open file for read: %s\n", filename), false;
      if (!fread(weights, sizeof(*weights)*weightsCount, 1, f))
        return fclose(f), printf("cannot read from file: %s\n", filename), false;
      fclose(f);
    }
    return true;
  }
  

  virtual bool saveDemo() { return true; }
};



class Layer: public WeightHolder {
public:
  Layer *prev, *next;

  Layout layout;

  Neuron *neurons;
  int neuronsCount;

  bool ownWeights;

  bool skipTrain;

  Stat stat;

  Layout::List mtLayouts;
  Layout::List mtPrevLayouts;


  Layer(Layer *prev, const Layout &layout, int weightsCount = 0, Weight *weights = nullptr):
    WeightHolder(weightsCount, weights),
    prev(prev ? &prev->back() : nullptr),
    next(),
    layout(layout),
    neurons(),
    neuronsCount(layout.getCount()),
    ownWeights(!weights && weightsCount),
    skipTrain()
  {
    assert(layout);
    assert(neuronsCount > 0);
    assert(weightsCount >= 0);
    assert(prev || !weightsCount);

    if (this->prev) this->prev->next = this;
    if (neuronsCount) {
      neurons = new Neuron[neuronsCount];
      memset(neurons, 0, sizeof(*neurons)*neuronsCount);
    }
    if (ownWeights) {
      this->weights = new Weight[weightsCount];
      memset(this->weights, 0, sizeof(*this->weights)*weightsCount);
    }

    stat.neurons = neuronsCount;
    stat.activeNeurons = layout.getActiveCount();
    stat.weights = weightsCount;
    stat.links = weightsCount;
    stat.memsize = neuronsCount*sizeof(*neurons);
    if (ownWeights) stat.memsize += weightsCount*sizeof(*weights);
  }


  ~Layer() {
    if (next) delete next;
    if (neurons) delete[] neurons;
    if (ownWeights) delete[] weights;
  }


  inline Layer& front()
    { Layer *l = this; while(l->prev) l = l->prev; return *l; }
  inline Layer& back()
    { Layer *l = this; while(l->next) l = l->next; return *l; }
  inline Stat sumStat() const
    { Stat s; for(const Layer *l = this; l; l = l->next) s += l->stat; return s; }

  bool save(bool demoOnly = false)
    { return WeightHolder::save(demoOnly) && (!next || next->save(demoOnly)); }
  bool load()
    { return WeightHolder::load() && (!next || next->load()); }

 
  void clearAccum() {
    Accum a = {};
    for(Neuron *in = neurons, *e = in + neuronsCount; in < e; ++in)
      in->a = a;
  }

  
  void fillWeights(WeightReal wmin, WeightReal wmax) {
    WeightReal k = (wmax - wmin)/RAND_MAX;
    for(Weight *iw = weights, *e = iw + weightsCount; iw < e; ++iw)
      iw->w = rand()*k + wmin;
  }
 

  virtual void split(int threadsCount) {
    layout.split(mtLayouts, threadsCount);
    if (prev) prev->layout.split(mtPrevLayouts, threadsCount);
  }
  virtual void pass(Barrier &barrier) { }
  virtual void backpassWeights(Barrier &barrier) { }
  virtual void backpassDeltas(Barrier &barrier) { }
  
  virtual void testPass() { }
  virtual void testBackpass() { }
  
  
  void passFull(const Layer *last = nullptr, int threadsCount = 1) {
    struct H {
      Layer &layer;
      const Layer *last;
      std::atomic<unsigned int> barrierCounter;
      std::vector<std::thread*> threads;
      
      H(Layer &layer, const Layer *last, int threadsCount): layer(layer), last(last), barrierCounter(0), threads(threadsCount) { }
      
      void func(int tid, unsigned int seed) {
        Barrier barrier(barrierCounter, tid, threads.size(), seed);
        for(Layer *l = layer.next; l; l = l->next) {
          l->pass(barrier);
          if (l == last || !l->next) break;
          barrier.wait();
        }
      }
    } h(*this, last, threadsCount);
    
    for(Layer *l = this; l; l = l->next)
      l->split(threadsCount);
    for(int i = 1; i < threadsCount; ++i)
      h.threads[i] = new std::thread(&H::func, &h, i, rand());
    h.func(0, rand());
    for(int i = 1; i < threadsCount; ++i)
      { h.threads[i]->join(); delete h.threads[i]; }
  }
};


#endif