Blob Blame Raw

import threading


class CacheItem:
  def __init__(self, owner, key, value):
    self.owner = owner
    self.key = key
    self.value = value
    
    self.prev = None
    self.next = self.owner.first
    if self.next:
      self.next.prev = self
    else:
      self.owner.last = self
    self.owner.first = self
  
  def touch(self):
    if not self.prev:
      return
    self.prev.next = self.next
    if self.next:
      self.next.prev = self.prev
    else:
      self.owner.last = self.prev
    return self.value

  def detouch(self):
    if self.prev:
      self.prev.next = self.next
    else:
      self.owner.first = self.next
    if self.next:
      self.next.prev = self.prev
    else:
      self.owner.last = self.prev


class CacheTable:
  def __init__(self, maxcount):
    self.items = dict()
    self.first = None
    self.last = None
    self.maxcount = maxcount
    
  def get(self, key):
    item = self.items.get(key, None)
    return item.touch() if item else None
  
  def set(self, key, value):
    item = self.items.get(key, None)
    if item:
      item.value = value
      item.touch()
      return
    self.items[key] = CacheItem(self, key, value)
    while len(self.items) > self.maxcount:
      del self.items[self.last.key]
      self.last.detouch()

  def unset(self, key):
    item = self.items.get(key, None)
    if item:
      del self.items[key]
      item.detouch()
    
  def clear(self):
    self.first = None
    self.last = None
    # remove cyclic references to help GC
    for v in self.items.values():
      v.prev = None
      v.next = None
    self.items.clear()
    self.count = 0


class Cache:
  def __init__(self, server):
    self.server = server
    self.lock = threading.Lock()
    self.tables = dict()
    self.maxcount = self.server.config['db']['cache']['maxcount']
  
  def create_connection(self, connection):
    return CacheConnection(self, connection)
  
  def clear(self):
    with self.lock:
      for t in self.tables.values():
        t.clear()
      self.tables.clear()
  
  def build_select(self, connection, table, fields):
    assert(type(fields) is dict)
    where = list()
    args = list()
    for k, v in fields.items():
      assert(type(k) is str)
      if type(v) is int:
        where.append('%F=%d')
        args.append(k)
        args.append(v)
      elif type(v) is str:
        where.append('%F=%s')
        args.append(k)
        args.append(v)
      else:
        assert(False)
    where = ' AND '.join(where) if where else '1'
    return connection.parse('SELECT * FROM %T WHERE ' + where, table, *args)
    
  def select(self, connection, table, fields):
    assert(type(table) is str)
    sql = self.build_select(connection, table, fields)
    
    with self.lock:
      tbl = self.tables.get(table)
      if not tbl:
        self.tables[table] = tbl = CacheTable(self.maxcount)
      rows = tbl.get(sql)
     
    if rows is None:
      rows = connection.query_dict(sql)
      assert(type(rows) is list)
      with self.lock:
        tbl = self.tables.get(table)
        if not tbl:
          self.tables[table] = tbl = CacheTable(self.maxcount)
        tbl.set(sql, rows)
    return rows

  def reset(self, connection, table, fields = None):
    assert(type(table) is str)
    sql = None
    if not fields is None:
      sql = self.build_select(connection, table, fields)
    with self.lock:
      tbl = self.tables.get(table)
      if tbl:
        if sql is None:
          tbl.clear()
        else:
          tbl.unset(sql)
    

class CacheConnection:
  def __init__(self, cache, connection):
    self.cache = cache
    self.connection = connection

  def clear(self):
    self.cache.clear()

  def select(self, table, fields):
    return self.cache.select(self.connection, table, fields)

  def reset(self, table, fields = None):
    self.cache.reset(self.connection, table, fields)

  def row(self, table, id):
    rows = self.select(table, {'id': id})
    if len(rows) == 0:
      return None
    assert(len(rows) == 1)
    return rows[0]

  def reset_row(self, table, id):
    self.reset(table, {'id': id})