pyxllib 0.3.96__py3-none-any.whl → 0.3.200__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (358) hide show
  1. pyxllib/__init__.py +21 -21
  2. pyxllib/algo/__init__.py +8 -8
  3. pyxllib/algo/disjoint.py +54 -54
  4. pyxllib/algo/geo.py +541 -529
  5. pyxllib/algo/intervals.py +964 -964
  6. pyxllib/algo/matcher.py +389 -311
  7. pyxllib/algo/newbie.py +166 -166
  8. pyxllib/algo/pupil.py +629 -461
  9. pyxllib/algo/shapelylib.py +67 -67
  10. pyxllib/algo/specialist.py +241 -240
  11. pyxllib/algo/stat.py +494 -458
  12. pyxllib/algo/treelib.py +149 -149
  13. pyxllib/algo/unitlib.py +66 -66
  14. {pyxlpr → pyxllib/autogui}/__init__.py +5 -5
  15. pyxllib/autogui/activewin.py +246 -0
  16. pyxllib/autogui/all.py +9 -0
  17. pyxllib/{ext/autogui → autogui}/autogui.py +852 -823
  18. pyxllib/autogui/uiautolib.py +362 -0
  19. pyxllib/{ext/autogui → autogui}/virtualkey.py +102 -102
  20. pyxllib/autogui/wechat.py +827 -0
  21. pyxllib/autogui/wechat_msg.py +421 -0
  22. pyxllib/autogui/wxautolib.py +84 -0
  23. pyxllib/cv/__init__.py +5 -5
  24. pyxllib/cv/expert.py +267 -267
  25. pyxllib/cv/imfile.py +159 -159
  26. pyxllib/cv/imhash.py +39 -39
  27. pyxllib/cv/pupil.py +9 -9
  28. pyxllib/cv/rgbfmt.py +1525 -1525
  29. pyxllib/cv/slidercaptcha.py +137 -0
  30. pyxllib/cv/trackbartools.py +251 -251
  31. pyxllib/cv/xlcvlib.py +1040 -1040
  32. pyxllib/cv/xlpillib.py +423 -423
  33. pyxllib/data/echarts.py +240 -129
  34. pyxllib/data/jsonlib.py +89 -0
  35. pyxllib/data/oss.py +72 -72
  36. pyxllib/data/pglib.py +1127 -643
  37. pyxllib/data/sqlite.py +568 -341
  38. pyxllib/data/sqllib.py +297 -297
  39. pyxllib/ext/JLineViewer.py +505 -492
  40. pyxllib/ext/__init__.py +6 -6
  41. pyxllib/ext/demolib.py +246 -246
  42. pyxllib/ext/drissionlib.py +277 -0
  43. pyxllib/ext/kq5034lib.py +12 -1606
  44. pyxllib/ext/old.py +663 -663
  45. pyxllib/ext/qt.py +449 -449
  46. pyxllib/ext/robustprocfile.py +497 -0
  47. pyxllib/ext/seleniumlib.py +76 -76
  48. pyxllib/ext/tk.py +173 -173
  49. pyxllib/ext/unixlib.py +827 -826
  50. pyxllib/ext/utools.py +351 -338
  51. pyxllib/ext/webhook.py +124 -101
  52. pyxllib/ext/win32lib.py +40 -40
  53. pyxllib/ext/wjxlib.py +88 -0
  54. pyxllib/ext/wpsapi.py +124 -0
  55. pyxllib/ext/xlwork.py +9 -0
  56. pyxllib/ext/yuquelib.py +1105 -173
  57. pyxllib/file/__init__.py +17 -17
  58. pyxllib/file/docxlib.py +761 -761
  59. pyxllib/file/gitlib.py +309 -309
  60. pyxllib/file/libreoffice.py +165 -0
  61. pyxllib/file/movielib.py +148 -139
  62. pyxllib/file/newbie.py +10 -10
  63. pyxllib/file/onenotelib.py +1469 -1469
  64. pyxllib/file/packlib/__init__.py +330 -293
  65. pyxllib/file/packlib/zipfile.py +2441 -2441
  66. pyxllib/file/pdflib.py +426 -426
  67. pyxllib/file/pupil.py +185 -185
  68. pyxllib/file/specialist/__init__.py +685 -685
  69. pyxllib/file/specialist/dirlib.py +799 -799
  70. pyxllib/file/specialist/download.py +193 -186
  71. pyxllib/file/specialist/filelib.py +2829 -2618
  72. pyxllib/file/xlsxlib.py +3131 -2976
  73. pyxllib/file/xlsyncfile.py +341 -0
  74. pyxllib/prog/__init__.py +5 -5
  75. pyxllib/prog/cachetools.py +64 -0
  76. pyxllib/prog/deprecatedlib.py +233 -233
  77. pyxllib/prog/filelock.py +42 -0
  78. pyxllib/prog/ipyexec.py +253 -253
  79. pyxllib/prog/multiprogs.py +940 -0
  80. pyxllib/prog/newbie.py +451 -444
  81. pyxllib/prog/pupil.py +1197 -1128
  82. pyxllib/prog/sitepackages.py +33 -33
  83. pyxllib/prog/specialist/__init__.py +391 -217
  84. pyxllib/prog/specialist/bc.py +203 -200
  85. pyxllib/prog/specialist/browser.py +497 -488
  86. pyxllib/prog/specialist/common.py +347 -347
  87. pyxllib/prog/specialist/datetime.py +199 -131
  88. pyxllib/prog/specialist/tictoc.py +240 -241
  89. pyxllib/prog/specialist/xllog.py +180 -180
  90. pyxllib/prog/xlosenv.py +108 -101
  91. pyxllib/stdlib/__init__.py +17 -17
  92. pyxllib/stdlib/tablepyxl/__init__.py +10 -10
  93. pyxllib/stdlib/tablepyxl/style.py +303 -303
  94. pyxllib/stdlib/tablepyxl/tablepyxl.py +130 -130
  95. pyxllib/text/__init__.py +8 -8
  96. pyxllib/text/ahocorasick.py +39 -39
  97. pyxllib/text/airscript.js +744 -0
  98. pyxllib/text/charclasslib.py +121 -109
  99. pyxllib/text/jiebalib.py +267 -264
  100. pyxllib/text/jinjalib.py +32 -0
  101. pyxllib/text/jsa_ai_prompt.md +271 -0
  102. pyxllib/text/jscode.py +922 -767
  103. pyxllib/text/latex/__init__.py +158 -158
  104. pyxllib/text/levenshtein.py +303 -303
  105. pyxllib/text/nestenv.py +1215 -1215
  106. pyxllib/text/newbie.py +300 -288
  107. pyxllib/text/pupil/__init__.py +8 -8
  108. pyxllib/text/pupil/common.py +1121 -1095
  109. pyxllib/text/pupil/xlalign.py +326 -326
  110. pyxllib/text/pycode.py +47 -47
  111. pyxllib/text/specialist/__init__.py +8 -8
  112. pyxllib/text/specialist/common.py +112 -112
  113. pyxllib/text/specialist/ptag.py +186 -186
  114. pyxllib/text/spellchecker.py +172 -172
  115. pyxllib/text/templates/echart_base.html +11 -0
  116. pyxllib/text/templates/highlight_code.html +17 -0
  117. pyxllib/text/templates/latex_editor.html +103 -0
  118. pyxllib/text/vbacode.py +17 -17
  119. pyxllib/text/xmllib.py +747 -685
  120. pyxllib/xl.py +42 -38
  121. pyxllib/xlcv.py +17 -17
  122. pyxllib-0.3.200.dist-info/METADATA +48 -0
  123. pyxllib-0.3.200.dist-info/RECORD +126 -0
  124. {pyxllib-0.3.96.dist-info → pyxllib-0.3.200.dist-info}/WHEEL +1 -2
  125. {pyxllib-0.3.96.dist-info → pyxllib-0.3.200.dist-info/licenses}/LICENSE +190 -190
  126. pyxllib/ext/autogui/__init__.py +0 -8
  127. pyxllib-0.3.96.dist-info/METADATA +0 -51
  128. pyxllib-0.3.96.dist-info/RECORD +0 -333
  129. pyxllib-0.3.96.dist-info/top_level.txt +0 -2
  130. pyxlpr/ai/__init__.py +0 -5
  131. pyxlpr/ai/clientlib.py +0 -1281
  132. pyxlpr/ai/specialist.py +0 -286
  133. pyxlpr/ai/torch_app.py +0 -172
  134. pyxlpr/ai/xlpaddle.py +0 -655
  135. pyxlpr/ai/xltorch.py +0 -705
  136. pyxlpr/data/__init__.py +0 -11
  137. pyxlpr/data/coco.py +0 -1325
  138. pyxlpr/data/datacls.py +0 -365
  139. pyxlpr/data/datasets.py +0 -200
  140. pyxlpr/data/gptlib.py +0 -1291
  141. pyxlpr/data/icdar/__init__.py +0 -96
  142. pyxlpr/data/icdar/deteval.py +0 -377
  143. pyxlpr/data/icdar/icdar2013.py +0 -341
  144. pyxlpr/data/icdar/iou.py +0 -340
  145. pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
  146. pyxlpr/data/imtextline.py +0 -473
  147. pyxlpr/data/labelme.py +0 -866
  148. pyxlpr/data/removeline.py +0 -179
  149. pyxlpr/data/specialist.py +0 -57
  150. pyxlpr/eval/__init__.py +0 -85
  151. pyxlpr/paddleocr.py +0 -776
  152. pyxlpr/ppocr/__init__.py +0 -15
  153. pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
  154. pyxlpr/ppocr/data/__init__.py +0 -135
  155. pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
  156. pyxlpr/ppocr/data/imaug/__init__.py +0 -67
  157. pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
  158. pyxlpr/ppocr/data/imaug/east_process.py +0 -437
  159. pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
  160. pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
  161. pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
  162. pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
  163. pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
  164. pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
  165. pyxlpr/ppocr/data/imaug/operators.py +0 -433
  166. pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
  167. pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
  168. pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
  169. pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
  170. pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
  171. pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
  172. pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
  173. pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
  174. pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
  175. pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
  176. pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
  177. pyxlpr/ppocr/data/simple_dataset.py +0 -372
  178. pyxlpr/ppocr/losses/__init__.py +0 -61
  179. pyxlpr/ppocr/losses/ace_loss.py +0 -52
  180. pyxlpr/ppocr/losses/basic_loss.py +0 -135
  181. pyxlpr/ppocr/losses/center_loss.py +0 -88
  182. pyxlpr/ppocr/losses/cls_loss.py +0 -30
  183. pyxlpr/ppocr/losses/combined_loss.py +0 -67
  184. pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
  185. pyxlpr/ppocr/losses/det_db_loss.py +0 -80
  186. pyxlpr/ppocr/losses/det_east_loss.py +0 -63
  187. pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
  188. pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
  189. pyxlpr/ppocr/losses/distillation_loss.py +0 -272
  190. pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
  191. pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
  192. pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
  193. pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
  194. pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
  195. pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
  196. pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
  197. pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
  198. pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
  199. pyxlpr/ppocr/losses/table_att_loss.py +0 -109
  200. pyxlpr/ppocr/metrics/__init__.py +0 -44
  201. pyxlpr/ppocr/metrics/cls_metric.py +0 -45
  202. pyxlpr/ppocr/metrics/det_metric.py +0 -82
  203. pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
  204. pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
  205. pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
  206. pyxlpr/ppocr/metrics/kie_metric.py +0 -70
  207. pyxlpr/ppocr/metrics/rec_metric.py +0 -75
  208. pyxlpr/ppocr/metrics/table_metric.py +0 -50
  209. pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
  210. pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
  211. pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
  212. pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
  213. pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
  214. pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
  215. pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
  216. pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
  217. pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
  218. pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
  219. pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
  220. pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
  221. pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
  222. pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
  223. pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
  224. pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
  225. pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
  226. pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
  227. pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
  228. pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
  229. pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
  230. pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
  231. pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
  232. pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
  233. pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
  234. pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
  235. pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
  236. pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
  237. pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
  238. pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
  239. pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
  240. pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
  241. pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
  242. pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
  243. pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
  244. pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
  245. pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
  246. pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
  247. pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
  248. pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
  249. pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
  250. pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
  251. pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
  252. pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
  253. pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
  254. pyxlpr/ppocr/optimizer/__init__.py +0 -61
  255. pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
  256. pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
  257. pyxlpr/ppocr/optimizer/optimizer.py +0 -160
  258. pyxlpr/ppocr/optimizer/regularizer.py +0 -52
  259. pyxlpr/ppocr/postprocess/__init__.py +0 -55
  260. pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
  261. pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
  262. pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
  263. pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
  264. pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
  265. pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
  266. pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
  267. pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
  268. pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
  269. pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
  270. pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
  271. pyxlpr/ppocr/tools/__init__.py +0 -14
  272. pyxlpr/ppocr/tools/eval.py +0 -83
  273. pyxlpr/ppocr/tools/export_center.py +0 -77
  274. pyxlpr/ppocr/tools/export_model.py +0 -129
  275. pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
  276. pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
  277. pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
  278. pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
  279. pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
  280. pyxlpr/ppocr/tools/infer/utility.py +0 -629
  281. pyxlpr/ppocr/tools/infer_cls.py +0 -83
  282. pyxlpr/ppocr/tools/infer_det.py +0 -134
  283. pyxlpr/ppocr/tools/infer_e2e.py +0 -122
  284. pyxlpr/ppocr/tools/infer_kie.py +0 -153
  285. pyxlpr/ppocr/tools/infer_rec.py +0 -146
  286. pyxlpr/ppocr/tools/infer_table.py +0 -107
  287. pyxlpr/ppocr/tools/program.py +0 -596
  288. pyxlpr/ppocr/tools/test_hubserving.py +0 -117
  289. pyxlpr/ppocr/tools/train.py +0 -163
  290. pyxlpr/ppocr/tools/xlprog.py +0 -748
  291. pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
  292. pyxlpr/ppocr/utils/__init__.py +0 -24
  293. pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
  294. pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
  295. pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
  296. pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
  297. pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
  298. pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
  299. pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
  300. pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
  301. pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
  302. pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
  303. pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
  304. pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
  305. pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
  306. pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
  307. pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
  308. pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
  309. pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
  310. pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
  311. pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
  312. pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
  313. pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
  314. pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
  315. pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
  316. pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
  317. pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
  318. pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
  319. pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
  320. pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
  321. pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
  322. pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
  323. pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
  324. pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
  325. pyxlpr/ppocr/utils/dict90.txt +0 -90
  326. pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
  327. pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
  328. pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
  329. pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
  330. pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
  331. pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
  332. pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
  333. pyxlpr/ppocr/utils/en_dict.txt +0 -95
  334. pyxlpr/ppocr/utils/gen_label.py +0 -81
  335. pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
  336. pyxlpr/ppocr/utils/iou.py +0 -54
  337. pyxlpr/ppocr/utils/logging.py +0 -69
  338. pyxlpr/ppocr/utils/network.py +0 -84
  339. pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
  340. pyxlpr/ppocr/utils/profiler.py +0 -110
  341. pyxlpr/ppocr/utils/save_load.py +0 -150
  342. pyxlpr/ppocr/utils/stats.py +0 -72
  343. pyxlpr/ppocr/utils/utility.py +0 -80
  344. pyxlpr/ppstructure/__init__.py +0 -13
  345. pyxlpr/ppstructure/predict_system.py +0 -187
  346. pyxlpr/ppstructure/table/__init__.py +0 -13
  347. pyxlpr/ppstructure/table/eval_table.py +0 -72
  348. pyxlpr/ppstructure/table/matcher.py +0 -192
  349. pyxlpr/ppstructure/table/predict_structure.py +0 -136
  350. pyxlpr/ppstructure/table/predict_table.py +0 -221
  351. pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
  352. pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
  353. pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
  354. pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
  355. pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
  356. pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
  357. pyxlpr/ppstructure/utility.py +0 -71
  358. pyxlpr/xlai.py +0 -10
