snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ class ModelTrainer(Protocol):
20
20
  self,
21
21
  expected_output_cols_list: List[str],
22
22
  drop_input_cols: Optional[bool] = False,
23
+ example_output_pd_df: Optional[pd.DataFrame] = None,
23
24
  ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
24
25
  raise NotImplementedError
25
26
 
@@ -377,7 +377,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
377
377
  anonymous=True,
378
378
  imports=imports, # type: ignore[arg-type]
379
379
  statement_params=sproc_statement_params,
380
- execute_as="caller",
381
380
  )
382
381
  def _distributed_search(
383
382
  session: Session,
@@ -495,7 +494,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
495
494
  label_arg_name = "Y" if "Y" in argspec.args else "y"
496
495
  args[label_arg_name] = df[label_cols].squeeze()
497
496
 
498
- if sample_weight_col is not None and "sample_weight" in argspec.args:
497
+ if sample_weight_col is not None:
499
498
  args["sample_weight"] = df[sample_weight_col].squeeze()
500
499
  return args, estimator, indices, len(df), params_to_evaluate
501
500
 
@@ -783,7 +782,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
783
782
  anonymous=True,
784
783
  imports=imports, # type: ignore[arg-type]
785
784
  statement_params=sproc_statement_params,
786
- execute_as="caller",
787
785
  )
