snowflake-ml-python 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (178) hide show
  1. snowflake/ml/_internal/env_utils.py +16 -13
  2. snowflake/ml/_internal/exceptions/modeling_error_messages.py +5 -1
  3. snowflake/ml/feature_store/__init__.py +9 -0
  4. snowflake/ml/feature_store/entity.py +73 -0
  5. snowflake/ml/feature_store/feature_store.py +1657 -0
  6. snowflake/ml/feature_store/feature_view.py +459 -0
  7. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +9 -1
  8. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  9. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +12 -2
  10. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +7 -3
  11. snowflake/ml/model/model_signature.py +72 -16
  12. snowflake/ml/model/type_hints.py +9 -0
  13. snowflake/ml/modeling/_internal/estimator_protocols.py +1 -41
  14. snowflake/ml/modeling/_internal/model_trainer_builder.py +13 -9
  15. snowflake/ml/modeling/_internal/{distributed_hpo_trainer.py → snowpark_implementations/distributed_hpo_trainer.py} +3 -1
  16. snowflake/ml/modeling/_internal/{xgboost_external_memory_trainer.py → snowpark_implementations/xgboost_external_memory_trainer.py} +3 -1
  17. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +3 -3
  18. snowflake/ml/modeling/cluster/affinity_propagation.py +3 -3
  19. snowflake/ml/modeling/cluster/agglomerative_clustering.py +3 -3
  20. snowflake/ml/modeling/cluster/birch.py +3 -3
  21. snowflake/ml/modeling/cluster/bisecting_k_means.py +3 -3
  22. snowflake/ml/modeling/cluster/dbscan.py +3 -3
  23. snowflake/ml/modeling/cluster/feature_agglomeration.py +3 -3
  24. snowflake/ml/modeling/cluster/k_means.py +3 -3
  25. snowflake/ml/modeling/cluster/mean_shift.py +3 -3
  26. snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
  27. snowflake/ml/modeling/cluster/optics.py +3 -3
  28. snowflake/ml/modeling/cluster/spectral_biclustering.py +3 -3
  29. snowflake/ml/modeling/cluster/spectral_clustering.py +3 -3
  30. snowflake/ml/modeling/cluster/spectral_coclustering.py +3 -3
  31. snowflake/ml/modeling/compose/column_transformer.py +3 -3
  32. snowflake/ml/modeling/compose/transformed_target_regressor.py +3 -3
  33. snowflake/ml/modeling/covariance/elliptic_envelope.py +3 -3
  34. snowflake/ml/modeling/covariance/empirical_covariance.py +3 -3
  35. snowflake/ml/modeling/covariance/graphical_lasso.py +3 -3
  36. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +3 -3
  37. snowflake/ml/modeling/covariance/ledoit_wolf.py +3 -3
  38. snowflake/ml/modeling/covariance/min_cov_det.py +3 -3
  39. snowflake/ml/modeling/covariance/oas.py +3 -3
  40. snowflake/ml/modeling/covariance/shrunk_covariance.py +3 -3
  41. snowflake/ml/modeling/decomposition/dictionary_learning.py +3 -3
  42. snowflake/ml/modeling/decomposition/factor_analysis.py +3 -3
  43. snowflake/ml/modeling/decomposition/fast_ica.py +3 -3
  44. snowflake/ml/modeling/decomposition/incremental_pca.py +3 -3
  45. snowflake/ml/modeling/decomposition/kernel_pca.py +3 -3
  46. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +3 -3
  47. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +3 -3
  48. snowflake/ml/modeling/decomposition/pca.py +3 -3
  49. snowflake/ml/modeling/decomposition/sparse_pca.py +3 -3
  50. snowflake/ml/modeling/decomposition/truncated_svd.py +3 -3
  51. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +3 -3
  52. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +3 -3
  53. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +3 -3
  54. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +3 -3
  55. snowflake/ml/modeling/ensemble/bagging_classifier.py +3 -3
  56. snowflake/ml/modeling/ensemble/bagging_regressor.py +3 -3
  57. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +3 -3
  58. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +3 -3
  59. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +3 -3
  60. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +3 -3
  61. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +3 -3
  62. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +3 -3
  63. snowflake/ml/modeling/ensemble/isolation_forest.py +3 -3
  64. snowflake/ml/modeling/ensemble/random_forest_classifier.py +3 -3
  65. snowflake/ml/modeling/ensemble/random_forest_regressor.py +3 -3
  66. snowflake/ml/modeling/ensemble/stacking_regressor.py +3 -3
  67. snowflake/ml/modeling/ensemble/voting_classifier.py +3 -3
  68. snowflake/ml/modeling/ensemble/voting_regressor.py +3 -3
  69. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +3 -3
  70. snowflake/ml/modeling/feature_selection/select_fdr.py +3 -3
  71. snowflake/ml/modeling/feature_selection/select_fpr.py +3 -3
  72. snowflake/ml/modeling/feature_selection/select_fwe.py +3 -3
  73. snowflake/ml/modeling/feature_selection/select_k_best.py +3 -3
  74. snowflake/ml/modeling/feature_selection/select_percentile.py +3 -3
  75. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +3 -3
  76. snowflake/ml/modeling/feature_selection/variance_threshold.py +3 -3
  77. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +3 -3
  78. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +3 -3
  79. snowflake/ml/modeling/impute/iterative_imputer.py +3 -3
  80. snowflake/ml/modeling/impute/knn_imputer.py +3 -3
  81. snowflake/ml/modeling/impute/missing_indicator.py +3 -3
  82. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +3 -3
  83. snowflake/ml/modeling/kernel_approximation/nystroem.py +3 -3
  84. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +3 -3
  85. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +3 -3
  86. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +3 -3
  87. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +3 -3
  88. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +3 -3
  89. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +3 -3
  90. snowflake/ml/modeling/linear_model/ard_regression.py +3 -3
  91. snowflake/ml/modeling/linear_model/bayesian_ridge.py +3 -3
  92. snowflake/ml/modeling/linear_model/elastic_net.py +3 -3
  93. snowflake/ml/modeling/linear_model/elastic_net_cv.py +3 -3
  94. snowflake/ml/modeling/linear_model/gamma_regressor.py +3 -3
  95. snowflake/ml/modeling/linear_model/huber_regressor.py +3 -3
  96. snowflake/ml/modeling/linear_model/lars.py +3 -3
  97. snowflake/ml/modeling/linear_model/lars_cv.py +3 -3
  98. snowflake/ml/modeling/linear_model/lasso.py +3 -3
  99. snowflake/ml/modeling/linear_model/lasso_cv.py +3 -3
  100. snowflake/ml/modeling/linear_model/lasso_lars.py +3 -3
  101. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +3 -3
  102. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +3 -3
  103. snowflake/ml/modeling/linear_model/linear_regression.py +3 -3
  104. snowflake/ml/modeling/linear_model/logistic_regression.py +3 -3
  105. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +3 -3
  106. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +3 -3
  107. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +3 -3
  108. snowflake/ml/modeling/linear_model/multi_task_lasso.py +3 -3
  109. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +3 -3
  110. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +3 -3
  111. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +3 -3
  112. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +3 -3
  113. snowflake/ml/modeling/linear_model/perceptron.py +3 -3
  114. snowflake/ml/modeling/linear_model/poisson_regressor.py +3 -3
  115. snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -3
  116. snowflake/ml/modeling/linear_model/ridge.py +3 -3
  117. snowflake/ml/modeling/linear_model/ridge_classifier.py +3 -3
  118. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +3 -3
  119. snowflake/ml/modeling/linear_model/ridge_cv.py +3 -3
  120. snowflake/ml/modeling/linear_model/sgd_classifier.py +3 -3
  121. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +3 -3
  122. snowflake/ml/modeling/linear_model/sgd_regressor.py +3 -3
  123. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +3 -3
  124. snowflake/ml/modeling/linear_model/tweedie_regressor.py +3 -3
  125. snowflake/ml/modeling/manifold/isomap.py +3 -3
  126. snowflake/ml/modeling/manifold/mds.py +3 -3
  127. snowflake/ml/modeling/manifold/spectral_embedding.py +3 -3
  128. snowflake/ml/modeling/manifold/tsne.py +3 -3
  129. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +3 -3
  130. snowflake/ml/modeling/mixture/gaussian_mixture.py +3 -3
  131. snowflake/ml/modeling/model_selection/grid_search_cv.py +3 -13
  132. snowflake/ml/modeling/model_selection/randomized_search_cv.py +3 -13
  133. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +3 -3
  134. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +3 -3
  135. snowflake/ml/modeling/multiclass/output_code_classifier.py +3 -3
  136. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
  137. snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
  138. snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
  139. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +3 -3
  140. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
  141. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +3 -3
  142. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +3 -3
  143. snowflake/ml/modeling/neighbors/kernel_density.py +3 -3
  144. snowflake/ml/modeling/neighbors/local_outlier_factor.py +3 -3
  145. snowflake/ml/modeling/neighbors/nearest_centroid.py +3 -3
  146. snowflake/ml/modeling/neighbors/nearest_neighbors.py +3 -3
  147. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +3 -3
  148. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +3 -3
  149. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +3 -3
  150. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +3 -3
  151. snowflake/ml/modeling/neural_network/mlp_classifier.py +3 -3
  152. snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -3
  153. snowflake/ml/modeling/preprocessing/polynomial_features.py +3 -3
  154. snowflake/ml/modeling/semi_supervised/label_propagation.py +3 -3
  155. snowflake/ml/modeling/semi_supervised/label_spreading.py +3 -3
  156. snowflake/ml/modeling/svm/linear_svc.py +3 -3
  157. snowflake/ml/modeling/svm/linear_svr.py +3 -3
  158. snowflake/ml/modeling/svm/nu_svc.py +3 -3
  159. snowflake/ml/modeling/svm/nu_svr.py +3 -3
  160. snowflake/ml/modeling/svm/svc.py +3 -3
  161. snowflake/ml/modeling/svm/svr.py +3 -3
  162. snowflake/ml/modeling/tree/decision_tree_classifier.py +3 -3
  163. snowflake/ml/modeling/tree/decision_tree_regressor.py +3 -3
  164. snowflake/ml/modeling/tree/extra_tree_classifier.py +3 -3
  165. snowflake/ml/modeling/tree/extra_tree_regressor.py +3 -3
  166. snowflake/ml/modeling/xgboost/xgb_classifier.py +3 -3
  167. snowflake/ml/modeling/xgboost/xgb_regressor.py +3 -3
  168. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +3 -3
  169. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +3 -3
  170. snowflake/ml/version.py +1 -1
  171. {snowflake_ml_python-1.2.1.dist-info → snowflake_ml_python-1.2.2.dist-info}/METADATA +16 -1
  172. {snowflake_ml_python-1.2.1.dist-info → snowflake_ml_python-1.2.2.dist-info}/RECORD +178 -174
  173. /snowflake/ml/modeling/_internal/{pandas_trainer.py → local_implementations/pandas_trainer.py} +0 -0
  174. /snowflake/ml/modeling/_internal/{snowpark_handlers.py → snowpark_implementations/snowpark_handlers.py} +0 -0
  175. /snowflake/ml/modeling/_internal/{snowpark_trainer.py → snowpark_implementations/snowpark_trainer.py} +0 -0
  176. {snowflake_ml_python-1.2.1.dist-info → snowflake_ml_python-1.2.2.dist-info}/LICENSE.txt +0 -0
  177. {snowflake_ml_python-1.2.1.dist-info → snowflake_ml_python-1.2.2.dist-info}/WHEEL +0 -0
  178. {snowflake_ml_python-1.2.1.dist-info → snowflake_ml_python-1.2.2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,18 @@
1
1
  import enum
2
+ import json
2
3
  import warnings
3
- from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Type
4
+ from typing import (
5
+ Any,
6
+ Dict,
7
+ List,
8
+ Literal,
9
+ Optional,
10
+ Sequence,
11
+ Tuple,
12
+ Type,
13
+ Union,
14
+ cast,
15
+ )
4
16
 
5
17
  import numpy as np
6
18
  import pandas as pd
@@ -337,6 +349,31 @@ class SnowparkIdentifierRule(enum.Enum):
337
349
  assert_never(self)
338
350
 
339
351
 
352
+ def _get_dataframe_values_range(
353
+ df: snowflake.snowpark.DataFrame,
354
+ ) -> Dict[str, Union[Tuple[int, int], Tuple[float, float]]]:
355
+ columns = [
356
+ F.array_construct(F.min(field.name), F.max(field.name)).as_(field.name)
357
+ for field in df.schema.fields
358
+ if isinstance(field.datatype, spt._NumericType)
359
+ ]
360
+ if not columns:
361
+ return {}
362
+ res = df.select(columns).collect()
363
+ if len(res) != 1:
364
+ raise snowml_exceptions.SnowflakeMLException(
365
+ error_code=error_codes.INTERNAL_SNOWML_ERROR,
366
+ original_exception=ValueError(f"Unable to get the value range of fields {df.columns}"),
367
+ )
368
+ return cast(
369
+ Dict[str, Union[Tuple[int, int], Tuple[float, float]]],
370
+ {
371
+ sql_identifier.SqlIdentifier(k, case_sensitive=True).identifier(): (json.loads(v)[0], json.loads(v)[1])
372
+ for k, v in res[0].as_dict().items()
373
+ },
374
+ )
375
+
376
+
340
377
  def _validate_snowpark_data(
341
378
  data: snowflake.snowpark.DataFrame, features: Sequence[core.BaseFeatureSpec]
342
379
  ) -> SnowparkIdentifierRule:
@@ -361,6 +398,7 @@ def _validate_snowpark_data(
361
398
  SnowparkIdentifierRule.NORMALIZED: [],
362
399
  }
363
400
  schema = data.schema
401
+ values_range = _get_dataframe_values_range(data)
364
402
  for identifier_rule in errors.keys():
365
403
  for feature in features:
366
404
  try:
@@ -401,8 +439,11 @@ def _validate_snowpark_data(
401
439
  + f"Feature is a scalar feature, while {field.name} is not."
402
440
  ),
403
441
  )
