snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -280,7 +280,7 @@ def _get_or_create_image_repo(session: Session, *, service_func_name: str, image
280
280
  conn = session._conn._conn
281
281
  # We try to use the same db and schema as the service function locates, as we could retrieve those information
282
282
  # if that is a fully qualified one. If not we use the current session one.
283
- (_db, _schema, _, _) = identifier.parse_schema_level_object_identifier(service_func_name)
283
+ (_db, _schema, _) = identifier.parse_schema_level_object_identifier(service_func_name)
284
284
  db = _db if _db is not None else conn._database
285
285
  schema = _schema if _schema is not None else conn._schema
286
286
  assert isinstance(db, str) and isinstance(schema, str)
@@ -343,7 +343,7 @@ class SnowServiceDeployment:
343
343
  self.model_zip_stage_path = model_zip_stage_path
344
344
  self.options = options
345
345
  self.target_method = target_method
346
- (db, schema, _, _) = identifier.parse_schema_level_object_identifier(service_func_name)
346
+ (db, schema, _) = identifier.parse_schema_level_object_identifier(service_func_name)
347
347
 
348
348
  self._service_name = identifier.get_schema_level_object_identifier(db, schema, f"service_{model_id}")
349
349
  self._job_name = identifier.get_schema_level_object_identifier(db, schema, f"build_{model_id}")
@@ -503,7 +503,7 @@ class SnowServiceDeployment:
503
503
  norm_stage_path = posixpath.normpath(identifier.remove_prefix(self.model_zip_stage_path, "@"))
504
504
  # Ensure model stage path has root prefix as stage mount will it mount it to root.
505
505
  absolute_model_stage_path = os.path.join("/", norm_stage_path)
506
- (db, schema, stage, path) = identifier.parse_schema_level_object_identifier(norm_stage_path)
506
+ (db, schema, stage, path) = identifier.parse_snowflake_stage_path(norm_stage_path)
507
507
  substitutes = {
508
508
  "image": image,
509
509
  "predict_endpoint_name": constants.PREDICT,
@@ -10,6 +10,7 @@ from absl import logging
10
10
  from packaging import requirements
11
11
  from typing_extensions import deprecated
12
12
 
13
+ from snowflake import snowpark
13
14
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
14
15
  from snowflake.ml._internal.lineage import lineage_utils
15
16
  from snowflake.ml.data import data_source
@@ -91,6 +92,7 @@ class ModelComposer:
91
92
  python_version: Optional[str] = None,
92
93
  ext_modules: Optional[List[ModuleType]] = None,
93
94
  code_paths: Optional[List[str]] = None,
95
+ model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
94
96
  options: Optional[model_types.ModelSaveOption] = None,
95
97
  ) -> model_meta.ModelMetadata:
96
98
  if not options:
@@ -119,6 +121,7 @@ class ModelComposer:
119
121
  python_version=python_version,
120
122
  ext_modules=ext_modules,
121
123
  code_paths=code_paths,
124
+ model_objective=model_objective,
122
125
  options=options,
123
126
  )
124
127
  assert self.packager.meta is not None
@@ -185,4 +188,6 @@ class ModelComposer:
185
188
  data_sources = lineage_utils.get_data_sources(model)
186
189
  if not data_sources and sample_input_data is not None:
187
190
  data_sources = lineage_utils.get_data_sources(sample_input_data)
191
+ if not data_sources and isinstance(sample_input_data, snowpark.DataFrame):
192
+ data_sources = [data_source.DataFrameInfo(sample_input_data.queries["queries"][-1])]
188
193
  return data_sources
@@ -1,11 +1,11 @@
1
1
  import collections
2
2
  import copy
3
3
  import pathlib
4
- import warnings
5
4
  from typing import List, Optional, cast
6
5
 
7
6
  import yaml
8
7
 
8
+ from snowflake.ml._internal import env_utils
9
9
  from snowflake.ml.data import data_source
10
10
  from snowflake.ml.model import type_hints
11
11
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -47,7 +47,9 @@ class ModelManifest:
47
47
  runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
