lionagi 0.17.5__py3-none-any.whl → 0.17.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/__init__.py +5 -2
- lionagi/config.py +26 -0
- lionagi/fields/action.py +5 -3
- lionagi/libs/file/chunk.py +3 -14
- lionagi/libs/file/process.py +10 -92
- lionagi/libs/schema/breakdown_pydantic_annotation.py +45 -0
- lionagi/ln/_async_call.py +6 -6
- lionagi/ln/fuzzy/_fuzzy_match.py +3 -6
- lionagi/ln/fuzzy/_fuzzy_validate.py +3 -4
- lionagi/ln/fuzzy/_string_similarity.py +11 -5
- lionagi/ln/fuzzy/_to_dict.py +19 -19
- lionagi/ln/types.py +15 -0
- lionagi/operations/operate/operate.py +7 -11
- lionagi/operations/parse/parse.py +5 -3
- lionagi/protocols/generic/element.py +3 -6
- lionagi/protocols/generic/event.py +1 -1
- lionagi/protocols/mail/package.py +2 -2
- lionagi/protocols/messages/instruction.py +9 -1
- lionagi/protocols/operatives/operative.py +4 -3
- lionagi/service/__init__.py +67 -8
- lionagi/service/broadcaster.py +61 -0
- lionagi/service/connections/api_calling.py +21 -140
- lionagi/service/hooks/__init__.py +2 -10
- lionagi/service/hooks/_types.py +5 -4
- lionagi/service/hooks/hook_registry.py +11 -11
- lionagi/service/hooks/hooked_event.py +142 -0
- lionagi/service/imodel.py +11 -6
- lionagi/session/branch.py +46 -169
- lionagi/session/session.py +1 -44
- lionagi/tools/file/reader.py +6 -4
- lionagi/utils.py +3 -334
- lionagi/version.py +1 -1
- {lionagi-0.17.5.dist-info → lionagi-0.17.7.dist-info}/METADATA +2 -2
- {lionagi-0.17.5.dist-info → lionagi-0.17.7.dist-info}/RECORD +36 -40
- lionagi/libs/file/_utils.py +0 -10
- lionagi/libs/file/concat.py +0 -121
- lionagi/libs/file/concat_files.py +0 -85
- lionagi/libs/file/file_ops.py +0 -118
- lionagi/libs/file/save.py +0 -103
- lionagi/ln/concurrency/throttle.py +0 -83
- lionagi/settings.py +0 -71
- {lionagi-0.17.5.dist-info → lionagi-0.17.7.dist-info}/WHEEL +0 -0
- {lionagi-0.17.5.dist-info → lionagi-0.17.7.dist-info}/licenses/LICENSE +0 -0
@@ -12,6 +12,7 @@ from lionagi.models import FieldModel, ModelParams
|
|
12
12
|
from lionagi.protocols.operatives.step import Operative, Step
|
13
13
|
from lionagi.protocols.types import Instruction, Progression, SenderRecipient
|
14
14
|
from lionagi.service.imodel import iModel
|
15
|
+
from lionagi.session.branch import AlcallParams
|
15
16
|
|
16
17
|
if TYPE_CHECKING:
|
17
18
|
from lionagi.session.branch import Branch, ToolRef
|
@@ -64,10 +65,8 @@ async def operate(
|
|
64
65
|
return_operative: bool = False,
|
65
66
|
actions: bool = False,
|
66
67
|
reason: bool = False,
|
67
|
-
|
68
|
-
action_strategy: Literal[
|
69
|
-
"sequential", "concurrent", "batch"
|
70
|
-
] = "concurrent",
|
68
|
+
call_params: AlcallParams = None,
|
69
|
+
action_strategy: Literal["sequential", "concurrent"] = "concurrent",
|
71
70
|
verbose_action: bool = False,
|
72
71
|
field_models: list[FieldModel] = None,
|
73
72
|
exclude_fields: list | dict | None = None,
|
@@ -191,17 +190,14 @@ async def operate(
|
|
191
190
|
getattr(response_model, "action_required", None) is True
|
192
191
|
and getattr(response_model, "action_requests", None) is not None
|
193
192
|
):
|
194
|
-
|
195
|
-
|
196
|
-
instruct.action_strategy
|
197
|
-
if instruct.action_strategy
|
198
|
-
else action_kwargs.get("strategy", "concurrent")
|
193
|
+
action_strategy = (
|
194
|
+
action_strategy or instruct.action_strategy or "concurrent"
|
199
195
|
)
|
200
|
-
|
201
196
|
action_response_models = await branch.act(
|
202
197
|
response_model.action_requests,
|
198
|
+
strategy=action_strategy,
|
203
199
|
verbose_action=verbose_action,
|
204
|
-
|
200
|
+
call_params=call_params,
|
205
201
|
)
|
206
202
|
# Possibly refine the operative with the tool outputs
|
207
203
|
operative = Step.respond_operative(
|
@@ -6,9 +6,6 @@ from typing import TYPE_CHECKING, Any, Literal
|
|
6
6
|
|
7
7
|
from pydantic import BaseModel
|
8
8
|
|
9
|
-
from lionagi.ln.fuzzy._fuzzy_validate import fuzzy_validate_mapping
|
10
|
-
from lionagi.utils import breakdown_pydantic_annotation
|
11
|
-
|
12
9
|
if TYPE_CHECKING:
|
13
10
|
from lionagi.session.branch import Branch
|
14
11
|
|
@@ -34,6 +31,11 @@ async def parse(
|
|
34
31
|
suppress_conversion_errors: bool = False,
|
35
32
|
response_format=None,
|
36
33
|
):
|
34
|
+
from lionagi.libs.schema.breakdown_pydantic_annotation import (
|
35
|
+
breakdown_pydantic_annotation,
|
36
|
+
)
|
37
|
+
from lionagi.ln.fuzzy._fuzzy_validate import fuzzy_validate_mapping
|
38
|
+
|
37
39
|
if operative is not None:
|
38
40
|
max_retries = operative.max_retries
|
39
41
|
response_format = operative.request_type or response_format
|
@@ -21,8 +21,7 @@ from pydantic import (
|
|
21
21
|
from lionagi import ln
|
22
22
|
from lionagi._class_registry import get_class
|
23
23
|
from lionagi._errors import IDError
|
24
|
-
from lionagi.
|
25
|
-
from lionagi.utils import import_module, time, to_dict
|
24
|
+
from lionagi.utils import import_module, to_dict
|
26
25
|
|
27
26
|
from .._concepts import Collective, Observable, Ordering
|
28
27
|
|
@@ -156,9 +155,7 @@ class Element(BaseModel, Observable):
|
|
156
155
|
frozen=True,
|
157
156
|
)
|
158
157
|
created_at: float = Field(
|
159
|
-
default_factory=lambda:
|
160
|
-
tz=Settings.Config.TIMEZONE, type_="timestamp"
|
161
|
-
),
|
158
|
+
default_factory=lambda: ln.now_utc().timestamp(),
|
162
159
|
title="Creation Timestamp",
|
163
160
|
description="Timestamp of element creation.",
|
164
161
|
frozen=True,
|
@@ -205,7 +202,7 @@ class Element(BaseModel, Observable):
|
|
205
202
|
ValueError: If `val` cannot be converted to a float timestamp.
|
206
203
|
"""
|
207
204
|
if val is None:
|
208
|
-
return
|
205
|
+
return ln.now_utc().timestamp()
|
209
206
|
if isinstance(val, float):
|
210
207
|
return val
|
211
208
|
if isinstance(val, dt.datetime):
|
@@ -138,7 +138,7 @@ class Event(Element):
|
|
138
138
|
"""
|
139
139
|
|
140
140
|
execution: Execution = Field(default_factory=Execution)
|
141
|
-
streaming: bool = False
|
141
|
+
streaming: bool = Field(False, exclude=True)
|
142
142
|
|
143
143
|
@field_serializer("execution")
|
144
144
|
def _serialize_execution(self, val: Execution) -> dict:
|
@@ -5,8 +5,8 @@
|
|
5
5
|
from enum import Enum
|
6
6
|
from typing import Any
|
7
7
|
|
8
|
+
from lionagi.ln import now_utc
|
8
9
|
from lionagi.protocols.generic.element import ID, IDType
|
9
|
-
from lionagi.utils import time
|
10
10
|
|
11
11
|
from .._concepts import Communicatable, Observable
|
12
12
|
|
@@ -93,7 +93,7 @@ class Package(Observable):
|
|
93
93
|
):
|
94
94
|
super().__init__()
|
95
95
|
self.id = IDType.create()
|
96
|
-
self.created_at =
|
96
|
+
self.created_at = now_utc().timestamp()
|
97
97
|
self.category = validate_category(category)
|
98
98
|
self.item = item
|
99
99
|
self.request_source = request_source
|
@@ -7,7 +7,7 @@ from typing import Any, Literal
|
|
7
7
|
from pydantic import BaseModel, JsonValue, field_serializer
|
8
8
|
from typing_extensions import override
|
9
9
|
|
10
|
-
from lionagi.utils import UNDEFINED,
|
10
|
+
from lionagi.utils import UNDEFINED, copy
|
11
11
|
|
12
12
|
from .base import MessageRole
|
13
13
|
from .message import RoledMessage, SenderRecipient
|
@@ -256,6 +256,10 @@ def prepare_instruction_content(
|
|
256
256
|
Raises:
|
257
257
|
ValueError: If request_fields and request_model are both given.
|
258
258
|
"""
|
259
|
+
from lionagi.libs.schema.breakdown_pydantic_annotation import (
|
260
|
+
breakdown_pydantic_annotation,
|
261
|
+
)
|
262
|
+
|
259
263
|
if request_fields and request_model:
|
260
264
|
raise ValueError(
|
261
265
|
"only one of request_fields or request_model can be provided"
|
@@ -476,6 +480,10 @@ class Instruction(RoledMessage):
|
|
476
480
|
|
477
481
|
@response_format.setter
|
478
482
|
def response_format(self, model: type[BaseModel]) -> None:
|
483
|
+
from lionagi.libs.schema.breakdown_pydantic_annotation import (
|
484
|
+
breakdown_pydantic_annotation,
|
485
|
+
)
|
486
|
+
|
479
487
|
if isinstance(model, BaseModel):
|
480
488
|
self.content["request_model"] = type(model)
|
481
489
|
else:
|
@@ -7,9 +7,10 @@ from typing import Any
|
|
7
7
|
from pydantic import BaseModel
|
8
8
|
from pydantic.fields import FieldInfo
|
9
9
|
|
10
|
+
from lionagi.ln import extract_json
|
10
11
|
from lionagi.ln.fuzzy._fuzzy_match import fuzzy_match_keys
|
11
12
|
from lionagi.models import FieldModel, ModelParams, OperableModel
|
12
|
-
from lionagi.utils import UNDEFINED
|
13
|
+
from lionagi.utils import UNDEFINED
|
13
14
|
|
14
15
|
|
15
16
|
class Operative:
|
@@ -145,7 +146,7 @@ class Operative:
|
|
145
146
|
Raises:
|
146
147
|
Exception: If the validation fails.
|
147
148
|
"""
|
148
|
-
d_ =
|
149
|
+
d_ = extract_json(text, fuzzy_parse=True)
|
149
150
|
if isinstance(d_, list | tuple) and len(d_) == 1:
|
150
151
|
d_ = d_[0]
|
151
152
|
try:
|
@@ -167,7 +168,7 @@ class Operative:
|
|
167
168
|
"""
|
168
169
|
d_ = text
|
169
170
|
try:
|
170
|
-
d_ =
|
171
|
+
d_ = extract_json(text, fuzzy_parse=True)
|
171
172
|
if isinstance(d_, list | tuple) and len(d_) == 1:
|
172
173
|
d_ = d_[0]
|
173
174
|
d_ = fuzzy_match_keys(
|
lionagi/service/__init__.py
CHANGED
@@ -1,12 +1,26 @@
|
|
1
1
|
# Eager imports for core functionality
|
2
|
-
from
|
3
|
-
|
4
|
-
from .hooks import
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
2
|
+
from typing import TYPE_CHECKING
|
3
|
+
|
4
|
+
from .hooks import (
|
5
|
+
AssosiatedEventInfo,
|
6
|
+
HookDict,
|
7
|
+
HookedEvent,
|
8
|
+
HookEvent,
|
9
|
+
HookEventTypes,
|
10
|
+
HookRegistry,
|
11
|
+
global_hook_logger,
|
12
|
+
)
|
13
|
+
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from .broadcaster import Broadcaster
|
16
|
+
from .connections.api_calling import APICalling
|
17
|
+
from .connections.endpoint import Endpoint, EndpointConfig
|
18
|
+
from .imodel import iModel
|
19
|
+
from .manager import iModelManager
|
20
|
+
from .rate_limited_processor import RateLimitedAPIExecutor
|
21
|
+
from .token_calculator import TokenCalculator
|
22
|
+
|
23
|
+
|
10
24
|
_lazy_imports = {}
|
11
25
|
|
12
26
|
|
@@ -15,12 +29,49 @@ def __getattr__(name: str):
|
|
15
29
|
if name in _lazy_imports:
|
16
30
|
return _lazy_imports[name]
|
17
31
|
|
32
|
+
if name == "RateLimitedAPIExecutor":
|
33
|
+
from .rate_limited_processor import RateLimitedAPIExecutor
|
34
|
+
|
35
|
+
_lazy_imports["RateLimitedAPIExecutor"] = RateLimitedAPIExecutor
|
36
|
+
return RateLimitedAPIExecutor
|
37
|
+
|
38
|
+
if name in ("Endpoint", "EndpointConfig"):
|
39
|
+
from .connections.endpoint import Endpoint, EndpointConfig
|
40
|
+
|
41
|
+
_lazy_imports["Endpoint"] = Endpoint
|
42
|
+
_lazy_imports["EndpointConfig"] = EndpointConfig
|
43
|
+
return Endpoint if name == "Endpoint" else EndpointConfig
|
44
|
+
|
45
|
+
if name == "iModelManager":
|
46
|
+
from .manager import iModelManager
|
47
|
+
|
48
|
+
_lazy_imports["iModelManager"] = iModelManager
|
49
|
+
return iModelManager
|
50
|
+
|
51
|
+
if name == "iModel":
|
52
|
+
from .imodel import iModel
|
53
|
+
|
54
|
+
_lazy_imports["iModel"] = iModel
|
55
|
+
return iModel
|
56
|
+
|
57
|
+
if name == "APICalling":
|
58
|
+
from .connections.api_calling import APICalling
|
59
|
+
|
60
|
+
_lazy_imports["APICalling"] = APICalling
|
61
|
+
return APICalling
|
62
|
+
|
18
63
|
if name == "TokenCalculator":
|
19
64
|
from .token_calculator import TokenCalculator
|
20
65
|
|
21
66
|
_lazy_imports["TokenCalculator"] = TokenCalculator
|
22
67
|
return TokenCalculator
|
23
68
|
|
69
|
+
if name == "Broadcaster":
|
70
|
+
from .broadcaster import Broadcaster
|
71
|
+
|
72
|
+
_lazy_imports["Broadcaster"] = Broadcaster
|
73
|
+
return Broadcaster
|
74
|
+
|
24
75
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
25
76
|
|
26
77
|
|
@@ -37,4 +88,12 @@ __all__ = (
|
|
37
88
|
"AssosiatedEventInfo",
|
38
89
|
"HookEvent",
|
39
90
|
"HookRegistry",
|
91
|
+
"Broadcaster",
|
92
|
+
"HookEventTypes",
|
93
|
+
"HookDict",
|
94
|
+
"AssosiatedEventInfo",
|
95
|
+
"HookEvent",
|
96
|
+
"HookRegistry",
|
97
|
+
"global_hook_logger",
|
98
|
+
"HookedEvent",
|
40
99
|
)
|
@@ -0,0 +1,61 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from collections.abc import Callable
|
5
|
+
from typing import Any, ClassVar
|
6
|
+
|
7
|
+
from lionagi.ln.concurrency.utils import is_coro_func
|
8
|
+
from lionagi.protocols.generic.event import Event
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
__all__ = ("Broadcaster",)
|
13
|
+
|
14
|
+
|
15
|
+
class Broadcaster:
|
16
|
+
"""Real-time event broadcasting system for hook events. Should subclass to implement specific event types."""
|
17
|
+
|
18
|
+
_instance: ClassVar[Broadcaster | None] = None
|
19
|
+
_subscribers: ClassVar[list[Callable[[Any], None]]] = []
|
20
|
+
_event_type: ClassVar[type[Event]]
|
21
|
+
|
22
|
+
def __new__(cls):
|
23
|
+
if cls._instance is None:
|
24
|
+
cls._instance = super().__new__(cls)
|
25
|
+
return cls._instance
|
26
|
+
|
27
|
+
@classmethod
|
28
|
+
def subscribe(cls, callback: Callable[[Any], None]) -> None:
|
29
|
+
"""Subscribe to hook events with sync callback."""
|
30
|
+
if callback not in cls._subscribers:
|
31
|
+
cls._subscribers.append(callback)
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def unsubscribe(cls, callback: Callable[[Any], None]) -> None:
|
35
|
+
"""Unsubscribe from hook events."""
|
36
|
+
if callback in cls._subscribers:
|
37
|
+
cls._subscribers.remove(callback)
|
38
|
+
|
39
|
+
@classmethod
|
40
|
+
async def broadcast(cls, event) -> None:
|
41
|
+
"""Broadcast event to all subscribers."""
|
42
|
+
if not isinstance(event, cls._event_type):
|
43
|
+
raise ValueError(
|
44
|
+
f"Event must be of type {cls._event_type.__name__}"
|
45
|
+
)
|
46
|
+
|
47
|
+
for callback in cls._subscribers:
|
48
|
+
try:
|
49
|
+
if is_coro_func(callback):
|
50
|
+
await callback(event)
|
51
|
+
else:
|
52
|
+
callback(event)
|
53
|
+
except Exception as e:
|
54
|
+
logger.error(
|
55
|
+
f"Error in subscriber callback: {e}", exc_info=True
|
56
|
+
)
|
57
|
+
|
58
|
+
@classmethod
|
59
|
+
def get_subscriber_count(cls) -> int:
|
60
|
+
"""Get total number of subscribers."""
|
61
|
+
return len(cls._subscribers)
|
@@ -2,31 +2,18 @@
|
|
2
2
|
#
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
4
4
|
|
5
|
-
import asyncio
|
6
5
|
import logging
|
7
6
|
|
8
|
-
from
|
9
|
-
from pydantic import Field, PrivateAttr, model_validator
|
7
|
+
from pydantic import Field, model_validator
|
10
8
|
from typing_extensions import Self
|
11
9
|
|
12
|
-
from
|
13
|
-
from lionagi.protocols.types import Log
|
14
|
-
from lionagi.service.hooks import HookEvent, HookEventTypes, global_hook_logger
|
15
|
-
|
10
|
+
from ..hooks.hooked_event import HookedEvent
|
16
11
|
from .endpoint import Endpoint
|
17
12
|
|
18
|
-
|
19
|
-
# Lazy import for TokenCalculator
|
20
|
-
def _get_token_calculator():
|
21
|
-
from lionagi.service.token_calculator import TokenCalculator
|
22
|
-
|
23
|
-
return TokenCalculator
|
24
|
-
|
25
|
-
|
26
13
|
logger = logging.getLogger(__name__)
|
27
14
|
|
28
15
|
|
29
|
-
class APICalling(
|
16
|
+
class APICalling(HookedEvent):
|
30
17
|
"""Handles asynchronous API calls with automatic token usage tracking.
|
31
18
|
|
32
19
|
This class manages API calls through endpoints, handling both regular
|
@@ -61,9 +48,6 @@ class APICalling(Event):
|
|
61
48
|
exclude=True,
|
62
49
|
)
|
63
50
|
|
64
|
-
_pre_invoke_hook_event: HookEvent = PrivateAttr(None)
|
65
|
-
_post_invoke_hook_event: HookEvent = PrivateAttr(None)
|
66
|
-
|
67
51
|
@model_validator(mode="after")
|
68
52
|
def _validate_streaming(self) -> Self:
|
69
53
|
"""Validate streaming configuration and add token usage if requested."""
|
@@ -127,12 +111,14 @@ class APICalling(Event):
|
|
127
111
|
@property
|
128
112
|
def required_tokens(self) -> int | None:
|
129
113
|
"""Calculate the number of tokens required for this request."""
|
114
|
+
from lionagi.service.token_calculator import TokenCalculator
|
115
|
+
|
130
116
|
if not self.endpoint.config.requires_tokens:
|
131
117
|
return None
|
132
118
|
|
133
119
|
# Handle chat completions format
|
134
120
|
if "messages" in self.payload:
|
135
|
-
return
|
121
|
+
return TokenCalculator.calculate_message_tokens(
|
136
122
|
self.payload["messages"], **self.payload
|
137
123
|
)
|
138
124
|
# Handle responses API format
|
@@ -153,95 +139,29 @@ class APICalling(Event):
|
|
153
139
|
messages.append(item)
|
154
140
|
else:
|
155
141
|
return None
|
156
|
-
return
|
142
|
+
return TokenCalculator.calculate_message_tokens(
|
157
143
|
messages, **self.payload
|
158
144
|
)
|
159
145
|
# Handle embeddings endpoint
|
160
146
|
elif "embed" in self.endpoint.config.endpoint:
|
161
|
-
return
|
162
|
-
**self.payload
|
163
|
-
)
|
147
|
+
return TokenCalculator.calculate_embed_token(**self.payload)
|
164
148
|
|
165
149
|
return None
|
166
150
|
|
167
|
-
async def
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
try:
|
175
|
-
self.execution.status = EventStatus.PROCESSING
|
176
|
-
if h_ev := self._pre_invoke_hook_event:
|
177
|
-
await h_ev.invoke()
|
178
|
-
if h_ev._should_exit:
|
179
|
-
raise h_ev._exit_cause or RuntimeError(
|
180
|
-
"Pre-invocation hook requested exit without a cause"
|
181
|
-
)
|
182
|
-
await global_hook_logger.alog(Log.create(h_ev))
|
183
|
-
|
184
|
-
# Make the API call with skip_payload_creation=True since payload is already prepared
|
185
|
-
response = await self.endpoint.call(
|
186
|
-
request=self.payload,
|
187
|
-
cache_control=self.cache_control,
|
188
|
-
skip_payload_creation=True,
|
189
|
-
extra_headers=self.headers if self.headers else None,
|
190
|
-
)
|
191
|
-
|
192
|
-
if h_ev := self._post_invoke_hook_event:
|
193
|
-
await h_ev.invoke()
|
194
|
-
if h_ev._should_exit:
|
195
|
-
raise h_ev._exit_cause or RuntimeError(
|
196
|
-
"Post-invocation hook requested exit without a cause"
|
197
|
-
)
|
198
|
-
await global_hook_logger.alog(Log.create(h_ev))
|
199
|
-
|
200
|
-
self.execution.response = response
|
201
|
-
self.execution.status = EventStatus.COMPLETED
|
202
|
-
|
203
|
-
except get_cancelled_exc_class():
|
204
|
-
self.execution.error = "API call cancelled"
|
205
|
-
self.execution.status = EventStatus.CANCELLED
|
206
|
-
raise
|
207
|
-
|
208
|
-
except Exception as e:
|
209
|
-
self.execution.error = str(e)
|
210
|
-
self.execution.status = EventStatus.FAILED
|
211
|
-
logger.error(f"API call failed: {e}")
|
212
|
-
|
213
|
-
finally:
|
214
|
-
self.execution.duration = asyncio.get_event_loop().time() - start
|
215
|
-
|
216
|
-
async def stream(self):
|
217
|
-
"""Stream the API response through the endpoint.
|
218
|
-
|
219
|
-
Yields:
|
220
|
-
Streaming chunks from the API.
|
221
|
-
"""
|
222
|
-
start = asyncio.get_event_loop().time()
|
223
|
-
response = []
|
224
|
-
|
225
|
-
try:
|
226
|
-
self.execution.status = EventStatus.PROCESSING
|
227
|
-
|
228
|
-
async for chunk in self.endpoint.stream(
|
229
|
-
request=self.payload,
|
230
|
-
extra_headers=self.headers if self.headers else None,
|
231
|
-
):
|
232
|
-
response.append(chunk)
|
233
|
-
yield chunk
|
234
|
-
|
235
|
-
self.execution.response = response
|
236
|
-
self.execution.status = EventStatus.COMPLETED
|
237
|
-
|
238
|
-
except Exception as e:
|
239
|
-
self.execution.error = str(e)
|
240
|
-
self.execution.status = EventStatus.FAILED
|
241
|
-
logger.error(f"Streaming failed: {e}")
|
151
|
+
async def _invoke(self):
|
152
|
+
return await self.endpoint.call(
|
153
|
+
request=self.payload,
|
154
|
+
cache_control=self.cache_control,
|
155
|
+
skip_payload_creation=True,
|
156
|
+
extra_headers=self.headers if self.headers else None,
|
157
|
+
)
|
242
158
|
|
243
|
-
|
244
|
-
|
159
|
+
async def _stream(self):
|
160
|
+
async for i in self.endpoint.stream(
|
161
|
+
request=self.payload,
|
162
|
+
extra_headers=self.headers if self.headers else None,
|
163
|
+
):
|
164
|
+
yield i
|
245
165
|
|
246
166
|
@property
|
247
167
|
def request(self) -> dict:
|
@@ -249,42 +169,3 @@ class APICalling(Event):
|
|
249
169
|
return {
|
250
170
|
"required_tokens": self.required_tokens,
|
251
171
|
}
|
252
|
-
|
253
|
-
@property
|
254
|
-
def response(self):
|
255
|
-
"""Get the response from the execution."""
|
256
|
-
return self.execution.response if self.execution else None
|
257
|
-
|
258
|
-
def create_pre_invoke_hook(
|
259
|
-
self,
|
260
|
-
hook_registry,
|
261
|
-
exit_hook: bool = None,
|
262
|
-
hook_timeout: float = 30.0,
|
263
|
-
hook_params: dict = None,
|
264
|
-
):
|
265
|
-
h_ev = HookEvent(
|
266
|
-
hook_type=HookEventTypes.PreInvokation,
|
267
|
-
event_like=self,
|
268
|
-
registry=hook_registry,
|
269
|
-
exit=exit_hook,
|
270
|
-
timeout=hook_timeout,
|
271
|
-
params=hook_params or {},
|
272
|
-
)
|
273
|
-
self._pre_invoke_hook_event = h_ev
|
274
|
-
|
275
|
-
def create_post_invoke_hook(
|
276
|
-
self,
|
277
|
-
hook_registry,
|
278
|
-
exit_hook: bool = None,
|
279
|
-
hook_timeout: float = 30.0,
|
280
|
-
hook_params: dict = None,
|
281
|
-
):
|
282
|
-
h_ev = HookEvent(
|
283
|
-
hook_type=HookEventTypes.PostInvokation,
|
284
|
-
event_like=self,
|
285
|
-
registry=hook_registry,
|
286
|
-
exit=exit_hook,
|
287
|
-
timeout=hook_timeout,
|
288
|
-
params=hook_params or {},
|
289
|
-
)
|
290
|
-
self._post_invoke_hook_event = h_ev
|
@@ -1,19 +1,10 @@
|
|
1
1
|
# Copyright (c) 2025, HaiyangLi <quantocean.li at gmail dot com>
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
3
3
|
|
4
|
-
from lionagi.protocols.types import DataLogger
|
5
|
-
|
6
4
|
from ._types import AssosiatedEventInfo, HookDict, HookEventTypes
|
7
5
|
from .hook_event import HookEvent
|
8
6
|
from .hook_registry import HookRegistry
|
9
|
-
|
10
|
-
global_hook_logger = DataLogger(
|
11
|
-
persist_dir="./data/logs",
|
12
|
-
subfolder="hooks",
|
13
|
-
file_prefix="hook",
|
14
|
-
capacity=1000,
|
15
|
-
)
|
16
|
-
|
7
|
+
from .hooked_event import HookedEvent, global_hook_logger
|
17
8
|
|
18
9
|
__all__ = (
|
19
10
|
"HookEventTypes",
|
@@ -22,4 +13,5 @@ __all__ = (
|
|
22
13
|
"HookEvent",
|
23
14
|
"HookRegistry",
|
24
15
|
"global_hook_logger",
|
16
|
+
"HookedEvent",
|
25
17
|
)
|
lionagi/service/hooks/_types.py
CHANGED
@@ -23,8 +23,8 @@ __all__ = (
|
|
23
23
|
|
24
24
|
class HookEventTypes(str, Enum):
|
25
25
|
PreEventCreate = "pre_event_create"
|
26
|
-
|
27
|
-
|
26
|
+
PreInvocation = "pre_invocation"
|
27
|
+
PostInvocation = "post_invocation"
|
28
28
|
|
29
29
|
|
30
30
|
ALLOWED_HOOKS_TYPES = HookEventTypes.allowed()
|
@@ -32,11 +32,12 @@ ALLOWED_HOOKS_TYPES = HookEventTypes.allowed()
|
|
32
32
|
|
33
33
|
class HookDict(TypedDict):
|
34
34
|
pre_event_create: Callable | None
|
35
|
-
|
36
|
-
|
35
|
+
pre_invocation: Callable | None
|
36
|
+
post_invocation: Callable | None
|
37
37
|
|
38
38
|
|
39
39
|
StreamHandlers = dict[str, Callable[[SC], Awaitable[None]]]
|
40
|
+
"""Mapping of chunk type names to their respective asynchronous handler functions."""
|
40
41
|
|
41
42
|
|
42
43
|
class AssosiatedEventInfo(TypedDict, total=False):
|
@@ -102,7 +102,7 @@ class HookRegistry:
|
|
102
102
|
except Exception as e:
|
103
103
|
return (e, exit, EventStatus.CANCELLED)
|
104
104
|
|
105
|
-
async def
|
105
|
+
async def pre_invocation(
|
106
106
|
self, event: E, /, exit: bool = False, **kw
|
107
107
|
) -> tuple[Any, bool, EventStatus]:
|
108
108
|
"""Hook to be called when an event is dequeued and right before it is invoked.
|
@@ -110,12 +110,12 @@ class HookRegistry:
|
|
110
110
|
Typically used to check permissions.
|
111
111
|
|
112
112
|
The hook function takes the content of the event as a dictionary.
|
113
|
-
It can either raise an exception to abort the event
|
113
|
+
It can either raise an exception to abort the event invocation or pass to continue (status: cancelled).
|
114
114
|
It cannot modify the event itself, and won't be able to access the event instance.
|
115
115
|
"""
|
116
116
|
try:
|
117
117
|
res = await self._call(
|
118
|
-
HookEventTypes.
|
118
|
+
HookEventTypes.PreInvocation,
|
119
119
|
None,
|
120
120
|
None,
|
121
121
|
event,
|
@@ -127,16 +127,16 @@ class HookRegistry:
|
|
127
127
|
except Exception as e:
|
128
128
|
return (e, exit, EventStatus.CANCELLED)
|
129
129
|
|
130
|
-
async def
|
130
|
+
async def post_invocation(
|
131
131
|
self, event: E, /, exit: bool = False, **kw
|
132
132
|
) -> tuple[None | Exception, bool, EventStatus, EventStatus]:
|
133
133
|
"""Hook to be called right after event finished its execution.
|
134
|
-
It can either raise an exception to abort the event
|
134
|
+
It can either raise an exception to abort the event invocation or pass to continue (status: aborted).
|
135
135
|
It cannot modify the event itself, and won't be able to access the event instance.
|
136
136
|
"""
|
137
137
|
try:
|
138
138
|
res = await self._call(
|
139
|
-
HookEventTypes.
|
139
|
+
HookEventTypes.PostInvocation,
|
140
140
|
None,
|
141
141
|
None,
|
142
142
|
event,
|
@@ -156,7 +156,7 @@ class HookRegistry:
|
|
156
156
|
Typically used for logging or stream event abortion.
|
157
157
|
|
158
158
|
The handler function signature should be: `async def handler(chunk: Any) -> None`
|
159
|
-
It can either raise an exception to mark the event
|
159
|
+
It can either raise an exception to mark the event invocation as "failed" or pass to continue (status: aborted).
|
160
160
|
"""
|
161
161
|
try:
|
162
162
|
res = await self._call_stream_handler(
|
@@ -196,14 +196,14 @@ class HookRegistry:
|
|
196
196
|
match hook_type:
|
197
197
|
case HookEventTypes.PreEventCreate:
|
198
198
|
return await self.pre_event_create(event_like, **kw), meta
|
199
|
-
case HookEventTypes.
|
199
|
+
case HookEventTypes.PreInvocation:
|
200
200
|
meta["event_id"] = str(event_like.id)
|
201
201
|
meta["event_created_at"] = event_like.created_at
|
202
|
-
return await self.
|
203
|
-
case HookEventTypes.
|
202
|
+
return await self.pre_invocation(event_like, **kw), meta
|
203
|
+
case HookEventTypes.PostInvocation:
|
204
204
|
meta["event_id"] = str(event_like.id)
|
205
205
|
meta["event_created_at"] = event_like.created_at
|
206
|
-
return await self.
|
206
|
+
return await self.post_invocation(**kw), meta
|
207
207
|
return await self.handle_streaming_chunk(chunk_type, chunk, exit, **kw)
|
208
208
|
|
209
209
|
def _can_handle(
|