snowflake-ml-python 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_complete.py +107 -64
  3. snowflake/cortex/_finetune.py +273 -0
  4. snowflake/cortex/_sse_client.py +91 -28
  5. snowflake/cortex/_util.py +30 -1
  6. snowflake/ml/_internal/type_utils.py +3 -3
  7. snowflake/ml/data/__init__.py +5 -0
  8. snowflake/ml/model/_client/model/model_version_impl.py +7 -7
  9. snowflake/ml/model/_client/ops/model_ops.py +51 -30
  10. snowflake/ml/model/_client/ops/service_ops.py +13 -2
  11. snowflake/ml/model/_client/sql/model.py +0 -14
  12. snowflake/ml/model/_client/sql/service.py +25 -1
  13. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  14. snowflake/ml/model/_packager/model_env/model_env.py +12 -0
  15. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -1
  16. snowflake/ml/model/_packager/model_handlers/catboost.py +1 -1
  17. snowflake/ml/model/_packager/model_handlers/custom.py +3 -1
  18. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -1
  19. snowflake/ml/model/_packager/model_handlers/sklearn.py +48 -1
  20. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
  22. snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
  23. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
  24. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
  25. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
  26. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
  27. snowflake/ml/model/_packager/model_task/model_task_utils.py +1 -1
  28. snowflake/ml/model/_signatures/core.py +63 -16
  29. snowflake/ml/model/_signatures/pandas_handler.py +71 -27
  30. snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
  31. snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
  32. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
  33. snowflake/ml/model/_signatures/utils.py +4 -0
  34. snowflake/ml/model/model_signature.py +38 -9
  35. snowflake/ml/model/type_hints.py +1 -1
  36. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
  37. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
  38. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +158 -1045
  39. snowflake/ml/monitoring/_manager/model_monitor_manager.py +106 -230
  40. snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
  41. snowflake/ml/monitoring/model_monitor.py +7 -96
  42. snowflake/ml/registry/registry.py +17 -29
  43. snowflake/ml/version.py +1 -1
  44. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +31 -5
  45. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +48 -47
  46. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
  47. snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
  48. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
  49. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -19,12 +19,26 @@ from snowflake.ml.model._packager.model_meta import (
19
19
  )
20
20
  from snowflake.ml.model._packager.model_task import model_task_utils
21
21
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
22
+ from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
22
23
 
23
24
  if TYPE_CHECKING:
24
25
  import sklearn.base
25
26
  import sklearn.pipeline
26
27
 
27
28
 
29
+ def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "sklearn.pipeline.Pipeline":
30
+ new_steps = []
31
+ for step_name, step in model.steps:
32
+ new_reg = step
33
+ if hasattr(step, "_sklearn_estimator") and step._sklearn_estimator is not None:
34
+ # Unpack estimator to open source.
35
+ new_reg = step._sklearn_estimator
36
+ new_steps.append((step_name, new_reg))
37
+
38
+ model.steps = new_steps
39
+ return model
40
+
41
+
28
42
  @final
29
43
  class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]]):
