snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/identifier.py +1 -1
  5. snowflake/ml/_internal/utils/mixins.py +71 -0
  6. snowflake/ml/_internal/utils/service_logger.py +4 -2
  7. snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
  8. snowflake/ml/data/data_connector.py +43 -2
  9. snowflake/ml/data/data_ingestor.py +8 -0
  10. snowflake/ml/data/torch_utils.py +1 -1
  11. snowflake/ml/dataset/dataset.py +3 -2
  12. snowflake/ml/dataset/dataset_reader.py +22 -6
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/experiment_tracking.py +319 -0
  20. snowflake/ml/jobs/_utils/constants.py +1 -1
  21. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +5 -3
  23. snowflake/ml/jobs/_utils/query_helper.py +20 -0
  24. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
  25. snowflake/ml/jobs/_utils/spec_utils.py +21 -4
  26. snowflake/ml/jobs/decorators.py +18 -25
  27. snowflake/ml/jobs/job.py +137 -37
  28. snowflake/ml/jobs/manager.py +228 -153
  29. snowflake/ml/lineage/lineage_node.py +2 -2
  30. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  31. snowflake/ml/model/_client/ops/model_ops.py +12 -3
  32. snowflake/ml/model/_client/ops/service_ops.py +324 -138
  33. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  34. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  35. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  38. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  39. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  40. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  45. snowflake/ml/model/event_handler.py +117 -0
  46. snowflake/ml/model/model_signature.py +9 -9
  47. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  48. snowflake/ml/model/target_platform.py +11 -0
  49. snowflake/ml/model/task.py +9 -0
  50. snowflake/ml/model/type_hints.py +5 -13
  51. snowflake/ml/modeling/framework/base.py +1 -1
  52. snowflake/ml/modeling/metrics/classification.py +14 -14
  53. snowflake/ml/modeling/metrics/correlation.py +19 -8
  54. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  55. snowflake/ml/modeling/metrics/ranking.py +6 -6
  56. snowflake/ml/modeling/metrics/regression.py +9 -9
  57. snowflake/ml/monitoring/explain_visualize.py +12 -5
  58. snowflake/ml/registry/_manager/model_manager.py +47 -15
  59. snowflake/ml/registry/registry.py +109 -64
  60. snowflake/ml/version.py +1 -1
  61. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
  62. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
  63. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
  64. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
  65. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,20 @@
1
1
  import logging
2
2
  import pathlib
3
3
  import textwrap
4
- from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast, overload
4
+ from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
5
5
  from uuid import uuid4
6
6
 
7
+ import pandas as pd
7
8
  import yaml
8
9
 
9
10
  from snowflake import snowpark
10
11
  from snowflake.ml._internal import telemetry
11
12
  from snowflake.ml._internal.utils import identifier
12
13
  from snowflake.ml.jobs import job as jb
13
- from snowflake.ml.jobs._utils import payload_utils, spec_utils
14
+ from snowflake.ml.jobs._utils import payload_utils, query_helper, spec_utils
14
15
  from snowflake.snowpark.context import get_active_session
15
16
  from snowflake.snowpark.exceptions import SnowparkSQLException
17
+ from snowflake.snowpark.functions import coalesce, col, lit, when
16
18
 
17
19
  logger = logging.getLogger(__name__)
18
20
 
@@ -25,39 +27,127 @@ T = TypeVar("T")
25
27
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
26
28
  def list_jobs(
27
29
  limit: int = 10,
28
- scope: Union[Literal["account", "database", "schema"], str, None] = None,
30
+ database: Optional[str] = None,
31
+ schema: Optional[str] = None,
29
32
  session: Optional[snowpark.Session] = None,
30
- ) -> snowpark.DataFrame:
33
+ ) -> pd.DataFrame:
31
34
  """
32
- Returns a Snowpark DataFrame with the list of jobs in the current session.
35
+ Returns a Pandas DataFrame with the list of jobs in the current session.
33
36
 
34
37
  Args:
35
38
  limit: The maximum number of jobs to return. Non-positive values are treated as no limit.
36
- scope: The scope to list jobs from, such as "schema" or "compute pool <pool_name>".
39
+ database: The database to use. If not specified, uses the current database.
40
+ schema: The schema to use. If not specified, uses the current schema.
37
41
  session: The Snowpark session to use. If none specified, uses active session.
38
42
 
39
43
  Returns:
40
44
  A DataFrame with the list of jobs.
41
45
 
46
+ Raises:
47
+ SnowparkSQLException: if there is an error retrieving the job history.
48
+
42
49
  Examples:
43
50
  >>> from snowflake.ml.jobs import list_jobs
44
- >>> list_jobs(limit=5).show()
51
+ >>> list_jobs(limit=5)
45
52
  """
