sql-blocks 1.25.112__py3-none-any.whl → 1.25.514__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sql_blocks/sql_blocks.py CHANGED
@@ -42,23 +42,34 @@ class SQLObject:
42
42
  if not table_name:
43
43
  return
44
44
  cls = SQLObject
45
+ is_file_name = any([
46
+ '/' in table_name, '.' in table_name
47
+ ])
48
+ ref = table_name
49
+ if is_file_name:
50
+ ref = table_name.split('/')[-1].split('.')[0]
45
51
  if cls.ALIAS_FUNC:
46
- self.__alias = cls.ALIAS_FUNC(table_name)
52
+ self.__alias = cls.ALIAS_FUNC(ref)
47
53
  elif ' ' in table_name.strip():
48
54
  table_name, self.__alias = table_name.split()
49
- elif '_' in table_name:
55
+ elif '_' in ref:
50
56
  self.__alias = ''.join(
51
57
  word[0].lower()
52
- for word in table_name.split('_')
58
+ for word in ref.split('_')
53
59
  )
54
60
  else:
55
- self.__alias = table_name.lower()[:3]
61
+ self.__alias = ref.lower()[:3]
56
62
  self.values.setdefault(FROM, []).append(f'{table_name} {self.alias}')
57
63
 
58
64
  @property
59
65
  def table_name(self) -> str:
60
66
  return self.values[FROM][0].split()[0]
61
67
 
68
+ def set_file_format(self, pattern: str):
69
+ if '{' not in pattern:
70
+ pattern = '{}' + pattern
71
+ self.values[FROM][0] = pattern.format(self.aka())
72
+
62
73
  @property
63
74
  def alias(self) -> str:
64
75
  if self.__alias:
@@ -70,6 +81,16 @@ class SQLObject:
70
81
  appendix = {WHERE: r'\s+and\s+|', FROM: r'\s+join\s+|\s+JOIN\s+'}
71
82
  return KEYWORD[key][0].format(appendix.get(key, ''))
72
83
 
84
+ @staticmethod
85
+ def is_named_field(fld: str, name: str='') -> bool:
86
+ return re.search(fr'(\s+as\s+|\s+AS\s+){name}', fld)
87
+
88
+ def has_named_field(self, name: str) -> bool:
89
+ return any(
90
+ self.is_named_field(fld, name)
91
+ for fld in self.values.get(SELECT, [])
92
+ )
93
+
73
94
  def diff(self, key: str, search_list: list, exact: bool=False) -> set:
74
95
  def disassemble(source: list) -> list:
75
96
  if not exact:
@@ -78,16 +99,17 @@ class SQLObject:
78
99
  for fld in source:
79
100
  result += re.split(r'([=()]|<>|\s+ON\s+|\s+on\s+)', fld)
80
101
  return result
81
- def cleanup(fld: str) -> str:
102
+ def cleanup(text: str) -> str:
103
+ text = re.sub(r'[\n\t]', ' ', text)
82
104
  if exact:
83
- fld = fld.lower()
84
- return fld.strip()
85
- def is_named_field(fld: str) -> bool:
86
- return key == SELECT and re.search(r'\s+as\s+|\s+AS\s+', fld)
105
+ text = text.lower()
106
+ return text.strip()
87
107
  def field_set(source: list) -> set:
88
108
  return set(
89
109
  (
90
- fld if is_named_field(fld) else
110
+ fld
111
+ if key == SELECT and self.is_named_field(fld, key)
112
+ else
91
113
  re.sub(pattern, '', cleanup(fld))
92
114
  )
93
115
  for string in disassemble(source)
@@ -105,18 +127,22 @@ class SQLObject:
105
127
  return s1.symmetric_difference(s2)
106
128
  return s1 - s2
107
129
 
108
- def delete(self, search: str, keys: list=USUAL_KEYS):
130
+ def delete(self, search: str, keys: list=USUAL_KEYS, exact: bool=False):
131
+ if exact:
132
+ not_match = lambda item: not re.search(fr'\w*[.]*{search}$', item)
133
+ else:
134
+ not_match = lambda item: search not in item
109
135
  for key in keys:
110
- result = []
111
- for item in self.values.get(key, []):
112
- if search not in item:
113
- result.append(item)
114
- self.values[key] = result
136
+ self.values[key] = [
137
+ item for item in self.values.get(key, [])
138
+ if not_match(item)
139
+ ]
115
140
 
116
141
 
117
142
  SQL_CONST_SYSDATE = 'SYSDATE'
118
143
  SQL_CONST_CURR_DATE = 'Current_date'
119
- SQL_CONSTS = [SQL_CONST_SYSDATE, SQL_CONST_CURR_DATE]
144
+ SQL_ROW_NUM = 'ROWNUM'
145
+ SQL_CONSTS = [SQL_CONST_SYSDATE, SQL_CONST_CURR_DATE, SQL_ROW_NUM]
120
146
 
121
147
 
122
148
  class Field:
@@ -133,7 +159,7 @@ class Field:
133
159
  name = name.strip()
134
160
  if name in ('_', '*'):
135
161
  name = '*'
136
- elif not is_const():
162
+ elif not is_const() and not main.has_named_field(name):
137
163
  name = f'{main.alias}.{name}'
138
164
  if Function in cls.__bases__:
139
165
  name = f'{cls.__name__}({name})'
@@ -171,15 +197,35 @@ class Dialect(Enum):
171
197
  POSTGRESQL = 3
172
198
  MYSQL = 4
173
199
 
200
+ SQL_TYPES = 'CHAR INT DATE FLOAT ANY'.split()
201
+ CHAR, INT, DATE, FLOAT, ANY = SQL_TYPES
202
+
174
203
  class Function:
175
204
  dialect = Dialect.ANSI
205
+ inputs = None
206
+ output = None
207
+ separator = ', '
208
+ auto_convert = True
209
+ append_param = False
176
210
 
177
211
  def __init__(self, *params: list):
212
+ def set_func_types(param):
213
+ if self.auto_convert and isinstance(param, Function):
214
+ func = param
215
+ main_param = self.inputs[0]
216
+ unfriendly = all([
217
+ func.output != main_param,
218
+ func.output != ANY,
219
+ main_param != ANY
220
+ ])
221
+ if unfriendly:
222
+ return Cast(func, main_param)
223
+ return param
178
224
  # --- Replace class methods by instance methods: ------
179
225
  self.add = self.__add
180
226
  self.format = self.__format
181
227
  # -----------------------------------------------------
182
- self.params = [str(p) for p in params]
228
+ self.params = [set_func_types(p) for p in params]
183
229
  self.field_class = Field
184
230
  self.pattern = self.get_pattern()
185
231
  self.extra = {}
@@ -196,14 +242,35 @@ class Function:
196
242
  def __str__(self) -> str:
197
243
  return self.pattern.format(
198
244
  func_name=self.__class__.__name__,
199
- params=', '.join(self.params)
245
+ params=self.separator.join(str(p) for p in self.params)
200
246
  )
201
247
 
248
+ @classmethod
249
+ def help(cls) -> str:
250
+ descr = ' '.join(B.__name__ for B in cls.__bases__)
251
+ params = cls.inputs or ''
252
+ return cls().get_pattern().format(
253
+ func_name=f'{descr} {cls.__name__}',
254
+ params=cls.separator.join(str(p) for p in params)
255
+ ) + f' Return {cls.output}'
256
+
257
+ def set_main_param(self, name: str, main: SQLObject) -> bool:
258
+ nested_functions = [
259
+ param for param in self.params if isinstance(param, Function)
260
+ ]
261
+ for func in nested_functions:
262
+ if func.inputs:
263
+ func.set_main_param(name, main)
264
+ return
265
+ new_params = [Field.format(name, main)]
266
+ if self.append_param:
267
+ self.params += new_params
268
+ else:
269
+ self.params = new_params + self.params
270
+
202
271
  def __format(self, name: str, main: SQLObject) -> str:
203
272
  if name not in '*_':
204
- self.params = [
205
- Field.format(name, main)
206
- ] + self.params
273
+ self.set_main_param(name, main)
207
274
  return str(self)
208
275
 
209
276
  @classmethod
@@ -223,6 +290,9 @@ class Function:
223
290
 
224
291
  # ---- String Functions: ---------------------------------
225
292
  class SubString(Function):
293
+ inputs = [CHAR, INT, INT]
294
+ output = CHAR
295
+
226
296
  def get_pattern(self) -> str:
227
297
  if self.dialect in (Dialect.ORACLE, Dialect.MYSQL):
228
298
  return 'Substr({params})'
@@ -230,31 +300,55 @@ class SubString(Function):
230
300
 
231
301
  # ---- Numeric Functions: --------------------------------
232
302
  class Round(Function):
233
- ...
303
+ inputs = [FLOAT]
304
+ output = FLOAT
234
305
 
235
306
  # --- Date Functions: ------------------------------------
236
307
  class DateDiff(Function):
237
- def get_pattern(self) -> str:
308
+ inputs = [DATE]
309
+ output = DATE
310
+ append_param = True
311
+
312
+ def __str__(self) -> str:
238
313
  def is_field_or_func(name: str) -> bool:
239
- return re.sub('[()]', '', name).isidentifier()
314
+ candidate = re.sub(
315
+ '[()]', '', name.split('.')[-1]
316
+ )
317
+ return candidate.isidentifier()
240
318
  if self.dialect != Dialect.SQL_SERVER:
319
+ params = [str(p) for p in self.params]
241
320
  return ' - '.join(
242
321
  p if is_field_or_func(p) else f"'{p}'"
243
- for p in self.params
322
+ for p in params
244
323
  ) # <==== Date subtract
245
- return super().get_pattern()
324
+ return super().__str__()
325
+
326
+
327
+ class DatePart(Function):
328
+ inputs = [DATE]
329
+ output = INT
246
330
 
247
- class Year(Function):
248
331
  def get_pattern(self) -> str:
332
+ interval = self.__class__.__name__
249
333
  database_type = {
250
- Dialect.ORACLE: 'Extract(YEAR FROM {params})',
251
- Dialect.POSTGRESQL: "Date_Part('year', {params})",
334
+ Dialect.ORACLE: 'Extract('+interval+' FROM {params})',
335
+ Dialect.POSTGRESQL: "Date_Part('"+interval+"', {params})",
252
336
  }
253
337
  if self.dialect in database_type:
254
338
  return database_type[self.dialect]
255
339
  return super().get_pattern()
256
340
 
341
+ class Year(DatePart):
342
+ ...
343
+ class Month(DatePart):
344
+ ...
345
+ class Day(DatePart):
346
+ ...
347
+
348
+
257
349
  class Current_Date(Function):
350
+ output = DATE
351
+
258
352
  def get_pattern(self) -> str:
259
353
  database_type = {
260
354
  Dialect.ORACLE: SQL_CONST_SYSDATE,
@@ -277,14 +371,15 @@ class Frame:
277
371
  keywords = ''
278
372
  for field, obj in args.items():
279
373
  is_valid = any([
280
- obj is class_type # or isinstance(obj, class_type)
281
- for class_type in (OrderBy, Partition)
374
+ obj is OrderBy,
375
+ obj is Partition,
376
+ isinstance(obj, Rows),
282
377
  ])
283
378
  if not is_valid:
284
379
  continue
285
380
  keywords += '{}{} {}'.format(
286
381
  '\n\t\t' if self.break_lines else ' ',
287
- obj.cls_to_str(), field
382
+ obj.cls_to_str(), field if field != '_' else ''
288
383
  )
289
384
  if keywords and self.break_lines:
290
385
  keywords += '\n\t'
@@ -293,7 +388,8 @@ class Frame:
293
388
 
294
389
 
295
390
  class Aggregate(Frame):
296
- ...
391
+ inputs = [FLOAT]
392
+ output = FLOAT
297
393
 
298
394
  class Window(Frame):
299
395
  ...
@@ -312,20 +408,30 @@ class Count(Aggregate, Function):
312
408
 
313
409
  # ---- Window Functions: -----------------------------------
314
410
  class Row_Number(Window, Function):
315
- ...
411
+ output = INT
412
+
316
413
  class Rank(Window, Function):
317
- ...
414
+ output = INT
415
+
318
416
  class Lag(Window, Function):
319
- ...
417
+ output = ANY
418
+
320
419
  class Lead(Window, Function):
321
- ...
420
+ output = ANY
322
421
 
323
422
 
324
423
  # ---- Conversions and other Functions: ---------------------
325
424
  class Coalesce(Function):
326
- ...
425
+ inputs = [ANY]
426
+ output = ANY
427
+
327
428
  class Cast(Function):
328
- ...
429
+ inputs = [ANY]
430
+ output = ANY
431
+ separator = ' As '
432
+
433
+
434
+ FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
329
435
 
330
436
 
331
437
  class ExpressionField:
@@ -350,15 +456,20 @@ class ExpressionField:
350
456
  class FieldList:
351
457
  separator = ','
352
458
 
353
- def __init__(self, fields: list=[], class_types = [Field]):
459
+ def __init__(self, fields: list=[], class_types = [Field], ziped: bool=False):
354
460
  if isinstance(fields, str):
355
461
  fields = [
356
462
  f.strip() for f in fields.split(self.separator)
357
463
  ]
358
464
  self.fields = fields
359
465
  self.class_types = class_types
466
+ self.ziped = ziped
360
467
 
361
468
  def add(self, name: str, main: SQLObject):
469
+ if self.ziped: # --- One class per field...
470
+ for field, class_type in zip(self.fields, self.class_types):
471
+ class_type.add(field, main)
472
+ return
362
473
  for field in self.fields:
363
474
  for class_type in self.class_types:
364
475
  class_type.add(field, main)
@@ -400,36 +511,40 @@ class ForeignKey:
400
511
 
401
512
  def quoted(value) -> str:
402
513
  if isinstance(value, str):
514
+ if re.search(r'\bor\b', value, re.IGNORECASE):
515
+ raise PermissionError('Possible SQL injection attempt')
403
516
  value = f"'{value}'"
404
517
  return str(value)
405
518
 
406
519
 
407
520
  class Position(Enum):
521
+ StartsWith = -1
408
522
  Middle = 0
409
- StartsWith = 1
410
- EndsWith = 2
523
+ EndsWith = 1
411
524
 
412
525
 
413
526
  class Where:
414
527
  prefix = ''
415
528
 
416
- def __init__(self, expr: str):
417
- self.expr = expr
529
+ def __init__(self, content: str):
530
+ self.content = content
418
531
 
419
532
  @classmethod
420
533
  def __constructor(cls, operator: str, value):
421
- return cls(expr=f'{operator} {quoted(value)}')
534
+ return cls(f'{operator} {quoted(value)}')
422
535
 
423
536
  @classmethod
424
537
  def eq(cls, value):
425
538
  return cls.__constructor('=', value)
426
539
 
427
540
  @classmethod
428
- def contains(cls, content: str, pos: Position = Position.Middle):
541
+ def contains(cls, text: str, pos: int | Position = Position.Middle):
542
+ if isinstance(pos, int):
543
+ pos = Position(pos)
429
544
  return cls(
430
545
  "LIKE '{}{}{}'".format(
431
546
  '%' if pos != Position.StartsWith else '',
432
- content,
547
+ text,
433
548
  '%' if pos != Position.EndsWith else ''
434
549
  )
435
550
  )
@@ -460,9 +575,43 @@ class Where:
460
575
  values = ','.join(quoted(v) for v in values)
461
576
  return cls(f'IN ({values})')
462
577
 
578
+ @classmethod
579
+ def formula(cls, formula: str):
580
+ where = cls( ExpressionField(formula) )
581
+ where.add = where.add_expression
582
+ return where
583
+
584
+ def add_expression(self, name: str, main: SQLObject):
585
+ self.content = self.content.format(name, main)
586
+ main.values.setdefault(WHERE, []).append('{} {}'.format(
587
+ self.prefix, self.content
588
+ ))
589
+
590
+ @classmethod
591
+ def join(cls, query: SQLObject):
592
+ where = cls(query)
593
+ where.add = where.add_join
594
+ return where
595
+
596
+ def add_join(self, name: str, main: SQLObject):
597
+ query = self.content
598
+ main.values[FROM].append(f',{query.table_name} {query.alias}')
599
+ for key in USUAL_KEYS:
600
+ main.update_values(key, query.values.get(key, []))
601
+ if query.key_field:
602
+ main.values.setdefault(WHERE, []).append('({a1}.{f1} = {a2}.{f2})'.format(
603
+ a1=main.alias, f1=name,
604
+ a2=query.alias, f2=query.key_field
605
+ ))
606
+
463
607
  def add(self, name: str, main: SQLObject):
608
+ func_type = FUNCTION_CLASS.get(name.lower())
609
+ if func_type:
610
+ name = func_type.format('*', main)
611
+ elif not main.has_named_field(name):
612
+ name = Field.format(name, main)
464
613
  main.values.setdefault(WHERE, []).append('{}{} {}'.format(
465
- self.prefix, Field.format(name, main), self.expr
614
+ self.prefix, name, self.content
466
615
  ))
467
616
 
468
617
 
@@ -470,6 +619,10 @@ eq, contains, gt, gte, lt, lte, is_null, inside = (
470
619
  getattr(Where, method) for method in
471
620
  ('eq', 'contains', 'gt', 'gte', 'lt', 'lte', 'is_null', 'inside')
472
621
  )
622
+ startswith, endswith = [
623
+ lambda x: contains(x, Position.StartsWith),
624
+ lambda x: contains(x, Position.EndsWith)
625
+ ]
473
626
 
474
627
 
475
628
  class Not(Where):
@@ -477,7 +630,7 @@ class Not(Where):
477
630
 
478
631
  @classmethod
479
632
  def eq(cls, value):
480
- return Where(expr=f'<> {quoted(value)}')
633
+ return Where(f'<> {quoted(value)}')
481
634
 
482
635
 
483
636
  class Case:
@@ -486,22 +639,26 @@ class Case:
486
639
  self.default = None
487
640
  self.field = field
488
641
 
489
- def when(self, condition: Where, result: str):
642
+ def when(self, condition: Where, result):
643
+ if isinstance(result, str):
644
+ result = quoted(result)
490
645
  self.__conditions[result] = condition
491
646
  return self
492
647
 
493
- def else_value(self, default: str):
648
+ def else_value(self, default):
649
+ if isinstance(default, str):
650
+ default = quoted(default)
494
651
  self.default = default
495
652
  return self
496
653
 
497
654
  def add(self, name: str, main: SQLObject):
498
655
  field = Field.format(self.field, main)
499
- default = quoted(self.default)
656
+ default = self.default
500
657
  name = 'CASE \n{}\n\tEND AS {}'.format(
501
658
  '\n'.join(
502
- f'\t\tWHEN {field} {cond.expr} THEN {quoted(res)}'
659
+ f'\t\tWHEN {field} {cond.content} THEN {res}'
503
660
  for res, cond in self.__conditions.items()
504
- ) + f'\n\t\tELSE {default}' if default else '',
661
+ ) + (f'\n\t\tELSE {default}' if default else ''),
505
662
  name
506
663
  )
507
664
  main.values.setdefault(SELECT, []).append(name)
@@ -512,42 +669,69 @@ class Options:
512
669
  self.__children: dict = values
513
670
 
514
671
  def add(self, logical_separator: str, main: SQLObject):
515
- if logical_separator not in ('AND', 'OR'):
672
+ if logical_separator.upper() not in ('AND', 'OR'):
516
673
  raise ValueError('`logical_separator` must be AND or OR')
517
- conditions: list[str] = []
674
+ temp = Select(f'{main.table_name} {main.alias}')
518
675
  child: Where
519
676
  for field, child in self.__children.items():
520
- conditions.append(' {} {} '.format(
521
- Field.format(field, main), child.expr
522
- ))
677
+ child.add(field, temp)
523
678
  main.values.setdefault(WHERE, []).append(
524
- '(' + logical_separator.join(conditions) + ')'
679
+ '(' + f'\n\t{logical_separator} '.join(temp.values[WHERE]) + ')'
525
680
  )
526
681
 
527
682
 
528
683
  class Between:
684
+ is_literal: bool = False
685
+
529
686
  def __init__(self, start, end):
530
687
  if start > end:
531
688
  start, end = end, start
532
689
  self.start = start
533
690
  self.end = end
534
691
 
692
+ def literal(self) -> Where:
693
+ return Where('BETWEEN {} AND {}'.format(
694
+ self.start, self.end
695
+ ))
696
+
535
697
  def add(self, name: str, main:SQLObject):
536
- Where.gte(self.start).add(name, main),
698
+ if self.is_literal:
699
+ return self.literal().add(name, main)
700
+ Where.gte(self.start).add(name, main)
537
701
  Where.lte(self.end).add(name, main)
538
702
 
703
+ class SameDay(Between):
704
+ def __init__(self, date: str):
705
+ super().__init__(
706
+ f'{date} 00:00:00',
707
+ f'{date} 23:59:59',
708
+ )
709
+
710
+
711
+ class Range(Case):
712
+ INC_FUNCTION = lambda x: x + 1
713
+
714
+ def __init__(self, field: str, values: dict):
715
+ super().__init__(field)
716
+ start = 0
717
+ cls = self.__class__
718
+ for label, value in sorted(values.items(), key=lambda item: item[1]):
719
+ self.when(
720
+ Between(start, value).literal(), label
721
+ )
722
+ start = cls.INC_FUNCTION(value)
723
+
539
724
 
540
725
  class Clause:
541
726
  @classmethod
542
727
  def format(cls, name: str, main: SQLObject) -> str:
543
728
  def is_function() -> bool:
544
729
  diff = main.diff(SELECT, [name.lower()], True)
545
- FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
546
730
  return diff.intersection(FUNCTION_CLASS)
547
731
  found = re.findall(r'^_\d', name)
548
732
  if found:
549
733
  name = found[0].replace('_', '')
550
- elif main.alias and not is_function():
734
+ elif '.' not in name and main.alias and not is_function():
551
735
  name = f'{main.alias}.{name}'
552
736
  return name
553
737
 
@@ -556,6 +740,34 @@ class SortType(Enum):
556
740
  ASC = ''
557
741
  DESC = ' DESC'
558
742
 
743
+ class Row:
744
+ def __init__(self, value: int=0):
745
+ self.value = value
746
+
747
+ def __str__(self) -> str:
748
+ return '{} {}'.format(
749
+ 'UNBOUNDED' if self.value == 0 else self.value,
750
+ self.__class__.__name__.upper()
751
+ )
752
+
753
+ class Preceding(Row):
754
+ ...
755
+ class Following(Row):
756
+ ...
757
+ class Current(Row):
758
+ def __str__(self) -> str:
759
+ return 'CURRENT ROW'
760
+
761
+ class Rows:
762
+ def __init__(self, *rows: list[Row]):
763
+ self.rows = rows
764
+
765
+ def cls_to_str(self) -> str:
766
+ return 'ROWS {}{}'.format(
767
+ 'BETWEEN ' if len(self.rows) > 1 else '',
768
+ ' AND '.join(str(row) for row in self.rows)
769
+ )
770
+
559
771
 
560
772
  class OrderBy(Clause):
561
773
  sort: SortType = SortType.ASC
@@ -590,7 +802,7 @@ class Having:
590
802
 
591
803
  def add(self, name: str, main:SQLObject):
592
804
  main.values[GROUP_BY][-1] += ' HAVING {} {}'.format(
593
- self.function.format(name, main), self.condition.expr
805
+ self.function.format(name, main), self.condition.content
594
806
  )
595
807
 
596
808
  @classmethod
@@ -620,12 +832,20 @@ class Rule:
620
832
  ...
621
833
 
622
834
  class QueryLanguage:
623
- pattern = '{select}{_from}{where}{group_by}{order_by}'
835
+ pattern = '{select}{_from}{where}{group_by}{order_by}{limit}'
624
836
  has_default = {key: bool(key == SELECT) for key in KEYWORD}
625
837
 
626
838
  @staticmethod
627
- def remove_alias(fld: str) -> str:
628
- return ''.join(re.split(r'\w+[.]', fld))
839
+ def remove_alias(text: str) -> str:
840
+ value, sep = '', ''
841
+ text = re.sub('[\n\t]', ' ', text)
842
+ if ':' in text:
843
+ text, value = text.split(':', maxsplit=1)
844
+ sep = ':'
845
+ return '{}{}{}'.format(
846
+ ''.join(re.split(r'\w+[.]', text)),
847
+ sep, value.replace("'", '"')
848
+ )
629
849
 
630
850
  def join_with_tabs(self, values: list, sep: str='') -> str:
631
851
  sep = sep + self.TABULATION
@@ -643,18 +863,21 @@ class QueryLanguage:
643
863
  return self.join_with_tabs(values, ' AND ')
644
864
 
645
865
  def sort_by(self, values: list) -> str:
646
- return self.join_with_tabs(values)
866
+ return self.join_with_tabs(values, ',')
647
867
 
648
868
  def set_group(self, values: list) -> str:
649
869
  return self.join_with_tabs(values, ',')
650
870
 
871
+ def set_limit(self, values: list) -> str:
872
+ return self.join_with_tabs(values, ' ')
873
+
651
874
  def __init__(self, target: 'Select'):
652
- self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY]
875
+ self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY, LIMIT]
653
876
  self.TABULATION = '\n\t' if target.break_lines else ' '
654
877
  self.LINE_BREAK = '\n' if target.break_lines else ' '
655
878
  self.TOKEN_METHODS = {
656
879
  SELECT: self.add_field, FROM: self.get_tables,
657
- WHERE: self.extract_conditions,
880
+ WHERE: self.extract_conditions, LIMIT: self.set_limit,
658
881
  ORDER_BY: self.sort_by, GROUP_BY: self.set_group,
659
882
  }
660
883
  self.result = {}
@@ -690,7 +913,8 @@ class MongoDBLanguage(QueryLanguage):
690
913
  LOGICAL_OP_TO_MONGO_FUNC = {
691
914
  '>': '$gt', '>=': '$gte',
692
915
  '<': '$lt', '<=': '$lte',
693
- '=': '$eq', '<>': '$ne',
916
+ '=': '$eq', '<>': '$ne',
917
+ 'like': '$regex', 'LIKE': '$regex',
694
918
  }
695
919
  OPERATORS = '|'.join(op for op in LOGICAL_OP_TO_MONGO_FUNC)
696
920
  REGEX = {
@@ -743,7 +967,7 @@ class MongoDBLanguage(QueryLanguage):
743
967
  field, *op, const = tokens
744
968
  op = ''.join(op)
745
969
  expr = '{begin}{op}:{const}{end}'.format(
746
- begin='{', const=const, end='}',
970
+ begin='{', const=const.replace('%', '.*'), end='}',
747
971
  op=cls.LOGICAL_OP_TO_MONGO_FUNC[op],
748
972
  )
749
973
  where_list.append(f'{field}:{expr}')
@@ -852,6 +1076,55 @@ class Neo4JLanguage(QueryLanguage):
852
1076
  return ''
853
1077
 
854
1078
 
1079
+ class DatabricksLanguage(QueryLanguage):
1080
+ pattern = '{_from}{where}{group_by}{order_by}{select}{limit}'
1081
+ has_default = {key: bool(key == SELECT) for key in KEYWORD}
1082
+
1083
+ def __init__(self, target: 'Select'):
1084
+ super().__init__(target)
1085
+ self.aggregation_fields = []
1086
+
1087
+ def add_field(self, values: list) -> str:
1088
+ AGG_FUNCS = '|'.join(cls.__name__ for cls in Aggregate.__subclasses__())
1089
+ # --------------------------------------------------------------
1090
+ def is_agg_field(fld: str) -> bool:
1091
+ return re.findall(fr'({AGG_FUNCS})[(]', fld, re.IGNORECASE)
1092
+ # --------------------------------------------------------------
1093
+ new_values = []
1094
+ for val in values:
1095
+ if is_agg_field(val):
1096
+ self.aggregation_fields.append(val)
1097
+ else:
1098
+ new_values.append(val)
1099
+ values = new_values
1100
+ return super().add_field(values)
1101
+
1102
+ def prefix(self, key: str) -> str:
1103
+ def get_aggregate() -> str:
1104
+ return 'AGGREGATE {} '.format(
1105
+ ','.join(self.aggregation_fields)
1106
+ )
1107
+ return '{}{}{}{}{}'.format(
1108
+ '|> ' if key != FROM else '',
1109
+ self.LINE_BREAK,
1110
+ get_aggregate() if key == GROUP_BY else '',
1111
+ key, self.TABULATION
1112
+ )
1113
+
1114
+ # def get_tables(self, values: list) -> str:
1115
+ # return self.join_with_tabs(values)
1116
+
1117
+ # def extract_conditions(self, values: list) -> str:
1118
+ # return self.join_with_tabs(values, ' AND ')
1119
+
1120
+ # def sort_by(self, values: list) -> str:
1121
+ # return self.join_with_tabs(values, ',')
1122
+
1123
+ def set_group(self, values: list) -> str:
1124
+ return self.join_with_tabs(values, ',')
1125
+
1126
+
1127
+
855
1128
  class Parser:
856
1129
  REGEX = {}
857
1130
 
@@ -958,10 +1231,13 @@ class SQLParser(Parser):
958
1231
  if not key in values:
959
1232
  continue
960
1233
  separator = self.class_type.get_separator(key)
1234
+ cls = {
1235
+ ORDER_BY: OrderBy, GROUP_BY: GroupBy
1236
+ }.get(key, Field)
961
1237
  obj.values[key] = [
962
- Field.format(fld, obj)
1238
+ cls.format(fld, obj)
963
1239
  for fld in re.split(separator, values[key])
964
- if (fld != '*' and len(tables) == 1) or obj.match(fld)
1240
+ if (fld != '*' and len(tables) == 1) or obj.match(fld, key)
965
1241
  ]
966
1242
  result[obj.alias] = obj
967
1243
  self.queries = list( result.values() )
@@ -1021,16 +1297,26 @@ class CypherParser(Parser):
1021
1297
  if token in self.TOKEN_METHODS:
1022
1298
  return
1023
1299
  class_list = [Field]
1024
- if '$' in token:
1300
+ if '*' in token:
1301
+ token = token.replace('*', '')
1302
+ self.queries[-1].key_field = token
1303
+ return
1304
+ elif '$' in token:
1025
1305
  func_name, token = token.split('$')
1026
1306
  if func_name == 'count':
1027
1307
  if not token:
1028
1308
  token = 'count_1'
1029
- NamedField(token, Count).add('*', self.queries[-1])
1030
- class_list = []
1309
+ pk_field = self.queries[-1].key_field or 'id'
1310
+ Count().As(token, extra_classes).add(pk_field, self.queries[-1])
1311
+ return
1031
1312
  else:
1032
- FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
1033
- class_list = [ FUNCTION_CLASS[func_name] ]
1313
+ class_type = FUNCTION_CLASS.get(func_name)
1314
+ if not class_type:
1315
+ raise ValueError(f'Unknown function `{func_name}`.')
1316
+ if ':' in token:
1317
+ token, field_alias = token.split(':')
1318
+ class_type = class_type().As(field_alias)
1319
+ class_list = [class_type]
1034
1320
  class_list += extra_classes
1035
1321
  FieldList(token, class_list).add('', self.queries[-1])
1036
1322
 
@@ -1045,10 +1331,13 @@ class CypherParser(Parser):
1045
1331
  def add_foreign_key(self, token: str, pk_field: str=''):
1046
1332
  curr, last = [self.queries[i] for i in (-1, -2)]
1047
1333
  if not pk_field:
1048
- if not last.values.get(SELECT):
1049
- raise IndexError(f'Primary Key not found for {last.table_name}.')
1050
- pk_field = last.values[SELECT][-1].split('.')[-1]
1051
- last.delete(pk_field, [SELECT])
1334
+ if last.key_field:
1335
+ pk_field = last.key_field
1336
+ else:
1337
+ if not last.values.get(SELECT):
1338
+ raise IndexError(f'Primary Key not found for {last.table_name}.')
1339
+ pk_field = last.values[SELECT][-1].split('.')[-1]
1340
+ last.delete(pk_field, [SELECT], exact=True)
1052
1341
  if '{}' in token:
1053
1342
  foreign_fld = token.format(
1054
1343
  last.table_name.lower()
@@ -1063,12 +1352,11 @@ class CypherParser(Parser):
1063
1352
  if fld not in curr.values.get(GROUP_BY, [])
1064
1353
  ]
1065
1354
  foreign_fld = fields[0].split('.')[-1]
1066
- curr.delete(foreign_fld, [SELECT])
1355
+ curr.delete(foreign_fld, [SELECT], exact=True)
1067
1356
  if curr.join_type == JoinType.RIGHT:
1068
1357
  pk_field, foreign_fld = foreign_fld, pk_field
1069
1358
  if curr.join_type == JoinType.RIGHT:
1070
1359
  curr, last = last, curr
1071
- # pk_field, foreign_fld = foreign_fld, pk_field
1072
1360
  k = ForeignKey.get_key(curr, last)
1073
1361
  ForeignKey.references[k] = (foreign_fld, pk_field)
1074
1362
 
@@ -1192,7 +1480,18 @@ class MongoParser(Parser):
1192
1480
 
1193
1481
  def begin_conditions(self, value: str):
1194
1482
  self.where_list = {}
1483
+ self.field_method = self.first_ORfield
1195
1484
  return Where
1485
+
1486
+ def first_ORfield(self, text: str):
1487
+ if text.startswith('$'):
1488
+ return
1489
+ found = re.search(r'\w+[:]', text)
1490
+ if not found:
1491
+ return
1492
+ self.field_method = None
1493
+ p1, p2 = found.span()
1494
+ self.last_field = text[p1: p2-1]
1196
1495
 
1197
1496
  def increment_brackets(self, value: str):
1198
1497
  self.brackets[value] += 1
@@ -1201,6 +1500,7 @@ class MongoParser(Parser):
1201
1500
  self.method = self.new_query
1202
1501
  self.last_field = ''
1203
1502
  self.where_list = None
1503
+ self.field_method = None
1204
1504
  self.PARAM_BY_FUNCTION = {
1205
1505
  'find': Where, 'aggregate': GroupBy, 'sort': OrderBy
1206
1506
  }
@@ -1230,13 +1530,14 @@ class MongoParser(Parser):
1230
1530
  self.close_brackets(
1231
1531
  BRACKET_PAIR[token]
1232
1532
  )
1533
+ elif self.field_method:
1534
+ self.field_method(token)
1233
1535
  self.method = self.TOKEN_METHODS.get(token)
1234
1536
  # ----------------------------
1235
1537
 
1236
1538
 
1237
1539
  class Select(SQLObject):
1238
1540
  join_type: JoinType = JoinType.INNER
1239
- REGEX = {}
1240
1541
  EQUIVALENT_NAMES = {}
1241
1542
 
1242
1543
  def __init__(self, table_name: str='', **values):
@@ -1254,21 +1555,30 @@ class Select(SQLObject):
1254
1555
 
1255
1556
  def add(self, name: str, main: SQLObject):
1256
1557
  old_tables = main.values.get(FROM, [])
1257
- new_tables = set([
1258
- '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
1558
+ if len(self.values[FROM]) > 1:
1559
+ old_tables += self.values[FROM][1:]
1560
+ new_tables = []
1561
+ row = '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
1259
1562
  jt=self.join_type.value,
1260
1563
  tb=self.aka(),
1261
1564
  a1=main.alias, f1=name,
1262
1565
  a2=self.alias, f2=self.key_field
1263
1566
  )
1264
- ] + old_tables[1:])
1265
- main.values[FROM] = old_tables[:1] + list(new_tables)
1567
+ if row not in old_tables[1:]:
1568
+ new_tables.append(row)
1569
+ main.values[FROM] = old_tables[:1] + new_tables + old_tables[1:]
1266
1570
  for key in USUAL_KEYS:
1267
1571
  main.update_values(key, self.values.get(key, []))
1268
1572
 
1269
- def __add__(self, other: SQLObject):
1573
+ def copy(self) -> SQLObject:
1270
1574
  from copy import deepcopy
1271
- query = deepcopy(self)
1575
+ return deepcopy(self)
1576
+
1577
+ def no_relation_error(self, other: SQLObject):
1578
+ raise ValueError(f'No relationship found between {self.table_name} and {other.table_name}.')
1579
+
1580
+ def __add__(self, other: SQLObject):
1581
+ query = self.copy()
1272
1582
  if query.table_name.lower() == other.table_name.lower():
1273
1583
  for key in USUAL_KEYS:
1274
1584
  query.update_values(key, other.values.get(key, []))
@@ -1281,7 +1591,7 @@ class Select(SQLObject):
1281
1591
  PrimaryKey.add(primary_key, query)
1282
1592
  query.add(foreign_field, other)
1283
1593
  return other
1284
- raise ValueError(f'No relationship found between {query.table_name} and {other.table_name}.')
1594
+ self.no_relation_error(other) # === raise ERROR ... ===
1285
1595
  elif primary_key:
1286
1596
  PrimaryKey.add(primary_key, other)
1287
1597
  other.add(foreign_field, query)
