snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Mapping, Optional
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
4
|
from snowflake.ml._internal.utils import (
|
@@ -15,7 +15,7 @@ MODEL_JSON_MODEL_NAME_FIELD = "model_name"
|
|
15
15
|
MODEL_JSON_VERSION_NAME_FIELD = "version_name"
|
16
16
|
|
17
17
|
|
18
|
-
def _build_sql_list_from_columns(columns:
|
18
|
+
def _build_sql_list_from_columns(columns: list[sql_identifier.SqlIdentifier]) -> str:
|
19
19
|
sql_list = ", ".join([f"'{column}'" for column in columns])
|
20
20
|
return f"({sql_list})"
|
21
21
|
|
@@ -60,17 +60,17 @@ class ModelMonitorSQLClient:
|
|
60
60
|
function_name: str,
|
61
61
|
warehouse_name: sql_identifier.SqlIdentifier,
|
62
62
|
timestamp_column: sql_identifier.SqlIdentifier,
|
63
|
-
id_columns:
|
64
|
-
prediction_score_columns:
|
65
|
-
prediction_class_columns:
|
66
|
-
actual_score_columns:
|
67
|
-
actual_class_columns:
|
63
|
+
id_columns: list[sql_identifier.SqlIdentifier],
|
64
|
+
prediction_score_columns: list[sql_identifier.SqlIdentifier],
|
65
|
+
prediction_class_columns: list[sql_identifier.SqlIdentifier],
|
66
|
+
actual_score_columns: list[sql_identifier.SqlIdentifier],
|
67
|
+
actual_class_columns: list[sql_identifier.SqlIdentifier],
|
68
68
|
refresh_interval: str,
|
69
69
|
aggregation_window: str,
|
70
70
|
baseline_database: Optional[sql_identifier.SqlIdentifier] = None,
|
71
71
|
baseline_schema: Optional[sql_identifier.SqlIdentifier] = None,
|
72
72
|
baseline: Optional[sql_identifier.SqlIdentifier] = None,
|
73
|
-
statement_params: Optional[
|
73
|
+
statement_params: Optional[dict[str, Any]] = None,
|
74
74
|
) -> None:
|
75
75
|
baseline_sql = ""
|
76
76
|
if baseline:
|
@@ -103,7 +103,7 @@ class ModelMonitorSQLClient:
|
|
103
103
|
database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
104
104
|
schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
105
105
|
monitor_name: sql_identifier.SqlIdentifier,
|
106
|
-
statement_params: Optional[
|
106
|
+
statement_params: Optional[dict[str, Any]] = None,
|
107
107
|
) -> None:
|
108
108
|
search_database_name = database_name or self._database_name
|
109
109
|
search_schema_name = schema_name or self._schema_name
|
@@ -116,8 +116,8 @@ class ModelMonitorSQLClient:
|
|
116
116
|
def show_model_monitors(
|
117
117
|
self,
|
118
118
|
*,
|
119
|
-
statement_params: Optional[
|
120
|
-
) ->
|
119
|
+
statement_params: Optional[dict[str, Any]] = None,
|
120
|
+
) -> list[snowpark.Row]:
|
121
121
|
fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
|
122
122
|
return (
|
123
123
|
query_result_checker.SqlResultValidator(
|
@@ -135,7 +135,7 @@ class ModelMonitorSQLClient:
|
|
135
135
|
database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
136
136
|
schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
137
137
|
monitor_name: sql_identifier.SqlIdentifier,
|
138
|
-
statement_params: Optional[
|
138
|
+
statement_params: Optional[dict[str, Any]] = None,
|
139
139
|
) -> bool:
|
140
140
|
search_database_name = database_name or self._database_name
|
141
141
|
search_schema_name = schema_name or self._schema_name
|
@@ -153,7 +153,7 @@ class ModelMonitorSQLClient:
|
|
153
153
|
def validate_monitor_warehouse(
|
154
154
|
self,
|
155
155
|
warehouse_name: sql_identifier.SqlIdentifier,
|
156
|
-
statement_params: Optional[
|
156
|
+
statement_params: Optional[dict[str, Any]] = None,
|
157
157
|
) -> None:
|
158
158
|
"""Validate warehouse provided for monitoring exists.
|
159
159
|
|
@@ -177,11 +177,11 @@ class ModelMonitorSQLClient:
|
|
177
177
|
*,
|
178
178
|
source_column_schema: Mapping[str, types.DataType],
|
179
179
|
timestamp_column: sql_identifier.SqlIdentifier,
|
180
|
-
prediction_score_columns:
|
181
|
-
prediction_class_columns:
|
182
|
-
actual_score_columns:
|
183
|
-
actual_class_columns:
|
184
|
-
id_columns:
|
180
|
+
prediction_score_columns: list[sql_identifier.SqlIdentifier],
|
181
|
+
prediction_class_columns: list[sql_identifier.SqlIdentifier],
|
182
|
+
actual_score_columns: list[sql_identifier.SqlIdentifier],
|
183
|
+
actual_class_columns: list[sql_identifier.SqlIdentifier],
|
184
|
+
id_columns: list[sql_identifier.SqlIdentifier],
|
185
185
|
) -> None:
|
186
186
|
"""Ensures all columns exist in the source table.
|
187
187
|
|
@@ -221,11 +221,11 @@ class ModelMonitorSQLClient:
|
|
221
221
|
source_schema: Optional[sql_identifier.SqlIdentifier],
|
222
222
|
source: sql_identifier.SqlIdentifier,
|
223
223
|
timestamp_column: sql_identifier.SqlIdentifier,
|
224
|
-
prediction_score_columns:
|
225
|
-
prediction_class_columns:
|
226
|
-
actual_score_columns:
|
227
|
-
actual_class_columns:
|
228
|
-
id_columns:
|
224
|
+
prediction_score_columns: list[sql_identifier.SqlIdentifier],
|
225
|
+
prediction_class_columns: list[sql_identifier.SqlIdentifier],
|
226
|
+
actual_score_columns: list[sql_identifier.SqlIdentifier],
|
227
|
+
actual_class_columns: list[sql_identifier.SqlIdentifier],
|
228
|
+
id_columns: list[sql_identifier.SqlIdentifier],
|
229
229
|
) -> None:
|
230
230
|
source_database = source_database or self._database_name
|
231
231
|
source_schema = source_schema or self._schema_name
|
@@ -250,7 +250,7 @@ class ModelMonitorSQLClient:
|
|
250
250
|
self,
|
251
251
|
operation: str,
|
252
252
|
monitor_name: sql_identifier.SqlIdentifier,
|
253
|
-
statement_params: Optional[
|
253
|
+
statement_params: Optional[dict[str, Any]] = None,
|
254
254
|
) -> None:
|
255
255
|
if operation not in {"SUSPEND", "RESUME"}:
|
256
256
|
raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables")
|
@@ -263,7 +263,7 @@ class ModelMonitorSQLClient:
|
|
263
263
|
def suspend_monitor(
|
264
264
|
self,
|
265
265
|
monitor_name: sql_identifier.SqlIdentifier,
|
266
|
-
statement_params: Optional[
|
266
|
+
statement_params: Optional[dict[str, Any]] = None,
|
267
267
|
) -> None:
|
268
268
|
self._alter_monitor(
|
269
269
|
operation="SUSPEND",
|
@@ -274,7 +274,7 @@ class ModelMonitorSQLClient:
|
|
274
274
|
def resume_monitor(
|
275
275
|
self,
|
276
276
|
monitor_name: sql_identifier.SqlIdentifier,
|
277
|
-
statement_params: Optional[
|
277
|
+
statement_params: Optional[dict[str, Any]] = None,
|
278
278
|
) -> None:
|
279
279
|
self._alter_monitor(
|
280
280
|
operation="RESUME",
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional
|
3
3
|
|
4
4
|
from snowflake import snowpark
|
5
5
|
from snowflake.ml._internal.utils import sql_identifier
|
@@ -20,7 +20,7 @@ class ModelMonitorManager:
|
|
20
20
|
database_name: sql_identifier.SqlIdentifier,
|
21
21
|
schema_name: sql_identifier.SqlIdentifier,
|
22
22
|
*,
|
23
|
-
statement_params: Optional[
|
23
|
+
statement_params: Optional[dict[str, Any]] = None,
|
24
24
|
) -> None:
|
25
25
|
"""
|
26
26
|
Opens a ModelMonitorManager for a given database and schema.
|
@@ -64,7 +64,7 @@ class ModelMonitorManager:
|
|
64
64
|
f"Found: {existing_target_methods}."
|
65
65
|
)
|
66
66
|
|
67
|
-
def _build_column_list_from_input(self, columns: Optional[
|
67
|
+
def _build_column_list_from_input(self, columns: Optional[list[str]]) -> list[sql_identifier.SqlIdentifier]:
|
68
68
|
return [sql_identifier.SqlIdentifier(column_name) for column_name in columns] if columns else []
|
69
69
|
|
70
70
|
def add_monitor(
|
@@ -172,7 +172,7 @@ class ModelMonitorManager:
|
|
172
172
|
"""
|
173
173
|
rows = self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
|
174
174
|
|
175
|
-
def model_match_fn(model_details:
|
175
|
+
def model_match_fn(model_details: dict[str, str]) -> bool:
|
176
176
|
return (
|
177
177
|
model_details[model_monitor_sql_client.MODEL_JSON_MODEL_NAME_FIELD] == model_version.model_name
|
178
178
|
and model_details[model_monitor_sql_client.MODEL_JSON_VERSION_NAME_FIELD] == model_version.version_name
|
@@ -215,7 +215,7 @@ class ModelMonitorManager:
|
|
215
215
|
name=monitor_name_id,
|
216
216
|
)
|
217
217
|
|
218
|
-
def show_model_monitors(self) ->
|
218
|
+
def show_model_monitors(self) -> list[snowpark.Row]:
|
219
219
|
"""Show all model monitors in the registry.
|
220
220
|
|
221
221
|
Returns:
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
|
-
from typing import
|
2
|
+
from typing import Optional
|
3
3
|
|
4
4
|
from snowflake.ml.model._client.model import model_version_impl
|
5
5
|
|
@@ -14,20 +14,20 @@ class ModelMonitorSourceConfig:
|
|
14
14
|
timestamp_column: str
|
15
15
|
"""Name of column in the source containing timestamp."""
|
16
16
|
|
17
|
-
id_columns:
|
17
|
+
id_columns: list[str]
|
18
18
|
"""List of columns in the source containing unique identifiers."""
|
19
19
|
|
20
|
-
prediction_score_columns: Optional[
|
20
|
+
prediction_score_columns: Optional[list[str]] = None
|
21
21
|
"""List of columns in the source containing prediction scores.
|
22
22
|
Can be regression scores for regression models and probability scores for classification models."""
|
23
23
|
|
24
|
-
prediction_class_columns: Optional[
|
24
|
+
prediction_class_columns: Optional[list[str]] = None
|
25
25
|
"""List of columns in the source containing prediction classes for classification models."""
|
26
26
|
|
27
|
-
actual_score_columns: Optional[
|
27
|
+
actual_score_columns: Optional[list[str]] = None
|
28
28
|
"""List of columns in the source containing actual scores."""
|
29
29
|
|
30
|
-
actual_class_columns: Optional[
|
30
|
+
actual_class_columns: Optional[list[str]] = None
|
31
31
|
"""List of columns in the source containing actual classes for classification models."""
|
32
32
|
|
33
33
|
baseline: Optional[str] = None
|
@@ -1,10 +1,10 @@
|
|
1
1
|
from types import ModuleType
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional, Union
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
from absl.logging import logging
|
6
6
|
|
7
|
-
from snowflake.ml._internal import platform_capabilities, telemetry
|
7
|
+
from snowflake.ml._internal import env, platform_capabilities, telemetry
|
8
8
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
10
10
|
from snowflake.ml._internal.utils import sql_identifier
|
@@ -43,20 +43,21 @@ class ModelManager:
|
|
43
43
|
model_name: str,
|
44
44
|
version_name: Optional[str] = None,
|
45
45
|
comment: Optional[str] = None,
|
46
|
-
metrics: Optional[
|
47
|
-
conda_dependencies: Optional[
|
48
|
-
pip_requirements: Optional[
|
49
|
-
artifact_repository_map: Optional[
|
50
|
-
|
46
|
+
metrics: Optional[dict[str, Any]] = None,
|
47
|
+
conda_dependencies: Optional[list[str]] = None,
|
48
|
+
pip_requirements: Optional[list[str]] = None,
|
49
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
50
|
+
resource_constraint: Optional[dict[str, str]] = None,
|
51
|
+
target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
|
51
52
|
python_version: Optional[str] = None,
|
52
|
-
signatures: Optional[
|
53
|
+
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
53
54
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
54
|
-
user_files: Optional[
|
55
|
-
code_paths: Optional[
|
56
|
-
ext_modules: Optional[
|
55
|
+
user_files: Optional[dict[str, list[str]]] = None,
|
56
|
+
code_paths: Optional[list[str]] = None,
|
57
|
+
ext_modules: Optional[list[ModuleType]] = None,
|
57
58
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
58
59
|
options: Optional[model_types.ModelSaveOption] = None,
|
59
|
-
statement_params: Optional[
|
60
|
+
statement_params: Optional[dict[str, Any]] = None,
|
60
61
|
) -> model_version_impl.ModelVersion:
|
61
62
|
|
62
63
|
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
@@ -129,6 +130,7 @@ class ModelManager:
|
|
129
130
|
conda_dependencies=conda_dependencies,
|
130
131
|
pip_requirements=pip_requirements,
|
131
132
|
artifact_repository_map=artifact_repository_map,
|
133
|
+
resource_constraint=resource_constraint,
|
132
134
|
target_platforms=target_platforms,
|
133
135
|
python_version=python_version,
|
134
136
|
signatures=signatures,
|
@@ -148,20 +150,21 @@ class ModelManager:
|
|
148
150
|
model_name: str,
|
149
151
|
version_name: str,
|
150
152
|
comment: Optional[str] = None,
|
151
|
-
metrics: Optional[
|
152
|
-
conda_dependencies: Optional[
|
153
|
-
pip_requirements: Optional[
|
154
|
-
artifact_repository_map: Optional[
|
155
|
-
|
153
|
+
metrics: Optional[dict[str, Any]] = None,
|
154
|
+
conda_dependencies: Optional[list[str]] = None,
|
155
|
+
pip_requirements: Optional[list[str]] = None,
|
156
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
157
|
+
resource_constraint: Optional[dict[str, str]] = None,
|
158
|
+
target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
|
156
159
|
python_version: Optional[str] = None,
|
157
|
-
signatures: Optional[
|
160
|
+
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
158
161
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
159
|
-
user_files: Optional[
|
160
|
-
code_paths: Optional[
|
161
|
-
ext_modules: Optional[
|
162
|
+
user_files: Optional[dict[str, list[str]]] = None,
|
163
|
+
code_paths: Optional[list[str]] = None,
|
164
|
+
ext_modules: Optional[list[ModuleType]] = None,
|
162
165
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
163
166
|
options: Optional[model_types.ModelSaveOption] = None,
|
164
|
-
statement_params: Optional[
|
167
|
+
statement_params: Optional[dict[str, Any]] = None,
|
165
168
|
) -> model_version_impl.ModelVersion:
|
166
169
|
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
167
170
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
@@ -208,6 +211,14 @@ class ModelManager:
|
|
208
211
|
if target_platforms:
|
209
212
|
# Convert any string target platforms to TargetPlatform objects
|
210
213
|
platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
|
214
|
+
else:
|
215
|
+
# Default the target platform to SPCS if not specified when running in ML runtime
|
216
|
+
if env.IN_ML_RUNTIME:
|
217
|
+
logger.info(
|
218
|
+
"Logging the model on Container Runtime for ML without specifying `target_platforms`. "
|
219
|
+
'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
|
220
|
+
)
|
221
|
+
platforms = [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
211
222
|
|
212
223
|
if artifact_repository_map:
|
213
224
|
for channel, artifact_repository_name in artifact_repository_map.items():
|
@@ -223,8 +234,17 @@ class ModelManager:
|
|
223
234
|
|
224
235
|
logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
|
225
236
|
|
237
|
+
# Extract save_location from options if present
|
238
|
+
save_location = None
|
239
|
+
if options and "save_location" in options:
|
240
|
+
save_location = options.get("save_location")
|
241
|
+
logger.info(f"Model will be saved to local directory: {save_location}")
|
242
|
+
|
226
243
|
mc = model_composer.ModelComposer(
|
227
|
-
self._model_ops._session,
|
244
|
+
self._model_ops._session,
|
245
|
+
stage_path=stage_path,
|
246
|
+
statement_params=statement_params,
|
247
|
+
save_location=save_location,
|
228
248
|
)
|
229
249
|
model_metadata: model_meta.ModelMetadata = mc.save(
|
230
250
|
name=model_name_id.resolved(),
|
@@ -234,6 +254,7 @@ class ModelManager:
|
|
234
254
|
conda_dependencies=conda_dependencies,
|
235
255
|
pip_requirements=pip_requirements,
|
236
256
|
artifact_repository_map=artifact_repository_map,
|
257
|
+
resource_constraint=resource_constraint,
|
237
258
|
target_platforms=platforms,
|
238
259
|
python_version=python_version,
|
239
260
|
user_files=user_files,
|
@@ -295,7 +316,7 @@ class ModelManager:
|
|
295
316
|
self,
|
296
317
|
model_name: str,
|
297
318
|
*,
|
298
|
-
statement_params: Optional[
|
319
|
+
statement_params: Optional[dict[str, Any]] = None,
|
299
320
|
) -> model_impl.Model:
|
300
321
|
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
301
322
|
if self._model_ops.validate_existence(
|
@@ -323,8 +344,8 @@ class ModelManager:
|
|
323
344
|
def models(
|
324
345
|
self,
|
325
346
|
*,
|
326
|
-
statement_params: Optional[
|
327
|
-
) ->
|
347
|
+
statement_params: Optional[dict[str, Any]] = None,
|
348
|
+
) -> list[model_impl.Model]:
|
328
349
|
model_names = self._model_ops.list_models_or_versions(
|
329
350
|
database_name=None,
|
330
351
|
schema_name=None,
|
@@ -342,7 +363,7 @@ class ModelManager:
|
|
342
363
|
def show_models(
|
343
364
|
self,
|
344
365
|
*,
|
345
|
-
statement_params: Optional[
|
366
|
+
statement_params: Optional[dict[str, Any]] = None,
|
346
367
|
) -> pd.DataFrame:
|
347
368
|
rows = self._model_ops.show_models_or_versions(
|
348
369
|
database_name=None,
|
@@ -355,7 +376,7 @@ class ModelManager:
|
|
355
376
|
self,
|
356
377
|
model_name: str,
|
357
378
|
*,
|
358
|
-
statement_params: Optional[
|
379
|
+
statement_params: Optional[dict[str, Any]] = None,
|
359
380
|
) -> None:
|
360
381
|
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
361
382
|
|
@@ -368,7 +389,7 @@ class ModelManager:
|
|
368
389
|
|
369
390
|
def _parse_fully_qualified_name(
|
370
391
|
self, model_name: str
|
371
|
-
) ->
|
392
|
+
) -> tuple[
|
372
393
|
Optional[sql_identifier.SqlIdentifier], Optional[sql_identifier.SqlIdentifier], sql_identifier.SqlIdentifier
|
373
394
|
]:
|
374
395
|
try:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import warnings
|
2
2
|
from types import ModuleType
|
3
|
-
from typing import Any,
|
3
|
+
from typing import Any, Optional, Union, overload
|
4
4
|
|
5
5
|
import pandas as pd
|
6
6
|
|
@@ -36,7 +36,7 @@ class Registry:
|
|
36
36
|
*,
|
37
37
|
database_name: Optional[str] = None,
|
38
38
|
schema_name: Optional[str] = None,
|
39
|
-
options: Optional[
|
39
|
+
options: Optional[dict[str, Any]] = None,
|
40
40
|
) -> None:
|
41
41
|
"""Opens a registry within a pre-created Snowflake schema.
|
42
42
|
|
@@ -75,7 +75,9 @@ class Registry:
|
|
75
75
|
)
|
76
76
|
|
77
77
|
self._model_manager = model_manager.ModelManager(
|
78
|
-
session,
|
78
|
+
session,
|
79
|
+
database_name=self._database_name,
|
80
|
+
schema_name=self._schema_name,
|
79
81
|
)
|
80
82
|
|
81
83
|
self.enable_monitoring = options.get("enable_monitoring", True) if options else True
|
@@ -105,17 +107,18 @@ class Registry:
|
|
105
107
|
model_name: str,
|
106
108
|
version_name: Optional[str] = None,
|
107
109
|
comment: Optional[str] = None,
|
108
|
-
metrics: Optional[
|
109
|
-
conda_dependencies: Optional[
|
110
|
-
pip_requirements: Optional[
|
111
|
-
artifact_repository_map: Optional[
|
112
|
-
|
110
|
+
metrics: Optional[dict[str, Any]] = None,
|
111
|
+
conda_dependencies: Optional[list[str]] = None,
|
112
|
+
pip_requirements: Optional[list[str]] = None,
|
113
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
114
|
+
resource_constraint: Optional[dict[str, str]] = None,
|
115
|
+
target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
|
113
116
|
python_version: Optional[str] = None,
|
114
|
-
signatures: Optional[
|
117
|
+
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
115
118
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
116
|
-
user_files: Optional[
|
117
|
-
code_paths: Optional[
|
118
|
-
ext_modules: Optional[
|
119
|
+
user_files: Optional[dict[str, list[str]]] = None,
|
120
|
+
code_paths: Optional[list[str]] = None,
|
121
|
+
ext_modules: Optional[list[ModuleType]] = None,
|
119
122
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
120
123
|
options: Optional[model_types.ModelSaveOption] = None,
|
121
124
|
) -> ModelVersion:
|
@@ -150,6 +153,7 @@ class Registry:
|
|
150
153
|
Format: {channel_name: artifact_repository_name}, where:
|
151
154
|
- channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
|
152
155
|
- artifact_repository_name: The name or URL of the repository to fetch packages from.
|
156
|
+
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
153
157
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
154
158
|
{"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
|
155
159
|
python_version: Python version in which the model is run. Defaults to None.
|
@@ -181,6 +185,7 @@ class Registry:
|
|
181
185
|
- target_methods: List of target methods to register when logging the model.
|
182
186
|
This option is not used in MLFlow models. Defaults to None, in which case the model handler's
|
183
187
|
default target methods will be used.
|
188
|
+
- save_location: Location to save the model and metadata.
|
184
189
|
- method_options: Per-method saving options. This dictionary has method names as keys and dictionary
|
185
190
|
values with the desired options.
|
186
191
|
|
@@ -229,6 +234,7 @@ class Registry:
|
|
229
234
|
"conda_dependencies",
|
230
235
|
"pip_requirements",
|
231
236
|
"artifact_repository_map",
|
237
|
+
"resource_constraint",
|
232
238
|
"target_platforms",
|
233
239
|
"python_version",
|
234
240
|
"signatures",
|
@@ -241,17 +247,18 @@ class Registry:
|
|
241
247
|
model_name: str,
|
242
248
|
version_name: Optional[str] = None,
|
243
249
|
comment: Optional[str] = None,
|
244
|
-
metrics: Optional[
|
245
|
-
conda_dependencies: Optional[
|
246
|
-
pip_requirements: Optional[
|
247
|
-
artifact_repository_map: Optional[
|
248
|
-
|
250
|
+
metrics: Optional[dict[str, Any]] = None,
|
251
|
+
conda_dependencies: Optional[list[str]] = None,
|
252
|
+
pip_requirements: Optional[list[str]] = None,
|
253
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
254
|
+
resource_constraint: Optional[dict[str, str]] = None,
|
255
|
+
target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
|
249
256
|
python_version: Optional[str] = None,
|
250
|
-
signatures: Optional[
|
257
|
+
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
251
258
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
252
|
-
user_files: Optional[
|
253
|
-
code_paths: Optional[
|
254
|
-
ext_modules: Optional[
|
259
|
+
user_files: Optional[dict[str, list[str]]] = None,
|
260
|
+
code_paths: Optional[list[str]] = None,
|
261
|
+
ext_modules: Optional[list[ModuleType]] = None,
|
255
262
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
256
263
|
options: Optional[model_types.ModelSaveOption] = None,
|
257
264
|
) -> ModelVersion:
|
@@ -286,6 +293,7 @@ class Registry:
|
|
286
293
|
Format: {channel_name: artifact_repository_name}, where:
|
287
294
|
- channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
|
288
295
|
- artifact_repository_name: The name or URL of the repository to fetch packages from.
|
296
|
+
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
289
297
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
290
298
|
{"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
|
291
299
|
python_version: Python version in which the model is run. Defaults to None.
|
@@ -317,6 +325,7 @@ class Registry:
|
|
317
325
|
- target_methods: List of target methods to register when logging the model.
|
318
326
|
This option is not used in MLFlow models. Defaults to None, in which case the model handler's
|
319
327
|
default target methods will be used.
|
328
|
+
- save_location: Location to save the model and metadata.
|
320
329
|
- method_options: Per-method saving options. This dictionary has method names as keys and dictionary
|
321
330
|
values with the desired options. See the example below.
|
322
331
|
|
@@ -369,6 +378,7 @@ class Registry:
|
|
369
378
|
conda_dependencies,
|
370
379
|
pip_requirements,
|
371
380
|
artifact_repository_map,
|
381
|
+
resource_constraint,
|
372
382
|
target_platforms,
|
373
383
|
python_version,
|
374
384
|
signatures,
|
@@ -403,6 +413,7 @@ class Registry:
|
|
403
413
|
conda_dependencies=conda_dependencies,
|
404
414
|
pip_requirements=pip_requirements,
|
405
415
|
artifact_repository_map=artifact_repository_map,
|
416
|
+
resource_constraint=resource_constraint,
|
406
417
|
target_platforms=target_platforms,
|
407
418
|
python_version=python_version,
|
408
419
|
signatures=signatures,
|
@@ -438,7 +449,7 @@ class Registry:
|
|
438
449
|
project=_TELEMETRY_PROJECT,
|
439
450
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
440
451
|
)
|
441
|
-
def models(self) ->
|
452
|
+
def models(self) -> list[Model]:
|
442
453
|
"""Get all models in the schema where the registry is opened.
|
443
454
|
|
444
455
|
Returns:
|
@@ -564,7 +575,7 @@ class Registry:
|
|
564
575
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
565
576
|
)
|
566
577
|
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
567
|
-
def show_model_monitors(self) ->
|
578
|
+
def show_model_monitors(self) -> list[snowpark.Row]:
|
568
579
|
"""Show all model monitors in the registry.
|
569
580
|
|
570
581
|
Returns:
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import http
|
2
2
|
import logging
|
3
3
|
from datetime import timedelta
|
4
|
-
from typing import
|
4
|
+
from typing import Optional
|
5
5
|
|
6
6
|
import requests
|
7
7
|
from cryptography.hazmat.primitives.asymmetric import types
|
@@ -10,7 +10,7 @@ from requests import auth
|
|
10
10
|
from snowflake.ml._internal.utils import jwt_generator
|
11
11
|
|
12
12
|
logger = logging.getLogger(__name__)
|
13
|
-
_JWT_TOKEN_CACHE:
|
13
|
+
_JWT_TOKEN_CACHE: dict[str, dict[int, str]] = {}
|
14
14
|
|
15
15
|
|
16
16
|
def get_jwt_token_generator(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import configparser
|
2
2
|
import os
|
3
|
-
from typing import
|
3
|
+
from typing import Optional, Union
|
4
4
|
|
5
5
|
from absl import logging
|
6
6
|
from cryptography.hazmat import backends
|
@@ -76,7 +76,7 @@ def _load_pem_to_der(private_key_path: str) -> bytes:
|
|
76
76
|
)
|
77
77
|
|
78
78
|
|
79
|
-
def _connection_properties_from_env() ->
|
79
|
+
def _connection_properties_from_env() -> dict[str, str]:
|
80
80
|
"""Returns a dict with all possible login related env variables."""
|
81
81
|
sf_conn_prop = {
|
82
82
|
# Mandatory fields
|
@@ -104,7 +104,7 @@ def _connection_properties_from_env() -> Dict[str, str]:
|
|
104
104
|
return sf_conn_prop
|
105
105
|
|
106
106
|
|
107
|
-
def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") ->
|
107
|
+
def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -> dict[str, str]:
|
108
108
|
"""Loads the dictionary from snowsql config file."""
|
109
109
|
snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
|
110
110
|
if not os.path.exists(snowsql_config_file):
|
@@ -133,7 +133,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
|
|
133
133
|
|
134
134
|
|
135
135
|
@snowpark._internal.utils.private_preview(version="0.2.0")
|
136
|
-
def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) ->
|
136
|
+
def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> dict[str, Union[str, bytes]]:
|
137
137
|
"""Returns a dict that can be used directly into snowflake python connector or Snowpark session config.
|
138
138
|
|
139
139
|
NOTE: Token/Auth information is sideloaded in all cases above, if provided in following order:
|
@@ -164,7 +164,7 @@ def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] =
|
|
164
164
|
Raises:
|
165
165
|
Exception: if none of config file and environment variable are present.
|
166
166
|
"""
|
167
|
-
conn_prop:
|
167
|
+
conn_prop: dict[str, Union[str, bytes]] = {}
|
168
168
|
login_file = login_file or os.path.expanduser(_DEFAULT_CONNECTION_FILE)
|
169
169
|
# If login file exists, use this exclusively.
|
170
170
|
if os.path.exists(login_file):
|
snowflake/ml/utils/sparse.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import collections
|
2
2
|
import json
|
3
|
-
from typing import
|
3
|
+
from typing import Optional
|
4
4
|
|
5
5
|
import pandas as pd
|
6
6
|
from pandas import arrays as pandas_arrays
|
@@ -9,7 +9,7 @@ from pandas.core.arrays import sparse as pandas_sparse
|
|
9
9
|
from snowflake.snowpark import DataFrame
|
10
10
|
|
11
11
|
|
12
|
-
def _pandas_to_sparse_pandas(pandas_df: pd.DataFrame, sparse_cols:
|
12
|
+
def _pandas_to_sparse_pandas(pandas_df: pd.DataFrame, sparse_cols: list[str]) -> Optional[pd.DataFrame]:
|
13
13
|
"""Convert the pandas df into pandas df with multiple SparseArray columns."""
|
14
14
|
num_rows = pandas_df.shape[0]
|
15
15
|
if num_rows == 0:
|
@@ -52,8 +52,9 @@ def _pandas_to_sparse_pandas(pandas_df: pd.DataFrame, sparse_cols: List[str]) ->
|
|
52
52
|
return pandas_df
|
53
53
|
|
54
54
|
|
55
|
-
def to_pandas_with_sparse(df: DataFrame, sparse_cols:
|
56
|
-
"""Load a Snowpark df with sparse columns represented in JSON strings into pandas df with multiple SparseArray
|
55
|
+
def to_pandas_with_sparse(df: DataFrame, sparse_cols: list[str]) -> pd.DataFrame:
|
56
|
+
"""Load a Snowpark df with sparse columns represented in JSON strings into pandas df with multiple SparseArray
|
57
|
+
columns.
|
57
58
|
|
58
59
|
For example, for below input:
|
59
60
|
----------------------------------------------
|
snowflake/ml/utils/sql_client.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
from enum import Enum
|
2
|
-
from typing import Dict
|
3
2
|
|
4
3
|
|
5
4
|
class CreationOption(Enum):
|
@@ -13,7 +12,7 @@ class CreationMode:
|
|
13
12
|
self.if_not_exists = if_not_exists
|
14
13
|
self.or_replace = or_replace
|
15
14
|
|
16
|
-
def get_ddl_phrases(self) ->
|
15
|
+
def get_ddl_phrases(self) -> dict[CreationOption, str]:
|
17
16
|
if_not_exists_sql = " IF NOT EXISTS" if self.if_not_exists else ""
|
18
17
|
or_replace_sql = " OR REPLACE" if self.or_replace else ""
|
19
18
|
return {
|
snowflake/ml/version.py
CHANGED
@@ -1 +1,2 @@
|
|
1
|
-
|
1
|
+
# This is parsed by regex in conda recipe meta file. Make sure not to break it.
|
2
|
+
VERSION = "1.8.3"
|