auto-coder 0.1.354__py3-none-any.whl → 0.1.356__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.354.dist-info → auto_coder-0.1.356.dist-info}/METADATA +1 -1
- {auto_coder-0.1.354.dist-info → auto_coder-0.1.356.dist-info}/RECORD +40 -35
- autocoder/agent/agentic_filter.py +1 -1
- autocoder/agent/auto_learn.py +631 -0
- autocoder/auto_coder.py +8 -0
- autocoder/auto_coder_runner.py +59 -87
- autocoder/chat/conf_command.py +270 -0
- autocoder/chat/models_command.py +485 -0
- autocoder/chat/rules_command.py +458 -0
- autocoder/chat_auto_coder.py +34 -24
- autocoder/chat_auto_coder_lang.py +156 -2
- autocoder/commands/auto_command.py +1 -1
- autocoder/commands/auto_web.py +1 -1
- autocoder/common/__init__.py +2 -0
- autocoder/common/auto_coder_lang.py +9 -1
- autocoder/common/command_completer.py +58 -12
- autocoder/common/command_completer_v2.py +615 -0
- autocoder/common/global_cancel.py +53 -16
- autocoder/common/rulefiles/autocoderrules_utils.py +83 -0
- autocoder/common/v2/agent/agentic_edit.py +4 -4
- autocoder/common/v2/code_agentic_editblock_manager.py +9 -9
- autocoder/common/v2/code_diff_manager.py +2 -2
- autocoder/common/v2/code_editblock_manager.py +11 -10
- autocoder/common/v2/code_strict_diff_manager.py +3 -2
- autocoder/dispacher/actions/action.py +6 -6
- autocoder/dispacher/actions/plugins/action_regex_project.py +2 -2
- autocoder/events/event_manager_singleton.py +1 -1
- autocoder/index/index.py +2 -2
- autocoder/rag/cache/local_byzer_storage_cache.py +1 -1
- autocoder/rag/cache/local_duckdb_storage_cache.py +8 -0
- autocoder/rag/loaders/image_loader.py +25 -13
- autocoder/rag/long_context_rag.py +2 -2
- autocoder/utils/auto_coder_utils/chat_stream_out.py +3 -4
- autocoder/utils/model_provider_selector.py +14 -2
- autocoder/utils/thread_utils.py +9 -27
- autocoder/version.py +1 -1
- {auto_coder-0.1.354.dist-info → auto_coder-0.1.356.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.354.dist-info → auto_coder-0.1.356.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.354.dist-info → auto_coder-0.1.356.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.354.dist-info → auto_coder-0.1.356.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,485 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import shlex
|
|
3
|
+
import fnmatch # Add fnmatch for wildcard matching
|
|
4
|
+
from typing import Dict, Any
|
|
5
|
+
from rich.console import Console
|
|
6
|
+
from rich.table import Table
|
|
7
|
+
from rich.panel import Panel
|
|
8
|
+
import byzerllm
|
|
9
|
+
from typing import Generator
|
|
10
|
+
from autocoder import models as models_module
|
|
11
|
+
from autocoder.common.printer import Printer
|
|
12
|
+
from autocoder.common.result_manager import ResultManager
|
|
13
|
+
from autocoder.common.model_speed_tester import render_speed_test_in_terminal
|
|
14
|
+
from autocoder.utils.llms import get_single_llm
|
|
15
|
+
|
|
16
|
+
def handle_models_command(query: str, memory: Dict[str, Any]):
|
|
17
|
+
"""
|
|
18
|
+
Handle /models subcommands:
|
|
19
|
+
/models /list - List all models (default + custom)
|
|
20
|
+
/models /add <n> <api_key> - Add model with simplified params
|
|
21
|
+
/models /add_model name=xxx base_url=xxx ... - Add model with custom params
|
|
22
|
+
/models /remove <n> - Remove model by name
|
|
23
|
+
/models /chat <content> - Chat with a model
|
|
24
|
+
"""
|
|
25
|
+
console = Console()
|
|
26
|
+
printer = Printer(console=console)
|
|
27
|
+
|
|
28
|
+
product_mode = memory.get("product_mode", "lite")
|
|
29
|
+
if product_mode != "lite":
|
|
30
|
+
printer.print_in_terminal("models_lite_only", style="red")
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
# Check if the query is empty or only whitespace
|
|
34
|
+
if not query.strip():
|
|
35
|
+
printer.print_in_terminal("models_usage")
|
|
36
|
+
return
|
|
37
|
+
|
|
38
|
+
models_data = models_module.load_models()
|
|
39
|
+
subcmd = ""
|
|
40
|
+
if "/list" in query:
|
|
41
|
+
subcmd = "/list"
|
|
42
|
+
query = query.replace("/list", "", 1).strip()
|
|
43
|
+
|
|
44
|
+
if "/add_model" in query:
|
|
45
|
+
subcmd = "/add_model"
|
|
46
|
+
query = query.replace("/add_model", "", 1).strip()
|
|
47
|
+
|
|
48
|
+
if "/add" in query:
|
|
49
|
+
subcmd = "/add"
|
|
50
|
+
query = query.replace("/add", "", 1).strip()
|
|
51
|
+
|
|
52
|
+
# alias to /add
|
|
53
|
+
if "/activate" in query:
|
|
54
|
+
subcmd = "/add"
|
|
55
|
+
query = query.replace("/activate", "", 1).strip()
|
|
56
|
+
|
|
57
|
+
if "/remove" in query:
|
|
58
|
+
subcmd = "/remove"
|
|
59
|
+
query = query.replace("/remove", "", 1).strip()
|
|
60
|
+
|
|
61
|
+
if "/speed-test" in query:
|
|
62
|
+
subcmd = "/speed-test"
|
|
63
|
+
query = query.replace("/speed-test", "", 1).strip()
|
|
64
|
+
|
|
65
|
+
if "/speed_test" in query:
|
|
66
|
+
subcmd = "/speed-test"
|
|
67
|
+
query = query.replace("/speed_test", "", 1).strip()
|
|
68
|
+
|
|
69
|
+
if "input_price" in query:
|
|
70
|
+
subcmd = "/input_price"
|
|
71
|
+
query = query.replace("/input_price", "", 1).strip()
|
|
72
|
+
|
|
73
|
+
if "output_price" in query:
|
|
74
|
+
subcmd = "/output_price"
|
|
75
|
+
query = query.replace("/output_price", "", 1).strip()
|
|
76
|
+
|
|
77
|
+
if "/speed" in query:
|
|
78
|
+
subcmd = "/speed"
|
|
79
|
+
query = query.replace("/speed", "", 1).strip()
|
|
80
|
+
|
|
81
|
+
if "/chat" in query:
|
|
82
|
+
subcmd = "/chat"
|
|
83
|
+
query = query.replace("/chat", "", 1).strip()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
if not subcmd:
|
|
88
|
+
printer.print_in_terminal("models_usage")
|
|
89
|
+
|
|
90
|
+
result_manager = ResultManager()
|
|
91
|
+
if subcmd == "/list":
|
|
92
|
+
pattern = query.strip() # Get the filter pattern from the query
|
|
93
|
+
filtered_models_data = models_data
|
|
94
|
+
|
|
95
|
+
if pattern: # Apply filter if a pattern is provided
|
|
96
|
+
filtered_models_data = [
|
|
97
|
+
m for m in models_data if fnmatch.fnmatch(m.get("name", ""), pattern)
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
if filtered_models_data:
|
|
101
|
+
# Sort models by speed (average_speed)
|
|
102
|
+
sorted_models = sorted(filtered_models_data, key=lambda x: float(x.get('average_speed', 0)))
|
|
103
|
+
sorted_models.reverse()
|
|
104
|
+
|
|
105
|
+
table = Table(
|
|
106
|
+
title=printer.get_message_from_key("models_title") + (f" (Filtered by: '{pattern}')" if pattern else ""),
|
|
107
|
+
expand=True,
|
|
108
|
+
show_lines=True
|
|
109
|
+
)
|
|
110
|
+
table.add_column("Name", style="cyan", width=40, overflow="fold", no_wrap=False)
|
|
111
|
+
table.add_column("Model Name", style="magenta", width=30, overflow="fold", no_wrap=False)
|
|
112
|
+
table.add_column("Base URL", style="white", width=30, overflow="fold", no_wrap=False)
|
|
113
|
+
table.add_column("Input Price (M)", style="magenta", width=15, overflow="fold", no_wrap=False)
|
|
114
|
+
table.add_column("Output Price (M)", style="magenta", width=15, overflow="fold", no_wrap=False)
|
|
115
|
+
table.add_column("Speed (s/req)", style="blue", width=15, overflow="fold", no_wrap=False)
|
|
116
|
+
for m in sorted_models:
|
|
117
|
+
# Check if api_key_path exists and file exists
|
|
118
|
+
is_api_key_set = "api_key" in m
|
|
119
|
+
name = m.get("name", "")
|
|
120
|
+
if is_api_key_set:
|
|
121
|
+
api_key = m.get("api_key", "").strip()
|
|
122
|
+
if not api_key:
|
|
123
|
+
printer.print_in_terminal("models_api_key_empty", style="yellow", name=name)
|
|
124
|
+
name = f"{name} *"
|
|
125
|
+
|
|
126
|
+
table.add_row(
|
|
127
|
+
name,
|
|
128
|
+
m.get("model_name", ""),
|
|
129
|
+
m.get("base_url", ""),
|
|
130
|
+
f"{m.get('input_price', 0.0):.2f}",
|
|
131
|
+
f"{m.get('output_price', 0.0):.2f}",
|
|
132
|
+
f"{m.get('average_speed', 0.0):.3f}"
|
|
133
|
+
)
|
|
134
|
+
console.print(table)
|
|
135
|
+
result_manager.add_result(content=json.dumps(sorted_models, ensure_ascii=False), meta={
|
|
136
|
+
"action": "models",
|
|
137
|
+
"input": {
|
|
138
|
+
"query": query # Keep original query for logging
|
|
139
|
+
}
|
|
140
|
+
})
|
|
141
|
+
else:
|
|
142
|
+
if pattern:
|
|
143
|
+
# Use a specific message if filtering resulted in no models
|
|
144
|
+
printer.print_in_terminal("models_no_models_matching_pattern", style="yellow", pattern=pattern)
|
|
145
|
+
result_manager.add_result(content=f"No models found matching pattern: {pattern}", meta={
|
|
146
|
+
"action": "models",
|
|
147
|
+
"input": {
|
|
148
|
+
"query": query
|
|
149
|
+
}
|
|
150
|
+
})
|
|
151
|
+
else:
|
|
152
|
+
# Original message if no models exist at all
|
|
153
|
+
printer.print_in_terminal("models_no_models", style="yellow")
|
|
154
|
+
result_manager.add_result(content="No models found", meta={
|
|
155
|
+
"action": "models",
|
|
156
|
+
"input": {
|
|
157
|
+
"query": query
|
|
158
|
+
}
|
|
159
|
+
})
|
|
160
|
+
|
|
161
|
+
elif subcmd == "/input_price":
|
|
162
|
+
args = query.strip().split()
|
|
163
|
+
if len(args) >= 2:
|
|
164
|
+
name = args[0]
|
|
165
|
+
try:
|
|
166
|
+
price = float(args[1])
|
|
167
|
+
if models_module.update_model_input_price(name, price):
|
|
168
|
+
printer.print_in_terminal("models_input_price_updated", style="green", name=name, price=price)
|
|
169
|
+
result_manager.add_result(content=f"models_input_price_updated: {name} {price}",meta={
|
|
170
|
+
"action": "models",
|
|
171
|
+
"input": {
|
|
172
|
+
"query": query
|
|
173
|
+
}
|
|
174
|
+
})
|
|
175
|
+
else:
|
|
176
|
+
printer.print_in_terminal("models_not_found", style="red", name=name)
|
|
177
|
+
result_manager.add_result(content=f"models_not_found: {name}",meta={
|
|
178
|
+
"action": "models",
|
|
179
|
+
"input": {
|
|
180
|
+
"query": query
|
|
181
|
+
}
|
|
182
|
+
})
|
|
183
|
+
except ValueError as e:
|
|
184
|
+
result_manager.add_result(content=f"models_invalid_price: {str(e)}",meta={
|
|
185
|
+
"action": "models",
|
|
186
|
+
"input": {
|
|
187
|
+
"query": query
|
|
188
|
+
}
|
|
189
|
+
})
|
|
190
|
+
printer.print_in_terminal("models_invalid_price", style="red", error=str(e))
|
|
191
|
+
else:
|
|
192
|
+
result_manager.add_result(content=printer.get_message_from_key("models_input_price_usage"),meta={
|
|
193
|
+
"action": "models",
|
|
194
|
+
"input": {
|
|
195
|
+
"query": query
|
|
196
|
+
}
|
|
197
|
+
})
|
|
198
|
+
printer.print_in_terminal("models_input_price_usage", style="red")
|
|
199
|
+
|
|
200
|
+
elif subcmd == "/output_price":
|
|
201
|
+
args = query.strip().split()
|
|
202
|
+
if len(args) >= 2:
|
|
203
|
+
name = args[0]
|
|
204
|
+
try:
|
|
205
|
+
price = float(args[1])
|
|
206
|
+
if models_module.update_model_output_price(name, price):
|
|
207
|
+
printer.print_in_terminal("models_output_price_updated", style="green", name=name, price=price)
|
|
208
|
+
result_manager.add_result(content=f"models_output_price_updated: {name} {price}",meta={
|
|
209
|
+
"action": "models",
|
|
210
|
+
"input": {
|
|
211
|
+
"query": query
|
|
212
|
+
}
|
|
213
|
+
})
|
|
214
|
+
else:
|
|
215
|
+
printer.print_in_terminal("models_not_found", style="red", name=name)
|
|
216
|
+
result_manager.add_result(content=f"models_not_found: {name}",meta={
|
|
217
|
+
"action": "models",
|
|
218
|
+
"input": {
|
|
219
|
+
"query": query
|
|
220
|
+
}
|
|
221
|
+
})
|
|
222
|
+
except ValueError as e:
|
|
223
|
+
printer.print_in_terminal("models_invalid_price", style="red", error=str(e))
|
|
224
|
+
result_manager.add_result(content=f"models_invalid_price: {str(e)}",meta={
|
|
225
|
+
"action": "models",
|
|
226
|
+
"input": {
|
|
227
|
+
"query": query
|
|
228
|
+
}
|
|
229
|
+
})
|
|
230
|
+
else:
|
|
231
|
+
result_manager.add_result(content=printer.get_message_from_key("models_output_price_usage"),meta={
|
|
232
|
+
"action": "models",
|
|
233
|
+
"input": {
|
|
234
|
+
"query": query
|
|
235
|
+
}
|
|
236
|
+
})
|
|
237
|
+
printer.print_in_terminal("models_output_price_usage", style="red")
|
|
238
|
+
|
|
239
|
+
elif subcmd == "/speed":
|
|
240
|
+
args = query.strip().split()
|
|
241
|
+
if len(args) >= 2:
|
|
242
|
+
name = args[0]
|
|
243
|
+
try:
|
|
244
|
+
speed = float(args[1])
|
|
245
|
+
if models_module.update_model_speed(name, speed):
|
|
246
|
+
printer.print_in_terminal("models_speed_updated", style="green", name=name, speed=speed)
|
|
247
|
+
result_manager.add_result(content=f"models_speed_updated: {name} {speed}",meta={
|
|
248
|
+
"action": "models",
|
|
249
|
+
"input": {
|
|
250
|
+
"query": query
|
|
251
|
+
}
|
|
252
|
+
})
|
|
253
|
+
else:
|
|
254
|
+
printer.print_in_terminal("models_not_found", style="red", name=name)
|
|
255
|
+
result_manager.add_result(content=f"models_not_found: {name}",meta={
|
|
256
|
+
"action": "models",
|
|
257
|
+
"input": {
|
|
258
|
+
"query": query
|
|
259
|
+
}
|
|
260
|
+
})
|
|
261
|
+
except ValueError as e:
|
|
262
|
+
printer.print_in_terminal("models_invalid_speed", style="red", error=str(e))
|
|
263
|
+
result_manager.add_result(content=f"models_invalid_speed: {str(e)}",meta={
|
|
264
|
+
"action": "models",
|
|
265
|
+
"input": {
|
|
266
|
+
"query": query
|
|
267
|
+
}
|
|
268
|
+
})
|
|
269
|
+
else:
|
|
270
|
+
result_manager.add_result(content=printer.get_message_from_key("models_speed_usage"),meta={
|
|
271
|
+
"action": "models",
|
|
272
|
+
"input": {
|
|
273
|
+
"query": query
|
|
274
|
+
}
|
|
275
|
+
})
|
|
276
|
+
printer.print_in_terminal("models_speed_usage", style="red")
|
|
277
|
+
|
|
278
|
+
elif subcmd == "/speed-test":
|
|
279
|
+
test_rounds = 1 # 默认测试轮数
|
|
280
|
+
|
|
281
|
+
enable_long_context = False
|
|
282
|
+
if "/long_context" in query:
|
|
283
|
+
enable_long_context = True
|
|
284
|
+
query = query.replace("/long_context", "", 1).strip()
|
|
285
|
+
|
|
286
|
+
if "/long-context" in query:
|
|
287
|
+
enable_long_context = True
|
|
288
|
+
query = query.replace("/long-context", "", 1).strip()
|
|
289
|
+
|
|
290
|
+
# 解析可选的测试轮数参数
|
|
291
|
+
args = query.strip().split()
|
|
292
|
+
if args and args[0].isdigit():
|
|
293
|
+
test_rounds = int(args[0])
|
|
294
|
+
|
|
295
|
+
render_speed_test_in_terminal(product_mode, test_rounds,enable_long_context=enable_long_context)
|
|
296
|
+
## 等待优化,获取明细数据
|
|
297
|
+
result_manager.add_result(content="models test success",meta={
|
|
298
|
+
"action": "models",
|
|
299
|
+
"input": {
|
|
300
|
+
"query": query
|
|
301
|
+
}
|
|
302
|
+
})
|
|
303
|
+
|
|
304
|
+
elif subcmd == "/add":
|
|
305
|
+
# Support both simplified and legacy formats
|
|
306
|
+
args = query.strip().split(" ")
|
|
307
|
+
if len(args) == 2:
|
|
308
|
+
# Simplified: /models /add <name> <api_key>
|
|
309
|
+
name, api_key = args[0], args[1]
|
|
310
|
+
result = models_module.update_model_with_api_key(name, api_key)
|
|
311
|
+
if result:
|
|
312
|
+
result_manager.add_result(content=f"models_added: {name}",meta={
|
|
313
|
+
"action": "models",
|
|
314
|
+
"input": {
|
|
315
|
+
"query": query
|
|
316
|
+
}
|
|
317
|
+
})
|
|
318
|
+
printer.print_in_terminal("models_added", style="green", name=name)
|
|
319
|
+
else:
|
|
320
|
+
result_manager.add_result(content=f"models_add_failed: {name}",meta={
|
|
321
|
+
"action": "models",
|
|
322
|
+
"input": {
|
|
323
|
+
"query": query
|
|
324
|
+
}
|
|
325
|
+
})
|
|
326
|
+
printer.print_in_terminal("models_add_failed", style="red", name=name)
|
|
327
|
+
else:
|
|
328
|
+
models_list = "\n".join([m["name"] for m in models_module.default_models_list])
|
|
329
|
+
printer.print_in_terminal("models_add_usage", style="red", models=models_list)
|
|
330
|
+
result_manager.add_result(content=printer.get_message_from_key_with_format("models_add_usage",models=models_list),meta={
|
|
331
|
+
"action": "models",
|
|
332
|
+
"input": {
|
|
333
|
+
"query": query
|
|
334
|
+
}
|
|
335
|
+
})
|
|
336
|
+
|
|
337
|
+
elif subcmd == "/add_model":
|
|
338
|
+
# Parse key=value pairs: /models /add_model name=abc base_url=http://xx ...
|
|
339
|
+
# Collect key=value pairs
|
|
340
|
+
kv_pairs = shlex.split(query)
|
|
341
|
+
data_dict = {}
|
|
342
|
+
for pair in kv_pairs:
|
|
343
|
+
if '=' not in pair:
|
|
344
|
+
printer.print_in_terminal("models_add_model_params", style="red")
|
|
345
|
+
continue
|
|
346
|
+
k, v = pair.split('=', 1)
|
|
347
|
+
data_dict[k.strip()] = v.strip()
|
|
348
|
+
|
|
349
|
+
# Name is required
|
|
350
|
+
if "name" not in data_dict:
|
|
351
|
+
printer.print_in_terminal("models_add_model_name_required", style="red")
|
|
352
|
+
return
|
|
353
|
+
|
|
354
|
+
# Check duplication
|
|
355
|
+
if any(m["name"] == data_dict["name"] for m in models_data):
|
|
356
|
+
printer.print_in_terminal("models_add_model_exists", style="yellow", name=data_dict["name"])
|
|
357
|
+
result_manager.add_result(content=printer.get_message_from_key_with_format("models_add_model_exists",name=data_dict["name"]),meta={
|
|
358
|
+
"action": "models",
|
|
359
|
+
"input": {
|
|
360
|
+
"query": query
|
|
361
|
+
}
|
|
362
|
+
})
|
|
363
|
+
return
|
|
364
|
+
|
|
365
|
+
# Create model with defaults
|
|
366
|
+
final_model = {
|
|
367
|
+
"name": data_dict["name"],
|
|
368
|
+
"model_type": data_dict.get("model_type", "saas/openai"),
|
|
369
|
+
"model_name": data_dict.get("model_name", data_dict["name"]),
|
|
370
|
+
"base_url": data_dict.get("base_url", "https://api.openai.com/v1"),
|
|
371
|
+
"api_key_path": data_dict.get("api_key_path", "api.openai.com"),
|
|
372
|
+
"description": data_dict.get("description", ""),
|
|
373
|
+
"is_reasoning": data_dict.get("is_reasoning", "false") in ["true", "True", "TRUE", "1"]
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
models_data.append(final_model)
|
|
377
|
+
models_module.save_models(models_data)
|
|
378
|
+
printer.print_in_terminal("models_add_model_success", style="green", name=data_dict["name"])
|
|
379
|
+
result_manager.add_result(content=f"models_add_model_success: {data_dict['name']}",meta={
|
|
380
|
+
"action": "models",
|
|
381
|
+
"input": {
|
|
382
|
+
"query": query
|
|
383
|
+
}
|
|
384
|
+
})
|
|
385
|
+
|
|
386
|
+
elif subcmd == "/remove":
|
|
387
|
+
args = query.strip().split(" ")
|
|
388
|
+
if len(args) < 1:
|
|
389
|
+
printer.print_in_terminal("models_add_usage", style="red")
|
|
390
|
+
result_manager.add_result(content=printer.get_message_from_key("models_add_usage"),meta={
|
|
391
|
+
"action": "models",
|
|
392
|
+
"input": {
|
|
393
|
+
"query": query
|
|
394
|
+
}
|
|
395
|
+
})
|
|
396
|
+
return
|
|
397
|
+
name = args[0]
|
|
398
|
+
filtered_models = [m for m in models_data if m["name"] != name]
|
|
399
|
+
if len(filtered_models) == len(models_data):
|
|
400
|
+
printer.print_in_terminal("models_add_model_remove", style="yellow", name=name)
|
|
401
|
+
result_manager.add_result(content=printer.get_message_from_key_with_format("models_add_model_remove",name=name),meta={
|
|
402
|
+
"action": "models",
|
|
403
|
+
"input": {
|
|
404
|
+
"query": query
|
|
405
|
+
}
|
|
406
|
+
})
|
|
407
|
+
return
|
|
408
|
+
models_module.save_models(filtered_models)
|
|
409
|
+
printer.print_in_terminal("models_add_model_removed", style="green", name=name)
|
|
410
|
+
result_manager.add_result(content=printer.get_message_from_key_with_format("models_add_model_removed",name=name),meta={
|
|
411
|
+
"action": "models",
|
|
412
|
+
"input": {
|
|
413
|
+
"query": query
|
|
414
|
+
}
|
|
415
|
+
})
|
|
416
|
+
elif subcmd == "/chat":
|
|
417
|
+
if not query.strip():
|
|
418
|
+
printer.print_in_terminal("Please provide content in format: <model_name> <question>", style="yellow")
|
|
419
|
+
result_manager.add_result(content="Please provide content in format: <model_name> <question>", meta={
|
|
420
|
+
"action": "models",
|
|
421
|
+
"input": {
|
|
422
|
+
"query": query
|
|
423
|
+
}
|
|
424
|
+
})
|
|
425
|
+
return
|
|
426
|
+
|
|
427
|
+
# 分离模型名称和用户问题
|
|
428
|
+
parts = query.strip().split(' ', 1) # 只在第一个空格处分割
|
|
429
|
+
if len(parts) < 2:
|
|
430
|
+
printer.print_in_terminal("Correct format should be: <model_name> <question>, where question can contain spaces", style="yellow")
|
|
431
|
+
result_manager.add_result(content="Correct format should be: <model_name> <question>, where question can contain spaces", meta={
|
|
432
|
+
"action": "models",
|
|
433
|
+
"input": {
|
|
434
|
+
"query": query
|
|
435
|
+
}
|
|
436
|
+
})
|
|
437
|
+
return
|
|
438
|
+
|
|
439
|
+
model_name = parts[0]
|
|
440
|
+
user_question = parts[1] # 这将包含所有剩余文本,保留空格
|
|
441
|
+
product_mode = memory.get("product_mode", "lite")
|
|
442
|
+
|
|
443
|
+
try:
|
|
444
|
+
# Get the model
|
|
445
|
+
llm = get_single_llm(model_name, product_mode=product_mode)
|
|
446
|
+
|
|
447
|
+
@byzerllm.prompt()
|
|
448
|
+
def chat_func(content: str) -> Generator[str, None, None]:
|
|
449
|
+
"""
|
|
450
|
+
{{ content }}
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
# Support custom llm_config parameters
|
|
454
|
+
result = chat_func.with_llm(llm).run(user_question)
|
|
455
|
+
output_text = ""
|
|
456
|
+
for res in result:
|
|
457
|
+
output_text += res
|
|
458
|
+
print(res, end="", flush=True)
|
|
459
|
+
print("\n")
|
|
460
|
+
|
|
461
|
+
# Print the result
|
|
462
|
+
|
|
463
|
+
result_manager.add_result(content=output_text, meta={
|
|
464
|
+
"action": "models",
|
|
465
|
+
"input": {
|
|
466
|
+
"query": query
|
|
467
|
+
}
|
|
468
|
+
})
|
|
469
|
+
except Exception as e:
|
|
470
|
+
error_message = f"Error chatting with model: {str(e)}"
|
|
471
|
+
printer.print_str_in_terminal(error_message, style="red")
|
|
472
|
+
result_manager.add_result(content=error_message, meta={
|
|
473
|
+
"action": "models",
|
|
474
|
+
"input": {
|
|
475
|
+
"query": query
|
|
476
|
+
}
|
|
477
|
+
})
|
|
478
|
+
else:
|
|
479
|
+
printer.print_in_terminal("models_unknown_subcmd", style="yellow", subcmd=subcmd)
|
|
480
|
+
result_manager.add_result(content=printer.get_message_from_key_with_format("models_unknown_subcmd",subcmd=subcmd),meta={
|
|
481
|
+
"action": "models",
|
|
482
|
+
"input": {
|
|
483
|
+
"query": query
|
|
484
|
+
}
|
|
485
|
+
})
|