788
786
  def _distributed_search(
789
787
  session: Session,
@@ -1061,7 +1059,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1061
1059
  if label_cols:
1062
1060
  label_arg_name = "Y" if "Y" in argspec.args else "y"
1063
1061
  args[label_arg_name] = y
1064
- if sample_weight_col is not None and "sample_weight" in argspec.args:
1062
+ if sample_weight_col is not None:
1065
1063
  args["sample_weight"] = df[sample_weight_col].squeeze()
1066
1064
  # estimator.refit = original_refit
1067
1065
  refit_start_time = time.time()
@@ -318,19 +318,19 @@ class SnowparkTransformHandlers:
318
318
  with open(local_score_file_name_path, mode="r+b") as local_score_file_obj:
319
319
  estimator = cp.load(local_score_file_obj)
320
320
 
321
- argspec = inspect.getfullargspec(estimator.score)
322
- if "X" in argspec.args:
321
+ params = inspect.signature(estimator.score).parameters
322
+ if "X" in params:
323
323
  args = {"X": df[input_cols]}
324
- elif "X_test" in argspec.args:
324
+ elif "X_test" in params:
325
325
  args = {"X_test": df[input_cols]}
326
326
  else:
327
327
  raise RuntimeError("Neither 'X' or 'X_test' exist in argument")
328
328
 
329
329
  if label_cols:
330
- label_arg_name = "Y" if "Y" in argspec.args else "y"
330
+ label_arg_name = "Y" if "Y" in params else "y"
331
331
  args[label_arg_name] = df[label_cols].squeeze()
332
332
 
333
- if sample_weight_col is not None and "sample_weight" in argspec.args:
333
+ if sample_weight_col is not None and "sample_weight" in params:
334
334
  args["sample_weight"] = df[sample_weight_col].squeeze()
335
335
 
336
336
  result: float = estimator.score(**args)
@@ -35,6 +35,7 @@ cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
35
35
 
36
36
  _PROJECT = "ModelDevelopment"
37
37
  _ENABLE_ANONYMOUS_SPROC = False
38
+ _ENABLE_TRACER = True
38
39
 
39
40
 
40
41
  class SnowparkModelTrainer:
@@ -119,6 +120,8 @@ class SnowparkModelTrainer:
119
120
  A callable that can be registered as a stored procedure.
120
121
  """
121
122
  imports = model_spec.imports # In order for the sproc to not resolve this reference in snowflake.ml
123
+ method_name = "fit"
124
+ tracer_name = f"snowpark.ml.modeling.{self._class_name.lower()}.{method_name}"
122
125
 
123
126
  def fit_wrapper_function(
124
127
  session: Session,
@@ -138,110 +141,97 @@ class SnowparkModelTrainer:
138
141
  for import_name in imports:
139
142
  importlib.import_module(import_name)
140
143
 
141
- # Execute snowpark queries and obtain the results as pandas dataframe
142
- # NB: this implies that the result data must fit into memory.
143
- for query in sql_queries[:-1]:
144
- _ = session.sql(query).collect(statement_params=statement_params)
145
- sp_df = session.sql(sql_queries[-1])
146
- df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
147
- df.columns = sp_df.columns
144
+ def fit_and_return_estimator() -> str:
145
+ """This is a helper function within the sproc to download the data, fit the model, and upload the model.
146
+
147
+ Returns:
148
+ The name of the file in session's temp stage (temp_stage_name) that contains the serialized model.
149
+ """
150
+ # Execute snowpark queries and obtain the results as pandas dataframe
151
+ # NB: this implies that the result data must fit into memory.
152
+ for query in sql_queries[:-1]:
153
+ _ = session.sql(query).collect(statement_params=statement_params)
154
+ sp_df = session.sql(sql_queries[-1])
155
+ df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
156
+ df.columns = sp_df.columns
157
+
158
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
159
+
160
+ session.file.get(
161
+ stage_location=temp_stage_name,
162
+ target_directory=local_transform_file_name,
163
+ statement_params=statement_params,
164
+ )
148
165
 
149
- local_transform_file_name = temp_file_utils.get_temp_file_path()
166
+ local_transform_file_path = os.path.join(
167
+ local_transform_file_name, os.listdir(local_transform_file_name)[0]
168
+ )
169
+ with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
170
+ estimator = cp.load(local_transform_file_obj)
150
171
 
151
- session.file.get(
152
- stage_location=temp_stage_name,
153
- target_directory=local_transform_file_name,
154
- statement_params=statement_params,
155
- )
172
+ params = inspect.signature(estimator.fit).parameters
173
+ args = {"X": df[input_cols]}
174
+ if label_cols:
175
+ label_arg_name = "Y" if "Y" in params else "y"
176
+ args[label_arg_name] = df[label_cols].squeeze()
156
177
 
157
- local_transform_file_path = os.path.join(
158
- local_transform_file_name, os.listdir(local_transform_file_name)[0]
159
- )
160
- with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
161
- estimator = cp.load(local_transform_file_obj)
178
+ if sample_weight_col is not None and "sample_weight" in params:
179
+ args["sample_weight"] = df[sample_weight_col].squeeze()
162
180
 
163
- argspec = inspect.getfullargspec(estimator.fit)
164
- args = {"X": df[input_cols]}
165
- if label_cols:
166
- label_arg_name = "Y" if "Y" in argspec.args else "y"
167
- args[label_arg_name] = df[label_cols].squeeze()
181
+ estimator.fit(**args)
168
182
 
169
- if sample_weight_col is not None and "sample_weight" in argspec.args:
170
- args["sample_weight"] = df[sample_weight_col].squeeze()
183
+ local_result_file_name = temp_file_utils.get_temp_file_path()
171
184
 
172
- estimator.fit(**args)
185
+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
186
+ cp.dump(estimator, local_result_file_obj)
173
187
 
174
- local_result_file_name = temp_file_utils.get_temp_file_path()
188
+ session.file.put(
189
+ local_file_name=local_result_file_name,
190
+ stage_location=temp_stage_name,
191
+ auto_compress=False,
192
+ overwrite=True,
193
+ statement_params=statement_params,
194
+ )
195
+ return local_result_file_name
175
196
 
176
- with open(local_result_file_name, mode="w+b") as local_result_file_obj:
177
- cp.dump(estimator, local_result_file_obj)
197
+ if _ENABLE_TRACER:
178
198
 
179
- session.file.put(
180
- local_file_name=local_result_file_name,
181
- stage_location=temp_stage_name,
182
- auto_compress=False,
183
- overwrite=True,
184
- statement_params=statement_params,
185
- )
199
+ # Use opentelemetry to trace the dist and span of the fit operation.
200
+ # This would allow user to see the trace in the Snowflake UI.
201
+ from opentelemetry import trace
186
202
 
187
- # Note: you can add something like + "|" + str(df) to the return string
188
- # to pass debug information to the caller.
189
- return str(os.path.basename(local_result_file_name))
203
+ tracer = trace.get_tracer(tracer_name)
204
+ with tracer.start_as_current_span("fit"):
205
+ local_result_file_name = fit_and_return_estimator()
206
+ # Note: you can add something like + "|" + str(df) to the return string
207
+ # to pass debug information to the caller.
208
+ return str(os.path.basename(local_result_file_name))
209
+ else:
210
+ local_result_file_name = fit_and_return_estimator()
211
+ return str(os.path.basename(local_result_file_name))
190
212
 
191
213
  return fit_wrapper_function
192
214
 
193
- def _get_fit_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
215
+ def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
194
216
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
195
- fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
196
-
197
- relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
198
- pkg_versions=model_spec.pkgDependencies, session=self.session
199
- )
200
-
201
- fit_wrapper_sproc = self.session.sproc.register(
202
- func=self._build_fit_wrapper_sproc(model_spec=model_spec),
203
- is_permanent=False,
204
- name=fit_sproc_name,
205
- packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
206
- replace=True,
207
- session=self.session,
208
- statement_params=statement_params,
209
- anonymous=True,
210
- execute_as="caller",
211
- )
212
-
213
- return fit_wrapper_sproc
214
-
215
- def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
216
- # If the sproc already exists, don't register.
217
- if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
218
- self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
219
-
220
- model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
221
- fit_sproc_key = model_spec.__class__.__name__
222
- if fit_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
223
- fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] # type: ignore[attr-defined]
224
- return fit_sproc
225
217
 
226
218
  fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
227
219
 
228
220
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
229
221
  pkg_versions=model_spec.pkgDependencies, session=self.session
230
222
  )
223
+ packages = ["snowflake-snowpark-python", "snowflake-telemetry-python"] + relaxed_dependencies
231
224
 
232
225
  fit_wrapper_sproc = self.session.sproc.register(
233
226
  func=self._build_fit_wrapper_sproc(model_spec=model_spec),
234
227
  is_permanent=False,
235
228
  name=fit_sproc_name,
236
- packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
229
+ packages=packages, # type: ignore[arg-type]
237
230
  replace=True,
238
231
  session=self.session,
239
232
  statement_params=statement_params,
240
- execute_as="caller",
233
+ anonymous=anonymous,
241
234
  )
242
-
243
- self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] = fit_wrapper_sproc # type: ignore[attr-defined]
244
-
245
235
  return fit_wrapper_sproc
246
236
 
247
237
  def _build_fit_predict_wrapper_sproc(
@@ -333,7 +323,9 @@ class SnowparkModelTrainer:
333
323
 
334
324
  # write into a temp table in sproc and load the table from outside
335
325
  session.write_pandas(
336
- fit_predict_result_pd, fit_predict_result_name, auto_create_table=True, table_type="temp"
326
+ fit_predict_result_pd,
327
+ fit_predict_result_name,
328
+ overwrite=True,
337
329
  )
338
330
 
339
331
  # Note: you can add something like + "|" + str(df) to the return string
@@ -414,13 +406,13 @@ class SnowparkModelTrainer:
414
406
  with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
415
407
  estimator = cp.load(local_transform_file_obj)
416
408
 
417
- argspec = inspect.getfullargspec(estimator.fit)
409
+ params = inspect.signature(estimator.fit).parameters
418
410
  args = {"X": df[input_cols]}
419
411
  if label_cols:
420
- label_arg_name = "Y" if "Y" in argspec.args else "y"
412
+ label_arg_name = "Y" if "Y" in params else "y"
421
413
  args[label_arg_name] = df[label_cols].squeeze()
422
414
 
423
- if sample_weight_col is not None and "sample_weight" in argspec.args:
415
+ if sample_weight_col is not None and "sample_weight" in params:
424
416
  args["sample_weight"] = df[sample_weight_col].squeeze()
425
417
 
426
418
  fit_transform_result = estimator.fit_transform(**args)
@@ -468,16 +460,14 @@ class SnowparkModelTrainer:
468
460
  session.write_pandas(
469
461
  transformed_pandas_df,
470
462
  fit_transform_result_name,
471
- auto_create_table=True,
472
- table_type="temp",
473
- quote_identifiers=False,
463
+ overwrite=True,
474
464
  )
475
465
 
476
466
  return str(os.path.basename(local_result_file_name))
477
467
 
478
468
  return fit_transform_wrapper_function
479
469
 
480
- def _get_fit_predict_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
470
+ def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
481
471
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
482
472
 
483
473
  fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
@@ -494,49 +484,12 @@ class SnowparkModelTrainer:
494
484
  replace=True,
495
485
  session=self.session,
496
486
  statement_params=statement_params,
497
- anonymous=True,
498
- execute_as="caller",
487
+ anonymous=anonymous,
499
488
  )
500
489
 
501
490
  return fit_predict_wrapper_sproc
502
491
 
503
- def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
504
- # If the sproc already exists, don't register.
505
- if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
506
- self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
507
-
508
- model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
509
- fit_predict_sproc_key = model_spec.__class__.__name__ + "_fit_predict"
510
- if fit_predict_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
511
- fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
512
- fit_predict_sproc_key
513
- ]
514
- return fit_sproc
515
-
516
- fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
517
-
518
- relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
519
- pkg_versions=model_spec.pkgDependencies, session=self.session
520
- )
521
-
522
- fit_predict_wrapper_sproc = self.session.sproc.register(
523
- func=self._build_fit_predict_wrapper_sproc(model_spec=model_spec),
524
- is_permanent=False,
525
- name=fit_predict_sproc_name,
526
- packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
527
- replace=True,
528
- session=self.session,
529
- statement_params=statement_params,
530
- execute_as="caller",
531
- )
532
-
533
- self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
534
- fit_predict_sproc_key
535
- ] = fit_predict_wrapper_sproc
536
-
537
- return fit_predict_wrapper_sproc
538
-
539
- def _get_fit_transform_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
492
+ def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
540
493
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
541
494
 
542
495
  fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
@@ -553,44 +506,8 @@ class SnowparkModelTrainer:
553
506
  replace=True,
554
507
  session=self.session,
555
508
  statement_params=statement_params,
556
- anonymous=True,
557
- execute_as="caller",
509
+ anonymous=anonymous,
558
510
  )
559
- return fit_transform_wrapper_sproc
560
-
561
- def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
562
- # If the sproc already exists, don't register.
563
- if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
564
- self.session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
565
-
566
- model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
567
- fit_transform_sproc_key = model_spec.__class__.__name__ + "_fit_transform"
568
- if fit_transform_sproc_key in self.session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
569
- fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
570
- fit_transform_sproc_key
571
- ]
572
- return fit_sproc
573
-
574
- fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
575
-
576
- relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
577
- pkg_versions=model_spec.pkgDependencies, session=self.session
578
- )
579
-
580
- fit_transform_wrapper_sproc = self.session.sproc.register(
581
- func=self._build_fit_transform_wrapper_sproc(model_spec=model_spec),
582
- is_permanent=False,
583
- name=fit_transform_sproc_name,
584
- packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
585
- replace=True,
586
- session=self.session,
587
- statement_params=statement_params,
588
- execute_as="caller",
589
- )
590
-
591
- self.session._FIT_WRAPPER_SPROCS[ # type: ignore[attr-defined]
592
- fit_transform_sproc_key
593
- ] = fit_transform_wrapper_sproc
594
511
 
595
512
  return fit_transform_wrapper_sproc
596
513
 
@@ -629,9 +546,9 @@ class SnowparkModelTrainer:
629
546
  # Call fit sproc
630
547
 
631
548
  if _ENABLE_ANONYMOUS_SPROC:
632
- fit_wrapper_sproc = self._get_fit_wrapper_sproc_anonymous(statement_params=statement_params)
549
+ fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params, anonymous=True)
633
550
  else:
634
- fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params)
551
+ fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params, anonymous=False)
635
552
 
636
553
  try:
637
554
  sproc_export_file_name: str = fit_wrapper_sproc(
@@ -665,6 +582,7 @@ class SnowparkModelTrainer:
665
582
  self,
666
583
  expected_output_cols_list: List[str],
667
584
  drop_input_cols: Optional[bool] = False,
585
+ example_output_pd_df: Optional[pd.DataFrame] = None,
668
586
  ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
669
587
  """Trains the model by pushing down the compute into Snowflake using stored procedures.
670
588
  This API is different from fit itself because it would also provide the predict
@@ -675,6 +593,11 @@ class SnowparkModelTrainer:
675
593
  name as a list. Defaults to None.
676
594
  drop_input_cols (Optional[bool]): Boolean to determine drop
677
595
  the input columns from the output dataset or not
596
+ example_output_pd_df (Optional[pd.DataFrame]): Example output dataframe
597
+ This is to create a temp table in the client side with df_one_row. This can maintain the same column
598
+ name and data type as the output dataframe. Within the sproc, we don't need to create another temp table
599
+ again - instead, we overwrite into this table without changing the schema.
600
+ This is not used in PandasModelTrainer.
678
601
 
679
602
  Returns:
680
603
  Tuple[Union[DataFrame, pd.DataFrame], object]: [predicted dataset, estimator]
@@ -702,12 +625,35 @@ class SnowparkModelTrainer:
702
625
 
703
626
  # Call fit sproc
704
627
  if _ENABLE_ANONYMOUS_SPROC:
705
- fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc_anonymous(statement_params=statement_params)
628
+ fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(
629
+ statement_params=statement_params, anonymous=True
630
+ )
706
631
  else:
707
- fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
632
+ fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(
633
+ statement_params=statement_params, anonymous=False
634
+ )
708
635
 
709
636
  fit_predict_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
710
637
 
638
+ # Create a temp table in advance to store the output
639
+ # This would allow us to use the same table outside the stored procedure
640
+ if not drop_input_cols:
641
+ assert example_output_pd_df is not None
642
+ remove_dataset_col_name_exist_in_output_col = list(set(dataset.columns) - set(example_output_pd_df.columns))
643
+ pd_df_one_row = (
644
+ dataset.select(remove_dataset_col_name_exist_in_output_col)
645
+ .limit(1)
646
+ .to_pandas(statement_params=statement_params)
647
+ )
648
+ example_output_pd_df = pd.concat([pd_df_one_row, example_output_pd_df], axis=1)
649
+
650
+ self.session.write_pandas(
651
+ example_output_pd_df,
652
+ fit_predict_result_name,
653
+ auto_create_table=True,
654
+ table_type="temp",
655
+ )
656
+
711
657
  sproc_export_file_name: str = fit_predict_wrapper_sproc(
712
658
  self.session,
713
659
  queries,
@@ -769,14 +715,32 @@ class SnowparkModelTrainer:
769
715
 
770
716
  # Call fit sproc
771
717
  if _ENABLE_ANONYMOUS_SPROC:
772
- fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc_anonymous(
773
- statement_params=statement_params
718
+ fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(
719
+ statement_params=statement_params, anonymous=True
774
720
  )
775
721
  else:
776
- fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params)
722
+ fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(
723
+ statement_params=statement_params, anonymous=False
724
+ )
777
725
 
778
726
  fit_transform_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
779
727
 
728
+ # Create a temp table in advance to store the output
729
+ # This would allow us to use the same table outside the stored procedure
730
+ df_one_line = dataset.limit(1).to_pandas(statement_params=statement_params)
731
+ df_one_line[
732
+ expected_output_cols_list[0]
733
+ ] = "[0]" # Add one column as the output_col; this is a dummy value to represent the OBJECT type
734
+ if drop_input_cols:
735
+ self.session.write_pandas(
736
+ df_one_line[expected_output_cols_list[0]],
737
+ fit_transform_result_name,
738
+ auto_create_table=True,
739
+ table_type="temp",
740
+ )
741
+ else:
742
+ self.session.write_pandas(df_one_line, fit_transform_result_name, auto_create_table=True, table_type="temp")
743
+
780
744
  sproc_export_file_name: str = fit_transform_wrapper_sproc(
781
745
  self.session,
782
746
  queries,
@@ -303,7 +303,6 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
303
303
  statement_params=statement_params,
304
304
  anonymous=True,
305
305
  imports=list(import_file_paths),
306
- execute_as="caller",
307
306
  ) # type: ignore[misc]
308
307
  def fit_wrapper_sproc(
309
308
  session: Session,