snowflake-ml-python 1.7.4__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/env_utils.py +64 -21
  3. snowflake/ml/_internal/file_utils.py +18 -4
  4. snowflake/ml/_internal/platform_capabilities.py +3 -0
  5. snowflake/ml/_internal/relax_version_strategy.py +16 -0
  6. snowflake/ml/_internal/telemetry.py +25 -0
  7. snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +18 -0
  9. snowflake/ml/feature_store/feature_view.py +46 -1
  10. snowflake/ml/fileset/fileset.py +0 -1
  11. snowflake/ml/jobs/_utils/constants.py +31 -1
  12. snowflake/ml/jobs/_utils/payload_utils.py +232 -72
  13. snowflake/ml/jobs/_utils/spec_utils.py +78 -38
  14. snowflake/ml/jobs/decorators.py +8 -25
  15. snowflake/ml/jobs/job.py +4 -4
  16. snowflake/ml/jobs/manager.py +5 -0
  17. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  18. snowflake/ml/model/_client/ops/model_ops.py +107 -14
  19. snowflake/ml/model/_client/ops/service_ops.py +1 -1
  20. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  21. snowflake/ml/model/_client/sql/model_version.py +58 -0
  22. snowflake/ml/model/_client/sql/service.py +8 -2
  23. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  26. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  27. snowflake/ml/model/_packager/model_env/model_env.py +49 -29
  28. snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
  29. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +44 -24
  30. snowflake/ml/model/_packager/model_handlers/keras.py +226 -0
  31. snowflake/ml/model/_packager/model_handlers/pytorch.py +51 -20
  32. snowflake/ml/model/_packager/model_handlers/sklearn.py +25 -3
  33. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +73 -21
  34. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -72
  35. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  36. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  37. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  38. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
  39. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  40. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  41. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  42. snowflake/ml/model/_packager/model_meta/model_meta.py +6 -2
  43. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +16 -0
  44. snowflake/ml/model/_packager/model_packager.py +3 -5
  45. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +8 -1
  47. snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
  48. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  49. snowflake/ml/model/_signatures/core.py +54 -33
  50. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  51. snowflake/ml/model/_signatures/numpy_handler.py +12 -20
  52. snowflake/ml/model/_signatures/pandas_handler.py +28 -37
  53. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  54. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  55. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  56. snowflake/ml/model/_signatures/utils.py +120 -8
  57. snowflake/ml/model/custom_model.py +13 -4
  58. snowflake/ml/model/model_signature.py +39 -13
  59. snowflake/ml/model/type_hints.py +28 -2
  60. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
  61. snowflake/ml/modeling/metrics/ranking.py +3 -0
  62. snowflake/ml/modeling/metrics/regression.py +3 -0
  63. snowflake/ml/modeling/pipeline/pipeline.py +18 -1
  64. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  65. snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
  66. snowflake/ml/registry/_manager/model_manager.py +55 -7
  67. snowflake/ml/registry/registry.py +52 -4
  68. snowflake/ml/version.py +1 -1
  69. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +336 -27
  70. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +73 -66
  71. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
  72. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
  73. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,5 @@
1
1
  import copy
2
2
  import functools
3
- import inspect
4
3
  from typing import Callable, Dict, List, Optional, TypeVar
5
4
 
6
5
  from typing_extensions import ParamSpec
@@ -8,7 +7,7 @@ from typing_extensions import ParamSpec
8
7
  from snowflake import snowpark
9
8
  from snowflake.ml._internal import telemetry
10
9
  from snowflake.ml.jobs import job as jb, manager as jm
11
- from snowflake.ml.jobs._utils import payload_utils
10
+ from snowflake.ml.jobs._utils import constants
12
11
 
13
12
  _PROJECT = "MLJob"
14
13
 
