#ifndef SEGMENT_LAYER_INC_CPP
#define SEGMENT_LAYER_INC_CPP
#include "segment.inc.cpp"
#include "layer.inc.cpp"
class SegmentLayer: public Segment {
public:
Layer *fl;
Layer *bl;
SegmentLayer(Layer &fl, Layer &bl):
Segment(fl.layout.getW(), fl.layout.getH(), fl.layout.getD(), fl.weightsCount, fl.weights)
{
assert(fl.layout.getW() == bl.layout.getW());
assert(fl.layout.getH() == bl.layout.getH());
assert(fl.layout.getD() == bl.layout.getD());
filename = fl.filename;
}
SegmentLayer(Layer &layer): SegmentLayer(layer, layer.back()) { }
Quality pass(Barrier &barrier, int x, int y, int z, NeuronReal trainRatio) override {
check(x, y, z);
// copy values
// pass
Layer *ffl = fl;
for(Layer *l = fl->next; l; l = l->next) {
if (!l->skipTrain) ffl = l;
barrier.wait();
l->pass(barrier);
if (l == bl) break;
if (!l->next) return Quality::nan();
}
// verify
Quality q;
if (trainRatio <= 0) return q;
// back pass deltas
for(Layer *l = bl; l != ffl; l = l->next) {
barrier.wait();
l->backpassDeltas(barrier);
}
// beack pass weights
for(Layer *l = bl; l; l = l->next) {
if (!l->skipTrain) {
barrier.wait();
l->backpassDeltas(barrier);
}
}
return q;
}
};
#endif