|
|
56d550 |
|
|
|
56d550 |
#include <ctime></ctime>
|
|
|
56d550 |
#include <cstdlib></cstdlib>
|
|
|
56d550 |
|
|
|
025224 |
#include "nnlayer.inc.cpp"
|
|
|
025224 |
#include "nnlayer.conv.inc.cpp"
|
|
|
025224 |
#include "nnlayer.lnk.inc.cpp"
|
|
|
56d550 |
#include "nntrain.inc.cpp"
|
|
|
56d550 |
|
|
|
56d550 |
|
|
|
56d550 |
int main() {
|
|
|
56d550 |
srand(time(NULL));
|
|
|
56d550 |
|
|
|
025224 |
const char *filename = "data/output/weights.bin"; // 28x28
|
|
|
025224 |
//const char *filename = "data/output/weights22.bin";
|
|
|
025224 |
|
|
|
56d550 |
printf("load training data\n");
|
|
|
025224 |
Trainer t(0.25);
|
|
|
56d550 |
if (!t.loadSymbolMap("data/symbols-data.bin", 784, 10)) return 1;
|
|
|
025224 |
//if (!t.loadSymbolMap("data/symbols22-data.bin", 22*22, 10)) return 1;
|
|
|
56d550 |
|
|
|
56d550 |
printf("create neural network\n");
|
|
|
025224 |
//Layer l(nullptr, 784);
|
|
|
025224 |
//new LayerSimple(l, 128);
|
|
|
025224 |
//new LayerSimple(l, 64);
|
|
|
025224 |
//new LayerSimple(l, 10);
|
|
|
025224 |
|
|
|
025224 |
//Layer l(nullptr, 784);
|
|
|
025224 |
//new LayerConvolution<false, 6,2="">(l, 12, 12, 4);</false,>
|
|
|
025224 |
//new LayerConvolution<false, 6,2="">(l, 4, 4, 16);</false,>
|
|
|
025224 |
////new LayerConvolution<false, 4="">(l, 1, 1, 128);</false,>
|
|
|
025224 |
////new LayerSimple(l, 64);
|
|
|
025224 |
//new LayerSimple(l, 10);
|
|
|
025224 |
|
|
|
025224 |
//Layer l(nullptr, 784);
|
|
|
025224 |
//new LayerConvolution<false, 7="">(l, 22, 22, 1);</false,>
|
|
|
025224 |
//new LayerConvolution<false, 9="">(l, 14, 14, 1);</false,>
|
|
|
025224 |
//new LayerConvolution<false,11>(l, 4, 4, 1);</false,11>
|
|
|
025224 |
//new LayerSimple(l, 10);
|
|
|
025224 |
|
|
|
025224 |
Layer l(nullptr, 784);
|
|
|
53488e |
new LayerLinkConvolution(l, 28, 28, 1, 22, 22, 1, 60);
|
|
|
53488e |
new LayerLinkConvolution(l, 22, 22, 1, 14, 14, 1, 100);
|
|
|
53488e |
new LayerLinkConvolution(l, 14, 14, 1, 4, 4, 1, 140);
|
|
|
025224 |
new LayerSimple(l, 10);
|
|
|
025224 |
|
|
|
53488e |
printf(" neurons: %d, links %d, memSize: %llu\n", l.totalSize(), l.totalLinks(), (unsigned long long)l.totalMemSize());
|
|
|
56d550 |
|
|
|
025224 |
//printf("try load previously saved network\n");
|
|
|
025224 |
//l.load(filename);
|
|
|
56d550 |
|
|
|
56d550 |
printf("train\n");
|
|
|
025224 |
t.trainSimple(l, 1000, 0.5, 10000);
|
|
|
025224 |
//for(double q = 0.5; q > 0.05; q *= 0.5)
|
|
|
025224 |
// t.trainBlock(l, 50, 20, q, 10000);
|
|
|
025224 |
//t.trainBlock(l, 50, 20, 0.5, 10000);
|
|
|
025224 |
//t.trainBlock(l, 10, 100, 0.5, 1);
|
|
|
56d550 |
|
|
|
56d550 |
printf("save neural network weights\n");
|
|
|
025224 |
if (!l.save(filename)) return 1;
|
|
|
56d550 |
|
|
|
56d550 |
return 0;
|
|
|
56d550 |
}
|
|
|
56d550 |
|