yaicli 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yaicli/client.py ADDED
@@ -0,0 +1,385 @@
1
+ import json
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, Generator, List, Optional, Union, cast
4
+
5
+ import litellm
6
+ from json_repair import repair_json
7
+ from litellm.types.utils import Choices
8
+ from litellm.types.utils import Message as ChoiceMessage
9
+ from litellm.types.utils import ModelResponse
10
+ from rich.panel import Panel
11
+
12
+ from .config import cfg
13
+ from .console import get_console
14
+ from .schemas import LLMResponse, ChatMessage, ToolCall
15
+ from .tools import (
16
+ Function,
17
+ FunctionName,
18
+ get_function,
19
+ get_openai_schemas,
20
+ list_functions,
21
+ )
22
+
23
+ litellm.drop_params = True
24
+ console = get_console()
25
+
26
+
27
+ class RefreshLive:
28
+ """Refresh live display"""
29
+
30
+
31
+ class StopLive:
32
+ """Stop live display"""
33
+
34
+
35
+ @dataclass
36
+ class LitellmClient:
37
+ """OpenAI provider implementation"""
38
+
39
+ api_key: str = field(default_factory=lambda: cfg["API_KEY"])
40
+ model: str = field(default_factory=lambda: f"{cfg['PROVIDER']}/{cfg['MODEL']}")
41
+ base_url: Optional[str] = field(default_factory=lambda: cfg["BASE_URL"])
42
+ timeout: int = field(default_factory=lambda: cfg["TIMEOUT"])
43
+
44
+ verbose: bool = False
45
+
46
+ def __post_init__(self) -> None:
47
+ """Initialize OpenAI client"""
48
+ self.pre_tool_call_id = None
49
+
50
+ def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
51
+ """Convert message format to OpenAI API required format"""
52
+ openai_messages = []
53
+ for msg in messages:
54
+ if msg.tool_call_id:
55
+ openai_messages.append(
56
+ {"role": msg.role, "content": msg.content, "tool_call_id": msg.tool_call_id, "name": msg.name}
57
+ )
58
+ else:
59
+ openai_messages.append({"role": msg.role, "content": msg.content})
60
+ return openai_messages
61
+
62
+ def _convert_functions(self, _: List[Function]) -> List[Dict[str, Any]]:
63
+ """Convert function format to OpenAI API required format"""
64
+ return get_openai_schemas()
65
+
66
+ def _execute_tool_call(self, tool_call: ToolCall) -> tuple[str, bool]:
67
+ """Call function and return result"""
68
+ console.print(f"@Function call: {tool_call.name}({tool_call.arguments})", style="blue")
69
+
70
+ # 1. Get function
71
+ try:
72
+ function = get_function(FunctionName(tool_call.name))
73
+ except ValueError as e:
74
+ error_msg = f"Function '{tool_call.name!r}' not exists: {e}"
75
+ console.print(error_msg, style="red")
76
+ return error_msg, False
77
+
78
+ # 2. Parse function arguments
79
+ try:
80
+ arguments = repair_json(tool_call.arguments, return_objects=True)
81
+ if not isinstance(arguments, dict):
82
+ error_msg = f"Invalid arguments type: {arguments!r}, should be JSON object"
83
+ console.print(error_msg, style="red")
84
+ return error_msg, False
85
+ arguments = cast(dict, arguments)
86
+ except Exception as e:
87
+ error_msg = f"Invalid arguments from llm: {e}\nRaw arguments: {tool_call.arguments!r}"
88
+ console.print(error_msg, style="red")
89
+ return error_msg, False
90
+
91
+ # 3. execute function
92
+ try:
93
+ function_result = function.execute(**arguments)
94
+ if cfg["SHOW_FUNCTION_OUTPUT"]:
95
+ panel = Panel(
96
+ function_result,
97
+ title="Function output",
98
+ title_align="left",
99
+ expand=False,
100
+ border_style="blue",
101
+ style="dim",
102
+ )
103
+ console.print(panel)
104
+ return function_result, True
105
+ except Exception as e:
106
+ error_msg = f"Call function error: {e}\nFunction name: {tool_call.name!r}\nArguments: {arguments!r}"
107
+ console.print(error_msg, style="red")
108
+ return error_msg, False
109
+
110
+ def completion(
111
+ self,
112
+ messages: List[ChatMessage],
113
+ stream: bool = False,
114
+ recursion_depth: int = 0,
115
+ ) -> Generator[Union[LLMResponse, RefreshLive], None, None]:
116
+ """Send message to OpenAI with a maximum recursion depth of 5"""
117
+ if self.verbose:
118
+ console.print(messages)
119
+ openai_messages = self._convert_messages(messages)
120
+
121
+ # Prepare request parameters
122
+ params: Dict[str, Any] = {
123
+ "model": self.model,
124
+ "messages": openai_messages,
125
+ "temperature": cfg["TEMPERATURE"],
126
+ "top_p": cfg["TOP_P"],
127
+ "stream": stream,
128
+ # Openai: This value is now deprecated in favor of max_completion_tokens.
129
+ "max_tokens": cfg["MAX_TOKENS"],
130
+ "max_completion_tokens": cfg["MAX_TOKENS"],
131
+ # litellm api params
132
+ "api_key": self.api_key,
133
+ "base_url": self.base_url,
134
+ "reasoning_effort": cfg["REASONING_EFFORT"],
135
+ }
136
+
137
+ # Add optional parameters
138
+ if cfg["ENABLE_FUNCTIONS"]:
139
+ params["tools"] = self._convert_functions(list_functions())
140
+ params["tool_choice"] = "auto"
141
+ params["parallel_tool_calls"] = False
142
+ # Send request
143
+ response = litellm.completion(**params)
144
+ if stream:
145
+ response = cast(litellm.CustomStreamWrapper, response)
146
+ llm_content_generator = self._handle_stream_response(response)
147
+ else:
148
+ response = cast(ModelResponse, response)
149
+ llm_content_generator = self._handle_normal_response(response)
150
+ for llm_content in llm_content_generator:
151
+ yield llm_content
152
+ if llm_content.tool_call:
153
+ if not self.pre_tool_call_id:
154
+ self.pre_tool_call_id = llm_content.tool_call.id
155
+ elif self.pre_tool_call_id == llm_content.tool_call.id:
156
+ continue
157
+ # Let live display know we are in next run
158
+ yield RefreshLive()
159
+
160
+ # execute function call
161
+ function_result, _ = self._execute_tool_call(llm_content.tool_call)
162
+
163
+ # add function call result
164
+ messages.append(
165
+ ChatMessage(
166
+ role=self.detect_tool_role(cfg["PROVIDER"]),
167
+ content=function_result,
168
+ name=llm_content.tool_call.name,
169
+ tool_call_id=llm_content.tool_call.id,
170
+ )
171
+ )
172
+ # Check if we've exceeded the maximum recursion depth
173
+ if recursion_depth >= 5:
174
+ console.print("Maximum recursion depth (5) reached, stopping further tool calls", style="yellow")
175
+ return
176
+
177
+ # Continue with recursion if within limits
178
+ if stream:
179
+ yield from self.completion(messages, stream=stream, recursion_depth=recursion_depth + 1)
180
+ else:
181
+ yield from self.completion(messages, stream=stream, recursion_depth=recursion_depth + 1)
182
+ # yield StopLive()
183
+
184
+ def stream_completion(self, messages: List[ChatMessage], stream: bool = True) -> Generator[LLMResponse, None, None]:
185
+ openai_messages = self._convert_messages(messages)
186
+ params: Dict[str, Any] = {
187
+ "model": self.model,
188
+ "messages": openai_messages,
189
+ "temperature": cfg["TEMPERATURE"],
190
+ "top_p": cfg["TOP_P"],
191
+ "stream": stream,
192
+ # Openai: This value is now deprecated in favor of max_completion_tokens.
193
+ "max_tokens": cfg["MAX_TOKENS"],
194
+ "max_completion_tokens": cfg["MAX_TOKENS"],
195
+ # litellm api params
196
+ "api_key": self.api_key,
197
+ "base_url": self.base_url,
198
+ }
199
+ # Add optional parameters
200
+ if cfg["ENABLE_FUNCTIONS"]:
201
+ params["tools"] = self._convert_functions(list_functions())
202
+ params["tool_choice"] = "auto"
203
+ params["parallel_tool_calls"] = False
204
+
205
+ # Send request
206
+ response = litellm.completion(**params)
207
+ response = cast(litellm.CustomStreamWrapper, response)
208
+ llm_content_generator = self._handle_stream_response(response)
209
+ for llm_content in llm_content_generator:
210
+ yield llm_content
211
+ if llm_content.tool_call:
212
+ if not self.pre_tool_call_id:
213
+ self.pre_tool_call_id = llm_content.tool_call.id
214
+ elif self.pre_tool_call_id == llm_content.tool_call.id:
215
+ continue
216
+
217
+ # execute function
218
+ function_result, _ = self._execute_tool_call(llm_content.tool_call)
219
+
220
+ # add function call result
221
+ messages.append(
222
+ ChatMessage(
223
+ role=self.detect_tool_role(cfg["PROVIDER"]),
224
+ content=function_result,
225
+ name=llm_content.tool_call.name,
226
+ tool_call_id=llm_content.tool_call.id,
227
+ )
228
+ )
229
+
230
+ yield from self.stream_completion(messages)
231
+
232
+ def _handle_normal_response(self, response: ModelResponse) -> Generator[LLMResponse, None, None]:
233
+ """Handle normal (non-streaming) response
234
+
235
+ Returns:
236
+ LLMContent object with:
237
+ - reasoning: The thinking/reasoning content (if any)
238
+ - content: The normal response content
239
+ """
240
+ choice = response.choices[0]
241
+ content = choice.message.content or "" # type: ignore
242
+ reasoning = choice.message.reasoning_content # type: ignore
243
+ finish_reason = choice.finish_reason
244
+ tool_call: Optional[ToolCall] = None
245
+
246
+ # Check if the response contains reasoning content
247
+ if "<think>" in content and "</think>" in content:
248
+ # Extract reasoning content
249
+ content = content.lstrip()
250
+ if content.startswith("<think>"):
251
+ think_end = content.find("</think>")
252
+ if think_end != -1:
253
+ reasoning = content[7:think_end].strip() # Start after <think>
254
+ # Remove the <think> block from the main content
255
+ content = content[think_end + 8 :].strip() # Start after </think>
256
+ # Check if the response contains reasoning content in model_extra
257
+ elif hasattr(choice.message, "model_extra") and choice.message.model_extra: # type: ignore
258
+ model_extra = choice.message.model_extra # type: ignore
259
+ reasoning = self._get_reasoning_content(model_extra)
260
+ if finish_reason == "tool_calls":
261
+ if '{"index":' in content or '"tool_calls":' in content:
262
+ # Tool call data may in content after the <think> block
263
+ # >/n{"index": 0, "tool_call_id": "call_1", "function": {"name": "name", "arguments": "{}"}, "output": null}
264
+ tool_index = content.find('{"index":')
265
+ if tool_index != -1:
266
+ tmp_content = content[tool_index:]
267
+ # Tool call data may in content after the <think> block
268
+ try:
269
+ choice = self.parse_choice_from_content(tmp_content)
270
+ except ValueError:
271
+ pass
272
+ if hasattr(choice, "message") and hasattr(choice.message, "tool_calls") and choice.message.tool_calls: # type: ignore
273
+ tool = choice.message.tool_calls[0] # type: ignore
274
+ tool_call = ToolCall(tool.id, tool.function.name or "", tool.function.arguments)
275
+
276
+ yield LLMResponse(reasoning=reasoning, content=content, finish_reason=finish_reason, tool_call=tool_call)
277
+
278
+ def _handle_stream_response(self, response: litellm.CustomStreamWrapper) -> Generator[LLMResponse, None, None]:
279
+ """Handle streaming response
280
+
281
+ Returns:
282
+ Generator yielding LLMContent objects with:
283
+ - reasoning: The thinking/reasoning content (if any)
284
+ - content: The normal response content
285
+ """
286
+ full_reasoning = ""
287
+ full_content = ""
288
+ content = ""
289
+ reasoning = ""
290
+ tool_id = ""
291
+ tool_call_name = ""
292
+ arguments = ""
293
+ tool_call: Optional[ToolCall] = None
294
+ for chunk in response:
295
+ # Check if the response contains reasoning content
296
+ choice = chunk.choices[0] # type: ignore
297
+ delta = choice.delta
298
+ finish_reason = choice.finish_reason
299
+
300
+ # Concat content
301
+ content = delta.content or ""
302
+ full_content += content
303
+
304
+ # Concat reasoning
305
+ reasoning = self._get_reasoning_content(delta)
306
+ full_reasoning += reasoning or ""
307
+
308
+ if finish_reason:
309
+ pass
310
+ if finish_reason == "tool_calls" or ('{"index":' in content or '"tool_calls":' in content):
311
+ # Tool call data may in content after the <think> block
312
+ # >/n{"index": 0, "tool_call_id": "call_1", "function": {"name": "name", "arguments": "{}"}, "output": null}
313
+ tool_index = full_content.find('{"index":')
314
+ if tool_index != -1:
315
+ tmp_content = full_content[tool_index:]
316
+ try:
317
+ choice = self.parse_choice_from_content(tmp_content)
318
+ except ValueError:
319
+ pass
320
+ if hasattr(choice.delta, "tool_calls") and choice.delta.tool_calls: # type: ignore
321
+ # Handle tool calls
322
+ tool_id = choice.delta.tool_calls[0].id or "" # type: ignore
323
+ for tool in choice.delta.tool_calls: # type: ignore
324
+ if not tool.function:
325
+ continue
326
+ tool_call_name = tool.function.name or ""
327
+ arguments += tool.function.arguments or ""
328
+ tool_call = ToolCall(tool_id, tool_call_name, arguments)
329
+ yield LLMResponse(reasoning=reasoning, content=content, tool_call=tool_call, finish_reason=finish_reason)
330
+
331
+ def _get_reasoning_content(self, delta: Any) -> Optional[str]:
332
+ """Extract reasoning content from delta if available based on specific keys.
333
+
334
+ This method checks for various keys that might contain reasoning content
335
+ in different API implementations.
336
+
337
+ Args:
338
+ delta: The delta/model_extra from the API response
339
+
340
+ Returns:
341
+ The reasoning content string if found, None otherwise
342
+ """
343
+ if not delta:
344
+ return None
345
+ # Reasoning content keys from API:
346
+ # reasoning_content: deepseek/infi-ai
347
+ # reasoning: openrouter
348
+ # <think> block implementation not in here
349
+ for key in ("reasoning_content", "reasoning"):
350
+ # Check if the key exists and its value is a non-empty string
351
+ if hasattr(delta, key):
352
+ return getattr(delta, key)
353
+
354
+ return None # Return None if no relevant key with a string value is found
355
+
356
+ def parse_choice_from_content(self, content: str) -> Choices:
357
+ """
358
+ Parse the choice from the content after <think>...</think> block.
359
+ Args:
360
+ content: The content from the LLM response
361
+ choice_cls: The class to use to parse the choice
362
+ Returns:
363
+ The choice object
364
+ Raises ValueError if the content is not valid JSON
365
+ """
366
+ try:
367
+ content_dict = json.loads(content)
368
+ except json.JSONDecodeError:
369
+ raise ValueError(f"Invalid message from LLM: {content}")
370
+ if "delta" in content_dict:
371
+ try:
372
+ content_dict["delta"] = ChoiceMessage.model_validate(content_dict["delta"])
373
+ except Exception as e:
374
+ raise ValueError(f"Invalid message from LLM: {content}") from e
375
+ try:
376
+ return Choices.model_validate(content_dict)
377
+ except Exception as e:
378
+ raise ValueError(f"Invalid message from LLM: {content}") from e
379
+
380
+ def detect_tool_role(self, provider: str) -> str:
381
+ """Detect the role of the tool call"""
382
+ if provider == "gemini":
383
+ return "user"
384
+ else:
385
+ return "tool"
yaicli/config.py CHANGED
@@ -1,20 +1,21 @@
1
1
  import configparser