442
+ continue
404
443
  try:
405
- _validate_snowpark_type_feature(data, field, ft_type, feature.name)
444
+ _validate_snowpark_type_feature(
445
+ data, field, ft_type, feature.name, values_range.get(field.name, None)
446
+ )
406
447
  except snowml_exceptions.SnowflakeMLException as e:
407
448
  errors[identifier_rule].append(e.original_exception)
408
449
  break
@@ -433,17 +474,12 @@ If using the inferred names from model signatures, there are the following error
433
474
 
434
475
 
435
476
  def _validate_snowpark_type_feature(
436
- df: snowflake.snowpark.DataFrame, field: spt.StructField, ft_type: DataType, ft_name: str
477
+ df: snowflake.snowpark.DataFrame,
478
+ field: spt.StructField,
479
+ ft_type: DataType,
480
+ ft_name: str,
481
+ value_range: Optional[Union[Tuple[int, int], Tuple[float, float]]],
437
482
  ) -> None:
438
- def get_value_range(field_name: str) -> Tuple[int, int]:
439
- res = df.select(F.min(field_name).as_("MIN"), F.max(field_name).as_("MAX")).collect()
440
- if len(res) != 1:
441
- raise snowml_exceptions.SnowflakeMLException(
442
- error_code=error_codes.INTERNAL_SNOWML_ERROR,
443
- original_exception=ValueError(f"Unable to get the value range of field {field_name}"),
444
- )
445
- return res[0].MIN, res[0].MAX
446
-
447
483
  field_data_type = field.datatype
