snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +6 -6
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +6 -16
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +17 -28
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +4 -5
- snowflake/ml/jobs/job.py +24 -14
- snowflake/ml/jobs/manager.py +37 -41
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +199 -26
- snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +13 -13
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +17 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +33 -31
- snowflake/ml/registry/registry.py +29 -22
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Mapping, Optional
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
4
|
from snowflake.ml._internal.utils import (
|
@@ -15,7 +15,7 @@ MODEL_JSON_MODEL_NAME_FIELD = "model_name"
|
|
15
15
|
MODEL_JSON_VERSION_NAME_FIELD = "version_name"
|
16
16
|
|
17
17
|
|
18
|
-
def _build_sql_list_from_columns(columns:
|
18
|
+
def _build_sql_list_from_columns(columns: list[sql_identifier.SqlIdentifier]) -> str:
|
19
19
|
sql_list = ", ".join([f"'{column}'" for column in columns])
|
20
20
|
return f"({sql_list})"
|
21
21
|
|
@@ -60,17 +60,17 @@ class ModelMonitorSQLClient:
|
|
60
60
|
function_name: str,
|
61
61
|
warehouse_name: sql_identifier.SqlIdentifier,
|
62
62
|
timestamp_column: sql_identifier.SqlIdentifier,
|
63
|
-
id_columns:
|
64
|
-
prediction_score_columns:
|
65
|
-
prediction_class_columns:
|
66
|
-
actual_score_columns:
|
67
|
-
actual_class_columns:
|
63
|
+
id_columns: list[sql_identifier.SqlIdentifier],
|
64
|
+
prediction_score_columns: list[sql_identifier.SqlIdentifier],
|
65
|
+
prediction_class_columns: list[sql_identifier.SqlIdentifier],
|
66
|
+
actual_score_columns: list[sql_identifier.SqlIdentifier],
|
67
|
+
actual_class_columns: list[sql_identifier.SqlIdentifier],
|
68
68
|
refresh_interval: str,
|
69
69
|
aggregation_window: str,
|
70
70
|
baseline_database: Optional[sql_identifier.SqlIdentifier] = None,
|
71
71
|
baseline_schema: Optional[sql_identifier.SqlIdentifier] = None,
|
72
72
|
baseline: Optional[sql_identifier.SqlIdentifier] = None,
|
73
|
-
statement_params: Optional[
|
73
|
+
statement_params: Optional[dict[str, Any]] = None,
|
74
74
|
) -> None:
|
75
75
|
baseline_sql = ""
|
76
76
|
if baseline:
|
@@ -103,7 +103,7 @@ class ModelMonitorSQLClient:
|
|
103
103
|
database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
104
104
|
schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
105
105
|
monitor_name: sql_identifier.SqlIdentifier,
|
106
|
-
statement_params: Optional[
|
106
|
+
statement_params: Optional[dict[str, Any]] = None,
|
107
107
|
) -> None:
|
108
108
|
search_database_name = database_name or self._database_name
|
109
109
|
search_schema_name = schema_name or self._schema_name
|
@@ -116,8 +116,8 @@ class ModelMonitorSQLClient:
|
|
116
116
|
def show_model_monitors(
|
117
117
|
self,
|
118
118
|
*,
|
119
|
-
statement_params: Optional[
|
120
|
-
) ->
|
119
|
+
statement_params: Optional[dict[str, Any]] = None,
|
120
|
+
) -> list[snowpark.Row]:
|
121
121
|
fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
|
122
122
|
return (
|
123
123
|
query_result_checker.SqlResultValidator(
|
@@ -135,7 +135,7 @@ class ModelMonitorSQLClient:
|
|
135
135
|
database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
136
136
|
schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
137
137
|
monitor_name: sql_identifier.SqlIdentifier,
|
138
|
-
statement_params: Optional[
|
138
|
+
statement_params: Optional[dict[str, Any]] = None,
|
139
139
|
) -> bool:
|
140
140
|
search_database_name = database_name or self._database_name
|
141
141
|
search_schema_name = schema_name or self._schema_name
|
@@ -153,7 +153,7 @@ class ModelMonitorSQLClient:
|
|
153
153
|
def validate_monitor_warehouse(
|
154
154
|
self,
|
155
155
|
warehouse_name: sql_identifier.SqlIdentifier,
|
156
|
-
statement_params: Optional[
|
156
|
+
statement_params: Optional[dict[str, Any]] = None,
|
157
157
|
) -> None:
|
158
158
|
"""Validate warehouse provided for monitoring exists.
|
159
159
|
|
@@ -177,11 +177,11 @@ class ModelMonitorSQLClient:
|
|
177
177
|
*,
|
178
178
|
source_column_schema: Mapping[str, types.DataType],
|
179
179
|
timestamp_column: sql_identifier.SqlIdentifier,
|
180
|
-
prediction_score_columns:
|
181
|
-
prediction_class_columns:
|
182
|
-
actual_score_columns:
|
183
|
-
actual_class_columns:
|
184
|
-
id_columns:
|
180
|
+
prediction_score_columns: list[sql_identifier.SqlIdentifier],
|
181
|
+
prediction_class_columns: list[sql_identifier.SqlIdentifier],
|
182
|
+
actual_score_columns: list[sql_identifier.SqlIdentifier],
|
183
|
+
actual_class_columns: list[sql_identifier.SqlIdentifier],
|
184
|
+
id_columns: list[sql_identifier.SqlIdentifier],
|
185
185
|
) -> None:
|
186
186
|
"""Ensures all columns exist in the source table.
|
187
187
|
|
@@ -221,11 +221,11 @@ class ModelMonitorSQLClient:
|
|
221
221
|
source_schema: Optional[sql_identifier.SqlIdentifier],
|
222
222
|
source: sql_identifier.SqlIdentifier,
|
223
223
|
timestamp_column: sql_identifier.SqlIdentifier,
|
224
|
-
prediction_score_columns:
|
225
|
-
prediction_class_columns:
|
226
|
-
actual_score_columns:
|
227
|
-
actual_class_columns:
|
228
|
-
id_columns:
|
224
|
+
prediction_score_columns: list[sql_identifier.SqlIdentifier],
|
225
|
+
prediction_class_columns: list[sql_identifier.SqlIdentifier],
|
226
|
+
actual_score_columns: list[sql_identifier.SqlIdentifier],
|
227
|
+
actual_class_columns: list[sql_identifier.SqlIdentifier],
|
228
|
+
id_columns: list[sql_identifier.SqlIdentifier],
|
229
229
|
) -> None:
|
230
230
|
source_database = source_database or self._database_name
|
231
231
|
source_schema = source_schema or self._schema_name
|
@@ -250,7 +250,7 @@ class ModelMonitorSQLClient:
|
|
250
250
|
self,
|
251
251
|
operation: str,
|
252
252
|
monitor_name: sql_identifier.SqlIdentifier,
|
253
|
-
statement_params: Optional[
|
253
|
+
statement_params: Optional[dict[str, Any]] = None,
|
254
254
|
) -> None:
|
255
255
|
if operation not in {"SUSPEND", "RESUME"}:
|
256
256
|
raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables")
|
@@ -263,7 +263,7 @@ class ModelMonitorSQLClient:
|
|
263
263
|
def suspend_monitor(
|
264
264
|
self,
|
265
265
|
monitor_name: sql_identifier.SqlIdentifier,
|
266
|
-
statement_params: Optional[
|
266
|
+
statement_params: Optional[dict[str, Any]] = None,
|
267
267
|
) -> None:
|
268
268
|
self._alter_monitor(
|
269
269
|
operation="SUSPEND",
|
@@ -274,7 +274,7 @@ class ModelMonitorSQLClient:
|
|
274
274
|
def resume_monitor(
|
275
275
|
self,
|
276
276
|
monitor_name: sql_identifier.SqlIdentifier,
|
277
|
-
statement_params: Optional[
|
277
|
+
statement_params: Optional[dict[str, Any]] = None,
|
278
278
|
) -> None:
|
279
279
|
self._alter_monitor(
|
280
280
|
operation="RESUME",
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional
|
3
3
|
|
4
4
|
from snowflake import snowpark
|
5
5
|
from snowflake.ml._internal.utils import sql_identifier
|
@@ -20,7 +20,7 @@ class ModelMonitorManager:
|
|
20
20
|
database_name: sql_identifier.SqlIdentifier,
|
21
21
|
schema_name: sql_identifier.SqlIdentifier,
|
22
22
|
*,
|
23
|
-
statement_params: Optional[
|
23
|
+
statement_params: Optional[dict[str, Any]] = None,
|
24
24
|
) -> None:
|
25
25
|
"""
|
26
26
|
Opens a ModelMonitorManager for a given database and schema.
|
@@ -64,7 +64,7 @@ class ModelMonitorManager:
|
|
64
64
|
f"Found: {existing_target_methods}."
|
65
65
|
)
|
66
66
|
|
67
|
-
def _build_column_list_from_input(self, columns: Optional[
|
67
|
+
def _build_column_list_from_input(self, columns: Optional[list[str]]) -> list[sql_identifier.SqlIdentifier]:
|
68
68
|
return [sql_identifier.SqlIdentifier(column_name) for column_name in columns] if columns else []
|
69
69
|
|
70
70
|
def add_monitor(
|
@@ -172,7 +172,7 @@ class ModelMonitorManager:
|
|
172
172
|
"""
|
173
173
|
rows = self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
|
174
174
|
|
175
|
-
def model_match_fn(model_details:
|
175
|
+
def model_match_fn(model_details: dict[str, str]) -> bool:
|
176
176
|
return (
|
177
177
|
model_details[model_monitor_sql_client.MODEL_JSON_MODEL_NAME_FIELD] == model_version.model_name
|
178
178
|
and model_details[model_monitor_sql_client.MODEL_JSON_VERSION_NAME_FIELD] == model_version.version_name
|
@@ -215,7 +215,7 @@ class ModelMonitorManager:
|
|
215
215
|
name=monitor_name_id,
|
216
216
|
)
|
217
217
|
|
218
|
-
def show_model_monitors(self) ->
|
218
|
+
def show_model_monitors(self) -> list[snowpark.Row]:
|
219
219
|
"""Show all model monitors in the registry.
|
220
220
|
|
221
221
|
Returns:
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
|
-
from typing import
|
2
|
+
from typing import Optional
|
3
3
|
|
4
4
|
from snowflake.ml.model._client.model import model_version_impl
|
5
5
|
|
@@ -14,20 +14,20 @@ class ModelMonitorSourceConfig:
|
|
14
14
|
timestamp_column: str
|
15
15
|
"""Name of column in the source containing timestamp."""
|
16
16
|
|
17
|
-
id_columns:
|
17
|
+
id_columns: list[str]
|
18
18
|
"""List of columns in the source containing unique identifiers."""
|
19
19
|
|
20
|
-
prediction_score_columns: Optional[
|
20
|
+
prediction_score_columns: Optional[list[str]] = None
|
21
21
|
"""List of columns in the source containing prediction scores.
|
22
22
|
Can be regression scores for regression models and probability scores for classification models."""
|
23
23
|
|
24
|
-
prediction_class_columns: Optional[
|
24
|
+
prediction_class_columns: Optional[list[str]] = None
|
25
25
|
"""List of columns in the source containing prediction classes for classification models."""
|
26
26
|
|
27
|
-
actual_score_columns: Optional[
|
27
|
+
actual_score_columns: Optional[list[str]] = None
|
28
28
|
"""List of columns in the source containing actual scores."""
|
29
29
|
|
30
|
-
actual_class_columns: Optional[
|
30
|
+
actual_class_columns: Optional[list[str]] = None
|
31
31
|
"""List of columns in the source containing actual classes for classification models."""
|
32
32
|
|
33
33
|
baseline: Optional[str] = None
|
@@ -1,11 +1,10 @@
|
|
1
|
-
import os
|
2
1
|
from types import ModuleType
|
3
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional, Union
|
4
3
|
|
5
4
|
import pandas as pd
|
6
5
|
from absl.logging import logging
|
7
6
|
|
8
|
-
from snowflake.ml._internal import platform_capabilities, telemetry
|
7
|
+
from snowflake.ml._internal import env, platform_capabilities, telemetry
|
9
8
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
10
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
11
10
|
from snowflake.ml._internal.utils import sql_identifier
|
@@ -14,7 +13,6 @@ from snowflake.ml.model._client.model import model_impl, model_version_impl
|
|
14
13
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
15
14
|
from snowflake.ml.model._model_composer import model_composer
|
16
15
|
from snowflake.ml.model._packager.model_meta import model_meta
|
17
|
-
from snowflake.ml.modeling._internal import constants
|
18
16
|
from snowflake.snowpark import exceptions as snowpark_exceptions, session
|
19
17
|
|
20
18
|
logger = logging.getLogger(__name__)
|
@@ -45,20 +43,21 @@ class ModelManager:
|
|
45
43
|
model_name: str,
|
46
44
|
version_name: Optional[str] = None,
|
47
45
|
comment: Optional[str] = None,
|
48
|
-
metrics: Optional[
|
49
|
-
conda_dependencies: Optional[
|
50
|
-
pip_requirements: Optional[
|
51
|
-
artifact_repository_map: Optional[
|
52
|
-
|
46
|
+
metrics: Optional[dict[str, Any]] = None,
|
47
|
+
conda_dependencies: Optional[list[str]] = None,
|
48
|
+
pip_requirements: Optional[list[str]] = None,
|
49
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
50
|
+
resource_constraint: Optional[dict[str, str]] = None,
|
51
|
+
target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
|
53
52
|
python_version: Optional[str] = None,
|
54
|
-
signatures: Optional[
|
53
|
+
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
55
54
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
56
|
-
user_files: Optional[
|
57
|
-
code_paths: Optional[
|
58
|
-
ext_modules: Optional[
|
55
|
+
user_files: Optional[dict[str, list[str]]] = None,
|
56
|
+
code_paths: Optional[list[str]] = None,
|
57
|
+
ext_modules: Optional[list[ModuleType]] = None,
|
59
58
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
60
59
|
options: Optional[model_types.ModelSaveOption] = None,
|
61
|
-
statement_params: Optional[
|
60
|
+
statement_params: Optional[dict[str, Any]] = None,
|
62
61
|
) -> model_version_impl.ModelVersion:
|
63
62
|
|
64
63
|
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
@@ -131,6 +130,7 @@ class ModelManager:
|
|
131
130
|
conda_dependencies=conda_dependencies,
|
132
131
|
pip_requirements=pip_requirements,
|
133
132
|
artifact_repository_map=artifact_repository_map,
|
133
|
+
resource_constraint=resource_constraint,
|
134
134
|
target_platforms=target_platforms,
|
135
135
|
python_version=python_version,
|
136
136
|
signatures=signatures,
|
@@ -150,20 +150,21 @@ class ModelManager:
|
|
150
150
|
model_name: str,
|
151
151
|
version_name: str,
|
152
152
|
comment: Optional[str] = None,
|
153
|
-
metrics: Optional[
|
154
|
-
conda_dependencies: Optional[
|
155
|
-
pip_requirements: Optional[
|
156
|
-
artifact_repository_map: Optional[
|
157
|
-
|
153
|
+
metrics: Optional[dict[str, Any]] = None,
|
154
|
+
conda_dependencies: Optional[list[str]] = None,
|
155
|
+
pip_requirements: Optional[list[str]] = None,
|
156
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
157
|
+
resource_constraint: Optional[dict[str, str]] = None,
|
158
|
+
target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
|
158
159
|
python_version: Optional[str] = None,
|
159
|
-
signatures: Optional[
|
160
|
+
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
160
161
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
161
|
-
user_files: Optional[
|
162
|
-
code_paths: Optional[
|
163
|
-
ext_modules: Optional[
|
162
|
+
user_files: Optional[dict[str, list[str]]] = None,
|
163
|
+
code_paths: Optional[list[str]] = None,
|
164
|
+
ext_modules: Optional[list[ModuleType]] = None,
|
164
165
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
165
166
|
options: Optional[model_types.ModelSaveOption] = None,
|
166
|
-
statement_params: Optional[
|
167
|
+
statement_params: Optional[dict[str, Any]] = None,
|
167
168
|
) -> model_version_impl.ModelVersion:
|
168
169
|
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
169
170
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
@@ -212,7 +213,7 @@ class ModelManager:
|
|
212
213
|
platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
|
213
214
|
else:
|
214
215
|
# Default the target platform to SPCS if not specified when running in ML runtime
|
215
|
-
if
|
216
|
+
if env.IN_ML_RUNTIME:
|
216
217
|
logger.info(
|
217
218
|
"Logging the model on Container Runtime for ML without specifying `target_platforms`. "
|
218
219
|
'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
|
@@ -253,6 +254,7 @@ class ModelManager:
|
|
253
254
|
conda_dependencies=conda_dependencies,
|
254
255
|
pip_requirements=pip_requirements,
|
255
256
|
artifact_repository_map=artifact_repository_map,
|
257
|
+
resource_constraint=resource_constraint,
|
256
258
|
target_platforms=platforms,
|
257
259
|
python_version=python_version,
|
258
260
|
user_files=user_files,
|
@@ -314,7 +316,7 @@ class ModelManager:
|
|
314
316
|
self,
|
315
317
|
model_name: str,
|
316
318
|
*,
|
317
|
-
statement_params: Optional[
|
319
|
+
statement_params: Optional[dict[str, Any]] = None,
|
318
320
|
) -> model_impl.Model:
|
319
321
|
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
320
322
|
if self._model_ops.validate_existence(
|
@@ -342,8 +344,8 @@ class ModelManager:
|
|
342
344
|
def models(
|
343
345
|
self,
|
344
346
|
*,
|
345
|
-
statement_params: Optional[
|
346
|
-
) ->
|
347
|
+
statement_params: Optional[dict[str, Any]] = None,
|
348
|
+
) -> list[model_impl.Model]:
|
347
349
|
model_names = self._model_ops.list_models_or_versions(
|
348
350
|
database_name=None,
|
349
351
|
schema_name=None,
|
@@ -361,7 +363,7 @@ class ModelManager:
|
|
361
363
|
def show_models(
|
362
364
|
self,
|
363
365
|
*,
|
364
|
-
statement_params: Optional[
|
366
|
+
statement_params: Optional[dict[str, Any]] = None,
|
365
367
|
) -> pd.DataFrame:
|
366
368
|
rows = self._model_ops.show_models_or_versions(
|
367
369
|
database_name=None,
|
@@ -374,7 +376,7 @@ class ModelManager:
|
|
374
376
|
self,
|
375
377
|
model_name: str,
|
376
378
|
*,
|
377
|
-
statement_params: Optional[
|
379
|
+
statement_params: Optional[dict[str, Any]] = None,
|
378
380
|
) -> None:
|
379
381
|
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
380
382
|
|
@@ -387,7 +389,7 @@ class ModelManager:
|
|
387
389
|
|
388
390
|
def _parse_fully_qualified_name(
|
389
391
|
self, model_name: str
|
390
|
-
) ->
|
392
|
+
) -> tuple[
|
391
393
|
Optional[sql_identifier.SqlIdentifier], Optional[sql_identifier.SqlIdentifier], sql_identifier.SqlIdentifier
|
392
394
|
]:
|
393
395
|
try:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import warnings
|
2
2
|
from types import ModuleType
|
3
|
-
from typing import Any,
|
3
|
+
from typing import Any, Optional, Union, overload
|
4
4
|
|
5
5
|
import pandas as pd
|
6
6
|
|
@@ -36,7 +36,7 @@ class Registry:
|
|
36
36
|
*,
|
37
37
|
database_name: Optional[str] = None,
|
38
38
|
schema_name: Optional[str] = None,
|
39
|
-
options: Optional[
|
39
|
+
options: Optional[dict[str, Any]] = None,
|
40
40
|
) -> None:
|
41
41
|
"""Opens a registry within a pre-created Snowflake schema.
|
42
42
|
|
@@ -107,17 +107,18 @@ class Registry:
|
|
107
107
|
model_name: str,
|
108
108
|
version_name: Optional[str] = None,
|
109
109
|
comment: Optional[str] = None,
|
110
|
-
metrics: Optional[
|
111
|
-
conda_dependencies: Optional[
|
112
|
-
pip_requirements: Optional[
|
113
|
-
artifact_repository_map: Optional[
|
114
|
-
|
110
|
+
metrics: Optional[dict[str, Any]] = None,
|
111
|
+
conda_dependencies: Optional[list[str]] = None,
|
112
|
+
pip_requirements: Optional[list[str]] = None,
|
113
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
114
|
+
resource_constraint: Optional[dict[str, str]] = None,
|
115
|
+
target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
|
115
116
|
python_version: Optional[str] = None,
|
116
|
-
signatures: Optional[
|
117
|
+
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
117
118
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
118
|
-
user_files: Optional[
|
119
|
-
code_paths: Optional[
|
120
|
-
ext_modules: Optional[
|
119
|
+
user_files: Optional[dict[str, list[str]]] = None,
|
120
|
+
code_paths: Optional[list[str]] = None,
|
121
|
+
ext_modules: Optional[list[ModuleType]] = None,
|
121
122
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
122
123
|
options: Optional[model_types.ModelSaveOption] = None,
|
123
124
|
) -> ModelVersion:
|
@@ -152,6 +153,7 @@ class Registry:
|
|
152
153
|
Format: {channel_name: artifact_repository_name}, where:
|
153
154
|
- channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
|
154
155
|
- artifact_repository_name: The name or URL of the repository to fetch packages from.
|
156
|
+
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
155
157
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
156
158
|
{"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
|
157
159
|
python_version: Python version in which the model is run. Defaults to None.
|
@@ -232,6 +234,7 @@ class Registry:
|
|
232
234
|
"conda_dependencies",
|
233
235
|
"pip_requirements",
|
234
236
|
"artifact_repository_map",
|
237
|
+
"resource_constraint",
|
235
238
|
"target_platforms",
|
236
239
|
"python_version",
|
237
240
|
"signatures",
|
@@ -244,17 +247,18 @@ class Registry:
|
|
244
247
|
model_name: str,
|
245
248
|
version_name: Optional[str] = None,
|
246
249
|
comment: Optional[str] = None,
|
247
|
-
metrics: Optional[
|
248
|
-
conda_dependencies: Optional[
|
249
|
-
pip_requirements: Optional[
|
250
|
-
artifact_repository_map: Optional[
|
251
|
-
|
250
|
+
metrics: Optional[dict[str, Any]] = None,
|
251
|
+
conda_dependencies: Optional[list[str]] = None,
|
252
|
+
pip_requirements: Optional[list[str]] = None,
|
253
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
254
|
+
resource_constraint: Optional[dict[str, str]] = None,
|
255
|
+
target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
|
252
256
|
python_version: Optional[str] = None,
|
253
|
-
signatures: Optional[
|
257
|
+
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
254
258
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
255
|
-
user_files: Optional[
|
256
|
-
code_paths: Optional[
|
257
|
-
ext_modules: Optional[
|
259
|
+
user_files: Optional[dict[str, list[str]]] = None,
|
260
|
+
code_paths: Optional[list[str]] = None,
|
261
|
+
ext_modules: Optional[list[ModuleType]] = None,
|
258
262
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
259
263
|
options: Optional[model_types.ModelSaveOption] = None,
|
260
264
|
) -> ModelVersion:
|
@@ -289,6 +293,7 @@ class Registry:
|
|
289
293
|
Format: {channel_name: artifact_repository_name}, where:
|
290
294
|
- channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
|
291
295
|
- artifact_repository_name: The name or URL of the repository to fetch packages from.
|
296
|
+
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
292
297
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
293
298
|
{"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
|
294
299
|
python_version: Python version in which the model is run. Defaults to None.
|
@@ -373,6 +378,7 @@ class Registry:
|
|
373
378
|
conda_dependencies,
|
374
379
|
pip_requirements,
|
375
380
|
artifact_repository_map,
|
381
|
+
resource_constraint,
|
376
382
|
target_platforms,
|
377
383
|
python_version,
|
378
384
|
signatures,
|
@@ -407,6 +413,7 @@ class Registry:
|
|
407
413
|
conda_dependencies=conda_dependencies,
|
408
414
|
pip_requirements=pip_requirements,
|
409
415
|
artifact_repository_map=artifact_repository_map,
|
416
|
+
resource_constraint=resource_constraint,
|
410
417
|
target_platforms=target_platforms,
|
411
418
|
python_version=python_version,
|
412
419
|
signatures=signatures,
|
@@ -442,7 +449,7 @@ class Registry:
|
|
442
449
|
project=_TELEMETRY_PROJECT,
|
443
450
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
444
451
|
)
|
445
|
-
def models(self) ->
|
452
|
+
def models(self) -> list[Model]:
|
446
453
|
"""Get all models in the schema where the registry is opened.
|
447
454
|
|
448
455
|
Returns:
|
@@ -568,7 +575,7 @@ class Registry:
|
|
568
575
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
569
576
|
)
|
570
577
|
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
571
|
-
def show_model_monitors(self) ->
|
578
|
+
def show_model_monitors(self) -> list[snowpark.Row]:
|
572
579
|
"""Show all model monitors in the registry.
|
573
580
|
|
574
581
|
Returns:
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import http
|
2
2
|
import logging
|
3
3
|
from datetime import timedelta
|
4
|
-
from typing import
|
4
|
+
from typing import Optional
|
5
5
|
|
6
6
|
import requests
|
7
7
|
from cryptography.hazmat.primitives.asymmetric import types
|
@@ -10,7 +10,7 @@ from requests import auth
|
|
10
10
|
from snowflake.ml._internal.utils import jwt_generator
|
11
11
|
|
12
12
|
logger = logging.getLogger(__name__)
|
13
|
-
_JWT_TOKEN_CACHE:
|
13
|
+
_JWT_TOKEN_CACHE: dict[str, dict[int, str]] = {}
|
14
14
|
|
15
15
|
|
16
16
|
def get_jwt_token_generator(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import configparser
|
2
2
|
import os
|
3
|
-
from typing import
|
3
|
+
from typing import Optional, Union
|
4
4
|
|
5
5
|
from absl import logging
|
6
6
|
from cryptography.hazmat import backends
|
@@ -76,7 +76,7 @@ def _load_pem_to_der(private_key_path: str) -> bytes:
|
|
76
76
|
)
|
77
77
|
|
78
78
|
|
79
|
-
def _connection_properties_from_env() ->
|
79
|
+
def _connection_properties_from_env() -> dict[str, str]:
|
80
80
|
"""Returns a dict with all possible login related env variables."""
|
81
81
|
sf_conn_prop = {
|
82
82
|
# Mandatory fields
|
@@ -104,7 +104,7 @@ def _connection_properties_from_env() -> Dict[str, str]:
|
|
104
104
|
return sf_conn_prop
|
105
105
|
|
106
106
|
|
107
|
-
def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") ->
|
107
|
+
def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -> dict[str, str]:
|
108
108
|
"""Loads the dictionary from snowsql config file."""
|
109
109
|
snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
|
110
110
|
if not os.path.exists(snowsql_config_file):
|
@@ -133,7 +133,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
|
|
133
133
|
|
134
134
|
|
135
135
|
@snowpark._internal.utils.private_preview(version="0.2.0")
|
136
|
-
def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) ->
|
136
|
+
def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> dict[str, Union[str, bytes]]:
|
137
137
|
"""Returns a dict that can be used directly into snowflake python connector or Snowpark session config.
|
138
138
|
|
139
139
|
NOTE: Token/Auth information is sideloaded in all cases above, if provided in following order:
|
@@ -164,7 +164,7 @@ def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] =
|
|
164
164
|
Raises:
|
165
165
|
Exception: if none of config file and environment variable are present.
|
166
166
|
"""
|
167
|
-
conn_prop:
|
167
|
+
conn_prop: dict[str, Union[str, bytes]] = {}
|
168
168
|
login_file = login_file or os.path.expanduser(_DEFAULT_CONNECTION_FILE)
|
169
169
|
# If login file exists, use this exclusively.
|
170
170
|
if os.path.exists(login_file):
|
snowflake/ml/utils/sparse.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import collections
|
2
2
|
import json
|
3
|
-
from typing import
|
3
|
+
from typing import Optional
|
4
4
|
|
5
5
|
import pandas as pd
|
6
6
|
from pandas import arrays as pandas_arrays
|
@@ -9,7 +9,7 @@ from pandas.core.arrays import sparse as pandas_sparse
|
|
9
9
|
from snowflake.snowpark import DataFrame
|
10
10
|
|
11
11
|
|
12
|
-
def _pandas_to_sparse_pandas(pandas_df: pd.DataFrame, sparse_cols:
|
12
|
+
def _pandas_to_sparse_pandas(pandas_df: pd.DataFrame, sparse_cols: list[str]) -> Optional[pd.DataFrame]:
|
13
13
|
"""Convert the pandas df into pandas df with multiple SparseArray columns."""
|
14
14
|
num_rows = pandas_df.shape[0]
|
15
15
|
if num_rows == 0:
|
@@ -52,8 +52,9 @@ def _pandas_to_sparse_pandas(pandas_df: pd.DataFrame, sparse_cols: List[str]) ->
|
|
52
52
|
return pandas_df
|
53
53
|
|
54
54
|
|
55
|
-
def to_pandas_with_sparse(df: DataFrame, sparse_cols:
|
56
|
-
"""Load a Snowpark df with sparse columns represented in JSON strings into pandas df with multiple SparseArray
|
55
|
+
def to_pandas_with_sparse(df: DataFrame, sparse_cols: list[str]) -> pd.DataFrame:
|
56
|
+
"""Load a Snowpark df with sparse columns represented in JSON strings into pandas df with multiple SparseArray
|
57
|
+
columns.
|
57
58
|
|
58
59
|
For example, for below input:
|
59
60
|
----------------------------------------------
|
snowflake/ml/utils/sql_client.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
from enum import Enum
|
2
|
-
from typing import Dict
|
3
2
|
|
4
3
|
|
5
4
|
class CreationOption(Enum):
|
@@ -13,7 +12,7 @@ class CreationMode:
|
|
13
12
|
self.if_not_exists = if_not_exists
|
14
13
|
self.or_replace = or_replace
|
15
14
|
|
16
|
-
def get_ddl_phrases(self) ->
|
15
|
+
def get_ddl_phrases(self) -> dict[CreationOption, str]:
|
17
16
|
if_not_exists_sql = " IF NOT EXISTS" if self.if_not_exists else ""
|
18
17
|
or_replace_sql = " OR REPLACE" if self.or_replace else ""
|
19
18
|
return {
|
snowflake/ml/version.py
CHANGED
@@ -1 +1,2 @@
|
|
1
|
-
|
1
|
+
# This is parsed by regex in conda recipe meta file. Make sure not to break it.
|
2
|
+
VERSION = "1.8.3"
|