46
53
  session = session or get_active_session()
54
+ try:
55
+ df = _get_job_history_spcs(
56
+ session,
57
+ limit=limit,
58
+ database=database,
59
+ schema=schema,
60
+ )
61
+ return df.to_pandas()
62
+ except SnowparkSQLException as spcs_error:
63
+ if spcs_error.sql_error_code == 2143:
64
+ logger.debug("Job history is not enabled. Please enable it to use this feature.")
65
+ df = _get_job_services(session, limit=limit, database=database, schema=schema)
66
+ return df.to_pandas()
67
+ raise
68
+
69
+
70
+ def _get_job_services(
71
+ session: snowpark.Session, limit: int = 10, database: Optional[str] = None, schema: Optional[str] = None
72
+ ) -> snowpark.DataFrame:
47
73
  query = "SHOW JOB SERVICES"
48
74
  query += f" LIKE '{JOB_ID_PREFIX}%'"
49
- if scope:
50
- query += f" IN {scope}"
75
+ database = database or session.get_current_database()
76
+ schema = schema or session.get_current_schema()
77
+ if database is None and schema is None:
78
+ query += "IN account"
79
+ elif not schema:
80
+ query += f" IN DATABASE {database}"
81
+ else:
82
+ query += f" IN {database}.{schema}"
51
83
  if limit > 0:
52
84
  query += f" LIMIT {limit}"
53
85
  df = session.sql(query)
54
86
  df = df.select(
55
87
  df['"name"'],
56
- df['"owner"'],
57
88
  df['"status"'],
58
- df['"created_on"'],
89
+ lit(None).alias('"message"'),
90
+ df['"database_name"'],
91
+ df['"schema_name"'],
92
+ df['"owner"'],
59
93
  df['"compute_pool"'],
60
- ).order_by('"created_on"', ascending=False)
94
+ df['"target_instances"'],
95
+ df['"created_on"'].alias('"created_time"'),
96
+ when(col('"status"').isin(jb.TERMINAL_JOB_STATUSES), col('"updated_on"'))
97
+ .otherwise(lit(None))
98
+ .alias('"completed_time"'),
99
+ ).order_by('"created_time"', ascending=False)
100
+ return df
101
+
102
+
103
+ def _get_job_history_spcs(
104
+ session: snowpark.Session,
105
+ limit: int = 10,
106
+ database: Optional[str] = None,
107
+ schema: Optional[str] = None,
108
+ include_deleted: bool = False,
109
+ created_time_start: Optional[str] = None,
110
+ created_time_end: Optional[str] = None,
111
+ ) -> snowpark.DataFrame:
112
+ query = ["select * from table(snowflake.spcs.get_job_history("]
113
+ query_params = []
114
+ if created_time_start:
115
+ query_params.append(f"created_time_start => TO_TIMESTAMP_LTZ('{created_time_start}')")
116
+ if created_time_end:
117
+ query_params.append(f"created_time_end => TO_TIMESTAMP_LTZ('{created_time_end}')")
118
+ query.append(",".join(query_params))
119
+ query.append("))")
120
+ condition = []
121
+ database = database or session.get_current_database()
122
+ schema = schema or session.get_current_schema()
123
+
124
+ # format database and schema identifiers
125
+ if database:
126
+ condition.append(f"DATABASE_NAME = '{identifier.resolve_identifier(database)}'")
127
+
128
+ if schema:
129
+ condition.append(f"SCHEMA_NAME = '{identifier.resolve_identifier(schema)}'")
130
+
131
+ if not include_deleted:
132
+ condition.append("DELETED_TIME IS NULL")
133
+
134
+ if len(condition) > 0:
135
+ query.append("WHERE " + " AND ".join(condition))
136
+ if limit > 0:
137
+ query.append(f"LIMIT {limit}")
138
+ df = session.sql("\n".join(query))
139
+ df = df.select(
140
+ df["NAME"].alias('"name"'),
141
+ df["STATUS"].alias('"status"'),
142
+ df["MESSAGE"].alias('"message"'),
143
+ df["DATABASE_NAME"].alias('"database_name"'),
144
+ df["SCHEMA_NAME"].alias('"schema_name"'),
145
+ df["OWNER"].alias('"owner"'),
146
+ df["COMPUTE_POOL_NAME"].alias('"compute_pool"'),
147
+ coalesce(df["PARAMETERS"]["REPLICAS"], lit(1)).alias('"target_instances"'),
148
+ df["CREATED_TIME"].alias('"created_time"'),
149
+ df["COMPLETED_TIME"].alias('"completed_time"'),
150
+ )
61
151
  return df
62
152
 
63
153
 
@@ -74,9 +164,9 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
74
164
 
75
165
  job_id = f"{database}.{schema}.{job_name}"
76
166
  try:
77
- # Validate that job exists by doing a status check
167
+ # Validate that job exists by doing a spec lookup
78
168
  job = jb.MLJob[Any](job_id, session=session)
79
- _ = job.status
169
+ _ = job._service_spec
80
170
  return job
81
171
  except SnowparkSQLException as e:
82
172
  if "does not exist" in e.message:
@@ -95,7 +185,7 @@ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Sessio
95
185
  logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
96
186
  except Exception as e:
97
187
  logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
98
- session.sql("DROP SERVICE IDENTIFIER(?)", params=(job.id,)).collect()
188
+ query_helper.run_query(session, "DROP SERVICE IDENTIFIER(?)", params=(job.id,))
99
189
 
100
190
 
101
191
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
@@ -105,17 +195,11 @@ def submit_file(
105
195
  *,
106
196
  stage_name: str,
107
197
  args: Optional[list[str]] = None,
108
- env_vars: Optional[dict[str, str]] = None,
198
+ target_instances: int = 1,
109
199
  pip_requirements: Optional[list[str]] = None,
110
200
  external_access_integrations: Optional[list[str]] = None,
111
- query_warehouse: Optional[str] = None,
112
- spec_overrides: Optional[dict[str, Any]] = None,
113
- target_instances: int = 1,
114
- min_instances: Optional[int] = None,
115
- enable_metrics: bool = False,
116
- database: Optional[str] = None,
117
- schema: Optional[str] = None,
118
201
  session: Optional[snowpark.Session] = None,
202
+ **kwargs: Any,
119
203
  ) -> jb.MLJob[None]:
120
204
  """
121
205
  Submit a Python file as a job to the compute pool.
@@ -125,18 +209,20 @@ def submit_file(
125
209
  compute_pool: The compute pool to use for the job.
126
210
  stage_name: The name of the stage where the job payload will be uploaded.
127
211
  args: A list of arguments to pass to the job.
128
- env_vars: Environment variables to set in container
212
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
129
213
  pip_requirements: A list of pip requirements for the job.
130
214
  external_access_integrations: A list of external access integrations.
131
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
132
- spec_overrides: Custom service specification overrides to apply.
133
- target_instances: The number of instances to use for the job. If none specified, single node job is created.
134
- min_instances: The minimum number of nodes required to start the job. If none specified,
135
- defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
136
- enable_metrics: Whether to enable metrics publishing for the job.
137
- database: The database to use.
138
- schema: The schema to use.
139
215
  session: The Snowpark session to use. If none specified, uses active session.
216
+ kwargs: Additional keyword arguments. Supported arguments:
217
+ database (str): The database to use for the job.
218
+ schema (str): The schema to use for the job.
219
+ min_instances (int): The minimum number of nodes required to start the job.
220
+ If none specified, defaults to target_instances. If set, the job
221
+ will not start until the minimum number of nodes is available.
222
+ env_vars (dict): Environment variables to set in container.
223
+ enable_metrics (bool): Whether to enable metrics publishing for the job.
224
+ query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
225
+ spec_overrides (dict): A dictionary of overrides for the service spec.
140
226
 
141
227
  Returns:
142
228
  An object representing the submitted job.
@@ -146,17 +232,11 @@ def submit_file(
146
232
  args=args,
147
233
  compute_pool=compute_pool,
148
234
  stage_name=stage_name,
149
- env_vars=env_vars,
235
+ target_instances=target_instances,
150
236
  pip_requirements=pip_requirements,
151
237
  external_access_integrations=external_access_integrations,
152
- query_warehouse=query_warehouse,
153
- spec_overrides=spec_overrides,
154
- target_instances=target_instances,
155
- min_instances=min_instances,
156
- enable_metrics=enable_metrics,
157
- database=database,
158
- schema=schema,
159
238
  session=session,
239
+ **kwargs,
160
240
  )
161
241
 
162
242
 
@@ -168,17 +248,11 @@ def submit_directory(
168
248
  entrypoint: str,
169
249
  stage_name: str,
170
250
  args: Optional[list[str]] = None,
171
- env_vars: Optional[dict[str, str]] = None,
251
+ target_instances: int = 1,
172
252
  pip_requirements: Optional[list[str]] = None,
173
253
  external_access_integrations: Optional[list[str]] = None,
174
- query_warehouse: Optional[str] = None,
175
- spec_overrides: Optional[dict[str, Any]] = None,
176
- target_instances: int = 1,
177
- min_instances: Optional[int] = None,
178
- enable_metrics: bool = False,
179
- database: Optional[str] = None,
180
- schema: Optional[str] = None,
181
254
  session: Optional[snowpark.Session] = None,
255
+ **kwargs: Any,
182
256
  ) -> jb.MLJob[None]:
183
257
  """
184
258
  Submit a directory containing Python script(s) as a job to the compute pool.
@@ -189,18 +263,20 @@ def submit_directory(
189
263
  entrypoint: The relative path to the entry point script inside the source directory.
190
264
  stage_name: The name of the stage where the job payload will be uploaded.
191
265
  args: A list of arguments to pass to the job.
192
- env_vars: Environment variables to set in container
266
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
193
267
  pip_requirements: A list of pip requirements for the job.
194
268
  external_access_integrations: A list of external access integrations.
195
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
196
- spec_overrides: Custom service specification overrides to apply.
197
- target_instances: The number of instances to use for the job. If none specified, single node job is created.
198
- min_instances: The minimum number of nodes required to start the job. If none specified,
199
- defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
200
- enable_metrics: Whether to enable metrics publishing for the job.
201
- database: The database to use.
202
- schema: The schema to use.
203
269
  session: The Snowpark session to use. If none specified, uses active session.
270
+ kwargs: Additional keyword arguments. Supported arguments:
271
+ database (str): The database to use for the job.
272
+ schema (str): The schema to use for the job.
273
+ min_instances (int): The minimum number of nodes required to start the job.
274
+ If none specified, defaults to target_instances. If set, the job
275
+ will not start until the minimum number of nodes is available.
276
+ env_vars (dict): Environment variables to set in container.
277
+ enable_metrics (bool): Whether to enable metrics publishing for the job.
278
+ query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
279
+ spec_overrides (dict): A dictionary of overrides for the service spec.
204
280
 
205
281
  Returns:
206
282
  An object representing the submitted job.
@@ -211,17 +287,11 @@ def submit_directory(
211
287
  args=args,
212
288
  compute_pool=compute_pool,
213
289
  stage_name=stage_name,
214
- env_vars=env_vars,
290
+ target_instances=target_instances,
215
291
  pip_requirements=pip_requirements,
216
292
  external_access_integrations=external_access_integrations,
217
- query_warehouse=query_warehouse,
218
- spec_overrides=spec_overrides,
219
- target_instances=target_instances,
220
- min_instances=min_instances,
221
- enable_metrics=enable_metrics,
222
- database=database,
223
- schema=schema,
224
293
  session=session,
294
+ **kwargs,
225
295
  )
226
296
 
227
297
 
@@ -233,17 +303,11 @@ def submit_from_stage(
233
303
  entrypoint: str,
234
304
  stage_name: str,
235
305
  args: Optional[list[str]] = None,
236
- env_vars: Optional[dict[str, str]] = None,
306
+ target_instances: int = 1,
237
307
  pip_requirements: Optional[list[str]] = None,
238
308
  external_access_integrations: Optional[list[str]] = None,
239
- query_warehouse: Optional[str] = None,
240
- spec_overrides: Optional[dict[str, Any]] = None,
241
- target_instances: int = 1,
242
- min_instances: Optional[int] = None,
243
- enable_metrics: bool = False,
244
- database: Optional[str] = None,
245
- schema: Optional[str] = None,
246
309
  session: Optional[snowpark.Session] = None,
310
+ **kwargs: Any,
247
311
  ) -> jb.MLJob[None]:
248
312
  """
249
313
  Submit a directory containing Python script(s) as a job to the compute pool.
@@ -254,19 +318,20 @@ def submit_from_stage(
254
318
  entrypoint: a stage path containing the entry point script inside the source directory.
255
319
  stage_name: The name of the stage where the job payload will be uploaded.
256
320
  args: A list of arguments to pass to the job.
257
- env_vars: Environment variables to set in container
321
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
258
322
  pip_requirements: A list of pip requirements for the job.
259
323
  external_access_integrations: A list of external access integrations.
260
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
261
- spec_overrides: Custom service specification overrides to apply.
262
- target_instances: The number of instances to use for the job. If none specified, single node job is created.
263
- min_instances: The minimum number of nodes required to start the job. If none specified,
264
- defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
265
- enable_metrics: Whether to enable metrics publishing for the job.
266
- database: The database to use.
267
- schema: The schema to use.
268
324
  session: The Snowpark session to use. If none specified, uses active session.
269
-
325
+ kwargs: Additional keyword arguments. Supported arguments:
326
+ database (str): The database to use for the job.
327
+ schema (str): The schema to use for the job.
328
+ min_instances (int): The minimum number of nodes required to start the job.
329
+ If none specified, defaults to target_instances. If set, the job
330
+ will not start until the minimum number of nodes is available.
331
+ env_vars (dict): Environment variables to set in container.
332
+ enable_metrics (bool): Whether to enable metrics publishing for the job.
333
+ query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
334
+ spec_overrides (dict): A dictionary of overrides for the service spec.
270
335
 
271
336
  Returns:
272
337
  An object representing the submitted job.
@@ -277,17 +342,11 @@ def submit_from_stage(
277
342
  args=args,
278
343
  compute_pool=compute_pool,
279
344
  stage_name=stage_name,
280
- env_vars=env_vars,
345
+ target_instances=target_instances,
281
346
  pip_requirements=pip_requirements,
282
347
  external_access_integrations=external_access_integrations,
283
- query_warehouse=query_warehouse,
284
- spec_overrides=spec_overrides,
285
- target_instances=target_instances,
286
- min_instances=min_instances,
287
- enable_metrics=enable_metrics,
288
- database=database,
289
- schema=schema,
290
348
  session=session,
349
+ **kwargs,
291
350
  )
292
351
 
293
352
 
@@ -299,17 +358,11 @@ def _submit_job(
299
358
  stage_name: str,
300
359
  entrypoint: Optional[str] = None,
301
360
  args: Optional[list[str]] = None,
302
- env_vars: Optional[dict[str, str]] = None,
361
+ target_instances: int = 1,
303
362
  pip_requirements: Optional[list[str]] = None,
304
363
  external_access_integrations: Optional[list[str]] = None,
305
- query_warehouse: Optional[str] = None,
306
- spec_overrides: Optional[dict[str, Any]] = None,
307
- target_instances: int = 1,
308
- min_instances: Optional[int] = None,
309
- enable_metrics: bool = False,
310
- database: Optional[str] = None,
311
- schema: Optional[str] = None,
312
364
  session: Optional[snowpark.Session] = None,
365
+ **kwargs: Any,
313
366
  ) -> jb.MLJob[None]:
314
367
  ...
315
368
 
@@ -322,17 +375,11 @@ def _submit_job(
322
375
  stage_name: str,
323
376
  entrypoint: Optional[str] = None,
324
377
  args: Optional[list[str]] = None,
325
- env_vars: Optional[dict[str, str]] = None,
378
+ target_instances: int = 1,
326
379
  pip_requirements: Optional[list[str]] = None,
327
380
  external_access_integrations: Optional[list[str]] = None,
328
- query_warehouse: Optional[str] = None,
329
- spec_overrides: Optional[dict[str, Any]] = None,
330
- target_instances: int = 1,
331
- min_instances: Optional[int] = None,
332
- enable_metrics: bool = False,
333
- database: Optional[str] = None,
334
- schema: Optional[str] = None,
335
381
  session: Optional[snowpark.Session] = None,
382
+ **kwargs: Any,
336
383
  ) -> jb.MLJob[T]:
337
384
  ...
338
385
 
@@ -345,8 +392,9 @@ def _submit_job(
345
392
  # TODO: Log lengths of args, env_vars, and spec_overrides values
346
393
  "pip_requirements",
347
394
  "external_access_integrations",
395
+ "num_instances", # deprecated
348
396
  "target_instances",
349
- "enable_metrics",
397
+ "min_instances",
350
398
  ],
351
399
  )
352
400
  def _submit_job(
@@ -356,17 +404,9 @@ def _submit_job(
356
404
  stage_name: str,
357
405
  entrypoint: Optional[str] = None,
358
406
  args: Optional[list[str]] = None,
359
- env_vars: Optional[dict[str, str]] = None,
360
- pip_requirements: Optional[list[str]] = None,
361
- external_access_integrations: Optional[list[str]] = None,
362
- query_warehouse: Optional[str] = None,
363
- spec_overrides: Optional[dict[str, Any]] = None,
364
407
  target_instances: int = 1,
365
- min_instances: Optional[int] = None,
366
- enable_metrics: bool = False,
367
- database: Optional[str] = None,
368
- schema: Optional[str] = None,
369
408
  session: Optional[snowpark.Session] = None,
409
+ **kwargs: Any,
370
410
  ) -> jb.MLJob[T]:
371
411
  """
372
412
  Submit a job to the compute pool.
@@ -377,37 +417,48 @@ def _submit_job(
377
417
  stage_name: The name of the stage where the job payload will be uploaded.
378
418
  entrypoint: The entry point for the job execution. Required if source is a directory.
379
419
  args: A list of arguments to pass to the job.
380
- env_vars: Environment variables to set in container
381
- pip_requirements: A list of pip requirements for the job.
382
- external_access_integrations: A list of external access integrations.
383
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
384
- spec_overrides: Custom service specification overrides to apply.
385
420
  target_instances: The number of instances to use for the job. If none specified, single node job is created.
386
- min_instances: The minimum number of nodes required to start the job. If none specified,
387
- defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
388
- enable_metrics: Whether to enable metrics publishing for the job.
389
- database: The database to use.
390
- schema: The schema to use.
391
421
  session: The Snowpark session to use. If none specified, uses active session.
422
+ kwargs: Additional keyword arguments.
392
423
 
393
424
  Returns:
394
425
  An object representing the submitted job.
395
426
 
396
427
  Raises:
397
- RuntimeError: If required Snowflake features are not enabled.
398
428
  ValueError: If database or schema value(s) are invalid
429
+ SnowparkSQLException: If there is an error submitting the job.
399
430
  """
431
+ session = session or get_active_session()
432
+
433
+ # Check for deprecated args
434
+ if "num_instances" in kwargs:
435
+ logger.warning(
436
+ "'num_instances' is deprecated and will be removed in a future release. Use 'target_instances' instead."
437
+ )
438
+ target_instances = max(target_instances, kwargs.pop("num_instances"))
439
+
440
+ # Use kwargs for less common optional parameters
441
+ database = kwargs.pop("database", None)
442
+ schema = kwargs.pop("schema", None)
443
+ min_instances = kwargs.pop("min_instances", target_instances)
444
+ pip_requirements = kwargs.pop("pip_requirements", None)
445
+ external_access_integrations = kwargs.pop("external_access_integrations", None)
446
+ env_vars = kwargs.pop("env_vars", None)
447
+ spec_overrides = kwargs.pop("spec_overrides", None)
448
+ enable_metrics = kwargs.pop("enable_metrics", True)
449
+ query_warehouse = kwargs.pop("query_warehouse", None)
450
+
451
+ # Warn if there are unknown kwargs
452
+ if kwargs:
453
+ logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
454
+
455
+ # Validate parameters
400
456
  if database and not schema:
401
457
  raise ValueError("Schema must be specified if database is specified.")
402
458
  if target_instances < 1:
403
459
  raise ValueError("target_instances must be greater than 0.")
404
-
405
- min_instances = target_instances if min_instances is None else min_instances
406
460
  if not (0 < min_instances <= target_instances):
407
461
  raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
408
-
409
- session = session or get_active_session()
410
-
411
462
  if min_instances > 1:
412
463
  # Validate min_instances against compute pool max_nodes
413
464
  pool_info = jb._get_compute_pool_info(session, compute_pool)
@@ -418,13 +469,8 @@ def _submit_job(
418
469
  f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
419
470
  )
420
471
 
421
- # Validate database and schema identifiers on client side since
422
- # SQL parser for EXECUTE JOB SERVICE seems to struggle with this
423
- database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
424
- schema = identifier.resolve_identifier(cast(str, schema or session.get_current_schema()))
425
-
426
472
  job_name = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
427
- job_id = f"{database}.{schema}.{job_name}"
473
+ job_id = identifier.get_schema_level_object_identifier(database, schema, job_name)
428
474
  stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
429
475
  stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
430
476
  stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
@@ -453,7 +499,48 @@ def _submit_job(
453
499
  if spec_overrides:
454
500
  spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
455
501
 
456
- # Generate SQL command for job submission
502
+ query_text, params = _generate_submission_query(
503
+ spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id
504
+ )
505
+ try:
506
+ _ = query_helper.run_query(session, query_text, params=params)
507
+ except SnowparkSQLException as e:
508
+ if "Invalid spec: unknown option 'resourceManagement' for 'spec'." in e.message:
509
+ logger.warning("Dropping 'resourceManagement' from spec because control policy is not enabled.")
510
+ spec["spec"].pop("resourceManagement", None)
511
+ query_text, params = _generate_submission_query(
512
+ spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id
513
+ )
514
+ _ = query_helper.run_query(session, query_text, params=params)
515
+ else:
516
+ raise
517
+ return get_job(job_id, session=session)
518
+
519
+
520
+ def _generate_submission_query(
521
+ spec: dict[str, Any],
522
+ external_access_integrations: list[str],
523
+ query_warehouse: Optional[str],
524
+ target_instances: int,
525
+ session: snowpark.Session,
526
+ compute_pool: str,
527
+ job_id: str,
528
+ ) -> tuple[str, list[Any]]:
529
+ """
530
+ Generate the SQL query for job submission.
531
+
532
+ Args:
533
+ spec: The service spec for the job.
534
+ external_access_integrations: The external access integrations for the job.
535
+ query_warehouse: The query warehouse for the job.
536
+ target_instances: The number of instances for the job.
537
+ session: The Snowpark session to use.
538
+ compute_pool: The compute pool to use for the job.
539
+ job_id: The ID of the job.
540
+
541
+ Returns:
542
+ A tuple containing the SQL query text and the parameters for the query.
543
+ """
457
544
  query_template = textwrap.dedent(
458
545
  """\
459
546
  EXECUTE JOB SERVICE
@@ -477,17 +564,5 @@ def _submit_job(
477
564
  if target_instances > 1:
478
565
  query.append("REPLICAS = ?")
479
566
  params.append(target_instances)
480
-
481
- # Submit job
482
567
  query_text = "\n".join(line for line in query if line)
483
-
484
- try:
485
- _ = session.sql(query_text, params=params).collect()
486
- except SnowparkSQLException as e:
487
- if "invalid property 'ASYNC'" in e.message:
488
- raise RuntimeError(
489
- "SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
490
- ) from e
491
- raise
492
-
493
- return jb.MLJob(job_id, service_spec=spec, session=session)
568
+ return query_text, params
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Literal, Optional, Union
4
4
 
5
5
  from snowflake import snowpark
6
6
  from snowflake.ml._internal import telemetry
7
- from snowflake.ml._internal.utils import identifier
7
+ from snowflake.ml._internal.utils import identifier, mixins
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from snowflake.ml import dataset
@@ -15,7 +15,7 @@ _PROJECT = "LINEAGE"
15
15
  DOMAIN_LINEAGE_REGISTRY: dict[str, type["LineageNode"]] = {}
16
16
 
17
17
 
18
- class LineageNode:
18
+ class LineageNode(mixins.SerializableSessionMixin):
19
19
  """
20
20
  Represents a node in a lineage graph and serves as the base class for all machine learning objects.
21
21
  """
@@ -1,4 +1,5 @@
1
1
  import enum
2
+ import logging
2
3
  import pathlib
3
4
  import tempfile
4
5
  import warnings
@@ -10,7 +11,7 @@ from snowflake import snowpark
10
11
  from snowflake.ml._internal import telemetry
11
12
  from snowflake.ml._internal.utils import sql_identifier
12
13
  from snowflake.ml.lineage import lineage_node
13
- from snowflake.ml.model import type_hints as model_types
14
+ from snowflake.ml.model import task, type_hints
14
15
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
15
16
  from snowflake.ml.model._model_composer import model_composer
16
17
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -401,7 +402,7 @@ class ModelVersion(lineage_node.LineageNode):
401
402
  project=_TELEMETRY_PROJECT,
402
403
  subproject=_TELEMETRY_SUBPROJECT,
403
404
  )
404
- def get_model_task(self) -> model_types.Task:
405
+ def get_model_task(self) -> task.Task:
405
406
  statement_params = telemetry.get_statement_params(
406
407
  project=_TELEMETRY_PROJECT,
407
408
  subproject=_TELEMETRY_SUBPROJECT,
@@ -607,8 +608,8 @@ class ModelVersion(lineage_node.LineageNode):
607
608
  self,
608
609
  *,
609
610
  force: bool = False,
610
- options: Optional[model_types.ModelLoadOption] = None,
611
- ) -> model_types.SupportedModelType:
611
+ options: Optional[type_hints.ModelLoadOption] = None,
612
+ ) -> type_hints.SupportedModelType:
612
613
  """Load the underlying original Python object back from a model.
613
614
  This operation requires to have the exact the same environment as the one when logging the model, otherwise,
614
615
  the model might be not functional or some other problems might occur.
@@ -889,6 +890,17 @@ class ModelVersion(lineage_node.LineageNode):
889
890
  project=_TELEMETRY_PROJECT,
890
891
  subproject=_TELEMETRY_SUBPROJECT,
891
892
  )
893
+
894
+ # Check root logger level and emit warning if needed
895
+ root_logger = logging.getLogger()
896
+ if root_logger.level in (logging.WARNING, logging.ERROR):
897
+ warnings.warn(
898
+ "Suppressing service logs. Set the log level to INFO if you would like "
899
+ "verbose service logs (e.g., logging.getLogger().setLevel(logging.INFO)).",
900
+ UserWarning,
901
+ stacklevel=2,
902
+ )
903
+
892
904
  if build_external_access_integration is not None:
893
905
  msg = (
894
906
  "`build_external_access_integration` is deprecated. "