naboo 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- naboo-0.1.0/PKG-INFO +44 -0
- naboo-0.1.0/README.md +28 -0
- naboo-0.1.0/naboo/__init__.py +0 -0
- naboo-0.1.0/naboo/db.py +899 -0
- naboo-0.1.0/pyproject.toml +33 -0
naboo-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: naboo
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary:
|
|
5
|
+
Home-page: https://github.com/bdoms/naboo
|
|
6
|
+
License: MIT
|
|
7
|
+
Author: Brendan Doms
|
|
8
|
+
Requires-Python: >=3.12,<4.0
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Requires-Dist: asyncpg (==0.29.0)
|
|
13
|
+
Project-URL: Repository, https://github.com/bdoms/naboo
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
|
|
16
|
+
# Naboo
|
|
17
|
+
|
|
18
|
+
A light-weight, asynchronous ORM-like wrapper around `asyncpg` targeting Python 3.12+.
|
|
19
|
+
|
|
20
|
+
All records are returned as dictionaries, because you're just going to encode to JSON anyway.
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
## Testing
|
|
24
|
+
|
|
25
|
+
Get into postgres:
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
sudo -u postgres psql
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
Then setup the test database and permissions (or modify the environment variables in `pytest.ini` instead):
|
|
32
|
+
|
|
33
|
+
```sql
|
|
34
|
+
CREATE DATABASE naboo_test;
|
|
35
|
+
CREATE USER naboo_test_user WITH PASSWORD 'naboo_test_password';
|
|
36
|
+
GRANT ALL PRIVILEGES ON DATABASE naboo_test TO naboo_test_user;
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
Then to run tests:
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
pytest
|
|
43
|
+
```
|
|
44
|
+
|
naboo-0.1.0/README.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Naboo
|
|
2
|
+
|
|
3
|
+
A light-weight, asynchronous ORM-like wrapper around `asyncpg` targeting Python 3.12+.
|
|
4
|
+
|
|
5
|
+
All records are returned as dictionaries, because you're just going to encode to JSON anyway.
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
## Testing
|
|
9
|
+
|
|
10
|
+
Get into postgres:
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
sudo -u postgres psql
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
Then setup the test database and permissions (or modify the environment variables in `pytest.ini` instead):
|
|
17
|
+
|
|
18
|
+
```sql
|
|
19
|
+
CREATE DATABASE naboo_test;
|
|
20
|
+
CREATE USER naboo_test_user WITH PASSWORD 'naboo_test_password';
|
|
21
|
+
GRANT ALL PRIVILEGES ON DATABASE naboo_test TO naboo_test_user;
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
Then to run tests:
|
|
25
|
+
|
|
26
|
+
```bash
|
|
27
|
+
pytest
|
|
28
|
+
```
|
|
File without changes
|
naboo-0.1.0/naboo/db.py
ADDED
|
@@ -0,0 +1,899 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import base64
|
|
3
|
+
from datetime import datetime, date, time
|
|
4
|
+
import hashlib
|
|
5
|
+
import inspect
|
|
6
|
+
from uuid import UUID
|
|
7
|
+
|
|
8
|
+
import asyncpg
|
|
9
|
+
|
|
10
|
+
MAX_FIELD_LENGTH = 1024 * 1024 # 1 MB
|
|
11
|
+
MAX_LABEL_LENGTH = 63
|
|
12
|
+
MAX_LIMIT = 10000
|
|
13
|
+
MAX_SUBQUERY_ARGS = 9
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def validateName(s, allowed='_'):
|
|
17
|
+
if '.' in allowed and '.' in s:
|
|
18
|
+
# this is a table, where we can allow quotes around the table name
|
|
19
|
+
schema, table = s.split('.', 1)
|
|
20
|
+
if table.startswith('"') and table.endswith('"'):
|
|
21
|
+
# reconstruct without the surrounding quotes for checking
|
|
22
|
+
s = f'{schema}.{table[1:-1]}'
|
|
23
|
+
|
|
24
|
+
return all(char.isalnum() or char in allowed for char in s)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# NOTE that a class with only classmethods like this and changing the class properties is effectively a singleton
|
|
28
|
+
class Database:
|
|
29
|
+
|
|
30
|
+
pool = None
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
async def startup(cls, name, user, password, host='localhost', port='5432'):
|
|
34
|
+
|
|
35
|
+
# see here for connection arguments: https://github.com/MagicStack/asyncpg/blob/master/asyncpg/connection.py
|
|
36
|
+
cls.pool = await asyncpg.create_pool(host=host, port=port, user=user, password=password, database=name)
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
async def shutdown(cls):
|
|
40
|
+
# we have to enforce a timeout externally ourselves - this is the recommended way in the docs
|
|
41
|
+
try:
|
|
42
|
+
await asyncio.wait_for(cls.pool.close(), timeout=1)
|
|
43
|
+
except asyncio.TimeoutError:
|
|
44
|
+
cls.pool.terminate()
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
async def connect(cls):
|
|
48
|
+
return await cls.pool.acquire()
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
async def connection(cls):
|
|
52
|
+
async with cls.pool.acquire() as conn:
|
|
53
|
+
yield conn
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
async def dropTables(cls, conn):
|
|
57
|
+
raise NotImplementedError
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def labelName(cls, label_name):
|
|
61
|
+
# if the name is too long then hash it, but we'd prefer to not in order to keep it readable
|
|
62
|
+
# NOTE that 63 chars is the postgres default for label name max length
|
|
63
|
+
if len(label_name) > MAX_LABEL_LENGTH:
|
|
64
|
+
# NOTE that this produces output that's 64 characters long, which postgres will truncate
|
|
65
|
+
# but the collision metrics on 63 vs 64 is so low in practice it's not worth caring about
|
|
66
|
+
label_name = hashlib.sha256(label_name).hexdigest()
|
|
67
|
+
return label_name
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class Field:
|
|
71
|
+
|
|
72
|
+
def __init__(self, null=False, default=None, unique=False) -> None:
|
|
73
|
+
self.null = null
|
|
74
|
+
self.default = default
|
|
75
|
+
self.unique = unique
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def field_type(self):
|
|
79
|
+
return self.db_type
|
|
80
|
+
|
|
81
|
+
def constraint(self, table_name, col_name): # NOQA: ARG002
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
def create(self, table_name, col_name):
|
|
85
|
+
# NOTE: dot/period is allowed for tables so that schemas can be included
|
|
86
|
+
if not validateName(table_name, allowed='._') or not validateName(col_name):
|
|
87
|
+
raise ValueError('Invalid table or column name: ' + table_name + ' ' + col_name)
|
|
88
|
+
|
|
89
|
+
col = '"' + col_name + '" ' + self.field_type
|
|
90
|
+
|
|
91
|
+
if self.default is not None:
|
|
92
|
+
# wrapping default in a string means that it should also work here for bools, ints, etc.
|
|
93
|
+
col += ' DEFAULT ' + str(self.default)
|
|
94
|
+
|
|
95
|
+
if col_name == 'id':
|
|
96
|
+
if self.field_type not in ('uuid', 'int'):
|
|
97
|
+
raise TypeError('Field type is not allowed for primary keys: ' + self.field_type)
|
|
98
|
+
|
|
99
|
+
if isinstance(self, ForeignKeyField):
|
|
100
|
+
raise TypeError('Foreign keys are not allowed as primary keys')
|
|
101
|
+
|
|
102
|
+
col += ' PRIMARY KEY'
|
|
103
|
+
else:
|
|
104
|
+
if not self.null:
|
|
105
|
+
col += ' NOT NULL'
|
|
106
|
+
|
|
107
|
+
if self.unique:
|
|
108
|
+
col += ' UNIQUE'
|
|
109
|
+
|
|
110
|
+
constraint = self.constraint(table_name, col_name)
|
|
111
|
+
return col, constraint
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class BooleanField(Field):
|
|
115
|
+
|
|
116
|
+
db_type = 'boolean'
|
|
117
|
+
|
|
118
|
+
def __init__(self, default=None, **kwargs) -> None:
|
|
119
|
+
if default is not None and not isinstance(default, bool):
|
|
120
|
+
raise TypeError('Invalid default type: ' + str(type(default)))
|
|
121
|
+
|
|
122
|
+
super().__init__(default=default, **kwargs)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class CharField(Field):
|
|
126
|
+
|
|
127
|
+
db_type = 'varchar'
|
|
128
|
+
|
|
129
|
+
def __init__(self, max_length=MAX_FIELD_LENGTH, default=None, **kwargs) -> None:
|
|
130
|
+
# default needs to wrapped in single quotes
|
|
131
|
+
if default is not None:
|
|
132
|
+
if not isinstance(default, str):
|
|
133
|
+
raise TypeError('Invalid default type: ' + str(type(default)))
|
|
134
|
+
|
|
135
|
+
# NOTE: in theory we could escape these below, but not sure it's safe
|
|
136
|
+
# probably need some other method to ensure there aren't any crazy tricks here
|
|
137
|
+
# after way too much investigation it's unclear if there is a good method to allow this
|
|
138
|
+
if "'" in default or '\\' in default:
|
|
139
|
+
raise ValueError('Single quotes and backslashes are not allowed in default values: ' + default)
|
|
140
|
+
# escape bad characters
|
|
141
|
+
# default = default.replace("'", "''").replace('\', '\\')
|
|
142
|
+
|
|
143
|
+
# wrap it in quotes
|
|
144
|
+
default = "'" + default + "'"
|
|
145
|
+
|
|
146
|
+
super().__init__(default=default, **kwargs)
|
|
147
|
+
self.max_lenth = max_length
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def field_type(self):
|
|
151
|
+
return self.db_type + '(' + str(self.max_lenth) + ')'
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class DateField(Field):
|
|
155
|
+
|
|
156
|
+
db_type = 'date'
|
|
157
|
+
|
|
158
|
+
def __init__(self, default=None, **kwargs) -> None:
|
|
159
|
+
if default is not None:
|
|
160
|
+
if isinstance(default, date):
|
|
161
|
+
default = default.strftime('%Y-%m-%d')
|
|
162
|
+
else:
|
|
163
|
+
raise TypeError('Invalid default type: ' + str(type(default)))
|
|
164
|
+
|
|
165
|
+
super().__init__(default=default, **kwargs)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class DateTimeField(Field):
|
|
169
|
+
|
|
170
|
+
db_type = 'timestamp without time zone'
|
|
171
|
+
|
|
172
|
+
def __init__(self, auto_now_add=False, auto_now=False, default=None, **kwargs) -> None:
|
|
173
|
+
if auto_now_add or auto_now:
|
|
174
|
+
if default is not None:
|
|
175
|
+
raise ValueError('`default` must not be defined when `auto_now` or `auto_now_add` is True')
|
|
176
|
+
|
|
177
|
+
default = 'CURRENT_TIMESTAMP'
|
|
178
|
+
elif default is not None:
|
|
179
|
+
if isinstance(default, datetime):
|
|
180
|
+
# NOTE that this assumes the time is UTC already
|
|
181
|
+
default = default.strftime('%Y-%m-%d %H:%M:%S.%f')
|
|
182
|
+
else:
|
|
183
|
+
raise TypeError('Invalid default type: ' + str(type(default)))
|
|
184
|
+
|
|
185
|
+
super().__init__(default=default, **kwargs)
|
|
186
|
+
|
|
187
|
+
self.auto_now_add = auto_now_add
|
|
188
|
+
self.auto_now = auto_now
|
|
189
|
+
|
|
190
|
+
def constraint(self, table_name, col_name):
|
|
191
|
+
if not validateName(table_name, allowed='._') or not validateName(col_name):
|
|
192
|
+
raise ValueError('Invalid table or column name: ' + table_name + ' ' + col_name)
|
|
193
|
+
|
|
194
|
+
# create a function and trigger for this column - possible to do with one function but very ugly, see
|
|
195
|
+
# https://dba.stackexchange.com/questions/127787/trigger-function-taking-column-names-as-parameters-to-modify-the-row
|
|
196
|
+
|
|
197
|
+
sql = None
|
|
198
|
+
if self.auto_now:
|
|
199
|
+
full_name = table_name.replace('.', '_').replace('"', '') + '_' + col_name
|
|
200
|
+
function_name = Database.labelName('auto_now_function_' + full_name)
|
|
201
|
+
|
|
202
|
+
sql = f'CREATE OR REPLACE FUNCTION "{function_name}"() RETURNS TRIGGER AS $$\n' \
|
|
203
|
+
+ 'BEGIN\n' \
|
|
204
|
+
+ f'NEW."{col_name}" = NOW();\n' \
|
|
205
|
+
+ 'RETURN NEW;\n' \
|
|
206
|
+
+ 'END;\n' \
|
|
207
|
+
+ "$$ language 'plpgsql';\n"
|
|
208
|
+
|
|
209
|
+
# finally, need a trigger that calls the function on update:
|
|
210
|
+
trigger_name = Database.labelName('auto_now_trigger_' + full_name)
|
|
211
|
+
|
|
212
|
+
sql += f'CREATE TRIGGER "{trigger_name}" BEFORE UPDATE ON {table_name} ' \
|
|
213
|
+
+ f'FOR EACH ROW EXECUTE PROCEDURE "{function_name}"();'
|
|
214
|
+
|
|
215
|
+
return sql
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class IntField(Field):
|
|
219
|
+
|
|
220
|
+
db_type = 'integer'
|
|
221
|
+
|
|
222
|
+
def __init__(self, default=None, **kwargs) -> None:
|
|
223
|
+
if default is not None and not isinstance(default, int):
|
|
224
|
+
raise TypeError('Invalid default type: ' + str(type(default)))
|
|
225
|
+
|
|
226
|
+
super().__init__(default=default, **kwargs)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class FloatField(Field):
|
|
230
|
+
|
|
231
|
+
db_type = 'float'
|
|
232
|
+
|
|
233
|
+
def __init__(self, default=None, **kwargs) -> None:
|
|
234
|
+
if default is not None and not isinstance(default, float):
|
|
235
|
+
raise TypeError('Invalid default type: ' + str(type(default)))
|
|
236
|
+
|
|
237
|
+
super().__init__(default=default, **kwargs)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class TextField(Field):
|
|
241
|
+
|
|
242
|
+
db_type = 'text'
|
|
243
|
+
|
|
244
|
+
def __init__(self, default=None, **kwargs) -> None:
|
|
245
|
+
# default needs to wrapped in single quotes
|
|
246
|
+
if default is not None:
|
|
247
|
+
if not isinstance(default, str):
|
|
248
|
+
raise TypeError('Invalid default type: ' + str(type(default)))
|
|
249
|
+
|
|
250
|
+
# NOTE: in theory we could escape these below, but not sure it's safe
|
|
251
|
+
# probably need some other method to ensure there aren't any crazy tricks here
|
|
252
|
+
# after way too much investigation it's unclear if there is a good method to allow this
|
|
253
|
+
if "'" in default or '\\' in default:
|
|
254
|
+
raise ValueError('Single quotes and backslashes are not allowed in default values: ' + default)
|
|
255
|
+
# escape bad characters
|
|
256
|
+
# default = default.replace("'", "''").replace('\', '\\')
|
|
257
|
+
|
|
258
|
+
# wrap it in quotes
|
|
259
|
+
default = "'" + default + "'"
|
|
260
|
+
|
|
261
|
+
super().__init__(default=default, **kwargs)
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def field_type(self):
|
|
265
|
+
return self.db_type
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class TimeField(Field):
|
|
269
|
+
|
|
270
|
+
db_type = 'time without time zone'
|
|
271
|
+
|
|
272
|
+
def __init__(self, default=None, **kwargs) -> None:
|
|
273
|
+
if default is not None:
|
|
274
|
+
if isinstance(default, time):
|
|
275
|
+
# NOTE that this assumes the time is UTC already
|
|
276
|
+
default = default.strftime('%H:%M:%S.%f')
|
|
277
|
+
else:
|
|
278
|
+
raise TypeError('Invalid default type: ' + str(type(default)))
|
|
279
|
+
|
|
280
|
+
super().__init__(default=default, **kwargs)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class UUIDField(Field):
|
|
284
|
+
|
|
285
|
+
db_type = 'uuid'
|
|
286
|
+
|
|
287
|
+
def __init__(self, default='gen_random_uuid()', **kwargs) -> None:
|
|
288
|
+
if default is not None and default != 'gen_random_uuid()' and not isinstance(default, UUID):
|
|
289
|
+
raise TypeError('Invalid default type: ' + str(type(default)))
|
|
290
|
+
|
|
291
|
+
super().__init__(default=default, **kwargs)
|
|
292
|
+
|
|
293
|
+
@classmethod
|
|
294
|
+
def convert(cls, value):
|
|
295
|
+
return base64.urlsafe_b64encode(value.bytes).rstrip(b'=').decode()
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class ForeignKeyField(Field):
|
|
299
|
+
|
|
300
|
+
def __init__(self, model_class, default=None, **kwargs) -> None:
|
|
301
|
+
if model_class.id.db_type == 'uuid':
|
|
302
|
+
if default is not None and not isinstance(default, UUID):
|
|
303
|
+
raise TypeError('Invalid default type: ' + str(type(default)))
|
|
304
|
+
elif model_class.id.db_type == 'int':
|
|
305
|
+
if default is not None and not isinstance(default, int):
|
|
306
|
+
raise TypeError('Invalid default type: ' + str(type(default)))
|
|
307
|
+
else:
|
|
308
|
+
# FUTURE: in theory strings/charvars could work here too
|
|
309
|
+
raise NotImplementedError
|
|
310
|
+
|
|
311
|
+
super().__init__(default=default, **kwargs)
|
|
312
|
+
self.model_class = model_class
|
|
313
|
+
|
|
314
|
+
# we can support either int or uuid by dynamically using the other model's type
|
|
315
|
+
self.db_type = model_class.id.db_type
|
|
316
|
+
|
|
317
|
+
def constraint(self, table_name, col_name): # NOQA: ARG002
|
|
318
|
+
if not validateName(col_name):
|
|
319
|
+
raise ValueError('Invalid column name: ' + col_name)
|
|
320
|
+
|
|
321
|
+
# FUTURE: be able to disable "ON DELETE CASCADE"
|
|
322
|
+
name = Database.labelName(col_name + '_fkey')
|
|
323
|
+
return f'CONSTRAINT "{name}" FOREIGN KEY("{col_name}") REFERENCES {self.model_class.schema_table}(id) ' \
|
|
324
|
+
+ 'ON DELETE CASCADE'
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class Query:
|
|
328
|
+
|
|
329
|
+
# FUTURE: should we also support "is distinct from" or "between" here?
|
|
330
|
+
# https://www.postgresql.org/docs/13/functions-comparison.html
|
|
331
|
+
DIRECTIONS = ('ASC', 'DESC')
|
|
332
|
+
FUNCTIONS = ('LOWER', 'UPPER')
|
|
333
|
+
LOGICAL = ('AND', 'OR')
|
|
334
|
+
OPERATORS = ('=', '!=', '<', '>', '<=', '>=', 'LIKE', 'ILIKE') # 'IN', '%', '*', '!'
|
|
335
|
+
IS_OPERATORS = ('IS', 'IS NOT')
|
|
336
|
+
IS_VALUES = ('NULL', 'TRUE', 'FALSE', 'UNKOWN')
|
|
337
|
+
# FUTURE: should we alias alternatives like `=>` and `=<` ? better to just have one way or allow either?
|
|
338
|
+
|
|
339
|
+
def __init__(self, conn, model_class, columns=None, alias=None) -> None:
|
|
340
|
+
# FUTURE: support selecting specific columns here - note that the count function below will have to be modified
|
|
341
|
+
self.conn = conn
|
|
342
|
+
self.model_class = model_class
|
|
343
|
+
self.args = []
|
|
344
|
+
self.logic_level = 0
|
|
345
|
+
self.alias = alias
|
|
346
|
+
self.order_by_sql = ''
|
|
347
|
+
self.limit_sql = ''
|
|
348
|
+
self.offset_sql = ''
|
|
349
|
+
|
|
350
|
+
self._sql = 'SELECT '
|
|
351
|
+
if columns:
|
|
352
|
+
for column in columns:
|
|
353
|
+
self._check_col(column)
|
|
354
|
+
|
|
355
|
+
self._sql += ', '.join(['"' + column + '"' for column in columns])
|
|
356
|
+
else:
|
|
357
|
+
self._sql += '*'
|
|
358
|
+
|
|
359
|
+
self._sql += f' FROM {model_class.schema_table}'
|
|
360
|
+
|
|
361
|
+
if alias:
|
|
362
|
+
self._sql += f' AS "{alias}"'
|
|
363
|
+
|
|
364
|
+
@property
|
|
365
|
+
def sql(self):
|
|
366
|
+
# apply all the ending logic as needed
|
|
367
|
+
sql = self._sql
|
|
368
|
+
|
|
369
|
+
if self.order_by_sql:
|
|
370
|
+
sql += self.order_by_sql
|
|
371
|
+
|
|
372
|
+
if self.limit_sql:
|
|
373
|
+
sql += self.limit_sql
|
|
374
|
+
|
|
375
|
+
if self.offset_sql:
|
|
376
|
+
sql += self.offset_sql
|
|
377
|
+
|
|
378
|
+
return sql
|
|
379
|
+
|
|
380
|
+
def _check_col(self, name):
|
|
381
|
+
fields = self.model_class.fields
|
|
382
|
+
if name not in fields:
|
|
383
|
+
raise ValueError(f'Unknown field {name}')
|
|
384
|
+
|
|
385
|
+
def start_logic(self):
|
|
386
|
+
if ' WHERE ' not in self._sql:
|
|
387
|
+
self._sql += ' WHERE'
|
|
388
|
+
|
|
389
|
+
self._sql += ' ('
|
|
390
|
+
self.logic_level += 1
|
|
391
|
+
|
|
392
|
+
return self
|
|
393
|
+
|
|
394
|
+
def end_logic(self):
|
|
395
|
+
|
|
396
|
+
if self.logic_level < 1:
|
|
397
|
+
raise RuntimeError('Tried to close a logic group without one open')
|
|
398
|
+
|
|
399
|
+
if self._sql[-1] == '(':
|
|
400
|
+
raise RuntimeError('Empty logic group')
|
|
401
|
+
|
|
402
|
+
self._sql += ')'
|
|
403
|
+
|
|
404
|
+
self.logic_level -= 1
|
|
405
|
+
return self
|
|
406
|
+
|
|
407
|
+
def where(self, col_name, operator, col_value, logic='AND', func=None, parent_query=None):
|
|
408
|
+
|
|
409
|
+
self._check_col(col_name)
|
|
410
|
+
|
|
411
|
+
if operator in Query.IS_OPERATORS:
|
|
412
|
+
if col_value not in Query.IS_VALUES:
|
|
413
|
+
msg = 'Values for "' + operator + '" operator must be one of ' + str(Query.IS_VALUES)
|
|
414
|
+
msg += ', unknown value: ' + col_value
|
|
415
|
+
raise ValueError(msg)
|
|
416
|
+
elif operator not in Query.OPERATORS:
|
|
417
|
+
raise ValueError('Unsupported operator: ' + operator)
|
|
418
|
+
|
|
419
|
+
if logic and logic not in Query.LOGICAL:
|
|
420
|
+
raise ValueError('Unsupported logic: ' + logic)
|
|
421
|
+
|
|
422
|
+
if func and func not in Query.FUNCTIONS:
|
|
423
|
+
raise ValueError('Unsupported function: ' + func)
|
|
424
|
+
|
|
425
|
+
if ' WHERE ' in self._sql:
|
|
426
|
+
if self._sql[-1] != '(':
|
|
427
|
+
self._sql += ' '
|
|
428
|
+
if logic:
|
|
429
|
+
self._sql += logic + ' '
|
|
430
|
+
else:
|
|
431
|
+
self._sql += ' WHERE '
|
|
432
|
+
|
|
433
|
+
alias = self.alias and (self.alias + '"."') or ''
|
|
434
|
+
column = f'"{alias}{col_name}"'
|
|
435
|
+
|
|
436
|
+
if func:
|
|
437
|
+
column = f'{func}({column})'
|
|
438
|
+
|
|
439
|
+
# support "is null"
|
|
440
|
+
# FUTURE: is there a more elegant way to support this?
|
|
441
|
+
if operator in Query.IS_OPERATORS:
|
|
442
|
+
self._sql += f'{column} {operator} {col_value}'
|
|
443
|
+
elif parent_query:
|
|
444
|
+
# this checks that the col value is a part of the parent, and applies an alias
|
|
445
|
+
parent_query._check_col(col_value) # NOQA: SLF001
|
|
446
|
+
|
|
447
|
+
if not parent_query.alias:
|
|
448
|
+
raise RuntimeError('Alias is required when using a parent query column as a value')
|
|
449
|
+
|
|
450
|
+
if self.alias == parent_query.alias:
|
|
451
|
+
raise RuntimeError('Parent queries and sub queries must not have the same alias')
|
|
452
|
+
|
|
453
|
+
parent_col_value = f'"{parent_query.alias}"."{col_value}"'
|
|
454
|
+
|
|
455
|
+
self._sql += f'{column} {operator} {parent_col_value}'
|
|
456
|
+
else:
|
|
457
|
+
position = len(self.args) + 1
|
|
458
|
+
|
|
459
|
+
# FUTURE: handle tuples? handle other logic matching conditions?
|
|
460
|
+
pos = f'ANY(${position})' if isinstance(col_value, list) else f'${position}'
|
|
461
|
+
|
|
462
|
+
self._sql += f'{column} {operator} {pos}'
|
|
463
|
+
|
|
464
|
+
self.args.append(col_value)
|
|
465
|
+
|
|
466
|
+
return self
|
|
467
|
+
|
|
468
|
+
def add_logic(self, logic):
|
|
469
|
+
if logic not in Query.LOGICAL:
|
|
470
|
+
raise ValueError('Unsupported logic: ' + logic)
|
|
471
|
+
|
|
472
|
+
self._sql += ' ' + logic
|
|
473
|
+
|
|
474
|
+
return self
|
|
475
|
+
|
|
476
|
+
def exists(self, subquery):
|
|
477
|
+
|
|
478
|
+
if not isinstance(subquery, Query):
|
|
479
|
+
raise TypeError('Subquery must be an instance of `Query`')
|
|
480
|
+
|
|
481
|
+
sql = subquery.sql
|
|
482
|
+
if subquery.args:
|
|
483
|
+
# NOTE that because of the intermediary replacement mechanism below we have to limit the number of args
|
|
484
|
+
# this is determined by ascii z (122) - A (65) = 57
|
|
485
|
+
# FUTURE: this seems like a very high limit we won't hit, but probably a better way to do this anyway?
|
|
486
|
+
# NOTE: actually this falls apart with double digits because replace('$1') will catch '$10'
|
|
487
|
+
# need some kind of solution for that (running the loop backwards?)
|
|
488
|
+
if len(subquery.args) > MAX_SUBQUERY_ARGS:
|
|
489
|
+
raise RuntimeError('Too many args in subquery: ' + str(len(subquery.args)))
|
|
490
|
+
|
|
491
|
+
# to avoid conflicts where we do something like replace $1 with $2 and then accidentally replace
|
|
492
|
+
# the replaced $2 with something else instead of the actual placeholder $2 later
|
|
493
|
+
# we replace everything with a letter first, and then go back and replace with the actual number
|
|
494
|
+
for i, _arg in enumerate(subquery.args):
|
|
495
|
+
sql = sql.replace(f'${i + 1}', f'${chr(i + 65)}')
|
|
496
|
+
|
|
497
|
+
position = len(self.args) + 1
|
|
498
|
+
for i, arg in enumerate(subquery.args):
|
|
499
|
+
sql = sql.replace(f'${chr(i + 65)}', f'${position}')
|
|
500
|
+
self.args.append(arg)
|
|
501
|
+
position += 1
|
|
502
|
+
|
|
503
|
+
# WARNING - the subquery can be anything right now - do not expose to end users like this!
|
|
504
|
+
self._sql += ' EXISTS (' + sql + ')'
|
|
505
|
+
|
|
506
|
+
return self
|
|
507
|
+
|
|
508
|
+
def order_by(self, col_name, direction='ASC'):
|
|
509
|
+
|
|
510
|
+
self._check_col(col_name)
|
|
511
|
+
|
|
512
|
+
if direction not in Query.DIRECTIONS:
|
|
513
|
+
raise ValueError('Unsupported direction: ' + direction)
|
|
514
|
+
|
|
515
|
+
if ' ORDER BY ' in self.order_by_sql:
|
|
516
|
+
self.order_by_sql += ','
|
|
517
|
+
else:
|
|
518
|
+
self.order_by_sql += ' ORDER BY'
|
|
519
|
+
|
|
520
|
+
self.order_by_sql += f' "{col_name}" {direction}'
|
|
521
|
+
|
|
522
|
+
return self
|
|
523
|
+
|
|
524
|
+
# FUTURE: do we need to support this?
|
|
525
|
+
# def group_by(self, col_name):
|
|
526
|
+
|
|
527
|
+
def limit(self, n: int):
|
|
528
|
+
|
|
529
|
+
if not isinstance(n, int):
|
|
530
|
+
raise TypeError('Limit must be an integer: ' + str(n))
|
|
531
|
+
|
|
532
|
+
if n < 1:
|
|
533
|
+
raise ValueError('Limit must be greater than zero: ' + str(n))
|
|
534
|
+
|
|
535
|
+
if n > MAX_LIMIT:
|
|
536
|
+
raise ValueError('Limit must be 10000 or less: ' + str(n))
|
|
537
|
+
|
|
538
|
+
if ' LIMIT ' in self.limit_sql:
|
|
539
|
+
raise RuntimeError('Multiple calls to limit on the same query are not allowed')
|
|
540
|
+
|
|
541
|
+
self.limit_sql += f' LIMIT {n}'
|
|
542
|
+
|
|
543
|
+
return self
|
|
544
|
+
|
|
545
|
+
def offset(self, n: int):
|
|
546
|
+
|
|
547
|
+
if not isinstance(n, int):
|
|
548
|
+
raise TypeError('Offset must be an integer: ' + str(n))
|
|
549
|
+
|
|
550
|
+
if n < 0:
|
|
551
|
+
raise ValueError('Offset must not be negative: ' + str(n))
|
|
552
|
+
|
|
553
|
+
if ' OFFSET ' in self.offset_sql:
|
|
554
|
+
raise RuntimeError('Multiple calls to offset on the same query are not allowed')
|
|
555
|
+
|
|
556
|
+
self.offset_sql += f' OFFSET {n}'
|
|
557
|
+
|
|
558
|
+
return self
|
|
559
|
+
|
|
560
|
+
async def all(self):
|
|
561
|
+
|
|
562
|
+
if self.logic_level > 0:
|
|
563
|
+
raise RuntimeError('Tried to query without closing all logic groups')
|
|
564
|
+
|
|
565
|
+
return [self.model_class.convert(row) for row in await self.conn.fetch(self.sql, *self.args)]
|
|
566
|
+
|
|
567
|
+
async def count(self):
|
|
568
|
+
|
|
569
|
+
if self.logic_level > 0:
|
|
570
|
+
raise RuntimeError('Tried to query without closing all logic groups')
|
|
571
|
+
|
|
572
|
+
# NOTE that we purposefully use the _sql here that doesn't have endings applied
|
|
573
|
+
# because those can mess with the count
|
|
574
|
+
# NOTE the 1 is important here so we don't replace the select on subqueries if they exist
|
|
575
|
+
sql = self._sql.replace('SELECT * FROM', 'SELECT COUNT(*) FROM', 1)
|
|
576
|
+
|
|
577
|
+
return await self.conn.fetchval(sql, *self.args)
|
|
578
|
+
|
|
579
|
+
async def first(self):
|
|
580
|
+
|
|
581
|
+
if self.logic_level > 0:
|
|
582
|
+
raise RuntimeError('Tried to query without closing all logic groups')
|
|
583
|
+
|
|
584
|
+
return self.model_class.convert(await self.conn.fetchrow(self.sql, *self.args))
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
class Model:
|
|
588
|
+
|
|
589
|
+
_fields = None
|
|
590
|
+
_fields_class = None
|
|
591
|
+
|
|
592
|
+
class Meta:
|
|
593
|
+
# optional defaults:
|
|
594
|
+
schema = 'public'
|
|
595
|
+
|
|
596
|
+
# required on child classes (no default):
|
|
597
|
+
# table = ''
|
|
598
|
+
|
|
599
|
+
# constraints = {}
|
|
600
|
+
|
|
601
|
+
# docs: https://magicstack.github.io/asyncpg/current/api/index.html
|
|
602
|
+
# examples: https://github.com/jordic/fastapi_asyncpg/blob/master/fastapi_asyncpg/sql.py
|
|
603
|
+
|
|
604
|
+
# FUTURE: can we generalize checking these values and turning them into properties
|
|
605
|
+
# so we don't have to do each individually?
|
|
606
|
+
|
|
607
|
+
# also NOTE that caching these by setting the values on the class after first access can cause big problems
|
|
608
|
+
# we have cls.schema_table = schema_table at the end of that and it cached the BASE class version for all children
|
|
609
|
+
# caching stuff like that will have to take the cls.__name__ into account
|
|
610
|
+
@classmethod
|
|
611
|
+
@property
|
|
612
|
+
def meta_table(cls):
|
|
613
|
+
return hasattr(cls.Meta, 'table') and cls.Meta.table or cls.__name__.lower()
|
|
614
|
+
|
|
615
|
+
@classmethod
|
|
616
|
+
@property
|
|
617
|
+
def meta_schema(cls):
|
|
618
|
+
return hasattr(cls.Meta, 'schema') and cls.Meta.schema or Model.Meta.schema
|
|
619
|
+
|
|
620
|
+
@classmethod
|
|
621
|
+
@property
|
|
622
|
+
def schema_table(cls):
|
|
623
|
+
# NOTE: if schema is quoted together with the table name then postgres assumes it's in the public schema
|
|
624
|
+
# e.g. "public.user" gets converted to "public.public.user" which doesn't exist
|
|
625
|
+
# could still put schema in quotes separately if we're concerned ("schema"."table")
|
|
626
|
+
# but they're all system defined so it shouldn't be an issue
|
|
627
|
+
return f'{cls.meta_schema}."{cls.meta_table}"'
|
|
628
|
+
|
|
629
|
+
@classmethod
|
|
630
|
+
@property
|
|
631
|
+
def fields(cls):
|
|
632
|
+
# NOTE: this guard is needed to avoid infinite recursion caused by the getmembers call
|
|
633
|
+
# and the class name keeps earlier calls to base classes from overriding child classes
|
|
634
|
+
if cls._fields is None or cls._fields_class != cls.__name__:
|
|
635
|
+
cls._fields_class = cls.__name__
|
|
636
|
+
cls._fields = {}
|
|
637
|
+
|
|
638
|
+
attributes = inspect.getmembers(cls, lambda a: not(inspect.isroutine(a)))
|
|
639
|
+
attrs = [a for a in attributes if not a[0].startswith('_')]
|
|
640
|
+
for name, field in attrs:
|
|
641
|
+
if isinstance(field, Field):
|
|
642
|
+
# NOTE: the strip here allows for correcting conflicts between built in methods and field names
|
|
643
|
+
# e.g. there's a `create` method so we call a field `create_` and it works because of this
|
|
644
|
+
cls._fields[name.strip('_')] = field
|
|
645
|
+
|
|
646
|
+
return cls._fields
|
|
647
|
+
|
|
648
|
+
@classmethod
|
|
649
|
+
def convert(cls, item):
|
|
650
|
+
if not item:
|
|
651
|
+
return None
|
|
652
|
+
|
|
653
|
+
# NOTE: when this is called the item is still a Record type from asyncpg
|
|
654
|
+
return dict(item)
|
|
655
|
+
|
|
656
|
+
# NOTE that these all return dicts rather than objects, and that's ok, probably preferable
|
|
657
|
+
@classmethod
|
|
658
|
+
async def get(cls, conn, record_id):
|
|
659
|
+
|
|
660
|
+
return cls.convert(await conn.fetchrow(f'SELECT * FROM {cls.schema_table} WHERE id = $1', record_id)) # NOQA: S608
|
|
661
|
+
|
|
662
|
+
@classmethod
|
|
663
|
+
def select(cls, conn, columns=None, alias=None):
|
|
664
|
+
return Query(conn, cls, columns=columns, alias=alias)
|
|
665
|
+
|
|
666
|
+
@classmethod
|
|
667
|
+
async def create(cls, conn, **kwargs):
|
|
668
|
+
fields = cls.fields
|
|
669
|
+
names = []
|
|
670
|
+
values = []
|
|
671
|
+
|
|
672
|
+
# FUTURE: we could loop through the fields instead
|
|
673
|
+
# and then if we wanted to we could have a python only default (i.e. not controlled by the database)
|
|
674
|
+
for name, value in kwargs.items():
|
|
675
|
+
if name in fields:
|
|
676
|
+
if name == 'id':
|
|
677
|
+
raise KeyError(f'Primary key field {name} is auto generated, do not specify')
|
|
678
|
+
|
|
679
|
+
# NOTE: letting the database exclusively worry about values, nulls, uniques, etc. for now
|
|
680
|
+
names.append('"' + name + '"')
|
|
681
|
+
values.append(value)
|
|
682
|
+
else:
|
|
683
|
+
raise KeyError(f'Unknown field {name}')
|
|
684
|
+
|
|
685
|
+
sql = f'INSERT INTO {cls.schema_table} '
|
|
686
|
+
|
|
687
|
+
# we have to check for this having something in it because it's possible all values are defaults
|
|
688
|
+
if names:
|
|
689
|
+
# NOTE that we use the python formatting to create the $1, $2, $3, etc. places for sql
|
|
690
|
+
# but then pass the actual values through the function rather than trying to format them
|
|
691
|
+
places = ', '.join([f'${i}' for i in range(1, len(values) + 1)])
|
|
692
|
+
sql += f'({", ".join(names)}) VALUES ({places}) '
|
|
693
|
+
else:
|
|
694
|
+
sql += 'DEFAULT VALUES '
|
|
695
|
+
|
|
696
|
+
# NOTE: we want to be sure to use "RETURNING *" here so we can send back data exactly as the db has it
|
|
697
|
+
sql += 'RETURNING *'
|
|
698
|
+
|
|
699
|
+
return cls.convert(await conn.fetchrow(sql, *values))
|
|
700
|
+
|
|
701
|
+
# FUTURE: support bulk insert
|
|
702
|
+
# NOTE: for bulk insert we can use executemany, but that doesn't return ids
|
|
703
|
+
# see https://stackoverflow.com/questions/43739123/best-way-to-insert-multiple-rows-with-asyncpg
|
|
704
|
+
# which includes an answer on how to do that if we need it
|
|
705
|
+
|
|
706
|
+
# async def insert(conn, table, values):
|
|
707
|
+
# qs = "insert into {table} ({columns}) values ({values}) returning *".format(
|
|
708
|
+
# table=table,
|
|
709
|
+
# values=",".join([f"${p + 1}" for p in range(len(values.values()))]),
|
|
710
|
+
# columns=",".join(list(values.keys())),
|
|
711
|
+
# )
|
|
712
|
+
# return await conn.fetchrow(qs, *list(values.values()))
|
|
713
|
+
|
|
714
|
+
# FUTURE: support bulk update
|
|
715
|
+
@classmethod
|
|
716
|
+
async def update(cls, conn, record_id, **kwargs):
|
|
717
|
+
# FUTURE: this only works for a single record, create a batch version for multiple at once
|
|
718
|
+
fields = cls.fields
|
|
719
|
+
names = []
|
|
720
|
+
values = []
|
|
721
|
+
|
|
722
|
+
# FUTURE: we could loop through the fields instead
|
|
723
|
+
# and then if we wanted to we could have python only logic for updates (i.e. not controlled by the database)
|
|
724
|
+
for name, value in kwargs.items():
|
|
725
|
+
if name in fields:
|
|
726
|
+
if name == 'id':
|
|
727
|
+
raise KeyError(f'Primary key field {name} is auto generated, do not specify')
|
|
728
|
+
|
|
729
|
+
# NOTE: letting the database exclusively worry about values, nulls, uniques, etc. for now
|
|
730
|
+
names.append('"' + name + '"')
|
|
731
|
+
values.append(value)
|
|
732
|
+
else:
|
|
733
|
+
raise KeyError(f'Unknown field: {name}')
|
|
734
|
+
|
|
735
|
+
if not names:
|
|
736
|
+
raise ValueError('No fields to update')
|
|
737
|
+
|
|
738
|
+
# NOTE: we want to be sure to use "RETURNING *" here so we can send back data exactly as the db has it
|
|
739
|
+
# also note that we use the python formatting to create the $1, $2, $3, etc. places for sql
|
|
740
|
+
# but then pass the actual values through the function rather than trying to format them
|
|
741
|
+
columns = ', '.join([f'{names[i]}=${i + 1}' for i in range(len(values))])
|
|
742
|
+
|
|
743
|
+
sql = f'UPDATE {cls.schema_table} SET {columns} WHERE id=${len(values) + 1} RETURNING *' # NOQA: S608
|
|
744
|
+
|
|
745
|
+
values.append(record_id)
|
|
746
|
+
return cls.convert(await conn.fetchrow(sql, *values))
|
|
747
|
+
|
|
748
|
+
# FUTURE: the update and create methods are very similar - could we combine them somehow?
|
|
749
|
+
|
|
750
|
+
@classmethod
|
|
751
|
+
async def delete(cls, conn, record_id):
|
|
752
|
+
# this returns the text "DELETE N" where N is the amount of things deleted
|
|
753
|
+
response = await conn.execute(f'DELETE FROM {cls.schema_table} WHERE id = $1', record_id) # NOQA: S608
|
|
754
|
+
_delete, amount = response.split(' ', 1)
|
|
755
|
+
return int(amount)
|
|
756
|
+
|
|
757
|
+
@classmethod
|
|
758
|
+
async def delete_where(cls, conn, col_name, operator, col_value, and_name=None, and_operator=None, and_value=None):
|
|
759
|
+
|
|
760
|
+
fields = cls.fields
|
|
761
|
+
|
|
762
|
+
if col_name not in fields:
|
|
763
|
+
raise KeyError(f'Unknown field: {col_name}')
|
|
764
|
+
|
|
765
|
+
if operator in Query.IS_OPERATORS:
|
|
766
|
+
if col_value not in Query.IS_VALUES:
|
|
767
|
+
msg = 'Values for "' + operator + '" operator must be one of ' + str(Query.IS_VALUES)
|
|
768
|
+
msg += ', unknown value: ' + col_value
|
|
769
|
+
raise ValueError(msg)
|
|
770
|
+
elif operator not in Query.OPERATORS:
|
|
771
|
+
raise ValueError('Unsupported operator: ' + operator)
|
|
772
|
+
|
|
773
|
+
if and_operator:
|
|
774
|
+
if and_operator in Query.IS_OPERATORS:
|
|
775
|
+
if and_value not in Query.IS_VALUES:
|
|
776
|
+
msg = 'Values for "' + and_operator + '" operator must be one of ' + str(Query.IS_VALUES)
|
|
777
|
+
msg += ', unknown value: ' + and_value
|
|
778
|
+
raise ValueError(msg)
|
|
779
|
+
elif and_operator not in Query.OPERATORS:
|
|
780
|
+
raise ValueError('Unsupported operator: ' + and_operator)
|
|
781
|
+
|
|
782
|
+
# this returns the text "DELETE N" where N is the amount of things deleted
|
|
783
|
+
sql = f'DELETE FROM {cls.schema_table} WHERE "{col_name}" {operator} ' # NOQA: S608
|
|
784
|
+
|
|
785
|
+
if isinstance(col_value, list):
|
|
786
|
+
sql += 'ANY($1)'
|
|
787
|
+
else:
|
|
788
|
+
sql += '$1'
|
|
789
|
+
|
|
790
|
+
args = [col_value]
|
|
791
|
+
|
|
792
|
+
# support a single and clause to allow for exclusions
|
|
793
|
+
# FUTURE: this is already pretty complicated, any more and we need to refactor
|
|
794
|
+
# could try to combine with query, but a lot of that stuff (order by, offset, etc.) doesn't apply here
|
|
795
|
+
if and_name and and_operator and and_value:
|
|
796
|
+
sql += f' AND "{and_name}" {and_operator} $2'
|
|
797
|
+
args.append(and_value)
|
|
798
|
+
|
|
799
|
+
response = await conn.execute(sql, *args)
|
|
800
|
+
_delete, amount = response.split(' ', 1)
|
|
801
|
+
return int(amount)
|
|
802
|
+
|
|
803
|
+
@classmethod
|
|
804
|
+
def _generateColumns(cls, fields):
|
|
805
|
+
columns = []
|
|
806
|
+
constraints = []
|
|
807
|
+
after_constraints = []
|
|
808
|
+
for col_name, field in fields.items():
|
|
809
|
+
column, constraint = field.create(cls.schema_table, col_name)
|
|
810
|
+
|
|
811
|
+
columns.append(column)
|
|
812
|
+
|
|
813
|
+
if constraint:
|
|
814
|
+
# some constraints like foreign keys can be included during table creation
|
|
815
|
+
# but others can only be done later because they rely on executing a separate sql command
|
|
816
|
+
# for doing things like creating functions and triggers
|
|
817
|
+
if constraint.startswith('CONSTRAINT'):
|
|
818
|
+
constraints.append(constraint)
|
|
819
|
+
else:
|
|
820
|
+
# because after constraints are executed all together they need to be separate commands
|
|
821
|
+
# so we enforce that here
|
|
822
|
+
if not constraint.endswith(';'):
|
|
823
|
+
constraint += ';'
|
|
824
|
+
|
|
825
|
+
after_constraints.append(constraint)
|
|
826
|
+
|
|
827
|
+
return columns, constraints, after_constraints
|
|
828
|
+
|
|
829
|
+
@classmethod
|
|
830
|
+
async def createTable(cls, conn):
|
|
831
|
+
|
|
832
|
+
fields = cls.fields
|
|
833
|
+
|
|
834
|
+
if not fields:
|
|
835
|
+
raise ValueError('No fields found on model')
|
|
836
|
+
|
|
837
|
+
if 'id' not in fields:
|
|
838
|
+
raise KeyError('"id" column must be explicitly defined')
|
|
839
|
+
|
|
840
|
+
columns, constraints, after_constraints = cls._generateColumns(fields)
|
|
841
|
+
|
|
842
|
+
# also include constraints from the table itself here
|
|
843
|
+
# NOTE that these are assumed to all be check constraints only
|
|
844
|
+
if hasattr(cls.Meta, 'constraints'):
|
|
845
|
+
for name, constraint in cls.Meta.constraints.items():
|
|
846
|
+
constraints.append('CONSTRAINT ' + name + ' ' + constraint)
|
|
847
|
+
|
|
848
|
+
# add constraints to the end of the list
|
|
849
|
+
columns.extend(constraints)
|
|
850
|
+
|
|
851
|
+
sql = f'CREATE TABLE {cls.schema_table} ({", ".join(columns)})'
|
|
852
|
+
|
|
853
|
+
await conn.execute(sql)
|
|
854
|
+
|
|
855
|
+
if after_constraints:
|
|
856
|
+
# each after constraint is expected to be a full statement and thus end in a semicolon - see check above
|
|
857
|
+
await conn.execute(' '.join(after_constraints))
|
|
858
|
+
|
|
859
|
+
if hasattr(cls.Meta, 'unique_indexes'):
|
|
860
|
+
for name, index in cls.Meta.unique_indexes.items():
|
|
861
|
+
await conn.execute(f'CREATE UNIQUE INDEX "{name}" ON {cls.schema_table} ({index})')
|
|
862
|
+
|
|
863
|
+
if hasattr(cls.Meta, 'indexes'):
|
|
864
|
+
for name, index in cls.Meta.indexes.items():
|
|
865
|
+
await conn.execute(f'CREATE INDEX "{name}" ON {cls.schema_table} USING {index}')
|
|
866
|
+
|
|
867
|
+
@classmethod
|
|
868
|
+
async def dropTable(cls, conn):
|
|
869
|
+
sql = f'DROP TABLE {cls.schema_table}'
|
|
870
|
+
await conn.execute(sql)
|
|
871
|
+
|
|
872
|
+
@classmethod
|
|
873
|
+
async def addColumns(cls, conn, fields):
|
|
874
|
+
|
|
875
|
+
if not fields:
|
|
876
|
+
raise ValueError('No fields found on model')
|
|
877
|
+
|
|
878
|
+
if 'id' in fields:
|
|
879
|
+
raise KeyError('"id" column can not be added after table creation')
|
|
880
|
+
|
|
881
|
+
columns, constraints, after_constraints = cls._generateColumns(fields)
|
|
882
|
+
|
|
883
|
+
# alter table needs additional verbage vs create table
|
|
884
|
+
# columns need "add column"
|
|
885
|
+
columns = ['ADD COLUMN ' + column for column in columns]
|
|
886
|
+
|
|
887
|
+
# and constraints need "add"
|
|
888
|
+
constraints = ['ADD ' + constraint for constraint in constraints]
|
|
889
|
+
|
|
890
|
+
# add constraints to the end of the list
|
|
891
|
+
columns.extend(constraints)
|
|
892
|
+
|
|
893
|
+
sql = f'ALTER TABLE {cls.schema_table} {", ".join(columns)}'
|
|
894
|
+
|
|
895
|
+
await conn.execute(sql)
|
|
896
|
+
|
|
897
|
+
if after_constraints:
|
|
898
|
+
# each after constraint is expected to be a full statement and thus end in a semicolon - see check above
|
|
899
|
+
await conn.execute(' '.join(after_constraints))
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
[tool.poetry]
|
|
2
|
+
name = 'naboo'
|
|
3
|
+
version = '0.1.0'
|
|
4
|
+
description = ''
|
|
5
|
+
|
|
6
|
+
license = 'MIT'
|
|
7
|
+
|
|
8
|
+
authors = ['Brendan Doms']
|
|
9
|
+
|
|
10
|
+
readme = 'README.md'
|
|
11
|
+
|
|
12
|
+
repository = 'https://github.com/bdoms/naboo'
|
|
13
|
+
homepage = 'https://github.com/bdoms/naboo'
|
|
14
|
+
|
|
15
|
+
# packages = [
|
|
16
|
+
# {include = 'naboo'}
|
|
17
|
+
# ]
|
|
18
|
+
|
|
19
|
+
# include = [
|
|
20
|
+
# ]
|
|
21
|
+
|
|
22
|
+
# exclude = [
|
|
23
|
+
# ]
|
|
24
|
+
|
|
25
|
+
[tool.poetry.dependencies]
|
|
26
|
+
python = '^3.12'
|
|
27
|
+
asyncpg = '0.29.0'
|
|
28
|
+
|
|
29
|
+
[tool.poetry.dev-dependencies]
|
|
30
|
+
pytest = '8.3.3'
|
|
31
|
+
pytest-asyncio = '0.24.0'
|
|
32
|
+
pytest-env = '1.1.4'
|
|
33
|
+
ruff = '0.6.5'
|