semantio 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
semantio/agent.py CHANGED
@@ -22,19 +22,27 @@ from .memory import Memory
22
22
  logging.basicConfig(level=logging.INFO)
23
23
  logger = logging.getLogger(__name__)
24
24
 
25
+
25
26
  class Agent(BaseModel):
26
- # -*- Agent settings
27
+ """
28
+ An intelligent agent that combines LLM capabilities with dynamic knowledge base integration,
29
+ tool usage, and conversation memory. The agent can ingest external domain-specific content (via a dynamic document loader)
30
+ so that it answers queries based on that information.
31
+ """
27
32
  name: Optional[str] = Field(None, description="Name of the agent.")
28
33
  description: Optional[str] = Field(None, description="Description of the agent's role.")
29
34
  instructions: Optional[List[str]] = Field(None, description="List of instructions for the agent.")
30
- model: Optional[str] = Field(None, description="This one is not in the use.")
35
+ model: Optional[str] = Field(None, description="This one is not in use.")
31
36
  show_tool_calls: bool = Field(False, description="Whether to show tool calls in the response.")
32
37
  markdown: bool = Field(False, description="Whether to format the response in markdown.")
33
38
  tools: Optional[List[BaseTool]] = Field(None, description="List of tools available to the agent.")
34
39
  user_name: Optional[str] = Field("User", description="Name of the user interacting with the agent.")
35
40
  emoji: Optional[str] = Field(":robot:", description="Emoji to represent the agent in the CLI.")
36
41
  rag: Optional[RAG] = Field(None, description="RAG instance for context retrieval.")
37
- knowledge_base: Optional[Any] = Field(None, description="Knowledge base for domain-specific information.")
42
+ knowledge_base: Optional[Any] = Field(
43
+ None,
44
+ description="Domain-specific knowledge base content (e.g., loaded via a dynamic document loader)."
45
+ )
38
46
  llm: Optional[str] = Field(None, description="The LLM provider to use (e.g., 'groq', 'openai', 'anthropic').")
39
47
  llm_model: Optional[str] = Field(None, description="The specific model to use for the LLM provider.")
40
48
  llm_instance: Optional[BaseLLM] = Field(None, description="The LLM instance to use.")
@@ -57,120 +65,46 @@ class Agent(BaseModel):
57
65
  }
58
66
  )
59
67
 
60
- # Allow arbitrary types
61
68
  model_config = ConfigDict(arbitrary_types_allowed=True)
62
69
 
63
70
  def __init__(self, **kwargs):
64
71
  super().__init__(**kwargs)
65
- # Initialize the model and tools here if needed
72
+ # Initialize the LLM model and tools if needed.
66
73
  self._initialize_model()
67
- # Initialize memory with config
74
+ # Initialize conversation memory with configuration.
68
75
  self.memory = Memory(
69
76
  max_context_length=self.memory_config.get("max_context_length", 4000),
70
77
  summarization_threshold=self.memory_config.get("summarization_threshold", 3000)
71
78
  )
72
- # Initialize tools as an empty list if not provided
79
+ # Initialize tools as an empty list if not provided.
73
80
  if self.tools is None:
74
81
  self.tools = []
75
- # Automatically discover and register tools if auto tool is enabled
82
+ # Automatically discover and register tools if auto_tool is enabled.
76
83
  if self.auto_tool and not self.tools:
77
84
  self.tools = self._discover_tools()
78
- # Pass the LLM instance to each tool
85
+ # Pass the LLM instance to each tool.
79
86
  for tool in self.tools:
80
87
  tool.llm = self.llm_instance
81
- # Initialize the SentenceTransformer model for semantic matching
88
+ # Initialize the SentenceTransformer model for semantic matching.
82
89
  self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
83
- # Initialize RAG if not provided
90
+ # Initialize default RAG if not provided.
84
91
  if self.rag is None:
85
92
  self.rag = self._initialize_default_rag()
86
- # Automatically generate API if api=True
93
+ # Automatically generate API if api=True.
87
94
  if self.api:
88
95
  self._generate_api()
89
96
 
