nlpertools 1.0.6.dev0__py3-none-any.whl → 1.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nlpertools/__init__.py +3 -4
- nlpertools/cli.py +143 -0
- nlpertools/data_client.py +56 -17
- nlpertools/dataprocess.py +28 -12
- nlpertools/draw/__init__.py +0 -0
- nlpertools/draw/draw.py +81 -0
- nlpertools/draw/math_func.py +33 -0
- nlpertools/get_2fa.py +0 -0
- nlpertools/io/dir.py +35 -3
- nlpertools/io/file.py +17 -11
- nlpertools/ml.py +74 -24
- nlpertools/other.py +152 -24
- {nlpertools-1.0.6.dev0.dist-info → nlpertools-1.0.9.dist-info}/METADATA +33 -10
- {nlpertools-1.0.6.dev0.dist-info → nlpertools-1.0.9.dist-info}/RECORD +18 -12
- {nlpertools-1.0.6.dev0.dist-info → nlpertools-1.0.9.dist-info}/WHEEL +1 -1
- nlpertools-1.0.9.dist-info/entry_points.txt +2 -0
- {nlpertools-1.0.6.dev0.dist-info → nlpertools-1.0.9.dist-info}/LICENSE +0 -0
- {nlpertools-1.0.6.dev0.dist-info → nlpertools-1.0.9.dist-info}/top_level.txt +0 -0
    
        nlpertools/__init__.py
    CHANGED
    
    | @@ -3,6 +3,7 @@ | |
| 3 3 | 
             
            # @Author  : youshu.Ji
         | 
| 4 4 | 
             
            from .algo.kmp import *
         | 
| 5 5 | 
             
            from .data_structure.base_structure import *
         | 
| 6 | 
            +
            from .draw import *
         | 
| 6 7 | 
             
            from .dataprocess import *
         | 
| 7 8 | 
             
            from .io.dir import *
         | 
| 8 9 | 
             
            from .io.file import *
         | 
| @@ -15,10 +16,8 @@ from .reminder import * | |
| 15 16 | 
             
            from .utils_for_nlpertools import *
         | 
| 16 17 | 
             
            from .wrapper import *
         | 
| 17 18 | 
             
            from .monitor import *
         | 
| 19 | 
            +
            from .cli import *
         | 
| 18 20 |  | 
| 19 | 
            -
            import os
         | 
| 20 21 |  | 
| 21 22 |  | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
            __version__ = '1.0.5'
         | 
| 23 | 
            +
            __version__ = '1.0.9'
         | 
    
        nlpertools/cli.py
    ADDED
    
    | @@ -0,0 +1,143 @@ | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import uuid
         | 
| 4 | 
            +
            import sys
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            如何Debug cli.py
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def git_push():
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
                针对国内提交github经常失败,自动提交
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
                num = -1
         | 
| 16 | 
            +
                while 1:
         | 
| 17 | 
            +
                    num += 1
         | 
| 18 | 
            +
                    print("retry num: {}".format(num))
         | 
| 19 | 
            +
                    info = os.system("git push --set-upstream origin main")
         | 
| 20 | 
            +
                    print(str(info))
         | 
| 21 | 
            +
                    if not str(info).startswith("fatal"):
         | 
| 22 | 
            +
                        print("scucess")
         | 
| 23 | 
            +
                        break
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def git_pull():
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                针对国内提交github经常失败,自动提交
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                num = -1
         | 
| 31 | 
            +
                while 1:
         | 
| 32 | 
            +
                    num += 1
         | 
| 33 | 
            +
                    print("retry num: {}".format(num))
         | 
| 34 | 
            +
                    info = os.system("git pull")
         | 
| 35 | 
            +
                    print(str(info))
         | 
| 36 | 
            +
                    if not str(info).startswith("fatal") and not str(info).startswith("error"):
         | 
| 37 | 
            +
                        print("scucess")
         | 
| 38 | 
            +
                        break
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def get_mac_address():
         | 
| 42 | 
            +
                mac = uuid.UUID(int=uuid.getnode()).hex[-12:]
         | 
| 43 | 
            +
                mac_address = ":".join([mac[e:e + 2] for e in range(0, 11, 2)])
         | 
| 44 | 
            +
                print("mac address 不一定准确")
         | 
| 45 | 
            +
                print(mac_address)
         | 
| 46 | 
            +
                return mac_address
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def get_2af_value(key):
         | 
| 50 | 
            +
                import pyotp
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                key应该是7位的
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
                print(key)
         | 
| 55 | 
            +
                totp = pyotp.TOTP(key)
         | 
| 56 | 
            +
                print(totp.now())
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def start_gpu_usage_notify_server():
         | 
| 60 | 
            +
                from flask import Flask
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                app = Flask(__name__)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                @app.route("/notify", methods=["GET"])
         | 
| 65 | 
            +
                def notify():
         | 
| 66 | 
            +
                    # 这里可以根据需要动态生成通知内容
         | 
| 67 | 
            +
                    usage = os.popen("nvidia-smi --query-gpu=memory.used --format=csv").read().split("\n")[1:]
         | 
| 68 | 
            +
                    res = 0
         | 
| 69 | 
            +
                    for edx, each in enumerate(usage):
         | 
| 70 | 
            +
                        if each.startswith("0"):
         | 
| 71 | 
            +
                            res += 1
         | 
| 72 | 
            +
                    print(res)
         | 
| 73 | 
            +
                    return str(res), 200
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                app.run(host="0.0.0.0", port=5000)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def start_gpu_usage_notify_client():
         | 
| 79 | 
            +
                import requests
         | 
| 80 | 
            +
                from plyer import notification
         | 
| 81 | 
            +
                import time
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                SERVER_URL = 'http://127.0.0.1:5000/notify'  # 服务器的 API 地址
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def notify(text):
         | 
| 86 | 
            +
                    # 使用 plyer 发送通知
         | 
| 87 | 
            +
                    notification.notify(
         | 
| 88 | 
            +
                        title='远程通知',
         | 
| 89 | 
            +
                        message=text,
         | 
| 90 | 
            +
                        timeout=10  # 10秒的通知显示时间
         | 
| 91 | 
            +
                    )
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                """定时轮询服务器获取通知"""
         | 
| 94 | 
            +
                while True:
         | 
| 95 | 
            +
                    try:
         | 
| 96 | 
            +
                        response = requests.get(SERVER_URL)
         | 
| 97 | 
            +
                        if response.status_code == 200:
         | 
| 98 | 
            +
                            num = int(response.text)
         | 
| 99 | 
            +
                            if num > 0:
         | 
| 100 | 
            +
                                notify(f"服务器有{num}张卡")
         | 
| 101 | 
            +
                            print(f"服务器有{num}张卡")
         | 
| 102 | 
            +
                        else:
         | 
| 103 | 
            +
                            print("服务器没有新通知")
         | 
| 104 | 
            +
                    except Exception as e:
         | 
| 105 | 
            +
                        print(f"与服务器连接失败: {e}")
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    time.sleep(1)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            def main():
         | 
| 111 | 
            +
                parser = argparse.ArgumentParser(description="CLI tool for git operations and getting MAC address.")
         | 
| 112 | 
            +
                parser.add_argument('--gitpush', action='store_true', help='Perform git push operation.')
         | 
| 113 | 
            +
                parser.add_argument('--gitpull', action='store_true', help='Perform git pull operation.')
         | 
| 114 | 
            +
                parser.add_argument('--mac_address', action='store_true', help='Get the MAC address.')
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                parser.add_argument('--get_2fa', action='store_true', help='Get the 2fa value.')
         | 
| 117 | 
            +
                parser.add_argument('--get_2fa_key', type=str, help='Get the 2fa value.')
         | 
| 118 | 
            +
                parser.add_argument('--monitor_gpu_cli', action='store_true', help='Get the 2fa value.')
         | 
| 119 | 
            +
                parser.add_argument('--monitor_gpu_ser', action='store_true', help='Get the 2fa value.')
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                args = parser.parse_args()
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                if args.gitpush:
         | 
| 124 | 
            +
                    git_push()
         | 
| 125 | 
            +
                elif args.gitpull:
         | 
| 126 | 
            +
                    git_pull()
         | 
| 127 | 
            +
                elif args.mac_address:
         | 
| 128 | 
            +
                    get_mac_address()
         | 
| 129 | 
            +
                elif args.monitor_gpu_cli:
         | 
| 130 | 
            +
                    start_gpu_usage_notify_client()
         | 
| 131 | 
            +
                elif args.monitor_gpu_ser:
         | 
| 132 | 
            +
                    start_gpu_usage_notify_server()
         | 
| 133 | 
            +
                elif args.get_2fa:
         | 
| 134 | 
            +
                    if args.get_2fa_key:
         | 
| 135 | 
            +
                        get_2af_value(args.get_2fa_key)
         | 
| 136 | 
            +
                    else:
         | 
| 137 | 
            +
                        print("Please provide a key as an argument.")
         | 
| 138 | 
            +
                else:
         | 
| 139 | 
            +
                    print("No operation specified.")
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
            if __name__ == '__main__':
         | 
| 143 | 
            +
                main()
         | 
    
        nlpertools/data_client.py
    CHANGED
    
    | @@ -1,3 +1,4 @@ | |
| 1 | 
            +
            #encoding=utf-8
         | 
| 1 2 | 
             
            # !/usr/bin/python3.8
         | 
| 2 3 | 
             
            # -*- coding: utf-8 -*-
         | 
| 3 4 | 
             
            # @Author  : youshu.Ji
         | 
| @@ -5,9 +6,11 @@ import datetime | |
| 5 6 | 
             
            import json
         | 
| 6 7 | 
             
            import logging
         | 
| 7 8 |  | 
| 8 | 
            -
            from . import DB_CONFIG_FILE
         | 
| 9 9 | 
             
            from .io.file import read_yaml
         | 
| 10 10 | 
             
            from .utils.package import *
         | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            DB_CONFIG_FILE = os.path.join(os.path.dirname(__file__), "default_db_config.yml")
         | 
| 11 14 |  | 
| 12 15 | 
             
            # import aioredis
         | 
| 13 16 | 
             
            # import happybase
         | 
| @@ -28,21 +31,24 @@ class Neo4jOps(object): | |
| 28 31 | 
             
                NEO4J_TIMEOUT = 0.3
         | 
| 29 32 | 
             
                pass
         | 
| 30 33 |  | 
| 34 | 
            +
             | 
| 31 35 | 
             
            class SqliteOps(object):
         | 
| 32 | 
            -
                 | 
| 33 | 
            -
                 | 
| 34 | 
            -
                 | 
| 35 | 
            -
                 | 
| 36 | 
            -
                 | 
| 37 | 
            -
                 | 
| 38 | 
            -
                 | 
| 39 | 
            -
                 | 
| 40 | 
            -
                 | 
| 41 | 
            -
                 | 
| 42 | 
            -
                 | 
| 43 | 
            -
             | 
| 44 | 
            -
                 | 
| 45 | 
            -
                conn. | 
| 36 | 
            +
                pass
         | 
| 37 | 
            +
                # import sqlite3
         | 
| 38 | 
            +
                # database_path = r'xx.db'
         | 
| 39 | 
            +
                # conn = sqlite3.connect(database_path)
         | 
| 40 | 
            +
                # c = conn.cursor()
         | 
| 41 | 
            +
                # sql = "select name from sqlite_master where type='table' order by name"
         | 
| 42 | 
            +
                # c.execute(sql)
         | 
| 43 | 
            +
                # print(c.fetchall())
         | 
| 44 | 
            +
                # sql = "select * from typecho_contents"
         | 
| 45 | 
            +
                # c.execute(sql)
         | 
| 46 | 
            +
                # res = c.fetchall()
         | 
| 47 | 
            +
                # print(res[3])
         | 
| 48 | 
            +
                #
         | 
| 49 | 
            +
                # conn.commit()
         | 
| 50 | 
            +
                # conn.close()
         | 
| 51 | 
            +
             | 
| 46 52 |  | 
| 47 53 | 
             
            class MysqlOps(object):
         | 
| 48 54 | 
             
                import pandas as pd
         | 
| @@ -116,6 +122,41 @@ class EsOps(object): | |
| 116 122 | 
             
                    print(f"批量保存数据: {_res}")
         | 
| 117 123 |  | 
| 118 124 |  | 
| 125 | 
            +
            class MongoDB_BETA:
         | 
| 126 | 
            +
                def __init__(self, host='localhost', port=27017, db_name=None, collection_name=None):
         | 
| 127 | 
            +
                    self.host = host
         | 
| 128 | 
            +
                    self.port = port
         | 
| 129 | 
            +
                    self.db_name = db_name
         | 
| 130 | 
            +
                    self.collection_name = collection_name
         | 
| 131 | 
            +
                    self.client = None
         | 
| 132 | 
            +
                    self.db = None
         | 
| 133 | 
            +
                    self.collection = None
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def connect(self):
         | 
| 136 | 
            +
                    self.client = MongoClient(self.host, self.port)
         | 
| 137 | 
            +
                    self.db = self.client[self.db_name]
         | 
| 138 | 
            +
                    self.collection = self.db[self.collection_name]
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                def close(self):
         | 
| 141 | 
            +
                    if self.client:
         | 
| 142 | 
            +
                        self.client.close()
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def insert_data(self, data):
         | 
| 145 | 
            +
                    if isinstance(data, list):
         | 
| 146 | 
            +
                        self.collection.insert_many(data)
         | 
| 147 | 
            +
                    else:
         | 
| 148 | 
            +
                        self.collection.insert_one(data)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                def check_data_exists(self, query):
         | 
| 151 | 
            +
                    """
         | 
| 152 | 
            +
                    检查某个数据是否存在于数据库中
         | 
| 153 | 
            +
                    :param query: 查询条件
         | 
| 154 | 
            +
                    :return: 布尔值,表示数据是否存在
         | 
| 155 | 
            +
                    """
         | 
| 156 | 
            +
                    return self.collection.count_documents(query) > 0
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 119 160 | 
             
            class MongoOps(object):
         | 
| 120 161 | 
             
                from pymongo import MongoClient
         | 
| 121 162 | 
             
                def __init__(self, config=global_db_config["mongo"]):
         | 
| @@ -348,8 +389,6 @@ class KafkaOps(object): | |
| 348 389 | 
             
                        print(recv)
         | 
| 349 390 |  | 
| 350 391 |  | 
| 351 | 
            -
             | 
| 352 | 
            -
             | 
| 353 392 | 
             
            class MilvusOps(object):
         | 
| 354 393 | 
             
                def __init__(self, config=global_db_config.milvus):
         | 
| 355 394 | 
             
                    from pymilvus import connections, Collection
         | 
    
        nlpertools/dataprocess.py
    CHANGED
    
    | @@ -55,9 +55,9 @@ class Pattern: | |
| 55 55 | 
             
                # 中文人名
         | 
| 56 56 | 
             
                chinese_name_pattern = "(?:[\u4e00-\u9fa5·]{2,3})"
         | 
| 57 57 | 
             
                # 英文人名
         | 
| 58 | 
            -
                english_name_pattern = "(^[a-zA-Z][a-zA-Z\s]{0,20}[a-zA-Z]$)"
         | 
| 58 | 
            +
                english_name_pattern = r"(^[a-zA-Z][a-zA-Z\s]{0,20}[a-zA-Z]$)"
         | 
| 59 59 | 
             
                # 纯数字
         | 
| 60 | 
            -
                pure_num_pattern = "\d+"
         | 
| 60 | 
            +
                pure_num_pattern = r"\d+"
         | 
| 61 61 | 
             
                # xxxx图/表 之类的表述
         | 
| 62 62 | 
             
                pic_table_descript_pattern = ".{1,15}图"
         | 
| 63 63 |  | 
| @@ -66,20 +66,20 @@ class Pattern: | |
| 66 66 | 
             
                hlink_pattern = (
         | 
| 67 67 | 
             
                    r"(https?|ftp|file)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]"
         | 
| 68 68 | 
             
                )
         | 
| 69 | 
            -
                http_pattern = "(http|https):\/\/([\w.]+\/?)\S*/\S*"
         | 
| 69 | 
            +
                http_pattern = r"(http|https):\/\/([\w.]+\/?)\S*/\S*"
         | 
| 70 70 | 
             
                # 邮箱
         | 
| 71 | 
            -
                email_pattern = "[A-Za-z0-9\u4e00-\u9fa5]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+"
         | 
| 71 | 
            +
                email_pattern = r"[A-Za-z0-9\u4e00-\u9fa5]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+"
         | 
| 72 72 | 
             
                # html 可能过于严格了
         | 
| 73 | 
            -
                html_pattern = "<[\s\S]*?>"
         | 
| 73 | 
            +
                html_pattern = r"<[\s\S]*?>"
         | 
| 74 74 | 
             
                # 重复 “asdasdasdasd”
         | 
| 75 75 | 
             
                repeat_pattern = "(.)\1+"
         | 
| 76 76 | 
             
                # 日期
         | 
| 77 | 
            -
                day_time_pattern = "\d{1,4}(-)(1[0-2]|0?[1-9])\1(0?[1-9]|[1-2]\d|30|31)"
         | 
| 77 | 
            +
                day_time_pattern = r"\d{1,4}(-)(1[0-2]|0?[1-9])\1(0?[1-9]|[1-2]\d|30|31)"
         | 
| 78 78 | 
             
                # 小时
         | 
| 79 | 
            -
                hour_time_pattern = "(?:[01]\d|2[0-3]):[0-5]\d:[0-5]\d"
         | 
| 79 | 
            +
                hour_time_pattern = r"(?:[01]\d|2[0-3]):[0-5]\d:[0-5]\d"
         | 
| 80 80 | 
             
                # 股票
         | 
| 81 81 | 
             
                stock_pattern = (
         | 
| 82 | 
            -
                    "(s[hz]|S[HZ])(000[\d]{3}|002[\d]{3}|300[\d]{3}|600[\d]{3}|60[\d]{4})"
         | 
| 82 | 
            +
                    r"(s[hz]|S[HZ])(000[\d]{3}|002[\d]{3}|300[\d]{3}|600[\d]{3}|60[\d]{4})"
         | 
| 83 83 | 
             
                )
         | 
| 84 84 |  | 
| 85 85 | 
             
                # 一般是需要替换的
         | 
| @@ -91,7 +91,7 @@ class Pattern: | |
| 91 91 | 
             
                # 微博视频等
         | 
| 92 92 | 
             
                weibo_pattern = r"([\s]\w+(的微博视频)|#|【|】|转发微博)"
         | 
| 93 93 | 
             
                # @
         | 
| 94 | 
            -
                at_pattern = "@\w+"
         | 
| 94 | 
            +
                at_pattern = r"@\w+"
         | 
| 95 95 |  | 
| 96 96 | 
             
                # from https://github.com/bigscience-workshop/data-preparation pii
         | 
| 97 97 | 
             
                year_patterns = [
         | 
| @@ -116,7 +116,7 @@ class Pattern: | |
| 116 116 | 
             
                ipv4_pattern = r'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(?:\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)){3}'
         | 
| 117 117 | 
             
                ipv6_pattern = r'(?:[0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:){1,7}:|(?:[0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:){1,5}(?::[0-9a-fA-F]{1,4}){1,2}|(?:[0-9a-fA-F]{1,4}:){1,4}(?::[0-9a-fA-F]{1,4}){1,3}|(?:[0-9a-fA-F]{1,4}:){1,3}(?::[0-9a-fA-F]{1,4}){1,4}|(?:[0-9a-fA-F]{1,4}:){1,2}(?::[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:(?:(?::[0-9a-fA-F]{1,4}){1,6})|:(?:(?::[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(?::[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(?:ffff(?::0{1,4}){0,1}:){0,1}(?:(?:25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(?:25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])|(?:[0-9a-fA-F]{1,4}:){1,4}:(?:(?:25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(?:2[0-4]|1{0,1}[0-9]){0,1}[0-9])'
         | 
| 118 118 | 
             
                ip_pattern = r"(?:^|[\b\s@?,!;:\'\")(.\p{Han}])(" + r"|".join(
         | 
| 119 | 
            -
                    [ipv4_pattern, ipv6_pattern]) + ")(?:$|[\s@,?!;:\'\"(.\p{Han}])"
         | 
| 119 | 
            +
                    [ipv4_pattern, ipv6_pattern]) + r")(?:$|[\s@,?!;:\'\"(.\p{Han}])"
         | 
| 120 120 |  | 
| 121 121 | 
             
                # https://regex101.com/r/EpA5B7/1
         | 
| 122 122 | 
             
                email_line_pattern = r'''
         | 
| @@ -466,7 +466,7 @@ class TextProcess(object): | |
| 466 466 | 
             
                    p = re.compile(pattern, re.S)
         | 
| 467 467 | 
             
                    text = p.sub("", text)
         | 
| 468 468 |  | 
| 469 | 
            -
                    dr = re.compile("@\w+", re.S)
         | 
| 469 | 
            +
                    dr = re.compile(r"@\w+", re.S)
         | 
| 470 470 | 
             
                    text = dr.sub("", text)
         | 
| 471 471 |  | 
| 472 472 | 
             
                    return text
         | 
| @@ -527,7 +527,7 @@ class TextProcess(object): | |
| 527 527 | 
             
                        text = re.sub(pattern, replace, text)
         | 
| 528 528 | 
             
                    return text
         | 
| 529 529 |  | 
| 530 | 
            -
                def calc_proportion_zh(self,text):
         | 
| 530 | 
            +
                def calc_proportion_zh(self, text):
         | 
| 531 531 | 
             
                    text = text.strip()
         | 
| 532 532 | 
             
                    # 如果是中国英文的情况,并且英文有空格分开
         | 
| 533 533 | 
             
                    if " " in text:
         | 
| @@ -538,6 +538,8 @@ class TextProcess(object): | |
| 538 538 | 
             
                            chinese_count += 1
         | 
| 539 539 | 
             
                        else:
         | 
| 540 540 | 
             
                            pass
         | 
| 541 | 
            +
             | 
| 542 | 
            +
             | 
| 541 543 | 
             
            class CopyFunc():
         | 
| 542 544 | 
             
                # from https://github.com/lemon234071/clean-dialog
         | 
| 543 545 | 
             
                def is_chinese_char(cp):
         | 
| @@ -597,6 +599,20 @@ def convert_basic2fullwidth(sentence): | |
| 597 599 | 
             
                    new_sentence += char
         | 
| 598 600 | 
             
                return new_sentence
         | 
| 599 601 |  | 
| 602 | 
            +
             | 
| 603 | 
            +
            def clean_illegal_chars_for_excel(df):
         | 
| 604 | 
            +
                # openpyxl 库写入 Excel 文件时,有一些非法字符,需要删除
         | 
| 605 | 
            +
                # 定义一个函数来移除字符串中的非法字符
         | 
| 606 | 
            +
                def remove_illegal_chars(s):
         | 
| 607 | 
            +
                    if isinstance(s, str):
         | 
| 608 | 
            +
                        # 移除 ASCII 码在非法范围内的字符
         | 
| 609 | 
            +
                        return re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F]', '', s)
         | 
| 610 | 
            +
                    return s
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                # 应用清理函数到数据框的每个元素
         | 
| 613 | 
            +
                return df.map(remove_illegal_chars)
         | 
| 614 | 
            +
             | 
| 615 | 
            +
             | 
| 600 616 | 
             
            if __name__ == "__main__":
         | 
| 601 617 | 
             
                pattern_for_filter = [
         | 
| 602 618 | 
             
                    Pattern.redundancy_space_pattern,
         | 
| 
            File without changes
         | 
    
        nlpertools/draw/draw.py
    ADDED
    
    | @@ -0,0 +1,81 @@ | |
| 1 | 
            +
            #!/usr/bin/python3.8
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
            # @Author  : youshu.Ji
         | 
| 4 | 
            +
            from ..utils.package import plt
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def confused_matrix(confuse_matrix):
         | 
| 8 | 
            +
                import seaborn as sns
         | 
| 9 | 
            +
                sns.set()
         | 
| 10 | 
            +
                f, ax = plt.subplots()
         | 
| 11 | 
            +
                ticklabels = ["l1", "l2", "l31"]
         | 
| 12 | 
            +
                sns.heatmap(confuse_matrix, annot=True, fmt=".3g", ax=ax, cmap='rainbow',
         | 
| 13 | 
            +
                            xticklabels=ticklabels, yticklabels=ticklabels)  # 画热力图
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                ax.set_title('confusion matrix')  # 标题
         | 
| 16 | 
            +
                ax.set_xlabel('predict')  # x轴
         | 
| 17 | 
            +
                ax.set_ylabel('true')  # y轴
         | 
| 18 | 
            +
                plt.show()
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                f.savefig('tmp.jpg', bbox_inches='tight')
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def plot_histogram(data, bin_size, max_bin):
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                画直方图,超过1000的统一按1000算
         | 
| 26 | 
            +
                :param data:
         | 
| 27 | 
            +
                :param bin_size:
         | 
| 28 | 
            +
                :return:
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                import matplotlib.pyplot as plt
         | 
| 31 | 
            +
                import numpy as np
         | 
| 32 | 
            +
                import pandas as pd
         | 
| 33 | 
            +
                from matplotlib.ticker import MaxNLocator
         | 
| 34 | 
            +
                # 将超过1000的值改为1000
         | 
| 35 | 
            +
                def process_lengths(data):
         | 
| 36 | 
            +
                    return [length if length <= max_bin else max_bin + 3 for length in data]
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                # 前闭后开
         | 
| 39 | 
            +
                # min_num, max_num = 0, 1000
         | 
| 40 | 
            +
                # min_num, max_num = min(data), max(data)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                plt.figure(figsize=(12, 8))
         | 
| 43 | 
            +
                processed_data = process_lengths(data)
         | 
| 44 | 
            +
                bins = np.arange(0, max_bin + 2 * bin_size, bin_size)
         | 
| 45 | 
            +
                # 绘制直方图
         | 
| 46 | 
            +
                n, new_bins, patches = plt.hist(processed_data, bins=bins, edgecolor='black', color='skyblue', alpha=0.7,
         | 
| 47 | 
            +
                                                linewidth=0)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                # 添加"∞"的标签
         | 
| 50 | 
            +
                # bins会改变
         | 
| 51 | 
            +
                plt.gca().set_xticks(bins)
         | 
| 52 | 
            +
                plt.gca().set_xticklabels([str(i) for i in plt.xticks()[0][:-1]] + ["∞"])
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                mean_val = np.mean(data)
         | 
| 55 | 
            +
                plt.axvline(mean_val, color='red', linestyle='dashed', linewidth=1)
         | 
| 56 | 
            +
                plt.text(mean_val + bin_size / 10, max(n) * 0.9, f'Mean: {mean_val:.2f}', color='red')
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                # 添加标题和标签
         | 
| 59 | 
            +
                plt.title('Module Line Number Distribution', fontsize=16, fontweight='bold')
         | 
| 60 | 
            +
                plt.xlabel('module line number', fontsize=14)
         | 
| 61 | 
            +
                plt.ylabel('frequency', fontsize=14)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                plt.grid(True, linestyle='--', alpha=0.6)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                plt.xticks(fontsize=12)
         | 
| 66 | 
            +
                plt.yticks(fontsize=12)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                # 在每个柱状图上显示数值
         | 
| 69 | 
            +
                for i in range(len(patches)):
         | 
| 70 | 
            +
                    plt.text(patches[i].get_x() + patches[i].get_width() / 2, patches[i].get_height(),
         | 
| 71 | 
            +
                             str(int(n[i])), ha='center', va='bottom', fontsize=12)
         | 
| 72 | 
            +
                plt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))
         | 
| 73 | 
            +
                # 显示图表
         | 
| 74 | 
            +
                plt.show()
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            if __name__ == '__main__':
         | 
| 78 | 
            +
                # 调整区间大小
         | 
| 79 | 
            +
                bin_size = 50
         | 
| 80 | 
            +
                # 示例模块长度数据
         | 
| 81 | 
            +
                plot_histogram([1, 100, 999, 1000, 1002, 1100, 1150], bin_size, max_bin=1000)
         | 
| @@ -0,0 +1,33 @@ | |
| 1 | 
            +
            # 数学函数
         | 
| 2 | 
            +
            def draw_log():
         | 
| 3 | 
            +
                import matplotlib.pyplot as plt
         | 
| 4 | 
            +
                import numpy as np
         | 
| 5 | 
            +
                from matplotlib.ticker import MultipleLocator, FormatStrFormatter
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                # 生成一些数据
         | 
| 8 | 
            +
                x = np.linspace(0.1, 10, 100)
         | 
| 9 | 
            +
                # 默认log指的时loge
         | 
| 10 | 
            +
                y = np.log(x)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                # 创建一个新的图形和轴
         | 
| 13 | 
            +
                fig, ax = plt.subplots()
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                # 绘制log图像
         | 
| 16 | 
            +
                ax.plot(x, y)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                # 设置图像标题和轴标签
         | 
| 19 | 
            +
                ax.set_title("Logarithmic Function")
         | 
| 20 | 
            +
                ax.set_xlabel("x")
         | 
| 21 | 
            +
                ax.set_ylabel("log(x)")
         | 
| 22 | 
            +
                # 设置横坐标的刻度间隔为1
         | 
| 23 | 
            +
                ax.xaxis.set_major_locator(MultipleLocator(1))
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                # 设置横坐标的刻度格式
         | 
| 26 | 
            +
                ax.xaxis.set_major_formatter(FormatStrFormatter("%.1f"))
         | 
| 27 | 
            +
                # 添加x=1的虚线
         | 
| 28 | 
            +
                ax.axvline(x=1, linestyle="--", color="gray")
         | 
| 29 | 
            +
                # 添加y=1的虚线
         | 
| 30 | 
            +
                ax.axhline(y=0, linestyle="--", color="gray")
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                # 显示图像
         | 
| 33 | 
            +
                plt.show()
         | 
    
        nlpertools/get_2fa.py
    ADDED
    
    | 
            File without changes
         | 
    
        nlpertools/io/dir.py
    CHANGED
    
    | @@ -10,7 +10,30 @@ def j_mkdir(name): | |
| 10 10 | 
             
                os.makedirs(name, exist_ok=True)
         | 
| 11 11 |  | 
| 12 12 |  | 
| 13 | 
            -
            def  | 
| 13 | 
            +
            def j_walk(name, suffix=None):
         | 
| 14 | 
            +
                paths = []
         | 
| 15 | 
            +
                for root, dirs, files in os.walk(name):
         | 
| 16 | 
            +
                    for file in files:
         | 
| 17 | 
            +
                        path = os.path.join(root, file)
         | 
| 18 | 
            +
                        if not (suffix and not path.endswith(suffix)):
         | 
| 19 | 
            +
                            paths.append(path)
         | 
| 20 | 
            +
                return paths
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def windows_to_wsl_path(windows_path):
         | 
| 24 | 
            +
                # 转换驱动器号
         | 
| 25 | 
            +
                if windows_path[1:3] == ':\\':
         | 
| 26 | 
            +
                    drive_letter = windows_path[0].lower()
         | 
| 27 | 
            +
                    path = windows_path[2:].replace('\\', '/')
         | 
| 28 | 
            +
                    wsl_path = f'/mnt/{drive_letter}{path}'
         | 
| 29 | 
            +
                else:
         | 
| 30 | 
            +
                    # 如果路径不是以驱动器号开头,则直接替换路径分隔符
         | 
| 31 | 
            +
                    wsl_path = windows_path.replace('\\', '/').replace("'", "\'")
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                return wsl_path
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def get_filename(path, suffix=True) -> str:
         | 
| 14 37 | 
             
                """
         | 
| 15 38 | 
             
                返回路径最后的文件名
         | 
| 16 39 | 
             
                :param path:
         | 
| @@ -18,11 +41,20 @@ def get_filename(path) -> str: | |
| 18 41 | 
             
                """
         | 
| 19 42 | 
             
                # path = r'***/**/***.txt'
         | 
| 20 43 | 
             
                filename = os.path.split(path)[-1]
         | 
| 44 | 
            +
                if not suffix:
         | 
| 45 | 
            +
                    filename = filename.split('.')[0]
         | 
| 21 46 | 
             
                return filename
         | 
| 22 47 |  | 
| 23 48 |  | 
| 24 | 
            -
            def  | 
| 25 | 
            -
                 | 
| 49 | 
            +
            def listdir(dir_name, including_dir=True):
         | 
| 50 | 
            +
                filenames = os.listdir(dir_name)
         | 
| 51 | 
            +
                if including_dir:
         | 
| 52 | 
            +
                    return [os.path.join(dir_name, filename) for filename in filenames]
         | 
| 53 | 
            +
                else:
         | 
| 54 | 
            +
                    return list(filenames)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def listdir_yield(dir_name, including_dir=True):
         | 
| 26 58 | 
             
                filenames = os.listdir(dir_name)
         | 
| 27 59 | 
             
                for filename in filenames:
         | 
| 28 60 | 
             
                    if including_dir:
         | 
    
        nlpertools/io/file.py
    CHANGED
    
    | @@ -5,7 +5,6 @@ import codecs | |
| 5 5 | 
             
            import json
         | 
| 6 6 | 
             
            import pickle
         | 
| 7 7 | 
             
            import random
         | 
| 8 | 
            -
            import time
         | 
| 9 8 | 
             
            from itertools import (takewhile, repeat)
         | 
| 10 9 | 
             
            import pandas as pd
         | 
| 11 10 | 
             
            # import omegaconf
         | 
| @@ -15,10 +14,16 @@ from ..utils.package import * | |
| 15 14 | 
             
            LARGE_FILE_THRESHOLD = 1e5
         | 
| 16 15 |  | 
| 17 16 |  | 
| 17 | 
            +
            def safe_filename(filename: str) -> str:
         | 
| 18 | 
            +
                for char in ['\\', '/', ':', '*', '?', '"', '<', '>', '|']:
         | 
| 19 | 
            +
                    filename = filename.replace(char, '_')
         | 
| 20 | 
            +
                return filename
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 18 23 | 
             
            def read_yaml(path, omega=False):
         | 
| 19 24 | 
             
                if omega:
         | 
| 20 25 | 
             
                    return omegaconf.OmegaConf.load(path)
         | 
| 21 | 
            -
                return yaml.load(codecs.open(path), Loader=yaml.FullLoader)
         | 
| 26 | 
            +
                return yaml.load(codecs.open(path, encoding='utf-8'), Loader=yaml.FullLoader)
         | 
| 22 27 |  | 
| 23 28 |  | 
| 24 29 | 
             
            def _merge_file(filelist, save_filename, shuffle=False):
         | 
| @@ -52,7 +57,7 @@ load_from_json | |
| 52 57 |  | 
| 53 58 |  | 
| 54 59 | 
             
            # 读txt文件 一次全读完 返回list 去换行
         | 
| 55 | 
            -
            def readtxt_list_all_strip(path, encoding='utf-8'):
         | 
| 60 | 
            +
            def readtxt_list_all_strip(path, encoding='utf-8') -> list:
         | 
| 56 61 | 
             
                file_line_num = iter_count(path)
         | 
| 57 62 | 
             
                lines = []
         | 
| 58 63 | 
             
                with codecs.open(path, 'r', encoding) as r:
         | 
| @@ -67,7 +72,7 @@ def readtxt_list_all_strip(path, encoding='utf-8'): | |
| 67 72 |  | 
| 68 73 |  | 
| 69 74 | 
             
            # 读txt 一次读一行 最后返回list
         | 
| 70 | 
            -
            def readtxt_list_each(path):
         | 
| 75 | 
            +
            def readtxt_list_each(path) -> list:
         | 
| 71 76 | 
             
                lines = []
         | 
| 72 77 | 
             
                with codecs.open(path, 'r', 'utf-8') as r:
         | 
| 73 78 | 
             
                    line = r.readline()
         | 
| @@ -77,7 +82,7 @@ def readtxt_list_each(path): | |
| 77 82 | 
             
                return lines
         | 
| 78 83 |  | 
| 79 84 |  | 
| 80 | 
            -
            def readtxt_list_each_strip(path):
         | 
| 85 | 
            +
            def readtxt_list_each_strip(path) -> list:
         | 
| 81 86 | 
             
                """
         | 
| 82 87 | 
             
                yield方法
         | 
| 83 88 | 
             
                """
         | 
| @@ -89,14 +94,14 @@ def readtxt_list_each_strip(path): | |
| 89 94 |  | 
| 90 95 |  | 
| 91 96 | 
             
            # 读txt文件 一次全读完 返回list
         | 
| 92 | 
            -
            def readtxt_list_all(path):
         | 
| 97 | 
            +
            def readtxt_list_all(path) -> list:
         | 
| 93 98 | 
             
                with codecs.open(path, 'r', 'utf-8') as r:
         | 
| 94 99 | 
             
                    lines = r.readlines()
         | 
| 95 100 | 
             
                    return lines
         | 
| 96 101 |  | 
| 97 102 |  | 
| 98 103 | 
             
            # 读byte文件 读成一条string
         | 
| 99 | 
            -
            def readtxt_byte(path, encoding="utf-8"):
         | 
| 104 | 
            +
            def readtxt_byte(path, encoding="utf-8") -> str:
         | 
| 100 105 | 
             
                with codecs.open(path, 'rb') as r:
         | 
| 101 106 | 
             
                    lines = r.read()
         | 
| 102 107 | 
             
                    lines = lines.decode(encoding)
         | 
| @@ -104,7 +109,7 @@ def readtxt_byte(path, encoding="utf-8"): | |
| 104 109 |  | 
| 105 110 |  | 
| 106 111 | 
             
            # 读txt文件 读成一条string
         | 
| 107 | 
            -
            def readtxt_string(path, encoding="utf-8"):
         | 
| 112 | 
            +
            def readtxt_string(path, encoding="utf-8") -> str:
         | 
| 108 113 | 
             
                with codecs.open(path, 'r', encoding) as r:
         | 
| 109 114 | 
             
                    lines = r.read()
         | 
| 110 115 | 
             
                    return lines.replace('\r', '')
         | 
| @@ -236,12 +241,12 @@ def load_from_jsonl(path): | |
| 236 241 | 
             
                    return corpus
         | 
| 237 242 |  | 
| 238 243 |  | 
| 239 | 
            -
            def  | 
| 244 | 
            +
            def save_pkl(data, path):
         | 
| 240 245 | 
             
                with open(path, 'wb') as f:
         | 
| 241 246 | 
             
                    pickle.dump(data, f)
         | 
| 242 247 |  | 
| 243 248 |  | 
| 244 | 
            -
            def  | 
| 249 | 
            +
            def load_pkl(path):
         | 
| 245 250 | 
             
                with open(path, 'rb') as f:
         | 
| 246 251 | 
             
                    data = pickle.load(f)
         | 
| 247 252 | 
             
                return data
         | 
| @@ -261,6 +266,7 @@ def save_to_mongo(): | |
| 261 266 | 
             
                """
         | 
| 262 267 | 
             
                pass
         | 
| 263 268 |  | 
| 269 | 
            +
             | 
| 264 270 | 
             
            def load_from_mongo():
         | 
| 265 271 | 
             
                pass
         | 
| 266 272 |  | 
| @@ -274,4 +280,4 @@ def unmerge_cells_df(df) -> pd.DataFrame: | |
| 274 280 | 
             
                        else:
         | 
| 275 281 | 
             
                            values.append(i)
         | 
| 276 282 | 
             
                    df[column] = values
         | 
| 277 | 
            -
                return df
         | 
| 283 | 
            +
                return df
         | 
    
        nlpertools/ml.py
    CHANGED
    
    | @@ -17,10 +17,31 @@ from .io.file import readtxt_list_all_strip, writetxt_w_list, save_to_csv | |
| 17 17 | 
             
            from .utils.package import *
         | 
| 18 18 |  | 
| 19 19 |  | 
| 20 | 
            +
            def estimate_pass_at_k(num_samples:list, num_correct:list, k):
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                copy from https://huggingface.co/spaces/evaluate-metric/code_eval/blob/main/code_eval.py
         | 
| 23 | 
            +
                num_samples: list
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                """Estimates pass@k of each problem and returns them in an array."""
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def estimator(n: int, c: int, k: int) -> float:
         | 
| 28 | 
            +
                    """Calculates 1 - comb(n - c, k) / comb(n, k)."""
         | 
| 29 | 
            +
                    if n - c < k:
         | 
| 30 | 
            +
                        return 1.0
         | 
| 31 | 
            +
                    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                if isinstance(num_samples, int):
         | 
| 34 | 
            +
                    num_samples_it = itertools.repeat(num_samples, len(num_correct))
         | 
| 35 | 
            +
                else:
         | 
| 36 | 
            +
                    assert len(num_samples) == len(num_correct)
         | 
| 37 | 
            +
                    num_samples_it = iter(num_samples)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 20 42 | 
             
            def calc_llm_train_activation_memory(
         | 
| 21 | 
            -
             | 
| 43 | 
            +
                    model_name, sequence_length, batch_size, hidden_dim, lay_number, attention_heads_num, gpu_num=1
         | 
| 22 44 | 
             
            ):
         | 
| 23 | 
            -
             | 
| 24 45 | 
             
                """
         | 
| 25 46 | 
             
                return bytes
         | 
| 26 47 |  | 
| @@ -33,18 +54,18 @@ def calc_llm_train_activation_memory( | |
| 33 54 | 
             
                # FFN
         | 
| 34 55 | 
             
                # Layer Norm
         | 
| 35 56 | 
             
                r1 = (
         | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 57 | 
            +
                        sequence_length
         | 
| 58 | 
            +
                        * batch_size
         | 
| 59 | 
            +
                        * hidden_dim
         | 
| 60 | 
            +
                        * lay_number
         | 
| 61 | 
            +
                        * (34 + 5 * attention_heads_num * sequence_length / hidden_dim)
         | 
| 41 62 | 
             
                )
         | 
| 42 63 | 
             
                # reference2
         | 
| 43 64 | 
             
                r2 = (
         | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 65 | 
            +
                        lay_number * (2 * sequence_length * attention_heads_num + 16 * hidden_dim)
         | 
| 66 | 
            +
                        * sequence_length
         | 
| 67 | 
            +
                        * batch_size
         | 
| 68 | 
            +
                        / gpu_num
         | 
| 48 69 | 
             
                )
         | 
| 49 70 | 
             
                print(r1)
         | 
| 50 71 | 
             
                print(r2)
         | 
| @@ -80,7 +101,7 @@ class DataStructure: | |
| 80 101 | 
             
                }
         | 
| 81 102 | 
             
                ner_input_example = "这句话一共有两个实体分别为大象和老鼠。"
         | 
| 82 103 | 
             
                ner_label_example = (
         | 
| 83 | 
            -
             | 
| 104 | 
            +
                        list("OOOOOOOOOOOOO") + ["B-s", "I-s"] + ["O"] + ["B-o", "I-o"] + ["O"]
         | 
| 84 105 | 
             
                )
         | 
| 85 106 |  | 
| 86 107 |  | 
| @@ -135,7 +156,7 @@ class STEM(object): | |
| 135 156 | 
             
                        if each_srl:
         | 
| 136 157 | 
             
                            args = []
         | 
| 137 158 | 
             
                            for arg in each_srl:
         | 
| 138 | 
            -
                                args.extend(seg[arg[1] | 
| 159 | 
            +
                                args.extend(seg[arg[1]: arg[2] + 1])
         | 
| 139 160 | 
             
                            # 添加上谓词
         | 
| 140 161 | 
             
                            args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
         | 
| 141 162 | 
             
                            events.append(args)
         | 
| @@ -174,7 +195,7 @@ def subject_object_labeling(spo_list, text): | |
| 174 195 | 
             
                    q_list_length = len(q_list)
         | 
| 175 196 | 
             
                    k_list_length = len(k_list)
         | 
| 176 197 | 
             
                    for idx in range(k_list_length - q_list_length + 1):
         | 
| 177 | 
            -
                        t = [q == k for q, k in zip(q_list, k_list[idx | 
| 198 | 
            +
                        t = [q == k for q, k in zip(q_list, k_list[idx: idx + q_list_length])]
         | 
| 178 199 | 
             
                        # print(idx, t)
         | 
| 179 200 | 
             
                        if all(t):
         | 
| 180 201 | 
             
                            # print(idx)
         | 
| @@ -187,8 +208,8 @@ def subject_object_labeling(spo_list, text): | |
| 187 208 | 
             
                    if len(spo) == 2:
         | 
| 188 209 | 
             
                        labeling_list[idx_start + 1] = "I-" + spo_type
         | 
| 189 210 | 
             
                    elif len(spo) >= 3:
         | 
| 190 | 
            -
                        labeling_list[idx_start + 1 | 
| 191 | 
            -
             | 
| 211 | 
            +
                        labeling_list[idx_start + 1: idx_start + len(spo)] = ["I-" + spo_type] * (
         | 
| 212 | 
            +
                                len(spo) - 1
         | 
| 192 213 | 
             
                        )
         | 
| 193 214 | 
             
                    else:
         | 
| 194 215 | 
             
                        pass
         | 
| @@ -239,12 +260,12 @@ def convert_crf_format_10_fold(corpus, objdir_path): | |
| 239 260 | 
             
                split_position = int(len(corpus) / 10)
         | 
| 240 261 | 
             
                for k in range(0, 10):
         | 
| 241 262 | 
             
                    if k == 9:
         | 
| 242 | 
            -
                        dev_set = corpus[k * split_position | 
| 263 | 
            +
                        dev_set = corpus[k * split_position:]
         | 
| 243 264 | 
             
                        train_set = corpus[: k * split_position]
         | 
| 244 265 | 
             
                    else:
         | 
| 245 | 
            -
                        dev_set = corpus[k * split_position | 
| 266 | 
            +
                        dev_set = corpus[k * split_position: (k + 1) * split_position]
         | 
| 246 267 | 
             
                        train_set = (
         | 
| 247 | 
            -
             | 
| 268 | 
            +
                                corpus[: k * split_position] + corpus[(k + 1) * split_position:]
         | 
| 248 269 | 
             
                        )
         | 
| 249 270 | 
             
                    writetxt_w_list(
         | 
| 250 271 | 
             
                        train_set, os.path.join(objdir_path, "train{}.txt".format(k + 1))
         | 
| @@ -292,12 +313,41 @@ def kfold_txt(corpus, path, k=9, is_shuffle=True): | |
| 292 313 | 
             
                if is_shuffle:
         | 
| 293 314 | 
             
                    random.shuffle(corpus)
         | 
| 294 315 | 
             
                split_position = int(len(corpus) / 10)
         | 
| 295 | 
            -
                train_set, dev_set = corpus[: k * split_position], corpus[k * split_position | 
| 316 | 
            +
                train_set, dev_set = corpus[: k * split_position], corpus[k * split_position:]
         | 
| 296 317 | 
             
                writetxt_w_list(train_set, os.path.join(path, "train.tsv"), num_lf=1)
         | 
| 297 318 | 
             
                writetxt_w_list(dev_set, os.path.join(path, "test.tsv"), num_lf=1)
         | 
| 298 319 | 
             
                writetxt_w_list(dev_set, os.path.join(path, "dev.tsv"), num_lf=1)
         | 
| 299 320 |  | 
| 300 321 |  | 
| 322 | 
            +
            def sample():
         | 
| 323 | 
            +
                import pandas as pd
         | 
| 324 | 
            +
                from sklearn.model_selection import StratifiedShuffleSplit
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                # 假设 df 是你的 DataFrame
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                df = pd.DataFrame({
         | 
| 329 | 
            +
                    "count_line": [i for i in range(100)],
         | 
| 330 | 
            +
                    "x": [i for i in range(100)],
         | 
| 331 | 
            +
                    "y": [i // 10 for i in range(100)],
         | 
| 332 | 
            +
                })
         | 
| 333 | 
            +
                print(df)
         | 
| 334 | 
            +
                # count_line 是用于分层抽样的字段
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                # 创建 StratifiedShuffleSplit 对象,设置测试集比例为 0.1
         | 
| 337 | 
            +
                split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                # 获取训练集和测试集的索引
         | 
| 340 | 
            +
                train_index, test_index = next(split.split(df, df['y']))
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                # 根据索引划分训练集和测试集
         | 
| 343 | 
            +
                train_df = df.loc[train_index]
         | 
| 344 | 
            +
                test_df = df.loc[test_index]
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                # 打印训练集和测试集的行数
         | 
| 347 | 
            +
                print("训练集行数:", len(train_df))
         | 
| 348 | 
            +
                print("测试集行数:", len(test_df))
         | 
| 349 | 
            +
             | 
| 350 | 
            +
             | 
| 301 351 | 
             
            def kfold_df(df, save_dir=None):
         | 
| 302 352 | 
             
                """
         | 
| 303 353 | 
             
                划分train test val集, 写为windows可读的csv。
         | 
| @@ -389,7 +439,7 @@ def split_sentence(sentence, language="chinese", cross_line=True): | |
| 389 439 | 
             
                for idx, char in enumerate(sentence):
         | 
| 390 440 | 
             
                    if idx == len(sentence) - 1:
         | 
| 391 441 | 
             
                        if char in split_signs:
         | 
| 392 | 
            -
                            sentences.append(sentence[start_idx | 
| 442 | 
            +
                            sentences.append(sentence[start_idx: idx + 1].strip())
         | 
| 393 443 | 
             
                            start_idx = idx + 1
         | 
| 394 444 | 
             
                        else:
         | 
| 395 445 | 
             
                            sentences.append(sentence[start_idx:].strip())
         | 
| @@ -399,10 +449,10 @@ def split_sentence(sentence, language="chinese", cross_line=True): | |
| 399 449 | 
             
                                if idx < len(sentence) - 2:
         | 
| 400 450 | 
             
                                    # 处理。”。
         | 
| 401 451 | 
             
                                    if sentence[idx + 2] not in split_signs:
         | 
| 402 | 
            -
                                        sentences.append(sentence[start_idx | 
| 452 | 
            +
                                        sentences.append(sentence[start_idx: idx + 2].strip())
         | 
| 403 453 | 
             
                                        start_idx = idx + 2
         | 
| 404 454 | 
             
                            elif sentence[idx + 1] not in split_signs:
         | 
| 405 | 
            -
                                sentences.append(sentence[start_idx | 
| 455 | 
            +
                                sentences.append(sentence[start_idx: idx + 1].strip())
         | 
| 406 456 | 
             
                                start_idx = idx + 1
         | 
| 407 457 | 
             
                return sentences
         | 
| 408 458 |  | 
| @@ -480,4 +530,4 @@ if __name__ == "__main__": | |
| 480 530 | 
             
                    attention_heads_num=32,
         | 
| 481 531 | 
             
                    gpu_num=1
         | 
| 482 532 | 
             
                )
         | 
| 483 | 
            -
                print(res, "G")
         | 
| 533 | 
            +
                print(res, "G")
         | 
    
        nlpertools/other.py
    CHANGED
    
    | @@ -5,10 +5,13 @@ import itertools | |
| 5 5 | 
             
            import os
         | 
| 6 6 | 
             
            import re
         | 
| 7 7 | 
             
            import string
         | 
| 8 | 
            +
            import subprocess
         | 
| 9 | 
            +
            import threading
         | 
| 8 10 | 
             
            from concurrent.futures import ThreadPoolExecutor
         | 
| 9 11 | 
             
            from functools import reduce
         | 
| 10 12 | 
             
            import math
         | 
| 11 13 | 
             
            import datetime
         | 
| 14 | 
            +
            import difflib
         | 
| 12 15 | 
             
            import psutil
         | 
| 13 16 | 
             
            from .io.file import writetxt_w_list, writetxt_a
         | 
| 14 17 | 
             
            # import numpy as np
         | 
| @@ -27,6 +30,149 @@ ENGLISH_PUNCTUATION = list(',.;:\'"!?<>()') | |
| 27 30 | 
             
            OTHER_PUNCTUATION = list('!@#$%^&*')
         | 
| 28 31 |  | 
| 29 32 |  | 
| 33 | 
            +
            def setup_logging(log_file):
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
                Set up logging configuration.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                Args:
         | 
| 38 | 
            +
                    log_file (str): Path to the log file.
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                logging.basicConfig(
         | 
| 41 | 
            +
                    filename=log_file,
         | 
| 42 | 
            +
                    level=logging.INFO,
         | 
| 43 | 
            +
                    format='%(asctime)s - %(levelname)s - %(message)s',
         | 
| 44 | 
            +
                    datefmt='%Y-%m-%d %H:%M:%S'
         | 
| 45 | 
            +
                )
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def get_diff_parts(str1, str2):
         | 
| 49 | 
            +
                # 创建一个 SequenceMatcher 对象
         | 
| 50 | 
            +
                matcher = difflib.SequenceMatcher(None, str1, str2)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                # 获取差异部分
         | 
| 53 | 
            +
                diff_parts = []
         | 
| 54 | 
            +
                for tag, i1, i2, j1, j2 in matcher.get_opcodes():
         | 
| 55 | 
            +
                    if tag == 'replace' or tag == 'delete' or tag == 'insert':
         | 
| 56 | 
            +
                        diff_parts.append((tag, str1[i1:i2], str2[j1:j2]))
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                return diff_parts
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def run_cmd_with_timeout(cmd, timeout):
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
                https://juejin.cn/post/7391703459803086848
         | 
| 64 | 
            +
                """
         | 
| 65 | 
            +
                process = subprocess.Popen(cmd, shell=True, encoding="utf-8", errors="ignore", stdout=subprocess.PIPE,
         | 
| 66 | 
            +
                                           stderr=subprocess.PIPE)
         | 
| 67 | 
            +
                res = [None]
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def target():
         | 
| 70 | 
            +
                    try:
         | 
| 71 | 
            +
                        ans = process.communicate()
         | 
| 72 | 
            +
                        res[0] = ans
         | 
| 73 | 
            +
                    except subprocess.TimeoutExpired:
         | 
| 74 | 
            +
                        process.kill()
         | 
| 75 | 
            +
                        process.communicate()
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                thread = threading.Thread(target=target)
         | 
| 78 | 
            +
                thread.start()
         | 
| 79 | 
            +
                thread.join(timeout)
         | 
| 80 | 
            +
                if thread.is_alive():
         | 
| 81 | 
            +
                    print(f"Terminating {cmd}")
         | 
| 82 | 
            +
                    process.terminate()
         | 
| 83 | 
            +
                    thread.join()
         | 
| 84 | 
            +
                    print("Terminated successfully")
         | 
| 85 | 
            +
                    return False, f"{cmd} is running over {timeout}s"
         | 
| 86 | 
            +
                if process.returncode == 0:
         | 
| 87 | 
            +
                    # res[0][0] 是output
         | 
| 88 | 
            +
                    return True, res[0][0]
         | 
| 89 | 
            +
                else:
         | 
| 90 | 
            +
                    return False, res[0][0]
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def print_three_line_table(df):
         | 
| 94 | 
            +
                # TODO 这里需要添加可以支持excel里变红的功能
         | 
| 95 | 
            +
                import webbrowser
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                # import pandas as pd
         | 
| 98 | 
            +
                # data = {'from_pc': ['valid_data', 'illegal_char', 'more_data'],
         | 
| 99 | 
            +
                #         'rom_pc': ['another_valid_data', 'illegal_char', 'data']}
         | 
| 100 | 
            +
                # df = pd.DataFrame(data)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                # 将 DataFrame 转换为 HTML 表格
         | 
| 103 | 
            +
                html_table = df.to_html(index=False)
         | 
| 104 | 
            +
                html_table = html_table.replace('border="1"', 'border="0"')
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                first_line_px = str(2)
         | 
| 107 | 
            +
                second_line_px = str(1)
         | 
| 108 | 
            +
                third_line_px = str(2)
         | 
| 109 | 
            +
                # 定义三线表的 CSS 样式
         | 
| 110 | 
            +
                # // thead 表头
         | 
| 111 | 
            +
                # // tr 行
         | 
| 112 | 
            +
                # // td 单元格
         | 
| 113 | 
            +
                head = """<!DOCTYPE html>
         | 
| 114 | 
            +
                <html lang="zh">
         | 
| 115 | 
            +
                <head>
         | 
| 116 | 
            +
                    <meta charset="UTF-8">
         | 
| 117 | 
            +
                    <title>页面标题</title>
         | 
| 118 | 
            +
                </head>"""
         | 
| 119 | 
            +
                style = """
         | 
| 120 | 
            +
                <style>
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    table {
         | 
| 123 | 
            +
                        border-collapse: collapse;
         | 
| 124 | 
            +
                    }
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    tr, td, th {
         | 
| 127 | 
            +
                        text-align: center; /* 水平居中文本 */
         | 
| 128 | 
            +
                        vertical-align: middle; /* 垂直居中文本 */
         | 
| 129 | 
            +
                    }
         | 
| 130 | 
            +
                    thead tr {
         | 
| 131 | 
            +
                        border-top: (first_line_px)px solid black;
         | 
| 132 | 
            +
                        border-bottom: (second_line_px)px solid black;
         | 
| 133 | 
            +
                    }
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    thead th {
         | 
| 136 | 
            +
                        border-bottom: (second_line_px)px solid black;
         | 
| 137 | 
            +
                    }
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    tbody tr td {
         | 
| 140 | 
            +
                        border-bottom: 0px solid black;
         | 
| 141 | 
            +
                    }
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    tbody tr:last-child td {
         | 
| 144 | 
            +
                        border-bottom: (third_line_px)px solid black;
         | 
| 145 | 
            +
                    }
         | 
| 146 | 
            +
                </style>"""
         | 
| 147 | 
            +
                style = style.replace("(first_line_px)", first_line_px).replace("(second_line_px)", second_line_px).replace(
         | 
| 148 | 
            +
                    "(third_line_px)", third_line_px)
         | 
| 149 | 
            +
                # 将 CSS 样式和 HTML 表格结合起来
         | 
| 150 | 
            +
                html = f"{style}{html_table}"
         | 
| 151 | 
            +
                print(html)
         | 
| 152 | 
            +
                temp_file_path = "temp.html"
         | 
| 153 | 
            +
                # 将 HTML 保存到文件中
         | 
| 154 | 
            +
                with open(temp_file_path, "w") as f:
         | 
| 155 | 
            +
                    f.write(html)
         | 
| 156 | 
            +
                webbrowser.open('file://' + os.path.realpath(temp_file_path))
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            def jprint(obj, depth=0):
         | 
| 160 | 
            +
                if isinstance(obj, dict):
         | 
| 161 | 
            +
                    sep = "-" * (10 - depth * 3)
         | 
| 162 | 
            +
                    for k, v in obj.items():
         | 
| 163 | 
            +
                        print(depth * "|", sep, k, sep)
         | 
| 164 | 
            +
                        jprint(v)
         | 
| 165 | 
            +
                elif isinstance(obj, list):
         | 
| 166 | 
            +
                    for v in obj:
         | 
| 167 | 
            +
                        jprint(v, depth + 1)
         | 
| 168 | 
            +
                else:
         | 
| 169 | 
            +
                    print(obj)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
            def print_split(sign="=", num=20):
         | 
| 173 | 
            +
                print(sign * num)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
             | 
| 30 176 | 
             
            def seed_everything():
         | 
| 31 177 | 
             
                import torch
         | 
| 32 178 | 
             
                # seed everything
         | 
| @@ -82,21 +228,6 @@ def convert_np_to_py(obj): | |
| 82 228 | 
             
                    return obj
         | 
| 83 229 |  | 
| 84 230 |  | 
| 85 | 
            -
            def git_push():
         | 
| 86 | 
            -
                """
         | 
| 87 | 
            -
                针对国内提交github经常失败,自动提交
         | 
| 88 | 
            -
                """
         | 
| 89 | 
            -
                num = -1
         | 
| 90 | 
            -
                while 1:
         | 
| 91 | 
            -
                    num += 1
         | 
| 92 | 
            -
                    print("retry num: {}".format(num))
         | 
| 93 | 
            -
                    info = os.system("git push --set-upstream origin main")
         | 
| 94 | 
            -
                    print(str(info))
         | 
| 95 | 
            -
                    if not str(info).startswith("fatal"):
         | 
| 96 | 
            -
                        print("scucess")
         | 
| 97 | 
            -
                        break
         | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 231 | 
             
            def snake_to_camel(s: str) -> str:
         | 
| 101 232 | 
             
                """
         | 
| 102 233 | 
             
                author: u
         | 
| @@ -235,25 +366,22 @@ def stress_test(func, ipts): | |
| 235 366 | 
             
                return results
         | 
| 236 367 |  | 
| 237 368 |  | 
| 238 | 
            -
            def get_substring_loc(text, subtext):
         | 
| 239 | 
            -
                res = re.finditer(
         | 
| 240 | 
            -
                    subtext.replace('\\', '\\\\').replace('?', '\?').replace('(', '\(').replace(')', '\)').replace(']',
         | 
| 241 | 
            -
                                                                                                                   '\]').replace(
         | 
| 242 | 
            -
                        '[', '\[').replace('+', '\+'), text)
         | 
| 243 | 
            -
                l, r = [i for i in res][0].regs[0]
         | 
| 244 | 
            -
                return l, r
         | 
| 245 | 
            -
             | 
| 246 | 
            -
             | 
| 247 369 | 
             
            def squeeze_list(high_dim_list):
         | 
| 248 370 | 
             
                return list(itertools.chain.from_iterable(high_dim_list))
         | 
| 249 371 |  | 
| 250 372 |  | 
| 251 373 | 
             
            def unsqueeze_list(flatten_list, each_element_len):
         | 
| 374 | 
            +
                # 该函数是错的,被split_list替代了
         | 
| 252 375 | 
             
                two_dim_list = [flatten_list[i * each_element_len:(i + 1) * each_element_len] for i in
         | 
| 253 376 | 
             
                                range(len(flatten_list) // each_element_len)]
         | 
| 254 377 | 
             
                return two_dim_list
         | 
| 255 378 |  | 
| 256 379 |  | 
| 380 | 
            +
            def split_list(input_list, chunk_size):
         | 
| 381 | 
            +
                # 使用列表推导式将列表分割成二维数组
         | 
| 382 | 
            +
                return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)]
         | 
| 383 | 
            +
             | 
| 384 | 
            +
             | 
| 257 385 | 
             
            def auto_close():
         | 
| 258 386 | 
             
                """
         | 
| 259 387 | 
             
                针对企业微信15分钟会显示离开的机制,假装自己还在上班
         | 
| @@ -1,6 +1,6 @@ | |
| 1 | 
            -
            Metadata-Version: 2. | 
| 1 | 
            +
            Metadata-Version: 2.2
         | 
| 2 2 | 
             
            Name: nlpertools
         | 
| 3 | 
            -
            Version: 1.0. | 
| 3 | 
            +
            Version: 1.0.9
         | 
| 4 4 | 
             
            Summary: A small package about small basic IO operation when coding
         | 
| 5 5 | 
             
            Home-page: https://github.com/lvzii/nlpertools
         | 
| 6 6 | 
             
            Author: youshuJi
         | 
| @@ -12,6 +12,13 @@ Classifier: Operating System :: OS Independent | |
| 12 12 | 
             
            Requires-Python: >=3.6
         | 
| 13 13 | 
             
            Description-Content-Type: text/markdown
         | 
| 14 14 | 
             
            License-File: LICENSE
         | 
| 15 | 
            +
            Requires-Dist: numpy
         | 
| 16 | 
            +
            Requires-Dist: pandas
         | 
| 17 | 
            +
            Requires-Dist: psutil
         | 
| 18 | 
            +
            Provides-Extra: torch
         | 
| 19 | 
            +
            Requires-Dist: torch; extra == "torch"
         | 
| 20 | 
            +
            Dynamic: provides-extra
         | 
| 21 | 
            +
            Dynamic: requires-dist
         | 
| 15 22 |  | 
| 16 23 | 
             
            <div align="center">
         | 
| 17 24 | 
             
              <h4 align="center">
         | 
| @@ -23,9 +30,6 @@ License-File: LICENSE | |
| 23 30 | 
             
            </div>
         | 
| 24 31 |  | 
| 25 32 |  | 
| 26 | 
            -
            # 当前版本
         | 
| 27 | 
            -
             | 
| 28 | 
            -
            1.0.5
         | 
| 29 33 |  | 
| 30 34 | 
             
            # 说明
         | 
| 31 35 |  | 
| @@ -33,7 +37,7 @@ License-File: LICENSE | |
| 33 37 |  | 
| 34 38 | 
             
            它解决了什么问题:
         | 
| 35 39 |  | 
| 36 | 
            -
            - 很多函数是记不住的,  | 
| 40 | 
            +
            - 很多函数是记不住的, 每次写都要~~搜~~问大模型 ,例如pandas排序
         | 
| 37 41 | 
             
            - 刷题的时候,树结构的题目很难调试
         | 
| 38 42 |  | 
| 39 43 |  | 
| @@ -75,9 +79,9 @@ https://nlpertools.readthedocs.io/en/latest/ | |
| 75 79 | 
             
                  def __init__(self, IPT_MODEL_PATH):
         | 
| 76 80 | 
             
                      self.ltp = LTP(IPT_MODEL_PATH)
         | 
| 77 81 | 
             
              ```
         | 
| 78 | 
            -
               | 
| 82 | 
            +
              通过`pyinstrument`判断,超过1s的包即采用这种方式
         | 
| 79 83 | 
             
              - 2s+ happybase、seaborn、torch、jieba
         | 
| 80 | 
            -
              - 1s+
         | 
| 84 | 
            +
              - 1s+ /
         | 
| 81 85 | 
             
              - 0.5s+ pandas elasticsearch transformers xgboost nltk mongo
         | 
| 82 86 |  | 
| 83 87 |  | 
| @@ -85,6 +89,8 @@ https://nlpertools.readthedocs.io/en/latest/ | |
| 85 89 |  | 
| 86 90 | 
             
            - [readthedoc 检查文档构建状况](https://readthedocs.org/projects/nlpertools/builds)
         | 
| 87 91 |  | 
| 92 | 
            +
            - [打包发布指南](https://juejin.cn/post/7369413136224878644)
         | 
| 93 | 
            +
             | 
| 88 94 | 
             
            - 发布版本需要加tag
         | 
| 89 95 |  | 
| 90 96 | 
             
            ## 开发哲学
         | 
| @@ -106,6 +112,23 @@ b = nlpertools.io.file.readtxt_list_all_strip('res.txt') | |
| 106 112 | 
             
            ```
         | 
| 107 113 |  | 
| 108 114 | 
             
            ```bash
         | 
| 109 | 
            -
            #  | 
| 110 | 
            -
            python -m nlpertools
         | 
| 115 | 
            +
            # 生成pypi双因素认证的实时密钥(需要提供key)
         | 
| 116 | 
            +
            python -m nlpertools.cli --get_2fa --get_2fa_key your_key
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            ## git
         | 
| 119 | 
            +
            python -m nlpertools.cli --git_push
         | 
| 120 | 
            +
            python -m nlpertools.cli --git_pull
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            # 以下功能被nvitop替代,不推荐使用
         | 
| 123 | 
            +
            ## 监控gpu显存
         | 
| 124 | 
            +
            python -m nlpertools.monitor.gpu
         | 
| 125 | 
            +
            ## 监控cpu
         | 
| 126 | 
            +
            python -m  nlpertools.monitor.memory
         | 
| 111 127 | 
             
            ```
         | 
| 128 | 
            +
             | 
| 129 | 
            +
            ## 一些常用项目
         | 
| 130 | 
            +
             | 
| 131 | 
            +
            nvitop
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            ydata-profiling
         | 
| 134 | 
            +
             | 
| @@ -1,12 +1,14 @@ | |
| 1 | 
            -
            nlpertools/__init__.py,sha256= | 
| 2 | 
            -
            nlpertools/ | 
| 3 | 
            -
            nlpertools/ | 
| 1 | 
            +
            nlpertools/__init__.py,sha256=5ka-NeGW2AUDJ4YZ12DD64xcxuxf9PlQUurxDp5DHbQ,483
         | 
| 2 | 
            +
            nlpertools/cli.py,sha256=4Ik1NyFaoZpZLsYLAFRLk6xuYQk0IvexPr1Ieq08viE,3932
         | 
| 3 | 
            +
            nlpertools/data_client.py,sha256=esX8lUQrTui4uVkqPfhpHVok7Eq6ywpuemKjLeqoglc,14674
         | 
| 4 | 
            +
            nlpertools/dataprocess.py,sha256=v1mobuYN7I3dT6xIKlNOHVtcg31YtjF6FwNPTxeBFFY,23153
         | 
| 4 5 | 
             
            nlpertools/default_db_config.yml,sha256=E1K9k_xzXVlsf-HJQh8kyHXHYuvTpD12jD4Hfe5rUk8,606
         | 
| 5 | 
            -
            nlpertools/ | 
| 6 | 
            +
            nlpertools/get_2fa.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 7 | 
            +
            nlpertools/ml.py,sha256=qhUBCLuHfcFy8g5ZHNGYq4eH2vYWiGetyKucv8n60-A,18523
         | 
| 6 8 | 
             
            nlpertools/movie.py,sha256=rkyOnAXdsbWfMSbi1sE1VNRT7f66Hp9BnZsN_58Afmw,897
         | 
| 7 9 | 
             
            nlpertools/nlpertools_config.yml,sha256=ksXejxFs7pxR47tNAsrN88_4gvq9PCA2ZMO07H-dJXY,26
         | 
| 8 10 | 
             
            nlpertools/open_api.py,sha256=uyTY00OUlM57Cn0Wm0yZXcIS8vAszy9rKnDMBEWfWJM,1744
         | 
| 9 | 
            -
            nlpertools/other.py,sha256= | 
| 11 | 
            +
            nlpertools/other.py,sha256=JWJiXHRI8mhiUV3k4CZ4kQQS9QN3mw67SmGgTqZFtjs,15026
         | 
| 10 12 | 
             
            nlpertools/pic.py,sha256=13aaFJh3USGYGs4Y9tAKTvWjmdQR4YDjl3LlIhJheOA,9906
         | 
| 11 13 | 
             
            nlpertools/plugin.py,sha256=LB7j9GdoQi6TITddH-6EglHlOa0WIHLUT7X5vb_aIZY,1168
         | 
| 12 14 | 
             
            nlpertools/reminder.py,sha256=wiXwZQmxMck5vY3EvG8_oakP3FAdjGTikAIOiTPUQrs,2977
         | 
| @@ -22,9 +24,12 @@ nlpertools/algo/template.py,sha256=9vsHr4g3jZZ5KVU_2I9i97o8asRXq-8pSaCXIv0sHeM,2 | |
| 22 24 | 
             
            nlpertools/algo/union.py,sha256=0l7lGZbw1qIfW1z5TE8Oo3tybL1bKIP5rzpa5ZT-vLQ,249
         | 
| 23 25 | 
             
            nlpertools/data_structure/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 24 26 | 
             
            nlpertools/data_structure/base_structure.py,sha256=gVUvJZ5jsCAswRETTpMwcEjLKoageWiTuCKNEwIWKWk,2641
         | 
| 27 | 
            +
            nlpertools/draw/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 28 | 
            +
            nlpertools/draw/draw.py,sha256=19dskkr0wrgczxPJnphEszliwYshEh5SjD8Zz07nlk0,2615
         | 
| 29 | 
            +
            nlpertools/draw/math_func.py,sha256=0NQ22Dfi9DFG6Bg_hXnCT27w65-dqpOOIgZX7oUIW-Q,881
         | 
| 25 30 | 
             
            nlpertools/io/__init__.py,sha256=YMuKtC2Ddh5dL5MvXjyUKYOOuqzFYUhBPFaP2kyFG9I,68
         | 
| 26 | 
            -
            nlpertools/io/dir.py,sha256= | 
| 27 | 
            -
            nlpertools/io/file.py,sha256= | 
| 31 | 
            +
            nlpertools/io/dir.py,sha256=FPY62COQN8Ji72pk0dYRoXkrORYaUlybKNcL4474uUI,2263
         | 
| 32 | 
            +
            nlpertools/io/file.py,sha256=mLWl09IEi0rWPN4tTq3LwdYMvAjj4e_QsjEMhufuPPo,7192
         | 
| 28 33 | 
             
            nlpertools/monitor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 29 34 | 
             
            nlpertools/monitor/gpu.py,sha256=M59O6i0hlew7AzXZlaVZqbZA5IR93OhBY2WI0-T_HtY,531
         | 
| 30 35 | 
             
            nlpertools/monitor/memory.py,sha256=9t6q9BC8VVx4o3G4sBCn7IoQRx272zMPjSnL3yvTBAQ,657
         | 
| @@ -36,8 +41,9 @@ nlpertools/utils/package.py,sha256=wLg_M8j7Y6ReRjWHWCWoZJHrzEwuAr9TyG2jvb7OQCo,3 | |
| 36 41 | 
             
            nlpertools/utils/package_v1.py,sha256=sqgFb-zbTdMd5ziJLY6YUPqR49qUNZjxBH35DnyR5Wg,3542
         | 
| 37 42 | 
             
            nlpertools/utils/package_v2.py,sha256=WOcsguWfUd4XSAfmPgCtL8HtUbqJ6GRSMHb0OsB47r0,3932
         | 
| 38 43 | 
             
            nlpertools_helper/__init__.py,sha256=obxRUdZDctvcvK_iA1Dx2HmQFMlMzJto-xDPryq1lJ0,198
         | 
| 39 | 
            -
            nlpertools-1.0. | 
| 40 | 
            -
            nlpertools-1.0. | 
| 41 | 
            -
            nlpertools-1.0. | 
| 42 | 
            -
            nlpertools-1.0. | 
| 43 | 
            -
            nlpertools-1.0. | 
| 44 | 
            +
            nlpertools-1.0.9.dist-info/LICENSE,sha256=SBcMozykvTbZJ--MqSiKUmHLLROdnr25V70xCQgEwqw,11331
         | 
| 45 | 
            +
            nlpertools-1.0.9.dist-info/METADATA,sha256=lcKmxc7_mtYH47mPj8UHOM8-5T5YtrDwhHWVZkfHZXU,3330
         | 
| 46 | 
            +
            nlpertools-1.0.9.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
         | 
| 47 | 
            +
            nlpertools-1.0.9.dist-info/entry_points.txt,sha256=XEazQ4vUwJMoMAgAwk1Lq4PRQGklPkPBaFkiP0zN_JE,45
         | 
| 48 | 
            +
            nlpertools-1.0.9.dist-info/top_level.txt,sha256=_4q4MIFvMr4cAUbhWKWYdRXIXsF4PJDg4BUsZvgk94s,29
         | 
| 49 | 
            +
            nlpertools-1.0.9.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         |