sql-blocks 1.25.13__py3-none-any.whl → 1.25.51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sql_blocks/sql_blocks.py CHANGED
@@ -26,7 +26,7 @@ TO_LIST = lambda x: x if isinstance(x, list) else [x]
26
26
 
27
27
 
28
28
  class SQLObject:
29
- ALIAS_FUNC = lambda t: t.lower()[:3]
29
+ ALIAS_FUNC = None
30
30
  """ ^^^^^^^^^^^^^^^^^^^^^^^^
31
31
  You can change the behavior by assigning
32
32
  a user function to SQLObject.ALIAS_FUNC
@@ -41,21 +41,35 @@ class SQLObject:
41
41
  def set_table(self, table_name: str):
42
42
  if not table_name:
43
43
  return
44
- if ' ' in table_name.strip():
44
+ cls = SQLObject
45
+ is_file_name = any([
46
+ '/' in table_name, '.' in table_name
47
+ ])
48
+ ref = table_name
49
+ if is_file_name:
50
+ ref = table_name.split('/')[-1].split('.')[0]
51
+ if cls.ALIAS_FUNC:
52
+ self.__alias = cls.ALIAS_FUNC(ref)
53
+ elif ' ' in table_name.strip():
45
54
  table_name, self.__alias = table_name.split()
46
- elif '_' in table_name:
55
+ elif '_' in ref:
47
56
  self.__alias = ''.join(
48
57
  word[0].lower()
49
- for word in table_name.split('_')
58
+ for word in ref.split('_')
50
59
  )
51
60
  else:
52
- self.__alias = SQLObject.ALIAS_FUNC(table_name)
61
+ self.__alias = ref.lower()[:3]
53
62
  self.values.setdefault(FROM, []).append(f'{table_name} {self.alias}')
54
63
 
55
64
  @property
56
65
  def table_name(self) -> str:
57
66
  return self.values[FROM][0].split()[0]
58
67
 
68
+ def set_file_format(self, pattern: str):
69
+ if '{' not in pattern:
70
+ pattern = '{}' + pattern
71
+ self.values[FROM][0] = pattern.format(self.aka())
72
+
59
73
  @property
60
74
  def alias(self) -> str:
61
75
  if self.__alias:
@@ -67,6 +81,16 @@ class SQLObject:
67
81
  appendix = {WHERE: r'\s+and\s+|', FROM: r'\s+join\s+|\s+JOIN\s+'}
68
82
  return KEYWORD[key][0].format(appendix.get(key, ''))
69
83
 
84
+ @staticmethod
85
+ def is_named_field(fld: str, name: str='') -> bool:
86
+ return re.search(fr'(\s+as\s+|\s+AS\s+){name}', fld)
87
+
88
+ def has_named_field(self, name: str) -> bool:
89
+ return any(
90
+ self.is_named_field(fld, name)
91
+ for fld in self.values.get(SELECT, [])
92
+ )
93
+
70
94
  def diff(self, key: str, search_list: list, exact: bool=False) -> set:
71
95
  def disassemble(source: list) -> list:
72
96
  if not exact:
@@ -79,12 +103,12 @@ class SQLObject:
79
103
  if exact:
80
104
  fld = fld.lower()
81
105
  return fld.strip()
82
- def is_named_field(fld: str) -> bool:
83
- return key == SELECT and re.search(r'\s+as\s+|\s+AS\s+', fld)
84
106
  def field_set(source: list) -> set:
85
107
  return set(
86
108
  (
87
- fld if is_named_field(fld) else
109
+ fld
110
+ if key == SELECT and self.is_named_field(fld, key)
111
+ else
88
112
  re.sub(pattern, '', cleanup(fld))
89
113
  )
90
114
  for string in disassemble(source)
@@ -102,13 +126,22 @@ class SQLObject:
102
126
  return s1.symmetric_difference(s2)
103
127
  return s1 - s2
104
128
 
105
- def delete(self, search: str, keys: list=USUAL_KEYS):
129
+ def delete(self, search: str, keys: list=USUAL_KEYS, exact: bool=False):
130
+ if exact:
131
+ not_match = lambda item: not re.search(fr'\w*[.]*{search}$', item)
132
+ else:
133
+ not_match = lambda item: search not in item
106
134
  for key in keys:
107
- result = []
108
- for item in self.values.get(key, []):
109
- if search not in item:
110
- result.append(item)
111
- self.values[key] = result
135
+ self.values[key] = [
136
+ item for item in self.values.get(key, [])
137
+ if not_match(item)
138
+ ]
139
+
140
+
141
+ SQL_CONST_SYSDATE = 'SYSDATE'
142
+ SQL_CONST_CURR_DATE = 'Current_date'
143
+ SQL_ROW_NUM = 'ROWNUM'
144
+ SQL_CONSTS = [SQL_CONST_SYSDATE, SQL_CONST_CURR_DATE, SQL_ROW_NUM]
112
145
 
113
146
 
114
147
  class Field:
@@ -116,10 +149,16 @@ class Field:
116
149
 
117
150
  @classmethod
118
151
  def format(cls, name: str, main: SQLObject) -> str:
152
+ def is_const() -> bool:
153
+ return any([
154
+ re.findall('[.()0-9]', name),
155
+ name in SQL_CONSTS,
156
+ re.findall(r'\w+\s*[+-]\s*\w+', name)
157
+ ])
119
158
  name = name.strip()
120
159
  if name in ('_', '*'):
121
160
  name = '*'
122
- elif not re.findall('[.()0-9]', name):
161
+ elif not is_const() and not main.has_named_field(name):
123
162
  name = f'{main.alias}.{name}'
124
163
  if Function in cls.__bases__:
125
164
  name = f'{cls.__name__}({name})'
@@ -150,35 +189,89 @@ class NamedField:
150
189
  )
151
190
 
152
191
 
192
+ class Dialect(Enum):
193
+ ANSI = 0
194
+ SQL_SERVER = 1
195
+ ORACLE = 2
196
+ POSTGRESQL = 3
197
+ MYSQL = 4
198
+
199
+ SQL_TYPES = 'CHAR INT DATE FLOAT ANY'.split()
200
+ CHAR, INT, DATE, FLOAT, ANY = SQL_TYPES
201
+
153
202
  class Function:
203
+ dialect = Dialect.ANSI
204
+ inputs = None
205
+ output = None
206
+ separator = ', '
207
+ auto_convert = True
208
+ append_param = False
209
+
154
210
  def __init__(self, *params: list):
211
+ def set_func_types(param):
212
+ if self.auto_convert and isinstance(param, Function):
213
+ func = param
214
+ main_param = self.inputs[0]
215
+ unfriendly = all([
216
+ func.output != main_param,
217
+ func.output != ANY,
218
+ main_param != ANY
219
+ ])
220
+ if unfriendly:
221
+ return Cast(func, main_param)
222
+ return param
155
223
  # --- Replace class methods by instance methods: ------
156
224
  self.add = self.__add
157
225
  self.format = self.__format
158
226
  # -----------------------------------------------------
159
- self.params = [str(p) for p in params]
227
+ self.params = [set_func_types(p) for p in params]
160
228
  self.field_class = Field
161
- self.pattern = '{}({})'
229
+ self.pattern = self.get_pattern()
162
230
  self.extra = {}
163
231
 
232
+ def get_pattern(self) -> str:
233
+ return '{func_name}({params})'
234
+
164
235
  def As(self, field_alias: str, modifiers=None):
165
236
  if modifiers:
166
237
  self.extra[field_alias] = TO_LIST(modifiers)
167
238
  self.field_class = NamedField(field_alias)
168
239
  return self
169
240
 
170
- def __format(self, name: str, main: SQLObject) -> str:
171
- if name in '*_' and self.params:
172
- params = self.params
173
- else:
174
- params = [
175
- Field.format(name, main)
176
- ] + self.params
241
+ def __str__(self) -> str:
177
242
  return self.pattern.format(
178
- self.__class__.__name__,
179
- ', '.join(params)
243
+ func_name=self.__class__.__name__,
244
+ params=self.separator.join(str(p) for p in self.params)
180
245
  )
181
246
 
247
+ @classmethod
248
+ def help(cls) -> str:
249
+ descr = ' '.join(B.__name__ for B in cls.__bases__)
250
+ params = cls.inputs or ''
251
+ return cls().get_pattern().format(
252
+ func_name=f'{descr} {cls.__name__}',
253
+ params=cls.separator.join(str(p) for p in params)
254
+ ) + f' Return {cls.output}'
255
+
256
+ def set_main_param(self, name: str, main: SQLObject) -> bool:
257
+ nested_functions = [
258
+ param for param in self.params if isinstance(param, Function)
259
+ ]
260
+ for func in nested_functions:
261
+ if func.inputs:
262
+ func.set_main_param(name, main)
263
+ return
264
+ new_params = [Field.format(name, main)]
265
+ if self.append_param:
266
+ self.params += new_params
267
+ else:
268
+ self.params = new_params + self.params
269
+
270
+ def __format(self, name: str, main: SQLObject) -> str:
271
+ if name not in '*_':
272
+ self.set_main_param(name, main)
273
+ return str(self)
274
+
182
275
  @classmethod
183
276
  def format(cls, name: str, main: SQLObject):
184
277
  return cls().__format(name, main)
@@ -196,39 +289,110 @@ class Function:
196
289
 
197
290
  # ---- String Functions: ---------------------------------
198
291
  class SubString(Function):
199
- ...
292
+ inputs = [CHAR, INT, INT]
293
+ output = CHAR
294
+
295
+ def get_pattern(self) -> str:
296
+ if self.dialect in (Dialect.ORACLE, Dialect.MYSQL):
297
+ return 'Substr({params})'
298
+ return super().get_pattern()
200
299
 
201
300
  # ---- Numeric Functions: --------------------------------
202
301
  class Round(Function):
203
- ...
302
+ inputs = [FLOAT]
303
+ output = FLOAT
204
304
 
205
305
  # --- Date Functions: ------------------------------------
206
306
  class DateDiff(Function):
307
+ inputs = [DATE]
308
+ output = DATE
309
+ append_param = True
310
+
311
+ def __str__(self) -> str:
312
+ def is_field_or_func(name: str) -> bool:
313
+ candidate = re.sub(
314
+ '[()]', '', name.split('.')[-1]
315
+ )
316
+ return candidate.isidentifier()
317
+ if self.dialect != Dialect.SQL_SERVER:
318
+ params = [str(p) for p in self.params]
319
+ return ' - '.join(
320
+ p if is_field_or_func(p) else f"'{p}'"
321
+ for p in params
322
+ ) # <==== Date subtract
323
+ return super().__str__()
324
+
325
+
326
+ class DatePart(Function):
327
+ inputs = [DATE]
328
+ output = INT
329
+
330
+ def get_pattern(self) -> str:
331
+ interval = self.__class__.__name__
332
+ database_type = {
333
+ Dialect.ORACLE: 'Extract('+interval+' FROM {params})',
334
+ Dialect.POSTGRESQL: "Date_Part('"+interval+"', {params})",
335
+ }
336
+ if self.dialect in database_type:
337
+ return database_type[self.dialect]
338
+ return super().get_pattern()
339
+
340
+ class Year(DatePart):
207
341
  ...
208
- class Extract(Function):
342
+ class Month(DatePart):
209
343
  ...
210
- class DatePart(Function):
344
+ class Day(DatePart):
211
345
  ...
346
+
347
+
212
348
  class Current_Date(Function):
213
- ...
349
+ output = DATE
350
+
351
+ def get_pattern(self) -> str:
352
+ database_type = {
353
+ Dialect.ORACLE: SQL_CONST_SYSDATE,
354
+ Dialect.POSTGRESQL: SQL_CONST_CURR_DATE,
355
+ Dialect.SQL_SERVER: 'getDate()'
356
+ }
357
+ if self.dialect in database_type:
358
+ return database_type[self.dialect]
359
+ return super().get_pattern()
360
+ # --------------------------------------------------------
214
361
 
215
- class Aggregate:
362
+ class Frame:
216
363
  break_lines: bool = True
217
364
 
218
365
  def over(self, **args):
219
- keywords = ' '.join(
220
- '{}{} BY {}'.format(
221
- '\n\t\t' if self.break_lines else '',
222
- key.upper(), args[key]
223
- ) for key in ('partition', 'order')
224
- if key in args
225
- )
366
+ """
367
+ How to use:
368
+ over(field1=OrderBy, field2=Partition)
369
+ """
370
+ keywords = ''
371
+ for field, obj in args.items():
372
+ is_valid = any([
373
+ obj is OrderBy,
374
+ obj is Partition,
375
+ isinstance(obj, Rows),
376
+ ])
377
+ if not is_valid:
378
+ continue
379
+ keywords += '{}{} {}'.format(
380
+ '\n\t\t' if self.break_lines else ' ',
381
+ obj.cls_to_str(), field if field != '_' else ''
382
+ )
226
383
  if keywords and self.break_lines:
227
384
  keywords += '\n\t'
228
- self.pattern = '{}({})' + f' OVER({keywords})'
385
+ self.pattern = self.get_pattern() + f' OVER({keywords})'
229
386
  return self
230
387
 
231
388
 
389
+ class Aggregate(Frame):
390
+ inputs = [FLOAT]
391
+ output = FLOAT
392
+
393
+ class Window(Frame):
394
+ ...
395
+
232
396
  # ---- Aggregate Functions: -------------------------------
233
397
  class Avg(Aggregate, Function):
234
398
  ...
@@ -241,11 +405,32 @@ class Sum(Aggregate, Function):
241
405
  class Count(Aggregate, Function):
242
406
  ...
243
407
 
408
+ # ---- Window Functions: -----------------------------------
409
+ class Row_Number(Window, Function):
410
+ output = INT
411
+
412
+ class Rank(Window, Function):
413
+ output = INT
414
+
415
+ class Lag(Window, Function):
416
+ output = ANY
417
+
418
+ class Lead(Window, Function):
419
+ output = ANY
420
+
421
+
244
422
  # ---- Conversions and other Functions: ---------------------
245
423
  class Coalesce(Function):
246
- ...
424
+ inputs = [ANY]
425
+ output = ANY
426
+
247
427
  class Cast(Function):
248
- ...
428
+ inputs = [ANY]
429
+ output = ANY
430
+ separator = ' As '
431
+
432
+
433
+ FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
249
434
 
250
435
 
251
436
  class ExpressionField:
@@ -270,15 +455,20 @@ class ExpressionField:
270
455
  class FieldList:
271
456
  separator = ','
272
457
 
273
- def __init__(self, fields: list=[], class_types = [Field]):
458
+ def __init__(self, fields: list=[], class_types = [Field], ziped: bool=False):
274
459
  if isinstance(fields, str):
275
460
  fields = [
276
461
  f.strip() for f in fields.split(self.separator)
277
462
  ]
278
463
  self.fields = fields
279
464
  self.class_types = class_types
465
+ self.ziped = ziped
280
466
 
281
467
  def add(self, name: str, main: SQLObject):
468
+ if self.ziped: # --- One class per field...
469
+ for field, class_type in zip(self.fields, self.class_types):
470
+ class_type.add(field, main)
471
+ return
282
472
  for field in self.fields:
283
473
  for class_type in self.class_types:
284
474
  class_type.add(field, main)
@@ -320,27 +510,43 @@ class ForeignKey:
320
510
 
321
511
  def quoted(value) -> str:
322
512
  if isinstance(value, str):
513
+ if re.search(r'\bor\b', value, re.IGNORECASE):
514
+ raise PermissionError('Possible SQL injection attempt')
323
515
  value = f"'{value}'"
324
516
  return str(value)
325
517
 
326
518
 
519
+ class Position(Enum):
520
+ StartsWith = -1
521
+ Middle = 0
522
+ EndsWith = 1
523
+
524
+
327
525
  class Where:
328
526
  prefix = ''
329
527
 
330
- def __init__(self, expr: str):
331
- self.expr = expr
528
+ def __init__(self, content: str):
529
+ self.content = content
332
530
 
333
531
  @classmethod
334
532
  def __constructor(cls, operator: str, value):
335
- return cls(expr=f'{operator} {quoted(value)}')
533
+ return cls(f'{operator} {quoted(value)}')
336
534
 
337
535
  @classmethod
338
536
  def eq(cls, value):
339
537
  return cls.__constructor('=', value)
340
538
 
341
539
  @classmethod
342
- def contains(cls, value: str):
343
- return cls(f"LIKE '%{value}%'")
540
+ def contains(cls, text: str, pos: int | Position = Position.Middle):
541
+ if isinstance(pos, int):
542
+ pos = Position(pos)
543
+ return cls(
544
+ "LIKE '{}{}{}'".format(
545
+ '%' if pos != Position.StartsWith else '',
546
+ text,
547
+ '%' if pos != Position.EndsWith else ''
548
+ )
549
+ )
344
550
 
345
551
  @classmethod
346
552
  def gt(cls, value):
@@ -368,9 +574,42 @@ class Where:
368
574
  values = ','.join(quoted(v) for v in values)
369
575
  return cls(f'IN ({values})')
370
576
 
577
+ @classmethod
578
+ def formula(cls, formula: str):
579
+ where = cls( ExpressionField(formula) )
580
+ where.add = where.add_expression
581
+ return where
582
+
583
+ def add_expression(self, name: str, main: SQLObject):
584
+ self.content = self.content.format(name, main)
585
+ main.values.setdefault(WHERE, []).append('{} {}'.format(
586
+ self.prefix, self.content
587
+ ))
588
+
589
+ @classmethod
590
+ def join(cls, query: SQLObject):
591
+ where = cls(query)
592
+ where.add = where.add_join
593
+ return where
594
+
595
+ def add_join(self, name: str, main: SQLObject):
596
+ query = self.content
597
+ main.values[FROM].append(f',{query.table_name} {query.alias}')
598
+ for key in USUAL_KEYS:
599
+ main.update_values(key, query.values.get(key, []))
600
+ main.values.setdefault(WHERE, []).append('({a1}.{f1} = {a2}.{f2})'.format(
601
+ a1=main.alias, f1=name,
602
+ a2=query.alias, f2=query.key_field
603
+ ))
604
+
371
605
  def add(self, name: str, main: SQLObject):
606
+ func_type = FUNCTION_CLASS.get(name.lower())
607
+ if func_type:
608
+ name = func_type.format('*', main)
609
+ elif not main.has_named_field(name):
610
+ name = Field.format(name, main)
372
611
  main.values.setdefault(WHERE, []).append('{}{} {}'.format(
373
- self.prefix, Field.format(name, main), self.expr
612
+ self.prefix, name, self.content
374
613
  ))
375
614
 
376
615
 
@@ -378,6 +617,10 @@ eq, contains, gt, gte, lt, lte, is_null, inside = (
378
617
  getattr(Where, method) for method in
379
618
  ('eq', 'contains', 'gt', 'gte', 'lt', 'lte', 'is_null', 'inside')
380
619
  )
620
+ startswith, endswith = [
621
+ lambda x: contains(x, Position.StartsWith),
622
+ lambda x: contains(x, Position.EndsWith)
623
+ ]
381
624
 
382
625
 
383
626
  class Not(Where):
@@ -385,7 +628,7 @@ class Not(Where):
385
628
 
386
629
  @classmethod
387
630
  def eq(cls, value):
388
- return Where(expr=f'<> {quoted(value)}')
631
+ return Where(f'<> {quoted(value)}')
389
632
 
390
633
 
391
634
  class Case:
@@ -394,22 +637,26 @@ class Case:
394
637
  self.default = None
395
638
  self.field = field
396
639
 
397
- def when(self, condition: Where, result: str):
640
+ def when(self, condition: Where, result):
641
+ if isinstance(result, str):
642
+ result = quoted(result)
398
643
  self.__conditions[result] = condition
399
644
  return self
400
645
 
401
- def else_value(self, default: str):
646
+ def else_value(self, default):
647
+ if isinstance(default, str):
648
+ default = quoted(default)
402
649
  self.default = default
403
650
  return self
404
651
 
405
652
  def add(self, name: str, main: SQLObject):
406
653
  field = Field.format(self.field, main)
407
- default = quoted(self.default)
654
+ default = self.default
408
655
  name = 'CASE \n{}\n\tEND AS {}'.format(
409
656
  '\n'.join(
410
- f'\t\tWHEN {field} {cond.expr} THEN {quoted(res)}'
657
+ f'\t\tWHEN {field} {cond.content} THEN {res}'
411
658
  for res, cond in self.__conditions.items()
412
- ) + f'\n\t\tELSE {default}' if default else '',
659
+ ) + (f'\n\t\tELSE {default}' if default else ''),
413
660
  name
414
661
  )
415
662
  main.values.setdefault(SELECT, []).append(name)
@@ -420,14 +667,13 @@ class Options:
420
667
  self.__children: dict = values
421
668
 
422
669
  def add(self, logical_separator: str, main: SQLObject):
423
- """
424
- `logical_separator` must be AND or OR
425
- """
670
+ if logical_separator not in ('AND', 'OR'):
671
+ raise ValueError('`logical_separator` must be AND or OR')
426
672
  conditions: list[str] = []
427
673
  child: Where
428
674
  for field, child in self.__children.items():
429
675
  conditions.append(' {} {} '.format(
430
- Field.format(field, main), child.expr
676
+ Field.format(field, main), child.content
431
677
  ))
432
678
  main.values.setdefault(WHERE, []).append(
433
679
  '(' + logical_separator.join(conditions) + ')'
@@ -435,28 +681,57 @@ class Options:
435
681
 
436
682
 
437
683
  class Between:
684
+ is_literal: bool = False
685
+
438
686
  def __init__(self, start, end):
439
687
  if start > end:
440
688
  start, end = end, start
441
689
  self.start = start
442
690
  self.end = end
443
691
 
692
+ def literal(self) -> Where:
693
+ return Where('BETWEEN {} AND {}'.format(
694
+ self.start, self.end
695
+ ))
696
+
444
697
  def add(self, name: str, main:SQLObject):
445
- Where.gte(self.start).add(name, main),
698
+ if self.is_literal:
699
+ return self.literal().add(name, main)
700
+ Where.gte(self.start).add(name, main)
446
701
  Where.lte(self.end).add(name, main)
447
702
 
703
+ class SameDay(Between):
704
+ def __init__(self, date: str):
705
+ super().__init__(
706
+ f'{date} 00:00:00',
707
+ f'{date} 23:59:59',
708
+ )
709
+
710
+
711
+ class Range(Case):
712
+ INC_FUNCTION = lambda x: x + 1
713
+
714
+ def __init__(self, field: str, values: dict):
715
+ super().__init__(field)
716
+ start = 0
717
+ cls = self.__class__
718
+ for label, value in sorted(values.items(), key=lambda item: item[1]):
719
+ self.when(
720
+ Between(start, value).literal(), label
721
+ )
722
+ start = cls.INC_FUNCTION(value)
723
+
448
724
 
449
725
  class Clause:
450
726
  @classmethod
451
727
  def format(cls, name: str, main: SQLObject) -> str:
452
728
  def is_function() -> bool:
453
729
  diff = main.diff(SELECT, [name.lower()], True)
454
- FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
455
730
  return diff.intersection(FUNCTION_CLASS)
456
731
  found = re.findall(r'^_\d', name)
457
732
  if found:
458
733
  name = found[0].replace('_', '')
459
- elif main.alias and not is_function():
734
+ elif '.' not in name and main.alias and not is_function():
460
735
  name = f'{main.alias}.{name}'
461
736
  return name
462
737
 
@@ -465,6 +740,34 @@ class SortType(Enum):
465
740
  ASC = ''
466
741
  DESC = ' DESC'
467
742
 
743
+ class Row:
744
+ def __init__(self, value: int=0):
745
+ self.value = value
746
+
747
+ def __str__(self) -> str:
748
+ return '{} {}'.format(
749
+ 'UNBOUNDED' if self.value == 0 else self.value,
750
+ self.__class__.__name__.upper()
751
+ )
752
+
753
+ class Preceding(Row):
754
+ ...
755
+ class Following(Row):
756
+ ...
757
+ class Current(Row):
758
+ def __str__(self) -> str:
759
+ return 'CURRENT ROW'
760
+
761
+ class Rows:
762
+ def __init__(self, *rows: list[Row]):
763
+ self.rows = rows
764
+
765
+ def cls_to_str(self) -> str:
766
+ return 'ROWS {}{}'.format(
767
+ 'BETWEEN ' if len(self.rows) > 1 else '',
768
+ ' AND '.join(str(row) for row in self.rows)
769
+ )
770
+
468
771
 
469
772
  class OrderBy(Clause):
470
773
  sort: SortType = SortType.ASC
@@ -474,6 +777,16 @@ class OrderBy(Clause):
474
777
  name = cls.format(name, main)
475
778
  main.values.setdefault(ORDER_BY, []).append(name+cls.sort.value)
476
779
 
780
+ @classmethod
781
+ def cls_to_str(cls) -> str:
782
+ return ORDER_BY
783
+
784
+ PARTITION_BY = 'PARTITION BY'
785
+ class Partition:
786
+ @classmethod
787
+ def cls_to_str(cls) -> str:
788
+ return PARTITION_BY
789
+
477
790
 
478
791
  class GroupBy(Clause):
479
792
  @classmethod
@@ -489,7 +802,7 @@ class Having:
489
802
 
490
803
  def add(self, name: str, main:SQLObject):
491
804
  main.values[GROUP_BY][-1] += ' HAVING {} {}'.format(
492
- self.function.format(name, main), self.condition.expr
805
+ self.function.format(name, main), self.condition.content
493
806
  )
494
807
 
495
808
  @classmethod
@@ -519,7 +832,7 @@ class Rule:
519
832
  ...
520
833
 
521
834
  class QueryLanguage:
522
- pattern = '{select}{_from}{where}{group_by}{order_by}'
835
+ pattern = '{select}{_from}{where}{group_by}{order_by}{limit}'
523
836
  has_default = {key: bool(key == SELECT) for key in KEYWORD}
524
837
 
525
838
  @staticmethod
@@ -542,18 +855,21 @@ class QueryLanguage:
542
855
  return self.join_with_tabs(values, ' AND ')
543
856
 
544
857
  def sort_by(self, values: list) -> str:
545
- return self.join_with_tabs(values)
858
+ return self.join_with_tabs(values, ',')
546
859
 
547
860
  def set_group(self, values: list) -> str:
548
861
  return self.join_with_tabs(values, ',')
549
862
 
863
+ def set_limit(self, values: list) -> str:
864
+ return self.join_with_tabs(values, ' ')
865
+
550
866
  def __init__(self, target: 'Select'):
551
- self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY]
867
+ self.KEYWORDS = [SELECT, FROM, WHERE, GROUP_BY, ORDER_BY, LIMIT]
552
868
  self.TABULATION = '\n\t' if target.break_lines else ' '
553
869
  self.LINE_BREAK = '\n' if target.break_lines else ' '
554
870
  self.TOKEN_METHODS = {
555
871
  SELECT: self.add_field, FROM: self.get_tables,
556
- WHERE: self.extract_conditions,
872
+ WHERE: self.extract_conditions, LIMIT: self.set_limit,
557
873
  ORDER_BY: self.sort_by, GROUP_BY: self.set_group,
558
874
  }
559
875
  self.result = {}
@@ -857,10 +1173,13 @@ class SQLParser(Parser):
857
1173
  if not key in values:
858
1174
  continue
859
1175
  separator = self.class_type.get_separator(key)
1176
+ cls = {
1177
+ ORDER_BY: OrderBy, GROUP_BY: GroupBy
1178
+ }.get(key, Field)
860
1179
  obj.values[key] = [
861
- Field.format(fld, obj)
1180
+ cls.format(fld, obj)
862
1181
  for fld in re.split(separator, values[key])
863
- if (fld != '*' and len(tables) == 1) or obj.match(fld)
1182
+ if (fld != '*' and len(tables) == 1) or obj.match(fld, key)
864
1183
  ]
865
1184
  result[obj.alias] = obj
866
1185
  self.queries = list( result.values() )
@@ -920,16 +1239,26 @@ class CypherParser(Parser):
920
1239
  if token in self.TOKEN_METHODS:
921
1240
  return
922
1241
  class_list = [Field]
923
- if '$' in token:
1242
+ if '*' in token:
1243
+ token = token.replace('*', '')
1244
+ self.queries[-1].key_field = token
1245
+ return
1246
+ elif '$' in token:
924
1247
  func_name, token = token.split('$')
925
1248
  if func_name == 'count':
926
1249
  if not token:
927
1250
  token = 'count_1'
928
- NamedField(token, Count).add('*', self.queries[-1])
929
- class_list = []
1251
+ pk_field = self.queries[-1].key_field or 'id'
1252
+ Count().As(token, extra_classes).add(pk_field, self.queries[-1])
1253
+ return
930
1254
  else:
931
- FUNCTION_CLASS = {f.__name__.lower(): f for f in Function.__subclasses__()}
932
- class_list = [ FUNCTION_CLASS[func_name] ]
1255
+ class_type = FUNCTION_CLASS.get(func_name)
1256
+ if not class_type:
1257
+ raise ValueError(f'Unknown function `{func_name}`.')
1258
+ if ':' in token:
1259
+ token, field_alias = token.split(':')
1260
+ class_type = class_type().As(field_alias)
1261
+ class_list = [class_type]
933
1262
  class_list += extra_classes
934
1263
  FieldList(token, class_list).add('', self.queries[-1])
935
1264
 
@@ -944,10 +1273,13 @@ class CypherParser(Parser):
944
1273
  def add_foreign_key(self, token: str, pk_field: str=''):
945
1274
  curr, last = [self.queries[i] for i in (-1, -2)]
946
1275
  if not pk_field:
947
- if not last.values.get(SELECT):
948
- raise IndexError(f'Primary Key not found for {last.table_name}.')
949
- pk_field = last.values[SELECT][-1].split('.')[-1]
950
- last.delete(pk_field, [SELECT])
1276
+ if last.key_field:
1277
+ pk_field = last.key_field
1278
+ else:
1279
+ if not last.values.get(SELECT):
1280
+ raise IndexError(f'Primary Key not found for {last.table_name}.')
1281
+ pk_field = last.values[SELECT][-1].split('.')[-1]
1282
+ last.delete(pk_field, [SELECT], exact=True)
951
1283
  if '{}' in token:
952
1284
  foreign_fld = token.format(
953
1285
  last.table_name.lower()
@@ -962,12 +1294,11 @@ class CypherParser(Parser):
962
1294
  if fld not in curr.values.get(GROUP_BY, [])
963
1295
  ]
964
1296
  foreign_fld = fields[0].split('.')[-1]
965
- curr.delete(foreign_fld, [SELECT])
1297
+ curr.delete(foreign_fld, [SELECT], exact=True)
966
1298
  if curr.join_type == JoinType.RIGHT:
967
1299
  pk_field, foreign_fld = foreign_fld, pk_field
968
1300
  if curr.join_type == JoinType.RIGHT:
969
1301
  curr, last = last, curr
970
- # pk_field, foreign_fld = foreign_fld, pk_field
971
1302
  k = ForeignKey.get_key(curr, last)
972
1303
  ForeignKey.references[k] = (foreign_fld, pk_field)
973
1304
 
@@ -1135,7 +1466,6 @@ class MongoParser(Parser):
1135
1466
 
1136
1467
  class Select(SQLObject):
1137
1468
  join_type: JoinType = JoinType.INNER
1138
- REGEX = {}
1139
1469
  EQUIVALENT_NAMES = {}
1140
1470
 
1141
1471
  def __init__(self, table_name: str='', **values):
@@ -1153,21 +1483,30 @@ class Select(SQLObject):
1153
1483
 
1154
1484
  def add(self, name: str, main: SQLObject):
1155
1485
  old_tables = main.values.get(FROM, [])
1156
- new_tables = set([
1157
- '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
1486
+ if len(self.values[FROM]) > 1:
1487
+ old_tables += self.values[FROM][1:]
1488
+ new_tables = []
1489
+ row = '{jt}JOIN {tb} {a2} ON ({a1}.{f1} = {a2}.{f2})'.format(
1158
1490
  jt=self.join_type.value,
1159
1491
  tb=self.aka(),
1160
1492
  a1=main.alias, f1=name,
1161
1493
  a2=self.alias, f2=self.key_field
1162
1494
  )
1163
- ] + old_tables[1:])
1164
- main.values[FROM] = old_tables[:1] + list(new_tables)
1495
+ if row not in old_tables[1:]:
1496
+ new_tables.append(row)
1497
+ main.values[FROM] = old_tables[:1] + new_tables + old_tables[1:]
1165
1498
  for key in USUAL_KEYS:
1166
1499
  main.update_values(key, self.values.get(key, []))
1167
1500
 
1168
- def __add__(self, other: SQLObject):
1501
+ def copy(self) -> SQLObject:
1169
1502
  from copy import deepcopy
1170
- query = deepcopy(self)
1503
+ return deepcopy(self)
1504
+
1505
+ def no_relation_error(self, other: SQLObject):
1506
+ raise ValueError(f'No relationship found between {self.table_name} and {other.table_name}.')
1507
+
1508
+ def __add__(self, other: SQLObject):
1509
+ query = self.copy()
1171
1510
  if query.table_name.lower() == other.table_name.lower():
1172
1511
  for key in USUAL_KEYS:
1173
1512
  query.update_values(key, other.values.get(key, []))
@@ -1180,7 +1519,7 @@ class Select(SQLObject):
1180
1519
  PrimaryKey.add(primary_key, query)
1181
1520
  query.add(foreign_field, other)
1182
1521
  return other
1183
- raise ValueError(f'No relationship found between {query.table_name} and {other.table_name}.')
1522
+ self.no_relation_error(other) # === raise ERROR ... ===
1184
1523
  elif primary_key:
1185
1524
  PrimaryKey.add(primary_key, other)
1186
1525
  other.add(foreign_field, query)
@@ -1200,16 +1539,48 @@ class Select(SQLObject):
1200
1539
  if self.diff(key, other.values.get(key, []), True):
1201
1540
  return False
1202
1541
  return True
1542
+
1543
+ def __sub__(self, other: SQLObject) -> SQLObject:
1544
+ fk_field, primary_k = ForeignKey.find(self, other)
1545
+ if fk_field:
1546
+ query = self.copy()
1547
+ other = other.copy()
1548
+ else:
1549
+ fk_field, primary_k = ForeignKey.find(other, self)
1550
+ if not fk_field:
1551
+ self.no_relation_error(other) # === raise ERROR ... ===
1552
+ query = other.copy()
1553
+ other = self.copy()
1554
+ query.__class__ = NotSelectIN
1555
+ Field.add(fk_field, query)
1556
+ query.add(primary_k, other)
1557
+ return other
1203
1558
 
1204
1559
  def limit(self, row_count: int=100, offset: int=0):
1205
- result = [str(row_count)]
1206
- if offset > 0:
1207
- result.append(f'OFFSET {offset}')
1208
- self.values.setdefault(LIMIT, result)
1560
+ if Function.dialect == Dialect.SQL_SERVER:
1561
+ fields = self.values.get(SELECT)
1562
+ if fields:
1563
+ fields[0] = f'SELECT TOP({row_count}) {fields[0]}'
1564
+ else:
1565
+ self.values[SELECT] = [f'SELECT TOP({row_count}) *']
1566
+ return self
1567
+ if Function.dialect == Dialect.ORACLE:
1568
+ Where.gte(row_count).add(SQL_ROW_NUM, self)
1569
+ if offset > 0:
1570
+ Where.lte(row_count+offset).add(SQL_ROW_NUM, self)
1571
+ return self
1572
+ self.values[LIMIT] = ['{}{}'.format(
1573
+ row_count, f' OFFSET {offset}' if offset > 0 else ''
1574
+ )]
1209
1575
  return self
1210
1576
 
1211
- def match(self, expr: str) -> bool:
1212
- return re.findall(f'\b*{self.alias}[.]', expr) != []
1577
+ def match(self, field: str, key: str) -> bool:
1578
+ '''
1579
+ Recognizes if the field is from the current table
1580
+ '''
1581
+ if key in (ORDER_BY, GROUP_BY) and '.' not in field:
1582
+ return self.has_named_field(field)
1583
+ return re.findall(f'\b*{self.alias}[.]', field) != []
1213
1584
 
1214
1585
  @classmethod
1215
1586
  def parse(cls, txt: str, parser: Parser = SQLParser) -> list[SQLObject]:
@@ -1221,12 +1592,10 @@ class Select(SQLObject):
1221
1592
  for rule in rules:
1222
1593
  rule.apply(self)
1223
1594
 
1224
- def add_fields(self, fields: list, order_by: bool=False, group_by:bool=False):
1225
- class_types = [Field]
1226
- if order_by:
1227
- class_types += [OrderBy]
1228
- if group_by:
1229
- class_types += [GroupBy]
1595
+ def add_fields(self, fields: list, class_types=None):
1596
+ if not class_types:
1597
+ class_types = []
1598
+ class_types += [Field]
1230
1599
  FieldList(fields, class_types).add('', self)
1231
1600
 
1232
1601
  def translate_to(self, language: QueryLanguage) -> str:
@@ -1246,6 +1615,95 @@ class NotSelectIN(SelectIN):
1246
1615
  condition_class = Not
1247
1616
 
1248
1617
 
1618
+ class CTE(Select):
1619
+ prefix = ''
1620
+
1621
+ def __init__(self, table_name: str, query_list: list[Select]):
1622
+ super().__init__(table_name)
1623
+ for query in query_list:
1624
+ query.break_lines = False
1625
+ self.query_list = query_list
1626
+ self.break_lines = False
1627
+
1628
+ def __str__(self) -> str:
1629
+ size = 0
1630
+ for key in USUAL_KEYS:
1631
+ size += sum(len(v) for v in self.values.get(key, []) if '\n' not in v)
1632
+ if size > 70:
1633
+ self.break_lines = True
1634
+ # ---------------------------------------------------------
1635
+ def justify(query: Select) -> str:
1636
+ result, line = [], ''
1637
+ keywords = '|'.join(KEYWORD)
1638
+ for word in re.split(fr'({keywords}|AND|OR|,)', str(query)):
1639
+ if len(line) >= 50:
1640
+ result.append(line)
1641
+ line = ''
1642
+ line += word
1643
+ if line:
1644
+ result.append(line)
1645
+ return '\n '.join(result)
1646
+ # ---------------------------------------------------------
1647
+ return 'WITH {}{} AS (\n {}\n){}'.format(
1648
+ self.prefix, self.table_name,
1649
+ '\nUNION ALL\n '.join(
1650
+ justify(q) for q in self.query_list
1651
+ ), super().__str__()
1652
+ )
1653
+
1654
+ def join(self, pattern: str, fields: list | str, format: str=''):
1655
+ if isinstance(fields, str):
1656
+ count = len( fields.split(',') )
1657
+ else:
1658
+ count = len(fields)
1659
+ queries = detect(
1660
+ pattern*count, join_queries=False, format=format
1661
+ )
1662
+ FieldList(fields, queries, ziped=True).add('', self)
1663
+ self.break_lines = True
1664
+ return self
1665
+
1666
+ class Recursive(CTE):
1667
+ prefix = 'RECURSIVE '
1668
+
1669
+ def __str__(self) -> str:
1670
+ if len(self.query_list) > 1:
1671
+ self.query_list[-1].values[FROM].append(
1672
+ f', {self.table_name} {self.alias}')
1673
+ return super().__str__()
1674
+
1675
+ @classmethod
1676
+ def create(cls, name: str, pattern: str, formula: str, init_value, format: str=''):
1677
+ SQLObject.ALIAS_FUNC = None
1678
+ def get_field(obj: SQLObject, pos: int) -> str:
1679
+ return obj.values[SELECT][pos].split('.')[-1]
1680
+ t1, t2 = detect(
1681
+ pattern*2, join_queries=False, format=format
1682
+ )
1683
+ pk_field = get_field(t1, 0)
1684
+ foreign_key = ''
1685
+ for num in re.findall(r'\[(\d+)\]', formula):
1686
+ num = int(num)
1687
+ if not foreign_key:
1688
+ foreign_key = get_field(t2, num-1)
1689
+ formula = formula.replace(f'[{num}]', '%')
1690
+ else:
1691
+ formula = formula.replace(f'[{num}]', get_field(t2, num-1))
1692
+ Where.eq(init_value).add(pk_field, t1)
1693
+ Where.formula(formula).add(foreign_key or pk_field, t2)
1694
+ return cls(name, [t1, t2])
1695
+
1696
+ def counter(self, name: str, start, increment: str='+1'):
1697
+ for i, query in enumerate(self.query_list):
1698
+ if i == 0:
1699
+ Field.add(f'{start} AS {name}', query)
1700
+ else:
1701
+ Field.add(f'({name}{increment}) AS {name}', query)
1702
+ return self
1703
+
1704
+
1705
+ # ----- Rules -----
1706
+
1249
1707
  class RulePutLimit(Rule):
1250
1708
  @classmethod
1251
1709
  def apply(cls, target: Select):
@@ -1309,6 +1767,8 @@ class RuleDateFuncReplace(Rule):
1309
1767
  @classmethod
1310
1768
  def apply(cls, target: Select):
1311
1769
  for i, condition in enumerate(target.values.get(WHERE, [])):
1770
+ if not '(' in condition:
1771
+ continue
1312
1772
  tokens = [
1313
1773
  t.strip() for t in cls.REGEX.split(condition) if t.strip()
1314
1774
  ]
@@ -1320,6 +1780,32 @@ class RuleDateFuncReplace(Rule):
1320
1780
  target.values[WHERE][i] = ' AND '.join(temp.values[WHERE])
1321
1781
 
1322
1782
 
1783
+ class RuleReplaceJoinBySubselect(Rule):
1784
+ @classmethod
1785
+ def apply(cls, target: Select):
1786
+ main, *others = Select.parse( str(target) )
1787
+ modified = False
1788
+ for query in others:
1789
+ fk_field, primary_k = ForeignKey.find(main, query)
1790
+ more_relations = any([
1791
+ ref[0] == query.table_name for ref in ForeignKey.references
1792
+ ])
1793
+ keep_join = any([
1794
+ len( query.values.get(SELECT, []) ) > 0,
1795
+ len( query.values.get(WHERE, []) ) == 0,
1796
+ not fk_field, more_relations
1797
+ ])
1798
+ if keep_join:
1799
+ query.add(fk_field, main)
1800
+ continue
1801
+ query.__class__ = SubSelect
1802
+ Field.add(primary_k, query)
1803
+ query.add(fk_field, main)
1804
+ modified = True
1805
+ if modified:
1806
+ target.values = main.values.copy()
1807
+
1808
+
1323
1809
  def parser_class(text: str) -> Parser:
1324
1810
  PARSER_REGEX = [
1325
1811
  (r'select.*from', SQLParser),
@@ -1334,7 +1820,7 @@ def parser_class(text: str) -> Parser:
1334
1820
  return None
1335
1821
 
1336
1822
 
1337
- def detect(text: str) -> Select:
1823
+ def detect(text: str, join_queries: bool = True, format: str='') -> Select | list[Select]:
1338
1824
  from collections import Counter
1339
1825
  parser = parser_class(text)
1340
1826
  if not parser:
@@ -1345,21 +1831,18 @@ def detect(text: str) -> Select:
1345
1831
  continue
1346
1832
  pos = [ f.span() for f in re.finditer(fr'({table})[(]', text) ]
1347
1833
  for begin, end in pos[::-1]:
1348
- new_name = f'{table}_{count}' # See set_table (line 45)
1834
+ new_name = f'{table}_{count}' # See set_table (line 55)
1349
1835
  Select.EQUIVALENT_NAMES[new_name] = table
1350
1836
  text = text[:begin] + new_name + '(' + text[end:]
1351
1837
  count -= 1
1352
1838
  query_list = Select.parse(text, parser)
1839
+ if format:
1840
+ for query in query_list:
1841
+ query.set_file_format(format)
1842
+ if not join_queries:
1843
+ return query_list
1353
1844
  result = query_list[0]
1354
1845
  for query in query_list[1:]:
1355
1846
  result += query
1356
1847
  return result
1357
-
1358
- if __name__ == "__main__":
1359
- OrderBy.sort = SortType.DESC
1360
- query = Select(
1361
- 'order_Detail d',
1362
- customer_id=GroupBy,
1363
- _=Sum('d.unitPrice * d.quantity').As('total', OrderBy)
1364
- )
1365
- print(query)
1848
+ # ===========================================================================================//