snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +71 -0
- snowflake/ml/_internal/utils/service_logger.py +4 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
- snowflake/ml/data/data_connector.py +43 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +3 -2
- snowflake/ml/dataset/dataset_reader.py +22 -6
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +5 -3
- snowflake/ml/jobs/_utils/query_helper.py +20 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
- snowflake/ml/jobs/_utils/spec_utils.py +21 -4
- snowflake/ml/jobs/decorators.py +18 -25
- snowflake/ml/jobs/job.py +137 -37
- snowflake/ml/jobs/manager.py +228 -153
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/model_ops.py +12 -3
- snowflake/ml/model/_client/ops/service_ops.py +324 -138
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +9 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +47 -15
- snowflake/ml/registry/registry.py +109 -64
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/manager.py
CHANGED
@@ -1,18 +1,20 @@
|
|
1
1
|
import logging
|
2
2
|
import pathlib
|
3
3
|
import textwrap
|
4
|
-
from typing import Any, Callable,
|
4
|
+
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
|
5
5
|
from uuid import uuid4
|
6
6
|
|
7
|
+
import pandas as pd
|
7
8
|
import yaml
|
8
9
|
|
9
10
|
from snowflake import snowpark
|
10
11
|
from snowflake.ml._internal import telemetry
|
11
12
|
from snowflake.ml._internal.utils import identifier
|
12
13
|
from snowflake.ml.jobs import job as jb
|
13
|
-
from snowflake.ml.jobs._utils import payload_utils, spec_utils
|
14
|
+
from snowflake.ml.jobs._utils import payload_utils, query_helper, spec_utils
|
14
15
|
from snowflake.snowpark.context import get_active_session
|
15
16
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
17
|
+
from snowflake.snowpark.functions import coalesce, col, lit, when
|
16
18
|
|
17
19
|
logger = logging.getLogger(__name__)
|
18
20
|
|
@@ -25,39 +27,127 @@ T = TypeVar("T")
|
|
25
27
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
|
26
28
|
def list_jobs(
|
27
29
|
limit: int = 10,
|
28
|
-
|
30
|
+
database: Optional[str] = None,
|
31
|
+
schema: Optional[str] = None,
|
29
32
|
session: Optional[snowpark.Session] = None,
|
30
|
-
) ->
|
33
|
+
) -> pd.DataFrame:
|
31
34
|
"""
|
32
|
-
Returns a
|
35
|
+
Returns a Pandas DataFrame with the list of jobs in the current session.
|
33
36
|
|
34
37
|
Args:
|
35
38
|
limit: The maximum number of jobs to return. Non-positive values are treated as no limit.
|
36
|
-
|
39
|
+
database: The database to use. If not specified, uses the current database.
|
40
|
+
schema: The schema to use. If not specified, uses the current schema.
|
37
41
|
session: The Snowpark session to use. If none specified, uses active session.
|
38
42
|
|
39
43
|
Returns:
|
40
44
|
A DataFrame with the list of jobs.
|
41
45
|
|
46
|
+
Raises:
|
47
|
+
SnowparkSQLException: if there is an error retrieving the job history.
|
48
|
+
|
42
49
|
Examples:
|
43
50
|
>>> from snowflake.ml.jobs import list_jobs
|
44
|
-
>>> list_jobs(limit=5)
|
51
|
+
>>> list_jobs(limit=5)
|
45
52
|
"""
|
46
53
|
session = session or get_active_session()
|
54
|
+
try:
|
55
|
+
df = _get_job_history_spcs(
|
56
|
+
session,
|
57
|
+
limit=limit,
|
58
|
+
database=database,
|
59
|
+
schema=schema,
|
60
|
+
)
|
61
|
+
return df.to_pandas()
|
62
|
+
except SnowparkSQLException as spcs_error:
|
63
|
+
if spcs_error.sql_error_code == 2143:
|
64
|
+
logger.debug("Job history is not enabled. Please enable it to use this feature.")
|
65
|
+
df = _get_job_services(session, limit=limit, database=database, schema=schema)
|
66
|
+
return df.to_pandas()
|
67
|
+
raise
|
68
|
+
|
69
|
+
|
70
|
+
def _get_job_services(
|
71
|
+
session: snowpark.Session, limit: int = 10, database: Optional[str] = None, schema: Optional[str] = None
|
72
|
+
) -> snowpark.DataFrame:
|
47
73
|
query = "SHOW JOB SERVICES"
|
48
74
|
query += f" LIKE '{JOB_ID_PREFIX}%'"
|
49
|
-
|
50
|
-
|
75
|
+
database = database or session.get_current_database()
|
76
|
+
schema = schema or session.get_current_schema()
|
77
|
+
if database is None and schema is None:
|
78
|
+
query += "IN account"
|
79
|
+
elif not schema:
|
80
|
+
query += f" IN DATABASE {database}"
|
81
|
+
else:
|
82
|
+
query += f" IN {database}.{schema}"
|
51
83
|
if limit > 0:
|
52
84
|
query += f" LIMIT {limit}"
|
53
85
|
df = session.sql(query)
|
54
86
|
df = df.select(
|
55
87
|
df['"name"'],
|
56
|
-
df['"owner"'],
|
57
88
|
df['"status"'],
|
58
|
-
|
89
|
+
lit(None).alias('"message"'),
|
90
|
+
df['"database_name"'],
|
91
|
+
df['"schema_name"'],
|
92
|
+
df['"owner"'],
|
59
93
|
df['"compute_pool"'],
|
60
|
-
|
94
|
+
df['"target_instances"'],
|
95
|
+
df['"created_on"'].alias('"created_time"'),
|
96
|
+
when(col('"status"').isin(jb.TERMINAL_JOB_STATUSES), col('"updated_on"'))
|
97
|
+
.otherwise(lit(None))
|
98
|
+
.alias('"completed_time"'),
|
99
|
+
).order_by('"created_time"', ascending=False)
|
100
|
+
return df
|
101
|
+
|
102
|
+
|
103
|
+
def _get_job_history_spcs(
|
104
|
+
session: snowpark.Session,
|
105
|
+
limit: int = 10,
|
106
|
+
database: Optional[str] = None,
|
107
|
+
schema: Optional[str] = None,
|
108
|
+
include_deleted: bool = False,
|
109
|
+
created_time_start: Optional[str] = None,
|
110
|
+
created_time_end: Optional[str] = None,
|
111
|
+
) -> snowpark.DataFrame:
|
112
|
+
query = ["select * from table(snowflake.spcs.get_job_history("]
|
113
|
+
query_params = []
|
114
|
+
if created_time_start:
|
115
|
+
query_params.append(f"created_time_start => TO_TIMESTAMP_LTZ('{created_time_start}')")
|
116
|
+
if created_time_end:
|
117
|
+
query_params.append(f"created_time_end => TO_TIMESTAMP_LTZ('{created_time_end}')")
|
118
|
+
query.append(",".join(query_params))
|
119
|
+
query.append("))")
|
120
|
+
condition = []
|
121
|
+
database = database or session.get_current_database()
|
122
|
+
schema = schema or session.get_current_schema()
|
123
|
+
|
124
|
+
# format database and schema identifiers
|
125
|
+
if database:
|
126
|
+
condition.append(f"DATABASE_NAME = '{identifier.resolve_identifier(database)}'")
|
127
|
+
|
128
|
+
if schema:
|
129
|
+
condition.append(f"SCHEMA_NAME = '{identifier.resolve_identifier(schema)}'")
|
130
|
+
|
131
|
+
if not include_deleted:
|
132
|
+
condition.append("DELETED_TIME IS NULL")
|
133
|
+
|
134
|
+
if len(condition) > 0:
|
135
|
+
query.append("WHERE " + " AND ".join(condition))
|
136
|
+
if limit > 0:
|
137
|
+
query.append(f"LIMIT {limit}")
|
138
|
+
df = session.sql("\n".join(query))
|
139
|
+
df = df.select(
|
140
|
+
df["NAME"].alias('"name"'),
|
141
|
+
df["STATUS"].alias('"status"'),
|
142
|
+
df["MESSAGE"].alias('"message"'),
|
143
|
+
df["DATABASE_NAME"].alias('"database_name"'),
|
144
|
+
df["SCHEMA_NAME"].alias('"schema_name"'),
|
145
|
+
df["OWNER"].alias('"owner"'),
|
146
|
+
df["COMPUTE_POOL_NAME"].alias('"compute_pool"'),
|
147
|
+
coalesce(df["PARAMETERS"]["REPLICAS"], lit(1)).alias('"target_instances"'),
|
148
|
+
df["CREATED_TIME"].alias('"created_time"'),
|
149
|
+
df["COMPLETED_TIME"].alias('"completed_time"'),
|
150
|
+
)
|
61
151
|
return df
|
62
152
|
|
63
153
|
|
@@ -74,9 +164,9 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
74
164
|
|
75
165
|
job_id = f"{database}.{schema}.{job_name}"
|
76
166
|
try:
|
77
|
-
# Validate that job exists by doing a
|
167
|
+
# Validate that job exists by doing a spec lookup
|
78
168
|
job = jb.MLJob[Any](job_id, session=session)
|
79
|
-
_ = job.
|
169
|
+
_ = job._service_spec
|
80
170
|
return job
|
81
171
|
except SnowparkSQLException as e:
|
82
172
|
if "does not exist" in e.message:
|
@@ -95,7 +185,7 @@ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Sessio
|
|
95
185
|
logger.info(f"Successfully cleaned up stage files for job {job.id} at {stage_path}")
|
96
186
|
except Exception as e:
|
97
187
|
logger.warning(f"Failed to clean up stage files for job {job.id}: {e}")
|
98
|
-
|
188
|
+
query_helper.run_query(session, "DROP SERVICE IDENTIFIER(?)", params=(job.id,))
|
99
189
|
|
100
190
|
|
101
191
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
@@ -105,17 +195,11 @@ def submit_file(
|
|
105
195
|
*,
|
106
196
|
stage_name: str,
|
107
197
|
args: Optional[list[str]] = None,
|
108
|
-
|
198
|
+
target_instances: int = 1,
|
109
199
|
pip_requirements: Optional[list[str]] = None,
|
110
200
|
external_access_integrations: Optional[list[str]] = None,
|
111
|
-
query_warehouse: Optional[str] = None,
|
112
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
113
|
-
target_instances: int = 1,
|
114
|
-
min_instances: Optional[int] = None,
|
115
|
-
enable_metrics: bool = False,
|
116
|
-
database: Optional[str] = None,
|
117
|
-
schema: Optional[str] = None,
|
118
201
|
session: Optional[snowpark.Session] = None,
|
202
|
+
**kwargs: Any,
|
119
203
|
) -> jb.MLJob[None]:
|
120
204
|
"""
|
121
205
|
Submit a Python file as a job to the compute pool.
|
@@ -125,18 +209,20 @@ def submit_file(
|
|
125
209
|
compute_pool: The compute pool to use for the job.
|
126
210
|
stage_name: The name of the stage where the job payload will be uploaded.
|
127
211
|
args: A list of arguments to pass to the job.
|
128
|
-
|
212
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
129
213
|
pip_requirements: A list of pip requirements for the job.
|
130
214
|
external_access_integrations: A list of external access integrations.
|
131
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
132
|
-
spec_overrides: Custom service specification overrides to apply.
|
133
|
-
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
134
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
135
|
-
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
136
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
137
|
-
database: The database to use.
|
138
|
-
schema: The schema to use.
|
139
215
|
session: The Snowpark session to use. If none specified, uses active session.
|
216
|
+
kwargs: Additional keyword arguments. Supported arguments:
|
217
|
+
database (str): The database to use for the job.
|
218
|
+
schema (str): The schema to use for the job.
|
219
|
+
min_instances (int): The minimum number of nodes required to start the job.
|
220
|
+
If none specified, defaults to target_instances. If set, the job
|
221
|
+
will not start until the minimum number of nodes is available.
|
222
|
+
env_vars (dict): Environment variables to set in container.
|
223
|
+
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
224
|
+
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
225
|
+
spec_overrides (dict): A dictionary of overrides for the service spec.
|
140
226
|
|
141
227
|
Returns:
|
142
228
|
An object representing the submitted job.
|
@@ -146,17 +232,11 @@ def submit_file(
|
|
146
232
|
args=args,
|
147
233
|
compute_pool=compute_pool,
|
148
234
|
stage_name=stage_name,
|
149
|
-
|
235
|
+
target_instances=target_instances,
|
150
236
|
pip_requirements=pip_requirements,
|
151
237
|
external_access_integrations=external_access_integrations,
|
152
|
-
query_warehouse=query_warehouse,
|
153
|
-
spec_overrides=spec_overrides,
|
154
|
-
target_instances=target_instances,
|
155
|
-
min_instances=min_instances,
|
156
|
-
enable_metrics=enable_metrics,
|
157
|
-
database=database,
|
158
|
-
schema=schema,
|
159
238
|
session=session,
|
239
|
+
**kwargs,
|
160
240
|
)
|
161
241
|
|
162
242
|
|
@@ -168,17 +248,11 @@ def submit_directory(
|
|
168
248
|
entrypoint: str,
|
169
249
|
stage_name: str,
|
170
250
|
args: Optional[list[str]] = None,
|
171
|
-
|
251
|
+
target_instances: int = 1,
|
172
252
|
pip_requirements: Optional[list[str]] = None,
|
173
253
|
external_access_integrations: Optional[list[str]] = None,
|
174
|
-
query_warehouse: Optional[str] = None,
|
175
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
176
|
-
target_instances: int = 1,
|
177
|
-
min_instances: Optional[int] = None,
|
178
|
-
enable_metrics: bool = False,
|
179
|
-
database: Optional[str] = None,
|
180
|
-
schema: Optional[str] = None,
|
181
254
|
session: Optional[snowpark.Session] = None,
|
255
|
+
**kwargs: Any,
|
182
256
|
) -> jb.MLJob[None]:
|
183
257
|
"""
|
184
258
|
Submit a directory containing Python script(s) as a job to the compute pool.
|
@@ -189,18 +263,20 @@ def submit_directory(
|
|
189
263
|
entrypoint: The relative path to the entry point script inside the source directory.
|
190
264
|
stage_name: The name of the stage where the job payload will be uploaded.
|
191
265
|
args: A list of arguments to pass to the job.
|
192
|
-
|
266
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
193
267
|
pip_requirements: A list of pip requirements for the job.
|
194
268
|
external_access_integrations: A list of external access integrations.
|
195
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
196
|
-
spec_overrides: Custom service specification overrides to apply.
|
197
|
-
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
198
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
199
|
-
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
200
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
201
|
-
database: The database to use.
|
202
|
-
schema: The schema to use.
|
203
269
|
session: The Snowpark session to use. If none specified, uses active session.
|
270
|
+
kwargs: Additional keyword arguments. Supported arguments:
|
271
|
+
database (str): The database to use for the job.
|
272
|
+
schema (str): The schema to use for the job.
|
273
|
+
min_instances (int): The minimum number of nodes required to start the job.
|
274
|
+
If none specified, defaults to target_instances. If set, the job
|
275
|
+
will not start until the minimum number of nodes is available.
|
276
|
+
env_vars (dict): Environment variables to set in container.
|
277
|
+
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
278
|
+
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
279
|
+
spec_overrides (dict): A dictionary of overrides for the service spec.
|
204
280
|
|
205
281
|
Returns:
|
206
282
|
An object representing the submitted job.
|
@@ -211,17 +287,11 @@ def submit_directory(
|
|
211
287
|
args=args,
|
212
288
|
compute_pool=compute_pool,
|
213
289
|
stage_name=stage_name,
|
214
|
-
|
290
|
+
target_instances=target_instances,
|
215
291
|
pip_requirements=pip_requirements,
|
216
292
|
external_access_integrations=external_access_integrations,
|
217
|
-
query_warehouse=query_warehouse,
|
218
|
-
spec_overrides=spec_overrides,
|
219
|
-
target_instances=target_instances,
|
220
|
-
min_instances=min_instances,
|
221
|
-
enable_metrics=enable_metrics,
|
222
|
-
database=database,
|
223
|
-
schema=schema,
|
224
293
|
session=session,
|
294
|
+
**kwargs,
|
225
295
|
)
|
226
296
|
|
227
297
|
|
@@ -233,17 +303,11 @@ def submit_from_stage(
|
|
233
303
|
entrypoint: str,
|
234
304
|
stage_name: str,
|
235
305
|
args: Optional[list[str]] = None,
|
236
|
-
|
306
|
+
target_instances: int = 1,
|
237
307
|
pip_requirements: Optional[list[str]] = None,
|
238
308
|
external_access_integrations: Optional[list[str]] = None,
|
239
|
-
query_warehouse: Optional[str] = None,
|
240
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
241
|
-
target_instances: int = 1,
|
242
|
-
min_instances: Optional[int] = None,
|
243
|
-
enable_metrics: bool = False,
|
244
|
-
database: Optional[str] = None,
|
245
|
-
schema: Optional[str] = None,
|
246
309
|
session: Optional[snowpark.Session] = None,
|
310
|
+
**kwargs: Any,
|
247
311
|
) -> jb.MLJob[None]:
|
248
312
|
"""
|
249
313
|
Submit a directory containing Python script(s) as a job to the compute pool.
|
@@ -254,19 +318,20 @@ def submit_from_stage(
|
|
254
318
|
entrypoint: a stage path containing the entry point script inside the source directory.
|
255
319
|
stage_name: The name of the stage where the job payload will be uploaded.
|
256
320
|
args: A list of arguments to pass to the job.
|
257
|
-
|
321
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
258
322
|
pip_requirements: A list of pip requirements for the job.
|
259
323
|
external_access_integrations: A list of external access integrations.
|
260
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
261
|
-
spec_overrides: Custom service specification overrides to apply.
|
262
|
-
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
263
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
264
|
-
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
265
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
266
|
-
database: The database to use.
|
267
|
-
schema: The schema to use.
|
268
324
|
session: The Snowpark session to use. If none specified, uses active session.
|
269
|
-
|
325
|
+
kwargs: Additional keyword arguments. Supported arguments:
|
326
|
+
database (str): The database to use for the job.
|
327
|
+
schema (str): The schema to use for the job.
|
328
|
+
min_instances (int): The minimum number of nodes required to start the job.
|
329
|
+
If none specified, defaults to target_instances. If set, the job
|
330
|
+
will not start until the minimum number of nodes is available.
|
331
|
+
env_vars (dict): Environment variables to set in container.
|
332
|
+
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
333
|
+
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
334
|
+
spec_overrides (dict): A dictionary of overrides for the service spec.
|
270
335
|
|
271
336
|
Returns:
|
272
337
|
An object representing the submitted job.
|
@@ -277,17 +342,11 @@ def submit_from_stage(
|
|
277
342
|
args=args,
|
278
343
|
compute_pool=compute_pool,
|
279
344
|
stage_name=stage_name,
|
280
|
-
|
345
|
+
target_instances=target_instances,
|
281
346
|
pip_requirements=pip_requirements,
|
282
347
|
external_access_integrations=external_access_integrations,
|
283
|
-
query_warehouse=query_warehouse,
|
284
|
-
spec_overrides=spec_overrides,
|
285
|
-
target_instances=target_instances,
|
286
|
-
min_instances=min_instances,
|
287
|
-
enable_metrics=enable_metrics,
|
288
|
-
database=database,
|
289
|
-
schema=schema,
|
290
348
|
session=session,
|
349
|
+
**kwargs,
|
291
350
|
)
|
292
351
|
|
293
352
|
|
@@ -299,17 +358,11 @@ def _submit_job(
|
|
299
358
|
stage_name: str,
|
300
359
|
entrypoint: Optional[str] = None,
|
301
360
|
args: Optional[list[str]] = None,
|
302
|
-
|
361
|
+
target_instances: int = 1,
|
303
362
|
pip_requirements: Optional[list[str]] = None,
|
304
363
|
external_access_integrations: Optional[list[str]] = None,
|
305
|
-
query_warehouse: Optional[str] = None,
|
306
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
307
|
-
target_instances: int = 1,
|
308
|
-
min_instances: Optional[int] = None,
|
309
|
-
enable_metrics: bool = False,
|
310
|
-
database: Optional[str] = None,
|
311
|
-
schema: Optional[str] = None,
|
312
364
|
session: Optional[snowpark.Session] = None,
|
365
|
+
**kwargs: Any,
|
313
366
|
) -> jb.MLJob[None]:
|
314
367
|
...
|
315
368
|
|
@@ -322,17 +375,11 @@ def _submit_job(
|
|
322
375
|
stage_name: str,
|
323
376
|
entrypoint: Optional[str] = None,
|
324
377
|
args: Optional[list[str]] = None,
|
325
|
-
|
378
|
+
target_instances: int = 1,
|
326
379
|
pip_requirements: Optional[list[str]] = None,
|
327
380
|
external_access_integrations: Optional[list[str]] = None,
|
328
|
-
query_warehouse: Optional[str] = None,
|
329
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
330
|
-
target_instances: int = 1,
|
331
|
-
min_instances: Optional[int] = None,
|
332
|
-
enable_metrics: bool = False,
|
333
|
-
database: Optional[str] = None,
|
334
|
-
schema: Optional[str] = None,
|
335
381
|
session: Optional[snowpark.Session] = None,
|
382
|
+
**kwargs: Any,
|
336
383
|
) -> jb.MLJob[T]:
|
337
384
|
...
|
338
385
|
|
@@ -345,8 +392,9 @@ def _submit_job(
|
|
345
392
|
# TODO: Log lengths of args, env_vars, and spec_overrides values
|
346
393
|
"pip_requirements",
|
347
394
|
"external_access_integrations",
|
395
|
+
"num_instances", # deprecated
|
348
396
|
"target_instances",
|
349
|
-
"
|
397
|
+
"min_instances",
|
350
398
|
],
|
351
399
|
)
|
352
400
|
def _submit_job(
|
@@ -356,17 +404,9 @@ def _submit_job(
|
|
356
404
|
stage_name: str,
|
357
405
|
entrypoint: Optional[str] = None,
|
358
406
|
args: Optional[list[str]] = None,
|
359
|
-
env_vars: Optional[dict[str, str]] = None,
|
360
|
-
pip_requirements: Optional[list[str]] = None,
|
361
|
-
external_access_integrations: Optional[list[str]] = None,
|
362
|
-
query_warehouse: Optional[str] = None,
|
363
|
-
spec_overrides: Optional[dict[str, Any]] = None,
|
364
407
|
target_instances: int = 1,
|
365
|
-
min_instances: Optional[int] = None,
|
366
|
-
enable_metrics: bool = False,
|
367
|
-
database: Optional[str] = None,
|
368
|
-
schema: Optional[str] = None,
|
369
408
|
session: Optional[snowpark.Session] = None,
|
409
|
+
**kwargs: Any,
|
370
410
|
) -> jb.MLJob[T]:
|
371
411
|
"""
|
372
412
|
Submit a job to the compute pool.
|
@@ -377,37 +417,48 @@ def _submit_job(
|
|
377
417
|
stage_name: The name of the stage where the job payload will be uploaded.
|
378
418
|
entrypoint: The entry point for the job execution. Required if source is a directory.
|
379
419
|
args: A list of arguments to pass to the job.
|
380
|
-
env_vars: Environment variables to set in container
|
381
|
-
pip_requirements: A list of pip requirements for the job.
|
382
|
-
external_access_integrations: A list of external access integrations.
|
383
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
384
|
-
spec_overrides: Custom service specification overrides to apply.
|
385
420
|
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
386
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
387
|
-
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
388
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
389
|
-
database: The database to use.
|
390
|
-
schema: The schema to use.
|
391
421
|
session: The Snowpark session to use. If none specified, uses active session.
|
422
|
+
kwargs: Additional keyword arguments.
|
392
423
|
|
393
424
|
Returns:
|
394
425
|
An object representing the submitted job.
|
395
426
|
|
396
427
|
Raises:
|
397
|
-
RuntimeError: If required Snowflake features are not enabled.
|
398
428
|
ValueError: If database or schema value(s) are invalid
|
429
|
+
SnowparkSQLException: If there is an error submitting the job.
|
399
430
|
"""
|
431
|
+
session = session or get_active_session()
|
432
|
+
|
433
|
+
# Check for deprecated args
|
434
|
+
if "num_instances" in kwargs:
|
435
|
+
logger.warning(
|
436
|
+
"'num_instances' is deprecated and will be removed in a future release. Use 'target_instances' instead."
|
437
|
+
)
|
438
|
+
target_instances = max(target_instances, kwargs.pop("num_instances"))
|
439
|
+
|
440
|
+
# Use kwargs for less common optional parameters
|
441
|
+
database = kwargs.pop("database", None)
|
442
|
+
schema = kwargs.pop("schema", None)
|
443
|
+
min_instances = kwargs.pop("min_instances", target_instances)
|
444
|
+
pip_requirements = kwargs.pop("pip_requirements", None)
|
445
|
+
external_access_integrations = kwargs.pop("external_access_integrations", None)
|
446
|
+
env_vars = kwargs.pop("env_vars", None)
|
447
|
+
spec_overrides = kwargs.pop("spec_overrides", None)
|
448
|
+
enable_metrics = kwargs.pop("enable_metrics", True)
|
449
|
+
query_warehouse = kwargs.pop("query_warehouse", None)
|
450
|
+
|
451
|
+
# Warn if there are unknown kwargs
|
452
|
+
if kwargs:
|
453
|
+
logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
|
454
|
+
|
455
|
+
# Validate parameters
|
400
456
|
if database and not schema:
|
401
457
|
raise ValueError("Schema must be specified if database is specified.")
|
402
458
|
if target_instances < 1:
|
403
459
|
raise ValueError("target_instances must be greater than 0.")
|
404
|
-
|
405
|
-
min_instances = target_instances if min_instances is None else min_instances
|
406
460
|
if not (0 < min_instances <= target_instances):
|
407
461
|
raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
|
408
|
-
|
409
|
-
session = session or get_active_session()
|
410
|
-
|
411
462
|
if min_instances > 1:
|
412
463
|
# Validate min_instances against compute pool max_nodes
|
413
464
|
pool_info = jb._get_compute_pool_info(session, compute_pool)
|
@@ -418,13 +469,8 @@ def _submit_job(
|
|
418
469
|
f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
|
419
470
|
)
|
420
471
|
|
421
|
-
# Validate database and schema identifiers on client side since
|
422
|
-
# SQL parser for EXECUTE JOB SERVICE seems to struggle with this
|
423
|
-
database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
|
424
|
-
schema = identifier.resolve_identifier(cast(str, schema or session.get_current_schema()))
|
425
|
-
|
426
472
|
job_name = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
427
|
-
job_id =
|
473
|
+
job_id = identifier.get_schema_level_object_identifier(database, schema, job_name)
|
428
474
|
stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
|
429
475
|
stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
|
430
476
|
stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
|
@@ -453,7 +499,48 @@ def _submit_job(
|
|
453
499
|
if spec_overrides:
|
454
500
|
spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
|
455
501
|
|
456
|
-
|
502
|
+
query_text, params = _generate_submission_query(
|
503
|
+
spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id
|
504
|
+
)
|
505
|
+
try:
|
506
|
+
_ = query_helper.run_query(session, query_text, params=params)
|
507
|
+
except SnowparkSQLException as e:
|
508
|
+
if "Invalid spec: unknown option 'resourceManagement' for 'spec'." in e.message:
|
509
|
+
logger.warning("Dropping 'resourceManagement' from spec because control policy is not enabled.")
|
510
|
+
spec["spec"].pop("resourceManagement", None)
|
511
|
+
query_text, params = _generate_submission_query(
|
512
|
+
spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id
|
513
|
+
)
|
514
|
+
_ = query_helper.run_query(session, query_text, params=params)
|
515
|
+
else:
|
516
|
+
raise
|
517
|
+
return get_job(job_id, session=session)
|
518
|
+
|
519
|
+
|
520
|
+
def _generate_submission_query(
|
521
|
+
spec: dict[str, Any],
|
522
|
+
external_access_integrations: list[str],
|
523
|
+
query_warehouse: Optional[str],
|
524
|
+
target_instances: int,
|
525
|
+
session: snowpark.Session,
|
526
|
+
compute_pool: str,
|
527
|
+
job_id: str,
|
528
|
+
) -> tuple[str, list[Any]]:
|
529
|
+
"""
|
530
|
+
Generate the SQL query for job submission.
|
531
|
+
|
532
|
+
Args:
|
533
|
+
spec: The service spec for the job.
|
534
|
+
external_access_integrations: The external access integrations for the job.
|
535
|
+
query_warehouse: The query warehouse for the job.
|
536
|
+
target_instances: The number of instances for the job.
|
537
|
+
session: The Snowpark session to use.
|
538
|
+
compute_pool: The compute pool to use for the job.
|
539
|
+
job_id: The ID of the job.
|
540
|
+
|
541
|
+
Returns:
|
542
|
+
A tuple containing the SQL query text and the parameters for the query.
|
543
|
+
"""
|
457
544
|
query_template = textwrap.dedent(
|
458
545
|
"""\
|
459
546
|
EXECUTE JOB SERVICE
|
@@ -477,17 +564,5 @@ def _submit_job(
|
|
477
564
|
if target_instances > 1:
|
478
565
|
query.append("REPLICAS = ?")
|
479
566
|
params.append(target_instances)
|
480
|
-
|
481
|
-
# Submit job
|
482
567
|
query_text = "\n".join(line for line in query if line)
|
483
|
-
|
484
|
-
try:
|
485
|
-
_ = session.sql(query_text, params=params).collect()
|
486
|
-
except SnowparkSQLException as e:
|
487
|
-
if "invalid property 'ASYNC'" in e.message:
|
488
|
-
raise RuntimeError(
|
489
|
-
"SPCS Async Jobs not enabled. Set parameter `ENABLE_SNOWSERVICES_ASYNC_JOBS = TRUE` to enable."
|
490
|
-
) from e
|
491
|
-
raise
|
492
|
-
|
493
|
-
return jb.MLJob(job_id, service_spec=spec, session=session)
|
568
|
+
return query_text, params
|
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Literal, Optional, Union
|
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
6
|
from snowflake.ml._internal import telemetry
|
7
|
-
from snowflake.ml._internal.utils import identifier
|
7
|
+
from snowflake.ml._internal.utils import identifier, mixins
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
10
|
from snowflake.ml import dataset
|
@@ -15,7 +15,7 @@ _PROJECT = "LINEAGE"
|
|
15
15
|
DOMAIN_LINEAGE_REGISTRY: dict[str, type["LineageNode"]] = {}
|
16
16
|
|
17
17
|
|
18
|
-
class LineageNode:
|
18
|
+
class LineageNode(mixins.SerializableSessionMixin):
|
19
19
|
"""
|
20
20
|
Represents a node in a lineage graph and serves as the base class for all machine learning objects.
|
21
21
|
"""
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import enum
|
2
|
+
import logging
|
2
3
|
import pathlib
|
3
4
|
import tempfile
|
4
5
|
import warnings
|
@@ -10,7 +11,7 @@ from snowflake import snowpark
|
|
10
11
|
from snowflake.ml._internal import telemetry
|
11
12
|
from snowflake.ml._internal.utils import sql_identifier
|
12
13
|
from snowflake.ml.lineage import lineage_node
|
13
|
-
from snowflake.ml.model import type_hints
|
14
|
+
from snowflake.ml.model import task, type_hints
|
14
15
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
15
16
|
from snowflake.ml.model._model_composer import model_composer
|
16
17
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
@@ -401,7 +402,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
401
402
|
project=_TELEMETRY_PROJECT,
|
402
403
|
subproject=_TELEMETRY_SUBPROJECT,
|
403
404
|
)
|
404
|
-
def get_model_task(self) ->
|
405
|
+
def get_model_task(self) -> task.Task:
|
405
406
|
statement_params = telemetry.get_statement_params(
|
406
407
|
project=_TELEMETRY_PROJECT,
|
407
408
|
subproject=_TELEMETRY_SUBPROJECT,
|
@@ -607,8 +608,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
607
608
|
self,
|
608
609
|
*,
|
609
610
|
force: bool = False,
|
610
|
-
options: Optional[
|
611
|
-
) ->
|
611
|
+
options: Optional[type_hints.ModelLoadOption] = None,
|
612
|
+
) -> type_hints.SupportedModelType:
|
612
613
|
"""Load the underlying original Python object back from a model.
|
613
614
|
This operation requires to have the exact the same environment as the one when logging the model, otherwise,
|
614
615
|
the model might be not functional or some other problems might occur.
|
@@ -889,6 +890,17 @@ class ModelVersion(lineage_node.LineageNode):
|
|
889
890
|
project=_TELEMETRY_PROJECT,
|
890
891
|
subproject=_TELEMETRY_SUBPROJECT,
|
891
892
|
)
|
893
|
+
|
894
|
+
# Check root logger level and emit warning if needed
|
895
|
+
root_logger = logging.getLogger()
|
896
|
+
if root_logger.level in (logging.WARNING, logging.ERROR):
|
897
|
+
warnings.warn(
|
898
|
+
"Suppressing service logs. Set the log level to INFO if you would like "
|
899
|
+
"verbose service logs (e.g., logging.getLogger().setLevel(logging.INFO)).",
|
900
|
+
UserWarning,
|
901
|
+
stacklevel=2,
|
902
|
+
)
|
903
|
+
|
892
904
|
if build_external_access_integration is not None:
|
893
905
|
msg = (
|
894
906
|
"`build_external_access_integration` is deprecated. "
|