snowflake-ml-python 1.22.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. snowflake/ml/jobs/__init__.py +2 -0
  2. snowflake/ml/jobs/_utils/constants.py +1 -0
  3. snowflake/ml/jobs/_utils/payload_utils.py +38 -18
  4. snowflake/ml/jobs/_utils/query_helper.py +8 -1
  5. snowflake/ml/jobs/_utils/runtime_env_utils.py +117 -0
  6. snowflake/ml/jobs/_utils/stage_utils.py +2 -2
  7. snowflake/ml/jobs/_utils/types.py +22 -2
  8. snowflake/ml/jobs/job_definition.py +232 -0
  9. snowflake/ml/jobs/manager.py +16 -177
  10. snowflake/ml/model/_client/model/model_version_impl.py +90 -76
  11. snowflake/ml/model/_client/ops/model_ops.py +2 -18
  12. snowflake/ml/model/_client/ops/param_utils.py +124 -0
  13. snowflake/ml/model/_client/ops/service_ops.py +63 -18
  14. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -5
  15. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  16. snowflake/ml/model/_client/sql/service.py +4 -25
  17. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
  18. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
  19. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
  20. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -1
  21. snowflake/ml/model/_packager/model_handlers/huggingface.py +54 -10
  22. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +52 -16
  23. snowflake/ml/model/_signatures/utils.py +55 -0
  24. snowflake/ml/model/openai_signatures.py +97 -0
  25. snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
  26. snowflake/ml/version.py +1 -1
  27. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/METADATA +67 -1
  28. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/RECORD +31 -29
  29. snowflake/ml/experiment/callback/__init__.py +0 -0
  30. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/WHEEL +0 -0
  31. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/licenses/LICENSE.txt +0 -0
  32. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.23.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,124 @@
