Blob Blame Raw
#ifndef NNLAYER2_MT_INC_CPP
#define NNLAYER2_MT_INC_CPP


#include "nnlayer2.inc.cpp"


#include <atomic>
#include <thread>
#include <vector>


class Barrier {
private:
  std::atomic<unsigned int> &counter;
  const unsigned int threads;
  unsigned int next;
public:
  inline Barrier(std::atomic<unsigned int> &counter, unsigned int threads): counter(counter), threads(threads), next() { }
  inline void wait() { next += threads; ++counter; while(counter < next); }
};


class TrainMT {
private:
  struct LDesc {
    int nb, ne, lb, le;
    double sumQ;
    LDesc(): nb(), ne(), lb(), le(), sumQ() { }
  };

public:
  Layer *layer;
  const unsigned char *dataX;
  const unsigned char *dataY;
  int strideX;
  int strideY;
  int *shuffle;
  int count;
  Real trainRatio;

  TrainMT():
    layer(),
    dataX(),
    dataY(),
    strideX(),
    strideY(),
    shuffle(),
    count(),
    trainRatio() { }

private:
  void trainFunc(int tid, int threads, std::atomic<unsigned int> &barrierCounter, LDesc *ldescs) {
    Barrier barrier(barrierCounter, threads);

    Layer &fl = *layer;
    Layer &bl = layer->back();
    int layersCount = fl.totalLayers();
    LDesc *fld = ldescs, *bld = fld + layersCount - 1;

    Real trainRatio = this->trainRatio;

    double sumQ = 0;
    for(int i = 0; i < count; ++i) {
      int ii = shuffle[i];
      const unsigned char *curX = dataX + strideX*ii;
      const unsigned char *curY = dataY + strideY*ii;

      const unsigned char *px = curX;
      for(Neuron *in = fl.neurons + fld->nb, *e = fl.neurons + fld->ne; in < e; ++in, ++px)
        in->v = Real(*px)*Real(1/255.0);

      LDesc *ld = fld + 1;
      for(Layer *l = fl.next; l; l = l->next, ++ld) {
        barrier.wait();
        l->pass(ld->nb, ld->ne);
      }

      double q = 0;
      const unsigned char *py = curY;
      for(Neuron *in = bl.neurons + bld->nb, *e = bl.neurons + bld->ne; in < e; ++in, ++py) {
        Real d = Real(*py)*Real(1/255.0) - in->v;
        in->d *= d * trainRatio;
        q += d*d;
      }
      sumQ += q;

      if (trainRatio > 0) {
        ld = bld - 1;
        for(Layer *l = bl.prev; l; l = l->prev, --ld) {
          barrier.wait();
          l->backpass(ld->lb, ld->le);
        }
      }

      //if (!tid) printf(" - %d, %f, %f\n", i, q, sumQ);
    }

    ldescs->sumQ = sumQ;
  }

public:
  double train(int threads) {
    assert(threads > 0);
    assert(layer && !layer->prev);
    assert(dataX && dataY && shuffle);
    assert(count > 0);
    assert(trainRatio >= 0);

    int layersCount = layer->totalLayers();
    assert(layersCount > 0);
    std::vector<LDesc> ldescs( threads*layersCount );

    int layerId = 0;
    for(Layer *l = layer; l; l = l->next, ++layerId) {
      assert(layerId < layersCount);
      int tsize = l->size/threads;
      for(int tid = 0; tid < threads; ++tid) {
        LDesc &desc = ldescs[tid*layersCount + layerId];
        desc.nb = tid*tsize;
        desc.ne = desc.nb + tsize;
        if (tid == threads-1) desc.ne = l->size;
      }

      if (int lsize = l->size*l->lsize) {
        int tlsize = lsize/threads;
        int ipn = l->links[ l->lfirst ].nprev;
        int tid = 0;
        int count = 0;

        ldescs[tid*layersCount + layerId].lb = l->lfirst;
        if (threads > 1) {
          for(int il = l->lfirst; il != lsize; il = l->links[il].lnext, ++count) {
            Link &link = l->links[il];
            if (ipn != link.nprev) {
              if (count >= tlsize) {
                ldescs[tid*layersCount + layerId].le = il;
                ++tid;
                count -= tlsize;
                ldescs[tid*layersCount + layerId].lb = il;
                if (tid == threads - 1) break;
              }
              ipn = link.nprev;
            }
          }
        }
        ldescs[tid*layersCount + layerId].le = lsize;
      }
    }
    assert(layerId == layersCount);

    std::atomic<unsigned int> barrierCounter(0);
    std::vector<std::thread*> t(threads - 1);
    for(int i = 1; i < threads; ++i)
      t[i-1] = new std::thread(&TrainMT::trainFunc, this, i, threads, std::ref(barrierCounter), &ldescs[i*layersCount]);
    trainFunc(0, threads, barrierCounter, &ldescs[0]);

    double result = ldescs[0].sumQ;
    for(int i = 1; i < threads; ++i)
      { t[i-1]->join(); delete t[i-1]; result += ldescs[i*layersCount].sumQ; }

    return result/(count * layer->back().size);
  }
};


#endif