snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/cortex/_classify_text.py +3 -3
  3. snowflake/cortex/_complete.py +23 -24
  4. snowflake/cortex/_embed_text_1024.py +4 -4
  5. snowflake/cortex/_embed_text_768.py +4 -4
  6. snowflake/cortex/_finetune.py +8 -8
  7. snowflake/cortex/_util.py +8 -12
  8. snowflake/ml/_internal/env.py +4 -3
  9. snowflake/ml/_internal/env_utils.py +63 -34
  10. snowflake/ml/_internal/file_utils.py +10 -21
  11. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  12. snowflake/ml/_internal/init_utils.py +2 -3
  13. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  14. snowflake/ml/_internal/platform_capabilities.py +18 -16
  15. snowflake/ml/_internal/telemetry.py +39 -52
  16. snowflake/ml/_internal/type_utils.py +3 -3
  17. snowflake/ml/_internal/utils/db_utils.py +2 -2
  18. snowflake/ml/_internal/utils/identifier.py +10 -10
  19. snowflake/ml/_internal/utils/import_utils.py +2 -2
  20. snowflake/ml/_internal/utils/parallelize.py +7 -7
  21. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  22. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  23. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  24. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  25. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  26. snowflake/ml/_internal/utils/table_manager.py +9 -9
  27. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  28. snowflake/ml/data/data_connector.py +15 -36
  29. snowflake/ml/data/data_ingestor.py +4 -15
  30. snowflake/ml/data/data_source.py +2 -2
  31. snowflake/ml/data/ingestor_utils.py +3 -3
  32. snowflake/ml/data/torch_utils.py +5 -5
  33. snowflake/ml/dataset/dataset.py +11 -11
  34. snowflake/ml/dataset/dataset_metadata.py +8 -8
  35. snowflake/ml/dataset/dataset_reader.py +7 -7
  36. snowflake/ml/feature_store/__init__.py +1 -1
  37. snowflake/ml/feature_store/access_manager.py +7 -7
  38. snowflake/ml/feature_store/entity.py +6 -6
  39. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  41. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  45. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  48. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  51. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  52. snowflake/ml/feature_store/feature_store.py +52 -64
  53. snowflake/ml/feature_store/feature_view.py +24 -24
  54. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  55. snowflake/ml/fileset/fileset.py +5 -5
  56. snowflake/ml/fileset/sfcfs.py +13 -13
  57. snowflake/ml/fileset/stage_fs.py +15 -15
  58. snowflake/ml/jobs/_utils/constants.py +1 -1
  59. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  60. snowflake/ml/jobs/_utils/payload_utils.py +45 -46
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  64. snowflake/ml/jobs/_utils/spec_utils.py +18 -29
  65. snowflake/ml/jobs/_utils/types.py +2 -2
  66. snowflake/ml/jobs/decorators.py +10 -5
  67. snowflake/ml/jobs/job.py +87 -30
  68. snowflake/ml/jobs/manager.py +86 -56
  69. snowflake/ml/lineage/lineage_node.py +5 -5
  70. snowflake/ml/model/_client/model/model_impl.py +3 -3
  71. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  72. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  73. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  74. snowflake/ml/model/_client/ops/service_ops.py +217 -32
  75. snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
  76. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
  77. snowflake/ml/model/_client/sql/model.py +8 -8
  78. snowflake/ml/model/_client/sql/model_version.py +26 -26
  79. snowflake/ml/model/_client/sql/service.py +17 -26
  80. snowflake/ml/model/_client/sql/stage.py +2 -2
  81. snowflake/ml/model/_client/sql/tag.py +6 -6
  82. snowflake/ml/model/_model_composer/model_composer.py +58 -32
  83. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  85. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  86. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  87. snowflake/ml/model/_packager/model_handler.py +4 -4
  88. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  89. snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
  90. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  91. snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
  92. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  93. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  95. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  96. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
  97. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  98. snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
  99. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  100. snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
  101. snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
  102. snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
  103. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  104. snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
  105. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
  106. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  107. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  109. snowflake/ml/model/_packager/model_packager.py +11 -9
  110. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  111. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  112. snowflake/ml/model/_signatures/core.py +16 -24
  113. snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
  114. snowflake/ml/model/_signatures/utils.py +6 -6
  115. snowflake/ml/model/custom_model.py +24 -11
  116. snowflake/ml/model/model_signature.py +12 -23
  117. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  118. snowflake/ml/model/type_hints.py +3 -3
  119. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  120. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  122. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  123. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  124. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  125. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  126. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  128. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  129. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  130. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  131. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  132. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  133. snowflake/ml/modeling/cluster/birch.py +9 -1
  134. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  135. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  136. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  137. snowflake/ml/modeling/cluster/k_means.py +9 -1
  138. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  139. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  140. snowflake/ml/modeling/cluster/optics.py +9 -1
  141. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  142. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  143. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  144. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  145. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  146. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  147. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  148. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  149. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  150. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  151. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  152. snowflake/ml/modeling/covariance/oas.py +9 -1
  153. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  154. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  155. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  156. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  157. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  158. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  159. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  160. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  161. snowflake/ml/modeling/decomposition/pca.py +9 -1
  162. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  163. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  164. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  165. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  166. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  167. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  168. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  169. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  170. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  171. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  172. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  173. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  174. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  175. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  176. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  177. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  178. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  179. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  180. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  181. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  182. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  183. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  184. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  185. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  186. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  187. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  188. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  189. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  190. snowflake/ml/modeling/framework/_utils.py +10 -10
  191. snowflake/ml/modeling/framework/base.py +32 -32
  192. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  193. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  194. snowflake/ml/modeling/impute/__init__.py +1 -1
  195. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  196. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  197. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  198. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  199. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  200. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  201. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  202. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  203. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  204. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  205. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  206. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  207. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  208. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  209. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  210. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  211. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  212. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  213. snowflake/ml/modeling/linear_model/lars.py +9 -1
  214. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  215. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  216. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  217. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  218. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  219. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  220. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  221. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  222. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  223. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  224. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  225. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  226. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  227. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  228. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  229. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  230. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  231. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  232. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  233. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  234. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  235. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  236. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  237. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  238. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  239. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  240. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  241. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  242. snowflake/ml/modeling/manifold/isomap.py +9 -1
  243. snowflake/ml/modeling/manifold/mds.py +9 -1
  244. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  245. snowflake/ml/modeling/manifold/tsne.py +9 -1
  246. snowflake/ml/modeling/metrics/__init__.py +1 -1
  247. snowflake/ml/modeling/metrics/classification.py +39 -39
  248. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  249. snowflake/ml/modeling/metrics/ranking.py +7 -7
  250. snowflake/ml/modeling/metrics/regression.py +13 -13
  251. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  252. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  253. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  254. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  255. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  256. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  257. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  258. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  259. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  260. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  261. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  262. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  263. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  264. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  265. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  266. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  267. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  268. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  269. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  270. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  271. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  272. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  273. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  274. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  275. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  276. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  277. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  278. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  279. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  280. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  281. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  282. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  283. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  284. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  285. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  286. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  287. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  288. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  289. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  290. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  291. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  292. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  293. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  294. snowflake/ml/modeling/svm/svc.py +9 -1
  295. snowflake/ml/modeling/svm/svr.py +9 -1
  296. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  297. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  298. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  299. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  300. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  301. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  302. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  303. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  304. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  305. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  306. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  307. snowflake/ml/monitoring/explain_visualize.py +286 -0
  308. snowflake/ml/registry/_manager/model_manager.py +55 -32
  309. snowflake/ml/registry/registry.py +39 -31
  310. snowflake/ml/utils/authentication.py +2 -2
  311. snowflake/ml/utils/connection_params.py +5 -5
  312. snowflake/ml/utils/sparse.py +5 -4
  313. snowflake/ml/utils/sql_client.py +1 -2
  314. snowflake/ml/version.py +2 -1
  315. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
  316. snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
  317. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
  318. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  319. snowflake/ml/modeling/_internal/constants.py +0 -2
  320. snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
  321. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  322. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,7 @@
1
+ import logging
1
2
  import pathlib
2
3
  import textwrap
3
- from typing import (
4
- Any,
5
- Callable,
6
- Dict,
7
- List,
8
- Literal,
9
- Optional,
10
- TypeVar,
11
- Union,
12
- overload,
13
- )
4
+ from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast, overload
14
5
  from uuid import uuid4
15
6
 
16
7
  import yaml
@@ -23,13 +14,14 @@ from snowflake.ml.jobs._utils import payload_utils, spec_utils
23
14
  from snowflake.snowpark.context import get_active_session
24
15
  from snowflake.snowpark.exceptions import SnowparkSQLException
25
16
 
17
+ logger = logging.getLogger(__name__)
18
+
26
19
  _PROJECT = "MLJob"
27
20
  JOB_ID_PREFIX = "MLJOB_"
28
21
 
29
22
  T = TypeVar("T")
30
23
 
31
24
 
32
- @snowpark._internal.utils.private_preview(version="1.7.4")
33
25
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
34
26
  def list_jobs(
35
27
  limit: int = 10,
@@ -60,7 +52,7 @@ def list_jobs(
60
52
  query += f" LIMIT {limit}"
61
53
  df = session.sql(query)
62
54
  df = df.select(
63
- df['"name"'].alias('"id"'),
55
+ df['"name"'],
64
56
  df['"owner"'],
65
57
  df['"status"'],
66
58
  df['"created_on"'],
@@ -69,21 +61,20 @@ def list_jobs(
69
61
  return df
70
62
 
71
63
 
72
- @snowpark._internal.utils.private_preview(version="1.7.4")
73
64
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
74
65
  def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob[Any]:
75
66
  """Retrieve a job service from the backend."""
76
67
  session = session or get_active_session()
77
-
78
68
  try:
79
- # Validate job_id
80
- job_id = identifier.resolve_identifier(job_id)
69
+ database, schema, job_name = identifier.parse_schema_level_object_identifier(job_id)
70
+ database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
71
+ schema = identifier.resolve_identifier(cast(str, schema or session.get_current_schema()))
81
72
  except ValueError as e:
82
73
  raise ValueError(f"Invalid job ID: {job_id}") from e
83
74
 
75
+ job_id = f"{database}.{schema}.{job_name}"
84
76
  try:
85
77
  # Validate that job exists by doing a status check
86
- # FIXME: Retrieve return path
87
78
  job = jb.MLJob[Any](job_id, session=session)
88
79
  _ = job.status
89
80
  return job
@@ -93,7 +84,6 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
93
84
  raise
94
85
 
95
86
 
96
- @snowpark._internal.utils.private_preview(version="1.7.4")
97
87
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
98
88
  def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
99
89
  """Delete a job service from the backend. Status and logs will be lost."""
@@ -106,21 +96,22 @@ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Sessio
106
96
  session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
107
97
 
108
98
 
109
- @snowpark._internal.utils.private_preview(version="1.7.4")
110
99
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
111
100
  def submit_file(
112
101
  file_path: str,
113
102
  compute_pool: str,
114
103
  *,
115
104
  stage_name: str,
116
- args: Optional[List[str]] = None,
117
- env_vars: Optional[Dict[str, str]] = None,
118
- pip_requirements: Optional[List[str]] = None,
119
- external_access_integrations: Optional[List[str]] = None,
105
+ args: Optional[list[str]] = None,
106
+ env_vars: Optional[dict[str, str]] = None,
107
+ pip_requirements: Optional[list[str]] = None,
108
+ external_access_integrations: Optional[list[str]] = None,
120
109
  query_warehouse: Optional[str] = None,
121
- spec_overrides: Optional[Dict[str, Any]] = None,
110
+ spec_overrides: Optional[dict[str, Any]] = None,
122
111
  num_instances: Optional[int] = None,
123
112
  enable_metrics: bool = False,
113
+ database: Optional[str] = None,
114
+ schema: Optional[str] = None,
124
115
  session: Optional[snowpark.Session] = None,
125
116
  ) -> jb.MLJob[None]:
126
117
  """
@@ -138,6 +129,8 @@ def submit_file(
138
129
  spec_overrides: Custom service specification overrides to apply.
139
130
  num_instances: The number of instances to use for the job. If none specified, single node job is created.
140
131
  enable_metrics: Whether to enable metrics publishing for the job.
132
+ database: The database to use.
133
+ schema: The schema to use.
141
134
  session: The Snowpark session to use. If none specified, uses active session.
142
135
 
143
136
  Returns:
@@ -155,11 +148,12 @@ def submit_file(
155
148
  spec_overrides=spec_overrides,
156
149
  num_instances=num_instances,
157
150
  enable_metrics=enable_metrics,
151
+ database=database,
152
+ schema=schema,
158
153
  session=session,
159
154
  )
160
155
 
161
156
 
162
- @snowpark._internal.utils.private_preview(version="1.7.4")
163
157
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
164
158
  def submit_directory(
165
159
  dir_path: str,
@@ -167,14 +161,16 @@ def submit_directory(
167
161
  *,
168
162
  entrypoint: str,
169
163
  stage_name: str,
170
- args: Optional[List[str]] = None,
171
- env_vars: Optional[Dict[str, str]] = None,
172
- pip_requirements: Optional[List[str]] = None,
173
- external_access_integrations: Optional[List[str]] = None,
164
+ args: Optional[list[str]] = None,
165
+ env_vars: Optional[dict[str, str]] = None,
166
+ pip_requirements: Optional[list[str]] = None,
167
+ external_access_integrations: Optional[list[str]] = None,
174
168
  query_warehouse: Optional[str] = None,
175
- spec_overrides: Optional[Dict[str, Any]] = None,
169
+ spec_overrides: Optional[dict[str, Any]] = None,
176
170
  num_instances: Optional[int] = None,
177
171
  enable_metrics: bool = False,
172
+ database: Optional[str] = None,
173
+ schema: Optional[str] = None,
178
174
  session: Optional[snowpark.Session] = None,
179
175
  ) -> jb.MLJob[None]:
180
176
  """
@@ -193,6 +189,8 @@ def submit_directory(
193
189
  spec_overrides: Custom service specification overrides to apply.
194
190
  num_instances: The number of instances to use for the job. If none specified, single node job is created.
195
191
  enable_metrics: Whether to enable metrics publishing for the job.
192
+ database: The database to use.
193
+ schema: The schema to use.
196
194
  session: The Snowpark session to use. If none specified, uses active session.
197
195
 
198
196
  Returns:
@@ -211,6 +209,8 @@ def submit_directory(
211
209
  spec_overrides=spec_overrides,
212
210
  num_instances=num_instances,
213
211
  enable_metrics=enable_metrics,
212
+ database=database,
213
+ schema=schema,
214
214
  session=session,
215
215
  )
216
216
 
@@ -222,14 +222,16 @@ def _submit_job(
222
222
  *,
223
223
  stage_name: str,
224
224
  entrypoint: Optional[str] = None,
225
- args: Optional[List[str]] = None,
226
- env_vars: Optional[Dict[str, str]] = None,
227
- pip_requirements: Optional[List[str]] = None,
228
- external_access_integrations: Optional[List[str]] = None,
225
+ args: Optional[list[str]] = None,
226
+ env_vars: Optional[dict[str, str]] = None,
227
+ pip_requirements: Optional[list[str]] = None,
228
+ external_access_integrations: Optional[list[str]] = None,
229
229
  query_warehouse: Optional[str] = None,
230
- spec_overrides: Optional[Dict[str, Any]] = None,
230
+ spec_overrides: Optional[dict[str, Any]] = None,
231
231
  num_instances: Optional[int] = None,
232
232
  enable_metrics: bool = False,
233
+ database: Optional[str] = None,
234
+ schema: Optional[str] = None,
233
235
  session: Optional[snowpark.Session] = None,
234
236
  ) -> jb.MLJob[None]:
235
237
  ...
@@ -242,14 +244,16 @@ def _submit_job(
242
244
  *,
243
245
  stage_name: str,
244
246
  entrypoint: Optional[str] = None,
245
- args: Optional[List[str]] = None,
246
- env_vars: Optional[Dict[str, str]] = None,
247
- pip_requirements: Optional[List[str]] = None,
248
- external_access_integrations: Optional[List[str]] = None,
247
+ args: Optional[list[str]] = None,
248
+ env_vars: Optional[dict[str, str]] = None,
249
+ pip_requirements: Optional[list[str]] = None,
250
+ external_access_integrations: Optional[list[str]] = None,
249
251
  query_warehouse: Optional[str] = None,
250
- spec_overrides: Optional[Dict[str, Any]] = None,
252
+ spec_overrides: Optional[dict[str, Any]] = None,
251
253
  num_instances: Optional[int] = None,
252
254
  enable_metrics: bool = False,
255
+ database: Optional[str] = None,
256
+ schema: Optional[str] = None,
253
257
  session: Optional[snowpark.Session] = None,
254
258
  ) -> jb.MLJob[T]:
255
259
  ...
@@ -263,6 +267,8 @@ def _submit_job(
263
267
  # TODO: Log lengths of args, env_vars, and spec_overrides values
264
268
  "pip_requirements",
265
269
  "external_access_integrations",
270
+ "num_instances",
271
+ "enable_metrics",
266
272
  ],
267
273
  )
268
274
  def _submit_job(
@@ -271,14 +277,16 @@ def _submit_job(
271
277
  *,
272
278
  stage_name: str,
273
279
  entrypoint: Optional[str] = None,
274
- args: Optional[List[str]] = None,
275
- env_vars: Optional[Dict[str, str]] = None,
276
- pip_requirements: Optional[List[str]] = None,
277
- external_access_integrations: Optional[List[str]] = None,
280
+ args: Optional[list[str]] = None,
281
+ env_vars: Optional[dict[str, str]] = None,
282
+ pip_requirements: Optional[list[str]] = None,
283
+ external_access_integrations: Optional[list[str]] = None,
278
284
  query_warehouse: Optional[str] = None,
279
- spec_overrides: Optional[Dict[str, Any]] = None,
285
+ spec_overrides: Optional[dict[str, Any]] = None,
280
286
  num_instances: Optional[int] = None,
281
287
  enable_metrics: bool = False,
288
+ database: Optional[str] = None,
289
+ schema: Optional[str] = None,
282
290
  session: Optional[snowpark.Session] = None,
283
291
  ) -> jb.MLJob[T]:
284
292
  """
@@ -297,6 +305,8 @@ def _submit_job(
297
305
  spec_overrides: Custom service specification overrides to apply.
298
306
  num_instances: The number of instances to use for the job. If none specified, single node job is created.
299
307
  enable_metrics: Whether to enable metrics publishing for the job.
308
+ database: The database to use.
309
+ schema: The schema to use.
300
310
  session: The Snowpark session to use. If none specified, uses active session.
301
311
 
302
312
  Returns:
@@ -304,11 +314,28 @@ def _submit_job(
304
314
 
305
315
  Raises:
306
316
  RuntimeError: If required Snowflake features are not enabled.
317
+ ValueError: If database or schema value(s) are invalid
307
318
  """
319
+ # Display warning about PrPr parameters
320
+ if num_instances is not None:
321
+ logger.warning(
322
+ "_submit_job() parameter 'num_instances' is in private preview since 1.8.2. Do not use it in production.",
323
+ )
324
+ if database and not schema:
325
+ raise ValueError("Schema must be specified if database is specified.")
326
+
308
327
  session = session or get_active_session()
309
- job_id = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
310
- stage_name = "@" + stage_name.lstrip("@").rstrip("/")
311
- stage_path = pathlib.PurePosixPath(f"{stage_name}/{job_id}")
328
+
329
+ # Validate database and schema identifiers on client side since
330
+ # SQL parser for EXECUTE JOB SERVICE seems to struggle with this
331
+ database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
332
+ schema = identifier.resolve_identifier(cast(str, schema or session.get_current_schema()))
333
+
334
+ job_name = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
335
+ job_id = f"{database}.{schema}.{job_name}"
336
+ stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
337
+ stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
338
+ stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
312
339
 
313
340
  # Upload payload
314
341
  uploaded_payload = payload_utils.JobPayload(
@@ -335,31 +362,34 @@ def _submit_job(
335
362
 
336
363
  # Generate SQL command for job submission
337
364
  query_template = textwrap.dedent(
338
- f"""\
365
+ """\
339
366
  EXECUTE JOB SERVICE
340
- IN COMPUTE POOL {compute_pool}
367
+ IN COMPUTE POOL IDENTIFIER(?)
341
368
  FROM SPECIFICATION $$
342
- {{}}
369
+ {}
343
370
  $$
344
- NAME = {job_id}
371
+ NAME = IDENTIFIER(?)
345
372
  ASYNC = TRUE
346
373
  """
347
374
  )
375
+ params: list[Any] = [compute_pool, job_id]
348
376
  query = query_template.format(yaml.dump(spec)).splitlines()
349
377
  if external_access_integrations:
350
378
  external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
351
379
  query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
352
380
  query_warehouse = query_warehouse or session.get_current_warehouse()
353
381
  if query_warehouse:
354
- query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
382
+ query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
383
+ params.append(query_warehouse)
355
384
  if num_instances:
356
- query.append(f"REPLICAS = {num_instances}")
385
+ query.append("REPLICAS = ?")
386
+ params.append(num_instances)
357
387
 
358
388
  # Submit job
359
389
  query_text = "\n".join(line for line in query if line)
360
390
 
361
391
  try:
362
- _ = session.sql(query_text).collect()
392
+ _ = session.sql(query_text, params=params).collect()
363
393
  except SnowparkSQLException as e:
364
394
  if "invalid property 'ASYNC'" in e.message:
365
395
  raise RuntimeError(
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from datetime import datetime
3
- from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Set, Type, Union
3
+ from typing import TYPE_CHECKING, Literal, Optional, Union
4
4
 
5
5
  from snowflake import snowpark
6
6
  from snowflake.ml._internal import telemetry
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
12
12
  from snowflake.ml.model._client.model import model_version_impl
13
13
 
14
14
  _PROJECT = "LINEAGE"
15
- DOMAIN_LINEAGE_REGISTRY: Dict[str, Type["LineageNode"]] = {}
15
+ DOMAIN_LINEAGE_REGISTRY: dict[str, type["LineageNode"]] = {}
16
16
 
17
17
 
18
18
  class LineageNode:
@@ -87,8 +87,8 @@ class LineageNode:
87
87
  def lineage(
88
88
  self,
89
89
  direction: Literal["upstream", "downstream"] = "downstream",
90
- domain_filter: Optional[Set[Literal["feature_view", "dataset", "model", "table", "view"]]] = None,
91
- ) -> List[Union["feature_view.FeatureView", "dataset.Dataset", "model_version_impl.ModelVersion", "LineageNode"]]:
90
+ domain_filter: Optional[set[Literal["feature_view", "dataset", "model", "table", "view"]]] = None,
91
+ ) -> list[Union["feature_view.FeatureView", "dataset.Dataset", "model_version_impl.ModelVersion", "LineageNode"]]:
92
92
  """
93
93
  Retrieves the lineage nodes connected to this node.
94
94
 
@@ -109,7 +109,7 @@ class LineageNode:
109
109
  if domain_filter is not None:
110
110
  domain_filter = {d.lower() for d in domain_filter} # type: ignore[misc]
111
111
 
112
- lineage_nodes: List["LineageNode"] = []
112
+ lineage_nodes: list["LineageNode"] = []
113
113
  for row in df.collect():
114
114
  lineage_object = (
115
115
  json.loads(row["TARGET_OBJECT"])
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Optional, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  import pandas as pd
4
4
 
@@ -224,7 +224,7 @@ class Model:
224
224
  project=_TELEMETRY_PROJECT,
225
225
  subproject=_TELEMETRY_SUBPROJECT,
226
226
  )
227
- def versions(self) -> List[model_version_impl.ModelVersion]:
227
+ def versions(self) -> list[model_version_impl.ModelVersion]:
228
228
  """Get all versions in the model.
229
229
 
230
230
  Returns:
@@ -298,7 +298,7 @@ class Model:
298
298
  project=_TELEMETRY_PROJECT,
299
299
  subproject=_TELEMETRY_SUBPROJECT,
300
300
  )
301
- def show_tags(self) -> Dict[str, str]:
301
+ def show_tags(self) -> dict[str, str]:
302
302
  """Get a dictionary showing the tag and its value attached to the model.
303
303
 
304
304
  Returns:
@@ -2,10 +2,11 @@ import enum
2
2
  import pathlib
3
3
  import tempfile
4
4
  import warnings
5
- from typing import Any, Callable, Dict, List, Optional, Union, overload
5
+ from typing import Any, Callable, Optional, Union, overload
6
6
 
7
7
  import pandas as pd
8
8
 
9
+ from snowflake import snowpark
9
10
  from snowflake.ml._internal import telemetry
10
11
  from snowflake.ml._internal.utils import sql_identifier
11
12
  from snowflake.ml.lineage import lineage_node
@@ -32,7 +33,7 @@ class ModelVersion(lineage_node.LineageNode):
32
33
  _service_ops: service_ops.ServiceOperator
33
34
  _model_name: sql_identifier.SqlIdentifier
34
35
  _version_name: sql_identifier.SqlIdentifier
35
- _functions: List[model_manifest_schema.ModelFunctionInfo]
36
+ _functions: list[model_manifest_schema.ModelFunctionInfo]
36
37
 
37
38
  def __init__(self) -> None:
38
39
  raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
@@ -152,7 +153,7 @@ class ModelVersion(lineage_node.LineageNode):
152
153
  project=_TELEMETRY_PROJECT,
153
154
  subproject=_TELEMETRY_SUBPROJECT,
154
155
  )
155
- def show_metrics(self) -> Dict[str, Any]:
156
+ def show_metrics(self) -> dict[str, Any]:
156
157
  """Show all metrics logged with the model version.
157
158
 
158
159
  Returns:
@@ -293,7 +294,7 @@ class ModelVersion(lineage_node.LineageNode):
293
294
  statement_params=statement_params,
294
295
  )
295
296
 
296
- def _get_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]:
297
+ def _get_functions(self) -> list[model_manifest_schema.ModelFunctionInfo]:
297
298
  statement_params = telemetry.get_statement_params(
298
299
  project=_TELEMETRY_PROJECT,
299
300
  subproject=_TELEMETRY_SUBPROJECT,
@@ -327,7 +328,7 @@ class ModelVersion(lineage_node.LineageNode):
327
328
  project=_TELEMETRY_PROJECT,
328
329
  subproject=_TELEMETRY_SUBPROJECT,
329
330
  )
330
- def show_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]:
331
+ def show_functions(self) -> list[model_manifest_schema.ModelFunctionInfo]:
331
332
  """Show all functions information in a model version that is callable.
332
333
 
333
334
  Returns:
@@ -405,11 +406,6 @@ class ModelVersion(lineage_node.LineageNode):
405
406
  strict_input_validation: Enable stricter validation for the input data. This will result value range based
406
407
  type validation to make sure your input data won't overflow when providing to the model.
407
408
 
408
- Raises:
409
- ValueError: When no method with the corresponding name is available.
410
- ValueError: When there are more than 1 target methods available in the model but no function name specified.
411
- ValueError: When the partition column is not a valid Snowflake identifier.
412
-
413
409
  Returns:
414
410
  The prediction data. It would be the same type dataframe as your input.
415
411
  """
@@ -422,29 +418,7 @@ class ModelVersion(lineage_node.LineageNode):
422
418
  # Partition column must be a valid identifier
423
419
  partition_column = sql_identifier.SqlIdentifier(partition_column)
424
420
 
425
- functions: List[model_manifest_schema.ModelFunctionInfo] = self._functions
426
-
427
- if function_name:
428
- req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
429
- find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
430
- lambda method: method["name"] == req_method_name
431
- )
432
- target_function_info = next(
433
- filter(find_method, functions),
434
- None,
435
- )
436
- if target_function_info is None:
437
- raise ValueError(
438
- f"There is no method with name {function_name} available in the model"
439
- f" {self.fully_qualified_model_name} version {self.version_name}"
440
- )
441
- elif len(functions) != 1:
442
- raise ValueError(
443
- f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
444
- f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
445
- )
446
- else:
447
- target_function_info = functions[0]
421
+ target_function_info = self._get_function_info(function_name=function_name)
448
422
 
449
423
  if service_name:
450
424
  database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
@@ -475,6 +449,33 @@ class ModelVersion(lineage_node.LineageNode):
475
449
  is_partitioned=target_function_info["is_partitioned"],
476
450
  )
477
451
 
452
+ def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
453
+ functions: list[model_manifest_schema.ModelFunctionInfo] = self._functions
454
+
455
+ if function_name:
456
+ req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
457
+ find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
458
+ lambda method: method["name"] == req_method_name
459
+ )
460
+ target_function_info = next(
461
+ filter(find_method, functions),
462
+ None,
463
+ )
464
+ if target_function_info is None:
465
+ raise ValueError(
466
+ f"There is no method with name {function_name} available in the model"
467
+ f" {self.fully_qualified_model_name} version {self.version_name}"
468
+ )
469
+ elif len(functions) != 1:
470
+ raise ValueError(
471
+ f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
472
+ f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
473
+ )
474
+ else:
475
+ target_function_info = functions[0]
476
+
477
+ return target_function_info
478
+
478
479
  @telemetry.send_api_usage_telemetry(
479
480
  project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
480
481
  )
@@ -684,7 +685,7 @@ class ModelVersion(lineage_node.LineageNode):
684
685
  num_workers: Optional[int] = None,
685
686
  max_batch_rows: Optional[int] = None,
686
687
  force_rebuild: bool = False,
687
- build_external_access_integrations: Optional[List[str]] = None,
688
+ build_external_access_integrations: Optional[list[str]] = None,
688
689
  block: bool = True,
689
690
  ) -> Union[str, async_job.AsyncJob]:
690
691
  """Create an inference service with the given spec.
@@ -751,7 +752,7 @@ class ModelVersion(lineage_node.LineageNode):
751
752
  max_batch_rows: Optional[int] = None,
752
753
  force_rebuild: bool = False,
753
754
  build_external_access_integration: Optional[str] = None,
754
- build_external_access_integrations: Optional[List[str]] = None,
755
+ build_external_access_integrations: Optional[list[str]] = None,
755
756
  block: bool = True,
756
757
  ) -> Union[str, async_job.AsyncJob]:
757
758
  """Create an inference service with the given spec.
@@ -914,5 +915,72 @@ class ModelVersion(lineage_node.LineageNode):
914
915
  statement_params=statement_params,
915
916
  )
916
917
 
918
+ @snowpark._internal.utils.private_preview(version="1.8.3")
919
+ @telemetry.send_api_usage_telemetry(
920
+ project=_TELEMETRY_PROJECT,
921
+ subproject=_TELEMETRY_SUBPROJECT,
922
+ )
923
+ def _run_job(
924
+ self,
925
+ X: Union[pd.DataFrame, "dataframe.DataFrame"],
926
+ *,
927
+ job_name: str,
928
+ compute_pool: str,
929
+ image_repo: str,
930
+ output_table_name: str,
931
+ function_name: Optional[str] = None,
932
+ cpu_requests: Optional[str] = None,
933
+ memory_requests: Optional[str] = None,
934
+ gpu_requests: Optional[Union[str, int]] = None,
935
+ num_workers: Optional[int] = None,
936
+ max_batch_rows: Optional[int] = None,
937
+ force_rebuild: bool = False,
938
+ build_external_access_integrations: Optional[list[str]] = None,
939
+ ) -> Union[pd.DataFrame, dataframe.DataFrame]:
940
+ statement_params = telemetry.get_statement_params(
941
+ project=_TELEMETRY_PROJECT,
942
+ subproject=_TELEMETRY_SUBPROJECT,
943
+ )
944
+ target_function_info = self._get_function_info(function_name=function_name)
945
+ job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
946
+ image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
947
+ output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
948
+ output_table_name
949
+ )
950
+ warehouse = self._service_ops._session.get_current_warehouse()
951
+ assert warehouse, "No active warehouse selected in the current session."
952
+ return self._service_ops.invoke_job_method(
953
+ target_method=target_function_info["target_method"],
954
+ signature=target_function_info["signature"],
955
+ X=X,
956
+ database_name=None,
957
+ schema_name=None,
958
+ model_name=self._model_name,
959
+ version_name=self._version_name,
960
+ job_database_name=job_db_id,
961
+ job_schema_name=job_schema_id,
962
+ job_name=job_id,
963
+ compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
964
+ warehouse_name=sql_identifier.SqlIdentifier(warehouse),
965
+ image_repo_database_name=image_repo_db_id,
966
+ image_repo_schema_name=image_repo_schema_id,
967
+ image_repo_name=image_repo_id,
968
+ output_table_database_name=output_table_db_id,
969
+ output_table_schema_name=output_table_schema_id,
970
+ output_table_name=output_table_id,
971
+ cpu_requests=cpu_requests,
972
+ memory_requests=memory_requests,
973
+ gpu_requests=gpu_requests,
974
+ num_workers=num_workers,
975
+ max_batch_rows=max_batch_rows,
976
+ force_rebuild=force_rebuild,
977
+ build_external_access_integrations=(
978
+ None
979
+ if build_external_access_integrations is None
980
+ else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
981
+ ),
982
+ statement_params=statement_params,
983
+ )
984
+
917
985
 
918
986
  lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Any, Dict, Optional, TypedDict
2
+ from typing import Any, Optional, TypedDict
3
3
 
4
4
  from typing_extensions import NotRequired
5
5
 
@@ -14,7 +14,7 @@ MODEL_VERSION_METADATA_SCHEMA_VERSION = "2024-01-01"
14
14
 
15
15
 
16
16
  class ModelVersionMetadataSchema(TypedDict):
17
- metrics: NotRequired[Dict[str, Any]]
17
+ metrics: NotRequired[dict[str, Any]]
18
18
 
19
19
 
20
20
  class MetadataOperator:
@@ -44,7 +44,7 @@ class MetadataOperator:
44
44
  )
45
45
 
46
46
  @staticmethod
47
- def _parse(metadata_dict: Dict[str, Any]) -> ModelVersionMetadataSchema:
47
+ def _parse(metadata_dict: dict[str, Any]) -> ModelVersionMetadataSchema:
48
48
  loaded_metadata_schema_version = metadata_dict.get("snowpark_ml_schema_version", None)
49
49
  if loaded_metadata_schema_version is None:
50
50
  return ModelVersionMetadataSchema(metrics={})
@@ -65,8 +65,8 @@ class MetadataOperator:
65
65
  schema_name: Optional[sql_identifier.SqlIdentifier],
66
66
  model_name: sql_identifier.SqlIdentifier,
67
67
  version_name: sql_identifier.SqlIdentifier,
68
- statement_params: Optional[Dict[str, Any]] = None,
69
- ) -> Dict[str, Any]:
68
+ statement_params: Optional[dict[str, Any]] = None,
69
+ ) -> dict[str, Any]:
70
70
  version_info_list = self._model_client.show_versions(
71
71
  database_name=database_name,
72
72
  schema_name=schema_name,
@@ -89,7 +89,7 @@ class MetadataOperator:
89
89
  schema_name: Optional[sql_identifier.SqlIdentifier],
90
90
  model_name: sql_identifier.SqlIdentifier,
91
91
  version_name: sql_identifier.SqlIdentifier,
92
- statement_params: Optional[Dict[str, Any]] = None,
92
+ statement_params: Optional[dict[str, Any]] = None,
93
93
  ) -> ModelVersionMetadataSchema:
94
94
  metadata_dict = self._get_current_metadata_dict(
95
95
  database_name=database_name,
@@ -108,7 +108,7 @@ class MetadataOperator:
108
108
  schema_name: Optional[sql_identifier.SqlIdentifier],
109
109
  model_name: sql_identifier.SqlIdentifier,
110
110
  version_name: sql_identifier.SqlIdentifier,
111
- statement_params: Optional[Dict[str, Any]] = None,
111
+ statement_params: Optional[dict[str, Any]] = None,
112
112
  ) -> None:
113
113
  metadata_dict = self._get_current_metadata_dict(
114
114
  database_name=database_name,