snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/file_utils.py +18 -4
  3. snowflake/ml/_internal/platform_capabilities.py +3 -0
  4. snowflake/ml/_internal/telemetry.py +4 -0
  5. snowflake/ml/fileset/fileset.py +0 -1
  6. snowflake/ml/jobs/_utils/constants.py +24 -0
  7. snowflake/ml/jobs/_utils/payload_utils.py +94 -20
  8. snowflake/ml/jobs/_utils/spec_utils.py +73 -31
  9. snowflake/ml/jobs/decorators.py +3 -0
  10. snowflake/ml/jobs/manager.py +5 -0
  11. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  12. snowflake/ml/model/_client/ops/model_ops.py +107 -14
  13. snowflake/ml/model/_client/ops/service_ops.py +1 -1
  14. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  15. snowflake/ml/model/_client/sql/model_version.py +58 -0
  16. snowflake/ml/model/_client/sql/service.py +8 -2
  17. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  18. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  19. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  20. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  21. snowflake/ml/model/_packager/model_env/model_env.py +4 -1
  22. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
  23. snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
  24. snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
  25. snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -4
  26. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
  27. snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
  28. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  29. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  30. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  31. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  32. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
  35. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -0
  36. snowflake/ml/model/_packager/model_packager.py +3 -5
  37. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  38. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
  39. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  40. snowflake/ml/model/_signatures/core.py +52 -31
  41. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  42. snowflake/ml/model/_signatures/numpy_handler.py +9 -17
  43. snowflake/ml/model/_signatures/pandas_handler.py +19 -30
  44. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  45. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  46. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  47. snowflake/ml/model/_signatures/utils.py +120 -8
  48. snowflake/ml/model/custom_model.py +13 -4
  49. snowflake/ml/model/model_signature.py +31 -13
  50. snowflake/ml/model/type_hints.py +13 -2
  51. snowflake/ml/modeling/metrics/ranking.py +3 -0
  52. snowflake/ml/modeling/metrics/regression.py +3 -0
  53. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  54. snowflake/ml/registry/_manager/model_manager.py +55 -7
  55. snowflake/ml/registry/registry.py +18 -0
  56. snowflake/ml/version.py +1 -1
  57. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +287 -11
  58. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +61 -57
  59. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
  60. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
  61. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import enum
1
2
  import json
2
3
  import os
3
4
  import pathlib
@@ -31,6 +32,12 @@ from snowflake.snowpark import dataframe, row, session
31
32
  from snowflake.snowpark._internal import utils as snowpark_utils
32
33
 
33
34
 
35
+ # An enum class to represent Create Or Alter Model SQL command.
36
+ class ModelAction(enum.Enum):
37
+ CREATE = "CREATE"
38
+ ALTER = "ALTER"
39
+
40
+
34
41
  class ServiceInfo(TypedDict):
35
42
  name: str
36
43
  status: str
@@ -92,7 +99,7 @@ class ModelOperator:
92
99
  and self._model_version_client == __value._model_version_client
93
100
  )
94
101
 
