snowflake-ml-python 1.7.3__py3-none-any.whl → 1.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. snowflake/cortex/_complete.py +19 -0
  2. snowflake/ml/_internal/env_utils.py +64 -21
  3. snowflake/ml/_internal/platform_capabilities.py +87 -0
  4. snowflake/ml/_internal/relax_version_strategy.py +16 -0
  5. snowflake/ml/_internal/telemetry.py +21 -0
  6. snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
  7. snowflake/ml/dataset/dataset.py +0 -1
  8. snowflake/ml/feature_store/feature_store.py +18 -0
  9. snowflake/ml/feature_store/feature_view.py +46 -1
  10. snowflake/ml/fileset/fileset.py +6 -0
  11. snowflake/ml/jobs/__init__.py +21 -0
  12. snowflake/ml/jobs/_utils/constants.py +57 -0
  13. snowflake/ml/jobs/_utils/payload_utils.py +438 -0
  14. snowflake/ml/jobs/_utils/spec_utils.py +296 -0
  15. snowflake/ml/jobs/_utils/types.py +39 -0
  16. snowflake/ml/jobs/decorators.py +71 -0
  17. snowflake/ml/jobs/job.py +113 -0
  18. snowflake/ml/jobs/manager.py +298 -0
  19. snowflake/ml/model/_client/ops/model_ops.py +11 -2
  20. snowflake/ml/model/_client/ops/service_ops.py +1 -11
  21. snowflake/ml/model/_client/sql/service.py +13 -6
  22. snowflake/ml/model/_packager/model_env/model_env.py +45 -28
  23. snowflake/ml/model/_packager/model_handlers/_utils.py +19 -6
  24. snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
  25. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +17 -0
  26. snowflake/ml/model/_packager/model_handlers/keras.py +230 -0
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -0
  28. snowflake/ml/model/_packager/model_handlers/sklearn.py +28 -3
  29. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +74 -21
  30. snowflake/ml/model/_packager/model_handlers/tensorflow.py +27 -49
  31. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
  32. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -1
  33. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +3 -0
  34. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  35. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -1
  36. snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
  37. snowflake/ml/model/_signatures/base_handler.py +1 -2
  38. snowflake/ml/model/_signatures/builtins_handler.py +2 -2
  39. snowflake/ml/model/_signatures/core.py +2 -2
  40. snowflake/ml/model/_signatures/numpy_handler.py +11 -12
  41. snowflake/ml/model/_signatures/pandas_handler.py +11 -9
  42. snowflake/ml/model/_signatures/pytorch_handler.py +3 -6
  43. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  44. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
  45. snowflake/ml/model/model_signature.py +25 -4
  46. snowflake/ml/model/type_hints.py +15 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
  51. snowflake/ml/modeling/cluster/birch.py +6 -3
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
  53. snowflake/ml/modeling/cluster/dbscan.py +6 -3
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
  55. snowflake/ml/modeling/cluster/k_means.py +6 -3
  56. snowflake/ml/modeling/cluster/mean_shift.py +6 -3
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
  58. snowflake/ml/modeling/cluster/optics.py +6 -3
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
  62. snowflake/ml/modeling/compose/column_transformer.py +6 -3
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
  69. snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
  70. snowflake/ml/modeling/covariance/oas.py +6 -3
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
  74. snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
  79. snowflake/ml/modeling/decomposition/pca.py +6 -3
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
  108. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
  110. snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
  111. snowflake/ml/modeling/impute/knn_imputer.py +6 -3
  112. snowflake/ml/modeling/impute/missing_indicator.py +6 -3
  113. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
  114. snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
  115. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
  116. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
  117. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
  118. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
  119. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
  120. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
  121. snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
  122. snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
  123. snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
  124. snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
  125. snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
  126. snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
  127. snowflake/ml/modeling/linear_model/lars.py +6 -3
  128. snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
  129. snowflake/ml/modeling/linear_model/lasso.py +6 -3
  130. snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
  131. snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
  132. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
  133. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
  134. snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
  135. snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
  136. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
  137. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
  138. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
  139. snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
  140. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
  141. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
  142. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
  143. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
  144. snowflake/ml/modeling/linear_model/perceptron.py +6 -3
  145. snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
  146. snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
  147. snowflake/ml/modeling/linear_model/ridge.py +6 -3
  148. snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
  149. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
  150. snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
  151. snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
  152. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
  153. snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
  154. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
  155. snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
  156. snowflake/ml/modeling/manifold/isomap.py +6 -3
  157. snowflake/ml/modeling/manifold/mds.py +6 -3
  158. snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
  159. snowflake/ml/modeling/manifold/tsne.py +6 -3
  160. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
  161. snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
  162. snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
  163. snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
  174. snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
  184. snowflake/ml/modeling/pipeline/pipeline.py +28 -3
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -5
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
  188. snowflake/ml/modeling/svm/linear_svc.py +6 -3
  189. snowflake/ml/modeling/svm/linear_svr.py +6 -3
  190. snowflake/ml/modeling/svm/nu_svc.py +6 -3
  191. snowflake/ml/modeling/svm/nu_svr.py +6 -3
  192. snowflake/ml/modeling/svm/svc.py +6 -3
  193. snowflake/ml/modeling/svm/svr.py +6 -3
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +6 -3
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +6 -3
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +6 -3
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +6 -3
  202. snowflake/ml/registry/registry.py +34 -4
  203. snowflake/ml/version.py +1 -1
  204. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/METADATA +81 -33
  205. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/RECORD +208 -196
  206. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/WHEEL +1 -1
  207. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/LICENSE.txt +0 -0
  208. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,298 @@
