Blob Blame Raw
#ifndef BENCHMARK_INC_CPP
#define BENCHMARK_INC_CPP


#include "common.inc.cpp"
#include "layer.conv.inc.cpp"


class Benchmark: public ThreadControl {
private:
  typedef int Int;
  typedef double Float;
  
  int repeats;
  int mode;
  Layout pl, cl;
  Kernel k;
  std::vector<Float> pvalues;
  std::vector<Float> cvalues;
  std::vector<Float> weights;
  
  
  __attribute__((always_inline))
  void threadFuncXYCP(Barrier &barrier, int k_sx, int k_sy, int c_d, int p_d) {
    Layout pl = this->pl;
    Layout cl = this->cl;
    Kernel k = this->k;
    
    assert(k.sx == k_sx);
    assert(k.sy == k_sy);
    assert(cl.getD() == c_d);
    assert(pl.getD() == p_d);
    
    Int tid = barrier.tid;
    Int ts = barrier.threads;

    Int c_w = cl.getW();
    Int c_h = cl.getH();
    Int c_wd = c_w*c_d;
    Int c_hw = c_h*c_w;
    Int c_hwd = c_h*c_wd;
    Int c_sxz = cl.sx*cl.sz;
    
    Int p_sxz = pl.sx*pl.sz;
    Int p_szk = pl.sz*k.dx;
    Int p_sxzk = p_sxz*k.dy;

    //Int k_sxd = k_sx*p_d;
    Int k_syx = k_sx*k_sy;
    Int k_syxd = k_syx*p_d;
    
    Float *cvalues = this->cvalues.data() + cl.y0*c_sxz + cl.x0*cl.sz + cl.z0;
    Float *pvalues = this->pvalues.data() + (pl.y0 + k.oy)*p_sxz + (pl.x0 + k.ox)*pl.sz + pl.z0;
    Float *weights = this->weights.data();
    
    if (mode == 0) {
      for(Int i = repeats; i; --i) {
        for(Int i = tid; i < c_hwd; i += ts) {
          Int cy = i/c_wd;
          Int cx = i%c_wd/c_d;
          Int cz = i%c_d;
          
          Float *ic = &cvalues[ cy*c_sxz + cx*cl.sz + cz ];
          Float *ip = &pvalues[ cy*p_sxzk + cx*p_szk ];
          Float *iw = &weights[ cz*k_syxd ];

          Float a = 0;
          
          for(Int i = 0; i < p_d; ++i, ++ip, ++iw)
          for(Int i = 0; i < k_syx; ++i) {
            Int ky = i/k_sx;
            Int kx = i%k_sx;
            a += iw[i] * ip[ ky*p_sxz + kx*pl.sz ];
          }
          
          *ic = a;
        }
        barrier.wait2();
      }
    } else
    if (mode == 1) {
      if (c_w > 1 || c_h > 1 || pl.sx != pl.getW() || pl.sz != p_d) {
        for(Int i = repeats; i; --i)
        for(Int i = 0; i < k_syx; ++i) {
          Int ky = i/k_sx;
          Int kx = i%k_sx;
          //Int pi = ky*p_sxz + kx*pl.sz;
          //Int wi = i*p_d;
          Float *ip = &pvalues[ ky*p_sxz + kx*pl.sz ];
          Float *iw = &weights[ i*p_d ];
          for(Int i = tid; i < c_hw; i += ts) {
            Int cy = i/c_w;
            Int cx = i%c_w;
            Float *iip = &ip[ cy*p_sxzk + cx*p_szk ];
            Float *iic = &cvalues[ cy*c_sxz + cx*cl.sz ];
            
            for(Int cz = 0; cz < c_d; ++cz) {
              //Int pii = pi + cy*p_sxzk + cx*p_szk;
              //Int wii = wi + cz*k_syxd;
              Float *iiw = &iw[ cz*k_syxd ];
              Float a = iic[cz];
              
              for(Int i = 0; i < p_d; ++i)
                iip[i] += iiw[i] * a;
                //pvalues[pii + i] += weights[wii + i] * a;
            }
          }
          barrier.wait2();
        }
      } else {
        //Int c_dd = c_d*p_d;
        //Int c_wdd = c_wd*p_d;
        //Int c_hwdd = c_hwd*p_d;
        Int cnt = k_syxd/ts;
        for(Int i = repeats; i; --i) {
          for(Int i = cnt*tid; i < cnt; ++i) {
            //Int ky = i/k_sxd;
            //Int kx = i%k_sxd/p_d;
            //Int pz = i%p_d;
            //Float *ip = &pvalues[ i ];//ky*p_sxz + kx*pl.sz + pz ];
            Float *iw = &weights[ i*c_d ];
            Float a = 0;
            for(Int cz = 0; cz < c_d; ++cz) {
              a += iw[ cz ] * cvalues[ cz ]; //*k_syxd
            }
            pvalues[ i ] = a;
            //*ip = a;
          }
          barrier.wait2();
        }
      }
    }
  }