95
- def prepare_model_stage_path(
102
+ def prepare_model_temp_stage_path(
96
103
  self,
97
104
  *,
98
105
  database_name: Optional[sql_identifier.SqlIdentifier],
@@ -110,17 +117,28 @@ class ModelOperator:
110
117
  )
111
118
  return f"@{self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)}/model"
112
119
 
113
- def create_from_stage(
120
+ def get_model_version_stage_path(
121
+ self,
122
+ *,
123
+ database_name: Optional[sql_identifier.SqlIdentifier],
124
+ schema_name: Optional[sql_identifier.SqlIdentifier],
125
+ model_name: sql_identifier.SqlIdentifier,
126
+ version_name: sql_identifier.SqlIdentifier,
127
+ ) -> str:
128
+ return (
129
+ f"snow://model/{self._stage_client.fully_qualified_object_name(database_name, schema_name, model_name)}"
130
+ f"/versions/{version_name}/"
131
+ )
132
+
133
+ def get_model_action_from_model_name_and_version(
114
134
  self,
115
- composed_model: model_composer.ModelComposer,
116
135
  *,
117
136
  database_name: Optional[sql_identifier.SqlIdentifier],
118
137
  schema_name: Optional[sql_identifier.SqlIdentifier],
119
138
  model_name: sql_identifier.SqlIdentifier,
120
139
  version_name: sql_identifier.SqlIdentifier,
121
140
  statement_params: Optional[Dict[str, Any]] = None,
122
- ) -> None:
123
- stage_path = str(composed_model.stage_path)
141
+ ) -> ModelAction:
124
142
  if self.validate_existence(
125
143
  database_name=database_name,
126
144
  schema_name=schema_name,
@@ -140,6 +158,79 @@ class ModelOperator:
140
158
  f" version {version_name} already existed."
141
159
  )
142
160
  else:
161
+ return ModelAction.ALTER
162
+ else:
163
+ return ModelAction.CREATE
164
+
165
+ def add_or_create_live_version(
166
+ self,
167
+ *,
168
+ database_name: Optional[sql_identifier.SqlIdentifier],
169
+ schema_name: Optional[sql_identifier.SqlIdentifier],
170
+ model_name: sql_identifier.SqlIdentifier,
171
+ version_name: sql_identifier.SqlIdentifier,
172
+ statement_params: Optional[Dict[str, Any]] = None,
173
+ ) -> None:
174
+ model_action = self.get_model_action_from_model_name_and_version(
175
+ database_name=database_name,
176
+ schema_name=schema_name,
177
+ model_name=model_name,
178
+ version_name=version_name,
179
+ statement_params=statement_params,
180
+ )
181
+ if model_action == ModelAction.CREATE:
182
+ self._model_version_client.create_live_version(
183
+ database_name=database_name,
184
+ schema_name=schema_name,
185
+ model_name=model_name,
186
+ version_name=version_name,
187
+ statement_params=statement_params,
188
+ )
189
+ elif model_action == ModelAction.ALTER:
190
+ self._model_version_client.add_live_version(
191
+ database_name=database_name,
192
+ schema_name=schema_name,
193
+ model_name=model_name,
194
+ version_name=version_name,
195
+ statement_params=statement_params,
196
+ )
197
+ else:
198
+ raise AssertionError(f"The model_action is {model_action}. Expected CREATE or ALTER.")
199
+
200
+ def create_from_stage(
201
+ self,
202
+ composed_model: model_composer.ModelComposer,
203
+ *,
204
+ database_name: Optional[sql_identifier.SqlIdentifier],
205
+ schema_name: Optional[sql_identifier.SqlIdentifier],
206
+ model_name: sql_identifier.SqlIdentifier,
207
+ version_name: sql_identifier.SqlIdentifier,
208
+ statement_params: Optional[Dict[str, Any]] = None,
209
+ use_live_commit: Optional[bool] = False,
210
+ ) -> None:
211
+
212
+ if use_live_commit:
213
+ # if the model version is live, we can only commit the version
214
+ self._model_version_client.commit_version(
215
+ database_name=database_name,
216
+ schema_name=schema_name,
217
+ model_name=model_name,
218
+ version_name=version_name,
219
+ statement_params=statement_params,
220
+ )
221
+ else:
222
+ stage_path = str(composed_model.stage_path)
223
+ # if the model version is not live,
224
+ # find whether the model exists and whether the version exists
225
+ # and then decide whether to create or alter the model
226
+ model_action = self.get_model_action_from_model_name_and_version(
227
+ database_name=database_name,
228
+ schema_name=schema_name,
229
+ model_name=model_name,
230
+ version_name=version_name,
231
+ statement_params=statement_params,
232
+ )
233
+ if model_action == ModelAction.ALTER:
143
234
  self._model_version_client.add_version_from_stage(
144
235
  database_name=database_name,
145
236
  schema_name=schema_name,
@@ -148,15 +239,17 @@ class ModelOperator:
148
239
  version_name=version_name,
149
240
  statement_params=statement_params,
150
241
  )
151
- else:
152
- self._model_version_client.create_from_stage(
153
- database_name=database_name,
154
- schema_name=schema_name,
155
- stage_path=stage_path,
156
- model_name=model_name,
157
- version_name=version_name,
158
- statement_params=statement_params,
159
- )
242
+ elif model_action == ModelAction.CREATE:
243
+ self._model_version_client.create_from_stage(
244
+ database_name=database_name,
245
+ schema_name=schema_name,
246
+ stage_path=stage_path,
247
+ model_name=model_name,
248
+ version_name=version_name,
249
+ statement_params=statement_params,
250
+ )
251
+ else:
252
+ raise AssertionError(f"The model_action is {model_action}. Expected CREATE or ALTER.")
160
253
 
161
254
  def create_from_model_version(
162
255
  self,
@@ -100,7 +100,7 @@ class ServiceOperator:
100
100
  max_instances: int,
101
101
  cpu_requests: Optional[str],
102
102
  memory_requests: Optional[str],
103
- gpu_requests: Optional[str],
103
+ gpu_requests: Optional[Union[int, str]],
104
104
  num_workers: Optional[int],
105
105
  max_batch_rows: Optional[int],
106
106
  force_rebuild: bool,
@@ -1,5 +1,5 @@
1
1
  import pathlib
2
- from typing import List, Optional
2
+ from typing import List, Optional, Union
3
3
 
4
4
  import yaml
5
5
 
@@ -38,7 +38,7 @@ class ModelDeploymentSpec:
38
38
  max_instances: int,
39
39
  cpu: Optional[str],
40
40
  memory: Optional[str],
41
- gpu: Optional[str],
41
+ gpu: Optional[Union[str, int]],
42
42
  num_workers: Optional[int],
43
43
  max_batch_rows: Optional[int],
44
44
  force_rebuild: bool,
@@ -86,7 +86,11 @@ class ModelDeploymentSpec:
86
86
  service_dict["memory"] = memory
87
87
 
88
88
  if gpu:
89
- service_dict["gpu"] = gpu
89
+ if isinstance(gpu, int):
90
+ gpu_str = str(gpu)
91
+ else:
92
+ gpu_str = gpu
93
+ service_dict["gpu"] = gpu_str
90
94
 
91
95
  if num_workers:
92
96
  service_dict["num_workers"] = num_workers
@@ -71,6 +71,64 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
71
71
  statement_params=statement_params,
72
72
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
73
73
 
74
+ def create_live_version(
75
+ self,
76
+ *,
77
+ database_name: Optional[sql_identifier.SqlIdentifier],
78
+ schema_name: Optional[sql_identifier.SqlIdentifier],
79
+ model_name: sql_identifier.SqlIdentifier,
80
+ version_name: sql_identifier.SqlIdentifier,
81
+ statement_params: Optional[Dict[str, Any]] = None,
82
+ ) -> None:
83
+ sql = (
84
+ f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
85
+ f" WITH LIVE VERSION {version_name.identifier()}"
86
+ )
87
+ query_result_checker.SqlResultValidator(
88
+ self._session,
89
+ sql,
90
+ statement_params=statement_params,
91
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
92
+
93
+ def add_live_version(
94
+ self,
95
+ *,
96
+ database_name: Optional[sql_identifier.SqlIdentifier],
97
+ schema_name: Optional[sql_identifier.SqlIdentifier],
98
+ model_name: sql_identifier.SqlIdentifier,
99
+ version_name: sql_identifier.SqlIdentifier,
100
+ statement_params: Optional[Dict[str, Any]] = None,
101
+ ) -> None:
102
+ sql = (
103
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
104
+ f" ADD LIVE VERSION {version_name.identifier()}"
105
+ )
106
+ query_result_checker.SqlResultValidator(
107
+ self._session,
108
+ sql,
109
+ statement_params=statement_params,
110
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
111
+
112
+ def commit_version(
113
+ self,
114
+ *,
115
+ database_name: Optional[sql_identifier.SqlIdentifier],
116
+ schema_name: Optional[sql_identifier.SqlIdentifier],
117
+ model_name: sql_identifier.SqlIdentifier,
118
+ version_name: sql_identifier.SqlIdentifier,
119
+ statement_params: Optional[Dict[str, Any]] = None,
120
+ ) -> None:
121
+ sql = (
122
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
123
+ f" COMMIT VERSION {version_name.identifier()}"
124
+ )
125
+
126
+ query_result_checker.SqlResultValidator(
127
+ self._session,
128
+ sql,
129
+ statement_params=statement_params,
130
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
131
+
74
132
  # TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...`
75
133
  def add_version_from_stage(
76
134
  self,
@@ -1,7 +1,7 @@
1
1
  import enum
2
2
  import json
3
3
  import textwrap
4
- from typing import Any, Dict, List, Optional, Tuple
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
5
 
6
6
  from snowflake import snowpark
7
7
  from snowflake.ml._internal import platform_capabilities
@@ -11,6 +11,7 @@ from snowflake.ml._internal.utils import (
11
11
  sql_identifier,
12
12
  )
13
13
  from snowflake.ml.model._client.sql import _base
14
+ from snowflake.ml.model._model_composer.model_method import constants
14
15
  from snowflake.snowpark import dataframe, functions as F, row, types as spt
15
16
  from snowflake.snowpark._internal import utils as snowpark_utils
16
17
 
@@ -41,7 +42,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
41
42
  image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
42
43
  image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
43
44
  image_repo_name: sql_identifier.SqlIdentifier,
44
- gpu: Optional[str],
45
+ gpu: Optional[Union[str, int]],
45
46
  force_rebuild: bool,
46
47
  external_access_integration: sql_identifier.SqlIdentifier,
47
48
  statement_params: Optional[Dict[str, Any]] = None,
@@ -121,6 +122,11 @@ class ServiceSQLClient(_base._BaseSQLClient):
121
122
  args_sql_list.append(input_arg_value)
122
123
  args_sql = ", ".join(args_sql_list)
123
124
 
125
+ wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
126
+ if wide_input:
127
+ input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
128
+ args_sql = f"object_construct_keep_null({input_args_sql})"
129
+
124
130
  if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
125
131
  fully_qualified_service_name = self.fully_qualified_object_name(
126
132
  actual_database_name, actual_schema_name, service_name
@@ -1,8 +1,10 @@
1
1
  import pathlib
2
2
  import tempfile
3
3
  import uuid
4
+ import warnings
4
5
  from types import ModuleType
5
- from typing import Any, Dict, List, Optional
6
+ from typing import Any, Dict, List, Optional, Union
7
+ from urllib import parse
6
8
 
7
9
  from absl import logging
8
10
  from packaging import requirements
@@ -44,7 +46,13 @@ class ModelComposer:
44
46
  statement_params: Optional[Dict[str, Any]] = None,
45
47
  ) -> None:
46
48
  self.session = session
47
- self.stage_path = pathlib.PurePosixPath(stage_path)
49
+ self.stage_path: Union[pathlib.PurePosixPath, parse.ParseResult] = None # type: ignore[assignment]
50
+ if stage_path.startswith("snow://"):
51
+ # The stage path is a snowflake internal stage path
52
+ self.stage_path = parse.urlparse(stage_path)
53
+ else:
54
+ # The stage path is a user stage path
55
+ self.stage_path = pathlib.PurePosixPath(stage_path)
48
56
 
49
57
  self._workspace = tempfile.TemporaryDirectory()
50
58
  self._packager_workspace = tempfile.TemporaryDirectory()
@@ -70,7 +78,20 @@ class ModelComposer:
70
78
 
71
79
  @property
72
80
  def model_stage_path(self) -> str:
73
- return (self.stage_path / self.model_file_rel_path).as_posix()
81
+ if isinstance(self.stage_path, parse.ParseResult):
82
+ model_file_path = (pathlib.PosixPath(self.stage_path.path) / self.model_file_rel_path).as_posix()
83
+ new_url = parse.ParseResult(
84
+ scheme=self.stage_path.scheme,
85
+ netloc=self.stage_path.netloc,
86
+ path=str(model_file_path),
87
+ params=self.stage_path.params,
88
+ query=self.stage_path.query,
89
+ fragment=self.stage_path.fragment,
90
+ )
91
+ return str(parse.urlunparse(new_url))
92
+ else:
93
+ assert isinstance(self.stage_path, pathlib.PurePosixPath)
94
+ return (self.stage_path / self.model_file_rel_path).as_posix()
74
95
 
75
96
  @property
76
97
  def model_local_path(self) -> str:
@@ -86,6 +107,7 @@ class ModelComposer:
86
107
  metadata: Optional[Dict[str, str]] = None,
87
108
  conda_dependencies: Optional[List[str]] = None,
88
109
  pip_requirements: Optional[List[str]] = None,
110
+ artifact_repository_map: Optional[Dict[str, str]] = None,
89
111
  target_platforms: Optional[List[model_types.TargetPlatform]] = None,
90
112
  python_version: Optional[str] = None,
91
113
  user_files: Optional[Dict[str, List[str]]] = None,
@@ -94,8 +116,32 @@ class ModelComposer:
94
116
  task: model_types.Task = model_types.Task.UNKNOWN,
95
117
  options: Optional[model_types.ModelSaveOption] = None,
96
118
  ) -> model_meta.ModelMetadata:
119
+ # set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
120
+ conda_dep_dict = env_utils.validate_conda_dependency_string_list(
121
+ conda_dependencies if conda_dependencies else []
122
+ )
123
+ is_warehouse_runnable = (
124
+ not conda_dep_dict
125
+ or all(
126
+ chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
127
+ for chan in conda_dep_dict
128
+ )
129
+ ) and (not pip_requirements)
130
+ disable_explainability = (
131
+ target_platforms and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
132
+ ) or (not is_warehouse_runnable)
133
+
134
+ if disable_explainability and options and options.get("enable_explainability", False):
135
+ warnings.warn(
136
+ ("The model can be deployed to Snowpark Container Services only if `enable_explainability=False`."),
137
+ category=UserWarning,
138
+ stacklevel=2,
139
+ )
140
+
97
141
  if not options:
98
142
  options = model_types.BaseModelSaveOption()
143
+ if disable_explainability:
144
+ options["enable_explainability"] = False
99
145
 
100
146
  if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
101
147
  snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
@@ -120,6 +166,7 @@ class ModelComposer:
120
166
  metadata=metadata,
121
167
  conda_dependencies=conda_dependencies,
122
168
  pip_requirements=pip_requirements,
169
+ artifact_repository_map=artifact_repository_map,
123
170
  python_version=python_version,
124
171
  ext_modules=ext_modules,
125
172
  code_paths=code_paths,
@@ -78,6 +78,7 @@ class ModelManifest:
78
78
  logger.info("Relaxing version constraints for dependencies in the model.")
79
79
  logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
80
80
  logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
81
+ logger.info(f"artifact_repository_map: {runtime_to_use.runtime_env.artifact_repository_map}")
81
82
  runtime_dict = runtime_to_use.save(
82
83
  self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
83
84
  )
@@ -124,6 +125,9 @@ class ModelManifest:
124
125
  if len(model_meta.env.pip_requirements) > 0:
125
126
  dependencies["pip"] = runtime_dict["dependencies"]["pip"]
126
127
 
128
+ if model_meta.env.artifact_repository_map:
129
+ dependencies["artifact_repository_map"] = runtime_dict["dependencies"]["artifact_repository_map"]
130
+
127
131
  manifest_dict = model_manifest_schema.ModelManifestDict(
128
132
  manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
129
133
  runtimes={
@@ -1,6 +1,6 @@
1
1
  # This files contains schema definition of what will be written into MANIFEST.yml
2
2
  import enum
3
- from typing import Any, Dict, List, Literal, TypedDict, Union
3
+ from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
4
4
 
5
5
  from typing_extensions import NotRequired, Required
6
6
 
@@ -20,6 +20,7 @@ class ModelMethodFunctionTypes(enum.Enum):
20
20
  class ModelRuntimeDependenciesDict(TypedDict):
21
21
  conda: NotRequired[str]
22
22
  pip: NotRequired[str]
23
+ artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
23
24
 
24
25
 
25
26
  class ModelRuntimeDict(TypedDict):
@@ -98,7 +98,6 @@ class ModelMethod:
98
98
  def _get_method_arg_from_feature(
99
99
  feature: model_signature.BaseFeatureSpec, case_sensitive: bool = False
100
100
  ) -> model_manifest_schema.ModelMethodSignatureFieldWithName:
101
- assert isinstance(feature, model_signature.FeatureSpec), "FeatureGroupSpec is not supported."
102
101
  try:
103
102
  feature_name = sql_identifier.SqlIdentifier(feature.name, case_sensitive=case_sensitive)
104
103
  except ValueError as e:
@@ -3,7 +3,7 @@ import itertools
3
3
  import os
4
4
  import pathlib
5
5
  import warnings
6
- from typing import DefaultDict, List, Optional
6
+ from typing import DefaultDict, Dict, List, Optional
7
7
 
8
8
  from packaging import requirements, version
9
9
 
@@ -36,6 +36,7 @@ class ModelEnv:
36
36
  pip_requirements_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_PIP_REQUIREMENTS_FILENAME)
37
37
  self.conda_env_rel_path = pathlib.PurePosixPath(pathlib.Path(conda_env_rel_path).as_posix())
38
38
  self.pip_requirements_rel_path = pathlib.PurePosixPath(pathlib.Path(pip_requirements_rel_path).as_posix())
39
+ self.artifact_repository_map: Optional[Dict[str, str]] = None
39
40
  self._conda_dependencies: DefaultDict[str, List[requirements.Requirement]] = collections.defaultdict(list)
40
41
  self._pip_requirements: List[requirements.Requirement] = []
41
42
  self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
@@ -345,6 +346,7 @@ class ModelEnv:
345
346
  def load_from_dict(self, base_dir: pathlib.Path, env_dict: model_meta_schema.ModelEnvDict) -> None:
346
347
  self.conda_env_rel_path = pathlib.PurePosixPath(env_dict["conda"])
347
348
  self.pip_requirements_rel_path = pathlib.PurePosixPath(env_dict["pip"])
349
+ self.artifact_repository_map = env_dict.get("artifact_repository_map", None)
348
350
 
349
351
  self.load_from_conda_file(base_dir / self.conda_env_rel_path)
350
352
  self.load_from_pip_file(base_dir / self.pip_requirements_rel_path)
@@ -373,6 +375,7 @@ class ModelEnv:
373
375
  return {
374
376
  "conda": self.conda_env_rel_path.as_posix(),
375
377
  "pip": self.pip_requirements_rel_path.as_posix(),
378
+ "artifact_repository_map": self.artifact_repository_map if self.artifact_repository_map is not None else {},
376
379
  "python_version": self.python_version,
377
380
  "cuda_version": self.cuda_version,
378
381
  "snowpark_ml_version": self.snowpark_ml_version,
@@ -30,10 +30,7 @@ from snowflake.ml.model._packager.model_meta import (
30
30
  model_meta as model_meta_api,
31
31
  model_meta_schema,
32
32
  )
33
- from snowflake.ml.model._signatures import (
34
- builtins_handler,
35
- utils as model_signature_utils,
36
- )
33
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
34
  from snowflake.ml.model.models import huggingface_pipeline
38
35
  from snowflake.snowpark._internal import utils as snowpark_utils
39
36
 
@@ -66,16 +63,16 @@ def get_requirements_from_task(task: str, spcs_only: bool = False) -> List[model
66
63
  return []
67
64
 
68
65
 
69
- class NumpyEncoder(json.JSONEncoder):
70
- # This is a JSON encoder class to ensure the output from Huggingface pipeline is JSON serializable.
71
- # What it covers is numpy object.
72
- def default(self, z: object) -> object:
73
- if isinstance(z, np.number):
74
- if np.can_cast(z, np.int64, casting="safe"):
75
- return int(z)
76
- elif np.can_cast(z, np.float64, casting="safe"):
77
- return z.astype(np.float64)
78
- return super().default(z)
66
+ def sanitize_output(data: Any) -> Any:
67
+ if isinstance(data, np.number):
68
+ return data.item()
69
+ if isinstance(data, np.ndarray):
70
+ return sanitize_output(data.tolist())
71
+ if isinstance(data, list):
72
+ return [sanitize_output(x) for x in data]
73
+ if isinstance(data, dict):
74
+ return {k: sanitize_output(v) for k, v in data.items()}
75
+ return data
79
76
 
80
77
 
81
78
  @final
@@ -410,13 +407,17 @@ class HuggingFacePipelineHandler(
410
407
  )
411
408
  for conv_data in X.to_dict("records")
412
409
  ]
413
- elif len(signature.inputs) == 1:
414
- input_data = X.to_dict("list")[signature.inputs[0].name]
415
410
  else:
416
411
  if isinstance(raw_model, transformers.TableQuestionAnsweringPipeline):
417
412
  X["table"] = X["table"].apply(json.loads)
418
413
 
419
- input_data = X.to_dict("records")
414
+ # Most pipelines if it is expecting more than one arguments,
415
+ # it is expecting a list of dict, where each dict has keys corresponding to the argument.
416
+ if len(signature.inputs) > 1:
417
+ input_data = X.to_dict("records")
418
+ # If it is only expecting one argument, Then it is expecting a list of something.
419
+ else:
420
+ input_data = X[signature.inputs[0].name].to_list()
420
421
  temp_res = getattr(raw_model, target_method)(input_data)
421
422
 
422
423
  # Some huggingface pipeline will omit the outer list when there is only 1 input.
@@ -439,7 +440,6 @@ class HuggingFacePipelineHandler(
439
440
  ),
440
441
  )
441
442
  and X.shape[0] == 1
442
- and isinstance(temp_res[0], dict)
443
443
  )
444
444
  ):
445
445
  temp_res = [temp_res]
@@ -453,14 +453,18 @@ class HuggingFacePipelineHandler(
453
453
  temp_res = [[conv.generated_responses] for conv in temp_res]
454
454
 
455
455
  # To concat those who outputs a list with one input.
456
- if builtins_handler.ListOfBuiltinHandler.can_handle(temp_res):
457
- res = builtins_handler.ListOfBuiltinHandler.convert_to_df(temp_res)
458
- elif isinstance(temp_res[0], dict):
456
+ if isinstance(temp_res[0], list):
457
+ if isinstance(temp_res[0][0], dict):
458
+ res = pd.DataFrame({0: temp_res})
459
+ else:
460
+ res = pd.DataFrame(temp_res)
461
+ else:
459
462
  res = pd.DataFrame(temp_res)
460
- elif isinstance(temp_res[0], list):
461
- res = pd.DataFrame([json.dumps(output, cls=NumpyEncoder) for output in temp_res])
463
+
464
+ if hasattr(res, "map"):
465
+ res = res.map(sanitize_output)
462
466
  else:
463
- raise ValueError(f"Cannot parse output {temp_res} from pipeline object")
467
+ res = res.applymap(sanitize_output)
464
468
 
465
469
  return model_signature_utils.rename_pandas_df(data=res, features=signature.outputs)
466
470
 
@@ -191,11 +191,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
191
191
  signature: model_signature.ModelSignature,
192
192
  target_method: str,
193
193
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
194
- dtype_map = {
195
- spec.name: spec.as_dtype(force_numpy_dtype=True)
196
- for spec in signature.inputs
197
- if isinstance(spec, model_signature.FeatureSpec)
198
- }
194
+ dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
199
195
 
200
196
  @custom_model.inference_api
201
197
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame: