snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/file_utils.py +18 -4
  3. snowflake/ml/_internal/platform_capabilities.py +3 -0
  4. snowflake/ml/_internal/telemetry.py +4 -0
  5. snowflake/ml/fileset/fileset.py +0 -1
  6. snowflake/ml/jobs/_utils/constants.py +24 -0
  7. snowflake/ml/jobs/_utils/payload_utils.py +94 -20
  8. snowflake/ml/jobs/_utils/spec_utils.py +73 -31
  9. snowflake/ml/jobs/decorators.py +3 -0
  10. snowflake/ml/jobs/manager.py +5 -0
  11. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  12. snowflake/ml/model/_client/ops/model_ops.py +107 -14
  13. snowflake/ml/model/_client/ops/service_ops.py +1 -1
  14. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  15. snowflake/ml/model/_client/sql/model_version.py +58 -0
  16. snowflake/ml/model/_client/sql/service.py +8 -2
  17. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  18. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  19. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  20. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  21. snowflake/ml/model/_packager/model_env/model_env.py +4 -1
  22. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
  23. snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
  24. snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
  25. snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -4
  26. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
  27. snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
  28. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  29. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  30. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  31. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  32. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
  35. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -0
  36. snowflake/ml/model/_packager/model_packager.py +3 -5
  37. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  38. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
  39. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  40. snowflake/ml/model/_signatures/core.py +52 -31
  41. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  42. snowflake/ml/model/_signatures/numpy_handler.py +9 -17
  43. snowflake/ml/model/_signatures/pandas_handler.py +19 -30
  44. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  45. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  46. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  47. snowflake/ml/model/_signatures/utils.py +120 -8
  48. snowflake/ml/model/custom_model.py +13 -4
  49. snowflake/ml/model/model_signature.py +31 -13
  50. snowflake/ml/model/type_hints.py +13 -2
  51. snowflake/ml/modeling/metrics/ranking.py +3 -0
  52. snowflake/ml/modeling/metrics/regression.py +3 -0
  53. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  54. snowflake/ml/registry/_manager/model_manager.py +55 -7
  55. snowflake/ml/registry/registry.py +18 -0
  56. snowflake/ml/version.py +1 -1
  57. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +287 -11
  58. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +61 -57
  59. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
  60. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
  61. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -98,6 +98,9 @@ class ModelRuntime:
98
98
  dependencies=model_meta_schema.ModelRuntimeDependenciesDict(
99
99
  conda=env_dict["conda"],
100
100
  pip=env_dict["pip"],
101
+ artifact_repository_map=env_dict["artifact_repository_map"]
102
+ if env_dict.get("artifact_repository_map") is not None
103
+ else {},
101
104
  ),
102
105
  )
103
106
 
@@ -112,6 +115,7 @@ class ModelRuntime:
112
115
  env.python_version = meta_env.python_version
113
116
  env.cuda_version = meta_env.cuda_version
114
117
  env.snowpark_ml_version = meta_env.snowpark_ml_version
118
+ env.artifact_repository_map = meta_env.artifact_repository_map
115
119
 
116
120
  conda_env_rel_path = pathlib.PurePosixPath(loaded_dict["dependencies"]["conda"])
117
121
  pip_requirements_rel_path = pathlib.PurePosixPath(loaded_dict["dependencies"]["pip"])
@@ -14,21 +14,32 @@ from snowflake.ml.model._signatures import base_handler, core, pandas_handler
14
14
 
15
15
 
16
16
  class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBuiltinsList]):
17
+ @staticmethod
18
+ def _can_handle_element(
19
+ element: model_types._SupportedBuiltins,
20
+ ) -> TypeGuard[model_types._SupportedBuiltins]:
21
+ if isinstance(element, abc.Sequence) and not isinstance(element, str):
22
+ for sub_element in element:
23
+ if not ListOfBuiltinHandler._can_handle_element(sub_element):
24
+ return False
25
+ return True
26
+ elif isinstance(element, abc.Mapping):
27
+ for key, value in element.items():
28
+ if not isinstance(key, str):
29
+ return False
30
+ if not ListOfBuiltinHandler._can_handle_element(value):
31
+ return False
32
+ return True
33
+ else:
34
+ return isinstance(element, (int, float, bool, str, datetime.datetime))
35
+
17
36
  @staticmethod