1
+ """Utility functions for model parameter validation and resolution."""
2
+
3
+ from typing import Any, Optional, Sequence
4
+
5
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
6
+ from snowflake.ml._internal.utils import sql_identifier
7
+ from snowflake.ml.model._signatures import core
8
+
9
+
10
+ def validate_params(
11
+ params: Optional[dict[str, Any]],
12
+ signature_params: Optional[Sequence[core.BaseParamSpec]],
13
+ ) -> None:
14
+ """Validate user-provided params against signature params.
15
+
16
+ Args:
17
+ params: User-provided parameter dictionary (runtime values).
18
+ signature_params: Parameter specifications from the model signature.
19
+
20
+ Raises:
21
+ SnowflakeMLException: If params are provided but signature has no params,
22
+ or if unknown params are provided, or if param types are invalid,
23
+ or if duplicate params are provided with different cases.
24
+ """
25
+ # Params provided but signature has no params defined
26
+ if params and not signature_params:
27
+ raise exceptions.SnowflakeMLException(
28
+ error_code=error_codes.INVALID_ARGUMENT,
29
+ original_exception=ValueError(
30
+ f"Parameters were provided ({sorted(params.keys())}), "
31
+ "but this method does not accept any parameters."
32
+ ),
33
+ )
34
+
35
+ if not signature_params or not params:
36
+ return
37
+
38
+ # Case-insensitive lookup: normalized_name -> param_spec
39
+ param_spec_lookup = {ps.name.upper(): ps for ps in signature_params}
40
+
41
+ # Check for duplicate params with different cases (e.g., "temperature" and "TEMPERATURE")
42
+ normalized_names = [name.upper() for name in params]
43
+ if len(normalized_names) != len(set(normalized_names)):
44
+ # Find the duplicate params to raise an error
45
+ param_seen: dict[str, list[str]] = {}
46
+ for param_name in params:
47
+ param_seen.setdefault(param_name.upper(), []).append(param_name)
48
+ duplicate_param_names = [param_names for param_names in param_seen.values() if len(param_names) > 1]
49
+ raise exceptions.SnowflakeMLException(
50
+ error_code=error_codes.INVALID_ARGUMENT,
51
+ original_exception=ValueError(
52
+ f"Duplicate parameter(s) provided with different cases: {duplicate_param_names}. "
53
+ "Parameter names are case-insensitive."
54
+ ),
55
+ )
56
+
57
+ # Validate user-provided params exist (case-insensitive)
58
+ invalid_params = [name for name in params if name.upper() not in param_spec_lookup]
59
+ if invalid_params:
60
+ raise exceptions.SnowflakeMLException(
61
+ error_code=error_codes.INVALID_ARGUMENT,
62
+ original_exception=ValueError(
63
+ f"Unknown parameter(s): {sorted(invalid_params)}. "
64
+ f"Valid parameters are: {sorted(ps.name for ps in signature_params)}"
65
+ ),
66
+ )
67
+
68
+ # Validate types for each provided param
69
+ for param_name, default_value in params.items():
70
+ param_spec = param_spec_lookup[param_name.upper()]
71
+ if isinstance(param_spec, core.ParamSpec):
72
+ core.ParamSpec._validate_default_value(param_spec.dtype, default_value, param_spec.shape)
73
+
74
+
75
+ def resolve_params(
76
+ params: Optional[dict[str, Any]],
77
+ signature_params: Sequence[core.BaseParamSpec],
78
+ ) -> list[tuple[sql_identifier.SqlIdentifier, Any]]:
79
+ """Resolve final method parameters by applying user-provided params over signature defaults.
80
+
81
+ Args:
82
+ params: User-provided parameter dictionary (runtime values).
83
+ signature_params: Parameter specifications from the model signature.
84
+
85
+ Returns:
86
+ List of tuples (SqlIdentifier, value) for method invocation.
87
+ """
88
+ # Case-insensitive lookup: normalized_name -> param_spec
89
+ param_spec_lookup = {ps.name.upper(): ps for ps in signature_params}
90
+
91
+ # Start with defaults from signature
92
+ final_params: dict[str, Any] = {}
93
+ for param_spec in signature_params:
94
+ if hasattr(param_spec, "default_value"):
95
+ final_params[param_spec.name] = param_spec.default_value
96
+
97
+ # Override with provided runtime parameters (using signature's original param names)
98
+ if params:
99
+ for param_name, override_value in params.items():
100
+ canonical_name = param_spec_lookup[param_name.upper()].name
101
+ final_params[canonical_name] = override_value
102
+
103
+ return [(sql_identifier.SqlIdentifier(param_name), param_value) for param_name, param_value in final_params.items()]
104
+
105
+
106
+ def validate_and_resolve_params(
107
+ params: Optional[dict[str, Any]],
108
+ signature_params: Optional[Sequence[core.BaseParamSpec]],
109
+ ) -> Optional[list[tuple[sql_identifier.SqlIdentifier, Any]]]:
110
+ """Validate user-provided params against signature params and return method parameters.
111
+
112
+ Args:
113
+ params: User-provided parameter dictionary (runtime values).
114
+ signature_params: Parameter specifications from the model signature.
115
+
116
+ Returns:
117
+ List of tuples (SqlIdentifier, value) for method invocation, or None if no params.
118
+ """
119
+ validate_params(params, signature_params)
120
+
121
+ if not signature_params:
122
+ return None
123
+
124
+ return resolve_params(params, signature_params)
@@ -1,4 +1,6 @@
1
+ import base64
1
2
  import dataclasses
3
+ import json
2
4
  import logging
3
5
  import pathlib
4
6
  import re
@@ -6,7 +8,9 @@ import tempfile
6
8
  import threading
7
9
  import time
8
10
  import warnings
9
- from typing import Any, Optional, Union, cast
11
+ from typing import Any, Optional, Sequence, Union, cast
12
+
13
+ from pydantic import TypeAdapter
10
14
 
11
15
  from snowflake import snowpark
12
16
  from snowflake.ml._internal import file_utils, platform_capabilities as pc
@@ -14,9 +18,10 @@ from snowflake.ml._internal.utils import identifier, service_logger, sql_identif
14
18
  from snowflake.ml.jobs import job
15
19
  from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
16
20
  from snowflake.ml.model._client.model import batch_inference_specs
