snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +61 -0
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +6 -5
- snowflake/ml/jobs/_utils/query_helper.py +9 -0
- snowflake/ml/jobs/_utils/spec_utils.py +6 -4
- snowflake/ml/jobs/decorators.py +18 -25
- snowflake/ml/jobs/job.py +179 -58
- snowflake/ml/jobs/manager.py +194 -145
- snowflake/ml/model/_client/ops/model_ops.py +12 -3
- snowflake/ml/model/_client/ops/service_ops.py +4 -2
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/registry/_manager/model_manager.py +30 -15
- snowflake/ml/registry/registry.py +119 -42
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +52 -16
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +30 -26
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/manager.py
CHANGED
@@ -1,18 +1,21 @@
|
|
1
1
|
import logging
|
2
2
|
import pathlib
|
3
3
|
import textwrap
|
4
|
-
from typing import Any, Callable,
|
4
|
+
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
|
5
5
|
from uuid import uuid4
|
6
6
|
|
7
|
+
import pandas as pd
|
7
8
|
import yaml
|
8
9
|
|
9
10
|
from snowflake import snowpark
|
11
|
+
from snowflake.connector import errors
|
10
12
|
from snowflake.ml._internal import telemetry
|
11
13
|
from snowflake.ml._internal.utils import identifier
|
12
14
|
from snowflake.ml.jobs import job as jb
|
13
|
-
from snowflake.ml.jobs._utils import payload_utils, spec_utils
|
15
|
+
from snowflake.ml.jobs._utils import payload_utils, query_helper, spec_utils
|
14
16
|
from snowflake.snowpark.context import get_active_session
|
15
17
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
18
|
+
from snowflake.snowpark.functions import coalesce, col, lit, when
|
16
19
|
|
17
20
|
logger = logging.getLogger(__name__)
|
18
21
|
|
@@ -25,39 +28,127 @@ T = TypeVar("T")
|
|
25
28
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
|
26
29
|
def list_jobs(
|
27
30
|
limit: int = 10,
|
28
|
-
|
31
|
+
database: Optional[str] = None,
|
32
|
+
schema: Optional[str] = None,
|
29
33
|
session: Optional[snowpark.Session] = None,
|
30
|
-
) ->
|
34
|
+
) -> pd.DataFrame:
|
31
35
|
"""
|
32
|
-
Returns a
|
36
|
+
Returns a Pandas DataFrame with the list of jobs in the current session.
|
33
37
|
|
34
38
|
Args:
|
35
39
|
limit: The maximum number of jobs to return. Non-positive values are treated as no limit.
|
36
|
-
|
40
|
+
database: The database to use. If not specified, uses the current database.
|
41
|
+
schema: The schema to use. If not specified, uses the current schema.
|
37
42
|
session: The Snowpark session to use. If none specified, uses active session.
|
38
43
|
|
39
44
|
Returns:
|
40
45
|
A DataFrame with the list of jobs.
|
41
46
|
|
47
|
+
Raises:
|
48
|
+
SnowparkSQLException: if there is an error retrieving the job history.
|
49
|
+
|
42
50
|
Examples:
|
43
51
|
>>> from snowflake.ml.jobs import list_jobs
|
44
|
-
>>> list_jobs(limit=5)
|
52
|
+
>>> list_jobs(limit=5)
|
45
53
|
"""
|
46
54
|
session = session or get_active_session()
|
55
|
+
try:
|
56
|
+
df = _get_job_history_spcs(
|
57
|
+
session,
|
58
|
+
limit=limit,
|
59
|
+
database=database,
|
60
|
+
schema=schema,
|
61
|
+
)
|
62
|
+
return df.to_pandas()
|
63
|
+
except SnowparkSQLException as spcs_error:
|
64
|
+
if spcs_error.sql_error_code == 2143:
|
65
|
+
logger.debug("Job history is not enabled. Please enable it to use this feature.")
|
66
|
+
df = _get_job_services(session, limit=limit, database=database, schema=schema)
|
67
|
+
return df.to_pandas()
|
68
|
+
raise
|
69
|
+
|
70
|
+
|
71
|
+
def _get_job_services(
|
72
|
+
session: snowpark.Session, limit: int = 10, database: Optional[str] = None, schema: Optional[str] = None
|
73
|
+
) -> snowpark.DataFrame:
|
47
74
|
query = "SHOW JOB SERVICES"
|
48
75
|
query += f" LIKE '{JOB_ID_PREFIX}%'"
|
49
|
-
|
50
|
-
|
76
|
+
database = database or session.get_current_database()
|
77
|
+
schema = schema or session.get_current_schema()
|
78
|
+
if database is None and schema is None:
|
79
|
+
query += "IN account"
|
80
|
+
elif not schema:
|
81
|
+
query += f" IN DATABASE {database}"
|
82
|
+
else:
|
83
|
+
query += f" IN {database}.{schema}"
|
51
84
|
if limit > 0:
|
52
85
|
query += f" LIMIT {limit}"
|
53
86
|
df = session.sql(query)
|
54
87
|
df = df.select(
|
55
88
|
df['"name"'],
|
56
|
-
df['"owner"'],
|
57
89
|
df['"status"'],
|
58
|
-
|
90
|
+
lit(None).alias('"message"'),
|
91
|
+
df['"database_name"'],
|
92
|
+
df['"schema_name"'],
|
93
|
+
df['"owner"'],
|
59
94
|
df['"compute_pool"'],
|
60
|
-
|
95
|
+
df['"target_instances"'],
|
96
|
+
df['"created_on"'].alias('"created_time"'),
|
97
|
+
when(col('"status"').isin(jb.TERMINAL_JOB_STATUSES), col('"updated_on"'))
|
98
|
+
.otherwise(lit(None))
|
99
|
+
.alias('"completed_time"'),
|
100
|
+
).order_by('"created_time"', ascending=False)
|
101
|
+
return df
|
102
|
+
|
103
|
+
|
104
|
+
def _get_job_history_spcs(
|
105
|
+
session: snowpark.Session,
|
106
|
+
limit: int = 10,
|
107
|
+
database: Optional[str] = None,
|
108
|
+
schema: Optional[str] = None,
|
109
|
+
include_deleted: bool = False,
|
110
|
+
created_time_start: Optional[str] = None,
|
111
|
+
created_time_end: Optional[str] = None,
|
112
|
+
) -> snowpark.DataFrame:
|
113
|
+
query = ["select * from table(snowflake.spcs.get_job_history("]
|
114
|
+
query_params = []
|
115
|
+
if created_time_start:
|
116
|
+
query_params.append(f"created_time_start => TO_TIMESTAMP_LTZ('{created_time_start}')")
|
117
|
+
if created_time_end:
|
118
|
+
query_params.append(f"created_time_end => TO_TIMESTAMP_LTZ('{created_time_end}')")
|
119
|
+
query.append(",".join(query_params))
|
120
|
+
query.append("))")
|
121
|
+
condition = []
|
122
|
+
database = database or session.get_current_database()
|
123
|
+
schema = schema or session.get_current_schema()
|
124
|
+
|
125
|
+
# format database and schema identifiers
|
126
|
+
if database:
|
127
|
+
condition.append(f"DATABASE_NAME = '{identifier.resolve_identifier(database)}'")
|
128
|
+
|
129
|
+
if schema:
|
130
|
+
condition.append(f"SCHEMA_NAME = '{identifier.resolve_identifier(schema)}'")
|
131
|
+
|
132
|
+
if not include_deleted:
|
133
|
+
condition.append("DELETED_TIME IS NULL")
|
134
|
+
|
135
|
+
if len(condition) > 0:
|
136
|
+
query.append("WHERE " + " AND ".join(condition))
|
137
|
+
if limit > 0:
|
138
|
+
query.append(f"LIMIT {limit}")
|
139
|
+
df = session.sql("\n".join(query))
|
140
|
+
df = df.select(
|
141
|
+
df["NAME"].alias('"name"'),
|
142
|
+
df["STATUS"].alias('"status"'),
|
143
|
+
df["MESSAGE"].alias('"message"'),
|
144
|
+
df["DATABASE_NAME"].alias('"database_name"'),
|
145
|
+
df["SCHEMA_NAME"].alias('"schema_name"'),
|
146
|
+
df["OWNER"].alias('"owner"'),
|
147
|
+
df["COMPUTE_POOL_NAME"].alias('"compute_pool"'),
|
148
|
+
coalesce(df["PARAMETERS"]["REPLICAS"], lit(1)).alias('"target_instances"'),
|
149
|
+
df["CREATED_TIME"].alias('"created_time"'),
|
150
|
+
df["COMPLETED_TIME"].alias('"completed_time"'),
|
151
|
+
)
|
61
152
|
return df
|
62
153
|
|
63
154
|
|
@@ -74,12 +165,12 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
74
165
|
|
75
166
|
job_id = f"{database}.{schema}.{job_name}"
|
76
167
|
try:
|
77
|
-
# Validate that job exists by doing a
|
168
|
+
# Validate that job exists by doing a spec lookup
|
78
169
|
job = jb.MLJob[Any](job_id, session=session)
|
79
|
-
_ = job.
|
170
|
+
_ = job._service_spec
|
80
171
|
return job
|
81
|
-
except
|
82
|
-
if "does not exist" in e
|
172
|
+
except errors.ProgrammingError as e:
|
173
|
+
if "does not exist" in str(e):
|
83
174
|
raise ValueError(f"Job does not exist: {job_id}") from e
|
84
175
|
raise
|
85
176
|
|
@@ -95,7 +186,7 @@ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Sessio
|
|
95
186
|
logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
|
96
187
|
except Exception as e:
|
97
188
|
logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
|
98
|
-
session.
|
189
|
+
session._conn.run_query("DROP SERVICE IDENTIFIER(?)", params=(job.id,), _force_qmark_paramstyle=True)
|
99
190
|
|
100
191
|
|
101
192
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
@@ -105,17 +196,11 @@ def submit_file(
|
|
105
196
|
*,
|
106
197
|
stage_name: str,
|
107
198
|
args: Optional[list[str]] = None,
|
108
|
-
|
199
|
+
target_instances: int = 1,
|
109
200
|
pip_requirements: Optional[list[str]] = None,
|
110
201
|
external_access_integrations: Optional[list[str]] = None,
|
111
|
-
query_warehouse: Optional[str] = None,
|
112
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
113
|
-
target_instances: int = 1,
|
114
|
-
min_instances: Optional[int] = None,
|
115
|
-
enable_metrics: bool = False,
|
116
|
-
database: Optional[str] = None,
|
117
|
-
schema: Optional[str] = None,
|
118
202
|
session: Optional[snowpark.Session] = None,
|
203
|
+
**kwargs: Any,
|
119
204
|
) -> jb.MLJob[None]:
|
120
205
|
"""
|
121
206
|
Submit a Python file as a job to the compute pool.
|
@@ -125,18 +210,20 @@ def submit_file(
|
|
125
210
|
compute_pool: The compute pool to use for the job.
|
126
211
|
stage_name: The name of the stage where the job payload will be uploaded.
|
127
212
|
args: A list of arguments to pass to the job.
|
128
|
-
|
213
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
129
214
|
pip_requirements: A list of pip requirements for the job.
|
130
215
|
external_access_integrations: A list of external access integrations.
|
131
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
132
|
-
spec_overrides: Custom service specification overrides to apply.
|
133
|
-
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
134
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
135
|
-
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
136
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
137
|
-
database: The database to use.
|
138
|
-
schema: The schema to use.
|
139
216
|
session: The Snowpark session to use. If none specified, uses active session.
|
217
|
+
kwargs: Additional keyword arguments. Supported arguments:
|
218
|
+
database (str): The database to use for the job.
|
219
|
+
schema (str): The schema to use for the job.
|
220
|
+
min_instances (int): The minimum number of nodes required to start the job.
|
221
|
+
If none specified, defaults to target_instances. If set, the job
|
222
|
+
will not start until the minimum number of nodes is available.
|
223
|
+
env_vars (dict): Environment variables to set in container.
|
224
|
+
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
225
|
+
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
226
|
+
spec_overrides (dict): A dictionary of overrides for the service spec.
|
140
227
|
|
141
228
|
Returns:
|
142
229
|
An object representing the submitted job.
|
@@ -146,17 +233,11 @@ def submit_file(
|
|
146
233
|
args=args,
|
147
234
|
compute_pool=compute_pool,
|
148
235
|
stage_name=stage_name,
|
149
|
-
|
236
|
+
target_instances=target_instances,
|
150
237
|
pip_requirements=pip_requirements,
|
151
238
|
external_access_integrations=external_access_integrations,
|
152
|
-
query_warehouse=query_warehouse,
|
153
|
-
spec_overrides=spec_overrides,
|
154
|
-
target_instances=target_instances,
|
155
|
-
min_instances=min_instances,
|
156
|
-
enable_metrics=enable_metrics,
|
157
|
-
database=database,
|
158
|
-
schema=schema,
|
159
239
|
session=session,
|
240
|
+
**kwargs,
|
160
241
|
)
|
161
242
|
|
162
243
|
|
@@ -168,17 +249,11 @@ def submit_directory(
|
|
168
249
|
entrypoint: str,
|
169
250
|
stage_name: str,
|
170
251
|
args: Optional[list[str]] = None,
|
171
|
-
|
252
|
+
target_instances: int = 1,
|
172
253
|
pip_requirements: Optional[list[str]] = None,
|
173
254
|
external_access_integrations: Optional[list[str]] = None,
|
174
|
-
query_warehouse: Optional[str] = None,
|
175
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
176
|
-
target_instances: int = 1,
|
177
|
-
min_instances: Optional[int] = None,
|
178
|
-
enable_metrics: bool = False,
|
179
|
-
database: Optional[str] = None,
|
180
|
-
schema: Optional[str] = None,
|
181
255
|
session: Optional[snowpark.Session] = None,
|
256
|
+
**kwargs: Any,
|
182
257
|
) -> jb.MLJob[None]:
|
183
258
|
"""
|
184
259
|
Submit a directory containing Python script(s) as a job to the compute pool.
|
@@ -189,18 +264,20 @@ def submit_directory(
|
|
189
264
|
entrypoint: The relative path to the entry point script inside the source directory.
|
190
265
|
stage_name: The name of the stage where the job payload will be uploaded.
|
191
266
|
args: A list of arguments to pass to the job.
|
192
|
-
|
267
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
193
268
|
pip_requirements: A list of pip requirements for the job.
|
194
269
|
external_access_integrations: A list of external access integrations.
|
195
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
196
|
-
spec_overrides: Custom service specification overrides to apply.
|
197
|
-
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
198
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
199
|
-
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
200
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
201
|
-
database: The database to use.
|
202
|
-
schema: The schema to use.
|
203
270
|
session: The Snowpark session to use. If none specified, uses active session.
|
271
|
+
kwargs: Additional keyword arguments. Supported arguments:
|
272
|
+
database (str): The database to use for the job.
|
273
|
+
schema (str): The schema to use for the job.
|
274
|
+
min_instances (int): The minimum number of nodes required to start the job.
|
275
|
+
If none specified, defaults to target_instances. If set, the job
|
276
|
+
will not start until the minimum number of nodes is available.
|
277
|
+
env_vars (dict): Environment variables to set in container.
|
278
|
+
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
279
|
+
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
280
|
+
spec_overrides (dict): A dictionary of overrides for the service spec.
|
204
281
|
|
205
282
|
Returns:
|
206
283
|
An object representing the submitted job.
|
@@ -211,17 +288,11 @@ def submit_directory(
|
|
211
288
|
args=args,
|
212
289
|
compute_pool=compute_pool,
|
213
290
|
stage_name=stage_name,
|
214
|
-
|
291
|
+
target_instances=target_instances,
|
215
292
|
pip_requirements=pip_requirements,
|
216
293
|
external_access_integrations=external_access_integrations,
|
217
|
-
query_warehouse=query_warehouse,
|
218
|
-
spec_overrides=spec_overrides,
|
219
|
-
target_instances=target_instances,
|
220
|
-
min_instances=min_instances,
|
221
|
-
enable_metrics=enable_metrics,
|
222
|
-
database=database,
|
223
|
-
schema=schema,
|
224
294
|
session=session,
|
295
|
+
**kwargs,
|
225
296
|
)
|
226
297
|
|
227
298
|
|
@@ -233,17 +304,11 @@ def submit_from_stage(
|
|
233
304
|
entrypoint: str,
|
234
305
|
stage_name: str,
|
235
306
|
args: Optional[list[str]] = None,
|
236
|
-
|
307
|
+
target_instances: int = 1,
|
237
308
|
pip_requirements: Optional[list[str]] = None,
|
238
309
|
external_access_integrations: Optional[list[str]] = None,
|
239
|
-
query_warehouse: Optional[str] = None,
|
240
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
241
|
-
target_instances: int = 1,
|
242
|
-
min_instances: Optional[int] = None,
|
243
|
-
enable_metrics: bool = False,
|
244
|
-
database: Optional[str] = None,
|
245
|
-
schema: Optional[str] = None,
|
246
310
|
session: Optional[snowpark.Session] = None,
|
311
|
+
**kwargs: Any,
|
247
312
|
) -> jb.MLJob[None]:
|
248
313
|
"""
|
249
314
|
Submit a directory containing Python script(s) as a job to the compute pool.
|
@@ -254,19 +319,20 @@ def submit_from_stage(
|
|
254
319
|
entrypoint: a stage path containing the entry point script inside the source directory.
|
255
320
|
stage_name: The name of the stage where the job payload will be uploaded.
|
256
321
|
args: A list of arguments to pass to the job.
|
257
|
-
|
322
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
258
323
|
pip_requirements: A list of pip requirements for the job.
|
259
324
|
external_access_integrations: A list of external access integrations.
|
260
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
261
|
-
spec_overrides: Custom service specification overrides to apply.
|
262
|
-
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
263
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
264
|
-
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
265
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
266
|
-
database: The database to use.
|
267
|
-
schema: The schema to use.
|
268
325
|
session: The Snowpark session to use. If none specified, uses active session.
|
269
|
-
|
326
|
+
kwargs: Additional keyword arguments. Supported arguments:
|
327
|
+
database (str): The database to use for the job.
|
328
|
+
schema (str): The schema to use for the job.
|
329
|
+
min_instances (int): The minimum number of nodes required to start the job.
|
330
|
+
If none specified, defaults to target_instances. If set, the job
|
331
|
+
will not start until the minimum number of nodes is available.
|
332
|
+
env_vars (dict): Environment variables to set in container.
|
333
|
+
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
334
|
+
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
335
|
+
spec_overrides (dict): A dictionary of overrides for the service spec.
|
270
336
|
|
271
337
|
Returns:
|
272
338
|
An object representing the submitted job.
|
@@ -277,17 +343,11 @@ def submit_from_stage(
|
|
277
343
|
args=args,
|
278
344
|
compute_pool=compute_pool,
|
279
345
|
stage_name=stage_name,
|
280
|
-
|
346
|
+
target_instances=target_instances,
|
281
347
|
pip_requirements=pip_requirements,
|
282
348
|
external_access_integrations=external_access_integrations,
|
283
|
-
query_warehouse=query_warehouse,
|
284
|
-
spec_overrides=spec_overrides,
|
285
|
-
target_instances=target_instances,
|
286
|
-
min_instances=min_instances,
|
287
|
-
enable_metrics=enable_metrics,
|
288
|
-
database=database,
|
289
|
-
schema=schema,
|
290
349
|
session=session,
|
350
|
+
**kwargs,
|
291
351
|
)
|
292
352
|
|
293
353
|
|
@@ -299,17 +359,11 @@ def _submit_job(
|
|
299
359
|
stage_name: str,
|
300
360
|
entrypoint: Optional[str] = None,
|
301
361
|
args: Optional[list[str]] = None,
|
302
|
-
|
362
|
+
target_instances: int = 1,
|
303
363
|
pip_requirements: Optional[list[str]] = None,
|
304
364
|
external_access_integrations: Optional[list[str]] = None,
|
305
|
-
query_warehouse: Optional[str] = None,
|
306
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
307
|
-
target_instances: int = 1,
|
308
|
-
min_instances: Optional[int] = None,
|
309
|
-
enable_metrics: bool = False,
|
310
|
-
database: Optional[str] = None,
|
311
|
-
schema: Optional[str] = None,
|
312
365
|
session: Optional[snowpark.Session] = None,
|
366
|
+
**kwargs: Any,
|
313
367
|
) -> jb.MLJob[None]:
|
314
368
|
...
|
315
369
|
|
@@ -322,17 +376,11 @@ def _submit_job(
|
|
322
376
|
stage_name: str,
|
323
377
|
entrypoint: Optional[str] = None,
|
324
378
|
args: Optional[list[str]] = None,
|
325
|
-
|
379
|
+
target_instances: int = 1,
|
326
380
|
pip_requirements: Optional[list[str]] = None,
|
327
381
|
external_access_integrations: Optional[list[str]] = None,
|
328
|
-
query_warehouse: Optional[str] = None,
|
329
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
330
|
-
target_instances: int = 1,
|
331
|
-
min_instances: Optional[int] = None,
|
332
|
-
enable_metrics: bool = False,
|
333
|
-
database: Optional[str] = None,
|
334
|
-
schema: Optional[str] = None,
|
335
382
|
session: Optional[snowpark.Session] = None,
|
383
|
+
**kwargs: Any,
|
336
384
|
) -> jb.MLJob[T]:
|
337
385
|
...
|
338
386
|
|
@@ -345,8 +393,9 @@ def _submit_job(
|
|
345
393
|
# TODO: Log lengths of args, env_vars, and spec_overrides values
|
346
394
|
"pip_requirements",
|
347
395
|
"external_access_integrations",
|
396
|
+
"num_instances", # deprecated
|
348
397
|
"target_instances",
|
349
|
-
"
|
398
|
+
"min_instances",
|
350
399
|
],
|
351
400
|
)
|
352
401
|
def _submit_job(
|
@@ -356,17 +405,9 @@ def _submit_job(
|
|
356
405
|
stage_name: str,
|
357
406
|
entrypoint: Optional[str] = None,
|
358
407
|
args: Optional[list[str]] = None,
|
359
|
-
env_vars: Optional[dict[str, str]] = None,
|
360
|
-
pip_requirements: Optional[list[str]] = None,
|
361
|
-
external_access_integrations: Optional[list[str]] = None,
|
362
|
-
query_warehouse: Optional[str] = None,
|
363
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
364
408
|
target_instances: int = 1,
|
365
|
-
min_instances: Optional[int] = None,
|
366
|
-
enable_metrics: bool = False,
|
367
|
-
database: Optional[str] = None,
|
368
|
-
schema: Optional[str] = None,
|
369
409
|
session: Optional[snowpark.Session] = None,
|
410
|
+
**kwargs: Any,
|
370
411
|
) -> jb.MLJob[T]:
|
371
412
|
"""
|
372
413
|
Submit a job to the compute pool.
|
@@ -377,18 +418,9 @@ def _submit_job(
|
|
377
418
|
stage_name: The name of the stage where the job payload will be uploaded.
|
378
419
|
entrypoint: The entry point for the job execution. Required if source is a directory.
|
379
420
|
args: A list of arguments to pass to the job.
|
380
|
-
env_vars: Environment variables to set in container
|
381
|
-
pip_requirements: A list of pip requirements for the job.
|
382
|
-
external_access_integrations: A list of external access integrations.
|
383
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
384
|
-
spec_overrides: Custom service specification overrides to apply.
|
385
421
|
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
386
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
387
|
-
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
388
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
389
|
-
database: The database to use.
|
390
|
-
schema: The schema to use.
|
391
422
|
session: The Snowpark session to use. If none specified, uses active session.
|
423
|
+
kwargs: Additional keyword arguments.
|
392
424
|
|
393
425
|
Returns:
|
394
426
|
An object representing the submitted job.
|
@@ -396,35 +428,52 @@ def _submit_job(
|
|
396
428
|
Raises:
|
397
429
|
RuntimeError: If required Snowflake features are not enabled.
|
398
430
|
ValueError: If database or schema value(s) are invalid
|
431
|
+
errors.ProgrammingError: if the SQL query or its parameters are invalid
|
399
432
|
"""
|
433
|
+
session = session or get_active_session()
|
434
|
+
|
435
|
+
# Use kwargs for less common optional parameters
|
436
|
+
database = kwargs.pop("database", None)
|
437
|
+
schema = kwargs.pop("schema", None)
|
438
|
+
min_instances = kwargs.pop("min_instances", target_instances)
|
439
|
+
pip_requirements = kwargs.pop("pip_requirements", None)
|
440
|
+
external_access_integrations = kwargs.pop("external_access_integrations", None)
|
441
|
+
env_vars = kwargs.pop("env_vars", None)
|
442
|
+
spec_overrides = kwargs.pop("spec_overrides", None)
|
443
|
+
enable_metrics = kwargs.pop("enable_metrics", True)
|
444
|
+
query_warehouse = kwargs.pop("query_warehouse", None)
|
445
|
+
|
446
|
+
# Check for deprecated args
|
447
|
+
if "num_instances" in kwargs:
|
448
|
+
logger.warning(
|
449
|
+
"'num_instances' is deprecated and will be removed in a future release. Use 'target_instances' instead."
|
450
|
+
)
|
451
|
+
target_instances = max(target_instances, kwargs.pop("num_instances"))
|
452
|
+
|
453
|
+
# Warn if there are unknown kwargs
|
454
|
+
if kwargs:
|
455
|
+
logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
|
456
|
+
|
457
|
+
# Validate parameters
|
400
458
|
if database and not schema:
|
401
459
|
raise ValueError("Schema must be specified if database is specified.")
|
402
460
|
if target_instances < 1:
|
403
461
|
raise ValueError("target_instances must be greater than 0.")
|
404
|
-
|
405
|
-
min_instances = target_instances if min_instances is None else min_instances
|
406
462
|
if not (0 < min_instances <= target_instances):
|
407
463
|
raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
|
408
|
-
|
409
|
-
session = session or get_active_session()
|
410
|
-
|
411
464
|
if min_instances > 1:
|
412
465
|
# Validate min_instances against compute pool max_nodes
|
413
466
|
pool_info = jb._get_compute_pool_info(session, compute_pool)
|
414
|
-
|
467
|
+
requested_attributes = query_helper.get_attribute_map(session, {"max_nodes": 3})
|
468
|
+
max_nodes = int(pool_info[requested_attributes["max_nodes"]])
|
415
469
|
if min_instances > max_nodes:
|
416
470
|
raise ValueError(
|
417
471
|
f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
|
418
472
|
f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
|
419
473
|
)
|
420
474
|
|
421
|
-
# Validate database and schema identifiers on client side since
|
422
|
-
# SQL parser for EXECUTE JOB SERVICE seems to struggle with this
|
423
|
-
database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
|
424
|
-
schema = identifier.resolve_identifier(cast(str, schema or session.get_current_schema()))
|
425
|
-
|
426
475
|
job_name = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
427
|
-
job_id =
|
476
|
+
job_id = identifier.get_schema_level_object_identifier(database, schema, job_name)
|
428
477
|
stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
|
429
478
|
stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
|
430
479
|
stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
|
@@ -482,12 +531,12 @@ def _submit_job(
|
|
482
531
|
query_text = "\n".join(line for line in query if line)
|
483
532
|
|
484
533
|
try:
|
485
|
-
_ = session.
|
486
|
-
except
|
487
|
-
if "invalid property 'ASYNC'" in e
|
534
|
+
_ = session._conn.run_query(query_text, params=params, _force_qmark_paramstyle=True)
|
535
|
+
except errors.ProgrammingError as e:
|
536
|
+
if "invalid property 'ASYNC'" in str(e):
|
488
537
|
raise RuntimeError(
|
489
538
|
"SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
|
490
539
|
) from e
|
491
540
|
raise
|
492
541
|
|
493
|
-
return
|
542
|
+
return get_job(job_id, session=session)
|
@@ -955,7 +955,7 @@ class ModelOperator:
|
|
955
955
|
output_with_input_features = False
|
956
956
|
df = model_signature._convert_and_validate_local_data(X, signature.inputs, strict=strict_input_validation)
|
957
957
|
s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
|
958
|
-
self._session, df, keep_order=keep_order, features=signature.inputs
|
958
|
+
self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
|
959
959
|
)
|
960
960
|
else:
|
961
961
|
keep_order = False
|
@@ -969,9 +969,16 @@ class ModelOperator:
|
|
969
969
|
|
970
970
|
# Compose input and output names
|
971
971
|
input_args = []
|
972
|
+
quoted_identifiers_ignore_case = (
|
973
|
+
snowpark_handler.SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
|
974
|
+
self._session, statement_params
|
975
|
+
)
|
976
|
+
)
|
977
|
+
|
972
978
|
for input_feature in signature.inputs:
|
973
979
|
col_name = identifier_rule.get_sql_identifier_from_feature(input_feature.name)
|
974
|
-
|
980
|
+
if quoted_identifiers_ignore_case:
|
981
|
+
col_name = sql_identifier.SqlIdentifier(input_feature.name.upper(), case_sensitive=True)
|
975
982
|
input_args.append(col_name)
|
976
983
|
|
977
984
|
returns = []
|
@@ -1051,7 +1058,9 @@ class ModelOperator:
|
|
1051
1058
|
|
1052
1059
|
# Get final result
|
1053
1060
|
if not isinstance(X, dataframe.DataFrame):
|
1054
|
-
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
|
1061
|
+
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
|
1062
|
+
df_res, features=signature.outputs, statement_params=statement_params
|
1063
|
+
)
|
1055
1064
|
else:
|
1056
1065
|
return df_res
|
1057
1066
|
|
@@ -518,7 +518,7 @@ class ServiceOperator:
|
|
518
518
|
output_with_input_features = False
|
519
519
|
df = model_signature._convert_and_validate_local_data(X, signature.inputs)
|
520
520
|
s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
|
521
|
-
self._session, df, keep_order=keep_order, features=signature.inputs
|
521
|
+
self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
|
522
522
|
)
|
523
523
|
else:
|
524
524
|
keep_order = False
|
@@ -630,7 +630,9 @@ class ServiceOperator:
|
|
630
630
|
|
631
631
|
# get final result
|
632
632
|
if not isinstance(X, dataframe.DataFrame):
|
633
|
-
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
|
633
|
+
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
|
634
|
+
df_res, features=signature.outputs, statement_params=statement_params
|
635
|
+
)
|
634
636
|
else:
|
635
637
|
return df_res
|
636
638
|
|