18
37
  def can_handle(data: model_types.SupportedDataType) -> TypeGuard[model_types._SupportedBuiltinsList]:
19
38
  if not isinstance(data, abc.Sequence) or isinstance(data, str):
20
39
  return False
21
40
  if len(data) == 0:
22
41
  return False
23
- can_handle = True
24
- for element in data:
25
- # String is a Sequence but we take them as an whole
26
- if isinstance(element, abc.Sequence) and not isinstance(element, str):
27
- can_handle = ListOfBuiltinHandler.can_handle(element)
28
- elif not isinstance(element, (int, float, bool, str, datetime.datetime)):
29
- can_handle = False
30
- break
31
- return can_handle
42
+ return ListOfBuiltinHandler._can_handle_element(data)
32
43
 
33
44
  @staticmethod
34
45
  def count(data: model_types._SupportedBuiltinsList) -> int:
@@ -199,9 +199,16 @@ class DataType(Enum):
199
199
  class BaseFeatureSpec(ABC):
200
200
  """Abstract Class for specification of a feature."""
201
201
 
202
- def __init__(self, name: str) -> None:
202
+ def __init__(self, name: str, shape: Optional[Tuple[int, ...]]) -> None:
203
203
  self._name = name
204
204
 
205
+ if shape and not isinstance(shape, tuple):
206
+ raise snowml_exceptions.SnowflakeMLException(
207
+ error_code=error_codes.INVALID_TYPE,
208
+ original_exception=TypeError("Shape should be a tuple if presented."),
209
+ )
210
+ self._shape = shape
211
+
205
212
  @final
206
213
  @property
207
214
  def name(self) -> str:
@@ -213,6 +220,11 @@ class BaseFeatureSpec(ABC):
213
220
  """Convert to corresponding Snowpark Type."""
214
221
  pass
215
222
 
223
+ @abstractmethod
224
+ def as_dtype(self, force_numpy_dtype: bool = False) -> Union[npt.DTypeLike, str, PandasExtensionTypes]:
225
+ """Convert to corresponding local Type."""
226
+ pass
227
+
216
228
  @abstractmethod
217
229
  def to_dict(self) -> Dict[str, Any]:
218
230
  """Serialization"""
@@ -256,7 +268,7 @@ class FeatureSpec(BaseFeatureSpec):
256
268
  SnowflakeMLException: TypeError: When the dtype input type is incorrect.
257
269
  SnowflakeMLException: TypeError: When the shape input type is incorrect.
258
270
  """
259
- super().__init__(name=name)
271
+ super().__init__(name=name, shape=shape)
260
272
 
261
273
  if not isinstance(dtype, DataType):
262
274
  raise snowml_exceptions.SnowflakeMLException(
@@ -265,13 +277,6 @@ class FeatureSpec(BaseFeatureSpec):
265
277
  )
266
278
  self._dtype = dtype
267
279
 
268
- if shape and not isinstance(shape, tuple):
269
- raise snowml_exceptions.SnowflakeMLException(
270
- error_code=error_codes.INVALID_TYPE,
271
- original_exception=TypeError("Shape should be a tuple if presented."),
272
- )
273
- self._shape = shape
274
-
275
280
  self._nullable = nullable
276
281
 
277
282
  def as_snowpark_type(self) -> spt.DataType:
@@ -386,15 +391,23 @@ class FeatureSpec(BaseFeatureSpec):
386
391
  class FeatureGroupSpec(BaseFeatureSpec):
387
392
  """Specification of a group of features in Snowflake native model packaging."""
388
393
 
389
- def __init__(self, name: str, specs: List[FeatureSpec]) -> None:
394
+ def __init__(self, name: str, specs: List[BaseFeatureSpec], shape: Optional[Tuple[int, ...]] = None) -> None:
390
395
  """Initialize a feature group.
391
396
 
392
397
  Args:
393
398
  name: Name of the feature group.
394
399
  specs: A list of feature specifications that composes the group. All children feature specs have to have
395
400
  name. And all of them should have the same type.
