snowflake-ml-python 1.7.0__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_complete.py +107 -64
  3. snowflake/cortex/_finetune.py +273 -0
  4. snowflake/cortex/_sse_client.py +91 -28
  5. snowflake/cortex/_util.py +30 -1
  6. snowflake/ml/_internal/type_utils.py +3 -3
  7. snowflake/ml/_internal/utils/jwt_generator.py +141 -0
  8. snowflake/ml/data/__init__.py +5 -0
  9. snowflake/ml/model/_client/model/model_version_impl.py +26 -12
  10. snowflake/ml/model/_client/ops/model_ops.py +51 -30
  11. snowflake/ml/model/_client/ops/service_ops.py +25 -9
  12. snowflake/ml/model/_client/sql/model.py +0 -14
  13. snowflake/ml/model/_client/sql/service.py +25 -1
  14. snowflake/ml/model/_client/sql/stage.py +1 -1
  15. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  16. snowflake/ml/model/_packager/model_env/model_env.py +12 -0
  17. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -1
  18. snowflake/ml/model/_packager/model_handlers/catboost.py +1 -1
  19. snowflake/ml/model/_packager/model_handlers/custom.py +3 -1
  20. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -1
  21. snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -1
  22. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -1
  23. snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
  24. snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
  25. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
  26. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
  27. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
  28. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
  29. snowflake/ml/model/_packager/model_task/model_task_utils.py +1 -1
  30. snowflake/ml/model/_signatures/core.py +63 -16
  31. snowflake/ml/model/_signatures/pandas_handler.py +71 -27
  32. snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
  33. snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
  34. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
  35. snowflake/ml/model/_signatures/utils.py +4 -1
  36. snowflake/ml/model/model_signature.py +38 -9
  37. snowflake/ml/model/type_hints.py +1 -1
  38. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
  39. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
  40. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +148 -1200
  41. snowflake/ml/monitoring/_manager/model_monitor_manager.py +114 -238
  42. snowflake/ml/monitoring/entities/model_monitor_config.py +38 -12
  43. snowflake/ml/monitoring/model_monitor.py +12 -86
  44. snowflake/ml/registry/registry.py +28 -40
  45. snowflake/ml/utils/authentication.py +75 -0
  46. snowflake/ml/version.py +1 -1
  47. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/METADATA +116 -52
  48. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/RECORD +51 -49
  49. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/WHEEL +1 -1
  50. snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
  51. snowflake/ml/monitoring/entities/output_score_type.py +0 -90
  52. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/LICENSE.txt +0 -0
  53. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ import re
6
6
  import tempfile
7
7
  import threading
8
8
  import time
9
- from typing import Any, Dict, List, Optional, Tuple, cast
9
+ from typing import Any, Dict, List, Optional, Tuple, Union, cast
10
10
 
11
11
  from packaging import version
12
12
 
@@ -15,7 +15,7 @@ from snowflake.ml._internal import file_utils
15
15
  from snowflake.ml._internal.utils import service_logger, snowflake_env, sql_identifier
16
16
  from snowflake.ml.model._client.service import model_deployment_spec
17
17
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
18
- from snowflake.snowpark import exceptions, row, session
18
+ from snowflake.snowpark import async_job, exceptions, row, session
19
19
  from snowflake.snowpark._internal import utils as snowpark_utils
20
20
 
21
21
  module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
@@ -107,8 +107,20 @@ class ServiceOperator:
107
107
  max_batch_rows: Optional[int],
108
108
  force_rebuild: bool,
109
109
  build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
110
+ block: bool,
110
111
  statement_params: Optional[Dict[str, Any]] = None,
111
- ) -> str:
112
+ ) -> Union[str, async_job.AsyncJob]:
113
+
114
+ # Fall back to the registry's database and schema if not provided
115
+ database_name = database_name or self._database_name
116
+ schema_name = schema_name or self._schema_name
117
+
118
+ # Fall back to the model's database and schema if not provided then to the registry's database and schema
119
+ service_database_name = service_database_name or database_name or self._database_name
120
+ service_schema_name = service_schema_name or schema_name or self._schema_name
121
+
122
+ image_repo_database_name = image_repo_database_name or database_name or self._database_name
123
+ image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
112
124
  # create a temp stage
