snowflake-ml-python 1.23.0__py3-none-any.whl → 1.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +0 -4
- snowflake/ml/_internal/utils/mixins.py +26 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +5 -1
- snowflake/ml/data/data_connector.py +2 -2
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/experiment/_experiment_info.py +3 -3
- snowflake/ml/feature_store/__init__.py +2 -0
- snowflake/ml/feature_store/aggregation.py +367 -0
- snowflake/ml/feature_store/feature.py +366 -0
- snowflake/ml/feature_store/feature_store.py +234 -20
- snowflake/ml/feature_store/feature_view.py +189 -4
- snowflake/ml/feature_store/metadata_manager.py +425 -0
- snowflake/ml/feature_store/tile_sql_generator.py +1079 -0
- snowflake/ml/jobs/_interop/data_utils.py +8 -8
- snowflake/ml/jobs/_interop/dto_schema.py +52 -7
- snowflake/ml/jobs/_interop/protocols.py +124 -7
- snowflake/ml/jobs/_interop/utils.py +92 -33
- snowflake/ml/jobs/_utils/arg_protocol.py +7 -0
- snowflake/ml/jobs/_utils/constants.py +4 -0
- snowflake/ml/jobs/_utils/feature_flags.py +97 -13
- snowflake/ml/jobs/_utils/payload_utils.py +6 -40
- snowflake/ml/jobs/_utils/runtime_env_utils.py +12 -111
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +204 -27
- snowflake/ml/jobs/decorators.py +17 -22
- snowflake/ml/jobs/job.py +25 -10
- snowflake/ml/jobs/job_definition.py +100 -8
- snowflake/ml/model/__init__.py +4 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +38 -2
- snowflake/ml/model/_client/model/model_version_impl.py +56 -28
- snowflake/ml/model/_client/ops/model_ops.py +2 -8
- snowflake/ml/model/_client/ops/service_ops.py +6 -11
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/service.py +21 -29
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -1
- snowflake/ml/model/_packager/model_handlers/huggingface.py +20 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +70 -14
- snowflake/ml/model/_signatures/utils.py +76 -1
- snowflake/ml/model/models/huggingface_pipeline.py +3 -0
- snowflake/ml/model/openai_signatures.py +154 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/METADATA +79 -2
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/RECORD +47 -44
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/WHEEL +1 -1
- snowflake/ml/jobs/_utils/function_payload_utils.py +0 -43
- snowflake/ml/jobs/_utils/spec_utils.py +0 -22
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.23.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/top_level.txt +0 -0
|
@@ -75,10 +75,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
75
75
|
DESC_SERVICE_SPEC_COL_NAME = "spec"
|
|
76
76
|
DESC_SERVICE_CONTAINERS_SPEC_NAME = "containers"
|
|
77
77
|
DESC_SERVICE_NAME_SPEC_NAME = "name"
|
|
78
|
-
|
|
79
|
-
PROXY_CONTAINER_NAME = "proxy"
|
|
78
|
+
DESC_SERVICE_ENV_SPEC_NAME = "env"
|
|
80
79
|
MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME = "SPCS_MODEL_INFERENCE_SERVER__AUTOCAPTURE_ENABLED"
|
|
81
|
-
FEATURE_MODEL_INFERENCE_AUTOCAPTURE = "FEATURE_MODEL_INFERENCE_AUTOCAPTURE"
|
|
82
80
|
|
|
83
81
|
@contextlib.contextmanager
|
|
84
82
|
def _qmark_paramstyle(self) -> Generator[None, None, None]:
|
|
@@ -285,39 +283,33 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
285
283
|
)
|
|
286
284
|
return rows[0]
|
|
287
285
|
|
|
288
|
-
def
|
|
289
|
-
"""Extract whether service has autocapture enabled
|
|
286
|
+
def is_autocapture_enabled(self, row: row.Row) -> bool:
|
|
287
|
+
"""Extract whether service has autocapture enabled in any container from service spec.
|
|
290
288
|
|
|
291
289
|
Args:
|
|
292
290
|
row: A row.Row object from DESCRIBE SERVICE containing the service YAML spec.
|
|
293
291
|
|
|
294
292
|
Returns:
|
|
295
|
-
True if autocapture is enabled in
|
|
296
|
-
False if disabled or not set in
|
|
297
|
-
False if service doesn't have proxy container
|
|
293
|
+
True if autocapture is enabled in any container.
|
|
294
|
+
False if autocapture is disabled or not set in any container.
|
|
298
295
|
"""
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
if spec_yaml is None:
|
|
302
|
-
return False
|
|
303
|
-
spec_raw = yaml.safe_load(spec_yaml)
|
|
304
|
-
if spec_raw is None:
|
|
305
|
-
return False
|
|
306
|
-
spec = cast(dict[str, Any], spec_raw)
|
|
307
|
-
|
|
308
|
-
proxy_container_spec = next(
|
|
309
|
-
container
|
|
310
|
-
for container in spec[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME][
|
|
311
|
-
ServiceSQLClient.DESC_SERVICE_CONTAINERS_SPEC_NAME
|
|
312
|
-
]
|
|
313
|
-
if container[ServiceSQLClient.DESC_SERVICE_NAME_SPEC_NAME] == ServiceSQLClient.PROXY_CONTAINER_NAME
|
|
314
|
-
)
|
|
315
|
-
env = proxy_container_spec.get(ServiceSQLClient.DESC_SERVICE_PROXY_SPEC_ENV_NAME, {})
|
|
316
|
-
autocapture_enabled = env.get(ServiceSQLClient.MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME, "false")
|
|
317
|
-
return str(autocapture_enabled).lower() == "true"
|
|
318
|
-
|
|
319
|
-
except StopIteration:
|
|
296
|
+
spec_yaml = row.as_dict().get(ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME)
|
|
297
|
+
if spec_yaml is None:
|
|
320
298
|
return False
|
|
299
|
+
spec_raw = yaml.safe_load(spec_yaml)
|
|
300
|
+
if spec_raw is None:
|
|
301
|
+
return False
|
|
302
|
+
spec = cast(dict[str, Any], spec_raw)
|
|
303
|
+
|
|
304
|
+
containers = spec[ServiceSQLClient.DESC_SERVICE_SPEC_COL_NAME][
|
|
305
|
+
ServiceSQLClient.DESC_SERVICE_CONTAINERS_SPEC_NAME
|
|
306
|
+
]
|
|
307
|
+
for container in containers:
|
|
308
|
+
env = container.get(ServiceSQLClient.DESC_SERVICE_ENV_SPEC_NAME, {})
|
|
309
|
+
autocapture_enabled = env.get(ServiceSQLClient.MODEL_INFERENCE_AUTOCAPTURE_ENV_NAME, "false")
|
|
310
|
+
if str(autocapture_enabled).lower() == "true":
|
|
311
|
+
return True
|
|
312
|
+
return False
|
|
321
313
|
|
|
322
314
|
def drop_service(
|
|
323
315
|
self,
|
|
@@ -156,7 +156,8 @@ class ModelMethod:
|
|
|
156
156
|
f"Your parameter {param_spec.name} cannot be resolved as valid SQL identifier. "
|
|
157
157
|
"Try specifying `case_sensitive` as True."
|
|
158
158
|
) from e
|
|
159
|
-
|
|
159
|
+
# Convert None to "NULL" string so MANIFEST parser can interpret it as SQL NULL
|
|
160
|
+
default_value = "NULL" if param_spec.default_value is None else str(param_spec.default_value)
|
|
160
161
|
return model_manifest_schema.ModelMethodSignatureFieldWithNameAndDefault(
|
|
161
162
|
name=param_name.resolved(),
|
|
162
163
|
type=type_utils.convert_sp_to_sf_type(param_spec.dtype.as_snowpark_type()),
|
|
@@ -574,6 +574,26 @@ class TransformersPipelineHandler(
|
|
|
574
574
|
input_col = signature.inputs[0].name
|
|
575
575
|
audio_inputs = X[input_col].to_list()
|
|
576
576
|
temp_res = [getattr(raw_model, target_method)(audio) for audio in audio_inputs]
|
|
577
|
+
elif isinstance(raw_model, transformers.VideoClassificationPipeline):
|
|
578
|
+
# Video classification expects file paths. Write bytes to temp files,
|
|
579
|
+
# process them, and clean up.
|
|
580
|
+
import tempfile
|
|
581
|
+
|
|
582
|
+
input_col = signature.inputs[0].name
|
|
583
|
+
video_bytes_list = X[input_col].to_list()
|
|
584
|
+
temp_file_paths = []
|
|
585
|
+
temp_files = []
|
|
586
|
+
try:
|
|
587
|
+
# TODO: parallelize this if needed
|
|
588
|
+
for video_bytes in video_bytes_list:
|
|
589
|
+
temp_file = tempfile.NamedTemporaryFile()
|
|
590
|
+
temp_file.write(video_bytes)
|
|
591
|
+
temp_file_paths.append(temp_file.name)
|
|
592
|
+
temp_files.append(temp_file)
|
|
593
|
+
temp_res = getattr(raw_model, target_method)(temp_file_paths)
|
|
594
|
+
finally:
|
|
595
|
+
for f in temp_files:
|
|
596
|
+
f.close()
|
|
577
597
|
else:
|
|
578
598
|
# TODO: remove conversational pipeline code
|
|
579
599
|
# For others, we could offer the whole dataframe as a list.
|
|
@@ -16,6 +16,7 @@ from snowflake.ml.model._packager.model_meta import (
|
|
|
16
16
|
model_meta as model_meta_api,
|
|
17
17
|
model_meta_schema,
|
|
18
18
|
)
|
|
19
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
|
19
20
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
20
21
|
|
|
21
22
|
if TYPE_CHECKING:
|
|
@@ -24,10 +25,14 @@ if TYPE_CHECKING:
|
|
|
24
25
|
logger = logging.getLogger(__name__)
|
|
25
26
|
|
|
26
27
|
# Allowlist of supported target methods for SentenceTransformer models.
|
|
27
|
-
|
|
28
|
+
# Note: sentence-transformers >= 3.0 uses singular names (encode_query, encode_document)
|
|
29
|
+
# while older versions may use plural names (encode_queries, encode_documents).
|
|
30
|
+
_ALLOWED_TARGET_METHODS = ["encode", "encode_query", "encode_document", "encode_queries", "encode_documents"]
|
|
28
31
|
|
|
29
32
|
|
|
30
|
-
def _validate_sentence_transformers_signatures(
|
|
33
|
+
def _validate_sentence_transformers_signatures(
|
|
34
|
+
sigs: dict[str, model_signature.ModelSignature],
|
|
35
|
+
) -> None:
|
|
31
36
|
"""Validate signatures for SentenceTransformer models.
|
|
32
37
|
|
|
33
38
|
Args:
|
|
@@ -82,7 +87,9 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
|
82
87
|
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
|
83
88
|
|
|
84
89
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
|
85
|
-
|
|
90
|
+
# Default to singular names which are used in sentence-transformers >= 3.0
|
|
91
|
+
DEFAULT_TARGET_METHODS = ["encode", "encode_query", "encode_document"]
|
|
92
|
+
IS_AUTO_SIGNATURE = True
|
|
86
93
|
|
|
87
94
|
@classmethod
|
|
88
95
|
def can_handle(
|
|
@@ -138,7 +145,8 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
|
138
145
|
raise ValueError(f"target_methods {target_methods} must be a subset of {_ALLOWED_TARGET_METHODS}.")
|
|
139
146
|
|
|
140
147
|
def get_prediction(
|
|
141
|
-
target_method_name: str,
|
|
148
|
+
target_method_name: str,
|
|
149
|
+
sample_input_data: model_types.SupportedLocalDataType,
|
|
142
150
|
) -> model_types.SupportedLocalDataType:
|
|
143
151
|
if not isinstance(sample_input_data, pd.DataFrame):
|
|
144
152
|
sample_input_data = model_signature._convert_local_data_to_df(data=sample_input_data)
|
|
@@ -149,8 +157,13 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
|
149
157
|
)
|
|
150
158
|
X_list = sample_input_data.iloc[:, 0].tolist()
|
|
151
159
|
|
|
152
|
-
|
|
153
|
-
|
|
160
|
+
# Call the appropriate method based on target_method_name
|
|
161
|
+
method_to_call = getattr(model, target_method_name, None)
|
|
162
|
+
if not callable(method_to_call):
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"SentenceTransformer model does not have a callable method '{target_method_name}'."
|
|
165
|
+
)
|
|
166
|
+
return pd.DataFrame({0: method_to_call(X_list, batch_size=batch_size).tolist()})
|
|
154
167
|
|
|
155
168
|
if model_meta.signatures:
|
|
156
169
|
handlers_utils.validate_target_methods(model, list(model_meta.signatures.keys()))
|
|
@@ -171,6 +184,36 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
|
171
184
|
sample_input_data=sample_input_data,
|
|
172
185
|
get_prediction_fn=get_prediction,
|
|
173
186
|
)
|
|
187
|
+
else:
|
|
188
|
+
# Auto-infer signature from model when no sample_input_data is provided
|
|
189
|
+
# Get the embedding dimension from the model
|
|
190
|
+
embedding_dim = model.get_sentence_embedding_dimension()
|
|
191
|
+
if embedding_dim is None:
|
|
192
|
+
raise ValueError(
|
|
193
|
+
"Unable to auto-infer signature: model.get_sentence_embedding_dimension() returned None. "
|
|
194
|
+
"Please provide sample_input_data or signatures explicitly."
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
for target_method in target_methods:
|
|
198
|
+
# target_methods are already validated as callable by get_target_methods()
|
|
199
|
+
inferred_sig = model_signature_utils.sentence_transformers_signature_auto_infer(
|
|
200
|
+
target_method=target_method,
|
|
201
|
+
embedding_dim=embedding_dim,
|
|
202
|
+
)
|
|
203
|
+
if inferred_sig is None:
|
|
204
|
+
raise ValueError(
|
|
205
|
+
f"Unable to auto-infer signature for method '{target_method}'. "
|
|
206
|
+
"Please provide sample_input_data or signatures explicitly."
|
|
207
|
+
)
|
|
208
|
+
model_meta.signatures[target_method] = inferred_sig
|
|
209
|
+
|
|
210
|
+
# Ensure at least one method was successfully inferred
|
|
211
|
+
if not model_meta.signatures:
|
|
212
|
+
raise ValueError(
|
|
213
|
+
"No valid target methods found on the model. "
|
|
214
|
+
"Please provide sample_input_data or signatures explicitly, "
|
|
215
|
+
"or specify target_methods that exist on your model."
|
|
216
|
+
)
|
|
174
217
|
|
|
175
218
|
_validate_sentence_transformers_signatures(model_meta.signatures)
|
|
176
219
|
|
|
@@ -196,7 +239,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
|
196
239
|
|
|
197
240
|
model_meta.env.include_if_absent(
|
|
198
241
|
[
|
|
199
|
-
model_env.ModelDependency(
|
|
242
|
+
model_env.ModelDependency(
|
|
243
|
+
requirement="sentence-transformers",
|
|
244
|
+
pip_name="sentence-transformers",
|
|
245
|
+
),
|
|
200
246
|
model_env.ModelDependency(requirement="transformers", pip_name="transformers"),
|
|
201
247
|
model_env.ModelDependency(requirement="pytorch", pip_name="torch"),
|
|
202
248
|
],
|
|
@@ -205,7 +251,9 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
|
205
251
|
model_meta.env.cuda_version = kwargs.get("cuda_version", handlers_utils.get_default_cuda_version())
|
|
206
252
|
|
|
207
253
|
@staticmethod
|
|
208
|
-
def _get_device_config(
|
|
254
|
+
def _get_device_config(
|
|
255
|
+
**kwargs: Unpack[model_types.SentenceTransformersLoadOptions],
|
|
256
|
+
) -> Optional[str]:
|
|
209
257
|
if kwargs.get("device", None) is not None:
|
|
210
258
|
return kwargs["device"]
|
|
211
259
|
elif kwargs.get("use_gpu", False):
|
|
@@ -262,7 +310,8 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
|
262
310
|
model_meta: model_meta_api.ModelMetadata,
|
|
263
311
|
) -> type[custom_model.CustomModel]:
|
|
264
312
|
batch_size = cast(
|
|
265
|
-
model_meta_schema.SentenceTransformersModelBlobOptions,
|
|
313
|
+
model_meta_schema.SentenceTransformersModelBlobOptions,
|
|
314
|
+
model_meta.models[model_meta.name].options,
|
|
266
315
|
).get("batch_size", None)
|
|
267
316
|
|
|
268
317
|
def get_prediction(
|
|
@@ -270,12 +319,20 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
|
270
319
|
signature: model_signature.ModelSignature,
|
|
271
320
|
target_method: str,
|
|
272
321
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
|
322
|
+
# Capture target_method in closure to call the correct model method
|
|
323
|
+
method_to_call = getattr(raw_model, target_method, None)
|
|
324
|
+
if not callable(method_to_call):
|
|
325
|
+
raise ValueError(
|
|
326
|
+
f"SentenceTransformer model does not have a callable method '{target_method}'. "
|
|
327
|
+
f"This method may not be available in your version of sentence-transformers."
|
|
328
|
+
)
|
|
329
|
+
|
|
273
330
|
@custom_model.inference_api
|
|
274
331
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
|
275
332
|
X_list = X.iloc[:, 0].tolist()
|
|
276
333
|
|
|
277
334
|
return pd.DataFrame(
|
|
278
|
-
{signature.outputs[0].name:
|
|
335
|
+
{signature.outputs[0].name: method_to_call(X_list, batch_size=batch_size).tolist()}
|
|
279
336
|
)
|
|
280
337
|
|
|
281
338
|
return fn
|
|
@@ -298,7 +355,6 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
|
298
355
|
model = raw_model
|
|
299
356
|
|
|
300
357
|
_SentenceTransformer = _create_custom_model(model, model_meta)
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
return sentence_transformers_SentenceTransformer_model
|
|
358
|
+
sentence_transformers_model = _SentenceTransformer(custom_model.ModelContext())
|
|
359
|
+
|
|
360
|
+
return sentence_transformers_model
|
|
@@ -298,6 +298,24 @@ def huggingface_pipeline_signature_auto_infer(
|
|
|
298
298
|
shape=(-1,), # Variable length list of chunks
|
|
299
299
|
),
|
|
300
300
|
],
|
|
301
|
+
)
|
|
302
|
+
],
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.VideoClassificationPipeline
|
|
306
|
+
if task == "video-classification":
|
|
307
|
+
return core.ModelSignature(
|
|
308
|
+
inputs=[
|
|
309
|
+
core.FeatureSpec(name="video", dtype=core.DataType.BYTES),
|
|
310
|
+
],
|
|
311
|
+
outputs=[
|
|
312
|
+
core.FeatureGroupSpec(
|
|
313
|
+
name="labels",
|
|
314
|
+
specs=[
|
|
315
|
+
core.FeatureSpec(name="label", dtype=core.DataType.STRING),
|
|
316
|
+
core.FeatureSpec(name="score", dtype=core.DataType.DOUBLE),
|
|
317
|
+
],
|
|
318
|
+
shape=(-1,),
|
|
301
319
|
),
|
|
302
320
|
],
|
|
303
321
|
)
|
|
@@ -333,7 +351,11 @@ def huggingface_pipeline_signature_auto_infer(
|
|
|
333
351
|
)
|
|
334
352
|
|
|
335
353
|
# https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ImageTextToTextPipeline
|
|
336
|
-
if task
|
|
354
|
+
if task in [
|
|
355
|
+
"image-text-to-text",
|
|
356
|
+
"video-text-to-text",
|
|
357
|
+
"audio-text-to-text",
|
|
358
|
+
]:
|
|
337
359
|
if params.get("return_tensors", False):
|
|
338
360
|
raise NotImplementedError(
|
|
339
361
|
f"Auto deployment for HuggingFace pipeline {task} "
|
|
@@ -461,3 +483,56 @@ def infer_dict(name: str, data: dict[str, Any]) -> core.FeatureGroupSpec:
|
|
|
461
483
|
|
|
462
484
|
def check_if_series_is_empty(series: Optional[pd.Series]) -> bool:
|
|
463
485
|
return series is None or series.empty
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def sentence_transformers_signature_auto_infer(
|
|
489
|
+
target_method: str,
|
|
490
|
+
embedding_dim: int,
|
|
491
|
+
) -> Optional[core.ModelSignature]:
|
|
492
|
+
"""Auto-infer signature for SentenceTransformer models.
|
|
493
|
+
|
|
494
|
+
SentenceTransformer models have a simple signature: they take a string input
|
|
495
|
+
and return an embedding vector (array of floats).
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
target_method: The target method name. Supported methods:
|
|
499
|
+
- "encode": General encoding method
|
|
500
|
+
- "encode_query" / "encode_queries": Query encoding for asymmetric search
|
|
501
|
+
- "encode_document" / "encode_documents": Document encoding for asymmetric search
|
|
502
|
+
embedding_dim: The dimension of the embedding vector output by the model.
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
A ModelSignature for the target method, or None if the method is not supported.
|
|
506
|
+
|
|
507
|
+
Note:
|
|
508
|
+
sentence-transformers >= 3.0 uses singular names (encode_query, encode_document)
|
|
509
|
+
while older versions may use plural names (encode_queries, encode_documents).
|
|
510
|
+
Both naming conventions are supported for backward compatibility.
|
|
511
|
+
"""
|
|
512
|
+
# Support both singular (new) and plural (old) naming conventions
|
|
513
|
+
supported_methods = [
|
|
514
|
+
"encode",
|
|
515
|
+
"encode_query",
|
|
516
|
+
"encode_document",
|
|
517
|
+
"encode_queries",
|
|
518
|
+
"encode_documents",
|
|
519
|
+
]
|
|
520
|
+
|
|
521
|
+
if target_method not in supported_methods:
|
|
522
|
+
return None
|
|
523
|
+
|
|
524
|
+
# All SentenceTransformer encode methods have the same signature pattern:
|
|
525
|
+
# - Input: a single string column
|
|
526
|
+
# - Output: a single column containing embedding vectors (array of floats)
|
|
527
|
+
return core.ModelSignature(
|
|
528
|
+
inputs=[
|
|
529
|
+
core.FeatureSpec(name="text", dtype=core.DataType.STRING),
|
|
530
|
+
],
|
|
531
|
+
outputs=[
|
|
532
|
+
core.FeatureSpec(
|
|
533
|
+
name="output",
|
|
534
|
+
dtype=core.DataType.DOUBLE,
|
|
535
|
+
shape=(embedding_dim,),
|
|
536
|
+
),
|
|
537
|
+
],
|
|
538
|
+
)
|
|
@@ -105,6 +105,7 @@ class HuggingFacePipelineModel(huggingface.TransformersPipeline):
|
|
|
105
105
|
image_repo: Optional[str] = None,
|
|
106
106
|
image_build_compute_pool: Optional[str] = None,
|
|
107
107
|
ingress_enabled: bool = False,
|
|
108
|
+
min_instances: int = 0,
|
|
108
109
|
max_instances: int = 1,
|
|
109
110
|
cpu_requests: Optional[str] = None,
|
|
110
111
|
memory_requests: Optional[str] = None,
|
|
@@ -133,6 +134,7 @@ class HuggingFacePipelineModel(huggingface.TransformersPipeline):
|
|
|
133
134
|
image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
|
|
134
135
|
the service compute pool if None.
|
|
135
136
|
ingress_enabled: Whether ingress is enabled. Defaults to False.
|
|
137
|
+
min_instances: Minimum number of instances. Defaults to 0.
|
|
136
138
|
max_instances: Maximum number of instances. Defaults to 1.
|
|
137
139
|
cpu_requests: CPU requests configuration. Defaults to None.
|
|
138
140
|
memory_requests: Memory requests configuration. Defaults to None.
|
|
@@ -225,6 +227,7 @@ class HuggingFacePipelineModel(huggingface.TransformersPipeline):
|
|
|
225
227
|
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
|
226
228
|
image_repo_name=image_repo,
|
|
227
229
|
ingress_enabled=ingress_enabled,
|
|
230
|
+
min_instances=min_instances,
|
|
228
231
|
max_instances=max_instances,
|
|
229
232
|
cpu_requests=cpu_requests,
|
|
230
233
|
memory_requests=memory_requests,
|
|
@@ -88,6 +88,96 @@ _OPENAI_CHAT_SIGNATURE_SPEC = core.ModelSignature(
|
|
|
88
88
|
],
|
|
89
89
|
)
|
|
90
90
|
|
|
91
|
+
_OPENAI_CHAT_SIGNATURE_WITH_PARAMS_SPEC = core.ModelSignature(
|
|
92
|
+
inputs=[
|
|
93
|
+
core.FeatureGroupSpec(
|
|
94
|
+
name="messages",
|
|
95
|
+
specs=[
|
|
96
|
+
core.FeatureGroupSpec(
|
|
97
|
+
name="content",
|
|
98
|
+
specs=[
|
|
99
|
+
core.FeatureSpec(name="type", dtype=core.DataType.STRING),
|
|
100
|
+
# Text prompts
|
|
101
|
+
core.FeatureSpec(name="text", dtype=core.DataType.STRING),
|
|
102
|
+
# Image URL prompts
|
|
103
|
+
core.FeatureGroupSpec(
|
|
104
|
+
name="image_url",
|
|
105
|
+
specs=[
|
|
106
|
+
# Base64 encoded image URL or image URL
|
|
107
|
+
core.FeatureSpec(name="url", dtype=core.DataType.STRING),
|
|
108
|
+
# Image detail level (e.g., "low", "high", "auto")
|
|
109
|
+
core.FeatureSpec(name="detail", dtype=core.DataType.STRING),
|
|
110
|
+
],
|
|
111
|
+
),
|
|
112
|
+
# Video URL prompts
|
|
113
|
+
core.FeatureGroupSpec(
|
|
114
|
+
name="video_url",
|
|
115
|
+
specs=[
|
|
116
|
+
# Base64 encoded video URL
|
|
117
|
+
core.FeatureSpec(name="url", dtype=core.DataType.STRING),
|
|
118
|
+
],
|
|
119
|
+
),
|
|
120
|
+
# Audio prompts
|
|
121
|
+
core.FeatureGroupSpec(
|
|
122
|
+
name="input_audio",
|
|
123
|
+
specs=[
|
|
124
|
+
core.FeatureSpec(name="data", dtype=core.DataType.STRING),
|
|
125
|
+
core.FeatureSpec(name="format", dtype=core.DataType.STRING),
|
|
126
|
+
],
|
|
127
|
+
),
|
|
128
|
+
],
|
|
129
|
+
shape=(-1,),
|
|
130
|
+
),
|
|
131
|
+
core.FeatureSpec(name="name", dtype=core.DataType.STRING),
|
|
132
|
+
core.FeatureSpec(name="role", dtype=core.DataType.STRING),
|
|
133
|
+
core.FeatureSpec(name="title", dtype=core.DataType.STRING),
|
|
134
|
+
],
|
|
135
|
+
shape=(-1,),
|
|
136
|
+
),
|
|
137
|
+
],
|
|
138
|
+
outputs=[
|
|
139
|
+
core.FeatureSpec(name="id", dtype=core.DataType.STRING),
|
|
140
|
+
core.FeatureSpec(name="object", dtype=core.DataType.STRING),
|
|
141
|
+
core.FeatureSpec(name="created", dtype=core.DataType.FLOAT),
|
|
142
|
+
core.FeatureSpec(name="model", dtype=core.DataType.STRING),
|
|
143
|
+
core.FeatureGroupSpec(
|
|
144
|
+
name="choices",
|
|
145
|
+
specs=[
|
|
146
|
+
core.FeatureSpec(name="index", dtype=core.DataType.INT32),
|
|
147
|
+
core.FeatureGroupSpec(
|
|
148
|
+
name="message",
|
|
149
|
+
specs=[
|
|
150
|
+
core.FeatureSpec(name="content", dtype=core.DataType.STRING),
|
|
151
|
+
core.FeatureSpec(name="name", dtype=core.DataType.STRING),
|
|
152
|
+
core.FeatureSpec(name="role", dtype=core.DataType.STRING),
|
|
153
|
+
],
|
|
154
|
+
),
|
|
155
|
+
core.FeatureSpec(name="logprobs", dtype=core.DataType.STRING),
|
|
156
|
+
core.FeatureSpec(name="finish_reason", dtype=core.DataType.STRING),
|
|
157
|
+
],
|
|
158
|
+
shape=(-1,),
|
|
159
|
+
),
|
|
160
|
+
core.FeatureGroupSpec(
|
|
161
|
+
name="usage",
|
|
162
|
+
specs=[
|
|
163
|
+
core.FeatureSpec(name="completion_tokens", dtype=core.DataType.INT32),
|
|
164
|
+
core.FeatureSpec(name="prompt_tokens", dtype=core.DataType.INT32),
|
|
165
|
+
core.FeatureSpec(name="total_tokens", dtype=core.DataType.INT32),
|
|
166
|
+
],
|
|
167
|
+
),
|
|
168
|
+
],
|
|
169
|
+
params=[
|
|
170
|
+
core.ParamSpec(name="temperature", dtype=core.DataType.DOUBLE, default_value=1.0),
|
|
171
|
+
core.ParamSpec(name="max_completion_tokens", dtype=core.DataType.INT64, default_value=250),
|
|
172
|
+
core.ParamSpec(name="stop", dtype=core.DataType.STRING, default_value=""),
|
|
173
|
+
core.ParamSpec(name="n", dtype=core.DataType.INT32, default_value=1),
|
|
174
|
+
core.ParamSpec(name="stream", dtype=core.DataType.BOOL, default_value=False),
|
|
175
|
+
core.ParamSpec(name="top_p", dtype=core.DataType.DOUBLE, default_value=1.0),
|
|
176
|
+
core.ParamSpec(name="frequency_penalty", dtype=core.DataType.DOUBLE, default_value=0.0),
|
|
177
|
+
core.ParamSpec(name="presence_penalty", dtype=core.DataType.DOUBLE, default_value=0.0),
|
|
178
|
+
],
|
|
179
|
+
)
|
|
180
|
+
|
|
91
181
|
_OPENAI_CHAT_SIGNATURE_SPEC_WITH_CONTENT_FORMAT_STRING = core.ModelSignature(
|
|
92
182
|
inputs=[
|
|
93
183
|
core.FeatureGroupSpec(
|
|
@@ -142,6 +232,62 @@ _OPENAI_CHAT_SIGNATURE_SPEC_WITH_CONTENT_FORMAT_STRING = core.ModelSignature(
|
|
|
142
232
|
],
|
|
143
233
|
)
|
|
144
234
|
|
|
235
|
+
_OPENAI_CHAT_SIGNATURE_WITH_PARAMS_SPEC_WITH_CONTENT_FORMAT_STRING = core.ModelSignature(
|
|
236
|
+
inputs=[
|
|
237
|
+
core.FeatureGroupSpec(
|
|
238
|
+
name="messages",
|
|
239
|
+
specs=[
|
|
240
|
+
core.FeatureSpec(name="content", dtype=core.DataType.STRING),
|
|
241
|
+
core.FeatureSpec(name="name", dtype=core.DataType.STRING),
|
|
242
|
+
core.FeatureSpec(name="role", dtype=core.DataType.STRING),
|
|
243
|
+
core.FeatureSpec(name="title", dtype=core.DataType.STRING),
|
|
244
|
+
],
|
|
245
|
+
shape=(-1,),
|
|
246
|
+
),
|
|
247
|
+
],
|
|
248
|
+
outputs=[
|
|
249
|
+
core.FeatureSpec(name="id", dtype=core.DataType.STRING),
|
|
250
|
+
core.FeatureSpec(name="object", dtype=core.DataType.STRING),
|
|
251
|
+
core.FeatureSpec(name="created", dtype=core.DataType.FLOAT),
|
|
252
|
+
core.FeatureSpec(name="model", dtype=core.DataType.STRING),
|
|
253
|
+
core.FeatureGroupSpec(
|
|
254
|
+
name="choices",
|
|
255
|
+
specs=[
|
|
256
|
+
core.FeatureSpec(name="index", dtype=core.DataType.INT32),
|
|
257
|
+
core.FeatureGroupSpec(
|
|
258
|
+
name="message",
|
|
259
|
+
specs=[
|
|
260
|
+
core.FeatureSpec(name="content", dtype=core.DataType.STRING),
|
|
261
|
+
core.FeatureSpec(name="name", dtype=core.DataType.STRING),
|
|
262
|
+
core.FeatureSpec(name="role", dtype=core.DataType.STRING),
|
|
263
|
+
],
|
|
264
|
+
),
|
|
265
|
+
core.FeatureSpec(name="logprobs", dtype=core.DataType.STRING),
|
|
266
|
+
core.FeatureSpec(name="finish_reason", dtype=core.DataType.STRING),
|
|
267
|
+
],
|
|
268
|
+
shape=(-1,),
|
|
269
|
+
),
|
|
270
|
+
core.FeatureGroupSpec(
|
|
271
|
+
name="usage",
|
|
272
|
+
specs=[
|
|
273
|
+
core.FeatureSpec(name="completion_tokens", dtype=core.DataType.INT32),
|
|
274
|
+
core.FeatureSpec(name="prompt_tokens", dtype=core.DataType.INT32),
|
|
275
|
+
core.FeatureSpec(name="total_tokens", dtype=core.DataType.INT32),
|
|
276
|
+
],
|
|
277
|
+
),
|
|
278
|
+
],
|
|
279
|
+
params=[
|
|
280
|
+
core.ParamSpec(name="temperature", dtype=core.DataType.DOUBLE, default_value=1.0),
|
|
281
|
+
core.ParamSpec(name="max_completion_tokens", dtype=core.DataType.INT64, default_value=250),
|
|
282
|
+
core.ParamSpec(name="stop", dtype=core.DataType.STRING, default_value=""),
|
|
283
|
+
core.ParamSpec(name="n", dtype=core.DataType.INT32, default_value=1),
|
|
284
|
+
core.ParamSpec(name="stream", dtype=core.DataType.BOOL, default_value=False),
|
|
285
|
+
core.ParamSpec(name="top_p", dtype=core.DataType.DOUBLE, default_value=1.0),
|
|
286
|
+
core.ParamSpec(name="frequency_penalty", dtype=core.DataType.DOUBLE, default_value=0.0),
|
|
287
|
+
core.ParamSpec(name="presence_penalty", dtype=core.DataType.DOUBLE, default_value=0.0),
|
|
288
|
+
],
|
|
289
|
+
)
|
|
290
|
+
|
|
145
291
|
|
|
146
292
|
# Refer vLLM documentation: https://docs.vllm.ai/en/stable/serving/openai_compatible_server/#chat-template
|
|
147
293
|
|
|
@@ -152,3 +298,11 @@ OPENAI_CHAT_SIGNATURE_WITH_CONTENT_FORMAT_STRING = {"__call__": _OPENAI_CHAT_SIG
|
|
|
152
298
|
# This is the default signature.
|
|
153
299
|
# The content format allows vLLM to handler content parts like text, image, video, audio, file, etc.
|
|
154
300
|
OPENAI_CHAT_SIGNATURE = {"__call__": _OPENAI_CHAT_SIGNATURE_SPEC}
|
|
301
|
+
|
|
302
|
+
# Use this signature to leverage ParamSpec with the default ChatML template.
|
|
303
|
+
OPENAI_CHAT_WITH_PARAMS_SIGNATURE = {"__call__": _OPENAI_CHAT_SIGNATURE_WITH_PARAMS_SPEC}
|
|
304
|
+
|
|
305
|
+
# Use this signature to leverage ParamSpec with the content format string.
|
|
306
|
+
OPENAI_CHAT_WITH_PARAMS_SIGNATURE_WITH_CONTENT_FORMAT_STRING = {
|
|
307
|
+
"__call__": _OPENAI_CHAT_SIGNATURE_WITH_PARAMS_SPEC_WITH_CONTENT_FORMAT_STRING
|
|
308
|
+
}
|
|
@@ -193,12 +193,11 @@ class ModelParameterReconciler:
|
|
|
193
193
|
if enable_explainability:
|
|
194
194
|
if only_spcs or not is_warehouse_runnable:
|
|
195
195
|
raise ValueError(
|
|
196
|
-
"`enable_explainability` cannot be set to True when the model
|
|
197
|
-
"or the target platforms include SPCS."
|
|
196
|
+
"`enable_explainability` cannot be set to True when the model cannot run in Warehouse."
|
|
198
197
|
)
|
|
199
198
|
elif has_both_platforms:
|
|
200
199
|
warnings.warn(
|
|
201
|
-
("Explain function will only be available for model deployed to
|
|
200
|
+
("Explain function will only be available for model deployed to Warehouse."),
|
|
202
201
|
category=UserWarning,
|
|
203
202
|
stacklevel=2,
|
|
204
203
|
)
|
snowflake/ml/version.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
# This is parsed by regex in conda recipe meta file. Make sure not to break it.
|
|
2
|
-
VERSION = "1.
|
|
2
|
+
VERSION = "1.25.0"
|