semantio 0.0.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- semantio/__init__.py +0 -0
- semantio/agent.py +608 -0
- semantio/api/__init__.py +0 -0
- semantio/api/api_generator.py +23 -0
- semantio/api/fastapi_app.py +70 -0
- semantio/cli/__init__.py +0 -0
- semantio/cli/main.py +31 -0
- semantio/knowledge_base/__init__.py +5 -0
- semantio/knowledge_base/document_loader.py +61 -0
- semantio/knowledge_base/retriever.py +41 -0
- semantio/knowledge_base/vector_store.py +35 -0
- semantio/llm/__init__.py +17 -0
- semantio/llm/anthropic.py +39 -0
- semantio/llm/base_llm.py +12 -0
- semantio/llm/groq.py +39 -0
- semantio/llm/llama.py +0 -0
- semantio/llm/openai.py +26 -0
- semantio/memory.py +11 -0
- semantio/rag.py +18 -0
- semantio/storage/__init__.py +0 -0
- semantio/storage/cloud_storage.py +0 -0
- semantio/storage/local_storage.py +0 -0
- semantio/tools/__init__.py +0 -0
- semantio/tools/base_tool.py +12 -0
- semantio/tools/crypto.py +133 -0
- semantio/tools/duckduckgo.py +128 -0
- semantio/tools/stocks.py +131 -0
- semantio/tools/web_browser.py +153 -0
- semantio/utils/__init__.py +7 -0
- semantio/utils/config.py +41 -0
- semantio/utils/date_utils.py +44 -0
- semantio/utils/file_utils.py +56 -0
- semantio/utils/logger.py +20 -0
- semantio/utils/validation_utils.py +44 -0
- semantio-0.0.1.dist-info/LICENSE +21 -0
- semantio-0.0.1.dist-info/METADATA +163 -0
- semantio-0.0.1.dist-info/RECORD +40 -0
- semantio-0.0.1.dist-info/WHEEL +5 -0
- semantio-0.0.1.dist-info/entry_points.txt +3 -0
- semantio-0.0.1.dist-info/top_level.txt +1 -0
semantio/__init__.py
ADDED
File without changes
|
semantio/agent.py
ADDED
@@ -0,0 +1,608 @@
|
|
1
|
+
from typing import Optional, List, Dict, Union, Iterator, Any
|
2
|
+
from pydantic import BaseModel, Field, ConfigDict
|
3
|
+
from PIL.Image import Image
|
4
|
+
import requests
|
5
|
+
import logging
|
6
|
+
import re
|
7
|
+
import io
|
8
|
+
import json
|
9
|
+
from .rag import RAG
|
10
|
+
from .llm.base_llm import BaseLLM
|
11
|
+
from .knowledge_base.retriever import Retriever
|
12
|
+
from .knowledge_base.vector_store import VectorStore
|
13
|
+
from sentence_transformers import SentenceTransformer, util
|
14
|
+
from fuzzywuzzy import fuzz
|
15
|
+
from .tools.base_tool import BaseTool
|
16
|
+
from pathlib import Path
|
17
|
+
import importlib
|
18
|
+
import os
|
19
|
+
|
20
|
+
# Configure logging
|
21
|
+
logging.basicConfig(level=logging.INFO)
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
class Agent(BaseModel):
|
24
|
+
# -*- Agent settings
|
25
|
+
name: Optional[str] = Field(None, description="Name of the agent.")
|
26
|
+
description: Optional[str] = Field(None, description="Description of the agent's role.")
|
27
|
+
instructions: Optional[List[str]] = Field(None, description="List of instructions for the agent.")
|
28
|
+
model: Optional[str] = Field(None, description="This one is not in the use.")
|
29
|
+
show_tool_calls: bool = Field(False, description="Whether to show tool calls in the response.")
|
30
|
+
markdown: bool = Field(False, description="Whether to format the response in markdown.")
|
31
|
+
tools: Optional[List[BaseTool]] = Field(None, description="List of tools available to the agent.")
|
32
|
+
user_name: Optional[str] = Field("User", description="Name of the user interacting with the agent.")
|
33
|
+
emoji: Optional[str] = Field(":robot:", description="Emoji to represent the agent in the CLI.")
|
34
|
+
rag: Optional[RAG] = Field(None, description="RAG instance for context retrieval.")
|
35
|
+
knowledge_base: Optional[Any] = Field(None, description="Knowledge base for domain-specific information.")
|
36
|
+
llm: Optional[str] = Field(None, description="The LLM provider to use (e.g., 'groq', 'openai', 'anthropic').")
|
37
|
+
llm_model: Optional[str] = Field(None, description="The specific model to use for the LLM provider.")
|
38
|
+
llm_instance: Optional[BaseLLM] = Field(None, description="The LLM instance to use.")
|
39
|
+
json_output: bool = Field(False, description="Whether to format the response as JSON.")
|
40
|
+
api: bool = Field(False, description="Whether to generate an API for the agent.")
|
41
|
+
api_config: Optional[Dict] = Field(
|
42
|
+
None,
|
43
|
+
description="Configuration for the API (e.g., host, port, authentication).",
|
44
|
+
)
|
45
|
+
api_generator: Optional[Any] = Field(None, description="The API generator instance.")
|
46
|
+
expected_output: Optional[Union[str, Dict]] = Field(None, description="The expected format or structure of the output.")
|
47
|
+
semantic_model: Optional[Any] = Field(None, description="SentenceTransformer model for semantic matching.")
|
48
|
+
# Allow arbitrary types
|
49
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
50
|
+
|
51
|
+
def __init__(self, **kwargs):
|
52
|
+
super().__init__(**kwargs)
|
53
|
+
# Initialize the model and tools here if needed
|
54
|
+
self._initialize_model()
|
55
|
+
# Automatically discover and register tools
|
56
|
+
self.tools = self._discover_tools()
|
57
|
+
# Pass the LLM instance to each tool
|
58
|
+
for tool in self.tools:
|
59
|
+
tool.llm = self.llm_instance
|
60
|
+
# Initialize the SentenceTransformer model for semantic matching
|
61
|
+
self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
62
|
+
# Initialize RAG if not provided
|
63
|
+
if self.rag is None:
|
64
|
+
self.rag = self._initialize_default_rag()
|
65
|
+
# Automatically generate API if api=True
|
66
|
+
if self.api:
|
67
|
+
self._generate_api()
|
68
|
+
|
69
|
+
def _discover_tools(self) -> List[BaseTool]:
|
70
|
+
"""
|
71
|
+
Automatically discover and register tools from the 'tools' directory.
|
72
|
+
"""
|
73
|
+
tools = []
|
74
|
+
tools_dir = Path(__file__).parent / "tools"
|
75
|
+
|
76
|
+
if not tools_dir.exists():
|
77
|
+
logger.warning(f"Tools directory not found: {tools_dir}")
|
78
|
+
return tools
|
79
|
+
|
80
|
+
# Iterate over all Python files in the 'tools' directory
|
81
|
+
for file in tools_dir.glob("*.py"):
|
82
|
+
if file.name == "base_tool.py":
|
83
|
+
continue # Skip the base tool file
|
84
|
+
|
85
|
+
try:
|
86
|
+
# Import the module
|
87
|
+
module_name = file.stem
|
88
|
+
module = importlib.import_module(f"hashai.tools.{module_name}")
|
89
|
+
|
90
|
+
# Find all classes that inherit from BaseTool
|
91
|
+
for name, obj in module.__dict__.items():
|
92
|
+
if isinstance(obj, type) and issubclass(obj, BaseTool) and obj != BaseTool:
|
93
|
+
# Instantiate the tool and add it to the list
|
94
|
+
tools.append(obj())
|
95
|
+
logger.info(f"Registered tool: {obj.__name__}")
|
96
|
+
except Exception as e:
|
97
|
+
logger.error(f"Failed to load tool from {file}: {e}")
|
98
|
+
|
99
|
+
return tools
|
100
|
+
|
101
|
+
def _get_tool_descriptions(self) -> str:
|
102
|
+
"""Generate a description of all available tools for the LLM prompt."""
|
103
|
+
return "\n".join(
|
104
|
+
f"{tool.name}: {tool.description}" for tool in self.tools
|
105
|
+
)
|
106
|
+
|
107
|
+
def _initialize_model(self):
|
108
|
+
"""Initialize the model based on the provided configuration."""
|
109
|
+
if self.llm_instance is not None:
|
110
|
+
return # LLM is already initialized, do nothing
|
111
|
+
if self.llm is None:
|
112
|
+
raise ValueError("llm must be specified.")
|
113
|
+
|
114
|
+
# Get the API key from the environment or the provided configuration
|
115
|
+
api_key = getattr(self, 'api_key', None) or os.getenv(f"{self.llm.upper()}_API_KEY")
|
116
|
+
|
117
|
+
# Map LLM providers to their respective classes and default models
|
118
|
+
llm_providers = {
|
119
|
+
"groq": {
|
120
|
+
"class": "GroqLlm",
|
121
|
+
"default_model": "mixtral-8x7b-32768",
|
122
|
+
},
|
123
|
+
"openai": {
|
124
|
+
"class": "OpenAILlm",
|
125
|
+
"default_model": "gpt-4",
|
126
|
+
},
|
127
|
+
"anthropic": {
|
128
|
+
"class": "AnthropicLlm",
|
129
|
+
"default_model": "claude-2.1",
|
130
|
+
},
|
131
|
+
}
|
132
|
+
|
133
|
+
# Normalize the LLM provider name (case-insensitive)
|
134
|
+
llm_provider = self.llm.lower()
|
135
|
+
|
136
|
+
if llm_provider not in llm_providers:
|
137
|
+
raise ValueError(f"Unsupported LLM provider: {self.llm}. Supported providers are: {list(llm_providers.keys())}")
|
138
|
+
|
139
|
+
# Get the LLM class and default model
|
140
|
+
llm_config = llm_providers[llm_provider]
|
141
|
+
llm_class_name = llm_config["class"]
|
142
|
+
default_model = llm_config["default_model"]
|
143
|
+
|
144
|
+
# Use the user-provided model or fallback to the default model
|
145
|
+
model_to_use = self.llm_model or default_model
|
146
|
+
|
147
|
+
# Dynamically import and initialize the LLM class
|
148
|
+
module_name = f"hashai.llm.{llm_provider}"
|
149
|
+
llm_module = importlib.import_module(module_name)
|
150
|
+
llm_class = getattr(llm_module, llm_class_name)
|
151
|
+
self.llm_instance = llm_class(model=model_to_use, api_key=api_key)
|
152
|
+
|
153
|
+
def _initialize_default_rag(self) -> RAG:
|
154
|
+
"""Initialize a default RAG instance with a dummy vector store."""
|
155
|
+
vector_store = VectorStore()
|
156
|
+
retriever = Retriever(vector_store)
|
157
|
+
return RAG(retriever)
|
158
|
+
|
159
|
+
def load_image_from_url(self, image_url: str) -> Image:
|
160
|
+
"""Load an image from a URL and return it as a PIL Image."""
|
161
|
+
response = requests.get(image_url)
|
162
|
+
image_bytes = response.content
|
163
|
+
return Image.open(io.BytesIO(image_bytes))
|
164
|
+
|
165
|
+
def print_response(
|
166
|
+
self,
|
167
|
+
message: Optional[Union[str, Image, List, Dict]] = None,
|
168
|
+
stream: bool = False,
|
169
|
+
markdown: bool = False,
|
170
|
+
**kwargs,
|
171
|
+
) -> Union[str, Dict]: # Add return type hint
|
172
|
+
"""Print the agent's response to the console and return it."""
|
173
|
+
if isinstance(message, Image):
|
174
|
+
# Handle image input
|
175
|
+
message = self._process_image(message)
|
176
|
+
|
177
|
+
if stream:
|
178
|
+
# Handle streaming response
|
179
|
+
response = ""
|
180
|
+
for chunk in self._stream_response(message, markdown=markdown, **kwargs):
|
181
|
+
print(chunk)
|
182
|
+
response += chunk
|
183
|
+
return response
|
184
|
+
else:
|
185
|
+
# Generate and return the response
|
186
|
+
response = self._generate_response(message, markdown=markdown, **kwargs)
|
187
|
+
print(response) # Print the response to the console
|
188
|
+
return response
|
189
|
+
|
190
|
+
def _process_image(self, image: Image) -> str:
|
191
|
+
"""Process the image and return a string representation."""
|
192
|
+
# Convert the image to text or extract relevant information
|
193
|
+
# For now, we'll just return a placeholder string
|
194
|
+
return "Image processed. Extracted text: [Placeholder]"
|
195
|
+
|
196
|
+
def _stream_response(self, message: str, markdown: bool = False, **kwargs) -> Iterator[str]:
|
197
|
+
"""Stream the agent's response."""
|
198
|
+
# Simulate streaming by yielding chunks of the response
|
199
|
+
response = self._generate_response(message, markdown=markdown, **kwargs)
|
200
|
+
for chunk in response.split():
|
201
|
+
yield chunk + " "
|
202
|
+
|
203
|
+
def register_tool(self, tool: BaseTool):
|
204
|
+
"""Register a tool for the agent."""
|
205
|
+
if self.tools is None:
|
206
|
+
self.tools = []
|
207
|
+
self.tools.append(tool)
|
208
|
+
|
209
|
+
def _detect_tool_call(self, message: str) -> Optional[Dict[str, Any]]:
|
210
|
+
"""
|
211
|
+
Use the LLM to detect which tool should be called based on the user's query.
|
212
|
+
"""
|
213
|
+
if not self.tools:
|
214
|
+
logger.warning("No tools available to detect.")
|
215
|
+
return None
|
216
|
+
|
217
|
+
# Create a prompt for the LLM
|
218
|
+
prompt = f"""
|
219
|
+
You are an AI agent that helps users by selecting the most appropriate tool to answer their query. Below is a list of available tools and their functionalities:
|
220
|
+
|
221
|
+
{self._get_tool_descriptions()}
|
222
|
+
|
223
|
+
Based on the user's query, select the most appropriate tool. Respond with the name of the tool (e.g., "CryptoPriceChecker"). If no tool is suitable, respond with "None".
|
224
|
+
|
225
|
+
User Query: "{message}"
|
226
|
+
"""
|
227
|
+
|
228
|
+
try:
|
229
|
+
# Call the LLM to generate the response
|
230
|
+
response = self.llm_instance.generate(prompt=prompt)
|
231
|
+
tool_name = response.strip().replace('"', '').replace("'", "")
|
232
|
+
|
233
|
+
# Find the tool in the list of available tools
|
234
|
+
tool = next((t for t in self.tools if t.name.lower() == tool_name.lower()), None)
|
235
|
+
if tool:
|
236
|
+
logger.info(f"Detected tool call: {tool.name}")
|
237
|
+
return {
|
238
|
+
"tool": tool.name,
|
239
|
+
"input": {"query": message}
|
240
|
+
}
|
241
|
+
except Exception as e:
|
242
|
+
logger.error(f"Failed to detect tool call: {e}")
|
243
|
+
|
244
|
+
return None
|
245
|
+
|
246
|
+
def _analyze_query_and_select_tools(self, query: str) -> List[Dict[str, Any]]:
|
247
|
+
"""
|
248
|
+
Use the LLM to analyze the query and dynamically select tools.
|
249
|
+
Returns a list of tool calls, each with the tool name and input.
|
250
|
+
"""
|
251
|
+
# Create a prompt for the LLM to analyze the query and select tools
|
252
|
+
prompt = f"""
|
253
|
+
You are an AI agent that helps analyze user queries and select the most appropriate tools.
|
254
|
+
Below is a list of available tools and their functionalities:
|
255
|
+
|
256
|
+
{self._get_tool_descriptions()}
|
257
|
+
|
258
|
+
For the following query, analyze the intent and select the most appropriate tools.
|
259
|
+
Respond with a JSON array of tool names and their inputs.
|
260
|
+
If no tool is suitable, respond with an empty array.
|
261
|
+
|
262
|
+
Query: "{query}"
|
263
|
+
|
264
|
+
Respond in the following JSON format:
|
265
|
+
[
|
266
|
+
{{
|
267
|
+
"tool": "tool_name",
|
268
|
+
"input": {{
|
269
|
+
"query": "user_query",
|
270
|
+
"context": "optional_context"
|
271
|
+
}}
|
272
|
+
}}
|
273
|
+
]
|
274
|
+
"""
|
275
|
+
|
276
|
+
try:
|
277
|
+
# Call the LLM to generate the response
|
278
|
+
response = self.llm_instance.generate(prompt=prompt)
|
279
|
+
# Parse the response as JSON
|
280
|
+
tool_calls = json.loads(response)
|
281
|
+
return tool_calls
|
282
|
+
except Exception as e:
|
283
|
+
logger.error(f"Failed to analyze query and select tools: {e}")
|
284
|
+
return []
|
285
|
+
|
286
|
+
|
287
|
+
def _generate_response(self, message: str, markdown: bool = False, **kwargs) -> str:
|
288
|
+
"""Generate the agent's response, including tool execution and context retrieval."""
|
289
|
+
# Use the LLM to analyze the query and dynamically select tools
|
290
|
+
tool_calls = self._analyze_query_and_select_tools(message)
|
291
|
+
|
292
|
+
responses = []
|
293
|
+
tool_outputs = {} # Store outputs of all tools for collaboration
|
294
|
+
|
295
|
+
# Execute tools if any are detected
|
296
|
+
if tool_calls:
|
297
|
+
for tool_call in tool_calls:
|
298
|
+
tool_name = tool_call["tool"]
|
299
|
+
tool_input = tool_call["input"]
|
300
|
+
|
301
|
+
# Find the tool
|
302
|
+
tool = next((t for t in self.tools if t.name.lower() == tool_name.lower()), None)
|
303
|
+
if tool:
|
304
|
+
try:
|
305
|
+
# Execute the tool
|
306
|
+
tool_output = tool.execute(tool_input)
|
307
|
+
response = f"Tool '{tool_name}' executed. Output: {tool_output}"
|
308
|
+
if self.show_tool_calls:
|
309
|
+
response = f"**Tool Called:** {tool_name}\n\n{response}"
|
310
|
+
responses.append(response)
|
311
|
+
|
312
|
+
# Store the tool output for collaboration
|
313
|
+
tool_outputs[tool_name] = tool_output
|
314
|
+
except Exception as e:
|
315
|
+
logger.error(f"Tool called:** {tool_name}\n\n{response}")
|
316
|
+
responses.append(f"An error occurred while executing the tool '{tool_name}': {e}")
|
317
|
+
else:
|
318
|
+
responses.append(f"Tool '{tool_name}' not found.")
|
319
|
+
|
320
|
+
# If multiple tools were executed, combine their outputs for analysis
|
321
|
+
if tool_outputs:
|
322
|
+
try:
|
323
|
+
# Prepare the context for the LLM
|
324
|
+
context = {
|
325
|
+
"tool_outputs": tool_outputs,
|
326
|
+
"rag_context": self.rag.retrieve(message) if self.rag else None,
|
327
|
+
"knowledge_base_context": self._find_all_relevant_keys(message, self._flatten_data(self.knowledge_base)) if self.knowledge_base else None,
|
328
|
+
}
|
329
|
+
|
330
|
+
# Generate a response using the LLM
|
331
|
+
llm_response = self.llm_instance.generate(prompt=message, context=context, **kwargs)
|
332
|
+
responses.append(f"**Analysis:**\n\n{llm_response}")
|
333
|
+
except Exception as e:
|
334
|
+
logger.error(f"Failed to generate LLM response: {e}")
|
335
|
+
responses.append(f"An error occurred while generating the analysis: {e}")
|
336
|
+
|
337
|
+
# If no tools were executed, proceed with the original logic
|
338
|
+
if not tool_calls:
|
339
|
+
# Retrieve relevant context using RAG
|
340
|
+
rag_context = self.rag.retrieve(message) if self.rag else None
|
341
|
+
# Retrieve relevant context from the knowledge base (API result)
|
342
|
+
knowledge_base_context = None
|
343
|
+
if self.knowledge_base:
|
344
|
+
# Flatten the knowledge base
|
345
|
+
flattened_data = self._flatten_data(self.knowledge_base)
|
346
|
+
# Find all relevant key-value pairs in the knowledge base
|
347
|
+
relevant_values = self._find_all_relevant_keys(message, flattened_data)
|
348
|
+
if relevant_values:
|
349
|
+
knowledge_base_context = ", ".join(relevant_values)
|
350
|
+
|
351
|
+
# Combine both contexts (RAG and knowledge base)
|
352
|
+
context = {
|
353
|
+
"rag_context": rag_context,
|
354
|
+
"knowledge_base_context": knowledge_base_context,
|
355
|
+
}
|
356
|
+
# Prepare the prompt with instructions, description, and context
|
357
|
+
prompt = self._build_prompt(message, context)
|
358
|
+
|
359
|
+
# Generate the response using the LLM
|
360
|
+
response = self.llm_instance.generate(prompt=prompt, context=context, **kwargs)
|
361
|
+
|
362
|
+
# Format the response based on the json_output flag
|
363
|
+
if self.json_output:
|
364
|
+
response = self._format_response_as_json(response)
|
365
|
+
|
366
|
+
# Validate the response against the expected_output
|
367
|
+
if self.expected_output:
|
368
|
+
response = self._validate_response(response)
|
369
|
+
|
370
|
+
if markdown:
|
371
|
+
return f"**Response:**\n\n{response}"
|
372
|
+
return response
|
373
|
+
|
374
|
+
# Combine all responses into a single output
|
375
|
+
return "\n\n".join(responses)
|
376
|
+
|
377
|
+
def _build_prompt(self, message: str, context: Optional[List[Dict]]) -> str:
|
378
|
+
"""Build the prompt using instructions, description, and context."""
|
379
|
+
prompt_parts = []
|
380
|
+
|
381
|
+
# Add description if available
|
382
|
+
if self.description:
|
383
|
+
prompt_parts.append(f"Description: {self.description}")
|
384
|
+
|
385
|
+
# Add instructions if available
|
386
|
+
if self.instructions:
|
387
|
+
instructions = "\n".join(self.instructions)
|
388
|
+
prompt_parts.append(f"Instructions: {instructions}")
|
389
|
+
|
390
|
+
# Add context if available
|
391
|
+
if context:
|
392
|
+
prompt_parts.append(f"Context: {context}")
|
393
|
+
|
394
|
+
# Add the user's message
|
395
|
+
prompt_parts.append(f"User Input: {message}")
|
396
|
+
|
397
|
+
return "\n\n".join(prompt_parts)
|
398
|
+
|
399
|
+
def _format_response_as_json(self, response: str) -> Union[Dict, str]:
|
400
|
+
"""Format the response as JSON if json_output is True."""
|
401
|
+
try:
|
402
|
+
# Use regex to extract JSON from the response (e.g., within ```json ``` blocks)
|
403
|
+
json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL)
|
404
|
+
if json_match:
|
405
|
+
# Extract the JSON part and parse it
|
406
|
+
json_str = json_match.group(1)
|
407
|
+
return json.loads(json_str) # Return the parsed JSON object (a dictionary)
|
408
|
+
else:
|
409
|
+
# If no JSON block is found, try to parse the entire response as JSON
|
410
|
+
return json.loads(response) # Return the parsed JSON object (a dictionary)
|
411
|
+
except json.JSONDecodeError:
|
412
|
+
# If the response is not valid JSON, wrap it in a dictionary
|
413
|
+
return {"response": response} # Return a dictionary with the response as a string
|
414
|
+
|
415
|
+
def normalize_key(self, key: str) -> str:
|
416
|
+
"""
|
417
|
+
Normalize a key by converting it to lowercase and replacing spaces with underscores.
|
418
|
+
"""
|
419
|
+
return key.lower().replace(" ", "_")
|
420
|
+
|
421
|
+
def match_key(self, expected_key, response_keys, threshold=0.5):
|
422
|
+
"""
|
423
|
+
Match an expected key to the closest key in the response using semantic similarity or fuzzy matching.
|
424
|
+
"""
|
425
|
+
expected_key_norm = self.normalize_key(expected_key)
|
426
|
+
response_keys_norm = [self.normalize_key(k) for k in response_keys]
|
427
|
+
|
428
|
+
if hasattr(self, 'semantic_model') and self.semantic_model is not None:
|
429
|
+
try:
|
430
|
+
# Compute embeddings for the expected key and all response keys
|
431
|
+
expected_embedding = self.semantic_model.encode(expected_key_norm, convert_to_tensor=True)
|
432
|
+
response_embeddings = self.semantic_model.encode(response_keys_norm, convert_to_tensor=True)
|
433
|
+
|
434
|
+
# Compute cosine similarity
|
435
|
+
similarity_scores = util.pytorch_cos_sim(expected_embedding, response_embeddings)[0]
|
436
|
+
|
437
|
+
# Find the best match
|
438
|
+
best_score = similarity_scores.max().item()
|
439
|
+
best_index = similarity_scores.argmax().item()
|
440
|
+
|
441
|
+
if best_score > threshold:
|
442
|
+
return response_keys[best_index], best_score
|
443
|
+
except Exception as e:
|
444
|
+
logging.warning(f"Semantic matching failed: {e}. Falling back to fuzzy matching.")
|
445
|
+
|
446
|
+
# Fallback to fuzzy matching
|
447
|
+
best_match = None
|
448
|
+
best_score = -1
|
449
|
+
for key, key_norm in zip(response_keys, response_keys_norm):
|
450
|
+
score = fuzz.ratio(expected_key_norm, key_norm) / 100
|
451
|
+
if score > best_score:
|
452
|
+
best_score = score
|
453
|
+
best_match = key
|
454
|
+
|
455
|
+
return best_match, best_score
|
456
|
+
|
457
|
+
def _validate_response(self, response: Union[str, Dict]) -> Union[str, Dict]:
|
458
|
+
"""Validate the response against the expected_output format using semantic similarity or fallback methods."""
|
459
|
+
if isinstance(self.expected_output, dict):
|
460
|
+
if not isinstance(response, dict):
|
461
|
+
return {"response": response}
|
462
|
+
|
463
|
+
validated_response = {}
|
464
|
+
normalized_expected_keys = {self.normalize_key(k): k for k in self.expected_output.keys()}
|
465
|
+
|
466
|
+
for expected_key_norm, expected_key_orig in normalized_expected_keys.items():
|
467
|
+
# Find all response keys that match the expected key (case-insensitive and normalized)
|
468
|
+
matching_response_keys = [
|
469
|
+
k for k in response.keys()
|
470
|
+
if self.normalize_key(k) == expected_key_norm
|
471
|
+
]
|
472
|
+
|
473
|
+
# If no exact match, use semantic matching to find similar keys
|
474
|
+
if not matching_response_keys:
|
475
|
+
for response_key in response.keys():
|
476
|
+
best_match, best_score = self.match_key(expected_key_orig, [response_key])
|
477
|
+
if best_match and best_score > 0.5: # Use a threshold to determine a valid match
|
478
|
+
matching_response_keys.append(response_key)
|
479
|
+
|
480
|
+
# Merge values from all matching keys
|
481
|
+
merged_values = []
|
482
|
+
for matching_key in matching_response_keys:
|
483
|
+
value = response[matching_key]
|
484
|
+
if isinstance(value, list):
|
485
|
+
merged_values.extend(value)
|
486
|
+
else:
|
487
|
+
merged_values.append(value)
|
488
|
+
|
489
|
+
# Assign the merged values to the expected key
|
490
|
+
if merged_values:
|
491
|
+
validated_response[expected_key_orig] = merged_values
|
492
|
+
else:
|
493
|
+
validated_response[expected_key_orig] = "NA" # Default value for missing keys
|
494
|
+
|
495
|
+
# Recursively validate nested dictionaries
|
496
|
+
expected_value = self.expected_output[expected_key_orig]
|
497
|
+
if isinstance(expected_value, dict) and isinstance(validated_response[expected_key_orig], dict):
|
498
|
+
validated_response[expected_key_orig] = self._validate_response(validated_response[expected_key_orig])
|
499
|
+
|
500
|
+
return validated_response
|
501
|
+
elif isinstance(self.expected_output, str):
|
502
|
+
if not isinstance(response, str):
|
503
|
+
return str(response)
|
504
|
+
return response
|
505
|
+
|
506
|
+
def cli_app(
|
507
|
+
self,
|
508
|
+
message: Optional[str] = None,
|
509
|
+
exit_on: Optional[List[str]] = None,
|
510
|
+
**kwargs,
|
511
|
+
):
|
512
|
+
"""Run the agent in a CLI app."""
|
513
|
+
from rich.prompt import Prompt
|
514
|
+
|
515
|
+
if message:
|
516
|
+
self.print_response(message=message, **kwargs)
|
517
|
+
|
518
|
+
_exit_on = exit_on or ["exit", "quit", "bye"]
|
519
|
+
while True:
|
520
|
+
message = Prompt.ask(f"[bold] {self.emoji} {self.user_name} [/bold]")
|
521
|
+
if message in _exit_on:
|
522
|
+
break
|
523
|
+
|
524
|
+
self.print_response(message=message, **kwargs)
|
525
|
+
|
526
|
+
def _generate_api(self):
|
527
|
+
"""Generate an API for the agent if api=True."""
|
528
|
+
from .api.api_generator import APIGenerator
|
529
|
+
self.api_generator = APIGenerator(self)
|
530
|
+
print(f"API generated for agent '{self.name}'. Use `.run_api()` to start the API server.")
|
531
|
+
|
532
|
+
def run_api(self):
|
533
|
+
"""Run the API server for the agent."""
|
534
|
+
if not hasattr(self, 'api_generator'):
|
535
|
+
raise ValueError("API is not enabled for this agent. Set `api=True` when initializing the agent.")
|
536
|
+
|
537
|
+
# Get API configuration
|
538
|
+
host = self.api_config.get("host", "0.0.0.0") if self.api_config else "0.0.0.0"
|
539
|
+
port = self.api_config.get("port", 8000) if self.api_config else 8000
|
540
|
+
|
541
|
+
# Run the API server
|
542
|
+
self.api_generator.run(host=host, port=port)
|
543
|
+
|
544
|
+
def _flatten_data(self, data: Union[Dict, List], parent_key: str = "", separator: str = "_") -> List[Dict]:
|
545
|
+
"""
|
546
|
+
Recursively flatten a nested dictionary or list into a list of key-value pairs.
|
547
|
+
|
548
|
+
Args:
|
549
|
+
data (Union[Dict, List]): The nested data structure.
|
550
|
+
parent_key (str): The parent key (used for recursion).
|
551
|
+
separator (str): The separator used for nested keys.
|
552
|
+
|
553
|
+
Returns:
|
554
|
+
List[Dict]: A list of flattened key-value pairs.
|
555
|
+
"""
|
556
|
+
items = []
|
557
|
+
if isinstance(data, dict):
|
558
|
+
for key, value in data.items():
|
559
|
+
new_key = f"{parent_key}{separator}{key}" if parent_key else key
|
560
|
+
if isinstance(value, (dict, list)):
|
561
|
+
items.extend(self._flatten_data(value, new_key, separator))
|
562
|
+
else:
|
563
|
+
items.append({new_key: value})
|
564
|
+
# Include the value as a key for searching
|
565
|
+
if isinstance(value, str):
|
566
|
+
items.append({value: new_key})
|
567
|
+
elif isinstance(data, list):
|
568
|
+
for index, item in enumerate(data):
|
569
|
+
new_key = f"{parent_key}{separator}{index}" if parent_key else str(index)
|
570
|
+
if isinstance(item, (dict, list)):
|
571
|
+
items.extend(self._flatten_data(item, new_key, separator))
|
572
|
+
else:
|
573
|
+
items.append({new_key: item})
|
574
|
+
# Include the value as a key for searching
|
575
|
+
if isinstance(item, str):
|
576
|
+
items.append({item: new_key})
|
577
|
+
return items
|
578
|
+
|
579
|
+
def _find_all_relevant_keys(self, query: str, flattened_data: List[Dict], threshold: float = 0.5) -> List[str]:
|
580
|
+
"""
|
581
|
+
Find all relevant keys in the flattened data based on semantic similarity to the query.
|
582
|
+
|
583
|
+
Args:
|
584
|
+
query (str): The user's query.
|
585
|
+
flattened_data (List[Dict]): The flattened key-value pairs.
|
586
|
+
threshold (float): The similarity threshold for considering a match.
|
587
|
+
|
588
|
+
Returns:
|
589
|
+
List[str]: A list of relevant values.
|
590
|
+
"""
|
591
|
+
if not flattened_data:
|
592
|
+
return []
|
593
|
+
|
594
|
+
# Extract keys from the flattened data
|
595
|
+
keys = [list(item.keys())[0] for item in flattened_data]
|
596
|
+
|
597
|
+
# Compute embeddings for the query and keys
|
598
|
+
query_embedding = self.semantic_model.encode(query, convert_to_tensor=True)
|
599
|
+
key_embeddings = self.semantic_model.encode(keys, convert_to_tensor=True)
|
600
|
+
|
601
|
+
# Compute cosine similarity between the query and keys
|
602
|
+
similarities = util.pytorch_cos_sim(query_embedding, key_embeddings)[0]
|
603
|
+
|
604
|
+
# Find all keys with a similarity score above the threshold
|
605
|
+
relevant_indices = [i for i, score in enumerate(similarities) if score > threshold]
|
606
|
+
relevant_values = [flattened_data[i][keys[i]] for i in relevant_indices]
|
607
|
+
|
608
|
+
return relevant_values
|
semantio/api/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from .fastapi_app import create_fastapi_app # Import the factory function
|
2
|
+
|
3
|
+
class APIGenerator:
|
4
|
+
def __init__(self, assistant):
|
5
|
+
"""
|
6
|
+
Initialize the APIGenerator with the given assistant.
|
7
|
+
|
8
|
+
Args:
|
9
|
+
assistant: The assistant instance for which the API is being created.
|
10
|
+
"""
|
11
|
+
self.assistant = assistant
|
12
|
+
self.app = create_fastapi_app(assistant, assistant.api_config) # Pass api_config to create_fastapi_app
|
13
|
+
|
14
|
+
def run(self, host: str = "0.0.0.0", port: int = 8000):
|
15
|
+
"""
|
16
|
+
Run the FastAPI app.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
host (str): The host address to run the API server on. Default is "0.0.0.0".
|
20
|
+
port (int): The port to run the API server on. Default is 8000.
|
21
|
+
"""
|
22
|
+
import uvicorn
|
23
|
+
uvicorn.run(self.app, host=host, port=port)
|