alita-sdk 0.3.203__py3-none-any.whl → 0.3.205__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -58,8 +58,8 @@ class AlitaClient:
58
58
  self.list_apps_url = f"{self.base_url}{self.api_path}/applications/applications/prompt_lib/{self.project_id}"
59
59
  self.integration_details = f"{self.base_url}{self.api_path}/integrations/integration/{self.project_id}"
60
60
  self.secrets_url = f"{self.base_url}{self.api_path}/secrets/secret/{self.project_id}"
61
- self.artifacts_url = f"{self.base_url}{self.api_path}/artifacts/artifacts/{self.project_id}"
62
- self.artifact_url = f"{self.base_url}{self.api_path}/artifacts/artifact/{self.project_id}"
61
+ self.artifacts_url = f"{self.base_url}{self.api_path}/artifacts/artifacts/default/{self.project_id}"
62
+ self.artifact_url = f"{self.base_url}{self.api_path}/artifacts/artifact/default/{self.project_id}"
63
63
  self.bucket_url = f"{self.base_url}{self.api_path}/artifacts/buckets/{self.project_id}"
64
64
  self.configurations_url = f'{self.base_url}{self.api_path}/integrations/integrations/default/{self.project_id}?section=configurations&unsecret=true'
65
65
  self.ai_section_url = f'{self.base_url}{self.api_path}/integrations/integrations/default/{self.project_id}?section=ai'
@@ -291,7 +291,7 @@ class AlitaClient:
291
291
  return self._process_requst(data)
292
292
 
293
293
  def download_artifact(self, bucket_name, artifact_name):
294
- url = f'{self.artifact_url}/{bucket_name}/{artifact_name}'
294
+ url = f'{self.artifact_url}/{bucket_name.lower()}/{artifact_name}'
295
295
  data = requests.get(url, headers=self.headers, verify=False)
296
296
  if data.status_code == 403:
297
297
  return {"error": "You are not authorized to access this resource"}
@@ -1,13 +1,14 @@
1
1
  import json
2
- from json import dumps
3
- from typing import Any, Optional, List, Dict
2
+ import math
3
+ from typing import Any, Optional, List, Dict, Callable
4
4
  from pydantic import BaseModel, model_validator, Field
5
- from langchain_core.tools import ToolException
6
5
  from ..langchain.tools.vector import VectorAdapter
7
6
  from langchain_core.messages import HumanMessage
8
7
  from alita_sdk.tools.elitea_base import BaseToolApiWrapper
9
8
  from logging import getLogger
10
9
 
10
+ from ..utils.logging import dispatch_custom_event
11
+
11
12
  logger = getLogger(__name__)
12
13
 
13
14
  class IndexDocumentsModel(BaseModel):
@@ -139,6 +140,7 @@ class VectorStoreWrapper(BaseToolApiWrapper):
139
140
  vectoradapter: Any = None
140
141
  pg_helper: Any = None
141
142
  embeddings: Any = None
143
+ process_document_func: Optional[Callable] = None
142
144
 
143
145
  @model_validator(mode='before')
144
146
  @classmethod
@@ -182,18 +184,122 @@ class VectorStoreWrapper(BaseToolApiWrapper):
182
184
  except Exception as e:
183
185
  logger.error(f"Failed to initialize PGVectorSearch: {str(e)}")
184
186
 
