Blob Blame Raw

import datetime
import traceback
import MySQLdb.cursors


class Cursor:
  def __init__(self, connection, internal):
    self.connection = connection
    self.internal = internal
    self.entered = None
    self.iterator = None

  def get_internal(self):
    return self.internal if self.entered is None else self.entered

  def __enter__(self):
    self.entered = self.internal.__enter__()
    return self

  def __exit__(self, exc_type, exc_value, traceback):
    self.entered.__exit__(exc_type, exc_value, traceback)
    self.entered = None

  def __iter__(self):
    self.iterator = iter(self.get_internal())
    return self

  def __next__(self):
    data = next(self.iterator)
    if type(data) is dict:
      for k, v in data.items():
        t = self.connection.typebytype.get( type(v) )
        if t: data[k] = t.from_db(self, v)
    else:
      orig = data
      data = list()
      for v in orig:
        t = self.connection.typebytype.get( type(v) )
        data.append(t.from_db(self, v) if t else v)
    return data
  
  def execute(self, sql, *args, **kvargs):
    sql = self.connection.parse(sql, *args, **kvargs)
    try:
      self.get_internal().execute( sql )
    except Exception as e:
      print('SQL Error in query:')
      print(sql)
      raise e


class Connection:
  def __init__(self, pool, internal, readonly = True):
    self.pool = pool
    self.server = self.pool.server
    self.typebytype = self.server.dbtypebytype
    self.typebychar = self.server.dbtypebychar
    self.cache = self.server.dbcache.create_connection(self)
    self.request = None
    self.internal = internal
    self.readonly = readonly
    self.finished = True
    self.begin()

  def parse(self, text, *args, **kvargs):
    i = iter(text)
    index = 0
    result = ''
    try:
      while True:
        try: c = next(i)
        except StopIteration: break
        if c == '%':
          c = next(i)
          if c == '(':
            field = ''
            while True:
              c = next(i)
              if c == ')': break
              field += c
            c = next(i)
            result += self.typebychar[c].to_db(self, kvargs[field])
          elif c == '%':
            result += '%'
          else:
            result += self.typebychar[c].to_db(self, args[index])
            index += 1
        else:
          result += c
    except StopIteration:
      raise Exception('unexpeted end of sql template')
    return result

  def cursor(self, as_dict = False, sql = None, *args, **kvargs):
    assert not self.finished
    cursorclass = MySQLdb.cursors.DictCursor if as_dict else MySQLdb.cursors.Cursor
    cursor = Cursor(self, self.internal.cursor(cursorclass))
    if sql:
      cursor.execute(sql, *args, **kvargs)
    return cursor
  def cursor_list(self, sql = None, *args, **kvargs):
    return self.cursor(False, sql, *args, **kvargs)
  def cursor_dict(self, sql = None, *args, **kvargs):
    return self.cursor(True, sql, *args, **kvargs)
  
  def query(self, as_dict, sql = None, *args, **kvargs):
    with self.cursor(as_dict, sql, *args, **kvargs) as cursor:
      return list(cursor)
  def query_list(self, sql = None, *args, **kvargs):
    return self.query(False, sql, *args, **kvargs)
  def query_dict(self, sql = None, *args, **kvargs):
    return self.query(True, sql, *args, **kvargs)

  def execute(self, sql = None, *args, **kvargs):
    with self.cursor(False, sql, *args, **kvargs):
      return

  def insert_id(self):
    assert not self.finished
    return self.internal.insert_id()
  
  def escape(self, *args, **kwargs):
    r = self.internal.escape(*args, **kwargs)
    return r.decode("utf8") if type(r) is bytes else r

  def escape_string(self, *args, **kwargs):
    r = self.internal.escape_string(*args, **kwargs)
    return r.decode("utf8") if type(r) is bytes else r
  
  def begin(self):
    assert self.finished
    self.finished = False
    self.execute("SET sql_mode='STRICT_TRANS_TABLES'")
    if self.readonly:
      self.execute("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ")
      self.execute("START TRANSACTION READ ONLY, WITH CONSISTENT SNAPSHOT")
    else:
      self.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE")
      self.execute("START TRANSACTION READ WRITE")
    self.now = datetime.datetime.now(datetime.timezone.utc)

  def commit(self):
    assert not self.finished
    self.internal.commit()
    self.finished = True
  
  def rollback(self):
    assert not self.finished
    self.internal.rollback()
    self.finished = True
      
  def release(self):
    if not self.finished:
      try: self.rollback()
      except Exception as e:
        print(traceback.format_exc())
        print(e)