evalscope 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/arguments.py +2 -1
- evalscope/benchmarks/__init__.py +2 -2
- evalscope/benchmarks/aigc/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/base.py +56 -0
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
- evalscope/benchmarks/aime/aime24_adapter.py +1 -1
- evalscope/benchmarks/aime/aime25_adapter.py +4 -4
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
- evalscope/benchmarks/arc/arc_adapter.py +1 -1
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
- evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
- evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
- evalscope/benchmarks/data_adapter.py +16 -9
- evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +4 -1
- evalscope/benchmarks/live_code_bench/testing_util.py +6 -3
- evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
- evalscope/benchmarks/mmlu/mmlu_adapter.py +3 -1
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
- evalscope/benchmarks/utils.py +7 -16
- evalscope/cli/start_app.py +1 -1
- evalscope/collections/evaluator.py +16 -4
- evalscope/config.py +7 -3
- evalscope/constants.py +11 -0
- evalscope/evaluator/evaluator.py +2 -2
- evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
- evalscope/metrics/__init__.py +49 -4
- evalscope/metrics/llm_judge.py +1 -1
- evalscope/metrics/named_metrics.py +13 -0
- evalscope/metrics/t2v_metrics/__init__.py +66 -0
- evalscope/metrics/t2v_metrics/clipscore.py +14 -0
- evalscope/metrics/t2v_metrics/constants.py +12 -0
- evalscope/metrics/t2v_metrics/itmscore.py +14 -0
- evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
- evalscope/metrics/t2v_metrics/models/model.py +45 -0
- evalscope/metrics/t2v_metrics/models/utils.py +25 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
- evalscope/metrics/t2v_metrics/score.py +78 -0
- evalscope/metrics/t2v_metrics/vqascore.py +14 -0
- evalscope/models/__init__.py +50 -14
- evalscope/models/adapters/__init__.py +17 -0
- evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
- evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
- evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
- evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
- evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
- evalscope/models/adapters/t2i_adapter.py +76 -0
- evalscope/models/custom/__init__.py +2 -1
- evalscope/models/custom/dummy_model.py +11 -13
- evalscope/models/local_model.py +82 -33
- evalscope/models/model.py +2 -42
- evalscope/models/register.py +26 -0
- evalscope/perf/plugin/datasets/flickr8k.py +2 -1
- evalscope/perf/utils/benchmark_util.py +2 -2
- evalscope/perf/utils/db_util.py +8 -2
- evalscope/report/__init__.py +1 -0
- evalscope/report/app.py +117 -67
- evalscope/report/app_arguments.py +11 -0
- evalscope/report/generator.py +1 -1
- evalscope/run.py +3 -3
- evalscope/third_party/thinkbench/eval.py +19 -7
- evalscope/utils/chat_service.py +2 -2
- evalscope/utils/import_utils.py +66 -0
- evalscope/utils/utils.py +12 -4
- evalscope/version.py +2 -2
- {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/METADATA +18 -1
- {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/RECORD +175 -63
- tests/aigc/__init__.py +1 -0
- tests/aigc/test_t2i.py +87 -0
- tests/cli/test_run.py +11 -5
- tests/perf/test_perf.py +2 -1
- evalscope/metrics/code_metric.py +0 -98
- evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
- evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
- {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/LICENSE +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/WHEEL +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/entry_points.txt +0 -0
- {evalscope-0.14.0.dist-info → evalscope-0.15.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,473 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
|
|
7
|
+
Based on timm code base
|
|
8
|
+
https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import math
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
from functools import partial
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
from timm.layers import DropPath, PatchEmbed, trunc_normal_
|
|
19
|
+
from timm.models import adapt_input_conv
|
|
20
|
+
except ImportError:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
from ..models.base_model import BaseEncoder
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Mlp(nn.Module):
|
|
27
|
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
in_features,
|
|
32
|
+
hidden_features=None,
|
|
33
|
+
out_features=None,
|
|
34
|
+
act_layer=nn.GELU,
|
|
35
|
+
drop=0.0,
|
|
36
|
+
):
|
|
37
|
+
super().__init__()
|
|
38
|
+
out_features = out_features or in_features
|
|
39
|
+
hidden_features = hidden_features or in_features
|
|
40
|
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
41
|
+
self.act = act_layer()
|
|
42
|
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
43
|
+
self.drop = nn.Dropout(drop)
|
|
44
|
+
|
|
45
|
+
def forward(self, x):
|
|
46
|
+
x = self.fc1(x)
|
|
47
|
+
x = self.act(x)
|
|
48
|
+
x = self.drop(x)
|
|
49
|
+
x = self.fc2(x)
|
|
50
|
+
x = self.drop(x)
|
|
51
|
+
return x
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class Attention(nn.Module):
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
dim,
|
|
59
|
+
num_heads=8,
|
|
60
|
+
qkv_bias=False,
|
|
61
|
+
qk_scale=None,
|
|
62
|
+
attn_drop=0.0,
|
|
63
|
+
proj_drop=0.0,
|
|
64
|
+
):
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.num_heads = num_heads
|
|
67
|
+
head_dim = dim // num_heads
|
|
68
|
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
|
69
|
+
self.scale = qk_scale or head_dim**-0.5
|
|
70
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
71
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
72
|
+
self.proj = nn.Linear(dim, dim)
|
|
73
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
|
74
|
+
self.attn_gradients = None
|
|
75
|
+
self.attention_map = None
|
|
76
|
+
|
|
77
|
+
def save_attn_gradients(self, attn_gradients):
|
|
78
|
+
self.attn_gradients = attn_gradients
|
|
79
|
+
|
|
80
|
+
def get_attn_gradients(self):
|
|
81
|
+
return self.attn_gradients
|
|
82
|
+
|
|
83
|
+
def save_attention_map(self, attention_map):
|
|
84
|
+
self.attention_map = attention_map
|
|
85
|
+
|
|
86
|
+
def get_attention_map(self):
|
|
87
|
+
return self.attention_map
|
|
88
|
+
|
|
89
|
+
def forward(self, x, register_hook=False):
|
|
90
|
+
B, N, C = x.shape
|
|
91
|
+
qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4))
|
|
92
|
+
q, k, v = (
|
|
93
|
+
qkv[0],
|
|
94
|
+
qkv[1],
|
|
95
|
+
qkv[2],
|
|
96
|
+
) # make torchscript happy (cannot use tensor as tuple)
|
|
97
|
+
|
|
98
|
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
99
|
+
attn = attn.softmax(dim=-1)
|
|
100
|
+
attn = self.attn_drop(attn)
|
|
101
|
+
|
|
102
|
+
if register_hook:
|
|
103
|
+
self.save_attention_map(attn)
|
|
104
|
+
attn.register_hook(self.save_attn_gradients)
|
|
105
|
+
|
|
106
|
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
107
|
+
x = self.proj(x)
|
|
108
|
+
x = self.proj_drop(x)
|
|
109
|
+
return x
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class Block(nn.Module):
|
|
113
|
+
|
|
114
|
+
def __init__(
|
|
115
|
+
self,
|
|
116
|
+
dim,
|
|
117
|
+
num_heads,
|
|
118
|
+
mlp_ratio=4.0,
|
|
119
|
+
qkv_bias=False,
|
|
120
|
+
qk_scale=None,
|
|
121
|
+
drop=0.0,
|
|
122
|
+
attn_drop=0.0,
|
|
123
|
+
drop_path=0.0,
|
|
124
|
+
act_layer=nn.GELU,
|
|
125
|
+
norm_layer=nn.LayerNorm,
|
|
126
|
+
use_grad_checkpointing=False,
|
|
127
|
+
):
|
|
128
|
+
super().__init__()
|
|
129
|
+
self.norm1 = norm_layer(dim)
|
|
130
|
+
self.attn = Attention(
|
|
131
|
+
dim,
|
|
132
|
+
num_heads=num_heads,
|
|
133
|
+
qkv_bias=qkv_bias,
|
|
134
|
+
qk_scale=qk_scale,
|
|
135
|
+
attn_drop=attn_drop,
|
|
136
|
+
proj_drop=drop,
|
|
137
|
+
)
|
|
138
|
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
|
139
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
140
|
+
self.norm2 = norm_layer(dim)
|
|
141
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
142
|
+
self.mlp = Mlp(
|
|
143
|
+
in_features=dim,
|
|
144
|
+
hidden_features=mlp_hidden_dim,
|
|
145
|
+
act_layer=act_layer,
|
|
146
|
+
drop=drop,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if use_grad_checkpointing:
|
|
150
|
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
|
151
|
+
self.attn = checkpoint_wrapper(self.attn)
|
|
152
|
+
self.mlp = checkpoint_wrapper(self.mlp)
|
|
153
|
+
|
|
154
|
+
def forward(self, x, register_hook=False):
|
|
155
|
+
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
|
156
|
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
157
|
+
return x
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class VisionTransformer(nn.Module):
|
|
161
|
+
"""Vision Transformer
|
|
162
|
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
|
163
|
+
https://arxiv.org/abs/2010.11929
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
img_size=224,
|
|
169
|
+
patch_size=16,
|
|
170
|
+
in_chans=3,
|
|
171
|
+
num_classes=1000,
|
|
172
|
+
embed_dim=768,
|
|
173
|
+
depth=12,
|
|
174
|
+
num_heads=12,
|
|
175
|
+
mlp_ratio=4.0,
|
|
176
|
+
qkv_bias=True,
|
|
177
|
+
qk_scale=None,
|
|
178
|
+
representation_size=None,
|
|
179
|
+
drop_rate=0.0,
|
|
180
|
+
attn_drop_rate=0.0,
|
|
181
|
+
drop_path_rate=0.0,
|
|
182
|
+
norm_layer=None,
|
|
183
|
+
use_grad_checkpointing=False,
|
|
184
|
+
ckpt_layer=0,
|
|
185
|
+
):
|
|
186
|
+
"""
|
|
187
|
+
Args:
|
|
188
|
+
img_size (int, tuple): input image size
|
|
189
|
+
patch_size (int, tuple): patch size
|
|
190
|
+
in_chans (int): number of input channels
|
|
191
|
+
num_classes (int): number of classes for classification head
|
|
192
|
+
embed_dim (int): embedding dimension
|
|
193
|
+
depth (int): depth of transformer
|
|
194
|
+
num_heads (int): number of attention heads
|
|
195
|
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
|
196
|
+
qkv_bias (bool): enable bias for qkv if True
|
|
197
|
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
|
198
|
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
|
199
|
+
drop_rate (float): dropout rate
|
|
200
|
+
attn_drop_rate (float): attention dropout rate
|
|
201
|
+
drop_path_rate (float): stochastic depth rate
|
|
202
|
+
norm_layer: (nn.Module): normalization layer
|
|
203
|
+
"""
|
|
204
|
+
super().__init__()
|
|
205
|
+
self.num_features = (self.embed_dim) = embed_dim # num_features for consistency with other models
|
|
206
|
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
|
207
|
+
|
|
208
|
+
self.patch_embed = PatchEmbed(
|
|
209
|
+
img_size=img_size,
|
|
210
|
+
patch_size=patch_size,
|
|
211
|
+
in_chans=in_chans,
|
|
212
|
+
embed_dim=embed_dim,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
num_patches = self.patch_embed.num_patches
|
|
216
|
+
|
|
217
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
218
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
|
219
|
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
220
|
+
|
|
221
|
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
222
|
+
self.blocks = nn.ModuleList([
|
|
223
|
+
Block(
|
|
224
|
+
dim=embed_dim,
|
|
225
|
+
num_heads=num_heads,
|
|
226
|
+
mlp_ratio=mlp_ratio,
|
|
227
|
+
qkv_bias=qkv_bias,
|
|
228
|
+
qk_scale=qk_scale,
|
|
229
|
+
drop=drop_rate,
|
|
230
|
+
attn_drop=attn_drop_rate,
|
|
231
|
+
drop_path=dpr[i],
|
|
232
|
+
norm_layer=norm_layer,
|
|
233
|
+
use_grad_checkpointing=(use_grad_checkpointing and i >= depth - ckpt_layer),
|
|
234
|
+
) for i in range(depth)
|
|
235
|
+
])
|
|
236
|
+
self.norm = norm_layer(embed_dim)
|
|
237
|
+
|
|
238
|
+
trunc_normal_(self.pos_embed, std=0.02)
|
|
239
|
+
trunc_normal_(self.cls_token, std=0.02)
|
|
240
|
+
self.apply(self._init_weights)
|
|
241
|
+
|
|
242
|
+
def _init_weights(self, m):
|
|
243
|
+
if isinstance(m, nn.Linear):
|
|
244
|
+
trunc_normal_(m.weight, std=0.02)
|
|
245
|
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
246
|
+
nn.init.constant_(m.bias, 0)
|
|
247
|
+
elif isinstance(m, nn.LayerNorm):
|
|
248
|
+
nn.init.constant_(m.bias, 0)
|
|
249
|
+
nn.init.constant_(m.weight, 1.0)
|
|
250
|
+
|
|
251
|
+
@torch.jit.ignore
|
|
252
|
+
def no_weight_decay(self):
|
|
253
|
+
return {'pos_embed', 'cls_token'}
|
|
254
|
+
|
|
255
|
+
def forward(self, x, register_blk=-1):
|
|
256
|
+
B = x.shape[0]
|
|
257
|
+
x = self.patch_embed(x)
|
|
258
|
+
|
|
259
|
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
|
260
|
+
x = torch.cat((cls_tokens, x), dim=1)
|
|
261
|
+
|
|
262
|
+
x = x + self.pos_embed[:, :x.size(1), :]
|
|
263
|
+
x = self.pos_drop(x)
|
|
264
|
+
|
|
265
|
+
for i, blk in enumerate(self.blocks):
|
|
266
|
+
x = blk(x, register_blk == i)
|
|
267
|
+
x = self.norm(x)
|
|
268
|
+
|
|
269
|
+
return x
|
|
270
|
+
|
|
271
|
+
@torch.jit.ignore()
|
|
272
|
+
def load_pretrained(self, checkpoint_path, prefix=''):
|
|
273
|
+
_load_weights(self, checkpoint_path, prefix)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@torch.no_grad()
|
|
277
|
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
|
278
|
+
"""Load weights from .npz checkpoints for official Google Brain Flax implementation"""
|
|
279
|
+
import numpy as np
|
|
280
|
+
|
|
281
|
+
def _n2p(w, t=True):
|
|
282
|
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
|
283
|
+
w = w.flatten()
|
|
284
|
+
if t:
|
|
285
|
+
if w.ndim == 4:
|
|
286
|
+
w = w.transpose([3, 2, 0, 1])
|
|
287
|
+
elif w.ndim == 3:
|
|
288
|
+
w = w.transpose([2, 0, 1])
|
|
289
|
+
elif w.ndim == 2:
|
|
290
|
+
w = w.transpose([1, 0])
|
|
291
|
+
return torch.from_numpy(w)
|
|
292
|
+
|
|
293
|
+
w = np.load(checkpoint_path)
|
|
294
|
+
if not prefix and 'opt/target/embedding/kernel' in w:
|
|
295
|
+
prefix = 'opt/target/'
|
|
296
|
+
|
|
297
|
+
if hasattr(model.patch_embed, 'backbone'):
|
|
298
|
+
# hybrid
|
|
299
|
+
backbone = model.patch_embed.backbone
|
|
300
|
+
stem_only = not hasattr(backbone, 'stem')
|
|
301
|
+
stem = backbone if stem_only else backbone.stem
|
|
302
|
+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
|
303
|
+
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
|
304
|
+
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
|
305
|
+
if not stem_only:
|
|
306
|
+
for i, stage in enumerate(backbone.stages):
|
|
307
|
+
for j, block in enumerate(stage.blocks):
|
|
308
|
+
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
|
309
|
+
for r in range(3):
|
|
310
|
+
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
|
311
|
+
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
|
312
|
+
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
|
313
|
+
if block.downsample is not None:
|
|
314
|
+
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
|
315
|
+
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
|
316
|
+
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
|
317
|
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
|
318
|
+
else:
|
|
319
|
+
embed_conv_w = adapt_input_conv(model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
|
320
|
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
|
321
|
+
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
|
322
|
+
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
|
323
|
+
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
|
324
|
+
if pos_embed_w.shape != model.pos_embed.shape:
|
|
325
|
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
|
326
|
+
pos_embed_w,
|
|
327
|
+
model.pos_embed,
|
|
328
|
+
getattr(model, 'num_tokens', 1),
|
|
329
|
+
model.patch_embed.grid_size,
|
|
330
|
+
)
|
|
331
|
+
model.pos_embed.copy_(pos_embed_w)
|
|
332
|
+
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
|
333
|
+
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
|
334
|
+
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
|
335
|
+
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
|
336
|
+
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
|
337
|
+
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
|
338
|
+
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
|
339
|
+
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
|
340
|
+
for i, block in enumerate(model.blocks.children()):
|
|
341
|
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
|
342
|
+
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
|
343
|
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
|
344
|
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
|
345
|
+
block.attn.qkv.weight.copy_(
|
|
346
|
+
torch.cat([_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
|
347
|
+
block.attn.qkv.bias.copy_(
|
|
348
|
+
torch.cat([_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
|
349
|
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
|
350
|
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
|
351
|
+
for r in range(2):
|
|
352
|
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
|
353
|
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
|
354
|
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
|
355
|
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
|
359
|
+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
|
360
|
+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
361
|
+
print('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
|
362
|
+
ntok_new = posemb_new.shape[1]
|
|
363
|
+
if num_tokens:
|
|
364
|
+
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
|
365
|
+
ntok_new -= num_tokens
|
|
366
|
+
else:
|
|
367
|
+
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
|
368
|
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
|
369
|
+
if not len(gs_new): # backwards compatibility
|
|
370
|
+
gs_new = [int(math.sqrt(ntok_new))] * 2
|
|
371
|
+
assert len(gs_new) >= 2
|
|
372
|
+
print('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
|
|
373
|
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
|
374
|
+
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
|
375
|
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
|
376
|
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
|
377
|
+
return
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
|
381
|
+
# interpolate position embedding
|
|
382
|
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
|
383
|
+
num_patches = visual_encoder.patch_embed.num_patches
|
|
384
|
+
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
|
385
|
+
# height (== width) for the checkpoint position embedding
|
|
386
|
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
|
|
387
|
+
# height (== width) for the new position embedding
|
|
388
|
+
new_size = int(num_patches**0.5)
|
|
389
|
+
|
|
390
|
+
if orig_size != new_size:
|
|
391
|
+
# class_token and dist_token are kept unchanged
|
|
392
|
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
|
393
|
+
# only the position tokens are interpolated
|
|
394
|
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
|
395
|
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
|
396
|
+
pos_tokens = torch.nn.functional.interpolate(
|
|
397
|
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
|
398
|
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
|
399
|
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
|
400
|
+
print('reshape position embedding from %d to %d' % (orig_size**2, new_size**2))
|
|
401
|
+
|
|
402
|
+
return new_pos_embed
|
|
403
|
+
else:
|
|
404
|
+
return pos_embed_checkpoint
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
class VisionTransformerEncoder(VisionTransformer, BaseEncoder):
|
|
408
|
+
|
|
409
|
+
@classmethod
|
|
410
|
+
def from_config(cls, cfg, from_pretrained=False):
|
|
411
|
+
|
|
412
|
+
vit_type = cfg.get('vit_type', 'base')
|
|
413
|
+
image_size = cfg.get('image_size', 384)
|
|
414
|
+
ckpt_layer = cfg.get('vit_ckpt_layer', 0)
|
|
415
|
+
drop_path_rate = cfg.get('vit_drop_path_rate', 0)
|
|
416
|
+
norm_layer_eps = cfg.get('vit_layer_norm_epsilon', -1)
|
|
417
|
+
use_grad_checkpointing = cfg.get('vit_grad_ckpt', False)
|
|
418
|
+
|
|
419
|
+
if norm_layer_eps == -1:
|
|
420
|
+
norm_layer = None
|
|
421
|
+
else:
|
|
422
|
+
norm_layer = partial(nn.LayerNorm, eps=norm_layer_eps)
|
|
423
|
+
|
|
424
|
+
# norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
425
|
+
assert vit_type in ['base', 'large'], 'vit parameter must be base or large'
|
|
426
|
+
if vit_type == 'base':
|
|
427
|
+
vision_width = 768
|
|
428
|
+
visual_encoder = cls(
|
|
429
|
+
img_size=image_size,
|
|
430
|
+
patch_size=16,
|
|
431
|
+
embed_dim=vision_width,
|
|
432
|
+
depth=12,
|
|
433
|
+
num_heads=12,
|
|
434
|
+
use_grad_checkpointing=use_grad_checkpointing,
|
|
435
|
+
ckpt_layer=ckpt_layer,
|
|
436
|
+
drop_path_rate=0 or drop_path_rate,
|
|
437
|
+
norm_layer=norm_layer,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
if from_pretrained:
|
|
441
|
+
checkpoint = torch.hub.load_state_dict_from_url(
|
|
442
|
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
|
|
443
|
+
map_location='cpu',
|
|
444
|
+
check_hash=True,
|
|
445
|
+
)
|
|
446
|
+
state_dict = checkpoint['model']
|
|
447
|
+
state_dict['pos_embed'] = interpolate_pos_embed(state_dict['pos_embed'], visual_encoder)
|
|
448
|
+
msg = visual_encoder.load_state_dict(state_dict, strict=False)
|
|
449
|
+
|
|
450
|
+
elif vit_type == 'large':
|
|
451
|
+
vision_width = 1024
|
|
452
|
+
visual_encoder = cls(
|
|
453
|
+
img_size=image_size,
|
|
454
|
+
patch_size=16,
|
|
455
|
+
embed_dim=vision_width,
|
|
456
|
+
depth=24,
|
|
457
|
+
num_heads=16,
|
|
458
|
+
use_grad_checkpointing=use_grad_checkpointing,
|
|
459
|
+
ckpt_layer=ckpt_layer,
|
|
460
|
+
drop_path_rate=0.1 or drop_path_rate,
|
|
461
|
+
norm_layer=norm_layer,
|
|
462
|
+
)
|
|
463
|
+
if from_pretrained:
|
|
464
|
+
from timm.models.helpers import load_custom_pretrained
|
|
465
|
+
from timm.models.vision_transformer import default_cfgs
|
|
466
|
+
|
|
467
|
+
load_custom_pretrained(visual_encoder, default_cfgs['vit_large_patch16_224_in21k'])
|
|
468
|
+
|
|
469
|
+
visual_encoder.vision_width = vision_width
|
|
470
|
+
return visual_encoder
|
|
471
|
+
|
|
472
|
+
def forward_features(self, x, register_blk=-1):
|
|
473
|
+
return super().forward(x, register_blk)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from ..common.registry import registry
|
|
9
|
+
from .base_processor import BaseProcessor
|
|
10
|
+
from .blip_processors import (Blip2ImageTrainProcessor, BlipCaptionProcessor, BlipImageEvalProcessor,
|
|
11
|
+
BlipImageTrainProcessor)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
'BaseProcessor',
|
|
15
|
+
# BLIP
|
|
16
|
+
'BlipImageTrainProcessor',
|
|
17
|
+
'Blip2ImageTrainProcessor',
|
|
18
|
+
'BlipImageEvalProcessor',
|
|
19
|
+
'BlipCaptionProcessor',
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def load_processor(name, cfg=None):
|
|
24
|
+
"""
|
|
25
|
+
Example
|
|
26
|
+
|
|
27
|
+
>>> processor = load_processor("alpro_video_train", cfg=None)
|
|
28
|
+
"""
|
|
29
|
+
processor = registry.get_processor_class(name).from_config(cfg)
|
|
30
|
+
|
|
31
|
+
return processor
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2022, salesforce.com, inc.
|
|
3
|
+
All rights reserved.
|
|
4
|
+
SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from omegaconf import OmegaConf
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseProcessor:
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self.transform = lambda x: x
|
|
15
|
+
return
|
|
16
|
+
|
|
17
|
+
def __call__(self, item):
|
|
18
|
+
return self.transform(item)
|
|
19
|
+
|
|
20
|
+
@classmethod
|
|
21
|
+
def from_config(cls, cfg=None):
|
|
22
|
+
return cls()
|
|
23
|
+
|
|
24
|
+
def build(self, **kwargs):
|
|
25
|
+
cfg = OmegaConf.create(kwargs)
|
|
26
|
+
|
|
27
|
+
return self.from_config(cfg)
|