snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/cortex/_classify_text.py +3 -3
  3. snowflake/cortex/_complete.py +23 -24
  4. snowflake/cortex/_embed_text_1024.py +4 -4
  5. snowflake/cortex/_embed_text_768.py +4 -4
  6. snowflake/cortex/_finetune.py +8 -8
  7. snowflake/cortex/_util.py +8 -12
  8. snowflake/ml/_internal/env.py +4 -3
  9. snowflake/ml/_internal/env_utils.py +63 -34
  10. snowflake/ml/_internal/file_utils.py +10 -21
  11. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  12. snowflake/ml/_internal/init_utils.py +2 -3
  13. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  14. snowflake/ml/_internal/platform_capabilities.py +18 -16
  15. snowflake/ml/_internal/telemetry.py +39 -52
  16. snowflake/ml/_internal/type_utils.py +3 -3
  17. snowflake/ml/_internal/utils/db_utils.py +2 -2
  18. snowflake/ml/_internal/utils/identifier.py +10 -10
  19. snowflake/ml/_internal/utils/import_utils.py +2 -2
  20. snowflake/ml/_internal/utils/parallelize.py +7 -7
  21. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  22. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  23. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  24. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  25. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  26. snowflake/ml/_internal/utils/table_manager.py +9 -9
  27. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  28. snowflake/ml/data/data_connector.py +15 -36
  29. snowflake/ml/data/data_ingestor.py +4 -15
  30. snowflake/ml/data/data_source.py +2 -2
  31. snowflake/ml/data/ingestor_utils.py +3 -3
  32. snowflake/ml/data/torch_utils.py +5 -5
  33. snowflake/ml/dataset/dataset.py +11 -11
  34. snowflake/ml/dataset/dataset_metadata.py +8 -8
  35. snowflake/ml/dataset/dataset_reader.py +7 -7
  36. snowflake/ml/feature_store/__init__.py +1 -1
  37. snowflake/ml/feature_store/access_manager.py +7 -7
  38. snowflake/ml/feature_store/entity.py +6 -6
  39. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  41. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  45. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  48. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  51. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  52. snowflake/ml/feature_store/feature_store.py +52 -64
  53. snowflake/ml/feature_store/feature_view.py +24 -24
  54. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  55. snowflake/ml/fileset/fileset.py +5 -5
  56. snowflake/ml/fileset/sfcfs.py +13 -13
  57. snowflake/ml/fileset/stage_fs.py +15 -15
  58. snowflake/ml/jobs/_utils/constants.py +1 -1
  59. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  60. snowflake/ml/jobs/_utils/payload_utils.py +45 -46
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -5
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  64. snowflake/ml/jobs/_utils/spec_utils.py +18 -29
  65. snowflake/ml/jobs/_utils/types.py +2 -2
  66. snowflake/ml/jobs/decorators.py +10 -5
  67. snowflake/ml/jobs/job.py +87 -30
  68. snowflake/ml/jobs/manager.py +86 -56
  69. snowflake/ml/lineage/lineage_node.py +5 -5
  70. snowflake/ml/model/_client/model/model_impl.py +3 -3
  71. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  72. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  73. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  74. snowflake/ml/model/_client/ops/service_ops.py +217 -32
  75. snowflake/ml/model/_client/service/model_deployment_spec.py +359 -65
  76. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +69 -24
  77. snowflake/ml/model/_client/sql/model.py +8 -8
  78. snowflake/ml/model/_client/sql/model_version.py +26 -26
  79. snowflake/ml/model/_client/sql/service.py +17 -26
  80. snowflake/ml/model/_client/sql/stage.py +2 -2
  81. snowflake/ml/model/_client/sql/tag.py +6 -6
  82. snowflake/ml/model/_model_composer/model_composer.py +58 -32
  83. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  85. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  86. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  87. snowflake/ml/model/_packager/model_handler.py +4 -4
  88. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  89. snowflake/ml/model/_packager/model_handlers/_utils.py +47 -5
  90. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  91. snowflake/ml/model/_packager/model_handlers/custom.py +9 -5
  92. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  93. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  95. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  96. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -6
  97. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  98. snowflake/ml/model/_packager/model_handlers/sklearn.py +104 -46
  99. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  100. snowflake/ml/model/_packager/model_handlers/tensorflow.py +11 -8
  101. snowflake/ml/model/_packager/model_handlers/torchscript.py +6 -6
  102. snowflake/ml/model/_packager/model_handlers/xgboost.py +21 -22
  103. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  104. snowflake/ml/model/_packager/model_meta/model_meta.py +39 -38
  105. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -11
  106. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  107. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  109. snowflake/ml/model/_packager/model_packager.py +11 -9
  110. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  111. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  112. snowflake/ml/model/_signatures/core.py +16 -24
  113. snowflake/ml/model/_signatures/dmatrix_handler.py +17 -4
  114. snowflake/ml/model/_signatures/utils.py +6 -6
  115. snowflake/ml/model/custom_model.py +24 -11
  116. snowflake/ml/model/model_signature.py +12 -23
  117. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  118. snowflake/ml/model/type_hints.py +3 -3
  119. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  120. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  122. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  123. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  124. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  125. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  126. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  128. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  129. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  130. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  131. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  132. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  133. snowflake/ml/modeling/cluster/birch.py +9 -1
  134. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  135. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  136. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  137. snowflake/ml/modeling/cluster/k_means.py +9 -1
  138. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  139. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  140. snowflake/ml/modeling/cluster/optics.py +9 -1
  141. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  142. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  143. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  144. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  145. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  146. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  147. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  148. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  149. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  150. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  151. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  152. snowflake/ml/modeling/covariance/oas.py +9 -1
  153. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  154. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  155. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  156. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  157. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  158. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  159. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  160. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  161. snowflake/ml/modeling/decomposition/pca.py +9 -1
  162. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  163. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  164. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  165. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  166. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  167. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  168. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  169. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  170. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  171. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  172. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  173. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  174. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  175. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  176. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  177. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  178. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  179. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  180. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  181. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  182. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  183. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  184. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  185. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  186. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  187. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  188. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  189. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  190. snowflake/ml/modeling/framework/_utils.py +10 -10
  191. snowflake/ml/modeling/framework/base.py +32 -32
  192. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  193. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  194. snowflake/ml/modeling/impute/__init__.py +1 -1
  195. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  196. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  197. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  198. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  199. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  200. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  201. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  202. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  203. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  204. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  205. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  206. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  207. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  208. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  209. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  210. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  211. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  212. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  213. snowflake/ml/modeling/linear_model/lars.py +9 -1
  214. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  215. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  216. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  217. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  218. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  219. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  220. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  221. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  222. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  223. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  224. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  225. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  226. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  227. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  228. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  229. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  230. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  231. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  232. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  233. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  234. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  235. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  236. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  237. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  238. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  239. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  240. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  241. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  242. snowflake/ml/modeling/manifold/isomap.py +9 -1
  243. snowflake/ml/modeling/manifold/mds.py +9 -1
  244. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  245. snowflake/ml/modeling/manifold/tsne.py +9 -1
  246. snowflake/ml/modeling/metrics/__init__.py +1 -1
  247. snowflake/ml/modeling/metrics/classification.py +39 -39
  248. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  249. snowflake/ml/modeling/metrics/ranking.py +7 -7
  250. snowflake/ml/modeling/metrics/regression.py +13 -13
  251. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  252. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  253. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  254. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  255. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  256. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  257. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  258. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  259. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  260. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  261. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  262. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  263. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  264. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  265. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  266. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  267. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  268. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  269. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  270. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  271. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  272. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  273. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  274. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  275. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  276. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  277. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  278. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  279. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  280. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  281. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  282. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  283. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  284. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  285. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  286. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  287. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  288. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  289. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  290. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  291. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  292. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  293. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  294. snowflake/ml/modeling/svm/svc.py +9 -1
  295. snowflake/ml/modeling/svm/svr.py +9 -1
  296. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  297. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  298. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  299. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  300. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  301. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  302. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  303. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  304. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  305. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  306. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  307. snowflake/ml/monitoring/explain_visualize.py +286 -0
  308. snowflake/ml/registry/_manager/model_manager.py +55 -32
  309. snowflake/ml/registry/registry.py +39 -31
  310. snowflake/ml/utils/authentication.py +2 -2
  311. snowflake/ml/utils/connection_params.py +5 -5
  312. snowflake/ml/utils/sparse.py +5 -4
  313. snowflake/ml/utils/sql_client.py +1 -2
  314. snowflake/ml/version.py +2 -1
  315. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +55 -14
  316. snowflake_ml_python-1.8.4.dist-info/RECORD +419 -0
  317. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
  318. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  319. snowflake/ml/modeling/_internal/constants.py +0 -2
  320. snowflake_ml_python-1.8.2.dist-info/RECORD +0 -420
  321. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  322. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,10 @@
