snowflake-ml-python 1.7.1__py3-none-any.whl → 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. snowflake/cortex/__init__.py +16 -8
  2. snowflake/cortex/_classify_text.py +12 -1
  3. snowflake/cortex/_complete.py +82 -13
  4. snowflake/cortex/_embed_text_1024.py +9 -2
  5. snowflake/cortex/_embed_text_768.py +9 -2
  6. snowflake/cortex/_extract_answer.py +9 -2
  7. snowflake/cortex/_sentiment.py +9 -2
  8. snowflake/cortex/_summarize.py +9 -2
  9. snowflake/cortex/_translate.py +9 -2
  10. snowflake/ml/_internal/env_utils.py +7 -52
  11. snowflake/ml/_internal/utils/identifier.py +4 -2
  12. snowflake/ml/_internal/utils/jwt_generator.py +141 -0
  13. snowflake/ml/data/__init__.py +3 -0
  14. snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
  15. snowflake/ml/data/data_connector.py +53 -11
  16. snowflake/ml/data/data_ingestor.py +2 -1
  17. snowflake/ml/data/torch_utils.py +18 -5
  18. snowflake/ml/feature_store/examples/example_helper.py +2 -1
  19. snowflake/ml/fileset/fileset.py +18 -18
  20. snowflake/ml/model/_client/model/model_version_impl.py +24 -8
  21. snowflake/ml/model/_client/ops/model_ops.py +2 -6
  22. snowflake/ml/model/_client/ops/service_ops.py +12 -7
  23. snowflake/ml/model/_client/sql/model_version.py +11 -0
  24. snowflake/ml/model/_client/sql/stage.py +1 -1
  25. snowflake/ml/model/_model_composer/model_composer.py +8 -3
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
  27. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  28. snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
  29. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
  30. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
  31. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
  32. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  33. snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
  34. snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
  35. snowflake/ml/model/_packager/model_handlers/_utils.py +27 -2
  36. snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
  37. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +5 -1
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
  39. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
  40. snowflake/ml/model/_packager/model_handlers/sklearn.py +10 -9
  41. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
  42. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
  45. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  46. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  47. snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
  48. snowflake/ml/model/_signatures/pandas_handler.py +1 -1
  49. snowflake/ml/model/_signatures/snowpark_handler.py +8 -2
  50. snowflake/ml/model/_signatures/utils.py +0 -1
  51. snowflake/ml/model/type_hints.py +1 -0
  52. snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
  53. snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
  54. snowflake/ml/modeling/pipeline/pipeline.py +6 -176
  55. snowflake/ml/modeling/xgboost/xgb_classifier.py +161 -88
  56. snowflake/ml/modeling/xgboost/xgb_regressor.py +160 -85
  57. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +160 -85
  58. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +160 -85
  59. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +5 -170
  60. snowflake/ml/monitoring/_manager/model_monitor_manager.py +9 -9
  61. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -2
  62. snowflake/ml/monitoring/model_monitor.py +26 -11
  63. snowflake/ml/registry/_manager/model_manager.py +70 -33
  64. snowflake/ml/registry/registry.py +53 -34
  65. snowflake/ml/utils/authentication.py +75 -0
  66. snowflake/ml/version.py +1 -1
  67. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/METADATA +120 -53
  68. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/RECORD +71 -74
  69. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/WHEEL +1 -1
  70. snowflake/ml/_internal/utils/retryable_http.py +0 -39
  71. snowflake/ml/fileset/parquet_parser.py +0 -170
  72. snowflake/ml/fileset/tf_dataset.py +0 -88
  73. snowflake/ml/fileset/torch_datapipe.py +0 -57
  74. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
  75. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
  76. snowflake/ml/monitoring/entities/output_score_type.py +0 -90
  77. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/LICENSE.txt +0 -0
  78. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/top_level.txt +0 -0
@@ -68,21 +68,45 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
68
68
  return cast("BaseEstimator", model)
69
69
 
70
70
  @classmethod
