django-ormql 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
django_ormql/query.py ADDED
@@ -0,0 +1,863 @@
1
+ import logging
2
+
3
+ from django.conf import settings
4
+ from django.db import models
5
+ from django.db.models import (
6
+ F,
7
+ Value,
8
+ Q,
9
+ ExpressionWrapper,
10
+ BooleanField,
11
+ aggregates,
12
+ OrderBy,
13
+ functions,
14
+ lookups,
15
+ OuterRef,
16
+ )
17
+ from django.db.models.fields.json import KeyTransform
18
+ from django.db.models.functions import Cast
19
+ from sqlglot import parse_one, Dialect, Tokenizer, TokenType, Generator, ParseError
20
+ from sqlglot import expressions
21
+
22
+ from . import db_func
23
+ from .db_func import NumericAwareCase
24
+ from .exceptions import QueryNotSupported, QueryError
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class OrmqlDialect(Dialect):
30
+ DPIPE_IS_STRING_CONCAT = True
31
+ QUOTE_START = "'"
32
+ QUOTE_END = "'"
33
+ IDENTIFIER_START = "`"
34
+ IDENTIFIER_END = "`"
35
+
36
+ class Tokenizer(Tokenizer):
37
+ QUOTES = ["'", '"']
38
+ IDENTIFIERS = ["`"]
39
+
40
+ KEYWORDS = {
41
+ "==": TokenType.EQ,
42
+ "::": TokenType.DCOLON,
43
+ ">=": TokenType.GTE,
44
+ "<=": TokenType.LTE,
45
+ "<>": TokenType.NEQ,
46
+ "!=": TokenType.NEQ,
47
+ "||": TokenType.DPIPE,
48
+ "->": TokenType.ARROW,
49
+ "AND": TokenType.AND,
50
+ "ASC": TokenType.ASC,
51
+ "AS": TokenType.ALIAS,
52
+ "BETWEEN": TokenType.BETWEEN,
53
+ "CASE": TokenType.CASE,
54
+ "CURRENT_DATE": TokenType.CURRENT_DATE,
55
+ "CURRENT_TIME": TokenType.CURRENT_TIME,
56
+ "CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP,
57
+ "DESC": TokenType.DESC,
58
+ "DISTINCT": TokenType.DISTINCT,
59
+ "ELSE": TokenType.ELSE,
60
+ "END": TokenType.END,
61
+ "EXISTS": TokenType.EXISTS,
62
+ "FALSE": TokenType.FALSE,
63
+ "FILTER": TokenType.FILTER,
64
+ "FIRST": TokenType.FIRST,
65
+ "FROM": TokenType.FROM,
66
+ "GROUP BY": TokenType.GROUP_BY,
67
+ "HAVING": TokenType.HAVING,
68
+ "ILIKE": TokenType.ILIKE,
69
+ "IN": TokenType.IN,
70
+ "IS": TokenType.IS,
71
+ "ISNULL": TokenType.ISNULL,
72
+ "LIKE": TokenType.LIKE,
73
+ "LIMIT": TokenType.LIMIT,
74
+ "NOT": TokenType.NOT,
75
+ "NOTNULL": TokenType.NOTNULL,
76
+ "NULL": TokenType.NULL,
77
+ "OFFSET": TokenType.OFFSET,
78
+ "OR": TokenType.OR,
79
+ "ORDER BY": TokenType.ORDER_BY,
80
+ "SELECT": TokenType.SELECT,
81
+ "THEN": TokenType.THEN,
82
+ "TRUE": TokenType.TRUE,
83
+ "WHEN": TokenType.WHEN,
84
+ "WHERE": TokenType.WHERE,
85
+ # TYPES
86
+ "BOOL": TokenType.BOOLEAN,
87
+ "BOOLEAN": TokenType.BOOLEAN,
88
+ "INT": TokenType.INT,
89
+ "BIGINT": TokenType.BIGINT,
90
+ "DECIMAL": TokenType.DECIMAL,
91
+ "FLOAT": TokenType.FLOAT,
92
+ "DOUBLE": TokenType.DOUBLE,
93
+ "JSONB": TokenType.JSONB,
94
+ "TEXT": TokenType.TEXT,
95
+ "TIME": TokenType.TIME,
96
+ "DATE": TokenType.DATE,
97
+ "DATETIME": TokenType.DATETIME,
98
+ }
99
+
100
+ class Generator(Generator):
101
+ pass
102
+
103
+
104
+ boolean_expression_nodes = {
105
+ expressions.EQ: db_func.Equal,
106
+ expressions.NEQ: db_func.NotEqual,
107
+ expressions.GT: db_func.GreaterThan,
108
+ expressions.GTE: db_func.GreaterEqualThan,
109
+ expressions.LT: db_func.LowerThan,
110
+ expressions.LTE: db_func.LowerEqualThan,
111
+ expressions.Is: db_func.Is,
112
+ expressions.Like: db_func.Like,
113
+ expressions.ILike: lambda a, b: db_func.Like(
114
+ functions.Upper(a), functions.Upper(b)
115
+ ),
116
+ }
117
+
118
+ math_binary_nodes = {
119
+ expressions.Mul: db_func.Mul,
120
+ expressions.Add: db_func.Add,
121
+ expressions.Sub: db_func.Sub,
122
+ expressions.Div: db_func.Div,
123
+ expressions.Mod: db_func.Mod,
124
+ }
125
+
126
+ aggregate_nodes = {
127
+ expressions.Avg: aggregates.Avg,
128
+ expressions.Count: aggregates.Count,
129
+ expressions.Max: aggregates.Max,
130
+ expressions.Min: aggregates.Min,
131
+ expressions.Stddev: aggregates.StdDev,
132
+ expressions.Variance: aggregates.Variance,
133
+ expressions.Sum: aggregates.Sum,
134
+ }
135
+
136
+ function_nodes = {
137
+ expressions.Coalesce: functions.Coalesce,
138
+ expressions.Concat: functions.Concat,
139
+ expressions.Greatest: functions.Greatest,
140
+ expressions.Least: functions.Least,
141
+ expressions.Abs: functions.Abs,
142
+ expressions.Ceil: functions.Ceil,
143
+ expressions.Floor: functions.Floor,
144
+ expressions.Mod: functions.Mod,
145
+ expressions.Left: functions.Left,
146
+ expressions.Right: functions.Right,
147
+ expressions.Length: functions.Length,
148
+ expressions.Lower: functions.Lower,
149
+ expressions.Upper: functions.Upper,
150
+ expressions.SubstringIndex: functions.StrIndex,
151
+ }
152
+
153
+ types = {
154
+ expressions.DataType.Type.BIGDECIMAL: models.DecimalField(
155
+ max_digits=20, decimal_places=2
156
+ ), # TODO variable?
157
+ expressions.DataType.Type.DECIMAL: models.DecimalField(
158
+ max_digits=20, decimal_places=2
159
+ ), # TODO variable?
160
+ expressions.DataType.Type.BIGINT: models.BigIntegerField(),
161
+ expressions.DataType.Type.BIGSERIAL: models.BigIntegerField(),
162
+ expressions.DataType.Type.INT: models.IntegerField(),
163
+ expressions.DataType.Type.BOOLEAN: models.BooleanField(),
164
+ expressions.DataType.Type.JSON: models.JSONField(),
165
+ expressions.DataType.Type.JSONB: models.JSONField(),
166
+ expressions.DataType.Type.DOUBLE: models.FloatField(),
167
+ expressions.DataType.Type.FLOAT: models.FloatField(),
168
+ expressions.DataType.Type.TEXT: models.TextField(),
169
+ expressions.DataType.Type.TIME: models.TimeField(),
170
+ expressions.DataType.Type.TIMESTAMPTZ: models.DateTimeField(),
171
+ expressions.DataType.Type.DATETIME: models.DateTimeField(),
172
+ expressions.DataType.Type.DATE: models.DateField(),
173
+ }
174
+
175
+
176
+ class Query:
177
+ def __init__(self, sql, tables, placeholders, timezone, default_limit):
178
+ self.sql = sql
179
+ self.tables = tables
180
+ self.timezone = timezone
181
+ self.placeholders = placeholders or {}
182
+ self.default_limit = default_limit
183
+
184
+ def _to_column_path(self, expression):
185
+ """
186
+ Hack an expression like part1.part2.part3.part4.part5 to [part1, part2, part3, part4, part5]
187
+ """
188
+ if isinstance(expression, expressions.Dot):
189
+ return [
190
+ *self._to_column_path(expression.this),
191
+ expression.expression.this,
192
+ ]
193
+ elif isinstance(expression, expressions.Column):
194
+ return [
195
+ x.this
196
+ for x in [
197
+ expression.args.get("catalog"),
198
+ expression.args.get("db"),
199
+ expression.args.get("table"),
200
+ expression.args.get("this"),
201
+ ]
202
+ if x
203
+ ]
204
+ else:
205
+ raise TypeError("Invalid type")
206
+
207
+ def _expression_to_django(self, expression, **kwargs):
208
+ table = kwargs["table"]
209
+ aggregate_names = kwargs["aggregate_names"]
210
+ parent_table_stack = kwargs.get("parent_table_stack", [])
211
+ if isinstance(expression, (expressions.Column, expressions.Dot)):
212
+ cp = self._to_column_path(expression)
213
+ if len(cp) == 1 and aggregate_names and cp[0] in aggregate_names:
214
+ return F(aggregate_names[cp[0]])
215
+ return table.resolve_column_path(cp)
216
+ elif (
217
+ isinstance(expression, expressions.Anonymous)
218
+ and expression.this.lower() == "outer"
219
+ ):
220
+
221
+ def _resolve(e, parent_stack, depth):
222
+ if isinstance(e, expressions.Anonymous) and e.this.lower() == "outer":
223
+ if not parent_stack:
224
+ raise QueryError("OUTER nested too far")
225
+ return _resolve(e.expressions[0], parent_stack[:-1], depth + 1)
226
+ elif isinstance(e, (expressions.Column, expressions.Dot)):
227
+ if not parent_stack:
228
+ raise QueryError("OUTER nested too far")
229
+ return self._to_column_path(e), parent_stack[-1], depth
230
+ else:
231
+ raise QueryNotSupported("Invalid argument to OUTER()")
232
+
233
+ cp, lookup_table, depth = _resolve(
234
+ expression.expressions[0], parent_table_stack, 1
235
+ )
236
+ p = lookup_table.resolve_column_path(cp)
237
+ if isinstance(p, F):
238
+ p = p.name
239
+ else:
240
+ raise QueryNotSupported(f"Cannot use '{cp}' in OUTER()")
241
+ for i in range(depth):
242
+ p = OuterRef(p)
243
+ return p
244
+ elif isinstance(expression, expressions.Alias):
245
+ return self._expression_to_django(expression.this, **kwargs)
246
+ elif isinstance(expression, expressions.Literal):
247
+ return Value(expression.to_py())
248
+ elif isinstance(expression, expressions.Boolean):
249
+ return Value(expression.this)
250
+ elif isinstance(expression, expressions.Star):
251
+ return "*"
252
+ elif isinstance(expression, expressions.Cast):
253
+ return Cast(
254
+ self._expression_to_django(expression.this, **kwargs),
255
+ output_field=types[expression.to.this],
256
+ )
257
+ elif isinstance(expression, expressions.Extract):
258
+ if isinstance(expression.this, expressions.Var):
259
+ lookup_name = expression.this.this.lower()
260
+ else:
261
+ lookup_name = expression.this.to_py()
262
+ if lookup_name not in (
263
+ "year",
264
+ "iso_year",
265
+ "quarter",
266
+ "month",
267
+ "day",
268
+ "week",
269
+ "week_day",
270
+ "iso_week_day",
271
+ "hour",
272
+ "minute",
273
+ "second",
274
+ ):
275
+ raise QueryNotSupported(f"Unsupported extract value '{lookup_name}'")
276
+ return functions.Extract(
277
+ self._expression_to_django(expression.expression, **kwargs),
278
+ lookup_name=lookup_name,
279
+ tzinfo=self.timezone,
280
+ )
281
+ elif (
282
+ isinstance(expression, expressions.Anonymous)
283
+ and expression.this.lower() == "datetrunc"
284
+ ):
285
+ if len(expression.expressions) != 2:
286
+ raise QueryError("Function datetrunc takes exactly two arguments")
287
+ try:
288
+ lookup_name = expression.expressions[0].to_py()
289
+ if lookup_name not in (
290
+ "year",
291
+ "quarter",
292
+ "month",
293
+ "day",
294
+ "week",
295
+ "hour",
296
+ "minute",
297
+ "second",
298
+ ):
299
+ raise QueryNotSupported(
300
+ f"Unsupported truncation type '{lookup_name}'"
301
+ )
302
+ except ValueError:
303
+ raise QueryNotSupported("Unsupported truncation type")
304
+ return functions.Trunc(
305
+ self._expression_to_django(expression.expressions[1], **kwargs),
306
+ lookup_name,
307
+ tzinfo=self.timezone,
308
+ )
309
+ elif type(expression) in function_nodes:
310
+ if expression.args.get("this"):
311
+ args = [self._expression_to_django(expression.this, **kwargs)]
312
+ else:
313
+ args = []
314
+ if expression.args.get("expression"):
315
+ args += [self._expression_to_django(expression.expression, **kwargs)]
316
+ args += [
317
+ self._expression_to_django(e, **kwargs) for e in expression.expressions
318
+ ]
319
+ cls = function_nodes[type(expression)]
320
+ if (cls.arity and cls.arity != len(args)) or any(
321
+ v is not None
322
+ and k
323
+ not in (
324
+ "this",
325
+ "expression",
326
+ "expressions",
327
+ "ignore_nulls",
328
+ "safe",
329
+ "coalesce",
330
+ )
331
+ for k, v in expression.args.items()
332
+ ):
333
+ raise QueryNotSupported(
334
+ f"Wrong number of arguments for function {expression.sql()}"
335
+ )
336
+ return cls(*args)
337
+ elif isinstance(expression, expressions.Round):
338
+ args = [
339
+ self._expression_to_django(expression.this, **kwargs),
340
+ ]
341
+ if expression.args.get("decimals"):
342
+ args.append(
343
+ self._expression_to_django(expression.args["decimals"], **kwargs)
344
+ )
345
+ if any(
346
+ v is not None and k not in ("this", "decimals")
347
+ for k, v in expression.args.items()
348
+ ):
349
+ raise QueryNotSupported(
350
+ f"Wrong number of arguments for function {expression.sql()}"
351
+ )
352
+ return functions.Round(*args)
353
+ elif isinstance(expression, expressions.Pad):
354
+ args = [
355
+ self._expression_to_django(expression.this, **kwargs),
356
+ self._expression_to_django(expression.expression, **kwargs),
357
+ ]
358
+ if expression.args.get("fill_pattern"):
359
+ args.append(
360
+ self._expression_to_django(
361
+ expression.args["fill_pattern"], **kwargs
362
+ )
363
+ )
364
+ if any(
365
+ v is not None
366
+ and k not in ("this", "expression", "fill_pattern", "is_left")
367
+ for k, v in expression.args.items()
368
+ ):
369
+ raise QueryNotSupported(
370
+ f"Wrong number of arguments for function {expression.sql()}"
371
+ )
372
+ if expression.args["is_left"]:
373
+ return functions.LPad(*args)
374
+ else:
375
+ return functions.RPad(*args)
376
+ elif isinstance(expression, expressions.StrPosition):
377
+ if not expression.args.get("substr"):
378
+ raise QueryNotSupported(
379
+ f"Wrong number of arguments for function {expression.sql()}"
380
+ )
381
+ args = [
382
+ self._expression_to_django(expression.this, **kwargs),
383
+ self._expression_to_django(expression.args["substr"], **kwargs),
384
+ ]
385
+ if any(
386
+ v is not None and k not in ("this", "substr")
387
+ for k, v in expression.args.items()
388
+ ):
389
+ raise QueryNotSupported(
390
+ f"Wrong number of arguments for function {expression.sql()}"
391
+ )
392
+ return functions.StrIndex(*args)
393
+ elif isinstance(expression, expressions.Substring):
394
+ if not expression.args.get("start"):
395
+ raise QueryNotSupported(
396
+ f"Wrong number of arguments for function {expression.sql()}"
397
+ )
398
+ args = [
399
+ self._expression_to_django(expression.this, **kwargs),
400
+ self._expression_to_django(expression.args["start"], **kwargs),
401
+ ]
402
+ if expression.args.get("length"):
403
+ args.append(
404
+ self._expression_to_django(expression.args["length"], **kwargs)
405
+ )
406
+ if any(
407
+ v is not None and k not in ("this", "start", "length")
408
+ for k, v in expression.args.items()
409
+ ):
410
+ raise QueryNotSupported(
411
+ f"Wrong number of arguments for function {expression.sql()}"
412
+ )
413
+ return functions.Substr(*args)
414
+ elif isinstance(expression, expressions.Replace):
415
+ args = [
416
+ self._expression_to_django(expression.this, **kwargs),
417
+ self._expression_to_django(expression.expression, **kwargs),
418
+ ]
419
+ if expression.args.get("replacement"):
420
+ args.append(
421
+ self._expression_to_django(expression.args["replacement"], **kwargs)
422
+ )
423
+ return functions.Replace(*args)
424
+ elif isinstance(expression, expressions.DPipe):
425
+ return functions.Concat(
426
+ self._expression_to_django(expression.this, **kwargs),
427
+ self._expression_to_django(expression.expression, **kwargs),
428
+ )
429
+ elif (
430
+ isinstance(expression, expressions.Filter)
431
+ and type(expression.this) in aggregate_nodes
432
+ ):
433
+ if isinstance(expression.this.this, expressions.Distinct):
434
+ args = [
435
+ self._expression_to_django(e, **kwargs)
436
+ for e in expression.this.this.expressions
437
+ ]
438
+ distinct = True
439
+ else:
440
+ args = [self._expression_to_django(expression.this.this, **kwargs)]
441
+ distinct = False
442
+ if len(args) > 1:
443
+ raise QueryNotSupported(
444
+ "Multiple arguments to aggregate expression not supported"
445
+ )
446
+ return aggregate_nodes[type(expression.this)](
447
+ *args,
448
+ distinct=distinct,
449
+ filter=self._expression_to_django(expression.expression.this, **kwargs),
450
+ )
451
+ elif type(expression) in aggregate_nodes:
452
+ if isinstance(expression.this, expressions.Distinct):
453
+ args = [
454
+ self._expression_to_django(e, **kwargs)
455
+ for e in expression.this.expressions
456
+ ]
457
+ distinct = True
458
+ else:
459
+ args = [self._expression_to_django(expression.this, **kwargs)]
460
+ distinct = False
461
+ if len(args) > 1:
462
+ raise QueryNotSupported(
463
+ "Multiple arguments to aggregate expression not supported"
464
+ )
465
+ return aggregate_nodes[type(expression)](*args, distinct=distinct)
466
+ elif type(expression) in math_binary_nodes:
467
+ lhs = self._expression_to_django(expression.left, **kwargs)
468
+ rhs = self._expression_to_django(expression.right, **kwargs)
469
+ return math_binary_nodes[type(expression)](
470
+ lhs,
471
+ rhs,
472
+ )
473
+ elif isinstance(expression, expressions.Order):
474
+ raise QueryNotSupported("ORDER not supported in expression")
475
+ elif isinstance(expression, expressions.Null):
476
+ # TODO do we need to guess output_field better?
477
+ return Value(None, output_field=models.TextField(null=True))
478
+ elif isinstance(expression, (expressions.NullSafeEQ, expressions.NullSafeNEQ)):
479
+ raise QueryNotSupported("IS (NOT) DISTINCT not supported")
480
+ elif isinstance(expression, expressions.Paren):
481
+ return self._expression_to_django(expression.this, **kwargs)
482
+ elif isinstance(expression, expressions.Neg):
483
+ return -self._expression_to_django(expression.this, **kwargs)
484
+ elif isinstance(
485
+ expression,
486
+ (
487
+ expressions.BitwiseNot,
488
+ expressions.BitwiseOr,
489
+ expressions.BitwiseXor,
490
+ expressions.BitwiseAnd,
491
+ expressions.BitwiseCount,
492
+ expressions.BitwiseLeftShift,
493
+ expressions.BitwiseRightShift,
494
+ ),
495
+ ):
496
+ raise QueryNotSupported("Bitwise operations not supported")
497
+ elif type(expression) in boolean_expression_nodes:
498
+ return ExpressionWrapper(
499
+ boolean_expression_nodes[type(expression)](
500
+ self._expression_to_django(expression.left, **kwargs),
501
+ self._expression_to_django(expression.right, **kwargs),
502
+ ),
503
+ output_field=BooleanField(),
504
+ )
505
+ elif isinstance(expression, expressions.Between):
506
+ return Q(
507
+ ExpressionWrapper(
508
+ db_func.GreaterEqualThan(
509
+ self._expression_to_django(expression.this, **kwargs),
510
+ self._expression_to_django(expression.args["low"], **kwargs),
511
+ ),
512
+ output_field=BooleanField(),
513
+ )
514
+ ) & Q(
515
+ ExpressionWrapper(
516
+ db_func.LowerEqualThan(
517
+ self._expression_to_django(expression.this, **kwargs),
518
+ self._expression_to_django(expression.args["high"], **kwargs),
519
+ ),
520
+ output_field=BooleanField(),
521
+ )
522
+ )
523
+ elif isinstance(expression, expressions.In):
524
+ if expression.args.get("query"):
525
+ return ExpressionWrapper(
526
+ lookups.In(
527
+ self._expression_to_django(expression.this, **kwargs),
528
+ self._expression_to_django(expression.args["query"], **kwargs),
529
+ ),
530
+ output_field=BooleanField(),
531
+ )
532
+ else:
533
+ return ExpressionWrapper(
534
+ lookups.In(
535
+ self._expression_to_django(expression.this, **kwargs),
536
+ [
537
+ self._expression_to_django(e, **kwargs)
538
+ for e in expression.expressions
539
+ ],
540
+ ),
541
+ output_field=BooleanField(),
542
+ )
543
+ elif isinstance(expression, expressions.And):
544
+ return self._expression_to_django(
545
+ expression.left, **kwargs
546
+ ) & self._expression_to_django(expression.right, **kwargs)
547
+ elif isinstance(expression, expressions.Or):
548
+ return self._expression_to_django(
549
+ expression.left, **kwargs
550
+ ) | self._expression_to_django(expression.right, **kwargs)
551
+ elif isinstance(expression, expressions.Not):
552
+ return ~self._expression_to_django(expression.this, **kwargs)
553
+ elif isinstance(expression, expressions.Case):
554
+ default = None
555
+ whens = []
556
+ if expression.this:
557
+ for w in expression.args.get("ifs", []):
558
+ whens.append(
559
+ models.When(
560
+ db_func.Equal(
561
+ self._expression_to_django(expression.this, **kwargs),
562
+ self._expression_to_django(w.this, **kwargs),
563
+ ),
564
+ then=self._expression_to_django(w.args["true"], **kwargs),
565
+ )
566
+ )
567
+ else:
568
+ for w in expression.args.get("ifs", []):
569
+ whens.append(
570
+ models.When(
571
+ self._expression_to_django(w.this, **kwargs),
572
+ then=self._expression_to_django(w.args["true"], **kwargs),
573
+ )
574
+ )
575
+ if expression.args.get("default"):
576
+ default = self._expression_to_django(
577
+ expression.args["default"], **kwargs
578
+ )
579
+ return NumericAwareCase(*whens, default=default)
580
+ elif isinstance(expression, expressions.CurrentDate):
581
+ return functions.TruncDate(functions.Now(), tzinfo=self.timezone)
582
+ elif isinstance(expression, expressions.CurrentTime):
583
+ return functions.TruncTime(functions.Now(), tzinfo=self.timezone)
584
+ elif isinstance(expression, expressions.CurrentTimestamp):
585
+ return functions.Now()
586
+ elif isinstance(expression, expressions.Subquery):
587
+ if not isinstance(expression.this, expressions.Select):
588
+ raise QueryNotSupported("Only SELECT subqueries are supported")
589
+ qs, _ = self._select_to_qs(
590
+ expression.this, parent_table_stack=parent_table_stack + [table]
591
+ )
592
+ return db_func.AutoTypedSubquery(
593
+ qs,
594
+ )
595
+ elif isinstance(expression, expressions.Exists):
596
+ if not isinstance(expression.this, expressions.Select):
597
+ raise QueryNotSupported("Only SELECT subqueries are supported")
598
+ qs, _ = self._select_to_qs(
599
+ expression.this, parent_table_stack=parent_table_stack + [table]
600
+ )
601
+ return models.Exists(qs)
602
+ elif isinstance(expression, expressions.Placeholder):
603
+ if expression.name == "?":
604
+ raise QueryError("Placeholder must be named")
605
+ if expression.name not in self.placeholders:
606
+ raise QueryError(f"Placeholder '{expression.name}' not filled")
607
+ return Value(self.placeholders[expression.name])
608
+ elif isinstance(expression, expressions.JSONExtract):
609
+ if isinstance(expression.expression, expressions.JSONPath):
610
+ k = self._expression_to_django(expression.this, **kwargs)
611
+ for pathel in expression.expression.expressions:
612
+ if isinstance(pathel, expressions.JSONPathRoot):
613
+ pass
614
+ elif isinstance(pathel, expressions.JSONPathKey):
615
+ k = KeyTransform(
616
+ pathel.this,
617
+ k,
618
+ )
619
+ else:
620
+ raise QueryNotSupported("Advanced JSON path is not supported")
621
+ return k
622
+ elif isinstance(expression.expression, expressions.Literal):
623
+ return KeyTransform(
624
+ expression.expression.this,
625
+ self._expression_to_django(expression.this, **kwargs),
626
+ )
627
+ elif isinstance(expression.expression, expressions.Column):
628
+ return KeyTransform(
629
+ expression.expression.this.this,
630
+ self._expression_to_django(expression.this, **kwargs),
631
+ )
632
+ elif isinstance(expression.expression, expressions.Identifier):
633
+ return KeyTransform(
634
+ expression.expression.this.this,
635
+ self._expression_to_django(expression.this, **kwargs),
636
+ )
637
+ else:
638
+ raise QueryNotSupported(f"Unsupported JSON path: {expression.sql()}")
639
+ else:
640
+ raise QueryNotSupported(f"Unsupported expression: {expression.sql()}")
641
+
642
+ def _expression_to_name(self, expression):
643
+ if isinstance(expression, (expressions.Column, expressions.Dot)):
644
+ return ".".join(self._to_column_path(expression))
645
+ elif isinstance(expression, expressions.Literal):
646
+ return str(expression.this)
647
+ elif isinstance(expression, expressions.Alias):
648
+ return expression.output_name
649
+ elif type(expression) in aggregate_nodes:
650
+ return expression.sql()
651
+ else:
652
+ return expression.sql()
653
+
654
+ def _where_to_django(self, node, **kwargs):
655
+ return self._expression_to_django(node, **kwargs)
656
+
657
+ def _select_to_qs(self, root, parent_table_stack):
658
+ table = root.args["from_"].this
659
+ if not isinstance(table, expressions.Table):
660
+ raise QueryNotSupported("Unsupported FROM statement")
661
+ if table.args.get("alias"):
662
+ raise QueryNotSupported("Table alias not supported")
663
+ if table.args.get("db"):
664
+ raise QueryNotSupported("Database names not supported")
665
+ if root.args.get("joins"):
666
+ raise QueryNotSupported("SELECT from multiple tables not supported")
667
+
668
+ if table.this.this not in self.tables:
669
+ raise QueryNotSupported(f"Table {table.this} not found")
670
+
671
+ table = self.tables[table.this.this]
672
+
673
+ qs = table.base_qs
674
+
675
+ if root.args.get("where"):
676
+ qs = qs.filter(
677
+ self._where_to_django(
678
+ root.args["where"].this,
679
+ table=table,
680
+ aggregate_names=[],
681
+ parent_table_stack=parent_table_stack,
682
+ )
683
+ )
684
+
685
+ group_args = []
686
+ if root.args.get("group"):
687
+ for i, e in enumerate(root.args["group"]):
688
+ django_e = self._expression_to_django(
689
+ e,
690
+ table=table,
691
+ aggregate_names=[],
692
+ parent_table_stack=parent_table_stack,
693
+ )
694
+ group_args.append(django_e)
695
+
696
+ values_args = {}
697
+ values_names = {}
698
+ aggregations = {}
699
+ name_to_aggregation = {}
700
+ for i, e in enumerate(root.args["expressions"]):
701
+ if isinstance(e, expressions.Star):
702
+ raise QueryNotSupported("SELECT * is not supported")
703
+ else:
704
+ n = self._expression_to_name(e)
705
+ while n in values_names or n in aggregations:
706
+ n += "_"
707
+
708
+ # TODO We sould validate that everything selected in a GROUP BY query is either an aggregate, part of
709
+ # the grouping, or a literal. However, I have not found a safe way to validate yet and it's not a big deal.
710
+ django_e = self._expression_to_django(
711
+ e,
712
+ table=table,
713
+ aggregate_names=[],
714
+ parent_table_stack=parent_table_stack,
715
+ )
716
+ if isinstance(django_e, aggregates.Aggregate):
717
+ # We do not use the alias names given by the user, first to ensure uniqueness, but also Django has
718
+ # had some SQL injection vulns recently that affected user-chosen annotate targets. We'll remap
719
+ # ourselves later.
720
+ aggregations[f"expr{i}"] = django_e
721
+ values_names[f"expr{i}"] = n
722
+ name_to_aggregation[n] = f"expr{i}"
723
+ else:
724
+ values_args[f"expr{i}"] = self._expression_to_django(
725
+ e,
726
+ table=table,
727
+ aggregate_names=[],
728
+ parent_table_stack=parent_table_stack,
729
+ )
730
+ values_names[f"expr{i}"] = n
731
+
732
+ if root.args.get("distinct"):
733
+ qs = qs.distinct()
734
+
735
+ order_by = []
736
+ if root.args.get("order"):
737
+ for i, ordered in enumerate(root.args["order"].args["expressions"]):
738
+ if (
739
+ isinstance(ordered.this, expressions.Column)
740
+ and isinstance(ordered.this.this, expressions.Identifier)
741
+ and ordered.this.this.this in name_to_aggregation
742
+ ):
743
+ order_by.append(
744
+ OrderBy(
745
+ F(name_to_aggregation[ordered.this.this.this]),
746
+ descending=ordered.args["desc"],
747
+ nulls_first=True if ordered.args["nulls_first"] else None,
748
+ nulls_last=True
749
+ if not ordered.args["nulls_first"]
750
+ else None,
751
+ )
752
+ )
753
+ else:
754
+ order_by.append(
755
+ OrderBy(
756
+ self._expression_to_django(
757
+ ordered.this, table=table, aggregate_names=[]
758
+ ),
759
+ descending=ordered.args["desc"],
760
+ nulls_first=True if ordered.args["nulls_first"] else None,
761
+ nulls_last=True
762
+ if not ordered.args["nulls_first"]
763
+ else None,
764
+ )
765
+ )
766
+
767
+ if parent_table_stack:
768
+ if len(values_args) + len(aggregations) != 1:
769
+ raise QueryError("Subquery must return exactly 1 column")
770
+
771
+ if group_args and not aggregations:
772
+ # Django will not do proper group by without any aggregations, so we need to do trickery
773
+ aggregations = {"_grp_trick": models.Count("*")}
774
+
775
+ if aggregations:
776
+ if group_args:
777
+ qs = (
778
+ qs.order_by()
779
+ .annotate(
780
+ **{f"grp{i}": v for i, v in enumerate(group_args)},
781
+ **values_args,
782
+ )
783
+ .values(
784
+ *[f"grp{i}" for i, v in enumerate(group_args)],
785
+ *values_args.keys(),
786
+ )
787
+ .annotate(**aggregations)
788
+ )
789
+
790
+ if root.args.get("having"):
791
+ qs = qs.filter(
792
+ self._where_to_django(
793
+ root.args["having"].this,
794
+ table=table,
795
+ aggregate_names=name_to_aggregation,
796
+ )
797
+ )
798
+ elif parent_table_stack:
799
+ # Django can't use .aggregate() in subqueries, we need to do trickery
800
+ qs = (
801
+ qs.annotate(_agg_trick=Value("1"))
802
+ .values("_agg_trick")
803
+ .annotate(**aggregations)
804
+ .values(list(aggregations.keys())[0])
805
+ )
806
+ else:
807
+ qs = qs.aggregate(**aggregations)
808
+ else:
809
+ qs = qs.values(**values_args)
810
+
811
+ if not isinstance(qs, dict):
812
+ if order_by:
813
+ qs = qs.order_by(*order_by)
814
+
815
+ if root.args.get("offset") and root.args.get("limit"):
816
+ if not isinstance(root.args["limit"].expression, expressions.Literal):
817
+ raise QueryNotSupported("LIMIT may only contain literal numbers")
818
+ if not isinstance(root.args["offset"].expression, expressions.Literal):
819
+ raise QueryNotSupported("OFFSET may only contain literal numbers")
820
+ offset = int(root.args["offset"].expression.this)
821
+ limit = int(root.args["limit"].expression.this)
822
+ qs = qs[offset : offset + limit]
823
+ elif root.args.get("limit"):
824
+ if not isinstance(root.args["limit"].expression, expressions.Literal):
825
+ raise QueryNotSupported("LIMIT may only contain literal numbers")
826
+ limit = int(root.args["limit"].expression.this)
827
+ qs = qs[:limit]
828
+ elif root.args.get("offset"):
829
+ if not isinstance(root.args["offset"].expression, expressions.Literal):
830
+ raise QueryNotSupported("OFFSET may only contain literal numbers")
831
+ offset = int(root.args["offset"].expression.this)
832
+ qs = qs[offset:]
833
+ elif self.default_limit and not parent_table_stack:
834
+ qs = qs[: self.default_limit]
835
+
836
+ return qs, values_names
837
+
838
+ def evaluate(self):
839
+ try:
840
+ ast = parse_one(self.sql, dialect=OrmqlDialect)
841
+ except ParseError as e:
842
+ raise QueryNotSupported(str(e)) from e
843
+
844
+ if settings.DEBUG:
845
+ print(f"Parsed statement: {ast!r}")
846
+
847
+ if not isinstance(ast, expressions.Select):
848
+ raise QueryNotSupported("Only SELECT queries are supported")
849
+
850
+ try:
851
+ qs, values_names = self._select_to_qs(ast, [])
852
+ except QueryError:
853
+ raise
854
+ except Exception as e:
855
+ raise QueryError("Query parsing failed") from e
856
+
857
+ if isinstance(qs, dict):
858
+ yield {values_names[k]: v for k, v in qs.items()}
859
+ else:
860
+ if settings.DEBUG:
861
+ print(f"Generated statement: {qs.query!s}")
862
+ for row in qs:
863
+ yield {values_names[k]: v for k, v in row.items() if k in values_names}