snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/file_utils.py +18 -4
  3. snowflake/ml/_internal/platform_capabilities.py +3 -0
  4. snowflake/ml/_internal/telemetry.py +4 -0
  5. snowflake/ml/fileset/fileset.py +0 -1
  6. snowflake/ml/jobs/_utils/constants.py +25 -1
  7. snowflake/ml/jobs/_utils/payload_utils.py +94 -20
  8. snowflake/ml/jobs/_utils/spec_utils.py +95 -31
  9. snowflake/ml/jobs/decorators.py +7 -0
  10. snowflake/ml/jobs/manager.py +20 -0
  11. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  12. snowflake/ml/model/_client/ops/model_ops.py +113 -17
  13. snowflake/ml/model/_client/ops/service_ops.py +16 -5
  14. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  15. snowflake/ml/model/_client/sql/model_version.py +58 -0
  16. snowflake/ml/model/_client/sql/service.py +10 -2
  17. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  18. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +5 -2
  19. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  20. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  21. snowflake/ml/model/_packager/model_env/model_env.py +4 -1
  22. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
  23. snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
  24. snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
  25. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -8
  26. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
  27. snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
  28. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  29. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  30. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  31. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  32. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -2
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
  35. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -0
  36. snowflake/ml/model/_packager/model_packager.py +3 -5
  37. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  38. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
  39. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  40. snowflake/ml/model/_signatures/core.py +52 -31
  41. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  42. snowflake/ml/model/_signatures/numpy_handler.py +9 -17
  43. snowflake/ml/model/_signatures/pandas_handler.py +19 -30
  44. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  45. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  46. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  47. snowflake/ml/model/_signatures/utils.py +120 -8
  48. snowflake/ml/model/custom_model.py +13 -4
  49. snowflake/ml/model/model_signature.py +31 -13
  50. snowflake/ml/model/type_hints.py +13 -2
  51. snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
  52. snowflake/ml/modeling/metrics/ranking.py +3 -0
  53. snowflake/ml/modeling/metrics/regression.py +3 -0
  54. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  55. snowflake/ml/registry/_manager/model_manager.py +55 -7
  56. snowflake/ml/registry/registry.py +59 -1
  57. snowflake/ml/version.py +1 -1
  58. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/METADATA +308 -12
  59. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/RECORD +62 -58
  60. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/WHEEL +1 -1
  61. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info/licenses}/LICENSE.txt +0 -0
  62. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/top_level.txt +0 -0
