snowflake-ml-python 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +4 -2
- snowflake/ml/_internal/utils/import_utils.py +31 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/data/torch_utils.py +33 -14
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
- snowflake/ml/feature_store/examples/example_helper.py +6 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
- snowflake/ml/feature_store/feature_store.py +1 -2
- snowflake/ml/feature_store/feature_view.py +5 -1
- snowflake/ml/model/_client/model/model_version_impl.py +144 -10
- snowflake/ml/model/_client/ops/model_ops.py +25 -6
- snowflake/ml/model/_client/ops/service_ops.py +33 -28
- snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_client/sql/model.py +14 -0
- snowflake/ml/model/_client/sql/service.py +6 -18
- snowflake/ml/model/_model_composer/model_composer.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +3 -6
- snowflake/ml/model/_packager/model_handlers/custom.py +2 -0
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +3 -6
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -65
- snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
- snowflake/ml/model/_packager/model_packager.py +0 -11
- snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +13 -25
- snowflake/ml/model/_signatures/pandas_handler.py +16 -0
- snowflake/ml/model/custom_model.py +47 -7
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/model/type_hints.py +8 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
- snowflake/ml/modeling/cluster/dbscan.py +5 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
- snowflake/ml/modeling/cluster/k_means.py +14 -19
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
- snowflake/ml/modeling/cluster/optics.py +6 -6
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
- snowflake/ml/modeling/compose/column_transformer.py +15 -5
- snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
- snowflake/ml/modeling/decomposition/pca.py +28 -15
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
- snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
- snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +0 -10
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
- snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +3 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
- snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
- snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +3 -3
- snowflake/ml/modeling/manifold/tsne.py +10 -4
- snowflake/ml/modeling/metrics/classification.py +12 -16
- snowflake/ml/modeling/metrics/ranking.py +3 -3
- snowflake/ml/modeling/metrics/regression.py +3 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
- snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +16 -14
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
- snowflake/ml/modeling/svm/linear_svc.py +25 -16
- snowflake/ml/modeling/svm/linear_svr.py +23 -17
- snowflake/ml/modeling/svm/nu_svc.py +5 -3
- snowflake/ml/modeling/svm/nu_svr.py +3 -1
- snowflake/ml/modeling/svm/svc.py +9 -5
- snowflake/ml/modeling/svm/svr.py +3 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
- snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
- snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
- snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
- snowflake/ml/monitoring/_client/{monitor_sql_client.py → model_monitor_sql_client.py} +1 -1
- snowflake/ml/monitoring/{_client → _manager}/model_monitor_manager.py +9 -8
- snowflake/ml/monitoring/{_client/model_monitor.py → model_monitor.py} +3 -3
- snowflake/ml/registry/_manager/model_manager.py +15 -1
- snowflake/ml/registry/registry.py +15 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/METADATA +81 -9
- {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/RECORD +150 -150
- {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/WHEEL +1 -1
- /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
- {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,9 @@ from snowflake.snowpark._internal import utils as snowpark_utils
|
|
32
32
|
|
33
33
|
|
34
34
|
class ModelOperator:
|
35
|
+
INFERENCE_SERVICE_NAME_COL_NAME = "service_name"
|
36
|
+
INFERENCE_SERVICE_ENDPOINT_COL_NAME = "endpoints"
|
37
|
+
|
35
38
|
def __init__(
|
36
39
|
self,
|
37
40
|
session: session.Session,
|
@@ -522,7 +525,7 @@ class ModelOperator:
|
|
522
525
|
model_name: sql_identifier.SqlIdentifier,
|
523
526
|
version_name: sql_identifier.SqlIdentifier,
|
524
527
|
statement_params: Optional[Dict[str, Any]] = None,
|
525
|
-
) -> List[str]:
|
528
|
+
) -> Dict[str, List[str]]:
|
526
529
|
res = self._model_client.show_versions(
|
527
530
|
database_name=database_name,
|
528
531
|
schema_name=schema_name,
|
@@ -530,8 +533,8 @@ class ModelOperator:
|
|
530
533
|
version_name=version_name,
|
531
534
|
statement_params=statement_params,
|
532
535
|
)
|
533
|
-
|
534
|
-
if
|
536
|
+
service_col_name = self._model_client.MODEL_VERSION_INFERENCE_SERVICES_COL_NAME
|
537
|
+
if service_col_name not in res[0]:
|
535
538
|
# User need to opt into BCR 2024_08
|
536
539
|
raise exceptions.SnowflakeMLException(
|
537
540
|
error_code=error_codes.OPT_IN_REQUIRED,
|
@@ -540,9 +543,24 @@ class ModelOperator:
|
|
540
543
|
"https://docs.snowflake.com/en/release-notes/bcr-bundles/2024_08_bundle)."
|
541
544
|
),
|
542
545
|
)
|
543
|
-
|
546
|
+
|
547
|
+
json_array = json.loads(res[0][service_col_name])
|
544
548
|
# TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
|
545
|
-
|
549
|
+
services = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
|
550
|
+
endpoint_col_name = self._model_client.MODEL_INFERENCE_SERVICE_ENDPOINT_COL_NAME
|
551
|
+
|
552
|
+
services_col, endpoints_col = [], []
|
553
|
+
for service in services:
|
554
|
+
res = self._model_client.show_endpoints(service_name=service)
|
555
|
+
endpoints = [endpoint[endpoint_col_name] for endpoint in res]
|
556
|
+
for endpoint in endpoints:
|
557
|
+
services_col.append(service)
|
558
|
+
endpoints_col.append(endpoint)
|
559
|
+
|
560
|
+
return {
|
561
|
+
self.INFERENCE_SERVICE_NAME_COL_NAME: services_col,
|
562
|
+
self.INFERENCE_SERVICE_ENDPOINT_COL_NAME: endpoints_col,
|
563
|
+
}
|
546
564
|
|
547
565
|
def delete_service(
|
548
566
|
self,
|
@@ -566,7 +584,8 @@ class ModelOperator:
|
|
566
584
|
db, schema, service_name, self._session.get_current_database(), self._session.get_current_schema()
|
567
585
|
)
|
568
586
|
|
569
|
-
|
587
|
+
service_col_name = self.INFERENCE_SERVICE_NAME_COL_NAME
|
588
|
+
for service in services[service_col_name]:
|
570
589
|
if service == fully_qualified_service_name:
|
571
590
|
self._service_client.drop_service(
|
572
591
|
database_name=db,
|
@@ -100,11 +100,13 @@ class ServiceOperator:
|
|
100
100
|
image_repo_name: sql_identifier.SqlIdentifier,
|
101
101
|
ingress_enabled: bool,
|
102
102
|
max_instances: int,
|
103
|
+
cpu_requests: Optional[str],
|
104
|
+
memory_requests: Optional[str],
|
103
105
|
gpu_requests: Optional[str],
|
104
106
|
num_workers: Optional[int],
|
105
107
|
max_batch_rows: Optional[int],
|
106
108
|
force_rebuild: bool,
|
107
|
-
|
109
|
+
build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
|
108
110
|
statement_params: Optional[Dict[str, Any]] = None,
|
109
111
|
) -> str:
|
110
112
|
# create a temp stage
|
@@ -119,6 +121,14 @@ class ServiceOperator:
|
|
119
121
|
)
|
120
122
|
stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
|
121
123
|
|
124
|
+
# TODO(hayu): Remove the version check after Snowflake 8.40.0 release
|
125
|
+
if (
|
126
|
+
snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
|
127
|
+
< version.parse("8.40.0")
|
128
|
+
and build_external_access_integrations is None
|
129
|
+
):
|
130
|
+
raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
|
131
|
+
|
122
132
|
self._model_deployment_spec.save(
|
123
133
|
database_name=database_name or self._database_name,
|
124
134
|
schema_name=schema_name or self._schema_name,
|
@@ -134,11 +144,13 @@ class ServiceOperator:
|
|
134
144
|
image_repo_name=image_repo_name,
|
135
145
|
ingress_enabled=ingress_enabled,
|
136
146
|
max_instances=max_instances,
|
147
|
+
cpu=cpu_requests,
|
148
|
+
memory=memory_requests,
|
137
149
|
gpu=gpu_requests,
|
138
150
|
num_workers=num_workers,
|
139
151
|
max_batch_rows=max_batch_rows,
|
140
152
|
force_rebuild=force_rebuild,
|
141
|
-
|
153
|
+
external_access_integrations=build_external_access_integrations,
|
142
154
|
)
|
143
155
|
file_utils.upload_directory_to_stage(
|
144
156
|
self._session,
|
@@ -163,32 +175,25 @@ class ServiceOperator:
|
|
163
175
|
statement_params=statement_params,
|
164
176
|
)
|
165
177
|
|
166
|
-
#
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
model_build_service_name
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
log_thread = self._start_service_log_streaming(
|
186
|
-
async_job, services, model_inference_service_exists, force_rebuild, statement_params
|
187
|
-
)
|
188
|
-
log_thread.join()
|
189
|
-
else:
|
190
|
-
while not async_job.is_done():
|
191
|
-
time.sleep(5)
|
178
|
+
# stream service logs in a thread
|
179
|
+
model_build_service_name = sql_identifier.SqlIdentifier(self._get_model_build_service_name(query_id))
|
180
|
+
model_build_service = ServiceLogInfo(
|
181
|
+
database_name=service_database_name,
|
182
|
+
schema_name=service_schema_name,
|
183
|
+
service_name=model_build_service_name,
|
184
|
+
container_name="model-build",
|
185
|
+
)
|
186
|
+
model_inference_service = ServiceLogInfo(
|
187
|
+
database_name=service_database_name,
|
188
|
+
schema_name=service_schema_name,
|
189
|
+
service_name=service_name,
|
190
|
+
container_name="model-inference",
|
191
|
+
)
|
192
|
+
services = [model_build_service, model_inference_service]
|
193
|
+
log_thread = self._start_service_log_streaming(
|
194
|
+
async_job, services, model_inference_service_exists, force_rebuild, statement_params
|
195
|
+
)
|
196
|
+
log_thread.join()
|
192
197
|
|
193
198
|
res = cast(str, cast(List[row.Row], async_job.result())[0][0])
|
194
199
|
module_logger.info(f"Inference service {service_name} deployment complete: {res}")
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import pathlib
|
2
|
-
from typing import Optional
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
import yaml
|
5
5
|
|
@@ -36,11 +36,13 @@ class ModelDeploymentSpec:
|
|
36
36
|
image_repo_name: sql_identifier.SqlIdentifier,
|
37
37
|
ingress_enabled: bool,
|
38
38
|
max_instances: int,
|
39
|
+
cpu: Optional[str],
|
40
|
+
memory: Optional[str],
|
39
41
|
gpu: Optional[str],
|
40
42
|
num_workers: Optional[int],
|
41
43
|
max_batch_rows: Optional[int],
|
42
44
|
force_rebuild: bool,
|
43
|
-
|
45
|
+
external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
|
44
46
|
) -> None:
|
45
47
|
# create the deployment spec
|
46
48
|
# models spec
|
@@ -55,12 +57,15 @@ class ModelDeploymentSpec:
|
|
55
57
|
fq_image_repo_name = identifier.get_schema_level_object_identifier(
|
56
58
|
saved_image_repo_database.identifier(), saved_image_repo_schema.identifier(), image_repo_name.identifier()
|
57
59
|
)
|
58
|
-
image_build_dict
|
59
|
-
compute_pool
|
60
|
-
image_repo
|
61
|
-
force_rebuild
|
62
|
-
|
63
|
-
|
60
|
+
image_build_dict: model_deployment_spec_schema.ImageBuildDict = {
|
61
|
+
"compute_pool": image_build_compute_pool_name.identifier(),
|
62
|
+
"image_repo": fq_image_repo_name,
|
63
|
+
"force_rebuild": force_rebuild,
|
64
|
+
}
|
65
|
+
if external_access_integrations is not None:
|
66
|
+
image_build_dict["external_access_integrations"] = [
|
67
|
+
eai.identifier() for eai in external_access_integrations
|
68
|
+
]
|
64
69
|
|
65
70
|
# service spec
|
66
71
|
saved_service_database = service_database_name or database_name
|
@@ -74,6 +79,12 @@ class ModelDeploymentSpec:
|
|
74
79
|
ingress_enabled=ingress_enabled,
|
75
80
|
max_instances=max_instances,
|
76
81
|
)
|
82
|
+
if cpu:
|
83
|
+
service_dict["cpu"] = cpu
|
84
|
+
|
85
|
+
if memory:
|
86
|
+
service_dict["memory"] = memory
|
87
|
+
|
77
88
|
if gpu:
|
78
89
|
service_dict["gpu"] = gpu
|
79
90
|
|
@@ -12,7 +12,7 @@ class ImageBuildDict(TypedDict):
|
|
12
12
|
compute_pool: Required[str]
|
13
13
|
image_repo: Required[str]
|
14
14
|
force_rebuild: Required[bool]
|
15
|
-
external_access_integrations:
|
15
|
+
external_access_integrations: NotRequired[List[str]]
|
16
16
|
|
17
17
|
|
18
18
|
class ServiceDict(TypedDict):
|
@@ -20,6 +20,8 @@ class ServiceDict(TypedDict):
|
|
20
20
|
compute_pool: Required[str]
|
21
21
|
ingress_enabled: Required[bool]
|
22
22
|
max_instances: Required[int]
|
23
|
+
cpu: NotRequired[str]
|
24
|
+
memory: NotRequired[str]
|
23
25
|
gpu: NotRequired[str]
|
24
26
|
num_workers: NotRequired[int]
|
25
27
|
max_batch_rows: NotRequired[int]
|
@@ -17,6 +17,8 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
17
17
|
MODEL_VERSION_ALIASES_COL_NAME = "aliases"
|
18
18
|
MODEL_VERSION_INFERENCE_SERVICES_COL_NAME = "inference_services"
|
19
19
|
|
20
|
+
MODEL_INFERENCE_SERVICE_ENDPOINT_COL_NAME = "name"
|
21
|
+
|
20
22
|
def show_models(
|
21
23
|
self,
|
22
24
|
*,
|
@@ -83,6 +85,18 @@ class ModelSQLClient(_base._BaseSQLClient):
|
|
83
85
|
|
84
86
|
return res.validate()
|
85
87
|
|
88
|
+
def show_endpoints(
|
89
|
+
self,
|
90
|
+
*,
|
91
|
+
service_name: str,
|
92
|
+
) -> List[row.Row]:
|
93
|
+
res = query_result_checker.SqlResultValidator(
|
94
|
+
self._session,
|
95
|
+
(f"SHOW ENDPOINTS IN SERVICE {service_name}"),
|
96
|
+
).has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
|
97
|
+
|
98
|
+
return res.validate()
|
99
|
+
|
86
100
|
def set_comment(
|
87
101
|
self,
|
88
102
|
*,
|
@@ -3,13 +3,10 @@ import json
|
|
3
3
|
import textwrap
|
4
4
|
from typing import Any, Dict, List, Optional, Tuple
|
5
5
|
|
6
|
-
from packaging import version
|
7
|
-
|
8
6
|
from snowflake import snowpark
|
9
7
|
from snowflake.ml._internal.utils import (
|
10
8
|
identifier,
|
11
9
|
query_result_checker,
|
12
|
-
snowflake_env,
|
13
10
|
sql_identifier,
|
14
11
|
)
|
15
12
|
from snowflake.ml.model._client.sql import _base
|
@@ -120,21 +117,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
120
117
|
args_sql_list.append(input_arg_value)
|
121
118
|
args_sql = ", ".join(args_sql_list)
|
122
119
|
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
|
130
|
-
|
131
|
-
else:
|
132
|
-
function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
|
133
|
-
fully_qualified_function_name = identifier.get_schema_level_object_identifier(
|
134
|
-
actual_database_name.identifier(),
|
135
|
-
actual_schema_name.identifier(),
|
136
|
-
function_name,
|
137
|
-
)
|
120
|
+
function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
|
121
|
+
fully_qualified_function_name = identifier.get_schema_level_object_identifier(
|
122
|
+
actual_database_name.identifier(),
|
123
|
+
actual_schema_name.identifier(),
|
124
|
+
function_name,
|
125
|
+
)
|
138
126
|
|
139
127
|
sql = textwrap.dedent(
|
140
128
|
f"""{with_sql}
|
@@ -86,6 +86,7 @@ class ModelComposer:
|
|
86
86
|
metadata: Optional[Dict[str, str]] = None,
|
87
87
|
conda_dependencies: Optional[List[str]] = None,
|
88
88
|
pip_requirements: Optional[List[str]] = None,
|
89
|
+
target_platforms: Optional[List[model_types.TargetPlatform]] = None,
|
89
90
|
python_version: Optional[str] = None,
|
90
91
|
ext_modules: Optional[List[ModuleType]] = None,
|
91
92
|
code_paths: Optional[List[str]] = None,
|
@@ -131,6 +132,7 @@ class ModelComposer:
|
|
131
132
|
model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
|
132
133
|
options=options,
|
133
134
|
data_sources=self._get_data_sources(model, sample_input_data),
|
135
|
+
target_platforms=target_platforms,
|
134
136
|
)
|
135
137
|
|
136
138
|
file_utils.upload_directory_to_stage(
|
@@ -44,6 +44,7 @@ class ModelManifest:
|
|
44
44
|
model_rel_path: pathlib.PurePosixPath,
|
45
45
|
options: Optional[type_hints.ModelSaveOption] = None,
|
46
46
|
data_sources: Optional[List[data_source.DataSource]] = None,
|
47
|
+
target_platforms: Optional[List[type_hints.TargetPlatform]] = None,
|
47
48
|
) -> None:
|
48
49
|
if options is None:
|
49
50
|
options = {}
|
@@ -132,6 +133,9 @@ class ModelManifest:
|
|
132
133
|
if lineage_sources:
|
133
134
|
manifest_dict["lineage_sources"] = lineage_sources
|
134
135
|
|
136
|
+
if target_platforms:
|
137
|
+
manifest_dict["target_platforms"] = [platform.value for platform in target_platforms]
|
138
|
+
|
135
139
|
with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f:
|
136
140
|
# Anchors are not supported in the server, avoid that.
|
137
141
|
yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
|
@@ -27,7 +27,7 @@ def get_model_method_options_from_options(
|
|
27
27
|
options: type_hints.ModelSaveOption, target_method: str
|
28
28
|
) -> ModelMethodOptions:
|
29
29
|
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
30
|
-
if
|
30
|
+
if target_method == "explain":
|
31
31
|
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
32
32
|
method_option = options.get("method_options", {}).get(target_method, {})
|
33
33
|
global_function_type = options.get("function_type", default_function_type)
|
@@ -191,7 +191,11 @@ def convert_explanations_to_2D_df(
|
|
191
191
|
# convert to object or numpy creates strings of fixed length
|
192
192
|
return np.asarray(json.dumps(dict(zip(classes_list, row)), cls=NumpyEncoder), dtype=object)
|
193
193
|
|
194
|
-
|
194
|
+
# convert to dict only for multiclass
|
195
|
+
if len(classes_list) > 2:
|
196
|
+
exp_2d = np.apply_along_axis(row_to_dict, -1, explanations)
|
197
|
+
else: # assumes index 1 is positive class always
|
198
|
+
exp_2d = np.apply_along_axis(lambda arr: arr[1], -1, explanations)
|
195
199
|
|
196
200
|
return pd.DataFrame(exp_2d)
|
197
201
|
|
@@ -9,17 +9,14 @@ from typing_extensions import TypeGuard, Unpack
|
|
9
9
|
from snowflake.ml._internal import type_utils
|
10
10
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
11
11
|
from snowflake.ml.model._packager.model_env import model_env
|
12
|
-
from snowflake.ml.model._packager.model_handlers import
|
13
|
-
_base,
|
14
|
-
_utils as handlers_utils,
|
15
|
-
model_objective_utils,
|
16
|
-
)
|
12
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
17
13
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
18
14
|
from snowflake.ml.model._packager.model_meta import (
|
19
15
|
model_blob_meta,
|
20
16
|
model_meta as model_meta_api,
|
21
17
|
model_meta_schema,
|
22
18
|
)
|
19
|
+
from snowflake.ml.model._packager.model_task import model_task_utils
|
23
20
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
24
21
|
|
25
22
|
if TYPE_CHECKING:
|
@@ -97,7 +94,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
97
94
|
sample_input_data=sample_input_data,
|
98
95
|
get_prediction_fn=get_prediction,
|
99
96
|
)
|
100
|
-
model_task_and_output =
|
97
|
+
model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
|
101
98
|
model_meta.task = model_task_and_output.task
|
102
99
|
if enable_explainability:
|
103
100
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
@@ -99,6 +99,8 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
99
99
|
for sub_name, model_ref in model.context.model_refs.items():
|
100
100
|
handler = model_handler.find_handler(model_ref.model)
|
101
101
|
assert handler is not None
|
102
|
+
if handler is None:
|
103
|
+
raise TypeError("Your input type to custom model is not currently supported")
|
102
104
|
sub_model = handler.cast_model(model_ref.model)
|
103
105
|
handler.save_model(
|
104
106
|
name=sub_name,
|
@@ -256,12 +256,20 @@ class HuggingFacePipelineHandler(
|
|
256
256
|
@staticmethod
|
257
257
|
def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> Dict[str, str]:
|
258
258
|
device_config: Dict[str, Any] = {}
|
259
|
+
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
260
|
+
gpu_nums = 0
|
261
|
+
if cuda_visible_devices is not None:
|
262
|
+
gpu_nums = len(cuda_visible_devices.split(","))
|
259
263
|
if (
|
260
264
|
kwargs.get("use_gpu", False)
|
261
265
|
and kwargs.get("device_map", None) is None
|
262
266
|
and kwargs.get("device", None) is None
|
263
267
|
):
|
264
|
-
|
268
|
+
if gpu_nums == 0 or gpu_nums > 1:
|
269
|
+
# Use accelerator if there are multiple GPUs or no GPU
|
270
|
+
device_config["device_map"] = "auto"
|
271
|
+
else:
|
272
|
+
device_config["device"] = "cuda"
|
265
273
|
elif kwargs.get("device_map", None) is not None:
|
266
274
|
device_config["device_map"] = kwargs["device_map"]
|
267
275
|
elif kwargs.get("device", None) is not None:
|
@@ -310,6 +318,7 @@ class HuggingFacePipelineHandler(
|
|
310
318
|
m = transformers.pipeline(
|
311
319
|
model_blob_options["task"],
|
312
320
|
model=model_blob_file_or_dir_path,
|
321
|
+
trust_remote_code=True,
|
313
322
|
**device_config,
|
314
323
|
)
|
315
324
|
|
@@ -20,17 +20,14 @@ from typing_extensions import TypeGuard, Unpack
|
|
20
20
|
from snowflake.ml._internal import type_utils
|
21
21
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
22
22
|
from snowflake.ml.model._packager.model_env import model_env
|
23
|
-
from snowflake.ml.model._packager.model_handlers import
|
24
|
-
_base,
|
25
|
-
_utils as handlers_utils,
|
26
|
-
model_objective_utils,
|
27
|
-
)
|
23
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
28
24
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
29
25
|
from snowflake.ml.model._packager.model_meta import (
|
30
26
|
model_blob_meta,
|
31
27
|
model_meta as model_meta_api,
|
32
28
|
model_meta_schema,
|
33
29
|
)
|
30
|
+
from snowflake.ml.model._packager.model_task import model_task_utils
|
34
31
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
35
32
|
|
36
33
|
if TYPE_CHECKING:
|
@@ -113,7 +110,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
113
110
|
sample_input_data=sample_input_data,
|
114
111
|
get_prediction_fn=get_prediction,
|
115
112
|
)
|
116
|
-
model_task_and_output =
|
113
|
+
model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
|
117
114
|
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
|
118
115
|
if enable_explainability:
|
119
116
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import inspect
|
1
2
|
import logging
|
2
3
|
import os
|
3
4
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
@@ -155,8 +156,14 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
155
156
|
model_blob_filename = model_blob_metadata.path
|
156
157
|
model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename)
|
157
158
|
|
159
|
+
additional_kwargs = {}
|
160
|
+
if "trust_remote_code" in inspect.signature(sentence_transformers.SentenceTransformer).parameters:
|
161
|
+
additional_kwargs["trust_remote_code"] = True
|
162
|
+
|
158
163
|
model = sentence_transformers.SentenceTransformer(
|
159
|
-
model_blob_file_or_dir_path,
|
164
|
+
model_blob_file_or_dir_path,
|
165
|
+
device=cls._get_device_config(**kwargs),
|
166
|
+
**additional_kwargs,
|
160
167
|
)
|
161
168
|
return model
|
162
169
|
|
@@ -10,17 +10,14 @@ from typing_extensions import TypeGuard, Unpack
|
|
10
10
|
from snowflake.ml._internal import type_utils
|
11
11
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
12
12
|
from snowflake.ml.model._packager.model_env import model_env
|
13
|
-
from snowflake.ml.model._packager.model_handlers import
|
14
|
-
_base,
|
15
|
-
_utils as handlers_utils,
|
16
|
-
model_objective_utils,
|
17
|
-
)
|
13
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
18
14
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
19
15
|
from snowflake.ml.model._packager.model_meta import (
|
20
16
|
model_blob_meta,
|
21
17
|
model_meta as model_meta_api,
|
22
18
|
model_meta_schema,
|
23
19
|
)
|
20
|
+
from snowflake.ml.model._packager.model_task import model_task_utils
|
24
21
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
25
22
|
|
26
23
|
if TYPE_CHECKING:
|
@@ -137,7 +134,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
137
134
|
sample_input_data, model_meta, explain_target_method
|
138
135
|
)
|
139
136
|
|
140
|
-
model_task_and_output_type =
|
137
|
+
model_task_and_output_type = model_task_utils.get_model_task_and_output_type(model)
|
141
138
|
model_meta.task = model_task_and_output_type.task
|
142
139
|
|
143
140
|
# if users did not ask then we enable if we have background data
|
@@ -5,24 +5,20 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, fin
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
|
-
from packaging import version
|
9
8
|
from typing_extensions import TypeGuard, Unpack
|
10
9
|
|
11
10
|
from snowflake.ml._internal import type_utils
|
12
11
|
from snowflake.ml._internal.exceptions import exceptions
|
13
12
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
14
13
|
from snowflake.ml.model._packager.model_env import model_env
|
15
|
-
from snowflake.ml.model._packager.model_handlers import
|
16
|
-
_base,
|
17
|
-
_utils as handlers_utils,
|
18
|
-
model_objective_utils,
|
19
|
-
)
|
14
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
20
15
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
21
16
|
from snowflake.ml.model._packager.model_meta import (
|
22
17
|
model_blob_meta,
|
23
18
|
model_meta as model_meta_api,
|
24
19
|
model_meta_schema,
|
25
20
|
)
|
21
|
+
from snowflake.ml.model._packager.model_task import model_task_utils
|
26
22
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
27
23
|
|
28
24
|
if TYPE_CHECKING:
|
@@ -72,41 +68,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
72
68
|
return cast("BaseEstimator", model)
|
73
69
|
|
74
70
|
@classmethod
|
75
|
-
def
|
76
|
-
from importlib import metadata as importlib_metadata
|
77
|
-
|
78
|
-
from packaging import version
|
79
|
-
|
80
|
-
local_version = None
|
81
|
-
|
82
|
-
try:
|
83
|
-
local_dist = importlib_metadata.distribution(pkg_name)
|
84
|
-
local_version = version.parse(local_dist.version)
|
85
|
-
except importlib_metadata.PackageNotFoundError:
|
86
|
-
pass
|
87
|
-
|
88
|
-
return local_version
|
89
|
-
|
90
|
-
@classmethod
|
91
|
-
def _can_support_xgb(cls, enable_explainability: Optional[bool]) -> bool:
|
92
|
-
|
93
|
-
local_xgb_version = cls._get_local_version_package("xgboost")
|
94
|
-
|
95
|
-
if local_xgb_version and local_xgb_version >= version.parse("2.1.0"):
|
96
|
-
if enable_explainability:
|
97
|
-
warnings.warn(
|
98
|
-
f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
|
99
|
-
+ "If you want model explanations, lower the xgboost version to <2.1.0.",
|
100
|
-
category=UserWarning,
|
101
|
-
stacklevel=1,
|
102
|
-
)
|
103
|
-
return False
|
104
|
-
return True
|
105
|
-
|
106
|
-
@classmethod
|
107
|
-
def _get_supported_object_for_explainability(
|
108
|
-
cls, estimator: "BaseEstimator", enable_explainability: Optional[bool]
|
109
|
-
) -> Any:
|
71
|
+
def _get_supported_object_for_explainability(cls, estimator: "BaseEstimator") -> Any:
|
110
72
|
from snowflake.ml.modeling import pipeline as snowml_pipeline
|
111
73
|
|
112
74
|
# handle pipeline objects separately
|
@@ -118,8 +80,6 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
118
80
|
if hasattr(estimator, method_name):
|
119
81
|
try:
|
120
82
|
result = getattr(estimator, method_name)()
|
121
|
-
if method_name == "to_xgboost" and not cls._can_support_xgb(enable_explainability):
|
122
|
-
return None
|
123
83
|
return result
|
124
84
|
except exceptions.SnowflakeMLException:
|
125
85
|
pass # Do nothing and continue to the next method
|
@@ -168,7 +128,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
168
128
|
model_meta.signatures = temp_model_signature_dict
|
169
129
|
|
170
130
|
if enable_explainability or enable_explainability is None:
|
171
|
-
python_base_obj = cls._get_supported_object_for_explainability(model
|
131
|
+
python_base_obj = cls._get_supported_object_for_explainability(model)
|
172
132
|
if python_base_obj is None:
|
173
133
|
if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
|
174
134
|
raise ValueError(
|
@@ -177,7 +137,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
177
137
|
# set None to False so we don't include shap in the environment
|
178
138
|
enable_explainability = False
|
179
139
|
else:
|
180
|
-
model_task_and_output_type =
|
140
|
+
model_task_and_output_type = model_task_utils.get_model_task_and_output_type(python_base_obj)
|
181
141
|
model_meta.task = model_task_and_output_type.task
|
182
142
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
183
143
|
model_meta = handlers_utils.add_explain_method_signature(
|
@@ -213,28 +173,10 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
213
173
|
model_dependencies = model._get_dependencies()
|
214
174
|
for dep in model_dependencies:
|
215
175
|
pkg_name = dep.split("==")[0]
|
216
|
-
|
217
|
-
_include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
|
218
|
-
continue
|
219
|
-
|
220
|
-
local_xgb_version = cls._get_local_version_package("xgboost")
|
221
|
-
if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
|
222
|
-
model_meta.env.include_if_absent(
|
223
|
-
[
|
224
|
-
model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
|
225
|
-
],
|
226
|
-
check_local_version=False,
|
227
|
-
)
|
228
|
-
else:
|
229
|
-
model_meta.env.include_if_absent(
|
230
|
-
[
|
231
|
-
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
232
|
-
],
|
233
|
-
check_local_version=True,
|
234
|
-
)
|
176
|
+
_include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
|
235
177
|
|
236
178
|
if enable_explainability:
|
237
|
-
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
179
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
|
238
180
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
239
181
|
model_meta.env.include_if_absent(_include_if_absent_pkgs, check_local_version=True)
|
240
182
|
|