sql-blocks 1.25.112__py3-none-any.whl → 1.25.514__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_blocks/sql_blocks.py +596 -113
- {sql_blocks-1.25.112.dist-info → sql_blocks-1.25.514.dist-info}/METADATA +270 -9
- sql_blocks-1.25.514.dist-info/RECORD +7 -0
- sql_blocks-1.25.112.dist-info/RECORD +0 -7
- {sql_blocks-1.25.112.dist-info → sql_blocks-1.25.514.dist-info}/LICENSE +0 -0
- {sql_blocks-1.25.112.dist-info → sql_blocks-1.25.514.dist-info}/WHEEL +0 -0
- {sql_blocks-1.25.112.dist-info → sql_blocks-1.25.514.dist-info}/top_level.txt +0 -0
sql_blocks/sql_blocks.py
CHANGED
@@ -42,23 +42,34 @@ class SQLObject:
|
|
42
42
|
if not table_name:
|
43
43
|
return
|
44
44
|
cls = SQLObject
|
45
|
+
is_file_name = any([
|
46
|
+
'/' in table_name, '.' in table_name
|
47
|
+
])
|
48
|
+
ref = table_name
|
49
|
+
if is_file_name:
|
50
|
+
ref = table_name.split('/')[-1].split('.')[0]
|
45
51
|
if cls.ALIAS_FUNC:
|
46
|
-
self.__alias = cls.ALIAS_FUNC(
|
52
|
+
self.__alias = cls.ALIAS_FUNC(ref)
|
47
53
|
elif ' ' in table_name.strip():
|
48
54
|
table_name, self.__alias = table_name.split()
|
49
|
-
elif '_' in
|
55
|
+
elif '_' in ref:
|
50
56
|
self.__alias = ''.join(
|
51
57
|
word[0].lower()
|
52
|
-
for word in
|
58
|
+
for word in ref.split('_')
|
53
59
|
)
|
54
60
|
else:
|
55
|
-
self.__alias =
|
61
|
+
self.__alias = ref.lower()[:3]
|
56
62
|
self.values.setdefault(FROM, []).append(f'{table_name} {self.alias}')
|
57
63
|
|
58
64
|
@property
|
59
65
|
def table_name(self) -> str:
|
60
66
|
return self.values[FROM][0].split()[0]
|
61
67
|
|
68
|
+
def set_file_format(self, pattern: str):
|
69
|
+
if '{' not in pattern:
|
70
|
+
pattern = '{}' + pattern
|
71
|
+
self.values[FROM][0] = pattern.format(self.aka())
|
72
|
+
|
62
73
|
@property
|
63
74
|
def alias(self) -> str:
|
64
75
|
if self.__alias:
|
@@ -70,6 +81,16 @@ class SQLObject:
|
|
70
81
|
appendix = {WHERE: r'\s+and\s+|', FROM: r'\s+join\s+|\s+JOIN\s+'}
|
71
82
|
return KEYWORD[key][0].format(appendix.get(key, ''))
|
72
83
|
|
84
|
+
@staticmethod
|
85
|
+
def is_named_field(fld: str, name: str='') -> bool:
|
86
|
+
return re.search(fr'(\s+as\s+|\s+AS\s+){name}', fld)
|
87
|
+
|
88
|
+
def has_named_field(self, name: str) -> bool:
|
89
|
+
return any(
|
90
|
+
self.is_named_field(fld, name)
|
91
|
+
for fld in self.values.get(SELECT, [])
|
92
|
+
)
|
93
|
+
|
73
94
|
def diff(self, key: str, search_list: list, exact: bool=False) -> set:
|
74
95
|
def disassemble(source: list) -> list:
|
75
96
|
if not exact:
|
@@ -78,16 +99,17 @@ class SQLObject:
|
|
78
99
|
for fld in source:
|
79
100
|
result += re.split(r'([=()]|<>|\s+ON\s+|\s+on\s+)', fld)
|
80
101
|
return result
|
81
|
-
def cleanup(
|
102
|
+
def cleanup(text: str) -> str:
|
103
|
+
text = re.sub(r'[\n\t]', ' ', text)
|
82
104
|
if exact:
|
83
|
-
|
84
|
-
return
|
85
|
-
def is_named_field(fld: str) -> bool:
|
86
|
-
return key == SELECT and re.search(r'\s+as\s+|\s+AS\s+', fld)
|
105
|
+
text = text.lower()
|
106
|
+
return text.strip()
|
87
107
|
def field_set(source: list) -> set:
|
88
108
|
return set(
|
89
109
|
(
|
90
|
-
fld
|
110
|
+
fld
|
111
|
+
if key == SELECT and self.is_named_field(fld, key)
|
112
|
+
else
|
91
113
|
re.sub(pattern, '', cleanup(fld))
|
92
114
|
)
|
93
115
|
for string in disassemble(source)
|
@@ -105,18 +127,22 @@ class SQLObject:
|
|
105
127
|
return s1.symmetric_difference(s2)
|
106
128
|
return s1 - s2
|
107
129
|
|
108
|
-
def delete(self, search: str, keys: list=USUAL_KEYS):
|
130
|
+
def delete(self, search: str, keys: list=USUAL_KEYS, exact: bool=False):
|
131
|
+
if exact:
|
132
|
+
not_match = lambda item: not re.search(fr'\w*[.]*{search}$', item)
|
133
|
+
else:
|
134
|
+
not_match = lambda item: search not in item
|
109
135
|
for key in keys:
|
110
|
-
|
111
|
-
|
112
|
-
if
|
113
|
-
|
114
|
-
self.values[key] = result
|
136
|
+
self.values[key] = [
|
137
|
+
item for item in self.values.get(key, [])
|
138
|
+
if not_match(item)
|
139
|
+
]
|
115
140
|
|
116
141
|
|
117
142
|
SQL_CONST_SYSDATE = 'SYSDATE'
|
118
143
|
SQL_CONST_CURR_DATE = 'Current_date'
|
119
|
-
|
144
|
+
SQL_ROW_NUM = 'ROWNUM'
|
145
|
+
SQL_CONSTS = [SQL_CONST_SYSDATE, SQL_CONST_CURR_DATE, SQL_ROW_NUM]
|
120
146
|
|
121
147
|
|
122
148
|
class Field:
|
@@ -133,7 +159,7 @@ class Field:
|
|
133
159
|
name = name.strip()
|
134
160
|
if name in ('_', '*'):
|
135
161
|
name = '*'
|
136
|
-
elif not is_const():
|
162
|
+
elif not is_const() and not main.has_named_field(name):
|
137
163
|
name = f'{main.alias}.{name}'
|
138
164
|
if Function in cls.__bases__:
|
139
165
|
name = f'{cls.__name__}({name})'
|
@@ -171,15 +197,35 @@ class Dialect(Enum):
|
|
171
197
|
POSTGRESQL = 3
|
172
198
|
MYSQL = 4
|
173
199
|
|
200
|
+
SQL_TYPES = 'CHAR INT DATE FLOAT ANY'.split()
|
201
|
+
CHAR, INT, DATE, FLOAT, ANY = SQL_TYPES
|
202
|
+
|
174
203
|
class Function:
|
175
204
|
dialect = Dialect.ANSI
|
205
|
+
inputs = None
|
206
|
+
output = None
|
207
|
+
separator = ', '
|
208
|
+
auto_convert = True
|
209
|
+
append_param = False
|
176
210
|
|
177
211
|
def __init__(self, *params: list):
|
212
|
+
def set_func_types(param):
|
213
|
+
if self.auto_convert and isinstance(param, Function):
|
214
|
+
func = param
|
215
|
+
main_param = self.inputs[0]
|
216
|
+
unfriendly = all([
|
217
|
+
func.output != main_param,
|
218
|
+
func.output != ANY,
|
219
|
+
main_param != ANY
|
220
|
+
])
|
221
|
+
if unfriendly:
|
222
|
+
return Cast(func, main_param)
|
223
|
+
return param
|
178
224
|
# --- Replace class methods by instance methods: ------
|
179
225
|
self.add = self.__add
|
180
226
|
self.format = self.__format
|
181
227
|
# -----------------------------------------------------
|
182
|
-
self.params = [
|
228
|
+
self.params = [set_func_types(p) for p in params]
|
183
229
|
self.field_class = Field
|
184
230
|
self.pattern = self.get_pattern()
|
185
231
|
self.extra = {}
|
@@ -196,14 +242,35 @@ class Function:
|
|
196
242
|
def __str__(self) -> str:
|
197
243
|
return self.pattern.format(
|
198
244
|
func_name=self.__class__.__name__,
|
199
|
-
params=
|
245
|
+
params=self.separator.join(str(p) for p in self.params)
|
200
246
|
)
|
201
247
|
|
248
|
+
@classmethod
|
249
|
+
def help(cls) -> str:
|
250
|
+
descr = ' '.join(B.__name__ for B in cls.__bases__)
|
251
|
+
params = cls.inputs or ''
|
252
|
+
return cls().get_pattern().format(
|
253
|
+
func_name=f'{descr} {cls.__name__}',
|
254
|
+
params=cls.separator.join(str(p) for p in params)
|
255
|
+
) + f' Return {cls.output}'
|
256
|
+
|
257
|
+
def set_main_param(self, name: str, main: SQLObject) -> bool:
|
258
|
+
nested_functions = [
|
259
|
+
param for param in self.params if isinstance(param, Function)
|
260
|
+
]
|
261
|
+
for func in nested_functions:
|
262
|
+
if func.inputs:
|
263
|
+
func.set_main_param(name, main)
|
264
|
+
return
|
265
|
+
new_params = [Field.format(name, main)]
|
266
|
+
if self.append_param:
|
267
|
+
self.params += new_params
|
268
|
+
else:
|
269
|
+
self.params = new_params + self.params
|
270
|
+
|
202
271
|
def __format(self, name: str, main: SQLObject) -> str:
|
203
272
|
if name not in '*_':
|
204
|
-
self.
|
205
|
-
Field.format(name, main)
|
206
|
-
] + self.params
|
273
|
+
self.set_main_param(name, main)
|
207
274
|
return str(self)
|
208
275
|
|
209
276
|
@classmethod
|
@@ -223,6 +290,9 @@ class Function:
|
|
223
290
|
|
224
291
|
# ---- String Functions: ---------------------------------
|
225
292
|
class SubString(Function):
|
293
|
+
inputs = [CHAR, INT, INT]
|
294
|
+
output = CHAR
|
295
|
+
|
226
296
|
def get_pattern(self) -> str:
|
227
297
|
if self.dialect in (Dialect.ORACLE, Dialect.MYSQL):
|
228
298
|
return 'Substr({params})'
|
@@ -230,31 +300,55 @@ class SubString(Function):
|
|
230
300
|
|
231
301
|
# ---- Numeric Functions: --------------------------------
|
232
302
|
class Round(Function):
|
233
|
-
|
303
|
+
inputs = [FLOAT]
|
304
|
+
output = FLOAT
|
234
305
|
|
235
306
|
# --- Date Functions: ------------------------------------
|
236
307
|
class DateDiff(Function):
|
237
|
-
|
308
|
+
inputs = [DATE]
|
309
|
+
output = DATE
|
310
|
+
append_param = True
|
311
|
+
|
312
|
+
def __str__(self) -> str:
|
238
313
|
def is_field_or_func(name: str) -> bool:
|
239
|
-
|
314
|
+
candidate = re.sub(
|
315
|
+
'[()]', '', name.split('.')[-1]
|
316
|
+
)
|
317
|
+
return candidate.isidentifier()
|
240
318
|
if self.dialect != Dialect.SQL_SERVER:
|
319
|
+
params = [str(p) for p in self.params]
|
241
320
|
return ' - '.join(
|
242
321
|
p if is_field_or_func(p) else f"'{p}'"
|
243
|
-
for p in
|
322
|
+
for p in params
|
244
323
|
) # <==== Date subtract
|
245
|
-
return super().
|
324
|
+
return super().__str__()
|
325
|
+
|
326
|
+
|
327
|
+
class DatePart(Function):
|
328
|
+
inputs = [DATE]
|
329
|
+
output = INT
|
246
330
|
|
247
|
-
class Year(Function):
|
248
331
|
def get_pattern(self) -> str:
|
332
|
+
interval = self.__class__.__name__
|
249
333
|
database_type = {
|
250
|
-
Dialect.ORACLE: 'Extract(
|
251
|
-
Dialect.POSTGRESQL: "Date_Part('
|
334
|
+
Dialect.ORACLE: 'Extract('+interval+' FROM {params})',
|
335
|
+
Dialect.POSTGRESQL: "Date_Part('"+interval+"', {params})",
|
252
336
|
}
|
253
337
|
if self.dialect in database_type:
|
254
338
|
return database_type[self.dialect]
|
255
339
|
return super().get_pattern()
|
256
340
|
|
341
|
+
class Year(DatePart):
|
342
|
+
...
|
343
|
+
class Month(DatePart):
|
344
|
+
...
|
345
|
+
class Day(DatePart):
|
346
|
+
...
|
347
|
+
|
348
|
+
|
257
349
|
class Current_Date(Function):
|
350
|
+
output = DATE
|
351
|
+
|
258
352
|
def get_pattern(self) -> str:
|
259
353
|
database_type = {
|
260
354
|
Dialect.ORACLE: SQL_CONST_SYSDATE,
|
@@ -277,14 +371,15 @@ class Frame:
|
|
277
371
|
keywords = ''
|
278
372
|
for field, obj in args.items():
|
279
373
|
is_valid = any([
|
280
|
-
obj is
|
281
|
-
|
374
|
+
obj is OrderBy,
|
375
|
+
obj is Partition,
|
376
|
+
isinstance(obj, Rows),
|
282
377
|
])
|
283
378
|
if not is_valid:
|
284
379
|
continue
|
285
380
|
keywords += '{}{} {}'.format(
|
286
381
|
'\n\t\t' if self.break_lines else ' ',
|
287
|
-
obj.cls_to_str(), field
|
382
|
+
obj.cls_to_str(), field if field != '_' else ''
|
288
383
|
)
|
289
384
|
if keywords and self.break_lines:
|
290
385
|
keywords += '\n\t'
|
@@ -293,7 +388,8 @@ class Frame:
|
|
293
388
|
|
294
389
|
|
295
390
|
class Aggregate(Frame):
|
296
|
-
|
391
|
+
inputs = [FLOAT]
|
392
|
+
output = FLOAT
|
297
393
|
|
298
394
|
class Window(Frame):
|
299
395
|
...
|
@@ -312,20 +408,30 @@ class Count(Aggregate, Function):
|
|
312
408
|
|
313
409
|
# ---- Window Functions: -----------------------------------
|
314
410
|
class Row_Number(Window, Function):
|
315
|
-
|
411
|
+
output = INT
|
412
|
+
|
316
413
|
class Rank(Window, Function):
|
317
|
-
|
414
|
+
output = INT
|
415
|
+
|
318
416
|
class Lag(Window, Function):
|
319
|
-
|
417
|
+
output = ANY
|
418
|
+
|
320
419
|
class Lead(Window, Function):
|
321
|
-
|
420
|
+
output = ANY
|
322
421
|
|
323
422
|
|
324
423
|
# ---- Conversions and other Functions: ---------------------
|
325
424
|
class Coalesce(Function):
|
326
|
-
|
425
|
+
inputs = [ANY]
|
426
|
+
output = ANY
|
427
|
+
|
327
428
|
class Cast(Function):
|
328
|
-
|
429
|
+
inputs = [ANY]
|
430
|
+
output = ANY
|
431
|
+
separator = ' As '
|
432
|
+
|
433
|
+
|
434
|
+
FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
|
329
435
|
|
330
436
|
|
331
437
|
class ExpressionField:
|
@@ -350,15 +456,20 @@ class ExpressionField:
|
|
350
456
|
class FieldList:
|
351
457
|
separator = ','
|
352
458
|
|
353
|
-
def __init__(self, fields: list=[], class_types = [Field]):
|
459
|
+
def __init__(self, fields: list=[], class_types = [Field], ziped: bool=False):
|
354
460
|
if isinstance(fields, str):
|
355
461
|
fields = [
|
356
462
|
f.strip() for f in fields.split(self.separator)
|
357
463
|
]
|
358
464
|
self.fields = fields
|
359
465
|
self.class_types = class_types
|
466
|
+
self.ziped = ziped
|
360
467
|
|
361
468
|
def add(self, name: str, main: SQLObject):
|
469
|
+
if self.ziped: # --- One class per field...
|
470
|
+
for field, class_type in zip(self.fields, self.class_types):
|
471
|
+
class_type.add(field, main)
|
472
|
+
return
|
362
473
|
for field in self.fields:
|
363
474
|
for class_type in self.class_types:
|
364
475
|
class_type.add(field, main)
|
@@ -400,36 +511,40 @@ class ForeignKey:
|
|
400
511
|
|
401
512
|
def quoted(value) -> str:
|
402
513
|
if isinstance(value, str):
|
514
|
+
if re.search(r'\bor\b', value, re.IGNORECASE):
|
515
|
+
raise PermissionError('Possible SQL injection attempt')
|
403
516
|
value = f"'{value}'"
|
404
517
|
return str(value)
|
405
518
|
|
406
519
|
|
407
520
|
class Position(Enum):
|
521
|
+
StartsWith = -1
|
408
522
|
Middle = 0
|
409
|
-
|
410
|
-
EndsWith = 2
|
523
|
+
EndsWith = 1
|
411
524
|
|
412
525
|
|
413
526
|
class Where:
|
414
527
|
prefix = ''
|
415
528
|
|
416
|
-
def __init__(self,
|
417
|
-
self.
|
529
|
+
def __init__(self, content: str):
|
530
|
+
self.content = content
|
418
531
|
|
419
532
|
@classmethod
|
420
533
|
def __constructor(cls, operator: str, value):
|
421
|
-
return cls(
|
534
|
+
return cls(f'{operator} {quoted(value)}')
|
422
535
|
|
423
536
|
@classmethod
|
424
537
|
def eq(cls, value):
|
425
538
|
return cls.__constructor('=', value)
|
426
539
|
|
427
540
|
@classmethod
|
428
|
-
def contains(cls,
|
541
|
+
def contains(cls, text: str, pos: int | Position = Position.Middle):
|
542
|
+
if isinstance(pos, int):
|
543
|
+
pos = Position(pos)
|
429
544
|
return cls(
|
430
545
|
"LIKE '{}{}{}'".format(
|
431
546
|
'%' if pos != Position.StartsWith else '',
|
432
|
-
|
547
|
+
text,
|
433
548
|
'%' if pos != Position.EndsWith else ''
|
434
549
|
)
|
435
550
|
)
|
@@ -460,9 +575,43 @@ class Where:
|
|
460
575
|
values = ','.join(quoted(v) for v in values)
|
461
576
|
return cls(f'IN ({values})')
|
462
577
|
|
578
|
+
@classmethod
|
579
|
+
def formula(cls, formula: str):
|
580
|
+
where = cls( ExpressionField(formula) )
|
581
|
+
where.add = where.add_expression
|
582
|
+
return where
|
583
|
+
|
584
|
+
def add_expression(self, name: str, main: SQLObject):
|
585
|
+
self.content = self.content.format(name, main)
|
586
|
+
main.values.setdefault(WHERE, []).append('{} {}'.format(
|
587
|
+
self.prefix, self.content
|
588
|
+
))
|
589
|
+
|
590
|
+
@classmethod
|
591
|
+
def join(cls, query: SQLObject):
|
592
|
+
where = cls(query)
|
593
|
+
where.add = where.add_join
|
594
|
+
return where
|
595
|
+
|
596
|
+
def add_join(self, name: str, main: SQLObject):
|
597
|
+
query = self.content
|
598
|
+
main.values[FROM].append(f',{query.table_name} {query.alias}')
|
599
|
+
for key in USUAL_KEYS:
|
600
|
+
main.update_values(key, query.values.get(key, []))
|
601
|
+
if query.key_field:
|
602
|
+
main.values.setdefault(WHERE, []).append('({a1}.{f1} = {a2}.{f2})'.format(
|
603
|
+
a1=main.alias, f1=name,
|
604
|
+
a2=query.alias, f2=query.key_field
|
605
|
+
))
|
606
|
+
|
463
607
|
def add(self, name: str, main: SQLObject):
|
608
|
+
func_type = FUNCTION_CLASS.get(name.lower())
|
609
|
+
if func_type:
|
610
|
+
name = func_type.format('*', main)
|
611
|
+
elif not main.has_named_field(name):
|
612
|
+
name = Field.format(name, main)
|
464
613
|
main.values.setdefault(WHERE, []).append('{}{} {}'.format(
|
465
|
-
self.prefix,
|
614
|
+
self.prefix, name, self.content
|
466
615
|
))
|
467
616
|
|
468
617
|
|
@@ -470,6 +619,10 @@ eq, contains, gt, gte, lt, lte, is_null, inside = (
|
|
470
619
|
getattr(Where, method) for method in
|
471
620
|
('eq', 'contains', 'gt', 'gte', 'lt', 'lte', 'is_null', 'inside')
|
472
621
|
)
|
622
|
+
startswith, endswith = [
|
623
|
+
lambda x: contains(x, Position.StartsWith),
|
624
|
+
lambda x: contains(x, Position.EndsWith)
|
625
|
+
]
|
473
626
|
|
474
627
|
|
475
628
|
class Not(Where):
|
@@ -477,7 +630,7 @@ class Not(Where):
|
|
477
630
|
|
478
631
|
@classmethod
|
479
632
|
def eq(cls, value):
|
480
|
-
return Where(
|
633
|
+
return Where(f'<> {quoted(value)}')
|
481
634
|
|
482
635
|
|
483
636
|
class Case:
|
@@ -486,22 +639,26 @@ class Case:
|
|
486
639
|
self.default = None
|
487
640
|
self.field = field
|
488
641
|
|
489
|
-
def when(self, condition: Where, result
|
642
|
+
def when(self, condition: Where, result):
|
643
|
+
if isinstance(result, str):
|
644
|
+
result = quoted(result)
|
490
645
|
self.__conditions[result] = condition
|
491
646
|
return self
|
492
647
|
|
493
|
-
def else_value(self, default
|
648
|
+
def else_value(self, default):
|
649
|
+
if isinstance(default, str):
|
650
|
+
default = quoted(default)
|
494
651
|
self.default = default
|
495
652
|
return self
|
496
653
|
|
497
654
|
def add(self, name: str, main: SQLObject):
|
498
655
|
field = Field.format(self.field, main)
|
499
|
-
default =
|
656
|
+
default = self.default
|
500
657
|
name = 'CASE \n{}\n\tEND AS {}'.format(
|
501
658
|
'\n'.join(
|
502
|
-
f'\t\tWHEN {field} {cond.
|
659
|
+
f'\t\tWHEN {field} {cond.content} THEN {res}'
|
503
660
|
for res, cond in self.__conditions.items()
|
504
|
-
) + f'\n\t\tELSE {default}' if default else '',
|
661
|
+
) + (f'\n\t\tELSE {default}' if default else ''),
|
505
662
|
name
|
506
663
|
)
|
507
664
|
main.values.setdefault(SELECT, []).append(name)
|
@@ -512,42 +669,69 @@ class Options:
|
|
512
669
|
self.__children: dict = values
|
513
670
|
|
514
671
|
def add(self, logical_separator: str, main: SQLObject):
|
515
|
-
if logical_separator not in ('AND', 'OR'):
|
672
|
+
if logical_separator.upper() not in ('AND', 'OR'):
|
516
673
|
raise ValueError('`logical_separator` must be AND or OR')
|
517
|
-
|
674
|
+
temp = Select(f'{main.table_name} {main.alias}')
|
518
675
|
child: Where
|
519
676
|
for field, child in self.__children.items():
|
520
|
-
|
521
|
-
Field.format(field, main), child.expr
|
522
|
-
))
|
677
|
+
child.add(field, temp)
|
523
678
|
main.values.setdefault(WHERE, []).append(
|
524
|
-
'(' + logical_separator.join(
|
679
|
+
'(' + f'\n\t{logical_separator} '.join(temp.values[WHERE]) + ')'
|
525
680
|
)
|
526
681
|
|
527
682
|
|
528
683
|
class Between:
|
684
|
+
is_literal: bool = False
|
685
|
+
|
529
686
|
def __init__(self, start, end):
|
530
687
|
if start > end:
|
531
688
|
start, end = end, start
|
532
689
|
self.start = start
|
533
690
|
self.end = end
|
534
691
|
|
692
|
+
def literal(self) -> Where:
|
693
|
+
return Where('BETWEEN {} AND {}'.format(
|
694
|
+
self.start, self.end
|
695
|
+
))
|
696
|
+
|
535
697
|
def add(self, name: str, main:SQLObject):
|
536
|
-
|
698
|
+
if self.is_literal:
|
699
|
+
return self.literal().add(name, main)
|
700
|
+
Where.gte(self.start).add(name, main)
|
537
701
|
Where.lte(self.end).add(name, main)
|
538
702
|
|
703
|
+
class SameDay(Between):
|
704
|
+
def __init__(self, date: str):
|
705
|
+
super().__init__(
|
706
|
+
f'{date} 00:00:00',
|
707
|
+
f'{date} 23:59:59',
|
708
|
+
)
|
709
|
+
|
710
|
+
|
711
|
+
class Range(Case):
|
712
|
+
INC_FUNCTION = lambda x: x + 1
|
713
|
+
|
714
|
+
def __init__(self, field: str, values: dict):
|
715
|
+
super().__init__(field)
|
716
|
+
start = 0
|
717
|
+
cls = self.__class__
|
718
|
+
for label, value in sorted(values.items(), key=lambda item: item[1]):
|
719
|
+
self.when(
|
720
|
+
Between(start, value).literal(), label
|
721
|
+
)
|
722
|
+
start = cls.INC_FUNCTION(value)
|
723
|
+
|
539
724
|
|
540
725
|
class Clause:
|
541
726
|
@classmethod
|
542
727
|
def format(cls, name: str, main: SQLObject) -> str:
|
543
728
|
def is_function() -> bool:
|
544
729
|
diff = main.diff(SELECT, [name.lower()], True)
|
545
|
-
FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
|
546
730
|
return diff.intersection(FUNCTION_CLASS)
|
547
731
|
found = re.findall(r'^_\d', name)
|
548
732
|
if found:
|
549
733
|
name = found[0].replace('_', '')
|
550
|
-
elif main.alias and not is_function():
|
734
|
+
elif '.' not in name and main.alias and not is_function():
|
551
735
|
name = f'{main.alias}.{name}'
|
552
736
|
return name
|
553
737
|
|
@@ -556,6 +740,34 @@ class SortType(Enum):
|
|
556
740
|
ASC = ''
|
557
741
|
DESC = ' DESC'
|
558
742
|
|
743
|
+
class Row:
|
744
|
+
def __init__(self, value: int=0):
|
745
|
+
self.value = value
|
746
|
+
|
747
|
+
def __str__(self) -> str:
|
748
|
+
return '{} {}'.format(
|
749
|
+
'UNBOUNDED' if self.value == 0 else self.value,
|
750
|
+
self.__class__.__name__.upper()
|
751
|
+
)
|
752
|
+
|
753
|
+
class Preceding(Row):
|
754
|
+
...
|
755
|
+
class Following(Row):
|
756
|
+
...
|
757
|
+
class Current(Row):
|
758
|
+
def __str__(self) -> str:
|
759
|
+
return 'CURRENT ROW'
|
760
|
+
|
761
|
+
class Rows:
|
762
|
+
def __init__(self, *rows: list[Row]):
|
763
|
+
self.rows = rows
|
764
|
+
|
765
|
+
def cls_to_str(self) -> str:
|
766
|
+
return 'ROWS {}{}'.format(
|
767
|
+
'BETWEEN ' if len(self.rows) > 1 else '',
|
768
|
+
' AND '.join(str(row) for row in self.rows)
|
769
|
+
)
|
770
|
+
|
559
771
|
|
560
772
|
class OrderBy(Clause):
|
561
773
|
sort: SortType = SortType.ASC
|
@@ -590,7 +802,7 @@ class Having:
|
|
590
802
|
|
591
803
|
def add(self, name: str, main:SQLObject):
|
592
804
|
main.values[GROUP_BY][-1] += ' HAVING {} {}'.format(
|
593
|
-
self.function.format(name, main), self.condition.
|
805
|
+
self.function.format(name, main), self.condition.content
|
594
806
|
)
|
595
807
|
|
596
808
|
@classmethod
|
@@ -620,12 +832,20 @@ class Rule:
|
|
620
832
|
...
|
621
833
|
|
622
834
|
class QueryLanguage:
|
623
|
-
pattern = '{select}{_from}{where}{group_by}{order_by}'
|
835
|
+
pattern = '{select}{_from}{where}{group_by}{order_by}{limit}'
|
624
836
|
has_default = {key: bool(key == SELECT) for key in KEYWORD}
|
625
837
|
|
626
838
|
@staticmethod
|
627
|
-
def remove_alias(
|
628
|
-
|
839
|
+
def remove_alias(text: str) -> str:
|
840
|
+
value, sep = '', ''
|
841
|
+
text = re.sub('[\n\t]', ' ', text)
|
842
|
+
if ':' in text:
|
843
|
+
text, value = text.split(':', maxsplit=1)
|
844
|
+
sep = ':'
|
845
|
+
return '{}{}{}'.format(
|
846
|
+
''.join(re.split(r'\w+[.]', text)),
|
847
|
+
sep, value.replace("'", '"')
|
848
|
+
)
|
629
849
|
|
630
850
|
def join_with_tabs(self, values: list, sep: str='') -> str:
|
631
851
|
sep = sep + self.TABULATION
|
@@ -643,18 +863,21 @@ class QueryLanguage:
|
|
643
863
|
return self.join_with_tabs(values, ' AND ')
|
644
864
|
|
645
865
|
def sort_by(self, values: list) -> str:
|
646
|
-
return self.join_with_tabs(values)
|
866
|
+
return self.join_with_tabs(values, ',')
|
647
867
|
|
648
868
|
def set_group(self, values: list) -> str:
|
649
869
|
return self.join_with_tabs(values, ',')
|
650
870
|
|
871
|
+
def set_limit(self, values: list) -> str:
|
872
|
+
return self.join_with_tabs(values, ' ')
|
873
|
+
|
651
874
|
def __init__(self, target: 'Select'):
|
652
|
-
self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY]
|
875
|
+
self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY, LIMIT]
|
653
876
|
self.TABULATION = '\n\t' if target.break_lines else ' '
|
654
877
|
self.LINE_BREAK = '\n' if target.break_lines else ' '
|
655
878
|
self.TOKEN_METHODS = {
|
656
879
|
SELECT: self.add_field, FROM: self.get_tables,
|
657
|
-
WHERE: self.extract_conditions,
|
880
|
+
WHERE: self.extract_conditions, LIMIT: self.set_limit,
|
658
881
|
ORDER_BY: self.sort_by, GROUP_BY: self.set_group,
|
659
882
|
}
|
660
883
|
self.result = {}
|
@@ -690,7 +913,8 @@ class MongoDBLanguage(QueryLanguage):
|
|
690
913
|
LOGICAL_OP_TO_MONGO_FUNC = {
|
691
914
|
'>': '$gt', '>=': '$gte',
|
692
915
|
'<': '$lt', '<=': '$lte',
|
693
|
-
'=': '$eq', '<>': '$ne',
|
916
|
+
'=': '$eq', '<>': '$ne',
|
917
|
+
'like': '$regex', 'LIKE': '$regex',
|
694
918
|
}
|
695
919
|
OPERATORS = '|'.join(op for op in LOGICAL_OP_TO_MONGO_FUNC)
|
696
920
|
REGEX = {
|
@@ -743,7 +967,7 @@ class MongoDBLanguage(QueryLanguage):
|
|
743
967
|
field, *op, const = tokens
|
744
968
|
op = ''.join(op)
|
745
969
|
expr = '{begin}{op}:{const}{end}'.format(
|
746
|
-
begin='{', const=const, end='}',
|
970
|
+
begin='{', const=const.replace('%', '.*'), end='}',
|
747
971
|
op=cls.LOGICAL_OP_TO_MONGO_FUNC[op],
|
748
972
|
)
|
749
973
|
where_list.append(f'{field}:{expr}')
|
@@ -852,6 +1076,55 @@ class Neo4JLanguage(QueryLanguage):
|
|
852
1076
|
return ''
|
853
1077
|
|
854
1078
|
|
1079
|
+
class DatabricksLanguage(QueryLanguage):
|
1080
|
+
pattern = '{_from}{where}{group_by}{order_by}{select}{limit}'
|
1081
|
+
has_default = {key: bool(key == SELECT) for key in KEYWORD}
|
1082
|
+
|
1083
|
+
def __init__(self, target: 'Select'):
|
1084
|
+
super().__init__(target)
|
1085
|
+
self.aggregation_fields = []
|
1086
|
+
|
1087
|
+
def add_field(self, values: list) -> str:
|
1088
|
+
AGG_FUNCS = '|'.join(cls.__name__ for cls in Aggregate.__subclasses__())
|
1089
|
+
# --------------------------------------------------------------
|
1090
|
+
def is_agg_field(fld: str) -> bool:
|
1091
|
+
return re.findall(fr'({AGG_FUNCS})[(]', fld, re.IGNORECASE)
|
1092
|
+
# --------------------------------------------------------------
|
1093
|
+
new_values = []
|
1094
|
+
for val in values:
|
1095
|
+
if is_agg_field(val):
|
1096
|
+
self.aggregation_fields.append(val)
|
1097
|
+
else:
|
1098
|
+
new_values.append(val)
|
1099
|
+
values = new_values
|
1100
|
+
return super().add_field(values)
|
1101
|
+
|
1102
|
+
def prefix(self, key: str) -> str:
|
1103
|
+
def get_aggregate() -> str:
|
1104
|
+
return 'AGGREGATE {} '.format(
|
1105
|
+
','.join(self.aggregation_fields)
|
1106
|
+
)
|
1107
|
+
return '{}{}{}{}{}'.format(
|
1108
|
+
'|> ' if key != FROM else '',
|
1109
|
+
self.LINE_BREAK,
|
1110
|
+
get_aggregate() if key == GROUP_BY else '',
|
1111
|
+
key, self.TABULATION
|
1112
|
+
)
|
1113
|
+
|
1114
|
+
# def get_tables(self, values: list) -> str:
|
1115
|
+
# return self.join_with_tabs(values)
|
1116
|
+
|
1117
|
+
# def extract_conditions(self, values: list) -> str:
|
1118
|
+
# return self.join_with_tabs(values, ' AND ')
|
1119
|
+
|
1120
|
+
# def sort_by(self, values: list) -> str:
|
1121
|
+
# return self.join_with_tabs(values, ',')
|
1122
|
+
|
1123
|
+
def set_group(self, values: list) -> str:
|
1124
|
+
return self.join_with_tabs(values, ',')
|
1125
|
+
|
1126
|
+
|
1127
|
+
|
855
1128
|
class Parser:
|
856
1129
|
REGEX = {}
|
857
1130
|
|
@@ -958,10 +1231,13 @@ class SQLParser(Parser):
|
|
958
1231
|
if not key in values:
|
959
1232
|
continue
|
960
1233
|
separator = self.class_type.get_separator(key)
|
1234
|
+
cls = {
|
1235
|
+
ORDER_BY: OrderBy, GROUP_BY: GroupBy
|
1236
|
+
}.get(key, Field)
|
961
1237
|
obj.values[key] = [
|
962
|
-
|
1238
|
+
cls.format(fld, obj)
|
963
1239
|
for fld in re.split(separator, values[key])
|
964
|
-
if (fld != '*' and len(tables) == 1) or obj.match(fld)
|
1240
|
+
if (fld != '*' and len(tables) == 1) or obj.match(fld, key)
|
965
1241
|
]
|
966
1242
|
result[obj.alias] = obj
|
967
1243
|
self.queries = list( result.values() )
|
@@ -1021,16 +1297,26 @@ class CypherParser(Parser):
|
|
1021
1297
|
if token in self.TOKEN_METHODS:
|
1022
1298
|
return
|
1023
1299
|
class_list = [Field]
|
1024
|
-
if '
|
1300
|
+
if '*' in token:
|
1301
|
+
token = token.replace('*', '')
|
1302
|
+
self.queries[-1].key_field = token
|
1303
|
+
return
|
1304
|
+
elif '$' in token:
|
1025
1305
|
func_name, token = token.split('$')
|
1026
1306
|
if func_name == 'count':
|
1027
1307
|
if not token:
|
1028
1308
|
token = 'count_1'
|
1029
|
-
|
1030
|
-
|
1309
|
+
pk_field = self.queries[-1].key_field or 'id'
|
1310
|
+
Count().As(token, extra_classes).add(pk_field, self.queries[-1])
|
1311
|
+
return
|
1031
1312
|
else:
|
1032
|
-
|
1033
|
-
|
1313
|
+
class_type = FUNCTION_CLASS.get(func_name)
|
1314
|
+
if not class_type:
|
1315
|
+
raise ValueError(f'Unknown function `{func_name}`.')
|
1316
|
+
if ':' in token:
|
1317
|
+
token, field_alias = token.split(':')
|
1318
|
+
class_type = class_type().As(field_alias)
|
1319
|
+
class_list = [class_type]
|
1034
1320
|
class_list += extra_classes
|
1035
1321
|
FieldList(token, class_list).add('', self.queries[-1])
|
1036
1322
|
|
@@ -1045,10 +1331,13 @@ class CypherParser(Parser):
|
|
1045
1331
|
def add_foreign_key(self, token: str, pk_field: str=''):
|
1046
1332
|
curr, last = [self.queries[i] for i in (-1, -2)]
|
1047
1333
|
if not pk_field:
|
1048
|
-
if
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1334
|
+
if last.key_field:
|
1335
|
+
pk_field = last.key_field
|
1336
|
+
else:
|
1337
|
+
if not last.values.get(SELECT):
|
1338
|
+
raise IndexError(f'Primary Key not found for {last.table_name}.')
|
1339
|
+
pk_field = last.values[SELECT][-1].split('.')[-1]
|
1340
|
+
last.delete(pk_field, [SELECT], exact=True)
|
1052
1341
|
if '{}' in token:
|
1053
1342
|
foreign_fld = token.format(
|
1054
1343
|
last.table_name.lower()
|
@@ -1063,12 +1352,11 @@ class CypherParser(Parser):
|
|
1063
1352
|
if fld not in curr.values.get(GROUP_BY, [])
|
1064
1353
|
]
|
1065
1354
|
foreign_fld = fields[0].split('.')[-1]
|
1066
|
-
curr.delete(foreign_fld, [SELECT])
|
1355
|
+
curr.delete(foreign_fld, [SELECT], exact=True)
|
1067
1356
|
if curr.join_type == JoinType.RIGHT:
|
1068
1357
|
pk_field, foreign_fld = foreign_fld, pk_field
|
1069
1358
|
if curr.join_type == JoinType.RIGHT:
|
1070
1359
|
curr, last = last, curr
|
1071
|
-
# pk_field, foreign_fld = foreign_fld, pk_field
|
1072
1360
|
k = ForeignKey.get_key(curr, last)
|
1073
1361
|
ForeignKey.references[k] = (foreign_fld, pk_field)
|
1074
1362
|
|
@@ -1192,7 +1480,18 @@ class MongoParser(Parser):
|
|
1192
1480
|
|
1193
1481
|
def begin_conditions(self, value: str):
|
1194
1482
|
self.where_list = {}
|
1483
|
+
self.field_method = self.first_ORfield
|
1195
1484
|
return Where
|
1485
|
+
|
1486
|
+
def first_ORfield(self, text: str):
|
1487
|
+
if text.startswith('$'):
|
1488
|
+
return
|
1489
|
+
found = re.search(r'\w+[:]', text)
|
1490
|
+
if not found:
|
1491
|
+
return
|
1492
|
+
self.field_method = None
|
1493
|
+
p1, p2 = found.span()
|
1494
|
+
self.last_field = text[p1: p2-1]
|
1196
1495
|
|
1197
1496
|
def increment_brackets(self, value: str):
|
1198
1497
|
self.brackets[value] += 1
|
@@ -1201,6 +1500,7 @@ class MongoParser(Parser):
|
|
1201
1500
|
self.method = self.new_query
|
1202
1501
|
self.last_field = ''
|
1203
1502
|
self.where_list = None
|
1503
|
+
self.field_method = None
|
1204
1504
|
self.PARAM_BY_FUNCTION = {
|
1205
1505
|
'find': Where, 'aggregate': GroupBy, 'sort': OrderBy
|
1206
1506
|
}
|
@@ -1230,13 +1530,14 @@ class MongoParser(Parser):
|
|
1230
1530
|
self.close_brackets(
|
1231
1531
|
BRACKET_PAIR[token]
|
1232
1532
|
)
|
1533
|
+
elif self.field_method:
|
1534
|
+
self.field_method(token)
|
1233
1535
|
self.method = self.TOKEN_METHODS.get(token)
|
1234
1536
|
# ----------------------------
|
1235
1537
|
|
1236
1538
|
|
1237
1539
|
class Select(SQLObject):
|
1238
1540
|
join_type: JoinType = JoinType.INNER
|
1239
|
-
REGEX = {}
|
1240
1541
|
EQUIVALENT_NAMES = {}
|
1241
1542
|
|
1242
1543
|
def __init__(self, table_name: str='', **values):
|
@@ -1254,21 +1555,30 @@ class Select(SQLObject):
|
|
1254
1555
|
|
1255
1556
|
def add(self, name: str, main: SQLObject):
|
1256
1557
|
old_tables = main.values.get(FROM, [])
|
1257
|
-
|
1258
|
-
|
1558
|
+
if len(self.values[FROM]) > 1:
|
1559
|
+
old_tables += self.values[FROM][1:]
|
1560
|
+
new_tables = []
|
1561
|
+
row = '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
|
1259
1562
|
jt=self.join_type.value,
|
1260
1563
|
tb=self.aka(),
|
1261
1564
|
a1=main.alias, f1=name,
|
1262
1565
|
a2=self.alias, f2=self.key_field
|
1263
1566
|
)
|
1264
|
-
|
1265
|
-
|
1567
|
+
if row not in old_tables[1:]:
|
1568
|
+
new_tables.append(row)
|
1569
|
+
main.values[FROM] = old_tables[:1] + new_tables + old_tables[1:]
|
1266
1570
|
for key in USUAL_KEYS:
|
1267
1571
|
main.update_values(key, self.values.get(key, []))
|
1268
1572
|
|
1269
|
-
def
|
1573
|
+
def copy(self) -> SQLObject:
|
1270
1574
|
from copy import deepcopy
|
1271
|
-
|
1575
|
+
return deepcopy(self)
|
1576
|
+
|
1577
|
+
def no_relation_error(self, other: SQLObject):
|
1578
|
+
raise ValueError(f'No relationship found between {self.table_name} and {other.table_name}.')
|
1579
|
+
|
1580
|
+
def __add__(self, other: SQLObject):
|
1581
|
+
query = self.copy()
|
1272
1582
|
if query.table_name.lower() == other.table_name.lower():
|
1273
1583
|
for key in USUAL_KEYS:
|
1274
1584
|
query.update_values(key, other.values.get(key, []))
|
@@ -1281,7 +1591,7 @@ class Select(SQLObject):
|
|
1281
1591
|
PrimaryKey.add(primary_key, query)
|
1282
1592
|
query.add(foreign_field, other)
|
1283
1593
|
return other
|
1284
|
-
|
1594
|
+
self.no_relation_error(other) # === raise ERROR ... ===
|
1285
1595
|
elif primary_key:
|
1286
1596
|
PrimaryKey.add(primary_key, other)
|
1287
1597
|
other.add(foreign_field, query)
|
@@ -1301,16 +1611,48 @@ class Select(SQLObject):
|
|
1301
1611
|
if self.diff(key, other.values.get(key, []), True):
|
1302
1612
|
return False
|
1303
1613
|
return True
|
1614
|
+
|
1615
|
+
def __sub__(self, other: SQLObject) -> SQLObject:
|
1616
|
+
fk_field, primary_k = ForeignKey.find(self, other)
|
1617
|
+
if fk_field:
|
1618
|
+
query = self.copy()
|
1619
|
+
other = other.copy()
|
1620
|
+
else:
|
1621
|
+
fk_field, primary_k = ForeignKey.find(other, self)
|
1622
|
+
if not fk_field:
|
1623
|
+
self.no_relation_error(other) # === raise ERROR ... ===
|
1624
|
+
query = other.copy()
|
1625
|
+
other = self.copy()
|
1626
|
+
query.__class__ = NotSelectIN
|
1627
|
+
Field.add(fk_field, query)
|
1628
|
+
query.add(primary_k, other)
|
1629
|
+
return other
|
1304
1630
|
|
1305
1631
|
def limit(self, row_count: int=100, offset: int=0):
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1632
|
+
if Function.dialect == Dialect.SQL_SERVER:
|
1633
|
+
fields = self.values.get(SELECT)
|
1634
|
+
if fields:
|
1635
|
+
fields[0] = f'SELECT TOP({row_count}) {fields[0]}'
|
1636
|
+
else:
|
1637
|
+
self.values[SELECT] = [f'SELECT TOP({row_count}) *']
|
1638
|
+
return self
|
1639
|
+
if Function.dialect == Dialect.ORACLE:
|
1640
|
+
Where.gte(row_count).add(SQL_ROW_NUM, self)
|
1641
|
+
if offset > 0:
|
1642
|
+
Where.lte(row_count+offset).add(SQL_ROW_NUM, self)
|
1643
|
+
return self
|
1644
|
+
self.values[LIMIT] = ['{}{}'.format(
|
1645
|
+
row_count, f' OFFSET {offset}' if offset > 0 else ''
|
1646
|
+
)]
|
1310
1647
|
return self
|
1311
1648
|
|
1312
|
-
def match(self,
|
1313
|
-
|
1649
|
+
def match(self, field: str, key: str) -> bool:
|
1650
|
+
'''
|
1651
|
+
Recognizes if the field is from the current table
|
1652
|
+
'''
|
1653
|
+
if key in (ORDER_BY, GROUP_BY) and '.' not in field:
|
1654
|
+
return self.has_named_field(field)
|
1655
|
+
return re.findall(f'\b*{self.alias}[.]', field) != []
|
1314
1656
|
|
1315
1657
|
@classmethod
|
1316
1658
|
def parse(cls, txt: str, parser: Parser = SQLParser) -> list[SQLObject]:
|
@@ -1322,12 +1664,10 @@ class Select(SQLObject):
|
|
1322
1664
|
for rule in rules:
|
1323
1665
|
rule.apply(self)
|
1324
1666
|
|
1325
|
-
def add_fields(self, fields: list,
|
1326
|
-
|
1327
|
-
|
1328
|
-
|
1329
|
-
if group_by:
|
1330
|
-
class_types += [GroupBy]
|
1667
|
+
def add_fields(self, fields: list, class_types=None):
|
1668
|
+
if not class_types:
|
1669
|
+
class_types = []
|
1670
|
+
class_types += [Field]
|
1331
1671
|
FieldList(fields, class_types).add('', self)
|
1332
1672
|
|
1333
1673
|
def translate_to(self, language: QueryLanguage) -> str:
|
@@ -1347,6 +1687,95 @@ class NotSelectIN(SelectIN):
|
|
1347
1687
|
condition_class = Not
|
1348
1688
|
|
1349
1689
|
|
1690
|
+
class CTE(Select):
|
1691
|
+
prefix = ''
|
1692
|
+
|
1693
|
+
def __init__(self, table_name: str, query_list: list[Select]):
|
1694
|
+
super().__init__(table_name)
|
1695
|
+
for query in query_list:
|
1696
|
+
query.break_lines = False
|
1697
|
+
self.query_list = query_list
|
1698
|
+
self.break_lines = False
|
1699
|
+
|
1700
|
+
def __str__(self) -> str:
|
1701
|
+
size = 0
|
1702
|
+
for key in USUAL_KEYS:
|
1703
|
+
size += sum(len(v) for v in self.values.get(key, []) if '\n' not in v)
|
1704
|
+
if size > 70:
|
1705
|
+
self.break_lines = True
|
1706
|
+
# ---------------------------------------------------------
|
1707
|
+
def justify(query: Select) -> str:
|
1708
|
+
result, line = [], ''
|
1709
|
+
keywords = '|'.join(KEYWORD)
|
1710
|
+
for word in re.split(fr'({keywords}|AND|OR|,)', str(query)):
|
1711
|
+
if len(line) >= 50:
|
1712
|
+
result.append(line)
|
1713
|
+
line = ''
|
1714
|
+
line += word
|
1715
|
+
if line:
|
1716
|
+
result.append(line)
|
1717
|
+
return '\n '.join(result)
|
1718
|
+
# ---------------------------------------------------------
|
1719
|
+
return 'WITH {}{} AS (\n {}\n){}'.format(
|
1720
|
+
self.prefix, self.table_name,
|
1721
|
+
'\nUNION ALL\n '.join(
|
1722
|
+
justify(q) for q in self.query_list
|
1723
|
+
), super().__str__()
|
1724
|
+
)
|
1725
|
+
|
1726
|
+
def join(self, pattern: str, fields: list | str, format: str=''):
|
1727
|
+
if isinstance(fields, str):
|
1728
|
+
count = len( fields.split(',') )
|
1729
|
+
else:
|
1730
|
+
count = len(fields)
|
1731
|
+
queries = detect(
|
1732
|
+
pattern*count, join_queries=False, format=format
|
1733
|
+
)
|
1734
|
+
FieldList(fields, queries, ziped=True).add('', self)
|
1735
|
+
self.break_lines = True
|
1736
|
+
return self
|
1737
|
+
|
1738
|
+
class Recursive(CTE):
|
1739
|
+
prefix = 'RECURSIVE '
|
1740
|
+
|
1741
|
+
def __str__(self) -> str:
|
1742
|
+
if len(self.query_list) > 1:
|
1743
|
+
self.query_list[-1].values[FROM].append(
|
1744
|
+
f', {self.table_name} {self.alias}')
|
1745
|
+
return super().__str__()
|
1746
|
+
|
1747
|
+
@classmethod
|
1748
|
+
def create(cls, name: str, pattern: str, formula: str, init_value, format: str=''):
|
1749
|
+
SQLObject.ALIAS_FUNC = None
|
1750
|
+
def get_field(obj: SQLObject, pos: int) -> str:
|
1751
|
+
return obj.values[SELECT][pos].split('.')[-1]
|
1752
|
+
t1, t2 = detect(
|
1753
|
+
pattern*2, join_queries=False, format=format
|
1754
|
+
)
|
1755
|
+
pk_field = get_field(t1, 0)
|
1756
|
+
foreign_key = ''
|
1757
|
+
for num in re.findall(r'\[(\d+)\]', formula):
|
1758
|
+
num = int(num)
|
1759
|
+
if not foreign_key:
|
1760
|
+
foreign_key = get_field(t2, num-1)
|
1761
|
+
formula = formula.replace(f'[{num}]', '%')
|
1762
|
+
else:
|
1763
|
+
formula = formula.replace(f'[{num}]', get_field(t2, num-1))
|
1764
|
+
Where.eq(init_value).add(pk_field, t1)
|
1765
|
+
Where.formula(formula).add(foreign_key or pk_field, t2)
|
1766
|
+
return cls(name, [t1, t2])
|
1767
|
+
|
1768
|
+
def counter(self, name: str, start, increment: str='+1'):
|
1769
|
+
for i, query in enumerate(self.query_list):
|
1770
|
+
if i == 0:
|
1771
|
+
Field.add(f'{start} AS {name}', query)
|
1772
|
+
else:
|
1773
|
+
Field.add(f'({name}{increment}) AS {name}', query)
|
1774
|
+
return self
|
1775
|
+
|
1776
|
+
|
1777
|
+
# ----- Rules -----
|
1778
|
+
|
1350
1779
|
class RulePutLimit(Rule):
|
1351
1780
|
@classmethod
|
1352
1781
|
def apply(cls, target: Select):
|
@@ -1410,6 +1839,8 @@ class RuleDateFuncReplace(Rule):
|
|
1410
1839
|
@classmethod
|
1411
1840
|
def apply(cls, target: Select):
|
1412
1841
|
for i, condition in enumerate(target.values.get(WHERE, [])):
|
1842
|
+
if not '(' in condition:
|
1843
|
+
continue
|
1413
1844
|
tokens = [
|
1414
1845
|
t.strip() for t in cls.REGEX.split(condition) if t.strip()
|
1415
1846
|
]
|
@@ -1431,12 +1862,13 @@ class RuleReplaceJoinBySubselect(Rule):
|
|
1431
1862
|
more_relations = any([
|
1432
1863
|
ref[0] == query.table_name for ref in ForeignKey.references
|
1433
1864
|
])
|
1434
|
-
|
1865
|
+
keep_join = any([
|
1435
1866
|
len( query.values.get(SELECT, []) ) > 0,
|
1436
1867
|
len( query.values.get(WHERE, []) ) == 0,
|
1437
1868
|
not fk_field, more_relations
|
1438
1869
|
])
|
1439
|
-
if
|
1870
|
+
if keep_join:
|
1871
|
+
query.add(fk_field, main)
|
1440
1872
|
continue
|
1441
1873
|
query.__class__ = SubSelect
|
1442
1874
|
Field.add(primary_k, query)
|
@@ -1460,7 +1892,7 @@ def parser_class(text: str) -> Parser:
|
|
1460
1892
|
return None
|
1461
1893
|
|
1462
1894
|
|
1463
|
-
def detect(text: str) -> Select:
|
1895
|
+
def detect(text: str, join_queries: bool = True, format: str='') -> Select | list[Select]:
|
1464
1896
|
from collections import Counter
|
1465
1897
|
parser = parser_class(text)
|
1466
1898
|
if not parser:
|
@@ -1471,14 +1903,65 @@ def detect(text: str) -> Select:
|
|
1471
1903
|
continue
|
1472
1904
|
pos = [ f.span() for f in re.finditer(fr'({table})[(]', text) ]
|
1473
1905
|
for begin, end in pos[::-1]:
|
1474
|
-
new_name = f'{table}_{count}' # See set_table (line
|
1906
|
+
new_name = f'{table}_{count}' # See set_table (line 55)
|
1475
1907
|
Select.EQUIVALENT_NAMES[new_name] = table
|
1476
1908
|
text = text[:begin] + new_name + '(' + text[end:]
|
1477
1909
|
count -= 1
|
1478
1910
|
query_list = Select.parse(text, parser)
|
1911
|
+
if format:
|
1912
|
+
for query in query_list:
|
1913
|
+
query.set_file_format(format)
|
1914
|
+
if not join_queries:
|
1915
|
+
return query_list
|
1479
1916
|
result = query_list[0]
|
1480
1917
|
for query in query_list[1:]:
|
1481
1918
|
result += query
|
1482
1919
|
return result
|
1483
|
-
|
1484
|
-
|
1920
|
+
# ===========================================================================================//
|
1921
|
+
|
1922
|
+
|
1923
|
+
if __name__ == "__main__":
|
1924
|
+
# def identifica_suspeitos() -> Select:
|
1925
|
+
# """Mostra quais pessoas tem caracteríosticas iguais à descrição do suspeito"""
|
1926
|
+
# Select.join_type = JoinType.LEFT
|
1927
|
+
# return Select(
|
1928
|
+
# 'Suspeito s', id=Field,
|
1929
|
+
# _=Where.join(
|
1930
|
+
# Select('Pessoa p',
|
1931
|
+
# OR=Options(
|
1932
|
+
# pessoa=Where('= s.id'),
|
1933
|
+
# altura=Where.formula('ABS(% - s.{f}) < 0.5'),
|
1934
|
+
# peso=Where.formula('ABS(% - s.{f}) < 0.5'),
|
1935
|
+
# cabelo=Where.formula('% = s.{f}'),
|
1936
|
+
# olhos=Where.formula('% = s.{f}'),
|
1937
|
+
# sexo=Where.formula('% = s.{f}'),
|
1938
|
+
# ),
|
1939
|
+
# nome=Field
|
1940
|
+
# )
|
1941
|
+
# )
|
1942
|
+
# )
|
1943
|
+
# query = identifica_suspeitos()
|
1944
|
+
# print('='*50)
|
1945
|
+
# print(query)
|
1946
|
+
# print('-'*50)
|
1947
|
+
script = '''
|
1948
|
+
db.people.find({
|
1949
|
+
{
|
1950
|
+
$or: [
|
1951
|
+
status:{$eq:"B"},
|
1952
|
+
age:{$lt:50}
|
1953
|
+
]
|
1954
|
+
},
|
1955
|
+
age:{$gte:18}, status:{$eq:"A"}
|
1956
|
+
},{
|
1957
|
+
name: 1, user_id: 1
|
1958
|
+
}).sort({
|
1959
|
+
'''
|
1960
|
+
print('='*50)
|
1961
|
+
q1 = Select.parse(script, MongoParser)[0]
|
1962
|
+
print(q1)
|
1963
|
+
print('-'*50)
|
1964
|
+
q2 = q1.translate_to(MongoDBLanguage)
|
1965
|
+
print(q2)
|
1966
|
+
# print('-'*50)
|
1967
|
+
print('='*50)
|