snowflake-ml-python 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +0 -4
  2. snowflake/ml/_internal/utils/mixins.py +26 -1
  3. snowflake/ml/data/_internal/arrow_ingestor.py +5 -1
  4. snowflake/ml/data/data_connector.py +2 -2
  5. snowflake/ml/data/data_ingestor.py +2 -1
  6. snowflake/ml/experiment/_experiment_info.py +3 -3
  7. snowflake/ml/feature_store/__init__.py +2 -0
  8. snowflake/ml/feature_store/aggregation.py +367 -0
  9. snowflake/ml/feature_store/feature.py +366 -0
  10. snowflake/ml/feature_store/feature_store.py +234 -20
  11. snowflake/ml/feature_store/feature_view.py +189 -4
  12. snowflake/ml/feature_store/metadata_manager.py +425 -0
  13. snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
  14. snowflake/ml/jobs/_interop/data_utils.py +8 -8
  15. snowflake/ml/jobs/_interop/dto_schema.py +52 -7
  16. snowflake/ml/jobs/_interop/protocols.py +124 -7
  17. snowflake/ml/jobs/_interop/utils.py +92 -33
  18. snowflake/ml/jobs/_utils/arg_protocol.py +7 -0
  19. snowflake/ml/jobs/_utils/constants.py +4 -0
  20. snowflake/ml/jobs/_utils/feature_flags.py +97 -13
  21. snowflake/ml/jobs/_utils/payload_utils.py +6 -40
  22. snowflake/ml/jobs/_utils/runtime_env_utils.py +12 -111
  23. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +204 -27
  24. snowflake/ml/jobs/decorators.py +17 -22
  25. snowflake/ml/jobs/job.py +25 -10
  26. snowflake/ml/jobs/job_definition.py +100 -8
  27. snowflake/ml/model/__init__.py +4 -0
  28. snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
  29. snowflake/ml/model/_client/model/model_version_impl.py +56 -28
  30. snowflake/ml/model/_client/ops/model_ops.py +2 -8
  31. snowflake/ml/model/_client/ops/service_ops.py +6 -11
  32. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  33. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  34. snowflake/ml/model/_client/sql/service.py +21 -29
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -1
  36. snowflake/ml/model/_packager/model_handlers/huggingface.py +20 -0
  37. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +70 -14
  38. snowflake/ml/model/_signatures/utils.py +76 -1
  39. snowflake/ml/model/models/huggingface_pipeline.py +3 -0
  40. snowflake/ml/model/openai_signatures.py +154 -0
  41. snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -3
  42. snowflake/ml/version.py +1 -1
  43. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/METADATA +79 -2
  44. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/RECORD +47 -44
  45. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/WHEEL +1 -1
  46. snowflake/ml/jobs/_utils/function_payload_utils.py +0 -43
  47. snowflake/ml/jobs/_utils/spec_utils.py +0 -22
  48. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/licenses/LICENSE.txt +0 -0
  49. {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/top_level.txt +0 -0
@@ -75,10 +75,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
75
75
  DESC_SERVICE_SPEC_COL_NAME = "spec"
76
76
  DESC_SERVICE_CONTAINERS_SPEC_NAME = "containers"
77
77
  DESC_SERVICE_NAME_SPEC_NAME = "name"
78
- DESC_SERVICE_PROXY_SPEC_ENV_NAME = "env"
79
- PROXY_CONTAINER_NAME = "proxy"
78
+ DESC_SERVICE_ENV_SPEC_NAME = "env"
80
79
  MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME = "SPCS_MODEL_INFERENCE_SERVER__AUTOCAPTURE_ENABLED"
81
- FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
82
80
 
83
81
  @contextlib.contextmanager
84
82
  def _qmark_paramstyle(self) -> Generator[None, None, None]:
@@ -285,39 +283,33 @@ class ServiceSQLClient(_base._BaseSQLClient):
285
283
  )
286
284
  return rows[0]
287
285
 
