snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/file_utils.py +18 -4
  3. snowflake/ml/_internal/platform_capabilities.py +3 -0
  4. snowflake/ml/_internal/telemetry.py +4 -0
  5. snowflake/ml/fileset/fileset.py +0 -1
  6. snowflake/ml/jobs/_utils/constants.py +25 -1
  7. snowflake/ml/jobs/_utils/payload_utils.py +94 -20
  8. snowflake/ml/jobs/_utils/spec_utils.py +95 -31
  9. snowflake/ml/jobs/decorators.py +7 -0
  10. snowflake/ml/jobs/manager.py +20 -0
  11. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  12. snowflake/ml/model/_client/ops/model_ops.py +113 -17
  13. snowflake/ml/model/_client/ops/service_ops.py +16 -5
  14. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  15. snowflake/ml/model/_client/sql/model_version.py +58 -0
  16. snowflake/ml/model/_client/sql/service.py +10 -2
  17. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  18. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +5 -2
  19. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  20. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  21. snowflake/ml/model/_packager/model_env/model_env.py +4 -1
  22. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
  23. snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
  24. snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
  25. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -8
  26. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
  27. snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
  28. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  29. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  30. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  31. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  32. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -2
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
  35. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -0
  36. snowflake/ml/model/_packager/model_packager.py +3 -5
  37. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  38. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
  39. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  40. snowflake/ml/model/_signatures/core.py +52 -31
  41. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  42. snowflake/ml/model/_signatures/numpy_handler.py +9 -17
  43. snowflake/ml/model/_signatures/pandas_handler.py +19 -30
  44. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  45. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  46. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  47. snowflake/ml/model/_signatures/utils.py +120 -8
  48. snowflake/ml/model/custom_model.py +13 -4
  49. snowflake/ml/model/model_signature.py +31 -13
  50. snowflake/ml/model/type_hints.py +13 -2
  51. snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
  52. snowflake/ml/modeling/metrics/ranking.py +3 -0
  53. snowflake/ml/modeling/metrics/regression.py +3 -0
  54. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  55. snowflake/ml/registry/_manager/model_manager.py +55 -7
  56. snowflake/ml/registry/registry.py +59 -1
  57. snowflake/ml/version.py +1 -1
  58. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/METADATA +308 -12
  59. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/RECORD +62 -58
  60. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/WHEEL +1 -1
  61. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info/licenses}/LICENSE.txt +0 -0
  62. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/top_level.txt +0 -0
@@ -106,6 +106,8 @@ def submit_file(
106
106
  external_access_integrations: Optional[List[str]] = None,
107
107
  query_warehouse: Optional[str] = None,
108
108
  spec_overrides: Optional[Dict[str, Any]] = None,
109
+ num_instances: Optional[int] = None,
110
+ enable_metrics: bool = False,
109
111
  session: Optional[snowpark.Session] = None,
110
112
  ) -> jb.MLJob:
111
113
  """
@@ -121,6 +123,8 @@ def submit_file(
121
123
  external_access_integrations: A list of external access integrations.
122
124
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
123
125
  spec_overrides: Custom service specification overrides to apply.
126
+ num_instances: The number of instances to use for the job. If none specified, single node job is created.
127
+ enable_metrics: Whether to enable metrics publishing for the job.
124
128
  session: The Snowpark session to use. If none specified, uses active session.
125
129
 
126
130
  Returns:
@@ -136,6 +140,8 @@ def submit_file(
136
140
  external_access_integrations=external_access_integrations,
137
141
  query_warehouse=query_warehouse,
138
142
  spec_overrides=spec_overrides,
143
+ num_instances=num_instances,
144
+ enable_metrics=enable_metrics,
139
145
  session=session,
140
146
  )
141
147
 
@@ -154,6 +160,8 @@ def submit_directory(
154
160
  external_access_integrations: Optional[List[str]] = None,
155
161
  query_warehouse: Optional[str] = None,
156
162
  spec_overrides: Optional[Dict[str, Any]] = None,
163
+ num_instances: Optional[int] = None,
164
+ enable_metrics: bool = False,
157
165
  session: Optional[snowpark.Session] = None,
158
166
  ) -> jb.MLJob:
159
167
  """
