Blame projects/neural/view-digits.cpp

8e5348
8e5348
#include <helianthus.h></helianthus.h>
8e5348
8e5348
8e5348
#include "layer.all.inc.cpp"
8e5348
8e5348
8e5348
Layer *nl;
8e5348
Framebuffer fb, fbMin;
8e5348
Animation fbAnim, fbMinAnim;
8e5348
8e5348
int wasPressed;
8e5348
double prevX, prevY;
8e5348
8e5348
8e5348
8e5348
void prepareImage() {
8e5348
  int w, h;
8e5348
  unsigned char *pixels = NULL;
8e5348
8e5348
  saveState();
8e5348
  target(fb);
8e5348
  imageFromViewport(&w, &h, &pixels);
8e5348
  restoreState();
8e5348
  if (!pixels) return;
8e5348
8e5348
  int x0 = w, y0 = h, x1 = 0, y1 = 0;
8e5348
  for(int y = 0; y < h; ++y) {
8e5348
    for(int x = 0; x < w; ++x) {
8e5348
      if (imageGetPixel(w, h, pixels, x, y) != 0x000000ff) {
8e5348
        if (x0 > x) x0 = x;
8e5348
        if (x1 < x) x1 = x;
8e5348
        if (y0 > y) y0 = y;
8e5348
        if (y1 < y) y1 = y;
8e5348
      }
8e5348
    }
8e5348
  }
8e5348
  free(pixels);
8e5348
  pixels = NULL;
8e5348
8e5348
  if (x1 < x0 || y1 < y0) return;
8e5348
8e5348
  int fw = framebufferGetWidth(fbMin);
8e5348
  int fh = framebufferGetHeight(fbMin);
8e5348
8e5348
  double wx = x1 - x0 + 1;
8e5348
  double wy = y1 - y0 + 1;
8e5348
  double s = (fw - 4)/(double)(wx > wy ? wx : wy);
8e5348
  double cx = (x0 + x1)/2.0;
8e5348
  double cy = (y0 + y1)/2.0;
8e5348
8e5348
  double xx = fw/2 - s*cx;
8e5348
  double yy = fh/2 - s*cy;
8e5348
  double ww = s*w;
8e5348
  double hh = s*h;
8e5348
8e5348
  saveState();
8e5348
  target(fbMin);
8e5348
  noStroke();
8e5348
  rectTextured(fbAnim, xx, yy, ww, hh);
8e5348
  imageFromViewport(&w, &h, &pixels);
8e5348
  restoreState();
8e5348
8e5348
  if (!pixels) return;
8e5348
  Neuron *in = nl->neurons;
8e5348
  for(int y = 0; y < h; ++y)
8e5348
    for(int x = 0; x < w; ++x)
8e5348
      (in++)->v = colorGetValue(imageGetPixel(w, h, pixels, x, y));
8e5348
}
8e5348
8e5348
8e5348
void init() {
8e5348
  background(COLOR_BLACK);
8e5348
  stroke(COLOR_WHITE);
8e5348
  fb = createFramebufferEx(512, 512, NULL, FALSE, FALSE, TRUE);
8e5348
  fbMin = createFramebufferEx(28, 28, NULL, FALSE, FALSE, TRUE);
8e5348
  fbAnim = createAnimationFromFramebuffer(fb);
8e5348
  fbMinAnim = createAnimationFromFramebuffer(fbMin);
8e5348
8e5348
  saveState();
8e5348
  target(fb);
8e5348
  clear();
8e5348
  target(fbMin);
8e5348
  clear();
8e5348
  restoreState();
8e5348
8e5348
  #define FILENAME "data/weights-digit.bin"
8e5348
  nl = new Layer(                   nullptr, Layout(28, 28));
8e5348
  (new LayerSimple<funcsigmoidexp>    ( *nl, Layout(256)      ))->filename = FILENAME "1";</funcsigmoidexp>
8e5348
  (new LayerSimple<funcsigmoidexp>    ( *nl, Layout( 64)      ))->filename = FILENAME "2";</funcsigmoidexp>
8e5348
  (new LayerSimple<funcsigmoidexp>    ( *nl, Layout( 10)      ))->filename = FILENAME "3";</funcsigmoidexp>
8e5348
  nl->load();
8e5348
}
8e5348
8e5348
8e5348
void draw() {
8e5348
  saveState();
8e5348
8e5348
  if (mouseDown("left")) {
8e5348
    double x = mouseX(), y = mouseY();
8e5348
    if (!wasPressed) prevX = x, prevY = y;
8e5348
8e5348
    saveState();
8e5348
    strokeWidth(32);
8e5348
    target(fb);
8e5348
    line(prevX, prevY, x, y);
8e5348
    restoreState();
8e5348
8e5348
    prevX = x, prevY = y;
8e5348
    wasPressed = TRUE;
8e5348
  } else {
8e5348
    wasPressed = FALSE;
8e5348
  }
8e5348
8e5348
  if (keyWentDown("space")) {
8e5348
    prepareImage();
8e5348
    for(Layer *l = nl->next; l; l = l->next)
8e5348
      l->testPass();
8e5348
    saveState();
8e5348
    target(fb);
8e5348
    clear();
8e5348
    restoreState();
8e5348
  }
8e5348
8e5348
  noStroke();
8e5348
  rectTextured(fbAnim, 0, 0, 512, 512);
8e5348
8e5348
  stroke(COLOR_WHITE);
8e5348
  rectTextured(fbMinAnim, 16, 16, 28, 28);
8e5348
8e5348
  noFill();
8e5348
8e5348
  Layer &nlb = nl->back();
8e5348
  textSize(8);
8e5348
  int res = 0;
8e5348
  for(int i = 0; i < 10; ++i) {
8e5348
    if (nlb.neurons[i].v > nlb.neurons[res].v) res = i;
8e5348
    textf(16, 90+8*i, "%d: %lf", i, nlb.neurons[i].v);
8e5348
  }
8e5348
  textSize(16);
8e5348
  textf(16, 60, "%d", res);
8e5348
8e5348
  restoreState();
8e5348
}
8e5348
8e5348
8e5348
int main(int largc, char **largv) {
8e5348
  windowSetVariableFrameRate();
8e5348
  windowSetInit(&init);
8e5348
  windowSetDraw(&draw);
8e5348
  windowRun();
8e5348
  return 0;
8e5348
}
8e5348