snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/cortex/_classify_text.py +3 -3
  3. snowflake/cortex/_complete.py +23 -24
  4. snowflake/cortex/_embed_text_1024.py +4 -4
  5. snowflake/cortex/_embed_text_768.py +4 -4
  6. snowflake/cortex/_finetune.py +8 -8
  7. snowflake/cortex/_util.py +8 -12
  8. snowflake/ml/_internal/env.py +4 -3
  9. snowflake/ml/_internal/env_utils.py +63 -34
  10. snowflake/ml/_internal/file_utils.py +10 -21
  11. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  12. snowflake/ml/_internal/init_utils.py +2 -3
  13. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  14. snowflake/ml/_internal/platform_capabilities.py +18 -16
  15. snowflake/ml/_internal/telemetry.py +39 -52
  16. snowflake/ml/_internal/type_utils.py +3 -3
  17. snowflake/ml/_internal/utils/db_utils.py +2 -2
  18. snowflake/ml/_internal/utils/identifier.py +10 -10
  19. snowflake/ml/_internal/utils/import_utils.py +2 -2
  20. snowflake/ml/_internal/utils/parallelize.py +7 -7
  21. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  22. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  23. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  24. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  25. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  26. snowflake/ml/_internal/utils/table_manager.py +9 -9
  27. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  28. snowflake/ml/data/data_connector.py +15 -36
  29. snowflake/ml/data/data_ingestor.py +4 -15
  30. snowflake/ml/data/data_source.py +2 -2
  31. snowflake/ml/data/ingestor_utils.py +3 -3
  32. snowflake/ml/data/torch_utils.py +5 -5
  33. snowflake/ml/dataset/dataset.py +11 -11
  34. snowflake/ml/dataset/dataset_metadata.py +8 -8
  35. snowflake/ml/dataset/dataset_reader.py +7 -7
  36. snowflake/ml/feature_store/__init__.py +1 -1
  37. snowflake/ml/feature_store/access_manager.py +7 -7
  38. snowflake/ml/feature_store/entity.py +6 -6
  39. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  41. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  45. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  48. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  51. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  52. snowflake/ml/feature_store/feature_store.py +52 -64
  53. snowflake/ml/feature_store/feature_view.py +24 -24
  54. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  55. snowflake/ml/fileset/fileset.py +5 -5
  56. snowflake/ml/fileset/sfcfs.py +13 -13
  57. snowflake/ml/fileset/stage_fs.py +15 -15
  58. snowflake/ml/jobs/_utils/constants.py +1 -1
  59. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  60. snowflake/ml/jobs/_utils/payload_utils.py +45 -46
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  64. snowflake/ml/jobs/_utils/spec_utils.py +18 -29
  65. snowflake/ml/jobs/_utils/types.py +2 -2
  66. snowflake/ml/jobs/decorators.py +10 -5
  67. snowflake/ml/jobs/job.py +87 -30
  68. snowflake/ml/jobs/manager.py +86 -56
  69. snowflake/ml/lineage/lineage_node.py +5 -5
  70. snowflake/ml/model/_client/model/model_impl.py +3 -3
  71. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  72. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  73. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  74. snowflake/ml/model/_client/ops/service_ops.py +217 -32
  75. snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
  76. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
  77. snowflake/ml/model/_client/sql/model.py +8 -8
  78. snowflake/ml/model/_client/sql/model_version.py +26 -26
  79. snowflake/ml/model/_client/sql/service.py +17 -26
  80. snowflake/ml/model/_client/sql/stage.py +2 -2
  81. snowflake/ml/model/_client/sql/tag.py +6 -6
  82. snowflake/ml/model/_model_composer/model_composer.py +58 -32
  83. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  85. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  86. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  87. snowflake/ml/model/_packager/model_handler.py +4 -4
  88. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  89. snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
  90. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  91. snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
  92. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  93. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  95. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  96. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
  97. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  98. snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
  99. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  100. snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
  101. snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
  102. snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
  103. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  104. snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
  105. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
  106. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  107. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  109. snowflake/ml/model/_packager/model_packager.py +11 -9
  110. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  111. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  112. snowflake/ml/model/_signatures/core.py +16 -24
  113. snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
  114. snowflake/ml/model/_signatures/utils.py +6 -6
  115. snowflake/ml/model/custom_model.py +24 -11
  116. snowflake/ml/model/model_signature.py +12 -23
  117. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  118. snowflake/ml/model/type_hints.py +3 -3
  119. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  120. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  122. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  123. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  124. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  125. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  126. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  128. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  129. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  130. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  131. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  132. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  133. snowflake/ml/modeling/cluster/birch.py +9 -1
  134. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  135. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  136. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  137. snowflake/ml/modeling/cluster/k_means.py +9 -1
  138. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  139. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  140. snowflake/ml/modeling/cluster/optics.py +9 -1
  141. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  142. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  143. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  144. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  145. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  146. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  147. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  148. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  149. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  150. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  151. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  152. snowflake/ml/modeling/covariance/oas.py +9 -1
  153. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  154. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  155. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  156. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  157. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  158. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  159. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  160. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  161. snowflake/ml/modeling/decomposition/pca.py +9 -1
  162. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  163. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  164. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  165. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  166. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  167. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  168. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  169. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  170. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  171. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  172. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  173. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  174. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  175. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  176. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  177. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  178. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  179. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  180. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  181. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  182. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  183. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  184. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  185. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  186. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  187. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  188. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  189. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  190. snowflake/ml/modeling/framework/_utils.py +10 -10
  191. snowflake/ml/modeling/framework/base.py +32 -32
  192. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  193. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  194. snowflake/ml/modeling/impute/__init__.py +1 -1
  195. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  196. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  197. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  198. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  199. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  200. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  201. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  202. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  203. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  204. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  205. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  206. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  207. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  208. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  209. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  210. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  211. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  212. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  213. snowflake/ml/modeling/linear_model/lars.py +9 -1
  214. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  215. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  216. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  217. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  218. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  219. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  220. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  221. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  222. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  223. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  224. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  225. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  226. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  227. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  228. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  229. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  230. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  231. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  232. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  233. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  234. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  235. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  236. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  237. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  238. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  239. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  240. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  241. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  242. snowflake/ml/modeling/manifold/isomap.py +9 -1
  243. snowflake/ml/modeling/manifold/mds.py +9 -1
  244. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  245. snowflake/ml/modeling/manifold/tsne.py +9 -1
  246. snowflake/ml/modeling/metrics/__init__.py +1 -1
  247. snowflake/ml/modeling/metrics/classification.py +39 -39
  248. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  249. snowflake/ml/modeling/metrics/ranking.py +7 -7
  250. snowflake/ml/modeling/metrics/regression.py +13 -13
  251. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  252. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  253. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  254. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  255. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  256. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  257. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  258. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  259. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  260. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  261. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  262. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  263. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  264. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  265. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  266. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  267. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  268. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  269. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  270. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  271. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  272. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  273. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  274. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  275. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  276. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  277. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  278. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  279. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  280. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  281. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  282. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  283. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  284. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  285. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  286. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  287. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  288. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  289. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  290. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  291. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  292. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  293. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  294. snowflake/ml/modeling/svm/svc.py +9 -1
  295. snowflake/ml/modeling/svm/svr.py +9 -1
  296. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  297. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  298. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  299. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  300. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  301. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  302. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  303. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  304. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  305. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  306. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  307. snowflake/ml/monitoring/explain_visualize.py +286 -0
  308. snowflake/ml/registry/_manager/model_manager.py +55 -32
  309. snowflake/ml/registry/registry.py +39 -31
  310. snowflake/ml/utils/authentication.py +2 -2
  311. snowflake/ml/utils/connection_params.py +5 -5
  312. snowflake/ml/utils/sparse.py +5 -4
  313. snowflake/ml/utils/sql_client.py +1 -2
  314. snowflake/ml/version.py +2 -1
  315. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
  316. snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
  317. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
  318. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  319. snowflake/ml/modeling/_internal/constants.py +0 -2
  320. snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
  321. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  322. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import os
