#ifndef BENCHMARK_INC_CPP
#define BENCHMARK_INC_CPP
#include "common.inc.cpp"
#include "layer.conv.inc.cpp"
class Benchmark: public ThreadControl {
private:
typedef int Int;
typedef double Float;
int repeats;
int mode;
Layout pl, cl;
Kernel k;
std::vector<Float> pvalues;
std::vector<Float> cvalues;
std::vector<Float> weights;
__attribute__((always_inline))
void threadFuncXYCP(Barrier &barrier, int k_sx, int k_sy, int c_d, int p_d) {
Layout pl = this->pl;
Layout cl = this->cl;
Kernel k = this->k;
assert(k.sx == k_sx);
assert(k.sy == k_sy);
assert(cl.getD() == c_d);
assert(pl.getD() == p_d);
Int tid = barrier.tid;
Int ts = barrier.threads;
Int c_w = cl.getW();
Int c_h = cl.getH();
Int c_wd = c_w*c_d;
Int c_hw = c_h*c_w;
Int c_hwd = c_h*c_wd;
Int c_sxz = cl.sx*cl.sz;
Int p_sxz = pl.sx*pl.sz;
Int p_szk = pl.sz*k.dx;
Int p_sxzk = p_sxz*k.dy;
//Int k_sxd = k_sx*p_d;
Int k_syx = k_sx*k_sy;
Int k_syxd = k_syx*p_d;
Float *cvalues = this->cvalues.data() + cl.y0*c_sxz + cl.x0*cl.sz + cl.z0;
Float *pvalues = this->pvalues.data() + (pl.y0 + k.oy)*p_sxz + (pl.x0 + k.ox)*pl.sz + pl.z0;
Float *weights = this->weights.data();
if (mode == 0) {
for(Int i = repeats; i; --i) {
for(Int i = tid; i < c_hwd; i += ts) {
Int cy = i/c_wd;
Int cx = i%c_wd/c_d;
Int cz = i%c_d;
Float *ic = &cvalues[ cy*c_sxz + cx*cl.sz + cz ];
Float *ip = &pvalues[ cy*p_sxzk + cx*p_szk ];
Float *iw = &weights[ cz*k_syxd ];
Float a = 0;
for(Int i = 0; i < p_d; ++i, ++ip, ++iw)
for(Int i = 0; i < k_syx; ++i) {
Int ky = i/k_sx;
Int kx = i%k_sx;
a += iw[i] * ip[ ky*p_sxz + kx*pl.sz ];
}
*ic = a;
}
barrier.wait2();
}
} else
if (mode == 1) {
if (c_w > 1 || c_h > 1 || pl.sx != pl.getW() || pl.sz != p_d) {
for(Int i = repeats; i; --i)
for(Int i = 0; i < k_syx; ++i) {
Int ky = i/k_sx;
Int kx = i%k_sx;
//Int pi = ky*p_sxz + kx*pl.sz;
//Int wi = i*p_d;
Float *ip = &pvalues[ ky*p_sxz + kx*pl.sz ];
Float *iw = &weights[ i*p_d ];
for(Int i = tid; i < c_hw; i += ts) {
Int cy = i/c_w;
Int cx = i%c_w;
Float *iip = &ip[ cy*p_sxzk + cx*p_szk ];
Float *iic = &cvalues[ cy*c_sxz + cx*cl.sz ];
for(Int cz = 0; cz < c_d; ++cz) {
//Int pii = pi + cy*p_sxzk + cx*p_szk;
//Int wii = wi + cz*k_syxd;
Float *iiw = &iw[ cz*k_syxd ];
Float a = iic[cz];
for(Int i = 0; i < p_d; ++i)
iip[i] += iiw[i] * a;
//pvalues[pii + i] += weights[wii + i] * a;
}
}
barrier.wait2();
}
} else {
//Int c_dd = c_d*p_d;
//Int c_wdd = c_wd*p_d;
//Int c_hwdd = c_hwd*p_d;
Int cnt = k_syxd/ts;
for(Int i = repeats; i; --i) {
for(Int i = cnt*tid; i < cnt; ++i) {
//Int ky = i/k_sxd;
//Int kx = i%k_sxd/p_d;
//Int pz = i%p_d;
//Float *ip = &pvalues[ i ];//ky*p_sxz + kx*pl.sz + pz ];
Float *iw = &weights[ i*c_d ];
Float a = 0;
for(Int cz = 0; cz < c_d; ++cz) {
a += iw[ cz ] * cvalues[ cz ]; //*k_syxd
}
pvalues[ i ] = a;
//*ip = a;
}
barrier.wait2();
}
}
}
}
__attribute__((always_inline))
void threadFuncXYC(Barrier &barrier, int k_sx, int k_sy, int c_d, int p_d) {
switch(p_d) {
case 3: return threadFuncXYCP( barrier, k_sx, k_sy, c_d, 3 );
case 4: return threadFuncXYCP( barrier, k_sx, k_sy, c_d, 4 );
case 24: return threadFuncXYCP( barrier, k_sx, k_sy, c_d, 24 );
case 48: return threadFuncXYCP( barrier, k_sx, k_sy, c_d, 28 );
}
threadFuncXYCP( barrier, k_sx, k_sy, c_d, p_d );
}
__attribute__((always_inline))
void threadFuncXY(Barrier &barrier, int k_sx, int k_sy, int c_d, int p_d) {
switch(c_d) {
case 3: return threadFuncXYC( barrier, k_sx, k_sy, 3, p_d );
case 4: return threadFuncXYC( barrier, k_sx, k_sy, 4, p_d );
case 24: return threadFuncXYC( barrier, k_sx, k_sy, 24, p_d );
case 48: return threadFuncXYC( barrier, k_sx, k_sy, 48, p_d );
}
threadFuncXYC( barrier, k_sx, k_sy, c_d, p_d );
}
void threadFunc(Barrier &barrier) override {
Int k_sx = k.sx;
Int k_sy = k.sy;
Int c_d = cl.getD();
Int p_d = pl.getD();
if (k_sy == k_sx) switch(k_sx) {
case 4: return threadFuncXY( barrier, 4, 4, c_d, p_d );
}
threadFuncXY( barrier, k_sx, k_sy, c_d, p_d );
}
void init(int mode, Layout pl, Layout cl, Kernel k, long long totalLinks) {
assert(pl);
assert(cl);
assert(k);
assert(totalLinks > 0);
assert(0 <= pl.x0 + k.ox && (cl.getW()-1)*k.dx + k.ox + k.sx <= pl.sx);
assert(0 <= pl.y0 + k.oy && (cl.getH()-1)*k.dy + k.oy + k.sy <= pl.sy);
this->mode = mode;
this->pl = pl;
this->cl = cl;
this->k = k;
pvalues.resize(pl.getCount());
cvalues.resize(cl.getCount());
weights.resize(cl.getD()*k.sx*k.sy*pl.getD());
for(int i = 0; i < (int)pvalues.size(); ++i) pvalues[i] = rand();
for(int i = 0; i < (int)cvalues.size(); ++i) cvalues[i] = rand();
for(int i = 0; i < (int)weights.size(); ++i) weights[i] = rand();
long long links = weights.size() * cl.getW() * cl.getH();
repeats = (totalLinks - 1)/links + 1;
//printf( "benchmark init: prev %lld, curr %lld, links %lld, repeats %d, total links: %lld\n",
// (long long)pvalues.size(), (long long)cvalues.size(), links, repeats, links*repeats );
}
void run(const char *name, int threadsCount, int mode, Layout pl, Layout cl, Kernel k, long long totalLinks) {
init(mode, pl, cl, k, totalLinks);
volatile long long t0 = timeUs();
runThreads(threadsCount);
volatile long long t1 = timeUs();
Float sum = 0;
for(int i = 0; i < (int)pvalues.size(); ++i) sum += pvalues[i];
for(int i = 0; i < (int)cvalues.size(); ++i) sum += cvalues[i];
for(int i = 0; i < (int)weights.size(); ++i) sum += weights[i];
printf("%s %d: %f, %lld\n", name, mode, (t1 - t0)*1e-6, (long long)sum);
}
void run(const char *name, int threadsCount, Layout pl, Layout cl, Kernel k, long long totalLinks) {
for(int mode = 0; mode < 2; ++mode)
run(name, threadsCount, mode, pl, cl, k, totalLinks);
}
public:
void run(int threadsCount = 1) {
//printf("run benchmark: %d\n", threadsCount);
/*
run( "514x3 -> 258x24", threadsCount,
Layout(514, 514, 3).expandXY(2),
Layout(258, 258, 24).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
run( "258x24 -> 130x48", threadsCount,
Layout(258, 258, 24).expandXY(2),
Layout(130, 130, 48).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
run( "130x48 -> 66x96 ", threadsCount,
Layout(130, 130, 48).expandXY(2),
Layout( 66, 66, 96).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
run( "66x96 -> 34x144", threadsCount,
Layout( 66, 66, 96).expandXY(2),
Layout(34, 34, 144).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
run( "34x144 -> 18x216", threadsCount,
Layout(34, 34, 144).expandXY(2),
Layout(18, 18, 216).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
run( "18x216 -> 10x324", threadsCount,
Layout(18, 18, 216).expandXY(2),
Layout(10, 10, 324).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
run( "10x324 -> 6x486 ", threadsCount,
Layout(10, 10, 324).expandXY(2),
Layout( 6, 6, 486).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
run( "6x486 -> 4x729 ", threadsCount,
Layout( 6, 6, 486).expandXY(2),
Layout( 4, 4, 729).expandXY(2), Kernel(4, 2, -2), 10ll*1000*1000*1000 );
*/
run( "4x768 -> 1x1093 ", threadsCount,
Layout( 4, 4, 768).expandXY(0),
Layout( 1, 1, 1093).expandXY(0), Kernel(4, 2, 0), 10ll*1000*1000*1000 );
}
};
#endif