openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -1,1398 +1,1563 @@
|
|
|
1
|
-
# Scene Text Recognition Model Hub
|
|
2
|
-
# Copyright 2022 Darwin Bautista
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# https://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import math
|
|
17
|
-
from itertools import permutations
|
|
18
|
-
from collections import OrderedDict
|
|
19
|
-
import hashlib
|
|
20
|
-
import os
|
|
21
|
-
import gzip
|
|
22
|
-
import html
|
|
23
|
-
import urllib
|
|
24
|
-
import warnings
|
|
25
|
-
import numpy as np
|
|
26
|
-
import torch
|
|
27
|
-
import torch.nn as nn
|
|
28
|
-
import torch.nn.functional as F
|
|
29
|
-
from torch import Tensor
|
|
30
|
-
from torch.nn.modules import transformer
|
|
31
|
-
from typing import Any, Optional, Tuple, List, Union
|
|
32
|
-
from pkg_resources import packaging
|
|
33
|
-
from PIL import Image
|
|
34
|
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
|
35
|
-
from tqdm import tqdm
|
|
36
|
-
from functools import lru_cache
|
|
37
|
-
|
|
38
|
-
import ftfy
|
|
39
|
-
import regex as re
|
|
40
|
-
|
|
41
|
-
try:
|
|
42
|
-
from torchvision.transforms import InterpolationMode
|
|
43
|
-
BICUBIC = InterpolationMode.BICUBIC
|
|
44
|
-
except ImportError:
|
|
45
|
-
BICUBIC = Image.BICUBIC
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
@lru_cache()
|
|
49
|
-
def default_bpe():
|
|
50
|
-
return os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
text =
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
for
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
return
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
def
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
self.
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
self.
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
x
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
self.
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
self.
|
|
508
|
-
self.
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
x =
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
self.
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
)
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
x = self.
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
x = x.permute(1, 0, 2) # NLD -> LND
|
|
645
|
-
x = self.transformer(x)
|
|
646
|
-
x = x.permute(1, 0, 2) # LND -> NLD
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
self.
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
self.
|
|
815
|
-
self.
|
|
816
|
-
self.
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
self
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
return
|
|
1248
|
-
|
|
1249
|
-
def
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
logits = self.head(
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
|
|
1
|
+
# Scene Text Recognition Model Hub
|
|
2
|
+
# Copyright 2022 Darwin Bautista
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
from itertools import permutations
|
|
18
|
+
from collections import OrderedDict
|
|
19
|
+
import hashlib
|
|
20
|
+
import os
|
|
21
|
+
import gzip
|
|
22
|
+
import html
|
|
23
|
+
import urllib
|
|
24
|
+
import warnings
|
|
25
|
+
import numpy as np
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn as nn
|
|
28
|
+
import torch.nn.functional as F
|
|
29
|
+
from torch import Tensor
|
|
30
|
+
from torch.nn.modules import transformer
|
|
31
|
+
from typing import Any, Optional, Tuple, List, Union
|
|
32
|
+
from pkg_resources import packaging
|
|
33
|
+
from PIL import Image
|
|
34
|
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
|
35
|
+
from tqdm import tqdm
|
|
36
|
+
from functools import lru_cache
|
|
37
|
+
|
|
38
|
+
import ftfy
|
|
39
|
+
import regex as re
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
from torchvision.transforms import InterpolationMode
|
|
43
|
+
BICUBIC = InterpolationMode.BICUBIC
|
|
44
|
+
except ImportError:
|
|
45
|
+
BICUBIC = Image.BICUBIC
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@lru_cache()
|
|
49
|
+
def default_bpe():
|
|
50
|
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
|
51
|
+
'bpe_simple_vocab_16e6.txt.gz')
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@lru_cache()
|
|
55
|
+
def bytes_to_unicode():
|
|
56
|
+
"""
|
|
57
|
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
|
58
|
+
The reversible bpe codes work on unicode strings.
|
|
59
|
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
|
60
|
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
|
61
|
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
|
62
|
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
|
63
|
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
|
64
|
+
"""
|
|
65
|
+
bs = list(range(ord('!'),
|
|
66
|
+
ord('~') + 1)) + list(range(
|
|
67
|
+
ord('¡'),
|
|
68
|
+
ord('¬') + 1)) + list(range(ord('®'),
|
|
69
|
+
ord('ÿ') + 1))
|
|
70
|
+
cs = bs[:]
|
|
71
|
+
n = 0
|
|
72
|
+
for b in range(2**8):
|
|
73
|
+
if b not in bs:
|
|
74
|
+
bs.append(b)
|
|
75
|
+
cs.append(2**8 + n)
|
|
76
|
+
n += 1
|
|
77
|
+
cs = [chr(n) for n in cs]
|
|
78
|
+
return dict(zip(bs, cs))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_pairs(word):
|
|
82
|
+
"""Return set of symbol pairs in a word.
|
|
83
|
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
|
84
|
+
"""
|
|
85
|
+
pairs = set()
|
|
86
|
+
prev_char = word[0]
|
|
87
|
+
for char in word[1:]:
|
|
88
|
+
pairs.add((prev_char, char))
|
|
89
|
+
prev_char = char
|
|
90
|
+
return pairs
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def basic_clean(text):
|
|
94
|
+
text = ftfy.fix_text(text)
|
|
95
|
+
text = html.unescape(html.unescape(text))
|
|
96
|
+
return text.strip()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def whitespace_clean(text):
|
|
100
|
+
text = re.sub(r'\s+', ' ', text)
|
|
101
|
+
text = text.strip()
|
|
102
|
+
return text
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class SimpleTokenizer(object):
|
|
106
|
+
|
|
107
|
+
def __init__(self, bpe_path: str = default_bpe()):
|
|
108
|
+
self.byte_encoder = bytes_to_unicode()
|
|
109
|
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
|
110
|
+
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
|
|
111
|
+
merges = merges[1:49152 - 256 - 2 + 1]
|
|
112
|
+
merges = [tuple(merge.split()) for merge in merges]
|
|
113
|
+
vocab = list(bytes_to_unicode().values())
|
|
114
|
+
vocab = vocab + [v + '</w>' for v in vocab]
|
|
115
|
+
for merge in merges:
|
|
116
|
+
vocab.append(''.join(merge))
|
|
117
|
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
|
118
|
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
|
119
|
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
|
120
|
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
|
121
|
+
self.cache = {
|
|
122
|
+
'<|startoftext|>': '<|startoftext|>',
|
|
123
|
+
'<|endoftext|>': '<|endoftext|>'
|
|
124
|
+
}
|
|
125
|
+
self.pat = re.compile(
|
|
126
|
+
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
|
127
|
+
re.IGNORECASE)
|
|
128
|
+
|
|
129
|
+
def bpe(self, token):
|
|
130
|
+
if token in self.cache:
|
|
131
|
+
return self.cache[token]
|
|
132
|
+
word = tuple(token[:-1]) + (token[-1] + '</w>', )
|
|
133
|
+
pairs = get_pairs(word)
|
|
134
|
+
|
|
135
|
+
if not pairs:
|
|
136
|
+
return token + '</w>'
|
|
137
|
+
|
|
138
|
+
while True:
|
|
139
|
+
bigram = min(
|
|
140
|
+
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
|
141
|
+
if bigram not in self.bpe_ranks:
|
|
142
|
+
break
|
|
143
|
+
first, second = bigram
|
|
144
|
+
new_word = []
|
|
145
|
+
i = 0
|
|
146
|
+
while i < len(word):
|
|
147
|
+
try:
|
|
148
|
+
j = word.index(first, i)
|
|
149
|
+
new_word.extend(word[i:j])
|
|
150
|
+
i = j
|
|
151
|
+
except:
|
|
152
|
+
new_word.extend(word[i:])
|
|
153
|
+
break
|
|
154
|
+
|
|
155
|
+
if word[i] == first and i < len(word) - 1 and word[
|
|
156
|
+
i + 1] == second:
|
|
157
|
+
new_word.append(first + second)
|
|
158
|
+
i += 2
|
|
159
|
+
else:
|
|
160
|
+
new_word.append(word[i])
|
|
161
|
+
i += 1
|
|
162
|
+
new_word = tuple(new_word)
|
|
163
|
+
word = new_word
|
|
164
|
+
if len(word) == 1:
|
|
165
|
+
break
|
|
166
|
+
else:
|
|
167
|
+
pairs = get_pairs(word)
|
|
168
|
+
word = ' '.join(word)
|
|
169
|
+
self.cache[token] = word
|
|
170
|
+
return word
|
|
171
|
+
|
|
172
|
+
def encode(self, text):
|
|
173
|
+
bpe_tokens = []
|
|
174
|
+
text = whitespace_clean(basic_clean(text)).lower()
|
|
175
|
+
for token in re.findall(self.pat, text):
|
|
176
|
+
token = ''.join(self.byte_encoder[b]
|
|
177
|
+
for b in token.encode('utf-8'))
|
|
178
|
+
bpe_tokens.extend(self.encoder[bpe_token]
|
|
179
|
+
for bpe_token in self.bpe(token).split(' '))
|
|
180
|
+
return bpe_tokens
|
|
181
|
+
|
|
182
|
+
def decode(self, tokens):
|
|
183
|
+
text = ''.join([self.decoder[token] for token in tokens])
|
|
184
|
+
text = bytearray([self.byte_decoder[c] for c in text
|
|
185
|
+
]).decode('utf-8',
|
|
186
|
+
errors='replace').replace('</w>', ' ')
|
|
187
|
+
return text
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
if packaging.version.parse(
|
|
191
|
+
torch.__version__) < packaging.version.parse('1.7.1'):
|
|
192
|
+
warnings.warn('PyTorch version 1.7.1 or higher is recommended')
|
|
193
|
+
|
|
194
|
+
__all__ = ['available_models', 'load', 'tokenize']
|
|
195
|
+
_tokenizer = SimpleTokenizer()
|
|
196
|
+
|
|
197
|
+
_MODELS = {
|
|
198
|
+
'RN50':
|
|
199
|
+
'https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt',
|
|
200
|
+
'RN101':
|
|
201
|
+
'https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt',
|
|
202
|
+
'RN50x4':
|
|
203
|
+
'https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt',
|
|
204
|
+
'RN50x16':
|
|
205
|
+
'https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt',
|
|
206
|
+
'RN50x64':
|
|
207
|
+
'https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt',
|
|
208
|
+
'ViT-B/32':
|
|
209
|
+
'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt',
|
|
210
|
+
'ViT-B/16':
|
|
211
|
+
'https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt',
|
|
212
|
+
'ViT-L/14':
|
|
213
|
+
'https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt',
|
|
214
|
+
'ViT-L/14@336px':
|
|
215
|
+
'https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt',
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def convert_weights(model: nn.Module):
|
|
220
|
+
"""Convert applicable model parameters to fp16"""
|
|
221
|
+
|
|
222
|
+
def _convert_weights_to_fp16(l):
|
|
223
|
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
|
224
|
+
l.weight.data = l.weight.data.half()
|
|
225
|
+
if l.bias is not None:
|
|
226
|
+
l.bias.data = l.bias.data.half()
|
|
227
|
+
|
|
228
|
+
if isinstance(l, nn.MultiheadAttention):
|
|
229
|
+
for attr in [
|
|
230
|
+
*[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
|
|
231
|
+
'in_proj_bias', 'bias_k', 'bias_v'
|
|
232
|
+
]:
|
|
233
|
+
tensor = getattr(l, attr)
|
|
234
|
+
if tensor is not None:
|
|
235
|
+
tensor.data = tensor.data.half()
|
|
236
|
+
|
|
237
|
+
for name in ['text_projection', 'proj']:
|
|
238
|
+
if hasattr(l, name):
|
|
239
|
+
attr = getattr(l, name)
|
|
240
|
+
if attr is not None:
|
|
241
|
+
attr.data = attr.data.half()
|
|
242
|
+
|
|
243
|
+
model.apply(_convert_weights_to_fp16)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def build_model(state_dict: dict):
|
|
247
|
+
vit = 'visual.proj' in state_dict
|
|
248
|
+
|
|
249
|
+
if vit:
|
|
250
|
+
vision_width = state_dict['visual.conv1.weight'].shape[0]
|
|
251
|
+
vision_layers = len([
|
|
252
|
+
k for k in state_dict.keys()
|
|
253
|
+
if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
|
|
254
|
+
])
|
|
255
|
+
vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
|
|
256
|
+
grid_size = round(
|
|
257
|
+
(state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
|
|
258
|
+
image_resolution = vision_patch_size * grid_size
|
|
259
|
+
else:
|
|
260
|
+
counts: list = [
|
|
261
|
+
len(
|
|
262
|
+
set(
|
|
263
|
+
k.split('.')[2] for k in state_dict
|
|
264
|
+
if k.startswith(f'visual.layer{b}')))
|
|
265
|
+
for b in [1, 2, 3, 4]
|
|
266
|
+
]
|
|
267
|
+
vision_layers = tuple(counts)
|
|
268
|
+
vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0]
|
|
269
|
+
output_width = round(
|
|
270
|
+
(state_dict['visual.attnpool.positional_embedding'].shape[0] -
|
|
271
|
+
1)**0.5)
|
|
272
|
+
vision_patch_size = None
|
|
273
|
+
assert output_width**2 + 1 == state_dict[
|
|
274
|
+
'visual.attnpool.positional_embedding'].shape[0]
|
|
275
|
+
image_resolution = output_width * 32
|
|
276
|
+
|
|
277
|
+
embed_dim = state_dict['text_projection'].shape[1]
|
|
278
|
+
context_length = state_dict['positional_embedding'].shape[0]
|
|
279
|
+
vocab_size = state_dict['token_embedding.weight'].shape[0]
|
|
280
|
+
transformer_width = state_dict['ln_final.weight'].shape[0]
|
|
281
|
+
transformer_heads = transformer_width // 64
|
|
282
|
+
transformer_layers = len(
|
|
283
|
+
set(
|
|
284
|
+
k.split('.')[2] for k in state_dict
|
|
285
|
+
if k.startswith('transformer.resblocks')))
|
|
286
|
+
|
|
287
|
+
model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
|
|
288
|
+
vision_patch_size, context_length, vocab_size,
|
|
289
|
+
transformer_width, transformer_heads, transformer_layers)
|
|
290
|
+
|
|
291
|
+
for key in ['input_resolution', 'context_length', 'vocab_size']:
|
|
292
|
+
if key in state_dict:
|
|
293
|
+
del state_dict[key]
|
|
294
|
+
|
|
295
|
+
convert_weights(model)
|
|
296
|
+
model.load_state_dict(state_dict)
|
|
297
|
+
return model.eval()
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _download(url: str, root: str):
|
|
301
|
+
os.makedirs(root, exist_ok=True)
|
|
302
|
+
filename = os.path.basename(url)
|
|
303
|
+
|
|
304
|
+
expected_sha256 = url.split('/')[-2]
|
|
305
|
+
download_target = os.path.join(root, filename)
|
|
306
|
+
|
|
307
|
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
|
308
|
+
raise RuntimeError(
|
|
309
|
+
f'{download_target} exists and is not a regular file')
|
|
310
|
+
|
|
311
|
+
if os.path.isfile(download_target):
|
|
312
|
+
if hashlib.sha256(open(download_target,
|
|
313
|
+
'rb').read()).hexdigest() == expected_sha256:
|
|
314
|
+
return download_target
|
|
315
|
+
else:
|
|
316
|
+
warnings.warn(
|
|
317
|
+
f'{download_target} exists, but the SHA256 checksum does not match; re-downloading the file'
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
with urllib.request.urlopen(url) as source, open(download_target,
|
|
321
|
+
'wb') as output:
|
|
322
|
+
with tqdm(total=int(source.info().get('Content-Length')),
|
|
323
|
+
ncols=80,
|
|
324
|
+
unit='iB',
|
|
325
|
+
unit_scale=True,
|
|
326
|
+
unit_divisor=1024) as loop:
|
|
327
|
+
while True:
|
|
328
|
+
buffer = source.read(8192)
|
|
329
|
+
if not buffer:
|
|
330
|
+
break
|
|
331
|
+
|
|
332
|
+
output.write(buffer)
|
|
333
|
+
loop.update(len(buffer))
|
|
334
|
+
|
|
335
|
+
if hashlib.sha256(open(download_target,
|
|
336
|
+
'rb').read()).hexdigest() != expected_sha256:
|
|
337
|
+
raise RuntimeError(
|
|
338
|
+
'Model has been downloaded but the SHA256 checksum does not not match'
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return download_target
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def _convert_image_to_rgb(image):
|
|
345
|
+
return image.convert('RGB')
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def _transform(n_px):
|
|
349
|
+
return Compose([
|
|
350
|
+
Resize(n_px, interpolation=BICUBIC),
|
|
351
|
+
CenterCrop(n_px),
|
|
352
|
+
_convert_image_to_rgb,
|
|
353
|
+
ToTensor(),
|
|
354
|
+
Normalize((0.48145466, 0.4578275, 0.40821073),
|
|
355
|
+
(0.26862954, 0.26130258, 0.27577711)),
|
|
356
|
+
])
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def available_models() -> List[str]:
|
|
360
|
+
"""Returns the names of available CLIP models"""
|
|
361
|
+
return list(_MODELS.keys())
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class Bottleneck(nn.Module):
|
|
365
|
+
expansion = 4
|
|
366
|
+
|
|
367
|
+
def __init__(self, inplanes, planes, stride=1):
|
|
368
|
+
super().__init__()
|
|
369
|
+
|
|
370
|
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
|
371
|
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
|
372
|
+
self.bn1 = nn.BatchNorm2d(planes)
|
|
373
|
+
self.relu1 = nn.ReLU(inplace=True)
|
|
374
|
+
|
|
375
|
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
|
376
|
+
self.bn2 = nn.BatchNorm2d(planes)
|
|
377
|
+
self.relu2 = nn.ReLU(inplace=True)
|
|
378
|
+
|
|
379
|
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
|
380
|
+
|
|
381
|
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
|
382
|
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
|
383
|
+
self.relu3 = nn.ReLU(inplace=True)
|
|
384
|
+
|
|
385
|
+
self.downsample = None
|
|
386
|
+
self.stride = stride
|
|
387
|
+
|
|
388
|
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
|
389
|
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
|
390
|
+
self.downsample = nn.Sequential(
|
|
391
|
+
OrderedDict([('-1', nn.AvgPool2d(stride)),
|
|
392
|
+
('0',
|
|
393
|
+
nn.Conv2d(inplanes,
|
|
394
|
+
planes * self.expansion,
|
|
395
|
+
1,
|
|
396
|
+
stride=1,
|
|
397
|
+
bias=False)),
|
|
398
|
+
('1', nn.BatchNorm2d(planes * self.expansion))]))
|
|
399
|
+
|
|
400
|
+
def forward(self, x: torch.Tensor):
|
|
401
|
+
identity = x
|
|
402
|
+
|
|
403
|
+
out = self.relu1(self.bn1(self.conv1(x)))
|
|
404
|
+
out = self.relu2(self.bn2(self.conv2(out)))
|
|
405
|
+
out = self.avgpool(out)
|
|
406
|
+
out = self.bn3(self.conv3(out))
|
|
407
|
+
|
|
408
|
+
if self.downsample is not None:
|
|
409
|
+
identity = self.downsample(x)
|
|
410
|
+
|
|
411
|
+
out += identity
|
|
412
|
+
out = self.relu3(out)
|
|
413
|
+
return out
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
class AttentionPool2d(nn.Module):
|
|
417
|
+
|
|
418
|
+
def __init__(self,
|
|
419
|
+
spacial_dim: int,
|
|
420
|
+
embed_dim: int,
|
|
421
|
+
num_heads: int,
|
|
422
|
+
output_dim: int = None):
|
|
423
|
+
super().__init__()
|
|
424
|
+
self.positional_embedding = nn.Parameter(
|
|
425
|
+
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
|
|
426
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
|
427
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
|
428
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
|
429
|
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
|
430
|
+
self.num_heads = num_heads
|
|
431
|
+
|
|
432
|
+
def forward(self, x):
|
|
433
|
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
|
434
|
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
|
435
|
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
|
436
|
+
x, _ = F.multi_head_attention_forward(
|
|
437
|
+
query=x[:1],
|
|
438
|
+
key=x,
|
|
439
|
+
value=x,
|
|
440
|
+
embed_dim_to_check=x.shape[-1],
|
|
441
|
+
num_heads=self.num_heads,
|
|
442
|
+
q_proj_weight=self.q_proj.weight,
|
|
443
|
+
k_proj_weight=self.k_proj.weight,
|
|
444
|
+
v_proj_weight=self.v_proj.weight,
|
|
445
|
+
in_proj_weight=None,
|
|
446
|
+
in_proj_bias=torch.cat(
|
|
447
|
+
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
|
448
|
+
bias_k=None,
|
|
449
|
+
bias_v=None,
|
|
450
|
+
add_zero_attn=False,
|
|
451
|
+
dropout_p=0,
|
|
452
|
+
out_proj_weight=self.c_proj.weight,
|
|
453
|
+
out_proj_bias=self.c_proj.bias,
|
|
454
|
+
use_separate_proj_weight=True,
|
|
455
|
+
training=self.training,
|
|
456
|
+
need_weights=False)
|
|
457
|
+
return x.squeeze(0)
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
class ModifiedResNet(nn.Module):
|
|
461
|
+
"""
|
|
462
|
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
|
463
|
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
|
464
|
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
|
465
|
+
- The final pooling layer is a QKV attention instead of an average pool
|
|
466
|
+
"""
|
|
467
|
+
|
|
468
|
+
def __init__(self,
|
|
469
|
+
layers,
|
|
470
|
+
output_dim,
|
|
471
|
+
heads,
|
|
472
|
+
input_resolution=224,
|
|
473
|
+
width=64):
|
|
474
|
+
super().__init__()
|
|
475
|
+
self.output_dim = output_dim
|
|
476
|
+
self.input_resolution = input_resolution
|
|
477
|
+
|
|
478
|
+
# the 3-layer stem
|
|
479
|
+
self.conv1 = nn.Conv2d(3,
|
|
480
|
+
width // 2,
|
|
481
|
+
kernel_size=3,
|
|
482
|
+
stride=2,
|
|
483
|
+
padding=1,
|
|
484
|
+
bias=False)
|
|
485
|
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
|
486
|
+
self.relu1 = nn.ReLU(inplace=True)
|
|
487
|
+
self.conv2 = nn.Conv2d(width // 2,
|
|
488
|
+
width // 2,
|
|
489
|
+
kernel_size=3,
|
|
490
|
+
padding=1,
|
|
491
|
+
bias=False)
|
|
492
|
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
|
493
|
+
self.relu2 = nn.ReLU(inplace=True)
|
|
494
|
+
self.conv3 = nn.Conv2d(width // 2,
|
|
495
|
+
width,
|
|
496
|
+
kernel_size=3,
|
|
497
|
+
padding=1,
|
|
498
|
+
bias=False)
|
|
499
|
+
self.bn3 = nn.BatchNorm2d(width)
|
|
500
|
+
self.relu3 = nn.ReLU(inplace=True)
|
|
501
|
+
self.avgpool = nn.AvgPool2d(2)
|
|
502
|
+
|
|
503
|
+
# residual layers
|
|
504
|
+
self._inplanes = width # this is a *mutable* variable used during construction
|
|
505
|
+
self.layer1 = self._make_layer(width, layers[0])
|
|
506
|
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
|
507
|
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
|
508
|
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
|
509
|
+
|
|
510
|
+
embed_dim = width * 32 # the ResNet feature dimension
|
|
511
|
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
|
|
512
|
+
heads, output_dim)
|
|
513
|
+
|
|
514
|
+
def _make_layer(self, planes, blocks, stride=1):
|
|
515
|
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
|
516
|
+
|
|
517
|
+
self._inplanes = planes * Bottleneck.expansion
|
|
518
|
+
for _ in range(1, blocks):
|
|
519
|
+
layers.append(Bottleneck(self._inplanes, planes))
|
|
520
|
+
|
|
521
|
+
return nn.Sequential(*layers)
|
|
522
|
+
|
|
523
|
+
def forward(self, x):
|
|
524
|
+
|
|
525
|
+
def stem(x):
|
|
526
|
+
x = self.relu1(self.bn1(self.conv1(x)))
|
|
527
|
+
x = self.relu2(self.bn2(self.conv2(x)))
|
|
528
|
+
x = self.relu3(self.bn3(self.conv3(x)))
|
|
529
|
+
x = self.avgpool(x)
|
|
530
|
+
return x
|
|
531
|
+
|
|
532
|
+
x = x.type(self.conv1.weight.dtype)
|
|
533
|
+
x = stem(x)
|
|
534
|
+
x = self.layer1(x)
|
|
535
|
+
x = self.layer2(x)
|
|
536
|
+
x = self.layer3(x)
|
|
537
|
+
x = self.layer4(x)
|
|
538
|
+
x = self.attnpool(x)
|
|
539
|
+
|
|
540
|
+
return x
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
class LayerNorm(nn.LayerNorm):
|
|
544
|
+
"""Subclass torch's LayerNorm to handle fp16."""
|
|
545
|
+
|
|
546
|
+
def forward(self, x: torch.Tensor):
|
|
547
|
+
orig_type = x.dtype
|
|
548
|
+
ret = super().forward(x.type(torch.float32))
|
|
549
|
+
return ret.type(orig_type)
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
class QuickGELU(nn.Module):
|
|
553
|
+
|
|
554
|
+
def forward(self, x: torch.Tensor):
|
|
555
|
+
return x * torch.sigmoid(1.702 * x)
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
class ResidualAttentionBlock(nn.Module):
|
|
559
|
+
|
|
560
|
+
def __init__(self,
|
|
561
|
+
d_model: int,
|
|
562
|
+
n_head: int,
|
|
563
|
+
attn_mask: torch.Tensor = None):
|
|
564
|
+
super().__init__()
|
|
565
|
+
|
|
566
|
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
|
567
|
+
self.ln_1 = LayerNorm(d_model)
|
|
568
|
+
self.mlp = nn.Sequential(
|
|
569
|
+
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
|
|
570
|
+
('gelu', QuickGELU()),
|
|
571
|
+
('c_proj', nn.Linear(d_model * 4, d_model))]))
|
|
572
|
+
self.ln_2 = LayerNorm(d_model)
|
|
573
|
+
self.attn_mask = attn_mask
|
|
574
|
+
|
|
575
|
+
def attention(self, x: torch.Tensor):
|
|
576
|
+
self.attn_mask = self.attn_mask.to(
|
|
577
|
+
dtype=x.dtype,
|
|
578
|
+
device=x.device) if self.attn_mask is not None else None
|
|
579
|
+
return self.attn(x, x, x, need_weights=False,
|
|
580
|
+
attn_mask=self.attn_mask)[0]
|
|
581
|
+
|
|
582
|
+
def forward(self, x: torch.Tensor):
|
|
583
|
+
x = x + self.attention(self.ln_1(x))
|
|
584
|
+
x = x + self.mlp(self.ln_2(x))
|
|
585
|
+
return x
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
class Transformer(nn.Module):
|
|
589
|
+
|
|
590
|
+
def __init__(self,
|
|
591
|
+
width: int,
|
|
592
|
+
layers: int,
|
|
593
|
+
heads: int,
|
|
594
|
+
attn_mask: torch.Tensor = None):
|
|
595
|
+
super().__init__()
|
|
596
|
+
self.width = width
|
|
597
|
+
self.layers = layers
|
|
598
|
+
self.resblocks = nn.Sequential(*[
|
|
599
|
+
ResidualAttentionBlock(width, heads, attn_mask)
|
|
600
|
+
for _ in range(layers)
|
|
601
|
+
])
|
|
602
|
+
|
|
603
|
+
def forward(self, x: torch.Tensor):
|
|
604
|
+
return self.resblocks(x)
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
class VisionTransformer(nn.Module):
|
|
608
|
+
|
|
609
|
+
def __init__(self, input_resolution: int, patch_size: int, width: int,
|
|
610
|
+
layers: int, heads: int, output_dim: int):
|
|
611
|
+
super().__init__()
|
|
612
|
+
self.input_resolution = input_resolution
|
|
613
|
+
self.output_dim = output_dim
|
|
614
|
+
self.conv1 = nn.Conv2d(in_channels=3,
|
|
615
|
+
out_channels=width,
|
|
616
|
+
kernel_size=patch_size,
|
|
617
|
+
stride=patch_size,
|
|
618
|
+
bias=False)
|
|
619
|
+
|
|
620
|
+
scale = width**-0.5
|
|
621
|
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
|
622
|
+
self.positional_embedding = nn.Parameter(scale * torch.randn(
|
|
623
|
+
(input_resolution // patch_size)**2 + 1, width))
|
|
624
|
+
self.ln_pre = LayerNorm(width)
|
|
625
|
+
|
|
626
|
+
self.transformer = Transformer(width, layers, heads)
|
|
627
|
+
|
|
628
|
+
self.ln_post = LayerNorm(width)
|
|
629
|
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
|
630
|
+
|
|
631
|
+
def forward(self, x: torch.Tensor):
|
|
632
|
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
|
633
|
+
x = x.reshape(x.shape[0], x.shape[1],
|
|
634
|
+
-1) # shape = [*, width, grid ** 2]
|
|
635
|
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
|
636
|
+
x = torch.cat([
|
|
637
|
+
self.class_embedding.to(x.dtype) + torch.zeros(
|
|
638
|
+
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
|
|
639
|
+
],
|
|
640
|
+
dim=1) # shape = [*, grid ** 2 + 1, width]
|
|
641
|
+
x = x + self.positional_embedding.to(x.dtype)
|
|
642
|
+
x = self.ln_pre(x)
|
|
643
|
+
|
|
644
|
+
x = x.permute(1, 0, 2) # NLD -> LND
|
|
645
|
+
x = self.transformer(x)
|
|
646
|
+
x = x.permute(1, 0, 2) # LND -> NLD
|
|
647
|
+
|
|
648
|
+
x = self.ln_post(x)
|
|
649
|
+
if self.proj is not None:
|
|
650
|
+
x = x @ self.proj
|
|
651
|
+
|
|
652
|
+
return x
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
class CLIP(nn.Module):
|
|
656
|
+
|
|
657
|
+
def __init__(
|
|
658
|
+
self,
|
|
659
|
+
embed_dim: int,
|
|
660
|
+
# vision
|
|
661
|
+
image_resolution: int,
|
|
662
|
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
|
663
|
+
vision_width: int,
|
|
664
|
+
vision_patch_size: int,
|
|
665
|
+
# text
|
|
666
|
+
context_length: int,
|
|
667
|
+
vocab_size: int,
|
|
668
|
+
transformer_width: int,
|
|
669
|
+
transformer_heads: int,
|
|
670
|
+
transformer_layers: int):
|
|
671
|
+
super().__init__()
|
|
672
|
+
|
|
673
|
+
self.context_length = context_length
|
|
674
|
+
|
|
675
|
+
if isinstance(vision_layers, (tuple, list)):
|
|
676
|
+
vision_heads = vision_width * 32 // 64
|
|
677
|
+
self.visual = ModifiedResNet(layers=vision_layers,
|
|
678
|
+
output_dim=embed_dim,
|
|
679
|
+
heads=vision_heads,
|
|
680
|
+
input_resolution=image_resolution,
|
|
681
|
+
width=vision_width)
|
|
682
|
+
else:
|
|
683
|
+
vision_heads = vision_width // 64
|
|
684
|
+
self.visual = VisionTransformer(input_resolution=image_resolution,
|
|
685
|
+
patch_size=vision_patch_size,
|
|
686
|
+
width=vision_width,
|
|
687
|
+
layers=vision_layers,
|
|
688
|
+
heads=vision_heads,
|
|
689
|
+
output_dim=embed_dim)
|
|
690
|
+
|
|
691
|
+
self.transformer = Transformer(width=transformer_width,
|
|
692
|
+
layers=transformer_layers,
|
|
693
|
+
heads=transformer_heads,
|
|
694
|
+
attn_mask=self.build_attention_mask())
|
|
695
|
+
|
|
696
|
+
self.vocab_size = vocab_size
|
|
697
|
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
|
698
|
+
self.positional_embedding = nn.Parameter(
|
|
699
|
+
torch.empty(self.context_length, transformer_width))
|
|
700
|
+
self.ln_final = LayerNorm(transformer_width)
|
|
701
|
+
|
|
702
|
+
self.text_projection = nn.Parameter(
|
|
703
|
+
torch.empty(transformer_width, embed_dim))
|
|
704
|
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
705
|
+
|
|
706
|
+
self.initialize_parameters()
|
|
707
|
+
|
|
708
|
+
def initialize_parameters(self):
|
|
709
|
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
|
710
|
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
|
711
|
+
|
|
712
|
+
if isinstance(self.visual, ModifiedResNet):
|
|
713
|
+
if self.visual.attnpool is not None:
|
|
714
|
+
std = self.visual.attnpool.c_proj.in_features**-0.5
|
|
715
|
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
|
716
|
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
|
717
|
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
|
718
|
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
|
719
|
+
|
|
720
|
+
for resnet_block in [
|
|
721
|
+
self.visual.layer1, self.visual.layer2, self.visual.layer3,
|
|
722
|
+
self.visual.layer4
|
|
723
|
+
]:
|
|
724
|
+
for name, param in resnet_block.named_parameters():
|
|
725
|
+
if name.endswith('bn3.weight'):
|
|
726
|
+
nn.init.zeros_(param)
|
|
727
|
+
|
|
728
|
+
proj_std = (self.transformer.width**-0.5) * (
|
|
729
|
+
(2 * self.transformer.layers)**-0.5)
|
|
730
|
+
attn_std = self.transformer.width**-0.5
|
|
731
|
+
fc_std = (2 * self.transformer.width)**-0.5
|
|
732
|
+
for block in self.transformer.resblocks:
|
|
733
|
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
|
734
|
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
|
735
|
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
|
736
|
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
|
737
|
+
|
|
738
|
+
if self.text_projection is not None:
|
|
739
|
+
nn.init.normal_(self.text_projection,
|
|
740
|
+
std=self.transformer.width**-0.5)
|
|
741
|
+
|
|
742
|
+
def build_attention_mask(self):
|
|
743
|
+
# lazily create causal attention mask, with full attention between the vision tokens
|
|
744
|
+
# pytorch uses additive attention mask; fill with -inf
|
|
745
|
+
mask = torch.empty(self.context_length, self.context_length)
|
|
746
|
+
mask.fill_(float('-inf'))
|
|
747
|
+
mask.triu_(1) # zero out the lower diagonal
|
|
748
|
+
return mask
|
|
749
|
+
|
|
750
|
+
@property
|
|
751
|
+
def dtype(self):
|
|
752
|
+
return self.visual.conv1.weight.dtype
|
|
753
|
+
|
|
754
|
+
def encode_image(self, image):
|
|
755
|
+
return self.visual(image.type(self.dtype))
|
|
756
|
+
|
|
757
|
+
def encode_text(self, text):
|
|
758
|
+
x = self.token_embedding(text).type(
|
|
759
|
+
self.dtype) # [batch_size, n_ctx, d_model]
|
|
760
|
+
|
|
761
|
+
x = x + self.positional_embedding.type(self.dtype)
|
|
762
|
+
x = x.permute(1, 0, 2) # NLD -> LND
|
|
763
|
+
x = self.transformer(x)
|
|
764
|
+
x = x.permute(1, 0, 2) # LND -> NLD
|
|
765
|
+
x = self.ln_final(x).type(self.dtype)
|
|
766
|
+
|
|
767
|
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
|
768
|
+
output = x[torch.arange(x.shape[0]),
|
|
769
|
+
text.argmax(dim=-1)] @ self.text_projection
|
|
770
|
+
output = torch.cat([output.unsqueeze(1), x], dim=1)
|
|
771
|
+
|
|
772
|
+
return output
|
|
773
|
+
|
|
774
|
+
def forward(self, image, text):
|
|
775
|
+
image_features = self.encode_image(image)
|
|
776
|
+
text_features = self.encode_text(text)
|
|
777
|
+
|
|
778
|
+
# normalized features
|
|
779
|
+
image_features = image_features / image_features.norm(dim=1,
|
|
780
|
+
keepdim=True)
|
|
781
|
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
|
782
|
+
|
|
783
|
+
# cosine similarity as logits
|
|
784
|
+
logit_scale = self.logit_scale.exp()
|
|
785
|
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
|
786
|
+
logits_per_text = logits_per_image.t()
|
|
787
|
+
|
|
788
|
+
# shape = [global_batch_size, global_batch_size]
|
|
789
|
+
return logits_per_image, logits_per_text
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
class FMU(nn.Module):
|
|
793
|
+
"""A Transformer decoder layer supporting two-stream attention (XLNet)
|
|
794
|
+
This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
|
|
795
|
+
|
|
796
|
+
def __init__(self,
|
|
797
|
+
d_model,
|
|
798
|
+
nhead,
|
|
799
|
+
dim_feedforward=2048,
|
|
800
|
+
dropout=0.1,
|
|
801
|
+
activation='gelu',
|
|
802
|
+
layer_norm_eps=1e-5):
|
|
803
|
+
super().__init__()
|
|
804
|
+
self.cross_attn = nn.MultiheadAttention(d_model,
|
|
805
|
+
nhead,
|
|
806
|
+
dropout=dropout,
|
|
807
|
+
batch_first=True)
|
|
808
|
+
# Implementation of Feedforward model
|
|
809
|
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
810
|
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
811
|
+
|
|
812
|
+
self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
813
|
+
|
|
814
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
815
|
+
self.dropout2 = nn.Dropout(dropout)
|
|
816
|
+
self.dropout3 = nn.Dropout(dropout)
|
|
817
|
+
|
|
818
|
+
self.activation = transformer._get_activation_fn(activation)
|
|
819
|
+
|
|
820
|
+
def __setstate__(self, state):
|
|
821
|
+
if 'activation' not in state:
|
|
822
|
+
state['activation'] = F.gelu
|
|
823
|
+
super().__setstate__(state)
|
|
824
|
+
|
|
825
|
+
def forward(self, query: Tensor, memory: Tensor):
|
|
826
|
+
"""Forward pass for a single stream (i.e. content or query)
|
|
827
|
+
tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
|
|
828
|
+
Both tgt_kv and memory are expected to be LayerNorm'd too.
|
|
829
|
+
memory is LayerNorm'd by ViT.
|
|
830
|
+
"""
|
|
831
|
+
query1, ca_weights = self.cross_attn(query, memory, memory)
|
|
832
|
+
query = query + self.dropout1(query1)
|
|
833
|
+
|
|
834
|
+
query2 = self.linear2(
|
|
835
|
+
self.dropout2(self.activation(self.linear1(self.norm(query)))))
|
|
836
|
+
query = query + self.dropout3(query2)
|
|
837
|
+
|
|
838
|
+
return query
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
class DecoderLayer(nn.Module):
|
|
842
|
+
"""A Transformer decoder layer supporting two-stream attention (XLNet) This
|
|
843
|
+
implements a pre-LN decoder, as opposed to the post-LN default in
|
|
844
|
+
PyTorch."""
|
|
845
|
+
|
|
846
|
+
def __init__(
|
|
847
|
+
self,
|
|
848
|
+
d_model,
|
|
849
|
+
nhead,
|
|
850
|
+
dim_feedforward=2048,
|
|
851
|
+
dropout=0.1,
|
|
852
|
+
activation='gelu',
|
|
853
|
+
layer_norm_eps=1e-5,
|
|
854
|
+
):
|
|
855
|
+
super().__init__()
|
|
856
|
+
self.self_attn = nn.MultiheadAttention(d_model,
|
|
857
|
+
nhead,
|
|
858
|
+
dropout=dropout,
|
|
859
|
+
batch_first=True)
|
|
860
|
+
self.cross_attn = nn.MultiheadAttention(d_model,
|
|
861
|
+
nhead,
|
|
862
|
+
dropout=dropout,
|
|
863
|
+
batch_first=True)
|
|
864
|
+
# Implementation of Feedforward model
|
|
865
|
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
866
|
+
self.dropout = nn.Dropout(dropout)
|
|
867
|
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
868
|
+
|
|
869
|
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
870
|
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
871
|
+
self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
872
|
+
self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
873
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
874
|
+
self.dropout2 = nn.Dropout(dropout)
|
|
875
|
+
self.dropout3 = nn.Dropout(dropout)
|
|
876
|
+
|
|
877
|
+
self.activation = transformer._get_activation_fn(activation)
|
|
878
|
+
|
|
879
|
+
def __setstate__(self, state):
|
|
880
|
+
if 'activation' not in state:
|
|
881
|
+
state['activation'] = F.gelu
|
|
882
|
+
super().__setstate__(state)
|
|
883
|
+
|
|
884
|
+
def forward_stream(
|
|
885
|
+
self,
|
|
886
|
+
tgt: Tensor,
|
|
887
|
+
tgt_norm: Tensor,
|
|
888
|
+
tgt_kv: Tensor,
|
|
889
|
+
memory: Tensor,
|
|
890
|
+
tgt_mask: Optional[Tensor],
|
|
891
|
+
tgt_key_padding_mask: Optional[Tensor],
|
|
892
|
+
):
|
|
893
|
+
"""Forward pass for a single stream (i.e. content or query) tgt_norm is
|
|
894
|
+
just a LayerNorm'd tgt.
|
|
895
|
+
|
|
896
|
+
Added as a separate parameter for efficiency. Both tgt_kv and memory
|
|
897
|
+
are expected to be LayerNorm'd too. memory is LayerNorm'd by ViT.
|
|
898
|
+
"""
|
|
899
|
+
tgt2, sa_weights = self.self_attn(
|
|
900
|
+
tgt_norm,
|
|
901
|
+
tgt_kv,
|
|
902
|
+
tgt_kv,
|
|
903
|
+
attn_mask=tgt_mask,
|
|
904
|
+
key_padding_mask=tgt_key_padding_mask)
|
|
905
|
+
|
|
906
|
+
tgt = tgt + self.dropout1(tgt2)
|
|
907
|
+
|
|
908
|
+
tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
|
|
909
|
+
self.attn_map = ca_weights
|
|
910
|
+
tgt = tgt + self.dropout2(tgt2)
|
|
911
|
+
|
|
912
|
+
tgt2 = self.linear2(
|
|
913
|
+
self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
|
|
914
|
+
tgt = tgt + self.dropout3(tgt2)
|
|
915
|
+
return tgt, sa_weights, ca_weights
|
|
916
|
+
|
|
917
|
+
def forward(
|
|
918
|
+
self,
|
|
919
|
+
query,
|
|
920
|
+
content,
|
|
921
|
+
memory,
|
|
922
|
+
query_mask: Optional[Tensor] = None,
|
|
923
|
+
content_mask: Optional[Tensor] = None,
|
|
924
|
+
content_key_padding_mask: Optional[Tensor] = None,
|
|
925
|
+
update_content: bool = True,
|
|
926
|
+
):
|
|
927
|
+
query_norm = self.norm_q(query)
|
|
928
|
+
content_norm = self.norm_c(content)
|
|
929
|
+
query = self.forward_stream(query, query_norm, content_norm, memory,
|
|
930
|
+
query_mask, content_key_padding_mask)[0]
|
|
931
|
+
if update_content:
|
|
932
|
+
content = self.forward_stream(content, content_norm, content_norm,
|
|
933
|
+
memory, content_mask,
|
|
934
|
+
content_key_padding_mask)[0]
|
|
935
|
+
return query, content
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
class Decoder(nn.Module):
|
|
939
|
+
__constants__ = ['norm']
|
|
940
|
+
|
|
941
|
+
def __init__(self, decoder_layer, num_layers, norm):
|
|
942
|
+
super().__init__()
|
|
943
|
+
self.layers = transformer._get_clones(decoder_layer, num_layers)
|
|
944
|
+
self.num_layers = num_layers
|
|
945
|
+
self.norm = norm
|
|
946
|
+
|
|
947
|
+
def forward(
|
|
948
|
+
self,
|
|
949
|
+
query,
|
|
950
|
+
content,
|
|
951
|
+
memory,
|
|
952
|
+
query_mask: Optional[Tensor] = None,
|
|
953
|
+
content_mask: Optional[Tensor] = None,
|
|
954
|
+
content_key_padding_mask: Optional[Tensor] = None,
|
|
955
|
+
):
|
|
956
|
+
for i, mod in enumerate(self.layers):
|
|
957
|
+
last = i == len(self.layers) - 1
|
|
958
|
+
query, content = mod(
|
|
959
|
+
query,
|
|
960
|
+
content,
|
|
961
|
+
memory,
|
|
962
|
+
query_mask,
|
|
963
|
+
content_mask,
|
|
964
|
+
content_key_padding_mask,
|
|
965
|
+
update_content=not last,
|
|
966
|
+
)
|
|
967
|
+
query = self.norm(query)
|
|
968
|
+
return query
|
|
969
|
+
|
|
970
|
+
|
|
971
|
+
class TokenEmbedding(nn.Module):
|
|
972
|
+
|
|
973
|
+
def __init__(self, charset_size: int, embed_dim: int):
|
|
974
|
+
super().__init__()
|
|
975
|
+
self.embedding = nn.Embedding(charset_size, embed_dim)
|
|
976
|
+
self.embed_dim = embed_dim
|
|
977
|
+
|
|
978
|
+
def forward(self, tokens: torch.Tensor):
|
|
979
|
+
return math.sqrt(self.embed_dim) * self.embedding(tokens)
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
def load(name: str,
|
|
983
|
+
device: Union[str, torch.device] = 'cuda'
|
|
984
|
+
if torch.cuda.is_available() else 'cpu',
|
|
985
|
+
jit: bool = False,
|
|
986
|
+
download_root: str = None):
|
|
987
|
+
"""Load a CLIP model
|
|
988
|
+
|
|
989
|
+
Parameters
|
|
990
|
+
----------
|
|
991
|
+
name : str
|
|
992
|
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
|
993
|
+
|
|
994
|
+
device : Union[str, torch.device]
|
|
995
|
+
The device to put the loaded model
|
|
996
|
+
|
|
997
|
+
jit : bool
|
|
998
|
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
|
999
|
+
|
|
1000
|
+
download_root: str
|
|
1001
|
+
path to download the model files; by default, it uses "~/.cache/clip"
|
|
1002
|
+
|
|
1003
|
+
Returns
|
|
1004
|
+
-------
|
|
1005
|
+
model : torch.nn.Module
|
|
1006
|
+
The CLIP model
|
|
1007
|
+
|
|
1008
|
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
|
1009
|
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
|
1010
|
+
"""
|
|
1011
|
+
if name in _MODELS:
|
|
1012
|
+
model_path = _download(
|
|
1013
|
+
_MODELS[name], download_root
|
|
1014
|
+
or os.path.expanduser('~/.cache/clip'))
|
|
1015
|
+
elif os.path.isfile(name):
|
|
1016
|
+
model_path = name
|
|
1017
|
+
else:
|
|
1018
|
+
raise RuntimeError(
|
|
1019
|
+
f'Model {name} not found; available models = {available_models()}')
|
|
1020
|
+
|
|
1021
|
+
with open(model_path, 'rb') as opened_file:
|
|
1022
|
+
try:
|
|
1023
|
+
# loading JIT archive
|
|
1024
|
+
model = torch.jit.load(
|
|
1025
|
+
opened_file, map_location=device if jit else 'cpu').eval()
|
|
1026
|
+
state_dict = None
|
|
1027
|
+
except RuntimeError:
|
|
1028
|
+
# loading saved state dict
|
|
1029
|
+
if jit:
|
|
1030
|
+
warnings.warn(
|
|
1031
|
+
f'File {model_path} is not a JIT archive. Loading as a state dict instead'
|
|
1032
|
+
)
|
|
1033
|
+
jit = False
|
|
1034
|
+
state_dict = torch.load(opened_file, map_location='cpu')
|
|
1035
|
+
|
|
1036
|
+
if not jit:
|
|
1037
|
+
model = build_model(state_dict or model.state_dict()).to(device)
|
|
1038
|
+
if str(device) == 'cpu':
|
|
1039
|
+
model.float()
|
|
1040
|
+
return model, _transform(model.visual.input_resolution)
|
|
1041
|
+
|
|
1042
|
+
# patch the device names
|
|
1043
|
+
device_holder = torch.jit.trace(
|
|
1044
|
+
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
|
1045
|
+
device_node = [
|
|
1046
|
+
n for n in device_holder.graph.findAllNodes('prim::Constant')
|
|
1047
|
+
if 'Device' in repr(n)
|
|
1048
|
+
][-1]
|
|
1049
|
+
|
|
1050
|
+
def patch_device(module):
|
|
1051
|
+
try:
|
|
1052
|
+
graphs = [module.graph] if hasattr(module, 'graph') else []
|
|
1053
|
+
except RuntimeError:
|
|
1054
|
+
graphs = []
|
|
1055
|
+
|
|
1056
|
+
if hasattr(module, 'forward1'):
|
|
1057
|
+
graphs.append(module.forward1.graph)
|
|
1058
|
+
|
|
1059
|
+
for graph in graphs:
|
|
1060
|
+
for node in graph.findAllNodes('prim::Constant'):
|
|
1061
|
+
if 'value' in node.attributeNames() and str(
|
|
1062
|
+
node['value']).startswith('cuda'):
|
|
1063
|
+
node.copyAttributes(device_node)
|
|
1064
|
+
|
|
1065
|
+
model.apply(patch_device)
|
|
1066
|
+
patch_device(model.encode_image)
|
|
1067
|
+
patch_device(model.encode_text)
|
|
1068
|
+
|
|
1069
|
+
# patch dtype to float32 on CPU
|
|
1070
|
+
if str(device) == 'cpu':
|
|
1071
|
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(),
|
|
1072
|
+
example_inputs=[])
|
|
1073
|
+
float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
|
|
1074
|
+
float_node = float_input.node()
|
|
1075
|
+
|
|
1076
|
+
def patch_float(module):
|
|
1077
|
+
try:
|
|
1078
|
+
graphs = [module.graph] if hasattr(module, 'graph') else []
|
|
1079
|
+
except RuntimeError:
|
|
1080
|
+
graphs = []
|
|
1081
|
+
|
|
1082
|
+
if hasattr(module, 'forward1'):
|
|
1083
|
+
graphs.append(module.forward1.graph)
|
|
1084
|
+
|
|
1085
|
+
for graph in graphs:
|
|
1086
|
+
for node in graph.findAllNodes('aten::to'):
|
|
1087
|
+
inputs = list(node.inputs())
|
|
1088
|
+
for i in [
|
|
1089
|
+
1, 2
|
|
1090
|
+
]: # dtype can be the second or third argument to aten::to()
|
|
1091
|
+
if inputs[i].node()['value'] == 5:
|
|
1092
|
+
inputs[i].node().copyAttributes(float_node)
|
|
1093
|
+
|
|
1094
|
+
model.apply(patch_float)
|
|
1095
|
+
patch_float(model.encode_image)
|
|
1096
|
+
patch_float(model.encode_text)
|
|
1097
|
+
|
|
1098
|
+
model.float()
|
|
1099
|
+
|
|
1100
|
+
return model, _transform(model.input_resolution.item())
|
|
1101
|
+
|
|
1102
|
+
|
|
1103
|
+
def tokenize(
|
|
1104
|
+
texts: Union[str, List[str]],
|
|
1105
|
+
context_length: int = 77,
|
|
1106
|
+
truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
|
1107
|
+
"""
|
|
1108
|
+
Returns the tokenized representation of given input string(s)
|
|
1109
|
+
|
|
1110
|
+
Parameters
|
|
1111
|
+
----------
|
|
1112
|
+
texts : Union[str, List[str]]
|
|
1113
|
+
An input string or a list of input strings to tokenize
|
|
1114
|
+
|
|
1115
|
+
context_length : int
|
|
1116
|
+
The context length to use; all CLIP models use 77 as the context length
|
|
1117
|
+
|
|
1118
|
+
truncate: bool
|
|
1119
|
+
Whether to truncate the text in case its encoding is longer than the context length
|
|
1120
|
+
|
|
1121
|
+
Returns
|
|
1122
|
+
-------
|
|
1123
|
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
|
1124
|
+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
|
1125
|
+
"""
|
|
1126
|
+
if isinstance(texts, str):
|
|
1127
|
+
texts = [texts]
|
|
1128
|
+
|
|
1129
|
+
sot_token = _tokenizer.encoder['<|startoftext|>']
|
|
1130
|
+
eot_token = _tokenizer.encoder['<|endoftext|>']
|
|
1131
|
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
|
|
1132
|
+
for text in texts]
|
|
1133
|
+
if packaging.version.parse(
|
|
1134
|
+
torch.__version__) < packaging.version.parse('1.8.0'):
|
|
1135
|
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
|
1136
|
+
else:
|
|
1137
|
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
|
1138
|
+
|
|
1139
|
+
for i, tokens in enumerate(all_tokens):
|
|
1140
|
+
if len(tokens) > context_length:
|
|
1141
|
+
if truncate:
|
|
1142
|
+
tokens = tokens[:context_length]
|
|
1143
|
+
tokens[-1] = eot_token
|
|
1144
|
+
else:
|
|
1145
|
+
raise RuntimeError(
|
|
1146
|
+
f'Input {texts[i]} is too long for context length {context_length}'
|
|
1147
|
+
)
|
|
1148
|
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
|
1149
|
+
|
|
1150
|
+
return result
|
|
1151
|
+
|
|
1152
|
+
|
|
1153
|
+
class DptrParseq(nn.Module):
|
|
1154
|
+
|
|
1155
|
+
def __init__(self,
|
|
1156
|
+
in_channels,
|
|
1157
|
+
out_channels,
|
|
1158
|
+
max_label_length=25,
|
|
1159
|
+
embed_dim=512,
|
|
1160
|
+
dec_num_heads=8,
|
|
1161
|
+
dec_mlp_ratio=4,
|
|
1162
|
+
dec_depth=6,
|
|
1163
|
+
perm_num=6,
|
|
1164
|
+
perm_forward=True,
|
|
1165
|
+
perm_mirrored=True,
|
|
1166
|
+
decode_ar=True,
|
|
1167
|
+
refine_iters=1,
|
|
1168
|
+
dropout=0.1,
|
|
1169
|
+
is_pretrain=True,
|
|
1170
|
+
ORP_path=None,
|
|
1171
|
+
**kwargs: Any) -> None:
|
|
1172
|
+
super().__init__()
|
|
1173
|
+
self.pad_id = out_channels - 1
|
|
1174
|
+
self.eos_id = 0
|
|
1175
|
+
self.bos_id = out_channels - 2
|
|
1176
|
+
self.max_label_length = max_label_length
|
|
1177
|
+
self.decode_ar = decode_ar
|
|
1178
|
+
self.refine_iters = refine_iters
|
|
1179
|
+
self.is_pretrain = is_pretrain
|
|
1180
|
+
if not is_pretrain:
|
|
1181
|
+
self.token_query = nn.Parameter(torch.Tensor(1, 26, embed_dim))
|
|
1182
|
+
self.fmu = FMU(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio,
|
|
1183
|
+
dropout)
|
|
1184
|
+
|
|
1185
|
+
decoder_layer = DecoderLayer(embed_dim, dec_num_heads,
|
|
1186
|
+
embed_dim * dec_mlp_ratio, dropout)
|
|
1187
|
+
self.decoder = Decoder(decoder_layer,
|
|
1188
|
+
num_layers=dec_depth,
|
|
1189
|
+
norm=nn.LayerNorm(embed_dim))
|
|
1190
|
+
|
|
1191
|
+
# Perm/attn mask stuff
|
|
1192
|
+
self.rng = np.random.default_rng()
|
|
1193
|
+
self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
|
|
1194
|
+
self.perm_forward = perm_forward
|
|
1195
|
+
self.perm_mirrored = perm_mirrored
|
|
1196
|
+
|
|
1197
|
+
# We don't predict <bos> nor <pad>
|
|
1198
|
+
self.head = nn.Linear(embed_dim, out_channels - 2)
|
|
1199
|
+
self.text_embed = TokenEmbedding(out_channels, embed_dim)
|
|
1200
|
+
|
|
1201
|
+
# +1 for <eos>
|
|
1202
|
+
self.pos_queries = nn.Parameter(
|
|
1203
|
+
torch.Tensor(1, max_label_length + 1, embed_dim))
|
|
1204
|
+
self.dropout = nn.Dropout(p=dropout)
|
|
1205
|
+
# Encoder has its own init.
|
|
1206
|
+
self.apply(self._init_weights)
|
|
1207
|
+
nn.init.trunc_normal_(self.pos_queries, std=0.02)
|
|
1208
|
+
|
|
1209
|
+
if is_pretrain:
|
|
1210
|
+
self.clip_encoder, preprocess = load('ViT-B/16')
|
|
1211
|
+
for p in self.clip_encoder.parameters():
|
|
1212
|
+
p.requires_grad = False
|
|
1213
|
+
if ORP_path is None:
|
|
1214
|
+
background_image_folder_path = 'background_mages_folder/path'
|
|
1215
|
+
self.background_features = self.get_noise(
|
|
1216
|
+
background_image_folder_path, preprocess)
|
|
1217
|
+
torch.save(self.background_features, 'save/noise/to/ORP_path')
|
|
1218
|
+
else:
|
|
1219
|
+
self.background_features = torch.load(ORP_path,
|
|
1220
|
+
map_location='cpu')
|
|
1221
|
+
|
|
1222
|
+
def _init_weights(self, module: nn.Module):
|
|
1223
|
+
"""Initialize the weights using the typical initialization schemes used
|
|
1224
|
+
in SOTA models."""
|
|
1225
|
+
|
|
1226
|
+
if isinstance(module, nn.Linear):
|
|
1227
|
+
nn.init.trunc_normal_(module.weight, std=0.02)
|
|
1228
|
+
if module.bias is not None:
|
|
1229
|
+
nn.init.zeros_(module.bias)
|
|
1230
|
+
elif isinstance(module, nn.Embedding):
|
|
1231
|
+
nn.init.trunc_normal_(module.weight, std=0.02)
|
|
1232
|
+
if module.padding_idx is not None:
|
|
1233
|
+
module.weight.data[module.padding_idx].zero_()
|
|
1234
|
+
elif isinstance(module, nn.Conv2d):
|
|
1235
|
+
nn.init.kaiming_normal_(module.weight,
|
|
1236
|
+
mode='fan_out',
|
|
1237
|
+
nonlinearity='relu')
|
|
1238
|
+
if module.bias is not None:
|
|
1239
|
+
nn.init.zeros_(module.bias)
|
|
1240
|
+
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
|
|
1241
|
+
nn.init.ones_(module.weight)
|
|
1242
|
+
nn.init.zeros_(module.bias)
|
|
1243
|
+
|
|
1244
|
+
@torch.jit.ignore
|
|
1245
|
+
def no_weight_decay(self):
|
|
1246
|
+
param_names = {'text_embed.embedding.weight', 'pos_queries'}
|
|
1247
|
+
return param_names
|
|
1248
|
+
|
|
1249
|
+
def get_noise(self, background_image_path, preprocess):
|
|
1250
|
+
image_paths = [
|
|
1251
|
+
os.path.join(background_image_path, filename)
|
|
1252
|
+
for filename in os.listdir(image_folder_path)
|
|
1253
|
+
if filename.endswith(('.png', '.jpg', '.jpeg'))
|
|
1254
|
+
]
|
|
1255
|
+
features = []
|
|
1256
|
+
for image_path in image_paths:
|
|
1257
|
+
image = Image.open(image_path)
|
|
1258
|
+
input = preprocess(image).unsqueeze(0).to(self._device)
|
|
1259
|
+
with torch.no_grad():
|
|
1260
|
+
feature = self.clip_encoder.encode_image(input)
|
|
1261
|
+
features.append(feature)
|
|
1262
|
+
image.close()
|
|
1263
|
+
return torch.cat(features).cpu().numpy()
|
|
1264
|
+
|
|
1265
|
+
def clip_encode(self, labels):
|
|
1266
|
+
text_inputs = torch.cat(
|
|
1267
|
+
[tokenize(f"a photo of a '{c}'") for c in labels]).to(self._device)
|
|
1268
|
+
|
|
1269
|
+
return self.clip_encoder.encode_text(text_inputs)
|
|
1270
|
+
|
|
1271
|
+
def decode(
|
|
1272
|
+
self,
|
|
1273
|
+
tgt: torch.Tensor,
|
|
1274
|
+
memory: torch.Tensor,
|
|
1275
|
+
tgt_mask: Optional[Tensor] = None,
|
|
1276
|
+
tgt_padding_mask: Optional[Tensor] = None,
|
|
1277
|
+
tgt_query: Optional[Tensor] = None,
|
|
1278
|
+
tgt_query_mask: Optional[Tensor] = None,
|
|
1279
|
+
pos_query: torch.Tensor = None,
|
|
1280
|
+
):
|
|
1281
|
+
N, L = tgt.shape
|
|
1282
|
+
# <bos> stands for the null context. We only supply position information for characters after <bos>.
|
|
1283
|
+
null_ctx = self.text_embed(tgt[:, :1])
|
|
1284
|
+
|
|
1285
|
+
if tgt_query is None:
|
|
1286
|
+
tgt_query = pos_query[:, :L]
|
|
1287
|
+
tgt_emb = pos_query[:, :L - 1] + self.text_embed(tgt[:, 1:])
|
|
1288
|
+
tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
|
|
1289
|
+
|
|
1290
|
+
tgt_query = self.dropout(tgt_query)
|
|
1291
|
+
return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask,
|
|
1292
|
+
tgt_mask, tgt_padding_mask)
|
|
1293
|
+
|
|
1294
|
+
def forward(self, memory, data=None, pos_query=None):
|
|
1295
|
+
# print(memory.shape, data[0].shape)
|
|
1296
|
+
if self.training:
|
|
1297
|
+
if self.is_pretrain:
|
|
1298
|
+
return self.training_step(None, pos_query, data[0], memory)
|
|
1299
|
+
return self.training_step(memory, pos_query, data[0], None)
|
|
1300
|
+
else:
|
|
1301
|
+
if self.is_pretrain:
|
|
1302
|
+
return self.forward_test(None, memory, pos_query)
|
|
1303
|
+
return self.forward_test(memory, None, pos_query)
|
|
1304
|
+
|
|
1305
|
+
def forward_test(self,
|
|
1306
|
+
memory: Tensor,
|
|
1307
|
+
clip_ids,
|
|
1308
|
+
pos_query: Tensor = None,
|
|
1309
|
+
max_length: Optional[int] = None) -> Tensor:
|
|
1310
|
+
testing = max_length is None
|
|
1311
|
+
max_length = (self.max_label_length if max_length is None else min(
|
|
1312
|
+
max_length, self.max_label_length))
|
|
1313
|
+
|
|
1314
|
+
if self.is_pretrain:
|
|
1315
|
+
memory = self.clip_encoder.encode_text(clip_ids)
|
|
1316
|
+
else:
|
|
1317
|
+
bs = memory.shape[0]
|
|
1318
|
+
token_query = self.token_query.expand(bs, -1, -1)
|
|
1319
|
+
memory = self.fmu(token_query, memory)
|
|
1320
|
+
_device = memory.get_device()
|
|
1321
|
+
bs = memory.shape[0]
|
|
1322
|
+
# +1 for <eos> at end of sequence.
|
|
1323
|
+
num_steps = max_length + 1
|
|
1324
|
+
# memory = self.encode(images)
|
|
1325
|
+
|
|
1326
|
+
# Query positions up to `num_steps`
|
|
1327
|
+
if pos_query is None:
|
|
1328
|
+
pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
|
|
1329
|
+
else:
|
|
1330
|
+
pos_queries = pos_query
|
|
1331
|
+
|
|
1332
|
+
# Special case for the forward permutation. Faster than using `generate_attn_masks()`
|
|
1333
|
+
tgt_mask = query_mask = torch.triu(
|
|
1334
|
+
torch.full((num_steps, num_steps), float('-inf'), device=_device),
|
|
1335
|
+
1)
|
|
1336
|
+
self.attn_maps = []
|
|
1337
|
+
if self.decode_ar:
|
|
1338
|
+
tgt_in = torch.full((bs, num_steps),
|
|
1339
|
+
self.pad_id,
|
|
1340
|
+
dtype=torch.long,
|
|
1341
|
+
device=_device)
|
|
1342
|
+
tgt_in[:, 0] = self.bos_id
|
|
1343
|
+
|
|
1344
|
+
logits = []
|
|
1345
|
+
for i in range(num_steps):
|
|
1346
|
+
j = i + 1 # next token index
|
|
1347
|
+
# Efficient decoding:
|
|
1348
|
+
# Input the context up to the ith token. We use only one query (at position = i) at a time.
|
|
1349
|
+
# This works because of the lookahead masking effect of the canonical (forward) AR context.
|
|
1350
|
+
# Past tokens have no access to future tokens, hence are fixed once computed.
|
|
1351
|
+
tgt_out = self.decode(
|
|
1352
|
+
tgt_in[:, :j],
|
|
1353
|
+
memory,
|
|
1354
|
+
tgt_mask[:j, :j],
|
|
1355
|
+
tgt_query=pos_queries[:, i:j],
|
|
1356
|
+
tgt_query_mask=query_mask[i:j, :j],
|
|
1357
|
+
pos_query=pos_queries,
|
|
1358
|
+
)
|
|
1359
|
+
self.attn_maps.append(self.decoder.layers[-1].attn_map)
|
|
1360
|
+
# the next token probability is in the output's ith token position
|
|
1361
|
+
p_i = self.head(tgt_out)
|
|
1362
|
+
logits.append(p_i)
|
|
1363
|
+
if j < num_steps:
|
|
1364
|
+
# greedy decode. add the next token index to the target input
|
|
1365
|
+
tgt_in[:, j] = p_i.squeeze().argmax(-1)
|
|
1366
|
+
# Efficient batch decoding: If all output words have at least one EOS token, end decoding.
|
|
1367
|
+
if testing and (tgt_in == self.eos_id).any(dim=-1).all():
|
|
1368
|
+
break
|
|
1369
|
+
|
|
1370
|
+
logits = torch.cat(logits, dim=1)
|
|
1371
|
+
else:
|
|
1372
|
+
# No prior context, so input is just <bos>. We query all positions.
|
|
1373
|
+
tgt_in = torch.full((bs, 1),
|
|
1374
|
+
self.bos_id,
|
|
1375
|
+
dtype=torch.long,
|
|
1376
|
+
device=_device)
|
|
1377
|
+
tgt_out = self.decode(tgt_in,
|
|
1378
|
+
memory,
|
|
1379
|
+
tgt_query=pos_queries,
|
|
1380
|
+
pos_query=pos_queries)
|
|
1381
|
+
logits = self.head(tgt_out)
|
|
1382
|
+
|
|
1383
|
+
if self.refine_iters:
|
|
1384
|
+
# For iterative refinement, we always use a 'cloze' mask.
|
|
1385
|
+
# We can derive it from the AR forward mask by unmasking the token context to the right.
|
|
1386
|
+
query_mask[torch.triu(
|
|
1387
|
+
torch.ones(num_steps,
|
|
1388
|
+
num_steps,
|
|
1389
|
+
dtype=torch.bool,
|
|
1390
|
+
device=_device), 2)] = 0
|
|
1391
|
+
bos = torch.full((bs, 1),
|
|
1392
|
+
self.bos_id,
|
|
1393
|
+
dtype=torch.long,
|
|
1394
|
+
device=_device)
|
|
1395
|
+
for i in range(self.refine_iters):
|
|
1396
|
+
# Prior context is the previous output.
|
|
1397
|
+
tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
|
|
1398
|
+
tgt_len = tgt_in.shape[1]
|
|
1399
|
+
tgt_padding_mask = (tgt_in == self.eos_id).int().cumsum(
|
|
1400
|
+
-1) > 0 # mask tokens beyond the first EOS token.
|
|
1401
|
+
tgt_out = self.decode(
|
|
1402
|
+
tgt_in,
|
|
1403
|
+
memory,
|
|
1404
|
+
tgt_mask[:tgt_len, :tgt_len],
|
|
1405
|
+
tgt_padding_mask,
|
|
1406
|
+
tgt_query=pos_queries,
|
|
1407
|
+
tgt_query_mask=query_mask[:, :tgt_len],
|
|
1408
|
+
pos_query=pos_queries,
|
|
1409
|
+
)
|
|
1410
|
+
logits = self.head(tgt_out)
|
|
1411
|
+
|
|
1412
|
+
return F.softmax(logits, -1)
|
|
1413
|
+
|
|
1414
|
+
def gen_tgt_perms(self, tgt, _device):
|
|
1415
|
+
"""Generate shared permutations for the whole batch.
|
|
1416
|
+
|
|
1417
|
+
This works because the same attention mask can be used for the shorter
|
|
1418
|
+
sequences because of the padding mask.
|
|
1419
|
+
"""
|
|
1420
|
+
# We don't permute the position of BOS, we permute EOS separately
|
|
1421
|
+
max_num_chars = tgt.shape[1] - 2
|
|
1422
|
+
# Special handling for 1-character sequences
|
|
1423
|
+
if max_num_chars == 1:
|
|
1424
|
+
return torch.arange(3, device=_device).unsqueeze(0)
|
|
1425
|
+
perms = [torch.arange(max_num_chars, device=_device)
|
|
1426
|
+
] if self.perm_forward else []
|
|
1427
|
+
# Additional permutations if needed
|
|
1428
|
+
max_perms = math.factorial(max_num_chars)
|
|
1429
|
+
if self.perm_mirrored:
|
|
1430
|
+
max_perms //= 2
|
|
1431
|
+
num_gen_perms = min(self.max_gen_perms, max_perms)
|
|
1432
|
+
# For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions
|
|
1433
|
+
# Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars.
|
|
1434
|
+
if max_num_chars < 5:
|
|
1435
|
+
# Pool of permutations to sample from. We only need the first half (if complementary option is selected)
|
|
1436
|
+
# Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
|
|
1437
|
+
if max_num_chars == 4 and self.perm_mirrored:
|
|
1438
|
+
selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
|
|
1439
|
+
else:
|
|
1440
|
+
selector = list(range(max_perms))
|
|
1441
|
+
perm_pool = torch.as_tensor(list(
|
|
1442
|
+
permutations(range(max_num_chars), max_num_chars)),
|
|
1443
|
+
device=_device)[selector]
|
|
1444
|
+
# If the forward permutation is always selected, no need to add it to the pool for sampling
|
|
1445
|
+
if self.perm_forward:
|
|
1446
|
+
perm_pool = perm_pool[1:]
|
|
1447
|
+
perms = torch.stack(perms)
|
|
1448
|
+
if len(perm_pool):
|
|
1449
|
+
i = self.rng.choice(len(perm_pool),
|
|
1450
|
+
size=num_gen_perms - len(perms),
|
|
1451
|
+
replace=False)
|
|
1452
|
+
perms = torch.cat([perms, perm_pool[i]])
|
|
1453
|
+
else:
|
|
1454
|
+
perms.extend([
|
|
1455
|
+
torch.randperm(max_num_chars, device=_device)
|
|
1456
|
+
for _ in range(num_gen_perms - len(perms))
|
|
1457
|
+
])
|
|
1458
|
+
perms = torch.stack(perms)
|
|
1459
|
+
if self.perm_mirrored:
|
|
1460
|
+
# Add complementary pairs
|
|
1461
|
+
comp = perms.flip(-1)
|
|
1462
|
+
# Stack in such a way that the pairs are next to each other.
|
|
1463
|
+
perms = torch.stack([perms, comp
|
|
1464
|
+
]).transpose(0, 1).reshape(-1, max_num_chars)
|
|
1465
|
+
# NOTE:
|
|
1466
|
+
# The only meaningful way of permuting the EOS position is by moving it one character position at a time.
|
|
1467
|
+
# However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS
|
|
1468
|
+
# positions will always be much less than the number of permutations (unless a low perm_num is set).
|
|
1469
|
+
# Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly
|
|
1470
|
+
# distribute it across the chosen number of permutations.
|
|
1471
|
+
# Add position indices of BOS and EOS
|
|
1472
|
+
bos_idx = perms.new_zeros((len(perms), 1))
|
|
1473
|
+
eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1)
|
|
1474
|
+
perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1)
|
|
1475
|
+
# Special handling for the reverse direction. This does two things:
|
|
1476
|
+
# 1. Reverse context for the characters
|
|
1477
|
+
# 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode)
|
|
1478
|
+
if len(perms) > 1:
|
|
1479
|
+
perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1,
|
|
1480
|
+
device=_device)
|
|
1481
|
+
return perms
|
|
1482
|
+
|
|
1483
|
+
def generate_attn_masks(self, perm, _device):
|
|
1484
|
+
"""Generate attention masks given a sequence permutation (includes pos.
|
|
1485
|
+
for bos and eos tokens)
|
|
1486
|
+
|
|
1487
|
+
:param perm: the permutation sequence. i = 0 is always the BOS
|
|
1488
|
+
:return: lookahead attention masks
|
|
1489
|
+
"""
|
|
1490
|
+
sz = perm.shape[0]
|
|
1491
|
+
mask = torch.zeros((sz, sz), device=_device)
|
|
1492
|
+
for i in range(sz):
|
|
1493
|
+
query_idx = perm[i]
|
|
1494
|
+
masked_keys = perm[i + 1:]
|
|
1495
|
+
mask[query_idx, masked_keys] = float('-inf')
|
|
1496
|
+
content_mask = mask[:-1, :-1].clone()
|
|
1497
|
+
mask[torch.eye(sz, dtype=torch.bool,
|
|
1498
|
+
device=_device)] = float('-inf') # mask "self"
|
|
1499
|
+
query_mask = mask[1:, :-1]
|
|
1500
|
+
return content_mask, query_mask
|
|
1501
|
+
|
|
1502
|
+
def training_step(self, memory, pos_query, tgt_ids, clip_ids):
|
|
1503
|
+
bs = tgt_ids.shape[0]
|
|
1504
|
+
if self.is_pretrain:
|
|
1505
|
+
memory = self.clip_encoder.encode_text(clip_ids)
|
|
1506
|
+
n = memory.shape[1]
|
|
1507
|
+
B, N, D = self.background_features.shape
|
|
1508
|
+
random_B = np.random.choice(B, bs, replace=False)
|
|
1509
|
+
random_N = np.random.choice(N, n, replace=False)
|
|
1510
|
+
noise = self.background_features[random_B][:, random_N]
|
|
1511
|
+
noise = torch.from_numpy(noise).to(memory.get_device())
|
|
1512
|
+
memory = memory + noise * 1e-1
|
|
1513
|
+
else:
|
|
1514
|
+
token_query = self.token_query.expand(bs, -1, -1)
|
|
1515
|
+
memory = self.fmu(token_query, memory)
|
|
1516
|
+
|
|
1517
|
+
if pos_query is None:
|
|
1518
|
+
pos_query = self.pos_queries.expand(bs, -1, -1)
|
|
1519
|
+
# Prepare the target sequences (input and output)
|
|
1520
|
+
tgt_perms = self.gen_tgt_perms(tgt_ids, memory.get_device())
|
|
1521
|
+
tgt_in = tgt_ids[:, :-1]
|
|
1522
|
+
tgt_out = tgt_ids[:, 1:]
|
|
1523
|
+
|
|
1524
|
+
# The [EOS] token is not depended upon by any other token in any permutation ordering
|
|
1525
|
+
tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
|
|
1526
|
+
|
|
1527
|
+
loss = 0
|
|
1528
|
+
loss_numel = 0
|
|
1529
|
+
n = (tgt_out != self.pad_id).sum().item()
|
|
1530
|
+
for i, perm in enumerate(tgt_perms):
|
|
1531
|
+
tgt_mask, query_mask = self.generate_attn_masks(
|
|
1532
|
+
perm, memory.get_device())
|
|
1533
|
+
# print("tgt_in:", tgt_in, "tgt_out:", tgt_out, "tgt_padding_mask:", tgt_padding_mask)
|
|
1534
|
+
# print('tgt_mask:', tgt_mask)
|
|
1535
|
+
# print('query_mask:', query_mask)
|
|
1536
|
+
# print(tgt_in.shape, memory.shape, tgt_mask.shape, tgt_padding_mask.shape, query_mask.shape, pos_query.shape)
|
|
1537
|
+
out = self.decode(
|
|
1538
|
+
tgt_in,
|
|
1539
|
+
memory,
|
|
1540
|
+
tgt_mask,
|
|
1541
|
+
tgt_padding_mask,
|
|
1542
|
+
tgt_query_mask=query_mask,
|
|
1543
|
+
pos_query=pos_query,
|
|
1544
|
+
)
|
|
1545
|
+
# print('out:', out)
|
|
1546
|
+
logits = self.head(out)
|
|
1547
|
+
# print('logits:', logits)
|
|
1548
|
+
if i == 0:
|
|
1549
|
+
final_out = logits
|
|
1550
|
+
loss += n * F.cross_entropy(logits.flatten(end_dim=1),
|
|
1551
|
+
tgt_out.flatten(),
|
|
1552
|
+
ignore_index=self.pad_id)
|
|
1553
|
+
loss_numel += n
|
|
1554
|
+
# After the second iteration (i.e. done with canonical and reverse orderings),
|
|
1555
|
+
# remove the [EOS] tokens for the succeeding perms
|
|
1556
|
+
if i == 1:
|
|
1557
|
+
tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id,
|
|
1558
|
+
tgt_out)
|
|
1559
|
+
n = (tgt_out != self.pad_id).sum().item()
|
|
1560
|
+
loss /= loss_numel
|
|
1561
|
+
|
|
1562
|
+
# self.log('loss', loss)
|
|
1563
|
+
return [loss, final_out]
|