17
- from snowflake.ml.model._client.ops import deployment_step
21
+ from snowflake.ml.model._client.ops import deployment_step, param_utils
18
22
  from snowflake.ml.model._client.service import model_deployment_spec
19
23
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
24
+ from snowflake.ml.model._signatures import core
20
25
  from snowflake.snowpark import async_job, exceptions, row, session
21
26
  from snowflake.snowpark._internal import utils as snowpark_utils
22
27
 
@@ -582,15 +587,10 @@ class ServiceOperator:
582
587
  )
583
588
  for status in statuses:
584
589
  if status.instance_id is not None:
585
- instance_status, container_status = None, None
586
- if status.instance_status is not None:
587
- instance_status = status.instance_status.value
588
- if status.container_status is not None:
589
- container_status = status.container_status.value
590
590
  module_logger.info(
591
591
  f"Instance[{status.instance_id}]: "
592
- f"instance status: {instance_status}, "
593
- f"container status: {container_status}, "
592
+ f"instance status: {status.instance_status}, "
593
+ f"container status: {status.container_status}, "
594
594
  f"message: {status.message}"
595
595
  )
596
596
  time.sleep(5)
@@ -930,6 +930,38 @@ class ServiceOperator:
930
930
  except exceptions.SnowparkSQLException:
931
931
  return False
932
932
 
933
+ @staticmethod
934
+ def _encode_params(params: Optional[dict[str, Any]]) -> Optional[str]:
935
+ """Encode params dictionary to a base64 string.
936
+
937
+ Args:
938
+ params: Optional dictionary of model inference parameters.
939
+
940
+ Returns:
941
+ Base64 encoded JSON string of the params, or None if input is None.
942
+ """
943
+ if params is None:
944
+ return None
945
+ return base64.b64encode(json.dumps(params).encode("utf-8")).decode("utf-8")
946
+
947
+ @staticmethod
948
+ def _encode_column_handling(
949
+ column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
950
+ ) -> Optional[str]:
951
+ """Validate and encode column_handling to a base64 string.
952
+
953
+ Args:
954
+ column_handling: Optional dictionary mapping column names to file encoding options.
955
+
956
+ Returns:
957
+ Base64 encoded JSON string of the column handling options, or None if input is None.
958
+ """
959
+ if column_handling is None:
960
+ return None
961
+ adapter = TypeAdapter(dict[str, batch_inference_specs.ColumnHandlingOptions])
962
+ validated_input = adapter.validate_python(column_handling)
963
+ return base64.b64encode(adapter.dump_json(validated_input)).decode("utf-8")
964
+
933
965
  def invoke_batch_job_method(
934
966
  self,
935
967
  *,
@@ -942,8 +974,9 @@ class ServiceOperator:
942
974
  image_repo_name: Optional[str],
943
975
  input_stage_location: str,
944
976
  input_file_pattern: str,
945
- column_handling: Optional[str],
946
- params: Optional[str],
977
+ column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
978
+ params: Optional[dict[str, Any]],
979
+ signature_params: Optional[Sequence[core.BaseParamSpec]],
947
980
  output_stage_location: str,
948
981
  completion_filename: str,
949
982
  force_rebuild: bool,
@@ -954,7 +987,13 @@ class ServiceOperator:
954
987
  gpu_requests: Optional[str],
955
988
  replicas: Optional[int],
956
989
  statement_params: Optional[dict[str, Any]] = None,
990
+ inference_engine_args: Optional[InferenceEngineArgs] = None,
957
991
  ) -> job.MLJob[Any]:
992
+ # Validate and encode params
993
+ param_utils.validate_params(params, signature_params)
994
+ params_encoded = self._encode_params(params)
995
+ column_handling_encoded = self._encode_column_handling(column_handling)
996
+
958
997
  database_name = self._database_name
959
998
  schema_name = self._schema_name
960
999
 
@@ -980,8 +1019,8 @@ class ServiceOperator:
980
1019
  max_batch_rows=max_batch_rows,
981
1020
  input_stage_location=input_stage_location,
982
1021
  input_file_pattern=input_file_pattern,
