snowflake-ml-python 1.7.3__py3-none-any.whl → 1.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. snowflake/cortex/_complete.py +19 -0
  2. snowflake/ml/_internal/env_utils.py +64 -21
  3. snowflake/ml/_internal/platform_capabilities.py +87 -0
  4. snowflake/ml/_internal/relax_version_strategy.py +16 -0
  5. snowflake/ml/_internal/telemetry.py +21 -0
  6. snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
  7. snowflake/ml/dataset/dataset.py +0 -1
  8. snowflake/ml/feature_store/feature_store.py +18 -0
  9. snowflake/ml/feature_store/feature_view.py +46 -1
  10. snowflake/ml/fileset/fileset.py +6 -0
  11. snowflake/ml/jobs/__init__.py +21 -0
  12. snowflake/ml/jobs/_utils/constants.py +57 -0
  13. snowflake/ml/jobs/_utils/payload_utils.py +438 -0
  14. snowflake/ml/jobs/_utils/spec_utils.py +296 -0
  15. snowflake/ml/jobs/_utils/types.py +39 -0
  16. snowflake/ml/jobs/decorators.py +71 -0
  17. snowflake/ml/jobs/job.py +113 -0
  18. snowflake/ml/jobs/manager.py +298 -0
  19. snowflake/ml/model/_client/ops/model_ops.py +11 -2
  20. snowflake/ml/model/_client/ops/service_ops.py +1 -11
  21. snowflake/ml/model/_client/sql/service.py +13 -6
  22. snowflake/ml/model/_packager/model_env/model_env.py +45 -28
  23. snowflake/ml/model/_packager/model_handlers/_utils.py +19 -6
  24. snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
  25. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +17 -0
  26. snowflake/ml/model/_packager/model_handlers/keras.py +230 -0
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -0
  28. snowflake/ml/model/_packager/model_handlers/sklearn.py +28 -3
  29. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +74 -21
  30. snowflake/ml/model/_packager/model_handlers/tensorflow.py +27 -49
  31. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
  32. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -1
  33. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +3 -0
  34. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  35. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -1
  36. snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
  37. snowflake/ml/model/_signatures/base_handler.py +1 -2
  38. snowflake/ml/model/_signatures/builtins_handler.py +2 -2
  39. snowflake/ml/model/_signatures/core.py +2 -2
  40. snowflake/ml/model/_signatures/numpy_handler.py +11 -12
  41. snowflake/ml/model/_signatures/pandas_handler.py +11 -9
  42. snowflake/ml/model/_signatures/pytorch_handler.py +3 -6
  43. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  44. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
  45. snowflake/ml/model/model_signature.py +25 -4
  46. snowflake/ml/model/type_hints.py +15 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
  51. snowflake/ml/modeling/cluster/birch.py +6 -3
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
  53. snowflake/ml/modeling/cluster/dbscan.py +6 -3
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
  55. snowflake/ml/modeling/cluster/k_means.py +6 -3
  56. snowflake/ml/modeling/cluster/mean_shift.py +6 -3
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
  58. snowflake/ml/modeling/cluster/optics.py +6 -3
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
  62. snowflake/ml/modeling/compose/column_transformer.py +6 -3
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
  69. snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
  70. snowflake/ml/modeling/covariance/oas.py +6 -3
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
  74. snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
  79. snowflake/ml/modeling/decomposition/pca.py +6 -3
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
  108. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
  110. snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
  111. snowflake/ml/modeling/impute/knn_imputer.py +6 -3
  112. snowflake/ml/modeling/impute/missing_indicator.py +6 -3
  113. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
  114. snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
  115. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
  116. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
  117. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
  118. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
  119. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
  120. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
  121. snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
  122. snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
  123. snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
  124. snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
  125. snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
  126. snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
  127. snowflake/ml/modeling/linear_model/lars.py +6 -3
  128. snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
  129. snowflake/ml/modeling/linear_model/lasso.py +6 -3
  130. snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
  131. snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
  132. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
  133. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
  134. snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
  135. snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
  136. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
  137. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
  138. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
  139. snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
  140. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
  141. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
  142. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
  143. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
  144. snowflake/ml/modeling/linear_model/perceptron.py +6 -3
  145. snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
  146. snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
  147. snowflake/ml/modeling/linear_model/ridge.py +6 -3
  148. snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
  149. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
  150. snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
  151. snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
  152. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
  153. snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
  154. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
  155. snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
  156. snowflake/ml/modeling/manifold/isomap.py +6 -3
  157. snowflake/ml/modeling/manifold/mds.py +6 -3
  158. snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
  159. snowflake/ml/modeling/manifold/tsne.py +6 -3
  160. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
  161. snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
  162. snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
  163. snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
  174. snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
  184. snowflake/ml/modeling/pipeline/pipeline.py +28 -3
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -5
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
  188. snowflake/ml/modeling/svm/linear_svc.py +6 -3
  189. snowflake/ml/modeling/svm/linear_svr.py +6 -3
  190. snowflake/ml/modeling/svm/nu_svc.py +6 -3
  191. snowflake/ml/modeling/svm/nu_svr.py +6 -3
  192. snowflake/ml/modeling/svm/svc.py +6 -3
  193. snowflake/ml/modeling/svm/svr.py +6 -3
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +6 -3
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +6 -3
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +6 -3
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +6 -3
  202. snowflake/ml/registry/registry.py +34 -4
  203. snowflake/ml/version.py +1 -1
  204. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/METADATA +81 -33
  205. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/RECORD +208 -196
  206. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/WHEEL +1 -1
  207. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/LICENSE.txt +0 -0
  208. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,296 @@
