lionagi 0.14.4__py3-none-any.whl → 0.14.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/fields/instruct.py +3 -17
- lionagi/libs/concurrency/__init__.py +25 -1
- lionagi/libs/concurrency/cancel.py +1 -1
- lionagi/libs/concurrency/patterns.py +145 -138
- lionagi/libs/concurrency/primitives.py +145 -97
- lionagi/libs/concurrency/resource_tracker.py +182 -0
- lionagi/libs/concurrency/task.py +4 -2
- lionagi/operations/builder.py +9 -0
- lionagi/operations/flow.py +163 -60
- lionagi/protocols/generic/pile.py +7 -10
- lionagi/protocols/generic/processor.py +53 -26
- lionagi/service/connections/providers/_claude_code/__init__.py +3 -0
- lionagi/service/connections/providers/_claude_code/models.py +235 -0
- lionagi/service/connections/providers/_claude_code/stream_cli.py +350 -0
- lionagi/service/connections/providers/claude_code_.py +13 -223
- lionagi/service/connections/providers/claude_code_cli.py +38 -343
- lionagi/service/rate_limited_processor.py +53 -35
- lionagi/session/branch.py +6 -51
- lionagi/session/session.py +26 -8
- lionagi/utils.py +56 -174
- lionagi/version.py +1 -1
- {lionagi-0.14.4.dist-info → lionagi-0.14.6.dist-info}/METADATA +6 -2
- {lionagi-0.14.4.dist-info → lionagi-0.14.6.dist-info}/RECORD +25 -21
- {lionagi-0.14.4.dist-info → lionagi-0.14.6.dist-info}/WHEEL +0 -0
- {lionagi-0.14.4.dist-info → lionagi-0.14.6.dist-info}/licenses/LICENSE +0 -0
@@ -6,238 +6,28 @@ from __future__ import annotations
|
|
6
6
|
|
7
7
|
import json
|
8
8
|
import warnings
|
9
|
-
from pathlib import Path
|
10
|
-
from typing import Any, Literal
|
11
9
|
|
12
|
-
from pydantic import BaseModel
|
10
|
+
from pydantic import BaseModel
|
13
11
|
|
14
12
|
from lionagi.libs.schema.as_readable import as_readable
|
15
13
|
from lionagi.service.connections.endpoint import Endpoint
|
16
14
|
from lionagi.service.connections.endpoint_config import EndpointConfig
|
17
15
|
from lionagi.utils import is_import_installed, to_dict, to_list
|
18
16
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
ClaudePermission
|
23
|
-
|
24
|
-
"acceptEdits",
|
25
|
-
"bypassPermissions",
|
26
|
-
"dangerously-skip-permissions",
|
27
|
-
]
|
28
|
-
|
29
|
-
CLAUDE_CODE_OPTION_PARAMS = {
|
30
|
-
"allowed_tools",
|
31
|
-
"max_thinking_tokens",
|
32
|
-
"mcp_tools",
|
33
|
-
"mcp_servers",
|
34
|
-
"permission_mode",
|
35
|
-
"continue_conversation",
|
36
|
-
"resume",
|
37
|
-
"max_turns",
|
38
|
-
"disallowed_tools",
|
39
|
-
"model",
|
40
|
-
"permission_prompt_tool_name",
|
41
|
-
"cwd",
|
42
|
-
"system_prompt",
|
43
|
-
"append_system_prompt",
|
44
|
-
}
|
45
|
-
|
46
|
-
|
47
|
-
# --------------------------------------------------------------------------- request model
|
48
|
-
class ClaudeCodeRequest(BaseModel):
|
49
|
-
# -- conversational bits -------------------------------------------------
|
50
|
-
prompt: str = Field(description="The prompt for Claude Code")
|
51
|
-
system_prompt: str | None = None
|
52
|
-
append_system_prompt: str | None = None
|
53
|
-
max_turns: int | None = None
|
54
|
-
continue_conversation: bool = False
|
55
|
-
resume: str | None = None
|
56
|
-
|
57
|
-
# -- repo / workspace ----------------------------------------------------
|
58
|
-
repo: Path = Field(default_factory=Path.cwd, exclude=True)
|
59
|
-
ws: str | None = None # sub-directory under repo
|
60
|
-
add_dir: str | None = None # extra read-only mount
|
61
|
-
allowed_tools: list[str] | None = None
|
62
|
-
|
63
|
-
# -- runtime & safety ----------------------------------------------------
|
64
|
-
model: Literal["sonnet", "opus"] | str | None = "sonnet"
|
65
|
-
max_thinking_tokens: int | None = None
|
66
|
-
mcp_tools: list[str] = Field(default_factory=list)
|
67
|
-
mcp_servers: dict[str, Any] = Field(default_factory=dict)
|
68
|
-
permission_mode: ClaudePermission | None = None
|
69
|
-
permission_prompt_tool_name: str | None = None
|
70
|
-
disallowed_tools: list[str] = Field(default_factory=list)
|
71
|
-
|
72
|
-
# -- internal use --------------------------------------------------------
|
73
|
-
auto_finish: bool = Field(
|
74
|
-
default=True,
|
75
|
-
exclude=True,
|
76
|
-
description="Automatically finish the conversation after the first response",
|
77
|
-
)
|
78
|
-
verbose_output: bool = Field(default=False, exclude=True)
|
79
|
-
cli_display_theme: Literal["light", "dark"] = "light"
|
80
|
-
|
81
|
-
# ------------------------ validators & helpers --------------------------
|
82
|
-
@field_validator("permission_mode", mode="before")
|
83
|
-
def _norm_perm(cls, v):
|
84
|
-
if v in {
|
85
|
-
"dangerously-skip-permissions",
|
86
|
-
"--dangerously-skip-permissions",
|
87
|
-
}:
|
88
|
-
return "bypassPermissions"
|
89
|
-
return v
|
90
|
-
|
91
|
-
# Workspace path derived from repo + ws
|
92
|
-
def cwd(self) -> Path:
|
93
|
-
if not self.ws:
|
94
|
-
return self.repo
|
95
|
-
|
96
|
-
# Convert to Path object for proper validation
|
97
|
-
ws_path = Path(self.ws)
|
98
|
-
|
99
|
-
# Check for absolute paths or directory traversal attempts
|
100
|
-
if ws_path.is_absolute():
|
101
|
-
raise ValueError(
|
102
|
-
f"Workspace path must be relative, got absolute: {self.ws}"
|
103
|
-
)
|
104
|
-
|
105
|
-
if ".." in ws_path.parts:
|
106
|
-
raise ValueError(
|
107
|
-
f"Directory traversal detected in workspace path: {self.ws}"
|
108
|
-
)
|
109
|
-
|
110
|
-
# Resolve paths to handle symlinks and normalize
|
111
|
-
repo_resolved = self.repo.resolve()
|
112
|
-
result = (self.repo / ws_path).resolve()
|
113
|
-
|
114
|
-
# Ensure the resolved path is within the repository bounds
|
115
|
-
try:
|
116
|
-
result.relative_to(repo_resolved)
|
117
|
-
except ValueError:
|
118
|
-
raise ValueError(
|
119
|
-
f"Workspace path escapes repository bounds. "
|
120
|
-
f"Repository: {repo_resolved}, Workspace: {result}"
|
121
|
-
)
|
122
|
-
|
123
|
-
return result
|
124
|
-
|
125
|
-
@model_validator(mode="after")
|
126
|
-
def _check_perm_workspace(self):
|
127
|
-
if self.permission_mode == "bypassPermissions":
|
128
|
-
# Use secure path validation with resolved paths
|
129
|
-
repo_resolved = self.repo.resolve()
|
130
|
-
cwd_resolved = self.cwd().resolve()
|
131
|
-
|
132
|
-
# Check if cwd is within repo bounds using proper path methods
|
133
|
-
try:
|
134
|
-
cwd_resolved.relative_to(repo_resolved)
|
135
|
-
except ValueError:
|
136
|
-
raise ValueError(
|
137
|
-
f"With bypassPermissions, workspace must be within repository bounds. "
|
138
|
-
f"Repository: {repo_resolved}, Workspace: {cwd_resolved}"
|
139
|
-
)
|
140
|
-
return self
|
141
|
-
|
142
|
-
# ------------------------ CLI helpers -----------------------------------
|
143
|
-
def as_cmd_args(self) -> list[str]:
|
144
|
-
"""Build argument list for the *Node* `claude` CLI."""
|
145
|
-
args: list[str] = ["-p", self.prompt, "--output-format", "stream-json"]
|
146
|
-
if self.allowed_tools:
|
147
|
-
args.append("--allowedTools")
|
148
|
-
for tool in self.allowed_tools:
|
149
|
-
args.append(f'"{tool}"')
|
150
|
-
|
151
|
-
if self.disallowed_tools:
|
152
|
-
args.append("--disallowedTools")
|
153
|
-
for tool in self.disallowed_tools:
|
154
|
-
args.append(f'"{tool}"')
|
155
|
-
|
156
|
-
if self.resume:
|
157
|
-
args += ["--resume", self.resume]
|
158
|
-
elif self.continue_conversation:
|
159
|
-
args.append("--continue")
|
160
|
-
|
161
|
-
if self.max_turns:
|
162
|
-
# +1 because CLI counts *pairs*
|
163
|
-
args += ["--max-turns", str(self.max_turns + 1)]
|
164
|
-
|
165
|
-
if self.permission_mode == "bypassPermissions":
|
166
|
-
args += ["--dangerously-skip-permissions"]
|
167
|
-
|
168
|
-
if self.add_dir:
|
169
|
-
args += ["--add-dir", self.add_dir]
|
170
|
-
|
171
|
-
args += ["--model", self.model or "sonnet", "--verbose"]
|
172
|
-
return args
|
173
|
-
|
174
|
-
# ------------------------ SDK helpers -----------------------------------
|
175
|
-
def as_claude_options(self):
|
176
|
-
from claude_code_sdk import ClaudeCodeOptions
|
177
|
-
|
178
|
-
data = {
|
179
|
-
k: v
|
180
|
-
for k, v in self.model_dump(exclude_none=True).items()
|
181
|
-
if k in CLAUDE_CODE_OPTION_PARAMS
|
182
|
-
}
|
183
|
-
return ClaudeCodeOptions(**data)
|
184
|
-
|
185
|
-
# ------------------------ convenience constructor -----------------------
|
186
|
-
@classmethod
|
187
|
-
def create(
|
188
|
-
cls,
|
189
|
-
messages: list[dict[str, Any]],
|
190
|
-
resume: str | None = None,
|
191
|
-
continue_conversation: bool | None = None,
|
192
|
-
**kwargs,
|
193
|
-
):
|
194
|
-
if not messages:
|
195
|
-
raise ValueError("messages may not be empty")
|
196
|
-
|
197
|
-
prompt = ""
|
198
|
-
|
199
|
-
# 1. if resume or continue_conversation, use the last message
|
200
|
-
if resume or continue_conversation:
|
201
|
-
continue_conversation = True
|
202
|
-
prompt = messages[-1]["content"]
|
203
|
-
if isinstance(prompt, (dict, list)):
|
204
|
-
prompt = json.dumps(prompt)
|
205
|
-
|
206
|
-
# 2. else, use entire messages except system message
|
207
|
-
else:
|
208
|
-
prompts = []
|
209
|
-
continue_conversation = False
|
210
|
-
for message in messages:
|
211
|
-
if message["role"] != "system":
|
212
|
-
content = message["content"]
|
213
|
-
prompts.append(
|
214
|
-
json.dumps(content)
|
215
|
-
if isinstance(content, (dict, list))
|
216
|
-
else content
|
217
|
-
)
|
218
|
-
|
219
|
-
prompt = "\n".join(prompts)
|
220
|
-
|
221
|
-
# 3. assemble the request data
|
222
|
-
data: dict[str, Any] = dict(
|
223
|
-
prompt=prompt,
|
224
|
-
resume=resume,
|
225
|
-
continue_conversation=bool(continue_conversation),
|
226
|
-
)
|
227
|
-
|
228
|
-
# 4. extract system prompt if available
|
229
|
-
if (messages[0]["role"] == "system") and (
|
230
|
-
resume or continue_conversation
|
231
|
-
):
|
232
|
-
data["system_prompt"] = messages[0]["content"]
|
233
|
-
if kwargs.get("append_system_prompt"):
|
234
|
-
data["append_system_prompt"] = str(
|
235
|
-
kwargs.get("append_system_prompt")
|
236
|
-
)
|
17
|
+
from ._claude_code.models import (
|
18
|
+
CLAUDE_CODE_OPTION_PARAMS,
|
19
|
+
ClaudeCodeRequest,
|
20
|
+
ClaudePermission,
|
21
|
+
)
|
237
22
|
|
238
|
-
|
239
|
-
|
23
|
+
__all__ = (
|
24
|
+
"ClaudeCodeRequest",
|
25
|
+
"CLAUDE_CODE_OPTION_PARAMS", # backward compatibility
|
26
|
+
"ClaudePermission", # backward compatibility
|
27
|
+
"ClaudeCodeEndpoint",
|
28
|
+
)
|
240
29
|
|
30
|
+
HAS_CLAUDE_CODE_SDK = is_import_installed("claude_code_sdk")
|
241
31
|
|
242
32
|
# --------------------------------------------------------------------------- SDK endpoint
|
243
33
|
ENDPOINT_CONFIG = EndpointConfig(
|
@@ -1,352 +1,23 @@
|
|
1
|
+
# Copyright (c) 2025, HaiyangLi <quantocean.li at gmail dot com>
|
2
|
+
#
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
4
|
+
|
1
5
|
from __future__ import annotations
|
2
6
|
|
3
|
-
import
|
4
|
-
import codecs
|
5
|
-
import contextlib
|
6
|
-
import dataclasses
|
7
|
-
import json
|
8
|
-
import logging
|
9
|
-
import shutil
|
10
|
-
from collections.abc import AsyncIterator, Callable
|
11
|
-
from datetime import datetime
|
12
|
-
from functools import partial
|
13
|
-
from textwrap import shorten
|
14
|
-
from typing import Any
|
7
|
+
from collections.abc import AsyncIterator
|
15
8
|
|
16
|
-
from json_repair import repair_json
|
17
9
|
from pydantic import BaseModel
|
18
10
|
|
19
|
-
from lionagi.libs.schema.as_readable import as_readable
|
20
11
|
from lionagi.service.connections.endpoint import Endpoint, EndpointConfig
|
21
12
|
from lionagi.utils import to_dict
|
22
13
|
|
23
|
-
from .
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
logging.basicConfig(level=logging.INFO)
|
31
|
-
log = logging.getLogger("claude-cli")
|
32
|
-
|
33
|
-
|
34
|
-
@dataclasses.dataclass
|
35
|
-
class ClaudeChunk:
|
36
|
-
"""Low-level wrapper around every NDJSON object coming from the CLI."""
|
37
|
-
|
38
|
-
raw: dict[str, Any]
|
39
|
-
type: str
|
40
|
-
# convenience views
|
41
|
-
thinking: str | None = None
|
42
|
-
text: str | None = None
|
43
|
-
tool_use: dict[str, Any] | None = None
|
44
|
-
tool_result: dict[str, Any] | None = None
|
45
|
-
|
46
|
-
|
47
|
-
@dataclasses.dataclass
|
48
|
-
class ClaudeSession:
|
49
|
-
"""Aggregated view of a whole CLI conversation."""
|
50
|
-
|
51
|
-
session_id: str | None = None
|
52
|
-
model: str | None = None
|
53
|
-
|
54
|
-
# chronological log
|
55
|
-
chunks: list[ClaudeChunk] = dataclasses.field(default_factory=list)
|
56
|
-
|
57
|
-
# materialised views
|
58
|
-
thinking_log: list[str] = dataclasses.field(default_factory=list)
|
59
|
-
messages: list[dict[str, Any]] = dataclasses.field(default_factory=list)
|
60
|
-
tool_uses: list[dict[str, Any]] = dataclasses.field(default_factory=list)
|
61
|
-
tool_results: list[dict[str, Any]] = dataclasses.field(
|
62
|
-
default_factory=list
|
63
|
-
)
|
64
|
-
|
65
|
-
# final summary
|
66
|
-
result: str = ""
|
67
|
-
usage: dict[str, Any] = dataclasses.field(default_factory=dict)
|
68
|
-
total_cost_usd: float | None = None
|
69
|
-
num_turns: int | None = None
|
70
|
-
duration_ms: int | None = None
|
71
|
-
duration_api_ms: int | None = None
|
72
|
-
is_error: bool = False
|
73
|
-
|
74
|
-
|
75
|
-
# --------------------------------------------------------------------------- helpers
|
76
|
-
|
77
|
-
|
78
|
-
async def ndjson_from_cli(request: ClaudeCodeRequest):
|
79
|
-
"""
|
80
|
-
Yields each JSON object emitted by the *claude-code* CLI.
|
81
|
-
|
82
|
-
• Robust against UTF‑8 splits across chunks (incremental decoder).
|
83
|
-
• Robust against braces inside strings (uses json.JSONDecoder.raw_decode)
|
84
|
-
• Falls back to `json_repair.repair_json` when necessary.
|
85
|
-
"""
|
86
|
-
workspace = request.cwd()
|
87
|
-
workspace.mkdir(parents=True, exist_ok=True)
|
88
|
-
|
89
|
-
proc = await asyncio.create_subprocess_exec(
|
90
|
-
CLAUDE,
|
91
|
-
*request.as_cmd_args(),
|
92
|
-
cwd=str(workspace),
|
93
|
-
stdout=asyncio.subprocess.PIPE,
|
94
|
-
stderr=asyncio.subprocess.PIPE,
|
95
|
-
)
|
96
|
-
|
97
|
-
decoder = codecs.getincrementaldecoder("utf-8")()
|
98
|
-
json_decoder = json.JSONDecoder()
|
99
|
-
buffer: str = "" # text buffer that may hold >1 JSON objects
|
100
|
-
|
101
|
-
try:
|
102
|
-
while True:
|
103
|
-
chunk = await proc.stdout.read(4096)
|
104
|
-
if not chunk:
|
105
|
-
break
|
106
|
-
|
107
|
-
# 1) decode *incrementally* so we never split multibyte chars
|
108
|
-
buffer += decoder.decode(chunk)
|
109
|
-
|
110
|
-
# 2) try to peel off as many complete JSON objs as possible
|
111
|
-
while buffer:
|
112
|
-
buffer = buffer.lstrip() # remove leading spaces/newlines
|
113
|
-
if not buffer:
|
114
|
-
break
|
115
|
-
try:
|
116
|
-
obj, idx = json_decoder.raw_decode(buffer)
|
117
|
-
yield obj
|
118
|
-
buffer = buffer[idx:] # keep remainder for next round
|
119
|
-
except json.JSONDecodeError:
|
120
|
-
# incomplete → need more bytes
|
121
|
-
break
|
122
|
-
|
123
|
-
# 3) flush any tail bytes in the incremental decoder
|
124
|
-
buffer += decoder.decode(b"", final=True)
|
125
|
-
buffer = buffer.strip()
|
126
|
-
if buffer:
|
127
|
-
try:
|
128
|
-
obj, idx = json_decoder.raw_decode(buffer)
|
129
|
-
yield obj
|
130
|
-
except json.JSONDecodeError:
|
131
|
-
try:
|
132
|
-
fixed = repair_json(buffer)
|
133
|
-
yield json.loads(fixed)
|
134
|
-
log.warning(
|
135
|
-
"Repaired malformed JSON fragment at stream end"
|
136
|
-
)
|
137
|
-
except Exception:
|
138
|
-
log.error(
|
139
|
-
"Skipped unrecoverable JSON tail: %.120s…", buffer
|
140
|
-
)
|
141
|
-
|
142
|
-
# 4) propagate non‑zero exit code
|
143
|
-
if await proc.wait() != 0:
|
144
|
-
err = (await proc.stderr.read()).decode().strip()
|
145
|
-
raise RuntimeError(err or "CLI exited non‑zero")
|
146
|
-
|
147
|
-
finally:
|
148
|
-
with contextlib.suppress(ProcessLookupError):
|
149
|
-
proc.terminate()
|
150
|
-
await proc.wait()
|
151
|
-
|
152
|
-
|
153
|
-
# --------------------------------------------------------------------------- SSE route
|
154
|
-
async def stream_events(request: ClaudeCodeRequest):
|
155
|
-
async for obj in ndjson_from_cli(request):
|
156
|
-
yield obj
|
157
|
-
yield {"type": "done"}
|
158
|
-
|
159
|
-
|
160
|
-
print_readable = partial(as_readable, md=True, display_str=True)
|
161
|
-
|
162
|
-
|
163
|
-
def _pp_system(sys_obj: dict[str, Any], theme) -> None:
|
164
|
-
txt = (
|
165
|
-
f"◼️ **Claude Code Session** \n"
|
166
|
-
f"- id: `{sys_obj.get('session_id', '?')}` \n"
|
167
|
-
f"- model: `{sys_obj.get('model', '?')}` \n"
|
168
|
-
f"- tools: {', '.join(sys_obj.get('tools', [])[:8])}"
|
169
|
-
+ ("…" if len(sys_obj.get("tools", [])) > 8 else "")
|
170
|
-
)
|
171
|
-
print_readable(txt, border=False, theme=theme)
|
172
|
-
|
173
|
-
|
174
|
-
def _pp_thinking(thought: str, theme) -> None:
|
175
|
-
text = f"""
|
176
|
-
🧠 Thinking:
|
177
|
-
{thought}
|
178
|
-
"""
|
179
|
-
print_readable(text, border=True, theme=theme)
|
180
|
-
|
181
|
-
|
182
|
-
def _pp_assistant_text(text: str, theme) -> None:
|
183
|
-
txt = f"""
|
184
|
-
> 🗣️ Claude:
|
185
|
-
{text}
|
186
|
-
"""
|
187
|
-
print_readable(txt, theme=theme)
|
188
|
-
|
189
|
-
|
190
|
-
def _pp_tool_use(tu: dict[str, Any], theme) -> None:
|
191
|
-
preview = shorten(str(tu["input"]).replace("\n", " "), 130)
|
192
|
-
body = f"- 🔧 Tool Use — {tu['name']}({tu['id']}) - input: {preview}"
|
193
|
-
print_readable(body, border=False, panel=False, theme=theme)
|
194
|
-
|
195
|
-
|
196
|
-
def _pp_tool_result(tr: dict[str, Any], theme) -> None:
|
197
|
-
body_preview = shorten(str(tr["content"]).replace("\n", " "), 130)
|
198
|
-
status = "ERR" if tr.get("is_error") else "OK"
|
199
|
-
body = f"- 📄 Tool Result({tr['tool_use_id']}) - {status}\n\n\tcontent: {body_preview}"
|
200
|
-
print_readable(body, border=False, panel=False, theme=theme)
|
201
|
-
|
202
|
-
|
203
|
-
def _pp_final(sess: ClaudeSession, theme) -> None:
|
204
|
-
usage = sess.usage or {}
|
205
|
-
txt = (
|
206
|
-
f"### ✅ Session complete - {datetime.utcnow().isoformat(timespec='seconds')} UTC\n"
|
207
|
-
f"**Result:**\n\n{sess.result or ''}\n\n"
|
208
|
-
f"- cost: **${sess.total_cost_usd:.4f}** \n"
|
209
|
-
f"- turns: **{sess.num_turns}** \n"
|
210
|
-
f"- duration: **{sess.duration_ms} ms** (API {sess.duration_api_ms} ms) \n"
|
211
|
-
f"- tokens in/out: {usage.get('input_tokens', 0)}/{usage.get('output_tokens', 0)}"
|
212
|
-
)
|
213
|
-
print_readable(txt, theme=theme)
|
214
|
-
|
215
|
-
|
216
|
-
# --------------------------------------------------------------------------- internal utils
|
217
|
-
|
218
|
-
|
219
|
-
async def _maybe_await(func, *args, **kw):
|
220
|
-
"""Call func which may be sync or async."""
|
221
|
-
res = func(*args, **kw) if func else None
|
222
|
-
if asyncio.iscoroutine(res):
|
223
|
-
await res
|
224
|
-
|
225
|
-
|
226
|
-
# --------------------------------------------------------------------------- main parser
|
227
|
-
|
228
|
-
|
229
|
-
async def stream_claude_code_cli( # noqa: C901 (complexity from branching is fine here)
|
230
|
-
request: ClaudeCodeRequest,
|
231
|
-
session: ClaudeSession = ClaudeSession(),
|
232
|
-
*,
|
233
|
-
on_system: Callable[[dict[str, Any]], None] | None = None,
|
234
|
-
on_thinking: Callable[[str], None] | None = None,
|
235
|
-
on_text: Callable[[str], None] | None = None,
|
236
|
-
on_tool_use: Callable[[dict[str, Any]], None] | None = None,
|
237
|
-
on_tool_result: Callable[[dict[str, Any]], None] | None = None,
|
238
|
-
on_final: Callable[[ClaudeSession], None] | None = None,
|
239
|
-
) -> AsyncIterator[ClaudeChunk | dict | ClaudeSession]:
|
240
|
-
"""
|
241
|
-
Consume the ND‑JSON stream produced by ndjson_from_cli()
|
242
|
-
and return a fully‑populated ClaudeSession.
|
243
|
-
|
244
|
-
If callbacks are omitted a default pretty‑print is emitted.
|
245
|
-
"""
|
246
|
-
stream = ndjson_from_cli(request)
|
247
|
-
theme = request.cli_display_theme or "light"
|
248
|
-
|
249
|
-
async for obj in stream:
|
250
|
-
typ = obj.get("type", "unknown")
|
251
|
-
chunk = ClaudeChunk(raw=obj, type=typ)
|
252
|
-
session.chunks.append(chunk)
|
253
|
-
|
254
|
-
# ------------------------ SYSTEM -----------------------------------
|
255
|
-
if typ == "system":
|
256
|
-
data = obj
|
257
|
-
session.session_id = data.get("session_id", session.session_id)
|
258
|
-
session.model = data.get("model", session.model)
|
259
|
-
await _maybe_await(on_system, data)
|
260
|
-
if request.verbose_output and on_system is None:
|
261
|
-
_pp_system(data, theme)
|
262
|
-
yield data
|
263
|
-
|
264
|
-
# ------------------------ ASSISTANT --------------------------------
|
265
|
-
elif typ == "assistant":
|
266
|
-
msg = obj["message"]
|
267
|
-
session.messages.append(msg)
|
268
|
-
|
269
|
-
for blk in msg.get("content", []):
|
270
|
-
btype = blk.get("type")
|
271
|
-
if btype == "thinking":
|
272
|
-
thought = blk.get("thinking", "").strip()
|
273
|
-
chunk.thinking = thought
|
274
|
-
session.thinking_log.append(thought)
|
275
|
-
await _maybe_await(on_thinking, thought)
|
276
|
-
if request.verbose_output and on_thinking is None:
|
277
|
-
_pp_thinking(thought, theme)
|
278
|
-
|
279
|
-
elif btype == "text":
|
280
|
-
text = blk.get("text", "")
|
281
|
-
chunk.text = text
|
282
|
-
await _maybe_await(on_text, text)
|
283
|
-
if request.verbose_output and on_text is None:
|
284
|
-
_pp_assistant_text(text, theme)
|
285
|
-
|
286
|
-
elif btype == "tool_use":
|
287
|
-
tu = {
|
288
|
-
"id": blk["id"],
|
289
|
-
"name": blk["name"],
|
290
|
-
"input": blk["input"],
|
291
|
-
}
|
292
|
-
chunk.tool_use = tu
|
293
|
-
session.tool_uses.append(tu)
|
294
|
-
await _maybe_await(on_tool_use, tu)
|
295
|
-
if request.verbose_output and on_tool_use is None:
|
296
|
-
_pp_tool_use(tu, theme)
|
297
|
-
|
298
|
-
elif btype == "tool_result":
|
299
|
-
tr = {
|
300
|
-
"tool_use_id": blk["tool_use_id"],
|
301
|
-
"content": blk["content"],
|
302
|
-
"is_error": blk.get("is_error", False),
|
303
|
-
}
|
304
|
-
chunk.tool_result = tr
|
305
|
-
session.tool_results.append(tr)
|
306
|
-
await _maybe_await(on_tool_result, tr)
|
307
|
-
if request.verbose_output and on_tool_result is None:
|
308
|
-
_pp_tool_result(tr, theme)
|
309
|
-
yield chunk
|
310
|
-
|
311
|
-
# ------------------------ USER (tool_result containers) ------------
|
312
|
-
elif typ == "user":
|
313
|
-
msg = obj["message"]
|
314
|
-
session.messages.append(msg)
|
315
|
-
for blk in msg.get("content", []):
|
316
|
-
if blk.get("type") == "tool_result":
|
317
|
-
tr = {
|
318
|
-
"tool_use_id": blk["tool_use_id"],
|
319
|
-
"content": blk["content"],
|
320
|
-
"is_error": blk.get("is_error", False),
|
321
|
-
}
|
322
|
-
chunk.tool_result = tr
|
323
|
-
session.tool_results.append(tr)
|
324
|
-
await _maybe_await(on_tool_result, tr)
|
325
|
-
if request.verbose_output and on_tool_result is None:
|
326
|
-
_pp_tool_result(tr, theme)
|
327
|
-
yield chunk
|
328
|
-
|
329
|
-
# ------------------------ RESULT -----------------------------------
|
330
|
-
elif typ == "result":
|
331
|
-
session.result = obj.get("result", "").strip()
|
332
|
-
session.usage = obj.get("usage", {})
|
333
|
-
session.total_cost_usd = obj.get("total_cost_usd")
|
334
|
-
session.num_turns = obj.get("num_turns")
|
335
|
-
session.duration_ms = obj.get("duration_ms")
|
336
|
-
session.duration_api_ms = obj.get("duration_api_ms")
|
337
|
-
session.is_error = obj.get("is_error", False)
|
338
|
-
|
339
|
-
# ------------------------ DONE -------------------------------------
|
340
|
-
elif typ == "done":
|
341
|
-
break
|
342
|
-
|
343
|
-
# final pretty print
|
344
|
-
await _maybe_await(on_final, session)
|
345
|
-
if request.verbose_output and on_final is None:
|
346
|
-
_pp_final(session, theme)
|
347
|
-
|
348
|
-
yield session
|
349
|
-
|
14
|
+
from ._claude_code.models import ClaudeCodeRequest
|
15
|
+
from ._claude_code.stream_cli import (
|
16
|
+
ClaudeChunk,
|
17
|
+
ClaudeSession,
|
18
|
+
log,
|
19
|
+
stream_claude_code_cli,
|
20
|
+
)
|
350
21
|
|
351
22
|
ENDPOINT_CONFIG = EndpointConfig(
|
352
23
|
name="claude_code_cli",
|
@@ -363,13 +34,27 @@ class ClaudeCodeCLIEndpoint(Endpoint):
|
|
363
34
|
def __init__(self, config: EndpointConfig = ENDPOINT_CONFIG, **kwargs):
|
364
35
|
super().__init__(config=config, **kwargs)
|
365
36
|
|
37
|
+
@property
|
38
|
+
def claude_handlers(self):
|
39
|
+
handlers = {
|
40
|
+
"on_thinking": None,
|
41
|
+
"on_text": None,
|
42
|
+
"on_tool_use": None,
|
43
|
+
"on_tool_result": None,
|
44
|
+
"on_system": None,
|
45
|
+
"on_final": None,
|
46
|
+
}
|
47
|
+
return self.config.kwargs.get("claude_handlers", handlers)
|
48
|
+
|
366
49
|
def create_payload(self, request: dict | BaseModel, **kwargs):
|
367
50
|
req_dict = {**self.config.kwargs, **to_dict(request), **kwargs}
|
368
51
|
messages = req_dict.pop("messages")
|
369
52
|
req_obj = ClaudeCodeRequest.create(messages=messages, **req_dict)
|
370
53
|
return {"request": req_obj}, {}
|
371
54
|
|
372
|
-
async def stream(
|
55
|
+
async def stream(
|
56
|
+
self, request: dict | BaseModel, **kwargs
|
57
|
+
) -> AsyncIterator[ClaudeChunk | dict | ClaudeSession]:
|
373
58
|
payload, _ = self.create_payload(request, **kwargs)["request"]
|
374
59
|
async for chunk in stream_claude_code_cli(payload):
|
375
60
|
yield chunk
|
@@ -386,7 +71,9 @@ class ClaudeCodeCLIEndpoint(Endpoint):
|
|
386
71
|
system: dict = None
|
387
72
|
|
388
73
|
# 1. stream the Claude Code response
|
389
|
-
async for chunk in stream_claude_code_cli(
|
74
|
+
async for chunk in stream_claude_code_cli(
|
75
|
+
request, session, **self.claude_handlers, **kwargs
|
76
|
+
):
|
390
77
|
if isinstance(chunk, dict):
|
391
78
|
system = chunk
|
392
79
|
responses.append(chunk)
|
@@ -395,6 +82,7 @@ class ClaudeCodeCLIEndpoint(Endpoint):
|
|
395
82
|
responses[-1], ClaudeSession
|
396
83
|
):
|
397
84
|
req2 = request.model_copy(deep=True)
|
85
|
+
req2.prompt = "Please provide a the final result message only"
|
398
86
|
req2.max_turns = 1
|
399
87
|
req2.continue_conversation = True
|
400
88
|
if system:
|
@@ -407,4 +95,11 @@ class ClaudeCodeCLIEndpoint(Endpoint):
|
|
407
95
|
log.info(
|
408
96
|
f"Session {session.session_id} finished with {len(responses)} chunks"
|
409
97
|
)
|
98
|
+
texts = []
|
99
|
+
for i in session.chunks:
|
100
|
+
if i.text is not None:
|
101
|
+
texts.append(i.text)
|
102
|
+
|
103
|
+
texts.append(session.result)
|
104
|
+
session.result = "\n".join(texts)
|
410
105
|
return to_dict(session, recursive=True)
|