1
1
  from snowflake.cortex._classify_text import ClassifyText, classify_text
2
- from snowflake.cortex._complete import Complete, CompleteOptions, complete
2
+ from snowflake.cortex._complete import (
3
+ Complete,
4
+ CompleteOptions,
5
+ ConversationMessage,
6
+ complete,
7
+ )
3
8
  from snowflake.cortex._embed_text_768 import EmbedText768, embed_text_768
4
9
  from snowflake.cortex._embed_text_1024 import EmbedText1024, embed_text_1024
5
10
  from snowflake.cortex._extract_answer import ExtractAnswer, extract_answer
@@ -14,6 +19,7 @@ __all__ = [
14
19
  "Complete",
15
20
  "complete",
16
21
  "CompleteOptions",
22
+ "ConversationMessage",
17
23
  "EmbedText768",
18
24
  "embed_text_768",
19
25
  "EmbedText1024",
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Union, cast
1
+ from typing import Optional, Union, cast
2
2
 
3
3
  from typing_extensions import deprecated
4
4
 
@@ -12,7 +12,7 @@ from snowflake.ml._internal import telemetry
12
12
  )
13
13
  def classify_text(
14
14
  str_input: Union[str, snowpark.Column],
15
- categories: Union[List[str], snowpark.Column],
15
+ categories: Union[list[str], snowpark.Column],
16
16
  session: Optional[snowpark.Session] = None,
17
17
  ) -> Union[str, snowpark.Column]:
18
18
  """Use the LLM inference service to classify the INPUT text into one of the target CATEGORIES.
@@ -32,7 +32,7 @@ def classify_text(
32
32
  def _classify_text_impl(
33
33
  function: str,
34
34
  str_input: Union[str, snowpark.Column],
35
- categories: Union[List[str], snowpark.Column],
35
+ categories: Union[list[str], snowpark.Column],
36
36
  session: Optional[snowpark.Session] = None,
37
37
  ) -> Union[str, snowpark.Column]:
38
38
  return cast(Union[str, snowpark.Column], call_sql_function(function, session, str_input, categories))
@@ -3,7 +3,7 @@ import logging
3
3
  import time
4
4
  import typing
5
5
  from io import BytesIO
6
- from typing import Any, Callable, Dict, Iterator, List, Optional, TypedDict, Union, cast
6
+ from typing import Any, Callable, Iterator, Optional, TypedDict, Union, cast
7
7
  from urllib.parse import urlunparse
8
8
 
9
9
  import requests
@@ -30,7 +30,7 @@ class ResponseFormat(TypedDict):
30
30
 
31
31
  type: str
32
32
  """The response format type (e.g. "json")"""
33
- schema: Dict[str, Any]
33
+ schema: dict[str, Any]
34
34
  """The schema defining the structure of the response. For json it should be a valid json schema object"""
35
35
 
36
36
 
@@ -71,12 +71,11 @@ class CompleteOptions(TypedDict):
71
71
  class ResponseParseException(Exception):
72
72
  """This exception is raised when the server response cannot be parsed."""
73
73
 
74
- pass
75
-
76
74
 
77
75
  class MidStreamException(Exception):
78
76
  """The SSE (Server-sent Event) stream can contain error messages in the middle of the stream,
79
- using the “error” event type. This exception is raised when there is such a mid-stream error."""
77
+ using the “error” event type. This exception is raised when there is such a mid-stream error.
78
+ """
80
79
 
81
80
  def __init__(
82
81
  self,
@@ -135,7 +134,7 @@ def retry(func: Callable[..., requests.Response]) -> Callable[..., requests.Resp
135
134
  return inner
136
135
 
137
136
 
138
- def _make_common_request_headers() -> Dict[str, str]:
137
+ def _make_common_request_headers() -> dict[str, str]:
139
138
  headers = {
140
139
  "Content-Type": "application/json",
141
140
  "Accept": "application/json, text/event-stream",
@@ -143,7 +142,7 @@ def _make_common_request_headers() -> Dict[str, str]:
143
142
  return headers
144
143
 
145
144
 
146
- def _get_request_id(resp: Dict[str, Any]) -> Optional[Any]:
145
+ def _get_request_id(resp: dict[str, Any]) -> Optional[Any]:
147
146
  request_id = None
148
147
  if "headers" in resp:
149
148
  for key, value in resp["headers"].items():
@@ -183,14 +182,14 @@ def _validate_response_format_object(options: CompleteOptions) -> None:
183
182
 
184
183
  def _make_request_body(
185
184
  model: str,
186
- prompt: Union[str, List[ConversationMessage]],
185
+ prompt: Union[str, list[ConversationMessage]],
187
186
  options: Optional[CompleteOptions] = None,
188
- ) -> Dict[str, Any]:
187
+ ) -> dict[str, Any]:
189
188
  data = {
190
189
  "model": model,
191
190
  "stream": True,
192
191
  }
193
- if isinstance(prompt, List):
192
+ if isinstance(prompt, list):
194
193
  data["messages"] = prompt
195
194
  else:
196
195
  data["messages"] = [{"content": prompt}]
@@ -217,7 +216,7 @@ def _make_request_body(
217
216
 
218
217
  # XP endpoint returns a dict response which needs to be converted to a format which can
219
218
  # be consumed by the SSEClient. This method does that.
220
- def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
219
+ def _xp_dict_to_response(raw_resp: dict[str, Any]) -> requests.Response:
221
220
 
222
221
  response = requests.Response()
223
222
  response.status_code = int(raw_resp["status"])
@@ -251,9 +250,9 @@ def _xp_dict_to_response(raw_resp: Dict[str, Any]) -> requests.Response:
251
250
 
252
251
  @retry
253
252
  def _call_complete_xp(
254
- snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
253
+ snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]],
255
254
  model: str,
256
- prompt: Union[str, List[ConversationMessage]],
255
+ prompt: Union[str, list[ConversationMessage]],
257
256
  options: Optional[CompleteOptions] = None,
258
257
  deadline: Optional[float] = None,
259
258
  ) -> requests.Response:
@@ -267,7 +266,7 @@ def _call_complete_xp(
267
266
  @retry
268
267
  def _call_complete_rest(
269
268
  model: str,
270
- prompt: Union[str, List[ConversationMessage]],
269
+ prompt: Union[str, list[ConversationMessage]],
271
270
  options: Optional[CompleteOptions] = None,
272
271
  session: Optional[snowpark.Session] = None,
273
272
  ) -> requests.Response:
@@ -340,9 +339,9 @@ def _complete_call_sql_function_snowpark(
340
339
 
341
340
 
342
341
  def _complete_non_streaming_immediate(
343
- snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
342
+ snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]],
344
343
  model: str,
345
- prompt: Union[str, List[ConversationMessage]],
344
+ prompt: Union[str, list[ConversationMessage]],
346
345
  options: Optional[CompleteOptions],
347
346
  session: Optional[snowpark.Session] = None,
348
347
  deadline: Optional[float] = None,
@@ -359,10 +358,10 @@ def _complete_non_streaming_immediate(
359
358
 
360
359
 
361
360
  def _complete_non_streaming_impl(
362
- snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
361
+ snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]],
363
362
  function: str,
364
363
  model: Union[str, snowpark.Column],
365
- prompt: Union[str, List[ConversationMessage], snowpark.Column],
364
+ prompt: Union[str, list[ConversationMessage], snowpark.Column],
366
365
  options: Optional[Union[CompleteOptions, snowpark.Column]],
367
366
  session: Optional[snowpark.Session] = None,
368
367
  deadline: Optional[float] = None,
@@ -389,9 +388,9 @@ def _complete_non_streaming_impl(
389
388
 
390
389
 
391
390
  def _complete_rest(
392
- snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]],
391
+ snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]],
393
392
  model: str,
