semantio 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
semantio/__init__.py ADDED
File without changes
semantio/agent.py ADDED
@@ -0,0 +1,608 @@
1
+ from typing import Optional, List, Dict, Union, Iterator, Any
2
+ from pydantic import BaseModel, Field, ConfigDict
3
+ from PIL.Image import Image
4
+ import requests
5
+ import logging
6
+ import re
7
+ import io
8
+ import json
9
+ from .rag import RAG
10
+ from .llm.base_llm import BaseLLM
11
+ from .knowledge_base.retriever import Retriever
12
+ from .knowledge_base.vector_store import VectorStore
13
+ from sentence_transformers import SentenceTransformer, util
14
+ from fuzzywuzzy import fuzz
15
+ from .tools.base_tool import BaseTool
16
+ from pathlib import Path
17
+ import importlib
18
+ import os
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+ class Agent(BaseModel):
24
+ # -*- Agent settings
25
+ name: Optional[str] = Field(None, description="Name of the agent.")
26
+ description: Optional[str] = Field(None, description="Description of the agent's role.")
27
+ instructions: Optional[List[str]] = Field(None, description="List of instructions for the agent.")
28
+ model: Optional[str] = Field(None, description="This one is not in the use.")
29
+ show_tool_calls: bool = Field(False, description="Whether to show tool calls in the response.")
30
+ markdown: bool = Field(False, description="Whether to format the response in markdown.")
31
+ tools: Optional[List[BaseTool]] = Field(None, description="List of tools available to the agent.")
32
+ user_name: Optional[str] = Field("User", description="Name of the user interacting with the agent.")
33
+ emoji: Optional[str] = Field(":robot:", description="Emoji to represent the agent in the CLI.")
34
+ rag: Optional[RAG] = Field(None, description="RAG instance for context retrieval.")
35
+ knowledge_base: Optional[Any] = Field(None, description="Knowledge base for domain-specific information.")
36
+ llm: Optional[str] = Field(None, description="The LLM provider to use (e.g., 'groq', 'openai', 'anthropic').")
37
+ llm_model: Optional[str] = Field(None, description="The specific model to use for the LLM provider.")
38
+ llm_instance: Optional[BaseLLM] = Field(None, description="The LLM instance to use.")
39
+ json_output: bool = Field(False, description="Whether to format the response as JSON.")
40
+ api: bool = Field(False, description="Whether to generate an API for the agent.")
41
+ api_config: Optional[Dict] = Field(
42
+ None,
43
+ description="Configuration for the API (e.g., host, port, authentication).",
44
+ )
45
+ api_generator: Optional[Any] = Field(None, description="The API generator instance.")
46
+ expected_output: Optional[Union[str, Dict]] = Field(None, description="The expected format or structure of the output.")
47
+ semantic_model: Optional[Any] = Field(None, description="SentenceTransformer model for semantic matching.")
48
+ # Allow arbitrary types
49
+ model_config = ConfigDict(arbitrary_types_allowed=True)
50
+
51
+ def __init__(self, **kwargs):
52
+ super().__init__(**kwargs)
53
+ # Initialize the model and tools here if needed
54
+ self._initialize_model()
55
+ # Automatically discover and register tools
56
+ self.tools = self._discover_tools()
57
+ # Pass the LLM instance to each tool
58
+ for tool in self.tools:
59
+ tool.llm = self.llm_instance
60
+ # Initialize the SentenceTransformer model for semantic matching
61
+ self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
62
+ # Initialize RAG if not provided
63
+ if self.rag is None:
64
+ self.rag = self._initialize_default_rag()
65
+ # Automatically generate API if api=True
66
+ if self.api:
67
+ self._generate_api()
68
+
69
+ def _discover_tools(self) -> List[BaseTool]:
70
+ """
71
+ Automatically discover and register tools from the 'tools' directory.
72
+ """
73
+ tools = []
74
+ tools_dir = Path(__file__).parent / "tools"
75
+
76
+ if not tools_dir.exists():
77
+ logger.warning(f"Tools directory not found: {tools_dir}")
78
+ return tools
79
+
80
+ # Iterate over all Python files in the 'tools' directory
81
+ for file in tools_dir.glob("*.py"):
82
+ if file.name == "base_tool.py":
83
+ continue # Skip the base tool file
84
+
85
+ try:
86
+ # Import the module
87
+ module_name = file.stem
88
+ module = importlib.import_module(f"hashai.tools.{module_name}")
89
+
90
+ # Find all classes that inherit from BaseTool
91
+ for name, obj in module.__dict__.items():
92
+ if isinstance(obj, type) and issubclass(obj, BaseTool) and obj != BaseTool:
93
+ # Instantiate the tool and add it to the list
94
+ tools.append(obj())
95
+ logger.info(f"Registered tool: {obj.__name__}")
96
+ except Exception as e:
97
+ logger.error(f"Failed to load tool from {file}: {e}")
98
+
99
+ return tools
100
+
101
+ def _get_tool_descriptions(self) -> str:
102
+ """Generate a description of all available tools for the LLM prompt."""
103
+ return "\n".join(
104
+ f"{tool.name}: {tool.description}" for tool in self.tools
105
+ )
106
+
107
+ def _initialize_model(self):
108
+ """Initialize the model based on the provided configuration."""
109
+ if self.llm_instance is not None:
110
+ return # LLM is already initialized, do nothing
111
+ if self.llm is None:
112
+ raise ValueError("llm must be specified.")
113
+
114
+ # Get the API key from the environment or the provided configuration
115
+ api_key = getattr(self, 'api_key', None) or os.getenv(f"{self.llm.upper()}_API_KEY")
116
+
117
+ # Map LLM providers to their respective classes and default models
118
+ llm_providers = {
119
+ "groq": {
120
+ "class": "GroqLlm",
121
+ "default_model": "mixtral-8x7b-32768",
122
+ },
123
+ "openai": {
124
+ "class": "OpenAILlm",
125
+ "default_model": "gpt-4",
126
+ },
127
+ "anthropic": {
128
+ "class": "AnthropicLlm",
129
+ "default_model": "claude-2.1",
130
+ },
131
+ }
132
+
133
+ # Normalize the LLM provider name (case-insensitive)
134
+ llm_provider = self.llm.lower()
135
+
136
+ if llm_provider not in llm_providers:
137
+ raise ValueError(f"Unsupported LLM provider: {self.llm}. Supported providers are: {list(llm_providers.keys())}")
138
+
139
+ # Get the LLM class and default model
140
+ llm_config = llm_providers[llm_provider]
141
+ llm_class_name = llm_config["class"]
142
+ default_model = llm_config["default_model"]
143
+
144
+ # Use the user-provided model or fallback to the default model
145
+ model_to_use = self.llm_model or default_model
146
+
147
+ # Dynamically import and initialize the LLM class
148
+ module_name = f"hashai.llm.{llm_provider}"
149
+ llm_module = importlib.import_module(module_name)
150
+ llm_class = getattr(llm_module, llm_class_name)
151
+ self.llm_instance = llm_class(model=model_to_use, api_key=api_key)
152
+
153
+ def _initialize_default_rag(self) -> RAG:
154
+ """Initialize a default RAG instance with a dummy vector store."""
155
+ vector_store = VectorStore()
156
+ retriever = Retriever(vector_store)
157
+ return RAG(retriever)
158
+
159
+ def load_image_from_url(self, image_url: str) -> Image:
160
+ """Load an image from a URL and return it as a PIL Image."""
161
+ response = requests.get(image_url)
162
+ image_bytes = response.content
163
+ return Image.open(io.BytesIO(image_bytes))
164
+
165
+ def print_response(
166
+ self,
167
+ message: Optional[Union[str, Image, List, Dict]] = None,
168
+ stream: bool = False,
169
+ markdown: bool = False,
170
+ **kwargs,
171
+ ) -> Union[str, Dict]: # Add return type hint
172
+ """Print the agent's response to the console and return it."""
173
+ if isinstance(message, Image):
174
+ # Handle image input
175
+ message = self._process_image(message)
176
+
177
+ if stream:
178
+ # Handle streaming response
179
+ response = ""
180
+ for chunk in self._stream_response(message, markdown=markdown, **kwargs):
181
+ print(chunk)
182
+ response += chunk
183
+ return response
184
+ else:
185
+ # Generate and return the response
186
+ response = self._generate_response(message, markdown=markdown, **kwargs)
187
+ print(response) # Print the response to the console
188
+ return response
189
+
190
+ def _process_image(self, image: Image) -> str:
191
+ """Process the image and return a string representation."""
192
+ # Convert the image to text or extract relevant information
193
+ # For now, we'll just return a placeholder string
194
+ return "Image processed. Extracted text: [Placeholder]"
195
+
196
+ def _stream_response(self, message: str, markdown: bool = False, **kwargs) -> Iterator[str]:
197
+ """Stream the agent's response."""
198
+ # Simulate streaming by yielding chunks of the response
199
+ response = self._generate_response(message, markdown=markdown, **kwargs)
200
+ for chunk in response.split():
201
+ yield chunk + " "
202
+
203
+ def register_tool(self, tool: BaseTool):
204
+ """Register a tool for the agent."""
205
+ if self.tools is None:
206
+ self.tools = []
207
+ self.tools.append(tool)
208
+
209
+ def _detect_tool_call(self, message: str) -> Optional[Dict[str, Any]]:
210
+ """
211
+ Use the LLM to detect which tool should be called based on the user's query.
212
+ """
213
+ if not self.tools:
214
+ logger.warning("No tools available to detect.")
215
+ return None
216
+
217
+ # Create a prompt for the LLM
218
+ prompt = f"""
219
+ You are an AI agent that helps users by selecting the most appropriate tool to answer their query. Below is a list of available tools and their functionalities:
220
+
221
+ {self._get_tool_descriptions()}
222
+
223
+ Based on the user's query, select the most appropriate tool. Respond with the name of the tool (e.g., "CryptoPriceChecker"). If no tool is suitable, respond with "None".
224
+
225
+ User Query: "{message}"
226
+ """
227
+
228
+ try:
229
+ # Call the LLM to generate the response
230
+ response = self.llm_instance.generate(prompt=prompt)
231
+ tool_name = response.strip().replace('"', '').replace("'", "")
232
+
233
+ # Find the tool in the list of available tools
234
+ tool = next((t for t in self.tools if t.name.lower() == tool_name.lower()), None)
235
+ if tool:
236
+ logger.info(f"Detected tool call: {tool.name}")
237
+ return {
238
+ "tool": tool.name,
239
+ "input": {"query": message}
240
+ }
241
+ except Exception as e:
242
+ logger.error(f"Failed to detect tool call: {e}")
243
+
244
+ return None
245
+
246
+ def _analyze_query_and_select_tools(self, query: str) -> List[Dict[str, Any]]:
247
+ """
248
+ Use the LLM to analyze the query and dynamically select tools.
249
+ Returns a list of tool calls, each with the tool name and input.
250
+ """
251
+ # Create a prompt for the LLM to analyze the query and select tools
252
+ prompt = f"""
253
+ You are an AI agent that helps analyze user queries and select the most appropriate tools.
254
+ Below is a list of available tools and their functionalities:
255
+
256
+ {self._get_tool_descriptions()}
257
+
258
+ For the following query, analyze the intent and select the most appropriate tools.
259
+ Respond with a JSON array of tool names and their inputs.
260
+ If no tool is suitable, respond with an empty array.
261
+
262
+ Query: "{query}"
263
+
264
+ Respond in the following JSON format:
265
+ [
266
+ {{
267
+ "tool": "tool_name",
268
+ "input": {{
269
+ "query": "user_query",
270
+ "context": "optional_context"
271
+ }}
272
+ }}
273
+ ]
274
+ """
275
+
276
+ try:
277
+ # Call the LLM to generate the response
278
+ response = self.llm_instance.generate(prompt=prompt)
279
+ # Parse the response as JSON
280
+ tool_calls = json.loads(response)
281
+ return tool_calls
282
+ except Exception as e:
283
+ logger.error(f"Failed to analyze query and select tools: {e}")
284
+ return []
285
+
286
+
287
+ def _generate_response(self, message: str, markdown: bool = False, **kwargs) -> str:
288
+ """Generate the agent's response, including tool execution and context retrieval."""
289
+ # Use the LLM to analyze the query and dynamically select tools
290
+ tool_calls = self._analyze_query_and_select_tools(message)
291
+
292
+ responses = []
293
+ tool_outputs = {} # Store outputs of all tools for collaboration
294
+
295
+ # Execute tools if any are detected
296
+ if tool_calls:
297
+ for tool_call in tool_calls:
298
+ tool_name = tool_call["tool"]
299
+ tool_input = tool_call["input"]
300
+
301
+ # Find the tool
302
+ tool = next((t for t in self.tools if t.name.lower() == tool_name.lower()), None)
303
+ if tool:
304
+ try:
305
+ # Execute the tool
306
+ tool_output = tool.execute(tool_input)
307
+ response = f"Tool '{tool_name}' executed. Output: {tool_output}"
308
+ if self.show_tool_calls:
309
+ response = f"**Tool Called:** {tool_name}\n\n{response}"
310
+ responses.append(response)
311
+
312
+ # Store the tool output for collaboration
313
+ tool_outputs[tool_name] = tool_output
314
+ except Exception as e:
315
+ logger.error(f"Tool called:** {tool_name}\n\n{response}")
316
+ responses.append(f"An error occurred while executing the tool '{tool_name}': {e}")
317
+ else:
318
+ responses.append(f"Tool '{tool_name}' not found.")
319
+
320
+ # If multiple tools were executed, combine their outputs for analysis
321
+ if tool_outputs:
322
+ try:
323
+ # Prepare the context for the LLM
324
+ context = {
325
+ "tool_outputs": tool_outputs,
326
+ "rag_context": self.rag.retrieve(message) if self.rag else None,
327
+ "knowledge_base_context": self._find_all_relevant_keys(message, self._flatten_data(self.knowledge_base)) if self.knowledge_base else None,
328
+ }
329
+
330
+ # Generate a response using the LLM
331
+ llm_response = self.llm_instance.generate(prompt=message, context=context, **kwargs)
332
+ responses.append(f"**Analysis:**\n\n{llm_response}")
333
+ except Exception as e:
334
+ logger.error(f"Failed to generate LLM response: {e}")
335
+ responses.append(f"An error occurred while generating the analysis: {e}")
336
+
337
+ # If no tools were executed, proceed with the original logic
338
+ if not tool_calls:
339
+ # Retrieve relevant context using RAG
340
+ rag_context = self.rag.retrieve(message) if self.rag else None
341
+ # Retrieve relevant context from the knowledge base (API result)
342
+ knowledge_base_context = None
343
+ if self.knowledge_base:
344
+ # Flatten the knowledge base
345
+ flattened_data = self._flatten_data(self.knowledge_base)
346
+ # Find all relevant key-value pairs in the knowledge base
347
+ relevant_values = self._find_all_relevant_keys(message, flattened_data)
348
+ if relevant_values:
349
+ knowledge_base_context = ", ".join(relevant_values)
350
+
351
+ # Combine both contexts (RAG and knowledge base)
352
+ context = {
353
+ "rag_context": rag_context,
354
+ "knowledge_base_context": knowledge_base_context,
355
+ }
356
+ # Prepare the prompt with instructions, description, and context
357
+ prompt = self._build_prompt(message, context)
358
+
359
+ # Generate the response using the LLM
360
+ response = self.llm_instance.generate(prompt=prompt, context=context, **kwargs)
361
+
362
+ # Format the response based on the json_output flag
363
+ if self.json_output:
364
+ response = self._format_response_as_json(response)
365
+
366
+ # Validate the response against the expected_output
367
+ if self.expected_output:
368
+ response = self._validate_response(response)
369
+
370
+ if markdown:
371
+ return f"**Response:**\n\n{response}"
372
+ return response
373
+
374
+ # Combine all responses into a single output
375
+ return "\n\n".join(responses)
376
+
377
+ def _build_prompt(self, message: str, context: Optional[List[Dict]]) -> str:
378
+ """Build the prompt using instructions, description, and context."""
379
+ prompt_parts = []
380
+
381
+ # Add description if available
382
+ if self.description:
383
+ prompt_parts.append(f"Description: {self.description}")
384
+
385
+ # Add instructions if available
386
+ if self.instructions:
387
+ instructions = "\n".join(self.instructions)
388
+ prompt_parts.append(f"Instructions: {instructions}")
389
+
390
+ # Add context if available
391
+ if context:
392
+ prompt_parts.append(f"Context: {context}")
393
+
394
+ # Add the user's message
395
+ prompt_parts.append(f"User Input: {message}")
396
+
397
+ return "\n\n".join(prompt_parts)
398
+
399
+ def _format_response_as_json(self, response: str) -> Union[Dict, str]:
400
+ """Format the response as JSON if json_output is True."""
401
+ try:
402
+ # Use regex to extract JSON from the response (e.g., within ```json ``` blocks)
403
+ json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL)
404
+ if json_match:
405
+ # Extract the JSON part and parse it
406
+ json_str = json_match.group(1)
407
+ return json.loads(json_str) # Return the parsed JSON object (a dictionary)
408
+ else:
409
+ # If no JSON block is found, try to parse the entire response as JSON
410
+ return json.loads(response) # Return the parsed JSON object (a dictionary)
411
+ except json.JSONDecodeError:
412
+ # If the response is not valid JSON, wrap it in a dictionary
413
+ return {"response": response} # Return a dictionary with the response as a string
414
+
415
+ def normalize_key(self, key: str) -> str:
416
+ """
417
+ Normalize a key by converting it to lowercase and replacing spaces with underscores.
418
+ """
419
+ return key.lower().replace(" ", "_")
420
+
421
+ def match_key(self, expected_key, response_keys, threshold=0.5):
422
+ """
423
+ Match an expected key to the closest key in the response using semantic similarity or fuzzy matching.
424
+ """
425
+ expected_key_norm = self.normalize_key(expected_key)
426
+ response_keys_norm = [self.normalize_key(k) for k in response_keys]
427
+
428
+ if hasattr(self, 'semantic_model') and self.semantic_model is not None:
429
+ try:
430
+ # Compute embeddings for the expected key and all response keys
431
+ expected_embedding = self.semantic_model.encode(expected_key_norm, convert_to_tensor=True)
432
+ response_embeddings = self.semantic_model.encode(response_keys_norm, convert_to_tensor=True)
433
+
434
+ # Compute cosine similarity
435
+ similarity_scores = util.pytorch_cos_sim(expected_embedding, response_embeddings)[0]
436
+
437
+ # Find the best match
438
+ best_score = similarity_scores.max().item()
439
+ best_index = similarity_scores.argmax().item()
440
+
441
+ if best_score > threshold:
442
+ return response_keys[best_index], best_score
443
+ except Exception as e:
444
+ logging.warning(f"Semantic matching failed: {e}. Falling back to fuzzy matching.")
445
+
446
+ # Fallback to fuzzy matching
447
+ best_match = None
448
+ best_score = -1
449
+ for key, key_norm in zip(response_keys, response_keys_norm):
450
+ score = fuzz.ratio(expected_key_norm, key_norm) / 100
451
+ if score > best_score:
452
+ best_score = score
453
+ best_match = key
454
+
455
+ return best_match, best_score
456
+
457
+ def _validate_response(self, response: Union[str, Dict]) -> Union[str, Dict]:
458
+ """Validate the response against the expected_output format using semantic similarity or fallback methods."""
459
+ if isinstance(self.expected_output, dict):
460
+ if not isinstance(response, dict):
461
+ return {"response": response}
462
+
463
+ validated_response = {}
464
+ normalized_expected_keys = {self.normalize_key(k): k for k in self.expected_output.keys()}
465
+
466
+ for expected_key_norm, expected_key_orig in normalized_expected_keys.items():
467
+ # Find all response keys that match the expected key (case-insensitive and normalized)
468
+ matching_response_keys = [
469
+ k for k in response.keys()
470
+ if self.normalize_key(k) == expected_key_norm
471
+ ]
472
+
473
+ # If no exact match, use semantic matching to find similar keys
474
+ if not matching_response_keys:
475
+ for response_key in response.keys():
476
+ best_match, best_score = self.match_key(expected_key_orig, [response_key])
477
+ if best_match and best_score > 0.5: # Use a threshold to determine a valid match
478
+ matching_response_keys.append(response_key)
479
+
480
+ # Merge values from all matching keys
481
+ merged_values = []
482
+ for matching_key in matching_response_keys:
483
+ value = response[matching_key]
484
+ if isinstance(value, list):
485
+ merged_values.extend(value)
486
+ else:
487
+ merged_values.append(value)
488
+
489
+ # Assign the merged values to the expected key
490
+ if merged_values:
491
+ validated_response[expected_key_orig] = merged_values
492
+ else:
493
+ validated_response[expected_key_orig] = "NA" # Default value for missing keys
494
+
495
+ # Recursively validate nested dictionaries
496
+ expected_value = self.expected_output[expected_key_orig]
497
+ if isinstance(expected_value, dict) and isinstance(validated_response[expected_key_orig], dict):
498
+ validated_response[expected_key_orig] = self._validate_response(validated_response[expected_key_orig])
499
+
500
+ return validated_response
501
+ elif isinstance(self.expected_output, str):
502
+ if not isinstance(response, str):
503
+ return str(response)
504
+ return response
505
+
506
+ def cli_app(
507
+ self,
508
+ message: Optional[str] = None,
509
+ exit_on: Optional[List[str]] = None,
510
+ **kwargs,
511
+ ):
512
+ """Run the agent in a CLI app."""
513
+ from rich.prompt import Prompt
514
+
515
+ if message:
516
+ self.print_response(message=message, **kwargs)
517
+
518
+ _exit_on = exit_on or ["exit", "quit", "bye"]
519
+ while True:
520
+ message = Prompt.ask(f"[bold] {self.emoji} {self.user_name} [/bold]")
521
+ if message in _exit_on:
522
+ break
523
+
524
+ self.print_response(message=message, **kwargs)
525
+
526
+ def _generate_api(self):
527
+ """Generate an API for the agent if api=True."""
528
+ from .api.api_generator import APIGenerator
529
+ self.api_generator = APIGenerator(self)
530
+ print(f"API generated for agent '{self.name}'. Use `.run_api()` to start the API server.")
531
+
532
+ def run_api(self):
533
+ """Run the API server for the agent."""
534
+ if not hasattr(self, 'api_generator'):
535
+ raise ValueError("API is not enabled for this agent. Set `api=True` when initializing the agent.")
536
+
537
+ # Get API configuration
538
+ host = self.api_config.get("host", "0.0.0.0") if self.api_config else "0.0.0.0"
539
+ port = self.api_config.get("port", 8000) if self.api_config else 8000
540
+
541
+ # Run the API server
542
+ self.api_generator.run(host=host, port=port)
543
+
544
+ def _flatten_data(self, data: Union[Dict, List], parent_key: str = "", separator: str = "_") -> List[Dict]:
545
+ """
546
+ Recursively flatten a nested dictionary or list into a list of key-value pairs.
547
+
548
+ Args:
549
+ data (Union[Dict, List]): The nested data structure.
550
+ parent_key (str): The parent key (used for recursion).
551
+ separator (str): The separator used for nested keys.
552
+
553
+ Returns:
554
+ List[Dict]: A list of flattened key-value pairs.
555
+ """
556
+ items = []
557
+ if isinstance(data, dict):
558
+ for key, value in data.items():
559
+ new_key = f"{parent_key}{separator}{key}" if parent_key else key
560
+ if isinstance(value, (dict, list)):
561
+ items.extend(self._flatten_data(value, new_key, separator))
562
+ else:
563
+ items.append({new_key: value})
564
+ # Include the value as a key for searching
565
+ if isinstance(value, str):
566
+ items.append({value: new_key})
567
+ elif isinstance(data, list):
568
+ for index, item in enumerate(data):
569
+ new_key = f"{parent_key}{separator}{index}" if parent_key else str(index)
570
+ if isinstance(item, (dict, list)):
571
+ items.extend(self._flatten_data(item, new_key, separator))
572
+ else:
573
+ items.append({new_key: item})
574
+ # Include the value as a key for searching
575
+ if isinstance(item, str):
576
+ items.append({item: new_key})
577
+ return items
578
+
579
+ def _find_all_relevant_keys(self, query: str, flattened_data: List[Dict], threshold: float = 0.5) -> List[str]:
580
+ """
581
+ Find all relevant keys in the flattened data based on semantic similarity to the query.
582
+
583
+ Args:
584
+ query (str): The user's query.
585
+ flattened_data (List[Dict]): The flattened key-value pairs.
586
+ threshold (float): The similarity threshold for considering a match.
587
+
588
+ Returns:
589
+ List[str]: A list of relevant values.
590
+ """
591
+ if not flattened_data:
592
+ return []
593
+
594
+ # Extract keys from the flattened data
595
+ keys = [list(item.keys())[0] for item in flattened_data]
596
+
597
+ # Compute embeddings for the query and keys
598
+ query_embedding = self.semantic_model.encode(query, convert_to_tensor=True)
599
+ key_embeddings = self.semantic_model.encode(keys, convert_to_tensor=True)
600
+
601
+ # Compute cosine similarity between the query and keys
602
+ similarities = util.pytorch_cos_sim(query_embedding, key_embeddings)[0]
603
+
604
+ # Find all keys with a similarity score above the threshold
605
+ relevant_indices = [i for i, score in enumerate(similarities) if score > threshold]
606
+ relevant_values = [flattened_data[i][keys[i]] for i in relevant_indices]
607
+
608
+ return relevant_values
File without changes
@@ -0,0 +1,23 @@
1
+ from .fastapi_app import create_fastapi_app # Import the factory function
2
+
3
+ class APIGenerator:
4
+ def __init__(self, assistant):
5
+ """
6
+ Initialize the APIGenerator with the given assistant.
7
+
8
+ Args:
9
+ assistant: The assistant instance for which the API is being created.
10
+ """
11
+ self.assistant = assistant
12
+ self.app = create_fastapi_app(assistant, assistant.api_config) # Pass api_config to create_fastapi_app
13
+
14
+ def run(self, host: str = "0.0.0.0", port: int = 8000):
15
+ """
16
+ Run the FastAPI app.
17
+
18
+ Args:
19
+ host (str): The host address to run the API server on. Default is "0.0.0.0".
20
+ port (int): The port to run the API server on. Default is 8000.
21
+ """
22
+ import uvicorn
23
+ uvicorn.run(self.app, host=host, port=port)