Blame simple/neural/nn-trainer.c

Ivan Mahonin ec0d7e
Ivan Mahonin ec0d7e
#include <time.h>
Ivan Mahonin ec0d7e
#include <stdlib.h>
Ivan Mahonin ec0d7e
Ivan Mahonin ec0d7e
#include "nntrain.inc.c"
Ivan Mahonin ec0d7e
Ivan Mahonin ec0d7e
Ivan Mahonin ec0d7e
int main() {
Ivan Mahonin ec0d7e
  srand(time(NULL));
Ivan Mahonin ec0d7e
Ivan Mahonin ec0d7e
  printf("load training data\n");
Ivan Mahonin ec0d7e
  NeuralTrainer *nt = ntNewSymbolMap("data/symbols-data.bin", 784, 10);
Ivan Mahonin ec0d7e
  if (!nt) return 1;
Ivan Mahonin ec0d7e
Ivan Mahonin ec0d7e
  printf("create neural network\n");
Ivan Mahonin ec0d7e
  double tr = 0.1;
Ivan Mahonin ec0d7e
  NeuralLayer *nl = nlNew(NULL, 784, tr);
Ivan Mahonin ec0d7e
  nlNew(nl, 256, tr);
Ivan Mahonin ec0d7e
  nlNew(nl, 128, tr);
Ivan Mahonin ec0d7e
  nlNew(nl, 64, tr);
Ivan Mahonin ec0d7e
  nlNew(nl, 64, tr);
Ivan Mahonin ec0d7e
  nlNew(nl, 10, tr);
Ivan Mahonin ec0d7e
Ivan Mahonin ec0d7e
  printf("try load previously saved network\n");
Ivan Mahonin ec0d7e
  nlLoad(nl, "data/weights.bin");
Ivan Mahonin ec0d7e
Ivan Mahonin ec0d7e
  printf("train\n");
Ivan Mahonin ec0d7e
  ntTrain(nt, nl, 10, 100, 0.4);
Ivan Mahonin ec0d7e
  ntTrain(nt, nl, 10, 100, 0.2);
Ivan Mahonin ec0d7e
  ntTrain(nt, nl, 10, 100, 0.1);
Ivan Mahonin ec0d7e
  ntTrain(nt, nl, 10, 100, 0.05);
Ivan Mahonin ec0d7e
Ivan Mahonin ec0d7e
  printf("save neural network weights\n");
Ivan Mahonin ec0d7e
  if (!nlSave(nl, "data/output/weights.bin")) return 1;
Ivan Mahonin ec0d7e
Ivan Mahonin ec0d7e
  return 0;
Ivan Mahonin ec0d7e
}
Ivan Mahonin ec0d7e