sql-blocks 1.25.2__py3-none-any.whl → 1.25.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sql_blocks/sql_blocks.py CHANGED
@@ -26,7 +26,7 @@ TO_LIST = lambda x: x if isinstance(x, list) else [x]
26
26
 
27
27
 
28
28
  class SQLObject:
29
- ALIAS_FUNC = lambda t: t.lower()[:3]
29
+ ALIAS_FUNC = None
30
30
  """ ^^^^^^^^^^^^^^^^^^^^^^^^
31
31
  You can change the behavior by assigning
32
32
  a user function to SQLObject.ALIAS_FUNC
@@ -41,21 +41,35 @@ class SQLObject:
41
41
  def set_table(self, table_name: str):
42
42
  if not table_name:
43
43
  return
44
- if ' ' in table_name.strip():
44
+ cls = SQLObject
45
+ is_file_name = any([
46
+ '/' in table_name, '.' in table_name
47
+ ])
48
+ ref = table_name
49
+ if is_file_name:
50
+ ref = table_name.split('/')[-1].split('.')[0]
51
+ if cls.ALIAS_FUNC:
52
+ self.__alias = cls.ALIAS_FUNC(ref)
53
+ elif ' ' in table_name.strip():
45
54
  table_name, self.__alias = table_name.split()
46
- elif '_' in table_name:
55
+ elif '_' in ref:
47
56
  self.__alias = ''.join(
48
57
  word[0].lower()
49
- for word in table_name.split('_')
58
+ for word in ref.split('_')
50
59
  )
51
60
  else:
52
- self.__alias = SQLObject.ALIAS_FUNC(table_name)
61
+ self.__alias = ref.lower()[:3]
53
62
  self.values.setdefault(FROM, []).append(f'{table_name} {self.alias}')
54
63
 
55
64
  @property
56
65
  def table_name(self) -> str:
57
66
  return self.values[FROM][0].split()[0]
58
67
 
68
+ def set_file_format(self, pattern: str):
69
+ if '{' not in pattern:
70
+ pattern = '{}' + pattern
71
+ self.values[FROM][0] = pattern.format(self.aka())
72
+
59
73
  @property
60
74
  def alias(self) -> str:
61
75
  if self.__alias:
@@ -67,6 +81,16 @@ class SQLObject:
67
81
  appendix = {WHERE: r'\s+and\s+|', FROM: r'\s+join\s+|\s+JOIN\s+'}
68
82
  return KEYWORD[key][0].format(appendix.get(key, ''))
69
83
 
84
+ @staticmethod
85
+ def is_named_field(fld: str, name: str='') -> bool:
86
+ return re.search(fr'(\s+as\s+|\s+AS\s+){name}', fld)
87
+
88
+ def has_named_field(self, name: str) -> bool:
89
+ return any(
90
+ self.is_named_field(fld, name)
91
+ for fld in self.values.get(SELECT, [])
92
+ )
93
+
70
94
  def diff(self, key: str, search_list: list, exact: bool=False) -> set:
71
95
  def disassemble(source: list) -> list:
72
96
  if not exact:
@@ -79,12 +103,12 @@ class SQLObject:
79
103
  if exact:
80
104
  fld = fld.lower()
81
105
  return fld.strip()
82
- def is_named_field(fld: str) -> bool:
83
- return key == SELECT and re.search(r'\s+as\s+|\s+AS\s+', fld)
84
106
  def field_set(source: list) -> set:
85
107
  return set(
86
108
  (
87
- fld if is_named_field(fld) else
109
+ fld
110
+ if key == SELECT and self.is_named_field(fld, key)
111
+ else
88
112
  re.sub(pattern, '', cleanup(fld))
89
113
  )
90
114
  for string in disassemble(source)
@@ -102,13 +126,22 @@ class SQLObject:
102
126
  return s1.symmetric_difference(s2)
103
127
  return s1 - s2
104
128
 
105
- def delete(self, search: str, keys: list=USUAL_KEYS):
129
+ def delete(self, search: str, keys: list=USUAL_KEYS, exact: bool=False):
130
+ if exact:
131
+ not_match = lambda item: not re.search(fr'\w*[.]*{search}$', item)
132
+ else:
133
+ not_match = lambda item: search not in item
106
134
  for key in keys:
107
- result = []
108
- for item in self.values.get(key, []):
109
- if search not in item:
110
- result.append(item)
111
- self.values[key] = result
135
+ self.values[key] = [
136
+ item for item in self.values.get(key, [])
137
+ if not_match(item)
138
+ ]
139
+
140
+
141
+ SQL_CONST_SYSDATE = 'SYSDATE'
142
+ SQL_CONST_CURR_DATE = 'Current_date'
143
+ SQL_ROW_NUM = 'ROWNUM'
144
+ SQL_CONSTS = [SQL_CONST_SYSDATE, SQL_CONST_CURR_DATE, SQL_ROW_NUM]
112
145
 
113
146
 
114
147
  class Field:
@@ -116,10 +149,16 @@ class Field:
116
149
 
117
150
  @classmethod
118
151
  def format(cls, name: str, main: SQLObject) -> str:
152
+ def is_const() -> bool:
153
+ return any([
154
+ re.findall('[.()0-9]', name),
155
+ name in SQL_CONSTS,
156
+ re.findall(r'\w+\s*[+-]\s*\w+', name)
157
+ ])
119
158
  name = name.strip()
120
159
  if name in ('_', '*'):
121
160
  name = '*'
122
- elif not re.findall('[.()0-9]', name):
161
+ elif not is_const() and not main.has_named_field(name):
123
162
  name = f'{main.alias}.{name}'
124
163
  if Function in cls.__bases__:
125
164
  name = f'{cls.__name__}({name})'
@@ -150,90 +189,210 @@ class NamedField:
150
189
  )
151
190
 
152
191
 
192
+ class Dialect(Enum):
193
+ ANSI = 0
194
+ SQL_SERVER = 1
195
+ ORACLE = 2
196
+ POSTGRESQL = 3
197
+ MYSQL = 4
198
+
199
+ SQL_TYPES = 'CHAR INT DATE FLOAT ANY'.split()
200
+ CHAR, INT, DATE, FLOAT, ANY = SQL_TYPES
201
+
153
202
  class Function:
154
- instance: dict = {}
203
+ dialect = Dialect.ANSI
204
+ inputs = None
205
+ output = None
206
+ separator = ', '
207
+ auto_convert = True
208
+ append_param = False
155
209
 
156
210
  def __init__(self, *params: list):
157
- func_name = self.__class__.__name__
158
- Function.instance[func_name] = self
159
- self.params = [str(p) for p in params]
160
- self.class_type = Field
161
- self.pattern = '{}({})'
211
+ def set_func_types(param):
212
+ if self.auto_convert and isinstance(param, Function):
213
+ func = param
214
+ main_param = self.inputs[0]
215
+ unfriendly = all([
216
+ func.output != main_param,
217
+ func.output != ANY,
218
+ main_param != ANY
219
+ ])
220
+ if unfriendly:
221
+ return Cast(func, main_param)
222
+ return param
223
+ # --- Replace class methods by instance methods: ------
224
+ self.add = self.__add
225
+ self.format = self.__format
226
+ # -----------------------------------------------------
227
+ self.params = [set_func_types(p) for p in params]
228
+ self.field_class = Field
229
+ self.pattern = self.get_pattern()
162
230
  self.extra = {}
163
231
 
232
+ def get_pattern(self) -> str:
233
+ return '{func_name}({params})'
234
+
164
235
  def As(self, field_alias: str, modifiers=None):
165
236
  if modifiers:
166
237
  self.extra[field_alias] = TO_LIST(modifiers)
167
- self.class_type = NamedField(field_alias)
238
+ self.field_class = NamedField(field_alias)
168
239
  return self
169
240
 
241
+ def __str__(self) -> str:
242
+ return self.pattern.format(
243
+ func_name=self.__class__.__name__,
244
+ params=self.separator.join(str(p) for p in self.params)
245
+ )
246
+
170
247
  @classmethod
171
- def format(cls, name: str, main: SQLObject) -> str:
172
- obj = cls.get_instance()
173
- if name in '*_' and obj.params:
174
- params = obj.params
248
+ def help(cls) -> str:
249
+ descr = ' '.join(B.__name__ for B in cls.__bases__)
250
+ params = cls.inputs or ''
251
+ return cls().get_pattern().format(
252
+ func_name=f'{descr} {cls.__name__}',
253
+ params=cls.separator.join(str(p) for p in params)
254
+ ) + f' Return {cls.output}'
255
+
256
+ def set_main_param(self, name: str, main: SQLObject) -> bool:
257
+ nested_functions = [
258
+ param for param in self.params if isinstance(param, Function)
259
+ ]
260
+ for func in nested_functions:
261
+ if func.inputs:
262
+ func.set_main_param(name, main)
263
+ return
264
+ new_params = [Field.format(name, main)]
265
+ if self.append_param:
266
+ self.params += new_params
175
267
  else:
176
- params = [
177
- Field.format(name, main)
178
- ] + obj.params
179
- return obj.pattern.format(
180
- cls.__name__,
181
- ', '.join(params)
182
- )
268
+ self.params = new_params + self.params
269
+
270
+ def __format(self, name: str, main: SQLObject) -> str:
271
+ if name not in '*_':
272
+ self.set_main_param(name, main)
273
+ return str(self)
274
+
275
+ @classmethod
276
+ def format(cls, name: str, main: SQLObject):
277
+ return cls().__format(name, main)
183
278
 
184
279
  def __add(self, name: str, main: SQLObject):
185
280
  name = self.format(name, main)
186
- self.class_type.add(name, main)
281
+ self.field_class.add(name, main)
187
282
  if self.extra:
188
283
  main.__call__(**self.extra)
189
284
 
190
- @classmethod
191
- def get_instance(cls):
192
- obj = Function.instance.get(cls.__name__)
193
- if not obj:
194
- obj = cls()
195
- return obj
196
-
197
285
  @classmethod
198
286
  def add(cls, name: str, main: SQLObject):
199
- cls.get_instance().__add(name, main)
287
+ cls().__add(name, main)
200
288
 
201
289
 
202
290
  # ---- String Functions: ---------------------------------
203
291
  class SubString(Function):
204
- ...
292
+ inputs = [CHAR, INT, INT]
293
+ output = CHAR
294
+
295
+ def get_pattern(self) -> str:
296
+ if self.dialect in (Dialect.ORACLE, Dialect.MYSQL):
297
+ return 'Substr({params})'
298
+ return super().get_pattern()
205
299
 
206
300
  # ---- Numeric Functions: --------------------------------
207
301
  class Round(Function):
208
- ...
302
+ inputs = [FLOAT]
303
+ output = FLOAT
209
304
 
210
305
  # --- Date Functions: ------------------------------------
211
306
  class DateDiff(Function):
307
+ inputs = [DATE]
308
+ output = DATE
309
+ append_param = True
310
+
311
+ def __str__(self) -> str:
312
+ def is_field_or_func(name: str) -> bool:
313
+ candidate = re.sub(
314
+ '[()]', '', name.split('.')[-1]
315
+ )
316
+ return candidate.isidentifier()
317
+ if self.dialect != Dialect.SQL_SERVER:
318
+ params = [str(p) for p in self.params]
319
+ return ' - '.join(
320
+ p if is_field_or_func(p) else f"'{p}'"
321
+ for p in params
322
+ ) # <==== Date subtract
323
+ return super().__str__()
324
+
325
+
326
+ class DatePart(Function):
327
+ inputs = [DATE]
328
+ output = INT
329
+
330
+ def get_pattern(self) -> str:
331
+ interval = self.__class__.__name__
332
+ database_type = {
333
+ Dialect.ORACLE: 'Extract('+interval+' FROM {params})',
334
+ Dialect.POSTGRESQL: "Date_Part('"+interval+"', {params})",
335
+ }
336
+ if self.dialect in database_type:
337
+ return database_type[self.dialect]
338
+ return super().get_pattern()
339
+
340
+ class Year(DatePart):
212
341
  ...
213
- class Extract(Function):
342
+ class Month(DatePart):
214
343
  ...
215
- class DatePart(Function):
344
+ class Day(DatePart):
216
345
  ...
346
+
347
+
217
348
  class Current_Date(Function):
218
- ...
349
+ output = DATE
350
+
351
+ def get_pattern(self) -> str:
352
+ database_type = {
353
+ Dialect.ORACLE: SQL_CONST_SYSDATE,
354
+ Dialect.POSTGRESQL: SQL_CONST_CURR_DATE,
355
+ Dialect.SQL_SERVER: 'getDate()'
356
+ }
357
+ if self.dialect in database_type:
358
+ return database_type[self.dialect]
359
+ return super().get_pattern()
360
+ # --------------------------------------------------------
219
361
 
220
- class Aggregate:
362
+ class Frame:
221
363
  break_lines: bool = True
222
364
 
223
365
  def over(self, **args):
224
- keywords = ' '.join(
225
- '{}{} BY {}'.format(
226
- '\n\t\t' if self.break_lines else '',
227
- key.upper(), args[key]
228
- ) for key in ('partition', 'order')
229
- if key in args
230
- )
366
+ """
367
+ How to use:
368
+ over(field1=OrderBy, field2=Partition)
369
+ """
370
+ keywords = ''
371
+ for field, obj in args.items():
372
+ is_valid = any([
373
+ obj is OrderBy,
374
+ obj is Partition,
375
+ isinstance(obj, Rows),
376
+ ])
377
+ if not is_valid:
378
+ continue
379
+ keywords += '{}{} {}'.format(
380
+ '\n\t\t' if self.break_lines else ' ',
381
+ obj.cls_to_str(), field if field != '_' else ''
382
+ )
231
383
  if keywords and self.break_lines:
232
384
  keywords += '\n\t'
233
- self.pattern = '{}({})' + f' OVER({keywords})'
385
+ self.pattern = self.get_pattern() + f' OVER({keywords})'
234
386
  return self
235
387
 
236
388
 
389
+ class Aggregate(Frame):
390
+ inputs = [FLOAT]
391
+ output = FLOAT
392
+
393
+ class Window(Frame):
394
+ ...
395
+
237
396
  # ---- Aggregate Functions: -------------------------------
238
397
  class Avg(Aggregate, Function):
239
398
  ...
@@ -246,11 +405,32 @@ class Sum(Aggregate, Function):
246
405
  class Count(Aggregate, Function):
247
406
  ...
248
407
 
408
+ # ---- Window Functions: -----------------------------------
409
+ class Row_Number(Window, Function):
410
+ output = INT
411
+
412
+ class Rank(Window, Function):
413
+ output = INT
414
+
415
+ class Lag(Window, Function):
416
+ output = ANY
417
+
418
+ class Lead(Window, Function):
419
+ output = ANY
420
+
421
+
249
422
  # ---- Conversions and other Functions: ---------------------
250
423
  class Coalesce(Function):
251
- ...
424
+ inputs = [ANY]
425
+ output = ANY
426
+
252
427
  class Cast(Function):
253
- ...
428
+ inputs = [ANY]
429
+ output = ANY
430
+ separator = ' As '
431
+
432
+
433
+ FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
254
434
 
255
435
 
256
436
  class ExpressionField:
@@ -275,15 +455,20 @@ class ExpressionField:
275
455
  class FieldList:
276
456
  separator = ','
277
457
 
278
- def __init__(self, fields: list=[], class_types = [Field]):
458
+ def __init__(self, fields: list=[], class_types = [Field], ziped: bool=False):
279
459
  if isinstance(fields, str):
280
460
  fields = [
281
461
  f.strip() for f in fields.split(self.separator)
282
462
  ]
283
463
  self.fields = fields
284
464
  self.class_types = class_types
465
+ self.ziped = ziped
285
466
 
286
467
  def add(self, name: str, main: SQLObject):
468
+ if self.ziped: # --- One class per field...
469
+ for field, class_type in zip(self.fields, self.class_types):
470
+ class_type.add(field, main)
471
+ return
287
472
  for field in self.fields:
288
473
  for class_type in self.class_types:
289
474
  class_type.add(field, main)
@@ -329,23 +514,35 @@ def quoted(value) -> str:
329
514
  return str(value)
330
515
 
331
516
 
517
+ class Position(Enum):
518
+ Middle = 0
519
+ StartsWith = 1
520
+ EndsWith = 2
521
+
522
+
332
523
  class Where:
333
524
  prefix = ''
334
525
 
335
- def __init__(self, expr: str):
336
- self.expr = expr
526
+ def __init__(self, content: str):
527
+ self.content = content
337
528
 
338
529
  @classmethod
339
530
  def __constructor(cls, operator: str, value):
340
- return cls(expr=f'{operator} {quoted(value)}')
531
+ return cls(f'{operator} {quoted(value)}')
341
532
 
342
533
  @classmethod
343
534
  def eq(cls, value):
344
535
  return cls.__constructor('=', value)
345
536
 
346
537
  @classmethod
347
- def contains(cls, value: str):
348
- return cls(f"LIKE '%{value}%'")
538
+ def contains(cls, text: str, pos: Position = Position.Middle):
539
+ return cls(
540
+ "LIKE '{}{}{}'".format(
541
+ '%' if pos != Position.StartsWith else '',
542
+ text,
543
+ '%' if pos != Position.EndsWith else ''
544
+ )
545
+ )
349
546
 
350
547
  @classmethod
351
548
  def gt(cls, value):
@@ -373,9 +570,42 @@ class Where:
373
570
  values = ','.join(quoted(v) for v in values)
374
571
  return cls(f'IN ({values})')
375
572
 
573
+ @classmethod
574
+ def formula(cls, formula: str):
575
+ where = cls( ExpressionField(formula) )
576
+ where.add = where.add_expression
577
+ return where
578
+
579
+ def add_expression(self, name: str, main: SQLObject):
580
+ self.content = self.content.format(name, main)
581
+ main.values.setdefault(WHERE, []).append('{} {}'.format(
582
+ self.prefix, self.content
583
+ ))
584
+
585
+ @classmethod
586
+ def join(cls, query: SQLObject):
587
+ where = cls(query)
588
+ where.add = where.add_join
589
+ return where
590
+
591
+ def add_join(self, name: str, main: SQLObject):
592
+ query = self.content
593
+ main.values[FROM].append(f',{query.table_name} {query.alias}')
594
+ for key in USUAL_KEYS:
595
+ main.update_values(key, query.values.get(key, []))
596
+ main.values.setdefault(WHERE, []).append('({a1}.{f1} = {a2}.{f2})'.format(
597
+ a1=main.alias, f1=name,
598
+ a2=query.alias, f2=query.key_field
599
+ ))
600
+
376
601
  def add(self, name: str, main: SQLObject):
602
+ func_type = FUNCTION_CLASS.get(name.lower())
603
+ if func_type:
604
+ name = func_type.format('*', main)
605
+ elif not main.has_named_field(name):
606
+ name = Field.format(name, main)
377
607
  main.values.setdefault(WHERE, []).append('{}{} {}'.format(
378
- self.prefix, Field.format(name, main), self.expr
608
+ self.prefix, name, self.content
379
609
  ))
380
610
 
381
611
 
@@ -383,6 +613,10 @@ eq, contains, gt, gte, lt, lte, is_null, inside = (
383
613
  getattr(Where, method) for method in
384
614
  ('eq', 'contains', 'gt', 'gte', 'lt', 'lte', 'is_null', 'inside')
385
615
  )
616
+ startswith, endswith = [
617
+ lambda x: contains(x, Position.StartsWith),
618
+ lambda x: contains(x, Position.EndsWith)
619
+ ]
386
620
 
387
621
 
388
622
  class Not(Where):
@@ -390,7 +624,7 @@ class Not(Where):
390
624
 
391
625
  @classmethod
392
626
  def eq(cls, value):
393
- return Where(expr=f'<> {quoted(value)}')
627
+ return Where(f'<> {quoted(value)}')
394
628
 
395
629
 
396
630
  class Case:
@@ -399,20 +633,24 @@ class Case:
399
633
  self.default = None
400
634
  self.field = field
401
635
 
402
- def when(self, condition: Where, result: str):
636
+ def when(self, condition: Where, result):
637
+ if isinstance(result, str):
638
+ result = quoted(result)
403
639
  self.__conditions[result] = condition
404
640
  return self
405
641
 
406
- def else_value(self, default: str):
642
+ def else_value(self, default):
643
+ if isinstance(default, str):
644
+ default = quoted(default)
407
645
  self.default = default
408
646
  return self
409
647
 
410
648
  def add(self, name: str, main: SQLObject):
411
649
  field = Field.format(self.field, main)
412
- default = quoted(self.default)
650
+ default = self.default
413
651
  name = 'CASE \n{}\n\tEND AS {}'.format(
414
652
  '\n'.join(
415
- f'\t\tWHEN {field} {cond.expr} THEN {quoted(res)}'
653
+ f'\t\tWHEN {field} {cond.content} THEN {res}'
416
654
  for res, cond in self.__conditions.items()
417
655
  ) + f'\n\t\tELSE {default}' if default else '',
418
656
  name
@@ -425,14 +663,13 @@ class Options:
425
663
  self.__children: dict = values
426
664
 
427
665
  def add(self, logical_separator: str, main: SQLObject):
428
- """
429
- `logical_separator` must be AND or OR
430
- """
666
+ if logical_separator not in ('AND', 'OR'):
667
+ raise ValueError('`logical_separator` must be AND or OR')
431
668
  conditions: list[str] = []
432
669
  child: Where
433
670
  for field, child in self.__children.items():
434
671
  conditions.append(' {} {} '.format(
435
- Field.format(field, main), child.expr
672
+ Field.format(field, main), child.content
436
673
  ))
437
674
  main.values.setdefault(WHERE, []).append(
438
675
  '(' + logical_separator.join(conditions) + ')'
@@ -450,18 +687,25 @@ class Between:
450
687
  Where.gte(self.start).add(name, main),
451
688
  Where.lte(self.end).add(name, main)
452
689
 
690
+ class SameDay(Between):
691
+ def __init__(self, date: str):
692
+ super().__init__(
693
+ f'{date} 00:00:00',
694
+ f'{date} 23:59:59',
695
+ )
696
+
697
+
453
698
 
454
699
  class Clause:
455
700
  @classmethod
456
701
  def format(cls, name: str, main: SQLObject) -> str:
457
702
  def is_function() -> bool:
458
703
  diff = main.diff(SELECT, [name.lower()], True)
459
- FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
460
704
  return diff.intersection(FUNCTION_CLASS)
461
705
  found = re.findall(r'^_\d', name)
462
706
  if found:
463
707
  name = found[0].replace('_', '')
464
- elif main.alias and not is_function():
708
+ elif '.' not in name and main.alias and not is_function():
465
709
  name = f'{main.alias}.{name}'
466
710
  return name
467
711
 
@@ -470,6 +714,34 @@ class SortType(Enum):
470
714
  ASC = ''
471
715
  DESC = ' DESC'
472
716
 
717
+ class Row:
718
+ def __init__(self, value: int=0):
719
+ self.value = value
720
+
721
+ def __str__(self) -> str:
722
+ return '{} {}'.format(
723
+ 'UNBOUNDED' if self.value == 0 else self.value,
724
+ self.__class__.__name__.upper()
725
+ )
726
+
727
+ class Preceding(Row):
728
+ ...
729
+ class Following(Row):
730
+ ...
731
+ class Current(Row):
732
+ def __str__(self) -> str:
733
+ return 'CURRENT ROW'
734
+
735
+ class Rows:
736
+ def __init__(self, *rows: list[Row]):
737
+ self.rows = rows
738
+
739
+ def cls_to_str(self) -> str:
740
+ return 'ROWS {}{}'.format(
741
+ 'BETWEEN ' if len(self.rows) > 1 else '',
742
+ ' AND '.join(str(row) for row in self.rows)
743
+ )
744
+
473
745
 
474
746
  class OrderBy(Clause):
475
747
  sort: SortType = SortType.ASC
@@ -479,6 +751,16 @@ class OrderBy(Clause):
479
751
  name = cls.format(name, main)
480
752
  main.values.setdefault(ORDER_BY, []).append(name+cls.sort.value)
481
753
 
754
+ @classmethod
755
+ def cls_to_str(cls) -> str:
756
+ return ORDER_BY
757
+
758
+ PARTITION_BY = 'PARTITION BY'
759
+ class Partition:
760
+ @classmethod
761
+ def cls_to_str(cls) -> str:
762
+ return PARTITION_BY
763
+
482
764
 
483
765
  class GroupBy(Clause):
484
766
  @classmethod
@@ -494,7 +776,7 @@ class Having:
494
776
 
495
777
  def add(self, name: str, main:SQLObject):
496
778
  main.values[GROUP_BY][-1] += ' HAVING {} {}'.format(
497
- self.function.format(name, main), self.condition.expr
779
+ self.function.format(name, main), self.condition.content
498
780
  )
499
781
 
500
782
  @classmethod
@@ -524,7 +806,7 @@ class Rule:
524
806
  ...
525
807
 
526
808
  class QueryLanguage:
527
- pattern = '{select}{_from}{where}{group_by}{order_by}'
809
+ pattern = '{select}{_from}{where}{group_by}{order_by}{limit}'
528
810
  has_default = {key: bool(key == SELECT) for key in KEYWORD}
529
811
 
530
812
  @staticmethod
@@ -547,18 +829,21 @@ class QueryLanguage:
547
829
  return self.join_with_tabs(values, ' AND ')
548
830
 
549
831
  def sort_by(self, values: list) -> str:
550
- return self.join_with_tabs(values)
832
+ return self.join_with_tabs(values, ',')
551
833
 
552
834
  def set_group(self, values: list) -> str:
553
835
  return self.join_with_tabs(values, ',')
554
836
 
837
+ def set_limit(self, values: list) -> str:
838
+ return self.join_with_tabs(values, ' ')
839
+
555
840
  def __init__(self, target: 'Select'):
556
- self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY]
841
+ self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY, LIMIT]
557
842
  self.TABULATION = '\n\t' if target.break_lines else ' '
558
843
  self.LINE_BREAK = '\n' if target.break_lines else ' '
559
844
  self.TOKEN_METHODS = {
560
845
  SELECT: self.add_field, FROM: self.get_tables,
561
- WHERE: self.extract_conditions,
846
+ WHERE: self.extract_conditions, LIMIT: self.set_limit,
562
847
  ORDER_BY: self.sort_by, GROUP_BY: self.set_group,
563
848
  }
564
849
  self.result = {}
@@ -862,10 +1147,13 @@ class SQLParser(Parser):
862
1147
  if not key in values:
863
1148
  continue
864
1149
  separator = self.class_type.get_separator(key)
1150
+ cls = {
1151
+ ORDER_BY: OrderBy, GROUP_BY: GroupBy
1152
+ }.get(key, Field)
865
1153
  obj.values[key] = [
866
- Field.format(fld, obj)
1154
+ cls.format(fld, obj)
867
1155
  for fld in re.split(separator, values[key])
868
- if (fld != '*' and len(tables) == 1) or obj.match(fld)
1156
+ if (fld != '*' and len(tables) == 1) or obj.match(fld, key)
869
1157
  ]
870
1158
  result[obj.alias] = obj
871
1159
  self.queries = list( result.values() )
@@ -925,16 +1213,26 @@ class CypherParser(Parser):
925
1213
  if token in self.TOKEN_METHODS:
926
1214
  return
927
1215
  class_list = [Field]
928
- if '$' in token:
1216
+ if '*' in token:
1217
+ token = token.replace('*', '')
1218
+ self.queries[-1].key_field = token
1219
+ return
1220
+ elif '$' in token:
929
1221
  func_name, token = token.split('$')
930
1222
  if func_name == 'count':
931
1223
  if not token:
932
1224
  token = 'count_1'
933
- NamedField(token, Count).add('*', self.queries[-1])
934
- class_list = []
1225
+ pk_field = self.queries[-1].key_field or 'id'
1226
+ Count().As(token, extra_classes).add(pk_field, self.queries[-1])
1227
+ return
935
1228
  else:
936
- FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
937
- class_list = [ FUNCTION_CLASS[func_name] ]
1229
+ class_type = FUNCTION_CLASS.get(func_name)
1230
+ if not class_type:
1231
+ raise ValueError(f'Unknown function `{func_name}`.')
1232
+ if ':' in token:
1233
+ token, field_alias = token.split(':')
1234
+ class_type = class_type().As(field_alias)
1235
+ class_list = [class_type]
938
1236
  class_list += extra_classes
939
1237
  FieldList(token, class_list).add('', self.queries[-1])
940
1238
 
@@ -949,10 +1247,13 @@ class CypherParser(Parser):
949
1247
  def add_foreign_key(self, token: str, pk_field: str=''):
950
1248
  curr, last = [self.queries[i] for i in (-1, -2)]
951
1249
  if not pk_field:
952
- if not last.values.get(SELECT):
953
- raise IndexError(f'Primary Key not found for {last.table_name}.')
954
- pk_field = last.values[SELECT][-1].split('.')[-1]
955
- last.delete(pk_field, [SELECT])
1250
+ if last.key_field:
1251
+ pk_field = last.key_field
1252
+ else:
1253
+ if not last.values.get(SELECT):
1254
+ raise IndexError(f'Primary Key not found for {last.table_name}.')
1255
+ pk_field = last.values[SELECT][-1].split('.')[-1]
1256
+ last.delete(pk_field, [SELECT], exact=True)
956
1257
  if '{}' in token:
957
1258
  foreign_fld = token.format(
958
1259
  last.table_name.lower()
@@ -967,12 +1268,11 @@ class CypherParser(Parser):
967
1268
  if fld not in curr.values.get(GROUP_BY, [])
968
1269
  ]
969
1270
  foreign_fld = fields[0].split('.')[-1]
970
- curr.delete(foreign_fld, [SELECT])
1271
+ curr.delete(foreign_fld, [SELECT], exact=True)
971
1272
  if curr.join_type == JoinType.RIGHT:
972
1273
  pk_field, foreign_fld = foreign_fld, pk_field
973
1274
  if curr.join_type == JoinType.RIGHT:
974
1275
  curr, last = last, curr
975
- # pk_field, foreign_fld = foreign_fld, pk_field
976
1276
  k = ForeignKey.get_key(curr, last)
977
1277
  ForeignKey.references[k] = (foreign_fld, pk_field)
978
1278
 
@@ -1158,21 +1458,30 @@ class Select(SQLObject):
1158
1458
 
1159
1459
  def add(self, name: str, main: SQLObject):
1160
1460
  old_tables = main.values.get(FROM, [])
1161
- new_tables = set([
1162
- '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
1461
+ if len(self.values[FROM]) > 1:
1462
+ old_tables += self.values[FROM][1:]
1463
+ new_tables = []
1464
+ row = '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
1163
1465
  jt=self.join_type.value,
1164
1466
  tb=self.aka(),
1165
1467
  a1=main.alias, f1=name,
1166
1468
  a2=self.alias, f2=self.key_field
1167
1469
  )
1168
- ] + old_tables[1:])
1169
- main.values[FROM] = old_tables[:1] + list(new_tables)
1470
+ if row not in old_tables[1:]:
1471
+ new_tables.append(row)
1472
+ main.values[FROM] = old_tables[:1] + new_tables + old_tables[1:]
1170
1473
  for key in USUAL_KEYS:
1171
1474
  main.update_values(key, self.values.get(key, []))
1172
1475
 
1173
- def __add__(self, other: SQLObject):
1476
+ def copy(self) -> SQLObject:
1174
1477
  from copy import deepcopy
1175
- query = deepcopy(self)
1478
+ return deepcopy(self)
1479
+
1480
+ def no_relation_error(self, other: SQLObject):
1481
+ raise ValueError(f'No relationship found between {self.table_name} and {other.table_name}.')
1482
+
1483
+ def __add__(self, other: SQLObject):
1484
+ query = self.copy()
1176
1485
  if query.table_name.lower() == other.table_name.lower():
1177
1486
  for key in USUAL_KEYS:
1178
1487
  query.update_values(key, other.values.get(key, []))
@@ -1185,7 +1494,7 @@ class Select(SQLObject):
1185
1494
  PrimaryKey.add(primary_key, query)
1186
1495
  query.add(foreign_field, other)
1187
1496
  return other
1188
- raise ValueError(f'No relationship found between {query.table_name} and {other.table_name}.')
1497
+ self.no_relation_error(other) # === raise ERROR ... ===
1189
1498
  elif primary_key:
1190
1499
  PrimaryKey.add(primary_key, other)
1191
1500
  other.add(foreign_field, query)
@@ -1205,16 +1514,48 @@ class Select(SQLObject):
1205
1514
  if self.diff(key, other.values.get(key, []), True):
1206
1515
  return False
1207
1516
  return True
1517
+
1518
+ def __sub__(self, other: SQLObject) -> SQLObject:
1519
+ fk_field, primary_k = ForeignKey.find(self, other)
1520
+ if fk_field:
1521
+ query = self.copy()
1522
+ other = other.copy()
1523
+ else:
1524
+ fk_field, primary_k = ForeignKey.find(other, self)
1525
+ if not fk_field:
1526
+ self.no_relation_error(other) # === raise ERROR ... ===
1527
+ query = other.copy()
1528
+ other = self.copy()
1529
+ query.__class__ = NotSelectIN
1530
+ Field.add(fk_field, query)
1531
+ query.add(primary_k, other)
1532
+ return other
1208
1533
 
1209
1534
  def limit(self, row_count: int=100, offset: int=0):
1210
- result = [str(row_count)]
1211
- if offset > 0:
1212
- result.append(f'OFFSET {offset}')
1213
- self.values.setdefault(LIMIT, result)
1535
+ if Function.dialect == Dialect.SQL_SERVER:
1536
+ fields = self.values.get(SELECT)
1537
+ if fields:
1538
+ fields[0] = f'SELECT TOP({row_count}) {fields[0]}'
1539
+ else:
1540
+ self.values[SELECT] = [f'SELECT TOP({row_count}) *']
1541
+ return self
1542
+ if Function.dialect == Dialect.ORACLE:
1543
+ Where.gte(row_count).add(SQL_ROW_NUM, self)
1544
+ if offset > 0:
1545
+ Where.lte(row_count+offset).add(SQL_ROW_NUM, self)
1546
+ return self
1547
+ self.values[LIMIT] = ['{}{}'.format(
1548
+ row_count, f' OFFSET {offset}' if offset > 0 else ''
1549
+ )]
1214
1550
  return self
1215
1551
 
1216
- def match(self, expr: str) -> bool:
1217
- return re.findall(f'\b*{self.alias}[.]', expr) != []
1552
+ def match(self, field: str, key: str) -> bool:
1553
+ '''
1554
+ Recognizes if the field is from the current table
1555
+ '''
1556
+ if key in (ORDER_BY, GROUP_BY) and '.' not in field:
1557
+ return self.has_named_field(field)
1558
+ return re.findall(f'\b*{self.alias}[.]', field) != []
1218
1559
 
1219
1560
  @classmethod
1220
1561
  def parse(cls, txt: str, parser: Parser = SQLParser) -> list[SQLObject]:
@@ -1226,12 +1567,10 @@ class Select(SQLObject):
1226
1567
  for rule in rules:
1227
1568
  rule.apply(self)
1228
1569
 
1229
- def add_fields(self, fields: list, order_by: bool=False, group_by:bool=False):
1230
- class_types = [Field]
1231
- if order_by:
1232
- class_types += [OrderBy]
1233
- if group_by:
1234
- class_types += [GroupBy]
1570
+ def add_fields(self, fields: list, class_types=None):
1571
+ if not class_types:
1572
+ class_types = []
1573
+ class_types += [Field]
1235
1574
  FieldList(fields, class_types).add('', self)
1236
1575
 
1237
1576
  def translate_to(self, language: QueryLanguage) -> str:
@@ -1251,6 +1590,95 @@ class NotSelectIN(SelectIN):
1251
1590
  condition_class = Not
1252
1591
 
1253
1592
 
1593
+ class CTE(Select):
1594
+ prefix = ''
1595
+
1596
+ def __init__(self, table_name: str, query_list: list[Select]):
1597
+ super().__init__(table_name)
1598
+ for query in query_list:
1599
+ query.break_lines = False
1600
+ self.query_list = query_list
1601
+ self.break_lines = False
1602
+
1603
+ def __str__(self) -> str:
1604
+ size = 0
1605
+ for key in USUAL_KEYS:
1606
+ size += sum(len(v) for v in self.values.get(key, []) if '\n' not in v)
1607
+ if size > 70:
1608
+ self.break_lines = True
1609
+ # ---------------------------------------------------------
1610
+ def justify(query: Select) -> str:
1611
+ result, line = [], ''
1612
+ keywords = '|'.join(KEYWORD)
1613
+ for word in re.split(fr'({keywords}|AND|OR|,)', str(query)):
1614
+ if len(line) >= 50:
1615
+ result.append(line)
1616
+ line = ''
1617
+ line += word
1618
+ if line:
1619
+ result.append(line)
1620
+ return '\n '.join(result)
1621
+ # ---------------------------------------------------------
1622
+ return 'WITH {}{} AS (\n {}\n){}'.format(
1623
+ self.prefix, self.table_name,
1624
+ '\nUNION ALL\n '.join(
1625
+ justify(q) for q in self.query_list
1626
+ ), super().__str__()
1627
+ )
1628
+
1629
+ def join(self, pattern: str, fields: list | str, format: str=''):
1630
+ if isinstance(fields, str):
1631
+ count = len( fields.split(',') )
1632
+ else:
1633
+ count = len(fields)
1634
+ queries = detect(
1635
+ pattern*count, join_queries=False, format=format
1636
+ )
1637
+ FieldList(fields, queries, ziped=True).add('', self)
1638
+ self.break_lines = True
1639
+ return self
1640
+
1641
+ class Recursive(CTE):
1642
+ prefix = 'RECURSIVE '
1643
+
1644
+ def __str__(self) -> str:
1645
+ if len(self.query_list) > 1:
1646
+ self.query_list[-1].values[FROM].append(
1647
+ f', {self.table_name} {self.alias}')
1648
+ return super().__str__()
1649
+
1650
+ @classmethod
1651
+ def create(cls, name: str, pattern: str, formula: str, init_value, format: str=''):
1652
+ SQLObject.ALIAS_FUNC = None
1653
+ def get_field(obj: SQLObject, pos: int) -> str:
1654
+ return obj.values[SELECT][pos].split('.')[-1]
1655
+ t1, t2 = detect(
1656
+ pattern*2, join_queries=False, format=format
1657
+ )
1658
+ pk_field = get_field(t1, 0)
1659
+ foreign_key = ''
1660
+ for num in re.findall(r'\[(\d+)\]', formula):
1661
+ num = int(num)
1662
+ if not foreign_key:
1663
+ foreign_key = get_field(t2, num-1)
1664
+ formula = formula.replace(f'[{num}]', '%')
1665
+ else:
1666
+ formula = formula.replace(f'[{num}]', get_field(t2, num-1))
1667
+ Where.eq(init_value).add(pk_field, t1)
1668
+ Where.formula(formula).add(foreign_key or pk_field, t2)
1669
+ return cls(name, [t1, t2])
1670
+
1671
+ def counter(self, name: str, start, increment: str='+1'):
1672
+ for i, query in enumerate(self.query_list):
1673
+ if i == 0:
1674
+ Field.add(f'{start} AS {name}', query)
1675
+ else:
1676
+ Field.add(f'({name}{increment}) AS {name}', query)
1677
+ return self
1678
+
1679
+
1680
+ # ----- Rules -----
1681
+
1254
1682
  class RulePutLimit(Rule):
1255
1683
  @classmethod
1256
1684
  def apply(cls, target: Select):
@@ -1314,6 +1742,8 @@ class RuleDateFuncReplace(Rule):
1314
1742
  @classmethod
1315
1743
  def apply(cls, target: Select):
1316
1744
  for i, condition in enumerate(target.values.get(WHERE, [])):
1745
+ if not '(' in condition:
1746
+ continue
1317
1747
  tokens = [
1318
1748
  t.strip() for t in cls.REGEX.split(condition) if t.strip()
1319
1749
  ]
@@ -1325,6 +1755,32 @@ class RuleDateFuncReplace(Rule):
1325
1755
  target.values[WHERE][i] = ' AND '.join(temp.values[WHERE])
1326
1756
 
1327
1757
 
1758
+ class RuleReplaceJoinBySubselect(Rule):
1759
+ @classmethod
1760
+ def apply(cls, target: Select):
1761
+ main, *others = Select.parse( str(target) )
1762
+ modified = False
1763
+ for query in others:
1764
+ fk_field, primary_k = ForeignKey.find(main, query)
1765
+ more_relations = any([
1766
+ ref[0] == query.table_name for ref in ForeignKey.references
1767
+ ])
1768
+ keep_join = any([
1769
+ len( query.values.get(SELECT, []) ) > 0,
1770
+ len( query.values.get(WHERE, []) ) == 0,
1771
+ not fk_field, more_relations
1772
+ ])
1773
+ if keep_join:
1774
+ query.add(fk_field, main)
1775
+ continue
1776
+ query.__class__ = SubSelect
1777
+ Field.add(primary_k, query)
1778
+ query.add(fk_field, main)
1779
+ modified = True
1780
+ if modified:
1781
+ target.values = main.values.copy()
1782
+
1783
+
1328
1784
  def parser_class(text: str) -> Parser:
1329
1785
  PARSER_REGEX = [
1330
1786
  (r'select.*from', SQLParser),
@@ -1339,7 +1795,7 @@ def parser_class(text: str) -> Parser:
1339
1795
  return None
1340
1796
 
1341
1797
 
1342
- def detect(text: str) -> Select:
1798
+ def detect(text: str, join_queries: bool = True, format: str='') -> Select | list[Select]:
1343
1799
  from collections import Counter
1344
1800
  parser = parser_class(text)
1345
1801
  if not parser:
@@ -1350,21 +1806,19 @@ def detect(text: str) -> Select:
1350
1806
  continue
1351
1807
  pos = [ f.span() for f in re.finditer(fr'({table})[(]', text) ]
1352
1808
  for begin, end in pos[::-1]:
1353
- new_name = f'{table}_{count}' # See set_table (line 45)
1809
+ new_name = f'{table}_{count}' # See set_table (line 55)
1354
1810
  Select.EQUIVALENT_NAMES[new_name] = table
1355
1811
  text = text[:begin] + new_name + '(' + text[end:]
1356
1812
  count -= 1
1357
1813
  query_list = Select.parse(text, parser)
1814
+ if format:
1815
+ for query in query_list:
1816
+ query.set_file_format(format)
1817
+ if not join_queries:
1818
+ return query_list
1358
1819
  result = query_list[0]
1359
1820
  for query in query_list[1:]:
1360
1821
  result += query
1361
1822
  return result
1823
+ # ===========================================================================================//
1362
1824
 
1363
- if __name__ == "__main__":
1364
- OrderBy.sort = SortType.DESC
1365
- query = Select(
1366
- 'order_Detail d',
1367
- customer_id=GroupBy,
1368
- _=Sum('d.unitPrice * d.quantity').As('total', OrderBy)
1369
- )
1370
- print(query)