snowflake-ml-python 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +0 -4
  2. snowflake/ml/_internal/utils/mixins.py +26 -1
  3. snowflake/ml/data/_internal/arrow_ingestor.py +5 -1
  4. snowflake/ml/data/data_connector.py +2 -2
  5. snowflake/ml/data/data_ingestor.py +2 -1
  6. snowflake/ml/experiment/_experiment_info.py +3 -3
  7. snowflake/ml/feature_store/__init__.py +2 -0
  8. snowflake/ml/feature_store/aggregation.py +367 -0
  9. snowflake/ml/feature_store/feature.py +366 -0
  10. snowflake/ml/feature_store/feature_store.py +234 -20
  11. snowflake/ml/feature_store/feature_view.py +189 -4
  12. snowflake/ml/feature_store/metadata_manager.py +425 -0
  13. snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
  14. snowflake/ml/jobs/_interop/data_utils.py +8 -8
  15. snowflake/ml/jobs/_interop/dto_schema.py +52 -7
  16. snowflake/ml/jobs/_interop/protocols.py +124 -7
  17. snowflake/ml/jobs/_interop/utils.py +92 -33
  18. snowflake/ml/jobs/_utils/arg_protocol.py +7 -0
  19. snowflake/ml/jobs/_utils/constants.py +4 -0
  20. snowflake/ml/jobs/_utils/feature_flags.py +97 -13
  21. snowflake/ml/jobs/_utils/payload_utils.py +6 -40
  22. snowflake/ml/jobs/_utils/runtime_env_utils.py +12 -111
  23. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +204 -27
  24. snowflake/ml/jobs/decorators.py +17 -22
  25. snowflake/ml/jobs/job.py +25 -10
  26. snowflake/ml/jobs/job_definition.py +100 -8
  27. snowflake/ml/model/__init__.py +4 -0
  28. snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
  29. snowflake/ml/model/_client/model/model_version_impl.py +56 -28
  30. snowflake/ml/model/_client/ops/model_ops.py +2 -8
  31. snowflake/ml/model/_client/ops/service_ops.py +6 -11
  32. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  33. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  34. snowflake/ml/model/_client/sql/service.py +21 -29
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -1
  36. snowflake/ml/model/_packager/model_handlers/huggingface.py +20 -0
  37. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +70 -14
  38. snowflake/ml/model/_signatures/utils.py +76 -1
  39. snowflake/ml/model/models/huggingface_pipeline.py +3 -0
  40. snowflake/ml/model/openai_signatures.py +154 -0
  41. snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -3
  42. snowflake/ml/version.py +1 -1
  43. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/METADATA +79 -2
  44. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/RECORD +47 -44
  45. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/WHEEL +1 -1
  46. snowflake/ml/jobs/_utils/function_payload_utils.py +0 -43
  47. snowflake/ml/jobs/_utils/spec_utils.py +0 -22
  48. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/licenses/LICENSE.txt +0 -0
  49. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1079 @@