71
- def _get_supported_object_for_explainability(cls, estimator: "BaseEstimator") -> Any:
71
+ def _get_supported_object_for_explainability(
72
+ cls,
73
+ estimator: "BaseEstimator",
74
+ background_data: Optional[model_types.SupportedDataType],
75
+ enable_explainability: Optional[bool],
76
+ ) -> Any:
72
77
  from snowflake.ml.modeling import pipeline as snowml_pipeline
73
78
 
74
79
  # handle pipeline objects separately
75
80
  if isinstance(estimator, snowml_pipeline.Pipeline): # type: ignore[attr-defined]
76
81
  return None
77
82
 
78
- methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
79
- for method_name in methods:
83
+ tree_methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
84
+ non_tree_methods = ["to_sklearn"]
85
+ for method_name in tree_methods:
86
+ if hasattr(estimator, method_name):
87
+ try:
88
+ result = getattr(estimator, method_name)()
89
+ return result
90
+ except exceptions.SnowflakeMLException:
91
+ pass # Do nothing and continue to the next method
92
+ for method_name in non_tree_methods:
80
93
  if hasattr(estimator, method_name):
81
94
  try:
82
95
  result = getattr(estimator, method_name)()
96
+ if enable_explainability is None and background_data is None:
97
+ return None # cannot get explain without background data
98
+ elif enable_explainability and background_data is None:
99
+ raise ValueError(
100
+ "Provide `sample_input_data` to generate explanations for sklearn Snowpark ML models."
101
+ )
83
102
  return result
84
103
  except exceptions.SnowflakeMLException:
85
104
  pass # Do nothing and continue to the next method
105
+
106
+ if enable_explainability:
107
+ raise ValueError(
108
+ "Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
109
+ )
86
110
  return None
87
111
 
88
112
  @classmethod
@@ -127,34 +151,39 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
127
151
  raise ValueError(f"Target method {method_name} does not exist in the model.")
128
152
  model_meta.signatures = temp_model_signature_dict
129
153
 
130
- if enable_explainability or enable_explainability is None:
131
- python_base_obj = cls._get_supported_object_for_explainability(model)
132
- if python_base_obj is None:
133
- if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
134
- raise ValueError(
135
- "Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
136
- )
154
+ python_base_obj = cls._get_supported_object_for_explainability(model, sample_input_data, enable_explainability)
155
+ explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
156
+
157
+ if enable_explainability:
158
+ if explain_target_method is None:
159
+ raise ValueError(
160
+ "The model must have one of the following methods to enable explainability: "
161
+ + ", ".join(cls.EXPLAIN_TARGET_METHODS)
162
+ )
163
+ if enable_explainability is None:
164
+ if python_base_obj is None or explain_target_method is None:
137
165
  # set None to False so we don't include shap in the environment
138
166
  enable_explainability = False
139
167
  else:
140
- model_task_and_output_type = model_task_utils.get_model_task_and_output_type(python_base_obj)
141
- model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
142
- explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
143
- model_meta = handlers_utils.add_explain_method_signature(
144
- model_meta=model_meta,
145
- explain_method="explain",
146
- target_method=explain_target_method,
147
- output_return_type=model_task_and_output_type.output_type,
148
- )
149
168
  enable_explainability = True