90
-
91
- def _generate_response_from_image(self,message: str, image: Union[str, Image], markdown: bool = False, **kwargs) -> str:
92
- """
93
- Send the image to the LLM for analysis if the LLM supports vision.
94
- Supports both local images (PIL.Image) and image URLs.
95
- """
96
- try:
97
- # Check if the LLM supports vision
98
- if not self.llm_instance or not self.llm_instance.supports_vision:
99
- raise ValueError("Vision is not supported for the current model.")
100
- prompt = self._build_prompt(message, context=None)
101
- # Handle image URL
102
- if isinstance(image, str) and image.startswith("http"):
103
- # Directly pass the URL to the LLM
104
- return self.llm_instance.generate_from_image_url(prompt,image, **kwargs)
105
-
106
- # Handle local image (PIL.Image)
107
- elif isinstance(image, Image):
108
- # Convert the image to bytes
109
- if image.mode == "RGBA":
110
- image = image.convert("RGB") # Convert RGBA to RGB
111
- image_bytes = io.BytesIO()
112
- image.save(image_bytes, format="JPEG") # Save as PNG (or any supported format)
113
- image_bytes = image_bytes.getvalue()
114
-
115
- # Generate response using base64-encoded image bytes
116
- return self.llm_instance.generate_from_image(prompt,image_bytes, **kwargs)
117
-
118
- else:
119
- raise ValueError("Unsupported image type. Provide either a URL or a PIL.Image.")
120
-
121
- except Exception as e:
122
- logger.error(f"Failed to generate response from image: {e}")
123
- return f"An error occurred while processing the image: {e}"
124
-
125
- def _discover_tools(self) -> List[BaseTool]:
126
- """
127
- Automatically discover and register tools from the 'tools' directory.
128
- """
129
- tools = []
130
- tools_dir = Path(__file__).parent / "tools"
131
-
132
- if not tools_dir.exists():
133
- logger.warning(f"Tools directory not found: {tools_dir}")
134
- return tools
135
-
136
- # Iterate over all Python files in the 'tools' directory
137
- for file in tools_dir.glob("*.py"):
138
- if file.name == "base_tool.py":
139
- continue # Skip the base tool file
140
-
141
- try:
142
- # Import the module
143
- module_name = file.stem
144
- module = importlib.import_module(f"semantio.tools.{module_name}")
145
-
146
- # Find all classes that inherit from BaseTool
147
- for name, obj in module.__dict__.items():
148
- if isinstance(obj, type) and issubclass(obj, BaseTool) and obj != BaseTool:
149
- # Instantiate the tool and add it to the list
150
- tools.append(obj())
151
- logger.info(f"Registered tool: {obj.__name__}")
152
- except Exception as e:
153
- logger.error(f"Failed to load tool from {file}: {e}")
154
-
155
- return tools
156
-
157
- def _get_tool_descriptions(self) -> str:
158
- """Generate a description of all available tools for the LLM prompt."""
159
- return "\n".join(
160
- f"{tool.name}: {tool.description}" for tool in self.tools
161
- )
162
-
163
97
  def _initialize_model(self):
164
- """Initialize the model based on the provided configuration."""
98
+ """Initialize the LLM model based on the provided configuration."""
165
99
  if self.llm_instance is not None:
166
- return # LLM is already initialized, do nothing
100
+ return # Already initialized.
167
101
  if self.llm is None:
168
102
  raise ValueError("llm must be specified.")
169
103
 
170
- # Get the API key from the environment or the provided configuration
104
+ # Retrieve API key from configuration or environment variable.
171
105
  api_key = getattr(self, 'api_key', None) or os.getenv(f"{self.llm.upper()}_API_KEY")
172
106
 
173
- # Map LLM providers to their respective classes and default models
107
+ # Map LLM providers to their respective classes and default models.
174
108
  llm_providers = {
175
109
  "groq": {
176
110
  "class": "GroqLlm",
@@ -198,28 +132,23 @@ class Agent(BaseModel):
198
132
  },
199
133
  }
200
134
 
201
- # Normalize the LLM provider name (case-insensitive)
202
135
  llm_provider = self.llm.lower()
203
-
204
136
  if llm_provider not in llm_providers:
205
- raise ValueError(f"Unsupported LLM provider: {self.llm}. Supported providers are: {list(llm_providers.keys())}")
137
+ raise ValueError(f"Unsupported LLM provider: {self.llm}. Supported providers: {list(llm_providers.keys())}")
206
138
 
207
- # Get the LLM class and default model
208
139
  llm_config = llm_providers[llm_provider]
209
140
  llm_class_name = llm_config["class"]
210
141
  default_model = llm_config["default_model"]
211
-
212
- # Use the user-provided model or fallback to the default model
213
142
  model_to_use = self.llm_model or default_model
214
143
 
215
- # Dynamically import and initialize the LLM class
144
+ # Dynamically import and initialize the LLM class.
216
145
  module_name = f"semantio.llm.{llm_provider}"
217
146
  llm_module = importlib.import_module(module_name)
218
147
  llm_class = getattr(llm_module, llm_class_name)
219
148
  self.llm_instance = llm_class(model=model_to_use, api_key=api_key)
220
149
 
221
150
  def _initialize_default_rag(self) -> RAG:
