snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +36 -0
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
- snowflake/ml/data/data_connector.py +103 -1
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
- snowflake/ml/experiment/callback/__init__.py +0 -0
- snowflake/ml/experiment/callback/keras.py +25 -2
- snowflake/ml/experiment/callback/lightgbm.py +27 -2
- snowflake/ml/experiment/callback/xgboost.py +25 -2
- snowflake/ml/experiment/experiment_tracking.py +93 -3
- snowflake/ml/experiment/utils.py +6 -0
- snowflake/ml/feature_store/feature_view.py +34 -24
- snowflake/ml/jobs/_interop/protocols.py +3 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +354 -356
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
- snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
- snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
- snowflake/ml/jobs/_utils/spec_utils.py +1 -445
- snowflake/ml/jobs/_utils/stage_utils.py +22 -1
- snowflake/ml/jobs/_utils/types.py +14 -7
- snowflake/ml/jobs/job.py +2 -8
- snowflake/ml/jobs/manager.py +57 -135
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +130 -14
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +93 -8
- snowflake/ml/model/_client/ops/service_ops.py +32 -52
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +94 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
- snowflake/ml/model/_packager/model_handler.py +8 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
- snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +390 -8
- snowflake/ml/model/_signatures/utils.py +13 -4
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/compute_pool.py +2 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/models/huggingface.py +285 -0
- snowflake/ml/model/models/huggingface_pipeline.py +19 -208
- snowflake/ml/model/type_hints.py +7 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/utils/html_utils.py +67 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
- snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
import base64
|
|
1
2
|
import enum
|
|
3
|
+
import json
|
|
2
4
|
import pathlib
|
|
3
5
|
import tempfile
|
|
4
6
|
import uuid
|
|
@@ -6,11 +8,12 @@ import warnings
|
|
|
6
8
|
from typing import Any, Callable, Optional, Union, overload
|
|
7
9
|
|
|
8
10
|
import pandas as pd
|
|
11
|
+
from pydantic import TypeAdapter
|
|
9
12
|
|
|
10
13
|
from snowflake import snowpark
|
|
11
|
-
from snowflake.ml import jobs
|
|
12
14
|
from snowflake.ml._internal import telemetry
|
|
13
15
|
from snowflake.ml._internal.utils import sql_identifier
|
|
16
|
+
from snowflake.ml.jobs import job
|
|
14
17
|
from snowflake.ml.lineage import lineage_node
|
|
15
18
|
from snowflake.ml.model import openai_signatures, task, type_hints
|
|
16
19
|
from snowflake.ml.model._client.model import (
|
|
@@ -30,6 +33,7 @@ _TELEMETRY_PROJECT = "MLOps"
|
|
|
30
33
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
|
31
34
|
_BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
|
|
32
35
|
_BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
|
|
36
|
+
_UTF8_ENCODING = "utf-8"
|
|
33
37
|
|
|
34
38
|
|
|
35
39
|
class ExportMode(enum.Enum):
|
|
@@ -46,6 +50,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
46
50
|
_version_name: sql_identifier.SqlIdentifier
|
|
47
51
|
_functions: list[model_manifest_schema.ModelFunctionInfo]
|
|
48
52
|
_model_spec: Optional[model_meta_schema.ModelMetadataDict]
|
|
53
|
+
_model_manifest: Optional[model_manifest_schema.ModelManifestDict]
|
|
49
54
|
|
|
50
55
|
def __init__(self) -> None:
|
|
51
56
|
raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
|
|
@@ -156,6 +161,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
156
161
|
self._version_name = version_name
|
|
157
162
|
self._functions = self._get_functions()
|
|
158
163
|
self._model_spec = None
|
|
164
|
+
self._model_manifest = None
|
|
159
165
|
super(cls, cls).__init__(
|
|
160
166
|
self,
|
|
161
167
|
session=model_ops._session,
|
|
@@ -463,6 +469,28 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
463
469
|
)
|
|
464
470
|
return self._model_spec
|
|
465
471
|
|
|
472
|
+
def _get_model_manifest(
|
|
473
|
+
self, statement_params: Optional[dict[str, Any]] = None
|
|
474
|
+
) -> model_manifest_schema.ModelManifestDict:
|
|
475
|
+
"""Fetch and cache the model manifest for this model version.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
statement_params: Optional dictionary of statement parameters to include
|
|
479
|
+
in the SQL command to fetch the model manifest.
|
|
480
|
+
|
|
481
|
+
Returns:
|
|
482
|
+
The model manifest as a dictionary for this model version.
|
|
483
|
+
"""
|
|
484
|
+
if self._model_manifest is None:
|
|
485
|
+
self._model_manifest = self._model_ops.get_model_version_manifest(
|
|
486
|
+
database_name=None,
|
|
487
|
+
schema_name=None,
|
|
488
|
+
model_name=self._model_name,
|
|
489
|
+
version_name=self._version_name,
|
|
490
|
+
statement_params=statement_params,
|
|
491
|
+
)
|
|
492
|
+
return self._model_manifest
|
|
493
|
+
|
|
466
494
|
@overload
|
|
467
495
|
def run(
|
|
468
496
|
self,
|
|
@@ -471,6 +499,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
471
499
|
function_name: Optional[str] = None,
|
|
472
500
|
partition_column: Optional[str] = None,
|
|
473
501
|
strict_input_validation: bool = False,
|
|
502
|
+
params: Optional[dict[str, Any]] = None,
|
|
474
503
|
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
|
475
504
|
"""Invoke a method in a model version object.
|
|
476
505
|
|
|
@@ -481,6 +510,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
481
510
|
partition_column: The partition column name to partition by.
|
|
482
511
|
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
|
483
512
|
type validation to make sure your input data won't overflow when providing to the model.
|
|
513
|
+
params: Optional dictionary of model inference parameters (e.g., temperature, top_k for LLMs).
|
|
514
|
+
These are passed as keyword arguments to the model's inference method. Defaults to None.
|
|
484
515
|
"""
|
|
485
516
|
...
|
|
486
517
|
|
|
@@ -492,6 +523,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
492
523
|
service_name: str,
|
|
493
524
|
function_name: Optional[str] = None,
|
|
494
525
|
strict_input_validation: bool = False,
|
|
526
|
+
params: Optional[dict[str, Any]] = None,
|
|
495
527
|
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
|
496
528
|
"""Invoke a method in a model version object via a service.
|
|
497
529
|
|
|
@@ -501,6 +533,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
501
533
|
function_name: The function name to run. It is the name used to call a function in SQL.
|
|
502
534
|
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
|
503
535
|
type validation to make sure your input data won't overflow when providing to the model.
|
|
536
|
+
params: Optional dictionary of model inference parameters (e.g., temperature, top_k for LLMs).
|
|
537
|
+
These are passed as keyword arguments to the model's inference method. Defaults to None.
|
|
504
538
|
"""
|
|
505
539
|
...
|
|
506
540
|
|
|
@@ -517,6 +551,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
517
551
|
function_name: Optional[str] = None,
|
|
518
552
|
partition_column: Optional[str] = None,
|
|
519
553
|
strict_input_validation: bool = False,
|
|
554
|
+
params: Optional[dict[str, Any]] = None,
|
|
520
555
|
) -> Union[pd.DataFrame, "dataframe.DataFrame"]:
|
|
521
556
|
"""Invoke a method in a model version object via the warehouse or a service.
|
|
522
557
|
|
|
@@ -528,9 +563,14 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
528
563
|
partition_column: The partition column name to partition by.
|
|
529
564
|
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
|
530
565
|
type validation to make sure your input data won't overflow when providing to the model.
|
|
566
|
+
params: Optional dictionary of model inference parameters (e.g., temperature, top_k for LLMs).
|
|
567
|
+
These are passed as keyword arguments to the model's inference method. Defaults to None.
|
|
531
568
|
|
|
532
569
|
Returns:
|
|
533
570
|
The prediction data. It would be the same type dataframe as your input.
|
|
571
|
+
|
|
572
|
+
Raises:
|
|
573
|
+
ValueError: When the model does not support running on warehouse and no service name is provided.
|
|
534
574
|
"""
|
|
535
575
|
statement_params = telemetry.get_statement_params(
|
|
536
576
|
project=_TELEMETRY_PROJECT,
|
|
@@ -555,8 +595,30 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
555
595
|
service_name=service_name_id,
|
|
556
596
|
strict_input_validation=strict_input_validation,
|
|
557
597
|
statement_params=statement_params,
|
|
598
|
+
params=params,
|
|
558
599
|
)
|
|
559
600
|
else:
|
|
601
|
+
manifest = self._get_model_manifest(statement_params=statement_params)
|
|
602
|
+
target_platforms = manifest.get("target_platforms", None)
|
|
603
|
+
if (
|
|
604
|
+
target_platforms is not None
|
|
605
|
+
and len(target_platforms) > 0
|
|
606
|
+
and type_hints.TargetPlatform.WAREHOUSE.value not in target_platforms
|
|
607
|
+
):
|
|
608
|
+
raise ValueError(
|
|
609
|
+
f"The model {self.fully_qualified_model_name} version {self.version_name} "
|
|
610
|
+
"is not logged for inference in Warehouse. "
|
|
611
|
+
"To run the model in Warehouse, please log the model again using `log_model` API with "
|
|
612
|
+
'`target_platforms=["WAREHOUSE"]` or '
|
|
613
|
+
'`target_platforms=["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"]` and rerun the command. '
|
|
614
|
+
"To run the model in Snowpark Container Services, the `service_name` argument must be provided. "
|
|
615
|
+
"You can create a service using the `create_service` API. "
|
|
616
|
+
"For inference in Warehouse, see https://docs.snowflake.com/en/developer-guide/"
|
|
617
|
+
"snowflake-ml/model-registry/warehouse#inference-from-python. "
|
|
618
|
+
"For inference in Snowpark Container Services, see https://docs.snowflake.com/en/developer-guide/"
|
|
619
|
+
"snowflake-ml/model-registry/container#python."
|
|
620
|
+
)
|
|
621
|
+
|
|
560
622
|
explain_case_sensitive = self._determine_explain_case_sensitivity(target_function_info, statement_params)
|
|
561
623
|
|
|
562
624
|
return self._model_ops.invoke_method(
|
|
@@ -573,6 +635,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
573
635
|
statement_params=statement_params,
|
|
574
636
|
is_partitioned=target_function_info["is_partitioned"],
|
|
575
637
|
explain_case_sensitive=explain_case_sensitive,
|
|
638
|
+
params=params,
|
|
576
639
|
)
|
|
577
640
|
|
|
578
641
|
def _determine_explain_case_sensitivity(
|
|
@@ -586,6 +649,41 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
586
649
|
method_options, target_function_info["name"]
|
|
587
650
|
)
|
|
588
651
|
|
|
652
|
+
@staticmethod
|
|
653
|
+
def _encode_column_handling(
|
|
654
|
+
column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
|
|
655
|
+
) -> Optional[str]:
|
|
656
|
+
"""Validate and encode column_handling to a base64 string.
|
|
657
|
+
|
|
658
|
+
Args:
|
|
659
|
+
column_handling: Optional dictionary mapping column names to file encoding options.
|
|
660
|
+
|
|
661
|
+
Returns:
|
|
662
|
+
Base64 encoded JSON string of the column handling options, or None if input is None.
|
|
663
|
+
"""
|
|
664
|
+
# TODO: validation for column names
|
|
665
|
+
if column_handling is None:
|
|
666
|
+
return None
|
|
667
|
+
adapter = TypeAdapter(dict[str, batch_inference_specs.ColumnHandlingOptions])
|
|
668
|
+
# TODO: throw error if the validate_python function fails
|
|
669
|
+
validated_input = adapter.validate_python(column_handling)
|
|
670
|
+
return base64.b64encode(adapter.dump_json(validated_input)).decode(_UTF8_ENCODING)
|
|
671
|
+
|
|
672
|
+
@staticmethod
|
|
673
|
+
def _encode_params(params: Optional[dict[str, Any]]) -> Optional[str]:
|
|
674
|
+
"""Encode params dictionary to a base64 string.
|
|
675
|
+
|
|
676
|
+
Args:
|
|
677
|
+
params: Optional dictionary of model inference parameters.
|
|
678
|
+
|
|
679
|
+
Returns:
|
|
680
|
+
Base64 encoded JSON string of the params, or None if input is None.
|
|
681
|
+
"""
|
|
682
|
+
if params is None:
|
|
683
|
+
return None
|
|
684
|
+
# TODO: validation for param names, types
|
|
685
|
+
return base64.b64encode(json.dumps(params).encode(_UTF8_ENCODING)).decode(_UTF8_ENCODING)
|
|
686
|
+
|
|
589
687
|
@telemetry.send_api_usage_telemetry(
|
|
590
688
|
project=_TELEMETRY_PROJECT,
|
|
591
689
|
subproject=_TELEMETRY_SUBPROJECT,
|
|
@@ -603,7 +701,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
603
701
|
input_spec: dataframe.DataFrame,
|
|
604
702
|
output_spec: batch_inference_specs.OutputSpec,
|
|
605
703
|
job_spec: Optional[batch_inference_specs.JobSpec] = None,
|
|
606
|
-
|
|
704
|
+
params: Optional[dict[str, Any]] = None,
|
|
705
|
+
column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]] = None,
|
|
706
|
+
) -> job.MLJob[Any]:
|
|
607
707
|
"""Execute batch inference on datasets as an SPCS job.
|
|
608
708
|
|
|
609
709
|
Args:
|
|
@@ -616,9 +716,15 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
616
716
|
job_spec (Optional[batch_inference_specs.JobSpec]): Optional configuration for job
|
|
617
717
|
execution parameters such as compute resources, worker counts, and job naming.
|
|
618
718
|
If None, default values will be used.
|
|
719
|
+
params (Optional[dict[str, Any]]): Optional dictionary of model inference parameters
|
|
720
|
+
(e.g., temperature, top_k for LLMs). These are passed as keyword arguments to the
|
|
721
|
+
model's inference method. Defaults to None.
|
|
722
|
+
column_handling (Optional[dict[str, batch_inference_specs.FileEncoding]]): Optional dictionary
|
|
723
|
+
specifying how to handle specific columns during file I/O. Maps column names to their
|
|
724
|
+
file encoding configuration.
|
|
619
725
|
|
|
620
726
|
Returns:
|
|
621
|
-
|
|
727
|
+
job.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
|
|
622
728
|
lifecycle.
|
|
623
729
|
|
|
624
730
|
Raises:
|
|
@@ -671,6 +777,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
671
777
|
subproject=_TELEMETRY_SUBPROJECT,
|
|
672
778
|
)
|
|
673
779
|
|
|
780
|
+
column_handling_as_string = self._encode_column_handling(column_handling)
|
|
781
|
+
params_as_string = self._encode_params(params)
|
|
782
|
+
|
|
674
783
|
if job_spec is None:
|
|
675
784
|
job_spec = batch_inference_specs.JobSpec()
|
|
676
785
|
|
|
@@ -718,6 +827,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
718
827
|
# input and output
|
|
719
828
|
input_stage_location=input_stage_location,
|
|
720
829
|
input_file_pattern="*",
|
|
830
|
+
column_handling=column_handling_as_string,
|
|
831
|
+
params=params_as_string,
|
|
721
832
|
output_stage_location=output_stage_location,
|
|
722
833
|
completion_filename="_SUCCESS",
|
|
723
834
|
# misc
|
|
@@ -1019,6 +1130,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1019
1130
|
force_rebuild: bool = False,
|
|
1020
1131
|
build_external_access_integration: Optional[str] = None,
|
|
1021
1132
|
block: bool = True,
|
|
1133
|
+
autocapture: bool = False,
|
|
1022
1134
|
inference_engine_options: Optional[dict[str, Any]] = None,
|
|
1023
1135
|
experimental_options: Optional[dict[str, Any]] = None,
|
|
1024
1136
|
) -> Union[str, async_job.AsyncJob]:
|
|
@@ -1053,13 +1165,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1053
1165
|
block: A bool value indicating whether this function will wait until the service is available.
|
|
1054
1166
|
When it is ``False``, this function executes the underlying service creation asynchronously
|
|
1055
1167
|
and returns an :class:`AsyncJob`.
|
|
1168
|
+
autocapture: Whether inference autocapture is enabled on the service. If true, inference data will be
|
|
1169
|
+
captured in the model inference table.
|
|
1056
1170
|
inference_engine_options: Options for the service creation with custom inference engine.
|
|
1057
1171
|
Supports `engine` and `engine_args_override`.
|
|
1058
1172
|
`engine` is the type of the inference engine to use.
|
|
1059
1173
|
`engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1060
1174
|
experimental_options: Experimental options for the service creation.
|
|
1061
|
-
Currently only `autocapture` is supported.
|
|
1062
|
-
`autocapture` is a boolean to enable/disable inference table.
|
|
1063
1175
|
"""
|
|
1064
1176
|
...
|
|
1065
1177
|
|
|
@@ -1081,6 +1193,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1081
1193
|
force_rebuild: bool = False,
|
|
1082
1194
|
build_external_access_integrations: Optional[list[str]] = None,
|
|
1083
1195
|
block: bool = True,
|
|
1196
|
+
autocapture: bool = False,
|
|
1084
1197
|
inference_engine_options: Optional[dict[str, Any]] = None,
|
|
1085
1198
|
experimental_options: Optional[dict[str, Any]] = None,
|
|
1086
1199
|
) -> Union[str, async_job.AsyncJob]:
|
|
@@ -1115,13 +1228,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1115
1228
|
block: A bool value indicating whether this function will wait until the service is available.
|
|
1116
1229
|
When it is ``False``, this function executes the underlying service creation asynchronously
|
|
1117
1230
|
and returns an :class:`AsyncJob`.
|
|
1231
|
+
autocapture: Whether inference autocapture is enabled on the service. If true, inference data will be
|
|
1232
|
+
captured in the model inference table.
|
|
1118
1233
|
inference_engine_options: Options for the service creation with custom inference engine.
|
|
1119
1234
|
Supports `engine` and `engine_args_override`.
|
|
1120
1235
|
`engine` is the type of the inference engine to use.
|
|
1121
1236
|
`engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1122
1237
|
experimental_options: Experimental options for the service creation.
|
|
1123
|
-
Currently only `autocapture` is supported.
|
|
1124
|
-
`autocapture` is a boolean to enable/disable inference table.
|
|
1125
1238
|
"""
|
|
1126
1239
|
...
|
|
1127
1240
|
|
|
@@ -1158,6 +1271,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1158
1271
|
build_external_access_integration: Optional[str] = None,
|
|
1159
1272
|
build_external_access_integrations: Optional[list[str]] = None,
|
|
1160
1273
|
block: bool = True,
|
|
1274
|
+
autocapture: bool = False,
|
|
1161
1275
|
inference_engine_options: Optional[dict[str, Any]] = None,
|
|
1162
1276
|
experimental_options: Optional[dict[str, Any]] = None,
|
|
1163
1277
|
) -> Union[str, async_job.AsyncJob]:
|
|
@@ -1194,13 +1308,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1194
1308
|
block: A bool value indicating whether this function will wait until the service is available.
|
|
1195
1309
|
When it is False, this function executes the underlying service creation asynchronously
|
|
1196
1310
|
and returns an AsyncJob.
|
|
1311
|
+
autocapture: Whether inference autocapture is enabled on the service. If true, inference data will be
|
|
1312
|
+
captured in the model inference table.
|
|
1197
1313
|
inference_engine_options: Options for the service creation with custom inference engine.
|
|
1198
1314
|
Supports `engine` and `engine_args_override`.
|
|
1199
1315
|
`engine` is the type of the inference engine to use.
|
|
1200
1316
|
`engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1201
1317
|
experimental_options: Experimental options for the service creation.
|
|
1202
|
-
Currently only `autocapture` is supported.
|
|
1203
|
-
`autocapture` is a boolean to enable/disable inference table.
|
|
1204
1318
|
|
|
1205
1319
|
|
|
1206
1320
|
Raises:
|
|
@@ -1251,9 +1365,6 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1251
1365
|
gpu_requests,
|
|
1252
1366
|
)
|
|
1253
1367
|
|
|
1254
|
-
# Extract autocapture from experimental_options
|
|
1255
|
-
autocapture = experimental_options.get("autocapture") if experimental_options else None
|
|
1256
|
-
|
|
1257
1368
|
from snowflake.ml.model import event_handler
|
|
1258
1369
|
from snowflake.snowpark import exceptions
|
|
1259
1370
|
|
|
@@ -1320,8 +1431,13 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1320
1431
|
"""List all the service names using this model version.
|
|
1321
1432
|
|
|
1322
1433
|
Returns:
|
|
1323
|
-
List of
|
|
1324
|
-
|
|
1434
|
+
List of details about all the services associated with this model version. The details include:
|
|
1435
|
+
name: The name of the service.
|
|
1436
|
+
status: The status of the service.
|
|
1437
|
+
inference_endpoint: The public endpoint of the service, if enabled and services is not in PENDING state.
|
|
1438
|
+
This will give privatelink endpoint if the session is created with privatelink connection
|
|
1439
|
+
internal_endpoint: The internal endpoint of the service, if services is not in PENDING state.
|
|
1440
|
+
autocapture_enabled: Whether service has autocapture enabled, if it is set in service proxy spec.
|
|
1325
1441
|
"""
|
|
1326
1442
|
statement_params = telemetry.get_statement_params(
|
|
1327
1443
|
project=_TELEMETRY_PROJECT,
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
import hashlib
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DeploymentStep(enum.Enum):
|
|
7
|
+
MODEL_BUILD = ("model-build", "model_build_")
|
|
8
|
+
MODEL_INFERENCE = ("model-inference", None)
|
|
9
|
+
MODEL_LOGGING = ("model-logging", "model_logging_")
|
|
10
|
+
|
|
11
|
+
def __init__(self, container_name: str, service_name_prefix: Optional[str]) -> None:
|
|
12
|
+
self._container_name = container_name
|
|
13
|
+
self._service_name_prefix = service_name_prefix
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def container_name(self) -> str:
|
|
17
|
+
"""Get the container name for the deployment step."""
|
|
18
|
+
return self._container_name
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def service_name_prefix(self) -> Optional[str]:
|
|
22
|
+
"""Get the service name prefix for the deployment step."""
|
|
23
|
+
return self._service_name_prefix
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
|
|
27
|
+
"""Get the service ID through the server-side logic."""
|
|
28
|
+
uuid = query_id.replace("-", "")
|
|
29
|
+
big_int = int(uuid, 16)
|
|
30
|
+
md5_hash = hashlib.md5(str(big_int).encode(), usedforsecurity=False).hexdigest()
|
|
31
|
+
identifier = md5_hash[:8]
|
|
32
|
+
service_name_prefix = deployment_step.service_name_prefix
|
|
33
|
+
if service_name_prefix is None:
|
|
34
|
+
# raise an exception if the service name prefix is None
|
|
35
|
+
raise ValueError(f"Service name prefix is {service_name_prefix} for deployment step {deployment_step}.")
|
|
36
|
+
return (service_name_prefix + identifier).upper()
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import enum
|
|
2
2
|
import json
|
|
3
|
+
import logging
|
|
3
4
|
import os
|
|
4
5
|
import pathlib
|
|
5
6
|
import tempfile
|
|
@@ -7,11 +8,13 @@ import warnings
|
|
|
7
8
|
from typing import Any, Literal, Optional, TypedDict, Union, cast, overload
|
|
8
9
|
|
|
9
10
|
import yaml
|
|
11
|
+
from typing_extensions import NotRequired
|
|
10
12
|
|
|
13
|
+
from snowflake.ml._internal import platform_capabilities
|
|
11
14
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
12
|
-
from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
|
|
15
|
+
from snowflake.ml._internal.utils import formatting, identifier, sql_identifier, url
|
|
13
16
|
from snowflake.ml.model import model_signature, type_hints
|
|
14
|
-
from snowflake.ml.model._client.ops import metadata_ops
|
|
17
|
+
from snowflake.ml.model._client.ops import deployment_step, metadata_ops
|
|
15
18
|
from snowflake.ml.model._client.sql import (
|
|
16
19
|
model as model_sql,
|
|
17
20
|
model_version as model_version_sql,
|
|
@@ -31,6 +34,8 @@ from snowflake.ml.model._signatures import snowpark_handler
|
|
|
31
34
|
from snowflake.snowpark import dataframe, row, session
|
|
32
35
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
33
36
|
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
34
39
|
|
|
35
40
|
# An enum class to represent Create Or Alter Model SQL command.
|
|
36
41
|
class ModelAction(enum.Enum):
|
|
@@ -42,6 +47,8 @@ class ServiceInfo(TypedDict):
|
|
|
42
47
|
name: str
|
|
43
48
|
status: str
|
|
44
49
|
inference_endpoint: Optional[str]
|
|
50
|
+
internal_endpoint: Optional[str]
|
|
51
|
+
autocapture_enabled: NotRequired[bool]
|
|
45
52
|
|
|
46
53
|
|
|
47
54
|
class ModelOperator:
|
|
@@ -651,6 +658,13 @@ class ModelOperator:
|
|
|
651
658
|
url_str = str(url_value)
|
|
652
659
|
return url_str if ModelOperator.PRIVATELINK_INGRESS_ENDPOINT_URL_SUBSTRING in url_str else None
|
|
653
660
|
|
|
661
|
+
def _extract_and_validate_port(self, res_row: "row.Row") -> Optional[int]:
|
|
662
|
+
"""Extract and validate port from endpoint row."""
|
|
663
|
+
port_value = res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_PORT_COL_NAME]
|
|
664
|
+
if port_value is None:
|
|
665
|
+
return None
|
|
666
|
+
return int(port_value)
|
|
667
|
+
|
|
654
668
|
def show_services(
|
|
655
669
|
self,
|
|
656
670
|
*,
|
|
@@ -684,8 +698,12 @@ class ModelOperator:
|
|
|
684
698
|
|
|
685
699
|
result: list[ServiceInfo] = []
|
|
686
700
|
is_privatelink_connection = self._is_privatelink_connection()
|
|
701
|
+
is_autocapture_param_enabled = (
|
|
702
|
+
platform_capabilities.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
|
|
703
|
+
)
|
|
687
704
|
|
|
688
705
|
for fully_qualified_service_name in fully_qualified_service_names:
|
|
706
|
+
port: Optional[int] = None
|
|
689
707
|
inference_endpoint: Optional[str] = None
|
|
690
708
|
db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
|
|
691
709
|
statuses = self._service_client.get_service_container_statuses(
|
|
@@ -695,6 +713,11 @@ class ModelOperator:
|
|
|
695
713
|
return result
|
|
696
714
|
|
|
697
715
|
service_status = statuses[0].service_status
|
|
716
|
+
service_description = self._service_client.describe_service(
|
|
717
|
+
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
|
718
|
+
)
|
|
719
|
+
internal_dns = str(service_description[self._service_client.DESC_SERVICE_INTERNAL_DNS_COL_NAME])
|
|
720
|
+
|
|
698
721
|
for res_row in self._service_client.show_endpoints(
|
|
699
722
|
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
|
700
723
|
):
|
|
@@ -706,19 +729,25 @@ class ModelOperator:
|
|
|
706
729
|
|
|
707
730
|
ingress_url = self._extract_and_validate_ingress_url(res_row)
|
|
708
731
|
privatelink_ingress_url = self._extract_and_validate_privatelink_url(res_row)
|
|
732
|
+
port = self._extract_and_validate_port(res_row)
|
|
709
733
|
|
|
710
734
|
if is_privatelink_connection and privatelink_ingress_url is not None:
|
|
711
735
|
inference_endpoint = privatelink_ingress_url
|
|
712
736
|
else:
|
|
713
737
|
inference_endpoint = ingress_url
|
|
714
738
|
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
)
|
|
739
|
+
service_info = ServiceInfo(
|
|
740
|
+
name=fully_qualified_service_name,
|
|
741
|
+
status=service_status.value,
|
|
742
|
+
inference_endpoint=inference_endpoint,
|
|
743
|
+
internal_endpoint=f"http://{internal_dns}:{port}" if port is not None else None,
|
|
721
744
|
)
|
|
745
|
+
if is_autocapture_param_enabled and self._service_client.DESC_SERVICE_SPEC_COL_NAME in service_description:
|
|
746
|
+
# Include column only if parameter is enabled and spec exists for service owner caller
|
|
747
|
+
autocapture_enabled = self._service_client.get_proxy_container_autocapture(service_description)
|
|
748
|
+
service_info["autocapture_enabled"] = autocapture_enabled
|
|
749
|
+
|
|
750
|
+
result.append(service_info)
|
|
722
751
|
|
|
723
752
|
return result
|
|
724
753
|
|
|
@@ -960,6 +989,7 @@ class ModelOperator:
|
|
|
960
989
|
statement_params: Optional[dict[str, str]] = None,
|
|
961
990
|
is_partitioned: Optional[bool] = None,
|
|
962
991
|
explain_case_sensitive: bool = False,
|
|
992
|
+
params: Optional[dict[str, Any]] = None,
|
|
963
993
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
964
994
|
...
|
|
965
995
|
|
|
@@ -976,6 +1006,7 @@ class ModelOperator:
|
|
|
976
1006
|
strict_input_validation: bool = False,
|
|
977
1007
|
statement_params: Optional[dict[str, str]] = None,
|
|
978
1008
|
explain_case_sensitive: bool = False,
|
|
1009
|
+
params: Optional[dict[str, Any]] = None,
|
|
979
1010
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
980
1011
|
...
|
|
981
1012
|
|
|
@@ -996,6 +1027,7 @@ class ModelOperator:
|
|
|
996
1027
|
statement_params: Optional[dict[str, str]] = None,
|
|
997
1028
|
is_partitioned: Optional[bool] = None,
|
|
998
1029
|
explain_case_sensitive: bool = False,
|
|
1030
|
+
params: Optional[dict[str, Any]] = None,
|
|
999
1031
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
1000
1032
|
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
|
1001
1033
|
|
|
@@ -1031,6 +1063,24 @@ class ModelOperator:
|
|
|
1031
1063
|
col_name = sql_identifier.SqlIdentifier(input_feature.name.upper(), case_sensitive=True)
|
|
1032
1064
|
input_args.append(col_name)
|
|
1033
1065
|
|
|
1066
|
+
method_parameters: Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]] = None
|
|
1067
|
+
if signature.params:
|
|
1068
|
+
# Start with defaults from signature
|
|
1069
|
+
final_params = {}
|
|
1070
|
+
for param_spec in signature.params:
|
|
1071
|
+
if hasattr(param_spec, "default_value"):
|
|
1072
|
+
final_params[param_spec.name] = param_spec.default_value
|
|
1073
|
+
|
|
1074
|
+
# Override with provided runtime parameters
|
|
1075
|
+
if params:
|
|
1076
|
+
final_params.update(params)
|
|
1077
|
+
|
|
1078
|
+
# Convert to list of tuples with SqlIdentifier for parameter names
|
|
1079
|
+
method_parameters = [
|
|
1080
|
+
(sql_identifier.SqlIdentifier(param_name), param_value)
|
|
1081
|
+
for param_name, param_value in final_params.items()
|
|
1082
|
+
]
|
|
1083
|
+
|
|
1034
1084
|
returns = []
|
|
1035
1085
|
for output_feature in signature.outputs:
|
|
1036
1086
|
output_name = identifier_rule.get_sql_identifier_from_feature(output_feature.name)
|
|
@@ -1049,6 +1099,7 @@ class ModelOperator:
|
|
|
1049
1099
|
schema_name=schema_name,
|
|
1050
1100
|
service_name=service_name,
|
|
1051
1101
|
statement_params=statement_params,
|
|
1102
|
+
params=method_parameters,
|
|
1052
1103
|
)
|
|
1053
1104
|
else:
|
|
1054
1105
|
assert model_name is not None
|
|
@@ -1064,6 +1115,7 @@ class ModelOperator:
|
|
|
1064
1115
|
model_name=model_name,
|
|
1065
1116
|
version_name=version_name,
|
|
1066
1117
|
statement_params=statement_params,
|
|
1118
|
+
params=method_parameters,
|
|
1067
1119
|
)
|
|
1068
1120
|
elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
|
1069
1121
|
df_res = self._model_version_client.invoke_table_function_method(
|
|
@@ -1079,6 +1131,7 @@ class ModelOperator:
|
|
|
1079
1131
|
statement_params=statement_params,
|
|
1080
1132
|
is_partitioned=is_partitioned or False,
|
|
1081
1133
|
explain_case_sensitive=explain_case_sensitive,
|
|
1134
|
+
params=method_parameters,
|
|
1082
1135
|
)
|
|
1083
1136
|
|
|
1084
1137
|
if keep_order:
|
|
@@ -1212,3 +1265,35 @@ class ModelOperator:
|
|
|
1212
1265
|
target_path=local_file_dir,
|
|
1213
1266
|
statement_params=statement_params,
|
|
1214
1267
|
)
|
|
1268
|
+
|
|
1269
|
+
def run_import_model_query(
|
|
1270
|
+
self,
|
|
1271
|
+
*,
|
|
1272
|
+
database_name: str,
|
|
1273
|
+
schema_name: str,
|
|
1274
|
+
yaml_content: str,
|
|
1275
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
1276
|
+
) -> None:
|
|
1277
|
+
yaml_content_escaped = snowpark_utils.escape_single_quotes(yaml_content) # type: ignore[no-untyped-call]
|
|
1278
|
+
|
|
1279
|
+
async_job = self._session.sql(
|
|
1280
|
+
f"SELECT SYSTEM$IMPORT_MODEL('{yaml_content_escaped}')",
|
|
1281
|
+
).collect(block=False, statement_params=statement_params)
|
|
1282
|
+
query_id = async_job.query_id # type: ignore[attr-defined]
|
|
1283
|
+
|
|
1284
|
+
logger.info(f"Remotely importing model, with the query id: {query_id}")
|
|
1285
|
+
model_logger_service_name = sql_identifier.SqlIdentifier(
|
|
1286
|
+
deployment_step.get_service_id_from_deployment_step(
|
|
1287
|
+
query_id,
|
|
1288
|
+
deployment_step.DeploymentStep.MODEL_LOGGING,
|
|
1289
|
+
)
|
|
1290
|
+
)
|
|
1291
|
+
|
|
1292
|
+
logger_name = model_logger_service_name.identifier()
|
|
1293
|
+
job_url = f"{url.JOB_URL_PREFIX}/{database_name}/{schema_name}/{logger_name}"
|
|
1294
|
+
snowflake_url = url.get_snowflake_url(session=self._session, url_path=job_url)
|
|
1295
|
+
logger.info(
|
|
1296
|
+
f"To monitor the progress of the model logging job, head to the job monitoring page {snowflake_url}"
|
|
1297
|
+
)
|
|
1298
|
+
|
|
1299
|
+
async_job.result() # type: ignore[attr-defined]
|