1
+ import pathlib
2
+ import textwrap
3
+ from typing import Any, Callable, Dict, List, Literal, Optional, Union
4
+ from uuid import uuid4
5
+
6
+ import yaml
7
+
8
+ from snowflake import snowpark
9
+ from snowflake.ml._internal import telemetry
10
+ from snowflake.ml._internal.utils import identifier
11
+ from snowflake.ml.jobs import job as jb
12
+ from snowflake.ml.jobs._utils import payload_utils, spec_utils
13
+ from snowflake.snowpark.context import get_active_session
14
+ from snowflake.snowpark.exceptions import SnowparkSQLException
15
+
16
+ _PROJECT = "MLJob"
17
+ JOB_ID_PREFIX = "MLJOB_"
18
+
19
+
20
+ @snowpark._internal.utils.private_preview(version="1.7.4")
21
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
22
+ def list_jobs(
23
+ limit: int = 10,
24
+ scope: Union[Literal["account", "database", "schema"], str, None] = None,
25
+ session: Optional[snowpark.Session] = None,
26
+ ) -> snowpark.DataFrame:
27
+ """
28
+ Returns a Snowpark DataFrame with the list of jobs in the current session.
29
+
30
+ Args:
31
+ limit: The maximum number of jobs to return. Non-positive values are treated as no limit.
32
+ scope: The scope to list jobs from, such as "schema" or "compute pool <pool_name>".
33
+ session: The Snowpark session to use. If none specified, uses active session.
34
+
35
+ Returns:
36
+ A DataFrame with the list of jobs.
37
+
38
+ Examples:
39
+ >>> from snowflake.ml.jobs import list_jobs
40
+ >>> list_jobs(limit=5).show()
41
+ """
42
+ session = session or get_active_session()
43
+ query = "SHOW JOB SERVICES"
44
+ query += f" LIKE '{JOB_ID_PREFIX}%'"
45
+ if scope:
46
+ query += f" IN {scope}"
47
+ if limit > 0:
48
+ query += f" LIMIT {limit}"
49
+ df = session.sql(query)
50
+ df = df.select(
51
+ df['"name"'].alias('"id"'),
52
+ df['"owner"'],
53
+ df['"status"'],
54
+ df['"created_on"'],
55
+ df['"compute_pool"'],
56
+ ).order_by('"created_on"', ascending=False)
57
+ return df
58
+
59
+
60
+ @snowpark._internal.utils.private_preview(version="1.7.4")
61
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
62
+ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob:
63
+ """Retrieve a job service from the backend."""
64
+ session = session or get_active_session()
65
+
66
+ try:
67
+ # Validate job_id
68
+ job_id = identifier.resolve_identifier(job_id)
69
+ except ValueError as e:
70
+ raise ValueError(f"Invalid job ID: {job_id}") from e
71
+
72
+ try:
73
+ # Validate that job exists by doing a status check
74
+ job = jb.MLJob(job_id, session=session)
75
+ _ = job.status
76
+ return job
77
+ except SnowparkSQLException as e:
78
+ if "does not exist" in e.message:
79
+ raise ValueError(f"Job does not exist: {job_id}") from e
80
+ raise
81
+
82
+
83
+ @snowpark._internal.utils.private_preview(version="1.7.4")
84
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
85
+ def delete_job(job: Union[str, jb.MLJob], session: Optional[snowpark.Session] = None) -> None:
86
+ """Delete a job service from the backend. Status and logs will be lost."""
87
+ if isinstance(job, jb.MLJob):
88
+ job_id = job.id
89
+ session = job._session or session
90
+ else:
91
+ job_id = job
92
+ session = session or get_active_session()
93
+ session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
94
+
95
+
96
+ @snowpark._internal.utils.private_preview(version="1.7.4")
97
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
98
+ def submit_file(
99
+ file_path: str,
100
+ compute_pool: str,
101
+ *,
102
+ stage_name: str,
103
+ args: Optional[List[str]] = None,
104
+ env_vars: Optional[Dict[str, str]] = None,
105
+ pip_requirements: Optional[List[str]] = None,
106
+ external_access_integrations: Optional[List[str]] = None,
107
+ query_warehouse: Optional[str] = None,
108
+ spec_overrides: Optional[Dict[str, Any]] = None,
109
+ session: Optional[snowpark.Session] = None,
110
+ ) -> jb.MLJob:
111
+ """
112
+ Submit a Python file as a job to the compute pool.
113
+
114
+ Args:
115
+ file_path: The path to the file containing the source code for the job.
116
+ compute_pool: The compute pool to use for the job.
117
+ stage_name: The name of the stage where the job payload will be uploaded.
118
+ args: A list of arguments to pass to the job.
119
+ env_vars: Environment variables to set in container
120
+ pip_requirements: A list of pip requirements for the job.
121
+ external_access_integrations: A list of external access integrations.
122
+ query_warehouse: The query warehouse to use. Defaults to session warehouse.
123
+ spec_overrides: Custom service specification overrides to apply.
124
+ session: The Snowpark session to use. If none specified, uses active session.
125
+
126
+ Returns:
127
+ An object representing the submitted job.
128
+ """
129
+ return _submit_job(
130
+ source=file_path,
131
+ args=args,
132
+ compute_pool=compute_pool,
133
+ stage_name=stage_name,
134
+ env_vars=env_vars,
135
+ pip_requirements=pip_requirements,
136
+ external_access_integrations=external_access_integrations,
137
+ query_warehouse=query_warehouse,
138
+ spec_overrides=spec_overrides,
139
+ session=session,
140
+ )
141
+
142
+
143
+ @snowpark._internal.utils.private_preview(version="1.7.4")
144
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
145
+ def submit_directory(
146
+ dir_path: str,
147
+ compute_pool: str,
148
+ *,
149
+ entrypoint: str,
150
+ stage_name: str,
151
+ args: Optional[List[str]] = None,
152
+ env_vars: Optional[Dict[str, str]] = None,
153
+ pip_requirements: Optional[List[str]] = None,
154
+ external_access_integrations: Optional[List[str]] = None,
155
+ query_warehouse: Optional[str] = None,
156
+ spec_overrides: Optional[Dict[str, Any]] = None,
157
+ session: Optional[snowpark.Session] = None,
158
+ ) -> jb.MLJob:
159
+ """
160
+ Submit a directory containing Python script(s) as a job to the compute pool.
161
+
162
+ Args:
163
+ dir_path: The path to the directory containing the job payload.
164
+ compute_pool: The compute pool to use for the job.
165
+ entrypoint: The relative path to the entry point script inside the source directory.
166
+ stage_name: The name of the stage where the job payload will be uploaded.
167
+ args: A list of arguments to pass to the job.
168
+ env_vars: Environment variables to set in container
169
+ pip_requirements: A list of pip requirements for the job.
170
+ external_access_integrations: A list of external access integrations.
171
+ query_warehouse: The query warehouse to use. Defaults to session warehouse.
172
+ spec_overrides: Custom service specification overrides to apply.
173
+ session: The Snowpark session to use. If none specified, uses active session.
174
+
175
+ Returns:
176
+ An object representing the submitted job.
177
+ """
178
+ return _submit_job(
179
+ source=dir_path,
180
+ entrypoint=entrypoint,
181
+ args=args,
182
+ compute_pool=compute_pool,
183
+ stage_name=stage_name,
184
+ env_vars=env_vars,
185
+ pip_requirements=pip_requirements,
186
+ external_access_integrations=external_access_integrations,
187
+ query_warehouse=query_warehouse,
188
+ spec_overrides=spec_overrides,
189
+ session=session,
190
+ )
191
+
192
+
193
+ @telemetry.send_api_usage_telemetry(
194
+ project=_PROJECT,
195
+ func_params_to_log=[
196
+ # TODO: Log the source type (callable, file, directory, etc)
197
+ # TODO: Log instance type of compute pool used
198
+ # TODO: Log lengths of args, env_vars, and spec_overrides values
199
+ "pip_requirements",
200
+ "external_access_integrations",
201
+ ],
202
+ )
203
+ def _submit_job(
204
+ source: Union[str, Callable[..., Any]],
205
+ compute_pool: str,
206
+ *,
207
+ stage_name: str,
208
+ entrypoint: Optional[str] = None,
209
+ args: Optional[List[str]] = None,
210
+ env_vars: Optional[Dict[str, str]] = None,
211
+ pip_requirements: Optional[List[str]] = None,
212
+ external_access_integrations: Optional[List[str]] = None,
213
+ query_warehouse: Optional[str] = None,
214
+ spec_overrides: Optional[Dict[str, Any]] = None,
215
+ session: Optional[snowpark.Session] = None,
216
+ ) -> jb.MLJob:
217
+ """
218
+ Submit a job to the compute pool.
219
+
220
+ Args:
221
+ source: The file/directory path containing payload source code or a serializable Python callable.
222
+ compute_pool: The compute pool to use for the job.
223
+ stage_name: The name of the stage where the job payload will be uploaded.
224
+ entrypoint: The entry point for the job execution. Required if source is a directory.
225
+ args: A list of arguments to pass to the job.
226
+ env_vars: Environment variables to set in container
227
+ pip_requirements: A list of pip requirements for the job.
228
+ external_access_integrations: A list of external access integrations.
229
+ query_warehouse: The query warehouse to use. Defaults to session warehouse.
230
+ spec_overrides: Custom service specification overrides to apply.
231
+ session: The Snowpark session to use. If none specified, uses active session.
232
+
233
+ Returns:
234
+ An object representing the submitted job.
235
+
236
+ Raises:
237
+ RuntimeError: If required Snowflake features are not enabled.
238
+ """
239
+ session = session or get_active_session()
240
+ job_id = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
241
+ stage_name = "@" + stage_name.lstrip("@").rstrip("/")
242
+ stage_path = pathlib.PurePosixPath(f"{stage_name}/{job_id}")
243
+
244
+ # Upload payload
245
+ uploaded_payload = payload_utils.JobPayload(
246
+ source,
247
+ entrypoint=entrypoint,
248
+ pip_requirements=pip_requirements,
249
+ ).upload(session, stage_path)
250
+
251
+ # Generate service spec
252
+ spec = spec_utils.generate_service_spec(
253
+ session,
254
+ compute_pool=compute_pool,
255
+ payload=uploaded_payload,
256
+ args=args,
257
+ )
258
+ spec_overrides = spec_utils.generate_spec_overrides(
259
+ environment_vars=env_vars,
260
+ custom_overrides=spec_overrides,
261
+ )
262
+ if spec_overrides:
263
+ spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
264
+
265
+ # Generate SQL command for job submission
266
+ query_template = textwrap.dedent(
267
+ f"""\
268
+ EXECUTE JOB SERVICE
269
+ IN COMPUTE POOL {compute_pool}
270
+ FROM SPECIFICATION $$
271
+ {{}}
272
+ $$
273
+ NAME = {job_id}
274
+ ASYNC = TRUE
275
+ """
276
+ )
277
+ query = query_template.format(yaml.dump(spec)).splitlines()
278
+ if external_access_integrations:
279
+ external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
280
+ query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
281
+ query_warehouse = query_warehouse or session.get_current_warehouse()
282
+ if query_warehouse:
283
+ query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
284
+
285
+ # Submit job
286
+ query_text = "\n".join(line for line in query if line)
287
+
288
+ try:
289
+ _ = session.sql(query_text).collect()
290
+ except SnowparkSQLException as e:
291
+ if "invalid property 'ASYNC'" in e.message:
292
+ raise RuntimeError(
293
+ "SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
294
+ ) from e
295
+ raise
296
+
297
+ # TODO: Wrap snowflake.core.service.JobService object
298
+ return jb.MLJob(job_id, session=session)
@@ -33,6 +33,7 @@ from snowflake.snowpark._internal import utils as snowpark_utils
33
33
 
