lionagi 0.14.8__py3-none-any.whl → 0.14.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/_errors.py +120 -11
- lionagi/_types.py +0 -6
- lionagi/config.py +3 -1
- lionagi/fields/reason.py +1 -1
- lionagi/libs/concurrency/throttle.py +79 -0
- lionagi/libs/parse.py +2 -1
- lionagi/libs/unstructured/__init__.py +0 -0
- lionagi/libs/unstructured/pdf_to_image.py +45 -0
- lionagi/libs/unstructured/read_image_to_base64.py +33 -0
- lionagi/libs/validate/to_num.py +378 -0
- lionagi/libs/validate/xml_parser.py +203 -0
- lionagi/models/operable_model.py +8 -3
- lionagi/operations/flow.py +0 -1
- lionagi/protocols/generic/event.py +2 -0
- lionagi/protocols/generic/log.py +26 -10
- lionagi/protocols/operatives/step.py +1 -1
- lionagi/protocols/types.py +9 -1
- lionagi/service/__init__.py +22 -1
- lionagi/service/connections/api_calling.py +57 -2
- lionagi/service/connections/endpoint_config.py +1 -1
- lionagi/service/connections/header_factory.py +4 -2
- lionagi/service/connections/match_endpoint.py +10 -10
- lionagi/service/connections/providers/anthropic_.py +5 -2
- lionagi/service/connections/providers/claude_code_.py +13 -17
- lionagi/service/connections/providers/claude_code_cli.py +51 -16
- lionagi/service/connections/providers/exa_.py +5 -3
- lionagi/service/connections/providers/oai_.py +116 -81
- lionagi/service/connections/providers/ollama_.py +38 -18
- lionagi/service/connections/providers/perplexity_.py +36 -14
- lionagi/service/connections/providers/types.py +30 -0
- lionagi/service/hooks/__init__.py +25 -0
- lionagi/service/hooks/_types.py +52 -0
- lionagi/service/hooks/_utils.py +85 -0
- lionagi/service/hooks/hook_event.py +67 -0
- lionagi/service/hooks/hook_registry.py +221 -0
- lionagi/service/imodel.py +120 -34
- lionagi/service/third_party/claude_code.py +715 -0
- lionagi/service/third_party/openai_model_names.py +198 -0
- lionagi/service/third_party/pplx_models.py +16 -8
- lionagi/service/types.py +21 -0
- lionagi/session/branch.py +1 -4
- lionagi/tools/base.py +1 -3
- lionagi/tools/file/reader.py +1 -1
- lionagi/tools/memory/tools.py +2 -2
- lionagi/utils.py +12 -775
- lionagi/version.py +1 -1
- {lionagi-0.14.8.dist-info → lionagi-0.14.10.dist-info}/METADATA +6 -2
- {lionagi-0.14.8.dist-info → lionagi-0.14.10.dist-info}/RECORD +50 -40
- lionagi/service/connections/providers/_claude_code/__init__.py +0 -3
- lionagi/service/connections/providers/_claude_code/models.py +0 -244
- lionagi/service/connections/providers/_claude_code/stream_cli.py +0 -359
- lionagi/service/third_party/openai_models.py +0 -18241
- {lionagi-0.14.8.dist-info → lionagi-0.14.10.dist-info}/WHEEL +0 -0
- {lionagi-0.14.8.dist-info → lionagi-0.14.10.dist-info}/licenses/LICENSE +0 -0
lionagi/protocols/generic/log.py
CHANGED
@@ -13,18 +13,19 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
|
13
13
|
|
14
14
|
from lionagi.utils import create_path, to_dict
|
15
15
|
|
16
|
-
from .._concepts import Manager
|
17
16
|
from .element import Element
|
18
17
|
from .pile import Pile
|
19
18
|
|
20
19
|
__all__ = (
|
20
|
+
"DataLoggerConfig",
|
21
21
|
"LogManagerConfig",
|
22
22
|
"Log",
|
23
|
+
"DataLogger",
|
23
24
|
"LogManager",
|
24
25
|
)
|
25
26
|
|
26
27
|
|
27
|
-
class
|
28
|
+
class DataLoggerConfig(BaseModel):
|
28
29
|
persist_dir: str | Path = "./data/logs"
|
29
30
|
subfolder: str | None = None
|
30
31
|
file_prefix: str | None = None
|
@@ -101,7 +102,7 @@ class Log(Element):
|
|
101
102
|
return cls(content=content)
|
102
103
|
|
103
104
|
|
104
|
-
class
|
105
|
+
class DataLogger:
|
105
106
|
"""
|
106
107
|
Manages a collection of logs, optionally auto-dumping them
|
107
108
|
to CSV or JSON when capacity is reached or at program exit.
|
@@ -111,7 +112,7 @@ class LogManager(Manager):
|
|
111
112
|
self,
|
112
113
|
*,
|
113
114
|
logs: Any = None,
|
114
|
-
_config:
|
115
|
+
_config: DataLoggerConfig = None,
|
115
116
|
**kwargs,
|
116
117
|
):
|
117
118
|
"""
|
@@ -128,7 +129,7 @@ class LogManager(Manager):
|
|
128
129
|
clear_after_dump: Whether to clear logs after saving.
|
129
130
|
"""
|
130
131
|
if _config is None:
|
131
|
-
_config =
|
132
|
+
_config = DataLoggerConfig(**kwargs)
|
132
133
|
|
133
134
|
if isinstance(logs, dict):
|
134
135
|
self.logs = Pile.from_dict(logs)
|
@@ -188,8 +189,15 @@ class LogManager(Manager):
|
|
188
189
|
if do_clear:
|
189
190
|
self.logs.clear()
|
190
191
|
except Exception as e:
|
191
|
-
|
192
|
-
|
192
|
+
# Check if it's a JSON serialization error with complex objects
|
193
|
+
if "JSON serializable" in str(e):
|
194
|
+
logging.debug(f"Could not serialize logs to JSON: {e}")
|
195
|
+
# Don't raise for JSON serialization issues during dumps
|
196
|
+
if clear is not False:
|
197
|
+
self.logs.clear() # Still clear if requested
|
198
|
+
else:
|
199
|
+
logging.error(f"Failed to dump logs: {e}")
|
200
|
+
raise
|
193
201
|
|
194
202
|
async def adump(
|
195
203
|
self,
|
@@ -222,16 +230,24 @@ class LogManager(Manager):
|
|
222
230
|
try:
|
223
231
|
self.dump(clear=self._config.clear_after_dump)
|
224
232
|
except Exception as e:
|
225
|
-
|
233
|
+
# Only log debug level for JSON serialization errors during exit
|
234
|
+
# These are non-critical and often occur with complex objects
|
235
|
+
if "JSON serializable" in str(e):
|
236
|
+
logging.debug(f"Could not serialize logs to JSON: {e}")
|
237
|
+
else:
|
238
|
+
logging.error(f"Failed to save logs on exit: {e}")
|
226
239
|
|
227
240
|
@classmethod
|
228
241
|
def from_config(
|
229
|
-
cls, config:
|
230
|
-
) ->
|
242
|
+
cls, config: DataLoggerConfig, logs: Any = None
|
243
|
+
) -> DataLogger:
|
231
244
|
"""
|
232
245
|
Construct a LogManager from a LogManagerConfig.
|
233
246
|
"""
|
234
247
|
return cls(_config=config, logs=logs)
|
235
248
|
|
236
249
|
|
250
|
+
LogManagerConfig = DataLoggerConfig
|
251
|
+
LogManager = DataLogger
|
252
|
+
|
237
253
|
# File: lionagi/protocols/generic/log.py
|
@@ -168,7 +168,7 @@ class Step:
|
|
168
168
|
ACTION_REQUESTS_FIELD,
|
169
169
|
]
|
170
170
|
)
|
171
|
-
if "reason" in operative.response_model.model_fields:
|
171
|
+
if "reason" in type(operative.response_model).model_fields:
|
172
172
|
field_models.extend([REASON_FIELD])
|
173
173
|
|
174
174
|
operative = Step._create_response_type(
|
lionagi/protocols/types.py
CHANGED
@@ -18,7 +18,13 @@ from .forms.flow import FlowDefinition, FlowStep
|
|
18
18
|
from .forms.report import BaseForm, Form, Report
|
19
19
|
from .generic.element import ID, Element, IDError, IDType, validate_order
|
20
20
|
from .generic.event import Event, EventStatus, Execution
|
21
|
-
from .generic.log import
|
21
|
+
from .generic.log import (
|
22
|
+
DataLogger,
|
23
|
+
DataLoggerConfig,
|
24
|
+
Log,
|
25
|
+
LogManager,
|
26
|
+
LogManagerConfig,
|
27
|
+
)
|
22
28
|
from .generic.pile import Pile, to_list_type
|
23
29
|
from .generic.processor import Executor, Processor
|
24
30
|
from .generic.progression import Progression, prog
|
@@ -107,4 +113,6 @@ __all__ = (
|
|
107
113
|
"FunctionCalling",
|
108
114
|
"ToolRef",
|
109
115
|
"MailManager",
|
116
|
+
"DataLogger",
|
117
|
+
"DataLoggerConfig",
|
110
118
|
)
|
lionagi/service/__init__.py
CHANGED
@@ -1 +1,22 @@
|
|
1
|
-
from .
|
1
|
+
from .connections.api_calling import APICalling
|
2
|
+
from .connections.endpoint import Endpoint, EndpointConfig
|
3
|
+
from .hooks import *
|
4
|
+
from .imodel import iModel
|
5
|
+
from .manager import iModelManager
|
6
|
+
from .rate_limited_processor import RateLimitedAPIExecutor
|
7
|
+
from .token_calculator import TokenCalculator
|
8
|
+
|
9
|
+
__all__ = (
|
10
|
+
"APICalling",
|
11
|
+
"Endpoint",
|
12
|
+
"EndpointConfig",
|
13
|
+
"RateLimitedAPIExecutor",
|
14
|
+
"TokenCalculator",
|
15
|
+
"iModel",
|
16
|
+
"iModelManager",
|
17
|
+
"HookEventTypes",
|
18
|
+
"HookDict",
|
19
|
+
"AssosiatedEventInfo",
|
20
|
+
"HookEvent",
|
21
|
+
"HookRegistry",
|
22
|
+
)
|
@@ -5,10 +5,13 @@
|
|
5
5
|
import asyncio
|
6
6
|
import logging
|
7
7
|
|
8
|
-
from
|
8
|
+
from anyio import get_cancelled_exc_class
|
9
|
+
from pydantic import Field, PrivateAttr, model_validator
|
9
10
|
from typing_extensions import Self
|
10
11
|
|
11
12
|
from lionagi.protocols.generic.event import Event, EventStatus
|
13
|
+
from lionagi.protocols.types import Log
|
14
|
+
from lionagi.service.hooks import HookEvent, HookEventTypes, global_hook_logger
|
12
15
|
from lionagi.service.token_calculator import TokenCalculator
|
13
16
|
|
14
17
|
from .endpoint import Endpoint
|
@@ -51,6 +54,9 @@ class APICalling(Event):
|
|
51
54
|
exclude=True,
|
52
55
|
)
|
53
56
|
|
57
|
+
_pre_invoke_hook_event: HookEvent = PrivateAttr(None)
|
58
|
+
_post_invoke_hook_event: HookEvent = PrivateAttr(None)
|
59
|
+
|
54
60
|
@model_validator(mode="after")
|
55
61
|
def _validate_streaming(self) -> Self:
|
56
62
|
"""Validate streaming configuration and add token usage if requested."""
|
@@ -162,6 +168,13 @@ class APICalling(Event):
|
|
162
168
|
|
163
169
|
try:
|
164
170
|
self.execution.status = EventStatus.PROCESSING
|
171
|
+
if h_ev := self._pre_invoke_hook_event:
|
172
|
+
await h_ev.invoke()
|
173
|
+
if h_ev._should_exit:
|
174
|
+
raise h_ev._exit_cause or RuntimeError(
|
175
|
+
"Pre-invocation hook requested exit without a cause"
|
176
|
+
)
|
177
|
+
await global_hook_logger.alog(Log.create(h_ev))
|
165
178
|
|
166
179
|
# Make the API call with skip_payload_creation=True since payload is already prepared
|
167
180
|
response = await self.endpoint.call(
|
@@ -171,10 +184,18 @@ class APICalling(Event):
|
|
171
184
|
extra_headers=self.headers if self.headers else None,
|
172
185
|
)
|
173
186
|
|
187
|
+
if h_ev := self._post_invoke_hook_event:
|
188
|
+
await h_ev.invoke()
|
189
|
+
if h_ev._should_exit:
|
190
|
+
raise h_ev._exit_cause or RuntimeError(
|
191
|
+
"Post-invocation hook requested exit without a cause"
|
192
|
+
)
|
193
|
+
await global_hook_logger.alog(Log.create(h_ev))
|
194
|
+
|
174
195
|
self.execution.response = response
|
175
196
|
self.execution.status = EventStatus.COMPLETED
|
176
197
|
|
177
|
-
except
|
198
|
+
except get_cancelled_exc_class():
|
178
199
|
self.execution.error = "API call cancelled"
|
179
200
|
self.execution.status = EventStatus.FAILED
|
180
201
|
raise
|
@@ -228,3 +249,37 @@ class APICalling(Event):
|
|
228
249
|
def response(self):
|
229
250
|
"""Get the response from the execution."""
|
230
251
|
return self.execution.response if self.execution else None
|
252
|
+
|
253
|
+
def create_pre_invoke_hook(
|
254
|
+
self,
|
255
|
+
hook_registry,
|
256
|
+
exit_hook: bool = None,
|
257
|
+
hook_timeout: float = 30.0,
|
258
|
+
hook_params: dict = None,
|
259
|
+
):
|
260
|
+
h_ev = HookEvent(
|
261
|
+
hook_type=HookEventTypes.PreInvokation,
|
262
|
+
event_like=self,
|
263
|
+
registry=hook_registry,
|
264
|
+
exit=exit_hook,
|
265
|
+
timeout=hook_timeout,
|
266
|
+
params=hook_params or {},
|
267
|
+
)
|
268
|
+
self._pre_invoke_hook_event = h_ev
|
269
|
+
|
270
|
+
def create_post_invoke_hook(
|
271
|
+
self,
|
272
|
+
hook_registry,
|
273
|
+
exit_hook: bool = None,
|
274
|
+
hook_timeout: float = 30.0,
|
275
|
+
hook_params: dict = None,
|
276
|
+
):
|
277
|
+
h_ev = HookEvent(
|
278
|
+
hook_type=HookEventTypes.PostInvokation,
|
279
|
+
event_like=self,
|
280
|
+
registry=hook_registry,
|
281
|
+
exit=exit_hook,
|
282
|
+
timeout=hook_timeout,
|
283
|
+
params=hook_params or {},
|
284
|
+
)
|
285
|
+
self._post_invoke_hook_event = h_ev
|
@@ -32,7 +32,7 @@ class EndpointConfig(BaseModel):
|
|
32
32
|
endpoint_params: list[str] | None = None
|
33
33
|
method: str = "POST"
|
34
34
|
params: dict[str, str] = Field(default_factory=dict)
|
35
|
-
content_type: str = "application/json"
|
35
|
+
content_type: str | None = "application/json"
|
36
36
|
auth_type: AUTH_TYPES = "bearer"
|
37
37
|
default_headers: dict = {}
|
38
38
|
request_options: B | None = None
|
@@ -27,11 +27,13 @@ class HeaderFactory:
|
|
27
27
|
@staticmethod
|
28
28
|
def get_header(
|
29
29
|
auth_type: AUTH_TYPES,
|
30
|
-
content_type: str = "application/json",
|
30
|
+
content_type: str | None = "application/json",
|
31
31
|
api_key: str | SecretStr | None = None,
|
32
32
|
default_headers: dict[str, str] | None = None,
|
33
33
|
) -> dict[str, str]:
|
34
|
-
dict_ =
|
34
|
+
dict_ = {}
|
35
|
+
if content_type is not None:
|
36
|
+
dict_ = HeaderFactory.get_content_type_header(content_type)
|
35
37
|
|
36
38
|
if auth_type == "none":
|
37
39
|
# No authentication needed
|
@@ -16,49 +16,49 @@ def match_endpoint(
|
|
16
16
|
if "chat" in endpoint:
|
17
17
|
from .providers.oai_ import OpenaiChatEndpoint
|
18
18
|
|
19
|
-
return OpenaiChatEndpoint(**kwargs)
|
19
|
+
return OpenaiChatEndpoint(None, **kwargs)
|
20
20
|
if "response" in endpoint:
|
21
21
|
from .providers.oai_ import OpenaiResponseEndpoint
|
22
22
|
|
23
|
-
return OpenaiResponseEndpoint(**kwargs)
|
23
|
+
return OpenaiResponseEndpoint(None, **kwargs)
|
24
24
|
if provider == "openrouter" and "chat" in endpoint:
|
25
25
|
from .providers.oai_ import OpenrouterChatEndpoint
|
26
26
|
|
27
|
-
return OpenrouterChatEndpoint(**kwargs)
|
27
|
+
return OpenrouterChatEndpoint(None, **kwargs)
|
28
28
|
if provider == "ollama" and "chat" in endpoint:
|
29
29
|
from .providers.ollama_ import OllamaChatEndpoint
|
30
30
|
|
31
|
-
return OllamaChatEndpoint(**kwargs)
|
31
|
+
return OllamaChatEndpoint(None, **kwargs)
|
32
32
|
if provider == "exa" and "search" in endpoint:
|
33
33
|
from .providers.exa_ import ExaSearchEndpoint
|
34
34
|
|
35
|
-
return ExaSearchEndpoint(**kwargs)
|
35
|
+
return ExaSearchEndpoint(None, **kwargs)
|
36
36
|
if provider == "anthropic" and (
|
37
37
|
"messages" in endpoint or "chat" in endpoint
|
38
38
|
):
|
39
39
|
from .providers.anthropic_ import AnthropicMessagesEndpoint
|
40
40
|
|
41
|
-
return AnthropicMessagesEndpoint(**kwargs)
|
41
|
+
return AnthropicMessagesEndpoint(None, **kwargs)
|
42
42
|
if provider == "groq" and "chat" in endpoint:
|
43
43
|
from .providers.oai_ import GroqChatEndpoint
|
44
44
|
|
45
|
-
return GroqChatEndpoint(**kwargs)
|
45
|
+
return GroqChatEndpoint(None, **kwargs)
|
46
46
|
if provider == "perplexity" and "chat" in endpoint:
|
47
47
|
from .providers.perplexity_ import PerplexityChatEndpoint
|
48
48
|
|
49
|
-
return PerplexityChatEndpoint(**kwargs)
|
49
|
+
return PerplexityChatEndpoint(None, **kwargs)
|
50
50
|
if provider == "claude_code":
|
51
51
|
if "cli" in endpoint:
|
52
52
|
from .providers.claude_code_cli import ClaudeCodeCLIEndpoint
|
53
53
|
|
54
|
-
return ClaudeCodeCLIEndpoint(**kwargs)
|
54
|
+
return ClaudeCodeCLIEndpoint(None, **kwargs)
|
55
55
|
|
56
56
|
if "query" in endpoint or "code" in endpoint:
|
57
57
|
from lionagi.service.connections.providers.claude_code_ import (
|
58
58
|
ClaudeCodeEndpoint,
|
59
59
|
)
|
60
60
|
|
61
|
-
return ClaudeCodeEndpoint(**kwargs)
|
61
|
+
return ClaudeCodeEndpoint(None, **kwargs)
|
62
62
|
|
63
63
|
from .providers.oai_ import OpenaiChatEndpoint
|
64
64
|
|
@@ -9,7 +9,7 @@ from lionagi.service.connections.endpoint import Endpoint
|
|
9
9
|
from lionagi.service.connections.endpoint_config import EndpointConfig
|
10
10
|
from lionagi.service.third_party.anthropic_models import CreateMessageRequest
|
11
11
|
|
12
|
-
|
12
|
+
_get_config = lambda: EndpointConfig(
|
13
13
|
name="anthropic_messages",
|
14
14
|
provider="anthropic",
|
15
15
|
base_url="https://api.anthropic.com/v1",
|
@@ -22,13 +22,16 @@ ANTHROPIC_MESSAGES_ENDPOINT_CONFIG = EndpointConfig(
|
|
22
22
|
request_options=CreateMessageRequest,
|
23
23
|
)
|
24
24
|
|
25
|
+
ANTHROPIC_MESSAGES_ENDPOINT_CONFIG = _get_config() # backward compatibility
|
26
|
+
|
25
27
|
|
26
28
|
class AnthropicMessagesEndpoint(Endpoint):
|
27
29
|
def __init__(
|
28
30
|
self,
|
29
|
-
config: EndpointConfig =
|
31
|
+
config: EndpointConfig = None,
|
30
32
|
**kwargs,
|
31
33
|
):
|
34
|
+
config = config or _get_config()
|
32
35
|
super().__init__(config, **kwargs)
|
33
36
|
|
34
37
|
def create_payload(
|
@@ -12,12 +12,14 @@ from pydantic import BaseModel
|
|
12
12
|
from lionagi.libs.schema.as_readable import as_readable
|
13
13
|
from lionagi.service.connections.endpoint import Endpoint
|
14
14
|
from lionagi.service.connections.endpoint_config import EndpointConfig
|
15
|
-
from lionagi.utils import
|
15
|
+
from lionagi.utils import to_dict, to_list
|
16
16
|
|
17
|
-
from .
|
17
|
+
from ...third_party.claude_code import (
|
18
18
|
CLAUDE_CODE_OPTION_PARAMS,
|
19
|
+
HAS_CLAUDE_CODE_SDK,
|
19
20
|
ClaudeCodeRequest,
|
20
21
|
ClaudePermission,
|
22
|
+
stream_cc_sdk_events,
|
21
23
|
)
|
22
24
|
|
23
25
|
__all__ = (
|
@@ -27,24 +29,27 @@ __all__ = (
|
|
27
29
|
"ClaudeCodeEndpoint",
|
28
30
|
)
|
29
31
|
|
30
|
-
HAS_CLAUDE_CODE_SDK = is_import_installed("claude_code_sdk")
|
31
32
|
|
32
33
|
# --------------------------------------------------------------------------- SDK endpoint
|
33
|
-
|
34
|
+
|
35
|
+
_get_config = lambda: EndpointConfig(
|
34
36
|
name="claude_code",
|
35
37
|
provider="claude_code",
|
36
38
|
base_url="internal",
|
37
39
|
endpoint="query",
|
38
|
-
api_key="dummy",
|
39
40
|
request_options=ClaudeCodeRequest,
|
40
41
|
timeout=3000,
|
42
|
+
api_key="dummy-key",
|
41
43
|
)
|
42
44
|
|
43
45
|
|
46
|
+
ENDPOINT_CONFIG = _get_config() # backward compatibility
|
47
|
+
|
48
|
+
|
44
49
|
class ClaudeCodeEndpoint(Endpoint):
|
45
50
|
"""Direct Python-SDK (non-CLI) endpoint - unchanged except for bug-fixes."""
|
46
51
|
|
47
|
-
def __init__(self, config: EndpointConfig =
|
52
|
+
def __init__(self, config: EndpointConfig = None, **kwargs):
|
48
53
|
if not HAS_CLAUDE_CODE_SDK:
|
49
54
|
raise ImportError(
|
50
55
|
"claude_code_sdk is not installed. "
|
@@ -56,6 +61,7 @@ class ClaudeCodeEndpoint(Endpoint):
|
|
56
61
|
DeprecationWarning,
|
57
62
|
)
|
58
63
|
|
64
|
+
config = config or _get_config()
|
59
65
|
super().__init__(config=config, **kwargs)
|
60
66
|
|
61
67
|
def create_payload(self, request: dict | BaseModel, **kwargs):
|
@@ -64,16 +70,9 @@ class ClaudeCodeEndpoint(Endpoint):
|
|
64
70
|
req_obj = ClaudeCodeRequest.create(messages=messages, **req_dict)
|
65
71
|
return {"request": req_obj}, {}
|
66
72
|
|
67
|
-
def _stream_claude_code(self, request: ClaudeCodeRequest):
|
68
|
-
from claude_code_sdk import query as sdk_query
|
69
|
-
|
70
|
-
return sdk_query(
|
71
|
-
prompt=request.prompt, options=request.as_claude_options()
|
72
|
-
)
|
73
|
-
|
74
73
|
async def stream(self, request: dict | BaseModel, **kwargs):
|
75
74
|
payload, _ = self.create_payload(request, **kwargs)
|
76
|
-
async for chunk in
|
75
|
+
async for chunk in stream_cc_sdk_events(payload["request"]):
|
77
76
|
yield chunk
|
78
77
|
|
79
78
|
def _parse_claude_code_response(self, responses: list) -> dict:
|
@@ -204,9 +203,6 @@ class ClaudeCodeEndpoint(Endpoint):
|
|
204
203
|
|
205
204
|
responses.append(chunk)
|
206
205
|
|
207
|
-
# 3. Parse the responses into a clean format
|
208
|
-
return self._parse_claude_code_response(responses)
|
209
|
-
|
210
206
|
|
211
207
|
def _display_message(chunk, theme):
|
212
208
|
from claude_code_sdk import types as cc_types
|
@@ -4,48 +4,78 @@
|
|
4
4
|
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
-
from collections.abc import AsyncIterator
|
7
|
+
from collections.abc import AsyncIterator, Callable
|
8
8
|
|
9
9
|
from pydantic import BaseModel
|
10
10
|
|
11
11
|
from lionagi.service.connections.endpoint import Endpoint, EndpointConfig
|
12
12
|
from lionagi.utils import to_dict
|
13
13
|
|
14
|
-
from .
|
15
|
-
from ._claude_code.stream_cli import (
|
14
|
+
from ...third_party.claude_code import (
|
16
15
|
ClaudeChunk,
|
16
|
+
ClaudeCodeRequest,
|
17
17
|
ClaudeSession,
|
18
|
-
|
18
|
+
)
|
19
|
+
from ...third_party.claude_code import log as cc_log
|
20
|
+
from ...third_party.claude_code import (
|
19
21
|
stream_claude_code_cli,
|
20
22
|
)
|
21
23
|
|
22
|
-
|
24
|
+
_get_config = lambda: EndpointConfig(
|
23
25
|
name="claude_code_cli",
|
24
26
|
provider="claude_code",
|
25
27
|
base_url="internal",
|
26
28
|
endpoint="query_cli",
|
27
|
-
api_key="dummy",
|
29
|
+
api_key="dummy-key",
|
28
30
|
request_options=ClaudeCodeRequest,
|
29
31
|
timeout=18000, # 30 mins
|
30
32
|
)
|
31
33
|
|
34
|
+
ENDPOINT_CONFIG = _get_config() # backward compatibility
|
35
|
+
|
36
|
+
|
37
|
+
_CLAUDE_HANDLER_PARAMS = (
|
38
|
+
"on_thinking",
|
39
|
+
"on_text",
|
40
|
+
"on_tool_use",
|
41
|
+
"on_tool_result",
|
42
|
+
"on_system",
|
43
|
+
"on_final",
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
def _validate_handlers(handlers: dict[str, Callable | None], /) -> None:
|
48
|
+
if not isinstance(handlers, dict):
|
49
|
+
raise ValueError("Handlers must be a dictionary")
|
50
|
+
for k, v in handlers.items():
|
51
|
+
if k not in _CLAUDE_HANDLER_PARAMS:
|
52
|
+
raise ValueError(f"Invalid handler key: {k}")
|
53
|
+
if not (v is None or callable(v)):
|
54
|
+
raise ValueError(
|
55
|
+
f"Handler value must be callable or None, got {type(v)}"
|
56
|
+
)
|
57
|
+
|
32
58
|
|
33
59
|
class ClaudeCodeCLIEndpoint(Endpoint):
|
34
|
-
def __init__(self, config: EndpointConfig =
|
60
|
+
def __init__(self, config: EndpointConfig = None, **kwargs):
|
61
|
+
config = config or _get_config()
|
35
62
|
super().__init__(config=config, **kwargs)
|
36
63
|
|
37
64
|
@property
|
38
65
|
def claude_handlers(self):
|
39
|
-
handlers = {
|
40
|
-
"on_thinking": None,
|
41
|
-
"on_text": None,
|
42
|
-
"on_tool_use": None,
|
43
|
-
"on_tool_result": None,
|
44
|
-
"on_system": None,
|
45
|
-
"on_final": None,
|
46
|
-
}
|
66
|
+
handlers = {k: None for k in _CLAUDE_HANDLER_PARAMS}
|
47
67
|
return self.config.kwargs.get("claude_handlers", handlers)
|
48
68
|
|
69
|
+
@claude_handlers.setter
|
70
|
+
def claude_handlers(self, value: dict):
|
71
|
+
_validate_handlers(value)
|
72
|
+
self.config.kwargs["claude_handlers"] = value
|
73
|
+
|
74
|
+
def update_handlers(self, **kwargs):
|
75
|
+
_validate_handlers(kwargs)
|
76
|
+
handlers = {**self.claude_handlers, **kwargs}
|
77
|
+
self.claude_handlers = handlers
|
78
|
+
|
49
79
|
def create_payload(self, request: dict | BaseModel, **kwargs):
|
50
80
|
req_dict = {**self.config.kwargs, **to_dict(request), **kwargs}
|
51
81
|
messages = req_dict.pop("messages")
|
@@ -75,6 +105,8 @@ class ClaudeCodeCLIEndpoint(Endpoint):
|
|
75
105
|
request, session, **self.claude_handlers, **kwargs
|
76
106
|
):
|
77
107
|
if isinstance(chunk, dict):
|
108
|
+
if chunk.get("type") == "done":
|
109
|
+
break
|
78
110
|
system = chunk
|
79
111
|
responses.append(chunk)
|
80
112
|
|
@@ -92,7 +124,7 @@ class ClaudeCodeCLIEndpoint(Endpoint):
|
|
92
124
|
responses.append(chunk)
|
93
125
|
if isinstance(chunk, ClaudeSession):
|
94
126
|
break
|
95
|
-
|
127
|
+
cc_log.info(
|
96
128
|
f"Session {session.session_id} finished with {len(responses)} chunks"
|
97
129
|
)
|
98
130
|
texts = []
|
@@ -102,4 +134,7 @@ class ClaudeCodeCLIEndpoint(Endpoint):
|
|
102
134
|
|
103
135
|
texts.append(session.result)
|
104
136
|
session.result = "\n".join(texts)
|
137
|
+
if request.cli_include_summary:
|
138
|
+
session.populate_summary()
|
139
|
+
|
105
140
|
return to_dict(session, recursive=True)
|
@@ -11,8 +11,7 @@ from lionagi.service.third_party.exa_models import ExaSearchRequest
|
|
11
11
|
|
12
12
|
__all__ = ("ExaSearchEndpoint",)
|
13
13
|
|
14
|
-
|
15
|
-
ENDPOINT_CONFIG = EndpointConfig(
|
14
|
+
_get_config = lambda: EndpointConfig(
|
16
15
|
name="exa_search",
|
17
16
|
provider="exa",
|
18
17
|
base_url="https://api.exa.ai",
|
@@ -27,7 +26,10 @@ ENDPOINT_CONFIG = EndpointConfig(
|
|
27
26
|
content_type="application/json",
|
28
27
|
)
|
29
28
|
|
29
|
+
ENDPOINT_CONFIG = _get_config() # backward compatibility
|
30
|
+
|
30
31
|
|
31
32
|
class ExaSearchEndpoint(Endpoint):
|
32
|
-
def __init__(self, config=
|
33
|
+
def __init__(self, config: EndpointConfig = None, **kwargs):
|
34
|
+
config = config or _get_config()
|
33
35
|
super().__init__(config=config, **kwargs)
|