waxsql 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
waxsql/gen/expr.py ADDED
@@ -0,0 +1,723 @@
1
+ """Typed expression generator.
2
+
3
+ Role: the type-driven heart of the query generator (pillar 1 in the
4
+ project's architectural pillars). Every other generator that needs an
5
+ expression — SELECT targets, WHERE predicates, JOIN ON conditions,
6
+ HAVING, FILTER, window-function arguments, etc. — funnels through
7
+ `gen_expr` so that type compatibility is enforced by construction
8
+ rather than checked after the fact.
9
+
10
+ The core recursion of the query generator. Given a `GenContext` and a
11
+ target `PgType`, produce an `Expr` whose actual type implicitly casts
12
+ to the target. The candidate productions are:
13
+
14
+ * **Column reference** — pull a binding from the current scope whose
15
+ type implicitly casts to the target. Available at any depth.
16
+
17
+ * **Literal** — synthesize a typed constant via `gen_literal`. Always
18
+ available (the catch-all when nothing else fits).
19
+
20
+ * **Function call** — pick a scalar function from the catalog whose
21
+ return type matches; recursively generate one arg per signature
22
+ parameter. Recursive — only available when the depth budget
23
+ allows.
24
+
25
+ * **Binary operator** — same as function call but for binary ops.
26
+
27
+ * **Aggregate function call** — only when `ctx.allow_aggregates`
28
+ is True AND we're not already inside an aggregate (`in_aggregate`
29
+ False). Same shape as a scalar function call, but the recursive
30
+ arg-generation context has `in_aggregate=True`, which prevents
31
+ nested aggregates like `sum(count(...))` (PG syntax error).
32
+ Recursive, gated on the depth budget.
33
+
34
+ * **Subquery productions** — `(SELECT ...)` (scalar), `EXISTS (...)`,
35
+ or `<expr> IN (...)`. Treated as leaves of the *outer* expression
36
+ tree (no expression-depth budget consumed for them) but rationed
37
+ by their own `subquery_depth_remaining` budget. EXISTS and IN are
38
+ only candidates when target_type is BOOL; scalar subqueries are
39
+ candidates for any target type. All three require `allow_subquery`
40
+ AND the matching feature flag.
41
+
42
+ The depth budget rations recursion. Every recursive call descends by
43
+ one; when `ctx.at_leaf()`, only column refs and literals are
44
+ considered. Recursive productions are *additionally* downweighted as
45
+ depth shrinks (`(depth_remaining/max_depth) ** leaf_bias`), so trees
46
+ don't always max out their budget — they collapse early often enough
47
+ to keep output varied.
48
+
49
+ Determinism rules respected here:
50
+
51
+ * All randomness comes from `ctx.rng`, never the global `random`.
52
+ * Catalog lookup methods return sorted lists — iteration order is
53
+ stable across Python builds.
54
+ * Scope's `visible_columns` returns bindings in insertion order,
55
+ which is also stable for any deterministic FROM-clause walk.
56
+ """
57
+ from __future__ import annotations
58
+
59
+ import random
60
+ from collections.abc import Callable
61
+ from dataclasses import replace
62
+
63
+ from ..ast import (
64
+ BinaryOp, Cast, ColumnRef, Expr, FuncCall, Literal, OrderByItem,
65
+ )
66
+ from ..catalog import FuncKind, FuncSig
67
+ from ..config import (
68
+ FEATURE_EXISTS, FEATURE_IN_SUBQUERY, FEATURE_SCALAR_SUBQUERY,
69
+ FEATURE_WINDOW_FUNCTION,
70
+ )
71
+ from ..context import GenContext
72
+ from ..types import (
73
+ BOOL, DATE, FLOAT8, INT4, INT8, INTERVAL, JSONB, NUMERIC,
74
+ PgType, TEXT, TIMESTAMPTZ, UUID, VARCHAR, implicitly_castable,
75
+ )
76
+ from .window import gen_window_spec
77
+
78
+
79
+ # Probability of emitting a typed NULL when the literal production is
80
+ # selected. Keeps a steady NULL rate flowing through expressions so
81
+ # that the printer's `NULL::T` path stays exercised, without drowning
82
+ # the output in NULLs.
83
+ _NULL_PROB: float = 0.05
84
+
85
+ # Probability of emitting `count(*)` instead of `count(<arg>)` when
86
+ # the chosen aggregate happens to be count. Both forms are common in
87
+ # real SQL — `count(col)` counts non-NULL values, `count(*)` counts
88
+ # all rows including all-NULL. Half-and-half is a reasonable balance;
89
+ # tune if outputs feel monotonous in either direction.
90
+ _P_COUNT_STAR: float = 0.5
91
+
92
+
93
+ # Valid unit keywords for date_trunc's first arg. PG accepts these
94
+ # 13 strings (and a few abbreviations); arbitrary TEXT errors at
95
+ # runtime ("unit X not recognized"). The list is fixed at PG-spec
96
+ # values — see https://www.postgresql.org/docs/current/functions-
97
+ # datetime.html#FUNCTIONS-DATETIME-TRUNC for canonical names.
98
+ _DATE_TRUNC_UNITS: tuple[str, ...] = (
99
+ "microseconds", "milliseconds", "second", "minute",
100
+ "hour", "day", "week", "month", "quarter", "year",
101
+ "decade", "century", "millennium",
102
+ )
103
+
104
+ # Fixed JSONB-array literals for the jsonb_array_length arg-rewrite.
105
+ # Ordinary JSONB literals (`_JSONB_LIT` above) are objects; passing
106
+ # them to jsonb_array_length errors with "cannot get array length of
107
+ # a non-array". A small pool gives output variety while guaranteeing
108
+ # array-ness.
109
+ _JSONB_ARRAY_LITS: tuple[str, ...] = (
110
+ "[1, 2, 3]", '["alpha", "beta"]', "[true, false]", "[]",
111
+ )
112
+
113
+ # Ordered-set aggregates: must be called with a WITHIN GROUP clause
114
+ # (PG rejects them otherwise). Set membership drives the agg-branch
115
+ # special-case in gen_expr (and the parallel sites in gen/select.py)
116
+ # that routes through `gen_ordered_set_agg` rather than the regular
117
+ # construction path. Public name (no leading underscore) because
118
+ # select.py imports it.
119
+ ORDERED_SET_AGGREGATES: frozenset[str] = frozenset({
120
+ "percentile_cont", "percentile_disc",
121
+ })
122
+
123
+
124
+ def gen_ordered_set_agg(ctx: GenContext, f: FuncSig) -> FuncCall:
125
+ """Build an ordered-set aggregate FuncCall with the required
126
+ WITHIN GROUP clause.
127
+
128
+ For percentile_cont / percentile_disc: the call arg is a fraction
129
+ (FLOAT8 in [0,1]); the WITHIN GROUP ORDER BY is a numeric
130
+ expression (we use FLOAT8 to match the catalog's declared return
131
+ type). PG actually allows polymorphic ORDER BY types, but the
132
+ return type changes accordingly — sticking with FLOAT8 keeps the
133
+ catalog's declared return correct.
134
+
135
+ The fraction is drawn from a small fixed pool (0.25, 0.5, 0.75)
136
+ rather than a random float so output stays readable. Could
137
+ extend to arbitrary literals if more variety is wanted.
138
+
139
+ FILTER and OVER are deliberately not attached here — composing
140
+ every clause variant on every aggregate kind is more code than
141
+ polish-tier value justifies. Ordered-set aggs in real SQL are
142
+ almost always used bare anyway.
143
+ """
144
+ rng = ctx.rng
145
+
146
+ # Args: a single FLOAT8 fraction literal in [0, 1].
147
+ fraction = rng.choice((0.25, 0.5, 0.75))
148
+ args = (Literal(FLOAT8, fraction),)
149
+
150
+ # WITHIN GROUP ORDER BY: one FLOAT8 expression. The expression
151
+ # is evaluated per-row (pre-aggregation), so the same forbidden-
152
+ # combinations apply as for regular aggregate args:
153
+ # * No nested aggregates (in_aggregate=True empties aggs pool).
154
+ # * No window functions (allow_window=False blocks the window
155
+ # branch). PG rejects `agg(... window(...) OVER (...) ...)`
156
+ # with "aggregate function calls cannot contain window
157
+ # function calls" — and that includes window-style use of
158
+ # other aggregates like `percentile_cont(x) OVER (...)`.
159
+ wg_ctx = replace(
160
+ ctx.descend(),
161
+ in_aggregate=True,
162
+ allow_window=False,
163
+ in_window=False,
164
+ )
165
+ wg_expr = gen_expr(wg_ctx, FLOAT8)
166
+ wg_expr = coerce_to_param_type(wg_expr, FLOAT8)
167
+ within_group = (
168
+ OrderByItem(expr=wg_expr, direction=rng.choice(("ASC", "DESC"))),
169
+ )
170
+
171
+ return FuncCall(
172
+ f.returns, f.name, args, within_group=within_group,
173
+ )
174
+
175
+
176
+ def _replace_literal_zero(expr: Expr, expected_type: PgType) -> Expr:
177
+ """If `expr` is a numeric literal with value 0 (or 0.0), return a
178
+ fresh literal of `expected_type` with value 1 instead. Otherwise
179
+ return `expr` unchanged.
180
+
181
+ Used at division/modulo emission sites to suppress the trivial
182
+ `x / 0` / `x % 0` case that PG constant-folds to a 22012 error.
183
+ Doesn't try to recursively detect zero-folding through arithmetic
184
+ (`y - y`, `0 * x`, ...) — those are statistically rare and
185
+ catching them robustly would require a constant-folding pass.
186
+ The bare-literal case is the dominant one.
187
+ """
188
+ if not isinstance(expr, Literal):
189
+ return expr
190
+ if expr.value is None:
191
+ return expr
192
+ # Numeric types only; bool/text/etc. literals can't be zero in a
193
+ # division-relevant sense.
194
+ if expr.pg_type not in (INT4, INT8, NUMERIC, FLOAT8):
195
+ return expr
196
+ if expr.value == 0:
197
+ # Literal of `expected_type` with value 1. Keeps the type
198
+ # discipline that the divisor's type matches the operator's
199
+ # right-hand param (set up by coerce_to_param_type just before
200
+ # this call).
201
+ return Literal(expected_type, 1 if expected_type in (INT4, INT8) else 1.0)
202
+ return expr
203
+
204
+
205
+ # ---------------------------------------------------------------------------
206
+ # Per-function argument rewriters
207
+ # ---------------------------------------------------------------------------
208
+ #
209
+ # A few catalog functions declare a generic arg type (TEXT / JSONB) but
210
+ # require a CONSTRAINED value at PLAN time: date_trunc's first arg must be
211
+ # a unit keyword, jsonb_array_length's arg must be a JSON array, mod's
212
+ # divisor must not constant-fold to zero. The catalog can't express "this
213
+ # TEXT must be one of these keywords", so the generator patches the
214
+ # generated args. Centralizing the "which function needs which rewrite"
215
+ # knowledge in one registry (vs scattered `if f.name == ...` branches at
216
+ # the call site) gives that knowledge a single home — a future catalog
217
+ # change updates one table instead of hunting call sites. Each rewriter
218
+ # mutates `args_list` in place and consumes `ctx.rng` identically to its
219
+ # former inline form; at most one fires per call (catalog names are
220
+ # unique), so determinism is unchanged.
221
+
222
+
223
+ def _rewrite_mod_args(args_list: list[Expr], ctx: GenContext, f: FuncSig) -> None:
224
+ # `mod(x, 0)` constant-folds to a 22012 division-by-zero error at PLAN
225
+ # time. Same suppression as the `/` / `%` operator branch — replace a
226
+ # bare-literal-zero second arg with 1.
227
+ if len(args_list) >= 2:
228
+ args_list[1] = _replace_literal_zero(args_list[1], f.args[1])
229
+
230
+
231
+ def _rewrite_date_trunc_args(args_list: list[Expr], ctx: GenContext, f: FuncSig) -> None:
232
+ # date_trunc's first arg must be a recognized UNIT keyword. The catalog
233
+ # signature is generic TEXT; replace whatever gen_expr produced with a
234
+ # fresh literal from the whitelist. PG validates the unit at run AND
235
+ # PLAN time (when not behind a cast), so a generic text literal would
236
+ # make EXECUTE fail without this rewrite.
237
+ if len(args_list) >= 1:
238
+ args_list[0] = Literal(TEXT, ctx.rng.choice(_DATE_TRUNC_UNITS))
239
+
240
+
241
+ def _rewrite_jsonb_array_length_args(
242
+ args_list: list[Expr], ctx: GenContext, f: FuncSig
243
+ ) -> None:
244
+ # jsonb_array_length requires its arg to be a JSON ARRAY, not an
245
+ # object. Our default JSONB literal is `{"k":"v"}` (object), and
246
+ # jsonb_build_object produces objects — replace with a guaranteed-array
247
+ # literal. Same safety pattern as date_trunc, different constraint.
248
+ if len(args_list) >= 1:
249
+ args_list[0] = Literal(JSONB, ctx.rng.choice(_JSONB_ARRAY_LITS))
250
+
251
+
252
+ # Function name -> in-place argument rewriter. Resolved by exact-name
253
+ # lookup (never iterated), so this dict does not violate the
254
+ # no-iteration-over-collections-in-RNG-paths determinism rule.
255
+ _ARG_REWRITERS: dict[str, Callable[[list[Expr], GenContext, FuncSig], None]] = {
256
+ "mod": _rewrite_mod_args,
257
+ "date_trunc": _rewrite_date_trunc_args,
258
+ "jsonb_array_length": _rewrite_jsonb_array_length_args,
259
+ }
260
+
261
+
262
+ def coerce_to_param_type(arg: Expr, param_type: PgType) -> Expr:
263
+ """Wrap `arg` in an explicit `::param_type` cast iff its actual
264
+ type isn't exactly `param_type`.
265
+
266
+ Why this matters: the catalog declares each FuncSig's return type
267
+ based on the declared param types. PG resolves the actual call by
268
+ looking at the runtime arg types — and when our generator passes
269
+ an INT4 to a NUMERIC slot (legal via implicit cast), PG's overload
270
+ resolution may pick a DIFFERENT overload returning a DIFFERENT
271
+ type. Classic case: `floor(int8)` resolves to `floor(double
272
+ precision) → double precision` in PG (preferred-type tiebreak),
273
+ not `floor(numeric) → numeric` as our catalog claims.
274
+
275
+ Wrapping the arg with `::numeric` matches PG's resolution to our
276
+ catalog's metadata: PG sees `floor(int8::numeric)` and unambiguously
277
+ picks `floor(numeric)`. The same trick fixes 42804 unknown-type
278
+ failures from bare string literals — `'foo'::text` is concretely
279
+ typed where `'foo'` would be 'unknown' until PG infers context.
280
+
281
+ Single shared helper so all four expression-construction sites —
282
+ gen_expr's func/op/agg/window dispatch, plus select.py's two
283
+ aggregate-construction helpers — apply identical coercion.
284
+ """
285
+ if arg.pg_type == param_type:
286
+ return arg
287
+ return Cast(pg_type=param_type, expr=arg, target_type=param_type)
288
+
289
+
290
+ def should_emit_count_star(rng: random.Random) -> bool:
291
+ """Coin flip for the `count(*)` substitution. Lives here (next
292
+ to the probability constant) so all three agg-construction sites
293
+ — this module's agg/window dispatch, gen/select.py's
294
+ `_gen_aggregate_funccall` and `_gen_having_expr` — share one
295
+ source of truth for the rate. Caller is responsible for checking
296
+ `f.name == "count"` first; this helper just provides the dice."""
297
+ return rng.random() < _P_COUNT_STAR
298
+
299
+
300
+ # Probability of attaching a `FILTER (WHERE ...)` clause to an
301
+ # aggregate call. Lower than _P_COUNT_STAR because count(*) replaces
302
+ # an existing form (every count picks one form or the other), while
303
+ # FILTER is purely additive — every aggregate with a filter is
304
+ # strictly noisier than the same aggregate without one. 20% strikes
305
+ # a balance: visible but not dominant.
306
+ _P_FILTER: float = 0.2
307
+
308
+
309
+ def should_emit_filter(rng: random.Random) -> bool:
310
+ """Coin flip for attaching a FILTER clause. PG only accepts
311
+ FILTER on aggregate (and ordered-set/hypothetical-set) calls;
312
+ callers MUST check `f.kind == FuncKind.AGGREGATE` before using
313
+ this dice. Window-style aggregates are also eligible — the
314
+ `count(*) FILTER (WHERE x > 0) OVER (...)` shape is valid PG."""
315
+ return rng.random() < _P_FILTER
316
+
317
+
318
+ def gen_filter_predicate(ctx: GenContext) -> "Expr":
319
+ """Build a BOOL expression for use as a FILTER (WHERE ...)
320
+ predicate.
321
+
322
+ Generated under in_aggregate=True (no nested aggregates inside
323
+ FILTER — PG: "aggregate function calls cannot be nested") and
324
+ allow_window=False (no windows — PG: "window functions are not
325
+ allowed in FILTER"). depth-decremented so the predicate doesn't
326
+ blow the parent's depth budget."""
327
+ pred_ctx = replace(
328
+ ctx.descend(),
329
+ in_aggregate=True,
330
+ allow_window=False,
331
+ in_window=False,
332
+ )
333
+ return gen_expr(pred_ctx, BOOL)
334
+
335
+
336
+ # Curated word list for string literals. Deliberately small and free
337
+ # of apostrophes / non-ASCII so the printer's literal escaper has
338
+ # nothing surprising to handle. Add more here if string-heavy output
339
+ # starts feeling repetitive — there's no determinism implication.
340
+ _WORDS: tuple[str, ...] = (
341
+ "alpha", "beta", "gamma", "delta", "epsilon", "zeta",
342
+ "foo", "bar", "baz", "quux", "lorem", "ipsum",
343
+ "north", "south", "east", "west",
344
+ "one", "two", "three", "four", "five",
345
+ )
346
+
347
+ # Small "nice" numbers. Generators that want a richer numeric range
348
+ # can use the operator pool to compose them; this just gives the
349
+ # literal-leaf production something deterministic to draw from.
350
+ _INTS: tuple[int, ...] = (0, 1, 2, 3, 5, 7, 10, 42, 100, 1000)
351
+ _FLOATS: tuple[float, ...] = (0.0, 0.5, 1.0, 1.5, 2.0, 3.14, 10.0, 100.0)
352
+
353
+ # Fixed temporal/UUID/JSONB literals. The schema and operator pool
354
+ # provide the variety; the literal leaf just needs *some* parseable
355
+ # value of the right type.
356
+ _DATE_LIT = "2024-01-01"
357
+ _TIMESTAMPTZ_LIT = "2024-01-01 12:00:00+00"
358
+ _INTERVAL_LIT = "1 day"
359
+ _UUID_LIT = "00000000-0000-0000-0000-000000000000"
360
+ _JSONB_LIT = '{"k":"v"}'
361
+
362
+
363
+ # ---------------------------------------------------------------------------
364
+ # Literal generation
365
+ # ---------------------------------------------------------------------------
366
+
367
+ def gen_literal(rng: random.Random, t: PgType) -> Literal:
368
+ """Generate a typed literal of the requested type.
369
+
370
+ With probability `_NULL_PROB` the literal is NULL (which the
371
+ printer renders as `NULL::T`). Otherwise a type-appropriate
372
+ constant is drawn from the curated lists above.
373
+
374
+ Always returns a Literal — never raises, never returns None.
375
+ Unknown types fall through to a typed NULL.
376
+ """
377
+ if rng.random() < _NULL_PROB:
378
+ return Literal(t, None)
379
+
380
+ if t == INT4 or t == INT8:
381
+ return Literal(t, rng.choice(_INTS))
382
+ if t == NUMERIC or t == FLOAT8:
383
+ return Literal(t, rng.choice(_FLOATS))
384
+ if t == TEXT or t == VARCHAR:
385
+ return Literal(t, rng.choice(_WORDS))
386
+ if t == BOOL:
387
+ # Avoid importing a `True`/`False` constant just to give them
388
+ # equal weight; rng.random() < 0.5 is one fewer dependency.
389
+ return Literal(t, rng.random() < 0.5)
390
+ if t == DATE:
391
+ return Literal(t, _DATE_LIT)
392
+ if t == TIMESTAMPTZ:
393
+ return Literal(t, _TIMESTAMPTZ_LIT)
394
+ if t == INTERVAL:
395
+ return Literal(t, _INTERVAL_LIT)
396
+ if t == UUID:
397
+ return Literal(t, _UUID_LIT)
398
+ if t == JSONB:
399
+ return Literal(t, _JSONB_LIT)
400
+
401
+ # Unrecognized type (e.g. an array, or a future user type): fall
402
+ # back to typed NULL. The printer handles `NULL::<sql>` for any
403
+ # PgType because PgType.sql() always renders.
404
+ return Literal(t, None)
405
+
406
+
407
+ # ---------------------------------------------------------------------------
408
+ # Expression generation
409
+ # ---------------------------------------------------------------------------
410
+
411
+ def gen_expr(ctx: GenContext, target_type: PgType) -> Expr:
412
+ """Generate an expression whose value type implicitly casts to
413
+ `target_type`.
414
+
415
+ See module docstring for the production set and the depth-budget
416
+ discipline. The function is total — it always returns an Expr,
417
+ even when the catalog has nothing useful for the requested type
418
+ (literal fallback).
419
+ """
420
+ # INVARIANT: aggregate-context check happens BEFORE column-ref
421
+ # resolution. Inside an aggregate's argument that itself lives in
422
+ # a correlated subquery, an outer-column reference would force
423
+ # the OUTER query into implicit-single-group mode (PG: 42803).
424
+ # The local-bindings narrowing here is what prevents that mode
425
+ # shift — same reasoning as PG's own parse-analysis ordering.
426
+ # Column-ref candidates. Inside an aggregate's args within a
427
+ # correlated subquery, restrict to LOCAL bindings only — outer-
428
+ # column refs in this position trigger PG's implicit-grouping
429
+ # inference on the OUTER query (PARSE-tier error 42803,
430
+ # "column must appear in the GROUP BY clause"). Outside aggregates,
431
+ # outer-column refs are the whole point of correlation, so the
432
+ # full visible_columns walk is correct.
433
+ if ctx.in_aggregate and ctx.in_correlated_subquery:
434
+ cols = ctx.scope.local_bindings(of_type=target_type)
435
+ else:
436
+ cols = ctx.scope.visible_columns(of_type=target_type)
437
+ funcs = ctx.catalog.scalar_funcs_returning(target_type)
438
+ ops = ctx.catalog.binary_ops_returning(target_type)
439
+ # Aggregates are only candidates when the surrounding context
440
+ # allows them AND we're not already inside one (no nested aggs).
441
+ aggs = (
442
+ ctx.catalog.aggs_returning(target_type)
443
+ if (ctx.allow_aggregates and not ctx.in_aggregate)
444
+ else []
445
+ )
446
+ # Window-style calls: union of WINDOW-only functions (row_number,
447
+ # rank, lag, ...) and AGGREGATE functions (which can also be used
448
+ # as windows via OVER). Eligible when allow_window AND not
449
+ # in_window AND the feature flag is set. The candidate is the
450
+ # FuncSig itself; we'll attach a WindowSpec at dispatch time.
451
+ window_candidates: list[FuncSig] = []
452
+ if (ctx.allow_window
453
+ and not ctx.in_window
454
+ and FEATURE_WINDOW_FUNCTION in ctx.config.feature_flags):
455
+ for f in ctx.catalog.functions:
456
+ # Ordered-set aggregates (percentile_cont, percentile_disc,
457
+ # ...) cannot be used as plain window functions —
458
+ # `percentile_cont(x) OVER (...)` is a parse error in PG.
459
+ # They MUST go through the WITHIN GROUP path, which is
460
+ # in the regular agg branch. Exclude from window pool.
461
+ if (f.kind in (FuncKind.WINDOW, FuncKind.AGGREGATE)
462
+ and f.name not in ORDERED_SET_AGGREGATES
463
+ and implicitly_castable(f.returns, target_type)):
464
+ window_candidates.append(f)
465
+
466
+ # Build (weight, kind) candidate list. Each kind names a
467
+ # production we'll dispatch to below.
468
+ # DETERMINISM: the append order below is fixed and never depends
469
+ # on iteration over a set (the only set in scope here is implicit
470
+ # in candidate filtering, which uses pre-sorted catalog lookups).
471
+ # This is what makes `rng.choices` give identical picks across
472
+ # Python builds for the same (seed, target_type, ctx) tuple.
473
+ candidates: list[tuple[float, str]] = []
474
+
475
+ # Leaves are always available.
476
+ if cols:
477
+ candidates.append((ctx.config.column_ref_weight, "col"))
478
+ candidates.append((ctx.config.literal_weight, "lit"))
479
+
480
+ # Subqueries are also leaves at the OUTER expression level — they
481
+ # represent a single value (or a BOOL test) that the inner SELECT
482
+ # produces. They DON'T consume expression depth (the descent
483
+ # resets `depth_remaining` for the inner expression tree). Their
484
+ # rationing comes from `subquery_depth_remaining`, a separate
485
+ # budget. So no leaf_factor downweighting here.
486
+ #
487
+ # Note the gate uses `at_subquery_leaf()`, NOT `at_leaf()` — the
488
+ # two budgets are independent. An expression deep in its parent's
489
+ # depth budget can still spawn a subquery if the subquery budget
490
+ # has room, and vice versa.
491
+ if (ctx.allow_subquery
492
+ and not ctx.at_subquery_leaf()):
493
+ flags = ctx.config.feature_flags
494
+ if FEATURE_SCALAR_SUBQUERY in flags:
495
+ candidates.append(
496
+ (ctx.config.scalar_subquery_weight, "scalar_sub")
497
+ )
498
+ # EXISTS and IN return BOOL exactly. BOOL doesn't implicitly
499
+ # cast to anything else in our type graph, so the gate is a
500
+ # plain target-type check.
501
+ if target_type == BOOL:
502
+ if FEATURE_EXISTS in flags:
503
+ candidates.append(
504
+ (ctx.config.exists_weight, "exists")
505
+ )
506
+ if FEATURE_IN_SUBQUERY in flags:
507
+ candidates.append(
508
+ (ctx.config.in_subquery_weight, "in_sub")
509
+ )
510
+
511
+ # Recursive productions only when the budget allows — and even
512
+ # then, downweighted as the budget shrinks. The exponent is
513
+ # `leaf_bias` from the config: 1.0 is linear decay, >1.0 favors
514
+ # leaves more aggressively, <1.0 favors recursion.
515
+ if not ctx.at_leaf():
516
+ max_d = max(1, ctx.config.max_expr_depth) # avoid /0
517
+ leaf_factor = (ctx.depth_remaining / max_d) ** ctx.config.leaf_bias
518
+ if funcs:
519
+ candidates.append(
520
+ (ctx.config.func_call_weight * leaf_factor, "func")
521
+ )
522
+ if ops:
523
+ candidates.append(
524
+ (ctx.config.binary_op_weight * leaf_factor, "op")
525
+ )
526
+ if aggs:
527
+ candidates.append(
528
+ (ctx.config.aggregate_call_weight * leaf_factor, "agg")
529
+ )
530
+ if window_candidates:
531
+ # Window calls are recursive (args can be expressions),
532
+ # so they get the same leaf_factor downweighting as
533
+ # other recursive productions.
534
+ candidates.append(
535
+ (ctx.config.window_call_weight * leaf_factor, "window")
536
+ )
537
+
538
+ # Determinism note: candidate-list construction order above is
539
+ # fixed (col, lit, subquery branches, recursive branches), so the
540
+ # `kinds` / `weights` parallel lists are identical across runs for
541
+ # a given (target_type, ctx) pair. `rng.choices` consumes that
542
+ # ordering plus a single rng draw — same seed → same kind picked.
543
+ weights = [w for w, _ in candidates]
544
+ kinds = [k for _, k in candidates]
545
+ kind = ctx.rng.choices(kinds, weights=weights, k=1)[0]
546
+
547
+ if kind == "col":
548
+ b = ctx.rng.choice(cols)
549
+ return ColumnRef(b.type, b.table_alias, b.column)
550
+
551
+ if kind == "lit":
552
+ return gen_literal(ctx.rng, target_type)
553
+
554
+ if kind == "func":
555
+ f = ctx.rng.choice(funcs)
556
+ # Each argument is generated at depth-1 from the parent (not
557
+ # cumulatively across siblings — depth tracks tree depth,
558
+ # not call-list length). Args are coerced to their declared
559
+ # param type via explicit cast — see coerce_to_param_type
560
+ # docstring for why our claimed return type only matches PG
561
+ # when args match exactly.
562
+ child_ctx = ctx.descend()
563
+ args_list = [
564
+ coerce_to_param_type(gen_expr(child_ctx, arg_t), arg_t)
565
+ for arg_t in f.args
566
+ ]
567
+ # A few catalog functions need their generated args massaged for
568
+ # PLAN-time validity (unit keyword, JSON array, non-zero divisor).
569
+ # The _ARG_REWRITERS registry holds the per-function logic; at
570
+ # most one fires (names are unique). See the registry definition
571
+ # for the determinism note.
572
+ rewriter = _ARG_REWRITERS.get(f.name)
573
+ if rewriter is not None:
574
+ rewriter(args_list, ctx, f)
575
+ return FuncCall(f.returns, f.name, tuple(args_list))
576
+
577
+ if kind == "op":
578
+ o = ctx.rng.choice(ops)
579
+ child_ctx = ctx.descend()
580
+ # left/right are not None here: catalog.binary_ops_returning
581
+ # filters out unary ops (which would have left=None or
582
+ # right=None).
583
+ assert o.left is not None and o.right is not None
584
+ # Same coercion rationale as the func branch — operator
585
+ # overload resolution is even more sensitive than function
586
+ # resolution because there's no fallback "best-match" rule
587
+ # for operators (PG either finds an exact match or errors).
588
+ left = coerce_to_param_type(gen_expr(child_ctx, o.left), o.left)
589
+ right = coerce_to_param_type(gen_expr(child_ctx, o.right), o.right)
590
+ # Avoid `x / 0` and `x % 0` — PG constant-folds these and
591
+ # raises 22012 "division by zero" at PLAN time. We only catch
592
+ # the bare-literal-zero case here; complex zero-folding (e.g.
593
+ # `x / (y - y)`) would still slip through to runtime, but PG
594
+ # only constant-folds when both operands are constant, and a
595
+ # constant-folded zero requires a literal zero somewhere in
596
+ # the chain — replacing the bare-literal case statistically
597
+ # eliminates it.
598
+ if o.symbol in ("/", "%"):
599
+ right = _replace_literal_zero(right, o.right)
600
+ return BinaryOp(o.returns, o.symbol, left, right)
601
+
602
+ if kind == "agg":
603
+ # Aggregates are gated by BOTH `allow_aggregates` (set per-clause
604
+ # by gen_select: True in SELECT-list/HAVING, False in WHERE/ON
605
+ # /target-lists-of-subqueries) AND `in_aggregate=False` (no
606
+ # nested aggregates: `sum(count(...))` is a hard parse error).
607
+ # Both gates already applied above when building `aggs`.
608
+ f = ctx.rng.choice(aggs)
609
+ # Ordered-set aggregates (percentile_cont, percentile_disc,
610
+ # mode, ...) MUST be called with a WITHIN GROUP clause; bare
611
+ # `percentile_cont(0.5)` is a parse error. Special-case
612
+ # before the regular agg path so we always attach the clause.
613
+ if f.name in ORDERED_SET_AGGREGATES:
614
+ return gen_ordered_set_agg(ctx, f)
615
+ # FILTER decision is made up-front so it composes with the
616
+ # star form (`count(*) FILTER (WHERE ...)`) — PG accepts both
617
+ # together. All aggregates are FILTER-eligible by PG's rules.
618
+ filter_expr = (
619
+ gen_filter_predicate(ctx)
620
+ if should_emit_filter(ctx.rng)
621
+ else None
622
+ )
623
+ # `count(*)` special form: only valid for count, only at probability
624
+ # `_P_COUNT_STAR`. The catalog still carries `count(INT4)` and
625
+ # `count(TEXT)` as the source of count's INT8-returning behavior;
626
+ # this just substitutes the star form at emission time. PG rejects
627
+ # the * placeholder for any other aggregate (`sum(*)` etc.), so
628
+ # the gate is name-specific.
629
+ if f.name == "count" and should_emit_count_star(ctx.rng):
630
+ return FuncCall(
631
+ f.returns, "count", (), star=True, filter_=filter_expr,
632
+ )
633
+ # Args are generated under in_aggregate=True, which removes
634
+ # aggregates from the candidate pool of the recursive call —
635
+ # blocking sum(count(...)) and similar nested-aggregate forms
636
+ # that PG rejects with a parse-analysis error.
637
+ child_ctx = replace(ctx.descend(), in_aggregate=True)
638
+ args = tuple(
639
+ coerce_to_param_type(gen_expr(child_ctx, arg_t), arg_t)
640
+ for arg_t in f.args
641
+ )
642
+ return FuncCall(f.returns, f.name, args, filter_=filter_expr)
643
+
644
+ if kind == "window":
645
+ f = ctx.rng.choice(window_candidates)
646
+ # FILTER is only valid on AGGREGATE-kind window functions —
647
+ # `row_number() FILTER (WHERE ...)` is a parse error because
648
+ # row_number isn't an aggregate. The same gate that drives
649
+ # count(*) eligibility (kind==AGGREGATE) drives FILTER
650
+ # eligibility here.
651
+ filter_expr = (
652
+ gen_filter_predicate(ctx)
653
+ if (f.kind == FuncKind.AGGREGATE
654
+ and should_emit_filter(ctx.rng))
655
+ else None
656
+ )
657
+ # `count(*) OVER (...)` is the canonical "running row count by
658
+ # partition" idiom — same name-specific gate as the agg path.
659
+ # Skip arg generation entirely for the star form; the OVER
660
+ # spec is generated against the OUTER ctx as usual.
661
+ if (f.kind == FuncKind.AGGREGATE
662
+ and f.name == "count"
663
+ and should_emit_count_star(ctx.rng)):
664
+ spec = gen_window_spec(ctx)
665
+ return FuncCall(
666
+ f.returns, "count", (),
667
+ over=spec, star=True, filter_=filter_expr,
668
+ )
669
+ # Window args have three forbidden things at once: nested
670
+ # windows (PG rejects sum(row_number() OVER (...)) OVER (...)),
671
+ # aggregates inside non-aggregate window args (the rules get
672
+ # subtle), and ordinary aggregates inside aggregate-as-window
673
+ # args (sum(count(...)) OVER (...) is invalid). Three flag
674
+ # resets cover the whole forbidden-combinations space.
675
+ arg_ctx = replace(
676
+ ctx.descend(),
677
+ in_window=True,
678
+ allow_window=False,
679
+ allow_aggregates=False,
680
+ in_aggregate=True, # also blocks aggs in window args
681
+ )
682
+ args = tuple(
683
+ coerce_to_param_type(gen_expr(arg_ctx, arg_t), arg_t)
684
+ for arg_t in f.args
685
+ )
686
+ # The window spec itself uses the OUTER ctx's scope (which
687
+ # has the same visible columns as the surrounding SELECT
688
+ # list); column refs in PARTITION BY / ORDER BY are valid
689
+ # there.
690
+ spec = gen_window_spec(ctx)
691
+ return FuncCall(
692
+ f.returns, f.name, args, over=spec, filter_=filter_expr,
693
+ )
694
+
695
+ if kind in ("scalar_sub", "exists", "in_sub"):
696
+ # Lazy import: gen/subquery.py imports gen_expr at module
697
+ # top level (it generates the inner WHERE/target via
698
+ # gen_expr), so a top-level import here would close the
699
+ # cycle. The import is cached in sys.modules after first
700
+ # use — only the first call pays the lookup cost.
701
+ # NOTE: the subquery branch consumes one slot from
702
+ # `subquery_depth_remaining`, not from the expression-depth
703
+ # budget — see the at_subquery_leaf gate above.
704
+ from . import subquery as _sq
705
+ # 50/50 correlated vs uncorrelated. Correlated subqueries are
706
+ # more realistic-looking but not always achievable (the
707
+ # forcer falls back to literal RHS if no compatible inner
708
+ # column exists); the runtime decision is per-subquery.
709
+ correlated = ctx.rng.random() < 0.5
710
+ if kind == "scalar_sub":
711
+ return _sq.gen_scalar_subquery(
712
+ ctx, target_type, correlated=correlated,
713
+ )
714
+ if kind == "exists":
715
+ return _sq.gen_exists_subquery(ctx, correlated=correlated)
716
+ # kind == "in_sub"
717
+ return _sq.gen_in_subquery(ctx, correlated=correlated)
718
+
719
+ # Unreachable — `kinds` is always non-empty (literal is always
720
+ # in the candidate set). The fallback is here defensively to
721
+ # satisfy type-checkers and to fail loudly if a future change
722
+ # accidentally drops the literal candidate.
723
+ raise RuntimeError(f"no production picked (kinds={kinds})")