401
+ shape: Used to represent scalar feature, 1-d feature list,
402
+ or n-d tensor. Use -1 to represent variable length. Defaults to None.
403
+
404
+ Examples:
405
+ - None: scalar
406
+ - (2,): 1d list with a fixed length of 2.
407
+ - (-1,): 1d list with variable length, used for ragged tensor representation.
408
+ - (d1, d2, d3): 3d tensor.
396
409
  """
397
- super().__init__(name=name)
410
+ super().__init__(name=name, shape=shape)
398
411
  self._specs = specs
399
412
  self._validate()
400
413
 
@@ -409,47 +422,52 @@ class FeatureGroupSpec(BaseFeatureSpec):
409
422
  error_code=error_codes.INVALID_ARGUMENT,
410
423
  original_exception=ValueError("All children feature specs have to have name."),
411
424
  )
412
- if not (all(s._shape is None for s in self._specs) or all(s._shape is not None for s in self._specs)):
413
- raise snowml_exceptions.SnowflakeMLException(
414
- error_code=error_codes.INVALID_ARGUMENT,
415
- original_exception=ValueError("All children feature specs have to have same shape."),
416
- )
417
- first_type = self._specs[0]._dtype
418
- if not all(s._dtype == first_type for s in self._specs):
419
- raise snowml_exceptions.SnowflakeMLException(
420
- error_code=error_codes.INVALID_ARGUMENT,
421
- original_exception=ValueError("All children feature specs have to have same type."),
422
- )
423
425
 
424
426
  def as_snowpark_type(self) -> spt.DataType:
425
- first_type = self._specs[0].as_snowpark_type()
426
- return spt.MapType(spt.StringType(), first_type)
427
+ spt_type = spt.StructType(
428
+ fields=[
429
+ spt.StructField(
430
+ s._name, datatype=s.as_snowpark_type(), nullable=s._nullable if isinstance(s, FeatureSpec) else True
431
+ )
432
+ for s in self._specs
433
+ ]
434
+ )
435
+ if not self._shape:
436
+ return spt_type
437
+ return spt.ArrayType(spt_type)
427
438
 
428
439
  def __eq__(self, other: object) -> bool:
429
440
  if isinstance(other, FeatureGroupSpec):
430
- return self._specs == other._specs
441
+ return self._name == other._name and self._specs == other._specs and self._shape == other._shape
431
442
  else:
432
443
  return False
433
444
 
434
445
  def __repr__(self) -> str:
435
446
  spec_strs = ",\n\t\t".join(repr(spec) for spec in self._specs)
447
+ shape_str = f", shape={repr(self._shape)}" if self._shape else ""
436
448
  return textwrap.dedent(
437
449
  f"""FeatureGroupSpec(
438
450
  name={repr(self._name)},
439
451
  specs=[
440
452
  {spec_strs}
441
- ]
453
+ ]{shape_str}
442
454
  )
443
455
  """
444
456
  )
445
457
 
458
+ def as_dtype(self, force_numpy_dtype: bool = False) -> Union[npt.DTypeLike, str, PandasExtensionTypes]:
459
+ return np.object_
460
+
446
461
  def to_dict(self) -> Dict[str, Any]:
447
462
  """Serialize the feature group into a dict.
448
463
 
449
464
  Returns:
450
465
  A dict that serializes the feature group.
451
466
  """
452
- return {"feature_group": {"name": self._name, "specs": [s.to_dict() for s in self._specs]}}
467
+ base_dict: Dict[str, Any] = {"name": self._name, "specs": [s.to_dict() for s in self._specs]}
468
+ if self._shape is not None:
469
+ base_dict["shape"] = self._shape
470
+ return base_dict
453
471
 
454
472
  @classmethod
455
473
  def from_dict(cls, input_dict: Dict[str, Any]) -> "FeatureGroupSpec":
@@ -462,10 +480,13 @@ class FeatureGroupSpec(BaseFeatureSpec):
462
480
  A feature group instance deserialized and created from the dict.
463
481
  """
464
482
  specs = []
465
- for e in input_dict["feature_group"]["specs"]:
466
- spec = FeatureSpec.from_dict(e)
483
+ for e in input_dict["specs"]:
484
+ spec = FeatureGroupSpec.from_dict(e) if "specs" in e else FeatureSpec.from_dict(e)
467
485
  specs.append(spec)
468
- return FeatureGroupSpec(name=input_dict["feature_group"]["name"], specs=specs)
486
+ shape = input_dict.get("shape", None)
487
+ if shape:
488
+ shape = tuple(shape)
489
+ return FeatureGroupSpec(name=input_dict["name"], specs=specs, shape=shape)
469
490
 
