snowflake-ml-python 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +0 -4
  2. snowflake/ml/feature_store/__init__.py +2 -0
  3. snowflake/ml/feature_store/aggregation.py +367 -0
  4. snowflake/ml/feature_store/feature.py +366 -0
  5. snowflake/ml/feature_store/feature_store.py +234 -20
  6. snowflake/ml/feature_store/feature_view.py +189 -4
  7. snowflake/ml/feature_store/metadata_manager.py +425 -0
  8. snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
  9. snowflake/ml/jobs/__init__.py +2 -0
  10. snowflake/ml/jobs/_utils/constants.py +1 -0
  11. snowflake/ml/jobs/_utils/payload_utils.py +38 -18
  12. snowflake/ml/jobs/_utils/query_helper.py +8 -1
  13. snowflake/ml/jobs/_utils/runtime_env_utils.py +117 -0
  14. snowflake/ml/jobs/_utils/stage_utils.py +2 -2
  15. snowflake/ml/jobs/_utils/types.py +22 -2
  16. snowflake/ml/jobs/job_definition.py +232 -0
  17. snowflake/ml/jobs/manager.py +16 -177
  18. snowflake/ml/model/__init__.py +4 -0
  19. snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
  20. snowflake/ml/model/_client/model/model_version_impl.py +120 -89
  21. snowflake/ml/model/_client/ops/model_ops.py +4 -26
  22. snowflake/ml/model/_client/ops/param_utils.py +124 -0
  23. snowflake/ml/model/_client/ops/service_ops.py +63 -23
  24. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -5
  25. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  26. snowflake/ml/model/_client/sql/service.py +25 -54
  27. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +21 -3
  28. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +21 -3
  29. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +21 -3
  30. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -1
  31. snowflake/ml/model/_packager/model_handlers/huggingface.py +74 -10
  32. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +121 -29
  33. snowflake/ml/model/_signatures/utils.py +130 -0
  34. snowflake/ml/model/openai_signatures.py +97 -0
  35. snowflake/ml/registry/_manager/model_parameter_reconciler.py +1 -1
  36. snowflake/ml/version.py +1 -1
  37. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/METADATA +105 -1
  38. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/RECORD +41 -35
  39. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/WHEEL +1 -1
  40. snowflake/ml/experiment/callback/__init__.py +0 -0
  41. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/licenses/LICENSE.txt +0 -0
  42. {snowflake_ml_python-1.22.0.dist-info → snowflake_ml_python-1.24.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,6 @@
1
+ import base64
1
2
  import dataclasses
3
+ import json
2
4
  import logging
3
5
  import pathlib
4
6
  import re
@@ -6,7 +8,9 @@ import tempfile
6
8
  import threading
7
9
  import time
8
10
  import warnings
9
- from typing import Any, Optional, Union, cast
11
+ from typing import Any, Optional, Sequence, Union, cast
12
+
13
+ from pydantic import TypeAdapter
10
14
 
11
15
  from snowflake import snowpark
12
16
  from snowflake.ml._internal import file_utils, platform_capabilities as pc
@@ -14,9 +18,10 @@ from snowflake.ml._internal.utils import identifier, service_logger, sql_identif
14
18
  from snowflake.ml.jobs import job
15
19
  from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
16
20
  from snowflake.ml.model._client.model import batch_inference_specs
17
- from snowflake.ml.model._client.ops import deployment_step
21
+ from snowflake.ml.model._client.ops import deployment_step, param_utils
18
22
  from snowflake.ml.model._client.service import model_deployment_spec
19
23
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
24
+ from snowflake.ml.model._signatures import core
20
25
  from snowflake.snowpark import async_job, exceptions, row, session
21
26
  from snowflake.snowpark._internal import utils as snowpark_utils
22
27
 
@@ -150,7 +155,6 @@ class ServiceOperator:
150
155
  self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
151
156
  workspace_path=pathlib.Path(self._workspace.name)
152
157
  )
153
- self._inference_autocapture_enabled = pc.PlatformCapabilities.get_instance().is_inference_autocapture_enabled()
154
158
 
155
159
  def __eq__(self, __value: object) -> bool:
156
160
  if not isinstance(__value, ServiceOperator):
