snowflake-ml-python 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +2 -0
- snowflake/ml/jobs/_utils/payload_utils.py +38 -18
- snowflake/ml/jobs/_utils/query_helper.py +8 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +58 -4
- snowflake/ml/jobs/_utils/spec_utils.py +0 -31
- snowflake/ml/jobs/_utils/stage_utils.py +2 -2
- snowflake/ml/jobs/_utils/types.py +22 -2
- snowflake/ml/jobs/job_definition.py +232 -0
- snowflake/ml/jobs/manager.py +16 -177
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +109 -32
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +45 -2
- snowflake/ml/model/_client/ops/param_utils.py +124 -0
- snowflake/ml/model/_client/ops/service_ops.py +81 -61
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +24 -9
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +4 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +30 -29
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +62 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/huggingface.py +54 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +52 -16
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +85 -0
- snowflake/ml/model/_signatures/utils.py +55 -0
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/openai_signatures.py +97 -0
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/METADATA +95 -1
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/RECORD +52 -46
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.21.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/manager.py
CHANGED
|
@@ -1,11 +1,5 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import logging
|
|
3
|
-
import os
|
|
4
|
-
import pathlib
|
|
5
|
-
import sys
|
|
6
|
-
from pathlib import PurePath
|
|
7
2
|
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
|
|
8
|
-
from uuid import uuid4
|
|
9
3
|
|
|
10
4
|
import pandas as pd
|
|
11
5
|
|
|
@@ -13,13 +7,8 @@ from snowflake import snowpark
|
|
|
13
7
|
from snowflake.ml._internal import telemetry
|
|
14
8
|
from snowflake.ml._internal.utils import identifier
|
|
15
9
|
from snowflake.ml.jobs import job as jb
|
|
16
|
-
from snowflake.ml.jobs._utils import
|
|
17
|
-
|
|
18
|
-
feature_flags,
|
|
19
|
-
payload_utils,
|
|
20
|
-
query_helper,
|
|
21
|
-
types,
|
|
22
|
-
)
|
|
10
|
+
from snowflake.ml.jobs._utils import query_helper
|
|
11
|
+
from snowflake.ml.jobs.job_definition import MLJobDefinition
|
|
23
12
|
from snowflake.snowpark.context import get_active_session
|
|
24
13
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
25
14
|
from snowflake.snowpark.functions import coalesce, col, lit, when
|
|
@@ -457,7 +446,6 @@ def _submit_job(
|
|
|
457
446
|
An object representing the submitted job.
|
|
458
447
|
|
|
459
448
|
Raises:
|
|
460
|
-
ValueError: If database or schema value(s) are invalid
|
|
461
449
|
RuntimeError: If schema is not specified in session context or job submission
|
|
462
450
|
"""
|
|
463
451
|
session = _ensure_session(session)
|
|
@@ -469,94 +457,30 @@ def _submit_job(
|
|
|
469
457
|
)
|
|
470
458
|
target_instances = max(target_instances, kwargs.pop("num_instances"))
|
|
471
459
|
|
|
472
|
-
imports = None
|
|
473
460
|
if "additional_payloads" in kwargs:
|
|
474
461
|
logger.warning(
|
|
475
462
|
"'additional_payloads' is deprecated and will be removed in a future release. Use 'imports' instead."
|
|
476
463
|
)
|
|
477
|
-
imports
|
|
464
|
+
if "imports" not in kwargs:
|
|
465
|
+
imports = kwargs.pop("additional_payloads", None)
|
|
466
|
+
kwargs.update({"imports": imports})
|
|
478
467
|
|
|
479
468
|
if "runtime_environment" in kwargs:
|
|
480
469
|
logger.warning("'runtime_environment' is in private preview since 1.15.0, do not use it in production.")
|
|
481
470
|
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
imports = kwargs.pop("imports", None) or imports
|
|
493
|
-
# if the mljob is submitted from a notebook, we use the same image tag as the notebook
|
|
494
|
-
runtime_environment = kwargs.pop("runtime_environment", os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR, None))
|
|
495
|
-
|
|
496
|
-
# Warn if there are unknown kwargs
|
|
497
|
-
if kwargs:
|
|
498
|
-
logger.warning(f"Ignoring unknown kwargs: {kwargs.keys()}")
|
|
499
|
-
|
|
500
|
-
# Validate parameters
|
|
501
|
-
if database and not schema:
|
|
502
|
-
raise ValueError("Schema must be specified if database is specified.")
|
|
503
|
-
if target_instances < 1:
|
|
504
|
-
raise ValueError("target_instances must be greater than 0.")
|
|
505
|
-
if not (0 < min_instances <= target_instances):
|
|
506
|
-
raise ValueError("min_instances must be greater than 0 and less than or equal to target_instances.")
|
|
507
|
-
if min_instances > 1:
|
|
508
|
-
# Validate min_instances against compute pool max_nodes
|
|
509
|
-
pool_info = jb._get_compute_pool_info(session, compute_pool)
|
|
510
|
-
max_nodes = int(pool_info["max_nodes"])
|
|
511
|
-
if min_instances > max_nodes:
|
|
512
|
-
raise ValueError(
|
|
513
|
-
f"The requested min_instances ({min_instances}) exceeds the max_nodes ({max_nodes}) "
|
|
514
|
-
f"of compute pool '{compute_pool}'. Reduce min_instances or increase max_nodes."
|
|
515
|
-
)
|
|
516
|
-
|
|
517
|
-
job_name = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
|
518
|
-
job_id = identifier.get_schema_level_object_identifier(database, schema, job_name)
|
|
519
|
-
stage_path_parts = identifier.parse_snowflake_stage_path(stage_name.lstrip("@"))
|
|
520
|
-
stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
|
|
521
|
-
stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
|
|
522
|
-
|
|
523
|
-
try:
|
|
524
|
-
# Upload payload
|
|
525
|
-
uploaded_payload = payload_utils.JobPayload(
|
|
526
|
-
source, entrypoint=entrypoint, pip_requirements=pip_requirements, imports=imports
|
|
527
|
-
).upload(session, stage_path)
|
|
528
|
-
except SnowparkSQLException as e:
|
|
529
|
-
if e.sql_error_code == 90106:
|
|
530
|
-
raise RuntimeError(
|
|
531
|
-
"Please specify a schema, either in the session context or as a parameter in the job submission"
|
|
532
|
-
)
|
|
533
|
-
elif e.sql_error_code == 3001 and "schema" in str(e).lower():
|
|
534
|
-
raise RuntimeError(
|
|
535
|
-
"please grant privileges on schema before submitting a job, see",
|
|
536
|
-
"https://docs.snowflake.com/en/developer-guide/snowflake-ml/ml-jobs/access-control-requirements",
|
|
537
|
-
" for more details",
|
|
538
|
-
) from e
|
|
539
|
-
raise
|
|
540
|
-
|
|
541
|
-
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
471
|
+
job_definition = MLJobDefinition.register(
|
|
472
|
+
source,
|
|
473
|
+
compute_pool,
|
|
474
|
+
stage_name,
|
|
475
|
+
session or get_active_session(),
|
|
476
|
+
entrypoint,
|
|
477
|
+
target_instances,
|
|
478
|
+
generate_suffix=True,
|
|
479
|
+
**kwargs,
|
|
480
|
+
)
|
|
542
481
|
|
|
543
482
|
try:
|
|
544
|
-
return
|
|
545
|
-
session=session,
|
|
546
|
-
payload=uploaded_payload,
|
|
547
|
-
args=args,
|
|
548
|
-
env_vars=combined_env_vars,
|
|
549
|
-
spec_overrides=spec_overrides,
|
|
550
|
-
compute_pool=compute_pool,
|
|
551
|
-
job_id=job_id,
|
|
552
|
-
external_access_integrations=external_access_integrations,
|
|
553
|
-
query_warehouse=query_warehouse,
|
|
554
|
-
target_instances=target_instances,
|
|
555
|
-
min_instances=min_instances,
|
|
556
|
-
enable_metrics=enable_metrics,
|
|
557
|
-
use_async=True,
|
|
558
|
-
runtime_environment=runtime_environment,
|
|
559
|
-
)
|
|
483
|
+
return job_definition(*(args or []))
|
|
560
484
|
except SnowparkSQLException as e:
|
|
561
485
|
if e.sql_error_code == 3001 and "schema" in str(e).lower():
|
|
562
486
|
raise RuntimeError(
|
|
@@ -567,91 +491,6 @@ def _submit_job(
|
|
|
567
491
|
raise
|
|
568
492
|
|
|
569
493
|
|
|
570
|
-
def _do_submit_job(
|
|
571
|
-
session: snowpark.Session,
|
|
572
|
-
payload: types.UploadedPayload,
|
|
573
|
-
args: Optional[list[str]],
|
|
574
|
-
env_vars: dict[str, str],
|
|
575
|
-
spec_overrides: dict[str, Any],
|
|
576
|
-
compute_pool: str,
|
|
577
|
-
job_id: Optional[str] = None,
|
|
578
|
-
external_access_integrations: Optional[list[str]] = None,
|
|
579
|
-
query_warehouse: Optional[str] = None,
|
|
580
|
-
target_instances: int = 1,
|
|
581
|
-
min_instances: int = 1,
|
|
582
|
-
enable_metrics: bool = True,
|
|
583
|
-
use_async: bool = True,
|
|
584
|
-
runtime_environment: Optional[str] = None,
|
|
585
|
-
) -> jb.MLJob[Any]:
|
|
586
|
-
"""
|
|
587
|
-
Generate the SQL query for job submission.
|
|
588
|
-
|
|
589
|
-
Args:
|
|
590
|
-
session: The Snowpark session to use.
|
|
591
|
-
payload: The uploaded job payload.
|
|
592
|
-
args: Arguments to pass to the entrypoint script.
|
|
593
|
-
env_vars: Environment variables to set in the job container.
|
|
594
|
-
spec_overrides: Custom service specification overrides.
|
|
595
|
-
compute_pool: The compute pool to use for job execution.
|
|
596
|
-
job_id: The ID of the job.
|
|
597
|
-
external_access_integrations: Optional list of external access integrations.
|
|
598
|
-
query_warehouse: Optional query warehouse to use.
|
|
599
|
-
target_instances: Number of instances for multi-node job.
|
|
600
|
-
min_instances: Minimum number of instances required to start the job.
|
|
601
|
-
enable_metrics: Whether to enable platform metrics for the job.
|
|
602
|
-
use_async: Whether to run the job asynchronously.
|
|
603
|
-
runtime_environment: image tag or full image URL to use for the job.
|
|
604
|
-
|
|
605
|
-
Returns:
|
|
606
|
-
The job object.
|
|
607
|
-
"""
|
|
608
|
-
args = [(v.as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint] + (args or [])
|
|
609
|
-
spec_options = {
|
|
610
|
-
"STAGE_PATH": payload.stage_path.as_posix(),
|
|
611
|
-
"ENTRYPOINT": ["/usr/local/bin/_entrypoint.sh"],
|
|
612
|
-
"ARGS": args,
|
|
613
|
-
"ENV_VARS": env_vars,
|
|
614
|
-
"ENABLE_METRICS": enable_metrics,
|
|
615
|
-
"SPEC_OVERRIDES": spec_overrides,
|
|
616
|
-
}
|
|
617
|
-
if runtime_environment:
|
|
618
|
-
# for the image tag or full image URL, we use that directly
|
|
619
|
-
spec_options["RUNTIME"] = runtime_environment
|
|
620
|
-
elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
|
|
621
|
-
# when feature flag is enabled, we get the local python version and wrap it in a dict
|
|
622
|
-
# in system function, we can know whether it is python version or image tag or full image URL through the format
|
|
623
|
-
spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"})
|
|
624
|
-
|
|
625
|
-
job_options = {
|
|
626
|
-
"EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
|
|
627
|
-
"QUERY_WAREHOUSE": query_warehouse,
|
|
628
|
-
"TARGET_INSTANCES": target_instances,
|
|
629
|
-
"MIN_INSTANCES": min_instances,
|
|
630
|
-
"ASYNC": use_async,
|
|
631
|
-
}
|
|
632
|
-
|
|
633
|
-
if feature_flags.FeatureFlags.ENABLE_STAGE_MOUNT_V2.is_enabled(default=True):
|
|
634
|
-
spec_options["ENABLE_STAGE_MOUNT_V2"] = True
|
|
635
|
-
if payload.payload_name:
|
|
636
|
-
job_options["GENERATE_SUFFIX"] = True
|
|
637
|
-
job_options = {k: v for k, v in job_options.items() if v is not None}
|
|
638
|
-
|
|
639
|
-
query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
|
|
640
|
-
if job_id:
|
|
641
|
-
database, schema, _ = identifier.parse_schema_level_object_identifier(job_id)
|
|
642
|
-
params = [
|
|
643
|
-
job_id
|
|
644
|
-
if payload.payload_name is None
|
|
645
|
-
else identifier.get_schema_level_object_identifier(database, schema, payload.payload_name) + "_",
|
|
646
|
-
compute_pool,
|
|
647
|
-
json.dumps(spec_options),
|
|
648
|
-
json.dumps(job_options),
|
|
649
|
-
]
|
|
650
|
-
actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
|
|
651
|
-
|
|
652
|
-
return get_job(actual_job_id, session=session)
|
|
653
|
-
|
|
654
|
-
|
|
655
494
|
def _ensure_session(session: Optional[snowpark.Session]) -> snowpark.Session:
|
|
656
495
|
try:
|
|
657
496
|
session = session or get_active_session()
|
|
@@ -7,7 +7,7 @@ from snowflake.ml._internal import telemetry
|
|
|
7
7
|
from snowflake.ml._internal.utils import identifier, mixins
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
10
|
-
from snowflake.ml import dataset
|
|
10
|
+
from snowflake.ml.dataset import dataset
|
|
11
11
|
from snowflake.ml.feature_store import feature_view
|
|
12
12
|
from snowflake.ml.model._client.model import model_version_impl
|
|
13
13
|
|
snowflake/ml/model/__init__.py
CHANGED
|
@@ -2,16 +2,20 @@ import sys
|
|
|
2
2
|
import warnings
|
|
3
3
|
|
|
4
4
|
from snowflake.ml.model._client.model.batch_inference_specs import (
|
|
5
|
+
ColumnHandlingOptions,
|
|
6
|
+
FileEncoding,
|
|
5
7
|
JobSpec,
|
|
6
8
|
OutputSpec,
|
|
7
9
|
SaveMode,
|
|
8
10
|
)
|
|
9
11
|
from snowflake.ml.model._client.model.model_impl import Model
|
|
10
12
|
from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
|
|
13
|
+
from snowflake.ml.model.code_path import CodePath
|
|
11
14
|
from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
|
|
12
15
|
from snowflake.ml.model.volatility import Volatility
|
|
13
16
|
|
|
14
17
|
__all__ = [
|
|
18
|
+
"CodePath",
|
|
15
19
|
"Model",
|
|
16
20
|
"ModelVersion",
|
|
17
21
|
"ExportMode",
|
|
@@ -20,6 +24,8 @@ __all__ = [
|
|
|
20
24
|
"OutputSpec",
|
|
21
25
|
"SaveMode",
|
|
22
26
|
"Volatility",
|
|
27
|
+
"FileEncoding",
|
|
28
|
+
"ColumnHandlingOptions",
|
|
23
29
|
]
|
|
24
30
|
|
|
25
31
|
_deprecation_warning_msg_for_3_9 = (
|
|
@@ -2,6 +2,7 @@ from enum import Enum
|
|
|
2
2
|
from typing import Optional
|
|
3
3
|
|
|
4
4
|
from pydantic import BaseModel
|
|
5
|
+
from typing_extensions import TypedDict
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class SaveMode(str, Enum):
|
|
@@ -18,6 +19,20 @@ class SaveMode(str, Enum):
|
|
|
18
19
|
ERROR = "error"
|
|
19
20
|
|
|
20
21
|
|
|
22
|
+
class FileEncoding(str, Enum):
|
|
23
|
+
"""The encoding of the file content that will be passed to the custom model."""
|
|
24
|
+
|
|
25
|
+
RAW_BYTES = "raw_bytes"
|
|
26
|
+
BASE64 = "base64"
|
|
27
|
+
BASE64_DATA_URL = "base64_data_url"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ColumnHandlingOptions(TypedDict):
|
|
31
|
+
"""Options for handling specific columns during run_batch for file I/O."""
|
|
32
|
+
|
|
33
|
+
encoding: FileEncoding
|
|
34
|
+
|
|
35
|
+
|
|
21
36
|
class OutputSpec(BaseModel):
|
|
22
37
|
"""Specification for batch inference output.
|
|
23
38
|
|
|
@@ -74,7 +89,7 @@ class JobSpec(BaseModel):
|
|
|
74
89
|
the memory of the node.
|
|
75
90
|
gpu_requests (Optional[str]): The gpu limit for GPU based inference. Can be integer or
|
|
76
91
|
string values. Use CPU if None.
|
|
77
|
-
replicas (Optional[int]): Number of job
|
|
92
|
+
replicas (Optional[int]): Number of SPCS job nodes used for distributed inference.
|
|
78
93
|
If not specified, defaults to 1 replica.
|
|
79
94
|
|
|
80
95
|
Example:
|
|
@@ -30,6 +30,10 @@ _TELEMETRY_PROJECT = "MLOps"
|
|
|
30
30
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
|
31
31
|
_BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
|
|
32
32
|
_BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
|
|
33
|
+
VLLM_SUPPORTED_TASKS = [
|
|
34
|
+
"text-generation",
|
|
35
|
+
"image-text-to-text",
|
|
36
|
+
]
|
|
33
37
|
|
|
34
38
|
|
|
35
39
|
class ExportMode(enum.Enum):
|
|
@@ -495,6 +499,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
495
499
|
function_name: Optional[str] = None,
|
|
496
500
|
partition_column: Optional[str] = None,
|
|
497
501
|
strict_input_validation: bool = False,
|
|
502
|
+
params: Optional[dict[str, Any]] = None,
|
|
498
503
|
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
|
499
504
|
"""Invoke a method in a model version object.
|
|
500
505
|
|
|
@@ -505,6 +510,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
505
510
|
partition_column: The partition column name to partition by.
|
|
506
511
|
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
|
507
512
|
type validation to make sure your input data won't overflow when providing to the model.
|
|
513
|
+
params: Optional dictionary of model inference parameters (e.g., temperature, top_k for LLMs).
|
|
514
|
+
These are passed as keyword arguments to the model's inference method. Defaults to None.
|
|
508
515
|
"""
|
|
509
516
|
...
|
|
510
517
|
|
|
@@ -516,6 +523,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
516
523
|
service_name: str,
|
|
517
524
|
function_name: Optional[str] = None,
|
|
518
525
|
strict_input_validation: bool = False,
|
|
526
|
+
params: Optional[dict[str, Any]] = None,
|
|
519
527
|
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
|
520
528
|
"""Invoke a method in a model version object via a service.
|
|
521
529
|
|
|
@@ -525,6 +533,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
525
533
|
function_name: The function name to run. It is the name used to call a function in SQL.
|
|
526
534
|
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
|
527
535
|
type validation to make sure your input data won't overflow when providing to the model.
|
|
536
|
+
params: Optional dictionary of model inference parameters (e.g., temperature, top_k for LLMs).
|
|
537
|
+
These are passed as keyword arguments to the model's inference method. Defaults to None.
|
|
528
538
|
"""
|
|
529
539
|
...
|
|
530
540
|
|
|
@@ -541,6 +551,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
541
551
|
function_name: Optional[str] = None,
|
|
542
552
|
partition_column: Optional[str] = None,
|
|
543
553
|
strict_input_validation: bool = False,
|
|
554
|
+
params: Optional[dict[str, Any]] = None,
|
|
544
555
|
) -> Union[pd.DataFrame, "dataframe.DataFrame"]:
|
|
545
556
|
"""Invoke a method in a model version object via the warehouse or a service.
|
|
546
557
|
|
|
@@ -552,6 +563,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
552
563
|
partition_column: The partition column name to partition by.
|
|
553
564
|
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
|
554
565
|
type validation to make sure your input data won't overflow when providing to the model.
|
|
566
|
+
params: Optional dictionary of model inference parameters (e.g., temperature, top_k for LLMs).
|
|
567
|
+
These are passed as keyword arguments to the model's inference method. Defaults to None.
|
|
555
568
|
|
|
556
569
|
Returns:
|
|
557
570
|
The prediction data. It would be the same type dataframe as your input.
|
|
@@ -582,6 +595,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
582
595
|
service_name=service_name_id,
|
|
583
596
|
strict_input_validation=strict_input_validation,
|
|
584
597
|
statement_params=statement_params,
|
|
598
|
+
params=params,
|
|
585
599
|
)
|
|
586
600
|
else:
|
|
587
601
|
manifest = self._get_model_manifest(statement_params=statement_params)
|
|
@@ -621,6 +635,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
621
635
|
statement_params=statement_params,
|
|
622
636
|
is_partitioned=target_function_info["is_partitioned"],
|
|
623
637
|
explain_case_sensitive=explain_case_sensitive,
|
|
638
|
+
params=params,
|
|
624
639
|
)
|
|
625
640
|
|
|
626
641
|
def _determine_explain_case_sensitivity(
|
|
@@ -651,6 +666,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
651
666
|
input_spec: dataframe.DataFrame,
|
|
652
667
|
output_spec: batch_inference_specs.OutputSpec,
|
|
653
668
|
job_spec: Optional[batch_inference_specs.JobSpec] = None,
|
|
669
|
+
params: Optional[dict[str, Any]] = None,
|
|
670
|
+
column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]] = None,
|
|
671
|
+
inference_engine_options: Optional[dict[str, Any]] = None,
|
|
654
672
|
) -> job.MLJob[Any]:
|
|
655
673
|
"""Execute batch inference on datasets as an SPCS job.
|
|
656
674
|
|
|
@@ -664,6 +682,16 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
664
682
|
job_spec (Optional[batch_inference_specs.JobSpec]): Optional configuration for job
|
|
665
683
|
execution parameters such as compute resources, worker counts, and job naming.
|
|
666
684
|
If None, default values will be used.
|
|
685
|
+
params (Optional[dict[str, Any]]): Optional dictionary of model inference parameters
|
|
686
|
+
(e.g., temperature, top_k for LLMs). These are passed as keyword arguments to the
|
|
687
|
+
model's inference method. Defaults to None.
|
|
688
|
+
column_handling (Optional[dict[str, batch_inference_specs.FileEncoding]]): Optional dictionary
|
|
689
|
+
specifying how to handle specific columns during file I/O. Maps column names to their
|
|
690
|
+
file encoding configuration.
|
|
691
|
+
inference_engine_options: Options for the service creation with custom inference engine.
|
|
692
|
+
Supports `engine` and `engine_args_override`.
|
|
693
|
+
`engine` is the type of the inference engine to use.
|
|
694
|
+
`engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
667
695
|
|
|
668
696
|
Returns:
|
|
669
697
|
job.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
|
|
@@ -722,6 +750,15 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
722
750
|
if job_spec is None:
|
|
723
751
|
job_spec = batch_inference_specs.JobSpec()
|
|
724
752
|
|
|
753
|
+
# Validate GPU support if GPU resources are requested
|
|
754
|
+
self._throw_error_if_gpu_is_not_supported(job_spec.gpu_requests, statement_params)
|
|
755
|
+
|
|
756
|
+
inference_engine_args = self._prepare_inference_engine_args(
|
|
757
|
+
inference_engine_options,
|
|
758
|
+
job_spec.gpu_requests,
|
|
759
|
+
statement_params,
|
|
760
|
+
)
|
|
761
|
+
|
|
725
762
|
warehouse = job_spec.warehouse or self._service_ops._session.get_current_warehouse()
|
|
726
763
|
if warehouse is None:
|
|
727
764
|
raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
|
|
@@ -746,12 +783,14 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
746
783
|
else:
|
|
747
784
|
job_name = job_spec.job_name
|
|
748
785
|
|
|
786
|
+
target_function_info = self._get_function_info(function_name=job_spec.function_name)
|
|
787
|
+
|
|
749
788
|
return self._service_ops.invoke_batch_job_method(
|
|
750
789
|
# model version info
|
|
751
790
|
model_name=self._model_name,
|
|
752
791
|
version_name=self._version_name,
|
|
753
792
|
# job spec
|
|
754
|
-
function_name=
|
|
793
|
+
function_name=target_function_info["target_method"],
|
|
755
794
|
compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
|
|
756
795
|
force_rebuild=job_spec.force_rebuild,
|
|
757
796
|
image_repo_name=job_spec.image_repo,
|
|
@@ -766,10 +805,14 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
766
805
|
# input and output
|
|
767
806
|
input_stage_location=input_stage_location,
|
|
768
807
|
input_file_pattern="*",
|
|
808
|
+
column_handling=column_handling,
|
|
809
|
+
params=params,
|
|
810
|
+
signature_params=target_function_info["signature"].params,
|
|
769
811
|
output_stage_location=output_stage_location,
|
|
770
812
|
completion_filename="_SUCCESS",
|
|
771
813
|
# misc
|
|
772
814
|
statement_params=statement_params,
|
|
815
|
+
inference_engine_args=inference_engine_args,
|
|
773
816
|
)
|
|
774
817
|
|
|
775
818
|
def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
|
|
@@ -985,20 +1028,55 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
985
1028
|
" the `log_model` function."
|
|
986
1029
|
)
|
|
987
1030
|
|
|
988
|
-
def
|
|
1031
|
+
def _prepare_inference_engine_args(
|
|
1032
|
+
self,
|
|
1033
|
+
inference_engine_options: Optional[dict[str, Any]],
|
|
1034
|
+
gpu_requests: Optional[Union[str, int]],
|
|
1035
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
1036
|
+
) -> Optional[service_ops.InferenceEngineArgs]:
|
|
1037
|
+
"""Prepare and validate inference engine arguments.
|
|
1038
|
+
|
|
1039
|
+
This method handles the common logic for processing inference engine options:
|
|
1040
|
+
1. Parse inference engine options into InferenceEngineArgs
|
|
1041
|
+
2. Validate that the model is a HuggingFace text-generation model (if inference engine is specified)
|
|
1042
|
+
3. Enrich inference engine args
|
|
1043
|
+
|
|
1044
|
+
Args:
|
|
1045
|
+
inference_engine_options: Optional dictionary containing inference engine configuration.
|
|
1046
|
+
gpu_requests: GPU resource request string (e.g., "4").
|
|
1047
|
+
statement_params: Optional dictionary of statement parameters for SQL commands.
|
|
1048
|
+
|
|
1049
|
+
Returns:
|
|
1050
|
+
Prepared InferenceEngineArgs or None if no inference engine is specified.
|
|
1051
|
+
"""
|
|
1052
|
+
inference_engine_args = inference_engine_utils._get_inference_engine_args(inference_engine_options)
|
|
1053
|
+
|
|
1054
|
+
if inference_engine_args is not None:
|
|
1055
|
+
# Validate that model is HuggingFace vLLM supported model and is logged with
|
|
1056
|
+
# OpenAI compatible signature.
|
|
1057
|
+
self._check_huggingface_vllm_supported_model(statement_params)
|
|
1058
|
+
# Enrich with GPU configuration
|
|
1059
|
+
inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
|
|
1060
|
+
inference_engine_args,
|
|
1061
|
+
gpu_requests,
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
return inference_engine_args
|
|
1065
|
+
|
|
1066
|
+
def _check_huggingface_vllm_supported_model(
|
|
989
1067
|
self,
|
|
990
1068
|
statement_params: Optional[dict[str, Any]] = None,
|
|
991
1069
|
) -> None:
|
|
992
|
-
"""Check if the model is a HuggingFace pipeline with
|
|
993
|
-
and is logged with
|
|
1070
|
+
"""Check if the model is a HuggingFace pipeline with vLLM supported task
|
|
1071
|
+
and is logged with OpenAI compatible signature.
|
|
994
1072
|
|
|
995
1073
|
Args:
|
|
996
1074
|
statement_params: Optional dictionary of statement parameters to include
|
|
997
1075
|
in the SQL command to fetch model spec.
|
|
998
1076
|
|
|
999
1077
|
Raises:
|
|
1000
|
-
ValueError: If the model is not a HuggingFace
|
|
1001
|
-
if the model is not logged with
|
|
1078
|
+
ValueError: If the model is not a HuggingFace vLLM supported model or
|
|
1079
|
+
if the model is not logged with OpenAI compatible signature.
|
|
1002
1080
|
"""
|
|
1003
1081
|
# Fetch model spec
|
|
1004
1082
|
model_spec = self._get_model_spec(statement_params)
|
|
@@ -1007,34 +1085,37 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1007
1085
|
model_type = model_spec.get("model_type")
|
|
1008
1086
|
if model_type != "huggingface_pipeline":
|
|
1009
1087
|
raise ValueError(
|
|
1010
|
-
f"Inference engine is only supported for HuggingFace
|
|
1088
|
+
f"Inference engine is only supported for HuggingFace vLLM supported models. "
|
|
1011
1089
|
f"Found model_type: {model_type}"
|
|
1012
1090
|
)
|
|
1013
1091
|
|
|
1014
|
-
# Check if model supports
|
|
1092
|
+
# Check if model supports vLLM supported task
|
|
1015
1093
|
# There should only be one model in the list because we don't support multiple models in a single model spec
|
|
1016
1094
|
models = model_spec.get("models", {})
|
|
1017
|
-
|
|
1095
|
+
is_vllm_supported_task = False
|
|
1018
1096
|
found_tasks: list[str] = []
|
|
1019
1097
|
|
|
1020
|
-
# As long as the model supports
|
|
1098
|
+
# As long as the model supports vLLM supported task, we can use it
|
|
1021
1099
|
for _, model_info in models.items():
|
|
1022
1100
|
options = model_info.get("options", {})
|
|
1023
1101
|
task = options.get("task")
|
|
1024
1102
|
if task:
|
|
1025
1103
|
found_tasks.append(str(task))
|
|
1026
|
-
if task
|
|
1027
|
-
|
|
1104
|
+
if task in VLLM_SUPPORTED_TASKS:
|
|
1105
|
+
is_vllm_supported_task = True
|
|
1028
1106
|
break
|
|
1029
1107
|
|
|
1030
|
-
if not
|
|
1108
|
+
if not is_vllm_supported_task:
|
|
1031
1109
|
tasks_str = ", ".join(found_tasks)
|
|
1032
1110
|
found_tasks_str = (
|
|
1033
1111
|
f"Found task(s): {tasks_str} in model spec." if found_tasks else "No task found in model spec."
|
|
1034
1112
|
)
|
|
1035
|
-
|
|
1113
|
+
supported_tasks_str = ", ".join(VLLM_SUPPORTED_TASKS)
|
|
1114
|
+
raise ValueError(
|
|
1115
|
+
f"Inference engine is only supported for vLLM supported tasks. {supported_tasks_str}. {found_tasks_str}"
|
|
1116
|
+
)
|
|
1036
1117
|
|
|
1037
|
-
# Check if the model is logged with
|
|
1118
|
+
# Check if the model is logged with OpenAI compatible signature.
|
|
1038
1119
|
signatures_dict = model_spec.get("signatures", {})
|
|
1039
1120
|
|
|
1040
1121
|
# Deserialize signatures from model spec to ModelSignature objects for proper semantic comparison.
|
|
@@ -1042,11 +1123,16 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1042
1123
|
func_name: core.ModelSignature.from_dict(sig_dict) for func_name, sig_dict in signatures_dict.items()
|
|
1043
1124
|
}
|
|
1044
1125
|
|
|
1045
|
-
if deserialized_signatures
|
|
1126
|
+
if deserialized_signatures not in [
|
|
1127
|
+
openai_signatures.OPENAI_CHAT_SIGNATURE,
|
|
1128
|
+
openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING,
|
|
1129
|
+
]:
|
|
1046
1130
|
raise ValueError(
|
|
1047
|
-
"Inference engine requires the model to be logged with OPENAI_CHAT_SIGNATURE
|
|
1131
|
+
"Inference engine requires the model to be logged with openai_signatures.OPENAI_CHAT_SIGNATURE or "
|
|
1132
|
+
"openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING. "
|
|
1048
1133
|
f"Found signatures: {signatures_dict}. "
|
|
1049
|
-
"Please log the model with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE"
|
|
1134
|
+
"Please log the model again with: signatures=openai_signatures.OPENAI_CHAT_SIGNATURE or "
|
|
1135
|
+
"signatures=openai_signatures.OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING"
|
|
1050
1136
|
)
|
|
1051
1137
|
|
|
1052
1138
|
@overload
|
|
@@ -1287,20 +1373,11 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1287
1373
|
# Validate GPU support if GPU resources are requested
|
|
1288
1374
|
self._throw_error_if_gpu_is_not_supported(gpu_requests, statement_params)
|
|
1289
1375
|
|
|
1290
|
-
inference_engine_args =
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
if inference_engine_args is not None:
|
|
1296
|
-
self._check_huggingface_text_generation_model(statement_params)
|
|
1297
|
-
|
|
1298
|
-
# Enrich inference engine args if inference engine is specified
|
|
1299
|
-
if inference_engine_args is not None:
|
|
1300
|
-
inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
|
|
1301
|
-
inference_engine_args,
|
|
1302
|
-
gpu_requests,
|
|
1303
|
-
)
|
|
1376
|
+
inference_engine_args = self._prepare_inference_engine_args(
|
|
1377
|
+
inference_engine_options,
|
|
1378
|
+
gpu_requests,
|
|
1379
|
+
statement_params,
|
|
1380
|
+
)
|
|
1304
1381
|
|
|
1305
1382
|
from snowflake.ml.model import event_handler
|
|
1306
1383
|
from snowflake.snowpark import exceptions
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
import hashlib
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DeploymentStep(enum.Enum):
|
|
7
|
+
MODEL_BUILD = ("model-build", "model_build_")
|
|
8
|
+
MODEL_INFERENCE = ("model-inference", None)
|
|
9
|
+
MODEL_LOGGING = ("model-logging", "model_logging_")
|
|
10
|
+
|
|
11
|
+
def __init__(self, container_name: str, service_name_prefix: Optional[str]) -> None:
|
|
12
|
+
self._container_name = container_name
|
|
13
|
+
self._service_name_prefix = service_name_prefix
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def container_name(self) -> str:
|
|
17
|
+
"""Get the container name for the deployment step."""
|
|
18
|
+
return self._container_name
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def service_name_prefix(self) -> Optional[str]:
|
|
22
|
+
"""Get the service name prefix for the deployment step."""
|
|
23
|
+
return self._service_name_prefix
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
|
|
27
|
+
"""Get the service ID through the server-side logic."""
|
|
28
|
+
uuid = query_id.replace("-", "")
|
|
29
|
+
big_int = int(uuid, 16)
|
|
30
|
+
md5_hash = hashlib.md5(str(big_int).encode(), usedforsecurity=False).hexdigest()
|
|
31
|
+
identifier = md5_hash[:8]
|
|
32
|
+
service_name_prefix = deployment_step.service_name_prefix
|
|
33
|
+
if service_name_prefix is None:
|
|
34
|
+
# raise an exception if the service name prefix is None
|
|
35
|
+
raise ValueError(f"Service name prefix is {service_name_prefix} for deployment step {deployment_step}.")
|
|
36
|
+
return (service_name_prefix + identifier).upper()
|