30
44
  """Handler for scikit-learn based model.
@@ -101,6 +115,10 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
101
115
  if sample_input_data is None:
102
116
  raise ValueError("Sample input data is required to enable explainability.")
103
117
 
118
+ # If this is a pipeline and we are in the container runtime, check for distributed estimator.
119
+ if os.getenv(IN_ML_RUNTIME_ENV_VAR) and isinstance(model, sklearn.pipeline.Pipeline):
120
+ model = _unpack_container_runtime_pipeline(model)
121
+
104
122
  if not is_sub_model:
105
123
  target_methods = handlers_utils.get_target_methods(
106
124
  model=model,
@@ -135,7 +153,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
135
153
  )
136
154
 
137
155
  model_task_and_output_type = model_task_utils.get_model_task_and_output_type(model)
138
- model_meta.task = model_task_and_output_type.task
156
+ model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
139
157
 
140
158
  # if users did not ask then we enable if we have background data
141
159
  if enable_explainability is None:
@@ -177,6 +195,35 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
177
195
  model_meta.models[name] = base_meta
178
196
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
179
197
 
198
+ # if model instance is a pipeline, check the pipeline steps
199
+ if isinstance(model, sklearn.pipeline.Pipeline):
200
+ for _, pipeline_step in model.steps:
201
+ if type_utils.LazyType("lightgbm.LGBMModel").isinstance(pipeline_step) or type_utils.LazyType(
202
+ "lightgbm.Booster"
203
+ ).isinstance(pipeline_step):
204
+ model_meta.env.include_if_absent(
205
+ [
206
+ model_env.ModelDependency(requirement="lightgbm", pip_name="lightgbm"),
207
+ ],
208
+ check_local_version=True,
209
+ )
210
+ elif type_utils.LazyType("xgboost.XGBModel").isinstance(pipeline_step) or type_utils.LazyType(
211
+ "xgboost.Booster"
212
+ ).isinstance(pipeline_step):
213
+ model_meta.env.include_if_absent(
214
+ [
215
+ model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
216
+ ],
217
+ check_local_version=True,
218
+ )
219
+ elif type_utils.LazyType("catboost.CatBoost").isinstance(pipeline_step):
220
+ model_meta.env.include_if_absent(
221
+ [
222
+ model_env.ModelDependency(requirement="catboost", pip_name="catboost"),
223
+ ],
224
+ check_local_version=True,
225
+ )
226
+
180
227
  if enable_explainability:
181
228
  model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
182
229
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
@@ -138,7 +138,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
138
138
  enable_explainability = False
139
139
  else:
140
140
  model_task_and_output_type = model_task_utils.get_model_task_and_output_type(python_base_obj)
141
- model_meta.task = model_task_and_output_type.task
141
+ model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
142
142
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
143
143
  model_meta = handlers_utils.add_explain_method_signature(
144
144
  model_meta=model_meta,
@@ -13,6 +13,7 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
13
13
  from snowflake.ml.model._packager.model_meta import (
14
14
  model_blob_meta,
15
15
  model_meta as model_meta_api,
16
+ model_meta_schema,
16
17
  )
17
18
  from snowflake.ml.model._signatures import (
18
19
  numpy_handler,
@@ -76,7 +77,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
76
77
 
77
78
  assert isinstance(model, tensorflow.Module)
78
79
 
79
- if isinstance(model, tensorflow.keras.Model):
80
+ is_keras_model = type_utils.LazyType("tensorflow.keras.Model").isinstance(model) or type_utils.LazyType(
81
+ "tf_keras.Model"
82
+ ).isinstance(model)
83
+
84
+ if is_keras_model:
80
85
  default_target_methods = ["predict"]
81
86
  else:
82
87
  default_target_methods = cls.DEFAULT_TARGET_METHODS
@@ -117,8 +122,14 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
117
122
 
118
123
  model_blob_path = os.path.join(model_blobs_dir_path, name)
119
124
  os.makedirs(model_blob_path, exist_ok=True)
120
- if isinstance(model, tensorflow.keras.Model):
125
+ if is_keras_model:
121
126
  tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
127
+ model_meta.env.include_if_absent(
128
+ [
129
+ model_env.ModelDependency(requirement="keras<3", pip_name="keras"),
130
+ ],
131
+ check_local_version=False,
132
+ )
122
133
  else:
123
134
  tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
124
135
 
@@ -127,12 +138,16 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
127
138
  model_type=cls.HANDLER_TYPE,
128
139
  handler_version=cls.HANDLER_VERSION,
129
140
  path=cls.MODEL_BLOB_FILE_OR_DIR,
141
+ options=model_meta_schema.TensorflowModelBlobOptions(is_keras_model=is_keras_model),
130
142
  )
131
143
  model_meta.models[name] = base_meta
132
144
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
133
145
 
134
146
  model_meta.env.include_if_absent(
135
- [model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow")], check_local_version=True
147
+ [
148
+ model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
149
+ ],
150
+ check_local_version=True,
136
151
  )
137
152
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
138
153
 
@@ -150,9 +165,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
150
165
  model_blobs_metadata = model_meta.models
151
166
  model_blob_metadata = model_blobs_metadata[name]
152
167
  model_blob_filename = model_blob_metadata.path
153
- m = tensorflow.keras.models.load_model(os.path.join(model_blob_path, model_blob_filename), compile=False)
154
- if isinstance(m, tensorflow.keras.Model):
155
- return m
168
+ model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
169
+ if model_blob_options.get("is_keras_model", False):
170
+ m = tensorflow.keras.models.load_model(os.path.join(model_blob_path, model_blob_filename), compile=False)
171
+ else:
172
+ m = tensorflow.saved_model.load(os.path.join(model_blob_path, model_blob_filename))
156
173
  return cast(tensorflow.Module, m)
157
174
 
158
175
  @classmethod
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
23
23
 
24
24
 
25
25
  @final
26
- class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # type:ignore[name-defined]
26
+ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
27
27
  """Handler for PyTorch JIT based model.
