snowflake-ml-python 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +0 -4
- snowflake/ml/_internal/utils/mixins.py +26 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +5 -1
- snowflake/ml/data/data_connector.py +2 -2
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/experiment/_experiment_info.py +3 -3
- snowflake/ml/feature_store/__init__.py +2 -0
- snowflake/ml/feature_store/aggregation.py +367 -0
- snowflake/ml/feature_store/feature.py +366 -0
- snowflake/ml/feature_store/feature_store.py +234 -20
- snowflake/ml/feature_store/feature_view.py +189 -4
- snowflake/ml/feature_store/metadata_manager.py +425 -0
- snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
- snowflake/ml/jobs/_interop/data_utils.py +8 -8
- snowflake/ml/jobs/_interop/dto_schema.py +52 -7
- snowflake/ml/jobs/_interop/protocols.py +124 -7
- snowflake/ml/jobs/_interop/utils.py +92 -33
- snowflake/ml/jobs/_utils/arg_protocol.py +7 -0
- snowflake/ml/jobs/_utils/constants.py +4 -0
- snowflake/ml/jobs/_utils/feature_flags.py +97 -13
- snowflake/ml/jobs/_utils/payload_utils.py +6 -40
- snowflake/ml/jobs/_utils/runtime_env_utils.py +12 -111
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +204 -27
- snowflake/ml/jobs/decorators.py +17 -22
- snowflake/ml/jobs/job.py +25 -10
- snowflake/ml/jobs/job_definition.py +100 -8
- snowflake/ml/model/__init__.py +4 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
- snowflake/ml/model/_client/model/model_version_impl.py +56 -28
- snowflake/ml/model/_client/ops/model_ops.py +2 -8
- snowflake/ml/model/_client/ops/service_ops.py +6 -11
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/service.py +21 -29
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -1
- snowflake/ml/model/_packager/model_handlers/huggingface.py +20 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +70 -14
- snowflake/ml/model/_signatures/utils.py +76 -1
- snowflake/ml/model/models/huggingface_pipeline.py +3 -0
- snowflake/ml/model/openai_signatures.py +154 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/METADATA +79 -2
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/RECORD +47 -44
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/WHEEL +1 -1
- snowflake/ml/jobs/_utils/function_payload_utils.py +0 -43
- snowflake/ml/jobs/_utils/spec_utils.py +0 -22
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/top_level.txt +0 -0
|
@@ -19,7 +19,6 @@ LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
|
|
|
19
19
|
INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
|
|
20
20
|
SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST = "SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST"
|
|
21
21
|
ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS = "ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS"
|
|
22
|
-
FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
class PlatformCapabilities:
|
|
@@ -85,9 +84,6 @@ class PlatformCapabilities:
|
|
|
85
84
|
def is_model_method_signature_parameters_enabled(self) -> bool:
|
|
86
85
|
return self._get_bool_feature(ENABLE_MODEL_METHOD_SIGNATURE_PARAMETERS, False)
|
|
87
86
|
|
|
88
|
-
def is_inference_autocapture_enabled(self) -> bool:
|
|
89
|
-
return self._is_feature_enabled(FEATURE_MODEL_INFERENCE_AUTOCAPTURE)
|
|
90
|
-
|
|
91
87
|
@staticmethod
|
|
92
88
|
def _get_features(session: snowpark_session.Session) -> dict[str, Any]:
|
|
93
89
|
try:
|
|
@@ -9,6 +9,7 @@ _SESSION_ACCOUNT_KEY = "session$account"
|
|
|
9
9
|
_SESSION_ROLE_KEY = "session$role"
|
|
10
10
|
_SESSION_DATABASE_KEY = "session$database"
|
|
11
11
|
_SESSION_SCHEMA_KEY = "session$schema"
|
|
12
|
+
_SESSION_STATE_ATTR = "_session_state"
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def _identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
|
|
@@ -61,7 +62,7 @@ class SerializableSessionMixin:
|
|
|
61
62
|
else:
|
|
62
63
|
self.__dict__.update(state)
|
|
63
64
|
|
|
64
|
-
self
|
|
65
|
+
setattr(self, _SESSION_STATE_ATTR, session_state)
|
|
65
66
|
|
|
66
67
|
def _set_session(self, session_state: _SessionState) -> None:
|
|
67
68
|
|
|
@@ -86,3 +87,27 @@ class SerializableSessionMixin:
|
|
|
86
87
|
),
|
|
87
88
|
),
|
|
88
89
|
)
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def session(self) -> Optional[snowpark_session.Session]:
|
|
93
|
+
if _SESSION_KEY not in self.__dict__:
|
|
94
|
+
session_state = getattr(self, _SESSION_STATE_ATTR, None)
|
|
95
|
+
if session_state is not None:
|
|
96
|
+
self._set_session(session_state)
|
|
97
|
+
return self.__dict__.get(_SESSION_KEY)
|
|
98
|
+
|
|
99
|
+
@session.setter
|
|
100
|
+
def session(self, value: Optional[snowpark_session.Session]) -> None:
|
|
101
|
+
self.__dict__[_SESSION_KEY] = value
|
|
102
|
+
|
|
103
|
+
# _getattr__ is only called when an attribute is NOT found through normal lookup.
|
|
104
|
+
# 1. Data descriptors (like @property with setter) from the class hierarchy
|
|
105
|
+
# 2. Instance __dict__ (e.g., self.x = 10)
|
|
106
|
+
# 3. Non-data descriptors (methods, `@property without setter) from the class hierarchy
|
|
107
|
+
# __getattr__ — only called if steps 1-3 all fail
|
|
108
|
+
def __getattr__(self, name: str) -> Any:
|
|
109
|
+
if name == _SESSION_KEY:
|
|
110
|
+
return self.session
|
|
111
|
+
if hasattr(super(), "__getattr__"):
|
|
112
|
+
return super().__getattr__(name) # type: ignore[misc]
|
|
113
|
+
raise AttributeError(f"{type(self).__name__!s} object has no attribute {name!r}")
|
|
@@ -73,15 +73,19 @@ class ArrowIngestor(data_ingestor.DataIngestor, mixins.SerializableSessionMixin)
|
|
|
73
73
|
self._schema: Optional[pa.Schema] = None
|
|
74
74
|
|
|
75
75
|
@classmethod
|
|
76
|
-
def from_sources(
|
|
76
|
+
def from_sources(
|
|
77
|
+
cls, session: snowpark.Session, sources: Sequence[data_source.DataSource], **kwargs: Any
|
|
78
|
+
) -> "ArrowIngestor":
|
|
77
79
|
if session is None:
|
|
78
80
|
raise ValueError("Session is required")
|
|
81
|
+
# Skipping kwargs until needed to avoid impact other workflows.
|
|
79
82
|
return cls(session, sources)
|
|
80
83
|
|
|
81
84
|
@classmethod
|
|
82
85
|
def from_ray_dataset(
|
|
83
86
|
cls,
|
|
84
87
|
ray_ds: "ray.data.Dataset",
|
|
88
|
+
**kwargs: Any,
|
|
85
89
|
) -> "ArrowIngestor":
|
|
86
90
|
raise NotImplementedError
|
|
87
91
|
|
|
@@ -94,7 +94,7 @@ class DataConnector:
|
|
|
94
94
|
**kwargs: Any,
|
|
95
95
|
) -> DataConnectorType:
|
|
96
96
|
ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
|
|
97
|
-
ray_ingestor = ingestor_class.from_ray_dataset(ray_ds=ray_ds)
|
|
97
|
+
ray_ingestor = ingestor_class.from_ray_dataset(ray_ds=ray_ds, **kwargs)
|
|
98
98
|
return cls(ray_ingestor, **kwargs)
|
|
99
99
|
|
|
100
100
|
@classmethod
|
|
@@ -111,7 +111,7 @@ class DataConnector:
|
|
|
111
111
|
**kwargs: Any,
|
|
112
112
|
) -> DataConnectorType:
|
|
113
113
|
ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
|
|
114
|
-
ingestor = ingestor_class.from_sources(session, sources)
|
|
114
|
+
ingestor = ingestor_class.from_sources(session, sources, **kwargs)
|
|
115
115
|
return cls(ingestor, **kwargs)
|
|
116
116
|
|
|
117
117
|
@property
|
|
@@ -16,7 +16,7 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
|
|
|
16
16
|
class DataIngestor(Protocol):
|
|
17
17
|
@classmethod
|
|
18
18
|
def from_sources(
|
|
19
|
-
cls: type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
|
|
19
|
+
cls: type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource], **kwargs: Any
|
|
20
20
|
) -> DataIngestorType:
|
|
21
21
|
raise NotImplementedError
|
|
22
22
|
|
|
@@ -24,6 +24,7 @@ class DataIngestor(Protocol):
|
|
|
24
24
|
def from_ray_dataset(
|
|
25
25
|
cls: type[DataIngestorType],
|
|
26
26
|
ray_ds: "ray.data.Dataset",
|
|
27
|
+
**kwargs: Any,
|
|
27
28
|
) -> DataIngestorType:
|
|
28
29
|
raise NotImplementedError
|
|
29
30
|
|
|
@@ -3,7 +3,7 @@ import functools
|
|
|
3
3
|
import types
|
|
4
4
|
from typing import Callable, Optional
|
|
5
5
|
|
|
6
|
-
from snowflake.ml import
|
|
6
|
+
from snowflake.ml.model._client.model import model_version_impl
|
|
7
7
|
from snowflake.ml.registry._manager import model_manager
|
|
8
8
|
|
|
9
9
|
|
|
@@ -23,7 +23,7 @@ class ExperimentInfoPatcher:
|
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
25
|
# Store original method at class definition time to avoid recursive patching
|
|
26
|
-
_original_log_model: Callable[...,
|
|
26
|
+
_original_log_model: Callable[..., model_version_impl.ModelVersion] = model_manager.ModelManager.log_model
|
|
27
27
|
|
|
28
28
|
# Stack of active experiment_info contexts for nested experiment support
|
|
29
29
|
_experiment_info_stack: list[ExperimentInfo] = []
|
|
@@ -36,7 +36,7 @@ class ExperimentInfoPatcher:
|
|
|
36
36
|
if not ExperimentInfoPatcher._experiment_info_stack:
|
|
37
37
|
|
|
38
38
|
@functools.wraps(ExperimentInfoPatcher._original_log_model)
|
|
39
|
-
def patched(*args, **kwargs) ->
|
|
39
|
+
def patched(*args, **kwargs) -> model_version_impl.ModelVersion: # type: ignore[no-untyped-def]
|
|
40
40
|
# Use the most recent (top of stack) experiment_info for nested contexts
|
|
41
41
|
current_experiment_info = ExperimentInfoPatcher._experiment_info_stack[-1]
|
|
42
42
|
return ExperimentInfoPatcher._original_log_model(
|
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
from snowflake.ml._internal import init_utils
|
|
4
4
|
|
|
5
5
|
from .access_manager import setup_feature_store
|
|
6
|
+
from .feature import Feature
|
|
6
7
|
|
|
7
8
|
pkg_dir = os.path.dirname(__file__)
|
|
8
9
|
pkg_name = __name__
|
|
@@ -12,4 +13,5 @@ for k, v in exportable_classes.items():
|
|
|
12
13
|
|
|
13
14
|
__all__ = list(exportable_classes.keys()) + [
|
|
14
15
|
"setup_feature_store",
|
|
16
|
+
"Feature",
|
|
15
17
|
]
|
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
"""Aggregation types and specifications for tile-based feature views.
|
|
2
|
+
|
|
3
|
+
This module provides the building blocks for defining time-series aggregations
|
|
4
|
+
that are computed using a tile-based approach for efficiency and correctness.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AggregationType(Enum):
|
|
16
|
+
"""Supported aggregation functions for tiled feature views.
|
|
17
|
+
|
|
18
|
+
These aggregation types are classified into categories:
|
|
19
|
+
- Simple aggregations (SUM, COUNT, AVG, MIN, MAX, STD, VAR): Stored as scalar partial results in tiles
|
|
20
|
+
- Sketch aggregations (APPROX_COUNT_DISTINCT, APPROX_PERCENTILE): Stored as mergeable state in tiles
|
|
21
|
+
- List aggregations (LAST_N, LAST_DISTINCT_N, FIRST_N, FIRST_DISTINCT_N): Stored as arrays in tiles
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
SUM = "sum"
|
|
25
|
+
COUNT = "count"
|
|
26
|
+
AVG = "avg"
|
|
27
|
+
MIN = "min"
|
|
28
|
+
MAX = "max"
|
|
29
|
+
STD = "std"
|
|
30
|
+
VAR = "var"
|
|
31
|
+
APPROX_COUNT_DISTINCT = "approx_count_distinct"
|
|
32
|
+
APPROX_PERCENTILE = "approx_percentile"
|
|
33
|
+
LAST_N = "last_n"
|
|
34
|
+
LAST_DISTINCT_N = "last_distinct_n"
|
|
35
|
+
FIRST_N = "first_n"
|
|
36
|
+
FIRST_DISTINCT_N = "first_distinct_n"
|
|
37
|
+
|
|
38
|
+
def is_simple(self) -> bool:
|
|
39
|
+
"""Check if this is a simple aggregation (scalar result per tile).
|
|
40
|
+
|
|
41
|
+
Simple aggregations include both basic aggregates (SUM, COUNT, etc.)
|
|
42
|
+
and sketch-based aggregates (APPROX_COUNT_DISTINCT, APPROX_PERCENTILE)
|
|
43
|
+
because they all produce a single value per entity per tile boundary.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
True if this is a simple aggregation type, False otherwise.
|
|
47
|
+
"""
|
|
48
|
+
return self in (
|
|
49
|
+
AggregationType.SUM,
|
|
50
|
+
AggregationType.COUNT,
|
|
51
|
+
AggregationType.AVG,
|
|
52
|
+
AggregationType.MIN,
|
|
53
|
+
AggregationType.MAX,
|
|
54
|
+
AggregationType.STD,
|
|
55
|
+
AggregationType.VAR,
|
|
56
|
+
AggregationType.APPROX_COUNT_DISTINCT,
|
|
57
|
+
AggregationType.APPROX_PERCENTILE,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def is_list(self) -> bool:
|
|
61
|
+
"""Check if this is a list aggregation (array result per tile)."""
|
|
62
|
+
return self in (
|
|
63
|
+
AggregationType.LAST_N,
|
|
64
|
+
AggregationType.LAST_DISTINCT_N,
|
|
65
|
+
AggregationType.FIRST_N,
|
|
66
|
+
AggregationType.FIRST_DISTINCT_N,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def is_sketch(self) -> bool:
|
|
70
|
+
"""Check if this is a sketch-based aggregation (HLL, T-Digest)."""
|
|
71
|
+
return self in (
|
|
72
|
+
AggregationType.APPROX_COUNT_DISTINCT,
|
|
73
|
+
AggregationType.APPROX_PERCENTILE,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# Special window value for lifetime aggregations
|
|
78
|
+
LIFETIME_WINDOW = "lifetime"
|
|
79
|
+
|
|
80
|
+
# Regex pattern for interval parsing: "1h", "24 hours", "30 minutes", etc.
|
|
81
|
+
_INTERVAL_PATTERN = re.compile(
|
|
82
|
+
r"^\s*(\d+)\s*(s|sec|secs|second|seconds|m|min|mins|minute|minutes|h|hr|hrs|hour|hours|d|day|days)\s*$",
|
|
83
|
+
re.IGNORECASE,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Mapping of interval unit aliases to canonical Snowflake units
|
|
87
|
+
_INTERVAL_UNIT_MAP = {
|
|
88
|
+
"s": "SECOND",
|
|
89
|
+
"sec": "SECOND",
|
|
90
|
+
"secs": "SECOND",
|
|
91
|
+
"second": "SECOND",
|
|
92
|
+
"seconds": "SECOND",
|
|
93
|
+
"m": "MINUTE",
|
|
94
|
+
"min": "MINUTE",
|
|
95
|
+
"mins": "MINUTE",
|
|
96
|
+
"minute": "MINUTE",
|
|
97
|
+
"minutes": "MINUTE",
|
|
98
|
+
"h": "HOUR",
|
|
99
|
+
"hr": "HOUR",
|
|
100
|
+
"hrs": "HOUR",
|
|
101
|
+
"hour": "HOUR",
|
|
102
|
+
"hours": "HOUR",
|
|
103
|
+
"d": "DAY",
|
|
104
|
+
"day": "DAY",
|
|
105
|
+
"days": "DAY",
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
# Seconds per unit for calculating window sizes
|
|
109
|
+
_SECONDS_PER_UNIT = {
|
|
110
|
+
"SECOND": 1,
|
|
111
|
+
"MINUTE": 60,
|
|
112
|
+
"HOUR": 3600,
|
|
113
|
+
"DAY": 86400,
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def is_lifetime_window(window: str) -> bool:
|
|
118
|
+
"""Check if a window string represents a lifetime aggregation.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
window: The window string to check.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
True if the window is "lifetime", False otherwise.
|
|
125
|
+
"""
|
|
126
|
+
return window.lower().strip() == LIFETIME_WINDOW
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def parse_interval(interval: str) -> tuple[int, str]:
|
|
130
|
+
"""Parse an interval string into (value, unit) tuple.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
interval: Interval string like "1h", "24 hours", "30 minutes".
|
|
134
|
+
Note: "lifetime" is NOT a valid interval - use is_lifetime_window() first.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Tuple of (numeric_value, snowflake_unit) e.g. (1, "HOUR").
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
ValueError: If the interval format is invalid.
|
|
141
|
+
"""
|
|
142
|
+
if is_lifetime_window(interval):
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"'{interval}' is not a numeric interval. " f"Use is_lifetime_window() to check for lifetime windows."
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
match = _INTERVAL_PATTERN.match(interval)
|
|
148
|
+
if not match:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Invalid interval format: '{interval}'. "
|
|
151
|
+
f"Expected format: '<number> <unit>' where unit is one of: "
|
|
152
|
+
f"seconds, minutes, hours, days (or abbreviations s, m, h, d)"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
value = int(match.group(1))
|
|
156
|
+
unit_str = match.group(2).lower()
|
|
157
|
+
unit = _INTERVAL_UNIT_MAP[unit_str]
|
|
158
|
+
|
|
159
|
+
if value <= 0:
|
|
160
|
+
raise ValueError(f"Interval value must be positive, got: {value}")
|
|
161
|
+
|
|
162
|
+
return value, unit
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def interval_to_seconds(interval: str) -> int:
|
|
166
|
+
"""Convert an interval string to total seconds.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
interval: Interval string like "1h", "24 hours".
|
|
170
|
+
Note: "lifetime" is NOT supported - returns -1 as sentinel.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Total seconds represented by the interval, or -1 for lifetime.
|
|
174
|
+
"""
|
|
175
|
+
if is_lifetime_window(interval):
|
|
176
|
+
return -1 # Sentinel value for lifetime
|
|
177
|
+
value, unit = parse_interval(interval)
|
|
178
|
+
return value * _SECONDS_PER_UNIT[unit]
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def format_interval_for_snowflake(interval: str) -> str:
|
|
182
|
+
"""Format an interval string for use in Snowflake SQL.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
interval: Interval string like "1h", "24 hours".
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
Formatted string like "HOUR" or "DAY" for use with DATEADD/TIME_SLICE.
|
|
189
|
+
"""
|
|
190
|
+
_, unit = parse_interval(interval)
|
|
191
|
+
return unit
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@dataclass(frozen=True)
|
|
195
|
+
class AggregationSpec:
|
|
196
|
+
"""Internal representation of an aggregation specification.
|
|
197
|
+
|
|
198
|
+
This is the serializable form that gets stored in metadata and used
|
|
199
|
+
for SQL generation. Users interact with the Feature class instead.
|
|
200
|
+
|
|
201
|
+
Attributes:
|
|
202
|
+
function: The aggregation function type.
|
|
203
|
+
source_column: The column to aggregate.
|
|
204
|
+
window: The lookback window for the aggregation (e.g., "24h", "7d").
|
|
205
|
+
output_column: The name of the output column.
|
|
206
|
+
offset: Offset to shift the window into the past (e.g., "1d" means [t-window-1d, t-1d]).
|
|
207
|
+
params: Additional parameters (e.g., {"n": 10} for LAST_N).
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
function: AggregationType
|
|
211
|
+
source_column: str
|
|
212
|
+
window: str
|
|
213
|
+
output_column: str
|
|
214
|
+
offset: str = "0"
|
|
215
|
+
params: dict[str, Any] = field(default_factory=dict)
|
|
216
|
+
|
|
217
|
+
def __post_init__(self) -> None:
|
|
218
|
+
"""Validate the aggregation spec after initialization."""
|
|
219
|
+
# Validate window format (allow "lifetime" as special case)
|
|
220
|
+
if not is_lifetime_window(self.window):
|
|
221
|
+
try:
|
|
222
|
+
interval_to_seconds(self.window)
|
|
223
|
+
except ValueError as e:
|
|
224
|
+
raise ValueError(f"Invalid window for aggregation '{self.output_column}': {e}") from e
|
|
225
|
+
|
|
226
|
+
# Validate offset format (if not "0")
|
|
227
|
+
# Note: offset is not allowed with lifetime windows
|
|
228
|
+
if self.offset != "0":
|
|
229
|
+
if is_lifetime_window(self.window):
|
|
230
|
+
raise ValueError(
|
|
231
|
+
f"Offset is not supported with lifetime windows for aggregation '{self.output_column}'"
|
|
232
|
+
)
|
|
233
|
+
try:
|
|
234
|
+
offset_seconds = interval_to_seconds(self.offset)
|
|
235
|
+
if offset_seconds < 0:
|
|
236
|
+
raise ValueError("Offset must be non-negative")
|
|
237
|
+
except ValueError as e:
|
|
238
|
+
raise ValueError(f"Invalid offset for aggregation '{self.output_column}': {e}") from e
|
|
239
|
+
|
|
240
|
+
# Validate params for list aggregations
|
|
241
|
+
if self.function.is_list():
|
|
242
|
+
if "n" not in self.params:
|
|
243
|
+
raise ValueError(
|
|
244
|
+
f"Parameter 'n' is required for {self.function.value} aggregation " f"'{self.output_column}'"
|
|
245
|
+
)
|
|
246
|
+
n = self.params["n"]
|
|
247
|
+
if not isinstance(n, int) or n <= 0:
|
|
248
|
+
raise ValueError(
|
|
249
|
+
f"Parameter 'n' must be a positive integer for aggregation " f"'{self.output_column}', got: {n}"
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Validate params for approx_percentile
|
|
253
|
+
if self.function == AggregationType.APPROX_PERCENTILE:
|
|
254
|
+
if "percentile" not in self.params:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
f"Parameter 'percentile' is required for approx_percentile aggregation '{self.output_column}'"
|
|
257
|
+
)
|
|
258
|
+
percentile = self.params["percentile"]
|
|
259
|
+
if not isinstance(percentile, (int, float)) or not (0.0 <= percentile <= 1.0):
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"Parameter 'percentile' must be a float between 0.0 and 1.0 for aggregation "
|
|
262
|
+
f"'{self.output_column}', got: {percentile}"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# Validate lifetime window support
|
|
266
|
+
# Only simple scalar aggregations support lifetime (O(1) via cumulative columns)
|
|
267
|
+
if is_lifetime_window(self.window):
|
|
268
|
+
supported_lifetime_types = (
|
|
269
|
+
AggregationType.SUM,
|
|
270
|
+
AggregationType.COUNT,
|
|
271
|
+
AggregationType.AVG,
|
|
272
|
+
AggregationType.MIN,
|
|
273
|
+
AggregationType.MAX,
|
|
274
|
+
AggregationType.STD,
|
|
275
|
+
AggregationType.VAR,
|
|
276
|
+
)
|
|
277
|
+
if self.function not in supported_lifetime_types:
|
|
278
|
+
supported_names = ", ".join(t.value.upper() for t in supported_lifetime_types)
|
|
279
|
+
raise ValueError(
|
|
280
|
+
f"Lifetime window is not supported for {self.function.value} aggregation "
|
|
281
|
+
f"'{self.output_column}'. Lifetime is only supported for: {supported_names}."
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
def is_lifetime(self) -> bool:
|
|
285
|
+
"""Check if this aggregation has a lifetime window.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
True if the window is "lifetime", False otherwise.
|
|
289
|
+
"""
|
|
290
|
+
return is_lifetime_window(self.window)
|
|
291
|
+
|
|
292
|
+
def get_window_seconds(self) -> int:
|
|
293
|
+
"""Get the window size in seconds.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
Total seconds for the window, or -1 for lifetime windows.
|
|
297
|
+
"""
|
|
298
|
+
return interval_to_seconds(self.window)
|
|
299
|
+
|
|
300
|
+
def get_offset_seconds(self) -> int:
|
|
301
|
+
"""Get the offset in seconds."""
|
|
302
|
+
return interval_to_seconds(self.offset) if self.offset != "0" else 0
|
|
303
|
+
|
|
304
|
+
def get_cumulative_column_name(self, partial_type: str) -> str:
|
|
305
|
+
"""Get the cumulative column name for lifetime aggregations.
|
|
306
|
+
|
|
307
|
+
Similar to get_tile_column_name but with _CUM_ prefix instead of _PARTIAL_.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
partial_type: One of "SUM", "COUNT", "SUM_SQ", "HLL", "TDIGEST", "MIN", "MAX", "FIRST".
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
Column name used in the tile table (prefixed with _CUM_).
|
|
314
|
+
"""
|
|
315
|
+
return f"_CUM_{partial_type}_{self.source_column.upper()}"
|
|
316
|
+
|
|
317
|
+
def get_tile_column_name(self, partial_type: str) -> str:
|
|
318
|
+
"""Get the internal tile column name for a base partial aggregate.
|
|
319
|
+
|
|
320
|
+
Aggregations are computed from base partials:
|
|
321
|
+
- _PARTIAL_SUM_{col}: SUM(col) - used by SUM, AVG, STD, VAR
|
|
322
|
+
- _PARTIAL_COUNT_{col}: COUNT(col) - used by COUNT, AVG, STD, VAR
|
|
323
|
+
- _PARTIAL_SUM_SQ_{col}: SUM(col*col) - used by STD, VAR
|
|
324
|
+
- _PARTIAL_HLL_{col}: HLL state - used by APPROX_COUNT_DISTINCT
|
|
325
|
+
- _PARTIAL_TDIGEST_{col}: T-Digest state - used by APPROX_PERCENTILE
|
|
326
|
+
|
|
327
|
+
This allows sharing columns across aggregation types on the same column.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
partial_type: One of "SUM", "COUNT", "SUM_SQ", "HLL", "TDIGEST", "LAST", "FIRST".
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Column name used in the tile table (prefixed with _PARTIAL_).
|
|
334
|
+
"""
|
|
335
|
+
return f"_PARTIAL_{partial_type}_{self.source_column.upper()}"
|
|
336
|
+
|
|
337
|
+
def get_sql_column_name(self) -> str:
|
|
338
|
+
"""Get the output column name formatted for SQL.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
Column name ready for use in SQL. Case-sensitive names are stored
|
|
342
|
+
with quotes (e.g., '"My_Col"'), case-insensitive names are uppercase.
|
|
343
|
+
"""
|
|
344
|
+
return self.output_column
|
|
345
|
+
|
|
346
|
+
def to_dict(self) -> dict[str, Any]:
|
|
347
|
+
"""Convert to a dictionary for JSON serialization."""
|
|
348
|
+
return {
|
|
349
|
+
"function": self.function.value,
|
|
350
|
+
"source_column": self.source_column,
|
|
351
|
+
"window": self.window,
|
|
352
|
+
"output_column": self.output_column,
|
|
353
|
+
"offset": self.offset,
|
|
354
|
+
"params": self.params,
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
@classmethod
|
|
358
|
+
def from_dict(cls, data: dict[str, Any]) -> AggregationSpec:
|
|
359
|
+
"""Create an AggregationSpec from a dictionary."""
|
|
360
|
+
return cls(
|
|
361
|
+
function=AggregationType(data["function"]),
|
|
362
|
+
source_column=data["source_column"],
|
|
363
|
+
window=data["window"],
|
|
364
|
+
output_column=data["output_column"],
|
|
365
|
+
offset=data.get("offset", "0"),
|
|
366
|
+
params=data.get("params", {}),
|
|
367
|
+
)
|