myrtille 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
myrtille/lib/cfg.py ADDED
@@ -0,0 +1,12 @@
1
+ import datetime
2
+ import pydantic
3
+
4
+
5
+ class Database(pydantic.BaseModel):
6
+ user: str
7
+ password: str
8
+ host: str
9
+ port: int
10
+ echo: bool | None = None
11
+ pool_size: int | None = None
12
+ timeout: datetime.timedelta | None = None
myrtille/lib/db.py ADDED
@@ -0,0 +1,120 @@
1
+ import contextlib
2
+ import time
3
+ import typing
4
+
5
+ import asyncmy
6
+ import pydantic
7
+
8
+ from myrtille.lib import cfg
9
+
10
+ ParamsType: typing.TypeAlias = typing.Collection[typing.Any]
11
+ Querriable: typing.TypeAlias = 'Database | Connection'
12
+
13
+
14
+ def _format_request(stmt: str, *, params: ParamsType | None = None):
15
+ if params is not None:
16
+ try:
17
+ stmt = stmt % tuple(params)
18
+ except Exception:
19
+ if len(params) == 0:
20
+ param_part = 'no params'
21
+ else:
22
+ param_part = f'params {", ".join(map(repr, params))}'
23
+ stmt = f"Invalid stmt '{stmt}' with {param_part}"
24
+ return f"Request '{' '.join(stmt.split())}'"
25
+
26
+
27
+ class Database(pydantic.BaseModel):
28
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
29
+
30
+ config: cfg.Database
31
+ pool: asyncmy.pool.Pool = pydantic.Field(exclude=True)
32
+
33
+ @contextlib.asynccontextmanager
34
+ async def acquire(self):
35
+ async with self.pool.acquire() as cnx:
36
+ yield Connection(database=self, cnx=cnx)
37
+ await cnx.rollback()
38
+
39
+ async def execute(self, stmt: str, *, params: ParamsType | None = None):
40
+ async with self.acquire() as cnx:
41
+ await cnx.execute(stmt, params=params)
42
+ await cnx.commit()
43
+
44
+ async def fetch_all(self, stmt: str, *, params: typing.Sequence[typing.Any] | None = None):
45
+ async with self.acquire() as cnx:
46
+ return await cnx.fetch_all(stmt, params=params)
47
+
48
+ async def fetch_optional(self, stmt: str, *, params: typing.Sequence[typing.Any] | None = None):
49
+ async with self.acquire() as cnx:
50
+ return await cnx.fetch_optional(stmt, params=params)
51
+
52
+
53
+ class Connection(pydantic.BaseModel):
54
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
55
+
56
+ database: Database
57
+ cnx: asyncmy.connection.Connection
58
+
59
+ async def execute(self, stmt: str, *, params: ParamsType | None = None):
60
+ t0 = time.perf_counter()
61
+ try:
62
+ async with self.cnx.cursor() as cursor:
63
+ await cursor.execute(stmt, params)
64
+ if self.database.config.echo:
65
+ print(
66
+ f'Log: {_format_request(stmt, params=params)}: {time.perf_counter() - t0:2.2f}s'
67
+ )
68
+ except Exception as e:
69
+ e.add_note(f'In {_format_request(stmt, params=params)}')
70
+ raise
71
+
72
+ async def fetch_all(self, stmt: str, params: typing.Sequence[typing.Any] | None = None):
73
+ t0 = time.perf_counter()
74
+ try:
75
+ async with self.cnx.cursor() as cursor:
76
+ await cursor.execute(stmt, params)
77
+ rows = await cursor.fetchall()
78
+ if self.database.config.echo:
79
+ print(
80
+ f'Log: {_format_request(stmt, params=params)}: {time.perf_counter() - t0:2.2f}s'
81
+ )
82
+ return rows
83
+ except Exception as e:
84
+ e.add_note(f'In {_format_request(stmt, params=params)}')
85
+ raise
86
+
87
+ async def fetch_optional(self, stmt: str, params: typing.Sequence[typing.Any] | None = None):
88
+ rows = await self.fetch_all(stmt, params)
89
+ if len(rows) > 1:
90
+ raise Exception(
91
+ f'{_format_request(stmt, params=params)} returned {len(rows)} (!= 1) rows'
92
+ )
93
+ elif len(rows) == 0:
94
+ return None
95
+ return rows[0]
96
+
97
+ async def commit(self):
98
+ await self.cnx.commit()
99
+
100
+
101
+ @contextlib.asynccontextmanager
102
+ async def make_database(db_config: cfg.Database):
103
+ async with asyncmy.pool.create_pool(
104
+ host=db_config.host,
105
+ port=db_config.port,
106
+ user=db_config.user,
107
+ password=db_config.password,
108
+ autocommit=False,
109
+ echo=db_config.echo or False,
110
+ connect_timeout=db_config.timeout.total_seconds()
111
+ if db_config.timeout is not None
112
+ else 31536000,
113
+ minsize=db_config.pool_size or 1,
114
+ maxsize=db_config.pool_size or 1,
115
+ ) as pool:
116
+ try:
117
+ yield Database(config=db_config, pool=pool)
118
+ finally:
119
+ pool.close()
120
+ await pool.wait_closed()
myrtille/lib/util.py ADDED
@@ -0,0 +1,9 @@
1
+ import typing
2
+
3
+
4
+ def as_any(v: typing.Any, /):
5
+ return v
6
+
7
+
8
+ def snake_to_pascal(s: str) -> str:
9
+ return ''.join(word.capitalize() for word in s.split('_'))
@@ -0,0 +1,59 @@
1
+ import asyncio
2
+ import re
3
+ import typing
4
+
5
+ import pydantic
6
+
7
+ from myrtille.lib import cfg, db
8
+
9
+
10
+ def _correct_ddl(ddl: str):
11
+ # Removes display width on interer types
12
+ # Removes floating point precision on time functions
13
+ for s in [
14
+ 'tinyint',
15
+ 'smallint',
16
+ 'int',
17
+ 'bigint',
18
+ 'DEFAULT CURRENT_TIMESTAMP',
19
+ 'ON UPDATE CURRENT_TIMESTAMP',
20
+ ]:
21
+ ddl = re.sub(rf' {s}\([0-9]*\)', f' {s}', ddl, flags=re.IGNORECASE)
22
+
23
+ # Removes non standart float precision
24
+ for s in ['float', 'double']:
25
+ ddl = re.sub(rf' {s}\([0-9]*,[0-9]*\)', f' {s}', ddl, flags=re.IGNORECASE)
26
+
27
+ return ddl
28
+
29
+
30
+ class _Table(pydantic.BaseModel):
31
+ schema_name: str
32
+ table_name: str
33
+
34
+ async def get_ddl(self, database: db.Database) -> str:
35
+ show_create_response = await database.fetch_optional(
36
+ f'SHOW CREATE TABLE `{self.schema_name}`.`{self.table_name}`'
37
+ )
38
+ assert show_create_response is not None
39
+ (_, ddl) = show_create_response
40
+ assert isinstance(ddl, str)
41
+ return _correct_ddl(ddl)
42
+
43
+
44
+ async def export(schema_name: str, config: cfg.Database) -> typing.Sequence[str]:
45
+ async with db.make_database(config) as database:
46
+ return await asyncio.gather(
47
+ *(
48
+ _Table(schema_name=schema_name, table_name=table_name).get_ddl(database)
49
+ for (table_name,) in await database.fetch_all(
50
+ """
51
+ SELECT TABLE_NAME
52
+ FROM INFORMATION_SCHEMA.TABLES
53
+ WHERE TABLE_TYPE = "BASE TABLE"
54
+ AND TABLE_SCHEMA = %s
55
+ """,
56
+ params=[schema_name],
57
+ )
58
+ )
59
+ )
@@ -0,0 +1,353 @@
1
+ from myrtille.mysql import types
2
+
3
+
4
+ def generate_ternary_option(option: bool):
5
+ return '1' if option else '0'
6
+
7
+
8
+ def generate_ref_action(action: types.RefAction):
9
+ match action:
10
+ case types.RefAction.RESTRICT:
11
+ return 'RESTRICT'
12
+ case types.RefAction.CASCADE:
13
+ return 'CASCADE'
14
+ case types.RefAction.SET_NULL:
15
+ return 'SET NULL'
16
+
17
+
18
+ def generate_key_part(key: types.KeyPart):
19
+ res = f'`{key.identifier}`'
20
+ if key.length is not None:
21
+ res += f'({key.length})'
22
+ if key.direction is not None:
23
+ res += f' {key.direction.name}'
24
+ return res
25
+
26
+
27
+ def generate_data_type(data_type: types.DataType):
28
+ arguments: list[str] = []
29
+
30
+ if (
31
+ isinstance(data_type, types.Datetime | types.Timestamp | types.Time)
32
+ and data_type.precision is not None
33
+ ):
34
+ arguments = [f'{data_type.precision}']
35
+
36
+ if isinstance(
37
+ data_type, types.Bit | types.Char | types.Varchar | types.Binary | types.Varbinary
38
+ ):
39
+ arguments = [f'{data_type.length}']
40
+
41
+ if isinstance(data_type, types.Set | types.Enum):
42
+ arguments = [f"'{v}'" for v in data_type.values]
43
+
44
+ match data_type:
45
+ case types.Tinyint():
46
+ data_types_name = 'tinyint'
47
+ case types.Smallint():
48
+ data_types_name = 'smallint'
49
+ case types.Mediumint():
50
+ data_types_name = 'mediumint'
51
+ case types.Int():
52
+ data_types_name = 'int'
53
+ case types.Bigint():
54
+ data_types_name = 'bigint'
55
+ case types.Decimal():
56
+ data_types_name = 'decimal'
57
+ if data_type.precision is not None:
58
+ arguments = [f'{data_type.precision}', f'{data_type.scale or 0}']
59
+ case types.Float():
60
+ data_types_name = 'float'
61
+ case types.Double():
62
+ data_types_name = 'double'
63
+ case types.Bit():
64
+ data_types_name = 'bit'
65
+ case types.Datetime():
66
+ data_types_name = 'datetime'
67
+ case types.Timestamp():
68
+ data_types_name = 'timestamp'
69
+ case types.Time():
70
+ data_types_name = 'time'
71
+ case types.Date():
72
+ data_types_name = 'date'
73
+ case types.Year():
74
+ data_types_name = 'year'
75
+ case types.Char():
76
+ data_types_name = 'char'
77
+ case types.Varchar():
78
+ data_types_name = 'varchar'
79
+ case types.TinyText():
80
+ data_types_name = 'tinytext'
81
+ case types.Text():
82
+ data_types_name = 'text'
83
+ case types.MediumText():
84
+ data_types_name = 'mediumtext'
85
+ case types.LongText():
86
+ data_types_name = 'longtext'
87
+ case types.Enum():
88
+ data_types_name = 'enum'
89
+ case types.Set():
90
+ data_types_name = 'set'
91
+ case types.Binary():
92
+ data_types_name = 'binary'
93
+ case types.Varbinary():
94
+ data_types_name = 'varbinary'
95
+ case types.TinyBlob():
96
+ data_types_name = 'tinyblob'
97
+ case types.Blob():
98
+ data_types_name = 'blob'
99
+ case types.MediumBlob():
100
+ data_types_name = 'mediumblob'
101
+ case types.LongBlob():
102
+ data_types_name = 'longblob'
103
+ case types.Json():
104
+ data_types_name = 'json'
105
+ case types.Geometry():
106
+ data_types_name = 'geometry'
107
+ case types.Point():
108
+ data_types_name = 'point'
109
+ case types.Linestring():
110
+ data_types_name = 'linestring'
111
+ case types.Polygon():
112
+ data_types_name = 'polygon'
113
+ case types.Geometrycollection():
114
+ data_types_name = 'geomcollection'
115
+ case types.Multipoint():
116
+ data_types_name = 'multipoint'
117
+ case types.Multilinestring():
118
+ data_types_name = 'multilinestring'
119
+ case types.Multipolygon():
120
+ data_types_name = 'multipolygon'
121
+ return data_types_name + (f'({",".join(arguments)})' if len(arguments) > 0 else '')
122
+
123
+
124
+ def generate_literal(literal: types.Literal):
125
+ match literal:
126
+ case types.TextLiteral():
127
+ text = f"'{literal.text}'"
128
+ return text
129
+ case types.NullLiteral():
130
+ return 'NULL'
131
+
132
+
133
+ def generate_default(default_value: types.DefaultValue):
134
+ match default_value:
135
+ case types.ExprDefaultAttribute():
136
+ return f'({default_value.expr})'
137
+ case types.LiteralDefaultAttribute():
138
+ return generate_literal(default_value.value)
139
+
140
+
141
+ def generate_column(column: types.Column):
142
+ attributes: list[str] = []
143
+
144
+ if isinstance(column.data_type, types.IntegerDataType) and column.data_type.unsigned:
145
+ attributes.append('unsigned')
146
+
147
+ if isinstance(column.data_type, types.TextDataType):
148
+ if column.data_type.charset is not None:
149
+ attributes.append(f'CHARACTER SET {column.data_type.charset}')
150
+ if column.data_type.collate is not None:
151
+ attributes.append(f'COLLATE {column.data_type.collate}')
152
+
153
+ if column.format != types.ColumnFormat.DEFAULT:
154
+ attributes.append(f'/*!50606 COLUMN_FORMAT {column.format.name} */')
155
+
156
+ if column.storage_media != types.StorageMedia.DEFAULT:
157
+ attributes.append(f'/*!50606 STORAGE {column.storage_media.name} */')
158
+
159
+ if column.generated is not None:
160
+ attributes.append(
161
+ f'GENERATED ALWAYS AS ({column.generated.expr}) {column.generated.type.name}'
162
+ )
163
+
164
+ if column.non_nullable is not None:
165
+ attributes.append(('NOT ' if column.non_nullable else '') + 'NULL')
166
+
167
+ if isinstance(column.data_type, types.SpatialDataType) and column.data_type.srid is not None:
168
+ attributes.append(f'/*!80003 SRID {column.data_type.srid} */')
169
+
170
+ if isinstance(column.data_type, types.IntegerDataType) and column.data_type.auto_increment:
171
+ attributes.append('AUTO_INCREMENT')
172
+
173
+ if column.default_value is not None:
174
+ attributes.append(f'DEFAULT {generate_default(column.default_value)}')
175
+
176
+ if isinstance(column.data_type, types.Datetime | types.Timestamp):
177
+ if column.data_type.default_now:
178
+ attributes.append('DEFAULT CURRENT_TIMESTAMP')
179
+ if column.data_type.on_update_now:
180
+ attributes.append('ON UPDATE CURRENT_TIMESTAMP')
181
+
182
+ if column.comment is not None:
183
+ attributes.append(f"COMMENT '{column.comment}'")
184
+ if column.invisible:
185
+ attributes.append('/*!80023 INVISIBLE */')
186
+
187
+ return f'`{column.name}` {generate_data_type(column.data_type)}' + ''.join(
188
+ f' {a}' for a in attributes
189
+ )
190
+
191
+
192
+ def generate_constraint(constraint: types.Constraint):
193
+ match constraint:
194
+ case types.ForeignConstraint():
195
+ col_names = f'({", ".join(f"`{c}`" for c in constraint.columns)})'
196
+ ref_col_names = f'({", ".join(f"`{c}`" for c in constraint.references.ref_columns)})'
197
+ parts = [
198
+ f'CONSTRAINT `{constraint.name}` FOREIGN KEY {col_names}',
199
+ f'REFERENCES `{constraint.references.ref_table}` {ref_col_names}',
200
+ ]
201
+ if constraint.references.on_delete is not None:
202
+ parts.append(f'ON DELETE {generate_ref_action(constraint.references.on_delete)}')
203
+ if constraint.references.on_update is not None:
204
+ parts.append(f'ON UPDATE {generate_ref_action(constraint.references.on_update)}')
205
+ return ' '.join(parts)
206
+ case types.CheckConstraint():
207
+ return f'CONSTRAINT `{constraint.name}` CHECK ({constraint.expr})'
208
+ case _:
209
+ key_list = f'({",".join(generate_key_part(key) for key in constraint.key_list)})'
210
+ match constraint:
211
+ case types.PrimaryConstraint():
212
+ return f'PRIMARY KEY {key_list}'
213
+ case types.UniqueConstraint():
214
+ return f'UNIQUE KEY `{constraint.name}` {key_list}'
215
+ case types.IndexConstraint():
216
+ return f'KEY `{constraint.name}` {key_list}'
217
+ case types.FulltextConstraint():
218
+ return f'FULLTEXT KEY `{constraint.name}` {key_list}'
219
+ case types.SpatialConstraint():
220
+ return f'SPATIAL KEY `{constraint.name}` {key_list}'
221
+
222
+
223
+ def generate_create_options(create_options: types.CreateOptions):
224
+ attributes: list[str] = []
225
+ if create_options.tablespace is not None:
226
+ attributes.append(f'/*!50100 TABLESPACE `{create_options.tablespace}` */')
227
+ if create_options.engine is not None:
228
+ attributes.append(f'ENGINE={create_options.engine}')
229
+ if create_options.auto_increment is not None:
230
+ attributes.append(f'AUTO_INCREMENT={create_options.auto_increment}')
231
+ if create_options.charset is not None:
232
+ attributes.append(f'DEFAULT CHARSET={create_options.charset}')
233
+ if create_options.collate is not None:
234
+ attributes.append(f'COLLATE={create_options.collate}')
235
+ if create_options.avg_row_length is not None:
236
+ attributes.append(f'AVG_ROW_LENGTH={create_options.avg_row_length}')
237
+ if create_options.stats_persistent is not None:
238
+ attributes.append(
239
+ f'STATS_PERSISTENT={generate_ternary_option(create_options.stats_persistent)}'
240
+ )
241
+ if create_options.row_format is not None:
242
+ attributes.append(f'ROW_FORMAT={create_options.row_format.name}')
243
+ if create_options.comment is not None:
244
+ attributes.append(f"COMMENT='{create_options.comment}'")
245
+
246
+ return ' '.join(attributes)
247
+
248
+
249
+ def generate_key_hash_partitioning_method(
250
+ partitioning: types.KeyPartitionType | types.HashPartitionType,
251
+ ):
252
+ match partitioning:
253
+ case types.KeyPartitionType():
254
+ s = 'KEY'
255
+ case types.HashPartitionType():
256
+ s = f'HASH ({partitioning.expr})'
257
+ return ('LINEAR ' if partitioning.linear else '') + s
258
+
259
+
260
+ def generate_partitioning_method(partitioning: types.Partitioning):
261
+ match partitioning:
262
+ case types.ListPartitioning() | types.RangePartitioning():
263
+ match partitioning.expr_or_columns:
264
+ case str():
265
+ args = f'({partitioning.expr_or_columns})'
266
+ case list():
267
+ args = f'COLUMNS ({", ".join(partitioning.expr_or_columns)})'
268
+ match partitioning:
269
+ case types.RangePartitioning():
270
+ return f'RANGE {args}'
271
+ case types.ListPartitioning():
272
+ return f'LIST {args}'
273
+ case types.KeyHashPartitioning():
274
+ return generate_key_hash_partitioning_method(partitioning.type)
275
+
276
+
277
+ def generate_partitioning(partitioning: types.Partitioning):
278
+ return f'PARTITION BY {generate_partitioning_method(partitioning)}'
279
+
280
+
281
+ def generate_subpartitioning(subpartitioning: types.Subpartitioning):
282
+ return f'SUBPARTITION BY {generate_key_hash_partitioning_method(subpartitioning.type)}'
283
+
284
+
285
+ def generate_partition_options(options: types.PartitionOptions):
286
+ parts: dict[str, str] = {}
287
+ if options.engine is not None:
288
+ parts['ENGINE'] = options.engine
289
+ return [f'{k} = {v}' for k, v in parts.items()]
290
+
291
+
292
+ def generate_key_hash_partition_def(partition: types.KeyHashPartition):
293
+ return ' '.join((partition.name, *generate_partition_options(partition.options)))
294
+
295
+
296
+ def generate_partition_lt_values(values: types.ValuesLessThan):
297
+ if values == [None]:
298
+ args = 'MAXVALUE'
299
+ else:
300
+ args = f'({", ".join((v if v is not None else "MAXVALUE") for v in values)})'
301
+
302
+ return f'VALUES LESS THAN {args}'
303
+
304
+
305
+ def generate_partition_in_values(values: types.ValuesIn):
306
+ elements = [(f'({", ".join(v)})' if isinstance(v, list) else v) for v in values]
307
+ return f'VALUES IN ({", ".join(elements)})'
308
+
309
+
310
+ def generate_subpartition(subparition: types.KeyHashPartition):
311
+ return f'SUBPARTITION {generate_key_hash_partition_def(subparition)}'
312
+
313
+
314
+ def generate_partition(
315
+ partition: types.RangePartition | types.ListPartition | types.KeyHashPartition,
316
+ ):
317
+ match partition:
318
+ case types.RangePartition() | types.ListPartition():
319
+ match partition:
320
+ case types.RangePartition():
321
+ values = generate_partition_lt_values(partition.values)
322
+ case types.ListPartition():
323
+ values = generate_partition_in_values(partition.values)
324
+ parts = [partition.name, values, *generate_partition_options(partition.options)]
325
+ definition = ' '.join(parts)
326
+
327
+ subparitions = [generate_subpartition(sub) for sub in partition.subpartitions]
328
+ if len(subparitions) > 0:
329
+ definition += f'\n ({",\n ".join(subparitions)})'
330
+ case types.KeyHashPartition():
331
+ definition = generate_key_hash_partition_def(partition)
332
+ return f'PARTITION {definition}'
333
+
334
+
335
+ def generate_partitioning_clause(partitioning: types.Partitioning):
336
+ parts = [generate_partitioning(partitioning)]
337
+ if partitioning.sub is not None:
338
+ parts.append(generate_subpartitioning(partitioning.sub))
339
+ parts.append(f'({",\n ".join(generate_partition(p) for p in partitioning.partitions)})')
340
+ return f'/*!50100 {"\n".join(parts)} */'
341
+
342
+
343
+ def generate(table: types.Table):
344
+ elements = [generate_column(c) for c in table.columns] + [
345
+ generate_constraint(c) for c in table.constraints
346
+ ]
347
+ options = generate_create_options(table.options)
348
+ statement = (
349
+ f'CREATE TABLE `{table.name}` (\n{",\n".join(f" {e}" for e in elements)}\n) {options}'
350
+ )
351
+ if table.partitioning is not None:
352
+ statement += '\n' + generate_partitioning_clause(table.partitioning)
353
+ return statement