quantalogic 0.30.8__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quantalogic/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
- # QuantaLogic package initialization
1
+ """QuantaLogic package initialization."""
2
+
2
3
  import warnings
3
4
 
4
5
  # Suppress specific warnings related to Pydantic's V2 configuration changes
@@ -9,13 +10,22 @@ warnings.filterwarnings(
9
10
  message=".*config keys have changed in V2:.*|.*'fields' config key is removed in V2.*",
10
11
  )
11
12
 
12
-
13
+ # Import public API
14
+ from .llm import generate_completion, generate_image, count_tokens # noqa: E402
13
15
  from .agent import Agent # noqa: E402
14
- from .console_print_events import console_print_events # noqa: E402
15
- from .console_print_token import console_print_token # noqa: E402
16
16
  from .event_emitter import EventEmitter # noqa: E402
17
17
  from .memory import AgentMemory, VariableMemory # noqa: E402
18
+ from .console_print_events import console_print_events # noqa: E402
19
+ from .console_print_token import console_print_token # noqa: E402
18
20
 
19
- """QuantaLogic package for AI-powered generative models."""
20
-
21
- __all__ = ["Agent", "EventEmitter", "AgentMemory", "VariableMemory", "console_print_events","console_print_token"]
21
+ __all__ = [
22
+ "Agent",
23
+ "EventEmitter",
24
+ "AgentMemory",
25
+ "VariableMemory",
26
+ "console_print_events",
27
+ "console_print_token",
28
+ "generate_completion",
29
+ "generate_image",
30
+ "count_tokens"
31
+ ]
quantalogic/agent.py CHANGED
@@ -5,7 +5,7 @@ from datetime import datetime
5
5
  from typing import Any
6
6
 
7
7
  from loguru import logger
8
- from pydantic import BaseModel, ConfigDict
8
+ from pydantic import BaseModel, ConfigDict, PrivateAttr
9
9
 
10
10
  from quantalogic.event_emitter import EventEmitter
11
11
  from quantalogic.generative_model import GenerativeModel, ResponseStats, TokenUsage
@@ -52,12 +52,16 @@ class ObserveResponseResult(BaseModel):
52
52
  class Agent(BaseModel):
53
53
  """Enhanced QuantaLogic agent implementing ReAct framework."""
54
54
 
55
- model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True, extra="forbid")
55
+ model_config = ConfigDict(
56
+ arbitrary_types_allowed=True,
57
+ validate_assignment=True,
58
+ extra="forbid"
59
+ )
56
60
 
57
61
  specific_expertise: str
58
62
  model: GenerativeModel
59
- memory: AgentMemory = AgentMemory() # A list User / Assistant Messages
60
- variable_store: VariableMemory = VariableMemory() # A dictionary of variables (var1: value1, var2: value2)
63
+ memory: AgentMemory = AgentMemory() # A list User / Assistant Messages
64
+ variable_store: VariableMemory = VariableMemory() # A dictionary of variables
61
65
  tools: ToolManager = ToolManager()
62
66
  event_emitter: EventEmitter = EventEmitter()
63
67
  config: AgentConfig
@@ -71,8 +75,9 @@ class Agent(BaseModel):
71
75
  max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS
72
76
  max_iterations: int = 30
73
77
  system_prompt: str = ""
74
- compact_every_n_iterations: int | None = None # Add this to the class attributes
75
- max_tokens_working_memory: int | None = None # Add max_tokens_working_memory attribute
78
+ compact_every_n_iterations: int | None = None
79
+ max_tokens_working_memory: int | None = None
80
+ _model_name: str = PrivateAttr(default="")
76
81
 
77
82
  def __init__(
78
83
  self,
@@ -84,17 +89,18 @@ class Agent(BaseModel):
84
89
  task_to_solve: str = "",
85
90
  specific_expertise: str = "General AI assistant with coding and problem-solving capabilities",
86
91
  get_environment: Callable[[], str] = get_environment,
87
- compact_every_n_iterations: int | None = None, # if set the memory will be compacted every n iterations
88
- max_tokens_working_memory: int | None = None, # if set the memory will be compacted each time the max_tokens_working_memory is reached
92
+ compact_every_n_iterations: int | None = None,
93
+ max_tokens_working_memory: int | None = None,
89
94
  ):
