snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -6,12 +6,13 @@ import textwrap
|
|
6
6
|
import warnings
|
7
7
|
from enum import Enum
|
8
8
|
from importlib import metadata as importlib_metadata
|
9
|
-
from typing import Any, DefaultDict,
|
9
|
+
from typing import Any, DefaultDict, Optional
|
10
10
|
|
11
11
|
import yaml
|
12
12
|
from packaging import requirements, specifiers, version
|
13
13
|
|
14
14
|
import snowflake.connector
|
15
|
+
from snowflake.ml import version as snowml_version
|
15
16
|
from snowflake.ml._internal import env as snowml_env, relax_version_strategy
|
16
17
|
from snowflake.ml._internal.utils import query_result_checker
|
17
18
|
from snowflake.snowpark import context, exceptions, session
|
@@ -27,8 +28,8 @@ class CONDA_OS(Enum):
|
|
27
28
|
|
28
29
|
|
29
30
|
_NODEFAULTS = "nodefaults"
|
30
|
-
_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE:
|
31
|
-
_SNOWFLAKE_CONDA_PACKAGE_CACHE:
|
31
|
+
_SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: dict[str, list[version.Version]] = {}
|
32
|
+
_SNOWFLAKE_CONDA_PACKAGE_CACHE: dict[str, list[version.Version]] = {}
|
32
33
|
_SUPPORTED_PACKAGE_SPEC_OPS = ["==", ">=", "<=", ">", "<"]
|
33
34
|
|
34
35
|
DEFAULT_CHANNEL_NAME = ""
|
@@ -64,7 +65,7 @@ def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
|
|
64
65
|
return r
|
65
66
|
|
66
67
|
|
67
|
-
def _validate_conda_dependency_string(dep_str: str) ->
|
68
|
+
def _validate_conda_dependency_string(dep_str: str) -> tuple[str, requirements.Requirement]:
|
68
69
|
"""Validate conda dependency string like `pytorch == 1.12.1` or `conda-forge::transformer` and split the channel
|
69
70
|
name before the double colon and requirement specification after that.
|
70
71
|
|
@@ -115,7 +116,7 @@ class DuplicateDependencyInMultipleChannelsError(Exception):
|
|
115
116
|
...
|
116
117
|
|
117
118
|
|
118
|
-
def append_requirement_list(req_list:
|
119
|
+
def append_requirement_list(req_list: list[requirements.Requirement], p_req: requirements.Requirement) -> None:
|
119
120
|
"""Append a requirement to an existing requirement list. If need and able to merge, merge it, otherwise, append it.
|
120
121
|
|
121
122
|
Args:
|
@@ -134,7 +135,7 @@ def append_requirement_list(req_list: List[requirements.Requirement], p_req: req
|
|
134
135
|
|
135
136
|
|
136
137
|
def append_conda_dependency(
|
137
|
-
conda_chan_deps: DefaultDict[str,
|
138
|
+
conda_chan_deps: DefaultDict[str, list[requirements.Requirement]], p_chan_dep: tuple[str, requirements.Requirement]
|
138
139
|
) -> None:
|
139
140
|
"""Append a conda dependency to an existing conda dependencies dict, if not existed in any channel.
|
140
141
|
To avoid making unnecessary modification to dict, we check the existence first, then try to merge, then append,
|
@@ -164,45 +165,73 @@ def append_conda_dependency(
|
|
164
165
|
conda_chan_deps[p_channel].append(p_req)
|
165
166
|
|
166
167
|
|
167
|
-
def validate_pip_requirement_string_list(
|
168
|
-
|
168
|
+
def validate_pip_requirement_string_list(
|
169
|
+
req_str_list: list[str], add_local_version_specifier: bool = False
|
170
|
+
) -> list[requirements.Requirement]:
|
171
|
+
"""Validate the list of pip requirement strings according to PEP 508.
|
169
172
|
|
170
173
|
Args:
|
171
|
-
req_str_list: The list of
|
174
|
+
req_str_list: The list of strings containing the pip requirement specification.
|
175
|
+
add_local_version_specifier: if True, add the version specifier of the locally installed package version to
|
176
|
+
requirements without version specifiers.
|
172
177
|
|
173
178
|
Returns:
|
174
179
|
A requirements.Requirement list containing the requirement information.
|
175
180
|
"""
|
176
|
-
seen_pip_requirement_list:
|
181
|
+
seen_pip_requirement_list: list[requirements.Requirement] = []
|
177
182
|
for req_str in req_str_list:
|
178
183
|
append_requirement_list(seen_pip_requirement_list, _validate_pip_requirement_string(req_str=req_str))
|
179
184
|
|
185
|
+
if add_local_version_specifier:
|
186
|
+
# For any requirement string that does not contain a specifier, add the specifier of a locally installed version
|
187
|
+
# if it exists.
|
188
|
+
seen_pip_requirement_list = list(
|
189
|
+
map(
|
190
|
+
lambda req: req if req.specifier else get_local_installed_version_of_pip_package(req),
|
191
|
+
seen_pip_requirement_list,
|
192
|
+
)
|
193
|
+
)
|
194
|
+
|
180
195
|
return seen_pip_requirement_list
|
181
196
|
|
182
197
|
|
183
|
-
def validate_conda_dependency_string_list(
|
198
|
+
def validate_conda_dependency_string_list(
|
199
|
+
dep_str_list: list[str], add_local_version_specifier: bool = False
|
200
|
+
) -> DefaultDict[str, list[requirements.Requirement]]:
|
184
201
|
"""Validate a list of conda dependency string, find any duplicate package across different channel and create a dict
|
185
202
|
to represent the whole dependencies.
|
186
203
|
|
187
204
|
Args:
|
188
205
|
dep_str_list: The list of string contains the conda dependency specification.
|
206
|
+
add_local_version_specifier: if True, add the version specifier of the locally installed package version to
|
207
|
+
requirements without version specifiers.
|
189
208
|
|
190
209
|
Returns:
|
191
210
|
A dict mapping from the channel name to the list of requirements from that channel.
|
192
211
|
"""
|
193
212
|
validated_conda_dependency_list = list(map(_validate_conda_dependency_string, dep_str_list))
|
194
|
-
ret_conda_dependency_dict: DefaultDict[str,
|
213
|
+
ret_conda_dependency_dict: DefaultDict[str, list[requirements.Requirement]] = collections.defaultdict(list)
|
195
214
|
for p_channel, p_req in validated_conda_dependency_list:
|
196
215
|
append_conda_dependency(ret_conda_dependency_dict, (p_channel, p_req))
|
197
216
|
|
217
|
+
if add_local_version_specifier:
|
218
|
+
# For any conda dependency string that does not contain a specifier, add the specifier of a locally installed
|
219
|
+
# version if it exists. This is best-effort: if the conda package does not have the same name as the pip
|
220
|
+
# package, it won't be found in the local environment.
|
221
|
+
for channel_str, reqs in ret_conda_dependency_dict.items():
|
222
|
+
reqs = list(
|
223
|
+
map(lambda req: req if req.specifier else get_local_installed_version_of_pip_package(req), reqs)
|
224
|
+
)
|
225
|
+
ret_conda_dependency_dict[channel_str] = reqs
|
226
|
+
|
198
227
|
return ret_conda_dependency_dict
|
199
228
|
|
200
229
|
|
201
230
|
def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement) -> requirements.Requirement:
|
202
231
|
"""Get the local installed version of a given pip package requirement.
|
203
|
-
If the package is locally installed, and the local version
|
232
|
+
If the package is locally installed, and the local version meets the specifier of the requirements, return a new
|
204
233
|
requirement specifier that pins the version.
|
205
|
-
If the local version does not meet the specifier of the requirements, a
|
234
|
+
If the local version does not meet the specifier of the requirements, a warning will be emitted and returns
|
206
235
|
the original package requirement.
|
207
236
|
If the package is not locally installed or not found, the original package requirement is returned.
|
208
237
|
|
@@ -217,7 +246,7 @@ def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement
|
|
217
246
|
local_dist_version = local_dist.version
|
218
247
|
except importlib_metadata.PackageNotFoundError:
|
219
248
|
if pip_req.name == SNOWPARK_ML_PKG_NAME:
|
220
|
-
local_dist_version =
|
249
|
+
local_dist_version = snowml_version.VERSION
|
221
250
|
else:
|
222
251
|
return pip_req
|
223
252
|
new_pip_req = copy.deepcopy(pip_req)
|
@@ -372,8 +401,8 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req
|
|
372
401
|
|
373
402
|
|
374
403
|
def get_matched_package_versions_in_information_schema_with_active_session(
|
375
|
-
reqs:
|
376
|
-
) ->
|
404
|
+
reqs: list[requirements.Requirement], python_version: str
|
405
|
+
) -> dict[str, list[version.Version]]:
|
377
406
|
try:
|
378
407
|
session = context.get_active_session()
|
379
408
|
except exceptions.SnowparkSessionException:
|
@@ -383,10 +412,10 @@ def get_matched_package_versions_in_information_schema_with_active_session(
|
|
383
412
|
|
384
413
|
def get_matched_package_versions_in_information_schema(
|
385
414
|
session: session.Session,
|
386
|
-
reqs:
|
415
|
+
reqs: list[requirements.Requirement],
|
387
416
|
python_version: str,
|
388
|
-
statement_params: Optional[
|
389
|
-
) ->
|
417
|
+
statement_params: Optional[dict[str, Any]] = None,
|
418
|
+
) -> dict[str, list[version.Version]]:
|
390
419
|
"""Look up the information_schema table to check if a package with the specified specifier exists in the Snowflake
|
391
420
|
Conda channel. Note that this is not the source of truth due to the potential delay caused by a package that might
|
392
421
|
exist in the information_schema table but has not yet become available in the Snowflake Conda channel.
|
@@ -400,8 +429,8 @@ def get_matched_package_versions_in_information_schema(
|
|
400
429
|
Returns:
|
401
430
|
A Dict, whose key is the package name, and value is a list of versions match the requirements.
|
402
431
|
"""
|
403
|
-
ret_dict:
|
404
|
-
reqs_to_request:
|
432
|
+
ret_dict: dict[str, list[version.Version]] = {}
|
433
|
+
reqs_to_request: list[requirements.Requirement] = []
|
405
434
|
for req in reqs:
|
406
435
|
if req.name in _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE:
|
407
436
|
available_versions = list(
|
@@ -457,7 +486,7 @@ def get_matched_package_versions_in_information_schema(
|
|
457
486
|
|
458
487
|
def save_conda_env_file(
|
459
488
|
path: pathlib.Path,
|
460
|
-
conda_chan_deps: DefaultDict[str,
|
489
|
+
conda_chan_deps: DefaultDict[str, list[requirements.Requirement]],
|
461
490
|
python_version: str,
|
462
491
|
cuda_version: Optional[str] = None,
|
463
492
|
default_channel_override: str = SNOWFLAKE_CONDA_CHANNEL_URL,
|
@@ -478,7 +507,7 @@ def save_conda_env_file(
|
|
478
507
|
"""
|
479
508
|
assert path.suffix in [".yml", ".yaml"], "Conda environment file should have extension of yml or yaml."
|
480
509
|
path.parent.mkdir(parents=True, exist_ok=True)
|
481
|
-
env:
|
510
|
+
env: dict[str, Any] = dict()
|
482
511
|
env["name"] = "snow-env"
|
483
512
|
# Get all channels in the dependencies, ordered by the number of the packages which belongs to and put into
|
484
513
|
# channels section.
|
@@ -505,7 +534,7 @@ def save_conda_env_file(
|
|
505
534
|
yaml.safe_dump(env, stream=f, default_flow_style=False)
|
506
535
|
|
507
536
|
|
508
|
-
def save_requirements_file(path: pathlib.Path, pip_deps:
|
537
|
+
def save_requirements_file(path: pathlib.Path, pip_deps: list[requirements.Requirement]) -> None:
|
509
538
|
"""Generate Python requirements.txt file in the given directory path.
|
510
539
|
|
511
540
|
Args:
|
@@ -521,9 +550,9 @@ def save_requirements_file(path: pathlib.Path, pip_deps: List[requirements.Requi
|
|
521
550
|
|
522
551
|
def load_conda_env_file(
|
523
552
|
path: pathlib.Path,
|
524
|
-
) ->
|
525
|
-
DefaultDict[str,
|
526
|
-
Optional[
|
553
|
+
) -> tuple[
|
554
|
+
DefaultDict[str, list[requirements.Requirement]],
|
555
|
+
Optional[list[requirements.Requirement]],
|
527
556
|
Optional[str],
|
528
557
|
Optional[str],
|
529
558
|
]:
|
@@ -601,7 +630,7 @@ def load_conda_env_file(
|
|
601
630
|
return conda_dep_dict, pip_deps_list if pip_deps_list else None, python_version, cuda_version
|
602
631
|
|
603
632
|
|
604
|
-
def load_requirements_file(path: pathlib.Path) ->
|
633
|
+
def load_requirements_file(path: pathlib.Path) -> list[requirements.Requirement]:
|
605
634
|
"""Load Python requirements.txt file from the given directory path.
|
606
635
|
|
607
636
|
Args:
|
@@ -641,8 +670,8 @@ def parse_python_version_string(dep: str) -> Optional[str]:
|
|
641
670
|
|
642
671
|
|
643
672
|
def _find_conda_dep_spec(
|
644
|
-
conda_chan_deps: DefaultDict[str,
|
645
|
-
) -> Optional[
|
673
|
+
conda_chan_deps: DefaultDict[str, list[requirements.Requirement]], pkg_name: str
|
674
|
+
) -> Optional[tuple[str, requirements.Requirement]]:
|
646
675
|
for channel in conda_chan_deps:
|
647
676
|
spec = next(filter(lambda req: req.name == pkg_name, conda_chan_deps[channel]), None)
|
648
677
|
if spec:
|
@@ -650,14 +679,14 @@ def _find_conda_dep_spec(
|
|
650
679
|
return None
|
651
680
|
|
652
681
|
|
653
|
-
def _find_pip_req_spec(pip_reqs:
|
682
|
+
def _find_pip_req_spec(pip_reqs: list[requirements.Requirement], pkg_name: str) -> Optional[requirements.Requirement]:
|
654
683
|
spec = next(filter(lambda req: req.name == pkg_name, pip_reqs), None)
|
655
684
|
return spec
|
656
685
|
|
657
686
|
|
658
687
|
def find_dep_spec(
|
659
|
-
conda_chan_deps: DefaultDict[str,
|
660
|
-
pip_reqs:
|
688
|
+
conda_chan_deps: DefaultDict[str, list[requirements.Requirement]],
|
689
|
+
pip_reqs: list[requirements.Requirement],
|
661
690
|
conda_pkg_name: str,
|
662
691
|
pip_pkg_name: Optional[str] = None,
|
663
692
|
remove_spec: bool = False,
|
@@ -11,18 +11,7 @@ import sys
|
|
11
11
|
import tarfile
|
12
12
|
import tempfile
|
13
13
|
import zipfile
|
14
|
-
from typing import
|
15
|
-
Any,
|
16
|
-
Callable,
|
17
|
-
Dict,
|
18
|
-
Generator,
|
19
|
-
List,
|
20
|
-
Literal,
|
21
|
-
Optional,
|
22
|
-
Set,
|
23
|
-
Tuple,
|
24
|
-
Union,
|
25
|
-
)
|
14
|
+
from typing import Any, Callable, Generator, Literal, Optional, Union
|
26
15
|
from urllib import parse
|
27
16
|
|
28
17
|
import cloudpickle
|
@@ -37,7 +26,7 @@ GENERATED_PY_FILE_EXT = (".pyc", ".pyo", ".pyd", ".pyi")
|
|
37
26
|
def copytree(
|
38
27
|
src: "Union[str, os.PathLike[str]]",
|
39
28
|
dst: "Union[str, os.PathLike[str]]",
|
40
|
-
ignore: Optional[Callable[...,
|
29
|
+
ignore: Optional[Callable[..., set[str]]] = None,
|
41
30
|
dirs_exist_ok: bool = False,
|
42
31
|
) -> "Union[str, os.PathLike[str]]":
|
43
32
|
"""This is a forked version of shutil.copytree that remove all copystat, to make sure it works in Sproc.
|
@@ -170,7 +159,7 @@ def zip_python_package(zipfile_path: str, package_name: str, ignore_generated_py
|
|
170
159
|
|
171
160
|
|
172
161
|
def hash_directory(
|
173
|
-
directory: Union[str, pathlib.Path], *, ignore_hidden: bool = False, excluded_files: Optional[
|
162
|
+
directory: Union[str, pathlib.Path], *, ignore_hidden: bool = False, excluded_files: Optional[list[str]] = None
|
174
163
|
) -> str:
|
175
164
|
"""Hash the **content** of a folder recursively using SHA-1.
|
176
165
|
|
@@ -186,7 +175,7 @@ def hash_directory(
|
|
186
175
|
excluded_files = []
|
187
176
|
|
188
177
|
def _update_hash_from_dir(
|
189
|
-
directory: Union[str, pathlib.Path], hash: "hashlib._Hash", *, ignore_hidden: bool, excluded_files:
|
178
|
+
directory: Union[str, pathlib.Path], hash: "hashlib._Hash", *, ignore_hidden: bool, excluded_files: list[str]
|
190
179
|
) -> "hashlib._Hash":
|
191
180
|
assert pathlib.Path(directory).is_dir(), "Provided path is not a directory."
|
192
181
|
for path in sorted(pathlib.Path(directory).iterdir(), key=lambda p: str(p).lower()):
|
@@ -208,7 +197,7 @@ def hash_directory(
|
|
208
197
|
).hexdigest()
|
209
198
|
|
210
199
|
|
211
|
-
def get_all_modules(dirname: str, prefix: str = "") ->
|
200
|
+
def get_all_modules(dirname: str, prefix: str = "") -> list[str]:
|
212
201
|
modules = [mod.name for mod in pkgutil.iter_modules([dirname], prefix=prefix)]
|
213
202
|
subdirs = [f.path for f in os.scandir(dirname) if f.is_dir()]
|
214
203
|
for sub_dirname in subdirs:
|
@@ -248,7 +237,7 @@ def _create_tar_gz_stream(source_dir: str, arcname: Optional[str] = None) -> Gen
|
|
248
237
|
yield output_stream
|
249
238
|
|
250
239
|
|
251
|
-
def get_package_path(package_name: str, strategy: Literal["first", "last"] = "first") ->
|
240
|
+
def get_package_path(package_name: str, strategy: Literal["first", "last"] = "first") -> tuple[str, str]:
|
252
241
|
"""[Obsolete]Return the path to where a package is defined and its start location.
|
253
242
|
Example 1: snowflake.ml -> path/to/site-packages/snowflake/ml, path/to/site-packages
|
254
243
|
Example 2: zip_imported_module -> path/to/some/zipfile.zip/zip_imported_module, path/to/some/zipfile.zip
|
@@ -267,7 +256,7 @@ def get_package_path(package_name: str, strategy: Literal["first", "last"] = "fi
|
|
267
256
|
return pkg_path, pkg_start_path
|
268
257
|
|
269
258
|
|
270
|
-
def stage_object(session: snowpark.Session, object: object, stage_location: str) ->
|
259
|
+
def stage_object(session: snowpark.Session, object: object, stage_location: str) -> list[snowpark.PutResult]:
|
271
260
|
temp_file = tempfile.NamedTemporaryFile(delete=False)
|
272
261
|
temp_file_path = temp_file.name
|
273
262
|
temp_file.close()
|
@@ -279,7 +268,7 @@ def stage_object(session: snowpark.Session, object: object, stage_location: str)
|
|
279
268
|
|
280
269
|
|
281
270
|
def stage_file_exists(
|
282
|
-
session: snowpark.Session, stage_location: str, file_name: str, statement_params:
|
271
|
+
session: snowpark.Session, stage_location: str, file_name: str, statement_params: dict[str, Any]
|
283
272
|
) -> bool:
|
284
273
|
try:
|
285
274
|
res = session.sql(f"list {stage_location}/{file_name}").collect(statement_params=statement_params)
|
@@ -297,7 +286,7 @@ def upload_directory_to_stage(
|
|
297
286
|
local_path: pathlib.Path,
|
298
287
|
stage_path: Union[pathlib.PurePosixPath, parse.ParseResult],
|
299
288
|
*,
|
300
|
-
statement_params: Optional[
|
289
|
+
statement_params: Optional[dict[str, Any]] = None,
|
301
290
|
) -> None:
|
302
291
|
"""Upload a local folder recursively to a stage and keep the structure.
|
303
292
|
|
@@ -350,7 +339,7 @@ def download_directory_from_stage(
|
|
350
339
|
stage_path: pathlib.PurePosixPath,
|
351
340
|
local_path: pathlib.Path,
|
352
341
|
*,
|
353
|
-
statement_params: Optional[
|
342
|
+
statement_params: Optional[dict[str, Any]] = None,
|
354
343
|
) -> None:
|
355
344
|
"""Upload a folder in stage recursively to a folder in local and keep the structure.
|
356
345
|
|
@@ -15,7 +15,6 @@ In this module you will find:
|
|
15
15
|
|
16
16
|
import math
|
17
17
|
from abc import ABC, abstractmethod
|
18
|
-
from typing import Dict, List, Tuple
|
19
18
|
|
20
19
|
|
21
20
|
class HRIDBase(ABC):
|
@@ -28,12 +27,11 @@ class HRIDBase(ABC):
|
|
28
27
|
@abstractmethod
|
29
28
|
def __id_generator__(self) -> int:
|
30
29
|
"""The generator to use to generate new IDs. The implementer needs to provide this."""
|
31
|
-
pass
|
32
30
|
|
33
|
-
__hrid_structure__:
|
31
|
+
__hrid_structure__: tuple[str, ...]
|
34
32
|
"""The HRID structure to be generated. The implementer needs to provide this."""
|
35
33
|
|
36
|
-
__hrid_words__:
|
34
|
+
__hrid_words__: dict[str, tuple[str, ...]]
|
37
35
|
"""The mapping between the HRID parts and the words to use. The implementer needs to provide this."""
|
38
36
|
|
39
37
|
__separator__ = "_"
|
@@ -82,7 +80,7 @@ class HRIDBase(ABC):
|
|
82
80
|
hrid.append(str(values[idxs[i]]))
|
83
81
|
return self.__separator__.join(hrid)
|
84
82
|
|
85
|
-
def generate(self) ->
|
83
|
+
def generate(self) -> tuple[int, str]:
|
86
84
|
"""Generate an ID and the corresponding HRID.
|
87
85
|
|
88
86
|
Returns:
|
@@ -92,7 +90,7 @@ class HRIDBase(ABC):
|
|
92
90
|
hrid = self.id_to_hrid(id)
|
93
91
|
return (id, hrid)
|
94
92
|
|
95
|
-
def _id_to_idxs(self, id: int) ->
|
93
|
+
def _id_to_idxs(self, id: int) -> list[int]:
|
96
94
|
"""Take the ID and convert it to indices into the HRID words.
|
97
95
|
|
98
96
|
Args:
|
@@ -109,7 +107,7 @@ class HRIDBase(ABC):
|
|
109
107
|
idxs.append((id & mask) >> shift)
|
110
108
|
return idxs
|
111
109
|
|
112
|
-
def _hrid_to_idxs(self, hrid: str) ->
|
110
|
+
def _hrid_to_idxs(self, hrid: str) -> list[int]:
|
113
111
|
"""Take the HRID and convert it to indices into the HRID words.
|
114
112
|
|
115
113
|
Args:
|
@@ -2,10 +2,9 @@ import importlib
|
|
2
2
|
import inspect
|
3
3
|
import pkgutil
|
4
4
|
from types import FunctionType
|
5
|
-
from typing import Dict
|
6
5
|
|
7
6
|
|
8
|
-
def fetch_classes_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) ->
|
7
|
+
def fetch_classes_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) -> dict[str, type]:
|
9
8
|
"""Finds classes defined all the python modules in the given package directory.
|
10
9
|
|
11
10
|
Args:
|
@@ -36,7 +35,7 @@ def fetch_classes_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) -> Dict[s
|
|
36
35
|
return exportable_classes
|
37
36
|
|
38
37
|
|
39
|
-
def fetch_functions_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) ->
|
38
|
+
def fetch_functions_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) -> dict[str, FunctionType]:
|
40
39
|
"""Finds functions defined all the python modules in the given package directory.
|
41
40
|
|
42
41
|
Args:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
from typing import Any, Callable,
|
3
|
+
from typing import Any, Callable, Optional, get_args
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
6
|
from snowflake.ml.data import data_source
|
@@ -9,7 +9,7 @@ _DATA_SOURCES_ATTR = "_data_sources"
|
|
9
9
|
|
10
10
|
|
11
11
|
def _wrap_func(
|
12
|
-
fn: Callable[..., snowpark.DataFrame], data_sources:
|
12
|
+
fn: Callable[..., snowpark.DataFrame], data_sources: list[data_source.DataSource]
|
13
13
|
) -> Callable[..., snowpark.DataFrame]:
|
14
14
|
"""Wrap a DataFrame transform function to propagate data_sources to derived DataFrames."""
|
15
15
|
|
@@ -34,9 +34,9 @@ def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., sno
|
|
34
34
|
return wrapped
|
35
35
|
|
36
36
|
|
37
|
-
def get_data_sources(*args: Any) -> Optional[
|
37
|
+
def get_data_sources(*args: Any) -> Optional[list[data_source.DataSource]]:
|
38
38
|
"""Helper method for extracting data sources attribute from DataFrames in an argument list"""
|
39
|
-
result: Optional[
|
39
|
+
result: Optional[list[data_source.DataSource]] = None
|
40
40
|
for arg in args:
|
41
41
|
srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
|
42
42
|
if isinstance(srcs, list) and all(isinstance(s, get_args(data_source.DataSource)) for s in srcs):
|
@@ -46,7 +46,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
|
|
46
46
|
return result
|
47
47
|
|
48
48
|
|
49
|
-
def set_data_sources(obj: Any, data_sources: Optional[
|
49
|
+
def set_data_sources(obj: Any, data_sources: Optional[list[data_source.DataSource]]) -> None:
|
50
50
|
"""Helper method for attaching data sources to an object"""
|
51
51
|
if data_sources:
|
52
52
|
assert all(isinstance(ds, get_args(data_source.DataSource)) for ds in data_sources)
|
@@ -54,7 +54,7 @@ def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSourc
|
|
54
54
|
|
55
55
|
|
56
56
|
def patch_dataframe(
|
57
|
-
df: snowpark.DataFrame, data_sources:
|
57
|
+
df: snowpark.DataFrame, data_sources: list[data_source.DataSource], inplace: bool = False
|
58
58
|
) -> snowpark.DataFrame:
|
59
59
|
"""
|
60
60
|
Monkey patch a DataFrame to add attach the provided data_sources as an attribute of the DataFrame.
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import json
|
2
|
-
from
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from typing import Any, Optional
|
3
4
|
|
4
5
|
from absl import logging
|
5
6
|
|
@@ -27,21 +28,50 @@ class PlatformCapabilities:
|
|
27
28
|
"""
|
28
29
|
|
29
30
|
_instance: Optional["PlatformCapabilities"] = None
|
31
|
+
# Used for unittesting only. This is to avoid the need to mock the session object or reaching out to Snowflake
|
32
|
+
_mock_features: Optional[dict[str, Any]] = None
|
30
33
|
|
31
34
|
@classmethod
|
32
35
|
def get_instance(cls, session: Optional[snowpark_session.Session] = None) -> "PlatformCapabilities":
|
36
|
+
# Used for unittesting only. In this situation, _instance is not initialized.
|
37
|
+
if cls._mock_features is not None:
|
38
|
+
return cls(features=cls._mock_features)
|
33
39
|
if not cls._instance:
|
34
|
-
cls._instance = cls(session)
|
40
|
+
cls._instance = cls(session=session)
|
35
41
|
return cls._instance
|
36
42
|
|
43
|
+
@classmethod
|
44
|
+
def set_mock_features(cls, features: Optional[dict[str, Any]] = None) -> None:
|
45
|
+
cls._mock_features = features
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def clear_mock_features(cls) -> None:
|
49
|
+
cls._mock_features = None
|
50
|
+
|
51
|
+
# For contextmanager, we need to have return type Iterator[Never]. However, Never type is introduced only in
|
52
|
+
# Python 3.11. So, we are ignoring the type for this method.
|
53
|
+
@classmethod # type: ignore[arg-type]
|
54
|
+
@contextmanager
|
55
|
+
def mock_features(cls, features: dict[str, Any]) -> None: # type: ignore[misc]
|
56
|
+
logging.debug(f"Setting mock features: {features}")
|
57
|
+
cls.set_mock_features(features)
|
58
|
+
try:
|
59
|
+
yield
|
60
|
+
finally:
|
61
|
+
logging.debug(f"Clearing mock features: {features}")
|
62
|
+
cls.clear_mock_features()
|
63
|
+
|
37
64
|
def is_nested_function_enabled(self) -> bool:
|
38
65
|
return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
|
39
66
|
|
67
|
+
def is_inlined_deployment_spec_enabled(self) -> bool:
|
68
|
+
return self._get_bool_feature("ENABLE_INLINE_DEPLOYMENT_SPEC", False)
|
69
|
+
|
40
70
|
def is_live_commit_enabled(self) -> bool:
|
41
71
|
return self._get_bool_feature("ENABLE_BUNDLE_MODULE_CHECKOUT", False)
|
42
72
|
|
43
73
|
@staticmethod
|
44
|
-
def _get_features(session: snowpark_session.Session) ->
|
74
|
+
def _get_features(session: snowpark_session.Session) -> dict[str, Any]:
|
45
75
|
try:
|
46
76
|
result = (
|
47
77
|
query_result_checker.SqlResultValidator(
|
@@ -68,11 +98,17 @@ class PlatformCapabilities:
|
|
68
98
|
# This can happen is server side is older than 9.2. That is fine.
|
69
99
|
return {}
|
70
100
|
|
71
|
-
def __init__(
|
101
|
+
def __init__(
|
102
|
+
self, *, session: Optional[snowpark_session.Session] = None, features: Optional[dict[str, Any]] = None
|
103
|
+
) -> None:
|
104
|
+
# This is for testing purposes only.
|
105
|
+
if features:
|
106
|
+
self.features = features
|
107
|
+
return
|
72
108
|
if not session:
|
73
109
|
session = next(iter(snowpark_session._get_active_sessions()))
|
74
110
|
assert session, "Missing active session object"
|
75
|
-
self.features
|
111
|
+
self.features = PlatformCapabilities._get_features(session)
|
76
112
|
|
77
113
|
def _get_bool_feature(self, feature_name: str, default_value: bool) -> bool:
|
78
114
|
value = self.features.get(feature_name, default_value)
|