pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (306) hide show
  1. pyxllib/algo/geo.py +12 -0
  2. pyxllib/algo/intervals.py +1 -1
  3. pyxllib/algo/matcher.py +78 -0
  4. pyxllib/algo/pupil.py +187 -19
  5. pyxllib/algo/specialist.py +2 -1
  6. pyxllib/algo/stat.py +38 -2
  7. {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
  8. pyxllib/autogui/activewin.py +246 -0
  9. pyxllib/autogui/all.py +9 -0
  10. pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
  11. pyxllib/autogui/uiautolib.py +362 -0
  12. pyxllib/autogui/wechat.py +827 -0
  13. pyxllib/autogui/wechat_msg.py +421 -0
  14. pyxllib/autogui/wxautolib.py +84 -0
  15. pyxllib/cv/slidercaptcha.py +137 -0
  16. pyxllib/data/echarts.py +123 -12
  17. pyxllib/data/jsonlib.py +89 -0
  18. pyxllib/data/pglib.py +514 -30
  19. pyxllib/data/sqlite.py +231 -4
  20. pyxllib/ext/JLineViewer.py +14 -1
  21. pyxllib/ext/drissionlib.py +277 -0
  22. pyxllib/ext/kq5034lib.py +0 -1594
  23. pyxllib/ext/robustprocfile.py +497 -0
  24. pyxllib/ext/unixlib.py +6 -5
  25. pyxllib/ext/utools.py +108 -95
  26. pyxllib/ext/webhook.py +32 -14
  27. pyxllib/ext/wjxlib.py +88 -0
  28. pyxllib/ext/wpsapi.py +124 -0
  29. pyxllib/ext/xlwork.py +9 -0
  30. pyxllib/ext/yuquelib.py +1003 -71
  31. pyxllib/file/docxlib.py +1 -1
  32. pyxllib/file/libreoffice.py +165 -0
  33. pyxllib/file/movielib.py +9 -0
  34. pyxllib/file/packlib/__init__.py +112 -75
  35. pyxllib/file/pdflib.py +1 -1
  36. pyxllib/file/pupil.py +1 -1
  37. pyxllib/file/specialist/dirlib.py +1 -1
  38. pyxllib/file/specialist/download.py +10 -3
  39. pyxllib/file/specialist/filelib.py +266 -55
  40. pyxllib/file/xlsxlib.py +205 -50
  41. pyxllib/file/xlsyncfile.py +341 -0
  42. pyxllib/prog/cachetools.py +64 -0
  43. pyxllib/prog/filelock.py +42 -0
  44. pyxllib/prog/multiprogs.py +940 -0
  45. pyxllib/prog/newbie.py +9 -2
  46. pyxllib/prog/pupil.py +129 -60
  47. pyxllib/prog/specialist/__init__.py +176 -2
  48. pyxllib/prog/specialist/bc.py +5 -2
  49. pyxllib/prog/specialist/browser.py +11 -2
  50. pyxllib/prog/specialist/datetime.py +68 -0
  51. pyxllib/prog/specialist/tictoc.py +12 -13
  52. pyxllib/prog/specialist/xllog.py +5 -5
  53. pyxllib/prog/xlosenv.py +7 -0
  54. pyxllib/text/airscript.js +744 -0
  55. pyxllib/text/charclasslib.py +17 -5
  56. pyxllib/text/jiebalib.py +6 -3
  57. pyxllib/text/jinjalib.py +32 -0
  58. pyxllib/text/jsa_ai_prompt.md +271 -0
  59. pyxllib/text/jscode.py +159 -4
  60. pyxllib/text/nestenv.py +1 -1
  61. pyxllib/text/newbie.py +12 -0
  62. pyxllib/text/pupil/common.py +26 -0
  63. pyxllib/text/specialist/ptag.py +2 -2
  64. pyxllib/text/templates/echart_base.html +11 -0
  65. pyxllib/text/templates/highlight_code.html +17 -0
  66. pyxllib/text/templates/latex_editor.html +103 -0
  67. pyxllib/text/xmllib.py +76 -14
  68. pyxllib/xl.py +2 -1
  69. pyxllib-0.3.197.dist-info/METADATA +48 -0
  70. pyxllib-0.3.197.dist-info/RECORD +126 -0
  71. {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
  72. pyxllib/ext/autogui/__init__.py +0 -8
  73. pyxllib-0.3.96.dist-info/METADATA +0 -51
  74. pyxllib-0.3.96.dist-info/RECORD +0 -333
  75. pyxllib-0.3.96.dist-info/top_level.txt +0 -2
  76. pyxlpr/ai/__init__.py +0 -5
  77. pyxlpr/ai/clientlib.py +0 -1281
  78. pyxlpr/ai/specialist.py +0 -286
  79. pyxlpr/ai/torch_app.py +0 -172
  80. pyxlpr/ai/xlpaddle.py +0 -655
  81. pyxlpr/ai/xltorch.py +0 -705
  82. pyxlpr/data/__init__.py +0 -11
  83. pyxlpr/data/coco.py +0 -1325
  84. pyxlpr/data/datacls.py +0 -365
  85. pyxlpr/data/datasets.py +0 -200
  86. pyxlpr/data/gptlib.py +0 -1291
  87. pyxlpr/data/icdar/__init__.py +0 -96
  88. pyxlpr/data/icdar/deteval.py +0 -377
  89. pyxlpr/data/icdar/icdar2013.py +0 -341
  90. pyxlpr/data/icdar/iou.py +0 -340
  91. pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
  92. pyxlpr/data/imtextline.py +0 -473
  93. pyxlpr/data/labelme.py +0 -866
  94. pyxlpr/data/removeline.py +0 -179
  95. pyxlpr/data/specialist.py +0 -57
  96. pyxlpr/eval/__init__.py +0 -85
  97. pyxlpr/paddleocr.py +0 -776
  98. pyxlpr/ppocr/__init__.py +0 -15
  99. pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
  100. pyxlpr/ppocr/data/__init__.py +0 -135
  101. pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
  102. pyxlpr/ppocr/data/imaug/__init__.py +0 -67
  103. pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
  104. pyxlpr/ppocr/data/imaug/east_process.py +0 -437
  105. pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
  106. pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
  107. pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
  108. pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
  109. pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
  110. pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
  111. pyxlpr/ppocr/data/imaug/operators.py +0 -433
  112. pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
  113. pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
  114. pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
  115. pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
  116. pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
  117. pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
  118. pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
  119. pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
  120. pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
  121. pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
  122. pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
  123. pyxlpr/ppocr/data/simple_dataset.py +0 -372
  124. pyxlpr/ppocr/losses/__init__.py +0 -61
  125. pyxlpr/ppocr/losses/ace_loss.py +0 -52
  126. pyxlpr/ppocr/losses/basic_loss.py +0 -135
  127. pyxlpr/ppocr/losses/center_loss.py +0 -88
  128. pyxlpr/ppocr/losses/cls_loss.py +0 -30
  129. pyxlpr/ppocr/losses/combined_loss.py +0 -67
  130. pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
  131. pyxlpr/ppocr/losses/det_db_loss.py +0 -80
  132. pyxlpr/ppocr/losses/det_east_loss.py +0 -63
  133. pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
  134. pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
  135. pyxlpr/ppocr/losses/distillation_loss.py +0 -272
  136. pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
  137. pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
  138. pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
  139. pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
  140. pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
  141. pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
  142. pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
  143. pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
  144. pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
  145. pyxlpr/ppocr/losses/table_att_loss.py +0 -109
  146. pyxlpr/ppocr/metrics/__init__.py +0 -44
  147. pyxlpr/ppocr/metrics/cls_metric.py +0 -45
  148. pyxlpr/ppocr/metrics/det_metric.py +0 -82
  149. pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
  150. pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
  151. pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
  152. pyxlpr/ppocr/metrics/kie_metric.py +0 -70
  153. pyxlpr/ppocr/metrics/rec_metric.py +0 -75
  154. pyxlpr/ppocr/metrics/table_metric.py +0 -50
  155. pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
  156. pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
  157. pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
  158. pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
  159. pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
  160. pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
  161. pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
  162. pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
  163. pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
  164. pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
  165. pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
  166. pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
  167. pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
  168. pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
  169. pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
  170. pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
  171. pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
  172. pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
  173. pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
  174. pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
  175. pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
  176. pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
  177. pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
  178. pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
  179. pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
  180. pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
  181. pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
  182. pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
  183. pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
  184. pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
  185. pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
  186. pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
  187. pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
  188. pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
  189. pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
  190. pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
  191. pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
  192. pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
  193. pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
  194. pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
  195. pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
  196. pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
  197. pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
  198. pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
  199. pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
  200. pyxlpr/ppocr/optimizer/__init__.py +0 -61
  201. pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
  202. pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
  203. pyxlpr/ppocr/optimizer/optimizer.py +0 -160
  204. pyxlpr/ppocr/optimizer/regularizer.py +0 -52
  205. pyxlpr/ppocr/postprocess/__init__.py +0 -55
  206. pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
  207. pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
  208. pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
  209. pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
  210. pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
  211. pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
  212. pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
  213. pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
  214. pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
  215. pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
  216. pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
  217. pyxlpr/ppocr/tools/__init__.py +0 -14
  218. pyxlpr/ppocr/tools/eval.py +0 -83
  219. pyxlpr/ppocr/tools/export_center.py +0 -77
  220. pyxlpr/ppocr/tools/export_model.py +0 -129
  221. pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
  222. pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
  223. pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
  224. pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
  225. pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
  226. pyxlpr/ppocr/tools/infer/utility.py +0 -629
  227. pyxlpr/ppocr/tools/infer_cls.py +0 -83
  228. pyxlpr/ppocr/tools/infer_det.py +0 -134
  229. pyxlpr/ppocr/tools/infer_e2e.py +0 -122
  230. pyxlpr/ppocr/tools/infer_kie.py +0 -153
  231. pyxlpr/ppocr/tools/infer_rec.py +0 -146
  232. pyxlpr/ppocr/tools/infer_table.py +0 -107
  233. pyxlpr/ppocr/tools/program.py +0 -596
  234. pyxlpr/ppocr/tools/test_hubserving.py +0 -117
  235. pyxlpr/ppocr/tools/train.py +0 -163
  236. pyxlpr/ppocr/tools/xlprog.py +0 -748
  237. pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
  238. pyxlpr/ppocr/utils/__init__.py +0 -24
  239. pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
  240. pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
  241. pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
  242. pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
  243. pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
  244. pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
  245. pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
  246. pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
  247. pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
  248. pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
  249. pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
  250. pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
  251. pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
  252. pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
  253. pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
  254. pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
  255. pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
  256. pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
  257. pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
  258. pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
  259. pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
  260. pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
  261. pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
  262. pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
  263. pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
  264. pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
  265. pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
  266. pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
  267. pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
  268. pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
  269. pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
  270. pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
  271. pyxlpr/ppocr/utils/dict90.txt +0 -90
  272. pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
  273. pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
  274. pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
  275. pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
  276. pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
  277. pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
  278. pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
  279. pyxlpr/ppocr/utils/en_dict.txt +0 -95
  280. pyxlpr/ppocr/utils/gen_label.py +0 -81
  281. pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
  282. pyxlpr/ppocr/utils/iou.py +0 -54
  283. pyxlpr/ppocr/utils/logging.py +0 -69
  284. pyxlpr/ppocr/utils/network.py +0 -84
  285. pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
  286. pyxlpr/ppocr/utils/profiler.py +0 -110
  287. pyxlpr/ppocr/utils/save_load.py +0 -150
  288. pyxlpr/ppocr/utils/stats.py +0 -72
  289. pyxlpr/ppocr/utils/utility.py +0 -80
  290. pyxlpr/ppstructure/__init__.py +0 -13
  291. pyxlpr/ppstructure/predict_system.py +0 -187
  292. pyxlpr/ppstructure/table/__init__.py +0 -13
  293. pyxlpr/ppstructure/table/eval_table.py +0 -72
  294. pyxlpr/ppstructure/table/matcher.py +0 -192
  295. pyxlpr/ppstructure/table/predict_structure.py +0 -136
  296. pyxlpr/ppstructure/table/predict_table.py +0 -221
  297. pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
  298. pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
  299. pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
  300. pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
  301. pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
  302. pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
  303. pyxlpr/ppstructure/utility.py +0 -71
  304. pyxlpr/xlai.py +0 -10
  305. /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
  306. {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
@@ -1,826 +0,0 @@
1
- # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import math
16
- import paddle
17
- import copy
18
- from paddle import nn
19
- import paddle.nn.functional as F
20
- from paddle.nn import LayerList
21
- from paddle.nn.initializer import XavierNormal as xavier_uniform_
22
- from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
23
- import numpy as np
24
- from pyxlpr.ppocr.modeling.heads.multiheadAttention import MultiheadAttention
25
- from paddle.nn.initializer import Constant as constant_
26
- from paddle.nn.initializer import XavierNormal as xavier_normal_
27
-
28
- zeros_ = constant_(value=0.)
29
- ones_ = constant_(value=1.)
30
-
31
-
32
- class Transformer(nn.Layer):
33
- """A transformer model. User is able to modify the attributes as needed. The architechture
34
- is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
35
- Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
36
- Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
37
- Processing Systems, pages 6000-6010.
38
-
39
- Args:
40
- d_model: the number of expected features in the encoder/decoder inputs (default=512).
41
- nhead: the number of heads in the multiheadattention models (default=8).
42
- num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
43
- num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
44
- dim_feedforward: the dimension of the feedforward network model (default=2048).
45
- dropout: the dropout value (default=0.1).
46
- custom_encoder: custom encoder (default=None).
47
- custom_decoder: custom decoder (default=None).
48
-
49
- """
50
-
51
- def __init__(self,
52
- d_model=512,
53
- nhead=8,
54
- num_encoder_layers=6,
55
- beam_size=0,
56
- num_decoder_layers=6,
57
- dim_feedforward=1024,
58
- attention_dropout_rate=0.0,
59
- residual_dropout_rate=0.1,
60
- custom_encoder=None,
61
- custom_decoder=None,
62
- in_channels=0,
63
- out_channels=0,
64
- scale_embedding=True):
65
- super(Transformer, self).__init__()
66
- self.out_channels = out_channels + 1
67
- self.embedding = Embeddings(
68
- d_model=d_model,
69
- vocab=self.out_channels,
70
- padding_idx=0,
71
- scale_embedding=scale_embedding)
72
- self.positional_encoding = PositionalEncoding(
73
- dropout=residual_dropout_rate,
74
- dim=d_model, )
75
- if custom_encoder is not None:
76
- self.encoder = custom_encoder
77
- else:
78
- if num_encoder_layers > 0:
79
- encoder_layer = TransformerEncoderLayer(
80
- d_model, nhead, dim_feedforward, attention_dropout_rate,
81
- residual_dropout_rate)
82
- self.encoder = TransformerEncoder(encoder_layer,
83
- num_encoder_layers)
84
- else:
85
- self.encoder = None
86
-
87
- if custom_decoder is not None:
88
- self.decoder = custom_decoder
89
- else:
90
- decoder_layer = TransformerDecoderLayer(
91
- d_model, nhead, dim_feedforward, attention_dropout_rate,
92
- residual_dropout_rate)
93
- self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
94
-
95
- self._reset_parameters()
96
- self.beam_size = beam_size
97
- self.d_model = d_model
98
- self.nhead = nhead
99
- self.tgt_word_prj = nn.Linear(
100
- d_model, self.out_channels, bias_attr=False)
101
- w0 = np.random.normal(0.0, d_model**-0.5,
102
- (d_model, self.out_channels)).astype(np.float32)
103
- self.tgt_word_prj.weight.set_value(w0)
104
- self.apply(self._init_weights)
105
-
106
- def _init_weights(self, m):
107
-
108
- if isinstance(m, nn.Conv2D):
109
- xavier_normal_(m.weight)
110
- if m.bias is not None:
111
- zeros_(m.bias)
112
-
113
- def forward_train(self, src, tgt):
114
- tgt = tgt[:, :-1]
115
-
116
- tgt_key_padding_mask = self.generate_padding_mask(tgt)
117
- tgt = self.embedding(tgt).transpose([1, 0, 2])
118
- tgt = self.positional_encoding(tgt)
119
- tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0])
120
-
121
- if self.encoder is not None:
122
- src = self.positional_encoding(src.transpose([1, 0, 2]))
123
- memory = self.encoder(src)
124
- else:
125
- memory = src.squeeze(2).transpose([2, 0, 1])
126
- output = self.decoder(
127
- tgt,
128
- memory,
129
- tgt_mask=tgt_mask,
130
- memory_mask=None,
131
- tgt_key_padding_mask=tgt_key_padding_mask,
132
- memory_key_padding_mask=None)
133
- output = output.transpose([1, 0, 2])
134
- logit = self.tgt_word_prj(output)
135
- return logit
136
-
137
- def forward(self, src, targets=None):
138
- """Take in and process masked source/target sequences.
139
- Args:
140
- src: the sequence to the encoder (required).
141
- tgt: the sequence to the decoder (required).
142
- Shape:
143
- - src: :math:`(S, N, E)`.
144
- - tgt: :math:`(T, N, E)`.
145
- Examples:
146
- >>> output = transformer_model(src, tgt)
147
- """
148
-
149
- if self.training:
150
- max_len = targets[1].max()
151
- tgt = targets[0][:, :2 + max_len]
152
- return self.forward_train(src, tgt)
153
- else:
154
- if self.beam_size > 0:
155
- return self.forward_beam(src)
156
- else:
157
- return self.forward_test(src)
158
-
159
- def forward_test(self, src):
160
- bs = paddle.shape(src)[0]
161
- if self.encoder is not None:
162
- src = self.positional_encoding(paddle.transpose(src, [1, 0, 2]))
163
- memory = self.encoder(src)
164
- else:
165
- memory = paddle.transpose(paddle.squeeze(src, 2), [2, 0, 1])
166
- dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
167
- dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32)
168
- for len_dec_seq in range(1, 25):
169
- dec_seq_embed = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
170
- dec_seq_embed = self.positional_encoding(dec_seq_embed)
171
- tgt_mask = self.generate_square_subsequent_mask(
172
- paddle.shape(dec_seq_embed)[0])
173
- output = self.decoder(
174
- dec_seq_embed,
175
- memory,
176
- tgt_mask=tgt_mask,
177
- memory_mask=None,
178
- tgt_key_padding_mask=None,
179
- memory_key_padding_mask=None)
180
- dec_output = paddle.transpose(output, [1, 0, 2])
181
- dec_output = dec_output[:, -1, :]
182
- word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
183
- preds_idx = paddle.argmax(word_prob, axis=1)
184
- if paddle.equal_all(
185
- preds_idx,
186
- paddle.full(
187
- paddle.shape(preds_idx), 3, dtype='int64')):
188
- break
189
- preds_prob = paddle.max(word_prob, axis=1)
190
- dec_seq = paddle.concat(
191
- [dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1)
192
- dec_prob = paddle.concat(
193
- [dec_prob, paddle.reshape(preds_prob, [-1, 1])], axis=1)
194
- return [dec_seq, dec_prob]
195
-
196
- def forward_beam(self, images):
197
- ''' Translation work in one batch '''
198
-
199
- def get_inst_idx_to_tensor_position_map(inst_idx_list):
200
- ''' Indicate the position of an instance in a tensor. '''
201
- return {
202
- inst_idx: tensor_position
203
- for tensor_position, inst_idx in enumerate(inst_idx_list)
204
- }
205
-
206
- def collect_active_part(beamed_tensor, curr_active_inst_idx,
207
- n_prev_active_inst, n_bm):
208
- ''' Collect tensor parts associated to active instances. '''
209
-
210
- beamed_tensor_shape = paddle.shape(beamed_tensor)
211
- n_curr_active_inst = len(curr_active_inst_idx)
212
- new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1],
213
- beamed_tensor_shape[2])
214
-
215
- beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
216
- beamed_tensor = beamed_tensor.index_select(
217
- curr_active_inst_idx, axis=0)
218
- beamed_tensor = beamed_tensor.reshape(new_shape)
219
-
220
- return beamed_tensor
221
-
222
- def collate_active_info(src_enc, inst_idx_to_position_map,
223
- active_inst_idx_list):
224
- # Sentences which are still active are collected,
225
- # so the decoder will not run on completed sentences.
226
-
227
- n_prev_active_inst = len(inst_idx_to_position_map)
228
- active_inst_idx = [
229
- inst_idx_to_position_map[k] for k in active_inst_idx_list
230
- ]
231
- active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64')
232
- active_src_enc = collect_active_part(
233
- src_enc.transpose([1, 0, 2]), active_inst_idx,
234
- n_prev_active_inst, n_bm).transpose([1, 0, 2])
235
- active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
236
- active_inst_idx_list)
237
- return active_src_enc, active_inst_idx_to_position_map
238
-
239
- def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
240
- inst_idx_to_position_map, n_bm,
241
- memory_key_padding_mask):
242
- ''' Decode and update beam status, and then return active beam idx '''
243
-
244
- def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
245
- dec_partial_seq = [
246
- b.get_current_state() for b in inst_dec_beams if not b.done
247
- ]
248
- dec_partial_seq = paddle.stack(dec_partial_seq)
249
- dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
250
- return dec_partial_seq
251
-
252
- def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
253
- memory_key_padding_mask):
254
- dec_seq = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
255
- dec_seq = self.positional_encoding(dec_seq)
256
- tgt_mask = self.generate_square_subsequent_mask(
257
- paddle.shape(dec_seq)[0])
258
- dec_output = self.decoder(
259
- dec_seq,
260
- enc_output,
261
- tgt_mask=tgt_mask,
262
- tgt_key_padding_mask=None,
263
- memory_key_padding_mask=memory_key_padding_mask, )
264
- dec_output = paddle.transpose(dec_output, [1, 0, 2])
265
- dec_output = dec_output[:,
266
- -1, :] # Pick the last step: (bh * bm) * d_h
267
- word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
268
- word_prob = paddle.reshape(word_prob, [n_active_inst, n_bm, -1])
269
- return word_prob
270
-
271
- def collect_active_inst_idx_list(inst_beams, word_prob,
272
- inst_idx_to_position_map):
273
- active_inst_idx_list = []
274
- for inst_idx, inst_position in inst_idx_to_position_map.items():
275
- is_inst_complete = inst_beams[inst_idx].advance(word_prob[
276
- inst_position])
277
- if not is_inst_complete:
278
- active_inst_idx_list += [inst_idx]
279
-
280
- return active_inst_idx_list
281
-
282
- n_active_inst = len(inst_idx_to_position_map)
283
- dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
284
- word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
285
- None)
286
- # Update the beam with predicted word prob information and collect incomplete instances
287
- active_inst_idx_list = collect_active_inst_idx_list(
288
- inst_dec_beams, word_prob, inst_idx_to_position_map)
289
- return active_inst_idx_list
290
-
291
- def collect_hypothesis_and_scores(inst_dec_beams, n_best):
292
- all_hyp, all_scores = [], []
293
- for inst_idx in range(len(inst_dec_beams)):
294
- scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
295
- all_scores += [scores[:n_best]]
296
- hyps = [
297
- inst_dec_beams[inst_idx].get_hypothesis(i)
298
- for i in tail_idxs[:n_best]
299
- ]
300
- all_hyp += [hyps]
301
- return all_hyp, all_scores
302
-
303
- with paddle.no_grad():
304
- #-- Encode
305
- if self.encoder is not None:
306
- src = self.positional_encoding(images.transpose([1, 0, 2]))
307
- src_enc = self.encoder(src)
308
- else:
309
- src_enc = images.squeeze(2).transpose([0, 2, 1])
310
-
311
- n_bm = self.beam_size
312
- src_shape = paddle.shape(src_enc)
313
- inst_dec_beams = [Beam(n_bm) for _ in range(1)]
314
- active_inst_idx_list = list(range(1))
315
- # Repeat data for beam search
316
- src_enc = paddle.tile(src_enc, [1, n_bm, 1])
317
- inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
318
- active_inst_idx_list)
319
- # Decode
320
- for len_dec_seq in range(1, 25):
321
- src_enc_copy = src_enc.clone()
322
- active_inst_idx_list = beam_decode_step(
323
- inst_dec_beams, len_dec_seq, src_enc_copy,
324
- inst_idx_to_position_map, n_bm, None)
325
- if not active_inst_idx_list:
326
- break # all instances have finished their path to <EOS>
327
- src_enc, inst_idx_to_position_map = collate_active_info(
328
- src_enc_copy, inst_idx_to_position_map,
329
- active_inst_idx_list)
330
- batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
331
- 1)
332
- result_hyp = []
333
- hyp_scores = []
334
- for bs_hyp, score in zip(batch_hyp, batch_scores):
335
- l = len(bs_hyp[0])
336
- bs_hyp_pad = bs_hyp[0] + [3] * (25 - l)
337
- result_hyp.append(bs_hyp_pad)
338
- score = float(score) / l
339
- hyp_score = [score for _ in range(25)]
340
- hyp_scores.append(hyp_score)
341
- return [
342
- paddle.to_tensor(
343
- np.array(result_hyp), dtype=paddle.int64),
344
- paddle.to_tensor(hyp_scores)
345
- ]
346
-
347
- def generate_square_subsequent_mask(self, sz):
348
- """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
349
- Unmasked positions are filled with float(0.0).
350
- """
351
- mask = paddle.zeros([sz, sz], dtype='float32')
352
- mask_inf = paddle.triu(
353
- paddle.full(
354
- shape=[sz, sz], dtype='float32', fill_value='-inf'),
355
- diagonal=1)
356
- mask = mask + mask_inf
357
- return mask
358
-
359
- def generate_padding_mask(self, x):
360
- padding_mask = paddle.equal(x, paddle.to_tensor(0, dtype=x.dtype))
361
- return padding_mask
362
-
363
- def _reset_parameters(self):
364
- """Initiate parameters in the transformer model."""
365
-
366
- for p in self.parameters():
367
- if p.dim() > 1:
368
- xavier_uniform_(p)
369
-
370
-
371
- class TransformerEncoder(nn.Layer):
372
- """TransformerEncoder is a stack of N encoder layers
373
- Args:
374
- encoder_layer: an instance of the TransformerEncoderLayer() class (required).
375
- num_layers: the number of sub-encoder-layers in the encoder (required).
376
- norm: the layer normalization component (optional).
377
- """
378
-
379
- def __init__(self, encoder_layer, num_layers):
380
- super(TransformerEncoder, self).__init__()
381
- self.layers = _get_clones(encoder_layer, num_layers)
382
- self.num_layers = num_layers
383
-
384
- def forward(self, src):
385
- """Pass the input through the endocder layers in turn.
386
- Args:
387
- src: the sequnce to the encoder (required).
388
- mask: the mask for the src sequence (optional).
389
- src_key_padding_mask: the mask for the src keys per batch (optional).
390
- """
391
- output = src
392
-
393
- for i in range(self.num_layers):
394
- output = self.layers[i](output,
395
- src_mask=None,
396
- src_key_padding_mask=None)
397
-
398
- return output
399
-
400
-
401
- class TransformerDecoder(nn.Layer):
402
- """TransformerDecoder is a stack of N decoder layers
403
-
404
- Args:
405
- decoder_layer: an instance of the TransformerDecoderLayer() class (required).
406
- num_layers: the number of sub-decoder-layers in the decoder (required).
407
- norm: the layer normalization component (optional).
408
-
409
- """
410
-
411
- def __init__(self, decoder_layer, num_layers):
412
- super(TransformerDecoder, self).__init__()
413
- self.layers = _get_clones(decoder_layer, num_layers)
414
- self.num_layers = num_layers
415
-
416
- def forward(self,
417
- tgt,
418
- memory,
419
- tgt_mask=None,
420
- memory_mask=None,
421
- tgt_key_padding_mask=None,
422
- memory_key_padding_mask=None):
423
- """Pass the inputs (and mask) through the decoder layer in turn.
424
-
425
- Args:
426
- tgt: the sequence to the decoder (required).
427
- memory: the sequnce from the last layer of the encoder (required).
428
- tgt_mask: the mask for the tgt sequence (optional).
429
- memory_mask: the mask for the memory sequence (optional).
430
- tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
431
- memory_key_padding_mask: the mask for the memory keys per batch (optional).
432
- """
433
- output = tgt
434
- for i in range(self.num_layers):
435
- output = self.layers[i](
436
- output,
437
- memory,
438
- tgt_mask=tgt_mask,
439
- memory_mask=memory_mask,
440
- tgt_key_padding_mask=tgt_key_padding_mask,
441
- memory_key_padding_mask=memory_key_padding_mask)
442
-
443
- return output
444
-
445
-
446
- class TransformerEncoderLayer(nn.Layer):
447
- """TransformerEncoderLayer is made up of self-attn and feedforward network.
448
- This standard encoder layer is based on the paper "Attention Is All You Need".
449
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
450
- Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
451
- Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
452
- in a different way during application.
453
-
454
- Args:
455
- d_model: the number of expected features in the input (required).
456
- nhead: the number of heads in the multiheadattention models (required).
457
- dim_feedforward: the dimension of the feedforward network model (default=2048).
458
- dropout: the dropout value (default=0.1).
459
-
460
- """
461
-
462
- def __init__(self,
463
- d_model,
464
- nhead,
465
- dim_feedforward=2048,
466
- attention_dropout_rate=0.0,
467
- residual_dropout_rate=0.1):
468
- super(TransformerEncoderLayer, self).__init__()
469
- self.self_attn = MultiheadAttention(
470
- d_model, nhead, dropout=attention_dropout_rate)
471
-
472
- self.conv1 = Conv2D(
473
- in_channels=d_model,
474
- out_channels=dim_feedforward,
475
- kernel_size=(1, 1))
476
- self.conv2 = Conv2D(
477
- in_channels=dim_feedforward,
478
- out_channels=d_model,
479
- kernel_size=(1, 1))
480
-
481
- self.norm1 = LayerNorm(d_model)
482
- self.norm2 = LayerNorm(d_model)
483
- self.dropout1 = Dropout(residual_dropout_rate)
484
- self.dropout2 = Dropout(residual_dropout_rate)
485
-
486
- def forward(self, src, src_mask=None, src_key_padding_mask=None):
487
- """Pass the input through the endocder layer.
488
- Args:
489
- src: the sequnce to the encoder layer (required).
490
- src_mask: the mask for the src sequence (optional).
491
- src_key_padding_mask: the mask for the src keys per batch (optional).
492
- """
493
- src2 = self.self_attn(
494
- src,
495
- src,
496
- src,
497
- attn_mask=src_mask,
498
- key_padding_mask=src_key_padding_mask)
499
- src = src + self.dropout1(src2)
500
- src = self.norm1(src)
501
-
502
- src = paddle.transpose(src, [1, 2, 0])
503
- src = paddle.unsqueeze(src, 2)
504
- src2 = self.conv2(F.relu(self.conv1(src)))
505
- src2 = paddle.squeeze(src2, 2)
506
- src2 = paddle.transpose(src2, [2, 0, 1])
507
- src = paddle.squeeze(src, 2)
508
- src = paddle.transpose(src, [2, 0, 1])
509
-
510
- src = src + self.dropout2(src2)
511
- src = self.norm2(src)
512
- return src
513
-
514
-
515
- class TransformerDecoderLayer(nn.Layer):
516
- """TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
517
- This standard decoder layer is based on the paper "Attention Is All You Need".
518
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
519
- Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
520
- Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
521
- in a different way during application.
522
-
523
- Args:
524
- d_model: the number of expected features in the input (required).
525
- nhead: the number of heads in the multiheadattention models (required).
526
- dim_feedforward: the dimension of the feedforward network model (default=2048).
527
- dropout: the dropout value (default=0.1).
528
-
529
- """
530
-
531
- def __init__(self,
532
- d_model,
533
- nhead,
534
- dim_feedforward=2048,
535
- attention_dropout_rate=0.0,
536
- residual_dropout_rate=0.1):
537
- super(TransformerDecoderLayer, self).__init__()
538
- self.self_attn = MultiheadAttention(
539
- d_model, nhead, dropout=attention_dropout_rate)
540
- self.multihead_attn = MultiheadAttention(
541
- d_model, nhead, dropout=attention_dropout_rate)
542
-
543
- self.conv1 = Conv2D(
544
- in_channels=d_model,
545
- out_channels=dim_feedforward,
546
- kernel_size=(1, 1))
547
- self.conv2 = Conv2D(
548
- in_channels=dim_feedforward,
549
- out_channels=d_model,
550
- kernel_size=(1, 1))
551
-
552
- self.norm1 = LayerNorm(d_model)
553
- self.norm2 = LayerNorm(d_model)
554
- self.norm3 = LayerNorm(d_model)
555
- self.dropout1 = Dropout(residual_dropout_rate)
556
- self.dropout2 = Dropout(residual_dropout_rate)
557
- self.dropout3 = Dropout(residual_dropout_rate)
558
-
559
- def forward(self,
560
- tgt,
561
- memory,
562
- tgt_mask=None,
563
- memory_mask=None,
564
- tgt_key_padding_mask=None,
565
- memory_key_padding_mask=None):
566
- """Pass the inputs (and mask) through the decoder layer.
567
-
568
- Args:
569
- tgt: the sequence to the decoder layer (required).
570
- memory: the sequnce from the last layer of the encoder (required).
571
- tgt_mask: the mask for the tgt sequence (optional).
572
- memory_mask: the mask for the memory sequence (optional).
573
- tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
574
- memory_key_padding_mask: the mask for the memory keys per batch (optional).
575
-
576
- """
577
- tgt2 = self.self_attn(
578
- tgt,
579
- tgt,
580
- tgt,
581
- attn_mask=tgt_mask,
582
- key_padding_mask=tgt_key_padding_mask)
583
- tgt = tgt + self.dropout1(tgt2)
584
- tgt = self.norm1(tgt)
585
- tgt2 = self.multihead_attn(
586
- tgt,
587
- memory,
588
- memory,
589
- attn_mask=memory_mask,
590
- key_padding_mask=memory_key_padding_mask)
591
- tgt = tgt + self.dropout2(tgt2)
592
- tgt = self.norm2(tgt)
593
-
594
- # default
595
- tgt = paddle.transpose(tgt, [1, 2, 0])
596
- tgt = paddle.unsqueeze(tgt, 2)
597
- tgt2 = self.conv2(F.relu(self.conv1(tgt)))
598
- tgt2 = paddle.squeeze(tgt2, 2)
599
- tgt2 = paddle.transpose(tgt2, [2, 0, 1])
600
- tgt = paddle.squeeze(tgt, 2)
601
- tgt = paddle.transpose(tgt, [2, 0, 1])
602
-
603
- tgt = tgt + self.dropout3(tgt2)
604
- tgt = self.norm3(tgt)
605
- return tgt
606
-
607
-
608
- def _get_clones(module, N):
609
- return LayerList([copy.deepcopy(module) for i in range(N)])
610
-
611
-
612
- class PositionalEncoding(nn.Layer):
613
- """Inject some information about the relative or absolute position of the tokens
614
- in the sequence. The positional encodings have the same dimension as
615
- the embeddings, so that the two can be summed. Here, we use sine and cosine
616
- functions of different frequencies.
617
- .. math::
618
- \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
619
- \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
620
- \text{where pos is the word position and i is the embed idx)
621
- Args:
622
- d_model: the embed dim (required).
623
- dropout: the dropout value (default=0.1).
624
- max_len: the max. length of the incoming sequence (default=5000).
625
- Examples:
626
- >>> pos_encoder = PositionalEncoding(d_model)
627
- """
628
-
629
- def __init__(self, dropout, dim, max_len=5000):
630
- super(PositionalEncoding, self).__init__()
631
- self.dropout = nn.Dropout(p=dropout)
632
-
633
- pe = paddle.zeros([max_len, dim])
634
- position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
635
- div_term = paddle.exp(
636
- paddle.arange(0, dim, 2).astype('float32') *
637
- (-math.log(10000.0) / dim))
638
- pe[:, 0::2] = paddle.sin(position * div_term)
639
- pe[:, 1::2] = paddle.cos(position * div_term)
640
- pe = paddle.unsqueeze(pe, 0)
641
- pe = paddle.transpose(pe, [1, 0, 2])
642
- self.register_buffer('pe', pe)
643
-
644
- def forward(self, x):
645
- """Inputs of forward function
646
- Args:
647
- x: the sequence fed to the positional encoder model (required).
648
- Shape:
649
- x: [sequence length, batch size, embed dim]
650
- output: [sequence length, batch size, embed dim]
651
- Examples:
652
- >>> output = pos_encoder(x)
653
- """
654
- x = x + self.pe[:paddle.shape(x)[0], :]
655
- return self.dropout(x)
656
-
657
-
658
- class PositionalEncoding_2d(nn.Layer):
659
- """Inject some information about the relative or absolute position of the tokens
660
- in the sequence. The positional encodings have the same dimension as
661
- the embeddings, so that the two can be summed. Here, we use sine and cosine
662
- functions of different frequencies.
663
- .. math::
664
- \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
665
- \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
666
- \text{where pos is the word position and i is the embed idx)
667
- Args:
668
- d_model: the embed dim (required).
669
- dropout: the dropout value (default=0.1).
670
- max_len: the max. length of the incoming sequence (default=5000).
671
- Examples:
672
- >>> pos_encoder = PositionalEncoding(d_model)
673
- """
674
-
675
- def __init__(self, dropout, dim, max_len=5000):
676
- super(PositionalEncoding_2d, self).__init__()
677
- self.dropout = nn.Dropout(p=dropout)
678
-
679
- pe = paddle.zeros([max_len, dim])
680
- position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
681
- div_term = paddle.exp(
682
- paddle.arange(0, dim, 2).astype('float32') *
683
- (-math.log(10000.0) / dim))
684
- pe[:, 0::2] = paddle.sin(position * div_term)
685
- pe[:, 1::2] = paddle.cos(position * div_term)
686
- pe = paddle.transpose(paddle.unsqueeze(pe, 0), [1, 0, 2])
687
- self.register_buffer('pe', pe)
688
-
689
- self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1))
690
- self.linear1 = nn.Linear(dim, dim)
691
- self.linear1.weight.data.fill_(1.)
692
- self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1))
693
- self.linear2 = nn.Linear(dim, dim)
694
- self.linear2.weight.data.fill_(1.)
695
-
696
- def forward(self, x):
697
- """Inputs of forward function
698
- Args:
699
- x: the sequence fed to the positional encoder model (required).
700
- Shape:
701
- x: [sequence length, batch size, embed dim]
702
- output: [sequence length, batch size, embed dim]
703
- Examples:
704
- >>> output = pos_encoder(x)
705
- """
706
- w_pe = self.pe[:paddle.shape(x)[-1], :]
707
- w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
708
- w_pe = w_pe * w1
709
- w_pe = paddle.transpose(w_pe, [1, 2, 0])
710
- w_pe = paddle.unsqueeze(w_pe, 2)
711
-
712
- h_pe = self.pe[:paddle.shape(x).shape[-2], :]
713
- w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
714
- h_pe = h_pe * w2
715
- h_pe = paddle.transpose(h_pe, [1, 2, 0])
716
- h_pe = paddle.unsqueeze(h_pe, 3)
717
-
718
- x = x + w_pe + h_pe
719
- x = paddle.transpose(
720
- paddle.reshape(x,
721
- [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
722
- [2, 0, 1])
723
-
724
- return self.dropout(x)
725
-
726
-
727
- class Embeddings(nn.Layer):
728
- def __init__(self, d_model, vocab, padding_idx, scale_embedding):
729
- super(Embeddings, self).__init__()
730
- self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
731
- w0 = np.random.normal(0.0, d_model**-0.5,
732
- (vocab, d_model)).astype(np.float32)
733
- self.embedding.weight.set_value(w0)
734
- self.d_model = d_model
735
- self.scale_embedding = scale_embedding
736
-
737
- def forward(self, x):
738
- if self.scale_embedding:
739
- x = self.embedding(x)
740
- return x * math.sqrt(self.d_model)
741
- return self.embedding(x)
742
-
743
-
744
- class Beam():
745
- ''' Beam search '''
746
-
747
- def __init__(self, size, device=False):
748
-
749
- self.size = size
750
- self._done = False
751
- # The score for each translation on the beam.
752
- self.scores = paddle.zeros((size, ), dtype=paddle.float32)
753
- self.all_scores = []
754
- # The backpointers at each time-step.
755
- self.prev_ks = []
756
- # The outputs at each time-step.
757
- self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)]
758
- self.next_ys[0][0] = 2
759
-
760
- def get_current_state(self):
761
- "Get the outputs for the current timestep."
762
- return self.get_tentative_hypothesis()
763
-
764
- def get_current_origin(self):
765
- "Get the backpointers for the current timestep."
766
- return self.prev_ks[-1]
767
-
768
- @property
769
- def done(self):
770
- return self._done
771
-
772
- def advance(self, word_prob):
773
- "Update beam status and check if finished or not."
774
- num_words = word_prob.shape[1]
775
-
776
- # Sum the previous scores.
777
- if len(self.prev_ks) > 0:
778
- beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
779
- else:
780
- beam_lk = word_prob[0]
781
-
782
- flat_beam_lk = beam_lk.reshape([-1])
783
- best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True,
784
- True) # 1st sort
785
- self.all_scores.append(self.scores)
786
- self.scores = best_scores
787
- # bestScoresId is flattened as a (beam x word) array,
788
- # so we need to calculate which word and beam each score came from
789
- prev_k = best_scores_id // num_words
790
- self.prev_ks.append(prev_k)
791
- self.next_ys.append(best_scores_id - prev_k * num_words)
792
- # End condition is when top-of-beam is EOS.
793
- if self.next_ys[-1][0] == 3:
794
- self._done = True
795
- self.all_scores.append(self.scores)
796
-
797
- return self._done
798
-
799
- def sort_scores(self):
800
- "Sort the scores."
801
- return self.scores, paddle.to_tensor(
802
- [i for i in range(int(self.scores.shape[0]))], dtype='int32')
803
-
804
- def get_the_best_score_and_idx(self):
805
- "Get the score of the best in the beam."
806
- scores, ids = self.sort_scores()
807
- return scores[1], ids[1]
808
-
809
- def get_tentative_hypothesis(self):
810
- "Get the decoded sequence for the current timestep."
811
- if len(self.next_ys) == 1:
812
- dec_seq = self.next_ys[0].unsqueeze(1)
813
- else:
814
- _, keys = self.sort_scores()
815
- hyps = [self.get_hypothesis(k) for k in keys]
816
- hyps = [[2] + h for h in hyps]
817
- dec_seq = paddle.to_tensor(hyps, dtype='int64')
818
- return dec_seq
819
-
820
- def get_hypothesis(self, k):
821
- """ Walk back to construct the full hypothesis. """
822
- hyp = []
823
- for j in range(len(self.prev_ks) - 1, -1, -1):
824
- hyp.append(self.next_ys[j + 1][k])
825
- k = self.prev_ks[j][k]
826
- return list(map(lambda x: x.item(), hyp[::-1]))