snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -507,12 +504,23 @@ class BernoulliRBM(BaseTransformer):
507
504
  autogenerated=self._autogenerated,
508
505
  subproject=_SUBPROJECT,
509
506
  )
510
- output_result, fitted_estimator = model_trainer.train_fit_predict(
511
- drop_input_cols=self._drop_input_cols,
512
- expected_output_cols_list=(
513
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
514
- ),
507
+ expected_output_cols = (
508
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
515
509
  )
510
+ if isinstance(dataset, DataFrame):
511
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
512
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
513
+ )
514
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
515
+ drop_input_cols=self._drop_input_cols,
516
+ expected_output_cols_list=expected_output_cols,
517
+ example_output_pd_df=example_output_pd_df,
518
+ )
519
+ else:
520
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
521
+ drop_input_cols=self._drop_input_cols,
522
+ expected_output_cols_list=expected_output_cols,
523
+ )
516
524
  self._sklearn_object = fitted_estimator
517
525
  self._is_fitted = True
518
526
  return output_result
@@ -593,12 +601,41 @@ class BernoulliRBM(BaseTransformer):
593
601
 
594
602
  return rv
595
603
 
596
- def _align_expected_output_names(
597
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
598
- ) -> List[str]:
604
+ def _align_expected_output(
605
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
606
+ ) -> Tuple[List[str], pd.DataFrame]:
607
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
608
+ and output dataframe with 1 line.
609
+ If the method is fit_predict, run 2 lines of data.
610
+ """
599
611
  # in case the inferred output column names dimension is different
600
612
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
601
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
613
+
614
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
615
+ # so change the minimum of number of rows to 2
616
+ num_examples = 2
617
+ statement_params = telemetry.get_function_usage_statement_params(
618
+ project=_PROJECT,
619
+ subproject=_SUBPROJECT,
620
+ function_name=telemetry.get_statement_params_full_func_name(
621
+ inspect.currentframe(), BernoulliRBM.__class__.__name__
622
+ ),
623
+ api_calls=[Session.call],
624
+ custom_tags={"autogen": True} if self._autogenerated else None,
625
+ )
626
+ if output_cols_prefix == "fit_predict_":
627
+ if hasattr(self._sklearn_object, "n_clusters"):
628
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
629
+ num_examples = self._sklearn_object.n_clusters
630
+ elif hasattr(self._sklearn_object, "min_samples"):
631
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
632
+ num_examples = self._sklearn_object.min_samples
633
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
634
+ # LocalOutlierFactor expects n_neighbors <= n_samples
635
+ num_examples = self._sklearn_object.n_neighbors
636
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
637
+ else:
638
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
602
639
 
603
640
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
604
641
  # seen during the fit.
@@ -610,12 +647,14 @@ class BernoulliRBM(BaseTransformer):
610
647
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
611
648
  if self.sample_weight_col:
612
649
  output_df_columns_set -= set(self.sample_weight_col)
650
+
613
651
  # if the dimension of inferred output column names is correct; use it
614
652
  if len(expected_output_cols_list) == len(output_df_columns_set):
615
- return expected_output_cols_list
653
+ return expected_output_cols_list, output_df_pd
616
654
  # otherwise, use the sklearn estimator's output
617
655
  else:
618
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
656
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
657
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
619
658
 
620
659
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
621
660
  @telemetry.send_api_usage_telemetry(
@@ -661,7 +700,7 @@ class BernoulliRBM(BaseTransformer):
661
700
  drop_input_cols=self._drop_input_cols,
662
701
  expected_output_cols_type="float",
663
702
  )
664
- expected_output_cols = self._align_expected_output_names(
703
+ expected_output_cols, _ = self._align_expected_output(
665
704
  inference_method, dataset, expected_output_cols, output_cols_prefix
666
705
  )
667
706
 
@@ -727,7 +766,7 @@ class BernoulliRBM(BaseTransformer):
727
766
  drop_input_cols=self._drop_input_cols,
728
767
  expected_output_cols_type="float",
729
768
  )
730
- expected_output_cols = self._align_expected_output_names(
769
+ expected_output_cols, _ = self._align_expected_output(
731
770
  inference_method, dataset, expected_output_cols, output_cols_prefix
732
771
  )
733
772
  elif isinstance(dataset, pd.DataFrame):
@@ -790,7 +829,7 @@ class BernoulliRBM(BaseTransformer):
790
829
  drop_input_cols=self._drop_input_cols,
791
830
  expected_output_cols_type="float",
792
831
  )
793
- expected_output_cols = self._align_expected_output_names(
832
+ expected_output_cols, _ = self._align_expected_output(
794
833
  inference_method, dataset, expected_output_cols, output_cols_prefix
795
834
  )
796
835
 
@@ -857,7 +896,7 @@ class BernoulliRBM(BaseTransformer):
857
896
  drop_input_cols = self._drop_input_cols,
858
897
  expected_output_cols_type="float",
859
898
  )
860
- expected_output_cols = self._align_expected_output_names(
899
+ expected_output_cols, _ = self._align_expected_output(
861
900
  inference_method, dataset, expected_output_cols, output_cols_prefix
862
901
  )
863
902
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -660,12 +657,23 @@ class MLPClassifier(BaseTransformer):
660
657
  autogenerated=self._autogenerated,
661
658
  subproject=_SUBPROJECT,
662
659
  )
663
- output_result, fitted_estimator = model_trainer.train_fit_predict(
664
- drop_input_cols=self._drop_input_cols,
665
- expected_output_cols_list=(
666
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
667
- ),
660
+ expected_output_cols = (
661
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
668
662
  )
663
+ if isinstance(dataset, DataFrame):
664
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
665
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
666
+ )
667
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
668
+ drop_input_cols=self._drop_input_cols,
669
+ expected_output_cols_list=expected_output_cols,
670
+ example_output_pd_df=example_output_pd_df,
671
+ )
672
+ else:
673
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
674
+ drop_input_cols=self._drop_input_cols,
675
+ expected_output_cols_list=expected_output_cols,
676
+ )
669
677
  self._sklearn_object = fitted_estimator
670
678
  self._is_fitted = True
671
679
  return output_result
@@ -744,12 +752,41 @@ class MLPClassifier(BaseTransformer):
744
752
 
745
753
  return rv
746
754
 
747
- def _align_expected_output_names(
748
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
749
- ) -> List[str]:
755
+ def _align_expected_output(
756
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
757
+ ) -> Tuple[List[str], pd.DataFrame]:
758
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
759
+ and output dataframe with 1 line.
760
+ If the method is fit_predict, run 2 lines of data.
761
+ """
750
762
  # in case the inferred output column names dimension is different