448
484
  col_name = identifier.get_unescaped_names(field.name)
449
485
 
@@ -465,16 +501,27 @@ def _validate_snowpark_type_feature(
465
501
  error_code=error_codes.INVALID_DATA,
466
502
  original_exception=ValueError(
467
503
  f"Data Validation Error in feature {ft_name}: "
468
- + f"Feature type {ft_type} is not met by column {col_name}."
504
+ f"Feature type {ft_type} is not met by column {col_name} "
505
+ f"because of its original type {field_data_type}"
506
+ ),
507
+ )
508
+ if value_range is None:
509
+ raise snowml_exceptions.SnowflakeMLException(
510
+ error_code=error_codes.INVALID_DATA,
511
+ original_exception=ValueError(
512
+ f"Data Validation Error in feature {ft_name}: "
513
+ f"Feature type {ft_type} is not met by column {col_name} "
514
+ f"because of its original type {field_data_type} is non-Numeric."
469
515
  ),
470
516
  )
471
- min_v, max_v = get_value_range(field.name)
517
+ min_v, max_v = value_range
472
518
  if max_v > np.iinfo(ft_type._numpy_type).max or min_v < np.iinfo(ft_type._numpy_type).min:
473
519
  raise snowml_exceptions.SnowflakeMLException(
474
520
  error_code=error_codes.INVALID_DATA,
475
521
  original_exception=ValueError(
476
522
  f"Data Validation Error in feature {ft_name}: "
477
- + f"Feature type {ft_type} is not met by column {col_name}."
523
+ f"Feature type {ft_type} is not met by column {col_name} "
524
+ f"because it overflows with min"
478
525
  ),
479
526
  )
