snowflake-ml-python 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. snowflake/ml/_internal/telemetry.py +4 -2
  2. snowflake/ml/_internal/utils/import_utils.py +31 -0
  3. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
  4. snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
  5. snowflake/ml/data/data_connector.py +1 -1
  6. snowflake/ml/data/torch_utils.py +33 -14
  7. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
  8. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
  9. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
  10. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
  11. snowflake/ml/feature_store/examples/example_helper.py +6 -3
  12. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
  13. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
  14. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
  15. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
  16. snowflake/ml/feature_store/feature_store.py +1 -2
  17. snowflake/ml/feature_store/feature_view.py +5 -1
  18. snowflake/ml/model/_client/model/model_version_impl.py +144 -10
  19. snowflake/ml/model/_client/ops/model_ops.py +25 -6
  20. snowflake/ml/model/_client/ops/service_ops.py +33 -28
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
  22. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  23. snowflake/ml/model/_client/sql/model.py +14 -0
  24. snowflake/ml/model/_client/sql/service.py +6 -18
  25. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  27. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  28. snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
  29. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -1
  30. snowflake/ml/model/_packager/model_handlers/catboost.py +3 -6
  31. snowflake/ml/model/_packager/model_handlers/custom.py +2 -0
  32. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
  33. snowflake/ml/model/_packager/model_handlers/lightgbm.py +3 -6
  34. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
  35. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -6
  36. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -65
  37. snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
  38. snowflake/ml/model/_packager/model_packager.py +0 -11
  39. snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +13 -25
  40. snowflake/ml/model/_signatures/pandas_handler.py +16 -0
  41. snowflake/ml/model/custom_model.py +47 -7
  42. snowflake/ml/model/model_signature.py +2 -0
  43. snowflake/ml/model/type_hints.py +8 -0
  44. snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
  45. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
  46. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
  47. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
  49. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
  51. snowflake/ml/modeling/cluster/dbscan.py +5 -2
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
  53. snowflake/ml/modeling/cluster/k_means.py +14 -19
  54. snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
  55. snowflake/ml/modeling/cluster/optics.py +6 -6
  56. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
  57. snowflake/ml/modeling/compose/column_transformer.py +15 -5
  58. snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
  59. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  60. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  61. snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
  62. snowflake/ml/modeling/covariance/oas.py +1 -1
  63. snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
  64. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
  65. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
  66. snowflake/ml/modeling/decomposition/pca.py +28 -15
  67. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
  68. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
  69. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
  70. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
  71. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
  72. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
  73. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
  74. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
  75. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
  76. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
  78. snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
  79. snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
  80. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
  81. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  82. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
  83. snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
  84. snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
  85. snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
  86. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  87. snowflake/ml/modeling/linear_model/lars.py +0 -10
  88. snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
  89. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  90. snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
  91. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
  92. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
  93. snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
  94. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
  95. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  96. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  97. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
  98. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
  99. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  100. snowflake/ml/modeling/linear_model/perceptron.py +3 -3
  101. snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
  102. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
  103. snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
  104. snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
  105. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
  106. snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
  107. snowflake/ml/modeling/manifold/isomap.py +1 -1
  108. snowflake/ml/modeling/manifold/mds.py +3 -3
  109. snowflake/ml/modeling/manifold/tsne.py +10 -4
  110. snowflake/ml/modeling/metrics/classification.py +12 -16
  111. snowflake/ml/modeling/metrics/ranking.py +3 -3
  112. snowflake/ml/modeling/metrics/regression.py +3 -3
  113. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
  114. snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
  115. snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
  116. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
  117. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
  118. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
  119. snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
  120. snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
  121. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  122. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
  123. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  124. snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
  125. snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
  126. snowflake/ml/modeling/pipeline/pipeline.py +16 -14
  127. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
  128. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
  129. snowflake/ml/modeling/svm/linear_svc.py +25 -16
  130. snowflake/ml/modeling/svm/linear_svr.py +23 -17
  131. snowflake/ml/modeling/svm/nu_svc.py +5 -3
  132. snowflake/ml/modeling/svm/nu_svr.py +3 -1
  133. snowflake/ml/modeling/svm/svc.py +9 -5
  134. snowflake/ml/modeling/svm/svr.py +3 -1
  135. snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
  136. snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
  137. snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
  138. snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
  139. snowflake/ml/monitoring/_client/{monitor_sql_client.py → model_monitor_sql_client.py} +1 -1
  140. snowflake/ml/monitoring/{_client → _manager}/model_monitor_manager.py +9 -8
  141. snowflake/ml/monitoring/{_client/model_monitor.py → model_monitor.py} +3 -3
  142. snowflake/ml/registry/_manager/model_manager.py +15 -1
  143. snowflake/ml/registry/registry.py +15 -8
  144. snowflake/ml/version.py +1 -1
  145. {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/METADATA +81 -9
  146. {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/RECORD +150 -150
  147. {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/WHEEL +1 -1
  148. /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
  149. {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/LICENSE.txt +0 -0
  150. {snowflake_ml_python-1.6.3.dist-info → snowflake_ml_python-1.7.0.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,9 @@ from snowflake.snowpark._internal import utils as snowpark_utils
32
32
 
33
33
 
34
34
  class ModelOperator:
35
+ INFERENCE_SERVICE_NAME_COL_NAME = "service_name"
36
+ INFERENCE_SERVICE_ENDPOINT_COL_NAME = "endpoints"
37
+
35
38
  def __init__(
36
39
  self,
37
40
  session: session.Session,
@@ -522,7 +525,7 @@ class ModelOperator:
522
525
  model_name: sql_identifier.SqlIdentifier,
523
526
  version_name: sql_identifier.SqlIdentifier,
524
527
  statement_params: Optional[Dict[str, Any]] = None,
525
- ) -> List[str]:
528
+ ) -> Dict[str, List[str]]:
526
529
  res = self._model_client.show_versions(
527
530
  database_name=database_name,
528
531
  schema_name=schema_name,
@@ -530,8 +533,8 @@ class ModelOperator:
530
533
  version_name=version_name,
531
534
  statement_params=statement_params,
532
535
  )
533
- col_name = self._model_client.MODEL_VERSION_INFERENCE_SERVICES_COL_NAME
534
- if col_name not in res[0]:
536
+ service_col_name = self._model_client.MODEL_VERSION_INFERENCE_SERVICES_COL_NAME
537
+ if service_col_name not in res[0]:
535
538
  # User need to opt into BCR 2024_08
536
539
  raise exceptions.SnowflakeMLException(
537
540
  error_code=error_codes.OPT_IN_REQUIRED,
@@ -540,9 +543,24 @@ class ModelOperator:
540
543
  "https://docs.snowflake.com/en/release-notes/bcr-bundles/2024_08_bundle)."
541
544
  ),
542
545
  )
543
- json_array = json.loads(res[0][col_name])
546
+
547
+ json_array = json.loads(res[0][service_col_name])
544
548
  # TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
545
- return [str(service) for service in json_array if "MODEL_BUILD_" not in service]
549
+ services = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
550
+ endpoint_col_name = self._model_client.MODEL_INFERENCE_SERVICE_ENDPOINT_COL_NAME
551
+
552
+ services_col, endpoints_col = [], []
553
+ for service in services:
554
+ res = self._model_client.show_endpoints(service_name=service)
555
+ endpoints = [endpoint[endpoint_col_name] for endpoint in res]
556
+ for endpoint in endpoints:
557
+ services_col.append(service)
558
+ endpoints_col.append(endpoint)
559
+
560
+ return {
561
+ self.INFERENCE_SERVICE_NAME_COL_NAME: services_col,
562
+ self.INFERENCE_SERVICE_ENDPOINT_COL_NAME: endpoints_col,
563
+ }
546
564
 
547
565
  def delete_service(
548
566
  self,
@@ -566,7 +584,8 @@ class ModelOperator:
566
584
  db, schema, service_name, self._session.get_current_database(), self._session.get_current_schema()
567
585
  )
568
586
 
569
- for service in services:
587
+ service_col_name = self.INFERENCE_SERVICE_NAME_COL_NAME
588
+ for service in services[service_col_name]:
570
589
  if service == fully_qualified_service_name:
571
590
  self._service_client.drop_service(
572
591
  database_name=db,
@@ -100,11 +100,13 @@ class ServiceOperator:
100
100
  image_repo_name: sql_identifier.SqlIdentifier,
101
101
  ingress_enabled: bool,
102
102
  max_instances: int,
103
+ cpu_requests: Optional[str],
104
+ memory_requests: Optional[str],
103
105
  gpu_requests: Optional[str],
104
106
  num_workers: Optional[int],
105
107
  max_batch_rows: Optional[int],
106
108
  force_rebuild: bool,
107
- build_external_access_integration: sql_identifier.SqlIdentifier,
109
+ build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
108
110
  statement_params: Optional[Dict[str, Any]] = None,
109
111
  ) -> str:
110
112
  # create a temp stage
@@ -119,6 +121,14 @@ class ServiceOperator:
119
121
  )
120
122
  stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
121
123
 
124
+ # TODO(hayu): Remove the version check after Snowflake 8.40.0 release
125
+ if (
126
+ snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
127
+ < version.parse("8.40.0")
128
+ and build_external_access_integrations is None
129
+ ):
130
+ raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
131
+
122
132
  self._model_deployment_spec.save(
123
133
  database_name=database_name or self._database_name,
124
134
  schema_name=schema_name or self._schema_name,
@@ -134,11 +144,13 @@ class ServiceOperator:
134
144
  image_repo_name=image_repo_name,
135
145
  ingress_enabled=ingress_enabled,
136
146
  max_instances=max_instances,
147
+ cpu=cpu_requests,
148
+ memory=memory_requests,
137
149
  gpu=gpu_requests,
138
150
  num_workers=num_workers,
139
151
  max_batch_rows=max_batch_rows,
140
152
  force_rebuild=force_rebuild,
141
- external_access_integration=build_external_access_integration,
153
+ external_access_integrations=build_external_access_integrations,
142
154
  )
143
155
  file_utils.upload_directory_to_stage(
144
156
  self._session,
@@ -163,32 +175,25 @@ class ServiceOperator:
163
175
  statement_params=statement_params,
164
176
  )
165
177
 
166
- # TODO(hayu): Remove the version check after Snowflake 8.37.0 release
167
- if snowflake_env.get_current_snowflake_version(
168
- self._session, statement_params=statement_params
169
- ) >= version.parse("8.37.0"):
170
- # stream service logs in a thread
171
- model_build_service_name = sql_identifier.SqlIdentifier(self._get_model_build_service_name(query_id))
172
- model_build_service = ServiceLogInfo(
173
- database_name=service_database_name,
174
- schema_name=service_schema_name,
175
- service_name=model_build_service_name,
176
- container_name="model-build",
177
- )
178
- model_inference_service = ServiceLogInfo(
179
- database_name=service_database_name,
180
- schema_name=service_schema_name,
181
- service_name=service_name,
182
- container_name="model-inference",
183
- )
184
- services = [model_build_service, model_inference_service]
185
- log_thread = self._start_service_log_streaming(
186
- async_job, services, model_inference_service_exists, force_rebuild, statement_params
187
- )
188
- log_thread.join()
189
- else:
190
- while not async_job.is_done():
191
- time.sleep(5)
178
+ # stream service logs in a thread
179
+ model_build_service_name = sql_identifier.SqlIdentifier(self._get_model_build_service_name(query_id))
180
+ model_build_service = ServiceLogInfo(
181
+ database_name=service_database_name,
182
+ schema_name=service_schema_name,
183
+ service_name=model_build_service_name,
184
+ container_name="model-build",
185
+ )
186
+ model_inference_service = ServiceLogInfo(
187
+ database_name=service_database_name,
188
+ schema_name=service_schema_name,
189
+ service_name=service_name,
190
+ container_name="model-inference",
191
+ )
192
+ services = [model_build_service, model_inference_service]
193
+ log_thread = self._start_service_log_streaming(
194
+ async_job, services, model_inference_service_exists, force_rebuild, statement_params
195
+ )
196
+ log_thread.join()
192
197
 
193
198
  res = cast(str, cast(List[row.Row], async_job.result())[0][0])
194
199
  module_logger.info(f"Inference service {service_name} deployment complete: {res}")
@@ -1,5 +1,5 @@
1
1
  import pathlib
2
- from typing import Optional
2
+ from typing import List, Optional
3
3
 
4
4
  import yaml
5
5
 
@@ -36,11 +36,13 @@ class ModelDeploymentSpec:
36
36
  image_repo_name: sql_identifier.SqlIdentifier,
37
37
  ingress_enabled: bool,
38
38
  max_instances: int,
39
+ cpu: Optional[str],
40
+ memory: Optional[str],
39
41
  gpu: Optional[str],
40
42
  num_workers: Optional[int],
41
43
  max_batch_rows: Optional[int],
42
44
  force_rebuild: bool,
43
- external_access_integration: sql_identifier.SqlIdentifier,
45
+ external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
44
46
  ) -> None:
45
47
  # create the deployment spec
46
48
  # models spec
@@ -55,12 +57,15 @@ class ModelDeploymentSpec:
55
57
  fq_image_repo_name = identifier.get_schema_level_object_identifier(
56
58
  saved_image_repo_database.identifier(), saved_image_repo_schema.identifier(), image_repo_name.identifier()
57
59
  )
58
- image_build_dict = model_deployment_spec_schema.ImageBuildDict(
59
- compute_pool=image_build_compute_pool_name.identifier(),
60
- image_repo=fq_image_repo_name,
61
- force_rebuild=force_rebuild,
62
- external_access_integrations=[external_access_integration.identifier()],
63
- )
60
+ image_build_dict: model_deployment_spec_schema.ImageBuildDict = {
61
+ "compute_pool": image_build_compute_pool_name.identifier(),
62
+ "image_repo": fq_image_repo_name,
63
+ "force_rebuild": force_rebuild,
64
+ }
65
+ if external_access_integrations is not None:
66
+ image_build_dict["external_access_integrations"] = [
67
+ eai.identifier() for eai in external_access_integrations
68
+ ]
64
69
 
65
70
  # service spec
66
71
  saved_service_database = service_database_name or database_name
@@ -74,6 +79,12 @@ class ModelDeploymentSpec:
74
79
  ingress_enabled=ingress_enabled,
75
80
  max_instances=max_instances,
76
81
  )