222
- """Initialize a default RAG instance with a dummy vector store."""
151
+ """Initialize a default RAG instance using a dummy vector store."""
223
152
  vector_store = VectorStore()
224
153
  retriever = Retriever(vector_store)
225
154
  return RAG(retriever)
@@ -227,45 +156,129 @@ class Agent(BaseModel):
227
156
  def print_response(
228
157
  self,
229
158
  message: Optional[Union[str, Image, List, Dict]] = None,
159
+ image: Optional[Union[str, Image]] = None,
230
160
  stream: bool = False,
231
161
  markdown: bool = False,
232
162
  team: Optional[List['Agent']] = None,
233
163
  **kwargs,
234
164
  ) -> Union[str, Dict]:
235
- """Print the agent's response to the console and return it."""
236
-
237
- # Store user message if provided
165
+ """
166
+ Generate and print the agent's response while storing conversation history.
167
+ If an image is provided (either via the 'image' parameter or if 'message' is a PIL.Image),
168
+ the agent processes it accordingly.
169
+ If a team is provided (or if self.team is set), only the aggregated final response is returned.
170
+ """
171
+ # Handle image input first.
172
+ if image is not None:
173
+ response = self._generate_response_from_image(message or "", image, markdown=markdown, **kwargs)
174
+ print(response)
175
+ if response:
176
+ self.memory.add_message(role="agent", content=response)
177
+ return response
178
+
179
+ if isinstance(message, Image):
180
+ response = self._generate_response_from_image("", message, markdown=markdown, **kwargs)
181
+ print(response)
182
+ if response:
183
+ self.memory.add_message(role="agent", content=response)
184
+ return response
185
+
186
+ # For text input, add the user message to memory.
238
187
  if message and isinstance(message, str):
239
188
  self.memory.add_message(role="user", content=message)
240
189
 
190
+ # If a team is provided (or if self.team exists), generate an aggregated final response.
191
+ if team is None and self.team is not None:
192
+ team = self.team
193
+
194
+ if team is not None:
195
+ # Instead of printing individual team outputs, call each agent's _generate_response
196
+ # to capture their outputs silently.
197
+ aggregated_responses = []
198
+ for agent in team:
199
+ resp = agent._generate_response(message, markdown=markdown, **kwargs)
200
+ aggregated_responses.append(f"**{agent.name}:**\n\n{resp}")
201
+ final_response = "\n\n".join(aggregated_responses)
202
+ print(final_response)
203
+ self.memory.add_message(role="agent", content=final_response)
204
+ return final_response
205
+
206
+ # Standard text response processing.
241
207
  if stream:
242
- # Handle streaming response
243
208
  response = ""
244
209
  for chunk in self._stream_response(message, markdown=markdown, **kwargs):
245
210
  print(chunk, end="", flush=True)
246
211
  response += chunk
247
- # Store agent response
248
212
  if response:
249
- self.memory.add_message(role="assistant", content=response)
250
- print() # New line after streaming
213
+ self.memory.add_message(role="agent", content=response)
214
+ print()
251
215
  return response
252
216
  else:
253
- # Generate and return the response
254
- response = self._generate_response(message, markdown=markdown, team=team, **kwargs)
255
- print(response) # Print the response to the console
256
- # Store agent response
217
+ response = self._generate_response(message, markdown=markdown, **kwargs)
218
+ print(response)
257
219
  if response:
258
- self.memory.add_message(role="assistant", content=response)
220
+ self.memory.add_message(role="agent", content=response)
259
221
  return response
260
222
 
261
-
262
223
  def _stream_response(self, message: str, markdown: bool = False, **kwargs) -> Iterator[str]:
263
- """Stream the agent's response."""
264
- # Simulate streaming by yielding chunks of the response
224
+ """Simulate streaming of the agent's response."""
265
225
  response = self._generate_response(message, markdown=markdown, **kwargs)
266
226
  for chunk in response.split():
267
227
  yield chunk + " "
268
228
 