394
- prompt: Union[str, List[ConversationMessage]],
393
+ prompt: Union[str, list[ConversationMessage]],
395
394
  options: Optional[CompleteOptions] = None,
396
395
  session: Optional[snowpark.Session] = None,
397
396
  deadline: Optional[float] = None,
@@ -414,8 +413,8 @@ def _complete_rest(
414
413
 
415
414
  def _complete_impl(
416
415
  model: Union[str, snowpark.Column],
417
- prompt: Union[str, List[ConversationMessage], snowpark.Column],
418
- snow_api_xp_request_handler: Optional[Callable[..., Dict[str, Any]]] = None,
416
+ prompt: Union[str, list[ConversationMessage], snowpark.Column],
417
+ snow_api_xp_request_handler: Optional[Callable[..., dict[str, Any]]] = None,
419
418
  function: str = "snowflake.cortex.complete",
420
419
  options: Optional[CompleteOptions] = None,
421
420
  session: Optional[snowpark.Session] = None,
@@ -430,7 +429,7 @@ def _complete_impl(
430
429
  if stream:
431
430
  if not isinstance(model, str):
432
431
  raise ValueError("in REST mode, 'model' must be a string")
433
- if not isinstance(prompt, str) and not isinstance(prompt, List):
432
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
434
433
  raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
435
434
  return _complete_rest(
436
435
  snow_api_xp_request_handler=snow_api_xp_request_handler,
@@ -456,7 +455,7 @@ def _complete_impl(
456
455
  )
457
456
  def complete(
458
457
  model: Union[str, snowpark.Column],
459
- prompt: Union[str, List[ConversationMessage], snowpark.Column],
458
+ prompt: Union[str, list[ConversationMessage], snowpark.Column],
460
459
  *,
461
460
  options: Optional[CompleteOptions] = None,
462
461
  session: Optional[snowpark.Session] = None,
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Union, cast
1
+ from typing import Optional, Union, cast
2
2
 
3
3
  from typing_extensions import deprecated
4
4
 
@@ -14,7 +14,7 @@ def embed_text_1024(
14
14
  model: Union[str, snowpark.Column],
15
15
  text: Union[str, snowpark.Column],
16
16
  session: Optional[snowpark.Session] = None,
17
- ) -> Union[List[float], snowpark.Column]:
17
+ ) -> Union[list[float], snowpark.Column]:
18
18
  """Calls into the LLM inference service to embed the text.
19
19
 
20
20
  Args:
@@ -35,8 +35,8 @@ def _embed_text_1024_impl(
35
35
  model: Union[str, snowpark.Column],
36
36
  text: Union[str, snowpark.Column],
37
37
  session: Optional[snowpark.Session] = None,
38
- ) -> Union[List[float], snowpark.Column]:
39
- return cast(Union[List[float], snowpark.Column], call_sql_function(function, session, model, text))
38
+ ) -> Union[list[float], snowpark.Column]:
39
+ return cast(Union[list[float], snowpark.Column], call_sql_function(function, session, model, text))
40
40
 
41
41
 
42
42
  EmbedText1024 = deprecated(
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Union, cast
1
+ from typing import Optional, Union, cast
2
2
 
3
3
  from typing_extensions import deprecated
4
4
 
@@ -14,7 +14,7 @@ def embed_text_768(
14
14
  model: Union[str, snowpark.Column],
15
15
  text: Union[str, snowpark.Column],
16
16
  session: Optional[snowpark.Session] = None,
17
- ) -> Union[List[float], snowpark.Column]:
17
+ ) -> Union[list[float], snowpark.Column]:
18
18
  """Calls into the LLM inference service to embed the text.
19
19
 
20
20
  Args:
@@ -35,8 +35,8 @@ def _embed_text_768_impl(
35
35
  model: Union[str, snowpark.Column],
36
36
  text: Union[str, snowpark.Column],
37
37
  session: Optional[snowpark.Session] = None,
38
- ) -> Union[List[float], snowpark.Column]:
39
- return cast(Union[List[float], snowpark.Column], call_sql_function(function, session, model, text))
38
+ ) -> Union[list[float], snowpark.Column]:
39
+ return cast(Union[list[float], snowpark.Column], call_sql_function(function, session, model, text))
40
40
 
41
41
 
42
42
  EmbedText768 = deprecated(
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from dataclasses import dataclass
3
- from typing import Any, Dict, List, Optional, Union, cast
3
+ from typing import Any, Optional, Union, cast
4
4
 
5
5
  from snowflake import snowpark
6
6
  from snowflake.cortex._util import (
@@ -53,7 +53,7 @@ class FinetuneStatus:
53
53
  created_on: Optional[int] = None
54
54
  """Creation timestamp of the Fine-tuning job in milliseconds."""
55
55
 
56
- error: Optional[Dict[str, Any]] = None
56
+ error: Optional[dict[str, Any]] = None
57
57
  """Error message propagated from the job."""
58
58
 
59
59
  finished_on: Optional[int] = None
@@ -62,7 +62,7 @@ class FinetuneStatus:
62
62
  progress: Optional[float] = None
63
63
  """Progress made as a fraction of total [0.0,1.0]."""
64
64
 
65
- training_result: Optional[List[Dict[str, Any]]] = None
65
+ training_result: Optional[list[dict[str, Any]]] = None
66
66
  """Detailed metrics report for a completed training."""
67
67
 
68
68
  trained_tokens: Optional[int] = None
@@ -135,7 +135,7 @@ class FinetuneJob:
135
135
  """
136
136
  result_string = _finetune_impl(operation="DESCRIBE", session=self._session, function_args=[self.status.id])
137
137
 
138
- result = FinetuneStatus(**cast(Dict[str, Any], _try_load_json(result_string)))
138
+ result = FinetuneStatus(**cast(dict[str, Any], _try_load_json(result_string)))
139
139
  return result
140
140
 
141
141
 
@@ -167,7 +167,7 @@ class Finetune:
167
167
  base_model: str,
168
168
  training_data: Union[str, snowpark.DataFrame],
169
169
  validation_data: Optional[Union[str, snowpark.DataFrame]] = None,
170
- options: Optional[Dict[str, Any]] = None,
170
+ options: Optional[dict[str, Any]] = None,
171
171
  ) -> FinetuneJob:
172
172
  """Create a new fine-tuning runs.
173
173
 
@@ -240,7 +240,7 @@ class Finetune:
240
240
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
241
241
  subproject=CORTEX_FINETUNE_TELEMETRY_SUBPROJECT,
242
242
  )
243
- def list_jobs(self) -> List["FinetuneJob"]:
243
+ def list_jobs(self) -> list["FinetuneJob"]:
244
244
  """Show current and past fine-tuning runs.
245
245
 
246
246
  Returns:
@@ -253,7 +253,7 @@ class Finetune:
253
253
  return [FinetuneJob(session=self._session, status=FinetuneStatus(**run_status)) for run_status in result]
254
254
 
255
255
 
256
- def _try_load_json(json_string: str) -> Union[Dict[Any, Any], List[Any]]:
256
+ def _try_load_json(json_string: str) -> Union[dict[Any, Any], list[Any]]:
257
257
  try:
258
258
  result = json.loads(str(json_string))
259
259
  except json.JSONDecodeError as e:
@@ -269,5 +269,5 @@ def _try_load_json(json_string: str) -> Union[Dict[Any, Any], List[Any]]:
269
269
  return result
270
270
 
271
271
 
272
- def _finetune_impl(operation: str, session: Optional[snowpark.Session], function_args: List[Any]) -> str:
272
+ def _finetune_impl(operation: str, session: Optional[snowpark.Session], function_args: list[Any]) -> str:
273
273
  return call_sql_function_literals(_CORTEX_FINETUNE_SYSTEM_FUNCTION_NAME, session, operation, *function_args)
snowflake/cortex/_util.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Optional, Union, cast
1
+ from typing import Any, Optional, Union, cast
2
2
 
3
3
  from snowflake import snowpark
4
4
  from snowflake.ml._internal.exceptions import error_codes, exceptions
@@ -11,22 +11,18 @@ CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
11
11
  class SnowflakeAuthenticationException(Exception):
12
12
  """This exception is raised when there is an issue with Snowflake's configuration."""
13
13
 
14
- pass
15
-
16
14
 
17
15
  class SnowflakeConfigurationException(Exception):
18
16
  """This exception is raised when there is an issue with Snowflake's configuration."""
19
17
 
20
- pass
21
-
22
18
 
23
19
  # Calls a sql function, handling both immediate (e.g. python types) and batch
24
20
  # (e.g. snowpark column and literal type modes).
25
21
  def call_sql_function(
26
22
  function: str,
27
23
  session: Optional[snowpark.Session],
28
- *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
29
- ) -> Union[str, List[float], snowpark.Column]:
24
+ *args: Union[str, list[str], snowpark.Column, dict[str, Union[int, float]]],
25
+ ) -> Union[str, list[float], snowpark.Column]:
30
26
  handle_as_column = False
31
27
 
32
28
  for arg in args:
@@ -34,15 +30,15 @@ def call_sql_function(
34
30
  handle_as_column = True
35
31
 
36
32
  if handle_as_column:
37
- return cast(Union[str, List[float], snowpark.Column], _call_sql_function_column(function, *args))
33
+ return cast(Union[str, list[float], snowpark.Column], _call_sql_function_column(function, *args))
38
34
  return cast(
39
- Union[str, List[float], snowpark.Column],
35
+ Union[str, list[float], snowpark.Column],
40
36
  _call_sql_function_immediate(function, session, *args),
41
37
  )
42
38
 
43
39
 
44
40
  def _call_sql_function_column(
45
- function: str, *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]]
41
+ function: str, *args: Union[str, list[str], snowpark.Column, dict[str, Union[int, float]]]
46
42
  ) -> snowpark.Column:
47
43
  return cast(snowpark.Column, functions.builtin(function)(*args))
48
44
 
@@ -50,8 +46,8 @@ def _call_sql_function_column(
50
46
  def _call_sql_function_immediate(
51
47
  function: str,
52
48
  session: Optional[snowpark.Session],
53
- *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
54
- ) -> Union[str, List[float]]:
49
+ *args: Union[str, list[str], snowpark.Column, dict[str, Union[int, float]]],
50
+ ) -> Union[str, list[float]]:
55
51
  session = session or context.get_active_session()
56
52
  if session is None:
57
53
  raise SnowflakeAuthenticationException(
@@ -1,8 +1,9 @@
1
+ import os
1
2
  import platform
2
3
 
3
- from snowflake.ml import version
4
-
5
4
  SOURCE = "SnowML"
6
- VERSION = version.VERSION
7
5
  PYTHON_VERSION = platform.python_version()
8
6
  OS = platform.system()
7
+ IN_ML_RUNTIME_ENV_VAR = "IN_SPCS_ML_RUNTIME"
8
+ IN_ML_RUNTIME = os.getenv(IN_ML_RUNTIME_ENV_VAR)
9
+ USE_OPTIMIZED_DATA_INGESTOR = "USE_OPTIMIZED_DATA_INGESTOR"
@@ -6,12 +6,13 @@ import textwrap
6
6
  import warnings
7
7
  from enum import Enum
8
8
  from importlib import metadata as importlib_metadata
9
- from typing import Any, DefaultDict, Dict, List, Optional, Tuple
9
+ from typing import Any, DefaultDict, Optional
10
10
 
11
11
  import yaml
12
12
  from packaging import requirements, specifiers, version
13
13
 
14
14
  import snowflake.connector
15
+ from snowflake.ml import version as snowml_version
15
16
  from snowflake.ml._internal import env as snowml_env, relax_version_strategy
16
17
  from snowflake.ml._internal.utils import query_result_checker
17
18
  from snowflake.snowpark import context, exceptions, session
@@ -27,8 +28,8 @@ class CONDA_OS(Enum):
27
28
 
28
29
 
29
30
  _NODEFAULTS = "nodefaults"
30
- _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
31
- _SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
31
+ _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: dict[str, list[version.Version]] = {}
32
+ _SNOWFLAKE_CONDA_PACKAGE_CACHE: dict[str, list[version.Version]] = {}
32
33
  _SUPPORTED_PACKAGE_SPEC_OPS = ["==", ">=", "<=", ">", "<"]
33
34
 
34
35
  DEFAULT_CHANNEL_NAME = ""
@@ -64,7 +65,7 @@ def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
64
65
  return r
65
66
 
66
67
 
67
- def _validate_conda_dependency_string(dep_str: str) -> Tuple[str, requirements.Requirement]:
68
+ def _validate_conda_dependency_string(dep_str: str) -> tuple[str, requirements.Requirement]:
68
69
  """Validate conda dependency string like `pytorch == 1.12.1` or `conda-forge::transformer` and split the channel
69
70
  name before the double colon and requirement specification after that.
70
71
 
@@ -115,7 +116,7 @@ class DuplicateDependencyInMultipleChannelsError(Exception):
115
116
  ...
116
117
 
117
118
 
118
- def append_requirement_list(req_list: List[requirements.Requirement], p_req: requirements.Requirement) -> None:
119
+ def append_requirement_list(req_list: list[requirements.Requirement], p_req: requirements.Requirement) -> None:
119
120
  """Append a requirement to an existing requirement list. If need and able to merge, merge it, otherwise, append it.
120
121
 
121
122
  Args:
@@ -134,7 +135,7 @@ def append_requirement_list(req_list: List[requirements.Requirement], p_req: req
134
135
 
135
136
 
136
137
  def append_conda_dependency(
137
- conda_chan_deps: DefaultDict[str, List[requirements.Requirement]], p_chan_dep: Tuple[str, requirements.Requirement]
138
+ conda_chan_deps: DefaultDict[str, list[requirements.Requirement]], p_chan_dep: tuple[str, requirements.Requirement]
138
139
  ) -> None:
139
140
  """Append a conda dependency to an existing conda dependencies dict, if not existed in any channel.
140
141
  To avoid making unnecessary modification to dict, we check the existence first, then try to merge, then append,
@@ -164,45 +165,73 @@ def append_conda_dependency(
164
165
  conda_chan_deps[p_channel].append(p_req)
165
166
 
166
167
 
167
- def validate_pip_requirement_string_list(req_str_list: List[str]) -> List[requirements.Requirement]:
168
- """Validate the a list of pip requirement string according to PEP 508.
168
+ def validate_pip_requirement_string_list(
169
+ req_str_list: list[str], add_local_version_specifier: bool = False
170
+ ) -> list[requirements.Requirement]:
171
+ """Validate the list of pip requirement strings according to PEP 508.
169
172
 
170
173
  Args:
171
- req_str_list: The list of string contains the pip requirement specification.
174
+ req_str_list: The list of strings containing the pip requirement specification.
175
+ add_local_version_specifier: if True, add the version specifier of the locally installed package version to
176
+ requirements without version specifiers.
172
177
 
173
178
  Returns:
174
179
  A requirements.Requirement list containing the requirement information.
175
180
  """
176
- seen_pip_requirement_list: List[requirements.Requirement] = []
181
+ seen_pip_requirement_list: list[requirements.Requirement] = []
177
182
  for req_str in req_str_list:
178
183
  append_requirement_list(seen_pip_requirement_list, _validate_pip_requirement_string(req_str=req_str))
179
184
 
185
+ if add_local_version_specifier:
186
+ # For any requirement string that does not contain a specifier, add the specifier of a locally installed version
187
+ # if it exists.
188
+ seen_pip_requirement_list = list(
189
+ map(
190
+ lambda req: req if req.specifier else get_local_installed_version_of_pip_package(req),
191
+ seen_pip_requirement_list,
192
+ )
193
+ )
194
+
180
195
  return seen_pip_requirement_list
181
196
 
182
197
 
183
- def validate_conda_dependency_string_list(dep_str_list: List[str]) -> DefaultDict[str, List[requirements.Requirement]]:
198
+ def validate_conda_dependency_string_list(
199
+ dep_str_list: list[str], add_local_version_specifier: bool = False
200
+ ) -> DefaultDict[str, list[requirements.Requirement]]:
184
201
  """Validate a list of conda dependency string, find any duplicate package across different channel and create a dict
185
202
  to represent the whole dependencies.
186
203
 
187
204
  Args:
188
205
  dep_str_list: The list of string contains the conda dependency specification.
206
+ add_local_version_specifier: if True, add the version specifier of the locally installed package version to
207
+ requirements without version specifiers.
189
208
 
190
209
  Returns:
191
210
  A dict mapping from the channel name to the list of requirements from that channel.
192
211
  """
193
212
  validated_conda_dependency_list = list(map(_validate_conda_dependency_string, dep_str_list))
194
- ret_conda_dependency_dict: DefaultDict[str, List[requirements.Requirement]] = collections.defaultdict(list)
213
+ ret_conda_dependency_dict: DefaultDict[str, list[requirements.Requirement]] = collections.defaultdict(list)
195
214
  for p_channel, p_req in validated_conda_dependency_list:
196
215
  append_conda_dependency(ret_conda_dependency_dict, (p_channel, p_req))
197
216
 
217
+ if add_local_version_specifier:
218
+ # For any conda dependency string that does not contain a specifier, add the specifier of a locally installed
219
+ # version if it exists. This is best-effort: if the conda package does not have the same name as the pip
220
+ # package, it won't be found in the local environment.
221
+ for channel_str, reqs in ret_conda_dependency_dict.items():
222
+ reqs = list(
223
+ map(lambda req: req if req.specifier else get_local_installed_version_of_pip_package(req), reqs)
224
+ )
225
+ ret_conda_dependency_dict[channel_str] = reqs
226
+
198
227
  return ret_conda_dependency_dict
199
228
 
200
229
 
201
230
  def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement) -> requirements.Requirement:
202
231
  """Get the local installed version of a given pip package requirement.
203
- If the package is locally installed, and the local version meet the specifier of the requirements, return a new
232
+ If the package is locally installed, and the local version meets the specifier of the requirements, return a new
204
233
  requirement specifier that pins the version.
205
- If the local version does not meet the specifier of the requirements, a warn will be omitted and returns
234
+ If the local version does not meet the specifier of the requirements, a warning will be emitted and returns
206
235
  the original package requirement.
207
236
  If the package is not locally installed or not found, the original package requirement is returned.
208
237
 
@@ -217,7 +246,7 @@ def get_local_installed_version_of_pip_package(pip_req: requirements.Requirement
217
246
  local_dist_version = local_dist.version
218
247
  except importlib_metadata.PackageNotFoundError:
219
248
  if pip_req.name == SNOWPARK_ML_PKG_NAME:
220
- local_dist_version = snowml_env.VERSION
249
+ local_dist_version = snowml_version.VERSION
221
250
  else:
222
251
  return pip_req
223
252
  new_pip_req = copy.deepcopy(pip_req)
@@ -372,8 +401,8 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req
372
401
 
373
402
 
374
403
  def get_matched_package_versions_in_information_schema_with_active_session(
375
- reqs: List[requirements.Requirement], python_version: str
376
- ) -> Dict[str, List[version.Version]]:
404
+ reqs: list[requirements.Requirement], python_version: str
405
+ ) -> dict[str, list[version.Version]]:
377
406
  try:
378
407
  session = context.get_active_session()
379
408
  except exceptions.SnowparkSessionException:
@@ -383,10 +412,10 @@ def get_matched_package_versions_in_information_schema_with_active_session(
383
412
 
384
413
  def get_matched_package_versions_in_information_schema(
385
414
  session: session.Session,
386
- reqs: List[requirements.Requirement],
415
+ reqs: list[requirements.Requirement],
387
416
  python_version: str,
388
- statement_params: Optional[Dict[str, Any]] = None,
389
- ) -> Dict[str, List[version.Version]]:
417
+ statement_params: Optional[dict[str, Any]] = None,
418
+ ) -> dict[str, list[version.Version]]:
390
419
  """Look up the information_schema table to check if a package with the specified specifier exists in the Snowflake
391
420
  Conda channel. Note that this is not the source of truth due to the potential delay caused by a package that might
392
421
  exist in the information_schema table but has not yet become available in the Snowflake Conda channel.
@@ -400,8 +429,8 @@ def get_matched_package_versions_in_information_schema(
400
429
  Returns:
401
430
  A Dict, whose key is the package name, and value is a list of versions match the requirements.
402
431
  """
403
- ret_dict: Dict[str, List[version.Version]] = {}
404
- reqs_to_request: List[requirements.Requirement] = []
432
+ ret_dict: dict[str, list[version.Version]] = {}
433
+ reqs_to_request: list[requirements.Requirement] = []
405
434
  for req in reqs:
406
435
  if req.name in _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE:
407
436
  available_versions = list(
@@ -457,7 +486,7 @@ def get_matched_package_versions_in_information_schema(
457
486
 
458
487
  def save_conda_env_file(
459
488
  path: pathlib.Path,
460
- conda_chan_deps: DefaultDict[str, List[requirements.Requirement]],
489
+ conda_chan_deps: DefaultDict[str, list[requirements.Requirement]],
461
490
  python_version: str,
462
491
  cuda_version: Optional[str] = None,
463
492
  default_channel_override: str = SNOWFLAKE_CONDA_CHANNEL_URL,
@@ -478,7 +507,7 @@ def save_conda_env_file(
478
507
  """
479
508
  assert path.suffix in [".yml", ".yaml"], "Conda environment file should have extension of yml or yaml."
480
509
  path.parent.mkdir(parents=True, exist_ok=True)
481
- env: Dict[str, Any] = dict()
510
+ env: dict[str, Any] = dict()
482
511
  env["name"] = "snow-env"
483
512
  # Get all channels in the dependencies, ordered by the number of the packages which belongs to and put into
484
513
  # channels section.
@@ -505,7 +534,7 @@ def save_conda_env_file(
505
534
  yaml.safe_dump(env, stream=f, default_flow_style=False)
506
535
 
507
536
 
508
- def save_requirements_file(path: pathlib.Path, pip_deps: List[requirements.Requirement]) -> None:
537
+ def save_requirements_file(path: pathlib.Path, pip_deps: list[requirements.Requirement]) -> None:
509
538
  """Generate Python requirements.txt file in the given directory path.
510
539
 
511
540
  Args:
@@ -521,9 +550,9 @@ def save_requirements_file(path: pathlib.Path, pip_deps: List[requirements.Requi
521
550
 
522
551
  def load_conda_env_file(
523
552
  path: pathlib.Path,
524
- ) -> Tuple[
525
- DefaultDict[str, List[requirements.Requirement]],
526
- Optional[List[requirements.Requirement]],
553
+ ) -> tuple[
554
+ DefaultDict[str, list[requirements.Requirement]],
555
+ Optional[list[requirements.Requirement]],
527
556
  Optional[str],
528
557
  Optional[str],
529
558
  ]:
@@ -601,7 +630,7 @@ def load_conda_env_file(
601
630
  return conda_dep_dict, pip_deps_list if pip_deps_list else None, python_version, cuda_version
602
631
 
603
632
 
604
- def load_requirements_file(path: pathlib.Path) -> List[requirements.Requirement]:
633
+ def load_requirements_file(path: pathlib.Path) -> list[requirements.Requirement]:
605
634
  """Load Python requirements.txt file from the given directory path.
606
635
 
607
636
  Args:
@@ -641,8 +670,8 @@ def parse_python_version_string(dep: str) -> Optional[str]:
641
670
 
642
671
 
643
672
  def _find_conda_dep_spec(
644
- conda_chan_deps: DefaultDict[str, List[requirements.Requirement]], pkg_name: str
645
- ) -> Optional[Tuple[str, requirements.Requirement]]:
673
+ conda_chan_deps: DefaultDict[str, list[requirements.Requirement]], pkg_name: str
674
+ ) -> Optional[tuple[str, requirements.Requirement]]:
646
675
  for channel in conda_chan_deps:
647
676
  spec = next(filter(lambda req: req.name == pkg_name, conda_chan_deps[channel]), None)
648
677
  if spec:
@@ -650,14 +679,14 @@ def _find_conda_dep_spec(
650
679
  return None
651
680
 
652
681
 
653
- def _find_pip_req_spec(pip_reqs: List[requirements.Requirement], pkg_name: str) -> Optional[requirements.Requirement]:
682
+ def _find_pip_req_spec(pip_reqs: list[requirements.Requirement], pkg_name: str) -> Optional[requirements.Requirement]:
654
683
  spec = next(filter(lambda req: req.name == pkg_name, pip_reqs), None)
655
684
  return spec
656
685
 
657
686
 
658
687
  def find_dep_spec(
659
- conda_chan_deps: DefaultDict[str, List[requirements.Requirement]],
660
- pip_reqs: List[requirements.Requirement],
688
+ conda_chan_deps: DefaultDict[str, list[requirements.Requirement]],
689
+ pip_reqs: list[requirements.Requirement],
661
690
  conda_pkg_name: str,
662
691
  pip_pkg_name: Optional[str] = None,
663
692
  remove_spec: bool = False,