snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. snowflake/ml/_internal/file_utils.py +3 -3
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/telemetry.py +11 -2
  7. snowflake/ml/_internal/utils/formatting.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +15 -106
  9. snowflake/ml/fileset/sfcfs.py +4 -3
  10. snowflake/ml/fileset/stage_fs.py +18 -0
  11. snowflake/ml/model/_api.py +9 -9
  12. snowflake/ml/model/_client/model/model_version_impl.py +20 -15
  13. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
  14. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
  15. snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
  16. snowflake/ml/model/_model_composer/model_composer.py +10 -8
  17. snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
  18. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
  20. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
  23. snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
  24. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
  25. snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
  28. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
  29. snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
  30. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
  31. snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
  32. snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
  33. snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
  34. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  35. snowflake/ml/model/_packager/model_packager.py +8 -6
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +13 -0
  38. snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
  40. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
  41. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
  42. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  43. snowflake/ml/modeling/_internal/model_trainer.py +2 -2
  44. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
  45. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
  46. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
  47. snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
  51. snowflake/ml/modeling/cluster/birch.py +33 -61
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
  53. snowflake/ml/modeling/cluster/dbscan.py +33 -61
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
  55. snowflake/ml/modeling/cluster/k_means.py +33 -61
  56. snowflake/ml/modeling/cluster/mean_shift.py +33 -61
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
  58. snowflake/ml/modeling/cluster/optics.py +33 -61
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
  62. snowflake/ml/modeling/compose/column_transformer.py +33 -61
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
  69. snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
  70. snowflake/ml/modeling/covariance/oas.py +33 -61
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
  74. snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
  79. snowflake/ml/modeling/decomposition/pca.py +33 -61
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
  108. snowflake/ml/modeling/framework/base.py +55 -5
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
  111. snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
  112. snowflake/ml/modeling/impute/knn_imputer.py +33 -61
  113. snowflake/ml/modeling/impute/missing_indicator.py +33 -61
  114. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
  123. snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
  125. snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
  129. snowflake/ml/modeling/linear_model/lars.py +33 -61
  130. snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
  131. snowflake/ml/modeling/linear_model/lasso.py +33 -61
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
  136. snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
  146. snowflake/ml/modeling/linear_model/perceptron.py +33 -61
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
  149. snowflake/ml/modeling/linear_model/ridge.py +33 -61
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
  158. snowflake/ml/modeling/manifold/isomap.py +33 -61
  159. snowflake/ml/modeling/manifold/mds.py +33 -61
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
  161. snowflake/ml/modeling/manifold/tsne.py +33 -61
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
  164. snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
  165. snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
  166. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
  167. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
  168. snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
  169. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
  170. snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
  171. snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
  172. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
  173. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
  174. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
  175. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
  176. snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
  177. snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
  178. snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
  179. snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
  180. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
  181. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
  182. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
  183. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
  184. snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
  185. snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
  189. snowflake/ml/modeling/svm/linear_svc.py +33 -61
  190. snowflake/ml/modeling/svm/linear_svr.py +33 -61
  191. snowflake/ml/modeling/svm/nu_svc.py +33 -61
  192. snowflake/ml/modeling/svm/nu_svr.py +33 -61
  193. snowflake/ml/modeling/svm/svc.py +33 -61
  194. snowflake/ml/modeling/svm/svr.py +33 -61
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
  203. snowflake/ml/registry/_manager/model_manager.py +6 -2
  204. snowflake/ml/registry/model_registry.py +100 -27
  205. snowflake/ml/registry/registry.py +6 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
  208. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
  209. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
  210. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
  211. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,10 @@
1
1
  import inspect
2
2
  from typing import Any, List, Optional
3
3
 
4
- import numpy as np
5
4
  import pandas as pd
6
5
 
7
6
  from snowflake.ml._internal.exceptions import error_codes, exceptions
7
+ from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
8
8
 
9
9
 
10
10
  class PandasTransformHandlers:
@@ -107,48 +107,9 @@ class PandasTransformHandlers:
107
107
 
108
108
  inference_res = getattr(self.estimator, inference_method)(input_df, *args, **kwargs)
109
109
 
