pyxllib 0.3.96__py3-none-any.whl → 0.3.197__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyxllib/algo/geo.py +12 -0
- pyxllib/algo/intervals.py +1 -1
- pyxllib/algo/matcher.py +78 -0
- pyxllib/algo/pupil.py +187 -19
- pyxllib/algo/specialist.py +2 -1
- pyxllib/algo/stat.py +38 -2
- {pyxlpr → pyxllib/autogui}/__init__.py +1 -1
- pyxllib/autogui/activewin.py +246 -0
- pyxllib/autogui/all.py +9 -0
- pyxllib/{ext/autogui → autogui}/autogui.py +40 -11
- pyxllib/autogui/uiautolib.py +362 -0
- pyxllib/autogui/wechat.py +827 -0
- pyxllib/autogui/wechat_msg.py +421 -0
- pyxllib/autogui/wxautolib.py +84 -0
- pyxllib/cv/slidercaptcha.py +137 -0
- pyxllib/data/echarts.py +123 -12
- pyxllib/data/jsonlib.py +89 -0
- pyxllib/data/pglib.py +514 -30
- pyxllib/data/sqlite.py +231 -4
- pyxllib/ext/JLineViewer.py +14 -1
- pyxllib/ext/drissionlib.py +277 -0
- pyxllib/ext/kq5034lib.py +0 -1594
- pyxllib/ext/robustprocfile.py +497 -0
- pyxllib/ext/unixlib.py +6 -5
- pyxllib/ext/utools.py +108 -95
- pyxllib/ext/webhook.py +32 -14
- pyxllib/ext/wjxlib.py +88 -0
- pyxllib/ext/wpsapi.py +124 -0
- pyxllib/ext/xlwork.py +9 -0
- pyxllib/ext/yuquelib.py +1003 -71
- pyxllib/file/docxlib.py +1 -1
- pyxllib/file/libreoffice.py +165 -0
- pyxllib/file/movielib.py +9 -0
- pyxllib/file/packlib/__init__.py +112 -75
- pyxllib/file/pdflib.py +1 -1
- pyxllib/file/pupil.py +1 -1
- pyxllib/file/specialist/dirlib.py +1 -1
- pyxllib/file/specialist/download.py +10 -3
- pyxllib/file/specialist/filelib.py +266 -55
- pyxllib/file/xlsxlib.py +205 -50
- pyxllib/file/xlsyncfile.py +341 -0
- pyxllib/prog/cachetools.py +64 -0
- pyxllib/prog/filelock.py +42 -0
- pyxllib/prog/multiprogs.py +940 -0
- pyxllib/prog/newbie.py +9 -2
- pyxllib/prog/pupil.py +129 -60
- pyxllib/prog/specialist/__init__.py +176 -2
- pyxllib/prog/specialist/bc.py +5 -2
- pyxllib/prog/specialist/browser.py +11 -2
- pyxllib/prog/specialist/datetime.py +68 -0
- pyxllib/prog/specialist/tictoc.py +12 -13
- pyxllib/prog/specialist/xllog.py +5 -5
- pyxllib/prog/xlosenv.py +7 -0
- pyxllib/text/airscript.js +744 -0
- pyxllib/text/charclasslib.py +17 -5
- pyxllib/text/jiebalib.py +6 -3
- pyxllib/text/jinjalib.py +32 -0
- pyxllib/text/jsa_ai_prompt.md +271 -0
- pyxllib/text/jscode.py +159 -4
- pyxllib/text/nestenv.py +1 -1
- pyxllib/text/newbie.py +12 -0
- pyxllib/text/pupil/common.py +26 -0
- pyxllib/text/specialist/ptag.py +2 -2
- pyxllib/text/templates/echart_base.html +11 -0
- pyxllib/text/templates/highlight_code.html +17 -0
- pyxllib/text/templates/latex_editor.html +103 -0
- pyxllib/text/xmllib.py +76 -14
- pyxllib/xl.py +2 -1
- pyxllib-0.3.197.dist-info/METADATA +48 -0
- pyxllib-0.3.197.dist-info/RECORD +126 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info}/WHEEL +1 -2
- pyxllib/ext/autogui/__init__.py +0 -8
- pyxllib-0.3.96.dist-info/METADATA +0 -51
- pyxllib-0.3.96.dist-info/RECORD +0 -333
- pyxllib-0.3.96.dist-info/top_level.txt +0 -2
- pyxlpr/ai/__init__.py +0 -5
- pyxlpr/ai/clientlib.py +0 -1281
- pyxlpr/ai/specialist.py +0 -286
- pyxlpr/ai/torch_app.py +0 -172
- pyxlpr/ai/xlpaddle.py +0 -655
- pyxlpr/ai/xltorch.py +0 -705
- pyxlpr/data/__init__.py +0 -11
- pyxlpr/data/coco.py +0 -1325
- pyxlpr/data/datacls.py +0 -365
- pyxlpr/data/datasets.py +0 -200
- pyxlpr/data/gptlib.py +0 -1291
- pyxlpr/data/icdar/__init__.py +0 -96
- pyxlpr/data/icdar/deteval.py +0 -377
- pyxlpr/data/icdar/icdar2013.py +0 -341
- pyxlpr/data/icdar/iou.py +0 -340
- pyxlpr/data/icdar/rrc_evaluation_funcs_1_1.py +0 -463
- pyxlpr/data/imtextline.py +0 -473
- pyxlpr/data/labelme.py +0 -866
- pyxlpr/data/removeline.py +0 -179
- pyxlpr/data/specialist.py +0 -57
- pyxlpr/eval/__init__.py +0 -85
- pyxlpr/paddleocr.py +0 -776
- pyxlpr/ppocr/__init__.py +0 -15
- pyxlpr/ppocr/configs/rec/multi_language/generate_multi_language_configs.py +0 -226
- pyxlpr/ppocr/data/__init__.py +0 -135
- pyxlpr/ppocr/data/imaug/ColorJitter.py +0 -26
- pyxlpr/ppocr/data/imaug/__init__.py +0 -67
- pyxlpr/ppocr/data/imaug/copy_paste.py +0 -170
- pyxlpr/ppocr/data/imaug/east_process.py +0 -437
- pyxlpr/ppocr/data/imaug/gen_table_mask.py +0 -244
- pyxlpr/ppocr/data/imaug/iaa_augment.py +0 -114
- pyxlpr/ppocr/data/imaug/label_ops.py +0 -789
- pyxlpr/ppocr/data/imaug/make_border_map.py +0 -184
- pyxlpr/ppocr/data/imaug/make_pse_gt.py +0 -106
- pyxlpr/ppocr/data/imaug/make_shrink_map.py +0 -126
- pyxlpr/ppocr/data/imaug/operators.py +0 -433
- pyxlpr/ppocr/data/imaug/pg_process.py +0 -906
- pyxlpr/ppocr/data/imaug/randaugment.py +0 -143
- pyxlpr/ppocr/data/imaug/random_crop_data.py +0 -239
- pyxlpr/ppocr/data/imaug/rec_img_aug.py +0 -533
- pyxlpr/ppocr/data/imaug/sast_process.py +0 -777
- pyxlpr/ppocr/data/imaug/text_image_aug/__init__.py +0 -17
- pyxlpr/ppocr/data/imaug/text_image_aug/augment.py +0 -120
- pyxlpr/ppocr/data/imaug/text_image_aug/warp_mls.py +0 -168
- pyxlpr/ppocr/data/lmdb_dataset.py +0 -115
- pyxlpr/ppocr/data/pgnet_dataset.py +0 -104
- pyxlpr/ppocr/data/pubtab_dataset.py +0 -107
- pyxlpr/ppocr/data/simple_dataset.py +0 -372
- pyxlpr/ppocr/losses/__init__.py +0 -61
- pyxlpr/ppocr/losses/ace_loss.py +0 -52
- pyxlpr/ppocr/losses/basic_loss.py +0 -135
- pyxlpr/ppocr/losses/center_loss.py +0 -88
- pyxlpr/ppocr/losses/cls_loss.py +0 -30
- pyxlpr/ppocr/losses/combined_loss.py +0 -67
- pyxlpr/ppocr/losses/det_basic_loss.py +0 -208
- pyxlpr/ppocr/losses/det_db_loss.py +0 -80
- pyxlpr/ppocr/losses/det_east_loss.py +0 -63
- pyxlpr/ppocr/losses/det_pse_loss.py +0 -149
- pyxlpr/ppocr/losses/det_sast_loss.py +0 -121
- pyxlpr/ppocr/losses/distillation_loss.py +0 -272
- pyxlpr/ppocr/losses/e2e_pg_loss.py +0 -140
- pyxlpr/ppocr/losses/kie_sdmgr_loss.py +0 -113
- pyxlpr/ppocr/losses/rec_aster_loss.py +0 -99
- pyxlpr/ppocr/losses/rec_att_loss.py +0 -39
- pyxlpr/ppocr/losses/rec_ctc_loss.py +0 -44
- pyxlpr/ppocr/losses/rec_enhanced_ctc_loss.py +0 -70
- pyxlpr/ppocr/losses/rec_nrtr_loss.py +0 -30
- pyxlpr/ppocr/losses/rec_sar_loss.py +0 -28
- pyxlpr/ppocr/losses/rec_srn_loss.py +0 -47
- pyxlpr/ppocr/losses/table_att_loss.py +0 -109
- pyxlpr/ppocr/metrics/__init__.py +0 -44
- pyxlpr/ppocr/metrics/cls_metric.py +0 -45
- pyxlpr/ppocr/metrics/det_metric.py +0 -82
- pyxlpr/ppocr/metrics/distillation_metric.py +0 -73
- pyxlpr/ppocr/metrics/e2e_metric.py +0 -86
- pyxlpr/ppocr/metrics/eval_det_iou.py +0 -274
- pyxlpr/ppocr/metrics/kie_metric.py +0 -70
- pyxlpr/ppocr/metrics/rec_metric.py +0 -75
- pyxlpr/ppocr/metrics/table_metric.py +0 -50
- pyxlpr/ppocr/modeling/architectures/__init__.py +0 -32
- pyxlpr/ppocr/modeling/architectures/base_model.py +0 -88
- pyxlpr/ppocr/modeling/architectures/distillation_model.py +0 -60
- pyxlpr/ppocr/modeling/backbones/__init__.py +0 -54
- pyxlpr/ppocr/modeling/backbones/det_mobilenet_v3.py +0 -268
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd.py +0 -246
- pyxlpr/ppocr/modeling/backbones/det_resnet_vd_sast.py +0 -285
- pyxlpr/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +0 -265
- pyxlpr/ppocr/modeling/backbones/kie_unet_sdmgr.py +0 -186
- pyxlpr/ppocr/modeling/backbones/rec_mobilenet_v3.py +0 -138
- pyxlpr/ppocr/modeling/backbones/rec_mv1_enhance.py +0 -258
- pyxlpr/ppocr/modeling/backbones/rec_nrtr_mtb.py +0 -48
- pyxlpr/ppocr/modeling/backbones/rec_resnet_31.py +0 -210
- pyxlpr/ppocr/modeling/backbones/rec_resnet_aster.py +0 -143
- pyxlpr/ppocr/modeling/backbones/rec_resnet_fpn.py +0 -307
- pyxlpr/ppocr/modeling/backbones/rec_resnet_vd.py +0 -286
- pyxlpr/ppocr/modeling/heads/__init__.py +0 -54
- pyxlpr/ppocr/modeling/heads/cls_head.py +0 -52
- pyxlpr/ppocr/modeling/heads/det_db_head.py +0 -118
- pyxlpr/ppocr/modeling/heads/det_east_head.py +0 -121
- pyxlpr/ppocr/modeling/heads/det_pse_head.py +0 -37
- pyxlpr/ppocr/modeling/heads/det_sast_head.py +0 -128
- pyxlpr/ppocr/modeling/heads/e2e_pg_head.py +0 -253
- pyxlpr/ppocr/modeling/heads/kie_sdmgr_head.py +0 -206
- pyxlpr/ppocr/modeling/heads/multiheadAttention.py +0 -163
- pyxlpr/ppocr/modeling/heads/rec_aster_head.py +0 -393
- pyxlpr/ppocr/modeling/heads/rec_att_head.py +0 -202
- pyxlpr/ppocr/modeling/heads/rec_ctc_head.py +0 -88
- pyxlpr/ppocr/modeling/heads/rec_nrtr_head.py +0 -826
- pyxlpr/ppocr/modeling/heads/rec_sar_head.py +0 -402
- pyxlpr/ppocr/modeling/heads/rec_srn_head.py +0 -280
- pyxlpr/ppocr/modeling/heads/self_attention.py +0 -406
- pyxlpr/ppocr/modeling/heads/table_att_head.py +0 -246
- pyxlpr/ppocr/modeling/necks/__init__.py +0 -32
- pyxlpr/ppocr/modeling/necks/db_fpn.py +0 -111
- pyxlpr/ppocr/modeling/necks/east_fpn.py +0 -188
- pyxlpr/ppocr/modeling/necks/fpn.py +0 -138
- pyxlpr/ppocr/modeling/necks/pg_fpn.py +0 -314
- pyxlpr/ppocr/modeling/necks/rnn.py +0 -92
- pyxlpr/ppocr/modeling/necks/sast_fpn.py +0 -284
- pyxlpr/ppocr/modeling/necks/table_fpn.py +0 -110
- pyxlpr/ppocr/modeling/transforms/__init__.py +0 -28
- pyxlpr/ppocr/modeling/transforms/stn.py +0 -135
- pyxlpr/ppocr/modeling/transforms/tps.py +0 -308
- pyxlpr/ppocr/modeling/transforms/tps_spatial_transformer.py +0 -156
- pyxlpr/ppocr/optimizer/__init__.py +0 -61
- pyxlpr/ppocr/optimizer/learning_rate.py +0 -228
- pyxlpr/ppocr/optimizer/lr_scheduler.py +0 -49
- pyxlpr/ppocr/optimizer/optimizer.py +0 -160
- pyxlpr/ppocr/optimizer/regularizer.py +0 -52
- pyxlpr/ppocr/postprocess/__init__.py +0 -55
- pyxlpr/ppocr/postprocess/cls_postprocess.py +0 -33
- pyxlpr/ppocr/postprocess/db_postprocess.py +0 -234
- pyxlpr/ppocr/postprocess/east_postprocess.py +0 -143
- pyxlpr/ppocr/postprocess/locality_aware_nms.py +0 -200
- pyxlpr/ppocr/postprocess/pg_postprocess.py +0 -52
- pyxlpr/ppocr/postprocess/pse_postprocess/__init__.py +0 -15
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/__init__.py +0 -29
- pyxlpr/ppocr/postprocess/pse_postprocess/pse/setup.py +0 -14
- pyxlpr/ppocr/postprocess/pse_postprocess/pse_postprocess.py +0 -118
- pyxlpr/ppocr/postprocess/rec_postprocess.py +0 -654
- pyxlpr/ppocr/postprocess/sast_postprocess.py +0 -355
- pyxlpr/ppocr/tools/__init__.py +0 -14
- pyxlpr/ppocr/tools/eval.py +0 -83
- pyxlpr/ppocr/tools/export_center.py +0 -77
- pyxlpr/ppocr/tools/export_model.py +0 -129
- pyxlpr/ppocr/tools/infer/predict_cls.py +0 -151
- pyxlpr/ppocr/tools/infer/predict_det.py +0 -300
- pyxlpr/ppocr/tools/infer/predict_e2e.py +0 -169
- pyxlpr/ppocr/tools/infer/predict_rec.py +0 -414
- pyxlpr/ppocr/tools/infer/predict_system.py +0 -204
- pyxlpr/ppocr/tools/infer/utility.py +0 -629
- pyxlpr/ppocr/tools/infer_cls.py +0 -83
- pyxlpr/ppocr/tools/infer_det.py +0 -134
- pyxlpr/ppocr/tools/infer_e2e.py +0 -122
- pyxlpr/ppocr/tools/infer_kie.py +0 -153
- pyxlpr/ppocr/tools/infer_rec.py +0 -146
- pyxlpr/ppocr/tools/infer_table.py +0 -107
- pyxlpr/ppocr/tools/program.py +0 -596
- pyxlpr/ppocr/tools/test_hubserving.py +0 -117
- pyxlpr/ppocr/tools/train.py +0 -163
- pyxlpr/ppocr/tools/xlprog.py +0 -748
- pyxlpr/ppocr/utils/EN_symbol_dict.txt +0 -94
- pyxlpr/ppocr/utils/__init__.py +0 -24
- pyxlpr/ppocr/utils/dict/ar_dict.txt +0 -117
- pyxlpr/ppocr/utils/dict/arabic_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/be_dict.txt +0 -145
- pyxlpr/ppocr/utils/dict/bg_dict.txt +0 -140
- pyxlpr/ppocr/utils/dict/chinese_cht_dict.txt +0 -8421
- pyxlpr/ppocr/utils/dict/cyrillic_dict.txt +0 -163
- pyxlpr/ppocr/utils/dict/devanagari_dict.txt +0 -167
- pyxlpr/ppocr/utils/dict/en_dict.txt +0 -63
- pyxlpr/ppocr/utils/dict/fa_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/french_dict.txt +0 -136
- pyxlpr/ppocr/utils/dict/german_dict.txt +0 -143
- pyxlpr/ppocr/utils/dict/hi_dict.txt +0 -162
- pyxlpr/ppocr/utils/dict/it_dict.txt +0 -118
- pyxlpr/ppocr/utils/dict/japan_dict.txt +0 -4399
- pyxlpr/ppocr/utils/dict/ka_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/korean_dict.txt +0 -3688
- pyxlpr/ppocr/utils/dict/latin_dict.txt +0 -185
- pyxlpr/ppocr/utils/dict/mr_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/ne_dict.txt +0 -153
- pyxlpr/ppocr/utils/dict/oc_dict.txt +0 -96
- pyxlpr/ppocr/utils/dict/pu_dict.txt +0 -130
- pyxlpr/ppocr/utils/dict/rs_dict.txt +0 -91
- pyxlpr/ppocr/utils/dict/rsc_dict.txt +0 -134
- pyxlpr/ppocr/utils/dict/ru_dict.txt +0 -125
- pyxlpr/ppocr/utils/dict/ta_dict.txt +0 -128
- pyxlpr/ppocr/utils/dict/table_dict.txt +0 -277
- pyxlpr/ppocr/utils/dict/table_structure_dict.txt +0 -2759
- pyxlpr/ppocr/utils/dict/te_dict.txt +0 -151
- pyxlpr/ppocr/utils/dict/ug_dict.txt +0 -114
- pyxlpr/ppocr/utils/dict/uk_dict.txt +0 -142
- pyxlpr/ppocr/utils/dict/ur_dict.txt +0 -137
- pyxlpr/ppocr/utils/dict/xi_dict.txt +0 -110
- pyxlpr/ppocr/utils/dict90.txt +0 -90
- pyxlpr/ppocr/utils/e2e_metric/Deteval.py +0 -574
- pyxlpr/ppocr/utils/e2e_metric/polygon_fast.py +0 -83
- pyxlpr/ppocr/utils/e2e_utils/extract_batchsize.py +0 -87
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_fast.py +0 -457
- pyxlpr/ppocr/utils/e2e_utils/extract_textpoint_slow.py +0 -592
- pyxlpr/ppocr/utils/e2e_utils/pgnet_pp_utils.py +0 -162
- pyxlpr/ppocr/utils/e2e_utils/visual.py +0 -162
- pyxlpr/ppocr/utils/en_dict.txt +0 -95
- pyxlpr/ppocr/utils/gen_label.py +0 -81
- pyxlpr/ppocr/utils/ic15_dict.txt +0 -36
- pyxlpr/ppocr/utils/iou.py +0 -54
- pyxlpr/ppocr/utils/logging.py +0 -69
- pyxlpr/ppocr/utils/network.py +0 -84
- pyxlpr/ppocr/utils/ppocr_keys_v1.txt +0 -6623
- pyxlpr/ppocr/utils/profiler.py +0 -110
- pyxlpr/ppocr/utils/save_load.py +0 -150
- pyxlpr/ppocr/utils/stats.py +0 -72
- pyxlpr/ppocr/utils/utility.py +0 -80
- pyxlpr/ppstructure/__init__.py +0 -13
- pyxlpr/ppstructure/predict_system.py +0 -187
- pyxlpr/ppstructure/table/__init__.py +0 -13
- pyxlpr/ppstructure/table/eval_table.py +0 -72
- pyxlpr/ppstructure/table/matcher.py +0 -192
- pyxlpr/ppstructure/table/predict_structure.py +0 -136
- pyxlpr/ppstructure/table/predict_table.py +0 -221
- pyxlpr/ppstructure/table/table_metric/__init__.py +0 -16
- pyxlpr/ppstructure/table/table_metric/parallel.py +0 -51
- pyxlpr/ppstructure/table/table_metric/table_metric.py +0 -247
- pyxlpr/ppstructure/table/tablepyxl/__init__.py +0 -13
- pyxlpr/ppstructure/table/tablepyxl/style.py +0 -283
- pyxlpr/ppstructure/table/tablepyxl/tablepyxl.py +0 -118
- pyxlpr/ppstructure/utility.py +0 -71
- pyxlpr/xlai.py +0 -10
- /pyxllib/{ext/autogui → autogui}/virtualkey.py +0 -0
- {pyxllib-0.3.96.dist-info → pyxllib-0.3.197.dist-info/licenses}/LICENSE +0 -0
pyxlpr/data/gptlib.py
DELETED
@@ -1,1291 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
# @Author : 陈坤泽
|
4
|
-
# @Email : 877362867@qq.com
|
5
|
-
# @Date : 2023/07/13 14:26
|
6
|
-
|
7
|
-
from pyxllib.prog.pupil import check_install_package
|
8
|
-
|
9
|
-
# check_install_package('transformers', 'transformers')
|
10
|
-
|
11
|
-
import ast
|
12
|
-
from collections import OrderedDict
|
13
|
-
from collections import Counter
|
14
|
-
import contextlib
|
15
|
-
import copy
|
16
|
-
import datetime
|
17
|
-
import heapq
|
18
|
-
import html
|
19
|
-
import json
|
20
|
-
import math
|
21
|
-
import random
|
22
|
-
import re
|
23
|
-
from urllib.parse import unquote
|
24
|
-
import io
|
25
|
-
import logging
|
26
|
-
import warnings
|
27
|
-
|
28
|
-
from jinja2 import Template
|
29
|
-
from openpyxl import Workbook
|
30
|
-
import pandas as pd
|
31
|
-
import requests
|
32
|
-
from tqdm import tqdm
|
33
|
-
|
34
|
-
try:
|
35
|
-
from transformers import AutoTokenizer, GPT2TokenizerFast
|
36
|
-
except ModuleNotFoundError:
|
37
|
-
pass
|
38
|
-
|
39
|
-
from pyxllib.prog.pupil import OutputLogger
|
40
|
-
from pyxllib.prog.specialist import browser, TicToc
|
41
|
-
from pyxllib.algo.pupil import ValuesStat
|
42
|
-
from pyxllib.file.specialist import XlPath, JsonlDataFile, JsonlDataDir, TwinDirs, ensure_localdir
|
43
|
-
from pyxllib.file.xlsxlib import extract_workbook_summary
|
44
|
-
|
45
|
-
|
46
|
-
def __1_生成提问数据():
|
47
|
-
pass
|
48
|
-
|
49
|
-
|
50
|
-
class Tokenizer:
|
51
|
-
_tokenizer = None
|
52
|
-
|
53
|
-
@classmethod
|
54
|
-
def get_tokenizer(cls):
|
55
|
-
""" 获取tokenizer,第一次调用时进行初始化 """
|
56
|
-
|
57
|
-
if cls._tokenizer is None:
|
58
|
-
# 根本没必要每次都尝试连接官网,本地有就不要老是sb的尝试连接huggingface
|
59
|
-
# 而且官网连接也不稳,这里换成我自己的服务器中转
|
60
|
-
# gpt2_dir = XlPath.tempdir() / 'huggingface_gpt2'
|
61
|
-
# ensure_localdir(gpt2_dir, 'https://xmutpriu.com/download/huggingface_gpt2.zip')
|
62
|
-
# Tokenizer._tokenizer = GPT2TokenizerFast.from_pretrained(gpt2_dir)
|
63
|
-
# 240103周三21:23,hx给过的新评测模型
|
64
|
-
gpt2_dir = XlPath.tempdir() / 'Atom-CL-SS'
|
65
|
-
ensure_localdir(gpt2_dir, 'https://xmutpriu.com/download/Atom-CL-SS.zip')
|
66
|
-
cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_dir, trust_remote_code=True)
|
67
|
-
return cls._tokenizer
|
68
|
-
|
69
|
-
@classmethod
|
70
|
-
def tokenize(cls, paragraph, max_length=500):
|
71
|
-
""" 对段落进行tokenize
|
72
|
-
|
73
|
-
:param str paragraph: 待分词的段落
|
74
|
-
:param int max_length: 单次处理的最大分词数,为了防止超过GPT2的限制,默认设置为500
|
75
|
-
:return list: 分词后的列表
|
76
|
-
|
77
|
-
>>> Tokenizer.tokenize('Hello, world! 汉字 123.14 35')
|
78
|
-
['Hello', ',', 'Ġworld', '!', 'Ġæ', '±', 'ī', 'åŃ', 'Ĺ', 'Ġ123', '.', '14', 'Ġ35']
|
79
|
-
"""
|
80
|
-
tokenizer = cls.get_tokenizer()
|
81
|
-
|
82
|
-
# 对段落进行切分
|
83
|
-
paragraph_slices = [paragraph[i:i + max_length] for i in range(0, len(paragraph), max_length)]
|
84
|
-
|
85
|
-
# 对每个切分的子段进行分词,并将结果拼接在一起
|
86
|
-
tokens = []
|
87
|
-
for slice in paragraph_slices:
|
88
|
-
tokens += tokenizer.tokenize(slice)
|
89
|
-
|
90
|
-
return tokens
|
91
|
-
|
92
|
-
@classmethod
|
93
|
-
def count_tokens(cls, paragraph, max_length=500):
|
94
|
-
""" 获取段落的token数量
|
95
|
-
|
96
|
-
:param str paragraph: 待分词的段落
|
97
|
-
:param int max_length: 单次处理的最大分词数,为了防止超过GPT2的限制,默认设置为500
|
98
|
-
:return int: token的数量
|
99
|
-
|
100
|
-
>>> Tokenizer.count_tokens('Hello, world!')
|
101
|
-
5
|
102
|
-
"""
|
103
|
-
return len(cls.tokenize(paragraph, max_length))
|
104
|
-
|
105
|
-
|
106
|
-
def print_statistics(data, indent_level=1):
|
107
|
-
""" 计算字符串长度,并且计算关键的一些token数
|
108
|
-
|
109
|
-
:param data: data应该是一个嵌套结构,表示会话与消息
|
110
|
-
"""
|
111
|
-
fmts = ['g', '.0f', '.0f', 'd', 'd']
|
112
|
-
stat_len = ValuesStat([len(str(x)) for x in data])
|
113
|
-
|
114
|
-
indent = '\t' * indent_level
|
115
|
-
print(f'{indent} {stat_len.summary(fmts)}')
|
116
|
-
|
117
|
-
|
118
|
-
def check_conversation_lengths(all_texts, n_values=(4, 4),
|
119
|
-
compute_tokens=False, ids=None):
|
120
|
-
""" 分析会话长度 """
|
121
|
-
|
122
|
-
# 0 预处理
|
123
|
-
for i, texts in enumerate(all_texts):
|
124
|
-
if isinstance(texts, str):
|
125
|
-
all_texts[i] = [texts]
|
126
|
-
|
127
|
-
# 如果没有提供ID,则使用默认的range(n)
|
128
|
-
if ids is None:
|
129
|
-
ids = list(range(len(all_texts)))
|
130
|
-
|
131
|
-
# 处理n_values的重叠
|
132
|
-
if sum(n_values) >= len(all_texts):
|
133
|
-
n_values = [len(all_texts), 0] # 将所有数据视为最短数据,不再考虑最长数据
|
134
|
-
|
135
|
-
# 1 消息长度统计
|
136
|
-
fmts = [None, '.0f', '.0f', 'd', 'd']
|
137
|
-
lengths = [len(t) for texts in all_texts for t in texts]
|
138
|
-
print(f'1、消息长度统计 {ValuesStat(lengths).summary(fmts)}')
|
139
|
-
|
140
|
-
# 2 每组会话消息数目
|
141
|
-
ct = Counter(len(texts) for texts in all_texts)
|
142
|
-
sorted_ct = {k: v for k, v in sorted(ct.items(), key=lambda x: x[0])}
|
143
|
-
print(f'2、每组消息数目: {sorted_ct}')
|
144
|
-
|
145
|
-
# 3 找出消息总长度最短和最长的会话
|
146
|
-
total_lengths = [(i, sum(len(t) for t in texts)) for i, texts in enumerate(all_texts)]
|
147
|
-
shortest_indices = [item[0] for item in heapq.nsmallest(n_values[0], total_lengths, key=lambda x: x[1])]
|
148
|
-
longest_indices = [item[0] for item in heapq.nlargest(n_values[1], total_lengths, key=lambda x: x[1])]
|
149
|
-
longest_indices = longest_indices[::-1] # 从小到大排序
|
150
|
-
|
151
|
-
parts = []
|
152
|
-
if shortest_indices:
|
153
|
-
parts.append(', '.join(map(str, [ids[i] for i in shortest_indices])))
|
154
|
-
if longest_indices:
|
155
|
-
parts.append(', '.join(map(str, [ids[i] for i in longest_indices])))
|
156
|
-
print(f'3、最短最长会话的id:', ', ..., '.join(parts))
|
157
|
-
|
158
|
-
# 4 计算token
|
159
|
-
if compute_tokens:
|
160
|
-
# 4.1 代表性样本的tokens数
|
161
|
-
s_texts = [' '.join([x for x in all_texts[i]]) for i in shortest_indices]
|
162
|
-
l_texts = [' '.join([x for x in all_texts[i]]) for i in longest_indices]
|
163
|
-
|
164
|
-
s_lens = [[len(x), Tokenizer.count_tokens(x)] for x in s_texts]
|
165
|
-
l_lens = [[len(x), Tokenizer.count_tokens(x)] for x in l_texts]
|
166
|
-
|
167
|
-
parts = []
|
168
|
-
if s_lens:
|
169
|
-
parts.append(', '.join(map(str, [x[1] for x in s_lens])))
|
170
|
-
if l_lens:
|
171
|
-
parts.append(', '.join(map(str, [x[1] for x in l_lens])))
|
172
|
-
# 仅计算3中代表性样本
|
173
|
-
print(f'4、tokens数量:', ', ..., '.join(parts))
|
174
|
-
|
175
|
-
# 4.2 token的比率规律
|
176
|
-
ratios = []
|
177
|
-
for x in s_lens + l_lens:
|
178
|
-
ratios.append(x[1] / x[0])
|
179
|
-
fmts = [None, '.0%', '.0%', '.0%', '.0%']
|
180
|
-
print(f'token/len比率统计 {ValuesStat(ratios).summary(fmts)}')
|
181
|
-
# 比率越大,代表越接近中文场景,汉字越多,要注意len的控制不要让token某些场合超出长度
|
182
|
-
|
183
|
-
|
184
|
-
def set_template(s, *args, **kwargs):
|
185
|
-
""" todo 这个名字会不会太容易冲突了? """
|
186
|
-
return Template(s.strip(), *args, **kwargs)
|
187
|
-
|
188
|
-
|
189
|
-
def set_meta_template(s, meta_start='[[', meta_end=']]', **kwargs):
|
190
|
-
""" 支持预先用某些格式渲染后,再返回标准渲染模板 """
|
191
|
-
t = Template(s.strip(), variable_start_string=meta_start,
|
192
|
-
variable_end_string=meta_end).render(**kwargs)
|
193
|
-
return Template(t)
|
194
|
-
|
195
|
-
|
196
|
-
class StyleParser:
|
197
|
-
def __init__(self, text):
|
198
|
-
# 使用正则表达式拆分文本,并获取权重和风格
|
199
|
-
self.styles = []
|
200
|
-
self.weights = []
|
201
|
-
matches = re.findall(r'<风格变换\d+(\s+(\d+))?[^>]*>\s*(.*?)\s*(?=<风格变换|$)', text, re.DOTALL)
|
202
|
-
for match in matches:
|
203
|
-
self.styles.append(match[2])
|
204
|
-
# 提取权重
|
205
|
-
weight = match[1]
|
206
|
-
if weight:
|
207
|
-
self.weights.append(int(weight))
|
208
|
-
else:
|
209
|
-
self.weights.append(100) # 默认权重
|
210
|
-
|
211
|
-
def random_pick(self):
|
212
|
-
""" 随机选择一个风格,并返回其下标和内容
|
213
|
-
|
214
|
-
:return tuple: (下标, 风格内容)
|
215
|
-
|
216
|
-
>>> sp = StyleParser("...") # 按照之前的格式传入一个字符串
|
217
|
-
>>> index, style = sp.random_pick() # 随机选择一个风格
|
218
|
-
"""
|
219
|
-
index = random.choices(range(len(self.styles)), weights=self.weights, k=1)[0]
|
220
|
-
return index, self.styles[index]
|
221
|
-
|
222
|
-
|
223
|
-
class GptChatJsonl(JsonlDataFile):
|
224
|
-
""" GPT问答批量执行脚本的jsonl生成、读取器 """
|
225
|
-
|
226
|
-
def __init__(self, file=None, num_records=None, *, start_id=None):
|
227
|
-
from datetime import datetime
|
228
|
-
|
229
|
-
super().__init__(file, num_records)
|
230
|
-
if start_id is None:
|
231
|
-
# 230821周一02:02,原本只有日期标记,后面发现这样id很容易出现重复,还是加上小时分钟更不容易引起一些没必要的麻烦
|
232
|
-
today = datetime.now().strftime("%Y%m%d%H%M")
|
233
|
-
self.start_id = int(today + "000000")
|
234
|
-
else:
|
235
|
-
self.start_id = start_id
|
236
|
-
|
237
|
-
def read_jsonl(self, file):
|
238
|
-
""" 从一个文件加载数据
|
239
|
-
"""
|
240
|
-
self.records = XlPath(file).read_jsonl()
|
241
|
-
try:
|
242
|
-
self.start_id = self.records[-1]['id']
|
243
|
-
except KeyError:
|
244
|
-
pass
|
245
|
-
|
246
|
-
def split_and_add_prompt(self, text, max_word_length=None, prompt=None):
|
247
|
-
"""
|
248
|
-
:param text: 要插入的文本(纯文本,不能是字典格式的 {'content': ...})
|
249
|
-
:param max_word_length: 如果设置了该值,则会对输入内容进行长度控制,拆成多个片段
|
250
|
-
之前测试,全英文长度大概在32500,中文在5000内
|
251
|
-
:param prompt: max_word_length开启时才会生效,每个part要附加的提示规则
|
252
|
-
可以写个字符串,默认每段都是开头加这个提示
|
253
|
-
也可以参考gen_prompt2,一些生成函数写法进行自定义
|
254
|
-
:return:
|
255
|
-
"""
|
256
|
-
# 0 如果没有输入max_word_length,就不用特地处理了
|
257
|
-
if max_word_length is None:
|
258
|
-
return [text]
|
259
|
-
|
260
|
-
# 1 工具函数
|
261
|
-
def gen_prompt1(n, i, text):
|
262
|
-
""" 一共n条,当前第i条的,当前内容是text """
|
263
|
-
if n == 1:
|
264
|
-
return text
|
265
|
-
if n > 1:
|
266
|
-
if i == 0:
|
267
|
-
return f'【注意,由于本次提问过长,这里拆分成{n}个片段分开输入,目前是第{1}个片段,你只需暂时回复"收到"即可】\n' + text
|
268
|
-
elif i < n - 1:
|
269
|
-
return f'【这是{n}个片段中的第{i + 1}个片段,先回复"收到"即可】\n' + text
|
270
|
-
else:
|
271
|
-
return f'【这是{n}个片段的最后一个片段,请开始回复内容】\n' + text
|
272
|
-
|
273
|
-
def gen_prompt2(n, i, text):
|
274
|
-
return prompt + text
|
275
|
-
|
276
|
-
if prompt is None:
|
277
|
-
gen_prompt = gen_prompt1
|
278
|
-
elif isinstance(prompt, str):
|
279
|
-
gen_prompt = gen_prompt2
|
280
|
-
else: # callable
|
281
|
-
gen_prompt = prompt
|
282
|
-
|
283
|
-
# 2 拆分重拼接
|
284
|
-
# 首先要重新调整max_word_length,在确定要拆分为几个片段的情况下,尽量保证这些片段之间的均匀性
|
285
|
-
num = len(text) // max_word_length + 1
|
286
|
-
max_word_length = math.ceil(len(text) / num)
|
287
|
-
|
288
|
-
# 2.1 检查是否有超长单行文本,要提前切分成多行
|
289
|
-
lines = text.rstrip().split('\n')
|
290
|
-
new_lines = []
|
291
|
-
for line in lines:
|
292
|
-
if len(line) < max_word_length:
|
293
|
-
new_lines.append(line)
|
294
|
-
else: # 单行就已经爆限制长度的比较特别
|
295
|
-
n = max_word_length - 10 # 将这个长文本按照n的长度再拆分成多个片段加入new_lines
|
296
|
-
parts = [line[i:i + n] for i in range(0, len(line), n)]
|
297
|
-
new_lines += parts
|
298
|
-
|
299
|
-
# 2.2 拼接new_lines
|
300
|
-
fragments = []
|
301
|
-
current_fragment = []
|
302
|
-
current_fragment_total_length = 0
|
303
|
-
for line in new_lines:
|
304
|
-
if current_fragment_total_length + len(line) <= max_word_length:
|
305
|
-
current_fragment.append(line)
|
306
|
-
current_fragment_total_length += len(line)
|
307
|
-
else:
|
308
|
-
fragments.append('\n'.join(current_fragment))
|
309
|
-
current_fragment = [line]
|
310
|
-
current_fragment_total_length = len(line)
|
311
|
-
if current_fragment:
|
312
|
-
fragments.append('\n'.join(current_fragment))
|
313
|
-
|
314
|
-
n = len(fragments)
|
315
|
-
fragments = [gen_prompt(n, i, x).strip() for i, x in enumerate(fragments)]
|
316
|
-
|
317
|
-
for i, fragment in enumerate(fragments):
|
318
|
-
fragment = {"content": fragment}
|
319
|
-
fragments[i] = fragment
|
320
|
-
return fragments
|
321
|
-
|
322
|
-
def split_texts(self, texts, max_word_length=None, prompt=None):
|
323
|
-
""" 长对话自动拆分成多轮对话 """
|
324
|
-
new_texts = []
|
325
|
-
for text in texts:
|
326
|
-
pure_text = text['content']
|
327
|
-
new_texts += self.split_and_add_prompt(pure_text, max_word_length=max_word_length, prompt=prompt)
|
328
|
-
if 'file_paths' in text: # 如果有文件,自动放在最后一轮插入
|
329
|
-
new_texts[-1]['file_paths'] = text['file_paths']
|
330
|
-
return new_texts
|
331
|
-
|
332
|
-
def add_record(self, texts, *, extra=None,
|
333
|
-
record_id=0, max_word_length=None, prompt=None):
|
334
|
-
"""
|
335
|
-
:param texts:
|
336
|
-
str -> list[str],可以只输入一个str,默认一轮对话
|
337
|
-
list[str] -> list[{'content': ..., 'file_paths': [...]}]
|
338
|
-
content: 文本内容
|
339
|
-
file_paths: 注意可以设置本地电脑其他来源,会自动移到该任务的upload_files里
|
340
|
-
:param record_id: 可以自定义这个session的id
|
341
|
-
:param max_word_length: 是否设置一个约束长度,自动切分会话中太长的消息
|
342
|
-
gpt4是8192个token,大概len就是8192/0.6=13653,一般建议如果要设就设10000左右
|
343
|
-
:param prompt: 自动分段后
|
344
|
-
None,自动配置的一套提示
|
345
|
-
'', 不用提示
|
346
|
-
:return:
|
347
|
-
"""
|
348
|
-
# 1 变成标准的list + 字典结构,方便后面统一处理
|
349
|
-
if not isinstance(texts, list):
|
350
|
-
texts = [texts]
|
351
|
-
|
352
|
-
for i, text in enumerate(texts):
|
353
|
-
if isinstance(text, str):
|
354
|
-
texts[i] = {'content': text}
|
355
|
-
|
356
|
-
# 2 如果设置了每次最大会话长度,要进行拆分
|
357
|
-
if max_word_length:
|
358
|
-
texts = self.split_texts(texts, max_word_length=max_word_length, prompt=prompt)
|
359
|
-
|
360
|
-
for i, text in enumerate(texts):
|
361
|
-
texts[i]['content'] = text['content'].strip()
|
362
|
-
|
363
|
-
# 3 添加会话conversation
|
364
|
-
self.start_id += 1
|
365
|
-
item = {'id': str(record_id or self.start_id), # 要转成字符串类型,不然容易出问题
|
366
|
-
'text': texts,
|
367
|
-
'first_text_length': len(texts[0]['content'])}
|
368
|
-
if extra:
|
369
|
-
item['extra'] = extra
|
370
|
-
self.records.append(item)
|
371
|
-
return item
|
372
|
-
|
373
|
-
def fix_file_paths(self, save_dir):
|
374
|
-
""" 修正records中设置的file_paths
|
375
|
-
|
376
|
-
这些路径可能在设置的时候图方便,设置的是非项目目录下的路径
|
377
|
-
这个函数会对这些路径进行修正,为了修正,需要输入一个该jsonl所保存的目录位置
|
378
|
-
"""
|
379
|
-
save_dir = XlPath(save_dir)
|
380
|
-
for i, record in tqdm(enumerate(self.records), desc='修复文件路径'):
|
381
|
-
dst_dir = save_dir / 'upload_files' / str(record['id'])
|
382
|
-
for j, text in enumerate(record['text']):
|
383
|
-
for k, fn in enumerate(text.get('file_paths', [])):
|
384
|
-
src_file = XlPath(fn)
|
385
|
-
src_file2 = src_file.as_posix()
|
386
|
-
if src_file2.startswith(f'upload_files/{record["id"]}/'):
|
387
|
-
continue
|
388
|
-
dst_file = dst_dir / src_file.name
|
389
|
-
dst_file2 = dst_file.relpath(save_dir).as_posix()
|
390
|
-
if src_file.is_file():
|
391
|
-
if src_file2 != dst_file2:
|
392
|
-
dst_dir.mkdir(parents=True, exist_ok=True)
|
393
|
-
src_file.copy(dst_file, if_exists='replace')
|
394
|
-
else: # 既然设置了,原文件目录应该在
|
395
|
-
raise FileNotFoundError(f'{src_file}')
|
396
|
-
text['file_paths'][k] = dst_file2
|
397
|
-
|
398
|
-
def clean_file_paths(self):
|
399
|
-
""" 清除records中的file_paths
|
400
|
-
一般用于把一些相关文件移到对应会话后,实际提问gpt的时候并不上传文件
|
401
|
-
"""
|
402
|
-
for x in self.records:
|
403
|
-
for t in x['text']:
|
404
|
-
if 'file_paths' in t:
|
405
|
-
del t['file_paths']
|
406
|
-
|
407
|
-
def find_indices_by_qlength(self):
|
408
|
-
""" 返回提问(q,question)内容从短到长的数据下标 """
|
409
|
-
lens = [(i, len(''.join([t['content'] for t in x['text']]))) for i, x in enumerate(self.records)]
|
410
|
-
# 根据长度进行排序,得到的元组的第一个元素为原列表的下标,第二个元素为对应的长度
|
411
|
-
sorted_lens = sorted(lens, key=lambda x: x[1])
|
412
|
-
# 取出排序后的下标
|
413
|
-
sorted_indexs = [i for i, _ in sorted_lens]
|
414
|
-
return sorted_indexs
|
415
|
-
|
416
|
-
def browse_record(self, index=None, paths=None, **kwargs):
|
417
|
-
""" 检查第i次会话的内容
|
418
|
-
"""
|
419
|
-
# 如果未提供索引,则尝试使用查询参数找到第一个匹配的记录
|
420
|
-
if index is None:
|
421
|
-
index = self.find_index(paths, **kwargs)
|
422
|
-
if index is None:
|
423
|
-
raise ValueError('No matching record found')
|
424
|
-
session = self.records[index]
|
425
|
-
|
426
|
-
# 构建HTML内容
|
427
|
-
html_content = "<html><body>"
|
428
|
-
|
429
|
-
# 输出除了text和all_answers以外的所有键值信息
|
430
|
-
html_content += "<h2>会话信息:</h2>"
|
431
|
-
html_content += "<ul>"
|
432
|
-
for key, value in session.items():
|
433
|
-
if key not in ["text", "all_answers"]:
|
434
|
-
html_content += f"<li>{html.escape(key)}: {html.escape(str(value))}</li>"
|
435
|
-
html_content += "</ul>"
|
436
|
-
|
437
|
-
# 输出text和all_answers的内容
|
438
|
-
texts = self.get_text_texts(session.get("text", []))
|
439
|
-
all_answers = self.get_all_answers_texts(session.get("all_answers", []))
|
440
|
-
|
441
|
-
max_length = max(len(texts), len(all_answers))
|
442
|
-
for idx in range(max_length):
|
443
|
-
html_content += f"<h3>第{idx + 1}次询问:</h3>"
|
444
|
-
if idx < len(texts):
|
445
|
-
html_content += f"<pre>{html.escape(texts[idx])}</pre>"
|
446
|
-
if idx < len(all_answers):
|
447
|
-
html_content += f"<h3>第{idx + 1}次回答:</h3>"
|
448
|
-
html_content += f"<pre>{html.escape(str(all_answers[idx]))}</pre>"
|
449
|
-
|
450
|
-
html_content += "</body></html>"
|
451
|
-
html_file = (XlPath.tempdir() / (str(session.get('id', index)) + '.html'))
|
452
|
-
html_file.write_text(html_content)
|
453
|
-
browser.html(html_file)
|
454
|
-
|
455
|
-
# 返回HTML字符串
|
456
|
-
return html_content
|
457
|
-
|
458
|
-
def get_text_texts(self, text):
|
459
|
-
""" 从text字段获得所有的文本内容
|
460
|
-
因为里面可能是dict
|
461
|
-
"""
|
462
|
-
ls = []
|
463
|
-
for t in text:
|
464
|
-
if isinstance(t, str):
|
465
|
-
ls.append(t)
|
466
|
-
else:
|
467
|
-
if "file_path" in t:
|
468
|
-
ls.append(("filep_path=" + str(t["file_path"]) + "\n\n") + t["content"])
|
469
|
-
else:
|
470
|
-
ls.append(t["content"])
|
471
|
-
return ls
|
472
|
-
|
473
|
-
def get_all_answers_texts(self, all_answers):
|
474
|
-
ls = []
|
475
|
-
for t in all_answers:
|
476
|
-
if isinstance(t, dict):
|
477
|
-
t = json.dumps(t, ensure_ascii=False, indent=2)
|
478
|
-
ls.append(str(t))
|
479
|
-
return ls
|
480
|
-
|
481
|
-
def check(self):
|
482
|
-
""" 检查会话、消息长度等信息 """
|
483
|
-
# 1 提问的内容
|
484
|
-
all_texts = [self.get_text_texts(session.get('text', []))
|
485
|
-
for session in self.records]
|
486
|
-
print('【提问的内容】')
|
487
|
-
check_conversation_lengths(all_texts,
|
488
|
-
compute_tokens=True,
|
489
|
-
ids=[x['id'] for x in self.records])
|
490
|
-
|
491
|
-
# 2 回复的内容
|
492
|
-
all_texts = [self.get_all_answers_texts(session.get('all_answers', []))
|
493
|
-
for session in self.records]
|
494
|
-
# 过滤空值,并相应地更新ids
|
495
|
-
filtered_texts = [(text, session['id']) for text, session in zip(all_texts, self.records) if text]
|
496
|
-
all_texts, ids = zip(*filtered_texts) if filtered_texts else ([], [])
|
497
|
-
if all_texts:
|
498
|
-
print('【回复的内容】')
|
499
|
-
check_conversation_lengths(all_texts,
|
500
|
-
compute_tokens=True,
|
501
|
-
ids=ids)
|
502
|
-
|
503
|
-
def filter_records_without_answers(self):
|
504
|
-
""" 过滤掉没有 'all_answers' 字段的sessions """
|
505
|
-
|
506
|
-
# 输出过滤前的sessions数量
|
507
|
-
print(f"过滤前的sessions数量:{len(self.records)}")
|
508
|
-
|
509
|
-
# 使用列表推导式过滤出包含 'all_answers' 字段的sessions
|
510
|
-
self.records = [s for s in self.records
|
511
|
-
if (''.join(map(str, s.get('all_answers', []))))]
|
512
|
-
|
513
|
-
# 输出过滤后的sessions数量
|
514
|
-
print(f"过滤后的sessions数量:{len(self.records)}")
|
515
|
-
|
516
|
-
@classmethod
|
517
|
-
def _parse_single_record_answer_contents(cls, record):
|
518
|
-
""" 注意本函数不做record备份 """
|
519
|
-
for answer in record.get('all_answers', []):
|
520
|
-
if isinstance(answer, dict) and 'contents' in answer:
|
521
|
-
n = len(answer['contents'])
|
522
|
-
for i in range(n - 1, -1, -1):
|
523
|
-
message = answer['contents'][i]['message']
|
524
|
-
if message and 'content' in message and 'error' not in message:
|
525
|
-
break
|
526
|
-
else:
|
527
|
-
answer['contents'] = ''
|
528
|
-
continue
|
529
|
-
|
530
|
-
content = message['content']
|
531
|
-
if 'parts' in content:
|
532
|
-
content = '\n'.join(content['parts'])
|
533
|
-
else:
|
534
|
-
content = content['text']
|
535
|
-
answer['contents'] = content
|
536
|
-
|
537
|
-
@classmethod
|
538
|
-
def _parse_single_record_answer_downloads(cls, record):
|
539
|
-
for answer in record.get('all_answers', []):
|
540
|
-
if 'downloads' in answer:
|
541
|
-
for i, link in enumerate(answer['downloads']):
|
542
|
-
m = re.search(r'filename%3D(.+?)&sig=', link)
|
543
|
-
if m:
|
544
|
-
answer['downloads'][i] = unquote(unquote(m.group(1)))
|
545
|
-
|
546
|
-
@classmethod
|
547
|
-
def parse_single_record_answer(cls, record):
|
548
|
-
cls._parse_single_record_answer_contents(record)
|
549
|
-
cls._parse_single_record_answer_downloads(record)
|
550
|
-
|
551
|
-
def parse_answer_contents(self):
|
552
|
-
""" 简化解释器返回结果中,contents的结构信息 """
|
553
|
-
for record in self.records:
|
554
|
-
self._parse_single_record_answer_contents(record)
|
555
|
-
|
556
|
-
def parse_answer_downloads(self):
|
557
|
-
""" 解析,简化下载链接的表达形式 """
|
558
|
-
for record in self.records:
|
559
|
-
self._parse_single_record_answer_downloads(record)
|
560
|
-
|
561
|
-
# 目录里的文件名也同理做精简
|
562
|
-
for f in self.infile.parent.glob_files():
|
563
|
-
if f.name.startswith('OpenAI-download-'):
|
564
|
-
f.rename2(f.parent / re.sub(r'OpenAI-download-\d+-', '', f.name),
|
565
|
-
if_exists='replace')
|
566
|
-
|
567
|
-
def filter_to_rechat(self, check_func, rechat_path=None):
|
568
|
-
""" 筛选失败的数据到一个新的目录,常用于对chatted数据筛选出未成功的样例,上池子重跑
|
569
|
-
这个不是简单的找出得不到all_answers的,而是可以很精细,包含复杂post、verify的情况
|
570
|
-
|
571
|
-
:param check_func: 一个函数,接收一个record,返回True或False
|
572
|
-
True,表示这个record是对的
|
573
|
-
False,表示这个record是错的,要挑选出来
|
574
|
-
:param rechat_path: 把挑选出来的数据放到新路径
|
575
|
-
"""
|
576
|
-
if rechat_path is None:
|
577
|
-
rechat_path = XlPath(self.infile.parent.as_posix() + '_rechat/in.jsonl')
|
578
|
-
|
579
|
-
rechat_path = XlPath(rechat_path)
|
580
|
-
td = TwinDirs(self.infile.parent, rechat_path.parent)
|
581
|
-
|
582
|
-
gcj = type(self)()
|
583
|
-
for record in self.records:
|
584
|
-
if not check_func(record):
|
585
|
-
record2 = {}
|
586
|
-
for k in ['id', 'text', 'first_text_length', 'extra']:
|
587
|
-
record2[k] = record[k]
|
588
|
-
gcj.records.append(record2)
|
589
|
-
for x in record['text']:
|
590
|
-
if 'file_path' in x:
|
591
|
-
td.copy_file(td.src_dir / x['file_path'])
|
592
|
-
|
593
|
-
gcj.save(rechat_path)
|
594
|
-
return gcj
|
595
|
-
|
596
|
-
def update_from_rechat(self, check_func, rechat_path=None):
|
597
|
-
""" 从另一份rechat的数据,更新回主master数据
|
598
|
-
|
599
|
-
:param check_func: 原chatted没过,但是rechatted通过的,需要把数据更新过来
|
600
|
-
:param rechat_path: 注意只能传路径,因为可能涉及到文件操作,需要知道目录所在
|
601
|
-
依据这个文件里的record记录更新回self
|
602
|
-
"""
|
603
|
-
if rechat_path is None:
|
604
|
-
rechat_path = XlPath(self.infile.parent.as_posix() + '_rechat') / 'out.jsonl'
|
605
|
-
|
606
|
-
rechat_path = XlPath(rechat_path)
|
607
|
-
td = TwinDirs(rechat_path.parent, self.infile.parent)
|
608
|
-
|
609
|
-
id2index = {x['id']: i for i, x in enumerate(self.records)}
|
610
|
-
|
611
|
-
gcj = type(self)(rechat_path)
|
612
|
-
gcj.parse_answer_contents()
|
613
|
-
gcj.parse_answer_downloads()
|
614
|
-
|
615
|
-
# 需要处理下下载链接名称
|
616
|
-
self.parse_answer_downloads()
|
617
|
-
gcj.parse_answer_downloads()
|
618
|
-
|
619
|
-
for y in gcj.records:
|
620
|
-
index = id2index[y['id']]
|
621
|
-
x = self.records[index]
|
622
|
-
if not check_func(x) and check_func(y):
|
623
|
-
# 先把x相关的数据删掉
|
624
|
-
if 'all_answers' in x:
|
625
|
-
for answer in x['all_answers']:
|
626
|
-
for fname in answer.get('downloads', []):
|
627
|
-
(XlPath(self.infile.parent) / fname).delete()
|
628
|
-
# 再把y拷贝过来
|
629
|
-
for answer in y['all_answers']:
|
630
|
-
for fname in answer.get('downloads', []):
|
631
|
-
td.copy_file(td.src_dir / fname)
|
632
|
-
self.records[index] = y
|
633
|
-
return gcj
|
634
|
-
|
635
|
-
|
636
|
-
GptQuestionJsonl = GptChatJsonl # 名称向下兼容
|
637
|
-
|
638
|
-
|
639
|
-
def __2_数据后处理():
|
640
|
-
""" 一些常用的文本、后处理功能也放到这里 """
|
641
|
-
|
642
|
-
|
643
|
-
def try_eval_json(resp_json):
|
644
|
-
try:
|
645
|
-
resp_json = ast.literal_eval(resp_json)
|
646
|
-
if isinstance(resp_json, dict):
|
647
|
-
resp_json = resp_json[resp_json.keys()[0]]
|
648
|
-
except:
|
649
|
-
pass
|
650
|
-
return resp_json
|
651
|
-
|
652
|
-
|
653
|
-
def try_load_json(resp_json):
|
654
|
-
if isinstance(resp_json, str):
|
655
|
-
try:
|
656
|
-
resp_json = json.loads(resp_json)
|
657
|
-
if isinstance(resp_json, dict):
|
658
|
-
resp_json = resp_json[resp_json.keys()[0]]
|
659
|
-
except:
|
660
|
-
pass
|
661
|
-
return resp_json
|
662
|
-
|
663
|
-
|
664
|
-
def try_parse_json(resp_json):
|
665
|
-
if isinstance(resp_json, dict):
|
666
|
-
try:
|
667
|
-
resp_json = '\n'.join(resp_json['contents'][-1]['message']['content'].get('parts', []))
|
668
|
-
except TypeError:
|
669
|
-
return ''
|
670
|
-
|
671
|
-
resp_json = try_eval_json(resp_json)
|
672
|
-
if isinstance(resp_json, str):
|
673
|
-
return try_load_json(resp_json)
|
674
|
-
return resp_json
|
675
|
-
|
676
|
-
|
677
|
-
def extract_code_blocks_from_md(markdown_text, *, sort_by_length=False):
|
678
|
-
""" 可以输入str,也可以输入list[str]
|
679
|
-
|
680
|
-
:param sort_by_length: 按代码长度从短到长排序
|
681
|
-
常用在比较确信有效代码段应该只有一段,但是有些短小的片段有干扰
|
682
|
-
此时可以排序后,选取最长的一个代码片段作为正确代码
|
683
|
-
"""
|
684
|
-
if isinstance(markdown_text, str):
|
685
|
-
markdown_text = [markdown_text]
|
686
|
-
|
687
|
-
matches = []
|
688
|
-
pattern = re.compile(r'^```[^\n]*\n(.+?)\n^```', re.MULTILINE | re.DOTALL)
|
689
|
-
for text in markdown_text:
|
690
|
-
matches += pattern.findall(text)
|
691
|
-
|
692
|
-
if sort_by_length:
|
693
|
-
matches = sorted(matches, key=len)
|
694
|
-
|
695
|
-
return matches
|
696
|
-
|
697
|
-
|
698
|
-
def extract_airscript_code_from_answers(all_answers):
|
699
|
-
""" 从多轮回答的最后一次回答中提取求解代码 """
|
700
|
-
contents = all_answers[-1]['contents']
|
701
|
-
text = contents[-1]['text']
|
702
|
-
code_blocks = extract_code_blocks_from_md(text, sort_by_length=True)
|
703
|
-
|
704
|
-
if code_blocks:
|
705
|
-
return code_blocks[-1]
|
706
|
-
else:
|
707
|
-
return ''
|
708
|
-
|
709
|
-
|
710
|
-
def merge_answers_contents(answers):
|
711
|
-
""" 对一组answers结果中,相同type的contents进行合并 """
|
712
|
-
for answer in answers:
|
713
|
-
contents = []
|
714
|
-
for content in answer['contents']:
|
715
|
-
if len(contents) == 0:
|
716
|
-
contents.append(content)
|
717
|
-
else:
|
718
|
-
if contents[-1]['type'] == content['type']:
|
719
|
-
contents[-1]['text'] += '\n' + content['text']
|
720
|
-
else:
|
721
|
-
contents.append(content)
|
722
|
-
answer['contents'] = contents
|
723
|
-
|
724
|
-
|
725
|
-
def refine_content_title(content, tag, dst_title=None):
|
726
|
-
""" 将内容中的标题描述形式标准化
|
727
|
-
|
728
|
-
:param tag: 原标题相关字符
|
729
|
-
:param content: 文本内容
|
730
|
-
:param dst_title: 目标标题格式
|
731
|
-
:return: 处理后的字符串
|
732
|
-
"""
|
733
|
-
if dst_title is None:
|
734
|
-
dst_title = f'<{tag}>'
|
735
|
-
content_lines = content.splitlines()
|
736
|
-
chars_str = re.compile(tag.replace(':', '[:的]?'))
|
737
|
-
chinese_chars = re.compile(r'[\u4e00-\u9fa5]')
|
738
|
-
|
739
|
-
res = []
|
740
|
-
for line in content_lines:
|
741
|
-
# 使用正则表达式查找匹配的部分
|
742
|
-
new_line = chars_str.sub('', line)
|
743
|
-
if new_line != line and not chinese_chars.search(new_line):
|
744
|
-
res.append(dst_title)
|
745
|
-
else:
|
746
|
-
# 如果不满足条件,不进行替换
|
747
|
-
res.append(line)
|
748
|
-
return '\n'.join(res)
|
749
|
-
|
750
|
-
|
751
|
-
def refine_block_name(record, block_names, preproc=None):
|
752
|
-
""" 优化模块的标题名,方便后续结构化提取数据
|
753
|
-
|
754
|
-
感觉这个系列解析是比较通用的,就放在标准库中
|
755
|
-
"""
|
756
|
-
# if preproc is None:
|
757
|
-
# def preproc(x):
|
758
|
-
# return x
|
759
|
-
|
760
|
-
for answer in record['all_answers']:
|
761
|
-
for content in answer['contents']:
|
762
|
-
if content['type'] == 'text':
|
763
|
-
text = old_text = content['text']
|
764
|
-
if preproc is not None:
|
765
|
-
text = preproc(text)
|
766
|
-
|
767
|
-
for block_name in block_names:
|
768
|
-
text = refine_content_title(text, block_name)
|
769
|
-
text = refine_content_title(text, '---', '')
|
770
|
-
# 一般不要直接修改原数据,但post里会有备份,所以这里verify可以直接修改了
|
771
|
-
# if 'answer' not in curr_record['extra']:
|
772
|
-
# curr_record['extra']['answer'] = []
|
773
|
-
# curr_record['extra']['answer'].append(text)
|
774
|
-
content['text'] = text
|
775
|
-
# 可以借助bc调试
|
776
|
-
# bcompare(old_text, text)
|
777
|
-
|
778
|
-
|
779
|
-
def extract_block_content(record, block_name):
|
780
|
-
""" 从record的all_answers中,从后往前检索 <block_name> 的内容,
|
781
|
-
返回第一个匹配结果,如果找不到则返回空字符串
|
782
|
-
"""
|
783
|
-
for answer in record['all_answers'][::-1]:
|
784
|
-
for content in answer['contents'][::-1]:
|
785
|
-
if content['type'] == 'text':
|
786
|
-
matches = list(re.finditer(rf'^<{block_name}>\n((.|\n)+?)(?=^<.+?>\n)',
|
787
|
-
content['text'] + '\n<test>\n', # 末尾补一个<test>,方便对齐
|
788
|
-
flags=re.MULTILINE))
|
789
|
-
if matches:
|
790
|
-
s = matches[-1].group(1).strip()
|
791
|
-
blocks = extract_code_blocks_from_md(s, sort_by_length=True)
|
792
|
-
if blocks:
|
793
|
-
return blocks[-1]
|
794
|
-
if s:
|
795
|
-
return s
|
796
|
-
return '' # 提取不到
|
797
|
-
|
798
|
-
|
799
|
-
def __3_生成最后训练用的数据():
|
800
|
-
pass
|
801
|
-
|
802
|
-
|
803
|
-
def texts2train_record(texts):
|
804
|
-
""" user和assistant的轮询对话,转为训练集格式 """
|
805
|
-
messages = []
|
806
|
-
for i, text in enumerate(texts):
|
807
|
-
role = 'assistant' if i % 2 else 'user'
|
808
|
-
messages.append({'role': role, 'content': text})
|
809
|
-
return {'messages': messages}
|
810
|
-
|
811
|
-
|
812
|
-
class GptTrainJsonl(JsonlDataFile):
|
813
|
-
"""
|
814
|
-
record: dict
|
815
|
-
messages: list
|
816
|
-
dict: role='user', content=...
|
817
|
-
dict: role='assistant', content=...
|
818
|
-
"""
|
819
|
-
|
820
|
-
def analyze_text_length(self):
|
821
|
-
# 1 先将数据统计到df
|
822
|
-
ls = []
|
823
|
-
columns = ['role', 'content']
|
824
|
-
for x in self.records:
|
825
|
-
for t in x['messages']:
|
826
|
-
ls.append([t['role'], t['content']])
|
827
|
-
df = pd.DataFrame.from_records(ls, columns=columns)
|
828
|
-
|
829
|
-
# 2 再从df筛选出不同的统计数据
|
830
|
-
print('【user和assistant】')
|
831
|
-
print_statistics(df['content'])
|
832
|
-
print('【user】')
|
833
|
-
print_statistics(df[df['role'] == 'user']['content'])
|
834
|
-
print('【assistant】')
|
835
|
-
print_statistics(df[df['role'] == 'assistant']['content'])
|
836
|
-
|
837
|
-
def check(self):
|
838
|
-
""" 检查会话、消息长度等信息 """
|
839
|
-
# 1. 提取'user'角色的content
|
840
|
-
user_texts = [[message['content']
|
841
|
-
for message in record['messages']
|
842
|
-
if message['role'] == 'user']
|
843
|
-
for record in self.records]
|
844
|
-
if not user_texts:
|
845
|
-
print('空数据')
|
846
|
-
return
|
847
|
-
|
848
|
-
print('【User的内容】')
|
849
|
-
check_conversation_lengths(user_texts, compute_tokens=True,
|
850
|
-
# 因为一般是使用JLineViewer进行查看,跟那个软件对称使用1开始编号
|
851
|
-
ids=list(range(1, len(user_texts) + 1)))
|
852
|
-
|
853
|
-
# 2. 提取'assistant'角色的content
|
854
|
-
assistant_texts = [[message['content']
|
855
|
-
for message in record['messages']
|
856
|
-
if message['role'] == 'assistant']
|
857
|
-
for record in self.records]
|
858
|
-
print('【Assistant的内容】')
|
859
|
-
check_conversation_lengths(assistant_texts, compute_tokens=True,
|
860
|
-
ids=list(range(1, len(assistant_texts) + 1)))
|
861
|
-
|
862
|
-
# 3. 将整个record视为一个完整的会话
|
863
|
-
full_conversations = [' '.join([message['content'] for message in record['messages']])
|
864
|
-
for record in self.records]
|
865
|
-
print('【完整的会话】')
|
866
|
-
check_conversation_lengths(full_conversations, compute_tokens=True,
|
867
|
-
ids=list(range(1, len(full_conversations) + 1)))
|
868
|
-
|
869
|
-
def browse_record(self, index=None, paths=None, **kwargs):
|
870
|
-
""" 显示第i次会话的内容 """
|
871
|
-
# 如果未提供索引,则尝试使用查询参数找到第一个匹配的记录
|
872
|
-
if index is None:
|
873
|
-
index = self.find_index(paths, **kwargs)
|
874
|
-
if index is None:
|
875
|
-
raise ValueError('No matching record found')
|
876
|
-
session = self.records[index]
|
877
|
-
|
878
|
-
# 构建HTML内容
|
879
|
-
html_content = "<html><body>"
|
880
|
-
|
881
|
-
# 输出除了messages以外的所有键值信息
|
882
|
-
html_content += "<h2>会话信息:</h2>"
|
883
|
-
html_content += "<ul>"
|
884
|
-
for key, value in session.items():
|
885
|
-
if key != "messages":
|
886
|
-
html_content += f"<li>{html.escape(key)}: {html.escape(str(value))}</li>"
|
887
|
-
html_content += "</ul>"
|
888
|
-
|
889
|
-
# 输出messages的内容
|
890
|
-
messages = session.get("messages", [])
|
891
|
-
|
892
|
-
for idx, message in enumerate(messages):
|
893
|
-
role = message.get('role', 'unknown')
|
894
|
-
content = message.get('content', '')
|
895
|
-
html_content += f"<h3>第{(idx // 2) + 1}次{role}的发言:</h3>"
|
896
|
-
html_content += f"<pre>{html.escape(content)}</pre>"
|
897
|
-
|
898
|
-
html_content += "</body></html>"
|
899
|
-
html_file = (XlPath.tempdir() / (f'session_{index}.html')) # 创建临时文件名,防止覆盖现有文件
|
900
|
-
html_file.write_text(html_content)
|
901
|
-
browser.html(html_file) # 在浏览器中打开HTML文件
|
902
|
-
|
903
|
-
# 或者返回HTML字符串
|
904
|
-
return html_content
|
905
|
-
|
906
|
-
def add_record(self, texts):
|
907
|
-
messages = []
|
908
|
-
for i, text in enumerate(texts):
|
909
|
-
role = 'assistant' if i % 2 else 'user'
|
910
|
-
messages.append({'role': role, 'content': text})
|
911
|
-
self.records.append({'messages': messages})
|
912
|
-
|
913
|
-
def add_from_texts(self, texts):
|
914
|
-
record = texts2train_record(texts)
|
915
|
-
self.records.append(record)
|
916
|
-
|
917
|
-
|
918
|
-
def __4_综合集成类():
|
919
|
-
pass
|
920
|
-
|
921
|
-
|
922
|
-
class GptChatDir:
|
923
|
-
""" 一个目录,包含了一个任务的所有数据,包括in、out、post等文件 """
|
924
|
-
|
925
|
-
def __init__(self, root=None, lines_per_file=10000):
|
926
|
-
if root is None:
|
927
|
-
root = self.__class__.__name__.lower()
|
928
|
-
|
929
|
-
self.root = root = XlPath(root)
|
930
|
-
self.lines_per_file = lines_per_file
|
931
|
-
|
932
|
-
self.chat_file = root / 'in.jsonl'
|
933
|
-
self.chatted_file = root / 'out.jsonl'
|
934
|
-
self.post_file = root / 'post.jsonl'
|
935
|
-
self.verify_file = root / 'verify.jsonl'
|
936
|
-
self.train_file = root / 'train.jsonl'
|
937
|
-
|
938
|
-
# 如果有目录文件,会优先以目录为准。如果没有,则会从单文件拆分创建。
|
939
|
-
self.update_dir()
|
940
|
-
|
941
|
-
self.upload_files_dir = root / 'upload_files'
|
942
|
-
self.download_files_dir = root / 'download_files'
|
943
|
-
|
944
|
-
# todo 把 1chat 改名 in,2chatted 改名 out
|
945
|
-
# for f in self.root.glob_files('*1chat*.jsonl'):
|
946
|
-
# f.rename2(f.parent / 'in.jsonl')
|
947
|
-
|
948
|
-
# for dir_path in [self.root, self.upload_files_dir, self.download_files_dir]:
|
949
|
-
for dir_path in [self.root]:
|
950
|
-
if not dir_path.is_dir():
|
951
|
-
dir_path.mkdir(parents=True, exist_ok=True)
|
952
|
-
|
953
|
-
# 这个类经常要并发处理,不能把一个不能序列化的类放到这里~
|
954
|
-
# self.logger = OutputLogger(log_file=self.root / 'log.txt')
|
955
|
-
|
956
|
-
def update_dir(self):
|
957
|
-
""" 目录结构有些更新后,一些成员变量要跟着改变 """
|
958
|
-
# 如果有目录文件,会优先以目录为准。如果没有,则会从单文件拆分创建。
|
959
|
-
self.chat_dir = JsonlDataDir.init_from_file(self.chat_file, self.lines_per_file)
|
960
|
-
self.chatted_dir = JsonlDataDir.init_from_file(self.chatted_file, self.lines_per_file)
|
961
|
-
self.post_dir = JsonlDataDir.init_from_file(self.post_file, self.lines_per_file)
|
962
|
-
self.verify_dir = JsonlDataDir.init_from_file(self.verify_file, self.lines_per_file)
|
963
|
-
self.train_dir = JsonlDataDir.init_from_file(self.train_file, self.lines_per_file)
|
964
|
-
|
965
|
-
def summary_records(self):
|
966
|
-
""" 一些统计信息 """
|
967
|
-
# 1 chat信息
|
968
|
-
gcd1 = self.chatted_dir or self.chat_dir
|
969
|
-
if not gcd1:
|
970
|
-
print('请确认是否有生成初始的chat数据')
|
971
|
-
return
|
972
|
-
|
973
|
-
print(f'【{self.root.name}】')
|
974
|
-
texts = [len(x['text']) for x in gcd1.yield_record()]
|
975
|
-
n, m = len(texts), sum(texts)
|
976
|
-
print(f'1、chat:{n}条会话*{m / n:.2g}条消息')
|
977
|
-
gcj1 = GptChatJsonl(gcd1.files[0]) # 统计一个文件就够了,不然太多了
|
978
|
-
gcj1.check_records()
|
979
|
-
print()
|
980
|
-
|
981
|
-
# 2 chatted信息
|
982
|
-
filter_records = [x for x in gcd1.yield_record() if 'all_answers' in x]
|
983
|
-
if filter_records:
|
984
|
-
print(f'2、chatted:已获得{len(filter_records)}条会话')
|
985
|
-
else:
|
986
|
-
print('2、chatted:暂未获得生成数据')
|
987
|
-
|
988
|
-
# 3 post信息
|
989
|
-
if self.post_dir:
|
990
|
-
print(f'3、post:{self.post_dir.count_records()}条会话')
|
991
|
-
|
992
|
-
# 4 verify(这一步有时候会集成到post中)
|
993
|
-
if self.verify_dir:
|
994
|
-
print(f'4、verify:{self.verify_dir.count_records()}条会话')
|
995
|
-
|
996
|
-
# 5 train 生成的训练数据
|
997
|
-
# print('5、train:')
|
998
|
-
# gtj = GptTrainJsonl(self.train_file)
|
999
|
-
# gtj.analyze_text_length()
|
1000
|
-
|
1001
|
-
def summary_downloads(self):
|
1002
|
-
""" 统计下载的文件情况 """
|
1003
|
-
print('【每个目录文件数量】')
|
1004
|
-
files_each_dir = []
|
1005
|
-
for d in self.download_files_dir.glob_dirs():
|
1006
|
-
files_each_dir.append(len(list(d.rglob_files())))
|
1007
|
-
print(ValuesStat(files_each_dir).summary())
|
1008
|
-
print(Counter(files_each_dir))
|
1009
|
-
|
1010
|
-
print('【每个文件大小】')
|
1011
|
-
filesizes_each_dir = []
|
1012
|
-
for d in self.download_files_dir.glob_dirs():
|
1013
|
-
for f in d.rglob_files():
|
1014
|
-
filesizes_each_dir.append(f.size())
|
1015
|
-
print(ValuesStat(filesizes_each_dir).summary())
|
1016
|
-
|
1017
|
-
def create_chat(self):
|
1018
|
-
""" 生成chat数据,具体内容方式跟业务有关 """
|
1019
|
-
raise NotImplementedError
|
1020
|
-
|
1021
|
-
def browse_chatted_record(self, index=None, paths=None, **kwargs):
|
1022
|
-
""" 显示第i次会话的内容 """
|
1023
|
-
f = self.chatted_file if self.chatted_file.is_file() else self.chat_file
|
1024
|
-
return GptChatJsonl(f, 100).browse_record(index, paths, **kwargs)
|
1025
|
-
|
1026
|
-
def chatted2post_record(self, chatted_record):
|
1027
|
-
""" 后处理,解析
|
1028
|
-
|
1029
|
-
一般会保留基本的all_answers结果,供检查上游一些基本情况
|
1030
|
-
然后把一些结构化结果存储到extra字段
|
1031
|
-
|
1032
|
-
:return: 会返回新的dict结构的一个post_record,如果解析失败,会返回None
|
1033
|
-
"""
|
1034
|
-
# 0 基本情况判断
|
1035
|
-
if 'all_answers' not in chatted_record:
|
1036
|
-
return
|
1037
|
-
|
1038
|
-
post_record = copy.deepcopy(chatted_record)
|
1039
|
-
|
1040
|
-
# 1 删掉一些没卵用的字段
|
1041
|
-
for name in ['all_questions', 'first_text_length', 'second_text_length']:
|
1042
|
-
if name in post_record:
|
1043
|
-
del post_record[name]
|
1044
|
-
|
1045
|
-
# 2 解析all_answers:这个结构太复杂,进行内容整理精简
|
1046
|
-
# 2.1 contents:这个结构太复杂,搁这俄罗斯套娃呢~ 稍微精简下更方便后处理
|
1047
|
-
for k, answer in enumerate(post_record['all_answers']):
|
1048
|
-
if isinstance(answer, dict) and 'contents' in answer:
|
1049
|
-
new_contents = []
|
1050
|
-
for i, x in enumerate(answer['contents']):
|
1051
|
-
if not x['message']:
|
1052
|
-
# Error in message stream
|
1053
|
-
# print(f'{post_record["id"]} answer[{k}] contents[{i}] message为空')
|
1054
|
-
continue
|
1055
|
-
|
1056
|
-
content = x['message']['content']
|
1057
|
-
tp = content['content_type']
|
1058
|
-
new_content = {'type': content['content_type']}
|
1059
|
-
if tp == 'text':
|
1060
|
-
new_content['text'] = '\n'.join(content['parts'])
|
1061
|
-
elif tp == 'code':
|
1062
|
-
new_content['text'] = content['text']
|
1063
|
-
elif tp == 'execution_output':
|
1064
|
-
new_content['text'] = content['text']
|
1065
|
-
elif tp == 'system_error':
|
1066
|
-
continue
|
1067
|
-
else:
|
1068
|
-
print(f'{post_record["id"]} answer[{k}] contents[{i}] content_type={tp} 未见类型')
|
1069
|
-
continue
|
1070
|
-
|
1071
|
-
new_contents.append(new_content)
|
1072
|
-
answer['contents'] = new_contents
|
1073
|
-
elif isinstance(answer, str): # 普通模式也转成解释器风格,方便统一处理
|
1074
|
-
post_record['all_answers'][k] = {'contents': [{'type': 'text',
|
1075
|
-
'text': answer}]}
|
1076
|
-
|
1077
|
-
# 2.2 downloads:下载链接精简下,并把关联的文件也顺带整理一下
|
1078
|
-
for answer in post_record['all_answers']:
|
1079
|
-
if 'downloads' not in answer:
|
1080
|
-
continue
|
1081
|
-
for i, link in enumerate(answer['downloads']):
|
1082
|
-
m = re.search(r'filename%3D(.+?)&sig=', link)
|
1083
|
-
if m:
|
1084
|
-
answer['downloads'][i] = str(post_record['id']) + '/' + unquote(unquote(m.group(1)))
|
1085
|
-
# 对应的文件不存在的不要,有数据超过50M的也不要
|
1086
|
-
file = self.download_files_dir / link
|
1087
|
-
if not file.exists() and file.size() > 50 * 1024 * 1024:
|
1088
|
-
return
|
1089
|
-
|
1090
|
-
# 理论上下载的文件不应该有重复,虽然不知道为什么会拿到重复,但去掉重复比较好
|
1091
|
-
answer['downloads'] = list(OrderedDict.fromkeys(answer['downloads']))
|
1092
|
-
|
1093
|
-
# 2.3 删掉answer里其他没用的字段
|
1094
|
-
for answer in post_record['all_answers']:
|
1095
|
-
for name in ['created', 'message_id', 'conversation_id', 'end_turn']:
|
1096
|
-
if name in answer:
|
1097
|
-
del answer[name]
|
1098
|
-
|
1099
|
-
# 返回处理结果
|
1100
|
-
return post_record
|
1101
|
-
|
1102
|
-
@staticmethod
|
1103
|
-
def post2verify_record(post_record):
|
1104
|
-
""" 这个一般是要具体任务定制的,没有通用操作方式
|
1105
|
-
|
1106
|
-
注意,如果要使用create_verify的多进程功能,这个函数必须是静态的,并且里面也不能使用其他"类静态方法"
|
1107
|
-
否则写成类方法或对象方法都可以
|
1108
|
-
|
1109
|
-
"""
|
1110
|
-
raise NotImplementedError
|
1111
|
-
|
1112
|
-
def verify2train_record(self, verify_record):
|
1113
|
-
""" 这个一般是要具体任务定制的,没有通用操作方式 """
|
1114
|
-
raise NotImplementedError
|
1115
|
-
|
1116
|
-
def organize_downloaded_files(self):
|
1117
|
-
# 把下载的文件整理的更清晰些
|
1118
|
-
for f in tqdm(list(self.root.glob_files('OpenAI-download-*')),
|
1119
|
-
desc='整理下载的文件'):
|
1120
|
-
new_name = re.sub(r'OpenAI-download-\d+-', '', f.name)
|
1121
|
-
new_name = new_name.replace('-', '/', 1)
|
1122
|
-
try:
|
1123
|
-
(self.download_files_dir / new_name).parent.mkdir(exist_ok=True)
|
1124
|
-
f.rename2(self.download_files_dir / new_name, if_exists='replace')
|
1125
|
-
except FileExistsError as e:
|
1126
|
-
# 有的文件会移动不了
|
1127
|
-
print(e)
|
1128
|
-
|
1129
|
-
# 会剩一些特殊的处理不了的文件,可以看一眼后手动删掉
|
1130
|
-
# 这些相关的records,默认的chatted2post_record会把这些记录过滤掉
|
1131
|
-
|
1132
|
-
def create_post(self, **kwargs):
|
1133
|
-
""" 建议初步跑的时候,先串行debug,等比较稳定后,再开并发跑
|
1134
|
-
"""
|
1135
|
-
if 'dst_dir' not in kwargs:
|
1136
|
-
kwargs['dst_dir'] = self.post_dir.root
|
1137
|
-
self.chatted_dir.process_each_record(self.chatted2post_record, **kwargs)
|
1138
|
-
self.post_dir.update_subfiles()
|
1139
|
-
num1, num2 = self.chatted_dir.count_records(), self.post_dir.count_records()
|
1140
|
-
print(f'chatted有{num1}条,转换post有{num2}条,转换率{num2 / num1:.2%}')
|
1141
|
-
|
1142
|
-
def create_verify(self, **kwargs):
|
1143
|
-
""" 有时候create_verify是有cpu密集运算场景的,可以开多进程
|
1144
|
-
"""
|
1145
|
-
if 'dst_dir' not in kwargs:
|
1146
|
-
kwargs['dst_dir'] = self.verify_dir.root
|
1147
|
-
self.post_dir.process_each_record(self.post2verify_record, **kwargs)
|
1148
|
-
self.verify_dir.update_subfiles()
|
1149
|
-
num1, num2 = self.post_dir.count_records(), self.verify_dir.count_records()
|
1150
|
-
num1 = num1 or -1
|
1151
|
-
print(f'post有{num1}条,转换verify有{num2}条,转换率{num2 / num1:.2%}')
|
1152
|
-
|
1153
|
-
def refine_verify(self, print_mode=1, **kwargs):
|
1154
|
-
""" 重复检查verify数据
|
1155
|
-
|
1156
|
-
这个函数可以重复执行,但前提是self.post2verify_record里的设计有增量规则部分
|
1157
|
-
"""
|
1158
|
-
self.verify_dir.process_each_record(self.post2verify_record, print_mode=print_mode,
|
1159
|
-
inplace=True, desc='refine_verify', **kwargs)
|
1160
|
-
|
1161
|
-
@classmethod
|
1162
|
-
def texts2train_record(cls, texts):
|
1163
|
-
""" user和assistant的轮询对话,转为训练集格式 """
|
1164
|
-
messages = []
|
1165
|
-
for i, text in enumerate(texts):
|
1166
|
-
role = 'assistant' if i % 2 else 'user'
|
1167
|
-
messages.append({'role': role, 'content': text})
|
1168
|
-
return {'messages': messages}
|
1169
|
-
|
1170
|
-
def create_train(self, **kwargs):
|
1171
|
-
if 'dst_dir' not in kwargs:
|
1172
|
-
kwargs['dst_dir'] = self.train_dir.root
|
1173
|
-
self.post_dir.process_each_record(self.verify2train_record, **kwargs)
|
1174
|
-
self.train_dir.update_subfiles()
|
1175
|
-
|
1176
|
-
def check_chatted_record(self, chatted_record):
|
1177
|
-
""" 检查chatted数据的有效性 """
|
1178
|
-
x = chatted_record
|
1179
|
-
x = self.chatted2post_record(x)
|
1180
|
-
# x = self.post2verify_record(x)
|
1181
|
-
# 针对verify可以再进一步定制规则
|
1182
|
-
return bool(x)
|
1183
|
-
|
1184
|
-
def create_rechat(self, rechat_path):
|
1185
|
-
""" 筛选失败的数据到一个新的目录,常用于对chatted数据筛选出未成功的样例,上池子重跑
|
1186
|
-
|
1187
|
-
:param rechat_path: 把挑选出来的数据放到新路径
|
1188
|
-
"""
|
1189
|
-
gcd = GptChatDir(rechat_path)
|
1190
|
-
f = open(gcd.chat_file, 'w', encoding='utf-8')
|
1191
|
-
|
1192
|
-
for record in tqdm(self.chatted_dir.yield_record(), '检查待重新生成的问题'):
|
1193
|
-
if not self.check_chatted_record(record):
|
1194
|
-
continue
|
1195
|
-
# 否则把这个条目放到rechat,准备拿去重新提问
|
1196
|
-
if 'error' in record:
|
1197
|
-
del record['error']
|
1198
|
-
f.write(json.dumps(record, ensure_ascii=False) + '\n')
|
1199
|
-
# 如果有文件,也要对应移动
|
1200
|
-
src_dir = self.upload_files_dir / str(record['id'])
|
1201
|
-
if src_dir.is_dir():
|
1202
|
-
src_dir.copy(gcd.upload_files_dir / src_dir.name, if_exists='skip')
|
1203
|
-
|
1204
|
-
f.close()
|
1205
|
-
return gcd
|
1206
|
-
|
1207
|
-
def update_chatted(self, rechat_path):
|
1208
|
-
""" 从另一个rechat数据,更新数据条目过来
|
1209
|
-
|
1210
|
-
self依然叫src,rechat叫dst,虽然其实数据是从rechat更新流向self
|
1211
|
-
|
1212
|
-
注意:这个函数还没有比较严格地进行调试~
|
1213
|
-
"""
|
1214
|
-
# 1 读取有效记录
|
1215
|
-
gcd = GptChatDir(rechat_path)
|
1216
|
-
gcd.organize_downloaded_files()
|
1217
|
-
# 请确保内存充足哦,这个函数会从rechat的chatted读取所有通过的记录保存起来
|
1218
|
-
dst_records = {}
|
1219
|
-
for record in gcd.chatted_dir.yield_record():
|
1220
|
-
# 找到有all_answers的挑出来
|
1221
|
-
post_record = self.chatted2post_record(record)
|
1222
|
-
if post_record:
|
1223
|
-
dst_records[record['id']] = record
|
1224
|
-
|
1225
|
-
# 2 更新记录
|
1226
|
-
def update_each_record(x):
|
1227
|
-
if x['id'] in dst_records:
|
1228
|
-
# 除了返回record,还得拷贝目录数据呢
|
1229
|
-
# 上传的目录一般没变,但最好重置下
|
1230
|
-
src_dir = self.upload_files_dir / x['id']
|
1231
|
-
dst_dir = gcd.upload_files_dir / x['id']
|
1232
|
-
dst_dir.copy(src_dir, if_exists='replace')
|
1233
|
-
# 下载的目录
|
1234
|
-
src_dir = self.download_files_dir / x['id']
|
1235
|
-
dst_dir = gcd.download_files_dir / x['id']
|
1236
|
-
dst_dir.copy(src_dir, if_exists='replace')
|
1237
|
-
return dst_records[x['id']]
|
1238
|
-
else:
|
1239
|
-
return x
|
1240
|
-
|
1241
|
-
self.chatted_dir.update_each_record(update_each_record)
|
1242
|
-
|
1243
|
-
|
1244
|
-
def __5_bdchat():
|
1245
|
-
""" 百度相关api """
|
1246
|
-
|
1247
|
-
|
1248
|
-
class BaiduChatbot:
|
1249
|
-
def __init__(self, api_key, secret_key, file_path=None):
|
1250
|
-
self.API_KEY = api_key
|
1251
|
-
self.SECRET_KEY = secret_key
|
1252
|
-
self.ACCESS_TOKEN = self._get_access_token()
|
1253
|
-
self.base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token="
|
1254
|
-
self.file_path = file_path # 文件路径为可选参数
|
1255
|
-
|
1256
|
-
def _get_access_token(self):
|
1257
|
-
"""
|
1258
|
-
使用 AK,SK 生成鉴权签名(Access Token)
|
1259
|
-
:return: access_token,或是None(如果错误)
|
1260
|
-
"""
|
1261
|
-
url = "https://aip.baidubce.com/oauth/2.0/token"
|
1262
|
-
params = {
|
1263
|
-
"grant_type": "client_credentials",
|
1264
|
-
"client_id": self.API_KEY,
|
1265
|
-
"client_secret": self.SECRET_KEY
|
1266
|
-
}
|
1267
|
-
return str(requests.post(url, params=params).json().get("access_token"))
|
1268
|
-
|
1269
|
-
def chat(self, user_message):
|
1270
|
-
""" 向Baidu API发送用户消息并返回API的回复
|
1271
|
-
注意user_message的token不要超过3k
|
1272
|
-
"""
|
1273
|
-
url = self.base_url + self.ACCESS_TOKEN
|
1274
|
-
payload = json.dumps({
|
1275
|
-
"messages": [{"role": "user", "content": user_message}]
|
1276
|
-
})
|
1277
|
-
headers = {'Content-Type': 'application/json'}
|
1278
|
-
response = requests.post(url, headers=headers, data=payload)
|
1279
|
-
response_json = response.json()
|
1280
|
-
response_json['user_message'] = user_message
|
1281
|
-
response_json['timestamp'] = datetime.datetime.now().isoformat()
|
1282
|
-
|
1283
|
-
# 如果指定了文件路径,自动保存记录
|
1284
|
-
if self.file_path:
|
1285
|
-
self._save_to_file(response_json)
|
1286
|
-
|
1287
|
-
return response_json.get('result', '')
|
1288
|
-
|
1289
|
-
def _save_to_file(self, response):
|
1290
|
-
with open(self.file_path, 'a', encoding='utf-8') as file:
|
1291
|
-
file.write(json.dumps(response, ensure_ascii=False) + '\n')
|