34
34
  class ServiceInfo(TypedDict):
35
35
  name: str
36
+ status: str
36
37
  inference_endpoint: Optional[str]
37
38
 
38
39
 
@@ -550,9 +551,13 @@ class ModelOperator:
550
551
  fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
551
552
 
552
553
  result = []
553
- ingress_url: Optional[str] = None
554
+
554
555
  for fully_qualified_service_name in fully_qualified_service_names:
556
+ ingress_url: Optional[str] = None
555
557
  db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
558
+ service_status, _ = self._service_client.get_service_status(
559
+ database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
560
+ )
556
561
  for res_row in self._service_client.show_endpoints(
557
562
  database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
558
563
  ):
@@ -566,7 +571,11 @@ class ModelOperator:
566
571
  )
567
572
  if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX):
568
573
  ingress_url = None
569
- result.append(ServiceInfo(name=fully_qualified_service_name, inference_endpoint=ingress_url))
574
+ result.append(
575
+ ServiceInfo(
576
+ name=fully_qualified_service_name, status=service_status.value, inference_endpoint=ingress_url
577
+ )
578
+ )
570
579
 
571
580
  return result
572
581
 
@@ -8,11 +8,9 @@ import threading
8
8
  import time
9
9
  from typing import Any, Dict, List, Optional, Tuple, Union, cast