110
- if isinstance(inference_res, list) and len(inference_res) > 0 and isinstance(inference_res[0], np.ndarray):
111
- # In case of multioutput estimators, predict_proba, decision_function etc., functions return a list of
112
- # ndarrays. We need to concatenate them.
113
-
114
- # First compute output column names
115
- if len(output_cols) == len(inference_res):
116
- actual_output_cols = []
117
- for idx, np_arr in enumerate(inference_res):
118
- for i in range(1 if len(np_arr.shape) <= 1 else np_arr.shape[1]):
119
- actual_output_cols.append(f"{output_cols[idx]}_{i}")
120
- output_cols = actual_output_cols
121
-
122
- # Concatenate np arrays
123
- transformed_numpy_array = np.concatenate(inference_res, axis=1)
124
- elif isinstance(inference_res, tuple) and len(inference_res) > 0 and isinstance(inference_res[0], np.ndarray):
125
- # In case of kneighbors, functions return a tuple of ndarrays.
126
- transformed_numpy_array = np.stack(inference_res, axis=1)
127
- else:
128
- transformed_numpy_array = inference_res
129
-
130
- if (len(transformed_numpy_array.shape) == 3) and inference_method != "kneighbors":
131
- # VotingClassifier will return results of shape (n_classifiers, n_samples, n_classes)
132
- # when voting = "soft" and flatten_transform = False. We can't handle unflatten transforms,
133
- # so we ignore flatten_transform flag and flatten the results.
134
- transformed_numpy_array = np.hstack(transformed_numpy_array) # type: ignore[call-overload]
135
-
136
- if len(transformed_numpy_array.shape) == 1:
137
- transformed_numpy_array = np.reshape(transformed_numpy_array, (-1, 1))
138
-
139
- shape = transformed_numpy_array.shape
140
- if shape[1] != len(output_cols):
141
- if len(output_cols) != 1:
142
- raise exceptions.SnowflakeMLException(
143
- error_code=error_codes.INVALID_ARGUMENT,
144
- original_exception=TypeError(
145
- "expected_output_cols must be same length as transformed array or " "should be of length 1"
146
- ),
147
- )
148
- actual_output_cols = []
149
- for i in range(shape[1]):
150
- actual_output_cols.append(f"{output_cols[0]}_{i}")
151
- output_cols = actual_output_cols
110
+ transformed_numpy_array, output_cols = handle_inference_result(
111
+ inference_res=inference_res, output_cols=output_cols, inference_method=inference_method
112
+ )
152
113
 
153
114
  if inference_method == "kneighbors":
154
115
  if len(transformed_numpy_array.shape) == 3: # return_distance=True
@@ -55,18 +55,18 @@ class PandasModelTrainer:
55
55
 
56
56
  def train_fit_predict(
57
57
  self,
58
- pass_through_columns: List[str],
59
58
  expected_output_cols_list: List[str],
59
+ drop_input_cols: Optional[bool] = False,
60
60
  ) -> Tuple[pd.DataFrame, object]:
61
61
  """Trains the model using specified features and target columns from the dataset.
62
62
  This API is different from fit itself because it would also provide the predict
63
63
  output.
64
64
 
65
65
  Args:
66
- pass_through_columns (List[str]): The column names that would
67
- display in the returned dataset.
68
66
  expected_output_cols_list (List[str]): The output columns
69
67
  name as a list. Defaults to None.
68
+ drop_input_cols (Optional[bool]): Boolean to determine whether to
69
+ drop the input columns from the output dataset.
70
70
 
71
71
  Returns:
72
72
  Tuple[pd.DataFrame, object]: [predicted dataset, estimator]
@@ -75,7 +75,7 @@ class PandasModelTrainer:
75
75
  args = {"X": self.dataset[self.input_cols]}
76
76
  result = self.estimator.fit_predict(**args)
77
77
  result_df = pd.DataFrame(data=result, columns=expected_output_cols_list)
78
- if len(pass_through_columns) == 0:
78
+ if drop_input_cols:
79
79
  result_df = result_df
80
80
  else:
81
81
  result_df = pd.concat([self.dataset, result_df], axis=1)
@@ -1,5 +1,8 @@
1
1
  from typing import Any, List, Optional
2
2
 
3
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import (
4
+ SnowparkTransformHandlers,
5
+ )
3
6
  from snowflake.snowpark import DataFrame, Session
4
7
 
5
8
 
@@ -42,14 +45,16 @@ class MLRuntimeTransformHandlers:
42
45
  inference_method: str,
43
46
  input_cols: List[str],
44
47
  expected_output_cols: List[str],
45
- pass_through_cols: List[str],
46
48
  session: Session,
47
49
  dependencies: List[str],
50
+ drop_input_cols: Optional[bool] = False,
48
51
  expected_output_cols_type: Optional[str] = "",
49
52
  *args: Any,
50
53
  **kwargs: Any,
51
54
  ) -> DataFrame:
52
55
  """Run batch inference on the given dataset.
