snowflake-ml-python 1.7.3__py3-none-any.whl → 1.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. snowflake/cortex/_complete.py +19 -0
  2. snowflake/ml/_internal/platform_capabilities.py +87 -0
  3. snowflake/ml/dataset/dataset.py +0 -1
  4. snowflake/ml/fileset/fileset.py +6 -0
  5. snowflake/ml/jobs/__init__.py +21 -0
  6. snowflake/ml/jobs/_utils/constants.py +51 -0
  7. snowflake/ml/jobs/_utils/payload_utils.py +352 -0
  8. snowflake/ml/jobs/_utils/spec_utils.py +298 -0
  9. snowflake/ml/jobs/_utils/types.py +39 -0
  10. snowflake/ml/jobs/decorators.py +91 -0
  11. snowflake/ml/jobs/job.py +113 -0
  12. snowflake/ml/jobs/manager.py +298 -0
  13. snowflake/ml/model/_client/ops/model_ops.py +11 -2
  14. snowflake/ml/model/_client/ops/service_ops.py +1 -11
  15. snowflake/ml/model/_client/sql/service.py +13 -6
  16. snowflake/ml/model/_packager/model_handlers/_utils.py +12 -3
  17. snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
  18. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -0
  19. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  20. snowflake/ml/model/_signatures/base_handler.py +1 -2
  21. snowflake/ml/model/_signatures/builtins_handler.py +2 -2
  22. snowflake/ml/model/_signatures/numpy_handler.py +6 -7
  23. snowflake/ml/model/_signatures/pandas_handler.py +2 -2
  24. snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
  25. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  26. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
  27. snowflake/ml/model/model_signature.py +17 -4
  28. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
  29. snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
  30. snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
  31. snowflake/ml/modeling/cluster/birch.py +6 -3
  32. snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
  33. snowflake/ml/modeling/cluster/dbscan.py +6 -3
  34. snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
  35. snowflake/ml/modeling/cluster/k_means.py +6 -3
  36. snowflake/ml/modeling/cluster/mean_shift.py +6 -3
  37. snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
  38. snowflake/ml/modeling/cluster/optics.py +6 -3
  39. snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
  40. snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
  41. snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
  42. snowflake/ml/modeling/compose/column_transformer.py +6 -3
  43. snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
  44. snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
  45. snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
  46. snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
  47. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
  48. snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
  49. snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
  50. snowflake/ml/modeling/covariance/oas.py +6 -3
  51. snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
  52. snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
  53. snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
  54. snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
  55. snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
  56. snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
  57. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
  58. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
  59. snowflake/ml/modeling/decomposition/pca.py +6 -3
  60. snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
  61. snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
  62. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
  63. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
  64. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
  65. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
  66. snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
  67. snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
  68. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
  69. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
  70. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
  71. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
  72. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
  73. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
  74. snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
  75. snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
  76. snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
  77. snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
  78. snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
  79. snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
  80. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
  81. snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
  82. snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
  83. snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
  84. snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
  85. snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
  86. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
  87. snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
  88. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
  89. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
  90. snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
  91. snowflake/ml/modeling/impute/knn_imputer.py +6 -3
  92. snowflake/ml/modeling/impute/missing_indicator.py +6 -3
  93. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
  94. snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
  95. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
  96. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
  97. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
  98. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
  99. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
  100. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
  101. snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
  102. snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
  103. snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
  104. snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
  105. snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
  106. snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
  107. snowflake/ml/modeling/linear_model/lars.py +6 -3
  108. snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
  109. snowflake/ml/modeling/linear_model/lasso.py +6 -3
  110. snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
  111. snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
  112. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
  113. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
  114. snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
  115. snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
  116. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
  117. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
  118. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
  119. snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
  120. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
  121. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
  122. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
  123. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
  124. snowflake/ml/modeling/linear_model/perceptron.py +6 -3
  125. snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
  126. snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
  127. snowflake/ml/modeling/linear_model/ridge.py +6 -3
  128. snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
  129. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
  130. snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
  131. snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
  132. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
  133. snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
  134. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
  135. snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
  136. snowflake/ml/modeling/manifold/isomap.py +6 -3
  137. snowflake/ml/modeling/manifold/mds.py +6 -3
  138. snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
  139. snowflake/ml/modeling/manifold/tsne.py +6 -3
  140. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
  141. snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
  142. snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
  143. snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
  144. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
  145. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
  146. snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
  147. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
  148. snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
  149. snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
  150. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
  151. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
  152. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
  153. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
  154. snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
  155. snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
  156. snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
  157. snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
  158. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
  159. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
  160. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
  161. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
  162. snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
  163. snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
  164. snowflake/ml/modeling/pipeline/pipeline.py +10 -2
  165. snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
  166. snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
  167. snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
  168. snowflake/ml/modeling/svm/linear_svc.py +6 -3
  169. snowflake/ml/modeling/svm/linear_svr.py +6 -3
  170. snowflake/ml/modeling/svm/nu_svc.py +6 -3
  171. snowflake/ml/modeling/svm/nu_svr.py +6 -3
  172. snowflake/ml/modeling/svm/svc.py +6 -3
  173. snowflake/ml/modeling/svm/svr.py +6 -3
  174. snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
  175. snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
  176. snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
  177. snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
  178. snowflake/ml/modeling/xgboost/xgb_classifier.py +6 -3
  179. snowflake/ml/modeling/xgboost/xgb_regressor.py +6 -3
  180. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +6 -3
  181. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +6 -3
  182. snowflake/ml/version.py +1 -1
  183. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +29 -14
  184. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +187 -178
  185. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
  186. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +0 -0
  187. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,298 @@
