Blob Blame Raw
#ifndef LAYER_CONV_TEST_INC_CPP
#define LAYER_CONV_TEST_INC_CPP



#include "layer.test.inc.cpp"
#include "layer.conv.inc.cpp"
#include "layer.conv.shared.inc.cpp"


class ConvTest: public LayerTest {
public:
  static void init(const Layout &cl, const Layout &pl, const Kernel &k, bool shared = false)
    { Test::init(cl.getCount(), pl.getCount(), (shared ? 1 : cl.getActiveCount())*k.sx*k.sy*pl.getD()); }


  static bool verifyWeights(const char *name, const Layout &cl, const Layout &pl, const Kernel &k) {
    Stage st(name);
    for(int cy = cl.y0; cy < cl.y1; ++cy)
    for(int cx = cl.x0; cx < cl.x1; ++cx)
    for(int cz = cl.z0; cz < cl.z1; ++cz) {
      int ci = (cy*cl.sx + cx)*cl.sz + cz;
      for(int ky = 0; ky < k.sy; ++ky)
      for(int kx = 0; kx < k.sx; ++kx)
      for(int pz = pl.z0; pz < pl.z1; ++pz) {
        int wi = ((cy - cl.y0)*cl.getW() + cx - cl.x0)*cl.getD() + cz - cl.z0;
        wi = ((wi*k.sy + ky)*k.sx + kx)*pl.getD() + pz - pl.z0;

        int px = pl.x0 + (cx - cl.x0)*k.dx + k.ox + kx;
        int py = pl.y0 + (cy - cl.y0)*k.dy + k.oy + ky;
        if ( px < pl.x0 || px >= pl.x1
          || py < pl.y0 || py >= pl.y1 ) continue;

        int pi = (py*pl.sx + px)*pl.sz + pz;

        int s = (int)p_neurons.size();
        int w = weights[wi].i;
        int i = ci*s + pi + 1;

        if (w != i) {
          int ww = w;
          int wpz = ww%pl.sz; ww /= pl.sz;
          int wpx = ww%pl.sx; ww /= pl.sx;
          int wpy = ww%pl.sy; ww /= pl.sy;
          int wcz = ww%cl.sz; ww /= cl.sz;
          int wcx = ww%cl.sx; ww /= cl.sx;
          int wcy = ww;

          printf(
            "wrong index: %d = ((%d*%d + %d)*%d + %d)*%d + (%d*%d + %d)*%d + %d + 1,\n"
            "expected:    %d = ((%d*%d + %d)*%d + %d)*%d + (%d*%d + %d)*%d + %d + 1\n"
            "wi %d, ky %d, kx %d \n",
            w,
            wcy, cl.sx, wcx, cl.sz, wcz, s,
            wpy, pl.sx, wpx, pl.sz, wpz,
            i,
            cy, cl.sx, cx, cl.sz, cz, s,
            py, pl.sx, px, pl.sz, pz,
            wi, ky, kx );
          pl.printYXZ("prev layout");
          cl.printYXZ("curr layout");
          k.printYX("kernel");
          ++errors;
          return st;
        }
      }
    }
    return st;
  }