470
491
 
471
492
  class ModelSignature:
@@ -525,7 +546,7 @@ class ModelSignature:
525
546
  sig_inputs = loaded["inputs"]
526
547
 
527
548
  deserialize_spec: Callable[[Dict[str, Any]], BaseFeatureSpec] = lambda sig_spec: (
528
- FeatureGroupSpec.from_dict(sig_spec) if "feature_group" in sig_spec else FeatureSpec.from_dict(sig_spec)
549
+ FeatureGroupSpec.from_dict(sig_spec) if "specs" in sig_spec else FeatureSpec.from_dict(sig_spec)
529
550
  )
530
551
 
531
552
  return ModelSignature(
@@ -0,0 +1,98 @@
1
+ from typing import TYPE_CHECKING, List, Literal, Optional, Sequence
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from typing_extensions import TypeGuard
6
+
7
+ from snowflake.ml._internal import type_utils
8
+ from snowflake.ml._internal.exceptions import (
9
+ error_codes,
10
+ exceptions as snowml_exceptions,
11
+ )
12
+ from snowflake.ml.model import type_hints as model_types
13
+ from snowflake.ml.model._signatures import base_handler, core
14
+
15
+ if TYPE_CHECKING:
16
+ import xgboost
17
+
18
+
19
+ class XGBoostDMatrixHandler(base_handler.BaseDataHandler["xgboost.DMatrix"]):
20
+ @staticmethod
21
+ def can_handle(data: model_types.SupportedDataType) -> TypeGuard["xgboost.DMatrix"]:
22
+ return type_utils.LazyType("xgboost.DMatrix").isinstance(data)
23
+
24
+ @staticmethod
25
+ def count(data: "xgboost.DMatrix") -> int:
26
+ return data.num_row()
27
+
28
+ @staticmethod
29
+ def truncate(data: "xgboost.DMatrix", length: int) -> "xgboost.DMatrix":
30
+
31
+ num_rows = min(
32
+ XGBoostDMatrixHandler.count(data),
33
+ length,
34
+ )
35
+ return data.slice(list(range(num_rows)))
36
+
37
+ @staticmethod
38
+ def validate(data: "xgboost.DMatrix") -> None:
39
+ if data.num_row() == 0:
40
+ raise snowml_exceptions.SnowflakeMLException(
41
+ error_code=error_codes.INVALID_DATA,
42
+ original_exception=ValueError("Data Validation Error: Empty data is found."),
43
+ )
44
+
45
+ @staticmethod
46
+ def infer_signature(data: "xgboost.DMatrix", role: Literal["input", "output"]) -> Sequence[core.BaseFeatureSpec]:
47
+ feature_prefix = f"{XGBoostDMatrixHandler.FEATURE_PREFIX}_"
48
+ features: List[core.BaseFeatureSpec] = []
49
+ role_prefix = (
50
+ XGBoostDMatrixHandler.INPUT_PREFIX if role == "input" else XGBoostDMatrixHandler.OUTPUT_PREFIX
51
+ ) + "_"
52
+
53
+ feature_names = data.feature_names or []
54
+ feature_types = data.feature_types or []
55
+
56
+ for i, (feature_name, dtype) in enumerate(zip(feature_names, feature_types)):
57
+ if not feature_name:
58
+ ft_name = f"{role_prefix}{feature_prefix}{i}"
59
+ else:
60
+ ft_name = feature_name
61
+
62
+ features.append(core.FeatureSpec(dtype=core.DataType.from_numpy_type(np.dtype(dtype)), name=ft_name))
63
+ return features
64
+
65
+ @staticmethod
66
+ def convert_to_df(data: "xgboost.DMatrix", ensure_serializable: bool = True) -> pd.DataFrame:
67
+ df = pd.DataFrame(data.get_data().toarray(), columns=data.feature_names)
68
+
69
+ feature_types = data.feature_types or []
70
+
71
+ if feature_types:
72
+ for idx, col in enumerate(df.columns):
73
+ dtype = feature_types[idx]
74
+ df[col] = df[col].astype(dtype)
75
+
76
+ return df
77
+
78
+ @staticmethod
79
+ def convert_from_df(
80
+ df: pd.DataFrame, features: Optional[Sequence[core.BaseFeatureSpec]] = None
81
+ ) -> "xgboost.DMatrix":
82
+ import xgboost as xgb
83
+
84
+ if not features:
85
+ return xgb.DMatrix(df)
86
+ else:
87
+ feature_names = []
88
+ feature_types = []
89
+ for feature in features:
90
+ if isinstance(feature, core.FeatureGroupSpec):
91
+ raise snowml_exceptions.SnowflakeMLException(
92
+ error_code=error_codes.NOT_IMPLEMENTED,
93
+ original_exception=NotImplementedError("FeatureGroupSpec is not supported."),
94
+ )
95
+ assert isinstance(feature, core.FeatureSpec), "Invalid feature kind."
96
+ feature_names.append(feature.name)
97
+ feature_types.append(feature._dtype._numpy_type)
98
+ return xgb.DMatrix(df, feature_names=feature_names, feature_types=feature_types)
@@ -1,5 +1,5 @@
1
1
  from collections import abc
2
- from typing import List, Literal, Sequence
2
+ from typing import Literal, Sequence
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -10,7 +10,7 @@ from snowflake.ml._internal.exceptions import (
10
10
  exceptions as snowml_exceptions,
11
11
  )
12
12
  from snowflake.ml.model import type_hints as model_types
13
- from snowflake.ml.model._signatures import base_handler, core
13
+ from snowflake.ml.model._signatures import base_handler, core, pandas_handler
14
14
 
15
15
 
16
16
  class NumpyArrayHandler(base_handler.BaseDataHandler[model_types._SupportedNumpyArray]):
@@ -46,6 +46,10 @@ class NumpyArrayHandler(base_handler.BaseDataHandler[model_types._SupportedNumpy
46
46
  def infer_signature(
47
47
  data: model_types._SupportedNumpyArray, role: Literal["input", "output"]
48
48
  ) -> Sequence[core.BaseFeatureSpec]:
49
+ if data.dtype == np.object_:
50
+ return pandas_handler.PandasDataFrameHandler.infer_signature(
51
+ NumpyArrayHandler.convert_to_df(data), role=role
52
+ )
49
53
  feature_prefix = f"{NumpyArrayHandler.FEATURE_PREFIX}_"
50
54
  dtype = core.DataType.from_numpy_type(data.dtype)
51
55
  role_prefix = (NumpyArrayHandler.INPUT_PREFIX if role == "input" else NumpyArrayHandler.OUTPUT_PREFIX) + "_"
@@ -108,21 +112,9 @@ class SeqOfNumpyArrayHandler(base_handler.BaseDataHandler[Sequence[model_types._
108
112
  def infer_signature(
109
113
  data: Sequence[model_types._SupportedNumpyArray], role: Literal["input", "output"]
110
114
  ) -> Sequence[core.BaseFeatureSpec]:
111
- feature_prefix = f"{SeqOfNumpyArrayHandler.FEATURE_PREFIX}_"
112
- features: List[core.BaseFeatureSpec] = []
113
- role_prefix = (
114
- SeqOfNumpyArrayHandler.INPUT_PREFIX if role == "input" else SeqOfNumpyArrayHandler.OUTPUT_PREFIX
115
- ) + "_"
116
-
117
- for i, data_col in enumerate(data):
118
- dtype = core.DataType.from_numpy_type(data_col.dtype)
119
- ft_name = f"{role_prefix}{feature_prefix}{i}"
120
- if len(data_col.shape) == 1:
121
- features.append(core.FeatureSpec(dtype=dtype, name=ft_name, nullable=False))
122
- else:
123
- ft_shape = tuple(data_col.shape[1:])
124
- features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape, nullable=False))
125
- return features
115
+ return pandas_handler.PandasDataFrameHandler.infer_signature(
116
+ SeqOfNumpyArrayHandler.convert_to_df(data, ensure_serializable=False), role=role
117
+ )
126
118
 
127
119
  @staticmethod
128
120
  def convert_to_df(
@@ -97,25 +97,7 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
97
97
  ),
98
98
  )
99
99
 
100
- if isinstance(df_col_data.iloc[0], list):
101
- arr = utils.convert_list_to_ndarray(df_col_data.iloc[0])
102
- arr_dtype = core.DataType.from_numpy_type(arr.dtype)
103
-
104
- converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in df_col_data]
105
-
106
- if not all(
107
- core.DataType.from_numpy_type(converted_data.dtype) == arr_dtype
108
- for converted_data in converted_data_list
109
- ):
110
- raise snowml_exceptions.SnowflakeMLException(
111
- error_code=error_codes.INVALID_DATA,
112
- original_exception=ValueError(
113
- "Data Validation Error: "
114
- + f"Inconsistent type of element in object found in column data {df_col_data}."
115
- ),
116
- )
117
-
118
- elif isinstance(df_col_data.iloc[0], np.ndarray):
100
+ if isinstance(df_col_data.iloc[0], np.ndarray):
119
101
  arr_dtype = core.DataType.from_numpy_type(df_col_data.iloc[0].dtype)
120
102
 
121
103
  if not all(core.DataType.from_numpy_type(data_row.dtype) == arr_dtype for data_row in df_col_data):
@@ -126,7 +108,7 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
126
108
  + f"Inconsistent type of element in object found in column data {df_col_data}."
127
109
  ),
128
110
  )