1
+ import logging
2
+ from math import ceil
3
+ from pathlib import PurePath
4
+ from typing import Any, Dict, List, Optional, Union
5
+
6
+ from snowflake import snowpark
7
+ from snowflake.ml._internal.utils import snowflake_env
8
+ from snowflake.ml.jobs._utils import constants, types
9
+
10
+
11
+ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
12
+ """Extract resource information for the specified compute pool"""
13
+ # Get the instance family
14
+ rows = session.sql(f"show compute pools like '{compute_pool}'").collect()
15
+ if not rows:
16
+ raise ValueError(f"Compute pool '{compute_pool}' not found")
17
+ instance_family: str = rows[0]["instance_family"]
18
+
19
+ # Get the cloud we're using (AWS, Azure, etc)
20
+ region = snowflake_env.get_regions(session)[snowflake_env.get_current_region_id(session)]
21
+ cloud = region["cloud"]
22
+
23
+ return (
24
+ constants.COMMON_INSTANCE_FAMILIES.get(instance_family)
25
+ or constants.CLOUD_INSTANCE_FAMILIES[cloud][instance_family]
26
+ )
27
+
28
+
29
+ def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
30
+ # Retrieve compute pool node resources
31
+ resources = _get_node_resources(session, compute_pool=compute_pool)
32
+
33
+ # Use MLRuntime image
34
+ image_repo = constants.DEFAULT_IMAGE_REPO
35
+ image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
36
+ image_tag = constants.DEFAULT_IMAGE_TAG
37
+
38
+ # Try to pull latest image tag from server side if possible
39
+ query_result = session.sql("SHOW PARAMETERS LIKE 'constants.RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT").collect()
40
+ if query_result:
41
+ image_tag = query_result[0]["value"]
42
+
43
+ # TODO: Should each instance consume the entire pod?
44
+ return types.ImageSpec(
45
+ repo=image_repo,
46
+ image_name=image_name,
47
+ image_tag=image_tag,
48
+ resource_requests=resources,
49
+ resource_limits=resources,
50
+ )
51
+
52
+
53
+ def generate_spec_overrides(
54
+ environment_vars: Optional[Dict[str, str]] = None,
55
+ custom_overrides: Optional[Dict[str, Any]] = None,
56
+ ) -> Dict[str, Any]:
57
+ """
58
+ Generate a dictionary of service specification overrides.
59
+
60
+ Args:
61
+ environment_vars: Environment variables to set in primary container
62
+ custom_overrides: Custom service specification overrides
63
+
64
+ Returns:
65
+ Resulting service specifiation patch dict. Empty if no overrides were supplied.
66
+ """
67
+ # Generate container level overrides
68
+ container_spec: Dict[str, Any] = {
69
+ "name": constants.DEFAULT_CONTAINER_NAME,
70
+ }
71
+ if environment_vars:
72
+ # TODO: Validate environment variables
73
+ container_spec["env"] = environment_vars
74
+
75
+ # Build container override spec only if any overrides were supplied
76
+ spec = {}
77
+ if len(container_spec) > 1:
78
+ spec = {
79
+ "spec": {
80
+ "containers": [container_spec],
81
+ }
82
+ }
83
+
84
+ # Apply custom overrides
85
+ if custom_overrides:
86
+ spec = merge_patch(spec, custom_overrides, display_name="custom_overrides")
87
+
88
+ return spec
89
+
90
+
91
+ def generate_service_spec(
92
+ session: snowpark.Session,
93
+ compute_pool: str,
94
+ payload: types.UploadedPayload,
95
+ args: Optional[List[str]] = None,
96
+ ) -> Dict[str, Any]:
97
+ """
98
+ Generate a service specification for a job.
99
+
100
+ Args:
101
+ session: Snowflake session
102
+ compute_pool: Compute pool for job execution
103
+ payload: Uploaded job payload
104
+ args: Arguments to pass to entrypoint script
105
+
106
+ Returns:
107
+ Job service specification
108
+ """
109
+ # Set resource requests/limits, including nvidia.com/gpu quantity if applicable
110
+ image_spec = _get_image_spec(session, compute_pool)
111
+ resource_requests: Dict[str, Union[str, int]] = {
112
+ "cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
113
+ "memory": f"{image_spec.resource_limits.memory}Gi",
114
+ }
115
+ resource_limits: Dict[str, Union[str, int]] = {
116
+ "cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
117
+ "memory": f"{image_spec.resource_limits.memory}Gi",
118
+ }
119
+ if image_spec.resource_limits.gpu > 0:
120
+ resource_requests["nvidia.com/gpu"] = image_spec.resource_requests.gpu
121
+ resource_limits["nvidia.com/gpu"] = image_spec.resource_limits.gpu
122
+
123
+ # Add local volumes for ephemeral logs and artifacts
124
+ volumes: List[Dict[str, str]] = []
125
+ volume_mounts: List[Dict[str, str]] = []
126
+ for volume_name, mount_path in [
127
+ ("system-logs", "/var/log/managedservices/system/mlrs"),
128
+ ("user-logs", "/var/log/managedservices/user/mlrs"),
129
+ ]:
130
+ volume_mounts.append(
131
+ {
132
+ "name": volume_name,
133
+ "mountPath": mount_path,
134
+ }
135
+ )
136
+ volumes.append(
137
+ {
138
+ "name": volume_name,
139
+ "source": "local",
140
+ }
141
+ )
142
+
143
+ # Mount 30% of memory limit as a memory-backed volume
144
+ memory_volume_size = min(
145
+ ceil(image_spec.resource_limits.memory * constants.MEMORY_VOLUME_SIZE),
146
+ image_spec.resource_requests.memory,
147
+ )
148
+ volume_mounts.append(
149
+ {
150
+ "name": constants.MEMORY_VOLUME_NAME,
151
+ "mountPath": "/dev/shm",
152
+ }
153
+ )
154
+ volumes.append(
155
+ {
156
+ "name": constants.MEMORY_VOLUME_NAME,
157
+ "source": "memory",
158
+ "size": f"{memory_volume_size}Gi",
159
+ }
160
+ )
161
+
162
+ # Mount payload as volume
163
+ stage_mount = PurePath(constants.STAGE_VOLUME_MOUNT_PATH)
164
+ volume_mounts.append(
165
+ {
166
+ "name": constants.STAGE_VOLUME_NAME,
167
+ "mountPath": stage_mount.as_posix(),
168
+ }
169
+ )
170
+ volumes.append(
171
+ {
172
+ "name": constants.STAGE_VOLUME_NAME,
173
+ "source": payload.stage_path.as_posix(),
174
+ }
175
+ )
176
+
177
+ # TODO: Add hooks for endpoints for integration with TensorBoard etc
178
+
179
+ # Assemble into service specification dict
180
+ spec = {
181
+ "spec": {
182
+ "containers": [
183
+ {
184
+ "name": constants.DEFAULT_CONTAINER_NAME,
185
+ "image": image_spec.full_name,
186
+ "command": ["/usr/local/bin/_entrypoint.sh"],
187
+ "args": [
188
+ stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v for v in payload.entrypoint
189
+ ]
190
+ + (args or []),
191
+ "env": {
192
+ constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
193
+ },
194
+ "volumeMounts": volume_mounts,
195
+ "resources": {
196
+ "requests": resource_requests,
197
+ "limits": resource_limits,
198
+ },
199
+ },
200
+ ],
201
+ "volumes": volumes,
202
+ }
203
+ }
204
+
205
+ return spec
206
+
207
+
208
+ def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
209
+ """
210
+ Implements a modified RFC7386 JSON Merge Patch
211
+ https://datatracker.ietf.org/doc/html/rfc7386
212
+
213
+ Behavior differs from the RFC in the following ways:
214
+ 1. Empty nested dictionaries resulting from the patch are treated as None and are pruned
215
+ 2. Attempts to merge lists of dicts using a merge key (default "name").
216
+ See _merge_lists_of_dicts for details on list merge behavior.
217
+
218
+ Args:
219
+ base: The base object to patch.
220
+ patch: The patch object.
221
+ display_name: The name of the patch object for logging purposes.
222
+
223
+ Returns:
224
+ The patched object.
225
+ """
226
+ if not type(base) is type(patch):
227
+ if base is not None:
228
+ logging.warning(f"Type mismatch while merging {display_name} (base={type(base)}, patch={type(patch)})")
229
+ return patch
230
+ elif isinstance(patch, list) and all(isinstance(v, dict) for v in base + patch):
231
+ # TODO: Should we prune empty lists?
232
+ return _merge_lists_of_dicts(base, patch, display_name=display_name)
233
+ elif not isinstance(patch, dict) or len(patch) == 0:
234
+ return patch
235
+
236
+ result = dict(base) # Shallow copy
237
+ for key, value in patch.items():
238
+ if value is None:
239
+ result.pop(key, None)
240
+ else:
241
+ merge_result = merge_patch(result.get(key, None), value, display_name=f"{display_name}.{key}")
242
+ if isinstance(merge_result, dict) and len(merge_result) == 0:
243
+ result.pop(key, None)
244
+ else:
245
+ result[key] = merge_result
246
+
247
+ return result
248
+
249
+
250
+ def _merge_lists_of_dicts(
251
+ base: List[Dict[str, Any]], patch: List[Dict[str, Any]], merge_key: str = "name", display_name: str = ""
252
+ ) -> List[Dict[str, Any]]:
253
+ """
254
+ Attempts to merge lists of dicts by matching on a merge key (default "name").
255
+ - If the merge key is missing, the behavior falls back to overwriting the list.
256
+ - If the merge key is present, the behavior is to match the list elements based on the
257
+ merge key and preserving any unmatched elements from the base list.
258
+ - Matched entries may be dropped in the following way(s):
259
+ 1. The matching patch entry has a None key entry, e.g. { "name": "foo", None: None }.
260
+
261
+ Args:
262
+ base: The base list of dicts.
263
+ patch: The patch list of dicts.
264
+ merge_key: The key to use for merging.
265
+ display_name: The name of the patch object for logging purposes.
266
+
267
+ Returns:
268
+ The merged list of dicts if merging successful, else returns the patch list.
269
+ """
270
+ if any(merge_key not in d for d in base + patch):
271
+ logging.warning(f"Missing merge key {merge_key} in {display_name}. Falling back to overwrite behavior.")
272
+ return patch
273
+
274
+ # Build mapping of merge key values to list elements for the base list
275
+ result = {d[merge_key]: d for d in base}
276
+ if len(result) != len(base):
277
+ logging.warning(f"Duplicate merge key {merge_key} in {display_name}. Falling back to overwrite behavior.")
278
+ return patch
279
+
280
+ # Apply patches
281
+ for d in patch:
282
+ key = d[merge_key]
283
+
284
+ # Removal case 1: `None` key in patch entry
285
+ if None in d:
286
+ result.pop(key, None)
287
+ continue
288
+
289
+ # Apply patch
290
+ if key in result:
291
+ d = merge_patch(result[key], d, display_name=f"{display_name}[{merge_key}={d[merge_key]}]")
292
+ # TODO: Should we drop the item if the patch result is empty save for the merge key?
293
+ # Can check `d.keys() <= {merge_key}`
294
+ result[key] = d
295
+
296
+ return list(result.values())
@@ -0,0 +1,39 @@
1
+ from dataclasses import dataclass
2
+ from pathlib import PurePath
3
+ from typing import List, Literal, Optional, Union
4
+
5
+ JOB_STATUS = Literal[
6
+ "PENDING",
7
+ "RUNNING",
8
+ "FAILED",
9
+ "DONE",
10
+ "INTERNAL_ERROR",
11
+ ]
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class UploadedPayload:
16
+ # TODO: Include manifest of payload files for validation
17
+ stage_path: PurePath
18
+ entrypoint: List[Union[str, PurePath]]
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class ComputeResources:
23
+ cpu: float # Number of vCPU cores
24
+ memory: float # Memory in GiB
25
+ gpu: int = 0 # Number of GPUs
26
+ gpu_type: Optional[str] = None
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class ImageSpec:
31
+ repo: str
32
+ image_name: str
33
+ image_tag: str
34
+ resource_requests: ComputeResources
35
+ resource_limits: ComputeResources
36
+
37
+ @property
38
+ def full_name(self) -> str:
39
+ return f"{self.repo}/{self.image_name}:{self.image_tag}"
@@ -0,0 +1,71 @@
1
+ import copy
2
+ import functools
3
+ from typing import Callable, Dict, List, Optional, TypeVar
4
+
5
+ from typing_extensions import ParamSpec
6
+
7
+ from snowflake import snowpark
8
+ from snowflake.ml._internal import telemetry
9
+ from snowflake.ml.jobs import job as jb, manager as jm
10
+ from snowflake.ml.jobs._utils import constants
11
+
12
+ _PROJECT = "MLJob"
13
+
14
+ _Args = ParamSpec("_Args")
15
+ _ReturnValue = TypeVar("_ReturnValue")
16
+
17
+
18
+ @snowpark._internal.utils.private_preview(version="1.7.4")
19
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
20
+ def remote(
21
+ compute_pool: str,
22
+ stage_name: str,
23
+ pip_requirements: Optional[List[str]] = None,
24
+ external_access_integrations: Optional[List[str]] = None,
25
+ query_warehouse: Optional[str] = None,
26
+ env_vars: Optional[Dict[str, str]] = None,
27
+ session: Optional[snowpark.Session] = None,
28
+ ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob]]:
29
+ """
30
+ Submit a job to the compute pool.
31
+
32
+ Args:
33
+ compute_pool: The compute pool to use for the job.
34
+ stage_name: The name of the stage where the job payload will be uploaded.
35
+ pip_requirements: A list of pip requirements for the job.
36
+ external_access_integrations: A list of external access integrations.
37
+ query_warehouse: The query warehouse to use. Defaults to session warehouse.
38
+ env_vars: Environment variables to set in container
39
+ session: The Snowpark session to use. If none specified, uses active session.
40
+
41
+ Returns:
42
+ Decorator that dispatches invocations of the decorated function as remote jobs.
43
+ """
44
+
45
+ def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob]:
46
+ # Copy the function to avoid modifying the original
47
+ # We need to modify the line number of the function to exclude the
48
+ # decorator from the copied source code
49
+ wrapped_func = copy.copy(func)
50
+ wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
51
+
52
+ @functools.wraps(func)
53
+ def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob:
54
+ payload = functools.partial(func, *args, **kwargs)
55
+ setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
56
+ job = jm._submit_job(
57
+ source=payload,
58
+ stage_name=stage_name,
59
+ compute_pool=compute_pool,
60
+ pip_requirements=pip_requirements,
61
+ external_access_integrations=external_access_integrations,
62
+ query_warehouse=query_warehouse,
63
+ env_vars=env_vars,
64
+ session=session,
65
+ )
66
+ assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
67
+ return job
68
+
69
+ return wrapper
70
+
71
+ return decorator
@@ -0,0 +1,113 @@
1
+ import time
2
+ from typing import Any, List, Optional, cast
3
+
4
+ from snowflake import snowpark
5
+ from snowflake.ml._internal import telemetry
6
+ from snowflake.ml.jobs._utils import constants, types
7
+ from snowflake.snowpark import context as sp_context
8
+
9
+ _PROJECT = "MLJob"
10
+ TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
11
+
12
+
13
+ class MLJob:
14
+ def __init__(self, id: str, session: Optional[snowpark.Session] = None) -> None:
15
+ self._id = id
16
+ self._session = session or sp_context.get_active_session()
17
+ self._status: types.JOB_STATUS = "PENDING"
18
+
19
+ @property
20
+ def id(self) -> str:
21
+ """Get the unique job ID"""
22
+ return self._id
23
+
24
+ @property
25
+ def status(self) -> types.JOB_STATUS:
26
+ """Get the job's execution status."""
27
+ if self._status not in TERMINAL_JOB_STATUSES:
28
+ # Query backend for job status if not in terminal state
29
+ self._status = _get_status(self._session, self.id)
30
+ return self._status
31
+
32
+ @snowpark._internal.utils.private_preview(version="1.7.4")
33
+ def get_logs(self, limit: int = -1) -> str:
34
+ """
35
+ Return the job's execution logs.
36
+
37
+ Args:
38
+ limit: The maximum number of lines to return. Negative values are treated as no limit.
39
+
40
+ Returns:
41
+ The job's execution logs.
42
+ """
43
+ logs = _get_logs(self._session, self.id, limit)
44
+ assert isinstance(logs, str) # mypy
45
+ return logs
46
+
47
+ @snowpark._internal.utils.private_preview(version="1.7.4")
48
+ def show_logs(self, limit: int = -1) -> None:
49
+ """
50
+ Display the job's execution logs.
51
+
52
+ Args:
53
+ limit: The maximum number of lines to display. Negative values are treated as no limit.
54
+ """
55
+ print(self.get_logs(limit)) # noqa: T201: we need to print here.
56
+
57
+ @snowpark._internal.utils.private_preview(version="1.7.4")
58
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
59
+ def wait(self, timeout: float = -1) -> types.JOB_STATUS:
60
+ """
61
+ Block until completion. Returns completion status.
62
+
63
+ Args:
64
+ timeout: The maximum time to wait in seconds. Negative values are treated as no timeout.
65
+
66
+ Returns:
67
+ The job's completion status.
68
+
69
+ Raises:
70
+ TimeoutError: If the job does not complete within the specified timeout.
71
+ """
72
+ delay = constants.JOB_POLL_INITIAL_DELAY_SECONDS # Start with 100ms delay
73
+ start_time = time.monotonic()
74
+ while self.status not in TERMINAL_JOB_STATUSES:
75
+ if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
76
+ raise TimeoutError(f"Job {self.id} did not complete within {elapsed} seconds")
77
+ time.sleep(delay)
78
+ delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
79
+ return self.status
80
+
81
+
82
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
83
+ def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
84
+ """Retrieve job execution status."""
85
+ # TODO: snowflake-snowpark-python<1.24.0 shows spurious error messages on
86
+ # `DESCRIBE` queries with bind variables
87
+ # Switch to use bind variables instead of client side formatting after
88
+ # updating to snowflake-snowpark-python>=1.24.0
89
+ (row,) = session.sql(f"DESCRIBE SERVICE {job_id}").collect()
90
+ return cast(types.JOB_STATUS, row["status"])
91
+
92
+
93
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit"])
94
+ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
95
+ """
96
+ Retrieve the job's execution logs.
97
+
98
+ Args:
99
+ job_id: The job ID.
100
+ limit: The maximum number of lines to return. Negative values are treated as no limit.
101
+ session: The Snowpark session to use. If none specified, uses active session.
102
+
103
+ Returns:
104
+ The job's execution logs.
105
+ """
106
+ params: List[Any] = [job_id]
107
+ if limit > 0:
108
+ params.append(limit)
109
+ (row,) = session.sql(
110
+ f"SELECT SYSTEM$GET_SERVICE_LOGS(?, 0, '{constants.DEFAULT_CONTAINER_NAME}'{f', ?' if limit > 0 else ''})",
111
+ params=params,
112
+ ).collect()
113
+ return str(row[0])