openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -71,6 +71,7 @@ class SSMatchLayer(nn.Module):
|
|
|
71
71
|
drop_path=0.0,
|
|
72
72
|
act_layer=nn.GELU,
|
|
73
73
|
epsilon=1e-6,
|
|
74
|
+
is_last_layer=False,
|
|
74
75
|
):
|
|
75
76
|
super().__init__()
|
|
76
77
|
self.dim = dim
|
|
@@ -96,17 +97,18 @@ class SSMatchLayer(nn.Module):
|
|
|
96
97
|
proj_drop=drop)
|
|
97
98
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
|
98
99
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
|
|
100
|
+
self.is_last_layer = is_last_layer
|
|
99
101
|
|
|
100
102
|
def forward(self, question_f, prompt_f, visual_f, mask=None):
|
|
101
|
-
|
|
102
103
|
question_f = question_f + self.drop_path(
|
|
103
104
|
self.images_to_question_cross_attn(self.normq1(question_f),
|
|
104
105
|
self.normkv1(prompt_f), mask))
|
|
105
106
|
question_f = question_f.reshape(visual_f.shape[0], -1, self.dim)
|
|
106
107
|
question_f = self.question_to_images_cross_attn(
|
|
107
108
|
self.normq2(question_f), self.normkv2(visual_f))
|
|
108
|
-
|
|
109
|
-
|
|
109
|
+
if self.is_last_layer:
|
|
110
|
+
return question_f
|
|
111
|
+
return question_f.flatten(0, 1).unsqueeze(1)
|
|
110
112
|
|
|
111
113
|
|
|
112
114
|
class SMTRDecoder(nn.Module):
|
|
@@ -152,7 +154,9 @@ class SMTRDecoder(nn.Module):
|
|
|
152
154
|
dynq2img_heads=dynq2img_heads,
|
|
153
155
|
mlp_ratio=4.0,
|
|
154
156
|
qkv_bias=True,
|
|
155
|
-
drop_path=dpr[i]
|
|
157
|
+
drop_path=dpr[i],
|
|
158
|
+
is_last_layer=i == num_layer - 1)
|
|
159
|
+
for i in range(num_layer)
|
|
156
160
|
])
|
|
157
161
|
|
|
158
162
|
self.ds = ds
|
|
@@ -351,7 +355,6 @@ class SMTRDecoder(nn.Module):
|
|
|
351
355
|
else:
|
|
352
356
|
return torch.concat(next_logits_all + pre_logits_all[::-1], 1)
|
|
353
357
|
|
|
354
|
-
|
|
355
358
|
def forward_test_bi_attn(self, x):
|
|
356
359
|
self.attn_maps = []
|
|
357
360
|
if not self.ds:
|
|
@@ -366,81 +369,122 @@ class SMTRDecoder(nn.Module):
|
|
|
366
369
|
next = self.next_token
|
|
367
370
|
pre = self.pre_token
|
|
368
371
|
next_pre = torch.concat([next, pre], 0)
|
|
369
|
-
next_pre = next_pre.squeeze(1)
|
|
372
|
+
next_pre = next_pre.squeeze(1) #2, 1, dim
|
|
370
373
|
|
|
371
374
|
prompt_next_embed = self.prompt_next_embed.squeeze(1)
|
|
372
375
|
prompt_pre_embed = self.prompt_pre_embed.squeeze(1)
|
|
373
376
|
|
|
374
|
-
next_id = torch.full([1, self.sub_str_len],
|
|
375
|
-
|
|
377
|
+
next_id = torch.full([1, self.sub_str_len],
|
|
378
|
+
self.bos_next,
|
|
379
|
+
dtype=torch.long,
|
|
380
|
+
device=x.device)
|
|
381
|
+
pre_id = torch.full([1, self.sub_str_len],
|
|
382
|
+
self.bos_pre,
|
|
383
|
+
dtype=torch.long,
|
|
384
|
+
device=x.device)
|
|
376
385
|
# prompt_next_bos = self.char_embed(prompt_id)
|
|
377
386
|
# pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.device)
|
|
378
|
-
next_pred_id_list = torch.full([1, self.max_len],
|
|
379
|
-
|
|
387
|
+
next_pred_id_list = torch.full([1, self.max_len],
|
|
388
|
+
self.ignore_index,
|
|
389
|
+
dtype=torch.long,
|
|
390
|
+
device=x.device)
|
|
391
|
+
pre_pred_id_list = torch.full([1, self.max_len],
|
|
392
|
+
self.ignore_index,
|
|
393
|
+
dtype=torch.long,
|
|
394
|
+
device=x.device)
|
|
380
395
|
next_logits_all = []
|
|
381
396
|
pre_logits_all = []
|
|
382
397
|
attn_map_next = []
|
|
383
398
|
attn_map_pre = []
|
|
384
|
-
mask_pad = torch.zeros([bs, 1],
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
399
|
+
mask_pad = torch.zeros([bs, 1],
|
|
400
|
+
dtype=torch.float32,
|
|
401
|
+
device=x.device)
|
|
402
|
+
for j in range(0, min(70, self.max_len - 1)):
|
|
403
|
+
|
|
404
|
+
prompt_char_next = torch.concat([
|
|
405
|
+
prompt_next_embed[:, :1, :],
|
|
406
|
+
prompt_next_embed[:, 1:, :] + self.char_embed(next_id)
|
|
407
|
+
], 1) # b, sub_l, dim
|
|
408
|
+
prompt_char_pre = torch.concat([
|
|
409
|
+
prompt_pre_embed[:, :1, :],
|
|
410
|
+
prompt_pre_embed[:, 1:, :] + self.char_embed(pre_id)
|
|
411
|
+
], 1) # b, sub_l, dim
|
|
412
|
+
prompt_char = torch.concat([prompt_char_next, prompt_char_pre],
|
|
413
|
+
0) #2, 6, dim
|
|
390
414
|
# prompt_char = prompt_char.flatten(0, 1)
|
|
391
415
|
|
|
392
|
-
mask_next = torch.where(next_id == self.bos_next,
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
416
|
+
mask_next = torch.where(next_id == self.bos_next,
|
|
417
|
+
float('-inf'), 0) # b, subs_l
|
|
418
|
+
mask_pre = torch.where(pre_id == self.bos_pre, float('-inf'),
|
|
419
|
+
0) # b, subs_l
|
|
420
|
+
mask = torch.concat([mask_next, mask_pre], 0) #2, 5
|
|
421
|
+
mask = torch.concat([mask_pad, mask], 1) # 2, 6
|
|
396
422
|
pred_token = next_pre
|
|
397
|
-
visual_f_i = visual_f[:2]
|
|
423
|
+
visual_f_i = visual_f[:2] # 2 l dim
|
|
398
424
|
for layer in self.cmff_decoder:
|
|
399
|
-
pred_token = layer(pred_token, prompt_char, visual_f_i,
|
|
400
|
-
|
|
401
|
-
|
|
425
|
+
pred_token = layer(pred_token, prompt_char, visual_f_i,
|
|
426
|
+
mask.unsqueeze(1))
|
|
427
|
+
|
|
402
428
|
logits_next_i = self.ques1_head(self.norm_pred(pred_token))
|
|
403
429
|
logits = F.softmax(logits_next_i, -1)
|
|
404
|
-
pred_id_i = logits.argmax(-1)
|
|
430
|
+
pred_id_i = logits.argmax(-1) #2, 1
|
|
405
431
|
# print(pred_id_i.shape)
|
|
406
|
-
|
|
407
|
-
next_pred_id_list[:, j:j+1] = pred_id_i[:1]
|
|
408
|
-
pre_pred_id_list[:, j:j+1] = pred_id_i[1:2]
|
|
432
|
+
|
|
433
|
+
next_pred_id_list[:, j:j + 1] = pred_id_i[:1]
|
|
434
|
+
pre_pred_id_list[:, j:j + 1] = pred_id_i[1:2]
|
|
409
435
|
if not (next_pred_id_list == self.eos).any(dim=-1).all():
|
|
410
436
|
next_logits_all.append(logits[:1])
|
|
411
|
-
attn_map_next.append(
|
|
437
|
+
attn_map_next.append(
|
|
438
|
+
self.cmff_decoder[-1].question_to_images_cross_attn.
|
|
439
|
+
attn_map[0])
|
|
412
440
|
next_id = torch.concat([next_id[:, 1:], pred_id_i[:1]], 1)
|
|
413
441
|
if not (pre_pred_id_list == self.eos).any(dim=-1).all():
|
|
414
442
|
pre_logits_all.append(logits[1:2])
|
|
415
|
-
attn_map_pre.append(
|
|
443
|
+
attn_map_pre.append(
|
|
444
|
+
self.cmff_decoder[-1].question_to_images_cross_attn.
|
|
445
|
+
attn_map[1])
|
|
416
446
|
pre_id = torch.concat([pred_id_i[1:2], pre_id[:, :-1]], 1)
|
|
417
|
-
|
|
418
|
-
if (next_pred_id_list == self.eos).any(dim=-1).all() and (
|
|
447
|
+
|
|
448
|
+
if (next_pred_id_list == self.eos).any(dim=-1).all() and (
|
|
449
|
+
pre_pred_id_list == self.eos).any(dim=-1).all():
|
|
419
450
|
break
|
|
420
451
|
# print(next_id, pre_id)
|
|
421
452
|
# exit(0)
|
|
422
|
-
if len(next_logits_all) > self.sub_str_len and len(
|
|
423
|
-
|
|
424
|
-
|
|
453
|
+
if len(next_logits_all) > self.sub_str_len and len(
|
|
454
|
+
pre_logits_all) > self.sub_str_len:
|
|
455
|
+
next_logits_all_ = torch.concat(next_logits_all[:-1],
|
|
456
|
+
1) # 1, l
|
|
457
|
+
pre_logits_all_ = torch.concat(pre_logits_all[:-1][::-1],
|
|
458
|
+
1) #1, l
|
|
425
459
|
|
|
426
460
|
next_id = next_logits_all_.argmax(-1)[:, -self.sub_str_len:]
|
|
427
461
|
pre_id = pre_logits_all_.argmax(-1)[:, :self.sub_str_len]
|
|
428
462
|
next_logits_all_mid = []
|
|
429
463
|
attn_map_next_mid = []
|
|
430
464
|
ques_next = self.next_token.tile([1, 1, 1, 1]).squeeze(1)
|
|
431
|
-
mask_pad = torch.zeros([1, 1],
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
465
|
+
mask_pad = torch.zeros([1, 1],
|
|
466
|
+
dtype=torch.float32,
|
|
467
|
+
device=x.device)
|
|
468
|
+
for j in range(0, min(70, self.max_len - 1)):
|
|
469
|
+
|
|
470
|
+
prompt_next = torch.concat([
|
|
471
|
+
prompt_next_embed[:, :1, :],
|
|
472
|
+
prompt_next_embed[:, 1:, :] + self.char_embed(next_id)
|
|
473
|
+
], 1) # b, sub_l, dim
|
|
474
|
+
mask_next = torch.where(next_id == self.bos_next,
|
|
475
|
+
float('-inf'), 0) # b, subs_l
|
|
436
476
|
mask = torch.concat([mask_pad, mask_next], 1)
|
|
437
477
|
# prompt_next = self.char_embed(prompt_id)
|
|
438
478
|
ques_next_i = ques_next
|
|
439
479
|
visual_f_i = visual_f[2:3]
|
|
440
480
|
for layer in self.cmff_decoder:
|
|
441
|
-
ques_next_i = layer(ques_next_i, prompt_next,
|
|
442
|
-
|
|
443
|
-
|
|
481
|
+
ques_next_i = layer(ques_next_i, prompt_next,
|
|
482
|
+
visual_f_i, mask.unsqueeze(1))
|
|
483
|
+
logits_next_i = self.ques1_head(
|
|
484
|
+
self.norm_pred(ques_next_i))
|
|
485
|
+
attn_map_next_mid.append(
|
|
486
|
+
self.cmff_decoder[-1].question_to_images_cross_attn.
|
|
487
|
+
attn_map[0])
|
|
444
488
|
logits = F.softmax(logits_next_i, -1)
|
|
445
489
|
pred_id_i = logits.argmax(-1)
|
|
446
490
|
next_logits_all_mid.append(logits)
|
|
@@ -449,12 +493,19 @@ class SMTRDecoder(nn.Module):
|
|
|
449
493
|
break
|
|
450
494
|
next_logits_all_mid = torch.concat(next_logits_all_mid, 1)
|
|
451
495
|
# next_logits_all_ = torch.concat([next_logits_all_, next_logits_all], 1)
|
|
452
|
-
self.attn_maps = [
|
|
453
|
-
|
|
496
|
+
self.attn_maps = [
|
|
497
|
+
attn_map_next, attn_map_next_mid, attn_map_pre[::-1]
|
|
498
|
+
]
|
|
499
|
+
return [
|
|
500
|
+
torch.concat(next_logits_all, 1), next_logits_all_mid,
|
|
501
|
+
torch.concat(pre_logits_all[::-1], 1)
|
|
502
|
+
]
|
|
454
503
|
else:
|
|
455
504
|
self.attn_maps = [attn_map_next, attn_map_pre[::-1]]
|
|
456
|
-
return [
|
|
457
|
-
|
|
505
|
+
return [
|
|
506
|
+
torch.concat(next_logits_all, 1),
|
|
507
|
+
torch.concat(pre_logits_all[::-1], 1)
|
|
508
|
+
]
|
|
458
509
|
|
|
459
510
|
def forward_test(self, x):
|
|
460
511
|
self.attn_maps = []
|
|
@@ -579,7 +630,7 @@ class SMTRDecoder(nn.Module):
|
|
|
579
630
|
prompt_char_next = torch.concat([
|
|
580
631
|
prompt_next_embed[:, :, :1, :],
|
|
581
632
|
prompt_next_embed[:, :, 1:, :] + self.char_embed(subs)
|
|
582
|
-
], 2) # b, n,
|
|
633
|
+
], 2) # b, n, subs_l, dim
|
|
583
634
|
next = self.next_token.tile([bs, max_len_curr, 1, 1])
|
|
584
635
|
|
|
585
636
|
max_len_curr_pre = targets[6].max()
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""UniRecConfig model configuration"""
|
|
15
|
+
|
|
16
|
+
from transformers.configuration_utils import PretrainedConfig
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class UniRecConfig(PretrainedConfig):
|
|
20
|
+
r"""
|
|
21
|
+
This is the configuration class to store the configuration of a [`M2M100Model`]. It is used to instantiate an
|
|
22
|
+
M2M100 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
|
23
|
+
with the defaults will yield a similar configuration to that of the M2M100
|
|
24
|
+
[facebook/m2m100_418M](https://huggingface.co/facebook/m2m100_418M) architecture.
|
|
25
|
+
|
|
26
|
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
27
|
+
documentation from [`PretrainedConfig`] for more information.
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
vocab_size (`int`, *optional*, defaults to 50265):
|
|
32
|
+
Vocabulary size of the M2M100 model. Defines the number of different tokens that can be represented by the
|
|
33
|
+
`inputs_ids` passed when calling [`M2M100Model`] or
|
|
34
|
+
d_model (`int`, *optional*, defaults to 1024):
|
|
35
|
+
Dimensionality of the layers and the pooler layer.
|
|
36
|
+
encoder_layers (`int`, *optional*, defaults to 12):
|
|
37
|
+
Number of encoder layers.
|
|
38
|
+
decoder_layers (`int`, *optional*, defaults to 12):
|
|
39
|
+
Number of decoder layers.
|
|
40
|
+
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
|
41
|
+
Number of attention heads for each attention layer in the Transformer encoder.
|
|
42
|
+
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
|
43
|
+
Number of attention heads for each attention layer in the Transformer decoder.
|
|
44
|
+
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
|
45
|
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
|
46
|
+
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
|
47
|
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
|
48
|
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
|
49
|
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
|
50
|
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
|
51
|
+
dropout (`float`, *optional*, defaults to 0.1):
|
|
52
|
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
|
53
|
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
54
|
+
The dropout ratio for the attention probabilities.
|
|
55
|
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
|
56
|
+
The dropout ratio for activations inside the fully connected layer.
|
|
57
|
+
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
|
58
|
+
The dropout ratio for classifier.
|
|
59
|
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
|
60
|
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
|
61
|
+
just in case (e.g., 512 or 1024 or 2048).
|
|
62
|
+
init_std (`float`, *optional*, defaults to 0.02):
|
|
63
|
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
64
|
+
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
|
65
|
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
|
66
|
+
for more details.
|
|
67
|
+
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
|
68
|
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
|
69
|
+
for more details.
|
|
70
|
+
use_cache (`bool`, *optional*, defaults to `True`):
|
|
71
|
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
|
72
|
+
|
|
73
|
+
Example:
|
|
74
|
+
|
|
75
|
+
```python
|
|
76
|
+
>>> from transformers import M2M100Config, M2M100Model
|
|
77
|
+
|
|
78
|
+
>>> # Initializing a M2M100 facebook/m2m100_418M style configuration
|
|
79
|
+
>>> configuration = M2M100Config()
|
|
80
|
+
|
|
81
|
+
>>> # Initializing a model (with random weights) from the facebook/m2m100_418M style configuration
|
|
82
|
+
>>> model = M2M100Model(configuration)
|
|
83
|
+
|
|
84
|
+
>>> # Accessing the model configuration
|
|
85
|
+
>>> configuration = model.config
|
|
86
|
+
```"""
|
|
87
|
+
|
|
88
|
+
model_type = 'm2m_100'
|
|
89
|
+
keys_to_ignore_at_inference = ['past_key_values']
|
|
90
|
+
attribute_map = {
|
|
91
|
+
'num_attention_heads': 'encoder_attention_heads',
|
|
92
|
+
'hidden_size': 'd_model'
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
vocab_size=50000,
|
|
98
|
+
max_position_embeddings=3072,
|
|
99
|
+
decoder_layers=6,
|
|
100
|
+
decoder_ffn_dim=1536,
|
|
101
|
+
decoder_attention_heads=6,
|
|
102
|
+
encoder_layerdrop=0.0,
|
|
103
|
+
decoder_layerdrop=0.0,
|
|
104
|
+
use_cache=True,
|
|
105
|
+
is_encoder_decoder=True,
|
|
106
|
+
activation_function='relu',
|
|
107
|
+
d_model=384,
|
|
108
|
+
dropout=0.1,
|
|
109
|
+
attention_dropout=0.1,
|
|
110
|
+
activation_dropout=0.0,
|
|
111
|
+
init_std=0.02,
|
|
112
|
+
decoder_start_token_id=0,
|
|
113
|
+
scale_embedding=True,
|
|
114
|
+
pad_token_id=1,
|
|
115
|
+
bos_token_id=0,
|
|
116
|
+
eos_token_id=2,
|
|
117
|
+
depths=[2, 2, 9, 2],
|
|
118
|
+
dims=[64, 128, 256, 384],
|
|
119
|
+
mixer=[['Conv'] * 2, ['Conv'] * 2,
|
|
120
|
+
['Conv'] * 6 + ['FGlobal', 'Global', 'Global'], ['Global'] * 2],
|
|
121
|
+
num_heads=[2, 4, 4, 6],
|
|
122
|
+
sub_k=[[2, 2], [2, 2], [2, 2], [2, 2]],
|
|
123
|
+
mlp_ratio=4,
|
|
124
|
+
kernel_size=[3, 3],
|
|
125
|
+
drop_path_rate=0.1,
|
|
126
|
+
label_smoothing=0.1,
|
|
127
|
+
torch_dtype='bfloat16',
|
|
128
|
+
**kwargs,
|
|
129
|
+
):
|
|
130
|
+
self.vocab_size = vocab_size
|
|
131
|
+
self.max_position_embeddings = max_position_embeddings
|
|
132
|
+
self.d_model = d_model
|
|
133
|
+
self.decoder_ffn_dim = decoder_ffn_dim
|
|
134
|
+
self.decoder_layers = decoder_layers
|
|
135
|
+
self.decoder_attention_heads = decoder_attention_heads
|
|
136
|
+
self.dropout = dropout
|
|
137
|
+
self.attention_dropout = attention_dropout
|
|
138
|
+
self.activation_dropout = activation_dropout
|
|
139
|
+
self.activation_function = activation_function
|
|
140
|
+
self.init_std = init_std
|
|
141
|
+
self.encoder_layerdrop = encoder_layerdrop
|
|
142
|
+
self.decoder_layerdrop = decoder_layerdrop
|
|
143
|
+
self.use_cache = use_cache
|
|
144
|
+
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
|
145
|
+
self.depths = depths
|
|
146
|
+
self.dims = dims
|
|
147
|
+
self.mixer = mixer
|
|
148
|
+
self.num_heads = num_heads
|
|
149
|
+
self.sub_k = sub_k
|
|
150
|
+
self.mlp_ratio = mlp_ratio
|
|
151
|
+
self.kernel_size = kernel_size
|
|
152
|
+
self.drop_path_rate = drop_path_rate
|
|
153
|
+
self.label_smoothing = label_smoothing
|
|
154
|
+
self.torch_dtype = torch_dtype
|
|
155
|
+
|
|
156
|
+
super().__init__(
|
|
157
|
+
pad_token_id=pad_token_id,
|
|
158
|
+
bos_token_id=bos_token_id,
|
|
159
|
+
eos_token_id=eos_token_id,
|
|
160
|
+
is_encoder_decoder=is_encoder_decoder,
|
|
161
|
+
decoder_start_token_id=decoder_start_token_id,
|
|
162
|
+
**kwargs,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
__all__ = ['UniRecConfig']
|