@@ -211,10 +215,6 @@ class ServiceOperator:
211
215
  progress_status.update("preparing deployment artifacts...")
212
216
  progress_status.increment()
213
217
 
214
- # If autocapture param is disabled, don't allow create service with autocapture
215
- if not self._inference_autocapture_enabled and autocapture:
216
- raise ValueError("Invalid Argument: Autocapture feature is not supported.")
217
-
218
218
  if self._workspace:
219
219
  stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
220
220
  else:
@@ -582,15 +582,10 @@ class ServiceOperator:
582
582
  )
583
583
  for status in statuses:
584
584
  if status.instance_id is not None:
585
- instance_status, container_status = None, None
586
- if status.instance_status is not None:
587
- instance_status = status.instance_status.value
588
- if status.container_status is not None:
589
- container_status = status.container_status.value
590
585
  module_logger.info(
591
586
  f"Instance[{status.instance_id}]: "
592
- f"instance status: {instance_status}, "
593
- f"container status: {container_status}, "
587
+ f"instance status: {status.instance_status}, "
588
+ f"container status: {status.container_status}, "
594
589
  f"message: {status.message}"
595
590
  )
596
591
  time.sleep(5)
@@ -930,6 +925,38 @@ class ServiceOperator:
930
925
  except exceptions.SnowparkSQLException:
931
926
  return False
932
927
 
928
+ @staticmethod
929
+ def _encode_params(params: Optional[dict[str, Any]]) -> Optional[str]:
930
+ """Encode params dictionary to a base64 string.
931
+
932
+ Args:
933
+ params: Optional dictionary of model inference parameters.
934
+
935
+ Returns:
936
+ Base64 encoded JSON string of the params, or None if input is None.
937
+ """
938
+ if params is None:
939
+ return None
940
+ return base64.b64encode(json.dumps(params).encode("utf-8")).decode("utf-8")
941
+
942
+ @staticmethod
943
+ def _encode_column_handling(
944
+ column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
945
+ ) -> Optional[str]:
946
+ """Validate and encode column_handling to a base64 string.
947
+
948
+ Args:
949
+ column_handling: Optional dictionary mapping column names to file encoding options.
950
+
951
+ Returns:
952
+ Base64 encoded JSON string of the column handling options, or None if input is None.
953
+ """
954
+ if column_handling is None:
955
+ return None
956
+ adapter = TypeAdapter(dict[str, batch_inference_specs.ColumnHandlingOptions])
957
+ validated_input = adapter.validate_python(column_handling)
958
+ return base64.b64encode(adapter.dump_json(validated_input)).decode("utf-8")
959
+
933
960
  def invoke_batch_job_method(
934
961
  self,
935
962
  *,
@@ -942,8 +969,9 @@ class ServiceOperator:
942
969
  image_repo_name: Optional[str],
943
970
  input_stage_location: str,
944
971
  input_file_pattern: str,
945
- column_handling: Optional[str],
946
- params: Optional[str],
972
+ column_handling: Optional[dict[str, batch_inference_specs.ColumnHandlingOptions]],
973
+ params: Optional[dict[str, Any]],
974
+ signature_params: Optional[Sequence[core.BaseParamSpec]],
947
975
  output_stage_location: str,
948
976
  completion_filename: str,
949
977
  force_rebuild: bool,
@@ -954,7 +982,13 @@ class ServiceOperator:
954
982
  gpu_requests: Optional[str],
955
983
  replicas: Optional[int],
956
984
  statement_params: Optional[dict[str, Any]] = None,
985
+ inference_engine_args: Optional[InferenceEngineArgs] = None,
957
986
  ) -> job.MLJob[Any]:
987
+ # Validate and encode params
988
+ param_utils.validate_params(params, signature_params)
989
+ params_encoded = self._encode_params(params)
990
+ column_handling_encoded = self._encode_column_handling(column_handling)
991
+
958
992
  database_name = self._database_name
959
993
  schema_name = self._schema_name
960
994
 
@@ -980,8 +1014,8 @@ class ServiceOperator:
980
1014
  max_batch_rows=max_batch_rows,
981
1015
  input_stage_location=input_stage_location,
982
1016
  input_file_pattern=input_file_pattern,