@@ -48,6 +48,7 @@ def create_model_metadata(
48
48
  ext_modules: Optional[List[ModuleType]] = None,
49
49
  conda_dependencies: Optional[List[str]] = None,
50
50
  pip_requirements: Optional[List[str]] = None,
51
+ artifact_repository_map: Optional[Dict[str, str]] = None,
51
52
  python_version: Optional[str] = None,
52
53
  task: model_types.Task = model_types.Task.UNKNOWN,
53
54
  **kwargs: Any,
@@ -67,6 +68,7 @@ def create_model_metadata(
67
68
  ext_modules: List of names of modules that need to be pickled with the model. Defaults to None.
68
69
  conda_dependencies: List of conda requirements for running the model. Defaults to None.
69
70
  pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
71
+ artifact_repository_map: A dict mapping from package channel to artifact repository name.
70
72
  python_version: A string of python version where model is run. Used for user override. If specified as None,
71
73
  current version would be captured. Defaults to None.
72
74
  task: The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION,
@@ -102,6 +104,7 @@ def create_model_metadata(
102
104
  env = _create_env_for_model_metadata(
103
105
  conda_dependencies=conda_dependencies,
104
106
  pip_requirements=pip_requirements,
107
+ artifact_repository_map=artifact_repository_map,
105
108
  python_version=python_version,
106
109
  embed_local_ml_library=embed_local_ml_library,
107
110
  )
@@ -151,6 +154,7 @@ def _create_env_for_model_metadata(
151
154
  *,
152
155
  conda_dependencies: Optional[List[str]] = None,
153
156
  pip_requirements: Optional[List[str]] = None,
157
+ artifact_repository_map: Optional[Dict[str, str]] = None,
154
158
  python_version: Optional[str] = None,
155
159
  embed_local_ml_library: bool = False,
156
160
  ) -> model_env.ModelEnv:
@@ -159,6 +163,7 @@ def _create_env_for_model_metadata(
159
163
  # Mypy doesn't like getter and setter have different types. See python/mypy #3004
160
164
  env.conda_dependencies = conda_dependencies # type: ignore[assignment]
161
165
  env.pip_requirements = pip_requirements # type: ignore[assignment]
166
+ env.artifact_repository_map = artifact_repository_map
162
167
  env.python_version = python_version # type: ignore[assignment]
163
168
  env.snowpark_ml_version = snowml_env.VERSION
164
169
 
@@ -331,7 +336,6 @@ class ModelMetadata:
331
336
  "function_properties": self.function_properties,
332
337
  }
333
338
  )
334
-
335
339
  with open(model_yaml_path, "w", encoding="utf-8") as out:
336
340
  yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
337
341
  yaml.safe_dump(model_dict, stream=out, default_flow_style=False)
@@ -18,6 +18,7 @@ class FunctionProperties(Enum):
18
18
  class ModelRuntimeDependenciesDict(TypedDict):
19
19
  conda: Required[str]
20
20
  pip: Required[str]
21
+ artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
21
22
 
22
23
 
23
24
  class ModelRuntimeDict(TypedDict):
@@ -28,6 +29,7 @@ class ModelRuntimeDict(TypedDict):
28
29
  class ModelEnvDict(TypedDict):
29
30
  conda: Required[str]
30
31
  pip: Required[str]
32
+ artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
31
33
  python_version: Required[str]
32
34
  cuda_version: NotRequired[Optional[str]]
33
35
  snowpark_ml_version: Required[str]
@@ -61,8 +63,17 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
61
63
  xgb_estimator_type: Required[str]
62
64
 
63
65
 
66
+ class PyTorchModelBlobOptions(BaseModelBlobOptions):
67
+ multiple_inputs: Required[bool]
68
+
69
+
70
+ class TorchScriptModelBlobOptions(BaseModelBlobOptions):
71
+ multiple_inputs: Required[bool]
72
+
73
+
64
74
  class TensorflowModelBlobOptions(BaseModelBlobOptions):
65
75
  save_format: Required[str]
76
+ multiple_inputs: Required[bool]
66
77
 
67
78
 
68
79
  class SentenceTransformersModelBlobOptions(BaseModelBlobOptions):
@@ -71,9 +82,12 @@ class SentenceTransformersModelBlobOptions(BaseModelBlobOptions):
71
82
 
72
83
  ModelBlobOptions = Union[
73
84
  BaseModelBlobOptions,
85
+ CatBoostModelBlobOptions,
74
86
  HuggingFacePipelineModelBlobOptions,
75
87
  MLFlowModelBlobOptions,
76
88
  XgboostModelBlobOptions,
89
+ PyTorchModelBlobOptions,
90
+ TorchScriptModelBlobOptions,
77
91
  TensorflowModelBlobOptions,
78
92
  SentenceTransformersModelBlobOptions,
79
93
  ]
@@ -43,13 +43,13 @@ class ModelPackager:
43
43
  metadata: Optional[Dict[str, str]] = None,
44
44
  conda_dependencies: Optional[List[str]] = None,
45
45
  pip_requirements: Optional[List[str]] = None,
46
+ artifact_repository_map: Optional[Dict[str, str]] = None,
46
47
  python_version: Optional[str] = None,
47
48
  ext_modules: Optional[List[ModuleType]] = None,
48
49
  code_paths: Optional[List[str]] = None,
49
- options: Optional[model_types.ModelSaveOption] = None,
50
+ options: model_types.ModelSaveOption,
50
51
  task: model_types.Task = model_types.Task.UNKNOWN,
51
52
  ) -> model_meta.ModelMetadata:
52
-
53
53
  if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
54
54
  raise snowml_exceptions.SnowflakeMLException(
55
55
  error_code=error_codes.INVALID_ARGUMENT,
@@ -58,9 +58,6 @@ class ModelPackager:
58
58
  ),
59
59
  )
