lionagi 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (43) hide show
  1. lionagi/core/collections/abc/component.py +26 -29
  2. lionagi/core/collections/abc/concepts.py +0 -6
  3. lionagi/core/collections/flow.py +0 -1
  4. lionagi/core/collections/model.py +3 -2
  5. lionagi/core/collections/pile.py +1 -1
  6. lionagi/core/collections/progression.py +4 -5
  7. lionagi/core/director/models/__init__.py +13 -0
  8. lionagi/core/director/models/action_model.py +61 -0
  9. lionagi/core/director/models/brainstorm_model.py +42 -0
  10. lionagi/core/director/models/plan_model.py +51 -0
  11. lionagi/core/director/models/reason_model.py +63 -0
  12. lionagi/core/director/models/step_model.py +65 -0
  13. lionagi/core/director/operations/__init__.py +0 -0
  14. lionagi/core/director/operations/select.py +93 -0
  15. lionagi/core/director/operations/utils.py +6 -0
  16. lionagi/core/generic/registry/component_registry/__init__.py +0 -0
  17. lionagi/core/operations/__init__.py +0 -0
  18. lionagi/core/operations/chat/__init__.py +0 -0
  19. lionagi/core/operations/direct/__init__.py +0 -0
  20. lionagi/core/operative/__init__.py +0 -0
  21. lionagi/core/unit/unit_mixin.py +3 -3
  22. lionagi/integrations/langchain_/__init__.py +0 -0
  23. lionagi/integrations/llamaindex_/__init__.py +0 -0
  24. lionagi/libs/sys_util.py +196 -32
  25. lionagi/lions/director/__init__.py +0 -0
  26. lionagi/operations/brainstorm/__init__.py +0 -0
  27. lionagi/operations/brainstorm.py +0 -0
  28. lionagi/operations/chat/__init__.py +0 -0
  29. lionagi/operations/models/__init__.py +0 -0
  30. lionagi/operations/plan/__init__.py +0 -0
  31. lionagi/operations/plan/base.py +0 -0
  32. lionagi/operations/query/__init__.py +0 -0
  33. lionagi/operations/rank/__init__.py +0 -0
  34. lionagi/operations/react/__init__.py +0 -0
  35. lionagi/operations/route/__init__.py +0 -0
  36. lionagi/operations/score/__init__.py +0 -0
  37. lionagi/operations/select/__init__.py +0 -0
  38. lionagi/operations/strategize/__init__.py +0 -0
  39. lionagi/version.py +1 -1
  40. {lionagi-0.3.5.dist-info → lionagi-0.3.7.dist-info}/METADATA +3 -2
  41. {lionagi-0.3.5.dist-info → lionagi-0.3.7.dist-info}/RECORD +43 -13
  42. {lionagi-0.3.5.dist-info → lionagi-0.3.7.dist-info}/LICENSE +0 -0
  43. {lionagi-0.3.5.dist-info → lionagi-0.3.7.dist-info}/WHEEL +0 -0
@@ -1,11 +1,11 @@
1
1
  """Component class, base building block in LionAGI."""
2
2
 
3
3
  import contextlib
4
- from abc import ABC
5
- from collections.abc import Sequence
6
4
  from functools import singledispatchmethod
7
- from typing import Any, Type, TypeAlias, TypeVar, Union
5
+ from typing import Any, TypeAlias, TypeVar, Union
8
6
 
7
+ import lionfuncs as ln
8
+ from lionabc import Observable
9
9
  from pandas import DataFrame, Series
10
10
  from pydantic import AliasChoices, BaseModel, Field, ValidationError
11
11
 
@@ -22,7 +22,13 @@ T = TypeVar("T")
22
22
  _init_class = {}
23
23
 
24
24
 