983
- column_handling=column_handling,
984
- params=params,
1017
+ column_handling=column_handling_encoded,
1018
+ params=params_encoded,
985
1019
  output_stage_location=output_stage_location,
986
1020
  completion_filename=completion_filename,
987
1021
  function_name=function_name,
@@ -992,11 +1026,17 @@ class ServiceOperator:
992
1026
  replicas=replicas,
993
1027
  )
994
1028
 
995
- self._model_deployment_spec.add_image_build_spec(
996
- image_build_compute_pool_name=compute_pool_name,
997
- fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
998
- force_rebuild=force_rebuild,
999
- )
1029
+ if inference_engine_args:
1030
+ self._model_deployment_spec.add_inference_engine_spec(
1031
+ inference_engine=inference_engine_args.inference_engine,
1032
+ inference_engine_args=inference_engine_args.inference_engine_args_override,
1033
+ )
1034
+ else:
1035
+ self._model_deployment_spec.add_image_build_spec(
1036
+ image_build_compute_pool_name=compute_pool_name,
1037
+ fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
1038
+ force_rebuild=force_rebuild,
1039
+ )
1000
1040
 
1001
1041
  spec_yaml_str_or_path = self._model_deployment_spec.save()
1002
1042
 
@@ -363,7 +363,7 @@ class ModelDeploymentSpec:
363
363
  inference_engine: inference_engine_module.InferenceEngine,
364
364
  inference_engine_args: Optional[list[str]] = None,
365
365
  ) -> "ModelDeploymentSpec":
366
- """Add inference engine specification. This must be called after self.add_service_spec().
366
+ """Add inference engine specification. This must be called after self.add_service_spec() or self.add_job_spec().
367
367
 
368
368
  Args:
369
369
  inference_engine: Inference engine.
@@ -376,9 +376,10 @@ class ModelDeploymentSpec:
376
376
  ValueError: If inference engine specification is called before add_service_spec().
377
377
  ValueError: If the argument does not have a '--' prefix.
378
378
  """
379
- # TODO: needs to eventually support job deployment spec
380
- if self._service is None:
381
- raise ValueError("Inference engine specification must be called after add_service_spec().")
379
+ if self._service is None and self._job is None:
380
+ raise ValueError(
381
+ "Inference engine specification must be called after add_service_spec() or add_job_spec()."
382
+ )
382
383
 
383
384
  if inference_engine_args is None:
384
385
  inference_engine_args = []
@@ -431,11 +432,17 @@ class ModelDeploymentSpec:
431
432
 
432
433
  inference_engine_args = filtered_args
433
434
 
434
- self._service.inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
435
+ inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec(
435
436
  # convert to string to be saved in the deployment spec
436
437
  inference_engine_name=inference_engine.value,
437
438
  inference_engine_args=inference_engine_args,
438
439
  )
440
+
441
+ if self._service:
442
+ self._service.inference_engine_spec = inference_engine_spec
443
+ elif self._job:
444
+ self._job.inference_engine_spec = inference_engine_spec
445
+
439
446
  return self
440
447
 
441
448
  def save(self) -> str:
@@ -61,6 +61,7 @@ class Job(BaseModel):
61
61
  input: Input
62
62
  output: Output
63
63
  replicas: Optional[int] = None
64
+ inference_engine_spec: Optional[InferenceEngineSpec] = None
64
65
 
65
66
 
66
67
  class LogModelArgs(BaseModel):
@@ -47,22 +47,6 @@ class ServiceStatus(enum.Enum):
47
47
  INTERNAL_ERROR = "INTERNAL_ERROR"
48
48
 
49
49
 
50
- class InstanceStatus(enum.Enum):
51
- PENDING = "PENDING"
52
- READY = "READY"
53
- FAILED = "FAILED"
54
- TERMINATING = "TERMINATING"
55
- SUCCEEDED = "SUCCEEDED"
56
-
57
-
58
- class ContainerStatus(enum.Enum):
59
- PENDING = "PENDING"
60
- READY = "READY"
61
- DONE = "DONE"
62
- FAILED = "FAILED"
63
- UNKNOWN = "UNKNOWN"
64
-
65
-
66
50
  @dataclasses.dataclass
67
51
  class ServiceStatusInfo:
68
52
  """