185
- def index_documents(self, documents):
187
+ def _get_indexed_data(self, store):
188
+ """ Get all indexed data from vectorstore """
189
+
190
+ # get already indexed data
191
+ result = {}
192
+ try:
193
+ self._log_data("Retrieving already indexed data from vectorstore",
194
+ tool_name="index_documents")
195
+ data = store.get(include=['documents', 'metadatas'])
196
+ # re-structure data to be more usable
197
+ for doc_str, meta, db_id in zip(data['documents'], data['metadatas'], data['ids']):
198
+ doc = json.loads(doc_str)
199
+ doc_id = str(meta['id'])
200
+ result[doc_id] = {
201
+ 'metadata': meta,
202
+ 'document': doc,
203
+ 'id': db_id
204
+ }
205
+ except Exception as e:
206
+ logger.error(f"Failed to get indexed data from vectorstore: {str(e)}. Continuing with empty index.")
207
+ return result
208
+
209
+ def _reduce_duplicates(self, documents, store) -> List[Any]:
210
+ """Remove documents already indexed in the vectorstore based on metadata 'id' and 'updated_on' fields."""
211
+
212
+ self._log_data("Verification of documents to index started", tool_name="index_documents")
213
+
214
+ data = self._get_indexed_data(store)
215
+ indexed_ids = set(data.keys())
216
+ if not indexed_ids:
217
+ self._log_data("Vectorstore is empty, indexing all incoming documents", tool_name="index_documents")
218
+ return documents
219
+
220
+ final_docs = []
221
+ docs_to_remove = []
222
+
223
+ for document in documents:
224
+ doc_id = document.metadata.get('id')
225
+ # get document's metadata and id and check if already indexed
226
+ if doc_id in indexed_ids:
227
+ # document has been indexed already, then verify `updated_on`
228
+ to_index_updated_on = document.metadata.get('updated_on')
229
+ indexed_meta = data[doc_id]['metadata']
230
+ indexed_updated_on = indexed_meta.get('updated_on')
231
+ if to_index_updated_on and indexed_updated_on and to_index_updated_on == indexed_updated_on:
232
+ # same updated_on, skip indexing
233
+ continue
234
+ # if updated_on is missing or different, we will re-index the document and remove old one
235
+ docs_to_remove.append(data[doc_id]['id'])
236
+ else:
237
+ final_docs.append(document)
238
+
239
+ if docs_to_remove:
240
+ self._log_data(
241
+ f"Removing {len(docs_to_remove)} documents from vectorstore that are already indexed with different updated_on.",
242
+ tool_name="index_documents"
243
+ )
244
+ store.delete(ids=docs_to_remove)
245
+
246
+ return final_docs
247
+
248
+ def index_documents(self, documents, progress_step: int = 20, clean_index: bool = True):
249
+ """ Index documents in the vectorstore.
250
+
251
+ Args:
252
+ documents (Any): Generator or list of documents to index.
253
+ document_processing_func (Optional[Callable]): Function to process documents after duplicates removal and before indexing.
254
+ progress_step (int): Step for progress reporting, default is 20.
255
+ clean_index (bool): If True, clean the index before re-indexing all documents.
256
+ """
257
+
186
258
  from ..langchain.interfaces.llm_processor import add_documents
259
+
260
+ # pre-process documents if needed (find duplicates, etc.)
261
+ if clean_index:
262
+ logger.info("Cleaning index before re-indexing all documents.")
263
+ self._log_data("Cleaning index before re-indexing all documents. Previous index will be removed", tool_name="index_documents")
264
+ try:
265
+ self.vectoradapter.delete_dataset(self.dataset)
266
+ self.vectoradapter.persist()
267
+ self.vectoradapter.vacuum()
268
+ self._log_data("Previous index has been removed",
269
+ tool_name="index_documents")
270
+ except Exception as e:
271
+ logger.warning(f"Failed to clean index: {str(e)}. Continuing with re-indexing.")
272
+ else:
273
+ # remove duplicates based on metadata 'id' and 'updated_on' fields
274
+ documents = self._reduce_duplicates(documents, self.vectoradapter.vectorstore)
275
+
276
+
277
+ if not documents or len(documents) == 0:
278
+ logger.info("No new documents to index after duplicate check.")
279
+ return {"status": "ok", "message": "No new documents to index."}
280
+
281
+ # if func is provided, apply it to documents
282
+ # used for processing of documents before indexing,
283
+ # e.g. to avoid time-consuming operations for documents that are already indexed
284
+ self.process_document_func(documents) if self.process_document_func else None
285
+
286
+ # notify user about missed required metadata fields: id, updated_on
287
+ # it is not required to have them, but it is recommended to have them for proper re-indexing and duplicate detection
288
+ for doc in documents:
289
+ if 'id' not in doc.metadata or 'updated_on' not in doc.metadata:
290
+ logger.warning(f"Document is missing required metadata field 'id' or 'updated_on': {doc.metadata}")
291
+
187
292
  logger.debug(f"Indexing documents: {documents}")
