crfm-helm 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/METADATA +27 -13
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/RECORD +203 -156
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +3 -3
- helm/benchmark/adaptation/adapters/test_adapter.py +4 -4
- helm/benchmark/annotation/air_bench_annotator.py +1 -1
- helm/benchmark/annotation/bigcodebench_annotator.py +3 -3
- helm/benchmark/annotation/bird_sql_annotator.py +2 -2
- helm/benchmark/annotation/chw_care_plan_annotator.py +7 -12
- helm/benchmark/annotation/ehr_sql_annotator.py +2 -2
- helm/benchmark/annotation/helpdesk_call_summarization_annotator.py +7 -7
- helm/benchmark/annotation/mimic_bhc_annotator.py +100 -0
- helm/benchmark/annotation/model_as_judge.py +12 -16
- helm/benchmark/annotation/omni_math_annotator.py +13 -14
- helm/benchmark/annotation/wildbench_annotator.py +9 -9
- helm/benchmark/executor.py +11 -12
- helm/benchmark/metrics/aci_bench_metrics.py +9 -29
- helm/benchmark/metrics/bias_word_lists.py +1 -1
- helm/benchmark/metrics/chw_care_plan_metrics.py +10 -30
- helm/benchmark/metrics/classification_metrics.py +3 -3
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +2 -2
- helm/benchmark/metrics/dischargeme_metrics.py +9 -29
- helm/benchmark/metrics/efficiency_metrics.py +3 -3
- helm/benchmark/metrics/gpt4_audio_refusal_metrics.py +145 -0
- helm/benchmark/metrics/ifeval_metrics.py +2 -2
- helm/benchmark/metrics/kpi_edgar_metrics.py +121 -0
- helm/benchmark/metrics/llm_jury_metrics.py +46 -0
- helm/benchmark/metrics/med_dialog_metrics.py +9 -29
- helm/benchmark/metrics/medalign_metrics.py +9 -29
- helm/benchmark/metrics/medi_qa_metrics.py +9 -29
- helm/benchmark/metrics/medication_qa_metrics.py +10 -30
- helm/benchmark/metrics/melt_bias_metric.py +234 -0
- helm/benchmark/metrics/melt_bias_word_lists.py +1367 -0
- helm/benchmark/metrics/melt_metric_specs.py +43 -0
- helm/benchmark/metrics/melt_toxicity_metric.py +107 -0
- helm/benchmark/metrics/mental_health_metrics.py +9 -29
- helm/benchmark/metrics/metric_service.py +11 -11
- helm/benchmark/metrics/mimic_bhc_metrics.py +14 -0
- helm/benchmark/metrics/mimic_rrs_metrics.py +9 -29
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +9 -29
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +9 -29
- helm/benchmark/metrics/openai_mrcr_metrics.py +52 -0
- helm/benchmark/metrics/ruler_qa_metrics.py +34 -0
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +9 -29
- helm/benchmark/metrics/summac/model_summac.py +1 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -1
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +2 -2
- helm/benchmark/metrics/toxicity_metrics.py +2 -2
- helm/benchmark/metrics/unitxt_metrics.py +3 -4
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -2
- helm/benchmark/metrics/vision_language/image_utils.py +2 -2
- helm/benchmark/model_deployment_registry.py +6 -8
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +33 -12
- helm/benchmark/presentation/run_display.py +13 -0
- helm/benchmark/presentation/schema.py +2 -1
- helm/benchmark/presentation/summarize.py +76 -59
- helm/benchmark/reeval_run.py +3 -4
- helm/benchmark/reeval_runner.py +3 -3
- helm/benchmark/run.py +78 -73
- helm/benchmark/run_expander.py +12 -1
- helm/benchmark/run_spec_factory.py +7 -6
- helm/benchmark/run_specs/audio_run_specs.py +52 -8
- helm/benchmark/run_specs/enterprise_run_specs.py +20 -0
- helm/benchmark/run_specs/experimental_run_specs.py +31 -1
- helm/benchmark/run_specs/long_context_run_specs.py +67 -15
- helm/benchmark/run_specs/medhelm_run_specs.py +146 -41
- helm/benchmark/run_specs/melt_run_specs.py +783 -0
- helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +169 -0
- helm/benchmark/run_specs/vlm_run_specs.py +28 -0
- helm/benchmark/runner.py +5 -5
- helm/benchmark/scenarios/aci_bench_scenario.py +7 -1
- helm/benchmark/scenarios/audio_language/air_bench_chat_scenario.py +3 -1
- helm/benchmark/scenarios/audio_language/air_bench_foundation_scenario.py +5 -5
- helm/benchmark/scenarios/audio_language/corebench_scenario.py +77 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification.py +103 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +110 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +78 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +109 -0
- helm/benchmark/scenarios/audio_language/vocal_sound_scenario.py +15 -1
- helm/benchmark/scenarios/audio_language/voxceleb2_scenario.py +1 -2
- helm/benchmark/scenarios/autobencher_capabilities_scenario.py +2 -2
- helm/benchmark/scenarios/chw_care_plan_scenario.py +14 -13
- helm/benchmark/scenarios/clear_scenario.py +11 -7
- helm/benchmark/scenarios/dischargeme_scenario.py +36 -21
- helm/benchmark/scenarios/ehr_sql_scenario.py +7 -1
- helm/benchmark/scenarios/ehrshot_scenario.py +28 -55
- helm/benchmark/scenarios/grammar.py +2 -2
- helm/benchmark/scenarios/headqa_scenario.py +6 -1
- helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +85 -0
- helm/benchmark/scenarios/{infinite_bench_sum_scenario.py → infinite_bench_en_sum_scenario.py} +10 -13
- helm/benchmark/scenarios/kpi_edgar_scenario.py +151 -0
- helm/benchmark/scenarios/med_dialog_scenario.py +6 -1
- helm/benchmark/scenarios/medalign_scenario.py +9 -3
- helm/benchmark/scenarios/medalign_scenario_helper.py +8 -5
- helm/benchmark/scenarios/medbullets_scenario.py +7 -2
- helm/benchmark/scenarios/medcalc_bench_scenario.py +4 -2
- helm/benchmark/scenarios/medec_scenario.py +6 -1
- helm/benchmark/scenarios/medhallu_scenario.py +7 -1
- helm/benchmark/scenarios/medi_qa_scenario.py +10 -4
- helm/benchmark/scenarios/medication_qa_scenario.py +7 -1
- helm/benchmark/scenarios/melt_ir_scenario.py +171 -0
- helm/benchmark/scenarios/melt_knowledge_scenario.py +246 -0
- helm/benchmark/scenarios/melt_lm_scenarios.py +252 -0
- helm/benchmark/scenarios/melt_scenarios.py +793 -0
- helm/benchmark/scenarios/melt_srn_scenario.py +342 -0
- helm/benchmark/scenarios/melt_synthetic_reasoning_scenario.py +222 -0
- helm/benchmark/scenarios/melt_translation_scenario.py +152 -0
- helm/benchmark/scenarios/mental_health_scenario.py +16 -5
- helm/benchmark/scenarios/mimic_bhc_scenario.py +12 -7
- helm/benchmark/scenarios/mimic_rrs_scenario.py +17 -8
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +14 -8
- helm/benchmark/scenarios/mmlu_pro_scenario.py +1 -1
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +5 -2
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +3 -2
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +11 -5
- helm/benchmark/scenarios/numeracy_scenario.py +2 -1
- helm/benchmark/scenarios/openai_mrcr_scenario.py +79 -0
- helm/benchmark/scenarios/pubmed_qa_scenario.py +6 -1
- helm/benchmark/scenarios/race_based_med_scenario.py +18 -8
- helm/benchmark/scenarios/ruler_qa_scenario_helper.py +2 -2
- helm/benchmark/scenarios/ruler_qa_scenarios.py +2 -2
- helm/benchmark/scenarios/shc_bmt_scenario.py +12 -6
- helm/benchmark/scenarios/shc_cdi_scenario.py +11 -6
- helm/benchmark/scenarios/shc_conf_scenario.py +12 -6
- helm/benchmark/scenarios/shc_ent_scenario.py +11 -6
- helm/benchmark/scenarios/shc_gip_scenario.py +13 -5
- helm/benchmark/scenarios/shc_privacy_scenario.py +78 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +76 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +12 -7
- helm/benchmark/scenarios/shc_sei_scenario.py +12 -7
- helm/benchmark/scenarios/shc_sequoia_scenario.py +13 -5
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +15 -8
- helm/benchmark/scenarios/test_infinite_bench_en_qa_scenario.py +18 -0
- helm/benchmark/scenarios/test_infinite_bench_en_sum_scenario.py +31 -0
- helm/benchmark/scenarios/truthful_qa_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/msr_vtt_scenario.py +75 -0
- helm/benchmark/server.py +2 -1
- helm/benchmark/static/schema_audio.yaml +60 -49
- helm/benchmark/static/schema_enterprise.yaml +21 -0
- helm/benchmark/static/schema_long_context.yaml +63 -20
- helm/benchmark/static/schema_medhelm.yaml +272 -213
- helm/benchmark/static/schema_melt.yaml +1257 -0
- helm/benchmark/static/schema_slphelm.yaml +162 -0
- helm/benchmark/static/schema_vhelm.yaml +26 -26
- helm/benchmark/static/schema_video.yaml +219 -0
- helm/benchmark/static_build/assets/index-94295e78.js +10 -0
- helm/benchmark/static_build/assets/index-b9779128.css +1 -0
- helm/benchmark/static_build/assets/medhelm-overview-eac29843.png +0 -0
- helm/benchmark/static_build/assets/{tremor-9cefc3c5.js → tremor-38a10867.js} +1 -1
- helm/benchmark/static_build/index.html +4 -4
- helm/benchmark/window_services/encoder_decoder_window_service.py +3 -3
- helm/benchmark/window_services/test_utils.py +3 -4
- helm/benchmark/window_services/tokenizer_service.py +7 -8
- helm/clients/anthropic_client.py +69 -29
- helm/clients/audio_language/diva_llama_client.py +4 -2
- helm/clients/audio_language/qwen2_5_omni_client.py +197 -0
- helm/clients/audio_language/qwen2_audiolm_client.py +8 -6
- helm/clients/audio_language/qwen_audiolm_client.py +4 -2
- helm/clients/audio_language/test.py +62 -0
- helm/clients/bedrock_client.py +3 -1
- helm/clients/client.py +7 -7
- helm/clients/grok_client.py +36 -0
- helm/clients/huggingface_client.py +42 -3
- helm/clients/huggingface_pipeline_client.py +138 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +1 -1
- helm/clients/image_generation/dalle_mini/model/modeling.py +1 -1
- helm/clients/image_generation/dalle_mini/model/processor.py +1 -1
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +1 -1
- helm/clients/openai_client.py +100 -54
- helm/clients/openai_responses_client.py +174 -0
- helm/clients/palmyra_client.py +2 -5
- helm/clients/reka_client.py +2 -2
- helm/clients/together_client.py +31 -4
- helm/clients/vertexai_client.py +6 -0
- helm/clients/vision_language/huggingface_vision2seq_client.py +6 -4
- helm/clients/vision_language/huggingface_vlm_client.py +2 -2
- helm/clients/vision_language/idefics_client.py +6 -2
- helm/clients/vision_language/paligemma_client.py +2 -2
- helm/clients/vision_language/qwen2_vlm_client.py +66 -53
- helm/clients/vision_language/qwen_vlm_client.py +7 -5
- helm/clients/writer_client.py +102 -0
- helm/common/context.py +80 -0
- helm/common/credentials_utils.py +5 -5
- helm/common/general.py +9 -2
- helm/common/hierarchical_logger.py +46 -3
- helm/common/local_context.py +140 -0
- helm/common/remote_context.py +61 -0
- helm/common/request.py +8 -0
- helm/config/model_deployments.yaml +864 -193
- helm/config/model_metadata.yaml +667 -53
- helm/config/tokenizer_configs.yaml +144 -3
- helm/proxy/cli.py +3 -1
- helm/proxy/critique/mechanical_turk_utils.py +1 -1
- helm/proxy/services/server_service.py +21 -85
- helm/tokenizers/grok_tokenizer.py +53 -0
- helm/tokenizers/huggingface_tokenizer.py +1 -1
- helm/tokenizers/test_grok_tokenizer.py +33 -0
- helm/benchmark/scenarios/test_infinite_bench_sum_scenario.py +0 -46
- helm/benchmark/static_build/assets/index-262903c1.js +0 -10
- helm/benchmark/static_build/assets/index-42060d71.css +0 -1
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/licenses/LICENSE +0 -0
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.6.dist-info}/top_level.txt +0 -0
- /helm/benchmark/static_build/assets/{medhelm-overview-3ddfcd65.png → medhelm-v1-overview-3ddfcd65.png} +0 -0
|
@@ -2,16 +2,16 @@
|
|
|
2
2
|
<html lang="en">
|
|
3
3
|
<head>
|
|
4
4
|
<meta charset="UTF-8" />
|
|
5
|
-
<link rel="icon" type="image/svg+xml" href="
|
|
5
|
+
<link rel="icon" type="image/svg+xml" href="https://crfm.stanford.edu/helm/helm.svg" />
|
|
6
6
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
|
7
7
|
<title>Holistic Evaluation of Language Models (HELM)</title>
|
|
8
8
|
<meta name="description" content="The Holistic Evaluation of Language Models (HELM) serves as a living benchmark for transparency in language models. Providing broad coverage and recognizing incompleteness, multi-metric measurements, and standardization. All data and analysis are freely accessible on the website for exploration and study." />
|
|
9
9
|
<script type="text/javascript" src="./config.js"></script>
|
|
10
|
-
<script type="module" crossorigin src="./assets/index-
|
|
10
|
+
<script type="module" crossorigin src="./assets/index-94295e78.js"></script>
|
|
11
11
|
<link rel="modulepreload" crossorigin href="./assets/react-f82877fd.js">
|
|
12
12
|
<link rel="modulepreload" crossorigin href="./assets/recharts-4037aff0.js">
|
|
13
|
-
<link rel="modulepreload" crossorigin href="./assets/tremor-
|
|
14
|
-
<link rel="stylesheet" href="./assets/index-
|
|
13
|
+
<link rel="modulepreload" crossorigin href="./assets/tremor-38a10867.js">
|
|
14
|
+
<link rel="stylesheet" href="./assets/index-b9779128.css">
|
|
15
15
|
</head>
|
|
16
16
|
<body class="block">
|
|
17
17
|
<div id="root"></div>
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from abc import ABC
|
|
2
2
|
|
|
3
|
-
from helm.common.hierarchical_logger import
|
|
3
|
+
from helm.common.hierarchical_logger import hwarn
|
|
4
4
|
from helm.benchmark.window_services.local_window_service import LocalWindowService
|
|
5
5
|
|
|
6
6
|
|
|
@@ -21,8 +21,8 @@ class EncoderDecoderWindowService(LocalWindowService, ABC):
|
|
|
21
21
|
vs. the completions, we check the two values separately.
|
|
22
22
|
"""
|
|
23
23
|
if expected_completion_token_length > self.max_output_length:
|
|
24
|
-
|
|
25
|
-
f"
|
|
24
|
+
hwarn(
|
|
25
|
+
f"The expected completion token length ({expected_completion_token_length}) "
|
|
26
26
|
f"exceeds the max output length ({self.max_output_length})."
|
|
27
27
|
)
|
|
28
28
|
return self.get_num_tokens(text) <= self.max_request_length
|
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
from typing import List
|
|
2
2
|
|
|
3
|
-
from helm.common.
|
|
3
|
+
from helm.common.local_context import LocalContext
|
|
4
4
|
from helm.common.cache_backend_config import CacheBackendConfig
|
|
5
|
-
from helm.proxy.services.server_service import ServerService
|
|
6
5
|
from helm.benchmark.metrics.metric_service import MetricService
|
|
7
6
|
from helm.benchmark.window_services.tokenizer_service import TokenizerService
|
|
8
7
|
|
|
@@ -229,5 +228,5 @@ GPT4_TEST_TOKENS: List[str] = [
|
|
|
229
228
|
|
|
230
229
|
|
|
231
230
|
def get_tokenizer_service(local_path: str, cache_backend_config: CacheBackendConfig) -> TokenizerService:
|
|
232
|
-
|
|
233
|
-
return MetricService(
|
|
231
|
+
context = LocalContext(base_path=local_path, cache_backend_config=cache_backend_config)
|
|
232
|
+
return MetricService(context)
|
|
@@ -1,26 +1,25 @@
|
|
|
1
|
-
from helm.common.
|
|
1
|
+
from helm.common.context import Context
|
|
2
2
|
from helm.common.tokenization_request import (
|
|
3
3
|
TokenizationRequest,
|
|
4
4
|
TokenizationRequestResult,
|
|
5
5
|
DecodeRequest,
|
|
6
6
|
DecodeRequestResult,
|
|
7
7
|
)
|
|
8
|
-
from helm.proxy.services.service import Service
|
|
9
8
|
|
|
10
9
|
|
|
10
|
+
# TODO: Rename this to TokenizerContext
|
|
11
11
|
class TokenizerService:
|
|
12
12
|
"""
|
|
13
|
-
A wrapper around `
|
|
13
|
+
A wrapper around `Context` that makes only necessary server requests to tokenize.
|
|
14
14
|
"""
|
|
15
15
|
|
|
16
|
-
def __init__(self,
|
|
17
|
-
self.
|
|
18
|
-
self._auth: Authentication = auth
|
|
16
|
+
def __init__(self, context: Context):
|
|
17
|
+
self._context: Context = context
|
|
19
18
|
|
|
20
19
|
def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
|
|
21
20
|
"""Tokenize via an API."""
|
|
22
|
-
return self.
|
|
21
|
+
return self._context.tokenize(request)
|
|
23
22
|
|
|
24
23
|
def decode(self, request: DecodeRequest) -> DecodeRequestResult:
|
|
25
24
|
"""Decode via an API."""
|
|
26
|
-
return self.
|
|
25
|
+
return self._context.decode(request)
|
helm/clients/anthropic_client.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import dataclasses
|
|
1
2
|
from typing import Any, Dict, List, Optional, TypedDict, Union, cast
|
|
2
3
|
import json
|
|
3
4
|
import os
|
|
@@ -7,10 +8,11 @@ import time
|
|
|
7
8
|
import urllib.parse
|
|
8
9
|
|
|
9
10
|
from helm.common.cache import CacheConfig
|
|
10
|
-
from helm.common.hierarchical_logger import htrack_block, hlog
|
|
11
|
+
from helm.common.hierarchical_logger import htrack_block, hlog, hwarn
|
|
11
12
|
from helm.common.media_object import IMAGE_TYPE, TEXT_TYPE
|
|
12
13
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
13
14
|
from helm.common.request import (
|
|
15
|
+
Thinking,
|
|
14
16
|
wrap_request_time,
|
|
15
17
|
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
16
18
|
Request,
|
|
@@ -30,8 +32,12 @@ from helm.clients.client import CachingClient, truncate_sequence, truncate_and_t
|
|
|
30
32
|
try:
|
|
31
33
|
from anthropic import Anthropic, BadRequestError
|
|
32
34
|
from anthropic.types import MessageParam
|
|
35
|
+
from anthropic.types.message import Message
|
|
36
|
+
from anthropic.types.text_block import TextBlock
|
|
37
|
+
from anthropic.types.thinking_block import ThinkingBlock
|
|
33
38
|
from anthropic.types.image_block_param import ImageBlockParam
|
|
34
39
|
from anthropic.types.text_block_param import TextBlockParam
|
|
40
|
+
from anthropic.types.thinking_config_enabled_param import ThinkingConfigEnabledParam
|
|
35
41
|
import websocket
|
|
36
42
|
except ModuleNotFoundError as e:
|
|
37
43
|
handle_module_not_found_error(e, ["anthropic"])
|
|
@@ -231,30 +237,41 @@ class AnthropicMessagesRequest(TypedDict, total=False):
|
|
|
231
237
|
temperature: float
|
|
232
238
|
top_k: int
|
|
233
239
|
top_p: float
|
|
240
|
+
thinking: ThinkingConfigEnabledParam
|
|
234
241
|
|
|
235
242
|
|
|
236
243
|
class AnthropicMessagesRequestError(NonRetriableException):
|
|
237
244
|
pass
|
|
238
245
|
|
|
239
246
|
|
|
240
|
-
class
|
|
247
|
+
class AnthropicMessagesEmptyContentError(Exception):
|
|
241
248
|
pass
|
|
242
249
|
|
|
243
250
|
|
|
244
251
|
class AnthropicMessagesClient(CachingClient):
|
|
245
252
|
# Source: https://docs.anthropic.com/claude/docs/models-overview
|
|
246
|
-
MAX_OUTPUT_TOKENS: int =
|
|
253
|
+
MAX_OUTPUT_TOKENS: int = 64000
|
|
247
254
|
|
|
248
255
|
MAX_IMAGE_SIZE_BYTES: int = 5242880 # 5MB
|
|
249
256
|
|
|
250
257
|
def __init__(
|
|
251
|
-
self,
|
|
258
|
+
self,
|
|
259
|
+
tokenizer: Tokenizer,
|
|
260
|
+
tokenizer_name: str,
|
|
261
|
+
cache_config: CacheConfig,
|
|
262
|
+
thinking_budget_tokens: Optional[int] = None,
|
|
263
|
+
anthropic_model_name: Optional[str] = None,
|
|
264
|
+
api_key: Optional[str] = None,
|
|
265
|
+
stream: Optional[bool] = None,
|
|
252
266
|
):
|
|
253
267
|
super().__init__(cache_config=cache_config)
|
|
254
268
|
self.tokenizer = tokenizer
|
|
255
269
|
self.tokenizer_name = tokenizer_name
|
|
256
270
|
self.client = Anthropic(api_key=api_key)
|
|
257
271
|
self.api_key: Optional[str] = api_key
|
|
272
|
+
self.anthropic_model_name: Optional[str] = anthropic_model_name
|
|
273
|
+
self.thinking_budget_tokens: Optional[int] = thinking_budget_tokens
|
|
274
|
+
self.stream: Optional[bool] = stream
|
|
258
275
|
|
|
259
276
|
def make_request(self, request: Request) -> RequestResult:
|
|
260
277
|
if request.max_tokens > AnthropicMessagesClient.MAX_OUTPUT_TOKENS:
|
|
@@ -293,8 +310,8 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
293
310
|
image_width > AnthropicClient.MAX_IMAGE_DIMENSION
|
|
294
311
|
or image_height > AnthropicClient.MAX_IMAGE_DIMENSION
|
|
295
312
|
):
|
|
296
|
-
|
|
297
|
-
f"
|
|
313
|
+
hwarn(
|
|
314
|
+
f"Image {image_location} exceeds max allowed size: "
|
|
298
315
|
f"{AnthropicClient.MAX_IMAGE_DIMENSION} pixels"
|
|
299
316
|
)
|
|
300
317
|
# Save the resized image to a temporary file
|
|
@@ -309,8 +326,8 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
309
326
|
base64_image = encode_base64(temp_file.name, format="JPEG")
|
|
310
327
|
|
|
311
328
|
elif os.path.getsize(image_location) > AnthropicMessagesClient.MAX_IMAGE_SIZE_BYTES:
|
|
312
|
-
|
|
313
|
-
f"
|
|
329
|
+
hwarn(
|
|
330
|
+
f"Image {image_location} exceeds max allowed size: "
|
|
314
331
|
f"{AnthropicMessagesClient.MAX_IMAGE_SIZE_BYTES} bytes"
|
|
315
332
|
)
|
|
316
333
|
# Resize the image so it is smaller than the max allowed size
|
|
@@ -351,7 +368,7 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
351
368
|
|
|
352
369
|
raw_request: AnthropicMessagesRequest = {
|
|
353
370
|
"messages": messages,
|
|
354
|
-
"model": request.model_engine,
|
|
371
|
+
"model": self.anthropic_model_name or request.model_engine,
|
|
355
372
|
"stop_sequences": request.stop_sequences,
|
|
356
373
|
"max_tokens": request.max_tokens,
|
|
357
374
|
"temperature": request.temperature,
|
|
@@ -360,6 +377,15 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
360
377
|
}
|
|
361
378
|
if system_message is not None:
|
|
362
379
|
raw_request["system"] = cast(str, system_message["content"])
|
|
380
|
+
if self.thinking_budget_tokens:
|
|
381
|
+
raw_request["thinking"] = {
|
|
382
|
+
"type": "enabled",
|
|
383
|
+
"budget_tokens": self.thinking_budget_tokens,
|
|
384
|
+
}
|
|
385
|
+
# Avoid error:
|
|
386
|
+
# `top_k` must be unset when thinking is enabled. Please consult our documentation at https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking # noqa: E501
|
|
387
|
+
del raw_request["top_k"]
|
|
388
|
+
|
|
363
389
|
completions: List[GeneratedOutput] = []
|
|
364
390
|
|
|
365
391
|
# `num_completions` is not supported, so instead make `num_completions` separate requests.
|
|
@@ -367,11 +393,15 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
367
393
|
|
|
368
394
|
def do_it() -> Dict[str, Any]:
|
|
369
395
|
try:
|
|
370
|
-
|
|
396
|
+
if self.stream:
|
|
397
|
+
with self.client.messages.stream(**raw_request) as message_stream:
|
|
398
|
+
result = message_stream.get_final_message().model_dump()
|
|
399
|
+
else:
|
|
400
|
+
result = self.client.messages.create(**raw_request).model_dump()
|
|
371
401
|
if "content" not in result or not result["content"]:
|
|
372
|
-
raise
|
|
373
|
-
elif "text" not in result["content"][
|
|
374
|
-
raise
|
|
402
|
+
raise AnthropicMessagesEmptyContentError(f"Anthropic response has empty content: {result}")
|
|
403
|
+
elif "text" not in result["content"][-1]:
|
|
404
|
+
raise AnthropicMessagesEmptyContentError(f"Anthropic response has non-text content: {result}")
|
|
375
405
|
return result
|
|
376
406
|
except BadRequestError as e:
|
|
377
407
|
response = e.response.json()
|
|
@@ -387,9 +417,10 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
387
417
|
},
|
|
388
418
|
request,
|
|
389
419
|
)
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
420
|
+
raw_response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
421
|
+
|
|
422
|
+
except AnthropicMessagesEmptyContentError:
|
|
423
|
+
hwarn("Anthropic response has empty content")
|
|
393
424
|
return RequestResult(
|
|
394
425
|
success=False,
|
|
395
426
|
cached=False,
|
|
@@ -399,32 +430,41 @@ class AnthropicMessagesClient(CachingClient):
|
|
|
399
430
|
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
400
431
|
)
|
|
401
432
|
|
|
402
|
-
if _is_content_moderation_failure(
|
|
403
|
-
|
|
404
|
-
f"WARNING: Returning empty request for {request.model_deployment} "
|
|
405
|
-
"due to content moderation filter"
|
|
406
|
-
)
|
|
433
|
+
if _is_content_moderation_failure(raw_response):
|
|
434
|
+
hwarn(f"Returning empty request for {request.model_deployment} " "due to content moderation filter")
|
|
407
435
|
return RequestResult(
|
|
408
436
|
success=False,
|
|
409
437
|
cached=cached,
|
|
410
|
-
error=
|
|
438
|
+
error=raw_response["error"]["message"],
|
|
411
439
|
completions=[],
|
|
412
440
|
embedding=[],
|
|
413
441
|
error_flags=ErrorFlags(is_retriable=False, is_fatal=False),
|
|
414
|
-
request_time=
|
|
415
|
-
request_datetime=
|
|
442
|
+
request_time=raw_response["request_time"],
|
|
443
|
+
request_datetime=raw_response["request_datetime"],
|
|
416
444
|
)
|
|
417
445
|
|
|
446
|
+
response_message: Message = Message.model_validate(raw_response)
|
|
447
|
+
response_text: Optional[str] = None
|
|
448
|
+
response_thinking: Optional[str] = None
|
|
449
|
+
for content in response_message.content:
|
|
450
|
+
if isinstance(content, TextBlock):
|
|
451
|
+
response_text = content.text
|
|
452
|
+
elif isinstance(content, ThinkingBlock):
|
|
453
|
+
response_thinking = content.thinking
|
|
454
|
+
if response_text is None:
|
|
455
|
+
raise Exception("Anthropic response did not contain text block")
|
|
418
456
|
completion = truncate_and_tokenize_response_text(
|
|
419
|
-
|
|
457
|
+
response_text, request, self.tokenizer, self.tokenizer_name, original_finish_reason=""
|
|
420
458
|
)
|
|
459
|
+
if response_thinking is not None:
|
|
460
|
+
completion = dataclasses.replace(completion, thinking=Thinking(text=response_thinking))
|
|
421
461
|
completions.append(completion)
|
|
422
462
|
|
|
423
463
|
return RequestResult(
|
|
424
464
|
success=True,
|
|
425
465
|
cached=cached,
|
|
426
|
-
request_time=
|
|
427
|
-
request_datetime=
|
|
466
|
+
request_time=raw_response["request_time"],
|
|
467
|
+
request_datetime=raw_response["request_datetime"],
|
|
428
468
|
completions=completions,
|
|
429
469
|
embedding=[],
|
|
430
470
|
)
|
|
@@ -617,8 +657,8 @@ class AnthropicLegacyClient(CachingClient):
|
|
|
617
657
|
if logprobs["tokens"] != tokens:
|
|
618
658
|
# This is a known limitation with the Anthropic API. For now keep track of the
|
|
619
659
|
# entries with the mismatch.
|
|
620
|
-
|
|
621
|
-
f"
|
|
660
|
+
hwarn(
|
|
661
|
+
f"naive truncation for logprobs did not work."
|
|
622
662
|
f"\nRequest:{raw_request}\nExpected: {tokens}\nActual: {logprobs['tokens']}"
|
|
623
663
|
)
|
|
624
664
|
check_logprobs = True
|
|
@@ -96,9 +96,11 @@ class DivaLlamaClient(CachingClient):
|
|
|
96
96
|
with _LOCK:
|
|
97
97
|
audio_input, text_input = DivaLlamaClient._get_generate_input(request)
|
|
98
98
|
if text_input is None:
|
|
99
|
-
return {"completions": self.pre_trained_model.generate([audio_input])}
|
|
99
|
+
return {"completions": self.pre_trained_model.generate([audio_input])} # type: ignore
|
|
100
100
|
else:
|
|
101
|
-
return {
|
|
101
|
+
return {
|
|
102
|
+
"completions": self.pre_trained_model.generate([audio_input], [text_input]) # type: ignore
|
|
103
|
+
}
|
|
102
104
|
|
|
103
105
|
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
104
106
|
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
import torch
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from helm.clients.audio_language.qwen_omni.modeling_qwen2_5_omni import Qwen2_5OmniModel
|
|
7
|
+
from helm.clients.audio_language.qwen_omni.processing_qwen2_5_omni import Qwen2_5OmniProcessor
|
|
8
|
+
from helm.clients.audio_language.qwen_omni.qwen2_5_omni_utils.v2_5 import process_mm_info
|
|
9
|
+
|
|
10
|
+
from helm.common.cache import CacheConfig
|
|
11
|
+
from helm.common.gpu_utils import get_torch_device_name
|
|
12
|
+
from helm.common.hierarchical_logger import hlog, htrack_block
|
|
13
|
+
from helm.common.media_object import TEXT_TYPE
|
|
14
|
+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
|
|
15
|
+
from helm.common.request import wrap_request_time
|
|
16
|
+
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class LoadedQwen2_5OmniModelProcessor:
|
|
21
|
+
"""Loaded model and processor for Qwen."""
|
|
22
|
+
|
|
23
|
+
model: Qwen2_5OmniModel
|
|
24
|
+
tokenizer: Qwen2_5OmniProcessor
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
_models_lock: Lock = Lock()
|
|
28
|
+
_models: Dict[str, Optional[LoadedQwen2_5OmniModelProcessor]] = {
|
|
29
|
+
"Qwen/Qwen2.5-Omni-7B": None,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Qwen2_5OmniAudioLMClient(CachingClient):
|
|
34
|
+
"""
|
|
35
|
+
From https://huggingface.co/Qwen/Qwen2.5-Omni-7B,
|
|
36
|
+
Qwen2.5-Omni is an end-to-end multimodal model designed to perceive diverse modalities, including text,
|
|
37
|
+
images, audio, and video, while simultaneously generating text and natural speech responses in a streaming manner.
|
|
38
|
+
|
|
39
|
+
Paper: https://arxiv.org/abs/2503.20215
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
END_OF_TEXT_TOKEN: str = "<|endoftext|>>"
|
|
43
|
+
|
|
44
|
+
def __init__(self, cache_config: CacheConfig):
|
|
45
|
+
super().__init__(cache_config=cache_config)
|
|
46
|
+
self._device: str = get_torch_device_name()
|
|
47
|
+
|
|
48
|
+
def _get_model(self, helm_model_name: str) -> LoadedQwen2_5OmniModelProcessor:
|
|
49
|
+
global _models_lock
|
|
50
|
+
global _models
|
|
51
|
+
|
|
52
|
+
model_name: str
|
|
53
|
+
if helm_model_name == "qwen2.5-omni-7b":
|
|
54
|
+
model_name = "Qwen/Qwen2.5-Omni-7B"
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError(f"Unhandled model name: {helm_model_name}")
|
|
57
|
+
|
|
58
|
+
# Ensure that only one thread is loading the model at a time
|
|
59
|
+
with _models_lock:
|
|
60
|
+
loaded_model_processor = _models[model_name]
|
|
61
|
+
if loaded_model_processor is None:
|
|
62
|
+
hlog(f"Loading model {model_name} and caching in memory...")
|
|
63
|
+
model = Qwen2_5OmniModel.from_pretrained(
|
|
64
|
+
model_name,
|
|
65
|
+
attn_implementation="flash_attention_2",
|
|
66
|
+
torch_dtype=torch.bfloat16,
|
|
67
|
+
device_map=self._device,
|
|
68
|
+
).eval()
|
|
69
|
+
tokenizer = Qwen2_5OmniProcessor.from_pretrained(
|
|
70
|
+
model_name,
|
|
71
|
+
)
|
|
72
|
+
_models[model_name] = LoadedQwen2_5OmniModelProcessor(model, tokenizer)
|
|
73
|
+
loaded_model_processor = _models[model_name]
|
|
74
|
+
|
|
75
|
+
assert loaded_model_processor is not None
|
|
76
|
+
return loaded_model_processor
|
|
77
|
+
|
|
78
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
79
|
+
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
|
|
80
|
+
|
|
81
|
+
loaded_model_processor: LoadedQwen2_5OmniModelProcessor = self._get_model(request.model_engine)
|
|
82
|
+
model = loaded_model_processor.model
|
|
83
|
+
tokenizer = loaded_model_processor.tokenizer
|
|
84
|
+
|
|
85
|
+
input_query: List[Dict[str, Any]] = []
|
|
86
|
+
query: List[Dict[str, str]] = []
|
|
87
|
+
prompt_text: str = ""
|
|
88
|
+
|
|
89
|
+
input_query.append(
|
|
90
|
+
{
|
|
91
|
+
"role": "system",
|
|
92
|
+
"content": (
|
|
93
|
+
"You are Qwen, a virtual human developed by the Qwen Team,"
|
|
94
|
+
" Alibaba Group, capable of perceiving auditory and visual inputs,"
|
|
95
|
+
" as well as generating text and speech."
|
|
96
|
+
),
|
|
97
|
+
}
|
|
98
|
+
)
|
|
99
|
+
# prompt_text += "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
|
100
|
+
for media_num, media_object in enumerate(request.multimodal_prompt.media_objects):
|
|
101
|
+
if media_object.is_type("audio") and media_object.location:
|
|
102
|
+
assert media_object.is_local_file, "Only local audio files are supported"
|
|
103
|
+
query.append({"type": "audio", "audio": media_object.location})
|
|
104
|
+
|
|
105
|
+
# prompt_text += f"<|im_start|>user\nAudio {media_num+1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
|
|
106
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
107
|
+
if media_object.text is None:
|
|
108
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
109
|
+
query.append({"type": "text", "text": media_object.text})
|
|
110
|
+
# prompt_text += media_object.text
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
113
|
+
# prompt_text += "<|im_end|>\n<|im_start|>assistant\n"
|
|
114
|
+
|
|
115
|
+
input_query.append({"role": "user", "content": query})
|
|
116
|
+
|
|
117
|
+
completions: List[GeneratedOutput] = []
|
|
118
|
+
request_time: float = 0
|
|
119
|
+
request_datetime: Optional[int] = None
|
|
120
|
+
all_cached: bool = True
|
|
121
|
+
|
|
122
|
+
with htrack_block(f"Generating for prompt: {prompt_text}"):
|
|
123
|
+
for completion_index in range(request.num_completions):
|
|
124
|
+
try:
|
|
125
|
+
|
|
126
|
+
def do_it() -> Dict[str, Any]:
|
|
127
|
+
# Refer to the official Qwen2.5-Omni documentation for the format of the input query
|
|
128
|
+
# https://huggingface.co/Qwen/Qwen2.5-Omni-7B
|
|
129
|
+
USE_AUDIO_IN_VIDEO = True
|
|
130
|
+
text = tokenizer.apply_chat_template(input_query, add_generation_prompt=True, tokenize=False)
|
|
131
|
+
audios, images, videos = process_mm_info(input_query, use_audio_in_video=USE_AUDIO_IN_VIDEO)
|
|
132
|
+
inputs = tokenizer(
|
|
133
|
+
text=text,
|
|
134
|
+
audios=audios,
|
|
135
|
+
images=images,
|
|
136
|
+
videos=videos,
|
|
137
|
+
return_tensors="pt",
|
|
138
|
+
padding=True,
|
|
139
|
+
use_audio_in_video=USE_AUDIO_IN_VIDEO,
|
|
140
|
+
)
|
|
141
|
+
inputs = inputs.to(self._device, torch.bfloat16)
|
|
142
|
+
input_seq_length = len(inputs.input_ids[0])
|
|
143
|
+
# The model runs into errors when setting thinker_max_new_tokens to 1
|
|
144
|
+
if request.max_tokens != 1:
|
|
145
|
+
pred, _ = model.generate(**inputs, thinker_max_new_tokens=request.max_tokens)
|
|
146
|
+
pred_decode = pred.cpu()[0][input_seq_length:]
|
|
147
|
+
else:
|
|
148
|
+
pred, _ = model.generate(**inputs)
|
|
149
|
+
pred_decode = pred.cpu()[0][input_seq_length : input_seq_length + 1]
|
|
150
|
+
completion = tokenizer.decode(
|
|
151
|
+
pred_decode,
|
|
152
|
+
skip_special_tokens=True,
|
|
153
|
+
clean_up_tokenization_spaces=False,
|
|
154
|
+
)
|
|
155
|
+
# The processor of Qwen2-Audio-Instruct consists an AutoTokenizer and a WhisperFeatureExtractor
|
|
156
|
+
tokens: List[str] = tokenizer.tokenizer.tokenize(completion) # type: ignore
|
|
157
|
+
return {"output": (completion, tokens)}
|
|
158
|
+
|
|
159
|
+
# Include the prompt and model name in the cache key
|
|
160
|
+
cache_key = CachingClient.make_cache_key(
|
|
161
|
+
raw_request={
|
|
162
|
+
"completion_index": completion_index,
|
|
163
|
+
"model": request.model,
|
|
164
|
+
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
|
|
165
|
+
"max_tokens": request.max_tokens,
|
|
166
|
+
},
|
|
167
|
+
request=request,
|
|
168
|
+
)
|
|
169
|
+
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
170
|
+
except RuntimeError as model_error:
|
|
171
|
+
return RequestResult(
|
|
172
|
+
success=False, cached=False, error=str(model_error), completions=[], embedding=[]
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
text, tokens = result["output"]
|
|
176
|
+
hlog(f"Generated: {text}")
|
|
177
|
+
|
|
178
|
+
# Tokenize truncated text to get the list of tokens
|
|
179
|
+
completions.append(
|
|
180
|
+
GeneratedOutput(
|
|
181
|
+
text=text, logprob=0, tokens=[Token(text=str(token), logprob=0) for token in tokens]
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
request_time += result["request_time"]
|
|
186
|
+
# Use the datetime from the first completion because that's when the request was fired
|
|
187
|
+
request_datetime = request_datetime or result.get("request_datetime")
|
|
188
|
+
all_cached = all_cached and cached
|
|
189
|
+
|
|
190
|
+
return RequestResult(
|
|
191
|
+
success=True,
|
|
192
|
+
cached=all_cached,
|
|
193
|
+
request_time=request_time,
|
|
194
|
+
request_datetime=request_datetime,
|
|
195
|
+
completions=completions,
|
|
196
|
+
embedding=[],
|
|
197
|
+
)
|
|
@@ -113,7 +113,9 @@ class Qwen2AudioLMClient(CachingClient):
|
|
|
113
113
|
try:
|
|
114
114
|
|
|
115
115
|
def do_it() -> Dict[str, Any]:
|
|
116
|
-
inputs = tokenizer.apply_chat_template(
|
|
116
|
+
inputs = tokenizer.apply_chat_template( # type: ignore
|
|
117
|
+
input_query, add_generation_prompt=True, tokenize=False
|
|
118
|
+
)
|
|
117
119
|
audios: List[Any] = []
|
|
118
120
|
# Refer to the official Qwen2-Audio documentation for the format of the input query
|
|
119
121
|
# https://huggingface.co/Qwen/Qwen2-Audio-7B-Instruct
|
|
@@ -124,13 +126,13 @@ class Qwen2AudioLMClient(CachingClient):
|
|
|
124
126
|
audios.append(
|
|
125
127
|
librosa.load(
|
|
126
128
|
element["audio_url"],
|
|
127
|
-
sr=tokenizer.feature_extractor.sampling_rate,
|
|
129
|
+
sr=tokenizer.feature_extractor.sampling_rate, # type: ignore
|
|
128
130
|
)[0]
|
|
129
131
|
)
|
|
130
|
-
inputs = tokenizer(
|
|
132
|
+
inputs = tokenizer( # type: ignore
|
|
131
133
|
text=inputs,
|
|
132
134
|
audios=audios,
|
|
133
|
-
sampling_rate=tokenizer.feature_extractor.sampling_rate,
|
|
135
|
+
sampling_rate=tokenizer.feature_extractor.sampling_rate, # type: ignore
|
|
134
136
|
return_tensors="pt",
|
|
135
137
|
padding=True,
|
|
136
138
|
)
|
|
@@ -140,11 +142,11 @@ class Qwen2AudioLMClient(CachingClient):
|
|
|
140
142
|
inputs = inputs.to(self._device)
|
|
141
143
|
pred = model.generate(**inputs, max_length=request.max_tokens + input_length)[:, input_length:]
|
|
142
144
|
|
|
143
|
-
completion = tokenizer.decode(
|
|
145
|
+
completion = tokenizer.decode( # type: ignore
|
|
144
146
|
pred.cpu()[0], skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
145
147
|
)
|
|
146
148
|
# The processor of Qwen2-Audio-Instruct consists an AutoTokenizer and a WhisperFeatureExtractor
|
|
147
|
-
tokens: List[str] = tokenizer.tokenizer.tokenize(completion)
|
|
149
|
+
tokens: List[str] = tokenizer.tokenizer.tokenize(completion) # type: ignore
|
|
148
150
|
return {"output": (completion, tokens)}
|
|
149
151
|
|
|
150
152
|
# Include the prompt and model name in the cache key
|
|
@@ -106,8 +106,10 @@ class QwenAudioLMClient(CachingClient):
|
|
|
106
106
|
try:
|
|
107
107
|
|
|
108
108
|
def do_it() -> Dict[str, Any]:
|
|
109
|
-
completion, _ = model.chat(
|
|
110
|
-
|
|
109
|
+
completion, _ = model.chat( # type: ignore
|
|
110
|
+
tokenizer, query=tokenizer.from_list_format(query), history=None # type: ignore
|
|
111
|
+
)
|
|
112
|
+
tokens: List[str] = tokenizer.tokenize(completion) # type: ignore
|
|
111
113
|
return {"output": (completion, tokens)}
|
|
112
114
|
|
|
113
115
|
# Include the prompt and model name in the cache key
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import soundfile as sf
|
|
2
|
+
|
|
3
|
+
from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor # type: ignore
|
|
4
|
+
from qwen_omni_utils import process_mm_info
|
|
5
|
+
|
|
6
|
+
# default: Load the model on the available device(s)
|
|
7
|
+
model = Qwen2_5OmniModel.from_pretrained("Qwen/Qwen2.5-Omni-7B", torch_dtype="auto", device_map="auto")
|
|
8
|
+
|
|
9
|
+
# We recommend enabling flash_attention_2 for better acceleration and memory saving.
|
|
10
|
+
# model = Qwen2_5OmniModel.from_pretrained(
|
|
11
|
+
# "Qwen/Qwen2.5-Omni-7B",
|
|
12
|
+
# torch_dtype="auto",
|
|
13
|
+
# device_map="auto",
|
|
14
|
+
# attn_implementation="flash_attention_2",
|
|
15
|
+
# )
|
|
16
|
+
|
|
17
|
+
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
|
|
18
|
+
|
|
19
|
+
conversation = [
|
|
20
|
+
{
|
|
21
|
+
"role": "system",
|
|
22
|
+
"content": (
|
|
23
|
+
"You are Qwen, a virtual human developed by the Qwen Team,"
|
|
24
|
+
" Alibaba Group, capable of perceiving auditory and visual"
|
|
25
|
+
" inputs, as well as generating text and speech."
|
|
26
|
+
),
|
|
27
|
+
},
|
|
28
|
+
{
|
|
29
|
+
"role": "user",
|
|
30
|
+
"content": [
|
|
31
|
+
{"type": "video", "video": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/draw.mp4"},
|
|
32
|
+
],
|
|
33
|
+
},
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
# set use audio in video
|
|
37
|
+
USE_AUDIO_IN_VIDEO = True
|
|
38
|
+
|
|
39
|
+
# Preparation for inference
|
|
40
|
+
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
|
41
|
+
audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO)
|
|
42
|
+
inputs = processor(
|
|
43
|
+
text=text,
|
|
44
|
+
audios=audios,
|
|
45
|
+
images=images,
|
|
46
|
+
videos=videos,
|
|
47
|
+
return_tensors="pt",
|
|
48
|
+
padding=True,
|
|
49
|
+
use_audio_in_video=USE_AUDIO_IN_VIDEO,
|
|
50
|
+
)
|
|
51
|
+
inputs = inputs.to(model.device).to(model.dtype)
|
|
52
|
+
|
|
53
|
+
# Inference: Generation of the output text and audio
|
|
54
|
+
text_ids, audio = model.generate(**inputs, use_audio_in_video=USE_AUDIO_IN_VIDEO)
|
|
55
|
+
|
|
56
|
+
text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
57
|
+
print(text)
|
|
58
|
+
sf.write(
|
|
59
|
+
"output.wav",
|
|
60
|
+
audio.reshape(-1).detach().cpu().numpy(),
|
|
61
|
+
samplerate=24000,
|
|
62
|
+
)
|