yaicli 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyproject.toml +5 -3
- yaicli/chat.py +396 -0
- yaicli/cli.py +250 -251
- yaicli/client.py +385 -0
- yaicli/config.py +31 -24
- yaicli/console.py +2 -2
- yaicli/const.py +28 -2
- yaicli/entry.py +68 -39
- yaicli/exceptions.py +8 -36
- yaicli/functions/__init__.py +39 -0
- yaicli/functions/buildin/execute_shell_command.py +47 -0
- yaicli/printer.py +145 -225
- yaicli/render.py +1 -1
- yaicli/role.py +231 -0
- yaicli/schemas.py +31 -0
- yaicli/tools.py +103 -0
- yaicli/utils.py +5 -2
- {yaicli-0.4.0.dist-info → yaicli-0.5.0.dist-info}/METADATA +164 -87
- yaicli-0.5.0.dist-info/RECORD +24 -0
- {yaicli-0.4.0.dist-info → yaicli-0.5.0.dist-info}/entry_points.txt +1 -1
- yaicli/chat_manager.py +0 -290
- yaicli/providers/__init__.py +0 -34
- yaicli/providers/base.py +0 -51
- yaicli/providers/cohere.py +0 -136
- yaicli/providers/openai.py +0 -176
- yaicli/roles.py +0 -276
- yaicli-0.4.0.dist-info/RECORD +0 -23
- {yaicli-0.4.0.dist-info → yaicli-0.5.0.dist-info}/WHEEL +0 -0
- {yaicli-0.4.0.dist-info → yaicli-0.5.0.dist-info}/licenses/LICENSE +0 -0
yaicli/client.py
ADDED
@@ -0,0 +1,385 @@
|
|
1
|
+
import json
|
2
|
+
from dataclasses import dataclass, field
|
3
|
+
from typing import Any, Dict, Generator, List, Optional, Union, cast
|
4
|
+
|
5
|
+
import litellm
|
6
|
+
from json_repair import repair_json
|
7
|
+
from litellm.types.utils import Choices
|
8
|
+
from litellm.types.utils import Message as ChoiceMessage
|
9
|
+
from litellm.types.utils import ModelResponse
|
10
|
+
from rich.panel import Panel
|
11
|
+
|
12
|
+
from .config import cfg
|
13
|
+
from .console import get_console
|
14
|
+
from .schemas import LLMResponse, ChatMessage, ToolCall
|
15
|
+
from .tools import (
|
16
|
+
Function,
|
17
|
+
FunctionName,
|
18
|
+
get_function,
|
19
|
+
get_openai_schemas,
|
20
|
+
list_functions,
|
21
|
+
)
|
22
|
+
|
23
|
+
litellm.drop_params = True
|
24
|
+
console = get_console()
|
25
|
+
|
26
|
+
|
27
|
+
class RefreshLive:
|
28
|
+
"""Refresh live display"""
|
29
|
+
|
30
|
+
|
31
|
+
class StopLive:
|
32
|
+
"""Stop live display"""
|
33
|
+
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class LitellmClient:
|
37
|
+
"""OpenAI provider implementation"""
|
38
|
+
|
39
|
+
api_key: str = field(default_factory=lambda: cfg["API_KEY"])
|
40
|
+
model: str = field(default_factory=lambda: f"{cfg['PROVIDER']}/{cfg['MODEL']}")
|
41
|
+
base_url: Optional[str] = field(default_factory=lambda: cfg["BASE_URL"])
|
42
|
+
timeout: int = field(default_factory=lambda: cfg["TIMEOUT"])
|
43
|
+
|
44
|
+
verbose: bool = False
|
45
|
+
|
46
|
+
def __post_init__(self) -> None:
|
47
|
+
"""Initialize OpenAI client"""
|
48
|
+
self.pre_tool_call_id = None
|
49
|
+
|
50
|
+
def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
51
|
+
"""Convert message format to OpenAI API required format"""
|
52
|
+
openai_messages = []
|
53
|
+
for msg in messages:
|
54
|
+
if msg.tool_call_id:
|
55
|
+
openai_messages.append(
|
56
|
+
{"role": msg.role, "content": msg.content, "tool_call_id": msg.tool_call_id, "name": msg.name}
|
57
|
+
)
|
58
|
+
else:
|
59
|
+
openai_messages.append({"role": msg.role, "content": msg.content})
|
60
|
+
return openai_messages
|
61
|
+
|
62
|
+
def _convert_functions(self, _: List[Function]) -> List[Dict[str, Any]]:
|
63
|
+
"""Convert function format to OpenAI API required format"""
|
64
|
+
return get_openai_schemas()
|
65
|
+
|
66
|
+
def _execute_tool_call(self, tool_call: ToolCall) -> tuple[str, bool]:
|
67
|
+
"""Call function and return result"""
|
68
|
+
console.print(f"@Function call: {tool_call.name}({tool_call.arguments})", style="blue")
|
69
|
+
|
70
|
+
# 1. Get function
|
71
|
+
try:
|
72
|
+
function = get_function(FunctionName(tool_call.name))
|
73
|
+
except ValueError as e:
|
74
|
+
error_msg = f"Function '{tool_call.name!r}' not exists: {e}"
|
75
|
+
console.print(error_msg, style="red")
|
76
|
+
return error_msg, False
|
77
|
+
|
78
|
+
# 2. Parse function arguments
|
79
|
+
try:
|
80
|
+
arguments = repair_json(tool_call.arguments, return_objects=True)
|
81
|
+
if not isinstance(arguments, dict):
|
82
|
+
error_msg = f"Invalid arguments type: {arguments!r}, should be JSON object"
|
83
|
+
console.print(error_msg, style="red")
|
84
|
+
return error_msg, False
|
85
|
+
arguments = cast(dict, arguments)
|
86
|
+
except Exception as e:
|
87
|
+
error_msg = f"Invalid arguments from llm: {e}\nRaw arguments: {tool_call.arguments!r}"
|
88
|
+
console.print(error_msg, style="red")
|
89
|
+
return error_msg, False
|
90
|
+
|
91
|
+
# 3. execute function
|
92
|
+
try:
|
93
|
+
function_result = function.execute(**arguments)
|
94
|
+
if cfg["SHOW_FUNCTION_OUTPUT"]:
|
95
|
+
panel = Panel(
|
96
|
+
function_result,
|
97
|
+
title="Function output",
|
98
|
+
title_align="left",
|
99
|
+
expand=False,
|
100
|
+
border_style="blue",
|
101
|
+
style="dim",
|
102
|
+
)
|
103
|
+
console.print(panel)
|
104
|
+
return function_result, True
|
105
|
+
except Exception as e:
|
106
|
+
error_msg = f"Call function error: {e}\nFunction name: {tool_call.name!r}\nArguments: {arguments!r}"
|
107
|
+
console.print(error_msg, style="red")
|
108
|
+
return error_msg, False
|
109
|
+
|
110
|
+
def completion(
|
111
|
+
self,
|
112
|
+
messages: List[ChatMessage],
|
113
|
+
stream: bool = False,
|
114
|
+
recursion_depth: int = 0,
|
115
|
+
) -> Generator[Union[LLMResponse, RefreshLive], None, None]:
|
116
|
+
"""Send message to OpenAI with a maximum recursion depth of 5"""
|
117
|
+
if self.verbose:
|
118
|
+
console.print(messages)
|
119
|
+
openai_messages = self._convert_messages(messages)
|
120
|
+
|
121
|
+
# Prepare request parameters
|
122
|
+
params: Dict[str, Any] = {
|
123
|
+
"model": self.model,
|
124
|
+
"messages": openai_messages,
|
125
|
+
"temperature": cfg["TEMPERATURE"],
|
126
|
+
"top_p": cfg["TOP_P"],
|
127
|
+
"stream": stream,
|
128
|
+
# Openai: This value is now deprecated in favor of max_completion_tokens.
|
129
|
+
"max_tokens": cfg["MAX_TOKENS"],
|
130
|
+
"max_completion_tokens": cfg["MAX_TOKENS"],
|
131
|
+
# litellm api params
|
132
|
+
"api_key": self.api_key,
|
133
|
+
"base_url": self.base_url,
|
134
|
+
"reasoning_effort": cfg["REASONING_EFFORT"],
|
135
|
+
}
|
136
|
+
|
137
|
+
# Add optional parameters
|
138
|
+
if cfg["ENABLE_FUNCTIONS"]:
|
139
|
+
params["tools"] = self._convert_functions(list_functions())
|
140
|
+
params["tool_choice"] = "auto"
|
141
|
+
params["parallel_tool_calls"] = False
|
142
|
+
# Send request
|
143
|
+
response = litellm.completion(**params)
|
144
|
+
if stream:
|
145
|
+
response = cast(litellm.CustomStreamWrapper, response)
|
146
|
+
llm_content_generator = self._handle_stream_response(response)
|
147
|
+
else:
|
148
|
+
response = cast(ModelResponse, response)
|
149
|
+
llm_content_generator = self._handle_normal_response(response)
|
150
|
+
for llm_content in llm_content_generator:
|
151
|
+
yield llm_content
|
152
|
+
if llm_content.tool_call:
|
153
|
+
if not self.pre_tool_call_id:
|
154
|
+
self.pre_tool_call_id = llm_content.tool_call.id
|
155
|
+
elif self.pre_tool_call_id == llm_content.tool_call.id:
|
156
|
+
continue
|
157
|
+
# Let live display know we are in next run
|
158
|
+
yield RefreshLive()
|
159
|
+
|
160
|
+
# execute function call
|
161
|
+
function_result, _ = self._execute_tool_call(llm_content.tool_call)
|
162
|
+
|
163
|
+
# add function call result
|
164
|
+
messages.append(
|
165
|
+
ChatMessage(
|
166
|
+
role=self.detect_tool_role(cfg["PROVIDER"]),
|
167
|
+
content=function_result,
|
168
|
+
name=llm_content.tool_call.name,
|
169
|
+
tool_call_id=llm_content.tool_call.id,
|
170
|
+
)
|
171
|
+
)
|
172
|
+
# Check if we've exceeded the maximum recursion depth
|
173
|
+
if recursion_depth >= 5:
|
174
|
+
console.print("Maximum recursion depth (5) reached, stopping further tool calls", style="yellow")
|
175
|
+
return
|
176
|
+
|
177
|
+
# Continue with recursion if within limits
|
178
|
+
if stream:
|
179
|
+
yield from self.completion(messages, stream=stream, recursion_depth=recursion_depth + 1)
|
180
|
+
else:
|
181
|
+
yield from self.completion(messages, stream=stream, recursion_depth=recursion_depth + 1)
|
182
|
+
# yield StopLive()
|
183
|
+
|
184
|
+
def stream_completion(self, messages: List[ChatMessage], stream: bool = True) -> Generator[LLMResponse, None, None]:
|
185
|
+
openai_messages = self._convert_messages(messages)
|
186
|
+
params: Dict[str, Any] = {
|
187
|
+
"model": self.model,
|
188
|
+
"messages": openai_messages,
|
189
|
+
"temperature": cfg["TEMPERATURE"],
|
190
|
+
"top_p": cfg["TOP_P"],
|
191
|
+
"stream": stream,
|
192
|
+
# Openai: This value is now deprecated in favor of max_completion_tokens.
|
193
|
+
"max_tokens": cfg["MAX_TOKENS"],
|
194
|
+
"max_completion_tokens": cfg["MAX_TOKENS"],
|
195
|
+
# litellm api params
|
196
|
+
"api_key": self.api_key,
|
197
|
+
"base_url": self.base_url,
|
198
|
+
}
|
199
|
+
# Add optional parameters
|
200
|
+
if cfg["ENABLE_FUNCTIONS"]:
|
201
|
+
params["tools"] = self._convert_functions(list_functions())
|
202
|
+
params["tool_choice"] = "auto"
|
203
|
+
params["parallel_tool_calls"] = False
|
204
|
+
|
205
|
+
# Send request
|
206
|
+
response = litellm.completion(**params)
|
207
|
+
response = cast(litellm.CustomStreamWrapper, response)
|
208
|
+
llm_content_generator = self._handle_stream_response(response)
|
209
|
+
for llm_content in llm_content_generator:
|
210
|
+
yield llm_content
|
211
|
+
if llm_content.tool_call:
|
212
|
+
if not self.pre_tool_call_id:
|
213
|
+
self.pre_tool_call_id = llm_content.tool_call.id
|
214
|
+
elif self.pre_tool_call_id == llm_content.tool_call.id:
|
215
|
+
continue
|
216
|
+
|
217
|
+
# execute function
|
218
|
+
function_result, _ = self._execute_tool_call(llm_content.tool_call)
|
219
|
+
|
220
|
+
# add function call result
|
221
|
+
messages.append(
|
222
|
+
ChatMessage(
|
223
|
+
role=self.detect_tool_role(cfg["PROVIDER"]),
|
224
|
+
content=function_result,
|
225
|
+
name=llm_content.tool_call.name,
|
226
|
+
tool_call_id=llm_content.tool_call.id,
|
227
|
+
)
|
228
|
+
)
|
229
|
+
|
230
|
+
yield from self.stream_completion(messages)
|
231
|
+
|
232
|
+
def _handle_normal_response(self, response: ModelResponse) -> Generator[LLMResponse, None, None]:
|
233
|
+
"""Handle normal (non-streaming) response
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
LLMContent object with:
|
237
|
+
- reasoning: The thinking/reasoning content (if any)
|
238
|
+
- content: The normal response content
|
239
|
+
"""
|
240
|
+
choice = response.choices[0]
|
241
|
+
content = choice.message.content or "" # type: ignore
|
242
|
+
reasoning = choice.message.reasoning_content # type: ignore
|
243
|
+
finish_reason = choice.finish_reason
|
244
|
+
tool_call: Optional[ToolCall] = None
|
245
|
+
|
246
|
+
# Check if the response contains reasoning content
|
247
|
+
if "<think>" in content and "</think>" in content:
|
248
|
+
# Extract reasoning content
|
249
|
+
content = content.lstrip()
|
250
|
+
if content.startswith("<think>"):
|
251
|
+
think_end = content.find("</think>")
|
252
|
+
if think_end != -1:
|
253
|
+
reasoning = content[7:think_end].strip() # Start after <think>
|
254
|
+
# Remove the <think> block from the main content
|
255
|
+
content = content[think_end + 8 :].strip() # Start after </think>
|
256
|
+
# Check if the response contains reasoning content in model_extra
|
257
|
+
elif hasattr(choice.message, "model_extra") and choice.message.model_extra: # type: ignore
|
258
|
+
model_extra = choice.message.model_extra # type: ignore
|
259
|
+
reasoning = self._get_reasoning_content(model_extra)
|
260
|
+
if finish_reason == "tool_calls":
|
261
|
+
if '{"index":' in content or '"tool_calls":' in content:
|
262
|
+
# Tool call data may in content after the <think> block
|
263
|
+
# >/n{"index": 0, "tool_call_id": "call_1", "function": {"name": "name", "arguments": "{}"}, "output": null}
|
264
|
+
tool_index = content.find('{"index":')
|
265
|
+
if tool_index != -1:
|
266
|
+
tmp_content = content[tool_index:]
|
267
|
+
# Tool call data may in content after the <think> block
|
268
|
+
try:
|
269
|
+
choice = self.parse_choice_from_content(tmp_content)
|
270
|
+
except ValueError:
|
271
|
+
pass
|
272
|
+
if hasattr(choice, "message") and hasattr(choice.message, "tool_calls") and choice.message.tool_calls: # type: ignore
|
273
|
+
tool = choice.message.tool_calls[0] # type: ignore
|
274
|
+
tool_call = ToolCall(tool.id, tool.function.name or "", tool.function.arguments)
|
275
|
+
|
276
|
+
yield LLMResponse(reasoning=reasoning, content=content, finish_reason=finish_reason, tool_call=tool_call)
|
277
|
+
|
278
|
+
def _handle_stream_response(self, response: litellm.CustomStreamWrapper) -> Generator[LLMResponse, None, None]:
|
279
|
+
"""Handle streaming response
|
280
|
+
|
281
|
+
Returns:
|
282
|
+
Generator yielding LLMContent objects with:
|
283
|
+
- reasoning: The thinking/reasoning content (if any)
|
284
|
+
- content: The normal response content
|
285
|
+
"""
|
286
|
+
full_reasoning = ""
|
287
|
+
full_content = ""
|
288
|
+
content = ""
|
289
|
+
reasoning = ""
|
290
|
+
tool_id = ""
|
291
|
+
tool_call_name = ""
|
292
|
+
arguments = ""
|
293
|
+
tool_call: Optional[ToolCall] = None
|
294
|
+
for chunk in response:
|
295
|
+
# Check if the response contains reasoning content
|
296
|
+
choice = chunk.choices[0] # type: ignore
|
297
|
+
delta = choice.delta
|
298
|
+
finish_reason = choice.finish_reason
|
299
|
+
|
300
|
+
# Concat content
|
301
|
+
content = delta.content or ""
|
302
|
+
full_content += content
|
303
|
+
|
304
|
+
# Concat reasoning
|
305
|
+
reasoning = self._get_reasoning_content(delta)
|
306
|
+
full_reasoning += reasoning or ""
|
307
|
+
|
308
|
+
if finish_reason:
|
309
|
+
pass
|
310
|
+
if finish_reason == "tool_calls" or ('{"index":' in content or '"tool_calls":' in content):
|
311
|
+
# Tool call data may in content after the <think> block
|
312
|
+
# >/n{"index": 0, "tool_call_id": "call_1", "function": {"name": "name", "arguments": "{}"}, "output": null}
|
313
|
+
tool_index = full_content.find('{"index":')
|
314
|
+
if tool_index != -1:
|
315
|
+
tmp_content = full_content[tool_index:]
|
316
|
+
try:
|
317
|
+
choice = self.parse_choice_from_content(tmp_content)
|
318
|
+
except ValueError:
|
319
|
+
pass
|
320
|
+
if hasattr(choice.delta, "tool_calls") and choice.delta.tool_calls: # type: ignore
|
321
|
+
# Handle tool calls
|
322
|
+
tool_id = choice.delta.tool_calls[0].id or "" # type: ignore
|
323
|
+
for tool in choice.delta.tool_calls: # type: ignore
|
324
|
+
if not tool.function:
|
325
|
+
continue
|
326
|
+
tool_call_name = tool.function.name or ""
|
327
|
+
arguments += tool.function.arguments or ""
|
328
|
+
tool_call = ToolCall(tool_id, tool_call_name, arguments)
|
329
|
+
yield LLMResponse(reasoning=reasoning, content=content, tool_call=tool_call, finish_reason=finish_reason)
|
330
|
+
|
331
|
+
def _get_reasoning_content(self, delta: Any) -> Optional[str]:
|
332
|
+
"""Extract reasoning content from delta if available based on specific keys.
|
333
|
+
|
334
|
+
This method checks for various keys that might contain reasoning content
|
335
|
+
in different API implementations.
|
336
|
+
|
337
|
+
Args:
|
338
|
+
delta: The delta/model_extra from the API response
|
339
|
+
|
340
|
+
Returns:
|
341
|
+
The reasoning content string if found, None otherwise
|
342
|
+
"""
|
343
|
+
if not delta:
|
344
|
+
return None
|
345
|
+
# Reasoning content keys from API:
|
346
|
+
# reasoning_content: deepseek/infi-ai
|
347
|
+
# reasoning: openrouter
|
348
|
+
# <think> block implementation not in here
|
349
|
+
for key in ("reasoning_content", "reasoning"):
|
350
|
+
# Check if the key exists and its value is a non-empty string
|
351
|
+
if hasattr(delta, key):
|
352
|
+
return getattr(delta, key)
|
353
|
+
|
354
|
+
return None # Return None if no relevant key with a string value is found
|
355
|
+
|
356
|
+
def parse_choice_from_content(self, content: str) -> Choices:
|
357
|
+
"""
|
358
|
+
Parse the choice from the content after <think>...</think> block.
|
359
|
+
Args:
|
360
|
+
content: The content from the LLM response
|
361
|
+
choice_cls: The class to use to parse the choice
|
362
|
+
Returns:
|
363
|
+
The choice object
|
364
|
+
Raises ValueError if the content is not valid JSON
|
365
|
+
"""
|
366
|
+
try:
|
367
|
+
content_dict = json.loads(content)
|
368
|
+
except json.JSONDecodeError:
|
369
|
+
raise ValueError(f"Invalid message from LLM: {content}")
|
370
|
+
if "delta" in content_dict:
|
371
|
+
try:
|
372
|
+
content_dict["delta"] = ChoiceMessage.model_validate(content_dict["delta"])
|
373
|
+
except Exception as e:
|
374
|
+
raise ValueError(f"Invalid message from LLM: {content}") from e
|
375
|
+
try:
|
376
|
+
return Choices.model_validate(content_dict)
|
377
|
+
except Exception as e:
|
378
|
+
raise ValueError(f"Invalid message from LLM: {content}") from e
|
379
|
+
|
380
|
+
def detect_tool_role(self, provider: str) -> str:
|
381
|
+
"""Detect the role of the tool call"""
|
382
|
+
if provider == "gemini":
|
383
|
+
return "user"
|
384
|
+
else:
|
385
|
+
return "tool"
|
yaicli/config.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1
1
|
import configparser
|
2
|
+
from dataclasses import dataclass
|
2
3
|
from functools import lru_cache
|
3
4
|
from os import getenv
|
4
|
-
from typing import Optional
|
5
|
+
from typing import Any, Optional
|
5
6
|
|
6
7
|
from rich import get_console
|
7
8
|
from rich.console import Console
|
8
9
|
|
9
|
-
from
|
10
|
+
from .const import (
|
10
11
|
CONFIG_PATH,
|
11
12
|
DEFAULT_CHAT_HISTORY_DIR,
|
12
13
|
DEFAULT_CONFIG_INI,
|
13
14
|
DEFAULT_CONFIG_MAP,
|
14
|
-
|
15
|
-
|
16
|
-
DEFAULT_ROLE_MODIFY_WARNING,
|
15
|
+
DEFAULT_TEMPERATURE,
|
16
|
+
DEFAULT_TOP_P,
|
17
17
|
)
|
18
|
-
from
|
18
|
+
from .utils import str2bool
|
19
19
|
|
20
20
|
|
21
21
|
class CasePreservingConfigParser(configparser.RawConfigParser):
|
@@ -25,6 +25,17 @@ class CasePreservingConfigParser(configparser.RawConfigParser):
|
|
25
25
|
return optionstr
|
26
26
|
|
27
27
|
|
28
|
+
@dataclass
|
29
|
+
class ProviderConfig:
|
30
|
+
"""Provider configuration"""
|
31
|
+
|
32
|
+
api_key: str
|
33
|
+
model: str
|
34
|
+
base_url: Optional[str] = None
|
35
|
+
temperature: float = DEFAULT_TEMPERATURE
|
36
|
+
top_p: float = DEFAULT_TOP_P
|
37
|
+
|
38
|
+
|
28
39
|
class Config(dict):
|
29
40
|
"""Configuration class that loads settings on initialization.
|
30
41
|
|
@@ -50,7 +61,7 @@ class Config(dict):
|
|
50
61
|
"""
|
51
62
|
# Start with defaults
|
52
63
|
self.clear()
|
53
|
-
self.
|
64
|
+
self._load_defaults()
|
54
65
|
|
55
66
|
# Load from config file
|
56
67
|
self._load_from_file()
|
@@ -59,13 +70,15 @@ class Config(dict):
|
|
59
70
|
self._load_from_env()
|
60
71
|
self._apply_type_conversion()
|
61
72
|
|
62
|
-
def _load_defaults(self) -> dict[str,
|
73
|
+
def _load_defaults(self) -> dict[str, Any]:
|
63
74
|
"""Load default configuration values as strings.
|
64
75
|
|
65
76
|
Returns:
|
66
77
|
Dictionary with default configuration values
|
67
78
|
"""
|
68
|
-
|
79
|
+
defaults = {k: v["value"] for k, v in DEFAULT_CONFIG_MAP.items()}
|
80
|
+
self.update(defaults)
|
81
|
+
return defaults
|
69
82
|
|
70
83
|
def _ensure_version_updated_config_keys(self):
|
71
84
|
"""Ensure configuration keys added in version updates exist in the config file.
|
@@ -75,14 +88,6 @@ class Config(dict):
|
|
75
88
|
config_content = f.read()
|
76
89
|
if "CHAT_HISTORY_DIR" not in config_content.strip(): # Check for empty lines
|
77
90
|
f.write(f"\nCHAT_HISTORY_DIR={DEFAULT_CHAT_HISTORY_DIR}\n")
|
78
|
-
if "MAX_SAVED_CHATS" not in config_content.strip(): # Check for empty lines
|
79
|
-
f.write(f"\nMAX_SAVED_CHATS={DEFAULT_MAX_SAVED_CHATS}\n")
|
80
|
-
if "JUSTIFY" not in config_content.strip():
|
81
|
-
f.write(f"\nJUSTIFY={DEFAULT_JUSTIFY}\n")
|
82
|
-
if "ROLE_MODIFY_WARNING" not in config_content.strip():
|
83
|
-
f.write(
|
84
|
-
f"\n# Set to false to disable warnings about modified built-in roles\nROLE_MODIFY_WARNING={DEFAULT_ROLE_MODIFY_WARNING}\n"
|
85
|
-
)
|
86
91
|
|
87
92
|
def _load_from_file(self) -> None:
|
88
93
|
"""Load configuration from the config file.
|
@@ -90,7 +95,7 @@ class Config(dict):
|
|
90
95
|
Creates default config file if it doesn't exist.
|
91
96
|
"""
|
92
97
|
if not CONFIG_PATH.exists():
|
93
|
-
self.console.print("Creating default configuration file.", style="bold yellow", justify=self
|
98
|
+
self.console.print("Creating default configuration file.", style="bold yellow", justify=self["JUSTIFY"])
|
94
99
|
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
95
100
|
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
|
96
101
|
f.write(DEFAULT_CONFIG_INI)
|
@@ -107,7 +112,7 @@ class Config(dict):
|
|
107
112
|
if not config_parser["core"].get(k, "").strip():
|
108
113
|
config_parser["core"][k] = v
|
109
114
|
|
110
|
-
self.update(config_parser["core"])
|
115
|
+
self.update(dict(config_parser["core"]))
|
111
116
|
|
112
117
|
# Check if keys added in version updates are missing and add them
|
113
118
|
self._ensure_version_updated_config_keys()
|
@@ -128,24 +133,26 @@ class Config(dict):
|
|
128
133
|
Updates the configuration dictionary in-place with properly typed values.
|
129
134
|
Falls back to default values if conversion fails.
|
130
135
|
"""
|
131
|
-
default_values_str =
|
136
|
+
default_values_str = {k: v["value"] for k, v in DEFAULT_CONFIG_MAP.items()}
|
132
137
|
|
133
138
|
for key, config_info in DEFAULT_CONFIG_MAP.items():
|
134
139
|
target_type = config_info["type"]
|
135
|
-
raw_value = self
|
140
|
+
raw_value = self[key]
|
136
141
|
converted_value = None
|
137
142
|
|
138
143
|
try:
|
144
|
+
if raw_value is None:
|
145
|
+
raw_value = default_values_str.get(key, "")
|
139
146
|
if target_type is bool:
|
140
147
|
converted_value = str2bool(raw_value)
|
141
148
|
elif target_type in (int, float, str):
|
142
149
|
converted_value = target_type(raw_value)
|
143
150
|
except (ValueError, TypeError) as e:
|
144
151
|
self.console.print(
|
145
|
-
f"[yellow]Warning:[/
|
152
|
+
f"[yellow]Warning:[/] Invalid value '{raw_value}' for '{key}'. "
|
146
153
|
f"Expected type '{target_type.__name__}'. Using default value '{default_values_str[key]}'. Error: {e}",
|
147
154
|
style="dim",
|
148
|
-
justify=self
|
155
|
+
justify=self["JUSTIFY"],
|
149
156
|
)
|
150
157
|
# Fallback to default string value if conversion fails
|
151
158
|
try:
|
@@ -158,7 +165,7 @@ class Config(dict):
|
|
158
165
|
self.console.print(
|
159
166
|
f"[red]Error:[/red] Could not convert default value for '{key}'. Using raw value.",
|
160
167
|
style="error",
|
161
|
-
justify=self
|
168
|
+
justify=self["JUSTIFY"],
|
162
169
|
)
|
163
170
|
converted_value = raw_value # Or assign a hardcoded safe default
|
164
171
|
|
yaicli/console.py
CHANGED
@@ -3,8 +3,8 @@ from typing import Any, Optional, Union
|
|
3
3
|
from rich.console import Console, JustifyMethod, OverflowMethod
|
4
4
|
from rich.style import Style
|
5
5
|
|
6
|
-
from
|
7
|
-
from
|
6
|
+
from .config import cfg
|
7
|
+
from .const import DEFAULT_JUSTIFY
|
8
8
|
|
9
9
|
_console = None
|
10
10
|
|
yaicli/const.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from enum import StrEnum
|
2
2
|
from pathlib import Path
|
3
3
|
from tempfile import gettempdir
|
4
|
-
from typing import Any, Literal
|
4
|
+
from typing import Any, Literal, Optional
|
5
5
|
|
6
6
|
from rich.console import JustifyMethod
|
7
7
|
|
@@ -24,14 +24,17 @@ CMD_SAVE_CHAT = "/save"
|
|
24
24
|
CMD_LOAD_CHAT = "/load"
|
25
25
|
CMD_LIST_CHATS = "/list"
|
26
26
|
CMD_DELETE_CHAT = "/del"
|
27
|
+
CMD_HELP = ("/help", "?")
|
27
28
|
|
28
29
|
EXEC_MODE = "exec"
|
29
30
|
CHAT_MODE = "chat"
|
30
31
|
TEMP_MODE = "temp"
|
31
32
|
CODE_MODE = "code"
|
32
33
|
|
34
|
+
HISTORY_FILE = Path("~/.yaicli_history").expanduser()
|
33
35
|
CONFIG_PATH = Path("~/.config/yaicli/config.ini").expanduser()
|
34
36
|
ROLES_DIR = CONFIG_PATH.parent / "roles"
|
37
|
+
FUNCTIONS_DIR = CONFIG_PATH.parent / "functions"
|
35
38
|
|
36
39
|
# Default configuration values
|
37
40
|
DEFAULT_CODE_THEME = "monokai"
|
@@ -41,7 +44,7 @@ DEFAULT_MODEL = "gpt-4o"
|
|
41
44
|
DEFAULT_SHELL_NAME = "auto"
|
42
45
|
DEFAULT_OS_NAME = "auto"
|
43
46
|
DEFAULT_STREAM: BOOL_STR = "true"
|
44
|
-
DEFAULT_TEMPERATURE: float = 0.
|
47
|
+
DEFAULT_TEMPERATURE: float = 0.5
|
45
48
|
DEFAULT_TOP_P: float = 1.0
|
46
49
|
DEFAULT_MAX_TOKENS: int = 1024
|
47
50
|
DEFAULT_MAX_HISTORY: int = 500
|
@@ -53,6 +56,9 @@ DEFAULT_CHAT_HISTORY_DIR: Path = Path(gettempdir()) / "yaicli/chats"
|
|
53
56
|
DEFAULT_MAX_SAVED_CHATS = 20
|
54
57
|
DEFAULT_JUSTIFY: JustifyMethod = "default"
|
55
58
|
DEFAULT_ROLE_MODIFY_WARNING: BOOL_STR = "true"
|
59
|
+
DEFAULT_ENABLE_FUNCTIONS: BOOL_STR = "true"
|
60
|
+
DEFAULT_SHOW_FUNCTION_OUTPUT: BOOL_STR = "true"
|
61
|
+
DEFAULT_REASONING_EFFORT: Optional[Literal["low", "high", "medium"]] = None
|
56
62
|
|
57
63
|
|
58
64
|
class EventTypeEnum(StrEnum):
|
@@ -63,6 +69,11 @@ class EventTypeEnum(StrEnum):
|
|
63
69
|
REASONING_END = "reasoning_end"
|
64
70
|
CONTENT = "content"
|
65
71
|
FINISH = "finish"
|
72
|
+
TOOL_CALL_START = "tool_call_start"
|
73
|
+
TOOL_CALL_DELTA = "tool_call_delta"
|
74
|
+
TOOL_CALL_END = "tool_call_end"
|
75
|
+
TOOL_RESULT = "tool_result"
|
76
|
+
TOOL_CALLS_FINISH = "tool_calls_finish"
|
66
77
|
|
67
78
|
|
68
79
|
SHELL_PROMPT = """You are YAICLI, a shell command generator.
|
@@ -119,6 +130,7 @@ DEFAULT_CONFIG_MAP = {
|
|
119
130
|
"TOP_P": {"value": DEFAULT_TOP_P, "env_key": "YAI_TOP_P", "type": float},
|
120
131
|
"MAX_TOKENS": {"value": DEFAULT_MAX_TOKENS, "env_key": "YAI_MAX_TOKENS", "type": int},
|
121
132
|
"TIMEOUT": {"value": DEFAULT_TIMEOUT, "env_key": "YAI_TIMEOUT", "type": int},
|
133
|
+
"REASONING_EFFORT": {"value": DEFAULT_REASONING_EFFORT, "env_key": "YAI_REASONING_EFFORT", "type": str},
|
122
134
|
"INTERACTIVE_ROUND": {
|
123
135
|
"value": DEFAULT_INTERACTIVE_ROUND,
|
124
136
|
"env_key": "YAI_INTERACTIVE_ROUND",
|
@@ -135,6 +147,13 @@ DEFAULT_CONFIG_MAP = {
|
|
135
147
|
"MAX_SAVED_CHATS": {"value": DEFAULT_MAX_SAVED_CHATS, "env_key": "YAI_MAX_SAVED_CHATS", "type": int},
|
136
148
|
# Role settings
|
137
149
|
"ROLE_MODIFY_WARNING": {"value": DEFAULT_ROLE_MODIFY_WARNING, "env_key": "YAI_ROLE_MODIFY_WARNING", "type": bool},
|
150
|
+
# Function settings
|
151
|
+
"ENABLE_FUNCTIONS": {"value": DEFAULT_ENABLE_FUNCTIONS, "env_key": "YAI_ENABLE_FUNCTIONS", "type": bool},
|
152
|
+
"SHOW_FUNCTION_OUTPUT": {
|
153
|
+
"value": DEFAULT_SHOW_FUNCTION_OUTPUT,
|
154
|
+
"env_key": "YAI_SHOW_FUNCTION_OUTPUT",
|
155
|
+
"type": bool,
|
156
|
+
},
|
138
157
|
}
|
139
158
|
|
140
159
|
DEFAULT_CONFIG_INI = f"""[core]
|
@@ -155,6 +174,7 @@ TEMPERATURE={DEFAULT_CONFIG_MAP["TEMPERATURE"]["value"]}
|
|
155
174
|
TOP_P={DEFAULT_CONFIG_MAP["TOP_P"]["value"]}
|
156
175
|
MAX_TOKENS={DEFAULT_CONFIG_MAP["MAX_TOKENS"]["value"]}
|
157
176
|
TIMEOUT={DEFAULT_CONFIG_MAP["TIMEOUT"]["value"]}
|
177
|
+
REASONING_EFFORT=
|
158
178
|
|
159
179
|
# Interactive mode parameters
|
160
180
|
INTERACTIVE_ROUND={DEFAULT_CONFIG_MAP["INTERACTIVE_ROUND"]["value"]}
|
@@ -176,4 +196,10 @@ MAX_SAVED_CHATS={DEFAULT_CONFIG_MAP["MAX_SAVED_CHATS"]["value"]}
|
|
176
196
|
# Role settings
|
177
197
|
# Set to false to disable warnings about modified built-in roles
|
178
198
|
ROLE_MODIFY_WARNING={DEFAULT_CONFIG_MAP["ROLE_MODIFY_WARNING"]["value"]}
|
199
|
+
|
200
|
+
# Function settings
|
201
|
+
# Set to false to disable sending functions in API requests
|
202
|
+
ENABLE_FUNCTIONS={DEFAULT_CONFIG_MAP["ENABLE_FUNCTIONS"]["value"]}
|
203
|
+
# Set to false to disable showing function output in the response
|
204
|
+
SHOW_FUNCTION_OUTPUT={DEFAULT_CONFIG_MAP["SHOW_FUNCTION_OUTPUT"]["value"]}
|
179
205
|
"""
|