56
+ Temporary workaround - pushdown implementation is not currently ready for batch_inference.
57
+ We use a SnowparkTransformHandlers until we have a way to use the runtime client.
53
58
 
54
59
  Args:
55
60
  inference_method: the name of the method used by `estimator` to run inference.
@@ -57,7 +62,7 @@ class MLRuntimeTransformHandlers:
57
62
  session: An active Snowpark Session.
58
63
  dependencies: List of dependencies for the transformer.
59
64
  expected_output_cols: column names (in order) of the output dataset.
60
- pass_through_cols: columns in the dataset not used in inference.
65
+ drop_input_cols: Boolean to determine whether to drop the input columns from the output dataset.
61
66
  expected_output_cols_type: Expected type of the output columns.
62
67
  args: additional positional arguments.
63
68
  kwargs: additional keyword args.
@@ -65,27 +70,26 @@ class MLRuntimeTransformHandlers:
65
70
  Returns:
66
71
  A new dataset of the same type as the input dataset.
67
72
 
68
- Raises:
69
- TypeError: The ML Runtimes client returned a non-DataFrame result.
70
73
  """
71
- output_df = self.client.batch_inference(
72
- inference_method=inference_method,
74
+
75
+ handler = SnowparkTransformHandlers(
73
76
  dataset=self.dataset,
74
77
  estimator=self.estimator,
75
- input_cols=input_cols,
76
- expected_output_cols=expected_output_cols,
77
- pass_through_cols=pass_through_cols,
78
- session=session,
79
- dependencies=dependencies,
80
- expected_output_cols_type=expected_output_cols_type,
78
+ class_name=self._class_name,
79
+ subproject=self._subproject,
80
+ autogenerated=self._autogenerated,
81
+ )
82
+ return handler.batch_inference(
83
+ inference_method,
84
+ input_cols,
85
+ expected_output_cols,
86
+ session,
87
+ dependencies,
88
+ drop_input_cols,
89
+ expected_output_cols_type,
81
90
  *args,
82
91
  **kwargs,
83
92
  )
84
- if not isinstance(output_df, DataFrame):
85
- raise TypeError(
86
- f"The ML Runtimes Client did not return a DataFrame a non-float value Returned type: {type(output_df)}"
87
- )
88
- return output_df
89
93
 
90
94
  def score(
91
95
  self,
@@ -29,7 +29,7 @@ class SKLearnModelSpecifications(ModelSpecifications):
29
29
  ]
30
30
 
31
31
  # A change from previous implementation.
32
- # When reusing the Sprocs for all the fit() call in the session, the static dpendencies list should include
32
+ # When reusing the Sprocs for all the fit() call in the session, the static dependencies list should include
33
33
  # all the possible dependencies required during the lifetime.
34
34
 
35
35
  # Include XGBoost in the dependencies if it is installed.
@@ -67,10 +67,12 @@ class XGBoostModelSpecifications(ModelSpecifications):
67
67
  class LightGBMModelSpecifications(ModelSpecifications):
68
68
  def __init__(self) -> None:
69
69
  import lightgbm
70
+ import sklearn
70
71
 
71
72
  imports: List[str] = ["lightgbm"]
72
73
  pkgDependencies: List[str] = [
73
74
  f"numpy=={np.__version__}",
75
+ f"scikit-learn=={sklearn.__version__}",
74
76
  f"lightgbm=={lightgbm.__version__}",
75
77
  f"cloudpickle=={cp.__version__}",
76
78
  ]
@@ -1,4 +1,4 @@
1
- from typing import List, Protocol, Tuple, Union
1
+ from typing import List, Optional, Protocol, Tuple, Union
2
2
 
3
3
  import pandas as pd
4
4
 
@@ -18,7 +18,7 @@ class ModelTrainer(Protocol):
18
18
 
19
19
  def train_fit_predict(
20
20
  self,
21
- pass_through_columns: List[str],
22
21
  expected_output_cols_list: List[str],
22
+ drop_input_cols: Optional[bool] = False,
23
23
  ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
24
24
  raise NotImplementedError