ygo 1.0.2__py3-none-any.whl → 1.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ygo/__init__.py CHANGED
@@ -1,10 +1,31 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- ---------------------------------------------
4
- Created on 2025/4/28 15:25
5
- @author: ZhangYundi
6
- @email: yundi.xxii@outlook.com
7
- ---------------------------------------------
8
- """
1
+ from .exceptions import FailTaskError, WarnException
2
+ from .pool import pool
3
+ from .delay import delay
4
+ from .utils import (
5
+ fn_params,
6
+ fn_signature_params,
7
+ fn_path,
8
+ fn_code,
9
+ fn_info,
10
+ module_from_str,
11
+ fn_from_str,
12
+ locate,
13
+ )
14
+ from .lazy import lazy_import
9
15
 
10
- from .ygo import *
16
+ __version__ = "1.2.12"
17
+
18
+ __all__ = [
19
+ "FailTaskError",
20
+ "delay",
21
+ "WarnException",
22
+ "fn_params",
23
+ "fn_signature_params",
24
+ "fn_path",
25
+ "fn_code",
26
+ "fn_info",
27
+ "fn_from_str",
28
+ "module_from_str",
29
+ "pool",
30
+ "lazy_import"
31
+ ]
ygo/delay.py ADDED
@@ -0,0 +1,89 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ ---------------------------------------------
4
+ Created on 2025/5/26 20:12
5
+ @author: ZhangYundi
6
+ @email: yundi.xxii@outlook.com
7
+ ---------------------------------------------
8
+ """
9
+
10
+ import functools
11
+ import inspect
12
+
13
+
14
+ class DelayedFunction:
15
+
16
+ def __init__(self, func):
17
+ self.func = func
18
+ self._fn_params_k = inspect.signature(self.func).parameters.keys()
19
+ self.stored_kwargs = self._get_default_args(func)
20
+ if hasattr(func, 'stored_kwargs'):
21
+ self.stored_kwargs.update(func.stored_kwargs)
22
+
23
+ def _get_default_args(self, func):
24
+ signature = inspect.signature(func)
25
+ return {
26
+ k: v.default
27
+ for k, v in signature.parameters.items()
28
+ if v.default is not inspect.Parameter.empty
29
+ }
30
+
31
+ def __call__(self, *args, **kwargs):
32
+ def delayed(*args, **_kwargs):
33
+ new_kwargs = {k: v for k, v in self.stored_kwargs.items()}
34
+ for k, v in _kwargs.items():
35
+ if k not in self._fn_params_k:
36
+ continue
37
+ new_kwargs[k] = v
38
+ return self.func(*args, **new_kwargs)
39
+
40
+ self._stored_kwargs(**kwargs)
41
+ new_fn = functools.wraps(self.func)(delayed)
42
+ new_fn.stored_kwargs = self.stored_kwargs
43
+ return new_fn
44
+
45
+ def _stored_kwargs(self, **kwargs):
46
+ for k, v in kwargs.items():
47
+ if k not in self._fn_params_k:
48
+ continue
49
+ self.stored_kwargs[k] = v
50
+
51
+
52
+ def delay(func):
53
+ """
54
+ 延迟执行
55
+ Parameters
56
+ ----------
57
+ func: Callable
58
+ 需要延迟执行的对象, 必须是一个Callable对象
59
+
60
+ Returns
61
+ -------
62
+ DelayedFunction
63
+ 将预先设置好的参数包装进原始的Callable对象中
64
+
65
+ Examples
66
+ --------
67
+
68
+ 场景1:基本使用
69
+
70
+ >>> fn = delay(lambda a, b: a+b)(a=1, b=2)
71
+ >>> fn()
72
+ 3
73
+
74
+ 场景2: 逐步传递参数
75
+
76
+ >>> fn1 = delay(lambda a, b, c: a+b+c)(a=1)
77
+ >>> fn2 = delay(fn1)(b=2)
78
+ >>> fn2(c=3)
79
+ 6
80
+
81
+ 场景3: 参数更改
82
+
83
+ >>> fn1 = delay(lambda a, b, c: a+b+c)(a=1, b=2)
84
+ >>> fn2 = delay(fn1)(c=3, b=5)
85
+ >>> fn2()
86
+ 9
87
+ """
88
+ return DelayedFunction(func)
89
+
ygo/exceptions.py CHANGED
@@ -7,7 +7,23 @@ Created on 2024/12/18 下午7:01
7
7
  ---------------------------------------------
