alita-sdk 0.3.554__py3-none-any.whl → 0.3.602__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of alita-sdk might be problematic. Click here for more details.
- alita_sdk/cli/agent_executor.py +2 -1
- alita_sdk/cli/agent_loader.py +34 -4
- alita_sdk/cli/agents.py +433 -203
- alita_sdk/configurations/openapi.py +227 -15
- alita_sdk/runtime/clients/client.py +4 -2
- alita_sdk/runtime/langchain/_constants_bkup.py +1318 -0
- alita_sdk/runtime/langchain/assistant.py +61 -11
- alita_sdk/runtime/langchain/constants.py +419 -171
- alita_sdk/runtime/langchain/document_loaders/AlitaJSONLoader.py +4 -2
- alita_sdk/runtime/langchain/document_loaders/AlitaTextLoader.py +5 -2
- alita_sdk/runtime/langchain/langraph_agent.py +106 -21
- alita_sdk/runtime/langchain/utils.py +30 -14
- alita_sdk/runtime/toolkits/__init__.py +3 -0
- alita_sdk/runtime/toolkits/artifact.py +2 -1
- alita_sdk/runtime/toolkits/mcp.py +6 -3
- alita_sdk/runtime/toolkits/mcp_config.py +1048 -0
- alita_sdk/runtime/toolkits/skill_router.py +2 -2
- alita_sdk/runtime/toolkits/tools.py +64 -2
- alita_sdk/runtime/toolkits/vectorstore.py +1 -1
- alita_sdk/runtime/tools/artifact.py +15 -0
- alita_sdk/runtime/tools/data_analysis.py +183 -0
- alita_sdk/runtime/tools/llm.py +30 -11
- alita_sdk/runtime/tools/mcp_server_tool.py +6 -3
- alita_sdk/runtime/tools/router.py +2 -4
- alita_sdk/runtime/tools/sandbox.py +9 -6
- alita_sdk/runtime/utils/constants.py +5 -1
- alita_sdk/runtime/utils/mcp_client.py +1 -1
- alita_sdk/runtime/utils/mcp_sse_client.py +1 -1
- alita_sdk/runtime/utils/toolkit_utils.py +2 -0
- alita_sdk/tools/__init__.py +3 -1
- alita_sdk/tools/ado/repos/__init__.py +26 -8
- alita_sdk/tools/ado/repos/repos_wrapper.py +78 -52
- alita_sdk/tools/ado/test_plan/__init__.py +3 -2
- alita_sdk/tools/ado/test_plan/test_plan_wrapper.py +23 -1
- alita_sdk/tools/ado/utils.py +1 -18
- alita_sdk/tools/ado/wiki/__init__.py +2 -1
- alita_sdk/tools/ado/wiki/ado_wrapper.py +23 -1
- alita_sdk/tools/ado/work_item/__init__.py +3 -2
- alita_sdk/tools/ado/work_item/ado_wrapper.py +23 -1
- alita_sdk/tools/advanced_jira_mining/__init__.py +2 -1
- alita_sdk/tools/aws/delta_lake/__init__.py +2 -1
- alita_sdk/tools/azure_ai/search/__init__.py +2 -1
- alita_sdk/tools/azure_ai/search/api_wrapper.py +1 -1
- alita_sdk/tools/base_indexer_toolkit.py +15 -6
- alita_sdk/tools/bitbucket/__init__.py +2 -1
- alita_sdk/tools/bitbucket/api_wrapper.py +1 -1
- alita_sdk/tools/bitbucket/cloud_api_wrapper.py +3 -3
- alita_sdk/tools/browser/__init__.py +1 -1
- alita_sdk/tools/carrier/__init__.py +1 -1
- alita_sdk/tools/chunkers/code/treesitter/treesitter.py +37 -13
- alita_sdk/tools/cloud/aws/__init__.py +2 -1
- alita_sdk/tools/cloud/azure/__init__.py +2 -1
- alita_sdk/tools/cloud/gcp/__init__.py +2 -1
- alita_sdk/tools/cloud/k8s/__init__.py +2 -1
- alita_sdk/tools/code/linter/__init__.py +2 -1
- alita_sdk/tools/code/sonar/__init__.py +2 -1
- alita_sdk/tools/code_indexer_toolkit.py +19 -2
- alita_sdk/tools/confluence/__init__.py +7 -6
- alita_sdk/tools/confluence/api_wrapper.py +2 -2
- alita_sdk/tools/custom_open_api/__init__.py +2 -1
- alita_sdk/tools/elastic/__init__.py +2 -1
- alita_sdk/tools/elitea_base.py +28 -9
- alita_sdk/tools/figma/__init__.py +52 -6
- alita_sdk/tools/figma/api_wrapper.py +1158 -123
- alita_sdk/tools/figma/figma_client.py +73 -0
- alita_sdk/tools/figma/toon_tools.py +2748 -0
- alita_sdk/tools/github/__init__.py +2 -1
- alita_sdk/tools/github/github_client.py +56 -92
- alita_sdk/tools/github/schemas.py +4 -4
- alita_sdk/tools/gitlab/__init__.py +2 -1
- alita_sdk/tools/gitlab/api_wrapper.py +118 -38
- alita_sdk/tools/gitlab_org/__init__.py +2 -1
- alita_sdk/tools/gitlab_org/api_wrapper.py +60 -62
- alita_sdk/tools/google/bigquery/__init__.py +2 -1
- alita_sdk/tools/google_places/__init__.py +2 -1
- alita_sdk/tools/jira/__init__.py +2 -1
- alita_sdk/tools/keycloak/__init__.py +2 -1
- alita_sdk/tools/localgit/__init__.py +2 -1
- alita_sdk/tools/memory/__init__.py +1 -1
- alita_sdk/tools/ocr/__init__.py +2 -1
- alita_sdk/tools/openapi/__init__.py +227 -15
- alita_sdk/tools/openapi/api_wrapper.py +1287 -802
- alita_sdk/tools/pandas/__init__.py +11 -5
- alita_sdk/tools/pandas/api_wrapper.py +38 -25
- alita_sdk/tools/postman/__init__.py +2 -1
- alita_sdk/tools/pptx/__init__.py +2 -1
- alita_sdk/tools/qtest/__init__.py +21 -2
- alita_sdk/tools/qtest/api_wrapper.py +430 -13
- alita_sdk/tools/rally/__init__.py +2 -1
- alita_sdk/tools/rally/api_wrapper.py +1 -1
- alita_sdk/tools/report_portal/__init__.py +2 -1
- alita_sdk/tools/salesforce/__init__.py +2 -1
- alita_sdk/tools/servicenow/__init__.py +2 -1
- alita_sdk/tools/sharepoint/__init__.py +2 -1
- alita_sdk/tools/sharepoint/api_wrapper.py +2 -2
- alita_sdk/tools/slack/__init__.py +3 -2
- alita_sdk/tools/slack/api_wrapper.py +2 -2
- alita_sdk/tools/sql/__init__.py +3 -2
- alita_sdk/tools/testio/__init__.py +2 -1
- alita_sdk/tools/testrail/__init__.py +2 -1
- alita_sdk/tools/utils/content_parser.py +77 -3
- alita_sdk/tools/utils/text_operations.py +163 -71
- alita_sdk/tools/xray/__init__.py +3 -2
- alita_sdk/tools/yagmail/__init__.py +2 -1
- alita_sdk/tools/zephyr/__init__.py +2 -1
- alita_sdk/tools/zephyr_enterprise/__init__.py +2 -1
- alita_sdk/tools/zephyr_essential/__init__.py +2 -1
- alita_sdk/tools/zephyr_scale/__init__.py +3 -2
- alita_sdk/tools/zephyr_scale/api_wrapper.py +2 -2
- alita_sdk/tools/zephyr_squad/__init__.py +2 -1
- {alita_sdk-0.3.554.dist-info → alita_sdk-0.3.602.dist-info}/METADATA +7 -6
- {alita_sdk-0.3.554.dist-info → alita_sdk-0.3.602.dist-info}/RECORD +116 -111
- {alita_sdk-0.3.554.dist-info → alita_sdk-0.3.602.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.554.dist-info → alita_sdk-0.3.602.dist-info}/entry_points.txt +0 -0
- {alita_sdk-0.3.554.dist-info → alita_sdk-0.3.602.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.554.dist-info → alita_sdk-0.3.602.dist-info}/top_level.txt +0 -0
|
@@ -8,7 +8,7 @@ from langchain_core.tools import ToolException
|
|
|
8
8
|
from pydantic import model_validator, PrivateAttr, create_model, SecretStr
|
|
9
9
|
from pydantic.fields import Field
|
|
10
10
|
|
|
11
|
-
from ..elitea_base import BaseToolApiWrapper
|
|
11
|
+
from ..elitea_base import BaseToolApiWrapper, BaseCodeToolApiWrapper
|
|
12
12
|
from ..gitlab.utils import get_diff_w_position, get_position
|
|
13
13
|
|
|
14
14
|
logger = logging.getLogger(__name__)
|
|
@@ -24,7 +24,7 @@ GitLabCreateBranch = create_model(
|
|
|
24
24
|
GitLabListBranches = create_model(
|
|
25
25
|
"GitLabListBranchesModel",
|
|
26
26
|
repository=(Optional[str], Field(description="Name of the repository", default=None)),
|
|
27
|
-
limit=(Optional[int], Field(description="Maximum number of branches to return. If not provided, all branches will be returned.", default=20)),
|
|
27
|
+
limit=(Optional[int], Field(description="Maximum number of branches to return. If not provided, all branches will be returned.", default=20, gt=0)),
|
|
28
28
|
branch_wildcard=(Optional[str], Field(description="Wildcard pattern to filter branches by name. If not provided, all branches will be returned.", default=None))
|
|
29
29
|
)
|
|
30
30
|
|
|
@@ -159,6 +159,9 @@ class GitLabWorkspaceAPIWrapper(BaseToolApiWrapper):
|
|
|
159
159
|
repo_instances: Dict[str, Any] = {}
|
|
160
160
|
_active_branch: Optional[str] = PrivateAttr(default='main')
|
|
161
161
|
|
|
162
|
+
# Reuse common file helpers from BaseCodeToolApiWrapper where applicable
|
|
163
|
+
edit_file = BaseCodeToolApiWrapper.edit_file
|
|
164
|
+
|
|
162
165
|
class Config:
|
|
163
166
|
arbitrary_types_allowed = True
|
|
164
167
|
|
|
@@ -371,51 +374,76 @@ class GitLabWorkspaceAPIWrapper(BaseToolApiWrapper):
|
|
|
371
374
|
except Exception as e:
|
|
372
375
|
return ToolException(e)
|
|
373
376
|
|
|
374
|
-
def
|
|
375
|
-
"""
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
377
|
+
def _read_file(self, file_path: str, branch: str, **kwargs) -> str:
|
|
378
|
+
"""
|
|
379
|
+
Internal read_file used by BaseCodeToolApiWrapper.edit_file.
|
|
380
|
+
Delegates to the public `read_file` implementation which supports an optional repository argument.
|
|
381
|
+
The repository may be passed via kwargs or provided earlier through `update_file` which sets
|
|
382
|
+
a temporary attribute `_tmp_repository_for_edit`.
|
|
383
|
+
"""
|
|
384
|
+
# Repository from temporary context, then None
|
|
385
|
+
repository = getattr(self, "_tmp_repository_for_edit", None)
|
|
386
|
+
try:
|
|
387
|
+
# Public read_file signature: read_file(file_path, branch, repository=None)
|
|
388
|
+
return self.read_file(file_path, branch, repository)
|
|
389
|
+
except Exception as e:
|
|
390
|
+
raise ToolException(f"Can't extract file content (`{file_path}`) due to error:\n{str(e)}")
|
|
391
|
+
|
|
392
|
+
def _write_file(self, file_path: str, content: str, branch: str = None, commit_message: str = None) -> str:
|
|
393
|
+
"""
|
|
394
|
+
Write content to a file (update only) in the specified GitLab repository.
|
|
395
|
+
|
|
396
|
+
This implementation follows the same commit flow as the previous `update_file`:
|
|
397
|
+
it does not attempt to create the file when it is missing — it will always
|
|
398
|
+
create a commit with a single `update` action. If the file does not exist on
|
|
399
|
+
the target branch, the underlying GitLab API will typically return an error.
|
|
389
400
|
"""
|
|
390
401
|
try:
|
|
402
|
+
branch = branch if branch else (self._active_branch if self._active_branch else self.branch)
|
|
403
|
+
# pick repository from temporary edit context
|
|
404
|
+
repository = getattr(self, "_tmp_repository_for_edit", None)
|
|
391
405
|
repo_instance = self._get_repo(repository)
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
for old, new in self.extract_old_new_pairs(update_query):
|
|
395
|
-
if not old.strip():
|
|
396
|
-
continue
|
|
397
|
-
updated_file_content = updated_file_content.replace(old, new)
|
|
398
|
-
if file_content == updated_file_content:
|
|
399
|
-
return (
|
|
400
|
-
"File content was not updated because old content was not found or empty."
|
|
401
|
-
"It may be helpful to use the read_file action to get "
|
|
402
|
-
"the current file contents."
|
|
403
|
-
)
|
|
406
|
+
|
|
407
|
+
# Always perform an 'update' action commit (do not create file when missing)
|
|
404
408
|
commit = {
|
|
405
409
|
"branch": branch,
|
|
406
|
-
"commit_message": "Update "
|
|
410
|
+
"commit_message": commit_message or f"Update {file_path}",
|
|
407
411
|
"actions": [
|
|
408
412
|
{
|
|
409
413
|
"action": "update",
|
|
410
414
|
"file_path": file_path,
|
|
411
|
-
"content":
|
|
415
|
+
"content": content,
|
|
412
416
|
}
|
|
413
417
|
],
|
|
414
418
|
}
|
|
415
419
|
repo_instance.commits.create(commit)
|
|
416
|
-
return "Updated file "
|
|
420
|
+
return f"Updated file {file_path}"
|
|
421
|
+
except ToolException:
|
|
422
|
+
raise
|
|
423
|
+
except Exception as e:
|
|
424
|
+
return ToolException(f"Unable to write file due to error: {str(e)}")
|
|
425
|
+
|
|
426
|
+
def update_file(self, file_path: str, update_query: str, branch: str, repository: Optional[str] = None) -> str:
|
|
427
|
+
"""Updates a file with new content using OLD/NEW markers by delegating to `edit_file`.
|
|
428
|
+
|
|
429
|
+
The method sets a temporary repository context so that `edit_file`'s internal
|
|
430
|
+
calls to `_read_file` and `_write_file` operate on the requested repository.
|
|
431
|
+
"""
|
|
432
|
+
# Set temporary repository context used by _read_file/_write_file
|
|
433
|
+
self._tmp_repository_for_edit = repository
|
|
434
|
+
try:
|
|
435
|
+
commit_message = f"Update {file_path}"
|
|
436
|
+
return self.edit_file(file_path=file_path, file_query=update_query, branch=branch, commit_message=commit_message)
|
|
437
|
+
except ToolException as e:
|
|
438
|
+
return str(e)
|
|
417
439
|
except Exception as e:
|
|
418
440
|
return ToolException(f"Unable to update file due to error: {str(e)}")
|
|
441
|
+
finally:
|
|
442
|
+
# Clear temporary context
|
|
443
|
+
try:
|
|
444
|
+
delattr(self, "_tmp_repository_for_edit")
|
|
445
|
+
except Exception:
|
|
446
|
+
self._tmp_repository_for_edit = None
|
|
419
447
|
|
|
420
448
|
def delete_file(self, file_path: str, branch: str, repository: Optional[str] = None) -> str:
|
|
421
449
|
"""Deletes a file from the repo."""
|
|
@@ -428,36 +456,6 @@ class GitLabWorkspaceAPIWrapper(BaseToolApiWrapper):
|
|
|
428
456
|
except Exception as e:
|
|
429
457
|
return ToolException(f"Unable to delete file due to error: {str(e)}")
|
|
430
458
|
|
|
431
|
-
def extract_old_new_pairs(self, file_query):
|
|
432
|
-
"""Extract old and new content pairs from the file query."""
|
|
433
|
-
code_lines = file_query.split("\n")
|
|
434
|
-
old_contents = []
|
|
435
|
-
new_contents = []
|
|
436
|
-
in_old_section = False
|
|
437
|
-
in_new_section = False
|
|
438
|
-
current_section_content = []
|
|
439
|
-
for line in code_lines:
|
|
440
|
-
if "OLD <<<" in line:
|
|
441
|
-
in_old_section = True
|
|
442
|
-
current_section_content = []
|
|
443
|
-
continue
|
|
444
|
-
if ">>>> OLD" in line:
|
|
445
|
-
in_old_section = False
|
|
446
|
-
old_contents.append("\n".join(current_section_content).strip())
|
|
447
|
-
current_section_content = []
|
|
448
|
-
continue
|
|
449
|
-
if "NEW <<<" in line:
|
|
450
|
-
in_new_section = True
|
|
451
|
-
current_section_content = []
|
|
452
|
-
continue
|
|
453
|
-
if ">>>> NEW" in line:
|
|
454
|
-
in_new_section = False
|
|
455
|
-
new_contents.append("\n".join(current_section_content).strip())
|
|
456
|
-
current_section_content = []
|
|
457
|
-
continue
|
|
458
|
-
if in_old_section or in_new_section:
|
|
459
|
-
current_section_content.append(line)
|
|
460
|
-
return list(zip(old_contents, new_contents))
|
|
461
459
|
|
|
462
460
|
def append_file(self, file_path: str, content: str, branch: str, repository: Optional[str] = None) -> str:
|
|
463
461
|
"""
|
|
@@ -8,6 +8,7 @@ from ....configurations.bigquery import BigQueryConfiguration
|
|
|
8
8
|
from ...utils import clean_string, get_max_toolkit_length
|
|
9
9
|
from .api_wrapper import BigQueryApiWrapper
|
|
10
10
|
from .tool import BigQueryAction
|
|
11
|
+
from ....runtime.utils.constants import TOOLKIT_NAME_META, TOOL_NAME_META, TOOLKIT_TYPE_META
|
|
11
12
|
|
|
12
13
|
name = "bigquery"
|
|
13
14
|
|
|
@@ -129,7 +130,7 @@ class BigQueryToolkit(BaseToolkit):
|
|
|
129
130
|
name=t["name"],
|
|
130
131
|
description=description,
|
|
131
132
|
args_schema=t["args_schema"],
|
|
132
|
-
metadata={
|
|
133
|
+
metadata={TOOLKIT_NAME_META: toolkit_name, TOOLKIT_TYPE_META: name, TOOL_NAME_META: t["name"]} if toolkit_name else {TOOL_NAME_META: t["name"]}
|
|
133
134
|
)
|
|
134
135
|
)
|
|
135
136
|
return instance
|
|
@@ -8,6 +8,7 @@ from ..base.tool import BaseAction
|
|
|
8
8
|
from ..elitea_base import filter_missconfigured_index_tools
|
|
9
9
|
from ..utils import clean_string, get_max_toolkit_length
|
|
10
10
|
from ...configurations.google_places import GooglePlacesConfiguration
|
|
11
|
+
from ...runtime.utils.constants import TOOLKIT_NAME_META, TOOL_NAME_META, TOOLKIT_TYPE_META
|
|
11
12
|
|
|
12
13
|
name = "google_places"
|
|
13
14
|
|
|
@@ -67,7 +68,7 @@ class GooglePlacesToolkit(BaseToolkit):
|
|
|
67
68
|
name=tool["name"],
|
|
68
69
|
description=description,
|
|
69
70
|
args_schema=tool["args_schema"],
|
|
70
|
-
metadata={
|
|
71
|
+
metadata={TOOLKIT_NAME_META: toolkit_name, TOOLKIT_TYPE_META: name, TOOL_NAME_META: tool["name"]} if toolkit_name else {TOOL_NAME_META: tool["name"]}
|
|
71
72
|
))
|
|
72
73
|
return cls(tools=tools)
|
|
73
74
|
|
alita_sdk/tools/jira/__init__.py
CHANGED
|
@@ -9,6 +9,7 @@ from ..elitea_base import filter_missconfigured_index_tools
|
|
|
9
9
|
from ..utils import clean_string, get_max_toolkit_length, parse_list, check_connection_response
|
|
10
10
|
from ...configurations.jira import JiraConfiguration
|
|
11
11
|
from ...configurations.pgvector import PgVectorConfiguration
|
|
12
|
+
from ...runtime.utils.constants import TOOLKIT_NAME_META, TOOLKIT_TYPE_META, TOOL_NAME_META
|
|
12
13
|
|
|
13
14
|
name = "jira"
|
|
14
15
|
|
|
@@ -126,7 +127,7 @@ class JiraToolkit(BaseToolkit):
|
|
|
126
127
|
name=tool["name"],
|
|
127
128
|
description=description,
|
|
128
129
|
args_schema=tool["args_schema"],
|
|
129
|
-
metadata={
|
|
130
|
+
metadata={TOOLKIT_NAME_META: toolkit_name, TOOLKIT_TYPE_META: name, TOOL_NAME_META: tool["name"]} if toolkit_name else {TOOL_NAME_META: tool["name"]}
|
|
130
131
|
))
|
|
131
132
|
return cls(tools=tools)
|
|
132
133
|
|
|
@@ -6,6 +6,7 @@ from pydantic import BaseModel, ConfigDict, create_model, Field, SecretStr
|
|
|
6
6
|
from .api_wrapper import KeycloakApiWrapper
|
|
7
7
|
from ..base.tool import BaseAction
|
|
8
8
|
from ..utils import clean_string, get_max_toolkit_length
|
|
9
|
+
from ...runtime.utils.constants import TOOLKIT_NAME_META, TOOL_NAME_META, TOOLKIT_TYPE_META
|
|
9
10
|
|
|
10
11
|
name = "keycloak"
|
|
11
12
|
|
|
@@ -54,7 +55,7 @@ class KeycloakToolkit(BaseToolkit):
|
|
|
54
55
|
name=tool["name"],
|
|
55
56
|
description=description,
|
|
56
57
|
args_schema=tool["args_schema"],
|
|
57
|
-
metadata={
|
|
58
|
+
metadata={TOOLKIT_NAME_META: toolkit_name, TOOLKIT_TYPE_META: name, TOOL_NAME_META: tool["name"]} if toolkit_name else {TOOL_NAME_META: tool["name"]}
|
|
58
59
|
))
|
|
59
60
|
return cls(tools=tools)
|
|
60
61
|
|
|
@@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, create_model, Field
|
|
|
5
5
|
|
|
6
6
|
from .local_git import LocalGit
|
|
7
7
|
from .tool import LocalGitAction
|
|
8
|
+
from ...runtime.utils.constants import TOOLKIT_NAME_META, TOOL_NAME_META, TOOLKIT_TYPE_META
|
|
8
9
|
|
|
9
10
|
name = "localgit"
|
|
10
11
|
|
|
@@ -55,7 +56,7 @@ class AlitaLocalGitToolkit(BaseToolkit):
|
|
|
55
56
|
mode=tool["mode"],
|
|
56
57
|
description=description,
|
|
57
58
|
args_schema=tool["args_schema"],
|
|
58
|
-
metadata={
|
|
59
|
+
metadata={TOOLKIT_NAME_META: toolkit_name, TOOLKIT_TYPE_META: name, TOOL_NAME_META: tool["name"]} if toolkit_name else {TOOL_NAME_META: tool["name"]}
|
|
59
60
|
))
|
|
60
61
|
return cls(tools=tools)
|
|
61
62
|
|
|
@@ -118,7 +118,7 @@ class MemoryToolkit(BaseToolkit):
|
|
|
118
118
|
# Add metadata to tools if toolkit_name is provided
|
|
119
119
|
if toolkit_name:
|
|
120
120
|
for tool in tools:
|
|
121
|
-
tool.metadata = {"toolkit_name": toolkit_name}
|
|
121
|
+
tool.metadata = {"toolkit_name": toolkit_name, "toolkit_type": name}
|
|
122
122
|
|
|
123
123
|
return cls(tools=tools)
|
|
124
124
|
|
alita_sdk/tools/ocr/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ from pydantic import create_model, BaseModel, ConfigDict, Field
|
|
|
6
6
|
from .api_wrapper import OCRApiWrapper
|
|
7
7
|
from ..base.tool import BaseAction
|
|
8
8
|
from ..utils import clean_string, get_max_toolkit_length
|
|
9
|
+
from ...runtime.utils.constants import TOOLKIT_NAME_META, TOOL_NAME_META, TOOLKIT_TYPE_META
|
|
9
10
|
|
|
10
11
|
name = "ocr"
|
|
11
12
|
|
|
@@ -59,7 +60,7 @@ class OCRToolkit(BaseToolkit):
|
|
|
59
60
|
name=tool["name"],
|
|
60
61
|
description=description,
|
|
61
62
|
args_schema=tool["args_schema"],
|
|
62
|
-
metadata={
|
|
63
|
+
metadata={TOOLKIT_NAME_META: toolkit_name, TOOLKIT_TYPE_META: name, TOOL_NAME_META: tool["name"]} if toolkit_name else {TOOL_NAME_META: tool["name"]}
|
|
63
64
|
))
|
|
64
65
|
return cls(tools=tools)
|
|
65
66
|
|
|
@@ -1,19 +1,218 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import base64
|
|
3
4
|
import json
|
|
4
|
-
|
|
5
|
+
import logging
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
9
|
+
from urllib.parse import urlparse
|
|
5
10
|
|
|
6
11
|
from langchain_core.tools import BaseTool, BaseToolkit
|
|
7
12
|
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
13
|
+
import requests
|
|
8
14
|
import yaml
|
|
9
15
|
|
|
10
16
|
from .api_wrapper import _get_base_url_from_spec, build_wrapper
|
|
11
17
|
from .tool import OpenApiAction
|
|
12
18
|
from ..elitea_base import filter_missconfigured_index_tools
|
|
13
19
|
from ...configurations.openapi import OpenApiConfiguration
|
|
20
|
+
from ...runtime.utils.constants import TOOLKIT_NAME_META, TOOL_NAME_META, TOOLKIT_TYPE_META
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
14
23
|
|
|
15
24
|
name = 'openapi'
|
|
16
25
|
|
|
26
|
+
# Module-level token cache: {cache_key: (access_token, expires_at_timestamp)}
|
|
27
|
+
# Protected by _oauth_token_cache_lock for thread-safe access
|
|
28
|
+
_oauth_token_cache: Dict[str, Tuple[str, float]] = {}
|
|
29
|
+
_oauth_token_cache_lock = threading.Lock()
|
|
30
|
+
|
|
31
|
+
# Token expiry buffer in seconds (refresh 60 seconds before actual expiry)
|
|
32
|
+
_TOKEN_EXPIRY_BUFFER = 60
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _get_oauth_cache_key(client_id: str, token_url: str, scope: Optional[str]) -> str:
|
|
36
|
+
"""Generate a cache key for OAuth tokens."""
|
|
37
|
+
return f"{client_id}:{token_url}:{scope or ''}"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _get_cached_token(cache_key: str) -> Optional[str]:
|
|
41
|
+
"""Get a cached token if it exists and is not expired. Thread-safe."""
|
|
42
|
+
with _oauth_token_cache_lock:
|
|
43
|
+
if cache_key not in _oauth_token_cache:
|
|
44
|
+
return None
|
|
45
|
+
token, expires_at = _oauth_token_cache[cache_key]
|
|
46
|
+
if time.time() >= expires_at - _TOKEN_EXPIRY_BUFFER:
|
|
47
|
+
# Token expired or about to expire
|
|
48
|
+
del _oauth_token_cache[cache_key]
|
|
49
|
+
return None
|
|
50
|
+
return token
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _cache_token(cache_key: str, token: str, expires_in: Optional[int]) -> None:
|
|
54
|
+
"""Cache a token with its expiry time. Thread-safe."""
|
|
55
|
+
# Default to 1 hour if expires_in not provided
|
|
56
|
+
expires_in = expires_in or 3600
|
|
57
|
+
expires_at = time.time() + expires_in
|
|
58
|
+
with _oauth_token_cache_lock:
|
|
59
|
+
_oauth_token_cache[cache_key] = (token, expires_at)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _obtain_oauth_token(
|
|
63
|
+
client_id: str,
|
|
64
|
+
client_secret: str,
|
|
65
|
+
token_url: str,
|
|
66
|
+
scope: Optional[str] = None,
|
|
67
|
+
method: str = 'default',
|
|
68
|
+
timeout: int = 30,
|
|
69
|
+
) -> Tuple[str, Optional[str]]:
|
|
70
|
+
"""
|
|
71
|
+
Obtain an OAuth2 access token using client credentials flow.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
client_id: OAuth client ID
|
|
75
|
+
client_secret: OAuth client secret
|
|
76
|
+
token_url: OAuth token endpoint URL
|
|
77
|
+
scope: Optional OAuth scope(s), space-separated if multiple
|
|
78
|
+
method: Token exchange method - 'default' (POST body) or 'Basic' (Basic auth header)
|
|
79
|
+
timeout: Request timeout in seconds
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Tuple of (access_token, error_message)
|
|
83
|
+
On success: (token, None)
|
|
84
|
+
On failure: (None, error_message)
|
|
85
|
+
"""
|
|
86
|
+
try:
|
|
87
|
+
headers = {
|
|
88
|
+
'Content-Type': 'application/x-www-form-urlencoded',
|
|
89
|
+
'Accept': 'application/json',
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# Build form data
|
|
93
|
+
data: Dict[str, str] = {
|
|
94
|
+
'grant_type': 'client_credentials',
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
if method == 'Basic':
|
|
98
|
+
# Use Basic auth header for client credentials
|
|
99
|
+
credentials = f"{client_id}:{client_secret}"
|
|
100
|
+
encoded_credentials = base64.b64encode(credentials.encode('utf-8')).decode('utf-8')
|
|
101
|
+
headers['Authorization'] = f'Basic {encoded_credentials}'
|
|
102
|
+
else:
|
|
103
|
+
# Default: include credentials in POST body
|
|
104
|
+
data['client_id'] = client_id
|
|
105
|
+
data['client_secret'] = client_secret
|
|
106
|
+
|
|
107
|
+
if scope:
|
|
108
|
+
data['scope'] = scope
|
|
109
|
+
|
|
110
|
+
# Log only the domain to avoid exposing sensitive path parameters (e.g., tenant IDs)
|
|
111
|
+
token_domain = urlparse(token_url).netloc or 'unknown'
|
|
112
|
+
logger.debug(f"OAuth token request to {token_domain} using method '{method}'")
|
|
113
|
+
|
|
114
|
+
response = requests.post(
|
|
115
|
+
token_url,
|
|
116
|
+
headers=headers,
|
|
117
|
+
data=data,
|
|
118
|
+
timeout=timeout,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if response.status_code == 200:
|
|
122
|
+
try:
|
|
123
|
+
token_data = response.json()
|
|
124
|
+
access_token = token_data.get('access_token')
|
|
125
|
+
if not access_token:
|
|
126
|
+
return None, "OAuth response did not contain 'access_token'"
|
|
127
|
+
|
|
128
|
+
# Cache the token
|
|
129
|
+
cache_key = _get_oauth_cache_key(client_id, token_url, scope)
|
|
130
|
+
expires_in = token_data.get('expires_in')
|
|
131
|
+
_cache_token(cache_key, access_token, expires_in)
|
|
132
|
+
|
|
133
|
+
logger.debug(f"OAuth token obtained successfully (expires_in: {expires_in})")
|
|
134
|
+
return access_token, None
|
|
135
|
+
except json.JSONDecodeError as e:
|
|
136
|
+
return None, f"Failed to parse OAuth token response as JSON: {e}"
|
|
137
|
+
|
|
138
|
+
# Handle error responses
|
|
139
|
+
error_msg = f"OAuth token request failed with status {response.status_code}"
|
|
140
|
+
try:
|
|
141
|
+
error_data = response.json()
|
|
142
|
+
if 'error' in error_data:
|
|
143
|
+
error_msg = f"{error_msg}: {error_data.get('error')}"
|
|
144
|
+
if 'error_description' in error_data:
|
|
145
|
+
error_msg = f"{error_msg} - {error_data.get('error_description')}"
|
|
146
|
+
except Exception:
|
|
147
|
+
if response.text:
|
|
148
|
+
error_msg = f"{error_msg}: {response.text[:500]}"
|
|
149
|
+
|
|
150
|
+
return None, error_msg
|
|
151
|
+
|
|
152
|
+
except requests.exceptions.Timeout:
|
|
153
|
+
return None, f"OAuth token request to {token_url} timed out"
|
|
154
|
+
except requests.exceptions.ConnectionError as e:
|
|
155
|
+
return None, f"Failed to connect to OAuth token endpoint {token_url}: {e}"
|
|
156
|
+
except requests.exceptions.RequestException as e:
|
|
157
|
+
return None, f"OAuth token request failed: {e}"
|
|
158
|
+
except Exception as e:
|
|
159
|
+
return None, f"Unexpected error during OAuth token exchange: {e}"
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _secret_to_str(value: Any) -> Optional[str]:
|
|
163
|
+
"""Convert a secret value to string, handling SecretStr and other types."""
|
|
164
|
+
if value is None:
|
|
165
|
+
return None
|
|
166
|
+
if hasattr(value, 'get_secret_value'):
|
|
167
|
+
try:
|
|
168
|
+
value = value.get_secret_value()
|
|
169
|
+
except Exception:
|
|
170
|
+
pass
|
|
171
|
+
if isinstance(value, str):
|
|
172
|
+
return value
|
|
173
|
+
return str(value)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _get_oauth_access_token(settings: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
|
|
177
|
+
"""
|
|
178
|
+
Get an OAuth access token from settings, using cache if available.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
settings: Dictionary containing OAuth configuration
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Tuple of (access_token, error_message)
|
|
185
|
+
On success: (token, None)
|
|
186
|
+
On failure: (None, error_message)
|
|
187
|
+
If OAuth not configured: (None, None)
|
|
188
|
+
"""
|
|
189
|
+
client_id = settings.get('client_id')
|
|
190
|
+
client_secret = _secret_to_str(settings.get('client_secret'))
|
|
191
|
+
token_url = settings.get('token_url')
|
|
192
|
+
|
|
193
|
+
# Check if OAuth is configured
|
|
194
|
+
if not client_id or not client_secret or not token_url:
|
|
195
|
+
return None, None # OAuth not configured
|
|
196
|
+
|
|
197
|
+
scope = settings.get('scope')
|
|
198
|
+
method = settings.get('method', 'default') or 'default'
|
|
199
|
+
|
|
200
|
+
# Try to get cached token
|
|
201
|
+
cache_key = _get_oauth_cache_key(client_id, token_url, scope)
|
|
202
|
+
cached_token = _get_cached_token(cache_key)
|
|
203
|
+
if cached_token:
|
|
204
|
+
logger.debug("Using cached OAuth token")
|
|
205
|
+
return cached_token, None
|
|
206
|
+
|
|
207
|
+
# Obtain new token
|
|
208
|
+
return _obtain_oauth_token(
|
|
209
|
+
client_id=client_id,
|
|
210
|
+
client_secret=client_secret,
|
|
211
|
+
token_url=token_url,
|
|
212
|
+
scope=scope,
|
|
213
|
+
method=method,
|
|
214
|
+
)
|
|
215
|
+
|
|
17
216
|
|
|
18
217
|
def get_toolkit(tool) -> BaseToolkit:
|
|
19
218
|
settings = tool.get('settings', {}) or {}
|
|
@@ -221,7 +420,7 @@ class AlitaOpenAPIToolkit(BaseToolkit):
|
|
|
221
420
|
name=tool_def['name'],
|
|
222
421
|
description=description,
|
|
223
422
|
args_schema=tool_def.get('args_schema'),
|
|
224
|
-
metadata={
|
|
423
|
+
metadata={TOOLKIT_NAME_META: toolkit_name, TOOLKIT_TYPE_META: name, TOOL_NAME_META: tool_def["name"]} if toolkit_name else {TOOL_NAME_META: tool_def["name"]},
|
|
225
424
|
)
|
|
226
425
|
)
|
|
227
426
|
|
|
@@ -249,22 +448,35 @@ def _coerce_selected_tool_names(selected_tools: Any) -> list[str]:
|
|
|
249
448
|
return []
|
|
250
449
|
|
|
251
450
|
|
|
252
|
-
def _secret_to_str(value: Any) -> Optional[str]:
|
|
253
|
-
if value is None:
|
|
254
|
-
return None
|
|
255
|
-
if hasattr(value, 'get_secret_value'):
|
|
256
|
-
try:
|
|
257
|
-
value = value.get_secret_value()
|
|
258
|
-
except Exception:
|
|
259
|
-
pass
|
|
260
|
-
if isinstance(value, str):
|
|
261
|
-
return value
|
|
262
|
-
return str(value)
|
|
263
|
-
|
|
264
|
-
|
|
265
451
|
def _build_headers_from_settings(settings: Dict[str, Any]) -> Dict[str, str]:
|
|
452
|
+
"""
|
|
453
|
+
Build HTTP headers from settings, supporting API key and OAuth authentication.
|
|
454
|
+
|
|
455
|
+
Authentication priority:
|
|
456
|
+
1. OAuth (client credentials flow) - if client_id, client_secret, and token_url are provided
|
|
457
|
+
2. API Key - if api_key is provided
|
|
458
|
+
3. Legacy authentication structure (for backward compatibility)
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
settings: Dictionary containing authentication settings
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
Dictionary of HTTP headers to include in requests
|
|
465
|
+
"""
|
|
266
466
|
headers: Dict[str, str] = {}
|
|
267
467
|
|
|
468
|
+
# First, try OAuth authentication (client credentials flow)
|
|
469
|
+
# This takes priority because it's more secure and commonly used with modern APIs
|
|
470
|
+
oauth_token, oauth_error = _get_oauth_access_token(settings)
|
|
471
|
+
if oauth_token:
|
|
472
|
+
headers['Authorization'] = f'Bearer {oauth_token}'
|
|
473
|
+
logger.debug("Using OAuth Bearer token for authentication")
|
|
474
|
+
return headers
|
|
475
|
+
elif oauth_error:
|
|
476
|
+
# OAuth was configured but failed - log the error
|
|
477
|
+
# We'll still try API key auth as fallback
|
|
478
|
+
logger.warning(f"OAuth token exchange failed: {oauth_error}")
|
|
479
|
+
|
|
268
480
|
# Legacy structure used by the custom OpenAPI UI
|
|
269
481
|
auth = settings.get('authentication')
|
|
270
482
|
if isinstance(auth, dict) and auth.get('type') == 'api_key':
|