evalscope 0.17.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evalscope/__init__.py +4 -1
- evalscope/api/__init__.py +0 -0
- evalscope/api/benchmark/__init__.py +3 -0
- evalscope/api/benchmark/adapters/__init__.py +3 -0
- evalscope/api/benchmark/adapters/default_data_adapter.py +683 -0
- evalscope/api/benchmark/adapters/multi_choice_adapter.py +83 -0
- evalscope/api/benchmark/adapters/text2image_adapter.py +155 -0
- evalscope/api/benchmark/benchmark.py +321 -0
- evalscope/api/benchmark/meta.py +115 -0
- evalscope/api/dataset/__init__.py +2 -0
- evalscope/api/dataset/dataset.py +349 -0
- evalscope/api/dataset/loader.py +261 -0
- evalscope/api/dataset/utils.py +143 -0
- evalscope/api/evaluator/__init__.py +3 -0
- evalscope/api/evaluator/cache.py +355 -0
- evalscope/api/evaluator/evaluator.py +56 -0
- evalscope/api/evaluator/state.py +264 -0
- evalscope/api/filter/__init__.py +1 -0
- evalscope/api/filter/filter.py +72 -0
- evalscope/api/messages/__init__.py +11 -0
- evalscope/api/messages/chat_message.py +198 -0
- evalscope/api/messages/content.py +102 -0
- evalscope/api/messages/utils.py +35 -0
- evalscope/api/metric/__init__.py +2 -0
- evalscope/api/metric/metric.py +55 -0
- evalscope/api/metric/scorer.py +105 -0
- evalscope/api/mixin/__init__.py +2 -0
- evalscope/api/mixin/dataset_mixin.py +105 -0
- evalscope/api/mixin/llm_judge_mixin.py +168 -0
- evalscope/api/model/__init__.py +12 -0
- evalscope/api/model/generate_config.py +157 -0
- evalscope/api/model/model.py +383 -0
- evalscope/api/model/model_output.py +285 -0
- evalscope/api/registry.py +182 -0
- evalscope/api/tool/__init__.py +3 -0
- evalscope/api/tool/tool_call.py +101 -0
- evalscope/api/tool/tool_info.py +173 -0
- evalscope/api/tool/utils.py +64 -0
- evalscope/app/ui/app_ui.py +2 -1
- evalscope/app/ui/multi_model.py +50 -25
- evalscope/app/ui/single_model.py +23 -11
- evalscope/app/utils/data_utils.py +42 -26
- evalscope/app/utils/text_utils.py +0 -2
- evalscope/app/utils/visualization.py +9 -4
- evalscope/arguments.py +6 -7
- evalscope/backend/opencompass/api_meta_template.py +2 -1
- evalscope/backend/opencompass/backend_manager.py +6 -3
- evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +10 -10
- evalscope/backend/rag_eval/clip_benchmark/task_template.py +8 -4
- evalscope/backend/rag_eval/ragas/task_template.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_transform.py +7 -4
- evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +2 -1
- evalscope/backend/rag_eval/utils/embedding.py +2 -1
- evalscope/backend/rag_eval/utils/llm.py +13 -12
- evalscope/benchmarks/__init__.py +0 -2
- evalscope/benchmarks/aigc/i2i/__init__.py +0 -0
- evalscope/benchmarks/aigc/i2i/general_i2i_adapter.py +44 -0
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +53 -55
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +41 -46
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +29 -45
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +34 -44
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +16 -27
- evalscope/benchmarks/aime/aime24_adapter.py +38 -40
- evalscope/benchmarks/aime/aime25_adapter.py +34 -40
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +86 -60
- evalscope/benchmarks/arc/arc_adapter.py +34 -147
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +96 -70
- evalscope/benchmarks/arena_hard/utils.py +37 -1
- evalscope/benchmarks/bbh/bbh_adapter.py +72 -144
- evalscope/benchmarks/bfcl/bfcl_adapter.py +181 -160
- evalscope/benchmarks/bfcl/generation.py +222 -0
- evalscope/benchmarks/ceval/ceval_adapter.py +94 -162
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +85 -82
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +34 -125
- evalscope/benchmarks/competition_math/competition_math_adapter.py +56 -108
- evalscope/benchmarks/data_collection/data_collection_adapter.py +183 -45
- evalscope/benchmarks/docmath/docmath_adapter.py +109 -51
- evalscope/benchmarks/docmath/utils.py +4 -5
- evalscope/benchmarks/drop/drop_adapter.py +88 -40
- evalscope/benchmarks/frames/frames_adapter.py +135 -52
- evalscope/benchmarks/general_arena/general_arena_adapter.py +136 -98
- evalscope/benchmarks/general_arena/utils.py +23 -27
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +40 -101
- evalscope/benchmarks/general_qa/general_qa_adapter.py +73 -134
- evalscope/benchmarks/gpqa/gpqa_adapter.py +61 -100
- evalscope/benchmarks/gpqa/{chain_of_thought.txt → prompt.py} +12 -5
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +62 -142
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +35 -124
- evalscope/benchmarks/hle/hle_adapter.py +127 -93
- evalscope/benchmarks/humaneval/humaneval_adapter.py +86 -55
- evalscope/benchmarks/ifeval/ifeval_adapter.py +69 -40
- evalscope/benchmarks/ifeval/instructions.py +109 -64
- evalscope/benchmarks/ifeval/instructions_registry.py +1 -1
- evalscope/benchmarks/ifeval/utils.py +6 -7
- evalscope/benchmarks/iquiz/iquiz_adapter.py +30 -65
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +2 -2
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +121 -71
- evalscope/benchmarks/live_code_bench/load_utils.py +13 -21
- evalscope/benchmarks/live_code_bench/testing_util.py +6 -2
- evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +49 -75
- evalscope/benchmarks/math_500/math_500_adapter.py +41 -48
- evalscope/benchmarks/mmlu/mmlu_adapter.py +32 -205
- evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +80 -99
- evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +64 -110
- evalscope/benchmarks/musr/musr_adapter.py +33 -64
- evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +192 -152
- evalscope/benchmarks/process_bench/process_bench_adapter.py +144 -76
- evalscope/benchmarks/race/race_adapter.py +33 -119
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +72 -70
- evalscope/benchmarks/super_gpqa/{five_shot_prompt.txt → prompt.py} +14 -16
- evalscope/benchmarks/super_gpqa/super_gpqa_adapter.py +73 -117
- evalscope/benchmarks/super_gpqa/utils.py +2 -1
- evalscope/benchmarks/tau_bench/generation.py +147 -0
- evalscope/benchmarks/tau_bench/tau_bench_adapter.py +112 -54
- evalscope/benchmarks/tool_bench/tool_bench_adapter.py +91 -70
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +56 -124
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +70 -265
- evalscope/benchmarks/winogrande/winogrande_adapter.py +28 -54
- evalscope/cli/cli.py +2 -0
- evalscope/cli/start_server.py +6 -3
- evalscope/collections/__init__.py +2 -10
- evalscope/collections/sampler.py +10 -10
- evalscope/collections/schema.py +13 -11
- evalscope/config.py +95 -54
- evalscope/constants.py +29 -61
- evalscope/evaluator/__init__.py +1 -1
- evalscope/evaluator/evaluator.py +277 -423
- evalscope/filters/__init__.py +2 -0
- evalscope/filters/extraction.py +126 -0
- evalscope/filters/selection.py +57 -0
- evalscope/metrics/__init__.py +13 -13
- evalscope/metrics/llm_judge.py +32 -30
- evalscope/metrics/math_parser.py +27 -22
- evalscope/metrics/metric.py +307 -0
- evalscope/metrics/metrics.py +22 -18
- evalscope/metrics/t2v_metrics/__init__.py +0 -52
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +4 -2
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +9 -13
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +3 -2
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +2 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +4 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +10 -5
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +4 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +15 -9
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +15 -10
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +9 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +2 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +3 -9
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +16 -10
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +8 -4
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +47 -25
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +12 -7
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +23 -17
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +33 -23
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +2 -1
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +46 -30
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +69 -37
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +7 -5
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +6 -4
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +7 -5
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +5 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +17 -13
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +35 -19
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +14 -12
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +63 -52
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +63 -38
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +6 -3
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +6 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +15 -13
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +3 -2
- evalscope/models/__init__.py +6 -29
- evalscope/models/mockllm.py +65 -0
- evalscope/models/model_apis.py +47 -0
- evalscope/models/modelscope.py +455 -0
- evalscope/models/openai_compatible.py +123 -0
- evalscope/models/text2image_model.py +124 -0
- evalscope/models/utils/openai.py +698 -0
- evalscope/perf/benchmark.py +2 -1
- evalscope/perf/http_client.py +4 -2
- evalscope/perf/plugin/api/custom_api.py +5 -4
- evalscope/perf/plugin/api/openai_api.py +11 -9
- evalscope/perf/plugin/datasets/custom.py +2 -1
- evalscope/perf/plugin/datasets/flickr8k.py +1 -1
- evalscope/perf/plugin/datasets/kontext_bench.py +1 -1
- evalscope/perf/plugin/datasets/line_by_line.py +2 -1
- evalscope/perf/plugin/datasets/longalpaca.py +2 -1
- evalscope/perf/plugin/datasets/openqa.py +4 -2
- evalscope/perf/utils/benchmark_util.py +7 -5
- evalscope/perf/utils/db_util.py +9 -6
- evalscope/perf/utils/local_server.py +8 -3
- evalscope/perf/utils/rich_display.py +16 -10
- evalscope/report/__init__.py +2 -2
- evalscope/report/combinator.py +18 -12
- evalscope/report/generator.py +101 -6
- evalscope/report/{utils.py → report.py} +8 -6
- evalscope/run.py +26 -44
- evalscope/summarizer.py +1 -1
- evalscope/utils/__init__.py +21 -2
- evalscope/utils/chat_service.py +2 -1
- evalscope/utils/deprecation_utils.py +12 -1
- evalscope/utils/function_utils.py +29 -0
- evalscope/utils/io_utils.py +100 -5
- evalscope/utils/json_schema.py +208 -0
- evalscope/utils/logger.py +51 -12
- evalscope/utils/model_utils.py +10 -7
- evalscope/utils/multi_choices.py +271 -0
- evalscope/utils/url_utils.py +65 -0
- evalscope/version.py +2 -2
- {evalscope-0.17.1.dist-info → evalscope-1.0.0.dist-info}/METADATA +98 -49
- {evalscope-0.17.1.dist-info → evalscope-1.0.0.dist-info}/RECORD +234 -216
- tests/aigc/test_t2i.py +22 -4
- tests/benchmark/__init__.py +1 -0
- tests/benchmark/test_eval.py +386 -0
- tests/cli/test_all.py +3 -5
- tests/cli/test_collection.py +13 -4
- tests/cli/test_custom.py +22 -15
- tests/rag/test_clip_benchmark.py +1 -0
- evalscope/benchmarks/aigc/t2i/base.py +0 -56
- evalscope/benchmarks/arc/ai2_arc.py +0 -151
- evalscope/benchmarks/benchmark.py +0 -81
- evalscope/benchmarks/ceval/ceval_exam.py +0 -146
- evalscope/benchmarks/cmmlu/cmmlu.py +0 -161
- evalscope/benchmarks/cmmlu/samples.jsonl +0 -5
- evalscope/benchmarks/competition_math/competition_math.py +0 -79
- evalscope/benchmarks/data_adapter.py +0 -528
- evalscope/benchmarks/filters.py +0 -59
- evalscope/benchmarks/gsm8k/gsm8k.py +0 -121
- evalscope/benchmarks/hellaswag/hellaswag.py +0 -112
- evalscope/benchmarks/humaneval/humaneval.py +0 -79
- evalscope/benchmarks/mmlu/mmlu.py +0 -160
- evalscope/benchmarks/mmlu/samples.jsonl +0 -5
- evalscope/benchmarks/process_bench/critique_template.txt +0 -13
- evalscope/benchmarks/race/race.py +0 -104
- evalscope/benchmarks/race/samples.jsonl +0 -5
- evalscope/benchmarks/super_gpqa/zero_shot_prompt.txt +0 -4
- evalscope/benchmarks/trivia_qa/trivia_qa.py +0 -89
- evalscope/benchmarks/truthful_qa/truthful_qa.py +0 -163
- evalscope/benchmarks/utils.py +0 -60
- evalscope/collections/evaluator.py +0 -375
- evalscope/metrics/completion_parsers.py +0 -227
- evalscope/metrics/named_metrics.py +0 -55
- evalscope/models/adapters/__init__.py +0 -14
- evalscope/models/adapters/base_adapter.py +0 -84
- evalscope/models/adapters/bfcl_adapter.py +0 -246
- evalscope/models/adapters/chat_adapter.py +0 -207
- evalscope/models/adapters/choice_adapter.py +0 -222
- evalscope/models/adapters/custom_adapter.py +0 -71
- evalscope/models/adapters/server_adapter.py +0 -236
- evalscope/models/adapters/t2i_adapter.py +0 -79
- evalscope/models/adapters/tau_bench_adapter.py +0 -189
- evalscope/models/custom/__init__.py +0 -4
- evalscope/models/custom/custom_model.py +0 -50
- evalscope/models/custom/dummy_model.py +0 -99
- evalscope/models/local_model.py +0 -128
- evalscope/models/register.py +0 -41
- tests/cli/test_run.py +0 -489
- {evalscope-0.17.1.dist-info → evalscope-1.0.0.dist-info}/LICENSE +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.0.dist-info}/WHEEL +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.0.dist-info}/entry_points.txt +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -111,7 +111,8 @@ class BlipVQA(BlipBase):
|
|
|
111
111
|
|
|
112
112
|
image_embeds = self.visual_encoder.forward_features(samples['image'])
|
|
113
113
|
encoder_output = self.text_encoder.forward_automask(
|
|
114
|
-
tokenized_text=samples['tokenized_text'], visual_embeds=image_embeds
|
|
114
|
+
tokenized_text=samples['tokenized_text'], visual_embeds=image_embeds
|
|
115
|
+
)
|
|
115
116
|
|
|
116
117
|
return encoder_output, image_embeds
|
|
117
118
|
|
|
@@ -150,15 +151,17 @@ class BlipVQA(BlipBase):
|
|
|
150
151
|
|
|
151
152
|
return loss, answer_output, answer_targets
|
|
152
153
|
|
|
153
|
-
def predict_answers(
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
154
|
+
def predict_answers(
|
|
155
|
+
self,
|
|
156
|
+
samples,
|
|
157
|
+
num_beams=3,
|
|
158
|
+
inference_method='rank',
|
|
159
|
+
max_len=10,
|
|
160
|
+
min_len=1,
|
|
161
|
+
num_ans_candidates=128,
|
|
162
|
+
answer_list=None,
|
|
163
|
+
**kwargs
|
|
164
|
+
):
|
|
162
165
|
"""
|
|
163
166
|
Args:
|
|
164
167
|
samples (dict): A dictionary containing the following keys:
|
|
@@ -204,8 +207,8 @@ class BlipVQA(BlipBase):
|
|
|
204
207
|
if isinstance(samples['text_input'], str):
|
|
205
208
|
samples['text_input'] = [samples['text_input']]
|
|
206
209
|
|
|
207
|
-
assert len(samples['text_input']
|
|
208
|
-
|
|
210
|
+
assert len(samples['text_input']
|
|
211
|
+
) == samples['image'].size(0), 'The number of questions must be equal to the batch size.'
|
|
209
212
|
|
|
210
213
|
if inference_method == 'generate':
|
|
211
214
|
return self._generate_answers(samples, num_beams=num_beams, max_length=max_len, min_length=min_len)
|
|
@@ -239,7 +242,8 @@ class BlipVQA(BlipBase):
|
|
|
239
242
|
num_beams=num_beams,
|
|
240
243
|
eos_token_id=self.tokenizer.sep_token_id,
|
|
241
244
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
242
|
-
**model_kwargs
|
|
245
|
+
**model_kwargs
|
|
246
|
+
)
|
|
243
247
|
|
|
244
248
|
# collect answers
|
|
245
249
|
answers = []
|
evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py
CHANGED
|
@@ -10,10 +10,16 @@ import torch
|
|
|
10
10
|
import torch.utils.checkpoint
|
|
11
11
|
from torch import Tensor, device, nn
|
|
12
12
|
from transformers.activations import ACT2FN
|
|
13
|
-
from transformers.modeling_outputs import (
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
13
|
+
from transformers.modeling_outputs import (
|
|
14
|
+
BaseModelOutputWithPastAndCrossAttentions,
|
|
15
|
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
16
|
+
)
|
|
17
|
+
from transformers.modeling_utils import (
|
|
18
|
+
PreTrainedModel,
|
|
19
|
+
apply_chunking_to_forward,
|
|
20
|
+
find_pruneable_heads_and_indices,
|
|
21
|
+
prune_linear_layer,
|
|
22
|
+
)
|
|
17
23
|
from transformers.models.bert.configuration_bert import BertConfig
|
|
18
24
|
from transformers.utils import logging
|
|
19
25
|
from typing import Tuple
|
|
@@ -76,8 +82,10 @@ class BertSelfAttention(nn.Module):
|
|
|
76
82
|
super().__init__()
|
|
77
83
|
self.config = config
|
|
78
84
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, 'embedding_size'):
|
|
79
|
-
raise ValueError(
|
|
80
|
-
|
|
85
|
+
raise ValueError(
|
|
86
|
+
'The hidden size (%d) is not a multiple of the number of attention '
|
|
87
|
+
'heads (%d)' % (config.hidden_size, config.num_attention_heads)
|
|
88
|
+
)
|
|
81
89
|
|
|
82
90
|
self.num_attention_heads = config.num_attention_heads
|
|
83
91
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
@@ -465,7 +473,8 @@ class BertEncoder(nn.Module):
|
|
|
465
473
|
|
|
466
474
|
if use_cache:
|
|
467
475
|
logger.warn(
|
|
468
|
-
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
|
|
476
|
+
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
|
|
477
|
+
)
|
|
469
478
|
use_cache = False
|
|
470
479
|
|
|
471
480
|
def create_custom_forward(module):
|
|
@@ -506,13 +515,15 @@ class BertEncoder(nn.Module):
|
|
|
506
515
|
all_hidden_states = all_hidden_states + (hidden_states, )
|
|
507
516
|
|
|
508
517
|
if not return_dict:
|
|
509
|
-
return tuple(
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
518
|
+
return tuple(
|
|
519
|
+
v for v in [
|
|
520
|
+
hidden_states,
|
|
521
|
+
next_decoder_cache,
|
|
522
|
+
all_hidden_states,
|
|
523
|
+
all_self_attentions,
|
|
524
|
+
all_cross_attentions,
|
|
525
|
+
] if v is not None
|
|
526
|
+
)
|
|
516
527
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
517
528
|
last_hidden_state=hidden_states,
|
|
518
529
|
past_key_values=next_decoder_cache,
|
|
@@ -703,8 +714,11 @@ class BertModel(BertPreTrainedModel):
|
|
|
703
714
|
else:
|
|
704
715
|
extended_attention_mask = attention_mask[:, None, None, :]
|
|
705
716
|
else:
|
|
706
|
-
raise ValueError(
|
|
707
|
-
|
|
717
|
+
raise ValueError(
|
|
718
|
+
'Wrong shape for input_ids (shape {}) or attention_mask (shape {})'.format(
|
|
719
|
+
input_shape, attention_mask.shape
|
|
720
|
+
)
|
|
721
|
+
)
|
|
708
722
|
|
|
709
723
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
710
724
|
# masked positions, this operation will create a tensor which is 0.0 for
|
|
@@ -753,7 +767,8 @@ class BertModel(BertPreTrainedModel):
|
|
|
753
767
|
"""
|
|
754
768
|
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
|
|
755
769
|
output_hidden_states = (
|
|
756
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
770
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
771
|
+
)
|
|
757
772
|
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
|
758
773
|
|
|
759
774
|
if is_decoder:
|
|
@@ -786,8 +801,9 @@ class BertModel(BertPreTrainedModel):
|
|
|
786
801
|
|
|
787
802
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
788
803
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
789
|
-
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
|
790
|
-
|
|
804
|
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
|
805
|
+
attention_mask, input_shape, device, is_decoder
|
|
806
|
+
)
|
|
791
807
|
|
|
792
808
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
|
793
809
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
@@ -39,7 +39,8 @@ class Bottleneck(nn.Module):
|
|
|
39
39
|
self.downsample = nn.Sequential(
|
|
40
40
|
OrderedDict([('-1', nn.AvgPool2d(stride)),
|
|
41
41
|
('0', nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
|
42
|
-
('1', nn.BatchNorm2d(planes * self.expansion))])
|
|
42
|
+
('1', nn.BatchNorm2d(planes * self.expansion))])
|
|
43
|
+
)
|
|
43
44
|
|
|
44
45
|
def forward(self, x: torch.Tensor):
|
|
45
46
|
identity = x
|
|
@@ -91,7 +92,8 @@ class AttentionPool2d(nn.Module):
|
|
|
91
92
|
out_proj_bias=self.c_proj.bias,
|
|
92
93
|
use_separate_proj_weight=True,
|
|
93
94
|
training=self.training,
|
|
94
|
-
need_weights=False
|
|
95
|
+
need_weights=False
|
|
96
|
+
)
|
|
95
97
|
|
|
96
98
|
return x[0]
|
|
97
99
|
|
|
@@ -120,7 +122,8 @@ class ResidualAttentionBlock(nn.Module):
|
|
|
120
122
|
self.ln_1 = LayerNorm(d_model)
|
|
121
123
|
self.mlp = nn.Sequential(
|
|
122
124
|
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), ('gelu', QuickGELU()),
|
|
123
|
-
('c_proj', nn.Linear(d_model * 4, d_model))])
|
|
125
|
+
('c_proj', nn.Linear(d_model * 4, d_model))])
|
|
126
|
+
)
|
|
124
127
|
self.ln_2 = LayerNorm(d_model)
|
|
125
128
|
self.attn_mask = attn_mask
|
|
126
129
|
|
|
@@ -141,18 +144,16 @@ class ResidualAttentionBlock(nn.Module):
|
|
|
141
144
|
|
|
142
145
|
class Transformer(nn.Module):
|
|
143
146
|
|
|
144
|
-
def __init__(
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
heads: int,
|
|
148
|
-
attn_mask: torch.Tensor = None,
|
|
149
|
-
use_grad_checkpointing=False):
|
|
147
|
+
def __init__(
|
|
148
|
+
self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False
|
|
149
|
+
):
|
|
150
150
|
super().__init__()
|
|
151
151
|
self.width = width
|
|
152
152
|
self.layers = layers
|
|
153
153
|
self.resblocks = nn.Sequential(
|
|
154
154
|
*
|
|
155
|
-
[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i > 12) for i in range(layers)]
|
|
155
|
+
[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i > 12) for i in range(layers)]
|
|
156
|
+
)
|
|
156
157
|
|
|
157
158
|
def forward(self, x: torch.Tensor):
|
|
158
159
|
return self.resblocks(x)
|
|
@@ -160,8 +161,9 @@ class Transformer(nn.Module):
|
|
|
160
161
|
|
|
161
162
|
class VisionTransformer(nn.Module):
|
|
162
163
|
|
|
163
|
-
def __init__(
|
|
164
|
-
|
|
164
|
+
def __init__(
|
|
165
|
+
self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool
|
|
166
|
+
):
|
|
165
167
|
super().__init__()
|
|
166
168
|
self.input_resolution = input_resolution
|
|
167
169
|
self.num_features = width
|
|
@@ -72,15 +72,17 @@ class Mlp(nn.Module):
|
|
|
72
72
|
|
|
73
73
|
class Attention(nn.Module):
|
|
74
74
|
|
|
75
|
-
def __init__(
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
dim,
|
|
78
|
+
num_heads=8,
|
|
79
|
+
qkv_bias=False,
|
|
80
|
+
qk_scale=None,
|
|
81
|
+
attn_drop=0.,
|
|
82
|
+
proj_drop=0.,
|
|
83
|
+
window_size=None,
|
|
84
|
+
attn_head_dim=None
|
|
85
|
+
):
|
|
84
86
|
super().__init__()
|
|
85
87
|
self.num_heads = num_heads
|
|
86
88
|
head_dim = dim // num_heads
|
|
@@ -100,8 +102,9 @@ class Attention(nn.Module):
|
|
|
100
102
|
if window_size:
|
|
101
103
|
self.window_size = window_size
|
|
102
104
|
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
|
103
|
-
self.relative_position_bias_table = nn.Parameter(
|
|
104
|
-
|
|
105
|
+
self.relative_position_bias_table = nn.Parameter(
|
|
106
|
+
torch.zeros(self.num_relative_distance, num_heads)
|
|
107
|
+
) # 2*Wh-1 * 2*Ww-1, nH
|
|
105
108
|
# cls to token & token 2 cls & cls to cls
|
|
106
109
|
|
|
107
110
|
# get pair-wise relative position index for each token inside the window
|
|
@@ -166,20 +169,22 @@ class Attention(nn.Module):
|
|
|
166
169
|
|
|
167
170
|
class Block(nn.Module):
|
|
168
171
|
|
|
169
|
-
def __init__(
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
172
|
+
def __init__(
|
|
173
|
+
self,
|
|
174
|
+
dim,
|
|
175
|
+
num_heads,
|
|
176
|
+
mlp_ratio=4.,
|
|
177
|
+
qkv_bias=False,
|
|
178
|
+
qk_scale=None,
|
|
179
|
+
drop=0.,
|
|
180
|
+
attn_drop=0.,
|
|
181
|
+
drop_path=0.,
|
|
182
|
+
init_values=None,
|
|
183
|
+
act_layer=nn.GELU,
|
|
184
|
+
norm_layer=nn.LayerNorm,
|
|
185
|
+
window_size=None,
|
|
186
|
+
attn_head_dim=None
|
|
187
|
+
):
|
|
183
188
|
super().__init__()
|
|
184
189
|
self.norm1 = norm_layer(dim)
|
|
185
190
|
self.attn = Attention(
|
|
@@ -190,7 +195,8 @@ class Block(nn.Module):
|
|
|
190
195
|
attn_drop=attn_drop,
|
|
191
196
|
proj_drop=drop,
|
|
192
197
|
window_size=window_size,
|
|
193
|
-
attn_head_dim=attn_head_dim
|
|
198
|
+
attn_head_dim=attn_head_dim
|
|
199
|
+
)
|
|
194
200
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
|
195
201
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
196
202
|
self.norm2 = norm_layer(dim)
|
|
@@ -244,8 +250,9 @@ class RelativePositionBias(nn.Module):
|
|
|
244
250
|
super().__init__()
|
|
245
251
|
self.window_size = window_size
|
|
246
252
|
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
|
247
|
-
self.relative_position_bias_table = nn.Parameter(
|
|
248
|
-
|
|
253
|
+
self.relative_position_bias_table = nn.Parameter(
|
|
254
|
+
torch.zeros(self.num_relative_distance, num_heads)
|
|
255
|
+
) # 2*Wh-1 * 2*Ww-1, nH
|
|
249
256
|
# cls to token & token 2 cls & cls to cls
|
|
250
257
|
|
|
251
258
|
# get pair-wise relative position index for each token inside the window
|
|
@@ -281,28 +288,30 @@ class VisionTransformer(nn.Module):
|
|
|
281
288
|
""" Vision Transformer with support for patch or hybrid CNN input stage
|
|
282
289
|
"""
|
|
283
290
|
|
|
284
|
-
def __init__(
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
291
|
+
def __init__(
|
|
292
|
+
self,
|
|
293
|
+
img_size=224,
|
|
294
|
+
patch_size=16,
|
|
295
|
+
in_chans=3,
|
|
296
|
+
num_classes=1000,
|
|
297
|
+
embed_dim=768,
|
|
298
|
+
depth=12,
|
|
299
|
+
num_heads=12,
|
|
300
|
+
mlp_ratio=4.,
|
|
301
|
+
qkv_bias=False,
|
|
302
|
+
qk_scale=None,
|
|
303
|
+
drop_rate=0.,
|
|
304
|
+
attn_drop_rate=0.,
|
|
305
|
+
drop_path_rate=0.,
|
|
306
|
+
norm_layer=nn.LayerNorm,
|
|
307
|
+
init_values=None,
|
|
308
|
+
use_abs_pos_emb=True,
|
|
309
|
+
use_rel_pos_bias=False,
|
|
310
|
+
use_shared_rel_pos_bias=False,
|
|
311
|
+
use_mean_pooling=True,
|
|
312
|
+
init_scale=0.001,
|
|
313
|
+
use_checkpoint=False
|
|
314
|
+
):
|
|
306
315
|
super().__init__()
|
|
307
316
|
self.image_size = img_size
|
|
308
317
|
self.num_classes = num_classes
|
|
@@ -338,7 +347,8 @@ class VisionTransformer(nn.Module):
|
|
|
338
347
|
drop_path=dpr[i],
|
|
339
348
|
norm_layer=norm_layer,
|
|
340
349
|
init_values=init_values,
|
|
341
|
-
window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None
|
|
350
|
+
window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None
|
|
351
|
+
) for i in range(depth)
|
|
342
352
|
])
|
|
343
353
|
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
|
344
354
|
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
|
@@ -450,7 +460,8 @@ def interpolate_pos_embed(model, checkpoint_model):
|
|
|
450
460
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
|
451
461
|
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
|
452
462
|
pos_tokens = torch.nn.functional.interpolate(
|
|
453
|
-
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False
|
|
463
|
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False
|
|
464
|
+
)
|
|
454
465
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
|
455
466
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
|
456
467
|
checkpoint_model['pos_embed'] = new_pos_embed
|
|
@@ -20,13 +20,23 @@ from torch.nn import CrossEntropyLoss
|
|
|
20
20
|
from transformers import BatchEncoding, PreTrainedTokenizer
|
|
21
21
|
from transformers.activations import ACT2FN
|
|
22
22
|
from transformers.file_utils import ModelOutput
|
|
23
|
-
from transformers.modeling_outputs import (
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
23
|
+
from transformers.modeling_outputs import (
|
|
24
|
+
BaseModelOutputWithPastAndCrossAttentions,
|
|
25
|
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
26
|
+
CausalLMOutputWithCrossAttentions,
|
|
27
|
+
MaskedLMOutput,
|
|
28
|
+
MultipleChoiceModelOutput,
|
|
29
|
+
NextSentencePredictorOutput,
|
|
30
|
+
QuestionAnsweringModelOutput,
|
|
31
|
+
SequenceClassifierOutput,
|
|
32
|
+
TokenClassifierOutput,
|
|
33
|
+
)
|
|
34
|
+
from transformers.modeling_utils import (
|
|
35
|
+
PreTrainedModel,
|
|
36
|
+
apply_chunking_to_forward,
|
|
37
|
+
find_pruneable_heads_and_indices,
|
|
38
|
+
prune_linear_layer,
|
|
39
|
+
)
|
|
30
40
|
from transformers.models.bert.configuration_bert import BertConfig
|
|
31
41
|
from transformers.utils import logging
|
|
32
42
|
from typing import Optional, Tuple
|
|
@@ -102,8 +112,10 @@ class BertSelfAttention(nn.Module):
|
|
|
102
112
|
super().__init__()
|
|
103
113
|
self.config = config
|
|
104
114
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, 'embedding_size'):
|
|
105
|
-
raise ValueError(
|
|
106
|
-
|
|
115
|
+
raise ValueError(
|
|
116
|
+
'The hidden size (%d) is not a multiple of the number of attention '
|
|
117
|
+
'heads (%d)' % (config.hidden_size, config.num_attention_heads)
|
|
118
|
+
)
|
|
107
119
|
|
|
108
120
|
self.num_attention_heads = config.num_attention_heads
|
|
109
121
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
@@ -408,8 +420,9 @@ class BertLayer(nn.Module):
|
|
|
408
420
|
output_attentions=output_attentions,
|
|
409
421
|
)
|
|
410
422
|
attention_output = cross_attention_outputs[0]
|
|
411
|
-
outputs = (
|
|
412
|
-
|
|
423
|
+
outputs = (
|
|
424
|
+
outputs + cross_attention_outputs[1:-1]
|
|
425
|
+
) # add cross attentions if we output attention weights
|
|
413
426
|
layer_output = apply_chunking_to_forward(
|
|
414
427
|
self.feed_forward_chunk,
|
|
415
428
|
self.chunk_size_feed_forward,
|
|
@@ -492,7 +505,8 @@ class BertEncoder(nn.Module):
|
|
|
492
505
|
|
|
493
506
|
if use_cache:
|
|
494
507
|
logger.warn(
|
|
495
|
-
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
|
|
508
|
+
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
|
|
509
|
+
)
|
|
496
510
|
use_cache = False
|
|
497
511
|
|
|
498
512
|
def create_custom_forward(module):
|
|
@@ -533,13 +547,15 @@ class BertEncoder(nn.Module):
|
|
|
533
547
|
all_hidden_states = all_hidden_states + (hidden_states, )
|
|
534
548
|
|
|
535
549
|
if not return_dict:
|
|
536
|
-
return tuple(
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
550
|
+
return tuple(
|
|
551
|
+
v for v in [
|
|
552
|
+
hidden_states,
|
|
553
|
+
next_decoder_cache,
|
|
554
|
+
all_hidden_states,
|
|
555
|
+
all_self_attentions,
|
|
556
|
+
all_cross_attentions,
|
|
557
|
+
] if v is not None
|
|
558
|
+
)
|
|
543
559
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
544
560
|
last_hidden_state=hidden_states,
|
|
545
561
|
past_key_values=next_decoder_cache,
|
|
@@ -730,8 +746,11 @@ class BertModel(BertPreTrainedModel):
|
|
|
730
746
|
else:
|
|
731
747
|
extended_attention_mask = attention_mask[:, None, None, :]
|
|
732
748
|
else:
|
|
733
|
-
raise ValueError(
|
|
734
|
-
|
|
749
|
+
raise ValueError(
|
|
750
|
+
'Wrong shape for input_ids (shape {}) or attention_mask (shape {})'.format(
|
|
751
|
+
input_shape, attention_mask.shape
|
|
752
|
+
)
|
|
753
|
+
)
|
|
735
754
|
|
|
736
755
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
737
756
|
# masked positions, this operation will create a tensor which is 0.0 for
|
|
@@ -781,7 +800,8 @@ class BertModel(BertPreTrainedModel):
|
|
|
781
800
|
"""
|
|
782
801
|
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
|
|
783
802
|
output_hidden_states = (
|
|
784
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
803
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
804
|
+
)
|
|
785
805
|
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
|
786
806
|
|
|
787
807
|
if is_decoder:
|
|
@@ -814,8 +834,9 @@ class BertModel(BertPreTrainedModel):
|
|
|
814
834
|
|
|
815
835
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
816
836
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
817
|
-
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
|
818
|
-
|
|
837
|
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
|
838
|
+
attention_mask, input_shape, device, is_decoder
|
|
839
|
+
)
|
|
819
840
|
|
|
820
841
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
|
821
842
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
@@ -1176,18 +1197,20 @@ class XBertLMHeadDecoder(BertLMHeadModel):
|
|
|
1176
1197
|
else:
|
|
1177
1198
|
return cls(config=med_config)
|
|
1178
1199
|
|
|
1179
|
-
def generate_from_encoder(
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1200
|
+
def generate_from_encoder(
|
|
1201
|
+
self,
|
|
1202
|
+
tokenized_prompt,
|
|
1203
|
+
visual_embeds,
|
|
1204
|
+
sep_token_id,
|
|
1205
|
+
pad_token_id,
|
|
1206
|
+
use_nucleus_sampling=False,
|
|
1207
|
+
num_beams=3,
|
|
1208
|
+
max_length=30,
|
|
1209
|
+
min_length=10,
|
|
1210
|
+
top_p=0.9,
|
|
1211
|
+
repetition_penalty=1.0,
|
|
1212
|
+
**kwargs
|
|
1213
|
+
):
|
|
1191
1214
|
|
|
1192
1215
|
if not use_nucleus_sampling:
|
|
1193
1216
|
num_beams = num_beams
|
|
@@ -1212,7 +1235,8 @@ class XBertLMHeadDecoder(BertLMHeadModel):
|
|
|
1212
1235
|
eos_token_id=sep_token_id,
|
|
1213
1236
|
pad_token_id=pad_token_id,
|
|
1214
1237
|
repetition_penalty=1.1,
|
|
1215
|
-
**model_kwargs
|
|
1238
|
+
**model_kwargs
|
|
1239
|
+
)
|
|
1216
1240
|
else:
|
|
1217
1241
|
# beam search
|
|
1218
1242
|
outputs = self.generate(
|
|
@@ -1223,7 +1247,8 @@ class XBertLMHeadDecoder(BertLMHeadModel):
|
|
|
1223
1247
|
eos_token_id=sep_token_id,
|
|
1224
1248
|
pad_token_id=pad_token_id,
|
|
1225
1249
|
repetition_penalty=repetition_penalty,
|
|
1226
|
-
**model_kwargs
|
|
1250
|
+
**model_kwargs
|
|
1251
|
+
)
|
|
1227
1252
|
|
|
1228
1253
|
return outputs
|
|
1229
1254
|
|
|
@@ -343,9 +343,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|
|
343
343
|
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
|
344
344
|
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
|
345
345
|
block.attn.qkv.weight.copy_(
|
|
346
|
-
torch.cat([_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])
|
|
346
|
+
torch.cat([_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])
|
|
347
|
+
)
|
|
347
348
|
block.attn.qkv.bias.copy_(
|
|
348
|
-
torch.cat([_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])
|
|
349
|
+
torch.cat([_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])
|
|
350
|
+
)
|
|
349
351
|
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
|
350
352
|
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
|
351
353
|
for r in range(2):
|
|
@@ -394,7 +396,8 @@ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
|
|
394
396
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
|
395
397
|
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
|
396
398
|
pos_tokens = torch.nn.functional.interpolate(
|
|
397
|
-
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False
|
|
399
|
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False
|
|
400
|
+
)
|
|
398
401
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
|
399
402
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
|
400
403
|
print('reshape position embedding from %d to %d' % (orig_size**2, new_size**2))
|
|
@@ -7,8 +7,12 @@
|
|
|
7
7
|
|
|
8
8
|
from ..common.registry import registry
|
|
9
9
|
from .base_processor import BaseProcessor
|
|
10
|
-
from .blip_processors import (
|
|
11
|
-
|
|
10
|
+
from .blip_processors import (
|
|
11
|
+
Blip2ImageTrainProcessor,
|
|
12
|
+
BlipCaptionProcessor,
|
|
13
|
+
BlipImageEvalProcessor,
|
|
14
|
+
BlipImageTrainProcessor,
|
|
15
|
+
)
|
|
12
16
|
|
|
13
17
|
__all__ = [
|
|
14
18
|
'BaseProcessor',
|
|
@@ -107,8 +107,9 @@ def color_func(img, factor):
|
|
|
107
107
|
# np.eye(3) * factor
|
|
108
108
|
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
|
109
109
|
# )[np.newaxis, np.newaxis, :]
|
|
110
|
-
M = np.float32([[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]) * factor + np.float32(
|
|
111
|
-
|
|
110
|
+
M = np.float32([[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]) * factor + np.float32([[
|
|
111
|
+
0.114
|
|
112
|
+
], [0.587], [0.299]])
|
|
112
113
|
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
|
113
114
|
return out
|
|
114
115
|
|