@@ -170,6 +178,8 @@ def submit_directory(
170
178
  external_access_integrations: A list of external access integrations.
171
179
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
172
180
  spec_overrides: Custom service specification overrides to apply.
181
+ num_instances: The number of instances to use for the job. If none specified, single node job is created.
182
+ enable_metrics: Whether to enable metrics publishing for the job.
173
183
  session: The Snowpark session to use. If none specified, uses active session.
174
184
 
175
185
  Returns:
@@ -186,6 +196,8 @@ def submit_directory(
186
196
  external_access_integrations=external_access_integrations,
187
197
  query_warehouse=query_warehouse,
188
198
  spec_overrides=spec_overrides,
199
+ num_instances=num_instances,
200
+ enable_metrics=enable_metrics,
189
201
  session=session,
190
202
  )
191
203
 
@@ -212,6 +224,8 @@ def _submit_job(
212
224
  external_access_integrations: Optional[List[str]] = None,
213
225
  query_warehouse: Optional[str] = None,
214
226
  spec_overrides: Optional[Dict[str, Any]] = None,
227
+ num_instances: Optional[int] = None,
228
+ enable_metrics: bool = False,
215
229
  session: Optional[snowpark.Session] = None,
216
230
  ) -> jb.MLJob:
217
231
  """
@@ -228,6 +242,8 @@ def _submit_job(
228
242
  external_access_integrations: A list of external access integrations.
229
243
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
230
244
  spec_overrides: Custom service specification overrides to apply.
245
+ num_instances: The number of instances to use for the job. If none specified, single node job is created.
246
+ enable_metrics: Whether to enable metrics publishing for the job.
231
247
  session: The Snowpark session to use. If none specified, uses active session.
232
248
 
233
249
  Returns:
@@ -254,6 +270,8 @@ def _submit_job(
254
270
  compute_pool=compute_pool,
255
271
  payload=uploaded_payload,
256
272
  args=args,
273
+ num_instances=num_instances,
274
+ enable_metrics=enable_metrics,
257
275
  )
258
276
  spec_overrides = spec_utils.generate_spec_overrides(
259
277
  environment_vars=env_vars,
@@ -281,6 +299,8 @@ def _submit_job(
281
299
  query_warehouse = query_warehouse or session.get_current_warehouse()
282
300
  if query_warehouse:
283
301
  query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
302
+ if num_instances:
303
+ query.append(f"REPLICAS = {num_instances}")
284
304
 
285
305
  # Submit job
286
306
  query_text = "\n".join(line for line in query if line)
@@ -746,7 +746,7 @@ class ModelVersion(lineage_node.LineageNode):
746
746
  max_instances: int = 1,
747
747
  cpu_requests: Optional[str] = None,
748
748
  memory_requests: Optional[str] = None,
749
- gpu_requests: Optional[str] = None,
749
+ gpu_requests: Optional[Union[str, int]] = None,
750
750
  num_workers: Optional[int] = None,
751
751
  max_batch_rows: Optional[int] = None,
752
752
  force_rebuild: bool = False,
@@ -1,3 +1,4 @@
1
+ import enum
1
2
  import json
2
3
  import os
3
4
  import pathlib
@@ -31,6 +32,12 @@ from snowflake.snowpark import dataframe, row, session
31
32
  from snowflake.snowpark._internal import utils as snowpark_utils
32
33
 
33
34
 
35
+ # An enum class to represent Create Or Alter Model SQL command.
36
+ class ModelAction(enum.Enum):
37
+ CREATE = "CREATE"
38
+ ALTER = "ALTER"
39
+
40
+
34
41
  class ServiceInfo(TypedDict):
35
42
  name: str
36
43
  status: str
@@ -92,7 +99,7 @@ class ModelOperator:
92
99
  and self._model_version_client == __value._model_version_client
93
100
  )
94
101
 
95
- def prepare_model_stage_path(
102
+ def prepare_model_temp_stage_path(
96
103
  self,
97
104
  *,
98
105
  database_name: Optional[sql_identifier.SqlIdentifier],
@@ -110,17 +117,28 @@ class ModelOperator:
110
117
  )
111
118
  return f"@{self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)}/model"
112
119
 
113
- def create_from_stage(
120
+ def get_model_version_stage_path(
121
+ self,
122
+ *,
123
+ database_name: Optional[sql_identifier.SqlIdentifier],
124
+ schema_name: Optional[sql_identifier.SqlIdentifier],
125
+ model_name: sql_identifier.SqlIdentifier,
126
+ version_name: sql_identifier.SqlIdentifier,
127
+ ) -> str:
128
+ return (
129
+ f"snow://model/{self._stage_client.fully_qualified_object_name(database_name, schema_name, model_name)}"
130
+ f"/versions/{version_name}/"
131
+ )
132
+
133
+ def get_model_action_from_model_name_and_version(
114
134
  self,
115
- composed_model: model_composer.ModelComposer,
116
135
  *,
117
136
  database_name: Optional[sql_identifier.SqlIdentifier],
118
137
  schema_name: Optional[sql_identifier.SqlIdentifier],
119
138
  model_name: sql_identifier.SqlIdentifier,
120
139
  version_name: sql_identifier.SqlIdentifier,
121
140
  statement_params: Optional[Dict[str, Any]] = None,
122
- ) -> None:
123
- stage_path = str(composed_model.stage_path)
141
+ ) -> ModelAction:
124
142
  if self.validate_existence(
125
143
  database_name=database_name,
126
144
  schema_name=schema_name,
@@ -140,6 +158,79 @@ class ModelOperator:
140
158
  f" version {version_name} already existed."
141
159
  )
142
160
  else:
161
+ return ModelAction.ALTER
162
+ else:
163
+ return ModelAction.CREATE
164
+
165
+ def add_or_create_live_version(
166
+ self,
167
+ *,
168
+ database_name: Optional[sql_identifier.SqlIdentifier],
169
+ schema_name: Optional[sql_identifier.SqlIdentifier],
170
+ model_name: sql_identifier.SqlIdentifier,
171
+ version_name: sql_identifier.SqlIdentifier,
172
+ statement_params: Optional[Dict[str, Any]] = None,
173
+ ) -> None:
174
+ model_action = self.get_model_action_from_model_name_and_version(
175
+ database_name=database_name,
176
+ schema_name=schema_name,
177
+ model_name=model_name,
178
+ version_name=version_name,
179
+ statement_params=statement_params,
180
+ )
181
+ if model_action == ModelAction.CREATE:
182
+ self._model_version_client.create_live_version(
183
+ database_name=database_name,
184
+ schema_name=schema_name,
185
+ model_name=model_name,
186
+ version_name=version_name,
187
+ statement_params=statement_params,
188
+ )
189
+ elif model_action == ModelAction.ALTER:
190
+ self._model_version_client.add_live_version(
191
+ database_name=database_name,
192
+ schema_name=schema_name,
193
+ model_name=model_name,
194
+ version_name=version_name,
195
+ statement_params=statement_params,
196
+ )
197
+ else:
198
+ raise AssertionError(f"The model_action is {model_action}. Expected CREATE or ALTER.")
199
+
200
+ def create_from_stage(
201
+ self,
202
+ composed_model: model_composer.ModelComposer,
203
+ *,
204
+ database_name: Optional[sql_identifier.SqlIdentifier],
205
+ schema_name: Optional[sql_identifier.SqlIdentifier],
206
+ model_name: sql_identifier.SqlIdentifier,
207
+ version_name: sql_identifier.SqlIdentifier,
208
+ statement_params: Optional[Dict[str, Any]] = None,
209
+ use_live_commit: Optional[bool] = False,
210
+ ) -> None:
211
+
212
+ if use_live_commit:
213
+ # if the model version is live, we can only commit the version
214
+ self._model_version_client.commit_version(
215
+ database_name=database_name,
216
+ schema_name=schema_name,
217
+ model_name=model_name,
218
+ version_name=version_name,
219
+ statement_params=statement_params,
220
+ )
221
+ else:
222
+ stage_path = str(composed_model.stage_path)
223
+ # if the model version is not live,
224
+ # find whether the model exists and whether the version exists
225
+ # and then decide whether to create or alter the model
226
+ model_action = self.get_model_action_from_model_name_and_version(
227
+ database_name=database_name,
228
+ schema_name=schema_name,
229
+ model_name=model_name,
230
+ version_name=version_name,
231
+ statement_params=statement_params,
232
+ )
233
+ if model_action == ModelAction.ALTER:
143
234
  self._model_version_client.add_version_from_stage(
144
235
  database_name=database_name,
145
236
  schema_name=schema_name,
@@ -148,15 +239,17 @@ class ModelOperator:
148
239
  version_name=version_name,
149
240
  statement_params=statement_params,
150
241
  )
151
- else:
152
- self._model_version_client.create_from_stage(
153
- database_name=database_name,
154
- schema_name=schema_name,
155
- stage_path=stage_path,
156
- model_name=model_name,
157
- version_name=version_name,
158
- statement_params=statement_params,
159
- )
242
+ elif model_action == ModelAction.CREATE:
243
+ self._model_version_client.create_from_stage(
244
+ database_name=database_name,
245
+ schema_name=schema_name,
246
+ stage_path=stage_path,
247
+ model_name=model_name,
248
+ version_name=version_name,
249
+ statement_params=statement_params,
250
+ )
251
+ else:
252
+ raise AssertionError(f"The model_action is {model_action}. Expected CREATE or ALTER.")
160
253
 
161
254
  def create_from_model_version(
162
255
  self,
@@ -696,14 +789,17 @@ class ModelOperator:
696
789
  version_name: sql_identifier.SqlIdentifier,
697
790
  statement_params: Optional[Dict[str, Any]] = None,
698
791
  ) -> type_hints.Task:
699
- model_spec = self._fetch_model_spec(
792
+ model_version = self._model_client.show_versions(
700
793
  database_name=database_name,
701
794
  schema_name=schema_name,
702
795
  model_name=model_name,
703
796
  version_name=version_name,
797
+ validate_result=True,
704
798
  statement_params=statement_params,
705
- )
706
- task_val = model_spec.get("task", type_hints.Task.UNKNOWN.value)
799
+ )[0]
800
+
801
+ model_attributes = json.loads(model_version.model_attributes)
802
+ task_val = model_attributes.get("task", type_hints.Task.UNKNOWN.value)
707
803
  return type_hints.Task(task_val)
708
804
 
709
805
  def get_functions(
@@ -100,7 +100,7 @@ class ServiceOperator:
100
100
  max_instances: int,
101
101
  cpu_requests: Optional[str],
102
102
  memory_requests: Optional[str],
103
- gpu_requests: Optional[str],
103
+ gpu_requests: Optional[Union[int, str]],
104
104
  num_workers: Optional[int],
105
105
  max_batch_rows: Optional[int],
106
106
  force_rebuild: bool,
@@ -161,12 +161,16 @@ class ServiceOperator:
161
161
  statement_params=statement_params,
162
162
  )
163
163
 
164
- # check if the inference service is already running
164
+ # check if the inference service is already running/suspended
165
165
  model_inference_service_exists = self._check_if_service_exists(
166
166
  database_name=service_database_name,
167
167
  schema_name=service_schema_name,
168
168
  service_name=service_name,
169
- service_status_list_if_exists=[service_sql.ServiceStatus.READY],
169
+ service_status_list_if_exists=[
170
+ service_sql.ServiceStatus.READY,
171
+ service_sql.ServiceStatus.SUSPENDING,
172
+ service_sql.ServiceStatus.SUSPENDED,
173
+ ],
170
174
  statement_params=statement_params,
171
175
  )
172
176
 
@@ -309,7 +313,10 @@ class ServiceOperator:
309
313
  set_service_log_metadata_to_model_inference(
310
314
  service_log_meta,
311
315
  model_inference_service,
312
- "Model Inference image build is not rebuilding the image and using previously built image.",
316
+ (
317
+ "Model Inference image build is not rebuilding the image, but using a previously built "
318
+ "image."
319
+ ),
313
320
  )
314
321
  continue
315
322
 
@@ -366,7 +373,9 @@ class ServiceOperator:
366
373
  time.sleep(5)
367
374
 
368
375
  if model_inference_service_exists:
369
- module_logger.info(f"Inference service {model_inference_service.display_service_name} is already RUNNING.")
376
+ module_logger.info(
377
+ f"Inference service {model_inference_service.display_service_name} has already been deployed."
378
+ )
370
379
  else:
371
380
  self._finalize_logs(
372
381
  service_log_meta.service_logger, service_log_meta.service, service_log_meta.log_offset, statement_params
@@ -416,6 +425,8 @@ class ServiceOperator:
416
425
  service_status_list_if_exists = [
417
426
  service_sql.ServiceStatus.PENDING,
418
427
  service_sql.ServiceStatus.READY,
428
+ service_sql.ServiceStatus.SUSPENDING,
429
+ service_sql.ServiceStatus.SUSPENDED,
419
430
  service_sql.ServiceStatus.DONE,
420
431
  service_sql.ServiceStatus.FAILED,
421
432
  ]
@@ -1,5 +1,5 @@
1
1
  import pathlib
2
- from typing import List, Optional
2
+ from typing import List, Optional, Union
3
3
 
4
4
  import yaml
5
5
 
@@ -38,7 +38,7 @@ class ModelDeploymentSpec:
38
38
  max_instances: int,
39
39
  cpu: Optional[str],
40
40
  memory: Optional[str],
41
- gpu: Optional[str],
41
+ gpu: Optional[Union[str, int]],
42
42
  num_workers: Optional[int],
43
43
  max_batch_rows: Optional[int],
44
44
  force_rebuild: bool,
@@ -86,7 +86,11 @@ class ModelDeploymentSpec:
86
86
  service_dict["memory"] = memory
87
87
 
88
88
  if gpu:
89
- service_dict["gpu"] = gpu
89
+ if isinstance(gpu, int):
90
+ gpu_str = str(gpu)
91
+ else:
92
+ gpu_str = gpu
93
+ service_dict["gpu"] = gpu_str
90
94
 
91
95
  if num_workers:
92
96
  service_dict["num_workers"] = num_workers
@@ -71,6 +71,64 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
71
71
  statement_params=statement_params,
72
72
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
73
73
 
74
+ def create_live_version(
75
+ self,
76
+ *,
77
+ database_name: Optional[sql_identifier.SqlIdentifier],
78
+ schema_name: Optional[sql_identifier.SqlIdentifier],
79
+ model_name: sql_identifier.SqlIdentifier,
80
+ version_name: sql_identifier.SqlIdentifier,
81
+ statement_params: Optional[Dict[str, Any]] = None,
82
+ ) -> None:
83
+ sql = (
84
+ f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
85
+ f" WITH LIVE VERSION {version_name.identifier()}"
86
+ )
87
+ query_result_checker.SqlResultValidator(
88
+ self._session,
89
+ sql,
90
+ statement_params=statement_params,
91
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
92
+
93
+ def add_live_version(
94
+ self,
95
+ *,
96
+ database_name: Optional[sql_identifier.SqlIdentifier],
97
+ schema_name: Optional[sql_identifier.SqlIdentifier],
98
+ model_name: sql_identifier.SqlIdentifier,
99
+ version_name: sql_identifier.SqlIdentifier,
100
+ statement_params: Optional[Dict[str, Any]] = None,
101
+ ) -> None:
102
+ sql = (
103
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
104
+ f" ADD LIVE VERSION {version_name.identifier()}"
105
+ )
106
+ query_result_checker.SqlResultValidator(
107
+ self._session,
108
+ sql,
109
+ statement_params=statement_params,
110
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
111
+
112
+ def commit_version(
113
+ self,
114
+ *,
115
+ database_name: Optional[sql_identifier.SqlIdentifier],
116
+ schema_name: Optional[sql_identifier.SqlIdentifier],
117
+ model_name: sql_identifier.SqlIdentifier,
118
+ version_name: sql_identifier.SqlIdentifier,
119
+ statement_params: Optional[Dict[str, Any]] = None,
120
+ ) -> None:
121
+ sql = (
122
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
123
+ f" COMMIT VERSION {version_name.identifier()}"
124
+ )
125
+
126
+ query_result_checker.SqlResultValidator(
127
+ self._session,
128
+ sql,
129
+ statement_params=statement_params,
130
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
131
+
74
132
  # TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...`
75
133
  def add_version_from_stage(
76
134
  self,
@@ -1,7 +1,7 @@
1
1
  import enum
2
2
  import json
3
3
  import textwrap
4
- from typing import Any, Dict, List, Optional, Tuple
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
5
 
6
6
  from snowflake import snowpark
7
7
  from snowflake.ml._internal import platform_capabilities
@@ -11,6 +11,7 @@ from snowflake.ml._internal.utils import (
11
11
  sql_identifier,
12
12
  )
13
13
  from snowflake.ml.model._client.sql import _base
14
+ from snowflake.ml.model._model_composer.model_method import constants
14
15
  from snowflake.snowpark import dataframe, functions as F, row, types as spt
15
16
  from snowflake.snowpark._internal import utils as snowpark_utils
16
17
 
@@ -19,6 +20,8 @@ class ServiceStatus(enum.Enum):
19
20
  UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet.
20
21
  PENDING = "PENDING" # resource set is being created, can't be used yet
21
22
  READY = "READY" # resource set has been deployed.
23
+ SUSPENDING = "SUSPENDING" # the service is set to suspended but the resource set is still in deleting state
24
+ SUSPENDED = "SUSPENDED" # the service is suspended and the resource set is deleted
22
25
  DELETING = "DELETING" # resource set is being deleted
23
26
  FAILED = "FAILED" # resource set has failed and cannot be used anymore
24
27
  DONE = "DONE" # resource set has finished running
@@ -41,7 +44,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
41
44
  image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
42
45
  image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
43
46
  image_repo_name: sql_identifier.SqlIdentifier,
44
- gpu: Optional[str],
47
+ gpu: Optional[Union[str, int]],
45
48
  force_rebuild: bool,
46
49
  external_access_integration: sql_identifier.SqlIdentifier,
47
50
  statement_params: Optional[Dict[str, Any]] = None,
@@ -121,6 +124,11 @@ class ServiceSQLClient(_base._BaseSQLClient):
121
124
  args_sql_list.append(input_arg_value)
122
125
  args_sql = ", ".join(args_sql_list)
123
126
 
127
+ wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
128
+ if wide_input:
129
+ input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
130
+ args_sql = f"object_construct_keep_null({input_args_sql})"
131
+
124
132
  if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
125
133
  fully_qualified_service_name = self.fully_qualified_object_name(
126
134
  actual_database_name, actual_schema_name, service_name
@@ -1,8 +1,10 @@
1
1
  import pathlib
2
2
  import tempfile
3
3
  import uuid
4
+ import warnings
4
5
  from types import ModuleType
5
- from typing import Any, Dict, List, Optional
6
+ from typing import Any, Dict, List, Optional, Union
7
+ from urllib import parse
6
8
 
7
9
  from absl import logging
8
10
  from packaging import requirements
@@ -44,7 +46,13 @@ class ModelComposer:
44
46
  statement_params: Optional[Dict[str, Any]] = None,
45
47
  ) -> None:
46
48
  self.session = session
47
- self.stage_path = pathlib.PurePosixPath(stage_path)
49
+ self.stage_path: Union[pathlib.PurePosixPath, parse.ParseResult] = None # type: ignore[assignment]
50
+ if stage_path.startswith("snow://"):
51
+ # The stage path is a snowflake internal stage path
52
+ self.stage_path = parse.urlparse(stage_path)
53
+ else:
54
+ # The stage path is a user stage path
55
+ self.stage_path = pathlib.PurePosixPath(stage_path)
48
56
 
49
57
  self._workspace = tempfile.TemporaryDirectory()
50
58
  self._packager_workspace = tempfile.TemporaryDirectory()
@@ -70,7 +78,20 @@ class ModelComposer:
70
78
 
71
79
  @property
72
80
  def model_stage_path(self) -> str:
73
- return (self.stage_path / self.model_file_rel_path).as_posix()
81
+ if isinstance(self.stage_path, parse.ParseResult):
82
+ model_file_path = (pathlib.PosixPath(self.stage_path.path) / self.model_file_rel_path).as_posix()
83
+ new_url = parse.ParseResult(
84
+ scheme=self.stage_path.scheme,
85
+ netloc=self.stage_path.netloc,
86
+ path=str(model_file_path),
87
+ params=self.stage_path.params,
88
+ query=self.stage_path.query,
89
+ fragment=self.stage_path.fragment,
90
+ )
91
+ return str(parse.urlunparse(new_url))
92
+ else:
93
+ assert isinstance(self.stage_path, pathlib.PurePosixPath)
94
+ return (self.stage_path / self.model_file_rel_path).as_posix()
74
95
 
75
96
  @property
76
97
  def model_local_path(self) -> str:
@@ -86,6 +107,7 @@ class ModelComposer:
86
107
  metadata: Optional[Dict[str, str]] = None,
87
108
  conda_dependencies: Optional[List[str]] = None,
88
109
  pip_requirements: Optional[List[str]] = None,
110
+ artifact_repository_map: Optional[Dict[str, str]] = None,
89
111
  target_platforms: Optional[List[model_types.TargetPlatform]] = None,
90
112
  python_version: Optional[str] = None,
91
113
  user_files: Optional[Dict[str, List[str]]] = None,
@@ -94,8 +116,32 @@ class ModelComposer:
94
116
  task: model_types.Task = model_types.Task.UNKNOWN,
95
117
  options: Optional[model_types.ModelSaveOption] = None,
96
118
  ) -> model_meta.ModelMetadata:
119
+ # set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
120
+ conda_dep_dict = env_utils.validate_conda_dependency_string_list(
121
+ conda_dependencies if conda_dependencies else []
122
+ )
123
+ is_warehouse_runnable = (
124
+ not conda_dep_dict
125
+ or all(
126
+ chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
127
+ for chan in conda_dep_dict
128
+ )
129
+ ) and (not pip_requirements)
130
+ disable_explainability = (
131
+ target_platforms and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
132
+ ) or (not is_warehouse_runnable)
133
+
134
+ if disable_explainability and options and options.get("enable_explainability", False):
135
+ warnings.warn(
136
+ ("The model can be deployed to Snowpark Container Services only if `enable_explainability=False`."),
137
+ category=UserWarning,
138
+ stacklevel=2,
139
+ )
140
+
97
141
  if not options:
98
142
  options = model_types.BaseModelSaveOption()
143
+ if disable_explainability:
144
+ options["enable_explainability"] = False
99
145
 
100
146
  if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
101
147
  snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
@@ -120,6 +166,7 @@ class ModelComposer:
120
166
  metadata=metadata,
121
167
  conda_dependencies=conda_dependencies,
122
168
  pip_requirements=pip_requirements,
169
+ artifact_repository_map=artifact_repository_map,
123
170
  python_version=python_version,
124
171
  ext_modules=ext_modules,
125
172
  code_paths=code_paths,
@@ -36,7 +36,6 @@ class ModelManifest:
36
36
  """
37
37
 
38
38
  MANIFEST_FILE_REL_PATH = "MANIFEST.yml"
39
- _ENABLE_USER_FILES = False
40
39
  _DEFAULT_RUNTIME_NAME = "python_runtime"
41
40
 
42
41
  def __init__(self, workspace_path: pathlib.Path) -> None:
@@ -78,6 +77,7 @@ class ModelManifest:
78
77
  logger.info("Relaxing version constraints for dependencies in the model.")
79
78
  logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
80
79
  logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
80
+ logger.info(f"artifact_repository_map: {runtime_to_use.runtime_env.artifact_repository_map}")
81
81
  runtime_dict = runtime_to_use.save(
82
82
  self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
83
83
  )
@@ -124,6 +124,9 @@ class ModelManifest:
124
124
  if len(model_meta.env.pip_requirements) > 0:
125
125
  dependencies["pip"] = runtime_dict["dependencies"]["pip"]
126
126
 
127
+ if model_meta.env.artifact_repository_map:
128
+ dependencies["artifact_repository_map"] = runtime_dict["dependencies"]["artifact_repository_map"]
129
+
127
130
  manifest_dict = model_manifest_schema.ModelManifestDict(
128
131
  manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
129
132
  runtimes={
@@ -145,7 +148,7 @@ class ModelManifest:
145
148
  ],
146
149
  )
147
150
 
148
- if self._ENABLE_USER_FILES:
151
+ if self.user_files:
149
152
  manifest_dict["user_files"] = [user_file.save(self.workspace_path) for user_file in self.user_files]
150
153
 
151
154
  lineage_sources = self._extract_lineage_info(data_sources)
@@ -1,6 +1,6 @@
1
1
  # This files contains schema definition of what will be written into MANIFEST.yml
2
2
  import enum
3
- from typing import Any, Dict, List, Literal, TypedDict, Union
3
+ from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
4
4
 
5
5
  from typing_extensions import NotRequired, Required
6
6
 
@@ -20,6 +20,7 @@ class ModelMethodFunctionTypes(enum.Enum):
20
20
  class ModelRuntimeDependenciesDict(TypedDict):
21
21
  conda: NotRequired[str]
22
22
  pip: NotRequired[str]
23
+ artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
23
24
 
24
25
 
25
26
  class ModelRuntimeDict(TypedDict):
@@ -98,7 +98,6 @@ class ModelMethod:
98
98
  def _get_method_arg_from_feature(
99
99
  feature: model_signature.BaseFeatureSpec, case_sensitive: bool = False
100
100
  ) -> model_manifest_schema.ModelMethodSignatureFieldWithName:
101
- assert isinstance(feature, model_signature.FeatureSpec), "FeatureGroupSpec is not supported."
102
101
  try:
103
102
  feature_name = sql_identifier.SqlIdentifier(feature.name, case_sensitive=case_sensitive)
104
103
  except ValueError as e:
@@ -3,7 +3,7 @@ import itertools
3
3
  import os
4
4
  import pathlib
5
5
  import warnings
6
- from typing import DefaultDict, List, Optional
6
+ from typing import DefaultDict, Dict, List, Optional
7
7
 
8
8
  from packaging import requirements, version
9
9
 
@@ -36,6 +36,7 @@ class ModelEnv:
36
36
  pip_requirements_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_PIP_REQUIREMENTS_FILENAME)
