snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +6 -6
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +6 -16
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +17 -28
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +4 -5
- snowflake/ml/jobs/job.py +24 -14
- snowflake/ml/jobs/manager.py +37 -41
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +199 -26
- snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +13 -13
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +17 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +33 -31
- snowflake/ml/registry/registry.py +29 -22
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import time
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
|
3
3
|
|
4
4
|
import yaml
|
5
5
|
|
@@ -18,11 +18,11 @@ class MLJob(Generic[T]):
|
|
18
18
|
def __init__(
|
19
19
|
self,
|
20
20
|
id: str,
|
21
|
-
service_spec: Optional[
|
21
|
+
service_spec: Optional[dict[str, Any]] = None,
|
22
22
|
session: Optional[snowpark.Session] = None,
|
23
23
|
) -> None:
|
24
24
|
self._id = id
|
25
|
-
self._service_spec_cached: Optional[
|
25
|
+
self._service_spec_cached: Optional[dict[str, Any]] = service_spec
|
26
26
|
self._session = session or sp_context.get_active_session()
|
27
27
|
|
28
28
|
self._status: types.JOB_STATUS = "PENDING"
|
@@ -42,18 +42,18 @@ class MLJob(Generic[T]):
|
|
42
42
|
return self._status
|
43
43
|
|
44
44
|
@property
|
45
|
-
def _service_spec(self) ->
|
45
|
+
def _service_spec(self) -> dict[str, Any]:
|
46
46
|
"""Get the job's service spec."""
|
47
47
|
if not self._service_spec_cached:
|
48
48
|
self._service_spec_cached = _get_service_spec(self._session, self.id)
|
49
49
|
return self._service_spec_cached
|
50
50
|
|
51
51
|
@property
|
52
|
-
def _container_spec(self) ->
|
52
|
+
def _container_spec(self) -> dict[str, Any]:
|
53
53
|
"""Get the job's main container spec."""
|
54
54
|
containers = self._service_spec["spec"]["containers"]
|
55
55
|
container_spec = next(c for c in containers if c["name"] == constants.DEFAULT_CONTAINER_NAME)
|
56
|
-
return cast(
|
56
|
+
return cast(dict[str, Any], container_spec)
|
57
57
|
|
58
58
|
@property
|
59
59
|
def _stage_path(self) -> str:
|
@@ -70,8 +70,17 @@ class MLJob(Generic[T]):
|
|
70
70
|
raise RuntimeError(f"Job {self.id} doesn't have a result path configured")
|
71
71
|
return f"{self._stage_path}/{result_path}"
|
72
72
|
|
73
|
-
@
|
74
|
-
def get_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> str:
|
73
|
+
@overload
|
74
|
+
def get_logs(self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: Literal[True]) -> list[str]:
|
75
|
+
...
|
76
|
+
|
77
|
+
@overload
|
78
|
+
def get_logs(self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: Literal[False] = False) -> str:
|
79
|
+
...
|
80
|
+
|
81
|
+
def get_logs(
|
82
|
+
self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: bool = False
|
83
|
+
) -> Union[str, list[str]]:
|
75
84
|
"""
|
76
85
|
Return the job's execution logs.
|
77
86
|
|
@@ -79,15 +88,17 @@ class MLJob(Generic[T]):
|
|
79
88
|
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
80
89
|
instance_id: Optional instance ID to get logs from a specific instance.
|
81
90
|
If not provided, returns logs from the head node.
|
91
|
+
as_list: If True, returns logs as a list of lines. Otherwise, returns logs as a single string.
|
82
92
|
|
83
93
|
Returns:
|
84
94
|
The job's execution logs.
|
85
95
|
"""
|
86
96
|
logs = _get_logs(self._session, self.id, limit, instance_id)
|
87
97
|
assert isinstance(logs, str) # mypy
|
98
|
+
if as_list:
|
99
|
+
return logs.splitlines()
|
88
100
|
return logs
|
89
101
|
|
90
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
91
102
|
def show_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> None:
|
92
103
|
"""
|
93
104
|
Display the job's execution logs.
|
@@ -97,9 +108,8 @@ class MLJob(Generic[T]):
|
|
97
108
|
instance_id: Optional instance ID to get logs from a specific instance.
|
98
109
|
If not provided, displays logs from the head node.
|
99
110
|
"""
|
100
|
-
print(self.get_logs(limit, instance_id)) # noqa: T201: we need to print here.
|
111
|
+
print(self.get_logs(limit, instance_id, as_list=False)) # noqa: T201: we need to print here.
|
101
112
|
|
102
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
103
113
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
|
104
114
|
def wait(self, timeout: float = -1) -> types.JOB_STATUS:
|
105
115
|
"""
|
@@ -167,10 +177,10 @@ def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[in
|
|
167
177
|
|
168
178
|
|
169
179
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
170
|
-
def _get_service_spec(session: snowpark.Session, job_id: str) ->
|
180
|
+
def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
|
171
181
|
"""Retrieve job execution service spec."""
|
172
182
|
(row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=[job_id]).collect()
|
173
|
-
return cast(
|
183
|
+
return cast(dict[str, Any], yaml.safe_load(row["spec"]))
|
174
184
|
|
175
185
|
|
176
186
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
|
@@ -192,7 +202,7 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
|
|
192
202
|
instance_id = _get_head_instance_id(session, job_id)
|
193
203
|
|
194
204
|
# Assemble params: [job_id, instance_id, container_name, (optional) limit]
|
195
|
-
params:
|
205
|
+
params: list[Any] = [
|
196
206
|
job_id,
|
197
207
|
0 if instance_id is None else instance_id,
|
198
208
|
constants.DEFAULT_CONTAINER_NAME,
|
snowflake/ml/jobs/manager.py
CHANGED
@@ -1,16 +1,7 @@
|
|
1
|
+
import logging
|
1
2
|
import pathlib
|
2
3
|
import textwrap
|
3
|
-
from typing import
|
4
|
-
Any,
|
5
|
-
Callable,
|
6
|
-
Dict,
|
7
|
-
List,
|
8
|
-
Literal,
|
9
|
-
Optional,
|
10
|
-
TypeVar,
|
11
|
-
Union,
|
12
|
-
overload,
|
13
|
-
)
|
4
|
+
from typing import Any, Callable, Literal, Optional, TypeVar, Union, overload
|
14
5
|
from uuid import uuid4
|
15
6
|
|
16
7
|
import yaml
|
@@ -23,13 +14,14 @@ from snowflake.ml.jobs._utils import payload_utils, spec_utils
|
|
23
14
|
from snowflake.snowpark.context import get_active_session
|
24
15
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
25
16
|
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
26
19
|
_PROJECT = "MLJob"
|
27
20
|
JOB_ID_PREFIX = "MLJOB_"
|
28
21
|
|
29
22
|
T = TypeVar("T")
|
30
23
|
|
31
24
|
|
32
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
33
25
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
|
34
26
|
def list_jobs(
|
35
27
|
limit: int = 10,
|
@@ -69,7 +61,6 @@ def list_jobs(
|
|
69
61
|
return df
|
70
62
|
|
71
63
|
|
72
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
73
64
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
74
65
|
def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob[Any]:
|
75
66
|
"""Retrieve a job service from the backend."""
|
@@ -93,7 +84,6 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
93
84
|
raise
|
94
85
|
|
95
86
|
|
96
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
97
87
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
98
88
|
def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
|
99
89
|
"""Delete a job service from the backend. Status and logs will be lost."""
|
@@ -106,19 +96,18 @@ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Sessio
|
|
106
96
|
session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
107
97
|
|
108
98
|
|
109
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
110
99
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
111
100
|
def submit_file(
|
112
101
|
file_path: str,
|
113
102
|
compute_pool: str,
|
114
103
|
*,
|
115
104
|
stage_name: str,
|
116
|
-
args: Optional[
|
117
|
-
env_vars: Optional[
|
118
|
-
pip_requirements: Optional[
|
119
|
-
external_access_integrations: Optional[
|
105
|
+
args: Optional[list[str]] = None,
|
106
|
+
env_vars: Optional[dict[str, str]] = None,
|
107
|
+
pip_requirements: Optional[list[str]] = None,
|
108
|
+
external_access_integrations: Optional[list[str]] = None,
|
120
109
|
query_warehouse: Optional[str] = None,
|
121
|
-
spec_overrides: Optional[
|
110
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
122
111
|
num_instances: Optional[int] = None,
|
123
112
|
enable_metrics: bool = False,
|
124
113
|
session: Optional[snowpark.Session] = None,
|
@@ -159,7 +148,6 @@ def submit_file(
|
|
159
148
|
)
|
160
149
|
|
161
150
|
|
162
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
163
151
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
164
152
|
def submit_directory(
|
165
153
|
dir_path: str,
|
@@ -167,12 +155,12 @@ def submit_directory(
|
|
167
155
|
*,
|
168
156
|
entrypoint: str,
|
169
157
|
stage_name: str,
|
170
|
-
args: Optional[
|
171
|
-
env_vars: Optional[
|
172
|
-
pip_requirements: Optional[
|
173
|
-
external_access_integrations: Optional[
|
158
|
+
args: Optional[list[str]] = None,
|
159
|
+
env_vars: Optional[dict[str, str]] = None,
|
160
|
+
pip_requirements: Optional[list[str]] = None,
|
161
|
+
external_access_integrations: Optional[list[str]] = None,
|
174
162
|
query_warehouse: Optional[str] = None,
|
175
|
-
spec_overrides: Optional[
|
163
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
176
164
|
num_instances: Optional[int] = None,
|
177
165
|
enable_metrics: bool = False,
|
178
166
|
session: Optional[snowpark.Session] = None,
|
@@ -222,12 +210,12 @@ def _submit_job(
|
|
222
210
|
*,
|
223
211
|
stage_name: str,
|
224
212
|
entrypoint: Optional[str] = None,
|
225
|
-
args: Optional[
|
226
|
-
env_vars: Optional[
|
227
|
-
pip_requirements: Optional[
|
228
|
-
external_access_integrations: Optional[
|
213
|
+
args: Optional[list[str]] = None,
|
214
|
+
env_vars: Optional[dict[str, str]] = None,
|
215
|
+
pip_requirements: Optional[list[str]] = None,
|
216
|
+
external_access_integrations: Optional[list[str]] = None,
|
229
217
|
query_warehouse: Optional[str] = None,
|
230
|
-
spec_overrides: Optional[
|
218
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
231
219
|
num_instances: Optional[int] = None,
|
232
220
|
enable_metrics: bool = False,
|
233
221
|
session: Optional[snowpark.Session] = None,
|
@@ -242,12 +230,12 @@ def _submit_job(
|
|
242
230
|
*,
|
243
231
|
stage_name: str,
|
244
232
|
entrypoint: Optional[str] = None,
|
245
|
-
args: Optional[
|
246
|
-
env_vars: Optional[
|
247
|
-
pip_requirements: Optional[
|
248
|
-
external_access_integrations: Optional[
|
233
|
+
args: Optional[list[str]] = None,
|
234
|
+
env_vars: Optional[dict[str, str]] = None,
|
235
|
+
pip_requirements: Optional[list[str]] = None,
|
236
|
+
external_access_integrations: Optional[list[str]] = None,
|
249
237
|
query_warehouse: Optional[str] = None,
|
250
|
-
spec_overrides: Optional[
|
238
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
251
239
|
num_instances: Optional[int] = None,
|
252
240
|
enable_metrics: bool = False,
|
253
241
|
session: Optional[snowpark.Session] = None,
|
@@ -263,6 +251,8 @@ def _submit_job(
|
|
263
251
|
# TODO: Log lengths of args, env_vars, and spec_overrides values
|
264
252
|
"pip_requirements",
|
265
253
|
"external_access_integrations",
|
254
|
+
"num_instances",
|
255
|
+
"enable_metrics",
|
266
256
|
],
|
267
257
|
)
|
268
258
|
def _submit_job(
|
@@ -271,12 +261,12 @@ def _submit_job(
|
|
271
261
|
*,
|
272
262
|
stage_name: str,
|
273
263
|
entrypoint: Optional[str] = None,
|
274
|
-
args: Optional[
|
275
|
-
env_vars: Optional[
|
276
|
-
pip_requirements: Optional[
|
277
|
-
external_access_integrations: Optional[
|
264
|
+
args: Optional[list[str]] = None,
|
265
|
+
env_vars: Optional[dict[str, str]] = None,
|
266
|
+
pip_requirements: Optional[list[str]] = None,
|
267
|
+
external_access_integrations: Optional[list[str]] = None,
|
278
268
|
query_warehouse: Optional[str] = None,
|
279
|
-
spec_overrides: Optional[
|
269
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
280
270
|
num_instances: Optional[int] = None,
|
281
271
|
enable_metrics: bool = False,
|
282
272
|
session: Optional[snowpark.Session] = None,
|
@@ -305,6 +295,12 @@ def _submit_job(
|
|
305
295
|
Raises:
|
306
296
|
RuntimeError: If required Snowflake features are not enabled.
|
307
297
|
"""
|
298
|
+
# Display warning about PrPr parameters
|
299
|
+
if num_instances is not None:
|
300
|
+
logger.warning(
|
301
|
+
"_submit_job() parameter 'num_instances' is in private preview since 1.8.2. Do not use it in production.",
|
302
|
+
)
|
303
|
+
|
308
304
|
session = session or get_active_session()
|
309
305
|
job_id = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
310
306
|
stage_name = "@" + stage_name.lstrip("@").rstrip("/")
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
from datetime import datetime
|
3
|
-
from typing import TYPE_CHECKING,
|
3
|
+
from typing import TYPE_CHECKING, Literal, Optional, Union
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
6
|
from snowflake.ml._internal import telemetry
|
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|
12
12
|
from snowflake.ml.model._client.model import model_version_impl
|
13
13
|
|
14
14
|
_PROJECT = "LINEAGE"
|
15
|
-
DOMAIN_LINEAGE_REGISTRY:
|
15
|
+
DOMAIN_LINEAGE_REGISTRY: dict[str, type["LineageNode"]] = {}
|
16
16
|
|
17
17
|
|
18
18
|
class LineageNode:
|
@@ -87,8 +87,8 @@ class LineageNode:
|
|
87
87
|
def lineage(
|
88
88
|
self,
|
89
89
|
direction: Literal["upstream", "downstream"] = "downstream",
|
90
|
-
domain_filter: Optional[
|
91
|
-
) ->
|
90
|
+
domain_filter: Optional[set[Literal["feature_view", "dataset", "model", "table", "view"]]] = None,
|
91
|
+
) -> list[Union["feature_view.FeatureView", "dataset.Dataset", "model_version_impl.ModelVersion", "LineageNode"]]:
|
92
92
|
"""
|
93
93
|
Retrieves the lineage nodes connected to this node.
|
94
94
|
|
@@ -109,7 +109,7 @@ class LineageNode:
|
|
109
109
|
if domain_filter is not None:
|
110
110
|
domain_filter = {d.lower() for d in domain_filter} # type: ignore[misc]
|
111
111
|
|
112
|
-
lineage_nodes:
|
112
|
+
lineage_nodes: list["LineageNode"] = []
|
113
113
|
for row in df.collect():
|
114
114
|
lineage_object = (
|
115
115
|
json.loads(row["TARGET_OBJECT"])
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Union
|
2
2
|
|
3
3
|
import pandas as pd
|
4
4
|
|
@@ -224,7 +224,7 @@ class Model:
|
|
224
224
|
project=_TELEMETRY_PROJECT,
|
225
225
|
subproject=_TELEMETRY_SUBPROJECT,
|
226
226
|
)
|
227
|
-
def versions(self) ->
|
227
|
+
def versions(self) -> list[model_version_impl.ModelVersion]:
|
228
228
|
"""Get all versions in the model.
|
229
229
|
|
230
230
|
Returns:
|
@@ -298,7 +298,7 @@ class Model:
|
|
298
298
|
project=_TELEMETRY_PROJECT,
|
299
299
|
subproject=_TELEMETRY_SUBPROJECT,
|
300
300
|
)
|
301
|
-
def show_tags(self) ->
|
301
|
+
def show_tags(self) -> dict[str, str]:
|
302
302
|
"""Get a dictionary showing the tag and its value attached to the model.
|
303
303
|
|
304
304
|
Returns:
|
@@ -2,10 +2,11 @@ import enum
|
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
4
|
import warnings
|
5
|
-
from typing import Any, Callable,
|
5
|
+
from typing import Any, Callable, Optional, Union, overload
|
6
6
|
|
7
7
|
import pandas as pd
|
8
8
|
|
9
|
+
from snowflake import snowpark
|
9
10
|
from snowflake.ml._internal import telemetry
|
10
11
|
from snowflake.ml._internal.utils import sql_identifier
|
11
12
|
from snowflake.ml.lineage import lineage_node
|
@@ -32,7 +33,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
32
33
|
_service_ops: service_ops.ServiceOperator
|
33
34
|
_model_name: sql_identifier.SqlIdentifier
|
34
35
|
_version_name: sql_identifier.SqlIdentifier
|
35
|
-
_functions:
|
36
|
+
_functions: list[model_manifest_schema.ModelFunctionInfo]
|
36
37
|
|
37
38
|
def __init__(self) -> None:
|
38
39
|
raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
|
@@ -152,7 +153,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
152
153
|
project=_TELEMETRY_PROJECT,
|
153
154
|
subproject=_TELEMETRY_SUBPROJECT,
|
154
155
|
)
|
155
|
-
def show_metrics(self) ->
|
156
|
+
def show_metrics(self) -> dict[str, Any]:
|
156
157
|
"""Show all metrics logged with the model version.
|
157
158
|
|
158
159
|
Returns:
|
@@ -293,7 +294,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
293
294
|
statement_params=statement_params,
|
294
295
|
)
|
295
296
|
|
296
|
-
def _get_functions(self) ->
|
297
|
+
def _get_functions(self) -> list[model_manifest_schema.ModelFunctionInfo]:
|
297
298
|
statement_params = telemetry.get_statement_params(
|
298
299
|
project=_TELEMETRY_PROJECT,
|
299
300
|
subproject=_TELEMETRY_SUBPROJECT,
|
@@ -327,7 +328,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
327
328
|
project=_TELEMETRY_PROJECT,
|
328
329
|
subproject=_TELEMETRY_SUBPROJECT,
|
329
330
|
)
|
330
|
-
def show_functions(self) ->
|
331
|
+
def show_functions(self) -> list[model_manifest_schema.ModelFunctionInfo]:
|
331
332
|
"""Show all functions information in a model version that is callable.
|
332
333
|
|
333
334
|
Returns:
|
@@ -405,11 +406,6 @@ class ModelVersion(lineage_node.LineageNode):
|
|
405
406
|
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
406
407
|
type validation to make sure your input data won't overflow when providing to the model.
|
407
408
|
|
408
|
-
Raises:
|
409
|
-
ValueError: When no method with the corresponding name is available.
|
410
|
-
ValueError: When there are more than 1 target methods available in the model but no function name specified.
|
411
|
-
ValueError: When the partition column is not a valid Snowflake identifier.
|
412
|
-
|
413
409
|
Returns:
|
414
410
|
The prediction data. It would be the same type dataframe as your input.
|
415
411
|
"""
|
@@ -422,29 +418,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
422
418
|
# Partition column must be a valid identifier
|
423
419
|
partition_column = sql_identifier.SqlIdentifier(partition_column)
|
424
420
|
|
425
|
-
|
426
|
-
|
427
|
-
if function_name:
|
428
|
-
req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
|
429
|
-
find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
|
430
|
-
lambda method: method["name"] == req_method_name
|
431
|
-
)
|
432
|
-
target_function_info = next(
|
433
|
-
filter(find_method, functions),
|
434
|
-
None,
|
435
|
-
)
|
436
|
-
if target_function_info is None:
|
437
|
-
raise ValueError(
|
438
|
-
f"There is no method with name {function_name} available in the model"
|
439
|
-
f" {self.fully_qualified_model_name} version {self.version_name}"
|
440
|
-
)
|
441
|
-
elif len(functions) != 1:
|
442
|
-
raise ValueError(
|
443
|
-
f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
|
444
|
-
f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
|
445
|
-
)
|
446
|
-
else:
|
447
|
-
target_function_info = functions[0]
|
421
|
+
target_function_info = self._get_function_info(function_name=function_name)
|
448
422
|
|
449
423
|
if service_name:
|
450
424
|
database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
|
@@ -475,6 +449,33 @@ class ModelVersion(lineage_node.LineageNode):
|
|
475
449
|
is_partitioned=target_function_info["is_partitioned"],
|
476
450
|
)
|
477
451
|
|
452
|
+
def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
|
453
|
+
functions: list[model_manifest_schema.ModelFunctionInfo] = self._functions
|
454
|
+
|
455
|
+
if function_name:
|
456
|
+
req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
|
457
|
+
find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
|
458
|
+
lambda method: method["name"] == req_method_name
|
459
|
+
)
|
460
|
+
target_function_info = next(
|
461
|
+
filter(find_method, functions),
|
462
|
+
None,
|
463
|
+
)
|
464
|
+
if target_function_info is None:
|
465
|
+
raise ValueError(
|
466
|
+
f"There is no method with name {function_name} available in the model"
|
467
|
+
f" {self.fully_qualified_model_name} version {self.version_name}"
|
468
|
+
)
|
469
|
+
elif len(functions) != 1:
|
470
|
+
raise ValueError(
|
471
|
+
f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
|
472
|
+
f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
|
473
|
+
)
|
474
|
+
else:
|
475
|
+
target_function_info = functions[0]
|
476
|
+
|
477
|
+
return target_function_info
|
478
|
+
|
478
479
|
@telemetry.send_api_usage_telemetry(
|
479
480
|
project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
|
480
481
|
)
|
@@ -684,7 +685,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
684
685
|
num_workers: Optional[int] = None,
|
685
686
|
max_batch_rows: Optional[int] = None,
|
686
687
|
force_rebuild: bool = False,
|
687
|
-
build_external_access_integrations: Optional[
|
688
|
+
build_external_access_integrations: Optional[list[str]] = None,
|
688
689
|
block: bool = True,
|
689
690
|
) -> Union[str, async_job.AsyncJob]:
|
690
691
|
"""Create an inference service with the given spec.
|
@@ -751,7 +752,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
751
752
|
max_batch_rows: Optional[int] = None,
|
752
753
|
force_rebuild: bool = False,
|
753
754
|
build_external_access_integration: Optional[str] = None,
|
754
|
-
build_external_access_integrations: Optional[
|
755
|
+
build_external_access_integrations: Optional[list[str]] = None,
|
755
756
|
block: bool = True,
|
756
757
|
) -> Union[str, async_job.AsyncJob]:
|
757
758
|
"""Create an inference service with the given spec.
|
@@ -914,5 +915,72 @@ class ModelVersion(lineage_node.LineageNode):
|
|
914
915
|
statement_params=statement_params,
|
915
916
|
)
|
916
917
|
|
918
|
+
@snowpark._internal.utils.private_preview(version="1.8.3")
|
919
|
+
@telemetry.send_api_usage_telemetry(
|
920
|
+
project=_TELEMETRY_PROJECT,
|
921
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
922
|
+
)
|
923
|
+
def run_job(
|
924
|
+
self,
|
925
|
+
X: Union[pd.DataFrame, "dataframe.DataFrame"],
|
926
|
+
*,
|
927
|
+
job_name: str,
|
928
|
+
compute_pool: str,
|
929
|
+
image_repo: str,
|
930
|
+
output_table_name: str,
|
931
|
+
function_name: Optional[str] = None,
|
932
|
+
cpu_requests: Optional[str] = None,
|
933
|
+
memory_requests: Optional[str] = None,
|
934
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
935
|
+
num_workers: Optional[int] = None,
|
936
|
+
max_batch_rows: Optional[int] = None,
|
937
|
+
force_rebuild: bool = False,
|
938
|
+
build_external_access_integrations: Optional[list[str]] = None,
|
939
|
+
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
940
|
+
statement_params = telemetry.get_statement_params(
|
941
|
+
project=_TELEMETRY_PROJECT,
|
942
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
943
|
+
)
|
944
|
+
target_function_info = self._get_function_info(function_name=function_name)
|
945
|
+
job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
|
946
|
+
image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
|
947
|
+
output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
|
948
|
+
output_table_name
|
949
|
+
)
|
950
|
+
warehouse = self._service_ops._session.get_current_warehouse()
|
951
|
+
assert warehouse, "No active warehouse selected in the current session."
|
952
|
+
return self._service_ops.invoke_job_method(
|
953
|
+
target_method=target_function_info["target_method"],
|
954
|
+
signature=target_function_info["signature"],
|
955
|
+
X=X,
|
956
|
+
database_name=None,
|
957
|
+
schema_name=None,
|
958
|
+
model_name=self._model_name,
|
959
|
+
version_name=self._version_name,
|
960
|
+
job_database_name=job_db_id,
|
961
|
+
job_schema_name=job_schema_id,
|
962
|
+
job_name=job_id,
|
963
|
+
compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
|
964
|
+
warehouse_name=sql_identifier.SqlIdentifier(warehouse),
|
965
|
+
image_repo_database_name=image_repo_db_id,
|
966
|
+
image_repo_schema_name=image_repo_schema_id,
|
967
|
+
image_repo_name=image_repo_id,
|
968
|
+
output_table_database_name=output_table_db_id,
|
969
|
+
output_table_schema_name=output_table_schema_id,
|
970
|
+
output_table_name=output_table_id,
|
971
|
+
cpu_requests=cpu_requests,
|
972
|
+
memory_requests=memory_requests,
|
973
|
+
gpu_requests=gpu_requests,
|
974
|
+
num_workers=num_workers,
|
975
|
+
max_batch_rows=max_batch_rows,
|
976
|
+
force_rebuild=force_rebuild,
|
977
|
+
build_external_access_integrations=(
|
978
|
+
None
|
979
|
+
if build_external_access_integrations is None
|
980
|
+
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
981
|
+
),
|
982
|
+
statement_params=statement_params,
|
983
|
+
)
|
984
|
+
|
917
985
|
|
918
986
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional, TypedDict
|
3
3
|
|
4
4
|
from typing_extensions import NotRequired
|
5
5
|
|
@@ -14,7 +14,7 @@ MODEL_VERSION_METADATA_SCHEMA_VERSION = "2024-01-01"
|
|
14
14
|
|
15
15
|
|
16
16
|
class ModelVersionMetadataSchema(TypedDict):
|
17
|
-
metrics: NotRequired[
|
17
|
+
metrics: NotRequired[dict[str, Any]]
|
18
18
|
|
19
19
|
|
20
20
|
class MetadataOperator:
|
@@ -44,7 +44,7 @@ class MetadataOperator:
|
|
44
44
|
)
|
45
45
|
|
46
46
|
@staticmethod
|
47
|
-
def _parse(metadata_dict:
|
47
|
+
def _parse(metadata_dict: dict[str, Any]) -> ModelVersionMetadataSchema:
|
48
48
|
loaded_metadata_schema_version = metadata_dict.get("snowpark_ml_schema_version", None)
|
49
49
|
if loaded_metadata_schema_version is None:
|
50
50
|
return ModelVersionMetadataSchema(metrics={})
|
@@ -65,8 +65,8 @@ class MetadataOperator:
|
|
65
65
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
66
66
|
model_name: sql_identifier.SqlIdentifier,
|
67
67
|
version_name: sql_identifier.SqlIdentifier,
|
68
|
-
statement_params: Optional[
|
69
|
-
) ->
|
68
|
+
statement_params: Optional[dict[str, Any]] = None,
|
69
|
+
) -> dict[str, Any]:
|
70
70
|
version_info_list = self._model_client.show_versions(
|
71
71
|
database_name=database_name,
|
72
72
|
schema_name=schema_name,
|
@@ -89,7 +89,7 @@ class MetadataOperator:
|
|
89
89
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
90
90
|
model_name: sql_identifier.SqlIdentifier,
|
91
91
|
version_name: sql_identifier.SqlIdentifier,
|
92
|
-
statement_params: Optional[
|
92
|
+
statement_params: Optional[dict[str, Any]] = None,
|
93
93
|
) -> ModelVersionMetadataSchema:
|
94
94
|
metadata_dict = self._get_current_metadata_dict(
|
95
95
|
database_name=database_name,
|
@@ -108,7 +108,7 @@ class MetadataOperator:
|
|
108
108
|
schema_name: Optional[sql_identifier.SqlIdentifier],
|
109
109
|
model_name: sql_identifier.SqlIdentifier,
|
110
110
|
version_name: sql_identifier.SqlIdentifier,
|
111
|
-
statement_params: Optional[
|
111
|
+
statement_params: Optional[dict[str, Any]] = None,
|
112
112
|
) -> None:
|
113
113
|
metadata_dict = self._get_current_metadata_dict(
|
114
114
|
database_name=database_name,
|