751
763
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
752
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
764
+
765
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
766
+ # so change the minimum of number of rows to 2
767
+ num_examples = 2
768
+ statement_params = telemetry.get_function_usage_statement_params(
769
+ project=_PROJECT,
770
+ subproject=_SUBPROJECT,
771
+ function_name=telemetry.get_statement_params_full_func_name(
772
+ inspect.currentframe(), MLPClassifier.__class__.__name__
773
+ ),
774
+ api_calls=[Session.call],
775
+ custom_tags={"autogen": True} if self._autogenerated else None,
776
+ )
777
+ if output_cols_prefix == "fit_predict_":
778
+ if hasattr(self._sklearn_object, "n_clusters"):
779
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
780
+ num_examples = self._sklearn_object.n_clusters
781
+ elif hasattr(self._sklearn_object, "min_samples"):
782
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
783
+ num_examples = self._sklearn_object.min_samples
784
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
785
+ # LocalOutlierFactor expects n_neighbors <= n_samples
786
+ num_examples = self._sklearn_object.n_neighbors
787
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
788
+ else:
789
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
753
790
 
754
791
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
755
792
  # seen during the fit.
@@ -761,12 +798,14 @@ class MLPClassifier(BaseTransformer):
761
798
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
762
799
  if self.sample_weight_col:
763
800
  output_df_columns_set -= set(self.sample_weight_col)
801
+
764
802
  # if the dimension of inferred output column names is correct; use it
765
803
  if len(expected_output_cols_list) == len(output_df_columns_set):
766
- return expected_output_cols_list
804
+ return expected_output_cols_list, output_df_pd
767
805
  # otherwise, use the sklearn estimator's output
768
806
  else:
769
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
807
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
808
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
770
809
 
771
810
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
772
811
  @telemetry.send_api_usage_telemetry(
@@ -814,7 +853,7 @@ class MLPClassifier(BaseTransformer):
814
853
  drop_input_cols=self._drop_input_cols,
815
854
  expected_output_cols_type="float",
816
855
  )
817
- expected_output_cols = self._align_expected_output_names(
856
+ expected_output_cols, _ = self._align_expected_output(
818
857
  inference_method, dataset, expected_output_cols, output_cols_prefix
819
858
  )
820
859
 
@@ -882,7 +921,7 @@ class MLPClassifier(BaseTransformer):
882
921
  drop_input_cols=self._drop_input_cols,
883
922
  expected_output_cols_type="float",
884
923
  )
