auto-coder 0.1.250__py3-none-any.whl → 0.1.252__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.250.dist-info → auto_coder-0.1.252.dist-info}/METADATA +2 -2
- {auto_coder-0.1.250.dist-info → auto_coder-0.1.252.dist-info}/RECORD +31 -29
- autocoder/auto_coder.py +36 -4
- autocoder/auto_coder_rag.py +198 -35
- autocoder/chat_auto_coder.py +58 -5
- autocoder/chat_auto_coder_lang.py +21 -3
- autocoder/common/__init__.py +2 -1
- autocoder/common/auto_coder_lang.py +11 -5
- autocoder/common/code_auto_generate.py +10 -0
- autocoder/common/code_auto_generate_diff.py +10 -0
- autocoder/common/code_auto_generate_editblock.py +22 -8
- autocoder/common/code_auto_generate_strict_diff.py +10 -0
- autocoder/common/code_modification_ranker.py +3 -3
- autocoder/common/global_cancel.py +21 -0
- autocoder/common/printer.py +4 -1
- autocoder/dispacher/actions/action.py +29 -8
- autocoder/dispacher/actions/plugins/action_regex_project.py +17 -5
- autocoder/index/filter/quick_filter.py +4 -6
- autocoder/index/index.py +17 -6
- autocoder/models.py +87 -6
- autocoder/rag/doc_filter.py +1 -3
- autocoder/rag/long_context_rag.py +7 -5
- autocoder/rag/token_limiter.py +1 -3
- autocoder/utils/auto_coder_utils/chat_stream_out.py +13 -2
- autocoder/utils/llms.py +15 -1
- autocoder/utils/thread_utils.py +201 -0
- autocoder/version.py +1 -1
- {auto_coder-0.1.250.dist-info → auto_coder-0.1.252.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.250.dist-info → auto_coder-0.1.252.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.250.dist-info → auto_coder-0.1.252.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.250.dist-info → auto_coder-0.1.252.dist-info}/top_level.txt +0 -0
autocoder/index/index.py
CHANGED
|
@@ -22,7 +22,8 @@ from autocoder.index.types import (
|
|
|
22
22
|
TargetFile,
|
|
23
23
|
FileList,
|
|
24
24
|
)
|
|
25
|
-
|
|
25
|
+
from autocoder.common.global_cancel import global_cancel
|
|
26
|
+
from autocoder.utils.llms import get_llm_names
|
|
26
27
|
class IndexManager:
|
|
27
28
|
def __init__(
|
|
28
29
|
self, llm: byzerllm.ByzerLLM, sources: List[SourceCode], args: AutoCoderArgs
|
|
@@ -195,7 +196,10 @@ class IndexManager:
|
|
|
195
196
|
return True
|
|
196
197
|
return False
|
|
197
198
|
|
|
198
|
-
def build_index_for_single_source(self, source: SourceCode):
|
|
199
|
+
def build_index_for_single_source(self, source: SourceCode):
|
|
200
|
+
if global_cancel.requested:
|
|
201
|
+
return None
|
|
202
|
+
|
|
199
203
|
file_path = source.module_name
|
|
200
204
|
if not os.path.exists(file_path):
|
|
201
205
|
return None
|
|
@@ -205,9 +209,7 @@ class IndexManager:
|
|
|
205
209
|
|
|
206
210
|
md5 = hashlib.md5(source.source_code.encode("utf-8")).hexdigest()
|
|
207
211
|
|
|
208
|
-
model_name =
|
|
209
|
-
if not model_name:
|
|
210
|
-
model_name = "unknown(without default model name)"
|
|
212
|
+
model_name = ",".join(get_llm_names(self.index_llm))
|
|
211
213
|
|
|
212
214
|
try:
|
|
213
215
|
start_time = time.monotonic()
|
|
@@ -314,6 +316,9 @@ class IndexManager:
|
|
|
314
316
|
):
|
|
315
317
|
wait_to_build_files.append(source)
|
|
316
318
|
|
|
319
|
+
# Remove duplicates based on module_name
|
|
320
|
+
wait_to_build_files = list({source.module_name: source for source in wait_to_build_files}.values())
|
|
321
|
+
|
|
317
322
|
counter = 0
|
|
318
323
|
num_files = len(wait_to_build_files)
|
|
319
324
|
total_files = len(self.sources)
|
|
@@ -329,6 +334,8 @@ class IndexManager:
|
|
|
329
334
|
for source in wait_to_build_files
|
|
330
335
|
]
|
|
331
336
|
for future in as_completed(futures):
|
|
337
|
+
if global_cancel.requested:
|
|
338
|
+
break
|
|
332
339
|
result = future.result()
|
|
333
340
|
if result is not None:
|
|
334
341
|
counter += 1
|
|
@@ -341,7 +348,11 @@ class IndexManager:
|
|
|
341
348
|
module_name = result["module_name"]
|
|
342
349
|
index_data[module_name] = result
|
|
343
350
|
updated_sources.append(module_name)
|
|
344
|
-
|
|
351
|
+
if len(updated_sources) > 5:
|
|
352
|
+
with open(self.index_file, "w") as file:
|
|
353
|
+
json.dump(index_data, file, ensure_ascii=False, indent=2)
|
|
354
|
+
updated_sources = []
|
|
355
|
+
|
|
345
356
|
# 如果 updated_sources 或 keys_to_remove 有值,则保存索引文件
|
|
346
357
|
if updated_sources or keys_to_remove:
|
|
347
358
|
with open(self.index_file, "w") as file:
|
autocoder/models.py
CHANGED
|
@@ -2,7 +2,6 @@ import os
|
|
|
2
2
|
import json
|
|
3
3
|
from typing import List, Dict
|
|
4
4
|
from urllib.parse import urlparse
|
|
5
|
-
from autocoder.common.auto_coder_lang import get_message_with_format
|
|
6
5
|
|
|
7
6
|
MODELS_JSON = os.path.expanduser("~/.auto-coder/keys/models.json")
|
|
8
7
|
|
|
@@ -15,7 +14,10 @@ default_models_list = [
|
|
|
15
14
|
"model_type": "saas/openai",
|
|
16
15
|
"base_url": "https://api.deepseek.com/v1",
|
|
17
16
|
"api_key_path": "api.deepseek.com",
|
|
18
|
-
"is_reasoning": True
|
|
17
|
+
"is_reasoning": True,
|
|
18
|
+
"input_price": 0.0, # 单位:M/百万 input tokens
|
|
19
|
+
"output_price": 0.0, # 单位:M/百万 output tokens
|
|
20
|
+
"average_speed": 0.0 # 单位:秒/请求
|
|
19
21
|
},
|
|
20
22
|
{
|
|
21
23
|
"name": "deepseek_chat",
|
|
@@ -24,7 +26,10 @@ default_models_list = [
|
|
|
24
26
|
"model_type": "saas/openai",
|
|
25
27
|
"base_url": "https://api.deepseek.com/v1",
|
|
26
28
|
"api_key_path": "api.deepseek.com",
|
|
27
|
-
"is_reasoning": False
|
|
29
|
+
"is_reasoning": False,
|
|
30
|
+
"input_price": 0.0,
|
|
31
|
+
"output_price": 0.0,
|
|
32
|
+
"average_speed": 0.0
|
|
28
33
|
},
|
|
29
34
|
{
|
|
30
35
|
"name":"o1",
|
|
@@ -33,7 +38,10 @@ default_models_list = [
|
|
|
33
38
|
"model_type": "saas/openai",
|
|
34
39
|
"base_url": "https://api.openai.com/v1",
|
|
35
40
|
"api_key_path": "",
|
|
36
|
-
"is_reasoning": True
|
|
41
|
+
"is_reasoning": True,
|
|
42
|
+
"input_price": 0.0,
|
|
43
|
+
"output_price": 0.0,
|
|
44
|
+
"average_speed": 0.0
|
|
37
45
|
}
|
|
38
46
|
]
|
|
39
47
|
|
|
@@ -106,6 +114,7 @@ def get_model_by_name(name: str) -> Dict:
|
|
|
106
114
|
"""
|
|
107
115
|
根据模型名称查找模型
|
|
108
116
|
"""
|
|
117
|
+
from autocoder.common.auto_coder_lang import get_message_with_format
|
|
109
118
|
models = load_models()
|
|
110
119
|
v = [m for m in models if m["name"] == name.strip()]
|
|
111
120
|
|
|
@@ -114,6 +123,78 @@ def get_model_by_name(name: str) -> Dict:
|
|
|
114
123
|
return v[0]
|
|
115
124
|
|
|
116
125
|
|
|
126
|
+
def update_model_input_price(name: str, price: float) -> bool:
|
|
127
|
+
"""更新模型输入价格
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
name: 模型名称
|
|
131
|
+
price: 输入价格(M/百万input tokens)
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
bool: 是否更新成功
|
|
135
|
+
"""
|
|
136
|
+
if price < 0:
|
|
137
|
+
raise ValueError("Price cannot be negative")
|
|
138
|
+
|
|
139
|
+
models = load_models()
|
|
140
|
+
updated = False
|
|
141
|
+
for model in models:
|
|
142
|
+
if model["name"] == name:
|
|
143
|
+
model["input_price"] = float(price)
|
|
144
|
+
updated = True
|
|
145
|
+
break
|
|
146
|
+
if updated:
|
|
147
|
+
save_models(models)
|
|
148
|
+
return updated
|
|
149
|
+
|
|
150
|
+
def update_model_output_price(name: str, price: float) -> bool:
|
|
151
|
+
"""更新模型输出价格
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
name: 模型名称
|
|
155
|
+
price: 输出价格(M/百万output tokens)
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
bool: 是否更新成功
|
|
159
|
+
"""
|
|
160
|
+
if price < 0:
|
|
161
|
+
raise ValueError("Price cannot be negative")
|
|
162
|
+
|
|
163
|
+
models = load_models()
|
|
164
|
+
updated = False
|
|
165
|
+
for model in models:
|
|
166
|
+
if model["name"] == name:
|
|
167
|
+
model["output_price"] = float(price)
|
|
168
|
+
updated = True
|
|
169
|
+
break
|
|
170
|
+
if updated:
|
|
171
|
+
save_models(models)
|
|
172
|
+
return updated
|
|
173
|
+
|
|
174
|
+
def update_model_speed(name: str, speed: float) -> bool:
|
|
175
|
+
"""更新模型平均速度
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
name: 模型名称
|
|
179
|
+
speed: 速度(秒/请求)
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
bool: 是否更新成功
|
|
183
|
+
"""
|
|
184
|
+
if speed <= 0:
|
|
185
|
+
raise ValueError("Speed must be positive")
|
|
186
|
+
|
|
187
|
+
models = load_models()
|
|
188
|
+
updated = False
|
|
189
|
+
for model in models:
|
|
190
|
+
if model["name"] == name:
|
|
191
|
+
model["average_speed"] = float(speed)
|
|
192
|
+
updated = True
|
|
193
|
+
break
|
|
194
|
+
if updated:
|
|
195
|
+
save_models(models)
|
|
196
|
+
return updated
|
|
197
|
+
|
|
117
198
|
def check_model_exists(name: str) -> bool:
|
|
118
199
|
"""
|
|
119
200
|
检查模型是否存在
|
|
@@ -124,14 +205,14 @@ def check_model_exists(name: str) -> bool:
|
|
|
124
205
|
def update_model_with_api_key(name: str, api_key: str) -> Dict:
|
|
125
206
|
"""
|
|
126
207
|
根据模型名称查找并更新模型的 api_key_path。
|
|
127
|
-
|
|
208
|
+
如果找到模型,会根据其 base_url 处理 api_key_path。
|
|
128
209
|
|
|
129
210
|
Args:
|
|
130
211
|
name: 模型名称
|
|
131
212
|
api_key: API密钥
|
|
132
213
|
|
|
133
214
|
Returns:
|
|
134
|
-
Dict:
|
|
215
|
+
Dict: 更新后的模型信息,如果未找到则返回None
|
|
135
216
|
"""
|
|
136
217
|
models = load_models()
|
|
137
218
|
|
autocoder/rag/doc_filter.py
CHANGED
|
@@ -91,9 +91,7 @@ class DocFilter:
|
|
|
91
91
|
def _run(conversations, docs):
|
|
92
92
|
submit_time_1 = time.time()
|
|
93
93
|
try:
|
|
94
|
-
llm =
|
|
95
|
-
llm.skip_nontext_check = True
|
|
96
|
-
llm.setup_default_model_name(self.recall_llm.default_model_name)
|
|
94
|
+
llm = self.recall_llm
|
|
97
95
|
|
|
98
96
|
v = (
|
|
99
97
|
_check_relevance_with_conversation.with_llm(
|
|
@@ -52,11 +52,13 @@ class LongContextRAG:
|
|
|
52
52
|
) -> None:
|
|
53
53
|
self.llm = llm
|
|
54
54
|
self.args = args
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
55
|
+
if args.product_mode == "pro":
|
|
56
|
+
self.index_model = byzerllm.ByzerLLM()
|
|
57
|
+
self.index_model.setup_default_model_name(
|
|
58
|
+
args.index_model or self.llm.default_model_name
|
|
59
|
+
)
|
|
60
|
+
else:
|
|
61
|
+
self.index_model = self.llm
|
|
60
62
|
|
|
61
63
|
self.path = path
|
|
62
64
|
self.relevant_score = self.args.rag_doc_filter_relevance or 5
|
autocoder/rag/token_limiter.py
CHANGED
|
@@ -224,9 +224,7 @@ class TokenLimiter:
|
|
|
224
224
|
for idx, line in enumerate(source_code_lines):
|
|
225
225
|
source_code_with_line_number += f"{idx+1} {line}\n"
|
|
226
226
|
|
|
227
|
-
llm =
|
|
228
|
-
llm.skip_nontext_check = True
|
|
229
|
-
llm.setup_default_model_name(self.chunk_llm.default_model_name)
|
|
227
|
+
llm = self.chunk_llm
|
|
230
228
|
|
|
231
229
|
extracted_info = (
|
|
232
230
|
self.extract_relevance_range_from_docs_with_conversation.options(
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from rich.console import Console
|
|
2
|
+
from autocoder.common.printer import Printer
|
|
2
3
|
from rich.live import Live
|
|
3
4
|
from rich.panel import Panel
|
|
4
5
|
from rich.markdown import Markdown
|
|
@@ -11,6 +12,7 @@ from autocoder.utils.request_queue import request_queue
|
|
|
11
12
|
import time
|
|
12
13
|
from byzerllm.utils.types import SingleOutputMeta
|
|
13
14
|
from autocoder.common import AutoCoderArgs
|
|
15
|
+
from autocoder.common.global_cancel import global_cancel
|
|
14
16
|
|
|
15
17
|
MAX_HISTORY_LINES = 40 # 最大保留历史行数
|
|
16
18
|
|
|
@@ -172,7 +174,9 @@ def stream_out(
|
|
|
172
174
|
current_line = "" # 当前行
|
|
173
175
|
assistant_response = ""
|
|
174
176
|
last_meta = None
|
|
175
|
-
panel_title = title if title is not None else f"Response[ {model_name} ]"
|
|
177
|
+
panel_title = title if title is not None else f"Response[ {model_name} ]"
|
|
178
|
+
first_token_time = 0.0
|
|
179
|
+
first_token_time_start = time.time()
|
|
176
180
|
try:
|
|
177
181
|
with Live(
|
|
178
182
|
Panel("", title=panel_title, border_style="green"),
|
|
@@ -180,6 +184,10 @@ def stream_out(
|
|
|
180
184
|
console=console
|
|
181
185
|
) as live:
|
|
182
186
|
for res in stream_generator:
|
|
187
|
+
if global_cancel.requested:
|
|
188
|
+
printer = Printer(console)
|
|
189
|
+
printer.print_in_terminal("generation_cancelled")
|
|
190
|
+
break
|
|
183
191
|
last_meta = res[1]
|
|
184
192
|
content = res[0]
|
|
185
193
|
reasoning_content = last_meta.reasoning_content
|
|
@@ -187,6 +195,9 @@ def stream_out(
|
|
|
187
195
|
if reasoning_content == "" and content == "":
|
|
188
196
|
continue
|
|
189
197
|
|
|
198
|
+
if first_token_time == 0.0:
|
|
199
|
+
first_token_time = time.time() - first_token_time_start
|
|
200
|
+
|
|
190
201
|
if keep_reasoning_content:
|
|
191
202
|
# 处理思考内容
|
|
192
203
|
if reasoning_content:
|
|
@@ -280,5 +291,5 @@ def stream_out(
|
|
|
280
291
|
status=RequestOption.COMPLETED
|
|
281
292
|
),
|
|
282
293
|
)
|
|
283
|
-
|
|
294
|
+
last_meta.first_token_time = first_token_time
|
|
284
295
|
return assistant_response, last_meta
|
autocoder/utils/llms.py
CHANGED
|
@@ -1,7 +1,21 @@
|
|
|
1
1
|
import byzerllm
|
|
2
|
-
from
|
|
2
|
+
from typing import Union,Optional
|
|
3
|
+
|
|
4
|
+
def get_llm_names(llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM,str],target_model_type:Optional[str]=None):
|
|
5
|
+
if target_model_type is None:
|
|
6
|
+
return [llm.default_model_name for llm in [llm] if llm.default_model_name]
|
|
7
|
+
llms = llm.get_sub_client(target_model_type)
|
|
8
|
+
if llms is None:
|
|
9
|
+
return [llm.default_model_name for llm in [llm] if llm.default_model_name]
|
|
10
|
+
elif isinstance(llms, list):
|
|
11
|
+
return [llm.default_model_name for llm in llms if llm.default_model_name]
|
|
12
|
+
elif isinstance(llms,str) and llms:
|
|
13
|
+
return llms.split(",")
|
|
14
|
+
else:
|
|
15
|
+
return [llm.default_model_name for llm in [llms] if llm.default_model_name]
|
|
3
16
|
|
|
4
17
|
def get_single_llm(model_names: str, product_mode: str):
|
|
18
|
+
from autocoder import models as models_module
|
|
5
19
|
if product_mode == "pro":
|
|
6
20
|
if "," in model_names:
|
|
7
21
|
# Multiple code models specified
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
from concurrent.futures import ThreadPoolExecutor, TimeoutError, CancelledError
|
|
2
|
+
from threading import Event
|
|
3
|
+
from inspect import signature
|
|
4
|
+
from functools import wraps
|
|
5
|
+
from typing import Any, Optional
|
|
6
|
+
import threading
|
|
7
|
+
import logging
|
|
8
|
+
import time
|
|
9
|
+
from autocoder.common.global_cancel import global_cancel
|
|
10
|
+
|
|
11
|
+
class CancellationRequested(Exception):
|
|
12
|
+
"""Raised when a task is requested to be cancelled."""
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def run_in_thread(timeout: Optional[float] = None):
|
|
17
|
+
"""Decorator that runs a function in a thread with signal handling.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
timeout (float, optional): Maximum time to wait for thread completion in seconds.
|
|
21
|
+
If None, will wait indefinitely.
|
|
22
|
+
|
|
23
|
+
The decorated function will run in a separate thread and can be interrupted by
|
|
24
|
+
signals like Ctrl+C (KeyboardInterrupt). When interrupted, it will log the event
|
|
25
|
+
and clean up gracefully.
|
|
26
|
+
"""
|
|
27
|
+
def decorator(func):
|
|
28
|
+
@wraps(func)
|
|
29
|
+
def wrapper(*args, **kwargs):
|
|
30
|
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
31
|
+
future = executor.submit(func, *args, **kwargs)
|
|
32
|
+
start_time = time.time()
|
|
33
|
+
|
|
34
|
+
while True:
|
|
35
|
+
try:
|
|
36
|
+
# 使用较短的超时时间进行轮询,确保能够响应中断信号
|
|
37
|
+
poll_timeout = 0.1
|
|
38
|
+
if timeout is not None:
|
|
39
|
+
remaining = timeout - (time.time() - start_time)
|
|
40
|
+
if remaining <= 0:
|
|
41
|
+
future.cancel()
|
|
42
|
+
raise TimeoutError(f"Timeout after {timeout}s in {func.__name__}")
|
|
43
|
+
poll_timeout = min(poll_timeout, remaining)
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
return future.result(timeout=poll_timeout)
|
|
47
|
+
except TimeoutError:
|
|
48
|
+
continue # 继续轮询
|
|
49
|
+
|
|
50
|
+
except KeyboardInterrupt:
|
|
51
|
+
logging.warning("KeyboardInterrupt received, attempting to cancel task...")
|
|
52
|
+
future.cancel()
|
|
53
|
+
raise
|
|
54
|
+
except Exception as e:
|
|
55
|
+
logging.error(f"Error occurred in thread: {str(e)}")
|
|
56
|
+
raise
|
|
57
|
+
return wrapper
|
|
58
|
+
return decorator
|
|
59
|
+
|
|
60
|
+
def run_in_thread_with_cancel(timeout: Optional[float] = None):
|
|
61
|
+
"""Decorator that runs a function in a thread with explicit cancellation support.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
timeout (float, optional): Maximum time to wait for thread completion in seconds.
|
|
65
|
+
If None, will wait indefinitely.
|
|
66
|
+
|
|
67
|
+
The decorated function MUST accept 'cancel_event' as its first parameter.
|
|
68
|
+
This cancel_event is a threading.Event object that can be used to check if
|
|
69
|
+
cancellation has been requested.
|
|
70
|
+
|
|
71
|
+
The decorated function can be called with an external cancel_event passed as a keyword argument.
|
|
72
|
+
If not provided, a new Event will be created.
|
|
73
|
+
|
|
74
|
+
Example:
|
|
75
|
+
@run_in_thread_with_cancel(timeout=10)
|
|
76
|
+
def long_task(cancel_event, arg1, arg2):
|
|
77
|
+
while not cancel_event.is_set():
|
|
78
|
+
# do work
|
|
79
|
+
if cancel_event.is_set():
|
|
80
|
+
raise CancellationRequested()
|
|
81
|
+
|
|
82
|
+
# 使用外部传入的cancel_event
|
|
83
|
+
external_cancel = Event()
|
|
84
|
+
try:
|
|
85
|
+
result = long_task(arg1, arg2, cancel_event=external_cancel)
|
|
86
|
+
except CancelledError:
|
|
87
|
+
print("Task was cancelled")
|
|
88
|
+
|
|
89
|
+
# 在其他地方取消任务
|
|
90
|
+
external_cancel.set()
|
|
91
|
+
"""
|
|
92
|
+
def decorator(func):
|
|
93
|
+
# 检查函数签名
|
|
94
|
+
sig = signature(func)
|
|
95
|
+
params = list(sig.parameters.keys())
|
|
96
|
+
if not params or params[0] != 'cancel_event':
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"Function {func.__name__} must have 'cancel_event' as its first parameter. "
|
|
99
|
+
f"Current parameters: {params}"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
@wraps(func)
|
|
103
|
+
def wrapper(*args, **kwargs):
|
|
104
|
+
# 从kwargs中提取或创建cancel_event
|
|
105
|
+
cancel_event = kwargs.pop('cancel_event', None) or Event()
|
|
106
|
+
|
|
107
|
+
def cancellable_task():
|
|
108
|
+
try:
|
|
109
|
+
return func(cancel_event, *args, **kwargs)
|
|
110
|
+
except CancellationRequested:
|
|
111
|
+
logging.info(f"Task {func.__name__} was cancelled")
|
|
112
|
+
raise
|
|
113
|
+
except Exception as e:
|
|
114
|
+
logging.error(f"Error in {func.__name__}: {str(e)}")
|
|
115
|
+
raise
|
|
116
|
+
|
|
117
|
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
118
|
+
future = executor.submit(cancellable_task)
|
|
119
|
+
start_time = time.time()
|
|
120
|
+
|
|
121
|
+
while True:
|
|
122
|
+
try:
|
|
123
|
+
# 使用较短的超时时间进行轮询,确保能够响应中断信号
|
|
124
|
+
poll_timeout = 0.1
|
|
125
|
+
if timeout is not None:
|
|
126
|
+
remaining = timeout - (time.time() - start_time)
|
|
127
|
+
if remaining <= 0:
|
|
128
|
+
cancel_event.set()
|
|
129
|
+
future.cancel()
|
|
130
|
+
raise TimeoutError(f"Timeout after {timeout}s in {func.__name__}")
|
|
131
|
+
poll_timeout = min(poll_timeout, remaining)
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
return future.result(timeout=poll_timeout)
|
|
135
|
+
except TimeoutError:
|
|
136
|
+
continue # 继续轮询
|
|
137
|
+
|
|
138
|
+
except KeyboardInterrupt:
|
|
139
|
+
logging.warning(f"KeyboardInterrupt received, cancelling {func.__name__}...")
|
|
140
|
+
cancel_event.set()
|
|
141
|
+
future.cancel()
|
|
142
|
+
raise CancelledError("Task cancelled by user")
|
|
143
|
+
except CancellationRequested:
|
|
144
|
+
logging.info(f"Task {func.__name__} was cancelled")
|
|
145
|
+
raise CancelledError("Task cancelled by request")
|
|
146
|
+
except Exception as e:
|
|
147
|
+
logging.error(f"Error occurred in thread: {str(e)}")
|
|
148
|
+
raise
|
|
149
|
+
|
|
150
|
+
return wrapper
|
|
151
|
+
return decorator
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def run_in_raw_thread():
|
|
155
|
+
"""A decorator that runs a function in a separate thread and handles exceptions.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
func: The function to run in a thread
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
A wrapper function that executes the decorated function in a thread
|
|
162
|
+
|
|
163
|
+
The decorator will:
|
|
164
|
+
1. Run the function in a separate thread
|
|
165
|
+
2. Handle KeyboardInterrupt properly
|
|
166
|
+
3. Propagate exceptions from the thread
|
|
167
|
+
4. Support function arguments
|
|
168
|
+
5. Preserve function metadata
|
|
169
|
+
"""
|
|
170
|
+
def decorator(func):
|
|
171
|
+
|
|
172
|
+
@wraps(func)
|
|
173
|
+
def wrapper(*args, **kwargs):
|
|
174
|
+
# Store thread results
|
|
175
|
+
result = []
|
|
176
|
+
exception = []
|
|
177
|
+
def worker():
|
|
178
|
+
try:
|
|
179
|
+
ret = func(*args, **kwargs)
|
|
180
|
+
result.append(ret)
|
|
181
|
+
global_cancel.reset()
|
|
182
|
+
except Exception as e:
|
|
183
|
+
global_cancel.reset()
|
|
184
|
+
raise
|
|
185
|
+
|
|
186
|
+
# Create and start thread with a meaningful name
|
|
187
|
+
thread = threading.Thread(target=worker, name=f"{func.__name__}_thread")
|
|
188
|
+
thread.daemon = True # Make thread daemon so it doesn't prevent program exit
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
thread.start()
|
|
192
|
+
while thread.is_alive():
|
|
193
|
+
thread.join(0.1)
|
|
194
|
+
|
|
195
|
+
return result[0] if result else None
|
|
196
|
+
except KeyboardInterrupt:
|
|
197
|
+
global_cancel.set()
|
|
198
|
+
raise KeyboardInterrupt("Task was cancelled by user")
|
|
199
|
+
|
|
200
|
+
return wrapper
|
|
201
|
+
return decorator
|
autocoder/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.1.
|
|
1
|
+
__version__ = "0.1.252"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|