60
60
 
61
- if not options:
62
- options = model_types.BaseModelSaveOption()
63
-
64
61
  handler = model_handler.find_handler(model)
65
62
  if handler is None:
66
63
  raise snowml_exceptions.SnowflakeMLException(
@@ -77,6 +74,7 @@ class ModelPackager:
77
74
  ext_modules=ext_modules,
78
75
  conda_dependencies=conda_dependencies,
79
76
  pip_requirements=pip_requirements,
77
+ artifact_repository_map=artifact_repository_map,
80
78
  python_version=python_version,
81
79
  task=task,
82
80
  **options,
@@ -1,2 +1 @@
1
- REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
- ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'keras>=2.0.0,<4', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<3', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.7.0,<3', 'sentencepiece>=0.1.95,<0.2.0', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'tensorflow>=2.17.0,<3', 'tokenizers>=0.15.1,<1', 'torchdata>=0.4,<1', 'transformers>=4.37.2,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
1
+ REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0,<3', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.12.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
@@ -98,6 +98,9 @@ class ModelRuntime:
98
98
  dependencies=model_meta_schema.ModelRuntimeDependenciesDict(
99
99
  conda=env_dict["conda"],
100
100
  pip=env_dict["pip"],
101
+ artifact_repository_map=env_dict["artifact_repository_map"]
102
+ if env_dict.get("artifact_repository_map") is not None
103
+ else {},
101
104
  ),
102
105
  )
103
106
 
@@ -112,6 +115,7 @@ class ModelRuntime:
112
115
  env.python_version = meta_env.python_version
113
116
  env.cuda_version = meta_env.cuda_version
114
117
  env.snowpark_ml_version = meta_env.snowpark_ml_version
118
+ env.artifact_repository_map = meta_env.artifact_repository_map
115
119
 
116
120
  conda_env_rel_path = pathlib.PurePosixPath(loaded_dict["dependencies"]["conda"])
117
121
  pip_requirements_rel_path = pathlib.PurePosixPath(loaded_dict["dependencies"]["pip"])
@@ -14,21 +14,32 @@ from snowflake.ml.model._signatures import base_handler, core, pandas_handler
14
14
 
15
15
 
16
16
  class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBuiltinsList]):
17
+ @staticmethod
18
+ def _can_handle_element(
19
+ element: model_types._SupportedBuiltins,
20
+ ) -> TypeGuard[model_types._SupportedBuiltins]:
21
+ if isinstance(element, abc.Sequence) and not isinstance(element, str):
22
+ for sub_element in element:
23
+ if not ListOfBuiltinHandler._can_handle_element(sub_element):
24
+ return False
25
+ return True
26
+ elif isinstance(element, abc.Mapping):
27
+ for key, value in element.items():
28
+ if not isinstance(key, str):
29
+ return False
30
+ if not ListOfBuiltinHandler._can_handle_element(value):
31
+ return False
32
+ return True
33
+ else:
34
+ return isinstance(element, (int, float, bool, str, datetime.datetime))
35
+
17
36
  @staticmethod
18
37
  def can_handle(data: model_types.SupportedDataType) -> TypeGuard[model_types._SupportedBuiltinsList]:
19
38
  if not isinstance(data, abc.Sequence) or isinstance(data, str):
20
39
  return False
21
40
  if len(data) == 0:
22
41
  return False
23
- can_handle = True
24
- for element in data:
25
- # String is a Sequence but we take them as an whole
26
- if isinstance(element, abc.Sequence) and not isinstance(element, str):
27
- can_handle = ListOfBuiltinHandler.can_handle(element)
28
- elif not isinstance(element, (int, float, bool, str, datetime.datetime)):
29
- can_handle = False
30
- break
31
- return can_handle
42
+ return ListOfBuiltinHandler._can_handle_element(data)
32
43
 
