sql-blocks 1.25.113__py3-none-any.whl → 1.25.514__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sql_blocks/sql_blocks.py CHANGED
@@ -42,23 +42,34 @@ class SQLObject:
42
42
  if not table_name:
43
43
  return
44
44
  cls = SQLObject
45
+ is_file_name = any([
46
+ '/' in table_name, '.' in table_name
47
+ ])
48
+ ref = table_name
49
+ if is_file_name:
50
+ ref = table_name.split('/')[-1].split('.')[0]
45
51
  if cls.ALIAS_FUNC:
46
- self.__alias = cls.ALIAS_FUNC(table_name)
52
+ self.__alias = cls.ALIAS_FUNC(ref)
47
53
  elif ' ' in table_name.strip():
48
54
  table_name, self.__alias = table_name.split()
49
- elif '_' in table_name:
55
+ elif '_' in ref:
50
56
  self.__alias = ''.join(
51
57
  word[0].lower()
52
- for word in table_name.split('_')
58
+ for word in ref.split('_')
53
59
  )
54
60
  else:
55
- self.__alias = table_name.lower()[:3]
61
+ self.__alias = ref.lower()[:3]
56
62
  self.values.setdefault(FROM, []).append(f'{table_name} {self.alias}')
57
63
 
58
64
  @property
59
65
  def table_name(self) -> str:
60
66
  return self.values[FROM][0].split()[0]
61
67
 
68
+ def set_file_format(self, pattern: str):
69
+ if '{' not in pattern:
70
+ pattern = '{}' + pattern
71
+ self.values[FROM][0] = pattern.format(self.aka())
72
+
62
73
  @property
63
74
  def alias(self) -> str:
64
75
  if self.__alias:
@@ -71,8 +82,14 @@ class SQLObject:
71
82
  return KEYWORD[key][0].format(appendix.get(key, ''))
72
83
 
73
84
  @staticmethod
74
- def is_named_field(fld: str, key: str) -> bool:
75
- return key == SELECT and re.search(r'\s+as\s+|\s+AS\s+', fld)
85
+ def is_named_field(fld: str, name: str='') -> bool:
86
+ return re.search(fr'(\s+as\s+|\s+AS\s+){name}', fld)
87
+
88
+ def has_named_field(self, name: str) -> bool:
89
+ return any(
90
+ self.is_named_field(fld, name)
91
+ for fld in self.values.get(SELECT, [])
92
+ )
76
93
 
77
94
  def diff(self, key: str, search_list: list, exact: bool=False) -> set:
78
95
  def disassemble(source: list) -> list:
@@ -82,14 +99,17 @@ class SQLObject:
82
99
  for fld in source:
83
100
  result += re.split(r'([=()]|<>|\s+ON\s+|\s+on\s+)', fld)
84
101
  return result
85
- def cleanup(fld: str) -> str:
102
+ def cleanup(text: str) -> str:
103
+ text = re.sub(r'[\n\t]', ' ', text)
86
104
  if exact:
87
- fld = fld.lower()
88
- return fld.strip()
105
+ text = text.lower()
106
+ return text.strip()
89
107
  def field_set(source: list) -> set:
90
108
  return set(
91
109
  (
92
- fld if self.is_named_field(fld, key) else
110
+ fld
111
+ if key == SELECT and self.is_named_field(fld, key)
112
+ else
93
113
  re.sub(pattern, '', cleanup(fld))
94
114
  )
95
115
  for string in disassemble(source)
@@ -121,7 +141,8 @@ class SQLObject:
121
141
 
122
142
  SQL_CONST_SYSDATE = 'SYSDATE'
123
143
  SQL_CONST_CURR_DATE = 'Current_date'
124
- SQL_CONSTS = [SQL_CONST_SYSDATE, SQL_CONST_CURR_DATE]
144
+ SQL_ROW_NUM = 'ROWNUM'
145
+ SQL_CONSTS = [SQL_CONST_SYSDATE, SQL_CONST_CURR_DATE, SQL_ROW_NUM]
125
146
 
126
147
 
127
148
  class Field:
@@ -138,7 +159,7 @@ class Field:
138
159
  name = name.strip()
139
160
  if name in ('_', '*'):
140
161
  name = '*'
141
- elif not is_const():
162
+ elif not is_const() and not main.has_named_field(name):
142
163
  name = f'{main.alias}.{name}'
143
164
  if Function in cls.__bases__:
144
165
  name = f'{cls.__name__}({name})'
@@ -176,15 +197,35 @@ class Dialect(Enum):
176
197
  POSTGRESQL = 3
177
198
  MYSQL = 4
178
199
 
200
+ SQL_TYPES = 'CHAR INT DATE FLOAT ANY'.split()
201
+ CHAR, INT, DATE, FLOAT, ANY = SQL_TYPES
202
+
179
203
  class Function:
180
204
  dialect = Dialect.ANSI
205
+ inputs = None
206
+ output = None
207
+ separator = ', '
208
+ auto_convert = True
209
+ append_param = False
181
210
 
182
211
  def __init__(self, *params: list):
212
+ def set_func_types(param):
213
+ if self.auto_convert and isinstance(param, Function):
214
+ func = param
215
+ main_param = self.inputs[0]
216
+ unfriendly = all([
217
+ func.output != main_param,
218
+ func.output != ANY,
219
+ main_param != ANY
220
+ ])
221
+ if unfriendly:
222
+ return Cast(func, main_param)
223
+ return param
183
224
  # --- Replace class methods by instance methods: ------
184
225
  self.add = self.__add
185
226
  self.format = self.__format
186
227
  # -----------------------------------------------------
187
- self.params = [str(p) for p in params]
228
+ self.params = [set_func_types(p) for p in params]
188
229
  self.field_class = Field
189
230
  self.pattern = self.get_pattern()
190
231
  self.extra = {}
@@ -201,14 +242,35 @@ class Function:
201
242
  def __str__(self) -> str:
202
243
  return self.pattern.format(
203
244
  func_name=self.__class__.__name__,
204
- params=', '.join(self.params)
245
+ params=self.separator.join(str(p) for p in self.params)
205
246
  )
206
247
 
248
+ @classmethod
249
+ def help(cls) -> str:
250
+ descr = ' '.join(B.__name__ for B in cls.__bases__)
251
+ params = cls.inputs or ''
252
+ return cls().get_pattern().format(
253
+ func_name=f'{descr} {cls.__name__}',
254
+ params=cls.separator.join(str(p) for p in params)
255
+ ) + f' Return {cls.output}'
256
+
257
+ def set_main_param(self, name: str, main: SQLObject) -> bool:
258
+ nested_functions = [
259
+ param for param in self.params if isinstance(param, Function)
260
+ ]
261
+ for func in nested_functions:
262
+ if func.inputs:
263
+ func.set_main_param(name, main)
264
+ return
265
+ new_params = [Field.format(name, main)]
266
+ if self.append_param:
267
+ self.params += new_params
268
+ else:
269
+ self.params = new_params + self.params
270
+
207
271
  def __format(self, name: str, main: SQLObject) -> str:
208
272
  if name not in '*_':
209
- self.params = [
210
- Field.format(name, main)
211
- ] + self.params
273
+ self.set_main_param(name, main)
212
274
  return str(self)
213
275
 
214
276
  @classmethod
@@ -228,6 +290,9 @@ class Function:
228
290
 
229
291
  # ---- String Functions: ---------------------------------
230
292
  class SubString(Function):
293
+ inputs = [CHAR, INT, INT]
294
+ output = CHAR
295
+
231
296
  def get_pattern(self) -> str:
232
297
  if self.dialect in (Dialect.ORACLE, Dialect.MYSQL):
233
298
  return 'Substr({params})'
@@ -235,31 +300,55 @@ class SubString(Function):
235
300
 
236
301
  # ---- Numeric Functions: --------------------------------
237
302
  class Round(Function):
238
- ...
303
+ inputs = [FLOAT]
304
+ output = FLOAT
239
305
 
240
306
  # --- Date Functions: ------------------------------------
241
307
  class DateDiff(Function):
242
- def get_pattern(self) -> str:
308
+ inputs = [DATE]
309
+ output = DATE
310
+ append_param = True
311
+
312
+ def __str__(self) -> str:
243
313
  def is_field_or_func(name: str) -> bool:
244
- return re.sub('[()]', '', name).isidentifier()
314
+ candidate = re.sub(
315
+ '[()]', '', name.split('.')[-1]
316
+ )
317
+ return candidate.isidentifier()
245
318
  if self.dialect != Dialect.SQL_SERVER:
319
+ params = [str(p) for p in self.params]
246
320
  return ' - '.join(
247
321
  p if is_field_or_func(p) else f"'{p}'"
248
- for p in self.params
322
+ for p in params
249
323
  ) # <==== Date subtract
250
- return super().get_pattern()
324
+ return super().__str__()
325
+
326
+
327
+ class DatePart(Function):
328
+ inputs = [DATE]
329
+ output = INT
251
330
 
252
- class Year(Function):
253
331
  def get_pattern(self) -> str:
332
+ interval = self.__class__.__name__
254
333
  database_type = {
255
- Dialect.ORACLE: 'Extract(YEAR FROM {params})',
256
- Dialect.POSTGRESQL: "Date_Part('year', {params})",
334
+ Dialect.ORACLE: 'Extract('+interval+' FROM {params})',
335
+ Dialect.POSTGRESQL: "Date_Part('"+interval+"', {params})",
257
336
  }
258
337
  if self.dialect in database_type:
259
338
  return database_type[self.dialect]
260
339
  return super().get_pattern()
261
340
 
341
+ class Year(DatePart):
342
+ ...
343
+ class Month(DatePart):
344
+ ...
345
+ class Day(DatePart):
346
+ ...
347
+
348
+
262
349
  class Current_Date(Function):
350
+ output = DATE
351
+
263
352
  def get_pattern(self) -> str:
264
353
  database_type = {
265
354
  Dialect.ORACLE: SQL_CONST_SYSDATE,
@@ -282,14 +371,15 @@ class Frame:
282
371
  keywords = ''
283
372
  for field, obj in args.items():
284
373
  is_valid = any([
285
- obj is class_type # or isinstance(obj, class_type)
286
- for class_type in (OrderBy, Partition)
374
+ obj is OrderBy,
375
+ obj is Partition,
376
+ isinstance(obj, Rows),
287
377
  ])
288
378
  if not is_valid:
289
379
  continue
290
380
  keywords += '{}{} {}'.format(
291
381
  '\n\t\t' if self.break_lines else ' ',
292
- obj.cls_to_str(), field
382
+ obj.cls_to_str(), field if field != '_' else ''
293
383
  )
294
384
  if keywords and self.break_lines:
295
385
  keywords += '\n\t'
@@ -298,7 +388,8 @@ class Frame:
298
388
 
299
389
 
300
390
  class Aggregate(Frame):
301
- ...
391
+ inputs = [FLOAT]
392
+ output = FLOAT
302
393
 
303
394
  class Window(Frame):
304
395
  ...
@@ -317,20 +408,30 @@ class Count(Aggregate, Function):
317
408
 
318
409
  # ---- Window Functions: -----------------------------------
319
410
  class Row_Number(Window, Function):
320
- ...
411
+ output = INT
412
+
321
413
  class Rank(Window, Function):
322
- ...
414
+ output = INT
415
+
323
416
  class Lag(Window, Function):
324
- ...
417
+ output = ANY
418
+
325
419
  class Lead(Window, Function):
326
- ...
420
+ output = ANY
327
421
 
328
422
 
329
423
  # ---- Conversions and other Functions: ---------------------
330
424
  class Coalesce(Function):
331
- ...
425
+ inputs = [ANY]
426
+ output = ANY
427
+
332
428
  class Cast(Function):
333
- ...
429
+ inputs = [ANY]
430
+ output = ANY
431
+ separator = ' As '
432
+
433
+
434
+ FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
334
435
 
335
436
 
336
437
  class ExpressionField:
@@ -355,15 +456,20 @@ class ExpressionField:
355
456
  class FieldList:
356
457
  separator = ','
357
458
 
358
- def __init__(self, fields: list=[], class_types = [Field]):
459
+ def __init__(self, fields: list=[], class_types = [Field], ziped: bool=False):
359
460
  if isinstance(fields, str):
360
461
  fields = [
361
462
  f.strip() for f in fields.split(self.separator)
362
463
  ]
363
464
  self.fields = fields
364
465
  self.class_types = class_types
466
+ self.ziped = ziped
365
467
 
366
468
  def add(self, name: str, main: SQLObject):
469
+ if self.ziped: # --- One class per field...
470
+ for field, class_type in zip(self.fields, self.class_types):
471
+ class_type.add(field, main)
472
+ return
367
473
  for field in self.fields:
368
474
  for class_type in self.class_types:
369
475
  class_type.add(field, main)
@@ -405,36 +511,40 @@ class ForeignKey:
405
511
 
406
512
  def quoted(value) -> str:
407
513
  if isinstance(value, str):
514
+ if re.search(r'\bor\b', value, re.IGNORECASE):
515
+ raise PermissionError('Possible SQL injection attempt')
408
516
  value = f"'{value}'"
409
517
  return str(value)
410
518
 
411
519
 
412
520
  class Position(Enum):
521
+ StartsWith = -1
413
522
  Middle = 0
414
- StartsWith = 1
415
- EndsWith = 2
523
+ EndsWith = 1
416
524
 
417
525
 
418
526
  class Where:
419
527
  prefix = ''
420
528
 
421
- def __init__(self, expr: str):
422
- self.expr = expr
529
+ def __init__(self, content: str):
530
+ self.content = content
423
531
 
424
532
  @classmethod
425
533
  def __constructor(cls, operator: str, value):
426
- return cls(expr=f'{operator} {quoted(value)}')
534
+ return cls(f'{operator} {quoted(value)}')
427
535
 
428
536
  @classmethod
429
537
  def eq(cls, value):
430
538
  return cls.__constructor('=', value)
431
539
 
432
540
  @classmethod
433
- def contains(cls, content: str, pos: Position = Position.Middle):
541
+ def contains(cls, text: str, pos: int | Position = Position.Middle):
542
+ if isinstance(pos, int):
543
+ pos = Position(pos)
434
544
  return cls(
435
545
  "LIKE '{}{}{}'".format(
436
546
  '%' if pos != Position.StartsWith else '',
437
- content,
547
+ text,
438
548
  '%' if pos != Position.EndsWith else ''
439
549
  )
440
550
  )
@@ -465,9 +575,43 @@ class Where:
465
575
  values = ','.join(quoted(v) for v in values)
466
576
  return cls(f'IN ({values})')
467
577
 
578
+ @classmethod
579
+ def formula(cls, formula: str):
580
+ where = cls( ExpressionField(formula) )
581
+ where.add = where.add_expression
582
+ return where
583
+
584
+ def add_expression(self, name: str, main: SQLObject):
585
+ self.content = self.content.format(name, main)
586
+ main.values.setdefault(WHERE, []).append('{} {}'.format(
587
+ self.prefix, self.content
588
+ ))
589
+
590
+ @classmethod
591
+ def join(cls, query: SQLObject):
592
+ where = cls(query)
593
+ where.add = where.add_join
594
+ return where
595
+
596
+ def add_join(self, name: str, main: SQLObject):
597
+ query = self.content
598
+ main.values[FROM].append(f',{query.table_name} {query.alias}')
599
+ for key in USUAL_KEYS:
600
+ main.update_values(key, query.values.get(key, []))
601
+ if query.key_field:
602
+ main.values.setdefault(WHERE, []).append('({a1}.{f1} = {a2}.{f2})'.format(
603
+ a1=main.alias, f1=name,
604
+ a2=query.alias, f2=query.key_field
605
+ ))
606
+
468
607
  def add(self, name: str, main: SQLObject):
608
+ func_type = FUNCTION_CLASS.get(name.lower())
609
+ if func_type:
610
+ name = func_type.format('*', main)
611
+ elif not main.has_named_field(name):
612
+ name = Field.format(name, main)
469
613
  main.values.setdefault(WHERE, []).append('{}{} {}'.format(
470
- self.prefix, Field.format(name, main), self.expr
614
+ self.prefix, name, self.content
471
615
  ))
472
616
 
473
617
 
@@ -475,6 +619,10 @@ eq, contains, gt, gte, lt, lte, is_null, inside = (
475
619
  getattr(Where, method) for method in
476
620
  ('eq', 'contains', 'gt', 'gte', 'lt', 'lte', 'is_null', 'inside')
477
621
  )
622
+ startswith, endswith = [
623
+ lambda x: contains(x, Position.StartsWith),
624
+ lambda x: contains(x, Position.EndsWith)
625
+ ]
478
626
 
479
627
 
480
628
  class Not(Where):
@@ -482,7 +630,7 @@ class Not(Where):
482
630
 
483
631
  @classmethod
484
632
  def eq(cls, value):
485
- return Where(expr=f'<> {quoted(value)}')
633
+ return Where(f'<> {quoted(value)}')
486
634
 
487
635
 
488
636
  class Case:
@@ -491,22 +639,26 @@ class Case:
491
639
  self.default = None
492
640
  self.field = field
493
641
 
494
- def when(self, condition: Where, result: str):
642
+ def when(self, condition: Where, result):
643
+ if isinstance(result, str):
644
+ result = quoted(result)
495
645
  self.__conditions[result] = condition
496
646
  return self
497
647
 
498
- def else_value(self, default: str):
648
+ def else_value(self, default):
649
+ if isinstance(default, str):
650
+ default = quoted(default)
499
651
  self.default = default
500
652
  return self
501
653
 
502
654
  def add(self, name: str, main: SQLObject):
503
655
  field = Field.format(self.field, main)
504
- default = quoted(self.default)
656
+ default = self.default
505
657
  name = 'CASE \n{}\n\tEND AS {}'.format(
506
658
  '\n'.join(
507
- f'\t\tWHEN {field} {cond.expr} THEN {quoted(res)}'
659
+ f'\t\tWHEN {field} {cond.content} THEN {res}'
508
660
  for res, cond in self.__conditions.items()
509
- ) + f'\n\t\tELSE {default}' if default else '',
661
+ ) + (f'\n\t\tELSE {default}' if default else ''),
510
662
  name
511
663
  )
512
664
  main.values.setdefault(SELECT, []).append(name)
@@ -517,42 +669,69 @@ class Options:
517
669
  self.__children: dict = values
518
670
 
519
671
  def add(self, logical_separator: str, main: SQLObject):
520
- if logical_separator not in ('AND', 'OR'):
672
+ if logical_separator.upper() not in ('AND', 'OR'):
521
673
  raise ValueError('`logical_separator` must be AND or OR')
522
- conditions: list[str] = []
674
+ temp = Select(f'{main.table_name} {main.alias}')
523
675
  child: Where
524
676
  for field, child in self.__children.items():
525
- conditions.append(' {} {} '.format(
526
- Field.format(field, main), child.expr
527
- ))
677
+ child.add(field, temp)
528
678
  main.values.setdefault(WHERE, []).append(
529
- '(' + logical_separator.join(conditions) + ')'
679
+ '(' + f'\n\t{logical_separator} '.join(temp.values[WHERE]) + ')'
530
680
  )
531
681
 
532
682
 
533
683
  class Between:
684
+ is_literal: bool = False
685
+
534
686
  def __init__(self, start, end):
535
687
  if start > end:
536
688
  start, end = end, start
537
689
  self.start = start
538
690
  self.end = end
539
691
 
692
+ def literal(self) -> Where:
693
+ return Where('BETWEEN {} AND {}'.format(
694
+ self.start, self.end
695
+ ))
696
+
540
697
  def add(self, name: str, main:SQLObject):
541
- Where.gte(self.start).add(name, main),
698
+ if self.is_literal:
699
+ return self.literal().add(name, main)
700
+ Where.gte(self.start).add(name, main)
542
701
  Where.lte(self.end).add(name, main)
543
702
 
703
+ class SameDay(Between):
704
+ def __init__(self, date: str):
705
+ super().__init__(
706
+ f'{date} 00:00:00',
707
+ f'{date} 23:59:59',
708
+ )
709
+
710
+
711
+ class Range(Case):
712
+ INC_FUNCTION = lambda x: x + 1
713
+
714
+ def __init__(self, field: str, values: dict):
715
+ super().__init__(field)
716
+ start = 0
717
+ cls = self.__class__
718
+ for label, value in sorted(values.items(), key=lambda item: item[1]):
719
+ self.when(
720
+ Between(start, value).literal(), label
721
+ )
722
+ start = cls.INC_FUNCTION(value)
723
+
544
724
 
545
725
  class Clause:
546
726
  @classmethod
547
727
  def format(cls, name: str, main: SQLObject) -> str:
548
728
  def is_function() -> bool:
549
729
  diff = main.diff(SELECT, [name.lower()], True)
550
- FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
551
730
  return diff.intersection(FUNCTION_CLASS)
552
731
  found = re.findall(r'^_\d', name)
553
732
  if found:
554
733
  name = found[0].replace('_', '')
555
- elif main.alias and not is_function():
734
+ elif '.' not in name and main.alias and not is_function():
556
735
  name = f'{main.alias}.{name}'
557
736
  return name
558
737
 
@@ -561,6 +740,34 @@ class SortType(Enum):
561
740
  ASC = ''
562
741
  DESC = ' DESC'
563
742
 
743
+ class Row:
744
+ def __init__(self, value: int=0):
745
+ self.value = value
746
+
747
+ def __str__(self) -> str:
748
+ return '{} {}'.format(
749
+ 'UNBOUNDED' if self.value == 0 else self.value,
750
+ self.__class__.__name__.upper()
751
+ )
752
+
753
+ class Preceding(Row):
754
+ ...
755
+ class Following(Row):
756
+ ...
757
+ class Current(Row):
758
+ def __str__(self) -> str:
759
+ return 'CURRENT ROW'
760
+
761
+ class Rows:
762
+ def __init__(self, *rows: list[Row]):
763
+ self.rows = rows
764
+
765
+ def cls_to_str(self) -> str:
766
+ return 'ROWS {}{}'.format(
767
+ 'BETWEEN ' if len(self.rows) > 1 else '',
768
+ ' AND '.join(str(row) for row in self.rows)
769
+ )
770
+
564
771
 
565
772
  class OrderBy(Clause):
566
773
  sort: SortType = SortType.ASC
@@ -595,7 +802,7 @@ class Having:
595
802
 
596
803
  def add(self, name: str, main:SQLObject):
597
804
  main.values[GROUP_BY][-1] += ' HAVING {} {}'.format(
598
- self.function.format(name, main), self.condition.expr
805
+ self.function.format(name, main), self.condition.content
599
806
  )
600
807
 
601
808
  @classmethod
@@ -625,12 +832,20 @@ class Rule:
625
832
  ...
626
833
 
627
834
  class QueryLanguage:
628
- pattern = '{select}{_from}{where}{group_by}{order_by}'
835
+ pattern = '{select}{_from}{where}{group_by}{order_by}{limit}'
629
836
  has_default = {key: bool(key == SELECT) for key in KEYWORD}
630
837
 
631
838
  @staticmethod
632
- def remove_alias(fld: str) -> str:
633
- return ''.join(re.split(r'\w+[.]', fld))
839
+ def remove_alias(text: str) -> str:
840
+ value, sep = '', ''
841
+ text = re.sub('[\n\t]', ' ', text)
842
+ if ':' in text:
843
+ text, value = text.split(':', maxsplit=1)
844
+ sep = ':'
845
+ return '{}{}{}'.format(
846
+ ''.join(re.split(r'\w+[.]', text)),
847
+ sep, value.replace("'", '"')
848
+ )
634
849
 
635
850
  def join_with_tabs(self, values: list, sep: str='') -> str:
636
851
  sep = sep + self.TABULATION
@@ -648,18 +863,21 @@ class QueryLanguage:
648
863
  return self.join_with_tabs(values, ' AND ')
649
864
 
650
865
  def sort_by(self, values: list) -> str:
651
- return self.join_with_tabs(values)
866
+ return self.join_with_tabs(values, ',')
652
867
 
653
868
  def set_group(self, values: list) -> str:
654
869
  return self.join_with_tabs(values, ',')
655
870
 
871
+ def set_limit(self, values: list) -> str:
872
+ return self.join_with_tabs(values, ' ')
873
+
656
874
  def __init__(self, target: 'Select'):
657
- self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY]
875
+ self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY, LIMIT]
658
876
  self.TABULATION = '\n\t' if target.break_lines else ' '