  __attribute__((always_inline))
  void threadFuncXYC(Barrier &barrier, int k_sx, int k_sy, int c_d, int p_d) {
    switch(p_d) {
    case   3: return threadFuncXYCP( barrier, k_sx, k_sy, c_d,  3 );
    case   4: return threadFuncXYCP( barrier, k_sx, k_sy, c_d,  4 );
    case  24: return threadFuncXYCP( barrier, k_sx, k_sy, c_d, 24 );
    case  48: return threadFuncXYCP( barrier, k_sx, k_sy, c_d, 28 );
    }
    threadFuncXYCP( barrier, k_sx, k_sy, c_d, p_d );
  }

  __attribute__((always_inline))
  void threadFuncXY(Barrier &barrier, int k_sx, int k_sy, int c_d, int p_d) {
    switch(c_d) {
    case   3: return threadFuncXYC( barrier, k_sx, k_sy,  3, p_d );
    case   4: return threadFuncXYC( barrier, k_sx, k_sy,  4, p_d );
    case  24: return threadFuncXYC( barrier, k_sx, k_sy, 24, p_d );
    case  48: return threadFuncXYC( barrier, k_sx, k_sy, 48, p_d );
    }
    threadFuncXYC( barrier, k_sx, k_sy, c_d, p_d );
  }

  void threadFunc(Barrier &barrier) override {
    Int k_sx = k.sx;
    Int k_sy = k.sy;
    Int c_d = cl.getD();
    Int p_d = pl.getD();

    if (k_sy == k_sx) switch(k_sx) {
    case 4: return threadFuncXY( barrier, 4, 4, c_d, p_d );
    }
    threadFuncXY( barrier, k_sx, k_sy, c_d, p_d );
  }
  

  void init(int mode, Layout pl, Layout cl, Kernel k, long long totalLinks) {
    assert(pl);
    assert(cl);
    assert(k);
    assert(totalLinks > 0);
    assert(0 <= pl.x0 + k.ox && (cl.getW()-1)*k.dx + k.ox + k.sx <= pl.sx);
    assert(0 <= pl.y0 + k.oy && (cl.getH()-1)*k.dy + k.oy + k.sy <= pl.sy);
    
    this->mode = mode;
    this->pl = pl;
    this->cl = cl;
    this->k = k;
    
    pvalues.resize(pl.getCount());
    cvalues.resize(cl.getCount());
    weights.resize(cl.getD()*k.sx*k.sy*pl.getD());
    
    for(int i = 0; i < (int)pvalues.size(); ++i) pvalues[i] = rand();
    for(int i = 0; i < (int)cvalues.size(); ++i) cvalues[i] = rand();
    for(int i = 0; i < (int)weights.size(); ++i) weights[i] = rand();
    
    long long links = weights.size() * cl.getW() * cl.getH();
    repeats = (totalLinks - 1)/links + 1;
    
    //printf( "benchmark init: prev %lld, curr %lld, links %lld, repeats %d, total links: %lld\n",
    //  (long long)pvalues.size(), (long long)cvalues.size(), links, repeats, links*repeats );
  }

  
  void run(const char *name, int threadsCount, int mode, Layout pl, Layout cl, Kernel k, long long totalLinks) {
    init(mode, pl, cl, k, totalLinks);
    
    volatile long long t0 = timeUs();
    runThreads(threadsCount);
    volatile long long t1 = timeUs();

    Float sum = 0;
    for(int i = 0; i < (int)pvalues.size(); ++i) sum += pvalues[i];
    for(int i = 0; i < (int)cvalues.size(); ++i) sum += cvalues[i];
    for(int i = 0; i < (int)weights.size(); ++i) sum += weights[i];
    printf("%s %d: %f, %lld\n", name, mode, (t1 - t0)*1e-6, (long long)sum);
  }

  
  void run(const char *name, int threadsCount, Layout pl, Layout cl, Kernel k, long long totalLinks) {
    for(int mode = 0; mode < 2; ++mode)
      run(name, threadsCount, mode, pl, cl, k, totalLinks);
  }


public:
  void run(int threadsCount = 1) {
    //printf("run benchmark: %d\n", threadsCount);
    /*
    run( "514x3  -> 258x24", threadsCount,
         Layout(514, 514,  3).expandXY(2),
         Layout(258, 258, 24).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
    run( "258x24 -> 130x48", threadsCount,
         Layout(258, 258, 24).expandXY(2),
         Layout(130, 130, 48).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
    run( "130x48 -> 66x96 ", threadsCount,
         Layout(130, 130, 48).expandXY(2),
         Layout( 66,  66, 96).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
    run( "66x96  -> 34x144", threadsCount,
         Layout( 66,  66, 96).expandXY(2),
         Layout(34, 34,  144).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
    run( "34x144 -> 18x216", threadsCount,
         Layout(34, 34,  144).expandXY(2),
         Layout(18, 18,  216).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
    run( "18x216 -> 10x324", threadsCount,
         Layout(18, 18,  216).expandXY(2),
         Layout(10, 10,  324).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
    run( "10x324 -> 6x486 ", threadsCount,
         Layout(10, 10,  324).expandXY(2),
         Layout( 6,  6,  486).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
    run( "6x486 -> 4x729  ", threadsCount,
         Layout( 6,  6,  486).expandXY(2),
         Layout( 4,  4,  729).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
    */
    run( "4x768 -> 1x1093 ", threadsCount,
         Layout( 4,  4,  768).expandXY(0),
         Layout( 1,  1, 1093).expandXY(0), Kernel(4, 2,  0), 10ll*1000*1000*1000 );
  }
};



#endif