snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py
CHANGED
@@ -1,20 +1,32 @@
|
|
1
1
|
import time
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
|
3
|
+
|
4
|
+
import yaml
|
3
5
|
|
4
6
|
from snowflake import snowpark
|
5
7
|
from snowflake.ml._internal import telemetry
|
6
|
-
from snowflake.ml.jobs._utils import constants, types
|
8
|
+
from snowflake.ml.jobs._utils import constants, interop_utils, types
|
7
9
|
from snowflake.snowpark import context as sp_context
|
8
10
|
|
9
11
|
_PROJECT = "MLJob"
|
10
12
|
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
11
13
|
|
14
|
+
T = TypeVar("T")
|
15
|
+
|
12
16
|
|
13
|
-
class MLJob:
|
14
|
-
def __init__(
|
17
|
+
class MLJob(Generic[T]):
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
id: str,
|
21
|
+
service_spec: Optional[dict[str, Any]] = None,
|
22
|
+
session: Optional[snowpark.Session] = None,
|
23
|
+
) -> None:
|
15
24
|
self._id = id
|
25
|
+
self._service_spec_cached: Optional[dict[str, Any]] = service_spec
|
16
26
|
self._session = session or sp_context.get_active_session()
|
27
|
+
|
17
28
|
self._status: types.JOB_STATUS = "PENDING"
|
29
|
+
self._result: Optional[interop_utils.ExecutionResult] = None
|
18
30
|
|
19
31
|
@property
|
20
32
|
def id(self) -> str:
|
@@ -29,33 +41,76 @@ class MLJob:
|
|
29
41
|
self._status = _get_status(self._session, self.id)
|
30
42
|
return self._status
|
31
43
|
|
32
|
-
@
|
33
|
-
def
|
44
|
+
@property
|
45
|
+
def _service_spec(self) -> dict[str, Any]:
|
46
|
+
"""Get the job's service spec."""
|
47
|
+
if not self._service_spec_cached:
|
48
|
+
self._service_spec_cached = _get_service_spec(self._session, self.id)
|
49
|
+
return self._service_spec_cached
|
50
|
+
|
51
|
+
@property
|
52
|
+
def _container_spec(self) -> dict[str, Any]:
|
53
|
+
"""Get the job's main container spec."""
|
54
|
+
containers = self._service_spec["spec"]["containers"]
|
55
|
+
container_spec = next(c for c in containers if c["name"] == constants.DEFAULT_CONTAINER_NAME)
|
56
|
+
return cast(dict[str, Any], container_spec)
|
57
|
+
|
58
|
+
@property
|
59
|
+
def _stage_path(self) -> str:
|
60
|
+
"""Get the job's artifact storage stage location."""
|
61
|
+
volumes = self._service_spec["spec"]["volumes"]
|
62
|
+
stage_path = next(v for v in volumes if v["name"] == constants.STAGE_VOLUME_NAME)["source"]
|
63
|
+
return cast(str, stage_path)
|
64
|
+
|
65
|
+
@property
|
66
|
+
def _result_path(self) -> str:
|
67
|
+
"""Get the job's result file location."""
|
68
|
+
result_path = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
|
69
|
+
if result_path is None:
|
70
|
+
raise RuntimeError(f"Job {self.id} doesn't have a result path configured")
|
71
|
+
return f"{self._stage_path}/{result_path}"
|
72
|
+
|
73
|
+
@overload
|
74
|
+
def get_logs(self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: Literal[True]) -> list[str]:
|
75
|
+
...
|
76
|
+
|
77
|
+
@overload
|
78
|
+
def get_logs(self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: Literal[False] = False) -> str:
|
79
|
+
...
|
80
|
+
|
81
|
+
def get_logs(
|
82
|
+
self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: bool = False
|
83
|
+
) -> Union[str, list[str]]:
|
34
84
|
"""
|
35
85
|
Return the job's execution logs.
|
36
86
|
|
37
87
|
Args:
|
38
88
|
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
89
|
+
instance_id: Optional instance ID to get logs from a specific instance.
|
90
|
+
If not provided, returns logs from the head node.
|
91
|
+
as_list: If True, returns logs as a list of lines. Otherwise, returns logs as a single string.
|
39
92
|
|
40
93
|
Returns:
|
41
94
|
The job's execution logs.
|
42
95
|
"""
|
43
|
-
logs = _get_logs(self._session, self.id, limit)
|
96
|
+
logs = _get_logs(self._session, self.id, limit, instance_id)
|
44
97
|
assert isinstance(logs, str) # mypy
|
98
|
+
if as_list:
|
99
|
+
return logs.splitlines()
|
45
100
|
return logs
|
46
101
|
|
47
|
-
|
48
|
-
def show_logs(self, limit: int = -1) -> None:
|
102
|
+
def show_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> None:
|
49
103
|
"""
|
50
104
|
Display the job's execution logs.
|
51
105
|
|
52
106
|
Args:
|
53
107
|
limit: The maximum number of lines to display. Negative values are treated as no limit.
|
108
|
+
instance_id: Optional instance ID to get logs from a specific instance.
|
109
|
+
If not provided, displays logs from the head node.
|
54
110
|
"""
|
55
|
-
print(self.get_logs(limit)) # noqa: T201: we need to print here.
|
111
|
+
print(self.get_logs(limit, instance_id, as_list=False)) # noqa: T201: we need to print here.
|
56
112
|
|
57
|
-
@
|
58
|
-
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
113
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
|
59
114
|
def wait(self, timeout: float = -1) -> types.JOB_STATUS:
|
60
115
|
"""
|
61
116
|
Block until completion. Returns completion status.
|
@@ -78,20 +133,58 @@ class MLJob:
|
|
78
133
|
delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
79
134
|
return self.status
|
80
135
|
|
136
|
+
@snowpark._internal.utils.private_preview(version="1.8.2")
|
137
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
|
138
|
+
def result(self, timeout: float = -1) -> T:
|
139
|
+
"""
|
140
|
+
Block until completion. Returns job execution result.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
timeout: The maximum time to wait in seconds. Negative values are treated as no timeout.
|
144
|
+
|
145
|
+
Returns:
|
146
|
+
T: The deserialized job result. # noqa: DAR401
|
147
|
+
|
148
|
+
Raises:
|
149
|
+
RuntimeError: If the job failed or if the job doesn't have a result to retrieve.
|
150
|
+
TimeoutError: If the job does not complete within the specified timeout. # noqa: DAR402
|
151
|
+
"""
|
152
|
+
if self._result is None:
|
153
|
+
self.wait(timeout)
|
154
|
+
try:
|
155
|
+
self._result = interop_utils.fetch_result(self._session, self._result_path)
|
156
|
+
except Exception as e:
|
157
|
+
raise RuntimeError(f"Failed to retrieve result for job (id={self.id})") from e
|
158
|
+
|
159
|
+
if self._result.success:
|
160
|
+
return cast(T, self._result.result)
|
161
|
+
raise RuntimeError(f"Job execution failed (id={self.id})") from self._result.exception
|
162
|
+
|
163
|
+
|
164
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
|
165
|
+
def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
|
166
|
+
"""Retrieve job or job instance execution status."""
|
167
|
+
if instance_id is not None:
|
168
|
+
# Get specific instance status
|
169
|
+
rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
170
|
+
for row in rows:
|
171
|
+
if row["instance_id"] == str(instance_id):
|
172
|
+
return cast(types.JOB_STATUS, row["status"])
|
173
|
+
raise ValueError(f"Instance {instance_id} not found in job {job_id}")
|
174
|
+
else:
|
175
|
+
(row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
176
|
+
return cast(types.JOB_STATUS, row["status"])
|
177
|
+
|
81
178
|
|
82
179
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
83
|
-
def
|
84
|
-
"""Retrieve job execution
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit"])
|
94
|
-
def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
|
180
|
+
def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
|
181
|
+
"""Retrieve job execution service spec."""
|
182
|
+
(row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=[job_id]).collect()
|
183
|
+
return cast(dict[str, Any], yaml.safe_load(row["spec"]))
|
184
|
+
|
185
|
+
|
186
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
|
187
|
+
def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None) -> str:
|
95
188
|
"""
|
96
189
|
Retrieve the job's execution logs.
|
97
190
|
|
@@ -99,15 +192,54 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
|
|
99
192
|
job_id: The job ID.
|
100
193
|
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
101
194
|
session: The Snowpark session to use. If none specified, uses active session.
|
195
|
+
instance_id: Optional instance ID to get logs from a specific instance.
|
102
196
|
|
103
197
|
Returns:
|
104
198
|
The job's execution logs.
|
105
199
|
"""
|
106
|
-
|
200
|
+
# If instance_id is not specified, try to get the head instance ID
|
201
|
+
if instance_id is None:
|
202
|
+
instance_id = _get_head_instance_id(session, job_id)
|
203
|
+
|
204
|
+
# Assemble params: [job_id, instance_id, container_name, (optional) limit]
|
205
|
+
params: list[Any] = [
|
206
|
+
job_id,
|
207
|
+
0 if instance_id is None else instance_id,
|
208
|
+
constants.DEFAULT_CONTAINER_NAME,
|
209
|
+
]
|
107
210
|
if limit > 0:
|
108
211
|
params.append(limit)
|
212
|
+
|
109
213
|
(row,) = session.sql(
|
110
|
-
f"SELECT SYSTEM$GET_SERVICE_LOGS(?,
|
214
|
+
f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
|
111
215
|
params=params,
|
112
216
|
).collect()
|
113
217
|
return str(row[0])
|
218
|
+
|
219
|
+
|
220
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
221
|
+
def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[int]:
|
222
|
+
"""
|
223
|
+
Retrieve the head instance ID of a job.
|
224
|
+
|
225
|
+
Args:
|
226
|
+
session: The Snowpark session to use.
|
227
|
+
job_id: The job ID.
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
The head instance ID of the job. Returns None if the head instance has not started yet.
|
231
|
+
"""
|
232
|
+
rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
233
|
+
if not rows:
|
234
|
+
return None
|
235
|
+
|
236
|
+
# Sort by start_time first, then by instance_id
|
237
|
+
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
238
|
+
head_instance = sorted_instances[0]
|
239
|
+
if not head_instance["start_time"]:
|
240
|
+
# If head instance hasn't started yet, return None
|
241
|
+
return None
|
242
|
+
try:
|
243
|
+
return int(head_instance["instance_id"])
|
244
|
+
except (ValueError, TypeError):
|
245
|
+
return 0
|
snowflake/ml/jobs/manager.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
|
+
import logging
|
1
2
|
import pathlib
|
2
3
|
import textwrap
|
3
|
-
from typing import Any, Callable,
|
4
|
+
from typing import Any, Callable, Literal, Optional, TypeVar, Union, overload
|
4
5
|
from uuid import uuid4
|
5
6
|
|
6
7
|
import yaml
|
@@ -13,11 +14,14 @@ from snowflake.ml.jobs._utils import payload_utils, spec_utils
|
|
13
14
|
from snowflake.snowpark.context import get_active_session
|
14
15
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
15
16
|
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
16
19
|
_PROJECT = "MLJob"
|
17
20
|
JOB_ID_PREFIX = "MLJOB_"
|
18
21
|
|
22
|
+
T = TypeVar("T")
|
23
|
+
|
19
24
|
|
20
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
21
25
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
|
22
26
|
def list_jobs(
|
23
27
|
limit: int = 10,
|
@@ -57,9 +61,8 @@ def list_jobs(
|
|
57
61
|
return df
|
58
62
|
|
59
63
|
|
60
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
61
64
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
62
|
-
def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob:
|
65
|
+
def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob[Any]:
|
63
66
|
"""Retrieve a job service from the backend."""
|
64
67
|
session = session or get_active_session()
|
65
68
|
|
@@ -71,7 +74,8 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
71
74
|
|
72
75
|
try:
|
73
76
|
# Validate that job exists by doing a status check
|
74
|
-
|
77
|
+
# FIXME: Retrieve return path
|
78
|
+
job = jb.MLJob[Any](job_id, session=session)
|
75
79
|
_ = job.status
|
76
80
|
return job
|
77
81
|
except SnowparkSQLException as e:
|
@@ -80,9 +84,8 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
80
84
|
raise
|
81
85
|
|
82
86
|
|
83
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
84
87
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
85
|
-
def delete_job(job: Union[str, jb.MLJob], session: Optional[snowpark.Session] = None) -> None:
|
88
|
+
def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
|
86
89
|
"""Delete a job service from the backend. Status and logs will be lost."""
|
87
90
|
if isinstance(job, jb.MLJob):
|
88
91
|
job_id = job.id
|
@@ -93,23 +96,22 @@ def delete_job(job: Union[str, jb.MLJob], session: Optional[snowpark.Session] =
|
|
93
96
|
session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
94
97
|
|
95
98
|
|
96
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
97
99
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
98
100
|
def submit_file(
|
99
101
|
file_path: str,
|
100
102
|
compute_pool: str,
|
101
103
|
*,
|
102
104
|
stage_name: str,
|
103
|
-
args: Optional[
|
104
|
-
env_vars: Optional[
|
105
|
-
pip_requirements: Optional[
|
106
|
-
external_access_integrations: Optional[
|
105
|
+
args: Optional[list[str]] = None,
|
106
|
+
env_vars: Optional[dict[str, str]] = None,
|
107
|
+
pip_requirements: Optional[list[str]] = None,
|
108
|
+
external_access_integrations: Optional[list[str]] = None,
|
107
109
|
query_warehouse: Optional[str] = None,
|
108
|
-
spec_overrides: Optional[
|
110
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
109
111
|
num_instances: Optional[int] = None,
|
110
112
|
enable_metrics: bool = False,
|
111
113
|
session: Optional[snowpark.Session] = None,
|
112
|
-
) -> jb.MLJob:
|
114
|
+
) -> jb.MLJob[None]:
|
113
115
|
"""
|
114
116
|
Submit a Python file as a job to the compute pool.
|
115
117
|
|
@@ -146,7 +148,6 @@ def submit_file(
|
|
146
148
|
)
|
147
149
|
|
148
150
|
|
149
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
150
151
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
151
152
|
def submit_directory(
|
152
153
|
dir_path: str,
|
@@ -154,16 +155,16 @@ def submit_directory(
|
|
154
155
|
*,
|
155
156
|
entrypoint: str,
|
156
157
|
stage_name: str,
|
157
|
-
args: Optional[
|
158
|
-
env_vars: Optional[
|
159
|
-
pip_requirements: Optional[
|
160
|
-
external_access_integrations: Optional[
|
158
|
+
args: Optional[list[str]] = None,
|
159
|
+
env_vars: Optional[dict[str, str]] = None,
|
160
|
+
pip_requirements: Optional[list[str]] = None,
|
161
|
+
external_access_integrations: Optional[list[str]] = None,
|
161
162
|
query_warehouse: Optional[str] = None,
|
162
|
-
spec_overrides: Optional[
|
163
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
163
164
|
num_instances: Optional[int] = None,
|
164
165
|
enable_metrics: bool = False,
|
165
166
|
session: Optional[snowpark.Session] = None,
|
166
|
-
) -> jb.MLJob:
|
167
|
+
) -> jb.MLJob[None]:
|
167
168
|
"""
|
168
169
|
Submit a directory containing Python script(s) as a job to the compute pool.
|
169
170
|
|
@@ -202,6 +203,46 @@ def submit_directory(
|
|
202
203
|
)
|
203
204
|
|
204
205
|
|
206
|
+
@overload
|
207
|
+
def _submit_job(
|
208
|
+
source: str,
|
209
|
+
compute_pool: str,
|
210
|
+
*,
|
211
|
+
stage_name: str,
|
212
|
+
entrypoint: Optional[str] = None,
|
213
|
+
args: Optional[list[str]] = None,
|
214
|
+
env_vars: Optional[dict[str, str]] = None,
|
215
|
+
pip_requirements: Optional[list[str]] = None,
|
216
|
+
external_access_integrations: Optional[list[str]] = None,
|
217
|
+
query_warehouse: Optional[str] = None,
|
218
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
219
|
+
num_instances: Optional[int] = None,
|
220
|
+
enable_metrics: bool = False,
|
221
|
+
session: Optional[snowpark.Session] = None,
|
222
|
+
) -> jb.MLJob[None]:
|
223
|
+
...
|
224
|
+
|
225
|
+
|
226
|
+
@overload
|
227
|
+
def _submit_job(
|
228
|
+
source: Callable[..., T],
|
229
|
+
compute_pool: str,
|
230
|
+
*,
|
231
|
+
stage_name: str,
|
232
|
+
entrypoint: Optional[str] = None,
|
233
|
+
args: Optional[list[str]] = None,
|
234
|
+
env_vars: Optional[dict[str, str]] = None,
|
235
|
+
pip_requirements: Optional[list[str]] = None,
|
236
|
+
external_access_integrations: Optional[list[str]] = None,
|
237
|
+
query_warehouse: Optional[str] = None,
|
238
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
239
|
+
num_instances: Optional[int] = None,
|
240
|
+
enable_metrics: bool = False,
|
241
|
+
session: Optional[snowpark.Session] = None,
|
242
|
+
) -> jb.MLJob[T]:
|
243
|
+
...
|
244
|
+
|
245
|
+
|
205
246
|
@telemetry.send_api_usage_telemetry(
|
206
247
|
project=_PROJECT,
|
207
248
|
func_params_to_log=[
|
@@ -210,24 +251,26 @@ def submit_directory(
|
|
210
251
|
# TODO: Log lengths of args, env_vars, and spec_overrides values
|
211
252
|
"pip_requirements",
|
212
253
|
"external_access_integrations",
|
254
|
+
"num_instances",
|
255
|
+
"enable_metrics",
|
213
256
|
],
|
214
257
|
)
|
215
258
|
def _submit_job(
|
216
|
-
source: Union[str, Callable[...,
|
259
|
+
source: Union[str, Callable[..., T]],
|
217
260
|
compute_pool: str,
|
218
261
|
*,
|
219
262
|
stage_name: str,
|
220
263
|
entrypoint: Optional[str] = None,
|
221
|
-
args: Optional[
|
222
|
-
env_vars: Optional[
|
223
|
-
pip_requirements: Optional[
|
224
|
-
external_access_integrations: Optional[
|
264
|
+
args: Optional[list[str]] = None,
|
265
|
+
env_vars: Optional[dict[str, str]] = None,
|
266
|
+
pip_requirements: Optional[list[str]] = None,
|
267
|
+
external_access_integrations: Optional[list[str]] = None,
|
225
268
|
query_warehouse: Optional[str] = None,
|
226
|
-
spec_overrides: Optional[
|
269
|
+
spec_overrides: Optional[dict[str, Any]] = None,
|
227
270
|
num_instances: Optional[int] = None,
|
228
271
|
enable_metrics: bool = False,
|
229
272
|
session: Optional[snowpark.Session] = None,
|
230
|
-
) -> jb.MLJob:
|
273
|
+
) -> jb.MLJob[T]:
|
231
274
|
"""
|
232
275
|
Submit a job to the compute pool.
|
233
276
|
|
@@ -252,6 +295,12 @@ def _submit_job(
|
|
252
295
|
Raises:
|
253
296
|
RuntimeError: If required Snowflake features are not enabled.
|
254
297
|
"""
|
298
|
+
# Display warning about PrPr parameters
|
299
|
+
if num_instances is not None:
|
300
|
+
logger.warning(
|
301
|
+
"_submit_job() parameter 'num_instances' is in private preview since 1.8.2. Do not use it in production.",
|
302
|
+
)
|
303
|
+
|
255
304
|
session = session or get_active_session()
|
256
305
|
job_id = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
257
306
|
stage_name = "@" + stage_name.lstrip("@").rstrip("/")
|
@@ -314,5 +363,4 @@ def _submit_job(
|
|
314
363
|
) from e
|
315
364
|
raise
|
316
365
|
|
317
|
-
|
318
|
-
return jb.MLJob(job_id, session=session)
|
366
|
+
return jb.MLJob(job_id, service_spec=spec, session=session)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
from datetime import datetime
|
3
|
-
from typing import TYPE_CHECKING,
|
3
|
+
from typing import TYPE_CHECKING, Literal, Optional, Union
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
6
|
from snowflake.ml._internal import telemetry
|
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|
12
12
|
from snowflake.ml.model._client.model import model_version_impl
|
13
13
|
|
14
14
|
_PROJECT = "LINEAGE"
|
15
|
-
DOMAIN_LINEAGE_REGISTRY:
|
15
|
+
DOMAIN_LINEAGE_REGISTRY: dict[str, type["LineageNode"]] = {}
|
16
16
|
|
17
17
|
|
18
18
|
class LineageNode:
|
@@ -87,8 +87,8 @@ class LineageNode:
|
|
87
87
|
def lineage(
|
88
88
|
self,
|
89
89
|
direction: Literal["upstream", "downstream"] = "downstream",
|
90
|
-
domain_filter: Optional[
|
91
|
-
) ->
|
90
|
+
domain_filter: Optional[set[Literal["feature_view", "dataset", "model", "table", "view"]]] = None,
|
91
|
+
) -> list[Union["feature_view.FeatureView", "dataset.Dataset", "model_version_impl.ModelVersion", "LineageNode"]]:
|
92
92
|
"""
|
93
93
|
Retrieves the lineage nodes connected to this node.
|
94
94
|
|
@@ -109,7 +109,7 @@ class LineageNode:
|
|
109
109
|
if domain_filter is not None:
|
110
110
|
domain_filter = {d.lower() for d in domain_filter} # type: ignore[misc]
|
111
111
|
|
112
|
-
lineage_nodes:
|
112
|
+
lineage_nodes: list["LineageNode"] = []
|
113
113
|
for row in df.collect():
|
114
114
|
lineage_object = (
|
115
115
|
json.loads(row["TARGET_OBJECT"])
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Union
|
2
2
|
|
3
3
|
import pandas as pd
|
4
4
|
|
@@ -224,7 +224,7 @@ class Model:
|
|
224
224
|
project=_TELEMETRY_PROJECT,
|
225
225
|
subproject=_TELEMETRY_SUBPROJECT,
|
226
226
|
)
|
227
|
-
def versions(self) ->
|
227
|
+
def versions(self) -> list[model_version_impl.ModelVersion]:
|
228
228
|
"""Get all versions in the model.
|
229
229
|
|
230
230
|
Returns:
|
@@ -298,7 +298,7 @@ class Model:
|
|
298
298
|
project=_TELEMETRY_PROJECT,
|
299
299
|
subproject=_TELEMETRY_SUBPROJECT,
|
300
300
|
)
|
301
|
-
def show_tags(self) ->
|
301
|
+
def show_tags(self) -> dict[str, str]:
|
302
302
|
"""Get a dictionary showing the tag and its value attached to the model.
|
303
303
|
|
304
304
|
Returns:
|