snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -8,11 +8,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import cloudpickle as cp
10
10
  import numpy as np
11
- from scipy.stats import rankdata
12
11
  from sklearn import model_selection
12
+ from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
13
13
 
14
14
  from snowflake.ml._internal import telemetry
15
- from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils
15
+ from snowflake.ml._internal.utils import (
16
+ identifier,
17
+ pkg_version_utils,
18
+ snowpark_dataframe_utils,
19
+ )
16
20
  from snowflake.ml._internal.utils.temp_file_utils import (
17
21
  cleanup_temp_files,
18
22
  get_temp_file_path,
@@ -26,7 +30,8 @@ from snowflake.snowpark._internal.utils import (
26
30
  TempObjectType,
27
31
  random_name_for_temp_object,
28
32
  )
29
- from snowflake.snowpark.functions import col, sproc, udtf
33
+ from snowflake.snowpark.functions import sproc, udtf
34
+ from snowflake.snowpark.row import Row
30
35
  from snowflake.snowpark.types import IntegerType, StringType, StructField, StructType
31
36
 
32
37
  cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
@@ -36,6 +41,117 @@ _PROJECT = "ModelDevelopment"
36
41
  DEFAULT_UDTF_NJOBS = 3
37
42
 
38
43
 
44
+ def construct_cv_results(
45
+ estimator: Union[GridSearchCV, RandomizedSearchCV],
46
+ n_split: int,
47
+ param_grid: List[Dict[str, Any]],
48
+ cv_results_raw_hex: List[Row],
49
+ cross_validator_indices_length: int,
50
+ parameter_grid_length: int,
51
+ ) -> Tuple[bool, Dict[str, Any]]:
52
+ """Construct the cross validation result from the UDF. Because we accelerate the process
53
+ by the number of cross validation number, and the combination of parameter grids.
54
+ Therefore, we need to stick them back together instead of returning the raw result
55
+ to align with original sklearn result.
56
+
57
+ Args:
58
+ estimator (Union[GridSearchCV, RandomizedSearchCV]): The sklearn object of estimator
59
+ GridSearchCV or RandomizedSearchCV
60
+ n_split (int): The number of split, which is determined by build_cross_validator.get_n_splits(X, y, groups)
61
+ param_grid (List[Dict[str, Any]]): the list of parameter grid or parameter sampler
62
+ cv_results_raw_hex (List[Row]): the list of cv_results from each cv and parameter grid combination.
63
+ Because UDxF can only return string, and numpy array/masked arrays cannot be encoded in a
64
+ json format. Each cv_result is encoded into hex string.
65
+ cross_validator_indices_length (int): the length of cross validator indices
66
+ parameter_grid_length (int): the length of parameter grid combination
67
+
68
+ Raises:
69
+ ValueError: Retrieved empty cross validation results
70
+ ValueError: Cross validator index length is 0
71
+ ValueError: Parameter index length is 0
72
+ ValueError: Retrieved incorrect dataframe dimension from Snowpark's UDTF.
73
+ RuntimeError: Cross validation results are unexpectedly empty for one fold.
74
+
75
+ Returns:
76
+ Tuple[bool, Dict[str, Any]]: returns multimetric, cv_results_
77
+ """
78
+ # Filter corner cases: either the snowpark dataframe result is empty; or index length is empty
79
+ if len(cv_results_raw_hex) == 0:
80
+ raise ValueError(
81
+ "Retrieved empty cross validation results from snowpark. Please retry or contact snowflake support."
82
+ )
83
+ if cross_validator_indices_length == 0:
84
+ raise ValueError("Cross validator index length is 0. Was the CV iterator empty? ")
85
+ if parameter_grid_length == 0:
86
+ raise ValueError("Parameter index length is 0. Were there no candidates?")
87
+
88
+ # cv_result maintains the original order
89
+ multimetric = False
90
+ # retrieve the cv_results from udtf table; results are encoded by hex and cloudpickle;
91
+ # We are constructing the raw information back to original form
92
+ if len(cv_results_raw_hex) != cross_validator_indices_length * parameter_grid_length:
93
+ raise ValueError(
94
+ "Retrieved incorrect dataframe dimension from Snowpark's UDTF."
95
+ f"Expected {cross_validator_indices_length * parameter_grid_length}, got {len(cv_results_raw_hex)}. "
96
+ "Please retry or contact snowflake support."
97
+ )
98
+
99
+ out = []
100
+
101
+ for each_cv_result_hex in cv_results_raw_hex:
102
+ # convert the hex string back to cv_results_
103
+ hex_str = bytes.fromhex(each_cv_result_hex[0])
104
+ with io.BytesIO(hex_str) as f_reload:
105
+ each_cv_result = cp.load(f_reload)
106
+ if not each_cv_result:
107
+ raise RuntimeError(
108
+ "Cross validation response is empty. This issue may be temporary - please try again."
109
+ )
110
+ temp_dict = dict()
111
+ """
112
+ This dictionary has the following keys
113
+ train_scores : dict of scorer name -> float
114
+ Score on training set (for all the scorers),
115
+ returned only if `return_train_score` is `True`.
116
+ test_scores : dict of scorer name -> float
117
+ Score on testing set (for all the scorers).
118
+ fit_time : float
119
+ Time spent for fitting in seconds.
120
+ score_time : float
121
+ Time spent for scoring in seconds.
122
+ """
123
+ if estimator.return_train_score:
124
+ if each_cv_result.get("split0_train_score", None):
125
+ # for single scorer, the split0_train_score only contains an array with one value
126
+ temp_dict["train_scores"] = each_cv_result["split0_train_score"][0]
127
+ else:
128
+ # if multimetric situation, the format would be
129
+ # {metric_name1: value, metric_name2: value, ...}
130
+ temp_dict["train_scores"] = {}
131
+ # For multi-metric evaluation, the scores for all the scorers are available in the
132
+ # cv_results_ dict at the keys ending with that scorer’s name ('_<scorer_name>')
133
+ # instead of '_score'.
134
+ for k, v in each_cv_result.items():
135
+ if "split0_train_" in k:
136
+ temp_dict["train_scores"][k[len("split0_train_") :]] = v
137
+ if isinstance(each_cv_result.get("split0_test_score"), np.ndarray):
138
+ temp_dict["test_scores"] = each_cv_result["split0_test_score"][0]
139
+ else:
140
+ temp_dict["test_scores"] = {}
141
+ for k, v in each_cv_result.items():
142
+ if "split0_test_" in k:
143
+ temp_dict["test_scores"][k[len("split0_test_") :]] = v
144
+ temp_dict["fit_time"] = each_cv_result["mean_fit_time"][0]
145
+ temp_dict["score_time"] = each_cv_result["mean_score_time"][0]
146
+ out.append(temp_dict)
147
+ first_test_score = out[0]["test_scores"]
148
+ multimetric = isinstance(first_test_score, dict)
149
+ return multimetric, estimator._format_results(param_grid, n_split, out)
150
+
151
+
152
+ cp.register_pickle_by_value(inspect.getmodule(construct_cv_results))
153
+
154
+
39
155
  class DistributedHPOTrainer(SnowparkModelTrainer):
40
156
  """
41
157
  A class for performing distributed hyperparameter optimization (HPO) using Snowpark.
@@ -105,7 +221,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
105
221
  temp_stage_creation_query = f"CREATE OR REPLACE TEMP STAGE {temp_stage_name};"
106
222
  session.sql(temp_stage_creation_query).collect()
107
223
 
108
- # Stage data.
224
+ # Stage data as parquet file
109
225
  dataset = snowpark_dataframe_utils.cast_snowpark_dataframe(dataset)
110
226
  remote_file_path = f"{temp_stage_name}/{temp_stage_name}.parquet"
111
227
  dataset.write.copy_into_location( # type:ignore[call-overload]
@@ -114,6 +230,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
114
230
  imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}").collect()]
115
231
 
116
232
  # Store GridSearchCV's refit variable. If user set it as False, we don't need to refit it again
233
+ # refit variable can be boolean, string or callable
117
234
  original_refit = estimator.refit
118
235
 
119
236
  # Create a temp file and dump the estimator to that file.
@@ -136,7 +253,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
136
253
  inspect.currentframe(), self.__class__.__name__
137
254
  ),
138
255
  api_calls=[sproc],
139
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
140
256
  )
141
257
  udtf_statement_params = telemetry.get_function_usage_statement_params(
142
258
  project=_PROJECT,
@@ -145,7 +261,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
145
261
  inspect.currentframe(), self.__class__.__name__
146
262
  ),
147
263
  api_calls=[udtf],
148
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
264
+ custom_tags=dict([("hpo_udtf", True)]),
149
265
  )
150
266
 
151
267
  # Put locally serialized estimator on stage.
@@ -208,7 +324,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
208
324
  for file_name in data_files
209
325
  ]
210
326
  df = pd.concat(partial_df, ignore_index=True)
211
- df.columns = [identifier.get_inferred_name(col) for col in df.columns]
327
+ df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
212
328
 
213
329
  X = df[input_cols]
214
330
  y = df[label_cols].squeeze() if label_cols else None
@@ -222,11 +338,16 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
222
338
  with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj:
223
339
  estimator = cp.load(local_estimator_file_obj)["estimator"]
224
340
 
225
- cv_orig = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator))
226
- indices = [test for _, test in cv_orig.split(X, y)]
341
+ build_cross_validator = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator))
342
+ from sklearn.utils.validation import indexable
343
+
344
+ X, y, _ = indexable(X, y, None)
345
+ n_splits = build_cross_validator.get_n_splits(X, y, None)
346
+ # store the cross_validator's test indices only to save space
347
+ cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
227
348
  local_indices_file_name = get_temp_file_path()
228
349
  with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
229
- cp.dump(indices, local_indices_file_obj)
350
+ cp.dump(cross_validator_indices, local_indices_file_obj)
230
351
 
231
352
  # Put locally serialized indices on stage.
232
353
  put_result = session.file.put(
@@ -237,7 +358,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
237
358
  )
238
359
  indices_location = put_result[0].target
239
360
  imports.append(f"@{temp_stage_name}/{indices_location}")
240
- indices_len = len(indices)
361
+ cross_validator_indices_length = int(len(cross_validator_indices))
362
+ parameter_grid_length = len(param_grid)
241
363
 
242
364
  assert estimator is not None
243
365
 
@@ -261,7 +383,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
261
383
  for file_name in data_files
262
384
  ]
263
385
  df = pd.concat(partial_df, ignore_index=True)
264
- df.columns = [identifier.get_inferred_name(col) for col in df.columns]
386
+ df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
265
387
 
266
388
  # load estimator
267
389
  local_estimator_file_path = os.path.join(
@@ -299,16 +421,30 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
299
421
  self.data_length = data_length
300
422
  self.params_to_evaluate = params_to_evaluate
301
423
 
302
- def process(self, params_idx: int, idx: int) -> Iterator[Tuple[str]]:
424
+ def process(self, params_idx: int, cv_idx: int) -> Iterator[Tuple[str]]:
425
+ # Assign parameter to GridSearchCV
303
426
  if hasattr(estimator, "param_grid"):
304
427
  self.estimator.param_grid = self.params_to_evaluate[params_idx]
428
+ # Assign parameter to RandomizedSearchCV
305
429
  else:
306
430
  self.estimator.param_distributions = self.params_to_evaluate[params_idx]
431
+ # cross validator's indices: we stored test indices only (to save space);
432
+ # use the full indices to re-construct the train indices back.
307
433
  full_indices = np.array([i for i in range(self.data_length)])
308
- test_indice = self.indices[idx]
434
+ test_indice = self.indices[cv_idx]
309
435
  train_indice = np.setdiff1d(full_indices, test_indice)
436
+ # assign the tuple of train and test indices to estimator's original cross validator
310
437
  self.estimator.cv = [(train_indice, test_indice)]
311
438
  self.estimator.fit(**self.args)
439
+ # If the cv_results_ is empty, then the udtf table will have different number of output rows
440
+ # from the input rows. Raise ValueError.
441
+ if not self.estimator.cv_results_:
442
+ raise RuntimeError(
443
+ """Cross validation results are unexpectedly empty for one fold.
444
+ This issue may be temporary - please try again."""
445
+ )
446
+ # Encode the dictionary of cv_results_ as binary (in hex format) to send it back
447
+ # because udtf doesn't allow numpy within json file
312
448
  binary_cv_results = None
313
449
  with io.BytesIO() as f:
314
450
  cp.dump(self.estimator.cv_results_, f)
@@ -333,96 +469,44 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
333
469
 
334
470
  HP_TUNING = F.table_function(random_udtf_name)
335
471
 
336
- idx_length = int(indices_len)
337
- params_length = len(param_grid)
338
- idxs = [i for i in range(idx_length)]
339
- param_indices, training_indices = [], []
340
- for param_idx, cv_idx in product([param_index for param_index in range(params_length)], idxs):
472
+ # param_indices is for the index for each parameter grid;
473
+ # cv_indices is for the index for each cross_validator's fold;
474
+ # param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
475
+ param_indices, cv_indices = [], []
476
+ for param_idx, cv_idx in product(
477
+ [param_index for param_index in range(parameter_grid_length)],
478
+ [cv_index for cv_index in range(cross_validator_indices_length)],
479
+ ):
341
480
  param_indices.append(param_idx)
342
- training_indices.append(cv_idx)
481
+ cv_indices.append(cv_idx)
343
482
 
344
- pd_df = pd.DataFrame(
483
+ indices_info_pandas = pd.DataFrame(
345
484
  {
346
- "PARAMS": param_indices,
347
- "TRAIN_IND": training_indices,
348
- "PARAM_INDEX": [i for i in range(idx_length * params_length)],
485
+ "PARAM_IND": param_indices,
486
+ "CV_IND": cv_indices,
487
+ "PARAM_CV_IND": [i for i in range(cross_validator_indices_length * parameter_grid_length)],
349
488
  }
350
489
  )
351
- df = session.create_dataframe(pd_df)
352
- results = df.select(
353
- F.cast(df["PARAM_INDEX"], IntegerType()).as_("PARAM_INDEX"),
354
- (HP_TUNING(df["PARAMS"], df["TRAIN_IND"]).over(partition_by=df["PARAM_INDEX"])),
490
+ indices_info_sp = session.create_dataframe(indices_info_pandas)
491
+ # execute udtf by querying HP_TUNING table
492
+ HP_raw_results = indices_info_sp.select(
493
+ F.cast(indices_info_sp["PARAM_CV_IND"], IntegerType()).as_("PARAM_CV_IND"),
494
+ (
495
+ HP_TUNING(indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]).over(
496
+ partition_by=indices_info_sp["PARAM_CV_IND"]
497
+ )
498
+ ),
355
499
  )
356
-
357
- # cv_result maintains the original order
358
- multimetric = False
359
- cv_results_ = dict()
360
- scorers = set()
361
- for i, val in enumerate(results.select("CV_RESULTS").sort(col("PARAM_INDEX")).collect()):
362
- # retrieved string had one more double quote in the front and end of the string.
363
- # use [1:-1] to remove the extra double quotes
364
- hex_str = bytes.fromhex(val[0])
365
- with io.BytesIO(hex_str) as f_reload:
366
- each_cv_result = cp.load(f_reload)
367
- for k, v in each_cv_result.items():
368
- cur_cv = i % idx_length
369
- key = k
370
- if "split0_test_" in k:
371
- # For multi-metric evaluation, the scores for all the scorers are available in the
372
- # cv_results_ dict at the keys ending with that scorer’s name ('_<scorer_name>')
373
- # instead of '_score'.
374
- scorers.add(k[len("split0_test_") :])
375
- key = k.replace("split0_test", f"split{cur_cv}_test")
376
- elif k.startswith("param"):
377
- if cur_cv != 0:
378
- key = False
379
- if key:
380
- if key not in cv_results_:
381
- cv_results_[key] = v
382
- else:
383
- cv_results_[key] = np.concatenate([cv_results_[key], v])
384
-
385
- multimetric = len(scorers) > 1
386
- # Use numpy to re-calculate all the information in cv_results_ again
387
- # Generally speaking, reshape all the results into the (scorers+2, idx_length, params_length) shape,
388
- # and average them by the idx_length;
389
- # idx_length is the number of cv folds; params_length is the number of parameter combinations
390
- scores = [
391
- np.reshape(
392
- np.concatenate([cv_results_[f"split{cur_cv}_test_{score}"] for cur_cv in range(idx_length)]),
393
- (idx_length, -1),
394
- )
395
- for score in scorers
396
- ]
397
-
398
- fit_score_test_matrix = np.stack(
399
- [
400
- np.reshape(cv_results_["mean_fit_time"], (idx_length, -1)),
401
- np.reshape(cv_results_["mean_score_time"], (idx_length, -1)),
402
- ]
403
- + scores
500
+ # multimetric, cv_results_, best_param_index, scorers
501
+ multimetric, cv_results_ = construct_cv_results(
502
+ estimator,
503
+ n_splits,
504
+ list(param_grid),
505
+ HP_raw_results.select("CV_RESULTS").sort(F.col("PARAM_CV_IND")).collect(),
506
+ cross_validator_indices_length,
507
+ parameter_grid_length,
404
508
  )
405
509
 
406
- mean_fit_score_test_matrix = np.mean(fit_score_test_matrix, axis=1)
407
- std_fit_score_test_matrix = np.std(fit_score_test_matrix, axis=1)
408
- cv_results_["std_fit_time"] = std_fit_score_test_matrix[0]
409
- cv_results_["mean_fit_time"] = mean_fit_score_test_matrix[0]
410
- cv_results_["std_score_time"] = std_fit_score_test_matrix[1]
411
- cv_results_["mean_score_time"] = mean_fit_score_test_matrix[1]
412
- for idx, score in enumerate(scorers):
413
- cv_results_[f"std_test_{score}"] = std_fit_score_test_matrix[idx + 2]
414
- cv_results_[f"mean_test_{score}"] = mean_fit_score_test_matrix[idx + 2]
415
- # re-compute the ranking again with mean_test_<score>.
416
- cv_results_[f"rank_test_{score}"] = rankdata(-cv_results_[f"mean_test_{score}"], method="min")
417
- # The best param is the highest ranking (which is 1) and we choose the first time ranking 1 appeared.
418
- # If all scores are `nan`, `rankdata` will also produce an array of `nan` values.
419
- # In that case, default to first index.
420
- best_param_index = (
421
- np.where(cv_results_[f"rank_test_{score}"] == 1)[0][0]
422
- if not np.isnan(cv_results_[f"rank_test_{score}"]).all()
423
- else 0
424
- )
425
-
426
510
  estimator.cv_results_ = cv_results_
427
511
  estimator.multimetric_ = multimetric
428
512
 
@@ -452,7 +536,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
452
536
  # With a non-custom callable, we can select the best score
453
537
  # based on the best index
454
538
  estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
455
- estimator.best_params_ = cv_results_["params"][best_param_index]
539
+ estimator.best_params_ = cv_results_["params"][estimator.best_index_]
456
540
 
457
541
  if original_refit:
458
542
  estimator.best_estimator_ = clone(estimator.estimator).set_params(
@@ -541,12 +625,15 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
541
625
  n_iter=self.estimator.n_iter,
542
626
  random_state=self.estimator.random_state,
543
627
  )
628
+ relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
629
+ pkg_versions=model_spec.pkgDependencies, session=self.session
630
+ )
544
631
  return self.fit_search_snowpark(
545
632
  param_grid=param_grid,
546
633
  dataset=self.dataset,
547
634
  session=self.session,
548
635
  estimator=self.estimator,
549
- dependencies=model_spec.pkgDependencies,
636
+ dependencies=relaxed_dependencies,
550
637
  udf_imports=["sklearn"],
551
638
  input_cols=self.input_cols,
552
639
  label_cols=self.label_cols,
@@ -132,3 +132,24 @@ def is_single_node(session: Session) -> bool:
132
132
  # If current session cannot retrieve the warehouse name back,
133
133
  # Default as True; Let HPO fall back to stored procedure implementation
134
134
  return True
135
+
136
+
137
+ def get_module_name(model: object) -> str:
138
+ """Returns the source module of the given object.
139
+
140
+ Args:
141
+ model: Object to inspect.
142
+
143
+ Returns:
144
+ Source module of the given object.
145
+
146
+ Raises:
147
+ SnowflakeMLException: If the source module of the given object is not found.
148
+ """
149
+ module = inspect.getmodule(model)
150
+ if module is None:
151
+ raise exceptions.SnowflakeMLException(
152
+ error_code=error_codes.INVALID_TYPE,
153
+ original_exception=ValueError(f"Unable to infer the source module of the given object {model}."),
154
+ )
155
+ return module.__name__
@@ -1,10 +1,9 @@
1
- import inspect
2
1
  from typing import List
3
2
 
4
3
  import cloudpickle as cp
5
4
  import numpy as np
6
5
 
7
- from snowflake.ml._internal.exceptions import error_codes, exceptions
6
+ from snowflake.ml.modeling._internal.estimator_utils import get_module_name
8
7
 
9
8
 
10
9
  class ModelSpecifications:
@@ -120,16 +119,10 @@ class ModelSpecificationsBuilder:
120
119
  Appropriate ModelSpecification object
121
120
 
122
121
  Raises:
123
- SnowflakeMLException: Raises an exception the module of given model can't be determined.
124
122
  TypeError: Raises the exception for unsupported modules.
125
123
  """
126
- module = inspect.getmodule(model)
127
- if module is None:
128
- raise exceptions.SnowflakeMLException(
129
- error_code=error_codes.INVALID_TYPE,
130
- original_exception=ValueError("Unable to infer model type of the given native model object."),
131
- )
132
- root_module_name = module.__name__.split(".")[0]
124
+ module_name = get_module_name(model=model)
125
+ root_module_name = module_name.split(".")[0]
133
126
  if root_module_name == "sklearn":
134
127
  from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
135
128
 
@@ -3,13 +3,20 @@ from typing import List, Optional, Union
3
3
  import pandas as pd
4
4
  from sklearn import model_selection
5
5
 
6
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
6
7
  from snowflake.ml.modeling._internal.distributed_hpo_trainer import (
7
8
  DistributedHPOTrainer,
8
9
  )
9
- from snowflake.ml.modeling._internal.estimator_utils import is_single_node
10
+ from snowflake.ml.modeling._internal.estimator_utils import (
11
+ get_module_name,
12
+ is_single_node,
13
+ )
10
14
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
11
15
  from snowflake.ml.modeling._internal.pandas_trainer import PandasModelTrainer
12
16
  from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer
17
+ from snowflake.ml.modeling._internal.xgboost_external_memory_trainer import (
18
+ XGBoostExternalMemoryTrainer,
19
+ )
13
20
  from snowflake.snowpark import DataFrame, Session
14
21
 
15
22
  _PROJECT = "ModelDevelopment"
@@ -30,6 +37,31 @@ class ModelTrainerBuilder:
30
37
  def _check_if_distributed_hpo_enabled(cls, session: Session) -> bool:
31
38
  return not is_single_node(session) and ModelTrainerBuilder._ENABLE_DISTRIBUTED is True
32
39
 
40
+ @classmethod
41
+ def _validate_external_memory_params(cls, estimator: object, batch_size: int) -> None:
42
+ """
43
+ Validate the params are set appropriately for external memory training.
44
+
45
+ Args:
46
+ estimator: Model object
47
+ batch_size: Number of rows in each batch of data processed during training.
48
+
49
+ Raises:
50
+ SnowflakeMLException: If the params are not appropriate for the external memory training feature.
51
+ """
52
+ module_name = get_module_name(model=estimator)
53
+ root_module_name = module_name.split(".")[0]
54
+ if root_module_name != "xgboost":
55
+ raise exceptions.SnowflakeMLException(
56
+ error_code=error_codes.INVALID_ARGUMENT,
57
+ original_exception=RuntimeError("External memory training is only supported for XGBoost models."),
58
+ )
59
+ if batch_size <= 0:
60
+ raise exceptions.SnowflakeMLException(
61
+ error_code=error_codes.INVALID_ARGUMENT,
62
+ original_exception=RuntimeError("Batch size must be >= 0 when using external memory training feature."),
63
+ )
64
+
33
65
  @classmethod
34
66
  def build(
35
67
  cls,
@@ -40,6 +72,8 @@ class ModelTrainerBuilder:
40
72
  sample_weight_col: Optional[str] = None,
41
73
  autogenerated: bool = False,
42
74
  subproject: str = "",
75
+ use_external_memory_version: bool = False,
76
+ batch_size: int = -1,
43
77
  ) -> ModelTrainer:
44
78
  """
45
79
  Builder method that creates an approproiate ModelTrainer instance based on the given params.
@@ -55,22 +89,32 @@ class ModelTrainerBuilder:
55
89
  )
56
90
  elif isinstance(dataset, DataFrame):
57
91
  trainer_klass = SnowparkModelTrainer
92
+ init_args = {
93
+ "estimator": estimator,
94
+ "dataset": dataset,
95
+ "session": dataset._session,
96
+ "input_cols": input_cols,
97
+ "label_cols": label_cols,
98
+ "sample_weight_col": sample_weight_col,
99
+ "autogenerated": autogenerated,
100
+ "subproject": subproject,
101
+ }
102
+
58
103
  assert dataset._session is not None # Make MyPy happpy
59
104
  if isinstance(estimator, model_selection.GridSearchCV) or isinstance(
60
105
  estimator, model_selection.RandomizedSearchCV
61
106
  ):
62
107
  if ModelTrainerBuilder._check_if_distributed_hpo_enabled(session=dataset._session):
63
108
  trainer_klass = DistributedHPOTrainer
64
- return trainer_klass(
65
- estimator=estimator,
66
- dataset=dataset,
67
- session=dataset._session,
68
- input_cols=input_cols,
69
- label_cols=label_cols,
70
- sample_weight_col=sample_weight_col,
71
- autogenerated=autogenerated,
72
- subproject=subproject,
73
- )
109
+ elif use_external_memory_version:
110
+ ModelTrainerBuilder._validate_external_memory_params(
111
+ estimator=estimator,
112
+ batch_size=batch_size,
113
+ )
114
+ trainer_klass = XGBoostExternalMemoryTrainer
115
+ init_args["batch_size"] = batch_size
116
+
117
+ return trainer_klass(**init_args) # type: ignore[arg-type]
74
118
  else:
75
119
  raise TypeError(
76
120
  f"Unexpected dataset type: {type(dataset)}."
@@ -306,7 +306,7 @@ class SnowparkHandlers:
306
306
  input_cols: List[str],
307
307
  label_cols: List[str],
308
308
  sample_weight_col: Optional[str],
309
- statement_params: Dict[str, str],
309
+ score_statement_params: Dict[str, str],
310
310
  ) -> float:
311
311
  import inspect
312
312
  import os
@@ -317,13 +317,13 @@ class SnowparkHandlers:
317
317
  importlib.import_module(import_name)
318
318
 
319
319
  for query in sql_queries[:-1]:
320
- _ = session.sql(query).collect(statement_params=statement_params)
320
+ _ = session.sql(query).collect(statement_params=score_statement_params)
321
321
  sp_df = session.sql(sql_queries[-1])
322
- df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
322
+ df: pd.DataFrame = sp_df.to_pandas(statement_params=score_statement_params)
323
323
  df.columns = sp_df.columns
324
324
 
325
325
  local_score_file_name = get_temp_file_path()
326
- session.file.get(stage_score_file_name, local_score_file_name, statement_params=statement_params)
326
+ session.file.get(stage_score_file_name, local_score_file_name, statement_params=score_statement_params)
327
327
 
328
328
  local_score_file_name_path = os.path.join(local_score_file_name, os.listdir(local_score_file_name)[0])
329
329
  with open(local_score_file_name_path, mode="r+b") as local_score_file_obj:
@@ -348,7 +348,7 @@ class SnowparkHandlers:
348
348
  return result
349
349
 
350
350
  # Call score sproc
351
- statement_params = telemetry.get_function_usage_statement_params(
351
+ score_statement_params = telemetry.get_function_usage_statement_params(
352
352
  project=_PROJECT,
353
353
  subproject=self._subproject,
354
354
  function_name=telemetry.get_statement_params_full_func_name(
@@ -357,6 +357,8 @@ class SnowparkHandlers:
357
357
  api_calls=[Session.call],
358
358
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
359
359
  )
360
+
361
+ kwargs = telemetry.get_sproc_statement_params_kwargs(score_wrapper_sproc, score_statement_params)
360
362
  score: float = score_wrapper_sproc(
361
363
  session,
362
364
  queries,
@@ -364,7 +366,8 @@ class SnowparkHandlers:
364
366
  input_cols,
365
367
  label_cols,
366
368
  sample_weight_col,
367
- statement_params,
369
+ score_statement_params,
370
+ **kwargs,
368
371
  )
369
372
 
370
373
  cleanup_temp_files([local_score_file_name])
@@ -12,7 +12,11 @@ from snowflake.ml._internal.exceptions import (
12
12
  exceptions,
13
13
  modeling_error_messages,
14
14
  )
15
- from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils
15
+ from snowflake.ml._internal.utils import (
16
+ identifier,
17
+ pkg_version_utils,
18
+ snowpark_dataframe_utils,
19
+ )
16
20
  from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
17
21
  from snowflake.ml._internal.utils.temp_file_utils import (
18
22
  cleanup_temp_files,
@@ -253,11 +257,15 @@ class SnowparkModelTrainer:
253
257
 
254
258
  fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
255
259
 
260
+ relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
261
+ pkg_versions=model_spec.pkgDependencies, session=self.session
262
+ )
263
+
256
264
  fit_wrapper_sproc = self.session.sproc.register(
257
265
  func=self._build_fit_wrapper_sproc(model_spec=model_spec),
258
266
  is_permanent=False,
259
267
  name=fit_sproc_name,
260
- packages=["snowflake-snowpark-python"] + model_spec.pkgDependencies, # type: ignore[arg-type]
268
+ packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
261
269
  replace=True,
262
270
  session=self.session,
263
271
  statement_params=statement_params,