openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -71,6 +71,7 @@ class SSMatchLayer(nn.Module):
71
71
  drop_path=0.0,
72
72
  act_layer=nn.GELU,
73
73
  epsilon=1e-6,
74
+ is_last_layer=False,
74
75
  ):
75
76
  super().__init__()
76
77
  self.dim = dim
@@ -96,17 +97,18 @@ class SSMatchLayer(nn.Module):
96
97
  proj_drop=drop)
97
98
  # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
99
  self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
100
+ self.is_last_layer = is_last_layer
99
101
 
100
102
  def forward(self, question_f, prompt_f, visual_f, mask=None):
101
-
102
103
  question_f = question_f + self.drop_path(
103
104
  self.images_to_question_cross_attn(self.normq1(question_f),
104
105
  self.normkv1(prompt_f), mask))
105
106
  question_f = question_f.reshape(visual_f.shape[0], -1, self.dim)
106
107
  question_f = self.question_to_images_cross_attn(
107
108
  self.normq2(question_f), self.normkv2(visual_f))
108
-
109
- return question_f
109
+ if self.is_last_layer:
110
+ return question_f
111
+ return question_f.flatten(0, 1).unsqueeze(1)
110
112
 
111
113
 
112
114
  class SMTRDecoder(nn.Module):
@@ -152,7 +154,9 @@ class SMTRDecoder(nn.Module):
152
154
  dynq2img_heads=dynq2img_heads,
153
155
  mlp_ratio=4.0,
154
156
  qkv_bias=True,
155
- drop_path=dpr[i]) for i in range(num_layer)
157
+ drop_path=dpr[i],
158
+ is_last_layer=i == num_layer - 1)
159
+ for i in range(num_layer)
156
160
  ])
157
161
 
158
162
  self.ds = ds
@@ -351,7 +355,6 @@ class SMTRDecoder(nn.Module):
351
355
  else:
352
356
  return torch.concat(next_logits_all + pre_logits_all[::-1], 1)
353
357
 
354
-
355
358
  def forward_test_bi_attn(self, x):
356
359
  self.attn_maps = []
357
360
  if not self.ds:
@@ -366,81 +369,122 @@ class SMTRDecoder(nn.Module):
366
369
  next = self.next_token
367
370
  pre = self.pre_token
368
371
  next_pre = torch.concat([next, pre], 0)
369
- next_pre = next_pre.squeeze(1) #2, 1, dim
372
+ next_pre = next_pre.squeeze(1) #2, 1, dim
370
373
 
371
374
  prompt_next_embed = self.prompt_next_embed.squeeze(1)
372
375
  prompt_pre_embed = self.prompt_pre_embed.squeeze(1)
373
376
 
374
- next_id = torch.full([1, self.sub_str_len], self.bos_next, dtype=torch.long, device=x.device)
375
- pre_id = torch.full([1, self.sub_str_len], self.bos_pre, dtype=torch.long, device=x.device)
377
+ next_id = torch.full([1, self.sub_str_len],
378
+ self.bos_next,
379
+ dtype=torch.long,
380
+ device=x.device)
381
+ pre_id = torch.full([1, self.sub_str_len],
382
+ self.bos_pre,
383
+ dtype=torch.long,
384
+ device=x.device)
376
385
  # prompt_next_bos = self.char_embed(prompt_id)
377
386
  # pred_prob_list = torch.full([bs, self.sub_str_len], self.ignore_index, dtype=torch.long, device=x.device)
378
- next_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.device)
379
- pre_pred_id_list = torch.full([1, self.max_len], self.ignore_index, dtype=torch.long, device=x.device)
387
+ next_pred_id_list = torch.full([1, self.max_len],
388
+ self.ignore_index,
389
+ dtype=torch.long,
390
+ device=x.device)
391
+ pre_pred_id_list = torch.full([1, self.max_len],
392
+ self.ignore_index,
393
+ dtype=torch.long,
394
+ device=x.device)
380
395
  next_logits_all = []
381
396
  pre_logits_all = []
382
397
  attn_map_next = []
383
398
  attn_map_pre = []
384
- mask_pad = torch.zeros([bs, 1], dtype=torch.float32, device=x.device)
385
- for j in range(0, min(70, self.max_len-1)):
386
-
387
- prompt_char_next = torch.concat([prompt_next_embed[:, :1, :], prompt_next_embed[:, 1:, :] + self.char_embed(next_id)], 1) # b, sub_l, dim
388
- prompt_char_pre = torch.concat([prompt_pre_embed[:, :1, :], prompt_pre_embed[:, 1:, :] + self.char_embed(pre_id)], 1) # b, sub_l, dim
389
- prompt_char = torch.concat([prompt_char_next, prompt_char_pre], 0) #2, 6, dim
399
+ mask_pad = torch.zeros([bs, 1],
400
+ dtype=torch.float32,
401
+ device=x.device)
402
+ for j in range(0, min(70, self.max_len - 1)):
403
+
404
+ prompt_char_next = torch.concat([
405
+ prompt_next_embed[:, :1, :],
406
+ prompt_next_embed[:, 1:, :] + self.char_embed(next_id)
407
+ ], 1) # b, sub_l, dim
408
+ prompt_char_pre = torch.concat([
409
+ prompt_pre_embed[:, :1, :],
410
+ prompt_pre_embed[:, 1:, :] + self.char_embed(pre_id)
411
+ ], 1) # b, sub_l, dim
412
+ prompt_char = torch.concat([prompt_char_next, prompt_char_pre],
413
+ 0) #2, 6, dim
390
414
  # prompt_char = prompt_char.flatten(0, 1)
391
415
 
392
- mask_next = torch.where(next_id == self.bos_next, float('-inf'), 0) # b, subs_l
393
- mask_pre = torch.where(pre_id == self.bos_pre, float('-inf'), 0) # b, subs_l
394
- mask = torch.concat([mask_next, mask_pre], 0) #2, 5
395
- mask = torch.concat([mask_pad, mask], 1) # 2, 6
416
+ mask_next = torch.where(next_id == self.bos_next,
417
+ float('-inf'), 0) # b, subs_l
418
+ mask_pre = torch.where(pre_id == self.bos_pre, float('-inf'),
419
+ 0) # b, subs_l
420
+ mask = torch.concat([mask_next, mask_pre], 0) #2, 5
421
+ mask = torch.concat([mask_pad, mask], 1) # 2, 6
396
422
  pred_token = next_pre
397
- visual_f_i = visual_f[:2] # 2 l dim
423
+ visual_f_i = visual_f[:2] # 2 l dim
398
424
  for layer in self.cmff_decoder:
399
- pred_token = layer(pred_token, prompt_char, visual_f_i, mask.unsqueeze(1))
400
-
401
-
425
+ pred_token = layer(pred_token, prompt_char, visual_f_i,
426
+ mask.unsqueeze(1))
427
+
402
428
  logits_next_i = self.ques1_head(self.norm_pred(pred_token))
403
429
  logits = F.softmax(logits_next_i, -1)
404
- pred_id_i = logits.argmax(-1) #2, 1
430
+ pred_id_i = logits.argmax(-1) #2, 1
405
431
  # print(pred_id_i.shape)
406
-
407
- next_pred_id_list[:, j:j+1] = pred_id_i[:1]
408
- pre_pred_id_list[:, j:j+1] = pred_id_i[1:2]
432
+
433
+ next_pred_id_list[:, j:j + 1] = pred_id_i[:1]
434
+ pre_pred_id_list[:, j:j + 1] = pred_id_i[1:2]
409
435
  if not (next_pred_id_list == self.eos).any(dim=-1).all():
410
436
  next_logits_all.append(logits[:1])
411
- attn_map_next.append(self.cmff_decoder[-1].question_to_images_cross_attn.attn_map[0])
437
+ attn_map_next.append(
438
+ self.cmff_decoder[-1].question_to_images_cross_attn.
439
+ attn_map[0])
412
440
  next_id = torch.concat([next_id[:, 1:], pred_id_i[:1]], 1)
413
441
  if not (pre_pred_id_list == self.eos).any(dim=-1).all():
414
442
  pre_logits_all.append(logits[1:2])
415
- attn_map_pre.append(self.cmff_decoder[-1].question_to_images_cross_attn.attn_map[1])
443
+ attn_map_pre.append(
444
+ self.cmff_decoder[-1].question_to_images_cross_attn.
445
+ attn_map[1])
416
446
  pre_id = torch.concat([pred_id_i[1:2], pre_id[:, :-1]], 1)
417
-
418
- if (next_pred_id_list == self.eos).any(dim=-1).all() and (pre_pred_id_list == self.eos).any(dim=-1).all():
447
+
448
+ if (next_pred_id_list == self.eos).any(dim=-1).all() and (
449
+ pre_pred_id_list == self.eos).any(dim=-1).all():
419
450
  break
420
451
  # print(next_id, pre_id)
421
452
  # exit(0)
422
- if len(next_logits_all) > self.sub_str_len and len(pre_logits_all) > self.sub_str_len:
423
- next_logits_all_ = torch.concat(next_logits_all[:-1], 1) # 1, l
424
- pre_logits_all_ = torch.concat(pre_logits_all[:-1][::-1], 1) #1, l
453
+ if len(next_logits_all) > self.sub_str_len and len(
454
+ pre_logits_all) > self.sub_str_len:
455
+ next_logits_all_ = torch.concat(next_logits_all[:-1],
456
+ 1) # 1, l
457
+ pre_logits_all_ = torch.concat(pre_logits_all[:-1][::-1],
458
+ 1) #1, l
425
459
 
426
460
  next_id = next_logits_all_.argmax(-1)[:, -self.sub_str_len:]
427
461
  pre_id = pre_logits_all_.argmax(-1)[:, :self.sub_str_len]
428
462
  next_logits_all_mid = []
429
463
  attn_map_next_mid = []
430
464
  ques_next = self.next_token.tile([1, 1, 1, 1]).squeeze(1)
431
- mask_pad = torch.zeros([1, 1], dtype=torch.float32, device=x.device)
432
- for j in range(0, min(70, self.max_len-1)):
433
-
434
- prompt_next = torch.concat([prompt_next_embed[:, :1, :], prompt_next_embed[:, 1:, :] + self.char_embed(next_id)], 1) # b, sub_l, dim
435
- mask_next = torch.where(next_id == self.bos_next, float('-inf'), 0) # b, subs_l
465
+ mask_pad = torch.zeros([1, 1],
466
+ dtype=torch.float32,
467
+ device=x.device)
468
+ for j in range(0, min(70, self.max_len - 1)):
469
+
470
+ prompt_next = torch.concat([
471
+ prompt_next_embed[:, :1, :],
472
+ prompt_next_embed[:, 1:, :] + self.char_embed(next_id)
473
+ ], 1) # b, sub_l, dim
474
+ mask_next = torch.where(next_id == self.bos_next,
475
+ float('-inf'), 0) # b, subs_l
436
476
  mask = torch.concat([mask_pad, mask_next], 1)
437
477
  # prompt_next = self.char_embed(prompt_id)
438
478
  ques_next_i = ques_next
439
479
  visual_f_i = visual_f[2:3]
440
480
  for layer in self.cmff_decoder:
441
- ques_next_i = layer(ques_next_i, prompt_next, visual_f_i, mask.unsqueeze(1))
442
- logits_next_i = self.ques1_head(self.norm_pred(ques_next_i))
443
- attn_map_next_mid.append(self.cmff_decoder[-1].question_to_images_cross_attn.attn_map[0])
481
+ ques_next_i = layer(ques_next_i, prompt_next,
482
+ visual_f_i, mask.unsqueeze(1))
483
+ logits_next_i = self.ques1_head(
484
+ self.norm_pred(ques_next_i))
485
+ attn_map_next_mid.append(
486
+ self.cmff_decoder[-1].question_to_images_cross_attn.
487
+ attn_map[0])
444
488
  logits = F.softmax(logits_next_i, -1)
445
489
  pred_id_i = logits.argmax(-1)
446
490
  next_logits_all_mid.append(logits)
@@ -449,12 +493,19 @@ class SMTRDecoder(nn.Module):
449
493
  break
450
494
  next_logits_all_mid = torch.concat(next_logits_all_mid, 1)
451
495
  # next_logits_all_ = torch.concat([next_logits_all_, next_logits_all], 1)
452
- self.attn_maps = [attn_map_next, attn_map_next_mid, attn_map_pre[::-1]]
453
- return [torch.concat(next_logits_all, 1), next_logits_all_mid, torch.concat(pre_logits_all[::-1], 1)]
496
+ self.attn_maps = [
497
+ attn_map_next, attn_map_next_mid, attn_map_pre[::-1]
498
+ ]
499
+ return [
500
+ torch.concat(next_logits_all, 1), next_logits_all_mid,
501
+ torch.concat(pre_logits_all[::-1], 1)
502
+ ]
454
503
  else:
455
504
  self.attn_maps = [attn_map_next, attn_map_pre[::-1]]
456
- return [torch.concat(next_logits_all, 1), torch.concat(pre_logits_all[::-1], 1)]
457
-
505
+ return [
506
+ torch.concat(next_logits_all, 1),
507
+ torch.concat(pre_logits_all[::-1], 1)
508
+ ]
458
509
 
459
510
  def forward_test(self, x):
460
511
  self.attn_maps = []
@@ -579,7 +630,7 @@ class SMTRDecoder(nn.Module):
579
630
  prompt_char_next = torch.concat([
580
631
  prompt_next_embed[:, :, :1, :],
581
632
  prompt_next_embed[:, :, 1:, :] + self.char_embed(subs)
582
- ], 2) # b, n, sub_l, dim
633
+ ], 2) # b, n, subs_l, dim
583
634
  next = self.next_token.tile([bs, max_len_curr, 1, 1])
584
635
 
585
636
  max_len_curr_pre = targets[6].max()
@@ -0,0 +1,166 @@
1
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """UniRecConfig model configuration"""
15
+
16
+ from transformers.configuration_utils import PretrainedConfig
17
+
18
+
19
+ class UniRecConfig(PretrainedConfig):
20
+ r"""
21
+ This is the configuration class to store the configuration of a [`M2M100Model`]. It is used to instantiate an
22
+ M2M100 model according to the specified arguments, defining the model architecture. Instantiating a configuration
23
+ with the defaults will yield a similar configuration to that of the M2M100
24
+ [facebook/m2m100_418M](https://huggingface.co/facebook/m2m100_418M) architecture.
25
+
26
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
27
+ documentation from [`PretrainedConfig`] for more information.
28
+
29
+
30
+ Args:
31
+ vocab_size (`int`, *optional*, defaults to 50265):
32
+ Vocabulary size of the M2M100 model. Defines the number of different tokens that can be represented by the
33
+ `inputs_ids` passed when calling [`M2M100Model`] or
34
+ d_model (`int`, *optional*, defaults to 1024):
35
+ Dimensionality of the layers and the pooler layer.
36
+ encoder_layers (`int`, *optional*, defaults to 12):
37
+ Number of encoder layers.
38
+ decoder_layers (`int`, *optional*, defaults to 12):
39
+ Number of decoder layers.
40
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
41
+ Number of attention heads for each attention layer in the Transformer encoder.
42
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
43
+ Number of attention heads for each attention layer in the Transformer decoder.
44
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
45
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
46
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
47
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
48
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
49
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
50
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
51
+ dropout (`float`, *optional*, defaults to 0.1):
52
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
53
+ attention_dropout (`float`, *optional*, defaults to 0.0):
54
+ The dropout ratio for the attention probabilities.
55
+ activation_dropout (`float`, *optional*, defaults to 0.0):
56
+ The dropout ratio for activations inside the fully connected layer.
57
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
58
+ The dropout ratio for classifier.
59
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
60
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
61
+ just in case (e.g., 512 or 1024 or 2048).
62
+ init_std (`float`, *optional*, defaults to 0.02):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
65
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
66
+ for more details.
67
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
68
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
69
+ for more details.
70
+ use_cache (`bool`, *optional*, defaults to `True`):
71
+ Whether or not the model should return the last key/values attentions (not used by all models).
72
+
73
+ Example:
74
+
75
+ ```python
76
+ >>> from transformers import M2M100Config, M2M100Model
77
+
78
+ >>> # Initializing a M2M100 facebook/m2m100_418M style configuration
79
+ >>> configuration = M2M100Config()
80
+
81
+ >>> # Initializing a model (with random weights) from the facebook/m2m100_418M style configuration
82
+ >>> model = M2M100Model(configuration)
83
+
84
+ >>> # Accessing the model configuration
85
+ >>> configuration = model.config
86
+ ```"""
87
+
88
+ model_type = 'm2m_100'
89
+ keys_to_ignore_at_inference = ['past_key_values']
90
+ attribute_map = {
91
+ 'num_attention_heads': 'encoder_attention_heads',
92
+ 'hidden_size': 'd_model'
93
+ }
94
+
95
+ def __init__(
96
+ self,
97
+ vocab_size=50000,
98
+ max_position_embeddings=3072,
99
+ decoder_layers=6,
100
+ decoder_ffn_dim=1536,
101
+ decoder_attention_heads=6,
102
+ encoder_layerdrop=0.0,
103
+ decoder_layerdrop=0.0,
104
+ use_cache=True,
105
+ is_encoder_decoder=True,
106
+ activation_function='relu',
107
+ d_model=384,
108
+ dropout=0.1,
109
+ attention_dropout=0.1,
110
+ activation_dropout=0.0,
111
+ init_std=0.02,
112
+ decoder_start_token_id=0,
113
+ scale_embedding=True,
114
+ pad_token_id=1,
115
+ bos_token_id=0,
116
+ eos_token_id=2,
117
+ depths=[2, 2, 9, 2],
118
+ dims=[64, 128, 256, 384],
119
+ mixer=[['Conv'] * 2, ['Conv'] * 2,
120
+ ['Conv'] * 6 + ['FGlobal', 'Global', 'Global'], ['Global'] * 2],
121
+ num_heads=[2, 4, 4, 6],
122
+ sub_k=[[2, 2], [2, 2], [2, 2], [2, 2]],
123
+ mlp_ratio=4,
124
+ kernel_size=[3, 3],
125
+ drop_path_rate=0.1,
126
+ label_smoothing=0.1,
127
+ torch_dtype='bfloat16',
128
+ **kwargs,
129
+ ):
130
+ self.vocab_size = vocab_size
131
+ self.max_position_embeddings = max_position_embeddings
132
+ self.d_model = d_model
133
+ self.decoder_ffn_dim = decoder_ffn_dim
134
+ self.decoder_layers = decoder_layers
135
+ self.decoder_attention_heads = decoder_attention_heads
136
+ self.dropout = dropout
137
+ self.attention_dropout = attention_dropout
138
+ self.activation_dropout = activation_dropout
139
+ self.activation_function = activation_function
140
+ self.init_std = init_std
141
+ self.encoder_layerdrop = encoder_layerdrop
142
+ self.decoder_layerdrop = decoder_layerdrop
143
+ self.use_cache = use_cache
144
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
145
+ self.depths = depths
146
+ self.dims = dims
147
+ self.mixer = mixer
148
+ self.num_heads = num_heads
149
+ self.sub_k = sub_k
150
+ self.mlp_ratio = mlp_ratio
151
+ self.kernel_size = kernel_size
152
+ self.drop_path_rate = drop_path_rate
153
+ self.label_smoothing = label_smoothing
154
+ self.torch_dtype = torch_dtype
155
+
156
+ super().__init__(
157
+ pad_token_id=pad_token_id,
158
+ bos_token_id=bos_token_id,
159
+ eos_token_id=eos_token_id,
160
+ is_encoder_decoder=is_encoder_decoder,
161
+ decoder_start_token_id=decoder_start_token_id,
162
+ **kwargs,
163
+ )
164
+
165
+
166
+ __all__ = ['UniRecConfig']