semantio 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- semantio/__init__.py +0 -0
- semantio/agent.py +608 -0
- semantio/api/__init__.py +0 -0
- semantio/api/api_generator.py +23 -0
- semantio/api/fastapi_app.py +70 -0
- semantio/cli/__init__.py +0 -0
- semantio/cli/main.py +31 -0
- semantio/knowledge_base/__init__.py +5 -0
- semantio/knowledge_base/document_loader.py +61 -0
- semantio/knowledge_base/retriever.py +41 -0
- semantio/knowledge_base/vector_store.py +35 -0
- semantio/llm/__init__.py +17 -0
- semantio/llm/anthropic.py +39 -0
- semantio/llm/base_llm.py +12 -0
- semantio/llm/groq.py +39 -0
- semantio/llm/llama.py +0 -0
- semantio/llm/openai.py +26 -0
- semantio/memory.py +11 -0
- semantio/rag.py +18 -0
- semantio/storage/__init__.py +0 -0
- semantio/storage/cloud_storage.py +0 -0
- semantio/storage/local_storage.py +0 -0
- semantio/tools/__init__.py +0 -0
- semantio/tools/base_tool.py +12 -0
- semantio/tools/crypto.py +133 -0
- semantio/tools/duckduckgo.py +128 -0
- semantio/tools/stocks.py +131 -0
- semantio/tools/web_browser.py +153 -0
- semantio/utils/__init__.py +7 -0
- semantio/utils/config.py +41 -0
- semantio/utils/date_utils.py +44 -0
- semantio/utils/file_utils.py +56 -0
- semantio/utils/logger.py +20 -0
- semantio/utils/validation_utils.py +44 -0
- semantio-0.0.1.dist-info/LICENSE +21 -0
- semantio-0.0.1.dist-info/METADATA +163 -0
- semantio-0.0.1.dist-info/RECORD +40 -0
- semantio-0.0.1.dist-info/WHEEL +5 -0
- semantio-0.0.1.dist-info/entry_points.txt +3 -0
- semantio-0.0.1.dist-info/top_level.txt +1 -0
semantio/__init__.py
ADDED
File without changes
|
semantio/agent.py
ADDED
@@ -0,0 +1,608 @@
|
|
1
|
+
from typing import Optional, List, Dict, Union, Iterator, Any
|
2
|
+
from pydantic import BaseModel, Field, ConfigDict
|
3
|
+
from PIL.Image import Image
|
4
|
+
import requests
|
5
|
+
import logging
|
6
|
+
import re
|
7
|
+
import io
|
8
|
+
import json
|
9
|
+
from .rag import RAG
|
10
|
+
from .llm.base_llm import BaseLLM
|
11
|
+
from .knowledge_base.retriever import Retriever
|
12
|
+
from .knowledge_base.vector_store import VectorStore
|
13
|
+
from sentence_transformers import SentenceTransformer, util
|
14
|
+
from fuzzywuzzy import fuzz
|
15
|
+
from .tools.base_tool import BaseTool
|
16
|
+
from pathlib import Path
|
17
|
+
import importlib
|
18
|
+
import os
|
19
|
+
|
20
|
+
# Configure logging
|
21
|
+
logging.basicConfig(level=logging.INFO)
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
class Agent(BaseModel):
|
24
|
+
# -*- Agent settings
|
25
|
+
name: Optional[str] = Field(None, description="Name of the agent.")
|
26
|
+
description: Optional[str] = Field(None, description="Description of the agent's role.")
|
27
|
+
instructions: Optional[List[str]] = Field(None, description="List of instructions for the agent.")
|
28
|
+
model: Optional[str] = Field(None, description="This one is not in the use.")
|
29
|
+
show_tool_calls: bool = Field(False, description="Whether to show tool calls in the response.")
|
30
|
+
markdown: bool = Field(False, description="Whether to format the response in markdown.")
|
31
|
+
tools: Optional[List[BaseTool]] = Field(None, description="List of tools available to the agent.")
|
32
|
+
user_name: Optional[str] = Field("User", description="Name of the user interacting with the agent.")
|
33
|
+
emoji: Optional[str] = Field(":robot:", description="Emoji to represent the agent in the CLI.")
|
34
|
+
rag: Optional[RAG] = Field(None, description="RAG instance for context retrieval.")
|
35
|
+
knowledge_base: Optional[Any] = Field(None, description="Knowledge base for domain-specific information.")
|
36
|
+
llm: Optional[str] = Field(None, description="The LLM provider to use (e.g., 'groq', 'openai', 'anthropic').")
|
37
|
+
llm_model: Optional[str] = Field(None, description="The specific model to use for the LLM provider.")
|
38
|
+
llm_instance: Optional[BaseLLM] = Field(None, description="The LLM instance to use.")
|
39
|
+
json_output: bool = Field(False, description="Whether to format the response as JSON.")
|
40
|
+
api: bool = Field(False, description="Whether to generate an API for the agent.")
|
41
|
+
api_config: Optional[Dict] = Field(
|
42
|
+
None,
|
43
|
+
description="Configuration for the API (e.g., host, port, authentication).",
|
44
|
+
)
|
45
|
+
api_generator: Optional[Any] = Field(None, description="The API generator instance.")
|
46
|
+
expected_output: Optional[Union[str, Dict]] = Field(None, description="The expected format or structure of the output.")
|
47
|
+
semantic_model: Optional[Any] = Field(None, description="SentenceTransformer model for semantic matching.")
|
48
|
+
# Allow arbitrary types
|
49
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
50
|
+
|
51
|
+
def __init__(self, **kwargs):
|
52
|
+
super().__init__(**kwargs)
|
53
|
+
# Initialize the model and tools here if needed
|
54
|
+
self._initialize_model()
|
55
|
+
# Automatically discover and register tools
|
56
|
+
self.tools = self._discover_tools()
|
57
|
+
# Pass the LLM instance to each tool
|
58
|
+
for tool in self.tools:
|
59
|
+
tool.llm = self.llm_instance
|
60
|
+
# Initialize the SentenceTransformer model for semantic matching
|
61
|
+
self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
62
|
+
# Initialize RAG if not provided
|
63
|
+
if self.rag is None:
|
64
|
+
self.rag = self._initialize_default_rag()
|
65
|
+
# Automatically generate API if api=True
|
66
|
+
if self.api:
|
67
|
+
self._generate_api()
|
68
|
+
|
69
|
+
def _discover_tools(self) -> List[BaseTool]:
|
70
|
+
"""
|
71
|
+
Automatically discover and register tools from the 'tools' directory.
|
72
|
+
"""
|
73
|
+
tools = []
|
74
|
+
tools_dir = Path(__file__).parent / "tools"
|
75
|
+
|
76
|
+
if not tools_dir.exists():
|
77
|
+
logger.warning(f"Tools directory not found: {tools_dir}")
|
78
|
+
return tools
|
79
|
+
|
80
|
+
# Iterate over all Python files in the 'tools' directory
|
81
|
+
for file in tools_dir.glob("*.py"):
|
82
|
+
if file.name == "base_tool.py":
|
83
|
+
continue # Skip the base tool file
|
84
|
+
|
85
|
+
try:
|
86
|
+
# Import the module
|
87
|
+
module_name = file.stem
|
88
|
+
module = importlib.import_module(f"hashai.tools.{module_name}")
|
89
|
+
|
90
|
+
# Find all classes that inherit from BaseTool
|
91
|
+
for name, obj in module.__dict__.items():
|
92
|
+
if isinstance(obj, type) and issubclass(obj, BaseTool) and obj != BaseTool:
|
93
|
+
# Instantiate the tool and add it to the list
|
94
|
+
tools.append(obj())
|
95
|
+
logger.info(f"Registered tool: {obj.__name__}")
|
96
|
+
except Exception as e:
|
97
|
+
logger.error(f"Failed to load tool from {file}: {e}")
|
98
|
+
|
99
|
+
return tools
|
100
|
+
|
101
|
+
def _get_tool_descriptions(self) -> str:
|
102
|
+
"""Generate a description of all available tools for the LLM prompt."""
|
103
|
+
return "\n".join(
|
104
|
+
f"{tool.name}: {tool.description}" for tool in self.tools
|
105
|
+
)
|
106
|
+
|
107
|
+
def _initialize_model(self):
|
108
|
+
"""Initialize the model based on the provided configuration."""
|
109
|
+
if self.llm_instance is not None:
|
110
|
+
return # LLM is already initialized, do nothing
|
111
|
+
if self.llm is None:
|
112
|
+
raise ValueError("llm must be specified.")
|
113
|
+
|
114
|
+
# Get the API key from the environment or the provided configuration
|
115
|
+
api_key = getattr(self, 'api_key', None) or os.getenv(f"{self.llm.upper()}_API_KEY")
|
116
|
+
|
117
|
+
# Map LLM providers to their respective classes and default models
|
118
|
+
llm_providers = {
|
119
|
+
"groq": {
|
120
|
+
"class": "GroqLlm",
|
121
|
+
"default_model": "mixtral-8x7b-32768",
|
122
|
+
},
|
123
|
+
"openai": {
|
124
|
+
"class": "OpenAILlm",
|
125
|
+
"default_model": "gpt-4",
|
126
|
+
},
|
127
|
+
"anthropic": {
|
128
|
+
"class": "AnthropicLlm",
|
129
|
+
"default_model": "claude-2.1",
|
130
|
+
},
|
131
|
+
}
|
132
|
+
|
133
|
+
# Normalize the LLM provider name (case-insensitive)
|
134
|
+
llm_provider = self.llm.lower()
|
135
|
+
|
136
|
+
if llm_provider not in llm_providers:
|
137
|
+
raise ValueError(f"Unsupported LLM provider: {self.llm}. Supported providers are: {list(llm_providers.keys())}")
|
138
|
+
|
139
|
+
# Get the LLM class and default model
|
140
|
+
llm_config = llm_providers[llm_provider]
|
141
|
+
llm_class_name = llm_config["class"]
|
142
|
+
default_model = llm_config["default_model"]
|
143
|
+
|
144
|
+
# Use the user-provided model or fallback to the default model
|
145
|
+
model_to_use = self.llm_model or default_model
|
146
|
+
|
147
|
+
# Dynamically import and initialize the LLM class
|
148
|
+
module_name = f"hashai.llm.{llm_provider}"
|
149
|
+
llm_module = importlib.import_module(module_name)
|
150
|
+
llm_class = getattr(llm_module, llm_class_name)
|
151
|
+
self.llm_instance = llm_class(model=model_to_use, api_key=api_key)
|
152
|
+
|
153
|
+
def _initialize_default_rag(self) -> RAG:
|
154
|
+
"""Initialize a default RAG instance with a dummy vector store."""
|
155
|
+
vector_store = VectorStore()
|
156
|
+
retriever = Retriever(vector_store)
|
157
|
+
return RAG(retriever)
|
158
|
+
|
159
|
+
def load_image_from_url(self, image_url: str) -> Image:
|
160
|
+
"""Load an image from a URL and return it as a PIL Image."""
|
161
|
+
response = requests.get(image_url)
|
162
|
+
image_bytes = response.content
|
163
|
+
return Image.open(io.BytesIO(image_bytes))
|
164
|
+
|
165
|
+
def print_response(
|
166
|
+
self,
|
167
|
+
message: Optional[Union[str, Image, List, Dict]] = None,
|
168
|
+
stream: bool = False,
|
169
|
+
markdown: bool = False,
|
170
|
+
**kwargs,
|
171
|
+
) -> Union[str, Dict]: # Add return type hint
|
172
|
+
"""Print the agent's response to the console and return it."""
|
173
|
+
if isinstance(message, Image):
|
174
|
+
# Handle image input
|
175
|
+
message = self._process_image(message)
|
176
|
+
|
177
|
+
if stream:
|
178
|
+
# Handle streaming response
|
179
|
+
response = ""
|
180
|
+
for chunk in self._stream_response(message, markdown=markdown, **kwargs):
|
181
|
+
print(chunk)
|
182
|
+
response += chunk
|
183
|
+
return response
|
184
|
+
else:
|
185
|
+
# Generate and return the response
|
186
|
+
response = self._generate_response(message, markdown=markdown, **kwargs)
|
187
|
+
print(response) # Print the response to the console
|
188
|
+
return response
|
189
|
+
|
190
|
+
def _process_image(self, image: Image) -> str:
|
191
|
+
"""Process the image and return a string representation."""
|
192
|
+
# Convert the image to text or extract relevant information
|
193
|
+
# For now, we'll just return a placeholder string
|
194
|
+
return "Image processed. Extracted text: [Placeholder]"
|
195
|
+
|
196
|
+
def _stream_response(self, message: str, markdown: bool = False, **kwargs) -> Iterator[str]:
|
197
|
+
"""Stream the agent's response."""
|
198
|
+
# Simulate streaming by yielding chunks of the response
|
199
|
+
response = self._generate_response(message, markdown=markdown, **kwargs)
|
200
|
+
for chunk in response.split():
|
201
|
+
yield chunk + " "
|
202
|
+
|
203
|
+
def register_tool(self, tool: BaseTool):
|
204
|
+
"""Register a tool for the agent."""
|
205
|
+
if self.tools is None:
|
206
|
+
self.tools = []
|
207
|
+
self.tools.append(tool)
|
208
|
+
|
209
|
+
def _detect_tool_call(self, message: str) -> Optional[Dict[str, Any]]:
|
210
|
+
"""
|
211
|
+
Use the LLM to detect which tool should be called based on the user's query.
|
212
|
+
"""
|
213
|
+
if not self.tools:
|
214
|
+
logger.warning("No tools available to detect.")
|
215
|
+
return None
|
216
|
+
|
217
|
+
# Create a prompt for the LLM
|
218
|
+
prompt = f"""
|
219
|
+
You are an AI agent that helps users by selecting the most appropriate tool to answer their query. Below is a list of available tools and their functionalities:
|
220
|
+
|
221
|
+
{self._get_tool_descriptions()}
|
222
|
+
|
223
|
+
Based on the user's query, select the most appropriate tool. Respond with the name of the tool (e.g., "CryptoPriceChecker"). If no tool is suitable, respond with "None".
|
224
|
+
|
225
|
+
User Query: "{message}"
|
226
|
+
"""
|
227
|
+
|
228
|
+
try:
|
229
|
+
# Call the LLM to generate the response
|
230
|
+
response = self.llm_instance.generate(prompt=prompt)
|
231
|
+
tool_name = response.strip().replace('"', '').replace("'", "")
|
232
|
+
|
233
|
+
# Find the tool in the list of available tools
|
234
|
+
tool = next((t for t in self.tools if t.name.lower() == tool_name.lower()), None)
|
235
|
+
if tool:
|
236
|
+
logger.info(f"Detected tool call: {tool.name}")
|
237
|
+
return {
|
238
|
+
"tool": tool.name,
|
239
|
+
"input": {"query": message}
|
240
|
+
}
|
241
|
+
except Exception as e:
|
242
|
+
logger.error(f"Failed to detect tool call: {e}")
|
243
|
+
|
244
|
+
return None
|
245
|
+
|
246
|
+
def _analyze_query_and_select_tools(self, query: str) -> List[Dict[str, Any]]:
|
247
|
+
"""
|
248
|
+
Use the LLM to analyze the query and dynamically select tools.
|
249
|
+
Returns a list of tool calls, each with the tool name and input.
|
250
|
+
"""
|
251
|
+
# Create a prompt for the LLM to analyze the query and select tools
|
252
|
+
prompt = f"""
|
253
|
+
You are an AI agent that helps analyze user queries and select the most appropriate tools.
|
254
|
+
Below is a list of available tools and their functionalities:
|
255
|
+
|
256
|
+
{self._get_tool_descriptions()}
|
257
|
+
|
258
|
+
For the following query, analyze the intent and select the most appropriate tools.
|
259
|
+
Respond with a JSON array of tool names and their inputs.
|
260
|
+
If no tool is suitable, respond with an empty array.
|
261
|
+
|
262
|
+
Query: "{query}"
|
263
|
+
|
264
|
+
Respond in the following JSON format:
|
265
|
+
[
|
266
|
+
{{
|
267
|
+
"tool": "tool_name",
|
268
|
+
"input": {{
|
269
|
+
"query": "user_query",
|
270
|
+
"context": "optional_context"
|
271
|
+
}}
|
272
|
+
}}
|
273
|
+
]
|
274
|
+
"""
|
275
|
+
|
276
|
+
try:
|
277
|
+
# Call the LLM to generate the response
|
278
|
+
response = self.llm_instance.generate(prompt=prompt)
|
279
|
+
# Parse the response as JSON
|
280
|
+
tool_calls = json.loads(response)
|
281
|
+
return tool_calls
|
282
|
+
except Exception as e:
|
283
|
+
logger.error(f"Failed to analyze query and select tools: {e}")
|
284
|
+
return []
|
285
|
+
|
286
|
+
|
287
|
+
def _generate_response(self, message: str, markdown: bool = False, **kwargs) -> str:
|
288
|
+
"""Generate the agent's response, including tool execution and context retrieval."""
|
289
|
+
# Use the LLM to analyze the query and dynamically select tools
|
290
|
+
tool_calls = self._analyze_query_and_select_tools(message)
|
291
|
+
|
292
|
+
responses = []
|
293
|
+
tool_outputs = {} # Store outputs of all tools for collaboration
|
294
|
+
|
295
|
+
# Execute tools if any are detected
|
296
|
+
if tool_calls:
|
297
|
+
for tool_call in tool_calls:
|
298
|
+
tool_name = tool_call["tool"]
|
299
|
+
tool_input = tool_call["input"]
|
300
|
+
|
301
|
+
# Find the tool
|
302
|
+
tool = next((t for t in self.tools if t.name.lower() == tool_name.lower()), None)
|
303
|
+
if tool:
|
304
|
+
try:
|
305
|
+
# Execute the tool
|
306
|
+
tool_output = tool.execute(tool_input)
|
307
|
+
response = f"Tool '{tool_name}' executed. Output: {tool_output}"
|
308
|
+
if self.show_tool_calls:
|
309
|
+
response = f"**Tool Called:** {tool_name}\n\n{response}"
|
310
|
+
responses.append(response)
|
311
|
+
|
312
|
+
# Store the tool output for collaboration
|
313
|
+
tool_outputs[tool_name] = tool_output
|
314
|
+
except Exception as e:
|
315
|
+
logger.error(f"Tool called:** {tool_name}\n\n{response}")
|
316
|
+
responses.append(f"An error occurred while executing the tool '{tool_name}': {e}")
|
317
|
+
else:
|
318
|
+
responses.append(f"Tool '{tool_name}' not found.")
|
319
|
+
|
320
|
+
# If multiple tools were executed, combine their outputs for analysis
|
321
|
+
if tool_outputs:
|
322
|
+
try:
|
323
|
+
# Prepare the context for the LLM
|
324
|
+
context = {
|
325
|
+
"tool_outputs": tool_outputs,
|
326
|
+
"rag_context": self.rag.retrieve(message) if self.rag else None,
|
327
|
+
"knowledge_base_context": self._find_all_relevant_keys(message, self._flatten_data(self.knowledge_base)) if self.knowledge_base else None,
|
328
|
+
}
|
329
|
+
|
330
|
+
# Generate a response using the LLM
|
331
|
+
llm_response = self.llm_instance.generate(prompt=message, context=context, **kwargs)
|
332
|
+
responses.append(f"**Analysis:**\n\n{llm_response}")
|
333
|
+
except Exception as e:
|
334
|
+
logger.error(f"Failed to generate LLM response: {e}")
|
335
|
+
responses.append(f"An error occurred while generating the analysis: {e}")
|
336
|
+
|
337
|
+
# If no tools were executed, proceed with the original logic
|
338
|
+
if not tool_calls:
|
339
|
+
# Retrieve relevant context using RAG
|
340
|
+
rag_context = self.rag.retrieve(message) if self.rag else None
|
341
|
+
# Retrieve relevant context from the knowledge base (API result)
|
342
|
+
knowledge_base_context = None
|
343
|
+
if self.knowledge_base:
|
344
|
+
# Flatten the knowledge base
|
345
|
+
flattened_data = self._flatten_data(self.knowledge_base)
|
346
|
+
# Find all relevant key-value pairs in the knowledge base
|
347
|
+
relevant_values = self._find_all_relevant_keys(message, flattened_data)
|
348
|
+
if relevant_values:
|
349
|
+
knowledge_base_context = ", ".join(relevant_values)
|
350
|
+
|
351
|
+
# Combine both contexts (RAG and knowledge base)
|
352
|
+
context = {
|
353
|
+
"rag_context": rag_context,
|
354
|
+
"knowledge_base_context": knowledge_base_context,
|
355
|
+
}
|
356
|
+
# Prepare the prompt with instructions, description, and context
|
357
|
+
prompt = self._build_prompt(message, context)
|
358
|
+
|
359
|
+
# Generate the response using the LLM
|
360
|
+
response = self.llm_instance.generate(prompt=prompt, context=context, **kwargs)
|
361
|
+
|
362
|
+
# Format the response based on the json_output flag
|
363
|
+
if self.json_output:
|
364
|
+
response = self._format_response_as_json(response)
|
365
|
+
|
366
|
+
# Validate the response against the expected_output
|
367
|
+
if self.expected_output:
|
368
|
+
response = self._validate_response(response)
|
369
|
+
|
370
|
+
if markdown:
|
371
|
+
return f"**Response:**\n\n{response}"
|
372
|
+
return response
|
373
|
+
|
374
|
+
# Combine all responses into a single output
|
375
|
+
return "\n\n".join(responses)
|
376
|
+
|
377
|
+
def _build_prompt(self, message: str, context: Optional[List[Dict]]) -> str:
|
378
|
+
"""Build the prompt using instructions, description, and context."""
|
379
|
+
prompt_parts = []
|
380
|
+
|
381
|
+
# Add description if available
|
382
|
+
if self.description:
|
383
|
+
prompt_parts.append(f"Description: {self.description}")
|
384
|
+
|
385
|
+
# Add instructions if available
|
386
|
+
if self.instructions:
|
387
|
+
instructions = "\n".join(self.instructions)
|
388
|
+
prompt_parts.append(f"Instructions: {instructions}")
|
389
|
+
|
390
|
+
# Add context if available
|
391
|
+
if context:
|
392
|
+
prompt_parts.append(f"Context: {context}")
|
393
|
+
|
394
|
+
# Add the user's message
|
395
|
+
prompt_parts.append(f"User Input: {message}")
|
396
|
+
|
397
|
+
return "\n\n".join(prompt_parts)
|
398
|
+
|
399
|
+
def _format_response_as_json(self, response: str) -> Union[Dict, str]:
|
400
|
+
"""Format the response as JSON if json_output is True."""
|
401
|
+
try:
|
402
|
+
# Use regex to extract JSON from the response (e.g., within ```json ``` blocks)
|
403
|
+
json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL)
|
404
|
+
if json_match:
|
405
|
+
# Extract the JSON part and parse it
|
406
|
+
json_str = json_match.group(1)
|
407
|
+
return json.loads(json_str) # Return the parsed JSON object (a dictionary)
|
408
|
+
else:
|
409
|
+
# If no JSON block is found, try to parse the entire response as JSON
|
410
|
+
return json.loads(response) # Return the parsed JSON object (a dictionary)
|
411
|
+
except json.JSONDecodeError:
|
412
|
+
# If the response is not valid JSON, wrap it in a dictionary
|
413
|
+
return {"response": response} # Return a dictionary with the response as a string
|
414
|
+
|
415
|
+
def normalize_key(self, key: str) -> str:
|
416
|
+
"""
|
417
|
+
Normalize a key by converting it to lowercase and replacing spaces with underscores.
|
418
|
+
"""
|
419
|
+
return key.lower().replace(" ", "_")
|
420
|
+
|
421
|
+
def match_key(self, expected_key, response_keys, threshold=0.5):
|
422
|
+
"""
|
423
|
+
Match an expected key to the closest key in the response using semantic similarity or fuzzy matching.
|
424
|
+
"""
|
425
|
+
expected_key_norm = self.normalize_key(expected_key)
|
426
|
+
response_keys_norm = [self.normalize_key(k) for k in response_keys]
|
427
|
+
|
428
|
+
if hasattr(self, 'semantic_model') and self.semantic_model is not None:
|
429
|
+
try:
|
430
|
+
# Compute embeddings for the expected key and all response keys
|
431
|
+
expected_embedding = self.semantic_model.encode(expected_key_norm, convert_to_tensor=True)
|
432
|
+
response_embeddings = self.semantic_model.encode(response_keys_norm, convert_to_tensor=True)
|
433
|
+
|
434
|
+
# Compute cosine similarity
|
435
|
+
similarity_scores = util.pytorch_cos_sim(expected_embedding, response_embeddings)[0]
|
436
|
+
|
437
|
+
# Find the best match
|
438
|
+
best_score = similarity_scores.max().item()
|
439
|
+
best_index = similarity_scores.argmax().item()
|
440
|
+
|
441
|
+
if best_score > threshold:
|
442
|
+
return response_keys[best_index], best_score
|
443
|
+
except Exception as e:
|
444
|
+
logging.warning(f"Semantic matching failed: {e}. Falling back to fuzzy matching.")
|
445
|
+
|
446
|
+
# Fallback to fuzzy matching
|
447
|
+
best_match = None
|
448
|
+
best_score = -1
|
449
|
+
for key, key_norm in zip(response_keys, response_keys_norm):
|
450
|
+
score = fuzz.ratio(expected_key_norm, key_norm) / 100
|
451
|
+
if score > best_score:
|
452
|
+
best_score = score
|
453
|
+
best_match = key
|
454
|
+
|
455
|
+
return best_match, best_score
|
456
|
+
|
457
|
+
def _validate_response(self, response: Union[str, Dict]) -> Union[str, Dict]:
|
458
|
+
"""Validate the response against the expected_output format using semantic similarity or fallback methods."""
|
459
|
+
if isinstance(self.expected_output, dict):
|
460
|
+
if not isinstance(response, dict):
|
461
|
+
return {"response": response}
|
462
|
+
|
463
|
+
validated_response = {}
|
464
|
+
normalized_expected_keys = {self.normalize_key(k): k for k in self.expected_output.keys()}
|
465
|
+
|
466
|
+
for expected_key_norm, expected_key_orig in normalized_expected_keys.items():
|
467
|
+
# Find all response keys that match the expected key (case-insensitive and normalized)
|
468
|
+
matching_response_keys = [
|
469
|
+
k for k in response.keys()
|
470
|
+
if self.normalize_key(k) == expected_key_norm
|
471
|
+
]
|
472
|
+
|
473
|
+
# If no exact match, use semantic matching to find similar keys
|
474
|
+
if not matching_response_keys:
|
475
|
+
for response_key in response.keys():
|
476
|
+
best_match, best_score = self.match_key(expected_key_orig, [response_key])
|
477
|
+
if best_match and best_score > 0.5: # Use a threshold to determine a valid match
|
478
|
+
matching_response_keys.append(response_key)
|
479
|
+
|
480
|
+
# Merge values from all matching keys
|
481
|
+
merged_values = []
|
482
|
+
for matching_key in matching_response_keys:
|
483
|
+
value = response[matching_key]
|
484
|
+
if isinstance(value, list):
|
485
|
+
merged_values.extend(value)
|
486
|
+
else:
|
487
|
+
merged_values.append(value)
|
488
|
+
|
489
|
+
# Assign the merged values to the expected key
|
490
|
+
if merged_values:
|
491
|
+
validated_response[expected_key_orig] = merged_values
|
492
|
+
else:
|
493
|
+
validated_response[expected_key_orig] = "NA" # Default value for missing keys
|
494
|
+
|
495
|
+
# Recursively validate nested dictionaries
|
496
|
+
expected_value = self.expected_output[expected_key_orig]
|
497
|
+
if isinstance(expected_value, dict) and isinstance(validated_response[expected_key_orig], dict):
|
498
|
+
validated_response[expected_key_orig] = self._validate_response(validated_response[expected_key_orig])
|
499
|
+
|
500
|
+
return validated_response
|
501
|
+
elif isinstance(self.expected_output, str):
|
502
|
+
if not isinstance(response, str):
|
503
|
+
return str(response)
|
504
|
+
return response
|
505
|
+
|
506
|
+
def cli_app(
|
507
|
+
self,
|
508
|
+
message: Optional[str] = None,
|
509
|
+
exit_on: Optional[List[str]] = None,
|
510
|
+
**kwargs,
|
511
|
+
):
|
512
|
+
"""Run the agent in a CLI app."""
|
513
|
+
from rich.prompt import Prompt
|
514
|
+
|
515
|
+
if message:
|
516
|
+
self.print_response(message=message, **kwargs)
|
517
|
+
|
518
|
+
_exit_on = exit_on or ["exit", "quit", "bye"]
|
519
|
+
while True:
|
520
|
+
message = Prompt.ask(f"[bold] {self.emoji} {self.user_name} [/bold]")
|
521
|
+
if message in _exit_on:
|
522
|
+
break
|
523
|
+
|
524
|
+
self.print_response(message=message, **kwargs)
|
525
|
+
|
526
|
+
def _generate_api(self):
|
527
|
+
"""Generate an API for the agent if api=True."""
|
528
|
+
from .api.api_generator import APIGenerator
|
529
|
+
self.api_generator = APIGenerator(self)
|
530
|
+
print(f"API generated for agent '{self.name}'. Use `.run_api()` to start the API server.")
|
531
|
+
|
532
|
+
def run_api(self):
|
533
|
+
"""Run the API server for the agent."""
|
534
|
+
if not hasattr(self, 'api_generator'):
|
535
|
+
raise ValueError("API is not enabled for this agent. Set `api=True` when initializing the agent.")
|
536
|
+
|
537
|
+
# Get API configuration
|
538
|
+
host = self.api_config.get("host", "0.0.0.0") if self.api_config else "0.0.0.0"
|
539
|
+
port = self.api_config.get("port", 8000) if self.api_config else 8000
|
540
|
+
|
541
|
+
# Run the API server
|
542
|
+
self.api_generator.run(host=host, port=port)
|
543
|
+
|
544
|
+
def _flatten_data(self, data: Union[Dict, List], parent_key: str = "", separator: str = "_") -> List[Dict]:
|
545
|
+
"""
|
546
|
+
Recursively flatten a nested dictionary or list into a list of key-value pairs.
|
547
|
+
|
548
|
+
Args:
|
549
|
+
data (Union[Dict, List]): The nested data structure.
|
550
|
+
parent_key (str): The parent key (used for recursion).
|
551
|
+
separator (str): The separator used for nested keys.
|
552
|
+
|
553
|
+
Returns:
|
554
|
+
List[Dict]: A list of flattened key-value pairs.
|
555
|
+
"""
|
556
|
+
items = []
|
557
|
+
if isinstance(data, dict):
|
558
|
+
for key, value in data.items():
|
559
|
+
new_key = f"{parent_key}{separator}{key}" if parent_key else key
|
560
|
+
if isinstance(value, (dict, list)):
|
561
|
+
items.extend(self._flatten_data(value, new_key, separator))
|
562
|
+
else:
|
563
|
+
items.append({new_key: value})
|
564
|
+
# Include the value as a key for searching
|
565
|
+
if isinstance(value, str):
|
566
|
+
items.append({value: new_key})
|
567
|
+
elif isinstance(data, list):
|
568
|
+
for index, item in enumerate(data):
|
569
|
+
new_key = f"{parent_key}{separator}{index}" if parent_key else str(index)
|
570
|
+
if isinstance(item, (dict, list)):
|
571
|
+
items.extend(self._flatten_data(item, new_key, separator))
|
572
|
+
else:
|
573
|
+
items.append({new_key: item})
|
574
|
+
# Include the value as a key for searching
|
575
|
+
if isinstance(item, str):
|
576
|
+
items.append({item: new_key})
|
577
|
+
return items
|
578
|
+
|
579
|
+
def _find_all_relevant_keys(self, query: str, flattened_data: List[Dict], threshold: float = 0.5) -> List[str]:
|
580
|
+
"""
|
581
|
+
Find all relevant keys in the flattened data based on semantic similarity to the query.
|
582
|
+
|
583
|
+
Args:
|
584
|
+
query (str): The user's query.
|
585
|
+
flattened_data (List[Dict]): The flattened key-value pairs.
|
586
|
+
threshold (float): The similarity threshold for considering a match.
|
587
|
+
|
588
|
+
Returns:
|
589
|
+
List[str]: A list of relevant values.
|
590
|
+
"""
|
591
|
+
if not flattened_data:
|
592
|
+
return []
|
593
|
+
|
594
|
+
# Extract keys from the flattened data
|
595
|
+
keys = [list(item.keys())[0] for item in flattened_data]
|
596
|
+
|
597
|
+
# Compute embeddings for the query and keys
|
598
|
+
query_embedding = self.semantic_model.encode(query, convert_to_tensor=True)
|
599
|
+
key_embeddings = self.semantic_model.encode(keys, convert_to_tensor=True)
|
600
|
+
|
601
|
+
# Compute cosine similarity between the query and keys
|
602
|
+
similarities = util.pytorch_cos_sim(query_embedding, key_embeddings)[0]
|
603
|
+
|
604
|
+
# Find all keys with a similarity score above the threshold
|
605
|
+
relevant_indices = [i for i, score in enumerate(similarities) if score > threshold]
|
606
|
+
relevant_values = [flattened_data[i][keys[i]] for i in relevant_indices]
|
607
|
+
|
608
|
+
return relevant_values
|
semantio/api/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from .fastapi_app import create_fastapi_app # Import the factory function
|
2
|
+
|
3
|
+
class APIGenerator:
|
4
|
+
def __init__(self, assistant):
|
5
|
+
"""
|
6
|
+
Initialize the APIGenerator with the given assistant.
|
7
|
+
|
8
|
+
Args:
|
9
|
+
assistant: The assistant instance for which the API is being created.
|
10
|
+
"""
|
11
|
+
self.assistant = assistant
|
12
|
+
self.app = create_fastapi_app(assistant, assistant.api_config) # Pass api_config to create_fastapi_app
|
13
|
+
|
14
|
+
def run(self, host: str = "0.0.0.0", port: int = 8000):
|
15
|
+
"""
|
16
|
+
Run the FastAPI app.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
host (str): The host address to run the API server on. Default is "0.0.0.0".
|
20
|
+
port (int): The port to run the API server on. Default is 8000.
|
21
|
+
"""
|
22
|
+
import uvicorn
|
23
|
+
uvicorn.run(self.app, host=host, port=port)
|