openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -1,1398 +1,1563 @@
1
- # Scene Text Recognition Model Hub
2
- # Copyright 2022 Darwin Bautista
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # https://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import math
17
- from itertools import permutations
18
- from collections import OrderedDict
19
- import hashlib
20
- import os
21
- import gzip
22
- import html
23
- import urllib
24
- import warnings
25
- import numpy as np
26
- import torch
27
- import torch.nn as nn
28
- import torch.nn.functional as F
29
- from torch import Tensor
30
- from torch.nn.modules import transformer
31
- from typing import Any, Optional, Tuple, List, Union
32
- from pkg_resources import packaging
33
- from PIL import Image
34
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
35
- from tqdm import tqdm
36
- from functools import lru_cache
37
-
38
- import ftfy
39
- import regex as re
40
-
41
- try:
42
- from torchvision.transforms import InterpolationMode
43
- BICUBIC = InterpolationMode.BICUBIC
44
- except ImportError:
45
- BICUBIC = Image.BICUBIC
46
-
47
-
48
- @lru_cache()
49
- def default_bpe():
50
- return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
51
-
52
-
53
- @lru_cache()
54
- def bytes_to_unicode():
55
- """
56
- Returns list of utf-8 byte and a corresponding list of unicode strings.
57
- The reversible bpe codes work on unicode strings.
58
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
59
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
60
- This is a signficant percentage of your normal, say, 32K bpe vocab.
61
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
62
- And avoids mapping to whitespace/control characters the bpe code barfs on.
63
- """
64
- bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
65
- cs = bs[:]
66
- n = 0
67
- for b in range(2**8):
68
- if b not in bs:
69
- bs.append(b)
70
- cs.append(2**8+n)
71
- n += 1
72
- cs = [chr(n) for n in cs]
73
- return dict(zip(bs, cs))
74
-
75
-
76
- def get_pairs(word):
77
- """Return set of symbol pairs in a word.
78
- Word is represented as tuple of symbols (symbols being variable-length strings).
79
- """
80
- pairs = set()
81
- prev_char = word[0]
82
- for char in word[1:]:
83
- pairs.add((prev_char, char))
84
- prev_char = char
85
- return pairs
86
-
87
-
88
- def basic_clean(text):
89
- text = ftfy.fix_text(text)
90
- text = html.unescape(html.unescape(text))
91
- return text.strip()
92
-
93
-
94
- def whitespace_clean(text):
95
- text = re.sub(r'\s+', ' ', text)
96
- text = text.strip()
97
- return text
98
-
99
-
100
- class SimpleTokenizer(object):
101
- def __init__(self, bpe_path: str = default_bpe()):
102
- self.byte_encoder = bytes_to_unicode()
103
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
104
- merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
105
- merges = merges[1:49152-256-2+1]
106
- merges = [tuple(merge.split()) for merge in merges]
107
- vocab = list(bytes_to_unicode().values())
108
- vocab = vocab + [v+'</w>' for v in vocab]
109
- for merge in merges:
110
- vocab.append(''.join(merge))
111
- vocab.extend(['<|startoftext|>', '<|endoftext|>'])
112
- self.encoder = dict(zip(vocab, range(len(vocab))))
113
- self.decoder = {v: k for k, v in self.encoder.items()}
114
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
115
- self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
116
- self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
117
-
118
- def bpe(self, token):
119
- if token in self.cache:
120
- return self.cache[token]
121
- word = tuple(token[:-1]) + ( token[-1] + '</w>',)
122
- pairs = get_pairs(word)
123
-
124
- if not pairs:
125
- return token+'</w>'
126
-
127
- while True:
128
- bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
129
- if bigram not in self.bpe_ranks:
130
- break
131
- first, second = bigram
132
- new_word = []
133
- i = 0
134
- while i < len(word):
135
- try:
136
- j = word.index(first, i)
137
- new_word.extend(word[i:j])
138
- i = j
139
- except:
140
- new_word.extend(word[i:])
141
- break
142
-
143
- if word[i] == first and i < len(word)-1 and word[i+1] == second:
144
- new_word.append(first+second)
145
- i += 2
146
- else:
147
- new_word.append(word[i])
148
- i += 1
149
- new_word = tuple(new_word)
150
- word = new_word
151
- if len(word) == 1:
152
- break
153
- else:
154
- pairs = get_pairs(word)
155
- word = ' '.join(word)
156
- self.cache[token] = word
157
- return word
158
-
159
- def encode(self, text):
160
- bpe_tokens = []
161
- text = whitespace_clean(basic_clean(text)).lower()
162
- for token in re.findall(self.pat, text):
163
- token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
164
- bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
165
- return bpe_tokens
166
-
167
- def decode(self, tokens):
168
- text = ''.join([self.decoder[token] for token in tokens])
169
- text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
170
- return text
171
-
172
-
173
- if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
174
- warnings.warn("PyTorch version 1.7.1 or higher is recommended")
175
-
176
-
177
- __all__ = ["available_models", "load", "tokenize"]
178
- _tokenizer = SimpleTokenizer()
179
-
180
- _MODELS = {
181
- "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
182
- "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
183
- "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
184
- "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
185
- "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
186
- "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
187
- "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
188
- "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
189
- "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
190
- }
191
-
192
-
193
- def convert_weights(model: nn.Module):
194
- """Convert applicable model parameters to fp16"""
195
-
196
- def _convert_weights_to_fp16(l):
197
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
198
- l.weight.data = l.weight.data.half()
199
- if l.bias is not None:
200
- l.bias.data = l.bias.data.half()
201
-
202
- if isinstance(l, nn.MultiheadAttention):
203
- for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
204
- tensor = getattr(l, attr)
205
- if tensor is not None:
206
- tensor.data = tensor.data.half()
207
-
208
- for name in ["text_projection", "proj"]:
209
- if hasattr(l, name):
210
- attr = getattr(l, name)
211
- if attr is not None:
212
- attr.data = attr.data.half()
213
-
214
- model.apply(_convert_weights_to_fp16)
215
-
216
-
217
- def build_model(state_dict: dict):
218
- vit = "visual.proj" in state_dict
219
-
220
- if vit:
221
- vision_width = state_dict["visual.conv1.weight"].shape[0]
222
- vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
223
- vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
224
- grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
225
- image_resolution = vision_patch_size * grid_size
226
- else:
227
- counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
228
- vision_layers = tuple(counts)
229
- vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
230
- output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
231
- vision_patch_size = None
232
- assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
233
- image_resolution = output_width * 32
234
-
235
- embed_dim = state_dict["text_projection"].shape[1]
236
- context_length = state_dict["positional_embedding"].shape[0]
237
- vocab_size = state_dict["token_embedding.weight"].shape[0]
238
- transformer_width = state_dict["ln_final.weight"].shape[0]
239
- transformer_heads = transformer_width // 64
240
- transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
241
-
242
- model = CLIP(
243
- embed_dim,
244
- image_resolution, vision_layers, vision_width, vision_patch_size,
245
- context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
246
- )
247
-
248
- for key in ["input_resolution", "context_length", "vocab_size"]:
249
- if key in state_dict:
250
- del state_dict[key]
251
-
252
- convert_weights(model)
253
- model.load_state_dict(state_dict)
254
- return model.eval()
255
-
256
-
257
- def _download(url: str, root: str):
258
- os.makedirs(root, exist_ok=True)
259
- filename = os.path.basename(url)
260
-
261
- expected_sha256 = url.split("/")[-2]
262
- download_target = os.path.join(root, filename)
263
-
264
- if os.path.exists(download_target) and not os.path.isfile(download_target):
265
- raise RuntimeError(f"{download_target} exists and is not a regular file")
266
-
267
- if os.path.isfile(download_target):
268
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
269
- return download_target
270
- else:
271
- warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
272
-
273
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
274
- with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
275
- while True:
276
- buffer = source.read(8192)
277
- if not buffer:
278
- break
279
-
280
- output.write(buffer)
281
- loop.update(len(buffer))
282
-
283
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
284
- raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
285
-
286
- return download_target
287
-
288
-
289
- def _convert_image_to_rgb(image):
290
- return image.convert("RGB")
291
-
292
-
293
- def _transform(n_px):
294
- return Compose([
295
- Resize(n_px, interpolation=BICUBIC),
296
- CenterCrop(n_px),
297
- _convert_image_to_rgb,
298
- ToTensor(),
299
- Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
300
- ])
301
-
302
-
303
- def available_models() -> List[str]:
304
- """Returns the names of available CLIP models"""
305
- return list(_MODELS.keys())
306
-
307
-
308
- class Bottleneck(nn.Module):
309
- expansion = 4
310
-
311
- def __init__(self, inplanes, planes, stride=1):
312
- super().__init__()
313
-
314
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
315
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
316
- self.bn1 = nn.BatchNorm2d(planes)
317
- self.relu1 = nn.ReLU(inplace=True)
318
-
319
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
320
- self.bn2 = nn.BatchNorm2d(planes)
321
- self.relu2 = nn.ReLU(inplace=True)
322
-
323
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
324
-
325
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
326
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
327
- self.relu3 = nn.ReLU(inplace=True)
328
-
329
- self.downsample = None
330
- self.stride = stride
331
-
332
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
333
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
334
- self.downsample = nn.Sequential(OrderedDict([
335
- ("-1", nn.AvgPool2d(stride)),
336
- ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
337
- ("1", nn.BatchNorm2d(planes * self.expansion))
338
- ]))
339
-
340
- def forward(self, x: torch.Tensor):
341
- identity = x
342
-
343
- out = self.relu1(self.bn1(self.conv1(x)))
344
- out = self.relu2(self.bn2(self.conv2(out)))
345
- out = self.avgpool(out)
346
- out = self.bn3(self.conv3(out))
347
-
348
- if self.downsample is not None:
349
- identity = self.downsample(x)
350
-
351
- out += identity
352
- out = self.relu3(out)
353
- return out
354
-
355
-
356
- class AttentionPool2d(nn.Module):
357
- def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
358
- super().__init__()
359
- self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
360
- self.k_proj = nn.Linear(embed_dim, embed_dim)
361
- self.q_proj = nn.Linear(embed_dim, embed_dim)
362
- self.v_proj = nn.Linear(embed_dim, embed_dim)
363
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
364
- self.num_heads = num_heads
365
-
366
- def forward(self, x):
367
- x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
368
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
369
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
370
- x, _ = F.multi_head_attention_forward(
371
- query=x[:1], key=x, value=x,
372
- embed_dim_to_check=x.shape[-1],
373
- num_heads=self.num_heads,
374
- q_proj_weight=self.q_proj.weight,
375
- k_proj_weight=self.k_proj.weight,
376
- v_proj_weight=self.v_proj.weight,
377
- in_proj_weight=None,
378
- in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
379
- bias_k=None,
380
- bias_v=None,
381
- add_zero_attn=False,
382
- dropout_p=0,
383
- out_proj_weight=self.c_proj.weight,
384
- out_proj_bias=self.c_proj.bias,
385
- use_separate_proj_weight=True,
386
- training=self.training,
387
- need_weights=False
388
- )
389
- return x.squeeze(0)
390
-
391
-
392
- class ModifiedResNet(nn.Module):
393
- """
394
- A ResNet class that is similar to torchvision's but contains the following changes:
395
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
396
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
397
- - The final pooling layer is a QKV attention instead of an average pool
398
- """
399
-
400
- def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
401
- super().__init__()
402
- self.output_dim = output_dim
403
- self.input_resolution = input_resolution
404
-
405
- # the 3-layer stem
406
- self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
407
- self.bn1 = nn.BatchNorm2d(width // 2)
408
- self.relu1 = nn.ReLU(inplace=True)
409
- self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
410
- self.bn2 = nn.BatchNorm2d(width // 2)
411
- self.relu2 = nn.ReLU(inplace=True)
412
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
413
- self.bn3 = nn.BatchNorm2d(width)
414
- self.relu3 = nn.ReLU(inplace=True)
415
- self.avgpool = nn.AvgPool2d(2)
416
-
417
- # residual layers
418
- self._inplanes = width # this is a *mutable* variable used during construction
419
- self.layer1 = self._make_layer(width, layers[0])
420
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
421
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
422
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
423
-
424
- embed_dim = width * 32 # the ResNet feature dimension
425
- self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
426
-
427
- def _make_layer(self, planes, blocks, stride=1):
428
- layers = [Bottleneck(self._inplanes, planes, stride)]
429
-
430
- self._inplanes = planes * Bottleneck.expansion
431
- for _ in range(1, blocks):
432
- layers.append(Bottleneck(self._inplanes, planes))
433
-
434
- return nn.Sequential(*layers)
435
-
436
- def forward(self, x):
437
- def stem(x):
438
- x = self.relu1(self.bn1(self.conv1(x)))
439
- x = self.relu2(self.bn2(self.conv2(x)))
440
- x = self.relu3(self.bn3(self.conv3(x)))
441
- x = self.avgpool(x)
442
- return x
443
-
444
- x = x.type(self.conv1.weight.dtype)
445
- x = stem(x)
446
- x = self.layer1(x)
447
- x = self.layer2(x)
448
- x = self.layer3(x)
449
- x = self.layer4(x)
450
- x = self.attnpool(x)
451
-
452
- return x
453
-
454
-
455
- class LayerNorm(nn.LayerNorm):
456
- """Subclass torch's LayerNorm to handle fp16."""
457
-
458
- def forward(self, x: torch.Tensor):
459
- orig_type = x.dtype
460
- ret = super().forward(x.type(torch.float32))
461
- return ret.type(orig_type)
462
-
463
-
464
- class QuickGELU(nn.Module):
465
- def forward(self, x: torch.Tensor):
466
- return x * torch.sigmoid(1.702 * x)
467
-
468
-
469
- class ResidualAttentionBlock(nn.Module):
470
- def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
471
- super().__init__()
472
-
473
- self.attn = nn.MultiheadAttention(d_model, n_head)
474
- self.ln_1 = LayerNorm(d_model)
475
- self.mlp = nn.Sequential(OrderedDict([
476
- ("c_fc", nn.Linear(d_model, d_model * 4)),
477
- ("gelu", QuickGELU()),
478
- ("c_proj", nn.Linear(d_model * 4, d_model))
479
- ]))
480
- self.ln_2 = LayerNorm(d_model)
481
- self.attn_mask = attn_mask
482
-
483
- def attention(self, x: torch.Tensor):
484
- self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
485
- return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
486
-
487
- def forward(self, x: torch.Tensor):
488
- x = x + self.attention(self.ln_1(x))
489
- x = x + self.mlp(self.ln_2(x))
490
- return x
491
-
492
-
493
- class Transformer(nn.Module):
494
- def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
495
- super().__init__()
496
- self.width = width
497
- self.layers = layers
498
- self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
499
-
500
- def forward(self, x: torch.Tensor):
501
- return self.resblocks(x)
502
-
503
-
504
- class VisionTransformer(nn.Module):
505
- def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
506
- super().__init__()
507
- self.input_resolution = input_resolution
508
- self.output_dim = output_dim
509
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
510
-
511
- scale = width ** -0.5
512
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
513
- self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
514
- self.ln_pre = LayerNorm(width)
515
-
516
- self.transformer = Transformer(width, layers, heads)
517
-
518
- self.ln_post = LayerNorm(width)
519
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
520
-
521
- def forward(self, x: torch.Tensor):
522
- x = self.conv1(x) # shape = [*, width, grid, grid]
523
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
524
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
525
- x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
526
- x = x + self.positional_embedding.to(x.dtype)
527
- x = self.ln_pre(x)
528
-
529
- x = x.permute(1, 0, 2) # NLD -> LND
530
- x = self.transformer(x)
531
- x = x.permute(1, 0, 2) # LND -> NLD
532
-
533
- x = self.ln_post(x)
534
- if self.proj is not None:
535
- x = x @ self.proj
536
-
537
- return x
538
-
539
-
540
- class CLIP(nn.Module):
541
- def __init__(self,
542
- embed_dim: int,
543
- # vision
544
- image_resolution: int,
545
- vision_layers: Union[Tuple[int, int, int, int], int],
546
- vision_width: int,
547
- vision_patch_size: int,
548
- # text
549
- context_length: int,
550
- vocab_size: int,
551
- transformer_width: int,
552
- transformer_heads: int,
553
- transformer_layers: int
554
- ):
555
- super().__init__()
556
-
557
- self.context_length = context_length
558
-
559
- if isinstance(vision_layers, (tuple, list)):
560
- vision_heads = vision_width * 32 // 64
561
- self.visual = ModifiedResNet(
562
- layers=vision_layers,
563
- output_dim=embed_dim,
564
- heads=vision_heads,
565
- input_resolution=image_resolution,
566
- width=vision_width
567
- )
568
- else:
569
- vision_heads = vision_width // 64
570
- self.visual = VisionTransformer(
571
- input_resolution=image_resolution,
572
- patch_size=vision_patch_size,
573
- width=vision_width,
574
- layers=vision_layers,
575
- heads=vision_heads,
576
- output_dim=embed_dim
577
- )
578
-
579
- self.transformer = Transformer(
580
- width=transformer_width,
581
- layers=transformer_layers,
582
- heads=transformer_heads,
583
- attn_mask=self.build_attention_mask()
584
- )
585
-
586
- self.vocab_size = vocab_size
587
- self.token_embedding = nn.Embedding(vocab_size, transformer_width)
588
- self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
589
- self.ln_final = LayerNorm(transformer_width)
590
-
591
- self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
592
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
593
-
594
- self.initialize_parameters()
595
-
596
- def initialize_parameters(self):
597
- nn.init.normal_(self.token_embedding.weight, std=0.02)
598
- nn.init.normal_(self.positional_embedding, std=0.01)
599
-
600
- if isinstance(self.visual, ModifiedResNet):
601
- if self.visual.attnpool is not None:
602
- std = self.visual.attnpool.c_proj.in_features ** -0.5
603
- nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
604
- nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
605
- nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
606
- nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
607
-
608
- for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
609
- for name, param in resnet_block.named_parameters():
610
- if name.endswith("bn3.weight"):
611
- nn.init.zeros_(param)
612
-
613
- proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
614
- attn_std = self.transformer.width ** -0.5
615
- fc_std = (2 * self.transformer.width) ** -0.5
616
- for block in self.transformer.resblocks:
617
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
618
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
619
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
620
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
621
-
622
- if self.text_projection is not None:
623
- nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
624
-
625
- def build_attention_mask(self):
626
- # lazily create causal attention mask, with full attention between the vision tokens
627
- # pytorch uses additive attention mask; fill with -inf
628
- mask = torch.empty(self.context_length, self.context_length)
629
- mask.fill_(float("-inf"))
630
- mask.triu_(1) # zero out the lower diagonal
631
- return mask
632
-
633
- @property
634
- def dtype(self):
635
- return self.visual.conv1.weight.dtype
636
-
637
- def encode_image(self, image):
638
- return self.visual(image.type(self.dtype))
639
-
640
- def encode_text(self, text):
641
- x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
642
-
643
- x = x + self.positional_embedding.type(self.dtype)
644
- x = x.permute(1, 0, 2) # NLD -> LND
645
- x = self.transformer(x)
646
- x = x.permute(1, 0, 2) # LND -> NLD
647
- x = self.ln_final(x).type(self.dtype)
648
-
649
- # take features from the eot embedding (eot_token is the highest number in each sequence)
650
- output = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
651
- output = torch.cat([output.unsqueeze(1), x], dim=1)
652
-
653
- return output
654
-
655
- def forward(self, image, text):
656
- image_features = self.encode_image(image)
657
- text_features = self.encode_text(text)
658
-
659
- # normalized features
660
- image_features = image_features / image_features.norm(dim=1, keepdim=True)
661
- text_features = text_features / text_features.norm(dim=1, keepdim=True)
662
-
663
- # cosine similarity as logits
664
- logit_scale = self.logit_scale.exp()
665
- logits_per_image = logit_scale * image_features @ text_features.t()
666
- logits_per_text = logits_per_image.t()
667
-
668
- # shape = [global_batch_size, global_batch_size]
669
- return logits_per_image, logits_per_text
670
-
671
-
672
- class FMU(nn.Module):
673
- """A Transformer decoder layer supporting two-stream attention (XLNet)
674
- This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
675
-
676
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu',
677
- layer_norm_eps=1e-5):
678
- super().__init__()
679
- self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
680
- # Implementation of Feedforward model
681
- self.linear1 = nn.Linear(d_model, dim_feedforward)
682
- self.linear2 = nn.Linear(dim_feedforward, d_model)
683
-
684
- self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
685
-
686
- self.dropout1 = nn.Dropout(dropout)
687
- self.dropout2 = nn.Dropout(dropout)
688
- self.dropout3 = nn.Dropout(dropout)
689
-
690
- self.activation = transformer._get_activation_fn(activation)
691
-
692
- def __setstate__(self, state):
693
- if 'activation' not in state:
694
- state['activation'] = F.gelu
695
- super().__setstate__(state)
696
-
697
- def forward(self, query: Tensor, memory: Tensor):
698
- """Forward pass for a single stream (i.e. content or query)
699
- tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
700
- Both tgt_kv and memory are expected to be LayerNorm'd too.
701
- memory is LayerNorm'd by ViT.
702
- """
703
- query1, ca_weights = self.cross_attn(query, memory, memory)
704
- query = query + self.dropout1(query1)
705
-
706
- query2 = self.linear2(self.dropout2(self.activation(self.linear1(self.norm(query)))))
707
- query = query + self.dropout3(query2)
708
-
709
- return query
710
-
711
-
712
- class DecoderLayer(nn.Module):
713
- """A Transformer decoder layer supporting two-stream attention (XLNet) This
714
- implements a pre-LN decoder, as opposed to the post-LN default in
715
- PyTorch."""
716
-
717
- def __init__(
718
- self,
719
- d_model,
720
- nhead,
721
- dim_feedforward=2048,
722
- dropout=0.1,
723
- activation='gelu',
724
- layer_norm_eps=1e-5,
725
- ):
726
- super().__init__()
727
- self.self_attn = nn.MultiheadAttention(d_model,
728
- nhead,
729
- dropout=dropout,
730
- batch_first=True)
731
- self.cross_attn = nn.MultiheadAttention(d_model,
732
- nhead,
733
- dropout=dropout,
734
- batch_first=True)
735
- # Implementation of Feedforward model
736
- self.linear1 = nn.Linear(d_model, dim_feedforward)
737
- self.dropout = nn.Dropout(dropout)
738
- self.linear2 = nn.Linear(dim_feedforward, d_model)
739
-
740
- self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
741
- self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
742
- self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps)
743
- self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps)
744
- self.dropout1 = nn.Dropout(dropout)
745
- self.dropout2 = nn.Dropout(dropout)
746
- self.dropout3 = nn.Dropout(dropout)
747
-
748
- self.activation = transformer._get_activation_fn(activation)
749
-
750
- def __setstate__(self, state):
751
- if 'activation' not in state:
752
- state['activation'] = F.gelu
753
- super().__setstate__(state)
754
-
755
- def forward_stream(
756
- self,
757
- tgt: Tensor,
758
- tgt_norm: Tensor,
759
- tgt_kv: Tensor,
760
- memory: Tensor,
761
- tgt_mask: Optional[Tensor],
762
- tgt_key_padding_mask: Optional[Tensor],
763
- ):
764
- """Forward pass for a single stream (i.e. content or query) tgt_norm is
765
- just a LayerNorm'd tgt.
766
-
767
- Added as a separate parameter for efficiency. Both tgt_kv and memory
768
- are expected to be LayerNorm'd too. memory is LayerNorm'd by ViT.
769
- """
770
- tgt2, sa_weights = self.self_attn(
771
- tgt_norm,
772
- tgt_kv,
773
- tgt_kv,
774
- attn_mask=tgt_mask,
775
- key_padding_mask=tgt_key_padding_mask)
776
-
777
- tgt = tgt + self.dropout1(tgt2)
778
-
779
- tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
780
- self.attn_map = ca_weights
781
- tgt = tgt + self.dropout2(tgt2)
782
-
783
- tgt2 = self.linear2(
784
- self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
785
- tgt = tgt + self.dropout3(tgt2)
786
- return tgt, sa_weights, ca_weights
787
-
788
- def forward(
789
- self,
790
- query,
791
- content,
792
- memory,
793
- query_mask: Optional[Tensor] = None,
794
- content_mask: Optional[Tensor] = None,
795
- content_key_padding_mask: Optional[Tensor] = None,
796
- update_content: bool = True,
797
- ):
798
- query_norm = self.norm_q(query)
799
- content_norm = self.norm_c(content)
800
- query = self.forward_stream(query, query_norm, content_norm, memory,
801
- query_mask, content_key_padding_mask)[0]
802
- if update_content:
803
- content = self.forward_stream(content, content_norm, content_norm,
804
- memory, content_mask,
805
- content_key_padding_mask)[0]
806
- return query, content
807
-
808
-
809
- class Decoder(nn.Module):
810
- __constants__ = ['norm']
811
-
812
- def __init__(self, decoder_layer, num_layers, norm):
813
- super().__init__()
814
- self.layers = transformer._get_clones(decoder_layer, num_layers)
815
- self.num_layers = num_layers
816
- self.norm = norm
817
-
818
- def forward(
819
- self,
820
- query,
821
- content,
822
- memory,
823
- query_mask: Optional[Tensor] = None,
824
- content_mask: Optional[Tensor] = None,
825
- content_key_padding_mask: Optional[Tensor] = None,
826
- ):
827
- for i, mod in enumerate(self.layers):
828
- last = i == len(self.layers) - 1
829
- query, content = mod(
830
- query,
831
- content,
832
- memory,
833
- query_mask,
834
- content_mask,
835
- content_key_padding_mask,
836
- update_content=not last,
837
- )
838
- query = self.norm(query)
839
- return query
840
-
841
-
842
- class TokenEmbedding(nn.Module):
843
-
844
- def __init__(self, charset_size: int, embed_dim: int):
845
- super().__init__()
846
- self.embedding = nn.Embedding(charset_size, embed_dim)
847
- self.embed_dim = embed_dim
848
-
849
- def forward(self, tokens: torch.Tensor):
850
- return math.sqrt(self.embed_dim) * self.embedding(tokens)
851
-
852
-
853
- def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
854
- """Load a CLIP model
855
-
856
- Parameters
857
- ----------
858
- name : str
859
- A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
860
-
861
- device : Union[str, torch.device]
862
- The device to put the loaded model
863
-
864
- jit : bool
865
- Whether to load the optimized JIT model or more hackable non-JIT model (default).
866
-
867
- download_root: str
868
- path to download the model files; by default, it uses "~/.cache/clip"
869
-
870
- Returns
871
- -------
872
- model : torch.nn.Module
873
- The CLIP model
874
-
875
- preprocess : Callable[[PIL.Image], torch.Tensor]
876
- A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
877
- """
878
- if name in _MODELS:
879
- model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
880
- elif os.path.isfile(name):
881
- model_path = name
882
- else:
883
- raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
884
-
885
- with open(model_path, 'rb') as opened_file:
886
- try:
887
- # loading JIT archive
888
- model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
889
- state_dict = None
890
- except RuntimeError:
891
- # loading saved state dict
892
- if jit:
893
- warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
894
- jit = False
895
- state_dict = torch.load(opened_file, map_location="cpu")
896
-
897
- if not jit:
898
- model = build_model(state_dict or model.state_dict()).to(device)
899
- if str(device) == "cpu":
900
- model.float()
901
- return model, _transform(model.visual.input_resolution)
902
-
903
- # patch the device names
904
- device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
905
- device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
906
-
907
- def patch_device(module):
908
- try:
909
- graphs = [module.graph] if hasattr(module, "graph") else []
910
- except RuntimeError:
911
- graphs = []
912
-
913
- if hasattr(module, "forward1"):
914
- graphs.append(module.forward1.graph)
915
-
916
- for graph in graphs:
917
- for node in graph.findAllNodes("prim::Constant"):
918
- if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
919
- node.copyAttributes(device_node)
920
-
921
- model.apply(patch_device)
922
- patch_device(model.encode_image)
923
- patch_device(model.encode_text)
924
-
925
- # patch dtype to float32 on CPU
926
- if str(device) == "cpu":
927
- float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
928
- float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
929
- float_node = float_input.node()
930
-
931
- def patch_float(module):
932
- try:
933
- graphs = [module.graph] if hasattr(module, "graph") else []
934
- except RuntimeError:
935
- graphs = []
936
-
937
- if hasattr(module, "forward1"):
938
- graphs.append(module.forward1.graph)
939
-
940
- for graph in graphs:
941
- for node in graph.findAllNodes("aten::to"):
942
- inputs = list(node.inputs())
943
- for i in [1, 2]: # dtype can be the second or third argument to aten::to()
944
- if inputs[i].node()["value"] == 5:
945
- inputs[i].node().copyAttributes(float_node)
946
-
947
- model.apply(patch_float)
948
- patch_float(model.encode_image)
949
- patch_float(model.encode_text)
950
-
951
- model.float()
952
-
953
- return model, _transform(model.input_resolution.item())
954
-
955
- def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
956
- """
957
- Returns the tokenized representation of given input string(s)
958
-
959
- Parameters
960
- ----------
961
- texts : Union[str, List[str]]
962
- An input string or a list of input strings to tokenize
963
-
964
- context_length : int
965
- The context length to use; all CLIP models use 77 as the context length
966
-
967
- truncate: bool
968
- Whether to truncate the text in case its encoding is longer than the context length
969
-
970
- Returns
971
- -------
972
- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
973
- We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
974
- """
975
- if isinstance(texts, str):
976
- texts = [texts]
977
-
978
- sot_token = _tokenizer.encoder["<|startoftext|>"]
979
- eot_token = _tokenizer.encoder["<|endoftext|>"]
980
- all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
981
- if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
982
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
983
- else:
984
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
985
-
986
- for i, tokens in enumerate(all_tokens):
987
- if len(tokens) > context_length:
988
- if truncate:
989
- tokens = tokens[:context_length]
990
- tokens[-1] = eot_token
991
- else:
992
- raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
993
- result[i, :len(tokens)] = torch.tensor(tokens)
994
-
995
- return result
996
-
997
- class DptrParseq(nn.Module):
998
-
999
- def __init__(self,
1000
- in_channels,
1001
- out_channels,
1002
- max_label_length=25,
1003
- embed_dim=512,
1004
- dec_num_heads=8,
1005
- dec_mlp_ratio=4,
1006
- dec_depth=6,
1007
- perm_num=6,
1008
- perm_forward=True,
1009
- perm_mirrored=True,
1010
- decode_ar=True,
1011
- refine_iters=1,
1012
- dropout=0.1,
1013
- is_pretrain=True,
1014
- ORP_path=None,
1015
- **kwargs: Any) -> None:
1016
- super().__init__()
1017
- self.pad_id = out_channels - 1
1018
- self.eos_id = 0
1019
- self.bos_id = out_channels - 2
1020
- self.max_label_length = max_label_length
1021
- self.decode_ar = decode_ar
1022
- self.refine_iters = refine_iters
1023
- self.is_pretrain = is_pretrain
1024
- if not is_pretrain:
1025
- self.token_query = nn.Parameter(torch.Tensor(1, 26, embed_dim))
1026
- self.fmu = FMU(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
1027
-
1028
- decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
1029
- self.decoder = Decoder(decoder_layer,
1030
- num_layers=dec_depth,
1031
- norm=nn.LayerNorm(embed_dim))
1032
-
1033
- # Perm/attn mask stuff
1034
- self.rng = np.random.default_rng()
1035
- self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
1036
- self.perm_forward = perm_forward
1037
- self.perm_mirrored = perm_mirrored
1038
-
1039
- # We don't predict <bos> nor <pad>
1040
- self.head = nn.Linear(embed_dim, out_channels - 2)
1041
- self.text_embed = TokenEmbedding(out_channels, embed_dim)
1042
-
1043
- # +1 for <eos>
1044
- self.pos_queries = nn.Parameter(
1045
- torch.Tensor(1, max_label_length + 1, embed_dim))
1046
- self.dropout = nn.Dropout(p=dropout)
1047
- # Encoder has its own init.
1048
- self.apply(self._init_weights)
1049
- nn.init.trunc_normal_(self.pos_queries, std=0.02)
1050
-
1051
- if is_pretrain:
1052
- self.clip_encoder, preprocess = load("ViT-B/16")
1053
- for p in self.clip_encoder.parameters():
1054
- p.requires_grad = False
1055
- if ORP_path is None:
1056
- background_image_folder_path = 'background_mages_folder/path'
1057
- self.background_features = self.get_noise(background_image_folder_path, preprocess)
1058
- torch.save(self.background_features, 'save/noise/to/ORP_path')
1059
- else:
1060
- self.background_features = torch.load(ORP_path, map_location='cpu')
1061
-
1062
- def _init_weights(self, module: nn.Module):
1063
- """Initialize the weights using the typical initialization schemes used
1064
- in SOTA models."""
1065
-
1066
- if isinstance(module, nn.Linear):
1067
- nn.init.trunc_normal_(module.weight, std=0.02)
1068
- if module.bias is not None:
1069
- nn.init.zeros_(module.bias)
1070
- elif isinstance(module, nn.Embedding):
1071
- nn.init.trunc_normal_(module.weight, std=0.02)
1072
- if module.padding_idx is not None:
1073
- module.weight.data[module.padding_idx].zero_()
1074
- elif isinstance(module, nn.Conv2d):
1075
- nn.init.kaiming_normal_(module.weight,
1076
- mode='fan_out',
1077
- nonlinearity='relu')
1078
- if module.bias is not None:
1079
- nn.init.zeros_(module.bias)
1080
- elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
1081
- nn.init.ones_(module.weight)
1082
- nn.init.zeros_(module.bias)
1083
-
1084
- @torch.jit.ignore
1085
- def no_weight_decay(self):
1086
- param_names = {'text_embed.embedding.weight', 'pos_queries'}
1087
- return param_names
1088
-
1089
- def get_noise(self, background_image_path, preprocess):
1090
- image_paths = [os.path.join(background_image_path, filename) for filename in os.listdir(image_folder_path) if
1091
- filename.endswith(('.png', '.jpg', '.jpeg'))]
1092
- features = []
1093
- for image_path in image_paths:
1094
- image = Image.open(image_path)
1095
- input = preprocess(image).unsqueeze(0).to(self._device)
1096
- with torch.no_grad():
1097
- feature = self.clip_encoder.encode_image(input)
1098
- features.append(feature)
1099
- image.close()
1100
- return torch.cat(features).cpu().numpy()
1101
-
1102
- def clip_encode(self, labels):
1103
- text_inputs = torch.cat([tokenize(f"a photo of a '{c}'") for c in labels]).to(self._device)
1104
-
1105
- return self.clip_encoder.encode_text(text_inputs)
1106
-
1107
- def decode(
1108
- self,
1109
- tgt: torch.Tensor,
1110
- memory: torch.Tensor,
1111
- tgt_mask: Optional[Tensor] = None,
1112
- tgt_padding_mask: Optional[Tensor] = None,
1113
- tgt_query: Optional[Tensor] = None,
1114
- tgt_query_mask: Optional[Tensor] = None,
1115
- pos_query: torch.Tensor = None,
1116
- ):
1117
- N, L = tgt.shape
1118
- # <bos> stands for the null context. We only supply position information for characters after <bos>.
1119
- null_ctx = self.text_embed(tgt[:, :1])
1120
-
1121
- if tgt_query is None:
1122
- tgt_query = pos_query[:, :L]
1123
- tgt_emb = pos_query[:, :L - 1] + self.text_embed(tgt[:, 1:])
1124
- tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
1125
-
1126
- tgt_query = self.dropout(tgt_query)
1127
- return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask,
1128
- tgt_mask, tgt_padding_mask)
1129
-
1130
- def forward(self, memory, data=None, pos_query=None):
1131
- # print(memory.shape, data[0].shape)
1132
- if self.training:
1133
- if self.is_pretrain:
1134
- return self.training_step(None, pos_query, data[0], memory)
1135
- return self.training_step(memory, pos_query, data[0], None)
1136
- else:
1137
- if self.is_pretrain:
1138
- return self.forward_test(None, memory, pos_query)
1139
- return self.forward_test(memory, None, pos_query)
1140
-
1141
- def forward_test(self,
1142
- memory: Tensor, clip_ids,
1143
- pos_query: Tensor = None,
1144
- max_length: Optional[int] = None) -> Tensor:
1145
- testing = max_length is None
1146
- max_length = (self.max_label_length if max_length is None else min(
1147
- max_length, self.max_label_length))
1148
-
1149
- if self.is_pretrain:
1150
- memory = self.clip_encoder.encode_text(clip_ids)
1151
- else:
1152
- bs = memory.shape[0]
1153
- token_query = self.token_query.expand(bs, -1, -1)
1154
- memory = self.fmu(token_query, memory)
1155
- _device = memory.get_device()
1156
- bs = memory.shape[0]
1157
- # +1 for <eos> at end of sequence.
1158
- num_steps = max_length + 1
1159
- # memory = self.encode(images)
1160
-
1161
- # Query positions up to `num_steps`
1162
- if pos_query is None:
1163
- pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
1164
- else:
1165
- pos_queries = pos_query
1166
-
1167
- # Special case for the forward permutation. Faster than using `generate_attn_masks()`
1168
- tgt_mask = query_mask = torch.triu(
1169
- torch.full((num_steps, num_steps), float('-inf'), device=_device),
1170
- 1)
1171
- self.attn_maps = []
1172
- if self.decode_ar:
1173
- tgt_in = torch.full((bs, num_steps),
1174
- self.pad_id,
1175
- dtype=torch.long,
1176
- device=_device)
1177
- tgt_in[:, 0] = self.bos_id
1178
-
1179
- logits = []
1180
- for i in range(num_steps):
1181
- j = i + 1 # next token index
1182
- # Efficient decoding:
1183
- # Input the context up to the ith token. We use only one query (at position = i) at a time.
1184
- # This works because of the lookahead masking effect of the canonical (forward) AR context.
1185
- # Past tokens have no access to future tokens, hence are fixed once computed.
1186
- tgt_out = self.decode(
1187
- tgt_in[:, :j],
1188
- memory,
1189
- tgt_mask[:j, :j],
1190
- tgt_query=pos_queries[:, i:j],
1191
- tgt_query_mask=query_mask[i:j, :j],
1192
- pos_query=pos_queries,
1193
- )
1194
- self.attn_maps.append(self.decoder.layers[-1].attn_map)
1195
- # the next token probability is in the output's ith token position
1196
- p_i = self.head(tgt_out)
1197
- logits.append(p_i)
1198
- if j < num_steps:
1199
- # greedy decode. add the next token index to the target input
1200
- tgt_in[:, j] = p_i.squeeze().argmax(-1)
1201
- # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
1202
- if testing and (tgt_in == self.eos_id).any(dim=-1).all():
1203
- break
1204
-
1205
- logits = torch.cat(logits, dim=1)
1206
- else:
1207
- # No prior context, so input is just <bos>. We query all positions.
1208
- tgt_in = torch.full((bs, 1),
1209
- self.bos_id,
1210
- dtype=torch.long,
1211
- device=_device)
1212
- tgt_out = self.decode(tgt_in,
1213
- memory,
1214
- tgt_query=pos_queries,
1215
- pos_query=pos_queries)
1216
- logits = self.head(tgt_out)
1217
-
1218
- if self.refine_iters:
1219
- # For iterative refinement, we always use a 'cloze' mask.
1220
- # We can derive it from the AR forward mask by unmasking the token context to the right.
1221
- query_mask[torch.triu(
1222
- torch.ones(num_steps,
1223
- num_steps,
1224
- dtype=torch.bool,
1225
- device=_device), 2)] = 0
1226
- bos = torch.full((bs, 1),
1227
- self.bos_id,
1228
- dtype=torch.long,
1229
- device=_device)
1230
- for i in range(self.refine_iters):
1231
- # Prior context is the previous output.
1232
- tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
1233
- tgt_len = tgt_in.shape[1]
1234
- tgt_padding_mask = (tgt_in == self.eos_id).int().cumsum(
1235
- -1) > 0 # mask tokens beyond the first EOS token.
1236
- tgt_out = self.decode(
1237
- tgt_in,
1238
- memory,
1239
- tgt_mask[:tgt_len, :tgt_len],
1240
- tgt_padding_mask,
1241
- tgt_query=pos_queries,
1242
- tgt_query_mask=query_mask[:, :tgt_len],
1243
- pos_query=pos_queries,
1244
- )
1245
- logits = self.head(tgt_out)
1246
-
1247
- return F.softmax(logits, -1)
1248
-
1249
- def gen_tgt_perms(self, tgt, _device):
1250
- """Generate shared permutations for the whole batch.
1251
-
1252
- This works because the same attention mask can be used for the shorter
1253
- sequences because of the padding mask.
1254
- """
1255
- # We don't permute the position of BOS, we permute EOS separately
1256
- max_num_chars = tgt.shape[1] - 2
1257
- # Special handling for 1-character sequences
1258
- if max_num_chars == 1:
1259
- return torch.arange(3, device=_device).unsqueeze(0)
1260
- perms = [torch.arange(max_num_chars, device=_device)
1261
- ] if self.perm_forward else []
1262
- # Additional permutations if needed
1263
- max_perms = math.factorial(max_num_chars)
1264
- if self.perm_mirrored:
1265
- max_perms //= 2
1266
- num_gen_perms = min(self.max_gen_perms, max_perms)
1267
- # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions
1268
- # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars.
1269
- if max_num_chars < 5:
1270
- # Pool of permutations to sample from. We only need the first half (if complementary option is selected)
1271
- # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
1272
- if max_num_chars == 4 and self.perm_mirrored:
1273
- selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
1274
- else:
1275
- selector = list(range(max_perms))
1276
- perm_pool = torch.as_tensor(list(
1277
- permutations(range(max_num_chars), max_num_chars)),
1278
- device=_device)[selector]
1279
- # If the forward permutation is always selected, no need to add it to the pool for sampling
1280
- if self.perm_forward:
1281
- perm_pool = perm_pool[1:]
1282
- perms = torch.stack(perms)
1283
- if len(perm_pool):
1284
- i = self.rng.choice(len(perm_pool),
1285
- size=num_gen_perms - len(perms),
1286
- replace=False)
1287
- perms = torch.cat([perms, perm_pool[i]])
1288
- else:
1289
- perms.extend([
1290
- torch.randperm(max_num_chars, device=_device)
1291
- for _ in range(num_gen_perms - len(perms))
1292
- ])
1293
- perms = torch.stack(perms)
1294
- if self.perm_mirrored:
1295
- # Add complementary pairs
1296
- comp = perms.flip(-1)
1297
- # Stack in such a way that the pairs are next to each other.
1298
- perms = torch.stack([perms, comp
1299
- ]).transpose(0, 1).reshape(-1, max_num_chars)
1300
- # NOTE:
1301
- # The only meaningful way of permuting the EOS position is by moving it one character position at a time.
1302
- # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS
1303
- # positions will always be much less than the number of permutations (unless a low perm_num is set).
1304
- # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly
1305
- # distribute it across the chosen number of permutations.
1306
- # Add position indices of BOS and EOS
1307
- bos_idx = perms.new_zeros((len(perms), 1))
1308
- eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1)
1309
- perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1)
1310
- # Special handling for the reverse direction. This does two things:
1311
- # 1. Reverse context for the characters
1312
- # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode)
1313
- if len(perms) > 1:
1314
- perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1,
1315
- device=_device)
1316
- return perms
1317
-
1318
- def generate_attn_masks(self, perm, _device):
1319
- """Generate attention masks given a sequence permutation (includes pos.
1320
- for bos and eos tokens)
1321
-
1322
- :param perm: the permutation sequence. i = 0 is always the BOS
1323
- :return: lookahead attention masks
1324
- """
1325
- sz = perm.shape[0]
1326
- mask = torch.zeros((sz, sz), device=_device)
1327
- for i in range(sz):
1328
- query_idx = perm[i]
1329
- masked_keys = perm[i + 1:]
1330
- mask[query_idx, masked_keys] = float('-inf')
1331
- content_mask = mask[:-1, :-1].clone()
1332
- mask[torch.eye(sz, dtype=torch.bool,
1333
- device=_device)] = float('-inf') # mask "self"
1334
- query_mask = mask[1:, :-1]
1335
- return content_mask, query_mask
1336
-
1337
- def training_step(self, memory, pos_query, tgt_ids, clip_ids):
1338
- bs = tgt_ids.shape[0]
1339
- if self.is_pretrain:
1340
- memory = self.clip_encoder.encode_text(clip_ids)
1341
- n = memory.shape[1]
1342
- B, N, D = self.background_features.shape
1343
- random_B = np.random.choice(B, bs, replace=False)
1344
- random_N = np.random.choice(N, n, replace=False)
1345
- noise = self.background_features[random_B][:, random_N]
1346
- noise = torch.from_numpy(noise).to(memory.get_device())
1347
- memory = memory + noise * 1e-1
1348
- else:
1349
- token_query = self.token_query.expand(bs, -1, -1)
1350
- memory = self.fmu(token_query, memory)
1351
-
1352
- if pos_query is None:
1353
- pos_query = self.pos_queries.expand(bs, -1, -1)
1354
- # Prepare the target sequences (input and output)
1355
- tgt_perms = self.gen_tgt_perms(tgt_ids, memory.get_device())
1356
- tgt_in = tgt_ids[:, :-1]
1357
- tgt_out = tgt_ids[:, 1:]
1358
-
1359
- # The [EOS] token is not depended upon by any other token in any permutation ordering
1360
- tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
1361
-
1362
- loss = 0
1363
- loss_numel = 0
1364
- n = (tgt_out != self.pad_id).sum().item()
1365
- for i, perm in enumerate(tgt_perms):
1366
- tgt_mask, query_mask = self.generate_attn_masks(
1367
- perm, memory.get_device())
1368
- # print("tgt_in:", tgt_in, "tgt_out:", tgt_out, "tgt_padding_mask:", tgt_padding_mask)
1369
- # print('tgt_mask:', tgt_mask)
1370
- # print('query_mask:', query_mask)
1371
- # print(tgt_in.shape, memory.shape, tgt_mask.shape, tgt_padding_mask.shape, query_mask.shape, pos_query.shape)
1372
- out = self.decode(
1373
- tgt_in,
1374
- memory,
1375
- tgt_mask,
1376
- tgt_padding_mask,
1377
- tgt_query_mask=query_mask,
1378
- pos_query=pos_query,
1379
- )
1380
- # print('out:', out)
1381
- logits = self.head(out)
1382
- # print('logits:', logits)
1383
- if i == 0:
1384
- final_out = logits
1385
- loss += n * F.cross_entropy(logits.flatten(end_dim=1),
1386
- tgt_out.flatten(),
1387
- ignore_index=self.pad_id)
1388
- loss_numel += n
1389
- # After the second iteration (i.e. done with canonical and reverse orderings),
1390
- # remove the [EOS] tokens for the succeeding perms
1391
- if i == 1:
1392
- tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id,
1393
- tgt_out)
1394
- n = (tgt_out != self.pad_id).sum().item()
1395
- loss /= loss_numel
1396
-
1397
- # self.log('loss', loss)
1398
- return [loss, final_out]
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from itertools import permutations
18
+ from collections import OrderedDict
19
+ import hashlib
20
+ import os
21
+ import gzip
22
+ import html
23
+ import urllib
24
+ import warnings
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from torch import Tensor
30
+ from torch.nn.modules import transformer
31
+ from typing import Any, Optional, Tuple, List, Union
32
+ from pkg_resources import packaging
33
+ from PIL import Image
34
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
35
+ from tqdm import tqdm
36
+ from functools import lru_cache
37
+
38
+ import ftfy
39
+ import regex as re
40
+
41
+ try:
42
+ from torchvision.transforms import InterpolationMode
43
+ BICUBIC = InterpolationMode.BICUBIC
44
+ except ImportError:
45
+ BICUBIC = Image.BICUBIC
46
+
47
+
48
+ @lru_cache()
49
+ def default_bpe():
50
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)),
51
+ 'bpe_simple_vocab_16e6.txt.gz')
52
+
53
+
54
+ @lru_cache()
55
+ def bytes_to_unicode():
56
+ """
57
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
58
+ The reversible bpe codes work on unicode strings.
59
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
60
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
61
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
62
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
63
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
64
+ """
65
+ bs = list(range(ord('!'),
66
+ ord('~') + 1)) + list(range(
67
+ ord('¡'),
68
+ ord('¬') + 1)) + list(range(ord('®'),
69
+ ord('ÿ') + 1))
70
+ cs = bs[:]
71
+ n = 0
72
+ for b in range(2**8):
73
+ if b not in bs:
74
+ bs.append(b)
75
+ cs.append(2**8 + n)
76
+ n += 1
77
+ cs = [chr(n) for n in cs]
78
+ return dict(zip(bs, cs))
79
+
80
+
81
+ def get_pairs(word):
82
+ """Return set of symbol pairs in a word.
83
+ Word is represented as tuple of symbols (symbols being variable-length strings).
84
+ """
85
+ pairs = set()
86
+ prev_char = word[0]
87
+ for char in word[1:]:
88
+ pairs.add((prev_char, char))
89
+ prev_char = char
90
+ return pairs
91
+
92
+
93
+ def basic_clean(text):
94
+ text = ftfy.fix_text(text)
95
+ text = html.unescape(html.unescape(text))
96
+ return text.strip()
97
+
98
+
99
+ def whitespace_clean(text):
100
+ text = re.sub(r'\s+', ' ', text)
101
+ text = text.strip()
102
+ return text
103
+
104
+
105
+ class SimpleTokenizer(object):
106
+
107
+ def __init__(self, bpe_path: str = default_bpe()):
108
+ self.byte_encoder = bytes_to_unicode()
109
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
110
+ merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
111
+ merges = merges[1:49152 - 256 - 2 + 1]
112
+ merges = [tuple(merge.split()) for merge in merges]
113
+ vocab = list(bytes_to_unicode().values())
114
+ vocab = vocab + [v + '</w>' for v in vocab]
115
+ for merge in merges:
116
+ vocab.append(''.join(merge))
117
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
118
+ self.encoder = dict(zip(vocab, range(len(vocab))))
119
+ self.decoder = {v: k for k, v in self.encoder.items()}
120
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
121
+ self.cache = {
122
+ '<|startoftext|>': '<|startoftext|>',
123
+ '<|endoftext|>': '<|endoftext|>'
124
+ }
125
+ self.pat = re.compile(
126
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
127
+ re.IGNORECASE)
128
+
129
+ def bpe(self, token):
130
+ if token in self.cache:
131
+ return self.cache[token]
132
+ word = tuple(token[:-1]) + (token[-1] + '</w>', )
133
+ pairs = get_pairs(word)
134
+
135
+ if not pairs:
136
+ return token + '</w>'
137
+
138
+ while True:
139
+ bigram = min(
140
+ pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
141
+ if bigram not in self.bpe_ranks:
142
+ break
143
+ first, second = bigram
144
+ new_word = []
145
+ i = 0
146
+ while i < len(word):
147
+ try:
148
+ j = word.index(first, i)
149
+ new_word.extend(word[i:j])
150
+ i = j
151
+ except:
152
+ new_word.extend(word[i:])
153
+ break
154
+
155
+ if word[i] == first and i < len(word) - 1 and word[
156
+ i + 1] == second:
157
+ new_word.append(first + second)
158
+ i += 2
159
+ else:
160
+ new_word.append(word[i])
161
+ i += 1
162
+ new_word = tuple(new_word)
163
+ word = new_word
164
+ if len(word) == 1:
165
+ break
166
+ else:
167
+ pairs = get_pairs(word)
168
+ word = ' '.join(word)
169
+ self.cache[token] = word
170
+ return word
171
+
172
+ def encode(self, text):
173
+ bpe_tokens = []
174
+ text = whitespace_clean(basic_clean(text)).lower()
175
+ for token in re.findall(self.pat, text):
176
+ token = ''.join(self.byte_encoder[b]
177
+ for b in token.encode('utf-8'))
178
+ bpe_tokens.extend(self.encoder[bpe_token]
179
+ for bpe_token in self.bpe(token).split(' '))
180
+ return bpe_tokens
181
+
182
+ def decode(self, tokens):
183
+ text = ''.join([self.decoder[token] for token in tokens])
184
+ text = bytearray([self.byte_decoder[c] for c in text
185
+ ]).decode('utf-8',
186
+ errors='replace').replace('</w>', ' ')
187
+ return text
188
+
189
+
190
+ if packaging.version.parse(
191
+ torch.__version__) < packaging.version.parse('1.7.1'):
192
+ warnings.warn('PyTorch version 1.7.1 or higher is recommended')
193
+
194
+ __all__ = ['available_models', 'load', 'tokenize']
195
+ _tokenizer = SimpleTokenizer()
196
+
197
+ _MODELS = {
198
+ 'RN50':
199
+ 'https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt',
200
+ 'RN101':
201
+ 'https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt',
202
+ 'RN50x4':
203
+ 'https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt',
204
+ 'RN50x16':
205
+ 'https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt',
206
+ 'RN50x64':
207
+ 'https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt',
208
+ 'ViT-B/32':
209
+ 'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt',
210
+ 'ViT-B/16':
211
+ 'https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt',
212
+ 'ViT-L/14':
213
+ 'https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt',
214
+ 'ViT-L/14@336px':
215
+ 'https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt',
216
+ }
217
+
218
+
219
+ def convert_weights(model: nn.Module):
220
+ """Convert applicable model parameters to fp16"""
221
+
222
+ def _convert_weights_to_fp16(l):
223
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
224
+ l.weight.data = l.weight.data.half()
225
+ if l.bias is not None:
226
+ l.bias.data = l.bias.data.half()
227
+
228
+ if isinstance(l, nn.MultiheadAttention):
229
+ for attr in [
230
+ *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
231
+ 'in_proj_bias', 'bias_k', 'bias_v'
232
+ ]:
233
+ tensor = getattr(l, attr)
234
+ if tensor is not None:
235
+ tensor.data = tensor.data.half()
236
+
237
+ for name in ['text_projection', 'proj']:
238
+ if hasattr(l, name):
239
+ attr = getattr(l, name)
240
+ if attr is not None:
241
+ attr.data = attr.data.half()
242
+
243
+ model.apply(_convert_weights_to_fp16)
244
+
245
+
246
+ def build_model(state_dict: dict):
247
+ vit = 'visual.proj' in state_dict
248
+
249
+ if vit:
250
+ vision_width = state_dict['visual.conv1.weight'].shape[0]
251
+ vision_layers = len([
252
+ k for k in state_dict.keys()
253
+ if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
254
+ ])
255
+ vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
256
+ grid_size = round(
257
+ (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
258
+ image_resolution = vision_patch_size * grid_size
259
+ else:
260
+ counts: list = [
261
+ len(
262
+ set(
263
+ k.split('.')[2] for k in state_dict
264
+ if k.startswith(f'visual.layer{b}')))
265
+ for b in [1, 2, 3, 4]
266
+ ]
267
+ vision_layers = tuple(counts)
268
+ vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0]
269
+ output_width = round(
270
+ (state_dict['visual.attnpool.positional_embedding'].shape[0] -
271
+ 1)**0.5)
272
+ vision_patch_size = None
273
+ assert output_width**2 + 1 == state_dict[
274
+ 'visual.attnpool.positional_embedding'].shape[0]
275
+ image_resolution = output_width * 32
276
+
277
+ embed_dim = state_dict['text_projection'].shape[1]
278
+ context_length = state_dict['positional_embedding'].shape[0]
279
+ vocab_size = state_dict['token_embedding.weight'].shape[0]
280
+ transformer_width = state_dict['ln_final.weight'].shape[0]
281
+ transformer_heads = transformer_width // 64
282
+ transformer_layers = len(
283
+ set(
284
+ k.split('.')[2] for k in state_dict
285
+ if k.startswith('transformer.resblocks')))
286
+
287
+ model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
288
+ vision_patch_size, context_length, vocab_size,
289
+ transformer_width, transformer_heads, transformer_layers)
290
+
291
+ for key in ['input_resolution', 'context_length', 'vocab_size']:
292
+ if key in state_dict:
293
+ del state_dict[key]
294
+
295
+ convert_weights(model)
296
+ model.load_state_dict(state_dict)
297
+ return model.eval()
298
+
299
+
300
+ def _download(url: str, root: str):
301
+ os.makedirs(root, exist_ok=True)
302
+ filename = os.path.basename(url)
303
+
304
+ expected_sha256 = url.split('/')[-2]
305
+ download_target = os.path.join(root, filename)
306
+
307
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
308
+ raise RuntimeError(
309
+ f'{download_target} exists and is not a regular file')
310
+
311
+ if os.path.isfile(download_target):
312
+ if hashlib.sha256(open(download_target,
313
+ 'rb').read()).hexdigest() == expected_sha256:
314
+ return download_target
315
+ else:
316
+ warnings.warn(
317
+ f'{download_target} exists, but the SHA256 checksum does not match; re-downloading the file'
318
+ )
319
+
320
+ with urllib.request.urlopen(url) as source, open(download_target,
321
+ 'wb') as output:
322
+ with tqdm(total=int(source.info().get('Content-Length')),
323
+ ncols=80,
324
+ unit='iB',
325
+ unit_scale=True,
326
+ unit_divisor=1024) as loop:
327
+ while True:
328
+ buffer = source.read(8192)
329
+ if not buffer:
330
+ break
331
+
332
+ output.write(buffer)
333
+ loop.update(len(buffer))
334
+
335
+ if hashlib.sha256(open(download_target,
336
+ 'rb').read()).hexdigest() != expected_sha256:
337
+ raise RuntimeError(
338
+ 'Model has been downloaded but the SHA256 checksum does not not match'
339
+ )
340
+
341
+ return download_target
342
+
343
+
344
+ def _convert_image_to_rgb(image):
345
+ return image.convert('RGB')
346
+
347
+
348
+ def _transform(n_px):
349
+ return Compose([
350
+ Resize(n_px, interpolation=BICUBIC),
351
+ CenterCrop(n_px),
352
+ _convert_image_to_rgb,
353
+ ToTensor(),
354
+ Normalize((0.48145466, 0.4578275, 0.40821073),
355
+ (0.26862954, 0.26130258, 0.27577711)),
356
+ ])
357
+
358
+
359
+ def available_models() -> List[str]:
360
+ """Returns the names of available CLIP models"""
361
+ return list(_MODELS.keys())
362
+
363
+
364
+ class Bottleneck(nn.Module):
365
+ expansion = 4
366
+
367
+ def __init__(self, inplanes, planes, stride=1):
368
+ super().__init__()
369
+
370
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
371
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
372
+ self.bn1 = nn.BatchNorm2d(planes)
373
+ self.relu1 = nn.ReLU(inplace=True)
374
+
375
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
376
+ self.bn2 = nn.BatchNorm2d(planes)
377
+ self.relu2 = nn.ReLU(inplace=True)
378
+
379
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
380
+
381
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
382
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
383
+ self.relu3 = nn.ReLU(inplace=True)
384
+
385
+ self.downsample = None
386
+ self.stride = stride
387
+
388
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
389
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
390
+ self.downsample = nn.Sequential(
391
+ OrderedDict([('-1', nn.AvgPool2d(stride)),
392
+ ('0',
393
+ nn.Conv2d(inplanes,
394
+ planes * self.expansion,
395
+ 1,
396
+ stride=1,
397
+ bias=False)),
398
+ ('1', nn.BatchNorm2d(planes * self.expansion))]))
399
+
400
+ def forward(self, x: torch.Tensor):
401
+ identity = x
402
+
403
+ out = self.relu1(self.bn1(self.conv1(x)))
404
+ out = self.relu2(self.bn2(self.conv2(out)))
405
+ out = self.avgpool(out)
406
+ out = self.bn3(self.conv3(out))
407
+
408
+ if self.downsample is not None:
409
+ identity = self.downsample(x)
410
+
411
+ out += identity
412
+ out = self.relu3(out)
413
+ return out
414
+
415
+
416
+ class AttentionPool2d(nn.Module):
417
+
418
+ def __init__(self,
419
+ spacial_dim: int,
420
+ embed_dim: int,
421
+ num_heads: int,
422
+ output_dim: int = None):
423
+ super().__init__()
424
+ self.positional_embedding = nn.Parameter(
425
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
426
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
427
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
428
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
429
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
430
+ self.num_heads = num_heads
431
+
432
+ def forward(self, x):
433
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
434
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
435
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
436
+ x, _ = F.multi_head_attention_forward(
437
+ query=x[:1],
438
+ key=x,
439
+ value=x,
440
+ embed_dim_to_check=x.shape[-1],
441
+ num_heads=self.num_heads,
442
+ q_proj_weight=self.q_proj.weight,
443
+ k_proj_weight=self.k_proj.weight,
444
+ v_proj_weight=self.v_proj.weight,
445
+ in_proj_weight=None,
446
+ in_proj_bias=torch.cat(
447
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
448
+ bias_k=None,
449
+ bias_v=None,
450
+ add_zero_attn=False,
451
+ dropout_p=0,
452
+ out_proj_weight=self.c_proj.weight,
453
+ out_proj_bias=self.c_proj.bias,
454
+ use_separate_proj_weight=True,
455
+ training=self.training,
456
+ need_weights=False)
457
+ return x.squeeze(0)
458
+
459
+
460
+ class ModifiedResNet(nn.Module):
461
+ """
462
+ A ResNet class that is similar to torchvision's but contains the following changes:
463
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
464
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
465
+ - The final pooling layer is a QKV attention instead of an average pool
466
+ """
467
+
468
+ def __init__(self,
469
+ layers,
470
+ output_dim,
471
+ heads,
472
+ input_resolution=224,
473
+ width=64):
474
+ super().__init__()
475
+ self.output_dim = output_dim
476
+ self.input_resolution = input_resolution
477
+
478
+ # the 3-layer stem
479
+ self.conv1 = nn.Conv2d(3,
480
+ width // 2,
481
+ kernel_size=3,
482
+ stride=2,
483
+ padding=1,
484
+ bias=False)
485
+ self.bn1 = nn.BatchNorm2d(width // 2)
486
+ self.relu1 = nn.ReLU(inplace=True)
487
+ self.conv2 = nn.Conv2d(width // 2,
488
+ width // 2,
489
+ kernel_size=3,
490
+ padding=1,
491
+ bias=False)
492
+ self.bn2 = nn.BatchNorm2d(width // 2)
493
+ self.relu2 = nn.ReLU(inplace=True)
494
+ self.conv3 = nn.Conv2d(width // 2,
495
+ width,
496
+ kernel_size=3,
497
+ padding=1,
498
+ bias=False)
499
+ self.bn3 = nn.BatchNorm2d(width)
500
+ self.relu3 = nn.ReLU(inplace=True)
501
+ self.avgpool = nn.AvgPool2d(2)
502
+
503
+ # residual layers
504
+ self._inplanes = width # this is a *mutable* variable used during construction
505
+ self.layer1 = self._make_layer(width, layers[0])
506
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
507
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
508
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
509
+
510
+ embed_dim = width * 32 # the ResNet feature dimension
511
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
512
+ heads, output_dim)
513
+
514
+ def _make_layer(self, planes, blocks, stride=1):
515
+ layers = [Bottleneck(self._inplanes, planes, stride)]
516
+
517
+ self._inplanes = planes * Bottleneck.expansion
518
+ for _ in range(1, blocks):
519
+ layers.append(Bottleneck(self._inplanes, planes))
520
+
521
+ return nn.Sequential(*layers)
522
+
523
+ def forward(self, x):
524
+
525
+ def stem(x):
526
+ x = self.relu1(self.bn1(self.conv1(x)))
527
+ x = self.relu2(self.bn2(self.conv2(x)))
528
+ x = self.relu3(self.bn3(self.conv3(x)))
529
+ x = self.avgpool(x)
530
+ return x
531
+
532
+ x = x.type(self.conv1.weight.dtype)
533
+ x = stem(x)
534
+ x = self.layer1(x)
535
+ x = self.layer2(x)
536
+ x = self.layer3(x)
537
+ x = self.layer4(x)
538
+ x = self.attnpool(x)
539
+
540
+ return x
541
+
542
+
543
+ class LayerNorm(nn.LayerNorm):
544
+ """Subclass torch's LayerNorm to handle fp16."""
545
+
546
+ def forward(self, x: torch.Tensor):
547
+ orig_type = x.dtype
548
+ ret = super().forward(x.type(torch.float32))
549
+ return ret.type(orig_type)
550
+
551
+
552
+ class QuickGELU(nn.Module):
553
+
554
+ def forward(self, x: torch.Tensor):
555
+ return x * torch.sigmoid(1.702 * x)
556
+
557
+
558
+ class ResidualAttentionBlock(nn.Module):
559
+
560
+ def __init__(self,
561
+ d_model: int,
562
+ n_head: int,
563
+ attn_mask: torch.Tensor = None):
564
+ super().__init__()
565
+
566
+ self.attn = nn.MultiheadAttention(d_model, n_head)
567
+ self.ln_1 = LayerNorm(d_model)
568
+ self.mlp = nn.Sequential(
569
+ OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
570
+ ('gelu', QuickGELU()),
571
+ ('c_proj', nn.Linear(d_model * 4, d_model))]))
572
+ self.ln_2 = LayerNorm(d_model)
573
+ self.attn_mask = attn_mask
574
+
575
+ def attention(self, x: torch.Tensor):
576
+ self.attn_mask = self.attn_mask.to(
577
+ dtype=x.dtype,
578
+ device=x.device) if self.attn_mask is not None else None
579
+ return self.attn(x, x, x, need_weights=False,
580
+ attn_mask=self.attn_mask)[0]
581
+
582
+ def forward(self, x: torch.Tensor):
583
+ x = x + self.attention(self.ln_1(x))
584
+ x = x + self.mlp(self.ln_2(x))
585
+ return x
586
+
587
+
588
+ class Transformer(nn.Module):
589
+
590
+ def __init__(self,
591
+ width: int,
592
+ layers: int,
593
+ heads: int,
594
+ attn_mask: torch.Tensor = None):
595
+ super().__init__()
596
+ self.width = width
597
+ self.layers = layers
598
+ self.resblocks = nn.Sequential(*[
599
+ ResidualAttentionBlock(width, heads, attn_mask)
600
+ for _ in range(layers)
601
+ ])
602
+
603
+ def forward(self, x: torch.Tensor):
604
+ return self.resblocks(x)
605
+
606
+
607
+ class VisionTransformer(nn.Module):
608
+
609
+ def __init__(self, input_resolution: int, patch_size: int, width: int,
610
+ layers: int, heads: int, output_dim: int):
611
+ super().__init__()
612
+ self.input_resolution = input_resolution
613
+ self.output_dim = output_dim
614
+ self.conv1 = nn.Conv2d(in_channels=3,
615
+ out_channels=width,
616
+ kernel_size=patch_size,
617
+ stride=patch_size,
618
+ bias=False)
619
+
620
+ scale = width**-0.5
621
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
622
+ self.positional_embedding = nn.Parameter(scale * torch.randn(
623
+ (input_resolution // patch_size)**2 + 1, width))
624
+ self.ln_pre = LayerNorm(width)
625
+
626
+ self.transformer = Transformer(width, layers, heads)
627
+
628
+ self.ln_post = LayerNorm(width)
629
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
630
+
631
+ def forward(self, x: torch.Tensor):
632
+ x = self.conv1(x) # shape = [*, width, grid, grid]
633
+ x = x.reshape(x.shape[0], x.shape[1],
634
+ -1) # shape = [*, width, grid ** 2]
635
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
636
+ x = torch.cat([
637
+ self.class_embedding.to(x.dtype) + torch.zeros(
638
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
639
+ ],
640
+ dim=1) # shape = [*, grid ** 2 + 1, width]
641
+ x = x + self.positional_embedding.to(x.dtype)
642
+ x = self.ln_pre(x)
643
+
644
+ x = x.permute(1, 0, 2) # NLD -> LND
645
+ x = self.transformer(x)
646
+ x = x.permute(1, 0, 2) # LND -> NLD
647
+
648
+ x = self.ln_post(x)
649
+ if self.proj is not None:
650
+ x = x @ self.proj
651
+
652
+ return x
653
+
654
+
655
+ class CLIP(nn.Module):
656
+
657
+ def __init__(
658
+ self,
659
+ embed_dim: int,
660
+ # vision
661
+ image_resolution: int,
662
+ vision_layers: Union[Tuple[int, int, int, int], int],
663
+ vision_width: int,
664
+ vision_patch_size: int,
665
+ # text
666
+ context_length: int,
667
+ vocab_size: int,
668
+ transformer_width: int,
669
+ transformer_heads: int,
670
+ transformer_layers: int):
671
+ super().__init__()
672
+
673
+ self.context_length = context_length
674
+
675
+ if isinstance(vision_layers, (tuple, list)):
676
+ vision_heads = vision_width * 32 // 64
677
+ self.visual = ModifiedResNet(layers=vision_layers,
678
+ output_dim=embed_dim,
679
+ heads=vision_heads,
680
+ input_resolution=image_resolution,
681
+ width=vision_width)
682
+ else:
683
+ vision_heads = vision_width // 64
684
+ self.visual = VisionTransformer(input_resolution=image_resolution,
685
+ patch_size=vision_patch_size,
686
+ width=vision_width,
687
+ layers=vision_layers,
688
+ heads=vision_heads,
689
+ output_dim=embed_dim)
690
+
691
+ self.transformer = Transformer(width=transformer_width,
692
+ layers=transformer_layers,
693
+ heads=transformer_heads,
694
+ attn_mask=self.build_attention_mask())
695
+
696
+ self.vocab_size = vocab_size
697
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
698
+ self.positional_embedding = nn.Parameter(
699
+ torch.empty(self.context_length, transformer_width))
700
+ self.ln_final = LayerNorm(transformer_width)
701
+
702
+ self.text_projection = nn.Parameter(
703
+ torch.empty(transformer_width, embed_dim))
704
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
705
+
706
+ self.initialize_parameters()
707
+
708
+ def initialize_parameters(self):
709
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
710
+ nn.init.normal_(self.positional_embedding, std=0.01)
711
+
712
+ if isinstance(self.visual, ModifiedResNet):
713
+ if self.visual.attnpool is not None:
714
+ std = self.visual.attnpool.c_proj.in_features**-0.5
715
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
716
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
717
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
718
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
719
+
720
+ for resnet_block in [
721
+ self.visual.layer1, self.visual.layer2, self.visual.layer3,
722
+ self.visual.layer4
723
+ ]:
724
+ for name, param in resnet_block.named_parameters():
725
+ if name.endswith('bn3.weight'):
726
+ nn.init.zeros_(param)
727
+
728
+ proj_std = (self.transformer.width**-0.5) * (
729
+ (2 * self.transformer.layers)**-0.5)
730
+ attn_std = self.transformer.width**-0.5
731
+ fc_std = (2 * self.transformer.width)**-0.5
732
+ for block in self.transformer.resblocks:
733
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
734
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
735
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
736
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
737
+
738
+ if self.text_projection is not None:
739
+ nn.init.normal_(self.text_projection,
740
+ std=self.transformer.width**-0.5)
741
+
742
+ def build_attention_mask(self):
743
+ # lazily create causal attention mask, with full attention between the vision tokens
744
+ # pytorch uses additive attention mask; fill with -inf
745
+ mask = torch.empty(self.context_length, self.context_length)
746
+ mask.fill_(float('-inf'))
747
+ mask.triu_(1) # zero out the lower diagonal
748
+ return mask
749
+
750
+ @property
751
+ def dtype(self):
752
+ return self.visual.conv1.weight.dtype
753
+
754
+ def encode_image(self, image):
755
+ return self.visual(image.type(self.dtype))
756
+
757
+ def encode_text(self, text):
758
+ x = self.token_embedding(text).type(
759
+ self.dtype) # [batch_size, n_ctx, d_model]
760
+
761
+ x = x + self.positional_embedding.type(self.dtype)
762
+ x = x.permute(1, 0, 2) # NLD -> LND
763
+ x = self.transformer(x)
764
+ x = x.permute(1, 0, 2) # LND -> NLD
765
+ x = self.ln_final(x).type(self.dtype)
766
+
767
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
768
+ output = x[torch.arange(x.shape[0]),
769
+ text.argmax(dim=-1)] @ self.text_projection
770
+ output = torch.cat([output.unsqueeze(1), x], dim=1)
771
+
772
+ return output
773
+
774
+ def forward(self, image, text):
775
+ image_features = self.encode_image(image)
776
+ text_features = self.encode_text(text)
777
+
778
+ # normalized features
779
+ image_features = image_features / image_features.norm(dim=1,
780
+ keepdim=True)
781
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
782
+
783
+ # cosine similarity as logits
784
+ logit_scale = self.logit_scale.exp()
785
+ logits_per_image = logit_scale * image_features @ text_features.t()
786
+ logits_per_text = logits_per_image.t()
787
+
788
+ # shape = [global_batch_size, global_batch_size]
789
+ return logits_per_image, logits_per_text
790
+
791
+
792
+ class FMU(nn.Module):
793
+ """A Transformer decoder layer supporting two-stream attention (XLNet)
794
+ This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
795
+
796
+ def __init__(self,
797
+ d_model,
798
+ nhead,
799
+ dim_feedforward=2048,
800
+ dropout=0.1,
801
+ activation='gelu',
802
+ layer_norm_eps=1e-5):
803
+ super().__init__()
804
+ self.cross_attn = nn.MultiheadAttention(d_model,
805
+ nhead,
806
+ dropout=dropout,
807
+ batch_first=True)
808
+ # Implementation of Feedforward model
809
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
810
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
811
+
812
+ self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
813
+
814
+ self.dropout1 = nn.Dropout(dropout)
815
+ self.dropout2 = nn.Dropout(dropout)
816
+ self.dropout3 = nn.Dropout(dropout)
817
+
818
+ self.activation = transformer._get_activation_fn(activation)
819
+
820
+ def __setstate__(self, state):
821
+ if 'activation' not in state:
822
+ state['activation'] = F.gelu
823
+ super().__setstate__(state)
824
+
825
+ def forward(self, query: Tensor, memory: Tensor):
826
+ """Forward pass for a single stream (i.e. content or query)
827
+ tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
828
+ Both tgt_kv and memory are expected to be LayerNorm'd too.
829
+ memory is LayerNorm'd by ViT.
830
+ """
831
+ query1, ca_weights = self.cross_attn(query, memory, memory)
832
+ query = query + self.dropout1(query1)
833
+
834
+ query2 = self.linear2(
835
+ self.dropout2(self.activation(self.linear1(self.norm(query)))))
836
+ query = query + self.dropout3(query2)
837
+
838
+ return query
839
+
840
+
841
+ class DecoderLayer(nn.Module):
842
+ """A Transformer decoder layer supporting two-stream attention (XLNet) This
843
+ implements a pre-LN decoder, as opposed to the post-LN default in
844
+ PyTorch."""
845
+
846
+ def __init__(
847
+ self,
848
+ d_model,
849
+ nhead,
850
+ dim_feedforward=2048,
851
+ dropout=0.1,
852
+ activation='gelu',
853
+ layer_norm_eps=1e-5,
854
+ ):
855
+ super().__init__()
856
+ self.self_attn = nn.MultiheadAttention(d_model,
857
+ nhead,
858
+ dropout=dropout,
859
+ batch_first=True)
860
+ self.cross_attn = nn.MultiheadAttention(d_model,
861
+ nhead,
862
+ dropout=dropout,
863
+ batch_first=True)
864
+ # Implementation of Feedforward model
865
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
866
+ self.dropout = nn.Dropout(dropout)
867
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
868
+
869
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
870
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
871
+ self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps)
872
+ self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps)
873
+ self.dropout1 = nn.Dropout(dropout)
874
+ self.dropout2 = nn.Dropout(dropout)
875
+ self.dropout3 = nn.Dropout(dropout)
876
+
877
+ self.activation = transformer._get_activation_fn(activation)
878
+
879
+ def __setstate__(self, state):
880
+ if 'activation' not in state:
881
+ state['activation'] = F.gelu
882
+ super().__setstate__(state)
883
+
884
+ def forward_stream(
885
+ self,
886
+ tgt: Tensor,
887
+ tgt_norm: Tensor,
888
+ tgt_kv: Tensor,
889
+ memory: Tensor,
890
+ tgt_mask: Optional[Tensor],
891
+ tgt_key_padding_mask: Optional[Tensor],
892
+ ):
893
+ """Forward pass for a single stream (i.e. content or query) tgt_norm is
894
+ just a LayerNorm'd tgt.
895
+
896
+ Added as a separate parameter for efficiency. Both tgt_kv and memory
897
+ are expected to be LayerNorm'd too. memory is LayerNorm'd by ViT.
898
+ """
899
+ tgt2, sa_weights = self.self_attn(
900
+ tgt_norm,
901
+ tgt_kv,
902
+ tgt_kv,
903
+ attn_mask=tgt_mask,
904
+ key_padding_mask=tgt_key_padding_mask)
905
+
906
+ tgt = tgt + self.dropout1(tgt2)
907
+
908
+ tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
909
+ self.attn_map = ca_weights
910
+ tgt = tgt + self.dropout2(tgt2)
911
+
912
+ tgt2 = self.linear2(
913
+ self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
914
+ tgt = tgt + self.dropout3(tgt2)
915
+ return tgt, sa_weights, ca_weights
916
+
917
+ def forward(
918
+ self,
919
+ query,
920
+ content,
921
+ memory,
922
+ query_mask: Optional[Tensor] = None,
923
+ content_mask: Optional[Tensor] = None,
924
+ content_key_padding_mask: Optional[Tensor] = None,
925
+ update_content: bool = True,
926
+ ):
927
+ query_norm = self.norm_q(query)
928
+ content_norm = self.norm_c(content)
929
+ query = self.forward_stream(query, query_norm, content_norm, memory,
930
+ query_mask, content_key_padding_mask)[0]
931
+ if update_content:
932
+ content = self.forward_stream(content, content_norm, content_norm,
933
+ memory, content_mask,
934
+ content_key_padding_mask)[0]
935
+ return query, content
936
+
937
+
938
+ class Decoder(nn.Module):
939
+ __constants__ = ['norm']
940
+
941
+ def __init__(self, decoder_layer, num_layers, norm):
942
+ super().__init__()
943
+ self.layers = transformer._get_clones(decoder_layer, num_layers)
944
+ self.num_layers = num_layers
945
+ self.norm = norm
946
+
947
+ def forward(
948
+ self,
949
+ query,
950
+ content,
951
+ memory,
952
+ query_mask: Optional[Tensor] = None,
953
+ content_mask: Optional[Tensor] = None,
954
+ content_key_padding_mask: Optional[Tensor] = None,
955
+ ):
956
+ for i, mod in enumerate(self.layers):
957
+ last = i == len(self.layers) - 1
958
+ query, content = mod(
959
+ query,
960
+ content,
961
+ memory,
962
+ query_mask,
963
+ content_mask,
964
+ content_key_padding_mask,
965
+ update_content=not last,
966
+ )
967
+ query = self.norm(query)
968
+ return query
969
+
970
+
971
+ class TokenEmbedding(nn.Module):
972
+
973
+ def __init__(self, charset_size: int, embed_dim: int):
974
+ super().__init__()
975
+ self.embedding = nn.Embedding(charset_size, embed_dim)
976
+ self.embed_dim = embed_dim
977
+
978
+ def forward(self, tokens: torch.Tensor):
979
+ return math.sqrt(self.embed_dim) * self.embedding(tokens)
980
+
981
+
982
+ def load(name: str,
983
+ device: Union[str, torch.device] = 'cuda'
984
+ if torch.cuda.is_available() else 'cpu',
985
+ jit: bool = False,
986
+ download_root: str = None):
987
+ """Load a CLIP model
988
+
989
+ Parameters
990
+ ----------
991
+ name : str
992
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
993
+
994
+ device : Union[str, torch.device]
995
+ The device to put the loaded model
996
+
997
+ jit : bool
998
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
999
+
1000
+ download_root: str
1001
+ path to download the model files; by default, it uses "~/.cache/clip"
1002
+
1003
+ Returns
1004
+ -------
1005
+ model : torch.nn.Module
1006
+ The CLIP model
1007
+
1008
+ preprocess : Callable[[PIL.Image], torch.Tensor]
1009
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
1010
+ """
1011
+ if name in _MODELS:
1012
+ model_path = _download(
1013
+ _MODELS[name], download_root
1014
+ or os.path.expanduser('~/.cache/clip'))
1015
+ elif os.path.isfile(name):
1016
+ model_path = name
1017
+ else:
1018
+ raise RuntimeError(
1019
+ f'Model {name} not found; available models = {available_models()}')
1020
+
1021
+ with open(model_path, 'rb') as opened_file:
1022
+ try:
1023
+ # loading JIT archive
1024
+ model = torch.jit.load(
1025
+ opened_file, map_location=device if jit else 'cpu').eval()
1026
+ state_dict = None
1027
+ except RuntimeError:
1028
+ # loading saved state dict
1029
+ if jit:
1030
+ warnings.warn(
1031
+ f'File {model_path} is not a JIT archive. Loading as a state dict instead'
1032
+ )
1033
+ jit = False
1034
+ state_dict = torch.load(opened_file, map_location='cpu')
1035
+
1036
+ if not jit:
1037
+ model = build_model(state_dict or model.state_dict()).to(device)
1038
+ if str(device) == 'cpu':
1039
+ model.float()
1040
+ return model, _transform(model.visual.input_resolution)
1041
+
1042
+ # patch the device names
1043
+ device_holder = torch.jit.trace(
1044
+ lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
1045
+ device_node = [
1046
+ n for n in device_holder.graph.findAllNodes('prim::Constant')
1047
+ if 'Device' in repr(n)
1048
+ ][-1]
1049
+
1050
+ def patch_device(module):
1051
+ try:
1052
+ graphs = [module.graph] if hasattr(module, 'graph') else []
1053
+ except RuntimeError:
1054
+ graphs = []
1055
+
1056
+ if hasattr(module, 'forward1'):
1057
+ graphs.append(module.forward1.graph)
1058
+
1059
+ for graph in graphs:
1060
+ for node in graph.findAllNodes('prim::Constant'):
1061
+ if 'value' in node.attributeNames() and str(
1062
+ node['value']).startswith('cuda'):
1063
+ node.copyAttributes(device_node)
1064
+
1065
+ model.apply(patch_device)
1066
+ patch_device(model.encode_image)
1067
+ patch_device(model.encode_text)
1068
+
1069
+ # patch dtype to float32 on CPU
1070
+ if str(device) == 'cpu':
1071
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(),
1072
+ example_inputs=[])
1073
+ float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
1074
+ float_node = float_input.node()
1075
+
1076
+ def patch_float(module):
1077
+ try:
1078
+ graphs = [module.graph] if hasattr(module, 'graph') else []
1079
+ except RuntimeError:
1080
+ graphs = []
1081
+
1082
+ if hasattr(module, 'forward1'):
1083
+ graphs.append(module.forward1.graph)
1084
+
1085
+ for graph in graphs:
1086
+ for node in graph.findAllNodes('aten::to'):
1087
+ inputs = list(node.inputs())
1088
+ for i in [
1089
+ 1, 2
1090
+ ]: # dtype can be the second or third argument to aten::to()
1091
+ if inputs[i].node()['value'] == 5:
1092
+ inputs[i].node().copyAttributes(float_node)
1093
+
1094
+ model.apply(patch_float)
1095
+ patch_float(model.encode_image)
1096
+ patch_float(model.encode_text)
1097
+
1098
+ model.float()
1099
+
1100
+ return model, _transform(model.input_resolution.item())
1101
+
1102
+
1103
+ def tokenize(
1104
+ texts: Union[str, List[str]],
1105
+ context_length: int = 77,
1106
+ truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
1107
+ """
1108
+ Returns the tokenized representation of given input string(s)
1109
+
1110
+ Parameters
1111
+ ----------
1112
+ texts : Union[str, List[str]]
1113
+ An input string or a list of input strings to tokenize
1114
+
1115
+ context_length : int
1116
+ The context length to use; all CLIP models use 77 as the context length
1117
+
1118
+ truncate: bool
1119
+ Whether to truncate the text in case its encoding is longer than the context length
1120
+
1121
+ Returns
1122
+ -------
1123
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
1124
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
1125
+ """
1126
+ if isinstance(texts, str):
1127
+ texts = [texts]
1128
+
1129
+ sot_token = _tokenizer.encoder['<|startoftext|>']
1130
+ eot_token = _tokenizer.encoder['<|endoftext|>']
1131
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
1132
+ for text in texts]
1133
+ if packaging.version.parse(
1134
+ torch.__version__) < packaging.version.parse('1.8.0'):
1135
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
1136
+ else:
1137
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
1138
+
1139
+ for i, tokens in enumerate(all_tokens):
1140
+ if len(tokens) > context_length:
1141
+ if truncate:
1142
+ tokens = tokens[:context_length]
1143
+ tokens[-1] = eot_token
1144
+ else:
1145
+ raise RuntimeError(
1146
+ f'Input {texts[i]} is too long for context length {context_length}'
1147
+ )
1148
+ result[i, :len(tokens)] = torch.tensor(tokens)
1149
+
1150
+ return result
1151
+
1152
+
1153
+ class DptrParseq(nn.Module):
1154
+
1155
+ def __init__(self,
1156
+ in_channels,
1157
+ out_channels,
1158
+ max_label_length=25,
1159
+ embed_dim=512,
1160
+ dec_num_heads=8,
1161
+ dec_mlp_ratio=4,
1162
+ dec_depth=6,
1163
+ perm_num=6,
1164
+ perm_forward=True,
1165
+ perm_mirrored=True,
1166
+ decode_ar=True,
1167
+ refine_iters=1,
1168
+ dropout=0.1,
1169
+ is_pretrain=True,
1170
+ ORP_path=None,
1171
+ **kwargs: Any) -> None:
1172
+ super().__init__()
1173
+ self.pad_id = out_channels - 1
1174
+ self.eos_id = 0
1175
+ self.bos_id = out_channels - 2
1176
+ self.max_label_length = max_label_length
1177
+ self.decode_ar = decode_ar
1178
+ self.refine_iters = refine_iters
1179
+ self.is_pretrain = is_pretrain
1180
+ if not is_pretrain:
1181
+ self.token_query = nn.Parameter(torch.Tensor(1, 26, embed_dim))
1182
+ self.fmu = FMU(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio,
1183
+ dropout)
1184
+
1185
+ decoder_layer = DecoderLayer(embed_dim, dec_num_heads,
1186
+ embed_dim * dec_mlp_ratio, dropout)
1187
+ self.decoder = Decoder(decoder_layer,
1188
+ num_layers=dec_depth,
1189
+ norm=nn.LayerNorm(embed_dim))
1190
+
1191
+ # Perm/attn mask stuff
1192
+ self.rng = np.random.default_rng()
1193
+ self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
1194
+ self.perm_forward = perm_forward
1195
+ self.perm_mirrored = perm_mirrored
1196
+
1197
+ # We don't predict <bos> nor <pad>
1198
+ self.head = nn.Linear(embed_dim, out_channels - 2)
1199
+ self.text_embed = TokenEmbedding(out_channels, embed_dim)
1200
+
1201
+ # +1 for <eos>
1202
+ self.pos_queries = nn.Parameter(
1203
+ torch.Tensor(1, max_label_length + 1, embed_dim))
1204
+ self.dropout = nn.Dropout(p=dropout)
1205
+ # Encoder has its own init.
1206
+ self.apply(self._init_weights)
1207
+ nn.init.trunc_normal_(self.pos_queries, std=0.02)
1208
+
1209
+ if is_pretrain:
1210
+ self.clip_encoder, preprocess = load('ViT-B/16')
1211
+ for p in self.clip_encoder.parameters():
1212
+ p.requires_grad = False
1213
+ if ORP_path is None:
1214
+ background_image_folder_path = 'background_mages_folder/path'
1215
+ self.background_features = self.get_noise(
1216
+ background_image_folder_path, preprocess)
1217
+ torch.save(self.background_features, 'save/noise/to/ORP_path')
1218
+ else:
1219
+ self.background_features = torch.load(ORP_path,
1220
+ map_location='cpu')
1221
+
1222
+ def _init_weights(self, module: nn.Module):
1223
+ """Initialize the weights using the typical initialization schemes used
1224
+ in SOTA models."""
1225
+
1226
+ if isinstance(module, nn.Linear):
1227
+ nn.init.trunc_normal_(module.weight, std=0.02)
1228
+ if module.bias is not None:
1229
+ nn.init.zeros_(module.bias)
1230
+ elif isinstance(module, nn.Embedding):
1231
+ nn.init.trunc_normal_(module.weight, std=0.02)
1232
+ if module.padding_idx is not None:
1233
+ module.weight.data[module.padding_idx].zero_()
1234
+ elif isinstance(module, nn.Conv2d):
1235
+ nn.init.kaiming_normal_(module.weight,
1236
+ mode='fan_out',
1237
+ nonlinearity='relu')
1238
+ if module.bias is not None:
1239
+ nn.init.zeros_(module.bias)
1240
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
1241
+ nn.init.ones_(module.weight)
1242
+ nn.init.zeros_(module.bias)
1243
+
1244
+ @torch.jit.ignore
1245
+ def no_weight_decay(self):
1246
+ param_names = {'text_embed.embedding.weight', 'pos_queries'}
1247
+ return param_names
1248
+
1249
+ def get_noise(self, background_image_path, preprocess):
1250
+ image_paths = [
1251
+ os.path.join(background_image_path, filename)
1252
+ for filename in os.listdir(image_folder_path)
1253
+ if filename.endswith(('.png', '.jpg', '.jpeg'))
1254
+ ]
1255
+ features = []
1256
+ for image_path in image_paths:
1257
+ image = Image.open(image_path)
1258
+ input = preprocess(image).unsqueeze(0).to(self._device)
1259
+ with torch.no_grad():
1260
+ feature = self.clip_encoder.encode_image(input)
1261
+ features.append(feature)
1262
+ image.close()
1263
+ return torch.cat(features).cpu().numpy()
1264
+
1265
+ def clip_encode(self, labels):
1266
+ text_inputs = torch.cat(
1267
+ [tokenize(f"a photo of a '{c}'") for c in labels]).to(self._device)
1268
+
1269
+ return self.clip_encoder.encode_text(text_inputs)
1270
+
1271
+ def decode(
1272
+ self,
1273
+ tgt: torch.Tensor,
1274
+ memory: torch.Tensor,
1275
+ tgt_mask: Optional[Tensor] = None,
1276
+ tgt_padding_mask: Optional[Tensor] = None,
1277
+ tgt_query: Optional[Tensor] = None,
1278
+ tgt_query_mask: Optional[Tensor] = None,
1279
+ pos_query: torch.Tensor = None,
1280
+ ):
1281
+ N, L = tgt.shape
1282
+ # <bos> stands for the null context. We only supply position information for characters after <bos>.
1283
+ null_ctx = self.text_embed(tgt[:, :1])
1284
+
1285
+ if tgt_query is None:
1286
+ tgt_query = pos_query[:, :L]
1287
+ tgt_emb = pos_query[:, :L - 1] + self.text_embed(tgt[:, 1:])
1288
+ tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
1289
+
1290
+ tgt_query = self.dropout(tgt_query)
1291
+ return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask,
1292
+ tgt_mask, tgt_padding_mask)
1293
+
1294
+ def forward(self, memory, data=None, pos_query=None):
1295
+ # print(memory.shape, data[0].shape)
1296
+ if self.training:
1297
+ if self.is_pretrain:
1298
+ return self.training_step(None, pos_query, data[0], memory)
1299
+ return self.training_step(memory, pos_query, data[0], None)
1300
+ else:
1301
+ if self.is_pretrain:
1302
+ return self.forward_test(None, memory, pos_query)
1303
+ return self.forward_test(memory, None, pos_query)
1304
+
1305
+ def forward_test(self,
1306
+ memory: Tensor,
1307
+ clip_ids,
1308
+ pos_query: Tensor = None,
1309
+ max_length: Optional[int] = None) -> Tensor:
1310
+ testing = max_length is None
1311
+ max_length = (self.max_label_length if max_length is None else min(
1312
+ max_length, self.max_label_length))
1313
+
1314
+ if self.is_pretrain:
1315
+ memory = self.clip_encoder.encode_text(clip_ids)
1316
+ else:
1317
+ bs = memory.shape[0]
1318
+ token_query = self.token_query.expand(bs, -1, -1)
1319
+ memory = self.fmu(token_query, memory)
1320
+ _device = memory.get_device()
1321
+ bs = memory.shape[0]
1322
+ # +1 for <eos> at end of sequence.
1323
+ num_steps = max_length + 1
1324
+ # memory = self.encode(images)
1325
+
1326
+ # Query positions up to `num_steps`
1327
+ if pos_query is None:
1328
+ pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
1329
+ else:
1330
+ pos_queries = pos_query
1331
+
1332
+ # Special case for the forward permutation. Faster than using `generate_attn_masks()`
1333
+ tgt_mask = query_mask = torch.triu(
1334
+ torch.full((num_steps, num_steps), float('-inf'), device=_device),
1335
+ 1)
1336
+ self.attn_maps = []
1337
+ if self.decode_ar:
1338
+ tgt_in = torch.full((bs, num_steps),
1339
+ self.pad_id,
1340
+ dtype=torch.long,
1341
+ device=_device)
1342
+ tgt_in[:, 0] = self.bos_id
1343
+
1344
+ logits = []
1345
+ for i in range(num_steps):
1346
+ j = i + 1 # next token index
1347
+ # Efficient decoding:
1348
+ # Input the context up to the ith token. We use only one query (at position = i) at a time.
1349
+ # This works because of the lookahead masking effect of the canonical (forward) AR context.
1350
+ # Past tokens have no access to future tokens, hence are fixed once computed.
1351
+ tgt_out = self.decode(
1352
+ tgt_in[:, :j],
1353
+ memory,
1354
+ tgt_mask[:j, :j],
1355
+ tgt_query=pos_queries[:, i:j],
1356
+ tgt_query_mask=query_mask[i:j, :j],
1357
+ pos_query=pos_queries,
1358
+ )
1359
+ self.attn_maps.append(self.decoder.layers[-1].attn_map)
1360
+ # the next token probability is in the output's ith token position
1361
+ p_i = self.head(tgt_out)
1362
+ logits.append(p_i)
1363
+ if j < num_steps:
1364
+ # greedy decode. add the next token index to the target input
1365
+ tgt_in[:, j] = p_i.squeeze().argmax(-1)
1366
+ # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
1367
+ if testing and (tgt_in == self.eos_id).any(dim=-1).all():
1368
+ break
1369
+
1370
+ logits = torch.cat(logits, dim=1)
1371
+ else:
1372
+ # No prior context, so input is just <bos>. We query all positions.
1373
+ tgt_in = torch.full((bs, 1),
1374
+ self.bos_id,
1375
+ dtype=torch.long,
1376
+ device=_device)
1377
+ tgt_out = self.decode(tgt_in,
1378
+ memory,
1379
+ tgt_query=pos_queries,
1380
+ pos_query=pos_queries)
1381
+ logits = self.head(tgt_out)
1382
+
1383
+ if self.refine_iters:
1384
+ # For iterative refinement, we always use a 'cloze' mask.
1385
+ # We can derive it from the AR forward mask by unmasking the token context to the right.
1386
+ query_mask[torch.triu(
1387
+ torch.ones(num_steps,
1388
+ num_steps,
1389
+ dtype=torch.bool,
1390
+ device=_device), 2)] = 0
1391
+ bos = torch.full((bs, 1),
1392
+ self.bos_id,
1393
+ dtype=torch.long,
1394
+ device=_device)
1395
+ for i in range(self.refine_iters):
1396
+ # Prior context is the previous output.
1397
+ tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
1398
+ tgt_len = tgt_in.shape[1]
1399
+ tgt_padding_mask = (tgt_in == self.eos_id).int().cumsum(
1400
+ -1) > 0 # mask tokens beyond the first EOS token.
1401
+ tgt_out = self.decode(
1402
+ tgt_in,
1403
+ memory,
1404
+ tgt_mask[:tgt_len, :tgt_len],
1405
+ tgt_padding_mask,
1406
+ tgt_query=pos_queries,
1407
+ tgt_query_mask=query_mask[:, :tgt_len],
1408
+ pos_query=pos_queries,
1409
+ )
1410
+ logits = self.head(tgt_out)
1411
+
1412
+ return F.softmax(logits, -1)
1413
+
1414
+ def gen_tgt_perms(self, tgt, _device):
1415
+ """Generate shared permutations for the whole batch.
1416
+
1417
+ This works because the same attention mask can be used for the shorter
1418
+ sequences because of the padding mask.
1419
+ """
1420
+ # We don't permute the position of BOS, we permute EOS separately
1421
+ max_num_chars = tgt.shape[1] - 2
1422
+ # Special handling for 1-character sequences
1423
+ if max_num_chars == 1:
1424
+ return torch.arange(3, device=_device).unsqueeze(0)
1425
+ perms = [torch.arange(max_num_chars, device=_device)
1426
+ ] if self.perm_forward else []
1427
+ # Additional permutations if needed
1428
+ max_perms = math.factorial(max_num_chars)
1429
+ if self.perm_mirrored:
1430
+ max_perms //= 2
1431
+ num_gen_perms = min(self.max_gen_perms, max_perms)
1432
+ # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions
1433
+ # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars.
1434
+ if max_num_chars < 5:
1435
+ # Pool of permutations to sample from. We only need the first half (if complementary option is selected)
1436
+ # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
1437
+ if max_num_chars == 4 and self.perm_mirrored:
1438
+ selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
1439
+ else:
1440
+ selector = list(range(max_perms))
1441
+ perm_pool = torch.as_tensor(list(
1442
+ permutations(range(max_num_chars), max_num_chars)),
1443
+ device=_device)[selector]
1444
+ # If the forward permutation is always selected, no need to add it to the pool for sampling
1445
+ if self.perm_forward:
1446
+ perm_pool = perm_pool[1:]
1447
+ perms = torch.stack(perms)
1448
+ if len(perm_pool):
1449
+ i = self.rng.choice(len(perm_pool),
1450
+ size=num_gen_perms - len(perms),
1451
+ replace=False)
1452
+ perms = torch.cat([perms, perm_pool[i]])
1453
+ else:
1454
+ perms.extend([
1455
+ torch.randperm(max_num_chars, device=_device)
1456
+ for _ in range(num_gen_perms - len(perms))
1457
+ ])
1458
+ perms = torch.stack(perms)
1459
+ if self.perm_mirrored:
1460
+ # Add complementary pairs
1461
+ comp = perms.flip(-1)
1462
+ # Stack in such a way that the pairs are next to each other.
1463
+ perms = torch.stack([perms, comp
1464
+ ]).transpose(0, 1).reshape(-1, max_num_chars)
1465
+ # NOTE:
1466
+ # The only meaningful way of permuting the EOS position is by moving it one character position at a time.
1467
+ # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS
1468
+ # positions will always be much less than the number of permutations (unless a low perm_num is set).
1469
+ # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly
1470
+ # distribute it across the chosen number of permutations.
1471
+ # Add position indices of BOS and EOS
1472
+ bos_idx = perms.new_zeros((len(perms), 1))
1473
+ eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1)
1474
+ perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1)
1475
+ # Special handling for the reverse direction. This does two things:
1476
+ # 1. Reverse context for the characters
1477
+ # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode)
1478
+ if len(perms) > 1:
1479
+ perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1,
1480
+ device=_device)
1481
+ return perms
1482
+
1483
+ def generate_attn_masks(self, perm, _device):
1484
+ """Generate attention masks given a sequence permutation (includes pos.
1485
+ for bos and eos tokens)
1486
+
1487
+ :param perm: the permutation sequence. i = 0 is always the BOS
1488
+ :return: lookahead attention masks
1489
+ """
1490
+ sz = perm.shape[0]
1491
+ mask = torch.zeros((sz, sz), device=_device)
1492
+ for i in range(sz):
1493
+ query_idx = perm[i]
1494
+ masked_keys = perm[i + 1:]
1495
+ mask[query_idx, masked_keys] = float('-inf')
1496
+ content_mask = mask[:-1, :-1].clone()
1497
+ mask[torch.eye(sz, dtype=torch.bool,
1498
+ device=_device)] = float('-inf') # mask "self"
1499
+ query_mask = mask[1:, :-1]
1500
+ return content_mask, query_mask
1501
+
1502
+ def training_step(self, memory, pos_query, tgt_ids, clip_ids):
1503
+ bs = tgt_ids.shape[0]
1504
+ if self.is_pretrain:
1505
+ memory = self.clip_encoder.encode_text(clip_ids)
1506
+ n = memory.shape[1]
1507
+ B, N, D = self.background_features.shape
1508
+ random_B = np.random.choice(B, bs, replace=False)
1509
+ random_N = np.random.choice(N, n, replace=False)
1510
+ noise = self.background_features[random_B][:, random_N]
1511
+ noise = torch.from_numpy(noise).to(memory.get_device())
1512
+ memory = memory + noise * 1e-1
1513
+ else:
1514
+ token_query = self.token_query.expand(bs, -1, -1)
1515
+ memory = self.fmu(token_query, memory)
1516
+
1517
+ if pos_query is None:
1518
+ pos_query = self.pos_queries.expand(bs, -1, -1)
1519
+ # Prepare the target sequences (input and output)
1520
+ tgt_perms = self.gen_tgt_perms(tgt_ids, memory.get_device())
1521
+ tgt_in = tgt_ids[:, :-1]
1522
+ tgt_out = tgt_ids[:, 1:]
1523
+
1524
+ # The [EOS] token is not depended upon by any other token in any permutation ordering
1525
+ tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
1526
+
1527
+ loss = 0
1528
+ loss_numel = 0
1529
+ n = (tgt_out != self.pad_id).sum().item()
1530
+ for i, perm in enumerate(tgt_perms):
1531
+ tgt_mask, query_mask = self.generate_attn_masks(
1532
+ perm, memory.get_device())
1533
+ # print("tgt_in:", tgt_in, "tgt_out:", tgt_out, "tgt_padding_mask:", tgt_padding_mask)
1534
+ # print('tgt_mask:', tgt_mask)
1535
+ # print('query_mask:', query_mask)
1536
+ # print(tgt_in.shape, memory.shape, tgt_mask.shape, tgt_padding_mask.shape, query_mask.shape, pos_query.shape)
1537
+ out = self.decode(
1538
+ tgt_in,
1539
+ memory,
1540
+ tgt_mask,
1541
+ tgt_padding_mask,
1542
+ tgt_query_mask=query_mask,
1543
+ pos_query=pos_query,
1544
+ )
1545
+ # print('out:', out)
1546
+ logits = self.head(out)
1547
+ # print('logits:', logits)
1548
+ if i == 0:
1549
+ final_out = logits
1550
+ loss += n * F.cross_entropy(logits.flatten(end_dim=1),
1551
+ tgt_out.flatten(),
1552
+ ignore_index=self.pad_id)
1553
+ loss_numel += n
1554
+ # After the second iteration (i.e. done with canonical and reverse orderings),
1555
+ # remove the [EOS] tokens for the succeeding perms
1556
+ if i == 1:
1557
+ tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id,
1558
+ tgt_out)
1559
+ n = (tgt_out != self.pad_id).sum().item()
1560
+ loss /= loss_numel
1561
+
1562
+ # self.log('loss', loss)
1563
+ return [loss, final_out]