sequrity 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sequrity/__init__.py +14 -0
- sequrity/_version.py +34 -0
- sequrity/client.py +35 -0
- sequrity/constants.py +27 -0
- sequrity/control/__init__.py +23 -0
- sequrity/control/chat_completion.py +116 -0
- sequrity/control/langgraph/__init__.py +3 -0
- sequrity/control/langgraph/graph_executor.py +282 -0
- sequrity/control/langgraph/run.py +184 -0
- sequrity/control/sqrt/__init__.py +3 -0
- sequrity/control/sqrt/grammar.lark +351 -0
- sequrity/control/sqrt/parser.py +209 -0
- sequrity/control/types/__init__.py +22 -0
- sequrity/control/types/headers/__init__.py +12 -0
- sequrity/control/types/headers/feature_headers.py +234 -0
- sequrity/control/types/headers/policy_headers.py +163 -0
- sequrity/control/types/headers/session_config_headers.py +228 -0
- sequrity/control/types/langgraph.py +20 -0
- sequrity/control/types/results.py +37 -0
- sequrity/control/types/value_with_meta.py +28 -0
- sequrity/control/wrapper.py +231 -0
- sequrity/service_provider.py +16 -0
- sequrity/types/__init__.py +0 -0
- sequrity/types/chat_completion/__init__.py +0 -0
- sequrity/types/chat_completion/request.py +377 -0
- sequrity/types/chat_completion/response.py +203 -0
- sequrity-0.0.1.dist-info/METADATA +64 -0
- sequrity-0.0.1.dist-info/RECORD +30 -0
- sequrity-0.0.1.dist-info/WHEEL +4 -0
- sequrity-0.0.1.dist-info/licenses/LICENSE +192 -0
sequrity/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Sequrity API Python client.
|
|
2
|
+
|
|
3
|
+
This package provides a Python client for interacting with the Sequrity API,
|
|
4
|
+
enabling secure LLM interactions with policy enforcement.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from sequrity.client import SequrityClient
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from sequrity._version import __version__
|
|
11
|
+
except ImportError:
|
|
12
|
+
__version__ = "0.0.0.dev0"
|
|
13
|
+
|
|
14
|
+
__all__ = ["SequrityClient", "__version__"]
|
sequrity/_version.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TYPE_CHECKING = False
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
from typing import Union
|
|
17
|
+
|
|
18
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
19
|
+
COMMIT_ID = Union[str, None]
|
|
20
|
+
else:
|
|
21
|
+
VERSION_TUPLE = object
|
|
22
|
+
COMMIT_ID = object
|
|
23
|
+
|
|
24
|
+
version: str
|
|
25
|
+
__version__: str
|
|
26
|
+
__version_tuple__: VERSION_TUPLE
|
|
27
|
+
version_tuple: VERSION_TUPLE
|
|
28
|
+
commit_id: COMMIT_ID
|
|
29
|
+
__commit_id__: COMMIT_ID
|
|
30
|
+
|
|
31
|
+
__version__ = version = '0.0.1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 0, 1)
|
|
33
|
+
|
|
34
|
+
__commit_id__ = commit_id = None
|
sequrity/client.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Sequrity API client module.
|
|
2
|
+
|
|
3
|
+
This module provides the main client class for interacting with the Sequrity API.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from .constants import SEQURITY_API_URL
|
|
9
|
+
from .control.wrapper import ControlApiWrapper
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SequrityClient:
|
|
13
|
+
"""Main client for interacting with the Sequrity API.
|
|
14
|
+
|
|
15
|
+
The SequrityClient provides a high-level interface for accessing Sequrity's
|
|
16
|
+
security features, including chat completion with security policies and
|
|
17
|
+
LangGraph integration.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
control: The Control API wrapper for chat completions and LangGraph operations.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, api_key: str, base_url: str | None = None, timeout: int = 300):
|
|
24
|
+
"""Initialize the Sequrity client.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
api_key: Your Sequrity API key for authentication.
|
|
28
|
+
base_url: Optional custom base URL for the Sequrity API.
|
|
29
|
+
Defaults to the production Sequrity API URL.
|
|
30
|
+
timeout: Request timeout in seconds. Defaults to 300.
|
|
31
|
+
"""
|
|
32
|
+
self._api_key = api_key
|
|
33
|
+
self._base_url = base_url or SEQURITY_API_URL
|
|
34
|
+
self._client = httpx.Client(timeout=timeout)
|
|
35
|
+
self.control = ControlApiWrapper(client=self._client, base_url=self._base_url, api_key=self._api_key)
|
sequrity/constants.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from enum import StrEnum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SequrityProductEnum(StrEnum):
|
|
5
|
+
CONTROL = "control"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
SEQURITY_API_URL = "https://api.sequrity.ai"
|
|
9
|
+
|
|
10
|
+
CONTROL_API_PATHS = {
|
|
11
|
+
"chat_completions": {
|
|
12
|
+
"default": "/control/v1/chat/completions",
|
|
13
|
+
"with_service_provider": "/control/{service_provider}/v1/chat/completions",
|
|
14
|
+
},
|
|
15
|
+
"responses": {
|
|
16
|
+
"default": "/control/v1/responses",
|
|
17
|
+
},
|
|
18
|
+
"generate_policy": {
|
|
19
|
+
"default": "/control/v1/generate_policy",
|
|
20
|
+
},
|
|
21
|
+
"vscode_chat_completions": {
|
|
22
|
+
"default": "/control/vscode/{service_provider}/v1/chat/completions",
|
|
23
|
+
},
|
|
24
|
+
"langgraph_chat_completions": {
|
|
25
|
+
"default": "/control/lang-graph/{service_provider}/v1/chat/completions",
|
|
26
|
+
},
|
|
27
|
+
}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from .types import (
|
|
2
|
+
ControlFlowMetaPolicy,
|
|
3
|
+
FeaturesHeader,
|
|
4
|
+
FineGrainedConfigHeader,
|
|
5
|
+
InternalPolicyPreset,
|
|
6
|
+
MetaData,
|
|
7
|
+
ResponseContentJsonSchema,
|
|
8
|
+
ResponseFormat,
|
|
9
|
+
SecurityPolicyHeader,
|
|
10
|
+
ValueWithMeta,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"FeaturesHeader",
|
|
15
|
+
"FineGrainedConfigHeader",
|
|
16
|
+
"MetaData",
|
|
17
|
+
"ResponseContentJsonSchema",
|
|
18
|
+
"SecurityPolicyHeader",
|
|
19
|
+
"ValueWithMeta",
|
|
20
|
+
"InternalPolicyPreset",
|
|
21
|
+
"ResponseFormat",
|
|
22
|
+
"ControlFlowMetaPolicy",
|
|
23
|
+
]
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from urllib.parse import urljoin
|
|
3
|
+
|
|
4
|
+
import httpx
|
|
5
|
+
|
|
6
|
+
from ..constants import CONTROL_API_PATHS
|
|
7
|
+
from ..service_provider import LlmServiceProviderEnum
|
|
8
|
+
from ..types.chat_completion.request import ChatCompletionRequest, Message, ReasoningEffort, ResponseFormat, Tool
|
|
9
|
+
from ..types.chat_completion.response import ChatCompletionResponse
|
|
10
|
+
from .types.headers import FeaturesHeader, FineGrainedConfigHeader, SecurityPolicyHeader
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def create_chat_completion_sync(
|
|
14
|
+
client: httpx.Client,
|
|
15
|
+
base_url: str,
|
|
16
|
+
api_key: str,
|
|
17
|
+
messages: list[Message] | list[dict],
|
|
18
|
+
model: str,
|
|
19
|
+
llm_api_key: str | None,
|
|
20
|
+
features: FeaturesHeader | None = None,
|
|
21
|
+
security_policy: SecurityPolicyHeader | None = None,
|
|
22
|
+
fine_grained_config: FineGrainedConfigHeader | None = None,
|
|
23
|
+
provider: LlmServiceProviderEnum | Literal["default"] = "default",
|
|
24
|
+
session_id: str | None = None,
|
|
25
|
+
reasoning_effort: ReasoningEffort | None = None,
|
|
26
|
+
response_format: ResponseFormat | None = None,
|
|
27
|
+
seed: int | None = None,
|
|
28
|
+
stream: bool | None = None,
|
|
29
|
+
temperature: float | None = None,
|
|
30
|
+
tools: list[Tool] | None = None,
|
|
31
|
+
top_p: float | None = None,
|
|
32
|
+
return_type: Literal["python", "json"] = "python",
|
|
33
|
+
) -> ChatCompletionResponse | dict:
|
|
34
|
+
"""Send a chat completion request to the Sequrity secure orchestrator.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
client: HTTP client for making requests.
|
|
38
|
+
base_url: Base URL of the Sequrity API.
|
|
39
|
+
api_key: Sequrity API key for authentication.
|
|
40
|
+
messages: List of chat messages.
|
|
41
|
+
model: Model name to use for completion.
|
|
42
|
+
llm_api_key: Optional LLM provider API key (uses server default if None).
|
|
43
|
+
features: Security features to enable.
|
|
44
|
+
security_policy: Security policy configuration.
|
|
45
|
+
fine_grained_config: Fine-grained security settings.
|
|
46
|
+
provider: LLM service provider (openai, openrouter, etc.).
|
|
47
|
+
session_id: Explicit session ID for continuing an existing conversation.
|
|
48
|
+
If None and no tool messages in the request, a new session is created.
|
|
49
|
+
reasoning_effort: Reasoning effort level for supported models.
|
|
50
|
+
response_format: Response format specification.
|
|
51
|
+
seed: Random seed for reproducibility.
|
|
52
|
+
stream: Whether to stream the response.
|
|
53
|
+
temperature: Sampling temperature.
|
|
54
|
+
tools: List of tools available to the model.
|
|
55
|
+
top_p: Nucleus sampling parameter.
|
|
56
|
+
return_type: Return as "python" (ChatCompletionResponse) or "json" (dict).
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
ChatCompletionResponse if return_type="python", dict if return_type="json".
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
httpx.HTTPStatusError: If the request fails.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
# Construct the URL based on service provider
|
|
66
|
+
if provider == "default":
|
|
67
|
+
path = CONTROL_API_PATHS["chat_completions"]["default"]
|
|
68
|
+
else:
|
|
69
|
+
path = CONTROL_API_PATHS["chat_completions"]["with_service_provider"].format(service_provider=provider)
|
|
70
|
+
url = urljoin(base_url, path)
|
|
71
|
+
|
|
72
|
+
# Prepare the request payload
|
|
73
|
+
payload = ChatCompletionRequest.model_validate(
|
|
74
|
+
{
|
|
75
|
+
"messages": messages,
|
|
76
|
+
"model": model,
|
|
77
|
+
"reasoning_effort": reasoning_effort,
|
|
78
|
+
"response_format": response_format,
|
|
79
|
+
"seed": seed,
|
|
80
|
+
"stream": stream,
|
|
81
|
+
"temperature": temperature,
|
|
82
|
+
"tools": tools,
|
|
83
|
+
"top_p": top_p,
|
|
84
|
+
}
|
|
85
|
+
).model_dump(exclude_none=True)
|
|
86
|
+
|
|
87
|
+
headers = {
|
|
88
|
+
"Authorization": f"Bearer {api_key}",
|
|
89
|
+
"Content-Type": "application/json",
|
|
90
|
+
}
|
|
91
|
+
if llm_api_key:
|
|
92
|
+
headers["X-Api-Key"] = llm_api_key
|
|
93
|
+
if features:
|
|
94
|
+
headers["X-Security-Features"] = features.dump_for_headers(mode="json_str")
|
|
95
|
+
if security_policy:
|
|
96
|
+
headers["X-Security-Policy"] = security_policy.dump_for_headers(mode="json_str")
|
|
97
|
+
if fine_grained_config:
|
|
98
|
+
headers["X-Security-Config"] = fine_grained_config.dump_for_headers(mode="json_str")
|
|
99
|
+
if session_id:
|
|
100
|
+
headers["X-Session-Id"] = session_id
|
|
101
|
+
|
|
102
|
+
# Make the HTTP request
|
|
103
|
+
response = client.post(url, json=payload, headers=headers)
|
|
104
|
+
response.raise_for_status()
|
|
105
|
+
response_data = response.json()
|
|
106
|
+
session_id = response.headers.get("X-Session-Id")
|
|
107
|
+
|
|
108
|
+
# Parse and return the response
|
|
109
|
+
response = ChatCompletionResponse.model_validate(response_data)
|
|
110
|
+
response.session_id = session_id
|
|
111
|
+
if return_type == "json":
|
|
112
|
+
return response.model_dump(mode="json")
|
|
113
|
+
elif return_type == "python":
|
|
114
|
+
return response
|
|
115
|
+
else:
|
|
116
|
+
raise ValueError(f"Invalid return_type: {return_type}")
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
"""LangGraph integration for SequrityClient.
|
|
2
|
+
|
|
3
|
+
This module provides functionality to execute LangGraph StateGraphs securely
|
|
4
|
+
through the Sequrity orchestrator.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from typing import Callable
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from langgraph.graph import END, START, StateGraph
|
|
12
|
+
|
|
13
|
+
LANGGRAPH_AVAILABLE = True
|
|
14
|
+
except ImportError:
|
|
15
|
+
LANGGRAPH_AVAILABLE = False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LangGraphExecutor:
|
|
19
|
+
"""Executor for running LangGraph StateGraphs through Sequrity.
|
|
20
|
+
|
|
21
|
+
This class handles:
|
|
22
|
+
1. Converting LangGraph to executable Python code
|
|
23
|
+
2. Mapping nodes to tools (external) or internal functions
|
|
24
|
+
3. Executing code via secure orchestrator
|
|
25
|
+
4. Handling tool calls for external nodes
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
graph: "StateGraph",
|
|
31
|
+
node_functions: dict[str, Callable] | None = None,
|
|
32
|
+
internal_node_mapping: dict[str, str] | None = None,
|
|
33
|
+
):
|
|
34
|
+
"""Initialize the executor.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
graph: LangGraph StateGraph to execute.
|
|
38
|
+
node_functions: Dict mapping node names to their functions.
|
|
39
|
+
internal_node_mapping: Map of node names to internal tool names
|
|
40
|
+
(e.g., {"agent": "parse_with_ai"}).
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
RuntimeError: If LangGraph is not installed.
|
|
44
|
+
"""
|
|
45
|
+
if not LANGGRAPH_AVAILABLE:
|
|
46
|
+
raise RuntimeError("LangGraph is not installed. Install it with: pip install langgraph")
|
|
47
|
+
|
|
48
|
+
self.graph = graph
|
|
49
|
+
|
|
50
|
+
# Extract node functions if not provided
|
|
51
|
+
if node_functions is None:
|
|
52
|
+
self.node_functions = self._extract_function_map(graph)
|
|
53
|
+
else:
|
|
54
|
+
self.node_functions = node_functions
|
|
55
|
+
|
|
56
|
+
# Determine which nodes are internal vs external
|
|
57
|
+
self.internal_node_mapping = internal_node_mapping or {}
|
|
58
|
+
self.external_nodes: set[str] = set()
|
|
59
|
+
for node_name in self.node_functions.keys():
|
|
60
|
+
if node_name not in self.internal_node_mapping:
|
|
61
|
+
self.external_nodes.add(node_name)
|
|
62
|
+
|
|
63
|
+
# Generate code once at initialization
|
|
64
|
+
self.generated_code = self._graph_to_code(graph, self.node_functions)
|
|
65
|
+
|
|
66
|
+
def _extract_function_map(self, graph: "StateGraph") -> dict[str, Callable]:
|
|
67
|
+
"""Extract node functions from LangGraph StateGraph.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
graph: The StateGraph to extract functions from.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Dict mapping node names to their callable functions.
|
|
74
|
+
"""
|
|
75
|
+
function_map = {}
|
|
76
|
+
|
|
77
|
+
# Access internal graph structure to get nodes
|
|
78
|
+
if hasattr(graph, "nodes"):
|
|
79
|
+
for node_name, node_data in graph.nodes.items():
|
|
80
|
+
if hasattr(node_data, "func"):
|
|
81
|
+
function_map[node_name] = node_data.func
|
|
82
|
+
elif callable(node_data):
|
|
83
|
+
function_map[node_name] = node_data
|
|
84
|
+
|
|
85
|
+
return function_map
|
|
86
|
+
|
|
87
|
+
def build_tool_definitions(self) -> list[dict]:
|
|
88
|
+
"""Build OpenAI-style tool definitions for external nodes.
|
|
89
|
+
|
|
90
|
+
Each external node becomes a tool that the orchestrator can call.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
List of tool definition dicts in OpenAI function calling format.
|
|
94
|
+
"""
|
|
95
|
+
tools = []
|
|
96
|
+
|
|
97
|
+
for node_name in self.external_nodes:
|
|
98
|
+
func = self.node_functions.get(node_name)
|
|
99
|
+
|
|
100
|
+
# Get description from function if available
|
|
101
|
+
if func:
|
|
102
|
+
description = getattr(func, "__doc__", f"Execute node: {node_name}")
|
|
103
|
+
else:
|
|
104
|
+
description = f"Execute node: {node_name}"
|
|
105
|
+
|
|
106
|
+
# Build tool schema
|
|
107
|
+
tool_def = {
|
|
108
|
+
"type": "function",
|
|
109
|
+
"function": {
|
|
110
|
+
"name": node_name,
|
|
111
|
+
"description": description,
|
|
112
|
+
"parameters": {
|
|
113
|
+
"type": "object",
|
|
114
|
+
"properties": {"state": {"type": "object", "description": "Current state dict"}},
|
|
115
|
+
"required": ["state"],
|
|
116
|
+
},
|
|
117
|
+
},
|
|
118
|
+
}
|
|
119
|
+
tools.append(tool_def)
|
|
120
|
+
|
|
121
|
+
return tools
|
|
122
|
+
|
|
123
|
+
def execute_tool_call(self, tool_call: dict) -> dict:
|
|
124
|
+
"""Execute a tool call (external node).
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
tool_call: Tool call dict with id, name, and arguments.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
The result dict from executing the node function.
|
|
131
|
+
|
|
132
|
+
Raises:
|
|
133
|
+
RuntimeError: If no function is found for the specified tool name.
|
|
134
|
+
"""
|
|
135
|
+
tool_name = tool_call.get("function", {}).get("name")
|
|
136
|
+
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
|
|
137
|
+
|
|
138
|
+
# Parse arguments
|
|
139
|
+
try:
|
|
140
|
+
arguments = json.loads(arguments_str)
|
|
141
|
+
except json.JSONDecodeError:
|
|
142
|
+
arguments = {}
|
|
143
|
+
|
|
144
|
+
# Get the node function
|
|
145
|
+
node_func = self.node_functions.get(tool_name)
|
|
146
|
+
if not node_func:
|
|
147
|
+
raise RuntimeError(f"No function found for external node: {tool_name}")
|
|
148
|
+
|
|
149
|
+
# Execute the node function
|
|
150
|
+
state = arguments.get("state", {})
|
|
151
|
+
result = node_func(state)
|
|
152
|
+
|
|
153
|
+
return result
|
|
154
|
+
|
|
155
|
+
def _graph_to_code(self, graph: "StateGraph", function_map: dict[str, Callable]) -> str:
|
|
156
|
+
"""Convert a LangGraph StateGraph into executable Python code.
|
|
157
|
+
|
|
158
|
+
Generates:
|
|
159
|
+
- Module-level code with state = initial_state
|
|
160
|
+
- Linear flow with if-else for conditional routing
|
|
161
|
+
- Uses keyword arguments for function calls
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
graph: The StateGraph to convert.
|
|
165
|
+
function_map: Dict mapping node names to their callable functions.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Generated Python code as a string.
|
|
169
|
+
"""
|
|
170
|
+
nodes = graph.nodes
|
|
171
|
+
edges = list(graph.edges)
|
|
172
|
+
branches = dict(graph.branches)
|
|
173
|
+
|
|
174
|
+
code_lines = []
|
|
175
|
+
code_lines.append("# Module-level code - assumes initial_state is predefined")
|
|
176
|
+
code_lines.append("state = initial_state")
|
|
177
|
+
code_lines.append("")
|
|
178
|
+
|
|
179
|
+
# Find starting node
|
|
180
|
+
start_edges = [target for source, target in edges if source == START]
|
|
181
|
+
if not start_edges:
|
|
182
|
+
code_lines.append("# Extract final result for user")
|
|
183
|
+
code_lines.append("final_return_value = state.get('result', state)")
|
|
184
|
+
return "\n".join(code_lines)
|
|
185
|
+
|
|
186
|
+
# Generate code for each node
|
|
187
|
+
visited = set()
|
|
188
|
+
self._generate_node_code(start_edges[0], nodes, edges, branches, function_map, code_lines, visited, indent=0)
|
|
189
|
+
|
|
190
|
+
code_lines.append("")
|
|
191
|
+
code_lines.append("# Extract final result for user")
|
|
192
|
+
code_lines.append("final_return_value = state.get('result', state)")
|
|
193
|
+
|
|
194
|
+
return "\n".join(code_lines)
|
|
195
|
+
|
|
196
|
+
def _generate_node_code(self, node_name, nodes, edges, branches, function_map, code_lines, visited, indent=0):
|
|
197
|
+
"""Recursively generate code for a node and its successors.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
node_name: Name of the current node to generate code for.
|
|
201
|
+
nodes: Dict of all nodes in the graph.
|
|
202
|
+
edges: List of edge tuples (source, target).
|
|
203
|
+
branches: Dict of conditional branch specifications.
|
|
204
|
+
function_map: Dict mapping node names to their callable functions.
|
|
205
|
+
code_lines: List to append generated code lines to.
|
|
206
|
+
visited: Set of already visited node names.
|
|
207
|
+
indent: Current indentation level.
|
|
208
|
+
"""
|
|
209
|
+
if node_name in visited or node_name == END or node_name == "__end__":
|
|
210
|
+
return
|
|
211
|
+
|
|
212
|
+
visited.add(node_name)
|
|
213
|
+
indent_str = " " * indent
|
|
214
|
+
|
|
215
|
+
# Get function name
|
|
216
|
+
node_spec = nodes.get(node_name)
|
|
217
|
+
if not node_spec:
|
|
218
|
+
return
|
|
219
|
+
|
|
220
|
+
node_func = node_spec.runnable
|
|
221
|
+
if function_map and node_name in function_map:
|
|
222
|
+
func_name = function_map[node_name].__name__
|
|
223
|
+
else:
|
|
224
|
+
func_name = getattr(node_func, "__name__", getattr(node_func, "name", str(node_name)))
|
|
225
|
+
|
|
226
|
+
# Generate node execution code with keyword arguments
|
|
227
|
+
code_lines.append(f"{indent_str}# Node: {node_name}")
|
|
228
|
+
code_lines.append(f"{indent_str}{node_name}_result = {func_name}(state=state)")
|
|
229
|
+
code_lines.append(f"{indent_str}# Update state")
|
|
230
|
+
code_lines.append(f"{indent_str}for key, value in {node_name}_result.items():")
|
|
231
|
+
code_lines.append(f"{indent_str} if key in state:")
|
|
232
|
+
code_lines.append(f"{indent_str} if isinstance(value, list) and isinstance(state[key], list):")
|
|
233
|
+
code_lines.append(f"{indent_str} state[key].extend(value)")
|
|
234
|
+
code_lines.append(f"{indent_str} else:")
|
|
235
|
+
code_lines.append(f"{indent_str} state[key] = value")
|
|
236
|
+
code_lines.append(f"{indent_str} else:")
|
|
237
|
+
code_lines.append(f"{indent_str} state[key] = value")
|
|
238
|
+
code_lines.append("")
|
|
239
|
+
|
|
240
|
+
# Check for conditional edges
|
|
241
|
+
if node_name in branches and branches[node_name]:
|
|
242
|
+
# Has conditional routing
|
|
243
|
+
branch_name = list(branches[node_name].keys())[0]
|
|
244
|
+
branch_spec = branches[node_name][branch_name]
|
|
245
|
+
|
|
246
|
+
# Use the branch name as the condition function name
|
|
247
|
+
condition_name = branch_name
|
|
248
|
+
|
|
249
|
+
# Get possible next nodes
|
|
250
|
+
possible_next = set()
|
|
251
|
+
if hasattr(branch_spec, "ends"):
|
|
252
|
+
ends = branch_spec.ends
|
|
253
|
+
possible_next.update(ends.values() if isinstance(ends, dict) else ends)
|
|
254
|
+
|
|
255
|
+
# Also check edges
|
|
256
|
+
for source, target in edges:
|
|
257
|
+
if source == node_name and target not in (END, "__end__"):
|
|
258
|
+
possible_next.add(target)
|
|
259
|
+
|
|
260
|
+
code_lines.append(f"{indent_str}# Conditional routing")
|
|
261
|
+
code_lines.append(f"{indent_str}next_node = {condition_name}(state=state)")
|
|
262
|
+
code_lines.append("")
|
|
263
|
+
|
|
264
|
+
# Generate if-else for each possible path
|
|
265
|
+
if possible_next:
|
|
266
|
+
first = True
|
|
267
|
+
for next_node in sorted(possible_next):
|
|
268
|
+
if_keyword = "if" if first else "elif"
|
|
269
|
+
first = False
|
|
270
|
+
|
|
271
|
+
code_lines.append(f"{indent_str}{if_keyword} next_node == '{next_node}':")
|
|
272
|
+
branch_visited = visited.copy()
|
|
273
|
+
self._generate_node_code(
|
|
274
|
+
next_node, nodes, edges, branches, function_map, code_lines, branch_visited, indent + 1
|
|
275
|
+
)
|
|
276
|
+
else:
|
|
277
|
+
# Regular edge - no condition
|
|
278
|
+
next_nodes = [target for source, target in edges if source == node_name]
|
|
279
|
+
if next_nodes and next_nodes[0] != END and next_nodes[0] != "__end__":
|
|
280
|
+
self._generate_node_code(
|
|
281
|
+
next_nodes[0], nodes, edges, branches, function_map, code_lines, visited, indent
|
|
282
|
+
)
|