Blame projects/neural/segment.layer.inc.cpp

Ivan Mahonin e4740d
#ifndef SEGMENT_LAYER_INC_CPP
Ivan Mahonin e4740d
#define SEGMENT_LAYER_INC_CPP
Ivan Mahonin e4740d
Ivan Mahonin e4740d
Ivan Mahonin e4740d
#include "segment.inc.cpp"
Ivan Mahonin e4740d
#include "layer.inc.cpp"
Ivan Mahonin e4740d
Ivan Mahonin e4740d
Ivan Mahonin e4740d
class SegmentLayer: public Segment {
Ivan Mahonin e4740d
public:
Ivan Mahonin e4740d
  Layer *fl;
Ivan Mahonin e4740d
  Layer *bl;
Ivan Mahonin e4740d
  
Ivan Mahonin e4740d
  SegmentLayer(Layer &fl, Layer &bl):
Ivan Mahonin e4740d
    Segment(fl.layout.getW(), fl.layout.getH(), fl.layout.getD(), fl.weightsCount, fl.weights)
Ivan Mahonin e4740d
  {
Ivan Mahonin e4740d
    assert(fl.layout.getW() == bl.layout.getW());
Ivan Mahonin e4740d
    assert(fl.layout.getH() == bl.layout.getH());
Ivan Mahonin e4740d
    assert(fl.layout.getD() == bl.layout.getD());
Ivan Mahonin e4740d
    filename = fl.filename;
Ivan Mahonin e4740d
  }
Ivan Mahonin e4740d
  
Ivan Mahonin e4740d
  SegmentLayer(Layer &layer): SegmentLayer(layer, layer.back()) { }
Ivan Mahonin e4740d
  
Ivan Mahonin e4740d
  
Ivan Mahonin e4740d
  Quality pass(Barrier &barrier, int x, int y, int z, NeuronReal trainRatio) override {
Ivan Mahonin e4740d
    check(x, y, z);
Ivan Mahonin e4740d
    
Ivan Mahonin e4740d
    // copy values
Ivan Mahonin e4740d
    
Ivan Mahonin e4740d
    // pass
Ivan Mahonin e4740d
    
Ivan Mahonin e4740d
    Layer *ffl = fl;
Ivan Mahonin e4740d
    for(Layer *l = fl->next; l; l = l->next) {
Ivan Mahonin e4740d
      if (!l->skipTrain) ffl = l;
Ivan Mahonin e4740d
      barrier.wait();
Ivan Mahonin e4740d
      l->pass(barrier);
Ivan Mahonin e4740d
      if (l == bl) break;
Ivan Mahonin e4740d
      if (!l->next) return Quality::nan();
Ivan Mahonin e4740d
    }
Ivan Mahonin e4740d
    
Ivan Mahonin e4740d
    // verify
Ivan Mahonin e4740d
    
Ivan Mahonin e4740d
    Quality q;
Ivan Mahonin e4740d
    
Ivan Mahonin e4740d
    if (trainRatio <= 0) return q;
Ivan Mahonin e4740d
    
Ivan Mahonin e4740d
    // back pass deltas
Ivan Mahonin e4740d
    
Ivan Mahonin e4740d
    for(Layer *l = bl; l != ffl; l = l->next) {
Ivan Mahonin e4740d
      barrier.wait();
Ivan Mahonin e4740d
      l->backpassDeltas(barrier);
Ivan Mahonin e4740d
    }
Ivan Mahonin e4740d
    
Ivan Mahonin e4740d
    // beack pass weights
Ivan Mahonin e4740d
Ivan Mahonin e4740d
    for(Layer *l = bl; l; l = l->next) {
Ivan Mahonin e4740d
      if (!l->skipTrain) {
Ivan Mahonin e4740d
        barrier.wait();
Ivan Mahonin e4740d
        l->backpassDeltas(barrier);
Ivan Mahonin e4740d
      }
Ivan Mahonin e4740d
    }
Ivan Mahonin e4740d
    
Ivan Mahonin e4740d
    return q;
Ivan Mahonin e4740d
  }
Ivan Mahonin e4740d
};
Ivan Mahonin e4740d
Ivan Mahonin e4740d
Ivan Mahonin e4740d
Ivan Mahonin e4740d
Ivan Mahonin e4740d
#endif
Ivan Mahonin e4740d
Ivan Mahonin e4740d