quantalogic 0.30.8__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantalogic/__init__.py +17 -7
- quantalogic/agent.py +75 -29
- quantalogic/agent_config.py +10 -0
- quantalogic/agent_factory.py +66 -11
- quantalogic/config.py +15 -0
- quantalogic/generative_model.py +17 -98
- quantalogic/get_model_info.py +26 -0
- quantalogic/interactive_text_editor.py +276 -102
- quantalogic/llm.py +135 -0
- quantalogic/main.py +60 -11
- quantalogic/prompts.py +66 -41
- quantalogic/task_runner.py +26 -39
- quantalogic/tool_manager.py +66 -0
- quantalogic/tools/replace_in_file_tool.py +1 -1
- quantalogic/tools/search_definition_names.py +2 -0
- quantalogic/tools/sql_query_tool.py +4 -2
- quantalogic/utils/get_all_models.py +20 -0
- {quantalogic-0.30.8.dist-info → quantalogic-0.31.0.dist-info}/METADATA +6 -1
- {quantalogic-0.30.8.dist-info → quantalogic-0.31.0.dist-info}/RECORD +22 -19
- {quantalogic-0.30.8.dist-info → quantalogic-0.31.0.dist-info}/LICENSE +0 -0
- {quantalogic-0.30.8.dist-info → quantalogic-0.31.0.dist-info}/WHEEL +0 -0
- {quantalogic-0.30.8.dist-info → quantalogic-0.31.0.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
|
1
|
+
from typing import List
|
2
2
|
|
3
3
|
from prompt_toolkit import PromptSession
|
4
4
|
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
@@ -9,70 +9,70 @@ from rich.panel import Panel
|
|
9
9
|
|
10
10
|
|
11
11
|
class InputHistoryManager:
|
12
|
-
"""Manages the history of input states for undo functionality."""
|
13
|
-
|
14
12
|
def __init__(self):
|
15
|
-
"""Initialize the InputHistoryManager with an empty history stack and current state."""
|
16
13
|
self.history_stack = []
|
17
14
|
self.current_state = []
|
18
15
|
|
19
|
-
def push_state(self, lines):
|
20
|
-
"""Push the current state to the history stack and update the current state.
|
21
|
-
|
22
|
-
Args:
|
23
|
-
lines (list): The current lines of input to be saved in the history.
|
24
|
-
"""
|
16
|
+
def push_state(self, lines: List[str]) -> None:
|
25
17
|
self.history_stack.append(self.current_state.copy())
|
26
18
|
self.current_state = lines.copy()
|
27
19
|
|
28
|
-
def undo(self, lines):
|
29
|
-
"""Revert to the previous state from the history stack.
|
30
|
-
|
31
|
-
Args:
|
32
|
-
lines (list): The current lines of input to be updated to the previous state.
|
33
|
-
|
34
|
-
Returns:
|
35
|
-
bool: True if undo was successful, False otherwise.
|
36
|
-
"""
|
20
|
+
def undo(self, lines: List[str]) -> bool:
|
37
21
|
if self.history_stack:
|
38
22
|
self.current_state = self.history_stack.pop()
|
39
23
|
lines[:] = self.current_state
|
40
24
|
return True
|
41
25
|
return False
|
42
26
|
|
43
|
-
|
44
|
-
def
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
27
|
+
class CommandRegistry:
|
28
|
+
def __init__(self):
|
29
|
+
self.commands = {}
|
30
|
+
|
31
|
+
def register(self, name: str, help_text: str = "") -> callable:
|
32
|
+
def decorator(func: callable) -> callable:
|
33
|
+
self.commands[name] = {
|
34
|
+
"handler": func,
|
35
|
+
"help": func.__doc__ or help_text
|
36
|
+
}
|
37
|
+
return func
|
38
|
+
return decorator
|
39
|
+
|
40
|
+
registry = CommandRegistry()
|
41
|
+
|
42
|
+
@registry.register("/help", "Show available commands")
|
43
|
+
def handle_help_command(lines: List[str], args: List[str], console: Console,
|
44
|
+
session: PromptSession, history_manager: InputHistoryManager) -> None:
|
45
|
+
"""Display auto-generated help from registered commands."""
|
46
|
+
help_content = "\n".join([f" {name}: {cmd['help']}" for name, cmd in registry.commands.items()])
|
47
|
+
console.print(Panel(f"Available commands:\n{help_content}", title="Help Menu", border_style="green"))
|
48
|
+
|
49
|
+
@registry.register("/date", "Show current date/time")
|
50
|
+
def handle_date_command(lines: List[str], args: List[str], console: Console,
|
51
|
+
session: PromptSession, history_manager: InputHistoryManager) -> None:
|
52
|
+
"""Display current date and time."""
|
53
|
+
from datetime import datetime
|
54
|
+
console.print(f"[bold #ffaa00]Current datetime: {datetime.now().isoformat()}[/bold #ffaa00]")
|
55
|
+
|
56
|
+
@registry.register("/edit", "Edit specific line: /edit <line_number>")
|
57
|
+
def handle_edit_command(lines: List[str], args: List[str], console: Console,
|
58
|
+
session: PromptSession, history_manager: InputHistoryManager) -> None:
|
59
|
+
"""Edit a specific line in the input buffer."""
|
54
60
|
try:
|
55
61
|
edit_line_num = int(args[0]) - 1
|
56
62
|
if 0 <= edit_line_num < len(lines):
|
57
|
-
console.print(f"[bold]Editing Line {edit_line_num + 1}:[/bold] {lines[edit_line_num]}")
|
63
|
+
console.print(f"[bold #1d3557]Editing Line {edit_line_num + 1}:[/bold #1d3557] {lines[edit_line_num]}") # Dark blue
|
58
64
|
new_line = session.prompt("New content: ")
|
59
65
|
history_manager.push_state(lines)
|
60
66
|
lines[edit_line_num] = new_line
|
61
67
|
else:
|
62
|
-
console.print("[
|
68
|
+
console.print("[bold #ff4444]Invalid line number.[/bold #ff4444]")
|
63
69
|
except (ValueError, IndexError):
|
64
|
-
|
70
|
+
console.print("[bold #ff4444]Invalid edit command. Usage: /edit <line_number>[/bold #ff4444]")
|
65
71
|
|
66
|
-
|
67
|
-
def handle_delete_command(lines, args, console,
|
68
|
-
|
69
|
-
|
70
|
-
Args:
|
71
|
-
lines (list): The current lines of input.
|
72
|
-
args (list): The arguments provided with the command.
|
73
|
-
console (Console): The console object for output.
|
74
|
-
history_manager (InputHistoryManager): The history manager for state tracking.
|
75
|
-
"""
|
72
|
+
@registry.register("/delete", "Delete specific line: /delete <line_number>")
|
73
|
+
def handle_delete_command(lines: List[str], args: List[str], console: Console,
|
74
|
+
session: PromptSession, history_manager: InputHistoryManager) -> None:
|
75
|
+
"""Delete a specific line from the input buffer."""
|
76
76
|
try:
|
77
77
|
delete_line_num = int(args[0]) - 1
|
78
78
|
if 0 <= delete_line_num < len(lines):
|
@@ -82,71 +82,89 @@ def handle_delete_command(lines, args, console, history_manager):
|
|
82
82
|
else:
|
83
83
|
console.print("[red]Invalid line number.[/red]")
|
84
84
|
except (ValueError, IndexError):
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
def handle_replace_command(lines, args, console, history_manager):
|
89
|
-
"""Handle the 'replace' command to search and replace text in all lines.
|
85
|
+
console.print("[bold #ff4444]Invalid delete command. Usage: /delete <line_number>[/bold #ff4444]")
|
90
86
|
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
console (Console): The console object for output.
|
95
|
-
history_manager (InputHistoryManager): The history manager for state tracking.
|
96
|
-
"""
|
87
|
+
@registry.register("/replace", "Search and replace: /replace <search> <replace>")
|
88
|
+
def handle_replace_command(lines: List[str], args: List[str], console: Console,
|
89
|
+
session: PromptSession, history_manager: InputHistoryManager) -> None:
|
97
90
|
try:
|
98
|
-
search_str
|
91
|
+
search_str = args[0]
|
92
|
+
replace_str = args[1]
|
99
93
|
history_manager.push_state(lines)
|
100
94
|
for i in range(len(lines)):
|
101
95
|
lines[i] = lines[i].replace(search_str, replace_str)
|
102
|
-
|
103
|
-
except ValueError:
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
commands = {"edit": handle_edit_command, "delete": handle_delete_command, "replace": handle_replace_command}
|
108
|
-
|
109
|
-
|
110
|
-
def handle_command(line, lines, console, session, history_manager):
|
111
|
-
"""Handle a command entered by the user.
|
112
|
-
|
113
|
-
Args:
|
114
|
-
line (str): The command line entered by the user.
|
115
|
-
lines (list): The current lines of input.
|
116
|
-
console (Console): The console object for output.
|
117
|
-
session (PromptSession): The prompt session for user input.
|
118
|
-
history_manager (InputHistoryManager): The history manager for state tracking.
|
96
|
+
console.print("[bold #00cc66]Search and replace completed.[/bold #00cc66]")
|
97
|
+
except (ValueError, IndexError):
|
98
|
+
console.print("[bold #ff4444]Invalid replace command. Usage: /replace <search_str> <replace_str>[/bold #ff4444]")
|
119
99
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
100
|
+
@registry.register("/model", "Show current AI model")
|
101
|
+
def handle_model_command(lines: List[str], args: List[str], console: Console,
|
102
|
+
session: PromptSession, history_manager: InputHistoryManager) -> None:
|
103
|
+
from quantalogic.agent_factory import AgentRegistry
|
104
|
+
try:
|
105
|
+
current_agent = AgentRegistry.get_agent("main_agent")
|
106
|
+
if current_agent:
|
107
|
+
console.print(f"[bold #ffaa00]Current AI model: {current_agent.model_name}[/bold #ffaa00]")
|
108
|
+
else:
|
109
|
+
console.print("[bold #ffaa00]No active agent found.[/bold #ffaa00]")
|
110
|
+
except ValueError as e:
|
111
|
+
console.print(f"[bold #ff4444]Error: {str(e)}[/bold #ff4444]")
|
112
|
+
|
113
|
+
@registry.register("/setmodel", "Set AI model name: /setmodel <name>")
|
114
|
+
def handle_set_model_command(lines: List[str], args: List[str], console: Console,
|
115
|
+
session: PromptSession, history_manager: InputHistoryManager) -> None:
|
116
|
+
from quantalogic.agent_factory import AgentRegistry
|
117
|
+
try:
|
118
|
+
if len(args) < 1:
|
119
|
+
console.print("[bold #ff4444]Error: Model name required. Usage: /setmodel <name>[/bold #ff4444]")
|
120
|
+
return
|
121
|
+
|
122
|
+
model_name = args[0]
|
123
|
+
current_agent = AgentRegistry.get_agent("main_agent")
|
124
|
+
if current_agent:
|
125
|
+
current_agent.model_name = model_name
|
126
|
+
console.print(f"[bold #00cc66]Model name updated to: {model_name}[/bold #00cc66]")
|
127
|
+
else:
|
128
|
+
console.print("[yellow]No active agent found.[/yellow]")
|
129
|
+
except ValueError as e:
|
130
|
+
console.print(f"[red]Error: {str(e)}[/red]")
|
131
|
+
|
132
|
+
@registry.register("/models", "List all available AI models")
|
133
|
+
def handle_models_command(lines: List[str], args: List[str], console: Console,
|
134
|
+
session: PromptSession, history_manager: InputHistoryManager) -> None:
|
135
|
+
"""Display all available AI models supported by the system."""
|
136
|
+
from quantalogic.utils.get_all_models import get_all_models
|
137
|
+
try:
|
138
|
+
models = get_all_models()
|
139
|
+
if models:
|
140
|
+
# Group models by provider
|
141
|
+
provider_groups = {}
|
142
|
+
for model in models:
|
143
|
+
provider = model.split('/')[0] if '/' in model else 'default'
|
144
|
+
if provider not in provider_groups:
|
145
|
+
provider_groups[provider] = []
|
146
|
+
provider_groups[provider].append(model)
|
147
|
+
|
148
|
+
# Create formatted output
|
149
|
+
output = "[bold #00cc66]Available AI Models:[/bold #00cc66]\n"
|
150
|
+
for provider, model_list in provider_groups.items():
|
151
|
+
output += f"\n[bold #ffaa00]{provider.upper()}[/bold #ffaa00]\n"
|
152
|
+
for model in sorted(model_list):
|
153
|
+
output += f" • {model}\n"
|
154
|
+
|
155
|
+
console.print(Panel(output, border_style="green"))
|
156
|
+
else:
|
157
|
+
console.print("[yellow]No models available.[/yellow]")
|
158
|
+
except Exception as e:
|
159
|
+
console.print(f"[red]Error retrieving models: {str(e)}[/red]")
|
132
160
|
|
133
161
|
|
134
162
|
def get_multiline_input(console: Console) -> str:
|
135
|
-
"""Get multiline input
|
136
|
-
|
137
|
-
Args:
|
138
|
-
console (Console): The console object for output.
|
139
|
-
|
140
|
-
Returns:
|
141
|
-
str: The multiline input provided by the user.
|
142
|
-
"""
|
163
|
+
"""Get multiline input with slash command support."""
|
143
164
|
console.print(
|
144
165
|
Panel(
|
145
166
|
"Enter your task. Press [bold]Enter[/bold] twice to submit.\n"
|
146
|
-
"
|
147
|
-
" edit <line_number> - Edit a specific line\n"
|
148
|
-
" delete <line_number> - Delete a specific line\n"
|
149
|
-
" replace <search_str> <replace_str> - Replace text in all lines",
|
167
|
+
"Type /help for available commands",
|
150
168
|
title="Multi-line Input",
|
151
169
|
border_style="blue",
|
152
170
|
)
|
@@ -162,18 +180,174 @@ def get_multiline_input(console: Console) -> str:
|
|
162
180
|
@bindings.add("c-z")
|
163
181
|
def _(event):
|
164
182
|
if history_manager.undo(lines):
|
165
|
-
console.print("[bold]Undo successful.[/bold]")
|
183
|
+
console.print("[bold #00cc66]Undo successful.[/bold #00cc66]")
|
184
|
+
|
185
|
+
|
186
|
+
from prompt_toolkit.completion import Completer, Completion
|
187
|
+
from prompt_toolkit.styles import Style
|
188
|
+
|
189
|
+
class CommandCompleter(Completer):
|
190
|
+
def get_completions(self, document, complete_event):
|
191
|
+
text = document.text_before_cursor
|
192
|
+
line_parts = text.split()
|
193
|
+
|
194
|
+
if len(line_parts) == 0 or (len(line_parts) == 1 and text.endswith(' ')):
|
195
|
+
for cmd, details in registry.commands.items():
|
196
|
+
yield Completion(
|
197
|
+
cmd,
|
198
|
+
start_position=-len(text),
|
199
|
+
display=f"[{cmd}] {details['help']}",
|
200
|
+
style="fg:ansicyan bold",
|
201
|
+
)
|
202
|
+
elif line_parts[0] in registry.commands:
|
203
|
+
cmd = registry.commands[line_parts[0]]
|
204
|
+
arg_index = len(line_parts) - 1
|
205
|
+
if arg_index == 1:
|
206
|
+
doc = cmd['handler'].__doc__ or ""
|
207
|
+
args_hint = next((line.split(':')[1].strip() for line in doc.split('\n')
|
208
|
+
if 'Args:' in line), "")
|
209
|
+
if args_hint:
|
210
|
+
yield Completion(
|
211
|
+
f"<{args_hint}>",
|
212
|
+
start_position=-len(text.split()[-1]),
|
213
|
+
style="fg:ansimagenta italic",
|
214
|
+
display=f"Expected argument: {args_hint}",
|
215
|
+
)
|
216
|
+
else:
|
217
|
+
if text.startswith('/'):
|
218
|
+
partial = text[1:].lstrip('/')
|
219
|
+
exact_matches = []
|
220
|
+
prefix_matches = []
|
221
|
+
|
222
|
+
for cmd in registry.commands:
|
223
|
+
cmd_without_slash = cmd[1:]
|
224
|
+
if cmd_without_slash.lower() == partial.lower():
|
225
|
+
exact_matches.append(cmd)
|
226
|
+
elif cmd_without_slash.lower().startswith(partial.lower()):
|
227
|
+
prefix_matches.append(cmd)
|
228
|
+
|
229
|
+
# Prioritize exact matches first
|
230
|
+
for match in exact_matches:
|
231
|
+
remaining = match[len('/' + partial):]
|
232
|
+
yield Completion(
|
233
|
+
remaining,
|
234
|
+
start_position=0, # Corrected from -len(partial)
|
235
|
+
display=f"{match} - {registry.commands[match]['help']}",
|
236
|
+
style="fg:ansiyellow bold",
|
237
|
+
)
|
238
|
+
|
239
|
+
# Then prefix matches
|
240
|
+
for match in prefix_matches:
|
241
|
+
remaining = match[len('/' + partial):]
|
242
|
+
yield Completion(
|
243
|
+
remaining,
|
244
|
+
start_position=0, # Corrected from -len(partial)
|
245
|
+
display=f"{match} - {registry.commands[match]['help']}",
|
246
|
+
style="fg:ansiyellow bold",
|
247
|
+
)
|
248
|
+
def get_completions(self, document, complete_event):
|
249
|
+
text = document.text_before_cursor
|
250
|
+
line_parts = text.split()
|
251
|
+
|
252
|
+
if len(line_parts) == 0 or (len(line_parts) == 1 and text.endswith(' ')):
|
253
|
+
for cmd, details in registry.commands.items():
|
254
|
+
yield Completion(
|
255
|
+
cmd,
|
256
|
+
start_position=-len(text),
|
257
|
+
display=f"[{cmd}] {details['help']}",
|
258
|
+
style="fg:ansicyan bold",
|
259
|
+
)
|
260
|
+
elif line_parts[0] in registry.commands:
|
261
|
+
cmd = registry.commands[line_parts[0]]
|
262
|
+
arg_index = len(line_parts) - 1
|
263
|
+
if arg_index == 1:
|
264
|
+
doc = cmd['handler'].__doc__ or ""
|
265
|
+
args_hint = next((line.split(':')[1].strip() for line in doc.split('\n')
|
266
|
+
if 'Args:' in line), "")
|
267
|
+
if args_hint:
|
268
|
+
yield Completion(
|
269
|
+
f"<{args_hint}>",
|
270
|
+
start_position=-len(text.split()[-1]),
|
271
|
+
style="fg:ansimagenta italic",
|
272
|
+
display=f"Expected argument: {args_hint}",
|
273
|
+
)
|
274
|
+
else:
|
275
|
+
if text.startswith('/'):
|
276
|
+
partial = text[1:].lstrip('/')
|
277
|
+
exact_matches = []
|
278
|
+
prefix_matches = []
|
279
|
+
|
280
|
+
for cmd in registry.commands:
|
281
|
+
cmd_without_slash = cmd[1:]
|
282
|
+
if cmd_without_slash.lower() == partial.lower():
|
283
|
+
exact_matches.append(cmd)
|
284
|
+
elif cmd_without_slash.lower().startswith(partial.lower()):
|
285
|
+
prefix_matches.append(cmd)
|
286
|
+
|
287
|
+
for match in exact_matches:
|
288
|
+
remaining = match[len('/' + partial):]
|
289
|
+
yield Completion(
|
290
|
+
remaining,
|
291
|
+
start_position=-len(partial),
|
292
|
+
display=f"{match} - {registry.commands[match]['help']}",
|
293
|
+
style="fg:ansiyellow bold",
|
294
|
+
)
|
295
|
+
|
296
|
+
for match in prefix_matches:
|
297
|
+
remaining = match[len('/' + partial):]
|
298
|
+
yield Completion(
|
299
|
+
remaining,
|
300
|
+
start_position=-len(partial),
|
301
|
+
display=f"{match} - {registry.commands[match]['help']}",
|
302
|
+
style="fg:ansiyellow bold",
|
303
|
+
)
|
304
|
+
|
305
|
+
command_completer = CommandCompleter()
|
306
|
+
|
307
|
+
def get_command_help(cmd_name: str) -> str:
|
308
|
+
"""Get formatted help text for a command"""
|
309
|
+
if cmd := registry.commands.get(cmd_name):
|
310
|
+
return f"[bold]{cmd_name}[/bold]: {cmd['help']}\n{cmd['handler'].__doc__ or ''}"
|
311
|
+
return ""
|
166
312
|
|
167
|
-
|
313
|
+
custom_style = Style.from_dict({
|
314
|
+
'completion-menu.completion': 'bg:#005577 #ffffff',
|
315
|
+
'completion-menu.completion.current': 'bg:#007799 #ffffff bold',
|
316
|
+
'scrollbar.background': 'bg:#6699aa',
|
317
|
+
'scrollbar.button': 'bg:#444444',
|
318
|
+
'documentation': 'bg:#003366 #ffffff',
|
319
|
+
})
|
320
|
+
|
321
|
+
session = PromptSession(
|
322
|
+
history=InMemoryHistory(),
|
323
|
+
auto_suggest=AutoSuggestFromHistory(),
|
324
|
+
key_bindings=bindings,
|
325
|
+
completer=command_completer,
|
326
|
+
complete_while_typing=True,
|
327
|
+
style=custom_style,
|
328
|
+
bottom_toolbar=lambda: get_command_help(session.default_buffer.document.text.split()[0][1:]
|
329
|
+
if session.default_buffer.document.text.startswith('/')
|
330
|
+
else ""),
|
331
|
+
)
|
168
332
|
|
169
333
|
try:
|
170
334
|
while True:
|
171
335
|
prompt_text = f"{line_number:>3}: "
|
172
336
|
line = session.prompt(prompt_text, rprompt="Press Enter twice to submit")
|
173
337
|
|
174
|
-
# Handle commands with single return
|
175
338
|
if line.strip().startswith('/'):
|
176
|
-
|
339
|
+
cmd_parts = line.strip().split()
|
340
|
+
cmd_name = cmd_parts[0].lower()
|
341
|
+
args = cmd_parts[1:]
|
342
|
+
|
343
|
+
if cmd_handler := registry.commands.get(cmd_name):
|
344
|
+
try:
|
345
|
+
cmd_handler["handler"](lines, args, console, session, history_manager)
|
346
|
+
except Exception as e:
|
347
|
+
console.print(f"[red]Error executing {cmd_name}: {str(e)}[/red]")
|
348
|
+
else:
|
349
|
+
console.print(f"[red]Unknown command: {cmd_name}[/red]")
|
350
|
+
continue
|
177
351
|
|
178
352
|
if line.strip() == "":
|
179
353
|
blank_lines += 1
|
@@ -181,7 +355,7 @@ def get_multiline_input(console: Console) -> str:
|
|
181
355
|
break
|
182
356
|
else:
|
183
357
|
blank_lines = 0
|
184
|
-
if not
|
358
|
+
if not any(line.strip().startswith(cmd) for cmd in registry.commands):
|
185
359
|
history_manager.push_state(lines)
|
186
360
|
lines.append(line)
|
187
361
|
line_number += 1
|
@@ -191,4 +365,4 @@ def get_multiline_input(console: Console) -> str:
|
|
191
365
|
console.print("\n[bold]Input cancelled by user.[/bold]")
|
192
366
|
return ""
|
193
367
|
|
194
|
-
return "\n".join(lines)
|
368
|
+
return "\n".join(lines)
|
quantalogic/llm.py
ADDED
@@ -0,0 +1,135 @@
|
|
1
|
+
"""LLM wrapper module for handling LiteLLM operations."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"generate_completion",
|
5
|
+
"generate_image",
|
6
|
+
"count_tokens",
|
7
|
+
"get_model_max_input_tokens",
|
8
|
+
"get_model_max_output_tokens",
|
9
|
+
]
|
10
|
+
|
11
|
+
import os
|
12
|
+
from typing import Any, Dict, List
|
13
|
+
|
14
|
+
from litellm import (
|
15
|
+
completion,
|
16
|
+
image_generation,
|
17
|
+
token_counter,
|
18
|
+
)
|
19
|
+
from loguru import logger
|
20
|
+
|
21
|
+
from quantalogic.get_model_info import (
|
22
|
+
get_max_input_tokens,
|
23
|
+
get_max_output_tokens,
|
24
|
+
model_info,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
def get_model_info(model_name: str) -> dict | None:
|
29
|
+
"""Get model information for a given model name."""
|
30
|
+
return model_info.get(model_name, None)
|
31
|
+
|
32
|
+
|
33
|
+
def generate_completion(**kwargs: Dict[str, Any]) -> Any:
|
34
|
+
"""Wraps litellm completion with proper type hints."""
|
35
|
+
model = kwargs.get("model", "")
|
36
|
+
if model.startswith("dashscope/"):
|
37
|
+
# Remove prefix and configure for OpenAI-compatible endpoint
|
38
|
+
kwargs["model"] = model.replace("dashscope/", "")
|
39
|
+
kwargs["custom_llm_provider"] = "openai" # Explicitly specify OpenAI provider
|
40
|
+
kwargs["base_url"] = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
|
41
|
+
api_key = os.getenv("DASHSCOPE_API_KEY")
|
42
|
+
if not api_key:
|
43
|
+
raise ValueError("DASHSCOPE_API_KEY is not set in the environment variables.")
|
44
|
+
kwargs["api_key"] = api_key
|
45
|
+
return completion(**kwargs)
|
46
|
+
|
47
|
+
|
48
|
+
def generate_image(**kwargs: Dict[str, Any]) -> Any:
|
49
|
+
"""Wraps litellm image_generation with proper type hints."""
|
50
|
+
return image_generation(**kwargs)
|
51
|
+
|
52
|
+
|
53
|
+
def count_tokens(model: str, messages: List[Dict[str, Any]]) -> int:
|
54
|
+
"""Wraps litellm token_counter with proper type hints."""
|
55
|
+
return token_counter(model=model, messages=messages)
|
56
|
+
|
57
|
+
|
58
|
+
def _get_model_info_impl(model_name: str) -> dict:
|
59
|
+
"""Get information about the model with prefix fallback logic."""
|
60
|
+
original_model = model_name
|
61
|
+
tried_models = [model_name]
|
62
|
+
|
63
|
+
while True:
|
64
|
+
try:
|
65
|
+
logger.debug(f"Attempting to retrieve model info for: {model_name}")
|
66
|
+
# Try direct lookup from model_info dictionary first
|
67
|
+
if model_name in model_info:
|
68
|
+
logger.debug(f"Found model info for {model_name} in model_info")
|
69
|
+
return model_info[model_name]
|
70
|
+
|
71
|
+
# Try get_model_info as fallback
|
72
|
+
info = get_model_info(model_name)
|
73
|
+
if info:
|
74
|
+
logger.debug(f"Found model info for {model_name} via get_model_info")
|
75
|
+
return info
|
76
|
+
except Exception as e:
|
77
|
+
logger.debug(f"Failed to get model info for {model_name}: {str(e)}")
|
78
|
+
pass
|
79
|
+
|
80
|
+
# Try removing one prefix level
|
81
|
+
parts = model_name.split("/")
|
82
|
+
if len(parts) <= 1:
|
83
|
+
break
|
84
|
+
model_name = "/".join(parts[1:])
|
85
|
+
tried_models.append(model_name)
|
86
|
+
|
87
|
+
error_msg = f"Could not find model info for {original_model} after trying: {' → '.join(tried_models)}"
|
88
|
+
logger.error(error_msg)
|
89
|
+
raise ValueError(error_msg)
|
90
|
+
|
91
|
+
|
92
|
+
def get_model_max_input_tokens(model_name: str) -> int | None:
|
93
|
+
"""Get the maximum number of input tokens for the model."""
|
94
|
+
try:
|
95
|
+
# First try direct lookup
|
96
|
+
max_tokens = get_max_input_tokens(model_name)
|
97
|
+
if max_tokens is not None:
|
98
|
+
return max_tokens
|
99
|
+
|
100
|
+
# If not found, try getting from model info
|
101
|
+
model_info = _get_model_info_impl(model_name)
|
102
|
+
max_input = model_info.get("max_input_tokens")
|
103
|
+
if max_input is not None:
|
104
|
+
return max_input
|
105
|
+
|
106
|
+
# If still not found, log warning and return default
|
107
|
+
logger.warning(f"No max input tokens found for {model_name}. Using default.")
|
108
|
+
return 8192 # A reasonable default for many models
|
109
|
+
|
110
|
+
except Exception as e:
|
111
|
+
logger.error(f"Error getting max input tokens for {model_name}: {e}")
|
112
|
+
return None
|
113
|
+
|
114
|
+
|
115
|
+
def get_model_max_output_tokens(model_name: str) -> int | None:
|
116
|
+
"""Get the maximum number of output tokens for the model."""
|
117
|
+
try:
|
118
|
+
# First try direct lookup
|
119
|
+
max_tokens = get_max_output_tokens(model_name)
|
120
|
+
if max_tokens is not None:
|
121
|
+
return max_tokens
|
122
|
+
|
123
|
+
# If not found, try getting from model info
|
124
|
+
model_info = _get_model_info_impl(model_name)
|
125
|
+
max_output = model_info.get("max_output_tokens")
|
126
|
+
if max_output is not None:
|
127
|
+
return max_output
|
128
|
+
|
129
|
+
# If still not found, log warning and return default
|
130
|
+
logger.warning(f"No max output tokens found for {model_name}. Using default.")
|
131
|
+
return 4096 # A reasonable default for many models
|
132
|
+
|
133
|
+
except Exception as e:
|
134
|
+
logger.error(f"Error getting max output tokens for {model_name}: {e}")
|
135
|
+
return None
|