33
44
  @staticmethod
34
45
  def count(data: model_types._SupportedBuiltinsList) -> int:
@@ -199,9 +199,16 @@ class DataType(Enum):
199
199
  class BaseFeatureSpec(ABC):
200
200
  """Abstract Class for specification of a feature."""
201
201
 
202
- def __init__(self, name: str) -> None:
202
+ def __init__(self, name: str, shape: Optional[Tuple[int, ...]]) -> None:
203
203
  self._name = name
204
204
 
205
+ if shape and not isinstance(shape, tuple):
206
+ raise snowml_exceptions.SnowflakeMLException(
207
+ error_code=error_codes.INVALID_TYPE,
208
+ original_exception=TypeError("Shape should be a tuple if presented."),
209
+ )
210
+ self._shape = shape
211
+
205
212
  @final
206
213
  @property
207
214
  def name(self) -> str:
@@ -213,6 +220,11 @@ class BaseFeatureSpec(ABC):
213
220
  """Convert to corresponding Snowpark Type."""
214
221
  pass
215
222
 
223
+ @abstractmethod
224
+ def as_dtype(self, force_numpy_dtype: bool = False) -> Union[npt.DTypeLike, str, PandasExtensionTypes]:
225
+ """Convert to corresponding local Type."""
226
+ pass
227
+
216
228
  @abstractmethod
217
229
  def to_dict(self) -> Dict[str, Any]:
218
230
  """Serialization"""
@@ -256,7 +268,7 @@ class FeatureSpec(BaseFeatureSpec):
256
268
  SnowflakeMLException: TypeError: When the dtype input type is incorrect.
257
269
  SnowflakeMLException: TypeError: When the shape input type is incorrect.
258
270
  """
259
- super().__init__(name=name)
271
+ super().__init__(name=name, shape=shape)
260
272
 
261
273
  if not isinstance(dtype, DataType):
262
274
  raise snowml_exceptions.SnowflakeMLException(
@@ -265,13 +277,6 @@ class FeatureSpec(BaseFeatureSpec):
265
277
  )
266
278
  self._dtype = dtype
267
279
 
268
- if shape and not isinstance(shape, tuple):
269
- raise snowml_exceptions.SnowflakeMLException(
270
- error_code=error_codes.INVALID_TYPE,
271
- original_exception=TypeError("Shape should be a tuple if presented."),
272
- )
273
- self._shape = shape
274
-
275
280
  self._nullable = nullable
276
281
 
277
282
  def as_snowpark_type(self) -> spt.DataType:
@@ -386,15 +391,23 @@ class FeatureSpec(BaseFeatureSpec):
386
391
  class FeatureGroupSpec(BaseFeatureSpec):
387
392
  """Specification of a group of features in Snowflake native model packaging."""
388
393
 
389
- def __init__(self, name: str, specs: List[FeatureSpec]) -> None:
394
+ def __init__(self, name: str, specs: List[BaseFeatureSpec], shape: Optional[Tuple[int, ...]] = None) -> None:
390
395
  """Initialize a feature group.
391
396
 
392
397
  Args:
393
398
  name: Name of the feature group.
394
399
  specs: A list of feature specifications that composes the group. All children feature specs have to have
395
400
  name. And all of them should have the same type.
