sql-blocks 1.25.13__py3-none-any.whl → 1.25.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sql_blocks/sql_blocks.py CHANGED
@@ -26,7 +26,7 @@ TO_LIST = lambda x: x if isinstance(x, list) else [x]
26
26
 
27
27
 
28
28
  class SQLObject:
29
- ALIAS_FUNC = lambda t: t.lower()[:3]
29
+ ALIAS_FUNC = None
30
30
  """ ^^^^^^^^^^^^^^^^^^^^^^^^
31
31
  You can change the behavior by assigning
32
32
  a user function to SQLObject.ALIAS_FUNC
@@ -41,21 +41,35 @@ class SQLObject:
41
41
  def set_table(self, table_name: str):
42
42
  if not table_name:
43
43
  return
44
- if ' ' in table_name.strip():
44
+ cls = SQLObject
45
+ is_file_name = any([
46
+ '/' in table_name, '.' in table_name
47
+ ])
48
+ ref = table_name
49
+ if is_file_name:
50
+ ref = table_name.split('/')[-1].split('.')[0]
51
+ if cls.ALIAS_FUNC:
52
+ self.__alias = cls.ALIAS_FUNC(ref)
53
+ elif ' ' in table_name.strip():
45
54
  table_name, self.__alias = table_name.split()
46
- elif '_' in table_name:
55
+ elif '_' in ref:
47
56
  self.__alias = ''.join(
48
57
  word[0].lower()
49
- for word in table_name.split('_')
58
+ for word in ref.split('_')
50
59
  )
51
60
  else:
52
- self.__alias = SQLObject.ALIAS_FUNC(table_name)
61
+ self.__alias = ref.lower()[:3]
53
62
  self.values.setdefault(FROM, []).append(f'{table_name} {self.alias}')
54
63
 
55
64
  @property
56
65
  def table_name(self) -> str:
57
66
  return self.values[FROM][0].split()[0]
58
67
 
68
+ def set_file_format(self, pattern: str):
69
+ if '{' not in pattern:
70
+ pattern = '{}' + pattern
71
+ self.values[FROM][0] = pattern.format(self.aka())
72
+
59
73
  @property
60
74
  def alias(self) -> str:
61
75
  if self.__alias:
@@ -67,6 +81,16 @@ class SQLObject:
67
81
  appendix = {WHERE: r'\s+and\s+|', FROM: r'\s+join\s+|\s+JOIN\s+'}
68
82
  return KEYWORD[key][0].format(appendix.get(key, ''))
69
83
 
84
+ @staticmethod
85
+ def is_named_field(fld: str, name: str='') -> bool:
86
+ return re.search(fr'(\s+as\s+|\s+AS\s+){name}', fld)
87
+
88
+ def has_named_field(self, name: str) -> bool:
89
+ return any(
90
+ self.is_named_field(fld, name)
91
+ for fld in self.values.get(SELECT, [])
92
+ )
93
+
70
94
  def diff(self, key: str, search_list: list, exact: bool=False) -> set:
71
95
  def disassemble(source: list) -> list:
72
96
  if not exact:
@@ -79,12 +103,12 @@ class SQLObject:
79
103
  if exact:
80
104
  fld = fld.lower()
81
105
  return fld.strip()
82
- def is_named_field(fld: str) -> bool:
83
- return key == SELECT and re.search(r'\s+as\s+|\s+AS\s+', fld)
84
106
  def field_set(source: list) -> set:
85
107
  return set(
86
108
  (
87
- fld if is_named_field(fld) else
109
+ fld
110
+ if key == SELECT and self.is_named_field(fld, key)
111
+ else
88
112
  re.sub(pattern, '', cleanup(fld))
89
113
  )
90
114
  for string in disassemble(source)
@@ -102,13 +126,22 @@ class SQLObject:
102
126
  return s1.symmetric_difference(s2)
103
127
  return s1 - s2
104
128
 
105
- def delete(self, search: str, keys: list=USUAL_KEYS):
129
+ def delete(self, search: str, keys: list=USUAL_KEYS, exact: bool=False):
130
+ if exact:
131
+ not_match = lambda item: not re.search(fr'\w*[.]*{search}$', item)
132
+ else:
133
+ not_match = lambda item: search not in item
106
134
  for key in keys:
107
- result = []
108
- for item in self.values.get(key, []):
109
- if search not in item:
110
- result.append(item)
111
- self.values[key] = result
135
+ self.values[key] = [
136
+ item for item in self.values.get(key, [])
137
+ if not_match(item)
138
+ ]
139
+
140
+
141
+ SQL_CONST_SYSDATE = 'SYSDATE'
142
+ SQL_CONST_CURR_DATE = 'Current_date'
143
+ SQL_ROW_NUM = 'ROWNUM'
144
+ SQL_CONSTS = [SQL_CONST_SYSDATE, SQL_CONST_CURR_DATE, SQL_ROW_NUM]
112
145
 
113
146
 
114
147
  class Field:
@@ -116,10 +149,16 @@ class Field:
116
149
 
117
150
  @classmethod
118
151
  def format(cls, name: str, main: SQLObject) -> str:
152
+ def is_const() -> bool:
153
+ return any([
154
+ re.findall('[.()0-9]', name),
155
+ name in SQL_CONSTS,
156
+ re.findall(r'\w+\s*[+-]\s*\w+', name)
157
+ ])
119
158
  name = name.strip()
120
159
  if name in ('_', '*'):
121
160
  name = '*'
122
- elif not re.findall('[.()0-9]', name):
161
+ elif not is_const() and not main.has_named_field(name):
123
162
  name = f'{main.alias}.{name}'
124
163
  if Function in cls.__bases__:
125
164
  name = f'{cls.__name__}({name})'
@@ -150,35 +189,89 @@ class NamedField:
150
189
  )
151
190
 
152
191
 
192
+ class Dialect(Enum):
193
+ ANSI = 0
194
+ SQL_SERVER = 1
195
+ ORACLE = 2
196
+ POSTGRESQL = 3
197
+ MYSQL = 4
198
+
199
+ SQL_TYPES = 'CHAR INT DATE FLOAT ANY'.split()
200
+ CHAR, INT, DATE, FLOAT, ANY = SQL_TYPES
201
+
153
202
  class Function:
203
+ dialect = Dialect.ANSI
204
+ inputs = None
205
+ output = None
206
+ separator = ', '
207
+ auto_convert = True
208
+ append_param = False
209
+
154
210
  def __init__(self, *params: list):
211
+ def set_func_types(param):
212
+ if self.auto_convert and isinstance(param, Function):
213
+ func = param
214
+ main_param = self.inputs[0]
215
+ unfriendly = all([
216
+ func.output != main_param,
217
+ func.output != ANY,
218
+ main_param != ANY
219
+ ])
220
+ if unfriendly:
221
+ return Cast(func, main_param)
222
+ return param
155
223
  # --- Replace class methods by instance methods: ------
156
224
  self.add = self.__add
157
225
  self.format = self.__format
158
226
  # -----------------------------------------------------
159
- self.params = [str(p) for p in params]
227
+ self.params = [set_func_types(p) for p in params]
160
228
  self.field_class = Field
161
- self.pattern = '{}({})'
229
+ self.pattern = self.get_pattern()
162
230
  self.extra = {}
163
231
 
232
+ def get_pattern(self) -> str:
233
+ return '{func_name}({params})'
234
+
164
235
  def As(self, field_alias: str, modifiers=None):
165
236
  if modifiers:
166
237
  self.extra[field_alias] = TO_LIST(modifiers)
167
238
  self.field_class = NamedField(field_alias)
168
239
  return self
169
240
 
170
- def __format(self, name: str, main: SQLObject) -> str:
171
- if name in '*_' and self.params:
172
- params = self.params
173
- else:
174
- params = [
175
- Field.format(name, main)
176
- ] + self.params
241
+ def __str__(self) -> str:
177
242
  return self.pattern.format(
178
- self.__class__.__name__,
179
- ', '.join(params)
243
+ func_name=self.__class__.__name__,
244
+ params=self.separator.join(str(p) for p in self.params)
180
245
  )
181
246
 
247
+ @classmethod
248
+ def help(cls) -> str:
249
+ descr = ' '.join(B.__name__ for B in cls.__bases__)
250
+ params = cls.inputs or ''
251
+ return cls().get_pattern().format(
252
+ func_name=f'{descr} {cls.__name__}',
253
+ params=cls.separator.join(str(p) for p in params)
254
+ ) + f' Return {cls.output}'
255
+
256
+ def set_main_param(self, name: str, main: SQLObject) -> bool:
257
+ nested_functions = [
258
+ param for param in self.params if isinstance(param, Function)
259
+ ]
260
+ for func in nested_functions:
261
+ if func.inputs:
262
+ func.set_main_param(name, main)
263
+ return
264
+ new_params = [Field.format(name, main)]
265
+ if self.append_param:
266
+ self.params += new_params
267
+ else:
268
+ self.params = new_params + self.params
269
+
270
+ def __format(self, name: str, main: SQLObject) -> str:
271
+ if name not in '*_':
272
+ self.set_main_param(name, main)
273
+ return str(self)
274
+
182
275
  @classmethod
183
276
  def format(cls, name: str, main: SQLObject):
184
277
  return cls().__format(name, main)
@@ -196,39 +289,110 @@ class Function:
196
289
 
197
290
  # ---- String Functions: ---------------------------------
198
291
  class SubString(Function):
199
- ...
292
+ inputs = [CHAR, INT, INT]
293
+ output = CHAR
294
+
295
+ def get_pattern(self) -> str:
296
+ if self.dialect in (Dialect.ORACLE, Dialect.MYSQL):
297
+ return 'Substr({params})'
298
+ return super().get_pattern()
200
299
 
201
300
  # ---- Numeric Functions: --------------------------------
202
301
  class Round(Function):
203
- ...
302
+ inputs = [FLOAT]
303
+ output = FLOAT
204
304
 
205
305
  # --- Date Functions: ------------------------------------
206
306
  class DateDiff(Function):
307
+ inputs = [DATE]
308
+ output = DATE
309
+ append_param = True
310
+
311
+ def __str__(self) -> str:
312
+ def is_field_or_func(name: str) -> bool:
313
+ candidate = re.sub(
314
+ '[()]', '', name.split('.')[-1]
315
+ )
316
+ return candidate.isidentifier()
317
+ if self.dialect != Dialect.SQL_SERVER:
318
+ params = [str(p) for p in self.params]
319
+ return ' - '.join(
320
+ p if is_field_or_func(p) else f"'{p}'"
321
+ for p in params
322
+ ) # <==== Date subtract
323
+ return super().__str__()
324
+
325
+
326
+ class DatePart(Function):
327
+ inputs = [DATE]
328
+ output = INT
329
+
330
+ def get_pattern(self) -> str:
331
+ interval = self.__class__.__name__
332
+ database_type = {
333
+ Dialect.ORACLE: 'Extract('+interval+' FROM {params})',
334
+ Dialect.POSTGRESQL: "Date_Part('"+interval+"', {params})",
335
+ }
336
+ if self.dialect in database_type:
337
+ return database_type[self.dialect]
338
+ return super().get_pattern()
339
+
340
+ class Year(DatePart):
207
341
  ...
208
- class Extract(Function):
342
+ class Month(DatePart):
209
343
  ...
210
- class DatePart(Function):
344
+ class Day(DatePart):
211
345
  ...
346
+
347
+
212
348
  class Current_Date(Function):
213
- ...
349
+ output = DATE
350
+
351
+ def get_pattern(self) -> str:
352
+ database_type = {
353
+ Dialect.ORACLE: SQL_CONST_SYSDATE,
354
+ Dialect.POSTGRESQL: SQL_CONST_CURR_DATE,
355
+ Dialect.SQL_SERVER: 'getDate()'
356
+ }
357
+ if self.dialect in database_type:
358
+ return database_type[self.dialect]
359
+ return super().get_pattern()
360
+ # --------------------------------------------------------
214
361
 
215
- class Aggregate:
362
+ class Frame:
216
363
  break_lines: bool = True
217
364
 
218
365
  def over(self, **args):
219
- keywords = ' '.join(
220
- '{}{} BY {}'.format(
221
- '\n\t\t' if self.break_lines else '',
222
- key.upper(), args[key]
223
- ) for key in ('partition', 'order')
224
- if key in args
225
- )
366
+ """
367
+ How to use:
368
+ over(field1=OrderBy, field2=Partition)
369
+ """
370
+ keywords = ''
371
+ for field, obj in args.items():
372
+ is_valid = any([
373
+ obj is OrderBy,
374
+ obj is Partition,
375
+ isinstance(obj, Rows),
376
+ ])
377
+ if not is_valid:
378
+ continue
379
+ keywords += '{}{} {}'.format(
380
+ '\n\t\t' if self.break_lines else ' ',
381
+ obj.cls_to_str(), field if field != '_' else ''
382
+ )
226
383
  if keywords and self.break_lines:
227
384
  keywords += '\n\t'
228
- self.pattern = '{}({})' + f' OVER({keywords})'
385
+ self.pattern = self.get_pattern() + f' OVER({keywords})'
229
386
  return self
230
387
 
231
388
 
389
+ class Aggregate(Frame):
390
+ inputs = [FLOAT]
391
+ output = FLOAT
392
+
393
+ class Window(Frame):
394
+ ...
395
+
232
396
  # ---- Aggregate Functions: -------------------------------
233
397
  class Avg(Aggregate, Function):
234
398
  ...
@@ -241,11 +405,32 @@ class Sum(Aggregate, Function):
241
405
  class Count(Aggregate, Function):
242
406
  ...
243
407
 
408
+ # ---- Window Functions: -----------------------------------
409
+ class Row_Number(Window, Function):
410
+ output = INT
411
+
412
+ class Rank(Window, Function):
413
+ output = INT
414
+
415
+ class Lag(Window, Function):
416
+ output = ANY
417
+
418
+ class Lead(Window, Function):
419
+ output = ANY
420
+
421
+
244
422
  # ---- Conversions and other Functions: ---------------------
245
423
  class Coalesce(Function):
246
- ...
424
+ inputs = [ANY]
425
+ output = ANY
426
+
247
427
  class Cast(Function):
248
- ...
428
+ inputs = [ANY]
429
+ output = ANY
430
+ separator = ' As '
431
+
432
+
433
+ FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
249
434
 
250
435
 
251
436
  class ExpressionField:
@@ -270,15 +455,20 @@ class ExpressionField:
270
455
  class FieldList:
271
456
  separator = ','
272
457
 
273
- def __init__(self, fields: list=[], class_types = [Field]):
458
+ def __init__(self, fields: list=[], class_types = [Field], ziped: bool=False):
274
459
  if isinstance(fields, str):
275
460
  fields = [
276
461
  f.strip() for f in fields.split(self.separator)
277
462
  ]
278
463
  self.fields = fields
279
464
  self.class_types = class_types
465
+ self.ziped = ziped
280
466
 
281
467
  def add(self, name: str, main: SQLObject):
468
+ if self.ziped: # --- One class per field...
469
+ for field, class_type in zip(self.fields, self.class_types):
470
+ class_type.add(field, main)
471
+ return
282
472
  for field in self.fields:
283
473
  for class_type in self.class_types:
284
474
  class_type.add(field, main)
@@ -324,23 +514,35 @@ def quoted(value) -> str:
324
514
  return str(value)
325
515
 
326
516
 
517
+ class Position(Enum):
518
+ Middle = 0
519
+ StartsWith = 1
520
+ EndsWith = 2
521
+
522
+
327
523
  class Where:
328
524
  prefix = ''
329
525
 
330
- def __init__(self, expr: str):
331
- self.expr = expr
526
+ def __init__(self, content: str):
527
+ self.content = content
332
528
 
333
529
  @classmethod
334
530
  def __constructor(cls, operator: str, value):
335
- return cls(expr=f'{operator} {quoted(value)}')
531
+ return cls(f'{operator} {quoted(value)}')
336
532
 
337
533
  @classmethod
338
534
  def eq(cls, value):
339
535
  return cls.__constructor('=', value)
340
536
 
341
537
  @classmethod
342
- def contains(cls, value: str):
343
- return cls(f"LIKE '%{value}%'")
538
+ def contains(cls, text: str, pos: Position = Position.Middle):
539
+ return cls(
540
+ "LIKE '{}{}{}'".format(
541
+ '%' if pos != Position.StartsWith else '',
542
+ text,
543
+ '%' if pos != Position.EndsWith else ''
544
+ )
545
+ )
344
546
 
345
547
  @classmethod
346
548
  def gt(cls, value):
@@ -368,9 +570,42 @@ class Where:
368
570
  values = ','.join(quoted(v) for v in values)
369
571
  return cls(f'IN ({values})')
370
572
 
573
+ @classmethod
574
+ def formula(cls, formula: str):
575
+ where = cls( ExpressionField(formula) )
576
+ where.add = where.add_expression
577
+ return where
578
+
579
+ def add_expression(self, name: str, main: SQLObject):
580
+ self.content = self.content.format(name, main)
581
+ main.values.setdefault(WHERE, []).append('{} {}'.format(
582
+ self.prefix, self.content
583
+ ))
584
+
585
+ @classmethod
586
+ def join(cls, query: SQLObject):
587
+ where = cls(query)
588
+ where.add = where.add_join
589
+ return where
590
+
591
+ def add_join(self, name: str, main: SQLObject):
592
+ query = self.content
593
+ main.values[FROM].append(f',{query.table_name} {query.alias}')
594
+ for key in USUAL_KEYS:
595
+ main.update_values(key, query.values.get(key, []))
596
+ main.values.setdefault(WHERE, []).append('({a1}.{f1} = {a2}.{f2})'.format(
597
+ a1=main.alias, f1=name,
598
+ a2=query.alias, f2=query.key_field
599
+ ))
600
+
371
601
  def add(self, name: str, main: SQLObject):
602
+ func_type = FUNCTION_CLASS.get(name.lower())
603
+ if func_type:
604
+ name = func_type.format('*', main)
605
+ elif not main.has_named_field(name):
606
+ name = Field.format(name, main)
372
607
  main.values.setdefault(WHERE, []).append('{}{} {}'.format(
373
- self.prefix, Field.format(name, main), self.expr
608
+ self.prefix, name, self.content
374
609
  ))
375
610
 
376
611
 
@@ -378,6 +613,10 @@ eq, contains, gt, gte, lt, lte, is_null, inside = (
378
613
  getattr(Where, method) for method in
379
614
  ('eq', 'contains', 'gt', 'gte', 'lt', 'lte', 'is_null', 'inside')
380
615
  )
616
+ startswith, endswith = [
617
+ lambda x: contains(x, Position.StartsWith),
618
+ lambda x: contains(x, Position.EndsWith)
619
+ ]
381
620
 
382
621
 
383
622
  class Not(Where):
@@ -385,7 +624,7 @@ class Not(Where):
385
624
 
386
625
  @classmethod
387
626
  def eq(cls, value):
388
- return Where(expr=f'<> {quoted(value)}')
627
+ return Where(f'<> {quoted(value)}')
389
628
 
390
629
 
391
630
  class Case:
@@ -394,20 +633,24 @@ class Case:
394
633
  self.default = None
395
634
  self.field = field
396
635
 
397
- def when(self, condition: Where, result: str):
636
+ def when(self, condition: Where, result):
637
+ if isinstance(result, str):
638
+ result = quoted(result)
398
639
  self.__conditions[result] = condition
399
640
  return self
400
641
 
401
- def else_value(self, default: str):
642
+ def else_value(self, default):
643
+ if isinstance(default, str):
644
+ default = quoted(default)
402
645
  self.default = default
403
646
  return self
404
647
 
405
648
  def add(self, name: str, main: SQLObject):
406
649
  field = Field.format(self.field, main)
407
- default = quoted(self.default)
650
+ default = self.default
408
651
  name = 'CASE \n{}\n\tEND AS {}'.format(
409
652
  '\n'.join(
410
- f'\t\tWHEN {field} {cond.expr} THEN {quoted(res)}'
653
+ f'\t\tWHEN {field} {cond.content} THEN {res}'
411
654
  for res, cond in self.__conditions.items()
412
655
  ) + f'\n\t\tELSE {default}' if default else '',
413
656
  name
@@ -420,14 +663,13 @@ class Options:
420
663
  self.__children: dict = values
421
664
 
422
665
  def add(self, logical_separator: str, main: SQLObject):
423
- """
424
- `logical_separator` must be AND or OR
425
- """
666
+ if logical_separator not in ('AND', 'OR'):
667
+ raise ValueError('`logical_separator` must be AND or OR')
426
668
  conditions: list[str] = []
427
669
  child: Where
428
670
  for field, child in self.__children.items():
429
671
  conditions.append(' {} {} '.format(
430
- Field.format(field, main), child.expr
672
+ Field.format(field, main), child.content
431
673
  ))
432
674
  main.values.setdefault(WHERE, []).append(
433
675
  '(' + logical_separator.join(conditions) + ')'
@@ -445,18 +687,25 @@ class Between:
445
687
  Where.gte(self.start).add(name, main),
446
688
  Where.lte(self.end).add(name, main)
447
689
 
690
+ class SameDay(Between):
691
+ def __init__(self, date: str):
692
+ super().__init__(
693
+ f'{date} 00:00:00',
694
+ f'{date} 23:59:59',
695
+ )
696
+
697
+
448
698
 
449
699
  class Clause:
450
700
  @classmethod
451
701
  def format(cls, name: str, main: SQLObject) -> str:
452
702
  def is_function() -> bool:
453
703
  diff = main.diff(SELECT, [name.lower()], True)
454
- FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
455
704
  return diff.intersection(FUNCTION_CLASS)
456
705
  found = re.findall(r'^_\d', name)
457
706
  if found:
458
707
  name = found[0].replace('_', '')
459
- elif main.alias and not is_function():
708
+ elif '.' not in name and main.alias and not is_function():
460
709
  name = f'{main.alias}.{name}'
461
710
  return name
462
711
 
@@ -465,6 +714,34 @@ class SortType(Enum):
465
714
  ASC = ''
466
715
  DESC = ' DESC'
467
716
 
717
+ class Row:
718
+ def __init__(self, value: int=0):
719
+ self.value = value
720
+
721
+ def __str__(self) -> str:
722
+ return '{} {}'.format(
723
+ 'UNBOUNDED' if self.value == 0 else self.value,
724
+ self.__class__.__name__.upper()
725
+ )
726
+
727
+ class Preceding(Row):
728
+ ...
729
+ class Following(Row):
730
+ ...
731
+ class Current(Row):
732
+ def __str__(self) -> str:
733
+ return 'CURRENT ROW'
734
+
735
+ class Rows:
736
+ def __init__(self, *rows: list[Row]):
737
+ self.rows = rows
738
+
739
+ def cls_to_str(self) -> str:
740
+ return 'ROWS {}{}'.format(
741
+ 'BETWEEN ' if len(self.rows) > 1 else '',
742
+ ' AND '.join(str(row) for row in self.rows)
743
+ )
744
+
468
745
 
469
746
  class OrderBy(Clause):
470
747
  sort: SortType = SortType.ASC
@@ -474,6 +751,16 @@ class OrderBy(Clause):
474
751
  name = cls.format(name, main)
475
752
  main.values.setdefault(ORDER_BY, []).append(name+cls.sort.value)
476
753
 
754
+ @classmethod
755
+ def cls_to_str(cls) -> str:
756
+ return ORDER_BY
757
+
758
+ PARTITION_BY = 'PARTITION BY'
759
+ class Partition:
760
+ @classmethod
761
+ def cls_to_str(cls) -> str:
762
+ return PARTITION_BY
763
+
477
764
 
478
765
  class GroupBy(Clause):
479
766
  @classmethod
@@ -489,7 +776,7 @@ class Having:
489
776
 
490
777
  def add(self, name: str, main:SQLObject):
491
778
  main.values[GROUP_BY][-1] += ' HAVING {} {}'.format(
492
- self.function.format(name, main), self.condition.expr
779
+ self.function.format(name, main), self.condition.content
493
780
  )
494
781
 
495
782
  @classmethod
@@ -519,7 +806,7 @@ class Rule:
519
806
  ...
520
807
 
521
808
  class QueryLanguage:
522
- pattern = '{select}{_from}{where}{group_by}{order_by}'
809
+ pattern = '{select}{_from}{where}{group_by}{order_by}{limit}'
523
810
  has_default = {key: bool(key == SELECT) for key in KEYWORD}
524
811
 
525
812
  @staticmethod
@@ -542,18 +829,21 @@ class QueryLanguage:
542
829
  return self.join_with_tabs(values, ' AND ')
543
830
 
544
831
  def sort_by(self, values: list) -> str:
545
- return self.join_with_tabs(values)
832
+ return self.join_with_tabs(values, ',')
546
833
 
547
834
  def set_group(self, values: list) -> str:
548
835
  return self.join_with_tabs(values, ',')
549
836
 
837
+ def set_limit(self, values: list) -> str:
838
+ return self.join_with_tabs(values, ' ')
839
+
550
840
  def __init__(self, target: 'Select'):
551
- self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY]
841
+ self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY, LIMIT]
552
842
  self.TABULATION = '\n\t' if target.break_lines else ' '
553
843
  self.LINE_BREAK = '\n' if target.break_lines else ' '
554
844
  self.TOKEN_METHODS = {
555
845
  SELECT: self.add_field, FROM: self.get_tables,
556
- WHERE: self.extract_conditions,
846
+ WHERE: self.extract_conditions, LIMIT: self.set_limit,
557
847
  ORDER_BY: self.sort_by, GROUP_BY: self.set_group,
558
848
  }
559
849
  self.result = {}
@@ -857,10 +1147,13 @@ class SQLParser(Parser):
857
1147
  if not key in values:
858
1148
  continue
859
1149
  separator = self.class_type.get_separator(key)
1150
+ cls = {
1151
+ ORDER_BY: OrderBy, GROUP_BY: GroupBy
1152
+ }.get(key, Field)
860
1153
  obj.values[key] = [
861
- Field.format(fld, obj)
1154
+ cls.format(fld, obj)
862
1155
  for fld in re.split(separator, values[key])
863
- if (fld != '*' and len(tables) == 1) or obj.match(fld)
1156
+ if (fld != '*' and len(tables) == 1) or obj.match(fld, key)
864
1157
  ]
865
1158
  result[obj.alias] = obj
866
1159
  self.queries = list( result.values() )
@@ -920,16 +1213,26 @@ class CypherParser(Parser):
920
1213
  if token in self.TOKEN_METHODS:
921
1214
  return
922
1215
  class_list = [Field]
923
- if '$' in token:
1216
+ if '*' in token:
1217
+ token = token.replace('*', '')
1218
+ self.queries[-1].key_field = token
1219
+ return
1220
+ elif '$' in token:
924
1221
  func_name, token = token.split('$')
925
1222
  if func_name == 'count':
926
1223
  if not token:
927
1224
  token = 'count_1'
928
- NamedField(token, Count).add('*', self.queries[-1])
929
- class_list = []
1225
+ pk_field = self.queries[-1].key_field or 'id'
1226
+ Count().As(token, extra_classes).add(pk_field, self.queries[-1])
1227
+ return
930
1228
  else:
931
- FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
932
- class_list = [ FUNCTION_CLASS[func_name] ]
1229
+ class_type = FUNCTION_CLASS.get(func_name)
1230
+ if not class_type:
1231
+ raise ValueError(f'Unknown function `{func_name}`.')
1232
+ if ':' in token:
1233
+ token, field_alias = token.split(':')
1234
+ class_type = class_type().As(field_alias)
1235
+ class_list = [class_type]
933
1236
  class_list += extra_classes
934
1237
  FieldList(token, class_list).add('', self.queries[-1])
935
1238
 
@@ -944,10 +1247,13 @@ class CypherParser(Parser):
944
1247
  def add_foreign_key(self, token: str, pk_field: str=''):
945
1248
  curr, last = [self.queries[i] for i in (-1, -2)]
946
1249
  if not pk_field:
947
- if not last.values.get(SELECT):
948
- raise IndexError(f'Primary Key not found for {last.table_name}.')
949
- pk_field = last.values[SELECT][-1].split('.')[-1]
950
- last.delete(pk_field, [SELECT])
1250
+ if last.key_field:
1251
+ pk_field = last.key_field
1252
+ else:
1253
+ if not last.values.get(SELECT):
1254
+ raise IndexError(f'Primary Key not found for {last.table_name}.')
1255
+ pk_field = last.values[SELECT][-1].split('.')[-1]
1256
+ last.delete(pk_field, [SELECT], exact=True)
951
1257
  if '{}' in token:
952
1258
  foreign_fld = token.format(
953
1259
  last.table_name.lower()
@@ -962,12 +1268,11 @@ class CypherParser(Parser):
962
1268
  if fld not in curr.values.get(GROUP_BY, [])
963
1269
  ]
964
1270
  foreign_fld = fields[0].split('.')[-1]
965
- curr.delete(foreign_fld, [SELECT])
1271
+ curr.delete(foreign_fld, [SELECT], exact=True)
966
1272
  if curr.join_type == JoinType.RIGHT:
967
1273
  pk_field, foreign_fld = foreign_fld, pk_field
968
1274
  if curr.join_type == JoinType.RIGHT:
969
1275
  curr, last = last, curr
970
- # pk_field, foreign_fld = foreign_fld, pk_field
971
1276
  k = ForeignKey.get_key(curr, last)
972
1277
  ForeignKey.references[k] = (foreign_fld, pk_field)
973
1278
 
@@ -1153,21 +1458,30 @@ class Select(SQLObject):
1153
1458
 
1154
1459
  def add(self, name: str, main: SQLObject):
1155
1460
  old_tables = main.values.get(FROM, [])
1156
- new_tables = set([
1157
- '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
1461
+ if len(self.values[FROM]) > 1:
1462
+ old_tables += self.values[FROM][1:]
1463
+ new_tables = []
1464
+ row = '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
1158
1465
  jt=self.join_type.value,
1159
1466
  tb=self.aka(),
1160
1467
  a1=main.alias, f1=name,
1161
1468
  a2=self.alias, f2=self.key_field
1162
1469
  )
1163
- ] + old_tables[1:])
1164
- main.values[FROM] = old_tables[:1] + list(new_tables)
1470
+ if row not in old_tables[1:]:
1471
+ new_tables.append(row)
1472
+ main.values[FROM] = old_tables[:1] + new_tables + old_tables[1:]
1165
1473
  for key in USUAL_KEYS:
1166
1474
  main.update_values(key, self.values.get(key, []))
1167
1475
 
1168
- def __add__(self, other: SQLObject):
1476
+ def copy(self) -> SQLObject:
1169
1477
  from copy import deepcopy
1170
- query = deepcopy(self)
1478
+ return deepcopy(self)
1479
+
1480
+ def no_relation_error(self, other: SQLObject):
1481
+ raise ValueError(f'No relationship found between {self.table_name} and {other.table_name}.')
1482
+
1483
+ def __add__(self, other: SQLObject):
1484
+ query = self.copy()
1171
1485
  if query.table_name.lower() == other.table_name.lower():
1172
1486
  for key in USUAL_KEYS:
1173
1487
  query.update_values(key, other.values.get(key, []))
@@ -1180,7 +1494,7 @@ class Select(SQLObject):
1180
1494
  PrimaryKey.add(primary_key, query)
1181
1495
  query.add(foreign_field, other)
1182
1496
  return other
1183
- raise ValueError(f'No relationship found between {query.table_name} and {other.table_name}.')
1497
+ self.no_relation_error(other) # === raise ERROR ... ===
1184
1498
  elif primary_key:
1185
1499
  PrimaryKey.add(primary_key, other)
1186
1500
  other.add(foreign_field, query)
@@ -1200,16 +1514,48 @@ class Select(SQLObject):
1200
1514
  if self.diff(key, other.values.get(key, []), True):
1201
1515
  return False
1202
1516
  return True
1517
+
1518
+ def __sub__(self, other: SQLObject) -> SQLObject:
1519
+ fk_field, primary_k = ForeignKey.find(self, other)
1520
+ if fk_field:
1521
+ query = self.copy()
1522
+ other = other.copy()
1523
+ else:
1524
+ fk_field, primary_k = ForeignKey.find(other, self)
1525
+ if not fk_field:
1526
+ self.no_relation_error(other) # === raise ERROR ... ===
1527
+ query = other.copy()
1528
+ other = self.copy()
1529
+ query.__class__ = NotSelectIN
1530
+ Field.add(fk_field, query)
1531
+ query.add(primary_k, other)
1532
+ return other
1203
1533
 
1204
1534
  def limit(self, row_count: int=100, offset: int=0):
1205
- result = [str(row_count)]
1206
- if offset > 0:
1207
- result.append(f'OFFSET {offset}')
1208
- self.values.setdefault(LIMIT, result)
1535
+ if Function.dialect == Dialect.SQL_SERVER:
1536
+ fields = self.values.get(SELECT)
1537
+ if fields:
1538
+ fields[0] = f'SELECT TOP({row_count}) {fields[0]}'
1539
+ else:
1540
+ self.values[SELECT] = [f'SELECT TOP({row_count}) *']
1541
+ return self
1542
+ if Function.dialect == Dialect.ORACLE:
1543
+ Where.gte(row_count).add(SQL_ROW_NUM, self)
1544
+ if offset > 0:
1545
+ Where.lte(row_count+offset).add(SQL_ROW_NUM, self)
1546
+ return self
1547
+ self.values[LIMIT] = ['{}{}'.format(
1548
+ row_count, f' OFFSET {offset}' if offset > 0 else ''
1549
+ )]
1209
1550
  return self
1210
1551
 
1211
- def match(self, expr: str) -> bool:
1212
- return re.findall(f'\b*{self.alias}[.]', expr) != []
1552
+ def match(self, field: str, key: str) -> bool:
1553
+ '''
1554
+ Recognizes if the field is from the current table
1555
+ '''
1556
+ if key in (ORDER_BY, GROUP_BY) and '.' not in field:
1557
+ return self.has_named_field(field)
1558
+ return re.findall(f'\b*{self.alias}[.]', field) != []
1213
1559
 
1214
1560
  @classmethod
1215
1561
  def parse(cls, txt: str, parser: Parser = SQLParser) -> list[SQLObject]:
@@ -1221,12 +1567,10 @@ class Select(SQLObject):
1221
1567
  for rule in rules:
1222
1568
  rule.apply(self)
1223
1569
 
1224
- def add_fields(self, fields: list, order_by: bool=False, group_by:bool=False):
1225
- class_types = [Field]
1226
- if order_by:
1227
- class_types += [OrderBy]
1228
- if group_by:
1229
- class_types += [GroupBy]
1570
+ def add_fields(self, fields: list, class_types=None):
1571
+ if not class_types:
1572
+ class_types = []
1573
+ class_types += [Field]
1230
1574
  FieldList(fields, class_types).add('', self)
1231
1575
 
1232
1576
  def translate_to(self, language: QueryLanguage) -> str:
@@ -1246,6 +1590,95 @@ class NotSelectIN(SelectIN):
1246
1590
  condition_class = Not
1247
1591
 
1248
1592
 
1593
+ class CTE(Select):
1594
+ prefix = ''
1595
+
1596
+ def __init__(self, table_name: str, query_list: list[Select]):
1597
+ super().__init__(table_name)
1598
+ for query in query_list:
1599
+ query.break_lines = False
1600
+ self.query_list = query_list
1601
+ self.break_lines = False
1602
+
1603
+ def __str__(self) -> str:
1604
+ size = 0
1605
+ for key in USUAL_KEYS:
1606
+ size += sum(len(v) for v in self.values.get(key, []) if '\n' not in v)
1607
+ if size > 70:
1608
+ self.break_lines = True
1609
+ # ---------------------------------------------------------
1610
+ def justify(query: Select) -> str:
1611
+ result, line = [], ''
1612
+ keywords = '|'.join(KEYWORD)
1613
+ for word in re.split(fr'({keywords}|AND|OR|,)', str(query)):
1614
+ if len(line) >= 50:
1615
+ result.append(line)
1616
+ line = ''
1617
+ line += word
1618
+ if line:
1619
+ result.append(line)
1620
+ return '\n '.join(result)
1621
+ # ---------------------------------------------------------
1622
+ return 'WITH {}{} AS (\n {}\n){}'.format(
1623
+ self.prefix, self.table_name,
1624
+ '\nUNION ALL\n '.join(
1625
+ justify(q) for q in self.query_list
1626
+ ), super().__str__()
1627
+ )
1628
+
1629
+ def join(self, pattern: str, fields: list | str, format: str=''):
1630
+ if isinstance(fields, str):
1631
+ count = len( fields.split(',') )
1632
+ else:
1633
+ count = len(fields)
1634
+ queries = detect(
1635
+ pattern*count, join_queries=False, format=format
1636
+ )
1637
+ FieldList(fields, queries, ziped=True).add('', self)
1638
+ self.break_lines = True
1639
+ return self
1640
+
1641
+ class Recursive(CTE):
1642
+ prefix = 'RECURSIVE '
1643
+
1644
+ def __str__(self) -> str:
1645
+ if len(self.query_list) > 1:
1646
+ self.query_list[-1].values[FROM].append(
1647
+ f', {self.table_name} {self.alias}')
1648
+ return super().__str__()
1649
+
1650
+ @classmethod
1651
+ def create(cls, name: str, pattern: str, formula: str, init_value, format: str=''):
1652
+ SQLObject.ALIAS_FUNC = None
1653
+ def get_field(obj: SQLObject, pos: int) -> str:
1654
+ return obj.values[SELECT][pos].split('.')[-1]
1655
+ t1, t2 = detect(
1656
+ pattern*2, join_queries=False, format=format
1657
+ )
1658
+ pk_field = get_field(t1, 0)
1659
+ foreign_key = ''
1660
+ for num in re.findall(r'\[(\d+)\]', formula):
1661
+ num = int(num)
1662
+ if not foreign_key:
1663
+ foreign_key = get_field(t2, num-1)
1664
+ formula = formula.replace(f'[{num}]', '%')
1665
+ else:
1666
+ formula = formula.replace(f'[{num}]', get_field(t2, num-1))
1667
+ Where.eq(init_value).add(pk_field, t1)
1668
+ Where.formula(formula).add(foreign_key or pk_field, t2)
1669
+ return cls(name, [t1, t2])
1670
+
1671
+ def counter(self, name: str, start, increment: str='+1'):
1672
+ for i, query in enumerate(self.query_list):
1673
+ if i == 0:
1674
+ Field.add(f'{start} AS {name}', query)
1675
+ else:
1676
+ Field.add(f'({name}{increment}) AS {name}', query)
1677
+ return self
1678
+
1679
+
1680
+ # ----- Rules -----
1681
+
1249
1682
  class RulePutLimit(Rule):
1250
1683
  @classmethod
1251
1684
  def apply(cls, target: Select):
@@ -1309,6 +1742,8 @@ class RuleDateFuncReplace(Rule):
1309
1742
  @classmethod
1310
1743
  def apply(cls, target: Select):
1311
1744
  for i, condition in enumerate(target.values.get(WHERE, [])):
1745
+ if not '(' in condition:
1746
+ continue
1312
1747
  tokens = [
1313
1748
  t.strip() for t in cls.REGEX.split(condition) if t.strip()
1314
1749
  ]
@@ -1320,6 +1755,32 @@ class RuleDateFuncReplace(Rule):
1320
1755
  target.values[WHERE][i] = ' AND '.join(temp.values[WHERE])
1321
1756
 
1322
1757
 
1758
+ class RuleReplaceJoinBySubselect(Rule):
1759
+ @classmethod
1760
+ def apply(cls, target: Select):
1761
+ main, *others = Select.parse( str(target) )
1762
+ modified = False
1763
+ for query in others:
1764
+ fk_field, primary_k = ForeignKey.find(main, query)
1765
+ more_relations = any([
1766
+ ref[0] == query.table_name for ref in ForeignKey.references
1767
+ ])
1768
+ keep_join = any([
1769
+ len( query.values.get(SELECT, []) ) > 0,
1770
+ len( query.values.get(WHERE, []) ) == 0,
1771
+ not fk_field, more_relations
1772
+ ])
1773
+ if keep_join:
1774
+ query.add(fk_field, main)
1775
+ continue
1776
+ query.__class__ = SubSelect
1777
+ Field.add(primary_k, query)
1778
+ query.add(fk_field, main)
1779
+ modified = True
1780
+ if modified:
1781
+ target.values = main.values.copy()
1782
+
1783
+
1323
1784
  def parser_class(text: str) -> Parser:
1324
1785
  PARSER_REGEX = [
1325
1786
  (r'select.*from', SQLParser),
@@ -1334,7 +1795,7 @@ def parser_class(text: str) -> Parser:
1334
1795
  return None
1335
1796
 
1336
1797
 
1337
- def detect(text: str) -> Select:
1798
+ def detect(text: str, join_queries: bool = True, format: str='') -> Select | list[Select]:
1338
1799
  from collections import Counter
1339
1800
  parser = parser_class(text)
1340
1801
  if not parser:
@@ -1345,21 +1806,19 @@ def detect(text: str) -> Select:
1345
1806
  continue
1346
1807
  pos = [ f.span() for f in re.finditer(fr'({table})[(]', text) ]
1347
1808
  for begin, end in pos[::-1]:
1348
- new_name = f'{table}_{count}' # See set_table (line 45)
1809
+ new_name = f'{table}_{count}' # See set_table (line 55)
1349
1810
  Select.EQUIVALENT_NAMES[new_name] = table
1350
1811
  text = text[:begin] + new_name + '(' + text[end:]
1351
1812
  count -= 1
1352
1813
  query_list = Select.parse(text, parser)
1814
+ if format:
1815
+ for query in query_list:
1816
+ query.set_file_format(format)
1817
+ if not join_queries:
1818
+ return query_list
1353
1819
  result = query_list[0]
1354
1820
  for query in query_list[1:]:
1355
1821
  result += query
1356
1822
  return result
1823
+ # ===========================================================================================//
1357
1824
 
1358
- if __name__ == "__main__":
1359
- OrderBy.sort = SortType.DESC
1360
- query = Select(
1361
- 'order_Detail d',
1362
- customer_id=GroupBy,
1363
- _=Sum('d.unitPrice * d.quantity').As('total', OrderBy)
1364
- )
1365
- print(query)