pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (306) hide show
  1. pyxllib/algo/geo.py +12 -0
  2. pyxllib/algo/intervals.py +1 -1
  3. pyxllib/algo/matcher.py +78 -0
  4. pyxllib/algo/pupil.py +187 -19
  5. pyxllib/algo/specialist.py +2 -1
  6. pyxllib/algo/stat.py +38 -2
  7. {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
  8. pyxllib/autogui/activewin.py +246 -0
  9. pyxllib/autogui/all.py +9 -0
  10. pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
  11. pyxllib/autogui/uiautolib.py +362 -0
  12. pyxllib/autogui/wechat.py +827 -0
  13. pyxllib/autogui/wechat_msg.py +421 -0
  14. pyxllib/autogui/wxautolib.py +84 -0
  15. pyxllib/cv/slidercaptcha.py +137 -0
  16. pyxllib/data/echarts.py +123 -12
  17. pyxllib/data/jsonlib.py +89 -0
  18. pyxllib/data/pglib.py +514 -30
  19. pyxllib/data/sqlite.py +231 -4
  20. pyxllib/ext/JLineViewer.py +14 -1
  21. pyxllib/ext/drissionlib.py +277 -0
  22. pyxllib/ext/kq5034lib.py +0 -1594
  23. pyxllib/ext/robustprocfile.py +497 -0
  24. pyxllib/ext/unixlib.py +6 -5
  25. pyxllib/ext/utools.py +108 -95
  26. pyxllib/ext/webhook.py +32 -14
  27. pyxllib/ext/wjxlib.py +88 -0
  28. pyxllib/ext/wpsapi.py +124 -0
  29. pyxllib/ext/xlwork.py +9 -0
  30. pyxllib/ext/yuquelib.py +1003 -71
  31. pyxllib/file/docxlib.py +1 -1
  32. pyxllib/file/libreoffice.py +165 -0
  33. pyxllib/file/movielib.py +9 -0
  34. pyxllib/file/packlib/__init__.py +112 -75
  35. pyxllib/file/pdflib.py +1 -1
  36. pyxllib/file/pupil.py +1 -1
  37. pyxllib/file/specialist/dirlib.py +1 -1
  38. pyxllib/file/specialist/download.py +10 -3
  39. pyxllib/file/specialist/filelib.py +266 -55
  40. pyxllib/file/xlsxlib.py +205 -50
  41. pyxllib/file/xlsyncfile.py +341 -0
  42. pyxllib/prog/cachetools.py +64 -0
  43. pyxllib/prog/filelock.py +42 -0
  44. pyxllib/prog/multiprogs.py +940 -0
  45. pyxllib/prog/newbie.py +9 -2
  46. pyxllib/prog/pupil.py +129 -60
  47. pyxllib/prog/specialist/__init__.py +176 -2
  48. pyxllib/prog/specialist/bc.py +5 -2
  49. pyxllib/prog/specialist/browser.py +11 -2
  50. pyxllib/prog/specialist/datetime.py +68 -0
  51. pyxllib/prog/specialist/tictoc.py +12 -13
  52. pyxllib/prog/specialist/xllog.py +5 -5
  53. pyxllib/prog/xlosenv.py +7 -0
  54. pyxllib/text/airscript.js +744 -0
  55. pyxllib/text/charclasslib.py +17 -5
  56. pyxllib/text/jiebalib.py +6 -3
  57. pyxllib/text/jinjalib.py +32 -0
  58. pyxllib/text/jsa_ai_prompt.md +271 -0
  59. pyxllib/text/jscode.py +159 -4
  60. pyxllib/text/nestenv.py +1 -1
  61. pyxllib/text/newbie.py +12 -0
  62. pyxllib/text/pupil/common.py +26 -0
  63. pyxllib/text/specialist/ptag.py +2 -2
  64. pyxllib/text/templates/echart_base.html +11 -0
  65. pyxllib/text/templates/highlight_code.html +17 -0
  66. pyxllib/text/templates/latex_editor.html +103 -0
  67. pyxllib/text/xmllib.py +76 -14
  68. pyxllib/xl.py +2 -1
  69. pyxllib-0.3.197.dist-info/METADATA +48 -0
  70. pyxllib-0.3.197.dist-info/RECORD +126 -0
  71. {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
  72. pyxllib/ext/autogui/__init__.py +0 -8
  73. pyxllib-0.3.96.dist-info/METADATA +0 -51
  74. pyxllib-0.3.96.dist-info/RECORD +0 -333
  75. pyxllib-0.3.96.dist-info/top_level.txt +0 -2
  76. pyxlpr/ai/__init__.py +0 -5
  77. pyxlpr/ai/clientlib.py +0 -1281
  78. pyxlpr/ai/specialist.py +0 -286
  79. pyxlpr/ai/torch_app.py +0 -172
  80. pyxlpr/ai/xlpaddle.py +0 -655
  81. pyxlpr/ai/xltorch.py +0 -705
  82. pyxlpr/data/__init__.py +0 -11
  83. pyxlpr/data/coco.py +0 -1325
  84. pyxlpr/data/datacls.py +0 -365
  85. pyxlpr/data/datasets.py +0 -200
  86. pyxlpr/data/gptlib.py +0 -1291
  87. pyxlpr/data/icdar/__init__.py +0 -96
  88. pyxlpr/data/icdar/deteval.py +0 -377
  89. pyxlpr/data/icdar/icdar2013.py +0 -341
  90. pyxlpr/data/icdar/iou.py +0 -340
  91. pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
  92. pyxlpr/data/imtextline.py +0 -473
  93. pyxlpr/data/labelme.py +0 -866
  94. pyxlpr/data/removeline.py +0 -179
  95. pyxlpr/data/specialist.py +0 -57
  96. pyxlpr/eval/__init__.py +0 -85
  97. pyxlpr/paddleocr.py +0 -776
  98. pyxlpr/ppocr/__init__.py +0 -15
  99. pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
  100. pyxlpr/ppocr/data/__init__.py +0 -135
  101. pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
  102. pyxlpr/ppocr/data/imaug/__init__.py +0 -67
  103. pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
  104. pyxlpr/ppocr/data/imaug/east_process.py +0 -437
  105. pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
  106. pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
  107. pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
  108. pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
  109. pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
  110. pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
  111. pyxlpr/ppocr/data/imaug/operators.py +0 -433
  112. pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
  113. pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
  114. pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
  115. pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
  116. pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
  117. pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
  118. pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
  119. pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
  120. pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
  121. pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
  122. pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
  123. pyxlpr/ppocr/data/simple_dataset.py +0 -372
  124. pyxlpr/ppocr/losses/__init__.py +0 -61
  125. pyxlpr/ppocr/losses/ace_loss.py +0 -52
  126. pyxlpr/ppocr/losses/basic_loss.py +0 -135
  127. pyxlpr/ppocr/losses/center_loss.py +0 -88
  128. pyxlpr/ppocr/losses/cls_loss.py +0 -30
  129. pyxlpr/ppocr/losses/combined_loss.py +0 -67
  130. pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
  131. pyxlpr/ppocr/losses/det_db_loss.py +0 -80
  132. pyxlpr/ppocr/losses/det_east_loss.py +0 -63
  133. pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
  134. pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
  135. pyxlpr/ppocr/losses/distillation_loss.py +0 -272
  136. pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
  137. pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
  138. pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
  139. pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
  140. pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
  141. pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
  142. pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
  143. pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
  144. pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
  145. pyxlpr/ppocr/losses/table_att_loss.py +0 -109
  146. pyxlpr/ppocr/metrics/__init__.py +0 -44
  147. pyxlpr/ppocr/metrics/cls_metric.py +0 -45
  148. pyxlpr/ppocr/metrics/det_metric.py +0 -82
  149. pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
  150. pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
  151. pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
  152. pyxlpr/ppocr/metrics/kie_metric.py +0 -70
  153. pyxlpr/ppocr/metrics/rec_metric.py +0 -75
  154. pyxlpr/ppocr/metrics/table_metric.py +0 -50
  155. pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
  156. pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
  157. pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
  158. pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
  159. pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
  160. pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
  161. pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
  162. pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
  163. pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
  164. pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
  165. pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
  166. pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
  167. pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
  168. pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
  169. pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
  170. pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
  171. pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
  172. pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
  173. pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
  174. pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
  175. pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
  176. pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
  177. pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
  178. pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
  179. pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
  180. pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
  181. pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
  182. pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
  183. pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
  184. pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
  185. pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
  186. pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
  187. pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
  188. pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
  189. pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
  190. pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
  191. pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
  192. pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
  193. pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
  194. pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
  195. pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
  196. pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
  197. pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
  198. pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
  199. pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
  200. pyxlpr/ppocr/optimizer/__init__.py +0 -61
  201. pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
  202. pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
  203. pyxlpr/ppocr/optimizer/optimizer.py +0 -160
  204. pyxlpr/ppocr/optimizer/regularizer.py +0 -52
  205. pyxlpr/ppocr/postprocess/__init__.py +0 -55
  206. pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
  207. pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
  208. pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
  209. pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
  210. pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
  211. pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
  212. pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
  213. pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
  214. pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
  215. pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
  216. pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
  217. pyxlpr/ppocr/tools/__init__.py +0 -14
  218. pyxlpr/ppocr/tools/eval.py +0 -83
  219. pyxlpr/ppocr/tools/export_center.py +0 -77
  220. pyxlpr/ppocr/tools/export_model.py +0 -129
  221. pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
  222. pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
  223. pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
  224. pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
  225. pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
  226. pyxlpr/ppocr/tools/infer/utility.py +0 -629
  227. pyxlpr/ppocr/tools/infer_cls.py +0 -83
  228. pyxlpr/ppocr/tools/infer_det.py +0 -134
  229. pyxlpr/ppocr/tools/infer_e2e.py +0 -122
  230. pyxlpr/ppocr/tools/infer_kie.py +0 -153
  231. pyxlpr/ppocr/tools/infer_rec.py +0 -146
  232. pyxlpr/ppocr/tools/infer_table.py +0 -107
  233. pyxlpr/ppocr/tools/program.py +0 -596
  234. pyxlpr/ppocr/tools/test_hubserving.py +0 -117
  235. pyxlpr/ppocr/tools/train.py +0 -163
  236. pyxlpr/ppocr/tools/xlprog.py +0 -748
  237. pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
  238. pyxlpr/ppocr/utils/__init__.py +0 -24
  239. pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
  240. pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
  241. pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
  242. pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
  243. pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
  244. pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
  245. pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
  246. pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
  247. pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
  248. pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
  249. pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
  250. pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
  251. pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
  252. pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
  253. pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
  254. pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
  255. pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
  256. pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
  257. pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
  258. pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
  259. pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
  260. pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
  261. pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
  262. pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
  263. pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
  264. pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
  265. pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
  266. pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
  267. pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
  268. pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
  269. pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
  270. pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
  271. pyxlpr/ppocr/utils/dict90.txt +0 -90
  272. pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
  273. pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
  274. pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
  275. pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
  276. pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
  277. pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
  278. pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
  279. pyxlpr/ppocr/utils/en_dict.txt +0 -95
  280. pyxlpr/ppocr/utils/gen_label.py +0 -81
  281. pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
  282. pyxlpr/ppocr/utils/iou.py +0 -54
  283. pyxlpr/ppocr/utils/logging.py +0 -69
  284. pyxlpr/ppocr/utils/network.py +0 -84
  285. pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
  286. pyxlpr/ppocr/utils/profiler.py +0 -110
  287. pyxlpr/ppocr/utils/save_load.py +0 -150
  288. pyxlpr/ppocr/utils/stats.py +0 -72
  289. pyxlpr/ppocr/utils/utility.py +0 -80
  290. pyxlpr/ppstructure/__init__.py +0 -13
  291. pyxlpr/ppstructure/predict_system.py +0 -187
  292. pyxlpr/ppstructure/table/__init__.py +0 -13
  293. pyxlpr/ppstructure/table/eval_table.py +0 -72
  294. pyxlpr/ppstructure/table/matcher.py +0 -192
  295. pyxlpr/ppstructure/table/predict_structure.py +0 -136
  296. pyxlpr/ppstructure/table/predict_table.py +0 -221
  297. pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
  298. pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
  299. pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
  300. pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
  301. pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
  302. pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
  303. pyxlpr/ppstructure/utility.py +0 -71
  304. pyxlpr/xlai.py +0 -10
  305. /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
  306. {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
pyxlpr/data/gptlib.py DELETED
@@ -1,1291 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # @Author : 陈坤泽
4
- # @Email : 877362867@qq.com
5
- # @Date : 2023/07/13 14:26
6
-
7
- from pyxllib.prog.pupil import check_install_package
8
-
9
- # check_install_package('transformers', 'transformers')
10
-
11
- import ast
12
- from collections import OrderedDict
13
- from collections import Counter
14
- import contextlib
15
- import copy
16
- import datetime
17
- import heapq
18
- import html
19
- import json
20
- import math
21
- import random
22
- import re
23
- from urllib.parse import unquote
24
- import io
25
- import logging
26
- import warnings
27
-
28
- from jinja2 import Template
29
- from openpyxl import Workbook
30
- import pandas as pd
31
- import requests
32
- from tqdm import tqdm
33
-
34
- try:
35
- from transformers import AutoTokenizer, GPT2TokenizerFast
36
- except ModuleNotFoundError:
37
- pass
38
-
39
- from pyxllib.prog.pupil import OutputLogger
40
- from pyxllib.prog.specialist import browser, TicToc
41
- from pyxllib.algo.pupil import ValuesStat
42
- from pyxllib.file.specialist import XlPath, JsonlDataFile, JsonlDataDir, TwinDirs, ensure_localdir
43
- from pyxllib.file.xlsxlib import extract_workbook_summary
44
-
45
-
46
- def __1_生成提问数据():
47
- pass
48
-
49
-
50
- class Tokenizer:
51
- _tokenizer = None
52
-
53
- @classmethod
54
- def get_tokenizer(cls):
55
- """ 获取tokenizer,第一次调用时进行初始化 """
56
-
57
- if cls._tokenizer is None:
58
- # 根本没必要每次都尝试连接官网,本地有就不要老是sb的尝试连接huggingface
59
- # 而且官网连接也不稳,这里换成我自己的服务器中转
60
- # gpt2_dir = XlPath.tempdir() / 'huggingface_gpt2'
61
- # ensure_localdir(gpt2_dir, 'https://xmutpriu.com/download/huggingface_gpt2.zip')
62
- # Tokenizer._tokenizer = GPT2TokenizerFast.from_pretrained(gpt2_dir)
63
- # 240103周三21:23,hx给过的新评测模型
64
- gpt2_dir = XlPath.tempdir() / 'Atom-CL-SS'
65
- ensure_localdir(gpt2_dir, 'https://xmutpriu.com/download/Atom-CL-SS.zip')
66
- cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_dir, trust_remote_code=True)
67
- return cls._tokenizer
68
-
69
- @classmethod
70
- def tokenize(cls, paragraph, max_length=500):
71
- """ 对段落进行tokenize
72
-
73
- :param str paragraph: 待分词的段落
74
- :param int max_length: 单次处理的最大分词数,为了防止超过GPT2的限制,默认设置为500
75
- :return list: 分词后的列表
76
-
77
- >>> Tokenizer.tokenize('Hello, world! 汉字 123.14 35')
78
- ['Hello', ',', 'Ġworld', '!', 'Ġæ', '±', 'ī', 'åŃ', 'Ĺ', 'Ġ123', '.', '14', 'Ġ35']
79
- """
80
- tokenizer = cls.get_tokenizer()
81
-
82
- # 对段落进行切分
83
- paragraph_slices = [paragraph[i:i + max_length] for i in range(0, len(paragraph), max_length)]
84
-
85
- # 对每个切分的子段进行分词,并将结果拼接在一起
86
- tokens = []
87
- for slice in paragraph_slices:
88
- tokens += tokenizer.tokenize(slice)
89
-
90
- return tokens
91
-
92
- @classmethod
93
- def count_tokens(cls, paragraph, max_length=500):
94
- """ 获取段落的token数量
95
-
96
- :param str paragraph: 待分词的段落
97
- :param int max_length: 单次处理的最大分词数,为了防止超过GPT2的限制,默认设置为500
98
- :return int: token的数量
99
-
100
- >>> Tokenizer.count_tokens('Hello, world!')
101
- 5
102
- """
103
- return len(cls.tokenize(paragraph, max_length))
104
-
105
-
106
- def print_statistics(data, indent_level=1):
107
- """ 计算字符串长度,并且计算关键的一些token数
108
-
109
- :param data: data应该是一个嵌套结构,表示会话与消息
110
- """
111
- fmts = ['g', '.0f', '.0f', 'd', 'd']
112
- stat_len = ValuesStat([len(str(x)) for x in data])
113
-
114
- indent = '\t' * indent_level
115
- print(f'{indent} {stat_len.summary(fmts)}')
116
-
117
-
118
- def check_conversation_lengths(all_texts, n_values=(4, 4),
119
- compute_tokens=False, ids=None):
120
- """ 分析会话长度 """
121
-
122
- # 0 预处理
123
- for i, texts in enumerate(all_texts):
124
- if isinstance(texts, str):
125
- all_texts[i] = [texts]
126
-
127
- # 如果没有提供ID,则使用默认的range(n)
128
- if ids is None:
129
- ids = list(range(len(all_texts)))
130
-
131
- # 处理n_values的重叠
132
- if sum(n_values) >= len(all_texts):
133
- n_values = [len(all_texts), 0] # 将所有数据视为最短数据,不再考虑最长数据
134
-
135
- # 1 消息长度统计
136
- fmts = [None, '.0f', '.0f', 'd', 'd']
137
- lengths = [len(t) for texts in all_texts for t in texts]
138
- print(f'1、消息长度统计 {ValuesStat(lengths).summary(fmts)}')
139
-
140
- # 2 每组会话消息数目
141
- ct = Counter(len(texts) for texts in all_texts)
142
- sorted_ct = {k: v for k, v in sorted(ct.items(), key=lambda x: x[0])}
143
- print(f'2、每组消息数目: {sorted_ct}')
144
-
145
- # 3 找出消息总长度最短和最长的会话
146
- total_lengths = [(i, sum(len(t) for t in texts)) for i, texts in enumerate(all_texts)]
147
- shortest_indices = [item[0] for item in heapq.nsmallest(n_values[0], total_lengths, key=lambda x: x[1])]
148
- longest_indices = [item[0] for item in heapq.nlargest(n_values[1], total_lengths, key=lambda x: x[1])]
149
- longest_indices = longest_indices[::-1] # 从小到大排序
150
-
151
- parts = []
152
- if shortest_indices:
153
- parts.append(', '.join(map(str, [ids[i] for i in shortest_indices])))
154
- if longest_indices:
155
- parts.append(', '.join(map(str, [ids[i] for i in longest_indices])))
156
- print(f'3、最短最长会话的id:', ', ..., '.join(parts))
157
-
158
- # 4 计算token
159
- if compute_tokens:
160
- # 4.1 代表性样本的tokens数
161
- s_texts = [' '.join([x for x in all_texts[i]]) for i in shortest_indices]
162
- l_texts = [' '.join([x for x in all_texts[i]]) for i in longest_indices]
163
-
164
- s_lens = [[len(x), Tokenizer.count_tokens(x)] for x in s_texts]
165
- l_lens = [[len(x), Tokenizer.count_tokens(x)] for x in l_texts]
166
-
167
- parts = []
168
- if s_lens:
169
- parts.append(', '.join(map(str, [x[1] for x in s_lens])))
170
- if l_lens:
171
- parts.append(', '.join(map(str, [x[1] for x in l_lens])))
172
- # 仅计算3中代表性样本
173
- print(f'4、tokens数量:', ', ..., '.join(parts))
174
-
175
- # 4.2 token的比率规律
176
- ratios = []
177
- for x in s_lens + l_lens:
178
- ratios.append(x[1] / x[0])
179
- fmts = [None, '.0%', '.0%', '.0%', '.0%']
180
- print(f'token/len比率统计 {ValuesStat(ratios).summary(fmts)}')
181
- # 比率越大,代表越接近中文场景,汉字越多,要注意len的控制不要让token某些场合超出长度
182
-
183
-
184
- def set_template(s, *args, **kwargs):
185
- """ todo 这个名字会不会太容易冲突了? """
186
- return Template(s.strip(), *args, **kwargs)
187
-
188
-
189
- def set_meta_template(s, meta_start='[[', meta_end=']]', **kwargs):
190
- """ 支持预先用某些格式渲染后,再返回标准渲染模板 """
191
- t = Template(s.strip(), variable_start_string=meta_start,
192
- variable_end_string=meta_end).render(**kwargs)
193
- return Template(t)
194
-
195
-
196
- class StyleParser:
197
- def __init__(self, text):
198
- # 使用正则表达式拆分文本,并获取权重和风格
199
- self.styles = []
200
- self.weights = []
201
- matches = re.findall(r'<风格变换\d+(\s+(\d+))?[^>]*>\s*(.*?)\s*(?=<风格变换|$)', text, re.DOTALL)
202
- for match in matches:
203
- self.styles.append(match[2])
204
- # 提取权重
205
- weight = match[1]
206
- if weight:
207
- self.weights.append(int(weight))
208
- else:
209
- self.weights.append(100) # 默认权重
210
-
211
- def random_pick(self):
212
- """ 随机选择一个风格,并返回其下标和内容
213
-
214
- :return tuple: (下标, 风格内容)
215
-
216
- >>> sp = StyleParser("...") # 按照之前的格式传入一个字符串
217
- >>> index, style = sp.random_pick() # 随机选择一个风格
218
- """
219
- index = random.choices(range(len(self.styles)), weights=self.weights, k=1)[0]
220
- return index, self.styles[index]
221
-
222
-
223
- class GptChatJsonl(JsonlDataFile):
224
- """ GPT问答批量执行脚本的jsonl生成、读取器 """
225
-
226
- def __init__(self, file=None, num_records=None, *, start_id=None):
227
- from datetime import datetime
228
-
229
- super().__init__(file, num_records)
230
- if start_id is None:
231
- # 230821周一02:02,原本只有日期标记,后面发现这样id很容易出现重复,还是加上小时分钟更不容易引起一些没必要的麻烦
232
- today = datetime.now().strftime("%Y%m%d%H%M")
233
- self.start_id = int(today + "000000")
234
- else:
235
- self.start_id = start_id
236
-
237
- def read_jsonl(self, file):
238
- """ 从一个文件加载数据
239
- """
240
- self.records = XlPath(file).read_jsonl()
241
- try:
242
- self.start_id = self.records[-1]['id']
243
- except KeyError:
244
- pass
245
-
246
- def split_and_add_prompt(self, text, max_word_length=None, prompt=None):
247
- """
248
- :param text: 要插入的文本(纯文本,不能是字典格式的 {'content': ...})
249
- :param max_word_length: 如果设置了该值,则会对输入内容进行长度控制,拆成多个片段
250
- 之前测试,全英文长度大概在32500,中文在5000内
251
- :param prompt: max_word_length开启时才会生效,每个part要附加的提示规则
252
- 可以写个字符串,默认每段都是开头加这个提示
253
- 也可以参考gen_prompt2,一些生成函数写法进行自定义
254
- :return:
255
- """
256
- # 0 如果没有输入max_word_length,就不用特地处理了
257
- if max_word_length is None:
258
- return [text]
259
-
260
- # 1 工具函数
261
- def gen_prompt1(n, i, text):
262
- """ 一共n条,当前第i条的,当前内容是text """
263
- if n == 1:
264
- return text
265
- if n > 1:
266
- if i == 0:
267
- return f'【注意,由于本次提问过长,这里拆分成{n}个片段分开输入,目前是第{1}个片段,你只需暂时回复"收到"即可】\n' + text
268
- elif i < n - 1:
269
- return f'【这是{n}个片段中的第{i + 1}个片段,先回复"收到"即可】\n' + text
270
- else:
271
- return f'【这是{n}个片段的最后一个片段,请开始回复内容】\n' + text
272
-
273
- def gen_prompt2(n, i, text):
274
- return prompt + text
275
-
276
- if prompt is None:
277
- gen_prompt = gen_prompt1
278
- elif isinstance(prompt, str):
279
- gen_prompt = gen_prompt2
280
- else: # callable
281
- gen_prompt = prompt
282
-
283
- # 2 拆分重拼接
284
- # 首先要重新调整max_word_length,在确定要拆分为几个片段的情况下,尽量保证这些片段之间的均匀性
285
- num = len(text) // max_word_length + 1
286
- max_word_length = math.ceil(len(text) / num)
287
-
288
- # 2.1 检查是否有超长单行文本,要提前切分成多行
289
- lines = text.rstrip().split('\n')
290
- new_lines = []
291
- for line in lines:
292
- if len(line) < max_word_length:
293
- new_lines.append(line)
294
- else: # 单行就已经爆限制长度的比较特别
295
- n = max_word_length - 10 # 将这个长文本按照n的长度再拆分成多个片段加入new_lines
296
- parts = [line[i:i + n] for i in range(0, len(line), n)]
297
- new_lines += parts
298
-
299
- # 2.2 拼接new_lines
300
- fragments = []
301
- current_fragment = []
302
- current_fragment_total_length = 0
303
- for line in new_lines:
304
- if current_fragment_total_length + len(line) <= max_word_length:
305
- current_fragment.append(line)
306
- current_fragment_total_length += len(line)
307
- else:
308
- fragments.append('\n'.join(current_fragment))
309
- current_fragment = [line]
310
- current_fragment_total_length = len(line)
311
- if current_fragment:
312
- fragments.append('\n'.join(current_fragment))
313
-
314
- n = len(fragments)
315
- fragments = [gen_prompt(n, i, x).strip() for i, x in enumerate(fragments)]
316
-
317
- for i, fragment in enumerate(fragments):
318
- fragment = {"content": fragment}
319
- fragments[i] = fragment
320
- return fragments
321
-
322
- def split_texts(self, texts, max_word_length=None, prompt=None):
323
- """ 长对话自动拆分成多轮对话 """
324
- new_texts = []
325
- for text in texts:
326
- pure_text = text['content']
327
- new_texts += self.split_and_add_prompt(pure_text, max_word_length=max_word_length, prompt=prompt)
328
- if 'file_paths' in text: # 如果有文件,自动放在最后一轮插入
329
- new_texts[-1]['file_paths'] = text['file_paths']
330
- return new_texts
331
-
332
- def add_record(self, texts, *, extra=None,
333
- record_id=0, max_word_length=None, prompt=None):
334
- """
335
- :param texts:
336
- str -> list[str],可以只输入一个str,默认一轮对话
337
- list[str] -> list[{'content': ..., 'file_paths': [...]}]
338
- content: 文本内容
339
- file_paths: 注意可以设置本地电脑其他来源,会自动移到该任务的upload_files里
340
- :param record_id: 可以自定义这个session的id
341
- :param max_word_length: 是否设置一个约束长度,自动切分会话中太长的消息
342
- gpt4是8192个token,大概len就是8192/0.6=13653,一般建议如果要设就设10000左右
343
- :param prompt: 自动分段后
344
- None,自动配置的一套提示
345
- '', 不用提示
346
- :return:
347
- """
348
- # 1 变成标准的list + 字典结构,方便后面统一处理
349
- if not isinstance(texts, list):
350
- texts = [texts]
351
-
352
- for i, text in enumerate(texts):
353
- if isinstance(text, str):
354
- texts[i] = {'content': text}
355
-
356
- # 2 如果设置了每次最大会话长度,要进行拆分
357
- if max_word_length:
358
- texts = self.split_texts(texts, max_word_length=max_word_length, prompt=prompt)
359
-
360
- for i, text in enumerate(texts):
361
- texts[i]['content'] = text['content'].strip()
362
-
363
- # 3 添加会话conversation
364
- self.start_id += 1
365
- item = {'id': str(record_id or self.start_id), # 要转成字符串类型,不然容易出问题
366
- 'text': texts,
367
- 'first_text_length': len(texts[0]['content'])}
368
- if extra:
369
- item['extra'] = extra
370
- self.records.append(item)
371
- return item
372
-
373
- def fix_file_paths(self, save_dir):
374
- """ 修正records中设置的file_paths
375
-
376
- 这些路径可能在设置的时候图方便,设置的是非项目目录下的路径
377
- 这个函数会对这些路径进行修正,为了修正,需要输入一个该jsonl所保存的目录位置
378
- """
379
- save_dir = XlPath(save_dir)
380
- for i, record in tqdm(enumerate(self.records), desc='修复文件路径'):
381
- dst_dir = save_dir / 'upload_files' / str(record['id'])
382
- for j, text in enumerate(record['text']):
383
- for k, fn in enumerate(text.get('file_paths', [])):
384
- src_file = XlPath(fn)
385
- src_file2 = src_file.as_posix()
386
- if src_file2.startswith(f'upload_files/{record["id"]}/'):
387
- continue
388
- dst_file = dst_dir / src_file.name
389
- dst_file2 = dst_file.relpath(save_dir).as_posix()
390
- if src_file.is_file():
391
- if src_file2 != dst_file2:
392
- dst_dir.mkdir(parents=True, exist_ok=True)
393
- src_file.copy(dst_file, if_exists='replace')
394
- else: # 既然设置了,原文件目录应该在
395
- raise FileNotFoundError(f'{src_file}')
396
- text['file_paths'][k] = dst_file2
397
-
398
- def clean_file_paths(self):
399
- """ 清除records中的file_paths
400
- 一般用于把一些相关文件移到对应会话后,实际提问gpt的时候并不上传文件
401
- """
402
- for x in self.records:
403
- for t in x['text']:
404
- if 'file_paths' in t:
405
- del t['file_paths']
406
-
407
- def find_indices_by_qlength(self):
408
- """ 返回提问(q,question)内容从短到长的数据下标 """
409
- lens = [(i, len(''.join([t['content'] for t in x['text']]))) for i, x in enumerate(self.records)]
410
- # 根据长度进行排序,得到的元组的第一个元素为原列表的下标,第二个元素为对应的长度
411
- sorted_lens = sorted(lens, key=lambda x: x[1])
412
- # 取出排序后的下标
413
- sorted_indexs = [i for i, _ in sorted_lens]
414
- return sorted_indexs
415
-
416
- def browse_record(self, index=None, paths=None, **kwargs):
417
- """ 检查第i次会话的内容
418
- """
419
- # 如果未提供索引,则尝试使用查询参数找到第一个匹配的记录
420
- if index is None:
421
- index = self.find_index(paths, **kwargs)
422
- if index is None:
423
- raise ValueError('No matching record found')
424
- session = self.records[index]
425
-
426
- # 构建HTML内容
427
- html_content = "<html><body>"
428
-
429
- # 输出除了text和all_answers以外的所有键值信息
430
- html_content += "<h2>会话信息:</h2>"
431
- html_content += "<ul>"
432
- for key, value in session.items():
433
- if key not in ["text", "all_answers"]:
434
- html_content += f"<li>{html.escape(key)}: {html.escape(str(value))}</li>"
435
- html_content += "</ul>"
436
-
437
- # 输出text和all_answers的内容
438
- texts = self.get_text_texts(session.get("text", []))
439
- all_answers = self.get_all_answers_texts(session.get("all_answers", []))
440
-
441
- max_length = max(len(texts), len(all_answers))
442
- for idx in range(max_length):
443
- html_content += f"<h3>第{idx + 1}次询问:</h3>"
444
- if idx < len(texts):
445
- html_content += f"<pre>{html.escape(texts[idx])}</pre>"
446
- if idx < len(all_answers):
447
- html_content += f"<h3>第{idx + 1}次回答:</h3>"
448
- html_content += f"<pre>{html.escape(str(all_answers[idx]))}</pre>"
449
-
450
- html_content += "</body></html>"
451
- html_file = (XlPath.tempdir() / (str(session.get('id', index)) + '.html'))
452
- html_file.write_text(html_content)
453
- browser.html(html_file)
454
-
455
- # 返回HTML字符串
456
- return html_content
457
-
458
- def get_text_texts(self, text):
459
- """ 从text字段获得所有的文本内容
460
- 因为里面可能是dict
461
- """
462
- ls = []
463
- for t in text:
464
- if isinstance(t, str):
465
- ls.append(t)
466
- else:
467
- if "file_path" in t:
468
- ls.append(("filep_path=" + str(t["file_path"]) + "\n\n") + t["content"])
469
- else:
470
- ls.append(t["content"])
471
- return ls
472
-
473
- def get_all_answers_texts(self, all_answers):
474
- ls = []
475
- for t in all_answers:
476
- if isinstance(t, dict):
477
- t = json.dumps(t, ensure_ascii=False, indent=2)
478
- ls.append(str(t))
479
- return ls
480
-
481
- def check(self):
482
- """ 检查会话、消息长度等信息 """
483
- # 1 提问的内容
484
- all_texts = [self.get_text_texts(session.get('text', []))
485
- for session in self.records]
486
- print('【提问的内容】')
487
- check_conversation_lengths(all_texts,
488
- compute_tokens=True,
489
- ids=[x['id'] for x in self.records])
490
-
491
- # 2 回复的内容
492
- all_texts = [self.get_all_answers_texts(session.get('all_answers', []))
493
- for session in self.records]
494
- # 过滤空值,并相应地更新ids
495
- filtered_texts = [(text, session['id']) for text, session in zip(all_texts, self.records) if text]
496
- all_texts, ids = zip(*filtered_texts) if filtered_texts else ([], [])
497
- if all_texts:
498
- print('【回复的内容】')
499
- check_conversation_lengths(all_texts,
500
- compute_tokens=True,
501
- ids=ids)
502
-
503
- def filter_records_without_answers(self):
504
- """ 过滤掉没有 'all_answers' 字段的sessions """
505
-
506
- # 输出过滤前的sessions数量
507
- print(f"过滤前的sessions数量:{len(self.records)}")
508
-
509
- # 使用列表推导式过滤出包含 'all_answers' 字段的sessions
510
- self.records = [s for s in self.records
511
- if (''.join(map(str, s.get('all_answers', []))))]
512
-
513
- # 输出过滤后的sessions数量
514
- print(f"过滤后的sessions数量:{len(self.records)}")
515
-
516
- @classmethod
517
- def _parse_single_record_answer_contents(cls, record):
518
- """ 注意本函数不做record备份 """
519
- for answer in record.get('all_answers', []):
520
- if isinstance(answer, dict) and 'contents' in answer:
521
- n = len(answer['contents'])
522
- for i in range(n - 1, -1, -1):
523
- message = answer['contents'][i]['message']
524
- if message and 'content' in message and 'error' not in message:
525
- break
526
- else:
527
- answer['contents'] = ''
528
- continue
529
-
530
- content = message['content']
531
- if 'parts' in content:
532
- content = '\n'.join(content['parts'])
533
- else:
534
- content = content['text']
535
- answer['contents'] = content
536
-
537
- @classmethod
538
- def _parse_single_record_answer_downloads(cls, record):
539
- for answer in record.get('all_answers', []):
540
- if 'downloads' in answer:
541
- for i, link in enumerate(answer['downloads']):
542
- m = re.search(r'filename%3D(.+?)&sig=', link)
543
- if m:
544
- answer['downloads'][i] = unquote(unquote(m.group(1)))
545
-
546
- @classmethod
547
- def parse_single_record_answer(cls, record):
548
- cls._parse_single_record_answer_contents(record)
549
- cls._parse_single_record_answer_downloads(record)
550
-
551
- def parse_answer_contents(self):
552
- """ 简化解释器返回结果中,contents的结构信息 """
553
- for record in self.records:
554
- self._parse_single_record_answer_contents(record)
555
-
556
- def parse_answer_downloads(self):
557
- """ 解析,简化下载链接的表达形式 """
558
- for record in self.records:
559
- self._parse_single_record_answer_downloads(record)
560
-
561
- # 目录里的文件名也同理做精简
562
- for f in self.infile.parent.glob_files():
563
- if f.name.startswith('OpenAI-download-'):
564
- f.rename2(f.parent / re.sub(r'OpenAI-download-\d+-', '', f.name),
565
- if_exists='replace')
566
-
567
- def filter_to_rechat(self, check_func, rechat_path=None):
568
- """ 筛选失败的数据到一个新的目录,常用于对chatted数据筛选出未成功的样例,上池子重跑
569
- 这个不是简单的找出得不到all_answers的,而是可以很精细,包含复杂post、verify的情况
570
-
571
- :param check_func: 一个函数,接收一个record,返回True或False
572
- True,表示这个record是对的
573
- False,表示这个record是错的,要挑选出来
574
- :param rechat_path: 把挑选出来的数据放到新路径
575
- """
576
- if rechat_path is None:
577
- rechat_path = XlPath(self.infile.parent.as_posix() + '_rechat/in.jsonl')
578
-
579
- rechat_path = XlPath(rechat_path)
580
- td = TwinDirs(self.infile.parent, rechat_path.parent)
581
-
582
- gcj = type(self)()
583
- for record in self.records:
584
- if not check_func(record):
585
- record2 = {}
586
- for k in ['id', 'text', 'first_text_length', 'extra']:
587
- record2[k] = record[k]
588
- gcj.records.append(record2)
589
- for x in record['text']:
590
- if 'file_path' in x:
591
- td.copy_file(td.src_dir / x['file_path'])
592
-
593
- gcj.save(rechat_path)
594
- return gcj
595
-
596
- def update_from_rechat(self, check_func, rechat_path=None):
597
- """ 从另一份rechat的数据,更新回主master数据
598
-
599
- :param check_func: 原chatted没过,但是rechatted通过的,需要把数据更新过来
600
- :param rechat_path: 注意只能传路径,因为可能涉及到文件操作,需要知道目录所在
601
- 依据这个文件里的record记录更新回self
602
- """
603
- if rechat_path is None:
604
- rechat_path = XlPath(self.infile.parent.as_posix() + '_rechat') / 'out.jsonl'
605
-
606
- rechat_path = XlPath(rechat_path)
607
- td = TwinDirs(rechat_path.parent, self.infile.parent)
608
-
609
- id2index = {x['id']: i for i, x in enumerate(self.records)}
610
-
611
- gcj = type(self)(rechat_path)
612
- gcj.parse_answer_contents()
613
- gcj.parse_answer_downloads()
614
-
615
- # 需要处理下下载链接名称
616
- self.parse_answer_downloads()
617
- gcj.parse_answer_downloads()
618
-
619
- for y in gcj.records:
620
- index = id2index[y['id']]
621
- x = self.records[index]
622
- if not check_func(x) and check_func(y):
623
- # 先把x相关的数据删掉
624
- if 'all_answers' in x:
625
- for answer in x['all_answers']:
626
- for fname in answer.get('downloads', []):
627
- (XlPath(self.infile.parent) / fname).delete()
628
- # 再把y拷贝过来
629
- for answer in y['all_answers']:
630
- for fname in answer.get('downloads', []):
631
- td.copy_file(td.src_dir / fname)
632
- self.records[index] = y
633
- return gcj
634
-
635
-
636
- GptQuestionJsonl = GptChatJsonl # 名称向下兼容
637
-
638
-
639
- def __2_数据后处理():
640
- """ 一些常用的文本、后处理功能也放到这里 """
641
-
642
-
643
- def try_eval_json(resp_json):
644
- try:
645
- resp_json = ast.literal_eval(resp_json)
646
- if isinstance(resp_json, dict):
647
- resp_json = resp_json[resp_json.keys()[0]]
648
- except:
649
- pass
650
- return resp_json
651
-
652
-
653
- def try_load_json(resp_json):
654
- if isinstance(resp_json, str):
655
- try:
656
- resp_json = json.loads(resp_json)
657
- if isinstance(resp_json, dict):
658
- resp_json = resp_json[resp_json.keys()[0]]
659
- except:
660
- pass
661
- return resp_json
662
-
663
-
664
- def try_parse_json(resp_json):
665
- if isinstance(resp_json, dict):
666
- try:
667
- resp_json = '\n'.join(resp_json['contents'][-1]['message']['content'].get('parts', []))
668
- except TypeError:
669
- return ''
670
-
671
- resp_json = try_eval_json(resp_json)
672
- if isinstance(resp_json, str):
673
- return try_load_json(resp_json)
674
- return resp_json
675
-
676
-
677
- def extract_code_blocks_from_md(markdown_text, *, sort_by_length=False):
678
- """ 可以输入str,也可以输入list[str]
679
-
680
- :param sort_by_length: 按代码长度从短到长排序
681
- 常用在比较确信有效代码段应该只有一段,但是有些短小的片段有干扰
682
- 此时可以排序后,选取最长的一个代码片段作为正确代码
683
- """
684
- if isinstance(markdown_text, str):
685
- markdown_text = [markdown_text]
686
-
687
- matches = []
688
- pattern = re.compile(r'^```[^\n]*\n(.+?)\n^```', re.MULTILINE | re.DOTALL)
689
- for text in markdown_text:
690
- matches += pattern.findall(text)
691
-
692
- if sort_by_length:
693
- matches = sorted(matches, key=len)
694
-
695
- return matches
696
-
697
-
698
- def extract_airscript_code_from_answers(all_answers):
699
- """ 从多轮回答的最后一次回答中提取求解代码 """
700
- contents = all_answers[-1]['contents']
701
- text = contents[-1]['text']
702
- code_blocks = extract_code_blocks_from_md(text, sort_by_length=True)
703
-
704
- if code_blocks:
705
- return code_blocks[-1]
706
- else:
707
- return ''
708
-
709
-
710
- def merge_answers_contents(answers):
711
- """ 对一组answers结果中,相同type的contents进行合并 """
712
- for answer in answers:
713
- contents = []
714
- for content in answer['contents']:
715
- if len(contents) == 0:
716
- contents.append(content)
717
- else:
718
- if contents[-1]['type'] == content['type']:
719
- contents[-1]['text'] += '\n' + content['text']
720
- else:
721
- contents.append(content)
722
- answer['contents'] = contents
723
-
724
-
725
- def refine_content_title(content, tag, dst_title=None):
726
- """ 将内容中的标题描述形式标准化
727
-
728
- :param tag: 原标题相关字符
729
- :param content: 文本内容
730
- :param dst_title: 目标标题格式
731
- :return: 处理后的字符串
732
- """
733
- if dst_title is None:
734
- dst_title = f'<{tag}>'
735
- content_lines = content.splitlines()
736
- chars_str = re.compile(tag.replace(':', '[:的]?'))
737
- chinese_chars = re.compile(r'[\u4e00-\u9fa5]')
738
-
739
- res = []
740
- for line in content_lines:
741
- # 使用正则表达式查找匹配的部分
742
- new_line = chars_str.sub('', line)
743
- if new_line != line and not chinese_chars.search(new_line):
744
- res.append(dst_title)
745
- else:
746
- # 如果不满足条件,不进行替换
747
- res.append(line)
748
- return '\n'.join(res)
749
-
750
-
751
- def refine_block_name(record, block_names, preproc=None):
752
- """ 优化模块的标题名,方便后续结构化提取数据
753
-
754
- 感觉这个系列解析是比较通用的,就放在标准库中
755
- """
756
- # if preproc is None:
757
- # def preproc(x):
758
- # return x
759
-
760
- for answer in record['all_answers']:
761
- for content in answer['contents']:
762
- if content['type'] == 'text':
763
- text = old_text = content['text']
764
- if preproc is not None:
765
- text = preproc(text)
766
-
767
- for block_name in block_names:
768
- text = refine_content_title(text, block_name)
769
- text = refine_content_title(text, '---', '')
770
- # 一般不要直接修改原数据,但post里会有备份,所以这里verify可以直接修改了
771
- # if 'answer' not in curr_record['extra']:
772
- # curr_record['extra']['answer'] = []
773
- # curr_record['extra']['answer'].append(text)
774
- content['text'] = text
775
- # 可以借助bc调试
776
- # bcompare(old_text, text)
777
-
778
-
779
- def extract_block_content(record, block_name):
780
- """ 从record的all_answers中,从后往前检索 <block_name> 的内容,
781
- 返回第一个匹配结果,如果找不到则返回空字符串
782
- """
783
- for answer in record['all_answers'][::-1]:
784
- for content in answer['contents'][::-1]:
785
- if content['type'] == 'text':
786
- matches = list(re.finditer(rf'^<{block_name}>\n((.|\n)+?)(?=^<.+?>\n)',
787
- content['text'] + '\n<test>\n', # 末尾补一个<test>,方便对齐
788
- flags=re.MULTILINE))
789
- if matches:
790
- s = matches[-1].group(1).strip()
791
- blocks = extract_code_blocks_from_md(s, sort_by_length=True)
792
- if blocks:
793
- return blocks[-1]
794
- if s:
795
- return s
796
- return '' # 提取不到
797
-
798
-
799
- def __3_生成最后训练用的数据():
800
- pass
801
-
802
-
803
- def texts2train_record(texts):
804
- """ user和assistant的轮询对话,转为训练集格式 """
805
- messages = []
806
- for i, text in enumerate(texts):
807
- role = 'assistant' if i % 2 else 'user'
808
- messages.append({'role': role, 'content': text})
809
- return {'messages': messages}
810
-
811
-
812
- class GptTrainJsonl(JsonlDataFile):
813
- """
814
- record: dict
815
- messages: list
816
- dict: role='user', content=...
817
- dict: role='assistant', content=...
818
- """
819
-
820
- def analyze_text_length(self):
821
- # 1 先将数据统计到df
822
- ls = []
823
- columns = ['role', 'content']
824
- for x in self.records:
825
- for t in x['messages']:
826
- ls.append([t['role'], t['content']])
827
- df = pd.DataFrame.from_records(ls, columns=columns)
828
-
829
- # 2 再从df筛选出不同的统计数据
830
- print('【user和assistant】')
831
- print_statistics(df['content'])
832
- print('【user】')
833
- print_statistics(df[df['role'] == 'user']['content'])
834
- print('【assistant】')
835
- print_statistics(df[df['role'] == 'assistant']['content'])
836
-
837
- def check(self):
838
- """ 检查会话、消息长度等信息 """
839
- # 1. 提取'user'角色的content
840
- user_texts = [[message['content']
841
- for message in record['messages']
842
- if message['role'] == 'user']
843
- for record in self.records]
844
- if not user_texts:
845
- print('空数据')
846
- return
847
-
848
- print('【User的内容】')
849
- check_conversation_lengths(user_texts, compute_tokens=True,
850
- # 因为一般是使用JLineViewer进行查看,跟那个软件对称使用1开始编号
851
- ids=list(range(1, len(user_texts) + 1)))
852
-
853
- # 2. 提取'assistant'角色的content
854
- assistant_texts = [[message['content']
855
- for message in record['messages']
856
- if message['role'] == 'assistant']
857
- for record in self.records]
858
- print('【Assistant的内容】')
859
- check_conversation_lengths(assistant_texts, compute_tokens=True,
860
- ids=list(range(1, len(assistant_texts) + 1)))
861
-
862
- # 3. 将整个record视为一个完整的会话
863
- full_conversations = [' '.join([message['content'] for message in record['messages']])
864
- for record in self.records]
865
- print('【完整的会话】')
866
- check_conversation_lengths(full_conversations, compute_tokens=True,
867
- ids=list(range(1, len(full_conversations) + 1)))
868
-
869
- def browse_record(self, index=None, paths=None, **kwargs):
870
- """ 显示第i次会话的内容 """
871
- # 如果未提供索引,则尝试使用查询参数找到第一个匹配的记录
872
- if index is None:
873
- index = self.find_index(paths, **kwargs)
874
- if index is None:
875
- raise ValueError('No matching record found')
876
- session = self.records[index]
877
-
878
- # 构建HTML内容
879
- html_content = "<html><body>"
880
-
881
- # 输出除了messages以外的所有键值信息
882
- html_content += "<h2>会话信息:</h2>"
883
- html_content += "<ul>"
884
- for key, value in session.items():
885
- if key != "messages":
886
- html_content += f"<li>{html.escape(key)}: {html.escape(str(value))}</li>"
887
- html_content += "</ul>"
888
-
889
- # 输出messages的内容
890
- messages = session.get("messages", [])
891
-
892
- for idx, message in enumerate(messages):
893
- role = message.get('role', 'unknown')
894
- content = message.get('content', '')
895
- html_content += f"<h3>第{(idx // 2) + 1}次{role}的发言:</h3>"
896
- html_content += f"<pre>{html.escape(content)}</pre>"
897
-
898
- html_content += "</body></html>"
899
- html_file = (XlPath.tempdir() / (f'session_{index}.html')) # 创建临时文件名,防止覆盖现有文件
900
- html_file.write_text(html_content)
901
- browser.html(html_file) # 在浏览器中打开HTML文件
902
-
903
- # 或者返回HTML字符串
904
- return html_content
905
-
906
- def add_record(self, texts):
907
- messages = []
908
- for i, text in enumerate(texts):
909
- role = 'assistant' if i % 2 else 'user'
910
- messages.append({'role': role, 'content': text})
911
- self.records.append({'messages': messages})
912
-
913
- def add_from_texts(self, texts):
914
- record = texts2train_record(texts)
915
- self.records.append(record)
916
-
917
-
918
- def __4_综合集成类():
919
- pass
920
-
921
-
922
- class GptChatDir:
923
- """ 一个目录,包含了一个任务的所有数据,包括in、out、post等文件 """
924
-
925
- def __init__(self, root=None, lines_per_file=10000):
926
- if root is None:
927
- root = self.__class__.__name__.lower()
928
-
929
- self.root = root = XlPath(root)
930
- self.lines_per_file = lines_per_file
931
-
932
- self.chat_file = root / 'in.jsonl'
933
- self.chatted_file = root / 'out.jsonl'
934
- self.post_file = root / 'post.jsonl'
935
- self.verify_file = root / 'verify.jsonl'
936
- self.train_file = root / 'train.jsonl'
937
-
938
- # 如果有目录文件,会优先以目录为准。如果没有,则会从单文件拆分创建。
939
- self.update_dir()
940
-
941
- self.upload_files_dir = root / 'upload_files'
942
- self.download_files_dir = root / 'download_files'
943
-
944
- # todo 把 1chat 改名 in,2chatted 改名 out
945
- # for f in self.root.glob_files('*1chat*.jsonl'):
946
- # f.rename2(f.parent / 'in.jsonl')
947
-
948
- # for dir_path in [self.root, self.upload_files_dir, self.download_files_dir]:
949
- for dir_path in [self.root]:
950
- if not dir_path.is_dir():
951
- dir_path.mkdir(parents=True, exist_ok=True)
952
-
953
- # 这个类经常要并发处理,不能把一个不能序列化的类放到这里~
954
- # self.logger = OutputLogger(log_file=self.root / 'log.txt')
955
-
956
- def update_dir(self):
957
- """ 目录结构有些更新后,一些成员变量要跟着改变 """
958
- # 如果有目录文件,会优先以目录为准。如果没有,则会从单文件拆分创建。
959
- self.chat_dir = JsonlDataDir.init_from_file(self.chat_file, self.lines_per_file)
960
- self.chatted_dir = JsonlDataDir.init_from_file(self.chatted_file, self.lines_per_file)
961
- self.post_dir = JsonlDataDir.init_from_file(self.post_file, self.lines_per_file)
962
- self.verify_dir = JsonlDataDir.init_from_file(self.verify_file, self.lines_per_file)
963
- self.train_dir = JsonlDataDir.init_from_file(self.train_file, self.lines_per_file)
964
-
965
- def summary_records(self):
966
- """ 一些统计信息 """
967
- # 1 chat信息
968
- gcd1 = self.chatted_dir or self.chat_dir
969
- if not gcd1:
970
- print('请确认是否有生成初始的chat数据')
971
- return
972
-
973
- print(f'【{self.root.name}】')
974
- texts = [len(x['text']) for x in gcd1.yield_record()]
975
- n, m = len(texts), sum(texts)
976
- print(f'1、chat:{n}条会话*{m / n:.2g}条消息')
977
- gcj1 = GptChatJsonl(gcd1.files[0]) # 统计一个文件就够了,不然太多了
978
- gcj1.check_records()
979
- print()
980
-
981
- # 2 chatted信息
982
- filter_records = [x for x in gcd1.yield_record() if 'all_answers' in x]
983
- if filter_records:
984
- print(f'2、chatted:已获得{len(filter_records)}条会话')
985
- else:
986
- print('2、chatted:暂未获得生成数据')
987
-
988
- # 3 post信息
989
- if self.post_dir:
990
- print(f'3、post:{self.post_dir.count_records()}条会话')
991
-
992
- # 4 verify(这一步有时候会集成到post中)
993
- if self.verify_dir:
994
- print(f'4、verify:{self.verify_dir.count_records()}条会话')
995
-
996
- # 5 train 生成的训练数据
997
- # print('5、train:')
998
- # gtj = GptTrainJsonl(self.train_file)
999
- # gtj.analyze_text_length()
1000
-
1001
- def summary_downloads(self):
1002
- """ 统计下载的文件情况 """
1003
- print('【每个目录文件数量】')
1004
- files_each_dir = []
1005
- for d in self.download_files_dir.glob_dirs():
1006
- files_each_dir.append(len(list(d.rglob_files())))
1007
- print(ValuesStat(files_each_dir).summary())
1008
- print(Counter(files_each_dir))
1009
-
1010
- print('【每个文件大小】')
1011
- filesizes_each_dir = []
1012
- for d in self.download_files_dir.glob_dirs():
1013
- for f in d.rglob_files():
1014
- filesizes_each_dir.append(f.size())
1015
- print(ValuesStat(filesizes_each_dir).summary())
1016
-
1017
- def create_chat(self):
1018
- """ 生成chat数据,具体内容方式跟业务有关 """
1019
- raise NotImplementedError
1020
-
1021
- def browse_chatted_record(self, index=None, paths=None, **kwargs):
1022
- """ 显示第i次会话的内容 """
1023
- f = self.chatted_file if self.chatted_file.is_file() else self.chat_file
1024
- return GptChatJsonl(f, 100).browse_record(index, paths, **kwargs)
1025
-
1026
- def chatted2post_record(self, chatted_record):
1027
- """ 后处理,解析
1028
-
1029
- 一般会保留基本的all_answers结果,供检查上游一些基本情况
1030
- 然后把一些结构化结果存储到extra字段
1031
-
1032
- :return: 会返回新的dict结构的一个post_record,如果解析失败,会返回None
1033
- """
1034
- # 0 基本情况判断
1035
- if 'all_answers' not in chatted_record:
1036
- return
1037
-
1038
- post_record = copy.deepcopy(chatted_record)
1039
-
1040
- # 1 删掉一些没卵用的字段
1041
- for name in ['all_questions', 'first_text_length', 'second_text_length']:
1042
- if name in post_record:
1043
- del post_record[name]
1044
-
1045
- # 2 解析all_answers:这个结构太复杂,进行内容整理精简
1046
- # 2.1 contents:这个结构太复杂,搁这俄罗斯套娃呢~ 稍微精简下更方便后处理
1047
- for k, answer in enumerate(post_record['all_answers']):
1048
- if isinstance(answer, dict) and 'contents' in answer:
1049
- new_contents = []
1050
- for i, x in enumerate(answer['contents']):
1051
- if not x['message']:
1052
- # Error in message stream
1053
- # print(f'{post_record["id"]} answer[{k}] contents[{i}] message为空')
1054
- continue
1055
-
1056
- content = x['message']['content']
1057
- tp = content['content_type']
1058
- new_content = {'type': content['content_type']}
1059
- if tp == 'text':
1060
- new_content['text'] = '\n'.join(content['parts'])
1061
- elif tp == 'code':
1062
- new_content['text'] = content['text']
1063
- elif tp == 'execution_output':
1064
- new_content['text'] = content['text']
1065
- elif tp == 'system_error':
1066
- continue
1067
- else:
1068
- print(f'{post_record["id"]} answer[{k}] contents[{i}] content_type={tp} 未见类型')
1069
- continue
1070
-
1071
- new_contents.append(new_content)
1072
- answer['contents'] = new_contents
1073
- elif isinstance(answer, str): # 普通模式也转成解释器风格,方便统一处理
1074
- post_record['all_answers'][k] = {'contents': [{'type': 'text',
1075
- 'text': answer}]}
1076
-
1077
- # 2.2 downloads:下载链接精简下,并把关联的文件也顺带整理一下
1078
- for answer in post_record['all_answers']:
1079
- if 'downloads' not in answer:
1080
- continue
1081
- for i, link in enumerate(answer['downloads']):
1082
- m = re.search(r'filename%3D(.+?)&sig=', link)
1083
- if m:
1084
- answer['downloads'][i] = str(post_record['id']) + '/' + unquote(unquote(m.group(1)))
1085
- # 对应的文件不存在的不要,有数据超过50M的也不要
1086
- file = self.download_files_dir / link
1087
- if not file.exists() and file.size() > 50 * 1024 * 1024:
1088
- return
1089
-
1090
- # 理论上下载的文件不应该有重复,虽然不知道为什么会拿到重复,但去掉重复比较好
1091
- answer['downloads'] = list(OrderedDict.fromkeys(answer['downloads']))
1092
-
1093
- # 2.3 删掉answer里其他没用的字段
1094
- for answer in post_record['all_answers']:
1095
- for name in ['created', 'message_id', 'conversation_id', 'end_turn']:
1096
- if name in answer:
1097
- del answer[name]
1098
-
1099
- # 返回处理结果
1100
- return post_record
1101
-
1102
- @staticmethod
1103
- def post2verify_record(post_record):
1104
- """ 这个一般是要具体任务定制的,没有通用操作方式
1105
-
1106
- 注意,如果要使用create_verify的多进程功能,这个函数必须是静态的,并且里面也不能使用其他"类静态方法"
1107
- 否则写成类方法或对象方法都可以
1108
-
1109
- """
1110
- raise NotImplementedError
1111
-
1112
- def verify2train_record(self, verify_record):
1113
- """ 这个一般是要具体任务定制的,没有通用操作方式 """
1114
- raise NotImplementedError
1115
-
1116
- def organize_downloaded_files(self):
1117
- # 把下载的文件整理的更清晰些
1118
- for f in tqdm(list(self.root.glob_files('OpenAI-download-*')),
1119
- desc='整理下载的文件'):
1120
- new_name = re.sub(r'OpenAI-download-\d+-', '', f.name)
1121
- new_name = new_name.replace('-', '/', 1)
1122
- try:
1123
- (self.download_files_dir / new_name).parent.mkdir(exist_ok=True)
1124
- f.rename2(self.download_files_dir / new_name, if_exists='replace')
1125
- except FileExistsError as e:
1126
- # 有的文件会移动不了
1127
- print(e)
1128
-
1129
- # 会剩一些特殊的处理不了的文件,可以看一眼后手动删掉
1130
- # 这些相关的records,默认的chatted2post_record会把这些记录过滤掉
1131
-
1132
- def create_post(self, **kwargs):
1133
- """ 建议初步跑的时候,先串行debug,等比较稳定后,再开并发跑
1134
- """
1135
- if 'dst_dir' not in kwargs:
1136
- kwargs['dst_dir'] = self.post_dir.root
1137
- self.chatted_dir.process_each_record(self.chatted2post_record, **kwargs)
1138
- self.post_dir.update_subfiles()
1139
- num1, num2 = self.chatted_dir.count_records(), self.post_dir.count_records()
1140
- print(f'chatted有{num1}条,转换post有{num2}条,转换率{num2 / num1:.2%}')
1141
-
1142
- def create_verify(self, **kwargs):
1143
- """ 有时候create_verify是有cpu密集运算场景的,可以开多进程
1144
- """
1145
- if 'dst_dir' not in kwargs:
1146
- kwargs['dst_dir'] = self.verify_dir.root
1147
- self.post_dir.process_each_record(self.post2verify_record, **kwargs)
1148
- self.verify_dir.update_subfiles()
1149
- num1, num2 = self.post_dir.count_records(), self.verify_dir.count_records()
1150
- num1 = num1 or -1
1151
- print(f'post有{num1}条,转换verify有{num2}条,转换率{num2 / num1:.2%}')
1152
-
1153
- def refine_verify(self, print_mode=1, **kwargs):
1154
- """ 重复检查verify数据
1155
-
1156
- 这个函数可以重复执行,但前提是self.post2verify_record里的设计有增量规则部分
1157
- """
1158
- self.verify_dir.process_each_record(self.post2verify_record, print_mode=print_mode,
1159
- inplace=True, desc='refine_verify', **kwargs)
1160
-
1161
- @classmethod
1162
- def texts2train_record(cls, texts):
1163
- """ user和assistant的轮询对话,转为训练集格式 """
1164
- messages = []
1165
- for i, text in enumerate(texts):
1166
- role = 'assistant' if i % 2 else 'user'
1167
- messages.append({'role': role, 'content': text})
1168
- return {'messages': messages}
1169
-
1170
- def create_train(self, **kwargs):
1171
- if 'dst_dir' not in kwargs:
1172
- kwargs['dst_dir'] = self.train_dir.root
1173
- self.post_dir.process_each_record(self.verify2train_record, **kwargs)
1174
- self.train_dir.update_subfiles()
1175
-
1176
- def check_chatted_record(self, chatted_record):
1177
- """ 检查chatted数据的有效性 """
1178
- x = chatted_record
1179
- x = self.chatted2post_record(x)
1180
- # x = self.post2verify_record(x)
1181
- # 针对verify可以再进一步定制规则
1182
- return bool(x)
1183
-
1184
- def create_rechat(self, rechat_path):
1185
- """ 筛选失败的数据到一个新的目录,常用于对chatted数据筛选出未成功的样例,上池子重跑
1186
-
1187
- :param rechat_path: 把挑选出来的数据放到新路径
1188
- """
1189
- gcd = GptChatDir(rechat_path)
1190
- f = open(gcd.chat_file, 'w', encoding='utf-8')
1191
-
1192
- for record in tqdm(self.chatted_dir.yield_record(), '检查待重新生成的问题'):
1193
- if not self.check_chatted_record(record):
1194
- continue
1195
- # 否则把这个条目放到rechat,准备拿去重新提问
1196
- if 'error' in record:
1197
- del record['error']
1198
- f.write(json.dumps(record, ensure_ascii=False) + '\n')
1199
- # 如果有文件,也要对应移动
1200
- src_dir = self.upload_files_dir / str(record['id'])
1201
- if src_dir.is_dir():
1202
- src_dir.copy(gcd.upload_files_dir / src_dir.name, if_exists='skip')
1203
-
1204
- f.close()
1205
- return gcd
1206
-
1207
- def update_chatted(self, rechat_path):
1208
- """ 从另一个rechat数据,更新数据条目过来
1209
-
1210
- self依然叫src,rechat叫dst,虽然其实数据是从rechat更新流向self
1211
-
1212
- 注意:这个函数还没有比较严格地进行调试~
1213
- """
1214
- # 1 读取有效记录
1215
- gcd = GptChatDir(rechat_path)
1216
- gcd.organize_downloaded_files()
1217
- # 请确保内存充足哦,这个函数会从rechat的chatted读取所有通过的记录保存起来
1218
- dst_records = {}
1219
- for record in gcd.chatted_dir.yield_record():
1220
- # 找到有all_answers的挑出来
1221
- post_record = self.chatted2post_record(record)
1222
- if post_record:
1223
- dst_records[record['id']] = record
1224
-
1225
- # 2 更新记录
1226
- def update_each_record(x):
1227
- if x['id'] in dst_records:
1228
- # 除了返回record,还得拷贝目录数据呢
1229
- # 上传的目录一般没变,但最好重置下
1230
- src_dir = self.upload_files_dir / x['id']
1231
- dst_dir = gcd.upload_files_dir / x['id']
1232
- dst_dir.copy(src_dir, if_exists='replace')
1233
- # 下载的目录
1234
- src_dir = self.download_files_dir / x['id']
1235
- dst_dir = gcd.download_files_dir / x['id']
1236
- dst_dir.copy(src_dir, if_exists='replace')
1237
- return dst_records[x['id']]
1238
- else:
1239
- return x
1240
-
1241
- self.chatted_dir.update_each_record(update_each_record)
1242
-
1243
-
1244
- def __5_bdchat():
1245
- """ 百度相关api """
1246
-
1247
-
1248
- class BaiduChatbot:
1249
- def __init__(self, api_key, secret_key, file_path=None):
1250
- self.API_KEY = api_key
1251
- self.SECRET_KEY = secret_key
1252
- self.ACCESS_TOKEN = self._get_access_token()
1253
- self.base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token="
1254
- self.file_path = file_path # 文件路径为可选参数
1255
-
1256
- def _get_access_token(self):
1257
- """
1258
- 使用 AK,SK 生成鉴权签名(Access Token)
1259
- :return: access_token,或是None(如果错误)
1260
- """
1261
- url = "https://aip.baidubce.com/oauth/2.0/token"
1262
- params = {
1263
- "grant_type": "client_credentials",
1264
- "client_id": self.API_KEY,
1265
- "client_secret": self.SECRET_KEY
1266
- }
1267
- return str(requests.post(url, params=params).json().get("access_token"))
1268
-
1269
- def chat(self, user_message):
1270
- """ 向Baidu API发送用户消息并返回API的回复
1271
- 注意user_message的token不要超过3k
1272
- """
1273
- url = self.base_url + self.ACCESS_TOKEN
1274
- payload = json.dumps({
1275
- "messages": [{"role": "user", "content": user_message}]
1276
- })
1277
- headers = {'Content-Type': 'application/json'}
1278
- response = requests.post(url, headers=headers, data=payload)
1279
- response_json = response.json()
1280
- response_json['user_message'] = user_message
1281
- response_json['timestamp'] = datetime.datetime.now().isoformat()
1282
-
1283
- # 如果指定了文件路径,自动保存记录
1284
- if self.file_path:
1285
- self._save_to_file(response_json)
1286
-
1287
- return response_json.get('result', '')
1288
-
1289
- def _save_to_file(self, response):
1290
- with open(self.file_path, 'a', encoding='utf-8') as file:
1291
- file.write(json.dumps(response, ensure_ascii=False) + '\n')