snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (237) hide show
  1. snowflake/cortex/__init__.py +16 -8
  2. snowflake/cortex/_classify_text.py +12 -1
  3. snowflake/cortex/_complete.py +101 -13
  4. snowflake/cortex/_embed_text_1024.py +9 -2
  5. snowflake/cortex/_embed_text_768.py +9 -2
  6. snowflake/cortex/_extract_answer.py +9 -2
  7. snowflake/cortex/_sentiment.py +9 -2
  8. snowflake/cortex/_summarize.py +9 -2
  9. snowflake/cortex/_translate.py +9 -2
  10. snowflake/ml/_internal/env_utils.py +7 -52
  11. snowflake/ml/_internal/platform_capabilities.py +87 -0
  12. snowflake/ml/_internal/utils/identifier.py +4 -2
  13. snowflake/ml/data/__init__.py +3 -0
  14. snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
  15. snowflake/ml/data/data_connector.py +53 -11
  16. snowflake/ml/data/data_ingestor.py +2 -1
  17. snowflake/ml/data/torch_utils.py +18 -5
  18. snowflake/ml/dataset/dataset.py +0 -1
  19. snowflake/ml/feature_store/examples/example_helper.py +2 -1
  20. snowflake/ml/fileset/fileset.py +24 -18
  21. snowflake/ml/jobs/__init__.py +21 -0
  22. snowflake/ml/jobs/_utils/constants.py +51 -0
  23. snowflake/ml/jobs/_utils/payload_utils.py +352 -0
  24. snowflake/ml/jobs/_utils/spec_utils.py +298 -0
  25. snowflake/ml/jobs/_utils/types.py +39 -0
  26. snowflake/ml/jobs/decorators.py +91 -0
  27. snowflake/ml/jobs/job.py +113 -0
  28. snowflake/ml/jobs/manager.py +298 -0
  29. snowflake/ml/model/_client/model/model_version_impl.py +5 -3
  30. snowflake/ml/model/_client/ops/model_ops.py +13 -8
  31. snowflake/ml/model/_client/ops/service_ops.py +1 -11
  32. snowflake/ml/model/_client/sql/model_version.py +11 -0
  33. snowflake/ml/model/_client/sql/service.py +13 -6
  34. snowflake/ml/model/_model_composer/model_composer.py +8 -3
  35. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  37. snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
  38. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
  39. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
  40. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
  41. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  42. snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
  43. snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
  44. snowflake/ml/model/_packager/model_handlers/_utils.py +39 -5
  45. snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
  46. snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +6 -1
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
  49. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
  50. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
  51. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
  52. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
  53. snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
  54. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
  55. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  56. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  57. snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
  58. snowflake/ml/model/_signatures/base_handler.py +1 -2
  59. snowflake/ml/model/_signatures/builtins_handler.py +2 -2
  60. snowflake/ml/model/_signatures/numpy_handler.py +6 -7
  61. snowflake/ml/model/_signatures/pandas_handler.py +3 -3
  62. snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
  63. snowflake/ml/model/_signatures/snowpark_handler.py +11 -5
  64. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
  65. snowflake/ml/model/model_signature.py +17 -4
  66. snowflake/ml/model/type_hints.py +1 -0
  67. snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
  68. snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
  69. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
  70. snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
  71. snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
  72. snowflake/ml/modeling/cluster/birch.py +6 -3
  73. snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
  74. snowflake/ml/modeling/cluster/dbscan.py +6 -3
  75. snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
  76. snowflake/ml/modeling/cluster/k_means.py +6 -3
  77. snowflake/ml/modeling/cluster/mean_shift.py +6 -3
  78. snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
  79. snowflake/ml/modeling/cluster/optics.py +6 -3
  80. snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
  81. snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
  82. snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
  83. snowflake/ml/modeling/compose/column_transformer.py +6 -3
  84. snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
  85. snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
  86. snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
  87. snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
  88. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
  89. snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
  90. snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
  91. snowflake/ml/modeling/covariance/oas.py +6 -3
  92. snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
  93. snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
  94. snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
  95. snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
  96. snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
  97. snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
  98. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
  99. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
  100. snowflake/ml/modeling/decomposition/pca.py +6 -3
  101. snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
  102. snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
  103. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
  104. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
  105. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
  106. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
  107. snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
  108. snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
  109. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
  110. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
  111. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
  112. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
  113. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
  114. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
  115. snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
  116. snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
  117. snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
  118. snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
  119. snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
  120. snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
  121. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
  122. snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
  123. snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
  124. snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
  125. snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
  126. snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
  127. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
  128. snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
  129. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
  130. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
  131. snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
  132. snowflake/ml/modeling/impute/knn_imputer.py +6 -3
  133. snowflake/ml/modeling/impute/missing_indicator.py +6 -3
  134. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
  135. snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
  136. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
  137. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
  138. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
  139. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
  140. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
  141. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
  142. snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
  143. snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
  144. snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
  145. snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
  146. snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
  147. snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
  148. snowflake/ml/modeling/linear_model/lars.py +6 -3
  149. snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
  150. snowflake/ml/modeling/linear_model/lasso.py +6 -3
  151. snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
  152. snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
  153. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
  154. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
  155. snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
  156. snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
  157. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
  158. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
  159. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
  160. snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
  161. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
  162. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
  163. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
  164. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
  165. snowflake/ml/modeling/linear_model/perceptron.py +6 -3
  166. snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
  167. snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
  168. snowflake/ml/modeling/linear_model/ridge.py +6 -3
  169. snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
  170. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
  171. snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
  172. snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
  173. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
  174. snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
  175. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
  176. snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
  177. snowflake/ml/modeling/manifold/isomap.py +6 -3
  178. snowflake/ml/modeling/manifold/mds.py +6 -3
  179. snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
  180. snowflake/ml/modeling/manifold/tsne.py +6 -3
  181. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
  182. snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
  183. snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
  184. snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
  185. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
  186. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
  187. snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
  188. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
  189. snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
  190. snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
  191. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
  192. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
  193. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
  194. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
  195. snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
  196. snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
  197. snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
  198. snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
  199. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
  200. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
  201. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
  202. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
  203. snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
  204. snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
  205. snowflake/ml/modeling/pipeline/pipeline.py +16 -178
  206. snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
  209. snowflake/ml/modeling/svm/linear_svc.py +6 -3
  210. snowflake/ml/modeling/svm/linear_svr.py +6 -3
  211. snowflake/ml/modeling/svm/nu_svc.py +6 -3
  212. snowflake/ml/modeling/svm/nu_svr.py +6 -3
  213. snowflake/ml/modeling/svm/svc.py +6 -3
  214. snowflake/ml/modeling/svm/svr.py +6 -3
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +167 -91
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +166 -88
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +166 -88
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +166 -88
  223. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
  224. snowflake/ml/registry/_manager/model_manager.py +70 -33
  225. snowflake/ml/registry/registry.py +41 -22
  226. snowflake/ml/version.py +1 -1
  227. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +63 -19
  228. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +231 -226
  229. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +1 -1
  230. snowflake/ml/_internal/utils/retryable_http.py +0 -39
  231. snowflake/ml/fileset/parquet_parser.py +0 -170
  232. snowflake/ml/fileset/tf_dataset.py +0 -88
  233. snowflake/ml/fileset/torch_datapipe.py +0 -57
  234. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
  235. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
  236. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
  237. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,91 @@
