alita-sdk 0.3.204__py3-none-any.whl → 0.3.206__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alita_sdk/runtime/clients/client.py +45 -5
- alita_sdk/runtime/langchain/assistant.py +22 -21
- alita_sdk/runtime/langchain/interfaces/llm_processor.py +1 -4
- alita_sdk/runtime/toolkits/application.py +5 -10
- alita_sdk/runtime/toolkits/tools.py +0 -1
- alita_sdk/runtime/tools/vectorstore.py +157 -13
- alita_sdk/runtime/utils/streamlit.py +33 -30
- alita_sdk/runtime/utils/utils.py +5 -0
- alita_sdk/tools/__init__.py +4 -0
- alita_sdk/tools/ado/repos/repos_wrapper.py +20 -13
- alita_sdk/tools/aws/__init__.py +7 -0
- alita_sdk/tools/aws/delta_lake/__init__.py +136 -0
- alita_sdk/tools/aws/delta_lake/api_wrapper.py +220 -0
- alita_sdk/tools/aws/delta_lake/schemas.py +20 -0
- alita_sdk/tools/aws/delta_lake/tool.py +35 -0
- alita_sdk/tools/bitbucket/api_wrapper.py +5 -5
- alita_sdk/tools/bitbucket/cloud_api_wrapper.py +54 -29
- alita_sdk/tools/elitea_base.py +55 -5
- alita_sdk/tools/gitlab/__init__.py +22 -10
- alita_sdk/tools/gitlab/api_wrapper.py +278 -253
- alita_sdk/tools/gitlab/tools.py +354 -376
- alita_sdk/tools/google/__init__.py +7 -0
- alita_sdk/tools/google/bigquery/__init__.py +154 -0
- alita_sdk/tools/google/bigquery/api_wrapper.py +502 -0
- alita_sdk/tools/google/bigquery/schemas.py +102 -0
- alita_sdk/tools/google/bigquery/tool.py +34 -0
- alita_sdk/tools/llm/llm_utils.py +0 -6
- alita_sdk/tools/openapi/__init__.py +14 -3
- alita_sdk/tools/sharepoint/__init__.py +2 -1
- alita_sdk/tools/sharepoint/api_wrapper.py +71 -7
- alita_sdk/tools/testrail/__init__.py +9 -1
- alita_sdk/tools/testrail/api_wrapper.py +154 -5
- alita_sdk/tools/utils/content_parser.py +77 -13
- alita_sdk/tools/zephyr_scale/api_wrapper.py +271 -22
- {alita_sdk-0.3.204.dist-info → alita_sdk-0.3.206.dist-info}/METADATA +3 -1
- {alita_sdk-0.3.204.dist-info → alita_sdk-0.3.206.dist-info}/RECORD +39 -30
- alita_sdk/runtime/llms/alita.py +0 -259
- {alita_sdk-0.3.204.dist-info → alita_sdk-0.3.206.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.204.dist-info → alita_sdk-0.3.206.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.204.dist-info → alita_sdk-0.3.206.dist-info}/top_level.txt +0 -0
@@ -250,6 +250,7 @@ class ReposApiWrapper(BaseCodeToolApiWrapper):
|
|
250
250
|
token: Optional[SecretStr]
|
251
251
|
_client: Optional[GitClient] = PrivateAttr()
|
252
252
|
|
253
|
+
llm: Optional[Any] = None
|
253
254
|
# Vector store configuration
|
254
255
|
connection_string: Optional[SecretStr] = None
|
255
256
|
collection_name: Optional[str] = None
|
@@ -303,24 +304,30 @@ class ReposApiWrapper(BaseCodeToolApiWrapper):
|
|
303
304
|
|
304
305
|
def _get_files(
|
305
306
|
self,
|
306
|
-
|
307
|
-
|
307
|
+
path: str = "",
|
308
|
+
branch: str = None,
|
308
309
|
recursion_level: str = "Full",
|
309
310
|
) -> str:
|
311
|
+
"""Get list of files from a repository path and branch.
|
312
|
+
|
313
|
+
Args:
|
314
|
+
path (str): Path within the repository to list files from
|
315
|
+
branch (str): Branch to get files from. Defaults to base_branch if None.
|
316
|
+
recursion_level (str): OneLevel - includes immediate children, Full - includes all items, None - no recursion
|
317
|
+
|
318
|
+
Returns:
|
319
|
+
List[str]: List of file paths
|
310
320
|
"""
|
311
|
-
|
312
|
-
recursion_level: OneLevel - includes immediate children, Full - includes all items, None - no recursion
|
313
|
-
"""
|
314
|
-
branch_name = branch_name if branch_name else self.base_branch
|
321
|
+
branch = branch if branch else self.base_branch
|
315
322
|
files: List[str] = []
|
316
323
|
try:
|
317
324
|
version_descriptor = GitVersionDescriptor(
|
318
|
-
version=
|
325
|
+
version=branch, version_type="branch"
|
319
326
|
)
|
320
327
|
items = self._client.get_items(
|
321
328
|
repository_id=self.repository_id,
|
322
329
|
project=self.project,
|
323
|
-
scope_path=
|
330
|
+
scope_path=path,
|
324
331
|
recursion_level=recursion_level,
|
325
332
|
version_descriptor=version_descriptor,
|
326
333
|
include_content_metadata=True,
|
@@ -334,7 +341,7 @@ class ReposApiWrapper(BaseCodeToolApiWrapper):
|
|
334
341
|
item = items.pop(0)
|
335
342
|
if item.git_object_type == "blob":
|
336
343
|
files.append(item.path)
|
337
|
-
return str
|
344
|
+
return files # Changed to return list directly instead of str
|
338
345
|
|
339
346
|
def set_active_branch(self, branch_name: str) -> str:
|
340
347
|
"""
|
@@ -389,7 +396,7 @@ class ReposApiWrapper(BaseCodeToolApiWrapper):
|
|
389
396
|
logger.error(msg)
|
390
397
|
return ToolException(msg)
|
391
398
|
|
392
|
-
def list_files(self, directory_path: str = "", branch_name: str = None) -> str:
|
399
|
+
def list_files(self, directory_path: str = "", branch_name: str = None) -> List[str]:
|
393
400
|
"""
|
394
401
|
Recursively fetches files from a directory in the repo.
|
395
402
|
|
@@ -398,12 +405,12 @@ class ReposApiWrapper(BaseCodeToolApiWrapper):
|
|
398
405
|
branch_name (str): The name of the branch where the files to be received.
|
399
406
|
|
400
407
|
Returns:
|
401
|
-
str: List of file paths, or an error message.
|
408
|
+
List[str]: List of file paths, or an error message.
|
402
409
|
"""
|
403
410
|
self.active_branch = branch_name if branch_name else self.active_branch
|
404
411
|
return self._get_files(
|
405
|
-
|
406
|
-
|
412
|
+
path=directory_path,
|
413
|
+
branch=self.active_branch if self.active_branch else self.base_branch,
|
407
414
|
)
|
408
415
|
|
409
416
|
def parse_pull_request_comments(
|
@@ -0,0 +1,136 @@
|
|
1
|
+
|
2
|
+
from functools import lru_cache
|
3
|
+
from typing import List, Optional, Type
|
4
|
+
|
5
|
+
from langchain_core.tools import BaseTool, BaseToolkit
|
6
|
+
from pydantic import BaseModel, Field, SecretStr, computed_field, field_validator
|
7
|
+
|
8
|
+
from ...utils import TOOLKIT_SPLITTER, clean_string, get_max_toolkit_length
|
9
|
+
from .api_wrapper import DeltaLakeApiWrapper
|
10
|
+
from .tool import DeltaLakeAction
|
11
|
+
|
12
|
+
name = "delta_lake"
|
13
|
+
|
14
|
+
@lru_cache(maxsize=1)
|
15
|
+
def get_available_tools() -> dict[str, dict]:
|
16
|
+
api_wrapper = DeltaLakeApiWrapper.model_construct()
|
17
|
+
available_tools: dict = {
|
18
|
+
x["name"]: x["args_schema"].model_json_schema()
|
19
|
+
for x in api_wrapper.get_available_tools()
|
20
|
+
}
|
21
|
+
return available_tools
|
22
|
+
|
23
|
+
toolkit_max_length = lru_cache(maxsize=1)(
|
24
|
+
lambda: get_max_toolkit_length(get_available_tools())
|
25
|
+
)
|
26
|
+
|
27
|
+
class DeltaLakeToolkitConfig(BaseModel):
|
28
|
+
class Config:
|
29
|
+
title = name
|
30
|
+
json_schema_extra = {
|
31
|
+
"metadata": {
|
32
|
+
"hidden": True,
|
33
|
+
"label": "AWS Delta Lake",
|
34
|
+
"icon_url": "delta-lake.svg",
|
35
|
+
"sections": {
|
36
|
+
"auth": {
|
37
|
+
"required": False,
|
38
|
+
"subsections": [
|
39
|
+
{"name": "AWS Access Key ID", "fields": ["aws_access_key_id"]},
|
40
|
+
{"name": "AWS Secret Access Key", "fields": ["aws_secret_access_key"]},
|
41
|
+
{"name": "AWS Session Token", "fields": ["aws_session_token"]},
|
42
|
+
{"name": "AWS Region", "fields": ["aws_region"]},
|
43
|
+
],
|
44
|
+
},
|
45
|
+
"connection": {
|
46
|
+
"required": False,
|
47
|
+
"subsections": [
|
48
|
+
{"name": "Delta Lake S3 Path", "fields": ["s3_path"]},
|
49
|
+
{"name": "Delta Lake Table Path", "fields": ["table_path"]},
|
50
|
+
],
|
51
|
+
},
|
52
|
+
},
|
53
|
+
}
|
54
|
+
}
|
55
|
+
|
56
|
+
aws_access_key_id: Optional[SecretStr] = Field(default=None, description="AWS access key ID", json_schema_extra={"secret": True, "configuration": True})
|
57
|
+
aws_secret_access_key: Optional[SecretStr] = Field(default=None, description="AWS secret access key", json_schema_extra={"secret": True, "configuration": True})
|
58
|
+
aws_session_token: Optional[SecretStr] = Field(default=None, description="AWS session token (optional)", json_schema_extra={"secret": True, "configuration": True})
|
59
|
+
aws_region: Optional[str] = Field(default=None, description="AWS region for Delta Lake storage", json_schema_extra={"configuration": True})
|
60
|
+
s3_path: Optional[str] = Field(default=None, description="S3 path to Delta Lake data (e.g., s3://bucket/path)", json_schema_extra={"configuration": True})
|
61
|
+
table_path: Optional[str] = Field(default=None, description="Delta Lake table path (if not using s3_path)", json_schema_extra={"configuration": True})
|
62
|
+
selected_tools: List[str] = Field(default=[], description="Selected tools", json_schema_extra={"args_schemas": get_available_tools()})
|
63
|
+
|
64
|
+
@field_validator("selected_tools", mode="before", check_fields=False)
|
65
|
+
@classmethod
|
66
|
+
def selected_tools_validator(cls, value: List[str]) -> list[str]:
|
67
|
+
return [i for i in value if i in get_available_tools()]
|
68
|
+
|
69
|
+
def _get_toolkit(tool) -> BaseToolkit:
|
70
|
+
return DeltaLakeToolkit().get_toolkit(
|
71
|
+
selected_tools=tool["settings"].get("selected_tools", []),
|
72
|
+
aws_access_key_id=tool["settings"].get("aws_access_key_id", None),
|
73
|
+
aws_secret_access_key=tool["settings"].get("aws_secret_access_key", None),
|
74
|
+
aws_session_token=tool["settings"].get("aws_session_token", None),
|
75
|
+
aws_region=tool["settings"].get("aws_region", None),
|
76
|
+
s3_path=tool["settings"].get("s3_path", None),
|
77
|
+
table_path=tool["settings"].get("table_path", None),
|
78
|
+
toolkit_name=tool.get("toolkit_name"),
|
79
|
+
)
|
80
|
+
|
81
|
+
def get_toolkit():
|
82
|
+
return DeltaLakeToolkit.toolkit_config_schema()
|
83
|
+
|
84
|
+
def get_tools(tool):
|
85
|
+
return _get_toolkit(tool).get_tools()
|
86
|
+
|
87
|
+
class DeltaLakeToolkit(BaseToolkit):
|
88
|
+
tools: List[BaseTool] = []
|
89
|
+
api_wrapper: Optional[DeltaLakeApiWrapper] = Field(default_factory=DeltaLakeApiWrapper.model_construct)
|
90
|
+
toolkit_name: Optional[str] = None
|
91
|
+
|
92
|
+
@computed_field
|
93
|
+
@property
|
94
|
+
def tool_prefix(self) -> str:
|
95
|
+
return (
|
96
|
+
clean_string(self.toolkit_name, toolkit_max_length()) + TOOLKIT_SPLITTER
|
97
|
+
if self.toolkit_name
|
98
|
+
else ""
|
99
|
+
)
|
100
|
+
|
101
|
+
@computed_field
|
102
|
+
@property
|
103
|
+
def available_tools(self) -> List[dict]:
|
104
|
+
return self.api_wrapper.get_available_tools()
|
105
|
+
|
106
|
+
@staticmethod
|
107
|
+
def toolkit_config_schema() -> Type[BaseModel]:
|
108
|
+
return DeltaLakeToolkitConfig
|
109
|
+
|
110
|
+
@classmethod
|
111
|
+
def get_toolkit(
|
112
|
+
cls,
|
113
|
+
selected_tools: list[str] | None = None,
|
114
|
+
toolkit_name: Optional[str] = None,
|
115
|
+
**kwargs,
|
116
|
+
) -> "DeltaLakeToolkit":
|
117
|
+
delta_lake_api_wrapper = DeltaLakeApiWrapper(**kwargs)
|
118
|
+
instance = cls(
|
119
|
+
tools=[], api_wrapper=delta_lake_api_wrapper, toolkit_name=toolkit_name
|
120
|
+
)
|
121
|
+
if selected_tools:
|
122
|
+
selected_tools = set(selected_tools)
|
123
|
+
for t in instance.available_tools:
|
124
|
+
if t["name"] in selected_tools:
|
125
|
+
instance.tools.append(
|
126
|
+
DeltaLakeAction(
|
127
|
+
api_wrapper=instance.api_wrapper,
|
128
|
+
name=instance.tool_prefix + t["name"],
|
129
|
+
description=f"S3 Path: {getattr(instance.api_wrapper, 's3_path', '')} Table Path: {getattr(instance.api_wrapper, 'table_path', '')}\n" + t["description"],
|
130
|
+
args_schema=t["args_schema"],
|
131
|
+
)
|
132
|
+
)
|
133
|
+
return instance
|
134
|
+
|
135
|
+
def get_tools(self):
|
136
|
+
return self.tools
|
@@ -0,0 +1,220 @@
|
|
1
|
+
import functools
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
from typing import Any, List, Optional
|
5
|
+
|
6
|
+
from deltalake import DeltaTable
|
7
|
+
from langchain_core.tools import ToolException
|
8
|
+
from pydantic import (
|
9
|
+
ConfigDict,
|
10
|
+
Field,
|
11
|
+
PrivateAttr,
|
12
|
+
SecretStr,
|
13
|
+
field_validator,
|
14
|
+
model_validator,
|
15
|
+
)
|
16
|
+
from pydantic_core.core_schema import ValidationInfo
|
17
|
+
from ...elitea_base import BaseToolApiWrapper
|
18
|
+
from .schemas import ArgsSchema
|
19
|
+
|
20
|
+
|
21
|
+
def process_output(func):
|
22
|
+
@functools.wraps(func)
|
23
|
+
def wrapper(self, *args, **kwargs):
|
24
|
+
try:
|
25
|
+
result = func(self, *args, **kwargs)
|
26
|
+
if isinstance(result, Exception):
|
27
|
+
return ToolException(str(result))
|
28
|
+
if isinstance(result, (dict, list)):
|
29
|
+
return json.dumps(result, default=str)
|
30
|
+
return str(result)
|
31
|
+
except Exception as e:
|
32
|
+
logging.error(f"Error in '{func.__name__}': {str(e)}")
|
33
|
+
return ToolException(str(e))
|
34
|
+
return wrapper
|
35
|
+
|
36
|
+
|
37
|
+
class DeltaLakeApiWrapper(BaseToolApiWrapper):
|
38
|
+
"""
|
39
|
+
API Wrapper for AWS Delta Lake. Handles authentication, querying, and utility methods.
|
40
|
+
"""
|
41
|
+
model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True)
|
42
|
+
|
43
|
+
aws_access_key_id: Optional[SecretStr] = Field(default=None, json_schema_extra={"env_key": "AWS_ACCESS_KEY_ID"})
|
44
|
+
aws_secret_access_key: Optional[SecretStr] = Field(default=None, json_schema_extra={"env_key": "AWS_SECRET_ACCESS_KEY"})
|
45
|
+
aws_session_token: Optional[SecretStr] = Field(default=None, json_schema_extra={"env_key": "AWS_SESSION_TOKEN"})
|
46
|
+
aws_region: Optional[str] = Field(default=None, json_schema_extra={"env_key": "AWS_REGION"})
|
47
|
+
s3_path: Optional[str] = Field(default=None, json_schema_extra={"env_key": "DELTA_LAKE_S3_PATH"})
|
48
|
+
table_path: Optional[str] = Field(default=None, json_schema_extra={"env_key": "DELTA_LAKE_TABLE_PATH"})
|
49
|
+
_delta_table: Optional[DeltaTable] = PrivateAttr(default=None)
|
50
|
+
|
51
|
+
@classmethod
|
52
|
+
def model_construct(cls, *args, **kwargs):
|
53
|
+
klass = super().model_construct(*args, **kwargs)
|
54
|
+
klass._delta_table = None
|
55
|
+
return klass
|
56
|
+
|
57
|
+
@field_validator(
|
58
|
+
"aws_access_key_id",
|
59
|
+
"aws_secret_access_key",
|
60
|
+
"aws_session_token",
|
61
|
+
"aws_region",
|
62
|
+
"s3_path",
|
63
|
+
"table_path",
|
64
|
+
mode="before",
|
65
|
+
check_fields=False,
|
66
|
+
)
|
67
|
+
@classmethod
|
68
|
+
def set_from_values_or_env(cls, value, info: ValidationInfo):
|
69
|
+
if value is None:
|
70
|
+
if json_schema_extra := cls.model_fields[info.field_name].json_schema_extra:
|
71
|
+
if env_key := json_schema_extra.get("env_key"):
|
72
|
+
try:
|
73
|
+
from langchain_core.utils import get_from_env
|
74
|
+
return get_from_env(
|
75
|
+
key=info.field_name,
|
76
|
+
env_key=env_key,
|
77
|
+
default=cls.model_fields[info.field_name].default,
|
78
|
+
)
|
79
|
+
except Exception:
|
80
|
+
return None
|
81
|
+
return value
|
82
|
+
|
83
|
+
@model_validator(mode="after")
|
84
|
+
def validate_auth(self) -> "DeltaLakeApiWrapper":
|
85
|
+
if not (self.aws_access_key_id and self.aws_secret_access_key and self.aws_region):
|
86
|
+
raise ValueError("You must provide AWS credentials and region.")
|
87
|
+
if not (self.s3_path or self.table_path):
|
88
|
+
raise ValueError("You must provide either s3_path or table_path.")
|
89
|
+
return self
|
90
|
+
|
91
|
+
@property
|
92
|
+
def delta_table(self) -> DeltaTable:
|
93
|
+
if not self._delta_table:
|
94
|
+
path = self.table_path or self.s3_path
|
95
|
+
if not path:
|
96
|
+
raise ToolException("Delta Lake table path (table_path or s3_path) must be specified.")
|
97
|
+
try:
|
98
|
+
storage_options = {
|
99
|
+
"AWS_ACCESS_KEY_ID": self.aws_access_key_id.get_secret_value() if self.aws_access_key_id else None,
|
100
|
+
"AWS_SECRET_ACCESS_KEY": self.aws_secret_access_key.get_secret_value() if self.aws_secret_access_key else None,
|
101
|
+
"AWS_REGION": self.aws_region,
|
102
|
+
}
|
103
|
+
if self.aws_session_token:
|
104
|
+
storage_options["AWS_SESSION_TOKEN"] = self.aws_session_token.get_secret_value()
|
105
|
+
storage_options = {k: v for k, v in storage_options.items() if v is not None}
|
106
|
+
self._delta_table = DeltaTable(path, storage_options=storage_options)
|
107
|
+
except Exception as e:
|
108
|
+
raise ToolException(f"Error initializing DeltaTable: {e}")
|
109
|
+
return self._delta_table
|
110
|
+
|
111
|
+
@process_output
|
112
|
+
def query_table(self, query: Optional[str] = None, columns: Optional[List[str]] = None, filters: Optional[dict] = None) -> List[dict]:
|
113
|
+
"""
|
114
|
+
Query Delta Lake table. Supports pandas-like filtering, column selection, and SQL-like queries (via pandas.DataFrame.query).
|
115
|
+
Args:
|
116
|
+
query: SQL-like query string (pandas.DataFrame.query syntax)
|
117
|
+
columns: List of columns to select
|
118
|
+
filters: Dict of column:value pairs for pandas-like filtering
|
119
|
+
Returns:
|
120
|
+
List of dicts representing rows
|
121
|
+
"""
|
122
|
+
dt = self.delta_table
|
123
|
+
df = dt.to_pandas()
|
124
|
+
if filters:
|
125
|
+
for col, val in filters.items():
|
126
|
+
df = df[df[col] == val]
|
127
|
+
if query:
|
128
|
+
try:
|
129
|
+
df = df.query(query)
|
130
|
+
except Exception as e:
|
131
|
+
raise ToolException(f"Error in query param: {e}")
|
132
|
+
if columns:
|
133
|
+
df = df[columns]
|
134
|
+
return df.to_dict(orient="records")
|
135
|
+
|
136
|
+
@process_output
|
137
|
+
def vector_search(self, embedding: List[float], k: int = 5, embedding_column: str = "embedding") -> List[dict]:
|
138
|
+
"""
|
139
|
+
Perform a vector similarity search on the Delta Lake table.
|
140
|
+
Args:
|
141
|
+
embedding: Query embedding vector.
|
142
|
+
k: Number of top results to return.
|
143
|
+
embedding_column: Name of the column containing embeddings.
|
144
|
+
Returns:
|
145
|
+
List of dicts for top k most similar rows.
|
146
|
+
"""
|
147
|
+
import numpy as np
|
148
|
+
|
149
|
+
dt = self.delta_table
|
150
|
+
df = dt.to_pandas()
|
151
|
+
if embedding_column not in df.columns:
|
152
|
+
raise ToolException(f"Embedding column '{embedding_column}' not found in table.")
|
153
|
+
|
154
|
+
# Filter out rows with missing embeddings
|
155
|
+
df = df[df[embedding_column].notnull()]
|
156
|
+
if df.empty:
|
157
|
+
return []
|
158
|
+
# Convert embeddings to numpy arrays
|
159
|
+
emb_matrix = np.array(df[embedding_column].tolist())
|
160
|
+
query_vec = np.array(embedding)
|
161
|
+
|
162
|
+
# Normalize for cosine similarity
|
163
|
+
emb_matrix_norm = emb_matrix / np.linalg.norm(emb_matrix, axis=1, keepdims=True)
|
164
|
+
query_vec_norm = query_vec / np.linalg.norm(query_vec)
|
165
|
+
similarities = np.dot(emb_matrix_norm, query_vec_norm)
|
166
|
+
|
167
|
+
# Get top k indices
|
168
|
+
top_k_idx = np.argsort(similarities)[-k:][::-1]
|
169
|
+
top_rows = df.iloc[top_k_idx]
|
170
|
+
return top_rows.to_dict(orient="records")
|
171
|
+
|
172
|
+
@process_output
|
173
|
+
def get_table_schema(self) -> str:
|
174
|
+
dt = self.delta_table
|
175
|
+
return dt.schema().to_pyarrow().to_string()
|
176
|
+
|
177
|
+
def get_available_tools(self) -> List[dict]:
|
178
|
+
return [
|
179
|
+
{
|
180
|
+
"name": "query_table",
|
181
|
+
"description": self.query_table.__doc__,
|
182
|
+
"args_schema": ArgsSchema.QueryTableArgs.value,
|
183
|
+
"ref": self.query_table,
|
184
|
+
},
|
185
|
+
{
|
186
|
+
"name": "vector_search",
|
187
|
+
"description": self.vector_search.__doc__,
|
188
|
+
"args_schema": ArgsSchema.VectorSearchArgs.value,
|
189
|
+
"ref": self.vector_search,
|
190
|
+
},
|
191
|
+
{
|
192
|
+
"name": "get_table_schema",
|
193
|
+
"description": self.get_table_schema.__doc__,
|
194
|
+
"args_schema": ArgsSchema.NoInput.value,
|
195
|
+
"ref": self.get_table_schema,
|
196
|
+
},
|
197
|
+
]
|
198
|
+
|
199
|
+
def run(self, name: str, *args: Any, **kwargs: Any):
|
200
|
+
for tool in self.get_available_tools():
|
201
|
+
if tool["name"] == name:
|
202
|
+
if len(args) == 1 and isinstance(args[0], dict) and not kwargs:
|
203
|
+
kwargs = args[0]
|
204
|
+
args = ()
|
205
|
+
try:
|
206
|
+
return tool["ref"](*args, **kwargs)
|
207
|
+
except TypeError as e:
|
208
|
+
if kwargs and not args:
|
209
|
+
try:
|
210
|
+
return tool["ref"](**kwargs)
|
211
|
+
except TypeError:
|
212
|
+
raise ValueError(
|
213
|
+
f"Argument mismatch for tool '{name}'. Error: {e}"
|
214
|
+
) from e
|
215
|
+
else:
|
216
|
+
raise ValueError(
|
217
|
+
f"Argument mismatch for tool '{name}'. Error: {e}"
|
218
|
+
) from e
|
219
|
+
else:
|
220
|
+
raise ValueError(f"Unknown tool name: {name}")
|
@@ -0,0 +1,20 @@
|
|
1
|
+
|
2
|
+
from enum import Enum
|
3
|
+
from typing import List, Optional
|
4
|
+
|
5
|
+
from pydantic import Field, create_model
|
6
|
+
|
7
|
+
class ArgsSchema(Enum):
|
8
|
+
NoInput = create_model("NoInput")
|
9
|
+
QueryTableArgs = create_model(
|
10
|
+
"QueryTableArgs",
|
11
|
+
query=(Optional[str], Field(default=None, description="SQL query to execute on Delta Lake table. If None, returns all data.")),
|
12
|
+
columns=(Optional[List[str]], Field(default=None, description="List of columns to select.")),
|
13
|
+
filters=(Optional[dict], Field(default=None, description="Dict of column:value pairs for pandas-like filtering.")),
|
14
|
+
)
|
15
|
+
VectorSearchArgs = create_model(
|
16
|
+
"VectorSearchArgs",
|
17
|
+
embedding=(List[float], Field(description="Embedding vector for similarity search.")),
|
18
|
+
k=(int, Field(default=5, description="Number of top results to return.")),
|
19
|
+
embedding_column=(Optional[str], Field(default="embedding", description="Name of the column containing embeddings.")),
|
20
|
+
)
|
@@ -0,0 +1,35 @@
|
|
1
|
+
|
2
|
+
from typing import Optional, Type
|
3
|
+
|
4
|
+
from langchain_core.callbacks import CallbackManagerForToolRun
|
5
|
+
from pydantic import BaseModel, field_validator, Field
|
6
|
+
from langchain_core.tools import BaseTool
|
7
|
+
from traceback import format_exc
|
8
|
+
from .api_wrapper import DeltaLakeApiWrapper
|
9
|
+
|
10
|
+
|
11
|
+
class DeltaLakeAction(BaseTool):
|
12
|
+
"""Tool for interacting with the Delta Lake API on AWS."""
|
13
|
+
|
14
|
+
api_wrapper: DeltaLakeApiWrapper = Field(default_factory=DeltaLakeApiWrapper)
|
15
|
+
name: str
|
16
|
+
description: str = ""
|
17
|
+
args_schema: Optional[Type[BaseModel]] = None
|
18
|
+
|
19
|
+
@field_validator('name', mode='before')
|
20
|
+
@classmethod
|
21
|
+
def remove_spaces(cls, v):
|
22
|
+
return v.replace(' ', '')
|
23
|
+
|
24
|
+
def _run(
|
25
|
+
self,
|
26
|
+
*args,
|
27
|
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
28
|
+
**kwargs,
|
29
|
+
) -> str:
|
30
|
+
"""Use the Delta Lake API to run an operation."""
|
31
|
+
try:
|
32
|
+
# Use the tool name to dispatch to the correct API wrapper method
|
33
|
+
return self.api_wrapper.run(self.name, *args, **kwargs)
|
34
|
+
except Exception as e:
|
35
|
+
return f"Error: {format_exc()}"
|
@@ -5,7 +5,7 @@ import logging
|
|
5
5
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
6
6
|
|
7
7
|
from langchain_core.tools import ToolException
|
8
|
-
from pydantic import
|
8
|
+
from pydantic import model_validator, SecretStr
|
9
9
|
from .bitbucket_constants import create_pr_data
|
10
10
|
from .cloud_api_wrapper import BitbucketCloudApi, BitbucketServerApi
|
11
11
|
from pydantic.fields import PrivateAttr
|
@@ -172,26 +172,26 @@ class BitbucketAPIWrapper(BaseCodeToolApiWrapper):
|
|
172
172
|
"""
|
173
173
|
return self._bitbucket.get_pull_requests()
|
174
174
|
|
175
|
-
def get_pull_request(self, pr_id: str) -> Any:
|
175
|
+
def get_pull_request(self, pr_id: str) -> Dict[str, Any]:
|
176
176
|
"""
|
177
177
|
Get details of a pull request
|
178
178
|
Parameters:
|
179
179
|
pr_id(str): the pull request ID
|
180
180
|
Returns:
|
181
|
-
|
181
|
+
dict: Details of the pull request as a dictionary
|
182
182
|
"""
|
183
183
|
try:
|
184
184
|
return self._bitbucket.get_pull_request(pr_id=pr_id)
|
185
185
|
except Exception as e:
|
186
186
|
return ToolException(f"Can't get pull request `{pr_id}` due to error:\n{str(e)}")
|
187
187
|
|
188
|
-
def get_pull_requests_changes(self, pr_id: str) -> Any:
|
188
|
+
def get_pull_requests_changes(self, pr_id: str) -> Dict[str, Any]:
|
189
189
|
"""
|
190
190
|
Get changes of a pull request
|
191
191
|
Parameters:
|
192
192
|
pr_id(str): the pull request ID
|
193
193
|
Returns:
|
194
|
-
|
194
|
+
dict: Changes of the pull request as a dictionary
|
195
195
|
"""
|
196
196
|
try:
|
197
197
|
return self._bitbucket.get_pull_requests_changes(pr_id=pr_id)
|