229
+ def _generate_response_from_image(self, message: str, image: Union[str, Image], markdown: bool = False, **kwargs) -> str:
230
+ """
231
+ Process an image by sending it to the LLM for analysis if the LLM supports vision.
232
+ Supports both image URLs and local PIL.Image objects.
233
+ """
234
+ try:
235
+ if not self.llm_instance or not getattr(self.llm_instance, "supports_vision", False):
236
+ raise ValueError("Vision is not supported for the current model.")
237
+ prompt = self._build_prompt(message, context=None)
238
+ if isinstance(image, str) and image.startswith("http"):
239
+ return self.llm_instance.generate_from_image_url(prompt, image, **kwargs)
240
+ elif isinstance(image, Image):
241
+ if image.mode == "RGBA":
242
+ image = image.convert("RGB")
243
+ image_bytes = io.BytesIO()
244
+ image.save(image_bytes, format="JPEG")
245
+ image_bytes = image_bytes.getvalue()
246
+ return self.llm_instance.generate_from_image(prompt, image_bytes, **kwargs)
247
+ else:
248
+ raise ValueError("Unsupported image type. Provide either a URL or a PIL.Image.")
249
+ except Exception as e:
250
+ logger.error(f"Failed to generate response from image: {e}")
251
+ return f"An error occurred while processing the image: {e}"
252
+
253
+ def _discover_tools(self) -> List[BaseTool]:
254
+ """
255
+ Automatically discover and register tools from the 'tools' directory.
256
+ """
257
+ tools = []
258
+ tools_dir = Path(__file__).parent / "tools"
259
+ if not tools_dir.exists():
260
+ logger.warning(f"Tools directory not found: {tools_dir}")
261
+ return tools
262
+ for file in tools_dir.glob("*.py"):
263
+ if file.name == "base_tool.py":
264
+ continue # Skip the base tool file.
265
+ try:
266
+ module_name = file.stem
267
+ module = importlib.import_module(f"semantio.tools.{module_name}")
268
+ for name, obj in module.__dict__.items():
269
+ if isinstance(obj, type) and issubclass(obj, BaseTool) and obj != BaseTool:
270
+ tools.append(obj())
271
+ logger.info(f"Registered tool: {obj.__name__}")
272
+ except Exception as e:
273
+ logger.error(f"Failed to load tool from {file}: {e}")
274
+ return tools
275
+
276
+ def _get_tool_descriptions(self) -> str:
277
+ """
278
+ Generate a description of all available tools for inclusion in the LLM prompt.
279
+ """
280
+ return "\n".join(f"{tool.name}: {tool.description}" for tool in self.tools)
281
+
269
282
  def register_tool(self, tool: BaseTool):
270
283
  """Register a tool for the agent."""
271
284
  if self.tools is None:
@@ -274,10 +287,9 @@ class Agent(BaseModel):
274
287
 
275
288
  def _analyze_query_and_select_tools(self, query: str) -> List[Dict[str, Any]]:
276
289
  """
277
- Use the LLM to analyze the query and dynamically select tools.
278
- Returns a list of tool calls, each with the tool name and input.
290
+ Use the LLM to analyze the query and dynamically select the most appropriate tools.
291
+ Returns a list of tool calls (tool name and input).
279
292
  """
280
- # Create a prompt for the LLM to analyze the query and select tools
281
293
  prompt = f"""
282
294
  You are an AI agent that helps analyze user queries and select the most appropriate tools.
283
295
  Below is a list of available tools and their functionalities:
@@ -301,211 +313,191 @@ class Agent(BaseModel):
301
313
  }}
302
314
  ]
303
315
  """
304
-
305
316
  try:
306
- # Call the LLM to generate the response
307
317
  response = self.llm_instance.generate(prompt=prompt)
308
- # Parse the response as JSON
309
318
  tool_calls = json.loads(response)
310
319
  return tool_calls
311
320
  except Exception as e:
312
321
  logger.error(f"Failed to analyze query and select tools: {e}")
313
322
  return []
314
323
 
315
-
316
324
  def _generate_response(self, message: str, markdown: bool = False, team: Optional[List['Agent']] = None, **kwargs) -> str:
317
325
  """Generate the agent's response, including tool execution and context retrieval."""
318
- # Use the specified team if provided
319
326
  if team is not None:
320
327
  return self._generate_team_response(message, team, markdown=markdown, **kwargs)
321
- # Initialize tool_outputs as an empty dictionary
328
+
322
329
  tool_outputs = {}
323
330
  responses = []
324
331
  tool_calls = []
325
- # Use the LLM to analyze the query and dynamically select tools when auto_tool is enabled
332
+
326
333
  if self.auto_tool:
327
334
  tool_calls = self._analyze_query_and_select_tools(message)
328
335
  else:
329
- # Check if tools are provided
330
336
  if self.tools:
331
337
  tool_calls = [
332
338
  {
333
339
  "tool": tool.name,
334
- "input": {
335
- "query": message, # Use the message as the query
336
- "context": None, # No context provided by default
337
- }
340
+ "input": {"query": message, "context": None}
338
341
  }
339
342
  for tool in self.tools
340
343
  ]
341
344
 
342
- # Execute tools if any are detected
343
345
  if tool_calls:
344
346
  for tool_call in tool_calls:
345
347
  tool_name = tool_call["tool"]
346
348
  tool_input = tool_call["input"]
347
-
348
- # Find the tool
349
349
  tool = next((t for t in self.tools if t.name.lower() == tool_name.lower()), None)
