Blame simple/neural/nn-trainer.cpp

56d550
56d550
#include <ctime></ctime>
56d550
#include <cstdlib></cstdlib>
56d550
025224
#include "nnlayer.inc.cpp"
025224
#include "nnlayer.conv.inc.cpp"
025224
#include "nnlayer.lnk.inc.cpp"
56d550
#include "nntrain.inc.cpp"
56d550
56d550
56d550
int main() {
56d550
  srand(time(NULL));
56d550
025224
  const char *filename = "data/output/weights.bin"; // 28x28
025224
  //const char *filename = "data/output/weights22.bin";
025224
56d550
  printf("load training data\n");
025224
  Trainer t(0.25);
56d550
  if (!t.loadSymbolMap("data/symbols-data.bin", 784, 10)) return 1;
025224
  //if (!t.loadSymbolMap("data/symbols22-data.bin", 22*22, 10)) return 1;
56d550
56d550
  printf("create neural network\n");
025224
  //Layer l(nullptr, 784);
025224
  //new LayerSimple(l, 128);
025224
  //new LayerSimple(l, 64);
025224
  //new LayerSimple(l, 10);
025224
025224
  //Layer l(nullptr, 784);
025224
  //new LayerConvolution<false, 6,2="">(l, 12, 12, 4);</false,>
025224
  //new LayerConvolution<false, 6,2="">(l, 4, 4, 16);</false,>
025224
  ////new LayerConvolution<false, 4="">(l, 1, 1, 128);</false,>
025224
  ////new LayerSimple(l, 64);
025224
  //new LayerSimple(l, 10);
025224
025224
  //Layer l(nullptr, 784);
025224
  //new LayerConvolution<false, 7="">(l, 22, 22, 1);</false,>
025224
  //new LayerConvolution<false, 9="">(l, 14, 14, 1);</false,>
025224
  //new LayerConvolution<false,11>(l, 4, 4, 1);</false,11>
025224
  //new LayerSimple(l, 10);
025224
025224
  Layer l(nullptr, 784);
53488e
  new LayerLinkConvolution(l, 28, 28, 1, 22, 22, 1,  60);
53488e
  new LayerLinkConvolution(l, 22, 22, 1, 14, 14, 1, 100);
53488e
  new LayerLinkConvolution(l, 14, 14, 1,  4,  4, 1, 140);
025224
  new LayerSimple(l, 10);
025224
53488e
  printf("  neurons: %d, links %d, memSize: %llu\n", l.totalSize(), l.totalLinks(), (unsigned long long)l.totalMemSize());
56d550
025224
  //printf("try load previously saved network\n");
025224
  //l.load(filename);
56d550
56d550
  printf("train\n");
025224
  t.trainSimple(l, 1000, 0.5, 10000);
025224
  //for(double q = 0.5; q > 0.05; q *= 0.5)
025224
  //  t.trainBlock(l, 50, 20, q, 10000);
025224
  //t.trainBlock(l, 50, 20, 0.5, 10000);
025224
  //t.trainBlock(l, 10, 100, 0.5, 1);
56d550
56d550
  printf("save neural network weights\n");
025224
  if (!l.save(filename)) return 1;
56d550
56d550
  return 0;
56d550
}
56d550