@@ -72,8 +56,8 @@ class ServiceStatusInfo:
72
56
 
73
57
  service_status: ServiceStatus
74
58
  instance_id: Optional[int] = None
75
- instance_status: Optional[InstanceStatus] = None
76
- container_status: Optional[ContainerStatus] = None
59
+ instance_status: Optional[str] = None
60
+ container_status: Optional[str] = None
77
61
  message: Optional[str] = None
78
62
 
79
63
 
@@ -91,10 +75,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
91
75
  DESC_SERVICE_SPEC_COL_NAME = "spec"
92
76
  DESC_SERVICE_CONTAINERS_SPEC_NAME = "containers"
93
77
  DESC_SERVICE_NAME_SPEC_NAME = "name"
94
- DESC_SERVICE_PROXY_SPEC_ENV_NAME = "env"
95
- PROXY_CONTAINER_NAME = "proxy"
78
+ DESC_SERVICE_ENV_SPEC_NAME = "env"
96
79
  MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME = "SPCS_MODEL_INFERENCE_SERVER__AUTOCAPTURE_ENABLED"
97
- FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
98
80
 
99
81
  @contextlib.contextmanager
100
82
  def _qmark_paramstyle(self) -> Generator[None, None, None]:
@@ -272,17 +254,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
272
254
  )
273
255
  statuses = []
274
256
  for r in rows:
275
- instance_status, container_status = None, None
276
- if r[ServiceSQLClient.INSTANCE_STATUS] is not None:
277
- instance_status = InstanceStatus(r[ServiceSQLClient.INSTANCE_STATUS])
278
- if r[ServiceSQLClient.CONTAINER_STATUS] is not None:
279
- container_status = ContainerStatus(r[ServiceSQLClient.CONTAINER_STATUS])
280
257
  statuses.append(
281
258
  ServiceStatusInfo(
282
259
  service_status=ServiceStatus(r[ServiceSQLClient.SERVICE_STATUS]),
283
260
  instance_id=r[ServiceSQLClient.INSTANCE_ID],
284
- instance_status=instance_status,
285
- container_status=container_status,
261
+ instance_status=r[ServiceSQLClient.INSTANCE_STATUS],
262
+ container_status=r[ServiceSQLClient.CONTAINER_STATUS],
286
263
  message=r[ServiceSQLClient.MESSAGE] if include_message else None,
287
264
  )
288
265
  )
@@ -306,39 +283,33 @@ class ServiceSQLClient(_base._BaseSQLClient):
306
283
  )
307
284
  return rows[0]
308
285
 
309
- def get_proxy_container_autocapture(self, row: row.Row) -> bool:
310
- """Extract whether service has autocapture enabled from proxy container spec.
286
+ def is_autocapture_enabled(self, row: row.Row) -> bool:
287
+ """Extract whether service has autocapture enabled in any container from service spec.
311
288
 
312
289
  Args:
313
290
  row: A row.Row object from DESCRIBE SERVICE containing the service YAML spec.
314
291
 
315
292
  Returns:
316
- True if autocapture is enabled in proxy spec
317
- False if disabled or not set in proxy spec
318
- False if service doesn't have proxy container
293
+ True if autocapture is enabled in any container.
294
+ False if autocapture is disabled or not set in any container.
319
295
  """
320
- try:
321
- spec_yaml = row[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME]
322
- if spec_yaml is None:
323
- return False
324
- spec_raw = yaml.safe_load(spec_yaml)
325
- if spec_raw is None:
326
- return False
327
- spec = cast(dict[str, Any], spec_raw)
328
-
329
- proxy_container_spec = next(
330
- container
331
- for container in spec[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME][
332
- ServiceSQLClient.DESC_SERVICE_CONTAINERS_SPEC_NAME
333
- ]
334
- if container[ServiceSQLClient.DESC_SERVICE_NAME_SPEC_NAME] == ServiceSQLClient.PROXY_CONTAINER_NAME
335
- )
336
- env = proxy_container_spec.get(ServiceSQLClient.DESC_SERVICE_PROXY_SPEC_ENV_NAME, {})
337
- autocapture_enabled = env.get(ServiceSQLClient.MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME, "false")
338
- return str(autocapture_enabled).lower() == "true"
339
-
340
- except StopIteration:
296
+ spec_yaml = row.as_dict().get(ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME)
297
+ if spec_yaml is None:
341
298
  return False
299
+ spec_raw = yaml.safe_load(spec_yaml)
300
+ if spec_raw is None:
301
+ return False
302
+ spec = cast(dict[str, Any], spec_raw)
303
+
304
+ containers = spec[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME][
305
+ ServiceSQLClient.DESC_SERVICE_CONTAINERS_SPEC_NAME
306
+ ]
307
+ for container in containers:
308
+ env = container.get(ServiceSQLClient.DESC_SERVICE_ENV_SPEC_NAME, {})
309
+ autocapture_enabled = env.get(ServiceSQLClient.MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME, "false")
310
+ if str(autocapture_enabled).lower() == "true":
311
+ return True
312
+ return False
342
313
 
343
314
  def drop_service(
344
315
  self,
@@ -41,11 +41,29 @@ features = meta.signatures[TARGET_METHOD].inputs
41
41
  input_cols = [feature.name for feature in features]
42
42
  dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
43
43
 
44
+ # Load inference parameters from method signature (if any)
45
+ param_cols = []
46
+ param_defaults = {{}}
47
+ if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
48
+ for param_spec in meta.signatures[TARGET_METHOD].params:
49
+ param_cols.append(param_spec.name)
50
+ param_defaults[param_spec.name] = param_spec.default_value
51
+
44
52
 
45
53
  # Actual function
46
54
  @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
47
55
  def {function_name}(df: pd.DataFrame) -> dict:
48
- df.columns = input_cols
49
- input_df = df.astype(dtype=dtype_map)
50
- predictions_df = runner(input_df[input_cols])
56
+ df.columns = input_cols + param_cols
57
+ input_df = df[input_cols].astype(dtype=dtype_map)
58
+
59
+ # Extract runtime param values, using defaults if None
60
+ method_params = {{}}
61
+ for col in param_cols:
62
+ val = df[col].iloc[0]
63
+ if val is None or pd.isna(val):
64
+ method_params[col] = param_defaults[col]
65
+ else:
66
+ method_params[col] = val
67
+
68
+ predictions_df = runner(input_df, **method_params)
51
69
  return predictions_df.replace({{pd.NA: None, np.nan: None}}).to_dict("records")
@@ -45,11 +45,29 @@ features = meta.signatures[TARGET_METHOD].inputs
45
45
  input_cols = [feature.name for feature in features]
46
46
  dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
47
47
 
48
+ # Load inference parameters from method signature (if any)
49
+ param_cols = []
50
+ param_defaults = {{}}
51
+ if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
52
+ for param_spec in meta.signatures[TARGET_METHOD].params:
53
+ param_cols.append(param_spec.name)
54
+ param_defaults[param_spec.name] = param_spec.default_value
55
+
48
56
 
49
57
  # Actual table function
50
58
  class {function_name}:
51
59
  @vectorized(input=pd.DataFrame, flatten_object_input={wide_input})
52
60
  def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
53
- df.columns = input_cols
54
- input_df = df.astype(dtype=dtype_map)
55
- return runner(input_df[input_cols])
61
+ df.columns = input_cols + param_cols
62
+ input_df = df[input_cols].astype(dtype=dtype_map)
63
+
64
+ # Extract runtime param values, using defaults if None
65
+ method_params = {{}}
66
+ for col in param_cols:
67
+ val = df[col].iloc[0]
68
+ if val is None or pd.isna(val):
69
+ method_params[col] = param_defaults[col]
70
+ else:
71
+ method_params[col] = val
72
+
73
+ return runner(input_df, **method_params)
@@ -40,11 +40,29 @@ features = meta.signatures[TARGET_METHOD].inputs
40
40
  input_cols = [feature.name for feature in features]
41
41
  dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
42
42
 
43
+ # Load inference parameters from method signature (if any)
44
+ param_cols = []
45
+ param_defaults = {{}}
46
+ if hasattr(meta.signatures[TARGET_METHOD], "params") and meta.signatures[TARGET_METHOD].params:
47
+ for param_spec in meta.signatures[TARGET_METHOD].params:
48
+ param_cols.append(param_spec.name)
49
+ param_defaults[param_spec.name] = param_spec.default_value
50
+
43
51
 
44
52
  # Actual table function
45
53
  class {function_name}:
46
54
  @vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
47
55
  def process(self, df: pd.DataFrame) -> pd.DataFrame:
48
- df.columns = input_cols
49
- input_df = df.astype(dtype=dtype_map)
50
- return runner(input_df[input_cols])
56
+ df.columns = input_cols + param_cols
57
+ input_df = df[input_cols].astype(dtype=dtype_map)
58
+
59
+ # Extract runtime param values, using defaults if None
60
+ method_params = {{}}
61
+ for col in param_cols:
62
+ val = df[col].iloc[0]
63
+ if val is None or pd.isna(val):
64
+ method_params[col] = param_defaults[col]
65
+ else:
66
+ method_params[col] = val
67
+
68
+ return runner(input_df, **method_params)
@@ -156,10 +156,12 @@ class ModelMethod:
156
156
  f"Your parameter {param_spec.name} cannot be resolved as valid SQL identifier. "
157
157
  "Try specifying `case_sensitive` as True."
158
158
  ) from e
159
+ # Convert None to "NULL" string so MANIFEST parser can interpret it as SQL NULL
160
+ default_value = "NULL" if param_spec.default_value is None else str(param_spec.default_value)
159
161
  return model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault(
160
162
  name=param_name.resolved(),
161
163
  type=type_utils.convert_sp_to_sf_type(param_spec.dtype.as_snowpark_type()),
162
- default=param_spec.default_value,
164
+ default=default_value,
163
165
  )
164
166
 
165
167
  def save(
@@ -1,3 +1,4 @@
1
+ import io
1
2
  import json
2
3
  import logging
3
4
  import os
@@ -28,7 +29,10 @@ from snowflake.ml.model._packager.model_meta import (
28
29
  model_meta as model_meta_api,
29
30
  model_meta_schema,
30
31
  )
31
- from snowflake.ml.model._signatures import utils as model_signature_utils
32
+ from snowflake.ml.model._signatures import (
33
+ core as model_signature_core,
34
+ utils as model_signature_utils,
35
+ )
32
36
  from snowflake.ml.model.models import (
33
37
  huggingface as huggingface_base,
34
38
  huggingface_pipeline,
@@ -530,7 +534,10 @@ class TransformersPipelineHandler(
530
534
  # verify when the target method is __call__ and
531
535
  # if the signature is default text-generation signature
532
536
  # then use the HuggingFaceOpenAICompatibleModel to wrap the pipeline
533
- if signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC:
537
+ if (
538
+ signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC
539
+ or signature == openai_signatures._OPENAI_CHAT_SIGNATURE_SPEC_WITH_CONTENT_FORMAT_STRING
540
+ ):
534
541
  wrapped_model = HuggingFaceOpenAICompatibleModel(pipeline=raw_model)
535
542
 
536
543
  temp_res = X.apply(
@@ -554,6 +561,39 @@ class TransformersPipelineHandler(
554
561
  else:
555
562
  input_data = X[signature.inputs[0].name].to_list()
556
563
  temp_res = getattr(raw_model, target_method)(input_data)
564
+ elif isinstance(raw_model, transformers.ImageClassificationPipeline):
565
+ # Image classification expects PIL Images. Convert bytes to PIL Images.
566
+ from PIL import Image
567
+
568
+ input_col = signature.inputs[0].name
569
+ images = [Image.open(io.BytesIO(img_bytes)) for img_bytes in X[input_col].to_list()]
570
+ temp_res = getattr(raw_model, target_method)(images)
571
+ elif isinstance(raw_model, transformers.AutomaticSpeechRecognitionPipeline):
572
+ # ASR pipeline accepts a single audio input (bytes, str, np.ndarray, or dict),
573
+ # not a list. Process each audio input individually.
574
+ input_col = signature.inputs[0].name
575
+ audio_inputs = X[input_col].to_list()
576
+ temp_res = [getattr(raw_model, target_method)(audio) for audio in audio_inputs]
577
+ elif isinstance(raw_model, transformers.VideoClassificationPipeline):
578
+ # Video classification expects file paths. Write bytes to temp files,
579
+ # process them, and clean up.
580
+ import tempfile
581
+
582
+ input_col = signature.inputs[0].name
583
+ video_bytes_list = X[input_col].to_list()
584
+ temp_file_paths = []
585
+ temp_files = []
586
+ try:
587
+ # TODO: parallelize this if needed
588
+ for video_bytes in video_bytes_list:
589
+ temp_file = tempfile.NamedTemporaryFile()
590
+ temp_file.write(video_bytes)
591
+ temp_file_paths.append(temp_file.name)
592
+ temp_files.append(temp_file)
593
+ temp_res = getattr(raw_model, target_method)(temp_file_paths)
594
+ finally:
595
+ for f in temp_files:
596
+ f.close()
557
597
  else:
558
598
  # TODO: remove conversational pipeline code
559
599
  # For others, we could offer the whole dataframe as a list.
@@ -615,11 +655,14 @@ class TransformersPipelineHandler(
615
655
  temp_res = [[conv.generated_responses] for conv in temp_res]
616
656
 
617
657
  # To concat those who outputs a list with one input.
618
- if isinstance(temp_res[0], list):
619
- if isinstance(temp_res[0][0], dict):
620
- res = pd.DataFrame({0: temp_res})
621
- else:
622
- res = pd.DataFrame(temp_res)
658
+ # if `signature.outputs` is single valued and is a FeatureGroupSpec,
659
+ # we create a DataFrame with one column and the values are stored as a dictionary.
660
+ # Otherwise, we create a DataFrame with the output as the column.
661
+ if len(signature.outputs) == 1 and isinstance(
662
+ signature.outputs[0], model_signature_core.FeatureGroupSpec
663
+ ):
664
+ # creating a dataframe with one column
665
+ res = pd.DataFrame({signature.outputs[0].name: temp_res})
623
666
  else:
624
667
  res = pd.DataFrame(temp_res)
625
668
 
@@ -702,7 +745,6 @@ class HuggingFaceOpenAICompatibleModel:
702
745
  self.pipeline = pipeline
703
746
  self.model = self.pipeline.model
704
747
  self.tokenizer = self.pipeline.tokenizer
705
-
706
748
  self.model_name = self.pipeline.model.name_or_path
707
749
 
708
750
  if self.tokenizer.pad_token is None:
@@ -724,11 +766,33 @@ class HuggingFaceOpenAICompatibleModel:
724
766
  Returns:
725
767
  The formatted prompt string ready for model input.
726
768
  """
