Blame db/connection.py

cbf076
cbf076
import datetime
cbf076
import traceback
4656ad
import MySQLdb.cursors
4656ad
4656ad
4656ad
class Cursor:
4656ad
  def __init__(self, connection, internal):
4656ad
    self.connection = connection
4656ad
    self.internal = internal
4656ad
    self.entered = None
4656ad
    self.iterator = None
4656ad
4656ad
  def get_internal(self):
4656ad
    return self.internal if self.entered is None else self.entered
4656ad
4656ad
  def __enter__(self):
4656ad
    self.entered = self.internal.__enter__()
4656ad
    return self
4656ad
4656ad
  def __exit__(self, exc_type, exc_value, traceback):
4656ad
    self.entered.__exit__(exc_type, exc_value, traceback)
4656ad
    self.entered = None
4656ad
4656ad
  def __iter__(self):
4656ad
    self.iterator = iter(self.get_internal())
4656ad
    return self
4656ad
4656ad
  def __next__(self):
4656ad
    data = next(self.iterator)
4656ad
    if type(data) is dict:
4656ad
      for k, v in data.items():
b838e2
        t = self.connection.typebytype.get( type(v) )
4656ad
        if t: data[k] = t.from_db(self, v)
4656ad
    else:
4656ad
      orig = data
4656ad
      data = list()
4656ad
      for v in orig:
b838e2
        t = self.connection.typebytype.get( type(v) )
4656ad
        data.append(t.from_db(self, v) if t else v)
4656ad
    return data
4656ad
  
05525d
  def execute(self, sql, *args, **kvargs):
b838e2
    sql = self.connection.parse(sql, *args, **kvargs)
b838e2
    try:
b838e2
      self.get_internal().execute( sql )
b838e2
    except Exception as e:
b838e2
      print('SQL Error in query:')
b838e2
      print(sql)
b838e2
      raise e
cbf076
cbf076
cbf076
class Connection:
cbf076
  def __init__(self, pool, internal, readonly = True):
cbf076
    self.pool = pool
b838e2
    self.server = self.pool.server
b838e2
    self.typebytype = self.server.dbtypebytype
b838e2
    self.typebychar = self.server.dbtypebychar
b838e2
    self.cache = self.server.dbcache.create_connection(self)
05525d
    self.request = None
cbf076
    self.internal = internal
cbf076
    self.readonly = readonly
cbf076
    self.finished = True
cbf076
    self.begin()
e5b0ac
    self.on_commit = list()
e5b0ac
    self.on_rollback = list()
cbf076
b838e2
  def parse(self, text, *args, **kvargs):
b838e2
    i = iter(text)
b838e2
    index = 0
b838e2
    result = ''
b838e2
    try:
b838e2
      while True:
b838e2
        try: c = next(i)
b838e2
        except StopIteration: break
b838e2
        if c == '%':
b838e2
          c = next(i)
b838e2
          if c == '(':
b838e2
            field = ''
b838e2
            while True:
b838e2
              c = next(i)
b838e2
              if c == ')': break
b838e2
              field += c
b838e2
            c = next(i)
b838e2
            result += self.typebychar[c].to_db(self, kvargs[field])
b838e2
          elif c == '%':
b838e2
            result += '%'
b838e2
          else:
b838e2
            result += self.typebychar[c].to_db(self, args[index])
b838e2
            index += 1
b838e2
        else:
b838e2
          result += c
b838e2
    except StopIteration:
b838e2
      raise Exception('unexpeted end of sql template')
b838e2
    return result
b838e2
05525d
  def cursor(self, as_dict = False, sql = None, *args, **kvargs):
cbf076
    assert not self.finished
4656ad
    cursorclass = MySQLdb.cursors.DictCursor if as_dict else MySQLdb.cursors.Cursor
4656ad
    cursor = Cursor(self, self.internal.cursor(cursorclass))
05525d
    if sql:
05525d
      cursor.execute(sql, *args, **kvargs)
4656ad
    return cursor
05525d
  def cursor_list(self, sql = None, *args, **kvargs):
05525d
    return self.cursor(False, sql, *args, **kvargs)
05525d
  def cursor_dict(self, sql = None, *args, **kvargs):
05525d
    return self.cursor(True, sql, *args, **kvargs)
4656ad
  
4656ad
  def query(self, as_dict, sql = None, *args, **kvargs):
05525d
    with self.cursor(as_dict, sql, *args, **kvargs) as cursor:
4656ad
      return list(cursor)
4656ad
  def query_list(self, sql = None, *args, **kvargs):
4656ad
    return self.query(False, sql, *args, **kvargs)
4656ad
  def query_dict(self, sql = None, *args, **kvargs):
4656ad
    return self.query(True, sql, *args, **kvargs)
4656ad
4656ad
  def execute(self, sql = None, *args, **kvargs):
05525d
    with self.cursor(False, sql, *args, **kvargs):
05525d
      return
cbf076
05525d
  def insert_id(self):
cbf076
    assert not self.finished
05525d
    return self.internal.insert_id()
cbf076
  
cbf076
  def escape(self, *args, **kwargs):
cbf076
    r = self.internal.escape(*args, **kwargs)
cbf076
    return r.decode("utf8") if type(r) is bytes else r
cbf076
cbf076
  def escape_string(self, *args, **kwargs):
cbf076
    r = self.internal.escape_string(*args, **kwargs)
cbf076
    return r.decode("utf8") if type(r) is bytes else r
cbf076
  
cbf076
  def begin(self):
cbf076
    assert self.finished
cbf076
    self.finished = False
b838e2
    self.execute("SET sql_mode='STRICT_TRANS_TABLES'")
cbf076
    if self.readonly:
4656ad
      self.execute("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ")
4656ad
      self.execute("START TRANSACTION READ ONLY, WITH CONSISTENT SNAPSHOT")
cbf076
    else:
4656ad
      self.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE")
b838e2
      self.execute("START TRANSACTION READ WRITE")
4656ad
    self.now = datetime.datetime.now(datetime.timezone.utc)
cbf076
e5b0ac
  def process_events(self, events, skip_errors = False):
e5b0ac
    while events:
e5b0ac
      events_copy = list(events)
e5b0ac
      events.clear()
e5b0ac
      for event in events_copy:
e5b0ac
        try:
e5b0ac
          event[0](*event[1], *event[2])
e5b0ac
        except Exception as e:
e5b0ac
          print("exception in event")
e5b0ac
          print(traceback.format_exc())
e5b0ac
          print(e)
e5b0ac
          if not skip_errors:
e5b0ac
            raise e
e5b0ac
e5b0ac
  def call_on_commit(self, function, *args, **kvargs):
e5b0ac
    self.on_commit.append((function, args, kvargs))
e5b0ac
  def call_on_rollback(self, function, *args, **kvargs):
e5b0ac
    self.on_rollback.append((function, args, kvargs))
e5b0ac
cbf076
  def commit(self):
cbf076
    assert not self.finished
e5b0ac
    self.process_events(self.on_commit)
e5b0ac
    self.on_commit.clear()
e5b0ac
    self.on_rollback.clear()
cbf076
    self.internal.commit()
cbf076
    self.finished = True
cbf076
  
cbf076
  def rollback(self):
cbf076
    assert not self.finished
e5b0ac
    self.process_events(self.on_rollback, skip_errors = True)
e5b0ac
    self.on_commit.clear()
e5b0ac
    self.on_rollback.clear()
cbf076
    self.internal.rollback()
cbf076
    self.finished = True
cbf076
      
cbf076
  def release(self):
cbf076
    if not self.finished:
cbf076
      try: self.rollback()
cbf076
      except Exception as e:
cbf076
        print(traceback.format_exc())
cbf076
        print(e)