|
|
b579b3 |
#ifndef SEGMENT_TEST_INC_CPP
|
|
|
b579b3 |
#define SEGMENT_TEST_INC_CPP
|
|
|
b579b3 |
|
|
|
b579b3 |
|
|
|
b579b3 |
#include "test.inc.cpp"
|
|
|
b579b3 |
#include "segment.inc.cpp"
|
|
|
b579b3 |
|
|
|
b579b3 |
|
|
|
b579b3 |
|
|
|
b579b3 |
class SegmentTest: public Test {
|
|
|
b579b3 |
public:
|
|
|
b579b3 |
static bool testSegment(const char *name, Segment &segment, Layout l, int x, int y, int z, NeuronReal trainRatio) {
|
|
|
b579b3 |
Stage st(name);
|
|
|
b579b3 |
|
|
|
e4740d |
class H: public ThreadControl {
|
|
|
e4740d |
public:
|
|
|
b579b3 |
Segment &segment;
|
|
|
b579b3 |
int x, y, z;
|
|
|
b579b3 |
Quality testQ;
|
|
|
b579b3 |
NeuronReal ratio;
|
|
|
b579b3 |
std::vector<quality> qualities;</quality>
|
|
|
b579b3 |
|
|
|
e4740d |
H(Segment &segment, int x, int y, int z, NeuronReal ratio): segment(segment), x(x), y(y), z(z), ratio(ratio) { }
|
|
|
b579b3 |
|
|
|
b579b3 |
void prepareData()
|
|
|
b579b3 |
{ memcpy(segment.weights, weights.data(), segment.weightsCount*sizeof(Weight)); }
|
|
|
b579b3 |
|
|
|
e4740d |
void threadFunc(Barrier &barrier) override
|
|
|
e4740d |
{ qualities[barrier.tid] = segment.pass(barrier, x, y, z, ratio); }
|
|
|
b579b3 |
|
|
|
b579b3 |
bool test(const char *name, int threadsCount) {
|
|
|
b579b3 |
Stage st(name);
|
|
|
b579b3 |
|
|
|
b579b3 |
assert(threadsCount > 0);
|
|
|
b579b3 |
qualities.clear();
|
|
|
b579b3 |
qualities.resize(threadsCount);
|
|
|
b579b3 |
|
|
|
b579b3 |
prepareData();
|
|
|
b579b3 |
|
|
|
b579b3 |
segment.split(threadsCount);
|
|
|
e4740d |
runThreads(threadsCount);
|
|
|
b579b3 |
|
|
|
e4740d |
Quality q;
|
|
|
e4740d |
for(int i = 0; i < threadsCount; ++i) q += qualities[i];
|
|
|
b579b3 |
|
|
|
b579b3 |
if ( fabs(q.train - testQ.train) > 1e-10
|
|
|
b579b3 |
|| fabs(q.human - testQ.human) > 1e-10 )
|
|
|
b579b3 |
{
|
|
|
b579b3 |
printf("qualities differs, was %g (%g), expected %g (%g)\n",
|
|
|
b579b3 |
q.human, q.train, testQ.human, testQ.train );
|
|
|
b579b3 |
++errors;
|
|
|
b579b3 |
}
|
|
|
b579b3 |
|
|
|
b579b3 |
for(int i = 0; i < segment.weightsCount; ++i) {
|
|
|
b579b3 |
WeightReal a = segment.weights[i].w;
|
|
|
b579b3 |
WeightReal b = weights[i + segment.weightsCount].w;
|
|
|
b579b3 |
if (fabs(a - b) > 1e-10) {
|
|
|
b579b3 |
printf("weights differs at %d, was %g, expected %g\n", i, a, b);
|
|
|
b579b3 |
segment.layout.printYXZ("layout");
|
|
|
b579b3 |
++errors; break;
|
|
|
b579b3 |
}
|
|
|
b579b3 |
}
|
|
|
b579b3 |
|
|
|
b579b3 |
return st;
|
|
|
b579b3 |
}
|
|
|
b579b3 |
} h(segment, x, y, z, trainRatio);
|
|
|
b579b3 |
|
|
|
b579b3 |
|
|
|
b579b3 |
assert(segment.weightsCount > 0);
|
|
|
b579b3 |
|
|
|
b579b3 |
int valuesCount = l.getCount();
|
|
|
b579b3 |
init(0, 0, segment.weightsCount*3, valuesCount);
|
|
|
b579b3 |
|
|
|
b579b3 |
for(int i = 0; i < valuesCount; ++i)
|
|
|
b579b3 |
values[i] = rand()/(NeuronReal)RAND_MAX;
|
|
|
b579b3 |
for(int i = 0; i < segment.weightsCount; ++i)
|
|
|
b579b3 |
weights[i].w = (WeightReal)(2.0*rand()/RAND_MAX - 1);
|
|
|
b579b3 |
|
|
|
b579b3 |
segment.layout = l;
|
|
|
b579b3 |
segment.f_values = values.data();
|
|
|
b579b3 |
segment.weights = &weights[segment.weightsCount];
|
|
|
b579b3 |
segment.check(x, y, z);
|
|
|
b579b3 |
|
|
|
b579b3 |
h.prepareData();
|
|
|
b579b3 |
h.testQ = segment.testPass(x, y, z, trainRatio);
|
|
|
b579b3 |
segment.weights += segment.weightsCount;
|
|
|
b579b3 |
|
|
|
b579b3 |
h.test("single-thread", 1);
|
|
|
b579b3 |
h.test("single-thread-repeat", 1);
|
|
|
b579b3 |
h.test("2-threads", 2);
|
|
|
b579b3 |
h.test("7-threads", 7);
|
|
|
b579b3 |
h.test("7-threads-repeat", 7);
|
|
|
b579b3 |
h.test("8-threads", 8);
|
|
|
b579b3 |
|
|
|
b579b3 |
return st;
|
|
|
b579b3 |
}
|
|
|
e4740d |
|
|
|
e4740d |
|
|
|
e4740d |
static bool testSegment(const char *name, Segment &segment) {
|
|
|
e4740d |
Layout l = segment.layout;
|
|
|
e4740d |
int x = l.x0 + rand()%(l.getW() - segment.sx + 1);
|
|
|
e4740d |
int y = l.y0 + rand()%(l.getH() - segment.sy + 1);
|
|
|
e4740d |
int z = l.z0 + rand()%(l.getD() - segment.sz + 1);
|
|
|
e4740d |
return testSegment(name, segment, l, x, y, z, 0.1);
|
|
|
e4740d |
}
|
|
|
b579b3 |
};
|
|
|
b579b3 |
|
|
|
b579b3 |
|
|
|
b579b3 |
#endif
|