769
+
770
+ final_messages = []
771
+ for message in messages:
772
+ if isinstance(message.get("content", ""), str):
773
+ final_messages.append({"role": message.get("role", "user"), "content": message.get("content", "")})
774
+ else:
775
+ # extract only the text from the content
776
+ # sample data:
777
+ # {
778
+ # "role": "user",
779
+ # "content": [
780
+ # {"type": "text", "text": "Hello, how are you?"}, # extracted
781
+ # {"type": "image", "image": "https://example.com/image.png"}, # not extracted
782
+ # ],
783
+ # }
784
+ for content_part in message.get("content", []):
785
+ if content_part.get("type", "") == "text":
786
+ final_messages.append(
787
+ {"role": message.get("role", "user"), "content": content_part.get("text", "")}
788
+ )
789
+ # TODO: implement other content types
790
+
727
791
  # Use the tokenizer's apply_chat_template method.
728
792
  # We ensured a template exists in __init__.
729
793
  if hasattr(self.tokenizer, "apply_chat_template"):
730
794
  return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
731
- messages,
795
+ final_messages,
732
796
  tokenize=False,
733
797
  add_generation_prompt=True,
734
798
  )
@@ -736,7 +800,7 @@ class HuggingFaceOpenAICompatibleModel:
736
800
  # Fallback for very old transformers without apply_chat_template
737
801
  # Manually apply ChatML-like formatting
738
802
  prompt = ""
739
- for message in messages:
803
+ for message in final_messages:
740
804
  role = message.get("role", "user")
741
805
  content = message.get("content", "")
742
806
  prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"