288
- def get_proxy_container_autocapture(self, row: row.Row) -> bool:
289
- """Extract whether service has autocapture enabled from proxy container spec.
286
+ def is_autocapture_enabled(self, row: row.Row) -> bool:
287
+ """Extract whether service has autocapture enabled in any container from service spec.
290
288
 
291
289
  Args:
292
290
  row: A row.Row object from DESCRIBE SERVICE containing the service YAML spec.
293
291
 
294
292
  Returns:
295
- True if autocapture is enabled in proxy spec
296
- False if disabled or not set in proxy spec
297
- False if service doesn't have proxy container
293
+ True if autocapture is enabled in any container.
294
+ False if autocapture is disabled or not set in any container.
298
295
  """
299
- try:
300
- spec_yaml = row[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME]
301
- if spec_yaml is None:
302
- return False
303
- spec_raw = yaml.safe_load(spec_yaml)
304
- if spec_raw is None:
305
- return False
306
- spec = cast(dict[str, Any], spec_raw)
307
-
308
- proxy_container_spec = next(
309
- container
310
- for container in spec[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME][
311
- ServiceSQLClient.DESC_SERVICE_CONTAINERS_SPEC_NAME
312
- ]
313
- if container[ServiceSQLClient.DESC_SERVICE_NAME_SPEC_NAME] == ServiceSQLClient.PROXY_CONTAINER_NAME
314
- )
315
- env = proxy_container_spec.get(ServiceSQLClient.DESC_SERVICE_PROXY_SPEC_ENV_NAME, {})
316
- autocapture_enabled = env.get(ServiceSQLClient.MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME, "false")
317
- return str(autocapture_enabled).lower() == "true"
318
-
319
- except StopIteration:
296
+ spec_yaml = row.as_dict().get(ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME)
297
+ if spec_yaml is None:
320
298
  return False
299
+ spec_raw = yaml.safe_load(spec_yaml)
300
+ if spec_raw is None:
301
+ return False
302
+ spec = cast(dict[str, Any], spec_raw)
303
+
304
+ containers = spec[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME][
305
+ ServiceSQLClient.DESC_SERVICE_CONTAINERS_SPEC_NAME
306
+ ]
307
+ for container in containers:
308
+ env = container.get(ServiceSQLClient.DESC_SERVICE_ENV_SPEC_NAME, {})
309
+ autocapture_enabled = env.get(ServiceSQLClient.MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME, "false")
310
+ if str(autocapture_enabled).lower() == "true":
311
+ return True
312
+ return False
321
313
 
322
314
  def drop_service(
323
315
  self,
@@ -156,7 +156,8 @@ class ModelMethod:
156
156
  f"Your parameter {param_spec.name} cannot be resolved as valid SQL identifier. "
157
157
  "Try specifying `case_sensitive` as True."
158
158
  ) from e
159
- default_value = param_spec.default_value if param_spec.default_value is None else str(param_spec.default_value)
159
+ # Convert None to "NULL" string so MANIFEST parser can interpret it as SQL NULL
160
+ default_value = "NULL" if param_spec.default_value is None else str(param_spec.default_value)
160
161
  return model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault(
161
162
  name=param_name.resolved(),
162
163
  type=type_utils.convert_sp_to_sf_type(param_spec.dtype.as_snowpark_type()),
@@ -574,6 +574,26 @@ class TransformersPipelineHandler(
574
574
  input_col = signature.inputs[0].name
575
575
  audio_inputs = X[input_col].to_list()
576
576
  temp_res = [getattr(raw_model, target_method)(audio) for audio in audio_inputs]
577
+ elif isinstance(raw_model, transformers.VideoClassificationPipeline):
578
+ # Video classification expects file paths. Write bytes to temp files,
579
+ # process them, and clean up.
580
+ import tempfile
581
+
582
+ input_col = signature.inputs[0].name
583
+ video_bytes_list = X[input_col].to_list()
584
+ temp_file_paths = []
585
+ temp_files = []
586
+ try:
587
+ # TODO: parallelize this if needed
588
+ for video_bytes in video_bytes_list:
589
+ temp_file = tempfile.NamedTemporaryFile()
590
+ temp_file.write(video_bytes)
591
+ temp_file_paths.append(temp_file.name)
592
+ temp_files.append(temp_file)
593
+ temp_res = getattr(raw_model, target_method)(temp_file_paths)
594
+ finally:
595
+ for f in temp_files:
596
+ f.close()
577
597
  else:
578
598
  # TODO: remove conversational pipeline code
579
599
  # For others, we could offer the whole dataframe as a list.
@@ -16,6 +16,7 @@ from snowflake.ml.model._packager.model_meta import (
16
16
  model_meta as model_meta_api,
17
17
  model_meta_schema,
18
18
  )
19
+ from snowflake.ml.model._signatures import utils as model_signature_utils
19
20
  from snowflake.snowpark._internal import utils as snowpark_utils
20
21
 
21
22
  if TYPE_CHECKING:
@@ -24,10 +25,14 @@ if TYPE_CHECKING:
24
25
  logger = logging.getLogger(__name__)
25
26
 
26
27
  # Allowlist of supported target methods for SentenceTransformer models.
27
- _ALLOWED_TARGET_METHODS = ["encode", "encode_queries", "encode_documents"]
28
+ # Note: sentence-transformers >= 3.0 uses singular names (encode_query, encode_document)
29
+ # while older versions may use plural names (encode_queries, encode_documents).
30
+ _ALLOWED_TARGET_METHODS = ["encode", "encode_query", "encode_document", "encode_queries", "encode_documents"]
28
31
 
29
32
 
30
- def _validate_sentence_transformers_signatures(sigs: dict[str, model_signature.ModelSignature]) -> None:
33
+ def _validate_sentence_transformers_signatures(
34
+ sigs: dict[str, model_signature.ModelSignature],
35
+ ) -> None:
31
36
  """Validate signatures for SentenceTransformer models.
32
37
 
33
38
  Args:
@@ -82,7 +87,9 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
82
87
  _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
83
88
 
84
89
  MODEL_BLOB_FILE_OR_DIR = "model"
85
- DEFAULT_TARGET_METHODS = ["encode", "encode_queries", "encode_documents"]
90
+ # Default to singular names which are used in sentence-transformers >= 3.0
91
+ DEFAULT_TARGET_METHODS = ["encode", "encode_query", "encode_document"]
92
+ IS_AUTO_SIGNATURE = True
86
93
 
87
94
  @classmethod
88
95
  def can_handle(
@@ -138,7 +145,8 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
138
145
  raise ValueError(f"target_methods {target_methods} must be a subset of {_ALLOWED_TARGET_METHODS}.")
139
146
 
140
147
  def get_prediction(
141
- target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
148
+ target_method_name: str,
149
+ sample_input_data: model_types.SupportedLocalDataType,
142
150
  ) -> model_types.SupportedLocalDataType:
143
151
  if not isinstance(sample_input_data, pd.DataFrame):
144
152
  sample_input_data = model_signature._convert_local_data_to_df(data=sample_input_data)
@@ -149,8 +157,13 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
149
157
  )
150
158
  X_list = sample_input_data.iloc[:, 0].tolist()
151
159
 
152
- assert callable(getattr(model, "encode", None))
153
- return pd.DataFrame({0: model.encode(X_list, batch_size=batch_size).tolist()})
160
+ # Call the appropriate method based on target_method_name
161
+ method_to_call = getattr(model, target_method_name, None)
162
+ if not callable(method_to_call):
163
+ raise ValueError(
164
+ f"SentenceTransformer model does not have a callable method '{target_method_name}'."
165
+ )
166
+ return pd.DataFrame({0: method_to_call(X_list, batch_size=batch_size).tolist()})
154
167
 
155
168
  if model_meta.signatures:
156
169
  handlers_utils.validate_target_methods(model, list(model_meta.signatures.keys()))
@@ -171,6 +184,36 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
171
184
  sample_input_data=sample_input_data,
172
185
  get_prediction_fn=get_prediction,
173
186
  )
187
+ else:
188
+ # Auto-infer signature from model when no sample_input_data is provided
189
+ # Get the embedding dimension from the model
190
+ embedding_dim = model.get_sentence_embedding_dimension()
191
+ if embedding_dim is None:
192
+ raise ValueError(
193
+ "Unable to auto-infer signature: model.get_sentence_embedding_dimension() returned None. "
194
+ "Please provide sample_input_data or signatures explicitly."
195
+ )
196
+
197
+ for target_method in target_methods:
198
+ # target_methods are already validated as callable by get_target_methods()
199
+ inferred_sig = model_signature_utils.sentence_transformers_signature_auto_infer(
200
+ target_method=target_method,
201
+ embedding_dim=embedding_dim,
202
+ )
203
+ if inferred_sig is None:
204
+ raise ValueError(
205
+ f"Unable to auto-infer signature for method '{target_method}'. "
206
+ "Please provide sample_input_data or signatures explicitly."
207
+ )
208
+ model_meta.signatures[target_method] = inferred_sig
209
+
210
+ # Ensure at least one method was successfully inferred
211
+ if not model_meta.signatures:
212
+ raise ValueError(
213
+ "No valid target methods found on the model. "
214
+ "Please provide sample_input_data or signatures explicitly, "
215
+ "or specify target_methods that exist on your model."
216
+ )
174
217
 
175
218
  _validate_sentence_transformers_signatures(model_meta.signatures)
176
219
 
@@ -196,7 +239,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
196
239
 
197
240
  model_meta.env.include_if_absent(
198
241
  [
199
- model_env.ModelDependency(requirement="sentence-transformers", pip_name="sentence-transformers"),
242
+ model_env.ModelDependency(
243
+ requirement="sentence-transformers",
244
+ pip_name="sentence-transformers",
245
+ ),
200
246
  model_env.ModelDependency(requirement="transformers", pip_name="transformers"),
201
247
  model_env.ModelDependency(requirement="pytorch", pip_name="torch"),
202
248
  ],
@@ -205,7 +251,9 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
205
251
  model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
206
252
 
207
253
  @staticmethod
208
- def _get_device_config(**kwargs: Unpack[model_types.SentenceTransformersLoadOptions]) -> Optional[str]:
254
+ def _get_device_config(
255
+ **kwargs: Unpack[model_types.SentenceTransformersLoadOptions],
256
+ ) -> Optional[str]:
209
257
  if kwargs.get("device", None) is not None:
210
258
  return kwargs["device"]
211
259
  elif kwargs.get("use_gpu", False):
@@ -262,7 +310,8 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
262
310
  model_meta: model_meta_api.ModelMetadata,
263
311
  ) -> type[custom_model.CustomModel]:
264
312
  batch_size = cast(
265
- model_meta_schema.SentenceTransformersModelBlobOptions, model_meta.models[model_meta.name].options
313
+ model_meta_schema.SentenceTransformersModelBlobOptions,
314
+ model_meta.models[model_meta.name].options,
266
315
  ).get("batch_size", None)
267
316
 
268
317
  def get_prediction(
@@ -270,12 +319,20 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
270
319
  signature: model_signature.ModelSignature,
271
320
  target_method: str,
272
321
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
322
+ # Capture target_method in closure to call the correct model method
323
+ method_to_call = getattr(raw_model, target_method, None)
324
+ if not callable(method_to_call):
325
+ raise ValueError(
326
+ f"SentenceTransformer model does not have a callable method '{target_method}'. "
327
+ f"This method may not be available in your version of sentence-transformers."
328
+ )
329
+
273
330
  @custom_model.inference_api
274
331
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
275
332
  X_list = X.iloc[:, 0].tolist()
276
333
 
277
334
  return pd.DataFrame(
278
- {signature.outputs[0].name: raw_model.encode(X_list, batch_size=batch_size).tolist()}
335
+ {signature.outputs[0].name: method_to_call(X_list, batch_size=batch_size).tolist()}
279
336
  )
280
337
 
281
338
  return fn
@@ -298,7 +355,6 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
298
355
  model = raw_model
299
356
 
300
357
  _SentenceTransformer = _create_custom_model(model, model_meta)
301
- sentence_transformers_SentenceTransformer_model = _SentenceTransformer(custom_model.ModelContext())
302
- predict_method = getattr(sentence_transformers_SentenceTransformer_model, "encode", None)
303
- assert callable(predict_method)
304
- return sentence_transformers_SentenceTransformer_model
358
+ sentence_transformers_model = _SentenceTransformer(custom_model.ModelContext())
359
+
360
+ return sentence_transformers_model
@@ -298,6 +298,24 @@ def huggingface_pipeline_signature_auto_infer(
298
298
  shape=(-1,), # Variable length list of chunks
299
299
  ),
300
300
  ],