885
- expected_output_cols = self._align_expected_output_names(
924
+ expected_output_cols, _ = self._align_expected_output(
886
925
  inference_method, dataset, expected_output_cols, output_cols_prefix
887
926
  )
888
927
  elif isinstance(dataset, pd.DataFrame):
@@ -945,7 +984,7 @@ class MLPClassifier(BaseTransformer):
945
984
  drop_input_cols=self._drop_input_cols,
946
985
  expected_output_cols_type="float",
947
986
  )
948
- expected_output_cols = self._align_expected_output_names(
987
+ expected_output_cols, _ = self._align_expected_output(
949
988
  inference_method, dataset, expected_output_cols, output_cols_prefix
950
989
  )
951
990
 
@@ -1010,7 +1049,7 @@ class MLPClassifier(BaseTransformer):
1010
1049
  drop_input_cols = self._drop_input_cols,
1011
1050
  expected_output_cols_type="float",
1012
1051
  )
1013
- expected_output_cols = self._align_expected_output_names(
1052
+ expected_output_cols, _ = self._align_expected_output(
1014
1053
  inference_method, dataset, expected_output_cols, output_cols_prefix
1015
1054
  )
1016
1055
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -656,12 +653,23 @@ class MLPRegressor(BaseTransformer):
656
653
  autogenerated=self._autogenerated,
657
654
  subproject=_SUBPROJECT,
658
655
  )
659
- output_result, fitted_estimator = model_trainer.train_fit_predict(
660
- drop_input_cols=self._drop_input_cols,
661
- expected_output_cols_list=(
662
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
663
- ),
656
+ expected_output_cols = (
657
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
664
658
  )
659
+ if isinstance(dataset, DataFrame):
660
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
661
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
662
+ )
663
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
664
+ drop_input_cols=self._drop_input_cols,
665
+ expected_output_cols_list=expected_output_cols,
666
+ example_output_pd_df=example_output_pd_df,
667
+ )
668
+ else:
669
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
670
+ drop_input_cols=self._drop_input_cols,
671
+ expected_output_cols_list=expected_output_cols,
672
+ )
665
673
  self._sklearn_object = fitted_estimator
666
674
  self._is_fitted = True
667
675
  return output_result
@@ -740,12 +748,41 @@ class MLPRegressor(BaseTransformer):
740
748
 
741
749
  return rv
742
750
 