1
+ import pathlib
2
+ import textwrap
3
+ from typing import Any, Callable, Dict, List, Literal, Optional, Union
4
+ from uuid import uuid4
5
+
6
+ import yaml
7
+
8
+ from snowflake import snowpark
9
+ from snowflake.ml._internal import telemetry
10
+ from snowflake.ml._internal.utils import identifier
11
+ from snowflake.ml.jobs import job as jb
12
+ from snowflake.ml.jobs._utils import payload_utils, spec_utils
13
+ from snowflake.snowpark.context import get_active_session
14
+ from snowflake.snowpark.exceptions import SnowparkSQLException
15
+
16
+ _PROJECT = "MLJob"
17
+ JOB_ID_PREFIX = "MLJOB_"
18
+
19
+
20
+ @snowpark._internal.utils.private_preview(version="1.7.4")
21
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
22
+ def list_jobs(
23
+ limit: int = 10,
24
+ scope: Union[Literal["account", "database", "schema"], str, None] = None,
25
+ session: Optional[snowpark.Session] = None,
26
+ ) -> snowpark.DataFrame:
27
+ """
28
+ Returns a Snowpark DataFrame with the list of jobs in the current session.
29
+
30
+ Args:
31
+ limit: The maximum number of jobs to return. Non-positive values are treated as no limit.
32
+ scope: The scope to list jobs from, such as "schema" or "compute pool <pool_name>".
33
+ session: The Snowpark session to use. If none specified, uses active session.
34
+
35
+ Returns:
36
+ A DataFrame with the list of jobs.
37
+
38
+ Examples:
39
+ >>> from snowflake.ml.jobs import list_jobs
40
+ >>> list_jobs(limit=5).show()
41
+ """
42
+ session = session or get_active_session()
43
+ query = "SHOW JOB SERVICES"
44
+ query += f" LIKE '{JOB_ID_PREFIX}%'"
45
+ if scope:
46
+ query += f" IN {scope}"
47
+ if limit > 0:
48
+ query += f" LIMIT {limit}"
49
+ df = session.sql(query)
50
+ df = df.select(
51
+ df['"name"'].alias('"id"'),
52
+ df['"owner"'],
53
+ df['"status"'],
54
+ df['"created_on"'],
55
+ df['"compute_pool"'],
56
+ ).order_by('"created_on"', ascending=False)
57
+ return df
58
+
59
+
60
+ @snowpark._internal.utils.private_preview(version="1.7.4")
61
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
62
+ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob:
63
+ """Retrieve a job service from the backend."""
64
+ session = session or get_active_session()
65
+
66
+ try:
67
+ # Validate job_id
68
+ job_id = identifier.resolve_identifier(job_id)
69
+ except ValueError as e:
70
+ raise ValueError(f"Invalid job ID: {job_id}") from e
71
+
72
+ try:
73
+ # Validate that job exists by doing a status check
74
+ job = jb.MLJob(job_id, session=session)
75
+ _ = job.status
76
+ return job
77
+ except SnowparkSQLException as e:
78
+ if "does not exist" in e.message:
79
+ raise ValueError(f"Job does not exist: {job_id}") from e
80
+ raise
81
+
82
+
83
+ @snowpark._internal.utils.private_preview(version="1.7.4")
84
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
85
+ def delete_job(job: Union[str, jb.MLJob], session: Optional[snowpark.Session] = None) -> None:
86
+ """Delete a job service from the backend. Status and logs will be lost."""
87
+ if isinstance(job, jb.MLJob):
88
+ job_id = job.id
89
+ session = job._session or session
90
+ else:
91
+ job_id = job
92
+ session = session or get_active_session()
93
+ session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
94
+
95
+
96
+ @snowpark._internal.utils.private_preview(version="1.7.4")
97
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
98
+ def submit_file(
99
+ file_path: str,
100
+ compute_pool: str,
101
+ *,
102
+ stage_name: str,
103
+ args: Optional[List[str]] = None,
104
+ env_vars: Optional[Dict[str, str]] = None,
105
+ pip_requirements: Optional[List[str]] = None,
106
+ external_access_integrations: Optional[List[str]] = None,
107
+ query_warehouse: Optional[str] = None,
108
+ spec_overrides: Optional[Dict[str, Any]] = None,
109
+ session: Optional[snowpark.Session] = None,
110
+ ) -> jb.MLJob:
111
+ """
112
+ Submit a Python file as a job to the compute pool.
113
+
114
+ Args:
115
+ file_path: The path to the file containing the source code for the job.
116
+ compute_pool: The compute pool to use for the job.
117
+ stage_name: The name of the stage where the job payload will be uploaded.
118
+ args: A list of arguments to pass to the job.
119
+ env_vars: Environment variables to set in container
120
+ pip_requirements: A list of pip requirements for the job.
121
+ external_access_integrations: A list of external access integrations.
122
+ query_warehouse: The query warehouse to use. Defaults to session warehouse.
123
+ spec_overrides: Custom service specification overrides to apply.
124
+ session: The Snowpark session to use. If none specified, uses active session.
125
+
126
+ Returns:
127
+ An object representing the submitted job.
128
+ """
129
+ return _submit_job(
130
+ source=file_path,
131
+ args=args,
132
+ compute_pool=compute_pool,
133
+ stage_name=stage_name,
134
+ env_vars=env_vars,
135
+ pip_requirements=pip_requirements,
136
+ external_access_integrations=external_access_integrations,
137
+ query_warehouse=query_warehouse,
138
+ spec_overrides=spec_overrides,
139
+ session=session,
140
+ )
141
+
142
+
143
+ @snowpark._internal.utils.private_preview(version="1.7.4")
144
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
145
+ def submit_directory(
146
+ dir_path: str,
147
+ compute_pool: str,
148
+ *,
149
+ entrypoint: str,
150
+ stage_name: str,
151
+ args: Optional[List[str]] = None,
152
+ env_vars: Optional[Dict[str, str]] = None,
153
+ pip_requirements: Optional[List[str]] = None,
154
+ external_access_integrations: Optional[List[str]] = None,
155
+ query_warehouse: Optional[str] = None,
156
+ spec_overrides: Optional[Dict[str, Any]] = None,
157
+ session: Optional[snowpark.Session] = None,
158
+ ) -> jb.MLJob:
159
+ """
160
+ Submit a directory containing Python script(s) as a job to the compute pool.
161
+
162
+ Args:
163
+ dir_path: The path to the directory containing the job payload.
164
+ compute_pool: The compute pool to use for the job.
165
+ entrypoint: The relative path to the entry point script inside the source directory.
166
+ stage_name: The name of the stage where the job payload will be uploaded.
167
+ args: A list of arguments to pass to the job.
168
+ env_vars: Environment variables to set in container
169
+ pip_requirements: A list of pip requirements for the job.
170
+ external_access_integrations: A list of external access integrations.
171
+ query_warehouse: The query warehouse to use. Defaults to session warehouse.
172
+ spec_overrides: Custom service specification overrides to apply.
173
+ session: The Snowpark session to use. If none specified, uses active session.
174
+
175
+ Returns:
176
+ An object representing the submitted job.
177
+ """
178
+ return _submit_job(
179
+ source=dir_path,
180
+ entrypoint=entrypoint,
181
+ args=args,
182
+ compute_pool=compute_pool,
183
+ stage_name=stage_name,
184
+ env_vars=env_vars,
185
+ pip_requirements=pip_requirements,
186
+ external_access_integrations=external_access_integrations,
187
+ query_warehouse=query_warehouse,
188
+ spec_overrides=spec_overrides,
189
+ session=session,
190
+ )
191
+
192
+
193
+ @telemetry.send_api_usage_telemetry(
194
+ project=_PROJECT,
195
+ func_params_to_log=[
196
+ # TODO: Log the source type (callable, file, directory, etc)
197
+ # TODO: Log instance type of compute pool used
198
+ # TODO: Log lengths of args, env_vars, and spec_overrides values
199
+ "pip_requirements",
200
+ "external_access_integrations",
201
+ ],
202
+ )
203
+ def _submit_job(
204
+ source: Union[str, Callable[..., Any]],
205
+ compute_pool: str,
206
+ *,
207
+ stage_name: str,
208
+ entrypoint: Optional[str] = None,
209
+ args: Optional[List[str]] = None,
210
+ env_vars: Optional[Dict[str, str]] = None,
211
+ pip_requirements: Optional[List[str]] = None,
212
+ external_access_integrations: Optional[List[str]] = None,
213
+ query_warehouse: Optional[str] = None,
214
+ spec_overrides: Optional[Dict[str, Any]] = None,
215
+ session: Optional[snowpark.Session] = None,
216
+ ) -> jb.MLJob:
217
+ """
218
+ Submit a job to the compute pool.
219
+
220
+ Args:
221
+ source: The file/directory path containing payload source code or a serializable Python callable.
222
+ compute_pool: The compute pool to use for the job.
223
+ stage_name: The name of the stage where the job payload will be uploaded.
224
+ entrypoint: The entry point for the job execution. Required if source is a directory.
225
+ args: A list of arguments to pass to the job.
226
+ env_vars: Environment variables to set in container
227
+ pip_requirements: A list of pip requirements for the job.
228
+ external_access_integrations: A list of external access integrations.
229
+ query_warehouse: The query warehouse to use. Defaults to session warehouse.
230
+ spec_overrides: Custom service specification overrides to apply.
231
+ session: The Snowpark session to use. If none specified, uses active session.
232
+
233
+ Returns:
234
+ An object representing the submitted job.
235
+
236
+ Raises:
237
+ RuntimeError: If required Snowflake features are not enabled.
238
+ """
239
+ session = session or get_active_session()
240
+ job_id = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
241
+ stage_name = "@" + stage_name.lstrip("@").rstrip("/")
242
+ stage_path = pathlib.PurePosixPath(f"{stage_name}/{job_id}")
243
+
244
+ # Upload payload
245
+ uploaded_payload = payload_utils.JobPayload(
246
+ source,
247
+ entrypoint=entrypoint,
248
+ pip_requirements=pip_requirements,
249
+ ).upload(session, stage_path)
250
+
251
+ # Generate service spec
252
+ spec = spec_utils.generate_service_spec(
253
+ session,
254
+ compute_pool=compute_pool,
255
+ payload=uploaded_payload,
256
+ args=args,
257
+ )
258
+ spec_overrides = spec_utils.generate_spec_overrides(
259
+ environment_vars=env_vars,
260
+ custom_overrides=spec_overrides,
261
+ )
262
+ if spec_overrides:
263
+ spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
264
+
265
+ # Generate SQL command for job submission
266
+ query_template = textwrap.dedent(
267
+ f"""\
268
+ EXECUTE JOB SERVICE
269
+ IN COMPUTE POOL {compute_pool}
270
+ FROM SPECIFICATION $$
271
+ {{}}
272
+ $$
273
+ NAME = {job_id}
274
+ ASYNC = TRUE
275
+ """
276
+ )
277
+ query = query_template.format(yaml.dump(spec)).splitlines()
278
+ if external_access_integrations:
279
+ external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
280
+ query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
281
+ query_warehouse = query_warehouse or session.get_current_warehouse()
282
+ if query_warehouse:
283
+ query.append(f"QUERY_WAREHOUSE = {query_warehouse}")
284
+
285
+ # Submit job
286
+ query_text = "\n".join(line for line in query if line)
287
+
288
+ try:
289
+ _ = session.sql(query_text).collect()
290
+ except SnowparkSQLException as e:
291
+ if "invalid property 'ASYNC'" in e.message:
292
+ raise RuntimeError(
293
+ "SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
294
+ ) from e
295
+ raise
296
+
297
+ # TODO: Wrap snowflake.core.service.JobService object
298
+ return jb.MLJob(job_id, session=session)
@@ -33,6 +33,7 @@ from snowflake.snowpark._internal import utils as snowpark_utils
33
33
 
