snowflake-ml-python 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +0 -4
- snowflake/ml/feature_store/__init__.py +2 -0
- snowflake/ml/feature_store/aggregation.py +367 -0
- snowflake/ml/feature_store/feature.py +366 -0
- snowflake/ml/feature_store/feature_store.py +234 -20
- snowflake/ml/feature_store/feature_view.py +189 -4
- snowflake/ml/feature_store/metadata_manager.py +425 -0
- snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +38 -18
- snowflake/ml/jobs/_utils/query_helper.py +8 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +117 -0
- snowflake/ml/jobs/_utils/stage_utils.py +2 -2
- snowflake/ml/jobs/_utils/types.py +22 -2
- snowflake/ml/jobs/job_definition.py +232 -0
- snowflake/ml/jobs/manager.py +16 -177
- snowflake/ml/model/__init__.py +4 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
- snowflake/ml/model/_client/model/model_version_impl.py +120 -89
- snowflake/ml/model/_client/ops/model_ops.py +4 -26
- snowflake/ml/model/_client/ops/param_utils.py +124 -0
- snowflake/ml/model/_client/ops/service_ops.py +63 -23
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/service.py +25 -54
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -1
- snowflake/ml/model/_packager/model_handlers/huggingface.py +74 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +121 -29
- snowflake/ml/model/_signatures/utils.py +130 -0
- snowflake/ml/model/openai_signatures.py +97 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/METADATA +105 -1
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/RECORD +41 -35
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/WHEEL +1 -1
- snowflake/ml/experiment/callback/__init__.py +0 -0
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1079 @@
|
|
|
1
|
+
"""SQL generators for tile-based aggregations.
|
|
2
|
+
|
|
3
|
+
This module provides SQL generation for:
|
|
4
|
+
1. TilingSqlGenerator: Creates the DT query that computes partial aggregations (tiles)
|
|
5
|
+
2. MergingSqlGenerator: Creates the CTEs for merging tiles during dataset generation
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from snowflake.ml.feature_store.aggregation import (
|
|
11
|
+
AggregationSpec,
|
|
12
|
+
AggregationType,
|
|
13
|
+
interval_to_seconds,
|
|
14
|
+
parse_interval,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# Maximum number of elements to store in array columns to avoid 128MB limit
|
|
18
|
+
# Assuming ~1KB per value, 100,000 values ≈ 100MB (leaving buffer)
|
|
19
|
+
_MAX_ARRAY_ELEMENTS = 100000
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TilingSqlGenerator:
|
|
23
|
+
"""Generates SQL for creating tile Dynamic Tables.
|
|
24
|
+
|
|
25
|
+
The tiling query:
|
|
26
|
+
1. Computes TIME_SLICE to bucket rows into tiles
|
|
27
|
+
2. Computes partial aggregations per (join_keys, tile_start)
|
|
28
|
+
3. For simple aggregations: stores SUM/COUNT as scalars
|
|
29
|
+
4. For list aggregations: stores pre-sorted arrays (ARRAY_AGG with ORDER BY)
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
source_query: str,
|
|
35
|
+
join_keys: list[str],
|
|
36
|
+
timestamp_col: str,
|
|
37
|
+
feature_granularity: str,
|
|
38
|
+
features: list[AggregationSpec],
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Initialize the TilingSqlGenerator.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
source_query: The source query providing raw event data.
|
|
44
|
+
join_keys: List of column names used as join keys (from entities).
|
|
45
|
+
timestamp_col: The timestamp column name.
|
|
46
|
+
feature_granularity: The tile interval (e.g., "1h", "1d").
|
|
47
|
+
features: List of aggregation specifications.
|
|
48
|
+
"""
|
|
49
|
+
self._source_query = source_query
|
|
50
|
+
self._join_keys = join_keys
|
|
51
|
+
self._timestamp_col = timestamp_col
|
|
52
|
+
self._feature_granularity = feature_granularity
|
|
53
|
+
self._features = features
|
|
54
|
+
|
|
55
|
+
# Parse interval for SQL generation
|
|
56
|
+
self._interval_value, self._interval_unit = parse_interval(feature_granularity)
|
|
57
|
+
|
|
58
|
+
# Track if we have any lifetime features (need cumulative columns)
|
|
59
|
+
self._has_lifetime_features = any(f.is_lifetime() for f in features)
|
|
60
|
+
|
|
61
|
+
def generate(self) -> str:
|
|
62
|
+
"""Generate the complete tiling SQL query.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
SQL query for creating the tile Dynamic Table.
|
|
66
|
+
"""
|
|
67
|
+
tile_columns = self._generate_tile_columns()
|
|
68
|
+
# Join keys and timestamp_col are already properly formatted by SqlIdentifier
|
|
69
|
+
join_keys_str = ", ".join(self._join_keys)
|
|
70
|
+
ts_col = self._timestamp_col
|
|
71
|
+
|
|
72
|
+
if not self._has_lifetime_features:
|
|
73
|
+
# Simple case: no lifetime features, just partial aggregations
|
|
74
|
+
query = f"""
|
|
75
|
+
SELECT
|
|
76
|
+
{join_keys_str},
|
|
77
|
+
TIME_SLICE({ts_col}, {self._interval_value}, '{self._interval_unit}', 'START') AS TILE_START,
|
|
78
|
+
{', '.join(tile_columns)}
|
|
79
|
+
FROM ({self._source_query})
|
|
80
|
+
GROUP BY {join_keys_str}, TILE_START
|
|
81
|
+
"""
|
|
82
|
+
else:
|
|
83
|
+
# With lifetime features: need cumulative columns via window functions
|
|
84
|
+
# Structure: SELECT *, cumulative_columns FROM (SELECT partial_columns GROUP BY)
|
|
85
|
+
cumulative_columns = self._generate_cumulative_columns()
|
|
86
|
+
|
|
87
|
+
query = f"""
|
|
88
|
+
SELECT
|
|
89
|
+
base.*,
|
|
90
|
+
{', '.join(cumulative_columns)}
|
|
91
|
+
FROM (
|
|
92
|
+
SELECT
|
|
93
|
+
{join_keys_str},
|
|
94
|
+
TIME_SLICE({ts_col}, {self._interval_value}, '{self._interval_unit}', 'START') AS TILE_START,
|
|
95
|
+
{', '.join(tile_columns)}
|
|
96
|
+
FROM ({self._source_query})
|
|
97
|
+
GROUP BY {join_keys_str}, TILE_START
|
|
98
|
+
) base
|
|
99
|
+
"""
|
|
100
|
+
return query.strip()
|
|
101
|
+
|
|
102
|
+
def _generate_tile_columns(self) -> list[str]:
|
|
103
|
+
"""Generate the tile column expressions for all features.
|
|
104
|
+
|
|
105
|
+
All simple aggregations share base partial columns:
|
|
106
|
+
- _PARTIAL_SUM_{col}: SUM(col) - used by SUM, AVG, STD, VAR
|
|
107
|
+
- _PARTIAL_COUNT_{col}: COUNT(col) - used by COUNT, AVG, STD, VAR
|
|
108
|
+
- _PARTIAL_SUM_SQ_{col}: SUM(col*col) - used by STD, VAR
|
|
109
|
+
|
|
110
|
+
This allows maximum column reuse. For example, SUM(amount) + AVG(amount)
|
|
111
|
+
only creates 2 columns, not 3.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
List of SQL column expressions for the tile table.
|
|
115
|
+
"""
|
|
116
|
+
# Track unique tile columns by their full name to avoid duplicates
|
|
117
|
+
seen_columns: set[str] = set()
|
|
118
|
+
columns = []
|
|
119
|
+
ts_col = self._timestamp_col
|
|
120
|
+
|
|
121
|
+
for spec in self._features:
|
|
122
|
+
src_col = spec.source_column
|
|
123
|
+
|
|
124
|
+
if spec.function == AggregationType.SUM:
|
|
125
|
+
# SUM needs _PARTIAL_SUM
|
|
126
|
+
col_name = spec.get_tile_column_name("SUM")
|
|
127
|
+
if col_name not in seen_columns:
|
|
128
|
+
columns.append(f"SUM({src_col}) AS {col_name}")
|
|
129
|
+
seen_columns.add(col_name)
|
|
130
|
+
|
|
131
|
+
elif spec.function == AggregationType.COUNT:
|
|
132
|
+
# COUNT needs _PARTIAL_COUNT
|
|
133
|
+
col_name = spec.get_tile_column_name("COUNT")
|
|
134
|
+
if col_name not in seen_columns:
|
|
135
|
+
columns.append(f"COUNT({src_col}) AS {col_name}")
|
|
136
|
+
seen_columns.add(col_name)
|
|
137
|
+
|
|
138
|
+
elif spec.function == AggregationType.AVG:
|
|
139
|
+
# AVG needs _PARTIAL_SUM and _PARTIAL_COUNT
|
|
140
|
+
sum_col = spec.get_tile_column_name("SUM")
|
|
141
|
+
count_col = spec.get_tile_column_name("COUNT")
|
|
142
|
+
if sum_col not in seen_columns:
|
|
143
|
+
columns.append(f"SUM({src_col}) AS {sum_col}")
|
|
144
|
+
seen_columns.add(sum_col)
|
|
145
|
+
if count_col not in seen_columns:
|
|
146
|
+
columns.append(f"COUNT({src_col}) AS {count_col}")
|
|
147
|
+
seen_columns.add(count_col)
|
|
148
|
+
|
|
149
|
+
elif spec.function == AggregationType.MIN:
|
|
150
|
+
# MIN needs _PARTIAL_MIN
|
|
151
|
+
col_name = spec.get_tile_column_name("MIN")
|
|
152
|
+
if col_name not in seen_columns:
|
|
153
|
+
columns.append(f"MIN({src_col}) AS {col_name}")
|
|
154
|
+
seen_columns.add(col_name)
|
|
155
|
+
|
|
156
|
+
elif spec.function == AggregationType.MAX:
|
|
157
|
+
# MAX needs _PARTIAL_MAX
|
|
158
|
+
col_name = spec.get_tile_column_name("MAX")
|
|
159
|
+
if col_name not in seen_columns:
|
|
160
|
+
columns.append(f"MAX({src_col}) AS {col_name}")
|
|
161
|
+
seen_columns.add(col_name)
|
|
162
|
+
|
|
163
|
+
elif spec.function in (AggregationType.STD, AggregationType.VAR):
|
|
164
|
+
# STD/VAR need _PARTIAL_SUM, _PARTIAL_COUNT, and _PARTIAL_SUM_SQ
|
|
165
|
+
sum_col = spec.get_tile_column_name("SUM")
|
|
166
|
+
count_col = spec.get_tile_column_name("COUNT")
|
|
167
|
+
sum_sq_col = spec.get_tile_column_name("SUM_SQ")
|
|
168
|
+
if sum_col not in seen_columns:
|
|
169
|
+
columns.append(f"SUM({src_col}) AS {sum_col}")
|
|
170
|
+
seen_columns.add(sum_col)
|
|
171
|
+
if count_col not in seen_columns:
|
|
172
|
+
columns.append(f"COUNT({src_col}) AS {count_col}")
|
|
173
|
+
seen_columns.add(count_col)
|
|
174
|
+
if sum_sq_col not in seen_columns:
|
|
175
|
+
columns.append(f"SUM({src_col} * {src_col}) AS {sum_sq_col}")
|
|
176
|
+
seen_columns.add(sum_sq_col)
|
|
177
|
+
|
|
178
|
+
elif spec.function == AggregationType.APPROX_COUNT_DISTINCT:
|
|
179
|
+
# APPROX_COUNT_DISTINCT uses HLL (HyperLogLog) state
|
|
180
|
+
col_name = spec.get_tile_column_name("HLL")
|
|
181
|
+
if col_name not in seen_columns:
|
|
182
|
+
columns.append(f"HLL_EXPORT(HLL_ACCUMULATE({src_col})) AS {col_name}")
|
|
183
|
+
seen_columns.add(col_name)
|
|
184
|
+
|
|
185
|
+
elif spec.function == AggregationType.APPROX_PERCENTILE:
|
|
186
|
+
# APPROX_PERCENTILE uses T-Digest state
|
|
187
|
+
col_name = spec.get_tile_column_name("TDIGEST")
|
|
188
|
+
if col_name not in seen_columns:
|
|
189
|
+
columns.append(f"APPROX_PERCENTILE_ACCUMULATE({src_col}) AS {col_name}")
|
|
190
|
+
seen_columns.add(col_name)
|
|
191
|
+
|
|
192
|
+
elif spec.function == AggregationType.LAST_N:
|
|
193
|
+
col_name = spec.get_tile_column_name("LAST")
|
|
194
|
+
if col_name not in seen_columns:
|
|
195
|
+
columns.append(
|
|
196
|
+
f"ARRAY_SLICE("
|
|
197
|
+
f"ARRAY_AGG({src_col}) WITHIN GROUP (ORDER BY {ts_col} DESC), "
|
|
198
|
+
f"0, {_MAX_ARRAY_ELEMENTS}) AS {col_name}"
|
|
199
|
+
)
|
|
200
|
+
seen_columns.add(col_name)
|
|
201
|
+
|
|
202
|
+
elif spec.function == AggregationType.LAST_DISTINCT_N:
|
|
203
|
+
# Uses same tile column as LAST_N (dedup happens at merge time)
|
|
204
|
+
col_name = spec.get_tile_column_name("LAST")
|
|
205
|
+
if col_name not in seen_columns:
|
|
206
|
+
columns.append(
|
|
207
|
+
f"ARRAY_SLICE("
|
|
208
|
+
f"ARRAY_AGG({src_col}) WITHIN GROUP (ORDER BY {ts_col} DESC), "
|
|
209
|
+
f"0, {_MAX_ARRAY_ELEMENTS}) AS {col_name}"
|
|
210
|
+
)
|
|
211
|
+
seen_columns.add(col_name)
|
|
212
|
+
|
|
213
|
+
elif spec.function == AggregationType.FIRST_N:
|
|
214
|
+
col_name = spec.get_tile_column_name("FIRST")
|
|
215
|
+
if col_name not in seen_columns:
|
|
216
|
+
columns.append(
|
|
217
|
+
f"ARRAY_SLICE("
|
|
218
|
+
f"ARRAY_AGG({src_col}) WITHIN GROUP (ORDER BY {ts_col} ASC), "
|
|
219
|
+
f"0, {_MAX_ARRAY_ELEMENTS}) AS {col_name}"
|
|
220
|
+
)
|
|
221
|
+
seen_columns.add(col_name)
|
|
222
|
+
|
|
223
|
+
elif spec.function == AggregationType.FIRST_DISTINCT_N:
|
|
224
|
+
# Uses same tile column as FIRST_N (dedup happens at merge time)
|
|
225
|
+
col_name = spec.get_tile_column_name("FIRST")
|
|
226
|
+
if col_name not in seen_columns:
|
|
227
|
+
columns.append(
|
|
228
|
+
f"ARRAY_SLICE("
|
|
229
|
+
f"ARRAY_AGG({src_col}) WITHIN GROUP (ORDER BY {ts_col} ASC), "
|
|
230
|
+
f"0, {_MAX_ARRAY_ELEMENTS}) AS {col_name}"
|
|
231
|
+
)
|
|
232
|
+
seen_columns.add(col_name)
|
|
233
|
+
|
|
234
|
+
return columns
|
|
235
|
+
|
|
236
|
+
def _generate_cumulative_columns(self) -> list[str]:
|
|
237
|
+
"""Generate cumulative column expressions for lifetime aggregations.
|
|
238
|
+
|
|
239
|
+
These columns use window functions to compute running totals per entity.
|
|
240
|
+
They are computed over the partial columns from the inner GROUP BY.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
List of SQL column expressions for cumulative columns.
|
|
244
|
+
"""
|
|
245
|
+
seen_columns: set[str] = set()
|
|
246
|
+
columns = []
|
|
247
|
+
join_keys_str = ", ".join(self._join_keys)
|
|
248
|
+
|
|
249
|
+
# Only process lifetime features
|
|
250
|
+
lifetime_features = [f for f in self._features if f.is_lifetime()]
|
|
251
|
+
|
|
252
|
+
for spec in lifetime_features:
|
|
253
|
+
if spec.function == AggregationType.SUM:
|
|
254
|
+
# Cumulative SUM
|
|
255
|
+
partial_col = spec.get_tile_column_name("SUM")
|
|
256
|
+
cum_col = spec.get_cumulative_column_name("SUM")
|
|
257
|
+
if cum_col not in seen_columns:
|
|
258
|
+
columns.append(
|
|
259
|
+
f"SUM({partial_col}) OVER ("
|
|
260
|
+
f"PARTITION BY {join_keys_str} "
|
|
261
|
+
f"ORDER BY TILE_START "
|
|
262
|
+
f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
|
|
263
|
+
f") AS {cum_col}"
|
|
264
|
+
)
|
|
265
|
+
seen_columns.add(cum_col)
|
|
266
|
+
|
|
267
|
+
elif spec.function == AggregationType.COUNT:
|
|
268
|
+
# Cumulative COUNT
|
|
269
|
+
partial_col = spec.get_tile_column_name("COUNT")
|
|
270
|
+
cum_col = spec.get_cumulative_column_name("COUNT")
|
|
271
|
+
if cum_col not in seen_columns:
|
|
272
|
+
columns.append(
|
|
273
|
+
f"SUM({partial_col}) OVER ("
|
|
274
|
+
f"PARTITION BY {join_keys_str} "
|
|
275
|
+
f"ORDER BY TILE_START "
|
|
276
|
+
f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
|
|
277
|
+
f") AS {cum_col}"
|
|
278
|
+
)
|
|
279
|
+
seen_columns.add(cum_col)
|
|
280
|
+
|
|
281
|
+
elif spec.function == AggregationType.AVG:
|
|
282
|
+
# Cumulative AVG needs cumulative SUM and COUNT
|
|
283
|
+
partial_sum = spec.get_tile_column_name("SUM")
|
|
284
|
+
partial_count = spec.get_tile_column_name("COUNT")
|
|
285
|
+
cum_sum = spec.get_cumulative_column_name("SUM")
|
|
286
|
+
cum_count = spec.get_cumulative_column_name("COUNT")
|
|
287
|
+
|
|
288
|
+
if cum_sum not in seen_columns:
|
|
289
|
+
columns.append(
|
|
290
|
+
f"SUM({partial_sum}) OVER ("
|
|
291
|
+
f"PARTITION BY {join_keys_str} "
|
|
292
|
+
f"ORDER BY TILE_START "
|
|
293
|
+
f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
|
|
294
|
+
f") AS {cum_sum}"
|
|
295
|
+
)
|
|
296
|
+
seen_columns.add(cum_sum)
|
|
297
|
+
if cum_count not in seen_columns:
|
|
298
|
+
columns.append(
|
|
299
|
+
f"SUM({partial_count}) OVER ("
|
|
300
|
+
f"PARTITION BY {join_keys_str} "
|
|
301
|
+
f"ORDER BY TILE_START "
|
|
302
|
+
f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
|
|
303
|
+
f") AS {cum_count}"
|
|
304
|
+
)
|
|
305
|
+
seen_columns.add(cum_count)
|
|
306
|
+
|
|
307
|
+
elif spec.function == AggregationType.MIN:
|
|
308
|
+
# Cumulative MIN (running minimum)
|
|
309
|
+
partial_col = spec.get_tile_column_name("MIN")
|
|
310
|
+
cum_col = spec.get_cumulative_column_name("MIN")
|
|
311
|
+
if cum_col not in seen_columns:
|
|
312
|
+
columns.append(
|
|
313
|
+
f"MIN({partial_col}) OVER ("
|
|
314
|
+
f"PARTITION BY {join_keys_str} "
|
|
315
|
+
f"ORDER BY TILE_START "
|
|
316
|
+
f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
|
|
317
|
+
f") AS {cum_col}"
|
|
318
|
+
)
|
|
319
|
+
seen_columns.add(cum_col)
|
|
320
|
+
|
|
321
|
+
elif spec.function == AggregationType.MAX:
|
|
322
|
+
# Cumulative MAX (running maximum)
|
|
323
|
+
partial_col = spec.get_tile_column_name("MAX")
|
|
324
|
+
cum_col = spec.get_cumulative_column_name("MAX")
|
|
325
|
+
if cum_col not in seen_columns:
|
|
326
|
+
columns.append(
|
|
327
|
+
f"MAX({partial_col}) OVER ("
|
|
328
|
+
f"PARTITION BY {join_keys_str} "
|
|
329
|
+
f"ORDER BY TILE_START "
|
|
330
|
+
f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
|
|
331
|
+
f") AS {cum_col}"
|
|
332
|
+
)
|
|
333
|
+
seen_columns.add(cum_col)
|
|
334
|
+
|
|
335
|
+
elif spec.function in (AggregationType.STD, AggregationType.VAR):
|
|
336
|
+
# Cumulative STD/VAR needs cumulative SUM, COUNT, and SUM_SQ
|
|
337
|
+
partial_sum = spec.get_tile_column_name("SUM")
|
|
338
|
+
partial_count = spec.get_tile_column_name("COUNT")
|
|
339
|
+
partial_sum_sq = spec.get_tile_column_name("SUM_SQ")
|
|
340
|
+
cum_sum = spec.get_cumulative_column_name("SUM")
|
|
341
|
+
cum_count = spec.get_cumulative_column_name("COUNT")
|
|
342
|
+
cum_sum_sq = spec.get_cumulative_column_name("SUM_SQ")
|
|
343
|
+
|
|
344
|
+
if cum_sum not in seen_columns:
|
|
345
|
+
columns.append(
|
|
346
|
+
f"SUM({partial_sum}) OVER ("
|
|
347
|
+
f"PARTITION BY {join_keys_str} "
|
|
348
|
+
f"ORDER BY TILE_START "
|
|
349
|
+
f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
|
|
350
|
+
f") AS {cum_sum}"
|
|
351
|
+
)
|
|
352
|
+
seen_columns.add(cum_sum)
|
|
353
|
+
if cum_count not in seen_columns:
|
|
354
|
+
columns.append(
|
|
355
|
+
f"SUM({partial_count}) OVER ("
|
|
356
|
+
f"PARTITION BY {join_keys_str} "
|
|
357
|
+
f"ORDER BY TILE_START "
|
|
358
|
+
f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
|
|
359
|
+
f") AS {cum_count}"
|
|
360
|
+
)
|
|
361
|
+
seen_columns.add(cum_count)
|
|
362
|
+
if cum_sum_sq not in seen_columns:
|
|
363
|
+
columns.append(
|
|
364
|
+
f"SUM({partial_sum_sq}) OVER ("
|
|
365
|
+
f"PARTITION BY {join_keys_str} "
|
|
366
|
+
f"ORDER BY TILE_START "
|
|
367
|
+
f"ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
|
|
368
|
+
f") AS {cum_sum_sq}"
|
|
369
|
+
)
|
|
370
|
+
seen_columns.add(cum_sum_sq)
|
|
371
|
+
|
|
372
|
+
# Note: APPROX_COUNT_DISTINCT (HLL_COMBINE) and APPROX_PERCENTILE (APPROX_PERCENTILE_COMBINE)
|
|
373
|
+
# do NOT support cumulative window frames in Snowflake.
|
|
374
|
+
# These will be handled at merge time by aggregating all tiles.
|
|
375
|
+
|
|
376
|
+
# Note: FIRST_N, FIRST_DISTINCT_N, LAST_N, LAST_DISTINCT_N lifetime
|
|
377
|
+
# are handled at merge time by scanning tiles, not via cumulative columns
|
|
378
|
+
|
|
379
|
+
return columns
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
class MergingSqlGenerator:
|
|
383
|
+
"""Generates CTEs for merging tiles during dataset generation.
|
|
384
|
+
|
|
385
|
+
The merging process:
|
|
386
|
+
1. TILES_JOINED_FVi: Join tiles with spine, filtering by window and complete tiles only
|
|
387
|
+
2. SIMPLE_MERGED_FVi: Aggregate simple features (SUM, COUNT, AVG)
|
|
388
|
+
3. LIST_MERGED_FVi: Flatten and aggregate list features (LAST_N, etc.)
|
|
389
|
+
4. FVi: Combine simple and list results
|
|
390
|
+
"""
|
|
391
|
+
|
|
392
|
+
def __init__(
|
|
393
|
+
self,
|
|
394
|
+
tile_table: str,
|
|
395
|
+
join_keys: list[str],
|
|
396
|
+
timestamp_col: str,
|
|
397
|
+
feature_granularity: str,
|
|
398
|
+
features: list[AggregationSpec],
|
|
399
|
+
spine_timestamp_col: str,
|
|
400
|
+
fv_index: int,
|
|
401
|
+
) -> None:
|
|
402
|
+
"""Initialize the MergingSqlGenerator.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
tile_table: Fully qualified name of the tile table.
|
|
406
|
+
join_keys: List of join key column names.
|
|
407
|
+
timestamp_col: The timestamp column from the feature view.
|
|
408
|
+
feature_granularity: The tile interval.
|
|
409
|
+
features: List of aggregation specifications.
|
|
410
|
+
spine_timestamp_col: The timestamp column from the spine.
|
|
411
|
+
fv_index: Index of this feature view (for CTE naming).
|
|
412
|
+
"""
|
|
413
|
+
self._tile_table = tile_table
|
|
414
|
+
self._join_keys = join_keys
|
|
415
|
+
self._timestamp_col = timestamp_col
|
|
416
|
+
self._feature_granularity = feature_granularity
|
|
417
|
+
self._features = features
|
|
418
|
+
self._spine_timestamp_col = spine_timestamp_col
|
|
419
|
+
self._fv_index = fv_index
|
|
420
|
+
|
|
421
|
+
# Separate lifetime from non-lifetime features
|
|
422
|
+
self._lifetime_features = [f for f in features if f.is_lifetime()]
|
|
423
|
+
self._non_lifetime_features = [f for f in features if not f.is_lifetime()]
|
|
424
|
+
|
|
425
|
+
# Separate non-lifetime features by type (simple vs list)
|
|
426
|
+
self._simple_features = [f for f in self._non_lifetime_features if f.function.is_simple()]
|
|
427
|
+
self._list_features = [f for f in self._non_lifetime_features if f.function.is_list()]
|
|
428
|
+
|
|
429
|
+
# Lifetime features are all simple (validation ensures only SUM, COUNT, AVG, MIN, MAX, STD, VAR)
|
|
430
|
+
self._lifetime_simple_features = self._lifetime_features
|
|
431
|
+
|
|
432
|
+
# Parse interval
|
|
433
|
+
self._interval_value, self._interval_unit = parse_interval(feature_granularity)
|
|
434
|
+
self._interval_seconds = interval_to_seconds(feature_granularity)
|
|
435
|
+
|
|
436
|
+
# Calculate max window in tiles for filtering (only for non-lifetime features)
|
|
437
|
+
if self._non_lifetime_features:
|
|
438
|
+
# Max tiles needed is the max of (window + offset) across all non-lifetime features
|
|
439
|
+
max_lookback_seconds = max(
|
|
440
|
+
f.get_window_seconds() + f.get_offset_seconds() for f in self._non_lifetime_features
|
|
441
|
+
)
|
|
442
|
+
self._max_tiles_needed = (max_lookback_seconds + self._interval_seconds - 1) // self._interval_seconds
|
|
443
|
+
else:
|
|
444
|
+
self._max_tiles_needed = 0
|
|
445
|
+
|
|
446
|
+
def generate_all_ctes(self) -> list[tuple[str, str]]:
|
|
447
|
+
"""Generate all CTEs needed for this feature view.
|
|
448
|
+
|
|
449
|
+
The optimization flow:
|
|
450
|
+
1. SPINE_BOUNDARY: Add truncated tile boundary to spine
|
|
451
|
+
2. UNIQUE_BOUNDS: Get distinct (entity, boundary) pairs
|
|
452
|
+
3. TILES_JOINED: Join tiles to unique boundaries (for non-lifetime features)
|
|
453
|
+
4. SIMPLE_MERGED: Aggregate simple non-lifetime features per boundary
|
|
454
|
+
5. LIST_MERGED: Aggregate list non-lifetime features per boundary
|
|
455
|
+
6. LIFETIME_MERGED: ASOF JOIN for lifetime features (O(1) per boundary)
|
|
456
|
+
7. LIFETIME_LIST_MERGED: Scan tiles for lifetime list features
|
|
457
|
+
8. FV: Join back to spine to expand results
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
List of (cte_name, cte_body) tuples.
|
|
461
|
+
"""
|
|
462
|
+
ctes = []
|
|
463
|
+
|
|
464
|
+
# CTE 1: Spine with tile boundary (for join-back)
|
|
465
|
+
ctes.append(self._generate_spine_boundary_cte())
|
|
466
|
+
|
|
467
|
+
# CTE 2: Unique boundaries (optimization - reduce aggregation work)
|
|
468
|
+
ctes.append(self._generate_unique_boundaries_cte())
|
|
469
|
+
|
|
470
|
+
# CTE 3: Join tiles with unique boundaries (only if we have non-lifetime features)
|
|
471
|
+
if self._non_lifetime_features:
|
|
472
|
+
ctes.append(self._generate_tiles_joined_cte())
|
|
473
|
+
|
|
474
|
+
# CTE 4: Simple aggregations (if any non-lifetime)
|
|
475
|
+
if self._simple_features:
|
|
476
|
+
ctes.append(self._generate_simple_merged_cte())
|
|
477
|
+
|
|
478
|
+
# CTE 5: List aggregations (if any non-lifetime)
|
|
479
|
+
if self._list_features:
|
|
480
|
+
ctes.append(self._generate_list_merged_cte())
|
|
481
|
+
|
|
482
|
+
# CTE 6: Lifetime simple aggregations (using ASOF JOIN on cumulative columns)
|
|
483
|
+
if self._lifetime_simple_features:
|
|
484
|
+
ctes.append(self._generate_lifetime_merged_cte())
|
|
485
|
+
|
|
486
|
+
# CTE 7: Combine all results and join back to spine
|
|
487
|
+
ctes.append(self._generate_combined_cte())
|
|
488
|
+
|
|
489
|
+
return ctes
|
|
490
|
+
|
|
491
|
+
def _generate_spine_boundary_cte(self) -> tuple[str, str]:
|
|
492
|
+
"""Generate CTE that adds tile boundary to deduplicated spine.
|
|
493
|
+
|
|
494
|
+
The tile boundary is the truncated timestamp that determines which
|
|
495
|
+
complete tiles are visible. All spine rows with the same boundary
|
|
496
|
+
will have identical feature values.
|
|
497
|
+
|
|
498
|
+
Returns:
|
|
499
|
+
Tuple of (cte_name, cte_body).
|
|
500
|
+
"""
|
|
501
|
+
cte_name = f"SPINE_BOUNDARY_FV{self._fv_index}"
|
|
502
|
+
|
|
503
|
+
# Quote column names to preserve case-sensitivity from spine dataframe
|
|
504
|
+
# The spine_timestamp_col is passed as-is from the user (e.g., "query_ts")
|
|
505
|
+
quoted_join_keys = [f'"{k}"' for k in self._join_keys]
|
|
506
|
+
quoted_spine_ts = f'"{self._spine_timestamp_col}"'
|
|
507
|
+
|
|
508
|
+
# Select all columns plus the tile boundary
|
|
509
|
+
all_cols = quoted_join_keys + [quoted_spine_ts]
|
|
510
|
+
select_cols = ", ".join(all_cols)
|
|
511
|
+
|
|
512
|
+
# DATE_TRUNC to tile granularity gives us the boundary
|
|
513
|
+
# All timestamps in the same granularity window see the same complete tiles
|
|
514
|
+
cte_body = f"""
|
|
515
|
+
SELECT DISTINCT {select_cols},
|
|
516
|
+
DATE_TRUNC('{self._interval_unit.lower()}', {quoted_spine_ts}) AS TILE_BOUNDARY
|
|
517
|
+
FROM SPINE
|
|
518
|
+
"""
|
|
519
|
+
return cte_name, cte_body.strip()
|
|
520
|
+
|
|
521
|
+
def _generate_unique_boundaries_cte(self) -> tuple[str, str]:
|
|
522
|
+
"""Generate CTE with unique (entity, boundary) pairs.
|
|
523
|
+
|
|
524
|
+
This is the key optimization: instead of computing features for each
|
|
525
|
+
spine row, we compute once per unique boundary and join back.
|
|
526
|
+
|
|
527
|
+
Returns:
|
|
528
|
+
Tuple of (cte_name, cte_body).
|
|
529
|
+
"""
|
|
530
|
+
cte_name = f"UNIQUE_BOUNDS_FV{self._fv_index}"
|
|
531
|
+
|
|
532
|
+
# Quote column names to handle case-sensitivity
|
|
533
|
+
quoted_join_keys = [f'"{k}"' for k in self._join_keys]
|
|
534
|
+
keys_str = ", ".join(quoted_join_keys)
|
|
535
|
+
|
|
536
|
+
cte_body = f"""
|
|
537
|
+
SELECT DISTINCT {keys_str}, TILE_BOUNDARY
|
|
538
|
+
FROM SPINE_BOUNDARY_FV{self._fv_index}
|
|
539
|
+
"""
|
|
540
|
+
return cte_name, cte_body.strip()
|
|
541
|
+
|
|
542
|
+
def _generate_tiles_joined_cte(self) -> tuple[str, str]:
|
|
543
|
+
"""Generate the CTE that joins tiles with unique boundaries.
|
|
544
|
+
|
|
545
|
+
This joins tiles to UNIQUE_BOUNDS (not full spine), which is much smaller
|
|
546
|
+
when there are many spine rows per tile boundary.
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
Tuple of (cte_name, cte_body) for the tiles joined CTE.
|
|
550
|
+
"""
|
|
551
|
+
cte_name = f"TILES_JOINED_FV{self._fv_index}"
|
|
552
|
+
|
|
553
|
+
# Quote column names for spine columns (case-sensitive)
|
|
554
|
+
quoted_join_keys = [f'"{k}"' for k in self._join_keys]
|
|
555
|
+
|
|
556
|
+
# Tile table column names match the join keys (already SqlIdentifier-formatted)
|
|
557
|
+
tile_join_keys = list(self._join_keys)
|
|
558
|
+
tile_keys_str = ", ".join(tile_join_keys)
|
|
559
|
+
|
|
560
|
+
# Join conditions: quoted spine columns to uppercase tile columns
|
|
561
|
+
join_conditions = [f"UB.{qk} = TILES.{tk}" for qk, tk in zip(quoted_join_keys, tile_join_keys)]
|
|
562
|
+
|
|
563
|
+
# Get all tile columns we need (deduplicated)
|
|
564
|
+
tile_columns_set: set[str] = set()
|
|
565
|
+
for spec in self._features:
|
|
566
|
+
if spec.function == AggregationType.SUM:
|
|
567
|
+
tile_columns_set.add(spec.get_tile_column_name("SUM"))
|
|
568
|
+
elif spec.function == AggregationType.COUNT:
|
|
569
|
+
tile_columns_set.add(spec.get_tile_column_name("COUNT"))
|
|
570
|
+
elif spec.function == AggregationType.AVG:
|
|
571
|
+
tile_columns_set.add(spec.get_tile_column_name("SUM"))
|
|
572
|
+
tile_columns_set.add(spec.get_tile_column_name("COUNT"))
|
|
573
|
+
elif spec.function == AggregationType.MIN:
|
|
574
|
+
tile_columns_set.add(spec.get_tile_column_name("MIN"))
|
|
575
|
+
elif spec.function == AggregationType.MAX:
|
|
576
|
+
tile_columns_set.add(spec.get_tile_column_name("MAX"))
|
|
577
|
+
elif spec.function in (AggregationType.STD, AggregationType.VAR):
|
|
578
|
+
tile_columns_set.add(spec.get_tile_column_name("SUM"))
|
|
579
|
+
tile_columns_set.add(spec.get_tile_column_name("COUNT"))
|
|
580
|
+
tile_columns_set.add(spec.get_tile_column_name("SUM_SQ"))
|
|
581
|
+
elif spec.function == AggregationType.APPROX_COUNT_DISTINCT:
|
|
582
|
+
tile_columns_set.add(spec.get_tile_column_name("HLL"))
|
|
583
|
+
elif spec.function == AggregationType.APPROX_PERCENTILE:
|
|
584
|
+
tile_columns_set.add(spec.get_tile_column_name("TDIGEST"))
|
|
585
|
+
elif spec.function in (AggregationType.LAST_N, AggregationType.LAST_DISTINCT_N):
|
|
586
|
+
tile_columns_set.add(spec.get_tile_column_name("LAST"))
|
|
587
|
+
elif spec.function in (AggregationType.FIRST_N, AggregationType.FIRST_DISTINCT_N):
|
|
588
|
+
tile_columns_set.add(spec.get_tile_column_name("FIRST"))
|
|
589
|
+
tile_columns = sorted(tile_columns_set) # Sort for deterministic output
|
|
590
|
+
|
|
591
|
+
tile_columns_str = ", ".join(f"TILES.{col}" for col in tile_columns)
|
|
592
|
+
|
|
593
|
+
# Window filter: only include tiles within the max window and complete tiles
|
|
594
|
+
# Complete tiles: tile_end <= tile_boundary (not spine timestamp)
|
|
595
|
+
# tile_end = DATEADD(interval_unit, interval_value, tile_start)
|
|
596
|
+
cte_body = f"""
|
|
597
|
+
SELECT
|
|
598
|
+
UB.*,
|
|
599
|
+
TILES.TILE_START,
|
|
600
|
+
{tile_columns_str}
|
|
601
|
+
FROM UNIQUE_BOUNDS_FV{self._fv_index} UB
|
|
602
|
+
LEFT JOIN (
|
|
603
|
+
SELECT {tile_keys_str}, TILE_START, {', '.join(tile_columns)}
|
|
604
|
+
FROM {self._tile_table}
|
|
605
|
+
) TILES
|
|
606
|
+
ON {' AND '.join(join_conditions)}
|
|
607
|
+
-- Window filter: tiles within max window from tile boundary
|
|
608
|
+
AND TILES.TILE_START >= DATEADD(
|
|
609
|
+
{self._interval_unit}, -{self._max_tiles_needed * self._interval_value}, UB.TILE_BOUNDARY
|
|
610
|
+
)
|
|
611
|
+
-- Complete tiles only: tile_end <= tile_boundary
|
|
612
|
+
AND DATEADD({self._interval_unit}, {self._interval_value}, TILES.TILE_START) <= UB.TILE_BOUNDARY
|
|
613
|
+
"""
|
|
614
|
+
return cte_name, cte_body.strip()
|
|
615
|
+
|
|
616
|
+
def _get_tile_filter_condition(self, spec: AggregationSpec) -> str:
|
|
617
|
+
"""Generate the CASE WHEN condition for filtering tiles by window and offset.
|
|
618
|
+
|
|
619
|
+
For a feature with window W and offset O, we want tiles where:
|
|
620
|
+
- TILE_START >= TILE_BOUNDARY - W - O (start of window)
|
|
621
|
+
- TILE_START < TILE_BOUNDARY - O (end of window, before offset)
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
spec: The aggregation specification.
|
|
625
|
+
|
|
626
|
+
Returns:
|
|
627
|
+
SQL condition string for use in CASE WHEN.
|
|
628
|
+
"""
|
|
629
|
+
window_tiles = (spec.get_window_seconds() + self._interval_seconds - 1) // self._interval_seconds
|
|
630
|
+
offset_tiles = spec.get_offset_seconds() // self._interval_seconds
|
|
631
|
+
|
|
632
|
+
if offset_tiles == 0:
|
|
633
|
+
# No offset: just filter by window start
|
|
634
|
+
return (
|
|
635
|
+
f"TILE_START >= DATEADD({self._interval_unit}, "
|
|
636
|
+
f"-{window_tiles * self._interval_value}, TILE_BOUNDARY)"
|
|
637
|
+
)
|
|
638
|
+
else:
|
|
639
|
+
# With offset: filter by both window start and end (shifted by offset)
|
|
640
|
+
window_start_tiles = window_tiles + offset_tiles
|
|
641
|
+
return (
|
|
642
|
+
f"TILE_START >= DATEADD({self._interval_unit}, "
|
|
643
|
+
f"-{window_start_tiles * self._interval_value}, TILE_BOUNDARY) "
|
|
644
|
+
f"AND TILE_START < DATEADD({self._interval_unit}, "
|
|
645
|
+
f"-{offset_tiles * self._interval_value}, TILE_BOUNDARY)"
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
def _get_list_tile_filter_condition(self, spec: AggregationSpec) -> str:
|
|
649
|
+
"""Generate filter condition for list aggregations (with table prefix).
|
|
650
|
+
|
|
651
|
+
Similar to _get_tile_filter_condition but uses t. prefix for table references.
|
|
652
|
+
|
|
653
|
+
Args:
|
|
654
|
+
spec: The aggregation specification.
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
SQL condition string for use in WHERE clause.
|
|
658
|
+
"""
|
|
659
|
+
window_tiles = (spec.get_window_seconds() + self._interval_seconds - 1) // self._interval_seconds
|
|
660
|
+
offset_tiles = spec.get_offset_seconds() // self._interval_seconds
|
|
661
|
+
|
|
662
|
+
if offset_tiles == 0:
|
|
663
|
+
# No offset: just filter by window start
|
|
664
|
+
return (
|
|
665
|
+
f"t.TILE_START >= DATEADD({self._interval_unit}, "
|
|
666
|
+
f"-{window_tiles * self._interval_value}, t.TILE_BOUNDARY)"
|
|
667
|
+
)
|
|
668
|
+
else:
|
|
669
|
+
# With offset: filter by both window start and end (shifted by offset)
|
|
670
|
+
window_start_tiles = window_tiles + offset_tiles
|
|
671
|
+
return (
|
|
672
|
+
f"t.TILE_START >= DATEADD({self._interval_unit}, "
|
|
673
|
+
f"-{window_start_tiles * self._interval_value}, t.TILE_BOUNDARY) "
|
|
674
|
+
f"AND t.TILE_START < DATEADD({self._interval_unit}, "
|
|
675
|
+
f"-{offset_tiles * self._interval_value}, t.TILE_BOUNDARY)"
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
def _generate_simple_merged_cte(self) -> tuple[str, str]:
|
|
679
|
+
"""Generate the CTE for simple aggregations (SUM, COUNT, AVG).
|
|
680
|
+
|
|
681
|
+
Groups by entity keys + TILE_BOUNDARY (not spine timestamp) for efficiency.
|
|
682
|
+
|
|
683
|
+
Returns:
|
|
684
|
+
Tuple of (cte_name, cte_body) for the simple merged CTE.
|
|
685
|
+
"""
|
|
686
|
+
cte_name = f"SIMPLE_MERGED_FV{self._fv_index}"
|
|
687
|
+
|
|
688
|
+
# Quote column names for case-sensitivity (from UNIQUE_BOUNDS which inherits from spine)
|
|
689
|
+
quoted_join_keys = [f'"{k}"' for k in self._join_keys]
|
|
690
|
+
# Group by entity keys + TILE_BOUNDARY (optimization)
|
|
691
|
+
group_by_cols = quoted_join_keys + ["TILE_BOUNDARY"]
|
|
692
|
+
group_by_str = ", ".join(group_by_cols)
|
|
693
|
+
|
|
694
|
+
agg_columns = []
|
|
695
|
+
for spec in self._simple_features:
|
|
696
|
+
output_col = spec.get_sql_column_name()
|
|
697
|
+
tile_filter = self._get_tile_filter_condition(spec)
|
|
698
|
+
|
|
699
|
+
if spec.function == AggregationType.SUM:
|
|
700
|
+
col_name = spec.get_tile_column_name("SUM")
|
|
701
|
+
agg_columns.append(f"SUM(CASE WHEN {tile_filter} " f"THEN {col_name} ELSE 0 END) AS {output_col}")
|
|
702
|
+
|
|
703
|
+
elif spec.function == AggregationType.COUNT:
|
|
704
|
+
col_name = spec.get_tile_column_name("COUNT")
|
|
705
|
+
agg_columns.append(f"SUM(CASE WHEN {tile_filter} " f"THEN {col_name} ELSE 0 END) AS {output_col}")
|
|
706
|
+
|
|
707
|
+
elif spec.function == AggregationType.MIN:
|
|
708
|
+
col_name = spec.get_tile_column_name("MIN")
|
|
709
|
+
agg_columns.append(f"MIN(CASE WHEN {tile_filter} " f"THEN {col_name} ELSE NULL END) AS {output_col}")
|
|
710
|
+
|
|
711
|
+
elif spec.function == AggregationType.MAX:
|
|
712
|
+
col_name = spec.get_tile_column_name("MAX")
|
|
713
|
+
agg_columns.append(f"MAX(CASE WHEN {tile_filter} " f"THEN {col_name} ELSE NULL END) AS {output_col}")
|
|
714
|
+
|
|
715
|
+
elif spec.function == AggregationType.AVG:
|
|
716
|
+
# AVG = SUM(partial_sums) / SUM(partial_counts)
|
|
717
|
+
sum_col = spec.get_tile_column_name("SUM")
|
|
718
|
+
count_col = spec.get_tile_column_name("COUNT")
|
|
719
|
+
agg_columns.append(
|
|
720
|
+
f"CASE WHEN SUM(CASE WHEN {tile_filter} "
|
|
721
|
+
f"THEN {count_col} ELSE 0 END) > 0 "
|
|
722
|
+
f"THEN SUM(CASE WHEN {tile_filter} "
|
|
723
|
+
f"THEN {sum_col} ELSE 0 END) / "
|
|
724
|
+
f"SUM(CASE WHEN {tile_filter} "
|
|
725
|
+
f"THEN {count_col} ELSE 0 END) "
|
|
726
|
+
f"ELSE NULL END AS {output_col}"
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
elif spec.function == AggregationType.VAR:
|
|
730
|
+
# VAR = (SUM_SQ / COUNT) - (SUM / COUNT)^2
|
|
731
|
+
# Using parallel variance formula
|
|
732
|
+
# GREATEST(0, ...) clamps to non-negative to handle floating-point errors
|
|
733
|
+
sum_col = spec.get_tile_column_name("SUM")
|
|
734
|
+
count_col = spec.get_tile_column_name("COUNT")
|
|
735
|
+
sum_sq_col = spec.get_tile_column_name("SUM_SQ")
|
|
736
|
+
agg_columns.append(
|
|
737
|
+
f"CASE WHEN SUM(CASE WHEN {tile_filter} "
|
|
738
|
+
f"THEN {count_col} ELSE 0 END) > 0 "
|
|
739
|
+
f"THEN GREATEST(0, ("
|
|
740
|
+
f"SUM(CASE WHEN {tile_filter} "
|
|
741
|
+
f"THEN {sum_sq_col} ELSE 0 END) / "
|
|
742
|
+
f"SUM(CASE WHEN {tile_filter} "
|
|
743
|
+
f"THEN {count_col} ELSE 0 END)"
|
|
744
|
+
f") - POWER("
|
|
745
|
+
f"SUM(CASE WHEN {tile_filter} "
|
|
746
|
+
f"THEN {sum_col} ELSE 0 END) / "
|
|
747
|
+
f"SUM(CASE WHEN {tile_filter} "
|
|
748
|
+
f"THEN {count_col} ELSE 0 END), 2)) "
|
|
749
|
+
f"ELSE NULL END AS {output_col}"
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
elif spec.function == AggregationType.STD:
|
|
753
|
+
# STD = SQRT(VAR) = SQRT((SUM_SQ / COUNT) - (SUM / COUNT)^2)
|
|
754
|
+
# GREATEST(0, ...) clamps variance to non-negative to handle floating-point errors
|
|
755
|
+
# that can cause sqrt of tiny negative numbers like -4.54747e-13
|
|
756
|
+
sum_col = spec.get_tile_column_name("SUM")
|
|
757
|
+
count_col = spec.get_tile_column_name("COUNT")
|
|
758
|
+
sum_sq_col = spec.get_tile_column_name("SUM_SQ")
|
|
759
|
+
agg_columns.append(
|
|
760
|
+
f"CASE WHEN SUM(CASE WHEN {tile_filter} "
|
|
761
|
+
f"THEN {count_col} ELSE 0 END) > 0 "
|
|
762
|
+
f"THEN SQRT(GREATEST(0, "
|
|
763
|
+
f"SUM(CASE WHEN {tile_filter} "
|
|
764
|
+
f"THEN {sum_sq_col} ELSE 0 END) / "
|
|
765
|
+
f"SUM(CASE WHEN {tile_filter} "
|
|
766
|
+
f"THEN {count_col} ELSE 0 END)"
|
|
767
|
+
f" - POWER("
|
|
768
|
+
f"SUM(CASE WHEN {tile_filter} "
|
|
769
|
+
f"THEN {sum_col} ELSE 0 END) / "
|
|
770
|
+
f"SUM(CASE WHEN {tile_filter} "
|
|
771
|
+
f"THEN {count_col} ELSE 0 END), 2))) "
|
|
772
|
+
f"ELSE NULL END AS {output_col}"
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
elif spec.function == AggregationType.APPROX_COUNT_DISTINCT:
|
|
776
|
+
# Combine HLL states and estimate count
|
|
777
|
+
# HLL_ESTIMATE(HLL_COMBINE(HLL_IMPORT(state))) gives the approximate count
|
|
778
|
+
col_name = spec.get_tile_column_name("HLL")
|
|
779
|
+
agg_columns.append(
|
|
780
|
+
f"HLL_ESTIMATE(HLL_COMBINE("
|
|
781
|
+
f"CASE WHEN {tile_filter} THEN HLL_IMPORT({col_name}) ELSE NULL END"
|
|
782
|
+
f")) AS {output_col}"
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
elif spec.function == AggregationType.APPROX_PERCENTILE:
|
|
786
|
+
# Combine T-Digest states and estimate percentile
|
|
787
|
+
# APPROX_PERCENTILE_ESTIMATE(APPROX_PERCENTILE_COMBINE(state), percentile)
|
|
788
|
+
col_name = spec.get_tile_column_name("TDIGEST")
|
|
789
|
+
percentile = spec.params.get("percentile", 0.5)
|
|
790
|
+
agg_columns.append(
|
|
791
|
+
f"APPROX_PERCENTILE_ESTIMATE(APPROX_PERCENTILE_COMBINE("
|
|
792
|
+
f"CASE WHEN {tile_filter} THEN {col_name} ELSE NULL END"
|
|
793
|
+
f"), {percentile}) AS {output_col}"
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
cte_body = f"""
|
|
797
|
+
SELECT
|
|
798
|
+
{group_by_str},
|
|
799
|
+
{', '.join(agg_columns)}
|
|
800
|
+
FROM TILES_JOINED_FV{self._fv_index}
|
|
801
|
+
GROUP BY {group_by_str}
|
|
802
|
+
"""
|
|
803
|
+
return cte_name, cte_body.strip()
|
|
804
|
+
|
|
805
|
+
def _generate_list_merged_cte(self) -> tuple[str, str]:
|
|
806
|
+
"""Generate the CTE for list aggregations using LATERAL FLATTEN."""
|
|
807
|
+
cte_name = f"LIST_MERGED_FV{self._fv_index}"
|
|
808
|
+
|
|
809
|
+
# Generate a more efficient single-pass CTE
|
|
810
|
+
# Each list feature gets its own lateral flatten in the FROM clause
|
|
811
|
+
cte_body = self._generate_list_cte_body()
|
|
812
|
+
|
|
813
|
+
return cte_name, cte_body.strip()
|
|
814
|
+
|
|
815
|
+
def _generate_lifetime_merged_cte(self) -> tuple[str, str]:
|
|
816
|
+
"""Generate the CTE for lifetime simple aggregations using ASOF JOIN.
|
|
817
|
+
|
|
818
|
+
Uses ASOF JOIN on cumulative columns for O(1) lookup per boundary.
|
|
819
|
+
This is much faster than aggregating all tiles from the beginning of time.
|
|
820
|
+
|
|
821
|
+
Returns:
|
|
822
|
+
Tuple of (cte_name, cte_body) for the lifetime merged CTE.
|
|
823
|
+
"""
|
|
824
|
+
cte_name = f"LIFETIME_MERGED_FV{self._fv_index}"
|
|
825
|
+
|
|
826
|
+
# Quote column names for case-sensitivity
|
|
827
|
+
quoted_join_keys = [f'"{k}"' for k in self._join_keys]
|
|
828
|
+
group_by_cols = quoted_join_keys + ["TILE_BOUNDARY"]
|
|
829
|
+
|
|
830
|
+
# Tile table column names match the join keys (already SqlIdentifier-formatted)
|
|
831
|
+
tile_join_keys = list(self._join_keys)
|
|
832
|
+
|
|
833
|
+
# Build ASOF JOIN condition: match the most recent tile before the boundary
|
|
834
|
+
asof_match = "UB.TILE_BOUNDARY > TILES.TILE_START"
|
|
835
|
+
join_conditions = [f"UB.{qk} = TILES.{tk}" for qk, tk in zip(quoted_join_keys, tile_join_keys)]
|
|
836
|
+
join_conditions_str = " AND ".join(join_conditions)
|
|
837
|
+
|
|
838
|
+
# Build select columns for lifetime features
|
|
839
|
+
select_cols = []
|
|
840
|
+
for spec in self._lifetime_simple_features:
|
|
841
|
+
output_col = spec.get_sql_column_name()
|
|
842
|
+
|
|
843
|
+
if spec.function == AggregationType.SUM:
|
|
844
|
+
cum_col = spec.get_cumulative_column_name("SUM")
|
|
845
|
+
select_cols.append(f"TILES.{cum_col} AS {output_col}")
|
|
846
|
+
|
|
847
|
+
elif spec.function == AggregationType.COUNT:
|
|
848
|
+
cum_col = spec.get_cumulative_column_name("COUNT")
|
|
849
|
+
select_cols.append(f"TILES.{cum_col} AS {output_col}")
|
|
850
|
+
|
|
851
|
+
elif spec.function == AggregationType.AVG:
|
|
852
|
+
cum_sum = spec.get_cumulative_column_name("SUM")
|
|
853
|
+
cum_count = spec.get_cumulative_column_name("COUNT")
|
|
854
|
+
select_cols.append(
|
|
855
|
+
f"CASE WHEN TILES.{cum_count} > 0 "
|
|
856
|
+
f"THEN TILES.{cum_sum} / TILES.{cum_count} "
|
|
857
|
+
f"ELSE NULL END AS {output_col}"
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
elif spec.function == AggregationType.MIN:
|
|
861
|
+
cum_col = spec.get_cumulative_column_name("MIN")
|
|
862
|
+
select_cols.append(f"TILES.{cum_col} AS {output_col}")
|
|
863
|
+
|
|
864
|
+
elif spec.function == AggregationType.MAX:
|
|
865
|
+
cum_col = spec.get_cumulative_column_name("MAX")
|
|
866
|
+
select_cols.append(f"TILES.{cum_col} AS {output_col}")
|
|
867
|
+
|
|
868
|
+
elif spec.function == AggregationType.VAR:
|
|
869
|
+
cum_sum = spec.get_cumulative_column_name("SUM")
|
|
870
|
+
cum_count = spec.get_cumulative_column_name("COUNT")
|
|
871
|
+
cum_sum_sq = spec.get_cumulative_column_name("SUM_SQ")
|
|
872
|
+
select_cols.append(
|
|
873
|
+
f"CASE WHEN TILES.{cum_count} > 0 "
|
|
874
|
+
f"THEN GREATEST(0, "
|
|
875
|
+
f"TILES.{cum_sum_sq} / TILES.{cum_count} "
|
|
876
|
+
f"- POWER(TILES.{cum_sum} / TILES.{cum_count}, 2)) "
|
|
877
|
+
f"ELSE NULL END AS {output_col}"
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
elif spec.function == AggregationType.STD:
|
|
881
|
+
cum_sum = spec.get_cumulative_column_name("SUM")
|
|
882
|
+
cum_count = spec.get_cumulative_column_name("COUNT")
|
|
883
|
+
cum_sum_sq = spec.get_cumulative_column_name("SUM_SQ")
|
|
884
|
+
select_cols.append(
|
|
885
|
+
f"CASE WHEN TILES.{cum_count} > 0 "
|
|
886
|
+
f"THEN SQRT(GREATEST(0, "
|
|
887
|
+
f"TILES.{cum_sum_sq} / TILES.{cum_count} "
|
|
888
|
+
f"- POWER(TILES.{cum_sum} / TILES.{cum_count}, 2))) "
|
|
889
|
+
f"ELSE NULL END AS {output_col}"
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
elif spec.function == AggregationType.APPROX_COUNT_DISTINCT:
|
|
893
|
+
cum_col = spec.get_cumulative_column_name("HLL")
|
|
894
|
+
select_cols.append(f"HLL_ESTIMATE(HLL_IMPORT(TILES.{cum_col})) AS {output_col}")
|
|
895
|
+
|
|
896
|
+
elif spec.function == AggregationType.APPROX_PERCENTILE:
|
|
897
|
+
cum_col = spec.get_cumulative_column_name("TDIGEST")
|
|
898
|
+
percentile = spec.params.get("percentile", 0.5)
|
|
899
|
+
select_cols.append(f"APPROX_PERCENTILE_ESTIMATE(TILES.{cum_col}, {percentile}) AS {output_col}")
|
|
900
|
+
|
|
901
|
+
select_cols_str = ", ".join(select_cols)
|
|
902
|
+
|
|
903
|
+
# Qualify group by columns with UB alias to avoid ambiguity in ASOF JOIN
|
|
904
|
+
qualified_group_cols = [f"UB.{col}" for col in group_by_cols]
|
|
905
|
+
qualified_group_str = ", ".join(qualified_group_cols)
|
|
906
|
+
|
|
907
|
+
cte_body = f"""
|
|
908
|
+
SELECT
|
|
909
|
+
{qualified_group_str},
|
|
910
|
+
{select_cols_str}
|
|
911
|
+
FROM UNIQUE_BOUNDS_FV{self._fv_index} UB
|
|
912
|
+
ASOF JOIN {self._tile_table} TILES
|
|
913
|
+
MATCH_CONDITION ({asof_match})
|
|
914
|
+
ON {join_conditions_str}
|
|
915
|
+
"""
|
|
916
|
+
return cte_name, cte_body.strip()
|
|
917
|
+
|
|
918
|
+
def _generate_list_cte_body(self) -> str:
|
|
919
|
+
"""Generate the body of the list merged CTE.
|
|
920
|
+
|
|
921
|
+
Groups by entity keys + TILE_BOUNDARY (not spine timestamp) for efficiency.
|
|
922
|
+
|
|
923
|
+
Returns:
|
|
924
|
+
SQL string for the list merged CTE body.
|
|
925
|
+
"""
|
|
926
|
+
# Quote column names for case-sensitivity (from UNIQUE_BOUNDS which inherits from spine)
|
|
927
|
+
quoted_join_keys = [f'"{k}"' for k in self._join_keys]
|
|
928
|
+
# Group by entity keys + TILE_BOUNDARY (optimization)
|
|
929
|
+
group_by_cols = quoted_join_keys + ["TILE_BOUNDARY"]
|
|
930
|
+
group_by_str = ", ".join(group_by_cols)
|
|
931
|
+
select_group_cols = ", ".join(f"t.{col}" for col in group_by_cols)
|
|
932
|
+
|
|
933
|
+
if not self._list_features:
|
|
934
|
+
return f"SELECT {group_by_str} FROM TILES_JOINED_FV{self._fv_index} WHERE 1=0"
|
|
935
|
+
|
|
936
|
+
# Build subqueries for each list feature
|
|
937
|
+
feature_subqueries = []
|
|
938
|
+
|
|
939
|
+
for _idx, spec in enumerate(self._list_features):
|
|
940
|
+
# Determine the tile column name based on aggregation type
|
|
941
|
+
if spec.function in (AggregationType.LAST_N, AggregationType.LAST_DISTINCT_N):
|
|
942
|
+
col_name = spec.get_tile_column_name("LAST")
|
|
943
|
+
order_clause = "t.TILE_START DESC, flat.INDEX ASC"
|
|
944
|
+
else: # FIRST_N, FIRST_DISTINCT_N
|
|
945
|
+
col_name = spec.get_tile_column_name("FIRST")
|
|
946
|
+
order_clause = "t.TILE_START ASC, flat.INDEX ASC"
|
|
947
|
+
|
|
948
|
+
output_col = spec.get_sql_column_name()
|
|
949
|
+
n_value = spec.params["n"]
|
|
950
|
+
|
|
951
|
+
# Calculate tile filter condition for window and offset
|
|
952
|
+
tile_filter = self._get_list_tile_filter_condition(spec)
|
|
953
|
+
|
|
954
|
+
is_distinct = spec.function in (AggregationType.LAST_DISTINCT_N, AggregationType.FIRST_DISTINCT_N)
|
|
955
|
+
|
|
956
|
+
if is_distinct:
|
|
957
|
+
# For distinct, use QUALIFY to keep first occurrence of each value
|
|
958
|
+
# Note: Inner query uses t. prefix, outer query uses bare column names
|
|
959
|
+
subquery = f"""
|
|
960
|
+
(SELECT {group_by_str},
|
|
961
|
+
ARRAY_AGG(val) WITHIN GROUP (ORDER BY rn) AS {output_col}
|
|
962
|
+
FROM (
|
|
963
|
+
SELECT {select_group_cols}, flat.VALUE AS val,
|
|
964
|
+
ROW_NUMBER() OVER (PARTITION BY {select_group_cols} ORDER BY {order_clause}) AS rn,
|
|
965
|
+
ROW_NUMBER() OVER (PARTITION BY {select_group_cols}, flat.VALUE ORDER BY {order_clause}) AS dup_rn
|
|
966
|
+
FROM TILES_JOINED_FV{self._fv_index} t,
|
|
967
|
+
LATERAL FLATTEN(INPUT => t.{col_name}) flat
|
|
968
|
+
WHERE {tile_filter}
|
|
969
|
+
AND flat.VALUE IS NOT NULL
|
|
970
|
+
) ranked
|
|
971
|
+
WHERE dup_rn = 1 AND rn <= {n_value}
|
|
972
|
+
GROUP BY {group_by_str}
|
|
973
|
+
)"""
|
|
974
|
+
else:
|
|
975
|
+
# Non-distinct: straightforward flatten and aggregate
|
|
976
|
+
subquery = f"""
|
|
977
|
+
(SELECT {select_group_cols},
|
|
978
|
+
ARRAY_SLICE(
|
|
979
|
+
ARRAY_AGG(flat.VALUE) WITHIN GROUP (ORDER BY {order_clause}),
|
|
980
|
+
0, {n_value}
|
|
981
|
+
) AS {output_col}
|
|
982
|
+
FROM TILES_JOINED_FV{self._fv_index} t,
|
|
983
|
+
LATERAL FLATTEN(INPUT => t.{col_name}) flat
|
|
984
|
+
WHERE {tile_filter}
|
|
985
|
+
AND flat.VALUE IS NOT NULL
|
|
986
|
+
GROUP BY {select_group_cols}
|
|
987
|
+
)"""
|
|
988
|
+
|
|
989
|
+
feature_subqueries.append((output_col, subquery))
|
|
990
|
+
|
|
991
|
+
# Combine all subqueries with JOINs
|
|
992
|
+
if len(feature_subqueries) == 1:
|
|
993
|
+
return feature_subqueries[0][1]
|
|
994
|
+
|
|
995
|
+
# Multiple list features: join them together
|
|
996
|
+
first_name, first_query = feature_subqueries[0]
|
|
997
|
+
result = "SELECT sq0.*, "
|
|
998
|
+
for i, (name, _) in enumerate(feature_subqueries[1:], 1):
|
|
999
|
+
result += f"sq{i}.{name}"
|
|
1000
|
+
if i < len(feature_subqueries) - 1:
|
|
1001
|
+
result += ", "
|
|
1002
|
+
|
|
1003
|
+
result += f"\n FROM {first_query} sq0"
|
|
1004
|
+
|
|
1005
|
+
for i, (_name, query) in enumerate(feature_subqueries[1:], 1):
|
|
1006
|
+
join_cond = " AND ".join(f"sq0.{col} = sq{i}.{col}" for col in group_by_cols)
|
|
1007
|
+
result += f"\n LEFT JOIN {query} sq{i}\n ON {join_cond}"
|
|
1008
|
+
|
|
1009
|
+
return result
|
|
1010
|
+
|
|
1011
|
+
def _generate_combined_cte(self) -> tuple[str, str]:
|
|
1012
|
+
"""Generate the final CTE that combines and expands results.
|
|
1013
|
+
|
|
1014
|
+
This joins the merged results (grouped by entity + TILE_BOUNDARY) back to
|
|
1015
|
+
the original spine (SPINE_BOUNDARY) to expand to per-spine-row output.
|
|
1016
|
+
|
|
1017
|
+
Returns:
|
|
1018
|
+
Tuple of (cte_name, cte_body) for the combined CTE.
|
|
1019
|
+
"""
|
|
1020
|
+
cte_name = f"FV{self._fv_index:03d}"
|
|
1021
|
+
|
|
1022
|
+
# Quote column names for case-sensitivity
|
|
1023
|
+
quoted_join_keys = [f'"{k}"' for k in self._join_keys]
|
|
1024
|
+
quoted_spine_ts = f'"{self._spine_timestamp_col}"'
|
|
1025
|
+
spine_output_cols = quoted_join_keys + [quoted_spine_ts]
|
|
1026
|
+
# Merged results are grouped by entity + TILE_BOUNDARY
|
|
1027
|
+
boundary_group_cols = quoted_join_keys + ["TILE_BOUNDARY"]
|
|
1028
|
+
|
|
1029
|
+
# Select columns from spine (entity keys + original timestamp)
|
|
1030
|
+
select_cols = [f"s.{col}" for col in spine_output_cols]
|
|
1031
|
+
|
|
1032
|
+
# Add simple feature columns (non-lifetime)
|
|
1033
|
+
for spec in self._simple_features:
|
|
1034
|
+
select_cols.append(f"simple.{spec.get_sql_column_name()}")
|
|
1035
|
+
|
|
1036
|
+
# Add list feature columns (non-lifetime)
|
|
1037
|
+
for spec in self._list_features:
|
|
1038
|
+
select_cols.append(f"list_agg.{spec.get_sql_column_name()}")
|
|
1039
|
+
|
|
1040
|
+
# Add lifetime simple feature columns
|
|
1041
|
+
for spec in self._lifetime_simple_features:
|
|
1042
|
+
select_cols.append(f"lifetime.{spec.get_sql_column_name()}")
|
|
1043
|
+
|
|
1044
|
+
# Build FROM clause with all necessary JOINs
|
|
1045
|
+
from_clause = f"SPINE_BOUNDARY_FV{self._fv_index} s"
|
|
1046
|
+
joins = []
|
|
1047
|
+
|
|
1048
|
+
# Join condition template
|
|
1049
|
+
def make_join_cond(alias: str) -> str:
|
|
1050
|
+
return " AND ".join(f"s.{col} = {alias}.{col}" for col in boundary_group_cols)
|
|
1051
|
+
|
|
1052
|
+
# Add JOINs for each CTE that has features
|
|
1053
|
+
if self._simple_features:
|
|
1054
|
+
joins.append(f"LEFT JOIN SIMPLE_MERGED_FV{self._fv_index} simple ON {make_join_cond('simple')}")
|
|
1055
|
+
|
|
1056
|
+
if self._list_features:
|
|
1057
|
+
joins.append(f"LEFT JOIN LIST_MERGED_FV{self._fv_index} list_agg ON {make_join_cond('list_agg')}")
|
|
1058
|
+
|
|
1059
|
+
if self._lifetime_simple_features:
|
|
1060
|
+
joins.append(f"LEFT JOIN LIFETIME_MERGED_FV{self._fv_index} lifetime ON {make_join_cond('lifetime')}")
|
|
1061
|
+
|
|
1062
|
+
if joins:
|
|
1063
|
+
cte_body = f"""
|
|
1064
|
+
SELECT {', '.join(select_cols)}
|
|
1065
|
+
FROM {from_clause}
|
|
1066
|
+
{chr(10).join(' ' + j for j in joins)}
|
|
1067
|
+
"""
|
|
1068
|
+
else:
|
|
1069
|
+
# No features (shouldn't happen)
|
|
1070
|
+
cte_body = f"""
|
|
1071
|
+
SELECT DISTINCT {', '.join(spine_output_cols)}
|
|
1072
|
+
FROM SPINE
|
|
1073
|
+
"""
|
|
1074
|
+
|
|
1075
|
+
return cte_name, cte_body.strip()
|
|
1076
|
+
|
|
1077
|
+
def get_output_columns(self) -> list[str]:
|
|
1078
|
+
"""Get the list of output column names from this feature view."""
|
|
1079
|
+
return [spec.output_column for spec in self._features]
|