28
28
 
29
29
  Currently torch.jit.ScriptModule based classes are supported.
@@ -41,25 +41,25 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
41
41
  def can_handle(
42
42
  cls,
43
43
  model: model_types.SupportedModelType,
44
- ) -> TypeGuard["torch.jit.ScriptModule"]: # type:ignore[name-defined]
44
+ ) -> TypeGuard["torch.jit.ScriptModule"]:
45
45
  return type_utils.LazyType("torch.jit.ScriptModule").isinstance(model)
46
46
 
47
47
  @classmethod
48
48
  def cast_model(
49
49
  cls,
50
50
  model: model_types.SupportedModelType,
51
- ) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
51
+ ) -> "torch.jit.ScriptModule":
52
52
  import torch
53
53
 
54
- assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
54
+ assert isinstance(model, torch.jit.ScriptModule)
55
55
 
56
- return cast(torch.jit.ScriptModule, model) # type:ignore[name-defined]
56
+ return cast(torch.jit.ScriptModule, model)
57
57
 
58
58
  @classmethod
59
59
  def save_model(
60
60
  cls,
61
61
  name: str,
62
- model: "torch.jit.ScriptModule", # type:ignore[name-defined]
62
+ model: "torch.jit.ScriptModule",
63
63
  model_meta: model_meta_api.ModelMetadata,
64
64
  model_blobs_dir_path: str,
65
65
  sample_input_data: Optional[model_types.SupportedDataType] = None,
@@ -72,7 +72,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
72
72
 
73
73
  import torch
74
74
 
75
- assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
75
+ assert isinstance(model, torch.jit.ScriptModule)
76
76
 
77
77
  if not is_sub_model:
78
78
  target_methods = handlers_utils.get_target_methods(
@@ -111,7 +111,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
111
111
  model_blob_path = os.path.join(model_blobs_dir_path, name)
112
112
  os.makedirs(model_blob_path, exist_ok=True)
113
113
  with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
114
- torch.jit.save(model, f) # type:ignore[no-untyped-call, attr-defined]
114
+ torch.jit.save(model, f) # type:ignore[no-untyped-call]
115
115
  base_meta = model_blob_meta.ModelBlobMeta(
116
116
  name=name,
117
117
  model_type=cls.HANDLER_TYPE,
@@ -133,7 +133,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
133
133
  model_meta: model_meta_api.ModelMetadata,
134
134
  model_blobs_dir_path: str,
135
135
  **kwargs: Unpack[model_types.TorchScriptLoadOptions],
136
- ) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
136
+ ) -> "torch.jit.ScriptModule":
137
137
  import torch
138
138
 
139
139
  model_blob_path = os.path.join(model_blobs_dir_path, name)
@@ -141,10 +141,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
141
141
  model_blob_metadata = model_blobs_metadata[name]
142
142
  model_blob_filename = model_blob_metadata.path
143
143
  with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
144
- m = torch.jit.load( # type:ignore[no-untyped-call, attr-defined]
144
+ m = torch.jit.load( # type:ignore[no-untyped-call]
145
145
  f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu"
146
146
  )
147
- assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined]
147
+ assert isinstance(m, torch.jit.ScriptModule)
148
148
 
149
149
  if kwargs.get("use_gpu", False):
150
150
  m = m.cuda()
@@ -154,7 +154,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
154
154
  @classmethod
155
155
  def convert_as_custom_model(
156
156
  cls,
157
- raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
157
+ raw_model: "torch.jit.ScriptModule",
158
158
  model_meta: model_meta_api.ModelMetadata,
159
159
  background_data: Optional[pd.DataFrame] = None,
160
160
  **kwargs: Unpack[model_types.TorchScriptLoadOptions],
@@ -162,11 +162,11 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
162
162
  from snowflake.ml.model import custom_model
163
163
 
164
164
  def _create_custom_model(
165
- raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
165
+ raw_model: "torch.jit.ScriptModule",
166
166
  model_meta: model_meta_api.ModelMetadata,
167
167
  ) -> Type[custom_model.CustomModel]:
168
168
  def fn_factory(
169
- raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
169
+ raw_model: "torch.jit.ScriptModule",
170
170
  signature: model_signature.ModelSignature,
171
171
  target_method: str,
172
172
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
@@ -1,3 +1,2 @@
1
- REQUIREMENTS = [
2
- "cloudpickle>=2.0.0"
3
- ]
1
+ REQUIREMENTS = ['cloudpickle>=2.0.0']
2
+ ALL_REQUIREMENTS=['cloudpickle>=2.0.0']
@@ -58,11 +58,16 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
58
58
  xgb_estimator_type: Required[str]
59
59
 
60
60
 
61
+ class TensorflowModelBlobOptions(BaseModelBlobOptions):
62
+ is_keras_model: Required[bool]
63
+
64
+
61
65
  ModelBlobOptions = Union[
62
66
  BaseModelBlobOptions,
63
67
  HuggingFacePipelineModelBlobOptions,
64
68
  MLFlowModelBlobOptions,
65
69
  XgboostModelBlobOptions,
70
+ TensorflowModelBlobOptions,
66
71
  ]
67
72
 
68
73
 
@@ -1,10 +1,2 @@
1
- REQUIREMENTS = [
2
- "absl-py>=0.15,<2",
3
- "anyio>=3.5.0,<4",
4
- "numpy>=1.23,<2",
5
- "packaging>=20.9,<24",
6
- "pandas>=1.0.0,<3",
7
- "pyyaml>=6.0,<7",
8
- "snowflake-snowpark-python>=1.17.0,<2",
9
- "typing-extensions>=4.1.0,<5"
10
- ]
1
+ REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
+ ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.1.0,<2.4', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.10,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
@@ -17,6 +17,8 @@ _SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES = [
17
17
  for r in _snowml_inference_alternative_requirements.REQUIREMENTS
18
18
  ]
19
19
 
20
+ PACKAGES_NOT_ALLOWED_IN_WAREHOUSE = ["snowflake-connector-python", "pyarrow"]
21
+
20
22
 
21
23
  class ModelRuntime:
22
24
  """Class to represent runtime in a model, which controls the runtime and version, imports and dependencies.
@@ -61,15 +63,8 @@ class ModelRuntime:
61
63
  ],
62
64
  )
63
65
 
64
- if not is_warehouse and self.embed_local_ml_library:
65
- self.runtime_env.include_if_absent(
66
- [
67
- model_env.ModelDependency(
68
- requirement="pyarrow",
69
- pip_name="pyarrow",
70
- )
71
- ],
72
- )
66
+ if is_warehouse and self.embed_local_ml_library:
67
+ self.runtime_env.remove_if_present_conda(PACKAGES_NOT_ALLOWED_IN_WAREHOUSE)
73
68
 
74
69
  if is_gpu:
75
70
  self.runtime_env.generate_env_for_cuda()
@@ -84,7 +84,7 @@ def get_model_task_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel
84
84
  if type_utils.LazyType("lightgbm.Booster").isinstance(model):
85
85
  model_task = model.params["objective"] # type: ignore[attr-defined]
86
86
  elif hasattr(model, "objective_"):
87
- model_task = model.objective_
87
+ model_task = model.objective_ # type: ignore[assignment]
88
88
  if model_task in _BINARY_CLASSIFICATION_OBJECTIVES:
89
89
  return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
90
90
  if model_task in _MULTI_CLASSIFICATION_OBJECTIVES:
@@ -14,10 +14,12 @@ from typing import (
14
14
  Type,
15
15
  Union,
16
16
  final,
17
+ get_args,
17
18
  )
18
19
 
19
20
  import numpy as np
20
21
  import numpy.typing as npt
22
+ import pandas as pd
21
23
 
22
24
  import snowflake.snowpark.types as spt
23
25
  from snowflake.ml._internal.exceptions import (
@@ -29,6 +31,21 @@ if TYPE_CHECKING:
29
31
  import mlflow
30
32
  import torch
31
33
 
34
+ PandasExtensionTypes = Union[
35
+ pd.Int8Dtype,
36
+ pd.Int16Dtype,
37
+ pd.Int32Dtype,
38
+ pd.Int64Dtype,
39
+ pd.UInt8Dtype,
40
+ pd.UInt16Dtype,
41
+ pd.UInt32Dtype,
42
+ pd.UInt64Dtype,
43
+ pd.Float32Dtype,
44
+ pd.Float64Dtype,
45
+ pd.BooleanDtype,
46
+ pd.StringDtype,
47
+ ]
48
+
32
49
 
33
50
  class DataType(Enum):
34
51
  def __init__(self, value: str, snowpark_type: Type[spt.DataType], numpy_type: npt.DTypeLike) -> None:
@@ -67,11 +84,11 @@ class DataType(Enum):
67
84
  return f"DataType.{self.name}"
68
85
 
69
86
  @classmethod
70
- def from_numpy_type(cls, np_type: npt.DTypeLike) -> "DataType":
87
+ def from_numpy_type(cls, input_type: Union[npt.DTypeLike, PandasExtensionTypes]) -> "DataType":
71
88
  """Translate numpy dtype to DataType for signature definition.
72
89
 
73
90
  Args:
74
- np_type: The numpy dtype.
91
+ input_type: The numpy dtype or Pandas Extension Dtype
75
92
 
76
93
  Raises:
77
94
  SnowflakeMLException: NotImplementedError: Raised when the given numpy type is not supported.
@@ -79,6 +96,10 @@ class DataType(Enum):
79
96
  Returns:
80
97
  Corresponding DataType.
81
98
  """
99
+ # To support pandas extension dtype
100
+ if isinstance(input_type, get_args(PandasExtensionTypes)):
101
+ input_type = input_type.type
102
+
82
103
  np_to_snowml_type_mapping = {i._numpy_type: i for i in DataType}
83
104
 
84
105
  # Add datetime types:
@@ -88,12 +109,12 @@ class DataType(Enum):
88
109
  np_to_snowml_type_mapping[f"datetime64[{res}]"] = DataType.TIMESTAMP_NTZ
89
110
 
90
111
  for potential_type in np_to_snowml_type_mapping.keys():
91
- if np.can_cast(np_type, potential_type, casting="no"):
112
+ if np.can_cast(input_type, potential_type, casting="no"):
92
113
  # This is used since the same dtype might represented in different ways.
93
114
  return np_to_snowml_type_mapping[potential_type]
94
115
  raise snowml_exceptions.SnowflakeMLException(
95
116
  error_code=error_codes.NOT_IMPLEMENTED,
96
- original_exception=NotImplementedError(f"Type {np_type} is not supported as a DataType."),
117
+ original_exception=NotImplementedError(f"Type {input_type} is not supported as a DataType."),
97
118
  )
98
119
 
99
120
  @classmethod
@@ -212,6 +233,7 @@ class FeatureSpec(BaseFeatureSpec):
212
233
  name: str,
213
234
  dtype: DataType,
214
235
  shape: Optional[Tuple[int, ...]] = None,
236
+ nullable: bool = True,
215
237
  ) -> None:
216
238
  """
217
239
  Initialize a feature.
@@ -219,6 +241,7 @@ class FeatureSpec(BaseFeatureSpec):
219
241
  Args:
220
242
  name: Name of the feature.
221
243
  dtype: Type of the elements in the feature.
244
+ nullable: Whether the feature is nullable. Defaults to True.
222
245
  shape: Used to represent scalar feature, 1-d feature list,
223
246
  or n-d tensor. Use -1 to represent variable length. Defaults to None.
224
247
 
@@ -227,6 +250,7 @@ class FeatureSpec(BaseFeatureSpec):
227
250
  - (2,): 1d list with a fixed length of 2.
228
251
  - (-1,): 1d list with variable length, used for ragged tensor representation.
229
252
  - (d1, d2, d3): 3d tensor.
253
+ nullable: Whether the feature is nullable. Defaults to True.
230
254
 
231
255
  Raises:
232
256
  SnowflakeMLException: TypeError: When the dtype input type is incorrect.
@@ -248,6 +272,8 @@ class FeatureSpec(BaseFeatureSpec):
248
272
  )
249
273
  self._shape = shape
250
274
 
275
+ self._nullable = nullable
276
+
251
277
  def as_snowpark_type(self) -> spt.DataType:
252
278
  result_type = self._dtype.as_snowpark_type()
253
279
  if not self._shape:
@@ -256,13 +282,34 @@ class FeatureSpec(BaseFeatureSpec):
256
282
  result_type = spt.ArrayType(result_type)
257
283
  return result_type
258
284
 
259
- def as_dtype(self) -> Union[npt.DTypeLike, str]:
285
+ def as_dtype(self) -> Union[npt.DTypeLike, str, PandasExtensionTypes]:
260
286
  """Convert to corresponding local Type."""
287
+
261
288
  if not self._shape:
262
289
  # scalar dtype: use keys from `np.sctypeDict` to prevent unit-less dtype 'datetime64'
263
290
  if "datetime64" in self._dtype._value:
264
291
  return self._dtype._value
265
- return self._dtype._numpy_type
292
+
293
+ np_type = self._dtype._numpy_type
294
+ if self._nullable:
295
+ np_to_pd_dtype_mapping = {
296
+ np.int8: pd.Int8Dtype(),
297
+ np.int16: pd.Int16Dtype(),
298
+ np.int32: pd.Int32Dtype(),
299
+ np.int64: pd.Int64Dtype(),
300
+ np.uint8: pd.UInt8Dtype(),
301
+ np.uint16: pd.UInt16Dtype(),
302
+ np.uint32: pd.UInt32Dtype(),
303
+ np.uint64: pd.UInt64Dtype(),
304
+ np.float32: pd.Float32Dtype(),
305
+ np.float64: pd.Float64Dtype(),
306
+ np.bool_: pd.BooleanDtype(),
307
+ np.str_: pd.StringDtype(),
308
+ }
309
+
310
+ return np_to_pd_dtype_mapping.get(np_type, np_type) # type: ignore[arg-type]
311
+
312
+ return np_type
266
313
  return np.object_
267
314
 
268
315
  def __eq__(self, other: object) -> bool:
@@ -273,7 +320,10 @@ class FeatureSpec(BaseFeatureSpec):
273
320
 
274
321
  def __repr__(self) -> str:
275
322
  shape_str = f", shape={repr(self._shape)}" if self._shape else ""
276
- return f"FeatureSpec(dtype={repr(self._dtype)}, name={repr(self._name)}{shape_str})"
323
+ return (
324
+ f"FeatureSpec(dtype={repr(self._dtype)}, "
325
+ f"name={repr(self._name)}{shape_str}, nullable={repr(self._nullable)})"
326
+ )
277
327
 
278
328
  def to_dict(self) -> Dict[str, Any]:
279
329
  """Serialize the feature group into a dict.
@@ -281,10 +331,7 @@ class FeatureSpec(BaseFeatureSpec):
281
331
  Returns:
282
332
  A dict that serializes the feature group.
283
333
  """
284
- base_dict: Dict[str, Any] = {
285
- "type": self._dtype.name,
286
- "name": self._name,
287
- }
334
+ base_dict: Dict[str, Any] = {"type": self._dtype.name, "name": self._name, "nullable": self._nullable}
288
335
  if self._shape is not None:
289
336
  base_dict["shape"] = self._shape
290
337
  return base_dict
@@ -304,7 +351,9 @@ class FeatureSpec(BaseFeatureSpec):
304
351
  if shape:
305
352
  shape = tuple(shape)
306
353
  type = DataType[input_dict["type"]]
307
- return FeatureSpec(name=name, dtype=type, shape=shape)
354
+ # If nullable is not provided, default to False for backward compatibility.
355
+ nullable = input_dict.get("nullable", False)
356
+ return FeatureSpec(name=name, dtype=type, shape=shape, nullable=nullable)
308
357
 
309
358
  @classmethod
310
359
  def from_mlflow_spec(
@@ -475,10 +524,8 @@ class ModelSignature:
475
524
  sig_outs = loaded["outputs"]
476
525
  sig_inputs = loaded["inputs"]
477
526
 
478
- deserialize_spec: Callable[[Dict[str, Any]], BaseFeatureSpec] = (
479
- lambda sig_spec: FeatureGroupSpec.from_dict(sig_spec)
480
- if "feature_group" in sig_spec
481
- else FeatureSpec.from_dict(sig_spec)
527
+ deserialize_spec: Callable[[Dict[str, Any]], BaseFeatureSpec] = lambda sig_spec: (
528
+ FeatureGroupSpec.from_dict(sig_spec) if "feature_group" in sig_spec else FeatureSpec.from_dict(sig_spec)
482
529
  )
483
530
 
484
531
  return ModelSignature(