@@ -1301,16 +1611,48 @@ class Select(SQLObject):
1301
1611
  if self.diff(key, other.values.get(key, []), True):
1302
1612
  return False
1303
1613
  return True
1614
+
1615
+ def __sub__(self, other: SQLObject) -> SQLObject:
1616
+ fk_field, primary_k = ForeignKey.find(self, other)
1617
+ if fk_field:
1618
+ query = self.copy()
1619
+ other = other.copy()
1620
+ else:
1621
+ fk_field, primary_k = ForeignKey.find(other, self)
1622
+ if not fk_field:
1623
+ self.no_relation_error(other) # === raise ERROR ... ===
1624
+ query = other.copy()
1625
+ other = self.copy()
1626
+ query.__class__ = NotSelectIN
1627
+ Field.add(fk_field, query)
1628
+ query.add(primary_k, other)
1629
+ return other
1304
1630
 
1305
1631
  def limit(self, row_count: int=100, offset: int=0):
1306
- result = [str(row_count)]
1307
- if offset > 0:
1308
- result.append(f'OFFSET {offset}')
1309
- self.values.setdefault(LIMIT, result)
1632
+ if Function.dialect == Dialect.SQL_SERVER:
1633
+ fields = self.values.get(SELECT)
1634
+ if fields:
1635
+ fields[0] = f'SELECT TOP({row_count}) {fields[0]}'
1636
+ else:
1637
+ self.values[SELECT] = [f'SELECT TOP({row_count}) *']
1638
+ return self
1639
+ if Function.dialect == Dialect.ORACLE:
1640
+ Where.gte(row_count).add(SQL_ROW_NUM, self)
1641
+ if offset > 0:
1642
+ Where.lte(row_count+offset).add(SQL_ROW_NUM, self)
1643
+ return self
1644
+ self.values[LIMIT] = ['{}{}'.format(
1645
+ row_count, f' OFFSET {offset}' if offset > 0 else ''
1646
+ )]
1310
1647
  return self
