crfm-helm 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/METADATA +19 -5
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/RECORD +121 -76
- helm/benchmark/adaptation/adapter_spec.py +32 -31
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
- helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
- helm/benchmark/annotation/air_bench_annotator.py +64 -0
- helm/benchmark/annotation/annotator_factory.py +6 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +1 -1
- helm/benchmark/annotation/live_qa_annotator.py +84 -0
- helm/benchmark/annotation/medication_qa_annotator.py +81 -0
- helm/benchmark/augmentations/perturbation.py +17 -1
- helm/benchmark/augmentations/test_perturbation.py +30 -0
- helm/benchmark/augmentations/translate_perturbation.py +1 -0
- helm/benchmark/huggingface_registration.py +16 -6
- helm/benchmark/metrics/air_bench_metrics.py +56 -0
- helm/benchmark/metrics/efficiency_metrics.py +9 -2
- helm/benchmark/metrics/evaluate_reference_metrics.py +16 -0
- helm/benchmark/metrics/fin_qa_metrics.py +60 -0
- helm/benchmark/metrics/fin_qa_metrics_helper.py +398 -0
- helm/benchmark/metrics/gpt4v_originality_critique_metrics.py +126 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +1 -0
- helm/benchmark/metrics/live_qa_metrics.py +23 -0
- helm/benchmark/metrics/medication_qa_metrics.py +23 -0
- helm/benchmark/metrics/prometheus_vision_critique_metrics.py +185 -0
- helm/benchmark/metrics/reka_vibe_critique_metrics.py +158 -0
- helm/benchmark/metrics/unitxt_metrics.py +20 -10
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +104 -21
- helm/benchmark/model_metadata_registry.py +5 -1
- helm/benchmark/presentation/schema.py +54 -4
- helm/benchmark/presentation/test_schema.py +11 -0
- helm/benchmark/run.py +16 -2
- helm/benchmark/run_expander.py +112 -63
- helm/benchmark/run_spec_factory.py +15 -10
- helm/benchmark/run_specs/air_bench_run_specs.py +40 -0
- helm/benchmark/run_specs/classic_run_specs.py +15 -11
- helm/benchmark/run_specs/decodingtrust_run_specs.py +3 -1
- helm/benchmark/run_specs/experimental_run_specs.py +33 -0
- helm/benchmark/run_specs/finance_run_specs.py +33 -0
- helm/benchmark/run_specs/vlm_run_specs.py +444 -65
- helm/benchmark/scenarios/air_bench_scenario.py +50 -0
- helm/benchmark/scenarios/ci_mcqa_scenario.py +80 -0
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +8 -2
- helm/benchmark/scenarios/fin_qa_scenario.py +117 -0
- helm/benchmark/scenarios/legalbench_scenario.py +6 -2
- helm/benchmark/scenarios/math_scenario.py +1 -1
- helm/benchmark/scenarios/test_air_bench_scenario.py +27 -0
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +4 -2
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +13 -2
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +1 -5
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +1 -5
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +5 -3
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +247 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +95 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +4 -2
- helm/benchmark/static/schema_air_bench.yaml +3149 -0
- helm/benchmark/static/schema_classic.yaml +3 -59
- helm/benchmark/static/schema_finance.yaml +143 -0
- helm/benchmark/static/schema_image2structure.yaml +447 -0
- helm/benchmark/static/schema_instruction_following.yaml +3 -52
- helm/benchmark/static/schema_lite.yaml +3 -61
- helm/benchmark/static/schema_medical.yaml +255 -0
- helm/benchmark/static/schema_mmlu.yaml +3 -61
- helm/benchmark/static/schema_tables.yaml +200 -0
- helm/benchmark/static/schema_thai.yaml +223 -0
- helm/benchmark/static/schema_unitxt.yaml +3 -61
- helm/benchmark/static/schema_vhelm.yaml +824 -0
- helm/benchmark/static/schema_vhelm_lite.yaml +109 -0
- helm/benchmark/static_build/assets/air-overview-d2e6c49f.png +0 -0
- helm/benchmark/static_build/assets/index-30dbceba.js +10 -0
- helm/benchmark/static_build/assets/index-66b02d40.css +1 -0
- helm/benchmark/static_build/assets/overview-74aea3d8.png +0 -0
- helm/benchmark/static_build/assets/process-flow-bd2eba96.png +0 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/clients/anthropic_client.py +78 -14
- helm/clients/auto_client.py +11 -0
- helm/clients/client.py +24 -7
- helm/clients/cohere_client.py +98 -3
- helm/clients/huggingface_client.py +71 -12
- helm/clients/openai_client.py +11 -5
- helm/clients/reka_client.py +189 -0
- helm/clients/test_client.py +3 -3
- helm/clients/test_huggingface_client.py +19 -3
- helm/clients/test_together_client.py +72 -2
- helm/clients/together_client.py +199 -2
- helm/clients/vertexai_client.py +117 -64
- helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
- helm/clients/vision_language/huggingface_vlm_client.py +12 -4
- helm/clients/vision_language/idefics_client.py +2 -2
- helm/clients/vision_language/paligemma_client.py +146 -0
- helm/clients/vision_language/palmyra_vision_client.py +84 -0
- helm/clients/yi_client.py +31 -0
- helm/common/critique_request.py +10 -1
- helm/common/images_utils.py +29 -3
- helm/config/model_deployments.yaml +504 -12
- helm/config/model_metadata.yaml +579 -52
- helm/config/tokenizer_configs.yaml +100 -1
- helm/proxy/critique/model_critique_client.py +32 -4
- helm/proxy/services/server_service.py +1 -1
- helm/tokenizers/auto_tokenizer.py +1 -1
- helm/tokenizers/cohere_tokenizer.py +44 -2
- helm/tokenizers/huggingface_tokenizer.py +36 -13
- helm/tokenizers/test_cohere_tokenizer.py +39 -0
- helm/tokenizers/test_huggingface_tokenizer.py +5 -1
- helm/benchmark/static/schema_vlm.yaml +0 -576
- helm/benchmark/static_build/assets/index-5088afcb.css +0 -1
- helm/benchmark/static_build/assets/index-d839df55.js +0 -9
- helm/benchmark/test_model_deployment_definition.py +0 -90
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/LICENSE +0 -0
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/WHEEL +0 -0
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/top_level.txt +0 -0
helm/clients/vertexai_client.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
import requests
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from threading import Lock
|
|
4
|
-
from typing import Any, Dict, Optional, List, Union
|
|
4
|
+
from typing import Any, Dict, Mapping, Optional, List, Union
|
|
5
5
|
|
|
6
6
|
from helm.common.cache import CacheConfig
|
|
7
|
-
from helm.common.hierarchical_logger import hlog
|
|
8
7
|
from helm.common.media_object import TEXT_TYPE
|
|
9
8
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
10
9
|
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, ErrorFlags
|
|
@@ -27,22 +26,62 @@ class VertexAIContentBlockedError(Exception):
|
|
|
27
26
|
pass
|
|
28
27
|
|
|
29
28
|
|
|
29
|
+
class SafetySettingPresets:
|
|
30
|
+
BLOCK_NONE = "block_none" # Disable all blocking
|
|
31
|
+
DEFAULT = "default" # Use default safety settings
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _get_safety_settings_for_preset(
|
|
35
|
+
safety_settings_preset: Optional[str],
|
|
36
|
+
) -> Optional[Dict[HarmCategory, SafetySetting.HarmBlockThreshold]]:
|
|
37
|
+
"""Get the safety settings for the safety_settings_preset.
|
|
38
|
+
|
|
39
|
+
If safety_settings_preset is None, use the default value of BLOCK_NONE (*not* DEFAULT)."""
|
|
40
|
+
if safety_settings_preset is None or safety_settings_preset == SafetySettingPresets.BLOCK_NONE:
|
|
41
|
+
return {
|
|
42
|
+
harm_category: SafetySetting.HarmBlockThreshold(SafetySetting.HarmBlockThreshold.BLOCK_NONE)
|
|
43
|
+
for harm_category in iter(HarmCategory)
|
|
44
|
+
}
|
|
45
|
+
elif safety_settings_preset == SafetySettingPresets.DEFAULT:
|
|
46
|
+
return None
|
|
47
|
+
else:
|
|
48
|
+
raise ValueError(f"Unknown safety_settings_preset: {safety_settings_preset}")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_model_name_for_request(request: Request) -> str:
|
|
52
|
+
# We have to strip "-safety-" suffixes from model names because they are not part of the Vertex AI model name
|
|
53
|
+
# TODO: Clean up this hack
|
|
54
|
+
return request.model_engine.split("-safety-")[0]
|
|
55
|
+
|
|
56
|
+
|
|
30
57
|
class VertexAIClient(CachingClient, ABC):
|
|
31
58
|
"""Client for Vertex AI models"""
|
|
32
59
|
|
|
33
|
-
def __init__(
|
|
60
|
+
def __init__(
|
|
61
|
+
self, cache_config: CacheConfig, project_id: str, location: str, safety_settings_preset: Optional[str] = None
|
|
62
|
+
) -> None:
|
|
34
63
|
super().__init__(cache_config=cache_config)
|
|
35
64
|
self.project_id = project_id
|
|
36
65
|
self.location = location
|
|
37
66
|
|
|
38
|
-
|
|
39
|
-
self.safety_settings
|
|
40
|
-
harm_category: SafetySetting.HarmBlockThreshold(SafetySetting.HarmBlockThreshold.BLOCK_NONE)
|
|
41
|
-
for harm_category in iter(HarmCategory)
|
|
42
|
-
}
|
|
67
|
+
self.safety_settings_preset = safety_settings_preset
|
|
68
|
+
self.safety_settings = _get_safety_settings_for_preset(safety_settings_preset)
|
|
43
69
|
|
|
44
70
|
vertexai.init(project=self.project_id, location=self.location)
|
|
45
71
|
|
|
72
|
+
def make_cache_key_with_safety_settings_preset(self, raw_request: Mapping, request: Request) -> Mapping:
|
|
73
|
+
"""Construct the key for the cache using the raw request.
|
|
74
|
+
|
|
75
|
+
Add `self.safety_settings_preset` to the key, if not None."""
|
|
76
|
+
if self.safety_settings_preset is not None:
|
|
77
|
+
assert "safety_settings_preset" not in raw_request
|
|
78
|
+
return {
|
|
79
|
+
**CachingClient.make_cache_key(raw_request, request),
|
|
80
|
+
"safety_settings_preset": self.safety_settings_preset,
|
|
81
|
+
}
|
|
82
|
+
else:
|
|
83
|
+
return CachingClient.make_cache_key(raw_request, request)
|
|
84
|
+
|
|
46
85
|
@abstractmethod
|
|
47
86
|
def make_request(self, request: Request) -> RequestResult:
|
|
48
87
|
raise NotImplementedError
|
|
@@ -72,7 +111,7 @@ class VertexAITextClient(VertexAIClient):
|
|
|
72
111
|
}
|
|
73
112
|
|
|
74
113
|
completions: List[GeneratedOutput] = []
|
|
75
|
-
model_name: str = request
|
|
114
|
+
model_name: str = _get_model_name_for_request(request)
|
|
76
115
|
|
|
77
116
|
try:
|
|
78
117
|
|
|
@@ -88,9 +127,9 @@ class VertexAITextClient(VertexAIClient):
|
|
|
88
127
|
# We need to include the engine's name to differentiate among requests made for different model
|
|
89
128
|
# engines since the engine name is not included in the request itself.
|
|
90
129
|
# Same for the prompt.
|
|
91
|
-
cache_key =
|
|
130
|
+
cache_key = self.make_cache_key_with_safety_settings_preset(
|
|
92
131
|
{
|
|
93
|
-
"engine":
|
|
132
|
+
"engine": model_name,
|
|
94
133
|
"prompt": request.prompt,
|
|
95
134
|
**parameters,
|
|
96
135
|
},
|
|
@@ -131,12 +170,6 @@ class VertexAITextClient(VertexAIClient):
|
|
|
131
170
|
class VertexAIChatClient(VertexAIClient):
|
|
132
171
|
"""Client for Vertex AI chat models (e.g., Gemini). Supports multimodal prompts."""
|
|
133
172
|
|
|
134
|
-
# Set the finish reason to this if the prompt violates the content policy
|
|
135
|
-
CONTENT_POLICY_VIOLATED_FINISH_REASON: str = "The prompt violates Google's content policy."
|
|
136
|
-
|
|
137
|
-
# Gemini returns this error for certain valid requests
|
|
138
|
-
CONTENT_HAS_NO_PARTS_ERROR: str = "Content has no parts."
|
|
139
|
-
|
|
140
173
|
# Enum taken from:
|
|
141
174
|
# https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1beta1#google.cloud.aiplatform.v1beta1.Candidate.FinishReason
|
|
142
175
|
# We don't directly import this enum because it can differ between different Vertex AI library versions.
|
|
@@ -149,7 +182,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
149
182
|
]
|
|
150
183
|
|
|
151
184
|
@staticmethod
|
|
152
|
-
def get_model(model_name: str) ->
|
|
185
|
+
def get_model(model_name: str) -> GenerativeModel:
|
|
153
186
|
global _models_lock
|
|
154
187
|
global _models
|
|
155
188
|
|
|
@@ -184,7 +217,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
184
217
|
}
|
|
185
218
|
|
|
186
219
|
completions: List[GeneratedOutput] = []
|
|
187
|
-
model_name: str = request
|
|
220
|
+
model_name: str = _get_model_name_for_request(request)
|
|
188
221
|
model = self.get_model(model_name)
|
|
189
222
|
|
|
190
223
|
try:
|
|
@@ -202,21 +235,24 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
202
235
|
)
|
|
203
236
|
candidates: List[Candidate] = response.candidates
|
|
204
237
|
|
|
205
|
-
# Depending on the version of the Vertex AI library and the type of
|
|
206
|
-
#
|
|
238
|
+
# Depending on the version of the Vertex AI library and the type of prompt blocking,
|
|
239
|
+
# prompt blocking can show up in many ways, so this defensively handles most of these ways
|
|
240
|
+
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
|
241
|
+
raise VertexAIContentBlockedError(
|
|
242
|
+
f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
|
|
243
|
+
)
|
|
207
244
|
if not candidates:
|
|
208
|
-
raise VertexAIContentBlockedError("No candidates in response
|
|
245
|
+
raise VertexAIContentBlockedError(f"No candidates in response: {response}")
|
|
209
246
|
predictions: List[Dict[str, Any]] = []
|
|
210
247
|
for candidate in candidates:
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
raise VertexAIContentBlockedError("Content has no parts due to content blocking")
|
|
248
|
+
# Depending on the version of the Vertex AI library and the type of prompt blocking,
|
|
249
|
+
# content blocking can show up in many ways, so this defensively handles most of these ways
|
|
250
|
+
if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
|
|
251
|
+
raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
|
|
252
|
+
if not candidate.content:
|
|
253
|
+
raise VertexAIContentBlockedError(f"No content in candidate: {candidate}")
|
|
254
|
+
if not candidate.content.parts:
|
|
255
|
+
raise VertexAIContentBlockedError(f"No content parts in candidate: {candidate}")
|
|
220
256
|
predictions.append({"text": candidate.content.text})
|
|
221
257
|
# TODO: Extract more information from the response
|
|
222
258
|
return {"predictions": predictions}
|
|
@@ -224,7 +260,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
224
260
|
# We need to include the engine's name to differentiate among requests made for different model
|
|
225
261
|
# engines since the engine name is not included in the request itself.
|
|
226
262
|
# Same for the prompt.
|
|
227
|
-
cache_key =
|
|
263
|
+
cache_key = self.make_cache_key_with_safety_settings_preset(
|
|
228
264
|
{
|
|
229
265
|
"model_name": model_name,
|
|
230
266
|
"prompt": request.prompt,
|
|
@@ -234,11 +270,11 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
234
270
|
)
|
|
235
271
|
|
|
236
272
|
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
237
|
-
except VertexAIContentBlockedError:
|
|
273
|
+
except VertexAIContentBlockedError as e:
|
|
238
274
|
return RequestResult(
|
|
239
275
|
success=False,
|
|
240
276
|
cached=False,
|
|
241
|
-
error="
|
|
277
|
+
error=f"Content blocked: {str(e)}",
|
|
242
278
|
completions=[],
|
|
243
279
|
embedding=[],
|
|
244
280
|
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
@@ -252,7 +288,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
252
288
|
return RequestResult(
|
|
253
289
|
success=False,
|
|
254
290
|
cached=False,
|
|
255
|
-
error="
|
|
291
|
+
error=f"Content blocked error in cached response: {str(response)}",
|
|
256
292
|
completions=[],
|
|
257
293
|
embedding=[],
|
|
258
294
|
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
@@ -266,7 +302,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
266
302
|
return RequestResult(
|
|
267
303
|
success=False,
|
|
268
304
|
cached=False,
|
|
269
|
-
error="
|
|
305
|
+
error=f"Content blocked error in cached prediction: {str(prediction)}",
|
|
270
306
|
completions=[],
|
|
271
307
|
embedding=[],
|
|
272
308
|
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
@@ -291,21 +327,6 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
291
327
|
)
|
|
292
328
|
|
|
293
329
|
def _make_multimodal_request(self, request: Request) -> RequestResult:
|
|
294
|
-
def complete_for_valid_error(error_message: str) -> RequestResult:
|
|
295
|
-
empty_completion = GeneratedOutput(
|
|
296
|
-
text="",
|
|
297
|
-
logprob=0,
|
|
298
|
-
tokens=[],
|
|
299
|
-
finish_reason={"reason": error_message},
|
|
300
|
-
)
|
|
301
|
-
return RequestResult(
|
|
302
|
-
success=True,
|
|
303
|
-
cached=False,
|
|
304
|
-
request_time=0,
|
|
305
|
-
completions=[empty_completion] * request.num_completions,
|
|
306
|
-
embedding=[],
|
|
307
|
-
)
|
|
308
|
-
|
|
309
330
|
# Contents can either be text or a list of multimodal content made up of text, images or other content
|
|
310
331
|
contents: Union[str, List[Union[str, Any]]] = request.prompt
|
|
311
332
|
# Used to generate a unique cache key for this specific request
|
|
@@ -334,7 +355,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
334
355
|
}
|
|
335
356
|
|
|
336
357
|
completions: List[GeneratedOutput] = []
|
|
337
|
-
model_name: str = request
|
|
358
|
+
model_name: str = _get_model_name_for_request(request)
|
|
338
359
|
model = self.get_model(model_name)
|
|
339
360
|
|
|
340
361
|
request_time = 0
|
|
@@ -346,30 +367,62 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
346
367
|
try:
|
|
347
368
|
|
|
348
369
|
def do_it() -> Dict[str, Any]:
|
|
349
|
-
|
|
370
|
+
response: GenerationResponse = model.generate_content(
|
|
350
371
|
contents, generation_config=parameters, safety_settings=self.safety_settings
|
|
351
372
|
)
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
373
|
+
# Depending on the version of the Vertex AI library and the type of prompt blocking,
|
|
374
|
+
# prompt blocking can show up in many ways, so this defensively handles most of these ways
|
|
375
|
+
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
|
376
|
+
raise VertexAIContentBlockedError(
|
|
377
|
+
f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
|
|
378
|
+
)
|
|
379
|
+
if not response.candidates:
|
|
380
|
+
raise VertexAIContentBlockedError(f"No candidates in response: {response}")
|
|
381
|
+
# We should only have one candidate
|
|
382
|
+
assert (
|
|
383
|
+
len(response.candidates) == 1
|
|
384
|
+
), f"Expected 1 candidate since candidate_count is 1, got {len(response.candidates)}."
|
|
385
|
+
candidate = response.candidates[0]
|
|
386
|
+
# Depending on the version of the Vertex AI library and the type of prompt blocking,
|
|
387
|
+
# content blocking can show up in many ways, so this defensively handles most of these ways
|
|
388
|
+
if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
|
|
389
|
+
raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
|
|
390
|
+
if not candidate.content:
|
|
391
|
+
raise VertexAIContentBlockedError(f"No content in candidate: {candidate}")
|
|
392
|
+
if not candidate.content.parts:
|
|
393
|
+
raise VertexAIContentBlockedError(f"No content parts in candidate: {candidate}")
|
|
394
|
+
return {"predictions": [{"text": candidate.text}]}
|
|
357
395
|
|
|
358
396
|
raw_cache_key = {"model_name": model_name, "prompt": prompt_key, **parameters}
|
|
359
397
|
if completion_index > 0:
|
|
360
398
|
raw_cache_key["completion_index"] = completion_index
|
|
361
399
|
|
|
362
|
-
cache_key =
|
|
400
|
+
cache_key = self.make_cache_key_with_safety_settings_preset(raw_cache_key, request)
|
|
363
401
|
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
364
|
-
except
|
|
365
|
-
if str(e) == self.CONTENT_HAS_NO_PARTS_ERROR:
|
|
366
|
-
return complete_for_valid_error(self.CONTENT_HAS_NO_PARTS_ERROR)
|
|
367
|
-
|
|
402
|
+
except requests.exceptions.RequestException as e:
|
|
368
403
|
error: str = f"Gemini Vision error: {e}"
|
|
369
404
|
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
405
|
+
except VertexAIContentBlockedError as e:
|
|
406
|
+
return RequestResult(
|
|
407
|
+
success=False,
|
|
408
|
+
cached=False,
|
|
409
|
+
error=f"Content blocked: {str(e)}",
|
|
410
|
+
completions=[],
|
|
411
|
+
embedding=[],
|
|
412
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
413
|
+
)
|
|
370
414
|
|
|
371
415
|
if "error" in response:
|
|
372
|
-
return
|
|
416
|
+
return RequestResult(
|
|
417
|
+
success=False,
|
|
418
|
+
cached=True,
|
|
419
|
+
error=f"Content blocked error in cached response: {str(response)}",
|
|
420
|
+
completions=[],
|
|
421
|
+
embedding=[],
|
|
422
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
423
|
+
request_time=response["request_time"],
|
|
424
|
+
request_datetime=response["request_datetime"],
|
|
425
|
+
)
|
|
373
426
|
|
|
374
427
|
response_text = response["predictions"][0]["text"]
|
|
375
428
|
completion = GeneratedOutput(text=response_text, logprob=0, tokens=[])
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from transformers import AutoProcessor, AutoModelForVision2Seq
|
|
6
|
+
from transformers.image_utils import load_image
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from helm.common.cache import CacheConfig
|
|
10
|
+
from helm.common.gpu_utils import get_torch_device_name, is_cuda_available
|
|
11
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
12
|
+
from helm.common.media_object import TEXT_TYPE
|
|
13
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
14
|
+
from helm.common.request import wrap_request_time
|
|
15
|
+
from helm.common.tokenization_request import TokenizationRequest
|
|
16
|
+
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
|
|
17
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class Vision2SeqModelProcessor:
|
|
22
|
+
"""Loaded model and processor."""
|
|
23
|
+
|
|
24
|
+
model: AutoModelForVision2Seq
|
|
25
|
+
processor: AutoProcessor
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
_models_lock: Lock = Lock()
|
|
29
|
+
_models: Dict[str, Optional[Vision2SeqModelProcessor]] = {
|
|
30
|
+
"HuggingFaceM4/idefics2-8b": None,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class HuggingFaceVision2SeqClient(CachingClient):
|
|
35
|
+
"""
|
|
36
|
+
Models for Vision2Seq models from HuggingFace.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
ASSISTANT_PREFIX: str = "Assistant:"
|
|
40
|
+
|
|
41
|
+
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
|
|
42
|
+
super().__init__(cache_config=cache_config)
|
|
43
|
+
self.tokenizer = tokenizer
|
|
44
|
+
self.tokenizer_name = tokenizer_name
|
|
45
|
+
self._device: str = get_torch_device_name()
|
|
46
|
+
|
|
47
|
+
def _get_model(self, checkpoint: str) -> Vision2SeqModelProcessor:
|
|
48
|
+
global _models_lock
|
|
49
|
+
global _models
|
|
50
|
+
|
|
51
|
+
# Ensure that only one thread is loading the model at a time
|
|
52
|
+
with _models_lock:
|
|
53
|
+
loaded_model_processor = _models[checkpoint]
|
|
54
|
+
if loaded_model_processor is None:
|
|
55
|
+
hlog(f"Loading model {checkpoint} and caching in memory...")
|
|
56
|
+
torch_dtype: torch.dtype = torch.float16 if is_cuda_available() else torch.float32
|
|
57
|
+
model = AutoModelForVision2Seq.from_pretrained(checkpoint, torch_dtype=torch_dtype).to(self._device)
|
|
58
|
+
processor = AutoProcessor.from_pretrained(checkpoint)
|
|
59
|
+
|
|
60
|
+
_models[checkpoint] = Vision2SeqModelProcessor(model, processor)
|
|
61
|
+
loaded_model_processor = _models[checkpoint]
|
|
62
|
+
|
|
63
|
+
assert loaded_model_processor is not None
|
|
64
|
+
return loaded_model_processor
|
|
65
|
+
|
|
66
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
67
|
+
assert request.model_deployment in _models, f"Not a valid model for this client: {request.model_deployment}"
|
|
68
|
+
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
69
|
+
|
|
70
|
+
loaded_model_processor: Vision2SeqModelProcessor = self._get_model(request.model_deployment)
|
|
71
|
+
model = loaded_model_processor.model
|
|
72
|
+
processor = loaded_model_processor.processor
|
|
73
|
+
|
|
74
|
+
generation_args: Dict[str, Any] = {
|
|
75
|
+
"max_new_tokens": request.max_tokens,
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
image_paths: List[str] = []
|
|
79
|
+
multimodal_prompt: List[Dict[str, str]] = []
|
|
80
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
81
|
+
if media_object.is_type("image") and media_object.location:
|
|
82
|
+
image_paths.append(media_object.location)
|
|
83
|
+
multimodal_prompt.append({"type": "image"})
|
|
84
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
85
|
+
if media_object.text is None:
|
|
86
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
87
|
+
|
|
88
|
+
multimodal_prompt.append({"type": "text", "text": media_object.text})
|
|
89
|
+
else:
|
|
90
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
91
|
+
|
|
92
|
+
completions: List[GeneratedOutput] = []
|
|
93
|
+
with htrack_block(f"Generating for prompt: {request.multimodal_prompt.text}"):
|
|
94
|
+
try:
|
|
95
|
+
|
|
96
|
+
def do_it() -> Dict[str, Any]:
|
|
97
|
+
messages = [{"role": "user", "content": multimodal_prompt}]
|
|
98
|
+
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
|
99
|
+
inputs = processor(
|
|
100
|
+
text=[prompt] * request.num_completions,
|
|
101
|
+
images=[
|
|
102
|
+
[load_image(image_path) for image_path in image_paths]
|
|
103
|
+
for _ in range(request.num_completions)
|
|
104
|
+
],
|
|
105
|
+
return_tensors="pt",
|
|
106
|
+
)
|
|
107
|
+
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
|
108
|
+
|
|
109
|
+
# Generate
|
|
110
|
+
generated_ids = model.generate(**inputs, **generation_args)
|
|
111
|
+
generated_texts: List[str] = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
|
112
|
+
return {"output": generated_texts}
|
|
113
|
+
|
|
114
|
+
# Include the prompt and model name in the cache key
|
|
115
|
+
cache_key = CachingClient.make_cache_key(
|
|
116
|
+
raw_request={
|
|
117
|
+
"n": request.num_completions,
|
|
118
|
+
"model": request.model,
|
|
119
|
+
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
|
|
120
|
+
**generation_args,
|
|
121
|
+
},
|
|
122
|
+
request=request,
|
|
123
|
+
)
|
|
124
|
+
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
125
|
+
except RuntimeError as model_error:
|
|
126
|
+
return RequestResult(success=False, cached=False, error=str(model_error), completions=[], embedding=[])
|
|
127
|
+
|
|
128
|
+
for text in result["output"]:
|
|
129
|
+
hlog(f"Generated text: {text}")
|
|
130
|
+
assert self.ASSISTANT_PREFIX in text, f"Expected {self.ASSISTANT_PREFIX} in the output"
|
|
131
|
+
text = text.rpartition(self.ASSISTANT_PREFIX)[-1]
|
|
132
|
+
hlog(f"Truncated: {text}")
|
|
133
|
+
tokenization_result = self.tokenizer.tokenize(
|
|
134
|
+
TokenizationRequest(text, tokenizer=self.tokenizer_name, encode=False)
|
|
135
|
+
)
|
|
136
|
+
tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
|
|
137
|
+
completions.append(GeneratedOutput(text=text, logprob=0, tokens=tokens))
|
|
138
|
+
|
|
139
|
+
return RequestResult(
|
|
140
|
+
success=True,
|
|
141
|
+
cached=cached,
|
|
142
|
+
request_time=result["request_time"],
|
|
143
|
+
completions=completions,
|
|
144
|
+
embedding=[],
|
|
145
|
+
)
|
|
@@ -25,7 +25,7 @@ except ModuleNotFoundError as e:
|
|
|
25
25
|
|
|
26
26
|
class HuggingFaceVLMClient(CachingClient):
|
|
27
27
|
"""
|
|
28
|
-
General
|
|
28
|
+
General client for VLM models from HuggingFace.
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
_models_lock: Lock = Lock()
|
|
@@ -34,6 +34,11 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
34
34
|
"huggingface/llava-1.5-7b-hf": "llava-hf/llava-1.5-7b-hf",
|
|
35
35
|
"huggingface/llava-1.5-13b-hf": "llava-hf/llava-1.5-13b-hf",
|
|
36
36
|
"huggingface/bakLlava-v1-hf": "llava-hf/bakLlava-v1-hf",
|
|
37
|
+
"huggingface/llava-v1.6-vicuna-7b-hf": "llava-hf/llava-v1.6-vicuna-7b-hf",
|
|
38
|
+
"huggingface/llava-v1.6-vicuna-13b-hf": "llava-hf/llava-v1.6-vicuna-13b-hf",
|
|
39
|
+
"huggingface/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
|
|
40
|
+
"huggingface/llava-v1.6-34b-hf": "llava-hf/llava-v1.6-34b-hf",
|
|
41
|
+
"huggingface/prometheus-vision-13b-v1.0-hf": "PahaII/prometheus-vision-13b-v1.0-hf",
|
|
37
42
|
}
|
|
38
43
|
|
|
39
44
|
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
|
|
@@ -45,7 +50,7 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
45
50
|
with self._models_lock:
|
|
46
51
|
model_id: str = self._models_aliases.get(model_name, model_name)
|
|
47
52
|
if model_id not in self._models:
|
|
48
|
-
self._models[model_id] = pipeline("image-to-text", model=model_id)
|
|
53
|
+
self._models[model_id] = pipeline("image-to-text", model=model_id, device_map="auto")
|
|
49
54
|
return self._models[model_id]
|
|
50
55
|
|
|
51
56
|
def make_request(self, request: Request) -> RequestResult:
|
|
@@ -90,11 +95,14 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
90
95
|
except RuntimeError as e:
|
|
91
96
|
return RequestResult(success=False, cached=False, error=str(e), completions=[], embedding=[])
|
|
92
97
|
|
|
98
|
+
output: str = result["generated_text"]
|
|
99
|
+
if "ASSISTANT: " in output:
|
|
100
|
+
output = output.split("ASSISTANT: ")[1]
|
|
93
101
|
tokenization_result: TokenizationRequestResult = self.tokenizer.tokenize(
|
|
94
|
-
TokenizationRequest(
|
|
102
|
+
TokenizationRequest(output, tokenizer=self.tokenizer_name)
|
|
95
103
|
)
|
|
96
104
|
tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
|
|
97
|
-
completions: List[GeneratedOutput] = [GeneratedOutput(text=
|
|
105
|
+
completions: List[GeneratedOutput] = [GeneratedOutput(text=output, logprob=0, tokens=tokens)]
|
|
98
106
|
return RequestResult(
|
|
99
107
|
success=True,
|
|
100
108
|
cached=cached,
|
|
@@ -88,7 +88,7 @@ class IDEFICSClient(CachingClient):
|
|
|
88
88
|
|
|
89
89
|
input_args: Dict[str, Union[str, bool]] = {"return_tensors": "pt"}
|
|
90
90
|
generation_args = {
|
|
91
|
-
"
|
|
91
|
+
"max_new_tokens": request.max_tokens,
|
|
92
92
|
"bad_words_ids": processor.tokenizer(self.BAD_WORD_TOKENS, add_special_tokens=False).input_ids,
|
|
93
93
|
}
|
|
94
94
|
|
|
@@ -140,7 +140,7 @@ class IDEFICSClient(CachingClient):
|
|
|
140
140
|
|
|
141
141
|
# Truncate the output text as IDEFICS outputs the entire sequence including the prompt
|
|
142
142
|
if "instruct" in request.model:
|
|
143
|
-
assert self.ASSISTANT_PREFIX in text, f"Expected {self.ASSISTANT_PREFIX} in the output"
|
|
143
|
+
assert self.ASSISTANT_PREFIX in text, f"Expected {self.ASSISTANT_PREFIX} in the output: {text}"
|
|
144
144
|
text = text.rpartition(self.ASSISTANT_PREFIX)[-1]
|
|
145
145
|
else:
|
|
146
146
|
# Best we can do is to remove the text portion of the prompt from the output
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
|
7
|
+
|
|
8
|
+
from helm.common.cache import CacheConfig
|
|
9
|
+
from helm.common.images_utils import open_image
|
|
10
|
+
from helm.common.gpu_utils import get_torch_device_name
|
|
11
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
12
|
+
from helm.common.media_object import TEXT_TYPE
|
|
13
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
14
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
15
|
+
from helm.common.tokenization_request import TokenizationRequest
|
|
16
|
+
from helm.common.request import wrap_request_time
|
|
17
|
+
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
|
|
18
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from PIL import Image
|
|
22
|
+
except ModuleNotFoundError as e:
|
|
23
|
+
handle_module_not_found_error(e, ["images"])
|
|
24
|
+
|
|
25
|
+
# Added to solve: cutlassF: no kernel found to launch!
|
|
26
|
+
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
|
27
|
+
torch.backends.cuda.enable_flash_sdp(False)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class LoadedPaliGemmaForConditionalGeneration:
|
|
32
|
+
"""Loaded model and processor for PaliGemma."""
|
|
33
|
+
|
|
34
|
+
model: PaliGemmaForConditionalGeneration
|
|
35
|
+
processor: AutoProcessor
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
_models_lock: Lock = Lock()
|
|
39
|
+
_models: Dict[str, Optional[LoadedPaliGemmaForConditionalGeneration]] = {}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PaliGemmaClient(CachingClient):
|
|
43
|
+
"""
|
|
44
|
+
PaliGemma is a versatile and lightweight vision-language model (VLM) inspired by PaLI-3
|
|
45
|
+
and based on open components such as the SigLIP vision model and the Gemma language model.
|
|
46
|
+
It takes both image and text as input and generates text as output, supporting multiple languages.
|
|
47
|
+
It is designed for class-leading fine-tune performance on a wide range of vision-language tasks
|
|
48
|
+
such as image and short video caption, visual question answering, text reading, object detection
|
|
49
|
+
and object segmentation.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
|
|
53
|
+
super().__init__(cache_config=cache_config)
|
|
54
|
+
self.tokenizer = tokenizer
|
|
55
|
+
self.tokenizer_name = tokenizer_name
|
|
56
|
+
self._device: str = get_torch_device_name()
|
|
57
|
+
|
|
58
|
+
def _get_model(self, checkpoint: str) -> LoadedPaliGemmaForConditionalGeneration:
|
|
59
|
+
global _models_lock
|
|
60
|
+
global _models
|
|
61
|
+
|
|
62
|
+
# Ensure that only one thread is loading the model at a time
|
|
63
|
+
with _models_lock:
|
|
64
|
+
if checkpoint not in _models or _models[checkpoint] is None:
|
|
65
|
+
hlog(f"Loading model {checkpoint} and caching in memory...")
|
|
66
|
+
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
|
67
|
+
checkpoint, torch_dtype=torch.bfloat16, device_map="auto"
|
|
68
|
+
).eval()
|
|
69
|
+
processor = AutoProcessor.from_pretrained(checkpoint)
|
|
70
|
+
_models[checkpoint] = LoadedPaliGemmaForConditionalGeneration(model, processor)
|
|
71
|
+
loaded_model_processor = _models[checkpoint]
|
|
72
|
+
|
|
73
|
+
assert loaded_model_processor is not None
|
|
74
|
+
return loaded_model_processor
|
|
75
|
+
|
|
76
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
77
|
+
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
78
|
+
|
|
79
|
+
loaded_model_processor: LoadedPaliGemmaForConditionalGeneration = self._get_model(request.model_deployment)
|
|
80
|
+
model = loaded_model_processor.model
|
|
81
|
+
processor = loaded_model_processor.processor
|
|
82
|
+
generation_args = {"max_new_tokens": request.max_tokens}
|
|
83
|
+
|
|
84
|
+
images: List[Image.Image] = []
|
|
85
|
+
prompt_pieces: List[str] = []
|
|
86
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
87
|
+
if media_object.is_type("image") and media_object.location:
|
|
88
|
+
images += [open_image(media_object.location).convert("RGB")]
|
|
89
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
90
|
+
if media_object.text is None:
|
|
91
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
92
|
+
prompt_pieces.append(media_object.text)
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
95
|
+
prompt_text: str = "\n".join(prompt_pieces)
|
|
96
|
+
model_inputs = processor(text=prompt_text, images=images, return_tensors="pt").to(self._device)
|
|
97
|
+
input_len = model_inputs["input_ids"].shape[-1]
|
|
98
|
+
|
|
99
|
+
completions: List[GeneratedOutput] = []
|
|
100
|
+
with htrack_block(f"Generating for prompt: {prompt_text}"):
|
|
101
|
+
try:
|
|
102
|
+
concat_results = []
|
|
103
|
+
for i_completion in range(request.num_completions):
|
|
104
|
+
|
|
105
|
+
def do_it() -> Dict[str, Any]:
|
|
106
|
+
with torch.inference_mode():
|
|
107
|
+
generation = model.generate(
|
|
108
|
+
**model_inputs, max_new_tokens=request.max_tokens, do_sample=False
|
|
109
|
+
)[0]
|
|
110
|
+
if not request.echo_prompt:
|
|
111
|
+
generation = generation[input_len:]
|
|
112
|
+
decoded = processor.decode(generation, skip_special_tokens=True)
|
|
113
|
+
return {"output": decoded}
|
|
114
|
+
|
|
115
|
+
# Include the prompt and model name in the cache key
|
|
116
|
+
cache_key = CachingClient.make_cache_key(
|
|
117
|
+
raw_request={
|
|
118
|
+
"n": request.num_completions,
|
|
119
|
+
"i": i_completion,
|
|
120
|
+
"model": request.model,
|
|
121
|
+
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
|
|
122
|
+
**generation_args,
|
|
123
|
+
},
|
|
124
|
+
request=request,
|
|
125
|
+
)
|
|
126
|
+
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
127
|
+
concat_results.append(result)
|
|
128
|
+
except RuntimeError as model_error:
|
|
129
|
+
return RequestResult(success=False, cached=False, error=str(model_error), completions=[], embedding=[])
|
|
130
|
+
|
|
131
|
+
for result in concat_results:
|
|
132
|
+
text = result["output"]
|
|
133
|
+
hlog(f"Generated text: {text}")
|
|
134
|
+
tokenization_result = self.tokenizer.tokenize(
|
|
135
|
+
TokenizationRequest(text, tokenizer=self.tokenizer_name, encode=False)
|
|
136
|
+
)
|
|
137
|
+
tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
|
|
138
|
+
completions.append(GeneratedOutput(text=text, logprob=0, tokens=tokens))
|
|
139
|
+
|
|
140
|
+
return RequestResult(
|
|
141
|
+
success=True,
|
|
142
|
+
cached=cached,
|
|
143
|
+
request_time=result["request_time"],
|
|
144
|
+
completions=completions,
|
|
145
|
+
embedding=[],
|
|
146
|
+
)
|