diff --git a/config.py b/config.py index 149aa43..9a68921 100644 --- a/config.py +++ b/config.py @@ -1,12 +1,13 @@ config = { 'db': { + 'prefix': 'ew_', 'connection': { - 'host' : '127.0.0.1', - 'user' : 'root', - 'passwd' : 'password', - 'db' : 'earthworm', - 'charset' : 'utf8mb4', + 'host' : '127.0.0.1', + 'user' : 'root', + 'passwd' : 'password', + 'db' : 'earthworm', + 'charset' : 'utf8mb4', }, }, } diff --git a/db/connection.py b/db/connection.py index be3d60e..edcf8eb 100644 --- a/db/connection.py +++ b/db/connection.py @@ -1,7 +1,53 @@ import datetime import traceback -import MySQLdb +import MySQLdb.cursors + +from db.parser import parse + + +class Cursor: + def __init__(self, connection, internal): + self.connection = connection + self.typebytype = self.connection.pool.server.dbtypebytype + self.typebychar = self.connection.pool.server.dbtypebychar + self.internal = internal + self.entered = None + self.iterator = None + + def get_internal(self): + return self.internal if self.entered is None else self.entered + + def __enter__(self): + self.entered = self.internal.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.entered.__exit__(exc_type, exc_value, traceback) + self.entered = None + + def __iter__(self): + self.iterator = iter(self.get_internal()) + return self + + def __next__(self): + data = next(self.iterator) + if type(data) is dict: + for k, v in data.items(): + t = self.typebytype.get( type(v) ) + if t: data[k] = t.from_db(self, v) + else: + orig = data + data = list() + for v in orig: + t = self.typebytype.get( type(v) ) + data.append(t.from_db(self, v) if t else v) + return data + + def execute(self, sql = None, *args, **kvargs): + if args or kvargs: + sql = parse(self.typebychar, self.connection, sql, *args, **kvargs) + self.get_internal().execute(sql) class Connection: @@ -12,9 +58,28 @@ class Connection: self.finished = True self.begin() - def cursor(self): + def cursor(self, as_dict = False): assert not self.finished - return self.internal.cursor(MySQLdb.cursors.DictCursor) + cursorclass = MySQLdb.cursors.DictCursor if as_dict else MySQLdb.cursors.Cursor + cursor = Cursor(self, self.internal.cursor(cursorclass)) + return cursor + def cursor_list(self): + return self.cursor(False) + def cursor_dict(self): + return self.cursor(True) + + def query(self, as_dict, sql = None, *args, **kvargs): + with self.cursor(as_dict) as cursor: + cursor.execute(sql, *args, **kvargs) + return list(cursor) + def query_list(self, sql = None, *args, **kvargs): + return self.query(False, sql, *args, **kvargs) + def query_dict(self, sql = None, *args, **kvargs): + return self.query(True, sql, *args, **kvargs) + + def execute(self, sql = None, *args, **kvargs): + with self.cursor() as cursor: + cursor.execute(sql, *args, **kvargs) def insert_id(self, *args, **kwargs): assert not self.finished @@ -31,19 +96,13 @@ class Connection: def begin(self): assert self.finished self.finished = False - with self.cursor() as cursor: - cursor.execute("SET autocommit=0") if self.readonly: - with self.cursor() as cursor: - cursor.execute("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ") - with self.cursor() as cursor: - cursor.execute("START TRANSACTION READ ONLY, WITH CONSISTENT SNAPSHOT") + self.execute("SET TRANSACTION ISOLATION LEVEL REPEATABLE READ") + self.execute("START TRANSACTION READ ONLY, WITH CONSISTENT SNAPSHOT") else: - with self.cursor() as cursor: - cursor.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE") - with self.cursor() as cursor: - cursor.execute("START TRANSACTION READ WRITE") - self.now = datetime.datetime.now(self.pool.server.config['timezone']) + self.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE") + self.execute("START TRANSACTION READ WRITE, WITH CONSISTENT SNAPSHOT") + self.now = datetime.datetime.now(datetime.timezone.utc) def commit(self): assert not self.finished diff --git a/db/parser.py b/db/parser.py new file mode 100644 index 0000000..3d4fbd9 --- /dev/null +++ b/db/parser.py @@ -0,0 +1,31 @@ + + +def parse(typebychar, connection, text, *args, **kvargs): + i = iter(text) + index = 0 + result = '' + try: + while True: + try: c = next(i) + except StopIteration: break + if c == '%': + c = next(i) + if c == '(': + field = '' + while True: + c = next(i) + if c == ')': break + field += c + c = next(i) + result += typebychar[c].to_db(connection, kvargs[field]) + elif c == '%': + result += '%' + else: + result += typebychar[c].to_db(connection, args[index]) + index += 1 + else: + result += c + except StopIteration: + raise Exception('unexpeted end of sql template') + return result + diff --git a/db/pool.py b/db/pool.py index 8e81db4..56318a1 100644 --- a/db/pool.py +++ b/db/pool.py @@ -51,7 +51,7 @@ class Pool: time.sleep(self.server.config['db']['retrytime']) if not connection: connection = MySQLdb.connect(**self.server.config['db']['connection']) - + with self.condition: self.busy[readonly].remove(None) self.busy[readonly].append(connection) diff --git a/db/types.py b/db/types.py new file mode 100644 index 0000000..6d9213a --- /dev/null +++ b/db/types.py @@ -0,0 +1,105 @@ + + +import datetime + + +dateformat = '%Y-%m-%dT%H:%M:%S' +datepattern = '0000-00-00T00:00:00' + +def date_to_str(date): + return date.strftime(dateformat).rjust(len(datepattern), '0') + +def set_timezone(date, tz = datetime.timezone.utc): + return datetime.datetime( + year = date.year, + month = date.month, + day = date.day, + hour = date.hour, + minute = date.minute, + second = date.second, + microsecond = date.microsecond, + tzinfo = tz ) + +def str_to_date(s): + return set_timezone( datetime.datetime.strptime(s, date_format) ) + + +class Type: + def from_raw(self, value): + return None + def from_db(self, connection, value): + return value + def to_db(self, connection, value): + return connection.escape(self.from_raw(value)) + + +class Int(Type): + def from_raw(self, value): + return int(value if value else 0) + + +class String(Type): + def from_raw(self, value): + return str(value) + def to_db(self, connection, value): + return "'" + connection.escape_string(self.from_raw(value)) + "'" + + +class Float(Type): + def from_raw(self, value): + return float(value if value else 0) + + +class Date(Type): + def from_raw(self, value): + return str_to_date(date_to_str(value_raw) if type(value_raw) is datetime.datetime else str(value_raw)) + def from_db(self, connection, value): + return set_timezone(value) + + +class Field(Type): + def from_raw(self, value): + result = str(value) + assert(result.isidentifier()) + return result + def to_db(self, connection, value): + return '`' + connection.escape_string(self.from_raw(value)) + '`' + + +class Table(Type): + def from_raw(self, value): + result = str(value) + assert(result.isidentifier()) + return result + def from_db(self, connection, value): + value = str(value) + prefix = connection.pool.server.config['db']['prefix'] + assert value.startswith(prefix) + return self.from_raw(value[len(prefix):]) + def to_db(self, connection, value): + return '`' + connection.escape_string( connection.pool.server.config['db']['prefix'] + self.from_raw(value)) + '`' + + +tint = Int() +tstring = String() +tfloat = Float() +tdate = Date() +tfield = Field() +ttable = Table() + + +bytype = { + int : tint, + str : tstring, + float : tfloat, + datetime.datetime: tdate, +} + +bychar = { + 'd': tint, + 's': tstring, + 'f': tfloat, + 'D': tdate, + 'F': tfield, + 'T': ttable, +} diff --git a/main.py b/main.py index 5ca7db4..082ca45 100644 --- a/main.py +++ b/main.py @@ -17,15 +17,16 @@ def application(env, start_response): request.connection = conn request.template = template.common.instance - content = '
' + request.t("Hello World!") + '
' \ - + '' + "Env:\n" + str(env) + '
' + + content = request.t("Hello World!") + + content += '' + "Env:\n" + str(env) + '
' - tables = [] - with request.connection.cursor() as cursor: - cursor.execute('SHOW TABLES') - for row in cursor: - tables.append(str(list(row.values())[0])) - content += 'DB tables: ' + ', '.join(tables) + '
' + tables = list(v[0] for v in request.connection.query_list('SHOW TABLES')) + content += 'DB tables: ' + ', '.join(tables) + '
' + + rows = request.connection.query_dict('SELECT * FROM %T', 'test') + content += 'Rows of test table: ' + str(rows) + '
' return request.complete_content(content) diff --git a/server.py b/server.py index 9167496..50a2e0f 100644 --- a/server.py +++ b/server.py @@ -1,5 +1,6 @@ import datetime +from db.types import bytype, bychar from db.pool import Pool @@ -13,10 +14,10 @@ class Server: self.config = { 'urlprefix' : urlprefix, 'urldataprefix' : str(config.get('urldataprefix', urlprefix + '/data')), - 'timezone' : config.get('timezone', datetime.timezone.utc), 'db' : { 'connection' : dict(config_db.get('connection', dict())), + 'prefix' : str(config_db.get('prefix', '')), 'retrytime' : float(config_db.get('retrytime', 0)), 'pool': { 'read' : int(config_db_pool.get('read' , 10)), @@ -25,10 +26,14 @@ class Server: }, } - assert type(self.config['timezone']) is datetime.timezone assert self.config['db']['retrytime'] >= 0 + assert self.config['db']['prefix'] == '' \ + or self.config['db']['prefix'].isidentifier() assert self.config['db']['pool']['read'] > 0 assert self.config['db']['pool']['write'] > 0 + self.dbtypebytype = bytype + self.dbtypebychar = bychar self.dbpool = Pool(self) +