188
293
  logger.debug(self.vectoradapter)
189
- self.vectoradapter.delete_dataset(self.dataset)
190
- self.vectoradapter.persist()
191
- logger.debug(f"Deleted Dataset")
192
- #
193
- self.vectoradapter.vacuum()
194
- #
294
+
295
+ documents = list(documents)
296
+ total_docs = len(documents)
195
297
  documents_count = 0
196
298
  _documents = []
299
+
300
+ # set default progress step to 20 if out of 0...100 or None
301
+ progress_step = 20 if progress_step not in range(0, 100) else progress_step
302
+ next_progress_point = progress_step
197
303
  for document in documents:
198
304
  documents_count += 1
199
305
  # logger.debug(f"Indexing document: {document}")
@@ -203,7 +309,14 @@ class VectorStoreWrapper(BaseToolApiWrapper):
203
309
  add_documents(vectorstore=self.vectoradapter.vectorstore, documents=_documents)
204
310
  self.vectoradapter.persist()
205
311
  _documents = []
206
- except Exception as e:
312
+
313
+ percent = math.floor((documents_count / total_docs) * 100)
314
+ if percent >= next_progress_point:
315
+ msg = f"Indexing progress: {percent}%. Processed {documents_count} of {total_docs} documents."
316
+ logger.debug(msg)
317
+ self._log_data(msg)
318
+ next_progress_point += progress_step
319
+ except Exception:
207
320
  from traceback import format_exc
208
321
  logger.error(f"Error: {format_exc()}")
209
322
  return {"status": "error", "message": f"Error: {format_exc()}"}
@@ -383,9 +496,11 @@ class VectorStoreWrapper(BaseToolApiWrapper):
383
496
  combined_items = [item for item in combined_items if abs(item[1]) >= cut_off]
384
497
 
385
498
  # Sort by score and limit results
386
- combined_items.sort(key=lambda x: x[1], reverse=True)
499
+
500
+ # for chroma we want ascending order (lower score is better), for others descending
501
+ combined_items.sort(key=lambda x: x[1], reverse= self.vectorstore_type.lower() != 'chroma')
387
502
  combined_items = combined_items[:search_top]
388
-
503
+
389
504
  # Format output based on doctype
390
505
  if doctype == 'code':
391
506
  return code_format(combined_items)
@@ -498,6 +613,21 @@ class VectorStoreWrapper(BaseToolApiWrapper):
498
613
  ])
499
614
  return result.content
500
615
 
616
+ def _log_data(self, message: str, tool_name: str = "index_data"):
617
+ """Log data and dispatch custom event for indexing progress"""
618
+
619
+ try:
620
+ dispatch_custom_event(
621
+ name="thinking_step",
622
+ data={
623
+ "message": message,
624
+ "tool_name": tool_name,
625
+ "toolkit": "vectorstore",
626
+ },
627
+ )
628
+ except Exception as e:
629
+ logger.warning(f"Failed to dispatch progress event: {str(e)}")
630
+
501
631
  def get_available_tools(self):