480
527
  elif ft_type in [core.DataType.FLOAT, core.DataType.DOUBLE]:
@@ -494,7 +541,16 @@ def _validate_snowpark_type_feature(
494
541
  + f"Feature type {ft_type} is not met by column {col_name}."
495
542
  ),
496
543
  )
497
- min_v, max_v = get_value_range(field.name)
544
+ if value_range is None:
545
+ raise snowml_exceptions.SnowflakeMLException(
546
+ error_code=error_codes.INVALID_DATA,
547
+ original_exception=ValueError(
548
+ f"Data Validation Error in feature {ft_name}: "
549
+ f"Feature type {ft_type} is not met by column {col_name} "
550
+ f"because of its original type {field_data_type} is non-Numeric."
551
+ ),
552
+ )
553
+ min_v, max_v = value_range
498
554
  if (
499
555
  max_v > np.finfo(ft_type._numpy_type).max # type: ignore[arg-type]
500
556
  or min_v < np.finfo(ft_type._numpy_type).min # type: ignore[arg-type]
@@ -3,6 +3,7 @@ from typing import (
3
3
  TYPE_CHECKING,
4
4
  Any,
5
5
  Dict,
6
+ List,
6
7
  Literal,
7
8
  Optional,
8
9
  Sequence,
@@ -173,6 +174,13 @@ class SnowparkContainerServiceDeployOptions(DeployOptions):
173
174
  debug_mode: When set to True, deployment artifacts will be persisted in a local temp directory.
174
175
  enable_ingress: When set to True, will expose HTTP endpoint for access to the predict method of the created
175
176
  service.
177
+ external_access_integrations: External Access Integrations name used to build image and deploy the model.
178
+ Please refer to the doc for how to create an External Access Integrations: https://docs.snowflake.com/
179
+ developer-guide/snowpark-container-services/additional-considerations-services-jobs
180
+ #configuring-network-capabilities .
181
+ To make sure your image could be built, access to the following endpoint must be allowed.
182
+ docker.com:80, docker.com:443, anaconda.com:80, anaconda.com:443, anaconda.org:80, anaconda.org:443,
183
+ pypi.org:80, pypi.org:443
176
184
  """
177
185
 
178
186
  compute_pool: str
@@ -187,6 +195,7 @@ class SnowparkContainerServiceDeployOptions(DeployOptions):
187
195
  model_in_image: NotRequired[bool]
188
196
  debug_mode: NotRequired[bool]
189
197
  enable_ingress: NotRequired[bool]
198
+ external_access_integrations: List[str]
190
199
 
191
200
 
192
201
  class ModelMethodSaveOptions(TypedDict):
@@ -6,47 +6,7 @@ from snowflake.snowpark import DataFrame, Session
6
6
 
7
7
 
8
8
  # TODO: Add more specific entities to type hint estimators instead of using `object`.
9
- class FitPredictHandlers(Protocol):
10
- def batch_inference(
11
- self,
12
- dataset: DataFrame,
13
- session: Session,
14
- estimator: object,
15
- dependencies: List[str],
16
- inference_method: str,
17
- input_cols: List[str],
18
- pass_through_columns: List[str],
19
- expected_output_cols_list: List[str],
20
- expected_output_cols_type: str = "",
21
- ) -> DataFrame:
22
- raise NotImplementedError
23
-
24
- def score_pandas(
25
- self,
26
- dataset: pd.DataFrame,
27
- estimator: object,
28
- input_cols: List[str],
29
- label_cols: List[str],
30
- sample_weight_col: Optional[str],
31
- ) -> float:
32
- raise NotImplementedError
33
-
34
- def score_snowpark(
35
- self,
36
- dataset: DataFrame,
37
- session: Session,
38
- estimator: object,
39
- dependencies: List[str],
40
- score_sproc_imports: List[str],
41
- input_cols: List[str],
42
- label_cols: List[str],
43
- sample_weight_col: Optional[str],
44
- ) -> float:
45
- raise NotImplementedError
46
-
47
-
48
- # TODO: Add more specific entities to type hint estimators instead of using `object`.
49
- class CVHandlers(Protocol):
9
+ class TransformerHandlers(Protocol):
50
10
  def batch_inference(
51
11
  self,
52
12
  dataset: DataFrame,
@@ -4,17 +4,21 @@ import pandas as pd
4
4
  from sklearn import model_selection
5
5
 
6
6
  from snowflake.ml._internal.exceptions import error_codes, exceptions
7
- from snowflake.ml.modeling._internal.distributed_hpo_trainer import (
8
- DistributedHPOTrainer,
9
- )
10
7
  from snowflake.ml.modeling._internal.estimator_utils import (
11
8
  get_module_name,
12
9
  is_single_node,
13
10
  )
11
+ from snowflake.ml.modeling._internal.local_implementations.pandas_trainer import (
12
+ PandasModelTrainer,
13
+ )
14
14
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
15
- from snowflake.ml.modeling._internal.pandas_trainer import PandasModelTrainer
16
- from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer
17
- from snowflake.ml.modeling._internal.xgboost_external_memory_trainer import (
15
+ from snowflake.ml.modeling._internal.snowpark_implementations.distributed_hpo_trainer import (
16
+ DistributedHPOTrainer,
17
+ )
18
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_trainer import (
19
+ SnowparkModelTrainer,
20
+ )
21
+ from snowflake.ml.modeling._internal.snowpark_implementations.xgboost_external_memory_trainer import (
18
22
  XGBoostExternalMemoryTrainer,
19
23
  )
20
24
  from snowflake.snowpark import DataFrame, Session
@@ -76,9 +80,9 @@ class ModelTrainerBuilder:
76
80
  batch_size: int = -1,
77
81
  ) -> ModelTrainer:
78
82
  """
79
- Builder method that creates an approproiate ModelTrainer instance based on the given params.
83
+ Builder method that creates an appropriate ModelTrainer instance based on the given params.
80
84
  """
81
- assert input_cols is not None # Make MyPy happpy
85
+ assert input_cols is not None # Make MyPy happy
82
86
  if isinstance(dataset, pd.DataFrame):
83
87
  return PandasModelTrainer(
84
88
  estimator=estimator,
@@ -100,7 +104,7 @@ class ModelTrainerBuilder:
100
104
  "subproject": subproject,
101
105
  }
102
106
 
103
- assert dataset._session is not None # Make MyPy happpy
107
+ assert dataset._session is not None # Make MyPy happy
104
108
  if isinstance(estimator, model_selection.GridSearchCV) or isinstance(
105
109
  estimator, model_selection.RandomizedSearchCV
106
110
  ):
@@ -24,7 +24,9 @@ from snowflake.ml._internal.utils.temp_file_utils import (
24
24
  from snowflake.ml.modeling._internal.model_specifications import (
25
25
  ModelSpecificationsBuilder,
26
26
  )
27
- from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer
27
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_trainer import (
28
+ SnowparkModelTrainer,
29
+ )
28
30
  from snowflake.snowpark import DataFrame, Session, functions as F
29
31
  from snowflake.snowpark._internal.utils import (
30
32
  TempObjectType,
@@ -23,7 +23,9 @@ from snowflake.ml.modeling._internal.model_specifications import (
23
23
  ModelSpecifications,
24
24
  ModelSpecificationsBuilder,
25
25
  )
26
- from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer
26
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_trainer import (
27
+ SnowparkModelTrainer,
28
+ )
27
29
  from snowflake.snowpark import (
28
30
  DataFrame,
29
31
  Session,
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -229,7 +229,7 @@ class CalibratedClassifierCV(BaseTransformer):
229
229
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
230
230
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
231
231
  self._snowpark_cols: Optional[List[str]] = self.input_cols
232
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=CalibratedClassifierCV.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
232
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=CalibratedClassifierCV.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
233
233
  self._autogenerated = True
234
234
 
235
235
  def _get_rand_id(self) -> str:
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -204,7 +204,7 @@ class AffinityPropagation(BaseTransformer):
204
204
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
205
205
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
206
206
  self._snowpark_cols: Optional[List[str]] = self.input_cols
207
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=AffinityPropagation.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
207
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=AffinityPropagation.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
208
208
  self._autogenerated = True
209
209
 
210
210
  def _get_rand_id(self) -> str:
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -237,7 +237,7 @@ class AgglomerativeClustering(BaseTransformer):
237
237
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
238
238
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
239
239
  self._snowpark_cols: Optional[List[str]] = self.input_cols
240
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=AgglomerativeClustering.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
240
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=AgglomerativeClustering.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
241
241
  self._autogenerated = True
242
242
 
243
243
  def _get_rand_id(self) -> str:
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -195,7 +195,7 @@ class Birch(BaseTransformer):
195
195
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
196
196
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
197
197
  self._snowpark_cols: Optional[List[str]] = self.input_cols
198
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=Birch.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
198
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=Birch.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
199
199
  self._autogenerated = True
200
200
 
201
201
  def _get_rand_id(self) -> str:
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -244,7 +244,7 @@ class BisectingKMeans(BaseTransformer):
244
244
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
245
245
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
246
246
  self._snowpark_cols: Optional[List[str]] = self.input_cols
247
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=BisectingKMeans.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
247
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=BisectingKMeans.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
248
248
  self._autogenerated = True
249
249
 
250
250
  def _get_rand_id(self) -> str:
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -212,7 +212,7 @@ class DBSCAN(BaseTransformer):
212
212
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
213
213
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
214
214
  self._snowpark_cols: Optional[List[str]] = self.input_cols
215
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=DBSCAN.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
215
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=DBSCAN.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
216
216
  self._autogenerated = True
217
217
 
218
218
  def _get_rand_id(self) -> str:
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -244,7 +244,7 @@ class FeatureAgglomeration(BaseTransformer):
244
244
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
245
245
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
246
246
  self._snowpark_cols: Optional[List[str]] = self.input_cols
247
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=FeatureAgglomeration.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
247
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=FeatureAgglomeration.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
248
248
  self._autogenerated = True
249
249
 
250
250
  def _get_rand_id(self) -> str:
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -239,7 +239,7 @@ class KMeans(BaseTransformer):
239
239
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
240
240
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
241
241
  self._snowpark_cols: Optional[List[str]] = self.input_cols
242
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=KMeans.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
242
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=KMeans.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
243
243
  self._autogenerated = True
244
244
 
245
245
  def _get_rand_id(self) -> str:
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -215,7 +215,7 @@ class MeanShift(BaseTransformer):
215
215
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
216
216
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
217
217
  self._snowpark_cols: Optional[List[str]] = self.input_cols
218
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=MeanShift.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
218
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=MeanShift.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
219
219
  self._autogenerated = True
220
220
 
221
221
  def _get_rand_id(self) -> str:
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -265,7 +265,7 @@ class MiniBatchKMeans(BaseTransformer):
265
265
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
266
266
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
267
267
  self._snowpark_cols: Optional[List[str]] = self.input_cols
268
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=MiniBatchKMeans.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
268
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=MiniBatchKMeans.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
269
269
  self._autogenerated = True
270
270
 
271
271
  def _get_rand_id(self) -> str:
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -285,7 +285,7 @@ class OPTICS(BaseTransformer):
285
285
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
286
286
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
287
287
  self._snowpark_cols: Optional[List[str]] = self.input_cols
288
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=OPTICS.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
288
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=OPTICS.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
289
289
  self._autogenerated = True
290
290
 
291
291
  def _get_rand_id(self) -> str:
@@ -26,7 +26,7 @@ from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
26
  from snowflake.ml._internal.utils import pkg_version_utils, identifier
27
27
  from snowflake.snowpark import DataFrame, Session
28
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
- from snowflake.ml.modeling._internal.snowpark_handlers import SnowparkHandlers as HandlersImpl
29
+ from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import SnowparkHandlers as HandlersImpl
30
30
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
31
  from snowflake.ml.modeling._internal.model_trainer import ModelTrainer
32
32
  from snowflake.ml.modeling._internal.estimator_utils import (
@@ -35,7 +35,7 @@ from snowflake.ml.modeling._internal.estimator_utils import (
35
35
  transform_snowml_obj_to_sklearn_obj,
36
36
  validate_sklearn_args,
37
37
  )
38
- from snowflake.ml.modeling._internal.estimator_protocols import FitPredictHandlers
38
+ from snowflake.ml.modeling._internal.estimator_protocols import TransformerHandlers
39
39
 
40
40
  from snowflake.ml.model.model_signature import (
41
41
  DataType,
@@ -223,7 +223,7 @@ class SpectralBiclustering(BaseTransformer):
223
223
  self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
224
224
  # If user used snowpark dataframe during fit, here it stores the snowpark input_cols, otherwise the processed input_cols
225
225
  self._snowpark_cols: Optional[List[str]] = self.input_cols
226
- self._handlers: FitPredictHandlers = HandlersImpl(class_name=SpectralBiclustering.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
226
+ self._handlers: TransformerHandlers = HandlersImpl(class_name=SpectralBiclustering.__class__.__name__, subproject=_SUBPROJECT, autogenerated=True)
227
227
  self._autogenerated = True
228
228
 
229
229
  def _get_rand_id(self) -> str: