Blame simple/neural/nn-trainer.c
|
|
ec0d7e |
|
|
|
ec0d7e |
#include <time.h></time.h>
|
|
|
ec0d7e |
#include <stdlib.h></stdlib.h>
|
|
|
ec0d7e |
|
|
|
ec0d7e |
#include "nntrain.inc.c"
|
|
|
ec0d7e |
|
|
|
ec0d7e |
|
|
|
ec0d7e |
int main() {
|
|
|
ec0d7e |
srand(time(NULL));
|
|
|
ec0d7e |
|
|
|
ec0d7e |
printf("load training data\n");
|
|
|
ec0d7e |
NeuralTrainer *nt = ntNewSymbolMap("data/symbols-data.bin", 784, 10);
|
|
|
ec0d7e |
if (!nt) return 1;
|
|
|
ec0d7e |
|
|
|
ec0d7e |
printf("create neural network\n");
|
|
|
ec0d7e |
double tr = 0.1;
|
|
|
ec0d7e |
NeuralLayer *nl = nlNew(NULL, 784, tr);
|
|
|
ec0d7e |
nlNew(nl, 256, tr);
|
|
|
ec0d7e |
nlNew(nl, 128, tr);
|
|
|
ec0d7e |
nlNew(nl, 64, tr);
|
|
|
ec0d7e |
nlNew(nl, 64, tr);
|
|
|
ec0d7e |
nlNew(nl, 10, tr);
|
|
|
ec0d7e |
|
|
|
ec0d7e |
printf("try load previously saved network\n");
|
|
|
ec0d7e |
nlLoad(nl, "data/weights.bin");
|
|
|
ec0d7e |
|
|
|
ec0d7e |
printf("train\n");
|
|
|
ec0d7e |
ntTrain(nt, nl, 10, 100, 0.4);
|
|
|
ec0d7e |
ntTrain(nt, nl, 10, 100, 0.2);
|
|
|
ec0d7e |
ntTrain(nt, nl, 10, 100, 0.1);
|
|
|
ec0d7e |
ntTrain(nt, nl, 10, 100, 0.05);
|
|
|
ec0d7e |
|
|
|
ec0d7e |
printf("save neural network weights\n");
|
|
|
ec0d7e |
if (!nlSave(nl, "data/output/weights.bin")) return 1;
|
|
|
ec0d7e |
|
|
|
ec0d7e |
return 0;
|
|
|
ec0d7e |
}
|
|
|
ec0d7e |
|