502
632
  return [
503
633
  {
@@ -80,6 +80,8 @@ _safe_import_tool('postman', 'postman', 'get_tools', 'PostmanToolkit')
80
80
  _safe_import_tool('memory', 'memory', 'get_tools', 'MemoryToolkit')
81
81
  _safe_import_tool('zephyr_squad', 'zephyr_squad', 'get_tools', 'ZephyrSquadToolkit')
82
82
  _safe_import_tool('slack', 'slack', 'get_tools', 'SlackToolkit')
83
+ _safe_import_tool('bigquery', 'google.bigquery', 'get_tools', 'BigQueryToolkit')
84
+ _safe_import_tool('delta_lake', 'aws.delta_lake', 'get_tools', 'DeltaLakeToolkit')
83
85
 
84
86
  # Log import summary
85
87
  available_count = len(AVAILABLE_TOOLS)
@@ -0,0 +1,7 @@
1
+ from .delta_lake import DeltaLakeToolkit
2
+
3
+ name = "aws"
4
+
5
+ def get_tools(tool_type, tool):
6
+ if tool_type == 'delta_lake':
7
+ return DeltaLakeToolkit().get_toolkit().get_tools()
@@ -0,0 +1,136 @@
1
+
2
+ from functools import lru_cache
3
+ from typing import List, Optional, Type
4
+
5
+ from langchain_core.tools import BaseTool, BaseToolkit
6
+ from pydantic import BaseModel, Field, SecretStr, computed_field, field_validator
7
+
8
+ from ...utils import TOOLKIT_SPLITTER, clean_string, get_max_toolkit_length
9
+ from .api_wrapper import DeltaLakeApiWrapper
10
+ from .tool import DeltaLakeAction
11
+
12
+ name = "delta_lake"
13
+
14
+ @lru_cache(maxsize=1)
15
+ def get_available_tools() -> dict[str, dict]:
16
+ api_wrapper = DeltaLakeApiWrapper.model_construct()
17
+ available_tools: dict = {
18
+ x["name"]: x["args_schema"].model_json_schema()
19
+ for x in api_wrapper.get_available_tools()
20
+ }
21
+ return available_tools
22
+
23
+ toolkit_max_length = lru_cache(maxsize=1)(
24
+ lambda: get_max_toolkit_length(get_available_tools())
25
+ )
26
+
27
+ class DeltaLakeToolkitConfig(BaseModel):
28
+ class Config:
29
+ title = name
30
+ json_schema_extra = {
31
+ "metadata": {
32
+ "hidden": True,
33
+ "label": "AWS Delta Lake",
34
+ "icon_url": "delta-lake.svg",
35
+ "sections": {
36
+ "auth": {
37
+ "required": False,
38
+ "subsections": [
39
+ {"name": "AWS Access Key ID", "fields": ["aws_access_key_id"]},
40
+ {"name": "AWS Secret Access Key", "fields": ["aws_secret_access_key"]},
41
+ {"name": "AWS Session Token", "fields": ["aws_session_token"]},
42
+ {"name": "AWS Region", "fields": ["aws_region"]},
43
+ ],
44
+ },
45
+ "connection": {
46
+ "required": False,
47
+ "subsections": [
48
+ {"name": "Delta Lake S3 Path", "fields": ["s3_path"]},
49
+ {"name": "Delta Lake Table Path", "fields": ["table_path"]},
50
+ ],
51
+ },
52
+ },
53
+ }
54
+ }
55
+
56
+ aws_access_key_id: Optional[SecretStr] = Field(default=None, description="AWS access key ID", json_schema_extra={"secret": True, "configuration": True})
57
+ aws_secret_access_key: Optional[SecretStr] = Field(default=None, description="AWS secret access key", json_schema_extra={"secret": True, "configuration": True})
58
+ aws_session_token: Optional[SecretStr] = Field(default=None, description="AWS session token (optional)", json_schema_extra={"secret": True, "configuration": True})
59
+ aws_region: Optional[str] = Field(default=None, description="AWS region for Delta Lake storage", json_schema_extra={"configuration": True})
60
+ s3_path: Optional[str] = Field(default=None, description="S3 path to Delta Lake data (e.g., s3://bucket/path)", json_schema_extra={"configuration": True})
61
+ table_path: Optional[str] = Field(default=None, description="Delta Lake table path (if not using s3_path)", json_schema_extra={"configuration": True})
62
+ selected_tools: List[str] = Field(default=[], description="Selected tools", json_schema_extra={"args_schemas": get_available_tools()})
63
+
64
+ @field_validator("selected_tools", mode="before", check_fields=False)
65
+ @classmethod
66
+ def selected_tools_validator(cls, value: List[str]) -> list[str]:
67
+ return [i for i in value if i in get_available_tools()]
68
+
69
+ def _get_toolkit(tool) -> BaseToolkit:
70
+ return DeltaLakeToolkit().get_toolkit(
71
+ selected_tools=tool["settings"].get("selected_tools", []),
72
+ aws_access_key_id=tool["settings"].get("aws_access_key_id", None),
73
+ aws_secret_access_key=tool["settings"].get("aws_secret_access_key", None),
74
+ aws_session_token=tool["settings"].get("aws_session_token", None),
75
+ aws_region=tool["settings"].get("aws_region", None),
76
+ s3_path=tool["settings"].get("s3_path", None),
77
+ table_path=tool["settings"].get("table_path", None),
78
+ toolkit_name=tool.get("toolkit_name"),
79
+ )
80
+
81
+ def get_toolkit():
82
+ return DeltaLakeToolkit.toolkit_config_schema()
83
+
84
+ def get_tools(tool):
85
+ return _get_toolkit(tool).get_tools()
86
+
87
+ class DeltaLakeToolkit(BaseToolkit):
88
+ tools: List[BaseTool] = []
89
+ api_wrapper: Optional[DeltaLakeApiWrapper] = Field(default_factory=DeltaLakeApiWrapper.model_construct)
90
+ toolkit_name: Optional[str] = None
91
+
92
+ @computed_field
93
+ @property
94
+ def tool_prefix(self) -> str:
95
+ return (
96
+ clean_string(self.toolkit_name, toolkit_max_length()) + TOOLKIT_SPLITTER
97
+ if self.toolkit_name
98
+ else ""
99
+ )
100
+
101
+ @computed_field
102
+ @property
103
+ def available_tools(self) -> List[dict]:
104
+ return self.api_wrapper.get_available_tools()
105
+
106
+ @staticmethod
107
+ def toolkit_config_schema() -> Type[BaseModel]:
108
+ return DeltaLakeToolkitConfig
109
+
110
+ @classmethod
111
+ def get_toolkit(
112
+ cls,
113
+ selected_tools: list[str] | None = None,
114
+ toolkit_name: Optional[str] = None,
115
+ **kwargs,
116
+ ) -> "DeltaLakeToolkit":
117
+ delta_lake_api_wrapper = DeltaLakeApiWrapper(**kwargs)
118
+ instance = cls(
119
+ tools=[], api_wrapper=delta_lake_api_wrapper, toolkit_name=toolkit_name
120
+ )
121
+ if selected_tools:
122
+ selected_tools = set(selected_tools)
123
+ for t in instance.available_tools:
124
+ if t["name"] in selected_tools:
125
+ instance.tools.append(
126
+ DeltaLakeAction(
127
+ api_wrapper=instance.api_wrapper,
128
+ name=instance.tool_prefix + t["name"],
129
+ description=f"S3 Path: {getattr(instance.api_wrapper, 's3_path', '')} Table Path: {getattr(instance.api_wrapper, 'table_path', '')}\n" + t["description"],
130
+ args_schema=t["args_schema"],
131
+ )
132
+ )
133
+ return instance
134
+
135
+ def get_tools(self):
136
+ return self.tools
@@ -0,0 +1,220 @@
1
+ import functools
2
+ import json
3
+ import logging
4
+ from typing import Any, List, Optional
5
+
6
+ from deltalake import DeltaTable
7
+ from langchain_core.tools import ToolException
8
+ from pydantic import (
9
+ ConfigDict,
10
+ Field,
11
+ PrivateAttr,
12
+ SecretStr,
13
+ field_validator,
14
+ model_validator,
15
+ )
16
+ from pydantic_core.core_schema import ValidationInfo
17
+ from ...elitea_base import BaseToolApiWrapper
18
+ from .schemas import ArgsSchema
19
+
20
+
21
+ def process_output(func):
22
+ @functools.wraps(func)
23
+ def wrapper(self, *args, **kwargs):
24
+ try:
25
+ result = func(self, *args, **kwargs)
26
+ if isinstance(result, Exception):
27
+ return ToolException(str(result))
28
+ if isinstance(result, (dict, list)):
29
+ return json.dumps(result, default=str)
30
+ return str(result)
31
+ except Exception as e:
32
+ logging.error(f"Error in '{func.__name__}': {str(e)}")
33
+ return ToolException(str(e))
34
+ return wrapper
35
+
36
+
37
+ class DeltaLakeApiWrapper(BaseToolApiWrapper):
38
+ """
39
+ API Wrapper for AWS Delta Lake. Handles authentication, querying, and utility methods.
40
+ """
41
+ model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True)
42
+
43
+ aws_access_key_id: Optional[SecretStr] = Field(default=None, json_schema_extra={"env_key": "AWS_ACCESS_KEY_ID"})
44
+ aws_secret_access_key: Optional[SecretStr] = Field(default=None, json_schema_extra={"env_key": "AWS_SECRET_ACCESS_KEY"})
45
+ aws_session_token: Optional[SecretStr] = Field(default=None, json_schema_extra={"env_key": "AWS_SESSION_TOKEN"})
46
+ aws_region: Optional[str] = Field(default=None, json_schema_extra={"env_key": "AWS_REGION"})
47
+ s3_path: Optional[str] = Field(default=None, json_schema_extra={"env_key": "DELTA_LAKE_S3_PATH"})
48
+ table_path: Optional[str] = Field(default=None, json_schema_extra={"env_key": "DELTA_LAKE_TABLE_PATH"})
49
+ _delta_table: Optional[DeltaTable] = PrivateAttr(default=None)
50
+
51
+ @classmethod
52
+ def model_construct(cls, *args, **kwargs):
53
+ klass = super().model_construct(*args, **kwargs)
54
+ klass._delta_table = None
55
+ return klass
56
+
57
+ @field_validator(
58
+ "aws_access_key_id",
59
+ "aws_secret_access_key",
60
+ "aws_session_token",
61
+ "aws_region",
62
+ "s3_path",
63
+ "table_path",
64
+ mode="before",
65
+ check_fields=False,
66
+ )
67
+ @classmethod
68
+ def set_from_values_or_env(cls, value, info: ValidationInfo):
69
+ if value is None:
70
+ if json_schema_extra := cls.model_fields[info.field_name].json_schema_extra:
71
+ if env_key := json_schema_extra.get("env_key"):
72
+ try:
73
+ from langchain_core.utils import get_from_env
74
+ return get_from_env(
75
+ key=info.field_name,
76
+ env_key=env_key,
77
+ default=cls.model_fields[info.field_name].default,
78
+ )
79
+ except Exception:
80
+ return None
81
+ return value
82
+
83
+ @model_validator(mode="after")
84
+ def validate_auth(self) -> "DeltaLakeApiWrapper":
85
+ if not (self.aws_access_key_id and self.aws_secret_access_key and self.aws_region):
86
+ raise ValueError("You must provide AWS credentials and region.")
87
+ if not (self.s3_path or self.table_path):
88
+ raise ValueError("You must provide either s3_path or table_path.")
89
+ return self
90
+
91
+ @property
92
+ def delta_table(self) -> DeltaTable:
93
+ if not self._delta_table:
94
+ path = self.table_path or self.s3_path
95
+ if not path:
96
+ raise ToolException("Delta Lake table path (table_path or s3_path) must be specified.")
97
+ try:
98
+ storage_options = {
99
+ "AWS_ACCESS_KEY_ID": self.aws_access_key_id.get_secret_value() if self.aws_access_key_id else None,
100
+ "AWS_SECRET_ACCESS_KEY": self.aws_secret_access_key.get_secret_value() if self.aws_secret_access_key else None,
101
+ "AWS_REGION": self.aws_region,
102
+ }
103
+ if self.aws_session_token:
104
+ storage_options["AWS_SESSION_TOKEN"] = self.aws_session_token.get_secret_value()
105
+ storage_options = {k: v for k, v in storage_options.items() if v is not None}
106
+ self._delta_table = DeltaTable(path, storage_options=storage_options)
107
+ except Exception as e:
108
+ raise ToolException(f"Error initializing DeltaTable: {e}")
109
+ return self._delta_table
110
+
111
+ @process_output
112
+ def query_table(self, query: Optional[str] = None, columns: Optional[List[str]] = None, filters: Optional[dict] = None) -> List[dict]:
113
+ """
114
+ Query Delta Lake table. Supports pandas-like filtering, column selection, and SQL-like queries (via pandas.DataFrame.query).
115
+ Args:
116
+ query: SQL-like query string (pandas.DataFrame.query syntax)
117
+ columns: List of columns to select
118
+ filters: Dict of column:value pairs for pandas-like filtering
119
+ Returns:
120
+ List of dicts representing rows
121
+ """
122
+ dt = self.delta_table
123
+ df = dt.to_pandas()
124
+ if filters:
125
+ for col, val in filters.items():
126
+ df = df[df[col] == val]
127
+ if query:
128
+ try:
129
+ df = df.query(query)
130
+ except Exception as e:
131
+ raise ToolException(f"Error in query param: {e}")
132
+ if columns:
133
+ df = df[columns]
134
+ return df.to_dict(orient="records")
135
+
136
+ @process_output
137
+ def vector_search(self, embedding: List[float], k: int = 5, embedding_column: str = "embedding") -> List[dict]:
138
+ """
139
+ Perform a vector similarity search on the Delta Lake table.
140
+ Args:
141
+ embedding: Query embedding vector.
142
+ k: Number of top results to return.
143
+ embedding_column: Name of the column containing embeddings.
144
+ Returns:
145
+ List of dicts for top k most similar rows.
146
+ """
147
+ import numpy as np
148
+
149
+ dt = self.delta_table
150
+ df = dt.to_pandas()
151
+ if embedding_column not in df.columns:
152
+ raise ToolException(f"Embedding column '{embedding_column}' not found in table.")
153
+
154
+ # Filter out rows with missing embeddings
155
+ df = df[df[embedding_column].notnull()]
156
+ if df.empty:
157
+ return []
158
+ # Convert embeddings to numpy arrays
159
+ emb_matrix = np.array(df[embedding_column].tolist())
160
+ query_vec = np.array(embedding)
161
+
162
+ # Normalize for cosine similarity
163
+ emb_matrix_norm = emb_matrix / np.linalg.norm(emb_matrix, axis=1, keepdims=True)
164
+ query_vec_norm = query_vec / np.linalg.norm(query_vec)
165
+ similarities = np.dot(emb_matrix_norm, query_vec_norm)
166
+
167
+ # Get top k indices
168
+ top_k_idx = np.argsort(similarities)[-k:][::-1]
169
+ top_rows = df.iloc[top_k_idx]
170
+ return top_rows.to_dict(orient="records")
171
+
172
+ @process_output
173
+ def get_table_schema(self) -> str:
174
+ dt = self.delta_table
175
+ return dt.schema().to_pyarrow().to_string()
176
+
177
+ def get_available_tools(self) -> List[dict]:
178
+ return [
179
+ {
180
+ "name": "query_table",
181
+ "description": self.query_table.__doc__,
182
+ "args_schema": ArgsSchema.QueryTableArgs.value,
183
+ "ref": self.query_table,
184
+ },
185
+ {
186
+ "name": "vector_search",
187
+ "description": self.vector_search.__doc__,
188
+ "args_schema": ArgsSchema.VectorSearchArgs.value,
189
+ "ref": self.vector_search,
190
+ },
191
+ {
192
+ "name": "get_table_schema",
193
+ "description": self.get_table_schema.__doc__,
194
+ "args_schema": ArgsSchema.NoInput.value,
195
+ "ref": self.get_table_schema,
196
+ },
197
+ ]
198
+
199
+ def run(self, name: str, *args: Any, **kwargs: Any):
200
+ for tool in self.get_available_tools():
201
+ if tool["name"] == name:
202
+ if len(args) == 1 and isinstance(args[0], dict) and not kwargs:
203
+ kwargs = args[0]
204
+ args = ()
205
+ try:
206
+ return tool["ref"](*args, **kwargs)
207
+ except TypeError as e:
208
+ if kwargs and not args:
209
+ try:
210
+ return tool["ref"](**kwargs)
211
+ except TypeError:
212
+ raise ValueError(
213
+ f"Argument mismatch for tool '{name}'. Error: {e}"
214
+ ) from e
215
+ else:
216
+ raise ValueError(
217
+ f"Argument mismatch for tool '{name}'. Error: {e}"
218
+ ) from e
219
+ else:
220
+ raise ValueError(f"Unknown tool name: {name}")
@@ -0,0 +1,20 @@
1
+
2
+ from enum import Enum
3
+ from typing import List, Optional
4
+
5
+ from pydantic import Field, create_model
6
+
7
+ class ArgsSchema(Enum):
8
+ NoInput = create_model("NoInput")
9
+ QueryTableArgs = create_model(
10
+ "QueryTableArgs",
11
+ query=(Optional[str], Field(default=None, description="SQL query to execute on Delta Lake table. If None, returns all data.")),
12
+ columns=(Optional[List[str]], Field(default=None, description="List of columns to select.")),
13
+ filters=(Optional[dict], Field(default=None, description="Dict of column:value pairs for pandas-like filtering.")),
14
+ )
15
+ VectorSearchArgs = create_model(
16
+ "VectorSearchArgs",
17
+ embedding=(List[float], Field(description="Embedding vector for similarity search.")),
18
+ k=(int, Field(default=5, description="Number of top results to return.")),
19
+ embedding_column=(Optional[str], Field(default="embedding", description="Name of the column containing embeddings.")),
20
+ )
@@ -0,0 +1,35 @@
1
+
2
+ from typing import Optional, Type
3
+
4
+ from langchain_core.callbacks import CallbackManagerForToolRun
5
+ from pydantic import BaseModel, field_validator, Field
6
+ from langchain_core.tools import BaseTool
7
+ from traceback import format_exc
8
+ from .api_wrapper import DeltaLakeApiWrapper
9
+
10
+
11
+ class DeltaLakeAction(BaseTool):
12
+ """Tool for interacting with the Delta Lake API on AWS."""
13
+
14
+ api_wrapper: DeltaLakeApiWrapper = Field(default_factory=DeltaLakeApiWrapper)
15
+ name: str
16
+ description: str = ""
17
+ args_schema: Optional[Type[BaseModel]] = None
18
+
19
+ @field_validator('name', mode='before')
20
+ @classmethod
21
+ def remove_spaces(cls, v):
22
+ return v.replace(' ', '')
23
+
24
+ def _run(
25
+ self,
26
+ *args,
27
+ run_manager: Optional[CallbackManagerForToolRun] = None,
28
+ **kwargs,
29
+ ) -> str:
30
+ """Use the Delta Lake API to run an operation."""
31
+ try:
32
+ # Use the tool name to dispatch to the correct API wrapper method
33
+ return self.api_wrapper.run(self.name, *args, **kwargs)
34
+ except Exception as e:
35
+ return f"Error: {format_exc()}"