48
48
  runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
49
49
  runtime_to_use.imports.append(str(model_rel_path) + "/")
50
- runtime_dict = runtime_to_use.save(self.workspace_path)
50
+ runtime_dict = runtime_to_use.save(
51
+ self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
52
+ )
51
53
 
52
54
  self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
53
55
  self.methods: List[model_method.ModelMethod] = []
@@ -75,13 +77,9 @@ class ModelManifest:
75
77
  )
76
78
 
77
79
  dependencies = model_manifest_schema.ModelRuntimeDependenciesDict(conda=runtime_dict["dependencies"]["conda"])
78
- if options.get("include_pip_dependencies"):
79
- warnings.warn(
80
- "`include_pip_dependencies` specified as True: pip dependencies will be included and may not"
81
- "be warehouse-compabible. The model may need to be run in SPCS.",
82
- category=UserWarning,
83
- stacklevel=1,
84
- )
80
+
81
+ # We only want to include pip dependencies file if there are any pip requirements.
82
+ if len(model_meta.env.pip_requirements) > 0:
85
83
  dependencies["pip"] = runtime_dict["dependencies"]["pip"]
86
84
 
87
85
  manifest_dict = model_manifest_schema.ModelManifestDict(
@@ -137,10 +135,15 @@ class ModelManifest:
137
135
  if isinstance(source, data_source.DatasetInfo):
138
136
  result.append(
139
137
  model_manifest_schema.LineageSourceDict(
140
- # Currently, we only support lineage from Dataset.
141
138
  type=model_manifest_schema.LineageSourceTypes.DATASET.value,
142
139
  entity=source.fully_qualified_name,
143
140
  version=source.version,
144
141
  )
145
142
  )
143
+ elif isinstance(source, data_source.DataFrameInfo):
144
+ result.append(
145
+ model_manifest_schema.LineageSourceDict(
146
+ type=model_manifest_schema.LineageSourceTypes.QUERY.value, entity=source.sql
147
+ )
148
+ )
146
149
  return result
@@ -57,12 +57,14 @@ class ModelFunctionInfo(TypedDict):
57
57
  target_method: actual target method name to be called.
58
58
  target_method_function_type: target method function type (FUNCTION or TABLE_FUNCTION).
59
59
  signature: The signature of the model method.
60
+ is_partitioned: Whether the function is partitioned.
60
61
  """
61
62
 
62
63
  name: Required[str]
63
64
  target_method: Required[str]
64
65
  target_method_function_type: Required[str]
65
66
  signature: Required[model_signature.ModelSignature]
67
+ is_partitioned: Required[bool]
66
68
 
67
69
 
68
70
  class ModelFunctionInfoDict(TypedDict):
@@ -78,6 +80,7 @@ class SnowparkMLDataDict(TypedDict):
78
80
 
79
81
  class LineageSourceTypes(enum.Enum):
80
82
  DATASET = "DATASET"
83
+ QUERY = "QUERY"
81
84
 
82
85
 
83
86
  class LineageSourceDict(TypedDict):
@@ -363,9 +363,14 @@ class ModelEnv:
363
363
  self.cuda_version = env_dict.get("cuda_version", None)
364
364
  self.snowpark_ml_version = env_dict["snowpark_ml_version"]
365
365
 
366
- def save_as_dict(self, base_dir: pathlib.Path) -> model_meta_schema.ModelEnvDict:
366
+ def save_as_dict(
367
+ self, base_dir: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
368
+ ) -> model_meta_schema.ModelEnvDict:
367
369
  env_utils.save_conda_env_file(
368
- pathlib.Path(base_dir / self.conda_env_rel_path), self._conda_dependencies, self.python_version
370
+ pathlib.Path(base_dir / self.conda_env_rel_path),
371
+ self._conda_dependencies,
372
+ self.python_version,
373
+ default_channel_override=default_channel_override,
369
374
  )
370
375
  env_utils.save_requirements_file(
371
376
  pathlib.Path(base_dir / self.pip_requirements_rel_path), self._pip_requirements
@@ -1,7 +1,8 @@
1
+ import os
1
2
  from abc import abstractmethod
2
- from enum import Enum
3
3
  from typing import Dict, Generic, Optional, Protocol, Type, final
4
4
 
5
+ import pandas as pd
5
6
  from typing_extensions import TypeGuard, Unpack
6
7
 
7
8
  from snowflake.ml.model import custom_model, type_hints as model_types
@@ -9,15 +10,6 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
9
10
  from snowflake.ml.model._packager.model_meta import model_meta
10
11
 
11
12
 
12
- class ModelObjective(Enum):
13
- # This is not getting stored anywhere as metadata yet so it should be fine to slowly extend it for better coverage
14
- UNKNOWN = "unknown"
15
- BINARY_CLASSIFICATION = "binary_classification"
16
- MULTI_CLASSIFICATION = "multi_classification"
17
- REGRESSION = "regression"
18
- RANKING = "ranking"
19
-
20
-
21
13
  class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
22
14
  HANDLER_TYPE: model_types.SupportedModelHandlerType
23
15
  HANDLER_VERSION: str
@@ -106,6 +98,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
106
98
  cls,
107
99
  raw_model: model_types._ModelType,
108
100
  model_meta: model_meta.ModelMetadata,
101
+ background_data: Optional[pd.DataFrame] = None,
109
102
  **kwargs: Unpack[model_types.BaseModelLoadOption],
110
103
  ) -> custom_model.CustomModel:
111
104
  """Create a custom model class wrap for unified interface when being deployed. The predict method will be
@@ -114,6 +107,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
114
107
  Args:
115
108
  raw_model: original model object,
116
109
  model_meta: The model metadata.
110
+ background_data: The background data used for the model explanations.
117
111
  kwargs: Options when converting the model.
118
112
 
119
113
  Raises:
@@ -131,7 +125,8 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
131
125
  _MIN_SNOWPARK_ML_VERSION: The minimal version of Snowpark ML library to use the current handler.
132
126
  _HANDLER_MIGRATOR_PLANS: Dict holding handler migrator plans.
133
127
 
134
- MODELE_BLOB_FILE_OR_DIR: Relative path of the model blob file in the model subdir. Default to "model.pkl".
128
+ MODEL_BLOB_FILE_OR_DIR: Relative path of the model blob file in the model subdir. Default to "model.pkl".
129
+ BG_DATA_FILE_SUFFIX: Suffix of the background data file. Default to "_background_data.pqt".
135
130
  MODEL_ARTIFACTS_DIR: Relative path of the model artifacts dir in the model subdir. Default to "artifacts"
136
131
  DEFAULT_TARGET_METHODS: Default target methods to be logged if not specified in this kind of model. Default to
137
132
  ["predict"]
@@ -139,8 +134,10 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
139
134
  inputting sample data or model signature. Default to False.
140
135
  """
141
136
 
142
- MODELE_BLOB_FILE_OR_DIR = "model.pkl"
137
+ MODEL_BLOB_FILE_OR_DIR = "model.pkl"
138
+ BG_DATA_FILE_SUFFIX = "_background_data.pqt"
143
139
  MODEL_ARTIFACTS_DIR = "artifacts"
140
+ EXPLAIN_ARTIFACTS_DIR = "explain_artifacts"
144
141
  DEFAULT_TARGET_METHODS = ["predict"]
145
142
  IS_AUTO_SIGNATURE = False
146
143
 
@@ -169,3 +166,23 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
169
166
  model_meta=model_meta,
170
167
  model_blobs_dir_path=model_blobs_dir_path,
171
168
  )
169
+
170
+ @classmethod
171
+ @final
172
+ def load_background_data(cls, name: str, model_blobs_dir_path: str) -> Optional[pd.DataFrame]:
173
+ """Load the model into memory.
174
+
175
+ Args:
176
+ name: Name of the model.
177
+ model_blobs_dir_path: Directory path to the whole model.
178
+
179
+ Returns:
180
+ Optional[pd.DataFrame], background data as pandas DataFrame, if exists.
181
+ """
182
+ data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, name + cls.BG_DATA_FILE_SUFFIX)
183
+ if not os.path.exists(model_blobs_dir_path) or not os.path.isfile(data_blob_path):
184
+ return None
185
+ with open(data_blob_path, "rb") as f:
186
+ background_data = pd.read_parquet(f)
187
+
188
+ return background_data
@@ -1,9 +1,11 @@
1
1
  import json
2
+ import warnings
2
3
  from typing import Any, Callable, Iterable, Optional, Sequence, cast
3
4
 
4
5
  import numpy as np
5
6
  import numpy.typing as npt
6
7
  import pandas as pd
8
+ from absl import logging
7
9
 
8
10
  from snowflake.ml.model import model_signature, type_hints as model_types
9
11
  from snowflake.ml.model._packager.model_meta import model_meta
@@ -11,6 +13,17 @@ from snowflake.ml.model._signatures import snowpark_handler
11
13
  from snowflake.snowpark import DataFrame as SnowparkDataFrame
12
14
 
13
15
 
16
+ class NumpyEncoder(json.JSONEncoder):
17
+ def default(self, obj: Any) -> Any:
18
+ if isinstance(obj, np.integer):
19
+ return int(obj)
20
+ if isinstance(obj, np.floating):
21
+ return float(obj)
22
+ if isinstance(obj, np.ndarray):
23
+ return obj.tolist()
24
+ return super().default(obj)
25
+
26
+
14
27
  def _is_callable(model: model_types.SupportedModelType, method_name: str) -> bool:
15
28
  return callable(getattr(model, method_name, None))
16
29
 
@@ -93,23 +106,42 @@ def convert_explanations_to_2D_df(
93
106
  return pd.DataFrame(explanations)
94
107
 
95
108
  if hasattr(model, "classes_"):
96
- classes_list = [cl for cl in model.classes_] # type:ignore[union-attr]
109
+ classes_list = [str(cl) for cl in model.classes_] # type:ignore[union-attr]
97
110
  len_classes = len(classes_list)
98
111
  if explanations.shape[2] != len_classes:
99
112
  raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
100
113
  else:
101
- classes_list = [i for i in range(explanations.shape[2])]
102
- exp_2d = []
103
- # TODO (SNOW-1549044): Optimize this
104
- for row in explanations:
105
- col_list = []
106
- for column in row:
107
- class_explanations = {}
108
- for cl, cl_exp in zip(classes_list, column):
109
- if isinstance(cl, (int, np.integer)):
110
- cl = int(cl)
111
- class_explanations[cl] = cl_exp
112
- col_list.append(json.dumps(class_explanations))
113
- exp_2d.append(col_list)
114
+ classes_list = [str(i) for i in range(explanations.shape[2])]
115
+
116
+ def row_to_dict(row: npt.NDArray[Any]) -> npt.NDArray[Any]:
117
+ """Converts a single row to a dictionary."""
118
+ # convert to object or numpy creates strings of fixed length
119
+ return np.asarray(json.dumps(dict(zip(classes_list, row)), cls=NumpyEncoder), dtype=object)
120
+
121
+ exp_2d = np.apply_along_axis(row_to_dict, -1, explanations)
114
122
 
115
123
  return pd.DataFrame(exp_2d)
124
+
125
+
126
+ def validate_model_objective(
127
+ passed_model_objective: model_types.ModelObjective, inferred_model_objective: model_types.ModelObjective
128
+ ) -> model_types.ModelObjective:
129
+ if (
130
+ passed_model_objective != model_types.ModelObjective.UNKNOWN
131
+ and inferred_model_objective != model_types.ModelObjective.UNKNOWN
132
+ ):
133
+ if passed_model_objective != inferred_model_objective:
134
+ warnings.warn(
135
+ f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for this model "
136
+ f"version and passed argument ModelObjective: {passed_model_objective.name} is ignored",
137
+ category=UserWarning,
138
+ stacklevel=1,
139
+ )
140
+ return inferred_model_objective
141
+ elif inferred_model_objective != model_types.ModelObjective.UNKNOWN:
142
+ logging.info(
143
+ f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for this model "
144
+ f"version"
145
+ )
146
+ return inferred_model_objective
147
+ return passed_model_objective
@@ -30,24 +30,24 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
30
30
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
31
31
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
32
32
 
33
- MODELE_BLOB_FILE_OR_DIR = "model.bin"
33
+ MODEL_BLOB_FILE_OR_DIR = "model.bin"
34
34
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
35
35
 
36
36
  @classmethod
37
- def get_model_objective(cls, model: "catboost.CatBoost") -> _base.ModelObjective:
37
+ def get_model_objective_and_output_type(cls, model: "catboost.CatBoost") -> model_types.ModelObjective:
38
38
  import catboost
39
39
 
40
40
  if isinstance(model, catboost.CatBoostClassifier):
41
41
  num_classes = handlers_utils.get_num_classes_if_exists(model)
42
42
  if num_classes == 2:
43
- return _base.ModelObjective.BINARY_CLASSIFICATION
44
- return _base.ModelObjective.MULTI_CLASSIFICATION
43
+ return model_types.ModelObjective.BINARY_CLASSIFICATION
44
+ return model_types.ModelObjective.MULTI_CLASSIFICATION
45
45
  if isinstance(model, catboost.CatBoostRanker):
46
- return _base.ModelObjective.RANKING
46
+ return model_types.ModelObjective.RANKING
47
47
  if isinstance(model, catboost.CatBoostRegressor):
48
- return _base.ModelObjective.REGRESSION
48
+ return model_types.ModelObjective.REGRESSION
49
49
  # TODO: Find out model type from the generic Catboost Model
50
- return _base.ModelObjective.UNKNOWN
50
+ return model_types.ModelObjective.UNKNOWN
51
51
 
52
52
  @classmethod
53
53
  def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
@@ -77,6 +77,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
77
77
  is_sub_model: Optional[bool] = False,
78
78
  **kwargs: Unpack[model_types.CatBoostModelSaveOptions],
79
79
  ) -> None:
80
+ enable_explainability = kwargs.get("enable_explainability", True)
81
+
80
82
  import catboost
81
83
 
82
84
  assert isinstance(model, catboost.CatBoost)
@@ -105,9 +107,14 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
105
107
  sample_input_data=sample_input_data,
106
108
  get_prediction_fn=get_prediction,
107
109
  )
108
- if kwargs.get("enable_explainability", False):
110
+ inferred_model_objective = cls.get_model_objective_and_output_type(model)
111
+ model_meta.model_objective = handlers_utils.validate_model_objective(
112
+ model_meta.model_objective, inferred_model_objective
113
+ )
114
+ model_objective = model_meta.model_objective
115
+ if enable_explainability:
109
116
  output_type = model_signature.DataType.DOUBLE
110
- if cls.get_model_objective(model) == _base.ModelObjective.MULTI_CLASSIFICATION:
117
+ if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION:
111
118
  output_type = model_signature.DataType.STRING
112
119
  model_meta = handlers_utils.add_explain_method_signature(
113
120
  model_meta=model_meta,
@@ -115,10 +122,13 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
115
122
  target_method="predict",
116
123
  output_return_type=output_type,
117
124
  )
125
+ model_meta.function_properties = {
126
+ "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
127
+ }
118
128
 
119
129
  model_blob_path = os.path.join(model_blobs_dir_path, name)
120
130
  os.makedirs(model_blob_path, exist_ok=True)
121
- model_save_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
131
+ model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
122
132
 
123
133
  model.save_model(model_save_path)
124
134
 
@@ -126,7 +136,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
126
136
  name=name,
127
137
  model_type=cls.HANDLER_TYPE,
128
138
  handler_version=cls.HANDLER_VERSION,
129
- path=cls.MODELE_BLOB_FILE_OR_DIR,
139
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
130
140
  options=model_meta_schema.CatBoostModelBlobOptions({"catboost_estimator_type": model.__class__.__name__}),
131
141
  )
132
142
  model_meta.models[name] = base_meta
@@ -138,11 +148,9 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
138
148
  ],
139
149
  check_local_version=True,
140
150
  )
141
- if kwargs.get("enable_explainability", False):
142
- model_meta.env.include_if_absent(
143
- [model_env.ModelDependency(requirement="shap", pip_name="shap")],
144
- check_local_version=True,
145
- )
151
+ if enable_explainability:
152
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
153
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
146
154
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
147
155
 
148
156
  return None
@@ -188,6 +196,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
188
196
  cls,
189
197
  raw_model: "catboost.CatBoost",
190
198
  model_meta: model_meta_api.ModelMetadata,
199
+ background_data: Optional[pd.DataFrame] = None,
191
200
  **kwargs: Unpack[model_types.CatBoostModelLoadOptions],
192
201
  ) -> custom_model.CustomModel:
193
202
  import catboost
@@ -51,6 +51,9 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
51
51
  **kwargs: Unpack[model_types.CustomModelSaveOption],
52
52
  ) -> None:
53
53
  assert isinstance(model, custom_model.CustomModel)
54
+ enable_explainability = kwargs.get("enable_explainability", False)
55
+ if enable_explainability:
56
+ raise NotImplementedError("Explainability is not supported for custom model.")
54
57
 
55
58
  def get_prediction(
56
59
  target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
@@ -108,13 +111,13 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
108
111
  # Make sure that the module where the model is defined get pickled by value as well.
109
112
  cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
110
113
  pickled_obj = (model.__class__, model.context)
111
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
114
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
112
115
  cloudpickle.dump(pickled_obj, f)
113
116
  # model meta will be saved by the context manager
114
117
  model_meta.models[name] = model_blob_meta.ModelBlobMeta(
115
118
  name=name,
116
119
  model_type=cls.HANDLER_TYPE,
117
- path=cls.MODELE_BLOB_FILE_OR_DIR,
120
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
118
121
  handler_version=cls.HANDLER_VERSION,
119
122
  function_properties=model_meta.function_properties,
120
123
  artifacts={
@@ -183,6 +186,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
183
186
  cls,
184
187
  raw_model: custom_model.CustomModel,
185
188
  model_meta: model_meta_api.ModelMetadata,
189
+ background_data: Optional[pd.DataFrame] = None,
186
190
  **kwargs: Unpack[model_types.CustomModelLoadOption],
187
191
  ) -> custom_model.CustomModel:
188
192
  return raw_model
@@ -89,7 +89,7 @@ class HuggingFacePipelineHandler(
89
89
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
90
90
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
91
91
 
92
- MODELE_BLOB_FILE_OR_DIR = "model"
92
+ MODEL_BLOB_FILE_OR_DIR = "model"
93
93
  ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
94
94
  DEFAULT_TARGET_METHODS = ["__call__"]
95
95
  IS_AUTO_SIGNATURE = True
@@ -133,6 +133,9 @@ class HuggingFacePipelineHandler(
133
133
  is_sub_model: Optional[bool] = False,
134
134
  **kwargs: Unpack[model_types.HuggingFaceSaveOptions],
135
135
  ) -> None:
136
+ enable_explainability = kwargs.get("enable_explainability", False)
137
+ if enable_explainability:
138
+ raise NotImplementedError("Explainability is not supported for huggingface model.")
136
139
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
137
140
  task = model.task # type:ignore[attr-defined]
138
141
  framework = model.framework # type:ignore[attr-defined]
@@ -193,7 +196,7 @@ class HuggingFacePipelineHandler(
193
196
 
194
197
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
195
198
  model.save_pretrained( # type:ignore[attr-defined]
196
- os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
199
+ os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
197
200
  )
198
201
  pipeline_params = {
199
202
  "_batch_size": model._batch_size, # type:ignore[attr-defined]
@@ -205,7 +208,7 @@ class HuggingFacePipelineHandler(
205
208
  with open(
206
209
  os.path.join(
207
210
  model_blob_path,
208
- cls.MODELE_BLOB_FILE_OR_DIR,
211
+ cls.MODEL_BLOB_FILE_OR_DIR,
209
212
  cls.ADDITIONAL_CONFIG_FILE,
210
213
  ),
211
214
  "wb",
@@ -213,7 +216,7 @@ class HuggingFacePipelineHandler(
213
216
  cloudpickle.dump(pipeline_params, f)
214
217
  else:
215
218
  with open(
216
- os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR),
219
+ os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR),
217
220
  "wb",
218
221
  ) as f:
219
222
  cloudpickle.dump(model, f)
@@ -222,7 +225,7 @@ class HuggingFacePipelineHandler(
222
225
  name=name,
223
226
  model_type=cls.HANDLER_TYPE,
224
227
  handler_version=cls.HANDLER_VERSION,
225
- path=cls.MODELE_BLOB_FILE_OR_DIR,
228
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
226
229
  options=model_meta_schema.HuggingFacePipelineModelBlobOptions(
227
230
  {
228
231
  "task": task,
@@ -329,6 +332,7 @@ class HuggingFacePipelineHandler(
329
332
  cls,
330
333
  raw_model: Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"],
331
334
  model_meta: model_meta_api.ModelMetadata,
335
+ background_data: Optional[pd.DataFrame] = None,
332
336
  **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
333
337
  ) -> custom_model.CustomModel:
334
338
  import transformers
@@ -365,7 +369,9 @@ class HuggingFacePipelineHandler(
365
369
  else:
366
370
  # For others, we could offer the whole dataframe as a list.
367
371
  # Some of them may need some conversion
368
- if isinstance(raw_model, transformers.ConversationalPipeline):
372
+ if hasattr(transformers, "ConversationalPipeline") and isinstance(
373
+ raw_model, transformers.ConversationalPipeline
374
+ ):
369
375
  input_data = [
370
376
  transformers.Conversation(
371
377
  text=conv_data["user_inputs"][0],
@@ -387,27 +393,33 @@ class HuggingFacePipelineHandler(
387
393
  # Making it not aligned with the auto-inferred signature.
388
394
  # If the output is a dict, we could blindly create a list containing that.
389
395
  # Otherwise, creating pandas DataFrame won't succeed.
390
- if isinstance(temp_res, (dict, transformers.Conversation)) or (
391
- # For some pipeline that is expected to generate a list of dict per input
392
- # When it omit outer list, it becomes list of dict instead of list of list of dict.
393
- # We need to distinguish them from those pipelines that designed to output a dict per input
394
- # So we need to check the pipeline type.
395
- isinstance(
396
- raw_model,
397
- (
398
- transformers.FillMaskPipeline,
399
- transformers.QuestionAnsweringPipeline,
400
- ),
396
+ if (
397
+ (hasattr(transformers, "Conversation") and isinstance(temp_res, transformers.Conversation))
398
+ or isinstance(temp_res, dict)
399
+ or (
400
+ # For some pipeline that is expected to generate a list of dict per input
401
+ # When it omit outer list, it becomes list of dict instead of list of list of dict.
402
+ # We need to distinguish them from those pipelines that designed to output a dict per input
403
+ # So we need to check the pipeline type.
404
+ isinstance(
405
+ raw_model,
406
+ (
407
+ transformers.FillMaskPipeline,
408
+ transformers.QuestionAnsweringPipeline,
409
+ ),
410
+ )
411
+ and X.shape[0] == 1
412
+ and isinstance(temp_res[0], dict)
401
413
  )
402
- and X.shape[0] == 1
403
- and isinstance(temp_res[0], dict)
404
414
  ):
405
415
  temp_res = [temp_res]
406
416
 
407
417
  if len(temp_res) == 0:
408
418
  return pd.DataFrame()
409
419
 
410
- if isinstance(raw_model, transformers.ConversationalPipeline):
420
+ if hasattr(transformers, "ConversationalPipeline") and isinstance(
421
+ raw_model, transformers.ConversationalPipeline
422
+ ):
411
423
  temp_res = [[conv.generated_responses] for conv in temp_res]
412
424
 
413
425
  # To concat those who outputs a list with one input.