snowflake-ml-python 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +16 -0
- snowflake/ml/_internal/platform_capabilities.py +36 -0
- snowflake/ml/_internal/telemetry.py +56 -7
- snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
- snowflake/ml/data/data_connector.py +103 -1
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
- snowflake/ml/experiment/_entities/run.py +15 -0
- snowflake/ml/experiment/callback/keras.py +25 -2
- snowflake/ml/experiment/callback/lightgbm.py +27 -2
- snowflake/ml/experiment/callback/xgboost.py +25 -2
- snowflake/ml/experiment/experiment_tracking.py +123 -13
- snowflake/ml/experiment/utils.py +6 -0
- snowflake/ml/feature_store/access_manager.py +1 -0
- snowflake/ml/feature_store/feature_store.py +1 -1
- snowflake/ml/feature_store/feature_view.py +34 -24
- snowflake/ml/jobs/_interop/protocols.py +3 -0
- snowflake/ml/jobs/_utils/feature_flags.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +360 -357
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
- snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
- snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
- snowflake/ml/jobs/_utils/spec_utils.py +2 -406
- snowflake/ml/jobs/_utils/stage_utils.py +22 -1
- snowflake/ml/jobs/_utils/types.py +14 -7
- snowflake/ml/jobs/job.py +8 -9
- snowflake/ml/jobs/manager.py +64 -129
- snowflake/ml/model/_client/model/inference_engine_utils.py +8 -4
- snowflake/ml/model/_client/model/model_version_impl.py +109 -28
- snowflake/ml/model/_client/ops/model_ops.py +32 -6
- snowflake/ml/model/_client/ops/service_ops.py +9 -4
- snowflake/ml/model/_client/sql/service.py +69 -2
- snowflake/ml/model/_packager/model_handler.py +8 -2
- snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
- snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_signatures/core.py +305 -8
- snowflake/ml/model/_signatures/utils.py +13 -4
- snowflake/ml/model/compute_pool.py +2 -0
- snowflake/ml/model/models/huggingface.py +285 -0
- snowflake/ml/model/models/huggingface_pipeline.py +25 -215
- snowflake/ml/model/type_hints.py +5 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
- snowflake/ml/utils/html_utils.py +67 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/METADATA +94 -7
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/RECORD +52 -48
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/manager.py
CHANGED
|
@@ -1,27 +1,25 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
+
import os
|
|
3
4
|
import pathlib
|
|
4
5
|
import sys
|
|
5
|
-
import textwrap
|
|
6
6
|
from pathlib import PurePath
|
|
7
7
|
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
|
|
8
8
|
from uuid import uuid4
|
|
9
9
|
|
|
10
10
|
import pandas as pd
|
|
11
|
-
import yaml
|
|
12
11
|
|
|
13
12
|
from snowflake import snowpark
|
|
14
13
|
from snowflake.ml._internal import telemetry
|
|
15
14
|
from snowflake.ml._internal.utils import identifier
|
|
16
15
|
from snowflake.ml.jobs import job as jb
|
|
17
16
|
from snowflake.ml.jobs._utils import (
|
|
17
|
+
constants,
|
|
18
18
|
feature_flags,
|
|
19
19
|
payload_utils,
|
|
20
20
|
query_helper,
|
|
21
|
-
spec_utils,
|
|
22
21
|
types,
|
|
23
22
|
)
|
|
24
|
-
from snowflake.snowpark._internal import utils as sp_utils
|
|
25
23
|
from snowflake.snowpark.context import get_active_session
|
|
26
24
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
27
25
|
from snowflake.snowpark.functions import coalesce, col, lit, when
|
|
@@ -259,7 +257,7 @@ def submit_directory(
|
|
|
259
257
|
dir_path: str,
|
|
260
258
|
compute_pool: str,
|
|
261
259
|
*,
|
|
262
|
-
entrypoint: str,
|
|
260
|
+
entrypoint: Union[str, list[str]],
|
|
263
261
|
stage_name: str,
|
|
264
262
|
args: Optional[list[str]] = None,
|
|
265
263
|
target_instances: int = 1,
|
|
@@ -274,7 +272,11 @@ def submit_directory(
|
|
|
274
272
|
Args:
|
|
275
273
|
dir_path: The path to the directory containing the job payload.
|
|
276
274
|
compute_pool: The compute pool to use for the job.
|
|
277
|
-
entrypoint: The
|
|
275
|
+
entrypoint: The entry point for job execution. Can be:
|
|
276
|
+
- A string path to the entry point script inside the source directory.
|
|
277
|
+
- A list of strings representing a custom command (e.g., ["arctic_training"])
|
|
278
|
+
which is passed through as-is without local resolution or validation.
|
|
279
|
+
This is useful for entrypoints that are installed via pip_requirements.
|
|
278
280
|
stage_name: The name of the stage where the job payload will be uploaded.
|
|
279
281
|
args: A list of arguments to pass to the job.
|
|
280
282
|
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
|
@@ -315,7 +317,7 @@ def submit_from_stage(
|
|
|
315
317
|
source: str,
|
|
316
318
|
compute_pool: str,
|
|
317
319
|
*,
|
|
318
|
-
entrypoint: str,
|
|
320
|
+
entrypoint: Union[str, list[str]],
|
|
319
321
|
stage_name: str,
|
|
320
322
|
args: Optional[list[str]] = None,
|
|
321
323
|
target_instances: int = 1,
|
|
@@ -330,7 +332,11 @@ def submit_from_stage(
|
|
|
330
332
|
Args:
|
|
331
333
|
source: a stage path or a stage containing the job payload.
|
|
332
334
|
compute_pool: The compute pool to use for the job.
|
|
333
|
-
entrypoint:
|
|
335
|
+
entrypoint: The entry point for job execution. Can be:
|
|
336
|
+
- A string path to the entry point script inside the source directory.
|
|
337
|
+
- A list of strings representing a custom command (e.g., ["arctic_training"])
|
|
338
|
+
which is passed through as-is without local resolution or validation.
|
|
339
|
+
This is useful for entrypoints that are installed via pip_requirements.
|
|
334
340
|
stage_name: The name of the stage where the job payload will be uploaded.
|
|
335
341
|
args: A list of arguments to pass to the job.
|
|
336
342
|
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
|
@@ -375,7 +381,7 @@ def _submit_job(
|
|
|
375
381
|
compute_pool: str,
|
|
376
382
|
*,
|
|
377
383
|
stage_name: str,
|
|
378
|
-
entrypoint: Optional[str] = None,
|
|
384
|
+
entrypoint: Optional[Union[str, list[str]]] = None,
|
|
379
385
|
args: Optional[list[str]] = None,
|
|
380
386
|
target_instances: int = 1,
|
|
381
387
|
pip_requirements: Optional[list[str]] = None,
|
|
@@ -392,7 +398,7 @@ def _submit_job(
|
|
|
392
398
|
compute_pool: str,
|
|
393
399
|
*,
|
|
394
400
|
stage_name: str,
|
|
395
|
-
entrypoint: Optional[str] = None,
|
|
401
|
+
entrypoint: Optional[Union[str, list[str]]] = None,
|
|
396
402
|
args: Optional[list[str]] = None,
|
|
397
403
|
target_instances: int = 1,
|
|
398
404
|
pip_requirements: Optional[list[str]] = None,
|
|
@@ -424,7 +430,7 @@ def _submit_job(
|
|
|
424
430
|
compute_pool: str,
|
|
425
431
|
*,
|
|
426
432
|
stage_name: str,
|
|
427
|
-
entrypoint: Optional[str] = None,
|
|
433
|
+
entrypoint: Optional[Union[str, list[str]]] = None,
|
|
428
434
|
args: Optional[list[str]] = None,
|
|
429
435
|
target_instances: int = 1,
|
|
430
436
|
session: Optional[snowpark.Session] = None,
|
|
@@ -437,7 +443,11 @@ def _submit_job(
|
|
|
437
443
|
source: The file/directory path containing payload source code or a serializable Python callable.
|
|
438
444
|
compute_pool: The compute pool to use for the job.
|
|
439
445
|
stage_name: The name of the stage where the job payload will be uploaded.
|
|
440
|
-
entrypoint: The entry point for the job execution.
|
|
446
|
+
entrypoint: The entry point for the job execution. Can be:
|
|
447
|
+
- A string path to a Python script (required if source is a directory).
|
|
448
|
+
- A list of strings representing a custom command (e.g., ["arctic_training"])
|
|
449
|
+
which is passed through as-is without local resolution or validation.
|
|
450
|
+
This is useful for entrypoints that are installed via pip_requirements.
|
|
441
451
|
args: A list of arguments to pass to the job.
|
|
442
452
|
target_instances: The number of instances to use for the job. If none specified, single node job is created.
|
|
443
453
|
session: The Snowpark session to use. If none specified, uses active session.
|
|
@@ -449,7 +459,6 @@ def _submit_job(
|
|
|
449
459
|
Raises:
|
|
450
460
|
ValueError: If database or schema value(s) are invalid
|
|
451
461
|
RuntimeError: If schema is not specified in session context or job submission
|
|
452
|
-
SnowparkSQLException: if failed to upload payload
|
|
453
462
|
"""
|
|
454
463
|
session = _ensure_session(session)
|
|
455
464
|
|
|
@@ -481,7 +490,8 @@ def _submit_job(
|
|
|
481
490
|
enable_metrics = kwargs.pop("enable_metrics", True)
|
|
482
491
|
query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
|
|
483
492
|
imports = kwargs.pop("imports", None) or imports
|
|
484
|
-
|
|
493
|
+
# if the mljob is submitted from a notebook, we use the same image tag as the notebook
|
|
494
|
+
runtime_environment = kwargs.pop("runtime_environment", os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR, None))
|
|
485
495
|
|
|
486
496
|
# Warn if there are unknown kwargs
|
|
487
497
|
if kwargs:
|
|
@@ -513,128 +523,51 @@ def _submit_job(
|
|
|
513
523
|
try:
|
|
514
524
|
# Upload payload
|
|
515
525
|
uploaded_payload = payload_utils.JobPayload(
|
|
516
|
-
source, entrypoint=entrypoint, pip_requirements=pip_requirements,
|
|
526
|
+
source, entrypoint=entrypoint, pip_requirements=pip_requirements, imports=imports
|
|
517
527
|
).upload(session, stage_path)
|
|
518
528
|
except SnowparkSQLException as e:
|
|
519
529
|
if e.sql_error_code == 90106:
|
|
520
530
|
raise RuntimeError(
|
|
521
531
|
"Please specify a schema, either in the session context or as a parameter in the job submission"
|
|
522
532
|
)
|
|
533
|
+
elif e.sql_error_code == 3001 and "schema" in str(e).lower():
|
|
534
|
+
raise RuntimeError(
|
|
535
|
+
"please grant privileges on schema before submitting a job, see",
|
|
536
|
+
"https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/access-control-requirements",
|
|
537
|
+
" for more details",
|
|
538
|
+
) from e
|
|
523
539
|
raise
|
|
524
540
|
|
|
525
|
-
|
|
526
|
-
# Add default env vars (extracted from spec_utils.generate_service_spec)
|
|
527
|
-
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
528
|
-
|
|
529
|
-
try:
|
|
530
|
-
return _do_submit_job_v2(
|
|
531
|
-
session=session,
|
|
532
|
-
payload=uploaded_payload,
|
|
533
|
-
args=args,
|
|
534
|
-
env_vars=combined_env_vars,
|
|
535
|
-
spec_overrides=spec_overrides,
|
|
536
|
-
compute_pool=compute_pool,
|
|
537
|
-
job_id=job_id,
|
|
538
|
-
external_access_integrations=external_access_integrations,
|
|
539
|
-
query_warehouse=query_warehouse,
|
|
540
|
-
target_instances=target_instances,
|
|
541
|
-
min_instances=min_instances,
|
|
542
|
-
enable_metrics=enable_metrics,
|
|
543
|
-
use_async=True,
|
|
544
|
-
runtime_environment=runtime_environment,
|
|
545
|
-
)
|
|
546
|
-
except SnowparkSQLException as e:
|
|
547
|
-
if not (e.sql_error_code == 90237 and sp_utils.is_in_stored_procedure()): # type: ignore[no-untyped-call]
|
|
548
|
-
raise
|
|
549
|
-
# SNOW-2390287: SYSTEM$EXECUTE_ML_JOB() is erroneously blocked in owner's rights
|
|
550
|
-
# stored procedures. This will be fixed in an upcoming release.
|
|
551
|
-
logger.warning(
|
|
552
|
-
"Job submission using V2 failed with error {}. Falling back to V1.".format(
|
|
553
|
-
str(e).split("\n", 1)[0],
|
|
554
|
-
)
|
|
555
|
-
)
|
|
556
|
-
|
|
557
|
-
# Fall back to v1
|
|
558
|
-
# Generate service spec
|
|
559
|
-
spec = spec_utils.generate_service_spec(
|
|
560
|
-
session,
|
|
561
|
-
compute_pool=compute_pool,
|
|
562
|
-
payload=uploaded_payload,
|
|
563
|
-
args=args,
|
|
564
|
-
target_instances=target_instances,
|
|
565
|
-
min_instances=min_instances,
|
|
566
|
-
enable_metrics=enable_metrics,
|
|
567
|
-
runtime_environment=runtime_environment,
|
|
568
|
-
)
|
|
569
|
-
|
|
570
|
-
# Generate spec overrides
|
|
571
|
-
spec_overrides = spec_utils.generate_spec_overrides(
|
|
572
|
-
environment_vars=env_vars,
|
|
573
|
-
custom_overrides=spec_overrides,
|
|
574
|
-
)
|
|
575
|
-
if spec_overrides:
|
|
576
|
-
spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
|
|
577
|
-
|
|
578
|
-
return _do_submit_job_v1(
|
|
579
|
-
session, spec, external_access_integrations, query_warehouse, target_instances, compute_pool, job_id
|
|
580
|
-
)
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
def _do_submit_job_v1(
|
|
584
|
-
session: snowpark.Session,
|
|
585
|
-
spec: dict[str, Any],
|
|
586
|
-
external_access_integrations: list[str],
|
|
587
|
-
query_warehouse: Optional[str],
|
|
588
|
-
target_instances: int,
|
|
589
|
-
compute_pool: str,
|
|
590
|
-
job_id: str,
|
|
591
|
-
) -> jb.MLJob[Any]:
|
|
592
|
-
"""
|
|
593
|
-
Generate the SQL query for job submission.
|
|
594
|
-
|
|
595
|
-
Args:
|
|
596
|
-
session: The Snowpark session to use.
|
|
597
|
-
spec: The service spec for the job.
|
|
598
|
-
external_access_integrations: The external access integrations for the job.
|
|
599
|
-
query_warehouse: The query warehouse for the job.
|
|
600
|
-
target_instances: The number of instances for the job.
|
|
601
|
-
session: The Snowpark session to use.
|
|
602
|
-
compute_pool: The compute pool to use for the job.
|
|
603
|
-
job_id: The ID of the job.
|
|
604
|
-
|
|
605
|
-
Returns:
|
|
606
|
-
The job object.
|
|
607
|
-
"""
|
|
608
|
-
query_template = textwrap.dedent(
|
|
609
|
-
"""\
|
|
610
|
-
EXECUTE JOB SERVICE
|
|
611
|
-
IN COMPUTE POOL IDENTIFIER(?)
|
|
612
|
-
FROM SPECIFICATION $$
|
|
613
|
-
{}
|
|
614
|
-
$$
|
|
615
|
-
NAME = IDENTIFIER(?)
|
|
616
|
-
ASYNC = TRUE
|
|
617
|
-
"""
|
|
618
|
-
)
|
|
619
|
-
params: list[Any] = [compute_pool, job_id]
|
|
620
|
-
query = query_template.format(yaml.dump(spec)).splitlines()
|
|
621
|
-
if external_access_integrations:
|
|
622
|
-
external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
|
|
623
|
-
query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
|
|
624
|
-
if query_warehouse:
|
|
625
|
-
query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
|
|
626
|
-
params.append(query_warehouse)
|
|
627
|
-
if target_instances > 1:
|
|
628
|
-
query.append("REPLICAS = ?")
|
|
629
|
-
params.append(target_instances)
|
|
630
|
-
|
|
631
|
-
query_text = "\n".join(line for line in query if line)
|
|
632
|
-
_ = query_helper.run_query(session, query_text, params=params)
|
|
541
|
+
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
633
542
|
|
|
634
|
-
|
|
543
|
+
try:
|
|
544
|
+
return _do_submit_job(
|
|
545
|
+
session=session,
|
|
546
|
+
payload=uploaded_payload,
|
|
547
|
+
args=args,
|
|
548
|
+
env_vars=combined_env_vars,
|
|
549
|
+
spec_overrides=spec_overrides,
|
|
550
|
+
compute_pool=compute_pool,
|
|
551
|
+
job_id=job_id,
|
|
552
|
+
external_access_integrations=external_access_integrations,
|
|
553
|
+
query_warehouse=query_warehouse,
|
|
554
|
+
target_instances=target_instances,
|
|
555
|
+
min_instances=min_instances,
|
|
556
|
+
enable_metrics=enable_metrics,
|
|
557
|
+
use_async=True,
|
|
558
|
+
runtime_environment=runtime_environment,
|
|
559
|
+
)
|
|
560
|
+
except SnowparkSQLException as e:
|
|
561
|
+
if e.sql_error_code == 3001 and "schema" in str(e).lower():
|
|
562
|
+
raise RuntimeError(
|
|
563
|
+
"please grant privileges on schema before submitting a job, see",
|
|
564
|
+
"https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/access-control-requirements"
|
|
565
|
+
" for more details",
|
|
566
|
+
) from e
|
|
567
|
+
raise
|
|
635
568
|
|
|
636
569
|
|
|
637
|
-
def
|
|
570
|
+
def _do_submit_job(
|
|
638
571
|
session: snowpark.Session,
|
|
639
572
|
payload: types.UploadedPayload,
|
|
640
573
|
args: Optional[list[str]],
|
|
@@ -672,9 +605,7 @@ def _do_submit_job_v2(
|
|
|
672
605
|
Returns:
|
|
673
606
|
The job object.
|
|
674
607
|
"""
|
|
675
|
-
args = [
|
|
676
|
-
(payload.stage_path.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
|
|
677
|
-
] + (args or [])
|
|
608
|
+
args = [(v.as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint] + (args or [])
|
|
678
609
|
spec_options = {
|
|
679
610
|
"STAGE_PATH": payload.stage_path.as_posix(),
|
|
680
611
|
"ENTRYPOINT": ["/usr/local/bin/_entrypoint.sh"],
|
|
@@ -683,13 +614,14 @@ def _do_submit_job_v2(
|
|
|
683
614
|
"ENABLE_METRICS": enable_metrics,
|
|
684
615
|
"SPEC_OVERRIDES": spec_overrides,
|
|
685
616
|
}
|
|
686
|
-
# for the image tag or full image URL, we use that directly
|
|
687
617
|
if runtime_environment:
|
|
618
|
+
# for the image tag or full image URL, we use that directly
|
|
688
619
|
spec_options["RUNTIME"] = runtime_environment
|
|
689
620
|
elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
|
|
690
621
|
# when feature flag is enabled, we get the local python version and wrap it in a dict
|
|
691
622
|
# in system function, we can know whether it is python version or image tag or full image URL through the format
|
|
692
623
|
spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
|
|
624
|
+
|
|
693
625
|
job_options = {
|
|
694
626
|
"EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
|
|
695
627
|
"QUERY_WAREHOUSE": query_warehouse,
|
|
@@ -697,6 +629,9 @@ def _do_submit_job_v2(
|
|
|
697
629
|
"MIN_INSTANCES": min_instances,
|
|
698
630
|
"ASYNC": use_async,
|
|
699
631
|
}
|
|
632
|
+
|
|
633
|
+
if feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(default=True):
|
|
634
|
+
spec_options["ENABLE_STAGE_MOUNT_V2"] = True
|
|
700
635
|
if payload.payload_name:
|
|
701
636
|
job_options["GENERATE_SUFFIX"] = True
|
|
702
637
|
job_options = {k: v for k, v in job_options.items() if v is not None}
|
|
@@ -4,14 +4,18 @@ from snowflake.ml.model._client.ops import service_ops
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def _get_inference_engine_args(
|
|
7
|
-
|
|
7
|
+
inference_engine_options: Optional[dict[str, Any]],
|
|
8
8
|
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
9
|
-
|
|
9
|
+
|
|
10
|
+
if not inference_engine_options:
|
|
10
11
|
return None
|
|
11
12
|
|
|
13
|
+
if "engine" not in inference_engine_options:
|
|
14
|
+
raise ValueError("'engine' field is required in inference_engine_options")
|
|
15
|
+
|
|
12
16
|
return service_ops.InferenceEngineArgs(
|
|
13
|
-
inference_engine=
|
|
14
|
-
inference_engine_args_override=
|
|
17
|
+
inference_engine=inference_engine_options["engine"],
|
|
18
|
+
inference_engine_args_override=inference_engine_options.get("engine_args_override"),
|
|
15
19
|
)
|
|
16
20
|
|
|
17
21
|
|
|
@@ -8,11 +8,11 @@ from typing import Any, Callable, Optional, Union, overload
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
|
|
10
10
|
from snowflake import snowpark
|
|
11
|
-
from snowflake.ml import jobs
|
|
12
11
|
from snowflake.ml._internal import telemetry
|
|
13
12
|
from snowflake.ml._internal.utils import sql_identifier
|
|
13
|
+
from snowflake.ml.jobs import job
|
|
14
14
|
from snowflake.ml.lineage import lineage_node
|
|
15
|
-
from snowflake.ml.model import task, type_hints
|
|
15
|
+
from snowflake.ml.model import openai_signatures, task, type_hints
|
|
16
16
|
from snowflake.ml.model._client.model import (
|
|
17
17
|
batch_inference_specs,
|
|
18
18
|
inference_engine_utils,
|
|
@@ -23,6 +23,7 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest_sch
|
|
|
23
23
|
from snowflake.ml.model._model_composer.model_method import utils as model_method_utils
|
|
24
24
|
from snowflake.ml.model._packager.model_handlers import snowmlmodel
|
|
25
25
|
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
|
26
|
+
from snowflake.ml.model._signatures import core
|
|
26
27
|
from snowflake.snowpark import Session, async_job, dataframe
|
|
27
28
|
|
|
28
29
|
_TELEMETRY_PROJECT = "MLOps"
|
|
@@ -45,6 +46,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
45
46
|
_version_name: sql_identifier.SqlIdentifier
|
|
46
47
|
_functions: list[model_manifest_schema.ModelFunctionInfo]
|
|
47
48
|
_model_spec: Optional[model_meta_schema.ModelMetadataDict]
|
|
49
|
+
_model_manifest: Optional[model_manifest_schema.ModelManifestDict]
|
|
48
50
|
|
|
49
51
|
def __init__(self) -> None:
|
|
50
52
|
raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
|
|
@@ -155,6 +157,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
155
157
|
self._version_name = version_name
|
|
156
158
|
self._functions = self._get_functions()
|
|
157
159
|
self._model_spec = None
|
|
160
|
+
self._model_manifest = None
|
|
158
161
|
super(cls, cls).__init__(
|
|
159
162
|
self,
|
|
160
163
|
session=model_ops._session,
|
|
@@ -462,6 +465,28 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
462
465
|
)
|
|
463
466
|
return self._model_spec
|
|
464
467
|
|
|
468
|
+
def _get_model_manifest(
|
|
469
|
+
self, statement_params: Optional[dict[str, Any]] = None
|
|
470
|
+
) -> model_manifest_schema.ModelManifestDict:
|
|
471
|
+
"""Fetch and cache the model manifest for this model version.
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
statement_params: Optional dictionary of statement parameters to include
|
|
475
|
+
in the SQL command to fetch the model manifest.
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
The model manifest as a dictionary for this model version.
|
|
479
|
+
"""
|
|
480
|
+
if self._model_manifest is None:
|
|
481
|
+
self._model_manifest = self._model_ops.get_model_version_manifest(
|
|
482
|
+
database_name=None,
|
|
483
|
+
schema_name=None,
|
|
484
|
+
model_name=self._model_name,
|
|
485
|
+
version_name=self._version_name,
|
|
486
|
+
statement_params=statement_params,
|
|
487
|
+
)
|
|
488
|
+
return self._model_manifest
|
|
489
|
+
|
|
465
490
|
@overload
|
|
466
491
|
def run(
|
|
467
492
|
self,
|
|
@@ -530,6 +555,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
530
555
|
|
|
531
556
|
Returns:
|
|
532
557
|
The prediction data. It would be the same type dataframe as your input.
|
|
558
|
+
|
|
559
|
+
Raises:
|
|
560
|
+
ValueError: When the model does not support running on warehouse and no service name is provided.
|
|
533
561
|
"""
|
|
534
562
|
statement_params = telemetry.get_statement_params(
|
|
535
563
|
project=_TELEMETRY_PROJECT,
|
|
@@ -556,6 +584,27 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
556
584
|
statement_params=statement_params,
|
|
557
585
|
)
|
|
558
586
|
else:
|
|
587
|
+
manifest = self._get_model_manifest(statement_params=statement_params)
|
|
588
|
+
target_platforms = manifest.get("target_platforms", None)
|
|
589
|
+
if (
|
|
590
|
+
target_platforms is not None
|
|
591
|
+
and len(target_platforms) > 0
|
|
592
|
+
and type_hints.TargetPlatform.WAREHOUSE.value not in target_platforms
|
|
593
|
+
):
|
|
594
|
+
raise ValueError(
|
|
595
|
+
f"The model {self.fully_qualified_model_name} version {self.version_name} "
|
|
596
|
+
"is not logged for inference in Warehouse. "
|
|
597
|
+
"To run the model in Warehouse, please log the model again using `log_model` API with "
|
|
598
|
+
'`target_platforms=["WAREHOUSE"]` or '
|
|
599
|
+
'`target_platforms=["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"]` and rerun the command. '
|
|
600
|
+
"To run the model in Snowpark Container Services, the `service_name` argument must be provided. "
|
|
601
|
+
"You can create a service using the `create_service` API. "
|
|
602
|
+
"For inference in Warehouse, see https://docs.snowflake.com/en/developer-guide/"
|
|
603
|
+
"snowflake-ml/model-registry/warehouse#inference-from-python. "
|
|
604
|
+
"For inference in Snowpark Container Services, see https://docs.snowflake.com/en/developer-guide/"
|
|
605
|
+
"snowflake-ml/model-registry/container#python."
|
|
606
|
+
)
|
|
607
|
+
|
|
559
608
|
explain_case_sensitive = self._determine_explain_case_sensitivity(target_function_info, statement_params)
|
|
560
609
|
|
|
561
610
|
return self._model_ops.invoke_method(
|
|
@@ -602,7 +651,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
602
651
|
input_spec: dataframe.DataFrame,
|
|
603
652
|
output_spec: batch_inference_specs.OutputSpec,
|
|
604
653
|
job_spec: Optional[batch_inference_specs.JobSpec] = None,
|
|
605
|
-
) ->
|
|
654
|
+
) -> job.MLJob[Any]:
|
|
606
655
|
"""Execute batch inference on datasets as an SPCS job.
|
|
607
656
|
|
|
608
657
|
Args:
|
|
@@ -617,7 +666,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
617
666
|
If None, default values will be used.
|
|
618
667
|
|
|
619
668
|
Returns:
|
|
620
|
-
|
|
669
|
+
job.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
|
|
621
670
|
lifecycle.
|
|
622
671
|
|
|
623
672
|
Raises:
|
|
@@ -940,14 +989,16 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
940
989
|
self,
|
|
941
990
|
statement_params: Optional[dict[str, Any]] = None,
|
|
942
991
|
) -> None:
|
|
943
|
-
"""Check if the model is a HuggingFace pipeline with text-generation task
|
|
992
|
+
"""Check if the model is a HuggingFace pipeline with text-generation task
|
|
993
|
+
and is logged with OPENAI_CHAT_SIGNATURE.
|
|
944
994
|
|
|
945
995
|
Args:
|
|
946
996
|
statement_params: Optional dictionary of statement parameters to include
|
|
947
997
|
in the SQL command to fetch model spec.
|
|
948
998
|
|
|
949
999
|
Raises:
|
|
950
|
-
ValueError: If the model is not a HuggingFace text-generation model
|
|
1000
|
+
ValueError: If the model is not a HuggingFace text-generation model or
|
|
1001
|
+
if the model is not logged with OPENAI_CHAT_SIGNATURE.
|
|
951
1002
|
"""
|
|
952
1003
|
# Fetch model spec
|
|
953
1004
|
model_spec = self._get_model_spec(statement_params)
|
|
@@ -983,6 +1034,21 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
983
1034
|
)
|
|
984
1035
|
raise ValueError(f"Inference engine is only supported for task 'text-generation'. {found_tasks_str}")
|
|
985
1036
|
|
|
1037
|
+
# Check if the model is logged with OPENAI_CHAT_SIGNATURE
|
|
1038
|
+
signatures_dict = model_spec.get("signatures", {})
|
|
1039
|
+
|
|
1040
|
+
# Deserialize signatures from model spec to ModelSignature objects for proper semantic comparison.
|
|
1041
|
+
deserialized_signatures = {
|
|
1042
|
+
func_name: core.ModelSignature.from_dict(sig_dict) for func_name, sig_dict in signatures_dict.items()
|
|
1043
|
+
}
|
|
1044
|
+
|
|
1045
|
+
if deserialized_signatures != openai_signatures.OPENAI_CHAT_SIGNATURE:
|
|
1046
|
+
raise ValueError(
|
|
1047
|
+
"Inference engine requires the model to be logged with OPENAI_CHAT_SIGNATURE. "
|
|
1048
|
+
f"Found signatures: {signatures_dict}. "
|
|
1049
|
+
"Please log the model with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE"
|
|
1050
|
+
)
|
|
1051
|
+
|
|
986
1052
|
@overload
|
|
987
1053
|
def create_service(
|
|
988
1054
|
self,
|
|
@@ -1001,6 +1067,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1001
1067
|
force_rebuild: bool = False,
|
|
1002
1068
|
build_external_access_integration: Optional[str] = None,
|
|
1003
1069
|
block: bool = True,
|
|
1070
|
+
autocapture: bool = False,
|
|
1071
|
+
inference_engine_options: Optional[dict[str, Any]] = None,
|
|
1004
1072
|
experimental_options: Optional[dict[str, Any]] = None,
|
|
1005
1073
|
) -> Union[str, async_job.AsyncJob]:
|
|
1006
1074
|
"""Create an inference service with the given spec.
|
|
@@ -1034,11 +1102,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1034
1102
|
block: A bool value indicating whether this function will wait until the service is available.
|
|
1035
1103
|
When it is ``False``, this function executes the underlying service creation asynchronously
|
|
1036
1104
|
and returns an :class:`AsyncJob`.
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
`
|
|
1041
|
-
`
|
|
1105
|
+
autocapture: Whether inference autocapture is enabled on the service. If true, inference data will be
|
|
1106
|
+
captured in the model inference table.
|
|
1107
|
+
inference_engine_options: Options for the service creation with custom inference engine.
|
|
1108
|
+
Supports `engine` and `engine_args_override`.
|
|
1109
|
+
`engine` is the type of the inference engine to use.
|
|
1110
|
+
`engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1111
|
+
experimental_options: Experimental options for the service creation.
|
|
1042
1112
|
"""
|
|
1043
1113
|
...
|
|
1044
1114
|
|
|
@@ -1060,6 +1130,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1060
1130
|
force_rebuild: bool = False,
|
|
1061
1131
|
build_external_access_integrations: Optional[list[str]] = None,
|
|
1062
1132
|
block: bool = True,
|
|
1133
|
+
autocapture: bool = False,
|
|
1134
|
+
inference_engine_options: Optional[dict[str, Any]] = None,
|
|
1063
1135
|
experimental_options: Optional[dict[str, Any]] = None,
|
|
1064
1136
|
) -> Union[str, async_job.AsyncJob]:
|
|
1065
1137
|
"""Create an inference service with the given spec.
|
|
@@ -1093,11 +1165,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1093
1165
|
block: A bool value indicating whether this function will wait until the service is available.
|
|
1094
1166
|
When it is ``False``, this function executes the underlying service creation asynchronously
|
|
1095
1167
|
and returns an :class:`AsyncJob`.
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
`
|
|
1100
|
-
`
|
|
1168
|
+
autocapture: Whether inference autocapture is enabled on the service. If true, inference data will be
|
|
1169
|
+
captured in the model inference table.
|
|
1170
|
+
inference_engine_options: Options for the service creation with custom inference engine.
|
|
1171
|
+
Supports `engine` and `engine_args_override`.
|
|
1172
|
+
`engine` is the type of the inference engine to use.
|
|
1173
|
+
`engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1174
|
+
experimental_options: Experimental options for the service creation.
|
|
1101
1175
|
"""
|
|
1102
1176
|
...
|
|
1103
1177
|
|
|
@@ -1134,6 +1208,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1134
1208
|
build_external_access_integration: Optional[str] = None,
|
|
1135
1209
|
build_external_access_integrations: Optional[list[str]] = None,
|
|
1136
1210
|
block: bool = True,
|
|
1211
|
+
autocapture: bool = False,
|
|
1212
|
+
inference_engine_options: Optional[dict[str, Any]] = None,
|
|
1137
1213
|
experimental_options: Optional[dict[str, Any]] = None,
|
|
1138
1214
|
) -> Union[str, async_job.AsyncJob]:
|
|
1139
1215
|
"""Create an inference service with the given spec.
|
|
@@ -1169,11 +1245,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1169
1245
|
block: A bool value indicating whether this function will wait until the service is available.
|
|
1170
1246
|
When it is False, this function executes the underlying service creation asynchronously
|
|
1171
1247
|
and returns an AsyncJob.
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
`
|
|
1176
|
-
`
|
|
1248
|
+
autocapture: Whether inference autocapture is enabled on the service. If true, inference data will be
|
|
1249
|
+
captured in the model inference table.
|
|
1250
|
+
inference_engine_options: Options for the service creation with custom inference engine.
|
|
1251
|
+
Supports `engine` and `engine_args_override`.
|
|
1252
|
+
`engine` is the type of the inference engine to use.
|
|
1253
|
+
`engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1254
|
+
experimental_options: Experimental options for the service creation.
|
|
1177
1255
|
|
|
1178
1256
|
|
|
1179
1257
|
Raises:
|
|
@@ -1209,9 +1287,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1209
1287
|
# Validate GPU support if GPU resources are requested
|
|
1210
1288
|
self._throw_error_if_gpu_is_not_supported(gpu_requests, statement_params)
|
|
1211
1289
|
|
|
1212
|
-
inference_engine_args = inference_engine_utils._get_inference_engine_args(
|
|
1290
|
+
inference_engine_args = inference_engine_utils._get_inference_engine_args(inference_engine_options)
|
|
1213
1291
|
|
|
1214
|
-
# Check if model is HuggingFace text-generation
|
|
1292
|
+
# Check if model is HuggingFace text-generation and is logged with
|
|
1293
|
+
# OPENAI_CHAT_SIGNATURE before doing inference engine checks
|
|
1215
1294
|
# Only validate if inference engine is actually specified
|
|
1216
1295
|
if inference_engine_args is not None:
|
|
1217
1296
|
self._check_huggingface_text_generation_model(statement_params)
|
|
@@ -1223,9 +1302,6 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1223
1302
|
gpu_requests,
|
|
1224
1303
|
)
|
|
1225
1304
|
|
|
1226
|
-
# Extract autocapture from experimental_options
|
|
1227
|
-
autocapture = experimental_options.get("autocapture") if experimental_options else None
|
|
1228
|
-
|
|
1229
1305
|
from snowflake.ml.model import event_handler
|
|
1230
1306
|
from snowflake.snowpark import exceptions
|
|
1231
1307
|
|
|
@@ -1292,8 +1368,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1292
1368
|
"""List all the service names using this model version.
|
|
1293
1369
|
|
|
1294
1370
|
Returns:
|
|
1295
|
-
List of
|
|
1296
|
-
|
|
1371
|
+
List of details about all the services associated with this model version. The details include:
|
|
1372
|
+
name: The name of the service.
|
|
1373
|
+
status: The status of the service.
|
|
1374
|
+
inference_endpoint: The public endpoint of the service, if enabled and services is not in PENDING state.
|
|
1375
|
+
This will give privatelink endpoint if the session is created with privatelink connection
|
|
1376
|
+
internal_endpoint: The internal endpoint of the service, if services is not in PENDING state.
|
|
1377
|
+
autocapture_enabled: Whether service has autocapture enabled, if it is set in service proxy spec.
|
|
1297
1378
|
"""
|
|
1298
1379
|
statement_params = telemetry.get_statement_params(
|
|
1299
1380
|
project=_TELEMETRY_PROJECT,
|