Blob Blame Raw
#ifndef SEGMENT_LAYER_INC_CPP
#define SEGMENT_LAYER_INC_CPP


#include "segment.inc.cpp"
#include "layer.inc.cpp"


class SegmentLayer: public Segment {
public:
  Layer *fl;
  Layer *bl;
  
  SegmentLayer(Layer &fl, Layer &bl):
    Segment(fl.layout.getW(), fl.layout.getH(), fl.layout.getD(), fl.weightsCount, fl.weights)
  {
    assert(fl.layout.getW() == bl.layout.getW());
    assert(fl.layout.getH() == bl.layout.getH());
    assert(fl.layout.getD() == bl.layout.getD());
    filename = fl.filename;
  }
  
  SegmentLayer(Layer &layer): SegmentLayer(layer, layer.back()) { }
  
  
  Quality pass(Barrier &barrier, int x, int y, int z, NeuronReal trainRatio) override {
    check(x, y, z);
    
    // copy values
    
    // pass
    
    Layer *ffl = fl;
    for(Layer *l = fl->next; l; l = l->next) {
      if (!l->skipTrain) ffl = l;
      barrier.wait();
      l->pass(barrier);
      if (l == bl) break;
      if (!l->next) return Quality::nan();
    }
    
    // verify
    
    Quality q;
    
    if (trainRatio <= 0) return q;
    
    // back pass deltas
    
    for(Layer *l = bl; l != ffl; l = l->next) {
      barrier.wait();
      l->backpassDeltas(barrier);
    }
    
    // beack pass weights

    for(Layer *l = bl; l; l = l->next) {
      if (!l->skipTrain) {
        barrier.wait();
        l->backpassDeltas(barrier);
      }
    }
    
    return q;
  }
};




#endif