@@ -26,6 +25,7 @@ def remote(
26
25
  query_warehouse: Optional[str] = None,
27
26
  env_vars: Optional[Dict[str, str]] = None,
28
27
  session: Optional[snowpark.Session] = None,
28
+ num_instances: Optional[int] = None,
29
29
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob]]:
30
30
  """
31
31
  Submit a job to the compute pool.
@@ -38,6 +38,7 @@ def remote(
38
38
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
39
39
  env_vars: Environment variables to set in container
40
40
  session: The Snowpark session to use. If none specified, uses active session.
41
+ num_instances: The number of nodes in the job. If none specified, create a single node job.
41
42
 
42
43
  Returns:
43
44
  Decorator that dispatches invocations of the decorated function as remote jobs.
@@ -50,31 +51,12 @@ def remote(
50
51
  wrapped_func = copy.copy(func)
51
52
  wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
52
53
 
53
- # Validate function arguments based on signature
54
- signature = inspect.signature(func)
55
- pos_arg_names = []
56
- for name, param in signature.parameters.items():
57
- param_type = payload_utils.get_parameter_type(param)
58
- if param_type is not None:
59
- payload_utils.validate_parameter_type(param_type, name)
60
- if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
61
- pos_arg_names.append(name)
62
-
63
54
  @functools.wraps(func)
64
55
  def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob:
65
- # Validate positional args
66
- for i, arg in enumerate(args):
67
- arg_name = pos_arg_names[i] if i < len(pos_arg_names) else f"args[{i}]"
68
- payload_utils.validate_parameter_type(type(arg), arg_name)
69
-
70
- # Validate keyword args
71
- for k, v in kwargs.items():
72
- payload_utils.validate_parameter_type(type(v), k)
73
-
74
- arg_list = [str(v) for v in args] + [x for k, v in kwargs.items() for x in (f"--{k}", str(v))]
56
+ payload = functools.partial(func, *args, **kwargs)
57
+ setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
75
58
  job = jm._submit_job(
76
- source=wrapped_func,
77
- args=arg_list,
59
+ source=payload,
78
60
  stage_name=stage_name,
79
61
  compute_pool=compute_pool,
80
62
  pip_requirements=pip_requirements,
@@ -82,8 +64,9 @@ def remote(
82
64
  query_warehouse=query_warehouse,
83
65
  env_vars=env_vars,
84
66
  session=session,
67
+ num_instances=num_instances,
85
68
  )
86
- assert isinstance(job, jb.MLJob)
69
+ assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
87
70
  return job
88
71
 
89
72
  return wrapper
snowflake/ml/jobs/job.py CHANGED
@@ -4,7 +4,7 @@ from typing import Any, List, Optional, cast
4
4
  from snowflake import snowpark
5
5
  from snowflake.ml._internal import telemetry
6
6
  from snowflake.ml.jobs._utils import constants, types
7
- from snowflake.snowpark.context import get_active_session
7
+ from snowflake.snowpark import context as sp_context
8
8
 
9
9
  _PROJECT = "MLJob"
10
10
  TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
@@ -13,7 +13,7 @@ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
13
13
  class MLJob:
14
14
  def __init__(self, id: str, session: Optional[snowpark.Session] = None) -> None:
15
15
  self._id = id
16
- self._session = session or get_active_session()
16
+ self._session = session or sp_context.get_active_session()
17
17
  self._status: types.JOB_STATUS = "PENDING"
18
18
 
19
19
  @property
@@ -79,7 +79,7 @@ class MLJob:
79
79
  return self.status
80
80
 
81
81
 
82
- @telemetry.send_api_usage_telemetry(project=_PROJECT)
82
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
83
83
  def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
84
84
  """Retrieve job execution status."""
85
85
  # TODO: snowflake-snowpark-python<1.24.0 shows spurious error messages on
@@ -90,7 +90,7 @@ def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
90
90
  return cast(types.JOB_STATUS, row["status"])
91
91
 
92
92
 
93
- @telemetry.send_api_usage_telemetry(project=_PROJECT)
93
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit"])
94
94
  def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
95
95
  """
96
96
  Retrieve the job's execution logs.
@@ -213,6 +213,7 @@ def _submit_job(
213
213
  query_warehouse: Optional[str] = None,
214
214
  spec_overrides: Optional[Dict[str, Any]] = None,
215
215
  session: Optional[snowpark.Session] = None,
216
+ num_instances: Optional[int] = None,
216
217
  ) -> jb.MLJob:
