quantalogic 0.30.6__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantalogic/__init__.py +17 -7
- quantalogic/agent.py +87 -37
- quantalogic/agent_config.py +10 -0
- quantalogic/agent_factory.py +66 -11
- quantalogic/config.py +15 -0
- quantalogic/generative_model.py +17 -98
- quantalogic/get_model_info.py +26 -0
- quantalogic/interactive_text_editor.py +276 -102
- quantalogic/llm.py +135 -0
- quantalogic/main.py +60 -11
- quantalogic/prompts.py +66 -41
- quantalogic/task_runner.py +26 -39
- quantalogic/tool_manager.py +110 -33
- quantalogic/tools/replace_in_file_tool.py +1 -1
- quantalogic/tools/search_definition_names.py +2 -0
- quantalogic/tools/sql_query_tool.py +4 -2
- quantalogic/utils/get_all_models.py +20 -0
- {quantalogic-0.30.6.dist-info → quantalogic-0.31.0.dist-info}/METADATA +6 -1
- {quantalogic-0.30.6.dist-info → quantalogic-0.31.0.dist-info}/RECORD +22 -19
- {quantalogic-0.30.6.dist-info → quantalogic-0.31.0.dist-info}/LICENSE +0 -0
- {quantalogic-0.30.6.dist-info → quantalogic-0.31.0.dist-info}/WHEEL +0 -0
- {quantalogic-0.30.6.dist-info → quantalogic-0.31.0.dist-info}/entry_points.txt +0 -0
quantalogic/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
|
-
|
1
|
+
"""QuantaLogic package initialization."""
|
2
|
+
|
2
3
|
import warnings
|
3
4
|
|
4
5
|
# Suppress specific warnings related to Pydantic's V2 configuration changes
|
@@ -9,13 +10,22 @@ warnings.filterwarnings(
|
|
9
10
|
message=".*config keys have changed in V2:.*|.*'fields' config key is removed in V2.*",
|
10
11
|
)
|
11
12
|
|
12
|
-
|
13
|
+
# Import public API
|
14
|
+
from .llm import generate_completion, generate_image, count_tokens # noqa: E402
|
13
15
|
from .agent import Agent # noqa: E402
|
14
|
-
from .console_print_events import console_print_events # noqa: E402
|
15
|
-
from .console_print_token import console_print_token # noqa: E402
|
16
16
|
from .event_emitter import EventEmitter # noqa: E402
|
17
17
|
from .memory import AgentMemory, VariableMemory # noqa: E402
|
18
|
+
from .console_print_events import console_print_events # noqa: E402
|
19
|
+
from .console_print_token import console_print_token # noqa: E402
|
18
20
|
|
19
|
-
|
20
|
-
|
21
|
-
|
21
|
+
__all__ = [
|
22
|
+
"Agent",
|
23
|
+
"EventEmitter",
|
24
|
+
"AgentMemory",
|
25
|
+
"VariableMemory",
|
26
|
+
"console_print_events",
|
27
|
+
"console_print_token",
|
28
|
+
"generate_completion",
|
29
|
+
"generate_image",
|
30
|
+
"count_tokens"
|
31
|
+
]
|
quantalogic/agent.py
CHANGED
@@ -5,7 +5,7 @@ from datetime import datetime
|
|
5
5
|
from typing import Any
|
6
6
|
|
7
7
|
from loguru import logger
|
8
|
-
from pydantic import BaseModel, ConfigDict
|
8
|
+
from pydantic import BaseModel, ConfigDict, PrivateAttr
|
9
9
|
|
10
10
|
from quantalogic.event_emitter import EventEmitter
|
11
11
|
from quantalogic.generative_model import GenerativeModel, ResponseStats, TokenUsage
|
@@ -52,12 +52,16 @@ class ObserveResponseResult(BaseModel):
|
|
52
52
|
class Agent(BaseModel):
|
53
53
|
"""Enhanced QuantaLogic agent implementing ReAct framework."""
|
54
54
|
|
55
|
-
model_config = ConfigDict(
|
55
|
+
model_config = ConfigDict(
|
56
|
+
arbitrary_types_allowed=True,
|
57
|
+
validate_assignment=True,
|
58
|
+
extra="forbid"
|
59
|
+
)
|
56
60
|
|
57
61
|
specific_expertise: str
|
58
62
|
model: GenerativeModel
|
59
|
-
memory: AgentMemory = AgentMemory()
|
60
|
-
variable_store: VariableMemory = VariableMemory()
|
63
|
+
memory: AgentMemory = AgentMemory() # A list User / Assistant Messages
|
64
|
+
variable_store: VariableMemory = VariableMemory() # A dictionary of variables
|
61
65
|
tools: ToolManager = ToolManager()
|
62
66
|
event_emitter: EventEmitter = EventEmitter()
|
63
67
|
config: AgentConfig
|
@@ -71,8 +75,9 @@ class Agent(BaseModel):
|
|
71
75
|
max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS
|
72
76
|
max_iterations: int = 30
|
73
77
|
system_prompt: str = ""
|
74
|
-
compact_every_n_iterations: int | None = None
|
75
|
-
max_tokens_working_memory: int | None = None
|
78
|
+
compact_every_n_iterations: int | None = None
|
79
|
+
max_tokens_working_memory: int | None = None
|
80
|
+
_model_name: str = PrivateAttr(default="")
|
76
81
|
|
77
82
|
def __init__(
|
78
83
|
self,
|
@@ -84,17 +89,18 @@ class Agent(BaseModel):
|
|
84
89
|
task_to_solve: str = "",
|
85
90
|
specific_expertise: str = "General AI assistant with coding and problem-solving capabilities",
|
86
91
|
get_environment: Callable[[], str] = get_environment,
|
87
|
-
compact_every_n_iterations: int | None = None,
|
88
|
-
max_tokens_working_memory: int | None = None,
|
92
|
+
compact_every_n_iterations: int | None = None,
|
93
|
+
max_tokens_working_memory: int | None = None,
|
89
94
|
):
|
90
95
|
"""Initialize the agent with model, memory, tools, and configurations."""
|
91
96
|
try:
|
92
97
|
logger.debug("Initializing agent...")
|
93
|
-
|
98
|
+
|
99
|
+
# Create event emitter
|
94
100
|
event_emitter = EventEmitter()
|
95
101
|
|
96
102
|
# Add TaskCompleteTool to the tools list if not already present
|
97
|
-
if TaskCompleteTool
|
103
|
+
if not any(isinstance(t, TaskCompleteTool) for t in tools):
|
98
104
|
tools.append(TaskCompleteTool())
|
99
105
|
|
100
106
|
tool_manager = ToolManager(tools={tool.name: tool for tool in tools})
|
@@ -114,32 +120,50 @@ class Agent(BaseModel):
|
|
114
120
|
system_prompt=system_prompt_text,
|
115
121
|
)
|
116
122
|
|
117
|
-
|
123
|
+
# Initialize using Pydantic's model_validate
|
118
124
|
super().__init__(
|
125
|
+
specific_expertise=specific_expertise,
|
119
126
|
model=GenerativeModel(model=model_name, event_emitter=event_emitter),
|
120
127
|
memory=memory,
|
121
128
|
variable_store=variable_store,
|
122
129
|
tools=tool_manager,
|
130
|
+
event_emitter=event_emitter,
|
123
131
|
config=config,
|
124
|
-
ask_for_user_validation=ask_for_user_validation,
|
125
132
|
task_to_solve=task_to_solve,
|
126
|
-
|
127
|
-
|
133
|
+
task_to_solve_summary="",
|
134
|
+
ask_for_user_validation=ask_for_user_validation,
|
135
|
+
last_tool_call={},
|
136
|
+
total_tokens=0,
|
137
|
+
current_iteration=0,
|
138
|
+
max_input_tokens=DEFAULT_MAX_INPUT_TOKENS,
|
139
|
+
max_output_tokens=DEFAULT_MAX_OUTPUT_TOKENS,
|
140
|
+
max_iterations=30,
|
141
|
+
system_prompt="",
|
142
|
+
compact_every_n_iterations=compact_every_n_iterations or 30,
|
143
|
+
max_tokens_working_memory=max_tokens_working_memory,
|
128
144
|
)
|
129
145
|
|
130
|
-
|
131
|
-
self.compact_every_n_iterations = compact_every_n_iterations or self.max_iterations
|
132
|
-
logger.debug(f"Memory will be compacted every {self.compact_every_n_iterations} iterations")
|
146
|
+
self._model_name = model_name
|
133
147
|
|
134
|
-
|
135
|
-
self.max_tokens_working_memory = max_tokens_working_memory
|
148
|
+
logger.debug(f"Memory will be compacted every {self.compact_every_n_iterations} iterations")
|
136
149
|
logger.debug(f"Max tokens for working memory set to: {self.max_tokens_working_memory}")
|
137
|
-
|
138
150
|
logger.debug("Agent initialized successfully.")
|
139
151
|
except Exception as e:
|
140
152
|
logger.error(f"Failed to initialize agent: {str(e)}")
|
141
153
|
raise
|
142
154
|
|
155
|
+
@property
|
156
|
+
def model_name(self) -> str:
|
157
|
+
"""Get the current model name."""
|
158
|
+
return self._model_name
|
159
|
+
|
160
|
+
@model_name.setter
|
161
|
+
def model_name(self, value: str) -> None:
|
162
|
+
"""Set the model name."""
|
163
|
+
self._model_name = value
|
164
|
+
# Update the model instance with the new name
|
165
|
+
self.model = GenerativeModel(model=value, event_emitter=self.event_emitter)
|
166
|
+
|
143
167
|
def clear_memory(self):
|
144
168
|
"""Clear the memory and reset the session."""
|
145
169
|
self._reset_session(clear_memory=True)
|
@@ -533,7 +557,10 @@ class Agent(BaseModel):
|
|
533
557
|
question_validation: str = (
|
534
558
|
"Do you permit the execution of this tool?\n"
|
535
559
|
f"Tool: {tool_name}\n"
|
536
|
-
|
560
|
+
"Arguments:\n"
|
561
|
+
"<arguments>\n"
|
562
|
+
+ "\n".join([f" <{key}>{value}</{key}>" for key, value in arguments_with_values.items()])
|
563
|
+
+ "\n</arguments>\n"
|
537
564
|
"Yes or No"
|
538
565
|
)
|
539
566
|
permission_granted = self.ask_for_user_validation(question_validation)
|
@@ -559,20 +586,24 @@ class Agent(BaseModel):
|
|
559
586
|
key: self._interpolate_variables(value) for key, value in arguments_with_values.items()
|
560
587
|
}
|
561
588
|
|
562
|
-
|
563
|
-
try:
|
564
|
-
converted_args = self.tools._convert_kwargs_types(tool_name, **arguments_with_values_interpolated)
|
565
|
-
except ValueError as e:
|
566
|
-
logger.error(f"Type conversion failed: {str(e)}")
|
567
|
-
return "", f"Error: Type conversion failed for tool '{tool_name}': {str(e)}"
|
589
|
+
arguments_with_values_interpolated = arguments_with_values_interpolated
|
568
590
|
|
569
591
|
# test if tool need variables in context
|
570
592
|
if tool.need_variables:
|
571
593
|
# Inject variables into the tool if needed
|
572
|
-
|
594
|
+
arguments_with_values_interpolated["variables"] = self.variable_store
|
573
595
|
if tool.need_caller_context_memory:
|
574
596
|
# Inject caller context into the tool if needed
|
575
|
-
|
597
|
+
arguments_with_values_interpolated["caller_context_memory"] = self.memory.memory
|
598
|
+
|
599
|
+
try:
|
600
|
+
# Convert arguments to proper types
|
601
|
+
converted_args = self.tools.validate_and_convert_arguments(
|
602
|
+
tool_name,
|
603
|
+
arguments_with_values_interpolated
|
604
|
+
)
|
605
|
+
except ValueError as e:
|
606
|
+
return "", f"Argument Error: {str(e)}"
|
576
607
|
|
577
608
|
# Add injectable variables
|
578
609
|
injectable_properties = tool.get_injectable_properties_in_execution()
|
@@ -599,10 +630,14 @@ class Agent(BaseModel):
|
|
599
630
|
return executed_tool, response
|
600
631
|
|
601
632
|
def _interpolate_variables(self, text: str) -> str:
|
602
|
-
"""Interpolate variables using $
|
633
|
+
"""Interpolate variables using $var$ syntax in the given text."""
|
603
634
|
try:
|
635
|
+
import re
|
604
636
|
for var in self.variable_store.keys():
|
605
|
-
|
637
|
+
# Escape the variable name for regex, but use raw value for replacement
|
638
|
+
pattern = rf'\${re.escape(var)}\$'
|
639
|
+
replacement = self.variable_store[var]
|
640
|
+
text = re.sub(pattern, replacement, text)
|
606
641
|
return text
|
607
642
|
except Exception as e:
|
608
643
|
logger.error(f"Error in _interpolate_variables: {str(e)}")
|
@@ -641,6 +676,7 @@ class Agent(BaseModel):
|
|
641
676
|
"1. Select ONE tool per message\n"
|
642
677
|
"2. You will receive the tool's output in the next user response\n"
|
643
678
|
"3. Choose the most appropriate tool for each step\n"
|
679
|
+
"4. Use task_complete tool to confirm task completion\n"
|
644
680
|
)
|
645
681
|
return prompt_use_tools
|
646
682
|
|
@@ -702,23 +738,32 @@ class Agent(BaseModel):
|
|
702
738
|
return summary.response
|
703
739
|
|
704
740
|
def _generate_task_summary(self, content: str) -> str:
|
705
|
-
"""Generate a concise summary
|
741
|
+
"""Generate a concise task-focused summary using the generative model.
|
706
742
|
|
707
743
|
Args:
|
708
744
|
content (str): The content to summarize
|
709
745
|
|
710
746
|
Returns:
|
711
|
-
str: Generated summary
|
747
|
+
str: Generated task summary
|
712
748
|
"""
|
713
749
|
try:
|
714
750
|
prompt = (
|
715
|
-
"
|
716
|
-
|
717
|
-
"
|
751
|
+
"Create an ultra-concise task summary that captures ONLY: \n"
|
752
|
+
"1. Primary objective/purpose\n"
|
753
|
+
"2. Core actions/requirements\n"
|
754
|
+
"3. Desired end-state/outcome\n\n"
|
755
|
+
"Guidelines:\n"
|
756
|
+
"- Use imperative voice\n"
|
757
|
+
"- Exclude background, explanations, and examples\n"
|
758
|
+
"- Compress information using semantic density\n"
|
759
|
+
"- Strict 2-3 sentence maximum (under 50 words)\n"
|
760
|
+
"- Format: 'Concise Task Summary: [Your summary]'\n\n"
|
761
|
+
f"Input Task Description:\n{content}\n\n"
|
762
|
+
"Concise Task Summary:"
|
718
763
|
)
|
719
764
|
result = self.model.generate(prompt=prompt)
|
720
765
|
logger.debug(f"Generated summary: {result.response}")
|
721
|
-
return result.response
|
766
|
+
return result.response.strip() + "\n🚨 The FULL task is in <task> tag in the previous messages.\n"
|
722
767
|
except Exception as e:
|
723
768
|
logger.error(f"Error generating summary: {str(e)}")
|
724
769
|
return f"Summary generation failed: {str(e)}"
|
@@ -743,3 +788,8 @@ class Agent(BaseModel):
|
|
743
788
|
"session_add_message",
|
744
789
|
{"role": "assistant", "content": assistant_content},
|
745
790
|
)
|
791
|
+
|
792
|
+
def update_model(self, new_model_name: str) -> None:
|
793
|
+
"""Update the model name and recreate the model instance."""
|
794
|
+
self.model_name = new_model_name
|
795
|
+
self.model = GenerativeModel(model=new_model_name, event_emitter=self.event_emitter)
|
quantalogic/agent_config.py
CHANGED
@@ -37,6 +37,14 @@ load_dotenv()
|
|
37
37
|
MODEL_NAME = "deepseek/deepseek-chat"
|
38
38
|
|
39
39
|
|
40
|
+
_current_model_name: str = ""
|
41
|
+
|
42
|
+
def get_current_model() -> str:
|
43
|
+
"""Retrieve the currently active model name."""
|
44
|
+
if not _current_model_name:
|
45
|
+
raise ValueError("No model initialized")
|
46
|
+
return _current_model_name
|
47
|
+
|
40
48
|
def create_agent(
|
41
49
|
model_name: str,
|
42
50
|
vision_model_name: str | None,
|
@@ -44,6 +52,8 @@ def create_agent(
|
|
44
52
|
compact_every_n_iteration: int | None = None,
|
45
53
|
max_tokens_working_memory: int | None = None
|
46
54
|
) -> Agent:
|
55
|
+
global _current_model_name
|
56
|
+
_current_model_name = model_name
|
47
57
|
"""Create an agent with the specified model and tools.
|
48
58
|
|
49
59
|
Args:
|
quantalogic/agent_factory.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
from typing import Optional
|
1
|
+
from typing import Dict, Optional
|
4
2
|
|
5
3
|
from loguru import logger
|
6
4
|
|
@@ -11,7 +9,57 @@ from quantalogic.agent_config import (
|
|
11
9
|
create_interpreter_agent,
|
12
10
|
)
|
13
11
|
from quantalogic.coding_agent import create_coding_agent
|
14
|
-
from quantalogic.search_agent import create_search_agent
|
12
|
+
from quantalogic.search_agent import create_search_agent # noqa: E402
|
13
|
+
|
14
|
+
|
15
|
+
class AgentRegistry:
|
16
|
+
"""Registry for managing agent instances by name."""
|
17
|
+
|
18
|
+
_instance = None
|
19
|
+
_agents: Dict[str, Agent] = {}
|
20
|
+
|
21
|
+
def __new__(cls):
|
22
|
+
if cls._instance is None:
|
23
|
+
cls._instance = super().__new__(cls)
|
24
|
+
return cls._instance
|
25
|
+
|
26
|
+
@classmethod
|
27
|
+
def register_agent(cls, name: str, agent: Agent) -> None:
|
28
|
+
"""Register an agent instance with a name.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
name: Unique name for the agent
|
32
|
+
agent: Agent instance to register
|
33
|
+
"""
|
34
|
+
if name in cls._agents:
|
35
|
+
raise ValueError(f"Agent with name {name} already exists")
|
36
|
+
cls._agents[name] = agent
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
def get_agent(cls, name: str) -> Agent:
|
40
|
+
"""Retrieve a registered agent by name.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
name: Name of the agent to retrieve
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
Registered Agent instance
|
47
|
+
|
48
|
+
Raises:
|
49
|
+
KeyError: If no agent with that name exists
|
50
|
+
"""
|
51
|
+
return cls._agents[name]
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def list_agents(cls) -> Dict[str, str]:
|
55
|
+
"""List all registered agents.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
Dictionary mapping agent names to their types
|
59
|
+
"""
|
60
|
+
return {name: type(agent).__name__ for name, agent in cls._agents.items()}
|
61
|
+
|
62
|
+
"""Agent factory module for creating different types of agents."""
|
15
63
|
|
16
64
|
|
17
65
|
def create_agent_for_mode(
|
@@ -46,7 +94,7 @@ def create_agent_for_mode(
|
|
46
94
|
|
47
95
|
if mode == "code":
|
48
96
|
logger.debug("Creating code agent without basic mode")
|
49
|
-
|
97
|
+
agent = create_coding_agent(
|
50
98
|
model_name,
|
51
99
|
vision_model_name,
|
52
100
|
basic=False,
|
@@ -54,8 +102,9 @@ def create_agent_for_mode(
|
|
54
102
|
compact_every_n_iteration=compact_every_n_iteration,
|
55
103
|
max_tokens_working_memory=max_tokens_working_memory
|
56
104
|
)
|
105
|
+
return agent
|
57
106
|
if mode == "code-basic":
|
58
|
-
|
107
|
+
agent = create_coding_agent(
|
59
108
|
model_name,
|
60
109
|
vision_model_name,
|
61
110
|
basic=True,
|
@@ -63,44 +112,50 @@ def create_agent_for_mode(
|
|
63
112
|
compact_every_n_iteration=compact_every_n_iteration,
|
64
113
|
max_tokens_working_memory=max_tokens_working_memory
|
65
114
|
)
|
115
|
+
return agent
|
66
116
|
elif mode == "basic":
|
67
|
-
|
117
|
+
agent = create_basic_agent(
|
68
118
|
model_name,
|
69
119
|
vision_model_name,
|
70
120
|
no_stream=no_stream,
|
71
121
|
compact_every_n_iteration=compact_every_n_iteration,
|
72
122
|
max_tokens_working_memory=max_tokens_working_memory
|
73
123
|
)
|
124
|
+
return agent
|
74
125
|
elif mode == "full":
|
75
|
-
|
126
|
+
agent = create_full_agent(
|
76
127
|
model_name,
|
77
128
|
vision_model_name,
|
78
129
|
no_stream=no_stream,
|
79
130
|
compact_every_n_iteration=compact_every_n_iteration,
|
80
131
|
max_tokens_working_memory=max_tokens_working_memory
|
81
132
|
)
|
133
|
+
return agent
|
82
134
|
elif mode == "interpreter":
|
83
|
-
|
135
|
+
agent = create_interpreter_agent(
|
84
136
|
model_name,
|
85
137
|
vision_model_name,
|
86
138
|
no_stream=no_stream,
|
87
139
|
compact_every_n_iteration=compact_every_n_iteration,
|
88
140
|
max_tokens_working_memory=max_tokens_working_memory
|
89
141
|
)
|
142
|
+
return agent
|
90
143
|
elif mode == "search":
|
91
|
-
|
144
|
+
agent = create_search_agent(
|
92
145
|
model_name,
|
93
146
|
no_stream=no_stream,
|
94
147
|
compact_every_n_iteration=compact_every_n_iteration,
|
95
148
|
max_tokens_working_memory=max_tokens_working_memory
|
96
149
|
)
|
150
|
+
return agent
|
97
151
|
if mode == "search-full":
|
98
|
-
|
152
|
+
agent = create_search_agent(
|
99
153
|
model_name,
|
100
154
|
mode_full=True,
|
101
155
|
no_stream=no_stream,
|
102
156
|
compact_every_n_iteration=compact_every_n_iteration,
|
103
157
|
max_tokens_working_memory=max_tokens_working_memory
|
104
158
|
)
|
159
|
+
return agent
|
105
160
|
else:
|
106
161
|
raise ValueError(f"Unknown agent mode: {mode}")
|
quantalogic/config.py
ADDED
@@ -0,0 +1,15 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
|
4
|
+
@dataclass
|
5
|
+
class QLConfig:
|
6
|
+
"""Central configuration for QuantaLogic agent parameters."""
|
7
|
+
model_name: str
|
8
|
+
verbose: bool
|
9
|
+
mode: str
|
10
|
+
log: str
|
11
|
+
vision_model_name: str | None
|
12
|
+
max_iterations: int
|
13
|
+
compact_every_n_iteration: int | None
|
14
|
+
max_tokens_working_memory: int | None
|
15
|
+
no_stream: bool
|
quantalogic/generative_model.py
CHANGED
@@ -1,24 +1,22 @@
|
|
1
1
|
"""Generative model module for AI-powered text generation."""
|
2
2
|
|
3
|
-
import functools
|
4
3
|
from datetime import datetime
|
5
4
|
from typing import Any, Dict, List
|
6
5
|
|
7
6
|
import litellm
|
8
7
|
import openai
|
9
|
-
from litellm import
|
8
|
+
from litellm import exceptions
|
10
9
|
from loguru import logger
|
11
10
|
from pydantic import BaseModel, Field, field_validator
|
12
11
|
|
13
12
|
from quantalogic.event_emitter import EventEmitter # Importing the EventEmitter class
|
14
|
-
from quantalogic.get_model_info import get_max_input_tokens, get_max_output_tokens,
|
13
|
+
from quantalogic.get_model_info import get_max_input_tokens, get_max_output_tokens, get_max_tokens
|
14
|
+
from quantalogic.llm import count_tokens, generate_completion, generate_image
|
15
15
|
|
16
16
|
MIN_RETRIES = 1
|
17
17
|
|
18
18
|
|
19
|
-
|
20
|
-
litellm.suppress_debug_info = True # Very important to suppress prints don't remove
|
21
|
-
|
19
|
+
litellm.suppress_debug_info = True # Very important to suppress prints don't remove
|
22
20
|
|
23
21
|
|
24
22
|
# Define the Message class for conversation handling
|
@@ -90,17 +88,16 @@ class GenerativeModel:
|
|
90
88
|
|
91
89
|
Args:
|
92
90
|
model: Model identifier. Defaults to "ollama/qwen2.5-coder:14b".
|
93
|
-
temperature: Temperature parameter for controlling randomness in generation.
|
94
|
-
Higher values (e.g. 0.8) make output more random, lower values (e.g. 0.2)
|
91
|
+
temperature: Temperature parameter for controlling randomness in generation.
|
92
|
+
Higher values (e.g. 0.8) make output more random, lower values (e.g. 0.2)
|
95
93
|
make it more deterministic. Defaults to 0.7.
|
96
|
-
event_emitter: Optional event emitter instance for handling asynchronous events
|
94
|
+
event_emitter: Optional event emitter instance for handling asynchronous events
|
97
95
|
and callbacks during text generation. Defaults to None.
|
98
96
|
"""
|
99
97
|
logger.debug(f"Initializing GenerativeModel with model={model}, temperature={temperature}")
|
100
98
|
self.model = model
|
101
99
|
self.temperature = temperature
|
102
100
|
self.event_emitter = event_emitter or EventEmitter() # Initialize event emitter
|
103
|
-
self._get_model_info_cached = functools.lru_cache(maxsize=32)(self._get_model_info_impl)
|
104
101
|
|
105
102
|
# Define retriable exceptions based on LiteLLM's exception mapping
|
106
103
|
RETRIABLE_EXCEPTIONS = (
|
@@ -161,7 +158,7 @@ class GenerativeModel:
|
|
161
158
|
try:
|
162
159
|
logger.debug(f"Generating response for prompt: {prompt}")
|
163
160
|
|
164
|
-
response =
|
161
|
+
response = generate_completion(
|
165
162
|
temperature=self.temperature,
|
166
163
|
model=self.model,
|
167
164
|
messages=messages,
|
@@ -187,7 +184,7 @@ class GenerativeModel:
|
|
187
184
|
def _stream_response(self, messages):
|
188
185
|
"""Private method to handle streaming responses."""
|
189
186
|
try:
|
190
|
-
for chunk in
|
187
|
+
for chunk in generate_completion(
|
191
188
|
temperature=self.temperature,
|
192
189
|
model=self.model,
|
193
190
|
messages=messages,
|
@@ -253,96 +250,21 @@ class GenerativeModel:
|
|
253
250
|
"""Count the number of tokens in a list of messages."""
|
254
251
|
logger.debug(f"Counting tokens for {len(messages)} messages using model {self.model}")
|
255
252
|
litellm_messages = [{"role": msg.role, "content": str(msg.content)} for msg in messages]
|
256
|
-
|
257
|
-
logger.debug(f"Token count: {token_count}")
|
258
|
-
return token_count
|
253
|
+
return count_tokens(model=self.model, messages=litellm_messages)
|
259
254
|
|
260
255
|
def token_counter_with_history(self, messages_history: list[Message], prompt: str) -> int:
|
261
256
|
"""Count the number of tokens in a list of messages and a prompt."""
|
262
257
|
litellm_messages = [{"role": msg.role, "content": str(msg.content)} for msg in messages_history]
|
263
258
|
litellm_messages.append({"role": "user", "content": str(prompt)})
|
264
|
-
return
|
265
|
-
|
266
|
-
def _get_model_info_impl(self, model_name: str) -> dict:
|
267
|
-
"""Get information about the model with prefix fallback logic."""
|
268
|
-
original_model = model_name
|
269
|
-
tried_models = [model_name]
|
270
|
-
|
271
|
-
while True:
|
272
|
-
try:
|
273
|
-
logger.debug(f"Attempting to retrieve model info for: {model_name}")
|
274
|
-
# Try direct lookup from model_info dictionary first
|
275
|
-
if model_name in model_info:
|
276
|
-
logger.debug(f"Found model info for {model_name} in model_info")
|
277
|
-
return model_info[model_name]
|
278
|
-
|
279
|
-
# Try get_model_info as fallback
|
280
|
-
info = get_model_info(model_name)
|
281
|
-
if info:
|
282
|
-
logger.debug(f"Found model info for {model_name} via get_model_info")
|
283
|
-
return info
|
284
|
-
except Exception as e:
|
285
|
-
logger.debug(f"Failed to get model info for {model_name}: {str(e)}")
|
286
|
-
pass
|
287
|
-
|
288
|
-
# Try removing one prefix level
|
289
|
-
parts = model_name.split("/")
|
290
|
-
if len(parts) <= 1:
|
291
|
-
break
|
292
|
-
model_name = "/".join(parts[1:])
|
293
|
-
tried_models.append(model_name)
|
294
|
-
|
295
|
-
error_msg = f"Could not find model info for {original_model} after trying: {' → '.join(tried_models)}"
|
296
|
-
logger.error(error_msg)
|
297
|
-
raise ValueError(error_msg)
|
298
|
-
|
299
|
-
def get_model_info(self, model_name: str = None) -> dict:
|
300
|
-
"""Get cached information about the model."""
|
301
|
-
if model_name is None:
|
302
|
-
model_name = self.model
|
303
|
-
return self._get_model_info_cached(model_name)
|
259
|
+
return count_tokens(model=self.model, messages=litellm_messages)
|
304
260
|
|
305
261
|
def get_model_max_input_tokens(self) -> int | None:
|
306
262
|
"""Get the maximum number of input tokens for the model."""
|
307
|
-
|
308
|
-
# First try direct lookup
|
309
|
-
max_tokens = get_max_input_tokens(self.model)
|
310
|
-
if max_tokens is not None:
|
311
|
-
return max_tokens
|
312
|
-
|
313
|
-
# If not found, try getting from model info
|
314
|
-
model_info = self.get_model_info()
|
315
|
-
if model_info:
|
316
|
-
return model_info.get("max_input_tokens")
|
317
|
-
|
318
|
-
# If still not found, log warning and return default
|
319
|
-
logger.warning(f"No max input tokens found for {self.model}. Using default.")
|
320
|
-
return 8192 # A reasonable default for many models
|
321
|
-
|
322
|
-
except Exception as e:
|
323
|
-
logger.error(f"Error getting max input tokens for {self.model}: {e}")
|
324
|
-
return None
|
263
|
+
return get_max_input_tokens(self.model)
|
325
264
|
|
326
265
|
def get_model_max_output_tokens(self) -> int | None:
|
327
266
|
"""Get the maximum number of output tokens for the model."""
|
328
|
-
|
329
|
-
# First try direct lookup
|
330
|
-
max_tokens = get_max_output_tokens(self.model)
|
331
|
-
if max_tokens is not None:
|
332
|
-
return max_tokens
|
333
|
-
|
334
|
-
# If not found, try getting from model info
|
335
|
-
model_info = self.get_model_info()
|
336
|
-
if model_info:
|
337
|
-
return model_info.get("max_output_tokens")
|
338
|
-
|
339
|
-
# If still not found, log warning and return default
|
340
|
-
logger.warning(f"No max output tokens found for {self.model}. Using default.")
|
341
|
-
return 4096 # A reasonable default for many models
|
342
|
-
|
343
|
-
except Exception as e:
|
344
|
-
logger.error(f"Error getting max output tokens for {self.model}: {e}")
|
345
|
-
return None
|
267
|
+
return get_max_output_tokens(self.model)
|
346
268
|
|
347
269
|
def generate_image(self, prompt: str, params: Dict[str, Any]) -> ResponseStats:
|
348
270
|
"""Generate an image using the specified model and parameters.
|
@@ -366,16 +288,13 @@ class GenerativeModel:
|
|
366
288
|
"""
|
367
289
|
try:
|
368
290
|
logger.debug(f"Generating image with params: {params}")
|
369
|
-
|
291
|
+
|
370
292
|
# Ensure prompt is in params
|
371
293
|
generation_params = {**params}
|
372
294
|
generation_params["prompt"] = prompt
|
373
|
-
|
295
|
+
|
374
296
|
# Call litellm's image generation function
|
375
|
-
response =
|
376
|
-
model=generation_params.pop("model"),
|
377
|
-
**generation_params
|
378
|
-
)
|
297
|
+
response = generate_image(model=generation_params.pop("model"), **generation_params)
|
379
298
|
|
380
299
|
# Convert response data to list of dictionaries with string values
|
381
300
|
if hasattr(response, "data"):
|
@@ -407,7 +326,7 @@ class GenerativeModel:
|
|
407
326
|
usage=TokenUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
408
327
|
model=str(params["model"]),
|
409
328
|
data=data,
|
410
|
-
created=created
|
329
|
+
created=created,
|
411
330
|
)
|
412
331
|
|
413
332
|
except Exception as e:
|