301
+ )
302
+ ],
303
+ )
304
+
305
+ # https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.VideoClassificationPipeline
306
+ if task == "video-classification":
307
+ return core.ModelSignature(
308
+ inputs=[
309
+ core.FeatureSpec(name="video", dtype=core.DataType.BYTES),
310
+ ],
311
+ outputs=[
312
+ core.FeatureGroupSpec(
313
+ name="labels",
314
+ specs=[
315
+ core.FeatureSpec(name="label", dtype=core.DataType.STRING),
316
+ core.FeatureSpec(name="score", dtype=core.DataType.DOUBLE),
317
+ ],
318
+ shape=(-1,),
301
319
  ),
302
320
  ],
303
321
  )
@@ -333,7 +351,11 @@ def huggingface_pipeline_signature_auto_infer(
333
351
  )
334
352
 
335
353
  # https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ImageTextToTextPipeline
336
- if task == "image-text-to-text":
354
+ if task in [
355
+ "image-text-to-text",
356
+ "video-text-to-text",
357
+ "audio-text-to-text",
358
+ ]:
337
359
  if params.get("return_tensors", False):
338
360
  raise NotImplementedError(
339
361
  f"Auto deployment for HuggingFace pipeline {task} "
@@ -461,3 +483,56 @@ def infer_dict(name: str, data: dict[str, Any]) -> core.FeatureGroupSpec:
461
483
 
462
484
  def check_if_series_is_empty(series: Optional[pd.Series]) -> bool:
463
485
  return series is None or series.empty
486
+
487
+
488
+ def sentence_transformers_signature_auto_infer(
489
+ target_method: str,
490
+ embedding_dim: int,
491
+ ) -> Optional[core.ModelSignature]:
492
+ """Auto-infer signature for SentenceTransformer models.
493
+
494
+ SentenceTransformer models have a simple signature: they take a string input
495
+ and return an embedding vector (array of floats).
496
+
497
+ Args:
498
+ target_method: The target method name. Supported methods:
499
+ - "encode": General encoding method
500
+ - "encode_query" / "encode_queries": Query encoding for asymmetric search
501
+ - "encode_document" / "encode_documents": Document encoding for asymmetric search
502
+ embedding_dim: The dimension of the embedding vector output by the model.
503
+
504
+ Returns:
505
+ A ModelSignature for the target method, or None if the method is not supported.
506
+
507
+ Note:
508
+ sentence-transformers >= 3.0 uses singular names (encode_query, encode_document)
509
+ while older versions may use plural names (encode_queries, encode_documents).
510
+ Both naming conventions are supported for backward compatibility.
511
+ """
512
+ # Support both singular (new) and plural (old) naming conventions
513
+ supported_methods = [
514
+ "encode",
515
+ "encode_query",
516
+ "encode_document",
517
+ "encode_queries",
518
+ "encode_documents",
519
+ ]
520
+
521
+ if target_method not in supported_methods:
522
+ return None
523
+
524
+ # All SentenceTransformer encode methods have the same signature pattern:
525
+ # - Input: a single string column
526
+ # - Output: a single column containing embedding vectors (array of floats)
527
+ return core.ModelSignature(
528
+ inputs=[
529
+ core.FeatureSpec(name="text", dtype=core.DataType.STRING),
530
+ ],
531
+ outputs=[
532
+ core.FeatureSpec(
533
+ name="output",
534
+ dtype=core.DataType.DOUBLE,
535
+ shape=(embedding_dim,),
536
+ ),
537
+ ],
538
+ )
@@ -105,6 +105,7 @@ class HuggingFacePipelineModel(huggingface.TransformersPipeline):
105
105
  image_repo: Optional[str] = None,
106
106
  image_build_compute_pool: Optional[str] = None,
107
107
  ingress_enabled: bool = False,
108
+ min_instances: int = 0,
108
109
  max_instances: int = 1,
109
110
  cpu_requests: Optional[str] = None,
110
111
  memory_requests: Optional[str] = None,
@@ -133,6 +134,7 @@ class HuggingFacePipelineModel(huggingface.TransformersPipeline):
133
134
  image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
134
135
  the service compute pool if None.
135
136
  ingress_enabled: Whether ingress is enabled. Defaults to False.
137
+ min_instances: Minimum number of instances. Defaults to 0.
136
138
  max_instances: Maximum number of instances. Defaults to 1.
137
139
  cpu_requests: CPU requests configuration. Defaults to None.
138
140
  memory_requests: Memory requests configuration. Defaults to None.
@@ -225,6 +227,7 @@ class HuggingFacePipelineModel(huggingface.TransformersPipeline):
225
227
  service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
226
228
  image_repo_name=image_repo,
227
229
  ingress_enabled=ingress_enabled,
230
+ min_instances=min_instances,
228
231
  max_instances=max_instances,
229
232
  cpu_requests=cpu_requests,
230
233
  memory_requests=memory_requests,
@@ -88,6 +88,96 @@ _OPENAI_CHAT_SIGNATURE_SPEC = core.ModelSignature(
88
88
  ],
89
89
  )