1
+ import copy
2
+ import functools
3
+ import inspect
4
+ from typing import Callable, Dict, List, Optional, TypeVar
5
+
6
+ from typing_extensions import ParamSpec
7
+
8
+ from snowflake import snowpark
9
+ from snowflake.ml._internal import telemetry
10
+ from snowflake.ml.jobs import job as jb, manager as jm
11
+ from snowflake.ml.jobs._utils import payload_utils
12
+
13
+ _PROJECT = "MLJob"
14
+
15
+ _Args = ParamSpec("_Args")
16
+ _ReturnValue = TypeVar("_ReturnValue")
17
+
18
+
19
+ @snowpark._internal.utils.private_preview(version="1.7.4")
20
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
21
+ def remote(
22
+ compute_pool: str,
23
+ stage_name: str,
24
+ pip_requirements: Optional[List[str]] = None,
25
+ external_access_integrations: Optional[List[str]] = None,
26
+ query_warehouse: Optional[str] = None,
27
+ env_vars: Optional[Dict[str, str]] = None,
28
+ session: Optional[snowpark.Session] = None,
29
+ ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob]]:
30
+ """
31
+ Submit a job to the compute pool.
32
+
33
+ Args:
34
+ compute_pool: The compute pool to use for the job.
35
+ stage_name: The name of the stage where the job payload will be uploaded.
36
+ pip_requirements: A list of pip requirements for the job.
37
+ external_access_integrations: A list of external access integrations.
38
+ query_warehouse: The query warehouse to use. Defaults to session warehouse.
39
+ env_vars: Environment variables to set in container
40
+ session: The Snowpark session to use. If none specified, uses active session.
41
+
42
+ Returns:
43
+ Decorator that dispatches invocations of the decorated function as remote jobs.
44
+ """
45
+
46
+ def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob]:
47
+ # Copy the function to avoid modifying the original
48
+ # We need to modify the line number of the function to exclude the
49
+ # decorator from the copied source code
50
+ wrapped_func = copy.copy(func)
51
+ wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
52
+
53
+ # Validate function arguments based on signature
54
+ signature = inspect.signature(func)
55
+ pos_arg_names = []
56
+ for name, param in signature.parameters.items():
57
+ param_type = payload_utils.get_parameter_type(param)
58
+ if param_type is not None:
59
+ payload_utils.validate_parameter_type(param_type, name)
60
+ if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
61
+ pos_arg_names.append(name)
62
+
63
+ @functools.wraps(func)
64
+ def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob:
65
+ # Validate positional args
66
+ for i, arg in enumerate(args):
67
+ arg_name = pos_arg_names[i] if i < len(pos_arg_names) else f"args[{i}]"
68
+ payload_utils.validate_parameter_type(type(arg), arg_name)
69
+
70
+ # Validate keyword args
71
+ for k, v in kwargs.items():
72
+ payload_utils.validate_parameter_type(type(v), k)
73
+
74
+ arg_list = [str(v) for v in args] + [x for k, v in kwargs.items() for x in (f"--{k}", str(v))]
75
+ job = jm._submit_job(
76
+ source=wrapped_func,
77
+ args=arg_list,
78
+ stage_name=stage_name,
79
+ compute_pool=compute_pool,
80
+ pip_requirements=pip_requirements,
81
+ external_access_integrations=external_access_integrations,
82
+ query_warehouse=query_warehouse,
83
+ env_vars=env_vars,
84
+ session=session,
85
+ )
86
+ assert isinstance(job, jb.MLJob)
87
+ return job
88
+
89
+ return wrapper
90
+
91
+ return decorator
@@ -0,0 +1,113 @@
1
+ import time
2
+ from typing import Any, List, Optional, cast
3
+
4
+ from snowflake import snowpark
5
+ from snowflake.ml._internal import telemetry
6
+ from snowflake.ml.jobs._utils import constants, types
7
+ from snowflake.snowpark.context import get_active_session
8
+
9
+ _PROJECT = "MLJob"
10
+ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
11
+
12
+
13
+ class MLJob:
14
+ def __init__(self, id: str, session: Optional[snowpark.Session] = None) -> None:
15
+ self._id = id
16
+ self._session = session or get_active_session()
17
+ self._status: types.JOB_STATUS = "PENDING"
18
+
19
+ @property
20
+ def id(self) -> str:
21
+ """Get the unique job ID"""
22
+ return self._id
23
+
24
+ @property
25
+ def status(self) -> types.JOB_STATUS:
26
+ """Get the job's execution status."""
27
+ if self._status not in TERMINAL_JOB_STATUSES:
28
+ # Query backend for job status if not in terminal state
29
+ self._status = _get_status(self._session, self.id)
30
+ return self._status
31
+
32
+ @snowpark._internal.utils.private_preview(version="1.7.4")
33
+ def get_logs(self, limit: int = -1) -> str:
34
+ """
35
+ Return the job's execution logs.
36
+
37
+ Args:
38
+ limit: The maximum number of lines to return. Negative values are treated as no limit.
39
+
40
+ Returns:
41
+ The job's execution logs.
42
+ """
43
+ logs = _get_logs(self._session, self.id, limit)
44
+ assert isinstance(logs, str) # mypy
45
+ return logs
46
+
47
+ @snowpark._internal.utils.private_preview(version="1.7.4")
48
+ def show_logs(self, limit: int = -1) -> None:
49
+ """
50
+ Display the job's execution logs.
51
+
52
+ Args:
53
+ limit: The maximum number of lines to display. Negative values are treated as no limit.
54
+ """
55
+ print(self.get_logs(limit)) # noqa: T201: we need to print here.
56
+
57
+ @snowpark._internal.utils.private_preview(version="1.7.4")
58
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
59
+ def wait(self, timeout: float = -1) -> types.JOB_STATUS:
60
+ """
61
+ Block until completion. Returns completion status.
62
+
63
+ Args:
64
+ timeout: The maximum time to wait in seconds. Negative values are treated as no timeout.
65
+
66
+ Returns:
67
+ The job's completion status.
68
+
69
+ Raises:
70
+ TimeoutError: If the job does not complete within the specified timeout.
71
+ """
72
+ delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
73
+ start_time = time.monotonic()
74
+ while self.status not in TERMINAL_JOB_STATUSES:
75
+ if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
76
+ raise TimeoutError(f"Job {self.id} did not complete within {elapsed} seconds")
77
+ time.sleep(delay)
78
+ delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
79
+ return self.status
80
+
81
+
82
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
83
+ def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
84
+ """Retrieve job execution status."""
85
+ # TODO: snowflake-snowpark-python<1.24.0 shows spurious error messages on
86
+ # `DESCRIBE` queries with bind variables
87
+ # Switch to use bind variables instead of client side formatting after
88
+ # updating to snowflake-snowpark-python>=1.24.0
89
+ (row,) = session.sql(f"DESCRIBE SERVICE {job_id}").collect()
90
+ return cast(types.JOB_STATUS, row["status"])
91
+
92
+
93
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
94
+ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
95
+ """
96
+ Retrieve the job's execution logs.
97
+
98
+ Args:
99
+ job_id: The job ID.
100
+ limit: The maximum number of lines to return. Negative values are treated as no limit.
101
+ session: The Snowpark session to use. If none specified, uses active session.
102
+
103
+ Returns:
104
+ The job's execution logs.
105
+ """
106
+ params: List[Any] = [job_id]
107
+ if limit > 0:
108
+ params.append(limit)
109
+ (row,) = session.sql(
110
+ f"SELECT SYSTEM$GET_SERVICE_LOGS(?, 0, '{constants.DEFAULT_CONTAINER_NAME}'{f', ?' if limit > 0 else ''})",
111
+ params=params,
112
+ ).collect()
113
+ return str(row[0])
@@ -0,0 +1,298 @@
1
+ import pathlib
2
+ import textwrap
3
+ from typing import Any, Callable, Dict, List, Literal, Optional, Union
4
+ from uuid import uuid4
5
+
6
+ import yaml
7
+
8
+ from snowflake import snowpark
9
+ from snowflake.ml._internal import telemetry
10
+ from snowflake.ml._internal.utils import identifier
11
+ from snowflake.ml.jobs import job as jb
12
+ from snowflake.ml.jobs._utils import payload_utils, spec_utils
13
+ from snowflake.snowpark.context import get_active_session
14
+ from snowflake.snowpark.exceptions import SnowparkSQLException
15
+
16
+ _PROJECT = "MLJob"
17
+ JOB_ID_PREFIX = "MLJOB_"
18
+
19
+
20
+ @snowpark._internal.utils.private_preview(version="1.7.4")
21
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
22
+ def list_jobs(
23
+ limit: int = 10,
24
+ scope: Union[Literal["account", "database", "schema"], str, None] = None,
25
+ session: Optional[snowpark.Session] = None,
26
+ ) -> snowpark.DataFrame:
27
+ """
28
+ Returns a Snowpark DataFrame with the list of jobs in the current session.
29
+
30
+ Args:
31
+ limit: The maximum number of jobs to return. Non-positive values are treated as no limit.
32
+ scope: The scope to list jobs from, such as "schema" or "compute pool <pool_name>".
33
+ session: The Snowpark session to use. If none specified, uses active session.
34
+
35
+ Returns:
36
+ A DataFrame with the list of jobs.
37
+
38
+ Examples:
39
+ >>> from snowflake.ml.jobs import list_jobs
40
+ >>> list_jobs(limit=5).show()
41
+ """
42
+ session = session or get_active_session()
43
+ query = "SHOW JOB SERVICES"
44
+ query += f" LIKE '{JOB_ID_PREFIX}%'"
45
+ if scope:
46
+ query += f" IN {scope}"
47
+ if limit > 0:
48
+ query += f" LIMIT {limit}"
49
+ df = session.sql(query)
50
+ df = df.select(
51
+ df['"name"'].alias('"id"'),
52
+ df['"owner"'],
53
+ df['"status"'],
54
+ df['"created_on"'],
55
+ df['"compute_pool"'],
56
+ ).order_by('"created_on"', ascending=False)
57
+ return df
58
+
59
+
60
+ @snowpark._internal.utils.private_preview(version="1.7.4")
61
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
62
+ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob:
63
+ """Retrieve a job service from the backend."""
64
+ session = session or get_active_session()
65
+
66
+ try:
67
+ # Validate job_id
68
+ job_id = identifier.resolve_identifier(job_id)
69
+ except ValueError as e:
70
+ raise ValueError(f"Invalid job ID: {job_id}") from e
71
+
72
+ try:
73
+ # Validate that job exists by doing a status check
74
+ job = jb.MLJob(job_id, session=session)
75
+ _ = job.status
76
+ return job
77
+ except SnowparkSQLException as e:
78
+ if "does not exist" in e.message:
79
+ raise ValueError(f"Job does not exist: {job_id}") from e
80
+ raise
81
+
82
+
83
+ @snowpark._internal.utils.private_preview(version="1.7.4")
84
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
85
+ def delete_job(job: Union[str, jb.MLJob], session: Optional[snowpark.Session] = None) -> None:
86
+ """Delete a job service from the backend. Status and logs will be lost."""
87
+ if isinstance(job, jb.MLJob):
88
+ job_id = job.id
89
+ session = job._session or session
90
+ else:
91
+ job_id = job
92
+ session = session or get_active_session()
93
+ session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
94
+
95
+
96
+ @snowpark._internal.utils.private_preview(version="1.7.4")
97
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
98
+ def submit_file(
99
+ file_path: str,
100
+ compute_pool: str,
101
+ *,
102
+ stage_name: str,
103
+ args: Optional[List[str]] = None,
104
+ env_vars: Optional[Dict[str, str]] = None,
105
+ pip_requirements: Optional[List[str]] = None,
106
+ external_access_integrations: Optional[List[str]] = None,
107
+ query_warehouse: Optional[str] = None,
108
+ spec_overrides: Optional[Dict[str, Any]] = None,
109
+ session: Optional[snowpark.Session] = None,
110
+ ) -> jb.MLJob:
111
+ """
112
+ Submit a Python file as a job to the compute pool.
113
+
114
+ Args:
115
+ file_path: The path to the file containing the source code for the job.
116
+ compute_pool: The compute pool to use for the job.
117
+ stage_name: The name of the stage where the job payload will be uploaded.
118
+ args: A list of arguments to pass to the job.
119
+ env_vars: Environment variables to set in container
120
+ pip_requirements: A list of pip requirements for the job.
121
+ external_access_integrations: A list of external access integrations.
122
+ query_warehouse: The query warehouse to use. Defaults to session warehouse.
123
+ spec_overrides: Custom service specification overrides to apply.
124
+ session: The Snowpark session to use. If none specified, uses active session.
125
+
126
+ Returns:
127
+ An object representing the submitted job.
128
+ """
129
+ return _submit_job(
130
+ source=file_path,
131
+ args=args,
132
+ compute_pool=compute_pool,
133
+ stage_name=stage_name,
134
+ env_vars=env_vars,
135
+ pip_requirements=pip_requirements,
136
+ external_access_integrations=external_access_integrations,
137
+ query_warehouse=query_warehouse,
138
+ spec_overrides=spec_overrides,
139
+ session=session,
140
+ )
141
+
142
+
143
+ @snowpark._internal.utils.private_preview(version="1.7.4")
144
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
145
+ def submit_directory(
146
+ dir_path: str,
147
+ compute_pool: str,
148
+ *,
149
+ entrypoint: str,
150
+ stage_name: str,
151
+ args: Optional[List[str]] = None,
152
+ env_vars: Optional[Dict[str, str]] = None,
153
+ pip_requirements: Optional[List[str]] = None,
154
+ external_access_integrations: Optional[List[str]] = None,
155
+ query_warehouse: Optional[str] = None,
156
+ spec_overrides: Optional[Dict[str, Any]] = None,
157
+ session: Optional[snowpark.Session] = None,
158
+ ) -> jb.MLJob:
159
+ """
160
+ Submit a directory containing Python script(s) as a job to the compute pool.
161
+
162
+ Args:
163
+ dir_path: The path to the directory containing the job payload.
164
+ compute_pool: The compute pool to use for the job.
165
+ entrypoint: The relative path to the entry point script inside the source directory.
166
+ stage_name: The name of the stage where the job payload will be uploaded.
167
+ args: A list of arguments to pass to the job.
168
+ env_vars: Environment variables to set in container
169
+ pip_requirements: A list of pip requirements for the job.
170
+ external_access_integrations: A list of external access integrations.
171
+ query_warehouse: The query warehouse to use. Defaults to session warehouse.
172
+ spec_overrides: Custom service specification overrides to apply.
173
+ session: The Snowpark session to use. If none specified, uses active session.
174
+
175
+ Returns:
176
+ An object representing the submitted job.
177
+ """
178
+ return _submit_job(
179
+ source=dir_path,
180
+ entrypoint=entrypoint,
181
+ args=args,
182
+ compute_pool=compute_pool,
183
+ stage_name=stage_name,
184
+ env_vars=env_vars,
185
+ pip_requirements=pip_requirements,
186
+ external_access_integrations=external_access_integrations,
187
+ query_warehouse=query_warehouse,
188
+ spec_overrides=spec_overrides,
189
+ session=session,
190
+ )
191
+
192
+
193
+ @telemetry.send_api_usage_telemetry(
194
+ project=_PROJECT,
195
+ func_params_to_log=[
196
+ # TODO: Log the source type (callable, file, directory, etc)
197
+ # TODO: Log instance type of compute pool used
198
+ # TODO: Log lengths of args, env_vars, and spec_overrides values
199
+ "pip_requirements",
200
+ "external_access_integrations",
201
+ ],
202
+ )
203
+ def _submit_job(
204
+ source: Union[str, Callable[..., Any]],
205
+ compute_pool: str,
206
+ *,
207
+ stage_name: str,
208
+ entrypoint: Optional[str] = None,
209
+ args: Optional[List[str]] = None,
210
+ env_vars: Optional[Dict[str, str]] = None,
211
+ pip_requirements: Optional[List[str]] = None,
212
+ external_access_integrations: Optional[List[str]] = None,
213
+ query_warehouse: Optional[str] = None,
214
+ spec_overrides: Optional[Dict[str, Any]] = None,
215
+ session: Optional[snowpark.Session] = None,
216
+ ) -> jb.MLJob:
217
+ """
218
+ Submit a job to the compute pool.
219
+
220
+ Args:
221
+ source: The file/directory path containing payload source code or a serializable Python callable.
222
+ compute_pool: The compute pool to use for the job.
223
+ stage_name: The name of the stage where the job payload will be uploaded.
224
+ entrypoint: The entry point for the job execution. Required if source is a directory.
225
+ args: A list of arguments to pass to the job.
226
+ env_vars: Environment variables to set in container
227
+ pip_requirements: A list of pip requirements for the job.
228
+ external_access_integrations: A list of external access integrations.
229
+ query_warehouse: The query warehouse to use. Defaults to session warehouse.
230
+ spec_overrides: Custom service specification overrides to apply.
231
+ session: The Snowpark session to use. If none specified, uses active session.
232
+
233
+ Returns:
234
+ An object representing the submitted job.
235
+
236
+ Raises:
237
+ RuntimeError: If required Snowflake features are not enabled.
238
+ """
239
+ session = session or get_active_session()
240
+ job_id = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
241
+ stage_name = "@" + stage_name.lstrip("@").rstrip("/")
242
+ stage_path = pathlib.PurePosixPath(f"{stage_name}/{job_id}")
243
+
244
+ # Upload payload
245
+ uploaded_payload = payload_utils.JobPayload(
246
+ source,
247
+ entrypoint=entrypoint,
248
+ pip_requirements=pip_requirements,
249
+ ).upload(session, stage_path)
250
+
251
+ # Generate service spec
252
+ spec = spec_utils.generate_service_spec(
253
+ session,
254
+ compute_pool=compute_pool,
255
+ payload=uploaded_payload,
256
+ args=args,
257
+ )
258
+ spec_overrides = spec_utils.generate_spec_overrides(
259
+ environment_vars=env_vars,
260
+ custom_overrides=spec_overrides,
261
+ )
262
+ if spec_overrides:
263
+ spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
264
+
265
+ # Generate SQL command for job submission
266
+ query_template = textwrap.dedent(
267
+ f"""\
268
+ EXECUTE JOB SERVICE
269
+ IN COMPUTE POOL {compute_pool}
270
+ FROM SPECIFICATION $$
271
+ {{}}
272
+ $$
273
+ NAME = {job_id}
274
+ ASYNC = TRUE
275
+ """
276
+ )
277
+ query = query_template.format(yaml.dump(spec)).splitlines()
278
+ if external_access_integrations:
279
+ external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
280
+ query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
281
+ query_warehouse = query_warehouse or session.get_current_warehouse()
282
+ if query_warehouse:
283
+ query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
284
+
285
+ # Submit job
286
+ query_text = "\n".join(line for line in query if line)
287
+
288
+ try:
289
+ _ = session.sql(query_text).collect()
290
+ except SnowparkSQLException as e:
291
+ if "invalid property 'ASYNC'" in e.message:
292
+ raise RuntimeError(
293
+ "SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
294
+ ) from e
295
+ raise
296
+
297
+ # TODO: Wrap snowflake.core.service.JobService object
298
+ return jb.MLJob(job_id, session=session)
@@ -447,13 +447,15 @@ class ModelVersion(lineage_node.LineageNode):
447
447
  target_function_info = functions[0]
