alita-sdk 0.3.211__py3-none-any.whl → 0.3.213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alita_sdk/runtime/clients/client.py +2 -2
- alita_sdk/runtime/langchain/document_loaders/AlitaDocxMammothLoader.py +48 -24
- alita_sdk/runtime/langchain/document_loaders/AlitaExcelLoader.py +47 -1
- alita_sdk/runtime/langchain/document_loaders/AlitaImageLoader.py +103 -49
- alita_sdk/runtime/langchain/document_loaders/AlitaPDFLoader.py +63 -0
- alita_sdk/runtime/langchain/document_loaders/AlitaPowerPointLoader.py +54 -0
- alita_sdk/runtime/langchain/document_loaders/AlitaTextLoader.py +66 -0
- alita_sdk/runtime/langchain/document_loaders/constants.py +13 -19
- alita_sdk/runtime/langchain/document_loaders/utils.py +30 -1
- alita_sdk/runtime/toolkits/artifact.py +5 -0
- alita_sdk/runtime/tools/artifact.py +2 -4
- alita_sdk/runtime/tools/vectorstore.py +2 -1
- alita_sdk/tools/ado/test_plan/test_plan_wrapper.py +13 -37
- alita_sdk/tools/ado/wiki/ado_wrapper.py +10 -39
- alita_sdk/tools/confluence/api_wrapper.py +2 -0
- alita_sdk/tools/elitea_base.py +24 -3
- alita_sdk/tools/gitlab/__init__.py +3 -2
- alita_sdk/tools/gitlab/api_wrapper.py +45 -18
- alita_sdk/tools/gitlab_org/api_wrapper.py +44 -25
- alita_sdk/tools/sharepoint/api_wrapper.py +13 -13
- alita_sdk/tools/testrail/api_wrapper.py +20 -0
- alita_sdk/tools/utils/content_parser.py +37 -162
- {alita_sdk-0.3.211.dist-info → alita_sdk-0.3.213.dist-info}/METADATA +1 -1
- {alita_sdk-0.3.211.dist-info → alita_sdk-0.3.213.dist-info}/RECORD +27 -24
- {alita_sdk-0.3.211.dist-info → alita_sdk-0.3.213.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.211.dist-info → alita_sdk-0.3.213.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.211.dist-info → alita_sdk-0.3.213.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,8 @@ import re
|
|
2
2
|
import string
|
3
3
|
from gensim.parsing import remove_stopwords
|
4
4
|
|
5
|
-
from ..tools.
|
5
|
+
from ..tools.utils import bytes_to_base64
|
6
|
+
from langchain_core.messages import HumanMessage
|
6
7
|
|
7
8
|
|
8
9
|
def cleanse_data(document: str) -> str:
|
@@ -32,3 +33,31 @@ def cleanse_data(document: str) -> str:
|
|
32
33
|
# document = document.replace(kw, "")
|
33
34
|
|
34
35
|
return document
|
36
|
+
|
37
|
+
def perform_llm_prediction_for_image_bytes(image_bytes: bytes, llm, prompt: str) -> str:
|
38
|
+
"""Performs LLM prediction for image content."""
|
39
|
+
base64_string = bytes_to_base64(image_bytes)
|
40
|
+
result = llm.invoke([
|
41
|
+
HumanMessage(
|
42
|
+
content=[
|
43
|
+
{"type": "text", "text": prompt},
|
44
|
+
{
|
45
|
+
"type": "image_url",
|
46
|
+
"image_url": {"url": f"data:image/png;base64,{base64_string}"},
|
47
|
+
},
|
48
|
+
]
|
49
|
+
)
|
50
|
+
])
|
51
|
+
return result.content
|
52
|
+
|
53
|
+
def create_temp_file(file_content: bytes):
|
54
|
+
import tempfile
|
55
|
+
|
56
|
+
# Automatic cleanup with context manager
|
57
|
+
with tempfile.NamedTemporaryFile(mode='w+b', delete=True) as temp_file:
|
58
|
+
# Write data to temp file
|
59
|
+
temp_file.write(file_content)
|
60
|
+
temp_file.flush() # Ensure data is written
|
61
|
+
|
62
|
+
# Get the file path for operations
|
63
|
+
return temp_file.name
|
@@ -26,6 +26,11 @@ class ArtifactToolkit(BaseToolkit):
|
|
26
26
|
connection_string = (Optional[SecretStr], Field(description="Connection string for vectorstore",
|
27
27
|
default=None,
|
28
28
|
json_schema_extra={'secret': True})),
|
29
|
+
# embedding model settings
|
30
|
+
embedding_model=(str, Field(default="HuggingFaceEmbeddings", description="Embedding model to use")),
|
31
|
+
embedding_model_params=(dict, Field(default={"model_name": "sentence-transformers/all-MiniLM-L6-v2"},
|
32
|
+
description="Parameters for embedding model")),
|
33
|
+
|
29
34
|
__config__=ConfigDict(json_schema_extra={'metadata': {"label": "Artifact", "icon_url": None}})
|
30
35
|
)
|
31
36
|
|
@@ -61,21 +61,19 @@ class ArtifactWrapper(BaseVectorStoreToolApiWrapper):
|
|
61
61
|
def create_new_bucket(self, bucket_name: str, expiration_measure = "weeks", expiration_value = 1):
|
62
62
|
return self.artifact.client.create_bucket(bucket_name, expiration_measure, expiration_value)
|
63
63
|
|
64
|
-
def _base_loader(self, **kwargs) ->
|
64
|
+
def _base_loader(self, **kwargs) -> Generator[Document, None, None]:
|
65
65
|
try:
|
66
66
|
all_files = self.list_files(self.bucket, False)
|
67
67
|
except Exception as e:
|
68
68
|
raise ToolException(f"Unable to extract files: {e}")
|
69
69
|
|
70
|
-
docs: List[Document] = []
|
71
70
|
for file in all_files['rows']:
|
72
71
|
metadata = {
|
73
72
|
("updated_on" if k == "modified" else k): str(v)
|
74
73
|
for k, v in file.items()
|
75
74
|
}
|
76
75
|
metadata['id'] = self.get_hash_from_bucket_and_file_name(self.bucket, file['name'])
|
77
|
-
|
78
|
-
return docs
|
76
|
+
yield Document(page_content="", metadata=metadata)
|
79
77
|
|
80
78
|
def get_hash_from_bucket_and_file_name(self, bucket, file_name):
|
81
79
|
hasher = hashlib.sha256()
|
@@ -197,7 +197,8 @@ class VectorStoreWrapper(BaseToolApiWrapper):
|
|
197
197
|
tool_name="_clean_collection"
|
198
198
|
)
|
199
199
|
data = self.vectoradapter.vectorstore.get(include=['metadatas'])
|
200
|
-
|
200
|
+
if data['ids']:
|
201
|
+
self.vectoradapter.vectorstore.delete(ids=data['ids'])
|
201
202
|
self._log_data(
|
202
203
|
f"Collection '{self.dataset}' has been cleaned. ",
|
203
204
|
tool_name="_clean_collection"
|
@@ -1,22 +1,21 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
|
-
|
4
|
-
|
5
|
-
from langchain_core.documents import Document
|
3
|
+
import xml.etree.ElementTree as ET
|
4
|
+
from typing import Generator, Optional
|
6
5
|
|
7
|
-
from alita_sdk.tools.elitea_base import BaseIndexParams
|
8
6
|
from azure.devops.connection import Connection
|
9
7
|
from azure.devops.v7_0.test_plan.models import TestPlanCreateParams, TestSuiteCreateParams, \
|
10
8
|
SuiteTestCaseCreateUpdateParameters
|
11
9
|
from azure.devops.v7_0.test_plan.test_plan_client import TestPlanClient
|
10
|
+
from langchain_core.documents import Document
|
12
11
|
from langchain_core.tools import ToolException
|
13
12
|
from msrest.authentication import BasicAuthentication
|
14
13
|
from pydantic import create_model, PrivateAttr, model_validator, SecretStr
|
15
14
|
from pydantic.fields import FieldInfo as Field
|
16
|
-
import xml.etree.ElementTree as ET
|
17
15
|
|
18
16
|
from ..work_item import AzureDevOpsApiWrapper
|
19
17
|
from ...elitea_base import BaseVectorStoreToolApiWrapper, extend_with_vector_tools
|
18
|
+
|
20
19
|
try:
|
21
20
|
from alita_sdk.runtime.langchain.interfaces.llm_processor import get_embeddings
|
22
21
|
except ImportError:
|
@@ -164,18 +163,6 @@ TestCasesGetModel = create_model(
|
|
164
163
|
suite_id=(int, Field(description="ID of the test suite for which test cases are requested"))
|
165
164
|
)
|
166
165
|
|
167
|
-
# Schema for indexing ADO Wiki pages into vector store
|
168
|
-
indexData = create_model(
|
169
|
-
"indexData",
|
170
|
-
__base__=BaseIndexParams,
|
171
|
-
plan_id=(int, Field(description="ID of the test plan for which test cases are requested")),
|
172
|
-
suite_ids=(list[int], Field(description="List of test suite IDs for which test cases are requested (can be empty)")),
|
173
|
-
progress_step=(Optional[int], Field(default=None, ge=0, le=100,
|
174
|
-
description="Optional step size for progress reporting during indexing")),
|
175
|
-
clean_index=(Optional[bool], Field(default=False,
|
176
|
-
description="Optional flag to enforce clean existing index before indexing new data")),
|
177
|
-
)
|
178
|
-
|
179
166
|
class TestPlanApiWrapper(BaseVectorStoreToolApiWrapper):
|
180
167
|
__test__ = False
|
181
168
|
organization_url: str
|
@@ -377,19 +364,6 @@ class TestPlanApiWrapper(BaseVectorStoreToolApiWrapper):
|
|
377
364
|
logger.error(f"Error getting test cases: {e}")
|
378
365
|
return ToolException(f"Error getting test cases: {e}")
|
379
366
|
|
380
|
-
def index_data(self,
|
381
|
-
plan_id: str,
|
382
|
-
suite_ids: list[str] = [],
|
383
|
-
collection_suffix: str = '',
|
384
|
-
progress_step: int = None,
|
385
|
-
clean_index: bool = False
|
386
|
-
):
|
387
|
-
"""Load ADO TestCases into the vector store."""
|
388
|
-
docs = self._base_loader(plan_id, suite_ids)
|
389
|
-
embedding = get_embeddings(self.embedding_model, self.embedding_model_params)
|
390
|
-
vs = self._init_vector_store(collection_suffix, embeddings=embedding)
|
391
|
-
return vs.index_documents(docs, progress_step=progress_step, clean_index=clean_index)
|
392
|
-
|
393
367
|
def _base_loader(self, plan_id: str, suite_ids: Optional[list[str]] = []) -> Generator[Document, None, None]:
|
394
368
|
cases = []
|
395
369
|
for sid in suite_ids:
|
@@ -410,7 +384,15 @@ class TestPlanApiWrapper(BaseVectorStoreToolApiWrapper):
|
|
410
384
|
})
|
411
385
|
|
412
386
|
def _process_document(self, document: Document) -> Generator[Document, None, None]:
|
413
|
-
|
387
|
+
if False:
|
388
|
+
yield # Unreachable, but keeps the function a generator
|
389
|
+
|
390
|
+
def _index_tool_params(self):
|
391
|
+
"""Return the parameters for indexing data."""
|
392
|
+
return {
|
393
|
+
"plan_id": (str, Field(description="ID of the test plan for which test cases are requested")),
|
394
|
+
"suite_ids": (str, Field(description="List of test suite IDs for which test cases are requested (can be empty)"))
|
395
|
+
}
|
414
396
|
|
415
397
|
@extend_with_vector_tools
|
416
398
|
def get_available_tools(self):
|
@@ -481,11 +463,5 @@ class TestPlanApiWrapper(BaseVectorStoreToolApiWrapper):
|
|
481
463
|
"description": self.get_test_cases.__doc__,
|
482
464
|
"args_schema": TestCasesGetModel,
|
483
465
|
"ref": self.get_test_cases,
|
484
|
-
},
|
485
|
-
{
|
486
|
-
"name": "index_data",
|
487
|
-
"ref": self.index_data,
|
488
|
-
"description": self.index_data.__doc__,
|
489
|
-
"args_schema": indexData,
|
490
466
|
}
|
491
467
|
]
|
@@ -1,28 +1,26 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import Any, Dict, Generator,
|
2
|
+
from typing import Any, Dict, Generator, Optional
|
3
3
|
|
4
|
-
from alita_sdk.tools.elitea_base import BaseIndexParams
|
5
|
-
from langchain_core.documents import Document
|
6
|
-
|
7
|
-
from ...elitea_base import BaseVectorStoreToolApiWrapper, extend_with_vector_tools
|
8
4
|
from azure.devops.connection import Connection
|
9
5
|
from azure.devops.exceptions import AzureDevOpsServiceError
|
10
6
|
from azure.devops.v7_0.core import CoreClient
|
11
7
|
from azure.devops.v7_0.wiki import WikiClient, WikiPageCreateOrUpdateParameters, WikiCreateParametersV2, \
|
12
8
|
WikiPageMoveParameters
|
13
9
|
from azure.devops.v7_0.wiki.models import GitVersionDescriptor
|
10
|
+
from langchain_core.documents import Document
|
14
11
|
from langchain_core.tools import ToolException
|
15
12
|
from msrest.authentication import BasicAuthentication
|
16
13
|
from pydantic import create_model, PrivateAttr, SecretStr
|
17
14
|
from pydantic import model_validator
|
18
15
|
from pydantic.fields import Field
|
16
|
+
|
17
|
+
from ...elitea_base import BaseVectorStoreToolApiWrapper, extend_with_vector_tools
|
18
|
+
|
19
19
|
try:
|
20
20
|
from alita_sdk.runtime.langchain.interfaces.llm_processor import get_embeddings
|
21
21
|
except ImportError:
|
22
22
|
from alita_sdk.langchain.interfaces.llm_processor import get_embeddings
|
23
23
|
|
24
|
-
from ...elitea_base import BaseToolApiWrapper
|
25
|
-
|
26
24
|
logger = logging.getLogger(__name__)
|
27
25
|
|
28
26
|
GetWikiInput = create_model(
|
@@ -60,17 +58,6 @@ RenamePageInput = create_model(
|
|
60
58
|
version_type=(Optional[str], Field(description="Version type (branch, tag, or commit). Determines how Id is interpreted", default="branch"))
|
61
59
|
)
|
62
60
|
|
63
|
-
# Schema for indexing ADO Wiki pages into vector store
|
64
|
-
indexData = create_model(
|
65
|
-
"indexData",
|
66
|
-
__base__=BaseIndexParams,
|
67
|
-
wiki_identifier=(str, Field(description="Wiki identifier to index, e.g., 'ABCProject.wiki'")),
|
68
|
-
progress_step=(Optional[int], Field(default=None, ge=0, le=100,
|
69
|
-
description="Optional step size for progress reporting during indexing")),
|
70
|
-
clean_index=(Optional[bool], Field(default=False,
|
71
|
-
description="Optional flag to enforce clean existing index before indexing new data")),
|
72
|
-
)
|
73
|
-
|
74
61
|
|
75
62
|
class AzureDevOpsApiWrapper(BaseVectorStoreToolApiWrapper):
|
76
63
|
organization_url: str
|
@@ -242,19 +229,6 @@ class AzureDevOpsApiWrapper(BaseVectorStoreToolApiWrapper):
|
|
242
229
|
logger.error(f"Unable to modify wiki page: {str(e)}")
|
243
230
|
return ToolException(f"Unable to modify wiki page: {str(e)}")
|
244
231
|
|
245
|
-
def index_data(
|
246
|
-
self,
|
247
|
-
wiki_identifier: str,
|
248
|
-
collection_suffix: str = '',
|
249
|
-
progress_step: int = None,
|
250
|
-
clean_index: bool = False
|
251
|
-
):
|
252
|
-
"""Load ADO Wiki pages into the vector store."""
|
253
|
-
docs = self._base_loader(wiki_identifier)
|
254
|
-
embedding = get_embeddings(self.embedding_model, self.embedding_model_params)
|
255
|
-
vs = self._init_vector_store(collection_suffix, embeddings=embedding)
|
256
|
-
return vs.index_documents(docs, progress_step=progress_step, clean_index=clean_index)
|
257
|
-
|
258
232
|
def _base_loader(self, wiki_identifier: str) -> Generator[Document, None, None]:
|
259
233
|
pages = self._client.get_pages_batch(pages_batch_request={}, project=self.project, wiki_identifier=wiki_identifier)
|
260
234
|
#
|
@@ -266,8 +240,11 @@ class AzureDevOpsApiWrapper(BaseVectorStoreToolApiWrapper):
|
|
266
240
|
'updated_on': ''
|
267
241
|
})
|
268
242
|
|
269
|
-
def
|
270
|
-
|
243
|
+
def _index_tool_params(self):
|
244
|
+
"""Return the parameters for indexing data."""
|
245
|
+
return {
|
246
|
+
"wiki_identifier": (str, Field(description="Wiki identifier to index, e.g., 'ABCProject.wiki'"))
|
247
|
+
}
|
271
248
|
|
272
249
|
@extend_with_vector_tools
|
273
250
|
def get_available_tools(self):
|
@@ -314,11 +291,5 @@ class AzureDevOpsApiWrapper(BaseVectorStoreToolApiWrapper):
|
|
314
291
|
"description": self.rename_wiki_page.__doc__,
|
315
292
|
"args_schema": RenamePageInput,
|
316
293
|
"ref": self.rename_wiki_page,
|
317
|
-
},
|
318
|
-
{
|
319
|
-
"name": "index_data",
|
320
|
-
"ref": self.index_data,
|
321
|
-
"description": self.index_data.__doc__,
|
322
|
-
"args_schema": indexData,
|
323
294
|
}
|
324
295
|
]
|
@@ -839,6 +839,8 @@ class ConfluenceAPIWrapper(BaseVectorStoreToolApiWrapper):
|
|
839
839
|
loader = AlitaConfluenceLoader(self.client, self.llm, bins_with_llm, **confluence_loader_params)
|
840
840
|
|
841
841
|
for document in loader._lazy_load(kwargs={}):
|
842
|
+
if 'updated_on' not in document.metadata and 'when' in document.metadata:
|
843
|
+
document.metadata['updated_on'] = document.metadata['when']
|
842
844
|
yield document
|
843
845
|
|
844
846
|
def _process_document(self, document: Document) -> Generator[Document, None, None]:
|
alita_sdk/tools/elitea_base.py
CHANGED
@@ -41,6 +41,11 @@ BaseCodeIndexParams = create_model(
|
|
41
41
|
blacklist=(Optional[List[str]], Field(description="File extensions or paths to exclude. Defaults to no exclusions if None.", default=None)),
|
42
42
|
)
|
43
43
|
|
44
|
+
RemoveIndexParams = create_model(
|
45
|
+
"RemoveIndexParams",
|
46
|
+
collection_suffix=(Optional[str], Field(description="Optional suffix for collection name (max 7 characters)", default="", max_length=7)),
|
47
|
+
)
|
48
|
+
|
44
49
|
BaseSearchParams = create_model(
|
45
50
|
"BaseSearchParams",
|
46
51
|
query=(str, Field(description="Query text to search in the index")),
|
@@ -109,7 +114,7 @@ BaseIndexDataParams = create_model(
|
|
109
114
|
description="Optional step size for progress reporting during indexing")),
|
110
115
|
clean_index=(Optional[bool], Field(default=False,
|
111
116
|
description="Optional flag to enforce clean existing index before indexing new data")),
|
112
|
-
chunking_tool=(Literal['markdown', 'statistical', 'proposal'], Field(description="Name of chunking tool", default=
|
117
|
+
chunking_tool=(Literal['markdown', 'statistical', 'proposal'], Field(description="Name of chunking tool", default=None)),
|
113
118
|
chunking_config=(Optional[dict], Field(description="Chunking tool configuration", default_factory=dict)),
|
114
119
|
)
|
115
120
|
|
@@ -345,7 +350,9 @@ class BaseVectorStoreToolApiWrapper(BaseToolApiWrapper):
|
|
345
350
|
"alita_sdk_options": {
|
346
351
|
"target_schema": collection_name,
|
347
352
|
},
|
348
|
-
"connection_string": self.connection_string.get_secret_value()
|
353
|
+
# "connection_string": self.connection_string.get_secret_value()
|
354
|
+
# 'postgresql+psycopg://project_23_user:Rxu4QtM2InLVNnm62GX7@pgvector:5432/project_23'
|
355
|
+
"connection_string": 'postgresql+psycopg://postgres:yourpassword@localhost:5432/postgres'
|
349
356
|
}
|
350
357
|
elif self.vectorstore_type == 'Chroma':
|
351
358
|
vectorstore_params = {
|
@@ -363,6 +370,13 @@ class BaseVectorStoreToolApiWrapper(BaseToolApiWrapper):
|
|
363
370
|
process_document_func=self._process_documents,
|
364
371
|
)
|
365
372
|
|
373
|
+
def remove_index(self, collection_suffix: str = ""):
|
374
|
+
"""
|
375
|
+
Cleans the indexed data in the collection
|
376
|
+
"""
|
377
|
+
|
378
|
+
self._init_vector_store(collection_suffix)._clean_collection()
|
379
|
+
|
366
380
|
def search_index(self,
|
367
381
|
query: str,
|
368
382
|
collection_suffix: str = "",
|
@@ -463,7 +477,14 @@ class BaseVectorStoreToolApiWrapper(BaseToolApiWrapper):
|
|
463
477
|
"ref": self.stepback_summary_index,
|
464
478
|
"description": self.stepback_summary_index.__doc__,
|
465
479
|
"args_schema": BaseStepbackSearchParams
|
466
|
-
}
|
480
|
+
},
|
481
|
+
{
|
482
|
+
"name": "remove_index",
|
483
|
+
"mode": "remove_index",
|
484
|
+
"ref": self.remove_index,
|
485
|
+
"description": self.remove_index.__doc__,
|
486
|
+
"args_schema": RemoveIndexParams
|
487
|
+
},
|
467
488
|
]
|
468
489
|
|
469
490
|
|
@@ -85,13 +85,14 @@ class AlitaGitlabToolkit(BaseToolkit):
|
|
85
85
|
if selected_tools:
|
86
86
|
if tool["name"] not in selected_tools:
|
87
87
|
continue
|
88
|
+
|
88
89
|
tools.append(BaseAction(
|
89
90
|
api_wrapper=gitlab_api_wrapper,
|
90
91
|
name=prefix + tool["name"],
|
91
|
-
description=tool["description"] +
|
92
|
+
description=tool["description"] + f"\nrepo: {gitlab_api_wrapper.repository}",
|
92
93
|
args_schema=tool["args_schema"]
|
93
94
|
))
|
94
95
|
return cls(tools=tools)
|
95
96
|
|
96
|
-
def get_tools(self):
|
97
|
+
def get_tools(self)-> List[BaseTool]:
|
97
98
|
return self.tools
|
@@ -1,6 +1,7 @@
|
|
1
1
|
# api_wrapper.py
|
2
2
|
from typing import Any, Dict, List, Optional
|
3
|
-
|
3
|
+
import fnmatch
|
4
|
+
from alita_sdk.tools.elitea_base import extend_with_vector_tools
|
4
5
|
from alita_sdk.tools.elitea_base import BaseCodeToolApiWrapper
|
5
6
|
from pydantic import create_model, Field, model_validator, SecretStr, PrivateAttr
|
6
7
|
|
@@ -53,7 +54,11 @@ CreateBranchModel = create_model(
|
|
53
54
|
)
|
54
55
|
ListBranchesInRepoModel = create_model(
|
55
56
|
"ListBranchesInRepoModel",
|
57
|
+
limit=(Optional[int], Field(default=20, description="Maximum number of branches to return. If not provided, all branches will be returned.")),
|
58
|
+
branch_wildcard=(Optional[str], Field(default=None, description="Wildcard pattern to filter branches by name. If not provided, all branches will be returned."))
|
59
|
+
|
56
60
|
)
|
61
|
+
|
57
62
|
ListFilesModel = create_model(
|
58
63
|
"ListFilesModel",
|
59
64
|
path=(Optional[str], Field(description="The path to list files from")),
|
@@ -142,9 +147,30 @@ class GitLabAPIWrapper(BaseCodeToolApiWrapper):
|
|
142
147
|
self._repo_instance.default_branch = branch_name
|
143
148
|
return f"Active branch set to {branch_name}"
|
144
149
|
|
145
|
-
def list_branches_in_repo(self) -> List[str]:
|
146
|
-
|
147
|
-
|
150
|
+
def list_branches_in_repo(self, limit: Optional[int] = 20, branch_wildcard: Optional[str] = None) -> List[str]:
|
151
|
+
"""
|
152
|
+
Lists branches in the repository with optional limit and wildcard filtering.
|
153
|
+
|
154
|
+
Parameters:
|
155
|
+
limit (Optional[int]): Maximum number of branches to return
|
156
|
+
branch_wildcard (Optional[str]): Wildcard pattern to filter branches (e.g., '*dev')
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
List[str]: List containing names of branches
|
160
|
+
"""
|
161
|
+
try:
|
162
|
+
branches = self._repo_instance.branches.list(get_all=True)
|
163
|
+
|
164
|
+
if branch_wildcard:
|
165
|
+
branches = [branch for branch in branches if fnmatch.fnmatch(branch.name, branch_wildcard)]
|
166
|
+
|
167
|
+
if limit is not None:
|
168
|
+
branches = branches[:limit]
|
169
|
+
|
170
|
+
branch_names = [branch.name for branch in branches]
|
171
|
+
return branch_names
|
172
|
+
except Exception as e:
|
173
|
+
return f"Failed to list branches: {str(e)}"
|
148
174
|
|
149
175
|
def list_files(self, path: str = None, recursive: bool = True, branch: str = None) -> List[str]:
|
150
176
|
branch = branch if branch else self._active_branch
|
@@ -404,84 +430,85 @@ class GitLabAPIWrapper(BaseCodeToolApiWrapper):
|
|
404
430
|
for commit in commits
|
405
431
|
]
|
406
432
|
|
433
|
+
@extend_with_vector_tools
|
407
434
|
def get_available_tools(self):
|
408
435
|
return [
|
409
436
|
{
|
410
437
|
"name": "create_branch",
|
411
438
|
"ref": self.create_branch,
|
412
|
-
"description": self.create_branch.__doc__,
|
439
|
+
"description": self.create_branch.__doc__ or "Create a new branch in the repository.",
|
413
440
|
"args_schema": CreateBranchModel,
|
414
441
|
},
|
415
442
|
{
|
416
443
|
"name": "list_branches_in_repo",
|
417
444
|
"ref": self.list_branches_in_repo,
|
418
|
-
"description": self.list_branches_in_repo.__doc__,
|
445
|
+
"description": self.list_branches_in_repo.__doc__ or "List branches in the repository with optional limit and wildcard filtering.",
|
419
446
|
"args_schema": ListBranchesInRepoModel,
|
420
447
|
},
|
421
448
|
{
|
422
449
|
"name": "list_files",
|
423
450
|
"ref": self.list_files,
|
424
|
-
"description": self.list_files.__doc__,
|
451
|
+
"description": self.list_files.__doc__ or "List files in the repository with optional path, recursive search, and branch.",
|
425
452
|
"args_schema": ListFilesModel,
|
426
453
|
},
|
427
454
|
{
|
428
455
|
"name": "list_folders",
|
429
456
|
"ref": self.list_folders,
|
430
|
-
"description": self.list_folders.__doc__,
|
457
|
+
"description": self.list_folders.__doc__ or "List folders in the repository with optional path, recursive search, and branch.",
|
431
458
|
"args_schema": ListFoldersModel,
|
432
459
|
},
|
433
460
|
{
|
434
461
|
"name": "get_issues",
|
435
462
|
"ref": self.get_issues,
|
436
|
-
"description": self.get_issues.__doc__,
|
463
|
+
"description": self.get_issues.__doc__ or "Get all open issues in the repository.",
|
437
464
|
"args_schema": GetIssuesModel,
|
438
465
|
},
|
439
466
|
{
|
440
467
|
"name": "get_issue",
|
441
468
|
"ref": self.get_issue,
|
442
|
-
"description": self.get_issue.__doc__,
|
469
|
+
"description": self.get_issue.__doc__ or "Get details of a specific issue by its number.",
|
443
470
|
"args_schema": GetIssueModel,
|
444
471
|
},
|
445
472
|
{
|
446
473
|
"name": "create_pull_request",
|
447
474
|
"ref": self.create_pull_request,
|
448
|
-
"description": self.create_pull_request.__doc__,
|
475
|
+
"description": self.create_pull_request.__doc__ or "Create a pull request (merge request) in the repository.",
|
449
476
|
"args_schema": CreatePullRequestModel,
|
450
477
|
},
|
451
478
|
{
|
452
479
|
"name": "comment_on_issue",
|
453
480
|
"ref": self.comment_on_issue,
|
454
|
-
"description": self.comment_on_issue.__doc__,
|
481
|
+
"description": self.comment_on_issue.__doc__ or "Comment on an issue in the repository.",
|
455
482
|
"args_schema": CommentOnIssueModel,
|
456
483
|
},
|
457
484
|
{
|
458
485
|
"name": "create_file",
|
459
486
|
"ref": self.create_file,
|
460
|
-
"description": self.create_file.__doc__,
|
487
|
+
"description": self.create_file.__doc__ or "Create a new file in the repository.",
|
461
488
|
"args_schema": CreateFileModel,
|
462
489
|
},
|
463
490
|
{
|
464
491
|
"name": "read_file",
|
465
492
|
"ref": self.read_file,
|
466
|
-
"description": self.read_file.__doc__,
|
493
|
+
"description": self.read_file.__doc__ or "Read the contents of a file in the repository.",
|
467
494
|
"args_schema": ReadFileModel,
|
468
495
|
},
|
469
496
|
{
|
470
497
|
"name": "update_file",
|
471
498
|
"ref": self.update_file,
|
472
|
-
"description": self.update_file.__doc__,
|
499
|
+
"description": self.update_file.__doc__ or "Update the contents of a file in the repository.",
|
473
500
|
"args_schema": UpdateFileModel,
|
474
501
|
},
|
475
502
|
{
|
476
503
|
"name": "append_file",
|
477
504
|
"ref": self.append_file,
|
478
|
-
"description": self.append_file.__doc__,
|
505
|
+
"description": self.append_file.__doc__ or "Append content to a file in the repository.",
|
479
506
|
"args_schema": AppendFileModel,
|
480
507
|
},
|
481
508
|
{
|
482
509
|
"name": "delete_file",
|
483
510
|
"ref": self.delete_file,
|
484
|
-
"description": self.delete_file.__doc__,
|
511
|
+
"description": self.delete_file.__doc__ or "Delete a file from the repository.",
|
485
512
|
"args_schema": DeleteFileModel,
|
486
513
|
},
|
487
514
|
{
|
@@ -507,5 +534,5 @@ class GitLabAPIWrapper(BaseCodeToolApiWrapper):
|
|
507
534
|
"ref": self.get_commits,
|
508
535
|
"description": "Retrieve a list of commits from the repository.",
|
509
536
|
"args_schema": GetCommitsModel,
|
510
|
-
}
|
537
|
+
}
|
511
538
|
] + self._get_vector_search_tools()
|