150
-
151
- background_data = handlers_utils.get_explainability_supported_background(
152
- sample_input_data, model_meta, explain_target_method
169
+ if enable_explainability:
170
+ model_task_and_output_type = model_task_utils.resolve_model_task_and_output_type(
171
+ python_base_obj, model_meta.task
172
+ )
173
+ model_meta.task = model_task_and_output_type.task
174
+ model_meta = handlers_utils.add_explain_method_signature(
175
+ model_meta=model_meta,
176
+ explain_method="explain",
177
+ target_method=explain_target_method,
178
+ output_return_type=model_task_and_output_type.output_type,
179
+ )
180
+ background_data = handlers_utils.get_explainability_supported_background(
181
+ sample_input_data, model_meta, explain_target_method
182
+ )
183
+ if background_data is not None:
184
+ handlers_utils.save_background_data(
185
+ model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
153
186
  )
154
- if background_data is not None:
155
- handlers_utils.save_background_data(
156
- model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
157
- )
158
187
 
159
188
  model_blob_path = os.path.join(model_blobs_dir_path, name)
160
189
  os.makedirs(model_blob_path, exist_ok=True)
@@ -237,8 +266,17 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
237
266
  def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
238
267
  import shap
239
268
 
240
- methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
241
- for method_name in methods:
269
+ tree_methods = ["to_xgboost", "to_lightgbm"]
270
+ non_tree_methods = ["to_sklearn"]
271
+ for method_name in tree_methods:
272
+ try:
273
+ base_model = getattr(raw_model, method_name)()
274
+ explainer = shap.TreeExplainer(base_model)
275
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
276
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
277
+ except exceptions.SnowflakeMLException:
278
+ pass # Do nothing and continue to the next method
279
+ for method_name in non_tree_methods:
242
280
  try:
243
281
  base_model = getattr(raw_model, method_name)()
244
282
  explainer = shap.Explainer(base_model, masker=background_data)
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
6
+ from packaging import version
6
7
  from typing_extensions import TypeGuard, Unpack
7
8
 
8
9
  from snowflake.ml._internal import type_utils
@@ -73,13 +74,42 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
73
74
  if enable_explainability:
74
75
  raise NotImplementedError("Explainability is not supported for Tensorflow model.")
75
76
 
77
+ # When tensorflow is installed, keras is also installed.
78
+ import keras
76
79
  import tensorflow
77
80
 
78
81
  assert isinstance(model, tensorflow.Module)
79
82
 
80
83
  is_keras_model = type_utils.LazyType("tensorflow.keras.Model").isinstance(model) or type_utils.LazyType(
81
- "tf_keras.Model"
84
+ "keras.Model"
82
85
  ).isinstance(model)
86
+ is_tf_keras_model = type_utils.LazyType("tf_keras.Model").isinstance(model)
87
+ is_keras_functional_or_sequential_model = (
88
+ getattr(model, "_is_graph_network", False)
89
+ or type_utils.LazyType("tensorflow.keras.engine.sequential.Sequential").isinstance(model)
90
+ or type_utils.LazyType("keras.engine.sequential.Sequential").isinstance(model)
91
+ or type_utils.LazyType("tf_keras.engine.sequential.Sequential").isinstance(model)
92
+ )
93
+
94
+ assert isinstance(model, tensorflow.Module)
95
+
96
+ keras_version = version.parse(keras.__version__)
97
+
98
+ # Tensorflow and keras model save format is different.
99
+ # Keras functional or sequential models are saved as keras format
100
+ # Keras v3 other models are saved using cloudpickle
101
+ # Keras v2 other models are saved using tensorflow saved model format
102
+ # Tensorflow models are saved using tensorflow saved model format
103
+
104
+ if is_keras_model or is_tf_keras_model:
105
+ if is_keras_functional_or_sequential_model:
106
+ save_format = "keras"
107
+ elif keras_version.major == 2 or is_tf_keras_model:
108
+ save_format = "keras_tf"
109
+ else:
110
+ save_format = "cloudpickle"
111
+ else:
112
+ save_format = "tf"
83
113
 
84
114
  if is_keras_model:
85
115
  default_target_methods = ["predict"]
@@ -93,6 +123,9 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
93
123
  default_target_methods=default_target_methods,
94
124
  )
95
125
 
126
+ if is_keras_model and len(target_methods) > 1:
127
+ raise ValueError("Keras model can only have one target method.")
128
+
96
129
  def get_prediction(
97
130
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
98
131
  ) -> model_types.SupportedLocalDataType:
@@ -122,31 +155,43 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
122
155
 
123
156
  model_blob_path = os.path.join(model_blobs_dir_path, name)
124
157
  os.makedirs(model_blob_path, exist_ok=True)
