snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -101,13 +101,5 @@ class ModelMetadataDict(TypedDict):
101
101
  function_properties: NotRequired[Dict[str, Dict[str, Any]]]
102
102
 
103
103
 
104
- class ModelObjective(Enum):
105
- UNKNOWN = "unknown"
106
- BINARY_CLASSIFICATION = "binary_classification"
107
- MULTI_CLASSIFICATION = "multi_classification"
108
- REGRESSION = "regression"
109
- RANKING = "ranking"
110
-
111
-
112
104
  class ModelExplainAlgorithm(Enum):
113
105
  SHAP = "shap"
@@ -47,6 +47,7 @@ class ModelPackager:
47
47
  ext_modules: Optional[List[ModuleType]] = None,
48
48
  code_paths: Optional[List[str]] = None,
49
49
  options: Optional[model_types.ModelSaveOption] = None,
50
+ model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
50
51
  ) -> model_meta.ModelMetadata:
51
52
  if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
52
53
  raise snowml_exceptions.SnowflakeMLException(
@@ -84,6 +85,7 @@ class ModelPackager:
84
85
  conda_dependencies=conda_dependencies,
85
86
  pip_requirements=pip_requirements,
86
87
  python_version=python_version,
88
+ model_objective=model_objective,
87
89
  **options,
88
90
  ) as meta:
89
91
  model_blobs_path = os.path.join(self.local_dir_path, ModelPackager.MODEL_BLOBS_DIR)
@@ -30,7 +30,7 @@ class SeqOfPyTorchTensorHandler(base_handler.BaseDataHandler[Sequence["torch.Ten
30
30
 
31
31
  @staticmethod
32
32
  def count(data: Sequence["torch.Tensor"]) -> int:
33
- return min(data_col.shape[0] for data_col in data)
33
+ return min(data_col.shape[0] for data_col in data) # type: ignore[no-any-return]
34
34
 
35
35
  @staticmethod
36
36
  def truncate(data: Sequence["torch.Tensor"]) -> Sequence["torch.Tensor"]:
@@ -110,6 +110,15 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any])
110
110
  # https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ConversationalPipeline
111
111
  # Needs to convert to conversation object.
112
112
  if task == "conversational":
113
+ warnings.warn(
114
+ (
115
+ "Conversational pipeline is removed from transformers since 4.42.0. "
116
+ "Support will be removed from snowflake-ml-python soon."
117
+ ),
118
+ category=DeprecationWarning,
119
+ stacklevel=1,
120
+ )
121
+
113
122
  return core.ModelSignature(
114
123
  inputs=[
115
124
  core.FeatureSpec(name="user_inputs", dtype=core.DataType.STRING, shape=(-1,)),
@@ -70,7 +70,9 @@ class LLM:
70
70
 
71
71
  import peft
72
72
 
73
- peft_config = peft.PeftConfig.from_pretrained(model_id_or_path, **hub_kwargs) # type: ignore[attr-defined]
73
+ peft_config = peft.PeftConfig.from_pretrained( # type: ignore[no-untyped-call, attr-defined]
74
+ model_id_or_path, **hub_kwargs
75
+ )
74
76
  if peft_config.peft_type != peft.PeftType.LORA: # type: ignore[attr-defined]
75
77
  raise ValueError("Only LORA is supported.")
76
78
  if peft_config.task_type != peft.TaskType.CAUSAL_LM: # type: ignore[attr-defined]
@@ -1,4 +1,5 @@
1
1
  # mypy: disable-error-code="import"
2
+ from enum import Enum
2
3
  from typing import (
3
4
  TYPE_CHECKING,
4
5
  Any,
@@ -232,7 +233,6 @@ class BaseModelSaveOption(TypedDict):
232
233
  _legacy_save: NotRequired[bool]
233
234
  function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
234
235
  method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
235
- include_pip_dependencies: NotRequired[bool]
236
236
  enable_explainability: NotRequired[bool]
237
237
 
238
238
 
@@ -431,3 +431,11 @@ class Deployment(TypedDict):
431
431
  signature: core.ModelSignature
432
432
  options: Required[DeployOptions]
433
433
  details: NotRequired[DeployDetails]
434
+
435
+
436
+ class ModelObjective(Enum):
437
+ UNKNOWN = "unknown"
438
+ BINARY_CLASSIFICATION = "binary_classification"
439
+ MULTI_CLASSIFICATION = "multi_classification"
440
+ REGRESSION = "regression"
441
+ RANKING = "ranking"
@@ -1 +1,2 @@
1
1
  IN_ML_RUNTIME_ENV_VAR = "IN_SPCS_ML_RUNTIME"
2
+ USE_OPTIMIZED_DATA_INGESTOR = "USE_OPTIMIZED_DATA_INGESTOR"
@@ -166,10 +166,10 @@ class PandasTransformHandlers:
166
166
  SnowflakeMLException: The input column list does not have one of `X` and `X_test`.
167
167
  """
168
168
  assert hasattr(self.estimator, "score") # make type checker happy
169
- argspec = inspect.getfullargspec(self.estimator.score)
170
- if "X" in argspec.args:
169
+ params = inspect.signature(self.estimator.score).parameters
170
+ if "X" in params:
171
171
  score_args = {"X": self.dataset[input_cols]}
172
- elif "X_test" in argspec.args:
172
+ elif "X_test" in params:
173
173
  score_args = {"X_test": self.dataset[input_cols]}
174
174
  else:
175
175
  raise exceptions.SnowflakeMLException(
@@ -178,10 +178,10 @@ class PandasTransformHandlers:
178
178
  )
179
179
 
180
180
  if len(label_cols) > 0:
181
- label_arg_name = "Y" if "Y" in argspec.args else "y"
181
+ label_arg_name = "Y" if "Y" in params else "y"
182
182
  score_args[label_arg_name] = self.dataset[label_cols].squeeze()
183
183
 
184
- if sample_weight_col is not None and "sample_weight" in argspec.args:
184
+ if sample_weight_col is not None and "sample_weight" in params:
185
185
  score_args["sample_weight"] = self.dataset[sample_weight_col].squeeze()
186
186
 
187
187
  score = self.estimator.score(**score_args)
@@ -43,14 +43,14 @@ class PandasModelTrainer:
43
43
  Trained model
44
44
  """
45
45
  assert hasattr(self.estimator, "fit") # Keep mypy happy
46
- argspec = inspect.getfullargspec(self.estimator.fit)
46
+ params = inspect.signature(self.estimator.fit).parameters
47
47
  args = {"X": self.dataset[self.input_cols]}
48
48
 
49
49
  if self.label_cols:
50
- label_arg_name = "Y" if "Y" in argspec.args else "y"
50
+ label_arg_name = "Y" if "Y" in params else "y"
51
51
  args[label_arg_name] = self.dataset[self.label_cols].squeeze()
52
52
 
53
- if self.sample_weight_col is not None and "sample_weight" in argspec.args:
53
+ if self.sample_weight_col is not None and "sample_weight" in params:
54
54
  args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze()
55
55
 
56
56
  return self.estimator.fit(**args)
@@ -59,6 +59,7 @@ class PandasModelTrainer:
59
59
  self,
60
60
  expected_output_cols_list: List[str],
61
61
  drop_input_cols: Optional[bool] = False,
62
+ example_output_pd_df: Optional[pd.DataFrame] = None,
62
63
  ) -> Tuple[pd.DataFrame, object]:
63
64
  """Trains the model using specified features and target columns from the dataset.
64
65
  This API is different from fit itself because it would also provide the predict
@@ -69,6 +70,8 @@ class PandasModelTrainer:
69
70
  name as a list. Defaults to None.
70
71
  drop_input_cols (Optional[bool]): Boolean to determine whether to
71
72
  drop the input columns from the output dataset.
73
+ example_output_pd_df (Optional[pd.DataFrame]): Example output dataframe
74
+ This is not used in PandasModelTrainer. It is used in SnowparkModelTrainer.
72
75
 
73
76
  Returns:
74
77
  Tuple[pd.DataFrame, object]: [predicted dataset, estimator]
@@ -108,13 +111,13 @@ class PandasModelTrainer:
108
111
  assert hasattr(self.estimator, "fit") # make type checker happy
109
112
  assert hasattr(self.estimator, "fit_transform") # make type checker happy
110
113
 
111
- argspec = inspect.getfullargspec(self.estimator.fit)
114
+ params = inspect.signature(self.estimator.fit).parameters
112
115
  args = {"X": self.dataset[self.input_cols]}
113
116
  if self.label_cols:
114
- label_arg_name = "Y" if "Y" in argspec.args else "y"
117
+ label_arg_name = "Y" if "Y" in params else "y"
115
118
  args[label_arg_name] = self.dataset[self.label_cols].squeeze()
116
119
 
117
- if self.sample_weight_col is not None and "sample_weight" in argspec.args:
120
+ if self.sample_weight_col is not None and "sample_weight" in params:
118
121
  args["sample_weight"] = self.dataset[self.sample_weight_col].squeeze()
119
122
 
120
123
  inference_res = self.estimator.fit_transform(**args)
@@ -53,11 +53,13 @@ class SKLearnModelSpecifications(ModelSpecifications):
53
53
 
54
54
  class XGBoostModelSpecifications(ModelSpecifications):
55
55
  def __init__(self) -> None:
56
+ import sklearn
56
57
  import xgboost
57
58
 
58
59
  imports: List[str] = ["xgboost"]
59
60
  pkgDependencies: List[str] = [
60
61
  f"numpy=={np.__version__}",
62
+ f"scikit-learn=={sklearn.__version__}",
61
63
  f"xgboost=={xgboost.__version__}",
62
64
  f"cloudpickle=={cp.__version__}",
63
65
  ]
@@ -20,6 +20,7 @@ class ModelTrainer(Protocol):
20
20
  self,
21
21
  expected_output_cols_list: List[str],
22
22
  drop_input_cols: Optional[bool] = False,
23
+ example_output_pd_df: Optional[pd.DataFrame] = None,
23
24
  ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
24
25
  raise NotImplementedError
25
26
 
@@ -495,7 +495,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
495
495
  label_arg_name = "Y" if "Y" in argspec.args else "y"
496
496
  args[label_arg_name] = df[label_cols].squeeze()
497
497
 
498
- if sample_weight_col is not None and "sample_weight" in argspec.args:
498
+ if sample_weight_col is not None:
499
499
  args["sample_weight"] = df[sample_weight_col].squeeze()
500
500
  return args, estimator, indices, len(df), params_to_evaluate
501
501
 
@@ -1061,7 +1061,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1061
1061
  if label_cols:
1062
1062
  label_arg_name = "Y" if "Y" in argspec.args else "y"
1063
1063
  args[label_arg_name] = y
1064
- if sample_weight_col is not None and "sample_weight" in argspec.args:
1064
+ if sample_weight_col is not None:
1065
1065
  args["sample_weight"] = df[sample_weight_col].squeeze()
1066
1066
  # estimator.refit = original_refit
1067
1067
  refit_start_time = time.time()
@@ -318,19 +318,19 @@ class SnowparkTransformHandlers:
318
318
  with open(local_score_file_name_path, mode="r+b") as local_score_file_obj:
319
319
  estimator = cp.load(local_score_file_obj)
320
320
 
321
- argspec = inspect.getfullargspec(estimator.score)
322
- if "X" in argspec.args:
321
+ params = inspect.signature(estimator.score).parameters
322
+ if "X" in params:
323
323
  args = {"X": df[input_cols]}
324
- elif "X_test" in argspec.args:
324
+ elif "X_test" in params:
325
325
  args = {"X_test": df[input_cols]}
326
326
  else:
327
327
  raise RuntimeError("Neither 'X' or 'X_test' exist in argument")
328
328
 
329
329
  if label_cols:
330
- label_arg_name = "Y" if "Y" in argspec.args else "y"
330
+ label_arg_name = "Y" if "Y" in params else "y"
331
331
  args[label_arg_name] = df[label_cols].squeeze()
332
332
 
333
- if sample_weight_col is not None and "sample_weight" in argspec.args:
333
+ if sample_weight_col is not None and "sample_weight" in params:
334
334
  args["sample_weight"] = df[sample_weight_col].squeeze()
335
335
 
336
336
  result: float = estimator.score(**args)
@@ -35,6 +35,7 @@ cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
35
35
 
36
36
  _PROJECT = "ModelDevelopment"
37
37
  _ENABLE_ANONYMOUS_SPROC = False
38
+ _ENABLE_TRACER = True
38
39
 
39
40
 
40
41
  class SnowparkModelTrainer:
@@ -119,6 +120,8 @@ class SnowparkModelTrainer:
119
120
  A callable that can be registered as a stored procedure.
120
121
  """
121
122
  imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml
123
+ method_name = "fit"
124
+ tracer_name = f"snowpark.ml.modeling.{self._class_name.lower()}.{method_name}"
122
125
 
123
126
  def fit_wrapper_function(
124
127
  session: Session,
@@ -138,110 +141,98 @@ class SnowparkModelTrainer:
138
141
  for import_name in imports:
139
142
  importlib.import_module(import_name)
140
143
 
141
- # Execute snowpark queries and obtain the results as pandas dataframe
142
- # NB: this implies that the result data must fit into memory.
143
- for query in sql_queries[:-1]:
144
- _ = session.sql(query).collect(statement_params=statement_params)
145
- sp_df = session.sql(sql_queries[-1])
146
- df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
147
- df.columns = sp_df.columns
144
+ def fit_and_return_estimator() -> str:
145
+ """This is a helper function within the sproc to download the data, fit the model, and upload the model.
146
+
147
+ Returns:
148
+ The name of the file in session's temp stage (temp_stage_name) that contains the serialized model.
149
+ """
150
+ # Execute snowpark queries and obtain the results as pandas dataframe
151
+ # NB: this implies that the result data must fit into memory.
152
+ for query in sql_queries[:-1]:
153
+ _ = session.sql(query).collect(statement_params=statement_params)
154
+ sp_df = session.sql(sql_queries[-1])
155
+ df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
156
+ df.columns = sp_df.columns
157
+
158
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
159
+
160
+ session.file.get(
161
+ stage_location=temp_stage_name,
162
+ target_directory=local_transform_file_name,
163
+ statement_params=statement_params,
164
+ )
148
165
 
149
- local_transform_file_name = temp_file_utils.get_temp_file_path()
166
+ local_transform_file_path = os.path.join(
167
+ local_transform_file_name, os.listdir(local_transform_file_name)[0]
168
+ )
169
+ with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
170
+ estimator = cp.load(local_transform_file_obj)
150
171
 
151
- session.file.get(
152
- stage_location=temp_stage_name,
153
- target_directory=local_transform_file_name,
154
- statement_params=statement_params,
155
- )
172
+ params = inspect.signature(estimator.fit).parameters
173
+ args = {"X": df[input_cols]}
174
+ if label_cols:
175
+ label_arg_name = "Y" if "Y" in params else "y"
176
+ args[label_arg_name] = df[label_cols].squeeze()
156
177
 
157
- local_transform_file_path = os.path.join(
158
- local_transform_file_name, os.listdir(local_transform_file_name)[0]
159
- )
160
- with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
161
- estimator = cp.load(local_transform_file_obj)
178
+ if sample_weight_col is not None and "sample_weight" in params:
179
+ args["sample_weight"] = df[sample_weight_col].squeeze()
162
180
 
163
- argspec = inspect.getfullargspec(estimator.fit)
164
- args = {"X": df[input_cols]}
165
- if label_cols:
166
- label_arg_name = "Y" if "Y" in argspec.args else "y"
167
- args[label_arg_name] = df[label_cols].squeeze()
181
+ estimator.fit(**args)
168
182
 
169
- if sample_weight_col is not None and "sample_weight" in argspec.args:
170
- args["sample_weight"] = df[sample_weight_col].squeeze()
183
+ local_result_file_name = temp_file_utils.get_temp_file_path()
171
184
 
172
- estimator.fit(**args)
185
+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
186
+ cp.dump(estimator, local_result_file_obj)
173
187
 
174
- local_result_file_name = temp_file_utils.get_temp_file_path()
188
+ session.file.put(
189
+ local_file_name=local_result_file_name,
190
+ stage_location=temp_stage_name,
191
+ auto_compress=False,
192
+ overwrite=True,
193
+ statement_params=statement_params,
194
+ )
195
+ return local_result_file_name
175
196
 
176
- with open(local_result_file_name, mode="w+b") as local_result_file_obj:
177
- cp.dump(estimator, local_result_file_obj)
197
+ if _ENABLE_TRACER:
178
198
 
179
- session.file.put(
180
- local_file_name=local_result_file_name,
181
- stage_location=temp_stage_name,
182
- auto_compress=False,
183
- overwrite=True,
184
- statement_params=statement_params,
185
- )
199
+ # Use opentelemetry to trace the dist and span of the fit operation.
200
+ # This would allow user to see the trace in the Snowflake UI.
201
+ from opentelemetry import trace
186
202
 
187
- # Note: you can add something like + "|" + str(df) to the return string
188
- # to pass debug information to the caller.
189
- return str(os.path.basename(local_result_file_name))
203
+ tracer = trace.get_tracer(tracer_name)
204
+ with tracer.start_as_current_span("fit"):
205
+ local_result_file_name = fit_and_return_estimator()
206
+ # Note: you can add something like + "|" + str(df) to the return string
207
+ # to pass debug information to the caller.
208
+ return str(os.path.basename(local_result_file_name))
209
+ else:
210
+ local_result_file_name = fit_and_return_estimator()
211
+ return str(os.path.basename(local_result_file_name))
190
212
 
191
213
  return fit_wrapper_function
192
214
 
193
- def _get_fit_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
215
+ def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
194
216
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
195
- fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
196
-
197
- relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
198
- pkg_versions=model_spec.pkgDependencies, session=self.session
199
- )
200
-
201
- fit_wrapper_sproc = self.session.sproc.register(
202
- func=self._build_fit_wrapper_sproc(model_spec=model_spec),
203
- is_permanent=False,
204
- name=fit_sproc_name,
205
- packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
206
- replace=True,
207
- session=self.session,
208
- statement_params=statement_params,
209
- anonymous=True,
210
- execute_as="caller",
211
- )
212
-
213
- return fit_wrapper_sproc
214
-
215
- def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
216
- # If the sproc already exists, don't register.
217
- if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
218
- self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
219
-
220
- model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
221
- fit_sproc_key = model_spec.__class__.__name__
222
- if fit_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
223
- fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] # type: ignore[attr-defined]
224
- return fit_sproc
225
217
 
226
218
  fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
227
219
 
228
220
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
229
221
  pkg_versions=model_spec.pkgDependencies, session=self.session
230
222
  )
223
+ packages = ["snowflake-snowpark-python", "snowflake-telemetry-python"] + relaxed_dependencies
231
224
 
232
225
  fit_wrapper_sproc = self.session.sproc.register(
233
226
  func=self._build_fit_wrapper_sproc(model_spec=model_spec),
234
227
  is_permanent=False,
235
228
  name=fit_sproc_name,
236
- packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
229
+ packages=packages, # type: ignore[arg-type]
237
230
  replace=True,
238
231
  session=self.session,
239
232
  statement_params=statement_params,
240
233
  execute_as="caller",
234
+ anonymous=anonymous,
241
235
  )
242
-
243
- self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] = fit_wrapper_sproc # type: ignore[attr-defined]
244
-
245
236
  return fit_wrapper_sproc
246
237
 
247
238
  def _build_fit_predict_wrapper_sproc(
@@ -333,7 +324,9 @@ class SnowparkModelTrainer:
333
324
 
334
325
  # write into a temp table in sproc and load the table from outside
335
326
  session.write_pandas(
336
- fit_predict_result_pd, fit_predict_result_name, auto_create_table=True, table_type="temp"
327
+ fit_predict_result_pd,
328
+ fit_predict_result_name,
329
+ overwrite=True,
337
330
  )
338
331
 
339
332
  # Note: you can add something like + "|" + str(df) to the return string
@@ -414,13 +407,13 @@ class SnowparkModelTrainer:
414
407
  with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
415
408
  estimator = cp.load(local_transform_file_obj)
416
409
 
417
- argspec = inspect.getfullargspec(estimator.fit)
410
+ params = inspect.signature(estimator.fit).parameters
418
411
  args = {"X": df[input_cols]}
419
412
  if label_cols:
420
- label_arg_name = "Y" if "Y" in argspec.args else "y"
413
+ label_arg_name = "Y" if "Y" in params else "y"
421
414
  args[label_arg_name] = df[label_cols].squeeze()
422
415
 
423
- if sample_weight_col is not None and "sample_weight" in argspec.args:
416
+ if sample_weight_col is not None and "sample_weight" in params:
424
417
  args["sample_weight"] = df[sample_weight_col].squeeze()
425
418
 
426
419
  fit_transform_result = estimator.fit_transform(**args)
@@ -477,7 +470,7 @@ class SnowparkModelTrainer:
477
470
 
478
471
  return fit_transform_wrapper_function
479
472
 
480
- def _get_fit_predict_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
473
+ def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
481
474
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
482
475
 
483
476
  fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
@@ -494,82 +487,14 @@ class SnowparkModelTrainer:
494
487
  replace=True,
495
488
  session=self.session,
496
489
  statement_params=statement_params,
497
- anonymous=True,
490
+ anonymous=anonymous,
498
491
  execute_as="caller",
499
492
  )
500
493
 
501
494
  return fit_predict_wrapper_sproc
502
495
 
503
- def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
504
- # If the sproc already exists, don't register.
505
- if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
506
- self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
507
-
508
- model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
509
- fit_predict_sproc_key = model_spec.__class__.__name__ + "_fit_predict"
510
- if fit_predict_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
511
- fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
512
- fit_predict_sproc_key
513
- ]
514
- return fit_sproc
515
-
516
- fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
517
-
518
- relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
519
- pkg_versions=model_spec.pkgDependencies, session=self.session
520
- )
521
-
522
- fit_predict_wrapper_sproc = self.session.sproc.register(
523
- func=self._build_fit_predict_wrapper_sproc(model_spec=model_spec),
524
- is_permanent=False,
525
- name=fit_predict_sproc_name,
526
- packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
527
- replace=True,
528
- session=self.session,
529
- statement_params=statement_params,
530
- execute_as="caller",
531
- )
532
-
533
- self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
534
- fit_predict_sproc_key
535
- ] = fit_predict_wrapper_sproc
536
-
537
- return fit_predict_wrapper_sproc
538
-
539
- def _get_fit_transform_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
540
- model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
541
-
542
- fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
543
-
544
- relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
545
- pkg_versions=model_spec.pkgDependencies, session=self.session
546
- )
547
-
548
- fit_transform_wrapper_sproc = self.session.sproc.register(
549
- func=self._build_fit_transform_wrapper_sproc(model_spec=model_spec),
550
- is_permanent=False,
551
- name=fit_transform_sproc_name,
552
- packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
553
- replace=True,
554
- session=self.session,
555
- statement_params=statement_params,
556
- anonymous=True,
557
- execute_as="caller",
558
- )
559
- return fit_transform_wrapper_sproc
560
-
561
- def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
562
- # If the sproc already exists, don't register.
563
- if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
564
- self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
565
-
496
+ def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
566
497
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
567
- fit_transform_sproc_key = model_spec.__class__.__name__ + "_fit_transform"
568
- if fit_transform_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
569
- fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
570
- fit_transform_sproc_key
571
- ]
572
- return fit_sproc
573
498
 
574
499
  fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
575
500
 
@@ -586,12 +511,9 @@ class SnowparkModelTrainer:
586
511
  session=self.session,
587
512
  statement_params=statement_params,
588
513
  execute_as="caller",
514
+ anonymous=anonymous,
589
515
  )
590
516
 
591
- self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
592
- fit_transform_sproc_key
593
- ] = fit_transform_wrapper_sproc
594
-
595
517
  return fit_transform_wrapper_sproc
596
518
 
597
519
  def train(self) -> object:
@@ -629,9 +551,9 @@ class SnowparkModelTrainer:
629
551
  # Call fit sproc
630
552
 
631
553
  if _ENABLE_ANONYMOUS_SPROC:
632
- fit_wrapper_sproc = self._get_fit_wrapper_sproc_anonymous(statement_params=statement_params)
554
+ fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params, anonymous=True)
633
555
  else:
634
- fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params)
556
+ fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params, anonymous=False)
635
557
 
636
558
  try:
637
559
  sproc_export_file_name: str = fit_wrapper_sproc(
@@ -665,6 +587,7 @@ class SnowparkModelTrainer:
665
587
  self,
666
588
  expected_output_cols_list: List[str],
667
589
  drop_input_cols: Optional[bool] = False,
590
+ example_output_pd_df: Optional[pd.DataFrame] = None,
668
591
  ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
669
592
  """Trains the model by pushing down the compute into Snowflake using stored procedures.
670
593
  This API is different from fit itself because it would also provide the predict
@@ -675,6 +598,11 @@ class SnowparkModelTrainer:
675
598
  name as a list. Defaults to None.
676
599
  drop_input_cols (Optional[bool]): Boolean to determine drop
677
600
  the input columns from the output dataset or not
601
+ example_output_pd_df (Optional[pd.DataFrame]): Example output dataframe
602
+ This is to create a temp table in the client side with df_one_row. This can maintain the same column
603
+ name and data type as the output dataframe. Within the sproc, we don't need to create another temp table
604
+ again - instead, we overwrite into this table without changing the schema.
605
+ This is not used in PandasModelTrainer.
678
606
 
679
607
  Returns:
680
608
  Tuple[Union[DataFrame, pd.DataFrame], object]: [predicted dataset, estimator]
@@ -702,12 +630,35 @@ class SnowparkModelTrainer:
702
630
 
703
631
  # Call fit sproc
704
632
  if _ENABLE_ANONYMOUS_SPROC:
705
- fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc_anonymous(statement_params=statement_params)
633
+ fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(
634
+ statement_params=statement_params, anonymous=True
635
+ )
706
636
  else:
707
- fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
637
+ fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(
638
+ statement_params=statement_params, anonymous=False
639
+ )
708
640
 
709
641
  fit_predict_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
710
642
 
643
+ # Create a temp table in advance to store the output
644
+ # This would allow us to use the same table outside the stored procedure
645
+ if not drop_input_cols:
646
+ assert example_output_pd_df is not None
647
+ remove_dataset_col_name_exist_in_output_col = list(set(dataset.columns) - set(example_output_pd_df.columns))
648
+ pd_df_one_row = (
649
+ dataset.select(remove_dataset_col_name_exist_in_output_col)
650
+ .limit(1)
651
+ .to_pandas(statement_params=statement_params)
652
+ )
653
+ example_output_pd_df = pd.concat([pd_df_one_row, example_output_pd_df], axis=1)
654
+
655
+ self.session.write_pandas(
656
+ example_output_pd_df,
657
+ fit_predict_result_name,
658
+ auto_create_table=True,
659
+ table_type="temp",
660
+ )
661
+
711
662
  sproc_export_file_name: str = fit_predict_wrapper_sproc(
712
663
  self.session,
713
664
  queries,
@@ -769,11 +720,13 @@ class SnowparkModelTrainer:
769
720
 
770
721
  # Call fit sproc
771
722
  if _ENABLE_ANONYMOUS_SPROC:
772
- fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc_anonymous(
773
- statement_params=statement_params
723
+ fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(
724
+ statement_params=statement_params, anonymous=True
774
725
  )
775
726
  else:
776
- fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params)
727
+ fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(
728
+ statement_params=statement_params, anonymous=False
729
+ )
777
730
 
778
731
  fit_transform_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
779
732