90
95
  """Initialize the agent with model, memory, tools, and configurations."""
91
96
  try:
92
97
  logger.debug("Initializing agent...")
93
- # Create event emitter first
98
+
99
+ # Create event emitter
94
100
  event_emitter = EventEmitter()
95
101
 
96
102
  # Add TaskCompleteTool to the tools list if not already present
97
- if TaskCompleteTool() not in tools:
103
+ if not any(isinstance(t, TaskCompleteTool) for t in tools):
98
104
  tools.append(TaskCompleteTool())
99
105
 
100
106
  tool_manager = ToolManager(tools={tool.name: tool for tool in tools})
@@ -114,32 +120,50 @@ class Agent(BaseModel):
114
120
  system_prompt=system_prompt_text,
115
121
  )
116
122
 
117
- logger.debug("Base class init started ...")
123
+ # Initialize using Pydantic's model_validate
118
124
  super().__init__(
125
+ specific_expertise=specific_expertise,
119
126
  model=GenerativeModel(model=model_name, event_emitter=event_emitter),
120
127
  memory=memory,
121
128
  variable_store=variable_store,
122
129
  tools=tool_manager,
130
+ event_emitter=event_emitter,
123
131
  config=config,
124
- ask_for_user_validation=ask_for_user_validation,
125
132
  task_to_solve=task_to_solve,
126
- specific_expertise=specific_expertise,
127
- event_emitter=event_emitter,
133
+ task_to_solve_summary="",
134
+ ask_for_user_validation=ask_for_user_validation,
135
+ last_tool_call={},
136
+ total_tokens=0,
137
+ current_iteration=0,
138
+ max_input_tokens=DEFAULT_MAX_INPUT_TOKENS,
139
+ max_output_tokens=DEFAULT_MAX_OUTPUT_TOKENS,
140
+ max_iterations=30,
141
+ system_prompt="",
142
+ compact_every_n_iterations=compact_every_n_iterations or 30,
143
+ max_tokens_working_memory=max_tokens_working_memory,
128
144
  )
129
145
 
130
- # Set the new compact_every_n_iterations parameter
131
- self.compact_every_n_iterations = compact_every_n_iterations or self.max_iterations
132
- logger.debug(f"Memory will be compacted every {self.compact_every_n_iterations} iterations")
146
+ self._model_name = model_name
133
147
 
134
- # Set the max_tokens_working_memory parameter
135
- self.max_tokens_working_memory = max_tokens_working_memory
148
+ logger.debug(f"Memory will be compacted every {self.compact_every_n_iterations} iterations")
136
149
  logger.debug(f"Max tokens for working memory set to: {self.max_tokens_working_memory}")
137
-
138
150
  logger.debug("Agent initialized successfully.")
139
151
  except Exception as e:
140
152
  logger.error(f"Failed to initialize agent: {str(e)}")
141
153
  raise
142
154
 
155
+ @property
156
+ def model_name(self) -> str:
157
+ """Get the current model name."""
158
+ return self._model_name
159
+
160
+ @model_name.setter
161
+ def model_name(self, value: str) -> None:
162
+ """Set the model name."""
163
+ self._model_name = value
164
+ # Update the model instance with the new name
165
+ self.model = GenerativeModel(model=value, event_emitter=self.event_emitter)
166
+
143
167
  def clear_memory(self):
144
168
  """Clear the memory and reset the session."""
145
169
  self._reset_session(clear_memory=True)
@@ -533,7 +557,10 @@ class Agent(BaseModel):
533
557
  question_validation: str = (
534
558
  "Do you permit the execution of this tool?\n"
535
559
  f"Tool: {tool_name}\n"
536
- f"Arguments: {arguments_with_values}\n"
560
+ "Arguments:\n"
561
+ "<arguments>\n"
562
+ + "\n".join([f" <{key}>{value}</{key}>" for key, value in arguments_with_values.items()])
563
+ + "\n</arguments>\n"
537
564
  "Yes or No"
538
565
  )
539
566
  permission_granted = self.ask_for_user_validation(question_validation)
