snowflake-ml-python 1.6.2__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/utils/db_utils.py +50 -0
  13. snowflake/ml/_internal/utils/service_logger.py +63 -0
  14. snowflake/ml/_internal/utils/sql_identifier.py +25 -1
  15. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  16. snowflake/ml/data/ingestor_utils.py +20 -10
  17. snowflake/ml/feature_store/access_manager.py +3 -3
  18. snowflake/ml/feature_store/feature_store.py +19 -2
  19. snowflake/ml/feature_store/feature_view.py +82 -28
  20. snowflake/ml/fileset/stage_fs.py +2 -1
  21. snowflake/ml/lineage/lineage_node.py +7 -2
  22. snowflake/ml/model/__init__.py +1 -2
  23. snowflake/ml/model/_client/model/model_version_impl.py +78 -9
  24. snowflake/ml/model/_client/ops/model_ops.py +89 -7
  25. snowflake/ml/model/_client/ops/service_ops.py +200 -91
  26. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -0
  27. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  28. snowflake/ml/model/_client/sql/_base.py +5 -0
  29. snowflake/ml/model/_client/sql/model.py +1 -0
  30. snowflake/ml/model/_client/sql/model_version.py +9 -5
  31. snowflake/ml/model/_client/sql/service.py +47 -13
  32. snowflake/ml/model/_model_composer/model_composer.py +11 -41
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +29 -4
  34. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  35. snowflake/ml/model/_packager/model_handlers/_utils.py +106 -32
  36. snowflake/ml/model/_packager/model_handlers/catboost.py +26 -27
  37. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -3
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +21 -6
  39. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  40. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +111 -58
  41. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  42. snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -66
  43. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +36 -17
  44. snowflake/ml/model/_packager/model_handlers/xgboost.py +22 -7
  45. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -45
  46. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -6
  47. snowflake/ml/model/_packager/model_packager.py +14 -10
  48. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  49. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  50. snowflake/ml/model/type_hints.py +11 -152
  51. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +0 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  53. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  54. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -0
  55. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -0
  56. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -0
  57. snowflake/ml/modeling/cluster/birch.py +1 -0
  58. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -0
  59. snowflake/ml/modeling/cluster/dbscan.py +1 -0
  60. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -0
  61. snowflake/ml/modeling/cluster/k_means.py +1 -0
  62. snowflake/ml/modeling/cluster/mean_shift.py +1 -0
  63. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -0
  64. snowflake/ml/modeling/cluster/optics.py +1 -0
  65. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -0
  66. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -0
  67. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -0
  68. snowflake/ml/modeling/compose/column_transformer.py +1 -0
  69. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -0
  70. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -0
  71. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -0
  72. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -0
  73. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -0
  74. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -0
  75. snowflake/ml/modeling/covariance/min_cov_det.py +1 -0
  76. snowflake/ml/modeling/covariance/oas.py +1 -0
  77. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -0
  78. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -0
  79. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -0
  80. snowflake/ml/modeling/decomposition/fast_ica.py +1 -0
  81. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -0
  82. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -0
  83. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -0
  84. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -0
  85. snowflake/ml/modeling/decomposition/pca.py +1 -0
  86. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -0
  87. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -0
  88. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -0
  89. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -0
  90. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -0
  91. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -0
  92. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -0
  93. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -0
  94. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -0
  95. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -0
  96. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -0
  97. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -0
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -0
  99. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -0
  100. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -0
  101. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -0
  102. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -0
  103. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -0
  104. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -0
  105. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -0
  106. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -0
  107. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -0
  108. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -0
  109. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -0
  110. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -0
  111. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -0
  112. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -0
  113. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -0
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -0
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -0
  116. snowflake/ml/modeling/impute/iterative_imputer.py +1 -0
  117. snowflake/ml/modeling/impute/knn_imputer.py +1 -0
  118. snowflake/ml/modeling/impute/missing_indicator.py +1 -0
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -0
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -0
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -0
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -0
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -0
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -0
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -0
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -0
  127. snowflake/ml/modeling/linear_model/ard_regression.py +1 -0
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -0
  129. snowflake/ml/modeling/linear_model/elastic_net.py +1 -0
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -0
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -0
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -0
  133. snowflake/ml/modeling/linear_model/lars.py +1 -0
  134. snowflake/ml/modeling/linear_model/lars_cv.py +1 -0
  135. snowflake/ml/modeling/linear_model/lasso.py +1 -0
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -0
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -0
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -0
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -0
  140. snowflake/ml/modeling/linear_model/linear_regression.py +1 -0
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -0
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -0
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -0
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -0
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -0
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -0
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -0
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -0
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -0
  150. snowflake/ml/modeling/linear_model/perceptron.py +1 -0
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -0
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -0
  153. snowflake/ml/modeling/linear_model/ridge.py +1 -0
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -0
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -0
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -0
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -0
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -0
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -0
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -0
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -0
  162. snowflake/ml/modeling/manifold/isomap.py +1 -0
  163. snowflake/ml/modeling/manifold/mds.py +1 -0
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -0
  165. snowflake/ml/modeling/manifold/tsne.py +1 -0
  166. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  167. snowflake/ml/modeling/metrics/ranking.py +0 -3
  168. snowflake/ml/modeling/metrics/regression.py +0 -3
  169. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -0
  170. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -0
  171. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -0
  172. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -0
  173. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -0
  174. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -0
  175. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -0
  176. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -0
  177. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -0
  178. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -0
  179. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -0
  180. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -0
  181. snowflake/ml/modeling/neighbors/kernel_density.py +1 -0
  182. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -0
  183. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -0
  184. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -0
  185. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -0
  186. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -0
  187. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -0
  188. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -0
  189. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -0
  190. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -0
  191. snowflake/ml/modeling/pipeline/pipeline.py +0 -1
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -0
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -0
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -0
  195. snowflake/ml/modeling/svm/linear_svc.py +1 -0
  196. snowflake/ml/modeling/svm/linear_svr.py +1 -0
  197. snowflake/ml/modeling/svm/nu_svc.py +1 -0
  198. snowflake/ml/modeling/svm/nu_svr.py +1 -0
  199. snowflake/ml/modeling/svm/svc.py +1 -0
  200. snowflake/ml/modeling/svm/svr.py +1 -0
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -0
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -0
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -0
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -0
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -0
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -0
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -0
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -0
  209. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  210. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  211. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  212. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  213. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  214. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  215. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  216. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  217. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  218. snowflake/ml/registry/_manager/model_manager.py +4 -4
  219. snowflake/ml/registry/registry.py +165 -6
  220. snowflake/ml/version.py +1 -1
  221. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +24 -9
  222. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/RECORD +225 -249
  223. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  224. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  225. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  226. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  227. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  228. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  229. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  230. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  231. snowflake/ml/_internal/utils/uri.py +0 -77
  232. snowflake/ml/model/_api.py +0 -568
  233. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  234. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  235. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  236. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  237. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  238. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  239. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  240. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  241. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  242. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  243. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  244. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  245. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  246. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  247. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  248. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  249. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  250. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  251. snowflake/ml/model/_packager/model_handlers/llm.py +0 -269
  252. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  253. snowflake/ml/model/deploy_platforms.py +0 -6
  254. snowflake/ml/model/models/llm.py +0 -106
  255. snowflake/ml/monitoring/monitor.py +0 -203
  256. snowflake/ml/registry/_initial_schema.py +0 -142
  257. snowflake/ml/registry/_schema.py +0 -82
  258. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  259. snowflake/ml/registry/_schema_version_manager.py +0 -163
  260. snowflake/ml/registry/model_registry.py +0 -2048
  261. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  262. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@ import os