1311
1648
 
1312
- def match(self, expr: str) -> bool:
1313
- return re.findall(f'\b*{self.alias}[.]', expr) != []
1649
+ def match(self, field: str, key: str) -> bool:
1650
+ '''
1651
+ Recognizes if the field is from the current table
1652
+ '''
1653
+ if key in (ORDER_BY, GROUP_BY) and '.' not in field:
1654
+ return self.has_named_field(field)
1655
+ return re.findall(f'\b*{self.alias}[.]', field) != []
1314
1656
 
1315
1657
  @classmethod
1316
1658
  def parse(cls, txt: str, parser: Parser = SQLParser) -> list[SQLObject]:
@@ -1322,12 +1664,10 @@ class Select(SQLObject):
1322
1664
  for rule in rules:
1323
1665
  rule.apply(self)
1324
1666
 
1325
- def add_fields(self, fields: list, order_by: bool=False, group_by:bool=False):
1326
- class_types = [Field]
1327
- if order_by:
1328
- class_types += [OrderBy]
1329
- if group_by:
1330
- class_types += [GroupBy]
1667
+ def add_fields(self, fields: list, class_types=None):
1668
+ if not class_types:
1669
+ class_types = []
1670
+ class_types += [Field]
1331
1671
  FieldList(fields, class_types).add('', self)
