crfm-helm 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/METADATA +41 -57
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/RECORD +197 -152
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +32 -31
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +12 -12
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +8 -8
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +77 -9
- helm/benchmark/adaptation/common_adapter_specs.py +2 -0
- helm/benchmark/annotation/air_bench_annotator.py +64 -0
- helm/benchmark/annotation/annotator_factory.py +6 -0
- helm/benchmark/annotation/anthropic_red_team_annotator.py +70 -0
- helm/benchmark/annotation/call_center_annotator.py +247 -0
- helm/benchmark/annotation/financebench_annotator.py +79 -0
- helm/benchmark/annotation/harm_bench_annotator.py +68 -0
- helm/benchmark/annotation/{image2structure → image2struct}/latex_compiler_annotator.py +2 -2
- helm/benchmark/annotation/{image2structure → image2struct}/lilypond_compiler_annotator.py +5 -3
- helm/benchmark/annotation/{image2structure → image2struct}/webpage_compiler_annotator.py +5 -5
- helm/benchmark/annotation/live_qa_annotator.py +71 -0
- helm/benchmark/annotation/medication_qa_annotator.py +68 -0
- helm/benchmark/annotation/model_as_judge.py +45 -0
- helm/benchmark/annotation/simple_safety_tests_annotator.py +64 -0
- helm/benchmark/annotation/xstest_annotator.py +110 -0
- helm/benchmark/augmentations/translate_perturbation.py +1 -0
- helm/benchmark/huggingface_registration.py +16 -6
- helm/benchmark/metrics/air_bench_metrics.py +56 -0
- helm/benchmark/metrics/annotation_metrics.py +108 -0
- helm/benchmark/metrics/bhasa_metrics.py +188 -0
- helm/benchmark/metrics/bhasa_metrics_specs.py +10 -0
- helm/benchmark/metrics/code_metrics_helper.py +11 -1
- helm/benchmark/metrics/fin_qa_metrics.py +60 -0
- helm/benchmark/metrics/fin_qa_metrics_helper.py +398 -0
- helm/benchmark/metrics/gpt4v_originality_critique_metrics.py +126 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +1 -0
- helm/benchmark/metrics/live_qa_metrics.py +23 -0
- helm/benchmark/metrics/medication_qa_metrics.py +23 -0
- helm/benchmark/metrics/prometheus_vision_critique_metrics.py +185 -0
- helm/benchmark/metrics/reka_vibe_critique_metrics.py +158 -0
- helm/benchmark/metrics/safety_metrics.py +57 -0
- helm/benchmark/metrics/summac/model_summac.py +3 -3
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -2
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +4 -4
- helm/benchmark/metrics/unitxt_metrics.py +20 -10
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +30 -72
- helm/benchmark/metrics/vision_language/image_utils.py +1 -1
- helm/benchmark/model_metadata_registry.py +3 -3
- helm/benchmark/presentation/schema.py +54 -4
- helm/benchmark/presentation/test_run_entry.py +1 -0
- helm/benchmark/presentation/test_schema.py +11 -0
- helm/benchmark/run.py +31 -2
- helm/benchmark/run_expander.py +113 -10
- helm/benchmark/run_spec_factory.py +4 -0
- helm/benchmark/run_specs/air_bench_run_specs.py +40 -0
- helm/benchmark/run_specs/bhasa_run_specs.py +638 -0
- helm/benchmark/run_specs/call_center_run_specs.py +152 -0
- helm/benchmark/run_specs/classic_run_specs.py +15 -11
- helm/benchmark/run_specs/decodingtrust_run_specs.py +11 -9
- helm/benchmark/run_specs/experimental_run_specs.py +85 -0
- helm/benchmark/run_specs/finance_run_specs.py +110 -0
- helm/benchmark/run_specs/safety_run_specs.py +154 -0
- helm/benchmark/run_specs/vlm_run_specs.py +251 -57
- helm/benchmark/scenarios/air_bench_scenario.py +50 -0
- helm/benchmark/scenarios/anthropic_red_team_scenario.py +71 -0
- helm/benchmark/scenarios/banking77_scenario.py +51 -0
- helm/benchmark/scenarios/bhasa_scenario.py +1798 -0
- helm/benchmark/scenarios/call_center_scenario.py +84 -0
- helm/benchmark/scenarios/ci_mcqa_scenario.py +80 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +2 -1
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +8 -2
- helm/benchmark/scenarios/ewok_scenario.py +116 -0
- helm/benchmark/scenarios/fin_qa_scenario.py +119 -0
- helm/benchmark/scenarios/financebench_scenario.py +53 -0
- helm/benchmark/scenarios/harm_bench_scenario.py +59 -0
- helm/benchmark/scenarios/scenario.py +1 -1
- helm/benchmark/scenarios/simple_safety_tests_scenario.py +33 -0
- helm/benchmark/scenarios/test_air_bench_scenario.py +27 -0
- helm/benchmark/scenarios/test_commonsense_scenario.py +21 -0
- helm/benchmark/scenarios/test_ewok_scenario.py +25 -0
- helm/benchmark/scenarios/test_financebench_scenario.py +26 -0
- helm/benchmark/scenarios/test_gsm_scenario.py +31 -0
- helm/benchmark/scenarios/test_legalbench_scenario.py +30 -0
- helm/benchmark/scenarios/test_math_scenario.py +2 -8
- helm/benchmark/scenarios/test_med_qa_scenario.py +30 -0
- helm/benchmark/scenarios/test_mmlu_scenario.py +33 -0
- helm/benchmark/scenarios/test_narrativeqa_scenario.py +73 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +4 -4
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +5 -5
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/exams_v_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/fair_face_scenario.py +136 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/chart2csv_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure/image2structure_scenario.py → image2struct/image2struct_scenario.py} +13 -2
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/latex_scenario.py +3 -7
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/musicsheet_scenario.py +1 -5
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/utils_latex.py +31 -39
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/driver.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/utils.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage_scenario.py +44 -13
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/mme_scenario.py +21 -18
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +7 -6
- helm/benchmark/scenarios/vision_language/pope_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/real_world_qa_scenario.py +57 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +7 -5
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +5 -5
- helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +98 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +3 -1
- helm/benchmark/scenarios/xstest_scenario.py +35 -0
- helm/benchmark/server.py +1 -6
- helm/benchmark/static/schema_air_bench.yaml +3149 -0
- helm/benchmark/static/schema_bhasa.yaml +709 -0
- helm/benchmark/static/schema_call_center.yaml +232 -0
- helm/benchmark/static/schema_classic.yaml +3 -59
- helm/benchmark/static/schema_cleva.yaml +768 -0
- helm/benchmark/static/schema_decodingtrust.yaml +444 -0
- helm/benchmark/static/schema_ewok.yaml +367 -0
- helm/benchmark/static/schema_finance.yaml +189 -0
- helm/benchmark/static/schema_image2struct.yaml +588 -0
- helm/benchmark/static/schema_instruction_following.yaml +3 -52
- helm/benchmark/static/schema_lite.yaml +3 -61
- helm/benchmark/static/schema_medical.yaml +255 -0
- helm/benchmark/static/schema_mmlu.yaml +3 -61
- helm/benchmark/static/schema_safety.yaml +247 -0
- helm/benchmark/static/schema_tables.yaml +317 -0
- helm/benchmark/static/schema_thai.yaml +244 -0
- helm/benchmark/static/schema_unitxt.yaml +3 -61
- helm/benchmark/static/{schema_vlm.yaml → schema_vhelm.yaml} +304 -298
- helm/benchmark/static/schema_vhelm_lite.yaml +4 -59
- helm/benchmark/static_build/assets/accenture-6f97eeda.png +0 -0
- helm/benchmark/static_build/assets/air-overview-d2e6c49f.png +0 -0
- helm/benchmark/static_build/assets/aisingapore-6dfc9acf.png +0 -0
- helm/benchmark/static_build/assets/cresta-9e22b983.png +0 -0
- helm/benchmark/static_build/assets/cuhk-8c5631e9.png +0 -0
- helm/benchmark/static_build/assets/index-05c76bb1.css +1 -0
- helm/benchmark/static_build/assets/index-58f97dcd.js +10 -0
- helm/benchmark/static_build/assets/overview-74aea3d8.png +0 -0
- helm/benchmark/static_build/assets/process-flow-bd2eba96.png +0 -0
- helm/benchmark/static_build/assets/scb10x-204bd786.png +0 -0
- helm/benchmark/static_build/assets/wellsfargo-a86a6c4a.png +0 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/benchmark/window_services/test_openai_window_service.py +8 -8
- helm/clients/ai21_client.py +71 -1
- helm/clients/anthropic_client.py +50 -28
- helm/clients/auto_client.py +11 -0
- helm/clients/client.py +24 -7
- helm/clients/cohere_client.py +98 -3
- helm/clients/huggingface_client.py +79 -19
- helm/clients/nvidia_nim_client.py +35 -0
- helm/clients/openai_client.py +11 -5
- helm/clients/palmyra_client.py +25 -0
- helm/clients/perspective_api_client.py +11 -6
- helm/clients/reka_client.py +189 -0
- helm/clients/test_client.py +7 -9
- helm/clients/test_huggingface_client.py +19 -3
- helm/clients/test_together_client.py +72 -2
- helm/clients/together_client.py +129 -23
- helm/clients/vertexai_client.py +62 -18
- helm/clients/vision_language/huggingface_vlm_client.py +1 -0
- helm/clients/vision_language/open_flamingo_client.py +1 -2
- helm/clients/vision_language/paligemma_client.py +146 -0
- helm/clients/vision_language/palmyra_vision_client.py +99 -0
- helm/clients/yi_client.py +31 -0
- helm/common/critique_request.py +10 -1
- helm/common/images_utils.py +25 -0
- helm/common/mongo_key_value_store.py +2 -1
- helm/common/request.py +16 -0
- helm/config/model_deployments.yaml +740 -363
- helm/config/model_metadata.yaml +824 -128
- helm/config/tokenizer_configs.yaml +207 -10
- helm/proxy/critique/model_critique_client.py +32 -4
- helm/proxy/example_queries.py +14 -21
- helm/proxy/services/server_service.py +2 -3
- helm/proxy/token_counters/test_auto_token_counter.py +2 -2
- helm/tokenizers/ai21_tokenizer.py +51 -59
- helm/tokenizers/auto_tokenizer.py +1 -1
- helm/tokenizers/cohere_tokenizer.py +29 -62
- helm/tokenizers/huggingface_tokenizer.py +35 -13
- helm/tokenizers/test_ai21_tokenizer.py +48 -0
- helm/tokenizers/test_cohere_tokenizer.py +39 -0
- helm/tokenizers/test_huggingface_tokenizer.py +5 -1
- helm/benchmark/static/benchmarking.css +0 -156
- helm/benchmark/static/benchmarking.js +0 -1705
- helm/benchmark/static/config.js +0 -3
- helm/benchmark/static/general.js +0 -122
- helm/benchmark/static/images/crfm-logo.png +0 -0
- helm/benchmark/static/images/helm-logo-simple.png +0 -0
- helm/benchmark/static/images/helm-logo.png +0 -0
- helm/benchmark/static/images/language-model-helm.png +0 -0
- helm/benchmark/static/images/organizations/ai21.png +0 -0
- helm/benchmark/static/images/organizations/anthropic.png +0 -0
- helm/benchmark/static/images/organizations/bigscience.png +0 -0
- helm/benchmark/static/images/organizations/cohere.png +0 -0
- helm/benchmark/static/images/organizations/eleutherai.png +0 -0
- helm/benchmark/static/images/organizations/google.png +0 -0
- helm/benchmark/static/images/organizations/meta.png +0 -0
- helm/benchmark/static/images/organizations/microsoft.png +0 -0
- helm/benchmark/static/images/organizations/nvidia.png +0 -0
- helm/benchmark/static/images/organizations/openai.png +0 -0
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/images/organizations/tsinghua-keg.png +0 -0
- helm/benchmark/static/images/organizations/yandex.png +0 -0
- helm/benchmark/static/images/scenarios-by-metrics.png +0 -0
- helm/benchmark/static/images/taxonomy-scenarios.png +0 -0
- helm/benchmark/static/index.html +0 -68
- helm/benchmark/static/info-icon.png +0 -0
- helm/benchmark/static/json-urls.js +0 -69
- helm/benchmark/static/plot-captions.js +0 -27
- helm/benchmark/static/schema_image2structure.yaml +0 -304
- helm/benchmark/static/utils.js +0 -285
- helm/benchmark/static_build/assets/index-737eef9e.js +0 -10
- helm/benchmark/static_build/assets/index-878a1094.css +0 -1
- helm/benchmark/window_services/ai21_window_service.py +0 -247
- helm/benchmark/window_services/cohere_window_service.py +0 -101
- helm/benchmark/window_services/test_ai21_window_service.py +0 -163
- helm/benchmark/window_services/test_cohere_window_service.py +0 -75
- helm/benchmark/window_services/test_cohere_window_service_utils.py +0 -8328
- helm/benchmark/window_services/test_ice_window_service.py +0 -327
- helm/tokenizers/ice_tokenizer.py +0 -30
- helm/tokenizers/test_ice_tokenizer.py +0 -57
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/LICENSE +0 -0
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.3.dist-info}/top_level.txt +0 -0
- /helm/benchmark/annotation/{image2structure → image2struct}/__init__.py +0 -0
- /helm/benchmark/annotation/{image2structure → image2struct}/image_compiler_annotator.py +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/__init__.py +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/__init__.py +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/jekyll_server.py +0 -0
helm/clients/vertexai_client.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import requests
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from threading import Lock
|
|
4
|
-
from typing import Any, Dict, Optional, List, Union
|
|
4
|
+
from typing import Any, Dict, Mapping, Optional, List, Union
|
|
5
5
|
|
|
6
6
|
from helm.common.cache import CacheConfig
|
|
7
7
|
from helm.common.media_object import TEXT_TYPE
|
|
@@ -26,22 +26,62 @@ class VertexAIContentBlockedError(Exception):
|
|
|
26
26
|
pass
|
|
27
27
|
|
|
28
28
|
|
|
29
|
+
class SafetySettingPresets:
|
|
30
|
+
BLOCK_NONE = "block_none" # Disable all blocking
|
|
31
|
+
DEFAULT = "default" # Use default safety settings
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _get_safety_settings_for_preset(
|
|
35
|
+
safety_settings_preset: Optional[str],
|
|
36
|
+
) -> Optional[Dict[HarmCategory, SafetySetting.HarmBlockThreshold]]:
|
|
37
|
+
"""Get the safety settings for the safety_settings_preset.
|
|
38
|
+
|
|
39
|
+
If safety_settings_preset is None, use the default value of BLOCK_NONE (*not* DEFAULT)."""
|
|
40
|
+
if safety_settings_preset is None or safety_settings_preset == SafetySettingPresets.BLOCK_NONE:
|
|
41
|
+
return {
|
|
42
|
+
harm_category: SafetySetting.HarmBlockThreshold(SafetySetting.HarmBlockThreshold.BLOCK_NONE)
|
|
43
|
+
for harm_category in iter(HarmCategory)
|
|
44
|
+
}
|
|
45
|
+
elif safety_settings_preset == SafetySettingPresets.DEFAULT:
|
|
46
|
+
return None
|
|
47
|
+
else:
|
|
48
|
+
raise ValueError(f"Unknown safety_settings_preset: {safety_settings_preset}")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_model_name_for_request(request: Request) -> str:
|
|
52
|
+
# We have to strip "-safety-" suffixes from model names because they are not part of the Vertex AI model name
|
|
53
|
+
# TODO: Clean up this hack
|
|
54
|
+
return request.model_engine.split("-safety-")[0]
|
|
55
|
+
|
|
56
|
+
|
|
29
57
|
class VertexAIClient(CachingClient, ABC):
|
|
30
58
|
"""Client for Vertex AI models"""
|
|
31
59
|
|
|
32
|
-
def __init__(
|
|
60
|
+
def __init__(
|
|
61
|
+
self, cache_config: CacheConfig, project_id: str, location: str, safety_settings_preset: Optional[str] = None
|
|
62
|
+
) -> None:
|
|
33
63
|
super().__init__(cache_config=cache_config)
|
|
34
64
|
self.project_id = project_id
|
|
35
65
|
self.location = location
|
|
36
66
|
|
|
37
|
-
|
|
38
|
-
self.safety_settings
|
|
39
|
-
harm_category: SafetySetting.HarmBlockThreshold(SafetySetting.HarmBlockThreshold.BLOCK_NONE)
|
|
40
|
-
for harm_category in iter(HarmCategory)
|
|
41
|
-
}
|
|
67
|
+
self.safety_settings_preset = safety_settings_preset
|
|
68
|
+
self.safety_settings = _get_safety_settings_for_preset(safety_settings_preset)
|
|
42
69
|
|
|
43
70
|
vertexai.init(project=self.project_id, location=self.location)
|
|
44
71
|
|
|
72
|
+
def make_cache_key_with_safety_settings_preset(self, raw_request: Mapping, request: Request) -> Mapping:
|
|
73
|
+
"""Construct the key for the cache using the raw request.
|
|
74
|
+
|
|
75
|
+
Add `self.safety_settings_preset` to the key, if not None."""
|
|
76
|
+
if self.safety_settings_preset is not None:
|
|
77
|
+
assert "safety_settings_preset" not in raw_request
|
|
78
|
+
return {
|
|
79
|
+
**CachingClient.make_cache_key(raw_request, request),
|
|
80
|
+
"safety_settings_preset": self.safety_settings_preset,
|
|
81
|
+
}
|
|
82
|
+
else:
|
|
83
|
+
return CachingClient.make_cache_key(raw_request, request)
|
|
84
|
+
|
|
45
85
|
@abstractmethod
|
|
46
86
|
def make_request(self, request: Request) -> RequestResult:
|
|
47
87
|
raise NotImplementedError
|
|
@@ -71,7 +111,7 @@ class VertexAITextClient(VertexAIClient):
|
|
|
71
111
|
}
|
|
72
112
|
|
|
73
113
|
completions: List[GeneratedOutput] = []
|
|
74
|
-
model_name: str = request
|
|
114
|
+
model_name: str = _get_model_name_for_request(request)
|
|
75
115
|
|
|
76
116
|
try:
|
|
77
117
|
|
|
@@ -87,9 +127,9 @@ class VertexAITextClient(VertexAIClient):
|
|
|
87
127
|
# We need to include the engine's name to differentiate among requests made for different model
|
|
88
128
|
# engines since the engine name is not included in the request itself.
|
|
89
129
|
# Same for the prompt.
|
|
90
|
-
cache_key =
|
|
130
|
+
cache_key = self.make_cache_key_with_safety_settings_preset(
|
|
91
131
|
{
|
|
92
|
-
"engine":
|
|
132
|
+
"engine": model_name,
|
|
93
133
|
"prompt": request.prompt,
|
|
94
134
|
**parameters,
|
|
95
135
|
},
|
|
@@ -177,7 +217,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
177
217
|
}
|
|
178
218
|
|
|
179
219
|
completions: List[GeneratedOutput] = []
|
|
180
|
-
model_name: str = request
|
|
220
|
+
model_name: str = _get_model_name_for_request(request)
|
|
181
221
|
model = self.get_model(model_name)
|
|
182
222
|
|
|
183
223
|
try:
|
|
@@ -197,7 +237,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
197
237
|
|
|
198
238
|
# Depending on the version of the Vertex AI library and the type of prompt blocking,
|
|
199
239
|
# prompt blocking can show up in many ways, so this defensively handles most of these ways
|
|
200
|
-
if response.prompt_feedback.block_reason:
|
|
240
|
+
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
|
201
241
|
raise VertexAIContentBlockedError(
|
|
202
242
|
f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
|
|
203
243
|
)
|
|
@@ -209,8 +249,10 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
209
249
|
# content blocking can show up in many ways, so this defensively handles most of these ways
|
|
210
250
|
if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
|
|
211
251
|
raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
|
|
252
|
+
if not candidate.content:
|
|
253
|
+
raise VertexAIContentBlockedError(f"No content in candidate: {candidate}")
|
|
212
254
|
if not candidate.content.parts:
|
|
213
|
-
raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
|
|
255
|
+
raise VertexAIContentBlockedError(f"No content parts in candidate: {candidate}")
|
|
214
256
|
predictions.append({"text": candidate.content.text})
|
|
215
257
|
# TODO: Extract more information from the response
|
|
216
258
|
return {"predictions": predictions}
|
|
@@ -218,7 +260,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
218
260
|
# We need to include the engine's name to differentiate among requests made for different model
|
|
219
261
|
# engines since the engine name is not included in the request itself.
|
|
220
262
|
# Same for the prompt.
|
|
221
|
-
cache_key =
|
|
263
|
+
cache_key = self.make_cache_key_with_safety_settings_preset(
|
|
222
264
|
{
|
|
223
265
|
"model_name": model_name,
|
|
224
266
|
"prompt": request.prompt,
|
|
@@ -313,7 +355,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
313
355
|
}
|
|
314
356
|
|
|
315
357
|
completions: List[GeneratedOutput] = []
|
|
316
|
-
model_name: str = request
|
|
358
|
+
model_name: str = _get_model_name_for_request(request)
|
|
317
359
|
model = self.get_model(model_name)
|
|
318
360
|
|
|
319
361
|
request_time = 0
|
|
@@ -330,7 +372,7 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
330
372
|
)
|
|
331
373
|
# Depending on the version of the Vertex AI library and the type of prompt blocking,
|
|
332
374
|
# prompt blocking can show up in many ways, so this defensively handles most of these ways
|
|
333
|
-
if response.prompt_feedback.block_reason:
|
|
375
|
+
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
|
334
376
|
raise VertexAIContentBlockedError(
|
|
335
377
|
f"Prompt blocked with reason: {response.prompt_feedback.block_reason}"
|
|
336
378
|
)
|
|
@@ -345,15 +387,17 @@ class VertexAIChatClient(VertexAIClient):
|
|
|
345
387
|
# content blocking can show up in many ways, so this defensively handles most of these ways
|
|
346
388
|
if candidate.finish_reason in VertexAIChatClient.CONTENT_BLOCKED_FINISH_REASONS:
|
|
347
389
|
raise VertexAIContentBlockedError(f"Content blocked with reason: {candidate.finish_reason}")
|
|
390
|
+
if not candidate.content:
|
|
391
|
+
raise VertexAIContentBlockedError(f"No content in candidate: {candidate}")
|
|
348
392
|
if not candidate.content.parts:
|
|
349
|
-
raise VertexAIContentBlockedError(f"No parts in candidate: {candidate}")
|
|
393
|
+
raise VertexAIContentBlockedError(f"No content parts in candidate: {candidate}")
|
|
350
394
|
return {"predictions": [{"text": candidate.text}]}
|
|
351
395
|
|
|
352
396
|
raw_cache_key = {"model_name": model_name, "prompt": prompt_key, **parameters}
|
|
353
397
|
if completion_index > 0:
|
|
354
398
|
raw_cache_key["completion_index"] = completion_index
|
|
355
399
|
|
|
356
|
-
cache_key =
|
|
400
|
+
cache_key = self.make_cache_key_with_safety_settings_preset(raw_cache_key, request)
|
|
357
401
|
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
358
402
|
except requests.exceptions.RequestException as e:
|
|
359
403
|
error: str = f"Gemini Vision error: {e}"
|
|
@@ -38,6 +38,7 @@ class HuggingFaceVLMClient(CachingClient):
|
|
|
38
38
|
"huggingface/llava-v1.6-vicuna-13b-hf": "llava-hf/llava-v1.6-vicuna-13b-hf",
|
|
39
39
|
"huggingface/llava-v1.6-mistral-7b-hf": "llava-hf/llava-v1.6-mistral-7b-hf",
|
|
40
40
|
"huggingface/llava-v1.6-34b-hf": "llava-hf/llava-v1.6-34b-hf",
|
|
41
|
+
"huggingface/prometheus-vision-13b-v1.0-hf": "PahaII/prometheus-vision-13b-v1.0-hf",
|
|
41
42
|
}
|
|
42
43
|
|
|
43
44
|
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
|
|
@@ -82,13 +82,12 @@ class OpenFlamingoClient(CachingClient):
|
|
|
82
82
|
# Build the prompt
|
|
83
83
|
prompt_text: str = ""
|
|
84
84
|
images: List[Image.Image] = []
|
|
85
|
+
request.validate()
|
|
85
86
|
for media_object in request.multimodal_prompt.media_objects:
|
|
86
87
|
if media_object.is_type("image") and media_object.location:
|
|
87
88
|
images.append(open_image(media_object.location))
|
|
88
89
|
prompt_text += self.IMAGE_TOKEN
|
|
89
90
|
elif media_object.is_type(TEXT_TYPE):
|
|
90
|
-
if media_object.text is None:
|
|
91
|
-
raise ValueError("MediaObject of text type has missing text field value")
|
|
92
91
|
prompt_text += media_object.text
|
|
93
92
|
else:
|
|
94
93
|
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
|
7
|
+
|
|
8
|
+
from helm.common.cache import CacheConfig
|
|
9
|
+
from helm.common.images_utils import open_image
|
|
10
|
+
from helm.common.gpu_utils import get_torch_device_name
|
|
11
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
12
|
+
from helm.common.media_object import TEXT_TYPE
|
|
13
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
14
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
15
|
+
from helm.common.tokenization_request import TokenizationRequest
|
|
16
|
+
from helm.common.request import wrap_request_time
|
|
17
|
+
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
|
|
18
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from PIL import Image
|
|
22
|
+
except ModuleNotFoundError as e:
|
|
23
|
+
handle_module_not_found_error(e, ["images"])
|
|
24
|
+
|
|
25
|
+
# Added to solve: cutlassF: no kernel found to launch!
|
|
26
|
+
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
|
27
|
+
torch.backends.cuda.enable_flash_sdp(False)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class LoadedPaliGemmaForConditionalGeneration:
|
|
32
|
+
"""Loaded model and processor for PaliGemma."""
|
|
33
|
+
|
|
34
|
+
model: PaliGemmaForConditionalGeneration
|
|
35
|
+
processor: AutoProcessor
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
_models_lock: Lock = Lock()
|
|
39
|
+
_models: Dict[str, Optional[LoadedPaliGemmaForConditionalGeneration]] = {}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PaliGemmaClient(CachingClient):
|
|
43
|
+
"""
|
|
44
|
+
PaliGemma is a versatile and lightweight vision-language model (VLM) inspired by PaLI-3
|
|
45
|
+
and based on open components such as the SigLIP vision model and the Gemma language model.
|
|
46
|
+
It takes both image and text as input and generates text as output, supporting multiple languages.
|
|
47
|
+
It is designed for class-leading fine-tune performance on a wide range of vision-language tasks
|
|
48
|
+
such as image and short video caption, visual question answering, text reading, object detection
|
|
49
|
+
and object segmentation.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig):
|
|
53
|
+
super().__init__(cache_config=cache_config)
|
|
54
|
+
self.tokenizer = tokenizer
|
|
55
|
+
self.tokenizer_name = tokenizer_name
|
|
56
|
+
self._device: str = get_torch_device_name()
|
|
57
|
+
|
|
58
|
+
def _get_model(self, checkpoint: str) -> LoadedPaliGemmaForConditionalGeneration:
|
|
59
|
+
global _models_lock
|
|
60
|
+
global _models
|
|
61
|
+
|
|
62
|
+
# Ensure that only one thread is loading the model at a time
|
|
63
|
+
with _models_lock:
|
|
64
|
+
if checkpoint not in _models or _models[checkpoint] is None:
|
|
65
|
+
hlog(f"Loading model {checkpoint} and caching in memory...")
|
|
66
|
+
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
|
67
|
+
checkpoint, torch_dtype=torch.bfloat16, device_map="auto"
|
|
68
|
+
).eval()
|
|
69
|
+
processor = AutoProcessor.from_pretrained(checkpoint)
|
|
70
|
+
_models[checkpoint] = LoadedPaliGemmaForConditionalGeneration(model, processor)
|
|
71
|
+
loaded_model_processor = _models[checkpoint]
|
|
72
|
+
|
|
73
|
+
assert loaded_model_processor is not None
|
|
74
|
+
return loaded_model_processor
|
|
75
|
+
|
|
76
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
77
|
+
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
78
|
+
|
|
79
|
+
loaded_model_processor: LoadedPaliGemmaForConditionalGeneration = self._get_model(request.model_deployment)
|
|
80
|
+
model = loaded_model_processor.model
|
|
81
|
+
processor = loaded_model_processor.processor
|
|
82
|
+
generation_args = {"max_new_tokens": request.max_tokens}
|
|
83
|
+
|
|
84
|
+
images: List[Image.Image] = []
|
|
85
|
+
prompt_pieces: List[str] = []
|
|
86
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
87
|
+
if media_object.is_type("image") and media_object.location:
|
|
88
|
+
images += [open_image(media_object.location).convert("RGB")]
|
|
89
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
90
|
+
if media_object.text is None:
|
|
91
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
92
|
+
prompt_pieces.append(media_object.text)
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
95
|
+
prompt_text: str = "\n".join(prompt_pieces)
|
|
96
|
+
model_inputs = processor(text=prompt_text, images=images, return_tensors="pt").to(self._device)
|
|
97
|
+
input_len = model_inputs["input_ids"].shape[-1]
|
|
98
|
+
|
|
99
|
+
completions: List[GeneratedOutput] = []
|
|
100
|
+
with htrack_block(f"Generating for prompt: {prompt_text}"):
|
|
101
|
+
try:
|
|
102
|
+
concat_results = []
|
|
103
|
+
for i_completion in range(request.num_completions):
|
|
104
|
+
|
|
105
|
+
def do_it() -> Dict[str, Any]:
|
|
106
|
+
with torch.inference_mode():
|
|
107
|
+
generation = model.generate(
|
|
108
|
+
**model_inputs, max_new_tokens=request.max_tokens, do_sample=False
|
|
109
|
+
)[0]
|
|
110
|
+
if not request.echo_prompt:
|
|
111
|
+
generation = generation[input_len:]
|
|
112
|
+
decoded = processor.decode(generation, skip_special_tokens=True)
|
|
113
|
+
return {"output": decoded}
|
|
114
|
+
|
|
115
|
+
# Include the prompt and model name in the cache key
|
|
116
|
+
cache_key = CachingClient.make_cache_key(
|
|
117
|
+
raw_request={
|
|
118
|
+
"n": request.num_completions,
|
|
119
|
+
"i": i_completion,
|
|
120
|
+
"model": request.model,
|
|
121
|
+
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
|
|
122
|
+
**generation_args,
|
|
123
|
+
},
|
|
124
|
+
request=request,
|
|
125
|
+
)
|
|
126
|
+
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
127
|
+
concat_results.append(result)
|
|
128
|
+
except RuntimeError as model_error:
|
|
129
|
+
return RequestResult(success=False, cached=False, error=str(model_error), completions=[], embedding=[])
|
|
130
|
+
|
|
131
|
+
for result in concat_results:
|
|
132
|
+
text = result["output"]
|
|
133
|
+
hlog(f"Generated text: {text}")
|
|
134
|
+
tokenization_result = self.tokenizer.tokenize(
|
|
135
|
+
TokenizationRequest(text, tokenizer=self.tokenizer_name, encode=False)
|
|
136
|
+
)
|
|
137
|
+
tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
|
|
138
|
+
completions.append(GeneratedOutput(text=text, logprob=0, tokens=tokens))
|
|
139
|
+
|
|
140
|
+
return RequestResult(
|
|
141
|
+
success=True,
|
|
142
|
+
cached=cached,
|
|
143
|
+
request_time=result["request_time"],
|
|
144
|
+
completions=completions,
|
|
145
|
+
embedding=[],
|
|
146
|
+
)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from typing import Dict, List
|
|
2
|
+
import json
|
|
3
|
+
|
|
4
|
+
import requests
|
|
5
|
+
|
|
6
|
+
from helm.common.cache import CacheConfig
|
|
7
|
+
from helm.common.images_utils import encode_base64
|
|
8
|
+
from helm.common.media_object import TEXT_TYPE
|
|
9
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, ErrorFlags
|
|
10
|
+
from helm.common.request import wrap_request_time
|
|
11
|
+
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt, truncate_and_tokenize_response_text
|
|
12
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PalmyraVisionContentBlockedError(Exception):
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PalmyraVisionClient(CachingClient):
|
|
20
|
+
CONTENT_BLOCKED_ERROR: str = "fail.input.content.moderation"
|
|
21
|
+
|
|
22
|
+
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, endpoint: str, cache_config: CacheConfig):
|
|
23
|
+
super().__init__(cache_config)
|
|
24
|
+
self.tokenizer: Tokenizer = tokenizer
|
|
25
|
+
self.tokenizer_name: str = tokenizer_name
|
|
26
|
+
|
|
27
|
+
# Currently, the Palmyra Vision model does not have a public API, so we need to use a secret endpoint
|
|
28
|
+
self.endpoint: str = endpoint
|
|
29
|
+
|
|
30
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
31
|
+
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
32
|
+
|
|
33
|
+
# Build the prompt
|
|
34
|
+
prompt: List[Dict[str, str]] = []
|
|
35
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
36
|
+
if media_object.is_type("image") and media_object.location:
|
|
37
|
+
prompt.append(
|
|
38
|
+
{
|
|
39
|
+
"type": "InlineData",
|
|
40
|
+
"value": encode_base64(media_object.location, format="JPEG"),
|
|
41
|
+
"contentType": "image/jpeg",
|
|
42
|
+
}
|
|
43
|
+
)
|
|
44
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
45
|
+
if media_object.text is None:
|
|
46
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
47
|
+
prompt.append({"type": "Text", "value": media_object.text})
|
|
48
|
+
else:
|
|
49
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
50
|
+
|
|
51
|
+
# Generate
|
|
52
|
+
try:
|
|
53
|
+
|
|
54
|
+
def do_it():
|
|
55
|
+
response = requests.post(
|
|
56
|
+
self.endpoint, headers={"Content-Type": "application/json"}, data=json.dumps({"parts": prompt})
|
|
57
|
+
)
|
|
58
|
+
json_response = json.loads(response.text)
|
|
59
|
+
|
|
60
|
+
# Check for content blocked error
|
|
61
|
+
if (
|
|
62
|
+
"errors" in json_response
|
|
63
|
+
and "tpe" in json_response
|
|
64
|
+
and json_response["tpe"] == self.CONTENT_BLOCKED_ERROR
|
|
65
|
+
):
|
|
66
|
+
raise PalmyraVisionContentBlockedError(json_response["errors"])
|
|
67
|
+
|
|
68
|
+
# Hard fail if the `choices` is missing from the response
|
|
69
|
+
assert "choices" in json_response, f"Invalid response: {response.text}"
|
|
70
|
+
|
|
71
|
+
return json_response
|
|
72
|
+
|
|
73
|
+
cache_key = CachingClient.make_cache_key(
|
|
74
|
+
raw_request={"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt)},
|
|
75
|
+
request=request,
|
|
76
|
+
)
|
|
77
|
+
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
78
|
+
except PalmyraVisionContentBlockedError as ex:
|
|
79
|
+
return RequestResult(
|
|
80
|
+
success=False,
|
|
81
|
+
cached=False,
|
|
82
|
+
error=f"Content blocked: {str(ex)}",
|
|
83
|
+
completions=[],
|
|
84
|
+
embedding=[],
|
|
85
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# The internal endpoint doesn't support any other parameters, so we have to truncate ourselves
|
|
89
|
+
completions: List[GeneratedOutput] = [
|
|
90
|
+
truncate_and_tokenize_response_text(choice["text"], request, self.tokenizer, self.tokenizer_name)
|
|
91
|
+
for choice in result["choices"]
|
|
92
|
+
]
|
|
93
|
+
return RequestResult(
|
|
94
|
+
success=True,
|
|
95
|
+
cached=cached,
|
|
96
|
+
request_time=result["request_time"],
|
|
97
|
+
completions=completions,
|
|
98
|
+
embedding=[],
|
|
99
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from helm.clients.openai_client import OpenAIClient
|
|
4
|
+
from helm.common.cache import CacheConfig
|
|
5
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class YiChatClient(OpenAIClient):
|
|
9
|
+
|
|
10
|
+
BASE_URL = "http://api.01ww.xyz/v1"
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
tokenizer: Tokenizer,
|
|
15
|
+
tokenizer_name: str,
|
|
16
|
+
cache_config: CacheConfig,
|
|
17
|
+
api_key: Optional[str] = None,
|
|
18
|
+
):
|
|
19
|
+
self.tokenizer = tokenizer
|
|
20
|
+
self.tokenizer_name = tokenizer_name
|
|
21
|
+
super().__init__(
|
|
22
|
+
tokenizer=tokenizer,
|
|
23
|
+
tokenizer_name=tokenizer_name,
|
|
24
|
+
cache_config=cache_config,
|
|
25
|
+
api_key=api_key,
|
|
26
|
+
org_id=None,
|
|
27
|
+
base_url=YiChatClient.BASE_URL,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
def _is_chat_model_engine(self, model_engine: str) -> bool:
|
|
31
|
+
return True
|
helm/common/critique_request.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Dict, List, Union
|
|
2
|
+
from typing import Dict, List, Union, Optional
|
|
3
|
+
from helm.common.media_object import MediaObject
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class QuestionType:
|
|
@@ -34,6 +35,11 @@ class CritiqueQuestionTemplate:
|
|
|
34
35
|
|
|
35
36
|
Can contain placeholders like {{placeholder}} that will be interpolated using the fields in CritiqueRequest."""
|
|
36
37
|
|
|
38
|
+
media_object: Optional[MediaObject] = None
|
|
39
|
+
"""Path of image for multimodal input.
|
|
40
|
+
|
|
41
|
+
Image path or URL of the question."""
|
|
42
|
+
|
|
37
43
|
|
|
38
44
|
@dataclass(frozen=True)
|
|
39
45
|
class CritiqueTaskTemplate:
|
|
@@ -53,6 +59,9 @@ class CritiqueTaskTemplate:
|
|
|
53
59
|
questions: List[CritiqueQuestionTemplate]
|
|
54
60
|
"""List of templates for the questions."""
|
|
55
61
|
|
|
62
|
+
max_tokens: Optional[int] = None
|
|
63
|
+
"""Max token to be generated for the free-end generation."""
|
|
64
|
+
|
|
56
65
|
|
|
57
66
|
@dataclass(frozen=True)
|
|
58
67
|
class CritiqueRequest:
|
helm/common/images_utils.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
from hashlib import md5
|
|
1
2
|
import base64
|
|
2
3
|
import io
|
|
4
|
+
import os
|
|
3
5
|
|
|
4
6
|
import requests
|
|
5
7
|
import shutil
|
|
@@ -43,6 +45,11 @@ def encode_base64(image_location: str, format="JPEG") -> str:
|
|
|
43
45
|
return base64.b64encode(image_file.getvalue()).decode("ascii")
|
|
44
46
|
|
|
45
47
|
|
|
48
|
+
def generate_hash(image: Image.Image) -> str:
|
|
49
|
+
"""Generates a hash for the image."""
|
|
50
|
+
return md5(image.tobytes()).hexdigest()
|
|
51
|
+
|
|
52
|
+
|
|
46
53
|
def copy_image(src: str, dest: str, width: Optional[int] = None, height: Optional[int] = None) -> None:
|
|
47
54
|
"""
|
|
48
55
|
Copies the image file from `src` path to `dest` path. If dimensions `width` and `height`
|
|
@@ -57,6 +64,24 @@ def copy_image(src: str, dest: str, width: Optional[int] = None, height: Optiona
|
|
|
57
64
|
shutil.copy(src, dest)
|
|
58
65
|
|
|
59
66
|
|
|
67
|
+
def resize_image_to_max_file_size(src: str, dest: str, max_size_in_bytes: int, step=10):
|
|
68
|
+
# Open an image file
|
|
69
|
+
with Image.open(src) as img:
|
|
70
|
+
width, height = img.size
|
|
71
|
+
|
|
72
|
+
# Reduce dimensions iteratively until the file size is under the limit
|
|
73
|
+
while True:
|
|
74
|
+
# Save the image temporarily to check the file size
|
|
75
|
+
img.save(dest, quality=95) # Start with high quality
|
|
76
|
+
if os.path.getsize(dest) < max_size_in_bytes:
|
|
77
|
+
break
|
|
78
|
+
|
|
79
|
+
# Reduce dimensions
|
|
80
|
+
width -= step
|
|
81
|
+
height -= step
|
|
82
|
+
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
|
83
|
+
|
|
84
|
+
|
|
60
85
|
def is_blacked_out_image(image_location: str) -> bool:
|
|
61
86
|
"""Returns True if the image is all black. False otherwise."""
|
|
62
87
|
try:
|
helm/common/request.py
CHANGED
|
@@ -72,6 +72,22 @@ class Request:
|
|
|
72
72
|
image_generation_parameters: Optional[ImageGenerationParameters] = None
|
|
73
73
|
"""Parameters for image generation."""
|
|
74
74
|
|
|
75
|
+
def validate(self):
|
|
76
|
+
if (
|
|
77
|
+
(self.messages and self.prompt)
|
|
78
|
+
or (self.messages and self.multimodal_prompt)
|
|
79
|
+
or (self.prompt and self.multimodal_prompt)
|
|
80
|
+
):
|
|
81
|
+
raise ValueError("Exactly one of the messages, prompt, multimodal_prompt fields should be set")
|
|
82
|
+
|
|
83
|
+
if self.multimodal_prompt:
|
|
84
|
+
for media_object in self.multimodal_prompt.media_objects:
|
|
85
|
+
if media_object.content_type == "text" and media_object.text is None:
|
|
86
|
+
raise ValueError("Media object with text content type must have text set")
|
|
87
|
+
|
|
88
|
+
if media_object.content_type == "image" and media_object.location is None:
|
|
89
|
+
raise ValueError("Media object with image content type must have location set")
|
|
90
|
+
|
|
75
91
|
@property
|
|
76
92
|
def model_host(self) -> str:
|
|
77
93
|
"""Returns the model host (referring to the deployment).
|