217
218
  """
218
219
  Submit a job to the compute pool.
@@ -229,6 +230,7 @@ def _submit_job(
229
230
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
230
231
  spec_overrides: Custom service specification overrides to apply.
231
232
  session: The Snowpark session to use. If none specified, uses active session.
233
+ num_instances: The number of instances to use for the job. If none specified, single node job is created.
232
234
 
233
235
  Returns:
234
236
  An object representing the submitted job.
@@ -254,6 +256,7 @@ def _submit_job(
254
256
  compute_pool=compute_pool,
255
257
  payload=uploaded_payload,
256
258
  args=args,
259
+ num_instances=num_instances,
257
260
  )
258
261
  spec_overrides = spec_utils.generate_spec_overrides(
259
262
  environment_vars=env_vars,
@@ -281,6 +284,8 @@ def _submit_job(
281
284
  query_warehouse = query_warehouse or session.get_current_warehouse()
282
285
  if query_warehouse:
283
286
  query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
287
+ if num_instances:
288
+ query.append(f"REPLICAS = {num_instances}")
284
289
 
285
290
  # Submit job
286
291
  query_text = "\n".join(line for line in query if line)
@@ -746,7 +746,7 @@ class ModelVersion(lineage_node.LineageNode):
746
746
  max_instances: int = 1,
747
747
  cpu_requests: Optional[str] = None,
748
748
  memory_requests: Optional[str] = None,
749
- gpu_requests: Optional[str] = None,
749
+ gpu_requests: Optional[Union[str, int]] = None,
750
750
  num_workers: Optional[int] = None,
751
751
  max_batch_rows: Optional[int] = None,
752
752
  force_rebuild: bool = False,
@@ -1,3 +1,4 @@
1
+ import enum
1
2
  import json
2
3
  import os
3
4
  import pathlib
@@ -31,6 +32,12 @@ from snowflake.snowpark import dataframe, row, session
31
32
  from snowflake.snowpark._internal import utils as snowpark_utils
32
33
 
33
34
 
35
+ # An enum class to represent Create Or Alter Model SQL command.
36
+ class ModelAction(enum.Enum):
37
+ CREATE = "CREATE"
38
+ ALTER = "ALTER"
39
+
40
+
34
41
  class ServiceInfo(TypedDict):
35
42
  name: str
36
43
  status: str
@@ -92,7 +99,7 @@ class ModelOperator:
92
99
  and self._model_version_client == __value._model_version_client
93
100
  )
94
101
 
95
- def prepare_model_stage_path(
102
+ def prepare_model_temp_stage_path(
96
103
  self,
97
104
  *,
98
105
  database_name: Optional[sql_identifier.SqlIdentifier],
@@ -110,17 +117,28 @@ class ModelOperator:
110
117
  )
111
118
  return f"@{self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)}/model"
112
119
 
113
- def create_from_stage(
120
+ def get_model_version_stage_path(
121
+ self,
122
+ *,
123
+ database_name: Optional[sql_identifier.SqlIdentifier],
124
+ schema_name: Optional[sql_identifier.SqlIdentifier],
125
+ model_name: sql_identifier.SqlIdentifier,
126
+ version_name: sql_identifier.SqlIdentifier,
127
+ ) -> str:
128
+ return (
129
+ f"snow://model/{self._stage_client.fully_qualified_object_name(database_name, schema_name, model_name)}"
130
+ f"/versions/{version_name}/"
131
+ )
132
+
133
+ def get_model_action_from_model_name_and_version(
114
134
  self,
115
- composed_model: model_composer.ModelComposer,
116
135
  *,
117
136
  database_name: Optional[sql_identifier.SqlIdentifier],
118
137
  schema_name: Optional[sql_identifier.SqlIdentifier],
119
138
  model_name: sql_identifier.SqlIdentifier,
120
139
  version_name: sql_identifier.SqlIdentifier,
121
140
  statement_params: Optional[Dict[str, Any]] = None,
122
- ) -> None:
123
- stage_path = str(composed_model.stage_path)
141
+ ) -> ModelAction:
124
142
  if self.validate_existence(
125
143
  database_name=database_name,
126
144
  schema_name=schema_name,
@@ -140,6 +158,79 @@ class ModelOperator:
140
158
  f" version {version_name} already existed."
141
159
  )
142
160
  else:
161
+ return ModelAction.ALTER
162
+ else:
163
+ return ModelAction.CREATE
164
+
165
+ def add_or_create_live_version(
166
+ self,
167
+ *,
168
+ database_name: Optional[sql_identifier.SqlIdentifier],
169
+ schema_name: Optional[sql_identifier.SqlIdentifier],
170
+ model_name: sql_identifier.SqlIdentifier,
171
+ version_name: sql_identifier.SqlIdentifier,
172
+ statement_params: Optional[Dict[str, Any]] = None,
173
+ ) -> None:
174
+ model_action = self.get_model_action_from_model_name_and_version(
175
+ database_name=database_name,
176
+ schema_name=schema_name,
177
+ model_name=model_name,
178
+ version_name=version_name,
179
+ statement_params=statement_params,
180
+ )
181
+ if model_action == ModelAction.CREATE:
182
+ self._model_version_client.create_live_version(
183
+ database_name=database_name,
184
+ schema_name=schema_name,
185
+ model_name=model_name,
186
+ version_name=version_name,
187
+ statement_params=statement_params,
188
+ )
189
+ elif model_action == ModelAction.ALTER:
190
+ self._model_version_client.add_live_version(
191
+ database_name=database_name,
192
+ schema_name=schema_name,
193
+ model_name=model_name,
194
+ version_name=version_name,
195
+ statement_params=statement_params,
196
+ )
197
+ else:
198
+ raise AssertionError(f"The model_action is {model_action}. Expected CREATE or ALTER.")
199
+
200
+ def create_from_stage(
201
+ self,
202
+ composed_model: model_composer.ModelComposer,
203
+ *,
204
+ database_name: Optional[sql_identifier.SqlIdentifier],
205
+ schema_name: Optional[sql_identifier.SqlIdentifier],
206
+ model_name: sql_identifier.SqlIdentifier,
207
+ version_name: sql_identifier.SqlIdentifier,
208
+ statement_params: Optional[Dict[str, Any]] = None,
209
+ use_live_commit: Optional[bool] = False,
210
+ ) -> None:
211
+
212
+ if use_live_commit:
213
+ # if the model version is live, we can only commit the version
214
+ self._model_version_client.commit_version(
215
+ database_name=database_name,
216
+ schema_name=schema_name,
217
+ model_name=model_name,
218
+ version_name=version_name,
219
+ statement_params=statement_params,
220
+ )
221
+ else:
222
+ stage_path = str(composed_model.stage_path)
223
+ # if the model version is not live,
224
+ # find whether the model exists and whether the version exists
225
+ # and then decide whether to create or alter the model
226
+ model_action = self.get_model_action_from_model_name_and_version(
227
+ database_name=database_name,
228
+ schema_name=schema_name,
229
+ model_name=model_name,
230
+ version_name=version_name,
231
+ statement_params=statement_params,
232
+ )
233
+ if model_action == ModelAction.ALTER:
143
234
  self._model_version_client.add_version_from_stage(
144
235
  database_name=database_name,
145
236
  schema_name=schema_name,
@@ -148,15 +239,17 @@ class ModelOperator:
148
239
  version_name=version_name,
149
240
  statement_params=statement_params,
150
241
  )
151
- else:
152
- self._model_version_client.create_from_stage(
153
- database_name=database_name,
154
- schema_name=schema_name,
155
- stage_path=stage_path,
156
- model_name=model_name,
157
- version_name=version_name,
158
- statement_params=statement_params,
159
- )
242
+ elif model_action == ModelAction.CREATE:
243
+ self._model_version_client.create_from_stage(
244
+ database_name=database_name,
245
+ schema_name=schema_name,
246
+ stage_path=stage_path,
247
+ model_name=model_name,
248
+ version_name=version_name,
249
+ statement_params=statement_params,
250
+ )
251
+ else:
252
+ raise AssertionError(f"The model_action is {model_action}. Expected CREATE or ALTER.")
160
253
 
161
254
  def create_from_model_version(
162
255
  self,
@@ -100,7 +100,7 @@ class ServiceOperator:
100
100
  max_instances: int,
101
101
  cpu_requests: Optional[str],
102
102
  memory_requests: Optional[str],
103
- gpu_requests: Optional[str],
103
+ gpu_requests: Optional[Union[int, str]],
104
104
  num_workers: Optional[int],
105
105
  max_batch_rows: Optional[int],
106
106
  force_rebuild: bool,
@@ -1,5 +1,5 @@
1
1
  import pathlib
2
- from typing import List, Optional
2
+ from typing import List, Optional, Union
3
3
 
4
4
  import yaml
5
5
 
@@ -38,7 +38,7 @@ class ModelDeploymentSpec:
38
38
  max_instances: int,
39
39
  cpu: Optional[str],
40
40
  memory: Optional[str],
41
- gpu: Optional[str],
41
+ gpu: Optional[Union[str, int]],
42
42
  num_workers: Optional[int],
43
43
  max_batch_rows: Optional[int],
44
44
  force_rebuild: bool,
@@ -86,7 +86,11 @@ class ModelDeploymentSpec:
86
86
  service_dict["memory"] = memory
87
87
 
88
88
  if gpu:
89
- service_dict["gpu"] = gpu
89
+ if isinstance(gpu, int):
90
+ gpu_str = str(gpu)
91
+ else:
92
+ gpu_str = gpu
93
+ service_dict["gpu"] = gpu_str
90
94
 
91
95
  if num_workers:
92
96
  service_dict["num_workers"] = num_workers
@@ -71,6 +71,64 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
71
71
  statement_params=statement_params,
72
72
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
73
73
 
74
+ def create_live_version(
75
+ self,
76
+ *,
77
+ database_name: Optional[sql_identifier.SqlIdentifier],
78
+ schema_name: Optional[sql_identifier.SqlIdentifier],
79
+ model_name: sql_identifier.SqlIdentifier,
80
+ version_name: sql_identifier.SqlIdentifier,
81
+ statement_params: Optional[Dict[str, Any]] = None,
82
+ ) -> None:
83
+ sql = (
84
+ f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
85
+ f" WITH LIVE VERSION {version_name.identifier()}"
86
+ )
87
+ query_result_checker.SqlResultValidator(
88
+ self._session,
89
+ sql,
90
+ statement_params=statement_params,
91
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
92
+
93
+ def add_live_version(
94
+ self,
95
+ *,
96
+ database_name: Optional[sql_identifier.SqlIdentifier],
97
+ schema_name: Optional[sql_identifier.SqlIdentifier],
98
+ model_name: sql_identifier.SqlIdentifier,
99
+ version_name: sql_identifier.SqlIdentifier,
100
+ statement_params: Optional[Dict[str, Any]] = None,
101
+ ) -> None:
102
+ sql = (
103
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
104
+ f" ADD LIVE VERSION {version_name.identifier()}"
105
+ )
106
+ query_result_checker.SqlResultValidator(
107
+ self._session,
108
+ sql,
109
+ statement_params=statement_params,
110
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
111
+
112
+ def commit_version(
113
+ self,
114
+ *,
115
+ database_name: Optional[sql_identifier.SqlIdentifier],
116
+ schema_name: Optional[sql_identifier.SqlIdentifier],
117
+ model_name: sql_identifier.SqlIdentifier,
118
+ version_name: sql_identifier.SqlIdentifier,
119
+ statement_params: Optional[Dict[str, Any]] = None,
120
+ ) -> None:
121
+ sql = (
122
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
123
+ f" COMMIT VERSION {version_name.identifier()}"
124
+ )
125
+
126
+ query_result_checker.SqlResultValidator(
127
+ self._session,
128
+ sql,
129
+ statement_params=statement_params,
130
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
131
+
74
132
  # TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...`
75
133
  def add_version_from_stage(
76
134
  self,
@@ -1,7 +1,7 @@
1
1
  import enum
2
2
  import json
3
3
  import textwrap
4
- from typing import Any, Dict, List, Optional, Tuple
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
5
 
6
6
  from snowflake import snowpark
7
7
  from snowflake.ml._internal import platform_capabilities
@@ -11,6 +11,7 @@ from snowflake.ml._internal.utils import (
11
11
  sql_identifier,
12
12
  )
13
13
  from snowflake.ml.model._client.sql import _base
14
+ from snowflake.ml.model._model_composer.model_method import constants
14
15
  from snowflake.snowpark import dataframe, functions as F, row, types as spt
15
16
  from snowflake.snowpark._internal import utils as snowpark_utils
16
17
 
@@ -41,7 +42,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
41
42
  image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
42
43
  image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
43
44
  image_repo_name: sql_identifier.SqlIdentifier,
44
- gpu: Optional[str],
45
+ gpu: Optional[Union[str, int]],
45
46
  force_rebuild: bool,
46
47
  external_access_integration: sql_identifier.SqlIdentifier,
47
48
  statement_params: Optional[Dict[str, Any]] = None,
@@ -121,6 +122,11 @@ class ServiceSQLClient(_base._BaseSQLClient):
121
122
  args_sql_list.append(input_arg_value)
122
123
  args_sql = ", ".join(args_sql_list)
123
124
 
125
+ wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
126
+ if wide_input:
127
+ input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
128
+ args_sql = f"object_construct_keep_null({input_args_sql})"
129
+
124
130
  if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
125
131
  fully_qualified_service_name = self.fully_qualified_object_name(
126
132
  actual_database_name, actual_schema_name, service_name
@@ -1,8 +1,10 @@
1
1
  import pathlib
2
2
  import tempfile
3
3
  import uuid
4
+ import warnings
4
5
  from types import ModuleType
5
- from typing import Any, Dict, List, Optional
6
+ from typing import Any, Dict, List, Optional, Union
7
+ from urllib import parse
6
8
 
7
9
  from absl import logging
8
10
  from packaging import requirements
@@ -44,7 +46,13 @@ class ModelComposer:
44
46
  statement_params: Optional[Dict[str, Any]] = None,
45
47
  ) -> None:
46
48
  self.session = session
47
- self.stage_path = pathlib.PurePosixPath(stage_path)
49
+ self.stage_path: Union[pathlib.PurePosixPath, parse.ParseResult] = None # type: ignore[assignment]
50
+ if stage_path.startswith("snow://"):
51
+ # The stage path is a snowflake internal stage path
52
+ self.stage_path = parse.urlparse(stage_path)
53
+ else:
54
+ # The stage path is a user stage path
55
+ self.stage_path = pathlib.PurePosixPath(stage_path)
48
56
 
49
57
  self._workspace = tempfile.TemporaryDirectory()
50
58
  self._packager_workspace = tempfile.TemporaryDirectory()
@@ -70,7 +78,20 @@ class ModelComposer:
70
78
 
71
79
  @property
72
80
  def model_stage_path(self) -> str:
73
- return (self.stage_path / self.model_file_rel_path).as_posix()
81
+ if isinstance(self.stage_path, parse.ParseResult):
82
+ model_file_path = (pathlib.PosixPath(self.stage_path.path) / self.model_file_rel_path).as_posix()
83
+ new_url = parse.ParseResult(
84
+ scheme=self.stage_path.scheme,
85
+ netloc=self.stage_path.netloc,
86
+ path=str(model_file_path),
87
+ params=self.stage_path.params,
88
+ query=self.stage_path.query,
89
+ fragment=self.stage_path.fragment,
90
+ )
91
+ return str(parse.urlunparse(new_url))
92
+ else:
93
+ assert isinstance(self.stage_path, pathlib.PurePosixPath)
94
+ return (self.stage_path / self.model_file_rel_path).as_posix()
74
95
 
75
96
  @property
76
97
  def model_local_path(self) -> str:
@@ -86,6 +107,7 @@ class ModelComposer:
86
107
  metadata: Optional[Dict[str, str]] = None,
87
108
  conda_dependencies: Optional[List[str]] = None,
88
109
  pip_requirements: Optional[List[str]] = None,
110
+ artifact_repository_map: Optional[Dict[str, str]] = None,
89
111
  target_platforms: Optional[List[model_types.TargetPlatform]] = None,
90
112
  python_version: Optional[str] = None,
91
113
  user_files: Optional[Dict[str, List[str]]] = None,
@@ -94,8 +116,32 @@ class ModelComposer:
94
116
  task: model_types.Task = model_types.Task.UNKNOWN,
95
117
  options: Optional[model_types.ModelSaveOption] = None,
96
118
  ) -> model_meta.ModelMetadata:
119
+ # set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
120
+ conda_dep_dict = env_utils.validate_conda_dependency_string_list(
121
+ conda_dependencies if conda_dependencies else []
122
+ )
123
+ is_warehouse_runnable = (
124
+ not conda_dep_dict
125
+ or all(
126
+ chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
127
+ for chan in conda_dep_dict
128
+ )
129
+ ) and (not pip_requirements)
130
+ disable_explainability = (
131
+ target_platforms and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
132
+ ) or (not is_warehouse_runnable)
133
+
134
+ if disable_explainability and options and options.get("enable_explainability", False):
135
+ warnings.warn(
136
+ ("The model can be deployed to Snowpark Container Services only if `enable_explainability=False`."),
137
+ category=UserWarning,
138
+ stacklevel=2,
139
+ )
140
+
97
141
  if not options:
98
142
  options = model_types.BaseModelSaveOption()
143
+ if disable_explainability:
144
+ options["enable_explainability"] = False
99
145
 
100
146
  if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
101
147
  snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
@@ -120,6 +166,7 @@ class ModelComposer:
120
166
  metadata=metadata,
121
167
  conda_dependencies=conda_dependencies,
122
168
  pip_requirements=pip_requirements,
169
+ artifact_repository_map=artifact_repository_map,
123
170
  python_version=python_version,
124
171
  ext_modules=ext_modules,
125
172
  code_paths=code_paths,
@@ -78,6 +78,7 @@ class ModelManifest:
78
78
  logger.info("Relaxing version constraints for dependencies in the model.")
79
79
  logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
80
80
  logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
81
+ logger.info(f"artifact_repository_map: {runtime_to_use.runtime_env.artifact_repository_map}")
81
82
  runtime_dict = runtime_to_use.save(
82
83
  self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
83
84
  )
@@ -124,6 +125,9 @@ class ModelManifest:
124
125
  if len(model_meta.env.pip_requirements) > 0:
125
126
  dependencies["pip"] = runtime_dict["dependencies"]["pip"]
126
127
 
128
+ if model_meta.env.artifact_repository_map:
129
+ dependencies["artifact_repository_map"] = runtime_dict["dependencies"]["artifact_repository_map"]
130
+
127
131
  manifest_dict = model_manifest_schema.ModelManifestDict(
128
132
  manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
129
133
  runtimes={
@@ -1,6 +1,6 @@
1
1
  # This files contains schema definition of what will be written into MANIFEST.yml
2
2
  import enum
3
- from typing import Any, Dict, List, Literal, TypedDict, Union
3
+ from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
4
4
 
5
5
  from typing_extensions import NotRequired, Required
6
6
 
@@ -20,6 +20,7 @@ class ModelMethodFunctionTypes(enum.Enum):
20
20
  class ModelRuntimeDependenciesDict(TypedDict):
21
21
  conda: NotRequired[str]
22
22
  pip: NotRequired[str]
23
+ artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
23
24
 
24
25
 
25
26
  class ModelRuntimeDict(TypedDict):
@@ -98,7 +98,6 @@ class ModelMethod:
98
98
  def _get_method_arg_from_feature(
99
99
  feature: model_signature.BaseFeatureSpec, case_sensitive: bool = False
100
100
  ) -> model_manifest_schema.ModelMethodSignatureFieldWithName:
101
- assert isinstance(feature, model_signature.FeatureSpec), "FeatureGroupSpec is not supported."
102
101
  try:
103
102
  feature_name = sql_identifier.SqlIdentifier(feature.name, case_sensitive=case_sensitive)
104
103
  except ValueError as e: