python-doctr 0.12.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/contrib/artefacts.py +1 -1
  3. doctr/contrib/base.py +1 -1
  4. doctr/datasets/__init__.py +0 -5
  5. doctr/datasets/coco_text.py +1 -1
  6. doctr/datasets/cord.py +1 -1
  7. doctr/datasets/datasets/__init__.py +1 -6
  8. doctr/datasets/datasets/base.py +1 -1
  9. doctr/datasets/datasets/pytorch.py +3 -3
  10. doctr/datasets/detection.py +1 -1
  11. doctr/datasets/doc_artefacts.py +1 -1
  12. doctr/datasets/funsd.py +1 -1
  13. doctr/datasets/generator/__init__.py +1 -6
  14. doctr/datasets/generator/base.py +1 -1
  15. doctr/datasets/generator/pytorch.py +1 -1
  16. doctr/datasets/ic03.py +1 -1
  17. doctr/datasets/ic13.py +1 -1
  18. doctr/datasets/iiit5k.py +1 -1
  19. doctr/datasets/iiithws.py +1 -1
  20. doctr/datasets/imgur5k.py +1 -1
  21. doctr/datasets/mjsynth.py +1 -1
  22. doctr/datasets/ocr.py +1 -1
  23. doctr/datasets/orientation.py +1 -1
  24. doctr/datasets/recognition.py +1 -1
  25. doctr/datasets/sroie.py +1 -1
  26. doctr/datasets/svhn.py +1 -1
  27. doctr/datasets/svt.py +1 -1
  28. doctr/datasets/synthtext.py +1 -1
  29. doctr/datasets/utils.py +1 -1
  30. doctr/datasets/vocabs.py +1 -3
  31. doctr/datasets/wildreceipt.py +1 -1
  32. doctr/file_utils.py +3 -102
  33. doctr/io/elements.py +1 -1
  34. doctr/io/html.py +1 -1
  35. doctr/io/image/__init__.py +1 -7
  36. doctr/io/image/base.py +1 -1
  37. doctr/io/image/pytorch.py +2 -2
  38. doctr/io/pdf.py +1 -1
  39. doctr/io/reader.py +1 -1
  40. doctr/models/_utils.py +56 -18
  41. doctr/models/builder.py +1 -1
  42. doctr/models/classification/magc_resnet/__init__.py +1 -6
  43. doctr/models/classification/magc_resnet/pytorch.py +3 -3
  44. doctr/models/classification/mobilenet/__init__.py +1 -6
  45. doctr/models/classification/mobilenet/pytorch.py +1 -1
  46. doctr/models/classification/predictor/__init__.py +1 -6
  47. doctr/models/classification/predictor/pytorch.py +2 -2
  48. doctr/models/classification/resnet/__init__.py +1 -6
  49. doctr/models/classification/resnet/pytorch.py +1 -1
  50. doctr/models/classification/textnet/__init__.py +1 -6
  51. doctr/models/classification/textnet/pytorch.py +2 -2
  52. doctr/models/classification/vgg/__init__.py +1 -6
  53. doctr/models/classification/vgg/pytorch.py +1 -1
  54. doctr/models/classification/vip/__init__.py +1 -4
  55. doctr/models/classification/vip/layers/__init__.py +1 -4
  56. doctr/models/classification/vip/layers/pytorch.py +2 -2
  57. doctr/models/classification/vip/pytorch.py +1 -1
  58. doctr/models/classification/vit/__init__.py +1 -6
  59. doctr/models/classification/vit/pytorch.py +3 -3
  60. doctr/models/classification/zoo.py +7 -12
  61. doctr/models/core.py +1 -1
  62. doctr/models/detection/_utils/__init__.py +1 -6
  63. doctr/models/detection/_utils/base.py +1 -1
  64. doctr/models/detection/_utils/pytorch.py +1 -1
  65. doctr/models/detection/core.py +2 -2
  66. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  67. doctr/models/detection/differentiable_binarization/base.py +5 -13
  68. doctr/models/detection/differentiable_binarization/pytorch.py +4 -4
  69. doctr/models/detection/fast/__init__.py +1 -6
  70. doctr/models/detection/fast/base.py +5 -15
  71. doctr/models/detection/fast/pytorch.py +5 -5
  72. doctr/models/detection/linknet/__init__.py +1 -6
  73. doctr/models/detection/linknet/base.py +4 -13
  74. doctr/models/detection/linknet/pytorch.py +3 -3
  75. doctr/models/detection/predictor/__init__.py +1 -6
  76. doctr/models/detection/predictor/pytorch.py +2 -2
  77. doctr/models/detection/zoo.py +16 -33
  78. doctr/models/factory/hub.py +26 -34
  79. doctr/models/kie_predictor/__init__.py +1 -6
  80. doctr/models/kie_predictor/base.py +1 -1
  81. doctr/models/kie_predictor/pytorch.py +3 -7
  82. doctr/models/modules/layers/__init__.py +1 -6
  83. doctr/models/modules/layers/pytorch.py +4 -4
  84. doctr/models/modules/transformer/__init__.py +1 -6
  85. doctr/models/modules/transformer/pytorch.py +3 -3
  86. doctr/models/modules/vision_transformer/__init__.py +1 -6
  87. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  88. doctr/models/predictor/__init__.py +1 -6
  89. doctr/models/predictor/base.py +4 -9
  90. doctr/models/predictor/pytorch.py +3 -6
  91. doctr/models/preprocessor/__init__.py +1 -6
  92. doctr/models/preprocessor/pytorch.py +28 -33
  93. doctr/models/recognition/core.py +1 -1
  94. doctr/models/recognition/crnn/__init__.py +1 -6
  95. doctr/models/recognition/crnn/pytorch.py +7 -7
  96. doctr/models/recognition/master/__init__.py +1 -6
  97. doctr/models/recognition/master/base.py +1 -1
  98. doctr/models/recognition/master/pytorch.py +6 -6
  99. doctr/models/recognition/parseq/__init__.py +1 -6
  100. doctr/models/recognition/parseq/base.py +1 -1
  101. doctr/models/recognition/parseq/pytorch.py +6 -6
  102. doctr/models/recognition/predictor/__init__.py +1 -6
  103. doctr/models/recognition/predictor/_utils.py +8 -17
  104. doctr/models/recognition/predictor/pytorch.py +2 -3
  105. doctr/models/recognition/sar/__init__.py +1 -6
  106. doctr/models/recognition/sar/pytorch.py +4 -4
  107. doctr/models/recognition/utils.py +1 -1
  108. doctr/models/recognition/viptr/__init__.py +1 -4
  109. doctr/models/recognition/viptr/pytorch.py +4 -4
  110. doctr/models/recognition/vitstr/__init__.py +1 -6
  111. doctr/models/recognition/vitstr/base.py +1 -1
  112. doctr/models/recognition/vitstr/pytorch.py +4 -4
  113. doctr/models/recognition/zoo.py +14 -14
  114. doctr/models/utils/__init__.py +1 -6
  115. doctr/models/utils/pytorch.py +3 -2
  116. doctr/models/zoo.py +1 -1
  117. doctr/transforms/functional/__init__.py +1 -6
  118. doctr/transforms/functional/base.py +3 -2
  119. doctr/transforms/functional/pytorch.py +5 -5
  120. doctr/transforms/modules/__init__.py +1 -7
  121. doctr/transforms/modules/base.py +28 -94
  122. doctr/transforms/modules/pytorch.py +29 -27
  123. doctr/utils/common_types.py +1 -1
  124. doctr/utils/data.py +1 -2
  125. doctr/utils/fonts.py +1 -1
  126. doctr/utils/geometry.py +7 -11
  127. doctr/utils/metrics.py +1 -1
  128. doctr/utils/multithreading.py +1 -1
  129. doctr/utils/reconstitution.py +1 -1
  130. doctr/utils/repr.py +1 -1
  131. doctr/utils/visualization.py +2 -2
  132. doctr/version.py +1 -1
  133. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/METADATA +30 -80
  134. python_doctr-1.0.1.dist-info/RECORD +149 -0
  135. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/WHEEL +1 -1
  136. doctr/datasets/datasets/tensorflow.py +0 -59
  137. doctr/datasets/generator/tensorflow.py +0 -58
  138. doctr/datasets/loader.py +0 -94
  139. doctr/io/image/tensorflow.py +0 -101
  140. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  141. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  142. doctr/models/classification/predictor/tensorflow.py +0 -60
  143. doctr/models/classification/resnet/tensorflow.py +0 -418
  144. doctr/models/classification/textnet/tensorflow.py +0 -275
  145. doctr/models/classification/vgg/tensorflow.py +0 -125
  146. doctr/models/classification/vit/tensorflow.py +0 -201
  147. doctr/models/detection/_utils/tensorflow.py +0 -34
  148. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  149. doctr/models/detection/fast/tensorflow.py +0 -427
  150. doctr/models/detection/linknet/tensorflow.py +0 -377
  151. doctr/models/detection/predictor/tensorflow.py +0 -70
  152. doctr/models/kie_predictor/tensorflow.py +0 -187
  153. doctr/models/modules/layers/tensorflow.py +0 -171
  154. doctr/models/modules/transformer/tensorflow.py +0 -235
  155. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  156. doctr/models/predictor/tensorflow.py +0 -155
  157. doctr/models/preprocessor/tensorflow.py +0 -122
  158. doctr/models/recognition/crnn/tensorflow.py +0 -317
  159. doctr/models/recognition/master/tensorflow.py +0 -320
  160. doctr/models/recognition/parseq/tensorflow.py +0 -516
  161. doctr/models/recognition/predictor/tensorflow.py +0 -79
  162. doctr/models/recognition/sar/tensorflow.py +0 -423
  163. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  164. doctr/models/utils/tensorflow.py +0 -189
  165. doctr/transforms/functional/tensorflow.py +0 -254
  166. doctr/transforms/modules/tensorflow.py +0 -562
  167. python_doctr-0.12.0.dist-info/RECORD +0 -180
  168. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/licenses/LICENSE +0 -0
  169. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/top_level.txt +0 -0
  170. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/zip-safe +0 -0
@@ -0,0 +1,149 @@
1
+ doctr/__init__.py,sha256=sdqGeYFfPLRsRH54PsedllScz5FD8yWwyekcsOq3JNc,110
2
+ doctr/file_utils.py,sha256=DsQvazaicyDSM1bNuI70p1DNewO0IOxQ5TRBgXSju3w,999
3
+ doctr/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ doctr/version.py,sha256=6p2_N2LdNEx0s2rC1vd5LYuVCJ2ya740g-riPQhSMfQ,23
5
+ doctr/contrib/__init__.py,sha256=EKeAGa3TOuJVWR4H_HJvuKO5VPEnJWXW305JMw3ufic,40
6
+ doctr/contrib/artefacts.py,sha256=fP-UMWBYVibCv0aMHXFd-D3HF4yQIbaWfV0EDl4_iAU,5291
7
+ doctr/contrib/base.py,sha256=uKV-B_9Y02jR_X6of55N3aJCcn8e1asMTHybsTQzRbw,3273
8
+ doctr/datasets/__init__.py,sha256=QzWe91FU1d3Vi5h4ayPtJUE-Y_SSs_4Xp6MAkZ8XKTc,504
9
+ doctr/datasets/coco_text.py,sha256=3-D4EjqOdWF-ioMXgdNj_plYWGrIa-gZt4FasphXAGA,5496
10
+ doctr/datasets/cord.py,sha256=uRGYIDtRsq01UTRShLKEUyOeWphydHyMQA9Rows-1z0,5315
11
+ doctr/datasets/detection.py,sha256=jeTNggksKEYsVF7_Km1LSJNc7VHTcES4BLuYt_UhXpg,3531
12
+ doctr/datasets/doc_artefacts.py,sha256=tYTdwfC5d2f_wkxhJ-2-XMf7Bk8pS4LQ_zLrk-JTVvQ,3230
13
+ doctr/datasets/funsd.py,sha256=siX7xUP38dHT3ECL3ReQ78v4RI-q0mwcuaX1TADkwzI,4733
14
+ doctr/datasets/ic03.py,sha256=l_f0ACbR9DD2_mCe57eLO8dnwFgb2JNP0jjpPmvXCP4,5564
15
+ doctr/datasets/ic13.py,sha256=EWv67V5TwkKcgOzE9qCz38Rnff6b5Bqo7FwXNECDy-I,4528
16
+ doctr/datasets/iiit5k.py,sha256=Yz3brhWUVCqXhAAg75xRU3qWJeX87j2lEG70zkkMrlo,4583
17
+ doctr/datasets/iiithws.py,sha256=8iMytN32T23BpIUMoAEzHrsUgzII_thjoMxHzAEu-nk,2768
18
+ doctr/datasets/imgur5k.py,sha256=MJwrMlBFsxlfJs65rgkKzRjhepNf04kez7kP90Tlt74,7525
19
+ doctr/datasets/mjsynth.py,sha256=SBlhKk7s1onMXF8AJj2WMnIF2n5XXBzTJTvOZxcW0TI,4075
20
+ doctr/datasets/ocr.py,sha256=61SxWHtDOty7ZtkAkuiPAw68n5wZYXI2YZJxP4LhrHc,2523
21
+ doctr/datasets/orientation.py,sha256=b5XsuSWChDxmU_zJ33LtBz76_Rkz5N6r6NogW9uAMZc,1091
22
+ doctr/datasets/recognition.py,sha256=FdaTGAoRq0VD5g7oo8NKphsI2YU8fSZzSEj9IppiT-Y,1858
23
+ doctr/datasets/sroie.py,sha256=9moG-O_PiFU0uXuYzsbzb-TxXKwCOK0NSyRSde0FlT4,4430
24
+ doctr/datasets/svhn.py,sha256=duB_HMWRLY3FKPEgMt2x9k3x4qEjCvbZftcsfuayRWE,5784
25
+ doctr/datasets/svt.py,sha256=B_39RFw3qWYWqGbpVcd51SyhUwi2vjBSCFQKhE1PpiA,5046
26
+ doctr/datasets/synthtext.py,sha256=F_5BS1vJ5b194kRaYpofsSWWIQqBoon9URF742jME5w,6301
27
+ doctr/datasets/utils.py,sha256=rt5-1B_ZbqPpZmfTRAMCqaUPpi3d2O5vfGfiVWkzp04,8106
28
+ doctr/datasets/vocabs.py,sha256=vKGhlVxA8KuJz1X15QcX2i8AwLT1u0eHxvQSUp0zFVA,99032
29
+ doctr/datasets/wildreceipt.py,sha256=vqGYVaaSS97GxSHuvYvpeCIShACAEYhgScJhpA943R0,5213
30
+ doctr/datasets/datasets/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
31
+ doctr/datasets/datasets/base.py,sha256=hyNTt3HCvMpeRA1rYMneLhAbeSTvAbC5JE8MKQBFk5g,4826
32
+ doctr/datasets/datasets/pytorch.py,sha256=U9xtyOSgXLnlGu9vG9ECvp0cf1SvXZ2SABVk2v0cp6g,2030
33
+ doctr/datasets/generator/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
34
+ doctr/datasets/generator/base.py,sha256=xiVAvAzP_r4fTpZy75O_daN93fzsjjpaN-vxn-GtyLU,5706
35
+ doctr/datasets/generator/pytorch.py,sha256=vvpbdu6PgDLbUkKmXYqZYPsoGaRQIV2NqAS7t99jP7E,2116
36
+ doctr/io/__init__.py,sha256=kS7tKGFvzxOCWBOun-Y8n9CsziwRKNynjwpZEUUI03M,106
37
+ doctr/io/elements.py,sha256=74pRYfxbe4yCB8S1It-PymGWbDG0S4qdbjf_P57_M0s,26474
38
+ doctr/io/html.py,sha256=gKu4hxMYaEgRDPcd45HBlctNXw0Fnk3CY5Jgr9JZRMw,698
39
+ doctr/io/pdf.py,sha256=LYxPh7auAIlPh3Xdh462peVnBbCMMjB7ouDfiQq7eok,1308
40
+ doctr/io/reader.py,sha256=gwQmziIkF-cozyE8K4KisZYmFmtyoZUSl4LWyI_CPCI,2732
41
+ doctr/io/image/__init__.py,sha256=bJLj2I8OOTYLuTDjdinao0nkOIWQOLbzIuww23EX3gw,43
42
+ doctr/io/image/base.py,sha256=EgAl_IaklfK67h4nY9nhedTMu1_MUqIqzNcJSH1IGzs,1681
43
+ doctr/io/image/pytorch.py,sha256=UYlll8CPBaOZ7tDyWN-xQ2DMVqGRnBI173Q4NF8ij7A,3201
44
+ doctr/models/__init__.py,sha256=yn_mXUL8B5L27Uaat1rLGRQHgLR8VLVxzBuPfNuN1YE,124
45
+ doctr/models/_utils.py,sha256=U66VEsTHHmJEgauqX_t6tl6fRUTgWiDPKmDL4i-Ptu0,8455
46
+ doctr/models/builder.py,sha256=LkleWWqwhkFdBMMAFWD9TWGDKXSIEx-cn3hi1CToIxc,20370
47
+ doctr/models/core.py,sha256=KPTn_zbiwTV7OoPDv6o-xx7Tzb-QnAfLne1HuNQlWu4,482
48
+ doctr/models/zoo.py,sha256=uPgHXZP_wcRaZXmwZyIzAcLaDHhzo-B_8a0cx64mOVk,9276
49
+ doctr/models/classification/__init__.py,sha256=piTKyOGbTiKBzZSAHQLXf622mqjxaAJPicNtH3b-A0k,173
50
+ doctr/models/classification/zoo.py,sha256=dXgrlE9V7-zNM6yd0hfgpcpTQpSED9X7KuR2jGbjcd0,4265
51
+ doctr/models/classification/magc_resnet/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
52
+ doctr/models/classification/magc_resnet/pytorch.py,sha256=yVnkta1vNSXB5tF8C4TLMeAptWjrbNGbtDmQ4z_lJ6w,5480
53
+ doctr/models/classification/mobilenet/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
54
+ doctr/models/classification/mobilenet/pytorch.py,sha256=33ou3bbirjFhSUe-ws9C6foPosiO_cse-RgDBfxYL0Q,9826
55
+ doctr/models/classification/predictor/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
56
+ doctr/models/classification/predictor/pytorch.py,sha256=bu2dWcwrOBD_0XloiNXzyJ7JrjKiczRv0YWjh46iHbQ,2523
57
+ doctr/models/classification/resnet/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
58
+ doctr/models/classification/resnet/pytorch.py,sha256=nmD1sg9-Vd5u_NRaVaES5lENi4d06jcy00TSl82RT50,13250
59
+ doctr/models/classification/textnet/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
60
+ doctr/models/classification/textnet/pytorch.py,sha256=8a4i33cNgb5zdGjJa8nDgZfHyEMDa4mvtz1rJ153PPM,10424
61
+ doctr/models/classification/vgg/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
62
+ doctr/models/classification/vgg/pytorch.py,sha256=1fO451jkdTLMnc4rDEkvsvkYVapyfK19TKEE0mHoPHY,3679
63
+ doctr/models/classification/vip/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
64
+ doctr/models/classification/vip/pytorch.py,sha256=4IXmfXOil4tEp-g5VmoQIbXWlud2aFXsc0wuIDrGG58,16134
65
+ doctr/models/classification/vip/layers/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
66
+ doctr/models/classification/vip/layers/pytorch.py,sha256=lK01uxaC6QAbTf3DkgxEaYwNFhcPk7rT9TPs6FEkUI4,21019
67
+ doctr/models/classification/vit/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
68
+ doctr/models/classification/vit/pytorch.py,sha256=rrBLNjHvpmxqsNwSZenhrX81u4SBeUiGenFYVQKuaWc,6371
69
+ doctr/models/detection/__init__.py,sha256=RqSz5beehLiqhW0PwFLFmCfTyMjofO-0umcQJLDMHjY,105
70
+ doctr/models/detection/core.py,sha256=3qHxwPRXD_Pcpk3uvUtoCXtW07xagLOjbvsVFzRli7g,3444
71
+ doctr/models/detection/zoo.py,sha256=fmTcKXfOioUk9faOJSnr_1_c_Xviv1JAoPs0FMhZKng,3671
72
+ doctr/models/detection/_utils/__init__.py,sha256=bJLj2I8OOTYLuTDjdinao0nkOIWQOLbzIuww23EX3gw,43
73
+ doctr/models/detection/_utils/base.py,sha256=1Ozf6nQoXt_OhKEt-QPhP1iJ0HJoMugbdDFNYK2wHG8,2475
74
+ doctr/models/detection/_utils/pytorch.py,sha256=NyQxrS4e_EVQ5FriZSQYbvTvtenmAsvSC5QGFbZUhV4,1021
75
+ doctr/models/detection/differentiable_binarization/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
76
+ doctr/models/detection/differentiable_binarization/base.py,sha256=gMfQBrpuxflQUF46CyBFA1WNztU11fdkQxL2J3M3CIQ,15788
77
+ doctr/models/detection/differentiable_binarization/pytorch.py,sha256=sioxicRYT_kr5BsqzfdH6Pa2GTPtmycA2LzQrurKvC0,16453
78
+ doctr/models/detection/fast/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
79
+ doctr/models/detection/fast/base.py,sha256=BQ1Z9ARWakLbTH0VtHjUvmKZKhITFG-RTZkGKucjCyA,10225
80
+ doctr/models/detection/fast/pytorch.py,sha256=fzUoDDPrrKS9H_WyO8jcPdUZrrqVMZHL4d04DIJq_yY,16702
81
+ doctr/models/detection/linknet/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
82
+ doctr/models/detection/linknet/base.py,sha256=EKU5rPreZ6HO3brR93R8x5ssBmEu1ynbtc8iffHgQ0g,10007
83
+ doctr/models/detection/linknet/pytorch.py,sha256=KrKuZj16Xiwgk89oSaaVWSgz9jItvkCVv20rjoaqfDw,14333
84
+ doctr/models/detection/predictor/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
85
+ doctr/models/detection/predictor/pytorch.py,sha256=jdU5A10FDBXMLXyu1zJsVxlqSmt860u-oT30H7cIoKM,2627
86
+ doctr/models/factory/__init__.py,sha256=cKPoH2V2157lLMTR2zsljG3_IQHziodqR-XK_LG0D_I,19
87
+ doctr/models/factory/hub.py,sha256=F-CG-9dG3Iqbsj2mKkxf0JqKuZbg4hOfDoFGTO6fAIs,6889
88
+ doctr/models/kie_predictor/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
89
+ doctr/models/kie_predictor/base.py,sha256=oLVDjIZ5sKq4ma5bopWVYUzq6q4ZAxF6mtEEQNxwBxk,2291
90
+ doctr/models/kie_predictor/pytorch.py,sha256=pjFtZcjDDO9mpC7Zscu9CYcmkfF87J1ioJ8N_dzUSq4,7704
91
+ doctr/models/modules/__init__.py,sha256=pouP7obVTu4p6aHkyaqa1yHKbynpvT0Hgo-LO_1U2R4,83
92
+ doctr/models/modules/layers/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
93
+ doctr/models/modules/layers/pytorch.py,sha256=tjTZYcQt30TSpFwpzBfwEc7QN7bQdxnYMGOcBI4MqwE,8667
94
+ doctr/models/modules/transformer/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
95
+ doctr/models/modules/transformer/pytorch.py,sha256=AVFeQCElFgg7ondhTYc_XFPocTKMt3uFJBHUFSEUevk,7657
96
+ doctr/models/modules/vision_transformer/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
97
+ doctr/models/modules/vision_transformer/pytorch.py,sha256=ltnBAevcfJl612yblI6WhcJAPEgUO0w-5jFL7vBhXoA,3943
98
+ doctr/models/predictor/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
99
+ doctr/models/predictor/base.py,sha256=YYvEepdYY7bphhOmBdFsVGFI6EaBZDip9hut8qz2uA8,8541
100
+ doctr/models/predictor/pytorch.py,sha256=XplCn6sCdeREsINz6foxqLVTTJ0jiO2vMJgcVF0_Dxw,6241
101
+ doctr/models/preprocessor/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
102
+ doctr/models/preprocessor/pytorch.py,sha256=vcMwCb04UlJdCJ5vh-ZXPzoy_-U0sPOn-OWgREvL42g,4395
103
+ doctr/models/recognition/__init__.py,sha256=bgAvbwjO14Z2RQFD2XKZcSYJNsyxPa96SGHmX_nqbAQ,145
104
+ doctr/models/recognition/core.py,sha256=InV-HLYkbQG65B6BYIud_JCWdbeoacC2MURjPZqceoo,1524
105
+ doctr/models/recognition/utils.py,sha256=WLUr6Qa3Kq803xyuLWmlFBS2Vf3P5GntyZVXW-y1LqU,3739
106
+ doctr/models/recognition/zoo.py,sha256=SeJr1IqZsFXzRGYrExgFb2aDyWp0fZWhR6Rr1RJTQHQ,2991
107
+ doctr/models/recognition/crnn/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
108
+ doctr/models/recognition/crnn/pytorch.py,sha256=jy5LVkG0NdLObFIaW2lWsdv6LWYYLIkZDZS-aTjFa6I,12234
109
+ doctr/models/recognition/master/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
110
+ doctr/models/recognition/master/base.py,sha256=KU94CO3YZYyYh83gW6F65IjeJfoAkwuandY3A_DxfM8,1471
111
+ doctr/models/recognition/master/pytorch.py,sha256=mL9o3O_L-iVw-Wq71tF-Y3a_G6qoGlUv2_3SvcWBu2I,12756
112
+ doctr/models/recognition/parseq/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
113
+ doctr/models/recognition/parseq/base.py,sha256=SMGEETj94ZIJUdSLe2d2PhLTpeptrW5MB7cW3WtVOhQ,1465
114
+ doctr/models/recognition/parseq/pytorch.py,sha256=hI_6qBDrKqe4358TKKz0YzXm06K9FKyGn2_6s8caHf8,20801
115
+ doctr/models/recognition/predictor/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
116
+ doctr/models/recognition/predictor/_utils.py,sha256=8q5ate9L7xeqOY-45foX1itPlHUBWw3RpLfCgYJ2ZJo,4695
117
+ doctr/models/recognition/predictor/pytorch.py,sha256=amUzwltBWhQNaCKB2hjqsQuRHtxfeknZO1uwSX3ASzY,2753
118
+ doctr/models/recognition/sar/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
119
+ doctr/models/recognition/sar/pytorch.py,sha256=_YMeCgu7p8-v0N75AiT6dR2O3ww7UFT8E82IWdF9lQM,15639
120
+ doctr/models/recognition/viptr/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
121
+ doctr/models/recognition/viptr/pytorch.py,sha256=IQQeuVyBbknN0BbEYy2Kak8WeBU4dDcpP0wUvjPO330,9355
122
+ doctr/models/recognition/vitstr/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
123
+ doctr/models/recognition/vitstr/base.py,sha256=tRMhC5i_l0S207_YFEs7wFAtkznyAJoH8sUZVMcZUBk,1419
124
+ doctr/models/recognition/vitstr/pytorch.py,sha256=UMXLvGHi065j3XQxbVwrqN2AL3YShfJ1pl-sKkRU0Cs,10080
125
+ doctr/models/utils/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
126
+ doctr/models/utils/pytorch.py,sha256=BHNEDG89XSW9TYNRwtGuf2RcpM2OfjnBb_92-OpYRr8,5804
127
+ doctr/transforms/__init__.py,sha256=0VHbvUZ7llFd1e_7_JdWTaxYMCvaR6KbUJaYJequmQI,23
128
+ doctr/transforms/functional/__init__.py,sha256=hxLu7FmRYn01ysoOQalrMa5K_qg4rLNf8OAuoemzd6I,23
129
+ doctr/transforms/functional/base.py,sha256=LRIM3L6nZQaZ3qyQKEFQ0ZLJRRg7rMYtshADK_ZCOi8,6851
130
+ doctr/transforms/functional/pytorch.py,sha256=mSToyPGn8O3WgK9rQJoGZcOHyIm06MQtWua1YA7p1kU,5102
131
+ doctr/transforms/modules/__init__.py,sha256=bJLj2I8OOTYLuTDjdinao0nkOIWQOLbzIuww23EX3gw,43
132
+ doctr/transforms/modules/base.py,sha256=fAJzcvq6rZv9oj8NRZ3Veo9z2-wFR40w3JjNmIdsP7w,7679
133
+ doctr/transforms/modules/pytorch.py,sha256=XTUbZPC3aahoB0cqwIPlLtsG8dbc7eReG7SAQsjYWJg,11802
134
+ doctr/utils/__init__.py,sha256=uQY9ibZ24V896fmihIsK23QOIZdKtk0HyKoCVJ_lLuM,95
135
+ doctr/utils/common_types.py,sha256=qK-gSvbAGJAhtqTUCO1DJ3Vr04KVFUW5I3Pn0T4Btp8,534
136
+ doctr/utils/data.py,sha256=eXLIuO5TJk6hjlVtodADQFKkt-dlEmLjEhIQcFVF2Zs,4174
137
+ doctr/utils/fonts.py,sha256=fxEZ8o67iyceZpiw87SskP6v8lcC5iHVcfWLPo_m6rQ,1265
138
+ doctr/utils/geometry.py,sha256=y4Xz-heV_sA7OfqoP40vHzJe4bKAgmzIBEA5FvV7Pgc,18678
139
+ doctr/utils/metrics.py,sha256=7u2Fz7KnY32PlV-Lo9M9iYKfHHMXEE6UdSTq-rgJAvU,20249
140
+ doctr/utils/multithreading.py,sha256=Z7cz0M-Zsqq5jRMrAW4Q5L7n2pNvC5qPKOfkXX4w05A,1972
141
+ doctr/utils/reconstitution.py,sha256=i6wqDQmCmu5WPRcIGFxGvGegy0mUv42lLK90DKUFtjc,7256
142
+ doctr/utils/repr.py,sha256=ShDOvCiPjdj1_v1AfLQ7AN_tunu07p0S_scFZYip7GY,2087
143
+ doctr/utils/visualization.py,sha256=sAB283iegS86gq0btK9H6PFLxP7SPXpuX9JWT_v53Iw,13118
144
+ python_doctr-1.0.1.dist-info/licenses/LICENSE,sha256=75RTSsXOsAYhGpxsHc9U41ep6GS7vrUPufeekgoeOXM,11336
145
+ python_doctr-1.0.1.dist-info/METADATA,sha256=gOwRkDU1ufra-cglBOjbW0OpxTwHBvVgFFM9YY_6iMM,32370
146
+ python_doctr-1.0.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
147
+ python_doctr-1.0.1.dist-info/top_level.txt,sha256=lCgp4pmjPI3HYph62XhfzA3jRwM715kGtJPmqIUJ9t8,6
148
+ python_doctr-1.0.1.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
149
+ python_doctr-1.0.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,59 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import os
7
- from copy import deepcopy
8
- from typing import Any
9
-
10
- import numpy as np
11
- import tensorflow as tf
12
-
13
- from doctr.io import read_img_as_tensor, tensor_from_numpy
14
-
15
- from .base import _AbstractDataset, _VisionDataset
16
-
17
- __all__ = ["AbstractDataset", "VisionDataset"]
18
-
19
-
20
- class AbstractDataset(_AbstractDataset):
21
- """Abstract class for all datasets"""
22
-
23
- def _read_sample(self, index: int) -> tuple[tf.Tensor, Any]:
24
- img_name, target = self.data[index]
25
-
26
- # Check target
27
- if isinstance(target, dict):
28
- assert "boxes" in target, "Target should contain 'boxes' key"
29
- assert "labels" in target, "Target should contain 'labels' key"
30
- elif isinstance(target, tuple):
31
- assert len(target) == 2
32
- assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), (
33
- "first element of the tuple should be a string or a numpy array"
34
- )
35
- assert isinstance(target[1], list), "second element of the tuple should be a list"
36
- else:
37
- assert isinstance(target, str) or isinstance(target, np.ndarray), (
38
- "Target should be a string or a numpy array"
39
- )
40
-
41
- # Read image
42
- img = (
43
- tensor_from_numpy(img_name, dtype=tf.float32)
44
- if isinstance(img_name, np.ndarray)
45
- else read_img_as_tensor(os.path.join(self.root, img_name), dtype=tf.float32)
46
- )
47
-
48
- return img, deepcopy(target)
49
-
50
- @staticmethod
51
- def collate_fn(samples: list[tuple[tf.Tensor, Any]]) -> tuple[tf.Tensor, list[Any]]:
52
- images, targets = zip(*samples)
53
- images = tf.stack(images, axis=0)
54
-
55
- return images, list(targets)
56
-
57
-
58
- class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
59
- pass
@@ -1,58 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import tensorflow as tf
7
-
8
- from .base import _CharacterGenerator, _WordGenerator
9
-
10
- __all__ = ["CharacterGenerator", "WordGenerator"]
11
-
12
-
13
- class CharacterGenerator(_CharacterGenerator):
14
- """Implements a character image generation dataset
15
-
16
- >>> from doctr.datasets import CharacterGenerator
17
- >>> ds = CharacterGenerator(vocab='abdef', num_samples=100)
18
- >>> img, target = ds[0]
19
-
20
- Args:
21
- vocab: vocabulary to take the character from
22
- num_samples: number of samples that will be generated iterating over the dataset
23
- cache_samples: whether generated images should be cached firsthand
24
- font_family: font to use to generate the text images
25
- img_transforms: composable transformations that will be applied to each image
26
- sample_transforms: composable transformations that will be applied to both the image and the target
27
- """
28
-
29
- def __init__(self, *args, **kwargs) -> None:
30
- super().__init__(*args, **kwargs)
31
-
32
- @staticmethod
33
- def collate_fn(samples):
34
- images, targets = zip(*samples)
35
- images = tf.stack(images, axis=0)
36
-
37
- return images, tf.convert_to_tensor(targets)
38
-
39
-
40
- class WordGenerator(_WordGenerator):
41
- """Implements a character image generation dataset
42
-
43
- >>> from doctr.datasets import WordGenerator
44
- >>> ds = WordGenerator(vocab='abdef', min_chars=1, max_chars=32, num_samples=100)
45
- >>> img, target = ds[0]
46
-
47
- Args:
48
- vocab: vocabulary to take the character from
49
- min_chars: minimum number of characters in a word
50
- max_chars: maximum number of characters in a word
51
- num_samples: number of samples that will be generated iterating over the dataset
52
- cache_samples: whether generated images should be cached firsthand
53
- font_family: font to use to generate the text images
54
- img_transforms: composable transformations that will be applied to each image
55
- sample_transforms: composable transformations that will be applied to both the image and the target
56
- """
57
-
58
- pass
doctr/datasets/loader.py DELETED
@@ -1,94 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import math
7
- from collections.abc import Callable
8
-
9
- import numpy as np
10
- import tensorflow as tf
11
-
12
- __all__ = ["DataLoader"]
13
-
14
-
15
- def default_collate(samples):
16
- """Collate multiple elements into batches
17
-
18
- Args:
19
- samples: list of N tuples containing M elements
20
-
21
- Returns:
22
- tuple of M sequences containing N elements each
23
- """
24
- batch_data = zip(*samples)
25
-
26
- tf_data = tuple(tf.stack(elt, axis=0) for elt in batch_data)
27
-
28
- return tf_data
29
-
30
-
31
- class DataLoader:
32
- """Implements a dataset wrapper for fast data loading
33
-
34
- >>> from doctr.datasets import CORD, DataLoader
35
- >>> train_set = CORD(train=True, download=True)
36
- >>> train_loader = DataLoader(train_set, batch_size=32)
37
- >>> train_iter = iter(train_loader)
38
- >>> images, targets = next(train_iter)
39
-
40
- Args:
41
- dataset: the dataset
42
- shuffle: whether the samples should be shuffled before passing it to the iterator
43
- batch_size: number of elements in each batch
44
- drop_last: if `True`, drops the last batch if it isn't full
45
- collate_fn: function to merge samples into a batch
46
- """
47
-
48
- def __init__(
49
- self,
50
- dataset,
51
- shuffle: bool = True,
52
- batch_size: int = 1,
53
- drop_last: bool = False,
54
- collate_fn: Callable | None = None,
55
- ) -> None:
56
- self.dataset = dataset
57
- self.shuffle = shuffle
58
- self.batch_size = batch_size
59
- nb = len(self.dataset) / batch_size
60
- self.num_batches = math.floor(nb) if drop_last else math.ceil(nb)
61
- if collate_fn is None:
62
- self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, "collate_fn") else default_collate
63
- else:
64
- self.collate_fn = collate_fn
65
- self.reset()
66
-
67
- def __len__(self) -> int:
68
- return self.num_batches
69
-
70
- def reset(self) -> None:
71
- # Updates indices after each epoch
72
- self._num_yielded = 0
73
- self.indices = np.arange(len(self.dataset))
74
- if self.shuffle is True:
75
- np.random.shuffle(self.indices)
76
-
77
- def __iter__(self):
78
- self.reset()
79
- return self
80
-
81
- def __next__(self):
82
- if self._num_yielded < self.num_batches:
83
- # Get next indices
84
- idx = self._num_yielded * self.batch_size
85
- indices = self.indices[idx : min(len(self.dataset), idx + self.batch_size)]
86
-
87
- samples = list(map(self.dataset.__getitem__, indices))
88
-
89
- batch_data = self.collate_fn(samples)
90
-
91
- self._num_yielded += 1
92
- return batch_data
93
- else:
94
- raise StopIteration
@@ -1,101 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
-
7
- import numpy as np
8
- import tensorflow as tf
9
- from PIL import Image
10
- from tensorflow.keras.utils import img_to_array
11
-
12
- from doctr.utils.common_types import AbstractPath
13
-
14
- __all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
15
-
16
-
17
- def tensor_from_pil(pil_img: Image.Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
18
- """Convert a PIL Image to a TensorFlow tensor
19
-
20
- Args:
21
- pil_img: a PIL image
22
- dtype: the output tensor data type
23
-
24
- Returns:
25
- decoded image as tensor
26
- """
27
- npy_img = img_to_array(pil_img)
28
-
29
- return tensor_from_numpy(npy_img, dtype)
30
-
31
-
32
- def read_img_as_tensor(img_path: AbstractPath, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
33
- """Read an image file as a TensorFlow tensor
34
-
35
- Args:
36
- img_path: location of the image file
37
- dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
38
-
39
- Returns:
40
- decoded image as a tensor
41
- """
42
- if dtype not in (tf.uint8, tf.float16, tf.float32):
43
- raise ValueError("insupported value for dtype")
44
-
45
- img = tf.io.read_file(img_path)
46
- img = tf.image.decode_jpeg(img, channels=3)
47
-
48
- if dtype != tf.uint8:
49
- img = tf.image.convert_image_dtype(img, dtype=dtype)
50
- img = tf.clip_by_value(img, 0, 1)
51
-
52
- return img
53
-
54
-
55
- def decode_img_as_tensor(img_content: bytes, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
56
- """Read a byte stream as a TensorFlow tensor
57
-
58
- Args:
59
- img_content: bytes of a decoded image
60
- dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
61
-
62
- Returns:
63
- decoded image as a tensor
64
- """
65
- if dtype not in (tf.uint8, tf.float16, tf.float32):
66
- raise ValueError("insupported value for dtype")
67
-
68
- img = tf.io.decode_image(img_content, channels=3)
69
-
70
- if dtype != tf.uint8:
71
- img = tf.image.convert_image_dtype(img, dtype=dtype)
72
- img = tf.clip_by_value(img, 0, 1)
73
-
74
- return img
75
-
76
-
77
- def tensor_from_numpy(npy_img: np.ndarray, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
78
- """Read an image file as a TensorFlow tensor
79
-
80
- Args:
81
- npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8
82
- dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
83
-
84
- Returns:
85
- same image as a tensor of shape (H, W, C)
86
- """
87
- if dtype not in (tf.uint8, tf.float16, tf.float32):
88
- raise ValueError("insupported value for dtype")
89
-
90
- if dtype == tf.uint8:
91
- img = tf.convert_to_tensor(npy_img, dtype=dtype)
92
- else:
93
- img = tf.image.convert_image_dtype(npy_img, dtype=dtype)
94
- img = tf.clip_by_value(img, 0, 1)
95
-
96
- return img
97
-
98
-
99
- def get_img_shape(img: tf.Tensor) -> tuple[int, int]:
100
- """Get the shape of an image"""
101
- return img.shape[:2]
@@ -1,196 +0,0 @@
1
- # Copyright (C) 2021-2025, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import math
7
- from copy import deepcopy
8
- from functools import partial
9
- from typing import Any
10
-
11
- import tensorflow as tf
12
- from tensorflow.keras import activations, layers
13
- from tensorflow.keras.models import Sequential
14
-
15
- from doctr.datasets import VOCABS
16
-
17
- from ...utils import _build_model
18
- from ..resnet.tensorflow import ResNet
19
-
20
- __all__ = ["magc_resnet31"]
21
-
22
-
23
- default_cfgs: dict[str, dict[str, Any]] = {
24
- "magc_resnet31": {
25
- "mean": (0.694, 0.695, 0.693),
26
- "std": (0.299, 0.296, 0.301),
27
- "input_shape": (32, 32, 3),
28
- "classes": list(VOCABS["french"]),
29
- "url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0",
30
- },
31
- }
32
-
33
-
34
- class MAGC(layers.Layer):
35
- """Implements the Multi-Aspect Global Context Attention, as described in
36
- <https://arxiv.org/pdf/1910.02562.pdf>`_.
37
-
38
- Args:
39
- inplanes: input channels
40
- headers: number of headers to split channels
41
- attn_scale: if True, re-scale attention to counteract the variance distibutions
42
- ratio: bottleneck ratio
43
- **kwargs
44
- """
45
-
46
- def __init__(
47
- self,
48
- inplanes: int,
49
- headers: int = 8,
50
- attn_scale: bool = False,
51
- ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper
52
- **kwargs,
53
- ) -> None:
54
- super().__init__(**kwargs)
55
-
56
- self.headers = headers # h
57
- self.inplanes = inplanes # C
58
- self.attn_scale = attn_scale
59
- self.ratio = ratio
60
- self.planes = int(inplanes * ratio)
61
-
62
- self.single_header_inplanes = int(inplanes / headers) # C / h
63
-
64
- self.conv_mask = layers.Conv2D(filters=1, kernel_size=1, kernel_initializer=tf.initializers.he_normal())
65
-
66
- self.transform = Sequential(
67
- [
68
- layers.Conv2D(filters=self.planes, kernel_size=1, kernel_initializer=tf.initializers.he_normal()),
69
- layers.LayerNormalization([1, 2, 3]),
70
- layers.ReLU(),
71
- layers.Conv2D(filters=self.inplanes, kernel_size=1, kernel_initializer=tf.initializers.he_normal()),
72
- ],
73
- name="transform",
74
- )
75
-
76
- def context_modeling(self, inputs: tf.Tensor) -> tf.Tensor:
77
- b, h, w, c = (tf.shape(inputs)[i] for i in range(4))
78
-
79
- # B, H, W, C -->> B*h, H, W, C/h
80
- x = tf.reshape(inputs, shape=(b, h, w, self.headers, self.single_header_inplanes))
81
- x = tf.transpose(x, perm=(0, 3, 1, 2, 4))
82
- x = tf.reshape(x, shape=(b * self.headers, h, w, self.single_header_inplanes))
83
-
84
- # Compute shorcut
85
- shortcut = x
86
- # B*h, 1, H*W, C/h
87
- shortcut = tf.reshape(shortcut, shape=(b * self.headers, 1, h * w, self.single_header_inplanes))
88
- # B*h, 1, C/h, H*W
89
- shortcut = tf.transpose(shortcut, perm=[0, 1, 3, 2])
90
-
91
- # Compute context mask
92
- # B*h, H, W, 1
93
- context_mask = self.conv_mask(x)
94
- # B*h, 1, H*W, 1
95
- context_mask = tf.reshape(context_mask, shape=(b * self.headers, 1, h * w, 1))
96
- # scale variance
97
- if self.attn_scale and self.headers > 1:
98
- context_mask = context_mask / math.sqrt(self.single_header_inplanes)
99
- # B*h, 1, H*W, 1
100
- context_mask = activations.softmax(context_mask, axis=2)
101
-
102
- # Compute context
103
- # B*h, 1, C/h, 1
104
- context = tf.matmul(shortcut, context_mask)
105
- context = tf.reshape(context, shape=(b, 1, c, 1))
106
- # B, 1, 1, C
107
- context = tf.transpose(context, perm=(0, 1, 3, 2))
108
- # Set shape to resolve shape when calling this module in the Sequential MAGCResnet
109
- batch, chan = inputs.get_shape().as_list()[0], inputs.get_shape().as_list()[-1]
110
- context.set_shape([batch, 1, 1, chan])
111
- return context
112
-
113
- def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
114
- # Context modeling: B, H, W, C -> B, 1, 1, C
115
- context = self.context_modeling(inputs)
116
- # Transform: B, 1, 1, C -> B, 1, 1, C
117
- transformed = self.transform(context, **kwargs)
118
- return inputs + transformed
119
-
120
-
121
- def _magc_resnet(
122
- arch: str,
123
- pretrained: bool,
124
- num_blocks: list[int],
125
- output_channels: list[int],
126
- stage_downsample: list[bool],
127
- stage_conv: list[bool],
128
- stage_pooling: list[tuple[int, int] | None],
129
- origin_stem: bool = True,
130
- **kwargs: Any,
131
- ) -> ResNet:
132
- kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
133
- kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
134
- kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
135
-
136
- _cfg = deepcopy(default_cfgs[arch])
137
- _cfg["num_classes"] = kwargs["num_classes"]
138
- _cfg["classes"] = kwargs["classes"]
139
- _cfg["input_shape"] = kwargs["input_shape"]
140
- kwargs.pop("classes")
141
-
142
- # Build the model
143
- model = ResNet(
144
- num_blocks,
145
- output_channels,
146
- stage_downsample,
147
- stage_conv,
148
- stage_pooling,
149
- origin_stem,
150
- attn_module=partial(MAGC, headers=8, attn_scale=True),
151
- cfg=_cfg,
152
- **kwargs,
153
- )
154
- _build_model(model)
155
-
156
- # Load pretrained parameters
157
- if pretrained:
158
- # The number of classes is not the same as the number of classes in the pretrained model =>
159
- # skip the mismatching layers for fine tuning
160
- model.from_pretrained(
161
- default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
162
- )
163
-
164
- return model
165
-
166
-
167
- def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
168
- """Resnet31 architecture with Multi-Aspect Global Context Attention as described in
169
- `"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition",
170
- <https://arxiv.org/pdf/1910.02562.pdf>`_.
171
-
172
- >>> import tensorflow as tf
173
- >>> from doctr.models import magc_resnet31
174
- >>> model = magc_resnet31(pretrained=False)
175
- >>> input_tensor = tf.random.uniform(shape=[1, 224, 224, 3], maxval=1, dtype=tf.float32)
176
- >>> out = model(input_tensor)
177
-
178
- Args:
179
- pretrained: boolean, True if model is pretrained
180
- **kwargs: keyword arguments of the ResNet architecture
181
-
182
- Returns:
183
- A feature extractor model
184
- """
185
- return _magc_resnet(
186
- "magc_resnet31",
187
- pretrained,
188
- [1, 2, 5, 3],
189
- [256, 256, 512, 512],
190
- [False] * 4,
191
- [True] * 4,
192
- [(2, 2), (2, 1), None, None],
193
- False,
194
- stem_channels=128,
195
- **kwargs,
196
- )