113
125
  stage_name = sql_identifier.SqlIdentifier(
114
126
  snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
@@ -130,8 +142,8 @@ class ServiceOperator:
130
142
  raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
131
143
 
132
144
  self._model_deployment_spec.save(
133
- database_name=database_name or self._database_name,
134
- schema_name=schema_name or self._schema_name,
145
+ database_name=database_name,
146
+ schema_name=schema_name,
135
147
  model_name=model_name,
136
148
  version_name=version_name,
137
149
  service_database_name=service_database_name,
@@ -193,11 +205,15 @@ class ServiceOperator:
193
205
  log_thread = self._start_service_log_streaming(
194
206
  async_job, services, model_inference_service_exists, force_rebuild, statement_params
195
207
  )
196
- log_thread.join()
197
208
 
198
- res = cast(str, cast(List[row.Row], async_job.result())[0][0])
199
- module_logger.info(f"Inference service {service_name} deployment complete: {res}")
200
- return res
209
+ if block:
210
+ log_thread.join()
211
+
212
+ res = cast(str, cast(List[row.Row], async_job.result())[0][0])
213
+ module_logger.info(f"Inference service {service_name} deployment complete: {res}")
214
+ return res
215
+ else:
216
+ return async_job
201
217
 
202
218
  def _start_service_log_streaming(
203
219
  self,
@@ -17,8 +17,6 @@ class ModelSQLClient(_base._BaseSQLClient):
17
17
  MODEL_VERSION_ALIASES_COL_NAME = "aliases"
18
18
  MODEL_VERSION_INFERENCE_SERVICES_COL_NAME = "inference_services"
19
19
 
20
- MODEL_INFERENCE_SERVICE_ENDPOINT_COL_NAME = "name"
21
-
22
20
  def show_models(
23
21
  self,
24
22
  *,
@@ -85,18 +83,6 @@ class ModelSQLClient(_base._BaseSQLClient):
85
83
 
86
84
  return res.validate()
87
85
 
88
- def show_endpoints(
89
- self,
90
- *,
91
- service_name: str,
92
- ) -> List[row.Row]:
93
- res = query_result_checker.SqlResultValidator(
94
- self._session,
95
- (f"SHOW ENDPOINTS IN SERVICE {service_name}"),
96
- ).has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
97
-
98
- return res.validate()
99
-
100
86
  def set_comment(
101
87
  self,
102
88
  *,
@@ -10,7 +10,7 @@ from snowflake.ml._internal.utils import (
10
10
  sql_identifier,
11
11
  )
12
12
  from snowflake.ml.model._client.sql import _base
13
- from snowflake.snowpark import dataframe, functions as F, types as spt
13
+ from snowflake.snowpark import dataframe, functions as F, row, types as spt
14
14
  from snowflake.snowpark._internal import utils as snowpark_utils
15
15
 
16
16
 
@@ -26,6 +26,9 @@ class ServiceStatus(enum.Enum):
26
26
 
27
27
 
28
28
  class ServiceSQLClient(_base._BaseSQLClient):
29
+ MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
30
+ MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
31
+
29
32
  def build_model_container(
30
33
  self,
31
34
  *,
@@ -216,3 +219,24 @@ class ServiceSQLClient(_base._BaseSQLClient):
216
219
  f"DROP SERVICE {self.fully_qualified_object_name(database_name, schema_name, service_name)}",
217
220
  statement_params=statement_params,
218
221
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
222
+
223
+ def show_endpoints(
224
+ self,
225
+ *,
226
+ database_name: Optional[sql_identifier.SqlIdentifier],
227
+ schema_name: Optional[sql_identifier.SqlIdentifier],
228
+ service_name: sql_identifier.SqlIdentifier,
229
+ statement_params: Optional[Dict[str, Any]] = None,
230
+ ) -> List[row.Row]:
231
+ fully_qualified_service_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
232
+ res = (
233
+ query_result_checker.SqlResultValidator(
234
+ self._session,
235
+ (f"SHOW ENDPOINTS IN SERVICE {fully_qualified_service_name}"),
236
+ statement_params=statement_params,
237
+ )
238
+ .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME, allow_empty=True)
239
+ .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME, allow_empty=True)
240
+ )
241
+
242
+ return res.validate()
@@ -15,6 +15,6 @@ class StageSQLClient(_base._BaseSQLClient):
15
15
  ) -> None:
16
16
  query_result_checker.SqlResultValidator(
17
17
  self._session,
18
- f"CREATE TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
18
+ f"CREATE SCOPED TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
19
19
  statement_params=statement_params,
20
20
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -5,6 +5,7 @@ import sys
5
5
 
6
6
  import anyio
7
7
  import pandas as pd
8
+ import numpy as np
8
9
  from _snowflake import vectorized
9
10
 
10
11
  from snowflake.ml.model._packager import model_packager
@@ -47,4 +48,4 @@ def {function_name}(df: pd.DataFrame) -> dict:
47
48
  df.columns = input_cols
48
49
  input_df = df.astype(dtype=dtype_map)
49
50
  predictions_df = runner(input_df[input_cols])
50
- return predictions_df.to_dict("records")
51
+ return predictions_df.replace({{pd.NA: None, np.nan: None}}).to_dict("records")
@@ -174,6 +174,18 @@ class ModelEnv:
174
174
  except env_utils.DuplicateDependencyError:
175
175
  pass
176
176
 
177
+ def remove_if_present_conda(self, conda_pkgs: List[str]) -> None:
178
+ """Remove conda requirements from model env if present.
179
+
180
+ Args:
181
+ conda_pkgs: A list of package name to be removed from conda requirements.
182
+ """
183
+ for pkg_name in conda_pkgs:
184
+ spec_conda = env_utils._find_conda_dep_spec(self._conda_dependencies, pkg_name)
185
+ if spec_conda:
186
+ channel, spec = spec_conda
187
+ self._conda_dependencies[channel].remove(spec)
188
+
177
189
  def generate_env_for_cuda(self) -> None:
178
190
  if self.cuda_version is None:
179
191
  return
@@ -179,7 +179,7 @@ def convert_explanations_to_2D_df(
179
179
  return pd.DataFrame(explanations)
180
180
 
181
181
  if hasattr(model, "classes_"):
182
- classes_list = [str(cl) for cl in model.classes_] # type:ignore[union-attr]
182
+ classes_list = [str(cl) for cl in model.classes_]
183
183
  len_classes = len(classes_list)
184
184
  if explanations.shape[2] != len_classes:
185
185
  raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
@@ -95,7 +95,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
95
95
  get_prediction_fn=get_prediction,
96
96
  )
97
97
  model_task_and_output = model_task_utils.get_model_task_and_output_type(model)
98
- model_meta.task = model_task_and_output.task
98
+ model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
99
99
  if enable_explainability:
100
100
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
101
101
  model_meta = handlers_utils.add_explain_method_signature(
@@ -2,7 +2,7 @@ import inspect
2
2
  import os
3
3
  import pathlib
4
4
  import sys
5
- from typing import Dict, Optional, Type, final
5
+ from typing import Dict, Optional, Type, cast, final
6
6
 
7
7
  import anyio
8
8
  import cloudpickle
@@ -108,6 +108,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
108
108
  model_meta=model_meta,
109
109
  model_blobs_dir_path=model_blobs_dir_path,
110
110
  is_sub_model=True,
111
+ **cast(model_types.BaseModelSaveOption, kwargs),
111
112
  )
112
113
 
113
114
  # Make sure that the module where the model is defined get pickled by value as well.
@@ -175,6 +176,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
175
176
  name=sub_model_name,
176
177
  model_meta=model_meta,
177
178
  model_blobs_dir_path=model_blobs_dir_path,
179
+ **cast(model_types.BaseModelLoadOption, kwargs),
178
180
  )
179
181
  models[sub_model_name] = sub_model
180
182
  reconstructed_context = custom_model.ModelContext(artifacts=artifacts, models=models)
@@ -196,13 +196,14 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
196
196
  with open(model_blob_file_path, "rb") as f:
197
197
  model = cloudpickle.load(f)
198
198
  assert isinstance(model, getattr(lightgbm, lightgbm_estimator_type))
199
+ assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
199
200
 
200
201
  return model
201
202
 
202
203
  @classmethod
203
204
  def convert_as_custom_model(
204
205
  cls,
205
- raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
206
+ raw_model: Union["lightgbm.Booster", "lightgbm.LGBMModel"],
206
207
  model_meta: model_meta_api.ModelMetadata,
207
208
  background_data: Optional[pd.DataFrame] = None,
208
209
  **kwargs: Unpack[model_types.LGBMModelLoadOptions],
@@ -19,12 +19,26 @@ from snowflake.ml.model._packager.model_meta import (
19
19
  )
20
20
  from snowflake.ml.model._packager.model_task import model_task_utils
21
21
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
22
+ from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
22
23
 
23
24
  if TYPE_CHECKING:
24
25
  import sklearn.base
25
26
  import sklearn.pipeline
26
27
 
27
28
 
29
+ def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "sklearn.pipeline.Pipeline":
30
+ new_steps = []
31
+ for step_name, step in model.steps:
32
+ new_reg = step
33
+ if hasattr(step, "_sklearn_estimator") and step._sklearn_estimator is not None:
34
+ # Unpack estimator to open source.
35
+ new_reg = step._sklearn_estimator
36
+ new_steps.append((step_name, new_reg))
37
+
38
+ model.steps = new_steps
39
+ return model
40
+
41
+
28
42
  @final
29
43
  class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]]):
30
44
  """Handler for scikit-learn based model.
@@ -101,6 +115,10 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
101
115
  if sample_input_data is None:
102
116
  raise ValueError("Sample input data is required to enable explainability.")
103
117
 
118
+ # If this is a pipeline and we are in the container runtime, check for distributed estimator.
119
+ if os.getenv(IN_ML_RUNTIME_ENV_VAR) and isinstance(model, sklearn.pipeline.Pipeline):
120
+ model = _unpack_container_runtime_pipeline(model)
121
+
104
122
  if not is_sub_model:
105
123
  target_methods = handlers_utils.get_target_methods(
106
124
  model=model,
@@ -135,7 +153,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
135
153
  )
136
154
 
137
155
  model_task_and_output_type = model_task_utils.get_model_task_and_output_type(model)
138
- model_meta.task = model_task_and_output_type.task
156
+ model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
139
157
 
140
158
  # if users did not ask then we enable if we have background data
141
159
  if enable_explainability is None:
@@ -146,6 +164,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
146
164
  stacklevel=1,
147
165
  )
148
166
  enable_explainability = False
167
+ elif model_meta.task == model_types.Task.UNKNOWN:
168
+ enable_explainability = False
149
169
  else:
150
170
  enable_explainability = True
151
171
  if enable_explainability:
@@ -177,6 +197,35 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
177
197
  model_meta.models[name] = base_meta
178
198
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
179
199
 
200
+ # if model instance is a pipeline, check the pipeline steps
201
+ if isinstance(model, sklearn.pipeline.Pipeline):
202
+ for _, pipeline_step in model.steps:
203
+ if type_utils.LazyType("lightgbm.LGBMModel").isinstance(pipeline_step) or type_utils.LazyType(
204
+ "lightgbm.Booster"
205
+ ).isinstance(pipeline_step):
206
+ model_meta.env.include_if_absent(
207
+ [
208
+ model_env.ModelDependency(requirement="lightgbm", pip_name="lightgbm"),
209
+ ],
210
+ check_local_version=True,
211
+ )
212
+ elif type_utils.LazyType("xgboost.XGBModel").isinstance(pipeline_step) or type_utils.LazyType(
213
+ "xgboost.Booster"
214
+ ).isinstance(pipeline_step):
215
+ model_meta.env.include_if_absent(
216
+ [
217
+ model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
218
+ ],
219
+ check_local_version=True,
220
+ )
221
+ elif type_utils.LazyType("catboost.CatBoost").isinstance(pipeline_step):
222
+ model_meta.env.include_if_absent(
223
+ [
224
+ model_env.ModelDependency(requirement="catboost", pip_name="catboost"),
225
+ ],
226
+ check_local_version=True,
227
+ )
228
+
180
229
  if enable_explainability:
181
230
  model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
182
231
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
@@ -138,7 +138,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
138
138
  enable_explainability = False
139
139
  else:
140
140
  model_task_and_output_type = model_task_utils.get_model_task_and_output_type(python_base_obj)
141
- model_meta.task = model_task_and_output_type.task
141
+ model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output_type.task)
142
142
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
143
143
  model_meta = handlers_utils.add_explain_method_signature(
144
144
  model_meta=model_meta,
@@ -13,6 +13,7 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
13
13
  from snowflake.ml.model._packager.model_meta import (
14
14
  model_blob_meta,
15
15
  model_meta as model_meta_api,
16
+ model_meta_schema,
16
17
  )
17
18
  from snowflake.ml.model._signatures import (
18
19
  numpy_handler,
@@ -76,7 +77,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
76
77
 
77
78
  assert isinstance(model, tensorflow.Module)
78
79
 
79
- if isinstance(model, tensorflow.keras.Model):
80
+ is_keras_model = type_utils.LazyType("tensorflow.keras.Model").isinstance(model) or type_utils.LazyType(
81
+ "tf_keras.Model"
82
+ ).isinstance(model)
83
+
84
+ if is_keras_model:
80
85
  default_target_methods = ["predict"]
81
86
  else:
82
87
  default_target_methods = cls.DEFAULT_TARGET_METHODS
@@ -117,8 +122,14 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
117
122
 
118
123
  model_blob_path = os.path.join(model_blobs_dir_path, name)
119
124
  os.makedirs(model_blob_path, exist_ok=True)
120
- if isinstance(model, tensorflow.keras.Model):
125
+ if is_keras_model:
121
126
  tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
127
+ model_meta.env.include_if_absent(
128
+ [
129
+ model_env.ModelDependency(requirement="keras<3", pip_name="keras"),
130
+ ],
131
+ check_local_version=False,
132
+ )
122
133
  else:
123
134
  tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
124
135
 
@@ -127,12 +138,16 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
127
138
  model_type=cls.HANDLER_TYPE,
128
139
  handler_version=cls.HANDLER_VERSION,
129
140
  path=cls.MODEL_BLOB_FILE_OR_DIR,
141
+ options=model_meta_schema.TensorflowModelBlobOptions(is_keras_model=is_keras_model),
130
142
  )
131
143
  model_meta.models[name] = base_meta
132
144
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
133
145
 
134
146
  model_meta.env.include_if_absent(
135
- [model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow")], check_local_version=True
147
+ [
148
+ model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
149
+ ],
150
+ check_local_version=True,
136
151
  )
137
152
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
138
153
 
@@ -150,9 +165,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
150
165
  model_blobs_metadata = model_meta.models
151
166
  model_blob_metadata = model_blobs_metadata[name]
152
167
  model_blob_filename = model_blob_metadata.path
153
- m = tensorflow.keras.models.load_model(os.path.join(model_blob_path, model_blob_filename), compile=False)
154
- if isinstance(m, tensorflow.keras.Model):
155
- return m
168
+ model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
169
+ if model_blob_options.get("is_keras_model", False):
170
+ m = tensorflow.keras.models.load_model(os.path.join(model_blob_path, model_blob_filename), compile=False)
171
+ else:
172
+ m = tensorflow.saved_model.load(os.path.join(model_blob_path, model_blob_filename))
156
173
  return cast(tensorflow.Module, m)
157
174
 
158
175
  @classmethod
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
23
23
 
24
24
 
25
25
  @final
26
- class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # type:ignore[name-defined]
26
+ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
27
27
  """Handler for PyTorch JIT based model.
28
28
 
29
29
  Currently torch.jit.ScriptModule based classes are supported.
@@ -41,25 +41,25 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
41
41
  def can_handle(
42
42
  cls,
43
43
  model: model_types.SupportedModelType,
44
- ) -> TypeGuard["torch.jit.ScriptModule"]: # type:ignore[name-defined]
44
+ ) -> TypeGuard["torch.jit.ScriptModule"]:
45
45
  return type_utils.LazyType("torch.jit.ScriptModule").isinstance(model)