350
350
  if tool:
351
351
  try:
352
- # Execute the tool
353
352
  tool_output = tool.execute(tool_input)
354
- response = f"Tool '{tool_name}' executed. Output: {tool_output}"
353
+ response_text = f"Tool '{tool_name}' executed. Output: {tool_output}"
355
354
  if self.show_tool_calls:
356
- response = f"**Tool Called:** {tool_name}\n\n{response}"
357
- responses.append(response)
358
-
359
- # Store the tool output for collaboration
355
+ response_text = f"**Tool Called:** {tool_name}\n\n{response_text}"
356
+ responses.append(response_text)
360
357
  tool_outputs[tool_name] = tool_output
361
358
  except Exception as e:
362
- logger.error(f"Tool called:** {tool_name}\n\n{response}")
359
+ logger.error(f"Error executing tool '{tool_name}': {e}")
363
360
  responses.append(f"An error occurred while executing the tool '{tool_name}': {e}")
364
361
  else:
365
362
  responses.append(f"Tool '{tool_name}' not found.")
366
363
 
367
- # If multiple tools were executed, combine their outputs for analysis
368
364
  if tool_outputs:
369
365
  try:
370
- # Prepare the context for the LLM
371
366
  context = {
372
367
  "conversation_history": self.memory.get_context(self.llm_instance),
373
368
  "tool_outputs": tool_outputs,
374
369
  "rag_context": self.rag.retrieve(message) if self.rag else None,
375
370
  "knowledge_base": self._get_knowledge_context(message) if self.knowledge_base else None,
376
371
  }
377
- # 3. Build a memory-aware prompt.
378
372
  prompt = self._build_memory_prompt(message, context)
379
- # To (convert MemoryEntry objects to dicts and remove metadata):
380
373
  memory_entries = [{"role": e.role, "content": e.content} for e in self.memory.storage.retrieve()]
381
- # Generate a response using the LLM
382
374
  llm_response = self.llm_instance.generate(prompt=prompt, context=context, memory=memory_entries, **kwargs)
383
375
  responses.append(f"**Analysis:**\n\n{llm_response}")
384
376
  except Exception as e:
385
377
  logger.error(f"Failed to generate LLM response: {e}")
386
378
  responses.append(f"An error occurred while generating the analysis: {e}")
387
- if not self.tools and not tool_calls:
388
- # If no tools were executed, proceed with the original logic
389
- # Retrieve relevant context using RAG
390
- rag_context = self.rag.retrieve(message) if self.rag else None
391
- # Retrieve relevant context from the knowledge base (API result)
392
- # knowledge_base_context = None
393
- # if self.knowledge_base:
394
- # # Flatten the knowledge base
395
- # flattened_data = self._flatten_data(self.knowledge_base)
396
- # # Find all relevant key-value pairs in the knowledge base
397
- # relevant_values = self._find_all_relevant_keys(message, flattened_data)
398
- # if relevant_values:
399
- # knowledge_base_context = ", ".join(relevant_values)
400
-
401
- # Combine both contexts (RAG and knowledge base)
379
+ elif not self.tools and not tool_calls:
402
380
  context = {
403
381
  "conversation_history": self.memory.get_context(self.llm_instance),
404
- "rag_context": rag_context,
382
+ "rag_context": self.rag.retrieve(message) if self.rag else None,
405
383
  "knowledge_base": self._get_knowledge_context(message),
406
384
  }
407
- # Prepare the prompt with instructions, description, and context
408
- # 3. Build a memory-aware prompt.
409
385
  prompt = self._build_memory_prompt(message, context)
410
- # To (convert MemoryEntry objects to dicts and remove metadata):
411
386
  memory_entries = [{"role": e.role, "content": e.content} for e in self.memory.storage.retrieve()]
412
-
413
- # Generate the response using the LLM
414
387
  response = self.llm_instance.generate(prompt=prompt, context=context, memory=memory_entries, **kwargs)
415
-
416
-
417
- # Format the response based on the json_output flag
418
388
  if self.json_output:
419
389
  response = self._format_response_as_json(response)
420
-
421
- # Validate the response against the expected_output
422
390
  if self.expected_output:
423
391
  response = self._validate_response(response)
424
-
425
392
  if markdown:
426
393
  return f"**Response:**\n\n{response}"
427
394
  return response
428
395
  return "\n\n".join(responses)
429
396
 