8
8
  """
9
9
 
10
+ from dataclasses import dataclass
11
+
10
12
  class WarnException(Exception):
11
13
  """自定义异常类,仅用于警告"""
12
14
  def __init__(self, message):
13
- super().__init__(message) # 调用父类的构造函数
15
+ super().__init__(message) # 调用父类的构造函数
16
+
17
+ @dataclass
18
+ class FailTaskError:
19
+ task_name: str
20
+ error: Exception
21
+
22
+ def __str__(self):
23
+ return f"""
24
+ [失败任务]: {self.task_name}
25
+ [错误信息]: \n{self.error}
26
+ """
27
+
28
+ def __repr__(self):
29
+ return self.__str__()
ygo/lazy.py ADDED
@@ -0,0 +1,50 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ ---------------------------------------------
4
+ Copyright (c) 2025 ZhangYundi
5
+ Licensed under the MIT License.
6
+ Created on 2025/6/29 10:36
7
+ Email: yundi.xxii@outlook.com
8
+ Description:
9
+ ---------------------------------------------
10
+ """
11
+
12
+ from typing import Any
13
+ from .utils import locate
14
+
15
+ class LazyImport:
16
+ def __init__(self, full_name: str):
17
+ self._full_name = full_name
18
+ self._obj = None
19
+
20
+ def _load(self) -> Any:
21
+ if self._obj is None:
22
+ self._obj = locate(self._full_name)
23
+ return self._obj
24
+
25
+ def __getattr__(self, attr: str) -> Any:
26
+ obj = self._load()
27
+ return getattr(obj, attr)
28
+
29
+ def __dir__(self) -> list[str]:
30
+ obj = self._load()
31
+ return dir(obj)
32
+
33
+ def __call__(self, *args, **kwargs) -> Any:
34
+ obj = self._load()
35
+ if isinstance(obj, type):
36
+ return obj(*args, **kwargs)
37
+ elif callable(obj):
38
+ return obj(*args, **kwargs)
39
+ else:
40
+ raise TypeError(f"The target `{self._full_name}` is not callable or instantiable.")
41
+
42
+ def __str__(self):
43
+ return self._full_name
44
+
45
+ def __repr__(self):
46
+ return self._full_name
47
+
48
+ def lazy_import(full_name: str):
49
+ """实现模块和方法的懒加载"""
50
+ return LazyImport(full_name)
ygo/pool.py ADDED
@@ -0,0 +1,286 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ ---------------------------------------------
4
+ Created on 2024/11/4 下午2:10
5
+ @author: ZhangYundi
6
+ @email: yundi.xxii@outlook.com
7
+ ---------------------------------------------
8
+ """
9
+ import functools
10
+ import multiprocessing
11
+ import os
12
+ import sys
13
+ import threading
14
+ import warnings
15
+ from typing import Literal
16
+
17
+ import logair
18
+ from joblib import Parallel, delayed
19
+
20
+ from .delay import delay
21
+ from .exceptions import WarnException, FailTaskError
22
+
23
+ with warnings.catch_warnings():
24
+ warnings.simplefilter("ignore")
25
+ from tqdm.auto import tqdm
26
+
27
+ logger = logair.get_logger("ygo")
28
+
29
+
30
+ def run_job(job, task_id, queue):
31
+ """执行任务并更新队列"""
32
+ try:
33
+ result = job()
34
+ except WarnException as e:
35
+ logger.warning(FailTaskError(task_name=job.task_name, error=e))
36
+ result = None
37
+ except Exception as e:
38
+ logger.error(FailTaskError(task_name=job.task_name, error=e), exc_info=e)
39
+ result = None
40
+ queue.put((task_id, 1))
41
+ return result
42
+
43
+
44
+ def update_progress_bars(tqdm_objects: list[tqdm],
45
+ task_ids,
46
+ queue: multiprocessing.Queue,
47
+ num_tasks: int,
48
+ task_counts: dict):
49
+ """根据队列中的消息更新 tqdm 进度条"""
50
+ completed_tasks = 0
51
+ completed_task_jobs = {id_: 0 for id_ in task_ids}
52
+ while completed_tasks < num_tasks:
53
+ task_id, progress_value = queue.get() # 从队列获取进度更新
54
+ completed_task_jobs[task_id] += 1
55
+ if completed_task_jobs[task_id] >= task_counts[task_id]:
56
+ completed_tasks += 1
57
+ tqdm_objects[task_id].update(progress_value)
58
+ [tqdm_object.close() for tqdm_object in tqdm_objects]
59
+
60
+
61
+ class pool:
62
+ """
63
+ 每个fn运行一次算一个job,每个job需要指定job_name, 如果没有job_name, 则默认分配 TaskDefault
64
+ 相同 job_name 的fn归到同一个task, 同时该task命名为job_name
65
+ 即一个task中包含了多个需要运行的 job fn
66
+ task1 <job_fn1, job_fn2, ...>
67
+ task2 <job_fn3, job_fn4, ...>
68
+ 所有的job_fn都会通过joblib并行运行
69
+ """
70
+
71
+ def __init__(self,
72
+ n_jobs: int = -1,
73
+ show_progress: bool = True,
74
+ backend: Literal['loky', 'threading', 'multiprocessing'] = 'loky',
75
+ ):
76
+ """backend: loky/threading/multiprocessing"""
77
+ self.show_progress = show_progress
78
+ if n_jobs < 0:
79
+ n_jobs = multiprocessing.cpu_count() - 1 # 给系统留一个cpu核
80
+ self._n_jobs = n_jobs
81
+
82
+ default_kwargs = {
83
+ # 'verbose': 10,
84
+ 'batch_size': 'auto',
85
+ 'pre_dispatch': f"2*{n_jobs}",
86
+ 'max_nbytes': '1M', # 减少进程间通信
87
+ 'timeout': None,
88
+ # 'prefer': 'processes',
89
+ }
90
+ self._parallel = Parallel(n_jobs=self._n_jobs, verbose=0, backend=backend,
91
+ **default_kwargs) if self._n_jobs > 1 else None
92
+ self._default_task_name = "GO-JOB"
93
+ self._all_jobs = list() # list[job]
94
+ self._all_tasks = list() # list[task_name]
95
+ self._task_ids = dict() # {task_name1: 0, task_name2: 1, ...}
96
+ self._task_counts = dict()
97
+ self._leave_mapping = dict()
98
+ self._id_counts = dict()
99
+ self._manager = None
100
+
101
+ # 1. 配置环境
102
+ self._configure_environment()
103
+
104
+ # 2. 抢占CPU核心
105
+ if self._n_jobs > 1:
106
+ self._steal_cpu_cores()
107
+
108
+ # 3. 预分配内存(如果启用)
109
+ # if self.memory_aware:
110
+ # self._preallocate_memory_pool()
111
+
112
+ def _configure_environment(self):
113
+ """配置环境以最大化资源使用"""
114
+ # 禁用所有竞争对手的并行
115
+ env_vars = {
116
+ 'OMP_NUM_THREADS': '1',
117
+ 'MKL_NUM_THREADS': '1',
118
+ 'OPENBLAS_NUM_THREADS': '1',
119
+ 'VECLIB_MAXIMUM_THREADS': '1',
120
+ 'NUMEXPR_NUM_THREADS': '1',
121
+ 'NUMBA_NUM_THREADS': '1',
122
+ 'PYTHON_GIL': '0', # 尝试禁用GIL(如果支持)
123
+ }
124
+
125
+ for key, value in env_vars.items():
126
+ os.environ[key] = value
127
+
128
+ # 设置Python优化
129
+ sys.setrecursionlimit(1000000)
130
+
131
+ # 设置joblib参数
132
+ os.environ['LOKY_MAX_CPU_COUNT'] = str(self._n_jobs)
133
+ os.environ['JOBLIB_START_METHOD'] = 'forkserver' # 更快的进程启动
134
+
135
+ def _steal_cpu_cores(self):
136
+ """为worker进程抢占CPU核心"""
137
+ try:
138
+ import psutil
139
+
140
+ # 分析CPU使用情况
141
+ cpu_percent = psutil.cpu_percent(interval=0.2, percpu=True)
142
+
143
+ # 分配核心给worker
144
+ worker_cores = []
145
+ for i in range(self._n_jobs):
146
+ # 寻找负载最低的核心
147
+ available_cores = [j for j in range(len(cpu_percent))
148
+ if j not in worker_cores]
149
+
150
+ if available_cores:
151
+ # 选择负载最低的可用核心
152
+ available_loads = [(j, cpu_percent[j]) for j in available_cores]
153
+ available_loads.sort(key=lambda x: x[1])
154
+ chosen_core = available_loads[0][0]
155
+ worker_cores.append(chosen_core)
156
+
157
+ self.worker_affinity = {
158
+ i: [core] for i, core in enumerate(worker_cores)
159
+ }
160
+
161
+ # print(f"🎯 为 {self.n_jobs} 个worker分配核心: {worker_cores}")
162
+
163
+ except Exception as e:
164
+ print(f"⚠️ CPU核心分配失败: {e}")
165
+
166
+ def submit(self, fn, job_name: str = "", postfix: str = "", leave: bool = True):
167
+ """
168
+ 提交并行任务
169
+ Parameters
170
+ ----------
171
+ fn: callable
172
+ 需要并行的callable对象
173
+ job_name: str
174
+ 提交到的任务名, 不同的任务对应不同的进度条
175
+ postfix: str
176
+ 后缀
177
+ leave: bool
178
+ 进度条完成后是否保留在屏幕上,默认 True
179
+ Returns
180
+ -------
181
+
182
+ Examples
183
+ --------
184
+ import time
185
+ import ygo
186
+ >>> def task_fn1(a, b):
187
+ time.sleep(3)
188
+ return a+b
189
+ >>> def task_fn2():
190
+ time.sleep(5)
191
+ return 0
192
+ >>> with ygo.pool() as go:
193
+ go.submit(task_fn1, job_name="task1")(a=1, b=2)
194
+ go.submit(task_fn2, job_name="task2")()
195
+
196
+ go.do()
197
+ """
198
+
199
+ # 提交任务,对任务进行分类,提交到对应的task id中,并且封装新的功能:使其在运行完毕后将任务进度更新放入队列
200
+ @functools.wraps(fn)
201
+ def collect(**kwargs):
202
+ """归集所有的job到对应的task"""
203
+ with warnings.catch_warnings():
204
+ warnings.simplefilter('ignore')
205
+ job = delay(fn)(**kwargs)
206
+ task_name = self._default_task_name if not job_name else job_name
207
+ task_name = f"{task_name}::{postfix}"
208
+ if task_name not in self._task_ids:
209
+ self._task_ids[task_name] = len(self._all_tasks)
210
+ self._task_counts[task_name] = 0
211
+ self._all_tasks.append(task_name)
212
+ self._id_counts[self._task_ids[task_name]] = 0
213
+ self._leave_mapping[task_name] = leave
214
+ self._task_counts[task_name] += 1
215
+ self._id_counts[self._task_ids[task_name]] += 1
216
+ job.task_id = self._task_ids[task_name]
217
+ job.job_id = self._task_counts[task_name]
218
+ job.task_name = task_name
219
+ self._all_jobs.append(job)
220
+ return job
221
+
222
+ return collect
223
+
224
+ def do(self):
225
+ if self.show_progress:
226
+ # 消息队列进行通信
227
+ self._manager = multiprocessing.Manager()
228
+ queue = self._manager.Queue()
229
+ tqdm_bars = [tqdm(total=self._task_counts[task_name],
230
+ desc=f"{str(task_name.split('::')[0])}",
231
+ postfix=task_name.split("::")[1],
232
+ leave=self._leave_mapping.get(task_name, True)) for task_name in
233
+ self._all_tasks]
234
+ # 初始化多个任务的进度条,每个任务一个
235
+ task_ids = [task_id for task_id in range(len(self._all_tasks))]
236
+ # 创建并启动用于更新进度条的线程
237
+ progress_thread = threading.Thread(target=update_progress_bars, args=(
238
+ tqdm_bars, task_ids, queue, len(self._all_tasks), self._id_counts))
239
+ progress_thread.start()
240
+ if self._parallel is not None:
241
+ with warnings.catch_warnings():
242
+ warnings.filterwarnings("ignore", category=UserWarning)
243
+ # 执行并行任务
244
+ result = self._parallel(delayed(run_job)(job=job,
245
+ task_id=job.task_id,
246
+ queue=queue) for job in self._all_jobs)
247
+ else:
248
+ result = [run_job(job=job, task_id=job.task_id, queue=queue) for job in self._all_jobs]
249
+ # 等待进度更新线程执行完毕
250
+ progress_thread.join()
251
+ else:
252
+ if self._parallel is not None:
253
+ with warnings.catch_warnings():
254
+ warnings.filterwarnings("ignore", category=UserWarning)
255
+ result = self._parallel(delayed(job)() for job in self._all_jobs)
256
+ else:
257
+ result = [job() for job in self._all_jobs]
258
+ self._all_jobs = list() # list[job]
259
+ self._all_tasks = list() # list[task_name]
260
+ self._task_ids = dict() # {task_name1: 0, task_name2: 1, ...}
261
+ self._task_counts = dict()
262
+ self._id_counts = dict()
263
+ self._leave_mapping = dict()
264
+ self.close()
265
+ return result
266
+
267
+ def close(self):
268
+ """释放所有资源"""
269
+ if hasattr(self, '_parallel') and self._parallel is not None:
270
+ try:
271
+ self._parallel.__exit__(None, None, None)
272
+ except Exception as e:
273
+ logger.warning(f"Failed to close Parallel: {e}")
274
+
275
+ if hasattr(self, '_manager') and self._manager is not None:
276
+ try:
277
+ self._manager.shutdown()
278
+ except Exception as e:
279
+ logger.warning(f"Failed to shutdown Manager: {e}")
280
+
281
+ def __enter__(self):
282
+ return self
283
+
284
+ def __exit__(self, exc_type, exc_val, exc_tb):
285
+ # 释放进程
286
+ self.close()