semantio 0.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
semantio/__init__.py ADDED
File without changes
semantio/agent.py ADDED
@@ -0,0 +1,608 @@
1
+ from typing import Optional, List, Dict, Union, Iterator, Any
2
+ from pydantic import BaseModel, Field, ConfigDict
3
+ from PIL.Image import Image
4
+ import requests
5
+ import logging
6
+ import re
7
+ import io
8
+ import json
9
+ from .rag import RAG
10
+ from .llm.base_llm import BaseLLM
11
+ from .knowledge_base.retriever import Retriever
12
+ from .knowledge_base.vector_store import VectorStore
13
+ from sentence_transformers import SentenceTransformer, util
14
+ from fuzzywuzzy import fuzz
15
+ from .tools.base_tool import BaseTool
16
+ from pathlib import Path
17
+ import importlib
18
+ import os
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+ class Agent(BaseModel):
24
+ # -*- Agent settings
25
+ name: Optional[str] = Field(None, description="Name of the agent.")
26
+ description: Optional[str] = Field(None, description="Description of the agent's role.")
27
+ instructions: Optional[List[str]] = Field(None, description="List of instructions for the agent.")
28
+ model: Optional[str] = Field(None, description="This one is not in the use.")
29
+ show_tool_calls: bool = Field(False, description="Whether to show tool calls in the response.")
30
+ markdown: bool = Field(False, description="Whether to format the response in markdown.")
31
+ tools: Optional[List[BaseTool]] = Field(None, description="List of tools available to the agent.")
32
+ user_name: Optional[str] = Field("User", description="Name of the user interacting with the agent.")
33
+ emoji: Optional[str] = Field(":robot:", description="Emoji to represent the agent in the CLI.")
34
+ rag: Optional[RAG] = Field(None, description="RAG instance for context retrieval.")
35
+ knowledge_base: Optional[Any] = Field(None, description="Knowledge base for domain-specific information.")
36
+ llm: Optional[str] = Field(None, description="The LLM provider to use (e.g., 'groq', 'openai', 'anthropic').")
37
+ llm_model: Optional[str] = Field(None, description="The specific model to use for the LLM provider.")
38
+ llm_instance: Optional[BaseLLM] = Field(None, description="The LLM instance to use.")
39
+ json_output: bool = Field(False, description="Whether to format the response as JSON.")
40
+ api: bool = Field(False, description="Whether to generate an API for the agent.")
41
+ api_config: Optional[Dict] = Field(
42
+ None,
43
+ description="Configuration for the API (e.g., host, port, authentication).",
44
+ )
45
+ api_generator: Optional[Any] = Field(None, description="The API generator instance.")
46
+ expected_output: Optional[Union[str, Dict]] = Field(None, description="The expected format or structure of the output.")
47
+ semantic_model: Optional[Any] = Field(None, description="SentenceTransformer model for semantic matching.")
48
+ # Allow arbitrary types
49
+ model_config = ConfigDict(arbitrary_types_allowed=True)
50
+
51
+ def __init__(self, **kwargs):
52
+ super().__init__(**kwargs)
53
+ # Initialize the model and tools here if needed
54
+ self._initialize_model()
55
+ # Automatically discover and register tools
56
+ self.tools = self._discover_tools()
57
+ # Pass the LLM instance to each tool
58
+ for tool in self.tools:
59
+ tool.llm = self.llm_instance
60
+ # Initialize the SentenceTransformer model for semantic matching
61
+ self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
62
+ # Initialize RAG if not provided
63
+ if self.rag is None:
64
+ self.rag = self._initialize_default_rag()
65
+ # Automatically generate API if api=True
66
+ if self.api:
67
+ self._generate_api()
68
+
69
+ def _discover_tools(self) -> List[BaseTool]:
70
+ """
71
+ Automatically discover and register tools from the 'tools' directory.
72
+ """
73
+ tools = []
74
+ tools_dir = Path(__file__).parent / "tools"
75
+
76
+ if not tools_dir.exists():
77
+ logger.warning(f"Tools directory not found: {tools_dir}")
78
+ return tools
79
+
80
+ # Iterate over all Python files in the 'tools' directory
81
+ for file in tools_dir.glob("*.py"):
82
+ if file.name == "base_tool.py":
83
+ continue # Skip the base tool file
84
+
85
+ try:
86
+ # Import the module
87
+ module_name = file.stem
88
+ module = importlib.import_module(f"hashai.tools.{module_name}")
89
+
90
+ # Find all classes that inherit from BaseTool
91
+ for name, obj in module.__dict__.items():
92
+ if isinstance(obj, type) and issubclass(obj, BaseTool) and obj != BaseTool:
93
+ # Instantiate the tool and add it to the list
94
+ tools.append(obj())
95
+ logger.info(f"Registered tool: {obj.__name__}")
96
+ except Exception as e:
97
+ logger.error(f"Failed to load tool from {file}: {e}")
98
+
99
+ return tools
100
+
101
+ def _get_tool_descriptions(self) -> str:
102
+ """Generate a description of all available tools for the LLM prompt."""
103
+ return "\n".join(
104
+ f"{tool.name}: {tool.description}" for tool in self.tools
105
+ )
106
+
107
+ def _initialize_model(self):
108
+ """Initialize the model based on the provided configuration."""
109
+ if self.llm_instance is not None:
110
+ return # LLM is already initialized, do nothing
111
+ if self.llm is None:
112
+ raise ValueError("llm must be specified.")
113
+
114
+ # Get the API key from the environment or the provided configuration
115
+ api_key = getattr(self, 'api_key', None) or os.getenv(f"{self.llm.upper()}_API_KEY")
116
+
117
+ # Map LLM providers to their respective classes and default models
118
+ llm_providers = {
119
+ "groq": {
120
+ "class": "GroqLlm",
121
+ "default_model": "mixtral-8x7b-32768",
122
+ },
123
+ "openai": {
124
+ "class": "OpenAILlm",
125
+ "default_model": "gpt-4",
126
+ },
127
+ "anthropic": {
128
+ "class": "AnthropicLlm",
129
+ "default_model": "claude-2.1",
130
+ },
131
+ }
132
+
133
+ # Normalize the LLM provider name (case-insensitive)
134
+ llm_provider = self.llm.lower()
135
+
136
+ if llm_provider not in llm_providers:
137
+ raise ValueError(f"Unsupported LLM provider: {self.llm}. Supported providers are: {list(llm_providers.keys())}")
138
+
139
+ # Get the LLM class and default model
140
+ llm_config = llm_providers[llm_provider]
141
+ llm_class_name = llm_config["class"]
142
+ default_model = llm_config["default_model"]
143
+
144
+ # Use the user-provided model or fallback to the default model
145
+ model_to_use = self.llm_model or default_model
146
+
147
+ # Dynamically import and initialize the LLM class
148
+ module_name = f"hashai.llm.{llm_provider}"
149
+ llm_module = importlib.import_module(module_name)
150
+ llm_class = getattr(llm_module, llm_class_name)
151
+ self.llm_instance = llm_class(model=model_to_use, api_key=api_key)
152
+
153
+ def _initialize_default_rag(self) -> RAG:
154
+ """Initialize a default RAG instance with a dummy vector store."""
155
+ vector_store = VectorStore()
156
+ retriever = Retriever(vector_store)
157
+ return RAG(retriever)
158
+
159
+ def load_image_from_url(self, image_url: str) -> Image:
160
+ """Load an image from a URL and return it as a PIL Image."""
161
+ response = requests.get(image_url)
162
+ image_bytes = response.content
163
+ return Image.open(io.BytesIO(image_bytes))
164
+
165
+ def print_response(
166
+ self,
167
+ message: Optional[Union[str, Image, List, Dict]] = None,
168
+ stream: bool = False,
169
+ markdown: bool = False,
170
+ **kwargs,
171
+ ) -> Union[str, Dict]: # Add return type hint
172
+ """Print the agent's response to the console and return it."""
173
+ if isinstance(message, Image):
174
+ # Handle image input
175
+ message = self._process_image(message)
176
+
177
+ if stream:
178
+ # Handle streaming response
179
+ response = ""
180
+ for chunk in self._stream_response(message, markdown=markdown, **kwargs):
181
+ print(chunk)
182
+ response += chunk
183
+ return response
184
+ else:
185
+ # Generate and return the response
186
+ response = self._generate_response(message, markdown=markdown, **kwargs)
187
+ print(response) # Print the response to the console
188
+ return response
189
+
190
+ def _process_image(self, image: Image) -> str:
191
+ """Process the image and return a string representation."""
192
+ # Convert the image to text or extract relevant information
193
+ # For now, we'll just return a placeholder string
194
+ return "Image processed. Extracted text: [Placeholder]"
195
+
196
+ def _stream_response(self, message: str, markdown: bool = False, **kwargs) -> Iterator[str]:
197
+ """Stream the agent's response."""
198
+ # Simulate streaming by yielding chunks of the response
199
+ response = self._generate_response(message, markdown=markdown, **kwargs)
200
+ for chunk in response.split():
201
+ yield chunk + " "
202
+
203
+ def register_tool(self, tool: BaseTool):
204
+ """Register a tool for the agent."""
205
+ if self.tools is None:
206
+ self.tools = []
207
+ self.tools.append(tool)
208
+
209
+ def _detect_tool_call(self, message: str) -> Optional[Dict[str, Any]]:
210
+ """
211
+ Use the LLM to detect which tool should be called based on the user's query.
212
+ """
213
+ if not self.tools:
214
+ logger.warning("No tools available to detect.")
215
+ return None
216
+
217
+ # Create a prompt for the LLM
218
+ prompt = f"""
219
+ You are an AI agent that helps users by selecting the most appropriate tool to answer their query. Below is a list of available tools and their functionalities:
220
+
221
+ {self._get_tool_descriptions()}
222
+
223
+ Based on the user's query, select the most appropriate tool. Respond with the name of the tool (e.g., "CryptoPriceChecker"). If no tool is suitable, respond with "None".
224
+
225
+ User Query: "{message}"
226
+ """
227
+
228
+ try:
229
+ # Call the LLM to generate the response
230
+ response = self.llm_instance.generate(prompt=prompt)
231
+ tool_name = response.strip().replace('"', '').replace("'", "")
232
+
233
+ # Find the tool in the list of available tools
234
+ tool = next((t for t in self.tools if t.name.lower() == tool_name.lower()), None)
235
+ if tool:
236
+ logger.info(f"Detected tool call: {tool.name}")
237
+ return {
238
+ "tool": tool.name,
239
+ "input": {"query": message}
240
+ }
241
+ except Exception as e:
242
+ logger.error(f"Failed to detect tool call: {e}")
243
+
244
+ return None
245
+
246
+ def _analyze_query_and_select_tools(self, query: str) -> List[Dict[str, Any]]:
247
+ """
248
+ Use the LLM to analyze the query and dynamically select tools.
249
+ Returns a list of tool calls, each with the tool name and input.
250
+ """
251
+ # Create a prompt for the LLM to analyze the query and select tools
252
+ prompt = f"""
253
+ You are an AI agent that helps analyze user queries and select the most appropriate tools.
254
+ Below is a list of available tools and their functionalities:
255
+
256
+ {self._get_tool_descriptions()}
257
+
258
+ For the following query, analyze the intent and select the most appropriate tools.
259
+ Respond with a JSON array of tool names and their inputs.
260
+ If no tool is suitable, respond with an empty array.
261
+
262
+ Query: "{query}"
263
+
264
+ Respond in the following JSON format:
265
+ [
266
+ {{
267
+ "tool": "tool_name",
268
+ "input": {{
269
+ "query": "user_query",
270
+ "context": "optional_context"
271
+ }}
272
+ }}
273
+ ]
274
+ """
275
+
276
+ try:
277
+ # Call the LLM to generate the response
278
+ response = self.llm_instance.generate(prompt=prompt)
279
+ # Parse the response as JSON
280
+ tool_calls = json.loads(response)
281
+ return tool_calls
282
+ except Exception as e:
283
+ logger.error(f"Failed to analyze query and select tools: {e}")
284
+ return []
285
+
286
+
287
+ def _generate_response(self, message: str, markdown: bool = False, **kwargs) -> str:
288
+ """Generate the agent's response, including tool execution and context retrieval."""
289
+ # Use the LLM to analyze the query and dynamically select tools
290
+ tool_calls = self._analyze_query_and_select_tools(message)
291
+
292
+ responses = []
293
+ tool_outputs = {} # Store outputs of all tools for collaboration
294
+
295
+ # Execute tools if any are detected
296
+ if tool_calls:
297
+ for tool_call in tool_calls:
298
+ tool_name = tool_call["tool"]
299
+ tool_input = tool_call["input"]
300
+
301
+ # Find the tool
302
+ tool = next((t for t in self.tools if t.name.lower() == tool_name.lower()), None)
303
+ if tool:
304
+ try:
305
+ # Execute the tool
306
+ tool_output = tool.execute(tool_input)
307
+ response = f"Tool '{tool_name}' executed. Output: {tool_output}"
308
+ if self.show_tool_calls:
309
+ response = f"**Tool Called:** {tool_name}\n\n{response}"
310
+ responses.append(response)
311
+
312
+ # Store the tool output for collaboration
313
+ tool_outputs[tool_name] = tool_output
314
+ except Exception as e:
315
+ logger.error(f"Tool called:** {tool_name}\n\n{response}")
316
+ responses.append(f"An error occurred while executing the tool '{tool_name}': {e}")
317
+ else:
318
+ responses.append(f"Tool '{tool_name}' not found.")
319
+
320
+ # If multiple tools were executed, combine their outputs for analysis
321
+ if tool_outputs:
322
+ try:
323
+ # Prepare the context for the LLM
324
+ context = {
325
+ "tool_outputs": tool_outputs,
326
+ "rag_context": self.rag.retrieve(message) if self.rag else None,
327
+ "knowledge_base_context": self._find_all_relevant_keys(message, self._flatten_data(self.knowledge_base)) if self.knowledge_base else None,
328
+ }
329
+
330
+ # Generate a response using the LLM
331
+ llm_response = self.llm_instance.generate(prompt=message, context=context, **kwargs)
332
+ responses.append(f"**Analysis:**\n\n{llm_response}")
333
+ except Exception as e:
334
+ logger.error(f"Failed to generate LLM response: {e}")
335
+ responses.append(f"An error occurred while generating the analysis: {e}")
336
+
337
+ # If no tools were executed, proceed with the original logic
338
+ if not tool_calls:
339
+ # Retrieve relevant context using RAG
340
+ rag_context = self.rag.retrieve(message) if self.rag else None
341
+ # Retrieve relevant context from the knowledge base (API result)
342
+ knowledge_base_context = None
343
+ if self.knowledge_base:
344
+ # Flatten the knowledge base
345
+ flattened_data = self._flatten_data(self.knowledge_base)
346
+ # Find all relevant key-value pairs in the knowledge base
347
+ relevant_values = self._find_all_relevant_keys(message, flattened_data)
348
+ if relevant_values:
349
+ knowledge_base_context = ", ".join(relevant_values)
350
+
351
+ # Combine both contexts (RAG and knowledge base)
352
+ context = {
353
+ "rag_context": rag_context,
354
+ "knowledge_base_context": knowledge_base_context,
355
+ }
356
+ # Prepare the prompt with instructions, description, and context
357
+ prompt = self._build_prompt(message, context)
358
+
359
+ # Generate the response using the LLM
360
+ response = self.llm_instance.generate(prompt=prompt, context=context, **kwargs)
361
+
362
+ # Format the response based on the json_output flag
363
+ if self.json_output:
364
+ response = self._format_response_as_json(response)
365
+
366
+ # Validate the response against the expected_output
367
+ if self.expected_output:
368
+ response = self._validate_response(response)
369
+
370
+ if markdown:
371
+ return f"**Response:**\n\n{response}"
372
+ return response
373
+
374
+ # Combine all responses into a single output
375
+ return "\n\n".join(responses)
376
+
377
+ def _build_prompt(self, message: str, context: Optional[List[Dict]]) -> str:
378
+ """Build the prompt using instructions, description, and context."""
379
+ prompt_parts = []
380
+
381
+ # Add description if available
382
+ if self.description:
383
+ prompt_parts.append(f"Description: {self.description}")
384
+
385
+ # Add instructions if available
386
+ if self.instructions:
387
+ instructions = "\n".join(self.instructions)
388
+ prompt_parts.append(f"Instructions: {instructions}")
389
+
390
+ # Add context if available
391
+ if context:
392
+ prompt_parts.append(f"Context: {context}")
393
+
394
+ # Add the user's message
395
+ prompt_parts.append(f"User Input: {message}")
396
+
397
+ return "\n\n".join(prompt_parts)
398
+
399
+ def _format_response_as_json(self, response: str) -> Union[Dict, str]:
400
+ """Format the response as JSON if json_output is True."""
401
+ try:
402
+ # Use regex to extract JSON from the response (e.g., within ```json ``` blocks)
403
+ json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL)
404
+ if json_match:
405
+ # Extract the JSON part and parse it
406
+ json_str = json_match.group(1)
407
+ return json.loads(json_str) # Return the parsed JSON object (a dictionary)
408
+ else:
409
+ # If no JSON block is found, try to parse the entire response as JSON
410
+ return json.loads(response) # Return the parsed JSON object (a dictionary)
411
+ except json.JSONDecodeError:
412
+ # If the response is not valid JSON, wrap it in a dictionary
413
+ return {"response": response} # Return a dictionary with the response as a string
414
+
415
+ def normalize_key(self, key: str) -> str:
416
+ """
417
+ Normalize a key by converting it to lowercase and replacing spaces with underscores.
418
+ """
419
+ return key.lower().replace(" ", "_")
420
+
421
+ def match_key(self, expected_key, response_keys, threshold=0.5):
422
+ """
423
+ Match an expected key to the closest key in the response using semantic similarity or fuzzy matching.
424
+ """
425
+ expected_key_norm = self.normalize_key(expected_key)
426
+ response_keys_norm = [self.normalize_key(k) for k in response_keys]
427
+
428
+ if hasattr(self, 'semantic_model') and self.semantic_model is not None:
429
+ try:
430
+ # Compute embeddings for the expected key and all response keys
431
+ expected_embedding = self.semantic_model.encode(expected_key_norm, convert_to_tensor=True)
432
+ response_embeddings = self.semantic_model.encode(response_keys_norm, convert_to_tensor=True)
433
+
434
+ # Compute cosine similarity
435
+ similarity_scores = util.pytorch_cos_sim(expected_embedding, response_embeddings)[0]
436
+
437
+ # Find the best match
438
+ best_score = similarity_scores.max().item()
439
+ best_index = similarity_scores.argmax().item()
440
+
441
+ if best_score > threshold:
442
+ return response_keys[best_index], best_score
443
+ except Exception as e:
444
+ logging.warning(f"Semantic matching failed: {e}. Falling back to fuzzy matching.")
445
+
446
+ # Fallback to fuzzy matching
447
+ best_match = None
448
+ best_score = -1
449
+ for key, key_norm in zip(response_keys, response_keys_norm):
450
+ score = fuzz.ratio(expected_key_norm, key_norm) / 100
451
+ if score > best_score:
452
+ best_score = score
453
+ best_match = key
454
+
455
+ return best_match, best_score
456
+
457
+ def _validate_response(self, response: Union[str, Dict]) -> Union[str, Dict]:
458
+ """Validate the response against the expected_output format using semantic similarity or fallback methods."""
459
+ if isinstance(self.expected_output, dict):
460
+ if not isinstance(response, dict):
461
+ return {"response": response}
462
+
463
+ validated_response = {}
464
+ normalized_expected_keys = {self.normalize_key(k): k for k in self.expected_output.keys()}
465
+
466
+ for expected_key_norm, expected_key_orig in normalized_expected_keys.items():
467
+ # Find all response keys that match the expected key (case-insensitive and normalized)
468
+ matching_response_keys = [
469
+ k for k in response.keys()
470
+ if self.normalize_key(k) == expected_key_norm
471
+ ]
472
+
473
+ # If no exact match, use semantic matching to find similar keys
474
+ if not matching_response_keys:
475
+ for response_key in response.keys():
476
+ best_match, best_score = self.match_key(expected_key_orig, [response_key])
477
+ if best_match and best_score > 0.5: # Use a threshold to determine a valid match
478
+ matching_response_keys.append(response_key)
479
+
480
+ # Merge values from all matching keys
481
+ merged_values = []
482
+ for matching_key in matching_response_keys:
483
+ value = response[matching_key]
484
+ if isinstance(value, list):
485
+ merged_values.extend(value)
486
+ else:
487
+ merged_values.append(value)
488
+
489
+ # Assign the merged values to the expected key
490
+ if merged_values:
491
+ validated_response[expected_key_orig] = merged_values
492
+ else:
493
+ validated_response[expected_key_orig] = "NA" # Default value for missing keys
494
+
495
+ # Recursively validate nested dictionaries
496
+ expected_value = self.expected_output[expected_key_orig]
497
+ if isinstance(expected_value, dict) and isinstance(validated_response[expected_key_orig], dict):
498
+ validated_response[expected_key_orig] = self._validate_response(validated_response[expected_key_orig])
499
+
500
+ return validated_response
501
+ elif isinstance(self.expected_output, str):
502
+ if not isinstance(response, str):
503
+ return str(response)
504
+ return response
505
+
506
+ def cli_app(
507
+ self,
508
+ message: Optional[str] = None,
509
+ exit_on: Optional[List[str]] = None,
510
+ **kwargs,
511
+ ):
512
+ """Run the agent in a CLI app."""
513
+ from rich.prompt import Prompt
514
+
515
+ if message:
516
+ self.print_response(message=message, **kwargs)
517
+
518
+ _exit_on = exit_on or ["exit", "quit", "bye"]
519
+ while True:
520
+ message = Prompt.ask(f"[bold] {self.emoji} {self.user_name} [/bold]")
521
+ if message in _exit_on:
522
+ break
523
+
524
+ self.print_response(message=message, **kwargs)
525
+
526
+ def _generate_api(self):
527
+ """Generate an API for the agent if api=True."""
528
+ from .api.api_generator import APIGenerator
529
+ self.api_generator = APIGenerator(self)
530
+ print(f"API generated for agent '{self.name}'. Use `.run_api()` to start the API server.")
531
+
532
+ def run_api(self):
533
+ """Run the API server for the agent."""
534
+ if not hasattr(self, 'api_generator'):
535
+ raise ValueError("API is not enabled for this agent. Set `api=True` when initializing the agent.")
536
+
537
+ # Get API configuration
538
+ host = self.api_config.get("host", "0.0.0.0") if self.api_config else "0.0.0.0"
539
+ port = self.api_config.get("port", 8000) if self.api_config else 8000
540
+
541
+ # Run the API server
542
+ self.api_generator.run(host=host, port=port)
543
+
544
+ def _flatten_data(self, data: Union[Dict, List], parent_key: str = "", separator: str = "_") -> List[Dict]:
545
+ """
546
+ Recursively flatten a nested dictionary or list into a list of key-value pairs.
547
+
548
+ Args:
549
+ data (Union[Dict, List]): The nested data structure.
550
+ parent_key (str): The parent key (used for recursion).
551
+ separator (str): The separator used for nested keys.
552
+
553
+ Returns:
554
+ List[Dict]: A list of flattened key-value pairs.
555
+ """
556
+ items = []
557
+ if isinstance(data, dict):
558
+ for key, value in data.items():
559
+ new_key = f"{parent_key}{separator}{key}" if parent_key else key
560
+ if isinstance(value, (dict, list)):
561
+ items.extend(self._flatten_data(value, new_key, separator))
562
+ else:
563
+ items.append({new_key: value})
564
+ # Include the value as a key for searching
565
+ if isinstance(value, str):
566
+ items.append({value: new_key})
567
+ elif isinstance(data, list):
568
+ for index, item in enumerate(data):
569
+ new_key = f"{parent_key}{separator}{index}" if parent_key else str(index)
570
+ if isinstance(item, (dict, list)):
571
+ items.extend(self._flatten_data(item, new_key, separator))
572
+ else:
573
+ items.append({new_key: item})
574
+ # Include the value as a key for searching
575
+ if isinstance(item, str):
576
+ items.append({item: new_key})
577
+ return items
578
+
579
+ def _find_all_relevant_keys(self, query: str, flattened_data: List[Dict], threshold: float = 0.5) -> List[str]:
580
+ """
581
+ Find all relevant keys in the flattened data based on semantic similarity to the query.
582
+
583
+ Args:
584
+ query (str): The user's query.
585
+ flattened_data (List[Dict]): The flattened key-value pairs.
586
+ threshold (float): The similarity threshold for considering a match.
587
+
588
+ Returns:
589
+ List[str]: A list of relevant values.
590
+ """
591
+ if not flattened_data:
592
+ return []
593
+
594
+ # Extract keys from the flattened data
595
+ keys = [list(item.keys())[0] for item in flattened_data]
596
+
597
+ # Compute embeddings for the query and keys
598
+ query_embedding = self.semantic_model.encode(query, convert_to_tensor=True)
599
+ key_embeddings = self.semantic_model.encode(keys, convert_to_tensor=True)
600
+
601
+ # Compute cosine similarity between the query and keys
602
+ similarities = util.pytorch_cos_sim(query_embedding, key_embeddings)[0]
603
+
604
+ # Find all keys with a similarity score above the threshold
605
+ relevant_indices = [i for i, score in enumerate(similarities) if score > threshold]
606
+ relevant_values = [flattened_data[i][keys[i]] for i in relevant_indices]
607
+
608
+ return relevant_values
File without changes
@@ -0,0 +1,23 @@
1
+ from .fastapi_app import create_fastapi_app # Import the factory function
2
+
3
+ class APIGenerator:
4
+ def __init__(self, assistant):
5
+ """
6
+ Initialize the APIGenerator with the given assistant.
7
+
8
+ Args:
9
+ assistant: The assistant instance for which the API is being created.
10
+ """
11
+ self.assistant = assistant
12
+ self.app = create_fastapi_app(assistant, assistant.api_config) # Pass api_config to create_fastapi_app
13
+
14
+ def run(self, host: str = "0.0.0.0", port: int = 8000):
15
+ """
16
+ Run the FastAPI app.
17
+
18
+ Args:
19
+ host (str): The host address to run the API server on. Default is "0.0.0.0".
20
+ port (int): The port to run the API server on. Default is 8000.
21
+ """
22
+ import uvicorn
23
+ uvicorn.run(self.app, host=host, port=port)