125
- if is_keras_model:
126
- tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
127
- model_meta.env.include_if_absent(
128
- [
129
- model_env.ModelDependency(requirement="keras<3", pip_name="keras"),
130
- ],
131
- check_local_version=False,
132
- )
158
+ save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
159
+ if save_format == "keras":
160
+ model.save(save_path, save_format="keras")
161
+ elif save_format == "keras_tf":
162
+ model.save(save_path, save_format="tf")
163
+ elif save_format == "cloudpickle":
164
+ import cloudpickle
165
+
166
+ with open(save_path, "wb") as f:
167
+ cloudpickle.dump(model, f)
133
168
  else:
134
- tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
169
+ tensorflow.saved_model.save(
170
+ model,
171
+ save_path,
172
+ options=tensorflow.saved_model.SaveOptions(experimental_custom_gradients=False),
173
+ )
135
174
 
136
175
  base_meta = model_blob_meta.ModelBlobMeta(
137
176
  name=name,
138
177
  model_type=cls.HANDLER_TYPE,
139
178
  handler_version=cls.HANDLER_VERSION,
140
179
  path=cls.MODEL_BLOB_FILE_OR_DIR,
141
- options=model_meta_schema.TensorflowModelBlobOptions(is_keras_model=is_keras_model),
180
+ options=model_meta_schema.TensorflowModelBlobOptions(save_format=save_format),
142
181
  )
143
182
  model_meta.models[name] = base_meta
144
183
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
145
184
 
185
+ dependencies = [
186
+ model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
187
+ ]
188
+ if is_keras_model:
189
+ dependencies.append(model_env.ModelDependency(requirement="keras", pip_name="keras"))
190
+ elif is_tf_keras_model:
191
+ dependencies.append(model_env.ModelDependency(requirement="tf-keras", pip_name="tf-keras"))
192
+
146
193
  model_meta.env.include_if_absent(
147
- [
148
- model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
149
- ],
194
+ dependencies,
150
195
  check_local_version=True,
151
196
  )
152
197
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
@@ -166,10 +211,18 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
166
211
  model_blob_metadata = model_blobs_metadata[name]
167
212
  model_blob_filename = model_blob_metadata.path
168
213
  model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
169
- if model_blob_options.get("is_keras_model", False):
170
- m = tensorflow.keras.models.load_model(os.path.join(model_blob_path, model_blob_filename), compile=False)
214
+ load_path = os.path.join(model_blob_path, model_blob_filename)
215
+ save_format = model_blob_options.get("save_format", "tf")
216
+ if save_format == "keras" or save_format == "keras_tf":
217
+ m = tensorflow.keras.models.load_model(load_path)
218
+ elif save_format == "cloudpickle":
219
+ import cloudpickle
220
+
221
+ with open(load_path, "rb") as f:
222
+ m = cloudpickle.load(f)
171
223
  else:
172
- m = tensorflow.saved_model.load(os.path.join(model_blob_path, model_blob_filename))
224
+ m = tensorflow.saved_model.load(load_path)
225
+
173
226
  return cast(tensorflow.Module, m)
174
227
 
175
228
  @classmethod
@@ -117,8 +117,8 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
117
117
  sample_input_data=sample_input_data,
118
118
  get_prediction_fn=get_prediction,
119
119
  )
120
- model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
121
- model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
120
+ model_task_and_output = model_task_utils.resolve_model_task_and_output_type(model, model_meta.task)
121
+ model_meta.task = model_task_and_output.task
122
122
  if enable_explainability:
123
123
  model_meta = handlers_utils.add_explain_method_signature(
124
124
  model_meta=model_meta,
@@ -254,7 +254,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
254
254
  import shap
255
255
 
256
256
  explainer = shap.TreeExplainer(raw_model)
257
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
257
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
258
258
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
259
259
 
260
260
  if target_method == "explain":
@@ -215,6 +215,7 @@ class ModelMetadata:
215
215
  function_properties: A dict mapping function names to dict mapping function property key to value.
216
216
  metadata: User provided key-value metadata of the model. Defaults to None.
217
217
  creation_timestamp: Unix timestamp when the model metadata is created.
218
+ user_files: Dict mapping subdirectories to extra artifact file paths for files to include in the model.
218
219
  task: Model task like TABULAR_REGRESSION, tabular_classification, timeseries_forecasting etc.
219
220
  """
220
221
 
@@ -234,6 +235,7 @@ class ModelMetadata:
234
235
  runtimes: Optional[Dict[str, model_runtime.ModelRuntime]] = None,
235
236
  signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
236
237
  function_properties: Optional[Dict[str, Dict[str, Any]]] = None,
238
+ user_files: Optional[Dict[str, List[str]]] = None,
237
239
  metadata: Optional[Dict[str, str]] = None,
238
240
  creation_timestamp: Optional[str] = None,
239
241
  min_snowpark_ml_version: Optional[str] = None,
@@ -247,6 +249,7 @@ class ModelMetadata:
247
249
  if signatures:
248
250
  self.signatures = signatures
249
251
  self.function_properties = function_properties or {}
252
+ self.user_files = user_files
250
253
  self.metadata = metadata
251
254
  self.model_type = model_type
252
255
  self.env = env
@@ -59,7 +59,11 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
59
59
 
60
60
 
61
61
  class TensorflowModelBlobOptions(BaseModelBlobOptions):
62
- is_keras_model: Required[bool]
62
+ save_format: Required[str]
63
+
64
+
65
+ class SentenceTransformersModelBlobOptions(BaseModelBlobOptions):
66
+ batch_size: Required[int]
63
67
 
64
68
 
65
69
  ModelBlobOptions = Union[
@@ -68,6 +72,7 @@ ModelBlobOptions = Union[
68
72
  MLFlowModelBlobOptions,
69
73
  XgboostModelBlobOptions,
70
74
  TensorflowModelBlobOptions,
75
+ SentenceTransformersModelBlobOptions,
71
76
  ]
72
77
 
73
78
 
@@ -1,2 +1,2 @@
1
- REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
- ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.1.0,<2.4', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.10,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
1
+ REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
+ ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'huggingface_hub<0.26', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.12.0,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
@@ -149,8 +149,9 @@ def _get_model_task(model: Any) -> type_hints.Task:
149
149
  raise ValueError(f"Model type {type(model)} is not supported")
150
150
 
151
151
 
152
- def get_model_task_and_output_type(model: Any) -> ModelTaskAndOutputType:
153
- task = _get_model_task(model)
152
+ def resolve_model_task_and_output_type(model: Any, passed_model_task: type_hints.Task) -> ModelTaskAndOutputType:
153
+ inferred_task = _get_model_task(model)
154
+ task = handlers_utils.validate_model_task(passed_model_task, inferred_task)
154
155
  output_type = model_signature.DataType.DOUBLE
155
156
  if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
156
157
  output_type = model_signature.DataType.STRING
@@ -224,6 +224,6 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
224
224
  df_col_dtypes = [df[col].dtype for col in df.columns]
225
225
  for df_col, df_col_dtype in zip(df_cols, df_col_dtypes):
226
226
  if df_col_dtype == np.dtype("O"):
227
- if isinstance(df[df_col][0], np.ndarray):
227
+ if isinstance(df[df_col].iloc[0], np.ndarray):
228
228
  df[df_col] = df[df_col].map(np.ndarray.tolist)
229
229
  return df
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Literal, Optional, Sequence, cast
2
+ from typing import Any, Literal, Optional, Sequence, cast
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -73,14 +73,20 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
73
73
  assert isinstance(feature, core.FeatureSpec), "Invalid feature kind."
74
74
  dtype_map[feature.name] = feature.as_dtype()
75
75
  df_local = data.to_pandas()
76
+
76
77
  # This is because Array will become string (Even though the correct schema is set)
77
78
  # and object will become variant type and requires an additional loads
78
79
  # to get correct data otherwise it would be string.
80
+ def load_if_not_null(x: str) -> Optional[Any]:
81
+ if x is None:
82
+ return None
83
+ return json.loads(x)
84
+
79
85
  for field in data.schema.fields:
80
86
  if isinstance(field.datatype, spt.ArrayType):
81
87
  df_local[identifier.get_unescaped_names(field.name)] = df_local[
82
88
  identifier.get_unescaped_names(field.name)
83
- ].map(json.loads)
89
+ ].map(load_if_not_null)
84
90
  # Only when the feature is not from inference, we are confident to do the type casting.
85
91
  # Otherwise, dtype_map will be empty.
86
92
  # Errors are ignored to make sure None won't be converted and won't raise Error
@@ -118,7 +118,6 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any])
118
118
  category=DeprecationWarning,
119
119
  stacklevel=1,
120
120
  )
121
-
122
121
  return core.ModelSignature(
123
122
  inputs=[
124
123
  core.FeatureSpec(name="user_inputs", dtype=core.DataType.STRING, shape=(-1,)),
@@ -199,6 +199,7 @@ class HuggingFaceSaveOptions(BaseModelSaveOption):
199
199
  class SentenceTransformersSaveOptions(BaseModelSaveOption):
200
200
  target_methods: NotRequired[Sequence[str]]
201
201
  cuda_version: NotRequired[str]
202
+ batch_size: NotRequired[int]
202
203
 
203
204
 
204
205
  ModelSaveOption = Union[
@@ -1,11 +1,9 @@
1
- import os
2
1
  from typing import List, Optional, Union
3
2
 
4
3
  import pandas as pd
5
4
  from sklearn import model_selection
6
5
 
7
6
  from snowflake.ml._internal.exceptions import error_codes, exceptions
8
- from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
9
7
  from snowflake.ml.modeling._internal.estimator_utils import (
10
8
  get_module_name,
11
9
  is_single_node,
@@ -13,9 +11,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
13
11
  from snowflake.ml.modeling._internal.local_implementations.pandas_trainer import (
14
12
  PandasModelTrainer,
15
13
  )
16
- from snowflake.ml.modeling._internal.ml_runtime_implementations.ml_runtime_trainer import (
17
- MLRuntimeModelTrainer,
18
- )
19
14
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
20
15
  from snowflake.ml.modeling._internal.snowpark_implementations.distributed_hpo_trainer import (
21
16
  DistributedHPOTrainer,
@@ -107,9 +102,6 @@ class ModelTrainerBuilder:
107
102
  "autogenerated": autogenerated,
108
103
  "subproject": subproject,
109
104
  }
110
- if os.environ.get(IN_ML_RUNTIME_ENV_VAR):
111
- return MLRuntimeModelTrainer(**init_args) # type: ignore[arg-type, return-value]
112
-
113
105
  trainer_klass = SnowparkModelTrainer
114
106
 
115
107
  assert dataset._session is not None # Make MyPy happy
@@ -1,16 +1,11 @@
1
- import os
2
1
  from typing import Optional, Union
3
2
 
4
3
  import pandas as pd
5
4
 
6
5
  from snowflake import snowpark
7
- from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
8
6
  from snowflake.ml.modeling._internal.local_implementations.pandas_handlers import (
9
7
  PandasTransformHandlers,
10
8
  )
11
- from snowflake.ml.modeling._internal.ml_runtime_implementations.ml_runtime_handlers import (
12
- MLRuntimeTransformHandlers,
13
- )
14
9
  from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import (
15
10
  SnowparkTransformHandlers,
16
11
  )
@@ -63,14 +58,6 @@ class ModelTransformerBuilder:
63
58
  )
64
59
 
65
60
  elif isinstance(dataset, snowpark.DataFrame):
66
- if os.environ.get(IN_ML_RUNTIME_ENV_VAR):
67
- return MLRuntimeTransformHandlers(
68
- dataset=dataset,
69
- estimator=estimator,
70
- class_name=class_name,
71
- subproject=subproject,
72
- autogenerated=autogenerated,
73
- )
74
61
  return SnowparkTransformHandlers(
75
62
  dataset=dataset,
76
63
  estimator=estimator,