Blame simple/neural/nn-heli.cpp

025224
025224
#include <math.h></math.h>
025224
#include <stdio.h></stdio.h>
025224
#include <stdlib.h></stdlib.h>
025224
025224
#include <helianthus.h></helianthus.h>
025224
025224
#include "nnlayer.inc.cpp"
025224
025224
025224
Layer *nl;
025224
Framebuffer fb, fbMin;
025224
Animation fbAnim, fbMinAnim;
025224
025224
int wasPressed;
025224
double prevX, prevY;
025224
025224
025224
025224
void prepareImage() {
025224
  int w, h;
025224
  unsigned char *pixels = NULL;
025224
025224
  saveState();
025224
  target(fb);
025224
  imageFromViewport(&w, &h, &pixels);
025224
  restoreState();
025224
  if (!pixels) return;
025224
025224
  int x0 = w, y0 = h, x1 = 0, y1 = 0;
025224
  for(int y = 0; y < h; ++y) {
025224
    for(int x = 0; x < w; ++x) {
025224
      if (imageGetPixel(w, h, pixels, x, y) != 0x000000ff) {
025224
        if (x0 > x) x0 = x;
025224
        if (x1 < x) x1 = x;
025224
        if (y0 > y) y0 = y;
025224
        if (y1 < y) y1 = y;
025224
      }
025224
    }
025224
  }
025224
  free(pixels);
025224
  pixels = NULL;
025224
025224
  if (x1 < x0 || y1 < y0) return;
025224
025224
  int fw = framebufferGetWidth(fbMin);
025224
  int fh = framebufferGetHeight(fbMin);
025224
025224
  double wx = x1 - x0 + 1;
025224
  double wy = y1 - y0 + 1;
025224
  double s = (fw - 4)/(double)(wx > wy ? wx : wy);
025224
  double cx = (x0 + x1)/2.0;
025224
  double cy = (y0 + y1)/2.0;
025224
025224
  double xx = fw/2 - s*cx;
025224
  double yy = fh/2 - s*cy;
025224
  double ww = s*w;
025224
  double hh = s*h;
025224
025224
  saveState();
025224
  target(fbMin);
025224
  noStroke();
025224
  rectTextured(fbAnim, xx, yy, ww, hh);
025224
  imageFromViewport(&w, &h, &pixels);
025224
  restoreState();
025224
025224
  if (!pixels) return;
025224
  double *p = nl->a;
025224
  for(int y = 0; y < h; ++y)
025224
    for(int x = 0; x < w; ++x)
025224
      *p++ = colorGetValue(imageGetPixel(w, h, pixels, x, y));
025224
}
025224
025224
025224
void init() {
025224
  background(COLOR_BLACK);
025224
  stroke(COLOR_WHITE);
025224
  fb = createFramebufferEx(512, 512, NULL, FALSE, FALSE, TRUE);
025224
  fbMin = createFramebufferEx(28, 28, NULL, FALSE, FALSE, TRUE);
025224
  fbAnim = createAnimationFromFramebuffer(fb);
025224
  fbMinAnim = createAnimationFromFramebuffer(fbMin);
025224
025224
  saveState();
025224
  target(fb);
025224
  clear();
025224
  target(fbMin);
025224
  clear();
025224
  restoreState();
025224
025224
  nl = new Layer(nullptr, 784);
025224
  new LayerAdd(*nl, 784);
025224
  new LayerSimple(*nl, 256);
025224
  new LayerSimple(*nl, 128);
025224
  new LayerSimple(*nl, 64);
025224
  new LayerSimple(*nl, 64);
025224
  new LayerSimple(*nl, 10);
025224
025224
  //nl = new Layer(nullptr, 22*22);
025224
  //new LayerConvolution(*nl, 10, 10, 4);
025224
  //new LayerConvolution(*nl, 4, 4, 16);
025224
  //new LayerConvolution(*nl, 1, 1, 64);
025224
  //new LayerPlain(*nl, 64);
025224
  //new LayerPlain(*nl, 32);
025224
  //new LayerPlain(*nl, 10);
025224
025224
  nl->load("data/weights.bin");
025224
}
025224
025224
025224
void draw() {
025224
  saveState();
025224
025224
  if (mouseDown("left")) {
025224
    double x = mouseX(), y = mouseY();
025224
    if (!wasPressed) prevX = x, prevY = y;
025224
025224
    saveState();
025224
    strokeWidth(32);
025224
    target(fb);
025224
    line(prevX, prevY, x, y);
025224
    restoreState();
025224
025224
    prevX = x, prevY = y;
025224
    wasPressed = TRUE;
025224
  } else {
025224
    wasPressed = FALSE;
025224
  }
025224
025224
  if (keyWentDown("space")) {
025224
    prepareImage();
025224
    nl->pass();
025224
    saveState();
025224
    target(fb);
025224
    clear();
025224
    restoreState();
025224
  }
025224
025224
  noStroke();
025224
  rectTextured(fbAnim, 0, 0, 512, 512);
025224
025224
  stroke(COLOR_WHITE);
025224
  rectTextured(fbMinAnim, 16, 16, 22, 22);
025224
025224
  noFill();
025224
025224
  Layer &nlb = nl->back();
025224
  textSize(8);
025224
  int res = 0;
025224
  for(int i = 0; i < 10; ++i) {
025224
    if (nlb.a[i] > nlb.a[res]) res = i;
025224
    textf(16, 90+8*i, "%d: %lf", i, nlb.a[i]);
025224
  }
025224
  textSize(16);
025224
  textf(16, 60, "%d", res);
025224
025224
  restoreState();
025224
}
025224
025224
025224
int main(int largc, char **largv) {
025224
  windowSetVariableFrameRate();
025224
  windowSetInit(&init);
025224
  windowSetDraw(&draw);
025224
  windowRun();
025224
  return 0;
025224
}
025224