983
- column_handling=column_handling,
984
- params=params,
1022
+ column_handling=column_handling_encoded,
1023
+ params=params_encoded,
985
1024
  output_stage_location=output_stage_location,
986
1025
  completion_filename=completion_filename,
987
1026
  function_name=function_name,
@@ -992,11 +1031,17 @@ class ServiceOperator:
992
1031
  replicas=replicas,
993
1032
  )
994
1033
 
995
- self._model_deployment_spec.add_image_build_spec(
996
- image_build_compute_pool_name=compute_pool_name,
997
- fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
998
- force_rebuild=force_rebuild,
999
- )
1034
+ if inference_engine_args:
1035
+ self._model_deployment_spec.add_inference_engine_spec(
1036
+ inference_engine=inference_engine_args.inference_engine,
1037
+ inference_engine_args=inference_engine_args.inference_engine_args_override,
1038
+ )
1039
+ else:
1040
+ self._model_deployment_spec.add_image_build_spec(
1041
+ image_build_compute_pool_name=compute_pool_name,
1042
+ fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
1043
+ force_rebuild=force_rebuild,
1044
+ )
1000
1045
 
1001
1046
  spec_yaml_str_or_path = self._model_deployment_spec.save()
1002
1047
 
@@ -363,7 +363,7 @@ class ModelDeploymentSpec:
363
363
  inference_engine: inference_engine_module.InferenceEngine,
364
364
  inference_engine_args: Optional[list[str]] = None,
365
365
  ) -> "ModelDeploymentSpec":
366
- """Add inference engine specification. This must be called after self.add_service_spec().
366
+ """Add inference engine specification. This must be called after self.add_service_spec() or self.add_job_spec().
367
367
 
368
368
  Args:
369
369
  inference_engine: Inference engine.
@@ -376,9 +376,10 @@ class ModelDeploymentSpec:
376
376
  ValueError: If inference engine specification is called before add_service_spec().
377
377
  ValueError: If the argument does not have a '--' prefix.
378
378
  """
379
- # TODO: needs to eventually support job deployment spec
380
- if self._service is None:
381
- raise ValueError("Inference engine specification must be called after add_service_spec().")
379
+ if self._service is None and self._job is None:
380
+ raise ValueError(
381
+ "Inference engine specification must be called after add_service_spec() or add_job_spec()."
382
+ )
382
383
 
383
384
  if inference_engine_args is None:
384
385
  inference_engine_args = []
@@ -431,11 +432,17 @@ class ModelDeploymentSpec:
431
432
 
432
433
  inference_engine_args = filtered_args
433
434
 
434
- self._service.inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
435
+ inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
435
436
  # convert to string to be saved in the deployment spec
436
437
  inference_engine_name=inference_engine.value,
437
438
  inference_engine_args=inference_engine_args,
438
439
  )
440
+
441
+ if self._service:
442
+ self._service.inference_engine_spec = inference_engine_spec
443
+ elif self._job:
444
+ self._job.inference_engine_spec = inference_engine_spec
445
+
439
446
  return self
440
447
 
441
448
  def save(self) -> str:
@@ -61,6 +61,7 @@ class Job(BaseModel):
61
61
  input: Input
62
62
  output: Output
63
63
  replicas: Optional[int] = None
64
+ inference_engine_spec: Optional[InferenceEngineSpec] = None
64
65
 
65
66
 
66
67
  class LogModelArgs(BaseModel):
@@ -47,22 +47,6 @@ class ServiceStatus(enum.Enum):
47
47
  INTERNAL_ERROR = "INTERNAL_ERROR"
48
48
 
49
49
 
50
- class InstanceStatus(enum.Enum):
51
- PENDING = "PENDING"
52
- READY = "READY"
53
- FAILED = "FAILED"
54
- TERMINATING = "TERMINATING"
55
- SUCCEEDED = "SUCCEEDED"
56
-
57
-
58
- class ContainerStatus(enum.Enum):
59
- PENDING = "PENDING"
60
- READY = "READY"
61
- DONE = "DONE"
62
- FAILED = "FAILED"
63
- UNKNOWN = "UNKNOWN"
64
-
65
-
66
50
  @dataclasses.dataclass
67
51
  class ServiceStatusInfo:
68
52
  """
