snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. snowflake/ml/_internal/utils/identifier.py +1 -1
  2. snowflake/ml/_internal/utils/mixins.py +61 -0
  3. snowflake/ml/jobs/_utils/constants.py +1 -1
  4. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  5. snowflake/ml/jobs/_utils/payload_utils.py +6 -5
  6. snowflake/ml/jobs/_utils/query_helper.py +9 -0
  7. snowflake/ml/jobs/_utils/spec_utils.py +6 -4
  8. snowflake/ml/jobs/decorators.py +18 -25
  9. snowflake/ml/jobs/job.py +179 -58
  10. snowflake/ml/jobs/manager.py +194 -145
  11. snowflake/ml/model/_client/ops/model_ops.py +12 -3
  12. snowflake/ml/model/_client/ops/service_ops.py +4 -2
  13. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
  14. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
  15. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  16. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  17. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  18. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  19. snowflake/ml/model/target_platform.py +11 -0
  20. snowflake/ml/model/task.py +9 -0
  21. snowflake/ml/model/type_hints.py +5 -13
  22. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  23. snowflake/ml/registry/_manager/model_manager.py +30 -15
  24. snowflake/ml/registry/registry.py +119 -42
  25. snowflake/ml/version.py +1 -1
  26. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +52 -16
  27. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +30 -26
  28. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
  29. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  30. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,21 @@
1
1
  import logging
2
2
  import pathlib
3
3
  import textwrap
4
- from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast, overload
4
+ from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
5
5
  from uuid import uuid4
6
6
 
7
+ import pandas as pd
7
8
  import yaml
8
9
 
9
10
  from snowflake import snowpark
11
+ from snowflake.connector import errors
10
12
  from snowflake.ml._internal import telemetry
11
13
  from snowflake.ml._internal.utils import identifier
12
14
  from snowflake.ml.jobs import job as jb
13
- from snowflake.ml.jobs._utils import payload_utils, spec_utils
15
+ from snowflake.ml.jobs._utils import payload_utils, query_helper, spec_utils
14
16
  from snowflake.snowpark.context import get_active_session
15
17
  from snowflake.snowpark.exceptions import SnowparkSQLException
18
+ from snowflake.snowpark.functions import coalesce, col, lit, when
16
19
 
17
20
  logger = logging.getLogger(__name__)
18
21
 
@@ -25,39 +28,127 @@ T = TypeVar("T")
25
28
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
26
29
  def list_jobs(
27
30
  limit: int = 10,
28
- scope: Union[Literal["account", "database", "schema"], str, None] = None,
31
+ database: Optional[str] = None,
32
+ schema: Optional[str] = None,
29
33
  session: Optional[snowpark.Session] = None,
30
- ) -> snowpark.DataFrame:
34
+ ) -> pd.DataFrame:
31
35
  """
32
- Returns a Snowpark DataFrame with the list of jobs in the current session.
36
+ Returns a Pandas DataFrame with the list of jobs in the current session.
33
37
 
34
38
  Args:
35
39
  limit: The maximum number of jobs to return. Non-positive values are treated as no limit.
36
- scope: The scope to list jobs from, such as "schema" or "compute pool <pool_name>".
40
+ database: The database to use. If not specified, uses the current database.
41
+ schema: The schema to use. If not specified, uses the current schema.
37
42
  session: The Snowpark session to use. If none specified, uses active session.
38
43
 
39
44
  Returns:
40
45
  A DataFrame with the list of jobs.
41
46
 
47
+ Raises:
48
+ SnowparkSQLException: if there is an error retrieving the job history.
49
+
42
50
  Examples:
43
51
  >>> from snowflake.ml.jobs import list_jobs
44
- >>> list_jobs(limit=5).show()
52
+ >>> list_jobs(limit=5)
45
53
  """
46
54
  session = session or get_active_session()
55
+ try:
56
+ df = _get_job_history_spcs(
57
+ session,
58
+ limit=limit,
59
+ database=database,
60
+ schema=schema,
61
+ )
62
+ return df.to_pandas()
63
+ except SnowparkSQLException as spcs_error:
64
+ if spcs_error.sql_error_code == 2143:
65
+ logger.debug("Job history is not enabled. Please enable it to use this feature.")
66
+ df = _get_job_services(session, limit=limit, database=database, schema=schema)
67
+ return df.to_pandas()
68
+ raise
69
+
70
+
71
+ def _get_job_services(
72
+ session: snowpark.Session, limit: int = 10, database: Optional[str] = None, schema: Optional[str] = None
73
+ ) -> snowpark.DataFrame:
47
74
  query = "SHOW JOB SERVICES"
48
75
  query += f" LIKE '{JOB_ID_PREFIX}%'"
49
- if scope:
50
- query += f" IN {scope}"
76
+ database = database or session.get_current_database()
77
+ schema = schema or session.get_current_schema()
78
+ if database is None and schema is None:
79
+ query += "IN account"
80
+ elif not schema:
81
+ query += f" IN DATABASE {database}"
82
+ else:
83
+ query += f" IN {database}.{schema}"
51
84
  if limit > 0:
52
85
  query += f" LIMIT {limit}"
53
86
  df = session.sql(query)
54
87
  df = df.select(
55
88
  df['"name"'],
56
- df['"owner"'],
57
89
  df['"status"'],
58
- df['"created_on"'],
90
+ lit(None).alias('"message"'),
91
+ df['"database_name"'],
92
+ df['"schema_name"'],
93
+ df['"owner"'],
59
94
  df['"compute_pool"'],
60
- ).order_by('"created_on"', ascending=False)
95
+ df['"target_instances"'],
96
+ df['"created_on"'].alias('"created_time"'),
97
+ when(col('"status"').isin(jb.TERMINAL_JOB_STATUSES), col('"updated_on"'))
98
+ .otherwise(lit(None))
99
+ .alias('"completed_time"'),
100
+ ).order_by('"created_time"', ascending=False)
101
+ return df
102
+
103
+
104
+ def _get_job_history_spcs(
105
+ session: snowpark.Session,
106
+ limit: int = 10,
107
+ database: Optional[str] = None,
108
+ schema: Optional[str] = None,
109
+ include_deleted: bool = False,
110
+ created_time_start: Optional[str] = None,
111
+ created_time_end: Optional[str] = None,
112
+ ) -> snowpark.DataFrame:
113
+ query = ["select * from table(snowflake.spcs.get_job_history("]
114
+ query_params = []
115
+ if created_time_start:
116
+ query_params.append(f"created_time_start => TO_TIMESTAMP_LTZ('{created_time_start}')")
117
+ if created_time_end:
118
+ query_params.append(f"created_time_end => TO_TIMESTAMP_LTZ('{created_time_end}')")
119
+ query.append(",".join(query_params))
120
+ query.append("))")
121
+ condition = []
122
+ database = database or session.get_current_database()
123
+ schema = schema or session.get_current_schema()
124
+
125
+ # format database and schema identifiers
126
+ if database:
127
+ condition.append(f"DATABASE_NAME = '{identifier.resolve_identifier(database)}'")
128
+
129
+ if schema:
130
+ condition.append(f"SCHEMA_NAME = '{identifier.resolve_identifier(schema)}'")
131
+
132
+ if not include_deleted:
133
+ condition.append("DELETED_TIME IS NULL")
134
+
135
+ if len(condition) > 0:
136
+ query.append("WHERE " + " AND ".join(condition))
137
+ if limit > 0:
138
+ query.append(f"LIMIT {limit}")
139
+ df = session.sql("\n".join(query))
140
+ df = df.select(
141
+ df["NAME"].alias('"name"'),
142
+ df["STATUS"].alias('"status"'),
143
+ df["MESSAGE"].alias('"message"'),
144
+ df["DATABASE_NAME"].alias('"database_name"'),
145
+ df["SCHEMA_NAME"].alias('"schema_name"'),
146
+ df["OWNER"].alias('"owner"'),
147
+ df["COMPUTE_POOL_NAME"].alias('"compute_pool"'),
148
+ coalesce(df["PARAMETERS"]["REPLICAS"], lit(1)).alias('"target_instances"'),
149
+ df["CREATED_TIME"].alias('"created_time"'),
150
+ df["COMPLETED_TIME"].alias('"completed_time"'),
151
+ )
61
152
  return df
62
153
 
63
154
 
@@ -74,12 +165,12 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
74
165
 
75
166
  job_id = f"{database}.{schema}.{job_name}"
76
167
  try:
77
- # Validate that job exists by doing a status check
168
+ # Validate that job exists by doing a spec lookup
78
169
  job = jb.MLJob[Any](job_id, session=session)
79
- _ = job.status
170
+ _ = job._service_spec
80
171
  return job
81
- except SnowparkSQLException as e:
82
- if "does not exist" in e.message:
172
+ except errors.ProgrammingError as e:
173
+ if "does not exist" in str(e):
83
174
  raise ValueError(f"Job does not exist: {job_id}") from e
84
175
  raise
85
176
 
@@ -95,7 +186,7 @@ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Sessio
95
186
  logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
96
187
  except Exception as e:
97
188
  logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
98
- session.sql("DROP SERVICE IDENTIFIER(?)", params=(job.id,)).collect()
189
+ session._conn.run_query("DROP SERVICE IDENTIFIER(?)", params=(job.id,), _force_qmark_paramstyle=True)
99
190
 
100
191
 
101
192
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
@@ -105,17 +196,11 @@ def submit_file(
105
196
  *,
106
197
  stage_name: str,
107
198
  args: Optional[list[str]] = None,
108
- env_vars: Optional[dict[str, str]] = None,
199
+ target_instances: int = 1,
109
200
  pip_requirements: Optional[list[str]] = None,
110
201
  external_access_integrations: Optional[list[str]] = None,
111
- query_warehouse: Optional[str] = None,
112
- spec_overrides: Optional[dict[str, Any]] = None,
113
- target_instances: int = 1,
114
- min_instances: Optional[int] = None,
115
- enable_metrics: bool = False,
116
- database: Optional[str] = None,
117
- schema: Optional[str] = None,
118
202
  session: Optional[snowpark.Session] = None,
203
+ **kwargs: Any,
119
204
  ) -> jb.MLJob[None]:
120
205
  """
121
206
  Submit a Python file as a job to the compute pool.
@@ -125,18 +210,20 @@ def submit_file(
125
210
  compute_pool: The compute pool to use for the job.
126
211
  stage_name: The name of the stage where the job payload will be uploaded.
127
212
  args: A list of arguments to pass to the job.
128
- env_vars: Environment variables to set in container
213
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
129
214
  pip_requirements: A list of pip requirements for the job.
130
215
  external_access_integrations: A list of external access integrations.
131
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
132
- spec_overrides: Custom service specification overrides to apply.
133
- target_instances: The number of instances to use for the job. If none specified, single node job is created.
134
- min_instances: The minimum number of nodes required to start the job. If none specified,
135
- defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
136
- enable_metrics: Whether to enable metrics publishing for the job.
137
- database: The database to use.
138
- schema: The schema to use.
139
216
  session: The Snowpark session to use. If none specified, uses active session.
217
+ kwargs: Additional keyword arguments. Supported arguments:
218
+ database (str): The database to use for the job.
219
+ schema (str): The schema to use for the job.
220
+ min_instances (int): The minimum number of nodes required to start the job.
221
+ If none specified, defaults to target_instances. If set, the job
222
+ will not start until the minimum number of nodes is available.
223
+ env_vars (dict): Environment variables to set in container.
224
+ enable_metrics (bool): Whether to enable metrics publishing for the job.
225
+ query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
226
+ spec_overrides (dict): A dictionary of overrides for the service spec.
140
227
 
141
228
  Returns:
142
229
  An object representing the submitted job.
@@ -146,17 +233,11 @@ def submit_file(
146
233
  args=args,
147
234
  compute_pool=compute_pool,
148
235
  stage_name=stage_name,
149
- env_vars=env_vars,
236
+ target_instances=target_instances,
150
237
  pip_requirements=pip_requirements,
151
238
  external_access_integrations=external_access_integrations,
152
- query_warehouse=query_warehouse,
153
- spec_overrides=spec_overrides,
154
- target_instances=target_instances,
155
- min_instances=min_instances,
156
- enable_metrics=enable_metrics,
157
- database=database,
158
- schema=schema,
159
239
  session=session,
240
+ **kwargs,
160
241
  )
161
242
 
162
243
 
@@ -168,17 +249,11 @@ def submit_directory(
168
249
  entrypoint: str,
169
250
  stage_name: str,
170
251
  args: Optional[list[str]] = None,
171
- env_vars: Optional[dict[str, str]] = None,
252
+ target_instances: int = 1,
172
253
  pip_requirements: Optional[list[str]] = None,
173
254
  external_access_integrations: Optional[list[str]] = None,
174
- query_warehouse: Optional[str] = None,
175
- spec_overrides: Optional[dict[str, Any]] = None,
176
- target_instances: int = 1,
177
- min_instances: Optional[int] = None,
178
- enable_metrics: bool = False,
179
- database: Optional[str] = None,
180
- schema: Optional[str] = None,
181
255
  session: Optional[snowpark.Session] = None,
256
+ **kwargs: Any,
182
257
  ) -> jb.MLJob[None]:
183
258
  """
184
259
  Submit a directory containing Python script(s) as a job to the compute pool.
@@ -189,18 +264,20 @@ def submit_directory(
189
264
  entrypoint: The relative path to the entry point script inside the source directory.
190
265
  stage_name: The name of the stage where the job payload will be uploaded.
191
266
  args: A list of arguments to pass to the job.
192
- env_vars: Environment variables to set in container
267
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
193
268
  pip_requirements: A list of pip requirements for the job.
194
269
  external_access_integrations: A list of external access integrations.
195
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
196
- spec_overrides: Custom service specification overrides to apply.
197
- target_instances: The number of instances to use for the job. If none specified, single node job is created.
198
- min_instances: The minimum number of nodes required to start the job. If none specified,
199
- defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
200
- enable_metrics: Whether to enable metrics publishing for the job.
201
- database: The database to use.
202
- schema: The schema to use.
203
270
  session: The Snowpark session to use. If none specified, uses active session.
271
+ kwargs: Additional keyword arguments. Supported arguments:
272
+ database (str): The database to use for the job.
273
+ schema (str): The schema to use for the job.
274
+ min_instances (int): The minimum number of nodes required to start the job.
275
+ If none specified, defaults to target_instances. If set, the job
276
+ will not start until the minimum number of nodes is available.
277
+ env_vars (dict): Environment variables to set in container.
278
+ enable_metrics (bool): Whether to enable metrics publishing for the job.
279
+ query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
280
+ spec_overrides (dict): A dictionary of overrides for the service spec.
204
281
 
205
282
  Returns:
206
283
  An object representing the submitted job.
@@ -211,17 +288,11 @@ def submit_directory(
211
288
  args=args,
212
289
  compute_pool=compute_pool,
213
290
  stage_name=stage_name,
214
- env_vars=env_vars,
291
+ target_instances=target_instances,
215
292
  pip_requirements=pip_requirements,
216
293
  external_access_integrations=external_access_integrations,
217
- query_warehouse=query_warehouse,
218
- spec_overrides=spec_overrides,
219
- target_instances=target_instances,
220
- min_instances=min_instances,
221
- enable_metrics=enable_metrics,
222
- database=database,
223
- schema=schema,
224
294
  session=session,
295
+ **kwargs,
225
296
  )
226
297
 
227
298
 
@@ -233,17 +304,11 @@ def submit_from_stage(
233
304
  entrypoint: str,
234
305
  stage_name: str,
235
306
  args: Optional[list[str]] = None,
236
- env_vars: Optional[dict[str, str]] = None,
307
+ target_instances: int = 1,
237
308
  pip_requirements: Optional[list[str]] = None,
238
309
  external_access_integrations: Optional[list[str]] = None,
239
- query_warehouse: Optional[str] = None,
240
- spec_overrides: Optional[dict[str, Any]] = None,
241
- target_instances: int = 1,
242
- min_instances: Optional[int] = None,
243
- enable_metrics: bool = False,
244
- database: Optional[str] = None,
245
- schema: Optional[str] = None,
246
310
  session: Optional[snowpark.Session] = None,
311
+ **kwargs: Any,
247
312
  ) -> jb.MLJob[None]:
248
313
  """
249
314
  Submit a directory containing Python script(s) as a job to the compute pool.
@@ -254,19 +319,20 @@ def submit_from_stage(
254
319
  entrypoint: a stage path containing the entry point script inside the source directory.
255
320
  stage_name: The name of the stage where the job payload will be uploaded.
256
321
  args: A list of arguments to pass to the job.
257
- env_vars: Environment variables to set in container
322
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
258
323
  pip_requirements: A list of pip requirements for the job.
259
324
  external_access_integrations: A list of external access integrations.
260
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
261
- spec_overrides: Custom service specification overrides to apply.
262
- target_instances: The number of instances to use for the job. If none specified, single node job is created.
263
- min_instances: The minimum number of nodes required to start the job. If none specified,
264
- defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
265
- enable_metrics: Whether to enable metrics publishing for the job.
266
- database: The database to use.
267
- schema: The schema to use.
268
325
  session: The Snowpark session to use. If none specified, uses active session.
269
-
326
+ kwargs: Additional keyword arguments. Supported arguments:
327
+ database (str): The database to use for the job.
328
+ schema (str): The schema to use for the job.
329
+ min_instances (int): The minimum number of nodes required to start the job.
330
+ If none specified, defaults to target_instances. If set, the job
331
+ will not start until the minimum number of nodes is available.
332
+ env_vars (dict): Environment variables to set in container.
333
+ enable_metrics (bool): Whether to enable metrics publishing for the job.
334
+ query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
335
+ spec_overrides (dict): A dictionary of overrides for the service spec.
270
336
 
271
337
  Returns:
272
338
  An object representing the submitted job.
@@ -277,17 +343,11 @@ def submit_from_stage(
277
343
  args=args,
278
344
  compute_pool=compute_pool,
279
345
  stage_name=stage_name,
280
- env_vars=env_vars,
346
+ target_instances=target_instances,
281
347
  pip_requirements=pip_requirements,
282
348
  external_access_integrations=external_access_integrations,
283
- query_warehouse=query_warehouse,
284
- spec_overrides=spec_overrides,
285
- target_instances=target_instances,
286
- min_instances=min_instances,
287
- enable_metrics=enable_metrics,
288
- database=database,
289
- schema=schema,
290
349
  session=session,
350
+ **kwargs,
291
351
  )
292
352
 
293
353
 
@@ -299,17 +359,11 @@ def _submit_job(
299
359
  stage_name: str,
300
360
  entrypoint: Optional[str] = None,
301
361
  args: Optional[list[str]] = None,
302
- env_vars: Optional[dict[str, str]] = None,
362
+ target_instances: int = 1,
303
363
  pip_requirements: Optional[list[str]] = None,
304
364
  external_access_integrations: Optional[list[str]] = None,
305
- query_warehouse: Optional[str] = None,
306
- spec_overrides: Optional[dict[str, Any]] = None,
307
- target_instances: int = 1,
308
- min_instances: Optional[int] = None,
309
- enable_metrics: bool = False,
310
- database: Optional[str] = None,
311
- schema: Optional[str] = None,
312
365
  session: Optional[snowpark.Session] = None,
366
+ **kwargs: Any,
313
367
  ) -> jb.MLJob[None]:
314
368
  ...
315
369
 
@@ -322,17 +376,11 @@ def _submit_job(
322
376
  stage_name: str,
323
377
  entrypoint: Optional[str] = None,
324
378
  args: Optional[list[str]] = None,
325
- env_vars: Optional[dict[str, str]] = None,
379
+ target_instances: int = 1,
326
380
  pip_requirements: Optional[list[str]] = None,
327
381
  external_access_integrations: Optional[list[str]] = None,
328
- query_warehouse: Optional[str] = None,
329
- spec_overrides: Optional[dict[str, Any]] = None,
330
- target_instances: int = 1,
331
- min_instances: Optional[int] = None,
332
- enable_metrics: bool = False,
333
- database: Optional[str] = None,
334
- schema: Optional[str] = None,
335
382
  session: Optional[snowpark.Session] = None,
383
+ **kwargs: Any,
336
384
  ) -> jb.MLJob[T]:
337
385
  ...
338
386
 
@@ -345,8 +393,9 @@ def _submit_job(
345
393
  # TODO: Log lengths of args, env_vars, and spec_overrides values
346
394
  "pip_requirements",
347
395
  "external_access_integrations",
396
+ "num_instances", # deprecated
348
397
  "target_instances",
349
- "enable_metrics",
398
+ "min_instances",
350
399
  ],
351
400
  )
352
401
  def _submit_job(
@@ -356,17 +405,9 @@ def _submit_job(
356
405
  stage_name: str,
357
406
  entrypoint: Optional[str] = None,
358
407
  args: Optional[list[str]] = None,
359
- env_vars: Optional[dict[str, str]] = None,
360
- pip_requirements: Optional[list[str]] = None,
361
- external_access_integrations: Optional[list[str]] = None,
362
- query_warehouse: Optional[str] = None,
363
- spec_overrides: Optional[dict[str, Any]] = None,
364
408
  target_instances: int = 1,
365
- min_instances: Optional[int] = None,
366
- enable_metrics: bool = False,
367
- database: Optional[str] = None,
368
- schema: Optional[str] = None,
369
409
  session: Optional[snowpark.Session] = None,
410
+ **kwargs: Any,
370
411
  ) -> jb.MLJob[T]:
371
412
  """
372
413
  Submit a job to the compute pool.
@@ -377,18 +418,9 @@ def _submit_job(
377
418
  stage_name: The name of the stage where the job payload will be uploaded.
378
419
  entrypoint: The entry point for the job execution. Required if source is a directory.
379
420
  args: A list of arguments to pass to the job.
380
- env_vars: Environment variables to set in container
381
- pip_requirements: A list of pip requirements for the job.
382
- external_access_integrations: A list of external access integrations.
383
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
384
- spec_overrides: Custom service specification overrides to apply.
385
421
  target_instances: The number of instances to use for the job. If none specified, single node job is created.
386
- min_instances: The minimum number of nodes required to start the job. If none specified,
387
- defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
388
- enable_metrics: Whether to enable metrics publishing for the job.
389
- database: The database to use.
390
- schema: The schema to use.
391
422
  session: The Snowpark session to use. If none specified, uses active session.
423
+ kwargs: Additional keyword arguments.
392
424
 
393
425
  Returns:
394
426
  An object representing the submitted job.
@@ -396,35 +428,52 @@ def _submit_job(
396
428
  Raises:
397
429
  RuntimeError: If required Snowflake features are not enabled.
398
430
  ValueError: If database or schema value(s) are invalid
431
+ errors.ProgrammingError: if the SQL query or its parameters are invalid
399
432
  """
433
+ session = session or get_active_session()
434
+
435
+ # Use kwargs for less common optional parameters
436
+ database = kwargs.pop("database", None)
437
+ schema = kwargs.pop("schema", None)
438
+ min_instances = kwargs.pop("min_instances", target_instances)
439
+ pip_requirements = kwargs.pop("pip_requirements", None)
440
+ external_access_integrations = kwargs.pop("external_access_integrations", None)
441
+ env_vars = kwargs.pop("env_vars", None)
442
+ spec_overrides = kwargs.pop("spec_overrides", None)
443
+ enable_metrics = kwargs.pop("enable_metrics", True)
444
+ query_warehouse = kwargs.pop("query_warehouse", None)
445
+
446
+ # Check for deprecated args
447
+ if "num_instances" in kwargs:
448
+ logger.warning(
449
+ "'num_instances' is deprecated and will be removed in a future release. Use 'target_instances' instead."
450
+ )
451
+ target_instances = max(target_instances, kwargs.pop("num_instances"))
452
+
453
+ # Warn if there are unknown kwargs
454
+ if kwargs:
455
+ logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
456
+
457
+ # Validate parameters
400
458
  if database and not schema:
401
459
  raise ValueError("Schema must be specified if database is specified.")
402
460
  if target_instances < 1:
403
461
  raise ValueError("target_instances must be greater than 0.")
404
-
405
- min_instances = target_instances if min_instances is None else min_instances
406
462
  if not (0 < min_instances <= target_instances):
407
463
  raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
408
-
409
- session = session or get_active_session()
410
-
411
464
  if min_instances > 1:
412
465
  # Validate min_instances against compute pool max_nodes
413
466
  pool_info = jb._get_compute_pool_info(session, compute_pool)
414
- max_nodes = int(pool_info["max_nodes"])
467
+ requested_attributes = query_helper.get_attribute_map(session, {"max_nodes": 3})
468
+ max_nodes = int(pool_info[requested_attributes["max_nodes"]])
415
469
  if min_instances > max_nodes:
416
470
  raise ValueError(
417
471
  f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
418
472
  f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
419
473
  )
420
474
 
421
- # Validate database and schema identifiers on client side since
422
- # SQL parser for EXECUTE JOB SERVICE seems to struggle with this
423
- database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
424
- schema = identifier.resolve_identifier(cast(str, schema or session.get_current_schema()))
425
-
426
475
  job_name = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
427
- job_id = f"{database}.{schema}.{job_name}"
476
+ job_id = identifier.get_schema_level_object_identifier(database, schema, job_name)
428
477
  stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
429
478
  stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
430
479
  stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
@@ -482,12 +531,12 @@ def _submit_job(
482
531
  query_text = "\n".join(line for line in query if line)
483
532
 
484
533
  try:
485
- _ = session.sql(query_text, params=params).collect()
486
- except SnowparkSQLException as e:
487
- if "invalid property 'ASYNC'" in e.message:
534
+ _ = session._conn.run_query(query_text, params=params, _force_qmark_paramstyle=True)
535
+ except errors.ProgrammingError as e:
536
+ if "invalid property 'ASYNC'" in str(e):
488
537
  raise RuntimeError(
489
538
  "SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
490
539
  ) from e
491
540
  raise
492
541
 
493
- return jb.MLJob(job_id, service_spec=spec, session=session)
542
+ return get_job(job_id, session=session)
@@ -955,7 +955,7 @@ class ModelOperator:
955
955
  output_with_input_features = False
956
956
  df = model_signature._convert_and_validate_local_data(X, signature.inputs, strict=strict_input_validation)
957
957
  s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
958
- self._session, df, keep_order=keep_order, features=signature.inputs
958
+ self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
959
959
  )
960
960
  else:
961
961
  keep_order = False
@@ -969,9 +969,16 @@ class ModelOperator:
969
969
 
970
970
  # Compose input and output names
971
971
  input_args = []
972
+ quoted_identifiers_ignore_case = (
973
+ snowpark_handler.SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
974
+ self._session, statement_params
975
+ )
976
+ )
977
+
972
978
  for input_feature in signature.inputs:
973
979
  col_name = identifier_rule.get_sql_identifier_from_feature(input_feature.name)
974
-
980
+ if quoted_identifiers_ignore_case:
981
+ col_name = sql_identifier.SqlIdentifier(input_feature.name.upper(), case_sensitive=True)
975
982
  input_args.append(col_name)
976
983
 
977
984
  returns = []
@@ -1051,7 +1058,9 @@ class ModelOperator:
1051
1058
 
1052
1059
  # Get final result
1053
1060
  if not isinstance(X, dataframe.DataFrame):
1054
- return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(df_res, features=signature.outputs)
1061
+ return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
1062
+ df_res, features=signature.outputs, statement_params=statement_params
1063
+ )
1055
1064
  else:
1056
1065
  return df_res
1057
1066
 
@@ -518,7 +518,7 @@ class ServiceOperator:
518
518
  output_with_input_features = False
519
519
  df = model_signature._convert_and_validate_local_data(X, signature.inputs)
520
520
  s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
521
- self._session, df, keep_order=keep_order, features=signature.inputs
521
+ self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
522
522
  )
523
523
  else:
524
524
  keep_order = False
@@ -630,7 +630,9 @@ class ServiceOperator:
630
630
 
631
631
  # get final result
632
632
  if not isinstance(X, dataframe.DataFrame):
633
- return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(df_res, features=signature.outputs)
633
+ return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
634
+ df_res, features=signature.outputs, statement_params=statement_params
635
+ )
634
636
  else:
635
637
  return df_res
636
638
 
@@ -2,6 +2,8 @@ from typing import Optional
2
2
 
3
3
  from pydantic import BaseModel
4
4
 
5
+ BaseModel.model_config["protected_namespaces"] = ()
6
+
5
7
 
6
8
  class Model(BaseModel):
7
9
  name: str