  static bool testIterators(const char *name, const Layout &cl, const Layout &pl, const Kernel &k, int threads) {
    Stage st(name);

    assert(cl && pl && k && threads > 0);
    Layout::List clist, plist;
    cl.split(clist, threads);
    pl.split(plist, threads);
    
    struct I: public Iter {
      static inline void init(Neuron &n, Iter::AccumType &a) { ++n.a.i; a.i = (AccumInt)(&n - c_neurons.data()); }
      static inline void iter(Neuron &n, Weight &w, Iter::AccumType &a) {
        if (w.i)
          ++errors;
        w.i = (WeightInt)(&n - p_neurons.data() + a.i*p_neurons.size() + 1);
      }
      static inline void iter2(Neuron &cn, Neuron &pn, Weight &w) {
        if (w.i)
          ++errors;
        w.i = (WeightInt)((&cn - c_neurons.data())*p_neurons.size() + &pn - p_neurons.data() + 1);
        pn.v = pn.v + 1;
      }
    };

    if (threads == 1) {
      Stage st("iterateTestConvolution");
      init(cl, pl, k);
      iterateTestConvolution<I>(cl, pl, k, c_neurons.data(), p_neurons.data(), weights.data());
      verifyNeurons("conv-neurons", cl, c_neurons.data());
      verifyWeights("conv-weights", cl, pl, k);
    }

    {
      Stage st("iterateConvolution");
      init(cl, pl, k);
      for(int i = 0; i < threads; ++i)
        iterateConvolution<I>(clist[i], pl, cl, k, c_neurons.data(), p_neurons.data(), weights.data());
      verifyNeurons("conv-neurons", cl, c_neurons.data());
      verifyWeights("conv-weights", cl, pl, k);
    }

    {
      Stage st("iterateConvolutionPoint");
      init(cl, pl, k);
      int e = errors;
      for(int ky = 0; ky < k.sy && errors == e; ++ky)
      for(int kx = 0; kx < k.sx && errors == e; ++kx) {
        for(int i = 0; i < threads; ++i)
          iterateConvolutionPoint<I>(clist[i], pl, cl, k, kx, ky, c_neurons.data(), p_neurons.data(), weights.data());
        if (!verifyNeuronsAccum(pl, p_neurons.data(), cl.getD(), true))
          printf("kx: %d, ky: %d\n", kx, ky), k.printYX("kernel");
      }
      verifyNeurons("conv-neurons", pl, p_neurons.data(), true);
      verifyWeights("conv-weights", cl, pl, k);
    }

    return st;
  }

  static bool testIterators(const char *name, const Layout &cl, const Layout &pl, const Kernel &k) {
    Stage st(name);
    testIterators( "single-thread", cl, pl, k,   1 );
    testIterators( "2-threads",     cl, pl, k,   2 );
    testIterators( "7-threads",     cl, pl, k,   7 );
    testIterators( "8-threads",     cl, pl, k,   8 );
    testIterators( "512-threads",   cl, pl, k, 512 );
    return st;
  }
  
  static bool test(const char *name, const Layout &cl, const Layout &pl, const Kernel &k) {
    Stage st(name);
    
    testIterators("iterators", cl, pl, k);

    {
      Layer l(nullptr, pl);
      new LayerConv<funcSigmoidExp>(l, cl, k);
      testLayer("LayerConv", l);
    }

    {
      Layer l(nullptr, cl);
      new LayerDeconv<funcSigmoidExp>(l, pl, k);
      testLayer("LayerDeconv", l);
    }

    {
      Layer l(nullptr, pl);
      new LayerConvShared<funcSigmoidExp>(l, cl, k);
      testLayer("LayerConvShared", l);
    }

    {
      Layer l(nullptr, cl);
      new LayerDeconvShared<funcSigmoidExp>(l, pl, k);
      testLayer("LayerDeconvShared", l);
    }

    return st;
  }

  static bool test(const char *name = "convolution") {
    Stage st(name);
    test( "square", Layout(64, 64, 4), Layout(128, 128, 4).expandXY(2),                 Kernel(5, 2, -2)           );
    test( "rect1",  Layout(63, 43, 5), Layout( 63,  85, 3).expandX(2).   expandY(3),    Kernel(5, 7, 1, 2, -2, -3) );
    test( "rect2",  Layout(43, 63, 3), Layout( 85,  63, 5).expandX(3).   expandY(3, 2), Kernel(7, 5, 2, 1, -3, -2) );
    test( "rect3",  Layout(64, 48, 5), Layout( 64,  96, 3).expandX(1, 2).expandY(3, 1), Kernel(4, 6, 1, 2, -1, -3) );
    test( "pad",    Layout(64, 48, 5).expandX(3, 4).expandY(4, 3).expandZ(5, 4),
                    Layout(64, 96, 3).expandX(6, 5).expandY(7, 6).expandZ(0, 1),        Kernel(4, 6, 1, 2, -1, -3) );
    return st;
  }
};


#endif