@@ -603,10 +630,14 @@ class Agent(BaseModel):
603
630
  return executed_tool, response
604
631
 
605
632
  def _interpolate_variables(self, text: str) -> str:
606
- """Interpolate variables using $var1$ syntax in the given text."""
633
+ """Interpolate variables using $var$ syntax in the given text."""
607
634
  try:
635
+ import re
608
636
  for var in self.variable_store.keys():
609
- text = text.replace(f"${var}$", self.variable_store[var])
637
+ # Escape the variable name for regex, but use raw value for replacement
638
+ pattern = rf'\${re.escape(var)}\$'
639
+ replacement = self.variable_store[var]
640
+ text = re.sub(pattern, replacement, text)
610
641
  return text
611
642
  except Exception as e:
612
643
  logger.error(f"Error in _interpolate_variables: {str(e)}")
@@ -645,6 +676,7 @@ class Agent(BaseModel):
645
676
  "1. Select ONE tool per message\n"
646
677
  "2. You will receive the tool's output in the next user response\n"
647
678
  "3. Choose the most appropriate tool for each step\n"
679
+ "4. Use task_complete tool to confirm task completion\n"
648
680
  )
649
681
  return prompt_use_tools
650
682
 
@@ -706,23 +738,32 @@ class Agent(BaseModel):
706
738
  return summary.response
707
739
 
708
740
  def _generate_task_summary(self, content: str) -> str:
709
- """Generate a concise summary of the given content using the generative model.
741
+ """Generate a concise task-focused summary using the generative model.
710
742
 
711
743
  Args:
712
744
  content (str): The content to summarize
713
745
 
714
746
  Returns:
715
- str: Generated summary
747
+ str: Generated task summary
716
748
  """
717
749
  try:
718
750
  prompt = (
719
- "Rewrite this task in a precise, dense, and concise manner:\n"
720
- f"{content}\n"
721
- "Summary should be 2-3 sentences maximum. No extra comments should be added.\n"
751
+ "Create an ultra-concise task summary that captures ONLY: \n"
752
+ "1. Primary objective/purpose\n"
753
+ "2. Core actions/requirements\n"
754
+ "3. Desired end-state/outcome\n\n"
755
+ "Guidelines:\n"
756
+ "- Use imperative voice\n"
757
+ "- Exclude background, explanations, and examples\n"
758
+ "- Compress information using semantic density\n"
759
+ "- Strict 2-3 sentence maximum (under 50 words)\n"
760
+ "- Format: 'Concise Task Summary: [Your summary]'\n\n"
761
+ f"Input Task Description:\n{content}\n\n"
762
+ "Concise Task Summary:"
722
763
  )
723
764
  result = self.model.generate(prompt=prompt)
724
765
  logger.debug(f"Generated summary: {result.response}")
725
- return result.response
766
+ return result.response.strip() + "\n🚨 The FULL task is in <task> tag in the previous messages.\n"
726
767
  except Exception as e:
727
768
  logger.error(f"Error generating summary: {str(e)}")
728
769
  return f"Summary generation failed: {str(e)}"
@@ -747,3 +788,8 @@ class Agent(BaseModel):
747
788
  "session_add_message",
748
789
  {"role": "assistant", "content": assistant_content},
749
790
  )
791
+
792
+ def update_model(self, new_model_name: str) -> None:
793
+ """Update the model name and recreate the model instance."""
794
+ self.model_name = new_model_name
795
+ self.model = GenerativeModel(model=new_model_name, event_emitter=self.event_emitter)
@@ -37,6 +37,14 @@ load_dotenv()
37
37
  MODEL_NAME = "deepseek/deepseek-chat"
38
38
 
39
39
 
40
+ _current_model_name: str = ""
41
+
42
+ def get_current_model() -> str:
43
+ """Retrieve the currently active model name."""
44
+ if not _current_model_name:
45
+ raise ValueError("No model initialized")
46
+ return _current_model_name
47
+
40
48
  def create_agent(
41
49
  model_name: str,
42
50
  vision_model_name: str | None,
@@ -44,6 +52,8 @@ def create_agent(
44
52
  compact_every_n_iteration: int | None = None,
45
53
  max_tokens_working_memory: int | None = None
46
54
  ) -> Agent:
55
+ global _current_model_name
56
+ _current_model_name = model_name
47
57
  """Create an agent with the specified model and tools.
48
58
 
49
59
  Args:
@@ -1,6 +1,4 @@
1
- """Agent factory module for creating different types of agents."""
2
-
3
- from typing import Optional
1
+ from typing import Dict, Optional
4
2
 
5
3
  from loguru import logger
6
4
 
@@ -11,7 +9,57 @@ from quantalogic.agent_config import (
11
9
  create_interpreter_agent,
12
10
  )
13
11
  from quantalogic.coding_agent import create_coding_agent
14
- from quantalogic.search_agent import create_search_agent
12
+ from quantalogic.search_agent import create_search_agent # noqa: E402
13
+
14
+
15
+ class AgentRegistry:
16
+ """Registry for managing agent instances by name."""
17
+
18
+ _instance = None
19
+ _agents: Dict[str, Agent] = {}
20
+
21
+ def __new__(cls):
22
+ if cls._instance is None:
23
+ cls._instance = super().__new__(cls)
24
+ return cls._instance
25
+
26
+ @classmethod
27
+ def register_agent(cls, name: str, agent: Agent) -> None:
28
+ """Register an agent instance with a name.
29
+
30
+ Args:
31
+ name: Unique name for the agent
32
+ agent: Agent instance to register
33
+ """
34
+ if name in cls._agents:
35
+ raise ValueError(f"Agent with name {name} already exists")
36
+ cls._agents[name] = agent
37
+
38
+ @classmethod
39
+ def get_agent(cls, name: str) -> Agent:
40
+ """Retrieve a registered agent by name.
41
+
42
+ Args:
43
+ name: Name of the agent to retrieve
44
+
45
+ Returns:
46
+ Registered Agent instance
47
+
48
+ Raises:
49
+ KeyError: If no agent with that name exists
50
+ """
51
+ return cls._agents[name]
52
+
53
+ @classmethod
54
+ def list_agents(cls) -> Dict[str, str]:
55
+ """List all registered agents.
56
+
57
+ Returns:
58
+ Dictionary mapping agent names to their types
59
+ """
60
+ return {name: type(agent).__name__ for name, agent in cls._agents.items()}
61
+
62
+ """Agent factory module for creating different types of agents."""
15
63
 
16
64
 
