snowflake-ml-python 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +0 -4
- snowflake/ml/feature_store/__init__.py +2 -0
- snowflake/ml/feature_store/aggregation.py +367 -0
- snowflake/ml/feature_store/feature.py +366 -0
- snowflake/ml/feature_store/feature_store.py +234 -20
- snowflake/ml/feature_store/feature_view.py +189 -4
- snowflake/ml/feature_store/metadata_manager.py +425 -0
- snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +38 -18
- snowflake/ml/jobs/_utils/query_helper.py +8 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +117 -0
- snowflake/ml/jobs/_utils/stage_utils.py +2 -2
- snowflake/ml/jobs/_utils/types.py +22 -2
- snowflake/ml/jobs/job_definition.py +232 -0
- snowflake/ml/jobs/manager.py +16 -177
- snowflake/ml/model/__init__.py +4 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
- snowflake/ml/model/_client/model/model_version_impl.py +120 -89
- snowflake/ml/model/_client/ops/model_ops.py +4 -26
- snowflake/ml/model/_client/ops/param_utils.py +124 -0
- snowflake/ml/model/_client/ops/service_ops.py +63 -23
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/service.py +25 -54
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -1
- snowflake/ml/model/_packager/model_handlers/huggingface.py +74 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +121 -29
- snowflake/ml/model/_signatures/utils.py +130 -0
- snowflake/ml/model/openai_signatures.py +97 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/METADATA +105 -1
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/RECORD +41 -35
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/WHEEL +1 -1
- snowflake/ml/experiment/callback/__init__.py +0 -0
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
import base64
|
|
1
2
|
import dataclasses
|
|
3
|
+
import json
|
|
2
4
|
import logging
|
|
3
5
|
import pathlib
|
|
4
6
|
import re
|
|
@@ -6,7 +8,9 @@ import tempfile
|
|
|
6
8
|
import threading
|
|
7
9
|
import time
|
|
8
10
|
import warnings
|
|
9
|
-
from typing import Any, Optional, Union, cast
|
|
11
|
+
from typing import Any, Optional, Sequence, Union, cast
|
|
12
|
+
|
|
13
|
+
from pydantic import TypeAdapter
|
|
10
14
|
|
|
11
15
|
from snowflake import snowpark
|
|
12
16
|
from snowflake.ml._internal import file_utils, platform_capabilities as pc
|
|
@@ -14,9 +18,10 @@ from snowflake.ml._internal.utils import identifier, service_logger, sql_identif
|
|
|
14
18
|
from snowflake.ml.jobs import job
|
|
15
19
|
from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
|
|
16
20
|
from snowflake.ml.model._client.model import batch_inference_specs
|
|
17
|
-
from snowflake.ml.model._client.ops import deployment_step
|
|
21
|
+
from snowflake.ml.model._client.ops import deployment_step, param_utils
|
|
18
22
|
from snowflake.ml.model._client.service import model_deployment_spec
|
|
19
23
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
|
24
|
+
from snowflake.ml.model._signatures import core
|
|
20
25
|
from snowflake.snowpark import async_job, exceptions, row, session
|
|
21
26
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
22
27
|
|
|
@@ -150,7 +155,6 @@ class ServiceOperator:
|
|
|
150
155
|
self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
|
|
151
156
|
workspace_path=pathlib.Path(self._workspace.name)
|
|
152
157
|
)
|
|
153
|
-
self._inference_autocapture_enabled = pc.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
|
|
154
158
|
|
|
155
159
|
def __eq__(self, __value: object) -> bool:
|
|
156
160
|
if not isinstance(__value, ServiceOperator):
|
|
@@ -211,10 +215,6 @@ class ServiceOperator:
|
|
|
211
215
|
progress_status.update("preparing deployment artifacts...")
|
|
212
216
|
progress_status.increment()
|
|
213
217
|
|
|
214
|
-
# If autocapture param is disabled, don't allow create service with autocapture
|
|
215
|
-
if not self._inference_autocapture_enabled and autocapture:
|
|
216
|
-
raise ValueError("Invalid Argument: Autocapture feature is not supported.")
|
|
217
|
-
|
|
218
218
|
if self._workspace:
|
|
219
219
|
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
|
220
220
|
else:
|
|
@@ -582,15 +582,10 @@ class ServiceOperator:
|
|
|
582
582
|
)
|
|
583
583
|
for status in statuses:
|
|
584
584
|
if status.instance_id is not None:
|
|
585
|
-
instance_status, container_status = None, None
|
|
586
|
-
if status.instance_status is not None:
|
|
587
|
-
instance_status = status.instance_status.value
|
|
588
|
-
if status.container_status is not None:
|
|
589
|
-
container_status = status.container_status.value
|
|
590
585
|
module_logger.info(
|
|
591
586
|
f"Instance[{status.instance_id}]: "
|
|
592
|
-
f"instance status: {instance_status}, "
|
|
593
|
-
f"container status: {container_status}, "
|
|
587
|
+
f"instance status: {status.instance_status}, "
|
|
588
|
+
f"container status: {status.container_status}, "
|
|
594
589
|
f"message: {status.message}"
|
|
595
590
|
)
|
|
596
591
|
time.sleep(5)
|
|
@@ -930,6 +925,38 @@ class ServiceOperator:
|
|
|
930
925
|
except exceptions.SnowparkSQLException:
|
|
931
926
|
return False
|
|
932
927
|
|
|
928
|
+
@staticmethod
|
|
929
|
+
def _encode_params(params: Optional[dict[str, Any]]) -> Optional[str]:
|
|
930
|
+
"""Encode params dictionary to a base64 string.
|
|
931
|
+
|
|
932
|
+
Args:
|
|
933
|
+
params: Optional dictionary of model inference parameters.
|
|
934
|
+
|
|
935
|
+
Returns:
|
|
936
|
+
Base64 encoded JSON string of the params, or None if input is None.
|
|
937
|
+
"""
|
|
938
|
+
if params is None:
|
|
939
|
+
return None
|
|
940
|
+
return base64.b64encode(json.dumps(params).encode("utf-8")).decode("utf-8")
|
|
941
|
+
|
|
942
|
+
@staticmethod
|
|
943
|
+
def _encode_column_handling(
|
|
944
|
+
column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
|
|
945
|
+
) -> Optional[str]:
|
|
946
|
+
"""Validate and encode column_handling to a base64 string.
|
|
947
|
+
|
|
948
|
+
Args:
|
|
949
|
+
column_handling: Optional dictionary mapping column names to file encoding options.
|
|
950
|
+
|
|
951
|
+
Returns:
|
|
952
|
+
Base64 encoded JSON string of the column handling options, or None if input is None.
|
|
953
|
+
"""
|
|
954
|
+
if column_handling is None:
|
|
955
|
+
return None
|
|
956
|
+
adapter = TypeAdapter(dict[str, batch_inference_specs.ColumnHandlingOptions])
|
|
957
|
+
validated_input = adapter.validate_python(column_handling)
|
|
958
|
+
return base64.b64encode(adapter.dump_json(validated_input)).decode("utf-8")
|
|
959
|
+
|
|
933
960
|
def invoke_batch_job_method(
|
|
934
961
|
self,
|
|
935
962
|
*,
|
|
@@ -942,8 +969,9 @@ class ServiceOperator:
|
|
|
942
969
|
image_repo_name: Optional[str],
|
|
943
970
|
input_stage_location: str,
|
|
944
971
|
input_file_pattern: str,
|
|
945
|
-
column_handling: Optional[str],
|
|
946
|
-
params: Optional[str],
|
|
972
|
+
column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
|
|
973
|
+
params: Optional[dict[str, Any]],
|
|
974
|
+
signature_params: Optional[Sequence[core.BaseParamSpec]],
|
|
947
975
|
output_stage_location: str,
|
|
948
976
|
completion_filename: str,
|
|
949
977
|
force_rebuild: bool,
|
|
@@ -954,7 +982,13 @@ class ServiceOperator:
|
|
|
954
982
|
gpu_requests: Optional[str],
|
|
955
983
|
replicas: Optional[int],
|
|
956
984
|
statement_params: Optional[dict[str, Any]] = None,
|
|
985
|
+
inference_engine_args: Optional[InferenceEngineArgs] = None,
|
|
957
986
|
) -> job.MLJob[Any]:
|
|
987
|
+
# Validate and encode params
|
|
988
|
+
param_utils.validate_params(params, signature_params)
|
|
989
|
+
params_encoded = self._encode_params(params)
|
|
990
|
+
column_handling_encoded = self._encode_column_handling(column_handling)
|
|
991
|
+
|
|
958
992
|
database_name = self._database_name
|
|
959
993
|
schema_name = self._schema_name
|
|
960
994
|
|
|
@@ -980,8 +1014,8 @@ class ServiceOperator:
|
|
|
980
1014
|
max_batch_rows=max_batch_rows,
|
|
981
1015
|
input_stage_location=input_stage_location,
|
|
982
1016
|
input_file_pattern=input_file_pattern,
|
|
983
|
-
column_handling=
|
|
984
|
-
params=
|
|
1017
|
+
column_handling=column_handling_encoded,
|
|
1018
|
+
params=params_encoded,
|
|
985
1019
|
output_stage_location=output_stage_location,
|
|
986
1020
|
completion_filename=completion_filename,
|
|
987
1021
|
function_name=function_name,
|
|
@@ -992,11 +1026,17 @@ class ServiceOperator:
|
|
|
992
1026
|
replicas=replicas,
|
|
993
1027
|
)
|
|
994
1028
|
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1029
|
+
if inference_engine_args:
|
|
1030
|
+
self._model_deployment_spec.add_inference_engine_spec(
|
|
1031
|
+
inference_engine=inference_engine_args.inference_engine,
|
|
1032
|
+
inference_engine_args=inference_engine_args.inference_engine_args_override,
|
|
1033
|
+
)
|
|
1034
|
+
else:
|
|
1035
|
+
self._model_deployment_spec.add_image_build_spec(
|
|
1036
|
+
image_build_compute_pool_name=compute_pool_name,
|
|
1037
|
+
fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
|
|
1038
|
+
force_rebuild=force_rebuild,
|
|
1039
|
+
)
|
|
1000
1040
|
|
|
1001
1041
|
spec_yaml_str_or_path = self._model_deployment_spec.save()
|
|
1002
1042
|
|
|
@@ -363,7 +363,7 @@ class ModelDeploymentSpec:
|
|
|
363
363
|
inference_engine: inference_engine_module.InferenceEngine,
|
|
364
364
|
inference_engine_args: Optional[list[str]] = None,
|
|
365
365
|
) -> "ModelDeploymentSpec":
|
|
366
|
-
"""Add inference engine specification. This must be called after self.add_service_spec().
|
|
366
|
+
"""Add inference engine specification. This must be called after self.add_service_spec() or self.add_job_spec().
|
|
367
367
|
|
|
368
368
|
Args:
|
|
369
369
|
inference_engine: Inference engine.
|
|
@@ -376,9 +376,10 @@ class ModelDeploymentSpec:
|
|
|
376
376
|
ValueError: If inference engine specification is called before add_service_spec().
|
|
377
377
|
ValueError: If the argument does not have a '--' prefix.
|
|
378
378
|
"""
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
379
|
+
if self._service is None and self._job is None:
|
|
380
|
+
raise ValueError(
|
|
381
|
+
"Inference engine specification must be called after add_service_spec() or add_job_spec()."
|
|
382
|
+
)
|
|
382
383
|
|
|
383
384
|
if inference_engine_args is None:
|
|
384
385
|
inference_engine_args = []
|
|
@@ -431,11 +432,17 @@ class ModelDeploymentSpec:
|
|
|
431
432
|
|
|
432
433
|
inference_engine_args = filtered_args
|
|
433
434
|
|
|
434
|
-
|
|
435
|
+
inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
|
|
435
436
|
# convert to string to be saved in the deployment spec
|
|
436
437
|
inference_engine_name=inference_engine.value,
|
|
437
438
|
inference_engine_args=inference_engine_args,
|
|
438
439
|
)
|
|
440
|
+
|
|
441
|
+
if self._service:
|
|
442
|
+
self._service.inference_engine_spec = inference_engine_spec
|
|
443
|
+
elif self._job:
|
|
444
|
+
self._job.inference_engine_spec = inference_engine_spec
|
|
445
|
+
|
|
439
446
|
return self
|
|
440
447
|
|
|
441
448
|
def save(self) -> str:
|
|
@@ -47,22 +47,6 @@ class ServiceStatus(enum.Enum):
|
|
|
47
47
|
INTERNAL_ERROR = "INTERNAL_ERROR"
|
|
48
48
|
|
|
49
49
|
|
|
50
|
-
class InstanceStatus(enum.Enum):
|
|
51
|
-
PENDING = "PENDING"
|
|
52
|
-
READY = "READY"
|
|
53
|
-
FAILED = "FAILED"
|
|
54
|
-
TERMINATING = "TERMINATING"
|
|
55
|
-
SUCCEEDED = "SUCCEEDED"
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class ContainerStatus(enum.Enum):
|
|
59
|
-
PENDING = "PENDING"
|
|
60
|
-
READY = "READY"
|
|
61
|
-
DONE = "DONE"
|
|
62
|
-
FAILED = "FAILED"
|
|
63
|
-
UNKNOWN = "UNKNOWN"
|
|
64
|
-
|
|
65
|
-
|
|
66
50
|
@dataclasses.dataclass
|
|
67
51
|
class ServiceStatusInfo:
|
|
68
52
|
"""
|
|
@@ -72,8 +56,8 @@ class ServiceStatusInfo:
|
|
|
72
56
|
|
|
73
57
|
service_status: ServiceStatus
|
|
74
58
|
instance_id: Optional[int] = None
|
|
75
|
-
instance_status: Optional[
|
|
76
|
-
container_status: Optional[
|
|
59
|
+
instance_status: Optional[str] = None
|
|
60
|
+
container_status: Optional[str] = None
|
|
77
61
|
message: Optional[str] = None
|
|
78
62
|
|
|
79
63
|
|
|
@@ -91,10 +75,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
91
75
|
DESC_SERVICE_SPEC_COL_NAME = "spec"
|
|
92
76
|
DESC_SERVICE_CONTAINERS_SPEC_NAME = "containers"
|
|
93
77
|
DESC_SERVICE_NAME_SPEC_NAME = "name"
|
|
94
|
-
|
|
95
|
-
PROXY_CONTAINER_NAME = "proxy"
|
|
78
|
+
DESC_SERVICE_ENV_SPEC_NAME = "env"
|
|
96
79
|
MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME = "SPCS_MODEL_INFERENCE_SERVER__AUTOCAPTURE_ENABLED"
|
|
97
|
-
FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
|
|
98
80
|
|
|
99
81
|
@contextlib.contextmanager
|
|
100
82
|
def _qmark_paramstyle(self) -> Generator[None, None, None]:
|
|
@@ -272,17 +254,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
272
254
|
)
|
|
273
255
|
statuses = []
|
|
274
256
|
for r in rows:
|
|
275
|
-
instance_status, container_status = None, None
|
|
276
|
-
if r[ServiceSQLClient.INSTANCE_STATUS] is not None:
|
|
277
|
-
instance_status = InstanceStatus(r[ServiceSQLClient.INSTANCE_STATUS])
|
|
278
|
-
if r[ServiceSQLClient.CONTAINER_STATUS] is not None:
|
|
279
|
-
container_status = ContainerStatus(r[ServiceSQLClient.CONTAINER_STATUS])
|
|
280
257
|
statuses.append(
|
|
281
258
|
ServiceStatusInfo(
|
|
282
259
|
service_status=ServiceStatus(r[ServiceSQLClient.SERVICE_STATUS]),
|
|
283
260
|
instance_id=r[ServiceSQLClient.INSTANCE_ID],
|
|
284
|
-
instance_status=
|
|
285
|
-
container_status=
|
|
261
|
+
instance_status=r[ServiceSQLClient.INSTANCE_STATUS],
|
|
262
|
+
container_status=r[ServiceSQLClient.CONTAINER_STATUS],
|
|
286
263
|
message=r[ServiceSQLClient.MESSAGE] if include_message else None,
|
|
287
264
|
)
|
|
288
265
|
)
|
|
@@ -306,39 +283,33 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
306
283
|
)
|
|
307
284
|
return rows[0]
|
|
308
285
|
|
|
309
|
-
def
|
|
310
|
-
"""Extract whether service has autocapture enabled
|
|
286
|
+
def is_autocapture_enabled(self, row: row.Row) -> bool:
|
|
287
|
+
"""Extract whether service has autocapture enabled in any container from service spec.
|
|
311
288
|
|
|
312
289
|
Args:
|
|
313
290
|
row: A row.Row object from DESCRIBE SERVICE containing the service YAML spec.
|
|
314
291
|
|
|
315
292
|
Returns:
|
|
316
|
-
True if autocapture is enabled in
|
|
317
|
-
False if disabled or not set in
|
|
318
|
-
False if service doesn't have proxy container
|
|
293
|
+
True if autocapture is enabled in any container.
|
|
294
|
+
False if autocapture is disabled or not set in any container.
|
|
319
295
|
"""
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
if spec_yaml is None:
|
|
323
|
-
return False
|
|
324
|
-
spec_raw = yaml.safe_load(spec_yaml)
|
|
325
|
-
if spec_raw is None:
|
|
326
|
-
return False
|
|
327
|
-
spec = cast(dict[str, Any], spec_raw)
|
|
328
|
-
|
|
329
|
-
proxy_container_spec = next(
|
|
330
|
-
container
|
|
331
|
-
for container in spec[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME][
|
|
332
|
-
ServiceSQLClient.DESC_SERVICE_CONTAINERS_SPEC_NAME
|
|
333
|
-
]
|
|
334
|
-
if container[ServiceSQLClient.DESC_SERVICE_NAME_SPEC_NAME] == ServiceSQLClient.PROXY_CONTAINER_NAME
|
|
335
|
-
)
|
|
336
|
-
env = proxy_container_spec.get(ServiceSQLClient.DESC_SERVICE_PROXY_SPEC_ENV_NAME, {})
|
|
337
|
-
autocapture_enabled = env.get(ServiceSQLClient.MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME, "false")
|
|
338
|
-
return str(autocapture_enabled).lower() == "true"
|
|
339
|
-
|
|
340
|
-
except StopIteration:
|
|
296
|
+
spec_yaml = row.as_dict().get(ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME)
|
|
297
|
+
if spec_yaml is None:
|
|
341
298
|
return False
|
|
299
|
+
spec_raw = yaml.safe_load(spec_yaml)
|
|
300
|
+
if spec_raw is None:
|
|
301
|
+
return False
|
|
302
|
+
spec = cast(dict[str, Any], spec_raw)
|
|
303
|
+
|
|
304
|
+
containers = spec[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME][
|
|
305
|
+
ServiceSQLClient.DESC_SERVICE_CONTAINERS_SPEC_NAME
|
|
306
|
+
]
|
|
307
|
+
for container in containers:
|
|
308
|
+
env = container.get(ServiceSQLClient.DESC_SERVICE_ENV_SPEC_NAME, {})
|
|
309
|
+
autocapture_enabled = env.get(ServiceSQLClient.MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME, "false")
|
|
310
|
+
if str(autocapture_enabled).lower() == "true":
|
|
311
|
+
return True
|
|
312
|
+
return False
|
|
342
313
|
|
|
343
314
|
def drop_service(
|
|
344
315
|
self,
|
|
@@ -41,11 +41,29 @@ features = meta.signatures[TARGET_METHOD].inputs
|
|
|
41
41
|
input_cols = [feature.name for feature in features]
|
|
42
42
|
dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
43
43
|
|
|
44
|
+
# Load inference parameters from method signature (if any)
|
|
45
|
+
param_cols = []
|
|
46
|
+
param_defaults = {{}}
|
|
47
|
+
if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
|
|
48
|
+
for param_spec in meta.signatures[TARGET_METHOD].params:
|
|
49
|
+
param_cols.append(param_spec.name)
|
|
50
|
+
param_defaults[param_spec.name] = param_spec.default_value
|
|
51
|
+
|
|
44
52
|
|
|
45
53
|
# Actual function
|
|
46
54
|
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
|
|
47
55
|
def {function_name}(df: pd.DataFrame) -> dict:
|
|
48
|
-
df.columns = input_cols
|
|
49
|
-
input_df = df.astype(dtype=dtype_map)
|
|
50
|
-
|
|
56
|
+
df.columns = input_cols + param_cols
|
|
57
|
+
input_df = df[input_cols].astype(dtype=dtype_map)
|
|
58
|
+
|
|
59
|
+
# Extract runtime param values, using defaults if None
|
|
60
|
+
method_params = {{}}
|
|
61
|
+
for col in param_cols:
|
|
62
|
+
val = df[col].iloc[0]
|
|
63
|
+
if val is None or pd.isna(val):
|
|
64
|
+
method_params[col] = param_defaults[col]
|
|
65
|
+
else:
|
|
66
|
+
method_params[col] = val
|
|
67
|
+
|
|
68
|
+
predictions_df = runner(input_df, **method_params)
|
|
51
69
|
return predictions_df.replace({{pd.NA: None, np.nan: None}}).to_dict("records")
|
|
@@ -45,11 +45,29 @@ features = meta.signatures[TARGET_METHOD].inputs
|
|
|
45
45
|
input_cols = [feature.name for feature in features]
|
|
46
46
|
dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
47
47
|
|
|
48
|
+
# Load inference parameters from method signature (if any)
|
|
49
|
+
param_cols = []
|
|
50
|
+
param_defaults = {{}}
|
|
51
|
+
if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
|
|
52
|
+
for param_spec in meta.signatures[TARGET_METHOD].params:
|
|
53
|
+
param_cols.append(param_spec.name)
|
|
54
|
+
param_defaults[param_spec.name] = param_spec.default_value
|
|
55
|
+
|
|
48
56
|
|
|
49
57
|
# Actual table function
|
|
50
58
|
class {function_name}:
|
|
51
59
|
@vectorized(input=pd.DataFrame, flatten_object_input={wide_input})
|
|
52
60
|
def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
53
|
-
df.columns = input_cols
|
|
54
|
-
input_df = df.astype(dtype=dtype_map)
|
|
55
|
-
|
|
61
|
+
df.columns = input_cols + param_cols
|
|
62
|
+
input_df = df[input_cols].astype(dtype=dtype_map)
|
|
63
|
+
|
|
64
|
+
# Extract runtime param values, using defaults if None
|
|
65
|
+
method_params = {{}}
|
|
66
|
+
for col in param_cols:
|
|
67
|
+
val = df[col].iloc[0]
|
|
68
|
+
if val is None or pd.isna(val):
|
|
69
|
+
method_params[col] = param_defaults[col]
|
|
70
|
+
else:
|
|
71
|
+
method_params[col] = val
|
|
72
|
+
|
|
73
|
+
return runner(input_df, **method_params)
|
|
@@ -40,11 +40,29 @@ features = meta.signatures[TARGET_METHOD].inputs
|
|
|
40
40
|
input_cols = [feature.name for feature in features]
|
|
41
41
|
dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
42
42
|
|
|
43
|
+
# Load inference parameters from method signature (if any)
|
|
44
|
+
param_cols = []
|
|
45
|
+
param_defaults = {{}}
|
|
46
|
+
if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
|
|
47
|
+
for param_spec in meta.signatures[TARGET_METHOD].params:
|
|
48
|
+
param_cols.append(param_spec.name)
|
|
49
|
+
param_defaults[param_spec.name] = param_spec.default_value
|
|
50
|
+
|
|
43
51
|
|
|
44
52
|
# Actual table function
|
|
45
53
|
class {function_name}:
|
|
46
54
|
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
|
|
47
55
|
def process(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
48
|
-
df.columns = input_cols
|
|
49
|
-
input_df = df.astype(dtype=dtype_map)
|
|
50
|
-
|
|
56
|
+
df.columns = input_cols + param_cols
|
|
57
|
+
input_df = df[input_cols].astype(dtype=dtype_map)
|
|
58
|
+
|
|
59
|
+
# Extract runtime param values, using defaults if None
|
|
60
|
+
method_params = {{}}
|
|
61
|
+
for col in param_cols:
|
|
62
|
+
val = df[col].iloc[0]
|
|
63
|
+
if val is None or pd.isna(val):
|
|
64
|
+
method_params[col] = param_defaults[col]
|
|
65
|
+
else:
|
|
66
|
+
method_params[col] = val
|
|
67
|
+
|
|
68
|
+
return runner(input_df, **method_params)
|
|
@@ -156,10 +156,12 @@ class ModelMethod:
|
|
|
156
156
|
f"Your parameter {param_spec.name} cannot be resolved as valid SQL identifier. "
|
|
157
157
|
"Try specifying `case_sensitive` as True."
|
|
158
158
|
) from e
|
|
159
|
+
# Convert None to "NULL" string so MANIFEST parser can interpret it as SQL NULL
|
|
160
|
+
default_value = "NULL" if param_spec.default_value is None else str(param_spec.default_value)
|
|
159
161
|
return model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault(
|
|
160
162
|
name=param_name.resolved(),
|
|
161
163
|
type=type_utils.convert_sp_to_sf_type(param_spec.dtype.as_snowpark_type()),
|
|
162
|
-
default=
|
|
164
|
+
default=default_value,
|
|
163
165
|
)
|
|
164
166
|
|
|
165
167
|
def save(
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import io
|
|
1
2
|
import json
|
|
2
3
|
import logging
|
|
3
4
|
import os
|
|
@@ -28,7 +29,10 @@ from snowflake.ml.model._packager.model_meta import (
|
|
|
28
29
|
model_meta as model_meta_api,
|
|
29
30
|
model_meta_schema,
|
|
30
31
|
)
|
|
31
|
-
from snowflake.ml.model._signatures import
|
|
32
|
+
from snowflake.ml.model._signatures import (
|
|
33
|
+
core as model_signature_core,
|
|
34
|
+
utils as model_signature_utils,
|
|
35
|
+
)
|
|
32
36
|
from snowflake.ml.model.models import (
|
|
33
37
|
huggingface as huggingface_base,
|
|
34
38
|
huggingface_pipeline,
|
|
@@ -530,7 +534,10 @@ class TransformersPipelineHandler(
|
|
|
530
534
|
# verify when the target method is __call__ and
|
|
531
535
|
# if the signature is default text-generation signature
|
|
532
536
|
# then use the HuggingFaceOpenAICompatibleModel to wrap the pipeline
|
|
533
|
-
if
|
|
537
|
+
if (
|
|
538
|
+
signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC
|
|
539
|
+
or signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC_WITH_CONTENT_FORMAT_STRING
|
|
540
|
+
):
|
|
534
541
|
wrapped_model = HuggingFaceOpenAICompatibleModel(pipeline=raw_model)
|
|
535
542
|
|
|
536
543
|
temp_res = X.apply(
|
|
@@ -554,6 +561,39 @@ class TransformersPipelineHandler(
|
|
|
554
561
|
else:
|
|
555
562
|
input_data = X[signature.inputs[0].name].to_list()
|
|
556
563
|
temp_res = getattr(raw_model, target_method)(input_data)
|
|
564
|
+
elif isinstance(raw_model, transformers.ImageClassificationPipeline):
|
|
565
|
+
# Image classification expects PIL Images. Convert bytes to PIL Images.
|
|
566
|
+
from PIL import Image
|
|
567
|
+
|
|
568
|
+
input_col = signature.inputs[0].name
|
|
569
|
+
images = [Image.open(io.BytesIO(img_bytes)) for img_bytes in X[input_col].to_list()]
|
|
570
|
+
temp_res = getattr(raw_model, target_method)(images)
|
|
571
|
+
elif isinstance(raw_model, transformers.AutomaticSpeechRecognitionPipeline):
|
|
572
|
+
# ASR pipeline accepts a single audio input (bytes, str, np.ndarray, or dict),
|
|
573
|
+
# not a list. Process each audio input individually.
|
|
574
|
+
input_col = signature.inputs[0].name
|
|
575
|
+
audio_inputs = X[input_col].to_list()
|
|
576
|
+
temp_res = [getattr(raw_model, target_method)(audio) for audio in audio_inputs]
|
|
577
|
+
elif isinstance(raw_model, transformers.VideoClassificationPipeline):
|
|
578
|
+
# Video classification expects file paths. Write bytes to temp files,
|
|
579
|
+
# process them, and clean up.
|
|
580
|
+
import tempfile
|
|
581
|
+
|
|
582
|
+
input_col = signature.inputs[0].name
|
|
583
|
+
video_bytes_list = X[input_col].to_list()
|
|
584
|
+
temp_file_paths = []
|
|
585
|
+
temp_files = []
|
|
586
|
+
try:
|
|
587
|
+
# TODO: parallelize this if needed
|
|
588
|
+
for video_bytes in video_bytes_list:
|
|
589
|
+
temp_file = tempfile.NamedTemporaryFile()
|
|
590
|
+
temp_file.write(video_bytes)
|
|
591
|
+
temp_file_paths.append(temp_file.name)
|
|
592
|
+
temp_files.append(temp_file)
|
|
593
|
+
temp_res = getattr(raw_model, target_method)(temp_file_paths)
|
|
594
|
+
finally:
|
|
595
|
+
for f in temp_files:
|
|
596
|
+
f.close()
|
|
557
597
|
else:
|
|
558
598
|
# TODO: remove conversational pipeline code
|
|
559
599
|
# For others, we could offer the whole dataframe as a list.
|
|
@@ -615,11 +655,14 @@ class TransformersPipelineHandler(
|
|
|
615
655
|
temp_res = [[conv.generated_responses] for conv in temp_res]
|
|
616
656
|
|
|
617
657
|
# To concat those who outputs a list with one input.
|
|
618
|
-
if
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
658
|
+
# if `signature.outputs` is single valued and is a FeatureGroupSpec,
|
|
659
|
+
# we create a DataFrame with one column and the values are stored as a dictionary.
|
|
660
|
+
# Otherwise, we create a DataFrame with the output as the column.
|
|
661
|
+
if len(signature.outputs) == 1 and isinstance(
|
|
662
|
+
signature.outputs[0], model_signature_core.FeatureGroupSpec
|
|
663
|
+
):
|
|
664
|
+
# creating a dataframe with one column
|
|
665
|
+
res = pd.DataFrame({signature.outputs[0].name: temp_res})
|
|
623
666
|
else:
|
|
624
667
|
res = pd.DataFrame(temp_res)
|
|
625
668
|
|
|
@@ -702,7 +745,6 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
702
745
|
self.pipeline = pipeline
|
|
703
746
|
self.model = self.pipeline.model
|
|
704
747
|
self.tokenizer = self.pipeline.tokenizer
|
|
705
|
-
|
|
706
748
|
self.model_name = self.pipeline.model.name_or_path
|
|
707
749
|
|
|
708
750
|
if self.tokenizer.pad_token is None:
|
|
@@ -724,11 +766,33 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
724
766
|
Returns:
|
|
725
767
|
The formatted prompt string ready for model input.
|
|
726
768
|
"""
|
|
769
|
+
|
|
770
|
+
final_messages = []
|
|
771
|
+
for message in messages:
|
|
772
|
+
if isinstance(message.get("content", ""), str):
|
|
773
|
+
final_messages.append({"role": message.get("role", "user"), "content": message.get("content", "")})
|
|
774
|
+
else:
|
|
775
|
+
# extract only the text from the content
|
|
776
|
+
# sample data:
|
|
777
|
+
# {
|
|
778
|
+
# "role": "user",
|
|
779
|
+
# "content": [
|
|
780
|
+
# {"type": "text", "text": "Hello, how are you?"}, # extracted
|
|
781
|
+
# {"type": "image", "image": "https://example.com/image.png"}, # not extracted
|
|
782
|
+
# ],
|
|
783
|
+
# }
|
|
784
|
+
for content_part in message.get("content", []):
|
|
785
|
+
if content_part.get("type", "") == "text":
|
|
786
|
+
final_messages.append(
|
|
787
|
+
{"role": message.get("role", "user"), "content": content_part.get("text", "")}
|
|
788
|
+
)
|
|
789
|
+
# TODO: implement other content types
|
|
790
|
+
|
|
727
791
|
# Use the tokenizer's apply_chat_template method.
|
|
728
792
|
# We ensured a template exists in __init__.
|
|
729
793
|
if hasattr(self.tokenizer, "apply_chat_template"):
|
|
730
794
|
return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
|
|
731
|
-
|
|
795
|
+
final_messages,
|
|
732
796
|
tokenize=False,
|
|
733
797
|
add_generation_prompt=True,
|
|
734
798
|
)
|
|
@@ -736,7 +800,7 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
736
800
|
# Fallback for very old transformers without apply_chat_template
|
|
737
801
|
# Manually apply ChatML-like formatting
|
|
738
802
|
prompt = ""
|
|
739
|
-
for message in
|
|
803
|
+
for message in final_messages:
|
|
740
804
|
role = message.get("role", "user")
|
|
741
805
|
content = message.get("content", "")
|
|
742
806
|
prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|