401
+ shape: Used to represent scalar feature, 1-d feature list,
402
+ or n-d tensor. Use -1 to represent variable length. Defaults to None.
403
+
404
+ Examples:
405
+ - None: scalar
406
+ - (2,): 1d list with a fixed length of 2.
407
+ - (-1,): 1d list with variable length, used for ragged tensor representation.
408
+ - (d1, d2, d3): 3d tensor.
396
409
  """
397
- super().__init__(name=name)
410
+ super().__init__(name=name, shape=shape)
398
411
  self._specs = specs
399
412
  self._validate()
400
413
 
@@ -409,47 +422,52 @@ class FeatureGroupSpec(BaseFeatureSpec):
409
422
  error_code=error_codes.INVALID_ARGUMENT,
410
423
  original_exception=ValueError("All children feature specs have to have name."),
411
424
  )
412
- if not (all(s._shape is None for s in self._specs) or all(s._shape is not None for s in self._specs)):
413
- raise snowml_exceptions.SnowflakeMLException(
414
- error_code=error_codes.INVALID_ARGUMENT,
415
- original_exception=ValueError("All children feature specs have to have same shape."),
416
- )
417
- first_type = self._specs[0]._dtype
418
- if not all(s._dtype == first_type for s in self._specs):
419
- raise snowml_exceptions.SnowflakeMLException(
420
- error_code=error_codes.INVALID_ARGUMENT,
421
- original_exception=ValueError("All children feature specs have to have same type."),
422
- )
423
425
 
424
426
  def as_snowpark_type(self) -> spt.DataType:
425
- first_type = self._specs[0].as_snowpark_type()
426
- return spt.MapType(spt.StringType(), first_type)
427
+ spt_type = spt.StructType(
428
+ fields=[
429
+ spt.StructField(
430
+ s._name, datatype=s.as_snowpark_type(), nullable=s._nullable if isinstance(s, FeatureSpec) else True
431
+ )
432
+ for s in self._specs
433
+ ]
434
+ )
435
+ if not self._shape:
436
+ return spt_type
437
+ return spt.ArrayType(spt_type)
427
438
 
428
439
  def __eq__(self, other: object) -> bool:
429
440
  if isinstance(other, FeatureGroupSpec):
430
- return self._specs == other._specs
441
+ return self._name == other._name and self._specs == other._specs and self._shape == other._shape
431
442
  else:
432
443
  return False
433
444
 
434
445
  def __repr__(self) -> str:
435
446
  spec_strs = ",\n\t\t".join(repr(spec) for spec in self._specs)
447
+ shape_str = f", shape={repr(self._shape)}" if self._shape else ""
436
448
  return textwrap.dedent(
437
449
  f"""FeatureGroupSpec(
438
450
  name={repr(self._name)},
439
451
  specs=[
440
452
  {spec_strs}
441
- ]
453
+ ]{shape_str}
442
454
  )
443
455
  """
444
456
  )
445
457
 
458
+ def as_dtype(self, force_numpy_dtype: bool = False) -> Union[npt.DTypeLike, str, PandasExtensionTypes]:
459
+ return np.object_
460
+
446
461
  def to_dict(self) -> Dict[str, Any]:
447
462
  """Serialize the feature group into a dict.
448
463
 
449
464
  Returns:
450
465
  A dict that serializes the feature group.
451
466
  """
452
- return {"feature_group": {"name": self._name, "specs": [s.to_dict() for s in self._specs]}}
467
+ base_dict: Dict[str, Any] = {"name": self._name, "specs": [s.to_dict() for s in self._specs]}
468
+ if self._shape is not None:
469
+ base_dict["shape"] = self._shape
470
+ return base_dict
453
471
 
454
472
  @classmethod
455
473
  def from_dict(cls, input_dict: Dict[str, Any]) -> "FeatureGroupSpec":
@@ -462,10 +480,13 @@ class FeatureGroupSpec(BaseFeatureSpec):
462
480
  A feature group instance deserialized and created from the dict.
463
481
  """
464
482
  specs = []
465
- for e in input_dict["feature_group"]["specs"]:
466
- spec = FeatureSpec.from_dict(e)
483
+ for e in input_dict["specs"]:
484
+ spec = FeatureGroupSpec.from_dict(e) if "specs" in e else FeatureSpec.from_dict(e)
467
485
  specs.append(spec)
468
- return FeatureGroupSpec(name=input_dict["feature_group"]["name"], specs=specs)
486
+ shape = input_dict.get("shape", None)
487
+ if shape:
488
+ shape = tuple(shape)
489
+ return FeatureGroupSpec(name=input_dict["name"], specs=specs, shape=shape)
469
490
 
