snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +6 -6
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +6 -16
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +17 -28
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +4 -5
- snowflake/ml/jobs/job.py +24 -14
- snowflake/ml/jobs/manager.py +37 -41
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +199 -26
- snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +13 -13
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +17 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +33 -31
- snowflake/ml/registry/registry.py +29 -22
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
import functools
|
2
2
|
import inspect
|
3
|
-
from typing import Any, Callable, Coroutine,
|
3
|
+
from typing import Any, Callable, Coroutine, Generator, Optional, Union
|
4
4
|
|
5
5
|
import anyio
|
6
6
|
import pandas as pd
|
@@ -78,7 +78,7 @@ class ModelRef:
|
|
78
78
|
return MethodRef(self, method_name)
|
79
79
|
raise AttributeError(f"Method {method_name} not found in model {self._name}.")
|
80
80
|
|
81
|
-
def __getstate__(self) ->
|
81
|
+
def __getstate__(self) -> dict[str, Any]:
|
82
82
|
state = self.__dict__.copy()
|
83
83
|
del state["_model"]
|
84
84
|
return state
|
@@ -113,8 +113,8 @@ class ModelContext:
|
|
113
113
|
def __init__(
|
114
114
|
self,
|
115
115
|
*,
|
116
|
-
artifacts: Optional[Union[
|
117
|
-
models: Optional[Union[
|
116
|
+
artifacts: Optional[Union[dict[str, str], str, model_types.SupportedModelType]] = None,
|
117
|
+
models: Optional[Union[dict[str, model_types.SupportedModelType], str, model_types.SupportedModelType]] = None,
|
118
118
|
**kwargs: Optional[Union[str, model_types.SupportedModelType]],
|
119
119
|
) -> None:
|
120
120
|
"""Initialize the model context.
|
@@ -130,8 +130,8 @@ class ModelContext:
|
|
130
130
|
ValueError: Raised when the model name is duplicated.
|
131
131
|
"""
|
132
132
|
|
133
|
-
self.artifacts:
|
134
|
-
self.model_refs:
|
133
|
+
self.artifacts: dict[str, str] = dict()
|
134
|
+
self.model_refs: dict[str, ModelRef] = dict()
|
135
135
|
|
136
136
|
# In case that artifacts is a dictionary, assume the original usage,
|
137
137
|
# which is to pass in a dictionary of artifacts.
|
@@ -185,7 +185,7 @@ class ModelContext:
|
|
185
185
|
return self.model_refs[name]
|
186
186
|
|
187
187
|
def __getitem__(self, key: str) -> Union[str, ModelRef]:
|
188
|
-
combined:
|
188
|
+
combined: dict[str, Union[str, ModelRef]] = {**self.artifacts, **self.model_refs}
|
189
189
|
if key not in combined:
|
190
190
|
raise KeyError(f"Key {key} not found in the kwargs, current available keys are: {combined.keys()}")
|
191
191
|
return combined[key]
|
@@ -226,7 +226,7 @@ class CustomModel:
|
|
226
226
|
else:
|
227
227
|
raise TypeError("A non-method inference API function is not supported.")
|
228
228
|
|
229
|
-
def _get_partitioned_infer_methods(self) ->
|
229
|
+
def _get_partitioned_infer_methods(self) -> list[str]:
|
230
230
|
"""Returns all methods in CLS with `partitioned_inference_api` as the outermost decorator."""
|
231
231
|
rv = []
|
232
232
|
for cls_method_str in dir(self):
|
@@ -1,18 +1,7 @@
|
|
1
1
|
import enum
|
2
2
|
import json
|
3
3
|
import warnings
|
4
|
-
from typing import
|
5
|
-
Any,
|
6
|
-
Dict,
|
7
|
-
List,
|
8
|
-
Literal,
|
9
|
-
Optional,
|
10
|
-
Sequence,
|
11
|
-
Tuple,
|
12
|
-
Type,
|
13
|
-
Union,
|
14
|
-
cast,
|
15
|
-
)
|
4
|
+
from typing import Any, Literal, Optional, Sequence, Union, cast
|
16
5
|
|
17
6
|
import numpy as np
|
18
7
|
import pandas as pd
|
@@ -30,7 +19,7 @@ from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
|
|
30
19
|
from snowflake.ml.model import type_hints as model_types
|
31
20
|
from snowflake.ml.model._signatures import (
|
32
21
|
base_handler,
|
33
|
-
builtins_handler
|
22
|
+
builtins_handler,
|
34
23
|
core,
|
35
24
|
dmatrix_handler,
|
36
25
|
numpy_handler,
|
@@ -48,7 +37,7 @@ FeatureGroupSpec = core.FeatureGroupSpec
|
|
48
37
|
ModelSignature = core.ModelSignature
|
49
38
|
|
50
39
|
|
51
|
-
_LOCAL_DATA_HANDLERS:
|
40
|
+
_LOCAL_DATA_HANDLERS: list[type[base_handler.BaseDataHandler[Any]]] = [
|
52
41
|
pandas_handler.PandasDataFrameHandler,
|
53
42
|
numpy_handler.NumpyArrayHandler,
|
54
43
|
builtins_handler.ListOfBuiltinHandler,
|
@@ -414,7 +403,7 @@ class SnowparkIdentifierRule(enum.Enum):
|
|
414
403
|
|
415
404
|
def _get_dataframe_values_range(
|
416
405
|
df: snowflake.snowpark.DataFrame,
|
417
|
-
) ->
|
406
|
+
) -> dict[str, Union[tuple[int, int], tuple[float, float]]]:
|
418
407
|
columns = [
|
419
408
|
F.array_construct(F.min(field.name), F.max(field.name)).as_(field.name)
|
420
409
|
for field in df.schema.fields
|
@@ -429,7 +418,7 @@ def _get_dataframe_values_range(
|
|
429
418
|
original_exception=ValueError(f"Unable to get the value range of fields {df.columns}"),
|
430
419
|
)
|
431
420
|
return cast(
|
432
|
-
|
421
|
+
dict[str, Union[tuple[int, int], tuple[float, float]]],
|
433
422
|
{
|
434
423
|
sql_identifier.SqlIdentifier(k, case_sensitive=True).identifier(): (json.loads(v)[0], json.loads(v)[1])
|
435
424
|
for k, v in res[0].as_dict().items()
|
@@ -456,7 +445,7 @@ def _validate_snowpark_data(
|
|
456
445
|
- inferred: signature `a` - Snowpark DF `"a"`, use `get_inferred_name`
|
457
446
|
- normalized: signature `a` - Snowpark DF `A`, use `resolve_identifier`
|
458
447
|
"""
|
459
|
-
errors:
|
448
|
+
errors: dict[SnowparkIdentifierRule, list[Exception]] = {
|
460
449
|
SnowparkIdentifierRule.INFERRED: [],
|
461
450
|
SnowparkIdentifierRule.NORMALIZED: [],
|
462
451
|
}
|
@@ -549,7 +538,7 @@ def _validate_snowpark_type_feature(
|
|
549
538
|
field: spt.StructField,
|
550
539
|
ft_type: DataType,
|
551
540
|
ft_name: str,
|
552
|
-
value_range: Optional[Union[
|
541
|
+
value_range: Optional[Union[tuple[int, int], tuple[float, float]]],
|
553
542
|
strict: bool = False,
|
554
543
|
) -> None:
|
555
544
|
field_data_type = field.datatype
|
@@ -716,8 +705,8 @@ def _convert_and_validate_local_data(
|
|
716
705
|
def infer_signature(
|
717
706
|
input_data: model_types.SupportedLocalDataType,
|
718
707
|
output_data: model_types.SupportedLocalDataType,
|
719
|
-
input_feature_names: Optional[
|
720
|
-
output_feature_names: Optional[
|
708
|
+
input_feature_names: Optional[list[str]] = None,
|
709
|
+
output_feature_names: Optional[list[str]] = None,
|
721
710
|
input_data_limit: Optional[int] = 100,
|
722
711
|
output_data_limit: Optional[int] = 100,
|
723
712
|
) -> core.ModelSignature:
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import warnings
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional
|
3
3
|
|
4
4
|
from packaging import version
|
5
5
|
|
@@ -13,7 +13,7 @@ class HuggingFacePipelineModel:
|
|
13
13
|
revision: Optional[str] = None,
|
14
14
|
token: Optional[str] = None,
|
15
15
|
trust_remote_code: Optional[bool] = None,
|
16
|
-
model_kwargs: Optional[
|
16
|
+
model_kwargs: Optional[dict[str, Any]] = None,
|
17
17
|
**kwargs: Any,
|
18
18
|
) -> None:
|
19
19
|
"""
|
@@ -65,6 +65,7 @@ class HuggingFacePipelineModel:
|
|
65
65
|
warnings.warn(
|
66
66
|
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.",
|
67
67
|
FutureWarning,
|
68
|
+
stacklevel=2,
|
68
69
|
)
|
69
70
|
if token is not None:
|
70
71
|
raise ValueError(
|
@@ -183,7 +184,8 @@ class HuggingFacePipelineModel:
|
|
183
184
|
warnings.warn(
|
184
185
|
f"No model was supplied, defaulted to {model} and revision"
|
185
186
|
f" {revision} ({transformers.pipelines.HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
|
186
|
-
"Using a pipeline without specifying a model name and revision in production is not recommended."
|
187
|
+
"Using a pipeline without specifying a model name and revision in production is not recommended.",
|
188
|
+
stacklevel=2,
|
187
189
|
)
|
188
190
|
if config is None and isinstance(model, str):
|
189
191
|
config_obj = transformers.AutoConfig.from_pretrained(
|
@@ -200,7 +202,8 @@ class HuggingFacePipelineModel:
|
|
200
202
|
if kwargs.get("device", None) is not None:
|
201
203
|
warnings.warn(
|
202
204
|
"Both `device` and `device_map` are specified. `device` will override `device_map`. You"
|
203
|
-
" will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`."
|
205
|
+
" will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`.",
|
206
|
+
stacklevel=2,
|
204
207
|
)
|
205
208
|
|
206
209
|
# ==== End pipeline logic from transformers ====
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
2
|
from enum import Enum
|
3
|
-
from typing import TYPE_CHECKING,
|
3
|
+
from typing import TYPE_CHECKING, Literal, Sequence, TypedDict, TypeVar, Union
|
4
4
|
|
5
5
|
import numpy.typing as npt
|
6
6
|
from typing_extensions import NotRequired
|
@@ -32,7 +32,7 @@ _SupportedBuiltins = Union[
|
|
32
32
|
bool,
|
33
33
|
str,
|
34
34
|
bytes,
|
35
|
-
|
35
|
+
dict[str, Union["_SupportedBuiltins", "_SupportedBuiltinsList"]],
|
36
36
|
"_SupportedBuiltinsList",
|
37
37
|
]
|
38
38
|
_SupportedNumpyDtype = Union[
|
@@ -153,7 +153,7 @@ class BaseModelSaveOption(TypedDict):
|
|
153
153
|
embed_local_ml_library: NotRequired[bool]
|
154
154
|
relax_version: NotRequired[bool]
|
155
155
|
function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
|
156
|
-
method_options: NotRequired[
|
156
|
+
method_options: NotRequired[dict[str, ModelMethodSaveOptions]]
|
157
157
|
enable_explainability: NotRequired[bool]
|
158
158
|
save_location: NotRequired[str]
|
159
159
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import inspect
|
2
2
|
import numbers
|
3
3
|
import os
|
4
|
-
from typing import Any, Callable
|
4
|
+
from typing import Any, Callable
|
5
5
|
|
6
6
|
import cloudpickle as cp
|
7
7
|
import numpy as np
|
@@ -16,7 +16,7 @@ from snowflake.snowpark import Session
|
|
16
16
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
17
17
|
|
18
18
|
|
19
|
-
def validate_sklearn_args(args:
|
19
|
+
def validate_sklearn_args(args: dict[str, tuple[Any, Any, bool]], klass: type) -> dict[str, Any]:
|
20
20
|
"""Validate if all the keyword args are supported by current version of SKLearn/XGBoost object.
|
21
21
|
|
22
22
|
Args:
|
@@ -71,7 +71,7 @@ def transform_snowml_obj_to_sklearn_obj(obj: Any) -> Any:
|
|
71
71
|
return obj
|
72
72
|
|
73
73
|
|
74
|
-
def gather_dependencies(obj: Any) ->
|
74
|
+
def gather_dependencies(obj: Any) -> set[str]:
|
75
75
|
"""Gathers dependencies from the SnowML Estimator and Transformer objects.
|
76
76
|
|
77
77
|
Args:
|
@@ -82,7 +82,7 @@ def gather_dependencies(obj: Any) -> Set[str]:
|
|
82
82
|
"""
|
83
83
|
|
84
84
|
if isinstance(obj, list) or isinstance(obj, tuple):
|
85
|
-
deps:
|
85
|
+
deps: set[str] = set()
|
86
86
|
for elem in obj:
|
87
87
|
deps = deps | set(gather_dependencies(elem))
|
88
88
|
return deps
|
@@ -167,8 +167,8 @@ def get_module_name(model: object) -> str:
|
|
167
167
|
|
168
168
|
|
169
169
|
def handle_inference_result(
|
170
|
-
inference_res: Any, output_cols:
|
171
|
-
) ->
|
170
|
+
inference_res: Any, output_cols: list[str], inference_method: str, within_udf: bool = False
|
171
|
+
) -> tuple[npt.NDArray[Any], list[str]]:
|
172
172
|
if isinstance(inference_res, list) and len(inference_res) > 0 and isinstance(inference_res[0], np.ndarray):
|
173
173
|
# In case of multioutput estimators, predict_proba, decision_function etc., functions return a list of
|
174
174
|
# ndarrays. We need to concatenate them.
|
@@ -248,7 +248,7 @@ def create_temp_stage(session: Session) -> str:
|
|
248
248
|
|
249
249
|
|
250
250
|
def upload_model_to_stage(
|
251
|
-
stage_name: str, estimator: object, session: Session, statement_params:
|
251
|
+
stage_name: str, estimator: object, session: Session, statement_params: dict[str, str]
|
252
252
|
) -> str:
|
253
253
|
"""Util method to pickle and upload the model to a temp Snowflake stage.
|
254
254
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import inspect
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
|
@@ -38,9 +38,9 @@ class PandasTransformHandlers:
|
|
38
38
|
def batch_inference(
|
39
39
|
self,
|
40
40
|
inference_method: str,
|
41
|
-
input_cols:
|
42
|
-
expected_output_cols:
|
43
|
-
snowpark_input_cols: Optional[
|
41
|
+
input_cols: list[str],
|
42
|
+
expected_output_cols: list[str],
|
43
|
+
snowpark_input_cols: Optional[list[str]] = None,
|
44
44
|
drop_input_cols: Optional[bool] = False,
|
45
45
|
*args: Any,
|
46
46
|
**kwargs: Any,
|
@@ -147,8 +147,8 @@ class PandasTransformHandlers:
|
|
147
147
|
|
148
148
|
def score(
|
149
149
|
self,
|
150
|
-
input_cols:
|
151
|
-
label_cols:
|
150
|
+
input_cols: list[str],
|
151
|
+
label_cols: list[str],
|
152
152
|
sample_weight_col: Optional[str],
|
153
153
|
*args: Any,
|
154
154
|
**kwargs: Any,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import inspect
|
2
|
-
from typing import
|
2
|
+
from typing import Optional
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
|
@@ -15,8 +15,8 @@ class PandasModelTrainer:
|
|
15
15
|
self,
|
16
16
|
estimator: object,
|
17
17
|
dataset: pd.DataFrame,
|
18
|
-
input_cols:
|
19
|
-
label_cols: Optional[
|
18
|
+
input_cols: list[str],
|
19
|
+
label_cols: Optional[list[str]],
|
20
20
|
sample_weight_col: Optional[str],
|
21
21
|
) -> None:
|
22
22
|
"""
|
@@ -57,10 +57,10 @@ class PandasModelTrainer:
|
|
57
57
|
|
58
58
|
def train_fit_predict(
|
59
59
|
self,
|
60
|
-
expected_output_cols_list:
|
60
|
+
expected_output_cols_list: list[str],
|
61
61
|
drop_input_cols: Optional[bool] = False,
|
62
62
|
example_output_pd_df: Optional[pd.DataFrame] = None,
|
63
|
-
) ->
|
63
|
+
) -> tuple[pd.DataFrame, object]:
|
64
64
|
"""Trains the model using specified features and target columns from the dataset.
|
65
65
|
This API is different from fit itself because it would also provide the predict
|
66
66
|
output.
|
@@ -92,9 +92,9 @@ class PandasModelTrainer:
|
|
92
92
|
|
93
93
|
def train_fit_transform(
|
94
94
|
self,
|
95
|
-
expected_output_cols_list:
|
95
|
+
expected_output_cols_list: list[str],
|
96
96
|
drop_input_cols: Optional[bool] = False,
|
97
|
-
) ->
|
97
|
+
) -> tuple[pd.DataFrame, object]:
|
98
98
|
"""Trains the model using specified features and target columns from the dataset.
|
99
99
|
This API is different from fit itself because it would also provide the transform
|
100
100
|
output.
|
@@ -1,5 +1,3 @@
|
|
1
|
-
from typing import List
|
2
|
-
|
3
1
|
import cloudpickle as cp
|
4
2
|
import numpy as np
|
5
3
|
|
@@ -11,7 +9,7 @@ class ModelSpecifications:
|
|
11
9
|
A dataclass to define model based specifications like required imports, and package dependencies for Sproc/Udfs.
|
12
10
|
"""
|
13
11
|
|
14
|
-
def __init__(self, imports:
|
12
|
+
def __init__(self, imports: list[str], pkgDependencies: list[str]) -> None:
|
15
13
|
self.imports = imports
|
16
14
|
self.pkgDependencies = pkgDependencies
|
17
15
|
|
@@ -20,7 +18,7 @@ class SKLearnModelSpecifications(ModelSpecifications):
|
|
20
18
|
def __init__(self) -> None:
|
21
19
|
import sklearn
|
22
20
|
|
23
|
-
imports:
|
21
|
+
imports: list[str] = ["sklearn"]
|
24
22
|
# TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda.
|
25
23
|
pkgDependencies = [
|
26
24
|
f"numpy=={np.__version__}",
|
@@ -56,8 +54,8 @@ class XGBoostModelSpecifications(ModelSpecifications):
|
|
56
54
|
import sklearn
|
57
55
|
import xgboost
|
58
56
|
|
59
|
-
imports:
|
60
|
-
pkgDependencies:
|
57
|
+
imports: list[str] = ["xgboost"]
|
58
|
+
pkgDependencies: list[str] = [
|
61
59
|
f"numpy=={np.__version__}",
|
62
60
|
f"scikit-learn=={sklearn.__version__}",
|
63
61
|
f"xgboost=={xgboost.__version__}",
|
@@ -71,8 +69,8 @@ class LightGBMModelSpecifications(ModelSpecifications):
|
|
71
69
|
import lightgbm
|
72
70
|
import sklearn
|
73
71
|
|
74
|
-
imports:
|
75
|
-
pkgDependencies:
|
72
|
+
imports: list[str] = ["lightgbm"]
|
73
|
+
pkgDependencies: list[str] = [
|
76
74
|
f"numpy=={np.__version__}",
|
77
75
|
f"scikit-learn=={sklearn.__version__}",
|
78
76
|
f"lightgbm=={lightgbm.__version__}",
|
@@ -86,8 +84,8 @@ class SklearnModelSelectionModelSpecifications(ModelSpecifications):
|
|
86
84
|
import sklearn
|
87
85
|
import xgboost
|
88
86
|
|
89
|
-
imports:
|
90
|
-
pkgDependencies:
|
87
|
+
imports: list[str] = ["sklearn", "xgboost"]
|
88
|
+
pkgDependencies: list[str] = [
|
91
89
|
f"numpy=={np.__version__}",
|
92
90
|
f"scikit-learn=={sklearn.__version__}",
|
93
91
|
f"cloudpickle=={cp.__version__}",
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Protocol, Union
|
2
2
|
|
3
3
|
import pandas as pd
|
4
4
|
|
@@ -18,15 +18,15 @@ class ModelTrainer(Protocol):
|
|
18
18
|
|
19
19
|
def train_fit_predict(
|
20
20
|
self,
|
21
|
-
expected_output_cols_list:
|
21
|
+
expected_output_cols_list: list[str],
|
22
22
|
drop_input_cols: Optional[bool] = False,
|
23
23
|
example_output_pd_df: Optional[pd.DataFrame] = None,
|
24
|
-
) ->
|
24
|
+
) -> tuple[Union[DataFrame, pd.DataFrame], object]:
|
25
25
|
raise NotImplementedError
|
26
26
|
|
27
27
|
def train_fit_transform(
|
28
28
|
self,
|
29
|
-
expected_output_cols_list:
|
29
|
+
expected_output_cols_list: list[str],
|
30
30
|
drop_input_cols: Optional[bool] = False,
|
31
|
-
) ->
|
31
|
+
) -> tuple[Union[DataFrame, pd.DataFrame], object]:
|
32
32
|
raise NotImplementedError
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Union
|
2
2
|
|
3
3
|
import pandas as pd
|
4
4
|
from sklearn import model_selection
|
@@ -71,8 +71,8 @@ class ModelTrainerBuilder:
|
|
71
71
|
cls,
|
72
72
|
estimator: object,
|
73
73
|
dataset: Union[DataFrame, pd.DataFrame],
|
74
|
-
input_cols: Optional[
|
75
|
-
label_cols: Optional[
|
74
|
+
input_cols: Optional[list[str]] = None,
|
75
|
+
label_cols: Optional[list[str]] = None,
|
76
76
|
sample_weight_col: Optional[str] = None,
|
77
77
|
autogenerated: bool = False,
|
78
78
|
subproject: str = "",
|
@@ -130,7 +130,7 @@ class ModelTrainerBuilder:
|
|
130
130
|
cls,
|
131
131
|
estimator: object,
|
132
132
|
dataset: Union[DataFrame, pd.DataFrame],
|
133
|
-
input_cols:
|
133
|
+
input_cols: list[str],
|
134
134
|
autogenerated: bool = False,
|
135
135
|
subproject: str = "",
|
136
136
|
) -> ModelTrainer:
|
@@ -169,8 +169,8 @@ class ModelTrainerBuilder:
|
|
169
169
|
cls,
|
170
170
|
estimator: object,
|
171
171
|
dataset: Union[DataFrame, pd.DataFrame],
|
172
|
-
input_cols:
|
173
|
-
label_cols: Optional[
|
172
|
+
input_cols: list[str],
|
173
|
+
label_cols: Optional[list[str]] = None,
|
174
174
|
sample_weight_col: Optional[str] = None,
|
175
175
|
autogenerated: bool = False,
|
176
176
|
subproject: str = "",
|
@@ -5,7 +5,7 @@ import os
|
|
5
5
|
import posixpath
|
6
6
|
import sys
|
7
7
|
import uuid
|
8
|
-
from typing import Any,
|
8
|
+
from typing import Any, Optional, Union
|
9
9
|
|
10
10
|
import cloudpickle as cp
|
11
11
|
import numpy as np
|
@@ -50,11 +50,11 @@ _UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}
|
|
50
50
|
def construct_cv_results(
|
51
51
|
estimator: Union[GridSearchCV, RandomizedSearchCV],
|
52
52
|
n_split: int,
|
53
|
-
param_grid:
|
54
|
-
cv_results_raw_hex:
|
53
|
+
param_grid: list[dict[str, Any]],
|
54
|
+
cv_results_raw_hex: list[Row],
|
55
55
|
cross_validator_indices_length: int,
|
56
56
|
parameter_grid_length: int,
|
57
|
-
) ->
|
57
|
+
) -> tuple[bool, dict[str, Any]]:
|
58
58
|
"""Construct the cross validation result from the UDF. Because we accelerate the process
|
59
59
|
by the number of cross validation number, and the combination of parameter grids.
|
60
60
|
Therefore, we need to stick them back together instead of returning the raw result
|
@@ -158,11 +158,11 @@ def construct_cv_results(
|
|
158
158
|
def construct_cv_results_memory_efficient_version(
|
159
159
|
estimator: Union[GridSearchCV, RandomizedSearchCV],
|
160
160
|
n_split: int,
|
161
|
-
param_grid:
|
162
|
-
cv_results_raw_hex:
|
161
|
+
param_grid: list[dict[str, Any]],
|
162
|
+
cv_results_raw_hex: list[Row],
|
163
163
|
cross_validator_indices_length: int,
|
164
164
|
parameter_grid_length: int,
|
165
|
-
) ->
|
165
|
+
) -> tuple[Any, dict[str, Any]]:
|
166
166
|
"""Construct the cross validation result from the UDF.
|
167
167
|
The output is a raw dictionary generated by _fit_and_score, encoded into hex binary.
|
168
168
|
This function need to decode the string and then call _format_result to stick them back together
|
@@ -210,7 +210,7 @@ def construct_cv_results_memory_efficient_version(
|
|
210
210
|
# because original SearchCV is ranked by parameter first and cv second,
|
211
211
|
# to make the memory efficient, we implemented by fitting on cv first and parameter second
|
212
212
|
# when retrieving the results back, the ordering should revert back to remain the same result as original SearchCV
|
213
|
-
def generate_the_order_by_parameter_index(all_combination_length: int) ->
|
213
|
+
def generate_the_order_by_parameter_index(all_combination_length: int) -> list[int]:
|
214
214
|
pattern = []
|
215
215
|
for i in range(all_combination_length):
|
216
216
|
if i % parameter_grid_length == 0:
|
@@ -221,7 +221,7 @@ def construct_cv_results_memory_efficient_version(
|
|
221
221
|
pattern.append(j)
|
222
222
|
return pattern
|
223
223
|
|
224
|
-
def rerank_array(original_array:
|
224
|
+
def rerank_array(original_array: list[Any], pattern: list[int]) -> list[Any]:
|
225
225
|
reranked_array = []
|
226
226
|
for index in pattern:
|
227
227
|
reranked_array.append(original_array[index])
|
@@ -251,8 +251,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
251
251
|
estimator: object,
|
252
252
|
dataset: DataFrame,
|
253
253
|
session: Session,
|
254
|
-
input_cols:
|
255
|
-
label_cols: Optional[
|
254
|
+
input_cols: list[str],
|
255
|
+
label_cols: Optional[list[str]],
|
256
256
|
sample_weight_col: Optional[str],
|
257
257
|
autogenerated: bool = False,
|
258
258
|
subproject: str = "",
|
@@ -289,10 +289,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
289
289
|
dataset: DataFrame,
|
290
290
|
session: Session,
|
291
291
|
estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
|
292
|
-
dependencies:
|
293
|
-
udf_imports:
|
294
|
-
input_cols:
|
295
|
-
label_cols: Optional[
|
292
|
+
dependencies: list[str],
|
293
|
+
udf_imports: list[str],
|
294
|
+
input_cols: list[str],
|
295
|
+
label_cols: Optional[list[str]],
|
296
296
|
sample_weight_col: Optional[str],
|
297
297
|
) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
|
298
298
|
from itertools import product
|
@@ -382,10 +382,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
382
382
|
)
|
383
383
|
def _distributed_search(
|
384
384
|
session: Session,
|
385
|
-
imports:
|
385
|
+
imports: list[str],
|
386
386
|
stage_estimator_file_name: str,
|
387
|
-
input_cols:
|
388
|
-
label_cols: Optional[
|
387
|
+
input_cols: list[str],
|
388
|
+
label_cols: Optional[list[str]],
|
389
389
|
) -> str:
|
390
390
|
import os
|
391
391
|
import time
|
@@ -455,12 +455,12 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
455
455
|
assert estimator is not None
|
456
456
|
|
457
457
|
@cachetools.cached(cache={})
|
458
|
-
def _load_data_into_udf() ->
|
459
|
-
|
458
|
+
def _load_data_into_udf() -> tuple[
|
459
|
+
dict[str, pd.DataFrame],
|
460
460
|
Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
|
461
461
|
pd.DataFrame,
|
462
462
|
int,
|
463
|
-
|
463
|
+
list[dict[str, Any]],
|
464
464
|
]:
|
465
465
|
import pyarrow.parquet as pq
|
466
466
|
|
@@ -512,7 +512,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
512
512
|
self.data_length = data_length
|
513
513
|
self.params_to_evaluate = params_to_evaluate
|
514
514
|
|
515
|
-
def process(self, params_idx: int, cv_idx: int) -> Iterator[
|
515
|
+
def process(self, params_idx: int, cv_idx: int) -> Iterator[tuple[str]]:
|
516
516
|
# Assign parameter to GridSearchCV
|
517
517
|
if hasattr(estimator, "param_grid"):
|
518
518
|
self.estimator.param_grid = self.params_to_evaluate[params_idx]
|
@@ -699,10 +699,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
699
699
|
dataset: DataFrame,
|
700
700
|
session: Session,
|
701
701
|
estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
|
702
|
-
dependencies:
|
703
|
-
udf_imports:
|
704
|
-
input_cols:
|
705
|
-
label_cols: Optional[
|
702
|
+
dependencies: list[str],
|
703
|
+
udf_imports: list[str],
|
704
|
+
input_cols: list[str],
|
705
|
+
label_cols: Optional[list[str]],
|
706
706
|
sample_weight_col: Optional[str],
|
707
707
|
) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
|
708
708
|
from itertools import product
|
@@ -727,7 +727,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
727
727
|
# Create a temp file and dump the estimator to that file.
|
728
728
|
estimator_file_name = temp_file_utils.get_temp_file_path()
|
729
729
|
params_to_evaluate = list(param_grid)
|
730
|
-
CONSTANTS:
|
730
|
+
CONSTANTS: dict[str, Any] = dict()
|
731
731
|
CONSTANTS["dataset_snowpark_cols"] = dataset.columns
|
732
732
|
CONSTANTS["n_candidates"] = len(params_to_evaluate)
|
733
733
|
CONSTANTS["_N_JOBS"] = estimator.n_jobs
|
@@ -791,10 +791,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
791
791
|
)
|
792
792
|
def _distributed_search(
|
793
793
|
session: Session,
|
794
|
-
imports:
|
794
|
+
imports: list[str],
|
795
795
|
stage_estimator_file_name: str,
|
796
|
-
input_cols:
|
797
|
-
label_cols: Optional[
|
796
|
+
input_cols: list[str],
|
797
|
+
label_cols: Optional[list[str]],
|
798
798
|
) -> str:
|
799
799
|
import os
|
800
800
|
import time
|