25
- class Element(BaseModel, ABC):
25
+ def change_dict_key(dict_: dict, old_key: str, new_key: str) -> None:
26
+ """Change a key in a dictionary."""
27
+ if old_key in dict_:
28
+ dict_[new_key] = dict_.pop(old_key)
29
+
30
+
31
+ class Element(BaseModel, Observable):
26
32
  """Base class for elements within the LionAGI system.
27
33
 
28
34
  Attributes:
@@ -31,15 +37,14 @@ class Element(BaseModel, ABC):
31
37
  """
32
38
 
33
39
  ln_id: str = Field(
34
- default_factory=SysUtil.create_id,
40
+ default_factory=SysUtil.id,
35
41
  title="ID",
36
- description="A 32-char unique hash identifier.",
37
42
  frozen=True,
38
43
  validation_alias=AliasChoices("node_id", "ID", "id"),
39
44
  )
40
45
 
41
46
  timestamp: str = Field(
42
- default_factory=lambda: SysUtil.get_timestamp(sep=None)[:-6],
47
+ default_factory=lambda: ln.time(type_="iso"),
43
48
  title="Creation Timestamp",
44
49
  description="The UTC timestamp of creation",
45
50
  frozen=True,
@@ -57,7 +62,7 @@ class Element(BaseModel, ABC):
57
62
  return True
58
63
 
59
64
 
60
- class Component(Element, ABC):
65
+ class Component(Element):
61
66
  """
62
67
  Represents a distinguishable, temporal entity in LionAGI.
63
68
 
@@ -195,17 +200,15 @@ class Component(Element, ABC):
195
200
  """Create a Component instance from a LlamaIndex object."""
196
201
  dict_ = obj.to_dict()
197
202
 
198
- SysUtil.change_dict_key(dict_, "text", "content")
203
+ change_dict_key(dict_, "text", "content")
199
204
  metadata = dict_.pop("metadata", {})
200
205
 
201
206
  for field in llama_meta_fields:
202
207
  metadata[field] = dict_.pop(field, None)
203
208
 
204
- SysUtil.change_dict_key(metadata, "class_name", "llama_index_class")
205
- SysUtil.change_dict_key(metadata, "id_", "llama_index_id")
206
- SysUtil.change_dict_key(
207
- metadata, "relationships", "llama_index_relationships"
208
- )
209
+ change_dict_key(metadata, "class_name", "llama_index_class")
210
+ change_dict_key(metadata, "id_", "llama_index_id")
211
+ change_dict_key(metadata, "relationships", "llama_index_relationships")
209
212
 
210
213
  dict_["metadata"] = metadata
211
214
  return cls.from_obj(dict_)
@@ -244,7 +247,7 @@ class Component(Element, ABC):
244
247
  @classmethod
245
248
  def _process_langchain_dict(cls, dict_: dict) -> dict:
246
249
  """Process a dictionary containing Langchain-specific data."""
247
- SysUtil.change_dict_key(dict_, "page_content", "content")
250
+ change_dict_key(dict_, "page_content", "content")
248
251
 
249
252
  metadata = dict_.pop("metadata", {})
250
253
  metadata.update(dict_.pop("kwargs", {}))
@@ -264,9 +267,9 @@ class Component(Element, ABC):
264
267
  if field in dict_:
265
268
  metadata[field] = dict_.pop(field)
266
269
 
267
- SysUtil.change_dict_key(metadata, "lc", "langchain")
268
- SysUtil.change_dict_key(metadata, "type", "lc_type")
269
- SysUtil.change_dict_key(metadata, "id", "lc_id")
270
+ change_dict_key(metadata, "lc", "langchain")
271
+ change_dict_key(metadata, "type", "lc_type")
272
+ change_dict_key(metadata, "id", "lc_id")
270
273
 
271
274
  extra_fields = {
272
275
  k: v for k, v in metadata.items() if k not in lc_meta_fields
@@ -298,9 +301,9 @@ class Component(Element, ABC):
298
301
  dict_["metadata"] = meta_
299
302
 
300
303
  if "ln_id" not in dict_:
301
- dict_["ln_id"] = meta_.pop("ln_id", SysUtil.create_id())
304
+ dict_["ln_id"] = meta_.pop("ln_id", SysUtil.id())
302
305
  if "timestamp" not in dict_:
303
- dict_["timestamp"] = SysUtil.get_timestamp(sep=None)[:-6]
306
+ dict_["timestamp"] = ln.time(type_="iso")
304
307
  if "metadata" not in dict_:
305
308
  dict_["metadata"] = {}
306
309
  if "extra_fields" not in dict_:
@@ -453,13 +456,13 @@ class Component(Element, ABC):
453
456
  ninsert(
454
457
  self.metadata,
455
458
  ["last_updated", name],
456
- SysUtil.get_timestamp(sep=None)[:-6],
459
+ ln.time(type_="iso")[:-6],
457
460
  )
458
461
  elif isinstance(a, tuple) and isinstance(a[0], int):
459
462
  nset(
460
463
  self.metadata,
461
464
  ["last_updated", name],
462
- SysUtil.get_timestamp(sep=None)[:-6],
465
+ ln.time(type_="iso")[:-6],
463
466
  )
464
467
 
465
468
  def _meta_pop(self, indices, default=...):
@@ -614,10 +617,4 @@ LionIDable: TypeAlias = Union[str, Element]
614
617
 
615
618
  def get_lion_id(item: LionIDable) -> str:
616
619
  """Get the Lion ID of an item."""
617
- if isinstance(item, Sequence) and len(item) == 1:
618
- item = item[0]
619
- if isinstance(item, str) and len(item) == 32:
620
- return item
621
- if getattr(item, "ln_id", None) is not None:
622
- return item.ln_id
623
- raise LionTypeError("Item must be a single LionIDable object.")
620
+ return SysUtil.get_id(item)
@@ -227,12 +227,6 @@ class Sendable(BaseModel, ABC):
227
227
  return value
228
228
 
229
229
  a = get_lion_id(value)
230
- if not isinstance(a, str) or len(a) != 32:
231
- raise LionTypeError(
232
- "Invalid sender or recipient value. "
233
- "Expected a valid node id or one of "
234
- "'system' or 'user'."
235
- )
236
230
  return a
237
231
 
238
232
 
@@ -1,7 +1,6 @@
1
1
  import contextlib
2
2
  from collections import deque
3
3
  from collections.abc import Mapping
4
- from typing import Tuple
5
4
 
6
5
  from pydantic import Field
7
6
 
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  import os
3
3
 
4
+ import lionfuncs as ln
4
5
  import numpy as np
5
6
  from dotenv import load_dotenv
6
7
 
@@ -91,8 +92,8 @@ class iModel:
91
92
  service (BaseService, optional): An instance of BaseService.
92
93
  **kwargs: Additional parameters for the model.
93
94
  """
94
- self.ln_id: str = SysUtil.create_id()
95
- self.timestamp: str = SysUtil.get_timestamp(sep=None)[:-6]
95
+ self.ln_id: str = SysUtil.id()
96
+ self.timestamp: str = ln.time(type_="iso")
96
97
  self.endpoint = endpoint
97
98
  self.allowed_parameters = allowed_parameters
98
99
  if isinstance(provider, type):
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  from collections.abc import AsyncIterator, Callable, Iterable
5
5
  from functools import wraps
6
- from typing import Any, Generic, Type, TypeVar
6
+ from typing import Any, Generic, TypeVar
7
7
 
8
8
  from pydantic import Field, field_validator
9
9
 
@@ -1,9 +1,8 @@
1
1
  import contextlib
2
2
 
3
+ import lionfuncs as ln
3
4
  from pydantic import Field, field_validator
4
5
 
5
- from lionagi.libs import SysUtil
6
-
7
6
  from .abc import Element, ItemNotFoundError, LionIDable, Ordering, get_lion_id
8
7
  from .util import _validate_order
9
8
 
@@ -90,7 +89,7 @@ class Progression(Element, Ordering):
90
89
  """Remove the next occurrence of an item from the progression."""
91
90
  if item in self:
92
91
  item = self._validate_order(item)
93
- l_ = SysUtil.create_copy(self.order)
92
+ l_ = ln.copy(self.order)
94
93
 
95
94
  with contextlib.suppress(Exception):
96
95
  for i in item:
@@ -143,7 +142,7 @@ class Progression(Element, Ordering):
143
142
  def __radd__(self, other):
144
143
  if not isinstance(other, Progression):
145
144
  _copy = self.copy()
146
- l_ = SysUtil.create_copy(_copy.order)
145
+ l_ = ln.copy(_copy.order)
147
146
  l_.insert(0, get_lion_id(other))
148
147
  _copy.order = l_
149
148
  return _copy
@@ -190,7 +189,7 @@ class Progression(Element, Ordering):
190
189
 
191
190
  def __list__(self):
192
191
  """Return a list representation of the progression."""
193
- return SysUtil.create_copy(self.order)
192
+ return ln.copy(self.order)
194
193
 
195
194
  def __reversed__(self):
196
195
  """Return a reversed progression."""
@@ -0,0 +1,13 @@
1
+ from .action_model import ActionModel
2
+ from .brainstorm_model import BrainstormModel
3
+ from .plan_model import PlanModel
4
+ from .reason_model import ReasonModel
5
+ from .step_model import StepModel
6
+
7
+ __all__ = [
8
+ "ReasonModel",
9
+ "StepModel",
10
+ "BrainstormModel",
11
+ "ActionModel",
12
+ "PlanModel",
13
+ ]
@@ -0,0 +1,61 @@
1
+ from typing import Any
2
+
3
+ from lionfuncs import to_dict, validate_str
4
+ from pydantic import BaseModel, Field, field_validator
5
+
6
+
7
+ class ActionModel(BaseModel):
8
+
9
+ title: str = Field(
10
+ ...,
11
+ title="Title",
12
+ description="Provide a concise title summarizing the action.",
13
+ )
14
+ content: str = Field(
15
+ ...,
16
+ title="Content",
17
+ description="Provide a brief description of the action to be performed.",
18
+ )
19
+ function: str = Field(
20
+ ...,
21
+ title="Function",
22
+ description=(
23
+ "Specify the name of the function to execute. **Choose from the provided "
24
+ "`tool_schema`; do not invent function names.**"
25
+ ),
26
+ examples=["print", "add", "len"],
27
+ )
28
+ arguments: dict[str, Any] = Field(
29
+ {},
30
+ title="Arguments",
31
+ description=(
32
+ "Provide the arguments to pass to the function as a dictionary. **Use "
33
+ "argument names and types as specified in the `tool_schema`; do not "
34
+ "invent argument names.**"
35
+ ),
36
+ examples=[{"num1": 1, "num2": 2}, {"x": "hello", "y": "world"}],
37
+ )
38
+
39
+ @field_validator("title", mode="before")
40
+ def validate_title(cls, value: Any) -> str:
41
+ return validate_str(value, "title")
42
+
43
+ @field_validator("content", mode="before")
44
+ def validate_content(cls, value: Any) -> str:
45
+ return validate_str(value, "content")
46
+
47
+ @field_validator("function", mode="before")
48
+ def validate_function(cls, value: Any) -> str:
49
+ return validate_str(value, "function")
50
+
51
+ @field_validator("arguments", mode="before")
52
+ def validate_arguments(cls, value: Any) -> dict[str, Any]:
53
+ return to_dict(
54
+ value,
55
+ fuzzy_parse=True,
56
+ suppress=True,
57
+ recursive=True,
58
+ )
59
+
60
+
61
+ __all__ = ["ActionModel"]
@@ -0,0 +1,42 @@
1
+ from typing import Any
2
+
3
+ from lionfuncs import validate_str
4
+ from pydantic import BaseModel, Field, field_validator
5
+
6
+ from .reason_model import ReasonModel
7
+ from .step_model import StepModel
8
+
9
+
10
+ class BrainstormModel(BaseModel):
11
+
12
+ title: str = Field(
13
+ ...,
14
+ title="Title",
15
+ description="Provide a concise title summarizing the brainstorming session.",
16
+ )
17
+ content: str = Field(
18
+ ...,
19
+ title="Content",
20
+ description="Describe the context or focus of the brainstorming session.",
21
+ )
22
+ ideas: list[StepModel] = Field(
23
+ ...,
24
+ title="Ideas",
25
+ description="A list of ideas for the next step, generated during brainstorming.",
26
+ )
27
+ reason: ReasonModel = Field(
28
+ ...,
29
+ title="Reason",
30
+ description="Provide the high level reasoning behind the brainstorming session.",
31
+ )
32
+
33
+ @field_validator("title", mode="before")
34
+ def validate_title(cls, value: Any) -> str:
35
+ return validate_str(value, "title")
36
+
37
+ @field_validator("content", mode="before")
38
+ def validate_content(cls, value: Any) -> str:
39
+ return validate_str(value, "content")
40
+
41
+
42
+ __all__ = ["BrainstormModel"]
@@ -0,0 +1,51 @@
1
+ from typing import Any, List
2
+
3
+ from lionfuncs import validate_str
4
+ from pydantic import BaseModel, Field, field_validator
5
+
6
+ from .reason_model import ReasonModel
7
+ from .step_model import StepModel
8
+
9
+
10
+ class PlanModel(BaseModel):
11
+ """
12
+ Represents a plan consisting of multiple steps, with an overall reason.
13
+
14
+ Attributes:
15
+ title (str): A concise title summarizing the plan.
16
+ content (str): A detailed description of the plan.
17
+ reason (ReasonModel): The overall reasoning behind the plan.
18
+ steps (List[StepModel]): A list of steps to execute the plan.
19
+ """
20
+
21
+ title: str = Field(
22
+ ...,
23
+ title="Title",
24
+ description="Provide a concise title summarizing the plan.",
25
+ )
26
+ content: str = Field(
27
+ ...,
28
+ title="Content",
29
+ description="Provide a detailed description of the plan.",
30
+ )
31
+ reason: ReasonModel = Field(
32
+ ...,
33
+ title="Reason",
34
+ description="Provide the reasoning behind the entire plan.",
35
+ )
36
+ steps: list[StepModel] = Field(
37
+ ...,
38
+ title="Steps",
39
+ description="A list of steps to execute the plan.",
40
+ )
41
+
42
+ @field_validator("title", mode="before")
43
+ def validate_title(cls, value: Any) -> str:
44
+ return validate_str(value, "title")
45
+
46
+ @field_validator("content", mode="before")
47
+ def validate_content(cls, value: Any) -> str:
48
+ return validate_str(value, "content")
49
+
50
+
51
+ __all__ = ["PlanModel"]
@@ -0,0 +1,63 @@
1
+ import logging
2
+ from typing import Any
3
+
4
+ from lionfuncs import to_num, validate_str
5
+ from pydantic import BaseModel, Field, field_validator
6
+
7
+
8
+ class ReasonModel(BaseModel):
9
+ title: str = Field(
10
+ ...,
11
+ title="Title",
12
+ description="Provide a concise title summarizing the reason.",
13
+ )
14
+ content: str = Field(
15
+ ...,
16
+ title="Content",
17
+ description=(
18
+ "Provide a detailed explanation supporting the reason, including relevant "
19
+ "information or context."
20
+ ),
21
+ )
22
+ confidence_score: float | None = Field(
23
+ None,
24
+ description=(
25
+ "Provide an objective numeric confidence score between 0 and 1 (with 3 "
26
+ "decimal places) indicating how likely you successfully achieved the task "
27
+ "according to user expectation. Interpret the score as:\n"
28
+ "- **1**: Very confident in a good job.\n"
29
+ "- **0**: Not confident at all.\n"
30
+ "- **[0.8, 1]**: You can continue the path of reasoning if needed.\n"
31
+ "- **[0.5, 0.8)**: Recheck your reasoning and consider reverting to a "
32
+ "previous, more confident reasoning path.\n"
33
+ "- **[0, 0.5)**: Stop because the reasoning is starting to be off track."
34
+ ),
35
+ examples=[0.821, 0.257, 0.923, 0.439],
36
+ ge=0,
37
+ le=1,
38
+ )
39
+
40
+ @field_validator("title", mode="before")
41
+ def validate_title(cls, value: Any) -> str:
42
+ return validate_str(value, "title")
43
+
44
+ @field_validator("content", mode="before")
45
+ def validate_content(cls, value: Any) -> str:
46
+ return validate_str(value, "content")
47
+
48
+ @field_validator("confidence_score", mode="before")
49
+ def validate_confidence_score(cls, value: Any) -> float:
50
+ try:
51
+ return to_num(
52
+ value,
53
+ upper_bound=1,
54
+ lower_bound=0,
55
+ num_type=float,
56
+ precision=3,
57
+ )
58
+ except Exception as e:
59
+ logging.error(f"Failed to convert {value} to a number. Error: {e}")
60
+ return 0.0
61
+
62
+
63
+ __all__ = ["ReasonModel"]
@@ -0,0 +1,65 @@
1
+ import logging
2
+ from typing import Any
3
+
4
+ from lionfuncs import validate_boolean, validate_str
5
+ from pydantic import BaseModel, Field, field_validator
6
+
7
+ from .action_model import ActionModel
8
+ from .reason_model import ReasonModel
9
+
10
+
11
+ class StepModel(BaseModel):
12
+ title: str = Field(
13
+ ...,
14
+ title="Title",
15
+ description="Provide a concise title summarizing the step.",
16
+ )
17
+ content: str = Field(
18
+ ...,
19
+ title="Content",
20
+ description="Describe the content of the step in detail.",
21
+ )
22
+ reason: ReasonModel = Field(
23
+ ...,
24
+ title="Reason",
25
+ description="Provide the reasoning behind this step, including supporting details.",
26
+ )
27
+ action_required: bool = Field(
28
+ False,
29
+ title="Action Required",
30
+ description=(
31
+ "Indicate whether this step requires an action. Set to **True** if an "
32
+ "action is required; otherwise, set to **False**."
33
+ ),
34
+ )
35
+ actions: list[ActionModel] = Field(
36
+ [],
37
+ title="Actions",
38
+ description=(
39
+ "List of actions to be performed if `action_required` is **True**. Leave "
40
+ "empty if no action is required. **When providing actions, you must "
41
+ "choose from the provided `tool_schema`. Do not invent function or "
42
+ "argument names.**"
43
+ ),
44
+ )
45
+
46
+ @field_validator("title", mode="before")
47
+ def validate_title(cls, value: Any) -> str:
48
+ return validate_str(value, "title")
49
+
50
+ @field_validator("content", mode="before")
51
+ def validate_content(cls, value: Any) -> str:
52
+ return validate_str(value, "content")
53
+
54
+ @field_validator("action_required", mode="before")
55
+ def validate_action_required(cls, value: Any) -> bool:
56
+ try:
57
+ return validate_boolean(value)
58
+ except Exception as e:
59
+ logging.error(
60
+ f"Failed to convert {value} to a boolean. Error: {e}"
61
+ )
62
+ return False
63
+
64
+
65
+ __all__ = ["StepModel"]
File without changes
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from enum import Enum
5
+
6
+ from lionfuncs import choose_most_similar
7
+ from pydantic import BaseModel
8
+
9
+ from lionagi.core.director.models import ReasonModel
10
+ from lionagi.core.session.branch import Branch
11
+
12
+ from .utils import is_enum
13
+
14
+ PROMPT = "Please select up to {max_num_selections} items from the following list {choices}. Provide the selection(s), and no comments from you"
15
+
16
+
17
+ class SelectionModel(BaseModel):
18
+ selected: list[str | Enum]
19
+
20
+
21
+ class ReasonSelectionModel(BaseModel):
22
+ selected: list[str | Enum]
23
+ reason: ReasonModel
24
+
25
+
26
+ async def select(
27
+ choices: list[str] | type[Enum],
28
+ max_num_selections: int = 1,
29
+ instruction=None,
30
+ context=None,
31
+ system=None,
32
+ sender=None,
33
+ recipient=None,
34
+ reason: bool = False,
35
+ return_enum: bool = False,
36
+ enum_parser: Callable = None, # parse the model string response to appropriate type
37
+ branch: Branch = None,
38
+ return_pydantic_model=False,
39
+ **kwargs, # additional chat arguments
40
+ ):
41
+ selections = []
42
+ if return_enum and not is_enum(choices):
43
+ raise ValueError("return_enum can only be True if choices is an Enum")
44
+
45
+ if is_enum(choices):
46
+ selections = [selection.value for selection in choices]
47
+ else:
48
+ selections = choices
49
+
50
+ prompt = PROMPT.format(
51
+ max_num_selections=max_num_selections, choices=selections
52
+ )
53
+
54
+ if instruction:
55
+ prompt = f"{instruction}\n\n{prompt} \n\n "
56
+
57
+ branch = branch or Branch()
58
+ response: SelectionModel | ReasonSelectionModel | str = await branch.chat(
59
+ instruction=prompt,
60
+ context=context,
61
+ system=system,
62
+ sender=sender,
63
+ recipient=recipient,
64
+ pydantic_model=SelectionModel if not reason else ReasonSelectionModel,
65
+ return_pydantic_model=True,
66
+ **kwargs,
67
+ )
68
+
69
+ selected = response
70
+ if isinstance(response, SelectionModel | ReasonSelectionModel):
71
+ selected = response.selected
72
+ selected = [selected] if not isinstance(selected, list) else selected
73
+ corrected_selections = [
74
+ choose_most_similar(selection, selections) for selection in selected
75
+ ]
76
+
77
+ if return_enum:
78
+ out = []
79
+ if not enum_parser:
80
+ enum_parser = lambda x: x
81
+ for selection in corrected_selections:
82
+ selection = enum_parser(selection)
83
+ for member in choices.__members__.values():
84
+ if member.value == selection:
85
+ out.append(member)
86
+ corrected_selections = out
87
+
88
+ if return_pydantic_model:
89
+ if not isinstance(response, SelectionModel | ReasonSelectionModel):
90
+ return SelectionModel(selected=corrected_selections)
91
+ response.selected = corrected_selections
92
+ return response
93
+ return corrected_selections
@@ -0,0 +1,6 @@
1
+ from enum import Enum
2
+ from inspect import isclass
3
+
4
+
5
+ def is_enum(choices):
6
+ return isclass(choices) and issubclass(choices, Enum)
File without changes
File without changes
File without changes
File without changes
@@ -2,9 +2,9 @@ import asyncio
2
2
  import contextlib
3
3
  import re
4
4
  from abc import ABC
5
- from typing import Any, Optional
5
+ from typing import Any
6
6
 
7
- from lionfuncs import extract_json_block, to_dict, validate_mapping
7
+ from lionfuncs import extract_block, to_dict, validate_mapping
8
8
 
9
9
  from lionagi.core.collections.abc import ActionError
10
10
  from lionagi.core.message import ActionRequest, ActionResponse, Instruction
@@ -1156,7 +1156,7 @@ class DirectiveMixin(ABC):
1156
1156
  return to_dict(out_, fuzzy_parse=True)
1157
1157
 
1158
1158
  with contextlib.suppress(Exception):
1159
- return extract_json_block(out_)
1159
+ return extract_block(out_)
1160
1160
 
1161
1161
  with contextlib.suppress(Exception):
1162
1162
  match = re.search(r"```json\n({.*?})\n```", out_, re.DOTALL)
File without changes
File without changes