448
448
 
449
449
  if service_name:
450
+ database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
451
+
450
452
  return self._model_ops.invoke_method(
451
453
  method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
452
454
  signature=target_function_info["signature"],
453
455
  X=X,
454
- database_name=None,
455
- schema_name=None,
456
- service_name=sql_identifier.SqlIdentifier(service_name),
456
+ database_name=database_name_id,
457
+ schema_name=schema_name_id,
458
+ service_name=service_name_id,
457
459
  strict_input_validation=strict_input_validation,
458
460
  statement_params=statement_params,
459
461
  )
@@ -33,6 +33,7 @@ from snowflake.snowpark._internal import utils as snowpark_utils
33
33
 
34
34
  class ServiceInfo(TypedDict):
35
35
  name: str
36
+ status: str
36
37
  inference_endpoint: Optional[str]
37
38
 
38
39
 
@@ -168,14 +169,10 @@ class ModelOperator:
168
169
  schema_name: Optional[sql_identifier.SqlIdentifier],
169
170
  model_name: sql_identifier.SqlIdentifier,
170
171
  version_name: sql_identifier.SqlIdentifier,
172
+ model_exists: bool,
171
173
  statement_params: Optional[Dict[str, Any]] = None,
172
174
  ) -> None:
173
- if self.validate_existence(
174
- database_name=database_name,
175
- schema_name=schema_name,
176
- model_name=model_name,
177
- statement_params=statement_params,
178
- ):
175
+ if model_exists:
179
176
  return self._model_version_client.add_version_from_model_version(
180
177
  source_database_name=source_database_name,
181
178
  source_schema_name=source_schema_name,
@@ -554,9 +551,13 @@ class ModelOperator:
554
551
  fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
555
552
 
556
553
  result = []
557
- ingress_url: Optional[str] = None
554
+
558
555
  for fully_qualified_service_name in fully_qualified_service_names:
556
+ ingress_url: Optional[str] = None
559
557
  db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
558
+ service_status, _ = self._service_client.get_service_status(
559
+ database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
560
+ )
560
561
  for res_row in self._service_client.show_endpoints(
561
562
  database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
562
563
  ):
@@ -570,7 +571,11 @@ class ModelOperator:
570
571
  )