82
+ if cpu:
83
+ service_dict["cpu"] = cpu
84
+
85
+ if memory:
86
+ service_dict["memory"] = memory
87
+
77
88
  if gpu:
78
89
  service_dict["gpu"] = gpu
79
90
 
@@ -12,7 +12,7 @@ class ImageBuildDict(TypedDict):
12
12
  compute_pool: Required[str]
13
13
  image_repo: Required[str]
14
14
  force_rebuild: Required[bool]
15
- external_access_integrations: Required[List[str]]
15
+ external_access_integrations: NotRequired[List[str]]
16
16
 
17
17
 
18
18
  class ServiceDict(TypedDict):
@@ -20,6 +20,8 @@ class ServiceDict(TypedDict):
20
20
  compute_pool: Required[str]
21
21
  ingress_enabled: Required[bool]
22
22
  max_instances: Required[int]
23
+ cpu: NotRequired[str]
24
+ memory: NotRequired[str]
23
25
  gpu: NotRequired[str]
24
26
  num_workers: NotRequired[int]
25
27
  max_batch_rows: NotRequired[int]
@@ -17,6 +17,8 @@ class ModelSQLClient(_base._BaseSQLClient):
17
17
  MODEL_VERSION_ALIASES_COL_NAME = "aliases"
18
18
  MODEL_VERSION_INFERENCE_SERVICES_COL_NAME = "inference_services"
19
19
 
20
+ MODEL_INFERENCE_SERVICE_ENDPOINT_COL_NAME = "name"
21
+
20
22
  def show_models(
21
23
  self,
22
24
  *,
@@ -83,6 +85,18 @@ class ModelSQLClient(_base._BaseSQLClient):
83
85
 
84
86
  return res.validate()
85
87
 
88
+ def show_endpoints(
89
+ self,
90
+ *,
91
+ service_name: str,
92
+ ) -> List[row.Row]:
93
+ res = query_result_checker.SqlResultValidator(
94
+ self._session,
95
+ (f"SHOW ENDPOINTS IN SERVICE {service_name}"),
96
+ ).has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
97
+
98
+ return res.validate()
99
+
86
100
  def set_comment(
87
101
  self,
88
102
  *,
@@ -3,13 +3,10 @@ import json
3
3
  import textwrap
4
4
  from typing import Any, Dict, List, Optional, Tuple
5
5
 
6
- from packaging import version
7
-
8
6
  from snowflake import snowpark
9
7
  from snowflake.ml._internal.utils import (
10
8
  identifier,
11
9
  query_result_checker,
12
- snowflake_env,
13
10
  sql_identifier,
14
11
  )
15
12
  from snowflake.ml.model._client.sql import _base
@@ -120,21 +117,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
120
117
  args_sql_list.append(input_arg_value)
121
118
  args_sql = ", ".join(args_sql_list)
122
119
 
123
- if snowflake_env.get_current_snowflake_version(
124
- self._session, statement_params=statement_params
125
- ) >= version.parse("8.39.0"):
126
- fully_qualified_service_name = self.fully_qualified_object_name(
127
- actual_database_name, actual_schema_name, service_name
128
- )
129
- fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
130
-
131
- else:
132
- function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
133
- fully_qualified_function_name = identifier.get_schema_level_object_identifier(
134
- actual_database_name.identifier(),
135
- actual_schema_name.identifier(),
136
- function_name,
137
- )
120
+ function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
121
+ fully_qualified_function_name = identifier.get_schema_level_object_identifier(
122
+ actual_database_name.identifier(),
123
+ actual_schema_name.identifier(),
124
+ function_name,
125
+ )
138
126
 
139
127
  sql = textwrap.dedent(
140
128
  f"""{with_sql}
@@ -86,6 +86,7 @@ class ModelComposer:
86
86
  metadata: Optional[Dict[str, str]] = None,
87
87
  conda_dependencies: Optional[List[str]] = None,
88
88
  pip_requirements: Optional[List[str]] = None,
89
+ target_platforms: Optional[List[model_types.TargetPlatform]] = None,
89
90
  python_version: Optional[str] = None,
90
91
  ext_modules: Optional[List[ModuleType]] = None,
91
92
  code_paths: Optional[List[str]] = None,
@@ -131,6 +132,7 @@ class ModelComposer:
131
132
  model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
132
133
  options=options,
133
134
  data_sources=self._get_data_sources(model, sample_input_data),
135
+ target_platforms=target_platforms,
134
136
  )
135
137
 
136
138
  file_utils.upload_directory_to_stage(
@@ -44,6 +44,7 @@ class ModelManifest:
44
44
  model_rel_path: pathlib.PurePosixPath,
45
45
  options: Optional[type_hints.ModelSaveOption] = None,
46
46
  data_sources: Optional[List[data_source.DataSource]] = None,
47
+ target_platforms: Optional[List[type_hints.TargetPlatform]] = None,
47
48
  ) -> None:
48
49
  if options is None:
49
50
  options = {}
@@ -132,6 +133,9 @@ class ModelManifest:
132
133
  if lineage_sources:
133
134
  manifest_dict["lineage_sources"] = lineage_sources
134
135
 
136
+ if target_platforms:
137
+ manifest_dict["target_platforms"] = [platform.value for platform in target_platforms]
138
+
135
139
  with (self.workspace_path / ModelManifest.MANIFEST_FILE_REL_PATH).open("w", encoding="utf-8") as f:
136
140
  # Anchors are not supported in the server, avoid that.
137
141
  yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
@@ -95,3 +95,4 @@ class ModelManifestDict(TypedDict):
95
95
  methods: Required[List[ModelMethodDict]]
96
96
  user_data: NotRequired[Dict[str, Any]]
97
97
  lineage_sources: NotRequired[List[LineageSourceDict]]
98
+ target_platforms: NotRequired[List[str]]
@@ -27,7 +27,7 @@ def get_model_method_options_from_options(
27
27
  options: type_hints.ModelSaveOption, target_method: str
28
28
  ) -> ModelMethodOptions:
29
29
  default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
30
- if options.get("enable_explainability", False) and target_method.startswith("explain"):
30
+ if target_method == "explain":
31
31
  default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
32
32
  method_option = options.get("method_options", {}).get(target_method, {})
33
33
  global_function_type = options.get("function_type", default_function_type)
@@ -191,7 +191,11 @@ def convert_explanations_to_2D_df(
191
191
  # convert to object or numpy creates strings of fixed length
192
192
  return np.asarray(json.dumps(dict(zip(classes_list, row)), cls=NumpyEncoder), dtype=object)
193
193
 
194
- exp_2d = np.apply_along_axis(row_to_dict, -1, explanations)
194
+ # convert to dict only for multiclass
195
+ if len(classes_list) > 2:
196
+ exp_2d = np.apply_along_axis(row_to_dict, -1, explanations)
197
+ else: # assumes index 1 is positive class always
198
+ exp_2d = np.apply_along_axis(lambda arr: arr[1], -1, explanations)
195
199
 
196
200
  return pd.DataFrame(exp_2d)
197
201
 
@@ -9,17 +9,14 @@ from typing_extensions import TypeGuard, Unpack
9
9
  from snowflake.ml._internal import type_utils
10
10
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
11
11
  from snowflake.ml.model._packager.model_env import model_env
12
- from snowflake.ml.model._packager.model_handlers import (
13
- _base,
14
- _utils as handlers_utils,
15
- model_objective_utils,
16
- )
12
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
17
13
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
18
14
  from snowflake.ml.model._packager.model_meta import (
19
15
  model_blob_meta,
20
16
  model_meta as model_meta_api,
21
17
  model_meta_schema,
22
18
  )
19
+ from snowflake.ml.model._packager.model_task import model_task_utils
23
20
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
24
21
 
25
22
  if TYPE_CHECKING:
@@ -97,7 +94,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
97
94
  sample_input_data=sample_input_data,
98
95
  get_prediction_fn=get_prediction,
99
96
  )
100
- model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
97
+ model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
101
98
  model_meta.task = model_task_and_output.task
102
99
  if enable_explainability:
103
100
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
@@ -99,6 +99,8 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
99
99
  for sub_name, model_ref in model.context.model_refs.items():
100
100
  handler = model_handler.find_handler(model_ref.model)
101
101
  assert handler is not None
102
+ if handler is None:
103
+ raise TypeError("Your input type to custom model is not currently supported")
102
104
  sub_model = handler.cast_model(model_ref.model)
103
105
  handler.save_model(
104
106
  name=sub_name,
@@ -256,12 +256,20 @@ class HuggingFacePipelineHandler(
256
256
  @staticmethod
257
257
  def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> Dict[str, str]:
258
258
  device_config: Dict[str, Any] = {}
259
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
260
+ gpu_nums = 0
261
+ if cuda_visible_devices is not None:
262
+ gpu_nums = len(cuda_visible_devices.split(","))
259
263
  if (
260
264
  kwargs.get("use_gpu", False)
261
265
  and kwargs.get("device_map", None) is None
262
266
  and kwargs.get("device", None) is None
263
267
  ):
264
- device_config["device_map"] = "auto"
268
+ if gpu_nums == 0 or gpu_nums > 1:
269
+ # Use accelerator if there are multiple GPUs or no GPU
270
+ device_config["device_map"] = "auto"
271
+ else:
272
+ device_config["device"] = "cuda"
265
273
  elif kwargs.get("device_map", None) is not None:
266
274
  device_config["device_map"] = kwargs["device_map"]
267
275
  elif kwargs.get("device", None) is not None:
@@ -310,6 +318,7 @@ class HuggingFacePipelineHandler(
310
318
  m = transformers.pipeline(
311
319
  model_blob_options["task"],
312
320
  model=model_blob_file_or_dir_path,
321
+ trust_remote_code=True,
313
322
  **device_config,
314
323
  )
315
324
 
@@ -20,17 +20,14 @@ from typing_extensions import TypeGuard, Unpack
20
20
  from snowflake.ml._internal import type_utils
21
21
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
22
22
  from snowflake.ml.model._packager.model_env import model_env
23
- from snowflake.ml.model._packager.model_handlers import (
24
- _base,
25
- _utils as handlers_utils,
26
- model_objective_utils,
27
- )
23
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
28
24
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
29
25
  from snowflake.ml.model._packager.model_meta import (
30
26
  model_blob_meta,
31
27
  model_meta as model_meta_api,
32
28
  model_meta_schema,
33
29
  )
30
+ from snowflake.ml.model._packager.model_task import model_task_utils
34
31
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
35
32
 
36
33
  if TYPE_CHECKING:
@@ -113,7 +110,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
113
110
  sample_input_data=sample_input_data,
114
111
  get_prediction_fn=get_prediction,
115
112
  )
116
- model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
113
+ model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
117
114
  model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
118
115
  if enable_explainability:
119
116
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
@@ -1,3 +1,4 @@
1
+ import inspect
1
2
  import logging
2
3
  import os
3
4
  from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
@@ -155,8 +156,14 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
155
156
  model_blob_filename = model_blob_metadata.path
156
157
  model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename)
157
158
 
159
+ additional_kwargs = {}
160
+ if "trust_remote_code" in inspect.signature(sentence_transformers.SentenceTransformer).parameters:
161
+ additional_kwargs["trust_remote_code"] = True
162
+
158
163
  model = sentence_transformers.SentenceTransformer(
159
- model_blob_file_or_dir_path, device=cls._get_device_config(**kwargs)
164
+ model_blob_file_or_dir_path,
165
+ device=cls._get_device_config(**kwargs),
166
+ **additional_kwargs,
160
167
  )
161
168
  return model
162
169
 
@@ -10,17 +10,14 @@ from typing_extensions import TypeGuard, Unpack
10
10
  from snowflake.ml._internal import type_utils
11
11
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
12
12
  from snowflake.ml.model._packager.model_env import model_env
13
- from snowflake.ml.model._packager.model_handlers import (
14
- _base,
15
- _utils as handlers_utils,
16
- model_objective_utils,
17
- )
13
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
18
14
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
19
15
  from snowflake.ml.model._packager.model_meta import (
20
16
  model_blob_meta,
21
17
  model_meta as model_meta_api,
22
18
  model_meta_schema,
23
19
  )
20
+ from snowflake.ml.model._packager.model_task import model_task_utils
24
21
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
25
22
 
26
23
  if TYPE_CHECKING:
@@ -137,7 +134,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
137
134
  sample_input_data, model_meta, explain_target_method
138
135
  )
139
136
 
140
- model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(model)
137
+ model_task_and_output_type = model_task_utils.get_model_task_and_output_type(model)
141
138
  model_meta.task = model_task_and_output_type.task
142
139
 
143
140
  # if users did not ask then we enable if we have background data
@@ -5,24 +5,20 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, fin
5
5
  import cloudpickle
6
6
  import numpy as np
7
7
  import pandas as pd
8
- from packaging import version
9
8
  from typing_extensions import TypeGuard, Unpack
10
9
 
11
10
  from snowflake.ml._internal import type_utils
12
11
  from snowflake.ml._internal.exceptions import exceptions
13
12
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
14
13
  from snowflake.ml.model._packager.model_env import model_env
15
- from snowflake.ml.model._packager.model_handlers import (
16
- _base,
17
- _utils as handlers_utils,
18
- model_objective_utils,
19
- )
14
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
20
15
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
21
16
  from snowflake.ml.model._packager.model_meta import (
22
17
  model_blob_meta,
23
18
  model_meta as model_meta_api,
24
19
  model_meta_schema,
25
20
  )
21
+ from snowflake.ml.model._packager.model_task import model_task_utils
26
22
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
27
23
 
28
24
  if TYPE_CHECKING:
@@ -72,41 +68,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
72
68
  return cast("BaseEstimator", model)
73
69
 
74
70
  @classmethod
75
- def _get_local_version_package(cls, pkg_name: str) -> Optional[version.Version]:
76
- from importlib import metadata as importlib_metadata
77
-
78
- from packaging import version
79
-
80
- local_version = None
81
-
82
- try:
83
- local_dist = importlib_metadata.distribution(pkg_name)
84
- local_version = version.parse(local_dist.version)
85
- except importlib_metadata.PackageNotFoundError:
86
- pass
87
-
88
- return local_version
89
-
90
- @classmethod
91
- def _can_support_xgb(cls, enable_explainability: Optional[bool]) -> bool:
92
-
93
- local_xgb_version = cls._get_local_version_package("xgboost")
94
-
95
- if local_xgb_version and local_xgb_version >= version.parse("2.1.0"):
96
- if enable_explainability:
97
- warnings.warn(
98
- f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
99
- + "If you want model explanations, lower the xgboost version to <2.1.0.",
100
- category=UserWarning,
101
- stacklevel=1,
102
- )
103
- return False
104
- return True
105
-
106
- @classmethod
107
- def _get_supported_object_for_explainability(
108
- cls, estimator: "BaseEstimator", enable_explainability: Optional[bool]
109
- ) -> Any:
71
+ def _get_supported_object_for_explainability(cls, estimator: "BaseEstimator") -> Any:
110
72
  from snowflake.ml.modeling import pipeline as snowml_pipeline
111
73
 
112
74
  # handle pipeline objects separately
@@ -118,8 +80,6 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
118
80
  if hasattr(estimator, method_name):
119
81
  try:
120
82
  result = getattr(estimator, method_name)()
121
- if method_name == "to_xgboost" and not cls._can_support_xgb(enable_explainability):
122
- return None
123
83
  return result
124
84
  except exceptions.SnowflakeMLException:
125
85
  pass # Do nothing and continue to the next method
@@ -168,7 +128,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
168
128
  model_meta.signatures = temp_model_signature_dict
169
129
 
170
130
  if enable_explainability or enable_explainability is None:
171
- python_base_obj = cls._get_supported_object_for_explainability(model, enable_explainability)
131
+ python_base_obj = cls._get_supported_object_for_explainability(model)
172
132
  if python_base_obj is None:
173
133
  if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
174
134
  raise ValueError(
@@ -177,7 +137,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
177
137
  # set None to False so we don't include shap in the environment
178
138
  enable_explainability = False
179
139
  else:
180
- model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(python_base_obj)
140
+ model_task_and_output_type = model_task_utils.get_model_task_and_output_type(python_base_obj)
181
141
  model_meta.task = model_task_and_output_type.task
182
142
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
183
143
  model_meta = handlers_utils.add_explain_method_signature(
@@ -213,28 +173,10 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
213
173
  model_dependencies = model._get_dependencies()
214
174
  for dep in model_dependencies:
215
175
  pkg_name = dep.split("==")[0]
216
- if pkg_name != "xgboost":
217
- _include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
218
- continue
219
-
220
- local_xgb_version = cls._get_local_version_package("xgboost")
221
- if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
222
- model_meta.env.include_if_absent(
223
- [
224
- model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
225
- ],
226
- check_local_version=False,
227
- )
228
- else:
229
- model_meta.env.include_if_absent(
230
- [
231
- model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
232
- ],
233
- check_local_version=True,
234
- )
176
+ _include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
235
177
 
236
178
  if enable_explainability:
237
- model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
179
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
238
180
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
239
181
  model_meta.env.include_if_absent(_include_if_absent_pkgs, check_local_version=True)
240
182