46
46
 
47
47
  @classmethod
48
48
  def cast_model(
49
49
  cls,
50
50
  model: model_types.SupportedModelType,
51
- ) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
51
+ ) -> "torch.jit.ScriptModule":
52
52
  import torch
53
53
 
54
- assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
54
+ assert isinstance(model, torch.jit.ScriptModule)
55
55
 
56
- return cast(torch.jit.ScriptModule, model) # type:ignore[name-defined]
56
+ return cast(torch.jit.ScriptModule, model)
57
57
 
58
58
  @classmethod
59
59
  def save_model(
60
60
  cls,
61
61
  name: str,
62
- model: "torch.jit.ScriptModule", # type:ignore[name-defined]
62
+ model: "torch.jit.ScriptModule",
63
63
  model_meta: model_meta_api.ModelMetadata,
64
64
  model_blobs_dir_path: str,
65
65
  sample_input_data: Optional[model_types.SupportedDataType] = None,
@@ -72,7 +72,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
72
72
 
73
73
  import torch
74
74
 
75
- assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
75
+ assert isinstance(model, torch.jit.ScriptModule)
76
76
 
77
77
  if not is_sub_model:
78
78
  target_methods = handlers_utils.get_target_methods(
@@ -111,7 +111,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
111
111
  model_blob_path = os.path.join(model_blobs_dir_path, name)
112
112
  os.makedirs(model_blob_path, exist_ok=True)
113
113
  with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
114
- torch.jit.save(model, f) # type:ignore[no-untyped-call, attr-defined]
114
+ torch.jit.save(model, f) # type:ignore[no-untyped-call]
115
115
  base_meta = model_blob_meta.ModelBlobMeta(
116
116
  name=name,
117
117
  model_type=cls.HANDLER_TYPE,
@@ -133,7 +133,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
133
133
  model_meta: model_meta_api.ModelMetadata,
134
134
  model_blobs_dir_path: str,
135
135
  **kwargs: Unpack[model_types.TorchScriptLoadOptions],
136
- ) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
136
+ ) -> "torch.jit.ScriptModule":
137
137
  import torch
138
138
 
139
139
  model_blob_path = os.path.join(model_blobs_dir_path, name)
@@ -141,10 +141,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
141
141
  model_blob_metadata = model_blobs_metadata[name]
142
142
  model_blob_filename = model_blob_metadata.path
143
143
  with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
144
- m = torch.jit.load( # type:ignore[no-untyped-call, attr-defined]
144
+ m = torch.jit.load( # type:ignore[no-untyped-call]
145
145
  f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu"
146
146
  )
147
- assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined]
147
+ assert isinstance(m, torch.jit.ScriptModule)
148
148
 
149
149
  if kwargs.get("use_gpu", False):
150
150
  m = m.cuda()
@@ -154,7 +154,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
154
154
  @classmethod
155
155
  def convert_as_custom_model(
156
156
  cls,
157
- raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
157
+ raw_model: "torch.jit.ScriptModule",
158
158
  model_meta: model_meta_api.ModelMetadata,
159
159
  background_data: Optional[pd.DataFrame] = None,
160
160
  **kwargs: Unpack[model_types.TorchScriptLoadOptions],
@@ -162,11 +162,11 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
162
162
  from snowflake.ml.model import custom_model
163
163
 
164
164
  def _create_custom_model(
165
- raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
165
+ raw_model: "torch.jit.ScriptModule",
166
166
  model_meta: model_meta_api.ModelMetadata,
167
167
  ) -> Type[custom_model.CustomModel]:
168
168
  def fn_factory(
169
- raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
169
+ raw_model: "torch.jit.ScriptModule",
170
170
  signature: model_signature.ModelSignature,
171
171
  target_method: str,
172
172
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
@@ -1,3 +1,2 @@
1
- REQUIREMENTS = [
2
- "cloudpickle>=2.0.0"
3
- ]
1
+ REQUIREMENTS = ['cloudpickle>=2.0.0']
2
+ ALL_REQUIREMENTS=['cloudpickle>=2.0.0']
@@ -58,11 +58,16 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
58
58
  xgb_estimator_type: Required[str]
59
59
 
60
60
 
61
+ class TensorflowModelBlobOptions(BaseModelBlobOptions):
62
+ is_keras_model: Required[bool]
63
+
64
+
61
65
  ModelBlobOptions = Union[
62
66
  BaseModelBlobOptions,
63
67
  HuggingFacePipelineModelBlobOptions,
64
68
  MLFlowModelBlobOptions,
65
69
  XgboostModelBlobOptions,
70
+ TensorflowModelBlobOptions,
66
71
  ]
67
72
 
68
73
 
@@ -1,10 +1,2 @@
1
- REQUIREMENTS = [
2
- "absl-py>=0.15,<2",
3
- "anyio>=3.5.0,<4",
4
- "numpy>=1.23,<2",
5
- "packaging>=20.9,<24",
6
- "pandas>=1.0.0,<3",
7
- "pyyaml>=6.0,<7",
8
- "snowflake-snowpark-python>=1.17.0,<2",
9
- "typing-extensions>=4.1.0,<5"
10
- ]
1
+ REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
+ ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2022.11,<2024', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.1.0,<2.4', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2022.11,<2024', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.10,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
@@ -17,6 +17,8 @@ _SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES = [
17
17
  for r in _snowml_inference_alternative_requirements.REQUIREMENTS
18
18
  ]
19
19
 
20
+ PACKAGES_NOT_ALLOWED_IN_WAREHOUSE = ["snowflake-connector-python", "pyarrow"]
21
+
20
22
 
21
23
  class ModelRuntime:
22
24
  """Class to represent runtime in a model, which controls the runtime and version, imports and dependencies.
@@ -61,15 +63,8 @@ class ModelRuntime:
61
63
  ],
62
64
  )
63
65
 
64
- if not is_warehouse and self.embed_local_ml_library:
65
- self.runtime_env.include_if_absent(
66
- [
67
- model_env.ModelDependency(
68
- requirement="pyarrow",
69
- pip_name="pyarrow",
70
- )
71
- ],
72
- )
66
+ if is_warehouse and self.embed_local_ml_library:
67
+ self.runtime_env.remove_if_present_conda(PACKAGES_NOT_ALLOWED_IN_WAREHOUSE)
73
68
 
74
69
  if is_gpu:
75
70
  self.runtime_env.generate_env_for_cuda()
@@ -84,7 +84,7 @@ def get_model_task_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel
84
84
  if type_utils.LazyType("lightgbm.Booster").isinstance(model):
85
85
  model_task = model.params["objective"] # type: ignore[attr-defined]
86
86
  elif hasattr(model, "objective_"):
87
- model_task = model.objective_
87
+ model_task = model.objective_ # type: ignore[assignment]
88
88
  if model_task in _BINARY_CLASSIFICATION_OBJECTIVES:
89
89
  return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
90
90
  if model_task in _MULTI_CLASSIFICATION_OBJECTIVES: