snowflake-ml-python 1.8.5__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. snowflake/ml/_internal/telemetry.py +6 -9
  2. snowflake/ml/_internal/utils/connection_params.py +196 -0
  3. snowflake/ml/_internal/utils/identifier.py +1 -1
  4. snowflake/ml/_internal/utils/mixins.py +61 -0
  5. snowflake/ml/jobs/__init__.py +2 -0
  6. snowflake/ml/jobs/_utils/constants.py +3 -2
  7. snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
  8. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  9. snowflake/ml/jobs/_utils/payload_utils.py +89 -40
  10. snowflake/ml/jobs/_utils/query_helper.py +9 -0
  11. snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
  12. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
  13. snowflake/ml/jobs/_utils/spec_utils.py +29 -5
  14. snowflake/ml/jobs/_utils/stage_utils.py +119 -0
  15. snowflake/ml/jobs/_utils/types.py +5 -1
  16. snowflake/ml/jobs/decorators.py +20 -28
  17. snowflake/ml/jobs/job.py +197 -61
  18. snowflake/ml/jobs/manager.py +253 -121
  19. snowflake/ml/model/_client/model/model_impl.py +58 -0
  20. snowflake/ml/model/_client/model/model_version_impl.py +90 -0
  21. snowflake/ml/model/_client/ops/model_ops.py +18 -6
  22. snowflake/ml/model/_client/ops/service_ops.py +23 -6
  23. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
  24. snowflake/ml/model/_client/sql/service.py +68 -20
  25. snowflake/ml/model/_client/sql/stage.py +5 -2
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
  27. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  28. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  29. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  31. snowflake/ml/model/_signatures/core.py +24 -0
  32. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  33. snowflake/ml/model/target_platform.py +11 -0
  34. snowflake/ml/model/task.py +9 -0
  35. snowflake/ml/model/type_hints.py +5 -13
  36. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  37. snowflake/ml/monitoring/explain_visualize.py +2 -2
  38. snowflake/ml/monitoring/model_monitor.py +0 -4
  39. snowflake/ml/registry/_manager/model_manager.py +30 -15
  40. snowflake/ml/registry/registry.py +144 -47
  41. snowflake/ml/utils/connection_params.py +1 -1
  42. snowflake/ml/utils/html_utils.py +263 -0
  43. snowflake/ml/version.py +1 -1
  44. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +64 -19
  45. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +48 -41
  46. snowflake/ml/monitoring/model_monitor_version.py +0 -1
  47. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
  48. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  49. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,21 @@
1
1
  import logging
2
2
  import pathlib
3
3
  import textwrap
4
- from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast, overload
4
+ from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
5
5
  from uuid import uuid4
6
6
 
7
+ import pandas as pd
7
8
  import yaml
8
9
 
9
10
  from snowflake import snowpark
11
+ from snowflake.connector import errors
10
12
  from snowflake.ml._internal import telemetry
11
13
  from snowflake.ml._internal.utils import identifier
12
14
  from snowflake.ml.jobs import job as jb
13
- from snowflake.ml.jobs._utils import payload_utils, spec_utils
15
+ from snowflake.ml.jobs._utils import payload_utils, query_helper, spec_utils
14
16
  from snowflake.snowpark.context import get_active_session
15
17
  from snowflake.snowpark.exceptions import SnowparkSQLException
18
+ from snowflake.snowpark.functions import coalesce, col, lit, when
16
19
 
17
20
  logger = logging.getLogger(__name__)
18
21
 
@@ -25,39 +28,127 @@ T = TypeVar("T")
25
28
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
26
29
  def list_jobs(
27
30
  limit: int = 10,
28
- scope: Union[Literal["account", "database", "schema"], str, None] = None,
31
+ database: Optional[str] = None,
32
+ schema: Optional[str] = None,
29
33
  session: Optional[snowpark.Session] = None,
30
- ) -> snowpark.DataFrame:
34
+ ) -> pd.DataFrame:
31
35
  """
32
- Returns a Snowpark DataFrame with the list of jobs in the current session.
36
+ Returns a Pandas DataFrame with the list of jobs in the current session.
33
37
 
34
38
  Args:
35
39
  limit: The maximum number of jobs to return. Non-positive values are treated as no limit.
36
- scope: The scope to list jobs from, such as "schema" or "compute pool <pool_name>".
40
+ database: The database to use. If not specified, uses the current database.
41
+ schema: The schema to use. If not specified, uses the current schema.
37
42
  session: The Snowpark session to use. If none specified, uses active session.
38
43
 
39
44
  Returns:
40
45
  A DataFrame with the list of jobs.
41
46
 
47
+ Raises:
48
+ SnowparkSQLException: if there is an error retrieving the job history.
49
+
42
50
  Examples:
43
51
  >>> from snowflake.ml.jobs import list_jobs
44
- >>> list_jobs(limit=5).show()
52
+ >>> list_jobs(limit=5)
45
53
  """
46
54
  session = session or get_active_session()
55
+ try:
56
+ df = _get_job_history_spcs(
57
+ session,
58
+ limit=limit,
59
+ database=database,
60
+ schema=schema,
61
+ )
62
+ return df.to_pandas()
63
+ except SnowparkSQLException as spcs_error:
64
+ if spcs_error.sql_error_code == 2143:
65
+ logger.debug("Job history is not enabled. Please enable it to use this feature.")
66
+ df = _get_job_services(session, limit=limit, database=database, schema=schema)
67
+ return df.to_pandas()
68
+ raise
69
+
70
+
71
+ def _get_job_services(
72
+ session: snowpark.Session, limit: int = 10, database: Optional[str] = None, schema: Optional[str] = None
73
+ ) -> snowpark.DataFrame:
47
74
  query = "SHOW JOB SERVICES"
48
75
  query += f" LIKE '{JOB_ID_PREFIX}%'"
49
- if scope:
50
- query += f" IN {scope}"
76
+ database = database or session.get_current_database()
77
+ schema = schema or session.get_current_schema()
78
+ if database is None and schema is None:
79
+ query += "IN account"
80
+ elif not schema:
81
+ query += f" IN DATABASE {database}"
82
+ else:
83
+ query += f" IN {database}.{schema}"
51
84
  if limit > 0:
52
85
  query += f" LIMIT {limit}"
53
86
  df = session.sql(query)
54
87
  df = df.select(
55
88
  df['"name"'],
56
- df['"owner"'],
57
89
  df['"status"'],
58
- df['"created_on"'],
90
+ lit(None).alias('"message"'),
91
+ df['"database_name"'],
92
+ df['"schema_name"'],
93
+ df['"owner"'],
59
94
  df['"compute_pool"'],
60
- ).order_by('"created_on"', ascending=False)
95
+ df['"target_instances"'],
96
+ df['"created_on"'].alias('"created_time"'),
97
+ when(col('"status"').isin(jb.TERMINAL_JOB_STATUSES), col('"updated_on"'))
98
+ .otherwise(lit(None))
99
+ .alias('"completed_time"'),
100
+ ).order_by('"created_time"', ascending=False)
101
+ return df
102
+
103
+
104
+ def _get_job_history_spcs(
105
+ session: snowpark.Session,
106
+ limit: int = 10,
107
+ database: Optional[str] = None,
108
+ schema: Optional[str] = None,
109
+ include_deleted: bool = False,
110
+ created_time_start: Optional[str] = None,
111
+ created_time_end: Optional[str] = None,
112
+ ) -> snowpark.DataFrame:
113
+ query = ["select * from table(snowflake.spcs.get_job_history("]
114
+ query_params = []
115
+ if created_time_start:
116
+ query_params.append(f"created_time_start => TO_TIMESTAMP_LTZ('{created_time_start}')")
117
+ if created_time_end:
118
+ query_params.append(f"created_time_end => TO_TIMESTAMP_LTZ('{created_time_end}')")
119
+ query.append(",".join(query_params))
120
+ query.append("))")
121
+ condition = []
122
+ database = database or session.get_current_database()
123
+ schema = schema or session.get_current_schema()
124
+
125
+ # format database and schema identifiers
126
+ if database:
127
+ condition.append(f"DATABASE_NAME = '{identifier.resolve_identifier(database)}'")
128
+
129
+ if schema:
130
+ condition.append(f"SCHEMA_NAME = '{identifier.resolve_identifier(schema)}'")
131
+
132
+ if not include_deleted:
133
+ condition.append("DELETED_TIME IS NULL")
134
+
135
+ if len(condition) > 0:
136
+ query.append("WHERE " + " AND ".join(condition))
137
+ if limit > 0:
138
+ query.append(f"LIMIT {limit}")
139
+ df = session.sql("\n".join(query))
140
+ df = df.select(
141
+ df["NAME"].alias('"name"'),
142
+ df["STATUS"].alias('"status"'),
143
+ df["MESSAGE"].alias('"message"'),
144
+ df["DATABASE_NAME"].alias('"database_name"'),
145
+ df["SCHEMA_NAME"].alias('"schema_name"'),
146
+ df["OWNER"].alias('"owner"'),
147
+ df["COMPUTE_POOL_NAME"].alias('"compute_pool"'),
148
+ coalesce(df["PARAMETERS"]["REPLICAS"], lit(1)).alias('"target_instances"'),
149
+ df["CREATED_TIME"].alias('"created_time"'),
150
+ df["COMPLETED_TIME"].alias('"completed_time"'),
151
+ )
61
152
  return df
62
153
 
63
154
 
@@ -74,12 +165,12 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
74
165
 
75
166
  job_id = f"{database}.{schema}.{job_name}"
76
167
  try:
77
- # Validate that job exists by doing a status check
168
+ # Validate that job exists by doing a spec lookup
78
169
  job = jb.MLJob[Any](job_id, session=session)
79
- _ = job.status
170
+ _ = job._service_spec
80
171
  return job
81
- except SnowparkSQLException as e:
82
- if "does not exist" in e.message:
172
+ except errors.ProgrammingError as e:
173
+ if "does not exist" in str(e):
83
174
  raise ValueError(f"Job does not exist: {job_id}") from e
84
175
  raise
85
176
 
@@ -87,13 +178,15 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
87
178
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
88
179
  def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
89
180
  """Delete a job service from the backend. Status and logs will be lost."""
90
- if isinstance(job, jb.MLJob):
91
- job_id = job.id
92
- session = job._session or session
93
- else:
94
- job_id = job
95
- session = session or get_active_session()
96
- session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
181
+ job = job if isinstance(job, jb.MLJob) else get_job(job, session=session)
182
+ session = job._session
183
+ try:
184
+ stage_path = job._stage_path
185
+ session.sql(f"REMOVE {stage_path}/").collect()
186
+ logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
187
+ except Exception as e:
188
+ logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
189
+ session._conn.run_query("DROP SERVICE IDENTIFIER(?)", params=(job.id,), _force_qmark_paramstyle=True)
97
190
 
98
191
 
99
192
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
@@ -103,17 +196,11 @@ def submit_file(
103
196
  *,
104
197
  stage_name: str,
105
198
  args: Optional[list[str]] = None,
106
- env_vars: Optional[dict[str, str]] = None,
199
+ target_instances: int = 1,
107
200
  pip_requirements: Optional[list[str]] = None,
108
201
  external_access_integrations: Optional[list[str]] = None,
109
- query_warehouse: Optional[str] = None,
110
- spec_overrides: Optional[dict[str, Any]] = None,
111
- target_instances: int = 1,
112
- min_instances: int = 1,
113
- enable_metrics: bool = False,
114
- database: Optional[str] = None,
115
- schema: Optional[str] = None,
116
202
  session: Optional[snowpark.Session] = None,
203
+ **kwargs: Any,
117
204
  ) -> jb.MLJob[None]:
118
205
  """
119
206
  Submit a Python file as a job to the compute pool.
@@ -123,17 +210,20 @@ def submit_file(
123
210
  compute_pool: The compute pool to use for the job.
124
211
  stage_name: The name of the stage where the job payload will be uploaded.
125
212
  args: A list of arguments to pass to the job.
126
- env_vars: Environment variables to set in container
213
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
127
214
  pip_requirements: A list of pip requirements for the job.
128
215
  external_access_integrations: A list of external access integrations.
129
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
130
- spec_overrides: Custom service specification overrides to apply.
131
- target_instances: The number of instances to use for the job. If none specified, single node job is created.
132
- min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
133
- enable_metrics: Whether to enable metrics publishing for the job.
134
- database: The database to use.
135
- schema: The schema to use.
136
216
  session: The Snowpark session to use. If none specified, uses active session.
217
+ kwargs: Additional keyword arguments. Supported arguments:
218
+ database (str): The database to use for the job.
219
+ schema (str): The schema to use for the job.
220
+ min_instances (int): The minimum number of nodes required to start the job.
221
+ If none specified, defaults to target_instances. If set, the job
222
+ will not start until the minimum number of nodes is available.
223
+ env_vars (dict): Environment variables to set in container.
224
+ enable_metrics (bool): Whether to enable metrics publishing for the job.
225
+ query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
226
+ spec_overrides (dict): A dictionary of overrides for the service spec.
137
227
 
138
228
  Returns:
139
229
  An object representing the submitted job.
@@ -143,17 +233,11 @@ def submit_file(
143
233
  args=args,
144
234
  compute_pool=compute_pool,
145
235
  stage_name=stage_name,
146
- env_vars=env_vars,
236
+ target_instances=target_instances,
147
237
  pip_requirements=pip_requirements,
148
238
  external_access_integrations=external_access_integrations,
149
- query_warehouse=query_warehouse,
150
- spec_overrides=spec_overrides,
151
- target_instances=target_instances,
152
- min_instances=min_instances,
153
- enable_metrics=enable_metrics,
154
- database=database,
155
- schema=schema,
156
239
  session=session,
240
+ **kwargs,
157
241
  )
158
242
 
159
243
 
@@ -165,17 +249,11 @@ def submit_directory(
165
249
  entrypoint: str,
166
250
  stage_name: str,
167
251
  args: Optional[list[str]] = None,
168
- env_vars: Optional[dict[str, str]] = None,
252
+ target_instances: int = 1,
169
253
  pip_requirements: Optional[list[str]] = None,
170
254
  external_access_integrations: Optional[list[str]] = None,
171
- query_warehouse: Optional[str] = None,
172
- spec_overrides: Optional[dict[str, Any]] = None,
173
- target_instances: int = 1,
174
- min_instances: int = 1,
175
- enable_metrics: bool = False,
176
- database: Optional[str] = None,
177
- schema: Optional[str] = None,
178
255
  session: Optional[snowpark.Session] = None,
256
+ **kwargs: Any,
179
257
  ) -> jb.MLJob[None]:
180
258
  """
181
259
  Submit a directory containing Python script(s) as a job to the compute pool.
@@ -186,17 +264,20 @@ def submit_directory(
186
264
  entrypoint: The relative path to the entry point script inside the source directory.
187
265
  stage_name: The name of the stage where the job payload will be uploaded.
188
266
  args: A list of arguments to pass to the job.
189
- env_vars: Environment variables to set in container
267
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
190
268
  pip_requirements: A list of pip requirements for the job.
191
269
  external_access_integrations: A list of external access integrations.
192
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
193
- spec_overrides: Custom service specification overrides to apply.
194
- target_instances: The number of instances to use for the job. If none specified, single node job is created.
195
- min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
196
- enable_metrics: Whether to enable metrics publishing for the job.
197
- database: The database to use.
198
- schema: The schema to use.
199
270
  session: The Snowpark session to use. If none specified, uses active session.
271
+ kwargs: Additional keyword arguments. Supported arguments:
272
+ database (str): The database to use for the job.
273
+ schema (str): The schema to use for the job.
274
+ min_instances (int): The minimum number of nodes required to start the job.
275
+ If none specified, defaults to target_instances. If set, the job
276
+ will not start until the minimum number of nodes is available.
277
+ env_vars (dict): Environment variables to set in container.
278
+ enable_metrics (bool): Whether to enable metrics publishing for the job.
279
+ query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
280
+ spec_overrides (dict): A dictionary of overrides for the service spec.
200
281
 
201
282
  Returns:
202
283
  An object representing the submitted job.
@@ -207,17 +288,66 @@ def submit_directory(
207
288
  args=args,
208
289
  compute_pool=compute_pool,
209
290
  stage_name=stage_name,
210
- env_vars=env_vars,
291
+ target_instances=target_instances,
211
292
  pip_requirements=pip_requirements,
212
293
  external_access_integrations=external_access_integrations,
213
- query_warehouse=query_warehouse,
214
- spec_overrides=spec_overrides,
294
+ session=session,
295
+ **kwargs,
296
+ )
297
+
298
+
299
+ @telemetry.send_api_usage_telemetry(project=_PROJECT)
300
+ def submit_from_stage(
301
+ source: str,
302
+ compute_pool: str,
303
+ *,
304
+ entrypoint: str,
305
+ stage_name: str,
306
+ args: Optional[list[str]] = None,
307
+ target_instances: int = 1,
308
+ pip_requirements: Optional[list[str]] = None,
309
+ external_access_integrations: Optional[list[str]] = None,
310
+ session: Optional[snowpark.Session] = None,
311
+ **kwargs: Any,
312
+ ) -> jb.MLJob[None]:
313
+ """
314
+ Submit a directory containing Python script(s) as a job to the compute pool.
315
+
316
+ Args:
317
+ source: a stage path or a stage containing the job payload.
318
+ compute_pool: The compute pool to use for the job.
319
+ entrypoint: a stage path containing the entry point script inside the source directory.
320
+ stage_name: The name of the stage where the job payload will be uploaded.
321
+ args: A list of arguments to pass to the job.
322
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
323
+ pip_requirements: A list of pip requirements for the job.
324
+ external_access_integrations: A list of external access integrations.
325
+ session: The Snowpark session to use. If none specified, uses active session.
326
+ kwargs: Additional keyword arguments. Supported arguments:
327
+ database (str): The database to use for the job.
328
+ schema (str): The schema to use for the job.
329
+ min_instances (int): The minimum number of nodes required to start the job.
330
+ If none specified, defaults to target_instances. If set, the job
331
+ will not start until the minimum number of nodes is available.
332
+ env_vars (dict): Environment variables to set in container.
333
+ enable_metrics (bool): Whether to enable metrics publishing for the job.
334
+ query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
335
+ spec_overrides (dict): A dictionary of overrides for the service spec.
336
+
337
+ Returns:
338
+ An object representing the submitted job.
339
+ """
340
+ return _submit_job(
341
+ source=source,
342
+ entrypoint=entrypoint,
343
+ args=args,
344
+ compute_pool=compute_pool,
345
+ stage_name=stage_name,
215
346
  target_instances=target_instances,
216
- min_instances=min_instances,
217
- enable_metrics=enable_metrics,
218
- database=database,
219
- schema=schema,
347
+ pip_requirements=pip_requirements,
348
+ external_access_integrations=external_access_integrations,
220
349
  session=session,
350
+ **kwargs,
221
351
  )
222
352
 
223
353
 
@@ -229,17 +359,11 @@ def _submit_job(
229
359
  stage_name: str,
230
360
  entrypoint: Optional[str] = None,
231
361
  args: Optional[list[str]] = None,
232
- env_vars: Optional[dict[str, str]] = None,
362
+ target_instances: int = 1,
233
363
  pip_requirements: Optional[list[str]] = None,
234
364
  external_access_integrations: Optional[list[str]] = None,
235
- query_warehouse: Optional[str] = None,
236
- spec_overrides: Optional[dict[str, Any]] = None,
237
- target_instances: int = 1,
238
- min_instances: int = 1,
239
- enable_metrics: bool = False,
240
- database: Optional[str] = None,
241
- schema: Optional[str] = None,
242
365
  session: Optional[snowpark.Session] = None,
366
+ **kwargs: Any,
243
367
  ) -> jb.MLJob[None]:
244
368
  ...
245
369
 
@@ -252,17 +376,11 @@ def _submit_job(
252
376
  stage_name: str,
253
377
  entrypoint: Optional[str] = None,
254
378
  args: Optional[list[str]] = None,
255
- env_vars: Optional[dict[str, str]] = None,
379
+ target_instances: int = 1,
256
380
  pip_requirements: Optional[list[str]] = None,
257
381
  external_access_integrations: Optional[list[str]] = None,
258
- query_warehouse: Optional[str] = None,
259
- spec_overrides: Optional[dict[str, Any]] = None,
260
- target_instances: int = 1,
261
- min_instances: int = 1,
262
- enable_metrics: bool = False,
263
- database: Optional[str] = None,
264
- schema: Optional[str] = None,
265
382
  session: Optional[snowpark.Session] = None,
383
+ **kwargs: Any,
266
384
  ) -> jb.MLJob[T]:
267
385
  ...
268
386
 
@@ -275,8 +393,9 @@ def _submit_job(
275
393
  # TODO: Log lengths of args, env_vars, and spec_overrides values
276
394
  "pip_requirements",
277
395
  "external_access_integrations",
396
+ "num_instances", # deprecated
278
397
  "target_instances",
279
- "enable_metrics",
398
+ "min_instances",
280
399
  ],
281
400
  )
282
401
  def _submit_job(
@@ -286,17 +405,9 @@ def _submit_job(
286
405
  stage_name: str,
287
406
  entrypoint: Optional[str] = None,
288
407
  args: Optional[list[str]] = None,
289
- env_vars: Optional[dict[str, str]] = None,
290
- pip_requirements: Optional[list[str]] = None,
291
- external_access_integrations: Optional[list[str]] = None,
292
- query_warehouse: Optional[str] = None,
293
- spec_overrides: Optional[dict[str, Any]] = None,
294
408
  target_instances: int = 1,
295
- min_instances: int = 1,
296
- enable_metrics: bool = False,
297
- database: Optional[str] = None,
298
- schema: Optional[str] = None,
299
409
  session: Optional[snowpark.Session] = None,
410
+ **kwargs: Any,
300
411
  ) -> jb.MLJob[T]:
301
412
  """
302
413
  Submit a job to the compute pool.
@@ -307,17 +418,9 @@ def _submit_job(
307
418
  stage_name: The name of the stage where the job payload will be uploaded.
308
419
  entrypoint: The entry point for the job execution. Required if source is a directory.
309
420
  args: A list of arguments to pass to the job.
310
- env_vars: Environment variables to set in container
311
- pip_requirements: A list of pip requirements for the job.
312
- external_access_integrations: A list of external access integrations.
313
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
314
- spec_overrides: Custom service specification overrides to apply.
315
421
  target_instances: The number of instances to use for the job. If none specified, single node job is created.
316
- min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
317
- enable_metrics: Whether to enable metrics publishing for the job.
318
- database: The database to use.
319
- schema: The schema to use.
320
422
  session: The Snowpark session to use. If none specified, uses active session.
423
+ kwargs: Additional keyword arguments.
321
424
 
322
425
  Returns:
323
426
  An object representing the submitted job.
@@ -325,23 +428,52 @@ def _submit_job(
325
428
  Raises:
326
429
  RuntimeError: If required Snowflake features are not enabled.
327
430
  ValueError: If database or schema value(s) are invalid
431
+ errors.ProgrammingError: if the SQL query or its parameters are invalid
328
432
  """
329
- if database and not schema:
330
- raise ValueError("Schema must be specified if database is specified.")
331
- if target_instances < 1 or min_instances < 1:
332
- raise ValueError("target_instances and min_instances must be greater than 0.")
333
- if min_instances > target_instances:
334
- raise ValueError("min_instances must be less than or equal to target_instances.")
335
-
336
433
  session = session or get_active_session()
337
434
 
338
- # Validate database and schema identifiers on client side since
339
- # SQL parser for EXECUTE JOB SERVICE seems to struggle with this
340
- database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
341
- schema = identifier.resolve_identifier(cast(str, schema or session.get_current_schema()))
435
+ # Use kwargs for less common optional parameters
436
+ database = kwargs.pop("database", None)
437
+ schema = kwargs.pop("schema", None)
438
+ min_instances = kwargs.pop("min_instances", target_instances)
439
+ pip_requirements = kwargs.pop("pip_requirements", None)
440
+ external_access_integrations = kwargs.pop("external_access_integrations", None)
441
+ env_vars = kwargs.pop("env_vars", None)
442
+ spec_overrides = kwargs.pop("spec_overrides", None)
443
+ enable_metrics = kwargs.pop("enable_metrics", True)
444
+ query_warehouse = kwargs.pop("query_warehouse", None)
445
+
446
+ # Check for deprecated args
447
+ if "num_instances" in kwargs:
448
+ logger.warning(
449
+ "'num_instances' is deprecated and will be removed in a future release. Use 'target_instances' instead."
450
+ )
451
+ target_instances = max(target_instances, kwargs.pop("num_instances"))
452
+
453
+ # Warn if there are unknown kwargs
454
+ if kwargs:
455
+ logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
456
+
457
+ # Validate parameters
458
+ if database and not schema:
459
+ raise ValueError("Schema must be specified if database is specified.")
460
+ if target_instances < 1:
461
+ raise ValueError("target_instances must be greater than 0.")
462
+ if not (0 < min_instances <= target_instances):
463
+ raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
464
+ if min_instances > 1:
465
+ # Validate min_instances against compute pool max_nodes
466
+ pool_info = jb._get_compute_pool_info(session, compute_pool)
467
+ requested_attributes = query_helper.get_attribute_map(session, {"max_nodes": 3})
468
+ max_nodes = int(pool_info[requested_attributes["max_nodes"]])
469
+ if min_instances > max_nodes:
470
+ raise ValueError(
471
+ f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
472
+ f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
473
+ )
342
474
 
343
475
  job_name = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
344
- job_id = f"{database}.{schema}.{job_name}"
476
+ job_id = identifier.get_schema_level_object_identifier(database, schema, job_name)
345
477
  stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
346
478
  stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
347
479
  stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
@@ -399,12 +531,12 @@ def _submit_job(
399
531
  query_text = "\n".join(line for line in query if line)
400
532
 
401
533
  try:
402
- _ = session.sql(query_text, params=params).collect()
403
- except SnowparkSQLException as e:
404
- if "invalid property 'ASYNC'" in e.message:
534
+ _ = session._conn.run_query(query_text, params=params, _force_qmark_paramstyle=True)
535
+ except errors.ProgrammingError as e:
536
+ if "invalid property 'ASYNC'" in str(e):
405
537
  raise RuntimeError(
406
538
  "SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
407
539
  ) from e
408
540
  raise
409
541
 
410
- return jb.MLJob(job_id, service_spec=spec, session=session)
542
+ return get_job(job_id, session=session)
@@ -426,3 +426,61 @@ class Model:
426
426
  schema_name=new_schema or self._model_ops._model_client._schema_name,
427
427
  )
428
428
  self._model_name = new_model
429
+
430
+ def _repr_html_(self) -> str:
431
+ """Generate an HTML representation of the model.
432
+
433
+ Returns:
434
+ str: HTML string containing formatted model details.
435
+ """
436
+ from snowflake.ml.utils import html_utils
437
+
438
+ # Get default version
439
+ default_version = self.default.version_name
440
+
441
+ # Get versions info
442
+ try:
443
+ versions_df = self.show_versions()
444
+ versions_html = ""
445
+
446
+ for _, row in versions_df.iterrows():
447
+ versions_html += html_utils.create_version_item(
448
+ version_name=row["name"],
449
+ created_on=str(row["created_on"]),
450
+ comment=str(row.get("comment", "")),
451
+ is_default=bool(row["is_default_version"]),
452
+ )
453
+ except Exception:
454
+ versions_html = html_utils.create_error_message("Error retrieving versions")
455
+
456
+ # Get tags
457
+ try:
458
+ tags = self.show_tags()
459
+ if not tags:
460
+ tags_html = html_utils.create_error_message("No tags available")
461
+ else:
462
+ tags_html = ""
463
+ for tag_name, tag_value in tags.items():
464
+ tags_html += html_utils.create_tag_item(tag_name, tag_value)
465
+ except Exception:
466
+ tags_html = html_utils.create_error_message("Error retrieving tags")
467
+
468
+ # Create main content sections
469
+ main_info = html_utils.create_grid_section(
470
+ [
471
+ ("Model Name", self.name),
472
+ ("Full Name", self.fully_qualified_name),
473
+ ("Description", self.description),
474
+ ("Default Version", default_version),
475
+ ]
476
+ )
477
+
478
+ versions_section = html_utils.create_section_header("Versions") + html_utils.create_content_section(
479
+ versions_html
480
+ )
481
+
482
+ tags_section = html_utils.create_section_header("Tags") + html_utils.create_content_section(tags_html)
483
+
484
+ content = main_info + versions_section + tags_section
485
+
486
+ return html_utils.create_base_container("Model Details", content)