2
+ from dataclasses import dataclass
2
3
  from functools import lru_cache
3
4
  from os import getenv
4
- from typing import Optional
5
+ from typing import Any, Optional
5
6
 
6
7
  from rich import get_console
7
8
  from rich.console import Console
8
9
 
9
- from yaicli.const import (
10
+ from .const import (
10
11
  CONFIG_PATH,
11
12
  DEFAULT_CHAT_HISTORY_DIR,
12
13
  DEFAULT_CONFIG_INI,
13
14
  DEFAULT_CONFIG_MAP,
14
- DEFAULT_JUSTIFY,
15
- DEFAULT_MAX_SAVED_CHATS,
15
+ DEFAULT_TEMPERATURE,
16
+ DEFAULT_TOP_P,
16
17
  )
17
- from yaicli.utils import str2bool
18
+ from .utils import str2bool
18
19
 
19
20
 
20
21
  class CasePreservingConfigParser(configparser.RawConfigParser):
@@ -24,6 +25,17 @@ class CasePreservingConfigParser(configparser.RawConfigParser):
24
25
  return optionstr
25
26
 
26
27
 
28
+ @dataclass
29
+ class ProviderConfig:
30
+ """Provider configuration"""
31
+
32
+ api_key: str
33
+ model: str
34
+ base_url: Optional[str] = None
35
+ temperature: float = DEFAULT_TEMPERATURE
36
+ top_p: float = DEFAULT_TOP_P
37
+
38
+
27
39
  class Config(dict):
28
40
  """Configuration class that loads settings on initialization.
29
41
 
@@ -49,7 +61,7 @@ class Config(dict):
49
61
  """
50
62
  # Start with defaults
51
63
  self.clear()
52
- self.update(self._load_defaults())
64
+ self._load_defaults()
53
65
 
54
66
  # Load from config file
55
67
  self._load_from_file()
@@ -58,13 +70,15 @@ class Config(dict):
58
70
  self._load_from_env()
59
71
  self._apply_type_conversion()
60
72
 
61
- def _load_defaults(self) -> dict[str, str]:
73
+ def _load_defaults(self) -> dict[str, Any]:
62
74
  """Load default configuration values as strings.
63
75
 
64
76
  Returns:
65
77
  Dictionary with default configuration values
66
78
  """
67
- return {k: v["value"] for k, v in DEFAULT_CONFIG_MAP.items()}
79
+ defaults = {k: v["value"] for k, v in DEFAULT_CONFIG_MAP.items()}
80
+ self.update(defaults)
81
+ return defaults
68
82
 
69
83
  def _ensure_version_updated_config_keys(self):
70
84
  """Ensure configuration keys added in version updates exist in the config file.
@@ -74,10 +88,6 @@ class Config(dict):
74
88
  config_content = f.read()
75
89
  if "CHAT_HISTORY_DIR" not in config_content.strip(): # Check for empty lines
76
90
  f.write(f"\nCHAT_HISTORY_DIR={DEFAULT_CHAT_HISTORY_DIR}\n")
77
- if "MAX_SAVED_CHATS" not in config_content.strip(): # Check for empty lines
78
- f.write(f"\nMAX_SAVED_CHATS={DEFAULT_MAX_SAVED_CHATS}\n")
79
- if "JUSTIFY" not in config_content.strip():
80
- f.write(f"\nJUSTIFY={DEFAULT_JUSTIFY}\n")
81
91
 
82
92
  def _load_from_file(self) -> None:
83
93
  """Load configuration from the config file.
@@ -85,7 +95,7 @@ class Config(dict):
85
95
  Creates default config file if it doesn't exist.
86
96
  """
87
97
  if not CONFIG_PATH.exists():
88
- self.console.print("Creating default configuration file.", style="bold yellow", justify=self.get("JUSTIFY"))
98
+ self.console.print("Creating default configuration file.", style="bold yellow", justify=self["JUSTIFY"])
89
99
  CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
90
100
  with open(CONFIG_PATH, "w", encoding="utf-8") as f:
91
101
  f.write(DEFAULT_CONFIG_INI)
@@ -102,7 +112,7 @@ class Config(dict):
102
112
  if not config_parser["core"].get(k, "").strip():
103
113
  config_parser["core"][k] = v
104
114
 
105
- self.update(config_parser["core"])
115
+ self.update(dict(config_parser["core"]))
106
116
 
107
117
  # Check if keys added in version updates are missing and add them
108
118
  self._ensure_version_updated_config_keys()
@@ -123,24 +133,26 @@ class Config(dict):
123
133
  Updates the configuration dictionary in-place with properly typed values.
124
134
  Falls back to default values if conversion fails.
125
135
  """
126
- default_values_str = self._load_defaults()
136
+ default_values_str = {k: v["value"] for k, v in DEFAULT_CONFIG_MAP.items()}
127
137
 
128
138
  for key, config_info in DEFAULT_CONFIG_MAP.items():
129
139
  target_type = config_info["type"]
130
- raw_value = self.get(key, default_values_str.get(key))
140
+ raw_value = self[key]
131
141
  converted_value = None
132
142
 
133
143
  try:
144
+ if raw_value is None:
145
+ raw_value = default_values_str.get(key, "")
134
146
  if target_type is bool:
135
147
  converted_value = str2bool(raw_value)
136
148
  elif target_type in (int, float, str):
137
149
  converted_value = target_type(raw_value)
138
150
  except (ValueError, TypeError) as e:
139
151
  self.console.print(
140
- f"[yellow]Warning:[/yellow] Invalid value '{raw_value}' for '{key}'. "
152
+ f"[yellow]Warning:[/] Invalid value '{raw_value}' for '{key}'. "
141
153
  f"Expected type '{target_type.__name__}'. Using default value '{default_values_str[key]}'. Error: {e}",
142
154
  style="dim",
143
- justify=self.get("JUSTIFY"),
155
+ justify=self["JUSTIFY"],
144
156
  )
145
157
  # Fallback to default string value if conversion fails
146
158
  try:
@@ -153,14 +165,14 @@ class Config(dict):
153
165
  self.console.print(
154
166
  f"[red]Error:[/red] Could not convert default value for '{key}'. Using raw value.",
155
167
  style="error",
156
- justify=self.get("JUSTIFY"),
168
+ justify=self["JUSTIFY"],
157
169
  )
158
170
  converted_value = raw_value # Or assign a hardcoded safe default
159
171
 
160
172
  self[key] = converted_value
161
173
 
162
174
 
163
- @lru_cache(maxsize=1)
175
+ @lru_cache(1)
164
176
  def get_config() -> Config:
165
177
  """Get the configuration singleton"""
166
178
  return Config()
yaicli/console.py CHANGED
@@ -3,8 +3,8 @@ from typing import Any, Optional, Union
3
3
  from rich.console import Console, JustifyMethod, OverflowMethod
4
4
  from rich.style import Style
5
5
 
6
- from yaicli.config import cfg
7
- from yaicli.const import DEFAULT_JUSTIFY
6
+ from .config import cfg
7
+ from .const import DEFAULT_JUSTIFY
8
8
 
9
9
  _console = None
10
10