snowflake-ml-python 1.8.5__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +6 -9
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +61 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +3 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +89 -40
- snowflake/ml/jobs/_utils/query_helper.py +9 -0
- snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
- snowflake/ml/jobs/_utils/spec_utils.py +29 -5
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +20 -28
- snowflake/ml/jobs/job.py +197 -61
- snowflake/ml/jobs/manager.py +253 -121
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +18 -6
- snowflake/ml/model/_client/ops/service_ops.py +23 -6
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
- snowflake/ml/model/_client/sql/service.py +68 -20
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/monitoring/explain_visualize.py +2 -2
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/_manager/model_manager.py +30 -15
- snowflake/ml/registry/registry.py +144 -47
- snowflake/ml/utils/connection_params.py +1 -1
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +64 -19
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +48 -41
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/manager.py
CHANGED
@@ -1,18 +1,21 @@
|
|
1
1
|
import logging
|
2
2
|
import pathlib
|
3
3
|
import textwrap
|
4
|
-
from typing import Any, Callable,
|
4
|
+
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
|
5
5
|
from uuid import uuid4
|
6
6
|
|
7
|
+
import pandas as pd
|
7
8
|
import yaml
|
8
9
|
|
9
10
|
from snowflake import snowpark
|
11
|
+
from snowflake.connector import errors
|
10
12
|
from snowflake.ml._internal import telemetry
|
11
13
|
from snowflake.ml._internal.utils import identifier
|
12
14
|
from snowflake.ml.jobs import job as jb
|
13
|
-
from snowflake.ml.jobs._utils import payload_utils, spec_utils
|
15
|
+
from snowflake.ml.jobs._utils import payload_utils, query_helper, spec_utils
|
14
16
|
from snowflake.snowpark.context import get_active_session
|
15
17
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
18
|
+
from snowflake.snowpark.functions import coalesce, col, lit, when
|
16
19
|
|
17
20
|
logger = logging.getLogger(__name__)
|
18
21
|
|
@@ -25,39 +28,127 @@ T = TypeVar("T")
|
|
25
28
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
|
26
29
|
def list_jobs(
|
27
30
|
limit: int = 10,
|
28
|
-
|
31
|
+
database: Optional[str] = None,
|
32
|
+
schema: Optional[str] = None,
|
29
33
|
session: Optional[snowpark.Session] = None,
|
30
|
-
) ->
|
34
|
+
) -> pd.DataFrame:
|
31
35
|
"""
|
32
|
-
Returns a
|
36
|
+
Returns a Pandas DataFrame with the list of jobs in the current session.
|
33
37
|
|
34
38
|
Args:
|
35
39
|
limit: The maximum number of jobs to return. Non-positive values are treated as no limit.
|
36
|
-
|
40
|
+
database: The database to use. If not specified, uses the current database.
|
41
|
+
schema: The schema to use. If not specified, uses the current schema.
|
37
42
|
session: The Snowpark session to use. If none specified, uses active session.
|
38
43
|
|
39
44
|
Returns:
|
40
45
|
A DataFrame with the list of jobs.
|
41
46
|
|
47
|
+
Raises:
|
48
|
+
SnowparkSQLException: if there is an error retrieving the job history.
|
49
|
+
|
42
50
|
Examples:
|
43
51
|
>>> from snowflake.ml.jobs import list_jobs
|
44
|
-
>>> list_jobs(limit=5)
|
52
|
+
>>> list_jobs(limit=5)
|
45
53
|
"""
|
46
54
|
session = session or get_active_session()
|
55
|
+
try:
|
56
|
+
df = _get_job_history_spcs(
|
57
|
+
session,
|
58
|
+
limit=limit,
|
59
|
+
database=database,
|
60
|
+
schema=schema,
|
61
|
+
)
|
62
|
+
return df.to_pandas()
|
63
|
+
except SnowparkSQLException as spcs_error:
|
64
|
+
if spcs_error.sql_error_code == 2143:
|
65
|
+
logger.debug("Job history is not enabled. Please enable it to use this feature.")
|
66
|
+
df = _get_job_services(session, limit=limit, database=database, schema=schema)
|
67
|
+
return df.to_pandas()
|
68
|
+
raise
|
69
|
+
|
70
|
+
|
71
|
+
def _get_job_services(
|
72
|
+
session: snowpark.Session, limit: int = 10, database: Optional[str] = None, schema: Optional[str] = None
|
73
|
+
) -> snowpark.DataFrame:
|
47
74
|
query = "SHOW JOB SERVICES"
|
48
75
|
query += f" LIKE '{JOB_ID_PREFIX}%'"
|
49
|
-
|
50
|
-
|
76
|
+
database = database or session.get_current_database()
|
77
|
+
schema = schema or session.get_current_schema()
|
78
|
+
if database is None and schema is None:
|
79
|
+
query += "IN account"
|
80
|
+
elif not schema:
|
81
|
+
query += f" IN DATABASE {database}"
|
82
|
+
else:
|
83
|
+
query += f" IN {database}.{schema}"
|
51
84
|
if limit > 0:
|
52
85
|
query += f" LIMIT {limit}"
|
53
86
|
df = session.sql(query)
|
54
87
|
df = df.select(
|
55
88
|
df['"name"'],
|
56
|
-
df['"owner"'],
|
57
89
|
df['"status"'],
|
58
|
-
|
90
|
+
lit(None).alias('"message"'),
|
91
|
+
df['"database_name"'],
|
92
|
+
df['"schema_name"'],
|
93
|
+
df['"owner"'],
|
59
94
|
df['"compute_pool"'],
|
60
|
-
|
95
|
+
df['"target_instances"'],
|
96
|
+
df['"created_on"'].alias('"created_time"'),
|
97
|
+
when(col('"status"').isin(jb.TERMINAL_JOB_STATUSES), col('"updated_on"'))
|
98
|
+
.otherwise(lit(None))
|
99
|
+
.alias('"completed_time"'),
|
100
|
+
).order_by('"created_time"', ascending=False)
|
101
|
+
return df
|
102
|
+
|
103
|
+
|
104
|
+
def _get_job_history_spcs(
|
105
|
+
session: snowpark.Session,
|
106
|
+
limit: int = 10,
|
107
|
+
database: Optional[str] = None,
|
108
|
+
schema: Optional[str] = None,
|
109
|
+
include_deleted: bool = False,
|
110
|
+
created_time_start: Optional[str] = None,
|
111
|
+
created_time_end: Optional[str] = None,
|
112
|
+
) -> snowpark.DataFrame:
|
113
|
+
query = ["select * from table(snowflake.spcs.get_job_history("]
|
114
|
+
query_params = []
|
115
|
+
if created_time_start:
|
116
|
+
query_params.append(f"created_time_start => TO_TIMESTAMP_LTZ('{created_time_start}')")
|
117
|
+
if created_time_end:
|
118
|
+
query_params.append(f"created_time_end => TO_TIMESTAMP_LTZ('{created_time_end}')")
|
119
|
+
query.append(",".join(query_params))
|
120
|
+
query.append("))")
|
121
|
+
condition = []
|
122
|
+
database = database or session.get_current_database()
|
123
|
+
schema = schema or session.get_current_schema()
|
124
|
+
|
125
|
+
# format database and schema identifiers
|
126
|
+
if database:
|
127
|
+
condition.append(f"DATABASE_NAME = '{identifier.resolve_identifier(database)}'")
|
128
|
+
|
129
|
+
if schema:
|
130
|
+
condition.append(f"SCHEMA_NAME = '{identifier.resolve_identifier(schema)}'")
|
131
|
+
|
132
|
+
if not include_deleted:
|
133
|
+
condition.append("DELETED_TIME IS NULL")
|
134
|
+
|
135
|
+
if len(condition) > 0:
|
136
|
+
query.append("WHERE " + " AND ".join(condition))
|
137
|
+
if limit > 0:
|
138
|
+
query.append(f"LIMIT {limit}")
|
139
|
+
df = session.sql("\n".join(query))
|
140
|
+
df = df.select(
|
141
|
+
df["NAME"].alias('"name"'),
|
142
|
+
df["STATUS"].alias('"status"'),
|
143
|
+
df["MESSAGE"].alias('"message"'),
|
144
|
+
df["DATABASE_NAME"].alias('"database_name"'),
|
145
|
+
df["SCHEMA_NAME"].alias('"schema_name"'),
|
146
|
+
df["OWNER"].alias('"owner"'),
|
147
|
+
df["COMPUTE_POOL_NAME"].alias('"compute_pool"'),
|
148
|
+
coalesce(df["PARAMETERS"]["REPLICAS"], lit(1)).alias('"target_instances"'),
|
149
|
+
df["CREATED_TIME"].alias('"created_time"'),
|
150
|
+
df["COMPLETED_TIME"].alias('"completed_time"'),
|
151
|
+
)
|
61
152
|
return df
|
62
153
|
|
63
154
|
|
@@ -74,12 +165,12 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
74
165
|
|
75
166
|
job_id = f"{database}.{schema}.{job_name}"
|
76
167
|
try:
|
77
|
-
# Validate that job exists by doing a
|
168
|
+
# Validate that job exists by doing a spec lookup
|
78
169
|
job = jb.MLJob[Any](job_id, session=session)
|
79
|
-
_ = job.
|
170
|
+
_ = job._service_spec
|
80
171
|
return job
|
81
|
-
except
|
82
|
-
if "does not exist" in e
|
172
|
+
except errors.ProgrammingError as e:
|
173
|
+
if "does not exist" in str(e):
|
83
174
|
raise ValueError(f"Job does not exist: {job_id}") from e
|
84
175
|
raise
|
85
176
|
|
@@ -87,13 +178,15 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
87
178
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
88
179
|
def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
|
89
180
|
"""Delete a job service from the backend. Status and logs will be lost."""
|
90
|
-
if isinstance(job, jb.MLJob)
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
181
|
+
job = job if isinstance(job, jb.MLJob) else get_job(job, session=session)
|
182
|
+
session = job._session
|
183
|
+
try:
|
184
|
+
stage_path = job._stage_path
|
185
|
+
session.sql(f"REMOVE {stage_path}/").collect()
|
186
|
+
logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
|
187
|
+
except Exception as e:
|
188
|
+
logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
|
189
|
+
session._conn.run_query("DROP SERVICE IDENTIFIER(?)", params=(job.id,), _force_qmark_paramstyle=True)
|
97
190
|
|
98
191
|
|
99
192
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
@@ -103,17 +196,11 @@ def submit_file(
|
|
103
196
|
*,
|
104
197
|
stage_name: str,
|
105
198
|
args: Optional[list[str]] = None,
|
106
|
-
|
199
|
+
target_instances: int = 1,
|
107
200
|
pip_requirements: Optional[list[str]] = None,
|
108
201
|
external_access_integrations: Optional[list[str]] = None,
|
109
|
-
query_warehouse: Optional[str] = None,
|
110
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
111
|
-
target_instances: int = 1,
|
112
|
-
min_instances: int = 1,
|
113
|
-
enable_metrics: bool = False,
|
114
|
-
database: Optional[str] = None,
|
115
|
-
schema: Optional[str] = None,
|
116
202
|
session: Optional[snowpark.Session] = None,
|
203
|
+
**kwargs: Any,
|
117
204
|
) -> jb.MLJob[None]:
|
118
205
|
"""
|
119
206
|
Submit a Python file as a job to the compute pool.
|
@@ -123,17 +210,20 @@ def submit_file(
|
|
123
210
|
compute_pool: The compute pool to use for the job.
|
124
211
|
stage_name: The name of the stage where the job payload will be uploaded.
|
125
212
|
args: A list of arguments to pass to the job.
|
126
|
-
|
213
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
127
214
|
pip_requirements: A list of pip requirements for the job.
|
128
215
|
external_access_integrations: A list of external access integrations.
|
129
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
130
|
-
spec_overrides: Custom service specification overrides to apply.
|
131
|
-
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
132
|
-
min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
|
133
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
134
|
-
database: The database to use.
|
135
|
-
schema: The schema to use.
|
136
216
|
session: The Snowpark session to use. If none specified, uses active session.
|
217
|
+
kwargs: Additional keyword arguments. Supported arguments:
|
218
|
+
database (str): The database to use for the job.
|
219
|
+
schema (str): The schema to use for the job.
|
220
|
+
min_instances (int): The minimum number of nodes required to start the job.
|
221
|
+
If none specified, defaults to target_instances. If set, the job
|
222
|
+
will not start until the minimum number of nodes is available.
|
223
|
+
env_vars (dict): Environment variables to set in container.
|
224
|
+
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
225
|
+
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
226
|
+
spec_overrides (dict): A dictionary of overrides for the service spec.
|
137
227
|
|
138
228
|
Returns:
|
139
229
|
An object representing the submitted job.
|
@@ -143,17 +233,11 @@ def submit_file(
|
|
143
233
|
args=args,
|
144
234
|
compute_pool=compute_pool,
|
145
235
|
stage_name=stage_name,
|
146
|
-
|
236
|
+
target_instances=target_instances,
|
147
237
|
pip_requirements=pip_requirements,
|
148
238
|
external_access_integrations=external_access_integrations,
|
149
|
-
query_warehouse=query_warehouse,
|
150
|
-
spec_overrides=spec_overrides,
|
151
|
-
target_instances=target_instances,
|
152
|
-
min_instances=min_instances,
|
153
|
-
enable_metrics=enable_metrics,
|
154
|
-
database=database,
|
155
|
-
schema=schema,
|
156
239
|
session=session,
|
240
|
+
**kwargs,
|
157
241
|
)
|
158
242
|
|
159
243
|
|
@@ -165,17 +249,11 @@ def submit_directory(
|
|
165
249
|
entrypoint: str,
|
166
250
|
stage_name: str,
|
167
251
|
args: Optional[list[str]] = None,
|
168
|
-
|
252
|
+
target_instances: int = 1,
|
169
253
|
pip_requirements: Optional[list[str]] = None,
|
170
254
|
external_access_integrations: Optional[list[str]] = None,
|
171
|
-
query_warehouse: Optional[str] = None,
|
172
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
173
|
-
target_instances: int = 1,
|
174
|
-
min_instances: int = 1,
|
175
|
-
enable_metrics: bool = False,
|
176
|
-
database: Optional[str] = None,
|
177
|
-
schema: Optional[str] = None,
|
178
255
|
session: Optional[snowpark.Session] = None,
|
256
|
+
**kwargs: Any,
|
179
257
|
) -> jb.MLJob[None]:
|
180
258
|
"""
|
181
259
|
Submit a directory containing Python script(s) as a job to the compute pool.
|
@@ -186,17 +264,20 @@ def submit_directory(
|
|
186
264
|
entrypoint: The relative path to the entry point script inside the source directory.
|
187
265
|
stage_name: The name of the stage where the job payload will be uploaded.
|
188
266
|
args: A list of arguments to pass to the job.
|
189
|
-
|
267
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
190
268
|
pip_requirements: A list of pip requirements for the job.
|
191
269
|
external_access_integrations: A list of external access integrations.
|
192
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
193
|
-
spec_overrides: Custom service specification overrides to apply.
|
194
|
-
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
195
|
-
min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
|
196
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
197
|
-
database: The database to use.
|
198
|
-
schema: The schema to use.
|
199
270
|
session: The Snowpark session to use. If none specified, uses active session.
|
271
|
+
kwargs: Additional keyword arguments. Supported arguments:
|
272
|
+
database (str): The database to use for the job.
|
273
|
+
schema (str): The schema to use for the job.
|
274
|
+
min_instances (int): The minimum number of nodes required to start the job.
|
275
|
+
If none specified, defaults to target_instances. If set, the job
|
276
|
+
will not start until the minimum number of nodes is available.
|
277
|
+
env_vars (dict): Environment variables to set in container.
|
278
|
+
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
279
|
+
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
280
|
+
spec_overrides (dict): A dictionary of overrides for the service spec.
|
200
281
|
|
201
282
|
Returns:
|
202
283
|
An object representing the submitted job.
|
@@ -207,17 +288,66 @@ def submit_directory(
|
|
207
288
|
args=args,
|
208
289
|
compute_pool=compute_pool,
|
209
290
|
stage_name=stage_name,
|
210
|
-
|
291
|
+
target_instances=target_instances,
|
211
292
|
pip_requirements=pip_requirements,
|
212
293
|
external_access_integrations=external_access_integrations,
|
213
|
-
|
214
|
-
|
294
|
+
session=session,
|
295
|
+
**kwargs,
|
296
|
+
)
|
297
|
+
|
298
|
+
|
299
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
300
|
+
def submit_from_stage(
|
301
|
+
source: str,
|
302
|
+
compute_pool: str,
|
303
|
+
*,
|
304
|
+
entrypoint: str,
|
305
|
+
stage_name: str,
|
306
|
+
args: Optional[list[str]] = None,
|
307
|
+
target_instances: int = 1,
|
308
|
+
pip_requirements: Optional[list[str]] = None,
|
309
|
+
external_access_integrations: Optional[list[str]] = None,
|
310
|
+
session: Optional[snowpark.Session] = None,
|
311
|
+
**kwargs: Any,
|
312
|
+
) -> jb.MLJob[None]:
|
313
|
+
"""
|
314
|
+
Submit a directory containing Python script(s) as a job to the compute pool.
|
315
|
+
|
316
|
+
Args:
|
317
|
+
source: a stage path or a stage containing the job payload.
|
318
|
+
compute_pool: The compute pool to use for the job.
|
319
|
+
entrypoint: a stage path containing the entry point script inside the source directory.
|
320
|
+
stage_name: The name of the stage where the job payload will be uploaded.
|
321
|
+
args: A list of arguments to pass to the job.
|
322
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
323
|
+
pip_requirements: A list of pip requirements for the job.
|
324
|
+
external_access_integrations: A list of external access integrations.
|
325
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
326
|
+
kwargs: Additional keyword arguments. Supported arguments:
|
327
|
+
database (str): The database to use for the job.
|
328
|
+
schema (str): The schema to use for the job.
|
329
|
+
min_instances (int): The minimum number of nodes required to start the job.
|
330
|
+
If none specified, defaults to target_instances. If set, the job
|
331
|
+
will not start until the minimum number of nodes is available.
|
332
|
+
env_vars (dict): Environment variables to set in container.
|
333
|
+
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
334
|
+
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
335
|
+
spec_overrides (dict): A dictionary of overrides for the service spec.
|
336
|
+
|
337
|
+
Returns:
|
338
|
+
An object representing the submitted job.
|
339
|
+
"""
|
340
|
+
return _submit_job(
|
341
|
+
source=source,
|
342
|
+
entrypoint=entrypoint,
|
343
|
+
args=args,
|
344
|
+
compute_pool=compute_pool,
|
345
|
+
stage_name=stage_name,
|
215
346
|
target_instances=target_instances,
|
216
|
-
|
217
|
-
|
218
|
-
database=database,
|
219
|
-
schema=schema,
|
347
|
+
pip_requirements=pip_requirements,
|
348
|
+
external_access_integrations=external_access_integrations,
|
220
349
|
session=session,
|
350
|
+
**kwargs,
|
221
351
|
)
|
222
352
|
|
223
353
|
|
@@ -229,17 +359,11 @@ def _submit_job(
|
|
229
359
|
stage_name: str,
|
230
360
|
entrypoint: Optional[str] = None,
|
231
361
|
args: Optional[list[str]] = None,
|
232
|
-
|
362
|
+
target_instances: int = 1,
|
233
363
|
pip_requirements: Optional[list[str]] = None,
|
234
364
|
external_access_integrations: Optional[list[str]] = None,
|
235
|
-
query_warehouse: Optional[str] = None,
|
236
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
237
|
-
target_instances: int = 1,
|
238
|
-
min_instances: int = 1,
|
239
|
-
enable_metrics: bool = False,
|
240
|
-
database: Optional[str] = None,
|
241
|
-
schema: Optional[str] = None,
|
242
365
|
session: Optional[snowpark.Session] = None,
|
366
|
+
**kwargs: Any,
|
243
367
|
) -> jb.MLJob[None]:
|
244
368
|
...
|
245
369
|
|
@@ -252,17 +376,11 @@ def _submit_job(
|
|
252
376
|
stage_name: str,
|
253
377
|
entrypoint: Optional[str] = None,
|
254
378
|
args: Optional[list[str]] = None,
|
255
|
-
|
379
|
+
target_instances: int = 1,
|
256
380
|
pip_requirements: Optional[list[str]] = None,
|
257
381
|
external_access_integrations: Optional[list[str]] = None,
|
258
|
-
query_warehouse: Optional[str] = None,
|
259
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
260
|
-
target_instances: int = 1,
|
261
|
-
min_instances: int = 1,
|
262
|
-
enable_metrics: bool = False,
|
263
|
-
database: Optional[str] = None,
|
264
|
-
schema: Optional[str] = None,
|
265
382
|
session: Optional[snowpark.Session] = None,
|
383
|
+
**kwargs: Any,
|
266
384
|
) -> jb.MLJob[T]:
|
267
385
|
...
|
268
386
|
|
@@ -275,8 +393,9 @@ def _submit_job(
|
|
275
393
|
# TODO: Log lengths of args, env_vars, and spec_overrides values
|
276
394
|
"pip_requirements",
|
277
395
|
"external_access_integrations",
|
396
|
+
"num_instances", # deprecated
|
278
397
|
"target_instances",
|
279
|
-
"
|
398
|
+
"min_instances",
|
280
399
|
],
|
281
400
|
)
|
282
401
|
def _submit_job(
|
@@ -286,17 +405,9 @@ def _submit_job(
|
|
286
405
|
stage_name: str,
|
287
406
|
entrypoint: Optional[str] = None,
|
288
407
|
args: Optional[list[str]] = None,
|
289
|
-
env_vars: Optional[dict[str, str]] = None,
|
290
|
-
pip_requirements: Optional[list[str]] = None,
|
291
|
-
external_access_integrations: Optional[list[str]] = None,
|
292
|
-
query_warehouse: Optional[str] = None,
|
293
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
294
408
|
target_instances: int = 1,
|
295
|
-
min_instances: int = 1,
|
296
|
-
enable_metrics: bool = False,
|
297
|
-
database: Optional[str] = None,
|
298
|
-
schema: Optional[str] = None,
|
299
409
|
session: Optional[snowpark.Session] = None,
|
410
|
+
**kwargs: Any,
|
300
411
|
) -> jb.MLJob[T]:
|
301
412
|
"""
|
302
413
|
Submit a job to the compute pool.
|
@@ -307,17 +418,9 @@ def _submit_job(
|
|
307
418
|
stage_name: The name of the stage where the job payload will be uploaded.
|
308
419
|
entrypoint: The entry point for the job execution. Required if source is a directory.
|
309
420
|
args: A list of arguments to pass to the job.
|
310
|
-
env_vars: Environment variables to set in container
|
311
|
-
pip_requirements: A list of pip requirements for the job.
|
312
|
-
external_access_integrations: A list of external access integrations.
|
313
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
314
|
-
spec_overrides: Custom service specification overrides to apply.
|
315
421
|
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
316
|
-
min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
|
317
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
318
|
-
database: The database to use.
|
319
|
-
schema: The schema to use.
|
320
422
|
session: The Snowpark session to use. If none specified, uses active session.
|
423
|
+
kwargs: Additional keyword arguments.
|
321
424
|
|
322
425
|
Returns:
|
323
426
|
An object representing the submitted job.
|
@@ -325,23 +428,52 @@ def _submit_job(
|
|
325
428
|
Raises:
|
326
429
|
RuntimeError: If required Snowflake features are not enabled.
|
327
430
|
ValueError: If database or schema value(s) are invalid
|
431
|
+
errors.ProgrammingError: if the SQL query or its parameters are invalid
|
328
432
|
"""
|
329
|
-
if database and not schema:
|
330
|
-
raise ValueError("Schema must be specified if database is specified.")
|
331
|
-
if target_instances < 1 or min_instances < 1:
|
332
|
-
raise ValueError("target_instances and min_instances must be greater than 0.")
|
333
|
-
if min_instances > target_instances:
|
334
|
-
raise ValueError("min_instances must be less than or equal to target_instances.")
|
335
|
-
|
336
433
|
session = session or get_active_session()
|
337
434
|
|
338
|
-
#
|
339
|
-
|
340
|
-
|
341
|
-
|
435
|
+
# Use kwargs for less common optional parameters
|
436
|
+
database = kwargs.pop("database", None)
|
437
|
+
schema = kwargs.pop("schema", None)
|
438
|
+
min_instances = kwargs.pop("min_instances", target_instances)
|
439
|
+
pip_requirements = kwargs.pop("pip_requirements", None)
|
440
|
+
external_access_integrations = kwargs.pop("external_access_integrations", None)
|
441
|
+
env_vars = kwargs.pop("env_vars", None)
|
442
|
+
spec_overrides = kwargs.pop("spec_overrides", None)
|
443
|
+
enable_metrics = kwargs.pop("enable_metrics", True)
|
444
|
+
query_warehouse = kwargs.pop("query_warehouse", None)
|
445
|
+
|
446
|
+
# Check for deprecated args
|
447
|
+
if "num_instances" in kwargs:
|
448
|
+
logger.warning(
|
449
|
+
"'num_instances' is deprecated and will be removed in a future release. Use 'target_instances' instead."
|
450
|
+
)
|
451
|
+
target_instances = max(target_instances, kwargs.pop("num_instances"))
|
452
|
+
|
453
|
+
# Warn if there are unknown kwargs
|
454
|
+
if kwargs:
|
455
|
+
logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
|
456
|
+
|
457
|
+
# Validate parameters
|
458
|
+
if database and not schema:
|
459
|
+
raise ValueError("Schema must be specified if database is specified.")
|
460
|
+
if target_instances < 1:
|
461
|
+
raise ValueError("target_instances must be greater than 0.")
|
462
|
+
if not (0 < min_instances <= target_instances):
|
463
|
+
raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
|
464
|
+
if min_instances > 1:
|
465
|
+
# Validate min_instances against compute pool max_nodes
|
466
|
+
pool_info = jb._get_compute_pool_info(session, compute_pool)
|
467
|
+
requested_attributes = query_helper.get_attribute_map(session, {"max_nodes": 3})
|
468
|
+
max_nodes = int(pool_info[requested_attributes["max_nodes"]])
|
469
|
+
if min_instances > max_nodes:
|
470
|
+
raise ValueError(
|
471
|
+
f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
|
472
|
+
f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
|
473
|
+
)
|
342
474
|
|
343
475
|
job_name = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
344
|
-
job_id =
|
476
|
+
job_id = identifier.get_schema_level_object_identifier(database, schema, job_name)
|
345
477
|
stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
|
346
478
|
stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
|
347
479
|
stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
|
@@ -399,12 +531,12 @@ def _submit_job(
|
|
399
531
|
query_text = "\n".join(line for line in query if line)
|
400
532
|
|
401
533
|
try:
|
402
|
-
_ = session.
|
403
|
-
except
|
404
|
-
if "invalid property 'ASYNC'" in e
|
534
|
+
_ = session._conn.run_query(query_text, params=params, _force_qmark_paramstyle=True)
|
535
|
+
except errors.ProgrammingError as e:
|
536
|
+
if "invalid property 'ASYNC'" in str(e):
|
405
537
|
raise RuntimeError(
|
406
538
|
"SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
|
407
539
|
) from e
|
408
540
|
raise
|
409
541
|
|
410
|
-
return
|
542
|
+
return get_job(job_id, session=session)
|
@@ -426,3 +426,61 @@ class Model:
|
|
426
426
|
schema_name=new_schema or self._model_ops._model_client._schema_name,
|
427
427
|
)
|
428
428
|
self._model_name = new_model
|
429
|
+
|
430
|
+
def _repr_html_(self) -> str:
|
431
|
+
"""Generate an HTML representation of the model.
|
432
|
+
|
433
|
+
Returns:
|
434
|
+
str: HTML string containing formatted model details.
|
435
|
+
"""
|
436
|
+
from snowflake.ml.utils import html_utils
|
437
|
+
|
438
|
+
# Get default version
|
439
|
+
default_version = self.default.version_name
|
440
|
+
|
441
|
+
# Get versions info
|
442
|
+
try:
|
443
|
+
versions_df = self.show_versions()
|
444
|
+
versions_html = ""
|
445
|
+
|
446
|
+
for _, row in versions_df.iterrows():
|
447
|
+
versions_html += html_utils.create_version_item(
|
448
|
+
version_name=row["name"],
|
449
|
+
created_on=str(row["created_on"]),
|
450
|
+
comment=str(row.get("comment", "")),
|
451
|
+
is_default=bool(row["is_default_version"]),
|
452
|
+
)
|
453
|
+
except Exception:
|
454
|
+
versions_html = html_utils.create_error_message("Error retrieving versions")
|
455
|
+
|
456
|
+
# Get tags
|
457
|
+
try:
|
458
|
+
tags = self.show_tags()
|
459
|
+
if not tags:
|
460
|
+
tags_html = html_utils.create_error_message("No tags available")
|
461
|
+
else:
|
462
|
+
tags_html = ""
|
463
|
+
for tag_name, tag_value in tags.items():
|
464
|
+
tags_html += html_utils.create_tag_item(tag_name, tag_value)
|
465
|
+
except Exception:
|
466
|
+
tags_html = html_utils.create_error_message("Error retrieving tags")
|
467
|
+
|
468
|
+
# Create main content sections
|
469
|
+
main_info = html_utils.create_grid_section(
|
470
|
+
[
|
471
|
+
("Model Name", self.name),
|
472
|
+
("Full Name", self.fully_qualified_name),
|
473
|
+
("Description", self.description),
|
474
|
+
("Default Version", default_version),
|
475
|
+
]
|
476
|
+
)
|
477
|
+
|
478
|
+
versions_section = html_utils.create_section_header("Versions") + html_utils.create_content_section(
|
479
|
+
versions_html
|
480
|
+
)
|
481
|
+
|
482
|
+
tags_section = html_utils.create_section_header("Tags") + html_utils.create_content_section(tags_html)
|
483
|
+
|
484
|
+
content = main_info + versions_section + tags_section
|
485
|
+
|
486
|
+
return html_utils.create_base_container("Model Details", content)
|