659
877
  self.LINE_BREAK = '\n' if target.break_lines else ' '
660
878
  self.TOKEN_METHODS = {
661
879
  SELECT: self.add_field, FROM: self.get_tables,
662
- WHERE: self.extract_conditions,
880
+ WHERE: self.extract_conditions, LIMIT: self.set_limit,
663
881
  ORDER_BY: self.sort_by, GROUP_BY: self.set_group,
664
882
  }
665
883
  self.result = {}
@@ -695,7 +913,8 @@ class MongoDBLanguage(QueryLanguage):
695
913
  LOGICAL_OP_TO_MONGO_FUNC = {
696
914
  '>': '$gt', '>=': '$gte',
697
915
  '<': '$lt', '<=': '$lte',
698
- '=': '$eq', '<>': '$ne',
916
+ '=': '$eq', '<>': '$ne',
917
+ 'like': '$regex', 'LIKE': '$regex',
699
918
  }
700
919
  OPERATORS = '|'.join(op for op in LOGICAL_OP_TO_MONGO_FUNC)
701
920
  REGEX = {
@@ -748,7 +967,7 @@ class MongoDBLanguage(QueryLanguage):
748
967
  field, *op, const = tokens
749
968
  op = ''.join(op)
750
969
  expr = '{begin}{op}:{const}{end}'.format(
751
- begin='{', const=const, end='}',
970
+ begin='{', const=const.replace('%', '.*'), end='}',
752
971
  op=cls.LOGICAL_OP_TO_MONGO_FUNC[op],
753
972
  )
754
973
  where_list.append(f'{field}:{expr}')
@@ -857,6 +1076,55 @@ class Neo4JLanguage(QueryLanguage):
857
1076
  return ''
858
1077
 
859
1078
 
1079
+ class DatabricksLanguage(QueryLanguage):
1080
+ pattern = '{_from}{where}{group_by}{order_by}{select}{limit}'
1081
+ has_default = {key: bool(key == SELECT) for key in KEYWORD}
1082
+
1083
+ def __init__(self, target: 'Select'):
1084
+ super().__init__(target)
1085
+ self.aggregation_fields = []
1086
+
1087
+ def add_field(self, values: list) -> str:
1088
+ AGG_FUNCS = '|'.join(cls.__name__ for cls in Aggregate.__subclasses__())
1089
+ # --------------------------------------------------------------
1090
+ def is_agg_field(fld: str) -> bool:
1091
+ return re.findall(fr'({AGG_FUNCS})[(]', fld, re.IGNORECASE)
1092
+ # --------------------------------------------------------------
1093
+ new_values = []
1094
+ for val in values:
1095
+ if is_agg_field(val):
1096
+ self.aggregation_fields.append(val)
1097
+ else:
1098
+ new_values.append(val)
1099
+ values = new_values
1100
+ return super().add_field(values)
1101
+
1102
+ def prefix(self, key: str) -> str:
1103
+ def get_aggregate() -> str:
1104
+ return 'AGGREGATE {} '.format(
1105
+ ','.join(self.aggregation_fields)
1106
+ )
1107
+ return '{}{}{}{}{}'.format(
1108
+ '|> ' if key != FROM else '',
1109
+ self.LINE_BREAK,
1110
+ get_aggregate() if key == GROUP_BY else '',
1111
+ key, self.TABULATION
1112
+ )
1113
+
1114
+ # def get_tables(self, values: list) -> str:
1115
+ # return self.join_with_tabs(values)
1116
+
1117
+ # def extract_conditions(self, values: list) -> str:
1118
+ # return self.join_with_tabs(values, ' AND ')
1119
+
1120
+ # def sort_by(self, values: list) -> str:
1121
+ # return self.join_with_tabs(values, ',')
1122
+
1123
+ def set_group(self, values: list) -> str:
1124
+ return self.join_with_tabs(values, ',')
1125
+
1126
+
1127
+
860
1128
  class Parser:
861
1129
  REGEX = {}
862
1130
 
@@ -963,8 +1231,11 @@ class SQLParser(Parser):
963
1231
  if not key in values:
964
1232
  continue
965
1233
  separator = self.class_type.get_separator(key)
1234
+ cls = {
1235
+ ORDER_BY: OrderBy, GROUP_BY: GroupBy
1236
+ }.get(key, Field)
966
1237
  obj.values[key] = [
967
- Field.format(fld, obj)
1238
+ cls.format(fld, obj)
968
1239
  for fld in re.split(separator, values[key])
969
1240
  if (fld != '*' and len(tables) == 1) or obj.match(fld, key)
970
1241
  ]
@@ -1026,7 +1297,11 @@ class CypherParser(Parser):
1026
1297
  if token in self.TOKEN_METHODS:
1027
1298
  return
1028
1299
  class_list = [Field]
1029
- if '$' in token:
1300
+ if '*' in token:
1301
+ token = token.replace('*', '')
1302
+ self.queries[-1].key_field = token
1303
+ return
1304
+ elif '$' in token:
1030
1305
  func_name, token = token.split('$')
1031
1306
  if func_name == 'count':
1032
1307
  if not token:
@@ -1035,8 +1310,13 @@ class CypherParser(Parser):
1035
1310
  Count().As(token, extra_classes).add(pk_field, self.queries[-1])
1036
1311
  return
1037
1312
  else:
1038
- FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
1039
- class_list = [ FUNCTION_CLASS[func_name] ]
1313
+ class_type = FUNCTION_CLASS.get(func_name)
1314
+ if not class_type:
1315
+ raise ValueError(f'Unknown function `{func_name}`.')
1316
+ if ':' in token:
1317
+ token, field_alias = token.split(':')
1318
+ class_type = class_type().As(field_alias)
1319
+ class_list = [class_type]
1040
1320
  class_list += extra_classes
1041
1321
  FieldList(token, class_list).add('', self.queries[-1])
1042
1322
 
@@ -1051,10 +1331,13 @@ class CypherParser(Parser):
1051
1331
  def add_foreign_key(self, token: str, pk_field: str=''):
1052
1332
  curr, last = [self.queries[i] for i in (-1, -2)]
1053
1333
  if not pk_field:
1054
- if not last.values.get(SELECT):
1055
- raise IndexError(f'Primary Key not found for {last.table_name}.')
1056
- pk_field = last.values[SELECT][-1].split('.')[-1]
1057
- last.delete(pk_field, [SELECT], exact=True)
1334
+ if last.key_field:
1335
+ pk_field = last.key_field
1336
+ else:
1337
+ if not last.values.get(SELECT):
1338
+ raise IndexError(f'Primary Key not found for {last.table_name}.')
1339
+ pk_field = last.values[SELECT][-1].split('.')[-1]
1340
+ last.delete(pk_field, [SELECT], exact=True)
1058
1341
  if '{}' in token:
1059
1342
  foreign_fld = token.format(
1060
1343
  last.table_name.lower()
@@ -1197,7 +1480,18 @@ class MongoParser(Parser):
1197
1480
 
1198
1481
  def begin_conditions(self, value: str):
1199
1482
  self.where_list = {}
1483
+ self.field_method = self.first_ORfield
1200
1484
  return Where
1485
+
1486
+ def first_ORfield(self, text: str):
1487
+ if text.startswith('$'):
1488
+ return
1489
+ found = re.search(r'\w+[:]', text)
1490
+ if not found:
1491
+ return
1492
+ self.field_method = None
1493
+ p1, p2 = found.span()
1494
+ self.last_field = text[p1: p2-1]
1201
1495
 
1202
1496
  def increment_brackets(self, value: str):
1203
1497
  self.brackets[value] += 1
@@ -1206,6 +1500,7 @@ class MongoParser(Parser):
1206
1500
  self.method = self.new_query
1207
1501
  self.last_field = ''
1208
1502
  self.where_list = None
1503
+ self.field_method = None
1209
1504
  self.PARAM_BY_FUNCTION = {
1210
1505
  'find': Where, 'aggregate': GroupBy, 'sort': OrderBy
1211
1506
  }
@@ -1235,13 +1530,14 @@ class MongoParser(Parser):
1235
1530
  self.close_brackets(
1236
1531
  BRACKET_PAIR[token]
1237
1532
  )
1533
+ elif self.field_method:
1534
+ self.field_method(token)
1238
1535
  self.method = self.TOKEN_METHODS.get(token)
1239
1536
  # ----------------------------
1240
1537
 
1241
1538
 
1242
1539
  class Select(SQLObject):
1243
1540
  join_type: JoinType = JoinType.INNER
1244
- REGEX = {}
1245
1541
  EQUIVALENT_NAMES = {}
1246
1542
 
1247
1543
  def __init__(self, table_name: str='', **values):
@@ -1259,21 +1555,30 @@ class Select(SQLObject):
1259
1555
 
1260
1556
  def add(self, name: str, main: SQLObject):
1261
1557
  old_tables = main.values.get(FROM, [])
1262
- new_tables = set([
1263
- '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
1558
+ if len(self.values[FROM]) > 1:
1559
+ old_tables += self.values[FROM][1:]
1560
+ new_tables = []
1561
+ row = '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
1264
1562
  jt=self.join_type.value,
1265
1563
  tb=self.aka(),
1266
1564
  a1=main.alias, f1=name,
1267
1565
  a2=self.alias, f2=self.key_field
1268
1566
  )
1269
- ] + old_tables[1:])
1270
- main.values[FROM] = old_tables[:1] + list(new_tables)
1567
+ if row not in old_tables[1:]:
1568
+ new_tables.append(row)
1569
+ main.values[FROM] = old_tables[:1] + new_tables + old_tables[1:]
1271
1570
  for key in USUAL_KEYS:
1272
1571
  main.update_values(key, self.values.get(key, []))
1273
1572
 
1274
- def __add__(self, other: SQLObject):
1573
+ def copy(self) -> SQLObject:
1275
1574
  from copy import deepcopy
1276
- query = deepcopy(self)
1575
+ return deepcopy(self)
1576
+
1577
+ def no_relation_error(self, other: SQLObject):
1578
+ raise ValueError(f'No relationship found between {self.table_name} and {other.table_name}.')
1579
+
1580
+ def __add__(self, other: SQLObject):
1581
+ query = self.copy()
1277
1582
  if query.table_name.lower() == other.table_name.lower():
1278
1583
  for key in USUAL_KEYS:
1279
1584
  query.update_values(key, other.values.get(key, []))
@@ -1286,7 +1591,7 @@ class Select(SQLObject):
1286
1591
  PrimaryKey.add(primary_key, query)
1287
1592
  query.add(foreign_field, other)
1288
1593
  return other
1289
- raise ValueError(f'No relationship found between {query.table_name} and {other.table_name}.')
1594
+ self.no_relation_error(other) # === raise ERROR ... ===
1290
1595
  elif primary_key:
1291
1596
  PrimaryKey.add(primary_key, other)
1292
1597
  other.add(foreign_field, query)
@@ -1306,12 +1611,39 @@ class Select(SQLObject):
1306
1611
  if self.diff(key, other.values.get(key, []), True):
1307
1612
  return False
1308
1613
  return True
1614
+
1615
+ def __sub__(self, other: SQLObject) -> SQLObject:
1616
+ fk_field, primary_k = ForeignKey.find(self, other)
1617
+ if fk_field:
1618
+ query = self.copy()
1619
+ other = other.copy()
1620
+ else:
1621
+ fk_field, primary_k = ForeignKey.find(other, self)
1622
+ if not fk_field:
1623
+ self.no_relation_error(other) # === raise ERROR ... ===
1624
+ query = other.copy()
1625
+ other = self.copy()
1626
+ query.__class__ = NotSelectIN
1627
+ Field.add(fk_field, query)
1628
+ query.add(primary_k, other)
1629
+ return other
1309
1630
 
1310
1631
  def limit(self, row_count: int=100, offset: int=0):
1311
- result = [str(row_count)]
1312
- if offset > 0:
1313
- result.append(f'OFFSET {offset}')
1314
- self.values.setdefault(LIMIT, result)
1632
+ if Function.dialect == Dialect.SQL_SERVER:
1633
+ fields = self.values.get(SELECT)
1634
+ if fields:
1635
+ fields[0] = f'SELECT TOP({row_count}) {fields[0]}'
1636
+ else:
1637
+ self.values[SELECT] = [f'SELECT TOP({row_count}) *']
1638
+ return self
1639
+ if Function.dialect == Dialect.ORACLE:
1640
+ Where.gte(row_count).add(SQL_ROW_NUM, self)
1641
+ if offset > 0:
1642
+ Where.lte(row_count+offset).add(SQL_ROW_NUM, self)
1643
+ return self
1644
+ self.values[LIMIT] = ['{}{}'.format(
1645
+ row_count, f' OFFSET {offset}' if offset > 0 else ''
1646
+ )]
1315
1647
  return self
1316
1648
 
1317
1649
  def match(self, field: str, key: str) -> bool:
@@ -1319,11 +1651,7 @@ class Select(SQLObject):
1319
1651
  Recognizes if the field is from the current table
1320
1652
  '''
1321
1653
  if key in (ORDER_BY, GROUP_BY) and '.' not in field:
1322
- return any(
1323
- self.is_named_field(fld, SELECT)
1324
- for fld in self.values[SELECT]
1325
- if field in fld
1326
- )
1654
+ return self.has_named_field(field)
1327
1655
  return re.findall(f'\b*{self.alias}[.]', field) != []
1328
1656
 
1329
1657
  @classmethod
@@ -1336,12 +1664,10 @@ class Select(SQLObject):
1336
1664
  for rule in rules:
1337
1665
  rule.apply(self)
1338
1666
 
1339
- def add_fields(self, fields: list, order_by: bool=False, group_by:bool=False):
1340
- class_types = [Field]
1341
- if order_by:
1342
- class_types += [OrderBy]
1343
- if group_by:
1344
- class_types += [GroupBy]
1667
+ def add_fields(self, fields: list, class_types=None):
1668
+ if not class_types:
1669
+ class_types = []
1670
+ class_types += [Field]
1345
1671
  FieldList(fields, class_types).add('', self)
1346
1672
 
1347
1673
  def translate_to(self, language: QueryLanguage) -> str:
@@ -1361,6 +1687,95 @@ class NotSelectIN(SelectIN):
1361
1687
  condition_class = Not
1362
1688
 
1363
1689
 
1690
+ class CTE(Select):
1691
+ prefix = ''
1692
+
1693
+ def __init__(self, table_name: str, query_list: list[Select]):
1694
+ super().__init__(table_name)
1695
+ for query in query_list:
1696
+ query.break_lines = False
1697
+ self.query_list = query_list
1698
+ self.break_lines = False
1699
+
1700
+ def __str__(self) -> str:
1701
+ size = 0
1702
+ for key in USUAL_KEYS:
1703
+ size += sum(len(v) for v in self.values.get(key, []) if '\n' not in v)
1704
+ if size > 70:
1705
+ self.break_lines = True
1706
+ # ---------------------------------------------------------
1707
+ def justify(query: Select) -> str:
1708
+ result, line = [], ''
1709
+ keywords = '|'.join(KEYWORD)
1710
+ for word in re.split(fr'({keywords}|AND|OR|,)', str(query)):
1711
+ if len(line) >= 50:
1712
+ result.append(line)
1713
+ line = ''
1714
+ line += word
1715
+ if line:
1716
+ result.append(line)
1717
+ return '\n '.join(result)
1718
+ # ---------------------------------------------------------
1719
+ return 'WITH {}{} AS (\n {}\n){}'.format(
1720
+ self.prefix, self.table_name,
1721
+ '\nUNION ALL\n '.join(
1722
+ justify(q) for q in self.query_list
1723
+ ), super().__str__()
1724
+ )
1725
+
1726
+ def join(self, pattern: str, fields: list | str, format: str=''):
1727
+ if isinstance(fields, str):
1728
+ count = len( fields.split(',') )
1729
+ else:
1730
+ count = len(fields)
1731
+ queries = detect(
1732
+ pattern*count, join_queries=False, format=format
1733
+ )
1734
+ FieldList(fields, queries, ziped=True).add('', self)
1735
+ self.break_lines = True
1736
+ return self
1737
+
1738
+ class Recursive(CTE):
1739
+ prefix = 'RECURSIVE '
1740
+
1741
+ def __str__(self) -> str:
1742
+ if len(self.query_list) > 1:
1743
+ self.query_list[-1].values[FROM].append(
1744
+ f', {self.table_name} {self.alias}')
1745
+ return super().__str__()
1746
+
1747
+ @classmethod
1748
+ def create(cls, name: str, pattern: str, formula: str, init_value, format: str=''):
1749
+ SQLObject.ALIAS_FUNC = None
1750
+ def get_field(obj: SQLObject, pos: int) -> str:
1751
+ return obj.values[SELECT][pos].split('.')[-1]
1752
+ t1, t2 = detect(
1753
+ pattern*2, join_queries=False, format=format
1754
+ )
1755
+ pk_field = get_field(t1, 0)
1756
+ foreign_key = ''
1757
+ for num in re.findall(r'\[(\d+)\]', formula):
1758
+ num = int(num)
1759
+ if not foreign_key:
1760
+ foreign_key = get_field(t2, num-1)
1761
+ formula = formula.replace(f'[{num}]', '%')
1762
+ else:
1763
+ formula = formula.replace(f'[{num}]', get_field(t2, num-1))
1764
+ Where.eq(init_value).add(pk_field, t1)
1765
+ Where.formula(formula).add(foreign_key or pk_field, t2)
1766
+ return cls(name, [t1, t2])
1767
+
1768
+ def counter(self, name: str, start, increment: str='+1'):
1769
+ for i, query in enumerate(self.query_list):
1770
+ if i == 0:
1771
+ Field.add(f'{start} AS {name}', query)
1772
+ else:
1773
+ Field.add(f'({name}{increment}) AS {name}', query)
1774
+ return self
1775
+
1776
+
1777
+ # ----- Rules -----
1778
+
1364
1779
  class RulePutLimit(Rule):
1365
1780
  @classmethod
1366
1781
  def apply(cls, target: Select):
@@ -1424,6 +1839,8 @@ class RuleDateFuncReplace(Rule):
1424
1839
  @classmethod
1425
1840
  def apply(cls, target: Select):
1426
1841
  for i, condition in enumerate(target.values.get(WHERE, [])):
1842
+ if not '(' in condition:
1843
+ continue
1427
1844
  tokens = [
1428
1845
  t.strip() for t in cls.REGEX.split(condition) if t.strip()
1429
1846
  ]
@@ -1475,7 +1892,7 @@ def parser_class(text: str) -> Parser:
1475
1892
  return None
1476
1893
 
1477
1894
 
1478
- def detect(text: str) -> Select:
1895
+ def detect(text: str, join_queries: bool = True, format: str='') -> Select | list[Select]:
1479
1896
  from collections import Counter
1480
1897
  parser = parser_class(text)
1481
1898
  if not parser:
@@ -1486,27 +1903,65 @@ def detect(text: str) -> Select:
1486
1903
  continue
1487
1904
  pos = [ f.span() for f in re.finditer(fr'({table})[(]', text) ]
1488
1905
  for begin, end in pos[::-1]:
1489
- new_name = f'{table}_{count}' # See set_table (line 45)
1906
+ new_name = f'{table}_{count}' # See set_table (line 55)
1490
1907
  Select.EQUIVALENT_NAMES[new_name] = table
1491
1908
  text = text[:begin] + new_name + '(' + text[end:]
1492
1909
  count -= 1
1493
1910
  query_list = Select.parse(text, parser)
1911
+ if format:
1912
+ for query in query_list:
1913
+ query.set_file_format(format)
1914
+ if not join_queries:
1915
+ return query_list
1494
1916
  result = query_list[0]
1495
1917
  for query in query_list[1:]:
1496
1918
  result += query
1497
1919
  return result
1498
-
1499
-
1500
- if __name__ == '__main__':
1501
- p, c, a = Select.parse('''
1502
- Professor(?nome="Júlio Cascalles", id)
1503
- <- Curso@disciplina(professor, aluno) ->
1504
- Aluno(id ^count$qtd_alunos)
1505
- ''', CypherParser)
1506
- query = p + c + a
1507
- print('#######################################')
1508
- print(query)
1509
- print('***************************************')
1510
- query.optimize([RuleReplaceJoinBySubselect])
1511
- print(query)
1512
- print('#######################################')
1920
+ # ===========================================================================================//
1921
+
1922
+
1923
+ if __name__ == "__main__":
1924
+ # def identifica_suspeitos() -> Select:
1925
+ # """Mostra quais pessoas tem caracteríosticas iguais à descrição do suspeito"""
1926
+ # Select.join_type = JoinType.LEFT
1927
+ # return Select(
1928
+ # 'Suspeito s', id=Field,
1929
+ # _=Where.join(
1930
+ # Select('Pessoa p',
1931
+ # OR=Options(
1932
+ # pessoa=Where('= s.id'),
1933
+ # altura=Where.formula('ABS(% - s.{f}) < 0.5'),
1934
+ # peso=Where.formula('ABS(% - s.{f}) < 0.5'),
1935
+ # cabelo=Where.formula('% = s.{f}'),
1936
+ # olhos=Where.formula('% = s.{f}'),
1937
+ # sexo=Where.formula('% = s.{f}'),
1938
+ # ),
1939
+ # nome=Field
1940
+ # )
1941
+ # )
1942
+ # )
1943
+ # query = identifica_suspeitos()
1944
+ # print('='*50)
1945
+ # print(query)
1946
+ # print('-'*50)
1947
+ script = '''
1948
+ db.people.find({
1949
+ {
1950
+ $or: [
1951
+ status:{$eq:"B"},
1952
+ age:{$lt:50}
1953
+ ]
1954
+ },
1955
+ age:{$gte:18}, status:{$eq:"A"}
1956
+ },{
1957
+ name: 1, user_id: 1
1958
+ }).sort({
1959
+ '''
1960
+ print('='*50)
1961
+ q1 = Select.parse(script, MongoParser)[0]
1962
+ print(q1)
1963
+ print('-'*50)
1964
+ q2 = q1.translate_to(MongoDBLanguage)
1965
+ print(q2)
1966
+ # print('-'*50)
1967
+ print('='*50)