snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +71 -0
- snowflake/ml/_internal/utils/service_logger.py +4 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
- snowflake/ml/data/data_connector.py +43 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +3 -2
- snowflake/ml/dataset/dataset_reader.py +22 -6
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +5 -3
- snowflake/ml/jobs/_utils/query_helper.py +20 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
- snowflake/ml/jobs/_utils/spec_utils.py +21 -4
- snowflake/ml/jobs/decorators.py +18 -25
- snowflake/ml/jobs/job.py +137 -37
- snowflake/ml/jobs/manager.py +228 -153
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/model_ops.py +12 -3
- snowflake/ml/model/_client/ops/service_ops.py +324 -138
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +9 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +47 -15
- snowflake/ml/registry/registry.py +109 -64
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/decorators.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
from typing import Callable, Optional, TypeVar
|
3
|
+
from typing import Any, Callable, Optional, TypeVar
|
4
4
|
|
5
5
|
from typing_extensions import ParamSpec
|
6
6
|
|
@@ -20,16 +20,11 @@ def remote(
|
|
20
20
|
compute_pool: str,
|
21
21
|
*,
|
22
22
|
stage_name: str,
|
23
|
+
target_instances: int = 1,
|
23
24
|
pip_requirements: Optional[list[str]] = None,
|
24
25
|
external_access_integrations: Optional[list[str]] = None,
|
25
|
-
query_warehouse: Optional[str] = None,
|
26
|
-
env_vars: Optional[dict[str, str]] = None,
|
27
|
-
target_instances: int = 1,
|
28
|
-
min_instances: Optional[int] = None,
|
29
|
-
enable_metrics: bool = False,
|
30
|
-
database: Optional[str] = None,
|
31
|
-
schema: Optional[str] = None,
|
32
26
|
session: Optional[snowpark.Session] = None,
|
27
|
+
**kwargs: Any,
|
33
28
|
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
|
34
29
|
"""
|
35
30
|
Submit a job to the compute pool.
|
@@ -37,17 +32,20 @@ def remote(
|
|
37
32
|
Args:
|
38
33
|
compute_pool: The compute pool to use for the job.
|
39
34
|
stage_name: The name of the stage where the job payload will be uploaded.
|
35
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
40
36
|
pip_requirements: A list of pip requirements for the job.
|
41
37
|
external_access_integrations: A list of external access integrations.
|
42
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
43
|
-
env_vars: Environment variables to set in container
|
44
|
-
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
45
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
46
|
-
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
47
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
48
|
-
database: The database to use for the job.
|
49
|
-
schema: The schema to use for the job.
|
50
38
|
session: The Snowpark session to use. If none specified, uses active session.
|
39
|
+
kwargs: Additional keyword arguments. Supported arguments:
|
40
|
+
database (str): The database to use for the job.
|
41
|
+
schema (str): The schema to use for the job.
|
42
|
+
min_instances (int): The minimum number of nodes required to start the job.
|
43
|
+
If none specified, defaults to target_instances. If set, the job
|
44
|
+
will not start until the minimum number of nodes is available.
|
45
|
+
env_vars (dict): Environment variables to set in container.
|
46
|
+
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
47
|
+
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
48
|
+
spec_overrides (dict): A dictionary of overrides for the service spec.
|
51
49
|
|
52
50
|
Returns:
|
53
51
|
Decorator that dispatches invocations of the decorated function as remote jobs.
|
@@ -61,22 +59,17 @@ def remote(
|
|
61
59
|
wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
|
62
60
|
|
63
61
|
@functools.wraps(func)
|
64
|
-
def wrapper(*
|
65
|
-
payload = payload_utils.create_function_payload(func, *
|
62
|
+
def wrapper(*_args: _Args.args, **_kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
|
63
|
+
payload = payload_utils.create_function_payload(func, *_args, **_kwargs)
|
66
64
|
job = jm._submit_job(
|
67
65
|
source=payload,
|
68
66
|
stage_name=stage_name,
|
69
67
|
compute_pool=compute_pool,
|
68
|
+
target_instances=target_instances,
|
70
69
|
pip_requirements=pip_requirements,
|
71
70
|
external_access_integrations=external_access_integrations,
|
72
|
-
query_warehouse=query_warehouse,
|
73
|
-
env_vars=env_vars,
|
74
|
-
target_instances=target_instances,
|
75
|
-
min_instances=min_instances,
|
76
|
-
enable_metrics=enable_metrics,
|
77
|
-
database=database,
|
78
|
-
schema=schema,
|
79
71
|
session=payload.session or session,
|
72
|
+
**kwargs,
|
80
73
|
)
|
81
74
|
assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
|
82
75
|
return job
|
snowflake/ml/jobs/job.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import json
|
1
2
|
import logging
|
2
3
|
import os
|
3
4
|
import time
|
@@ -9,7 +10,8 @@ import yaml
|
|
9
10
|
from snowflake import snowpark
|
10
11
|
from snowflake.ml._internal import telemetry
|
11
12
|
from snowflake.ml._internal.utils import identifier
|
12
|
-
from snowflake.ml.
|
13
|
+
from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
|
14
|
+
from snowflake.ml.jobs._utils import constants, interop_utils, query_helper, types
|
13
15
|
from snowflake.snowpark import Row, context as sp_context
|
14
16
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
15
17
|
|
@@ -21,7 +23,7 @@ T = TypeVar("T")
|
|
21
23
|
logger = logging.getLogger(__name__)
|
22
24
|
|
23
25
|
|
24
|
-
class MLJob(Generic[T]):
|
26
|
+
class MLJob(Generic[T], SerializableSessionMixin):
|
25
27
|
def __init__(
|
26
28
|
self,
|
27
29
|
id: str,
|
@@ -220,6 +222,19 @@ class MLJob(Generic[T]):
|
|
220
222
|
return cast(T, self._result.result)
|
221
223
|
raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
|
222
224
|
|
225
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
226
|
+
def cancel(self) -> None:
|
227
|
+
"""
|
228
|
+
Cancel the running job.
|
229
|
+
Raises:
|
230
|
+
RuntimeError: If cancellation fails. # noqa: DAR401
|
231
|
+
"""
|
232
|
+
try:
|
233
|
+
self._session.sql(f"CALL {self.id}!spcs_cancel_job()").collect()
|
234
|
+
logger.debug(f"Cancellation requested for job {self.id}")
|
235
|
+
except SnowparkSQLException as e:
|
236
|
+
raise RuntimeError(f"Failed to cancel job {self.id}: {e.message}") from e
|
237
|
+
|
223
238
|
|
224
239
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
|
225
240
|
def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
|
@@ -262,6 +277,7 @@ def _get_logs(
|
|
262
277
|
|
263
278
|
Raises:
|
264
279
|
RuntimeError: if failed to get head instance_id
|
280
|
+
SnowparkSQLException: if there is an error retrieving logs from SPCS interface.
|
265
281
|
"""
|
266
282
|
# If instance_id is not specified, try to get the head instance ID
|
267
283
|
if instance_id is None:
|
@@ -279,30 +295,55 @@ def _get_logs(
|
|
279
295
|
if limit > 0:
|
280
296
|
params.append(limit)
|
281
297
|
try:
|
282
|
-
(row,) =
|
298
|
+
(row,) = query_helper.run_query(
|
299
|
+
session,
|
283
300
|
f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
|
284
301
|
params=params,
|
285
|
-
)
|
302
|
+
)
|
303
|
+
full_log = str(row[0])
|
286
304
|
except SnowparkSQLException as e:
|
287
305
|
if "Container Status: PENDING" in e.message:
|
288
306
|
logger.warning("Waiting for container to start. Logs will be shown when available.")
|
289
307
|
return ""
|
290
308
|
else:
|
291
|
-
#
|
292
|
-
#
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
309
|
+
# Fallback plan:
|
310
|
+
# 1. Try SPCS Interface (doesn't require event table permission)
|
311
|
+
# 2. If the interface call fails, query Event Table (requires permission)
|
312
|
+
logger.debug("falling back to SPCS Interface for logs")
|
313
|
+
try:
|
314
|
+
logs = _get_logs_spcs(
|
315
|
+
session,
|
316
|
+
job_id,
|
317
|
+
limit=limit,
|
318
|
+
instance_id=instance_id if instance_id else 0,
|
319
|
+
container_name=constants.DEFAULT_CONTAINER_NAME,
|
302
320
|
)
|
303
|
-
|
304
|
-
|
305
|
-
|
321
|
+
full_log = os.linesep.join(row[0] for row in logs)
|
322
|
+
|
323
|
+
except SnowparkSQLException as spcs_error:
|
324
|
+
if spcs_error.sql_error_code == 2143:
|
325
|
+
logger.debug("persistent logs may not be enabled, falling back to event table")
|
326
|
+
else:
|
327
|
+
# If SPCS Interface fails for any other reason,
|
328
|
+
# for example, incorrect argument format,raise the error directly
|
329
|
+
raise
|
330
|
+
# event table accepts job name, not fully qualified name
|
331
|
+
db, schema, name = identifier.parse_schema_level_object_identifier(job_id)
|
332
|
+
db = db or session.get_current_database()
|
333
|
+
schema = schema or session.get_current_schema()
|
334
|
+
event_table_logs = _get_service_log_from_event_table(
|
335
|
+
session,
|
336
|
+
name,
|
337
|
+
database=db,
|
338
|
+
schema=schema,
|
339
|
+
instance_id=instance_id if instance_id else 0,
|
340
|
+
limit=limit,
|
341
|
+
)
|
342
|
+
if len(event_table_logs) == 0:
|
343
|
+
raise RuntimeError(
|
344
|
+
"No logs were found. Please verify that the database, schema, and job ID are correct."
|
345
|
+
)
|
346
|
+
full_log = os.linesep.join(json.loads(row[0]) for row in event_table_logs)
|
306
347
|
|
307
348
|
# If verbose is True, return the complete log
|
308
349
|
if verbose:
|
@@ -340,13 +381,21 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
340
381
|
RuntimeError: If the instances died or if some instances disappeared.
|
341
382
|
"""
|
342
383
|
|
343
|
-
|
384
|
+
try:
|
385
|
+
target_instances = _get_target_instances(session, job_id)
|
386
|
+
except SnowparkSQLException:
|
387
|
+
# service may be deleted
|
388
|
+
raise RuntimeError("Couldn’t retrieve service information")
|
344
389
|
|
345
390
|
if target_instances == 1:
|
346
391
|
return 0
|
347
392
|
|
348
393
|
try:
|
349
|
-
rows =
|
394
|
+
rows = query_helper.run_query(
|
395
|
+
session,
|
396
|
+
"SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)",
|
397
|
+
params=(job_id,),
|
398
|
+
)
|
350
399
|
except SnowparkSQLException:
|
351
400
|
# service may be deleted
|
352
401
|
raise RuntimeError("Couldn’t retrieve instances")
|
@@ -373,19 +422,29 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
373
422
|
|
374
423
|
|
375
424
|
def _get_service_log_from_event_table(
|
376
|
-
session: snowpark.Session,
|
425
|
+
session: snowpark.Session,
|
426
|
+
name: str,
|
427
|
+
database: Optional[str] = None,
|
428
|
+
schema: Optional[str] = None,
|
429
|
+
instance_id: Optional[int] = None,
|
430
|
+
limit: int = -1,
|
377
431
|
) -> list[Row]:
|
432
|
+
event_table_name = session.sql("SHOW PARAMETERS LIKE 'event_table' IN ACCOUNT").collect()[0]["value"]
|
433
|
+
query = [
|
434
|
+
"SELECT VALUE FROM IDENTIFIER(?)",
|
435
|
+
'WHERE RESOURCE_ATTRIBUTES:"snow.service.name" = ?',
|
436
|
+
]
|
378
437
|
params: list[Any] = [
|
379
|
-
|
380
|
-
schema,
|
438
|
+
event_table_name,
|
381
439
|
name,
|
382
440
|
]
|
383
|
-
|
384
|
-
"
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
441
|
+
if database:
|
442
|
+
query.append('AND RESOURCE_ATTRIBUTES:"snow.database.name" = ?')
|
443
|
+
params.append(database)
|
444
|
+
|
445
|
+
if schema:
|
446
|
+
query.append('AND RESOURCE_ATTRIBUTES:"snow.schema.name" = ?')
|
447
|
+
params.append(schema)
|
389
448
|
|
390
449
|
if instance_id:
|
391
450
|
query.append('AND RESOURCE_ATTRIBUTES:"snow.service.container.instance" = ?')
|
@@ -398,16 +457,18 @@ def _get_service_log_from_event_table(
|
|
398
457
|
if limit > 0:
|
399
458
|
query.append("LIMIT ?")
|
400
459
|
params.append(limit)
|
401
|
-
|
402
|
-
|
460
|
+
# the wrap used in query_helper does not have return type.
|
461
|
+
# sticking a # type: ignore[no-any-return] is to pass type check
|
462
|
+
rows = query_helper.run_query(
|
463
|
+
session,
|
403
464
|
"\n".join(line for line in query if line),
|
404
465
|
params=params,
|
405
|
-
)
|
406
|
-
return rows
|
466
|
+
)
|
467
|
+
return rows # type: ignore[no-any-return]
|
407
468
|
|
408
469
|
|
409
|
-
def _get_service_info(session: snowpark.Session, job_id: str) ->
|
410
|
-
(row,) =
|
470
|
+
def _get_service_info(session: snowpark.Session, job_id: str) -> Any:
|
471
|
+
(row,) = query_helper.run_query(session, "DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,))
|
411
472
|
return row
|
412
473
|
|
413
474
|
|
@@ -426,8 +487,10 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
|
|
426
487
|
ValueError: If the compute pool is not found.
|
427
488
|
"""
|
428
489
|
try:
|
429
|
-
|
430
|
-
return
|
490
|
+
# the wrap used in query_helper does not have return type.
|
491
|
+
# sticking a # type: ignore[no-any-return] is to pass type check
|
492
|
+
(pool_info,) = query_helper.run_query(session, "SHOW COMPUTE POOLS LIKE ?", params=(compute_pool,))
|
493
|
+
return pool_info # type: ignore[no-any-return]
|
431
494
|
except ValueError as e:
|
432
495
|
if "not enough values to unpack" in str(e):
|
433
496
|
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
@@ -438,3 +501,40 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row:
|
|
438
501
|
def _get_target_instances(session: snowpark.Session, job_id: str) -> int:
|
439
502
|
row = _get_service_info(session, job_id)
|
440
503
|
return int(row["target_instances"])
|
504
|
+
|
505
|
+
|
506
|
+
def _get_logs_spcs(
|
507
|
+
session: snowpark.Session,
|
508
|
+
fully_qualified_name: str,
|
509
|
+
limit: int = -1,
|
510
|
+
instance_id: Optional[int] = None,
|
511
|
+
container_name: Optional[str] = None,
|
512
|
+
start_time: Optional[str] = None,
|
513
|
+
end_time: Optional[str] = None,
|
514
|
+
) -> list[Row]:
|
515
|
+
query = [
|
516
|
+
f"SELECT LOG FROM table({fully_qualified_name}!spcs_get_logs(",
|
517
|
+
]
|
518
|
+
conditions_params = []
|
519
|
+
if start_time:
|
520
|
+
conditions_params.append(f"start_time => TO_TIMESTAMP_LTZ('{start_time}')")
|
521
|
+
if end_time:
|
522
|
+
conditions_params.append(f"end_time => TO_TIMESTAMP_LTZ('{end_time}')")
|
523
|
+
if len(conditions_params) > 0:
|
524
|
+
query.append(", ".join(conditions_params))
|
525
|
+
|
526
|
+
query.append("))")
|
527
|
+
|
528
|
+
query_params = []
|
529
|
+
if instance_id is not None:
|
530
|
+
query_params.append(f"INSTANCE_ID = {instance_id}")
|
531
|
+
if container_name:
|
532
|
+
query_params.append(f"CONTAINER_NAME = '{container_name}'")
|
533
|
+
if len(query_params) > 0:
|
534
|
+
query.append("WHERE " + " AND ".join(query_params))
|
535
|
+
|
536
|
+
query.append("ORDER BY TIMESTAMP ASC")
|
537
|
+
if limit > 0:
|
538
|
+
query.append(f" LIMIT {limit};")
|
539
|
+
rows = session.sql("\n".join(query)).collect()
|
540
|
+
return rows
|