sql-blocks 1.25.2__py3-none-any.whl → 1.25.47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_blocks/sql_blocks.py +577 -123
- {sql_blocks-1.25.2.dist-info → sql_blocks-1.25.47.dist-info}/METADATA +295 -5
- sql_blocks-1.25.47.dist-info/RECORD +7 -0
- sql_blocks-1.25.2.dist-info/RECORD +0 -7
- {sql_blocks-1.25.2.dist-info → sql_blocks-1.25.47.dist-info}/LICENSE +0 -0
- {sql_blocks-1.25.2.dist-info → sql_blocks-1.25.47.dist-info}/WHEEL +0 -0
- {sql_blocks-1.25.2.dist-info → sql_blocks-1.25.47.dist-info}/top_level.txt +0 -0
sql_blocks/sql_blocks.py
CHANGED
@@ -26,7 +26,7 @@ TO_LIST = lambda x: x if isinstance(x, list) else [x]
|
|
26
26
|
|
27
27
|
|
28
28
|
class SQLObject:
|
29
|
-
ALIAS_FUNC =
|
29
|
+
ALIAS_FUNC = None
|
30
30
|
""" ^^^^^^^^^^^^^^^^^^^^^^^^
|
31
31
|
You can change the behavior by assigning
|
32
32
|
a user function to SQLObject.ALIAS_FUNC
|
@@ -41,21 +41,35 @@ class SQLObject:
|
|
41
41
|
def set_table(self, table_name: str):
|
42
42
|
if not table_name:
|
43
43
|
return
|
44
|
-
|
44
|
+
cls = SQLObject
|
45
|
+
is_file_name = any([
|
46
|
+
'/' in table_name, '.' in table_name
|
47
|
+
])
|
48
|
+
ref = table_name
|
49
|
+
if is_file_name:
|
50
|
+
ref = table_name.split('/')[-1].split('.')[0]
|
51
|
+
if cls.ALIAS_FUNC:
|
52
|
+
self.__alias = cls.ALIAS_FUNC(ref)
|
53
|
+
elif ' ' in table_name.strip():
|
45
54
|
table_name, self.__alias = table_name.split()
|
46
|
-
elif '_' in
|
55
|
+
elif '_' in ref:
|
47
56
|
self.__alias = ''.join(
|
48
57
|
word[0].lower()
|
49
|
-
for word in
|
58
|
+
for word in ref.split('_')
|
50
59
|
)
|
51
60
|
else:
|
52
|
-
self.__alias =
|
61
|
+
self.__alias = ref.lower()[:3]
|
53
62
|
self.values.setdefault(FROM, []).append(f'{table_name} {self.alias}')
|
54
63
|
|
55
64
|
@property
|
56
65
|
def table_name(self) -> str:
|
57
66
|
return self.values[FROM][0].split()[0]
|
58
67
|
|
68
|
+
def set_file_format(self, pattern: str):
|
69
|
+
if '{' not in pattern:
|
70
|
+
pattern = '{}' + pattern
|
71
|
+
self.values[FROM][0] = pattern.format(self.aka())
|
72
|
+
|
59
73
|
@property
|
60
74
|
def alias(self) -> str:
|
61
75
|
if self.__alias:
|
@@ -67,6 +81,16 @@ class SQLObject:
|
|
67
81
|
appendix = {WHERE: r'\s+and\s+|', FROM: r'\s+join\s+|\s+JOIN\s+'}
|
68
82
|
return KEYWORD[key][0].format(appendix.get(key, ''))
|
69
83
|
|
84
|
+
@staticmethod
|
85
|
+
def is_named_field(fld: str, name: str='') -> bool:
|
86
|
+
return re.search(fr'(\s+as\s+|\s+AS\s+){name}', fld)
|
87
|
+
|
88
|
+
def has_named_field(self, name: str) -> bool:
|
89
|
+
return any(
|
90
|
+
self.is_named_field(fld, name)
|
91
|
+
for fld in self.values.get(SELECT, [])
|
92
|
+
)
|
93
|
+
|
70
94
|
def diff(self, key: str, search_list: list, exact: bool=False) -> set:
|
71
95
|
def disassemble(source: list) -> list:
|
72
96
|
if not exact:
|
@@ -79,12 +103,12 @@ class SQLObject:
|
|
79
103
|
if exact:
|
80
104
|
fld = fld.lower()
|
81
105
|
return fld.strip()
|
82
|
-
def is_named_field(fld: str) -> bool:
|
83
|
-
return key == SELECT and re.search(r'\s+as\s+|\s+AS\s+', fld)
|
84
106
|
def field_set(source: list) -> set:
|
85
107
|
return set(
|
86
108
|
(
|
87
|
-
fld
|
109
|
+
fld
|
110
|
+
if key == SELECT and self.is_named_field(fld, key)
|
111
|
+
else
|
88
112
|
re.sub(pattern, '', cleanup(fld))
|
89
113
|
)
|
90
114
|
for string in disassemble(source)
|
@@ -102,13 +126,22 @@ class SQLObject:
|
|
102
126
|
return s1.symmetric_difference(s2)
|
103
127
|
return s1 - s2
|
104
128
|
|
105
|
-
def delete(self, search: str, keys: list=USUAL_KEYS):
|
129
|
+
def delete(self, search: str, keys: list=USUAL_KEYS, exact: bool=False):
|
130
|
+
if exact:
|
131
|
+
not_match = lambda item: not re.search(fr'\w*[.]*{search}$', item)
|
132
|
+
else:
|
133
|
+
not_match = lambda item: search not in item
|
106
134
|
for key in keys:
|
107
|
-
|
108
|
-
|
109
|
-
if
|
110
|
-
|
111
|
-
|
135
|
+
self.values[key] = [
|
136
|
+
item for item in self.values.get(key, [])
|
137
|
+
if not_match(item)
|
138
|
+
]
|
139
|
+
|
140
|
+
|
141
|
+
SQL_CONST_SYSDATE = 'SYSDATE'
|
142
|
+
SQL_CONST_CURR_DATE = 'Current_date'
|
143
|
+
SQL_ROW_NUM = 'ROWNUM'
|
144
|
+
SQL_CONSTS = [SQL_CONST_SYSDATE, SQL_CONST_CURR_DATE, SQL_ROW_NUM]
|
112
145
|
|
113
146
|
|
114
147
|
class Field:
|
@@ -116,10 +149,16 @@ class Field:
|
|
116
149
|
|
117
150
|
@classmethod
|
118
151
|
def format(cls, name: str, main: SQLObject) -> str:
|
152
|
+
def is_const() -> bool:
|
153
|
+
return any([
|
154
|
+
re.findall('[.()0-9]', name),
|
155
|
+
name in SQL_CONSTS,
|
156
|
+
re.findall(r'\w+\s*[+-]\s*\w+', name)
|
157
|
+
])
|
119
158
|
name = name.strip()
|
120
159
|
if name in ('_', '*'):
|
121
160
|
name = '*'
|
122
|
-
elif not
|
161
|
+
elif not is_const() and not main.has_named_field(name):
|
123
162
|
name = f'{main.alias}.{name}'
|
124
163
|
if Function in cls.__bases__:
|
125
164
|
name = f'{cls.__name__}({name})'
|
@@ -150,90 +189,210 @@ class NamedField:
|
|
150
189
|
)
|
151
190
|
|
152
191
|
|
192
|
+
class Dialect(Enum):
|
193
|
+
ANSI = 0
|
194
|
+
SQL_SERVER = 1
|
195
|
+
ORACLE = 2
|
196
|
+
POSTGRESQL = 3
|
197
|
+
MYSQL = 4
|
198
|
+
|
199
|
+
SQL_TYPES = 'CHAR INT DATE FLOAT ANY'.split()
|
200
|
+
CHAR, INT, DATE, FLOAT, ANY = SQL_TYPES
|
201
|
+
|
153
202
|
class Function:
|
154
|
-
|
203
|
+
dialect = Dialect.ANSI
|
204
|
+
inputs = None
|
205
|
+
output = None
|
206
|
+
separator = ', '
|
207
|
+
auto_convert = True
|
208
|
+
append_param = False
|
155
209
|
|
156
210
|
def __init__(self, *params: list):
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
211
|
+
def set_func_types(param):
|
212
|
+
if self.auto_convert and isinstance(param, Function):
|
213
|
+
func = param
|
214
|
+
main_param = self.inputs[0]
|
215
|
+
unfriendly = all([
|
216
|
+
func.output != main_param,
|
217
|
+
func.output != ANY,
|
218
|
+
main_param != ANY
|
219
|
+
])
|
220
|
+
if unfriendly:
|
221
|
+
return Cast(func, main_param)
|
222
|
+
return param
|
223
|
+
# --- Replace class methods by instance methods: ------
|
224
|
+
self.add = self.__add
|
225
|
+
self.format = self.__format
|
226
|
+
# -----------------------------------------------------
|
227
|
+
self.params = [set_func_types(p) for p in params]
|
228
|
+
self.field_class = Field
|
229
|
+
self.pattern = self.get_pattern()
|
162
230
|
self.extra = {}
|
163
231
|
|
232
|
+
def get_pattern(self) -> str:
|
233
|
+
return '{func_name}({params})'
|
234
|
+
|
164
235
|
def As(self, field_alias: str, modifiers=None):
|
165
236
|
if modifiers:
|
166
237
|
self.extra[field_alias] = TO_LIST(modifiers)
|
167
|
-
self.
|
238
|
+
self.field_class = NamedField(field_alias)
|
168
239
|
return self
|
169
240
|
|
241
|
+
def __str__(self) -> str:
|
242
|
+
return self.pattern.format(
|
243
|
+
func_name=self.__class__.__name__,
|
244
|
+
params=self.separator.join(str(p) for p in self.params)
|
245
|
+
)
|
246
|
+
|
170
247
|
@classmethod
|
171
|
-
def
|
172
|
-
|
173
|
-
|
174
|
-
|
248
|
+
def help(cls) -> str:
|
249
|
+
descr = ' '.join(B.__name__ for B in cls.__bases__)
|
250
|
+
params = cls.inputs or ''
|
251
|
+
return cls().get_pattern().format(
|
252
|
+
func_name=f'{descr} {cls.__name__}',
|
253
|
+
params=cls.separator.join(str(p) for p in params)
|
254
|
+
) + f' Return {cls.output}'
|
255
|
+
|
256
|
+
def set_main_param(self, name: str, main: SQLObject) -> bool:
|
257
|
+
nested_functions = [
|
258
|
+
param for param in self.params if isinstance(param, Function)
|
259
|
+
]
|
260
|
+
for func in nested_functions:
|
261
|
+
if func.inputs:
|
262
|
+
func.set_main_param(name, main)
|
263
|
+
return
|
264
|
+
new_params = [Field.format(name, main)]
|
265
|
+
if self.append_param:
|
266
|
+
self.params += new_params
|
175
267
|
else:
|
176
|
-
params =
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
268
|
+
self.params = new_params + self.params
|
269
|
+
|
270
|
+
def __format(self, name: str, main: SQLObject) -> str:
|
271
|
+
if name not in '*_':
|
272
|
+
self.set_main_param(name, main)
|
273
|
+
return str(self)
|
274
|
+
|
275
|
+
@classmethod
|
276
|
+
def format(cls, name: str, main: SQLObject):
|
277
|
+
return cls().__format(name, main)
|
183
278
|
|
184
279
|
def __add(self, name: str, main: SQLObject):
|
185
280
|
name = self.format(name, main)
|
186
|
-
self.
|
281
|
+
self.field_class.add(name, main)
|
187
282
|
if self.extra:
|
188
283
|
main.__call__(**self.extra)
|
189
284
|
|
190
|
-
@classmethod
|
191
|
-
def get_instance(cls):
|
192
|
-
obj = Function.instance.get(cls.__name__)
|
193
|
-
if not obj:
|
194
|
-
obj = cls()
|
195
|
-
return obj
|
196
|
-
|
197
285
|
@classmethod
|
198
286
|
def add(cls, name: str, main: SQLObject):
|
199
|
-
cls
|
287
|
+
cls().__add(name, main)
|
200
288
|
|
201
289
|
|
202
290
|
# ---- String Functions: ---------------------------------
|
203
291
|
class SubString(Function):
|
204
|
-
|
292
|
+
inputs = [CHAR, INT, INT]
|
293
|
+
output = CHAR
|
294
|
+
|
295
|
+
def get_pattern(self) -> str:
|
296
|
+
if self.dialect in (Dialect.ORACLE, Dialect.MYSQL):
|
297
|
+
return 'Substr({params})'
|
298
|
+
return super().get_pattern()
|
205
299
|
|
206
300
|
# ---- Numeric Functions: --------------------------------
|
207
301
|
class Round(Function):
|
208
|
-
|
302
|
+
inputs = [FLOAT]
|
303
|
+
output = FLOAT
|
209
304
|
|
210
305
|
# --- Date Functions: ------------------------------------
|
211
306
|
class DateDiff(Function):
|
307
|
+
inputs = [DATE]
|
308
|
+
output = DATE
|
309
|
+
append_param = True
|
310
|
+
|
311
|
+
def __str__(self) -> str:
|
312
|
+
def is_field_or_func(name: str) -> bool:
|
313
|
+
candidate = re.sub(
|
314
|
+
'[()]', '', name.split('.')[-1]
|
315
|
+
)
|
316
|
+
return candidate.isidentifier()
|
317
|
+
if self.dialect != Dialect.SQL_SERVER:
|
318
|
+
params = [str(p) for p in self.params]
|
319
|
+
return ' - '.join(
|
320
|
+
p if is_field_or_func(p) else f"'{p}'"
|
321
|
+
for p in params
|
322
|
+
) # <==== Date subtract
|
323
|
+
return super().__str__()
|
324
|
+
|
325
|
+
|
326
|
+
class DatePart(Function):
|
327
|
+
inputs = [DATE]
|
328
|
+
output = INT
|
329
|
+
|
330
|
+
def get_pattern(self) -> str:
|
331
|
+
interval = self.__class__.__name__
|
332
|
+
database_type = {
|
333
|
+
Dialect.ORACLE: 'Extract('+interval+' FROM {params})',
|
334
|
+
Dialect.POSTGRESQL: "Date_Part('"+interval+"', {params})",
|
335
|
+
}
|
336
|
+
if self.dialect in database_type:
|
337
|
+
return database_type[self.dialect]
|
338
|
+
return super().get_pattern()
|
339
|
+
|
340
|
+
class Year(DatePart):
|
212
341
|
...
|
213
|
-
class
|
342
|
+
class Month(DatePart):
|
214
343
|
...
|
215
|
-
class DatePart
|
344
|
+
class Day(DatePart):
|
216
345
|
...
|
346
|
+
|
347
|
+
|
217
348
|
class Current_Date(Function):
|
218
|
-
|
349
|
+
output = DATE
|
350
|
+
|
351
|
+
def get_pattern(self) -> str:
|
352
|
+
database_type = {
|
353
|
+
Dialect.ORACLE: SQL_CONST_SYSDATE,
|
354
|
+
Dialect.POSTGRESQL: SQL_CONST_CURR_DATE,
|
355
|
+
Dialect.SQL_SERVER: 'getDate()'
|
356
|
+
}
|
357
|
+
if self.dialect in database_type:
|
358
|
+
return database_type[self.dialect]
|
359
|
+
return super().get_pattern()
|
360
|
+
# --------------------------------------------------------
|
219
361
|
|
220
|
-
class
|
362
|
+
class Frame:
|
221
363
|
break_lines: bool = True
|
222
364
|
|
223
365
|
def over(self, **args):
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
366
|
+
"""
|
367
|
+
How to use:
|
368
|
+
over(field1=OrderBy, field2=Partition)
|
369
|
+
"""
|
370
|
+
keywords = ''
|
371
|
+
for field, obj in args.items():
|
372
|
+
is_valid = any([
|
373
|
+
obj is OrderBy,
|
374
|
+
obj is Partition,
|
375
|
+
isinstance(obj, Rows),
|
376
|
+
])
|
377
|
+
if not is_valid:
|
378
|
+
continue
|
379
|
+
keywords += '{}{} {}'.format(
|
380
|
+
'\n\t\t' if self.break_lines else ' ',
|
381
|
+
obj.cls_to_str(), field if field != '_' else ''
|
382
|
+
)
|
231
383
|
if keywords and self.break_lines:
|
232
384
|
keywords += '\n\t'
|
233
|
-
self.pattern =
|
385
|
+
self.pattern = self.get_pattern() + f' OVER({keywords})'
|
234
386
|
return self
|
235
387
|
|
236
388
|
|
389
|
+
class Aggregate(Frame):
|
390
|
+
inputs = [FLOAT]
|
391
|
+
output = FLOAT
|
392
|
+
|
393
|
+
class Window(Frame):
|
394
|
+
...
|
395
|
+
|
237
396
|
# ---- Aggregate Functions: -------------------------------
|
238
397
|
class Avg(Aggregate, Function):
|
239
398
|
...
|
@@ -246,11 +405,32 @@ class Sum(Aggregate, Function):
|
|
246
405
|
class Count(Aggregate, Function):
|
247
406
|
...
|
248
407
|
|
408
|
+
# ---- Window Functions: -----------------------------------
|
409
|
+
class Row_Number(Window, Function):
|
410
|
+
output = INT
|
411
|
+
|
412
|
+
class Rank(Window, Function):
|
413
|
+
output = INT
|
414
|
+
|
415
|
+
class Lag(Window, Function):
|
416
|
+
output = ANY
|
417
|
+
|
418
|
+
class Lead(Window, Function):
|
419
|
+
output = ANY
|
420
|
+
|
421
|
+
|
249
422
|
# ---- Conversions and other Functions: ---------------------
|
250
423
|
class Coalesce(Function):
|
251
|
-
|
424
|
+
inputs = [ANY]
|
425
|
+
output = ANY
|
426
|
+
|
252
427
|
class Cast(Function):
|
253
|
-
|
428
|
+
inputs = [ANY]
|
429
|
+
output = ANY
|
430
|
+
separator = ' As '
|
431
|
+
|
432
|
+
|
433
|
+
FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
|
254
434
|
|
255
435
|
|
256
436
|
class ExpressionField:
|
@@ -275,15 +455,20 @@ class ExpressionField:
|
|
275
455
|
class FieldList:
|
276
456
|
separator = ','
|
277
457
|
|
278
|
-
def __init__(self, fields: list=[], class_types = [Field]):
|
458
|
+
def __init__(self, fields: list=[], class_types = [Field], ziped: bool=False):
|
279
459
|
if isinstance(fields, str):
|
280
460
|
fields = [
|
281
461
|
f.strip() for f in fields.split(self.separator)
|
282
462
|
]
|
283
463
|
self.fields = fields
|
284
464
|
self.class_types = class_types
|
465
|
+
self.ziped = ziped
|
285
466
|
|
286
467
|
def add(self, name: str, main: SQLObject):
|
468
|
+
if self.ziped: # --- One class per field...
|
469
|
+
for field, class_type in zip(self.fields, self.class_types):
|
470
|
+
class_type.add(field, main)
|
471
|
+
return
|
287
472
|
for field in self.fields:
|
288
473
|
for class_type in self.class_types:
|
289
474
|
class_type.add(field, main)
|
@@ -329,23 +514,35 @@ def quoted(value) -> str:
|
|
329
514
|
return str(value)
|
330
515
|
|
331
516
|
|
517
|
+
class Position(Enum):
|
518
|
+
Middle = 0
|
519
|
+
StartsWith = 1
|
520
|
+
EndsWith = 2
|
521
|
+
|
522
|
+
|
332
523
|
class Where:
|
333
524
|
prefix = ''
|
334
525
|
|
335
|
-
def __init__(self,
|
336
|
-
self.
|
526
|
+
def __init__(self, content: str):
|
527
|
+
self.content = content
|
337
528
|
|
338
529
|
@classmethod
|
339
530
|
def __constructor(cls, operator: str, value):
|
340
|
-
return cls(
|
531
|
+
return cls(f'{operator} {quoted(value)}')
|
341
532
|
|
342
533
|
@classmethod
|
343
534
|
def eq(cls, value):
|
344
535
|
return cls.__constructor('=', value)
|
345
536
|
|
346
537
|
@classmethod
|
347
|
-
def contains(cls,
|
348
|
-
return cls(
|
538
|
+
def contains(cls, text: str, pos: Position = Position.Middle):
|
539
|
+
return cls(
|
540
|
+
"LIKE '{}{}{}'".format(
|
541
|
+
'%' if pos != Position.StartsWith else '',
|
542
|
+
text,
|
543
|
+
'%' if pos != Position.EndsWith else ''
|
544
|
+
)
|
545
|
+
)
|
349
546
|
|
350
547
|
@classmethod
|
351
548
|
def gt(cls, value):
|
@@ -373,9 +570,42 @@ class Where:
|
|
373
570
|
values = ','.join(quoted(v) for v in values)
|
374
571
|
return cls(f'IN ({values})')
|
375
572
|
|
573
|
+
@classmethod
|
574
|
+
def formula(cls, formula: str):
|
575
|
+
where = cls( ExpressionField(formula) )
|
576
|
+
where.add = where.add_expression
|
577
|
+
return where
|
578
|
+
|
579
|
+
def add_expression(self, name: str, main: SQLObject):
|
580
|
+
self.content = self.content.format(name, main)
|
581
|
+
main.values.setdefault(WHERE, []).append('{} {}'.format(
|
582
|
+
self.prefix, self.content
|
583
|
+
))
|
584
|
+
|
585
|
+
@classmethod
|
586
|
+
def join(cls, query: SQLObject):
|
587
|
+
where = cls(query)
|
588
|
+
where.add = where.add_join
|
589
|
+
return where
|
590
|
+
|
591
|
+
def add_join(self, name: str, main: SQLObject):
|
592
|
+
query = self.content
|
593
|
+
main.values[FROM].append(f',{query.table_name} {query.alias}')
|
594
|
+
for key in USUAL_KEYS:
|
595
|
+
main.update_values(key, query.values.get(key, []))
|
596
|
+
main.values.setdefault(WHERE, []).append('({a1}.{f1} = {a2}.{f2})'.format(
|
597
|
+
a1=main.alias, f1=name,
|
598
|
+
a2=query.alias, f2=query.key_field
|
599
|
+
))
|
600
|
+
|
376
601
|
def add(self, name: str, main: SQLObject):
|
602
|
+
func_type = FUNCTION_CLASS.get(name.lower())
|
603
|
+
if func_type:
|
604
|
+
name = func_type.format('*', main)
|
605
|
+
elif not main.has_named_field(name):
|
606
|
+
name = Field.format(name, main)
|
377
607
|
main.values.setdefault(WHERE, []).append('{}{} {}'.format(
|
378
|
-
self.prefix,
|
608
|
+
self.prefix, name, self.content
|
379
609
|
))
|
380
610
|
|
381
611
|
|
@@ -383,6 +613,10 @@ eq, contains, gt, gte, lt, lte, is_null, inside = (
|
|
383
613
|
getattr(Where, method) for method in
|
384
614
|
('eq', 'contains', 'gt', 'gte', 'lt', 'lte', 'is_null', 'inside')
|
385
615
|
)
|
616
|
+
startswith, endswith = [
|
617
|
+
lambda x: contains(x, Position.StartsWith),
|
618
|
+
lambda x: contains(x, Position.EndsWith)
|
619
|
+
]
|
386
620
|
|
387
621
|
|
388
622
|
class Not(Where):
|
@@ -390,7 +624,7 @@ class Not(Where):
|
|
390
624
|
|
391
625
|
@classmethod
|
392
626
|
def eq(cls, value):
|
393
|
-
return Where(
|
627
|
+
return Where(f'<> {quoted(value)}')
|
394
628
|
|
395
629
|
|
396
630
|
class Case:
|
@@ -399,20 +633,24 @@ class Case:
|
|
399
633
|
self.default = None
|
400
634
|
self.field = field
|
401
635
|
|
402
|
-
def when(self, condition: Where, result
|
636
|
+
def when(self, condition: Where, result):
|
637
|
+
if isinstance(result, str):
|
638
|
+
result = quoted(result)
|
403
639
|
self.__conditions[result] = condition
|
404
640
|
return self
|
405
641
|
|
406
|
-
def else_value(self, default
|
642
|
+
def else_value(self, default):
|
643
|
+
if isinstance(default, str):
|
644
|
+
default = quoted(default)
|
407
645
|
self.default = default
|
408
646
|
return self
|
409
647
|
|
410
648
|
def add(self, name: str, main: SQLObject):
|
411
649
|
field = Field.format(self.field, main)
|
412
|
-
default =
|
650
|
+
default = self.default
|
413
651
|
name = 'CASE \n{}\n\tEND AS {}'.format(
|
414
652
|
'\n'.join(
|
415
|
-
f'\t\tWHEN {field} {cond.
|
653
|
+
f'\t\tWHEN {field} {cond.content} THEN {res}'
|
416
654
|
for res, cond in self.__conditions.items()
|
417
655
|
) + f'\n\t\tELSE {default}' if default else '',
|
418
656
|
name
|
@@ -425,14 +663,13 @@ class Options:
|
|
425
663
|
self.__children: dict = values
|
426
664
|
|
427
665
|
def add(self, logical_separator: str, main: SQLObject):
|
428
|
-
|
429
|
-
|
430
|
-
"""
|
666
|
+
if logical_separator not in ('AND', 'OR'):
|
667
|
+
raise ValueError('`logical_separator` must be AND or OR')
|
431
668
|
conditions: list[str] = []
|
432
669
|
child: Where
|
433
670
|
for field, child in self.__children.items():
|
434
671
|
conditions.append(' {} {} '.format(
|
435
|
-
Field.format(field, main), child.
|
672
|
+
Field.format(field, main), child.content
|
436
673
|
))
|
437
674
|
main.values.setdefault(WHERE, []).append(
|
438
675
|
'(' + logical_separator.join(conditions) + ')'
|
@@ -450,18 +687,25 @@ class Between:
|
|
450
687
|
Where.gte(self.start).add(name, main),
|
451
688
|
Where.lte(self.end).add(name, main)
|
452
689
|
|
690
|
+
class SameDay(Between):
|
691
|
+
def __init__(self, date: str):
|
692
|
+
super().__init__(
|
693
|
+
f'{date} 00:00:00',
|
694
|
+
f'{date} 23:59:59',
|
695
|
+
)
|
696
|
+
|
697
|
+
|
453
698
|
|
454
699
|
class Clause:
|
455
700
|
@classmethod
|
456
701
|
def format(cls, name: str, main: SQLObject) -> str:
|
457
702
|
def is_function() -> bool:
|
458
703
|
diff = main.diff(SELECT, [name.lower()], True)
|
459
|
-
FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
|
460
704
|
return diff.intersection(FUNCTION_CLASS)
|
461
705
|
found = re.findall(r'^_\d', name)
|
462
706
|
if found:
|
463
707
|
name = found[0].replace('_', '')
|
464
|
-
elif main.alias and not is_function():
|
708
|
+
elif '.' not in name and main.alias and not is_function():
|
465
709
|
name = f'{main.alias}.{name}'
|
466
710
|
return name
|
467
711
|
|
@@ -470,6 +714,34 @@ class SortType(Enum):
|
|
470
714
|
ASC = ''
|
471
715
|
DESC = ' DESC'
|
472
716
|
|
717
|
+
class Row:
|
718
|
+
def __init__(self, value: int=0):
|
719
|
+
self.value = value
|
720
|
+
|
721
|
+
def __str__(self) -> str:
|
722
|
+
return '{} {}'.format(
|
723
|
+
'UNBOUNDED' if self.value == 0 else self.value,
|
724
|
+
self.__class__.__name__.upper()
|
725
|
+
)
|
726
|
+
|
727
|
+
class Preceding(Row):
|
728
|
+
...
|
729
|
+
class Following(Row):
|
730
|
+
...
|
731
|
+
class Current(Row):
|
732
|
+
def __str__(self) -> str:
|
733
|
+
return 'CURRENT ROW'
|
734
|
+
|
735
|
+
class Rows:
|
736
|
+
def __init__(self, *rows: list[Row]):
|
737
|
+
self.rows = rows
|
738
|
+
|
739
|
+
def cls_to_str(self) -> str:
|
740
|
+
return 'ROWS {}{}'.format(
|
741
|
+
'BETWEEN ' if len(self.rows) > 1 else '',
|
742
|
+
' AND '.join(str(row) for row in self.rows)
|
743
|
+
)
|
744
|
+
|
473
745
|
|
474
746
|
class OrderBy(Clause):
|
475
747
|
sort: SortType = SortType.ASC
|
@@ -479,6 +751,16 @@ class OrderBy(Clause):
|
|
479
751
|
name = cls.format(name, main)
|
480
752
|
main.values.setdefault(ORDER_BY, []).append(name+cls.sort.value)
|
481
753
|
|
754
|
+
@classmethod
|
755
|
+
def cls_to_str(cls) -> str:
|
756
|
+
return ORDER_BY
|
757
|
+
|
758
|
+
PARTITION_BY = 'PARTITION BY'
|
759
|
+
class Partition:
|
760
|
+
@classmethod
|
761
|
+
def cls_to_str(cls) -> str:
|
762
|
+
return PARTITION_BY
|
763
|
+
|
482
764
|
|
483
765
|
class GroupBy(Clause):
|
484
766
|
@classmethod
|
@@ -494,7 +776,7 @@ class Having:
|
|
494
776
|
|
495
777
|
def add(self, name: str, main:SQLObject):
|
496
778
|
main.values[GROUP_BY][-1] += ' HAVING {} {}'.format(
|
497
|
-
self.function.format(name, main), self.condition.
|
779
|
+
self.function.format(name, main), self.condition.content
|
498
780
|
)
|
499
781
|
|
500
782
|
@classmethod
|
@@ -524,7 +806,7 @@ class Rule:
|
|
524
806
|
...
|
525
807
|
|
526
808
|
class QueryLanguage:
|
527
|
-
pattern = '{select}{_from}{where}{group_by}{order_by}'
|
809
|
+
pattern = '{select}{_from}{where}{group_by}{order_by}{limit}'
|
528
810
|
has_default = {key: bool(key == SELECT) for key in KEYWORD}
|
529
811
|
|
530
812
|
@staticmethod
|
@@ -547,18 +829,21 @@ class QueryLanguage:
|
|
547
829
|
return self.join_with_tabs(values, ' AND ')
|
548
830
|
|
549
831
|
def sort_by(self, values: list) -> str:
|
550
|
-
return self.join_with_tabs(values)
|
832
|
+
return self.join_with_tabs(values, ',')
|
551
833
|
|
552
834
|
def set_group(self, values: list) -> str:
|
553
835
|
return self.join_with_tabs(values, ',')
|
554
836
|
|
837
|
+
def set_limit(self, values: list) -> str:
|
838
|
+
return self.join_with_tabs(values, ' ')
|
839
|
+
|
555
840
|
def __init__(self, target: 'Select'):
|
556
|
-
self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY]
|
841
|
+
self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY, LIMIT]
|
557
842
|
self.TABULATION = '\n\t' if target.break_lines else ' '
|
558
843
|
self.LINE_BREAK = '\n' if target.break_lines else ' '
|
559
844
|
self.TOKEN_METHODS = {
|
560
845
|
SELECT: self.add_field, FROM: self.get_tables,
|
561
|
-
WHERE: self.extract_conditions,
|
846
|
+
WHERE: self.extract_conditions, LIMIT: self.set_limit,
|
562
847
|
ORDER_BY: self.sort_by, GROUP_BY: self.set_group,
|
563
848
|
}
|
564
849
|
self.result = {}
|
@@ -862,10 +1147,13 @@ class SQLParser(Parser):
|
|
862
1147
|
if not key in values:
|
863
1148
|
continue
|
864
1149
|
separator = self.class_type.get_separator(key)
|
1150
|
+
cls = {
|
1151
|
+
ORDER_BY: OrderBy, GROUP_BY: GroupBy
|
1152
|
+
}.get(key, Field)
|
865
1153
|
obj.values[key] = [
|
866
|
-
|
1154
|
+
cls.format(fld, obj)
|
867
1155
|
for fld in re.split(separator, values[key])
|
868
|
-
if (fld != '*' and len(tables) == 1) or obj.match(fld)
|
1156
|
+
if (fld != '*' and len(tables) == 1) or obj.match(fld, key)
|
869
1157
|
]
|
870
1158
|
result[obj.alias] = obj
|
871
1159
|
self.queries = list( result.values() )
|
@@ -925,16 +1213,26 @@ class CypherParser(Parser):
|
|
925
1213
|
if token in self.TOKEN_METHODS:
|
926
1214
|
return
|
927
1215
|
class_list = [Field]
|
928
|
-
if '
|
1216
|
+
if '*' in token:
|
1217
|
+
token = token.replace('*', '')
|
1218
|
+
self.queries[-1].key_field = token
|
1219
|
+
return
|
1220
|
+
elif '$' in token:
|
929
1221
|
func_name, token = token.split('$')
|
930
1222
|
if func_name == 'count':
|
931
1223
|
if not token:
|
932
1224
|
token = 'count_1'
|
933
|
-
|
934
|
-
|
1225
|
+
pk_field = self.queries[-1].key_field or 'id'
|
1226
|
+
Count().As(token, extra_classes).add(pk_field, self.queries[-1])
|
1227
|
+
return
|
935
1228
|
else:
|
936
|
-
|
937
|
-
|
1229
|
+
class_type = FUNCTION_CLASS.get(func_name)
|
1230
|
+
if not class_type:
|
1231
|
+
raise ValueError(f'Unknown function `{func_name}`.')
|
1232
|
+
if ':' in token:
|
1233
|
+
token, field_alias = token.split(':')
|
1234
|
+
class_type = class_type().As(field_alias)
|
1235
|
+
class_list = [class_type]
|
938
1236
|
class_list += extra_classes
|
939
1237
|
FieldList(token, class_list).add('', self.queries[-1])
|
940
1238
|
|
@@ -949,10 +1247,13 @@ class CypherParser(Parser):
|
|
949
1247
|
def add_foreign_key(self, token: str, pk_field: str=''):
|
950
1248
|
curr, last = [self.queries[i] for i in (-1, -2)]
|
951
1249
|
if not pk_field:
|
952
|
-
if
|
953
|
-
|
954
|
-
|
955
|
-
|
1250
|
+
if last.key_field:
|
1251
|
+
pk_field = last.key_field
|
1252
|
+
else:
|
1253
|
+
if not last.values.get(SELECT):
|
1254
|
+
raise IndexError(f'Primary Key not found for {last.table_name}.')
|
1255
|
+
pk_field = last.values[SELECT][-1].split('.')[-1]
|
1256
|
+
last.delete(pk_field, [SELECT], exact=True)
|
956
1257
|
if '{}' in token:
|
957
1258
|
foreign_fld = token.format(
|
958
1259
|
last.table_name.lower()
|
@@ -967,12 +1268,11 @@ class CypherParser(Parser):
|
|
967
1268
|
if fld not in curr.values.get(GROUP_BY, [])
|
968
1269
|
]
|
969
1270
|
foreign_fld = fields[0].split('.')[-1]
|
970
|
-
curr.delete(foreign_fld, [SELECT])
|
1271
|
+
curr.delete(foreign_fld, [SELECT], exact=True)
|
971
1272
|
if curr.join_type == JoinType.RIGHT:
|
972
1273
|
pk_field, foreign_fld = foreign_fld, pk_field
|
973
1274
|
if curr.join_type == JoinType.RIGHT:
|
974
1275
|
curr, last = last, curr
|
975
|
-
# pk_field, foreign_fld = foreign_fld, pk_field
|
976
1276
|
k = ForeignKey.get_key(curr, last)
|
977
1277
|
ForeignKey.references[k] = (foreign_fld, pk_field)
|
978
1278
|
|
@@ -1158,21 +1458,30 @@ class Select(SQLObject):
|
|
1158
1458
|
|
1159
1459
|
def add(self, name: str, main: SQLObject):
|
1160
1460
|
old_tables = main.values.get(FROM, [])
|
1161
|
-
|
1162
|
-
|
1461
|
+
if len(self.values[FROM]) > 1:
|
1462
|
+
old_tables += self.values[FROM][1:]
|
1463
|
+
new_tables = []
|
1464
|
+
row = '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
|
1163
1465
|
jt=self.join_type.value,
|
1164
1466
|
tb=self.aka(),
|
1165
1467
|
a1=main.alias, f1=name,
|
1166
1468
|
a2=self.alias, f2=self.key_field
|
1167
1469
|
)
|
1168
|
-
|
1169
|
-
|
1470
|
+
if row not in old_tables[1:]:
|
1471
|
+
new_tables.append(row)
|
1472
|
+
main.values[FROM] = old_tables[:1] + new_tables + old_tables[1:]
|
1170
1473
|
for key in USUAL_KEYS:
|
1171
1474
|
main.update_values(key, self.values.get(key, []))
|
1172
1475
|
|
1173
|
-
def
|
1476
|
+
def copy(self) -> SQLObject:
|
1174
1477
|
from copy import deepcopy
|
1175
|
-
|
1478
|
+
return deepcopy(self)
|
1479
|
+
|
1480
|
+
def no_relation_error(self, other: SQLObject):
|
1481
|
+
raise ValueError(f'No relationship found between {self.table_name} and {other.table_name}.')
|
1482
|
+
|
1483
|
+
def __add__(self, other: SQLObject):
|
1484
|
+
query = self.copy()
|
1176
1485
|
if query.table_name.lower() == other.table_name.lower():
|
1177
1486
|
for key in USUAL_KEYS:
|
1178
1487
|
query.update_values(key, other.values.get(key, []))
|
@@ -1185,7 +1494,7 @@ class Select(SQLObject):
|
|
1185
1494
|
PrimaryKey.add(primary_key, query)
|
1186
1495
|
query.add(foreign_field, other)
|
1187
1496
|
return other
|
1188
|
-
|
1497
|
+
self.no_relation_error(other) # === raise ERROR ... ===
|
1189
1498
|
elif primary_key:
|
1190
1499
|
PrimaryKey.add(primary_key, other)
|
1191
1500
|
other.add(foreign_field, query)
|
@@ -1205,16 +1514,48 @@ class Select(SQLObject):
|
|
1205
1514
|
if self.diff(key, other.values.get(key, []), True):
|
1206
1515
|
return False
|
1207
1516
|
return True
|
1517
|
+
|
1518
|
+
def __sub__(self, other: SQLObject) -> SQLObject:
|
1519
|
+
fk_field, primary_k = ForeignKey.find(self, other)
|
1520
|
+
if fk_field:
|
1521
|
+
query = self.copy()
|
1522
|
+
other = other.copy()
|
1523
|
+
else:
|
1524
|
+
fk_field, primary_k = ForeignKey.find(other, self)
|
1525
|
+
if not fk_field:
|
1526
|
+
self.no_relation_error(other) # === raise ERROR ... ===
|
1527
|
+
query = other.copy()
|
1528
|
+
other = self.copy()
|
1529
|
+
query.__class__ = NotSelectIN
|
1530
|
+
Field.add(fk_field, query)
|
1531
|
+
query.add(primary_k, other)
|
1532
|
+
return other
|
1208
1533
|
|
1209
1534
|
def limit(self, row_count: int=100, offset: int=0):
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1535
|
+
if Function.dialect == Dialect.SQL_SERVER:
|
1536
|
+
fields = self.values.get(SELECT)
|
1537
|
+
if fields:
|
1538
|
+
fields[0] = f'SELECT TOP({row_count}) {fields[0]}'
|
1539
|
+
else:
|
1540
|
+
self.values[SELECT] = [f'SELECT TOP({row_count}) *']
|
1541
|
+
return self
|
1542
|
+
if Function.dialect == Dialect.ORACLE:
|
1543
|
+
Where.gte(row_count).add(SQL_ROW_NUM, self)
|
1544
|
+
if offset > 0:
|
1545
|
+
Where.lte(row_count+offset).add(SQL_ROW_NUM, self)
|
1546
|
+
return self
|
1547
|
+
self.values[LIMIT] = ['{}{}'.format(
|
1548
|
+
row_count, f' OFFSET {offset}' if offset > 0 else ''
|
1549
|
+
)]
|
1214
1550
|
return self
|
1215
1551
|
|
1216
|
-
def match(self,
|
1217
|
-
|
1552
|
+
def match(self, field: str, key: str) -> bool:
|
1553
|
+
'''
|
1554
|
+
Recognizes if the field is from the current table
|
1555
|
+
'''
|
1556
|
+
if key in (ORDER_BY, GROUP_BY) and '.' not in field:
|
1557
|
+
return self.has_named_field(field)
|
1558
|
+
return re.findall(f'\b*{self.alias}[.]', field) != []
|
1218
1559
|
|
1219
1560
|
@classmethod
|
1220
1561
|
def parse(cls, txt: str, parser: Parser = SQLParser) -> list[SQLObject]:
|
@@ -1226,12 +1567,10 @@ class Select(SQLObject):
|
|
1226
1567
|
for rule in rules:
|
1227
1568
|
rule.apply(self)
|
1228
1569
|
|
1229
|
-
def add_fields(self, fields: list,
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
if group_by:
|
1234
|
-
class_types += [GroupBy]
|
1570
|
+
def add_fields(self, fields: list, class_types=None):
|
1571
|
+
if not class_types:
|
1572
|
+
class_types = []
|
1573
|
+
class_types += [Field]
|
1235
1574
|
FieldList(fields, class_types).add('', self)
|
1236
1575
|
|
1237
1576
|
def translate_to(self, language: QueryLanguage) -> str:
|
@@ -1251,6 +1590,95 @@ class NotSelectIN(SelectIN):
|
|
1251
1590
|
condition_class = Not
|
1252
1591
|
|
1253
1592
|
|
1593
|
+
class CTE(Select):
|
1594
|
+
prefix = ''
|
1595
|
+
|
1596
|
+
def __init__(self, table_name: str, query_list: list[Select]):
|
1597
|
+
super().__init__(table_name)
|
1598
|
+
for query in query_list:
|
1599
|
+
query.break_lines = False
|
1600
|
+
self.query_list = query_list
|
1601
|
+
self.break_lines = False
|
1602
|
+
|
1603
|
+
def __str__(self) -> str:
|
1604
|
+
size = 0
|
1605
|
+
for key in USUAL_KEYS:
|
1606
|
+
size += sum(len(v) for v in self.values.get(key, []) if '\n' not in v)
|
1607
|
+
if size > 70:
|
1608
|
+
self.break_lines = True
|
1609
|
+
# ---------------------------------------------------------
|
1610
|
+
def justify(query: Select) -> str:
|
1611
|
+
result, line = [], ''
|
1612
|
+
keywords = '|'.join(KEYWORD)
|
1613
|
+
for word in re.split(fr'({keywords}|AND|OR|,)', str(query)):
|
1614
|
+
if len(line) >= 50:
|
1615
|
+
result.append(line)
|
1616
|
+
line = ''
|
1617
|
+
line += word
|
1618
|
+
if line:
|
1619
|
+
result.append(line)
|
1620
|
+
return '\n '.join(result)
|
1621
|
+
# ---------------------------------------------------------
|
1622
|
+
return 'WITH {}{} AS (\n {}\n){}'.format(
|
1623
|
+
self.prefix, self.table_name,
|
1624
|
+
'\nUNION ALL\n '.join(
|
1625
|
+
justify(q) for q in self.query_list
|
1626
|
+
), super().__str__()
|
1627
|
+
)
|
1628
|
+
|
1629
|
+
def join(self, pattern: str, fields: list | str, format: str=''):
|
1630
|
+
if isinstance(fields, str):
|
1631
|
+
count = len( fields.split(',') )
|
1632
|
+
else:
|
1633
|
+
count = len(fields)
|
1634
|
+
queries = detect(
|
1635
|
+
pattern*count, join_queries=False, format=format
|
1636
|
+
)
|
1637
|
+
FieldList(fields, queries, ziped=True).add('', self)
|
1638
|
+
self.break_lines = True
|
1639
|
+
return self
|
1640
|
+
|
1641
|
+
class Recursive(CTE):
|
1642
|
+
prefix = 'RECURSIVE '
|
1643
|
+
|
1644
|
+
def __str__(self) -> str:
|
1645
|
+
if len(self.query_list) > 1:
|
1646
|
+
self.query_list[-1].values[FROM].append(
|
1647
|
+
f', {self.table_name} {self.alias}')
|
1648
|
+
return super().__str__()
|
1649
|
+
|
1650
|
+
@classmethod
|
1651
|
+
def create(cls, name: str, pattern: str, formula: str, init_value, format: str=''):
|
1652
|
+
SQLObject.ALIAS_FUNC = None
|
1653
|
+
def get_field(obj: SQLObject, pos: int) -> str:
|
1654
|
+
return obj.values[SELECT][pos].split('.')[-1]
|
1655
|
+
t1, t2 = detect(
|
1656
|
+
pattern*2, join_queries=False, format=format
|
1657
|
+
)
|
1658
|
+
pk_field = get_field(t1, 0)
|
1659
|
+
foreign_key = ''
|
1660
|
+
for num in re.findall(r'\[(\d+)\]', formula):
|
1661
|
+
num = int(num)
|
1662
|
+
if not foreign_key:
|
1663
|
+
foreign_key = get_field(t2, num-1)
|
1664
|
+
formula = formula.replace(f'[{num}]', '%')
|
1665
|
+
else:
|
1666
|
+
formula = formula.replace(f'[{num}]', get_field(t2, num-1))
|
1667
|
+
Where.eq(init_value).add(pk_field, t1)
|
1668
|
+
Where.formula(formula).add(foreign_key or pk_field, t2)
|
1669
|
+
return cls(name, [t1, t2])
|
1670
|
+
|
1671
|
+
def counter(self, name: str, start, increment: str='+1'):
|
1672
|
+
for i, query in enumerate(self.query_list):
|
1673
|
+
if i == 0:
|
1674
|
+
Field.add(f'{start} AS {name}', query)
|
1675
|
+
else:
|
1676
|
+
Field.add(f'({name}{increment}) AS {name}', query)
|
1677
|
+
return self
|
1678
|
+
|
1679
|
+
|
1680
|
+
# ----- Rules -----
|
1681
|
+
|
1254
1682
|
class RulePutLimit(Rule):
|
1255
1683
|
@classmethod
|
1256
1684
|
def apply(cls, target: Select):
|
@@ -1314,6 +1742,8 @@ class RuleDateFuncReplace(Rule):
|
|
1314
1742
|
@classmethod
|
1315
1743
|
def apply(cls, target: Select):
|
1316
1744
|
for i, condition in enumerate(target.values.get(WHERE, [])):
|
1745
|
+
if not '(' in condition:
|
1746
|
+
continue
|
1317
1747
|
tokens = [
|
1318
1748
|
t.strip() for t in cls.REGEX.split(condition) if t.strip()
|
1319
1749
|
]
|
@@ -1325,6 +1755,32 @@ class RuleDateFuncReplace(Rule):
|
|
1325
1755
|
target.values[WHERE][i] = ' AND '.join(temp.values[WHERE])
|
1326
1756
|
|
1327
1757
|
|
1758
|
+
class RuleReplaceJoinBySubselect(Rule):
|
1759
|
+
@classmethod
|
1760
|
+
def apply(cls, target: Select):
|
1761
|
+
main, *others = Select.parse( str(target) )
|
1762
|
+
modified = False
|
1763
|
+
for query in others:
|
1764
|
+
fk_field, primary_k = ForeignKey.find(main, query)
|
1765
|
+
more_relations = any([
|
1766
|
+
ref[0] == query.table_name for ref in ForeignKey.references
|
1767
|
+
])
|
1768
|
+
keep_join = any([
|
1769
|
+
len( query.values.get(SELECT, []) ) > 0,
|
1770
|
+
len( query.values.get(WHERE, []) ) == 0,
|
1771
|
+
not fk_field, more_relations
|
1772
|
+
])
|
1773
|
+
if keep_join:
|
1774
|
+
query.add(fk_field, main)
|
1775
|
+
continue
|
1776
|
+
query.__class__ = SubSelect
|
1777
|
+
Field.add(primary_k, query)
|
1778
|
+
query.add(fk_field, main)
|
1779
|
+
modified = True
|
1780
|
+
if modified:
|
1781
|
+
target.values = main.values.copy()
|
1782
|
+
|
1783
|
+
|
1328
1784
|
def parser_class(text: str) -> Parser:
|
1329
1785
|
PARSER_REGEX = [
|
1330
1786
|
(r'select.*from', SQLParser),
|
@@ -1339,7 +1795,7 @@ def parser_class(text: str) -> Parser:
|
|
1339
1795
|
return None
|
1340
1796
|
|
1341
1797
|
|
1342
|
-
def detect(text: str) -> Select:
|
1798
|
+
def detect(text: str, join_queries: bool = True, format: str='') -> Select | list[Select]:
|
1343
1799
|
from collections import Counter
|
1344
1800
|
parser = parser_class(text)
|
1345
1801
|
if not parser:
|
@@ -1350,21 +1806,19 @@ def detect(text: str) -> Select:
|
|
1350
1806
|
continue
|
1351
1807
|
pos = [ f.span() for f in re.finditer(fr'({table})[(]', text) ]
|
1352
1808
|
for begin, end in pos[::-1]:
|
1353
|
-
new_name = f'{table}_{count}' # See set_table (line
|
1809
|
+
new_name = f'{table}_{count}' # See set_table (line 55)
|
1354
1810
|
Select.EQUIVALENT_NAMES[new_name] = table
|
1355
1811
|
text = text[:begin] + new_name + '(' + text[end:]
|
1356
1812
|
count -= 1
|
1357
1813
|
query_list = Select.parse(text, parser)
|
1814
|
+
if format:
|
1815
|
+
for query in query_list:
|
1816
|
+
query.set_file_format(format)
|
1817
|
+
if not join_queries:
|
1818
|
+
return query_list
|
1358
1819
|
result = query_list[0]
|
1359
1820
|
for query in query_list[1:]:
|
1360
1821
|
result += query
|
1361
1822
|
return result
|
1823
|
+
# ===========================================================================================//
|
1362
1824
|
|
1363
|
-
if __name__ == "__main__":
|
1364
|
-
OrderBy.sort = SortType.DESC
|
1365
|
-
query = Select(
|
1366
|
-
'order_Detail d',
|
1367
|
-
customer_id=GroupBy,
|
1368
|
-
_=Sum('d.unitPrice * d.quantity').As('total', OrderBy)
|
1369
|
-
)
|
1370
|
-
print(query)
|