430
- # Modified prompt construction with memory integration
397
+ def _generate_team_response(self, message: str, team: List['Agent'], markdown: bool = False, **kwargs) -> str:
398
+ """
399
+ Generate a final aggregated response using a team of assistants.
400
+ This method calls each team member's internal _generate_response (without printing)
401
+ and aggregates the results into a single output.
402
+ """
403
+ team_responses = []
404
+ for agent in team:
405
+ resp = agent._generate_response(message, markdown=markdown, **kwargs)
406
+ team_responses.append(f"**{agent.name}:**\n\n{resp}")
407
+ return "\n\n".join(team_responses)
408
+
431
409
  def _build_memory_prompt(self, user_input: str, context: dict) -> str:
432
- """Enhanced prompt builder with memory context."""
410
+ """Construct a prompt that incorporates role, instructions, conversation history, and external context."""
433
411
  prompt_parts = []
434
-
435
412
  if self.description:
436
413
  prompt_parts.append(f"# ROLE\n{self.description}")
437
-
438
414
  if self.instructions:
439
- prompt_parts.append(f"# INSTRUCTIONS\n" + "\n".join(f"- {i}" for i in self.instructions))
440
-
441
- if context['conversation_history']:
415
+ prompt_parts.append("# INSTRUCTIONS\n" + "\n".join(f"- {i}" for i in self.instructions))
416
+ if context.get('conversation_history'):
442
417
  prompt_parts.append(f"# CONVERSATION HISTORY\n{context['conversation_history']}")
443
-
444
- if context['knowledge_base']:
418
+ if context.get('knowledge_base'):
445
419
  prompt_parts.append(f"# KNOWLEDGE BASE\n{context['knowledge_base']}")
446
-
447
420
  prompt_parts.append(f"# USER INPUT\n{user_input}")
448
-
449
421
  return "\n\n".join(prompt_parts)
450
422
 
423
+ def _summarize_text(self, text: str) -> str:
424
+ """
425
+ Summarize the provided text using the LLM.
426
+ Adjust the prompt as needed.
427
+ """
428
+ prompt = f"Summarize the following text concisely:\n\n{text}\n\nSummary:"
429
+ summary = self.llm_instance.generate(prompt=prompt)
430
+ return summary.strip()
431
+
451
432
  def _get_knowledge_context(self, message: str) -> str:
452
- """Retrieve and format knowledge base context."""
433
+ """
434
+ Retrieve context from the knowledge base.
435
+ For JSON documents, use the "flattened" field.
436
+ For other documents (e.g., website, YouTube) use the "text" field.
437
+ If the combined text is too long, break it into chunks and summarize each chunk.
438
+ """
453
439
  if not self.knowledge_base:
454
440
  return ""
441
+ texts = []
442
+ for doc in self.knowledge_base:
443
+ if isinstance(doc, dict):
444
+ if "flattened" in doc:
445
+ # Join all values from the flattened key/value pairs.
446
+ flattened_text = " ".join(str(v) for item in doc["flattened"] for v in item.values())
447
+ texts.append(flattened_text)
448
+ elif "text" in doc:
449
+ texts.append(doc["text"])
450
+ else:
451
+ texts.append(" ".join(str(v) for v in doc.values()))
452
+ else:
453
+ texts.append(str(doc))
454
+ combined_text = "\n".join(texts)
455
455
 
456
- flattened = self._flatten_data(self.knowledge_base)
457
- relevant = self._find_all_relevant_keys(message, flattened)
458
- return "\n".join(f"- {item}" for item in relevant) if relevant else ""
459
- def _generate_team_response(self, message: str, team: List['Agent'], markdown: bool = False, **kwargs) -> str:
460
- """Generate a response using a team of assistants."""
461
- responses = []
462
- for agent in team:
463
- response = agent.print_response(message, markdown=markdown, **kwargs)
464
- responses.append(f"**{agent.name}:**\n\n{response}")
465
- return "\n\n".join(responses)
456
+ # If the combined text is very long, break it into chunks and summarize.
457
+ max_words = 1000
458
+ words = combined_text.split()
459
+ if len(words) > max_words:
460
+ chunks = []
461
+ for i in range(0, len(words), max_words):
462
+ chunk = " ".join(words[i:i+max_words])
463
+ chunks.append(chunk)
464
+ # Summarize each chunk.
465
+ summaries = [self._summarize_text(chunk) for chunk in chunks]
466
+ final_context = "\n".join(summaries)
467
+ return final_context
468
+ else:
469
+ return combined_text
470
+
466
471
 
472
+
473
+
474
+
467
475
  def _build_prompt(self, message: str, context: Optional[List[Dict]]) -> str:
468
- """Build the prompt using instructions, description, and context."""
476
+ """Build a basic prompt including description, instructions, context, and user input."""
469
477
  prompt_parts = []
470
-
471
- # Add description if available
472
478
  if self.description:
473
479
  prompt_parts.append(f"Description: {self.description}")
474
-
475
- # Add instructions if available
476
480
  if self.instructions:
477
- instructions = "\n".join(self.instructions)
478
- prompt_parts.append(f"Instructions: {instructions}")
479
-
480
- # Add context if available
481
+ prompt_parts.append("Instructions: " + "\n".join(self.instructions))
481
482
  if context:
482
483
  prompt_parts.append(f"Context: {context}")
483
-
484
- # Add the user's message
485
484
  prompt_parts.append(f"User Input: {message}")
486
-
487
485
  return "\n\n".join(prompt_parts)
488
486
 
489
487
  def _format_response_as_json(self, response: str) -> Union[Dict, str]:
490
- """Format the response as JSON if json_output is True."""
488
+ """Attempt to extract and format a JSON response."""
491
489
  try:
492
- # Use regex to extract JSON from the response (e.g., within ```json ``` blocks)
493
490
  json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL)
494
491
  if json_match:
495
- # Extract the JSON part and parse it
496
492
  json_str = json_match.group(1)
497
- return json.loads(json_str) # Return the parsed JSON object (a dictionary)
493
+ return json.loads(json_str)
498
494
  else:
499
- # If no JSON block is found, try to parse the entire response as JSON
500
- return json.loads(response) # Return the parsed JSON object (a dictionary)
495
+ return json.loads(response)
501
496
  except json.JSONDecodeError:
502
- # If the response is not valid JSON, wrap it in a dictionary
503
- return {"response": response} # Return a dictionary with the response as a string
497
+ return {"response": response}
504
498
 
505
499
  def normalize_key(self, key: str) -> str:
506
- """
507
- Normalize a key by converting it to lowercase and replacing spaces with underscores.
508
- """
500
+ """Normalize a key by converting to lowercase and replacing spaces with underscores."""
509
501
  return key.lower().replace(" ", "_")
510
502
 
511
503
  def match_key(self, expected_key, response_keys, threshold=0.5):
@@ -543,31 +535,22 @@ class Agent(BaseModel):
543
535
  best_match = key
544
536
 
545
537
  return best_match, best_score
546
-
547
538
  def _validate_response(self, response: Union[str, Dict]) -> Union[str, Dict]:
548
- """Validate the response against the expected_output format using semantic similarity or fallback methods."""
539
+ """
540
+ Validate and structure the response based on the expected_output using semantic matching.
541
+ """
549
542
  if isinstance(self.expected_output, dict):
550
543
  if not isinstance(response, dict):
551
544
  return {"response": response}
552
-
553
545
  validated_response = {}
554
546
  normalized_expected_keys = {self.normalize_key(k): k for k in self.expected_output.keys()}
555
-
556
547
  for expected_key_norm, expected_key_orig in normalized_expected_keys.items():
557
- # Find all response keys that match the expected key (case-insensitive and normalized)
558
- matching_response_keys = [
559
- k for k in response.keys()
560
- if self.normalize_key(k) == expected_key_norm
561
- ]
562
-
563
- # If no exact match, use semantic matching to find similar keys
548
+ matching_response_keys = [k for k in response.keys() if self.normalize_key(k) == expected_key_norm]
564
549
  if not matching_response_keys:
565
550
  for response_key in response.keys():
566
551
  best_match, best_score = self.match_key(expected_key_orig, [response_key])
567
- if best_match and best_score > 0.5: # Use a threshold to determine a valid match
552
+ if best_match and best_score > 0.5:
568
553
  matching_response_keys.append(response_key)
569
-
570
- # Merge values from all matching keys
571
554
  merged_values = []
572
555
  for matching_key in matching_response_keys:
573
556
  value = response[matching_key]
@@ -575,50 +558,41 @@ class Agent(BaseModel):
575
558
  merged_values.extend(value)
576
559
  else:
577
560
  merged_values.append(value)
578
-
579
- # Assign the merged values to the expected key
580
- if merged_values:
581
- validated_response[expected_key_orig] = merged_values
582
- else:
583
- validated_response[expected_key_orig] = "NA" # Default value for missing keys
584
-
585
- # Recursively validate nested dictionaries
561
+ validated_response[expected_key_orig] = merged_values if merged_values else "NA"
586
562
  expected_value = self.expected_output[expected_key_orig]
587
563
  if isinstance(expected_value, dict) and isinstance(validated_response[expected_key_orig], dict):
588
564
  validated_response[expected_key_orig] = self._validate_response(validated_response[expected_key_orig])
589
-
590
565
  return validated_response
591
566
  elif isinstance(self.expected_output, str):
592
567
  if not isinstance(response, str):
593
568
  return str(response)
594
569
  return response
595
-
570
+
596
571
  def cli_app(
597
572
  self,
598
573
  message: Optional[str] = None,
599
574
  exit_on: Optional[List[str]] = None,
600
575
  **kwargs,
601
576
  ):
