#ifndef LAYER_INC_CPP
#define LAYER_INC_CPP
#include "common.inc.cpp"
class WeightHolder {
public:
const int weightsCount;
Weight *weights;
const char *filename;
explicit WeightHolder(int weightsCount = 0, Weight *weights = nullptr):
weightsCount(weightsCount), weights(weights), filename()
{ assert(weightsCount >= 0); }
virtual ~WeightHolder() { }
bool save(bool demoOnly = false) {
if (filename && weightsCount && !demoOnly) {
FILE *f = fopen(filename, "wb");
if (!f)
return printf("cannot open file for write: %s\n", filename), false;
if (!fwrite(weights, sizeof(*weights)*weightsCount, 1, f))
return fclose(f), printf("cannot write to file: %s\n", filename), false;
fclose(f);
}
return saveDemo();
}
bool load() {
if (filename && weightsCount) {
FILE *f = fopen(filename, "rb");
if (!f)
return printf("cannot open file for read: %s\n", filename), false;
if (!fread(weights, sizeof(*weights)*weightsCount, 1, f))
return fclose(f), printf("cannot read from file: %s\n", filename), false;
fclose(f);
}
return true;
}
virtual bool saveDemo() { return true; }
};
class Layer: public WeightHolder {
public:
Layer *prev, *next;
Layout layout;
Neuron *neurons;
int neuronsCount;
bool ownWeights;
bool skipTrain;
Stat stat;
Layout::List mtLayouts;
Layout::List mtPrevLayouts;
Layer(Layer *prev, const Layout &layout, int weightsCount = 0, Weight *weights = nullptr):
WeightHolder(weightsCount, weights),
prev(prev ? &prev->back() : nullptr),
next(),
layout(layout),
neurons(),
neuronsCount(layout.getCount()),
ownWeights(!weights && weightsCount),
skipTrain()
{
assert(layout);
assert(neuronsCount > 0);
assert(weightsCount >= 0);
assert(prev || !weightsCount);
if (this->prev) this->prev->next = this;
if (neuronsCount) {
neurons = new Neuron[neuronsCount];
memset(neurons, 0, sizeof(*neurons)*neuronsCount);
}
if (ownWeights) {
this->weights = new Weight[weightsCount];
memset(this->weights, 0, sizeof(*this->weights)*weightsCount);
}
stat.neurons = neuronsCount;
stat.activeNeurons = layout.getActiveCount();
stat.weights = weightsCount;
stat.links = weightsCount;
stat.memsize = neuronsCount*sizeof(*neurons);
if (ownWeights) stat.memsize += weightsCount*sizeof(*weights);
}
~Layer() {
if (next) delete next;
if (neurons) delete[] neurons;
if (ownWeights) delete[] weights;
}
inline Layer& front()
{ Layer *l = this; while(l->prev) l = l->prev; return *l; }
inline Layer& back()
{ Layer *l = this; while(l->next) l = l->next; return *l; }
inline Stat sumStat() const
{ Stat s; for(const Layer *l = this; l; l = l->next) s += l->stat; return s; }
bool save(bool demoOnly = false)
{ return WeightHolder::save(demoOnly) && (!next || next->save(demoOnly)); }
bool load()
{ return WeightHolder::load() && (!next || next->load()); }
void clearAccum() {
Accum a = {};
for(Neuron *in = neurons, *e = in + neuronsCount; in < e; ++in)
in->a = a;
}
void fillWeights(WeightReal wmin, WeightReal wmax) {
WeightReal k = (wmax - wmin)/RAND_MAX;
for(Weight *iw = weights, *e = iw + weightsCount; iw < e; ++iw)
iw->w = rand()*k + wmin;
}
virtual void split(int threadsCount) {
layout.split(mtLayouts, threadsCount);
if (prev) prev->layout.split(mtPrevLayouts, threadsCount);
}
virtual void pass(Barrier &barrier) { }
virtual void backpassWeights(Barrier &barrier) { }
virtual void backpassDeltas(Barrier &barrier) { }
virtual void testPass() { }
virtual void testBackpass() { }
void passFull(const Layer *last = nullptr, int threadsCount = 1) {
struct H {
Layer &layer;
const Layer *last;
std::atomic<unsigned int> barrierCounter;
std::vector<std::thread*> threads;
H(Layer &layer, const Layer *last, int threadsCount): layer(layer), last(last), barrierCounter(0), threads(threadsCount) { }
void func(int tid, unsigned int seed) {
Barrier barrier(barrierCounter, tid, threads.size(), seed);
for(Layer *l = layer.next; l; l = l->next) {
l->pass(barrier);
if (l == last || !l->next) break;
barrier.wait();
}
}
} h(*this, last, threadsCount);
for(Layer *l = this; l; l = l->next)
l->split(threadsCount);
for(int i = 1; i < threadsCount; ++i)
h.threads[i] = new std::thread(&H::func, &h, i, rand());
h.func(0, rand());
for(int i = 1; i < threadsCount; ++i)
{ h.threads[i]->join(); delete h.threads[i]; }
}
};
#endif