snowflake-ml-python 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +0 -4
- snowflake/ml/feature_store/__init__.py +2 -0
- snowflake/ml/feature_store/aggregation.py +367 -0
- snowflake/ml/feature_store/feature.py +366 -0
- snowflake/ml/feature_store/feature_store.py +234 -20
- snowflake/ml/feature_store/feature_view.py +189 -4
- snowflake/ml/feature_store/metadata_manager.py +425 -0
- snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +38 -18
- snowflake/ml/jobs/_utils/query_helper.py +8 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +117 -0
- snowflake/ml/jobs/_utils/stage_utils.py +2 -2
- snowflake/ml/jobs/_utils/types.py +22 -2
- snowflake/ml/jobs/job_definition.py +232 -0
- snowflake/ml/jobs/manager.py +16 -177
- snowflake/ml/model/__init__.py +4 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
- snowflake/ml/model/_client/model/model_version_impl.py +120 -89
- snowflake/ml/model/_client/ops/model_ops.py +4 -26
- snowflake/ml/model/_client/ops/param_utils.py +124 -0
- snowflake/ml/model/_client/ops/service_ops.py +63 -23
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/service.py +25 -54
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -1
- snowflake/ml/model/_packager/model_handlers/huggingface.py +74 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +121 -29
- snowflake/ml/model/_signatures/utils.py +130 -0
- snowflake/ml/model/openai_signatures.py +97 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/METADATA +105 -1
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/RECORD +41 -35
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/WHEEL +1 -1
- snowflake/ml/experiment/callback/__init__.py +0 -0
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/top_level.txt +0 -0
|
@@ -19,7 +19,6 @@ LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
|
|
|
19
19
|
INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
|
|
20
20
|
SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST = "SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST"
|
|
21
21
|
ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS = "ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS"
|
|
22
|
-
FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
class PlatformCapabilities:
|
|
@@ -85,9 +84,6 @@ class PlatformCapabilities:
|
|
|
85
84
|
def is_model_method_signature_parameters_enabled(self) -> bool:
|
|
86
85
|
return self._get_bool_feature(ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS, False)
|
|
87
86
|
|
|
88
|
-
def is_inference_autocapture_enabled(self) -> bool:
|
|
89
|
-
return self._is_feature_enabled(FEATURE_MODEL_INFERENCE_AUTOCAPTURE)
|
|
90
|
-
|
|
91
87
|
@staticmethod
|
|
92
88
|
def _get_features(session: snowpark_session.Session) -> dict[str, Any]:
|
|
93
89
|
try:
|
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
from snowflake.ml._internal import init_utils
|
|
4
4
|
|
|
5
5
|
from .access_manager import setup_feature_store
|
|
6
|
+
from .feature import Feature
|
|
6
7
|
|
|
7
8
|
pkg_dir = os.path.dirname(__file__)
|
|
8
9
|
pkg_name = __name__
|
|
@@ -12,4 +13,5 @@ for k, v in exportable_classes.items():
|
|
|
12
13
|
|
|
13
14
|
__all__ = list(exportable_classes.keys()) + [
|
|
14
15
|
"setup_feature_store",
|
|
16
|
+
"Feature",
|
|
15
17
|
]
|
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
"""Aggregation types and specifications for tile-based feature views.
|
|
2
|
+
|
|
3
|
+
This module provides the building blocks for defining time-series aggregations
|
|
4
|
+
that are computed using a tile-based approach for efficiency and correctness.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AggregationType(Enum):
|
|
16
|
+
"""Supported aggregation functions for tiled feature views.
|
|
17
|
+
|
|
18
|
+
These aggregation types are classified into categories:
|
|
19
|
+
- Simple aggregations (SUM, COUNT, AVG, MIN, MAX, STD, VAR): Stored as scalar partial results in tiles
|
|
20
|
+
- Sketch aggregations (APPROX_COUNT_DISTINCT, APPROX_PERCENTILE): Stored as mergeable state in tiles
|
|
21
|
+
- List aggregations (LAST_N, LAST_DISTINCT_N, FIRST_N, FIRST_DISTINCT_N): Stored as arrays in tiles
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
SUM = "sum"
|
|
25
|
+
COUNT = "count"
|
|
26
|
+
AVG = "avg"
|
|
27
|
+
MIN = "min"
|
|
28
|
+
MAX = "max"
|
|
29
|
+
STD = "std"
|
|
30
|
+
VAR = "var"
|
|
31
|
+
APPROX_COUNT_DISTINCT = "approx_count_distinct"
|
|
32
|
+
APPROX_PERCENTILE = "approx_percentile"
|
|
33
|
+
LAST_N = "last_n"
|
|
34
|
+
LAST_DISTINCT_N = "last_distinct_n"
|
|
35
|
+
FIRST_N = "first_n"
|
|
36
|
+
FIRST_DISTINCT_N = "first_distinct_n"
|
|
37
|
+
|
|
38
|
+
def is_simple(self) -> bool:
|
|
39
|
+
"""Check if this is a simple aggregation (scalar result per tile).
|
|
40
|
+
|
|
41
|
+
Simple aggregations include both basic aggregates (SUM, COUNT, etc.)
|
|
42
|
+
and sketch-based aggregates (APPROX_COUNT_DISTINCT, APPROX_PERCENTILE)
|
|
43
|
+
because they all produce a single value per entity per tile boundary.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
True if this is a simple aggregation type, False otherwise.
|
|
47
|
+
"""
|
|
48
|
+
return self in (
|
|
49
|
+
AggregationType.SUM,
|
|
50
|
+
AggregationType.COUNT,
|
|
51
|
+
AggregationType.AVG,
|
|
52
|
+
AggregationType.MIN,
|
|
53
|
+
AggregationType.MAX,
|
|
54
|
+
AggregationType.STD,
|
|
55
|
+
AggregationType.VAR,
|
|
56
|
+
AggregationType.APPROX_COUNT_DISTINCT,
|
|
57
|
+
AggregationType.APPROX_PERCENTILE,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def is_list(self) -> bool:
|
|
61
|
+
"""Check if this is a list aggregation (array result per tile)."""
|
|
62
|
+
return self in (
|
|
63
|
+
AggregationType.LAST_N,
|
|
64
|
+
AggregationType.LAST_DISTINCT_N,
|
|
65
|
+
AggregationType.FIRST_N,
|
|
66
|
+
AggregationType.FIRST_DISTINCT_N,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def is_sketch(self) -> bool:
|
|
70
|
+
"""Check if this is a sketch-based aggregation (HLL, T-Digest)."""
|
|
71
|
+
return self in (
|
|
72
|
+
AggregationType.APPROX_COUNT_DISTINCT,
|
|
73
|
+
AggregationType.APPROX_PERCENTILE,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# Special window value for lifetime aggregations
|
|
78
|
+
LIFETIME_WINDOW = "lifetime"
|
|
79
|
+
|
|
80
|
+
# Regex pattern for interval parsing: "1h", "24 hours", "30 minutes", etc.
|
|
81
|
+
_INTERVAL_PATTERN = re.compile(
|
|
82
|
+
r"^\s*(\d+)\s*(s|sec|secs|second|seconds|m|min|mins|minute|minutes|h|hr|hrs|hour|hours|d|day|days)\s*$",
|
|
83
|
+
re.IGNORECASE,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Mapping of interval unit aliases to canonical Snowflake units
|
|
87
|
+
_INTERVAL_UNIT_MAP = {
|
|
88
|
+
"s": "SECOND",
|
|
89
|
+
"sec": "SECOND",
|
|
90
|
+
"secs": "SECOND",
|
|
91
|
+
"second": "SECOND",
|
|
92
|
+
"seconds": "SECOND",
|
|
93
|
+
"m": "MINUTE",
|
|
94
|
+
"min": "MINUTE",
|
|
95
|
+
"mins": "MINUTE",
|
|
96
|
+
"minute": "MINUTE",
|
|
97
|
+
"minutes": "MINUTE",
|
|
98
|
+
"h": "HOUR",
|
|
99
|
+
"hr": "HOUR",
|
|
100
|
+
"hrs": "HOUR",
|
|
101
|
+
"hour": "HOUR",
|
|
102
|
+
"hours": "HOUR",
|
|
103
|
+
"d": "DAY",
|
|
104
|
+
"day": "DAY",
|
|
105
|
+
"days": "DAY",
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
# Seconds per unit for calculating window sizes
|
|
109
|
+
_SECONDS_PER_UNIT = {
|
|
110
|
+
"SECOND": 1,
|
|
111
|
+
"MINUTE": 60,
|
|
112
|
+
"HOUR": 3600,
|
|
113
|
+
"DAY": 86400,
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def is_lifetime_window(window: str) -> bool:
|
|
118
|
+
"""Check if a window string represents a lifetime aggregation.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
window: The window string to check.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
True if the window is "lifetime", False otherwise.
|
|
125
|
+
"""
|
|
126
|
+
return window.lower().strip() == LIFETIME_WINDOW
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def parse_interval(interval: str) -> tuple[int, str]:
|
|
130
|
+
"""Parse an interval string into (value, unit) tuple.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
interval: Interval string like "1h", "24 hours", "30 minutes".
|
|
134
|
+
Note: "lifetime" is NOT a valid interval - use is_lifetime_window() first.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Tuple of (numeric_value, snowflake_unit) e.g. (1, "HOUR").
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
ValueError: If the interval format is invalid.
|
|
141
|
+
"""
|
|
142
|
+
if is_lifetime_window(interval):
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"'{interval}' is not a numeric interval. " f"Use is_lifetime_window() to check for lifetime windows."
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
match = _INTERVAL_PATTERN.match(interval)
|
|
148
|
+
if not match:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Invalid interval format: '{interval}'. "
|
|
151
|
+
f"Expected format: '<number> <unit>' where unit is one of: "
|
|
152
|
+
f"seconds, minutes, hours, days (or abbreviations s, m, h, d)"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
value = int(match.group(1))
|
|
156
|
+
unit_str = match.group(2).lower()
|
|
157
|
+
unit = _INTERVAL_UNIT_MAP[unit_str]
|
|
158
|
+
|
|
159
|
+
if value <= 0:
|
|
160
|
+
raise ValueError(f"Interval value must be positive, got: {value}")
|
|
161
|
+
|
|
162
|
+
return value, unit
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def interval_to_seconds(interval: str) -> int:
|
|
166
|
+
"""Convert an interval string to total seconds.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
interval: Interval string like "1h", "24 hours".
|
|
170
|
+
Note: "lifetime" is NOT supported - returns -1 as sentinel.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Total seconds represented by the interval, or -1 for lifetime.
|
|
174
|
+
"""
|
|
175
|
+
if is_lifetime_window(interval):
|
|
176
|
+
return -1 # Sentinel value for lifetime
|
|
177
|
+
value, unit = parse_interval(interval)
|
|
178
|
+
return value * _SECONDS_PER_UNIT[unit]
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def format_interval_for_snowflake(interval: str) -> str:
|
|
182
|
+
"""Format an interval string for use in Snowflake SQL.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
interval: Interval string like "1h", "24 hours".
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
Formatted string like "HOUR" or "DAY" for use with DATEADD/TIME_SLICE.
|
|
189
|
+
"""
|
|
190
|
+
_, unit = parse_interval(interval)
|
|
191
|
+
return unit
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@dataclass(frozen=True)
|
|
195
|
+
class AggregationSpec:
|
|
196
|
+
"""Internal representation of an aggregation specification.
|
|
197
|
+
|
|
198
|
+
This is the serializable form that gets stored in metadata and used
|
|
199
|
+
for SQL generation. Users interact with the Feature class instead.
|
|
200
|
+
|
|
201
|
+
Attributes:
|
|
202
|
+
function: The aggregation function type.
|
|
203
|
+
source_column: The column to aggregate.
|
|
204
|
+
window: The lookback window for the aggregation (e.g., "24h", "7d").
|
|
205
|
+
output_column: The name of the output column.
|
|
206
|
+
offset: Offset to shift the window into the past (e.g., "1d" means [t-window-1d, t-1d]).
|
|
207
|
+
params: Additional parameters (e.g., {"n": 10} for LAST_N).
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
function: AggregationType
|
|
211
|
+
source_column: str
|
|
212
|
+
window: str
|
|
213
|
+
output_column: str
|
|
214
|
+
offset: str = "0"
|
|
215
|
+
params: dict[str, Any] = field(default_factory=dict)
|
|
216
|
+
|
|
217
|
+
def __post_init__(self) -> None:
|
|
218
|
+
"""Validate the aggregation spec after initialization."""
|
|
219
|
+
# Validate window format (allow "lifetime" as special case)
|
|
220
|
+
if not is_lifetime_window(self.window):
|
|
221
|
+
try:
|
|
222
|
+
interval_to_seconds(self.window)
|
|
223
|
+
except ValueError as e:
|
|
224
|
+
raise ValueError(f"Invalid window for aggregation '{self.output_column}': {e}") from e
|
|
225
|
+
|
|
226
|
+
# Validate offset format (if not "0")
|
|
227
|
+
# Note: offset is not allowed with lifetime windows
|
|
228
|
+
if self.offset != "0":
|
|
229
|
+
if is_lifetime_window(self.window):
|
|
230
|
+
raise ValueError(
|
|
231
|
+
f"Offset is not supported with lifetime windows for aggregation '{self.output_column}'"
|
|
232
|
+
)
|
|
233
|
+
try:
|
|
234
|
+
offset_seconds = interval_to_seconds(self.offset)
|
|
235
|
+
if offset_seconds < 0:
|
|
236
|
+
raise ValueError("Offset must be non-negative")
|
|
237
|
+
except ValueError as e:
|
|
238
|
+
raise ValueError(f"Invalid offset for aggregation '{self.output_column}': {e}") from e
|
|
239
|
+
|
|
240
|
+
# Validate params for list aggregations
|
|
241
|
+
if self.function.is_list():
|
|
242
|
+
if "n" not in self.params:
|
|
243
|
+
raise ValueError(
|
|
244
|
+
f"Parameter 'n' is required for {self.function.value} aggregation " f"'{self.output_column}'"
|
|
245
|
+
)
|
|
246
|
+
n = self.params["n"]
|
|
247
|
+
if not isinstance(n, int) or n <= 0:
|
|
248
|
+
raise ValueError(
|
|
249
|
+
f"Parameter 'n' must be a positive integer for aggregation " f"'{self.output_column}', got: {n}"
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Validate params for approx_percentile
|
|
253
|
+
if self.function == AggregationType.APPROX_PERCENTILE:
|
|
254
|
+
if "percentile" not in self.params:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
f"Parameter 'percentile' is required for approx_percentile aggregation '{self.output_column}'"
|
|
257
|
+
)
|
|
258
|
+
percentile = self.params["percentile"]
|
|
259
|
+
if not isinstance(percentile, (int, float)) or not (0.0 <= percentile <= 1.0):
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"Parameter 'percentile' must be a float between 0.0 and 1.0 for aggregation "
|
|
262
|
+
f"'{self.output_column}', got: {percentile}"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# Validate lifetime window support
|
|
266
|
+
# Only simple scalar aggregations support lifetime (O(1) via cumulative columns)
|
|
267
|
+
if is_lifetime_window(self.window):
|
|
268
|
+
supported_lifetime_types = (
|
|
269
|
+
AggregationType.SUM,
|
|
270
|
+
AggregationType.COUNT,
|
|
271
|
+
AggregationType.AVG,
|
|
272
|
+
AggregationType.MIN,
|
|
273
|
+
AggregationType.MAX,
|
|
274
|
+
AggregationType.STD,
|
|
275
|
+
AggregationType.VAR,
|
|
276
|
+
)
|
|
277
|
+
if self.function not in supported_lifetime_types:
|
|
278
|
+
supported_names = ", ".join(t.value.upper() for t in supported_lifetime_types)
|
|
279
|
+
raise ValueError(
|
|
280
|
+
f"Lifetime window is not supported for {self.function.value} aggregation "
|
|
281
|
+
f"'{self.output_column}'. Lifetime is only supported for: {supported_names}."
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
def is_lifetime(self) -> bool:
|
|
285
|
+
"""Check if this aggregation has a lifetime window.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
True if the window is "lifetime", False otherwise.
|
|
289
|
+
"""
|
|
290
|
+
return is_lifetime_window(self.window)
|
|
291
|
+
|
|
292
|
+
def get_window_seconds(self) -> int:
|
|
293
|
+
"""Get the window size in seconds.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
Total seconds for the window, or -1 for lifetime windows.
|
|
297
|
+
"""
|
|
298
|
+
return interval_to_seconds(self.window)
|
|
299
|
+
|
|
300
|
+
def get_offset_seconds(self) -> int:
|
|
301
|
+
"""Get the offset in seconds."""
|
|
302
|
+
return interval_to_seconds(self.offset) if self.offset != "0" else 0
|
|
303
|
+
|
|
304
|
+
def get_cumulative_column_name(self, partial_type: str) -> str:
|
|
305
|
+
"""Get the cumulative column name for lifetime aggregations.
|
|
306
|
+
|
|
307
|
+
Similar to get_tile_column_name but with _CUM_ prefix instead of _PARTIAL_.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
partial_type: One of "SUM", "COUNT", "SUM_SQ", "HLL", "TDIGEST", "MIN", "MAX", "FIRST".
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
Column name used in the tile table (prefixed with _CUM_).
|
|
314
|
+
"""
|
|
315
|
+
return f"_CUM_{partial_type}_{self.source_column.upper()}"
|
|
316
|
+
|
|
317
|
+
def get_tile_column_name(self, partial_type: str) -> str:
|
|
318
|
+
"""Get the internal tile column name for a base partial aggregate.
|
|
319
|
+
|
|
320
|
+
Aggregations are computed from base partials:
|
|
321
|
+
- _PARTIAL_SUM_{col}: SUM(col) - used by SUM, AVG, STD, VAR
|
|
322
|
+
- _PARTIAL_COUNT_{col}: COUNT(col) - used by COUNT, AVG, STD, VAR
|
|
323
|
+
- _PARTIAL_SUM_SQ_{col}: SUM(col*col) - used by STD, VAR
|
|
324
|
+
- _PARTIAL_HLL_{col}: HLL state - used by APPROX_COUNT_DISTINCT
|
|
325
|
+
- _PARTIAL_TDIGEST_{col}: T-Digest state - used by APPROX_PERCENTILE
|
|
326
|
+
|
|
327
|
+
This allows sharing columns across aggregation types on the same column.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
partial_type: One of "SUM", "COUNT", "SUM_SQ", "HLL", "TDIGEST", "LAST", "FIRST".
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Column name used in the tile table (prefixed with _PARTIAL_).
|
|
334
|
+
"""
|
|
335
|
+
return f"_PARTIAL_{partial_type}_{self.source_column.upper()}"
|
|
336
|
+
|
|
337
|
+
def get_sql_column_name(self) -> str:
|
|
338
|
+
"""Get the output column name formatted for SQL.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
Column name ready for use in SQL. Case-sensitive names are stored
|
|
342
|
+
with quotes (e.g., '"My_Col"'), case-insensitive names are uppercase.
|
|
343
|
+
"""
|
|
344
|
+
return self.output_column
|
|
345
|
+
|
|
346
|
+
def to_dict(self) -> dict[str, Any]:
|
|
347
|
+
"""Convert to a dictionary for JSON serialization."""
|
|
348
|
+
return {
|
|
349
|
+
"function": self.function.value,
|
|
350
|
+
"source_column": self.source_column,
|
|
351
|
+
"window": self.window,
|
|
352
|
+
"output_column": self.output_column,
|
|
353
|
+
"offset": self.offset,
|
|
354
|
+
"params": self.params,
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
@classmethod
|
|
358
|
+
def from_dict(cls, data: dict[str, Any]) -> AggregationSpec:
|
|
359
|
+
"""Create an AggregationSpec from a dictionary."""
|
|
360
|
+
return cls(
|
|
361
|
+
function=AggregationType(data["function"]),
|
|
362
|
+
source_column=data["source_column"],
|
|
363
|
+
window=data["window"],
|
|
364
|
+
output_column=data["output_column"],
|
|
365
|
+
offset=data.get("offset", "0"),
|
|
366
|
+
params=data.get("params", {}),
|
|
367
|
+
)
|