129
- elif not isinstance(df_col_data.iloc[0], (str, bytes)):
111
+ elif not isinstance(df_col_data.iloc[0], (str, bytes, dict, list)):
130
112
  raise snowml_exceptions.SnowflakeMLException(
131
113
  error_code=error_codes.INVALID_DATA,
132
114
  original_exception=ValueError(
@@ -171,16 +153,23 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
171
153
 
172
154
  if df_col_dtype == np.dtype("O"):
173
155
  if isinstance(df_col_data.iloc[0], list):
174
- arr = utils.convert_list_to_ndarray(df_col_data.iloc[0])
175
- arr_dtype = core.DataType.from_numpy_type(arr.dtype)
176
- ft_shape = np.shape(df_col_data.iloc[0])
177
-
178
- converted_data_list = [utils.convert_list_to_ndarray(data_row) for data_row in df_col_data]
179
-
180
- if not all(np.shape(converted_data) == ft_shape for converted_data in converted_data_list):
181
- ft_shape = (-1,)
182
-
183
- specs.append(core.FeatureSpec(dtype=arr_dtype, name=ft_name, shape=ft_shape))
156
+ spec_0 = utils.infer_list(ft_name, df_col_data.iloc[0])
157
+ for i in range(1, len(df_col_data)):
158
+ spec = utils.infer_list(ft_name, df_col_data.iloc[i])
159
+ if spec._shape != spec_0._shape:
160
+ spec_0._shape = (-1,)
161
+ spec._shape = (-1,)
162
+ if spec != spec_0:
163
+ raise snowml_exceptions.SnowflakeMLException(
164
+ error_code=error_codes.INVALID_DATA,
165
+ original_exception=ValueError(
166
+ "Unable to construct signature: "
167
+ f"Ragged nested or Unsupported list-like data {df_col_data} confronted."
168
+ ),
169
+ )
170
+ specs.append(spec_0)
171
+ elif isinstance(df_col_data.iloc[0], dict):
172
+ specs.append(utils.infer_dict(ft_name, df_col_data.iloc[0]))
184
173
  elif isinstance(df_col_data.iloc[0], np.ndarray):
185
174
  arr_dtype = core.DataType.from_numpy_type(df_col_data.iloc[0].dtype)
186
175
  ft_shape = np.shape(df_col_data.iloc[0])
@@ -1,5 +1,5 @@
1
1
  from collections import abc
2
- from typing import TYPE_CHECKING, List, Literal, Optional, Sequence
2
+ from typing import TYPE_CHECKING, Literal, Optional, Sequence
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -11,12 +11,54 @@ from snowflake.ml._internal.exceptions import (
11
11
  exceptions as snowml_exceptions,
12
12
  )
13
13
  from snowflake.ml.model import type_hints as model_types
14
- from snowflake.ml.model._signatures import base_handler, core
14
+ from snowflake.ml.model._signatures import base_handler, core, numpy_handler
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  import torch
18
18
 
19
19
 
20
+ class PyTorchTensorHandler(base_handler.BaseDataHandler["torch.Tensor"]):
21
+ @staticmethod
22
+ def can_handle(data: model_types.SupportedDataType) -> TypeGuard["torch.Tensor"]:
23
+ return type_utils.LazyType("torch.Tensor").isinstance(data)
24
+
25
+ @staticmethod
26
+ def count(data: "torch.Tensor") -> int:
27
+ return data.shape[0]
28
+
29
+ @staticmethod
30
+ def truncate(data: "torch.Tensor", length: int) -> "torch.Tensor":
31
+ return data[: min(PyTorchTensorHandler.count(data), length)]
32
+
33
+ @staticmethod
34
+ def validate(data: "torch.Tensor") -> None:
35
+ return numpy_handler.NumpyArrayHandler.validate(data.detach().cpu().numpy())
36
+
37
+ @staticmethod
38
+ def infer_signature(data: "torch.Tensor", role: Literal["input", "output"]) -> Sequence[core.BaseFeatureSpec]:
39
+ return numpy_handler.NumpyArrayHandler.infer_signature(data.detach().cpu().numpy(), role=role)
40
+
41
+ @staticmethod
42
+ def convert_to_df(data: "torch.Tensor", ensure_serializable: bool = True) -> pd.DataFrame:
43
+ return numpy_handler.NumpyArrayHandler.convert_to_df(
44
+ data.detach().cpu().numpy(), ensure_serializable=ensure_serializable
45
+ )
46
+
47
+ @staticmethod
48
+ def convert_from_df(df: pd.DataFrame, features: Optional[Sequence[core.BaseFeatureSpec]] = None) -> "torch.Tensor":
49
+ import torch
50
+
51
+ if features is None:
52
+ if any(dtype == np.dtype("O") for dtype in df.dtypes):
53
+ return torch.from_numpy(np.array(df.to_numpy().tolist()))
54
+ return torch.from_numpy(df.to_numpy())
55
+
56
+ assert isinstance(features[0], core.FeatureSpec)
57
+ return torch.from_numpy(
58
+ np.array(df.to_numpy().tolist(), dtype=features[0]._dtype._numpy_type),
59
+ )
60
+
61
+
20
62
  class SeqOfPyTorchTensorHandler(base_handler.BaseDataHandler[Sequence["torch.Tensor"]]):
21
63
  @staticmethod
22
64
  def can_handle(data: model_types.SupportedDataType) -> TypeGuard[Sequence["torch.Tensor"]]:
@@ -24,56 +66,28 @@ class SeqOfPyTorchTensorHandler(base_handler.BaseDataHandler[Sequence["torch.Ten
24
66
  return False
25
67
  if len(data) == 0:
26
68
  return False
27
- if type_utils.LazyType("torch.Tensor").isinstance(data[0]):
28
- return all(type_utils.LazyType("torch.Tensor").isinstance(data_col) for data_col in data)
29
- return False
69
+ return all(PyTorchTensorHandler.can_handle(data_col) for data_col in data)
30
70
 
31
71
  @staticmethod
32
72
  def count(data: Sequence["torch.Tensor"]) -> int:
33
- return min(data_col.shape[0] for data_col in data)
73
+ return min(PyTorchTensorHandler.count(data_col) for data_col in data)
34
74
 
35
75
  @staticmethod
36
76
  def truncate(data: Sequence["torch.Tensor"], length: int) -> Sequence["torch.Tensor"]:
37
- return [data_col[: min(SeqOfPyTorchTensorHandler.count(data), 10)] for data_col in data]
77
+ return [data_col[: min(SeqOfPyTorchTensorHandler.count(data), length)] for data_col in data]
38
78
 
39
79
  @staticmethod
40
80
  def validate(data: Sequence["torch.Tensor"]) -> None:
41
- import torch
42
-
43
81
  for data_col in data:
44
- if data_col.shape == torch.Size([0]):
45
- # Empty array
46
- raise snowml_exceptions.SnowflakeMLException(
47
- error_code=error_codes.INVALID_DATA,
48
- original_exception=ValueError("Data Validation Error: Empty data is found."),
49
- )
50
-
51
- if data_col.shape == torch.Size([1]):
52
- # scalar
53
- raise snowml_exceptions.SnowflakeMLException(
54
- error_code=error_codes.INVALID_DATA,
55
- original_exception=ValueError("Data Validation Error: Scalar data is found."),
56
- )
82
+ PyTorchTensorHandler.validate(data_col)
57
83
 
58
84
  @staticmethod
59
85
  def infer_signature(
60
86
  data: Sequence["torch.Tensor"], role: Literal["input", "output"]
61
87
  ) -> Sequence[core.BaseFeatureSpec]:
62
- feature_prefix = f"{SeqOfPyTorchTensorHandler.FEATURE_PREFIX}_"
63
- features: List[core.BaseFeatureSpec] = []
64
- role_prefix = (
65
- SeqOfPyTorchTensorHandler.INPUT_PREFIX if role == "input" else SeqOfPyTorchTensorHandler.OUTPUT_PREFIX
66
- ) + "_"
67
-
68
- for i, data_col in enumerate(data):
69
- dtype = core.DataType.from_torch_type(data_col.dtype)
70
- ft_name = f"{role_prefix}{feature_prefix}{i}"
71
- if len(data_col.shape) == 1:
72
- features.append(core.FeatureSpec(dtype=dtype, name=ft_name, nullable=False))
73
- else:
74
- ft_shape = tuple(data_col.shape[1:])
75
- features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape, nullable=False))
76
- return features
88
+ return numpy_handler.SeqOfNumpyArrayHandler.infer_signature(
89
+ [data_col.detach().cpu().numpy() for data_col in data], role=role
90
+ )
77
91
 
78
92
  @staticmethod
79
93
  def convert_to_df(data: Sequence["torch.Tensor"], ensure_serializable: bool = True) -> pd.DataFrame:
@@ -81,8 +95,8 @@ class SeqOfPyTorchTensorHandler(base_handler.BaseDataHandler[Sequence["torch.Ten
81
95
  # the content is still numpy array so that the type could be preserved.
82
96
  # But that would not serializable and cannot use as UDF input and output.
83
97
  if ensure_serializable:
84
- return pd.DataFrame({i: data_col.detach().to("cpu").numpy().tolist() for i, data_col in enumerate(data)})
85
- return pd.DataFrame({i: list(data_col.detach().to("cpu").numpy()) for i, data_col in enumerate(data)})
98
+ return pd.DataFrame({i: data_col.detach().cpu().numpy().tolist() for i, data_col in enumerate(data)})
99
+ return pd.DataFrame({i: list(data_col.detach().cpu().numpy()) for i, data_col in enumerate(data)})
86
100
 
87
101
  @staticmethod
88
102
  def convert_from_df(
@@ -95,8 +109,10 @@ class SeqOfPyTorchTensorHandler(base_handler.BaseDataHandler[Sequence["torch.Ten
95
109
  for feature in features:
96
110
  if isinstance(feature, core.FeatureGroupSpec):
97
111
  raise snowml_exceptions.SnowflakeMLException(
98
- error_code=error_codes.NOT_IMPLEMENTED,
99
- original_exception=NotImplementedError("FeatureGroupSpec is not supported."),
112
+ error_code=error_codes.INVALID_DATA_TYPE,
113
+ original_exception=NotImplementedError(
114
+ "FeatureGroupSpec is not supported when converting to Tensorflow tensor."
115
+ ),
100
116
  )
101
117
  assert isinstance(feature, core.FeatureSpec), "Invalid feature kind."
102
118
  res.append(torch.from_numpy(np.stack(df[feature.name].to_numpy()).astype(feature._dtype._numpy_type)))
@@ -65,12 +65,6 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
65
65
  dtype_map = {}
66
66
  if features:
67
67
  for feature in features:
68
- if isinstance(feature, core.FeatureGroupSpec):
69
- raise snowml_exceptions.SnowflakeMLException(
70
- error_code=error_codes.NOT_IMPLEMENTED,
71
- original_exception=NotImplementedError("FeatureGroupSpec is not supported."),
72
- )
73
- assert isinstance(feature, core.FeatureSpec), "Invalid feature kind."
74
68
  dtype_map[feature.name] = feature.as_dtype()
75
69
  df_local = data.to_pandas()
76
70
 
@@ -122,12 +116,6 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
122
116
  column_names = []
123
117
  columns = []
124
118
  for feature in features:
125
- if isinstance(feature, core.FeatureGroupSpec):
126
- raise snowml_exceptions.SnowflakeMLException(
127
- error_code=error_codes.NOT_IMPLEMENTED,
128
- original_exception=NotImplementedError("FeatureGroupSpec is not supported."),
129
- )
130
- assert isinstance(feature, core.FeatureSpec), "Invalid feature kind."
131
119
  column_names.append(identifier.get_inferred_name(feature.name))
132
120
  columns.append(F.col(identifier.get_inferred_name(feature.name)).cast(feature.as_snowpark_type()))
133
121