lionagi 0.14.10__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/libs/concurrency.py +1 -0
- lionagi/libs/token_transform/perplexity.py +2 -1
- lionagi/libs/token_transform/symbolic_compress_context.py +8 -7
- lionagi/ln/__init__.py +49 -0
- lionagi/ln/_async_call.py +293 -0
- lionagi/ln/_list_call.py +129 -0
- lionagi/ln/_models.py +126 -0
- lionagi/ln/_to_list.py +175 -0
- lionagi/ln/_types.py +146 -0
- lionagi/{libs → ln}/concurrency/__init__.py +4 -2
- lionagi/{libs → ln}/concurrency/throttle.py +2 -1
- lionagi/ln/concurrency/utils.py +14 -0
- lionagi/models/hashable_model.py +1 -2
- lionagi/operations/brainstorm/brainstorm.py +2 -1
- lionagi/operations/flow.py +3 -3
- lionagi/operations/node.py +3 -1
- lionagi/operations/plan/plan.py +3 -3
- lionagi/protocols/generic/pile.py +1 -1
- lionagi/protocols/graph/graph.py +13 -1
- lionagi/service/hooks/_types.py +2 -2
- lionagi/session/branch.py +4 -2
- lionagi/utils.py +90 -510
- lionagi/version.py +1 -1
- {lionagi-0.14.10.dist-info → lionagi-0.15.0.dist-info}/METADATA +4 -4
- {lionagi-0.14.10.dist-info → lionagi-0.15.0.dist-info}/RECORD +34 -28
- lionagi/libs/hash/__init__.py +0 -3
- lionagi/libs/hash/manager.py +0 -26
- /lionagi/{libs/hash/hash_dict.py → ln/_hash.py} +0 -0
- /lionagi/{libs → ln}/concurrency/cancel.py +0 -0
- /lionagi/{libs → ln}/concurrency/errors.py +0 -0
- /lionagi/{libs → ln}/concurrency/patterns.py +0 -0
- /lionagi/{libs → ln}/concurrency/primitives.py +0 -0
- /lionagi/{libs → ln}/concurrency/resource_tracker.py +0 -0
- /lionagi/{libs → ln}/concurrency/task.py +0 -0
- {lionagi-0.14.10.dist-info → lionagi-0.15.0.dist-info}/WHEEL +0 -0
- {lionagi-0.14.10.dist-info → lionagi-0.15.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
from ..ln.concurrency import * # backward compatibility
|
@@ -5,11 +5,12 @@ from timeit import default_timer as timer
|
|
5
5
|
import numpy as np
|
6
6
|
from pydantic import BaseModel
|
7
7
|
|
8
|
+
from lionagi.ln import alcall, lcall
|
8
9
|
from lionagi.protocols.generic.event import EventStatus
|
9
10
|
from lionagi.protocols.generic.log import Log
|
10
11
|
from lionagi.service.connections.api_calling import APICalling
|
11
12
|
from lionagi.service.imodel import iModel
|
12
|
-
from lionagi.utils import
|
13
|
+
from lionagi.utils import to_dict, to_list
|
13
14
|
|
14
15
|
|
15
16
|
@dataclass
|
@@ -2,9 +2,10 @@ from collections.abc import Callable
|
|
2
2
|
from pathlib import Path
|
3
3
|
from typing import Literal
|
4
4
|
|
5
|
+
from lionagi.ln import alcall
|
5
6
|
from lionagi.service.imodel import iModel
|
6
7
|
from lionagi.session.branch import Branch
|
7
|
-
from lionagi.utils import
|
8
|
+
from lionagi.utils import get_bins
|
8
9
|
|
9
10
|
from .base import TokenMapping, TokenMappingTemplate
|
10
11
|
from .synthlang_.base import SynthlangFramework, SynthlangTemplate
|
@@ -130,13 +131,13 @@ async def symbolic_compress_context(
|
|
130
131
|
_inner,
|
131
132
|
max_concurrent=max_concurrent,
|
132
133
|
retry_default=None,
|
133
|
-
|
134
|
-
|
134
|
+
retry_attempts=3,
|
135
|
+
retry_backoff=2,
|
135
136
|
retry_delay=1,
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
137
|
+
throttle_period=throttle_period,
|
138
|
+
output_flatten=True,
|
139
|
+
output_dropna=True,
|
140
|
+
output_unique=True,
|
140
141
|
)
|
141
142
|
text = "\n".join(results)
|
142
143
|
text = DEFAULT_INVOKATION_PROMPT + text
|
lionagi/ln/__init__.py
ADDED
@@ -0,0 +1,49 @@
|
|
1
|
+
from ._async_call import AlcallParams, BcallParams, alcall, bcall
|
2
|
+
from ._hash import hash_dict
|
3
|
+
from ._list_call import LcallParams, lcall
|
4
|
+
from ._models import DataClass, Params
|
5
|
+
from ._to_list import ToListParams, to_list
|
6
|
+
from ._types import (
|
7
|
+
Enum,
|
8
|
+
KeysDict,
|
9
|
+
MaybeSentinel,
|
10
|
+
MaybeUndefined,
|
11
|
+
MaybeUnset,
|
12
|
+
SingletonType,
|
13
|
+
T,
|
14
|
+
Undefined,
|
15
|
+
UndefinedType,
|
16
|
+
Unset,
|
17
|
+
UnsetType,
|
18
|
+
is_sentinel,
|
19
|
+
not_sentinel,
|
20
|
+
)
|
21
|
+
from .concurrency import *
|
22
|
+
|
23
|
+
__all__ = (
|
24
|
+
"Undefined",
|
25
|
+
"Unset",
|
26
|
+
"MaybeUndefined",
|
27
|
+
"MaybeUnset",
|
28
|
+
"MaybeSentinel",
|
29
|
+
"SingletonType",
|
30
|
+
"UndefinedType",
|
31
|
+
"UnsetType",
|
32
|
+
"KeysDict",
|
33
|
+
"T",
|
34
|
+
"Enum",
|
35
|
+
"is_sentinel",
|
36
|
+
"not_sentinel",
|
37
|
+
"Params",
|
38
|
+
"DataClass",
|
39
|
+
"Enum",
|
40
|
+
"hash_dict",
|
41
|
+
"to_list",
|
42
|
+
"ToListParams",
|
43
|
+
"lcall",
|
44
|
+
"LcallParams",
|
45
|
+
"alcall",
|
46
|
+
"bcall",
|
47
|
+
"AlcallParams",
|
48
|
+
"BcallParams",
|
49
|
+
)
|
@@ -0,0 +1,293 @@
|
|
1
|
+
import asyncio
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import Any, AsyncGenerator, Callable, ClassVar
|
4
|
+
|
5
|
+
import anyio
|
6
|
+
from pydantic import BaseModel
|
7
|
+
|
8
|
+
from ._models import Params
|
9
|
+
from ._to_list import to_list
|
10
|
+
from ._types import T, Unset, not_sentinel
|
11
|
+
from .concurrency import Lock as ConcurrencyLock
|
12
|
+
from .concurrency import Semaphore, create_task_group, is_coro_func
|
13
|
+
|
14
|
+
__all__ = (
|
15
|
+
"alcall",
|
16
|
+
"bcall",
|
17
|
+
"AlcallParams",
|
18
|
+
"BcallParams",
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
async def alcall(
|
23
|
+
input_: list[Any],
|
24
|
+
func: Callable[..., T],
|
25
|
+
/,
|
26
|
+
*,
|
27
|
+
input_flatten: bool = False,
|
28
|
+
input_dropna: bool = False,
|
29
|
+
input_unique: bool = False,
|
30
|
+
input_flatten_tuple_set: bool = False,
|
31
|
+
output_flatten: bool = False,
|
32
|
+
output_dropna: bool = False,
|
33
|
+
output_unique: bool = False,
|
34
|
+
output_flatten_tuple_set: bool = False,
|
35
|
+
delay_before_start: float = 0,
|
36
|
+
retry_initial_deplay: float = 0,
|
37
|
+
retry_backoff: float = 1,
|
38
|
+
retry_default: Any = Unset,
|
39
|
+
retry_timeout: float = 0,
|
40
|
+
retry_attempts: int = 0,
|
41
|
+
max_concurrent: int | None = None,
|
42
|
+
throttle_period: float | None = None,
|
43
|
+
**kwargs: Any,
|
44
|
+
) -> list[T]:
|
45
|
+
"""
|
46
|
+
Asynchronously apply a function to each element of a list, with optional input sanitization,
|
47
|
+
retries, timeout, and output processing.
|
48
|
+
"""
|
49
|
+
|
50
|
+
# Validate func is a single callable
|
51
|
+
if not callable(func):
|
52
|
+
# If func is not callable, maybe it's an iterable. Extract one callable if possible.
|
53
|
+
try:
|
54
|
+
func_list = list(func) # Convert iterable to list
|
55
|
+
except TypeError:
|
56
|
+
raise ValueError(
|
57
|
+
"func must be callable or an iterable containing one callable."
|
58
|
+
)
|
59
|
+
|
60
|
+
# Ensure exactly one callable is present
|
61
|
+
if len(func_list) != 1 or not callable(func_list[0]):
|
62
|
+
raise ValueError("Only one callable function is allowed.")
|
63
|
+
|
64
|
+
func = func_list[0]
|
65
|
+
|
66
|
+
# Process input if requested
|
67
|
+
if any((input_flatten, input_dropna)):
|
68
|
+
input_ = to_list(
|
69
|
+
input_,
|
70
|
+
flatten=input_flatten,
|
71
|
+
dropna=input_dropna,
|
72
|
+
unique=input_unique,
|
73
|
+
flatten_tuple_set=input_flatten_tuple_set,
|
74
|
+
)
|
75
|
+
else:
|
76
|
+
if not isinstance(input_, list):
|
77
|
+
# Attempt to iterate
|
78
|
+
if isinstance(input_, BaseModel):
|
79
|
+
# Pydantic model, convert to list
|
80
|
+
input_ = [input_]
|
81
|
+
else:
|
82
|
+
try:
|
83
|
+
iter(input_)
|
84
|
+
# It's iterable (tuple), convert to list of its contents
|
85
|
+
input_ = list(input_)
|
86
|
+
except TypeError:
|
87
|
+
# Not iterable, just wrap in a list
|
88
|
+
input_ = [input_]
|
89
|
+
|
90
|
+
# Optional initial delay before processing
|
91
|
+
if delay_before_start:
|
92
|
+
await anyio.sleep(delay_before_start)
|
93
|
+
|
94
|
+
semaphore = Semaphore(max_concurrent) if max_concurrent else None
|
95
|
+
throttle_delay = throttle_period or 0
|
96
|
+
coro_func = is_coro_func(func)
|
97
|
+
|
98
|
+
async def call_func(item: Any) -> T:
|
99
|
+
if coro_func:
|
100
|
+
# Async function
|
101
|
+
if retry_timeout is not None:
|
102
|
+
with anyio.move_on_after(retry_timeout) as cancel_scope:
|
103
|
+
result = await func(item, **kwargs)
|
104
|
+
if cancel_scope.cancelled_caught:
|
105
|
+
raise asyncio.TimeoutError(
|
106
|
+
f"Function call timed out after {retry_timeout}s"
|
107
|
+
)
|
108
|
+
return result
|
109
|
+
else:
|
110
|
+
return await func(item, **kwargs)
|
111
|
+
else:
|
112
|
+
# Sync function
|
113
|
+
if retry_timeout is not None:
|
114
|
+
with anyio.move_on_after(retry_timeout) as cancel_scope:
|
115
|
+
result = await anyio.to_thread.run_sync(
|
116
|
+
func, item, **kwargs
|
117
|
+
)
|
118
|
+
if cancel_scope.cancelled_caught:
|
119
|
+
raise asyncio.TimeoutError(
|
120
|
+
f"Function call timed out after {retry_timeout}s"
|
121
|
+
)
|
122
|
+
return result
|
123
|
+
else:
|
124
|
+
return await anyio.to_thread.run_sync(func, item, **kwargs)
|
125
|
+
|
126
|
+
async def execute_task(i: Any, index: int) -> Any:
|
127
|
+
attempts = 0
|
128
|
+
current_delay = retry_initial_deplay
|
129
|
+
while True:
|
130
|
+
try:
|
131
|
+
result = await call_func(i)
|
132
|
+
return index, result
|
133
|
+
|
134
|
+
# if cancelled, re-raise
|
135
|
+
except anyio.get_cancelled_exc_class():
|
136
|
+
raise
|
137
|
+
|
138
|
+
# handle other exceptions
|
139
|
+
except Exception:
|
140
|
+
attempts += 1
|
141
|
+
if attempts <= retry_attempts:
|
142
|
+
if current_delay:
|
143
|
+
await anyio.sleep(current_delay)
|
144
|
+
current_delay *= retry_backoff
|
145
|
+
# Retry loop continues
|
146
|
+
else:
|
147
|
+
# Exhausted retries
|
148
|
+
if not_sentinel(retry_default):
|
149
|
+
return index, retry_default
|
150
|
+
# No default, re-raise
|
151
|
+
raise
|
152
|
+
|
153
|
+
async def task_wrapper(item: Any, idx: int) -> Any:
|
154
|
+
if semaphore:
|
155
|
+
async with semaphore:
|
156
|
+
result = await execute_task(item, idx)
|
157
|
+
else:
|
158
|
+
result = await execute_task(item, idx)
|
159
|
+
|
160
|
+
return result
|
161
|
+
|
162
|
+
# Use task group for structured concurrency
|
163
|
+
results = []
|
164
|
+
results_lock = ConcurrencyLock() # Protect results list
|
165
|
+
|
166
|
+
async def run_and_store(item: Any, idx: int):
|
167
|
+
result = await task_wrapper(item, idx)
|
168
|
+
async with results_lock:
|
169
|
+
results.append(result)
|
170
|
+
|
171
|
+
# Execute all tasks using task group
|
172
|
+
async with create_task_group() as tg:
|
173
|
+
for idx, item in enumerate(input_):
|
174
|
+
await tg.start_soon(run_and_store, item, idx)
|
175
|
+
# Apply throttle delay between starting tasks
|
176
|
+
if throttle_delay and idx < len(input_) - 1:
|
177
|
+
await anyio.sleep(throttle_delay)
|
178
|
+
|
179
|
+
# Sort by original index
|
180
|
+
results.sort(key=lambda x: x[0])
|
181
|
+
|
182
|
+
# (index, result)
|
183
|
+
output_list = [r[1] for r in results]
|
184
|
+
return to_list(
|
185
|
+
output_list,
|
186
|
+
flatten=output_flatten,
|
187
|
+
dropna=output_dropna,
|
188
|
+
unique=output_unique,
|
189
|
+
flatten_tuple_set=output_flatten_tuple_set,
|
190
|
+
)
|
191
|
+
|
192
|
+
|
193
|
+
async def bcall(
|
194
|
+
input_: list[Any],
|
195
|
+
func: Callable[..., T],
|
196
|
+
/,
|
197
|
+
batch_size: int,
|
198
|
+
*,
|
199
|
+
input_flatten: bool = False,
|
200
|
+
input_dropna: bool = False,
|
201
|
+
input_unique: bool = False,
|
202
|
+
input_flatten_tuple_set: bool = False,
|
203
|
+
output_flatten: bool = False,
|
204
|
+
output_dropna: bool = False,
|
205
|
+
output_unique: bool = False,
|
206
|
+
output_flatten_tuple_set: bool = False,
|
207
|
+
delay_before_start: float = 0,
|
208
|
+
retry_initial_deplay: float = 0,
|
209
|
+
retry_backoff: float = 1,
|
210
|
+
retry_default: Any = Unset,
|
211
|
+
retry_timeout: float = 0,
|
212
|
+
retry_attempts: int = 0,
|
213
|
+
max_concurrent: int | None = None,
|
214
|
+
throttle_period: float | None = None,
|
215
|
+
**kwargs: Any,
|
216
|
+
) -> AsyncGenerator[list[T | tuple[T, float]], None]:
|
217
|
+
input_ = to_list(input_, flatten=True, dropna=True)
|
218
|
+
|
219
|
+
for i in range(0, len(input_), batch_size):
|
220
|
+
batch = input_[i : i + batch_size] # noqa: E203
|
221
|
+
yield await alcall(
|
222
|
+
batch,
|
223
|
+
func,
|
224
|
+
input_flatten=input_flatten,
|
225
|
+
input_dropna=input_dropna,
|
226
|
+
input_unique=input_unique,
|
227
|
+
input_flatten_tuple_set=input_flatten_tuple_set,
|
228
|
+
output_flatten=output_flatten,
|
229
|
+
output_dropna=output_dropna,
|
230
|
+
output_unique=output_unique,
|
231
|
+
output_flatten_tuple_set=output_flatten_tuple_set,
|
232
|
+
delay_before_start=delay_before_start,
|
233
|
+
retry_initial_deplay=retry_initial_deplay,
|
234
|
+
retry_backoff=retry_backoff,
|
235
|
+
retry_default=retry_default,
|
236
|
+
retry_timeout=retry_timeout,
|
237
|
+
retry_attempts=retry_attempts,
|
238
|
+
max_concurrent=max_concurrent,
|
239
|
+
throttle_period=throttle_period,
|
240
|
+
**kwargs,
|
241
|
+
)
|
242
|
+
|
243
|
+
|
244
|
+
@dataclass(slots=True, init=False, frozen=True)
|
245
|
+
class AlcallParams(Params):
|
246
|
+
# ClassVar attributes
|
247
|
+
_none_as_sentinel: ClassVar[bool] = True
|
248
|
+
_func: ClassVar[Any] = alcall
|
249
|
+
|
250
|
+
# input processing
|
251
|
+
input_flatten: bool
|
252
|
+
input_dropna: bool
|
253
|
+
input_unique: bool
|
254
|
+
input_flatten_tuple_set: bool
|
255
|
+
|
256
|
+
# output processing
|
257
|
+
output_flatten: bool
|
258
|
+
output_dropna: bool
|
259
|
+
output_unique: bool
|
260
|
+
output_flatten_tuple_set: bool
|
261
|
+
|
262
|
+
# retry and timeout
|
263
|
+
delay_before_start: float
|
264
|
+
retry_initial_deplay: float
|
265
|
+
retry_backoff: float
|
266
|
+
retry_default: Any
|
267
|
+
retry_timeout: float
|
268
|
+
retry_attempts: int
|
269
|
+
|
270
|
+
# concurrency and throttling
|
271
|
+
max_concurrent: int
|
272
|
+
throttle_period: float
|
273
|
+
|
274
|
+
kw: dict[str, Any] = Unset
|
275
|
+
|
276
|
+
async def __call__(
|
277
|
+
self, input_: list[Any], func: Callable[..., T], **kw
|
278
|
+
) -> list[T]:
|
279
|
+
f = self.as_partial()
|
280
|
+
return await f(input_, func, **kw)
|
281
|
+
|
282
|
+
|
283
|
+
@dataclass(slots=True, init=False, frozen=True)
|
284
|
+
class BcallParams(AlcallParams):
|
285
|
+
_func: ClassVar[Any] = bcall
|
286
|
+
|
287
|
+
batch_size: int
|
288
|
+
|
289
|
+
async def __call__(
|
290
|
+
self, input_: list[Any], func: Callable[..., T], **kw
|
291
|
+
) -> list[T]:
|
292
|
+
f = self.as_partial()
|
293
|
+
return await f(input_, func, self.batch_size, **kw)
|
lionagi/ln/_list_call.py
ADDED
@@ -0,0 +1,129 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Any, Callable, ClassVar, Iterable, TypeVar
|
3
|
+
|
4
|
+
from ._models import Params
|
5
|
+
from ._to_list import to_list
|
6
|
+
|
7
|
+
R = TypeVar("R")
|
8
|
+
T = TypeVar("T")
|
9
|
+
|
10
|
+
__all__ = ("lcall", "LcallParams")
|
11
|
+
|
12
|
+
|
13
|
+
def lcall(
|
14
|
+
input_: Iterable[T] | T,
|
15
|
+
func: Callable[[T], R] | Iterable[Callable[[T], R]],
|
16
|
+
/,
|
17
|
+
*args: Any,
|
18
|
+
input_flatten: bool = False,
|
19
|
+
input_dropna: bool = False,
|
20
|
+
input_unique: bool = False,
|
21
|
+
input_use_values: bool = False,
|
22
|
+
input_flatten_tuple_set: bool = False,
|
23
|
+
output_flatten: bool = False,
|
24
|
+
output_dropna: bool = False,
|
25
|
+
output_unique: bool = False,
|
26
|
+
output_flatten_tuple_set: bool = False,
|
27
|
+
**kwargs: Any,
|
28
|
+
) -> list[R]:
|
29
|
+
"""Apply function to each element in input list with optional processing.
|
30
|
+
|
31
|
+
Maps a function over input elements and processes results. Can sanitize input
|
32
|
+
and output using various filtering options.
|
33
|
+
|
34
|
+
Raises:
|
35
|
+
ValueError: If func is not callable or unique_output used incorrectly.
|
36
|
+
TypeError: If func or input processing fails.
|
37
|
+
"""
|
38
|
+
# Validate and extract callable function
|
39
|
+
if not callable(func):
|
40
|
+
try:
|
41
|
+
func_list = list(func)
|
42
|
+
if len(func_list) != 1 or not callable(func_list[0]):
|
43
|
+
raise ValueError(
|
44
|
+
"func must contain exactly one callable function."
|
45
|
+
)
|
46
|
+
func = func_list[0]
|
47
|
+
except TypeError as e:
|
48
|
+
raise ValueError(
|
49
|
+
"func must be callable or iterable with one callable."
|
50
|
+
) from e
|
51
|
+
|
52
|
+
# Validate output processing options
|
53
|
+
if output_unique and not (output_flatten or output_dropna):
|
54
|
+
raise ValueError(
|
55
|
+
"unique_output requires flatten or dropna for post-processing."
|
56
|
+
)
|
57
|
+
|
58
|
+
# Process input based on sanitization flag
|
59
|
+
if input_flatten or input_dropna:
|
60
|
+
input_ = to_list(
|
61
|
+
input_,
|
62
|
+
flatten=input_flatten,
|
63
|
+
dropna=input_dropna,
|
64
|
+
unique=input_unique,
|
65
|
+
flatten_tuple_set=input_flatten_tuple_set,
|
66
|
+
use_values=input_use_values,
|
67
|
+
)
|
68
|
+
else:
|
69
|
+
if not isinstance(input_, list):
|
70
|
+
try:
|
71
|
+
input_ = list(input_)
|
72
|
+
except TypeError:
|
73
|
+
input_ = [input_]
|
74
|
+
|
75
|
+
# Process elements and collect results
|
76
|
+
out = []
|
77
|
+
append = out.append
|
78
|
+
|
79
|
+
for item in input_:
|
80
|
+
try:
|
81
|
+
result = func(item, *args, **kwargs)
|
82
|
+
append(result)
|
83
|
+
except InterruptedError:
|
84
|
+
return out
|
85
|
+
except Exception:
|
86
|
+
raise
|
87
|
+
|
88
|
+
# Apply output processing if requested
|
89
|
+
if output_flatten or output_dropna:
|
90
|
+
out = to_list(
|
91
|
+
out,
|
92
|
+
flatten=output_flatten,
|
93
|
+
dropna=output_dropna,
|
94
|
+
unique=output_unique,
|
95
|
+
flatten_tuple_set=output_flatten_tuple_set,
|
96
|
+
)
|
97
|
+
|
98
|
+
return out
|
99
|
+
|
100
|
+
|
101
|
+
@dataclass(slots=True, frozen=True, init=False)
|
102
|
+
class LcallParams(Params):
|
103
|
+
_func: ClassVar[Any] = lcall
|
104
|
+
|
105
|
+
# input processing
|
106
|
+
input_flatten: bool
|
107
|
+
"""If True, recursively flatten input to a flat list"""
|
108
|
+
input_dropna: bool
|
109
|
+
"""If True, remove None and undefined values from input."""
|
110
|
+
input_unique: bool
|
111
|
+
input_use_values: bool
|
112
|
+
input_flatten_tuple_set: bool
|
113
|
+
|
114
|
+
# output processing
|
115
|
+
output_flatten: bool
|
116
|
+
"""If True, recursively flatten output to a flat list."""
|
117
|
+
output_dropna: bool
|
118
|
+
"""If True, remove None and undefined values."""
|
119
|
+
output_unique: bool
|
120
|
+
"""If True, remove duplicates (requires output_flatten=True)."""
|
121
|
+
output_use_values: bool
|
122
|
+
"""If True, extract values from enums/mappings."""
|
123
|
+
output_flatten_tuple_set: bool
|
124
|
+
"""If True, include tuples and sets in flattening."""
|
125
|
+
|
126
|
+
def __call__(self, input_: Any, *args, **kw) -> list:
|
127
|
+
"""Convert parameters to a list."""
|
128
|
+
f = self.as_partial()
|
129
|
+
return f(input_, *args, **kw)
|
lionagi/ln/_models.py
ADDED
@@ -0,0 +1,126 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from functools import partial
|
3
|
+
from typing import Any, ClassVar
|
4
|
+
|
5
|
+
from typing_extensions import override
|
6
|
+
|
7
|
+
from ._types import Undefined, Unset, is_sentinel
|
8
|
+
|
9
|
+
__all__ = ("Params", "DataClass")
|
10
|
+
|
11
|
+
|
12
|
+
class _SentinelAware:
|
13
|
+
"""Metaclass to ensure sentinels are handled correctly in subclasses."""
|
14
|
+
|
15
|
+
_none_as_sentinel: ClassVar[bool] = False
|
16
|
+
"""If True, None is treated as a sentinel value."""
|
17
|
+
|
18
|
+
_strict: ClassVar[bool] = False
|
19
|
+
"""No sentinels allowed if strict is True."""
|
20
|
+
|
21
|
+
_prefill_unset: ClassVar[bool] = True
|
22
|
+
"""If True, unset fields are prefilled with Unset."""
|
23
|
+
|
24
|
+
_allowed_keys: ClassVar[set[str]] = field(
|
25
|
+
default=set(), init=False, repr=False
|
26
|
+
)
|
27
|
+
"""Class variable cache to store allowed keys for parameters."""
|
28
|
+
|
29
|
+
@classmethod
|
30
|
+
def allowed(cls) -> set[str]:
|
31
|
+
"""Return the keys of the parameters."""
|
32
|
+
if cls._allowed_keys:
|
33
|
+
return cls._allowed_keys
|
34
|
+
cls._allowed_keys = {
|
35
|
+
i for i in cls.__dataclass_fields__.keys() if not i.startswith("_")
|
36
|
+
}
|
37
|
+
return cls._allowed_keys
|
38
|
+
|
39
|
+
@classmethod
|
40
|
+
def _is_sentinel(cls, value: Any) -> bool:
|
41
|
+
"""Check if a value is a sentinel (Undefined or Unset)."""
|
42
|
+
if value is None and cls._none_as_sentinel:
|
43
|
+
return True
|
44
|
+
return is_sentinel(value)
|
45
|
+
|
46
|
+
def __post_init__(self):
|
47
|
+
"""Post-initialization to ensure all fields are set."""
|
48
|
+
self._validate()
|
49
|
+
|
50
|
+
def _validate(self) -> None:
|
51
|
+
pass
|
52
|
+
|
53
|
+
def to_dict(self) -> dict[str, str]:
|
54
|
+
data = {}
|
55
|
+
for k in self.allowed():
|
56
|
+
if not self._is_sentinel(v := getattr(self, k)):
|
57
|
+
data[k] = v
|
58
|
+
return data
|
59
|
+
|
60
|
+
|
61
|
+
@dataclass(slots=True, frozen=True, init=False)
|
62
|
+
class Params(_SentinelAware):
|
63
|
+
"""Base class for parameters used in various functions."""
|
64
|
+
|
65
|
+
_func: ClassVar[Any] = Unset
|
66
|
+
_particial_func: ClassVar[Any] = Unset
|
67
|
+
|
68
|
+
@override
|
69
|
+
def _validate(self) -> None:
|
70
|
+
def _validate_strict(k):
|
71
|
+
if self._strict and self._is_sentinel(getattr(self, k, Unset)):
|
72
|
+
raise ValueError(f"Missing required parameter: {k}")
|
73
|
+
if (
|
74
|
+
self._prefill_unset
|
75
|
+
and getattr(self, k, Undefined) is Undefined
|
76
|
+
):
|
77
|
+
object.__setattr__(self, k, Unset)
|
78
|
+
|
79
|
+
for k in self.allowed():
|
80
|
+
_validate_strict(k)
|
81
|
+
|
82
|
+
def as_partial(self) -> Any:
|
83
|
+
# if partial function is already cached, return it
|
84
|
+
if self._particial_func is not Unset:
|
85
|
+
return self._particial_func
|
86
|
+
|
87
|
+
# validate is there is a function to apply
|
88
|
+
if self._func is Unset:
|
89
|
+
raise ValueError("No function defined for partial application.")
|
90
|
+
if not callable(self._func):
|
91
|
+
raise TypeError(
|
92
|
+
f"Expected a callable, got {type(self._func).__name__}."
|
93
|
+
)
|
94
|
+
|
95
|
+
# create a partial function with the current parameters
|
96
|
+
dict_ = self.to_dict()
|
97
|
+
if not dict_:
|
98
|
+
self._particial_func = self._func
|
99
|
+
return self._func
|
100
|
+
|
101
|
+
# handle kwargs if present, handle both 'kwargs' and 'kw'
|
102
|
+
kw_ = {}
|
103
|
+
kw_.update(dict_.pop("kwargs", {}))
|
104
|
+
kw_.update(dict_.pop("kw", {}))
|
105
|
+
dict_.update(kw_)
|
106
|
+
self._particial_func = partial(self._func, **dict_)
|
107
|
+
return self._particial_func
|
108
|
+
|
109
|
+
|
110
|
+
@dataclass(slots=True)
|
111
|
+
class DataClass(_SentinelAware):
|
112
|
+
"""A base class for data classes with strict parameter handling."""
|
113
|
+
|
114
|
+
@override
|
115
|
+
def _validate(self) -> None:
|
116
|
+
def _validate_strict(k):
|
117
|
+
if self._strict and self._is_sentinel(getattr(self, k, Unset)):
|
118
|
+
raise ValueError(f"Missing required parameter: {k}")
|
119
|
+
if (
|
120
|
+
self._prefill_unset
|
121
|
+
and getattr(self, k, Undefined) is Undefined
|
122
|
+
):
|
123
|
+
self.__setattr__(k, Unset)
|
124
|
+
|
125
|
+
for k in self.allowed():
|
126
|
+
_validate_strict(k)
|