4
4
  import pathlib
5
5
  import tempfile
6
6
  import warnings
7
- from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast, overload
7
+ from typing import Any, Literal, Optional, TypedDict, Union, cast, overload
8
8
 
9
9
  import yaml
10
10
 
@@ -104,7 +104,7 @@ class ModelOperator:
104
104
  *,
105
105
  database_name: Optional[sql_identifier.SqlIdentifier],
106
106
  schema_name: Optional[sql_identifier.SqlIdentifier],
107
- statement_params: Optional[Dict[str, Any]] = None,
107
+ statement_params: Optional[dict[str, Any]] = None,
108
108
  ) -> str:
109
109
  stage_name = sql_identifier.SqlIdentifier(
110
110
  snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
@@ -137,7 +137,7 @@ class ModelOperator:
137
137
  schema_name: Optional[sql_identifier.SqlIdentifier],
138
138
  model_name: sql_identifier.SqlIdentifier,
139
139
  version_name: sql_identifier.SqlIdentifier,
140
- statement_params: Optional[Dict[str, Any]] = None,
140
+ statement_params: Optional[dict[str, Any]] = None,
141
141
  ) -> ModelAction:
142
142
  if self.validate_existence(
143
143
  database_name=database_name,
@@ -169,7 +169,7 @@ class ModelOperator:
169
169
  schema_name: Optional[sql_identifier.SqlIdentifier],
170
170
  model_name: sql_identifier.SqlIdentifier,
171
171
  version_name: sql_identifier.SqlIdentifier,
172
- statement_params: Optional[Dict[str, Any]] = None,
172
+ statement_params: Optional[dict[str, Any]] = None,
173
173
  ) -> None:
174
174
  model_action = self.get_model_action_from_model_name_and_version(
175
175
  database_name=database_name,
@@ -205,7 +205,7 @@ class ModelOperator:
205
205
  schema_name: Optional[sql_identifier.SqlIdentifier],
206
206
  model_name: sql_identifier.SqlIdentifier,
207
207
  version_name: sql_identifier.SqlIdentifier,
208
- statement_params: Optional[Dict[str, Any]] = None,
208
+ statement_params: Optional[dict[str, Any]] = None,
209
209
  use_live_commit: Optional[bool] = False,
210
210
  ) -> None:
211
211
 
@@ -263,7 +263,7 @@ class ModelOperator:
263
263
  model_name: sql_identifier.SqlIdentifier,
264
264
  version_name: sql_identifier.SqlIdentifier,
265
265
  model_exists: bool,
266
- statement_params: Optional[Dict[str, Any]] = None,
266
+ statement_params: Optional[dict[str, Any]] = None,
267
267
  ) -> None:
268
268
  if model_exists:
269
269
  return self._model_version_client.add_version_from_model_version(
@@ -296,8 +296,8 @@ class ModelOperator:
296
296
  database_name: Optional[sql_identifier.SqlIdentifier],
297
297
  schema_name: Optional[sql_identifier.SqlIdentifier],
298
298
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
299
- statement_params: Optional[Dict[str, Any]] = None,
300
- ) -> List[row.Row]:
299
+ statement_params: Optional[dict[str, Any]] = None,
300
+ ) -> list[row.Row]:
301
301
  if model_name:
302
302
  return self._model_client.show_versions(
303
303
  database_name=database_name,
@@ -320,8 +320,8 @@ class ModelOperator:
320
320
  database_name: Optional[sql_identifier.SqlIdentifier],
321
321
  schema_name: Optional[sql_identifier.SqlIdentifier],
322
322
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
323
- statement_params: Optional[Dict[str, Any]] = None,
324
- ) -> List[sql_identifier.SqlIdentifier]:
323
+ statement_params: Optional[dict[str, Any]] = None,
324
+ ) -> list[sql_identifier.SqlIdentifier]:
325
325
  res = self.show_models_or_versions(
326
326
  database_name=database_name,
327
327
  schema_name=schema_name,
@@ -341,7 +341,7 @@ class ModelOperator:
341
341
  schema_name: Optional[sql_identifier.SqlIdentifier],
342
342
  model_name: sql_identifier.SqlIdentifier,
343
343
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
344
- statement_params: Optional[Dict[str, Any]] = None,
344
+ statement_params: Optional[dict[str, Any]] = None,
345
345
  ) -> bool:
346
346
  if version_name:
347
347
  res = self._model_client.show_versions(
@@ -369,7 +369,7 @@ class ModelOperator:
369
369
  schema_name: Optional[sql_identifier.SqlIdentifier],
370
370
  model_name: sql_identifier.SqlIdentifier,
371
371
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
372
- statement_params: Optional[Dict[str, Any]] = None,
372
+ statement_params: Optional[dict[str, Any]] = None,
373
373
  ) -> str:
374
374
  if version_name:
375
375
  res = self._model_client.show_versions(
@@ -398,7 +398,7 @@ class ModelOperator:
398
398
  schema_name: Optional[sql_identifier.SqlIdentifier],
399
399
  model_name: sql_identifier.SqlIdentifier,
400
400
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
401
- statement_params: Optional[Dict[str, Any]] = None,
401
+ statement_params: Optional[dict[str, Any]] = None,
402
402
  ) -> None:
403
403
  if version_name:
404
404
  self._model_version_client.set_comment(
@@ -426,7 +426,7 @@ class ModelOperator:
426
426
  schema_name: Optional[sql_identifier.SqlIdentifier],
427
427
  model_name: sql_identifier.SqlIdentifier,
428
428
  version_name: sql_identifier.SqlIdentifier,
429
- statement_params: Optional[Dict[str, Any]] = None,
429
+ statement_params: Optional[dict[str, Any]] = None,
430
430
  ) -> None:
431
431
  self._model_version_client.set_alias(
432
432
  alias_name=alias_name,
@@ -444,7 +444,7 @@ class ModelOperator:
444
444
  database_name: Optional[sql_identifier.SqlIdentifier],
445
445
  schema_name: Optional[sql_identifier.SqlIdentifier],
446
446
  model_name: sql_identifier.SqlIdentifier,
447
- statement_params: Optional[Dict[str, Any]] = None,
447
+ statement_params: Optional[dict[str, Any]] = None,
448
448
  ) -> None:
449
449
  self._model_version_client.unset_alias(
450
450
  database_name=database_name,
@@ -461,7 +461,7 @@ class ModelOperator:
461
461
  schema_name: Optional[sql_identifier.SqlIdentifier],
462
462
  model_name: sql_identifier.SqlIdentifier,
463
463
  version_name: sql_identifier.SqlIdentifier,
464
- statement_params: Optional[Dict[str, Any]] = None,
464
+ statement_params: Optional[dict[str, Any]] = None,
465
465
  ) -> None:
466
466
  if not self.validate_existence(
467
467
  database_name=database_name,
@@ -485,7 +485,7 @@ class ModelOperator:
485
485
  database_name: Optional[sql_identifier.SqlIdentifier],
486
486
  schema_name: Optional[sql_identifier.SqlIdentifier],
487
487
  model_name: sql_identifier.SqlIdentifier,
488
- statement_params: Optional[Dict[str, Any]] = None,
488
+ statement_params: Optional[dict[str, Any]] = None,
489
489
  ) -> sql_identifier.SqlIdentifier:
490
490
  res = self._model_client.show_models(
491
491
  database_name=database_name,
@@ -504,7 +504,7 @@ class ModelOperator:
504
504
  schema_name: Optional[sql_identifier.SqlIdentifier],
505
505
  model_name: sql_identifier.SqlIdentifier,
506
506
  alias_name: sql_identifier.SqlIdentifier,
507
- statement_params: Optional[Dict[str, Any]] = None,
507
+ statement_params: Optional[dict[str, Any]] = None,
508
508
  ) -> Optional[sql_identifier.SqlIdentifier]:
509
509
  res = self._model_client.show_versions(
510
510
  database_name=database_name,
@@ -528,7 +528,7 @@ class ModelOperator:
528
528
  tag_database_name: Optional[sql_identifier.SqlIdentifier],
529
529
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
530
530
  tag_name: sql_identifier.SqlIdentifier,
531
- statement_params: Optional[Dict[str, Any]] = None,
531
+ statement_params: Optional[dict[str, Any]] = None,
532
532
  ) -> Optional[str]:
533
533
  r = self._tag_client.get_tag_value(
534
534
  database_name=database_name,
@@ -550,15 +550,15 @@ class ModelOperator:
550
550
  database_name: Optional[sql_identifier.SqlIdentifier],
551
551
  schema_name: Optional[sql_identifier.SqlIdentifier],
552
552
  model_name: sql_identifier.SqlIdentifier,
553
- statement_params: Optional[Dict[str, Any]] = None,
554
- ) -> Dict[str, str]:
553
+ statement_params: Optional[dict[str, Any]] = None,
554
+ ) -> dict[str, str]:
555
555
  tags_info = self._tag_client.get_tag_list(
556
556
  database_name=database_name,
557
557
  schema_name=schema_name,
558
558
  model_name=model_name,
559
559
  statement_params=statement_params,
560
560
  )
561
- res: Dict[str, str] = {
561
+ res: dict[str, str] = {
562
562
  identifier.get_schema_level_object_identifier(
563
563
  sql_identifier.SqlIdentifier(r.TAG_DATABASE, case_sensitive=True),
564
564
  sql_identifier.SqlIdentifier(r.TAG_SCHEMA, case_sensitive=True),
@@ -578,7 +578,7 @@ class ModelOperator:
578
578
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
579
579
  tag_name: sql_identifier.SqlIdentifier,
580
580
  tag_value: str,
581
- statement_params: Optional[Dict[str, Any]] = None,
581
+ statement_params: Optional[dict[str, Any]] = None,
582
582
  ) -> None:
583
583
  self._tag_client.set_tag_on_model(
584
584
  database_name=database_name,
@@ -600,7 +600,7 @@ class ModelOperator:
600
600
  tag_database_name: Optional[sql_identifier.SqlIdentifier],
601
601
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
602
602
  tag_name: sql_identifier.SqlIdentifier,
603
- statement_params: Optional[Dict[str, Any]] = None,
603
+ statement_params: Optional[dict[str, Any]] = None,
604
604
  ) -> None:
605
605
  self._tag_client.unset_tag_on_model(
606
606
  database_name=database_name,
@@ -619,8 +619,8 @@ class ModelOperator:
619
619
  schema_name: Optional[sql_identifier.SqlIdentifier],
620
620
  model_name: sql_identifier.SqlIdentifier,
621
621
  version_name: sql_identifier.SqlIdentifier,
622
- statement_params: Optional[Dict[str, Any]] = None,
623
- ) -> List[ServiceInfo]:
622
+ statement_params: Optional[dict[str, Any]] = None,
623
+ ) -> list[ServiceInfo]:
624
624
  res = self._model_client.show_versions(
625
625
  database_name=database_name,
626
626
  schema_name=schema_name,
@@ -682,7 +682,7 @@ class ModelOperator:
682
682
  service_database_name: Optional[sql_identifier.SqlIdentifier],
683
683
  service_schema_name: Optional[sql_identifier.SqlIdentifier],
684
684
  service_name: sql_identifier.SqlIdentifier,
685
- statement_params: Optional[Dict[str, Any]] = None,
685
+ statement_params: Optional[dict[str, Any]] = None,
686
686
  ) -> None:
687
687
  services = self.show_services(
688
688
  database_name=database_name,
@@ -724,7 +724,7 @@ class ModelOperator:
724
724
  schema_name: Optional[sql_identifier.SqlIdentifier],
725
725
  model_name: sql_identifier.SqlIdentifier,
726
726
  version_name: sql_identifier.SqlIdentifier,
727
- statement_params: Optional[Dict[str, Any]] = None,
727
+ statement_params: Optional[dict[str, Any]] = None,
728
728
  ) -> model_manifest_schema.ModelManifestDict:
729
729
  with tempfile.TemporaryDirectory() as tmpdir:
730
730
  self._model_version_client.get_file(
@@ -741,9 +741,9 @@ class ModelOperator:
741
741
 
742
742
  @staticmethod
743
743
  def _match_model_spec_with_sql_functions(
744
- sql_functions_names: List[sql_identifier.SqlIdentifier], target_methods: List[str]
745
- ) -> Dict[sql_identifier.SqlIdentifier, str]:
746
- res: Dict[sql_identifier.SqlIdentifier, str] = {}
744
+ sql_functions_names: list[sql_identifier.SqlIdentifier], target_methods: list[str]
745
+ ) -> dict[sql_identifier.SqlIdentifier, str]:
746
+ res: dict[sql_identifier.SqlIdentifier, str] = {}
747
747
 
748
748
  for target_method in target_methods:
749
749
  # Here we need to find the SQL function corresponding to the Python function.
@@ -766,7 +766,7 @@ class ModelOperator:
766
766
  schema_name: Optional[sql_identifier.SqlIdentifier],
767
767
  model_name: sql_identifier.SqlIdentifier,
768
768
  version_name: sql_identifier.SqlIdentifier,
769
- statement_params: Optional[Dict[str, Any]] = None,
769
+ statement_params: Optional[dict[str, Any]] = None,
770
770
  ) -> model_meta_schema.ModelMetadataDict:
771
771
  raw_model_spec_res = self._model_client.show_versions(
772
772
  database_name=database_name,
@@ -787,7 +787,7 @@ class ModelOperator:
787
787
  schema_name: Optional[sql_identifier.SqlIdentifier],
788
788
  model_name: sql_identifier.SqlIdentifier,
789
789
  version_name: sql_identifier.SqlIdentifier,
790
- statement_params: Optional[Dict[str, Any]] = None,
790
+ statement_params: Optional[dict[str, Any]] = None,
791
791
  ) -> type_hints.Task:
792
792
  model_version = self._model_client.show_versions(
793
793
  database_name=database_name,
@@ -809,8 +809,8 @@ class ModelOperator:
809
809
  schema_name: Optional[sql_identifier.SqlIdentifier],
810
810
  model_name: sql_identifier.SqlIdentifier,
811
811
  version_name: sql_identifier.SqlIdentifier,
812
- statement_params: Optional[Dict[str, Any]] = None,
813
- ) -> List[model_manifest_schema.ModelFunctionInfo]:
812
+ statement_params: Optional[dict[str, Any]] = None,
813
+ ) -> list[model_manifest_schema.ModelFunctionInfo]:
814
814
  model_spec = self._fetch_model_spec(
815
815
  database_name=database_name,
816
816
  schema_name=schema_name,
@@ -907,7 +907,7 @@ class ModelOperator:
907
907
  version_name: sql_identifier.SqlIdentifier,
908
908
  strict_input_validation: bool = False,
909
909
  partition_column: Optional[sql_identifier.SqlIdentifier] = None,
910
- statement_params: Optional[Dict[str, str]] = None,
910
+ statement_params: Optional[dict[str, str]] = None,
911
911
  is_partitioned: Optional[bool] = None,
912
912
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
913
913
  ...
@@ -923,7 +923,7 @@ class ModelOperator:
923
923
  schema_name: Optional[sql_identifier.SqlIdentifier],
924
924
  service_name: sql_identifier.SqlIdentifier,
925
925
  strict_input_validation: bool = False,
926
- statement_params: Optional[Dict[str, str]] = None,
926
+ statement_params: Optional[dict[str, str]] = None,
927
927
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
928
928
  ...
929
929
 
@@ -941,7 +941,7 @@ class ModelOperator:
941
941
  service_name: Optional[sql_identifier.SqlIdentifier] = None,
942
942
  strict_input_validation: bool = False,
943
943
  partition_column: Optional[sql_identifier.SqlIdentifier] = None,
944
- statement_params: Optional[Dict[str, str]] = None,
944
+ statement_params: Optional[dict[str, str]] = None,
945
945
  is_partitioned: Optional[bool] = None,
946
946
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
947
947
  identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
@@ -1059,7 +1059,7 @@ class ModelOperator:
1059
1059
  schema_name: Optional[sql_identifier.SqlIdentifier],
1060
1060
  model_name: sql_identifier.SqlIdentifier,
1061
1061
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
1062
- statement_params: Optional[Dict[str, Any]] = None,
1062
+ statement_params: Optional[dict[str, Any]] = None,
1063
1063
  ) -> None:
1064
1064
  if version_name:
1065
1065
  self._model_version_client.drop_version(
@@ -1086,7 +1086,7 @@ class ModelOperator:
1086
1086
  new_model_db: Optional[sql_identifier.SqlIdentifier],
1087
1087
  new_model_schema: Optional[sql_identifier.SqlIdentifier],
1088
1088
  new_model_name: sql_identifier.SqlIdentifier,
1089
- statement_params: Optional[Dict[str, Any]] = None,
1089
+ statement_params: Optional[dict[str, Any]] = None,
1090
1090
  ) -> None:
1091
1091
  self._model_client.rename(
1092
1092
  database_name=database_name,
@@ -1121,7 +1121,7 @@ class ModelOperator:
1121
1121
  version_name: sql_identifier.SqlIdentifier,
1122
1122
  target_path: pathlib.Path,
1123
1123
  mode: Literal["full", "model", "minimal"] = "model",
1124
- statement_params: Optional[Dict[str, Any]] = None,
1124
+ statement_params: Optional[dict[str, Any]] = None,
1125
1125
  ) -> None:
1126
1126
  for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items():
1127
1127
  list_file_res = self._model_version_client.list_file(
@@ -6,14 +6,16 @@ import re
6
6
  import tempfile
7
7
  import threading
8
8
  import time
9
- from typing import Any, Dict, List, Optional, Tuple, Union, cast
9
+ from typing import Any, Optional, Union, cast
10
10
 
11
11
  from snowflake import snowpark
12
12
  from snowflake.ml._internal import file_utils, platform_capabilities as pc
13
- from snowflake.ml._internal.utils import service_logger, sql_identifier
13
+ from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
14
+ from snowflake.ml.model import model_signature, type_hints
14
15
  from snowflake.ml.model._client.service import model_deployment_spec
15
16
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
16
- from snowflake.snowpark import async_job, exceptions, row, session
17
+ from snowflake.ml.model._signatures import snowpark_handler
18
+ from snowflake.snowpark import async_job, dataframe, exceptions, row, session
17
19
  from snowflake.snowpark._internal import utils as snowpark_utils
18
20
 
19
21
  module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
@@ -104,9 +106,9 @@ class ServiceOperator:
104
106
  num_workers: Optional[int],
105
107
  max_batch_rows: Optional[int],
106
108
  force_rebuild: bool,
107
- build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
109
+ build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
108
110
  block: bool,
109
- statement_params: Optional[Dict[str, Any]] = None,
111
+ statement_params: Optional[dict[str, Any]] = None,
110
112
  ) -> Union[str, async_job.AsyncJob]:
111
113
 
112
114
  # Fall back to the registry's database and schema if not provided
@@ -120,32 +122,28 @@ class ServiceOperator:
120
122
  image_repo_database_name = image_repo_database_name or database_name or self._database_name
121
123
  image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
122
124
  if self._workspace:
123
- # create a temp stage
124
- stage_name = sql_identifier.SqlIdentifier(
125
- snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
126
- )
127
- self._stage_client.create_tmp_stage(
128
- database_name=database_name,
129
- schema_name=schema_name,
130
- stage_name=stage_name,
131
- statement_params=statement_params,
132
- )
133
- stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
125
+ stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
134
126
  else:
135
127
  stage_path = None
136
- spec_yaml_str_or_path = self._model_deployment_spec.save(
128
+ self._model_deployment_spec.add_model_spec(
137
129
  database_name=database_name,
138
130
  schema_name=schema_name,
139
131
  model_name=model_name,
140
132
  version_name=version_name,
141
- service_database_name=service_database_name,
142
- service_schema_name=service_schema_name,
143
- service_name=service_name,
133
+ )
134
+ self._model_deployment_spec.add_image_build_spec(
144
135
  image_build_compute_pool_name=image_build_compute_pool_name,
145
- service_compute_pool_name=service_compute_pool_name,
146
136
  image_repo_database_name=image_repo_database_name,
147
137
  image_repo_schema_name=image_repo_schema_name,
148
138
  image_repo_name=image_repo_name,
139
+ force_rebuild=force_rebuild,
140
+ external_access_integrations=build_external_access_integrations,
141
+ )
142
+ self._model_deployment_spec.add_service_spec(
143
+ service_database_name=service_database_name,
144
+ service_schema_name=service_schema_name,
145
+ service_name=service_name,
146
+ inference_compute_pool_name=service_compute_pool_name,
149
147
  ingress_enabled=ingress_enabled,
150
148
  max_instances=max_instances,
151
149
  cpu=cpu_requests,
@@ -153,9 +151,8 @@ class ServiceOperator:
153
151
  gpu=gpu_requests,
154
152
  num_workers=num_workers,
155
153
  max_batch_rows=max_batch_rows,
156
- force_rebuild=force_rebuild,
157
- external_access_integrations=build_external_access_integrations,
158
154
  )
155
+ spec_yaml_str_or_path = self._model_deployment_spec.save()
159
156
  if self._workspace:
160
157
  assert stage_path is not None
161
158
  file_utils.upload_directory_to_stage(
@@ -210,7 +207,7 @@ class ServiceOperator:
210
207
  if block:
211
208
  log_thread.join()
212
209
 
213
- res = cast(str, cast(List[row.Row], async_job.result())[0][0])
210
+ res = cast(str, cast(list[row.Row], async_job.result())[0][0])
214
211
  module_logger.info(f"Inference service {service_name} deployment complete: {res}")
215
212
  return res
216
213
  else:
@@ -219,10 +216,10 @@ class ServiceOperator:
219
216
  def _start_service_log_streaming(
220
217
  self,
221
218
  async_job: snowpark.AsyncJob,
222
- services: List[ServiceLogInfo],
219
+ services: list[ServiceLogInfo],
223
220
  model_inference_service_exists: bool,
224
221
  force_rebuild: bool,
225
- statement_params: Optional[Dict[str, Any]] = None,
222
+ statement_params: Optional[dict[str, Any]] = None,
226
223
  ) -> threading.Thread:
227
224
  """Start the service log streaming in a separate thread."""
228
225
  log_thread = threading.Thread(
@@ -241,14 +238,14 @@ class ServiceOperator:
241
238
  def _stream_service_logs(
242
239
  self,
243
240
  async_job: snowpark.AsyncJob,
244
- services: List[ServiceLogInfo],
241
+ services: list[ServiceLogInfo],
245
242
  model_inference_service_exists: bool,
246
243
  force_rebuild: bool,
247
- statement_params: Optional[Dict[str, Any]] = None,
244
+ statement_params: Optional[dict[str, Any]] = None,
248
245
  ) -> None:
249
246
  """Stream service logs while the async job is running."""
250
247
 
251
- def fetch_logs(service: ServiceLogInfo, offset: int) -> Tuple[str, int]:
248
+ def fetch_logs(service: ServiceLogInfo, offset: int) -> tuple[str, int]:
252
249
  service_logs = self._service_client.get_service_logs(
253
250
  database_name=service.database_name,
254
251
  schema_name=service.schema_name,
@@ -393,7 +390,7 @@ class ServiceOperator:
393
390
  service_logger: logging.Logger,
394
391
  service: ServiceLogInfo,
395
392
  offset: int,
396
- statement_params: Optional[Dict[str, Any]] = None,
393
+ statement_params: Optional[dict[str, Any]] = None,
397
394
  ) -> None:
398
395
  """Fetch service logs after the async job is done to ensure no logs are missed."""
399
396
  try:
@@ -425,8 +422,8 @@ class ServiceOperator:
425
422
  database_name: Optional[sql_identifier.SqlIdentifier],
426
423
  schema_name: Optional[sql_identifier.SqlIdentifier],
427
424
  service_name: sql_identifier.SqlIdentifier,
428
- service_status_list_if_exists: Optional[List[service_sql.ServiceStatus]] = None,
429
- statement_params: Optional[Dict[str, Any]] = None,
425
+ service_status_list_if_exists: Optional[list[service_sql.ServiceStatus]] = None,
426
+ statement_params: Optional[dict[str, Any]] = None,
430
427
  ) -> bool:
431
428
  if service_status_list_if_exists is None:
432
429
  service_status_list_if_exists = [
@@ -448,3 +445,191 @@ class ServiceOperator:
448
445
  return any(service_status == status for status in service_status_list_if_exists)
449
446
  except exceptions.SnowparkSQLException:
450
447
  return False
448
+
449
+ def invoke_job_method(
450
+ self,
451
+ target_method: str,
452
+ signature: model_signature.ModelSignature,
453
+ X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
454
+ database_name: Optional[sql_identifier.SqlIdentifier],
455
+ schema_name: Optional[sql_identifier.SqlIdentifier],
456
+ model_name: sql_identifier.SqlIdentifier,
457
+ version_name: sql_identifier.SqlIdentifier,
458
+ job_database_name: Optional[sql_identifier.SqlIdentifier],
459
+ job_schema_name: Optional[sql_identifier.SqlIdentifier],
460
+ job_name: sql_identifier.SqlIdentifier,
461
+ compute_pool_name: sql_identifier.SqlIdentifier,
462
+ warehouse_name: sql_identifier.SqlIdentifier,
463
+ image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
464
+ image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
465
+ image_repo_name: sql_identifier.SqlIdentifier,
466
+ output_table_database_name: Optional[sql_identifier.SqlIdentifier],
467
+ output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
468
+ output_table_name: sql_identifier.SqlIdentifier,
469
+ cpu_requests: Optional[str],
470
+ memory_requests: Optional[str],
471
+ gpu_requests: Optional[Union[int, str]],
472
+ num_workers: Optional[int],
473
+ max_batch_rows: Optional[int],
474
+ force_rebuild: bool,
475
+ build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
476
+ statement_params: Optional[dict[str, Any]] = None,
477
+ ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
478
+ # fall back to the registry's database and schema if not provided
479
+ database_name = database_name or self._database_name
480
+ schema_name = schema_name or self._schema_name
481
+
482
+ # fall back to the model's database and schema if not provided then to the registry's database and schema
483
+ job_database_name = job_database_name or database_name or self._database_name
484
+ job_schema_name = job_schema_name or schema_name or self._schema_name
485
+
486
+ image_repo_database_name = image_repo_database_name or database_name or self._database_name
487
+ image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
488
+
489
+ input_table_database_name = job_database_name
490
+ input_table_schema_name = job_schema_name
491
+ output_table_database_name = output_table_database_name or database_name or self._database_name
492
+ output_table_schema_name = output_table_schema_name or schema_name or self._schema_name
493
+
494
+ if self._workspace:
495
+ stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
496
+ else:
497
+ stage_path = None
498
+
499
+ # validate and prepare input
500
+ if not isinstance(X, dataframe.DataFrame):
501
+ keep_order = True
502
+ output_with_input_features = False
503
+ df = model_signature._convert_and_validate_local_data(X, signature.inputs)
504
+ s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
505
+ self._session, df, keep_order=keep_order, features=signature.inputs
506
+ )
507
+ else:
508
+ keep_order = False
509
+ output_with_input_features = True
510
+ s_df = X
511
+
512
+ # only write the index and feature input columns
513
+ cols = [snowpark_handler._KEEP_ORDER_COL_NAME] if snowpark_handler._KEEP_ORDER_COL_NAME in s_df.columns else []
514
+ cols += [
515
+ sql_identifier.SqlIdentifier(feature.name, case_sensitive=True).identifier() for feature in signature.inputs
516
+ ]
517
+ s_df = s_df.select(cols)
518
+ original_cols = s_df.columns
519
+
520
+ # input/output tables
521
+ fq_output_table_name = identifier.get_schema_level_object_identifier(
522
+ output_table_database_name.identifier(),
523
+ output_table_schema_name.identifier(),
524
+ output_table_name.identifier(),
525
+ )
526
+ tmp_input_table_id = sql_identifier.SqlIdentifier(
527
+ snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
528
+ )
529
+ fq_tmp_input_table_name = identifier.get_schema_level_object_identifier(
530
+ job_database_name.identifier(),
531
+ job_schema_name.identifier(),
532
+ tmp_input_table_id.identifier(),
533
+ )
534
+ s_df.write.save_as_table(
535
+ table_name=fq_tmp_input_table_name,
536
+ mode="errorifexists",
537
+ statement_params=statement_params,
538
+ )
539
+
540
+ try:
541
+ # save the spec
542
+ self._model_deployment_spec.add_model_spec(
543
+ database_name=database_name,
544
+ schema_name=schema_name,
545
+ model_name=model_name,
546
+ version_name=version_name,
547
+ )
548
+ self._model_deployment_spec.add_job_spec(
549
+ job_database_name=job_database_name,
550
+ job_schema_name=job_schema_name,
551
+ job_name=job_name,
552
+ inference_compute_pool_name=compute_pool_name,
553
+ cpu=cpu_requests,
554
+ memory=memory_requests,
555
+ gpu=gpu_requests,
556
+ num_workers=num_workers,
557
+ max_batch_rows=max_batch_rows,
558
+ warehouse=warehouse_name,
559
+ target_method=target_method,
560
+ input_table_database_name=input_table_database_name,
561
+ input_table_schema_name=input_table_schema_name,
562
+ input_table_name=tmp_input_table_id,
563
+ output_table_database_name=output_table_database_name,
564
+ output_table_schema_name=output_table_schema_name,
565
+ output_table_name=output_table_name,
566
+ )
567
+
568
+ self._model_deployment_spec.add_image_build_spec(
569
+ image_build_compute_pool_name=compute_pool_name,
570
+ image_repo_database_name=image_repo_database_name,
571
+ image_repo_schema_name=image_repo_schema_name,
572
+ image_repo_name=image_repo_name,
573
+ force_rebuild=force_rebuild,
574
+ external_access_integrations=build_external_access_integrations,
575
+ )
576
+
577
+ spec_yaml_str_or_path = self._model_deployment_spec.save()
578
+ if self._workspace:
579
+ assert stage_path is not None
580
+ file_utils.upload_directory_to_stage(
581
+ self._session,
582
+ local_path=pathlib.Path(self._workspace.name),
583
+ stage_path=pathlib.PurePosixPath(stage_path),
584
+ statement_params=statement_params,
585
+ )
586
+
587
+ # deploy the job
588
+ query_id, async_job = self._service_client.deploy_model(
589
+ stage_path=stage_path if self._workspace else None,
590
+ model_deployment_spec_file_rel_path=(
591
+ model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
592
+ ),
593
+ model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
594
+ statement_params=statement_params,
595
+ )
596
+
597
+ while not async_job.is_done():
598
+ time.sleep(5)
599
+ finally:
600
+ self._session.table(fq_tmp_input_table_name).drop_table()
601
+
602
+ # handle the output
603
+ df_res = self._session.table(fq_output_table_name)
604
+ if keep_order:
605
+ df_res = df_res.sort(
606
+ snowpark_handler._KEEP_ORDER_COL_NAME,
607
+ ascending=True,
608
+ )
609
+ df_res = df_res.drop(snowpark_handler._KEEP_ORDER_COL_NAME)
610
+
611
+ if not output_with_input_features:
612
+ df_res = df_res.drop(*original_cols)
613
+
614
+ # get final result
615
+ if not isinstance(X, dataframe.DataFrame):
616
+ return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(df_res, features=signature.outputs)
617
+ else:
618
+ return df_res
619
+
620
+ def _create_temp_stage(
621
+ self,
622
+ database_name: Optional[sql_identifier.SqlIdentifier],
623
+ schema_name: Optional[sql_identifier.SqlIdentifier],
624
+ statement_params: Optional[dict[str, Any]] = None,
625
+ ) -> str:
626
+ stage_name = sql_identifier.SqlIdentifier(
627
+ snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
628
+ )
629
+ self._stage_client.create_tmp_stage(
630
+ database_name=database_name,
631
+ schema_name=schema_name,
632
+ stage_name=stage_name,
633
+ statement_params=statement_params,
634
+ )
635
+ return self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name) # stage path