quantalogic 0.2.16__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantalogic/__init__.py +3 -2
- quantalogic/agent.py +57 -36
- quantalogic/agent_config.py +18 -13
- quantalogic/coding_agent.py +6 -2
- quantalogic/{print_event.py → console_print_events.py} +1 -3
- quantalogic/console_print_token.py +16 -0
- quantalogic/docs_cli.py +50 -0
- quantalogic/generative_model.py +80 -77
- quantalogic/main.py +81 -15
- quantalogic/server/agent_server.py +2 -2
- quantalogic/tools/llm_tool.py +52 -11
- quantalogic/tools/llm_vision_tool.py +23 -7
- quantalogic/xml_parser.py +109 -49
- {quantalogic-0.2.16.dist-info → quantalogic-0.2.17.dist-info}/METADATA +18 -147
- {quantalogic-0.2.16.dist-info → quantalogic-0.2.17.dist-info}/RECORD +18 -16
- quantalogic-0.2.17.dist-info/entry_points.txt +6 -0
- quantalogic-0.2.16.dist-info/entry_points.txt +0 -3
- {quantalogic-0.2.16.dist-info → quantalogic-0.2.17.dist-info}/LICENSE +0 -0
- {quantalogic-0.2.16.dist-info → quantalogic-0.2.17.dist-info}/WHEEL +0 -0
quantalogic/generative_model.py
CHANGED
@@ -7,9 +7,12 @@ from litellm import completion, exceptions, get_max_tokens, get_model_info, toke
|
|
7
7
|
from loguru import logger
|
8
8
|
from pydantic import BaseModel, Field, field_validator
|
9
9
|
|
10
|
+
from quantalogic.event_emitter import EventEmitter # Importing the EventEmitter class
|
11
|
+
|
10
12
|
MIN_RETRIES = 1
|
11
13
|
|
12
14
|
|
15
|
+
# Define the Message class for conversation handling
|
13
16
|
class Message(BaseModel):
|
14
17
|
"""Represents a message in a conversation with a specific role and content."""
|
15
18
|
|
@@ -70,21 +73,22 @@ class GenerativeModel:
|
|
70
73
|
self,
|
71
74
|
model: str = "ollama/qwen2.5-coder:14b",
|
72
75
|
temperature: float = 0.7,
|
76
|
+
event_emitter: EventEmitter = None, # EventEmitter instance
|
73
77
|
) -> None:
|
74
78
|
"""Initialize a generative model with configurable parameters.
|
75
79
|
|
76
|
-
Configure the generative model with specified model,
|
77
|
-
temperature, and maximum token settings.
|
78
|
-
|
79
80
|
Args:
|
80
|
-
model: Model identifier.
|
81
|
-
|
82
|
-
|
83
|
-
|
81
|
+
model: Model identifier. Defaults to "ollama/qwen2.5-coder:14b".
|
82
|
+
temperature: Temperature parameter for controlling randomness in generation.
|
83
|
+
Higher values (e.g. 0.8) make output more random, lower values (e.g. 0.2)
|
84
|
+
make it more deterministic. Defaults to 0.7.
|
85
|
+
event_emitter: Optional event emitter instance for handling asynchronous events
|
86
|
+
and callbacks during text generation. Defaults to None.
|
84
87
|
"""
|
85
88
|
logger.debug(f"Initializing GenerativeModel with model={model}, temperature={temperature}")
|
86
89
|
self.model = model
|
87
90
|
self.temperature = temperature
|
91
|
+
self.event_emitter = event_emitter or EventEmitter() # Initialize event emitter
|
88
92
|
self._get_model_info_cached = functools.lru_cache(maxsize=32)(self._get_model_info_impl)
|
89
93
|
|
90
94
|
# Define retriable exceptions based on LiteLLM's exception mapping
|
@@ -109,28 +113,20 @@ class GenerativeModel:
|
|
109
113
|
exceptions.PermissionDeniedError,
|
110
114
|
)
|
111
115
|
|
112
|
-
#
|
116
|
+
# Generate a response with conversation history and optional streaming
|
113
117
|
def generate_with_history(
|
114
|
-
self, messages_history: list[Message], prompt: str, image_url: str | None = None
|
118
|
+
self, messages_history: list[Message], prompt: str, image_url: str | None = None, streaming: bool = False
|
115
119
|
) -> ResponseStats:
|
116
120
|
"""Generate a response with conversation history and optional image.
|
117
121
|
|
118
|
-
Generates a response based on previous conversation messages,
|
119
|
-
a new user prompt, and an optional image URL.
|
120
|
-
|
121
122
|
Args:
|
122
123
|
messages_history: Previous conversation messages.
|
123
124
|
prompt: Current user prompt.
|
124
125
|
image_url: Optional image URL for visual queries.
|
126
|
+
streaming: Whether to stream the response.
|
125
127
|
|
126
128
|
Returns:
|
127
|
-
Detailed response statistics.
|
128
|
-
|
129
|
-
Raises:
|
130
|
-
openai.AuthenticationError: If authentication fails.
|
131
|
-
openai.InvalidRequestError: If the request is invalid (e.g., context length exceeded).
|
132
|
-
openai.APIError: For content policy violations or other API errors.
|
133
|
-
Exception: For other unexpected errors.
|
129
|
+
Detailed response statistics or a generator in streaming mode.
|
134
130
|
"""
|
135
131
|
messages = [{"role": msg.role, "content": str(msg.content)} for msg in messages_history]
|
136
132
|
|
@@ -147,6 +143,10 @@ class GenerativeModel:
|
|
147
143
|
else:
|
148
144
|
messages.append({"role": "user", "content": str(prompt)})
|
149
145
|
|
146
|
+
if streaming:
|
147
|
+
self.event_emitter.emit("stream_start") # Emit stream start event
|
148
|
+
return self._stream_response(messages) # Return generator
|
149
|
+
|
150
150
|
try:
|
151
151
|
logger.debug(f"Generating response for prompt: {prompt}")
|
152
152
|
|
@@ -171,54 +171,68 @@ class GenerativeModel:
|
|
171
171
|
)
|
172
172
|
|
173
173
|
except Exception as e:
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
raise openai.AuthenticationError(
|
190
|
-
f"Authentication failed with provider {error_details['provider']}"
|
191
|
-
) from e
|
192
|
-
|
193
|
-
# Handle context window errors
|
194
|
-
if isinstance(e, self.CONTEXT_EXCEPTIONS):
|
195
|
-
raise openai.InvalidRequestError(f"Context window exceeded or invalid request: {str(e)}") from e
|
196
|
-
|
197
|
-
# Handle content policy violations
|
198
|
-
if isinstance(e, self.POLICY_EXCEPTIONS):
|
199
|
-
raise openai.APIError(f"Content policy violation: {str(e)}") from e
|
200
|
-
|
201
|
-
# For other exceptions, preserve the original error type if it's from OpenAI
|
202
|
-
if isinstance(e, openai.OpenAIError):
|
203
|
-
raise
|
204
|
-
|
205
|
-
# Wrap unknown errors in APIError
|
206
|
-
raise openai.APIError(f"Unexpected error during generation: {str(e)}") from e
|
207
|
-
|
208
|
-
def generate(self, prompt: str, image_url: str | None = None) -> ResponseStats:
|
209
|
-
"""Generate a response without conversation history.
|
174
|
+
self._handle_generation_exception(e)
|
175
|
+
|
176
|
+
def _stream_response(self, messages):
|
177
|
+
"""Private method to handle streaming responses."""
|
178
|
+
try:
|
179
|
+
for chunk in completion(
|
180
|
+
temperature=self.temperature,
|
181
|
+
model=self.model,
|
182
|
+
messages=messages,
|
183
|
+
num_retries=MIN_RETRIES,
|
184
|
+
stream=True, # Enable streaming
|
185
|
+
):
|
186
|
+
if chunk.choices[0].delta.content is not None:
|
187
|
+
self.event_emitter.emit("stream_chunk", chunk.choices[0].delta.content)
|
188
|
+
yield chunk.choices[0].delta.content # Yield each chunk of content
|
210
189
|
|
211
|
-
|
212
|
-
|
190
|
+
self.event_emitter.emit("stream_end") # Emit stream end event
|
191
|
+
except Exception as e:
|
192
|
+
logger.error(f"Streaming error: {str(e)}")
|
193
|
+
raise
|
194
|
+
|
195
|
+
def generate(self, prompt: str, image_url: str | None = None, streaming: bool = False) -> ResponseStats:
|
196
|
+
"""Generate a response without conversation history.
|
213
197
|
|
214
198
|
Args:
|
215
199
|
prompt: User prompt.
|
216
200
|
image_url: Optional image URL for visual queries.
|
201
|
+
streaming: Whether to stream the response.
|
217
202
|
|
218
203
|
Returns:
|
219
|
-
Detailed response statistics.
|
204
|
+
Detailed response statistics or a generator in streaming mode.
|
220
205
|
"""
|
221
|
-
return self.generate_with_history([], prompt, image_url)
|
206
|
+
return self.generate_with_history([], prompt, image_url, streaming)
|
207
|
+
|
208
|
+
def _handle_generation_exception(self, e):
|
209
|
+
"""Handle exceptions during generation."""
|
210
|
+
error_details = {
|
211
|
+
"error_type": type(e).__name__,
|
212
|
+
"message": str(e),
|
213
|
+
"model": self.model,
|
214
|
+
"provider": getattr(e, "llm_provider", "unknown"),
|
215
|
+
"status_code": getattr(e, "status_code", None),
|
216
|
+
}
|
217
|
+
|
218
|
+
logger.error("LLM Generation Error: {}", error_details)
|
219
|
+
logger.debug(f"Error details: {error_details}")
|
220
|
+
logger.debug(f"Model: {self.model}, Temperature: {self.temperature}")
|
221
|
+
|
222
|
+
if isinstance(e, self.AUTH_EXCEPTIONS):
|
223
|
+
logger.debug("Authentication error occurred")
|
224
|
+
raise openai.AuthenticationError(f"Authentication failed with provider {error_details['provider']}") from e
|
225
|
+
|
226
|
+
if isinstance(e, self.CONTEXT_EXCEPTIONS):
|
227
|
+
raise openai.InvalidRequestError(f"Context window exceeded or invalid request: {str(e)}") from e
|
228
|
+
|
229
|
+
if isinstance(e, self.POLICY_EXCEPTIONS):
|
230
|
+
raise openai.APIError(f"Content policy violation: {str(e)}") from e
|
231
|
+
|
232
|
+
if isinstance(e, openai.OpenAIError):
|
233
|
+
raise
|
234
|
+
|
235
|
+
raise openai.APIError(f"Unexpected error during generation: {str(e)}") from e
|
222
236
|
|
223
237
|
def get_max_tokens(self) -> int:
|
224
238
|
"""Get the maximum number of tokens that can be generated by the model."""
|
@@ -239,17 +253,9 @@ class GenerativeModel:
|
|
239
253
|
return token_counter(model=self.model, messages=litellm_messages)
|
240
254
|
|
241
255
|
def _get_model_info_impl(self, model_name: str) -> dict:
|
242
|
-
"""Get information about the model with prefix fallback logic.
|
243
|
-
|
244
|
-
Attempts to find model info by progressively removing provider prefixes.
|
245
|
-
Raises ValueError if no valid model configuration is found.
|
246
|
-
Results are cached to improve performance.
|
247
|
-
|
248
|
-
Example:
|
249
|
-
openrouter/openai/gpt-4o-mini → openai/gpt-4o-mini → gpt-4o-mini
|
250
|
-
"""
|
256
|
+
"""Get information about the model with prefix fallback logic."""
|
251
257
|
original_model = model_name
|
252
|
-
|
258
|
+
|
253
259
|
while True:
|
254
260
|
try:
|
255
261
|
logger.debug(f"Attempting to retrieve model info for: {model_name}")
|
@@ -259,22 +265,19 @@ class GenerativeModel:
|
|
259
265
|
return model_info
|
260
266
|
except Exception:
|
261
267
|
pass
|
262
|
-
|
268
|
+
|
263
269
|
# Try removing one prefix level
|
264
|
-
parts = model_name.split(
|
270
|
+
parts = model_name.split("/")
|
265
271
|
if len(parts) <= 1:
|
266
272
|
break
|
267
|
-
model_name =
|
268
|
-
|
273
|
+
model_name = "/".join(parts[1:])
|
274
|
+
|
269
275
|
error_msg = f"Could not find model info for {original_model} after trying: {self.model} → {model_name}"
|
270
276
|
logger.error(error_msg)
|
271
277
|
raise ValueError(error_msg)
|
272
278
|
|
273
279
|
def get_model_info(self, model_name: str = None) -> dict:
|
274
|
-
"""Get cached information about the model.
|
275
|
-
|
276
|
-
If no model name is provided, uses the current model.
|
277
|
-
"""
|
280
|
+
"""Get cached information about the model."""
|
278
281
|
if model_name is None:
|
279
282
|
model_name = self.model
|
280
283
|
return self._get_model_info_cached(model_name)
|
quantalogic/main.py
CHANGED
@@ -10,15 +10,20 @@ from typing import Optional
|
|
10
10
|
import click
|
11
11
|
from loguru import logger
|
12
12
|
|
13
|
+
from quantalogic.console_print_events import console_print_events
|
14
|
+
from quantalogic.console_print_token import console_print_token
|
13
15
|
from quantalogic.utils.check_version import check_if_is_latest_version
|
14
16
|
from quantalogic.version import get_version
|
15
17
|
|
16
18
|
# Configure logger
|
17
19
|
logger.remove() # Remove default logger
|
18
20
|
|
21
|
+
from threading import Lock # noqa: E402
|
22
|
+
|
19
23
|
from rich.console import Console # noqa: E402
|
20
24
|
from rich.panel import Panel # noqa: E402
|
21
25
|
from rich.prompt import Confirm # noqa: E402
|
26
|
+
from rich.spinner import Spinner # noqa: E402
|
22
27
|
|
23
28
|
from quantalogic.agent import Agent # noqa: E402
|
24
29
|
|
@@ -31,30 +36,31 @@ from quantalogic.agent_config import ( # noqa: E402
|
|
31
36
|
create_orchestrator_agent,
|
32
37
|
)
|
33
38
|
from quantalogic.interactive_text_editor import get_multiline_input # noqa: E402
|
34
|
-
from quantalogic.print_event import console_print_events # noqa: E402
|
35
39
|
from quantalogic.search_agent import create_search_agent # noqa: E402
|
36
40
|
|
37
41
|
AGENT_MODES = ["code", "basic", "interpreter", "full", "code-basic", "search", "search-full"]
|
38
42
|
|
39
43
|
|
40
|
-
def create_agent_for_mode(mode: str, model_name: str, vision_model_name: str | None) -> Agent:
|
44
|
+
def create_agent_for_mode(mode: str, model_name: str, vision_model_name: str | None, no_stream: bool = False) -> Agent:
|
41
45
|
"""Create an agent based on the specified mode."""
|
42
46
|
logger.debug(f"Creating agent for mode: {mode} with model: {model_name}")
|
47
|
+
logger.debug(f"Using vision model: {vision_model_name}")
|
48
|
+
logger.debug(f"Using no_stream: {no_stream}")
|
43
49
|
if mode == "code":
|
44
50
|
logger.debug("Creating code agent without basic mode")
|
45
|
-
return create_coding_agent(model_name, vision_model_name, basic=False)
|
51
|
+
return create_coding_agent(model_name, vision_model_name, basic=False, no_stream=no_stream)
|
46
52
|
if mode == "code-basic":
|
47
|
-
return create_coding_agent(model_name, vision_model_name, basic=True)
|
53
|
+
return create_coding_agent(model_name, vision_model_name, basic=True, no_stream=no_stream)
|
48
54
|
elif mode == "basic":
|
49
|
-
return create_orchestrator_agent(model_name, vision_model_name)
|
55
|
+
return create_orchestrator_agent(model_name, vision_model_name, no_stream=no_stream)
|
50
56
|
elif mode == "full":
|
51
|
-
return create_full_agent(model_name, vision_model_name)
|
57
|
+
return create_full_agent(model_name, vision_model_name, no_stream=no_stream)
|
52
58
|
elif mode == "interpreter":
|
53
|
-
return create_interpreter_agent(model_name, vision_model_name)
|
59
|
+
return create_interpreter_agent(model_name, vision_model_name, no_stream=no_stream)
|
54
60
|
elif mode == "search":
|
55
|
-
return create_search_agent(model_name)
|
61
|
+
return create_search_agent(model_name, no_stream=no_stream)
|
56
62
|
if mode == "search-full":
|
57
|
-
return create_search_agent(model_name, mode_full=True)
|
63
|
+
return create_search_agent(model_name, mode_full=True, no_stream=no_stream)
|
58
64
|
else:
|
59
65
|
raise ValueError(f"Unknown agent mode: {mode}")
|
60
66
|
|
@@ -126,6 +132,27 @@ def get_task_from_file(file_path: str) -> str:
|
|
126
132
|
raise Exception(f"Unexpected error reading file: {e}")
|
127
133
|
|
128
134
|
|
135
|
+
# Spinner control
|
136
|
+
spinner_lock = Lock()
|
137
|
+
current_spinner = None
|
138
|
+
|
139
|
+
def start_spinner(console: Console) -> None:
|
140
|
+
"""Start the thinking spinner."""
|
141
|
+
global current_spinner
|
142
|
+
with spinner_lock:
|
143
|
+
if current_spinner is None:
|
144
|
+
current_spinner = console.status("[yellow]Thinking...", spinner="dots")
|
145
|
+
current_spinner.start()
|
146
|
+
|
147
|
+
def stop_spinner(console: Console) -> None:
|
148
|
+
"""Stop the thinking spinner."""
|
149
|
+
global current_spinner
|
150
|
+
with spinner_lock:
|
151
|
+
if current_spinner is not None:
|
152
|
+
current_spinner.stop()
|
153
|
+
current_spinner = None
|
154
|
+
|
155
|
+
|
129
156
|
def display_welcome_message(
|
130
157
|
console: Console, model_name: str, vision_model_name: str | None, max_iterations: int = 50
|
131
158
|
) -> None:
|
@@ -159,7 +186,7 @@ def display_welcome_message(
|
|
159
186
|
@click.option(
|
160
187
|
"--model-name",
|
161
188
|
default=MODEL_NAME,
|
162
|
-
help='Specify the model to use (litellm format, e.g. "openrouter/deepseek-chat").',
|
189
|
+
help='Specify the model to use (litellm format, e.g. "openrouter/deepseek/deepseek-chat").',
|
163
190
|
)
|
164
191
|
@click.option(
|
165
192
|
"--log",
|
@@ -213,7 +240,7 @@ def cli(
|
|
213
240
|
@click.option(
|
214
241
|
"--model-name",
|
215
242
|
default=MODEL_NAME,
|
216
|
-
help='Specify the model to use (litellm format, e.g. "openrouter/deepseek-chat").',
|
243
|
+
help='Specify the model to use (litellm format, e.g. "openrouter/deepseek/deepseek-chat").',
|
217
244
|
)
|
218
245
|
@click.option("--verbose", is_flag=True, help="Enable verbose output.")
|
219
246
|
@click.option("--mode", type=click.Choice(AGENT_MODES), default="code", help="Agent mode (code/search/full).")
|
@@ -234,6 +261,11 @@ def cli(
|
|
234
261
|
default=30,
|
235
262
|
help="Maximum number of iterations for task solving (default: 30).",
|
236
263
|
)
|
264
|
+
@click.option(
|
265
|
+
"--no-stream",
|
266
|
+
is_flag=True,
|
267
|
+
help="Disable streaming output (default: streaming enabled).",
|
268
|
+
)
|
237
269
|
@click.argument("task", required=False)
|
238
270
|
def task(
|
239
271
|
file: Optional[str],
|
@@ -244,6 +276,7 @@ def task(
|
|
244
276
|
vision_model_name: str | None,
|
245
277
|
task: Optional[str],
|
246
278
|
max_iterations: int,
|
279
|
+
no_stream: bool,
|
247
280
|
) -> None:
|
248
281
|
"""Execute a task with the QuantaLogic AI Assistant."""
|
249
282
|
console = Console()
|
@@ -286,9 +319,13 @@ def task(
|
|
286
319
|
)
|
287
320
|
)
|
288
321
|
|
289
|
-
logger.debug(
|
290
|
-
|
291
|
-
|
322
|
+
logger.debug(
|
323
|
+
f"Creating agent for mode: {mode} with model: {model_name}, vision model: {vision_model_name}, no_stream: {no_stream}"
|
324
|
+
)
|
325
|
+
agent = create_agent_for_mode(mode, model_name, vision_model_name=vision_model_name, no_stream=no_stream)
|
326
|
+
logger.debug(
|
327
|
+
f"Created agent for mode: {mode} with model: {model_name}, vision model: {vision_model_name}, no_stream: {no_stream}"
|
328
|
+
)
|
292
329
|
|
293
330
|
events = [
|
294
331
|
"task_start",
|
@@ -302,16 +339,45 @@ def task(
|
|
302
339
|
"memory_compacted",
|
303
340
|
"memory_summary",
|
304
341
|
]
|
342
|
+
# Add spinner control to event handlers
|
343
|
+
def handle_task_think_start(*args, **kwargs):
|
344
|
+
start_spinner(console)
|
345
|
+
|
346
|
+
def handle_task_think_end(*args, **kwargs):
|
347
|
+
stop_spinner(console)
|
348
|
+
|
349
|
+
def handle_stream_chunk(event: str, data: str) -> None:
|
350
|
+
if current_spinner:
|
351
|
+
stop_spinner(console)
|
352
|
+
if data is not None:
|
353
|
+
console.print(data, end="", markup=False)
|
354
|
+
|
305
355
|
agent.event_emitter.on(
|
306
356
|
event=events,
|
307
357
|
listener=console_print_events,
|
308
358
|
)
|
359
|
+
|
360
|
+
agent.event_emitter.on(
|
361
|
+
event="task_think_start",
|
362
|
+
listener=handle_task_think_start,
|
363
|
+
)
|
364
|
+
|
365
|
+
agent.event_emitter.on(
|
366
|
+
event="task_think_end",
|
367
|
+
listener=handle_task_think_end,
|
368
|
+
)
|
369
|
+
|
370
|
+
agent.event_emitter.on(
|
371
|
+
event="stream_chunk",
|
372
|
+
listener=handle_stream_chunk,
|
373
|
+
)
|
374
|
+
|
309
375
|
logger.debug("Registered event handlers for agent events with events: {events}")
|
310
376
|
|
311
377
|
logger.debug(f"Solving task with agent: {task_content}")
|
312
378
|
if max_iterations < 1:
|
313
379
|
raise ValueError("max_iterations must be greater than 0")
|
314
|
-
result = agent.solve_task(task=task_content, max_iterations=max_iterations)
|
380
|
+
result = agent.solve_task(task=task_content, max_iterations=max_iterations, streaming=not no_stream)
|
315
381
|
logger.debug(f"Task solved with result: {result} using {max_iterations} iterations")
|
316
382
|
|
317
383
|
console.print(
|
@@ -30,7 +30,7 @@ from quantalogic.agent_config import (
|
|
30
30
|
create_coding_agent, # noqa: F401
|
31
31
|
create_orchestrator_agent, # noqa: F401
|
32
32
|
)
|
33
|
-
from quantalogic.
|
33
|
+
from quantalogic.console_print_events import console_print_events
|
34
34
|
|
35
35
|
# Configure logger
|
36
36
|
logger.remove()
|
@@ -246,7 +246,7 @@ class AgentState:
|
|
246
246
|
def initialize_agent_with_sse_validation(self, model_name: str = MODEL_NAME):
|
247
247
|
"""Initialize agent with SSE-based user validation."""
|
248
248
|
try:
|
249
|
-
self.agent = create_agent(model_name)
|
249
|
+
self.agent = create_agent(model_name, None)
|
250
250
|
|
251
251
|
# Comprehensive list of agent events to track
|
252
252
|
agent_events = [
|
quantalogic/tools/llm_tool.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
"""LLM Tool for generating answers to questions using a language model."""
|
2
2
|
|
3
|
-
import
|
3
|
+
from typing import Callable
|
4
4
|
|
5
|
+
from loguru import logger
|
5
6
|
from pydantic import ConfigDict, Field
|
6
7
|
|
8
|
+
from quantalogic.console_print_token import console_print_token
|
7
9
|
from quantalogic.generative_model import GenerativeModel, Message
|
8
10
|
from quantalogic.tools.tool import Tool, ToolArgument
|
9
11
|
|
@@ -53,15 +55,42 @@ class LLMTool(Tool):
|
|
53
55
|
)
|
54
56
|
|
55
57
|
model_name: str = Field(..., description="The name of the language model to use")
|
56
|
-
generative_model: GenerativeModel | None = Field(default=None)
|
57
58
|
system_prompt: str | None = Field(default=None)
|
59
|
+
on_token: Callable | None = Field(default=None, exclude=True)
|
60
|
+
generative_model: GenerativeModel | None = Field(default=None, exclude=True)
|
61
|
+
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
model_name: str,
|
65
|
+
system_prompt: str | None = None,
|
66
|
+
on_token: Callable | None = None,
|
67
|
+
name: str = "llm_tool",
|
68
|
+
generative_model: GenerativeModel | None = None,
|
69
|
+
):
|
70
|
+
# Use dict to pass validated data to parent constructor
|
71
|
+
super().__init__(
|
72
|
+
**{
|
73
|
+
"model_name": model_name,
|
74
|
+
"system_prompt": system_prompt,
|
75
|
+
"on_token": on_token,
|
76
|
+
"name": name,
|
77
|
+
"generative_model": generative_model,
|
78
|
+
}
|
79
|
+
)
|
80
|
+
|
81
|
+
# Initialize the generative model
|
82
|
+
self.model_post_init(None)
|
58
83
|
|
59
84
|
def model_post_init(self, __context):
|
60
85
|
"""Initialize the generative model after model initialization."""
|
61
86
|
if self.generative_model is None:
|
62
87
|
self.generative_model = GenerativeModel(model=self.model_name)
|
63
|
-
|
88
|
+
logger.debug(f"Initialized LLMTool with model: {self.model_name}")
|
64
89
|
|
90
|
+
# Only set up event listener if on_token is provided
|
91
|
+
if self.on_token is not None:
|
92
|
+
logger.debug(f"Setting up event listener for LLMTool with model: {self.model_name}")
|
93
|
+
self.generative_model.event_emitter.on("stream_chunk", self.on_token)
|
65
94
|
|
66
95
|
def execute(
|
67
96
|
self, system_prompt: str | None = None, prompt: str | None = None, temperature: str | None = None
|
@@ -85,7 +114,7 @@ class LLMTool(Tool):
|
|
85
114
|
if not (0.0 <= temp <= 1.0):
|
86
115
|
raise ValueError("Temperature must be between 0 and 1.")
|
87
116
|
except ValueError as ve:
|
88
|
-
|
117
|
+
logger.error(f"Invalid temperature value: {temperature}")
|
89
118
|
raise ValueError(f"Invalid temperature value: {temperature}") from ve
|
90
119
|
|
91
120
|
used_system_prompt = self.system_prompt if self.system_prompt else system_prompt
|
@@ -96,20 +125,29 @@ class LLMTool(Tool):
|
|
96
125
|
Message(role="user", content=prompt),
|
97
126
|
]
|
98
127
|
|
128
|
+
is_streaming = self.on_token is not None
|
129
|
+
|
99
130
|
# Set the model's temperature
|
100
131
|
if self.generative_model:
|
101
132
|
self.generative_model.temperature = temp
|
102
133
|
|
103
134
|
# Generate the response using the generative model
|
104
135
|
try:
|
105
|
-
|
106
|
-
messages_history=messages_history, prompt=
|
136
|
+
result = self.generative_model.generate_with_history(
|
137
|
+
messages_history=messages_history, prompt=prompt, streaming=is_streaming
|
107
138
|
)
|
108
|
-
|
109
|
-
|
139
|
+
|
140
|
+
if is_streaming:
|
141
|
+
response = ""
|
142
|
+
for chunk in result:
|
143
|
+
response += chunk
|
144
|
+
else:
|
145
|
+
response = result.response
|
146
|
+
|
147
|
+
logger.debug(f"Generated response: {response}")
|
110
148
|
return response
|
111
149
|
except Exception as e:
|
112
|
-
|
150
|
+
logger.error(f"Error generating response: {e}")
|
113
151
|
raise Exception(f"Error generating response: {e}") from e
|
114
152
|
else:
|
115
153
|
raise ValueError("Generative model not initialized")
|
@@ -123,6 +161,9 @@ if __name__ == "__main__":
|
|
123
161
|
temperature = "0.7"
|
124
162
|
answer = tool.execute(system_prompt=system_prompt, prompt=question, temperature=temperature)
|
125
163
|
print(answer)
|
126
|
-
pirate = LLMTool(
|
164
|
+
pirate = LLMTool(
|
165
|
+
model_name="openrouter/openai/gpt-4o-mini", system_prompt="You are a pirate.", on_token=console_print_token
|
166
|
+
)
|
127
167
|
pirate_answer = pirate.execute(system_prompt=system_prompt, prompt=question, temperature=temperature)
|
128
|
-
print(
|
168
|
+
print("\n")
|
169
|
+
print(f"Anwser: {pirate_answer}")
|
@@ -1,8 +1,8 @@
|
|
1
1
|
"""LLM Vision Tool for analyzing images using a language model."""
|
2
2
|
|
3
|
-
import logging
|
4
3
|
from typing import Optional
|
5
4
|
|
5
|
+
from loguru import logger
|
6
6
|
from pydantic import ConfigDict, Field
|
7
7
|
|
8
8
|
from quantalogic.generative_model import GenerativeModel, Message
|
@@ -65,7 +65,12 @@ class LLMVisionTool(Tool):
|
|
65
65
|
"""Initialize the generative model after model initialization."""
|
66
66
|
if self.generative_model is None:
|
67
67
|
self.generative_model = GenerativeModel(model=self.model_name)
|
68
|
-
|
68
|
+
logger.debug(f"Initialized LLMVisionTool with model: {self.model_name}")
|
69
|
+
|
70
|
+
# Only set up event listener if on_token is provided
|
71
|
+
if self.on_token is not None:
|
72
|
+
logger.debug(f"Setting up event listener for LLMVisionTool with model: {self.model_name}")
|
73
|
+
self.generative_model.event_emitter.on("stream_chunk", self.on_token)
|
69
74
|
|
70
75
|
def execute(self, system_prompt: str, prompt: str, image_url: str, temperature: str = "0.7") -> str:
|
71
76
|
"""Execute the tool to analyze an image and generate a response.
|
@@ -88,7 +93,7 @@ class LLMVisionTool(Tool):
|
|
88
93
|
if not (0.0 <= temp <= 1.0):
|
89
94
|
raise ValueError("Temperature must be between 0 and 1.")
|
90
95
|
except ValueError as ve:
|
91
|
-
|
96
|
+
logger.error(f"Invalid temperature value: {temperature}")
|
92
97
|
raise ValueError(f"Invalid temperature value: {temperature}") from ve
|
93
98
|
|
94
99
|
if not image_url.startswith(("http://", "https://")):
|
@@ -105,14 +110,25 @@ class LLMVisionTool(Tool):
|
|
105
110
|
self.generative_model.temperature = temp
|
106
111
|
|
107
112
|
try:
|
113
|
+
is_streaming = self.on_token is not None
|
108
114
|
response_stats = self.generative_model.generate_with_history(
|
109
|
-
messages_history=messages_history,
|
115
|
+
messages_history=messages_history,
|
116
|
+
prompt=prompt,
|
117
|
+
image_url=image_url,
|
118
|
+
streaming=is_streaming
|
110
119
|
)
|
111
|
-
|
112
|
-
|
120
|
+
|
121
|
+
if is_streaming:
|
122
|
+
response = ""
|
123
|
+
for chunk in response_stats:
|
124
|
+
response += chunk
|
125
|
+
else:
|
126
|
+
response = response_stats.response.strip()
|
127
|
+
|
128
|
+
logger.info(f"Generated response: {response}")
|
113
129
|
return response
|
114
130
|
except Exception as e:
|
115
|
-
|
131
|
+
logger.error(f"Error generating response: {e}")
|
116
132
|
raise Exception(f"Error generating response: {e}") from e
|
117
133
|
|
118
134
|
|