sql-blocks 1.25.113__py3-none-any.whl → 1.25.514__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_blocks/sql_blocks.py +571 -116
- {sql_blocks-1.25.113.dist-info → sql_blocks-1.25.514.dist-info}/METADATA +270 -9
- sql_blocks-1.25.514.dist-info/RECORD +7 -0
- sql_blocks-1.25.113.dist-info/RECORD +0 -7
- {sql_blocks-1.25.113.dist-info → sql_blocks-1.25.514.dist-info}/LICENSE +0 -0
- {sql_blocks-1.25.113.dist-info → sql_blocks-1.25.514.dist-info}/WHEEL +0 -0
- {sql_blocks-1.25.113.dist-info → sql_blocks-1.25.514.dist-info}/top_level.txt +0 -0
sql_blocks/sql_blocks.py
CHANGED
@@ -42,23 +42,34 @@ class SQLObject:
|
|
42
42
|
if not table_name:
|
43
43
|
return
|
44
44
|
cls = SQLObject
|
45
|
+
is_file_name = any([
|
46
|
+
'/' in table_name, '.' in table_name
|
47
|
+
])
|
48
|
+
ref = table_name
|
49
|
+
if is_file_name:
|
50
|
+
ref = table_name.split('/')[-1].split('.')[0]
|
45
51
|
if cls.ALIAS_FUNC:
|
46
|
-
self.__alias = cls.ALIAS_FUNC(
|
52
|
+
self.__alias = cls.ALIAS_FUNC(ref)
|
47
53
|
elif ' ' in table_name.strip():
|
48
54
|
table_name, self.__alias = table_name.split()
|
49
|
-
elif '_' in
|
55
|
+
elif '_' in ref:
|
50
56
|
self.__alias = ''.join(
|
51
57
|
word[0].lower()
|
52
|
-
for word in
|
58
|
+
for word in ref.split('_')
|
53
59
|
)
|
54
60
|
else:
|
55
|
-
self.__alias =
|
61
|
+
self.__alias = ref.lower()[:3]
|
56
62
|
self.values.setdefault(FROM, []).append(f'{table_name} {self.alias}')
|
57
63
|
|
58
64
|
@property
|
59
65
|
def table_name(self) -> str:
|
60
66
|
return self.values[FROM][0].split()[0]
|
61
67
|
|
68
|
+
def set_file_format(self, pattern: str):
|
69
|
+
if '{' not in pattern:
|
70
|
+
pattern = '{}' + pattern
|
71
|
+
self.values[FROM][0] = pattern.format(self.aka())
|
72
|
+
|
62
73
|
@property
|
63
74
|
def alias(self) -> str:
|
64
75
|
if self.__alias:
|
@@ -71,8 +82,14 @@ class SQLObject:
|
|
71
82
|
return KEYWORD[key][0].format(appendix.get(key, ''))
|
72
83
|
|
73
84
|
@staticmethod
|
74
|
-
def is_named_field(fld: str,
|
75
|
-
return
|
85
|
+
def is_named_field(fld: str, name: str='') -> bool:
|
86
|
+
return re.search(fr'(\s+as\s+|\s+AS\s+){name}', fld)
|
87
|
+
|
88
|
+
def has_named_field(self, name: str) -> bool:
|
89
|
+
return any(
|
90
|
+
self.is_named_field(fld, name)
|
91
|
+
for fld in self.values.get(SELECT, [])
|
92
|
+
)
|
76
93
|
|
77
94
|
def diff(self, key: str, search_list: list, exact: bool=False) -> set:
|
78
95
|
def disassemble(source: list) -> list:
|
@@ -82,14 +99,17 @@ class SQLObject:
|
|
82
99
|
for fld in source:
|
83
100
|
result += re.split(r'([=()]|<>|\s+ON\s+|\s+on\s+)', fld)
|
84
101
|
return result
|
85
|
-
def cleanup(
|
102
|
+
def cleanup(text: str) -> str:
|
103
|
+
text = re.sub(r'[\n\t]', ' ', text)
|
86
104
|
if exact:
|
87
|
-
|
88
|
-
return
|
105
|
+
text = text.lower()
|
106
|
+
return text.strip()
|
89
107
|
def field_set(source: list) -> set:
|
90
108
|
return set(
|
91
109
|
(
|
92
|
-
fld
|
110
|
+
fld
|
111
|
+
if key == SELECT and self.is_named_field(fld, key)
|
112
|
+
else
|
93
113
|
re.sub(pattern, '', cleanup(fld))
|
94
114
|
)
|
95
115
|
for string in disassemble(source)
|
@@ -121,7 +141,8 @@ class SQLObject:
|
|
121
141
|
|
122
142
|
SQL_CONST_SYSDATE = 'SYSDATE'
|
123
143
|
SQL_CONST_CURR_DATE = 'Current_date'
|
124
|
-
|
144
|
+
SQL_ROW_NUM = 'ROWNUM'
|
145
|
+
SQL_CONSTS = [SQL_CONST_SYSDATE, SQL_CONST_CURR_DATE, SQL_ROW_NUM]
|
125
146
|
|
126
147
|
|
127
148
|
class Field:
|
@@ -138,7 +159,7 @@ class Field:
|
|
138
159
|
name = name.strip()
|
139
160
|
if name in ('_', '*'):
|
140
161
|
name = '*'
|
141
|
-
elif not is_const():
|
162
|
+
elif not is_const() and not main.has_named_field(name):
|
142
163
|
name = f'{main.alias}.{name}'
|
143
164
|
if Function in cls.__bases__:
|
144
165
|
name = f'{cls.__name__}({name})'
|
@@ -176,15 +197,35 @@ class Dialect(Enum):
|
|
176
197
|
POSTGRESQL = 3
|
177
198
|
MYSQL = 4
|
178
199
|
|
200
|
+
SQL_TYPES = 'CHAR INT DATE FLOAT ANY'.split()
|
201
|
+
CHAR, INT, DATE, FLOAT, ANY = SQL_TYPES
|
202
|
+
|
179
203
|
class Function:
|
180
204
|
dialect = Dialect.ANSI
|
205
|
+
inputs = None
|
206
|
+
output = None
|
207
|
+
separator = ', '
|
208
|
+
auto_convert = True
|
209
|
+
append_param = False
|
181
210
|
|
182
211
|
def __init__(self, *params: list):
|
212
|
+
def set_func_types(param):
|
213
|
+
if self.auto_convert and isinstance(param, Function):
|
214
|
+
func = param
|
215
|
+
main_param = self.inputs[0]
|
216
|
+
unfriendly = all([
|
217
|
+
func.output != main_param,
|
218
|
+
func.output != ANY,
|
219
|
+
main_param != ANY
|
220
|
+
])
|
221
|
+
if unfriendly:
|
222
|
+
return Cast(func, main_param)
|
223
|
+
return param
|
183
224
|
# --- Replace class methods by instance methods: ------
|
184
225
|
self.add = self.__add
|
185
226
|
self.format = self.__format
|
186
227
|
# -----------------------------------------------------
|
187
|
-
self.params = [
|
228
|
+
self.params = [set_func_types(p) for p in params]
|
188
229
|
self.field_class = Field
|
189
230
|
self.pattern = self.get_pattern()
|
190
231
|
self.extra = {}
|
@@ -201,14 +242,35 @@ class Function:
|
|
201
242
|
def __str__(self) -> str:
|
202
243
|
return self.pattern.format(
|
203
244
|
func_name=self.__class__.__name__,
|
204
|
-
params=
|
245
|
+
params=self.separator.join(str(p) for p in self.params)
|
205
246
|
)
|
206
247
|
|
248
|
+
@classmethod
|
249
|
+
def help(cls) -> str:
|
250
|
+
descr = ' '.join(B.__name__ for B in cls.__bases__)
|
251
|
+
params = cls.inputs or ''
|
252
|
+
return cls().get_pattern().format(
|
253
|
+
func_name=f'{descr} {cls.__name__}',
|
254
|
+
params=cls.separator.join(str(p) for p in params)
|
255
|
+
) + f' Return {cls.output}'
|
256
|
+
|
257
|
+
def set_main_param(self, name: str, main: SQLObject) -> bool:
|
258
|
+
nested_functions = [
|
259
|
+
param for param in self.params if isinstance(param, Function)
|
260
|
+
]
|
261
|
+
for func in nested_functions:
|
262
|
+
if func.inputs:
|
263
|
+
func.set_main_param(name, main)
|
264
|
+
return
|
265
|
+
new_params = [Field.format(name, main)]
|
266
|
+
if self.append_param:
|
267
|
+
self.params += new_params
|
268
|
+
else:
|
269
|
+
self.params = new_params + self.params
|
270
|
+
|
207
271
|
def __format(self, name: str, main: SQLObject) -> str:
|
208
272
|
if name not in '*_':
|
209
|
-
self.
|
210
|
-
Field.format(name, main)
|
211
|
-
] + self.params
|
273
|
+
self.set_main_param(name, main)
|
212
274
|
return str(self)
|
213
275
|
|
214
276
|
@classmethod
|
@@ -228,6 +290,9 @@ class Function:
|
|
228
290
|
|
229
291
|
# ---- String Functions: ---------------------------------
|
230
292
|
class SubString(Function):
|
293
|
+
inputs = [CHAR, INT, INT]
|
294
|
+
output = CHAR
|
295
|
+
|
231
296
|
def get_pattern(self) -> str:
|
232
297
|
if self.dialect in (Dialect.ORACLE, Dialect.MYSQL):
|
233
298
|
return 'Substr({params})'
|
@@ -235,31 +300,55 @@ class SubString(Function):
|
|
235
300
|
|
236
301
|
# ---- Numeric Functions: --------------------------------
|
237
302
|
class Round(Function):
|
238
|
-
|
303
|
+
inputs = [FLOAT]
|
304
|
+
output = FLOAT
|
239
305
|
|
240
306
|
# --- Date Functions: ------------------------------------
|
241
307
|
class DateDiff(Function):
|
242
|
-
|
308
|
+
inputs = [DATE]
|
309
|
+
output = DATE
|
310
|
+
append_param = True
|
311
|
+
|
312
|
+
def __str__(self) -> str:
|
243
313
|
def is_field_or_func(name: str) -> bool:
|
244
|
-
|
314
|
+
candidate = re.sub(
|
315
|
+
'[()]', '', name.split('.')[-1]
|
316
|
+
)
|
317
|
+
return candidate.isidentifier()
|
245
318
|
if self.dialect != Dialect.SQL_SERVER:
|
319
|
+
params = [str(p) for p in self.params]
|
246
320
|
return ' - '.join(
|
247
321
|
p if is_field_or_func(p) else f"'{p}'"
|
248
|
-
for p in
|
322
|
+
for p in params
|
249
323
|
) # <==== Date subtract
|
250
|
-
return super().
|
324
|
+
return super().__str__()
|
325
|
+
|
326
|
+
|
327
|
+
class DatePart(Function):
|
328
|
+
inputs = [DATE]
|
329
|
+
output = INT
|
251
330
|
|
252
|
-
class Year(Function):
|
253
331
|
def get_pattern(self) -> str:
|
332
|
+
interval = self.__class__.__name__
|
254
333
|
database_type = {
|
255
|
-
Dialect.ORACLE: 'Extract(
|
256
|
-
Dialect.POSTGRESQL: "Date_Part('
|
334
|
+
Dialect.ORACLE: 'Extract('+interval+' FROM {params})',
|
335
|
+
Dialect.POSTGRESQL: "Date_Part('"+interval+"', {params})",
|
257
336
|
}
|
258
337
|
if self.dialect in database_type:
|
259
338
|
return database_type[self.dialect]
|
260
339
|
return super().get_pattern()
|
261
340
|
|
341
|
+
class Year(DatePart):
|
342
|
+
...
|
343
|
+
class Month(DatePart):
|
344
|
+
...
|
345
|
+
class Day(DatePart):
|
346
|
+
...
|
347
|
+
|
348
|
+
|
262
349
|
class Current_Date(Function):
|
350
|
+
output = DATE
|
351
|
+
|
263
352
|
def get_pattern(self) -> str:
|
264
353
|
database_type = {
|
265
354
|
Dialect.ORACLE: SQL_CONST_SYSDATE,
|
@@ -282,14 +371,15 @@ class Frame:
|
|
282
371
|
keywords = ''
|
283
372
|
for field, obj in args.items():
|
284
373
|
is_valid = any([
|
285
|
-
obj is
|
286
|
-
|
374
|
+
obj is OrderBy,
|
375
|
+
obj is Partition,
|
376
|
+
isinstance(obj, Rows),
|
287
377
|
])
|
288
378
|
if not is_valid:
|
289
379
|
continue
|
290
380
|
keywords += '{}{} {}'.format(
|
291
381
|
'\n\t\t' if self.break_lines else ' ',
|
292
|
-
obj.cls_to_str(), field
|
382
|
+
obj.cls_to_str(), field if field != '_' else ''
|
293
383
|
)
|
294
384
|
if keywords and self.break_lines:
|
295
385
|
keywords += '\n\t'
|
@@ -298,7 +388,8 @@ class Frame:
|
|
298
388
|
|
299
389
|
|
300
390
|
class Aggregate(Frame):
|
301
|
-
|
391
|
+
inputs = [FLOAT]
|
392
|
+
output = FLOAT
|
302
393
|
|
303
394
|
class Window(Frame):
|
304
395
|
...
|
@@ -317,20 +408,30 @@ class Count(Aggregate, Function):
|
|
317
408
|
|
318
409
|
# ---- Window Functions: -----------------------------------
|
319
410
|
class Row_Number(Window, Function):
|
320
|
-
|
411
|
+
output = INT
|
412
|
+
|
321
413
|
class Rank(Window, Function):
|
322
|
-
|
414
|
+
output = INT
|
415
|
+
|
323
416
|
class Lag(Window, Function):
|
324
|
-
|
417
|
+
output = ANY
|
418
|
+
|
325
419
|
class Lead(Window, Function):
|
326
|
-
|
420
|
+
output = ANY
|
327
421
|
|
328
422
|
|
329
423
|
# ---- Conversions and other Functions: ---------------------
|
330
424
|
class Coalesce(Function):
|
331
|
-
|
425
|
+
inputs = [ANY]
|
426
|
+
output = ANY
|
427
|
+
|
332
428
|
class Cast(Function):
|
333
|
-
|
429
|
+
inputs = [ANY]
|
430
|
+
output = ANY
|
431
|
+
separator = ' As '
|
432
|
+
|
433
|
+
|
434
|
+
FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
|
334
435
|
|
335
436
|
|
336
437
|
class ExpressionField:
|
@@ -355,15 +456,20 @@ class ExpressionField:
|
|
355
456
|
class FieldList:
|
356
457
|
separator = ','
|
357
458
|
|
358
|
-
def __init__(self, fields: list=[], class_types = [Field]):
|
459
|
+
def __init__(self, fields: list=[], class_types = [Field], ziped: bool=False):
|
359
460
|
if isinstance(fields, str):
|
360
461
|
fields = [
|
361
462
|
f.strip() for f in fields.split(self.separator)
|
362
463
|
]
|
363
464
|
self.fields = fields
|
364
465
|
self.class_types = class_types
|
466
|
+
self.ziped = ziped
|
365
467
|
|
366
468
|
def add(self, name: str, main: SQLObject):
|
469
|
+
if self.ziped: # --- One class per field...
|
470
|
+
for field, class_type in zip(self.fields, self.class_types):
|
471
|
+
class_type.add(field, main)
|
472
|
+
return
|
367
473
|
for field in self.fields:
|
368
474
|
for class_type in self.class_types:
|
369
475
|
class_type.add(field, main)
|
@@ -405,36 +511,40 @@ class ForeignKey:
|
|
405
511
|
|
406
512
|
def quoted(value) -> str:
|
407
513
|
if isinstance(value, str):
|
514
|
+
if re.search(r'\bor\b', value, re.IGNORECASE):
|
515
|
+
raise PermissionError('Possible SQL injection attempt')
|
408
516
|
value = f"'{value}'"
|
409
517
|
return str(value)
|
410
518
|
|
411
519
|
|
412
520
|
class Position(Enum):
|
521
|
+
StartsWith = -1
|
413
522
|
Middle = 0
|
414
|
-
|
415
|
-
EndsWith = 2
|
523
|
+
EndsWith = 1
|
416
524
|
|
417
525
|
|
418
526
|
class Where:
|
419
527
|
prefix = ''
|
420
528
|
|
421
|
-
def __init__(self,
|
422
|
-
self.
|
529
|
+
def __init__(self, content: str):
|
530
|
+
self.content = content
|
423
531
|
|
424
532
|
@classmethod
|
425
533
|
def __constructor(cls, operator: str, value):
|
426
|
-
return cls(
|
534
|
+
return cls(f'{operator} {quoted(value)}')
|
427
535
|
|
428
536
|
@classmethod
|
429
537
|
def eq(cls, value):
|
430
538
|
return cls.__constructor('=', value)
|
431
539
|
|
432
540
|
@classmethod
|
433
|
-
def contains(cls,
|
541
|
+
def contains(cls, text: str, pos: int | Position = Position.Middle):
|
542
|
+
if isinstance(pos, int):
|
543
|
+
pos = Position(pos)
|
434
544
|
return cls(
|
435
545
|
"LIKE '{}{}{}'".format(
|
436
546
|
'%' if pos != Position.StartsWith else '',
|
437
|
-
|
547
|
+
text,
|
438
548
|
'%' if pos != Position.EndsWith else ''
|
439
549
|
)
|
440
550
|
)
|
@@ -465,9 +575,43 @@ class Where:
|
|
465
575
|
values = ','.join(quoted(v) for v in values)
|
466
576
|
return cls(f'IN ({values})')
|
467
577
|
|
578
|
+
@classmethod
|
579
|
+
def formula(cls, formula: str):
|
580
|
+
where = cls( ExpressionField(formula) )
|
581
|
+
where.add = where.add_expression
|
582
|
+
return where
|
583
|
+
|
584
|
+
def add_expression(self, name: str, main: SQLObject):
|
585
|
+
self.content = self.content.format(name, main)
|
586
|
+
main.values.setdefault(WHERE, []).append('{} {}'.format(
|
587
|
+
self.prefix, self.content
|
588
|
+
))
|
589
|
+
|
590
|
+
@classmethod
|
591
|
+
def join(cls, query: SQLObject):
|
592
|
+
where = cls(query)
|
593
|
+
where.add = where.add_join
|
594
|
+
return where
|
595
|
+
|
596
|
+
def add_join(self, name: str, main: SQLObject):
|
597
|
+
query = self.content
|
598
|
+
main.values[FROM].append(f',{query.table_name} {query.alias}')
|
599
|
+
for key in USUAL_KEYS:
|
600
|
+
main.update_values(key, query.values.get(key, []))
|
601
|
+
if query.key_field:
|
602
|
+
main.values.setdefault(WHERE, []).append('({a1}.{f1} = {a2}.{f2})'.format(
|
603
|
+
a1=main.alias, f1=name,
|
604
|
+
a2=query.alias, f2=query.key_field
|
605
|
+
))
|
606
|
+
|
468
607
|
def add(self, name: str, main: SQLObject):
|
608
|
+
func_type = FUNCTION_CLASS.get(name.lower())
|
609
|
+
if func_type:
|
610
|
+
name = func_type.format('*', main)
|
611
|
+
elif not main.has_named_field(name):
|
612
|
+
name = Field.format(name, main)
|
469
613
|
main.values.setdefault(WHERE, []).append('{}{} {}'.format(
|
470
|
-
self.prefix,
|
614
|
+
self.prefix, name, self.content
|
471
615
|
))
|
472
616
|
|
473
617
|
|
@@ -475,6 +619,10 @@ eq, contains, gt, gte, lt, lte, is_null, inside = (
|
|
475
619
|
getattr(Where, method) for method in
|
476
620
|
('eq', 'contains', 'gt', 'gte', 'lt', 'lte', 'is_null', 'inside')
|
477
621
|
)
|
622
|
+
startswith, endswith = [
|
623
|
+
lambda x: contains(x, Position.StartsWith),
|
624
|
+
lambda x: contains(x, Position.EndsWith)
|
625
|
+
]
|
478
626
|
|
479
627
|
|
480
628
|
class Not(Where):
|
@@ -482,7 +630,7 @@ class Not(Where):
|
|
482
630
|
|
483
631
|
@classmethod
|
484
632
|
def eq(cls, value):
|
485
|
-
return Where(
|
633
|
+
return Where(f'<> {quoted(value)}')
|
486
634
|
|
487
635
|
|
488
636
|
class Case:
|
@@ -491,22 +639,26 @@ class Case:
|
|
491
639
|
self.default = None
|
492
640
|
self.field = field
|
493
641
|
|
494
|
-
def when(self, condition: Where, result
|
642
|
+
def when(self, condition: Where, result):
|
643
|
+
if isinstance(result, str):
|
644
|
+
result = quoted(result)
|
495
645
|
self.__conditions[result] = condition
|
496
646
|
return self
|
497
647
|
|
498
|
-
def else_value(self, default
|
648
|
+
def else_value(self, default):
|
649
|
+
if isinstance(default, str):
|
650
|
+
default = quoted(default)
|
499
651
|
self.default = default
|
500
652
|
return self
|
501
653
|
|
502
654
|
def add(self, name: str, main: SQLObject):
|
503
655
|
field = Field.format(self.field, main)
|
504
|
-
default =
|
656
|
+
default = self.default
|
505
657
|
name = 'CASE \n{}\n\tEND AS {}'.format(
|
506
658
|
'\n'.join(
|
507
|
-
f'\t\tWHEN {field} {cond.
|
659
|
+
f'\t\tWHEN {field} {cond.content} THEN {res}'
|
508
660
|
for res, cond in self.__conditions.items()
|
509
|
-
) + f'\n\t\tELSE {default}' if default else '',
|
661
|
+
) + (f'\n\t\tELSE {default}' if default else ''),
|
510
662
|
name
|
511
663
|
)
|
512
664
|
main.values.setdefault(SELECT, []).append(name)
|
@@ -517,42 +669,69 @@ class Options:
|
|
517
669
|
self.__children: dict = values
|
518
670
|
|
519
671
|
def add(self, logical_separator: str, main: SQLObject):
|
520
|
-
if logical_separator not in ('AND', 'OR'):
|
672
|
+
if logical_separator.upper() not in ('AND', 'OR'):
|
521
673
|
raise ValueError('`logical_separator` must be AND or OR')
|
522
|
-
|
674
|
+
temp = Select(f'{main.table_name} {main.alias}')
|
523
675
|
child: Where
|
524
676
|
for field, child in self.__children.items():
|
525
|
-
|
526
|
-
Field.format(field, main), child.expr
|
527
|
-
))
|
677
|
+
child.add(field, temp)
|
528
678
|
main.values.setdefault(WHERE, []).append(
|
529
|
-
'(' + logical_separator.join(
|
679
|
+
'(' + f'\n\t{logical_separator} '.join(temp.values[WHERE]) + ')'
|
530
680
|
)
|
531
681
|
|
532
682
|
|
533
683
|
class Between:
|
684
|
+
is_literal: bool = False
|
685
|
+
|
534
686
|
def __init__(self, start, end):
|
535
687
|
if start > end:
|
536
688
|
start, end = end, start
|
537
689
|
self.start = start
|
538
690
|
self.end = end
|
539
691
|
|
692
|
+
def literal(self) -> Where:
|
693
|
+
return Where('BETWEEN {} AND {}'.format(
|
694
|
+
self.start, self.end
|
695
|
+
))
|
696
|
+
|
540
697
|
def add(self, name: str, main:SQLObject):
|
541
|
-
|
698
|
+
if self.is_literal:
|
699
|
+
return self.literal().add(name, main)
|
700
|
+
Where.gte(self.start).add(name, main)
|
542
701
|
Where.lte(self.end).add(name, main)
|
543
702
|
|
703
|
+
class SameDay(Between):
|
704
|
+
def __init__(self, date: str):
|
705
|
+
super().__init__(
|
706
|
+
f'{date} 00:00:00',
|
707
|
+
f'{date} 23:59:59',
|
708
|
+
)
|
709
|
+
|
710
|
+
|
711
|
+
class Range(Case):
|
712
|
+
INC_FUNCTION = lambda x: x + 1
|
713
|
+
|
714
|
+
def __init__(self, field: str, values: dict):
|
715
|
+
super().__init__(field)
|
716
|
+
start = 0
|
717
|
+
cls = self.__class__
|
718
|
+
for label, value in sorted(values.items(), key=lambda item: item[1]):
|
719
|
+
self.when(
|
720
|
+
Between(start, value).literal(), label
|
721
|
+
)
|
722
|
+
start = cls.INC_FUNCTION(value)
|
723
|
+
|
544
724
|
|
545
725
|
class Clause:
|
546
726
|
@classmethod
|
547
727
|
def format(cls, name: str, main: SQLObject) -> str:
|
548
728
|
def is_function() -> bool:
|
549
729
|
diff = main.diff(SELECT, [name.lower()], True)
|
550
|
-
FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
|
551
730
|
return diff.intersection(FUNCTION_CLASS)
|
552
731
|
found = re.findall(r'^_\d', name)
|
553
732
|
if found:
|
554
733
|
name = found[0].replace('_', '')
|
555
|
-
elif main.alias and not is_function():
|
734
|
+
elif '.' not in name and main.alias and not is_function():
|
556
735
|
name = f'{main.alias}.{name}'
|
557
736
|
return name
|
558
737
|
|
@@ -561,6 +740,34 @@ class SortType(Enum):
|
|
561
740
|
ASC = ''
|
562
741
|
DESC = ' DESC'
|
563
742
|
|
743
|
+
class Row:
|
744
|
+
def __init__(self, value: int=0):
|
745
|
+
self.value = value
|
746
|
+
|
747
|
+
def __str__(self) -> str:
|
748
|
+
return '{} {}'.format(
|
749
|
+
'UNBOUNDED' if self.value == 0 else self.value,
|
750
|
+
self.__class__.__name__.upper()
|
751
|
+
)
|
752
|
+
|
753
|
+
class Preceding(Row):
|
754
|
+
...
|
755
|
+
class Following(Row):
|
756
|
+
...
|
757
|
+
class Current(Row):
|
758
|
+
def __str__(self) -> str:
|
759
|
+
return 'CURRENT ROW'
|
760
|
+
|
761
|
+
class Rows:
|
762
|
+
def __init__(self, *rows: list[Row]):
|
763
|
+
self.rows = rows
|
764
|
+
|
765
|
+
def cls_to_str(self) -> str:
|
766
|
+
return 'ROWS {}{}'.format(
|
767
|
+
'BETWEEN ' if len(self.rows) > 1 else '',
|
768
|
+
' AND '.join(str(row) for row in self.rows)
|
769
|
+
)
|
770
|
+
|
564
771
|
|
565
772
|
class OrderBy(Clause):
|
566
773
|
sort: SortType = SortType.ASC
|
@@ -595,7 +802,7 @@ class Having:
|
|
595
802
|
|
596
803
|
def add(self, name: str, main:SQLObject):
|
597
804
|
main.values[GROUP_BY][-1] += ' HAVING {} {}'.format(
|
598
|
-
self.function.format(name, main), self.condition.
|
805
|
+
self.function.format(name, main), self.condition.content
|
599
806
|
)
|
600
807
|
|
601
808
|
@classmethod
|
@@ -625,12 +832,20 @@ class Rule:
|
|
625
832
|
...
|
626
833
|
|
627
834
|
class QueryLanguage:
|
628
|
-
pattern = '{select}{_from}{where}{group_by}{order_by}'
|
835
|
+
pattern = '{select}{_from}{where}{group_by}{order_by}{limit}'
|
629
836
|
has_default = {key: bool(key == SELECT) for key in KEYWORD}
|
630
837
|
|
631
838
|
@staticmethod
|
632
|
-
def remove_alias(
|
633
|
-
|
839
|
+
def remove_alias(text: str) -> str:
|
840
|
+
value, sep = '', ''
|
841
|
+
text = re.sub('[\n\t]', ' ', text)
|
842
|
+
if ':' in text:
|
843
|
+
text, value = text.split(':', maxsplit=1)
|
844
|
+
sep = ':'
|
845
|
+
return '{}{}{}'.format(
|
846
|
+
''.join(re.split(r'\w+[.]', text)),
|
847
|
+
sep, value.replace("'", '"')
|
848
|
+
)
|
634
849
|
|
635
850
|
def join_with_tabs(self, values: list, sep: str='') -> str:
|
636
851
|
sep = sep + self.TABULATION
|
@@ -648,18 +863,21 @@ class QueryLanguage:
|
|
648
863
|
return self.join_with_tabs(values, ' AND ')
|
649
864
|
|
650
865
|
def sort_by(self, values: list) -> str:
|
651
|
-
return self.join_with_tabs(values)
|
866
|
+
return self.join_with_tabs(values, ',')
|
652
867
|
|
653
868
|
def set_group(self, values: list) -> str:
|
654
869
|
return self.join_with_tabs(values, ',')
|
655
870
|
|
871
|
+
def set_limit(self, values: list) -> str:
|
872
|
+
return self.join_with_tabs(values, ' ')
|
873
|
+
|
656
874
|
def __init__(self, target: 'Select'):
|
657
|
-
self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY]
|
875
|
+
self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY, LIMIT]
|
658
876
|
self.TABULATION = '\n\t' if target.break_lines else ' '
|
659
877
|
self.LINE_BREAK = '\n' if target.break_lines else ' '
|
660
878
|
self.TOKEN_METHODS = {
|
661
879
|
SELECT: self.add_field, FROM: self.get_tables,
|
662
|
-
WHERE: self.extract_conditions,
|
880
|
+
WHERE: self.extract_conditions, LIMIT: self.set_limit,
|
663
881
|
ORDER_BY: self.sort_by, GROUP_BY: self.set_group,
|
664
882
|
}
|
665
883
|
self.result = {}
|
@@ -695,7 +913,8 @@ class MongoDBLanguage(QueryLanguage):
|
|
695
913
|
LOGICAL_OP_TO_MONGO_FUNC = {
|
696
914
|
'>': '$gt', '>=': '$gte',
|
697
915
|
'<': '$lt', '<=': '$lte',
|
698
|
-
'=': '$eq', '<>': '$ne',
|
916
|
+
'=': '$eq', '<>': '$ne',
|
917
|
+
'like': '$regex', 'LIKE': '$regex',
|
699
918
|
}
|
700
919
|
OPERATORS = '|'.join(op for op in LOGICAL_OP_TO_MONGO_FUNC)
|
701
920
|
REGEX = {
|
@@ -748,7 +967,7 @@ class MongoDBLanguage(QueryLanguage):
|
|
748
967
|
field, *op, const = tokens
|
749
968
|
op = ''.join(op)
|
750
969
|
expr = '{begin}{op}:{const}{end}'.format(
|
751
|
-
begin='{', const=const, end='}',
|
970
|
+
begin='{', const=const.replace('%', '.*'), end='}',
|
752
971
|
op=cls.LOGICAL_OP_TO_MONGO_FUNC[op],
|
753
972
|
)
|
754
973
|
where_list.append(f'{field}:{expr}')
|
@@ -857,6 +1076,55 @@ class Neo4JLanguage(QueryLanguage):
|
|
857
1076
|
return ''
|
858
1077
|
|
859
1078
|
|
1079
|
+
class DatabricksLanguage(QueryLanguage):
|
1080
|
+
pattern = '{_from}{where}{group_by}{order_by}{select}{limit}'
|
1081
|
+
has_default = {key: bool(key == SELECT) for key in KEYWORD}
|
1082
|
+
|
1083
|
+
def __init__(self, target: 'Select'):
|
1084
|
+
super().__init__(target)
|
1085
|
+
self.aggregation_fields = []
|
1086
|
+
|
1087
|
+
def add_field(self, values: list) -> str:
|
1088
|
+
AGG_FUNCS = '|'.join(cls.__name__ for cls in Aggregate.__subclasses__())
|
1089
|
+
# --------------------------------------------------------------
|
1090
|
+
def is_agg_field(fld: str) -> bool:
|
1091
|
+
return re.findall(fr'({AGG_FUNCS})[(]', fld, re.IGNORECASE)
|
1092
|
+
# --------------------------------------------------------------
|
1093
|
+
new_values = []
|
1094
|
+
for val in values:
|
1095
|
+
if is_agg_field(val):
|
1096
|
+
self.aggregation_fields.append(val)
|
1097
|
+
else:
|
1098
|
+
new_values.append(val)
|
1099
|
+
values = new_values
|
1100
|
+
return super().add_field(values)
|
1101
|
+
|
1102
|
+
def prefix(self, key: str) -> str:
|
1103
|
+
def get_aggregate() -> str:
|
1104
|
+
return 'AGGREGATE {} '.format(
|
1105
|
+
','.join(self.aggregation_fields)
|
1106
|
+
)
|
1107
|
+
return '{}{}{}{}{}'.format(
|
1108
|
+
'|> ' if key != FROM else '',
|
1109
|
+
self.LINE_BREAK,
|
1110
|
+
get_aggregate() if key == GROUP_BY else '',
|
1111
|
+
key, self.TABULATION
|
1112
|
+
)
|
1113
|
+
|
1114
|
+
# def get_tables(self, values: list) -> str:
|
1115
|
+
# return self.join_with_tabs(values)
|
1116
|
+
|
1117
|
+
# def extract_conditions(self, values: list) -> str:
|
1118
|
+
# return self.join_with_tabs(values, ' AND ')
|
1119
|
+
|
1120
|
+
# def sort_by(self, values: list) -> str:
|
1121
|
+
# return self.join_with_tabs(values, ',')
|
1122
|
+
|
1123
|
+
def set_group(self, values: list) -> str:
|
1124
|
+
return self.join_with_tabs(values, ',')
|
1125
|
+
|
1126
|
+
|
1127
|
+
|
860
1128
|
class Parser:
|
861
1129
|
REGEX = {}
|
862
1130
|
|
@@ -963,8 +1231,11 @@ class SQLParser(Parser):
|
|
963
1231
|
if not key in values:
|
964
1232
|
continue
|
965
1233
|
separator = self.class_type.get_separator(key)
|
1234
|
+
cls = {
|
1235
|
+
ORDER_BY: OrderBy, GROUP_BY: GroupBy
|
1236
|
+
}.get(key, Field)
|
966
1237
|
obj.values[key] = [
|
967
|
-
|
1238
|
+
cls.format(fld, obj)
|
968
1239
|
for fld in re.split(separator, values[key])
|
969
1240
|
if (fld != '*' and len(tables) == 1) or obj.match(fld, key)
|
970
1241
|
]
|
@@ -1026,7 +1297,11 @@ class CypherParser(Parser):
|
|
1026
1297
|
if token in self.TOKEN_METHODS:
|
1027
1298
|
return
|
1028
1299
|
class_list = [Field]
|
1029
|
-
if '
|
1300
|
+
if '*' in token:
|
1301
|
+
token = token.replace('*', '')
|
1302
|
+
self.queries[-1].key_field = token
|
1303
|
+
return
|
1304
|
+
elif '$' in token:
|
1030
1305
|
func_name, token = token.split('$')
|
1031
1306
|
if func_name == 'count':
|
1032
1307
|
if not token:
|
@@ -1035,8 +1310,13 @@ class CypherParser(Parser):
|
|
1035
1310
|
Count().As(token, extra_classes).add(pk_field, self.queries[-1])
|
1036
1311
|
return
|
1037
1312
|
else:
|
1038
|
-
|
1039
|
-
|
1313
|
+
class_type = FUNCTION_CLASS.get(func_name)
|
1314
|
+
if not class_type:
|
1315
|
+
raise ValueError(f'Unknown function `{func_name}`.')
|
1316
|
+
if ':' in token:
|
1317
|
+
token, field_alias = token.split(':')
|
1318
|
+
class_type = class_type().As(field_alias)
|
1319
|
+
class_list = [class_type]
|
1040
1320
|
class_list += extra_classes
|
1041
1321
|
FieldList(token, class_list).add('', self.queries[-1])
|
1042
1322
|
|
@@ -1051,10 +1331,13 @@ class CypherParser(Parser):
|
|
1051
1331
|
def add_foreign_key(self, token: str, pk_field: str=''):
|
1052
1332
|
curr, last = [self.queries[i] for i in (-1, -2)]
|
1053
1333
|
if not pk_field:
|
1054
|
-
if
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1334
|
+
if last.key_field:
|
1335
|
+
pk_field = last.key_field
|
1336
|
+
else:
|
1337
|
+
if not last.values.get(SELECT):
|
1338
|
+
raise IndexError(f'Primary Key not found for {last.table_name}.')
|
1339
|
+
pk_field = last.values[SELECT][-1].split('.')[-1]
|
1340
|
+
last.delete(pk_field, [SELECT], exact=True)
|
1058
1341
|
if '{}' in token:
|
1059
1342
|
foreign_fld = token.format(
|
1060
1343
|
last.table_name.lower()
|
@@ -1197,7 +1480,18 @@ class MongoParser(Parser):
|
|
1197
1480
|
|
1198
1481
|
def begin_conditions(self, value: str):
|
1199
1482
|
self.where_list = {}
|
1483
|
+
self.field_method = self.first_ORfield
|
1200
1484
|
return Where
|
1485
|
+
|
1486
|
+
def first_ORfield(self, text: str):
|
1487
|
+
if text.startswith('$'):
|
1488
|
+
return
|
1489
|
+
found = re.search(r'\w+[:]', text)
|
1490
|
+
if not found:
|
1491
|
+
return
|
1492
|
+
self.field_method = None
|
1493
|
+
p1, p2 = found.span()
|
1494
|
+
self.last_field = text[p1: p2-1]
|
1201
1495
|
|
1202
1496
|
def increment_brackets(self, value: str):
|
1203
1497
|
self.brackets[value] += 1
|
@@ -1206,6 +1500,7 @@ class MongoParser(Parser):
|
|
1206
1500
|
self.method = self.new_query
|
1207
1501
|
self.last_field = ''
|
1208
1502
|
self.where_list = None
|
1503
|
+
self.field_method = None
|
1209
1504
|
self.PARAM_BY_FUNCTION = {
|
1210
1505
|
'find': Where, 'aggregate': GroupBy, 'sort': OrderBy
|
1211
1506
|
}
|
@@ -1235,13 +1530,14 @@ class MongoParser(Parser):
|
|
1235
1530
|
self.close_brackets(
|
1236
1531
|
BRACKET_PAIR[token]
|
1237
1532
|
)
|
1533
|
+
elif self.field_method:
|
1534
|
+
self.field_method(token)
|
1238
1535
|
self.method = self.TOKEN_METHODS.get(token)
|
1239
1536
|
# ----------------------------
|
1240
1537
|
|
1241
1538
|
|
1242
1539
|
class Select(SQLObject):
|
1243
1540
|
join_type: JoinType = JoinType.INNER
|
1244
|
-
REGEX = {}
|
1245
1541
|
EQUIVALENT_NAMES = {}
|
1246
1542
|
|
1247
1543
|
def __init__(self, table_name: str='', **values):
|
@@ -1259,21 +1555,30 @@ class Select(SQLObject):
|
|
1259
1555
|
|
1260
1556
|
def add(self, name: str, main: SQLObject):
|
1261
1557
|
old_tables = main.values.get(FROM, [])
|
1262
|
-
|
1263
|
-
|
1558
|
+
if len(self.values[FROM]) > 1:
|
1559
|
+
old_tables += self.values[FROM][1:]
|
1560
|
+
new_tables = []
|
1561
|
+
row = '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
|
1264
1562
|
jt=self.join_type.value,
|
1265
1563
|
tb=self.aka(),
|
1266
1564
|
a1=main.alias, f1=name,
|
1267
1565
|
a2=self.alias, f2=self.key_field
|
1268
1566
|
)
|
1269
|
-
|
1270
|
-
|
1567
|
+
if row not in old_tables[1:]:
|
1568
|
+
new_tables.append(row)
|
1569
|
+
main.values[FROM] = old_tables[:1] + new_tables + old_tables[1:]
|
1271
1570
|
for key in USUAL_KEYS:
|
1272
1571
|
main.update_values(key, self.values.get(key, []))
|
1273
1572
|
|
1274
|
-
def
|
1573
|
+
def copy(self) -> SQLObject:
|
1275
1574
|
from copy import deepcopy
|
1276
|
-
|
1575
|
+
return deepcopy(self)
|
1576
|
+
|
1577
|
+
def no_relation_error(self, other: SQLObject):
|
1578
|
+
raise ValueError(f'No relationship found between {self.table_name} and {other.table_name}.')
|
1579
|
+
|
1580
|
+
def __add__(self, other: SQLObject):
|
1581
|
+
query = self.copy()
|
1277
1582
|
if query.table_name.lower() == other.table_name.lower():
|
1278
1583
|
for key in USUAL_KEYS:
|
1279
1584
|
query.update_values(key, other.values.get(key, []))
|
@@ -1286,7 +1591,7 @@ class Select(SQLObject):
|
|
1286
1591
|
PrimaryKey.add(primary_key, query)
|
1287
1592
|
query.add(foreign_field, other)
|
1288
1593
|
return other
|
1289
|
-
|
1594
|
+
self.no_relation_error(other) # === raise ERROR ... ===
|
1290
1595
|
elif primary_key:
|
1291
1596
|
PrimaryKey.add(primary_key, other)
|
1292
1597
|
other.add(foreign_field, query)
|
@@ -1306,12 +1611,39 @@ class Select(SQLObject):
|
|
1306
1611
|
if self.diff(key, other.values.get(key, []), True):
|
1307
1612
|
return False
|
1308
1613
|
return True
|
1614
|
+
|
1615
|
+
def __sub__(self, other: SQLObject) -> SQLObject:
|
1616
|
+
fk_field, primary_k = ForeignKey.find(self, other)
|
1617
|
+
if fk_field:
|
1618
|
+
query = self.copy()
|
1619
|
+
other = other.copy()
|
1620
|
+
else:
|
1621
|
+
fk_field, primary_k = ForeignKey.find(other, self)
|
1622
|
+
if not fk_field:
|
1623
|
+
self.no_relation_error(other) # === raise ERROR ... ===
|
1624
|
+
query = other.copy()
|
1625
|
+
other = self.copy()
|
1626
|
+
query.__class__ = NotSelectIN
|
1627
|
+
Field.add(fk_field, query)
|
1628
|
+
query.add(primary_k, other)
|
1629
|
+
return other
|
1309
1630
|
|
1310
1631
|
def limit(self, row_count: int=100, offset: int=0):
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1632
|
+
if Function.dialect == Dialect.SQL_SERVER:
|
1633
|
+
fields = self.values.get(SELECT)
|
1634
|
+
if fields:
|
1635
|
+
fields[0] = f'SELECT TOP({row_count}) {fields[0]}'
|
1636
|
+
else:
|
1637
|
+
self.values[SELECT] = [f'SELECT TOP({row_count}) *']
|
1638
|
+
return self
|
1639
|
+
if Function.dialect == Dialect.ORACLE:
|
1640
|
+
Where.gte(row_count).add(SQL_ROW_NUM, self)
|
1641
|
+
if offset > 0:
|
1642
|
+
Where.lte(row_count+offset).add(SQL_ROW_NUM, self)
|
1643
|
+
return self
|
1644
|
+
self.values[LIMIT] = ['{}{}'.format(
|
1645
|
+
row_count, f' OFFSET {offset}' if offset > 0 else ''
|
1646
|
+
)]
|
1315
1647
|
return self
|
1316
1648
|
|
1317
1649
|
def match(self, field: str, key: str) -> bool:
|
@@ -1319,11 +1651,7 @@ class Select(SQLObject):
|
|
1319
1651
|
Recognizes if the field is from the current table
|
1320
1652
|
'''
|
1321
1653
|
if key in (ORDER_BY, GROUP_BY) and '.' not in field:
|
1322
|
-
return
|
1323
|
-
self.is_named_field(fld, SELECT)
|
1324
|
-
for fld in self.values[SELECT]
|
1325
|
-
if field in fld
|
1326
|
-
)
|
1654
|
+
return self.has_named_field(field)
|
1327
1655
|
return re.findall(f'\b*{self.alias}[.]', field) != []
|
1328
1656
|
|
1329
1657
|
@classmethod
|
@@ -1336,12 +1664,10 @@ class Select(SQLObject):
|
|
1336
1664
|
for rule in rules:
|
1337
1665
|
rule.apply(self)
|
1338
1666
|
|
1339
|
-
def add_fields(self, fields: list,
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
if group_by:
|
1344
|
-
class_types += [GroupBy]
|
1667
|
+
def add_fields(self, fields: list, class_types=None):
|
1668
|
+
if not class_types:
|
1669
|
+
class_types = []
|
1670
|
+
class_types += [Field]
|
1345
1671
|
FieldList(fields, class_types).add('', self)
|
1346
1672
|
|
1347
1673
|
def translate_to(self, language: QueryLanguage) -> str:
|
@@ -1361,6 +1687,95 @@ class NotSelectIN(SelectIN):
|
|
1361
1687
|
condition_class = Not
|
1362
1688
|
|
1363
1689
|
|
1690
|
+
class CTE(Select):
|
1691
|
+
prefix = ''
|
1692
|
+
|
1693
|
+
def __init__(self, table_name: str, query_list: list[Select]):
|
1694
|
+
super().__init__(table_name)
|
1695
|
+
for query in query_list:
|
1696
|
+
query.break_lines = False
|
1697
|
+
self.query_list = query_list
|
1698
|
+
self.break_lines = False
|
1699
|
+
|
1700
|
+
def __str__(self) -> str:
|
1701
|
+
size = 0
|
1702
|
+
for key in USUAL_KEYS:
|
1703
|
+
size += sum(len(v) for v in self.values.get(key, []) if '\n' not in v)
|
1704
|
+
if size > 70:
|
1705
|
+
self.break_lines = True
|
1706
|
+
# ---------------------------------------------------------
|
1707
|
+
def justify(query: Select) -> str:
|
1708
|
+
result, line = [], ''
|
1709
|
+
keywords = '|'.join(KEYWORD)
|
1710
|
+
for word in re.split(fr'({keywords}|AND|OR|,)', str(query)):
|
1711
|
+
if len(line) >= 50:
|
1712
|
+
result.append(line)
|
1713
|
+
line = ''
|
1714
|
+
line += word
|
1715
|
+
if line:
|
1716
|
+
result.append(line)
|
1717
|
+
return '\n '.join(result)
|
1718
|
+
# ---------------------------------------------------------
|
1719
|
+
return 'WITH {}{} AS (\n {}\n){}'.format(
|
1720
|
+
self.prefix, self.table_name,
|
1721
|
+
'\nUNION ALL\n '.join(
|
1722
|
+
justify(q) for q in self.query_list
|
1723
|
+
), super().__str__()
|
1724
|
+
)
|
1725
|
+
|
1726
|
+
def join(self, pattern: str, fields: list | str, format: str=''):
|
1727
|
+
if isinstance(fields, str):
|
1728
|
+
count = len( fields.split(',') )
|
1729
|
+
else:
|
1730
|
+
count = len(fields)
|
1731
|
+
queries = detect(
|
1732
|
+
pattern*count, join_queries=False, format=format
|
1733
|
+
)
|
1734
|
+
FieldList(fields, queries, ziped=True).add('', self)
|
1735
|
+
self.break_lines = True
|
1736
|
+
return self
|
1737
|
+
|
1738
|
+
class Recursive(CTE):
|
1739
|
+
prefix = 'RECURSIVE '
|
1740
|
+
|
1741
|
+
def __str__(self) -> str:
|
1742
|
+
if len(self.query_list) > 1:
|
1743
|
+
self.query_list[-1].values[FROM].append(
|
1744
|
+
f', {self.table_name} {self.alias}')
|
1745
|
+
return super().__str__()
|
1746
|
+
|
1747
|
+
@classmethod
|
1748
|
+
def create(cls, name: str, pattern: str, formula: str, init_value, format: str=''):
|
1749
|
+
SQLObject.ALIAS_FUNC = None
|
1750
|
+
def get_field(obj: SQLObject, pos: int) -> str:
|
1751
|
+
return obj.values[SELECT][pos].split('.')[-1]
|
1752
|
+
t1, t2 = detect(
|
1753
|
+
pattern*2, join_queries=False, format=format
|
1754
|
+
)
|
1755
|
+
pk_field = get_field(t1, 0)
|
1756
|
+
foreign_key = ''
|
1757
|
+
for num in re.findall(r'\[(\d+)\]', formula):
|
1758
|
+
num = int(num)
|
1759
|
+
if not foreign_key:
|
1760
|
+
foreign_key = get_field(t2, num-1)
|
1761
|
+
formula = formula.replace(f'[{num}]', '%')
|
1762
|
+
else:
|
1763
|
+
formula = formula.replace(f'[{num}]', get_field(t2, num-1))
|
1764
|
+
Where.eq(init_value).add(pk_field, t1)
|
1765
|
+
Where.formula(formula).add(foreign_key or pk_field, t2)
|
1766
|
+
return cls(name, [t1, t2])
|
1767
|
+
|
1768
|
+
def counter(self, name: str, start, increment: str='+1'):
|
1769
|
+
for i, query in enumerate(self.query_list):
|
1770
|
+
if i == 0:
|
1771
|
+
Field.add(f'{start} AS {name}', query)
|
1772
|
+
else:
|
1773
|
+
Field.add(f'({name}{increment}) AS {name}', query)
|
1774
|
+
return self
|
1775
|
+
|
1776
|
+
|
1777
|
+
# ----- Rules -----
|
1778
|
+
|
1364
1779
|
class RulePutLimit(Rule):
|
1365
1780
|
@classmethod
|
1366
1781
|
def apply(cls, target: Select):
|
@@ -1424,6 +1839,8 @@ class RuleDateFuncReplace(Rule):
|
|
1424
1839
|
@classmethod
|
1425
1840
|
def apply(cls, target: Select):
|
1426
1841
|
for i, condition in enumerate(target.values.get(WHERE, [])):
|
1842
|
+
if not '(' in condition:
|
1843
|
+
continue
|
1427
1844
|
tokens = [
|
1428
1845
|
t.strip() for t in cls.REGEX.split(condition) if t.strip()
|
1429
1846
|
]
|
@@ -1475,7 +1892,7 @@ def parser_class(text: str) -> Parser:
|
|
1475
1892
|
return None
|
1476
1893
|
|
1477
1894
|
|
1478
|
-
def detect(text: str) -> Select:
|
1895
|
+
def detect(text: str, join_queries: bool = True, format: str='') -> Select | list[Select]:
|
1479
1896
|
from collections import Counter
|
1480
1897
|
parser = parser_class(text)
|
1481
1898
|
if not parser:
|
@@ -1486,27 +1903,65 @@ def detect(text: str) -> Select:
|
|
1486
1903
|
continue
|
1487
1904
|
pos = [ f.span() for f in re.finditer(fr'({table})[(]', text) ]
|
1488
1905
|
for begin, end in pos[::-1]:
|
1489
|
-
new_name = f'{table}_{count}' # See set_table (line
|
1906
|
+
new_name = f'{table}_{count}' # See set_table (line 55)
|
1490
1907
|
Select.EQUIVALENT_NAMES[new_name] = table
|
1491
1908
|
text = text[:begin] + new_name + '(' + text[end:]
|
1492
1909
|
count -= 1
|
1493
1910
|
query_list = Select.parse(text, parser)
|
1911
|
+
if format:
|
1912
|
+
for query in query_list:
|
1913
|
+
query.set_file_format(format)
|
1914
|
+
if not join_queries:
|
1915
|
+
return query_list
|
1494
1916
|
result = query_list[0]
|
1495
1917
|
for query in query_list[1:]:
|
1496
1918
|
result += query
|
1497
1919
|
return result
|
1498
|
-
|
1499
|
-
|
1500
|
-
|
1501
|
-
|
1502
|
-
|
1503
|
-
|
1504
|
-
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1508
|
-
|
1509
|
-
|
1510
|
-
|
1511
|
-
|
1512
|
-
|
1920
|
+
# ===========================================================================================//
|
1921
|
+
|
1922
|
+
|
1923
|
+
if __name__ == "__main__":
|
1924
|
+
# def identifica_suspeitos() -> Select:
|
1925
|
+
# """Mostra quais pessoas tem caracteríosticas iguais à descrição do suspeito"""
|
1926
|
+
# Select.join_type = JoinType.LEFT
|
1927
|
+
# return Select(
|
1928
|
+
# 'Suspeito s', id=Field,
|
1929
|
+
# _=Where.join(
|
1930
|
+
# Select('Pessoa p',
|
1931
|
+
# OR=Options(
|
1932
|
+
# pessoa=Where('= s.id'),
|
1933
|
+
# altura=Where.formula('ABS(% - s.{f}) < 0.5'),
|
1934
|
+
# peso=Where.formula('ABS(% - s.{f}) < 0.5'),
|
1935
|
+
# cabelo=Where.formula('% = s.{f}'),
|
1936
|
+
# olhos=Where.formula('% = s.{f}'),
|
1937
|
+
# sexo=Where.formula('% = s.{f}'),
|
1938
|
+
# ),
|
1939
|
+
# nome=Field
|
1940
|
+
# )
|
1941
|
+
# )
|
1942
|
+
# )
|
1943
|
+
# query = identifica_suspeitos()
|
1944
|
+
# print('='*50)
|
1945
|
+
# print(query)
|
1946
|
+
# print('-'*50)
|
1947
|
+
script = '''
|
1948
|
+
db.people.find({
|
1949
|
+
{
|
1950
|
+
$or: [
|
1951
|
+
status:{$eq:"B"},
|
1952
|
+
age:{$lt:50}
|
1953
|
+
]
|
1954
|
+
},
|
1955
|
+
age:{$gte:18}, status:{$eq:"A"}
|
1956
|
+
},{
|
1957
|
+
name: 1, user_id: 1
|
1958
|
+
}).sort({
|
1959
|
+
'''
|
1960
|
+
print('='*50)
|
1961
|
+
q1 = Select.parse(script, MongoParser)[0]
|
1962
|
+
print(q1)
|
1963
|
+
print('-'*50)
|
1964
|
+
q2 = q1.translate_to(MongoDBLanguage)
|
1965
|
+
print(q2)
|
1966
|
+
# print('-'*50)
|
1967
|
+
print('='*50)
|