pyxlpr/data/gptlib.py DELETED
@@ -1,1291 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # @Author : 陈坤泽
4
- # @Email : 877362867@qq.com
5
- # @Date : 2023/07/13 14:26
6
-
7
- from pyxllib.prog.pupil import check_install_package
8
-
9
- # check_install_package('transformers', 'transformers')
10
-
11
- import ast
12
- from collections import OrderedDict
13
- from collections import Counter
14
- import contextlib
15
- import copy
16
- import datetime
17
- import heapq
18
- import html
19
- import json
20
- import math
21
- import random
22
- import re
23
- from urllib.parse import unquote
24
- import io
25
- import logging
26
- import warnings
27
-
28
- from jinja2 import Template
29
- from openpyxl import Workbook
30
- import pandas as pd
31
- import requests
32
- from tqdm import tqdm
33
-
34
- try:
35
- from transformers import AutoTokenizer, GPT2TokenizerFast
36
- except ModuleNotFoundError:
37
- pass
38
-
39
- from pyxllib.prog.pupil import OutputLogger
40
- from pyxllib.prog.specialist import browser, TicToc
41
- from pyxllib.algo.pupil import ValuesStat
42
- from pyxllib.file.specialist import XlPath, JsonlDataFile, JsonlDataDir, TwinDirs, ensure_localdir
43
- from pyxllib.file.xlsxlib import extract_workbook_summary
44
-
45
-
46
- def __1_生成提问数据():
47
- pass
48
-
49
-
50
- class Tokenizer:
51
- _tokenizer = None
52
-
53
- @classmethod
54
- def get_tokenizer(cls):
55
- """ 获取tokenizer,第一次调用时进行初始化 """
56
-
57
- if cls._tokenizer is None:
58
- # 根本没必要每次都尝试连接官网,本地有就不要老是sb的尝试连接huggingface
59
- # 而且官网连接也不稳,这里换成我自己的服务器中转
60
- # gpt2_dir = XlPath.tempdir() / 'huggingface_gpt2'
61
- # ensure_localdir(gpt2_dir, 'https://xmutpriu.com/download/huggingface_gpt2.zip')
62
- # Tokenizer._tokenizer = GPT2TokenizerFast.from_pretrained(gpt2_dir)
63
- # 240103周三21:23,hx给过的新评测模型
64
- gpt2_dir = XlPath.tempdir() / 'Atom-CL-SS'
65
- ensure_localdir(gpt2_dir, 'https://xmutpriu.com/download/Atom-CL-SS.zip')
66
- cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_dir, trust_remote_code=True)
67
- return cls._tokenizer
68
-
69
- @classmethod
70
- def tokenize(cls, paragraph, max_length=500):
71
- """ 对段落进行tokenize
72
-
73
- :param str paragraph: 待分词的段落
74
- :param int max_length: 单次处理的最大分词数,为了防止超过GPT2的限制,默认设置为500
75
- :return list: 分词后的列表
76
-
77
- >>> Tokenizer.tokenize('Hello, world! 汉字 123.14 35')
78
- ['Hello', ',', 'Ġworld', '!', 'Ġæ', '±', 'ī', 'åŃ', 'Ĺ', 'Ġ123', '.', '14', 'Ġ35']
79
- """
80
- tokenizer = cls.get_tokenizer()
81
-
82
- # 对段落进行切分
83
- paragraph_slices = [paragraph[i:i + max_length] for i in range(0, len(paragraph), max_length)]
84
-
85
- # 对每个切分的子段进行分词,并将结果拼接在一起
86
- tokens = []
87
- for slice in paragraph_slices:
88
- tokens += tokenizer.tokenize(slice)
89
-
90
- return tokens
91
-
92
- @classmethod
93
- def count_tokens(cls, paragraph, max_length=500):
94
- """ 获取段落的token数量
95
-
96
- :param str paragraph: 待分词的段落
97
- :param int max_length: 单次处理的最大分词数,为了防止超过GPT2的限制,默认设置为500
98
- :return int: token的数量
99
-
100
- >>> Tokenizer.count_tokens('Hello, world!')
101
- 5
102
- """
103
- return len(cls.tokenize(paragraph, max_length))
104
-
105
-
106
- def print_statistics(data, indent_level=1):
107
- """ 计算字符串长度,并且计算关键的一些token数
108
-
109
- :param data: data应该是一个嵌套结构,表示会话与消息
110
- """
111
- fmts = ['g', '.0f', '.0f', 'd', 'd']
112
- stat_len = ValuesStat([len(str(x)) for x in data])
113
-
114
- indent = '\t' * indent_level
115
- print(f'{indent} {stat_len.summary(fmts)}')
116
-
117
-
118
- def check_conversation_lengths(all_texts, n_values=(4, 4),
119
- compute_tokens=False, ids=None):
120
- """ 分析会话长度 """
121
-
122
- # 0 预处理
123
- for i, texts in enumerate(all_texts):
124
- if isinstance(texts, str):
125
- all_texts[i] = [texts]
126
-
127
- # 如果没有提供ID,则使用默认的range(n)
128
- if ids is None:
129
- ids = list(range(len(all_texts)))
130
-
131
- # 处理n_values的重叠
132
- if sum(n_values) >= len(all_texts):
133
- n_values = [len(all_texts), 0] # 将所有数据视为最短数据,不再考虑最长数据
134
-
135
- # 1 消息长度统计
136
- fmts = [None, '.0f', '.0f', 'd', 'd']
137
- lengths = [len(t) for texts in all_texts for t in texts]
138
- print(f'1、消息长度统计 {ValuesStat(lengths).summary(fmts)}')
139
-
140
- # 2 每组会话消息数目
141
- ct = Counter(len(texts) for texts in all_texts)
142
- sorted_ct = {k: v for k, v in sorted(ct.items(), key=lambda x: x[0])}
143
- print(f'2、每组消息数目: {sorted_ct}')
144
-
145
- # 3 找出消息总长度最短和最长的会话
146
- total_lengths = [(i, sum(len(t) for t in texts)) for i, texts in enumerate(all_texts)]
147
- shortest_indices = [item[0] for item in heapq.nsmallest(n_values[0], total_lengths, key=lambda x: x[1])]
148
- longest_indices = [item[0] for item in heapq.nlargest(n_values[1], total_lengths, key=lambda x: x[1])]
149
- longest_indices = longest_indices[::-1] # 从小到大排序
150
-
151
- parts = []
152
- if shortest_indices:
153
- parts.append(', '.join(map(str, [ids[i] for i in shortest_indices])))
154
- if longest_indices:
155
- parts.append(', '.join(map(str, [ids[i] for i in longest_indices])))
156
- print(f'3、最短最长会话的id:', ', ..., '.join(parts))
157
-
158
- # 4 计算token
159
- if compute_tokens:
160
- # 4.1 代表性样本的tokens数
161
- s_texts = [' '.join([x for x in all_texts[i]]) for i in shortest_indices]
162
- l_texts = [' '.join([x for x in all_texts[i]]) for i in longest_indices]
163
-
164
- s_lens = [[len(x), Tokenizer.count_tokens(x)] for x in s_texts]
165
- l_lens = [[len(x), Tokenizer.count_tokens(x)] for x in l_texts]
166
-
167
- parts = []
168
- if s_lens:
169
- parts.append(', '.join(map(str, [x[1] for x in s_lens])))
170
- if l_lens:
171
- parts.append(', '.join(map(str, [x[1] for x in l_lens])))
172
- # 仅计算3中代表性样本
173
- print(f'4、tokens数量:', ', ..., '.join(parts))
174
-
175
- # 4.2 token的比率规律
176
- ratios = []
177
- for x in s_lens + l_lens:
178
- ratios.append(x[1] / x[0])
179
- fmts = [None, '.0%', '.0%', '.0%', '.0%']
180
- print(f'token/len比率统计 {ValuesStat(ratios).summary(fmts)}')
181
- # 比率越大,代表越接近中文场景,汉字越多,要注意len的控制不要让token某些场合超出长度
182
-
183
-
184
- def set_template(s, *args, **kwargs):
185
- """ todo 这个名字会不会太容易冲突了? """
186
- return Template(s.strip(), *args, **kwargs)
187
-
188
-
189
- def set_meta_template(s, meta_start='[[', meta_end=']]', **kwargs):
190
- """ 支持预先用某些格式渲染后,再返回标准渲染模板 """
191
- t = Template(s.strip(), variable_start_string=meta_start,
192
- variable_end_string=meta_end).render(**kwargs)
193
- return Template(t)
194
-
195
-
196
- class StyleParser:
197
- def __init__(self, text):
198
- # 使用正则表达式拆分文本,并获取权重和风格
199
- self.styles = []
200
- self.weights = []
201
- matches = re.findall(r'<风格变换\d+(\s+(\d+))?[^>]*>\s*(.*?)\s*(?=<风格变换|$)', text, re.DOTALL)
202
- for match in matches:
203
- self.styles.append(match[2])
204
- # 提取权重
205
- weight = match[1]
206
- if weight:
207
- self.weights.append(int(weight))
208
- else:
209
- self.weights.append(100) # 默认权重
210
-
211
- def random_pick(self):
212
- """ 随机选择一个风格,并返回其下标和内容
213
-
214
- :return tuple: (下标, 风格内容)
215
-
216
- >>> sp = StyleParser("...") # 按照之前的格式传入一个字符串
217
- >>> index, style = sp.random_pick() # 随机选择一个风格
218
- """
219
- index = random.choices(range(len(self.styles)), weights=self.weights, k=1)[0]
220
- return index, self.styles[index]
221
-
222
-
223
- class GptChatJsonl(JsonlDataFile):
224
- """ GPT问答批量执行脚本的jsonl生成、读取器 """
225
-
226
- def __init__(self, file=None, num_records=None, *, start_id=None):
227
- from datetime import datetime
228
-
229
- super().__init__(file, num_records)
230
- if start_id is None:
231
- # 230821周一02:02,原本只有日期标记,后面发现这样id很容易出现重复,还是加上小时分钟更不容易引起一些没必要的麻烦
232
- today = datetime.now().strftime("%Y%m%d%H%M")
233
- self.start_id = int(today + "000000")
234
- else:
235
- self.start_id = start_id
236
-
237
- def read_jsonl(self, file):
238
- """ 从一个文件加载数据
239
- """
240
- self.records = XlPath(file).read_jsonl()
241
- try:
242
- self.start_id = self.records[-1]['id']
243
- except KeyError:
244
- pass
245
-
246
- def split_and_add_prompt(self, text, max_word_length=None, prompt=None):
247
- """
248
- :param text: 要插入的文本(纯文本,不能是字典格式的 {'content': ...})
249
- :param max_word_length: 如果设置了该值,则会对输入内容进行长度控制,拆成多个片段
250
- 之前测试,全英文长度大概在32500,中文在5000内
251
- :param prompt: max_word_length开启时才会生效,每个part要附加的提示规则
252
- 可以写个字符串,默认每段都是开头加这个提示
253
- 也可以参考gen_prompt2,一些生成函数写法进行自定义
254
- :return:
255
- """
256
- # 0 如果没有输入max_word_length,就不用特地处理了
257
- if max_word_length is None:
258
- return [text]
259
-
260
- # 1 工具函数
261
- def gen_prompt1(n, i, text):
262
- """ 一共n条,当前第i条的,当前内容是text """
263
- if n == 1:
264
- return text
265
- if n > 1:
266
- if i == 0:
267
- return f'【注意,由于本次提问过长,这里拆分成{n}个片段分开输入,目前是第{1}个片段,你只需暂时回复"收到"即可】\n' + text
268
- elif i < n - 1:
269
- return f'【这是{n}个片段中的第{i + 1}个片段,先回复"收到"即可】\n' + text
270
- else:
271
- return f'【这是{n}个片段的最后一个片段,请开始回复内容】\n' + text
272
-
273
- def gen_prompt2(n, i, text):
274
- return prompt + text
275
-
276
- if prompt is None:
277
- gen_prompt = gen_prompt1
278
- elif isinstance(prompt, str):
279
- gen_prompt = gen_prompt2
280
- else: # callable
281
- gen_prompt = prompt
282
-
283
- # 2 拆分重拼接
284
- # 首先要重新调整max_word_length,在确定要拆分为几个片段的情况下,尽量保证这些片段之间的均匀性
285
- num = len(text) // max_word_length + 1
286
- max_word_length = math.ceil(len(text) / num)
287
-
288
- # 2.1 检查是否有超长单行文本,要提前切分成多行
289
- lines = text.rstrip().split('\n')
290
- new_lines = []
291
- for line in lines:
292
- if len(line) < max_word_length:
293
- new_lines.append(line)
294
- else: # 单行就已经爆限制长度的比较特别
295
- n = max_word_length - 10 # 将这个长文本按照n的长度再拆分成多个片段加入new_lines
296
- parts = [line[i:i + n] for i in range(0, len(line), n)]
297
- new_lines += parts
298
-
299
- # 2.2 拼接new_lines
300
- fragments = []
301
- current_fragment = []
302
- current_fragment_total_length = 0
303
- for line in new_lines:
304
- if current_fragment_total_length + len(line) <= max_word_length:
305
- current_fragment.append(line)
306
- current_fragment_total_length += len(line)
307
- else:
308
- fragments.append('\n'.join(current_fragment))
309
- current_fragment = [line]
310
- current_fragment_total_length = len(line)
311
- if current_fragment:
312
- fragments.append('\n'.join(current_fragment))
313
-
314
- n = len(fragments)
315
- fragments = [gen_prompt(n, i, x).strip() for i, x in enumerate(fragments)]
316
-
317
- for i, fragment in enumerate(fragments):
318
- fragment = {"content": fragment}
319
- fragments[i] = fragment
320
- return fragments
321
-
322
- def split_texts(self, texts, max_word_length=None, prompt=None):
323
- """ 长对话自动拆分成多轮对话 """
324
- new_texts = []
325
- for text in texts:
326
- pure_text = text['content']
327
- new_texts += self.split_and_add_prompt(pure_text, max_word_length=max_word_length, prompt=prompt)
328
- if 'file_paths' in text: # 如果有文件,自动放在最后一轮插入
329
- new_texts[-1]['file_paths'] = text['file_paths']
330
- return new_texts
331
-
332
- def add_record(self, texts, *, extra=None,
333
- record_id=0, max_word_length=None, prompt=None):
334
- """
335
- :param texts:
336
- str -> list[str],可以只输入一个str,默认一轮对话
337
- list[str] -> list[{'content': ..., 'file_paths': [...]}]
338
- content: 文本内容
339
- file_paths: 注意可以设置本地电脑其他来源,会自动移到该任务的upload_files里
340
- :param record_id: 可以自定义这个session的id
341
- :param max_word_length: 是否设置一个约束长度,自动切分会话中太长的消息
342
- gpt4是8192个token,大概len就是8192/0.6=13653,一般建议如果要设就设10000左右
343
- :param prompt: 自动分段后
344
- None,自动配置的一套提示
345
- '', 不用提示
346
- :return:
347
- """
348
- # 1 变成标准的list + 字典结构,方便后面统一处理
349
- if not isinstance(texts, list):
350
- texts = [texts]
351
-
352
- for i, text in enumerate(texts):
353
- if isinstance(text, str):
354
- texts[i] = {'content': text}
355
-
356
- # 2 如果设置了每次最大会话长度,要进行拆分
357
- if max_word_length:
358
- texts = self.split_texts(texts, max_word_length=max_word_length, prompt=prompt)
359
-
360
- for i, text in enumerate(texts):
361
- texts[i]['content'] = text['content'].strip()
362
-
363
- # 3 添加会话conversation
364
- self.start_id += 1
365
- item = {'id': str(record_id or self.start_id), # 要转成字符串类型,不然容易出问题
366
- 'text': texts,
367
- 'first_text_length': len(texts[0]['content'])}
368
- if extra:
369
- item['extra'] = extra
370
- self.records.append(item)
371
- return item
372
-
373
- def fix_file_paths(self, save_dir):
374
- """ 修正records中设置的file_paths
375
-
376
- 这些路径可能在设置的时候图方便,设置的是非项目目录下的路径
377
- 这个函数会对这些路径进行修正,为了修正,需要输入一个该jsonl所保存的目录位置
378
- """
379
- save_dir = XlPath(save_dir)
380
- for i, record in tqdm(enumerate(self.records), desc='修复文件路径'):
381
- dst_dir = save_dir / 'upload_files' / str(record['id'])
382
- for j, text in enumerate(record['text']):
383
- for k, fn in enumerate(text.get('file_paths', [])):
384
- src_file = XlPath(fn)
385
- src_file2 = src_file.as_posix()
386
- if src_file2.startswith(f'upload_files/{record["id"]}/'):
387
- continue
388
- dst_file = dst_dir / src_file.name
389
- dst_file2 = dst_file.relpath(save_dir).as_posix()
390
- if src_file.is_file():
391
- if src_file2 != dst_file2:
392
- dst_dir.mkdir(parents=True, exist_ok=True)
393
- src_file.copy(dst_file, if_exists='replace')
394
- else: # 既然设置了,原文件目录应该在
395
- raise FileNotFoundError(f'{src_file}')
396
- text['file_paths'][k] = dst_file2
397
-
398
- def clean_file_paths(self):
399
- """ 清除records中的file_paths
400
- 一般用于把一些相关文件移到对应会话后,实际提问gpt的时候并不上传文件
401
- """
402
- for x in self.records:
403
- for t in x['text']:
404
- if 'file_paths' in t:
405
- del t['file_paths']
406
-
407
- def find_indices_by_qlength(self):
408
- """ 返回提问(q,question)内容从短到长的数据下标 """
409
- lens = [(i, len(''.join([t['content'] for t in x['text']]))) for i, x in enumerate(self.records)]
410
- # 根据长度进行排序,得到的元组的第一个元素为原列表的下标,第二个元素为对应的长度
411
- sorted_lens = sorted(lens, key=lambda x: x[1])
412
- # 取出排序后的下标
413
- sorted_indexs = [i for i, _ in sorted_lens]
414
- return sorted_indexs
415
-
416
- def browse_record(self, index=None, paths=None, **kwargs):
417
- """ 检查第i次会话的内容
418
- """
419
- # 如果未提供索引,则尝试使用查询参数找到第一个匹配的记录
420
- if index is None:
421
- index = self.find_index(paths, **kwargs)
422
- if index is None:
423
- raise ValueError('No matching record found')
424
- session = self.records[index]
425
-
426
- # 构建HTML内容
427
- html_content = "<html><body>"
428
-
429
- # 输出除了text和all_answers以外的所有键值信息
430
- html_content += "<h2>会话信息:</h2>"
431
- html_content += "<ul>"
432
- for key, value in session.items():
433
- if key not in ["text", "all_answers"]:
434
- html_content += f"<li>{html.escape(key)}: {html.escape(str(value))}</li>"
435
- html_content += "</ul>"
436
-
437
- # 输出text和all_answers的内容
438
- texts = self.get_text_texts(session.get("text", []))
439
- all_answers = self.get_all_answers_texts(session.get("all_answers", []))
440
-
441
- max_length = max(len(texts), len(all_answers))
442
- for idx in range(max_length):
443
- html_content += f"<h3>第{idx + 1}次询问:</h3>"
444
- if idx < len(texts):
445
- html_content += f"<pre>{html.escape(texts[idx])}</pre>"
446
- if idx < len(all_answers):
447
- html_content += f"<h3>第{idx + 1}次回答:</h3>"
448
- html_content += f"<pre>{html.escape(str(all_answers[idx]))}</pre>"
449
-
450
- html_content += "</body></html>"
451
- html_file = (XlPath.tempdir() / (str(session.get('id', index)) + '.html'))
452
- html_file.write_text(html_content)
453
- browser.html(html_file)
454
-
455
- # 返回HTML字符串
456
- return html_content
457
-
458
- def get_text_texts(self, text):
459
- """ 从text字段获得所有的文本内容
460
- 因为里面可能是dict
461
- """
462
- ls = []
463
- for t in text:
464
- if isinstance(t, str):
465
- ls.append(t)
466
- else:
467
- if "file_path" in t:
468
- ls.append(("filep_path=" + str(t["file_path"]) + "\n\n") + t["content"])
469
- else:
470
- ls.append(t["content"])
471
- return ls
472
-
473
- def get_all_answers_texts(self, all_answers):
474
- ls = []
475
- for t in all_answers:
476
- if isinstance(t, dict):
477
- t = json.dumps(t, ensure_ascii=False, indent=2)
478
- ls.append(str(t))
479
- return ls
480
-
481
- def check(self):
482
- """ 检查会话、消息长度等信息 """
483
- # 1 提问的内容
484
- all_texts = [self.get_text_texts(session.get('text', []))
485
- for session in self.records]
486
- print('【提问的内容】')
487
- check_conversation_lengths(all_texts,
488
- compute_tokens=True,
489
- ids=[x['id'] for x in self.records])
490
-
491
- # 2 回复的内容
492
- all_texts = [self.get_all_answers_texts(session.get('all_answers', []))
493
- for session in self.records]
494
- # 过滤空值,并相应地更新ids
495
- filtered_texts = [(text, session['id']) for text, session in zip(all_texts, self.records) if text]
496
- all_texts, ids = zip(*filtered_texts) if filtered_texts else ([], [])
497
- if all_texts:
498
- print('【回复的内容】')
499
- check_conversation_lengths(all_texts,
500
- compute_tokens=True,
501
- ids=ids)
502
-
503
- def filter_records_without_answers(self):
504
- """ 过滤掉没有 'all_answers' 字段的sessions """
505
-
506
- # 输出过滤前的sessions数量
507
- print(f"过滤前的sessions数量:{len(self.records)}")
508
-
509
- # 使用列表推导式过滤出包含 'all_answers' 字段的sessions
510
- self.records = [s for s in self.records
511
- if (''.join(map(str, s.get('all_answers', []))))]
512
-
513
- # 输出过滤后的sessions数量
514
- print(f"过滤后的sessions数量:{len(self.records)}")
515
-
516
- @classmethod
517
- def _parse_single_record_answer_contents(cls, record):
518
- """ 注意本函数不做record备份 """
519
- for answer in record.get('all_answers', []):
520
- if isinstance(answer, dict) and 'contents' in answer:
521
- n = len(answer['contents'])
522
- for i in range(n - 1, -1, -1):
523
- message = answer['contents'][i]['message']
524
- if message and 'content' in message and 'error' not in message:
525
- break
526
- else:
527
- answer['contents'] = ''
528
- continue
529
-
530
- content = message['content']
531
- if 'parts' in content:
532
- content = '\n'.join(content['parts'])
533
- else:
534
- content = content['text']
535
- answer['contents'] = content
536
-
537
- @classmethod
538
- def _parse_single_record_answer_downloads(cls, record):
539
- for answer in record.get('all_answers', []):
540
- if 'downloads' in answer:
541
- for i, link in enumerate(answer['downloads']):
542
- m = re.search(r'filename%3D(.+?)&sig=', link)
543
- if m:
544
- answer['downloads'][i] = unquote(unquote(m.group(1)))
545
-
546
- @classmethod
547
- def parse_single_record_answer(cls, record):
548
- cls._parse_single_record_answer_contents(record)
549
- cls._parse_single_record_answer_downloads(record)
550
-
551
- def parse_answer_contents(self):
552
- """ 简化解释器返回结果中,contents的结构信息 """
553
- for record in self.records:
554
- self._parse_single_record_answer_contents(record)
555
-
556
- def parse_answer_downloads(self):
557
- """ 解析,简化下载链接的表达形式 """
558
- for record in self.records:
559
- self._parse_single_record_answer_downloads(record)
560
-
561
- # 目录里的文件名也同理做精简
562
- for f in self.infile.parent.glob_files():
563
- if f.name.startswith('OpenAI-download-'):
564
- f.rename2(f.parent / re.sub(r'OpenAI-download-\d+-', '', f.name),
565
- if_exists='replace')
566
-
567
- def filter_to_rechat(self, check_func, rechat_path=None):
568
- """ 筛选失败的数据到一个新的目录,常用于对chatted数据筛选出未成功的样例,上池子重跑
569
- 这个不是简单的找出得不到all_answers的,而是可以很精细,包含复杂post、verify的情况
570
-
571
- :param check_func: 一个函数,接收一个record,返回True或False
572
- True,表示这个record是对的
573
- False,表示这个record是错的,要挑选出来
574
- :param rechat_path: 把挑选出来的数据放到新路径
575
- """
576
- if rechat_path is None:
577
- rechat_path = XlPath(self.infile.parent.as_posix() + '_rechat/in.jsonl')
578
-
579
- rechat_path = XlPath(rechat_path)
580
- td = TwinDirs(self.infile.parent, rechat_path.parent)
581
-
582
- gcj = type(self)()
583
- for record in self.records:
584
- if not check_func(record):
585
- record2 = {}
586
- for k in ['id', 'text', 'first_text_length', 'extra']:
587
- record2[k] = record[k]
588
- gcj.records.append(record2)
589
- for x in record['text']:
590
- if 'file_path' in x:
591
- td.copy_file(td.src_dir / x['file_path'])
592
-
593
- gcj.save(rechat_path)
594
- return gcj
595
-
596
- def update_from_rechat(self, check_func, rechat_path=None):
597
- """ 从另一份rechat的数据,更新回主master数据
598
-
599
- :param check_func: 原chatted没过,但是rechatted通过的,需要把数据更新过来
600
- :param rechat_path: 注意只能传路径,因为可能涉及到文件操作,需要知道目录所在
601
- 依据这个文件里的record记录更新回self
602
- """
603
- if rechat_path is None:
604
- rechat_path = XlPath(self.infile.parent.as_posix() + '_rechat') / 'out.jsonl'
605
-
606
- rechat_path = XlPath(rechat_path)
607
- td = TwinDirs(rechat_path.parent, self.infile.parent)
608
-
609
- id2index = {x['id']: i for i, x in enumerate(self.records)}
610
-
611
- gcj = type(self)(rechat_path)
612
- gcj.parse_answer_contents()
613
- gcj.parse_answer_downloads()
614
-
615
- # 需要处理下下载链接名称
616
- self.parse_answer_downloads()
617
- gcj.parse_answer_downloads()
618
-
619
- for y in gcj.records:
620
- index = id2index[y['id']]
621
- x = self.records[index]
622
- if not check_func(x) and check_func(y):
623
- # 先把x相关的数据删掉
624
- if 'all_answers' in x:
625
- for answer in x['all_answers']:
626
- for fname in answer.get('downloads', []):
627
- (XlPath(self.infile.parent) / fname).delete()
628
- # 再把y拷贝过来
629
- for answer in y['all_answers']:
630
- for fname in answer.get('downloads', []):
631
- td.copy_file(td.src_dir / fname)
632
- self.records[index] = y
633
- return gcj
634
-
635
-
636
- GptQuestionJsonl = GptChatJsonl # 名称向下兼容
637
-
638
-
639
- def __2_数据后处理():
640
- """ 一些常用的文本、后处理功能也放到这里 """
641
-
642
-
643
- def try_eval_json(resp_json):
644
- try:
645
- resp_json = ast.literal_eval(resp_json)
646
- if isinstance(resp_json, dict):
647
- resp_json = resp_json[resp_json.keys()[0]]
648
- except:
649
- pass
650
- return resp_json
651
-
652
-
653
- def try_load_json(resp_json):
654
- if isinstance(resp_json, str):
655
- try:
656
- resp_json = json.loads(resp_json)
657
- if isinstance(resp_json, dict):
658
- resp_json = resp_json[resp_json.keys()[0]]
659
- except:
660
- pass
661
- return resp_json
662
-
663
-
664
- def try_parse_json(resp_json):
665
- if isinstance(resp_json, dict):
666
- try:
667
- resp_json = '\n'.join(resp_json['contents'][-1]['message']['content'].get('parts', []))
668
- except TypeError:
669
- return ''
670
-
671
- resp_json = try_eval_json(resp_json)
672
- if isinstance(resp_json, str):
673
- return try_load_json(resp_json)
674
- return resp_json
675
-
676
-
677
- def extract_code_blocks_from_md(markdown_text, *, sort_by_length=False):
678
- """ 可以输入str,也可以输入list[str]
679
-
680
- :param sort_by_length: 按代码长度从短到长排序
681
- 常用在比较确信有效代码段应该只有一段,但是有些短小的片段有干扰
682
- 此时可以排序后,选取最长的一个代码片段作为正确代码
683
- """
684
- if isinstance(markdown_text, str):
685
- markdown_text = [markdown_text]
686
-
687
- matches = []
688
- pattern = re.compile(r'^```[^\n]*\n(.+?)\n^```', re.MULTILINE | re.DOTALL)
689
- for text in markdown_text:
690
- matches += pattern.findall(text)
691
-
692
- if sort_by_length:
693
- matches = sorted(matches, key=len)
694
-
695
- return matches
696
-
697
-
698
- def extract_airscript_code_from_answers(all_answers):
699
- """ 从多轮回答的最后一次回答中提取求解代码 """
700
- contents = all_answers[-1]['contents']
701
- text = contents[-1]['text']
702
- code_blocks = extract_code_blocks_from_md(text, sort_by_length=True)
703
-
704
- if code_blocks:
705
- return code_blocks[-1]
706
- else:
707
- return ''
708
-
709
-
710
- def merge_answers_contents(answers):
711
- """ 对一组answers结果中,相同type的contents进行合并 """
712
- for answer in answers:
713
- contents = []
714
- for content in answer['contents']:
715
- if len(contents) == 0:
716
- contents.append(content)
717
- else:
718
- if contents[-1]['type'] == content['type']:
719
- contents[-1]['text'] += '\n' + content['text']
720
- else:
721
- contents.append(content)
722
- answer['contents'] = contents
723
-
724
-
725
- def refine_content_title(content, tag, dst_title=None):
726
- """ 将内容中的标题描述形式标准化
727
-
728
- :param tag: 原标题相关字符
729
- :param content: 文本内容
730
- :param dst_title: 目标标题格式
731
- :return: 处理后的字符串
732
- """
733
- if dst_title is None:
734
- dst_title = f'<{tag}>'
735
- content_lines = content.splitlines()
736
- chars_str = re.compile(tag.replace(':', '[:的]?'))
737
- chinese_chars = re.compile(r'[\u4e00-\u9fa5]')
738
-
739
- res = []
740
- for line in content_lines:
741
- # 使用正则表达式查找匹配的部分
742
- new_line = chars_str.sub('', line)
743
- if new_line != line and not chinese_chars.search(new_line):
744
- res.append(dst_title)
745
- else:
746
- # 如果不满足条件,不进行替换
747
- res.append(line)
748
- return '\n'.join(res)
749
-
750
-
751
- def refine_block_name(record, block_names, preproc=None):
752
- """ 优化模块的标题名,方便后续结构化提取数据
753
-
754
- 感觉这个系列解析是比较通用的,就放在标准库中
755
- """
756
- # if preproc is None:
757
- # def preproc(x):
758
- # return x
759
-
760
- for answer in record['all_answers']:
761
- for content in answer['contents']:
762
- if content['type'] == 'text':
763
- text = old_text = content['text']
764
- if preproc is not None:
765
- text = preproc(text)
766
-
767
- for block_name in block_names:
768
- text = refine_content_title(text, block_name)
769
- text = refine_content_title(text, '---', '')
770
- # 一般不要直接修改原数据,但post里会有备份,所以这里verify可以直接修改了
771
- # if 'answer' not in curr_record['extra']:
772
- # curr_record['extra']['answer'] = []
773
- # curr_record['extra']['answer'].append(text)
774
- content['text'] = text
775
- # 可以借助bc调试
776
- # bcompare(old_text, text)
777
-
778
-
779
- def extract_block_content(record, block_name):
780
- """ 从record的all_answers中,从后往前检索 <block_name> 的内容,
781
- 返回第一个匹配结果,如果找不到则返回空字符串
782
- """
783
- for answer in record['all_answers'][::-1]:
784
- for content in answer['contents'][::-1]:
785
- if content['type'] == 'text':
786
- matches = list(re.finditer(rf'^<{block_name}>\n((.|\n)+?)(?=^<.+?>\n)',
787
- content['text'] + '\n<test>\n', # 末尾补一个<test>,方便对齐
788
- flags=re.MULTILINE))
789
- if matches:
790
- s = matches[-1].group(1).strip()
791
- blocks = extract_code_blocks_from_md(s, sort_by_length=True)
792
- if blocks:
793
- return blocks[-1]
794
- if s:
795
- return s
796
- return '' # 提取不到
797
-
798
-
799
- def __3_生成最后训练用的数据():
800
- pass
801
-
802
-
803
- def texts2train_record(texts):
804
- """ user和assistant的轮询对话,转为训练集格式 """
805
- messages = []
806
- for i, text in enumerate(texts):
807
- role = 'assistant' if i % 2 else 'user'
808
- messages.append({'role': role, 'content': text})
809
- return {'messages': messages}
810
-
811
-
812
- class GptTrainJsonl(JsonlDataFile):
813
- """
814
- record: dict
815
- messages: list
816
- dict: role='user', content=...
817
- dict: role='assistant', content=...
818
- """
819
-
820
- def analyze_text_length(self):
821
- # 1 先将数据统计到df
822
- ls = []
823
- columns = ['role', 'content']
824
- for x in self.records:
825
- for t in x['messages']:
826
- ls.append([t['role'], t['content']])
827
- df = pd.DataFrame.from_records(ls, columns=columns)
828
-
829
- # 2 再从df筛选出不同的统计数据
830
- print('【user和assistant】')
831
- print_statistics(df['content'])
832
- print('【user】')
833
- print_statistics(df[df['role'] == 'user']['content'])
834
- print('【assistant】')
835
- print_statistics(df[df['role'] == 'assistant']['content'])
836
-
837
- def check(self):
838
- """ 检查会话、消息长度等信息 """
839
- # 1. 提取'user'角色的content
840
- user_texts = [[message['content']
841
- for message in record['messages']
842
- if message['role'] == 'user']
843
- for record in self.records]
844
- if not user_texts:
845
- print('空数据')
846
- return
847
-
848
- print('【User的内容】')
849
- check_conversation_lengths(user_texts, compute_tokens=True,
850
- # 因为一般是使用JLineViewer进行查看,跟那个软件对称使用1开始编号
851
- ids=list(range(1, len(user_texts) + 1)))
852
-
853
- # 2. 提取'assistant'角色的content
854
- assistant_texts = [[message['content']
855
- for message in record['messages']
856
- if message['role'] == 'assistant']
857
- for record in self.records]
858
- print('【Assistant的内容】')
859
- check_conversation_lengths(assistant_texts, compute_tokens=True,
860
- ids=list(range(1, len(assistant_texts) + 1)))
861
-
862
- # 3. 将整个record视为一个完整的会话
863
- full_conversations = [' '.join([message['content'] for message in record['messages']])
864
- for record in self.records]
865
- print('【完整的会话】')
866
- check_conversation_lengths(full_conversations, compute_tokens=True,
867
- ids=list(range(1, len(full_conversations) + 1)))
868
-
869
- def browse_record(self, index=None, paths=None, **kwargs):
870
- """ 显示第i次会话的内容 """
871
- # 如果未提供索引,则尝试使用查询参数找到第一个匹配的记录
872
- if index is None:
873
- index = self.find_index(paths, **kwargs)
874
- if index is None:
875
- raise ValueError('No matching record found')
876
- session = self.records[index]
877
-
878
- # 构建HTML内容
879
- html_content = "<html><body>"
880
-
881
- # 输出除了messages以外的所有键值信息
882
- html_content += "<h2>会话信息:</h2>"
883
- html_content += "<ul>"
884
- for key, value in session.items():
885
- if key != "messages":
886
- html_content += f"<li>{html.escape(key)}: {html.escape(str(value))}</li>"
887
- html_content += "</ul>"
888
-
889
- # 输出messages的内容
890
- messages = session.get("messages", [])
891
-
892
- for idx, message in enumerate(messages):
893
- role = message.get('role', 'unknown')
894
- content = message.get('content', '')
895
- html_content += f"<h3>第{(idx // 2) + 1}次{role}的发言:</h3>"
896
- html_content += f"<pre>{html.escape(content)}</pre>"
897
-
898
- html_content += "</body></html>"
899
- html_file = (XlPath.tempdir() / (f'session_{index}.html')) # 创建临时文件名,防止覆盖现有文件
900
- html_file.write_text(html_content)
901
- browser.html(html_file) # 在浏览器中打开HTML文件
902
-
903
- # 或者返回HTML字符串
904
- return html_content
905
-
906
- def add_record(self, texts):
907
- messages = []
908
- for i, text in enumerate(texts):
909
- role = 'assistant' if i % 2 else 'user'
910
- messages.append({'role': role, 'content': text})
911
- self.records.append({'messages': messages})
912
-
913
- def add_from_texts(self, texts):
914
- record = texts2train_record(texts)
915
- self.records.append(record)
916
-
917
-
918
- def __4_综合集成类():
919
- pass
920
-
921
-
922
- class GptChatDir:
923
- """ 一个目录,包含了一个任务的所有数据,包括in、out、post等文件 """
924
-
925
- def __init__(self, root=None, lines_per_file=10000):
926
- if root is None:
927
- root = self.__class__.__name__.lower()
928
-
929
- self.root = root = XlPath(root)
930
- self.lines_per_file = lines_per_file
931
-
932
- self.chat_file = root / 'in.jsonl'
933
- self.chatted_file = root / 'out.jsonl'
934
- self.post_file = root / 'post.jsonl'
935
- self.verify_file = root / 'verify.jsonl'
936
- self.train_file = root / 'train.jsonl'
937
-
938
- # 如果有目录文件,会优先以目录为准。如果没有,则会从单文件拆分创建。
939
- self.update_dir()
940
-
941
- self.upload_files_dir = root / 'upload_files'
942
- self.download_files_dir = root / 'download_files'
943
-
944
- # todo 把 1chat 改名 in,2chatted 改名 out
945
- # for f in self.root.glob_files('*1chat*.jsonl'):
946
- # f.rename2(f.parent / 'in.jsonl')
947
-
948
- # for dir_path in [self.root, self.upload_files_dir, self.download_files_dir]:
949
- for dir_path in [self.root]:
950
- if not dir_path.is_dir():
951
- dir_path.mkdir(parents=True, exist_ok=True)
952
-
953
- # 这个类经常要并发处理,不能把一个不能序列化的类放到这里~
954
- # self.logger = OutputLogger(log_file=self.root / 'log.txt')
955
-
956
- def update_dir(self):
957
- """ 目录结构有些更新后,一些成员变量要跟着改变 """
958
- # 如果有目录文件,会优先以目录为准。如果没有,则会从单文件拆分创建。
959
- self.chat_dir = JsonlDataDir.init_from_file(self.chat_file, self.lines_per_file)
960
- self.chatted_dir = JsonlDataDir.init_from_file(self.chatted_file, self.lines_per_file)
961
- self.post_dir = JsonlDataDir.init_from_file(self.post_file, self.lines_per_file)
962
- self.verify_dir = JsonlDataDir.init_from_file(self.verify_file, self.lines_per_file)
963
- self.train_dir = JsonlDataDir.init_from_file(self.train_file, self.lines_per_file)
964
-
965
- def summary_records(self):
966
- """ 一些统计信息 """
967
- # 1 chat信息
968
- gcd1 = self.chatted_dir or self.chat_dir
969
- if not gcd1:
970
- print('请确认是否有生成初始的chat数据')
971
- return
972
-
973
- print(f'【{self.root.name}】')
974
- texts = [len(x['text']) for x in gcd1.yield_record()]
975
- n, m = len(texts), sum(texts)
976
- print(f'1、chat:{n}条会话*{m / n:.2g}条消息')
977
- gcj1 = GptChatJsonl(gcd1.files[0]) # 统计一个文件就够了,不然太多了
978
- gcj1.check_records()
979
- print()
980
-
981
- # 2 chatted信息
982
- filter_records = [x for x in gcd1.yield_record() if 'all_answers' in x]
983
- if filter_records:
984
- print(f'2、chatted:已获得{len(filter_records)}条会话')
985
- else:
986
- print('2、chatted:暂未获得生成数据')
987
-
988
- # 3 post信息
989
- if self.post_dir:
990
- print(f'3、post:{self.post_dir.count_records()}条会话')
991
-
992
- # 4 verify(这一步有时候会集成到post中)
993
- if self.verify_dir:
994
- print(f'4、verify:{self.verify_dir.count_records()}条会话')
995
-
996
- # 5 train 生成的训练数据
997
- # print('5、train:')
998
- # gtj = GptTrainJsonl(self.train_file)
999
- # gtj.analyze_text_length()
1000
-
1001
- def summary_downloads(self):
1002
- """ 统计下载的文件情况 """
1003
- print('【每个目录文件数量】')
1004
- files_each_dir = []
1005
- for d in self.download_files_dir.glob_dirs():
1006
- files_each_dir.append(len(list(d.rglob_files())))
1007
- print(ValuesStat(files_each_dir).summary())
1008
- print(Counter(files_each_dir))
1009
-
1010
- print('【每个文件大小】')
1011
- filesizes_each_dir = []
1012
- for d in self.download_files_dir.glob_dirs():
1013
- for f in d.rglob_files():
1014
- filesizes_each_dir.append(f.size())
1015
- print(ValuesStat(filesizes_each_dir).summary())
1016
-
1017
- def create_chat(self):
1018
- """ 生成chat数据,具体内容方式跟业务有关 """
1019
- raise NotImplementedError
1020
-
1021
- def browse_chatted_record(self, index=None, paths=None, **kwargs):
1022
- """ 显示第i次会话的内容 """
1023
- f = self.chatted_file if self.chatted_file.is_file() else self.chat_file
1024
- return GptChatJsonl(f, 100).browse_record(index, paths, **kwargs)
1025
-
1026
- def chatted2post_record(self, chatted_record):
1027
- """ 后处理,解析
1028
-
1029
- 一般会保留基本的all_answers结果,供检查上游一些基本情况
1030
- 然后把一些结构化结果存储到extra字段
1031
-
1032
- :return: 会返回新的dict结构的一个post_record,如果解析失败,会返回None
1033
- """
1034
- # 0 基本情况判断
1035
- if 'all_answers' not in chatted_record:
1036
- return
1037
-
1038
- post_record = copy.deepcopy(chatted_record)
1039
-
1040
- # 1 删掉一些没卵用的字段
1041
- for name in ['all_questions', 'first_text_length', 'second_text_length']:
1042
- if name in post_record:
1043
- del post_record[name]
1044
-
1045
- # 2 解析all_answers:这个结构太复杂,进行内容整理精简
1046
- # 2.1 contents:这个结构太复杂,搁这俄罗斯套娃呢~ 稍微精简下更方便后处理
1047
- for k, answer in enumerate(post_record['all_answers']):
1048
- if isinstance(answer, dict) and 'contents' in answer:
1049
- new_contents = []
1050
- for i, x in enumerate(answer['contents']):
1051
- if not x['message']:
1052
- # Error in message stream
1053
- # print(f'{post_record["id"]} answer[{k}] contents[{i}] message为空')
1054
- continue
1055
-
1056
- content = x['message']['content']
1057
- tp = content['content_type']
1058
- new_content = {'type': content['content_type']}
1059
- if tp == 'text':
1060
- new_content['text'] = '\n'.join(content['parts'])
1061
- elif tp == 'code':
1062
- new_content['text'] = content['text']
1063
- elif tp == 'execution_output':
1064
- new_content['text'] = content['text']
1065
- elif tp == 'system_error':
1066
- continue
1067
- else:
1068
- print(f'{post_record["id"]} answer[{k}] contents[{i}] content_type={tp} 未见类型')
1069
- continue
1070
-
1071
- new_contents.append(new_content)
1072
- answer['contents'] = new_contents
1073
- elif isinstance(answer, str): # 普通模式也转成解释器风格,方便统一处理
1074
- post_record['all_answers'][k] = {'contents': [{'type': 'text',
1075
- 'text': answer}]}
1076
-
1077
- # 2.2 downloads:下载链接精简下,并把关联的文件也顺带整理一下
1078
- for answer in post_record['all_answers']:
1079
- if 'downloads' not in answer:
1080
- continue
1081
- for i, link in enumerate(answer['downloads']):
1082
- m = re.search(r'filename%3D(.+?)&sig=', link)
1083
- if m:
1084
- answer['downloads'][i] = str(post_record['id']) + '/' + unquote(unquote(m.group(1)))
1085
- # 对应的文件不存在的不要,有数据超过50M的也不要
1086
- file = self.download_files_dir / link
1087
- if not file.exists() and file.size() > 50 * 1024 * 1024:
1088
- return
1089
-
1090
- # 理论上下载的文件不应该有重复,虽然不知道为什么会拿到重复,但去掉重复比较好
1091
- answer['downloads'] = list(OrderedDict.fromkeys(answer['downloads']))
1092
-
1093
- # 2.3 删掉answer里其他没用的字段
1094
- for answer in post_record['all_answers']:
1095
- for name in ['created', 'message_id', 'conversation_id', 'end_turn']:
1096
- if name in answer:
1097
- del answer[name]
1098
-
1099
- # 返回处理结果
1100
- return post_record
1101
-
1102
- @staticmethod
1103
- def post2verify_record(post_record):
1104
- """ 这个一般是要具体任务定制的,没有通用操作方式
1105
-
1106
- 注意,如果要使用create_verify的多进程功能,这个函数必须是静态的,并且里面也不能使用其他"类静态方法"
1107
- 否则写成类方法或对象方法都可以
1108
-
1109
- """
1110
- raise NotImplementedError
1111
-
1112
- def verify2train_record(self, verify_record):
1113
- """ 这个一般是要具体任务定制的,没有通用操作方式 """
1114
- raise NotImplementedError
1115
-
1116
- def organize_downloaded_files(self):
1117
- # 把下载的文件整理的更清晰些
1118
- for f in tqdm(list(self.root.glob_files('OpenAI-download-*')),
1119
- desc='整理下载的文件'):
1120
- new_name = re.sub(r'OpenAI-download-\d+-', '', f.name)
1121
- new_name = new_name.replace('-', '/', 1)
1122
- try:
1123
- (self.download_files_dir / new_name).parent.mkdir(exist_ok=True)
1124
- f.rename2(self.download_files_dir / new_name, if_exists='replace')
1125
- except FileExistsError as e:
1126
- # 有的文件会移动不了
1127
- print(e)
1128
-
1129
- # 会剩一些特殊的处理不了的文件,可以看一眼后手动删掉
1130
- # 这些相关的records,默认的chatted2post_record会把这些记录过滤掉
1131
-
1132
- def create_post(self, **kwargs):
1133
- """ 建议初步跑的时候,先串行debug,等比较稳定后,再开并发跑
1134
- """
1135
- if 'dst_dir' not in kwargs:
1136
- kwargs['dst_dir'] = self.post_dir.root
1137
- self.chatted_dir.process_each_record(self.chatted2post_record, **kwargs)
1138
- self.post_dir.update_subfiles()
1139
- num1, num2 = self.chatted_dir.count_records(), self.post_dir.count_records()
1140
- print(f'chatted有{num1}条,转换post有{num2}条,转换率{num2 / num1:.2%}')
1141
-
1142
- def create_verify(self, **kwargs):
1143
- """ 有时候create_verify是有cpu密集运算场景的,可以开多进程
1144
- """
1145
- if 'dst_dir' not in kwargs:
1146
- kwargs['dst_dir'] = self.verify_dir.root
1147
- self.post_dir.process_each_record(self.post2verify_record, **kwargs)
1148
- self.verify_dir.update_subfiles()
1149
- num1, num2 = self.post_dir.count_records(), self.verify_dir.count_records()
1150
- num1 = num1 or -1
1151
- print(f'post有{num1}条,转换verify有{num2}条,转换率{num2 / num1:.2%}')
1152
-
1153
- def refine_verify(self, print_mode=1, **kwargs):
1154
- """ 重复检查verify数据
1155
-
1156
- 这个函数可以重复执行,但前提是self.post2verify_record里的设计有增量规则部分
1157
- """
1158
- self.verify_dir.process_each_record(self.post2verify_record, print_mode=print_mode,
1159
- inplace=True, desc='refine_verify', **kwargs)
1160
-
1161
- @classmethod
1162
- def texts2train_record(cls, texts):
1163
- """ user和assistant的轮询对话,转为训练集格式 """
1164
- messages = []
1165
- for i, text in enumerate(texts):
1166
- role = 'assistant' if i % 2 else 'user'
1167
- messages.append({'role': role, 'content': text})
1168
- return {'messages': messages}
1169
-
1170
- def create_train(self, **kwargs):
1171
- if 'dst_dir' not in kwargs:
1172
- kwargs['dst_dir'] = self.train_dir.root
1173
- self.post_dir.process_each_record(self.verify2train_record, **kwargs)
1174
- self.train_dir.update_subfiles()
1175
-
1176
- def check_chatted_record(self, chatted_record):
1177
- """ 检查chatted数据的有效性 """
1178
- x = chatted_record
1179
- x = self.chatted2post_record(x)
1180
- # x = self.post2verify_record(x)
1181
- # 针对verify可以再进一步定制规则
1182
- return bool(x)
1183
-
1184
- def create_rechat(self, rechat_path):
1185
- """ 筛选失败的数据到一个新的目录,常用于对chatted数据筛选出未成功的样例,上池子重跑
1186
-
1187
- :param rechat_path: 把挑选出来的数据放到新路径
1188
- """
1189
- gcd = GptChatDir(rechat_path)
1190
- f = open(gcd.chat_file, 'w', encoding='utf-8')
1191
-
1192
- for record in tqdm(self.chatted_dir.yield_record(), '检查待重新生成的问题'):
1193
- if not self.check_chatted_record(record):
1194
- continue
1195
- # 否则把这个条目放到rechat,准备拿去重新提问
1196
- if 'error' in record:
1197
- del record['error']
1198
- f.write(json.dumps(record, ensure_ascii=False) + '\n')
1199
- # 如果有文件,也要对应移动
1200
- src_dir = self.upload_files_dir / str(record['id'])
1201
- if src_dir.is_dir():
1202
- src_dir.copy(gcd.upload_files_dir / src_dir.name, if_exists='skip')
1203
-
1204
- f.close()
1205
- return gcd
1206
-
1207
- def update_chatted(self, rechat_path):
1208
- """ 从另一个rechat数据,更新数据条目过来
1209
-
1210
- self依然叫src,rechat叫dst,虽然其实数据是从rechat更新流向self
1211
-
1212
- 注意:这个函数还没有比较严格地进行调试~
1213
- """
1214
- # 1 读取有效记录
1215
- gcd = GptChatDir(rechat_path)
1216
- gcd.organize_downloaded_files()
1217
- # 请确保内存充足哦,这个函数会从rechat的chatted读取所有通过的记录保存起来
1218
- dst_records = {}
1219
- for record in gcd.chatted_dir.yield_record():
1220
- # 找到有all_answers的挑出来
1221
- post_record = self.chatted2post_record(record)
1222
- if post_record:
1223
- dst_records[record['id']] = record
1224
-
1225
- # 2 更新记录
1226
- def update_each_record(x):
1227
- if x['id'] in dst_records:
1228
- # 除了返回record,还得拷贝目录数据呢
1229
- # 上传的目录一般没变,但最好重置下
1230
- src_dir = self.upload_files_dir / x['id']
1231
- dst_dir = gcd.upload_files_dir / x['id']
1232
- dst_dir.copy(src_dir, if_exists='replace')
1233
- # 下载的目录
1234
- src_dir = self.download_files_dir / x['id']
1235
- dst_dir = gcd.download_files_dir / x['id']
1236
- dst_dir.copy(src_dir, if_exists='replace')
1237
- return dst_records[x['id']]
1238
- else:
1239
- return x
1240
-
1241
- self.chatted_dir.update_each_record(update_each_record)
1242
-
1243
-
1244
- def __5_bdchat():
1245
- """ 百度相关api """
1246
-
1247
-
1248
- class BaiduChatbot:
1249
- def __init__(self, api_key, secret_key, file_path=None):
1250
- self.API_KEY = api_key
1251
- self.SECRET_KEY = secret_key
1252
- self.ACCESS_TOKEN = self._get_access_token()
1253
- self.base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token="
1254
- self.file_path = file_path # 文件路径为可选参数
1255
-
1256
- def _get_access_token(self):
1257
- """
1258
- 使用 AK,SK 生成鉴权签名(Access Token)
1259
- :return: access_token,或是None(如果错误)
1260
- """
1261
- url = "https://aip.baidubce.com/oauth/2.0/token"
1262
- params = {
1263
- "grant_type": "client_credentials",
1264
- "client_id": self.API_KEY,
1265
- "client_secret": self.SECRET_KEY
1266
- }
1267
- return str(requests.post(url, params=params).json().get("access_token"))
1268
-
1269
- def chat(self, user_message):
1270
- """ 向Baidu API发送用户消息并返回API的回复
1271
- 注意user_message的token不要超过3k
1272
- """
1273
- url = self.base_url + self.ACCESS_TOKEN
1274
- payload = json.dumps({
1275
- "messages": [{"role": "user", "content": user_message}]
1276
- })
1277
- headers = {'Content-Type': 'application/json'}
1278
- response = requests.post(url, headers=headers, data=payload)
1279
- response_json = response.json()
1280
- response_json['user_message'] = user_message
1281
- response_json['timestamp'] = datetime.datetime.now().isoformat()
1282
-
1283
- # 如果指定了文件路径,自动保存记录
1284
- if self.file_path:
1285
- self._save_to_file(response_json)
1286
-
1287
- return response_json.get('result', '')
1288
-
1289
- def _save_to_file(self, response):
1290
- with open(self.file_path, 'a', encoding='utf-8') as file:
1291
- file.write(json.dumps(response, ensure_ascii=False) + '\n')