lionagi 0.15.9__py3-none-any.whl → 0.15.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/__init__.py +4 -6
- lionagi/adapters/async_postgres_adapter.py +56 -320
- lionagi/libs/file/_utils.py +10 -0
- lionagi/libs/file/process.py +16 -13
- lionagi/libs/unstructured/pdf_to_image.py +2 -2
- lionagi/libs/validate/string_similarity.py +4 -4
- lionagi/ln/__init__.py +28 -0
- lionagi/ln/_async_call.py +1 -0
- lionagi/ln/_extract_json.py +60 -0
- lionagi/ln/_fuzzy_json.py +116 -0
- lionagi/models/field_model.py +8 -6
- lionagi/operations/__init__.py +3 -0
- lionagi/operations/builder.py +10 -0
- lionagi/protocols/generic/element.py +120 -17
- lionagi/protocols/generic/pile.py +56 -1
- lionagi/protocols/generic/progression.py +11 -11
- lionagi/protocols/graph/_utils.py +22 -0
- lionagi/protocols/graph/graph.py +17 -21
- lionagi/protocols/graph/node.py +23 -3
- lionagi/protocols/messages/manager.py +41 -45
- lionagi/protocols/operatives/step.py +2 -19
- lionagi/protocols/types.py +1 -2
- lionagi/tools/file/reader.py +5 -6
- lionagi/utils.py +8 -385
- lionagi/version.py +1 -1
- {lionagi-0.15.9.dist-info → lionagi-0.15.13.dist-info}/METADATA +21 -16
- {lionagi-0.15.9.dist-info → lionagi-0.15.13.dist-info}/RECORD +29 -30
- lionagi/libs/package/__init__.py +0 -3
- lionagi/libs/package/imports.py +0 -21
- lionagi/libs/package/management.py +0 -62
- lionagi/libs/package/params.py +0 -30
- lionagi/libs/package/system.py +0 -22
- {lionagi-0.15.9.dist-info → lionagi-0.15.13.dist-info}/WHEEL +0 -0
- {lionagi-0.15.9.dist-info → lionagi-0.15.13.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,116 @@
|
|
1
|
+
import contextlib
|
2
|
+
import re
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import orjson
|
6
|
+
|
7
|
+
|
8
|
+
def fuzzy_json(str_to_parse: str, /) -> dict[str, Any] | list[dict[str, Any]]:
|
9
|
+
"""
|
10
|
+
Attempt to parse a JSON string, trying a few minimal "fuzzy" fixes if needed.
|
11
|
+
|
12
|
+
Steps:
|
13
|
+
1. Parse directly with json.loads.
|
14
|
+
2. Replace single quotes with double quotes, normalize spacing, and try again.
|
15
|
+
3. Attempt to fix unmatched brackets using fix_json_string.
|
16
|
+
4. If all fail, raise ValueError.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
str_to_parse: The JSON string to parse
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
Parsed JSON (dict or list of dicts)
|
23
|
+
|
24
|
+
Raises:
|
25
|
+
ValueError: If the string cannot be parsed as valid JSON
|
26
|
+
TypeError: If the input is not a string
|
27
|
+
"""
|
28
|
+
_check_valid_str(str_to_parse)
|
29
|
+
|
30
|
+
# 1. Direct attempt
|
31
|
+
with contextlib.suppress(Exception):
|
32
|
+
return orjson.loads(str_to_parse)
|
33
|
+
|
34
|
+
# 2. Try cleaning: replace single quotes with double and normalize
|
35
|
+
cleaned = _clean_json_string(str_to_parse.replace("'", '"'))
|
36
|
+
with contextlib.suppress(Exception):
|
37
|
+
return orjson.loads(cleaned)
|
38
|
+
|
39
|
+
# 3. Try fixing brackets
|
40
|
+
fixed = fix_json_string(cleaned)
|
41
|
+
with contextlib.suppress(Exception):
|
42
|
+
return orjson.loads(fixed)
|
43
|
+
|
44
|
+
# If all attempts fail
|
45
|
+
raise ValueError("Invalid JSON string")
|
46
|
+
|
47
|
+
|
48
|
+
def _check_valid_str(str_to_parse: str, /):
|
49
|
+
if not isinstance(str_to_parse, str):
|
50
|
+
raise TypeError("Input must be a string")
|
51
|
+
if not str_to_parse.strip():
|
52
|
+
raise ValueError("Input string is empty")
|
53
|
+
|
54
|
+
|
55
|
+
def _clean_json_string(s: str) -> str:
|
56
|
+
"""Basic normalization: replace unescaped single quotes, trim spaces, ensure keys are quoted."""
|
57
|
+
# Replace unescaped single quotes with double quotes
|
58
|
+
# '(?<!\\)'" means a single quote not preceded by a backslash
|
59
|
+
s = re.sub(r"(?<!\\)'", '"', s)
|
60
|
+
# Collapse multiple whitespaces
|
61
|
+
s = re.sub(r"\s+", " ", s)
|
62
|
+
# Ensure keys are quoted
|
63
|
+
# This attempts to find patterns like { key: value } and turn them into {"key": value}
|
64
|
+
s = re.sub(r'([{,])\s*([^"\s]+)\s*:', r'\1"\2":', s)
|
65
|
+
return s.strip()
|
66
|
+
|
67
|
+
|
68
|
+
def fix_json_string(str_to_parse: str, /) -> str:
|
69
|
+
"""Try to fix JSON string by ensuring brackets are matched properly."""
|
70
|
+
if not str_to_parse:
|
71
|
+
raise ValueError("Input string is empty")
|
72
|
+
|
73
|
+
brackets = {"{": "}", "[": "]"}
|
74
|
+
open_brackets = []
|
75
|
+
pos = 0
|
76
|
+
length = len(str_to_parse)
|
77
|
+
|
78
|
+
while pos < length:
|
79
|
+
char = str_to_parse[pos]
|
80
|
+
|
81
|
+
if char == "\\":
|
82
|
+
pos += 2 # Skip escaped chars
|
83
|
+
continue
|
84
|
+
|
85
|
+
if char == '"':
|
86
|
+
pos += 1
|
87
|
+
# skip string content
|
88
|
+
while pos < length:
|
89
|
+
if str_to_parse[pos] == "\\":
|
90
|
+
pos += 2
|
91
|
+
continue
|
92
|
+
if str_to_parse[pos] == '"':
|
93
|
+
pos += 1
|
94
|
+
break
|
95
|
+
pos += 1
|
96
|
+
continue
|
97
|
+
|
98
|
+
if char in brackets:
|
99
|
+
open_brackets.append(brackets[char])
|
100
|
+
elif char in brackets.values():
|
101
|
+
if not open_brackets:
|
102
|
+
# Extra closing bracket
|
103
|
+
# Better to raise error than guess
|
104
|
+
raise ValueError("Extra closing bracket found.")
|
105
|
+
if open_brackets[-1] != char:
|
106
|
+
# Mismatched bracket
|
107
|
+
raise ValueError("Mismatched brackets.")
|
108
|
+
open_brackets.pop()
|
109
|
+
|
110
|
+
pos += 1
|
111
|
+
|
112
|
+
# Add missing closing brackets if any
|
113
|
+
if open_brackets:
|
114
|
+
str_to_parse += "".join(reversed(open_brackets))
|
115
|
+
|
116
|
+
return str_to_parse
|
lionagi/models/field_model.py
CHANGED
@@ -16,7 +16,6 @@ from typing import Annotated, Any
|
|
16
16
|
from typing_extensions import Self, override
|
17
17
|
|
18
18
|
from .._errors import ValidationError
|
19
|
-
from ..utils import UNDEFINED
|
20
19
|
|
21
20
|
# Cache of valid Pydantic Field parameters
|
22
21
|
_PYDANTIC_FIELD_PARAMS: set[str] | None = None
|
@@ -660,13 +659,16 @@ def to_dict(self) -> dict[str, Any]:
|
|
660
659
|
|
661
660
|
# Convert metadata to dictionary
|
662
661
|
for meta in self.metadata:
|
663
|
-
if meta.key not in (
|
662
|
+
if meta.key not in (
|
663
|
+
"nullable",
|
664
|
+
"listable",
|
665
|
+
"validator",
|
666
|
+
"name",
|
667
|
+
"validator_kwargs",
|
668
|
+
"annotation",
|
669
|
+
):
|
664
670
|
result[meta.key] = meta.value
|
665
671
|
|
666
|
-
# Add annotation if available
|
667
|
-
if hasattr(self, "annotation"):
|
668
|
-
result["annotation"] = self.base_type
|
669
|
-
|
670
672
|
return result
|
671
673
|
|
672
674
|
|
lionagi/operations/__init__.py
CHANGED
@@ -8,6 +8,8 @@ from .flow import flow
|
|
8
8
|
from .node import BranchOperations, Operation
|
9
9
|
from .plan.plan import PlanOperation, plan
|
10
10
|
|
11
|
+
Builder = OperationGraphBuilder
|
12
|
+
|
11
13
|
__all__ = (
|
12
14
|
"ExpansionStrategy",
|
13
15
|
"OperationGraphBuilder",
|
@@ -19,4 +21,5 @@ __all__ = (
|
|
19
21
|
"PlanOperation",
|
20
22
|
"brainstorm",
|
21
23
|
"BrainstormOperation",
|
24
|
+
"Builder",
|
22
25
|
)
|
lionagi/operations/builder.py
CHANGED
@@ -453,6 +453,16 @@ def visualize_graph(
|
|
453
453
|
figsize=(14, 10),
|
454
454
|
):
|
455
455
|
"""Visualization with improved layout for complex graphs."""
|
456
|
+
from lionagi.protocols.graph.graph import (
|
457
|
+
_MATPLIB_AVAILABLE,
|
458
|
+
_NETWORKX_AVAILABLE,
|
459
|
+
)
|
460
|
+
|
461
|
+
if _MATPLIB_AVAILABLE is not True:
|
462
|
+
raise _MATPLIB_AVAILABLE
|
463
|
+
if _NETWORKX_AVAILABLE is not True:
|
464
|
+
raise _NETWORKX_AVAILABLE
|
465
|
+
|
456
466
|
import matplotlib.pyplot as plt
|
457
467
|
import networkx as nx
|
458
468
|
import numpy as np
|
@@ -6,7 +6,7 @@ from __future__ import annotations
|
|
6
6
|
|
7
7
|
import datetime as dt
|
8
8
|
from collections.abc import Mapping, Sequence
|
9
|
-
from typing import Any, Generic, TypeAlias, TypeVar
|
9
|
+
from typing import Any, Generic, Literal, TypeAlias, TypeVar
|
10
10
|
from uuid import UUID, uuid4
|
11
11
|
|
12
12
|
import orjson
|
@@ -22,7 +22,7 @@ from lionagi import ln
|
|
22
22
|
from lionagi._class_registry import get_class
|
23
23
|
from lionagi._errors import IDError
|
24
24
|
from lionagi.settings import Settings
|
25
|
-
from lionagi.utils import time, to_dict
|
25
|
+
from lionagi.utils import import_module, time, to_dict
|
26
26
|
|
27
27
|
from .._concepts import Collective, Observable, Ordering
|
28
28
|
|
@@ -190,11 +190,13 @@ class Element(BaseModel, Observable):
|
|
190
190
|
return val
|
191
191
|
|
192
192
|
@field_validator("created_at", mode="before")
|
193
|
-
def _coerce_created_at(
|
193
|
+
def _coerce_created_at(
|
194
|
+
cls, val: float | dt.datetime | str | None
|
195
|
+
) -> float:
|
194
196
|
"""Coerces `created_at` to a float-based timestamp.
|
195
197
|
|
196
198
|
Args:
|
197
|
-
val (float | datetime | None): The initial creation time value.
|
199
|
+
val (float | datetime | str | None): The initial creation time value.
|
198
200
|
|
199
201
|
Returns:
|
200
202
|
float: A float representing Unix epoch time in seconds.
|
@@ -208,6 +210,27 @@ class Element(BaseModel, Observable):
|
|
208
210
|
return val
|
209
211
|
if isinstance(val, dt.datetime):
|
210
212
|
return val.timestamp()
|
213
|
+
if isinstance(val, str):
|
214
|
+
# Parse datetime string from database
|
215
|
+
try:
|
216
|
+
# Handle datetime strings like "2025-08-30 10:54:59.310329"
|
217
|
+
# Convert space to T for ISO format, but handle timezone properly
|
218
|
+
iso_string = val.replace(" ", "T")
|
219
|
+
parsed_dt = dt.datetime.fromisoformat(iso_string)
|
220
|
+
|
221
|
+
# If parsed as naive datetime (no timezone), treat as UTC to avoid local timezone issues
|
222
|
+
if parsed_dt.tzinfo is None:
|
223
|
+
parsed_dt = parsed_dt.replace(tzinfo=dt.timezone.utc)
|
224
|
+
|
225
|
+
return parsed_dt.timestamp()
|
226
|
+
except ValueError:
|
227
|
+
# Try parsing as float string as fallback
|
228
|
+
try:
|
229
|
+
return float(val)
|
230
|
+
except ValueError:
|
231
|
+
raise ValueError(
|
232
|
+
f"Invalid datetime string: {val}"
|
233
|
+
) from None
|
211
234
|
try:
|
212
235
|
return float(val) # type: ignore
|
213
236
|
except Exception:
|
@@ -245,7 +268,7 @@ class Element(BaseModel, Observable):
|
|
245
268
|
Returns:
|
246
269
|
datetime: The creation time in UTC.
|
247
270
|
"""
|
248
|
-
return dt.datetime.fromtimestamp(self.created_at)
|
271
|
+
return dt.datetime.fromtimestamp(self.created_at, tz=dt.timezone.utc)
|
249
272
|
|
250
273
|
def __eq__(self, other: Any) -> bool:
|
251
274
|
"""Compares two Element instances by their ID."""
|
@@ -274,14 +297,31 @@ class Element(BaseModel, Observable):
|
|
274
297
|
return str(cls).split("'")[1]
|
275
298
|
return cls.__name__
|
276
299
|
|
277
|
-
def
|
278
|
-
"""Converts this Element to a dictionary."""
|
300
|
+
def _to_dict(self) -> dict:
|
279
301
|
dict_ = self.model_dump()
|
280
302
|
dict_["metadata"].update({"lion_class": self.class_name(full=True)})
|
281
303
|
return {k: v for k, v in dict_.items() if ln.not_sentinel(v)}
|
282
304
|
|
305
|
+
def to_dict(
|
306
|
+
self, mode: Literal["python", "json", "db"] = "python"
|
307
|
+
) -> dict:
|
308
|
+
"""Converts this Element to a dictionary."""
|
309
|
+
if mode == "python":
|
310
|
+
return self._to_dict()
|
311
|
+
if mode == "json":
|
312
|
+
return orjson.loads(self.to_json(decode=False))
|
313
|
+
if mode == "db":
|
314
|
+
dict_ = orjson.loads(self.to_json(decode=False))
|
315
|
+
dict_["node_metadata"] = dict_.pop("metadata", {})
|
316
|
+
dict_["created_at"] = self.created_datetime.isoformat(sep=" ")
|
317
|
+
return dict_
|
318
|
+
|
319
|
+
def as_jsonable(self) -> dict:
|
320
|
+
"""Converts this Element to a JSON-serializable dictionary."""
|
321
|
+
return self.to_dict(mode="json")
|
322
|
+
|
283
323
|
@classmethod
|
284
|
-
def from_dict(cls, data: dict,
|
324
|
+
def from_dict(cls, data: dict, /, mode: str = "python") -> Element:
|
285
325
|
"""Deserializes a dictionary into an Element or subclass of Element.
|
286
326
|
|
287
327
|
If `lion_class` in `metadata` refers to a subclass, this method
|
@@ -289,8 +329,17 @@ class Element(BaseModel, Observable):
|
|
289
329
|
|
290
330
|
Args:
|
291
331
|
data (dict): A dictionary of field data.
|
332
|
+
mode (str): Format mode - "python" for normal dicts, "db" for database format.
|
292
333
|
"""
|
293
|
-
|
334
|
+
# Preprocess database format if needed
|
335
|
+
if mode == "db":
|
336
|
+
data = cls._preprocess_db_data(data.copy())
|
337
|
+
metadata = {}
|
338
|
+
|
339
|
+
if "node_metadata" in data:
|
340
|
+
metadata = data.pop("node_metadata")
|
341
|
+
elif "metadata" in data:
|
342
|
+
metadata = data.pop("metadata")
|
294
343
|
if "lion_class" in metadata:
|
295
344
|
subcls: str = metadata.pop("lion_class")
|
296
345
|
if subcls != Element.class_name(full=True):
|
@@ -308,9 +357,6 @@ class Element(BaseModel, Observable):
|
|
308
357
|
return subcls_type.from_dict(data)
|
309
358
|
|
310
359
|
except Exception:
|
311
|
-
# Fallback attempt: direct import if not in registry
|
312
|
-
from lionagi.libs.package.imports import import_module
|
313
|
-
|
314
360
|
mod, imp = subcls.rsplit(".", 1)
|
315
361
|
subcls_type = import_module(mod, import_name=imp)
|
316
362
|
data["metadata"] = metadata
|
@@ -321,15 +367,72 @@ class Element(BaseModel, Observable):
|
|
321
367
|
data["metadata"] = metadata
|
322
368
|
return cls.model_validate(data)
|
323
369
|
|
324
|
-
|
370
|
+
@classmethod
|
371
|
+
def _preprocess_db_data(cls, data: dict) -> dict:
|
372
|
+
"""Preprocess raw database data for Element compatibility."""
|
373
|
+
import datetime as dt
|
374
|
+
import json
|
375
|
+
|
376
|
+
# Handle created_at field - convert datetime string to timestamp
|
377
|
+
if "created_at" in data and isinstance(data["created_at"], str):
|
378
|
+
try:
|
379
|
+
# Parse datetime string and convert to timestamp
|
380
|
+
dt_obj = dt.datetime.fromisoformat(
|
381
|
+
data["created_at"].replace(" ", "T")
|
382
|
+
)
|
383
|
+
# Treat as UTC if naive
|
384
|
+
if dt_obj.tzinfo is None:
|
385
|
+
dt_obj = dt_obj.replace(tzinfo=dt.timezone.utc)
|
386
|
+
data["created_at"] = dt_obj.timestamp()
|
387
|
+
except (ValueError, TypeError):
|
388
|
+
# Keep as string if parsing fails
|
389
|
+
pass
|
390
|
+
|
391
|
+
# Handle JSON string fields - parse to dict/list
|
392
|
+
json_fields = ["content", "node_metadata", "embedding"]
|
393
|
+
for field in json_fields:
|
394
|
+
if field in data and isinstance(data[field], str):
|
395
|
+
if data[field] in ("null", ""):
|
396
|
+
data[field] = None if field == "embedding" else {}
|
397
|
+
else:
|
398
|
+
try:
|
399
|
+
data[field] = json.loads(data[field])
|
400
|
+
except (json.JSONDecodeError, TypeError):
|
401
|
+
# Keep as empty dict for metadata fields, None for embedding
|
402
|
+
data[field] = {} if field != "embedding" else None
|
403
|
+
|
404
|
+
# Handle node_metadata -> metadata mapping
|
405
|
+
if "node_metadata" in data:
|
406
|
+
if (
|
407
|
+
data["node_metadata"] == "null"
|
408
|
+
or data["node_metadata"] is None
|
409
|
+
):
|
410
|
+
data["metadata"] = {}
|
411
|
+
else:
|
412
|
+
data["metadata"] = (
|
413
|
+
data["node_metadata"] if data["node_metadata"] else {}
|
414
|
+
)
|
415
|
+
# Remove node_metadata to avoid Pydantic validation error
|
416
|
+
data.pop("node_metadata", None)
|
417
|
+
|
418
|
+
return data
|
419
|
+
|
420
|
+
def to_json(self, decode: bool = True) -> str:
|
325
421
|
"""Converts this Element to a JSON string."""
|
326
|
-
dict_ = self.
|
327
|
-
|
422
|
+
dict_ = self._to_dict()
|
423
|
+
if decode:
|
424
|
+
return orjson.dumps(
|
425
|
+
dict_,
|
426
|
+
default=DEFAULT_ELEMENT_SERIALIZER,
|
427
|
+
option=ln.DEFAULT_SERIALIZER_OPTION,
|
428
|
+
).decode()
|
429
|
+
return orjson.dumps(dict_, default=DEFAULT_ELEMENT_SERIALIZER)
|
328
430
|
|
329
|
-
|
431
|
+
@classmethod
|
432
|
+
def from_json(cls, json_str: str, mode: str = "python") -> Element:
|
330
433
|
"""Deserializes a JSON string into an Element or subclass of Element."""
|
331
434
|
data = orjson.loads(json_str)
|
332
|
-
return cls.from_dict(data)
|
435
|
+
return cls.from_dict(data, mode=mode)
|
333
436
|
|
334
437
|
|
335
438
|
DEFAULT_ELEMENT_SERIALIZER = ln.get_orjson_default(
|
@@ -19,7 +19,7 @@ from pathlib import Path
|
|
19
19
|
from typing import Any, ClassVar, Generic, Literal, TypeVar
|
20
20
|
|
21
21
|
import pandas as pd
|
22
|
-
from pydantic import Field, field_serializer
|
22
|
+
from pydantic import Field, field_serializer
|
23
23
|
from pydantic.fields import FieldInfo
|
24
24
|
from pydapter import Adaptable, AsyncAdaptable
|
25
25
|
from typing_extensions import Self, deprecated, override
|
@@ -988,6 +988,12 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
|
|
988
988
|
is_same_dtype(self.collections.values())
|
989
989
|
)
|
990
990
|
|
991
|
+
@classmethod
|
992
|
+
def list_adapters(cls) -> list[str]:
|
993
|
+
syn_ = cls._adapter_registry._reg.keys()
|
994
|
+
asy_ = cls._async_registry._reg.keys()
|
995
|
+
return list(set(syn_) | set(asy_))
|
996
|
+
|
991
997
|
def adapt_to(self, obj_key: str, many=False, **kw: Any) -> Any:
|
992
998
|
"""Adapt to another format.
|
993
999
|
|
@@ -1132,6 +1138,55 @@ class Pile(Element, Collective[T], Generic[T], Adaptable, AsyncAdaptable):
|
|
1132
1138
|
) -> None:
|
1133
1139
|
return self.dump(fp, obj_key=obj_key, mode=mode, clear=clear, **kw)
|
1134
1140
|
|
1141
|
+
def filter_by_type(
|
1142
|
+
self,
|
1143
|
+
item_type: type[T] | list | set,
|
1144
|
+
strict_type: bool = False,
|
1145
|
+
as_pile: bool = False,
|
1146
|
+
reverse: bool = False,
|
1147
|
+
num_items: int | None = None,
|
1148
|
+
) -> list[T]:
|
1149
|
+
if isinstance(item_type, type):
|
1150
|
+
if is_union_type(item_type):
|
1151
|
+
item_type = set(union_members(item_type))
|
1152
|
+
else:
|
1153
|
+
item_type = {item_type}
|
1154
|
+
|
1155
|
+
if isinstance(item_type, list | tuple):
|
1156
|
+
item_type = set(item_type)
|
1157
|
+
|
1158
|
+
if not isinstance(item_type, set):
|
1159
|
+
raise TypeError("item_type must be a type or a list/set of types")
|
1160
|
+
|
1161
|
+
meth = None
|
1162
|
+
|
1163
|
+
if strict_type:
|
1164
|
+
meth = lambda item: type(item) in item_type
|
1165
|
+
else:
|
1166
|
+
meth = (
|
1167
|
+
lambda item: any(isinstance(item, t) for t in item_type)
|
1168
|
+
is True
|
1169
|
+
)
|
1170
|
+
|
1171
|
+
out = []
|
1172
|
+
prog = (
|
1173
|
+
list(self.progression)
|
1174
|
+
if not reverse
|
1175
|
+
else reversed(list(self.progression))
|
1176
|
+
)
|
1177
|
+
for i in prog:
|
1178
|
+
item = self.collections[i]
|
1179
|
+
if meth(item):
|
1180
|
+
out.append(item)
|
1181
|
+
if num_items is not None and len(out) == num_items:
|
1182
|
+
break
|
1183
|
+
|
1184
|
+
if as_pile:
|
1185
|
+
return self.__class__(
|
1186
|
+
collections=out, item_type=item_type, strict_type=strict_type
|
1187
|
+
)
|
1188
|
+
return out
|
1189
|
+
|
1135
1190
|
|
1136
1191
|
def to_list_type(value: Any, /) -> list[Any]:
|
1137
1192
|
"""Convert input to a list format"""
|
@@ -14,7 +14,7 @@ from lionagi._errors import ItemNotFoundError
|
|
14
14
|
from .._concepts import Ordering
|
15
15
|
from .element import ID, Element, IDError, IDType, validate_order
|
16
16
|
|
17
|
-
|
17
|
+
T = TypeVar("T", bound=Element)
|
18
18
|
|
19
19
|
|
20
20
|
__all__ = (
|
@@ -23,7 +23,7 @@ __all__ = (
|
|
23
23
|
)
|
24
24
|
|
25
25
|
|
26
|
-
class Progression(Element, Ordering[
|
26
|
+
class Progression(Element, Ordering[T], Generic[T]):
|
27
27
|
"""Tracks an ordered sequence of item IDs, with optional naming.
|
28
28
|
|
29
29
|
This class extends `Element` and implements `Ordering`, providing
|
@@ -39,7 +39,7 @@ class Progression(Element, Ordering[E], Generic[E]):
|
|
39
39
|
An optional human-readable identifier for the progression.
|
40
40
|
"""
|
41
41
|
|
42
|
-
order: list[ID[
|
42
|
+
order: list[ID[T].ID] = Field(
|
43
43
|
default_factory=list,
|
44
44
|
title="Order",
|
45
45
|
description="A sequence of IDs representing the progression.",
|
@@ -358,7 +358,7 @@ class Progression(Element, Ordering[E], Generic[E]):
|
|
358
358
|
raise ValueError("Can only extend with another Progression.")
|
359
359
|
self.order.extend(other.order)
|
360
360
|
|
361
|
-
def __add__(self, other: Any) -> Progression[
|
361
|
+
def __add__(self, other: Any) -> Progression[T]:
|
362
362
|
"""Returns a new Progression with IDs from both this and `other`.
|
363
363
|
|
364
364
|
Args:
|
@@ -371,7 +371,7 @@ class Progression(Element, Ordering[E], Generic[E]):
|
|
371
371
|
new_refs = validate_order(other)
|
372
372
|
return Progression(order=self.order + new_refs)
|
373
373
|
|
374
|
-
def __radd__(self, other: Any) -> Progression[
|
374
|
+
def __radd__(self, other: Any) -> Progression[T]:
|
375
375
|
"""Returns a new Progression with IDs from `other` + this.
|
376
376
|
|
377
377
|
Args:
|
@@ -396,7 +396,7 @@ class Progression(Element, Ordering[E], Generic[E]):
|
|
396
396
|
self.append(other)
|
397
397
|
return self
|
398
398
|
|
399
|
-
def __sub__(self, other: Any) -> Progression[
|
399
|
+
def __sub__(self, other: Any) -> Progression[T]:
|
400
400
|
"""Returns a new Progression excluding specified IDs.
|
401
401
|
|
402
402
|
Args:
|
@@ -435,7 +435,7 @@ class Progression(Element, Ordering[E], Generic[E]):
|
|
435
435
|
for i in reversed(item_):
|
436
436
|
self.order.insert(index, ID.get_id(i))
|
437
437
|
|
438
|
-
def
|
438
|
+
def __reversed__(self) -> Progression[T]:
|
439
439
|
"""Returns a new reversed Progression.
|
440
440
|
|
441
441
|
Returns:
|
@@ -456,19 +456,19 @@ class Progression(Element, Ordering[E], Generic[E]):
|
|
456
456
|
return NotImplemented
|
457
457
|
return (self.order == other.order) and (self.name == other.name)
|
458
458
|
|
459
|
-
def __gt__(self, other: Progression[
|
459
|
+
def __gt__(self, other: Progression[T]) -> bool:
|
460
460
|
"""Compares if this progression is "greater" by ID order."""
|
461
461
|
return self.order > other.order
|
462
462
|
|
463
|
-
def __lt__(self, other: Progression[
|
463
|
+
def __lt__(self, other: Progression[T]) -> bool:
|
464
464
|
"""Compares if this progression is "less" by ID order."""
|
465
465
|
return self.order < other.order
|
466
466
|
|
467
|
-
def __ge__(self, other: Progression[
|
467
|
+
def __ge__(self, other: Progression[T]) -> bool:
|
468
468
|
"""Compares if this progression is >= the other by ID order."""
|
469
469
|
return self.order >= other.order
|
470
470
|
|
471
|
-
def __le__(self, other: Progression[
|
471
|
+
def __le__(self, other: Progression[T]) -> bool:
|
472
472
|
"""Compares if this progression is <= the other by ID order."""
|
473
473
|
return self.order <= other.order
|
474
474
|
|
@@ -0,0 +1,22 @@
|
|
1
|
+
def check_networkx_available():
|
2
|
+
try:
|
3
|
+
from networkx import DiGraph # noqa: F401
|
4
|
+
|
5
|
+
return True
|
6
|
+
except Exception:
|
7
|
+
return ImportError(
|
8
|
+
"The 'networkx' package is required for this feature. "
|
9
|
+
"Please install `networkx` or `'lionagi[graph]'`."
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
def check_matplotlib_available():
|
14
|
+
try:
|
15
|
+
import matplotlib.pyplot as plt
|
16
|
+
|
17
|
+
return True
|
18
|
+
except Exception:
|
19
|
+
return ImportError(
|
20
|
+
"The 'matplotlib' package is required for this feature. "
|
21
|
+
"Please install `matplotlib` or `'lionagi[graph]'`."
|
22
|
+
)
|
lionagi/protocols/graph/graph.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
4
4
|
|
5
5
|
from collections import deque
|
6
|
-
from typing import Any, Literal
|
6
|
+
from typing import Any, Generic, Literal, TypeVar
|
7
7
|
|
8
8
|
from pydantic import Field, field_serializer, field_validator, model_validator
|
9
9
|
from typing_extensions import Self
|
@@ -17,11 +17,17 @@ from ..generic.pile import Pile
|
|
17
17
|
from .edge import Edge
|
18
18
|
from .node import Node
|
19
19
|
|
20
|
+
T = TypeVar("T", bound=Node)
|
21
|
+
|
22
|
+
from ._utils import check_matplotlib_available, check_networkx_available
|
23
|
+
|
24
|
+
_NETWORKX_AVAILABLE = check_networkx_available()
|
25
|
+
_MATPLIB_AVAILABLE = check_matplotlib_available()
|
20
26
|
__all__ = ("Graph",)
|
21
27
|
|
22
28
|
|
23
|
-
class Graph(Element, Relational):
|
24
|
-
internal_nodes: Pile[
|
29
|
+
class Graph(Element, Relational, Generic[T]):
|
30
|
+
internal_nodes: Pile[T] = Field(
|
25
31
|
default_factory=lambda: Pile(item_type={Node}, strict_type=False),
|
26
32
|
title="Internal Nodes",
|
27
33
|
description="A collection of nodes in the graph.",
|
@@ -214,13 +220,10 @@ class Graph(Element, Relational):
|
|
214
220
|
|
215
221
|
def to_networkx(self, **kwargs) -> Any:
|
216
222
|
"""Convert the graph to a NetworkX graph object."""
|
217
|
-
|
218
|
-
|
223
|
+
if _NETWORKX_AVAILABLE is not True:
|
224
|
+
raise _NETWORKX_AVAILABLE
|
219
225
|
|
220
|
-
|
221
|
-
from lionagi.libs.package.imports import check_import
|
222
|
-
|
223
|
-
DiGraph = check_import("networkx", import_name="DiGraph")
|
226
|
+
from networkx import DiGraph # type: ignore
|
224
227
|
|
225
228
|
g = DiGraph(**kwargs)
|
226
229
|
for node in self.internal_nodes:
|
@@ -245,20 +248,13 @@ class Graph(Element, Relational):
|
|
245
248
|
**kwargs,
|
246
249
|
):
|
247
250
|
"""Display the graph using NetworkX and Matplotlib."""
|
251
|
+
g = self.to_networkx(**kwargs)
|
252
|
+
if _MATPLIB_AVAILABLE is not True:
|
253
|
+
raise _MATPLIB_AVAILABLE
|
248
254
|
|
249
|
-
|
250
|
-
|
251
|
-
import networkx as nx # type: ignore
|
252
|
-
except ImportError:
|
253
|
-
from lionagi.libs.package.imports import check_import
|
254
|
-
|
255
|
-
check_import("matplotlib")
|
256
|
-
check_import("networkx")
|
257
|
-
|
258
|
-
import matplotlib.pyplot as plt # type: ignore
|
259
|
-
import networkx as nx # type: ignore
|
255
|
+
import matplotlib.pyplot as plt # type: ignore
|
256
|
+
import networkx as nx # type: ignore
|
260
257
|
|
261
|
-
g = self.to_networkx(**kwargs)
|
262
258
|
pos = nx.spring_layout(g)
|
263
259
|
nx.draw(
|
264
260
|
g,
|