34
34
  class ServiceInfo(TypedDict):
35
35
  name: str
36
+ status: str
36
37
  inference_endpoint: Optional[str]
37
38
 
38
39
 
@@ -550,9 +551,13 @@ class ModelOperator:
550
551
  fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
551
552
 
552
553
  result = []
553
- ingress_url: Optional[str] = None
554
+
554
555
  for fully_qualified_service_name in fully_qualified_service_names:
556
+ ingress_url: Optional[str] = None
555
557
  db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
558
+ service_status, _ = self._service_client.get_service_status(
559
+ database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
560
+ )
556
561
  for res_row in self._service_client.show_endpoints(
557
562
  database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
558
563
  ):
@@ -566,7 +571,11 @@ class ModelOperator:
566
571
  )
567
572
  if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX):
568
573
  ingress_url = None
569
- result.append(ServiceInfo(name=fully_qualified_service_name, inference_endpoint=ingress_url))
574
+ result.append(
575
+ ServiceInfo(
576
+ name=fully_qualified_service_name, status=service_status.value, inference_endpoint=ingress_url
577
+ )
578
+ )
570
579
 
571
580
  return result
572
581
 
@@ -8,11 +8,9 @@ import threading
8
8
  import time
9
9
  from typing import Any, Dict, List, Optional, Tuple, Union, cast