@@ -72,8 +56,8 @@ class ServiceStatusInfo:
72
56
 
73
57
  service_status: ServiceStatus
74
58
  instance_id: Optional[int] = None
75
- instance_status: Optional[InstanceStatus] = None
76
- container_status: Optional[ContainerStatus] = None
59
+ instance_status: Optional[str] = None
60
+ container_status: Optional[str] = None
77
61
  message: Optional[str] = None
78
62
 
79
63
 
@@ -272,17 +256,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
272
256
  )
273
257
  statuses = []
274
258
  for r in rows:
275
- instance_status, container_status = None, None
276
- if r[ServiceSQLClient.INSTANCE_STATUS] is not None:
277
- instance_status = InstanceStatus(r[ServiceSQLClient.INSTANCE_STATUS])
278
- if r[ServiceSQLClient.CONTAINER_STATUS] is not None:
279
- container_status = ContainerStatus(r[ServiceSQLClient.CONTAINER_STATUS])
280
259
  statuses.append(
281
260
  ServiceStatusInfo(
282
261
  service_status=ServiceStatus(r[ServiceSQLClient.SERVICE_STATUS]),
283
262
  instance_id=r[ServiceSQLClient.INSTANCE_ID],
284
- instance_status=instance_status,
285
- container_status=container_status,
263
+ instance_status=r[ServiceSQLClient.INSTANCE_STATUS],
264
+ container_status=r[ServiceSQLClient.CONTAINER_STATUS],
286
265
  message=r[ServiceSQLClient.MESSAGE] if include_message else None,
287
266
  )
288
267
  )
@@ -41,11 +41,29 @@ features = meta.signatures[TARGET_METHOD].inputs
41
41
  input_cols = [feature.name for feature in features]
42
42
  dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
43
43
 
44
+ # Load inference parameters from method signature (if any)
45
+ param_cols = []
46
+ param_defaults = {{}}
47
+ if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
48
+ for param_spec in meta.signatures[TARGET_METHOD].params:
49
+ param_cols.append(param_spec.name)
50
+ param_defaults[param_spec.name] = param_spec.default_value
51
+
44
52
 
45
53
  # Actual function
46
54
  @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
47
55
  def {function_name}(df: pd.DataFrame) -> dict:
48
- df.columns = input_cols
49
- input_df = df.astype(dtype=dtype_map)
50
- predictions_df = runner(input_df[input_cols])
56
+ df.columns = input_cols + param_cols
57
+ input_df = df[input_cols].astype(dtype=dtype_map)
58
+
59
+ # Extract runtime param values, using defaults if None
60
+ method_params = {{}}
61
+ for col in param_cols:
62
+ val = df[col].iloc[0]
63
+ if val is None or pd.isna(val):
64
+ method_params[col] = param_defaults[col]
65
+ else:
66
+ method_params[col] = val
67
+
68
+ predictions_df = runner(input_df, **method_params)
51
69
  return predictions_df.replace({{pd.NA: None, np.nan: None}}).to_dict("records")
@@ -45,11 +45,29 @@ features = meta.signatures[TARGET_METHOD].inputs
45
45
  input_cols = [feature.name for feature in features]
46
46
  dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
47
47
 
48
+ # Load inference parameters from method signature (if any)
49
+ param_cols = []
50
+ param_defaults = {{}}
51
+ if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
52
+ for param_spec in meta.signatures[TARGET_METHOD].params:
53
+ param_cols.append(param_spec.name)
54
+ param_defaults[param_spec.name] = param_spec.default_value
55
+
48
56
 
49
57
  # Actual table function
50
58
  class {function_name}:
51
59
  @vectorized(input=pd.DataFrame, flatten_object_input={wide_input})
52
60
  def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
