Blame projects/neural/layer.inc.cpp

Ivan Mahonin e865c9
#ifndef LAYER_INC_CPP
Ivan Mahonin e865c9
#define LAYER_INC_CPP
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin b579b3
#include "common.inc.cpp"
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin b579b3
class WeightHolder {
Ivan Mahonin e865c9
public:
Ivan Mahonin b579b3
  const int weightsCount;
Ivan Mahonin b579b3
  Weight *weights;
Ivan Mahonin 15c502
  
Ivan Mahonin b579b3
  const char *filename;
Ivan Mahonin b579b3
  
Ivan Mahonin b579b3
  explicit WeightHolder(int weightsCount = 0, Weight *weights = nullptr):
Ivan Mahonin b579b3
    weightsCount(weightsCount), weights(weights), filename()
Ivan Mahonin b579b3
    { assert(weightsCount >= 0); }
Ivan Mahonin b579b3
  
Ivan Mahonin b579b3
  
Ivan Mahonin b579b3
  virtual ~WeightHolder() { }
Ivan Mahonin b579b3
  
Ivan Mahonin b579b3
  
Ivan Mahonin b579b3
  bool save(bool demoOnly = false) {
Ivan Mahonin b579b3
    if (filename && weightsCount && !demoOnly) {
Ivan Mahonin b579b3
      FILE *f = fopen(filename, "wb");
Ivan Mahonin b579b3
      if (!f)
Ivan Mahonin b579b3
        return printf("cannot open file for write: %s\n", filename), false;
Ivan Mahonin b579b3
      if (!fwrite(weights, sizeof(*weights)*weightsCount, 1, f))
Ivan Mahonin b579b3
        return fclose(f), printf("cannot write to file: %s\n", filename), false;
Ivan Mahonin b579b3
      fclose(f);
Ivan Mahonin b579b3
    }
Ivan Mahonin b579b3
    return saveDemo();
Ivan Mahonin b579b3
  }
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin b579b3
  bool load() {
Ivan Mahonin b579b3
    if (filename && weightsCount) {
Ivan Mahonin b579b3
      FILE *f = fopen(filename, "rb");
Ivan Mahonin b579b3
      if (!f)
Ivan Mahonin b579b3
        return printf("cannot open file for read: %s\n", filename), false;
Ivan Mahonin b579b3
      if (!fread(weights, sizeof(*weights)*weightsCount, 1, f))
Ivan Mahonin b579b3
        return fclose(f), printf("cannot read from file: %s\n", filename), false;
Ivan Mahonin b579b3
      fclose(f);
Ivan Mahonin b579b3
    }
Ivan Mahonin b579b3
    return true;
Ivan Mahonin e865c9
  }
Ivan Mahonin b579b3
  
Ivan Mahonin e865c9
Ivan Mahonin b579b3
  virtual bool saveDemo() { return true; }
Ivan Mahonin e865c9
};
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin b579b3
Ivan Mahonin b579b3
class Layer: public WeightHolder {
Ivan Mahonin e865c9
public:
Ivan Mahonin e865c9
  Layer *prev, *next;
Ivan Mahonin e865c9
Ivan Mahonin e865c9
  Layout layout;
Ivan Mahonin e865c9
Ivan Mahonin e865c9
  Neuron *neurons;
Ivan Mahonin e865c9
  int neuronsCount;
Ivan Mahonin e865c9
Ivan Mahonin e865c9
  bool ownWeights;
Ivan Mahonin e865c9
Ivan Mahonin b579b3
  bool skipTrain;
Ivan Mahonin e865c9
Ivan Mahonin e865c9
  Stat stat;
Ivan Mahonin e865c9
Ivan Mahonin e865c9
  Layout::List mtLayouts;
Ivan Mahonin 036a8f
  Layout::List mtPrevLayouts;
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin e865c9
  Layer(Layer *prev, const Layout &layout, int weightsCount = 0, Weight *weights = nullptr):
Ivan Mahonin b579b3
    WeightHolder(weightsCount, weights),
Ivan Mahonin e865c9
    prev(prev ? &prev->back() : nullptr),
Ivan Mahonin e865c9
    next(),
Ivan Mahonin e865c9
    layout(layout),
Ivan Mahonin e865c9
    neurons(),
Ivan Mahonin e865c9
    neuronsCount(layout.getCount()),
Ivan Mahonin e865c9
    ownWeights(!weights && weightsCount),
Ivan Mahonin b579b3
    skipTrain()
Ivan Mahonin e865c9
  {
Ivan Mahonin e865c9
    assert(layout);
Ivan Mahonin e865c9
    assert(neuronsCount > 0);
Ivan Mahonin e865c9
    assert(weightsCount >= 0);
Ivan Mahonin 15c502
    assert(prev || !weightsCount);
Ivan Mahonin e865c9
Ivan Mahonin e865c9
    if (this->prev) this->prev->next = this;
Ivan Mahonin e865c9
    if (neuronsCount) {
Ivan Mahonin e865c9
      neurons = new Neuron[neuronsCount];
Ivan Mahonin e865c9
      memset(neurons, 0, sizeof(*neurons)*neuronsCount);
Ivan Mahonin e865c9
    }
Ivan Mahonin e865c9
    if (ownWeights) {
Ivan Mahonin e865c9
      this->weights = new Weight[weightsCount];
Ivan Mahonin e865c9
      memset(this->weights, 0, sizeof(*this->weights)*weightsCount);
Ivan Mahonin e865c9
    }
Ivan Mahonin e865c9
Ivan Mahonin e865c9
    stat.neurons = neuronsCount;
Ivan Mahonin e865c9
    stat.activeNeurons = layout.getActiveCount();
Ivan Mahonin e865c9
    stat.weights = weightsCount;
Ivan Mahonin e865c9
    stat.links = weightsCount;
Ivan Mahonin e865c9
    stat.memsize = neuronsCount*sizeof(*neurons);
Ivan Mahonin e865c9
    if (ownWeights) stat.memsize += weightsCount*sizeof(*weights);
Ivan Mahonin e865c9
  }
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin b579b3
  ~Layer() {
Ivan Mahonin e865c9
    if (next) delete next;
Ivan Mahonin e865c9
    if (neurons) delete[] neurons;
Ivan Mahonin e865c9
    if (ownWeights) delete[] weights;
Ivan Mahonin e865c9
  }
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin e865c9
  inline Layer& front()
Ivan Mahonin e865c9
    { Layer *l = this; while(l->prev) l = l->prev; return *l; }
Ivan Mahonin e865c9
  inline Layer& back()
Ivan Mahonin e865c9
    { Layer *l = this; while(l->next) l = l->next; return *l; }
Ivan Mahonin e865c9
  inline Stat sumStat() const
Ivan Mahonin e865c9
    { Stat s; for(const Layer *l = this; l; l = l->next) s += l->stat; return s; }
Ivan Mahonin e865c9
Ivan Mahonin b579b3
  bool save(bool demoOnly = false)
Ivan Mahonin b579b3
    { return WeightHolder::save(demoOnly) && (!next || next->save(demoOnly)); }
Ivan Mahonin b579b3
  bool load()
Ivan Mahonin b579b3
    { return WeightHolder::load() && (!next || next->load()); }
Ivan Mahonin e865c9
Ivan Mahonin b579b3
 
Ivan Mahonin e865c9
  void clearAccum() {
Ivan Mahonin e865c9
    Accum a = {};
Ivan Mahonin e865c9
    for(Neuron *in = neurons, *e = in + neuronsCount; in < e; ++in)
Ivan Mahonin e865c9
      in->a = a;
Ivan Mahonin e865c9
  }
Ivan Mahonin e865c9
Ivan Mahonin e865c9
  
Ivan Mahonin e865c9
  void fillWeights(WeightReal wmin, WeightReal wmax) {
Ivan Mahonin e865c9
    WeightReal k = (wmax - wmin)/RAND_MAX;
Ivan Mahonin e865c9
    for(Weight *iw = weights, *e = iw + weightsCount; iw < e; ++iw)
Ivan Mahonin e865c9
      iw->w = rand()*k + wmin;
Ivan Mahonin e865c9
  }
Ivan Mahonin e865c9
 
Ivan Mahonin b579b3
Ivan Mahonin 036a8f
  virtual void split(int threadsCount) {
Ivan Mahonin 036a8f
    layout.split(mtLayouts, threadsCount);
Ivan Mahonin 036a8f
    if (prev) prev->layout.split(mtPrevLayouts, threadsCount);
Ivan Mahonin 036a8f
  }
Ivan Mahonin e865c9
  virtual void pass(Barrier &barrier) { }
Ivan Mahonin e865c9
  virtual void backpassWeights(Barrier &barrier) { }
Ivan Mahonin e865c9
  virtual void backpassDeltas(Barrier &barrier) { }
Ivan Mahonin e865c9
  
Ivan Mahonin e865c9
  virtual void testPass() { }
Ivan Mahonin e865c9
  virtual void testBackpass() { }
Ivan Mahonin 036a8f
  
Ivan Mahonin 036a8f
  
Ivan Mahonin b579b3
  void passFull(const Layer *last = nullptr, int threadsCount = 1) {
Ivan Mahonin e4740d
    class H: public ThreadControl {
Ivan Mahonin e4740d
    public:
Ivan Mahonin b579b3
      Layer &layer;
Ivan Mahonin b579b3
      const Layer *last;
Ivan Mahonin b579b3
      
Ivan Mahonin e4740d
      H(Layer &layer, const Layer *last): layer(layer), last(last) { }
Ivan Mahonin b579b3
      
Ivan Mahonin e4740d
      void threadFunc(Barrier &barrier) override {
Ivan Mahonin b579b3
        for(Layer *l = layer.next; l; l = l->next) {
Ivan Mahonin b579b3
          l->pass(barrier);
Ivan Mahonin b579b3
          if (l == last || !l->next) break;
Ivan Mahonin b579b3
          barrier.wait();
Ivan Mahonin b579b3
        }
Ivan Mahonin b579b3
      }
Ivan Mahonin e4740d
    };
Ivan Mahonin e4740d
    H(*this, last).runThreads(threadsCount);
Ivan Mahonin b579b3
  }
Ivan Mahonin e865c9
};
Ivan Mahonin e865c9
Ivan Mahonin e865c9
Ivan Mahonin e865c9
#endif