10
10
 
11
- from packaging import version
12
-
13
11
  from snowflake import snowpark
14
12
  from snowflake.ml._internal import file_utils
15
- from snowflake.ml._internal.utils import service_logger, snowflake_env, sql_identifier
13
+ from snowflake.ml._internal.utils import service_logger, sql_identifier
16
14
  from snowflake.ml.model._client.service import model_deployment_spec
17
15
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
18
16
  from snowflake.snowpark import async_job, exceptions, row, session
@@ -133,14 +131,6 @@ class ServiceOperator:
133
131
  )
134
132
  stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
135
133
 
136
- # TODO(hayu): Remove the version check after Snowflake 8.40.0 release
137
- if (
138
- snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
139
- < version.parse("8.40.0")
140
- and build_external_access_integrations is None
141
- ):
142
- raise ValueError("External access integrations are required in Snowflake < 8.40.0.")
143
-
144
134
  self._model_deployment_spec.save(
145
135
  database_name=database_name,
146
136
  schema_name=schema_name,
@@ -4,6 +4,7 @@ import textwrap
4
4
  from typing import Any, Dict, List, Optional, Tuple
5
5
 
6
6
  from snowflake import snowpark
7
+ from snowflake.ml._internal import platform_capabilities
7
8
  from snowflake.ml._internal.utils import (
8
9
  identifier,
9
10
  query_result_checker,
@@ -120,12 +121,18 @@ class ServiceSQLClient(_base._BaseSQLClient):
120
121
  args_sql_list.append(input_arg_value)
121
122
  args_sql = ", ".join(args_sql_list)
122
123
 
123
- function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
124
- fully_qualified_function_name = identifier.get_schema_level_object_identifier(
125
- actual_database_name.identifier(),
126
- actual_schema_name.identifier(),
127
- function_name,
128
- )
124
+ if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
125
+ fully_qualified_service_name = self.fully_qualified_object_name(
126
+ actual_database_name, actual_schema_name, service_name
127
+ )
128
+ fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
129
+ else:
130
+ function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
131
+ fully_qualified_function_name = identifier.get_schema_level_object_identifier(
132
+ actual_database_name.identifier(),
133
+ actual_schema_name.identifier(),
134
+ function_name,
135
+ )
129
136
 
130
137
  sql = textwrap.dedent(
131
138
  f"""{with_sql}
@@ -38,8 +38,10 @@ def _is_callable(model: model_types.SupportedModelType, method_name: str) -> boo
38
38
  return callable(getattr(model, method_name, None))
39
39
 
40
40
 
41
- def get_truncated_sample_data(sample_input_data: model_types.SupportedDataType) -> model_types.SupportedLocalDataType:
42
- trunc_sample_input = model_signature._truncate_data(sample_input_data)
41
+ def get_truncated_sample_data(
42
+ sample_input_data: model_types.SupportedDataType, length: int = 100
43
+ ) -> model_types.SupportedLocalDataType:
44
+ trunc_sample_input = model_signature._truncate_data(sample_input_data, length=length)
43
45
  local_sample_input: model_types.SupportedLocalDataType = None
44
46
  if isinstance(sample_input_data, SnowparkDataFrame):
45
47
  # Added because of Any from missing stubs.
@@ -78,7 +80,14 @@ def validate_signature(
78
80
  local_sample_input = get_truncated_sample_data(sample_input_data)
79
81
  for target_method in target_methods:
80
82
  predictions_df = get_prediction_fn(target_method, local_sample_input)
81
- sig = model_signature.infer_signature(local_sample_input, predictions_df)
83
+ sig = model_signature.infer_signature(
84
+ sample_input_data,
85
+ predictions_df,
86
+ input_feature_names=None,
87
+ output_feature_names=None,
88
+ input_data_limit=100,
89
+ output_data_limit=100,
90
+ )
82
91
  model_meta.signatures[target_method] = sig
83
92
 
84
93
  return model_meta
@@ -66,7 +66,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
66
66
  sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
67
67
 
68
68
  if inspect.iscoroutinefunction(target_method):
69
- with anyio.start_blocking_portal() as portal:
69
+ with anyio.from_thread.start_blocking_portal() as portal:
70
70
  predictions_df = portal.call(target_method, model, sample_input_data)
71
71
  else:
72
72
  predictions_df = target_method(model, sample_input_data)
@@ -98,7 +98,6 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
98
98
  if model.context.model_refs:
99
99
  for sub_name, model_ref in model.context.model_refs.items():
100
100
  handler = model_handler.find_handler(model_ref.model)
101
- assert handler is not None
102
101
  if handler is None:
103
102
  raise TypeError("Your input type to custom model is not currently supported")
104
103
  sub_model = handler.cast_model(model_ref.model)
@@ -323,6 +323,7 @@ class HuggingFacePipelineHandler(
323
323
  model_blob_options["task"],
324
324
  model=model_blob_file_or_dir_path,
325
325
  trust_remote_code=True,
326
+ torch_dtype="auto",
326
327
  **device_config,
327
328
  )
328
329
 
@@ -1,2 +1,2 @@
1
- REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
- ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<4', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'huggingface_hub<0.26', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2', 'sqlparse>=0.4,<1', 'tensorflow>=2.12.0,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
1
+ REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
+ ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'huggingface_hub<0.26', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'tensorflow>=2.12.0,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
@@ -12,7 +12,6 @@ class BaseDataHandler(ABC, Generic[model_types._DataType]):
12
12
  FEATURE_PREFIX: Final[str] = "feature"
13
13
  INPUT_PREFIX: Final[str] = "input"
14
14
  OUTPUT_PREFIX: Final[str] = "output"
15
- SIG_INFER_ROWS_COUNT_LIMIT: Final[int] = 10
16
15
 
17
16
  @staticmethod
18
17
  @abstractmethod
@@ -26,7 +25,7 @@ class BaseDataHandler(ABC, Generic[model_types._DataType]):
26
25
 
27
26
  @staticmethod
28
27
  @abstractmethod
29
- def truncate(data: model_types._DataType) -> model_types._DataType:
28
+ def truncate(data: model_types._DataType, length: int) -> model_types._DataType:
30
29
  ...
31
30
 
32
31
  @staticmethod
@@ -35,8 +35,8 @@ class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBu
35
35
  return len(data)
36
36
 
37
37
  @staticmethod
38
- def truncate(data: model_types._SupportedBuiltinsList) -> model_types._SupportedBuiltinsList:
39
- return data[: min(ListOfBuiltinHandler.count(data), ListOfBuiltinHandler.SIG_INFER_ROWS_COUNT_LIMIT)]
38
+ def truncate(data: model_types._SupportedBuiltinsList, length: int) -> model_types._SupportedBuiltinsList:
39
+ return data[: min(ListOfBuiltinHandler.count(data), length)]
40
40
 
41
41
  @staticmethod
42
42
  def validate(data: model_types._SupportedBuiltinsList) -> None:
@@ -23,8 +23,8 @@ class NumpyArrayHandler(base_handler.BaseDataHandler[model_types._SupportedNumpy
23
23
  return data.shape[0]
24
24
 
25
25
  @staticmethod
26
- def truncate(data: model_types._SupportedNumpyArray) -> model_types._SupportedNumpyArray:
27
- return data[: min(NumpyArrayHandler.count(data), NumpyArrayHandler.SIG_INFER_ROWS_COUNT_LIMIT)]
26
+ def truncate(data: model_types._SupportedNumpyArray, length: int) -> model_types._SupportedNumpyArray:
27
+ return data[: min(NumpyArrayHandler.count(data), length)]
28
28
 
29
29
  @staticmethod
30
30
  def validate(data: model_types._SupportedNumpyArray) -> None:
@@ -94,11 +94,10 @@ class SeqOfNumpyArrayHandler(base_handler.BaseDataHandler[Sequence[model_types._
94
94
  return min(NumpyArrayHandler.count(data_col) for data_col in data)
95
95
 
96
96
  @staticmethod
97
- def truncate(data: Sequence[model_types._SupportedNumpyArray]) -> Sequence[model_types._SupportedNumpyArray]:
98
- return [
99
- data_col[: min(SeqOfNumpyArrayHandler.count(data), SeqOfNumpyArrayHandler.SIG_INFER_ROWS_COUNT_LIMIT)]
100
- for data_col in data
101
- ]
97
+ def truncate(
98
+ data: Sequence[model_types._SupportedNumpyArray], length: int
99
+ ) -> Sequence[model_types._SupportedNumpyArray]:
100
+ return [data_col[: min(SeqOfNumpyArrayHandler.count(data), length)] for data_col in data]
102
101
 
103
102
  @staticmethod
104
103
  def validate(data: Sequence[model_types._SupportedNumpyArray]) -> None:
@@ -23,8 +23,8 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
23
23
  return len(data.index)
24
24
 
25
25
  @staticmethod
26
- def truncate(data: pd.DataFrame) -> pd.DataFrame:
27
- return data.head(min(PandasDataFrameHandler.count(data), PandasDataFrameHandler.SIG_INFER_ROWS_COUNT_LIMIT))
26
+ def truncate(data: pd.DataFrame, length: int) -> pd.DataFrame:
27
+ return data.head(min(PandasDataFrameHandler.count(data), length))
28
28
 
29
29
  @staticmethod
30
30
  def validate(data: Union[pd.DataFrame, pd.Series]) -> None:
@@ -33,11 +33,8 @@ class SeqOfPyTorchTensorHandler(base_handler.BaseDataHandler[Sequence["torch.Ten
33
33
  return min(data_col.shape[0] for data_col in data) # type: ignore[no-any-return]
34
34
 
35
35
  @staticmethod
36
- def truncate(data: Sequence["torch.Tensor"]) -> Sequence["torch.Tensor"]:
37
- return [
38
- data_col[: min(SeqOfPyTorchTensorHandler.count(data), SeqOfPyTorchTensorHandler.SIG_INFER_ROWS_COUNT_LIMIT)]
39
- for data_col in data
40
- ]
36
+ def truncate(data: Sequence["torch.Tensor"], length: int) -> Sequence["torch.Tensor"]:
37
+ return [data_col[: min(SeqOfPyTorchTensorHandler.count(data), 10)] for data_col in data]
41
38
 
42
39
  @staticmethod
43
40
  def validate(data: Sequence["torch.Tensor"]) -> None:
@@ -29,8 +29,8 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
29
29
  return data.count()
30
30
 
31
31
  @staticmethod
32
- def truncate(data: snowflake.snowpark.DataFrame) -> snowflake.snowpark.DataFrame:
33
- return cast(snowflake.snowpark.DataFrame, data.limit(SnowparkDataFrameHandler.SIG_INFER_ROWS_COUNT_LIMIT))
32
+ def truncate(data: snowflake.snowpark.DataFrame, length: int) -> snowflake.snowpark.DataFrame:
33
+ return cast(snowflake.snowpark.DataFrame, data.limit(length))
34
34
 
35
35
  @staticmethod
36
36
  def validate(data: snowflake.snowpark.DataFrame) -> None:
@@ -52,7 +52,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
52
52
  data: snowflake.snowpark.DataFrame, role: Literal["input", "output"]
53
53
  ) -> Sequence[core.BaseFeatureSpec]:
54
54
  return pandas_handler.PandasDataFrameHandler.infer_signature(
55
- SnowparkDataFrameHandler.convert_to_df(data.limit(n=1)), role=role
55
+ SnowparkDataFrameHandler.convert_to_df(data), role=role
56
56
  )
57
57
 
58
58
  @staticmethod
@@ -60,14 +60,9 @@ class SeqOfTensorflowTensorHandler(
60
60
 
61
61
  @staticmethod
62
62
  def truncate(
63
- data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]
63
+ data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]], length: int
64
64
  ) -> Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]:
65
- return [
66
- data_col[
67
- : min(SeqOfTensorflowTensorHandler.count(data), SeqOfTensorflowTensorHandler.SIG_INFER_ROWS_COUNT_LIMIT)
68
- ]
69
- for data_col in data
70
- ]
65
+ return [data_col[: min(SeqOfTensorflowTensorHandler.count(data), length)] for data_col in data]
71
66
 
72
67
  @staticmethod
73
68
  def validate(data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]) -> None:
@@ -59,11 +59,16 @@ _ALL_DATA_HANDLERS = _LOCAL_DATA_HANDLERS + [snowpark_handler.SnowparkDataFrameH
59
59
 
60
60
  def _truncate_data(
61
61
  data: model_types.SupportedDataType,
62
+ length: Optional[int] = 100,
62
63
  ) -> model_types.SupportedDataType:
63
64
  for handler in _ALL_DATA_HANDLERS:
64
65
  if handler.can_handle(data):
66
+ # If length is None, return the original data
67
+ if length is None:
68
+ return data
69
+
65
70
  row_count = handler.count(data)
66
- if row_count <= handler.SIG_INFER_ROWS_COUNT_LIMIT:
71
+ if row_count <= length:
67
72
  return data
68
73
 
69
74
  warnings.warn(
@@ -77,7 +82,7 @@ def _truncate_data(
77
82
  category=UserWarning,
78
83
  stacklevel=1,
79
84
  )
80
- return handler.truncate(data)
85
+ return handler.truncate(data, length)
81
86
  raise snowml_exceptions.SnowflakeMLException(
82
87
  error_code=error_codes.NOT_IMPLEMENTED,
83
88
  original_exception=NotImplementedError(
@@ -687,6 +692,8 @@ def infer_signature(
687
692
  output_data: model_types.SupportedLocalDataType,
688
693
  input_feature_names: Optional[List[str]] = None,
689
694
  output_feature_names: Optional[List[str]] = None,
695
+ input_data_limit: Optional[int] = 100,
696
+ output_data_limit: Optional[int] = 100,
690
697
  ) -> core.ModelSignature:
691
698
  """
692
699
  Infer model signature from given input and output sample data.
@@ -710,12 +717,18 @@ def infer_signature(
710
717
  output_data: Sample output data for the model.
711
718
  input_feature_names: Names for input features. Defaults to None.
712
719
  output_feature_names: Names for output features. Defaults to None.
720
+ input_data_limit: Limit the number of rows to be used in signature inference in the input data. Defaults to 100.
721
+ If None, all rows are used. If the number of rows in the input data is less than the limit, all rows are
722
+ used.
723
+ output_data_limit: Limit the number of rows to be used in signature inference in the output data. Defaults to
724
+ 100. If None, all rows are used. If the number of rows in the output data is less than the limit, all rows
725
+ are used.
713
726
 
714
727
  Returns:
715
728
  A model signature inferred from the given input and output sample data.
716
729
  """
717
- inputs = _infer_signature(input_data, role="input")
730
+ inputs = _infer_signature(_truncate_data(input_data, input_data_limit), role="input")
718
731
  inputs = utils.rename_features(inputs, input_feature_names)
719
- outputs = _infer_signature(output_data, role="output")
732
+ outputs = _infer_signature(_truncate_data(output_data, output_data_limit), role="output")
720
733
  outputs = utils.rename_features(outputs, output_feature_names)
721
734
  return core.ModelSignature(inputs, outputs)