53
- df.columns = input_cols
54
- input_df = df.astype(dtype=dtype_map)
55
- return runner(input_df[input_cols])
61
+ df.columns = input_cols + param_cols
62
+ input_df = df[input_cols].astype(dtype=dtype_map)
63
+
64
+ # Extract runtime param values, using defaults if None
65
+ method_params = {{}}
66
+ for col in param_cols:
67
+ val = df[col].iloc[0]
68
+ if val is None or pd.isna(val):
69
+ method_params[col] = param_defaults[col]
70
+ else:
71
+ method_params[col] = val
72
+
73
+ return runner(input_df, **method_params)
@@ -40,11 +40,29 @@ features = meta.signatures[TARGET_METHOD].inputs
40
40
  input_cols = [feature.name for feature in features]
41
41
  dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
42
42
 
43
+ # Load inference parameters from method signature (if any)
44
+ param_cols = []
45
+ param_defaults = {{}}
46
+ if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
47
+ for param_spec in meta.signatures[TARGET_METHOD].params:
48
+ param_cols.append(param_spec.name)
49
+ param_defaults[param_spec.name] = param_spec.default_value
50
+
43
51
 
44
52
  # Actual table function
45
53
  class {function_name}:
46
54
  @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
47
55
  def process(self, df: pd.DataFrame) -> pd.DataFrame:
48
- df.columns = input_cols
49
- input_df = df.astype(dtype=dtype_map)
50
- return runner(input_df[input_cols])
56
+ df.columns = input_cols + param_cols
57
+ input_df = df[input_cols].astype(dtype=dtype_map)
58
+
59
+ # Extract runtime param values, using defaults if None
60
+ method_params = {{}}
61
+ for col in param_cols:
62
+ val = df[col].iloc[0]
63
+ if val is None or pd.isna(val):
64
+ method_params[col] = param_defaults[col]
65
+ else:
66
+ method_params[col] = val
67
+
68
+ return runner(input_df, **method_params)
@@ -156,10 +156,11 @@ class ModelMethod:
156
156
  f"Your parameter {param_spec.name} cannot be resolved as valid SQL identifier. "
157
157
  "Try specifying `case_sensitive` as True."
158
158
  ) from e
159
+ default_value = param_spec.default_value if param_spec.default_value is None else str(param_spec.default_value)
159
160
  return model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault(
160
161
  name=param_name.resolved(),
161
162
  type=type_utils.convert_sp_to_sf_type(param_spec.dtype.as_snowpark_type()),
162
- default=param_spec.default_value,
163
+ default=default_value,
163
164
  )
164
165
 
