snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +6 -6
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +6 -16
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +17 -28
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +4 -5
- snowflake/ml/jobs/job.py +24 -14
- snowflake/ml/jobs/manager.py +37 -41
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +199 -26
- snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +13 -13
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +17 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +33 -31
- snowflake/ml/registry/registry.py +29 -22
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
snowflake/ml/fileset/sfcfs.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import collections
|
2
2
|
import logging
|
3
3
|
from functools import partial
|
4
|
-
from typing import Any, Callable,
|
4
|
+
from typing import Any, Callable, Optional, Union, cast
|
5
5
|
|
6
6
|
import fsspec
|
7
7
|
|
@@ -100,7 +100,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
100
100
|
raise ValueError("Either sf_connection or snowpark_session has to be non-empty!")
|
101
101
|
self._conn = self._session._conn._conn # Telemetry wrappers expect connection under `conn_attr_name="_conn"``
|
102
102
|
self._kwargs = kwargs
|
103
|
-
self._stage_fs_set:
|
103
|
+
self._stage_fs_set: dict[tuple[str, str, str], stage_fs.SFStageFileSystem] = {}
|
104
104
|
|
105
105
|
super().__init__(**kwargs)
|
106
106
|
|
@@ -133,7 +133,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
133
133
|
assert isinstance(session, snowpark.Session)
|
134
134
|
return session
|
135
135
|
|
136
|
-
def __reduce__(self) ->
|
136
|
+
def __reduce__(self) -> tuple[Callable[[], type["SFFileSystem"]], tuple[()], dict[str, Any]]:
|
137
137
|
"""Returns a state dictionary for use in serialization.
|
138
138
|
|
139
139
|
Returns:
|
@@ -145,7 +145,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
145
145
|
|
146
146
|
return partial(self.__class__, **{_RECREATE_FROM_SERIALIZED: True}), (), state_dictionary
|
147
147
|
|
148
|
-
def __setstate__(self, state_dict:
|
148
|
+
def __setstate__(self, state_dict: dict[str, Any]) -> None:
|
149
149
|
"""Sets the dictionary state at deserialization time, and rebuilds a snowflake connection.
|
150
150
|
|
151
151
|
Args:
|
@@ -191,7 +191,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
191
191
|
func_params_to_log=["detail"],
|
192
192
|
conn_attr_name="_conn",
|
193
193
|
)
|
194
|
-
def ls(self, path: str, detail: bool = False, **kwargs: Any) -> Union[
|
194
|
+
def ls(self, path: str, detail: bool = False, **kwargs: Any) -> Union[list[str], list[dict[str, Any]]]:
|
195
195
|
"""Override fsspec `ls` method. List single "directory" with or without details.
|
196
196
|
|
197
197
|
Args:
|
@@ -214,14 +214,14 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
214
214
|
file_path = self._parse_file_path(path)
|
215
215
|
stage_fs = self._get_stage_fs(file_path)
|
216
216
|
stage_path_list = stage_fs.ls(file_path.filepath, detail=True, **kwargs)
|
217
|
-
stage_path_list = cast(
|
217
|
+
stage_path_list = cast(list[dict[str, Any]], stage_path_list)
|
218
218
|
return self._decorate_ls_res(stage_fs, stage_path_list, detail)
|
219
219
|
|
220
220
|
@telemetry.send_api_usage_telemetry(
|
221
221
|
project=_PROJECT,
|
222
222
|
conn_attr_name="_conn",
|
223
223
|
)
|
224
|
-
def optimize_read(self, files: Optional[
|
224
|
+
def optimize_read(self, files: Optional[list[str]] = None) -> None:
|
225
225
|
"""Prefetch and cache the presigned urls for all the given files to speed up the file opening.
|
226
226
|
|
227
227
|
All the files introduced here will have their urls cached. Further open() on any of cached urls will lead to a
|
@@ -232,8 +232,8 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
232
232
|
"""
|
233
233
|
if not files:
|
234
234
|
return
|
235
|
-
stage_fs_dict:
|
236
|
-
stage_file_paths:
|
235
|
+
stage_fs_dict: dict[str, stage_fs.SFStageFileSystem] = {}
|
236
|
+
stage_file_paths: dict[str, list[str]] = collections.defaultdict(list)
|
237
237
|
for file in files:
|
238
238
|
path_info = self._parse_file_path(file)
|
239
239
|
fs = self._get_stage_fs(path_info)
|
@@ -271,11 +271,11 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
271
271
|
project=_PROJECT,
|
272
272
|
conn_attr_name="_conn",
|
273
273
|
)
|
274
|
-
def info(self, path: str, **kwargs: Any) ->
|
274
|
+
def info(self, path: str, **kwargs: Any) -> dict[str, Any]:
|
275
275
|
"""Override fsspec `info` method. Give details of entry at path."""
|
276
276
|
file_path = self._parse_file_path(path)
|
277
277
|
stage_fs = self._get_stage_fs(file_path)
|
278
|
-
res:
|
278
|
+
res: dict[str, Any] = stage_fs.info(file_path.filepath, **kwargs)
|
279
279
|
if res:
|
280
280
|
res["name"] = self._stage_path_to_absolute_path(stage_fs, res["name"])
|
281
281
|
return res
|
@@ -283,9 +283,9 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
283
283
|
def _decorate_ls_res(
|
284
284
|
self,
|
285
285
|
stage_fs: stage_fs.SFStageFileSystem,
|
286
|
-
stage_path_list:
|
286
|
+
stage_path_list: list[dict[str, Any]],
|
287
287
|
detail: bool,
|
288
|
-
) -> Union[
|
288
|
+
) -> Union[list[str], list[dict[str, Any]]]:
|
289
289
|
"""Add the stage location as the prefix of file names returned by ls() of stagefs"""
|
290
290
|
for path in stage_path_list:
|
291
291
|
path["name"] = self._stage_path_to_absolute_path(stage_fs, path["name"])
|
snowflake/ml/fileset/stage_fs.py
CHANGED
@@ -2,7 +2,7 @@ import inspect
|
|
2
2
|
import logging
|
3
3
|
import time
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, Optional, Union, cast
|
6
6
|
|
7
7
|
import fsspec
|
8
8
|
from fsspec.implementations import http as httpfs
|
@@ -44,7 +44,7 @@ class _PresignedUrl:
|
|
44
44
|
return not self.expire_at or time.time() > self.expire_at - headroom_sec
|
45
45
|
|
46
46
|
|
47
|
-
def _get_httpfs_kwargs(**kwargs: Any) ->
|
47
|
+
def _get_httpfs_kwargs(**kwargs: Any) -> dict[str, Any]:
|
48
48
|
"""Extract kwargs that are meaningful to HTTPFileSystem."""
|
49
49
|
httpfs_related_keys = [
|
50
50
|
"block_size",
|
@@ -124,7 +124,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
124
124
|
self._db = db
|
125
125
|
self._schema = schema
|
126
126
|
self._stage = stage
|
127
|
-
self._url_cache:
|
127
|
+
self._url_cache: dict[str, _PresignedUrl] = {}
|
128
128
|
|
129
129
|
httpfs_kwargs = _get_httpfs_kwargs(**kwargs)
|
130
130
|
self._fs = httpfs.HTTPFileSystem(**httpfs_kwargs)
|
@@ -145,7 +145,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
145
145
|
project=_PROJECT,
|
146
146
|
func_params_to_log=["detail"],
|
147
147
|
)
|
148
|
-
def ls(self, path: str, detail: bool = False) -> Union[
|
148
|
+
def ls(self, path: str, detail: bool = False) -> Union[list[str], list[dict[str, Any]]]:
|
149
149
|
"""Override fsspec `ls` method. List single "directory" with or without details.
|
150
150
|
|
151
151
|
Args:
|
@@ -169,7 +169,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
169
169
|
loc = self.stage_name
|
170
170
|
path = path.lstrip("/")
|
171
171
|
async_job: snowpark.AsyncJob = self._session.sql(f"LIST '{loc}/{path}'").collect(block=False)
|
172
|
-
objects:
|
172
|
+
objects: list[snowpark.Row] = _resolve_async_job(async_job)
|
173
173
|
except snowpark_exceptions.SnowparkSQLException as e:
|
174
174
|
if e.sql_error_code == fileset_errors.ERRNO_DOMAIN_NOT_EXIST:
|
175
175
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -192,7 +192,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
192
192
|
@telemetry.send_api_usage_telemetry(
|
193
193
|
project=_PROJECT,
|
194
194
|
)
|
195
|
-
def optimize_read(self, files: Optional[
|
195
|
+
def optimize_read(self, files: Optional[list[str]] = None) -> None:
|
196
196
|
"""Prefetch and cache the presigned urls for all the given files to speed up the read performance.
|
197
197
|
|
198
198
|
All the files introduced here will have their urls cached. Further open() on any of cached urls will lead to a
|
@@ -271,7 +271,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
271
271
|
original_exception=fileset_errors.StageFileNotFoundError(f"Stage file {path} doesn't exist."),
|
272
272
|
)
|
273
273
|
|
274
|
-
def _open_with_snowpark(self, path: str, **kwargs:
|
274
|
+
def _open_with_snowpark(self, path: str, **kwargs: dict[str, Any]) -> fsspec.spec.AbstractBufferedFile:
|
275
275
|
"""Open the a file for reading using snowflake.snowpark.file_operation
|
276
276
|
|
277
277
|
Args:
|
@@ -299,7 +299,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
299
299
|
original_exception=e,
|
300
300
|
)
|
301
301
|
|
302
|
-
def _parse_list_result(self, list_result:
|
302
|
+
def _parse_list_result(self, list_result: list[snowpark.Row], search_path: str) -> list[dict[str, Any]]:
|
303
303
|
"""Convert the result from LIST query to the expected format of fsspec ls() method.
|
304
304
|
|
305
305
|
Note that Snowflake LIST query has different behavior with ls(). LIST query will return all the stage files
|
@@ -318,7 +318,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
318
318
|
Returns:
|
319
319
|
A list of dict, where each dict contains key-value pairs as the properties of a file.
|
320
320
|
"""
|
321
|
-
files:
|
321
|
+
files: dict[str, dict[str, Any]] = {}
|
322
322
|
search_path = search_path.strip("/")
|
323
323
|
for row in list_result:
|
324
324
|
name, size, md5, last_modified = row["name"], row["size"], row["md5"], row["last_modified"]
|
@@ -360,7 +360,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
360
360
|
|
361
361
|
def _add_file_info_helper(
|
362
362
|
self,
|
363
|
-
files:
|
363
|
+
files: dict[str, dict[str, Any]],
|
364
364
|
object_path: str,
|
365
365
|
file_size: int,
|
366
366
|
file_type: str,
|
@@ -379,12 +379,12 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
379
379
|
)
|
380
380
|
|
381
381
|
def _fetch_presigned_urls(
|
382
|
-
self, files:
|
383
|
-
) ->
|
382
|
+
self, files: list[str], url_lifetime: float = _PRESIGNED_URL_LIFETIME_SEC
|
383
|
+
) -> list[tuple[str, str]]:
|
384
384
|
"""Fetch presigned urls for the given files."""
|
385
385
|
file_df = self._session.create_dataframe(files).to_df("name")
|
386
386
|
try:
|
387
|
-
presigned_urls:
|
387
|
+
presigned_urls: list[tuple[str, str]] = file_df.select_expr(
|
388
388
|
f"name, get_presigned_url('{self.stage_name}', name, {url_lifetime}) as url"
|
389
389
|
).collect(
|
390
390
|
statement_params=telemetry.get_function_usage_statement_params(
|
@@ -418,10 +418,10 @@ def _match_error_code(ex: snowpark_exceptions.SnowparkSQLException, error_code:
|
|
418
418
|
|
419
419
|
|
420
420
|
@snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
|
421
|
-
def _resolve_async_job(async_job: snowpark.AsyncJob) ->
|
421
|
+
def _resolve_async_job(async_job: snowpark.AsyncJob) -> list[snowpark.Row]:
|
422
422
|
# Make sure Snowpark exceptions are properly caught and converted by wrap_exception wrapper
|
423
423
|
try:
|
424
|
-
query_result = cast(
|
424
|
+
query_result = cast(list[snowpark.Row], async_job.result("row"))
|
425
425
|
return query_result
|
426
426
|
except snowpark_errors.DatabaseError as e:
|
427
427
|
# HACK: Snowpark surfaces a generic exception if query doesn't complete immediately
|
@@ -10,7 +10,7 @@ import traceback
|
|
10
10
|
from collections import namedtuple
|
11
11
|
from dataclasses import dataclass
|
12
12
|
from types import TracebackType
|
13
|
-
from typing import Any, Callable,
|
13
|
+
from typing import Any, Callable, Optional, Union, cast
|
14
14
|
|
15
15
|
from snowflake import snowpark
|
16
16
|
from snowflake.snowpark import exceptions as sp_exceptions
|
@@ -33,7 +33,7 @@ class ExecutionResult:
|
|
33
33
|
def success(self) -> bool:
|
34
34
|
return self.exception is None
|
35
35
|
|
36
|
-
def to_dict(self) ->
|
36
|
+
def to_dict(self) -> dict[str, Any]:
|
37
37
|
"""Return the serializable dictionary."""
|
38
38
|
if isinstance(self.exception, BaseException):
|
39
39
|
exc_type = type(self.exception)
|
@@ -50,7 +50,7 @@ class ExecutionResult:
|
|
50
50
|
}
|
51
51
|
|
52
52
|
@classmethod
|
53
|
-
def from_dict(cls, result_dict:
|
53
|
+
def from_dict(cls, result_dict: dict[str, Any]) -> "ExecutionResult":
|
54
54
|
if not isinstance(result_dict.get("success"), bool):
|
55
55
|
raise ValueError("Invalid result dictionary")
|
56
56
|
|
@@ -242,11 +242,11 @@ def _install_sys_excepthook() -> None:
|
|
242
242
|
original_excepthook = sys.excepthook
|
243
243
|
|
244
244
|
def custom_excepthook(
|
245
|
-
exc_type:
|
245
|
+
exc_type: type[BaseException],
|
246
246
|
exc_value: BaseException,
|
247
247
|
exc_tb: Optional[TracebackType],
|
248
248
|
*,
|
249
|
-
seen_exc_ids: Optional[
|
249
|
+
seen_exc_ids: Optional[set[int]] = None,
|
250
250
|
) -> None:
|
251
251
|
if seen_exc_ids is None:
|
252
252
|
seen_exc_ids = set()
|
@@ -331,7 +331,7 @@ def _install_ipython_hook() -> bool:
|
|
331
331
|
except ImportError:
|
332
332
|
return False
|
333
333
|
|
334
|
-
def parse_traceback_str(traceback_str: str) ->
|
334
|
+
def parse_traceback_str(traceback_str: str) -> list[tuple[str, int, str, str]]:
|
335
335
|
return [
|
336
336
|
(m.group("filename"), int(m.group("lineno")), m.group("name"), m.group("line"))
|
337
337
|
for m in re.finditer(_TRACEBACK_ENTRY_PATTERN, traceback_str)
|
@@ -342,13 +342,13 @@ def _install_ipython_hook() -> bool:
|
|
342
342
|
|
343
343
|
def custom_format_exception_as_a_whole(
|
344
344
|
self: VerboseTB,
|
345
|
-
etype:
|
345
|
+
etype: type[BaseException],
|
346
346
|
evalue: Optional[BaseException],
|
347
347
|
etb: Optional[TracebackType],
|
348
348
|
number_of_lines_of_context: int,
|
349
349
|
tb_offset: Optional[int],
|
350
350
|
**kwargs: Any,
|
351
|
-
) ->
|
351
|
+
) -> list[list[str]]:
|
352
352
|
if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
|
353
353
|
# Implementation forked from IPython.core.ultratb.VerboseTB.format_exception_as_a_whole
|
354
354
|
head = self.prepare_header(remote_err.exc_type, long_version=False).replace(
|
@@ -388,7 +388,7 @@ def _install_ipython_hook() -> bool:
|
|
388
388
|
etb: Optional[TracebackType],
|
389
389
|
tb_offset: Optional[int] = None,
|
390
390
|
**kwargs: Any,
|
391
|
-
) ->
|
391
|
+
) -> list[str]:
|
392
392
|
if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
|
393
393
|
tb_list = [
|
394
394
|
(m.group("filename"), m.group("lineno"), m.group("name"), m.group("line"))
|
@@ -400,7 +400,7 @@ def _install_ipython_hook() -> bool:
|
|
400
400
|
"(most recent call last)",
|
401
401
|
"(from remote execution)",
|
402
402
|
)
|
403
|
-
return cast(
|
403
|
+
return cast(list[str], out_list)
|
404
404
|
return original_structured_traceback( # type: ignore[no-any-return]
|
405
405
|
self, etype, evalue, etb, tb_offset, **kwargs
|
406
406
|
)
|
@@ -6,17 +6,7 @@ import pickle
|
|
6
6
|
import sys
|
7
7
|
import textwrap
|
8
8
|
from pathlib import Path, PurePath
|
9
|
-
from typing import
|
10
|
-
Any,
|
11
|
-
Callable,
|
12
|
-
List,
|
13
|
-
Optional,
|
14
|
-
Type,
|
15
|
-
Union,
|
16
|
-
cast,
|
17
|
-
get_args,
|
18
|
-
get_origin,
|
19
|
-
)
|
9
|
+
from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
|
20
10
|
|
21
11
|
import cloudpickle as cp
|
22
12
|
|
@@ -277,7 +267,7 @@ class JobPayload:
|
|
277
267
|
source: Union[str, Path, Callable[..., Any]],
|
278
268
|
entrypoint: Optional[Union[str, Path]] = None,
|
279
269
|
*,
|
280
|
-
pip_requirements: Optional[
|
270
|
+
pip_requirements: Optional[list[str]] = None,
|
281
271
|
) -> None:
|
282
272
|
self.source = Path(source) if isinstance(source, str) else source
|
283
273
|
self.entrypoint = Path(entrypoint) if isinstance(entrypoint, str) else entrypoint
|
@@ -364,7 +354,7 @@ class JobPayload:
|
|
364
354
|
auto_compress=False,
|
365
355
|
)
|
366
356
|
|
367
|
-
python_entrypoint:
|
357
|
+
python_entrypoint: list[Union[str, PurePath]] = [
|
368
358
|
PurePath("mljob_launcher.py"),
|
369
359
|
entrypoint.file_path.relative_to(source),
|
370
360
|
]
|
@@ -381,7 +371,7 @@ class JobPayload:
|
|
381
371
|
)
|
382
372
|
|
383
373
|
|
384
|
-
def _get_parameter_type(param: inspect.Parameter) -> Optional[
|
374
|
+
def _get_parameter_type(param: inspect.Parameter) -> Optional[type[object]]:
|
385
375
|
# Unwrap Optional type annotations
|
386
376
|
param_type = param.annotation
|
387
377
|
if get_origin(param_type) is Union and len(get_args(param_type)) == 2 and type(None) in get_args(param_type):
|
@@ -390,10 +380,10 @@ def _get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
|
|
390
380
|
# Return None for empty type annotations
|
391
381
|
if param_type == inspect.Parameter.empty:
|
392
382
|
return None
|
393
|
-
return cast(
|
383
|
+
return cast(type[object], param_type)
|
394
384
|
|
395
385
|
|
396
|
-
def _validate_parameter_type(param_type:
|
386
|
+
def _validate_parameter_type(param_type: type[object], param_name: str) -> None:
|
397
387
|
# Validate param_type is a supported type
|
398
388
|
if param_type not in _SUPPORTED_ARG_TYPES:
|
399
389
|
raise ValueError(
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import argparse
|
2
|
+
import copy
|
2
3
|
import importlib.util
|
3
4
|
import json
|
4
5
|
import os
|
@@ -7,7 +8,7 @@ import sys
|
|
7
8
|
import traceback
|
8
9
|
import warnings
|
9
10
|
from pathlib import Path
|
10
|
-
from typing import Any,
|
11
|
+
from typing import Any, Optional
|
11
12
|
|
12
13
|
import cloudpickle
|
13
14
|
|
@@ -27,7 +28,7 @@ except ImportError:
|
|
27
28
|
from dataclasses import dataclass
|
28
29
|
|
29
30
|
@dataclass(frozen=True)
|
30
|
-
class ExecutionResult:
|
31
|
+
class ExecutionResult: # type: ignore[no-redef]
|
31
32
|
result: Optional[Any] = None
|
32
33
|
exception: Optional[BaseException] = None
|
33
34
|
|
@@ -35,7 +36,7 @@ except ImportError:
|
|
35
36
|
def success(self) -> bool:
|
36
37
|
return self.exception is None
|
37
38
|
|
38
|
-
def to_dict(self) ->
|
39
|
+
def to_dict(self) -> dict[str, Any]:
|
39
40
|
"""Return the serializable dictionary."""
|
40
41
|
if isinstance(self.exception, BaseException):
|
41
42
|
exc_type = type(self.exception)
|
@@ -136,7 +137,9 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
136
137
|
while tb and tb.tb_frame.f_code.co_filename in skip_files:
|
137
138
|
# Skip any frames preceding user script execution
|
138
139
|
tb = tb.tb_next
|
139
|
-
|
140
|
+
cleaned_ex = copy.copy(e) # Need to create a mutable copy of exception to set __traceback__
|
141
|
+
cleaned_ex = cleaned_ex.with_traceback(tb)
|
142
|
+
result_obj = ExecutionResult(exception=cleaned_ex)
|
140
143
|
raise
|
141
144
|
finally:
|
142
145
|
result_dict = result_obj.to_dict()
|
@@ -9,7 +9,7 @@ import logging
|
|
9
9
|
import socket
|
10
10
|
import sys
|
11
11
|
import time
|
12
|
-
from typing import Any
|
12
|
+
from typing import Any
|
13
13
|
|
14
14
|
import ray
|
15
15
|
from constants import (
|
@@ -33,34 +33,34 @@ class ShutdownSignal:
|
|
33
33
|
self.acknowledged_workers = set()
|
34
34
|
logging.info(f"ShutdownSignal actor created on {self.hostname}")
|
35
35
|
|
36
|
-
def request_shutdown(self) ->
|
36
|
+
def request_shutdown(self) -> dict[str, Any]:
|
37
37
|
"""Signal workers to shut down"""
|
38
38
|
self.shutdown_requested = True
|
39
39
|
self.timestamp = time.time()
|
40
40
|
logging.info(f"Shutdown requested by head node at {self.timestamp}")
|
41
41
|
return {"status": "shutdown_requested", "timestamp": self.timestamp, "host": self.hostname}
|
42
42
|
|
43
|
-
def should_shutdown(self) ->
|
43
|
+
def should_shutdown(self) -> dict[str, Any]:
|
44
44
|
"""Check if shutdown has been requested"""
|
45
45
|
return {"shutdown": self.shutdown_requested, "timestamp": self.timestamp, "host": self.hostname}
|
46
46
|
|
47
|
-
def ping(self) ->
|
47
|
+
def ping(self) -> dict[str, Any]:
|
48
48
|
"""Simple method to test connectivity"""
|
49
49
|
return {"status": "alive", "host": self.hostname}
|
50
50
|
|
51
|
-
def acknowledge_shutdown(self, worker_id: str) ->
|
51
|
+
def acknowledge_shutdown(self, worker_id: str) -> dict[str, Any]:
|
52
52
|
"""Worker acknowledges it has received the shutdown signal and is terminating"""
|
53
53
|
self.acknowledged_workers.add(worker_id)
|
54
54
|
logging.info(f"Worker {worker_id} acknowledged shutdown. Total acknowledged: {len(self.acknowledged_workers)}")
|
55
55
|
|
56
56
|
return {"status": "acknowledged", "worker_id": worker_id, "acknowledged_count": len(self.acknowledged_workers)}
|
57
57
|
|
58
|
-
def get_acknowledgment_workers(self) ->
|
58
|
+
def get_acknowledgment_workers(self) -> set[str]:
|
59
59
|
"""Get the set of workers who have acknowledged shutdown"""
|
60
60
|
return self.acknowledged_workers
|
61
61
|
|
62
62
|
|
63
|
-
def get_worker_node_ids() ->
|
63
|
+
def get_worker_node_ids() -> list[str]:
|
64
64
|
"""Get the IDs of all active worker nodes.
|
65
65
|
|
66
66
|
Returns:
|
@@ -127,7 +127,7 @@ def verify_shutdown(shutdown_signal: ActorHandle) -> None:
|
|
127
127
|
logging.debug(f"Shutdown status check: {check}")
|
128
128
|
|
129
129
|
|
130
|
-
def wait_for_acknowledgments(shutdown_signal: ActorHandle, worker_node_ids:
|
130
|
+
def wait_for_acknowledgments(shutdown_signal: ActorHandle, worker_node_ids: list[str], wait_time: int) -> None:
|
131
131
|
"""Wait for workers to acknowledge shutdown.
|
132
132
|
|
133
133
|
Args:
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
from math import ceil
|
3
3
|
from pathlib import PurePath
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional, Union
|
5
5
|
|
6
6
|
from snowflake import snowpark
|
7
7
|
from snowflake.ml._internal.utils import snowflake_env
|
@@ -15,10 +15,7 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
|
|
15
15
|
if not rows:
|
16
16
|
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
17
17
|
instance_family: str = rows[0]["instance_family"]
|
18
|
-
|
19
|
-
# Get the cloud we're using (AWS, Azure, etc)
|
20
|
-
region = snowflake_env.get_regions(session)[snowflake_env.get_current_region_id(session)]
|
21
|
-
cloud = region["cloud"]
|
18
|
+
cloud = snowflake_env.get_current_cloud(session, default=snowflake_env.SnowflakeCloudType.AWS)
|
22
19
|
|
23
20
|
return (
|
24
21
|
constants.COMMON_INSTANCE_FAMILIES.get(instance_family)
|
@@ -26,22 +23,14 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
|
|
26
23
|
)
|
27
24
|
|
28
25
|
|
29
|
-
def _get_image_spec(session: snowpark.Session, compute_pool: str
|
26
|
+
def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
|
30
27
|
# Retrieve compute pool node resources
|
31
28
|
resources = _get_node_resources(session, compute_pool=compute_pool)
|
32
29
|
|
33
30
|
# Use MLRuntime image
|
34
31
|
image_repo = constants.DEFAULT_IMAGE_REPO
|
35
32
|
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
36
|
-
|
37
|
-
# Try to pull latest image tag from server side if possible
|
38
|
-
if not image_tag:
|
39
|
-
query_result = session.sql("SHOW PARAMETERS LIKE 'constants.RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT").collect()
|
40
|
-
if query_result:
|
41
|
-
image_tag = query_result[0]["value"]
|
42
|
-
|
43
|
-
if image_tag is None:
|
44
|
-
image_tag = constants.DEFAULT_IMAGE_TAG
|
33
|
+
image_tag = constants.DEFAULT_IMAGE_TAG
|
45
34
|
|
46
35
|
# TODO: Should each instance consume the entire pod?
|
47
36
|
return types.ImageSpec(
|
@@ -54,9 +43,9 @@ def _get_image_spec(session: snowpark.Session, compute_pool: str, image_tag: Opt
|
|
54
43
|
|
55
44
|
|
56
45
|
def generate_spec_overrides(
|
57
|
-
environment_vars: Optional[
|
58
|
-
custom_overrides: Optional[
|
59
|
-
) ->
|
46
|
+
environment_vars: Optional[dict[str, str]] = None,
|
47
|
+
custom_overrides: Optional[dict[str, Any]] = None,
|
48
|
+
) -> dict[str, Any]:
|
60
49
|
"""
|
61
50
|
Generate a dictionary of service specification overrides.
|
62
51
|
|
@@ -68,7 +57,7 @@ def generate_spec_overrides(
|
|
68
57
|
Resulting service specifiation patch dict. Empty if no overrides were supplied.
|
69
58
|
"""
|
70
59
|
# Generate container level overrides
|
71
|
-
container_spec:
|
60
|
+
container_spec: dict[str, Any] = {
|
72
61
|
"name": constants.DEFAULT_CONTAINER_NAME,
|
73
62
|
}
|
74
63
|
if environment_vars:
|
@@ -95,10 +84,10 @@ def generate_service_spec(
|
|
95
84
|
session: snowpark.Session,
|
96
85
|
compute_pool: str,
|
97
86
|
payload: types.UploadedPayload,
|
98
|
-
args: Optional[
|
87
|
+
args: Optional[list[str]] = None,
|
99
88
|
num_instances: Optional[int] = None,
|
100
89
|
enable_metrics: bool = False,
|
101
|
-
) ->
|
90
|
+
) -> dict[str, Any]:
|
102
91
|
"""
|
103
92
|
Generate a service specification for a job.
|
104
93
|
|
@@ -117,11 +106,11 @@ def generate_service_spec(
|
|
117
106
|
image_spec = _get_image_spec(session, compute_pool)
|
118
107
|
|
119
108
|
# Set resource requests/limits, including nvidia.com/gpu quantity if applicable
|
120
|
-
resource_requests:
|
109
|
+
resource_requests: dict[str, Union[str, int]] = {
|
121
110
|
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
122
111
|
"memory": f"{image_spec.resource_limits.memory}Gi",
|
123
112
|
}
|
124
|
-
resource_limits:
|
113
|
+
resource_limits: dict[str, Union[str, int]] = {
|
125
114
|
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
126
115
|
"memory": f"{image_spec.resource_limits.memory}Gi",
|
127
116
|
}
|
@@ -130,8 +119,8 @@ def generate_service_spec(
|
|
130
119
|
resource_limits["nvidia.com/gpu"] = image_spec.resource_limits.gpu
|
131
120
|
|
132
121
|
# Add local volumes for ephemeral logs and artifacts
|
133
|
-
volumes:
|
134
|
-
volume_mounts:
|
122
|
+
volumes: list[dict[str, str]] = []
|
123
|
+
volume_mounts: list[dict[str, str]] = []
|
135
124
|
for volume_name, mount_path in [
|
136
125
|
("system-logs", "/var/log/managedservices/system/mlrs"),
|
137
126
|
("user-logs", "/var/log/managedservices/user/mlrs"),
|
@@ -302,11 +291,11 @@ def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
|
|
302
291
|
|
303
292
|
|
304
293
|
def _merge_lists_of_dicts(
|
305
|
-
base:
|
306
|
-
patch:
|
294
|
+
base: list[dict[str, Any]],
|
295
|
+
patch: list[dict[str, Any]],
|
307
296
|
merge_key: str = "name",
|
308
297
|
display_name: str = "",
|
309
|
-
) ->
|
298
|
+
) -> list[dict[str, Any]]:
|
310
299
|
"""
|
311
300
|
Attempts to merge lists of dicts by matching on a merge key (default "name").
|
312
301
|
- If the merge key is missing, the behavior falls back to overwriting the list.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
2
|
from pathlib import PurePath
|
3
|
-
from typing import
|
3
|
+
from typing import Literal, Optional, Union
|
4
4
|
|
5
5
|
JOB_STATUS = Literal[
|
6
6
|
"PENDING",
|
@@ -21,7 +21,7 @@ class PayloadEntrypoint:
|
|
21
21
|
class UploadedPayload:
|
22
22
|
# TODO: Include manifest of payload files for validation
|
23
23
|
stage_path: PurePath
|
24
|
-
entrypoint:
|
24
|
+
entrypoint: list[Union[str, PurePath]]
|
25
25
|
|
26
26
|
|
27
27
|
@dataclass(frozen=True)
|
snowflake/ml/jobs/decorators.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
from typing import Callable,
|
3
|
+
from typing import Callable, Optional, TypeVar
|
4
4
|
|
5
5
|
from typing_extensions import ParamSpec
|
6
6
|
|
@@ -15,16 +15,15 @@ _Args = ParamSpec("_Args")
|
|
15
15
|
_ReturnValue = TypeVar("_ReturnValue")
|
16
16
|
|
17
17
|
|
18
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
19
18
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
20
19
|
def remote(
|
21
20
|
compute_pool: str,
|
22
21
|
*,
|
23
22
|
stage_name: str,
|
24
|
-
pip_requirements: Optional[
|
25
|
-
external_access_integrations: Optional[
|
23
|
+
pip_requirements: Optional[list[str]] = None,
|
24
|
+
external_access_integrations: Optional[list[str]] = None,
|
26
25
|
query_warehouse: Optional[str] = None,
|
27
|
-
env_vars: Optional[
|
26
|
+
env_vars: Optional[dict[str, str]] = None,
|
28
27
|
num_instances: Optional[int] = None,
|
29
28
|
enable_metrics: bool = False,
|
30
29
|
session: Optional[snowpark.Session] = None,
|