yaicli 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyproject.toml +6 -3
- yaicli/chat.py +396 -0
- yaicli/cli.py +251 -251
- yaicli/client.py +385 -0
- yaicli/config.py +32 -20
- yaicli/console.py +2 -2
- yaicli/const.py +46 -21
- yaicli/entry.py +68 -39
- yaicli/exceptions.py +8 -36
- yaicli/functions/__init__.py +39 -0
- yaicli/functions/buildin/execute_shell_command.py +47 -0
- yaicli/printer.py +145 -225
- yaicli/render.py +1 -1
- yaicli/role.py +231 -0
- yaicli/schemas.py +31 -0
- yaicli/tools.py +103 -0
- yaicli/utils.py +5 -2
- {yaicli-0.3.3.dist-info → yaicli-0.5.0.dist-info}/METADATA +172 -132
- yaicli-0.5.0.dist-info/RECORD +24 -0
- {yaicli-0.3.3.dist-info → yaicli-0.5.0.dist-info}/entry_points.txt +1 -1
- yaicli/api.py +0 -316
- yaicli/chat_manager.py +0 -290
- yaicli/roles.py +0 -248
- yaicli-0.3.3.dist-info/RECORD +0 -20
- {yaicli-0.3.3.dist-info → yaicli-0.5.0.dist-info}/WHEEL +0 -0
- {yaicli-0.3.3.dist-info → yaicli-0.5.0.dist-info}/licenses/LICENSE +0 -0
yaicli/client.py
ADDED
@@ -0,0 +1,385 @@
|
|
1
|
+
import json
|
2
|
+
from dataclasses import dataclass, field
|
3
|
+
from typing import Any, Dict, Generator, List, Optional, Union, cast
|
4
|
+
|
5
|
+
import litellm
|
6
|
+
from json_repair import repair_json
|
7
|
+
from litellm.types.utils import Choices
|
8
|
+
from litellm.types.utils import Message as ChoiceMessage
|
9
|
+
from litellm.types.utils import ModelResponse
|
10
|
+
from rich.panel import Panel
|
11
|
+
|
12
|
+
from .config import cfg
|
13
|
+
from .console import get_console
|
14
|
+
from .schemas import LLMResponse, ChatMessage, ToolCall
|
15
|
+
from .tools import (
|
16
|
+
Function,
|
17
|
+
FunctionName,
|
18
|
+
get_function,
|
19
|
+
get_openai_schemas,
|
20
|
+
list_functions,
|
21
|
+
)
|
22
|
+
|
23
|
+
litellm.drop_params = True
|
24
|
+
console = get_console()
|
25
|
+
|
26
|
+
|
27
|
+
class RefreshLive:
|
28
|
+
"""Refresh live display"""
|
29
|
+
|
30
|
+
|
31
|
+
class StopLive:
|
32
|
+
"""Stop live display"""
|
33
|
+
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class LitellmClient:
|
37
|
+
"""OpenAI provider implementation"""
|
38
|
+
|
39
|
+
api_key: str = field(default_factory=lambda: cfg["API_KEY"])
|
40
|
+
model: str = field(default_factory=lambda: f"{cfg['PROVIDER']}/{cfg['MODEL']}")
|
41
|
+
base_url: Optional[str] = field(default_factory=lambda: cfg["BASE_URL"])
|
42
|
+
timeout: int = field(default_factory=lambda: cfg["TIMEOUT"])
|
43
|
+
|
44
|
+
verbose: bool = False
|
45
|
+
|
46
|
+
def __post_init__(self) -> None:
|
47
|
+
"""Initialize OpenAI client"""
|
48
|
+
self.pre_tool_call_id = None
|
49
|
+
|
50
|
+
def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
51
|
+
"""Convert message format to OpenAI API required format"""
|
52
|
+
openai_messages = []
|
53
|
+
for msg in messages:
|
54
|
+
if msg.tool_call_id:
|
55
|
+
openai_messages.append(
|
56
|
+
{"role": msg.role, "content": msg.content, "tool_call_id": msg.tool_call_id, "name": msg.name}
|
57
|
+
)
|
58
|
+
else:
|
59
|
+
openai_messages.append({"role": msg.role, "content": msg.content})
|
60
|
+
return openai_messages
|
61
|
+
|
62
|
+
def _convert_functions(self, _: List[Function]) -> List[Dict[str, Any]]:
|
63
|
+
"""Convert function format to OpenAI API required format"""
|
64
|
+
return get_openai_schemas()
|
65
|
+
|
66
|
+
def _execute_tool_call(self, tool_call: ToolCall) -> tuple[str, bool]:
|
67
|
+
"""Call function and return result"""
|
68
|
+
console.print(f"@Function call: {tool_call.name}({tool_call.arguments})", style="blue")
|
69
|
+
|
70
|
+
# 1. Get function
|
71
|
+
try:
|
72
|
+
function = get_function(FunctionName(tool_call.name))
|
73
|
+
except ValueError as e:
|
74
|
+
error_msg = f"Function '{tool_call.name!r}' not exists: {e}"
|
75
|
+
console.print(error_msg, style="red")
|
76
|
+
return error_msg, False
|
77
|
+
|
78
|
+
# 2. Parse function arguments
|
79
|
+
try:
|
80
|
+
arguments = repair_json(tool_call.arguments, return_objects=True)
|
81
|
+
if not isinstance(arguments, dict):
|
82
|
+
error_msg = f"Invalid arguments type: {arguments!r}, should be JSON object"
|
83
|
+
console.print(error_msg, style="red")
|
84
|
+
return error_msg, False
|
85
|
+
arguments = cast(dict, arguments)
|
86
|
+
except Exception as e:
|
87
|
+
error_msg = f"Invalid arguments from llm: {e}\nRaw arguments: {tool_call.arguments!r}"
|
88
|
+
console.print(error_msg, style="red")
|
89
|
+
return error_msg, False
|
90
|
+
|
91
|
+
# 3. execute function
|
92
|
+
try:
|
93
|
+
function_result = function.execute(**arguments)
|
94
|
+
if cfg["SHOW_FUNCTION_OUTPUT"]:
|
95
|
+
panel = Panel(
|
96
|
+
function_result,
|
97
|
+
title="Function output",
|
98
|
+
title_align="left",
|
99
|
+
expand=False,
|
100
|
+
border_style="blue",
|
101
|
+
style="dim",
|
102
|
+
)
|
103
|
+
console.print(panel)
|
104
|
+
return function_result, True
|
105
|
+
except Exception as e:
|
106
|
+
error_msg = f"Call function error: {e}\nFunction name: {tool_call.name!r}\nArguments: {arguments!r}"
|
107
|
+
console.print(error_msg, style="red")
|
108
|
+
return error_msg, False
|
109
|
+
|
110
|
+
def completion(
|
111
|
+
self,
|
112
|
+
messages: List[ChatMessage],
|
113
|
+
stream: bool = False,
|
114
|
+
recursion_depth: int = 0,
|
115
|
+
) -> Generator[Union[LLMResponse, RefreshLive], None, None]:
|
116
|
+
"""Send message to OpenAI with a maximum recursion depth of 5"""
|
117
|
+
if self.verbose:
|
118
|
+
console.print(messages)
|
119
|
+
openai_messages = self._convert_messages(messages)
|
120
|
+
|
121
|
+
# Prepare request parameters
|
122
|
+
params: Dict[str, Any] = {
|
123
|
+
"model": self.model,
|
124
|
+
"messages": openai_messages,
|
125
|
+
"temperature": cfg["TEMPERATURE"],
|
126
|
+
"top_p": cfg["TOP_P"],
|
127
|
+
"stream": stream,
|
128
|
+
# Openai: This value is now deprecated in favor of max_completion_tokens.
|
129
|
+
"max_tokens": cfg["MAX_TOKENS"],
|
130
|
+
"max_completion_tokens": cfg["MAX_TOKENS"],
|
131
|
+
# litellm api params
|
132
|
+
"api_key": self.api_key,
|
133
|
+
"base_url": self.base_url,
|
134
|
+
"reasoning_effort": cfg["REASONING_EFFORT"],
|
135
|
+
}
|
136
|
+
|
137
|
+
# Add optional parameters
|
138
|
+
if cfg["ENABLE_FUNCTIONS"]:
|
139
|
+
params["tools"] = self._convert_functions(list_functions())
|
140
|
+
params["tool_choice"] = "auto"
|
141
|
+
params["parallel_tool_calls"] = False
|
142
|
+
# Send request
|
143
|
+
response = litellm.completion(**params)
|
144
|
+
if stream:
|
145
|
+
response = cast(litellm.CustomStreamWrapper, response)
|
146
|
+
llm_content_generator = self._handle_stream_response(response)
|
147
|
+
else:
|
148
|
+
response = cast(ModelResponse, response)
|
149
|
+
llm_content_generator = self._handle_normal_response(response)
|
150
|
+
for llm_content in llm_content_generator:
|
151
|
+
yield llm_content
|
152
|
+
if llm_content.tool_call:
|
153
|
+
if not self.pre_tool_call_id:
|
154
|
+
self.pre_tool_call_id = llm_content.tool_call.id
|
155
|
+
elif self.pre_tool_call_id == llm_content.tool_call.id:
|
156
|
+
continue
|
157
|
+
# Let live display know we are in next run
|
158
|
+
yield RefreshLive()
|
159
|
+
|
160
|
+
# execute function call
|
161
|
+
function_result, _ = self._execute_tool_call(llm_content.tool_call)
|
162
|
+
|
163
|
+
# add function call result
|
164
|
+
messages.append(
|
165
|
+
ChatMessage(
|
166
|
+
role=self.detect_tool_role(cfg["PROVIDER"]),
|
167
|
+
content=function_result,
|
168
|
+
name=llm_content.tool_call.name,
|
169
|
+
tool_call_id=llm_content.tool_call.id,
|
170
|
+
)
|
171
|
+
)
|
172
|
+
# Check if we've exceeded the maximum recursion depth
|
173
|
+
if recursion_depth >= 5:
|
174
|
+
console.print("Maximum recursion depth (5) reached, stopping further tool calls", style="yellow")
|
175
|
+
return
|
176
|
+
|
177
|
+
# Continue with recursion if within limits
|
178
|
+
if stream:
|
179
|
+
yield from self.completion(messages, stream=stream, recursion_depth=recursion_depth + 1)
|
180
|
+
else:
|
181
|
+
yield from self.completion(messages, stream=stream, recursion_depth=recursion_depth + 1)
|
182
|
+
# yield StopLive()
|
183
|
+
|
184
|
+
def stream_completion(self, messages: List[ChatMessage], stream: bool = True) -> Generator[LLMResponse, None, None]:
|
185
|
+
openai_messages = self._convert_messages(messages)
|
186
|
+
params: Dict[str, Any] = {
|
187
|
+
"model": self.model,
|
188
|
+
"messages": openai_messages,
|
189
|
+
"temperature": cfg["TEMPERATURE"],
|
190
|
+
"top_p": cfg["TOP_P"],
|
191
|
+
"stream": stream,
|
192
|
+
# Openai: This value is now deprecated in favor of max_completion_tokens.
|
193
|
+
"max_tokens": cfg["MAX_TOKENS"],
|
194
|
+
"max_completion_tokens": cfg["MAX_TOKENS"],
|
195
|
+
# litellm api params
|
196
|
+
"api_key": self.api_key,
|
197
|
+
"base_url": self.base_url,
|
198
|
+
}
|
199
|
+
# Add optional parameters
|
200
|
+
if cfg["ENABLE_FUNCTIONS"]:
|
201
|
+
params["tools"] = self._convert_functions(list_functions())
|
202
|
+
params["tool_choice"] = "auto"
|
203
|
+
params["parallel_tool_calls"] = False
|
204
|
+
|
205
|
+
# Send request
|
206
|
+
response = litellm.completion(**params)
|
207
|
+
response = cast(litellm.CustomStreamWrapper, response)
|
208
|
+
llm_content_generator = self._handle_stream_response(response)
|
209
|
+
for llm_content in llm_content_generator:
|
210
|
+
yield llm_content
|
211
|
+
if llm_content.tool_call:
|
212
|
+
if not self.pre_tool_call_id:
|
213
|
+
self.pre_tool_call_id = llm_content.tool_call.id
|
214
|
+
elif self.pre_tool_call_id == llm_content.tool_call.id:
|
215
|
+
continue
|
216
|
+
|
217
|
+
# execute function
|
218
|
+
function_result, _ = self._execute_tool_call(llm_content.tool_call)
|
219
|
+
|
220
|
+
# add function call result
|
221
|
+
messages.append(
|
222
|
+
ChatMessage(
|
223
|
+
role=self.detect_tool_role(cfg["PROVIDER"]),
|
224
|
+
content=function_result,
|
225
|
+
name=llm_content.tool_call.name,
|
226
|
+
tool_call_id=llm_content.tool_call.id,
|
227
|
+
)
|
228
|
+
)
|
229
|
+
|
230
|
+
yield from self.stream_completion(messages)
|
231
|
+
|
232
|
+
def _handle_normal_response(self, response: ModelResponse) -> Generator[LLMResponse, None, None]:
|
233
|
+
"""Handle normal (non-streaming) response
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
LLMContent object with:
|
237
|
+
- reasoning: The thinking/reasoning content (if any)
|
238
|
+
- content: The normal response content
|
239
|
+
"""
|
240
|
+
choice = response.choices[0]
|
241
|
+
content = choice.message.content or "" # type: ignore
|
242
|
+
reasoning = choice.message.reasoning_content # type: ignore
|
243
|
+
finish_reason = choice.finish_reason
|
244
|
+
tool_call: Optional[ToolCall] = None
|
245
|
+
|
246
|
+
# Check if the response contains reasoning content
|
247
|
+
if "<think>" in content and "</think>" in content:
|
248
|
+
# Extract reasoning content
|
249
|
+
content = content.lstrip()
|
250
|
+
if content.startswith("<think>"):
|
251
|
+
think_end = content.find("</think>")
|
252
|
+
if think_end != -1:
|
253
|
+
reasoning = content[7:think_end].strip() # Start after <think>
|
254
|
+
# Remove the <think> block from the main content
|
255
|
+
content = content[think_end + 8 :].strip() # Start after </think>
|
256
|
+
# Check if the response contains reasoning content in model_extra
|
257
|
+
elif hasattr(choice.message, "model_extra") and choice.message.model_extra: # type: ignore
|
258
|
+
model_extra = choice.message.model_extra # type: ignore
|
259
|
+
reasoning = self._get_reasoning_content(model_extra)
|
260
|
+
if finish_reason == "tool_calls":
|
261
|
+
if '{"index":' in content or '"tool_calls":' in content:
|
262
|
+
# Tool call data may in content after the <think> block
|
263
|
+
# >/n{"index": 0, "tool_call_id": "call_1", "function": {"name": "name", "arguments": "{}"}, "output": null}
|
264
|
+
tool_index = content.find('{"index":')
|
265
|
+
if tool_index != -1:
|
266
|
+
tmp_content = content[tool_index:]
|
267
|
+
# Tool call data may in content after the <think> block
|
268
|
+
try:
|
269
|
+
choice = self.parse_choice_from_content(tmp_content)
|
270
|
+
except ValueError:
|
271
|
+
pass
|
272
|
+
if hasattr(choice, "message") and hasattr(choice.message, "tool_calls") and choice.message.tool_calls: # type: ignore
|
273
|
+
tool = choice.message.tool_calls[0] # type: ignore
|
274
|
+
tool_call = ToolCall(tool.id, tool.function.name or "", tool.function.arguments)
|
275
|
+
|
276
|
+
yield LLMResponse(reasoning=reasoning, content=content, finish_reason=finish_reason, tool_call=tool_call)
|
277
|
+
|
278
|
+
def _handle_stream_response(self, response: litellm.CustomStreamWrapper) -> Generator[LLMResponse, None, None]:
|
279
|
+
"""Handle streaming response
|
280
|
+
|
281
|
+
Returns:
|
282
|
+
Generator yielding LLMContent objects with:
|
283
|
+
- reasoning: The thinking/reasoning content (if any)
|
284
|
+
- content: The normal response content
|
285
|
+
"""
|
286
|
+
full_reasoning = ""
|
287
|
+
full_content = ""
|
288
|
+
content = ""
|
289
|
+
reasoning = ""
|
290
|
+
tool_id = ""
|
291
|
+
tool_call_name = ""
|
292
|
+
arguments = ""
|
293
|
+
tool_call: Optional[ToolCall] = None
|
294
|
+
for chunk in response:
|
295
|
+
# Check if the response contains reasoning content
|
296
|
+
choice = chunk.choices[0] # type: ignore
|
297
|
+
delta = choice.delta
|
298
|
+
finish_reason = choice.finish_reason
|
299
|
+
|
300
|
+
# Concat content
|
301
|
+
content = delta.content or ""
|
302
|
+
full_content += content
|
303
|
+
|
304
|
+
# Concat reasoning
|
305
|
+
reasoning = self._get_reasoning_content(delta)
|
306
|
+
full_reasoning += reasoning or ""
|
307
|
+
|
308
|
+
if finish_reason:
|
309
|
+
pass
|
310
|
+
if finish_reason == "tool_calls" or ('{"index":' in content or '"tool_calls":' in content):
|
311
|
+
# Tool call data may in content after the <think> block
|
312
|
+
# >/n{"index": 0, "tool_call_id": "call_1", "function": {"name": "name", "arguments": "{}"}, "output": null}
|
313
|
+
tool_index = full_content.find('{"index":')
|
314
|
+
if tool_index != -1:
|
315
|
+
tmp_content = full_content[tool_index:]
|
316
|
+
try:
|
317
|
+
choice = self.parse_choice_from_content(tmp_content)
|
318
|
+
except ValueError:
|
319
|
+
pass
|
320
|
+
if hasattr(choice.delta, "tool_calls") and choice.delta.tool_calls: # type: ignore
|
321
|
+
# Handle tool calls
|
322
|
+
tool_id = choice.delta.tool_calls[0].id or "" # type: ignore
|
323
|
+
for tool in choice.delta.tool_calls: # type: ignore
|
324
|
+
if not tool.function:
|
325
|
+
continue
|
326
|
+
tool_call_name = tool.function.name or ""
|
327
|
+
arguments += tool.function.arguments or ""
|
328
|
+
tool_call = ToolCall(tool_id, tool_call_name, arguments)
|
329
|
+
yield LLMResponse(reasoning=reasoning, content=content, tool_call=tool_call, finish_reason=finish_reason)
|
330
|
+
|
331
|
+
def _get_reasoning_content(self, delta: Any) -> Optional[str]:
|
332
|
+
"""Extract reasoning content from delta if available based on specific keys.
|
333
|
+
|
334
|
+
This method checks for various keys that might contain reasoning content
|
335
|
+
in different API implementations.
|
336
|
+
|
337
|
+
Args:
|
338
|
+
delta: The delta/model_extra from the API response
|
339
|
+
|
340
|
+
Returns:
|
341
|
+
The reasoning content string if found, None otherwise
|
342
|
+
"""
|
343
|
+
if not delta:
|
344
|
+
return None
|
345
|
+
# Reasoning content keys from API:
|
346
|
+
# reasoning_content: deepseek/infi-ai
|
347
|
+
# reasoning: openrouter
|
348
|
+
# <think> block implementation not in here
|
349
|
+
for key in ("reasoning_content", "reasoning"):
|
350
|
+
# Check if the key exists and its value is a non-empty string
|
351
|
+
if hasattr(delta, key):
|
352
|
+
return getattr(delta, key)
|
353
|
+
|
354
|
+
return None # Return None if no relevant key with a string value is found
|
355
|
+
|
356
|
+
def parse_choice_from_content(self, content: str) -> Choices:
|
357
|
+
"""
|
358
|
+
Parse the choice from the content after <think>...</think> block.
|
359
|
+
Args:
|
360
|
+
content: The content from the LLM response
|
361
|
+
choice_cls: The class to use to parse the choice
|
362
|
+
Returns:
|
363
|
+
The choice object
|
364
|
+
Raises ValueError if the content is not valid JSON
|
365
|
+
"""
|
366
|
+
try:
|
367
|
+
content_dict = json.loads(content)
|
368
|
+
except json.JSONDecodeError:
|
369
|
+
raise ValueError(f"Invalid message from LLM: {content}")
|
370
|
+
if "delta" in content_dict:
|
371
|
+
try:
|
372
|
+
content_dict["delta"] = ChoiceMessage.model_validate(content_dict["delta"])
|
373
|
+
except Exception as e:
|
374
|
+
raise ValueError(f"Invalid message from LLM: {content}") from e
|
375
|
+
try:
|
376
|
+
return Choices.model_validate(content_dict)
|
377
|
+
except Exception as e:
|
378
|
+
raise ValueError(f"Invalid message from LLM: {content}") from e
|
379
|
+
|
380
|
+
def detect_tool_role(self, provider: str) -> str:
|
381
|
+
"""Detect the role of the tool call"""
|
382
|
+
if provider == "gemini":
|
383
|
+
return "user"
|
384
|
+
else:
|
385
|
+
return "tool"
|
yaicli/config.py
CHANGED
@@ -1,20 +1,21 @@
|
|
1
1
|
import configparser
|
2
|
+
from dataclasses import dataclass
|
2
3
|
from functools import lru_cache
|
3
4
|
from os import getenv
|
4
|
-
from typing import Optional
|
5
|
+
from typing import Any, Optional
|
5
6
|
|
6
7
|
from rich import get_console
|
7
8
|
from rich.console import Console
|
8
9
|
|
9
|
-
from
|
10
|
+
from .const import (
|
10
11
|
CONFIG_PATH,
|
11
12
|
DEFAULT_CHAT_HISTORY_DIR,
|
12
13
|
DEFAULT_CONFIG_INI,
|
13
14
|
DEFAULT_CONFIG_MAP,
|
14
|
-
|
15
|
-
|
15
|
+
DEFAULT_TEMPERATURE,
|
16
|
+
DEFAULT_TOP_P,
|
16
17
|
)
|
17
|
-
from
|
18
|
+
from .utils import str2bool
|
18
19
|
|
19
20
|
|
20
21
|
class CasePreservingConfigParser(configparser.RawConfigParser):
|
@@ -24,6 +25,17 @@ class CasePreservingConfigParser(configparser.RawConfigParser):
|
|
24
25
|
return optionstr
|
25
26
|
|
26
27
|
|
28
|
+
@dataclass
|
29
|
+
class ProviderConfig:
|
30
|
+
"""Provider configuration"""
|
31
|
+
|
32
|
+
api_key: str
|
33
|
+
model: str
|
34
|
+
base_url: Optional[str] = None
|
35
|
+
temperature: float = DEFAULT_TEMPERATURE
|
36
|
+
top_p: float = DEFAULT_TOP_P
|
37
|
+
|
38
|
+
|
27
39
|
class Config(dict):
|
28
40
|
"""Configuration class that loads settings on initialization.
|
29
41
|
|
@@ -49,7 +61,7 @@ class Config(dict):
|
|
49
61
|
"""
|
50
62
|
# Start with defaults
|
51
63
|
self.clear()
|
52
|
-
self.
|
64
|
+
self._load_defaults()
|
53
65
|
|
54
66
|
# Load from config file
|
55
67
|
self._load_from_file()
|
@@ -58,13 +70,15 @@ class Config(dict):
|
|
58
70
|
self._load_from_env()
|
59
71
|
self._apply_type_conversion()
|
60
72
|
|
61
|
-
def _load_defaults(self) -> dict[str,
|
73
|
+
def _load_defaults(self) -> dict[str, Any]:
|
62
74
|
"""Load default configuration values as strings.
|
63
75
|
|
64
76
|
Returns:
|
65
77
|
Dictionary with default configuration values
|
66
78
|
"""
|
67
|
-
|
79
|
+
defaults = {k: v["value"] for k, v in DEFAULT_CONFIG_MAP.items()}
|
80
|
+
self.update(defaults)
|
81
|
+
return defaults
|
68
82
|
|
69
83
|
def _ensure_version_updated_config_keys(self):
|
70
84
|
"""Ensure configuration keys added in version updates exist in the config file.
|
@@ -74,10 +88,6 @@ class Config(dict):
|
|
74
88
|
config_content = f.read()
|
75
89
|
if "CHAT_HISTORY_DIR" not in config_content.strip(): # Check for empty lines
|
76
90
|
f.write(f"\nCHAT_HISTORY_DIR={DEFAULT_CHAT_HISTORY_DIR}\n")
|
77
|
-
if "MAX_SAVED_CHATS" not in config_content.strip(): # Check for empty lines
|
78
|
-
f.write(f"\nMAX_SAVED_CHATS={DEFAULT_MAX_SAVED_CHATS}\n")
|
79
|
-
if "JUSTIFY" not in config_content.strip():
|
80
|
-
f.write(f"\nJUSTIFY={DEFAULT_JUSTIFY}\n")
|
81
91
|
|
82
92
|
def _load_from_file(self) -> None:
|
83
93
|
"""Load configuration from the config file.
|
@@ -85,7 +95,7 @@ class Config(dict):
|
|
85
95
|
Creates default config file if it doesn't exist.
|
86
96
|
"""
|
87
97
|
if not CONFIG_PATH.exists():
|
88
|
-
self.console.print("Creating default configuration file.", style="bold yellow", justify=self
|
98
|
+
self.console.print("Creating default configuration file.", style="bold yellow", justify=self["JUSTIFY"])
|
89
99
|
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
90
100
|
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
|
91
101
|
f.write(DEFAULT_CONFIG_INI)
|
@@ -102,7 +112,7 @@ class Config(dict):
|
|
102
112
|
if not config_parser["core"].get(k, "").strip():
|
103
113
|
config_parser["core"][k] = v
|
104
114
|
|
105
|
-
self.update(config_parser["core"])
|
115
|
+
self.update(dict(config_parser["core"]))
|
106
116
|
|
107
117
|
# Check if keys added in version updates are missing and add them
|
108
118
|
self._ensure_version_updated_config_keys()
|
@@ -123,24 +133,26 @@ class Config(dict):
|
|
123
133
|
Updates the configuration dictionary in-place with properly typed values.
|
124
134
|
Falls back to default values if conversion fails.
|
125
135
|
"""
|
126
|
-
default_values_str =
|
136
|
+
default_values_str = {k: v["value"] for k, v in DEFAULT_CONFIG_MAP.items()}
|
127
137
|
|
128
138
|
for key, config_info in DEFAULT_CONFIG_MAP.items():
|
129
139
|
target_type = config_info["type"]
|
130
|
-
raw_value = self
|
140
|
+
raw_value = self[key]
|
131
141
|
converted_value = None
|
132
142
|
|
133
143
|
try:
|
144
|
+
if raw_value is None:
|
145
|
+
raw_value = default_values_str.get(key, "")
|
134
146
|
if target_type is bool:
|
135
147
|
converted_value = str2bool(raw_value)
|
136
148
|
elif target_type in (int, float, str):
|
137
149
|
converted_value = target_type(raw_value)
|
138
150
|
except (ValueError, TypeError) as e:
|
139
151
|
self.console.print(
|
140
|
-
f"[yellow]Warning:[/
|
152
|
+
f"[yellow]Warning:[/] Invalid value '{raw_value}' for '{key}'. "
|
141
153
|
f"Expected type '{target_type.__name__}'. Using default value '{default_values_str[key]}'. Error: {e}",
|
142
154
|
style="dim",
|
143
|
-
justify=self
|
155
|
+
justify=self["JUSTIFY"],
|
144
156
|
)
|
145
157
|
# Fallback to default string value if conversion fails
|
146
158
|
try:
|
@@ -153,14 +165,14 @@ class Config(dict):
|
|
153
165
|
self.console.print(
|
154
166
|
f"[red]Error:[/red] Could not convert default value for '{key}'. Using raw value.",
|
155
167
|
style="error",
|
156
|
-
justify=self
|
168
|
+
justify=self["JUSTIFY"],
|
157
169
|
)
|
158
170
|
converted_value = raw_value # Or assign a hardcoded safe default
|
159
171
|
|
160
172
|
self[key] = converted_value
|
161
173
|
|
162
174
|
|
163
|
-
@lru_cache(
|
175
|
+
@lru_cache(1)
|
164
176
|
def get_config() -> Config:
|
165
177
|
"""Get the configuration singleton"""
|
166
178
|
return Config()
|
yaicli/console.py
CHANGED
@@ -3,8 +3,8 @@ from typing import Any, Optional, Union
|
|
3
3
|
from rich.console import Console, JustifyMethod, OverflowMethod
|
4
4
|
from rich.style import Style
|
5
5
|
|
6
|
-
from
|
7
|
-
from
|
6
|
+
from .config import cfg
|
7
|
+
from .const import DEFAULT_JUSTIFY
|
8
8
|
|
9
9
|
_console = None
|
10
10
|
|