crfm-helm 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/METADATA +81 -112
- {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/RECORD +165 -155
- {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +12 -12
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +8 -8
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +77 -9
- helm/benchmark/adaptation/common_adapter_specs.py +2 -0
- helm/benchmark/annotation/anthropic_red_team_annotator.py +57 -0
- helm/benchmark/annotation/call_center_annotator.py +258 -0
- helm/benchmark/annotation/financebench_annotator.py +79 -0
- helm/benchmark/annotation/harm_bench_annotator.py +55 -0
- helm/benchmark/annotation/{image2structure → image2struct}/latex_compiler_annotator.py +2 -2
- helm/benchmark/annotation/{image2structure → image2struct}/lilypond_compiler_annotator.py +5 -3
- helm/benchmark/annotation/{image2structure → image2struct}/webpage_compiler_annotator.py +5 -5
- helm/benchmark/annotation/live_qa_annotator.py +37 -45
- helm/benchmark/annotation/medication_qa_annotator.py +36 -44
- helm/benchmark/annotation/model_as_judge.py +96 -0
- helm/benchmark/annotation/simple_safety_tests_annotator.py +50 -0
- helm/benchmark/annotation/xstest_annotator.py +100 -0
- helm/benchmark/metrics/annotation_metrics.py +108 -0
- helm/benchmark/metrics/bhasa_metrics.py +188 -0
- helm/benchmark/metrics/bhasa_metrics_specs.py +10 -0
- helm/benchmark/metrics/code_metrics_helper.py +11 -1
- helm/benchmark/metrics/safety_metrics.py +79 -0
- helm/benchmark/metrics/summac/model_summac.py +3 -3
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -2
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +4 -4
- helm/benchmark/metrics/unitxt_metrics.py +17 -3
- helm/benchmark/metrics/vision_language/image_metrics.py +7 -3
- helm/benchmark/metrics/vision_language/image_utils.py +1 -1
- helm/benchmark/model_metadata_registry.py +3 -3
- helm/benchmark/presentation/create_plots.py +1 -1
- helm/benchmark/presentation/schema.py +3 -0
- helm/benchmark/presentation/summarize.py +106 -256
- helm/benchmark/presentation/test_run_entry.py +1 -0
- helm/benchmark/presentation/test_summarize.py +145 -3
- helm/benchmark/run.py +15 -0
- helm/benchmark/run_expander.py +83 -30
- helm/benchmark/run_specs/bhasa_run_specs.py +652 -0
- helm/benchmark/run_specs/call_center_run_specs.py +152 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +8 -8
- helm/benchmark/run_specs/experimental_run_specs.py +52 -0
- helm/benchmark/run_specs/finance_run_specs.py +82 -1
- helm/benchmark/run_specs/safety_run_specs.py +154 -0
- helm/benchmark/run_specs/vlm_run_specs.py +100 -24
- helm/benchmark/scenarios/anthropic_red_team_scenario.py +71 -0
- helm/benchmark/scenarios/banking77_scenario.py +51 -0
- helm/benchmark/scenarios/bhasa_scenario.py +1942 -0
- helm/benchmark/scenarios/call_center_scenario.py +84 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +2 -1
- helm/benchmark/scenarios/ewok_scenario.py +116 -0
- helm/benchmark/scenarios/fin_qa_scenario.py +2 -0
- helm/benchmark/scenarios/financebench_scenario.py +53 -0
- helm/benchmark/scenarios/harm_bench_scenario.py +59 -0
- helm/benchmark/scenarios/raft_scenario.py +1 -1
- helm/benchmark/scenarios/scenario.py +1 -1
- helm/benchmark/scenarios/simple_safety_tests_scenario.py +33 -0
- helm/benchmark/scenarios/test_commonsense_scenario.py +21 -0
- helm/benchmark/scenarios/test_ewok_scenario.py +25 -0
- helm/benchmark/scenarios/test_financebench_scenario.py +26 -0
- helm/benchmark/scenarios/test_gsm_scenario.py +31 -0
- helm/benchmark/scenarios/test_legalbench_scenario.py +30 -0
- helm/benchmark/scenarios/test_math_scenario.py +2 -8
- helm/benchmark/scenarios/test_med_qa_scenario.py +30 -0
- helm/benchmark/scenarios/test_mmlu_scenario.py +33 -0
- helm/benchmark/scenarios/test_narrativeqa_scenario.py +73 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +4 -4
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/exams_v_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/fair_face_scenario.py +136 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/chart2csv_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/latex_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/musicsheet_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/utils_latex.py +31 -39
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/driver.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/utils.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage_scenario.py +41 -12
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/mme_scenario.py +21 -18
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/pope_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/real_world_qa_scenario.py +57 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +7 -5
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +6 -3
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +3 -1
- helm/benchmark/scenarios/xstest_scenario.py +35 -0
- helm/benchmark/server.py +1 -6
- helm/benchmark/static/schema_air_bench.yaml +750 -750
- helm/benchmark/static/schema_bhasa.yaml +709 -0
- helm/benchmark/static/schema_call_center.yaml +232 -0
- helm/benchmark/static/schema_cleva.yaml +768 -0
- helm/benchmark/static/schema_decodingtrust.yaml +444 -0
- helm/benchmark/static/schema_ewok.yaml +367 -0
- helm/benchmark/static/schema_finance.yaml +55 -9
- helm/benchmark/static/{schema_image2structure.yaml → schema_image2struct.yaml} +231 -90
- helm/benchmark/static/schema_legal.yaml +566 -0
- helm/benchmark/static/schema_safety.yaml +266 -0
- helm/benchmark/static/schema_tables.yaml +149 -8
- helm/benchmark/static/schema_thai.yaml +21 -0
- helm/benchmark/static/schema_vhelm.yaml +137 -101
- helm/benchmark/static_build/assets/accenture-6f97eeda.png +0 -0
- helm/benchmark/static_build/assets/aisingapore-6dfc9acf.png +0 -0
- helm/benchmark/static_build/assets/cresta-9e22b983.png +0 -0
- helm/benchmark/static_build/assets/cuhk-8c5631e9.png +0 -0
- helm/benchmark/static_build/assets/index-05c76bb1.css +1 -0
- helm/benchmark/static_build/assets/index-3ee38b3d.js +10 -0
- helm/benchmark/static_build/assets/scb10x-204bd786.png +0 -0
- helm/benchmark/static_build/assets/vhelm-aspects-1437d673.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-a1ca3f3f.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-8afb7616.png +0 -0
- helm/benchmark/static_build/assets/wellsfargo-a86a6c4a.png +0 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/benchmark/window_services/test_openai_window_service.py +8 -8
- helm/benchmark/window_services/tokenizer_service.py +0 -5
- helm/clients/ai21_client.py +71 -1
- helm/clients/anthropic_client.py +7 -19
- helm/clients/huggingface_client.py +38 -37
- helm/clients/nvidia_nim_client.py +35 -0
- helm/clients/openai_client.py +18 -4
- helm/clients/palmyra_client.py +24 -0
- helm/clients/perspective_api_client.py +11 -6
- helm/clients/test_client.py +4 -6
- helm/clients/together_client.py +22 -0
- helm/clients/vision_language/open_flamingo_client.py +1 -2
- helm/clients/vision_language/palmyra_vision_client.py +28 -13
- helm/common/cache.py +8 -30
- helm/common/images_utils.py +6 -0
- helm/common/key_value_store.py +9 -9
- helm/common/mongo_key_value_store.py +5 -4
- helm/common/request.py +16 -0
- helm/common/test_cache.py +1 -48
- helm/common/tokenization_request.py +0 -9
- helm/config/model_deployments.yaml +444 -329
- helm/config/model_metadata.yaml +513 -111
- helm/config/tokenizer_configs.yaml +140 -11
- helm/proxy/example_queries.py +14 -21
- helm/proxy/server.py +0 -9
- helm/proxy/services/remote_service.py +0 -6
- helm/proxy/services/server_service.py +6 -20
- helm/proxy/services/service.py +0 -6
- helm/proxy/token_counters/test_auto_token_counter.py +2 -2
- helm/tokenizers/ai21_tokenizer.py +51 -59
- helm/tokenizers/cohere_tokenizer.py +0 -75
- helm/tokenizers/huggingface_tokenizer.py +0 -1
- helm/tokenizers/test_ai21_tokenizer.py +48 -0
- helm/benchmark/data_overlap/data_overlap_spec.py +0 -86
- helm/benchmark/data_overlap/export_scenario_text.py +0 -119
- helm/benchmark/data_overlap/light_scenario.py +0 -60
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/static/benchmarking.css +0 -156
- helm/benchmark/static/benchmarking.js +0 -1705
- helm/benchmark/static/config.js +0 -3
- helm/benchmark/static/general.js +0 -122
- helm/benchmark/static/images/crfm-logo.png +0 -0
- helm/benchmark/static/images/helm-logo-simple.png +0 -0
- helm/benchmark/static/images/helm-logo.png +0 -0
- helm/benchmark/static/images/language-model-helm.png +0 -0
- helm/benchmark/static/images/organizations/ai21.png +0 -0
- helm/benchmark/static/images/organizations/anthropic.png +0 -0
- helm/benchmark/static/images/organizations/bigscience.png +0 -0
- helm/benchmark/static/images/organizations/cohere.png +0 -0
- helm/benchmark/static/images/organizations/eleutherai.png +0 -0
- helm/benchmark/static/images/organizations/google.png +0 -0
- helm/benchmark/static/images/organizations/meta.png +0 -0
- helm/benchmark/static/images/organizations/microsoft.png +0 -0
- helm/benchmark/static/images/organizations/nvidia.png +0 -0
- helm/benchmark/static/images/organizations/openai.png +0 -0
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/images/organizations/tsinghua-keg.png +0 -0
- helm/benchmark/static/images/organizations/yandex.png +0 -0
- helm/benchmark/static/images/scenarios-by-metrics.png +0 -0
- helm/benchmark/static/images/taxonomy-scenarios.png +0 -0
- helm/benchmark/static/index.html +0 -68
- helm/benchmark/static/info-icon.png +0 -0
- helm/benchmark/static/json-urls.js +0 -69
- helm/benchmark/static/plot-captions.js +0 -27
- helm/benchmark/static/utils.js +0 -285
- helm/benchmark/static_build/assets/index-30dbceba.js +0 -10
- helm/benchmark/static_build/assets/index-66b02d40.css +0 -1
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/window_services/ai21_window_service.py +0 -247
- helm/benchmark/window_services/cohere_window_service.py +0 -101
- helm/benchmark/window_services/test_ai21_window_service.py +0 -163
- helm/benchmark/window_services/test_cohere_window_service.py +0 -75
- helm/benchmark/window_services/test_cohere_window_service_utils.py +0 -8328
- helm/benchmark/window_services/test_ice_window_service.py +0 -327
- helm/tokenizers/ice_tokenizer.py +0 -30
- helm/tokenizers/test_ice_tokenizer.py +0 -57
- {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/LICENSE +0 -0
- {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/top_level.txt +0 -0
- /helm/benchmark/annotation/{image2structure → image2struct}/__init__.py +0 -0
- /helm/benchmark/annotation/{image2structure → image2struct}/image_compiler_annotator.py +0 -0
- /helm/benchmark/{data_overlap → scenarios/vision_language/image2struct}/__init__.py +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure/image2structure_scenario.py → image2struct/image2struct_scenario.py} +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure → image2struct/webpage}/__init__.py +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/jekyll_server.py +0 -0
|
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, TypedDict
|
|
|
9
9
|
|
|
10
10
|
from helm.common.cache import CacheConfig
|
|
11
11
|
from helm.common.hierarchical_logger import htrack_block, hlog
|
|
12
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
12
13
|
from helm.common.request import (
|
|
13
14
|
wrap_request_time,
|
|
14
15
|
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
@@ -58,60 +59,61 @@ class HuggingFaceServer:
|
|
|
58
59
|
self,
|
|
59
60
|
pretrained_model_name_or_path: str,
|
|
60
61
|
wrapped_tokenizer: WrappedPreTrainedTokenizer,
|
|
61
|
-
openvino=False,
|
|
62
|
+
openvino: bool = False,
|
|
62
63
|
**kwargs,
|
|
63
64
|
):
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
65
|
+
self.device: Optional[str]
|
|
66
|
+
if "device_map" in kwargs:
|
|
67
|
+
try:
|
|
68
|
+
import accelerate # noqa: F401
|
|
69
|
+
except ModuleNotFoundError as e:
|
|
70
|
+
handle_module_not_found_error(e, ["accelerate"])
|
|
71
|
+
hlog(f'Hugging Face device_map set to "{kwargs["device_map"]}".')
|
|
72
|
+
self.device = None
|
|
73
|
+
elif torch.cuda.is_available():
|
|
74
|
+
hlog('Hugging Face device set to "cuda:0" because CUDA is available.')
|
|
75
|
+
self.device = "cuda:0"
|
|
67
76
|
else:
|
|
77
|
+
hlog('Hugging Face device set to "cpu" because CUDA is unavailable.')
|
|
68
78
|
self.device = "cpu"
|
|
79
|
+
|
|
80
|
+
# Security issue: currently we trust remote code by default.
|
|
81
|
+
# We retain this temporarily to maintain reverse compatibility.
|
|
82
|
+
# TODO: Delete if-else and don't set trust_remote_code=True
|
|
83
|
+
if "trust_remote_code" not in kwargs:
|
|
84
|
+
kwargs["trust_remote_code"] = True
|
|
85
|
+
|
|
69
86
|
with htrack_block(f"Loading Hugging Face model {pretrained_model_name_or_path}"):
|
|
70
87
|
# WARNING this may fail if your GPU does not have enough memory
|
|
71
88
|
if openvino:
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
Intel® architectures using OpenVINO™ runtime.
|
|
76
|
-
"""
|
|
77
|
-
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
78
|
-
|
|
89
|
+
# Optimum Intel provides a simple interface to optimize Transformer models and convert them to \
|
|
90
|
+
# OpenVINO™ Intermediate Representation (IR) format to accelerate end-to-end pipelines on \
|
|
91
|
+
# Intel® architectures using OpenVINO™ runtime.
|
|
79
92
|
try:
|
|
80
93
|
from optimum.intel.openvino import OVModelForCausalLM
|
|
81
94
|
except ModuleNotFoundError as e:
|
|
82
95
|
handle_module_not_found_error(e, ["openvino"])
|
|
83
96
|
|
|
84
97
|
self.device = "cpu"
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
else:
|
|
93
|
-
self.model = OVModelForCausalLM.from_pretrained(
|
|
94
|
-
pretrained_model_name_or_path, export=True, trust_remote_code=True, **kwargs
|
|
95
|
-
).to(self.device)
|
|
98
|
+
self.model = OVModelForCausalLM.from_pretrained(
|
|
99
|
+
pretrained_model_name_or_path, export=True, **kwargs
|
|
100
|
+
).to(self.device)
|
|
101
|
+
elif self.device is None:
|
|
102
|
+
# kwargs contains device_map=auto
|
|
103
|
+
# Do not call to() because accelerate will take care of model device placement.
|
|
104
|
+
self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
96
105
|
else:
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
if "trust_remote_code" in kwargs:
|
|
101
|
-
self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs).to(
|
|
102
|
-
self.device
|
|
103
|
-
)
|
|
104
|
-
else:
|
|
105
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
|
106
|
-
pretrained_model_name_or_path, trust_remote_code=True, **kwargs
|
|
107
|
-
).to(self.device)
|
|
106
|
+
self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs).to(
|
|
107
|
+
self.device
|
|
108
|
+
)
|
|
108
109
|
self.wrapped_tokenizer = wrapped_tokenizer
|
|
109
110
|
|
|
110
111
|
def serve_request(self, raw_request: HuggingFaceRequest) -> Dict:
|
|
111
112
|
with self.wrapped_tokenizer as tokenizer:
|
|
112
113
|
encoded_input = tokenizer(raw_request["prompt"], return_tensors="pt", return_token_type_ids=False).to(
|
|
113
|
-
self.device
|
|
114
|
+
0 if self.device is None else self.device
|
|
114
115
|
)
|
|
116
|
+
|
|
115
117
|
stopping_criteria: Optional[StoppingCriteriaList] = None
|
|
116
118
|
optional_args = {}
|
|
117
119
|
if len(raw_request["stop_sequences"]) > 0:
|
|
@@ -249,9 +251,8 @@ def _process_huggingface_client_kwargs(raw_kwargs: Dict[str, Any]):
|
|
|
249
251
|
# e.g. the string "torch.bfloat16" is converted to torch.bfloat16
|
|
250
252
|
torch_dtype = processed_kwargs.get(TORCH_DTYPE_KEY)
|
|
251
253
|
if torch_dtype and isinstance(torch_dtype, str):
|
|
252
|
-
if
|
|
253
|
-
|
|
254
|
-
processed_kwargs[TORCH_DTYPE_KEY] = getattr(torch, torch_dtype[len(TORCH_DTYPE_VALUE_PREFIX) :])
|
|
254
|
+
if torch_dtype.startswith(TORCH_DTYPE_VALUE_PREFIX):
|
|
255
|
+
processed_kwargs[TORCH_DTYPE_KEY] = getattr(torch, torch_dtype[len(TORCH_DTYPE_VALUE_PREFIX) :])
|
|
255
256
|
|
|
256
257
|
return processed_kwargs
|
|
257
258
|
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from helm.clients.openai_client import OpenAIClient
|
|
4
|
+
from helm.common.cache import CacheConfig
|
|
5
|
+
from helm.common.request import Request
|
|
6
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class NvidiaNimClient(OpenAIClient):
|
|
10
|
+
|
|
11
|
+
BASE_URL = "https://integrate.api.nvidia.com/v1"
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
tokenizer: Tokenizer,
|
|
16
|
+
tokenizer_name: str,
|
|
17
|
+
cache_config: CacheConfig,
|
|
18
|
+
api_key: Optional[str] = None,
|
|
19
|
+
):
|
|
20
|
+
self.tokenizer = tokenizer
|
|
21
|
+
self.tokenizer_name = tokenizer_name
|
|
22
|
+
super().__init__(
|
|
23
|
+
tokenizer=tokenizer,
|
|
24
|
+
tokenizer_name=tokenizer_name,
|
|
25
|
+
cache_config=cache_config,
|
|
26
|
+
api_key=api_key,
|
|
27
|
+
org_id=None,
|
|
28
|
+
base_url=NvidiaNimClient.BASE_URL,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def _get_model_for_request(self, request: Request) -> str:
|
|
32
|
+
return request.model
|
|
33
|
+
|
|
34
|
+
def _is_chat_model_engine(self, model_engine: str) -> bool:
|
|
35
|
+
return True
|
helm/clients/openai_client.py
CHANGED
|
@@ -12,8 +12,8 @@ from helm.common.tokenization_request import (
|
|
|
12
12
|
TokenizationRequest,
|
|
13
13
|
TokenizationRequestResult,
|
|
14
14
|
)
|
|
15
|
-
from helm.tokenizers.tokenizer import Tokenizer
|
|
16
15
|
from .client import CachingClient, truncate_sequence, generate_uid_for_multimodal_prompt
|
|
16
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
17
17
|
|
|
18
18
|
try:
|
|
19
19
|
import openai
|
|
@@ -51,7 +51,7 @@ class OpenAIClient(CachingClient):
|
|
|
51
51
|
def _is_chat_model_engine(self, model_engine: str) -> bool:
|
|
52
52
|
if model_engine == "gpt-3.5-turbo-instruct":
|
|
53
53
|
return False
|
|
54
|
-
elif model_engine.startswith("gpt-3.5") or model_engine.startswith("gpt-4"):
|
|
54
|
+
elif model_engine.startswith("gpt-3.5") or model_engine.startswith("gpt-4") or model_engine.startswith("o1"):
|
|
55
55
|
return True
|
|
56
56
|
return False
|
|
57
57
|
|
|
@@ -132,6 +132,7 @@ class OpenAIClient(CachingClient):
|
|
|
132
132
|
content: Union[str, List[Union[str, Any]]]
|
|
133
133
|
if request.multimodal_prompt is not None:
|
|
134
134
|
content = []
|
|
135
|
+
request.validate()
|
|
135
136
|
for media_object in request.multimodal_prompt.media_objects:
|
|
136
137
|
if media_object.is_type("image") and media_object.location:
|
|
137
138
|
from helm.common.images_utils import encode_base64
|
|
@@ -140,8 +141,6 @@ class OpenAIClient(CachingClient):
|
|
|
140
141
|
image_object: Dict[str, str] = {"url": f"data:image/jpeg;base64,{base64_image}"}
|
|
141
142
|
content.append({"type": "image_url", "image_url": image_object})
|
|
142
143
|
elif media_object.is_type(TEXT_TYPE):
|
|
143
|
-
if media_object.text is None:
|
|
144
|
-
raise ValueError("MediaObject of text type has missing text field value")
|
|
145
144
|
content.append({"type": media_object.type, "text": media_object.text})
|
|
146
145
|
else:
|
|
147
146
|
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
@@ -170,6 +169,21 @@ class OpenAIClient(CachingClient):
|
|
|
170
169
|
if is_vlm(request.model) and raw_request["stop"] is None:
|
|
171
170
|
raw_request.pop("stop")
|
|
172
171
|
|
|
172
|
+
# Special handling for o1 models.
|
|
173
|
+
# Refer to the "Reasoning models" documentation further discussion of o1 model limitations:
|
|
174
|
+
# https://platform.openai.com/docs/guides/reasoning
|
|
175
|
+
if request.model_engine.startswith("o1"):
|
|
176
|
+
# Avoid error:
|
|
177
|
+
# "Unsupported parameter: 'max_tokens' is not supported with this model. Use 'max_completion_tokens' instead." # noqa: E501
|
|
178
|
+
# Note that openai>=1.45 is needed for this
|
|
179
|
+
if raw_request["max_tokens"]:
|
|
180
|
+
raw_request["max_completion_tokens"] = raw_request["max_tokens"]
|
|
181
|
+
raw_request.pop("max_tokens")
|
|
182
|
+
# Avoid error:
|
|
183
|
+
# "Invalid type for 'stop': expected an unsupported value, but got null instead."
|
|
184
|
+
if raw_request["stop"] is None:
|
|
185
|
+
raw_request.pop("stop")
|
|
186
|
+
|
|
173
187
|
def do_it() -> Dict[str, Any]:
|
|
174
188
|
return self.client.chat.completions.create(**raw_request).model_dump(mode="json")
|
|
175
189
|
|
helm/clients/palmyra_client.py
CHANGED
|
@@ -3,6 +3,7 @@ import json
|
|
|
3
3
|
import requests
|
|
4
4
|
from typing import Any, Dict, List
|
|
5
5
|
|
|
6
|
+
from helm.clients.openai_client import OpenAIClient
|
|
6
7
|
from helm.common.cache import CacheConfig
|
|
7
8
|
from helm.common.hierarchical_logger import hlog
|
|
8
9
|
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token, ErrorFlags
|
|
@@ -142,3 +143,26 @@ class PalmyraClient(CachingClient):
|
|
|
142
143
|
completions=completions,
|
|
143
144
|
embedding=[],
|
|
144
145
|
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class PalmyraChatClient(OpenAIClient):
|
|
149
|
+
"""Sends request to a Palmyra model using a OpenAI-compatible Chat API."""
|
|
150
|
+
|
|
151
|
+
def __init__(
|
|
152
|
+
self,
|
|
153
|
+
tokenizer: Tokenizer,
|
|
154
|
+
tokenizer_name: str,
|
|
155
|
+
cache_config: CacheConfig,
|
|
156
|
+
api_key: str,
|
|
157
|
+
):
|
|
158
|
+
super().__init__(
|
|
159
|
+
tokenizer=tokenizer,
|
|
160
|
+
tokenizer_name=tokenizer_name,
|
|
161
|
+
cache_config=cache_config,
|
|
162
|
+
api_key=api_key,
|
|
163
|
+
org_id=None,
|
|
164
|
+
base_url="https://api.writer.com/v1/chat",
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def _is_chat_model_engine(self, model_engine: str) -> bool:
|
|
168
|
+
return True
|
|
@@ -4,16 +4,21 @@ from dataclasses import asdict
|
|
|
4
4
|
from typing import Any, List, Dict, Optional
|
|
5
5
|
|
|
6
6
|
from dacite import from_dict
|
|
7
|
-
|
|
8
|
-
from googleapiclient.errors import BatchError, HttpError
|
|
9
|
-
from googleapiclient.http import BatchHttpRequest
|
|
10
|
-
from httplib2 import HttpLib2Error
|
|
7
|
+
|
|
11
8
|
from helm.clients.toxicity_classifier_client import ToxicityClassifierClient
|
|
9
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
12
10
|
from helm.proxy.retry import NonRetriableException
|
|
13
|
-
|
|
14
11
|
from helm.common.cache import Cache, CacheConfig
|
|
15
12
|
from helm.common.perspective_api_request import ToxicityAttributes, PerspectiveAPIRequest, PerspectiveAPIRequestResult
|
|
16
|
-
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from googleapiclient import discovery
|
|
16
|
+
from googleapiclient.errors import BatchError, HttpError
|
|
17
|
+
from googleapiclient.http import BatchHttpRequest
|
|
18
|
+
from httplib2 import HttpLib2Error
|
|
19
|
+
from google.auth.exceptions import DefaultCredentialsError
|
|
20
|
+
except ModuleNotFoundError as e:
|
|
21
|
+
handle_module_not_found_error(e, ["metrics"])
|
|
17
22
|
|
|
18
23
|
|
|
19
24
|
class PerspectiveAPIClientCredentialsError(NonRetriableException):
|
helm/clients/test_client.py
CHANGED
|
@@ -23,30 +23,28 @@ def test_truncate_sequence():
|
|
|
23
23
|
# echo_prompt = True, nothing gets truncated
|
|
24
24
|
truncate_sequence_helper(
|
|
25
25
|
["a", "b", "c"],
|
|
26
|
-
Request(
|
|
27
|
-
model="openai/text-davinci-002", model_deployment="openai/text-davinci-002", prompt="abc", echo_prompt=True
|
|
28
|
-
),
|
|
26
|
+
Request(model="openai/gpt2", model_deployment="huggingface/gpt2", prompt="abc", echo_prompt=True),
|
|
29
27
|
["a", "b", "c"],
|
|
30
28
|
)
|
|
31
29
|
|
|
32
30
|
# Nothing gets truncated
|
|
33
31
|
truncate_sequence_helper(
|
|
34
32
|
["hello", " world"],
|
|
35
|
-
Request(model="openai/
|
|
33
|
+
Request(model="openai/gpt2", model_deployment="huggingface/gpt2", stop_sequences=["#"]),
|
|
36
34
|
["hello", " world"],
|
|
37
35
|
)
|
|
38
36
|
|
|
39
37
|
# Truncate using stop sequences
|
|
40
38
|
truncate_sequence_helper(
|
|
41
39
|
["hello", " world", "\n", "what"],
|
|
42
|
-
Request(model="openai/
|
|
40
|
+
Request(model="openai/gpt2", model_deployment="huggingface/gpt2", stop_sequences=["\n"]),
|
|
43
41
|
["hello", " world"],
|
|
44
42
|
)
|
|
45
43
|
|
|
46
44
|
# Truncate using max tokens
|
|
47
45
|
truncate_sequence_helper(
|
|
48
46
|
["a", "b", "c"],
|
|
49
|
-
Request(model="openai/
|
|
47
|
+
Request(model="openai/gpt2", model_deployment="huggingface/gpt2", max_tokens=2),
|
|
50
48
|
["a", "b"],
|
|
51
49
|
)
|
|
52
50
|
|
helm/clients/together_client.py
CHANGED
|
@@ -7,6 +7,7 @@ import requests
|
|
|
7
7
|
from retrying import retry
|
|
8
8
|
|
|
9
9
|
from helm.common.cache import CacheConfig
|
|
10
|
+
from helm.common.media_object import IMAGE_TYPE, TEXT_TYPE
|
|
10
11
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
11
12
|
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
|
|
12
13
|
from helm.clients.client import CachingClient, truncate_sequence, cleanup_str
|
|
@@ -323,8 +324,29 @@ class TogetherChatClient(CachingClient):
|
|
|
323
324
|
self._together_model = together_model
|
|
324
325
|
|
|
325
326
|
def convert_to_raw_chat_request(self, request: Request) -> TogetherRawChatRequest:
|
|
327
|
+
request.validate()
|
|
328
|
+
messages: List[Dict[str, Any]]
|
|
326
329
|
if request.messages:
|
|
327
330
|
messages = request.messages
|
|
331
|
+
elif request.multimodal_prompt:
|
|
332
|
+
message_contents = []
|
|
333
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
334
|
+
if media_object.is_type(IMAGE_TYPE) and media_object.location:
|
|
335
|
+
assert media_object.location
|
|
336
|
+
if media_object.is_local_file:
|
|
337
|
+
from helm.common.images_utils import encode_base64
|
|
338
|
+
|
|
339
|
+
base64_image: str = encode_base64(media_object.location)
|
|
340
|
+
image_url = f"data:image/jpeg;base64,{base64_image}"
|
|
341
|
+
else:
|
|
342
|
+
image_url = media_object.location
|
|
343
|
+
message_contents.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
344
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
345
|
+
assert media_object.text
|
|
346
|
+
message_contents.append({"type": "text", "text": media_object.text})
|
|
347
|
+
else:
|
|
348
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
349
|
+
messages = [{"role": "user", "content": message_contents}]
|
|
328
350
|
else:
|
|
329
351
|
messages = [{"role": "user", "content": request.prompt}]
|
|
330
352
|
if self._together_model is not None:
|
|
@@ -82,13 +82,12 @@ class OpenFlamingoClient(CachingClient):
|
|
|
82
82
|
# Build the prompt
|
|
83
83
|
prompt_text: str = ""
|
|
84
84
|
images: List[Image.Image] = []
|
|
85
|
+
request.validate()
|
|
85
86
|
for media_object in request.multimodal_prompt.media_objects:
|
|
86
87
|
if media_object.is_type("image") and media_object.location:
|
|
87
88
|
images.append(open_image(media_object.location))
|
|
88
89
|
prompt_text += self.IMAGE_TOKEN
|
|
89
90
|
elif media_object.is_type(TEXT_TYPE):
|
|
90
|
-
if media_object.text is None:
|
|
91
|
-
raise ValueError("MediaObject of text type has missing text field value")
|
|
92
91
|
prompt_text += media_object.text
|
|
93
92
|
else:
|
|
94
93
|
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
@@ -6,13 +6,19 @@ import requests
|
|
|
6
6
|
from helm.common.cache import CacheConfig
|
|
7
7
|
from helm.common.images_utils import encode_base64
|
|
8
8
|
from helm.common.media_object import TEXT_TYPE
|
|
9
|
-
from helm.common.request import Request, RequestResult, GeneratedOutput
|
|
9
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, ErrorFlags
|
|
10
10
|
from helm.common.request import wrap_request_time
|
|
11
11
|
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt, truncate_and_tokenize_response_text
|
|
12
12
|
from helm.tokenizers.tokenizer import Tokenizer
|
|
13
13
|
|
|
14
14
|
|
|
15
|
+
class PalmyraVisionContentBlockedError(Exception):
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
15
19
|
class PalmyraVisionClient(CachingClient):
|
|
20
|
+
CONTENT_BLOCKED_ERROR: str = "fail.input.content.moderation"
|
|
21
|
+
|
|
16
22
|
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, endpoint: str, cache_config: CacheConfig):
|
|
17
23
|
super().__init__(cache_config)
|
|
18
24
|
self.tokenizer: Tokenizer = tokenizer
|
|
@@ -49,17 +55,19 @@ class PalmyraVisionClient(CachingClient):
|
|
|
49
55
|
response = requests.post(
|
|
50
56
|
self.endpoint, headers={"Content-Type": "application/json"}, data=json.dumps({"parts": prompt})
|
|
51
57
|
)
|
|
52
|
-
if response.status_code != 200:
|
|
53
|
-
curl_command: str = (
|
|
54
|
-
f"curl --location '{self.endpoint}' --header 'Content-Type: application/json' "
|
|
55
|
-
f"--data '{json.dumps({'parts': prompt})}'"
|
|
56
|
-
)
|
|
57
|
-
assert False, f"Got status code {response.status_code}. Try {curl_command}"
|
|
58
|
-
|
|
59
58
|
json_response = json.loads(response.text)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
59
|
+
|
|
60
|
+
# Check for content blocked error
|
|
61
|
+
if (
|
|
62
|
+
"errors" in json_response
|
|
63
|
+
and "tpe" in json_response
|
|
64
|
+
and json_response["tpe"] == self.CONTENT_BLOCKED_ERROR
|
|
65
|
+
):
|
|
66
|
+
raise PalmyraVisionContentBlockedError(json_response["errors"])
|
|
67
|
+
|
|
68
|
+
# Hard fail if the `choices` is missing from the response
|
|
69
|
+
assert "choices" in json_response, f"Invalid response: {response.text}"
|
|
70
|
+
|
|
63
71
|
return json_response
|
|
64
72
|
|
|
65
73
|
cache_key = CachingClient.make_cache_key(
|
|
@@ -67,8 +75,15 @@ class PalmyraVisionClient(CachingClient):
|
|
|
67
75
|
request=request,
|
|
68
76
|
)
|
|
69
77
|
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
70
|
-
except
|
|
71
|
-
return RequestResult(
|
|
78
|
+
except PalmyraVisionContentBlockedError as ex:
|
|
79
|
+
return RequestResult(
|
|
80
|
+
success=False,
|
|
81
|
+
cached=False,
|
|
82
|
+
error=f"Content blocked: {str(ex)}",
|
|
83
|
+
completions=[],
|
|
84
|
+
embedding=[],
|
|
85
|
+
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
86
|
+
)
|
|
72
87
|
|
|
73
88
|
# The internal endpoint doesn't support any other parameters, so we have to truncate ourselves
|
|
74
89
|
completions: List[GeneratedOutput] = [
|
helm/common/cache.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Dict, Callable, Generator, Mapping,
|
|
3
|
+
from typing import Dict, Callable, Generator, Mapping, Tuple
|
|
4
4
|
import json
|
|
5
5
|
import threading
|
|
6
6
|
|
|
@@ -38,6 +38,12 @@ class CacheConfig:
|
|
|
38
38
|
class KeyValueStoreCacheConfig(CacheConfig):
|
|
39
39
|
"""Configuration for a cache backed by a key-value store."""
|
|
40
40
|
|
|
41
|
+
# This was originally to distinguish between "primitive" cache configs
|
|
42
|
+
# and "compound" cache configs. But we don't have any "compound" cache configs currently.
|
|
43
|
+
# Hypthetical "compound" example: ReadOnlyCacheConfig(SqliteCacheConfig("path"))
|
|
44
|
+
# TODO: Maybe remove this eventually?
|
|
45
|
+
pass
|
|
46
|
+
|
|
41
47
|
|
|
42
48
|
@dataclass(frozen=True)
|
|
43
49
|
class SqliteCacheConfig(KeyValueStoreCacheConfig):
|
|
@@ -78,24 +84,6 @@ class MongoCacheConfig(KeyValueStoreCacheConfig):
|
|
|
78
84
|
return f"{self.uri}/{self.collection_name}"
|
|
79
85
|
|
|
80
86
|
|
|
81
|
-
@dataclass(frozen=True)
|
|
82
|
-
class WithFollowerCacheConfig(CacheConfig):
|
|
83
|
-
"""Configuration of a cache backed by a main cache and a follower cache."""
|
|
84
|
-
|
|
85
|
-
# Configuration for the main cache.
|
|
86
|
-
# Responses will be written to and served out of this cache.
|
|
87
|
-
main: KeyValueStoreCacheConfig
|
|
88
|
-
|
|
89
|
-
# Configuration for the follower cache.
|
|
90
|
-
# The follower cache is a write-only cache. Responses will be written to this cache,
|
|
91
|
-
# but not served out of this cache.
|
|
92
|
-
follower: KeyValueStoreCacheConfig
|
|
93
|
-
|
|
94
|
-
@property
|
|
95
|
-
def cache_stats_key(self) -> str:
|
|
96
|
-
return self.main.cache_stats_key
|
|
97
|
-
|
|
98
|
-
|
|
99
87
|
def get_all_from_sqlite(path: str) -> Generator[Tuple[Dict, Dict], None, None]:
|
|
100
88
|
"""Yields all decoded key, value pairs from the SQLite cache.
|
|
101
89
|
|
|
@@ -126,7 +114,7 @@ def create_key_value_store(config: KeyValueStoreCacheConfig) -> KeyValueStore:
|
|
|
126
114
|
elif isinstance(config, BlackHoleCacheConfig):
|
|
127
115
|
return BlackHoleKeyValueStore()
|
|
128
116
|
else:
|
|
129
|
-
raise ValueError(f"
|
|
117
|
+
raise ValueError(f"CacheConfig with unknown type: {config}")
|
|
130
118
|
|
|
131
119
|
|
|
132
120
|
@retry
|
|
@@ -189,14 +177,8 @@ class Cache(object):
|
|
|
189
177
|
|
|
190
178
|
def __init__(self, config: CacheConfig):
|
|
191
179
|
hlog(f"Created cache with config: {config}")
|
|
192
|
-
self.config: KeyValueStoreCacheConfig
|
|
193
|
-
self.follower_config: Optional[KeyValueStoreCacheConfig]
|
|
194
180
|
if isinstance(config, KeyValueStoreCacheConfig):
|
|
195
181
|
self.config = config
|
|
196
|
-
self.follower_config = None
|
|
197
|
-
elif isinstance(config, WithFollowerCacheConfig):
|
|
198
|
-
self.config = config.main
|
|
199
|
-
self.follower_config = config.follower
|
|
200
182
|
else:
|
|
201
183
|
raise ValueError(f"CacheConfig with unknown type: {config}")
|
|
202
184
|
|
|
@@ -216,8 +198,4 @@ class Cache(object):
|
|
|
216
198
|
response = compute()
|
|
217
199
|
|
|
218
200
|
write_to_key_value_store(key_value_store, request, response)
|
|
219
|
-
if self.follower_config is not None:
|
|
220
|
-
# TODO: Initialize follower_key_value_store in constructor
|
|
221
|
-
with create_key_value_store(self.follower_config) as follower_key_value_store:
|
|
222
|
-
write_to_key_value_store(follower_key_value_store, request, response)
|
|
223
201
|
return response, cached
|
helm/common/images_utils.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from hashlib import md5
|
|
1
2
|
import base64
|
|
2
3
|
import io
|
|
3
4
|
import os
|
|
@@ -44,6 +45,11 @@ def encode_base64(image_location: str, format="JPEG") -> str:
|
|
|
44
45
|
return base64.b64encode(image_file.getvalue()).decode("ascii")
|
|
45
46
|
|
|
46
47
|
|
|
48
|
+
def generate_hash(image: Image.Image) -> str:
|
|
49
|
+
"""Generates a hash for the image."""
|
|
50
|
+
return md5(image.tobytes()).hexdigest()
|
|
51
|
+
|
|
52
|
+
|
|
47
53
|
def copy_image(src: str, dest: str, width: Optional[int] = None, height: Optional[int] = None) -> None:
|
|
48
54
|
"""
|
|
49
55
|
Copies the image file from `src` path to `dest` path. If dimensions `width` and `height`
|
helm/common/key_value_store.py
CHANGED
|
@@ -15,11 +15,11 @@ class KeyValueStore(contextlib.AbstractContextManager):
|
|
|
15
15
|
"""Key value store that persists writes."""
|
|
16
16
|
|
|
17
17
|
@abstractmethod
|
|
18
|
-
def contains(self, key:
|
|
18
|
+
def contains(self, key: Mapping) -> bool:
|
|
19
19
|
pass
|
|
20
20
|
|
|
21
21
|
@abstractmethod
|
|
22
|
-
def get(self, key:
|
|
22
|
+
def get(self, key: Mapping) -> Optional[Dict]:
|
|
23
23
|
pass
|
|
24
24
|
|
|
25
25
|
@abstractmethod
|
|
@@ -35,7 +35,7 @@ class KeyValueStore(contextlib.AbstractContextManager):
|
|
|
35
35
|
pass
|
|
36
36
|
|
|
37
37
|
@abstractmethod
|
|
38
|
-
def remove(self, key:
|
|
38
|
+
def remove(self, key: Mapping) -> None:
|
|
39
39
|
pass
|
|
40
40
|
|
|
41
41
|
|
|
@@ -53,10 +53,10 @@ class SqliteKeyValueStore(KeyValueStore):
|
|
|
53
53
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
54
54
|
self._sqlite_dict.__exit__(exc_type, exc_value, traceback)
|
|
55
55
|
|
|
56
|
-
def contains(self, key:
|
|
56
|
+
def contains(self, key: Mapping) -> bool:
|
|
57
57
|
return request_to_key(key) in self._sqlite_dict
|
|
58
58
|
|
|
59
|
-
def get(self, key:
|
|
59
|
+
def get(self, key: Mapping) -> Optional[Dict]:
|
|
60
60
|
key_string = request_to_key(key)
|
|
61
61
|
result = self._sqlite_dict.get(key_string)
|
|
62
62
|
if result is not None:
|
|
@@ -77,7 +77,7 @@ class SqliteKeyValueStore(KeyValueStore):
|
|
|
77
77
|
for key, value in pairs:
|
|
78
78
|
self.put(key, value)
|
|
79
79
|
|
|
80
|
-
def remove(self, key:
|
|
80
|
+
def remove(self, key: Mapping) -> None:
|
|
81
81
|
del self._sqlite_dict[key]
|
|
82
82
|
self._sqlite_dict.commit()
|
|
83
83
|
|
|
@@ -91,10 +91,10 @@ class BlackHoleKeyValueStore(KeyValueStore):
|
|
|
91
91
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
92
92
|
pass
|
|
93
93
|
|
|
94
|
-
def contains(self, key:
|
|
94
|
+
def contains(self, key: Mapping) -> bool:
|
|
95
95
|
return False
|
|
96
96
|
|
|
97
|
-
def get(self, key:
|
|
97
|
+
def get(self, key: Mapping) -> Optional[Dict]:
|
|
98
98
|
return None
|
|
99
99
|
|
|
100
100
|
def get_all(self) -> Generator[Tuple[Dict, Dict], None, None]:
|
|
@@ -109,5 +109,5 @@ class BlackHoleKeyValueStore(KeyValueStore):
|
|
|
109
109
|
def multi_put(self, pairs: Iterable[Tuple[Dict, Dict]]) -> None:
|
|
110
110
|
return None
|
|
111
111
|
|
|
112
|
-
def remove(self, key:
|
|
112
|
+
def remove(self, key: Mapping) -> None:
|
|
113
113
|
return None
|
|
@@ -39,11 +39,11 @@ class MongoKeyValueStore(KeyValueStore):
|
|
|
39
39
|
serialized = json.dumps(key, sort_keys=True)
|
|
40
40
|
return json.loads(serialized, object_pairs_hook=SON)
|
|
41
41
|
|
|
42
|
-
def contains(self, key:
|
|
42
|
+
def contains(self, key: Mapping) -> bool:
|
|
43
43
|
query = {self._REQUEST_KEY: self._canonicalize_key(key)}
|
|
44
44
|
return self._collection.find_one(query) is not None
|
|
45
45
|
|
|
46
|
-
def get(self, key:
|
|
46
|
+
def get(self, key: Mapping) -> Optional[Dict]:
|
|
47
47
|
query = {self._REQUEST_KEY: self._canonicalize_key(key)}
|
|
48
48
|
document = self._collection.find_one(query)
|
|
49
49
|
if document is not None:
|
|
@@ -84,5 +84,6 @@ class MongoKeyValueStore(KeyValueStore):
|
|
|
84
84
|
# Note: unlike put, multi_put does not support documents with null bytes in keys.
|
|
85
85
|
self._collection.bulk_write(operations)
|
|
86
86
|
|
|
87
|
-
def remove(self, key:
|
|
88
|
-
self.
|
|
87
|
+
def remove(self, key: Mapping) -> None:
|
|
88
|
+
query = {self._REQUEST_KEY: self._canonicalize_key(key)}
|
|
89
|
+
self._collection.delete_one(query)
|
helm/common/request.py
CHANGED
|
@@ -72,6 +72,22 @@ class Request:
|
|
|
72
72
|
image_generation_parameters: Optional[ImageGenerationParameters] = None
|
|
73
73
|
"""Parameters for image generation."""
|
|
74
74
|
|
|
75
|
+
def validate(self):
|
|
76
|
+
if (
|
|
77
|
+
(self.messages and self.prompt)
|
|
78
|
+
or (self.messages and self.multimodal_prompt)
|
|
79
|
+
or (self.prompt and self.multimodal_prompt)
|
|
80
|
+
):
|
|
81
|
+
raise ValueError("Exactly one of the messages, prompt, multimodal_prompt fields should be set")
|
|
82
|
+
|
|
83
|
+
if self.multimodal_prompt:
|
|
84
|
+
for media_object in self.multimodal_prompt.media_objects:
|
|
85
|
+
if media_object.content_type == "text" and media_object.text is None:
|
|
86
|
+
raise ValueError("Media object with text content type must have text set")
|
|
87
|
+
|
|
88
|
+
if media_object.content_type == "image" and media_object.location is None:
|
|
89
|
+
raise ValueError("Media object with image content type must have location set")
|
|
90
|
+
|
|
75
91
|
@property
|
|
76
92
|
def model_host(self) -> str:
|
|
77
93
|
"""Returns the model host (referring to the deployment).
|