alita-sdk 0.3.203__py3-none-any.whl → 0.3.205__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alita_sdk/runtime/clients/client.py +3 -3
- alita_sdk/runtime/tools/vectorstore.py +143 -13
- alita_sdk/tools/__init__.py +2 -0
- alita_sdk/tools/aws/__init__.py +7 -0
- alita_sdk/tools/aws/delta_lake/__init__.py +136 -0
- alita_sdk/tools/aws/delta_lake/api_wrapper.py +220 -0
- alita_sdk/tools/aws/delta_lake/schemas.py +20 -0
- alita_sdk/tools/aws/delta_lake/tool.py +35 -0
- alita_sdk/tools/elitea_base.py +49 -4
- alita_sdk/tools/google/__init__.py +7 -0
- alita_sdk/tools/google/bigquery/__init__.py +154 -0
- alita_sdk/tools/google/bigquery/api_wrapper.py +502 -0
- alita_sdk/tools/google/bigquery/schemas.py +102 -0
- alita_sdk/tools/google/bigquery/tool.py +34 -0
- alita_sdk/tools/postman/api_wrapper.py +15 -8
- alita_sdk/tools/sharepoint/api_wrapper.py +60 -4
- alita_sdk/tools/testrail/__init__.py +9 -1
- alita_sdk/tools/testrail/api_wrapper.py +132 -6
- alita_sdk/tools/zephyr_scale/api_wrapper.py +271 -22
- {alita_sdk-0.3.203.dist-info → alita_sdk-0.3.205.dist-info}/METADATA +3 -1
- {alita_sdk-0.3.203.dist-info → alita_sdk-0.3.205.dist-info}/RECORD +24 -14
- {alita_sdk-0.3.203.dist-info → alita_sdk-0.3.205.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.203.dist-info → alita_sdk-0.3.205.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.203.dist-info → alita_sdk-0.3.205.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,502 @@
|
|
1
|
+
import functools
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
from typing import Any, Dict, List, Optional, Union
|
5
|
+
|
6
|
+
from google.cloud import bigquery
|
7
|
+
from langchain_core.tools import ToolException
|
8
|
+
from pydantic import (
|
9
|
+
ConfigDict,
|
10
|
+
Field,
|
11
|
+
PrivateAttr,
|
12
|
+
SecretStr,
|
13
|
+
field_validator,
|
14
|
+
model_validator,
|
15
|
+
)
|
16
|
+
from pydantic_core.core_schema import ValidationInfo
|
17
|
+
|
18
|
+
from ...elitea_base import BaseToolApiWrapper
|
19
|
+
from .schemas import ArgsSchema
|
20
|
+
|
21
|
+
|
22
|
+
def process_output(func):
|
23
|
+
@functools.wraps(func)
|
24
|
+
def wrapper(self, *args, **kwargs):
|
25
|
+
try:
|
26
|
+
result = func(self, *args, **kwargs)
|
27
|
+
if isinstance(result, Exception):
|
28
|
+
return ToolException(str(result))
|
29
|
+
if isinstance(result, (dict, list)):
|
30
|
+
return json.dumps(result, default=str)
|
31
|
+
return str(result)
|
32
|
+
except Exception as e:
|
33
|
+
logging.error(f"Error in '{func.__name__}': {str(e)}")
|
34
|
+
return ToolException(str(e))
|
35
|
+
|
36
|
+
return wrapper
|
37
|
+
|
38
|
+
|
39
|
+
class BigQueryApiWrapper(BaseToolApiWrapper):
|
40
|
+
model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True)
|
41
|
+
|
42
|
+
api_key: Optional[SecretStr] = Field(
|
43
|
+
default=None, json_schema_extra={"env_key": "BIGQUERY_API_KEY"}
|
44
|
+
)
|
45
|
+
project: Optional[str] = Field(
|
46
|
+
default=None, json_schema_extra={"env_key": "BIGQUERY_PROJECT"}
|
47
|
+
)
|
48
|
+
location: Optional[str] = Field(
|
49
|
+
default=None, json_schema_extra={"env_key": "BIGQUERY_LOCATION"}
|
50
|
+
)
|
51
|
+
dataset: Optional[str] = Field(
|
52
|
+
default=None, json_schema_extra={"env_key": "BIGQUERY_DATASET"}
|
53
|
+
)
|
54
|
+
table: Optional[str] = Field(
|
55
|
+
default=None, json_schema_extra={"env_key": "BIGQUERY_TABLE"}
|
56
|
+
)
|
57
|
+
embedding: Optional[Any] = None
|
58
|
+
_client: Optional[bigquery.Client] = PrivateAttr(default=None)
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def model_construct(cls, *args, **kwargs):
|
62
|
+
klass = super().model_construct(*args, **kwargs)
|
63
|
+
klass._client = None
|
64
|
+
return klass
|
65
|
+
|
66
|
+
@field_validator(
|
67
|
+
"api_key",
|
68
|
+
"project",
|
69
|
+
"location",
|
70
|
+
"dataset",
|
71
|
+
"table",
|
72
|
+
mode="before",
|
73
|
+
check_fields=False,
|
74
|
+
)
|
75
|
+
@classmethod
|
76
|
+
def set_from_values_or_env(cls, value, info: ValidationInfo):
|
77
|
+
if value is None:
|
78
|
+
if json_schema_extra := cls.model_fields[info.field_name].json_schema_extra:
|
79
|
+
if env_key := json_schema_extra.get("env_key"):
|
80
|
+
try:
|
81
|
+
from langchain_core.utils import get_from_env
|
82
|
+
|
83
|
+
return get_from_env(
|
84
|
+
key=info.field_name,
|
85
|
+
env_key=env_key,
|
86
|
+
default=cls.model_fields[info.field_name].default,
|
87
|
+
)
|
88
|
+
except Exception:
|
89
|
+
return None
|
90
|
+
return value
|
91
|
+
|
92
|
+
@model_validator(mode="after")
|
93
|
+
def validate_auth(self) -> "BigQueryApiWrapper":
|
94
|
+
if not self.api_key:
|
95
|
+
raise ValueError("You must provide a BigQuery API key.")
|
96
|
+
return self
|
97
|
+
|
98
|
+
@property
|
99
|
+
def bigquery_client(self) -> bigquery.Client:
|
100
|
+
if not self._client:
|
101
|
+
api_key = self.api_key.get_secret_value() if self.api_key else None
|
102
|
+
if not api_key:
|
103
|
+
raise ToolException("BigQuery API key is not set.")
|
104
|
+
try:
|
105
|
+
api_key_dict = json.loads(api_key)
|
106
|
+
credentials = bigquery.Client.from_service_account_info(
|
107
|
+
api_key_dict
|
108
|
+
)._credentials
|
109
|
+
self._client = bigquery.Client(
|
110
|
+
credentials=credentials,
|
111
|
+
project=self.project,
|
112
|
+
location=self.location,
|
113
|
+
)
|
114
|
+
except Exception as e:
|
115
|
+
raise ToolException(f"Error initializing GCP credentials: {str(e)}")
|
116
|
+
return self._client
|
117
|
+
|
118
|
+
def _get_table_id(self):
|
119
|
+
if not (self.project and self.dataset and self.table):
|
120
|
+
raise ToolException("Project, dataset, and table must be specified.")
|
121
|
+
return f"{self.project}.{self.dataset}.{self.table}"
|
122
|
+
|
123
|
+
def _create_filters(
|
124
|
+
self, filter: Optional[Union[Dict[str, Any], str]] = None
|
125
|
+
) -> str:
|
126
|
+
if filter:
|
127
|
+
if isinstance(filter, dict):
|
128
|
+
filter_expressions = []
|
129
|
+
for k, v in filter.items():
|
130
|
+
if isinstance(v, (int, float)):
|
131
|
+
filter_expressions.append(f"{k} = {v}")
|
132
|
+
else:
|
133
|
+
filter_expressions.append(f"{k} = '{v}'")
|
134
|
+
return " AND ".join(filter_expressions)
|
135
|
+
else:
|
136
|
+
return filter
|
137
|
+
return "TRUE"
|
138
|
+
|
139
|
+
def job_stats(self, job_id: str) -> Dict:
|
140
|
+
return self.bigquery_client.get_job(job_id)._properties.get("statistics", {})
|
141
|
+
|
142
|
+
def create_vector_index(self):
|
143
|
+
table_id = self._get_table_id()
|
144
|
+
index_name = f"{self.table}_langchain_index"
|
145
|
+
sql = f"""
|
146
|
+
CREATE VECTOR INDEX IF NOT EXISTS
|
147
|
+
`{index_name}`
|
148
|
+
ON `{table_id}`
|
149
|
+
(embedding)
|
150
|
+
OPTIONS(distance_type="EUCLIDEAN", index_type="IVF")
|
151
|
+
"""
|
152
|
+
try:
|
153
|
+
self.bigquery_client.query(sql).result()
|
154
|
+
return f"Vector index '{index_name}' created or already exists."
|
155
|
+
except Exception as ex:
|
156
|
+
logging.error(f"Vector index creation failed: {ex}")
|
157
|
+
return ToolException(f"Vector index creation failed: {ex}")
|
158
|
+
|
159
|
+
@process_output
|
160
|
+
def get_documents(
|
161
|
+
self,
|
162
|
+
ids: Optional[List[str]] = None,
|
163
|
+
filter: Optional[Union[Dict[str, Any], str]] = None,
|
164
|
+
):
|
165
|
+
table_id = self._get_table_id()
|
166
|
+
job_config = None
|
167
|
+
id_expr = "TRUE"
|
168
|
+
if ids:
|
169
|
+
job_config = bigquery.QueryJobConfig(
|
170
|
+
query_parameters=[bigquery.ArrayQueryParameter("ids", "STRING", ids)]
|
171
|
+
)
|
172
|
+
id_expr = "doc_id IN UNNEST(@ids)"
|
173
|
+
where_filter_expr = self._create_filters(filter)
|
174
|
+
query = f"SELECT * FROM `{table_id}` WHERE {id_expr} AND {where_filter_expr}"
|
175
|
+
job = self.bigquery_client.query(query, job_config=job_config)
|
176
|
+
return [dict(row) for row in job]
|
177
|
+
|
178
|
+
@process_output
|
179
|
+
def similarity_search(
|
180
|
+
self,
|
181
|
+
query: str,
|
182
|
+
k: int = 5,
|
183
|
+
filter: Optional[Union[Dict[str, Any], str]] = None,
|
184
|
+
):
|
185
|
+
"""Search for top `k` docs most similar to input query using vector similarity search."""
|
186
|
+
if not hasattr(self, "embedding") or self.embedding is None:
|
187
|
+
raise ToolException("Embedding model is not set on the wrapper.")
|
188
|
+
embedding_vector = self.embedding.embed_query(query)
|
189
|
+
# Prepare the vector search query
|
190
|
+
table_id = self._get_table_id()
|
191
|
+
where_filter_expr = "TRUE"
|
192
|
+
if filter:
|
193
|
+
if isinstance(filter, dict):
|
194
|
+
filter_expressions = [f"{k} = '{v}'" for k, v in filter.items()]
|
195
|
+
where_filter_expr = " AND ".join(filter_expressions)
|
196
|
+
else:
|
197
|
+
where_filter_expr = filter
|
198
|
+
# BigQuery vector search SQL (using VECTOR_SEARCH if available)
|
199
|
+
sql = f"""
|
200
|
+
SELECT *,
|
201
|
+
VECTOR_DISTANCE(embedding, @query_embedding) AS score
|
202
|
+
FROM `{table_id}`
|
203
|
+
WHERE {where_filter_expr}
|
204
|
+
ORDER BY score ASC
|
205
|
+
LIMIT {k}
|
206
|
+
"""
|
207
|
+
job_config = bigquery.QueryJobConfig(
|
208
|
+
query_parameters=[
|
209
|
+
bigquery.ArrayQueryParameter(
|
210
|
+
"query_embedding", "FLOAT64", embedding_vector
|
211
|
+
)
|
212
|
+
]
|
213
|
+
)
|
214
|
+
job = self.bigquery_client.query(sql, job_config=job_config)
|
215
|
+
return [dict(row) for row in job]
|
216
|
+
|
217
|
+
@process_output
|
218
|
+
def batch_search(
|
219
|
+
self,
|
220
|
+
queries: Optional[List[str]] = None,
|
221
|
+
embeddings: Optional[List[List[float]]] = None,
|
222
|
+
k: int = 5,
|
223
|
+
filter: Optional[Union[Dict[str, Any], str]] = None,
|
224
|
+
):
|
225
|
+
"""Batch vector similarity search. Accepts either queries (to embed) or embeddings."""
|
226
|
+
if queries is not None and embeddings is not None:
|
227
|
+
raise ToolException("Provide only one of 'queries' or 'embeddings'.")
|
228
|
+
if queries is not None:
|
229
|
+
if not hasattr(self, "embedding") or self.embedding is None:
|
230
|
+
raise ToolException("Embedding model is not set on the wrapper.")
|
231
|
+
embeddings = [self.embedding.embed_query(q) for q in queries]
|
232
|
+
if not embeddings:
|
233
|
+
raise ToolException("No embeddings or queries provided.")
|
234
|
+
table_id = self._get_table_id()
|
235
|
+
where_filter_expr = "TRUE"
|
236
|
+
if filter:
|
237
|
+
if isinstance(filter, dict):
|
238
|
+
filter_expressions = [f"{k} = '{v}'" for k, v in filter.items()]
|
239
|
+
where_filter_expr = " AND ".join(filter_expressions)
|
240
|
+
else:
|
241
|
+
where_filter_expr = filter
|
242
|
+
results = []
|
243
|
+
for emb in embeddings:
|
244
|
+
sql = f"""
|
245
|
+
SELECT *,
|
246
|
+
VECTOR_DISTANCE(embedding, @query_embedding) AS score
|
247
|
+
FROM `{table_id}`
|
248
|
+
WHERE {where_filter_expr}
|
249
|
+
ORDER BY score ASC
|
250
|
+
LIMIT {k}
|
251
|
+
"""
|
252
|
+
job_config = bigquery.QueryJobConfig(
|
253
|
+
query_parameters=[
|
254
|
+
bigquery.ArrayQueryParameter("query_embedding", "FLOAT64", emb)
|
255
|
+
]
|
256
|
+
)
|
257
|
+
job = self.bigquery_client.query(sql, job_config=job_config)
|
258
|
+
results.append([dict(row) for row in job])
|
259
|
+
return results
|
260
|
+
|
261
|
+
def similarity_search_by_vector(
|
262
|
+
self, embedding: List[float], k: int = 5, **kwargs
|
263
|
+
) -> List[Dict]:
|
264
|
+
"""Return docs most similar to embedding vector."""
|
265
|
+
table_id = self._get_table_id()
|
266
|
+
sql = f"""
|
267
|
+
SELECT *, VECTOR_DISTANCE(embedding, @query_embedding) AS score
|
268
|
+
FROM `{table_id}`
|
269
|
+
ORDER BY score ASC
|
270
|
+
LIMIT {k}
|
271
|
+
"""
|
272
|
+
job_config = bigquery.QueryJobConfig(
|
273
|
+
query_parameters=[
|
274
|
+
bigquery.ArrayQueryParameter("query_embedding", "FLOAT64", embedding)
|
275
|
+
]
|
276
|
+
)
|
277
|
+
job = self.bigquery_client.query(sql, job_config=job_config)
|
278
|
+
return [self._row_to_document(row) for row in job]
|
279
|
+
|
280
|
+
def similarity_search_by_vector_with_score(
|
281
|
+
self,
|
282
|
+
embedding: List[float],
|
283
|
+
filter: Optional[Union[Dict[str, Any], str]] = None,
|
284
|
+
k: int = 5,
|
285
|
+
**kwargs,
|
286
|
+
) -> List[Dict]:
|
287
|
+
"""Return docs most similar to embedding vector with scores."""
|
288
|
+
table_id = self._get_table_id()
|
289
|
+
where_filter_expr = self._create_filters(filter)
|
290
|
+
sql = f"""
|
291
|
+
SELECT *, VECTOR_DISTANCE(embedding, @query_embedding) AS score
|
292
|
+
FROM `{table_id}`
|
293
|
+
WHERE {where_filter_expr}
|
294
|
+
ORDER BY score ASC
|
295
|
+
LIMIT {k}
|
296
|
+
"""
|
297
|
+
job_config = bigquery.QueryJobConfig(
|
298
|
+
query_parameters=[
|
299
|
+
bigquery.ArrayQueryParameter("query_embedding", "FLOAT64", embedding)
|
300
|
+
]
|
301
|
+
)
|
302
|
+
job = self.bigquery_client.query(sql, job_config=job_config)
|
303
|
+
return [self._row_to_document(row) for row in job]
|
304
|
+
|
305
|
+
def similarity_search_with_score(
|
306
|
+
self,
|
307
|
+
query: str,
|
308
|
+
filter: Optional[Union[Dict[str, Any], str]] = None,
|
309
|
+
k: int = 5,
|
310
|
+
**kwargs,
|
311
|
+
) -> List[Dict]:
|
312
|
+
"""Search for top `k` docs most similar to input query, returns both docs and scores."""
|
313
|
+
embedding = self.embedding.embed_query(query)
|
314
|
+
return self.similarity_search_by_vector_with_score(
|
315
|
+
embedding, filter=filter, k=k, **kwargs
|
316
|
+
)
|
317
|
+
|
318
|
+
def similarity_search_by_vectors(
|
319
|
+
self,
|
320
|
+
embeddings: List[List[float]],
|
321
|
+
filter: Optional[Union[Dict[str, Any], str]] = None,
|
322
|
+
k: int = 5,
|
323
|
+
with_scores: bool = False,
|
324
|
+
with_embeddings: bool = False,
|
325
|
+
**kwargs,
|
326
|
+
) -> Any:
|
327
|
+
"""Core similarity search function. Handles a list of embedding vectors, optionally returning scores and embeddings."""
|
328
|
+
results = []
|
329
|
+
for emb in embeddings:
|
330
|
+
docs = self.similarity_search_by_vector_with_score(
|
331
|
+
emb, filter=filter, k=k, **kwargs
|
332
|
+
)
|
333
|
+
if not with_scores and not with_embeddings:
|
334
|
+
docs = [d for d in docs]
|
335
|
+
elif not with_embeddings:
|
336
|
+
docs = [{**d, "score": d.get("score")} for d in docs]
|
337
|
+
elif not with_scores:
|
338
|
+
docs = [{**d, "embedding": emb} for d in docs]
|
339
|
+
results.append(docs)
|
340
|
+
return results
|
341
|
+
|
342
|
+
def execute(self, method: str, *args, **kwargs):
|
343
|
+
"""
|
344
|
+
Universal method to call any method from google.cloud.bigquery.Client.
|
345
|
+
Args:
|
346
|
+
method: Name of the method to call on the BigQuery client.
|
347
|
+
*args: Positional arguments for the method.
|
348
|
+
**kwargs: Keyword arguments for the method.
|
349
|
+
Returns:
|
350
|
+
The result of the called method.
|
351
|
+
Raises:
|
352
|
+
ToolException: If the client is not initialized or method does not exist.
|
353
|
+
"""
|
354
|
+
if not self._client:
|
355
|
+
raise ToolException("BigQuery client is not initialized.")
|
356
|
+
if not hasattr(self._client, method):
|
357
|
+
raise ToolException(f"BigQuery client has no method '{method}'")
|
358
|
+
func = getattr(self._client, method)
|
359
|
+
try:
|
360
|
+
result = func(*args, **kwargs)
|
361
|
+
return result
|
362
|
+
except Exception as e:
|
363
|
+
logging.error(f"Error executing '{method}': {e}")
|
364
|
+
raise ToolException(f"Error executing '{method}': {e}")
|
365
|
+
|
366
|
+
@process_output
|
367
|
+
def create_delta_lake_table(
|
368
|
+
self,
|
369
|
+
table_name: str,
|
370
|
+
dataset: Optional[str] = None,
|
371
|
+
connection_id: str = None,
|
372
|
+
source_uris: list = None,
|
373
|
+
autodetect: bool = True,
|
374
|
+
project: Optional[str] = None,
|
375
|
+
**kwargs,
|
376
|
+
):
|
377
|
+
"""
|
378
|
+
Create a Delta Lake external table in BigQuery using the google.cloud.bigquery library.
|
379
|
+
Args:
|
380
|
+
table_name: Name of the Delta Lake table to create in BigQuery.
|
381
|
+
dataset: BigQuery dataset to contain the table (defaults to self.dataset).
|
382
|
+
connection_id: Fully qualified connection ID (project.region.connection_id).
|
383
|
+
source_uris: List of GCS URIs (prefixes) for the Delta Lake table.
|
384
|
+
autodetect: Whether to autodetect schema (default: True).
|
385
|
+
project: GCP project ID (defaults to self.project).
|
386
|
+
Returns:
|
387
|
+
API response as dict.
|
388
|
+
"""
|
389
|
+
dataset = dataset or self.dataset
|
390
|
+
project = project or self.project
|
391
|
+
if not (project and dataset and table_name and connection_id and source_uris):
|
392
|
+
raise ToolException("project, dataset, table_name, connection_id, and source_uris are required.")
|
393
|
+
client = self.bigquery_client
|
394
|
+
table_ref = bigquery.TableReference(
|
395
|
+
bigquery.DatasetReference(project, dataset), table_name
|
396
|
+
)
|
397
|
+
external_config = bigquery.ExternalConfig("DELTA_LAKE")
|
398
|
+
external_config.autodetect = autodetect
|
399
|
+
external_config.source_uris = source_uris
|
400
|
+
external_config.connection_id = connection_id
|
401
|
+
table = bigquery.Table(table_ref)
|
402
|
+
table.external_data_configuration = external_config
|
403
|
+
try:
|
404
|
+
created_table = client.create_table(table, exists_ok=True)
|
405
|
+
return created_table.to_api_repr()
|
406
|
+
except Exception as e:
|
407
|
+
raise ToolException(f"Failed to create Delta Lake table: {e}")
|
408
|
+
|
409
|
+
def get_available_tools(self) -> List[Dict[str, Any]]:
|
410
|
+
return [
|
411
|
+
{
|
412
|
+
"name": "get_documents",
|
413
|
+
"description": self.get_documents.__doc__,
|
414
|
+
"args_schema": ArgsSchema.GetDocuments.value,
|
415
|
+
"ref": self.get_documents,
|
416
|
+
},
|
417
|
+
{
|
418
|
+
"name": "similarity_search",
|
419
|
+
"description": self.similarity_search.__doc__,
|
420
|
+
"args_schema": ArgsSchema.SimilaritySearch.value,
|
421
|
+
"ref": self.similarity_search,
|
422
|
+
},
|
423
|
+
{
|
424
|
+
"name": "batch_search",
|
425
|
+
"description": self.batch_search.__doc__,
|
426
|
+
"args_schema": ArgsSchema.BatchSearch.value,
|
427
|
+
"ref": self.batch_search,
|
428
|
+
},
|
429
|
+
{
|
430
|
+
"name": "create_vector_index",
|
431
|
+
"description": self.create_vector_index.__doc__,
|
432
|
+
"args_schema": ArgsSchema.NoInput.value,
|
433
|
+
"ref": self.create_vector_index,
|
434
|
+
},
|
435
|
+
{
|
436
|
+
"name": "job_stats",
|
437
|
+
"description": self.job_stats.__doc__,
|
438
|
+
"args_schema": ArgsSchema.JobStatsArgs.value,
|
439
|
+
"ref": self.job_stats,
|
440
|
+
},
|
441
|
+
{
|
442
|
+
"name": "similarity_search_by_vector",
|
443
|
+
"description": self.similarity_search_by_vector.__doc__,
|
444
|
+
"args_schema": ArgsSchema.SimilaritySearchByVectorArgs.value,
|
445
|
+
"ref": self.similarity_search_by_vector,
|
446
|
+
},
|
447
|
+
{
|
448
|
+
"name": "similarity_search_by_vector_with_score",
|
449
|
+
"description": self.similarity_search_by_vector_with_score.__doc__,
|
450
|
+
"args_schema": ArgsSchema.SimilaritySearchByVectorWithScoreArgs.value,
|
451
|
+
"ref": self.similarity_search_by_vector_with_score,
|
452
|
+
},
|
453
|
+
{
|
454
|
+
"name": "similarity_search_with_score",
|
455
|
+
"description": self.similarity_search_with_score.__doc__,
|
456
|
+
"args_schema": ArgsSchema.SimilaritySearchWithScoreArgs.value,
|
457
|
+
"ref": self.similarity_search_with_score,
|
458
|
+
},
|
459
|
+
{
|
460
|
+
"name": "similarity_search_by_vectors",
|
461
|
+
"description": self.similarity_search_by_vectors.__doc__,
|
462
|
+
"args_schema": ArgsSchema.SimilaritySearchByVectorsArgs.value,
|
463
|
+
"ref": self.similarity_search_by_vectors,
|
464
|
+
},
|
465
|
+
{
|
466
|
+
"name": "execute",
|
467
|
+
"description": self.execute.__doc__,
|
468
|
+
"args_schema": ArgsSchema.ExecuteArgs.value,
|
469
|
+
"ref": self.execute,
|
470
|
+
},
|
471
|
+
{
|
472
|
+
"name": "create_delta_lake_table",
|
473
|
+
"description": self.create_delta_lake_table.__doc__,
|
474
|
+
"args_schema": ArgsSchema.CreateDeltaLakeTable.value,
|
475
|
+
"ref": self.create_delta_lake_table,
|
476
|
+
},
|
477
|
+
]
|
478
|
+
|
479
|
+
def run(self, name: str, *args: Any, **kwargs: Any):
|
480
|
+
for tool in self.get_available_tools():
|
481
|
+
if tool["name"] == name:
|
482
|
+
# Handle potential dictionary input for args when only one dict is passed
|
483
|
+
if len(args) == 1 and isinstance(args[0], dict) and not kwargs:
|
484
|
+
kwargs = args[0]
|
485
|
+
args = () # Clear args
|
486
|
+
try:
|
487
|
+
return tool["ref"](*args, **kwargs)
|
488
|
+
except TypeError as e:
|
489
|
+
# Attempt to call with kwargs only if args fail and kwargs exist
|
490
|
+
if kwargs and not args:
|
491
|
+
try:
|
492
|
+
return tool["ref"](**kwargs)
|
493
|
+
except TypeError:
|
494
|
+
raise ValueError(
|
495
|
+
f"Argument mismatch for tool '{name}'. Error: {e}"
|
496
|
+
) from e
|
497
|
+
else:
|
498
|
+
raise ValueError(
|
499
|
+
f"Argument mismatch for tool '{name}'. Error: {e}"
|
500
|
+
) from e
|
501
|
+
else:
|
502
|
+
raise ValueError(f"Unknown tool name: {name}")
|
@@ -0,0 +1,102 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
3
|
+
|
4
|
+
from pydantic import Field, create_model
|
5
|
+
|
6
|
+
|
7
|
+
class ArgsSchema(Enum):
|
8
|
+
NoInput = create_model("NoInput")
|
9
|
+
GetDocuments = create_model(
|
10
|
+
"GetDocuments",
|
11
|
+
ids=(
|
12
|
+
Optional[List[str]],
|
13
|
+
Field(default=None, description="List of document IDs to retrieve."),
|
14
|
+
),
|
15
|
+
filter=(
|
16
|
+
Optional[Union[Dict[str, Any], str]],
|
17
|
+
Field(default=None, description="Filter as dict or SQL WHERE clause."),
|
18
|
+
),
|
19
|
+
)
|
20
|
+
SimilaritySearch = create_model(
|
21
|
+
"SimilaritySearch",
|
22
|
+
query=(str, Field(description="Text query to search for similar documents.")),
|
23
|
+
k=(int, Field(default=5, description="Number of top results to return.")),
|
24
|
+
filter=(
|
25
|
+
Optional[Union[Dict[str, Any], str]],
|
26
|
+
Field(default=None, description="Filter as dict or SQL WHERE clause."),
|
27
|
+
),
|
28
|
+
)
|
29
|
+
BatchSearch = create_model(
|
30
|
+
"BatchSearch",
|
31
|
+
queries=(
|
32
|
+
Optional[List[str]],
|
33
|
+
Field(default=None, description="List of text queries."),
|
34
|
+
),
|
35
|
+
embeddings=(
|
36
|
+
Optional[List[List[float]]],
|
37
|
+
Field(default=None, description="List of embedding vectors."),
|
38
|
+
),
|
39
|
+
k=(int, Field(default=5, description="Number of top results to return.")),
|
40
|
+
filter=(
|
41
|
+
Optional[Union[Dict[str, Any], str]],
|
42
|
+
Field(default=None, description="Filter as dict or SQL WHERE clause."),
|
43
|
+
),
|
44
|
+
)
|
45
|
+
JobStatsArgs = create_model(
|
46
|
+
"JobStatsArgs", job_id=(str, Field(description="BigQuery job ID."))
|
47
|
+
)
|
48
|
+
SimilaritySearchByVectorArgs = create_model(
|
49
|
+
"SimilaritySearchByVectorArgs",
|
50
|
+
embedding=(List[float], Field(description="Embedding vector.")),
|
51
|
+
k=(int, Field(default=5, description="Number of top results to return.")),
|
52
|
+
)
|
53
|
+
SimilaritySearchByVectorWithScoreArgs = create_model(
|
54
|
+
"SimilaritySearchByVectorWithScoreArgs",
|
55
|
+
embedding=(List[float], Field(description="Embedding vector.")),
|
56
|
+
filter=(
|
57
|
+
Optional[Union[Dict[str, Any], str]],
|
58
|
+
Field(default=None, description="Filter as dict or SQL WHERE clause."),
|
59
|
+
),
|
60
|
+
k=(int, Field(default=5, description="Number of top results to return.")),
|
61
|
+
)
|
62
|
+
SimilaritySearchWithScoreArgs = create_model(
|
63
|
+
"SimilaritySearchWithScoreArgs",
|
64
|
+
query=(str, Field(description="Text query.")),
|
65
|
+
filter=(
|
66
|
+
Optional[Union[Dict[str, Any], str]],
|
67
|
+
Field(default=None, description="Filter as dict or SQL WHERE clause."),
|
68
|
+
),
|
69
|
+
k=(int, Field(default=5, description="Number of top results to return.")),
|
70
|
+
)
|
71
|
+
SimilaritySearchByVectorsArgs = create_model(
|
72
|
+
"SimilaritySearchByVectorsArgs",
|
73
|
+
embeddings=(List[List[float]], Field(description="List of embedding vectors.")),
|
74
|
+
filter=(
|
75
|
+
Optional[Union[Dict[str, Any], str]],
|
76
|
+
Field(default=None, description="Filter as dict or SQL WHERE clause."),
|
77
|
+
),
|
78
|
+
k=(int, Field(default=5, description="Number of top results to return.")),
|
79
|
+
with_scores=(bool, Field(default=False)),
|
80
|
+
with_embeddings=(bool, Field(default=False)),
|
81
|
+
)
|
82
|
+
ExecuteArgs = create_model(
|
83
|
+
"ExecuteArgs",
|
84
|
+
method=(str, Field(description="Name of the BigQuery client method to call.")),
|
85
|
+
args=(
|
86
|
+
Optional[List[Any]],
|
87
|
+
Field(default=None, description="Positional arguments for the method."),
|
88
|
+
),
|
89
|
+
kwargs=(
|
90
|
+
Optional[Dict[str, Any]],
|
91
|
+
Field(default=None, description="Keyword arguments for the method."),
|
92
|
+
),
|
93
|
+
)
|
94
|
+
CreateDeltaLakeTable = create_model(
|
95
|
+
"CreateDeltaLakeTable",
|
96
|
+
table_name=(str, Field(description="Name of the Delta Lake table to create in BigQuery.")),
|
97
|
+
dataset=(Optional[str], Field(default=None, description="BigQuery dataset to contain the table (defaults to self.dataset).")),
|
98
|
+
connection_id=(str, Field(description="Fully qualified connection ID (project.region.connection_id).")),
|
99
|
+
source_uris=(list, Field(description="List of GCS URIs (prefixes) for the Delta Lake table.")),
|
100
|
+
autodetect=(bool, Field(default=True, description="Whether to autodetect schema (default: True).")),
|
101
|
+
project=(Optional[str], Field(default=None, description="GCP project ID (defaults to self.project).")),
|
102
|
+
)
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from typing import Optional, Type
|
2
|
+
|
3
|
+
from langchain_core.callbacks import CallbackManagerForToolRun
|
4
|
+
from pydantic import BaseModel, field_validator, Field
|
5
|
+
from langchain_core.tools import BaseTool
|
6
|
+
from traceback import format_exc
|
7
|
+
from .api_wrapper import BigQueryApiWrapper
|
8
|
+
|
9
|
+
|
10
|
+
class BigQueryAction(BaseTool):
|
11
|
+
"""Tool for interacting with the BigQuery API."""
|
12
|
+
|
13
|
+
api_wrapper: BigQueryApiWrapper = Field(default_factory=BigQueryApiWrapper)
|
14
|
+
name: str
|
15
|
+
mode: str = ""
|
16
|
+
description: str = ""
|
17
|
+
args_schema: Optional[Type[BaseModel]] = None
|
18
|
+
|
19
|
+
@field_validator('name', mode='before')
|
20
|
+
@classmethod
|
21
|
+
def remove_spaces(cls, v):
|
22
|
+
return v.replace(' ', '')
|
23
|
+
|
24
|
+
def _run(
|
25
|
+
self,
|
26
|
+
*args,
|
27
|
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
28
|
+
**kwargs,
|
29
|
+
) -> str:
|
30
|
+
"""Use the GitHub API to run an operation."""
|
31
|
+
try:
|
32
|
+
return self.api_wrapper.run(self.mode, *args, **kwargs)
|
33
|
+
except Exception as e:
|
34
|
+
return f"Error: {format_exc()}"
|