37
37
  self.conda_env_rel_path = pathlib.PurePosixPath(pathlib.Path(conda_env_rel_path).as_posix())
38
38
  self.pip_requirements_rel_path = pathlib.PurePosixPath(pathlib.Path(pip_requirements_rel_path).as_posix())
39
+ self.artifact_repository_map: Optional[Dict[str, str]] = None
39
40
  self._conda_dependencies: DefaultDict[str, List[requirements.Requirement]] = collections.defaultdict(list)
40
41
  self._pip_requirements: List[requirements.Requirement] = []
41
42
  self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
@@ -345,6 +346,7 @@ class ModelEnv:
345
346
  def load_from_dict(self, base_dir: pathlib.Path, env_dict: model_meta_schema.ModelEnvDict) -> None:
346
347
  self.conda_env_rel_path = pathlib.PurePosixPath(env_dict["conda"])
347
348
  self.pip_requirements_rel_path = pathlib.PurePosixPath(env_dict["pip"])
349
+ self.artifact_repository_map = env_dict.get("artifact_repository_map", None)
348
350
 
349
351
  self.load_from_conda_file(base_dir / self.conda_env_rel_path)
350
352
  self.load_from_pip_file(base_dir / self.pip_requirements_rel_path)
@@ -373,6 +375,7 @@ class ModelEnv:
373
375
  return {
374
376
  "conda": self.conda_env_rel_path.as_posix(),
375
377
  "pip": self.pip_requirements_rel_path.as_posix(),
378
+ "artifact_repository_map": self.artifact_repository_map if self.artifact_repository_map is not None else {},
376
379
  "python_version": self.python_version,
377
380
  "cuda_version": self.cuda_version,
378
381
  "snowpark_ml_version": self.snowpark_ml_version,