snowflake-ml-python 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +16 -0
- snowflake/ml/_internal/platform_capabilities.py +36 -0
- snowflake/ml/_internal/telemetry.py +56 -7
- snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
- snowflake/ml/data/data_connector.py +103 -1
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
- snowflake/ml/experiment/_entities/run.py +15 -0
- snowflake/ml/experiment/callback/keras.py +25 -2
- snowflake/ml/experiment/callback/lightgbm.py +27 -2
- snowflake/ml/experiment/callback/xgboost.py +25 -2
- snowflake/ml/experiment/experiment_tracking.py +123 -13
- snowflake/ml/experiment/utils.py +6 -0
- snowflake/ml/feature_store/access_manager.py +1 -0
- snowflake/ml/feature_store/feature_store.py +1 -1
- snowflake/ml/feature_store/feature_view.py +34 -24
- snowflake/ml/jobs/_interop/protocols.py +3 -0
- snowflake/ml/jobs/_utils/feature_flags.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +360 -357
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
- snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
- snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
- snowflake/ml/jobs/_utils/spec_utils.py +2 -406
- snowflake/ml/jobs/_utils/stage_utils.py +22 -1
- snowflake/ml/jobs/_utils/types.py +14 -7
- snowflake/ml/jobs/job.py +8 -9
- snowflake/ml/jobs/manager.py +64 -129
- snowflake/ml/model/_client/model/inference_engine_utils.py +8 -4
- snowflake/ml/model/_client/model/model_version_impl.py +109 -28
- snowflake/ml/model/_client/ops/model_ops.py +32 -6
- snowflake/ml/model/_client/ops/service_ops.py +9 -4
- snowflake/ml/model/_client/sql/service.py +69 -2
- snowflake/ml/model/_packager/model_handler.py +8 -2
- snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
- snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_signatures/core.py +305 -8
- snowflake/ml/model/_signatures/utils.py +13 -4
- snowflake/ml/model/compute_pool.py +2 -0
- snowflake/ml/model/models/huggingface.py +285 -0
- snowflake/ml/model/models/huggingface_pipeline.py +25 -215
- snowflake/ml/model/type_hints.py +5 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
- snowflake/ml/utils/html_utils.py +67 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/METADATA +94 -7
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/RECORD +52 -48
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.19.0.dist-info → snowflake_ml_python-1.21.0.dist-info}/top_level.txt +0 -0
|
@@ -29,12 +29,16 @@ from snowflake.ml.model._packager.model_meta import (
|
|
|
29
29
|
model_meta_schema,
|
|
30
30
|
)
|
|
31
31
|
from snowflake.ml.model._signatures import utils as model_signature_utils
|
|
32
|
-
from snowflake.ml.model.models import
|
|
32
|
+
from snowflake.ml.model.models import (
|
|
33
|
+
huggingface as huggingface_base,
|
|
34
|
+
huggingface_pipeline,
|
|
35
|
+
)
|
|
33
36
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
34
37
|
|
|
35
38
|
logger = logging.getLogger(__name__)
|
|
36
39
|
|
|
37
40
|
if TYPE_CHECKING:
|
|
41
|
+
import torch
|
|
38
42
|
import transformers
|
|
39
43
|
|
|
40
44
|
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
|
|
@@ -77,8 +81,14 @@ def sanitize_output(data: Any) -> Any:
|
|
|
77
81
|
|
|
78
82
|
|
|
79
83
|
@final
|
|
80
|
-
class
|
|
81
|
-
_base.BaseModelHandler[
|
|
84
|
+
class TransformersPipelineHandler(
|
|
85
|
+
_base.BaseModelHandler[
|
|
86
|
+
Union[
|
|
87
|
+
huggingface_base.TransformersPipeline,
|
|
88
|
+
huggingface_pipeline.HuggingFacePipelineModel,
|
|
89
|
+
"transformers.Pipeline",
|
|
90
|
+
]
|
|
91
|
+
]
|
|
82
92
|
):
|
|
83
93
|
"""Handler for custom model."""
|
|
84
94
|
|
|
@@ -97,35 +107,48 @@ class HuggingFacePipelineHandler(
|
|
|
97
107
|
def can_handle(
|
|
98
108
|
cls,
|
|
99
109
|
model: model_types.SupportedModelType,
|
|
100
|
-
) -> TypeGuard[
|
|
110
|
+
) -> TypeGuard[
|
|
111
|
+
Union[
|
|
112
|
+
huggingface_base.TransformersPipeline,
|
|
113
|
+
huggingface_pipeline.HuggingFacePipelineModel,
|
|
114
|
+
"transformers.Pipeline",
|
|
115
|
+
]
|
|
116
|
+
]:
|
|
101
117
|
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
|
102
118
|
return True
|
|
103
119
|
if isinstance(model, huggingface_pipeline.HuggingFacePipelineModel):
|
|
104
120
|
return True
|
|
121
|
+
if isinstance(model, huggingface_base.TransformersPipeline):
|
|
122
|
+
return True
|
|
105
123
|
return False
|
|
106
124
|
|
|
107
125
|
@classmethod
|
|
108
126
|
def cast_model(
|
|
109
127
|
cls,
|
|
110
128
|
model: model_types.SupportedModelType,
|
|
111
|
-
) -> Union[
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
except ImportError:
|
|
118
|
-
assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
|
|
129
|
+
) -> Union[
|
|
130
|
+
huggingface_base.TransformersPipeline,
|
|
131
|
+
huggingface_pipeline.HuggingFacePipelineModel,
|
|
132
|
+
"transformers.Pipeline",
|
|
133
|
+
]:
|
|
134
|
+
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
|
119
135
|
return model
|
|
120
|
-
|
|
121
|
-
|
|
136
|
+
elif isinstance(model, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
|
|
137
|
+
model, huggingface_base.TransformersPipeline
|
|
138
|
+
):
|
|
122
139
|
return model
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError(f"Model {model} is not a valid Hugging Face model.")
|
|
123
142
|
|
|
124
143
|
@classmethod
|
|
125
144
|
def save_model(
|
|
126
145
|
cls,
|
|
127
146
|
name: str,
|
|
128
|
-
model: Union[
|
|
147
|
+
model: Union[
|
|
148
|
+
huggingface_base.TransformersPipeline,
|
|
149
|
+
huggingface_pipeline.HuggingFacePipelineModel,
|
|
150
|
+
"transformers.Pipeline",
|
|
151
|
+
],
|
|
129
152
|
model_meta: model_meta_api.ModelMetadata,
|
|
130
153
|
model_blobs_dir_path: str,
|
|
131
154
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
|
@@ -140,7 +163,9 @@ class HuggingFacePipelineHandler(
|
|
|
140
163
|
framework = model.framework # type:ignore[attr-defined]
|
|
141
164
|
batch_size = model._batch_size # type:ignore[attr-defined]
|
|
142
165
|
else:
|
|
143
|
-
assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
|
|
166
|
+
assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
|
|
167
|
+
model, huggingface_base.TransformersPipeline
|
|
168
|
+
)
|
|
144
169
|
task = model.task
|
|
145
170
|
framework = getattr(model, "framework", None)
|
|
146
171
|
batch_size = getattr(model, "batch_size", None)
|
|
@@ -156,7 +181,9 @@ class HuggingFacePipelineHandler(
|
|
|
156
181
|
**model._postprocess_params, # type:ignore[attr-defined]
|
|
157
182
|
}
|
|
158
183
|
else:
|
|
159
|
-
assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel)
|
|
184
|
+
assert isinstance(model, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
|
|
185
|
+
model, huggingface_base.TransformersPipeline
|
|
186
|
+
)
|
|
160
187
|
params = {**model.__dict__, **model.model_kwargs}
|
|
161
188
|
|
|
162
189
|
inferred_pipe_sig = model_signature_utils.huggingface_pipeline_signature_auto_infer(
|
|
@@ -177,7 +204,7 @@ class HuggingFacePipelineHandler(
|
|
|
177
204
|
else:
|
|
178
205
|
warnings.warn(
|
|
179
206
|
"It is impossible to validate your model signatures when using a"
|
|
180
|
-
"
|
|
207
|
+
f" {type(model).__name__} object. "
|
|
181
208
|
"Please make sure you are providing correct model signatures.",
|
|
182
209
|
UserWarning,
|
|
183
210
|
stacklevel=2,
|
|
@@ -302,14 +329,16 @@ class HuggingFacePipelineHandler(
|
|
|
302
329
|
def _load_pickle_model(
|
|
303
330
|
pickle_file: str,
|
|
304
331
|
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
|
305
|
-
) -> huggingface_pipeline.HuggingFacePipelineModel:
|
|
332
|
+
) -> Union[huggingface_pipeline.HuggingFacePipelineModel, huggingface_base.TransformersPipeline]:
|
|
306
333
|
with open(pickle_file, "rb") as f:
|
|
307
334
|
m = cloudpickle.load(f)
|
|
308
|
-
assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel)
|
|
335
|
+
assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
|
|
336
|
+
m, huggingface_base.TransformersPipeline
|
|
337
|
+
)
|
|
309
338
|
torch_dtype: Optional[str] = None
|
|
310
339
|
device_config = None
|
|
311
340
|
if getattr(m, "device", None) is None and getattr(m, "device_map", None) is None:
|
|
312
|
-
device_config =
|
|
341
|
+
device_config = TransformersPipelineHandler._get_device_config(**kwargs)
|
|
313
342
|
m.__dict__.update(device_config)
|
|
314
343
|
|
|
315
344
|
if getattr(m, "torch_dtype", None) is None and kwargs.get("use_gpu", False):
|
|
@@ -326,7 +355,9 @@ class HuggingFacePipelineHandler(
|
|
|
326
355
|
model_meta: model_meta_api.ModelMetadata,
|
|
327
356
|
model_blobs_dir_path: str,
|
|
328
357
|
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
|
329
|
-
) -> Union[
|
|
358
|
+
) -> Union[
|
|
359
|
+
huggingface_pipeline.HuggingFacePipelineModel, huggingface_base.TransformersPipeline, "transformers.Pipeline"
|
|
360
|
+
]:
|
|
330
361
|
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
|
331
362
|
# We need to redirect the some folders to a writable location in the sandbox.
|
|
332
363
|
os.environ["HF_HOME"] = "/tmp"
|
|
@@ -369,7 +400,7 @@ class HuggingFacePipelineHandler(
|
|
|
369
400
|
) as f:
|
|
370
401
|
pipeline_params = cloudpickle.load(f)
|
|
371
402
|
|
|
372
|
-
device_config =
|
|
403
|
+
device_config = TransformersPipelineHandler._get_device_config(**kwargs)
|
|
373
404
|
|
|
374
405
|
m = transformers.pipeline(
|
|
375
406
|
model_blob_options["task"],
|
|
@@ -402,7 +433,7 @@ class HuggingFacePipelineHandler(
|
|
|
402
433
|
|
|
403
434
|
def _create_pipeline_from_model(
|
|
404
435
|
model_blob_file_or_dir_path: str,
|
|
405
|
-
m: huggingface_pipeline.HuggingFacePipelineModel,
|
|
436
|
+
m: Union[huggingface_pipeline.HuggingFacePipelineModel, huggingface_base.TransformersPipeline],
|
|
406
437
|
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
|
407
438
|
) -> "transformers.Pipeline":
|
|
408
439
|
import transformers
|
|
@@ -414,7 +445,7 @@ class HuggingFacePipelineHandler(
|
|
|
414
445
|
torch_dtype=getattr(m, "torch_dtype", None),
|
|
415
446
|
revision=m.revision,
|
|
416
447
|
# pass device or device_map when creating the pipeline
|
|
417
|
-
**
|
|
448
|
+
**TransformersPipelineHandler._get_device_config(**kwargs),
|
|
418
449
|
# pass other model_kwargs to transformers.pipeline.from_pretrained method
|
|
419
450
|
**m.model_kwargs,
|
|
420
451
|
)
|
|
@@ -455,7 +486,11 @@ class HuggingFacePipelineHandler(
|
|
|
455
486
|
@classmethod
|
|
456
487
|
def convert_as_custom_model(
|
|
457
488
|
cls,
|
|
458
|
-
raw_model: Union[
|
|
489
|
+
raw_model: Union[
|
|
490
|
+
huggingface_pipeline.HuggingFacePipelineModel,
|
|
491
|
+
huggingface_base.TransformersPipeline,
|
|
492
|
+
"transformers.Pipeline",
|
|
493
|
+
],
|
|
459
494
|
model_meta: model_meta_api.ModelMetadata,
|
|
460
495
|
background_data: Optional[pd.DataFrame] = None,
|
|
461
496
|
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
|
@@ -609,7 +644,9 @@ class HuggingFacePipelineHandler(
|
|
|
609
644
|
|
|
610
645
|
return _HFPipelineModel
|
|
611
646
|
|
|
612
|
-
if isinstance(raw_model, huggingface_pipeline.HuggingFacePipelineModel)
|
|
647
|
+
if isinstance(raw_model, huggingface_pipeline.HuggingFacePipelineModel) or isinstance(
|
|
648
|
+
raw_model, huggingface_base.TransformersPipeline
|
|
649
|
+
):
|
|
613
650
|
if version.parse(transformers.__version__) < version.parse("4.32.0"):
|
|
614
651
|
# Backward compatibility since HF interface change.
|
|
615
652
|
raw_model.__dict__["use_auth_token"] = raw_model.__dict__["token"]
|
|
@@ -668,43 +705,64 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
668
705
|
|
|
669
706
|
self.model_name = self.pipeline.model.name_or_path
|
|
670
707
|
|
|
708
|
+
if self.tokenizer.pad_token is None:
|
|
709
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
710
|
+
|
|
711
|
+
# Ensure the tokenizer has a chat template.
|
|
712
|
+
# If not, we inject the default ChatML template which supports prompt generation.
|
|
713
|
+
if not getattr(self.tokenizer, "chat_template", None):
|
|
714
|
+
logger.warning(f"No chat template found for {self.model_name}. Using default ChatML template.")
|
|
715
|
+
self.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
|
|
716
|
+
|
|
671
717
|
def _apply_chat_template(self, messages: list[dict[str, Any]]) -> str:
|
|
672
718
|
"""
|
|
673
719
|
Applies a chat template to a list of messages.
|
|
674
|
-
If the tokenizer has a chat template, it uses that.
|
|
675
|
-
Otherwise, it falls back to a simple concatenation.
|
|
676
720
|
|
|
677
721
|
Args:
|
|
678
|
-
messages (list[dict]): A list of message dictionaries
|
|
679
|
-
[{"role": "user", "content": "Hello!"}, ...]
|
|
722
|
+
messages (list[dict]): A list of message dictionaries.
|
|
680
723
|
|
|
681
724
|
Returns:
|
|
682
725
|
The formatted prompt string ready for model input.
|
|
683
726
|
"""
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
# `tokenize=False` means it returns a string, not token IDs
|
|
727
|
+
# Use the tokenizer's apply_chat_template method.
|
|
728
|
+
# We ensured a template exists in __init__.
|
|
729
|
+
if hasattr(self.tokenizer, "apply_chat_template"):
|
|
688
730
|
return self.tokenizer.apply_chat_template( # type: ignore[no-any-return]
|
|
689
731
|
messages,
|
|
690
732
|
tokenize=False,
|
|
691
733
|
add_generation_prompt=True,
|
|
692
734
|
)
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
735
|
+
|
|
736
|
+
# Fallback for very old transformers without apply_chat_template
|
|
737
|
+
# Manually apply ChatML-like formatting
|
|
738
|
+
prompt = ""
|
|
739
|
+
for message in messages:
|
|
740
|
+
role = message.get("role", "user")
|
|
741
|
+
content = message.get("content", "")
|
|
742
|
+
prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|
|
743
|
+
prompt += "<|im_start|>assistant\n"
|
|
744
|
+
return prompt
|
|
745
|
+
|
|
746
|
+
def _get_stopping_criteria(self, stop_strings: list[str]) -> "transformers.StoppingCriteriaList":
|
|
747
|
+
|
|
748
|
+
import transformers
|
|
749
|
+
|
|
750
|
+
class StopStringsStoppingCriteria(transformers.StoppingCriteria):
|
|
751
|
+
def __init__(self, stop_strings: list[str], tokenizer: Any) -> None:
|
|
752
|
+
self.stop_strings = stop_strings
|
|
753
|
+
self.tokenizer = tokenizer
|
|
754
|
+
|
|
755
|
+
def __call__(self, input_ids: "torch.Tensor", scores: "torch.Tensor", **kwargs: Any) -> bool:
|
|
756
|
+
# Decode the generated text for each sequence
|
|
757
|
+
for i in range(input_ids.shape[0]):
|
|
758
|
+
generated_text = self.tokenizer.decode(input_ids[i], skip_special_tokens=True)
|
|
759
|
+
# Check if any stop string appears in the generated text
|
|
760
|
+
for stop_str in self.stop_strings:
|
|
761
|
+
if stop_str in generated_text:
|
|
762
|
+
return True
|
|
763
|
+
return False
|
|
764
|
+
|
|
765
|
+
return transformers.StoppingCriteriaList([StopStringsStoppingCriteria(stop_strings, self.tokenizer)])
|
|
708
766
|
|
|
709
767
|
def generate_chat_completion(
|
|
710
768
|
self,
|
|
@@ -727,18 +785,17 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
727
785
|
{"role": "user", "content": "What is deep learning?"}]
|
|
728
786
|
max_completion_tokens (int): The maximum number of completion tokens to generate.
|
|
729
787
|
stop_strings (list[str]): A list of strings to stop generation.
|
|
730
|
-
temperature (float): The temperature for sampling.
|
|
731
|
-
top_p (float): The top-p value for sampling.
|
|
732
|
-
stream (bool): Whether to stream the generation.
|
|
733
|
-
frequency_penalty (float): The frequency penalty for sampling.
|
|
734
|
-
presence_penalty (float): The presence penalty for sampling.
|
|
788
|
+
temperature (float): The temperature for sampling. 0 means greedy decoding.
|
|
789
|
+
top_p (float): The top-p value for nucleus sampling.
|
|
790
|
+
stream (bool): Whether to stream the generation (not yet supported).
|
|
791
|
+
frequency_penalty (float): The frequency penalty for sampling (maps to repetition_penalty).
|
|
792
|
+
presence_penalty (float): The presence penalty for sampling (not directly supported).
|
|
735
793
|
n (int): The number of samples to generate.
|
|
736
794
|
|
|
737
795
|
Returns:
|
|
738
796
|
dict: An OpenAI-compatible dictionary representing the chat completion.
|
|
739
797
|
"""
|
|
740
798
|
# Apply chat template to convert messages into a single prompt string
|
|
741
|
-
|
|
742
799
|
prompt_text = self._apply_chat_template(messages)
|
|
743
800
|
|
|
744
801
|
# Tokenize the prompt
|
|
@@ -749,42 +806,112 @@ class HuggingFaceOpenAICompatibleModel:
|
|
|
749
806
|
).to(self.model.device)
|
|
750
807
|
prompt_tokens = inputs.input_ids.shape[1]
|
|
751
808
|
|
|
752
|
-
|
|
809
|
+
if stream:
|
|
810
|
+
logger.warning(
|
|
811
|
+
"Streaming is not supported using transformers.Pipeline implementation. Ignoring stream=True."
|
|
812
|
+
)
|
|
813
|
+
stream = False
|
|
814
|
+
|
|
815
|
+
if presence_penalty is not None:
|
|
816
|
+
logger.warning(
|
|
817
|
+
"Presence penalty is not supported using transformers.Pipeline implementation."
|
|
818
|
+
" Ignoring presence_penalty."
|
|
819
|
+
)
|
|
820
|
+
presence_penalty = None
|
|
821
|
+
|
|
822
|
+
import transformers
|
|
823
|
+
|
|
824
|
+
transformers_version = version.parse(transformers.__version__)
|
|
825
|
+
|
|
826
|
+
# Stop strings are supported in transformers >= 4.43.0
|
|
827
|
+
can_handle_stop_strings = transformers_version >= version.parse("4.43.0")
|
|
753
828
|
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
829
|
+
# Determine sampling based on temperature (following serve.py logic)
|
|
830
|
+
# Default temperature to 1.0 if not specified
|
|
831
|
+
actual_temperature = temperature if temperature is not None else 1.0
|
|
832
|
+
do_sample = actual_temperature > 0.0
|
|
833
|
+
|
|
834
|
+
# Set up generation config following best practices from serve.py
|
|
835
|
+
generation_config = transformers.GenerationConfig(
|
|
836
|
+
max_new_tokens=max_completion_tokens if max_completion_tokens is not None else 1024,
|
|
758
837
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
759
838
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
760
|
-
|
|
761
|
-
stream=stream,
|
|
762
|
-
num_return_sequences=n,
|
|
763
|
-
num_beams=max(1, n), # must be >1
|
|
764
|
-
repetition_penalty=frequency_penalty,
|
|
765
|
-
# TODO: Handle diversity_penalty and num_beam_groups
|
|
766
|
-
# not all models support them making it hard to support any huggingface model
|
|
767
|
-
# diversity_penalty=presence_penalty if n > 1 else None,
|
|
768
|
-
# num_beam_groups=max(2, n) if presence_penalty else 1,
|
|
769
|
-
do_sample=False,
|
|
839
|
+
do_sample=do_sample,
|
|
770
840
|
)
|
|
771
841
|
|
|
842
|
+
# Only set temperature and top_p if sampling is enabled
|
|
843
|
+
if do_sample:
|
|
844
|
+
generation_config.temperature = actual_temperature
|
|
845
|
+
if top_p is not None:
|
|
846
|
+
generation_config.top_p = top_p
|
|
847
|
+
|
|
848
|
+
# Handle repetition penalty (mapped from frequency_penalty)
|
|
849
|
+
if frequency_penalty is not None:
|
|
850
|
+
# OpenAI's frequency_penalty is typically in range [-2.0, 2.0]
|
|
851
|
+
# HuggingFace's repetition_penalty is typically > 0, with 1.0 = no penalty
|
|
852
|
+
# We need to convert: frequency_penalty=0 -> repetition_penalty=1.0
|
|
853
|
+
# Higher frequency_penalty should increase repetition_penalty
|
|
854
|
+
generation_config.repetition_penalty = 1.0 + (frequency_penalty if frequency_penalty > 0 else 0)
|
|
855
|
+
|
|
856
|
+
# For multiple completions (n > 1), use sampling not beam search
|
|
857
|
+
if n > 1:
|
|
858
|
+
generation_config.num_return_sequences = n
|
|
859
|
+
# Force sampling on for multiple sequences
|
|
860
|
+
if not do_sample:
|
|
861
|
+
logger.warning("Forcing do_sample=True for n>1. Consider setting temperature > 0 for better diversity.")
|
|
862
|
+
generation_config.do_sample = True
|
|
863
|
+
generation_config.temperature = 1.0
|
|
864
|
+
else:
|
|
865
|
+
generation_config.num_return_sequences = 1
|
|
866
|
+
|
|
867
|
+
# Handle stop strings if provided
|
|
868
|
+
stopping_criteria = None
|
|
869
|
+
if stop_strings and not can_handle_stop_strings:
|
|
870
|
+
logger.warning("Stop strings are not supported in transformers < 4.41.0. Ignoring stop strings.")
|
|
871
|
+
|
|
872
|
+
if stop_strings and can_handle_stop_strings:
|
|
873
|
+
stopping_criteria = self._get_stopping_criteria(stop_strings)
|
|
874
|
+
output_ids = self.model.generate(
|
|
875
|
+
inputs.input_ids,
|
|
876
|
+
attention_mask=inputs.attention_mask,
|
|
877
|
+
generation_config=generation_config,
|
|
878
|
+
# Pass tokenizer for proper handling of stop strings
|
|
879
|
+
tokenizer=self.tokenizer,
|
|
880
|
+
stopping_criteria=stopping_criteria,
|
|
881
|
+
)
|
|
882
|
+
else:
|
|
883
|
+
output_ids = self.model.generate(
|
|
884
|
+
inputs.input_ids,
|
|
885
|
+
attention_mask=inputs.attention_mask,
|
|
886
|
+
generation_config=generation_config,
|
|
887
|
+
)
|
|
888
|
+
|
|
772
889
|
# Generate text
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
generation_config=generation_config,
|
|
777
|
-
)
|
|
890
|
+
# Handle the case where output might be 1D if n=1
|
|
891
|
+
if output_ids.dim() == 1:
|
|
892
|
+
output_ids = output_ids.unsqueeze(0)
|
|
778
893
|
|
|
779
894
|
generated_texts = []
|
|
780
895
|
completion_tokens = 0
|
|
781
896
|
total_tokens = prompt_tokens
|
|
897
|
+
|
|
782
898
|
for output_id in output_ids:
|
|
783
899
|
# The output_ids include the input prompt
|
|
784
900
|
# Decode the generated text, excluding the input prompt
|
|
785
901
|
# so we slice to get only new tokens
|
|
786
902
|
generated_tokens = output_id[prompt_tokens:]
|
|
787
903
|
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
|
904
|
+
|
|
905
|
+
# Trim stop strings from generated text if they appear
|
|
906
|
+
# The stop criteria would stop generating further tokens, so we need to trim the generated text
|
|
907
|
+
if stop_strings and can_handle_stop_strings:
|
|
908
|
+
for stop_str in stop_strings:
|
|
909
|
+
if stop_str in generated_text:
|
|
910
|
+
# Find the first occurrence and trim everything from there
|
|
911
|
+
stop_idx = generated_text.find(stop_str)
|
|
912
|
+
generated_text = generated_text[:stop_idx]
|
|
913
|
+
break # Stop after finding the first stop string
|
|
914
|
+
|
|
788
915
|
generated_texts.append(generated_text)
|
|
789
916
|
|
|
790
917
|
# Calculate completion tokens
|
|
@@ -148,12 +148,17 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
|
148
148
|
|
|
149
149
|
file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
|
150
150
|
|
|
151
|
+
# MLflow 3.x may return file:// URIs for artifact_path; extract just the last path component
|
|
152
|
+
artifact_path = model_info.artifact_path
|
|
153
|
+
if artifact_path.startswith("file://"):
|
|
154
|
+
artifact_path = artifact_path.rstrip("/").split("/")[-1]
|
|
155
|
+
|
|
151
156
|
base_meta = model_blob_meta.ModelBlobMeta(
|
|
152
157
|
name=name,
|
|
153
158
|
model_type=cls.HANDLER_TYPE,
|
|
154
159
|
handler_version=cls.HANDLER_VERSION,
|
|
155
160
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
|
156
|
-
options=model_meta_schema.MLFlowModelBlobOptions({"artifact_path":
|
|
161
|
+
options=model_meta_schema.MLFlowModelBlobOptions({"artifact_path": artifact_path}),
|
|
157
162
|
)
|
|
158
163
|
model_meta.models[name] = base_meta
|
|
159
164
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|