17
65
  def create_agent_for_mode(
@@ -46,7 +94,7 @@ def create_agent_for_mode(
46
94
 
47
95
  if mode == "code":
48
96
  logger.debug("Creating code agent without basic mode")
49
- return create_coding_agent(
97
+ agent = create_coding_agent(
50
98
  model_name,
51
99
  vision_model_name,
52
100
  basic=False,
@@ -54,8 +102,9 @@ def create_agent_for_mode(
54
102
  compact_every_n_iteration=compact_every_n_iteration,
55
103
  max_tokens_working_memory=max_tokens_working_memory
56
104
  )
105
+ return agent
57
106
  if mode == "code-basic":
58
- return create_coding_agent(
107
+ agent = create_coding_agent(
59
108
  model_name,
60
109
  vision_model_name,
61
110
  basic=True,
@@ -63,44 +112,50 @@ def create_agent_for_mode(
63
112
  compact_every_n_iteration=compact_every_n_iteration,
64
113
  max_tokens_working_memory=max_tokens_working_memory
65
114
  )
115
+ return agent
66
116
  elif mode == "basic":
67
- return create_basic_agent(
117
+ agent = create_basic_agent(
68
118
  model_name,
69
119
  vision_model_name,
70
120
  no_stream=no_stream,
71
121
  compact_every_n_iteration=compact_every_n_iteration,
72
122
  max_tokens_working_memory=max_tokens_working_memory
73
123
  )
124
+ return agent
74
125
  elif mode == "full":
75
- return create_full_agent(
126
+ agent = create_full_agent(
76
127
  model_name,
77
128
  vision_model_name,
78
129
  no_stream=no_stream,
79
130
  compact_every_n_iteration=compact_every_n_iteration,
80
131
  max_tokens_working_memory=max_tokens_working_memory
81
132
  )
133
+ return agent
82
134
  elif mode == "interpreter":
83
- return create_interpreter_agent(
135
+ agent = create_interpreter_agent(
84
136
  model_name,
85
137
  vision_model_name,
86
138
  no_stream=no_stream,
87
139
  compact_every_n_iteration=compact_every_n_iteration,
88
140
  max_tokens_working_memory=max_tokens_working_memory
89
141
  )
142
+ return agent
90
143
  elif mode == "search":
91
- return create_search_agent(
144
+ agent = create_search_agent(
92
145
  model_name,
93
146
  no_stream=no_stream,
94
147
  compact_every_n_iteration=compact_every_n_iteration,
95
148
  max_tokens_working_memory=max_tokens_working_memory
96
149
  )
150
+ return agent
97
151
  if mode == "search-full":
98
- return create_search_agent(
152
+ agent = create_search_agent(
99
153
  model_name,
100
154
  mode_full=True,
101
155
  no_stream=no_stream,
102
156
  compact_every_n_iteration=compact_every_n_iteration,
103
157
  max_tokens_working_memory=max_tokens_working_memory
104
158
  )
159
+ return agent
105
160
  else:
106
161
  raise ValueError(f"Unknown agent mode: {mode}")
quantalogic/config.py ADDED
@@ -0,0 +1,15 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class QLConfig:
6
+ """Central configuration for QuantaLogic agent parameters."""
7
+ model_name: str
8
+ verbose: bool
9
+ mode: str
10
+ log: str
11
+ vision_model_name: str | None
12
+ max_iterations: int
13
+ compact_every_n_iteration: int | None
14
+ max_tokens_working_memory: int | None
15
+ no_stream: bool
@@ -1,24 +1,22 @@
1
1
  """Generative model module for AI-powered text generation."""
2
2
 
3
- import functools
4
3
  from datetime import datetime
5
4
  from typing import Any, Dict, List
6
5
 
7
6
  import litellm
8
7
  import openai
9
- from litellm import completion, exceptions, get_max_tokens, get_model_info, image_generation, token_counter
8
+ from litellm import exceptions
10
9
  from loguru import logger
11
10
  from pydantic import BaseModel, Field, field_validator
12
11
 
13
12
  from quantalogic.event_emitter import EventEmitter # Importing the EventEmitter class
14
- from quantalogic.get_model_info import get_max_input_tokens, get_max_output_tokens, model_info
13
+ from quantalogic.get_model_info import get_max_input_tokens, get_max_output_tokens, get_max_tokens
14
+ from quantalogic.llm import count_tokens, generate_completion, generate_image
15
15
 
16
16
  MIN_RETRIES = 1
17
17
 
18
18
 
19
-
20
- litellm.suppress_debug_info = True # Very important to suppress prints don't remove
21
-
19
+ litellm.suppress_debug_info = True # Very important to suppress prints don't remove
22
20
 
23
21
 
24
22
  # Define the Message class for conversation handling
@@ -90,17 +88,16 @@ class GenerativeModel:
90
88
 
91
89
  Args:
92
90
  model: Model identifier. Defaults to "ollama/qwen2.5-coder:14b".
93
- temperature: Temperature parameter for controlling randomness in generation.
94
- Higher values (e.g. 0.8) make output more random, lower values (e.g. 0.2)
91
+ temperature: Temperature parameter for controlling randomness in generation.
92
+ Higher values (e.g. 0.8) make output more random, lower values (e.g. 0.2)
95
93
  make it more deterministic. Defaults to 0.7.
96
- event_emitter: Optional event emitter instance for handling asynchronous events
94
+ event_emitter: Optional event emitter instance for handling asynchronous events
97
95
  and callbacks during text generation. Defaults to None.
98
96
  """
99
97
  logger.debug(f"Initializing GenerativeModel with model={model}, temperature={temperature}")
100
98
  self.model = model
101
99
  self.temperature = temperature
102
100
  self.event_emitter = event_emitter or EventEmitter() # Initialize event emitter
103
- self._get_model_info_cached = functools.lru_cache(maxsize=32)(self._get_model_info_impl)
104
101
 
105
102
  # Define retriable exceptions based on LiteLLM's exception mapping
106
103
  RETRIABLE_EXCEPTIONS = (
@@ -161,7 +158,7 @@ class GenerativeModel:
161
158
  try:
162
159
  logger.debug(f"Generating response for prompt: {prompt}")
163
160
 
164
- response = completion(
161
+ response = generate_completion(
165
162
  temperature=self.temperature,
166
163
  model=self.model,
167
164
  messages=messages,
@@ -187,7 +184,7 @@ class GenerativeModel:
187
184
  def _stream_response(self, messages):
188
185
  """Private method to handle streaming responses."""
189
186
  try:
190
- for chunk in completion(
187
+ for chunk in generate_completion(
191
188
  temperature=self.temperature,
192
189
  model=self.model,
193
190
  messages=messages,
@@ -253,96 +250,21 @@ class GenerativeModel:
253
250
  """Count the number of tokens in a list of messages."""
254
251
  logger.debug(f"Counting tokens for {len(messages)} messages using model {self.model}")
255
252
  litellm_messages = [{"role": msg.role, "content": str(msg.content)} for msg in messages]
256
- token_count = token_counter(model=self.model, messages=litellm_messages)
257
- logger.debug(f"Token count: {token_count}")
258
- return token_count
253
+ return count_tokens(model=self.model, messages=litellm_messages)
259
254
 
260
255
  def token_counter_with_history(self, messages_history: list[Message], prompt: str) -> int:
261
256
  """Count the number of tokens in a list of messages and a prompt."""
262
257
  litellm_messages = [{"role": msg.role, "content": str(msg.content)} for msg in messages_history]
263
258
  litellm_messages.append({"role": "user", "content": str(prompt)})
264
- return token_counter(model=self.model, messages=litellm_messages)
265
-
266
- def _get_model_info_impl(self, model_name: str) -> dict:
267
- """Get information about the model with prefix fallback logic."""
268
- original_model = model_name
269
- tried_models = [model_name]
270
-
271
- while True:
272
- try:
273
- logger.debug(f"Attempting to retrieve model info for: {model_name}")
274
- # Try direct lookup from model_info dictionary first
275
- if model_name in model_info:
276
- logger.debug(f"Found model info for {model_name} in model_info")
277
- return model_info[model_name]
278
-
279
- # Try get_model_info as fallback
280
- info = get_model_info(model_name)
281
- if info:
282
- logger.debug(f"Found model info for {model_name} via get_model_info")
283
- return info
284
- except Exception as e:
285
- logger.debug(f"Failed to get model info for {model_name}: {str(e)}")
286
- pass
287
-
288
- # Try removing one prefix level
289
- parts = model_name.split("/")
290
- if len(parts) <= 1:
291
- break
292
- model_name = "/".join(parts[1:])
293
- tried_models.append(model_name)
294
-
295
- error_msg = f"Could not find model info for {original_model} after trying: {' → '.join(tried_models)}"
296
- logger.error(error_msg)
297
- raise ValueError(error_msg)
298
-
299
- def get_model_info(self, model_name: str = None) -> dict:
300
- """Get cached information about the model."""
301
- if model_name is None:
302
- model_name = self.model
303
- return self._get_model_info_cached(model_name)
259
+ return count_tokens(model=self.model, messages=litellm_messages)
304
260
 
305
261
  def get_model_max_input_tokens(self) -> int | None:
306
262
  """Get the maximum number of input tokens for the model."""
307
- try:
308
- # First try direct lookup
309
- max_tokens = get_max_input_tokens(self.model)
310
- if max_tokens is not None:
311
- return max_tokens
312
-
313
- # If not found, try getting from model info
314
- model_info = self.get_model_info()
315
- if model_info:
316
- return model_info.get("max_input_tokens")
317
-
318
- # If still not found, log warning and return default
319
- logger.warning(f"No max input tokens found for {self.model}. Using default.")
320
- return 8192 # A reasonable default for many models
321
-
322
- except Exception as e:
323
- logger.error(f"Error getting max input tokens for {self.model}: {e}")
324
- return None
263
+ return get_max_input_tokens(self.model)
325
264
 
326
265
  def get_model_max_output_tokens(self) -> int | None:
327
266
  """Get the maximum number of output tokens for the model."""
328
- try:
329
- # First try direct lookup
330
- max_tokens = get_max_output_tokens(self.model)
331
- if max_tokens is not None:
332
- return max_tokens
333
-
334
- # If not found, try getting from model info
335
- model_info = self.get_model_info()
336
- if model_info:
337
- return model_info.get("max_output_tokens")
338
-
339
- # If still not found, log warning and return default
340
- logger.warning(f"No max output tokens found for {self.model}. Using default.")
341
- return 4096 # A reasonable default for many models
342
-
343
- except Exception as e:
344
- logger.error(f"Error getting max output tokens for {self.model}: {e}")
345
- return None
267
+ return get_max_output_tokens(self.model)
346
268
 
347
269
  def generate_image(self, prompt: str, params: Dict[str, Any]) -> ResponseStats:
348
270
  """Generate an image using the specified model and parameters.
@@ -366,16 +288,13 @@ class GenerativeModel:
366
288
  """
367
289
  try:
368
290
  logger.debug(f"Generating image with params: {params}")
369
-
291
+
370
292
  # Ensure prompt is in params
371
293
  generation_params = {**params}
372
294
  generation_params["prompt"] = prompt
373
-
295
+
374
296
  # Call litellm's image generation function
375
- response = image_generation(
376
- model=generation_params.pop("model"),
377
- **generation_params
378
- )
297
+ response = generate_image(model=generation_params.pop("model"), **generation_params)
379
298
 
380
299
  # Convert response data to list of dictionaries with string values
381
300
  if hasattr(response, "data"):
@@ -407,7 +326,7 @@ class GenerativeModel:
407
326
  usage=TokenUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
408
327
  model=str(params["model"]),
409
328
  data=data,
410
- created=created
329
+ created=created,
411
330
  )
412
331
 
413
332
  except Exception as e:
@@ -1,4 +1,7 @@
1
1
  model_info = {
2
+ "dashscope/qwen-max": {"max_output_tokens": 8 * 1024, "max_input_tokens": 32 * 1024},
3
+ "dashscope/qwen-plus": {"max_output_tokens": 8 * 1024, "max_input_tokens": 131072},
4
+ "dashscope/qwen-turbo": {"max_output_tokens": 8 * 1024, "max_input_tokens": 1000000},
2
5
  "deepseek-reasoner": {"max_output_tokens": 8 * 1024, "max_input_tokens": 1024 * 128},
3
6
  "openrouter/deepseek/deepseek-r1": {"max_output_tokens": 8 * 1024, "max_input_tokens": 1024 * 128},
4
7
  "openrouter/mistralai/mistral-large-2411": {"max_output_tokens": 128 * 1024, "max_input_tokens": 1024 * 128},
@@ -6,6 +9,17 @@ model_info = {
6
9
  }
7
10
 
8
11
 
12
+ def print_model_info():
13
+ for model, info in model_info.items():
14
+ print(f"\n{model}:")
15
+ print(f" Max Input Tokens: {info['max_input_tokens']:,}")
16
+ print(f" Max Output Tokens: {info['max_output_tokens']:,}")
17
+
18
+
19
+ if __name__ == "__main__":
20
+ print_model_info()
21
+
22
+
9
23
  def get_max_output_tokens(model_name: str) -> int | None:
10
24
  """Get the maximum output tokens for a given model name."""
11
25
  return model_info.get(model_name, {}).get("max_output_tokens", None)
@@ -14,3 +28,15 @@ def get_max_output_tokens(model_name: str) -> int | None:
14
28
  def get_max_input_tokens(model_name: str) -> int | None:
15
29
  """Get the maximum input tokens for a given model name."""
16
30
  return model_info.get(model_name, {}).get("max_input_tokens", None)
31
+
32
+
33
+ def get_max_tokens(model_name: str) -> int | None:
34
+ """Get the maximum total tokens (input + output) for a given model name."""
35
+ model_data = model_info.get(model_name, {})
36
+ max_input = model_data.get("max_input_tokens")
37
+ max_output = model_data.get("max_output_tokens")
38
+
39
+ if max_input is None or max_output is None:
40
+ return None
41
+
42
+ return max_input + max_output