Blob Blame Raw
#ifndef LAYER_SUB_INC_CPP
#define LAYER_SUB_INC_CPP


#include "layer.simple.inc.cpp"


template<Func func>
class LayerSub: public Layer {
public:
  Layout optLayout;
  Layout::List mtOptLayouts;
  std::vector<Neuron*> choosen;
  
  LayerSub(Layer &prev, const Layout &layout):
    Layer(&prev, layout),
    optLayout(optimizeLayoutSimple(layout)),
    choosen(layout.getActiveCount(), nullptr)
    { }


  void split(int threadsCount) override {
    Layer::split(threadsCount);
    optLayout.split(mtOptLayouts, threadsCount);
  }

  
  void pass(Barrier &barrier) override {
    Layout cl = mtLayouts[barrier.tid];
    Layout pl = prev->layout;
    Layout wl = layout;
    if (!cl) return;

    assert(pl.getW() == wl.getW()*2);
    assert(pl.getH() == wl.getH()*2);
    assert(pl.getD() == wl.getD());
    assert(cl.isSubLayoutOf(wl));

    int c_h    = cl.getH();
    int c_w    = cl.getW();
    int c_d    = cl.getD();
    int c_sxz  = cl.sx*cl.sz;
    int c_swz  = c_w*cl.sz;
    int c_shxz = c_h*c_sxz;
    int c_dy   = c_sxz - c_swz;
    int c_dx   = cl.sz - c_d;

    int w_d    = wl.getD();
    int w_w    = wl.getW();
    int w_dy   = (w_w - c_w)*w_d;
    int w_dx   = w_d - c_d;
    
    int p_dy   = (pl.sx - c_w)*pl.sz*2;
    int p_dx   = pl.sz*2 - c_d;
    
    int p_i1   = pl.sz;
    int p_i2   = pl.sx*pl.sz;
    int p_i3   = p_i1 + p_i2;

    int cx0 = cl.x0 - wl.x0;
    int cy0 = cl.y0 - wl.y0;
    int cz0 = cl.z0 - wl.z0;
    
    Neuron *icn = neurons + (cl.y0*c_sxz + cl.x0*cl.sz + cl.z0);
    Neuron *ipn = prev->neurons + ((pl.y0 + cy0*2)*pl.sx + pl.x0 + cx0*2)*pl.sz + pl.z0 + cz0;
    Neuron **icc = choosen.data() + (cy0*w_w + cx0)*w_d + cz0;
    
    for(Neuron *e = icn + c_shxz; icn < e; icn += c_dy, ipn += p_dy, icc += w_dy)
    for(Neuron *e = icn +  c_swz; icn < e; icn += c_dx, ipn += p_dx, icc += w_dx)
    for(Neuron *e = icn +    c_d; icn < e; ++icn, ++ipn, ++icc) {
      Neuron *iipn = ipn, *pn = iipn;
      NeuronReal v = pn->v, d = pn->d;
      pn->d = 0;
      
      iipn = ipn + p_i1;
      if (v < iipn->v) { v = iipn->v; d = iipn->d; pn = iipn; }
      iipn->d = 0;
      
      iipn = ipn + p_i2;
      if (v < iipn->v) { v = iipn->v; d = iipn->d; pn = iipn; }
      iipn->d = 0;
      
      iipn = ipn + p_i3;
      if (v < iipn->v) { v = iipn->v; d = iipn->d; pn = iipn; }
      iipn->d = 0;
      
      func(*icn, v);
      icn->d *= d;
      *icc = pn;
    }
  }
  

  void backpassDeltas(Barrier &barrier) override {
    Layout cl = mtOptLayouts[barrier.tid];
    Layout wl = optLayout;
    if (!cl) return;

    int c_h    = cl.getH();
    int c_w    = cl.getW();
    int c_d    = cl.getD();
    int c_sxz  = cl.sx*cl.sz;
    int c_swz  = c_w*cl.sz;
    int c_shxz = c_h*c_sxz;
    int c_dy   = c_sxz - c_swz;
    int c_dx   = cl.sz - c_d;

    int w_d    = wl.getD();
    int w_w    = wl.getW();
    int w_dy   = (w_w - c_w)*w_d;
    int w_dx   = w_d - c_d;
    
    Neuron *icn = neurons + (cl.y0*c_sxz + cl.x0*cl.sz + cl.z0);
    Neuron **icc = choosen.data() + ((cl.y0 - wl.y0)*w_w + cl.x0 - wl.x0)*w_d + cl.z0 - wl.z0;
    
    for(Neuron *e = icn + c_shxz; icn < e; icn += c_dy, icc += w_dy)
    for(Neuron *e = icn +  c_swz; icn < e; icn += c_dx, icc += w_dx)
    for(Neuron *e = icn +    c_d; icn < e; ++icn, ++icc) {
      assert(*icc);
      (*icc)->d = icn->d;
    }
  }

  
  void testPass() override {
    Layout cl = layout;
    Layout pl = prev->layout;

    assert(pl.getW() == cl.getW()*2);
    assert(pl.getH() == cl.getH()*2);
    assert(pl.getD() == cl.getD());
    
    for(int cy = cl.y0; cy < cl.y1; ++cy)
    for(int cx = cl.x0; cx < cl.x1; ++cx)
    for(int cz = cl.z0; cz < cl.z1; ++cz) {
      int ci = (cy*cl.sx + cx)*cl.sz + cz;
      Neuron &cn = neurons[ci];

      Neuron *c = nullptr;
      NeuronReal v = 0, d = 0;

      for(int ky = 0; ky < 2; ++ky)
      for(int kx = 0; kx < 2; ++kx) {
        int px = pl.x0 + (cx - cl.x0)*2 + kx;
        int py = pl.y0 + (cy - cl.y0)*2 + ky;
        int pz = pl.z0 + cz - cl.z0;
        
        Neuron &pn = prev->neurons[ (py*pl.sx + px)*pl.sz + pz ];
        if (!c || v < pn.v) { v = pn.v; d = pn.d; c = &pn; }
        pn.d = 0;
      }
      
      assert(c);
      c->d = d;
      func(cn, v);
    }
  }

  
  void testBackpass() override {
    Layout cl = layout;
    Layout pl = prev->layout;

    assert(pl.getW() == cl.getW()*2);
    assert(pl.getH() == cl.getH()*2);
    assert(pl.getD() == cl.getD());
    
    for(int cy = cl.y0; cy < cl.y1; ++cy)
    for(int cx = cl.x0; cx < cl.x1; ++cx)
    for(int cz = cl.z0; cz < cl.z1; ++cz) {
      int ci = (cy*cl.sx + cx)*cl.sz + cz;
      Neuron &cn = neurons[ci];
      
      for(int ky = 0; ky < 2; ++ky)
      for(int kx = 0; kx < 2; ++kx) {
        int px = pl.x0 + (cx - cl.x0)*2 + kx;
        int py = pl.y0 + (cy - cl.y0)*2 + ky;
        int pz = pl.z0 + cz - cl.z0;
        
        Neuron &pn = prev->neurons[ (py*pl.sx + px)*pl.sz + pz ];
        pn.d *= cn.d;
      }
    }
  }
};


#endif