1332
1672
 
1333
1673
  def translate_to(self, language: QueryLanguage) -> str:
@@ -1347,6 +1687,95 @@ class NotSelectIN(SelectIN):
1347
1687
  condition_class = Not
1348
1688
 
1349
1689
 
1690
+ class CTE(Select):
1691
+ prefix = ''
1692
+
1693
+ def __init__(self, table_name: str, query_list: list[Select]):
1694
+ super().__init__(table_name)
1695
+ for query in query_list:
1696
+ query.break_lines = False
1697
+ self.query_list = query_list
1698
+ self.break_lines = False
1699
+
1700
+ def __str__(self) -> str:
1701
+ size = 0
1702
+ for key in USUAL_KEYS:
1703
+ size += sum(len(v) for v in self.values.get(key, []) if '\n' not in v)
1704
+ if size > 70:
1705
+ self.break_lines = True
1706
+ # ---------------------------------------------------------
1707
+ def justify(query: Select) -> str:
1708
+ result, line = [], ''
1709
+ keywords = '|'.join(KEYWORD)
1710
+ for word in re.split(fr'({keywords}|AND|OR|,)', str(query)):
1711
+ if len(line) >= 50:
1712
+ result.append(line)
1713
+ line = ''
1714
+ line += word
1715
+ if line:
1716
+ result.append(line)
1717
+ return '\n '.join(result)
1718
+ # ---------------------------------------------------------
1719
+ return 'WITH {}{} AS (\n {}\n){}'.format(
1720
+ self.prefix, self.table_name,
1721
+ '\nUNION ALL\n '.join(
1722
+ justify(q) for q in self.query_list
1723
+ ), super().__str__()
1724
+ )
1725
+
1726
+ def join(self, pattern: str, fields: list | str, format: str=''):
1727
+ if isinstance(fields, str):
1728
+ count = len( fields.split(',') )
1729
+ else:
1730
+ count = len(fields)
1731
+ queries = detect(
1732
+ pattern*count, join_queries=False, format=format
1733
+ )
1734
+ FieldList(fields, queries, ziped=True).add('', self)
1735
+ self.break_lines = True
1736
+ return self
1737
+
1738
+ class Recursive(CTE):
1739
+ prefix = 'RECURSIVE '
1740
+
1741
+ def __str__(self) -> str:
1742
+ if len(self.query_list) > 1:
1743
+ self.query_list[-1].values[FROM].append(
1744
+ f', {self.table_name} {self.alias}')
1745
+ return super().__str__()
1746
+
1747
+ @classmethod
1748
+ def create(cls, name: str, pattern: str, formula: str, init_value, format: str=''):
1749
+ SQLObject.ALIAS_FUNC = None
1750
+ def get_field(obj: SQLObject, pos: int) -> str:
1751
+ return obj.values[SELECT][pos].split('.')[-1]
1752
+ t1, t2 = detect(
1753
+ pattern*2, join_queries=False, format=format
1754
+ )
1755
+ pk_field = get_field(t1, 0)
1756
+ foreign_key = ''
1757
+ for num in re.findall(r'\[(\d+)\]', formula):
1758
+ num = int(num)
1759
+ if not foreign_key:
1760
+ foreign_key = get_field(t2, num-1)
1761
+ formula = formula.replace(f'[{num}]', '%')
1762
+ else:
1763
+ formula = formula.replace(f'[{num}]', get_field(t2, num-1))
1764
+ Where.eq(init_value).add(pk_field, t1)
1765
+ Where.formula(formula).add(foreign_key or pk_field, t2)
1766
+ return cls(name, [t1, t2])
1767
+
1768
+ def counter(self, name: str, start, increment: str='+1'):
1769
+ for i, query in enumerate(self.query_list):
1770
+ if i == 0:
1771
+ Field.add(f'{start} AS {name}', query)
1772
+ else:
1773
+ Field.add(f'({name}{increment}) AS {name}', query)
1774
+ return self
1775
+
1776
+
1777
+ # ----- Rules -----
1778
+
1350
1779
  class RulePutLimit(Rule):