10
10
 
11
- from packaging import version
12
-
13
11
  from snowflake import snowpark
14
12
  from snowflake.ml._internal import file_utils
15
- from snowflake.ml._internal.utils import service_logger, snowflake_env, sql_identifier
13
+ from snowflake.ml._internal.utils import service_logger, sql_identifier
16
14
  from snowflake.ml.model._client.service import model_deployment_spec
17
15
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
18
16
  from snowflake.snowpark import async_job, exceptions, row, session
@@ -133,14 +131,6 @@ class ServiceOperator:
133
131
  )
134
132
  stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
135
133
 
136
- # TODO(hayu): Remove the version check after Snowflake 8.40.0 release
137
- if (
138
- snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
139
- < version.parse("8.40.0")
140
- and build_external_access_integrations is None
141
- ):
142
- raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
143
-
144
134
  self._model_deployment_spec.save(
145
135
  database_name=database_name,
146
136
  schema_name=schema_name,
@@ -4,6 +4,7 @@ import textwrap
4
4
  from typing import Any, Dict, List, Optional, Tuple
5
5
 
6
6
  from snowflake import snowpark
7
+ from snowflake.ml._internal import platform_capabilities
7
8
  from snowflake.ml._internal.utils import (
8
9
  identifier,
9
10
  query_result_checker,
@@ -120,12 +121,18 @@ class ServiceSQLClient(_base._BaseSQLClient):
120
121
  args_sql_list.append(input_arg_value)
121
122
  args_sql = ", ".join(args_sql_list)
122
123
 
123
- function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
124
- fully_qualified_function_name = identifier.get_schema_level_object_identifier(
125
- actual_database_name.identifier(),
126
- actual_schema_name.identifier(),
127
- function_name,
128
- )
124
+ if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
125
+ fully_qualified_service_name = self.fully_qualified_object_name(
126
+ actual_database_name, actual_schema_name, service_name
127
+ )
128
+ fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
129
+ else:
130
+ function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
131
+ fully_qualified_function_name = identifier.get_schema_level_object_identifier(
132
+ actual_database_name.identifier(),
133
+ actual_schema_name.identifier(),
134
+ function_name,
135
+ )
129
136
 
130
137
  sql = textwrap.dedent(
131
138
  f"""{with_sql}
@@ -113,7 +113,33 @@ class ModelEnv:
113
113
  self._snowpark_ml_version = version.parse(snowpark_ml_version)
114
114
 
115
115
  def include_if_absent(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
116
- """Append requirements into model env if absent.
116
+ """Append requirements into model env if absent. Depending on the environment, requirements may be added
117
+ to either the pip requirements or conda dependencies.
118
+
119
+ Args:
120
+ pkgs: A list of ModelDependency namedtuple to be appended.
121
+ check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
122
+ """
123
+ if self.pip_requirements and not self.conda_dependencies and pkgs:
124
+ pip_pkg_reqs: List[str] = []
125
+ warnings.warn(
126
+ (
127
+ "Dependencies specified from pip requirements."
128
+ " This may prevent model deploying to Snowflake Warehouse."
129
+ ),
130
+ category=UserWarning,
131
+ stacklevel=2,
132
+ )
133
+ for conda_req_str, pip_name in pkgs:
134
+ _, conda_req = env_utils._validate_conda_dependency_string(conda_req_str)
135
+ pip_req = requirements.Requirement(f"{pip_name}{conda_req.specifier}")
136
+ pip_pkg_reqs.append(str(pip_req))
137
+ self._include_if_absent_pip(pip_pkg_reqs, check_local_version)
138
+ else:
139
+ self._include_if_absent_conda(pkgs, check_local_version)
140
+
141
+ def _include_if_absent_conda(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
142
+ """Append requirements into model env conda dependencies if absent.
117
143
 
118
144
  Args:
119
145
  pkgs: A list of ModelDependency namedtuple to be appended.
@@ -134,8 +160,8 @@ class ModelEnv:
134
160
  if show_warning_message:
135
161
  warnings.warn(
136
162
  (
137
- f"Basic dependency {req_to_add.name} specified from PIP requirements."
138
- + " This may prevent model deploying to Snowflake Warehouse."
163
+ f"Basic dependency {req_to_add.name} specified from pip requirements."
164
+ " This may prevent model deploying to Snowflake Warehouse."
139
165
  ),
140
166
  category=UserWarning,
141
167
  stacklevel=2,
@@ -157,11 +183,11 @@ class ModelEnv:
157
183
  stacklevel=2,
158
184
  )
159
185
 
160
- def include_if_absent_pip(self, pkgs: List[str], check_local_version: bool = False) -> None:
161
- """Append pip requirements into model env if absent.
186
+ def _include_if_absent_pip(self, pkgs: List[str], check_local_version: bool = False) -> None:
187
+ """Append pip requirements into model env pip requirements if absent.
162
188
 
163
189
  Args:
164
- pkgs: A list of string to be appended in pip requirement.
190
+ pkgs: A list of strings to be appended to pip environment.
165
191
  check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
166
192
  """
167
193
 
@@ -187,25 +213,6 @@ class ModelEnv:
187
213
  self._conda_dependencies[channel].remove(spec)
188
214
 
189
215
  def generate_env_for_cuda(self) -> None:
190
- if self.cuda_version is None:
191
- return
192
-
193
- cuda_spec = env_utils.find_dep_spec(
194
- self._conda_dependencies, self._pip_requirements, conda_pkg_name="cuda", remove_spec=False
195
- )
196
- if cuda_spec and not cuda_spec.specifier.contains(self.cuda_version):
197
- raise ValueError(
198
- "The CUDA requirement you specified in your conda dependencies or pip requirements is"
199
- " conflicting with CUDA version required. Please do not specify CUDA dependency using conda"
200
- " dependencies or pip requirements."
201
- )
202
-
203
- if not cuda_spec:
204
- self.include_if_absent(
205
- [ModelDependency(requirement=f"nvidia::cuda=={self.cuda_version}.*", pip_name="cuda")],
206
- check_local_version=False,
207
- )
208
-
209
216
  xgboost_spec = env_utils.find_dep_spec(
210
217
  self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
211
218
  )
@@ -236,7 +243,7 @@ class ModelEnv:
236
243
  check_local_version=False,
237
244
  )
238
245
 
239
- self.include_if_absent_pip(["bitsandbytes>=0.41.0"], check_local_version=False)
246
+ self._include_if_absent_pip(["bitsandbytes>=0.41.0"], check_local_version=False)
240
247
 
241
248
  def relax_version(self) -> None:
242
249
  """Relax the version requirements for both conda dependencies and pip requirements.
@@ -252,7 +259,9 @@ class ModelEnv:
252
259
  self._pip_requirements = list(map(env_utils.relax_requirement_version, self._pip_requirements))
253
260
 
254
261
  def load_from_conda_file(self, conda_env_path: pathlib.Path) -> None:
255
- conda_dependencies_dict, pip_requirements_list, python_version = env_utils.load_conda_env_file(conda_env_path)
262
+ conda_dependencies_dict, pip_requirements_list, python_version, cuda_version = env_utils.load_conda_env_file(
263
+ conda_env_path
264
+ )
256
265
 
257
266
  for channel, channel_dependencies in conda_dependencies_dict.items():
258
267
  if channel != env_utils.DEFAULT_CHANNEL_NAME:
@@ -310,6 +319,9 @@ class ModelEnv:
310
319
  if python_version:
311
320
  self.python_version = python_version
312
321
 
322
+ if cuda_version:
323
+ self.cuda_version = cuda_version
324
+
313
325
  def load_from_pip_file(self, pip_requirements_path: pathlib.Path) -> None:
314
326
  pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
315
327
 
@@ -342,12 +354,17 @@ class ModelEnv:
342
354
  self.snowpark_ml_version = env_dict["snowpark_ml_version"]
343
355
 
344
356
  def save_as_dict(
345
- self, base_dir: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
357
+ self,
358
+ base_dir: pathlib.Path,
359
+ default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL,
360
+ is_gpu: Optional[bool] = False,
346
361
  ) -> model_meta_schema.ModelEnvDict:
362
+ cuda_version = self.cuda_version if is_gpu else None
347
363
  env_utils.save_conda_env_file(
348
364
  pathlib.Path(base_dir / self.conda_env_rel_path),
349
365
  self._conda_dependencies,
350
366
  self.python_version,
367
+ cuda_version,
351
368
  default_channel_override=default_channel_override,
352
369
  )
353
370
  env_utils.save_requirements_file(
@@ -38,13 +38,17 @@ def _is_callable(model: model_types.SupportedModelType, method_name: str) -> boo
38
38
  return callable(getattr(model, method_name, None))
39
39
 
40
40
 
41
- def get_truncated_sample_data(sample_input_data: model_types.SupportedDataType) -> model_types.SupportedLocalDataType:
42
- trunc_sample_input = model_signature._truncate_data(sample_input_data)
41
+ def get_truncated_sample_data(
42
+ sample_input_data: model_types.SupportedDataType, length: int = 100, is_for_modeling_model: bool = False
43
+ ) -> model_types.SupportedLocalDataType:
44
+ trunc_sample_input = model_signature._truncate_data(sample_input_data, length=length)
43
45
  local_sample_input: model_types.SupportedLocalDataType = None
44
46
  if isinstance(sample_input_data, SnowparkDataFrame):
45
47
  # Added because of Any from missing stubs.
46
48
  trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
47
49
  local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
50
+ if is_for_modeling_model:
51
+ local_sample_input.columns = trunc_sample_input.columns
48
52
  else:
49
53
  local_sample_input = trunc_sample_input
50
54
  return local_sample_input
@@ -56,13 +60,15 @@ def validate_signature(
56
60
  target_methods: Iterable[str],
57
61
  sample_input_data: Optional[model_types.SupportedDataType],
58
62
  get_prediction_fn: Callable[[str, model_types.SupportedLocalDataType], model_types.SupportedLocalDataType],
63
+ is_for_modeling_model: bool = False,
59
64
  ) -> model_meta.ModelMetadata:
60
65
  if model_meta.signatures:
61
66
  validate_target_methods(model, list(model_meta.signatures.keys()))
62
67
  if sample_input_data is not None:
63
- local_sample_input = get_truncated_sample_data(sample_input_data)
68
+ local_sample_input = get_truncated_sample_data(
69
+ sample_input_data, is_for_modeling_model=is_for_modeling_model
70
+ )
64
71
  for target_method in model_meta.signatures.keys():
65
-
66
72
  model_signature_inst = model_meta.signatures.get(target_method)
67
73
  if model_signature_inst is not None:
68
74
  # strict validation the input signature
@@ -75,10 +81,17 @@ def validate_signature(
75
81
  assert (
76
82
  sample_input_data is not None
77
83
  ), "Model signature and sample input are None at the same time. This should not happen with local model."
78
- local_sample_input = get_truncated_sample_data(sample_input_data)
84
+ local_sample_input = get_truncated_sample_data(sample_input_data, is_for_modeling_model=is_for_modeling_model)
79
85
  for target_method in target_methods:
80
86
  predictions_df = get_prediction_fn(target_method, local_sample_input)
81
- sig = model_signature.infer_signature(local_sample_input, predictions_df)
87
+ sig = model_signature.infer_signature(
88
+ sample_input_data,
89
+ predictions_df,
90
+ input_feature_names=None,
91
+ output_feature_names=None,
92
+ input_data_limit=100,
93
+ output_data_limit=100,
94
+ )
82
95
  model_meta.signatures[target_method] = sig
83
96
 
84
97
  return model_meta
@@ -66,7 +66,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
66
66
  sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
67
67
 
68
68
  if inspect.iscoroutinefunction(target_method):
69
- with anyio.start_blocking_portal() as portal:
69
+ with anyio.from_thread.start_blocking_portal() as portal:
70
70
  predictions_df = portal.call(target_method, model, sample_input_data)
71
71
  else:
72
72
  predictions_df = target_method(model, sample_input_data)
@@ -98,7 +98,6 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
98
98
  if model.context.model_refs:
99
99
  for sub_name, model_ref in model.context.model_refs.items():
100
100
  handler = model_handler.find_handler(model_ref.model)
101
- assert handler is not None
102
101
  if handler is None:
103
102
  raise TypeError("Your input type to custom model is not currently supported")
104
103
  sub_model = handler.cast_model(model_ref.model)
@@ -146,6 +146,10 @@ class HuggingFacePipelineHandler(
146
146
  framework = getattr(model, "framework", None)
147
147
  batch_size = getattr(model, "batch_size", None)
148
148
 
149
+ has_tokenizer = getattr(model, "tokenizer", None) is not None
150
+ has_feature_extractor = getattr(model, "feature_extractor", None) is not None
151
+ has_image_preprocessor = getattr(model, "image_preprocessor", None) is not None
152
+
149
153
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
150
154
  params = {
151
155
  **model._preprocess_params, # type:ignore[attr-defined]
@@ -234,6 +238,9 @@ class HuggingFacePipelineHandler(
234
238
  {
235
239
  "task": task,
236
240
  "batch_size": batch_size if batch_size is not None else 1,
241
+ "has_tokenizer": has_tokenizer,
242
+ "has_feature_extractor": has_feature_extractor,
243
+ "has_image_preprocessor": has_image_preprocessor,
237
244
  }
238
245
  ),
239
246
  )
@@ -308,6 +315,14 @@ class HuggingFacePipelineHandler(
308
315
  if os.path.isdir(model_blob_file_or_dir_path):
309
316
  import transformers
310
317
 
318
+ additional_pipeline_params = {}
319
+ if model_blob_options.get("has_tokenizer", False):
320
+ additional_pipeline_params["tokenizer"] = model_blob_file_or_dir_path
321
+ if model_blob_options.get("has_feature_extractor", False):
322
+ additional_pipeline_params["feature_extractor"] = model_blob_file_or_dir_path
323
+ if model_blob_options.get("has_image_preprocessor", False):
324
+ additional_pipeline_params["image_preprocessor"] = model_blob_file_or_dir_path
325
+
311
326
  with open(
312
327
  os.path.join(
313
328
  model_blob_file_or_dir_path,
@@ -323,6 +338,8 @@ class HuggingFacePipelineHandler(
323
338
  model_blob_options["task"],
324
339
  model=model_blob_file_or_dir_path,
325
340
  trust_remote_code=True,
341
+ torch_dtype="auto",
342
+ **additional_pipeline_params,
326
343
  **device_config,
327
344
  )
328
345