sql-blocks 1.25.13__py3-none-any.whl → 1.25.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_blocks/sql_blocks.py +594 -111
- {sql_blocks-1.25.13.dist-info → sql_blocks-1.25.51.dist-info}/METADATA +324 -5
- sql_blocks-1.25.51.dist-info/RECORD +7 -0
- sql_blocks-1.25.13.dist-info/RECORD +0 -7
- {sql_blocks-1.25.13.dist-info → sql_blocks-1.25.51.dist-info}/LICENSE +0 -0
- {sql_blocks-1.25.13.dist-info → sql_blocks-1.25.51.dist-info}/WHEEL +0 -0
- {sql_blocks-1.25.13.dist-info → sql_blocks-1.25.51.dist-info}/top_level.txt +0 -0
sql_blocks/sql_blocks.py
CHANGED
@@ -26,7 +26,7 @@ TO_LIST = lambda x: x if isinstance(x, list) else [x]
|
|
26
26
|
|
27
27
|
|
28
28
|
class SQLObject:
|
29
|
-
ALIAS_FUNC =
|
29
|
+
ALIAS_FUNC = None
|
30
30
|
""" ^^^^^^^^^^^^^^^^^^^^^^^^
|
31
31
|
You can change the behavior by assigning
|
32
32
|
a user function to SQLObject.ALIAS_FUNC
|
@@ -41,21 +41,35 @@ class SQLObject:
|
|
41
41
|
def set_table(self, table_name: str):
|
42
42
|
if not table_name:
|
43
43
|
return
|
44
|
-
|
44
|
+
cls = SQLObject
|
45
|
+
is_file_name = any([
|
46
|
+
'/' in table_name, '.' in table_name
|
47
|
+
])
|
48
|
+
ref = table_name
|
49
|
+
if is_file_name:
|
50
|
+
ref = table_name.split('/')[-1].split('.')[0]
|
51
|
+
if cls.ALIAS_FUNC:
|
52
|
+
self.__alias = cls.ALIAS_FUNC(ref)
|
53
|
+
elif ' ' in table_name.strip():
|
45
54
|
table_name, self.__alias = table_name.split()
|
46
|
-
elif '_' in
|
55
|
+
elif '_' in ref:
|
47
56
|
self.__alias = ''.join(
|
48
57
|
word[0].lower()
|
49
|
-
for word in
|
58
|
+
for word in ref.split('_')
|
50
59
|
)
|
51
60
|
else:
|
52
|
-
self.__alias =
|
61
|
+
self.__alias = ref.lower()[:3]
|
53
62
|
self.values.setdefault(FROM, []).append(f'{table_name} {self.alias}')
|
54
63
|
|
55
64
|
@property
|
56
65
|
def table_name(self) -> str:
|
57
66
|
return self.values[FROM][0].split()[0]
|
58
67
|
|
68
|
+
def set_file_format(self, pattern: str):
|
69
|
+
if '{' not in pattern:
|
70
|
+
pattern = '{}' + pattern
|
71
|
+
self.values[FROM][0] = pattern.format(self.aka())
|
72
|
+
|
59
73
|
@property
|
60
74
|
def alias(self) -> str:
|
61
75
|
if self.__alias:
|
@@ -67,6 +81,16 @@ class SQLObject:
|
|
67
81
|
appendix = {WHERE: r'\s+and\s+|', FROM: r'\s+join\s+|\s+JOIN\s+'}
|
68
82
|
return KEYWORD[key][0].format(appendix.get(key, ''))
|
69
83
|
|
84
|
+
@staticmethod
|
85
|
+
def is_named_field(fld: str, name: str='') -> bool:
|
86
|
+
return re.search(fr'(\s+as\s+|\s+AS\s+){name}', fld)
|
87
|
+
|
88
|
+
def has_named_field(self, name: str) -> bool:
|
89
|
+
return any(
|
90
|
+
self.is_named_field(fld, name)
|
91
|
+
for fld in self.values.get(SELECT, [])
|
92
|
+
)
|
93
|
+
|
70
94
|
def diff(self, key: str, search_list: list, exact: bool=False) -> set:
|
71
95
|
def disassemble(source: list) -> list:
|
72
96
|
if not exact:
|
@@ -79,12 +103,12 @@ class SQLObject:
|
|
79
103
|
if exact:
|
80
104
|
fld = fld.lower()
|
81
105
|
return fld.strip()
|
82
|
-
def is_named_field(fld: str) -> bool:
|
83
|
-
return key == SELECT and re.search(r'\s+as\s+|\s+AS\s+', fld)
|
84
106
|
def field_set(source: list) -> set:
|
85
107
|
return set(
|
86
108
|
(
|
87
|
-
fld
|
109
|
+
fld
|
110
|
+
if key == SELECT and self.is_named_field(fld, key)
|
111
|
+
else
|
88
112
|
re.sub(pattern, '', cleanup(fld))
|
89
113
|
)
|
90
114
|
for string in disassemble(source)
|
@@ -102,13 +126,22 @@ class SQLObject:
|
|
102
126
|
return s1.symmetric_difference(s2)
|
103
127
|
return s1 - s2
|
104
128
|
|
105
|
-
def delete(self, search: str, keys: list=USUAL_KEYS):
|
129
|
+
def delete(self, search: str, keys: list=USUAL_KEYS, exact: bool=False):
|
130
|
+
if exact:
|
131
|
+
not_match = lambda item: not re.search(fr'\w*[.]*{search}$', item)
|
132
|
+
else:
|
133
|
+
not_match = lambda item: search not in item
|
106
134
|
for key in keys:
|
107
|
-
|
108
|
-
|
109
|
-
if
|
110
|
-
|
111
|
-
|
135
|
+
self.values[key] = [
|
136
|
+
item for item in self.values.get(key, [])
|
137
|
+
if not_match(item)
|
138
|
+
]
|
139
|
+
|
140
|
+
|
141
|
+
SQL_CONST_SYSDATE = 'SYSDATE'
|
142
|
+
SQL_CONST_CURR_DATE = 'Current_date'
|
143
|
+
SQL_ROW_NUM = 'ROWNUM'
|
144
|
+
SQL_CONSTS = [SQL_CONST_SYSDATE, SQL_CONST_CURR_DATE, SQL_ROW_NUM]
|
112
145
|
|
113
146
|
|
114
147
|
class Field:
|
@@ -116,10 +149,16 @@ class Field:
|
|
116
149
|
|
117
150
|
@classmethod
|
118
151
|
def format(cls, name: str, main: SQLObject) -> str:
|
152
|
+
def is_const() -> bool:
|
153
|
+
return any([
|
154
|
+
re.findall('[.()0-9]', name),
|
155
|
+
name in SQL_CONSTS,
|
156
|
+
re.findall(r'\w+\s*[+-]\s*\w+', name)
|
157
|
+
])
|
119
158
|
name = name.strip()
|
120
159
|
if name in ('_', '*'):
|
121
160
|
name = '*'
|
122
|
-
elif not
|
161
|
+
elif not is_const() and not main.has_named_field(name):
|
123
162
|
name = f'{main.alias}.{name}'
|
124
163
|
if Function in cls.__bases__:
|
125
164
|
name = f'{cls.__name__}({name})'
|
@@ -150,35 +189,89 @@ class NamedField:
|
|
150
189
|
)
|
151
190
|
|
152
191
|
|
192
|
+
class Dialect(Enum):
|
193
|
+
ANSI = 0
|
194
|
+
SQL_SERVER = 1
|
195
|
+
ORACLE = 2
|
196
|
+
POSTGRESQL = 3
|
197
|
+
MYSQL = 4
|
198
|
+
|
199
|
+
SQL_TYPES = 'CHAR INT DATE FLOAT ANY'.split()
|
200
|
+
CHAR, INT, DATE, FLOAT, ANY = SQL_TYPES
|
201
|
+
|
153
202
|
class Function:
|
203
|
+
dialect = Dialect.ANSI
|
204
|
+
inputs = None
|
205
|
+
output = None
|
206
|
+
separator = ', '
|
207
|
+
auto_convert = True
|
208
|
+
append_param = False
|
209
|
+
|
154
210
|
def __init__(self, *params: list):
|
211
|
+
def set_func_types(param):
|
212
|
+
if self.auto_convert and isinstance(param, Function):
|
213
|
+
func = param
|
214
|
+
main_param = self.inputs[0]
|
215
|
+
unfriendly = all([
|
216
|
+
func.output != main_param,
|
217
|
+
func.output != ANY,
|
218
|
+
main_param != ANY
|
219
|
+
])
|
220
|
+
if unfriendly:
|
221
|
+
return Cast(func, main_param)
|
222
|
+
return param
|
155
223
|
# --- Replace class methods by instance methods: ------
|
156
224
|
self.add = self.__add
|
157
225
|
self.format = self.__format
|
158
226
|
# -----------------------------------------------------
|
159
|
-
self.params = [
|
227
|
+
self.params = [set_func_types(p) for p in params]
|
160
228
|
self.field_class = Field
|
161
|
-
self.pattern =
|
229
|
+
self.pattern = self.get_pattern()
|
162
230
|
self.extra = {}
|
163
231
|
|
232
|
+
def get_pattern(self) -> str:
|
233
|
+
return '{func_name}({params})'
|
234
|
+
|
164
235
|
def As(self, field_alias: str, modifiers=None):
|
165
236
|
if modifiers:
|
166
237
|
self.extra[field_alias] = TO_LIST(modifiers)
|
167
238
|
self.field_class = NamedField(field_alias)
|
168
239
|
return self
|
169
240
|
|
170
|
-
def
|
171
|
-
if name in '*_' and self.params:
|
172
|
-
params = self.params
|
173
|
-
else:
|
174
|
-
params = [
|
175
|
-
Field.format(name, main)
|
176
|
-
] + self.params
|
241
|
+
def __str__(self) -> str:
|
177
242
|
return self.pattern.format(
|
178
|
-
self.__class__.__name__,
|
179
|
-
|
243
|
+
func_name=self.__class__.__name__,
|
244
|
+
params=self.separator.join(str(p) for p in self.params)
|
180
245
|
)
|
181
246
|
|
247
|
+
@classmethod
|
248
|
+
def help(cls) -> str:
|
249
|
+
descr = ' '.join(B.__name__ for B in cls.__bases__)
|
250
|
+
params = cls.inputs or ''
|
251
|
+
return cls().get_pattern().format(
|
252
|
+
func_name=f'{descr} {cls.__name__}',
|
253
|
+
params=cls.separator.join(str(p) for p in params)
|
254
|
+
) + f' Return {cls.output}'
|
255
|
+
|
256
|
+
def set_main_param(self, name: str, main: SQLObject) -> bool:
|
257
|
+
nested_functions = [
|
258
|
+
param for param in self.params if isinstance(param, Function)
|
259
|
+
]
|
260
|
+
for func in nested_functions:
|
261
|
+
if func.inputs:
|
262
|
+
func.set_main_param(name, main)
|
263
|
+
return
|
264
|
+
new_params = [Field.format(name, main)]
|
265
|
+
if self.append_param:
|
266
|
+
self.params += new_params
|
267
|
+
else:
|
268
|
+
self.params = new_params + self.params
|
269
|
+
|
270
|
+
def __format(self, name: str, main: SQLObject) -> str:
|
271
|
+
if name not in '*_':
|
272
|
+
self.set_main_param(name, main)
|
273
|
+
return str(self)
|
274
|
+
|
182
275
|
@classmethod
|
183
276
|
def format(cls, name: str, main: SQLObject):
|
184
277
|
return cls().__format(name, main)
|
@@ -196,39 +289,110 @@ class Function:
|
|
196
289
|
|
197
290
|
# ---- String Functions: ---------------------------------
|
198
291
|
class SubString(Function):
|
199
|
-
|
292
|
+
inputs = [CHAR, INT, INT]
|
293
|
+
output = CHAR
|
294
|
+
|
295
|
+
def get_pattern(self) -> str:
|
296
|
+
if self.dialect in (Dialect.ORACLE, Dialect.MYSQL):
|
297
|
+
return 'Substr({params})'
|
298
|
+
return super().get_pattern()
|
200
299
|
|
201
300
|
# ---- Numeric Functions: --------------------------------
|
202
301
|
class Round(Function):
|
203
|
-
|
302
|
+
inputs = [FLOAT]
|
303
|
+
output = FLOAT
|
204
304
|
|
205
305
|
# --- Date Functions: ------------------------------------
|
206
306
|
class DateDiff(Function):
|
307
|
+
inputs = [DATE]
|
308
|
+
output = DATE
|
309
|
+
append_param = True
|
310
|
+
|
311
|
+
def __str__(self) -> str:
|
312
|
+
def is_field_or_func(name: str) -> bool:
|
313
|
+
candidate = re.sub(
|
314
|
+
'[()]', '', name.split('.')[-1]
|
315
|
+
)
|
316
|
+
return candidate.isidentifier()
|
317
|
+
if self.dialect != Dialect.SQL_SERVER:
|
318
|
+
params = [str(p) for p in self.params]
|
319
|
+
return ' - '.join(
|
320
|
+
p if is_field_or_func(p) else f"'{p}'"
|
321
|
+
for p in params
|
322
|
+
) # <==== Date subtract
|
323
|
+
return super().__str__()
|
324
|
+
|
325
|
+
|
326
|
+
class DatePart(Function):
|
327
|
+
inputs = [DATE]
|
328
|
+
output = INT
|
329
|
+
|
330
|
+
def get_pattern(self) -> str:
|
331
|
+
interval = self.__class__.__name__
|
332
|
+
database_type = {
|
333
|
+
Dialect.ORACLE: 'Extract('+interval+' FROM {params})',
|
334
|
+
Dialect.POSTGRESQL: "Date_Part('"+interval+"', {params})",
|
335
|
+
}
|
336
|
+
if self.dialect in database_type:
|
337
|
+
return database_type[self.dialect]
|
338
|
+
return super().get_pattern()
|
339
|
+
|
340
|
+
class Year(DatePart):
|
207
341
|
...
|
208
|
-
class
|
342
|
+
class Month(DatePart):
|
209
343
|
...
|
210
|
-
class DatePart
|
344
|
+
class Day(DatePart):
|
211
345
|
...
|
346
|
+
|
347
|
+
|
212
348
|
class Current_Date(Function):
|
213
|
-
|
349
|
+
output = DATE
|
350
|
+
|
351
|
+
def get_pattern(self) -> str:
|
352
|
+
database_type = {
|
353
|
+
Dialect.ORACLE: SQL_CONST_SYSDATE,
|
354
|
+
Dialect.POSTGRESQL: SQL_CONST_CURR_DATE,
|
355
|
+
Dialect.SQL_SERVER: 'getDate()'
|
356
|
+
}
|
357
|
+
if self.dialect in database_type:
|
358
|
+
return database_type[self.dialect]
|
359
|
+
return super().get_pattern()
|
360
|
+
# --------------------------------------------------------
|
214
361
|
|
215
|
-
class
|
362
|
+
class Frame:
|
216
363
|
break_lines: bool = True
|
217
364
|
|
218
365
|
def over(self, **args):
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
366
|
+
"""
|
367
|
+
How to use:
|
368
|
+
over(field1=OrderBy, field2=Partition)
|
369
|
+
"""
|
370
|
+
keywords = ''
|
371
|
+
for field, obj in args.items():
|
372
|
+
is_valid = any([
|
373
|
+
obj is OrderBy,
|
374
|
+
obj is Partition,
|
375
|
+
isinstance(obj, Rows),
|
376
|
+
])
|
377
|
+
if not is_valid:
|
378
|
+
continue
|
379
|
+
keywords += '{}{} {}'.format(
|
380
|
+
'\n\t\t' if self.break_lines else ' ',
|
381
|
+
obj.cls_to_str(), field if field != '_' else ''
|
382
|
+
)
|
226
383
|
if keywords and self.break_lines:
|
227
384
|
keywords += '\n\t'
|
228
|
-
self.pattern =
|
385
|
+
self.pattern = self.get_pattern() + f' OVER({keywords})'
|
229
386
|
return self
|
230
387
|
|
231
388
|
|
389
|
+
class Aggregate(Frame):
|
390
|
+
inputs = [FLOAT]
|
391
|
+
output = FLOAT
|
392
|
+
|
393
|
+
class Window(Frame):
|
394
|
+
...
|
395
|
+
|
232
396
|
# ---- Aggregate Functions: -------------------------------
|
233
397
|
class Avg(Aggregate, Function):
|
234
398
|
...
|
@@ -241,11 +405,32 @@ class Sum(Aggregate, Function):
|
|
241
405
|
class Count(Aggregate, Function):
|
242
406
|
...
|
243
407
|
|
408
|
+
# ---- Window Functions: -----------------------------------
|
409
|
+
class Row_Number(Window, Function):
|
410
|
+
output = INT
|
411
|
+
|
412
|
+
class Rank(Window, Function):
|
413
|
+
output = INT
|
414
|
+
|
415
|
+
class Lag(Window, Function):
|
416
|
+
output = ANY
|
417
|
+
|
418
|
+
class Lead(Window, Function):
|
419
|
+
output = ANY
|
420
|
+
|
421
|
+
|
244
422
|
# ---- Conversions and other Functions: ---------------------
|
245
423
|
class Coalesce(Function):
|
246
|
-
|
424
|
+
inputs = [ANY]
|
425
|
+
output = ANY
|
426
|
+
|
247
427
|
class Cast(Function):
|
248
|
-
|
428
|
+
inputs = [ANY]
|
429
|
+
output = ANY
|
430
|
+
separator = ' As '
|
431
|
+
|
432
|
+
|
433
|
+
FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
|
249
434
|
|
250
435
|
|
251
436
|
class ExpressionField:
|
@@ -270,15 +455,20 @@ class ExpressionField:
|
|
270
455
|
class FieldList:
|
271
456
|
separator = ','
|
272
457
|
|
273
|
-
def __init__(self, fields: list=[], class_types = [Field]):
|
458
|
+
def __init__(self, fields: list=[], class_types = [Field], ziped: bool=False):
|
274
459
|
if isinstance(fields, str):
|
275
460
|
fields = [
|
276
461
|
f.strip() for f in fields.split(self.separator)
|
277
462
|
]
|
278
463
|
self.fields = fields
|
279
464
|
self.class_types = class_types
|
465
|
+
self.ziped = ziped
|
280
466
|
|
281
467
|
def add(self, name: str, main: SQLObject):
|
468
|
+
if self.ziped: # --- One class per field...
|
469
|
+
for field, class_type in zip(self.fields, self.class_types):
|
470
|
+
class_type.add(field, main)
|
471
|
+
return
|
282
472
|
for field in self.fields:
|
283
473
|
for class_type in self.class_types:
|
284
474
|
class_type.add(field, main)
|
@@ -320,27 +510,43 @@ class ForeignKey:
|
|
320
510
|
|
321
511
|
def quoted(value) -> str:
|
322
512
|
if isinstance(value, str):
|
513
|
+
if re.search(r'\bor\b', value, re.IGNORECASE):
|
514
|
+
raise PermissionError('Possible SQL injection attempt')
|
323
515
|
value = f"'{value}'"
|
324
516
|
return str(value)
|
325
517
|
|
326
518
|
|
519
|
+
class Position(Enum):
|
520
|
+
StartsWith = -1
|
521
|
+
Middle = 0
|
522
|
+
EndsWith = 1
|
523
|
+
|
524
|
+
|
327
525
|
class Where:
|
328
526
|
prefix = ''
|
329
527
|
|
330
|
-
def __init__(self,
|
331
|
-
self.
|
528
|
+
def __init__(self, content: str):
|
529
|
+
self.content = content
|
332
530
|
|
333
531
|
@classmethod
|
334
532
|
def __constructor(cls, operator: str, value):
|
335
|
-
return cls(
|
533
|
+
return cls(f'{operator} {quoted(value)}')
|
336
534
|
|
337
535
|
@classmethod
|
338
536
|
def eq(cls, value):
|
339
537
|
return cls.__constructor('=', value)
|
340
538
|
|
341
539
|
@classmethod
|
342
|
-
def contains(cls,
|
343
|
-
|
540
|
+
def contains(cls, text: str, pos: int | Position = Position.Middle):
|
541
|
+
if isinstance(pos, int):
|
542
|
+
pos = Position(pos)
|
543
|
+
return cls(
|
544
|
+
"LIKE '{}{}{}'".format(
|
545
|
+
'%' if pos != Position.StartsWith else '',
|
546
|
+
text,
|
547
|
+
'%' if pos != Position.EndsWith else ''
|
548
|
+
)
|
549
|
+
)
|
344
550
|
|
345
551
|
@classmethod
|
346
552
|
def gt(cls, value):
|
@@ -368,9 +574,42 @@ class Where:
|
|
368
574
|
values = ','.join(quoted(v) for v in values)
|
369
575
|
return cls(f'IN ({values})')
|
370
576
|
|
577
|
+
@classmethod
|
578
|
+
def formula(cls, formula: str):
|
579
|
+
where = cls( ExpressionField(formula) )
|
580
|
+
where.add = where.add_expression
|
581
|
+
return where
|
582
|
+
|
583
|
+
def add_expression(self, name: str, main: SQLObject):
|
584
|
+
self.content = self.content.format(name, main)
|
585
|
+
main.values.setdefault(WHERE, []).append('{} {}'.format(
|
586
|
+
self.prefix, self.content
|
587
|
+
))
|
588
|
+
|
589
|
+
@classmethod
|
590
|
+
def join(cls, query: SQLObject):
|
591
|
+
where = cls(query)
|
592
|
+
where.add = where.add_join
|
593
|
+
return where
|
594
|
+
|
595
|
+
def add_join(self, name: str, main: SQLObject):
|
596
|
+
query = self.content
|
597
|
+
main.values[FROM].append(f',{query.table_name} {query.alias}')
|
598
|
+
for key in USUAL_KEYS:
|
599
|
+
main.update_values(key, query.values.get(key, []))
|
600
|
+
main.values.setdefault(WHERE, []).append('({a1}.{f1} = {a2}.{f2})'.format(
|
601
|
+
a1=main.alias, f1=name,
|
602
|
+
a2=query.alias, f2=query.key_field
|
603
|
+
))
|
604
|
+
|
371
605
|
def add(self, name: str, main: SQLObject):
|
606
|
+
func_type = FUNCTION_CLASS.get(name.lower())
|
607
|
+
if func_type:
|
608
|
+
name = func_type.format('*', main)
|
609
|
+
elif not main.has_named_field(name):
|
610
|
+
name = Field.format(name, main)
|
372
611
|
main.values.setdefault(WHERE, []).append('{}{} {}'.format(
|
373
|
-
self.prefix,
|
612
|
+
self.prefix, name, self.content
|
374
613
|
))
|
375
614
|
|
376
615
|
|
@@ -378,6 +617,10 @@ eq, contains, gt, gte, lt, lte, is_null, inside = (
|
|
378
617
|
getattr(Where, method) for method in
|
379
618
|
('eq', 'contains', 'gt', 'gte', 'lt', 'lte', 'is_null', 'inside')
|
380
619
|
)
|
620
|
+
startswith, endswith = [
|
621
|
+
lambda x: contains(x, Position.StartsWith),
|
622
|
+
lambda x: contains(x, Position.EndsWith)
|
623
|
+
]
|
381
624
|
|
382
625
|
|
383
626
|
class Not(Where):
|
@@ -385,7 +628,7 @@ class Not(Where):
|
|
385
628
|
|
386
629
|
@classmethod
|
387
630
|
def eq(cls, value):
|
388
|
-
return Where(
|
631
|
+
return Where(f'<> {quoted(value)}')
|
389
632
|
|
390
633
|
|
391
634
|
class Case:
|
@@ -394,22 +637,26 @@ class Case:
|
|
394
637
|
self.default = None
|
395
638
|
self.field = field
|
396
639
|
|
397
|
-
def when(self, condition: Where, result
|
640
|
+
def when(self, condition: Where, result):
|
641
|
+
if isinstance(result, str):
|
642
|
+
result = quoted(result)
|
398
643
|
self.__conditions[result] = condition
|
399
644
|
return self
|
400
645
|
|
401
|
-
def else_value(self, default
|
646
|
+
def else_value(self, default):
|
647
|
+
if isinstance(default, str):
|
648
|
+
default = quoted(default)
|
402
649
|
self.default = default
|
403
650
|
return self
|
404
651
|
|
405
652
|
def add(self, name: str, main: SQLObject):
|
406
653
|
field = Field.format(self.field, main)
|
407
|
-
default =
|
654
|
+
default = self.default
|
408
655
|
name = 'CASE \n{}\n\tEND AS {}'.format(
|
409
656
|
'\n'.join(
|
410
|
-
f'\t\tWHEN {field} {cond.
|
657
|
+
f'\t\tWHEN {field} {cond.content} THEN {res}'
|
411
658
|
for res, cond in self.__conditions.items()
|
412
|
-
) + f'\n\t\tELSE {default}' if default else '',
|
659
|
+
) + (f'\n\t\tELSE {default}' if default else ''),
|
413
660
|
name
|
414
661
|
)
|
415
662
|
main.values.setdefault(SELECT, []).append(name)
|
@@ -420,14 +667,13 @@ class Options:
|
|
420
667
|
self.__children: dict = values
|
421
668
|
|
422
669
|
def add(self, logical_separator: str, main: SQLObject):
|
423
|
-
|
424
|
-
|
425
|
-
"""
|
670
|
+
if logical_separator not in ('AND', 'OR'):
|
671
|
+
raise ValueError('`logical_separator` must be AND or OR')
|
426
672
|
conditions: list[str] = []
|
427
673
|
child: Where
|
428
674
|
for field, child in self.__children.items():
|
429
675
|
conditions.append(' {} {} '.format(
|
430
|
-
Field.format(field, main), child.
|
676
|
+
Field.format(field, main), child.content
|
431
677
|
))
|
432
678
|
main.values.setdefault(WHERE, []).append(
|
433
679
|
'(' + logical_separator.join(conditions) + ')'
|
@@ -435,28 +681,57 @@ class Options:
|
|
435
681
|
|
436
682
|
|
437
683
|
class Between:
|
684
|
+
is_literal: bool = False
|
685
|
+
|
438
686
|
def __init__(self, start, end):
|
439
687
|
if start > end:
|
440
688
|
start, end = end, start
|
441
689
|
self.start = start
|
442
690
|
self.end = end
|
443
691
|
|
692
|
+
def literal(self) -> Where:
|
693
|
+
return Where('BETWEEN {} AND {}'.format(
|
694
|
+
self.start, self.end
|
695
|
+
))
|
696
|
+
|
444
697
|
def add(self, name: str, main:SQLObject):
|
445
|
-
|
698
|
+
if self.is_literal:
|
699
|
+
return self.literal().add(name, main)
|
700
|
+
Where.gte(self.start).add(name, main)
|
446
701
|
Where.lte(self.end).add(name, main)
|
447
702
|
|
703
|
+
class SameDay(Between):
|
704
|
+
def __init__(self, date: str):
|
705
|
+
super().__init__(
|
706
|
+
f'{date} 00:00:00',
|
707
|
+
f'{date} 23:59:59',
|
708
|
+
)
|
709
|
+
|
710
|
+
|
711
|
+
class Range(Case):
|
712
|
+
INC_FUNCTION = lambda x: x + 1
|
713
|
+
|
714
|
+
def __init__(self, field: str, values: dict):
|
715
|
+
super().__init__(field)
|
716
|
+
start = 0
|
717
|
+
cls = self.__class__
|
718
|
+
for label, value in sorted(values.items(), key=lambda item: item[1]):
|
719
|
+
self.when(
|
720
|
+
Between(start, value).literal(), label
|
721
|
+
)
|
722
|
+
start = cls.INC_FUNCTION(value)
|
723
|
+
|
448
724
|
|
449
725
|
class Clause:
|
450
726
|
@classmethod
|
451
727
|
def format(cls, name: str, main: SQLObject) -> str:
|
452
728
|
def is_function() -> bool:
|
453
729
|
diff = main.diff(SELECT, [name.lower()], True)
|
454
|
-
FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
|
455
730
|
return diff.intersection(FUNCTION_CLASS)
|
456
731
|
found = re.findall(r'^_\d', name)
|
457
732
|
if found:
|
458
733
|
name = found[0].replace('_', '')
|
459
|
-
elif main.alias and not is_function():
|
734
|
+
elif '.' not in name and main.alias and not is_function():
|
460
735
|
name = f'{main.alias}.{name}'
|
461
736
|
return name
|
462
737
|
|
@@ -465,6 +740,34 @@ class SortType(Enum):
|
|
465
740
|
ASC = ''
|
466
741
|
DESC = ' DESC'
|
467
742
|
|
743
|
+
class Row:
|
744
|
+
def __init__(self, value: int=0):
|
745
|
+
self.value = value
|
746
|
+
|
747
|
+
def __str__(self) -> str:
|
748
|
+
return '{} {}'.format(
|
749
|
+
'UNBOUNDED' if self.value == 0 else self.value,
|
750
|
+
self.__class__.__name__.upper()
|
751
|
+
)
|
752
|
+
|
753
|
+
class Preceding(Row):
|
754
|
+
...
|
755
|
+
class Following(Row):
|
756
|
+
...
|
757
|
+
class Current(Row):
|
758
|
+
def __str__(self) -> str:
|
759
|
+
return 'CURRENT ROW'
|
760
|
+
|
761
|
+
class Rows:
|
762
|
+
def __init__(self, *rows: list[Row]):
|
763
|
+
self.rows = rows
|
764
|
+
|
765
|
+
def cls_to_str(self) -> str:
|
766
|
+
return 'ROWS {}{}'.format(
|
767
|
+
'BETWEEN ' if len(self.rows) > 1 else '',
|
768
|
+
' AND '.join(str(row) for row in self.rows)
|
769
|
+
)
|
770
|
+
|
468
771
|
|
469
772
|
class OrderBy(Clause):
|
470
773
|
sort: SortType = SortType.ASC
|
@@ -474,6 +777,16 @@ class OrderBy(Clause):
|
|
474
777
|
name = cls.format(name, main)
|
475
778
|
main.values.setdefault(ORDER_BY, []).append(name+cls.sort.value)
|
476
779
|
|
780
|
+
@classmethod
|
781
|
+
def cls_to_str(cls) -> str:
|
782
|
+
return ORDER_BY
|
783
|
+
|
784
|
+
PARTITION_BY = 'PARTITION BY'
|
785
|
+
class Partition:
|
786
|
+
@classmethod
|
787
|
+
def cls_to_str(cls) -> str:
|
788
|
+
return PARTITION_BY
|
789
|
+
|
477
790
|
|
478
791
|
class GroupBy(Clause):
|
479
792
|
@classmethod
|
@@ -489,7 +802,7 @@ class Having:
|
|
489
802
|
|
490
803
|
def add(self, name: str, main:SQLObject):
|
491
804
|
main.values[GROUP_BY][-1] += ' HAVING {} {}'.format(
|
492
|
-
self.function.format(name, main), self.condition.
|
805
|
+
self.function.format(name, main), self.condition.content
|
493
806
|
)
|
494
807
|
|
495
808
|
@classmethod
|
@@ -519,7 +832,7 @@ class Rule:
|
|
519
832
|
...
|
520
833
|
|
521
834
|
class QueryLanguage:
|
522
|
-
pattern = '{select}{_from}{where}{group_by}{order_by}'
|
835
|
+
pattern = '{select}{_from}{where}{group_by}{order_by}{limit}'
|
523
836
|
has_default = {key: bool(key == SELECT) for key in KEYWORD}
|
524
837
|
|
525
838
|
@staticmethod
|
@@ -542,18 +855,21 @@ class QueryLanguage:
|
|
542
855
|
return self.join_with_tabs(values, ' AND ')
|
543
856
|
|
544
857
|
def sort_by(self, values: list) -> str:
|
545
|
-
return self.join_with_tabs(values)
|
858
|
+
return self.join_with_tabs(values, ',')
|
546
859
|
|
547
860
|
def set_group(self, values: list) -> str:
|
548
861
|
return self.join_with_tabs(values, ',')
|
549
862
|
|
863
|
+
def set_limit(self, values: list) -> str:
|
864
|
+
return self.join_with_tabs(values, ' ')
|
865
|
+
|
550
866
|
def __init__(self, target: 'Select'):
|
551
|
-
self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY]
|
867
|
+
self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY, LIMIT]
|
552
868
|
self.TABULATION = '\n\t' if target.break_lines else ' '
|
553
869
|
self.LINE_BREAK = '\n' if target.break_lines else ' '
|
554
870
|
self.TOKEN_METHODS = {
|
555
871
|
SELECT: self.add_field, FROM: self.get_tables,
|
556
|
-
WHERE: self.extract_conditions,
|
872
|
+
WHERE: self.extract_conditions, LIMIT: self.set_limit,
|
557
873
|
ORDER_BY: self.sort_by, GROUP_BY: self.set_group,
|
558
874
|
}
|
559
875
|
self.result = {}
|
@@ -857,10 +1173,13 @@ class SQLParser(Parser):
|
|
857
1173
|
if not key in values:
|
858
1174
|
continue
|
859
1175
|
separator = self.class_type.get_separator(key)
|
1176
|
+
cls = {
|
1177
|
+
ORDER_BY: OrderBy, GROUP_BY: GroupBy
|
1178
|
+
}.get(key, Field)
|
860
1179
|
obj.values[key] = [
|
861
|
-
|
1180
|
+
cls.format(fld, obj)
|
862
1181
|
for fld in re.split(separator, values[key])
|
863
|
-
if (fld != '*' and len(tables) == 1) or obj.match(fld)
|
1182
|
+
if (fld != '*' and len(tables) == 1) or obj.match(fld, key)
|
864
1183
|
]
|
865
1184
|
result[obj.alias] = obj
|
866
1185
|
self.queries = list( result.values() )
|
@@ -920,16 +1239,26 @@ class CypherParser(Parser):
|
|
920
1239
|
if token in self.TOKEN_METHODS:
|
921
1240
|
return
|
922
1241
|
class_list = [Field]
|
923
|
-
if '
|
1242
|
+
if '*' in token:
|
1243
|
+
token = token.replace('*', '')
|
1244
|
+
self.queries[-1].key_field = token
|
1245
|
+
return
|
1246
|
+
elif '$' in token:
|
924
1247
|
func_name, token = token.split('$')
|
925
1248
|
if func_name == 'count':
|
926
1249
|
if not token:
|
927
1250
|
token = 'count_1'
|
928
|
-
|
929
|
-
|
1251
|
+
pk_field = self.queries[-1].key_field or 'id'
|
1252
|
+
Count().As(token, extra_classes).add(pk_field, self.queries[-1])
|
1253
|
+
return
|
930
1254
|
else:
|
931
|
-
|
932
|
-
|
1255
|
+
class_type = FUNCTION_CLASS.get(func_name)
|
1256
|
+
if not class_type:
|
1257
|
+
raise ValueError(f'Unknown function `{func_name}`.')
|
1258
|
+
if ':' in token:
|
1259
|
+
token, field_alias = token.split(':')
|
1260
|
+
class_type = class_type().As(field_alias)
|
1261
|
+
class_list = [class_type]
|
933
1262
|
class_list += extra_classes
|
934
1263
|
FieldList(token, class_list).add('', self.queries[-1])
|
935
1264
|
|
@@ -944,10 +1273,13 @@ class CypherParser(Parser):
|
|
944
1273
|
def add_foreign_key(self, token: str, pk_field: str=''):
|
945
1274
|
curr, last = [self.queries[i] for i in (-1, -2)]
|
946
1275
|
if not pk_field:
|
947
|
-
if
|
948
|
-
|
949
|
-
|
950
|
-
|
1276
|
+
if last.key_field:
|
1277
|
+
pk_field = last.key_field
|
1278
|
+
else:
|
1279
|
+
if not last.values.get(SELECT):
|
1280
|
+
raise IndexError(f'Primary Key not found for {last.table_name}.')
|
1281
|
+
pk_field = last.values[SELECT][-1].split('.')[-1]
|
1282
|
+
last.delete(pk_field, [SELECT], exact=True)
|
951
1283
|
if '{}' in token:
|
952
1284
|
foreign_fld = token.format(
|
953
1285
|
last.table_name.lower()
|
@@ -962,12 +1294,11 @@ class CypherParser(Parser):
|
|
962
1294
|
if fld not in curr.values.get(GROUP_BY, [])
|
963
1295
|
]
|
964
1296
|
foreign_fld = fields[0].split('.')[-1]
|
965
|
-
curr.delete(foreign_fld, [SELECT])
|
1297
|
+
curr.delete(foreign_fld, [SELECT], exact=True)
|
966
1298
|
if curr.join_type == JoinType.RIGHT:
|
967
1299
|
pk_field, foreign_fld = foreign_fld, pk_field
|
968
1300
|
if curr.join_type == JoinType.RIGHT:
|
969
1301
|
curr, last = last, curr
|
970
|
-
# pk_field, foreign_fld = foreign_fld, pk_field
|
971
1302
|
k = ForeignKey.get_key(curr, last)
|
972
1303
|
ForeignKey.references[k] = (foreign_fld, pk_field)
|
973
1304
|
|
@@ -1135,7 +1466,6 @@ class MongoParser(Parser):
|
|
1135
1466
|
|
1136
1467
|
class Select(SQLObject):
|
1137
1468
|
join_type: JoinType = JoinType.INNER
|
1138
|
-
REGEX = {}
|
1139
1469
|
EQUIVALENT_NAMES = {}
|
1140
1470
|
|
1141
1471
|
def __init__(self, table_name: str='', **values):
|
@@ -1153,21 +1483,30 @@ class Select(SQLObject):
|
|
1153
1483
|
|
1154
1484
|
def add(self, name: str, main: SQLObject):
|
1155
1485
|
old_tables = main.values.get(FROM, [])
|
1156
|
-
|
1157
|
-
|
1486
|
+
if len(self.values[FROM]) > 1:
|
1487
|
+
old_tables += self.values[FROM][1:]
|
1488
|
+
new_tables = []
|
1489
|
+
row = '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
|
1158
1490
|
jt=self.join_type.value,
|
1159
1491
|
tb=self.aka(),
|
1160
1492
|
a1=main.alias, f1=name,
|
1161
1493
|
a2=self.alias, f2=self.key_field
|
1162
1494
|
)
|
1163
|
-
|
1164
|
-
|
1495
|
+
if row not in old_tables[1:]:
|
1496
|
+
new_tables.append(row)
|
1497
|
+
main.values[FROM] = old_tables[:1] + new_tables + old_tables[1:]
|
1165
1498
|
for key in USUAL_KEYS:
|
1166
1499
|
main.update_values(key, self.values.get(key, []))
|
1167
1500
|
|
1168
|
-
def
|
1501
|
+
def copy(self) -> SQLObject:
|
1169
1502
|
from copy import deepcopy
|
1170
|
-
|
1503
|
+
return deepcopy(self)
|
1504
|
+
|
1505
|
+
def no_relation_error(self, other: SQLObject):
|
1506
|
+
raise ValueError(f'No relationship found between {self.table_name} and {other.table_name}.')
|
1507
|
+
|
1508
|
+
def __add__(self, other: SQLObject):
|
1509
|
+
query = self.copy()
|
1171
1510
|
if query.table_name.lower() == other.table_name.lower():
|
1172
1511
|
for key in USUAL_KEYS:
|
1173
1512
|
query.update_values(key, other.values.get(key, []))
|
@@ -1180,7 +1519,7 @@ class Select(SQLObject):
|
|
1180
1519
|
PrimaryKey.add(primary_key, query)
|
1181
1520
|
query.add(foreign_field, other)
|
1182
1521
|
return other
|
1183
|
-
|
1522
|
+
self.no_relation_error(other) # === raise ERROR ... ===
|
1184
1523
|
elif primary_key:
|
1185
1524
|
PrimaryKey.add(primary_key, other)
|
1186
1525
|
other.add(foreign_field, query)
|
@@ -1200,16 +1539,48 @@ class Select(SQLObject):
|
|
1200
1539
|
if self.diff(key, other.values.get(key, []), True):
|
1201
1540
|
return False
|
1202
1541
|
return True
|
1542
|
+
|
1543
|
+
def __sub__(self, other: SQLObject) -> SQLObject:
|
1544
|
+
fk_field, primary_k = ForeignKey.find(self, other)
|
1545
|
+
if fk_field:
|
1546
|
+
query = self.copy()
|
1547
|
+
other = other.copy()
|
1548
|
+
else:
|
1549
|
+
fk_field, primary_k = ForeignKey.find(other, self)
|
1550
|
+
if not fk_field:
|
1551
|
+
self.no_relation_error(other) # === raise ERROR ... ===
|
1552
|
+
query = other.copy()
|
1553
|
+
other = self.copy()
|
1554
|
+
query.__class__ = NotSelectIN
|
1555
|
+
Field.add(fk_field, query)
|
1556
|
+
query.add(primary_k, other)
|
1557
|
+
return other
|
1203
1558
|
|
1204
1559
|
def limit(self, row_count: int=100, offset: int=0):
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1560
|
+
if Function.dialect == Dialect.SQL_SERVER:
|
1561
|
+
fields = self.values.get(SELECT)
|
1562
|
+
if fields:
|
1563
|
+
fields[0] = f'SELECT TOP({row_count}) {fields[0]}'
|
1564
|
+
else:
|
1565
|
+
self.values[SELECT] = [f'SELECT TOP({row_count}) *']
|
1566
|
+
return self
|
1567
|
+
if Function.dialect == Dialect.ORACLE:
|
1568
|
+
Where.gte(row_count).add(SQL_ROW_NUM, self)
|
1569
|
+
if offset > 0:
|
1570
|
+
Where.lte(row_count+offset).add(SQL_ROW_NUM, self)
|
1571
|
+
return self
|
1572
|
+
self.values[LIMIT] = ['{}{}'.format(
|
1573
|
+
row_count, f' OFFSET {offset}' if offset > 0 else ''
|
1574
|
+
)]
|
1209
1575
|
return self
|
1210
1576
|
|
1211
|
-
def match(self,
|
1212
|
-
|
1577
|
+
def match(self, field: str, key: str) -> bool:
|
1578
|
+
'''
|
1579
|
+
Recognizes if the field is from the current table
|
1580
|
+
'''
|
1581
|
+
if key in (ORDER_BY, GROUP_BY) and '.' not in field:
|
1582
|
+
return self.has_named_field(field)
|
1583
|
+
return re.findall(f'\b*{self.alias}[.]', field) != []
|
1213
1584
|
|
1214
1585
|
@classmethod
|
1215
1586
|
def parse(cls, txt: str, parser: Parser = SQLParser) -> list[SQLObject]:
|
@@ -1221,12 +1592,10 @@ class Select(SQLObject):
|
|
1221
1592
|
for rule in rules:
|
1222
1593
|
rule.apply(self)
|
1223
1594
|
|
1224
|
-
def add_fields(self, fields: list,
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
if group_by:
|
1229
|
-
class_types += [GroupBy]
|
1595
|
+
def add_fields(self, fields: list, class_types=None):
|
1596
|
+
if not class_types:
|
1597
|
+
class_types = []
|
1598
|
+
class_types += [Field]
|
1230
1599
|
FieldList(fields, class_types).add('', self)
|
1231
1600
|
|
1232
1601
|
def translate_to(self, language: QueryLanguage) -> str:
|
@@ -1246,6 +1615,95 @@ class NotSelectIN(SelectIN):
|
|
1246
1615
|
condition_class = Not
|
1247
1616
|
|
1248
1617
|
|
1618
|
+
class CTE(Select):
|
1619
|
+
prefix = ''
|
1620
|
+
|
1621
|
+
def __init__(self, table_name: str, query_list: list[Select]):
|
1622
|
+
super().__init__(table_name)
|
1623
|
+
for query in query_list:
|
1624
|
+
query.break_lines = False
|
1625
|
+
self.query_list = query_list
|
1626
|
+
self.break_lines = False
|
1627
|
+
|
1628
|
+
def __str__(self) -> str:
|
1629
|
+
size = 0
|
1630
|
+
for key in USUAL_KEYS:
|
1631
|
+
size += sum(len(v) for v in self.values.get(key, []) if '\n' not in v)
|
1632
|
+
if size > 70:
|
1633
|
+
self.break_lines = True
|
1634
|
+
# ---------------------------------------------------------
|
1635
|
+
def justify(query: Select) -> str:
|
1636
|
+
result, line = [], ''
|
1637
|
+
keywords = '|'.join(KEYWORD)
|
1638
|
+
for word in re.split(fr'({keywords}|AND|OR|,)', str(query)):
|
1639
|
+
if len(line) >= 50:
|
1640
|
+
result.append(line)
|
1641
|
+
line = ''
|
1642
|
+
line += word
|
1643
|
+
if line:
|
1644
|
+
result.append(line)
|
1645
|
+
return '\n '.join(result)
|
1646
|
+
# ---------------------------------------------------------
|
1647
|
+
return 'WITH {}{} AS (\n {}\n){}'.format(
|
1648
|
+
self.prefix, self.table_name,
|
1649
|
+
'\nUNION ALL\n '.join(
|
1650
|
+
justify(q) for q in self.query_list
|
1651
|
+
), super().__str__()
|
1652
|
+
)
|
1653
|
+
|
1654
|
+
def join(self, pattern: str, fields: list | str, format: str=''):
|
1655
|
+
if isinstance(fields, str):
|
1656
|
+
count = len( fields.split(',') )
|
1657
|
+
else:
|
1658
|
+
count = len(fields)
|
1659
|
+
queries = detect(
|
1660
|
+
pattern*count, join_queries=False, format=format
|
1661
|
+
)
|
1662
|
+
FieldList(fields, queries, ziped=True).add('', self)
|
1663
|
+
self.break_lines = True
|
1664
|
+
return self
|
1665
|
+
|
1666
|
+
class Recursive(CTE):
|
1667
|
+
prefix = 'RECURSIVE '
|
1668
|
+
|
1669
|
+
def __str__(self) -> str:
|
1670
|
+
if len(self.query_list) > 1:
|
1671
|
+
self.query_list[-1].values[FROM].append(
|
1672
|
+
f', {self.table_name} {self.alias}')
|
1673
|
+
return super().__str__()
|
1674
|
+
|
1675
|
+
@classmethod
|
1676
|
+
def create(cls, name: str, pattern: str, formula: str, init_value, format: str=''):
|
1677
|
+
SQLObject.ALIAS_FUNC = None
|
1678
|
+
def get_field(obj: SQLObject, pos: int) -> str:
|
1679
|
+
return obj.values[SELECT][pos].split('.')[-1]
|
1680
|
+
t1, t2 = detect(
|
1681
|
+
pattern*2, join_queries=False, format=format
|
1682
|
+
)
|
1683
|
+
pk_field = get_field(t1, 0)
|
1684
|
+
foreign_key = ''
|
1685
|
+
for num in re.findall(r'\[(\d+)\]', formula):
|
1686
|
+
num = int(num)
|
1687
|
+
if not foreign_key:
|
1688
|
+
foreign_key = get_field(t2, num-1)
|
1689
|
+
formula = formula.replace(f'[{num}]', '%')
|
1690
|
+
else:
|
1691
|
+
formula = formula.replace(f'[{num}]', get_field(t2, num-1))
|
1692
|
+
Where.eq(init_value).add(pk_field, t1)
|
1693
|
+
Where.formula(formula).add(foreign_key or pk_field, t2)
|
1694
|
+
return cls(name, [t1, t2])
|
1695
|
+
|
1696
|
+
def counter(self, name: str, start, increment: str='+1'):
|
1697
|
+
for i, query in enumerate(self.query_list):
|
1698
|
+
if i == 0:
|
1699
|
+
Field.add(f'{start} AS {name}', query)
|
1700
|
+
else:
|
1701
|
+
Field.add(f'({name}{increment}) AS {name}', query)
|
1702
|
+
return self
|
1703
|
+
|
1704
|
+
|
1705
|
+
# ----- Rules -----
|
1706
|
+
|
1249
1707
|
class RulePutLimit(Rule):
|
1250
1708
|
@classmethod
|
1251
1709
|
def apply(cls, target: Select):
|
@@ -1309,6 +1767,8 @@ class RuleDateFuncReplace(Rule):
|
|
1309
1767
|
@classmethod
|
1310
1768
|
def apply(cls, target: Select):
|
1311
1769
|
for i, condition in enumerate(target.values.get(WHERE, [])):
|
1770
|
+
if not '(' in condition:
|
1771
|
+
continue
|
1312
1772
|
tokens = [
|
1313
1773
|
t.strip() for t in cls.REGEX.split(condition) if t.strip()
|
1314
1774
|
]
|
@@ -1320,6 +1780,32 @@ class RuleDateFuncReplace(Rule):
|
|
1320
1780
|
target.values[WHERE][i] = ' AND '.join(temp.values[WHERE])
|
1321
1781
|
|
1322
1782
|
|
1783
|
+
class RuleReplaceJoinBySubselect(Rule):
|
1784
|
+
@classmethod
|
1785
|
+
def apply(cls, target: Select):
|
1786
|
+
main, *others = Select.parse( str(target) )
|
1787
|
+
modified = False
|
1788
|
+
for query in others:
|
1789
|
+
fk_field, primary_k = ForeignKey.find(main, query)
|
1790
|
+
more_relations = any([
|
1791
|
+
ref[0] == query.table_name for ref in ForeignKey.references
|
1792
|
+
])
|
1793
|
+
keep_join = any([
|
1794
|
+
len( query.values.get(SELECT, []) ) > 0,
|
1795
|
+
len( query.values.get(WHERE, []) ) == 0,
|
1796
|
+
not fk_field, more_relations
|
1797
|
+
])
|
1798
|
+
if keep_join:
|
1799
|
+
query.add(fk_field, main)
|
1800
|
+
continue
|
1801
|
+
query.__class__ = SubSelect
|
1802
|
+
Field.add(primary_k, query)
|
1803
|
+
query.add(fk_field, main)
|
1804
|
+
modified = True
|
1805
|
+
if modified:
|
1806
|
+
target.values = main.values.copy()
|
1807
|
+
|
1808
|
+
|
1323
1809
|
def parser_class(text: str) -> Parser:
|
1324
1810
|
PARSER_REGEX = [
|
1325
1811
|
(r'select.*from', SQLParser),
|
@@ -1334,7 +1820,7 @@ def parser_class(text: str) -> Parser:
|
|
1334
1820
|
return None
|
1335
1821
|
|
1336
1822
|
|
1337
|
-
def detect(text: str) -> Select:
|
1823
|
+
def detect(text: str, join_queries: bool = True, format: str='') -> Select | list[Select]:
|
1338
1824
|
from collections import Counter
|
1339
1825
|
parser = parser_class(text)
|
1340
1826
|
if not parser:
|
@@ -1345,21 +1831,18 @@ def detect(text: str) -> Select:
|
|
1345
1831
|
continue
|
1346
1832
|
pos = [ f.span() for f in re.finditer(fr'({table})[(]', text) ]
|
1347
1833
|
for begin, end in pos[::-1]:
|
1348
|
-
new_name = f'{table}_{count}' # See set_table (line
|
1834
|
+
new_name = f'{table}_{count}' # See set_table (line 55)
|
1349
1835
|
Select.EQUIVALENT_NAMES[new_name] = table
|
1350
1836
|
text = text[:begin] + new_name + '(' + text[end:]
|
1351
1837
|
count -= 1
|
1352
1838
|
query_list = Select.parse(text, parser)
|
1839
|
+
if format:
|
1840
|
+
for query in query_list:
|
1841
|
+
query.set_file_format(format)
|
1842
|
+
if not join_queries:
|
1843
|
+
return query_list
|
1353
1844
|
result = query_list[0]
|
1354
1845
|
for query in query_list[1:]:
|
1355
1846
|
result += query
|
1356
1847
|
return result
|
1357
|
-
|
1358
|
-
if __name__ == "__main__":
|
1359
|
-
OrderBy.sort = SortType.DESC
|
1360
|
-
query = Select(
|
1361
|
-
'order_Detail d',
|
1362
|
-
customer_id=GroupBy,
|
1363
|
-
_=Sum('d.unitPrice * d.quantity').As('total', OrderBy)
|
1364
|
-
)
|
1365
|
-
print(query)
|
1848
|
+
# ===========================================================================================//
|