#ifndef SEGMENT_TEST_INC_CPP
#define SEGMENT_TEST_INC_CPP
#include "test.inc.cpp"
#include "segment.inc.cpp"
class SegmentTest: public Test {
public:
static bool testSegment(const char *name, Segment &segment, Layout l, int x, int y, int z, NeuronReal trainRatio) {
Stage st(name);
class H: public ThreadControl {
public:
Segment &segment;
int x, y, z;
Quality testQ;
NeuronReal ratio;
std::vector<Quality> qualities;
H(Segment &segment, int x, int y, int z, NeuronReal ratio): segment(segment), x(x), y(y), z(z), ratio(ratio) { }
void prepareData()
{ memcpy(segment.weights, weights.data(), segment.weightsCount*sizeof(Weight)); }
void threadFunc(Barrier &barrier) override
{ qualities[barrier.tid] = segment.pass(barrier, x, y, z, ratio); }
bool test(const char *name, int threadsCount) {
Stage st(name);
assert(threadsCount > 0);
qualities.clear();
qualities.resize(threadsCount);
prepareData();
segment.split(threadsCount);
runThreads(threadsCount);
Quality q;
for(int i = 0; i < threadsCount; ++i) q += qualities[i];
if ( fabs(q.train - testQ.train) > 1e-10
|| fabs(q.human - testQ.human) > 1e-10 )
{
printf("qualities differs, was %g (%g), expected %g (%g)\n",
q.human, q.train, testQ.human, testQ.train );
++errors;
}
for(int i = 0; i < segment.weightsCount; ++i) {
WeightReal a = segment.weights[i].w;
WeightReal b = weights[i + segment.weightsCount].w;
if (fabs(a - b) > 1e-10) {
printf("weights differs at %d, was %g, expected %g\n", i, a, b);
segment.layout.printYXZ("layout");
++errors; break;
}
}
return st;
}
} h(segment, x, y, z, trainRatio);
assert(segment.weightsCount > 0);
int valuesCount = l.getCount();
init(0, 0, segment.weightsCount*3, valuesCount);
for(int i = 0; i < valuesCount; ++i)
values[i] = rand()/(NeuronReal)RAND_MAX;
for(int i = 0; i < segment.weightsCount; ++i)
weights[i].w = (WeightReal)(2.0*rand()/RAND_MAX - 1);
segment.layout = l;
segment.f_values = values.data();
segment.weights = &weights[segment.weightsCount];
segment.check(x, y, z);
h.prepareData();
h.testQ = segment.testPass(x, y, z, trainRatio);
segment.weights += segment.weightsCount;
h.test("single-thread", 1);
h.test("single-thread-repeat", 1);
h.test("2-threads", 2);
h.test("7-threads", 7);
h.test("7-threads-repeat", 7);
h.test("8-threads", 8);
return st;
}
static bool testSegment(const char *name, Segment &segment) {
Layout l = segment.layout;
int x = l.x0 + rand()%(l.getW() - segment.sx + 1);
int y = l.y0 + rand()%(l.getH() - segment.sy + 1);
int z = l.z0 + rand()%(l.getD() - segment.sz + 1);
return testSegment(name, segment, l, x, y, z, 0.1);
}
};
#endif