743
- def _align_expected_output_names(
744
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
745
- ) -> List[str]:
751
+ def _align_expected_output(
752
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
753
+ ) -> Tuple[List[str], pd.DataFrame]:
754
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
755
+ and output dataframe with 1 line.
756
+ If the method is fit_predict, run 2 lines of data.
757
+ """
746
758
  # in case the inferred output column names dimension is different
747
759
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
748
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
760
+
761
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
762
+ # so change the minimum of number of rows to 2
763
+ num_examples = 2
764
+ statement_params = telemetry.get_function_usage_statement_params(
765
+ project=_PROJECT,
766
+ subproject=_SUBPROJECT,
767
+ function_name=telemetry.get_statement_params_full_func_name(
768
+ inspect.currentframe(), MLPRegressor.__class__.__name__
769
+ ),
770
+ api_calls=[Session.call],
771
+ custom_tags={"autogen": True} if self._autogenerated else None,
772
+ )
773
+ if output_cols_prefix == "fit_predict_":
774
+ if hasattr(self._sklearn_object, "n_clusters"):
775
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
776
+ num_examples = self._sklearn_object.n_clusters
777
+ elif hasattr(self._sklearn_object, "min_samples"):
778
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
779
+ num_examples = self._sklearn_object.min_samples
780
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
781
+ # LocalOutlierFactor expects n_neighbors <= n_samples
782
+ num_examples = self._sklearn_object.n_neighbors
783
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
784
+ else:
785
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
749
786
 
750
787
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
751
788
  # seen during the fit.
@@ -757,12 +794,14 @@ class MLPRegressor(BaseTransformer):
757
794
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
758
795
  if self.sample_weight_col:
759
796
  output_df_columns_set -= set(self.sample_weight_col)
797
+
760
798
  # if the dimension of inferred output column names is correct; use it
761
799
  if len(expected_output_cols_list) == len(output_df_columns_set):
762
- return expected_output_cols_list
800
+ return expected_output_cols_list, output_df_pd
763
801
  # otherwise, use the sklearn estimator's output
764
802
  else:
765
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
803
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
804
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
766
805
 
767
806
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
768
807
  @telemetry.send_api_usage_telemetry(
@@ -808,7 +847,7 @@ class MLPRegressor(BaseTransformer):
808
847
  drop_input_cols=self._drop_input_cols,
809
848
  expected_output_cols_type="float",
810
849
  )
811
- expected_output_cols = self._align_expected_output_names(
850
+ expected_output_cols, _ = self._align_expected_output(
812
851
  inference_method, dataset, expected_output_cols, output_cols_prefix
813
852
  )
814
853
 
@@ -874,7 +913,7 @@ class MLPRegressor(BaseTransformer):
874
913
  drop_input_cols=self._drop_input_cols,
875
914
  expected_output_cols_type="float",
876
915
  )
877
- expected_output_cols = self._align_expected_output_names(
916
+ expected_output_cols, _ = self._align_expected_output(
878
917
  inference_method, dataset, expected_output_cols, output_cols_prefix
879
918
  )
880
919
  elif isinstance(dataset, pd.DataFrame):
@@ -937,7 +976,7 @@ class MLPRegressor(BaseTransformer):
937
976
  drop_input_cols=self._drop_input_cols,
938
977
  expected_output_cols_type="float",
939
978
  )
940
- expected_output_cols = self._align_expected_output_names(
979
+ expected_output_cols, _ = self._align_expected_output(
941
980
  inference_method, dataset, expected_output_cols, output_cols_prefix
942
981
  )
943
982
 
@@ -1002,7 +1041,7 @@ class MLPRegressor(BaseTransformer):
1002
1041
  drop_input_cols = self._drop_input_cols,
1003
1042
  expected_output_cols_type="float",
1004
1043
  )
1005
- expected_output_cols = self._align_expected_output_names(
1044
+ expected_output_cols, _ = self._align_expected_output(
1006
1045
  inference_method, dataset, expected_output_cols, output_cols_prefix
1007
1046
  )
1008
1047
 
@@ -0,0 +1,5 @@
1
+ """Disables the snowpark observability tracer when running modeling fit"""
2
+
3
+ from snowflake.ml.modeling._internal.snowpark_implementations import snowpark_trainer
4
+
5
+ snowpark_trainer._ENABLE_TRACER = False
@@ -19,6 +19,7 @@ from snowflake.ml._internal import file_utils, telemetry
19
19
  from snowflake.ml._internal.exceptions import error_codes, exceptions
20
20
  from snowflake.ml._internal.lineage import lineage_utils
21
21
  from snowflake.ml._internal.utils import snowpark_dataframe_utils, temp_file_utils
22
+ from snowflake.ml.data import data_source
22
23
  from snowflake.ml.model.model_signature import ModelSignature, _infer_signature
23
24
  from snowflake.ml.modeling._internal.model_transformer_builder import (
24
25
  ModelTransformerBuilder,
@@ -417,9 +418,6 @@ class Pipeline(base.BaseTransformer):
417
418
 
418
419
  Returns:
419
420
  Fitted pipeline.
420
-
421
- Raises:
422
- ValueError: A pipeline incompatible with sklearn is used on MLRS
423
421
  """
424
422
 
425
423
  self._validate_steps()
@@ -431,11 +429,11 @@ class Pipeline(base.BaseTransformer):
431
429
 
432
430
  # Extract lineage information here since we're overriding fit() directly
433
431
  data_sources = lineage_utils.get_data_sources(dataset)
432
+ if not data_sources and isinstance(dataset, snowpark.DataFrame):
433
+ data_sources = [data_source.DataFrameInfo(dataset.queries["queries"][-1])]
434
434
  lineage_utils.set_data_sources(self, data_sources)
435
435
 
436
436
  if self._can_be_trained_in_ml_runtime(dataset):
437
- if not self._is_convertible_to_sklearn:
438
- raise ValueError("This pipeline cannot be converted to an sklearn pipeline.")
439
437
  self._fit_ml_runtime(dataset)
440
438
 
441
439
  elif squash and isinstance(dataset, snowpark.DataFrame):
@@ -608,14 +606,8 @@ class Pipeline(base.BaseTransformer):
608
606
 
609
607
  Returns:
610
608
  Output dataset.
611
-
612
- Raises:
613
- ValueError: An sklearn object has not been fit and stored before calling this function.
614
609
  """
615
- if os.environ.get(IN_ML_RUNTIME_ENV_VAR):
616
- if self._sklearn_object is None:
617
- raise ValueError("Model must be fit before inference.")
618
-
610
+ if os.environ.get(IN_ML_RUNTIME_ENV_VAR) and self._sklearn_object is not None:
619
611
  expected_output_cols = self._infer_output_cols()
620
612
  handler = ModelTransformerBuilder.build(
621
613
  dataset=dataset,