snowflake-ml-python 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
- snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
- snowflake/ml/_internal/telemetry.py +3 -2
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
- snowflake/ml/experiment/callback/keras.py +3 -0
- snowflake/ml/experiment/callback/lightgbm.py +3 -0
- snowflake/ml/experiment/callback/xgboost.py +3 -0
- snowflake/ml/experiment/experiment_tracking.py +19 -7
- snowflake/ml/feature_store/feature_store.py +236 -61
- snowflake/ml/jobs/__init__.py +4 -0
- snowflake/ml/jobs/_interop/__init__.py +0 -0
- snowflake/ml/jobs/_interop/data_utils.py +124 -0
- snowflake/ml/jobs/_interop/dto_schema.py +95 -0
- snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
- snowflake/ml/jobs/_interop/legacy.py +225 -0
- snowflake/ml/jobs/_interop/protocols.py +471 -0
- snowflake/ml/jobs/_interop/results.py +51 -0
- snowflake/ml/jobs/_interop/utils.py +144 -0
- snowflake/ml/jobs/_utils/constants.py +16 -2
- snowflake/ml/jobs/_utils/feature_flags.py +37 -5
- snowflake/ml/jobs/_utils/payload_utils.py +8 -2
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
- snowflake/ml/jobs/_utils/spec_utils.py +2 -1
- snowflake/ml/jobs/_utils/stage_utils.py +4 -0
- snowflake/ml/jobs/_utils/types.py +15 -0
- snowflake/ml/jobs/job.py +186 -40
- snowflake/ml/jobs/manager.py +48 -39
- snowflake/ml/model/__init__.py +19 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
- snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
- snowflake/ml/model/_client/model/model_version_impl.py +168 -18
- snowflake/ml/model/_client/ops/model_ops.py +4 -0
- snowflake/ml/model/_client/ops/service_ops.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +3 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +11 -3
- snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
- snowflake/ml/model/_packager/model_env/model_env.py +22 -5
- snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
- snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +7 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/type_hints.py +16 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
- snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +50 -4
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +54 -45
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
|
@@ -7,6 +7,7 @@ from typing import Any, Callable, Optional, Union, overload
|
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
|
|
10
|
+
from snowflake import snowpark
|
|
10
11
|
from snowflake.ml import jobs
|
|
11
12
|
from snowflake.ml._internal import telemetry
|
|
12
13
|
from snowflake.ml._internal.utils import sql_identifier
|
|
@@ -19,7 +20,9 @@ from snowflake.ml.model._client.model import (
|
|
|
19
20
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
20
21
|
from snowflake.ml.model._model_composer import model_composer
|
|
21
22
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
23
|
+
from snowflake.ml.model._model_composer.model_method import utils as model_method_utils
|
|
22
24
|
from snowflake.ml.model._packager.model_handlers import snowmlmodel
|
|
25
|
+
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
|
23
26
|
from snowflake.snowpark import Session, async_job, dataframe
|
|
24
27
|
|
|
25
28
|
_TELEMETRY_PROJECT = "MLOps"
|
|
@@ -41,6 +44,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
41
44
|
_model_name: sql_identifier.SqlIdentifier
|
|
42
45
|
_version_name: sql_identifier.SqlIdentifier
|
|
43
46
|
_functions: list[model_manifest_schema.ModelFunctionInfo]
|
|
47
|
+
_model_spec: Optional[model_meta_schema.ModelMetadataDict]
|
|
44
48
|
|
|
45
49
|
def __init__(self) -> None:
|
|
46
50
|
raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
|
|
@@ -150,6 +154,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
150
154
|
self._model_name = model_name
|
|
151
155
|
self._version_name = version_name
|
|
152
156
|
self._functions = self._get_functions()
|
|
157
|
+
self._model_spec = None
|
|
153
158
|
super(cls, cls).__init__(
|
|
154
159
|
self,
|
|
155
160
|
session=model_ops._session,
|
|
@@ -437,6 +442,26 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
437
442
|
"""
|
|
438
443
|
return self._functions
|
|
439
444
|
|
|
445
|
+
def _get_model_spec(self, statement_params: Optional[dict[str, Any]] = None) -> model_meta_schema.ModelMetadataDict:
|
|
446
|
+
"""Fetch and cache the model spec for this model version.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
statement_params: Optional dictionary of statement parameters to include
|
|
450
|
+
in the SQL command to fetch the model spec.
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
The model spec as a dictionary for this model version.
|
|
454
|
+
"""
|
|
455
|
+
if self._model_spec is None:
|
|
456
|
+
self._model_spec = self._model_ops._fetch_model_spec(
|
|
457
|
+
database_name=None,
|
|
458
|
+
schema_name=None,
|
|
459
|
+
model_name=self._model_name,
|
|
460
|
+
version_name=self._version_name,
|
|
461
|
+
statement_params=statement_params,
|
|
462
|
+
)
|
|
463
|
+
return self._model_spec
|
|
464
|
+
|
|
440
465
|
@overload
|
|
441
466
|
def run(
|
|
442
467
|
self,
|
|
@@ -531,6 +556,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
531
556
|
statement_params=statement_params,
|
|
532
557
|
)
|
|
533
558
|
else:
|
|
559
|
+
explain_case_sensitive = self._determine_explain_case_sensitivity(target_function_info, statement_params)
|
|
560
|
+
|
|
534
561
|
return self._model_ops.invoke_method(
|
|
535
562
|
method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
|
|
536
563
|
method_function_type=target_function_info["target_method_function_type"],
|
|
@@ -544,8 +571,20 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
544
571
|
partition_column=partition_column,
|
|
545
572
|
statement_params=statement_params,
|
|
546
573
|
is_partitioned=target_function_info["is_partitioned"],
|
|
574
|
+
explain_case_sensitive=explain_case_sensitive,
|
|
547
575
|
)
|
|
548
576
|
|
|
577
|
+
def _determine_explain_case_sensitivity(
|
|
578
|
+
self,
|
|
579
|
+
target_function_info: model_manifest_schema.ModelFunctionInfo,
|
|
580
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
581
|
+
) -> bool:
|
|
582
|
+
model_spec = self._get_model_spec(statement_params)
|
|
583
|
+
method_options = model_spec.get("method_options", {})
|
|
584
|
+
return model_method_utils.determine_explain_case_sensitive_from_method_options(
|
|
585
|
+
method_options, target_function_info["name"]
|
|
586
|
+
)
|
|
587
|
+
|
|
549
588
|
@telemetry.send_api_usage_telemetry(
|
|
550
589
|
project=_TELEMETRY_PROJECT,
|
|
551
590
|
subproject=_TELEMETRY_SUBPROJECT,
|
|
@@ -555,7 +594,8 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
555
594
|
"job_spec",
|
|
556
595
|
],
|
|
557
596
|
)
|
|
558
|
-
|
|
597
|
+
@snowpark._internal.utils.private_preview(version="1.18.0")
|
|
598
|
+
def run_batch(
|
|
559
599
|
self,
|
|
560
600
|
*,
|
|
561
601
|
compute_pool: str,
|
|
@@ -563,6 +603,68 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
563
603
|
output_spec: batch_inference_specs.OutputSpec,
|
|
564
604
|
job_spec: Optional[batch_inference_specs.JobSpec] = None,
|
|
565
605
|
) -> jobs.MLJob[Any]:
|
|
606
|
+
"""Execute batch inference on datasets as an SPCS job.
|
|
607
|
+
|
|
608
|
+
Args:
|
|
609
|
+
compute_pool (str): Name of the compute pool to use for building the image containers and batch
|
|
610
|
+
inference execution.
|
|
611
|
+
input_spec (dataframe.DataFrame): Snowpark DataFrame containing the input data for inference.
|
|
612
|
+
The DataFrame should contain all required features for model prediction and passthrough columns.
|
|
613
|
+
output_spec (batch_inference_specs.OutputSpec): Configuration for where and how to save
|
|
614
|
+
the inference results. Specifies the stage location and file handling behavior.
|
|
615
|
+
job_spec (Optional[batch_inference_specs.JobSpec]): Optional configuration for job
|
|
616
|
+
execution parameters such as compute resources, worker counts, and job naming.
|
|
617
|
+
If None, default values will be used.
|
|
618
|
+
|
|
619
|
+
Returns:
|
|
620
|
+
jobs.MLJob[Any]: A batch inference job object that can be used to monitor progress and manage the job
|
|
621
|
+
lifecycle.
|
|
622
|
+
|
|
623
|
+
Raises:
|
|
624
|
+
ValueError: If warehouse is not set in job_spec and no current warehouse is available.
|
|
625
|
+
RuntimeError: If the input_spec cannot be processed or written to the staging location.
|
|
626
|
+
|
|
627
|
+
Example:
|
|
628
|
+
>>> # Prepare input data - Example 1: From a table
|
|
629
|
+
>>> input_df = session.table("my_input_table")
|
|
630
|
+
>>>
|
|
631
|
+
>>> # Prepare input data - Example 2: From a SQL query
|
|
632
|
+
>>> input_df = session.sql(
|
|
633
|
+
... "SELECT id, feature_1, feature_2 FROM feature_table WHERE feature_1 > 100"
|
|
634
|
+
... )
|
|
635
|
+
>>>
|
|
636
|
+
>>> # Prepare input data - Example 3: From Parquet files in a stage
|
|
637
|
+
>>> input_df = session.read.option("pattern", ".*\\.parquet").parquet(
|
|
638
|
+
... "@my_stage/input_data/"
|
|
639
|
+
... ).select("id", "feature_1", "feature_2")
|
|
640
|
+
>>>
|
|
641
|
+
>>> # Configure output location
|
|
642
|
+
>>> output_spec = OutputSpec(
|
|
643
|
+
... stage_location='@My_DB.PUBLIC.MY_STAGE/someth/path/',
|
|
644
|
+
... mode=SaveMode.OVERWRITE
|
|
645
|
+
... )
|
|
646
|
+
>>>
|
|
647
|
+
>>> # Configure job parameters
|
|
648
|
+
>>> job_spec = JobSpec(
|
|
649
|
+
... job_name="my_batch_inference",
|
|
650
|
+
... num_workers=4,
|
|
651
|
+
... cpu_requests="2",
|
|
652
|
+
... memory_requests="8Gi"
|
|
653
|
+
... )
|
|
654
|
+
>>>
|
|
655
|
+
>>> # Run batch inference
|
|
656
|
+
>>> job = model_version.run_batch(
|
|
657
|
+
... compute_pool="my_compute_pool",
|
|
658
|
+
... input_spec=input_df,
|
|
659
|
+
... output_spec=output_spec,
|
|
660
|
+
... job_spec=job_spec
|
|
661
|
+
... )
|
|
662
|
+
|
|
663
|
+
Note:
|
|
664
|
+
This method is currently in private preview and requires Snowflake version 1.18.0 or later.
|
|
665
|
+
The input data is temporarily stored in the output stage location under /_temporary before
|
|
666
|
+
inference execution.
|
|
667
|
+
"""
|
|
566
668
|
statement_params = telemetry.get_statement_params(
|
|
567
669
|
project=_TELEMETRY_PROJECT,
|
|
568
670
|
subproject=_TELEMETRY_SUBPROJECT,
|
|
@@ -789,6 +891,51 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
789
891
|
version_name=sql_identifier.SqlIdentifier(version),
|
|
790
892
|
)
|
|
791
893
|
|
|
894
|
+
def _can_run_on_gpu(
|
|
895
|
+
self,
|
|
896
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
897
|
+
) -> bool:
|
|
898
|
+
"""Check if the model has GPU runtime support.
|
|
899
|
+
|
|
900
|
+
Args:
|
|
901
|
+
statement_params: Optional dictionary of statement parameters to include
|
|
902
|
+
in the SQL command to fetch model spec.
|
|
903
|
+
|
|
904
|
+
Returns:
|
|
905
|
+
True if the model has GPU runtime configured, False otherwise.
|
|
906
|
+
"""
|
|
907
|
+
# Fetch model spec
|
|
908
|
+
model_spec = self._get_model_spec(statement_params)
|
|
909
|
+
|
|
910
|
+
# Check if runtimes section exists and has gpu runtime
|
|
911
|
+
runtimes = model_spec.get("runtimes", {})
|
|
912
|
+
return "gpu" in runtimes
|
|
913
|
+
|
|
914
|
+
def _throw_error_if_gpu_is_not_supported(
|
|
915
|
+
self,
|
|
916
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
|
917
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
918
|
+
) -> None:
|
|
919
|
+
"""Check if the model has GPU runtime support.
|
|
920
|
+
|
|
921
|
+
Args:
|
|
922
|
+
gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
|
|
923
|
+
if None.
|
|
924
|
+
statement_params: Optional dictionary of statement parameters to include
|
|
925
|
+
in the SQL command to fetch model spec.
|
|
926
|
+
|
|
927
|
+
Raises:
|
|
928
|
+
ValueError: If the model does not have GPU runtime support.
|
|
929
|
+
"""
|
|
930
|
+
if gpu_requests is not None and not self._can_run_on_gpu(statement_params):
|
|
931
|
+
raise ValueError(
|
|
932
|
+
f"GPU resources requested (gpu_requests={gpu_requests}), but the model "
|
|
933
|
+
f"{self.fully_qualified_model_name} version {self.version_name} does not have GPU runtime support. "
|
|
934
|
+
"Please ensure the model was logged with GPU runtime configuration or do not provide gpu_requests. "
|
|
935
|
+
"To log the model with GPU runtime configuration, provide `cuda_version` in the `options` while calling"
|
|
936
|
+
" the `log_model` function."
|
|
937
|
+
)
|
|
938
|
+
|
|
792
939
|
def _check_huggingface_text_generation_model(
|
|
793
940
|
self,
|
|
794
941
|
statement_params: Optional[dict[str, Any]] = None,
|
|
@@ -803,13 +950,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
803
950
|
ValueError: If the model is not a HuggingFace text-generation model.
|
|
804
951
|
"""
|
|
805
952
|
# Fetch model spec
|
|
806
|
-
model_spec = self.
|
|
807
|
-
database_name=None,
|
|
808
|
-
schema_name=None,
|
|
809
|
-
model_name=self._model_name,
|
|
810
|
-
version_name=self._version_name,
|
|
811
|
-
statement_params=statement_params,
|
|
812
|
-
)
|
|
953
|
+
model_spec = self._get_model_spec(statement_params)
|
|
813
954
|
|
|
814
955
|
# Check if model_type is huggingface_pipeline
|
|
815
956
|
model_type = model_spec.get("model_type")
|
|
@@ -894,9 +1035,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
894
1035
|
When it is ``False``, this function executes the underlying service creation asynchronously
|
|
895
1036
|
and returns an :class:`AsyncJob`.
|
|
896
1037
|
experimental_options: Experimental options for the service creation with custom inference engine.
|
|
897
|
-
Currently,
|
|
1038
|
+
Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
|
|
898
1039
|
`inference_engine` is the name of the inference engine to use.
|
|
899
1040
|
`inference_engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1041
|
+
`autocapture` is a boolean to enable/disable inference table.
|
|
900
1042
|
"""
|
|
901
1043
|
...
|
|
902
1044
|
|
|
@@ -952,9 +1094,10 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
952
1094
|
When it is ``False``, this function executes the underlying service creation asynchronously
|
|
953
1095
|
and returns an :class:`AsyncJob`.
|
|
954
1096
|
experimental_options: Experimental options for the service creation with custom inference engine.
|
|
955
|
-
Currently,
|
|
1097
|
+
Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
|
|
956
1098
|
`inference_engine` is the name of the inference engine to use.
|
|
957
1099
|
`inference_engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1100
|
+
`autocapture` is a boolean to enable/disable inference table.
|
|
958
1101
|
"""
|
|
959
1102
|
...
|
|
960
1103
|
|
|
@@ -1027,21 +1170,20 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1027
1170
|
When it is False, this function executes the underlying service creation asynchronously
|
|
1028
1171
|
and returns an AsyncJob.
|
|
1029
1172
|
experimental_options: Experimental options for the service creation with custom inference engine.
|
|
1030
|
-
Currently,
|
|
1173
|
+
Currently, `inference_engine`, `inference_engine_args_override`, and `autocapture` are supported.
|
|
1031
1174
|
`inference_engine` is the name of the inference engine to use.
|
|
1032
1175
|
`inference_engine_args_override` is a list of string arguments to pass to the inference engine.
|
|
1176
|
+
`autocapture` is a boolean to enable/disable inference table.
|
|
1033
1177
|
|
|
1034
1178
|
|
|
1035
1179
|
Raises:
|
|
1036
|
-
ValueError: Illegal external access integration arguments
|
|
1180
|
+
ValueError: Illegal external access integration arguments, or if GPU resources are requested
|
|
1181
|
+
but the model does not have GPU runtime support.
|
|
1037
1182
|
exceptions.SnowparkSQLException: if service already exists.
|
|
1038
1183
|
|
|
1039
1184
|
Returns:
|
|
1040
1185
|
If `block=True`, return result information about service creation from server.
|
|
1041
1186
|
Otherwise, return the service creation AsyncJob.
|
|
1042
|
-
|
|
1043
|
-
Raises:
|
|
1044
|
-
ValueError: Illegal external access integration arguments.
|
|
1045
1187
|
"""
|
|
1046
1188
|
statement_params = telemetry.get_statement_params(
|
|
1047
1189
|
project=_TELEMETRY_PROJECT,
|
|
@@ -1064,12 +1206,16 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1064
1206
|
|
|
1065
1207
|
service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
|
|
1066
1208
|
|
|
1067
|
-
#
|
|
1068
|
-
|
|
1069
|
-
self._check_huggingface_text_generation_model(statement_params)
|
|
1209
|
+
# Validate GPU support if GPU resources are requested
|
|
1210
|
+
self._throw_error_if_gpu_is_not_supported(gpu_requests, statement_params)
|
|
1070
1211
|
|
|
1071
1212
|
inference_engine_args = inference_engine_utils._get_inference_engine_args(experimental_options)
|
|
1072
1213
|
|
|
1214
|
+
# Check if model is HuggingFace text-generation before doing inference engine checks
|
|
1215
|
+
# Only validate if inference engine is actually specified
|
|
1216
|
+
if inference_engine_args is not None:
|
|
1217
|
+
self._check_huggingface_text_generation_model(statement_params)
|
|
1218
|
+
|
|
1073
1219
|
# Enrich inference engine args if inference engine is specified
|
|
1074
1220
|
if inference_engine_args is not None:
|
|
1075
1221
|
inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
|
|
@@ -1077,6 +1223,9 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1077
1223
|
gpu_requests,
|
|
1078
1224
|
)
|
|
1079
1225
|
|
|
1226
|
+
# Extract autocapture from experimental_options
|
|
1227
|
+
autocapture = experimental_options.get("autocapture") if experimental_options else None
|
|
1228
|
+
|
|
1080
1229
|
from snowflake.ml.model import event_handler
|
|
1081
1230
|
from snowflake.snowpark import exceptions
|
|
1082
1231
|
|
|
@@ -1116,6 +1265,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1116
1265
|
statement_params=statement_params,
|
|
1117
1266
|
progress_status=status,
|
|
1118
1267
|
inference_engine_args=inference_engine_args,
|
|
1268
|
+
autocapture=autocapture,
|
|
1119
1269
|
)
|
|
1120
1270
|
status.update(label="Model service created successfully", state="complete", expanded=False)
|
|
1121
1271
|
return result
|
|
@@ -952,6 +952,7 @@ class ModelOperator:
|
|
|
952
952
|
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
|
953
953
|
statement_params: Optional[dict[str, str]] = None,
|
|
954
954
|
is_partitioned: Optional[bool] = None,
|
|
955
|
+
explain_case_sensitive: bool = False,
|
|
955
956
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
956
957
|
...
|
|
957
958
|
|
|
@@ -967,6 +968,7 @@ class ModelOperator:
|
|
|
967
968
|
service_name: sql_identifier.SqlIdentifier,
|
|
968
969
|
strict_input_validation: bool = False,
|
|
969
970
|
statement_params: Optional[dict[str, str]] = None,
|
|
971
|
+
explain_case_sensitive: bool = False,
|
|
970
972
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
971
973
|
...
|
|
972
974
|
|
|
@@ -986,6 +988,7 @@ class ModelOperator:
|
|
|
986
988
|
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
|
987
989
|
statement_params: Optional[dict[str, str]] = None,
|
|
988
990
|
is_partitioned: Optional[bool] = None,
|
|
991
|
+
explain_case_sensitive: bool = False,
|
|
989
992
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
|
990
993
|
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
|
991
994
|
|
|
@@ -1068,6 +1071,7 @@ class ModelOperator:
|
|
|
1068
1071
|
version_name=version_name,
|
|
1069
1072
|
statement_params=statement_params,
|
|
1070
1073
|
is_partitioned=is_partitioned or False,
|
|
1074
|
+
explain_case_sensitive=explain_case_sensitive,
|
|
1071
1075
|
)
|
|
1072
1076
|
|
|
1073
1077
|
if keep_order:
|
|
@@ -206,6 +206,8 @@ class ServiceOperator:
|
|
|
206
206
|
hf_model_args: Optional[HFModelArgs] = None,
|
|
207
207
|
# inference engine model
|
|
208
208
|
inference_engine_args: Optional[InferenceEngineArgs] = None,
|
|
209
|
+
# inference table
|
|
210
|
+
autocapture: Optional[bool] = None,
|
|
209
211
|
) -> Union[str, async_job.AsyncJob]:
|
|
210
212
|
|
|
211
213
|
# Generate operation ID for this deployment
|
|
@@ -261,6 +263,7 @@ class ServiceOperator:
|
|
|
261
263
|
gpu=gpu_requests,
|
|
262
264
|
num_workers=num_workers,
|
|
263
265
|
max_batch_rows=max_batch_rows,
|
|
266
|
+
autocapture=autocapture,
|
|
264
267
|
)
|
|
265
268
|
if hf_model_args:
|
|
266
269
|
# hf model
|
|
@@ -146,6 +146,7 @@ class ModelDeploymentSpec:
|
|
|
146
146
|
gpu: Optional[Union[str, int]] = None,
|
|
147
147
|
num_workers: Optional[int] = None,
|
|
148
148
|
max_batch_rows: Optional[int] = None,
|
|
149
|
+
autocapture: Optional[bool] = None,
|
|
149
150
|
) -> "ModelDeploymentSpec":
|
|
150
151
|
"""Add service specification to the deployment spec.
|
|
151
152
|
|
|
@@ -161,6 +162,7 @@ class ModelDeploymentSpec:
|
|
|
161
162
|
gpu: GPU requirement.
|
|
162
163
|
num_workers: Number of workers.
|
|
163
164
|
max_batch_rows: Maximum batch rows for inference.
|
|
165
|
+
autocapture: Whether to enable inference table.
|
|
164
166
|
|
|
165
167
|
Raises:
|
|
166
168
|
ValueError: If a job spec already exists.
|
|
@@ -186,6 +188,7 @@ class ModelDeploymentSpec:
|
|
|
186
188
|
compute_pool=inference_compute_pool_name.identifier(),
|
|
187
189
|
ingress_enabled=ingress_enabled,
|
|
188
190
|
max_instances=max_instances,
|
|
191
|
+
autocapture=autocapture,
|
|
189
192
|
**self._inference_spec,
|
|
190
193
|
)
|
|
191
194
|
return self
|
|
@@ -438,6 +438,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
438
438
|
partition_column: Optional[sql_identifier.SqlIdentifier],
|
|
439
439
|
statement_params: Optional[dict[str, Any]] = None,
|
|
440
440
|
is_partitioned: bool = True,
|
|
441
|
+
explain_case_sensitive: bool = False,
|
|
441
442
|
) -> dataframe.DataFrame:
|
|
442
443
|
with_statements = []
|
|
443
444
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
|
@@ -505,7 +506,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
|
505
506
|
cols_to_drop = []
|
|
506
507
|
|
|
507
508
|
for output_name, output_type, output_col_name in returns:
|
|
508
|
-
|
|
509
|
+
case_sensitive = "explain" in method_name.resolved().lower() and explain_case_sensitive
|
|
510
|
+
output_identifier = sql_identifier.SqlIdentifier(output_name, case_sensitive=case_sensitive).identifier()
|
|
509
511
|
if output_identifier != output_col_name:
|
|
510
512
|
cols_to_drop.append(output_identifier)
|
|
511
513
|
output_cols.append(F.col(output_identifier).astype(output_type))
|
|
@@ -87,7 +87,9 @@ class ModelManifest:
|
|
|
87
87
|
model_meta_schema.FunctionProperties.PARTITIONED.value, False
|
|
88
88
|
),
|
|
89
89
|
wide_input=len(model_meta.signatures[target_method].inputs) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT,
|
|
90
|
-
options=model_method.get_model_method_options_from_options(
|
|
90
|
+
options=model_method.get_model_method_options_from_options(
|
|
91
|
+
options, target_method, model_meta.model_type
|
|
92
|
+
),
|
|
91
93
|
)
|
|
92
94
|
|
|
93
95
|
self.methods.append(method)
|
|
@@ -11,6 +11,7 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest_sch
|
|
|
11
11
|
from snowflake.ml.model._model_composer.model_method import (
|
|
12
12
|
constants,
|
|
13
13
|
function_generator,
|
|
14
|
+
utils,
|
|
14
15
|
)
|
|
15
16
|
from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
|
|
16
17
|
from snowflake.ml.model.volatility import Volatility
|
|
@@ -31,12 +32,19 @@ class ModelMethodOptions(TypedDict):
|
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
def get_model_method_options_from_options(
|
|
34
|
-
options: type_hints.ModelSaveOption, target_method: str
|
|
35
|
+
options: type_hints.ModelSaveOption, target_method: str, model_type: Optional[str] = None
|
|
35
36
|
) -> ModelMethodOptions:
|
|
36
37
|
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
|
38
|
+
method_option = options.get("method_options", {}).get(target_method, {})
|
|
39
|
+
case_sensitive = method_option.get("case_sensitive", False)
|
|
37
40
|
if target_method == "explain":
|
|
38
41
|
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
|
39
|
-
|
|
42
|
+
case_sensitive = utils.determine_explain_case_sensitive_from_method_options(
|
|
43
|
+
options.get("method_options", {}), target_method
|
|
44
|
+
)
|
|
45
|
+
elif model_type == "prophet":
|
|
46
|
+
# Prophet models always require TABLE_FUNCTION because they need entire time series context
|
|
47
|
+
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
|
40
48
|
global_function_type = options.get("function_type", default_function_type)
|
|
41
49
|
function_type = method_option.get("function_type", global_function_type)
|
|
42
50
|
if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
|
|
@@ -48,7 +56,7 @@ def get_model_method_options_from_options(
|
|
|
48
56
|
|
|
49
57
|
# Only include volatility if explicitly provided in method options
|
|
50
58
|
result: ModelMethodOptions = ModelMethodOptions(
|
|
51
|
-
case_sensitive=
|
|
59
|
+
case_sensitive=case_sensitive,
|
|
52
60
|
function_type=function_type,
|
|
53
61
|
)
|
|
54
62
|
if resolved_volatility:
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Mapping, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def determine_explain_case_sensitive_from_method_options(
|
|
7
|
+
method_options: Mapping[str, Optional[Mapping[str, Any]]],
|
|
8
|
+
target_method: str,
|
|
9
|
+
) -> bool:
|
|
10
|
+
"""Determine explain method case sensitivity from related predict methods.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
method_options: Mapping from method name to its options. Each option may
|
|
14
|
+
contain ``"case_sensitive"`` to indicate SQL identifier sensitivity.
|
|
15
|
+
target_method: The target method name being resolved (e.g., an ``explain_*``
|
|
16
|
+
method).
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
True if the explain method should be treated as case sensitive; otherwise False.
|
|
20
|
+
"""
|
|
21
|
+
if "explain" not in target_method:
|
|
22
|
+
return False
|
|
23
|
+
predict_priority_methods = ["predict_proba", "predict", "predict_log_proba"]
|
|
24
|
+
for src_method in predict_priority_methods:
|
|
25
|
+
src_opts = method_options.get(src_method)
|
|
26
|
+
if src_opts is not None:
|
|
27
|
+
return bool(src_opts.get("case_sensitive", False))
|
|
28
|
+
return False
|
|
@@ -240,14 +240,31 @@ class ModelEnv:
|
|
|
240
240
|
self._conda_dependencies[channel].remove(spec)
|
|
241
241
|
|
|
242
242
|
def generate_env_for_cuda(self) -> None:
|
|
243
|
+
|
|
244
|
+
# Insert py-xgboost-gpu only for XGBoost versions < 3.0.0
|
|
243
245
|
xgboost_spec = env_utils.find_dep_spec(
|
|
244
|
-
self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=
|
|
246
|
+
self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=False
|
|
245
247
|
)
|
|
246
248
|
if xgboost_spec:
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
249
|
+
# Only handle explicitly pinned versions. Insert GPU variant iff pinned major < 3.
|
|
250
|
+
pinned_major: Optional[int] = None
|
|
251
|
+
for spec in xgboost_spec.specifier:
|
|
252
|
+
if spec.operator in ("==", "===", ">", ">="):
|
|
253
|
+
try:
|
|
254
|
+
pinned_major = version.parse(spec.version).major
|
|
255
|
+
except version.InvalidVersion:
|
|
256
|
+
pinned_major = None
|
|
257
|
+
break
|
|
258
|
+
|
|
259
|
+
if pinned_major is not None and pinned_major < 3:
|
|
260
|
+
xgboost_spec = env_utils.find_dep_spec(
|
|
261
|
+
self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
|
|
262
|
+
)
|
|
263
|
+
if xgboost_spec:
|
|
264
|
+
self.include_if_absent(
|
|
265
|
+
[ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")],
|
|
266
|
+
check_local_version=False,
|
|
267
|
+
)
|
|
251
268
|
|
|
252
269
|
tf_spec = env_utils.find_dep_spec(
|
|
253
270
|
self._conda_dependencies, self._pip_requirements, conda_pkg_name="tensorflow", remove_spec=True
|
|
@@ -305,3 +305,73 @@ def get_default_cuda_version() -> str:
|
|
|
305
305
|
|
|
306
306
|
return torch.version.cuda or model_env.DEFAULT_CUDA_VERSION
|
|
307
307
|
return model_env.DEFAULT_CUDA_VERSION
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def normalize_column_name(column_name: str) -> str:
|
|
311
|
+
"""Normalize a column name to be a valid unquoted Snowflake SQL identifier.
|
|
312
|
+
|
|
313
|
+
Converts column names with spaces and special characters (e.g., "Christmas Day")
|
|
314
|
+
into valid lowercase unquoted SQL identifiers (e.g., "christmas_day") that can be used
|
|
315
|
+
without quotes in SQL queries. This follows Snowflake's unquoted identifier rules:
|
|
316
|
+
https://docs.snowflake.com/en/sql-reference/identifiers-syntax
|
|
317
|
+
|
|
318
|
+
The normalization approach is preferred over quoted identifiers because:
|
|
319
|
+
- Unquoted identifiers are simpler and more readable
|
|
320
|
+
- They don't require special handling in SQL contexts
|
|
321
|
+
- They avoid case-sensitivity complications
|
|
322
|
+
- Lowercase convention improves readability and follows Python/pandas conventions
|
|
323
|
+
|
|
324
|
+
This utility is useful for model handlers that need to ensure output column names
|
|
325
|
+
from models (e.g., Prophet holiday columns, feature names) are SQL-safe.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
column_name: Original column name (may contain spaces, special chars, etc.)
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
Normalized lowercase column name that is a valid unquoted SQL identifier matching
|
|
332
|
+
the pattern [a-z_][a-z0-9_$]*
|
|
333
|
+
|
|
334
|
+
Examples:
|
|
335
|
+
>>> normalize_column_name_for_snowflake("Christmas Day")
|
|
336
|
+
'christmas_day'
|
|
337
|
+
>>> normalize_column_name_for_snowflake("New Year's Day")
|
|
338
|
+
'new_year_s_day'
|
|
339
|
+
>>> normalize_column_name_for_snowflake("__private")
|
|
340
|
+
'__private'
|
|
341
|
+
>>> normalize_column_name_for_snowflake("2023_data")
|
|
342
|
+
'_2023_data'
|
|
343
|
+
"""
|
|
344
|
+
import re
|
|
345
|
+
|
|
346
|
+
# Convert to lowercase for readability and consistency
|
|
347
|
+
normalized = column_name.lower()
|
|
348
|
+
|
|
349
|
+
# Replace spaces and special characters with underscores
|
|
350
|
+
# Keep only alphanumeric characters, underscores, and dollar signs (valid in unquoted identifiers)
|
|
351
|
+
normalized = re.sub(r"[^a-z0-9_$]", "_", normalized)
|
|
352
|
+
|
|
353
|
+
# Collapse consecutive underscores while preserving leading underscores
|
|
354
|
+
# This handles cases like "A B" → "a__b" → "a_b" while preserving "__name" → "__name"
|
|
355
|
+
if len(normalized) > 1:
|
|
356
|
+
# Count and preserve leading underscores, collapse the rest
|
|
357
|
+
leading_underscores = len(normalized) - len(normalized.lstrip("_"))
|
|
358
|
+
rest_of_string = normalized[leading_underscores:]
|
|
359
|
+
rest_collapsed = re.sub(r"_+", "_", rest_of_string)
|
|
360
|
+
normalized = "_" * leading_underscores + rest_collapsed
|
|
361
|
+
|
|
362
|
+
if normalized == "_":
|
|
363
|
+
return "_"
|
|
364
|
+
|
|
365
|
+
normalized = normalized.rstrip("_")
|
|
366
|
+
|
|
367
|
+
# Ensure it starts with a letter or underscore (SQL requirement)
|
|
368
|
+
# Unquoted identifiers must match: [a-z_][a-z0-9_$]*
|
|
369
|
+
if normalized and normalized[0].isdigit():
|
|
370
|
+
normalized = "_" + normalized
|
|
371
|
+
|
|
372
|
+
# If normalization resulted in empty string, use a default
|
|
373
|
+
# (This happens when input was only underscores like "___")
|
|
374
|
+
if not normalized:
|
|
375
|
+
normalized = "column"
|
|
376
|
+
|
|
377
|
+
return normalized
|