yaicli 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyproject.toml +5 -3
- yaicli/chat.py +396 -0
- yaicli/cli.py +250 -251
- yaicli/client.py +385 -0
- yaicli/config.py +31 -24
- yaicli/console.py +2 -2
- yaicli/const.py +28 -2
- yaicli/entry.py +68 -39
- yaicli/exceptions.py +8 -36
- yaicli/functions/__init__.py +39 -0
- yaicli/functions/buildin/execute_shell_command.py +47 -0
- yaicli/printer.py +145 -225
- yaicli/render.py +1 -1
- yaicli/role.py +231 -0
- yaicli/schemas.py +31 -0
- yaicli/tools.py +103 -0
- yaicli/utils.py +5 -2
- {yaicli-0.4.0.dist-info → yaicli-0.5.0.dist-info}/METADATA +164 -87
- yaicli-0.5.0.dist-info/RECORD +24 -0
- {yaicli-0.4.0.dist-info → yaicli-0.5.0.dist-info}/entry_points.txt +1 -1
- yaicli/chat_manager.py +0 -290
- yaicli/providers/__init__.py +0 -34
- yaicli/providers/base.py +0 -51
- yaicli/providers/cohere.py +0 -136
- yaicli/providers/openai.py +0 -176
- yaicli/roles.py +0 -276
- yaicli-0.4.0.dist-info/RECORD +0 -23
- {yaicli-0.4.0.dist-info → yaicli-0.5.0.dist-info}/WHEEL +0 -0
- {yaicli-0.4.0.dist-info → yaicli-0.5.0.dist-info}/licenses/LICENSE +0 -0
yaicli/chat_manager.py
DELETED
@@ -1,290 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import time
|
3
|
-
from abc import ABC, abstractmethod
|
4
|
-
from datetime import datetime
|
5
|
-
from pathlib import Path
|
6
|
-
from typing import Any, Dict, List, Optional, TypedDict, Union
|
7
|
-
|
8
|
-
from rich.console import Console
|
9
|
-
|
10
|
-
from yaicli.config import Config, cfg
|
11
|
-
from yaicli.console import get_console
|
12
|
-
from yaicli.utils import option_callback
|
13
|
-
|
14
|
-
|
15
|
-
class ChatFileInfo(TypedDict):
|
16
|
-
"""Chat info, parse chat filename and store metadata"""
|
17
|
-
|
18
|
-
index: int
|
19
|
-
path: str
|
20
|
-
title: str
|
21
|
-
date: str
|
22
|
-
timestamp: int
|
23
|
-
|
24
|
-
|
25
|
-
class ChatsMap(TypedDict):
|
26
|
-
"""Chat info cache for chat manager"""
|
27
|
-
|
28
|
-
title: Dict[str, ChatFileInfo]
|
29
|
-
index: Dict[int, ChatFileInfo]
|
30
|
-
|
31
|
-
|
32
|
-
class ChatManager(ABC):
|
33
|
-
"""Abstract base class that defines the chat manager interface"""
|
34
|
-
|
35
|
-
@abstractmethod
|
36
|
-
def make_chat_title(self, prompt: Optional[str] = None) -> str:
|
37
|
-
"""Make a chat title from a given full prompt"""
|
38
|
-
pass
|
39
|
-
|
40
|
-
@abstractmethod
|
41
|
-
def save_chat(self, history: List[Dict[str, Any]], title: Optional[str] = None) -> str:
|
42
|
-
"""Save a chat and return the chat title"""
|
43
|
-
pass
|
44
|
-
|
45
|
-
@abstractmethod
|
46
|
-
def list_chats(self) -> List[ChatFileInfo]:
|
47
|
-
"""List all saved chats and return the chat list"""
|
48
|
-
pass
|
49
|
-
|
50
|
-
@abstractmethod
|
51
|
-
def refresh_chats(self) -> None:
|
52
|
-
"""Force refresh the chat list"""
|
53
|
-
pass
|
54
|
-
|
55
|
-
@abstractmethod
|
56
|
-
def load_chat_by_index(self, index: int) -> Union[ChatFileInfo, Dict]:
|
57
|
-
"""Load a chat by index and return the chat data"""
|
58
|
-
pass
|
59
|
-
|
60
|
-
@abstractmethod
|
61
|
-
def load_chat_by_title(self, title: str) -> Union[ChatFileInfo, Dict]:
|
62
|
-
"""Load a chat by title and return the chat data"""
|
63
|
-
pass
|
64
|
-
|
65
|
-
@abstractmethod
|
66
|
-
def delete_chat(self, index: int) -> bool:
|
67
|
-
"""Delete a chat by index and return success status"""
|
68
|
-
pass
|
69
|
-
|
70
|
-
@abstractmethod
|
71
|
-
def validate_chat_index(self, index: int) -> bool:
|
72
|
-
"""Validate a chat index and return success status"""
|
73
|
-
pass
|
74
|
-
|
75
|
-
|
76
|
-
class FileChatManager(ChatManager):
|
77
|
-
"""File system based chat manager implementation"""
|
78
|
-
|
79
|
-
console: Console = get_console()
|
80
|
-
config: Config = cfg
|
81
|
-
chat_dir = Path(config["CHAT_HISTORY_DIR"])
|
82
|
-
max_saved_chats = config["MAX_SAVED_CHATS"]
|
83
|
-
chat_dir.mkdir(parents=True, exist_ok=True)
|
84
|
-
|
85
|
-
def __init__(self):
|
86
|
-
self._chats_map: Optional[ChatsMap] = None # Cache for chat map
|
87
|
-
|
88
|
-
@property
|
89
|
-
def chats_map(self) -> ChatsMap:
|
90
|
-
"""Get the map of chats, loading from disk only when needed"""
|
91
|
-
if self._chats_map is None:
|
92
|
-
self._load_chats()
|
93
|
-
return self._chats_map or {"index": {}, "title": {}}
|
94
|
-
|
95
|
-
@classmethod
|
96
|
-
@option_callback
|
97
|
-
def print_list_option(cls, _: Any):
|
98
|
-
"""Print the list of chats"""
|
99
|
-
cls.console.print("Finding Chats...")
|
100
|
-
c = -1
|
101
|
-
for c, file in enumerate(sorted(cls.chat_dir.glob("*.json"), key=lambda f: f.stat().st_mtime)):
|
102
|
-
info: ChatFileInfo = cls._parse_filename(file, c + 1)
|
103
|
-
cls.console.print(f"{c + 1}. {info['title']} ({info['date']})")
|
104
|
-
if c == -1:
|
105
|
-
cls.console.print("No chats found", style="dim")
|
106
|
-
|
107
|
-
def make_chat_title(self, prompt: Optional[str] = None) -> str:
|
108
|
-
"""Make a chat title from a given full prompt"""
|
109
|
-
if prompt:
|
110
|
-
return prompt[:100]
|
111
|
-
else:
|
112
|
-
return f"Chat-{int(time.time())}"
|
113
|
-
|
114
|
-
def validate_chat_index(self, index: int) -> bool:
|
115
|
-
"""Validate a chat index and return success status"""
|
116
|
-
return index > 0 and index in self.chats_map["index"]
|
117
|
-
|
118
|
-
def refresh_chats(self) -> None:
|
119
|
-
"""Force refresh the chat list from disk"""
|
120
|
-
self._load_chats()
|
121
|
-
|
122
|
-
@staticmethod
|
123
|
-
def _parse_filename(chat_file: Path, index: int) -> ChatFileInfo:
|
124
|
-
"""Parse a chat filename and extract metadata"""
|
125
|
-
# filename: "20250421-214005-title-meaning of life"
|
126
|
-
filename = chat_file.stem
|
127
|
-
parts = filename.split("-")
|
128
|
-
title_str_len = 6 # "title-" marker length
|
129
|
-
|
130
|
-
# Check if the filename has the expected format
|
131
|
-
if len(parts) >= 4 and "title" in parts:
|
132
|
-
str_title_index = filename.find("title")
|
133
|
-
if str_title_index == -1:
|
134
|
-
# If "title" is not found, use full filename as the title
|
135
|
-
# Just in case, fallback to use fullname, but this should never happen when `len(parts) >= 4 and "title" in parts`
|
136
|
-
str_title_index = 0
|
137
|
-
title_str_len = 0
|
138
|
-
|
139
|
-
# "20250421-214005-title-meaning of life" ==> "meaning of life"
|
140
|
-
title = filename[str_title_index + title_str_len :]
|
141
|
-
date_ = parts[0]
|
142
|
-
time_ = parts[1]
|
143
|
-
# Format date
|
144
|
-
date_str = f"{date_[:4]}-{date_[4:6]}-{date_[6:]} {time_[:2]}:{time_[2:4]}"
|
145
|
-
|
146
|
-
# Calculate timestamp from date parts
|
147
|
-
try:
|
148
|
-
date_time_str = f"{date_}{time_}"
|
149
|
-
timestamp = int(datetime.strptime(date_time_str, "%Y%m%d%H%M%S").timestamp())
|
150
|
-
except ValueError:
|
151
|
-
timestamp = 0
|
152
|
-
else:
|
153
|
-
# Fallback for files that don't match expected format
|
154
|
-
title = filename
|
155
|
-
date_str = ""
|
156
|
-
timestamp = 0
|
157
|
-
|
158
|
-
# The actual title is stored in the JSON file, so we'll use that when loading
|
159
|
-
# This is just for the initial listing before the file is opened
|
160
|
-
return {
|
161
|
-
"index": index,
|
162
|
-
"path": str(chat_file),
|
163
|
-
"title": title,
|
164
|
-
"date": date_str,
|
165
|
-
"timestamp": timestamp,
|
166
|
-
}
|
167
|
-
|
168
|
-
def _load_chats(self) -> None:
|
169
|
-
"""Load chats from disk into memory"""
|
170
|
-
chat_files = sorted(list(self.chat_dir.glob("*.json")), reverse=True)
|
171
|
-
chats_map: ChatsMap = {"title": {}, "index": {}}
|
172
|
-
|
173
|
-
for i, chat_file in enumerate(chat_files[: self.max_saved_chats]):
|
174
|
-
try:
|
175
|
-
info = self._parse_filename(chat_file, i + 1)
|
176
|
-
chats_map["title"][info["title"]] = info
|
177
|
-
chats_map["index"][i + 1] = info
|
178
|
-
except Exception as e:
|
179
|
-
# Log the error but continue processing other files
|
180
|
-
self.console.print(f"Error parsing session file {chat_file}: {e}", style="dim")
|
181
|
-
continue
|
182
|
-
|
183
|
-
self._chats_map = chats_map
|
184
|
-
|
185
|
-
def list_chats(self) -> List[ChatFileInfo]:
|
186
|
-
"""List all saved chats and return the chat list"""
|
187
|
-
return list(self.chats_map["index"].values())
|
188
|
-
|
189
|
-
def save_chat(self, history: List[Dict[str, Any]], title: Optional[str] = None) -> str:
|
190
|
-
"""Save chat history to the file system, overwriting existing chats with the same title.
|
191
|
-
|
192
|
-
If no title is provided, the chat will be saved with a default title.
|
193
|
-
The default title is "Chat-{current timestamp}".
|
194
|
-
|
195
|
-
Args:
|
196
|
-
history (List[Dict[str, Any]]): The chat history to save
|
197
|
-
title (Optional[str]): The title of the chat provided by the user
|
198
|
-
|
199
|
-
Returns:
|
200
|
-
str: The title of the saved chat
|
201
|
-
"""
|
202
|
-
history = history or []
|
203
|
-
|
204
|
-
save_title = title or f"Chat-{int(time.time())}"
|
205
|
-
save_title = self.make_chat_title(save_title)
|
206
|
-
|
207
|
-
# Check for existing session with the same title and delete it
|
208
|
-
existing_chat = self.chats_map["title"].get(save_title)
|
209
|
-
if existing_chat:
|
210
|
-
try:
|
211
|
-
existing_path = Path(existing_chat["path"])
|
212
|
-
existing_path.unlink()
|
213
|
-
except OSError as e:
|
214
|
-
self.console.print(
|
215
|
-
f"Warning: Could not delete existing chat file {existing_chat['path']}: {e}",
|
216
|
-
style="dim",
|
217
|
-
)
|
218
|
-
|
219
|
-
timestamp = datetime.now().astimezone().strftime("%Y%m%d-%H%M%S")
|
220
|
-
filename = f"{timestamp}-title-{save_title}.json"
|
221
|
-
filepath = self.chat_dir / filename
|
222
|
-
|
223
|
-
try:
|
224
|
-
with open(filepath, "w", encoding="utf-8") as f:
|
225
|
-
json.dump({"history": history, "title": save_title}, f, ensure_ascii=False, indent=2)
|
226
|
-
# Force refresh the chat list after saving
|
227
|
-
self.refresh_chats()
|
228
|
-
return save_title
|
229
|
-
except Exception as e:
|
230
|
-
self.console.print(f"Error saving chat '{save_title}': {e}", style="dim")
|
231
|
-
return ""
|
232
|
-
|
233
|
-
def _load_chat_data(self, chat_info: Optional[ChatFileInfo]) -> Union[ChatFileInfo, Dict]:
|
234
|
-
"""Common method to load chat data from a chat info dict"""
|
235
|
-
if not chat_info:
|
236
|
-
return {}
|
237
|
-
|
238
|
-
try:
|
239
|
-
chat_file = Path(chat_info["path"])
|
240
|
-
with open(chat_file, "r", encoding="utf-8") as f:
|
241
|
-
chat_data = json.load(f)
|
242
|
-
|
243
|
-
return {
|
244
|
-
"title": chat_data.get("title", chat_info["title"]),
|
245
|
-
"timestamp": chat_info["timestamp"],
|
246
|
-
"history": chat_data.get("history", []),
|
247
|
-
}
|
248
|
-
except FileNotFoundError:
|
249
|
-
self.console.print(f"Chat file not found: {chat_info['path']}", style="dim")
|
250
|
-
return {}
|
251
|
-
except json.JSONDecodeError as e:
|
252
|
-
self.console.print(f"Invalid JSON in chat file {chat_info['path']}: {e}", style="dim")
|
253
|
-
return {}
|
254
|
-
except Exception as e:
|
255
|
-
self.console.print(f"Error loading chat from {chat_info['path']}: {e}", style="dim")
|
256
|
-
return {}
|
257
|
-
|
258
|
-
def load_chat_by_index(self, index: int) -> Union[ChatFileInfo, Dict]:
|
259
|
-
"""Load a chat by index and return the chat data"""
|
260
|
-
if not self.validate_chat_index(index):
|
261
|
-
return {}
|
262
|
-
chat_info = self.chats_map.get("index", {}).get(index)
|
263
|
-
return self._load_chat_data(chat_info)
|
264
|
-
|
265
|
-
def load_chat_by_title(self, title: str) -> Union[ChatFileInfo, Dict]:
|
266
|
-
"""Load a chat by title and return the chat data"""
|
267
|
-
chat_info = self.chats_map.get("title", {}).get(title)
|
268
|
-
return self._load_chat_data(chat_info)
|
269
|
-
|
270
|
-
def delete_chat(self, index: int) -> bool:
|
271
|
-
"""Delete a chat by index and return success status"""
|
272
|
-
if not self.validate_chat_index(index):
|
273
|
-
return False
|
274
|
-
|
275
|
-
chat_info = self.chats_map["index"].get(index)
|
276
|
-
if not chat_info:
|
277
|
-
return False
|
278
|
-
|
279
|
-
try:
|
280
|
-
chat_file = Path(chat_info["path"])
|
281
|
-
chat_file.unlink()
|
282
|
-
# Force refresh the chat list
|
283
|
-
self.refresh_chats()
|
284
|
-
return True
|
285
|
-
except FileNotFoundError:
|
286
|
-
self.console.print(f"Chat file not found: {chat_info['path']}", style="dim")
|
287
|
-
return False
|
288
|
-
except Exception as e:
|
289
|
-
self.console.print(f"Error deleting chat {index}: {e}", style="dim")
|
290
|
-
return False
|
yaicli/providers/__init__.py
DELETED
@@ -1,34 +0,0 @@
|
|
1
|
-
from yaicli.const import DEFAULT_PROVIDER
|
2
|
-
from yaicli.providers.base import BaseClient
|
3
|
-
from yaicli.providers.cohere import CohereClient
|
4
|
-
from yaicli.providers.openai import OpenAIClient
|
5
|
-
|
6
|
-
|
7
|
-
def create_api_client(config, console, verbose):
|
8
|
-
"""Factory function to create the appropriate API client based on provider.
|
9
|
-
|
10
|
-
Args:
|
11
|
-
config: The configuration dictionary
|
12
|
-
console: The rich console for output
|
13
|
-
verbose: Whether to enable verbose output
|
14
|
-
|
15
|
-
Returns:
|
16
|
-
An instance of the appropriate ApiClient implementation
|
17
|
-
"""
|
18
|
-
provider = config.get("PROVIDER", DEFAULT_PROVIDER).lower()
|
19
|
-
|
20
|
-
if provider == "openai":
|
21
|
-
return OpenAIClient(config, console, verbose)
|
22
|
-
elif provider == "cohere":
|
23
|
-
return CohereClient(config, console, verbose)
|
24
|
-
# elif provider == "google":
|
25
|
-
# return GoogleApiClient(config, console, verbose)
|
26
|
-
# elif provider == "claude":
|
27
|
-
# return ClaudeApiClient(config, console, verbose)
|
28
|
-
else:
|
29
|
-
# Fallback to openai client
|
30
|
-
console.print(f"Using generic HTTP client for provider: {provider}", style="yellow")
|
31
|
-
return OpenAIClient(config, console, verbose)
|
32
|
-
|
33
|
-
|
34
|
-
__all__ = ["BaseClient", "OpenAIClient", "CohereClient", "create_api_client"]
|
yaicli/providers/base.py
DELETED
@@ -1,51 +0,0 @@
|
|
1
|
-
from abc import ABC, abstractmethod
|
2
|
-
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
3
|
-
|
4
|
-
from rich.console import Console
|
5
|
-
|
6
|
-
|
7
|
-
class BaseClient(ABC):
|
8
|
-
"""Base abstract class for LLM API clients."""
|
9
|
-
|
10
|
-
def __init__(self, config: Dict[str, Any], console: Console, verbose: bool):
|
11
|
-
"""Initialize the API client with configuration."""
|
12
|
-
self.config = config
|
13
|
-
self.console = console
|
14
|
-
self.verbose = verbose
|
15
|
-
self.timeout = self.config["TIMEOUT"]
|
16
|
-
|
17
|
-
@abstractmethod
|
18
|
-
def completion(self, messages: List[Dict[str, str]]) -> Tuple[Optional[str], Optional[str]]:
|
19
|
-
"""Get a complete non-streamed response from the API."""
|
20
|
-
pass
|
21
|
-
|
22
|
-
@abstractmethod
|
23
|
-
def stream_completion(self, messages: List[Dict[str, str]]) -> Iterator[Dict[str, Any]]:
|
24
|
-
"""Connect to the API and yield parsed stream events."""
|
25
|
-
pass
|
26
|
-
|
27
|
-
def _get_reasoning_content(self, delta: dict) -> Optional[str]:
|
28
|
-
"""Extract reasoning content from delta if available based on specific keys.
|
29
|
-
|
30
|
-
This method checks for various keys that might contain reasoning content
|
31
|
-
in different API implementations.
|
32
|
-
|
33
|
-
Args:
|
34
|
-
delta: The delta dictionary from the API response
|
35
|
-
|
36
|
-
Returns:
|
37
|
-
The reasoning content string if found, None otherwise
|
38
|
-
"""
|
39
|
-
if not delta:
|
40
|
-
return None
|
41
|
-
# Reasoning content keys from API:
|
42
|
-
# reasoning_content: deepseek/infi-ai
|
43
|
-
# reasoning: openrouter
|
44
|
-
# <think> block implementation not in here
|
45
|
-
for key in ("reasoning_content", "reasoning"):
|
46
|
-
# Check if the key exists and its value is a non-empty string
|
47
|
-
value = delta.get(key)
|
48
|
-
if isinstance(value, str) and value:
|
49
|
-
return value
|
50
|
-
|
51
|
-
return None # Return None if no relevant key with a string value is found
|
yaicli/providers/cohere.py
DELETED
@@ -1,136 +0,0 @@
|
|
1
|
-
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, TypeVar
|
2
|
-
|
3
|
-
from cohere import ChatResponse, ClientV2, StreamedChatResponseV2
|
4
|
-
from cohere.core.api_error import ApiError
|
5
|
-
|
6
|
-
from yaicli.const import EventTypeEnum
|
7
|
-
from yaicli.providers.base import BaseClient
|
8
|
-
|
9
|
-
ChunkType = Literal[
|
10
|
-
"message-start",
|
11
|
-
"content-start",
|
12
|
-
"content-delta",
|
13
|
-
"content-end",
|
14
|
-
"tool-plan-delta",
|
15
|
-
"tool-call-start",
|
16
|
-
"tool-call-delta",
|
17
|
-
"tool-call-end",
|
18
|
-
"citation-start",
|
19
|
-
"citation-end",
|
20
|
-
"message-end",
|
21
|
-
"debug",
|
22
|
-
]
|
23
|
-
|
24
|
-
# Type variable for chunks that have delta attribute
|
25
|
-
T = TypeVar("T", bound=StreamedChatResponseV2)
|
26
|
-
|
27
|
-
|
28
|
-
class CohereClient(BaseClient):
|
29
|
-
"""Cohere API client implementation using the official Cohere Python library."""
|
30
|
-
|
31
|
-
def __init__(self, config: Dict[str, Any], console, verbose: bool):
|
32
|
-
"""Initialize the Cohere API client with configuration."""
|
33
|
-
super().__init__(config, console, verbose)
|
34
|
-
self.api_key = config["API_KEY"]
|
35
|
-
self.model = config["MODEL"]
|
36
|
-
if not config["BASE_URL"] or "cohere" not in config["BASE_URL"]:
|
37
|
-
# BASE_URL can be empty, in which case we use the default base_url
|
38
|
-
self.base_url = "https://api.cohere.com"
|
39
|
-
else:
|
40
|
-
self.base_url = config["BASE_URL"]
|
41
|
-
self.base_url = self.base_url.rstrip("/")
|
42
|
-
if self.base_url.endswith("v2") or self.base_url.endswith("v1"):
|
43
|
-
self.base_url = self.base_url[:-2]
|
44
|
-
|
45
|
-
# Initialize the Cohere client with our custom configuration
|
46
|
-
self.client = ClientV2(
|
47
|
-
api_key=self.api_key,
|
48
|
-
base_url=self.base_url,
|
49
|
-
client_name="Yaicli",
|
50
|
-
timeout=self.timeout,
|
51
|
-
)
|
52
|
-
|
53
|
-
def _prepare_request_params(self, messages: List[Dict[str, str]]) -> Dict[str, Any]:
|
54
|
-
"""Prepare the common request parameters for Cohere API calls."""
|
55
|
-
# P value must be between 0.01 and 0.99, default to 0.75 if outside this range, also cohere api default is 0.75
|
56
|
-
p = 0.75 if not (0.01 < self.config["TOP_P"] < 0.99) else self.config["TOP_P"]
|
57
|
-
return {
|
58
|
-
"messages": messages,
|
59
|
-
"model": self.model,
|
60
|
-
"temperature": self.config["TEMPERATURE"],
|
61
|
-
"max_tokens": self.config["MAX_TOKENS"],
|
62
|
-
"p": p,
|
63
|
-
}
|
64
|
-
|
65
|
-
def _process_completion_response(self, response: ChatResponse) -> Tuple[Optional[str], Optional[str]]:
|
66
|
-
"""Process the response from a non-streamed Cohere completion request."""
|
67
|
-
try:
|
68
|
-
content = response.message.content
|
69
|
-
if not content:
|
70
|
-
return None, None
|
71
|
-
text = content[0].text
|
72
|
-
if not text:
|
73
|
-
return None, None
|
74
|
-
return text, None
|
75
|
-
|
76
|
-
except Exception as e:
|
77
|
-
self.console.print(f"Error processing Cohere response: {e}", style="red")
|
78
|
-
if self.verbose:
|
79
|
-
self.console.print(f"Response: {response}")
|
80
|
-
return None, None
|
81
|
-
|
82
|
-
def completion(self, messages: List[Dict[str, str]]) -> Tuple[Optional[str], Optional[str]]:
|
83
|
-
"""Get a complete non-streamed response from the Cohere API."""
|
84
|
-
params = self._prepare_request_params(messages)
|
85
|
-
|
86
|
-
try:
|
87
|
-
response: ChatResponse = self.client.chat(**params)
|
88
|
-
return self._process_completion_response(response)
|
89
|
-
except ApiError as e:
|
90
|
-
self.console.print(f"Cohere API error: {e}", style="red")
|
91
|
-
if self.verbose:
|
92
|
-
self.console.print(f"Response: {e.body}")
|
93
|
-
return None, None
|
94
|
-
|
95
|
-
def stream_completion(self, messages: List[Dict[str, str]]) -> Iterator[Dict[str, Any]]:
|
96
|
-
"""Connect to the Cohere API and yield parsed stream events."""
|
97
|
-
params = self._prepare_request_params(messages)
|
98
|
-
|
99
|
-
try:
|
100
|
-
for chunk in self.client.v2.chat_stream(**params):
|
101
|
-
# Skip message start/end events
|
102
|
-
if chunk.type in ("message-start", "message-end", "content-end"): # type: ignore
|
103
|
-
continue
|
104
|
-
|
105
|
-
# Safe attribute checking - skip if any required attribute is missing
|
106
|
-
if not hasattr(chunk, "delta"):
|
107
|
-
continue
|
108
|
-
|
109
|
-
# At this point we know chunk has delta attribute
|
110
|
-
delta = getattr(chunk, "delta")
|
111
|
-
if delta is None or not hasattr(delta, "message"):
|
112
|
-
continue
|
113
|
-
|
114
|
-
message = getattr(delta, "message")
|
115
|
-
if message is None or not hasattr(message, "content"):
|
116
|
-
continue
|
117
|
-
|
118
|
-
content = getattr(message, "content")
|
119
|
-
if content is None or not hasattr(content, "text"):
|
120
|
-
continue
|
121
|
-
|
122
|
-
# Access text safely
|
123
|
-
text = getattr(content, "text")
|
124
|
-
if text:
|
125
|
-
yield {"type": EventTypeEnum.CONTENT, "chunk": text}
|
126
|
-
|
127
|
-
except ApiError as e:
|
128
|
-
self.console.print(f"Cohere API error during streaming: {e}", style="red")
|
129
|
-
yield {"type": EventTypeEnum.ERROR, "message": str(e)}
|
130
|
-
except Exception as e:
|
131
|
-
self.console.print(f"Unexpected error during Cohere streaming: {e}", style="red")
|
132
|
-
if self.verbose:
|
133
|
-
import traceback
|
134
|
-
|
135
|
-
traceback.print_exc()
|
136
|
-
yield {"type": EventTypeEnum.ERROR, "message": f"Unexpected stream error: {e}"}
|
yaicli/providers/openai.py
DELETED
@@ -1,176 +0,0 @@
|
|
1
|
-
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
2
|
-
|
3
|
-
import openai
|
4
|
-
from openai import OpenAI
|
5
|
-
from openai.types.chat.chat_completion import ChatCompletion
|
6
|
-
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
|
7
|
-
|
8
|
-
from yaicli.const import EventTypeEnum
|
9
|
-
from yaicli.providers.base import BaseClient
|
10
|
-
|
11
|
-
|
12
|
-
class OpenAIClient(BaseClient):
|
13
|
-
"""OpenAI API client implementation using the official OpenAI Python library."""
|
14
|
-
|
15
|
-
def __init__(self, config: Dict[str, Any], console, verbose: bool):
|
16
|
-
"""Initialize the OpenAI API client with configuration."""
|
17
|
-
super().__init__(config, console, verbose)
|
18
|
-
self.api_key = config["API_KEY"]
|
19
|
-
self.model = config["MODEL"]
|
20
|
-
self.base_url = config["BASE_URL"]
|
21
|
-
|
22
|
-
# Initialize the OpenAI client with our custom configuration
|
23
|
-
self.client = OpenAI(
|
24
|
-
api_key=self.api_key,
|
25
|
-
base_url=self.base_url,
|
26
|
-
timeout=self.timeout,
|
27
|
-
default_headers={"X-Title": "Yaicli"},
|
28
|
-
max_retries=2, # Add retry logic for resilience
|
29
|
-
)
|
30
|
-
|
31
|
-
def _prepare_request_params(self, messages: List[Dict[str, str]], stream: bool) -> Dict[str, Any]:
|
32
|
-
"""Prepare the common request parameters for OpenAI API calls."""
|
33
|
-
return {
|
34
|
-
"messages": messages,
|
35
|
-
"model": self.model,
|
36
|
-
"stream": stream,
|
37
|
-
"temperature": self.config["TEMPERATURE"],
|
38
|
-
"top_p": self.config["TOP_P"],
|
39
|
-
# Openai: This value is now deprecated in favor of max_completion_tokens
|
40
|
-
"max_tokens": self.config["MAX_TOKENS"],
|
41
|
-
"max_completion_tokens": self.config["MAX_TOKENS"],
|
42
|
-
}
|
43
|
-
|
44
|
-
def _process_completion_response(self, conpletion: ChatCompletion) -> Tuple[Optional[str], Optional[str]]:
|
45
|
-
"""Process the response from a non-streamed OpenAI completion request."""
|
46
|
-
try:
|
47
|
-
# OpenAI SDK returns structured objects
|
48
|
-
content = conpletion.choices[0].message.content
|
49
|
-
reasoning = None
|
50
|
-
|
51
|
-
# Check for reasoning in model_extra
|
52
|
-
if hasattr(conpletion.choices[0].message, "model_extra") and conpletion.choices[0].message.model_extra:
|
53
|
-
extra = conpletion.choices[0].message.model_extra
|
54
|
-
if extra and "reasoning" in extra:
|
55
|
-
reasoning = extra["reasoning"]
|
56
|
-
|
57
|
-
# If no reasoning in model_extra, try extracting from <think> tags
|
58
|
-
if reasoning is None and isinstance(content, str):
|
59
|
-
content = content.lstrip()
|
60
|
-
if content.startswith("<think>"):
|
61
|
-
think_end = content.find("</think>")
|
62
|
-
if think_end != -1:
|
63
|
-
reasoning = content[7:think_end].strip() # Start after <think>
|
64
|
-
# Remove the <think> block from the main content
|
65
|
-
content = content[think_end + 8 :].strip() # Start after </think>
|
66
|
-
|
67
|
-
return content, reasoning
|
68
|
-
except Exception as e:
|
69
|
-
self.console.print(f"Error processing OpenAI response: {e}", style="red")
|
70
|
-
if self.verbose:
|
71
|
-
self.console.print(f"Response: {conpletion}")
|
72
|
-
return None, None
|
73
|
-
|
74
|
-
def completion(self, messages: List[Dict[str, str]]) -> Tuple[Optional[str], Optional[str]]:
|
75
|
-
"""Get a complete non-streamed response from the OpenAI API."""
|
76
|
-
params = self._prepare_request_params(messages, stream=False)
|
77
|
-
|
78
|
-
try:
|
79
|
-
# Use context manager for proper resource management
|
80
|
-
with self.client.with_options(timeout=self.timeout) as client:
|
81
|
-
response: ChatCompletion = client.chat.completions.create(**params)
|
82
|
-
return self._process_completion_response(response)
|
83
|
-
except openai.APIConnectionError as e:
|
84
|
-
self.console.print(f"OpenAI connection error: {e}", style="red")
|
85
|
-
if self.verbose:
|
86
|
-
self.console.print(f"Underlying error: {e.__cause__}")
|
87
|
-
return None, None
|
88
|
-
except openai.RateLimitError as e:
|
89
|
-
self.console.print(f"OpenAI rate limit error (429): {e}", style="red")
|
90
|
-
return None, None
|
91
|
-
except openai.APIStatusError as e:
|
92
|
-
self.console.print(f"OpenAI API error (status {e.status_code}): {e}", style="red")
|
93
|
-
if self.verbose:
|
94
|
-
self.console.print(f"Response: {e.response}")
|
95
|
-
return None, None
|
96
|
-
except Exception as e:
|
97
|
-
self.console.print(f"Unexpected error during OpenAI completion: {e}", style="red")
|
98
|
-
if self.verbose:
|
99
|
-
import traceback
|
100
|
-
|
101
|
-
traceback.print_exc()
|
102
|
-
return None, None
|
103
|
-
|
104
|
-
def stream_completion(self, messages: List[Dict[str, str]]) -> Iterator[Dict[str, Any]]:
|
105
|
-
"""Connect to the OpenAI API and yield parsed stream events.
|
106
|
-
|
107
|
-
Args:
|
108
|
-
messages: The list of message dictionaries to send to the API
|
109
|
-
|
110
|
-
Yields:
|
111
|
-
Event dictionaries with the following structure:
|
112
|
-
- type: The event type (from EventTypeEnum)
|
113
|
-
- chunk/message/reason: The content of the event
|
114
|
-
"""
|
115
|
-
params: Dict[str, Any] = self._prepare_request_params(messages, stream=True)
|
116
|
-
in_reasoning: bool = False
|
117
|
-
|
118
|
-
try:
|
119
|
-
# Use context manager to ensure proper cleanup
|
120
|
-
with self.client.chat.completions.create(**params) as stream:
|
121
|
-
for chunk in stream:
|
122
|
-
choices: List[Choice] = chunk.choices
|
123
|
-
if not choices:
|
124
|
-
# Some APIs may return empty choices upon reaching the end of content.
|
125
|
-
continue
|
126
|
-
choice: Choice = choices[0]
|
127
|
-
delta: ChoiceDelta = choice.delta
|
128
|
-
finish_reason: Optional[str] = choice.finish_reason
|
129
|
-
|
130
|
-
# Process model_extra for reasoning content
|
131
|
-
if hasattr(delta, "model_extra") and delta.model_extra:
|
132
|
-
reasoning: Optional[str] = self._get_reasoning_content(delta.model_extra)
|
133
|
-
if reasoning:
|
134
|
-
yield {"type": EventTypeEnum.REASONING, "chunk": reasoning}
|
135
|
-
in_reasoning = True
|
136
|
-
continue
|
137
|
-
|
138
|
-
# Process content delta
|
139
|
-
if hasattr(delta, "content") and delta.content:
|
140
|
-
content_chunk = delta.content
|
141
|
-
if in_reasoning and content_chunk:
|
142
|
-
# Send reasoning end signal before content
|
143
|
-
in_reasoning = False
|
144
|
-
yield {"type": EventTypeEnum.REASONING_END, "chunk": ""}
|
145
|
-
yield {"type": EventTypeEnum.CONTENT, "chunk": content_chunk}
|
146
|
-
elif content_chunk:
|
147
|
-
yield {"type": EventTypeEnum.CONTENT, "chunk": content_chunk}
|
148
|
-
|
149
|
-
# Process finish reason
|
150
|
-
if finish_reason:
|
151
|
-
# Send reasoning end signal if still in reasoning state
|
152
|
-
if in_reasoning:
|
153
|
-
in_reasoning = False
|
154
|
-
yield {"type": EventTypeEnum.REASONING_END, "chunk": ""}
|
155
|
-
yield {"type": EventTypeEnum.FINISH, "reason": finish_reason}
|
156
|
-
|
157
|
-
except openai.APIConnectionError as e:
|
158
|
-
self.console.print(f"OpenAI connection error during streaming: {e}", style="red")
|
159
|
-
if self.verbose:
|
160
|
-
self.console.print(f"Underlying error: {e.__cause__}")
|
161
|
-
yield {"type": EventTypeEnum.ERROR, "message": str(e)}
|
162
|
-
except openai.RateLimitError as e:
|
163
|
-
self.console.print(f"OpenAI rate limit error (429) during streaming: {e}", style="red")
|
164
|
-
yield {"type": EventTypeEnum.ERROR, "message": str(e)}
|
165
|
-
except openai.APIStatusError as e:
|
166
|
-
self.console.print(f"OpenAI API error (status {e.status_code}) during streaming: {e}", style="red")
|
167
|
-
if self.verbose:
|
168
|
-
self.console.print(f"Response: {e.response}")
|
169
|
-
yield {"type": EventTypeEnum.ERROR, "message": str(e)}
|
170
|
-
except Exception as e:
|
171
|
-
self.console.print(f"Unexpected error during OpenAI streaming: {e}", style="red")
|
172
|
-
if self.verbose:
|
173
|
-
import traceback
|
174
|
-
|
175
|
-
traceback.print_exc()
|
176
|
-
yield {"type": EventTypeEnum.ERROR, "message": f"Unexpected stream error: {e}"}
|