470
491
 
471
492
  class ModelSignature:
@@ -525,7 +546,7 @@ class ModelSignature:
525
546
  sig_inputs = loaded["inputs"]
526
547
 
527
548
  deserialize_spec: Callable[[Dict[str, Any]], BaseFeatureSpec] = lambda sig_spec: (
528
- FeatureGroupSpec.from_dict(sig_spec) if "feature_group" in sig_spec else FeatureSpec.from_dict(sig_spec)
549
+ FeatureGroupSpec.from_dict(sig_spec) if "specs" in sig_spec else FeatureSpec.from_dict(sig_spec)
529
550
  )
530
551
 
531
552
  return ModelSignature(
@@ -0,0 +1,98 @@
1
+ from typing import TYPE_CHECKING, List, Literal, Optional, Sequence
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from typing_extensions import TypeGuard
6
+
7
+ from snowflake.ml._internal import type_utils
8
+ from snowflake.ml._internal.exceptions import (
9
+ error_codes,
10
+ exceptions as snowml_exceptions,
11
+ )
12
+ from snowflake.ml.model import type_hints as model_types
13
+ from snowflake.ml.model._signatures import base_handler, core
14
+
15
+ if TYPE_CHECKING:
16
+ import xgboost
17
+
18
+
19
+ class XGBoostDMatrixHandler(base_handler.BaseDataHandler["xgboost.DMatrix"]):
20
+ @staticmethod
21
+ def can_handle(data: model_types.SupportedDataType) -> TypeGuard["xgboost.DMatrix"]:
22
+ return type_utils.LazyType("xgboost.DMatrix").isinstance(data)
23
+
24
+ @staticmethod
25
+ def count(data: "xgboost.DMatrix") -> int:
26
+ return data.num_row()
27
+
28
+ @staticmethod
29
+ def truncate(data: "xgboost.DMatrix", length: int) -> "xgboost.DMatrix":
30
+
31
+ num_rows = min(
32
+ XGBoostDMatrixHandler.count(data),
33
+ length,
34
+ )
35
+ return data.slice(list(range(num_rows)))
36
+
37
+ @staticmethod
38
+ def validate(data: "xgboost.DMatrix") -> None:
39
+ if data.num_row() == 0:
40
+ raise snowml_exceptions.SnowflakeMLException(
41
+ error_code=error_codes.INVALID_DATA,
42
+ original_exception=ValueError("Data Validation Error: Empty data is found."),
43
+ )
44
+
45
+ @staticmethod
46
+ def infer_signature(data: "xgboost.DMatrix", role: Literal["input", "output"]) -> Sequence[core.BaseFeatureSpec]:
47
+ feature_prefix = f"{XGBoostDMatrixHandler.FEATURE_PREFIX}_"
48
+ features: List[core.BaseFeatureSpec] = []
49
+ role_prefix = (
50
+ XGBoostDMatrixHandler.INPUT_PREFIX if role == "input" else XGBoostDMatrixHandler.OUTPUT_PREFIX
51
+ ) + "_"
52
+
53
+ feature_names = data.feature_names or []
54
+ feature_types = data.feature_types or []
55
+
56
+ for i, (feature_name, dtype) in enumerate(zip(feature_names, feature_types)):
57
+ if not feature_name:
58
+ ft_name = f"{role_prefix}{feature_prefix}{i}"
59
+ else:
60
+ ft_name = feature_name
61
+
62
+ features.append(core.FeatureSpec(dtype=core.DataType.from_numpy_type(np.dtype(dtype)), name=ft_name))
63
+ return features
64
+
65
+ @staticmethod
66
+ def convert_to_df(data: "xgboost.DMatrix", ensure_serializable: bool = True) -> pd.DataFrame:
67
+ df = pd.DataFrame(data.get_data().toarray(), columns=data.feature_names)
68
+
69
+ feature_types = data.feature_types or []
70
+
71
+ if feature_types:
72
+ for idx, col in enumerate(df.columns):
73
+ dtype = feature_types[idx]
74
+ df[col] = df[col].astype(dtype)
75
+
76
+ return df
77
+
78
+ @staticmethod
79
+ def convert_from_df(
80
+ df: pd.DataFrame, features: Optional[Sequence[core.BaseFeatureSpec]] = None
81
+ ) -> "xgboost.DMatrix":
82
+ import xgboost as xgb
83
+
84
+ if not features:
85
+ return xgb.DMatrix(df)
86
+ else:
87
+ feature_names = []
88
+ feature_types = []
89
+ for feature in features:
90
+ if isinstance(feature, core.FeatureGroupSpec):
91
+ raise snowml_exceptions.SnowflakeMLException(
92
+ error_code=error_codes.NOT_IMPLEMENTED,
93
+ original_exception=NotImplementedError("FeatureGroupSpec is not supported."),
94
+ )
95
+ assert isinstance(feature, core.FeatureSpec), "Invalid feature kind."
96
+ feature_names.append(feature.name)
97
+ feature_types.append(feature._dtype._numpy_type)
98
+ return xgb.DMatrix(df, feature_names=feature_names, feature_types=feature_types)
@@ -1,5 +1,5 @@
1
1
  from collections import abc
2
- from typing import List, Literal, Sequence
2
+ from typing import Literal, Sequence
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -10,7 +10,7 @@ from snowflake.ml._internal.exceptions import (
10
10
  exceptions as snowml_exceptions,
11
11
  )
12
12
  from snowflake.ml.model import type_hints as model_types
13
- from snowflake.ml.model._signatures import base_handler, core
13
+ from snowflake.ml.model._signatures import base_handler, core, pandas_handler
14
14
 
15
15
 
16
16
  class NumpyArrayHandler(base_handler.BaseDataHandler[model_types._SupportedNumpyArray]):
@@ -46,6 +46,10 @@ class NumpyArrayHandler(base_handler.BaseDataHandler[model_types._SupportedNumpy
46
46
  def infer_signature(
47
47
  data: model_types._SupportedNumpyArray, role: Literal["input", "output"]
48
48
  ) -> Sequence[core.BaseFeatureSpec]:
49
+ if data.dtype == np.object_:
50
+ return pandas_handler.PandasDataFrameHandler.infer_signature(
51
+ NumpyArrayHandler.convert_to_df(data), role=role
52
+ )
49
53
  feature_prefix = f"{NumpyArrayHandler.FEATURE_PREFIX}_"
50
54
  dtype = core.DataType.from_numpy_type(data.dtype)
51
55
  role_prefix = (NumpyArrayHandler.INPUT_PREFIX if role == "input" else NumpyArrayHandler.OUTPUT_PREFIX) + "_"
@@ -108,21 +112,9 @@ class SeqOfNumpyArrayHandler(base_handler.BaseDataHandler[Sequence[model_types._
108
112
  def infer_signature(
109
113
  data: Sequence[model_types._SupportedNumpyArray], role: Literal["input", "output"]
110
114
  ) -> Sequence[core.BaseFeatureSpec]:
111
- feature_prefix = f"{SeqOfNumpyArrayHandler.FEATURE_PREFIX}_"
112
- features: List[core.BaseFeatureSpec] = []
113
- role_prefix = (
114
- SeqOfNumpyArrayHandler.INPUT_PREFIX if role == "input" else SeqOfNumpyArrayHandler.OUTPUT_PREFIX
115
- ) + "_"
116
-
117
- for i, data_col in enumerate(data):
118
- dtype = core.DataType.from_numpy_type(data_col.dtype)
119
- ft_name = f"{role_prefix}{feature_prefix}{i}"
120
- if len(data_col.shape) == 1:
121
- features.append(core.FeatureSpec(dtype=dtype, name=ft_name, nullable=False))
122
- else:
123
- ft_shape = tuple(data_col.shape[1:])
124
- features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape, nullable=False))
125
- return features
115
+ return pandas_handler.PandasDataFrameHandler.infer_signature(
116
+ SeqOfNumpyArrayHandler.convert_to_df(data, ensure_serializable=False), role=role
117
+ )
126
118
 
127
119
  @staticmethod
128
120
  def convert_to_df(
@@ -97,25 +97,7 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
97
97
  ),
98
98
  )
99
99
 
100
- if isinstance(df_col_data.iloc[0], list):
101
- arr = utils.convert_list_to_ndarray(df_col_data.iloc[0])
102
- arr_dtype = core.DataType.from_numpy_type(arr.dtype)
103
-
104
- converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in df_col_data]
105
-
106
- if not all(
107
- core.DataType.from_numpy_type(converted_data.dtype) == arr_dtype
108
- for converted_data in converted_data_list
109
- ):
110
- raise snowml_exceptions.SnowflakeMLException(
111
- error_code=error_codes.INVALID_DATA,
112
- original_exception=ValueError(
113
- "Data Validation Error: "
114
- + f"Inconsistent type of element in object found in column data {df_col_data}."
115
- ),
116
- )
117
-
118
- elif isinstance(df_col_data.iloc[0], np.ndarray):
100
+ if isinstance(df_col_data.iloc[0], np.ndarray):
119
101
  arr_dtype = core.DataType.from_numpy_type(df_col_data.iloc[0].dtype)
120
102
 
121
103
  if not all(core.DataType.from_numpy_type(data_row.dtype) == arr_dtype for data_row in df_col_data):
@@ -126,7 +108,7 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
126
108
  + f"Inconsistent type of element in object found in column data {df_col_data}."
127
109
  ),
128
110
  )