2
2
  import pathlib
3
3
  import sys
4
4
  import tempfile
5
- import warnings
6
5
  import zipfile
7
6
  from contextlib import contextmanager
8
7
  from datetime import datetime
@@ -18,7 +17,6 @@ from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
18
17
  from snowflake.ml.model import model_signature, type_hints as model_types
19
18
  from snowflake.ml.model._packager.model_env import model_env
20
19
  from snowflake.ml.model._packager.model_meta import (
21
- _core_requirements,
22
20
  _packaging_requirements,
23
21
  model_blob_meta,
24
22
  model_meta_schema,
@@ -29,14 +27,10 @@ from snowflake.ml.model._packager.model_runtime import model_runtime
29
27
  MODEL_METADATA_FILE = "model.yaml"
30
28
  MODEL_CODE_DIR = "code"
31
29
 
32
- _PACKAGING_CORE_DEPENDENCIES = [
33
- str(env_utils.get_package_spec_with_supported_ops_only(requirements.Requirement(r)))
34
- for r in _core_requirements.REQUIREMENTS
35
- ] # Legacy Model only
36
30
  _PACKAGING_REQUIREMENTS = [
37
31
  str(env_utils.get_package_spec_with_supported_ops_only(requirements.Requirement(r)))
38
32
  for r in _packaging_requirements.REQUIREMENTS
39
- ] # New Model only
33
+ ]
40
34
  _SNOWFLAKE_PKG_NAME = "snowflake"
41
35
  _SNOWFLAKE_ML_PKG_NAME = f"{_SNOWFLAKE_PKG_NAME}.ml"
42
36
 
@@ -55,7 +49,7 @@ def create_model_metadata(
55
49
  conda_dependencies: Optional[List[str]] = None,
56
50
  pip_requirements: Optional[List[str]] = None,
57
51
  python_version: Optional[str] = None,
58
- model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
52
+ task: model_types.Task = model_types.Task.UNKNOWN,
59
53
  **kwargs: Any,
60
54
  ) -> Generator["ModelMetadata", None, None]:
61
55
  """Create a generator for model metadata object. Use generator to ensure correct register and unregister for
@@ -75,9 +69,9 @@ def create_model_metadata(
75
69
  pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
76
70
  python_version: A string of python version where model is run. Used for user override. If specified as None,
77
71
  current version would be captured. Defaults to None.
78
- model_objective: The objective of the Model Version. It is an enum class ModelObjective with values REGRESSION,
79
- BINARY_CLASSIFICATION, MULTI_CLASSIFICATION, RANKING, or UNKNOWN. By default it is set to
80
- ModelObjective.UNKNOWN and may be overridden by inferring from the Model Object.
72
+ task: The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION,
73
+ TABULAR_BINARY_CLASSIFICATION, TABULAR_MULTI_CLASSIFICATION, TABULAR_RANKING, or UNKNOWN. By default,
74
+ it is set to Task.UNKNOWN and may be overridden by inferring from the Model Object.
81
75
  **kwargs: Dict of attributes and values of the metadata. Used when loading from file.
82
76
 
83
77
  Raises:
@@ -88,18 +82,6 @@ def create_model_metadata(
88
82
  """
89
83
  model_dir_path = os.path.normpath(model_dir_path)
90
84
  embed_local_ml_library = kwargs.pop("embed_local_ml_library", False)
91
- legacy_save = kwargs.pop("_legacy_save", False)
92
- if "relax_version" not in kwargs:
93
- warnings.warn(
94
- (
95
- "`relax_version` is not set and therefore defaulted to True. Dependency version constraints relaxed "
96
- "from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility, "
97
- "reproducibility, etc., set `options={'relax_version': False}` when logging the model."
98
- ),
99
- category=UserWarning,
100
- stacklevel=2,
101
- )
102
- relax_version = kwargs.pop("relax_version", True)
103
85
 
104
86
  if embed_local_ml_library:
105
87
  # Use the last one which is loaded first, that is mean, it is loaded from site-packages.
@@ -122,7 +104,6 @@ def create_model_metadata(
122
104
  pip_requirements=pip_requirements,
123
105
  python_version=python_version,
124
106
  embed_local_ml_library=embed_local_ml_library,
125
- legacy_save=legacy_save,
126
107
  )
127
108
 
128
109
  if embed_local_ml_library:
@@ -135,18 +116,13 @@ def create_model_metadata(
135
116
  model_type=model_type,
136
117
  signatures=signatures,
137
118
  function_properties=function_properties,
138
- model_objective=model_objective,
119
+ task=task,
139
120
  )
140
121
 
141
122
  code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR)
142
- if (embed_local_ml_library and legacy_save) or code_paths:
123
+ if code_paths:
143
124
  os.makedirs(code_dir_path, exist_ok=True)
144
125
 
145
- if embed_local_ml_library and legacy_save:
146
- snowml_path_in_code = os.path.join(code_dir_path, _SNOWFLAKE_PKG_NAME)
147
- os.makedirs(snowml_path_in_code, exist_ok=True)
148
- file_utils.copy_file_or_tree(path_to_copy, snowml_path_in_code)
149
-
150
126
  if code_paths:
151
127
  for code_path in code_paths:
152
128
  # This part is to prevent users from providing code following our naming and overwrite our code.
