alita-sdk 0.3.263__py3-none-any.whl → 0.3.499__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alita_sdk/cli/__init__.py +10 -0
- alita_sdk/cli/__main__.py +17 -0
- alita_sdk/cli/agent/__init__.py +5 -0
- alita_sdk/cli/agent/default.py +258 -0
- alita_sdk/cli/agent_executor.py +155 -0
- alita_sdk/cli/agent_loader.py +215 -0
- alita_sdk/cli/agent_ui.py +228 -0
- alita_sdk/cli/agents.py +3601 -0
- alita_sdk/cli/callbacks.py +647 -0
- alita_sdk/cli/cli.py +168 -0
- alita_sdk/cli/config.py +306 -0
- alita_sdk/cli/context/__init__.py +30 -0
- alita_sdk/cli/context/cleanup.py +198 -0
- alita_sdk/cli/context/manager.py +731 -0
- alita_sdk/cli/context/message.py +285 -0
- alita_sdk/cli/context/strategies.py +289 -0
- alita_sdk/cli/context/token_estimation.py +127 -0
- alita_sdk/cli/formatting.py +182 -0
- alita_sdk/cli/input_handler.py +419 -0
- alita_sdk/cli/inventory.py +1256 -0
- alita_sdk/cli/mcp_loader.py +315 -0
- alita_sdk/cli/toolkit.py +327 -0
- alita_sdk/cli/toolkit_loader.py +85 -0
- alita_sdk/cli/tools/__init__.py +43 -0
- alita_sdk/cli/tools/approval.py +224 -0
- alita_sdk/cli/tools/filesystem.py +1751 -0
- alita_sdk/cli/tools/planning.py +389 -0
- alita_sdk/cli/tools/terminal.py +414 -0
- alita_sdk/community/__init__.py +64 -8
- alita_sdk/community/inventory/__init__.py +224 -0
- alita_sdk/community/inventory/config.py +257 -0
- alita_sdk/community/inventory/enrichment.py +2137 -0
- alita_sdk/community/inventory/extractors.py +1469 -0
- alita_sdk/community/inventory/ingestion.py +3172 -0
- alita_sdk/community/inventory/knowledge_graph.py +1457 -0
- alita_sdk/community/inventory/parsers/__init__.py +218 -0
- alita_sdk/community/inventory/parsers/base.py +295 -0
- alita_sdk/community/inventory/parsers/csharp_parser.py +907 -0
- alita_sdk/community/inventory/parsers/go_parser.py +851 -0
- alita_sdk/community/inventory/parsers/html_parser.py +389 -0
- alita_sdk/community/inventory/parsers/java_parser.py +593 -0
- alita_sdk/community/inventory/parsers/javascript_parser.py +629 -0
- alita_sdk/community/inventory/parsers/kotlin_parser.py +768 -0
- alita_sdk/community/inventory/parsers/markdown_parser.py +362 -0
- alita_sdk/community/inventory/parsers/python_parser.py +604 -0
- alita_sdk/community/inventory/parsers/rust_parser.py +858 -0
- alita_sdk/community/inventory/parsers/swift_parser.py +832 -0
- alita_sdk/community/inventory/parsers/text_parser.py +322 -0
- alita_sdk/community/inventory/parsers/yaml_parser.py +370 -0
- alita_sdk/community/inventory/patterns/__init__.py +61 -0
- alita_sdk/community/inventory/patterns/ast_adapter.py +380 -0
- alita_sdk/community/inventory/patterns/loader.py +348 -0
- alita_sdk/community/inventory/patterns/registry.py +198 -0
- alita_sdk/community/inventory/presets.py +535 -0
- alita_sdk/community/inventory/retrieval.py +1403 -0
- alita_sdk/community/inventory/toolkit.py +173 -0
- alita_sdk/community/inventory/visualize.py +1370 -0
- alita_sdk/configurations/__init__.py +10 -0
- alita_sdk/configurations/ado.py +4 -2
- alita_sdk/configurations/azure_search.py +1 -1
- alita_sdk/configurations/bigquery.py +1 -1
- alita_sdk/configurations/bitbucket.py +94 -2
- alita_sdk/configurations/browser.py +18 -0
- alita_sdk/configurations/carrier.py +19 -0
- alita_sdk/configurations/confluence.py +96 -1
- alita_sdk/configurations/delta_lake.py +1 -1
- alita_sdk/configurations/figma.py +0 -5
- alita_sdk/configurations/github.py +65 -1
- alita_sdk/configurations/gitlab.py +79 -0
- alita_sdk/configurations/google_places.py +17 -0
- alita_sdk/configurations/jira.py +103 -0
- alita_sdk/configurations/postman.py +1 -1
- alita_sdk/configurations/qtest.py +1 -3
- alita_sdk/configurations/report_portal.py +19 -0
- alita_sdk/configurations/salesforce.py +19 -0
- alita_sdk/configurations/service_now.py +1 -12
- alita_sdk/configurations/sharepoint.py +19 -0
- alita_sdk/configurations/sonar.py +18 -0
- alita_sdk/configurations/sql.py +20 -0
- alita_sdk/configurations/testio.py +18 -0
- alita_sdk/configurations/testrail.py +88 -0
- alita_sdk/configurations/xray.py +94 -1
- alita_sdk/configurations/zephyr_enterprise.py +94 -1
- alita_sdk/configurations/zephyr_essential.py +95 -0
- alita_sdk/runtime/clients/artifact.py +12 -2
- alita_sdk/runtime/clients/client.py +235 -66
- alita_sdk/runtime/clients/mcp_discovery.py +342 -0
- alita_sdk/runtime/clients/mcp_manager.py +262 -0
- alita_sdk/runtime/clients/sandbox_client.py +373 -0
- alita_sdk/runtime/langchain/assistant.py +123 -17
- alita_sdk/runtime/langchain/constants.py +8 -1
- alita_sdk/runtime/langchain/document_loaders/AlitaDocxMammothLoader.py +315 -3
- alita_sdk/runtime/langchain/document_loaders/AlitaExcelLoader.py +209 -31
- alita_sdk/runtime/langchain/document_loaders/AlitaImageLoader.py +1 -1
- alita_sdk/runtime/langchain/document_loaders/AlitaJSONLoader.py +8 -2
- alita_sdk/runtime/langchain/document_loaders/AlitaMarkdownLoader.py +66 -0
- alita_sdk/runtime/langchain/document_loaders/AlitaPDFLoader.py +79 -10
- alita_sdk/runtime/langchain/document_loaders/AlitaPowerPointLoader.py +52 -15
- alita_sdk/runtime/langchain/document_loaders/AlitaPythonLoader.py +9 -0
- alita_sdk/runtime/langchain/document_loaders/AlitaTableLoader.py +1 -4
- alita_sdk/runtime/langchain/document_loaders/AlitaTextLoader.py +15 -2
- alita_sdk/runtime/langchain/document_loaders/ImageParser.py +30 -0
- alita_sdk/runtime/langchain/document_loaders/constants.py +187 -40
- alita_sdk/runtime/langchain/interfaces/llm_processor.py +4 -2
- alita_sdk/runtime/langchain/langraph_agent.py +406 -91
- alita_sdk/runtime/langchain/utils.py +51 -8
- alita_sdk/runtime/llms/preloaded.py +2 -6
- alita_sdk/runtime/models/mcp_models.py +61 -0
- alita_sdk/runtime/toolkits/__init__.py +26 -0
- alita_sdk/runtime/toolkits/application.py +9 -2
- alita_sdk/runtime/toolkits/artifact.py +19 -7
- alita_sdk/runtime/toolkits/datasource.py +13 -6
- alita_sdk/runtime/toolkits/mcp.py +780 -0
- alita_sdk/runtime/toolkits/planning.py +178 -0
- alita_sdk/runtime/toolkits/subgraph.py +11 -6
- alita_sdk/runtime/toolkits/tools.py +214 -60
- alita_sdk/runtime/toolkits/vectorstore.py +9 -4
- alita_sdk/runtime/tools/__init__.py +22 -0
- alita_sdk/runtime/tools/application.py +16 -4
- alita_sdk/runtime/tools/artifact.py +312 -19
- alita_sdk/runtime/tools/function.py +100 -4
- alita_sdk/runtime/tools/graph.py +81 -0
- alita_sdk/runtime/tools/image_generation.py +212 -0
- alita_sdk/runtime/tools/llm.py +539 -180
- alita_sdk/runtime/tools/mcp_inspect_tool.py +284 -0
- alita_sdk/runtime/tools/mcp_remote_tool.py +181 -0
- alita_sdk/runtime/tools/mcp_server_tool.py +3 -1
- alita_sdk/runtime/tools/planning/__init__.py +36 -0
- alita_sdk/runtime/tools/planning/models.py +246 -0
- alita_sdk/runtime/tools/planning/wrapper.py +607 -0
- alita_sdk/runtime/tools/router.py +2 -1
- alita_sdk/runtime/tools/sandbox.py +375 -0
- alita_sdk/runtime/tools/vectorstore.py +62 -63
- alita_sdk/runtime/tools/vectorstore_base.py +156 -85
- alita_sdk/runtime/utils/AlitaCallback.py +106 -20
- alita_sdk/runtime/utils/mcp_client.py +465 -0
- alita_sdk/runtime/utils/mcp_oauth.py +244 -0
- alita_sdk/runtime/utils/mcp_sse_client.py +405 -0
- alita_sdk/runtime/utils/mcp_tools_discovery.py +124 -0
- alita_sdk/runtime/utils/streamlit.py +41 -14
- alita_sdk/runtime/utils/toolkit_utils.py +28 -9
- alita_sdk/runtime/utils/utils.py +14 -0
- alita_sdk/tools/__init__.py +78 -35
- alita_sdk/tools/ado/__init__.py +0 -1
- alita_sdk/tools/ado/repos/__init__.py +10 -6
- alita_sdk/tools/ado/repos/repos_wrapper.py +12 -11
- alita_sdk/tools/ado/test_plan/__init__.py +10 -7
- alita_sdk/tools/ado/test_plan/test_plan_wrapper.py +56 -23
- alita_sdk/tools/ado/wiki/__init__.py +10 -11
- alita_sdk/tools/ado/wiki/ado_wrapper.py +114 -28
- alita_sdk/tools/ado/work_item/__init__.py +10 -11
- alita_sdk/tools/ado/work_item/ado_wrapper.py +63 -10
- alita_sdk/tools/advanced_jira_mining/__init__.py +10 -7
- alita_sdk/tools/aws/delta_lake/__init__.py +13 -11
- alita_sdk/tools/azure_ai/search/__init__.py +11 -7
- alita_sdk/tools/base_indexer_toolkit.py +392 -86
- alita_sdk/tools/bitbucket/__init__.py +18 -11
- alita_sdk/tools/bitbucket/api_wrapper.py +52 -9
- alita_sdk/tools/bitbucket/cloud_api_wrapper.py +5 -5
- alita_sdk/tools/browser/__init__.py +40 -16
- alita_sdk/tools/browser/crawler.py +3 -1
- alita_sdk/tools/browser/utils.py +15 -6
- alita_sdk/tools/carrier/__init__.py +17 -17
- alita_sdk/tools/carrier/backend_reports_tool.py +8 -4
- alita_sdk/tools/carrier/excel_reporter.py +8 -4
- alita_sdk/tools/chunkers/__init__.py +3 -1
- alita_sdk/tools/chunkers/code/codeparser.py +1 -1
- alita_sdk/tools/chunkers/sematic/json_chunker.py +1 -0
- alita_sdk/tools/chunkers/sematic/markdown_chunker.py +97 -6
- alita_sdk/tools/chunkers/sematic/proposal_chunker.py +1 -1
- alita_sdk/tools/chunkers/universal_chunker.py +270 -0
- alita_sdk/tools/cloud/aws/__init__.py +9 -6
- alita_sdk/tools/cloud/azure/__init__.py +9 -6
- alita_sdk/tools/cloud/gcp/__init__.py +9 -6
- alita_sdk/tools/cloud/k8s/__init__.py +9 -6
- alita_sdk/tools/code/linter/__init__.py +7 -7
- alita_sdk/tools/code/loaders/codesearcher.py +3 -2
- alita_sdk/tools/code/sonar/__init__.py +18 -12
- alita_sdk/tools/code_indexer_toolkit.py +199 -0
- alita_sdk/tools/confluence/__init__.py +14 -11
- alita_sdk/tools/confluence/api_wrapper.py +198 -58
- alita_sdk/tools/confluence/loader.py +10 -0
- alita_sdk/tools/custom_open_api/__init__.py +9 -4
- alita_sdk/tools/elastic/__init__.py +8 -7
- alita_sdk/tools/elitea_base.py +543 -64
- alita_sdk/tools/figma/__init__.py +10 -8
- alita_sdk/tools/figma/api_wrapper.py +352 -153
- alita_sdk/tools/github/__init__.py +13 -11
- alita_sdk/tools/github/api_wrapper.py +9 -26
- alita_sdk/tools/github/github_client.py +75 -12
- alita_sdk/tools/github/schemas.py +2 -1
- alita_sdk/tools/gitlab/__init__.py +11 -10
- alita_sdk/tools/gitlab/api_wrapper.py +135 -45
- alita_sdk/tools/gitlab_org/__init__.py +11 -9
- alita_sdk/tools/google/bigquery/__init__.py +12 -13
- alita_sdk/tools/google_places/__init__.py +18 -10
- alita_sdk/tools/jira/__init__.py +14 -8
- alita_sdk/tools/jira/api_wrapper.py +315 -168
- alita_sdk/tools/keycloak/__init__.py +8 -7
- alita_sdk/tools/localgit/local_git.py +56 -54
- alita_sdk/tools/memory/__init__.py +27 -11
- alita_sdk/tools/non_code_indexer_toolkit.py +7 -2
- alita_sdk/tools/ocr/__init__.py +8 -7
- alita_sdk/tools/openapi/__init__.py +10 -1
- alita_sdk/tools/pandas/__init__.py +8 -7
- alita_sdk/tools/pandas/api_wrapper.py +7 -25
- alita_sdk/tools/postman/__init__.py +8 -10
- alita_sdk/tools/postman/api_wrapper.py +19 -8
- alita_sdk/tools/postman/postman_analysis.py +8 -1
- alita_sdk/tools/pptx/__init__.py +8 -9
- alita_sdk/tools/qtest/__init__.py +19 -13
- alita_sdk/tools/qtest/api_wrapper.py +1784 -88
- alita_sdk/tools/rally/__init__.py +10 -9
- alita_sdk/tools/report_portal/__init__.py +20 -15
- alita_sdk/tools/salesforce/__init__.py +19 -15
- alita_sdk/tools/servicenow/__init__.py +14 -11
- alita_sdk/tools/sharepoint/__init__.py +14 -13
- alita_sdk/tools/sharepoint/api_wrapper.py +179 -39
- alita_sdk/tools/sharepoint/authorization_helper.py +191 -1
- alita_sdk/tools/sharepoint/utils.py +8 -2
- alita_sdk/tools/slack/__init__.py +10 -7
- alita_sdk/tools/sql/__init__.py +19 -18
- alita_sdk/tools/sql/api_wrapper.py +71 -23
- alita_sdk/tools/testio/__init__.py +18 -12
- alita_sdk/tools/testrail/__init__.py +10 -10
- alita_sdk/tools/testrail/api_wrapper.py +213 -45
- alita_sdk/tools/utils/__init__.py +28 -4
- alita_sdk/tools/utils/content_parser.py +181 -61
- alita_sdk/tools/utils/text_operations.py +254 -0
- alita_sdk/tools/vector_adapters/VectorStoreAdapter.py +83 -27
- alita_sdk/tools/xray/__init__.py +12 -7
- alita_sdk/tools/xray/api_wrapper.py +58 -113
- alita_sdk/tools/zephyr/__init__.py +9 -6
- alita_sdk/tools/zephyr_enterprise/__init__.py +13 -8
- alita_sdk/tools/zephyr_enterprise/api_wrapper.py +17 -7
- alita_sdk/tools/zephyr_essential/__init__.py +13 -9
- alita_sdk/tools/zephyr_essential/api_wrapper.py +289 -47
- alita_sdk/tools/zephyr_essential/client.py +6 -4
- alita_sdk/tools/zephyr_scale/__init__.py +10 -7
- alita_sdk/tools/zephyr_scale/api_wrapper.py +6 -2
- alita_sdk/tools/zephyr_squad/__init__.py +9 -6
- {alita_sdk-0.3.263.dist-info → alita_sdk-0.3.499.dist-info}/METADATA +180 -33
- alita_sdk-0.3.499.dist-info/RECORD +433 -0
- alita_sdk-0.3.499.dist-info/entry_points.txt +2 -0
- alita_sdk-0.3.263.dist-info/RECORD +0 -342
- {alita_sdk-0.3.263.dist-info → alita_sdk-0.3.499.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.263.dist-info → alita_sdk-0.3.499.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.263.dist-info → alita_sdk-0.3.499.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
4
|
+
|
|
5
|
+
from ..utils.utils import clean_string
|
|
6
|
+
from langchain_core.tools import BaseTool
|
|
7
|
+
from langchain_core.messages import BaseMessage, AIMessage, ToolCall
|
|
8
|
+
from typing import Any, Type, Optional, Union
|
|
9
|
+
from pydantic import create_model, field_validator, BaseModel
|
|
10
|
+
from pydantic.fields import FieldInfo
|
|
11
|
+
from ..langchain.mixedAgentRenderes import convert_message_to_json
|
|
12
|
+
from logging import getLogger
|
|
13
|
+
|
|
14
|
+
logger = getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
graphToolSchema = create_model(
|
|
17
|
+
"graphToolSchema",
|
|
18
|
+
input=(str, FieldInfo(description="User Input for Graph")),
|
|
19
|
+
chat_history=(Optional[list[BaseMessage]],
|
|
20
|
+
FieldInfo(description="Chat History relevant for Graph in format [{'role': '<user| assistant | etc>', 'content': '<content of the respected message>'}]", default=[]))
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def formulate_query(kwargs):
|
|
25
|
+
chat_history = []
|
|
26
|
+
if kwargs.get('chat_history'):
|
|
27
|
+
if isinstance(kwargs.get('chat_history')[-1], BaseMessage):
|
|
28
|
+
chat_history = convert_message_to_json(kwargs.get('chat_history')[:])
|
|
29
|
+
elif isinstance(kwargs.get('chat_history')[-1], dict):
|
|
30
|
+
if all([True if message.get('role') and message.get('content') else False for message in
|
|
31
|
+
kwargs.get('chat_history')]):
|
|
32
|
+
chat_history = kwargs.get('chat_history')[:]
|
|
33
|
+
else:
|
|
34
|
+
for each in kwargs.get('chat_history')[:]:
|
|
35
|
+
chat_history.append(AIMessage(json.dumps(each)))
|
|
36
|
+
elif isinstance(kwargs.get('chat_history')[-1], str):
|
|
37
|
+
chat_history = []
|
|
38
|
+
for each in kwargs.get('chat_history')[:]:
|
|
39
|
+
chat_history.append(AIMessage(each))
|
|
40
|
+
elif kwargs.get('messages'):
|
|
41
|
+
chat_history = convert_message_to_json(kwargs.get('messages')[:])
|
|
42
|
+
result = {"input": kwargs.get('input'), "chat_history": chat_history}
|
|
43
|
+
for key, value in kwargs.items():
|
|
44
|
+
if key not in ("input", "chat_history"):
|
|
45
|
+
result[key] = value
|
|
46
|
+
return result
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class GraphTool(BaseTool):
|
|
50
|
+
name: str = 'GraphTool'
|
|
51
|
+
description: str = 'Graph tool for tools'
|
|
52
|
+
graph: CompiledStateGraph
|
|
53
|
+
args_schema: Type[BaseModel] = graphToolSchema
|
|
54
|
+
return_type: str = "str"
|
|
55
|
+
|
|
56
|
+
@field_validator('name', mode='before')
|
|
57
|
+
@classmethod
|
|
58
|
+
def remove_spaces(cls, v):
|
|
59
|
+
return clean_string(v)
|
|
60
|
+
|
|
61
|
+
def invoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
|
|
62
|
+
"""Override default invoke to preserve all fields, not just args_schema"""
|
|
63
|
+
schema_values = self.args_schema(**input).model_dump() if self.args_schema else {}
|
|
64
|
+
extras = {k: v for k, v in input.items() if k not in schema_values}
|
|
65
|
+
all_kwargs = {**kwargs, **extras, **schema_values}
|
|
66
|
+
if config is None:
|
|
67
|
+
config = {}
|
|
68
|
+
# Pass the config to the _run empty or the one passed from the parent executor.
|
|
69
|
+
return self._run(config, **all_kwargs)
|
|
70
|
+
|
|
71
|
+
def _run(self, *args, **kwargs):
|
|
72
|
+
config = None
|
|
73
|
+
# From invoke method we are passing only 1 arg so it is safe to do this condition and config assignment.
|
|
74
|
+
# Default to None is safe because it will be checked also on the langchain side.
|
|
75
|
+
if args:
|
|
76
|
+
config = args[0]
|
|
77
|
+
response = self.graph.invoke(formulate_query(kwargs), config=config)
|
|
78
|
+
if self.return_type == "str":
|
|
79
|
+
return response["output"]
|
|
80
|
+
else:
|
|
81
|
+
return {"messages": [{"role": "assistant", "content": response["output"]}]}
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Image generation tool for Alita SDK.
|
|
3
|
+
"""
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Optional, Type, Any, List, Literal
|
|
6
|
+
from langchain_core.tools import BaseTool, BaseToolkit
|
|
7
|
+
from pydantic import BaseModel, Field, create_model, ConfigDict
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
name = "image_generation"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_tools(tools_list: list, alita_client=None, llm=None,
|
|
15
|
+
memory_store=None):
|
|
16
|
+
"""
|
|
17
|
+
Get image generation tools for the provided tool configurations.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
tools_list: List of tool configurations
|
|
21
|
+
alita_client: Alita client instance (required for image generation)
|
|
22
|
+
llm: LLM client instance (unused for image generation)
|
|
23
|
+
memory_store: Optional memory store instance (unused)
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
List of image generation tools
|
|
27
|
+
"""
|
|
28
|
+
all_tools = []
|
|
29
|
+
|
|
30
|
+
for tool in tools_list:
|
|
31
|
+
if (tool.get('type') == 'image_generation' or
|
|
32
|
+
tool.get('toolkit_name') == 'image_generation'):
|
|
33
|
+
try:
|
|
34
|
+
if not alita_client:
|
|
35
|
+
logger.error("Alita client is required for image "
|
|
36
|
+
"generation tools")
|
|
37
|
+
continue
|
|
38
|
+
|
|
39
|
+
toolkit_instance = ImageGenerationToolkit.get_toolkit(
|
|
40
|
+
client=alita_client,
|
|
41
|
+
toolkit_name=tool.get('toolkit_name', '')
|
|
42
|
+
)
|
|
43
|
+
all_tools.extend(toolkit_instance.get_tools())
|
|
44
|
+
except Exception as e:
|
|
45
|
+
logger.error(f"Error in image generation toolkit "
|
|
46
|
+
f"get_tools: {e}")
|
|
47
|
+
logger.error(f"Tool config: {tool}")
|
|
48
|
+
raise
|
|
49
|
+
|
|
50
|
+
return all_tools
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ImageGenerationInput(BaseModel):
|
|
54
|
+
"""Input schema for image generation tool."""
|
|
55
|
+
prompt: str = Field(
|
|
56
|
+
description="Text prompt describing the image to generate"
|
|
57
|
+
)
|
|
58
|
+
n: int = Field(
|
|
59
|
+
default=1, description="Number of images to generate (1-10)",
|
|
60
|
+
ge=1, le=10
|
|
61
|
+
)
|
|
62
|
+
size: str = Field(
|
|
63
|
+
default="auto",
|
|
64
|
+
description="Size of the generated image (e.g., '1024x1024')"
|
|
65
|
+
)
|
|
66
|
+
quality: str = Field(
|
|
67
|
+
default="auto",
|
|
68
|
+
description="Quality of the generated image ('low', 'medium', 'high')"
|
|
69
|
+
)
|
|
70
|
+
style: Optional[str] = Field(
|
|
71
|
+
default=None, description="Style of the generated image (optional)"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ImageGenerationTool(BaseTool):
|
|
76
|
+
"""Tool for generating images using the Alita client."""
|
|
77
|
+
|
|
78
|
+
name: str = "generate_image"
|
|
79
|
+
description: str = "Generate images from text prompts using AI models"
|
|
80
|
+
args_schema: Type[BaseModel] = ImageGenerationInput
|
|
81
|
+
alita_client: Any = None
|
|
82
|
+
|
|
83
|
+
def __init__(self, client, **kwargs):
|
|
84
|
+
super().__init__(**kwargs)
|
|
85
|
+
self.alita_client = client
|
|
86
|
+
|
|
87
|
+
def _run(self, prompt: str, n: int = 1, size: str = "auto",
|
|
88
|
+
quality: str = "auto", style: Optional[str] = None) -> list:
|
|
89
|
+
"""Generate an image based on the provided parameters."""
|
|
90
|
+
try:
|
|
91
|
+
logger.info(f"Generating image with prompt: {prompt[:50]}...")
|
|
92
|
+
|
|
93
|
+
result = self.alita_client.generate_image(
|
|
94
|
+
prompt=prompt,
|
|
95
|
+
n=n,
|
|
96
|
+
size=size,
|
|
97
|
+
quality=quality,
|
|
98
|
+
style=style
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Return multimodal content format for LLM consumption
|
|
102
|
+
if 'data' in result:
|
|
103
|
+
images = result['data']
|
|
104
|
+
content_chunks = []
|
|
105
|
+
|
|
106
|
+
# Add a text description of what was generated
|
|
107
|
+
if len(images) == 1:
|
|
108
|
+
content_chunks.append({
|
|
109
|
+
"type": "text",
|
|
110
|
+
"text": f"Generated image for prompt: '{prompt}'"
|
|
111
|
+
})
|
|
112
|
+
else:
|
|
113
|
+
content_chunks.append({
|
|
114
|
+
"type": "text",
|
|
115
|
+
"text": f"Generated {len(images)} images for "
|
|
116
|
+
f"prompt: '{prompt}'"
|
|
117
|
+
})
|
|
118
|
+
|
|
119
|
+
# Add image content for each generated image
|
|
120
|
+
for image_data in images:
|
|
121
|
+
if image_data.get('url'):
|
|
122
|
+
content_chunks.append({
|
|
123
|
+
"type": "image_url",
|
|
124
|
+
"image_url": {
|
|
125
|
+
"url": image_data['url']
|
|
126
|
+
}
|
|
127
|
+
})
|
|
128
|
+
elif image_data.get('b64_json'):
|
|
129
|
+
content_chunks.append({
|
|
130
|
+
"type": "image_url",
|
|
131
|
+
"image_url": {
|
|
132
|
+
"url": f"data:image/png;base64,"
|
|
133
|
+
f"{image_data['b64_json']}"
|
|
134
|
+
}
|
|
135
|
+
})
|
|
136
|
+
|
|
137
|
+
return content_chunks
|
|
138
|
+
|
|
139
|
+
# Fallback to text response if no images in result
|
|
140
|
+
return [{
|
|
141
|
+
"type": "text",
|
|
142
|
+
"text": f"Image generation completed but no images "
|
|
143
|
+
f"returned: {result}"
|
|
144
|
+
}]
|
|
145
|
+
|
|
146
|
+
except Exception as e:
|
|
147
|
+
logger.error(f"Error generating image: {e}")
|
|
148
|
+
return [{
|
|
149
|
+
"type": "text",
|
|
150
|
+
"text": f"Error generating image: {str(e)}"
|
|
151
|
+
}]
|
|
152
|
+
|
|
153
|
+
async def _arun(self, prompt: str, n: int = 1, size: str = "256x256",
|
|
154
|
+
quality: str = "auto",
|
|
155
|
+
style: Optional[str] = None) -> list:
|
|
156
|
+
"""Async version - for now just calls the sync version."""
|
|
157
|
+
return self._run(prompt, n, size, quality, style)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def create_image_generation_tool(client):
|
|
161
|
+
"""Create an image generation tool with the provided Alita client."""
|
|
162
|
+
return ImageGenerationTool(client=client)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class ImageGenerationToolkit(BaseToolkit):
|
|
166
|
+
"""Toolkit for image generation tools."""
|
|
167
|
+
tools: List[BaseTool] = []
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def toolkit_config_schema() -> BaseModel:
|
|
171
|
+
"""Get the configuration schema for the image generation toolkit."""
|
|
172
|
+
# Create sample tool to get schema
|
|
173
|
+
sample_tool = ImageGenerationTool(client=None)
|
|
174
|
+
selected_tools = {sample_tool.name: sample_tool.args_schema.schema()}
|
|
175
|
+
|
|
176
|
+
return create_model(
|
|
177
|
+
'image_generation',
|
|
178
|
+
selected_tools=(
|
|
179
|
+
List[Literal[tuple(selected_tools)]],
|
|
180
|
+
Field(
|
|
181
|
+
default=[],
|
|
182
|
+
json_schema_extra={'args_schemas': selected_tools}
|
|
183
|
+
)
|
|
184
|
+
),
|
|
185
|
+
__config__=ConfigDict(json_schema_extra={
|
|
186
|
+
'metadata': {
|
|
187
|
+
"label": "Image Generation",
|
|
188
|
+
"icon_url": "image_generation.svg",
|
|
189
|
+
"hidden": True,
|
|
190
|
+
"categories": ["internal_tool"],
|
|
191
|
+
"extra_categories": ["image generation"],
|
|
192
|
+
}
|
|
193
|
+
})
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
@classmethod
|
|
197
|
+
def get_toolkit(cls, client=None, **kwargs):
|
|
198
|
+
"""
|
|
199
|
+
Get toolkit with image generation tools.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
client: Alita client instance (required)
|
|
203
|
+
**kwargs: Additional arguments
|
|
204
|
+
"""
|
|
205
|
+
if not client:
|
|
206
|
+
raise ValueError("Alita client is required for image generation")
|
|
207
|
+
|
|
208
|
+
tools = [ImageGenerationTool(client=client)]
|
|
209
|
+
return cls(tools=tools)
|
|
210
|
+
|
|
211
|
+
def get_tools(self):
|
|
212
|
+
return self.tools
|