602
- """Run the agent in a CLI app."""
577
+ """Run the agent as a command-line application."""
603
578
  from rich.prompt import Prompt
604
579
 
605
- # Print initial message if provided
606
580
  if message:
607
581
  self.print_response(message=message, **kwargs)
608
582
 
609
583
  _exit_on = exit_on or ["exit", "quit", "bye"]
610
584
  while True:
611
585
  try:
612
- message = Prompt.ask(f"[bold] {self.emoji} {self.user_name} [/bold]")
613
- if message in _exit_on:
586
+ user_input = Prompt.ask(f"[bold] {self.emoji} {self.user_name} [/bold]")
587
+ if user_input in _exit_on:
614
588
  break
615
- self.print_response(message=message, **kwargs)
589
+ self.print_response(message=user_input, **kwargs)
616
590
  except KeyboardInterrupt:
617
591
  print("\n\nSession ended. Goodbye!")
618
592
  break
619
593
 
620
594
  def _generate_api(self):
621
- """Generate an API for the agent if api=True."""
595
+ """Generate an API for the agent if API mode is enabled."""
622
596
  from .api.api_generator import APIGenerator
623
597
  self.api_generator = APIGenerator(self)
624
598
  print(f"API generated for agent '{self.name}'. Use `.run_api()` to start the API server.")
@@ -627,76 +601,7 @@ class Agent(BaseModel):
627
601
  """Run the API server for the agent."""
628
602
  if not hasattr(self, 'api_generator'):
629
603
  raise ValueError("API is not enabled for this agent. Set `api=True` when initializing the agent.")
630
-
631
- # Get API configuration
632
604
  host = self.api_config.get("host", "0.0.0.0") if self.api_config else "0.0.0.0"
633
605
  port = self.api_config.get("port", 8000) if self.api_config else 8000
634
-
635
- # Run the API server
636
606
  self.api_generator.run(host=host, port=port)
637
607
 
638
- def _flatten_data(self, data: Union[Dict, List], parent_key: str = "", separator: str = "_") -> List[Dict]:
639
- """
640
- Recursively flatten a nested dictionary or list into a list of key-value pairs.
641
-
642
- Args:
643
- data (Union[Dict, List]): The nested data structure.
644
- parent_key (str): The parent key (used for recursion).
645
- separator (str): The separator used for nested keys.
646
-
647
- Returns:
648
- List[Dict]: A list of flattened key-value pairs.
649
- """
650
- items = []
651
- if isinstance(data, dict):
652
- for key, value in data.items():
653
- new_key = f"{parent_key}{separator}{key}" if parent_key else key
654
- if isinstance(value, (dict, list)):
655
- items.extend(self._flatten_data(value, new_key, separator))
656
- else:
657
- items.append({new_key: value})
658
- # Include the value as a key for searching
659
- if isinstance(value, str):
660
- items.append({value: new_key})
661
- elif isinstance(data, list):
662
- for index, item in enumerate(data):
663
- new_key = f"{parent_key}{separator}{index}" if parent_key else str(index)
664
- if isinstance(item, (dict, list)):
665
- items.extend(self._flatten_data(item, new_key, separator))
666
- else:
667
- items.append({new_key: item})
668
- # Include the value as a key for searching
669
- if isinstance(item, str):
670
- items.append({item: new_key})
671
- return items
672
-
673
- def _find_all_relevant_keys(self, query: str, flattened_data: List[Dict], threshold: float = 0.5) -> List[str]:
674
- """
675
- Find all relevant keys in the flattened data based on semantic similarity to the query.
676
-
677
- Args:
678
- query (str): The user's query.
679
- flattened_data (List[Dict]): The flattened key-value pairs.
680
- threshold (float): The similarity threshold for considering a match.
681
-
682
- Returns:
683
- List[str]: A list of relevant values.
684
- """
685
- if not flattened_data:
686
- return []
687
-
688
- # Extract keys from the flattened data
689
- keys = [list(item.keys())[0] for item in flattened_data]
690
-
691
- # Compute embeddings for the query and keys
692
- query_embedding = self.semantic_model.encode(query, convert_to_tensor=True)
693
- key_embeddings = self.semantic_model.encode(keys, convert_to_tensor=True)
694
-
695
- # Compute cosine similarity between the query and keys
696
- similarities = util.pytorch_cos_sim(query_embedding, key_embeddings)[0]
697
-
698
- # Find all keys with a similarity score above the threshold
699
- relevant_indices = [i for i, score in enumerate(similarities) if score > threshold]
700
- relevant_values = [flattened_data[i][keys[i]] for i in relevant_indices]
701
-
702
- return relevant_values