90
90
 
91
+ _OPENAI_CHAT_SIGNATURE_WITH_PARAMS_SPEC = core.ModelSignature(
92
+ inputs=[
93
+ core.FeatureGroupSpec(
94
+ name="messages",
95
+ specs=[
96
+ core.FeatureGroupSpec(
97
+ name="content",
98
+ specs=[
99
+ core.FeatureSpec(name="type", dtype=core.DataType.STRING),
100
+ # Text prompts
101
+ core.FeatureSpec(name="text", dtype=core.DataType.STRING),
102
+ # Image URL prompts
103
+ core.FeatureGroupSpec(
104
+ name="image_url",
105
+ specs=[
106
+ # Base64 encoded image URL or image URL
107
+ core.FeatureSpec(name="url", dtype=core.DataType.STRING),
108
+ # Image detail level (e.g., "low", "high", "auto")
109
+ core.FeatureSpec(name="detail", dtype=core.DataType.STRING),
110
+ ],
111
+ ),
112
+ # Video URL prompts
113
+ core.FeatureGroupSpec(
114
+ name="video_url",
115
+ specs=[
116
+ # Base64 encoded video URL
117
+ core.FeatureSpec(name="url", dtype=core.DataType.STRING),
118
+ ],
119
+ ),
120
+ # Audio prompts
121
+ core.FeatureGroupSpec(
122
+ name="input_audio",
123
+ specs=[
124
+ core.FeatureSpec(name="data", dtype=core.DataType.STRING),
125
+ core.FeatureSpec(name="format", dtype=core.DataType.STRING),
126
+ ],
127
+ ),
128
+ ],
129
+ shape=(-1,),
130
+ ),
131
+ core.FeatureSpec(name="name", dtype=core.DataType.STRING),
132
+ core.FeatureSpec(name="role", dtype=core.DataType.STRING),
133
+ core.FeatureSpec(name="title", dtype=core.DataType.STRING),
134
+ ],
135
+ shape=(-1,),
136
+ ),
137
+ ],
138
+ outputs=[
139
+ core.FeatureSpec(name="id", dtype=core.DataType.STRING),
140
+ core.FeatureSpec(name="object", dtype=core.DataType.STRING),
141
+ core.FeatureSpec(name="created", dtype=core.DataType.FLOAT),
142
+ core.FeatureSpec(name="model", dtype=core.DataType.STRING),
143
+ core.FeatureGroupSpec(
144
+ name="choices",
145
+ specs=[
146
+ core.FeatureSpec(name="index", dtype=core.DataType.INT32),
147
+ core.FeatureGroupSpec(
148
+ name="message",
149
+ specs=[
150
+ core.FeatureSpec(name="content", dtype=core.DataType.STRING),
151
+ core.FeatureSpec(name="name", dtype=core.DataType.STRING),
152
+ core.FeatureSpec(name="role", dtype=core.DataType.STRING),
153
+ ],
154
+ ),
155
+ core.FeatureSpec(name="logprobs", dtype=core.DataType.STRING),
156
+ core.FeatureSpec(name="finish_reason", dtype=core.DataType.STRING),
157
+ ],
158
+ shape=(-1,),
159
+ ),
160
+ core.FeatureGroupSpec(
161
+ name="usage",
162
+ specs=[
163
+ core.FeatureSpec(name="completion_tokens", dtype=core.DataType.INT32),
164
+ core.FeatureSpec(name="prompt_tokens", dtype=core.DataType.INT32),
165
+ core.FeatureSpec(name="total_tokens", dtype=core.DataType.INT32),
166
+ ],
167
+ ),
168
+ ],
169
+ params=[
170
+ core.ParamSpec(name="temperature", dtype=core.DataType.DOUBLE, default_value=1.0),
171
+ core.ParamSpec(name="max_completion_tokens", dtype=core.DataType.INT64, default_value=250),
172
+ core.ParamSpec(name="stop", dtype=core.DataType.STRING, default_value=""),
173
+ core.ParamSpec(name="n", dtype=core.DataType.INT32, default_value=1),
174
+ core.ParamSpec(name="stream", dtype=core.DataType.BOOL, default_value=False),
175
+ core.ParamSpec(name="top_p", dtype=core.DataType.DOUBLE, default_value=1.0),
176
+ core.ParamSpec(name="frequency_penalty", dtype=core.DataType.DOUBLE, default_value=0.0),
177
+ core.ParamSpec(name="presence_penalty", dtype=core.DataType.DOUBLE, default_value=0.0),
178
+ ],
179
+ )
180
+
91
181
  _OPENAI_CHAT_SIGNATURE_SPEC_WITH_CONTENT_FORMAT_STRING = core.ModelSignature(
92
182
  inputs=[
93
183
  core.FeatureGroupSpec(
@@ -142,6 +232,62 @@ _OPENAI_CHAT_SIGNATURE_SPEC_WITH_CONTENT_FORMAT_STRING = core.ModelSignature(
142
232
  ],
143
233
  )
144
234
 
235
+ _OPENAI_CHAT_SIGNATURE_WITH_PARAMS_SPEC_WITH_CONTENT_FORMAT_STRING = core.ModelSignature(
236
+ inputs=[
237
+ core.FeatureGroupSpec(
238
+ name="messages",
239
+ specs=[
240
+ core.FeatureSpec(name="content", dtype=core.DataType.STRING),
241
+ core.FeatureSpec(name="name", dtype=core.DataType.STRING),
242
+ core.FeatureSpec(name="role", dtype=core.DataType.STRING),
243
+ core.FeatureSpec(name="title", dtype=core.DataType.STRING),
244
+ ],
245
+ shape=(-1,),
246
+ ),
247
+ ],
248
+ outputs=[
249
+ core.FeatureSpec(name="id", dtype=core.DataType.STRING),
250
+ core.FeatureSpec(name="object", dtype=core.DataType.STRING),
251
+ core.FeatureSpec(name="created", dtype=core.DataType.FLOAT),
252
+ core.FeatureSpec(name="model", dtype=core.DataType.STRING),
253
+ core.FeatureGroupSpec(
254
+ name="choices",
255
+ specs=[
256
+ core.FeatureSpec(name="index", dtype=core.DataType.INT32),
257
+ core.FeatureGroupSpec(
258
+ name="message",
259
+ specs=[
260
+ core.FeatureSpec(name="content", dtype=core.DataType.STRING),
261
+ core.FeatureSpec(name="name", dtype=core.DataType.STRING),
262
+ core.FeatureSpec(name="role", dtype=core.DataType.STRING),
263
+ ],
264
+ ),
265
+ core.FeatureSpec(name="logprobs", dtype=core.DataType.STRING),
266
+ core.FeatureSpec(name="finish_reason", dtype=core.DataType.STRING),
267
+ ],
268
+ shape=(-1,),
269
+ ),
270
+ core.FeatureGroupSpec(
271
+ name="usage",
272
+ specs=[
273
+ core.FeatureSpec(name="completion_tokens", dtype=core.DataType.INT32),
274
+ core.FeatureSpec(name="prompt_tokens", dtype=core.DataType.INT32),
275
+ core.FeatureSpec(name="total_tokens", dtype=core.DataType.INT32),
276
+ ],
277
+ ),
278
+ ],
279
+ params=[
280
+ core.ParamSpec(name="temperature", dtype=core.DataType.DOUBLE, default_value=1.0),
281
+ core.ParamSpec(name="max_completion_tokens", dtype=core.DataType.INT64, default_value=250),
282
+ core.ParamSpec(name="stop", dtype=core.DataType.STRING, default_value=""),
283
+ core.ParamSpec(name="n", dtype=core.DataType.INT32, default_value=1),
284
+ core.ParamSpec(name="stream", dtype=core.DataType.BOOL, default_value=False),
285
+ core.ParamSpec(name="top_p", dtype=core.DataType.DOUBLE, default_value=1.0),
286
+ core.ParamSpec(name="frequency_penalty", dtype=core.DataType.DOUBLE, default_value=0.0),
287
+ core.ParamSpec(name="presence_penalty", dtype=core.DataType.DOUBLE, default_value=0.0),
288
+ ],
289
+ )
290
+
145
291
 
146
292
  # Refer vLLM documentation: https://docs.vllm.ai/en/stable/serving/openai_compatible_server/#chat-template
147
293
 
@@ -152,3 +298,11 @@ OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING = {"__call__": _OPENAI_CHAT_SIG
152
298
  # This is the default signature.
153
299
  # The content format allows vLLM to handler content parts like text, image, video, audio, file, etc.
154
300
  OPENAI_CHAT_SIGNATURE = {"__call__": _OPENAI_CHAT_SIGNATURE_SPEC}
301
+
302
+ # Use this signature to leverage ParamSpec with the default ChatML template.
303
+ OPENAI_CHAT_WITH_PARAMS_SIGNATURE = {"__call__": _OPENAI_CHAT_SIGNATURE_WITH_PARAMS_SPEC}
304
+
305
+ # Use this signature to leverage ParamSpec with the content format string.
306
+ OPENAI_CHAT_WITH_PARAMS_SIGNATURE_WITH_CONTENT_FORMAT_STRING = {
307
+ "__call__": _OPENAI_CHAT_SIGNATURE_WITH_PARAMS_SPEC_WITH_CONTENT_FORMAT_STRING
308
+ }
@@ -193,12 +193,11 @@ class ModelParameterReconciler:
193
193
  if enable_explainability:
194
194
  if only_spcs or not is_warehouse_runnable:
195
195
  raise ValueError(
196
- "`enable_explainability` cannot be set to True when the model is not runnable in WH "
197
- "or the target platforms include SPCS."
196
+ "`enable_explainability` cannot be set to True when the model cannot run in Warehouse."
198
197
  )
199
198
  elif has_both_platforms:
200
199
  warnings.warn(
201
- ("Explain function will only be available for model deployed to warehouse."),
200
+ ("Explain function will only be available for model deployed to Warehouse."),
202
201
  category=UserWarning,
203
202
  stacklevel=2,
204
203
  )
snowflake/ml/version.py CHANGED
@@ -1,2 +1,2 @@
1
1
  # This is parsed by regex in conda recipe meta file. Make sure not to break it.
2
- VERSION = "1.23.0"
2
+ VERSION = "1.25.0"