1
+ """SQL generators for tile-based aggregations.
2
+
3
+ This module provides SQL generation for:
4
+ 1. TilingSqlGenerator: Creates the DT query that computes partial aggregations (tiles)
5
+ 2. MergingSqlGenerator: Creates the CTEs for merging tiles during dataset generation
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from snowflake.ml.feature_store.aggregation import (
11
+ AggregationSpec,
12
+ AggregationType,
13
+ interval_to_seconds,
14
+ parse_interval,
15
+ )
16
+
17
+ # Maximum number of elements to store in array columns to avoid 128MB limit
18
+ # Assuming ~1KB per value, 100,000 values ≈ 100MB (leaving buffer)
19
+ _MAX_ARRAY_ELEMENTS = 100000
20
+
21
+
22
+ class TilingSqlGenerator:
23
+ """Generates SQL for creating tile Dynamic Tables.
24
+
25
+ The tiling query:
26
+ 1. Computes TIME_SLICE to bucket rows into tiles
27
+ 2. Computes partial aggregations per (join_keys, tile_start)
28
+ 3. For simple aggregations: stores SUM/COUNT as scalars
29
+ 4. For list aggregations: stores pre-sorted arrays (ARRAY_AGG with ORDER BY)
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ source_query: str,
35
+ join_keys: list[str],
36
+ timestamp_col: str,
37
+ feature_granularity: str,
38
+ features: list[AggregationSpec],
39
+ ) -> None:
40
+ """Initialize the TilingSqlGenerator.
41
+
42
+ Args:
43
+ source_query: The source query providing raw event data.
44
+ join_keys: List of column names used as join keys (from entities).
45
+ timestamp_col: The timestamp column name.
46
+ feature_granularity: The tile interval (e.g., "1h", "1d").
47
+ features: List of aggregation specifications.
48
+ """
49
+ self._source_query = source_query
50
+ self._join_keys = join_keys
51
+ self._timestamp_col = timestamp_col
52
+ self._feature_granularity = feature_granularity
53
+ self._features = features
54
+
55
+ # Parse interval for SQL generation
56
+ self._interval_value, self._interval_unit = parse_interval(feature_granularity)
57
+
58
+ # Track if we have any lifetime features (need cumulative columns)
59
+ self._has_lifetime_features = any(f.is_lifetime() for f in features)
60
+
61
+ def generate(self) -> str:
62
+ """Generate the complete tiling SQL query.
63
+
64
+ Returns:
65
+ SQL query for creating the tile Dynamic Table.
66
+ """
67
+ tile_columns = self._generate_tile_columns()
68
+ # Join keys and timestamp_col are already properly formatted by SqlIdentifier
69
+ join_keys_str = ", ".join(self._join_keys)
70
+ ts_col = self._timestamp_col
71
+
72
+ if not self._has_lifetime_features:
73
+ # Simple case: no lifetime features, just partial aggregations
74
+ query = f"""
75
+ SELECT
76
+ {join_keys_str},
77
+ TIME_SLICE({ts_col}, {self._interval_value}, '{self._interval_unit}', 'START') AS TILE_START,
78
+ {', '.join(tile_columns)}
79
+ FROM ({self._source_query})
80
+ GROUP BY {join_keys_str}, TILE_START
81
+ """
82
+ else:
83
+ # With lifetime features: need cumulative columns via window functions
84
+ # Structure: SELECT *, cumulative_columns FROM (SELECT partial_columns GROUP BY)
85
+ cumulative_columns = self._generate_cumulative_columns()
86
+
87
+ query = f"""
88
+ SELECT
89
+ base.*,
90
+ {', '.join(cumulative_columns)}
91
+ FROM (
92
+ SELECT
93
+ {join_keys_str},
94
+ TIME_SLICE({ts_col}, {self._interval_value}, '{self._interval_unit}', 'START') AS TILE_START,
95
+ {', '.join(tile_columns)}
96
+ FROM ({self._source_query})
97
+ GROUP BY {join_keys_str}, TILE_START
98
+ ) base
99
+ """
100
+ return query.strip()
101
+
102
+ def _generate_tile_columns(self) -> list[str]:
103
+ """Generate the tile column expressions for all features.
104
+
105
+ All simple aggregations share base partial columns:
106
+ - _PARTIAL_SUM_{col}: SUM(col) - used by SUM, AVG, STD, VAR
107
+ - _PARTIAL_COUNT_{col}: COUNT(col) - used by COUNT, AVG, STD, VAR
108
+ - _PARTIAL_SUM_SQ_{col}: SUM(col*col) - used by STD, VAR
109
+
110
+ This allows maximum column reuse. For example, SUM(amount) + AVG(amount)
111
+ only creates 2 columns, not 3.
112
+
113
+ Returns:
114
+ List of SQL column expressions for the tile table.
115
+ """
116
+ # Track unique tile columns by their full name to avoid duplicates
117
+ seen_columns: set[str] = set()
118
+ columns = []
119
+ ts_col = self._timestamp_col
120
+
121
+ for spec in self._features:
122
+ src_col = spec.source_column
123
+
124
+ if spec.function == AggregationType.SUM:
125
+ # SUM needs _PARTIAL_SUM
126
+ col_name = spec.get_tile_column_name("SUM")
127
+ if col_name not in seen_columns:
128
+ columns.append(f"SUM({src_col}) AS {col_name}")
129
+ seen_columns.add(col_name)
130
+
131
+ elif spec.function == AggregationType.COUNT:
132
+ # COUNT needs _PARTIAL_COUNT
133
+ col_name = spec.get_tile_column_name("COUNT")
134
+ if col_name not in seen_columns:
135
+ columns.append(f"COUNT({src_col}) AS {col_name}")
136
+ seen_columns.add(col_name)
137
+
138
+ elif spec.function == AggregationType.AVG:
139
+ # AVG needs _PARTIAL_SUM and _PARTIAL_COUNT
140
+ sum_col = spec.get_tile_column_name("SUM")
141
+ count_col = spec.get_tile_column_name("COUNT")
142
+ if sum_col not in seen_columns:
143
+ columns.append(f"SUM({src_col}) AS {sum_col}")
144
+ seen_columns.add(sum_col)
145
+ if count_col not in seen_columns:
146
+ columns.append(f"COUNT({src_col}) AS {count_col}")
147
+ seen_columns.add(count_col)
148
+
149
+ elif spec.function == AggregationType.MIN:
150
+ # MIN needs _PARTIAL_MIN
151
+ col_name = spec.get_tile_column_name("MIN")
152
+ if col_name not in seen_columns:
153
+ columns.append(f"MIN({src_col}) AS {col_name}")
154
+ seen_columns.add(col_name)
155
+
156
+ elif spec.function == AggregationType.MAX:
157
+ # MAX needs _PARTIAL_MAX
158
+ col_name = spec.get_tile_column_name("MAX")
159
+ if col_name not in seen_columns:
160
+ columns.append(f"MAX({src_col}) AS {col_name}")
161
+ seen_columns.add(col_name)
162
+
163
+ elif spec.function in (AggregationType.STD, AggregationType.VAR):
164
+ # STD/VAR need _PARTIAL_SUM, _PARTIAL_COUNT, and _PARTIAL_SUM_SQ
165
+ sum_col = spec.get_tile_column_name("SUM")
166
+ count_col = spec.get_tile_column_name("COUNT")
167
+ sum_sq_col = spec.get_tile_column_name("SUM_SQ")
168
+ if sum_col not in seen_columns:
169
+ columns.append(f"SUM({src_col}) AS {sum_col}")
170
+ seen_columns.add(sum_col)
171
+ if count_col not in seen_columns:
172
+ columns.append(f"COUNT({src_col}) AS {count_col}")
173
+ seen_columns.add(count_col)
174
+ if sum_sq_col not in seen_columns:
175
+ columns.append(f"SUM({src_col} * {src_col}) AS {sum_sq_col}")
176
+ seen_columns.add(sum_sq_col)
177
+
178
+ elif spec.function == AggregationType.APPROX_COUNT_DISTINCT:
179
+ # APPROX_COUNT_DISTINCT uses HLL (HyperLogLog) state
180
+ col_name = spec.get_tile_column_name("HLL")
181
+ if col_name not in seen_columns:
182
+ columns.append(f"HLL_EXPORT(HLL_ACCUMULATE({src_col})) AS {col_name}")
183
+ seen_columns.add(col_name)
184
+
185
+ elif spec.function == AggregationType.APPROX_PERCENTILE:
186
+ # APPROX_PERCENTILE uses T-Digest state
187
+ col_name = spec.get_tile_column_name("TDIGEST")
188
+ if col_name not in seen_columns:
189
+ columns.append(f"APPROX_PERCENTILE_ACCUMULATE({src_col}) AS {col_name}")
190
+ seen_columns.add(col_name)
191
+
192
+ elif spec.function == AggregationType.LAST_N:
193
+ col_name = spec.get_tile_column_name("LAST")
194
+ if col_name not in seen_columns:
195
+ columns.append(
196
+ f"ARRAY_SLICE("
197
+ f"ARRAY_AGG({src_col}) WITHIN GROUP (ORDER BY {ts_col} DESC), "
198
+ f"0, {_MAX_ARRAY_ELEMENTS}) AS {col_name}"
199
+ )
200
+ seen_columns.add(col_name)
201
+
202
+ elif spec.function == AggregationType.LAST_DISTINCT_N:
203
+ # Uses same tile column as LAST_N (dedup happens at merge time)
204
+ col_name = spec.get_tile_column_name("LAST")
205
+ if col_name not in seen_columns:
206
+ columns.append(
207
+ f"ARRAY_SLICE("
208
+ f"ARRAY_AGG({src_col}) WITHIN GROUP (ORDER BY {ts_col} DESC), "
209
+ f"0, {_MAX_ARRAY_ELEMENTS}) AS {col_name}"
210
+ )
211
+ seen_columns.add(col_name)
212
+
213
+ elif spec.function == AggregationType.FIRST_N:
214
+ col_name = spec.get_tile_column_name("FIRST")
215
+ if col_name not in seen_columns:
216
+ columns.append(
217
+ f"ARRAY_SLICE("
218
+ f"ARRAY_AGG({src_col}) WITHIN GROUP (ORDER BY {ts_col} ASC), "
219
+ f"0, {_MAX_ARRAY_ELEMENTS}) AS {col_name}"
220
+ )
221
+ seen_columns.add(col_name)
222
+
223
+ elif spec.function == AggregationType.FIRST_DISTINCT_N:
224
+ # Uses same tile column as FIRST_N (dedup happens at merge time)
225
+ col_name = spec.get_tile_column_name("FIRST")
226
+ if col_name not in seen_columns:
227
+ columns.append(
228
+ f"ARRAY_SLICE("
229
+ f"ARRAY_AGG({src_col}) WITHIN GROUP (ORDER BY {ts_col} ASC), "
230
+ f"0, {_MAX_ARRAY_ELEMENTS}) AS {col_name}"
231
+ )
232
+ seen_columns.add(col_name)
233
+
234
+ return columns
235
+
236
+ def _generate_cumulative_columns(self) -> list[str]:
237
+ """Generate cumulative column expressions for lifetime aggregations.
238
+
239
+ These columns use window functions to compute running totals per entity.
240
+ They are computed over the partial columns from the inner GROUP BY.
241
+
242
+ Returns:
243
+ List of SQL column expressions for cumulative columns.
244
+ """
245
+ seen_columns: set[str] = set()
246
+ columns = []
247
+ join_keys_str = ", ".join(self._join_keys)
248
+
249
+ # Only process lifetime features
250
+ lifetime_features = [f for f in self._features if f.is_lifetime()]
251
+
252
+ for spec in lifetime_features:
253
+ if spec.function == AggregationType.SUM:
254
+ # Cumulative SUM
255
+ partial_col = spec.get_tile_column_name("SUM")
256
+ cum_col = spec.get_cumulative_column_name("SUM")
257
+ if cum_col not in seen_columns:
258
+ columns.append(
259
+ f"SUM({partial_col}) OVER ("
260
+ f"PARTITION BY {join_keys_str} "
261
+ f"ORDER BY TILE_START "
262
+ f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
263
+ f") AS {cum_col}"
264
+ )
265
+ seen_columns.add(cum_col)
266
+
267
+ elif spec.function == AggregationType.COUNT:
268
+ # Cumulative COUNT
269
+ partial_col = spec.get_tile_column_name("COUNT")
270
+ cum_col = spec.get_cumulative_column_name("COUNT")
271
+ if cum_col not in seen_columns:
272
+ columns.append(
273
+ f"SUM({partial_col}) OVER ("
274
+ f"PARTITION BY {join_keys_str} "
275
+ f"ORDER BY TILE_START "
276
+ f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
277
+ f") AS {cum_col}"
278
+ )
279
+ seen_columns.add(cum_col)
280
+
281
+ elif spec.function == AggregationType.AVG:
282
+ # Cumulative AVG needs cumulative SUM and COUNT
283
+ partial_sum = spec.get_tile_column_name("SUM")
284
+ partial_count = spec.get_tile_column_name("COUNT")
285
+ cum_sum = spec.get_cumulative_column_name("SUM")
286
+ cum_count = spec.get_cumulative_column_name("COUNT")
287
+
288
+ if cum_sum not in seen_columns:
289
+ columns.append(
290
+ f"SUM({partial_sum}) OVER ("
291
+ f"PARTITION BY {join_keys_str} "
292
+ f"ORDER BY TILE_START "
293
+ f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
294
+ f") AS {cum_sum}"
295
+ )
296
+ seen_columns.add(cum_sum)
297
+ if cum_count not in seen_columns:
298
+ columns.append(
299
+ f"SUM({partial_count}) OVER ("
300
+ f"PARTITION BY {join_keys_str} "
301
+ f"ORDER BY TILE_START "
302
+ f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
303
+ f") AS {cum_count}"
304
+ )
305
+ seen_columns.add(cum_count)
306
+
307
+ elif spec.function == AggregationType.MIN:
308
+ # Cumulative MIN (running minimum)
309
+ partial_col = spec.get_tile_column_name("MIN")
310
+ cum_col = spec.get_cumulative_column_name("MIN")
311
+ if cum_col not in seen_columns:
312
+ columns.append(
313
+ f"MIN({partial_col}) OVER ("
314
+ f"PARTITION BY {join_keys_str} "
315
+ f"ORDER BY TILE_START "
316
+ f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
317
+ f") AS {cum_col}"
318
+ )
319
+ seen_columns.add(cum_col)
320
+
321
+ elif spec.function == AggregationType.MAX:
322
+ # Cumulative MAX (running maximum)
323
+ partial_col = spec.get_tile_column_name("MAX")
324
+ cum_col = spec.get_cumulative_column_name("MAX")
325
+ if cum_col not in seen_columns:
326
+ columns.append(
327
+ f"MAX({partial_col}) OVER ("
328
+ f"PARTITION BY {join_keys_str} "
329
+ f"ORDER BY TILE_START "
330
+ f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
331
+ f") AS {cum_col}"
332
+ )
333
+ seen_columns.add(cum_col)
334
+
335
+ elif spec.function in (AggregationType.STD, AggregationType.VAR):
336
+ # Cumulative STD/VAR needs cumulative SUM, COUNT, and SUM_SQ
337
+ partial_sum = spec.get_tile_column_name("SUM")
338
+ partial_count = spec.get_tile_column_name("COUNT")
339
+ partial_sum_sq = spec.get_tile_column_name("SUM_SQ")
340
+ cum_sum = spec.get_cumulative_column_name("SUM")
341
+ cum_count = spec.get_cumulative_column_name("COUNT")
342
+ cum_sum_sq = spec.get_cumulative_column_name("SUM_SQ")
343
+
344
+ if cum_sum not in seen_columns:
345
+ columns.append(
346
+ f"SUM({partial_sum}) OVER ("
347
+ f"PARTITION BY {join_keys_str} "
348
+ f"ORDER BY TILE_START "
349
+ f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
350
+ f") AS {cum_sum}"
351
+ )
352
+ seen_columns.add(cum_sum)
353
+ if cum_count not in seen_columns:
354
+ columns.append(
355
+ f"SUM({partial_count}) OVER ("
356
+ f"PARTITION BY {join_keys_str} "
357
+ f"ORDER BY TILE_START "
358
+ f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
359
+ f") AS {cum_count}"
360
+ )
361
+ seen_columns.add(cum_count)
362
+ if cum_sum_sq not in seen_columns:
363
+ columns.append(
364
+ f"SUM({partial_sum_sq}) OVER ("
365
+ f"PARTITION BY {join_keys_str} "
366
+ f"ORDER BY TILE_START "
367
+ f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
368
+ f") AS {cum_sum_sq}"
369
+ )
370
+ seen_columns.add(cum_sum_sq)
371
+
372
+ # Note: APPROX_COUNT_DISTINCT (HLL_COMBINE) and APPROX_PERCENTILE (APPROX_PERCENTILE_COMBINE)
373
+ # do NOT support cumulative window frames in Snowflake.
374
+ # These will be handled at merge time by aggregating all tiles.
375
+
376
+ # Note: FIRST_N, FIRST_DISTINCT_N, LAST_N, LAST_DISTINCT_N lifetime
377
+ # are handled at merge time by scanning tiles, not via cumulative columns
378
+
379
+ return columns
380
+
381
+
382
+ class MergingSqlGenerator:
383
+ """Generates CTEs for merging tiles during dataset generation.
384
+
385
+ The merging process:
386
+ 1. TILES_JOINED_FVi: Join tiles with spine, filtering by window and complete tiles only
387
+ 2. SIMPLE_MERGED_FVi: Aggregate simple features (SUM, COUNT, AVG)
388
+ 3. LIST_MERGED_FVi: Flatten and aggregate list features (LAST_N, etc.)
389
+ 4. FVi: Combine simple and list results
390
+ """
391
+
392
+ def __init__(
393
+ self,
394
+ tile_table: str,
395
+ join_keys: list[str],
396
+ timestamp_col: str,
397
+ feature_granularity: str,
398
+ features: list[AggregationSpec],
399
+ spine_timestamp_col: str,
400
+ fv_index: int,
401
+ ) -> None:
402
+ """Initialize the MergingSqlGenerator.
403
+
404
+ Args:
405
+ tile_table: Fully qualified name of the tile table.
406
+ join_keys: List of join key column names.
407
+ timestamp_col: The timestamp column from the feature view.
408
+ feature_granularity: The tile interval.
409
+ features: List of aggregation specifications.
410
+ spine_timestamp_col: The timestamp column from the spine.
411
+ fv_index: Index of this feature view (for CTE naming).
412
+ """
413
+ self._tile_table = tile_table
414
+ self._join_keys = join_keys
415
+ self._timestamp_col = timestamp_col
416
+ self._feature_granularity = feature_granularity
417
+ self._features = features
418
+ self._spine_timestamp_col = spine_timestamp_col
419
+ self._fv_index = fv_index
420
+
421
+ # Separate lifetime from non-lifetime features
422
+ self._lifetime_features = [f for f in features if f.is_lifetime()]
423
+ self._non_lifetime_features = [f for f in features if not f.is_lifetime()]
424
+
425
+ # Separate non-lifetime features by type (simple vs list)
426
+ self._simple_features = [f for f in self._non_lifetime_features if f.function.is_simple()]
427
+ self._list_features = [f for f in self._non_lifetime_features if f.function.is_list()]
428
+
429
+ # Lifetime features are all simple (validation ensures only SUM, COUNT, AVG, MIN, MAX, STD, VAR)
430
+ self._lifetime_simple_features = self._lifetime_features
431
+
432
+ # Parse interval
433
+ self._interval_value, self._interval_unit = parse_interval(feature_granularity)
434
+ self._interval_seconds = interval_to_seconds(feature_granularity)
435
+
436
+ # Calculate max window in tiles for filtering (only for non-lifetime features)
437
+ if self._non_lifetime_features:
438
+ # Max tiles needed is the max of (window + offset) across all non-lifetime features
439
+ max_lookback_seconds = max(
440
+ f.get_window_seconds() + f.get_offset_seconds() for f in self._non_lifetime_features
441
+ )
442
+ self._max_tiles_needed = (max_lookback_seconds + self._interval_seconds - 1) // self._interval_seconds
443
+ else:
444
+ self._max_tiles_needed = 0
445
+
446
+ def generate_all_ctes(self) -> list[tuple[str, str]]:
447
+ """Generate all CTEs needed for this feature view.
448
+
449
+ The optimization flow:
450
+ 1. SPINE_BOUNDARY: Add truncated tile boundary to spine
451
+ 2. UNIQUE_BOUNDS: Get distinct (entity, boundary) pairs
452
+ 3. TILES_JOINED: Join tiles to unique boundaries (for non-lifetime features)
453
+ 4. SIMPLE_MERGED: Aggregate simple non-lifetime features per boundary
454
+ 5. LIST_MERGED: Aggregate list non-lifetime features per boundary
455
+ 6. LIFETIME_MERGED: ASOF JOIN for lifetime features (O(1) per boundary)
456
+ 7. LIFETIME_LIST_MERGED: Scan tiles for lifetime list features
457
+ 8. FV: Join back to spine to expand results
458
+
459
+ Returns:
460
+ List of (cte_name, cte_body) tuples.
461
+ """
462
+ ctes = []
463
+
464
+ # CTE 1: Spine with tile boundary (for join-back)
465
+ ctes.append(self._generate_spine_boundary_cte())
466
+
467
+ # CTE 2: Unique boundaries (optimization - reduce aggregation work)
468
+ ctes.append(self._generate_unique_boundaries_cte())
469
+
470
+ # CTE 3: Join tiles with unique boundaries (only if we have non-lifetime features)
471
+ if self._non_lifetime_features:
472
+ ctes.append(self._generate_tiles_joined_cte())
473
+
474
+ # CTE 4: Simple aggregations (if any non-lifetime)
475
+ if self._simple_features:
476
+ ctes.append(self._generate_simple_merged_cte())
477
+
478
+ # CTE 5: List aggregations (if any non-lifetime)
479
+ if self._list_features:
480
+ ctes.append(self._generate_list_merged_cte())
481
+
482
+ # CTE 6: Lifetime simple aggregations (using ASOF JOIN on cumulative columns)
483
+ if self._lifetime_simple_features:
484
+ ctes.append(self._generate_lifetime_merged_cte())
485
+
486
+ # CTE 7: Combine all results and join back to spine
487
+ ctes.append(self._generate_combined_cte())
488
+
489
+ return ctes
490
+
491
+ def _generate_spine_boundary_cte(self) -> tuple[str, str]:
492
+ """Generate CTE that adds tile boundary to deduplicated spine.
493
+
494
+ The tile boundary is the truncated timestamp that determines which
495
+ complete tiles are visible. All spine rows with the same boundary
496
+ will have identical feature values.
497
+
498
+ Returns:
499
+ Tuple of (cte_name, cte_body).
500
+ """
501
+ cte_name = f"SPINE_BOUNDARY_FV{self._fv_index}"
502
+
503
+ # Quote column names to preserve case-sensitivity from spine dataframe
504
+ # The spine_timestamp_col is passed as-is from the user (e.g., "query_ts")
505
+ quoted_join_keys = [f'"{k}"' for k in self._join_keys]
506
+ quoted_spine_ts = f'"{self._spine_timestamp_col}"'
507
+
508
+ # Select all columns plus the tile boundary
509
+ all_cols = quoted_join_keys + [quoted_spine_ts]
510
+ select_cols = ", ".join(all_cols)
511
+
512
+ # DATE_TRUNC to tile granularity gives us the boundary
513
+ # All timestamps in the same granularity window see the same complete tiles
514
+ cte_body = f"""
515
+ SELECT DISTINCT {select_cols},
516
+ DATE_TRUNC('{self._interval_unit.lower()}', {quoted_spine_ts}) AS TILE_BOUNDARY
517
+ FROM SPINE
518
+ """
519
+ return cte_name, cte_body.strip()
520
+
521
+ def _generate_unique_boundaries_cte(self) -> tuple[str, str]:
522
+ """Generate CTE with unique (entity, boundary) pairs.
523
+
524
+ This is the key optimization: instead of computing features for each
525
+ spine row, we compute once per unique boundary and join back.
526
+
527
+ Returns:
528
+ Tuple of (cte_name, cte_body).
529
+ """
530
+ cte_name = f"UNIQUE_BOUNDS_FV{self._fv_index}"
531
+
532
+ # Quote column names to handle case-sensitivity
533
+ quoted_join_keys = [f'"{k}"' for k in self._join_keys]
534
+ keys_str = ", ".join(quoted_join_keys)
535
+
536
+ cte_body = f"""
537
+ SELECT DISTINCT {keys_str}, TILE_BOUNDARY
538
+ FROM SPINE_BOUNDARY_FV{self._fv_index}
539
+ """
540
+ return cte_name, cte_body.strip()
541
+
542
+ def _generate_tiles_joined_cte(self) -> tuple[str, str]:
543
+ """Generate the CTE that joins tiles with unique boundaries.
544
+
545
+ This joins tiles to UNIQUE_BOUNDS (not full spine), which is much smaller
546
+ when there are many spine rows per tile boundary.
547
+
548
+ Returns:
549
+ Tuple of (cte_name, cte_body) for the tiles joined CTE.
550
+ """
551
+ cte_name = f"TILES_JOINED_FV{self._fv_index}"
552
+
553
+ # Quote column names for spine columns (case-sensitive)
554
+ quoted_join_keys = [f'"{k}"' for k in self._join_keys]
555
+
556
+ # Tile table column names match the join keys (already SqlIdentifier-formatted)
557
+ tile_join_keys = list(self._join_keys)
558
+ tile_keys_str = ", ".join(tile_join_keys)
559
+
560
+ # Join conditions: quoted spine columns to uppercase tile columns
561
+ join_conditions = [f"UB.{qk} = TILES.{tk}" for qk, tk in zip(quoted_join_keys, tile_join_keys)]
562
+
563
+ # Get all tile columns we need (deduplicated)
564
+ tile_columns_set: set[str] = set()
565
+ for spec in self._features:
566
+ if spec.function == AggregationType.SUM:
567
+ tile_columns_set.add(spec.get_tile_column_name("SUM"))
568
+ elif spec.function == AggregationType.COUNT:
569
+ tile_columns_set.add(spec.get_tile_column_name("COUNT"))
570
+ elif spec.function == AggregationType.AVG:
571
+ tile_columns_set.add(spec.get_tile_column_name("SUM"))
572
+ tile_columns_set.add(spec.get_tile_column_name("COUNT"))
573
+ elif spec.function == AggregationType.MIN:
574
+ tile_columns_set.add(spec.get_tile_column_name("MIN"))
575
+ elif spec.function == AggregationType.MAX:
576
+ tile_columns_set.add(spec.get_tile_column_name("MAX"))
577
+ elif spec.function in (AggregationType.STD, AggregationType.VAR):
578
+ tile_columns_set.add(spec.get_tile_column_name("SUM"))
579
+ tile_columns_set.add(spec.get_tile_column_name("COUNT"))
580
+ tile_columns_set.add(spec.get_tile_column_name("SUM_SQ"))
581
+ elif spec.function == AggregationType.APPROX_COUNT_DISTINCT:
582
+ tile_columns_set.add(spec.get_tile_column_name("HLL"))
583
+ elif spec.function == AggregationType.APPROX_PERCENTILE:
584
+ tile_columns_set.add(spec.get_tile_column_name("TDIGEST"))
585
+ elif spec.function in (AggregationType.LAST_N, AggregationType.LAST_DISTINCT_N):
586
+ tile_columns_set.add(spec.get_tile_column_name("LAST"))
587
+ elif spec.function in (AggregationType.FIRST_N, AggregationType.FIRST_DISTINCT_N):
588
+ tile_columns_set.add(spec.get_tile_column_name("FIRST"))
589
+ tile_columns = sorted(tile_columns_set) # Sort for deterministic output
590
+
591
+ tile_columns_str = ", ".join(f"TILES.{col}" for col in tile_columns)
592
+
593
+ # Window filter: only include tiles within the max window and complete tiles
594
+ # Complete tiles: tile_end <= tile_boundary (not spine timestamp)
595
+ # tile_end = DATEADD(interval_unit, interval_value, tile_start)
596
+ cte_body = f"""
597
+ SELECT
598
+ UB.*,
599
+ TILES.TILE_START,
600
+ {tile_columns_str}
601
+ FROM UNIQUE_BOUNDS_FV{self._fv_index} UB
602
+ LEFT JOIN (
603
+ SELECT {tile_keys_str}, TILE_START, {', '.join(tile_columns)}
604
+ FROM {self._tile_table}
605
+ ) TILES
606
+ ON {' AND '.join(join_conditions)}
607
+ -- Window filter: tiles within max window from tile boundary
608
+ AND TILES.TILE_START >= DATEADD(
609
+ {self._interval_unit}, -{self._max_tiles_needed * self._interval_value}, UB.TILE_BOUNDARY
610
+ )
611
+ -- Complete tiles only: tile_end <= tile_boundary
612
+ AND DATEADD({self._interval_unit}, {self._interval_value}, TILES.TILE_START) <= UB.TILE_BOUNDARY
613
+ """
614
+ return cte_name, cte_body.strip()
615
+
616
+ def _get_tile_filter_condition(self, spec: AggregationSpec) -> str:
617
+ """Generate the CASE WHEN condition for filtering tiles by window and offset.
618
+
619
+ For a feature with window W and offset O, we want tiles where:
620
+ - TILE_START >= TILE_BOUNDARY - W - O (start of window)
621
+ - TILE_START < TILE_BOUNDARY - O (end of window, before offset)
622
+
623
+ Args:
624
+ spec: The aggregation specification.
625
+
626
+ Returns:
627
+ SQL condition string for use in CASE WHEN.
628
+ """
629
+ window_tiles = (spec.get_window_seconds() + self._interval_seconds - 1) // self._interval_seconds
630
+ offset_tiles = spec.get_offset_seconds() // self._interval_seconds
631
+
632
+ if offset_tiles == 0:
633
+ # No offset: just filter by window start
634
+ return (
635
+ f"TILE_START >= DATEADD({self._interval_unit}, "
636
+ f"-{window_tiles * self._interval_value}, TILE_BOUNDARY)"
637
+ )
638
+ else:
639
+ # With offset: filter by both window start and end (shifted by offset)
640
+ window_start_tiles = window_tiles + offset_tiles
641
+ return (
642
+ f"TILE_START >= DATEADD({self._interval_unit}, "
643
+ f"-{window_start_tiles * self._interval_value}, TILE_BOUNDARY) "
644
+ f"AND TILE_START < DATEADD({self._interval_unit}, "
645
+ f"-{offset_tiles * self._interval_value}, TILE_BOUNDARY)"
646
+ )
647
+
648
+ def _get_list_tile_filter_condition(self, spec: AggregationSpec) -> str:
649
+ """Generate filter condition for list aggregations (with table prefix).
650
+
651
+ Similar to _get_tile_filter_condition but uses t. prefix for table references.
652
+
653
+ Args:
654
+ spec: The aggregation specification.
655
+
656
+ Returns:
657
+ SQL condition string for use in WHERE clause.
658
+ """
659
+ window_tiles = (spec.get_window_seconds() + self._interval_seconds - 1) // self._interval_seconds
660
+ offset_tiles = spec.get_offset_seconds() // self._interval_seconds
661
+
662
+ if offset_tiles == 0:
663
+ # No offset: just filter by window start
664
+ return (
665
+ f"t.TILE_START >= DATEADD({self._interval_unit}, "
666
+ f"-{window_tiles * self._interval_value}, t.TILE_BOUNDARY)"
667
+ )
668
+ else:
669
+ # With offset: filter by both window start and end (shifted by offset)
670
+ window_start_tiles = window_tiles + offset_tiles
671
+ return (
672
+ f"t.TILE_START >= DATEADD({self._interval_unit}, "
673
+ f"-{window_start_tiles * self._interval_value}, t.TILE_BOUNDARY) "
674
+ f"AND t.TILE_START < DATEADD({self._interval_unit}, "
675
+ f"-{offset_tiles * self._interval_value}, t.TILE_BOUNDARY)"
676
+ )
677
+
678
+ def _generate_simple_merged_cte(self) -> tuple[str, str]:
679
+ """Generate the CTE for simple aggregations (SUM, COUNT, AVG).
680
+
681
+ Groups by entity keys + TILE_BOUNDARY (not spine timestamp) for efficiency.
682
+
683
+ Returns:
684
+ Tuple of (cte_name, cte_body) for the simple merged CTE.
685
+ """
686
+ cte_name = f"SIMPLE_MERGED_FV{self._fv_index}"
687
+
688
+ # Quote column names for case-sensitivity (from UNIQUE_BOUNDS which inherits from spine)
689
+ quoted_join_keys = [f'"{k}"' for k in self._join_keys]
690
+ # Group by entity keys + TILE_BOUNDARY (optimization)
691
+ group_by_cols = quoted_join_keys + ["TILE_BOUNDARY"]
692
+ group_by_str = ", ".join(group_by_cols)
693
+
694
+ agg_columns = []
695
+ for spec in self._simple_features:
696
+ output_col = spec.get_sql_column_name()
697
+ tile_filter = self._get_tile_filter_condition(spec)
698
+
699
+ if spec.function == AggregationType.SUM:
700
+ col_name = spec.get_tile_column_name("SUM")
701
+ agg_columns.append(f"SUM(CASE WHEN {tile_filter} " f"THEN {col_name} ELSE 0 END) AS {output_col}")
702
+
703
+ elif spec.function == AggregationType.COUNT:
704
+ col_name = spec.get_tile_column_name("COUNT")
705
+ agg_columns.append(f"SUM(CASE WHEN {tile_filter} " f"THEN {col_name} ELSE 0 END) AS {output_col}")
706
+
707
+ elif spec.function == AggregationType.MIN:
708
+ col_name = spec.get_tile_column_name("MIN")
709
+ agg_columns.append(f"MIN(CASE WHEN {tile_filter} " f"THEN {col_name} ELSE NULL END) AS {output_col}")
710
+
711
+ elif spec.function == AggregationType.MAX:
712
+ col_name = spec.get_tile_column_name("MAX")
713
+ agg_columns.append(f"MAX(CASE WHEN {tile_filter} " f"THEN {col_name} ELSE NULL END) AS {output_col}")
714
+
715
+ elif spec.function == AggregationType.AVG:
716
+ # AVG = SUM(partial_sums) / SUM(partial_counts)
717
+ sum_col = spec.get_tile_column_name("SUM")
718
+ count_col = spec.get_tile_column_name("COUNT")
719
+ agg_columns.append(
720
+ f"CASE WHEN SUM(CASE WHEN {tile_filter} "
721
+ f"THEN {count_col} ELSE 0 END) > 0 "
722
+ f"THEN SUM(CASE WHEN {tile_filter} "
723
+ f"THEN {sum_col} ELSE 0 END) / "
724
+ f"SUM(CASE WHEN {tile_filter} "
725
+ f"THEN {count_col} ELSE 0 END) "
726
+ f"ELSE NULL END AS {output_col}"
727
+ )
728
+
729
+ elif spec.function == AggregationType.VAR:
730
+ # VAR = (SUM_SQ / COUNT) - (SUM / COUNT)^2
731
+ # Using parallel variance formula
732
+ # GREATEST(0, ...) clamps to non-negative to handle floating-point errors
733
+ sum_col = spec.get_tile_column_name("SUM")
734
+ count_col = spec.get_tile_column_name("COUNT")
735
+ sum_sq_col = spec.get_tile_column_name("SUM_SQ")
736
+ agg_columns.append(
737
+ f"CASE WHEN SUM(CASE WHEN {tile_filter} "
738
+ f"THEN {count_col} ELSE 0 END) > 0 "
739
+ f"THEN GREATEST(0, ("
740
+ f"SUM(CASE WHEN {tile_filter} "
741
+ f"THEN {sum_sq_col} ELSE 0 END) / "
742
+ f"SUM(CASE WHEN {tile_filter} "
743
+ f"THEN {count_col} ELSE 0 END)"
744
+ f") - POWER("
745
+ f"SUM(CASE WHEN {tile_filter} "
746
+ f"THEN {sum_col} ELSE 0 END) / "
747
+ f"SUM(CASE WHEN {tile_filter} "
748
+ f"THEN {count_col} ELSE 0 END), 2)) "
749
+ f"ELSE NULL END AS {output_col}"
750
+ )
751
+
752
+ elif spec.function == AggregationType.STD:
753
+ # STD = SQRT(VAR) = SQRT((SUM_SQ / COUNT) - (SUM / COUNT)^2)
754
+ # GREATEST(0, ...) clamps variance to non-negative to handle floating-point errors
755
+ # that can cause sqrt of tiny negative numbers like -4.54747e-13
756
+ sum_col = spec.get_tile_column_name("SUM")
757
+ count_col = spec.get_tile_column_name("COUNT")
758
+ sum_sq_col = spec.get_tile_column_name("SUM_SQ")
759
+ agg_columns.append(
760
+ f"CASE WHEN SUM(CASE WHEN {tile_filter} "
761
+ f"THEN {count_col} ELSE 0 END) > 0 "
762
+ f"THEN SQRT(GREATEST(0, "
763
+ f"SUM(CASE WHEN {tile_filter} "
764
+ f"THEN {sum_sq_col} ELSE 0 END) / "
765
+ f"SUM(CASE WHEN {tile_filter} "
766
+ f"THEN {count_col} ELSE 0 END)"
767
+ f" - POWER("
768
+ f"SUM(CASE WHEN {tile_filter} "
769
+ f"THEN {sum_col} ELSE 0 END) / "
770
+ f"SUM(CASE WHEN {tile_filter} "
771
+ f"THEN {count_col} ELSE 0 END), 2))) "
772
+ f"ELSE NULL END AS {output_col}"
773
+ )
774
+
775
+ elif spec.function == AggregationType.APPROX_COUNT_DISTINCT:
776
+ # Combine HLL states and estimate count
777
+ # HLL_ESTIMATE(HLL_COMBINE(HLL_IMPORT(state))) gives the approximate count
778
+ col_name = spec.get_tile_column_name("HLL")
779
+ agg_columns.append(
780
+ f"HLL_ESTIMATE(HLL_COMBINE("
781
+ f"CASE WHEN {tile_filter} THEN HLL_IMPORT({col_name}) ELSE NULL END"
782
+ f")) AS {output_col}"
783
+ )
784
+
785
+ elif spec.function == AggregationType.APPROX_PERCENTILE:
786
+ # Combine T-Digest states and estimate percentile
787
+ # APPROX_PERCENTILE_ESTIMATE(APPROX_PERCENTILE_COMBINE(state), percentile)
788
+ col_name = spec.get_tile_column_name("TDIGEST")
789
+ percentile = spec.params.get("percentile", 0.5)
790
+ agg_columns.append(
791
+ f"APPROX_PERCENTILE_ESTIMATE(APPROX_PERCENTILE_COMBINE("
792
+ f"CASE WHEN {tile_filter} THEN {col_name} ELSE NULL END"
793
+ f"), {percentile}) AS {output_col}"
794
+ )
795
+
796
+ cte_body = f"""
797
+ SELECT
798
+ {group_by_str},
799
+ {', '.join(agg_columns)}
800
+ FROM TILES_JOINED_FV{self._fv_index}
801
+ GROUP BY {group_by_str}
802
+ """
803
+ return cte_name, cte_body.strip()
804
+
805
+ def _generate_list_merged_cte(self) -> tuple[str, str]:
806
+ """Generate the CTE for list aggregations using LATERAL FLATTEN."""
807
+ cte_name = f"LIST_MERGED_FV{self._fv_index}"
808
+
809
+ # Generate a more efficient single-pass CTE
810
+ # Each list feature gets its own lateral flatten in the FROM clause
811
+ cte_body = self._generate_list_cte_body()
812
+
813
+ return cte_name, cte_body.strip()
814
+
815
+ def _generate_lifetime_merged_cte(self) -> tuple[str, str]:
816
+ """Generate the CTE for lifetime simple aggregations using ASOF JOIN.
817
+
818
+ Uses ASOF JOIN on cumulative columns for O(1) lookup per boundary.
819
+ This is much faster than aggregating all tiles from the beginning of time.
820
+
821
+ Returns:
822
+ Tuple of (cte_name, cte_body) for the lifetime merged CTE.
823
+ """
824
+ cte_name = f"LIFETIME_MERGED_FV{self._fv_index}"
825
+
826
+ # Quote column names for case-sensitivity
827
+ quoted_join_keys = [f'"{k}"' for k in self._join_keys]
828
+ group_by_cols = quoted_join_keys + ["TILE_BOUNDARY"]
829
+
830
+ # Tile table column names match the join keys (already SqlIdentifier-formatted)
831
+ tile_join_keys = list(self._join_keys)
832
+
833
+ # Build ASOF JOIN condition: match the most recent tile before the boundary
834
+ asof_match = "UB.TILE_BOUNDARY > TILES.TILE_START"
835
+ join_conditions = [f"UB.{qk} = TILES.{tk}" for qk, tk in zip(quoted_join_keys, tile_join_keys)]
836
+ join_conditions_str = " AND ".join(join_conditions)
837
+
838
+ # Build select columns for lifetime features
839
+ select_cols = []
840
+ for spec in self._lifetime_simple_features:
841
+ output_col = spec.get_sql_column_name()
842
+
843
+ if spec.function == AggregationType.SUM:
844
+ cum_col = spec.get_cumulative_column_name("SUM")
845
+ select_cols.append(f"TILES.{cum_col} AS {output_col}")
846
+
847
+ elif spec.function == AggregationType.COUNT:
848
+ cum_col = spec.get_cumulative_column_name("COUNT")
849
+ select_cols.append(f"TILES.{cum_col} AS {output_col}")
850
+
851
+ elif spec.function == AggregationType.AVG:
852
+ cum_sum = spec.get_cumulative_column_name("SUM")
853
+ cum_count = spec.get_cumulative_column_name("COUNT")
854
+ select_cols.append(
855
+ f"CASE WHEN TILES.{cum_count} > 0 "
856
+ f"THEN TILES.{cum_sum} / TILES.{cum_count} "
857
+ f"ELSE NULL END AS {output_col}"
858
+ )
859
+
860
+ elif spec.function == AggregationType.MIN:
861
+ cum_col = spec.get_cumulative_column_name("MIN")
862
+ select_cols.append(f"TILES.{cum_col} AS {output_col}")
863
+
864
+ elif spec.function == AggregationType.MAX:
865
+ cum_col = spec.get_cumulative_column_name("MAX")
866
+ select_cols.append(f"TILES.{cum_col} AS {output_col}")
867
+
868
+ elif spec.function == AggregationType.VAR:
869
+ cum_sum = spec.get_cumulative_column_name("SUM")
870
+ cum_count = spec.get_cumulative_column_name("COUNT")
871
+ cum_sum_sq = spec.get_cumulative_column_name("SUM_SQ")
872
+ select_cols.append(
873
+ f"CASE WHEN TILES.{cum_count} > 0 "
874
+ f"THEN GREATEST(0, "
875
+ f"TILES.{cum_sum_sq} / TILES.{cum_count} "
876
+ f"- POWER(TILES.{cum_sum} / TILES.{cum_count}, 2)) "
877
+ f"ELSE NULL END AS {output_col}"
878
+ )
879
+
880
+ elif spec.function == AggregationType.STD:
881
+ cum_sum = spec.get_cumulative_column_name("SUM")
882
+ cum_count = spec.get_cumulative_column_name("COUNT")
883
+ cum_sum_sq = spec.get_cumulative_column_name("SUM_SQ")
884
+ select_cols.append(
885
+ f"CASE WHEN TILES.{cum_count} > 0 "
886
+ f"THEN SQRT(GREATEST(0, "
887
+ f"TILES.{cum_sum_sq} / TILES.{cum_count} "
888
+ f"- POWER(TILES.{cum_sum} / TILES.{cum_count}, 2))) "
889
+ f"ELSE NULL END AS {output_col}"
890
+ )
891
+
892
+ elif spec.function == AggregationType.APPROX_COUNT_DISTINCT:
893
+ cum_col = spec.get_cumulative_column_name("HLL")
894
+ select_cols.append(f"HLL_ESTIMATE(HLL_IMPORT(TILES.{cum_col})) AS {output_col}")
895
+
896
+ elif spec.function == AggregationType.APPROX_PERCENTILE:
897
+ cum_col = spec.get_cumulative_column_name("TDIGEST")
898
+ percentile = spec.params.get("percentile", 0.5)
899
+ select_cols.append(f"APPROX_PERCENTILE_ESTIMATE(TILES.{cum_col}, {percentile}) AS {output_col}")
900
+
901
+ select_cols_str = ", ".join(select_cols)
902
+
903
+ # Qualify group by columns with UB alias to avoid ambiguity in ASOF JOIN
904
+ qualified_group_cols = [f"UB.{col}" for col in group_by_cols]
905
+ qualified_group_str = ", ".join(qualified_group_cols)
906
+
907
+ cte_body = f"""
908
+ SELECT
909
+ {qualified_group_str},
910
+ {select_cols_str}
911
+ FROM UNIQUE_BOUNDS_FV{self._fv_index} UB
912
+ ASOF JOIN {self._tile_table} TILES
913
+ MATCH_CONDITION ({asof_match})
914
+ ON {join_conditions_str}
915
+ """
916
+ return cte_name, cte_body.strip()
917
+
918
+ def _generate_list_cte_body(self) -> str:
919
+ """Generate the body of the list merged CTE.
920
+
921
+ Groups by entity keys + TILE_BOUNDARY (not spine timestamp) for efficiency.
922
+
923
+ Returns:
924
+ SQL string for the list merged CTE body.
925
+ """
926
+ # Quote column names for case-sensitivity (from UNIQUE_BOUNDS which inherits from spine)
927
+ quoted_join_keys = [f'"{k}"' for k in self._join_keys]
928
+ # Group by entity keys + TILE_BOUNDARY (optimization)
929
+ group_by_cols = quoted_join_keys + ["TILE_BOUNDARY"]
930
+ group_by_str = ", ".join(group_by_cols)
931
+ select_group_cols = ", ".join(f"t.{col}" for col in group_by_cols)
932
+
933
+ if not self._list_features:
934
+ return f"SELECT {group_by_str} FROM TILES_JOINED_FV{self._fv_index} WHERE 1=0"
935
+
936
+ # Build subqueries for each list feature
937
+ feature_subqueries = []
938
+
939
+ for _idx, spec in enumerate(self._list_features):
940
+ # Determine the tile column name based on aggregation type
941
+ if spec.function in (AggregationType.LAST_N, AggregationType.LAST_DISTINCT_N):
942
+ col_name = spec.get_tile_column_name("LAST")
943
+ order_clause = "t.TILE_START DESC, flat.INDEX ASC"
944
+ else: # FIRST_N, FIRST_DISTINCT_N
945
+ col_name = spec.get_tile_column_name("FIRST")
946
+ order_clause = "t.TILE_START ASC, flat.INDEX ASC"
947
+
948
+ output_col = spec.get_sql_column_name()
949
+ n_value = spec.params["n"]
950
+
951
+ # Calculate tile filter condition for window and offset
952
+ tile_filter = self._get_list_tile_filter_condition(spec)
953
+
954
+ is_distinct = spec.function in (AggregationType.LAST_DISTINCT_N, AggregationType.FIRST_DISTINCT_N)
955
+
956
+ if is_distinct:
957
+ # For distinct, use QUALIFY to keep first occurrence of each value
958
+ # Note: Inner query uses t. prefix, outer query uses bare column names
959
+ subquery = f"""
960
+ (SELECT {group_by_str},
961
+ ARRAY_AGG(val) WITHIN GROUP (ORDER BY rn) AS {output_col}
962
+ FROM (
963
+ SELECT {select_group_cols}, flat.VALUE AS val,
964
+ ROW_NUMBER() OVER (PARTITION BY {select_group_cols} ORDER BY {order_clause}) AS rn,
965
+ ROW_NUMBER() OVER (PARTITION BY {select_group_cols}, flat.VALUE ORDER BY {order_clause}) AS dup_rn
966
+ FROM TILES_JOINED_FV{self._fv_index} t,
967
+ LATERAL FLATTEN(INPUT => t.{col_name}) flat
968
+ WHERE {tile_filter}
969
+ AND flat.VALUE IS NOT NULL
970
+ ) ranked
971
+ WHERE dup_rn = 1 AND rn <= {n_value}
972
+ GROUP BY {group_by_str}
973
+ )"""
974
+ else:
975
+ # Non-distinct: straightforward flatten and aggregate
976
+ subquery = f"""
977
+ (SELECT {select_group_cols},
978
+ ARRAY_SLICE(
979
+ ARRAY_AGG(flat.VALUE) WITHIN GROUP (ORDER BY {order_clause}),
980
+ 0, {n_value}
981
+ ) AS {output_col}
982
+ FROM TILES_JOINED_FV{self._fv_index} t,
983
+ LATERAL FLATTEN(INPUT => t.{col_name}) flat
984
+ WHERE {tile_filter}
985
+ AND flat.VALUE IS NOT NULL
986
+ GROUP BY {select_group_cols}
987
+ )"""
988
+
989
+ feature_subqueries.append((output_col, subquery))
990
+
991
+ # Combine all subqueries with JOINs
992
+ if len(feature_subqueries) == 1:
993
+ return feature_subqueries[0][1]
994
+
995
+ # Multiple list features: join them together
996
+ first_name, first_query = feature_subqueries[0]
997
+ result = "SELECT sq0.*, "
998
+ for i, (name, _) in enumerate(feature_subqueries[1:], 1):
999
+ result += f"sq{i}.{name}"
1000
+ if i < len(feature_subqueries) - 1:
1001
+ result += ", "
1002
+
1003
+ result += f"\n FROM {first_query} sq0"
1004
+
1005
+ for i, (_name, query) in enumerate(feature_subqueries[1:], 1):
1006
+ join_cond = " AND ".join(f"sq0.{col} = sq{i}.{col}" for col in group_by_cols)
1007
+ result += f"\n LEFT JOIN {query} sq{i}\n ON {join_cond}"
1008
+
1009
+ return result
1010
+
1011
+ def _generate_combined_cte(self) -> tuple[str, str]:
1012
+ """Generate the final CTE that combines and expands results.
1013
+
1014
+ This joins the merged results (grouped by entity + TILE_BOUNDARY) back to
1015
+ the original spine (SPINE_BOUNDARY) to expand to per-spine-row output.
1016
+
1017
+ Returns:
1018
+ Tuple of (cte_name, cte_body) for the combined CTE.
1019
+ """
1020
+ cte_name = f"FV{self._fv_index:03d}"
1021
+
1022
+ # Quote column names for case-sensitivity
1023
+ quoted_join_keys = [f'"{k}"' for k in self._join_keys]
1024
+ quoted_spine_ts = f'"{self._spine_timestamp_col}"'
1025
+ spine_output_cols = quoted_join_keys + [quoted_spine_ts]
1026
+ # Merged results are grouped by entity + TILE_BOUNDARY
1027
+ boundary_group_cols = quoted_join_keys + ["TILE_BOUNDARY"]
1028
+
1029
+ # Select columns from spine (entity keys + original timestamp)
1030
+ select_cols = [f"s.{col}" for col in spine_output_cols]
1031
+
1032
+ # Add simple feature columns (non-lifetime)
1033
+ for spec in self._simple_features:
1034
+ select_cols.append(f"simple.{spec.get_sql_column_name()}")
1035
+
1036
+ # Add list feature columns (non-lifetime)
1037
+ for spec in self._list_features:
1038
+ select_cols.append(f"list_agg.{spec.get_sql_column_name()}")
1039
+
1040
+ # Add lifetime simple feature columns
1041
+ for spec in self._lifetime_simple_features:
1042
+ select_cols.append(f"lifetime.{spec.get_sql_column_name()}")
1043
+
1044
+ # Build FROM clause with all necessary JOINs
1045
+ from_clause = f"SPINE_BOUNDARY_FV{self._fv_index} s"
1046
+ joins = []
1047
+
1048
+ # Join condition template
1049
+ def make_join_cond(alias: str) -> str:
1050
+ return " AND ".join(f"s.{col} = {alias}.{col}" for col in boundary_group_cols)
1051
+
1052
+ # Add JOINs for each CTE that has features
1053
+ if self._simple_features:
1054
+ joins.append(f"LEFT JOIN SIMPLE_MERGED_FV{self._fv_index} simple ON {make_join_cond('simple')}")
1055
+
1056
+ if self._list_features:
1057
+ joins.append(f"LEFT JOIN LIST_MERGED_FV{self._fv_index} list_agg ON {make_join_cond('list_agg')}")
1058
+
1059
+ if self._lifetime_simple_features:
1060
+ joins.append(f"LEFT JOIN LIFETIME_MERGED_FV{self._fv_index} lifetime ON {make_join_cond('lifetime')}")
1061
+
1062
+ if joins:
1063
+ cte_body = f"""
1064
+ SELECT {', '.join(select_cols)}
1065
+ FROM {from_clause}
1066
+ {chr(10).join(' ' + j for j in joins)}
1067
+ """
1068
+ else:
1069
+ # No features (shouldn't happen)
1070
+ cte_body = f"""
1071
+ SELECT DISTINCT {', '.join(spine_output_cols)}
1072
+ FROM SPINE
1073
+ """
1074
+
1075
+ return cte_name, cte_body.strip()
1076
+
1077
+ def get_output_columns(self) -> list[str]:
1078
+ """Get the list of output column names from this feature view."""
1079
+ return [spec.output_column for spec in self._features]