Blob Blame Raw

import datetime
import traceback
import MySQLdb.cursors

from db.parser import parse


class Cursor:
  def __init__(self, connection, internal):
    self.connection = connection
    self.typebytype = self.connection.pool.server.dbtypebytype
    self.typebychar = self.connection.pool.server.dbtypebychar
    self.internal = internal
    self.entered = None
    self.iterator = None

  def get_internal(self):
    return self.internal if self.entered is None else self.entered

  def __enter__(self):
    self.entered = self.internal.__enter__()
    return self

  def __exit__(self, exc_type, exc_value, traceback):
    self.entered.__exit__(exc_type, exc_value, traceback)
    self.entered = None

  def __iter__(self):
    self.iterator = iter(self.get_internal())
    return self

  def __next__(self):
    data = next(self.iterator)
    if type(data) is dict:
      for k, v in data.items():
        t = self.typebytype.get( type(v) )
        if t: data[k] = t.from_db(self, v)
    else:
      orig = data
      data = list()
      for v in orig:
        t = self.typebytype.get( type(v) )
        data.append(t.from_db(self, v) if t else v)
    return data
  
  def execute(self, sql = None, *args, **kvargs):
    if args or kvargs:
      sql = parse(self.typebychar, self.connection, sql, *args, **kvargs)
    self.get_internal().execute(sql)


class Connection:
  def __init__(self, pool, internal, readonly = True):
    self.pool = pool
    self.internal = internal
    self.readonly = readonly
    self.finished = True
    self.begin()

  def cursor(self, as_dict = False):
    assert not self.finished
    cursorclass = MySQLdb.cursors.DictCursor if as_dict else MySQLdb.cursors.Cursor
    cursor = Cursor(self, self.internal.cursor(cursorclass))
    return cursor
  def cursor_list(self):
    return self.cursor(False)
  def cursor_dict(self):
    return self.cursor(True)
  
  def query(self, as_dict, sql = None, *args, **kvargs):
    with self.cursor(as_dict) as cursor:
      cursor.execute(sql, *args, **kvargs)
      return list(cursor)
  def query_list(self, sql = None, *args, **kvargs):
    return self.query(False, sql, *args, **kvargs)
  def query_dict(self, sql = None, *args, **kvargs):
    return self.query(True, sql, *args, **kvargs)

  def execute(self, sql = None, *args, **kvargs):
    with self.cursor() as cursor:
      cursor.execute(sql, *args, **kvargs)

  def insert_id(self, *args, **kwargs):
    assert not self.finished
    return self.internal.insert_id(*args, **kwargs)
  
  def escape(self, *args, **kwargs):
    r = self.internal.escape(*args, **kwargs)
    return r.decode("utf8") if type(r) is bytes else r

  def escape_string(self, *args, **kwargs):
    r = self.internal.escape_string(*args, **kwargs)
    return r.decode("utf8") if type(r) is bytes else r
  
  def begin(self):
    assert self.finished
    self.finished = False
    if self.readonly:
      self.execute("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ")
      self.execute("START TRANSACTION READ ONLY, WITH CONSISTENT SNAPSHOT")
    else:
      self.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE")
      self.execute("START TRANSACTION READ WRITE, WITH CONSISTENT SNAPSHOT")
    self.now = datetime.datetime.now(datetime.timezone.utc)

  def commit(self):
    assert not self.finished
    self.internal.commit()
    self.finished = True
  
  def rollback(self):
    assert not self.finished
    self.internal.rollback()
    self.finished = True
      
  def release(self):
    if not self.finished:
      try: self.rollback()
      except Exception as e:
        print(traceback.format_exc())
        print(e)