alita-sdk 0.3.374__py3-none-any.whl → 0.3.423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of alita-sdk might be problematic. Click here for more details.
- alita_sdk/configurations/bitbucket.py +95 -0
- alita_sdk/configurations/confluence.py +96 -1
- alita_sdk/configurations/gitlab.py +79 -0
- alita_sdk/configurations/jira.py +103 -0
- alita_sdk/configurations/testrail.py +88 -0
- alita_sdk/configurations/xray.py +93 -0
- alita_sdk/configurations/zephyr_enterprise.py +93 -0
- alita_sdk/configurations/zephyr_essential.py +75 -0
- alita_sdk/runtime/clients/client.py +3 -2
- alita_sdk/runtime/clients/sandbox_client.py +8 -0
- alita_sdk/runtime/langchain/assistant.py +56 -40
- alita_sdk/runtime/langchain/constants.py +4 -0
- alita_sdk/runtime/langchain/document_loaders/AlitaDocxMammothLoader.py +315 -3
- alita_sdk/runtime/langchain/document_loaders/AlitaJSONLoader.py +4 -1
- alita_sdk/runtime/langchain/document_loaders/constants.py +28 -12
- alita_sdk/runtime/langchain/langraph_agent.py +92 -28
- alita_sdk/runtime/langchain/utils.py +24 -4
- alita_sdk/runtime/toolkits/application.py +8 -1
- alita_sdk/runtime/toolkits/tools.py +80 -49
- alita_sdk/runtime/tools/__init__.py +7 -2
- alita_sdk/runtime/tools/application.py +7 -0
- alita_sdk/runtime/tools/function.py +28 -23
- alita_sdk/runtime/tools/graph.py +10 -4
- alita_sdk/runtime/tools/image_generation.py +104 -8
- alita_sdk/runtime/tools/llm.py +146 -114
- alita_sdk/runtime/tools/sandbox.py +166 -63
- alita_sdk/runtime/tools/vectorstore.py +22 -21
- alita_sdk/runtime/tools/vectorstore_base.py +16 -15
- alita_sdk/runtime/utils/utils.py +1 -0
- alita_sdk/tools/__init__.py +43 -31
- alita_sdk/tools/ado/work_item/ado_wrapper.py +17 -8
- alita_sdk/tools/base_indexer_toolkit.py +102 -93
- alita_sdk/tools/code_indexer_toolkit.py +15 -5
- alita_sdk/tools/confluence/api_wrapper.py +30 -8
- alita_sdk/tools/confluence/loader.py +10 -0
- alita_sdk/tools/elitea_base.py +22 -22
- alita_sdk/tools/gitlab/api_wrapper.py +8 -9
- alita_sdk/tools/jira/api_wrapper.py +1 -1
- alita_sdk/tools/non_code_indexer_toolkit.py +2 -2
- alita_sdk/tools/openapi/__init__.py +10 -1
- alita_sdk/tools/qtest/api_wrapper.py +298 -51
- alita_sdk/tools/sharepoint/api_wrapper.py +104 -33
- alita_sdk/tools/sharepoint/authorization_helper.py +175 -1
- alita_sdk/tools/sharepoint/utils.py +8 -2
- alita_sdk/tools/utils/content_parser.py +27 -16
- alita_sdk/tools/vector_adapters/VectorStoreAdapter.py +38 -25
- {alita_sdk-0.3.374.dist-info → alita_sdk-0.3.423.dist-info}/METADATA +1 -1
- {alita_sdk-0.3.374.dist-info → alita_sdk-0.3.423.dist-info}/RECORD +51 -51
- {alita_sdk-0.3.374.dist-info → alita_sdk-0.3.423.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.374.dist-info → alita_sdk-0.3.423.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.374.dist-info → alita_sdk-0.3.423.dist-info}/top_level.txt +0 -0
|
@@ -7,12 +7,14 @@ from json import JSONDecodeError
|
|
|
7
7
|
from typing import Optional, List, Any, Dict, Callable, Generator, Literal
|
|
8
8
|
|
|
9
9
|
import requests
|
|
10
|
+
from atlassian.errors import ApiError
|
|
10
11
|
from langchain_community.document_loaders.confluence import ContentFormat
|
|
11
12
|
from langchain_core.documents import Document
|
|
12
13
|
from langchain_core.messages import HumanMessage
|
|
13
14
|
from langchain_core.tools import ToolException
|
|
14
15
|
from markdownify import markdownify
|
|
15
16
|
from pydantic import Field, PrivateAttr, model_validator, create_model, SecretStr
|
|
17
|
+
from requests import HTTPError
|
|
16
18
|
from tenacity import retry, stop_after_attempt, wait_exponential, before_sleep_log
|
|
17
19
|
|
|
18
20
|
from alita_sdk.tools.non_code_indexer_toolkit import NonCodeIndexerToolkit
|
|
@@ -194,6 +196,7 @@ class ConfluenceAPIWrapper(NonCodeIndexerToolkit):
|
|
|
194
196
|
keep_markdown_format: Optional[bool] = True
|
|
195
197
|
ocr_languages: Optional[str] = None
|
|
196
198
|
keep_newlines: Optional[bool] = True
|
|
199
|
+
_errors: Optional[list[str]] = None
|
|
197
200
|
_image_cache: ImageDescriptionCache = PrivateAttr(default_factory=ImageDescriptionCache)
|
|
198
201
|
|
|
199
202
|
@model_validator(mode='before')
|
|
@@ -498,7 +501,9 @@ class ConfluenceAPIWrapper(NonCodeIndexerToolkit):
|
|
|
498
501
|
restrictions = self.client.get_all_restrictions_for_content(page["id"])
|
|
499
502
|
|
|
500
503
|
return (
|
|
501
|
-
page["status"] == "current"
|
|
504
|
+
(page["status"] == "current"
|
|
505
|
+
# allow user to see archived content if needed
|
|
506
|
+
or page["status"] == "archived")
|
|
502
507
|
and not restrictions["read"]["restrictions"]["user"]["results"]
|
|
503
508
|
and not restrictions["read"]["restrictions"]["group"]["results"]
|
|
504
509
|
)
|
|
@@ -518,18 +523,35 @@ class ConfluenceAPIWrapper(NonCodeIndexerToolkit):
|
|
|
518
523
|
),
|
|
519
524
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
520
525
|
)(self.client.get_page_by_id)
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
+
try:
|
|
527
|
+
page = get_page(
|
|
528
|
+
page_id=page_id, expand=f"{self.content_format.value},version"
|
|
529
|
+
)
|
|
530
|
+
except (ApiError, HTTPError) as e:
|
|
531
|
+
logger.error(f"Error fetching page with ID {page_id}: {e}")
|
|
532
|
+
page_content_temp = f"Confluence API Error: cannot fetch the page with ID {page_id}: {e}"
|
|
533
|
+
# store errors
|
|
534
|
+
if self._errors is None:
|
|
535
|
+
self._errors = []
|
|
536
|
+
self._errors.append(page_content_temp)
|
|
537
|
+
return Document(page_content=page_content_temp,
|
|
538
|
+
metadata={})
|
|
539
|
+
# TODO: update on toolkit advanced settings level as a separate feature
|
|
540
|
+
# if not self.include_restricted_content and not self.is_public_page(page):
|
|
541
|
+
# continue
|
|
526
542
|
yield self.process_page(page, skip_images)
|
|
527
543
|
|
|
544
|
+
def _log_errors(self):
|
|
545
|
+
""" Log errors encountered during toolkit execution. """
|
|
546
|
+
if self._errors:
|
|
547
|
+
logger.info(f"Errors encountered during toolkit execution: {self._errors}")
|
|
548
|
+
|
|
528
549
|
def read_page_by_id(self, page_id: str, skip_images: bool = False):
|
|
529
550
|
"""Reads a page by its id in the Confluence space. If id is not available, but there is a title - use get_page_id first."""
|
|
530
551
|
result = list(self.get_pages_by_id([page_id], skip_images))
|
|
531
552
|
if not result:
|
|
532
|
-
"
|
|
553
|
+
return f"Pages not found. Errors: {self._errors}" if self._errors \
|
|
554
|
+
else "Pages not found or you do not have access to them."
|
|
533
555
|
return result[0].page_content
|
|
534
556
|
# return self._strip_base64_images(result[0].page_content) if skip_images else result[0].page_content
|
|
535
557
|
|
|
@@ -1674,7 +1696,7 @@ class ConfluenceAPIWrapper(NonCodeIndexerToolkit):
|
|
|
1674
1696
|
description="List of file extensions to skip when processing attachments: i.e. ['*.png', '*.jpg']",
|
|
1675
1697
|
default=[])),
|
|
1676
1698
|
"include_comments": (Optional[bool], Field(description="Include comments.", default=False)),
|
|
1677
|
-
"include_labels": (Optional[bool], Field(description="Include labels.", default=
|
|
1699
|
+
"include_labels": (Optional[bool], Field(description="Include labels.", default=False)),
|
|
1678
1700
|
"ocr_languages": (Optional[str], Field(description="OCR languages for processing attachments.", default='eng')),
|
|
1679
1701
|
"keep_markdown_format": (Optional[bool], Field(description="Keep the markdown format.", default=True)),
|
|
1680
1702
|
"keep_newlines": (Optional[bool], Field(description="Keep newlines in the content.", default=True)),
|
|
@@ -3,6 +3,7 @@ from typing import Optional, List
|
|
|
3
3
|
from logging import getLogger
|
|
4
4
|
|
|
5
5
|
import requests
|
|
6
|
+
from langchain_core.documents import Document
|
|
6
7
|
|
|
7
8
|
logger = getLogger(__name__)
|
|
8
9
|
from PIL import Image
|
|
@@ -193,6 +194,15 @@ class AlitaConfluenceLoader(ConfluenceLoader):
|
|
|
193
194
|
else:
|
|
194
195
|
return super().process_image(link, ocr_languages)
|
|
195
196
|
|
|
197
|
+
def process_page(self, page: dict, include_attachments: bool, include_comments: bool, include_labels: bool,
|
|
198
|
+
content_format: ContentFormat, ocr_languages: Optional[str] = None,
|
|
199
|
+
keep_markdown_format: Optional[bool] = False, keep_newlines: bool = False) -> Document:
|
|
200
|
+
if not page.get("title"):
|
|
201
|
+
# if 'include_restricted_content' set to True, draft pages are loaded and can have no title
|
|
202
|
+
page["title"] = "Untitled"
|
|
203
|
+
return super().process_page(page, include_attachments, include_comments, include_labels, content_format,
|
|
204
|
+
ocr_languages, keep_markdown_format, keep_newlines)
|
|
205
|
+
|
|
196
206
|
# TODO review usage
|
|
197
207
|
# def process_svg(
|
|
198
208
|
# self,
|
alita_sdk/tools/elitea_base.py
CHANGED
|
@@ -33,12 +33,12 @@ LoaderSchema = create_model(
|
|
|
33
33
|
# Base Vector Store Schema Models
|
|
34
34
|
BaseIndexParams = create_model(
|
|
35
35
|
"BaseIndexParams",
|
|
36
|
-
|
|
36
|
+
index_name=(str, Field(description="Index name (max 7 characters)", min_length=1, max_length=7)),
|
|
37
37
|
)
|
|
38
38
|
|
|
39
39
|
BaseCodeIndexParams = create_model(
|
|
40
40
|
"BaseCodeIndexParams",
|
|
41
|
-
|
|
41
|
+
index_name=(str, Field(description="Index name (max 7 characters)", min_length=1, max_length=7)),
|
|
42
42
|
clean_index=(Optional[bool], Field(default=False, description="Optional flag to enforce clean existing index before indexing new data")),
|
|
43
43
|
progress_step=(Optional[int], Field(default=5, ge=0, le=100,
|
|
44
44
|
description="Optional step size for progress reporting during indexing")),
|
|
@@ -50,14 +50,14 @@ BaseCodeIndexParams = create_model(
|
|
|
50
50
|
|
|
51
51
|
RemoveIndexParams = create_model(
|
|
52
52
|
"RemoveIndexParams",
|
|
53
|
-
|
|
53
|
+
index_name=(Optional[str], Field(description="Optional index name (max 7 characters)", default="", max_length=7)),
|
|
54
54
|
)
|
|
55
55
|
|
|
56
56
|
BaseSearchParams = create_model(
|
|
57
57
|
"BaseSearchParams",
|
|
58
58
|
query=(str, Field(description="Query text to search in the index")),
|
|
59
|
-
|
|
60
|
-
description="Optional
|
|
59
|
+
index_name=(Optional[str], Field(
|
|
60
|
+
description="Optional index name (max 7 characters). Leave empty to search across all datasets",
|
|
61
61
|
default="", max_length=7)),
|
|
62
62
|
filter=(Optional[dict], Field(
|
|
63
63
|
description="Filter to apply to the search results. Can be a dictionary or a JSON string.",
|
|
@@ -87,7 +87,7 @@ BaseSearchParams = create_model(
|
|
|
87
87
|
BaseStepbackSearchParams = create_model(
|
|
88
88
|
"BaseStepbackSearchParams",
|
|
89
89
|
query=(str, Field(description="Query text to search in the index")),
|
|
90
|
-
|
|
90
|
+
index_name=(Optional[str], Field(description="Optional index name (max 7 characters)", default="", max_length=7)),
|
|
91
91
|
messages=(Optional[List], Field(description="Chat messages for stepback search context", default=[])),
|
|
92
92
|
filter=(Optional[dict], Field(
|
|
93
93
|
description="Filter to apply to the search results. Can be a dictionary or a JSON string.",
|
|
@@ -324,12 +324,12 @@ class BaseVectorStoreToolApiWrapper(BaseToolApiWrapper):
|
|
|
324
324
|
#
|
|
325
325
|
docs = base_chunker(file_content_generator=docs, config=base_chunking_config)
|
|
326
326
|
#
|
|
327
|
-
|
|
327
|
+
index_name = kwargs.get("index_name")
|
|
328
328
|
progress_step = kwargs.get("progress_step")
|
|
329
329
|
clean_index = kwargs.get("clean_index")
|
|
330
330
|
vs = self._init_vector_store()
|
|
331
331
|
#
|
|
332
|
-
return vs.index_documents(docs,
|
|
332
|
+
return vs.index_documents(docs, index_name=index_name, progress_step=progress_step, clean_index=clean_index)
|
|
333
333
|
|
|
334
334
|
def _process_documents(self, documents: List[Document]) -> Generator[Document, None, None]:
|
|
335
335
|
"""
|
|
@@ -399,10 +399,10 @@ class BaseVectorStoreToolApiWrapper(BaseToolApiWrapper):
|
|
|
399
399
|
)
|
|
400
400
|
return self._vector_store
|
|
401
401
|
|
|
402
|
-
def remove_index(self,
|
|
402
|
+
def remove_index(self, index_name: str = ""):
|
|
403
403
|
"""Cleans the indexed data in the collection."""
|
|
404
|
-
self._init_vector_store()._clean_collection(
|
|
405
|
-
return (f"Collection '{
|
|
404
|
+
self._init_vector_store()._clean_collection(index_name=index_name)
|
|
405
|
+
return (f"Collection '{index_name}' has been removed from the vector store.\n"
|
|
406
406
|
f"Available collections: {self.list_collections()}")
|
|
407
407
|
|
|
408
408
|
def list_collections(self):
|
|
@@ -410,19 +410,19 @@ class BaseVectorStoreToolApiWrapper(BaseToolApiWrapper):
|
|
|
410
410
|
vectorstore_wrapper = self._init_vector_store()
|
|
411
411
|
return vectorstore_wrapper.list_collections()
|
|
412
412
|
|
|
413
|
-
def _build_collection_filter(self, filter: dict | str,
|
|
413
|
+
def _build_collection_filter(self, filter: dict | str, index_name: str = "") -> dict:
|
|
414
414
|
"""Builds a filter for the collection based on the provided suffix."""
|
|
415
415
|
|
|
416
416
|
filter = filter if isinstance(filter, dict) else json.loads(filter)
|
|
417
|
-
if
|
|
417
|
+
if index_name:
|
|
418
418
|
filter.update({"collection": {
|
|
419
|
-
"$eq":
|
|
419
|
+
"$eq": index_name.strip()
|
|
420
420
|
}})
|
|
421
421
|
return filter
|
|
422
422
|
|
|
423
423
|
def search_index(self,
|
|
424
424
|
query: str,
|
|
425
|
-
|
|
425
|
+
index_name: str = "",
|
|
426
426
|
filter: dict | str = {}, cut_off: float = 0.5,
|
|
427
427
|
search_top: int = 10, reranker: dict = {},
|
|
428
428
|
full_text_search: Optional[Dict[str, Any]] = None,
|
|
@@ -431,7 +431,7 @@ class BaseVectorStoreToolApiWrapper(BaseToolApiWrapper):
|
|
|
431
431
|
**kwargs):
|
|
432
432
|
""" Searches indexed documents in the vector store."""
|
|
433
433
|
vectorstore = self._init_vector_store()
|
|
434
|
-
filter = self._build_collection_filter(filter,
|
|
434
|
+
filter = self._build_collection_filter(filter, index_name)
|
|
435
435
|
found_docs = vectorstore.search_documents(
|
|
436
436
|
query,
|
|
437
437
|
doctype=self.doctype,
|
|
@@ -448,7 +448,7 @@ class BaseVectorStoreToolApiWrapper(BaseToolApiWrapper):
|
|
|
448
448
|
def stepback_search_index(self,
|
|
449
449
|
query: str,
|
|
450
450
|
messages: List[Dict[str, Any]] = [],
|
|
451
|
-
|
|
451
|
+
index_name: str = "",
|
|
452
452
|
filter: dict | str = {}, cut_off: float = 0.5,
|
|
453
453
|
search_top: int = 10, reranker: dict = {},
|
|
454
454
|
full_text_search: Optional[Dict[str, Any]] = None,
|
|
@@ -457,7 +457,7 @@ class BaseVectorStoreToolApiWrapper(BaseToolApiWrapper):
|
|
|
457
457
|
**kwargs):
|
|
458
458
|
""" Searches indexed documents in the vector store."""
|
|
459
459
|
|
|
460
|
-
filter = self._build_collection_filter(filter,
|
|
460
|
+
filter = self._build_collection_filter(filter, index_name)
|
|
461
461
|
vectorstore = self._init_vector_store()
|
|
462
462
|
found_docs = vectorstore.stepback_search(
|
|
463
463
|
query,
|
|
@@ -475,7 +475,7 @@ class BaseVectorStoreToolApiWrapper(BaseToolApiWrapper):
|
|
|
475
475
|
def stepback_summary_index(self,
|
|
476
476
|
query: str,
|
|
477
477
|
messages: List[Dict[str, Any]] = [],
|
|
478
|
-
|
|
478
|
+
index_name: str = "",
|
|
479
479
|
filter: dict | str = {}, cut_off: float = 0.5,
|
|
480
480
|
search_top: int = 10, reranker: dict = {},
|
|
481
481
|
full_text_search: Optional[Dict[str, Any]] = None,
|
|
@@ -484,7 +484,7 @@ class BaseVectorStoreToolApiWrapper(BaseToolApiWrapper):
|
|
|
484
484
|
**kwargs):
|
|
485
485
|
""" Generates a summary of indexed documents using stepback technique."""
|
|
486
486
|
vectorstore = self._init_vector_store()
|
|
487
|
-
filter = self._build_collection_filter(filter,
|
|
487
|
+
filter = self._build_collection_filter(filter, index_name)
|
|
488
488
|
|
|
489
489
|
found_docs = vectorstore.stepback_summary(
|
|
490
490
|
query,
|
|
@@ -655,7 +655,7 @@ class BaseCodeToolApiWrapper(BaseVectorStoreToolApiWrapper):
|
|
|
655
655
|
return parse_code_files_for_db(file_content_generator())
|
|
656
656
|
|
|
657
657
|
def index_data(self,
|
|
658
|
-
|
|
658
|
+
index_name: str,
|
|
659
659
|
branch: Optional[str] = None,
|
|
660
660
|
whitelist: Optional[List[str]] = None,
|
|
661
661
|
blacklist: Optional[List[str]] = None,
|
|
@@ -669,7 +669,7 @@ class BaseCodeToolApiWrapper(BaseVectorStoreToolApiWrapper):
|
|
|
669
669
|
)
|
|
670
670
|
vectorstore = self._init_vector_store()
|
|
671
671
|
clean_index = kwargs.get('clean_index', False)
|
|
672
|
-
return vectorstore.index_documents(documents,
|
|
672
|
+
return vectorstore.index_documents(documents, index_name=index_name,
|
|
673
673
|
clean_index=clean_index, is_code=True,
|
|
674
674
|
progress_step=kwargs.get('progress_step', 5))
|
|
675
675
|
|
|
@@ -115,9 +115,8 @@ class GitLabAPIWrapper(CodeIndexerToolkit):
|
|
|
115
115
|
"""Remove trailing slash from URL if present."""
|
|
116
116
|
return url.rstrip('/') if url else url
|
|
117
117
|
|
|
118
|
-
@model_validator(mode='
|
|
119
|
-
|
|
120
|
-
def validate_toolkit(cls, values: Dict) -> Dict:
|
|
118
|
+
@model_validator(mode='after')
|
|
119
|
+
def validate_toolkit(self):
|
|
121
120
|
try:
|
|
122
121
|
import gitlab
|
|
123
122
|
except ImportError:
|
|
@@ -125,17 +124,17 @@ class GitLabAPIWrapper(CodeIndexerToolkit):
|
|
|
125
124
|
"python-gitlab is not installed. "
|
|
126
125
|
"Please install it with `pip install python-gitlab`"
|
|
127
126
|
)
|
|
128
|
-
|
|
127
|
+
self.repository = self._sanitize_url(self.repository)
|
|
129
128
|
g = gitlab.Gitlab(
|
|
130
|
-
url=
|
|
131
|
-
private_token=
|
|
129
|
+
url=self._sanitize_url(self.url),
|
|
130
|
+
private_token=self.private_token.get_secret_value(),
|
|
132
131
|
keep_base_url=True,
|
|
133
132
|
)
|
|
134
133
|
|
|
135
134
|
g.auth()
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
return
|
|
135
|
+
self._git = g
|
|
136
|
+
self._active_branch = self.branch
|
|
137
|
+
return self
|
|
139
138
|
|
|
140
139
|
@property
|
|
141
140
|
def repo_instance(self):
|
|
@@ -563,7 +563,7 @@ class JiraApiWrapper(NonCodeIndexerToolkit):
|
|
|
563
563
|
Use the appropriate issue link type (e.g., "Test", "Relates", "Blocks").
|
|
564
564
|
If we use "Test" linktype, the test is inward issue, the story/other issue is outward issue.."""
|
|
565
565
|
|
|
566
|
-
comment = "
|
|
566
|
+
comment = f"Issue {inward_issue_key} was linked to {outward_issue_key}."
|
|
567
567
|
comment_body = {"content": [{"content": [{"text": comment,"type": "text"}],"type": "paragraph"}],"type": "doc","version": 1} if self.api_version == "3" else comment
|
|
568
568
|
link_data = {
|
|
569
569
|
"type": {"name": f"{linktype}"},
|
|
@@ -6,11 +6,11 @@ from alita_sdk.tools.base_indexer_toolkit import BaseIndexerToolkit
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class NonCodeIndexerToolkit(BaseIndexerToolkit):
|
|
9
|
-
def _get_indexed_data(self,
|
|
9
|
+
def _get_indexed_data(self, index_name: str):
|
|
10
10
|
if not self.vector_adapter:
|
|
11
11
|
raise ToolException("Vector adapter is not initialized. "
|
|
12
12
|
"Check your configuration: embedding_model and vectorstore_type.")
|
|
13
|
-
return self.vector_adapter.get_indexed_data(self,
|
|
13
|
+
return self.vector_adapter.get_indexed_data(self, index_name)
|
|
14
14
|
|
|
15
15
|
def key_fn(self, document: Document):
|
|
16
16
|
return document.metadata.get('id')
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import re
|
|
3
3
|
import logging
|
|
4
|
+
import yaml
|
|
4
5
|
from typing import List, Any, Optional, Dict
|
|
5
6
|
from langchain_core.tools import BaseTool, BaseToolkit, ToolException
|
|
6
7
|
from requests_openapi import Operation, Client, Server
|
|
@@ -101,7 +102,15 @@ class AlitaOpenAPIToolkit(BaseToolkit):
|
|
|
101
102
|
else:
|
|
102
103
|
tools_set = {}
|
|
103
104
|
if isinstance(openapi_spec, str):
|
|
104
|
-
|
|
105
|
+
# Try to detect if it's YAML or JSON by attempting to parse as JSON first
|
|
106
|
+
try:
|
|
107
|
+
openapi_spec = json.loads(openapi_spec)
|
|
108
|
+
except json.JSONDecodeError:
|
|
109
|
+
# If JSON parsing fails, try YAML
|
|
110
|
+
try:
|
|
111
|
+
openapi_spec = yaml.safe_load(openapi_spec)
|
|
112
|
+
except yaml.YAMLError as e:
|
|
113
|
+
raise ToolException(f"Failed to parse OpenAPI spec as JSON or YAML: {e}")
|
|
105
114
|
c = Client()
|
|
106
115
|
c.load_spec(openapi_spec)
|
|
107
116
|
if headers:
|