129
- elif not isinstance(df_col_data.iloc[0], (str, bytes)):
111
+ elif not isinstance(df_col_data.iloc[0], (str, bytes, dict, list)):
130
112
  raise snowml_exceptions.SnowflakeMLException(
131
113
  error_code=error_codes.INVALID_DATA,
132
114
  original_exception=ValueError(
@@ -171,16 +153,23 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
171
153
 
172
154
  if df_col_dtype == np.dtype("O"):
173
155
  if isinstance(df_col_data.iloc[0], list):
174
- arr = utils.convert_list_to_ndarray(df_col_data.iloc[0])
175
- arr_dtype = core.DataType.from_numpy_type(arr.dtype)
176
- ft_shape = np.shape(df_col_data.iloc[0])
177
-
178
- converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in df_col_data]
179
-
180
- if not all(np.shape(converted_data) == ft_shape for converted_data in converted_data_list):
181
- ft_shape = (-1,)
182
-
183
- specs.append(core.FeatureSpec(dtype=arr_dtype, name=ft_name, shape=ft_shape))
156
+ spec_0 = utils.infer_list(ft_name, df_col_data.iloc[0])
157
+ for i in range(1, len(df_col_data)):
158
+ spec = utils.infer_list(ft_name, df_col_data.iloc[i])
159
+ if spec._shape != spec_0._shape:
160
+ spec_0._shape = (-1,)
161
+ spec._shape = (-1,)
162
+ if spec != spec_0:
163
+ raise snowml_exceptions.SnowflakeMLException(
164
+ error_code=error_codes.INVALID_DATA,
165
+ original_exception=ValueError(
166
+ "Unable to construct signature: "
167
+ f"Ragged nested or Unsupported list-like data {df_col_data} confronted."
168
+ ),
169
+ )
170
+ specs.append(spec_0)
171
+ elif isinstance(df_col_data.iloc[0], dict):
172
+ specs.append(utils.infer_dict(ft_name, df_col_data.iloc[0]))
184
173
  elif isinstance(df_col_data.iloc[0], np.ndarray):
185
174
  arr_dtype = core.DataType.from_numpy_type(df_col_data.iloc[0].dtype)
186
175
  ft_shape = np.shape(df_col_data.iloc[0])