snowflake-ml-python 1.22.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +38 -18
- snowflake/ml/jobs/_utils/query_helper.py +8 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +117 -0
- snowflake/ml/jobs/_utils/stage_utils.py +2 -2
- snowflake/ml/jobs/_utils/types.py +22 -2
- snowflake/ml/jobs/job_definition.py +232 -0
- snowflake/ml/jobs/manager.py +16 -177
- snowflake/ml/model/_client/model/model_version_impl.py +90 -76
- snowflake/ml/model/_client/ops/model_ops.py +2 -18
- snowflake/ml/model/_client/ops/param_utils.py +124 -0
- snowflake/ml/model/_client/ops/service_ops.py +63 -18
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/service.py +4 -25
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -1
- snowflake/ml/model/_packager/model_handlers/huggingface.py +54 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +52 -16
- snowflake/ml/model/_signatures/utils.py +55 -0
- snowflake/ml/model/openai_signatures.py +97 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/METADATA +67 -1
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/RECORD +31 -29
- snowflake/ml/experiment/callback/__init__.py +0 -0
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""Utility functions for model parameter validation and resolution."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional, Sequence
|
|
4
|
+
|
|
5
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
6
|
+
from snowflake.ml._internal.utils import sql_identifier
|
|
7
|
+
from snowflake.ml.model._signatures import core
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def validate_params(
|
|
11
|
+
params: Optional[dict[str, Any]],
|
|
12
|
+
signature_params: Optional[Sequence[core.BaseParamSpec]],
|
|
13
|
+
) -> None:
|
|
14
|
+
"""Validate user-provided params against signature params.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
params: User-provided parameter dictionary (runtime values).
|
|
18
|
+
signature_params: Parameter specifications from the model signature.
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
SnowflakeMLException: If params are provided but signature has no params,
|
|
22
|
+
or if unknown params are provided, or if param types are invalid,
|
|
23
|
+
or if duplicate params are provided with different cases.
|
|
24
|
+
"""
|
|
25
|
+
# Params provided but signature has no params defined
|
|
26
|
+
if params and not signature_params:
|
|
27
|
+
raise exceptions.SnowflakeMLException(
|
|
28
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
|
29
|
+
original_exception=ValueError(
|
|
30
|
+
f"Parameters were provided ({sorted(params.keys())}), "
|
|
31
|
+
"but this method does not accept any parameters."
|
|
32
|
+
),
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if not signature_params or not params:
|
|
36
|
+
return
|
|
37
|
+
|
|
38
|
+
# Case-insensitive lookup: normalized_name -> param_spec
|
|
39
|
+
param_spec_lookup = {ps.name.upper(): ps for ps in signature_params}
|
|
40
|
+
|
|
41
|
+
# Check for duplicate params with different cases (e.g., "temperature" and "TEMPERATURE")
|
|
42
|
+
normalized_names = [name.upper() for name in params]
|
|
43
|
+
if len(normalized_names) != len(set(normalized_names)):
|
|
44
|
+
# Find the duplicate params to raise an error
|
|
45
|
+
param_seen: dict[str, list[str]] = {}
|
|
46
|
+
for param_name in params:
|
|
47
|
+
param_seen.setdefault(param_name.upper(), []).append(param_name)
|
|
48
|
+
duplicate_param_names = [param_names for param_names in param_seen.values() if len(param_names) > 1]
|
|
49
|
+
raise exceptions.SnowflakeMLException(
|
|
50
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
|
51
|
+
original_exception=ValueError(
|
|
52
|
+
f"Duplicate parameter(s) provided with different cases: {duplicate_param_names}. "
|
|
53
|
+
"Parameter names are case-insensitive."
|
|
54
|
+
),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Validate user-provided params exist (case-insensitive)
|
|
58
|
+
invalid_params = [name for name in params if name.upper() not in param_spec_lookup]
|
|
59
|
+
if invalid_params:
|
|
60
|
+
raise exceptions.SnowflakeMLException(
|
|
61
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
|
62
|
+
original_exception=ValueError(
|
|
63
|
+
f"Unknown parameter(s): {sorted(invalid_params)}. "
|
|
64
|
+
f"Valid parameters are: {sorted(ps.name for ps in signature_params)}"
|
|
65
|
+
),
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Validate types for each provided param
|
|
69
|
+
for param_name, default_value in params.items():
|
|
70
|
+
param_spec = param_spec_lookup[param_name.upper()]
|
|
71
|
+
if isinstance(param_spec, core.ParamSpec):
|
|
72
|
+
core.ParamSpec._validate_default_value(param_spec.dtype, default_value, param_spec.shape)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def resolve_params(
|
|
76
|
+
params: Optional[dict[str, Any]],
|
|
77
|
+
signature_params: Sequence[core.BaseParamSpec],
|
|
78
|
+
) -> list[tuple[sql_identifier.SqlIdentifier, Any]]:
|
|
79
|
+
"""Resolve final method parameters by applying user-provided params over signature defaults.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
params: User-provided parameter dictionary (runtime values).
|
|
83
|
+
signature_params: Parameter specifications from the model signature.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
List of tuples (SqlIdentifier, value) for method invocation.
|
|
87
|
+
"""
|
|
88
|
+
# Case-insensitive lookup: normalized_name -> param_spec
|
|
89
|
+
param_spec_lookup = {ps.name.upper(): ps for ps in signature_params}
|
|
90
|
+
|
|
91
|
+
# Start with defaults from signature
|
|
92
|
+
final_params: dict[str, Any] = {}
|
|
93
|
+
for param_spec in signature_params:
|
|
94
|
+
if hasattr(param_spec, "default_value"):
|
|
95
|
+
final_params[param_spec.name] = param_spec.default_value
|
|
96
|
+
|
|
97
|
+
# Override with provided runtime parameters (using signature's original param names)
|
|
98
|
+
if params:
|
|
99
|
+
for param_name, override_value in params.items():
|
|
100
|
+
canonical_name = param_spec_lookup[param_name.upper()].name
|
|
101
|
+
final_params[canonical_name] = override_value
|
|
102
|
+
|
|
103
|
+
return [(sql_identifier.SqlIdentifier(param_name), param_value) for param_name, param_value in final_params.items()]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def validate_and_resolve_params(
|
|
107
|
+
params: Optional[dict[str, Any]],
|
|
108
|
+
signature_params: Optional[Sequence[core.BaseParamSpec]],
|
|
109
|
+
) -> Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]]:
|
|
110
|
+
"""Validate user-provided params against signature params and return method parameters.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
params: User-provided parameter dictionary (runtime values).
|
|
114
|
+
signature_params: Parameter specifications from the model signature.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
List of tuples (SqlIdentifier, value) for method invocation, or None if no params.
|
|
118
|
+
"""
|
|
119
|
+
validate_params(params, signature_params)
|
|
120
|
+
|
|
121
|
+
if not signature_params:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
return resolve_params(params, signature_params)
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
import base64
|
|
1
2
|
import dataclasses
|
|
3
|
+
import json
|
|
2
4
|
import logging
|
|
3
5
|
import pathlib
|
|
4
6
|
import re
|
|
@@ -6,7 +8,9 @@ import tempfile
|
|
|
6
8
|
import threading
|
|
7
9
|
import time
|
|
8
10
|
import warnings
|
|
9
|
-
from typing import Any, Optional, Union, cast
|
|
11
|
+
from typing import Any, Optional, Sequence, Union, cast
|
|
12
|
+
|
|
13
|
+
from pydantic import TypeAdapter
|
|
10
14
|
|
|
11
15
|
from snowflake import snowpark
|
|
12
16
|
from snowflake.ml._internal import file_utils, platform_capabilities as pc
|
|
@@ -14,9 +18,10 @@ from snowflake.ml._internal.utils import identifier, service_logger, sql_identif
|
|
|
14
18
|
from snowflake.ml.jobs import job
|
|
15
19
|
from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
|
|
16
20
|
from snowflake.ml.model._client.model import batch_inference_specs
|
|
17
|
-
from snowflake.ml.model._client.ops import deployment_step
|
|
21
|
+
from snowflake.ml.model._client.ops import deployment_step, param_utils
|
|
18
22
|
from snowflake.ml.model._client.service import model_deployment_spec
|
|
19
23
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
|
24
|
+
from snowflake.ml.model._signatures import core
|
|
20
25
|
from snowflake.snowpark import async_job, exceptions, row, session
|
|
21
26
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
22
27
|
|
|
@@ -582,15 +587,10 @@ class ServiceOperator:
|
|
|
582
587
|
)
|
|
583
588
|
for status in statuses:
|
|
584
589
|
if status.instance_id is not None:
|
|
585
|
-
instance_status, container_status = None, None
|
|
586
|
-
if status.instance_status is not None:
|
|
587
|
-
instance_status = status.instance_status.value
|
|
588
|
-
if status.container_status is not None:
|
|
589
|
-
container_status = status.container_status.value
|
|
590
590
|
module_logger.info(
|
|
591
591
|
f"Instance[{status.instance_id}]: "
|
|
592
|
-
f"instance status: {instance_status}, "
|
|
593
|
-
f"container status: {container_status}, "
|
|
592
|
+
f"instance status: {status.instance_status}, "
|
|
593
|
+
f"container status: {status.container_status}, "
|
|
594
594
|
f"message: {status.message}"
|
|
595
595
|
)
|
|
596
596
|
time.sleep(5)
|
|
@@ -930,6 +930,38 @@ class ServiceOperator:
|
|
|
930
930
|
except exceptions.SnowparkSQLException:
|
|
931
931
|
return False
|
|
932
932
|
|
|
933
|
+
@staticmethod
|
|
934
|
+
def _encode_params(params: Optional[dict[str, Any]]) -> Optional[str]:
|
|
935
|
+
"""Encode params dictionary to a base64 string.
|
|
936
|
+
|
|
937
|
+
Args:
|
|
938
|
+
params: Optional dictionary of model inference parameters.
|
|
939
|
+
|
|
940
|
+
Returns:
|
|
941
|
+
Base64 encoded JSON string of the params, or None if input is None.
|
|
942
|
+
"""
|
|
943
|
+
if params is None:
|
|
944
|
+
return None
|
|
945
|
+
return base64.b64encode(json.dumps(params).encode("utf-8")).decode("utf-8")
|
|
946
|
+
|
|
947
|
+
@staticmethod
|
|
948
|
+
def _encode_column_handling(
|
|
949
|
+
column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
|
|
950
|
+
) -> Optional[str]:
|
|
951
|
+
"""Validate and encode column_handling to a base64 string.
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
column_handling: Optional dictionary mapping column names to file encoding options.
|
|
955
|
+
|
|
956
|
+
Returns:
|
|
957
|
+
Base64 encoded JSON string of the column handling options, or None if input is None.
|
|
958
|
+
"""
|
|
959
|
+
if column_handling is None:
|
|
960
|
+
return None
|
|
961
|
+
adapter = TypeAdapter(dict[str, batch_inference_specs.ColumnHandlingOptions])
|
|
962
|
+
validated_input = adapter.validate_python(column_handling)
|
|
963
|
+
return base64.b64encode(adapter.dump_json(validated_input)).decode("utf-8")
|
|
964
|
+
|
|
933
965
|
def invoke_batch_job_method(
|
|
934
966
|
self,
|
|
935
967
|
*,
|
|
@@ -942,8 +974,9 @@ class ServiceOperator:
|
|
|
942
974
|
image_repo_name: Optional[str],
|
|
943
975
|
input_stage_location: str,
|
|
944
976
|
input_file_pattern: str,
|
|
945
|
-
column_handling: Optional[str],
|
|
946
|
-
params: Optional[str],
|
|
977
|
+
column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
|
|
978
|
+
params: Optional[dict[str, Any]],
|
|
979
|
+
signature_params: Optional[Sequence[core.BaseParamSpec]],
|
|
947
980
|
output_stage_location: str,
|
|
948
981
|
completion_filename: str,
|
|
949
982
|
force_rebuild: bool,
|
|
@@ -954,7 +987,13 @@ class ServiceOperator:
|
|
|
954
987
|
gpu_requests: Optional[str],
|
|
955
988
|
replicas: Optional[int],
|
|
956
989
|
statement_params: Optional[dict[str, Any]] = None,
|
|
990
|
+
inference_engine_args: Optional[InferenceEngineArgs] = None,
|
|
957
991
|
) -> job.MLJob[Any]:
|
|
992
|
+
# Validate and encode params
|
|
993
|
+
param_utils.validate_params(params, signature_params)
|
|
994
|
+
params_encoded = self._encode_params(params)
|
|
995
|
+
column_handling_encoded = self._encode_column_handling(column_handling)
|
|
996
|
+
|
|
958
997
|
database_name = self._database_name
|
|
959
998
|
schema_name = self._schema_name
|
|
960
999
|
|
|
@@ -980,8 +1019,8 @@ class ServiceOperator:
|
|
|
980
1019
|
max_batch_rows=max_batch_rows,
|
|
981
1020
|
input_stage_location=input_stage_location,
|
|
982
1021
|
input_file_pattern=input_file_pattern,
|
|
983
|
-
column_handling=
|
|
984
|
-
params=
|
|
1022
|
+
column_handling=column_handling_encoded,
|
|
1023
|
+
params=params_encoded,
|
|
985
1024
|
output_stage_location=output_stage_location,
|
|
986
1025
|
completion_filename=completion_filename,
|
|
987
1026
|
function_name=function_name,
|
|
@@ -992,11 +1031,17 @@ class ServiceOperator:
|
|
|
992
1031
|
replicas=replicas,
|
|
993
1032
|
)
|
|
994
1033
|
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1034
|
+
if inference_engine_args:
|
|
1035
|
+
self._model_deployment_spec.add_inference_engine_spec(
|
|
1036
|
+
inference_engine=inference_engine_args.inference_engine,
|
|
1037
|
+
inference_engine_args=inference_engine_args.inference_engine_args_override,
|
|
1038
|
+
)
|
|
1039
|
+
else:
|
|
1040
|
+
self._model_deployment_spec.add_image_build_spec(
|
|
1041
|
+
image_build_compute_pool_name=compute_pool_name,
|
|
1042
|
+
fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
|
|
1043
|
+
force_rebuild=force_rebuild,
|
|
1044
|
+
)
|
|
1000
1045
|
|
|
1001
1046
|
spec_yaml_str_or_path = self._model_deployment_spec.save()
|
|
1002
1047
|
|
|
@@ -363,7 +363,7 @@ class ModelDeploymentSpec:
|
|
|
363
363
|
inference_engine: inference_engine_module.InferenceEngine,
|
|
364
364
|
inference_engine_args: Optional[list[str]] = None,
|
|
365
365
|
) -> "ModelDeploymentSpec":
|
|
366
|
-
"""Add inference engine specification. This must be called after self.add_service_spec().
|
|
366
|
+
"""Add inference engine specification. This must be called after self.add_service_spec() or self.add_job_spec().
|
|
367
367
|
|
|
368
368
|
Args:
|
|
369
369
|
inference_engine: Inference engine.
|
|
@@ -376,9 +376,10 @@ class ModelDeploymentSpec:
|
|
|
376
376
|
ValueError: If inference engine specification is called before add_service_spec().
|
|
377
377
|
ValueError: If the argument does not have a '--' prefix.
|
|
378
378
|
"""
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
379
|
+
if self._service is None and self._job is None:
|
|
380
|
+
raise ValueError(
|
|
381
|
+
"Inference engine specification must be called after add_service_spec() or add_job_spec()."
|
|
382
|
+
)
|
|
382
383
|
|
|
383
384
|
if inference_engine_args is None:
|
|
384
385
|
inference_engine_args = []
|
|
@@ -431,11 +432,17 @@ class ModelDeploymentSpec:
|
|
|
431
432
|
|
|
432
433
|
inference_engine_args = filtered_args
|
|
433
434
|
|
|
434
|
-
|
|
435
|
+
inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
|
|
435
436
|
# convert to string to be saved in the deployment spec
|
|
436
437
|
inference_engine_name=inference_engine.value,
|
|
437
438
|
inference_engine_args=inference_engine_args,
|
|
438
439
|
)
|
|
440
|
+
|
|
441
|
+
if self._service:
|
|
442
|
+
self._service.inference_engine_spec = inference_engine_spec
|
|
443
|
+
elif self._job:
|
|
444
|
+
self._job.inference_engine_spec = inference_engine_spec
|
|
445
|
+
|
|
439
446
|
return self
|
|
440
447
|
|
|
441
448
|
def save(self) -> str:
|
|
@@ -47,22 +47,6 @@ class ServiceStatus(enum.Enum):
|
|
|
47
47
|
INTERNAL_ERROR = "INTERNAL_ERROR"
|
|
48
48
|
|
|
49
49
|
|
|
50
|
-
class InstanceStatus(enum.Enum):
|
|
51
|
-
PENDING = "PENDING"
|
|
52
|
-
READY = "READY"
|
|
53
|
-
FAILED = "FAILED"
|
|
54
|
-
TERMINATING = "TERMINATING"
|
|
55
|
-
SUCCEEDED = "SUCCEEDED"
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class ContainerStatus(enum.Enum):
|
|
59
|
-
PENDING = "PENDING"
|
|
60
|
-
READY = "READY"
|
|
61
|
-
DONE = "DONE"
|
|
62
|
-
FAILED = "FAILED"
|
|
63
|
-
UNKNOWN = "UNKNOWN"
|
|
64
|
-
|
|
65
|
-
|
|
66
50
|
@dataclasses.dataclass
|
|
67
51
|
class ServiceStatusInfo:
|
|
68
52
|
"""
|
|
@@ -72,8 +56,8 @@ class ServiceStatusInfo:
|
|
|
72
56
|
|
|
73
57
|
service_status: ServiceStatus
|
|
74
58
|
instance_id: Optional[int] = None
|
|
75
|
-
instance_status: Optional[
|
|
76
|
-
container_status: Optional[
|
|
59
|
+
instance_status: Optional[str] = None
|
|
60
|
+
container_status: Optional[str] = None
|
|
77
61
|
message: Optional[str] = None
|
|
78
62
|
|
|
79
63
|
|
|
@@ -272,17 +256,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
272
256
|
)
|
|
273
257
|
statuses = []
|
|
274
258
|
for r in rows:
|
|
275
|
-
instance_status, container_status = None, None
|
|
276
|
-
if r[ServiceSQLClient.INSTANCE_STATUS] is not None:
|
|
277
|
-
instance_status = InstanceStatus(r[ServiceSQLClient.INSTANCE_STATUS])
|
|
278
|
-
if r[ServiceSQLClient.CONTAINER_STATUS] is not None:
|
|
279
|
-
container_status = ContainerStatus(r[ServiceSQLClient.CONTAINER_STATUS])
|
|
280
259
|
statuses.append(
|
|
281
260
|
ServiceStatusInfo(
|
|
282
261
|
service_status=ServiceStatus(r[ServiceSQLClient.SERVICE_STATUS]),
|
|
283
262
|
instance_id=r[ServiceSQLClient.INSTANCE_ID],
|
|
284
|
-
instance_status=
|
|
285
|
-
container_status=
|
|
263
|
+
instance_status=r[ServiceSQLClient.INSTANCE_STATUS],
|
|
264
|
+
container_status=r[ServiceSQLClient.CONTAINER_STATUS],
|
|
286
265
|
message=r[ServiceSQLClient.MESSAGE] if include_message else None,
|
|
287
266
|
)
|
|
288
267
|
)
|
|
@@ -41,11 +41,29 @@ features = meta.signatures[TARGET_METHOD].inputs
|
|
|
41
41
|
input_cols = [feature.name for feature in features]
|
|
42
42
|
dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
43
43
|
|
|
44
|
+
# Load inference parameters from method signature (if any)
|
|
45
|
+
param_cols = []
|
|
46
|
+
param_defaults = {{}}
|
|
47
|
+
if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
|
|
48
|
+
for param_spec in meta.signatures[TARGET_METHOD].params:
|
|
49
|
+
param_cols.append(param_spec.name)
|
|
50
|
+
param_defaults[param_spec.name] = param_spec.default_value
|
|
51
|
+
|
|
44
52
|
|
|
45
53
|
# Actual function
|
|
46
54
|
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
|
|
47
55
|
def {function_name}(df: pd.DataFrame) -> dict:
|
|
48
|
-
df.columns = input_cols
|
|
49
|
-
input_df = df.astype(dtype=dtype_map)
|
|
50
|
-
|
|
56
|
+
df.columns = input_cols + param_cols
|
|
57
|
+
input_df = df[input_cols].astype(dtype=dtype_map)
|
|
58
|
+
|
|
59
|
+
# Extract runtime param values, using defaults if None
|
|
60
|
+
method_params = {{}}
|
|
61
|
+
for col in param_cols:
|
|
62
|
+
val = df[col].iloc[0]
|
|
63
|
+
if val is None or pd.isna(val):
|
|
64
|
+
method_params[col] = param_defaults[col]
|
|
65
|
+
else:
|
|
66
|
+
method_params[col] = val
|
|
67
|
+
|
|
68
|
+
predictions_df = runner(input_df, **method_params)
|
|
51
69
|
return predictions_df.replace({{pd.NA: None, np.nan: None}}).to_dict("records")
|
|
@@ -45,11 +45,29 @@ features = meta.signatures[TARGET_METHOD].inputs
|
|
|
45
45
|
input_cols = [feature.name for feature in features]
|
|
46
46
|
dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
47
47
|
|
|
48
|
+
# Load inference parameters from method signature (if any)
|
|
49
|
+
param_cols = []
|
|
50
|
+
param_defaults = {{}}
|
|
51
|
+
if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
|
|
52
|
+
for param_spec in meta.signatures[TARGET_METHOD].params:
|
|
53
|
+
param_cols.append(param_spec.name)
|
|
54
|
+
param_defaults[param_spec.name] = param_spec.default_value
|
|
55
|
+
|
|
48
56
|
|
|
49
57
|
# Actual table function
|
|
50
58
|
class {function_name}:
|
|
51
59
|
@vectorized(input=pd.DataFrame, flatten_object_input={wide_input})
|
|
52
60
|
def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
53
|
-
df.columns = input_cols
|
|
54
|
-
input_df = df.astype(dtype=dtype_map)
|
|
55
|
-
|
|
61
|
+
df.columns = input_cols + param_cols
|
|
62
|
+
input_df = df[input_cols].astype(dtype=dtype_map)
|
|
63
|
+
|
|
64
|
+
# Extract runtime param values, using defaults if None
|
|
65
|
+
method_params = {{}}
|
|
66
|
+
for col in param_cols:
|
|
67
|
+
val = df[col].iloc[0]
|
|
68
|
+
if val is None or pd.isna(val):
|
|
69
|
+
method_params[col] = param_defaults[col]
|
|
70
|
+
else:
|
|
71
|
+
method_params[col] = val
|
|
72
|
+
|
|
73
|
+
return runner(input_df, **method_params)
|
|
@@ -40,11 +40,29 @@ features = meta.signatures[TARGET_METHOD].inputs
|
|
|
40
40
|
input_cols = [feature.name for feature in features]
|
|
41
41
|
dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
42
42
|
|
|
43
|
+
# Load inference parameters from method signature (if any)
|
|
44
|
+
param_cols = []
|
|
45
|
+
param_defaults = {{}}
|
|
46
|
+
if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
|
|
47
|
+
for param_spec in meta.signatures[TARGET_METHOD].params:
|
|
48
|
+
param_cols.append(param_spec.name)
|
|
49
|
+
param_defaults[param_spec.name] = param_spec.default_value
|
|
50
|
+
|
|
43
51
|
|
|
44
52
|
# Actual table function
|
|
45
53
|
class {function_name}:
|
|
46
54
|
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
|
|
47
55
|
def process(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
48
|
-
df.columns = input_cols
|
|
49
|
-
input_df = df.astype(dtype=dtype_map)
|
|
50
|
-
|
|
56
|
+
df.columns = input_cols + param_cols
|
|
57
|
+
input_df = df[input_cols].astype(dtype=dtype_map)
|
|
58
|
+
|
|
59
|
+
# Extract runtime param values, using defaults if None
|
|
60
|
+
method_params = {{}}
|
|
61
|
+
for col in param_cols:
|
|
62
|
+
val = df[col].iloc[0]
|
|
63
|
+
if val is None or pd.isna(val):
|
|
64
|
+
method_params[col] = param_defaults[col]
|
|
65
|
+
else:
|
|
66
|
+
method_params[col] = val
|
|
67
|
+
|
|
68
|
+
return runner(input_df, **method_params)
|
|
@@ -156,10 +156,11 @@ class ModelMethod:
|
|
|
156
156
|
f"Your parameter {param_spec.name} cannot be resolved as valid SQL identifier. "
|
|
157
157
|
"Try specifying `case_sensitive` as True."
|
|
158
158
|
) from e
|
|
159
|
+
default_value = param_spec.default_value if param_spec.default_value is None else str(param_spec.default_value)
|
|
159
160
|
return model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault(
|
|
160
161
|
name=param_name.resolved(),
|
|
161
162
|
type=type_utils.convert_sp_to_sf_type(param_spec.dtype.as_snowpark_type()),
|
|
162
|
-
default=
|
|
163
|
+
default=default_value,
|
|
163
164
|
)
|
|
164
165
|
|
|
165
166
|
def save(
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import io
|
|
1
2
|
import json
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
@@ -28,7 +29,10 @@ from snowflake.ml.model._packager.model_meta import (
|
|
|
28
29
|
model_meta as model_meta_api,
|
|
29
30
|
model_meta_schema,
|
|
30
31
|
)
|
|
31
|
-
from snowflake.ml.model._signatures import
|
|
32
|
+
from snowflake.ml.model._signatures import (
|
|
33
|
+
core as model_signature_core,
|
|
34
|
+
utils as model_signature_utils,
|
|
35
|
+
)
|
|
32
36
|
from snowflake.ml.model.models import (
|
|
33
37
|
huggingface as huggingface_base,
|
|
34
38
|
huggingface_pipeline,
|
|
@@ -530,7 +534,10 @@ class TransformersPipelineHandler(
|
|
|
530
534
|
# verify when the target method is __call__ and
|
|
531
535
|
# if the signature is default text-generation signature
|
|
532
536
|
# then use the HuggingFaceOpenAICompatibleModel to wrap the pipeline
|
|
533
|
-
if
|
|
537
|
+
if (
|
|
538
|
+
signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC
|
|
539
|
+
or signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC_WITH_CONTENT_FORMAT_STRING
|
|
540
|
+
):
|
|
534
541
|
wrapped_model = HuggingFaceOpenAICompatibleModel(pipeline=raw_model)
|
|
535
542
|
|
|
536
543
|
temp_res = X.apply(
|
|
@@ -554,6 +561,19 @@ class TransformersPipelineHandler(
|
|
|
554
561
|
else:
|
|
555
562
|
input_data = X[signature.inputs[0].name].to_list()
|
|
556
563
|
temp_res = getattr(raw_model, target_method)(input_data)
|
|
564
|
+
elif isinstance(raw_model, transformers.ImageClassificationPipeline):
|
|
565
|
+
# Image classification expects PIL Images. Convert bytes to PIL Images.
|
|
566
|
+
from PIL import Image
|
|
567
|
+
|
|
568
|
+
input_col = signature.inputs[0].name
|
|
569
|
+
images = [Image.open(io.BytesIO(img_bytes)) for img_bytes in X[input_col].to_list()]
|
|
570
|
+
temp_res = getattr(raw_model, target_method)(images)
|
|
571
|
+
elif isinstance(raw_model, transformers.AutomaticSpeechRecognitionPipeline):
|
|
572
|
+
# ASR pipeline accepts a single audio input (bytes, str, np.ndarray, or dict),
|
|
573
|
+
# not a list. Process each audio input individually.
|
|
574
|
+
input_col = signature.inputs[0].name
|
|
575
|
+
audio_inputs = X[input_col].to_list()
|
|
576
|
+
temp_res = [getattr(raw_model, target_method)(audio) for audio in audio_inputs]
|
|
557
577
|
else:
|
|
558
578
|
# TODO: remove conversational pipeline code
|
|
559
579
|
# For others, we could offer the whole dataframe as a list.
|
|
@@ -615,11 +635,14 @@ class TransformersPipelineHandler(
|
|
|
615
635
|
temp_res = [[conv.generated_responses] for conv in temp_res]
|
|
616
636
|
|
|
617
637
|
# To concat those who outputs a list with one input.
|
|
618
|
-
if
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
638
|
+
# if `signature.outputs` is single valued and is a FeatureGroupSpec,
|
|
639
|
+
# we create a DataFrame with one column and the values are stored as a dictionary.
|
|
640
|
+
# Otherwise, we create a DataFrame with the output as the column.
|
|
641
|
+
if len(signature.outputs) == 1 and isinstance(
|
|
642
|
+
signature.outputs[0], model_signature_core.FeatureGroupSpec
|
|
643
|
+
):
|
|
644
|
+
# creating a dataframe with one column
|
|
645
|
+
res = pd.DataFrame({signature.outputs[0].name: temp_res})
|
|
623
646
|
else:
|
|
624
647
|
res = pd.DataFrame(temp_res)
|
|
625
648
|
|
|
@@ -702,7 +725,6 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
702
725
|
self.pipeline = pipeline
|
|
703
726
|
self.model = self.pipeline.model
|
|
704
727
|
self.tokenizer = self.pipeline.tokenizer
|
|
705
|
-
|
|
706
728
|
self.model_name = self.pipeline.model.name_or_path
|
|
707
729
|
|
|
708
730
|
if self.tokenizer.pad_token is None:
|
|
@@ -724,11 +746,33 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
724
746
|
Returns:
|
|
725
747
|
The formatted prompt string ready for model input.
|
|
726
748
|
"""
|
|
749
|
+
|
|
750
|
+
final_messages = []
|
|
751
|
+
for message in messages:
|
|
752
|
+
if isinstance(message.get("content", ""), str):
|
|
753
|
+
final_messages.append({"role": message.get("role", "user"), "content": message.get("content", "")})
|
|
754
|
+
else:
|
|
755
|
+
# extract only the text from the content
|
|
756
|
+
# sample data:
|
|
757
|
+
# {
|
|
758
|
+
# "role": "user",
|
|
759
|
+
# "content": [
|
|
760
|
+
# {"type": "text", "text": "Hello, how are you?"}, # extracted
|
|
761
|
+
# {"type": "image", "image": "https://example.com/image.png"}, # not extracted
|
|
762
|
+
# ],
|
|
763
|
+
# }
|
|
764
|
+
for content_part in message.get("content", []):
|
|
765
|
+
if content_part.get("type", "") == "text":
|
|
766
|
+
final_messages.append(
|
|
767
|
+
{"role": message.get("role", "user"), "content": content_part.get("text", "")}
|
|
768
|
+
)
|
|
769
|
+
# TODO: implement other content types
|
|
770
|
+
|
|
727
771
|
# Use the tokenizer's apply_chat_template method.
|
|
728
772
|
# We ensured a template exists in __init__.
|
|
729
773
|
if hasattr(self.tokenizer, "apply_chat_template"):
|
|
730
774
|
return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
|
|
731
|
-
|
|
775
|
+
final_messages,
|
|
732
776
|
tokenize=False,
|
|
733
777
|
add_generation_prompt=True,
|
|
734
778
|
)
|
|
@@ -736,7 +780,7 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
736
780
|
# Fallback for very old transformers without apply_chat_template
|
|
737
781
|
# Manually apply ChatML-like formatting
|
|
738
782
|
prompt = ""
|
|
739
|
-
for message in
|
|
783
|
+
for message in final_messages:
|
|
740
784
|
role = message.get("role", "user")
|
|
741
785
|
content = message.get("content", "")
|
|
742
786
|
prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|