1351
1780
  @classmethod
1352
1781
  def apply(cls, target: Select):
@@ -1410,6 +1839,8 @@ class RuleDateFuncReplace(Rule):
1410
1839
  @classmethod
1411
1840
  def apply(cls, target: Select):
1412
1841
  for i, condition in enumerate(target.values.get(WHERE, [])):
1842
+ if not '(' in condition:
1843
+ continue
1413
1844
  tokens = [
1414
1845
  t.strip() for t in cls.REGEX.split(condition) if t.strip()
1415
1846
  ]
@@ -1431,12 +1862,13 @@ class RuleReplaceJoinBySubselect(Rule):
1431
1862
  more_relations = any([
1432
1863
  ref[0] == query.table_name for ref in ForeignKey.references
1433
1864
  ])
1434
- invalid = any([
1865
+ keep_join = any([
1435
1866
  len( query.values.get(SELECT, []) ) > 0,
1436
1867
  len( query.values.get(WHERE, []) ) == 0,
1437
1868
  not fk_field, more_relations
1438
1869
  ])
1439
- if invalid:
1870
+ if keep_join:
1871
+ query.add(fk_field, main)
1440
1872
  continue
1441
1873
  query.__class__ = SubSelect
1442
1874
  Field.add(primary_k, query)
@@ -1460,7 +1892,7 @@ def parser_class(text: str) -> Parser:
1460
1892
  return None
1461
1893
 
1462
1894
 
1463
- def detect(text: str) -> Select:
1895
+ def detect(text: str, join_queries: bool = True, format: str='') -> Select | list[Select]:
1464
1896
  from collections import Counter
1465
1897
  parser = parser_class(text)
1466
1898
  if not parser:
@@ -1471,14 +1903,65 @@ def detect(text: str) -> Select:
1471
1903
  continue
1472
1904
  pos = [ f.span() for f in re.finditer(fr'({table})[(]', text) ]
1473
1905
  for begin, end in pos[::-1]:
1474
- new_name = f'{table}_{count}' # See set_table (line 45)
1906
+ new_name = f'{table}_{count}' # See set_table (line 55)
1475
1907
  Select.EQUIVALENT_NAMES[new_name] = table
1476
1908
  text = text[:begin] + new_name + '(' + text[end:]
1477
1909
  count -= 1
1478
1910
  query_list = Select.parse(text, parser)
1911
+ if format:
1912
+ for query in query_list:
1913
+ query.set_file_format(format)
1914
+ if not join_queries:
1915
+ return query_list
1479
1916
  result = query_list[0]
1480
1917
  for query in query_list[1:]:
1481
1918
  result += query
1482
1919
  return result
1483
-
1484
-
1920
+ # ===========================================================================================//
1921
+
1922
+
1923
+ if __name__ == "__main__":
1924
+ # def identifica_suspeitos() -> Select:
1925
+ # """Mostra quais pessoas tem caracteríosticas iguais à descrição do suspeito"""
1926
+ # Select.join_type = JoinType.LEFT
1927
+ # return Select(
1928
+ # 'Suspeito s', id=Field,
1929
+ # _=Where.join(
1930
+ # Select('Pessoa p',
1931
+ # OR=Options(
1932
+ # pessoa=Where('= s.id'),
1933
+ # altura=Where.formula('ABS(% - s.{f}) < 0.5'),
1934
+ # peso=Where.formula('ABS(% - s.{f}) < 0.5'),
1935
+ # cabelo=Where.formula('% = s.{f}'),
1936
+ # olhos=Where.formula('% = s.{f}'),
1937
+ # sexo=Where.formula('% = s.{f}'),
1938
+ # ),
1939
+ # nome=Field
1940
+ # )
1941
+ # )
1942
+ # )
1943
+ # query = identifica_suspeitos()
1944
+ # print('='*50)
1945
+ # print(query)
1946
+ # print('-'*50)
1947
+ script = '''
1948
+ db.people.find({
1949
+ {
1950
+ $or: [
1951
+ status:{$eq:"B"},
1952
+ age:{$lt:50}
1953
+ ]
1954
+ },
1955
+ age:{$gte:18}, status:{$eq:"A"}
1956
+ },{
1957
+ name: 1, user_id: 1
1958
+ }).sort({
1959
+ '''
1960
+ print('='*50)
1961
+ q1 = Select.parse(script, MongoParser)[0]
1962
+ print(q1)
1963
+ print('-'*50)
1964
+ q2 = q1.translate_to(MongoDBLanguage)
1965
+ print(q2)
1966
+ # print('-'*50)
1967
+ print('='*50)