@@ -165,8 +141,6 @@ def create_model_metadata(
165
141
  cloudpickle.register_pickle_by_value(mod)
166
142
  imported_modules.append(mod)
167
143
  yield model_meta
168
- if relax_version:
169
- model_meta.env.relax_version()
170
144
  model_meta.save(model_dir_path)
171
145
  finally:
172
146
  for mod in imported_modules:
@@ -179,7 +153,6 @@ def _create_env_for_model_metadata(
179
153
  pip_requirements: Optional[List[str]] = None,
180
154
  python_version: Optional[str] = None,
181
155
  embed_local_ml_library: bool = False,
182
- legacy_save: bool = False,
183
156
  ) -> model_env.ModelEnv:
184
157
  env = model_env.ModelEnv()
185
158
 
@@ -189,7 +162,7 @@ def _create_env_for_model_metadata(
189
162
  env.python_version = python_version # type: ignore[assignment]
190
163
  env.snowpark_ml_version = snowml_env.VERSION
191
164
 
192
- requirements_to_add = _PACKAGING_CORE_DEPENDENCIES if legacy_save else _PACKAGING_REQUIREMENTS
165
+ requirements_to_add = _PACKAGING_REQUIREMENTS
193
166
 
194
167
  if embed_local_ml_library:
195
168
  env.include_if_absent(
@@ -242,7 +215,7 @@ class ModelMetadata:
242
215
  function_properties: A dict mapping function names to dict mapping function property key to value.
243
216
  metadata: User provided key-value metadata of the model. Defaults to None.
244
217
  creation_timestamp: Unix timestamp when the model metadata is created.
245
- model_objective: Model objective like regression, classification etc.
218
+ task: Model task like TABULAR_REGRESSION, tabular_classification, timeseries_forecasting etc.
246
219
  """
247
220
 
248
221
  def telemetry_metadata(self) -> ModelMetadataTelemetryDict:
@@ -266,7 +239,7 @@ class ModelMetadata:
266
239
  min_snowpark_ml_version: Optional[str] = None,
267
240
  models: Optional[Dict[str, model_blob_meta.ModelBlobMeta]] = None,
268
241
  original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
269
- model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
242
+ task: model_types.Task = model_types.Task.UNKNOWN,
270
243
  explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
271
244
  ) -> None:
272
245
  self.name = name
@@ -292,7 +265,7 @@ class ModelMetadata:
292
265
 
293
266
  self.original_metadata_version = original_metadata_version
294
267
 
295
- self.model_objective: model_types.ModelObjective = model_objective
268
+ self.task: model_types.Task = task
296
269
  self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm
297
270
 
298
271
  @property
@@ -309,10 +282,10 @@ class ModelMetadata:
309
282
  if self._runtimes and "cpu" in self._runtimes:
310
283
  return self._runtimes
311
284
  runtimes = {
312
- "cpu": model_runtime.ModelRuntime("cpu", self.env),
285
+ "cpu": model_runtime.ModelRuntime("cpu", self.env, is_warehouse=False),
313
286
  }
314
287
  if self.env.cuda_version:
315
- runtimes.update({"gpu": model_runtime.ModelRuntime("gpu", self.env, is_gpu=True)})
288
+ runtimes.update({"gpu": model_runtime.ModelRuntime("gpu", self.env, is_warehouse=False, is_gpu=True)})
316
289
  return runtimes
317
290
 
318
291
  def save(self, model_dir_path: str) -> None:
@@ -346,7 +319,7 @@ class ModelMetadata:
346
319
  "signatures": {func_name: sig.to_dict() for func_name, sig in self.signatures.items()},
347
320
  "version": model_meta_schema.MODEL_METADATA_VERSION,
348
321
  "min_snowpark_ml_version": self.min_snowpark_ml_version,
349
- "model_objective": self.model_objective.value,
322
+ "task": self.task.value,
350
323
  "explainability": (
351
324
  model_meta_schema.ExplainabilityMetadataDict(algorithm=self.explain_algorithm.value)
352
325
  if self.explain_algorithm
@@ -390,7 +363,7 @@ class ModelMetadata:
390
363
  signatures=loaded_meta["signatures"],
391
364
  version=original_loaded_meta_version,
392
365
  min_snowpark_ml_version=loaded_meta_min_snowpark_ml_version,
393
- model_objective=loaded_meta.get("model_objective", model_types.ModelObjective.UNKNOWN.value),
366
+ task=loaded_meta.get("task", model_types.Task.UNKNOWN.value),
394
367
  explainability=loaded_meta.get("explainability", None),
395
368
  function_properties=loaded_meta.get("function_properties", {}),
396
369
  )
@@ -445,9 +418,7 @@ class ModelMetadata:
445
418
  min_snowpark_ml_version=model_dict["min_snowpark_ml_version"],
446
419
  models=models,
447
420
  original_metadata_version=model_dict["version"],
448
- model_objective=model_types.ModelObjective(
449
- model_dict.get("model_objective", model_types.ModelObjective.UNKNOWN.value)
450
- ),
421
+ task=model_types.Task(model_dict.get("task", model_types.Task.UNKNOWN.value)),
451
422
  explain_algorithm=explanation_algorithm,
452
423
  function_properties=model_dict.get("function_properties", {}),
453
424
  )
@@ -50,10 +50,6 @@ class LightGBMModelBlobOptions(BaseModelBlobOptions):
50
50
  lightgbm_estimator_type: Required[str]
51
51
 
52
52
 
53
- class LLMModelBlobOptions(BaseModelBlobOptions):
54
- batch_size: Required[int]
55
-
56
-
57
53
  class MLFlowModelBlobOptions(BaseModelBlobOptions):
58
54
  artifact_path: Required[str]
59
55
 
@@ -65,7 +61,6 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
65
61
  ModelBlobOptions = Union[
66
62
  BaseModelBlobOptions,
67
63
  HuggingFacePipelineModelBlobOptions,
68
- LLMModelBlobOptions,
69
64
  MLFlowModelBlobOptions,
70
65
  XgboostModelBlobOptions,
71
66
  ]
@@ -96,7 +91,7 @@ class ModelMetadataDict(TypedDict):
96
91
  signatures: Required[Dict[str, Dict[str, Any]]]
97
92
  version: Required[str]
98
93
  min_snowpark_ml_version: Required[str]
99
- model_objective: Required[str]
94
+ task: Required[str]
100
95
  explainability: NotRequired[Optional[ExplainabilityMetadataDict]]
101
96
  function_properties: NotRequired[Dict[str, Dict[str, Any]]]
102
97
 
@@ -47,8 +47,9 @@ class ModelPackager:
47
47
  ext_modules: Optional[List[ModuleType]] = None,
48
48
  code_paths: Optional[List[str]] = None,
49
49
  options: Optional[model_types.ModelSaveOption] = None,
50
- model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
50
+ task: model_types.Task = model_types.Task.UNKNOWN,
51
51
  ) -> model_meta.ModelMetadata:
52
+
52
53
  if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
53
54
  raise snowml_exceptions.SnowflakeMLException(
54
55
  error_code=error_codes.INVALID_ARGUMENT,
@@ -57,17 +58,20 @@ class ModelPackager:
57
58
  ),
58
59
  )
59
60
 
60
- if (signatures is not None) and (sample_input_data is not None):
61
- raise snowml_exceptions.SnowflakeMLException(
62
- error_code=error_codes.INVALID_ARGUMENT,
63
- original_exception=ValueError(
64
- "Signatures and sample_input_data both cannot be specified at the same time."
65
- ),
66
- )
67
-
68
61
  if not options:
69
62
  options = model_types.BaseModelSaveOption()
70
63
 
64
+ # here handling the case of enable_explainability is False/None
65
+ enable_explainability = options.get("enable_explainability", None)
66
+ if enable_explainability is False or enable_explainability is None:
67
+ if (signatures is not None) and (sample_input_data is not None):
68
+ raise snowml_exceptions.SnowflakeMLException(
69
+ error_code=error_codes.INVALID_ARGUMENT,
70
+ original_exception=ValueError(
71
+ "Signatures and sample_input_data both cannot be specified at the same time."
72
+ ),
73
+ )
74
+
71
75
  handler = model_handler.find_handler(model)
72
76
  if handler is None:
73
77
  raise snowml_exceptions.SnowflakeMLException(
@@ -85,7 +89,7 @@ class ModelPackager:
85
89
  conda_dependencies=conda_dependencies,
86
90
  pip_requirements=pip_requirements,
87
91
  python_version=python_version,
88
- model_objective=model_objective,
92
+ task=task,
89
93
  **options,
90
94
  ) as meta:
91
95
  model_blobs_path = os.path.join(self.local_dir_path, ModelPackager.MODEL_BLOBS_DIR)
@@ -36,6 +36,7 @@ class ModelRuntime:
36
36
  name: str,
37
37
  env: model_env.ModelEnv,
38
38
  imports: Optional[List[str]] = None,
39
+ is_warehouse: bool = False,
39
40
  is_gpu: bool = False,
40
41
  loading_from_file: bool = False,
41
42
  ) -> None:
@@ -60,6 +61,16 @@ class ModelRuntime:
60
61
  ],
61
62
  )
62
63
 
64
+ if not is_warehouse and self.embed_local_ml_library:
65
+ self.runtime_env.include_if_absent(
66
+ [
67
+ model_env.ModelDependency(
68
+ requirement="pyarrow",
69
+ pip_name="pyarrow",
70
+ )
71
+ ],
72
+ )
73
+
63
74
  if is_gpu:
64
75
  self.runtime_env.generate_env_for_cuda()
65
76
 
@@ -14,9 +14,10 @@ from snowflake.ml._internal.exceptions import (
14
14
  )
15
15
  from snowflake.ml._internal.utils import identifier
16
16
  from snowflake.ml.model import type_hints as model_types
17
- from snowflake.ml.model._deploy_client.warehouse import infer_template
18
17
  from snowflake.ml.model._signatures import base_handler, core, pandas_handler
19
18
 
19
+ _KEEP_ORDER_COL_NAME = "_ID"
20
+
20
21
 
21
22
  class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.DataFrame]):
22
23
  @staticmethod
@@ -109,7 +110,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
109
110
  # Role will be no effect on the column index. That is to say, the feature name is the actual column name.
110
111
  if keep_order:
111
112
  df = df.reset_index(drop=True)
112
- df[infer_template._KEEP_ORDER_COL_NAME] = df.index
113
+ df[_KEEP_ORDER_COL_NAME] = df.index
113
114
  sp_df = session.create_dataframe(df)
114
115
  column_names = []
115
116
  columns = []
@@ -1,23 +1,9 @@
1
1
  # mypy: disable-error-code="import"
2
2
  from enum import Enum
3
- from typing import (
4
- TYPE_CHECKING,
5
- Any,
6
- Dict,
7
- List,
8
- Literal,
9
- Optional,
10
- Sequence,
11
- TypedDict,
12
- TypeVar,
13
- Union,
14
- )
3
+ from typing import TYPE_CHECKING, Dict, Literal, Sequence, TypedDict, TypeVar, Union
15
4
 
16
5
  import numpy.typing as npt
17
- from typing_extensions import NotRequired, Required
18
-
19
- from snowflake.ml.model import deploy_platforms
20
- from snowflake.ml.model._signatures import core
6
+ from typing_extensions import NotRequired
21
7
 
22
8
  if TYPE_CHECKING:
23
9
  import catboost
@@ -35,7 +21,6 @@ if TYPE_CHECKING:
35
21
 
36
22
  import snowflake.ml.model.custom_model
37
23
  import snowflake.ml.model.models.huggingface_pipeline
38
- import snowflake.ml.model.models.llm
39
24
  import snowflake.snowpark
40
25
  from snowflake.ml.modeling.framework import base # noqa: F401
41
26
 
@@ -91,7 +76,6 @@ SupportedNoSignatureRequirementsModelType = Union[
91
76
  "transformers.Pipeline",
92
77
  "sentence_transformers.SentenceTransformer",
93
78
  "snowflake.ml.model.models.huggingface_pipeline.HuggingFacePipelineModel",
94
- "snowflake.ml.model.models.llm.LLM",
95
79
  ]
96
80
 
97
81
  SupportedModelType = Union[
@@ -134,86 +118,11 @@ SupportedModelHandlerType = Literal[
134
118
  "tensorflow",
135
119
  "torchscript",
136
120
  "xgboost",
137
- "llm",
138
121
  ]
139
122
 
140
123
  _ModelType = TypeVar("_ModelType", bound=SupportedModelType)
141
124
 
142
125
 
143
- class DeployOptions(TypedDict):
144
- """Common Options for deploying to Snowflake."""
145
-
146
- ...
147
-
148
-
149
- class WarehouseDeployOptions(DeployOptions):
150
- """Options for deploying to the Snowflake Warehouse.
151
-
152
-
153
- permanent_udf_stage_location: A Snowflake stage option where the UDF should be persisted. If specified, the model
154
- will be deployed as a permanent UDF, otherwise temporary.
155
- relax_version: Whether or not relax the version constraints of the dependencies if unresolvable. It detects any
156
- ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to False.
157
- replace_udf: Flag to indicate when deploying model as permanent UDF, whether overwriting existed UDF is allowed.
158
- Default to False.
159
- """
160
-
161
- permanent_udf_stage_location: NotRequired[str]
162
- relax_version: NotRequired[bool]
163
- replace_udf: NotRequired[bool]
164
-
165
-
166
- class SnowparkContainerServiceDeployOptions(DeployOptions):
167
- """Deployment options for deploying to SnowService.
168
- When type hint is updated, please ensure the concrete class is updated accordingly at:
169
- //snowflake/ml/model/_deploy_client/snowservice/_deploy_options
170
-
171
- compute_pool[REQUIRED]: SnowService compute pool name. Please refer to official doc for how to create a
172
- compute pool: https://docs.snowflake.com/LIMITEDACCESS/snowpark-containers/reference/compute-pool
173
- image_repo: SnowService image repo path. e.g. "<image_registry>/<db>/<schema>/<repo>". Default to auto
174
- inferred based on session information.
175
- min_instances: Minimum number of service replicas. Default to 1.
176
- max_instances: Maximum number of service replicas. Default to 1.
177
- prebuilt_snowflake_image: When provided, the image-building step is skipped, and the pre-built image from
178
- Snowflake is used as is. This option is for users who consistently use the same image for multiple use
179
- cases, allowing faster deployment. The snowflake image used for deployment is logged to the console for
180
- future use. Default to None.
181
- num_gpus: Number of GPUs to be used for the service. Default to 0.
182
- num_workers: Number of workers used for model inference. Please ensure that the number of workers is set lower than
183
- the total available memory divided by the size of model to prevent memory-related issues. Default is number of
184
- CPU cores * 2 + 1.
185
- enable_remote_image_build: When set to True, will enable image build on a remote SnowService job. Default is True.
186
- force_image_build: When set to True, an image rebuild will occur. The default is False, which means the system
187
- will automatically check whether a previously built image can be reused
188
- model_in_image: When set to True, image would container full model weights. The default if False, which
189
- means image without model weights and we do stage mount to access weights.
190
- debug_mode: When set to True, deployment artifacts will be persisted in a local temp directory.
191
- enable_ingress: When set to True, will expose HTTP endpoint for access to the predict method of the created
192
- service.
193
- external_access_integrations: External Access Integrations name used to build image and deploy the model.
194
- Please refer to the doc for how to create an External Access Integrations: https://docs.snowflake.com/
195
- developer-guide/snowpark-container-services/additional-considerations-services-jobs
196
- #configuring-network-capabilities .
197
- To make sure your image could be built, access to the following endpoint must be allowed.
198
- docker.com:80, docker.com:443, anaconda.com:80, anaconda.com:443, anaconda.org:80, anaconda.org:443,
199
- pypi.org:80, pypi.org:443
200
- """
201
-
202
- compute_pool: str
203
- image_repo: NotRequired[str]
204
- min_instances: NotRequired[int]
205
- max_instances: NotRequired[int]
206
- prebuilt_snowflake_image: NotRequired[str]
207
- num_gpus: NotRequired[int]
208
- num_workers: NotRequired[int]
209
- enable_remote_image_build: NotRequired[bool]
210
- force_image_build: NotRequired[bool]
211
- model_in_image: NotRequired[bool]
212
- debug_mode: NotRequired[bool]
213
- enable_ingress: NotRequired[bool]
214
- external_access_integrations: List[str]
215
-
216
-
217
126
  class ModelMethodSaveOptions(TypedDict):
218
127
  case_sensitive: NotRequired[bool]
219
128
  max_batch_size: NotRequired[int]
@@ -224,13 +133,12 @@ class BaseModelSaveOption(TypedDict):
224
133
  """Options for saving the model.
225
134
 
226
135
  embed_local_ml_library: Embedding local SnowML into the code directory of the folder.
227
- relax_version: Whether or not relax the version constraints of the dependencies if unresolvable. It detects any
228
- ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to False.
136
+ relax_version: Whether or not relax the version constraints of the dependencies if unresolvable in Warehouse.
137
+ It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
229
138
  """
230
139
 
231
140
  embed_local_ml_library: NotRequired[bool]
232
141
  relax_version: NotRequired[bool]
233
- _legacy_save: NotRequired[bool]
234
142
  function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
235
143
  method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
236
144
  enable_explainability: NotRequired[bool]
@@ -293,10 +201,6 @@ class SentenceTransformersSaveOptions(BaseModelSaveOption):
293
201
  cuda_version: NotRequired[str]
294
202
 
295
203
 
296
- class LLMSaveOptions(BaseModelSaveOption):
297
- cuda_version: NotRequired[str]
298
-
299
-
300
204
  ModelSaveOption = Union[
301
205
  BaseModelSaveOption,
302
206
  CatBoostModelSaveOptions,
@@ -311,7 +215,6 @@ ModelSaveOption = Union[
311
215
  MLFlowSaveOptions,
312
216
  HuggingFaceSaveOptions,
313
217
  SentenceTransformersSaveOptions,
314
- LLMSaveOptions,
315
218
  ]
316
219
 
317
220
 
@@ -369,10 +272,7 @@ class HuggingFaceLoadOptions(BaseModelLoadOption):
369
272
 
370
273
  class SentenceTransformersLoadOptions(BaseModelLoadOption):
371
274
  use_gpu: NotRequired[bool]
372
-
373
-
374
- class LLMLoadOptions(BaseModelLoadOption):
375
- ...
275
+ device: NotRequired[str]
376
276
 
377
277
 
378
278
  ModelLoadOption = Union[
@@ -389,53 +289,12 @@ ModelLoadOption = Union[
389
289
  MLFlowLoadOptions,
390
290
  HuggingFaceLoadOptions,
391
291
  SentenceTransformersLoadOptions,
392
- LLMLoadOptions,
393
- ]
394
-
395
-
396
- class SnowparkContainerServiceDeployDetails(TypedDict):
397
- """
398
- Attributes:
399
- service_info: A snowpark row containing the result of "describe service"
400
- service_function_sql: SQL for service function creation.
401
- """
402
-
403
- service_info: Optional[Dict[str, Any]]
404
- service_function_sql: str
405
-
406
-
407
- class WarehouseDeployDetails(TypedDict):
408
- ...
409
-
410
-
411
- DeployDetails = Union[
412
- SnowparkContainerServiceDeployDetails,
413
- WarehouseDeployDetails,
414
292
  ]
415
293
 
416
294
 
417
- class Deployment(TypedDict):
418
- """Deployment information.
419
-
420
- Attributes:
421
- name: Name of the deployment.
422
- platform: Target platform to deploy the model.
423
- target_method: Target method name.
424
- signature: The signature of the model method.
425
- options: Additional options when deploying the model.
426
- """
427
-
428
- name: Required[str]
429
- platform: Required[deploy_platforms.TargetPlatform]
430
- target_method: Required[str]
431
- signature: core.ModelSignature
432
- options: Required[DeployOptions]
433
- details: NotRequired[DeployDetails]
434
-
435
-
436
- class ModelObjective(Enum):
437
- UNKNOWN = "unknown"
438
- BINARY_CLASSIFICATION = "binary_classification"
439
- MULTI_CLASSIFICATION = "multi_classification"
440
- REGRESSION = "regression"
441
- RANKING = "ranking"
295
+ class Task(Enum):
296
+ UNKNOWN = "UNKNOWN"
297
+ TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION"
298
+ TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
299
+ TABULAR_REGRESSION = "TABULAR_REGRESSION"
300
+ TABULAR_RANKING = "TABULAR_RANKING"
@@ -377,7 +377,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
377
377
  anonymous=True,
378
378
  imports=imports, # type: ignore[arg-type]
379
379
  statement_params=sproc_statement_params,
380
- execute_as="caller",
381
380
  )
382
381
  def _distributed_search(
383
382
  session: Session,
@@ -783,7 +782,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
783
782
  anonymous=True,
784
783
  imports=imports, # type: ignore[arg-type]
785
784
  statement_params=sproc_statement_params,
786
- execute_as="caller",
787
785
  )
788
786
  def _distributed_search(
789
787
  session: Session,
@@ -230,7 +230,6 @@ class SnowparkModelTrainer:
230
230
  replace=True,
231
231
  session=self.session,
232
232
  statement_params=statement_params,
233
- execute_as="caller",
234
233
  anonymous=anonymous,
235
234
  )
236
235
  return fit_wrapper_sproc
@@ -461,9 +460,7 @@ class SnowparkModelTrainer:
461
460
  session.write_pandas(
462
461
  transformed_pandas_df,
463
462
  fit_transform_result_name,
464
- auto_create_table=True,
465
- table_type="temp",
466
- quote_identifiers=False,
463
+ overwrite=True,
467
464
  )
468
465
 
469
466
  return str(os.path.basename(local_result_file_name))
@@ -488,7 +485,6 @@ class SnowparkModelTrainer:
488
485
  session=self.session,
489
486
  statement_params=statement_params,
490
487
  anonymous=anonymous,
491
- execute_as="caller",
492
488
  )
493
489
 
494
490
  return fit_predict_wrapper_sproc
@@ -510,7 +506,6 @@ class SnowparkModelTrainer:
510
506
  replace=True,
511
507
  session=self.session,
512
508
  statement_params=statement_params,
513
- execute_as="caller",
514
509
  anonymous=anonymous,
515
510
  )
516
511
 
@@ -730,6 +725,22 @@ class SnowparkModelTrainer:
730
725
 
731
726
  fit_transform_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
732
727
 
728
+ # Create a temp table in advance to store the output
729
+ # This would allow us to use the same table outside the stored procedure
730
+ df_one_line = dataset.limit(1).to_pandas(statement_params=statement_params)
731
+ df_one_line[
732
+ expected_output_cols_list[0]
733
+ ] = "[0]" # Add one column as the output_col; this is a dummy value to represent the OBJECT type
734
+ if drop_input_cols:
735
+ self.session.write_pandas(
736
+ df_one_line[expected_output_cols_list[0]],
737
+ fit_transform_result_name,
738
+ auto_create_table=True,
739
+ table_type="temp",
740
+ )
741
+ else:
742
+ self.session.write_pandas(df_one_line, fit_transform_result_name, auto_create_table=True, table_type="temp")
743
+
733
744
  sproc_export_file_name: str = fit_transform_wrapper_sproc(
734
745
  self.session,
735
746
  queries,
@@ -303,7 +303,6 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
303
303
  statement_params=statement_params,
304
304
  anonymous=True,
305
305
  imports=list(import_file_paths),
306
- execute_as="caller",
307
306
  ) # type: ignore[misc]
308
307
  def fit_wrapper_sproc(
309
308
  session: Session,
@@ -576,6 +576,7 @@ class CalibratedClassifierCV(BaseTransformer):
576
576
  """
577
577
  self._infer_input_output_cols(dataset)
578
578
  super()._check_dataset_type(dataset)
579
+
579
580
  model_trainer = ModelTrainerBuilder.build_fit_transform(
580
581
  estimator=self._sklearn_object,
581
582
  dataset=dataset,
@@ -555,6 +555,7 @@ class AffinityPropagation(BaseTransformer):
555
555
  """
556
556
  self._infer_input_output_cols(dataset)
557
557
  super()._check_dataset_type(dataset)
558
+
558
559
  model_trainer = ModelTrainerBuilder.build_fit_transform(
559
560
  estimator=self._sklearn_object,
560
561
  dataset=dataset,
@@ -586,6 +586,7 @@ class AgglomerativeClustering(BaseTransformer):
586
586
  """
587
587
  self._infer_input_output_cols(dataset)
588
588
  super()._check_dataset_type(dataset)
589
+
589
590
  model_trainer = ModelTrainerBuilder.build_fit_transform(
590
591
  estimator=self._sklearn_object,
591
592
  dataset=dataset,
@@ -550,6 +550,7 @@ class Birch(BaseTransformer):
550
550
  """
551
551
  self._infer_input_output_cols(dataset)
552
552
  super()._check_dataset_type(dataset)
553
+
553
554
  model_trainer = ModelTrainerBuilder.build_fit_transform(
554
555
  estimator=self._sklearn_object,
555
556
  dataset=dataset,
@@ -599,6 +599,7 @@ class BisectingKMeans(BaseTransformer):
599
599
  """
600
600
  self._infer_input_output_cols(dataset)
601
601
  super()._check_dataset_type(dataset)
602
+
602
603
  model_trainer = ModelTrainerBuilder.build_fit_transform(
603
604
  estimator=self._sklearn_object,
604
605
  dataset=dataset,