571
572
  if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX):
572
573
  ingress_url = None
573
- result.append(ServiceInfo(name=fully_qualified_service_name, inference_endpoint=ingress_url))
574
+ result.append(
575
+ ServiceInfo(
576
+ name=fully_qualified_service_name, status=service_status.value, inference_endpoint=ingress_url
577
+ )
578
+ )
574
579
 
575
580
  return result
576
581
 
@@ -8,11 +8,9 @@ import threading
8
8
  import time
9
9
  from typing import Any, Dict, List, Optional, Tuple, Union, cast
10
10
 
11
- from packaging import version
12
-
13
11
  from snowflake import snowpark
14
12
  from snowflake.ml._internal import file_utils
15
- from snowflake.ml._internal.utils import service_logger, snowflake_env, sql_identifier
13
+ from snowflake.ml._internal.utils import service_logger, sql_identifier
16
14
  from snowflake.ml.model._client.service import model_deployment_spec
17
15
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
18
16
  from snowflake.snowpark import async_job, exceptions, row, session
@@ -133,14 +131,6 @@ class ServiceOperator:
133
131
  )
134
132
  stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
135
133
 
136
- # TODO(hayu): Remove the version check after Snowflake 8.40.0 release
137
- if (
138
- snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
139
- < version.parse("8.40.0")
140
- and build_external_access_integrations is None
141
- ):
142
- raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
143
-
144
134
  self._model_deployment_spec.save(
145
135
  database_name=database_name,
146
136
  schema_name=schema_name,
@@ -10,6 +10,7 @@ from snowflake.ml._internal.utils import (
10
10
  sql_identifier,
11
11
  )
12
12
  from snowflake.ml.model._client.sql import _base
13
+ from snowflake.ml.model._model_composer.model_method import constants
13
14
  from snowflake.snowpark import dataframe, functions as F, row, types as spt
14
15
  from snowflake.snowpark._internal import utils as snowpark_utils
15
16
 
@@ -333,6 +334,11 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
333
334
 
334
335
  args_sql = ", ".join(args_sql_list)
335
336
 
337
+ wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
338
+ if wide_input:
339
+ input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
340
+ args_sql = f"object_construct_keep_null({input_args_sql})"
341
+
336
342
  sql = textwrap.dedent(
337
343
  f"""WITH {','.join(with_statements)}
338
344
  SELECT *,
@@ -412,6 +418,11 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
412
418
 
413
419
  args_sql = ", ".join(args_sql_list)
414
420
 
421
+ wide_input = len(input_args) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT
422
+ if wide_input:
423
+ input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
424
+ args_sql = f"object_construct_keep_null({input_args_sql})"
425
+
415
426
  sql = textwrap.dedent(
416
427
  f"""WITH {','.join(with_statements)}
417
428
  SELECT *,
@@ -4,6 +4,7 @@ import textwrap
4
4
  from typing import Any, Dict, List, Optional, Tuple
5
5
 
6
6
  from snowflake import snowpark
7
+ from snowflake.ml._internal import platform_capabilities
7
8
  from snowflake.ml._internal.utils import (
8
9
  identifier,
9
10
  query_result_checker,
@@ -120,12 +121,18 @@ class ServiceSQLClient(_base._BaseSQLClient):
120
121
  args_sql_list.append(input_arg_value)
121
122
  args_sql = ", ".join(args_sql_list)
122
123
 
123
- function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
124
- fully_qualified_function_name = identifier.get_schema_level_object_identifier(
125
- actual_database_name.identifier(),
126
- actual_schema_name.identifier(),
127
- function_name,
128
- )
124
+ if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
125
+ fully_qualified_service_name = self.fully_qualified_object_name(
126
+ actual_database_name, actual_schema_name, service_name
127
+ )
128
+ fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
129
+ else:
130
+ function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
131
+ fully_qualified_function_name = identifier.get_schema_level_object_identifier(
132
+ actual_database_name.identifier(),
133
+ actual_schema_name.identifier(),
134
+ function_name,
135
+ )
129
136
 
130
137
  sql = textwrap.dedent(
131
138
  f"""{with_sql}