165
166
  def save(
@@ -1,3 +1,4 @@
1
+ import io
1
2
  import json
2
3
  import logging
3
4
  import os
@@ -28,7 +29,10 @@ from snowflake.ml.model._packager.model_meta import (
28
29
  model_meta as model_meta_api,
29
30
  model_meta_schema,
30
31
  )
31
- from snowflake.ml.model._signatures import utils as model_signature_utils
32
+ from snowflake.ml.model._signatures import (
33
+ core as model_signature_core,
34
+ utils as model_signature_utils,
35
+ )
32
36
  from snowflake.ml.model.models import (
33
37
  huggingface as huggingface_base,
34
38
  huggingface_pipeline,
@@ -530,7 +534,10 @@ class TransformersPipelineHandler(
530
534
  # verify when the target method is __call__ and
531
535
  # if the signature is default text-generation signature
532
536
  # then use the HuggingFaceOpenAICompatibleModel to wrap the pipeline
533
- if signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC:
537
+ if (
538
+ signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC
539
+ or signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC_WITH_CONTENT_FORMAT_STRING
540
+ ):
534
541
  wrapped_model = HuggingFaceOpenAICompatibleModel(pipeline=raw_model)
535
542
 
536
543
  temp_res = X.apply(
@@ -554,6 +561,19 @@ class TransformersPipelineHandler(
554
561
  else:
555
562
  input_data = X[signature.inputs[0].name].to_list()
556
563
  temp_res = getattr(raw_model, target_method)(input_data)
564
+ elif isinstance(raw_model, transformers.ImageClassificationPipeline):
565
+ # Image classification expects PIL Images. Convert bytes to PIL Images.
566
+ from PIL import Image
567
+
568
+ input_col = signature.inputs[0].name
569
+ images = [Image.open(io.BytesIO(img_bytes)) for img_bytes in X[input_col].to_list()]
570
+ temp_res = getattr(raw_model, target_method)(images)
571
+ elif isinstance(raw_model, transformers.AutomaticSpeechRecognitionPipeline):
572
+ # ASR pipeline accepts a single audio input (bytes, str, np.ndarray, or dict),
573
+ # not a list. Process each audio input individually.
574
+ input_col = signature.inputs[0].name
575
+ audio_inputs = X[input_col].to_list()
576
+ temp_res = [getattr(raw_model, target_method)(audio) for audio in audio_inputs]
557
577
  else:
558
578
  # TODO: remove conversational pipeline code
559
579
  # For others, we could offer the whole dataframe as a list.
@@ -615,11 +635,14 @@ class TransformersPipelineHandler(
615
635
  temp_res = [[conv.generated_responses] for conv in temp_res]
616
636
 
617
637
  # To concat those who outputs a list with one input.
618
- if isinstance(temp_res[0], list):
619
- if isinstance(temp_res[0][0], dict):
620
- res = pd.DataFrame({0: temp_res})
621
- else:
622
- res = pd.DataFrame(temp_res)
638
+ # if `signature.outputs` is single valued and is a FeatureGroupSpec,
639
+ # we create a DataFrame with one column and the values are stored as a dictionary.
640
+ # Otherwise, we create a DataFrame with the output as the column.
641
+ if len(signature.outputs) == 1 and isinstance(
642
+ signature.outputs[0], model_signature_core.FeatureGroupSpec
643
+ ):
644
+ # creating a dataframe with one column
645
+ res = pd.DataFrame({signature.outputs[0].name: temp_res})
623
646
  else:
624
647
  res = pd.DataFrame(temp_res)
625
648
 
@@ -702,7 +725,6 @@ class HuggingFaceOpenAICompatibleModel:
702
725
  self.pipeline = pipeline
703
726
  self.model = self.pipeline.model
704
727
  self.tokenizer = self.pipeline.tokenizer
705
-
706
728
  self.model_name = self.pipeline.model.name_or_path
707
729
 
708
730
  if self.tokenizer.pad_token is None:
@@ -724,11 +746,33 @@ class HuggingFaceOpenAICompatibleModel:
724
746
  Returns:
725
747
  The formatted prompt string ready for model input.
726
748
  """
749
+
750
+ final_messages = []
751
+ for message in messages:
752
+ if isinstance(message.get("content", ""), str):
753
+ final_messages.append({"role": message.get("role", "user"), "content": message.get("content", "")})
754
+ else:
755
+ # extract only the text from the content
756
+ # sample data:
757
+ # {
758
+ # "role": "user",
759
+ # "content": [
760
+ # {"type": "text", "text": "Hello, how are you?"}, # extracted
761
+ # {"type": "image", "image": "https://example.com/image.png"}, # not extracted
762
+ # ],
763
+ # }
764
+ for content_part in message.get("content", []):
765
+ if content_part.get("type", "") == "text":
766
+ final_messages.append(
767
+ {"role": message.get("role", "user"), "content": content_part.get("text", "")}
768
+ )
769
+ # TODO: implement other content types
770
+
727
771
  # Use the tokenizer's apply_chat_template method.
728
772
  # We ensured a template exists in __init__.
729
773
  if hasattr(self.tokenizer, "apply_chat_template"):
730
774
  return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
731
- messages,
775
+ final_messages,
732
776
  tokenize=False,
733
777
  add_generation_prompt=True,
734
778
  )
@@ -736,7 +780,7 @@ class HuggingFaceOpenAICompatibleModel:
736
780
  # Fallback for very old transformers without apply_chat_template
737
781
  # Manually apply ChatML-like formatting
738
782
  prompt = ""
739
- for message in messages:
783
+ for message in final_messages:
740
784
  role = message.get("role", "user")
741
785
  content = message.get("content", "")
742
786
  prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"