lionagi 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/core/execute/structure_executor.py +21 -1
- lionagi/core/flow/monoflow/ReAct.py +3 -1
- lionagi/core/flow/monoflow/followup.py +3 -1
- lionagi/core/generic/component.py +197 -120
- lionagi/core/generic/condition.py +2 -0
- lionagi/core/generic/edge.py +33 -33
- lionagi/core/graph/graph.py +1 -1
- lionagi/core/tool/tool_manager.py +10 -9
- lionagi/experimental/report/form.py +64 -0
- lionagi/experimental/report/report.py +138 -0
- lionagi/experimental/report/util.py +47 -0
- lionagi/experimental/tool/schema.py +3 -3
- lionagi/experimental/tool/tool_manager.py +1 -1
- lionagi/experimental/validator/rule.py +139 -0
- lionagi/experimental/validator/validator.py +56 -0
- lionagi/experimental/work/__init__.py +10 -0
- lionagi/experimental/work/async_queue.py +54 -0
- lionagi/experimental/work/schema.py +60 -17
- lionagi/experimental/work/work_function.py +55 -77
- lionagi/experimental/work/worker.py +56 -12
- lionagi/experimental/work2/__init__.py +0 -0
- lionagi/experimental/work2/form.py +371 -0
- lionagi/experimental/work2/report.py +289 -0
- lionagi/experimental/work2/schema.py +30 -0
- lionagi/experimental/{work → work2}/tests.py +1 -1
- lionagi/experimental/work2/util.py +0 -0
- lionagi/experimental/work2/work.py +0 -0
- lionagi/experimental/work2/work_function.py +89 -0
- lionagi/experimental/work2/worker.py +12 -0
- lionagi/integrations/storage/storage_util.py +4 -4
- lionagi/integrations/storage/structure_excel.py +268 -0
- lionagi/integrations/storage/to_excel.py +18 -9
- lionagi/libs/__init__.py +4 -0
- lionagi/tests/test_core/generic/__init__.py +0 -0
- lionagi/tests/test_core/generic/test_component.py +89 -0
- lionagi/version.py +1 -1
- {lionagi-0.1.1.dist-info → lionagi-0.1.2.dist-info}/METADATA +1 -1
- {lionagi-0.1.1.dist-info → lionagi-0.1.2.dist-info}/RECORD +43 -27
- lionagi/experimental/work/_logger.py +0 -25
- /lionagi/experimental/{work/exchange.py → report/__init__.py} +0 -0
- /lionagi/experimental/{work/util.py → validator/__init__.py} +0 -0
- {lionagi-0.1.1.dist-info → lionagi-0.1.2.dist-info}/LICENSE +0 -0
- {lionagi-0.1.1.dist-info → lionagi-0.1.2.dist-info}/WHEEL +0 -0
- {lionagi-0.1.1.dist-info → lionagi-0.1.2.dist-info}/top_level.txt +0 -0
lionagi/core/generic/edge.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
1
|
"""
|
2
|
-
Module for representing conditions and edges between nodes in a graph
|
2
|
+
Module for representing conditions and edges between nodes in a graph.
|
3
3
|
|
4
|
-
This module provides the base for creating and managing edges that connect
|
5
|
-
within a graph. It includes support for conditional edges, allowing
|
6
|
-
evaluation of connections based on custom logic.
|
4
|
+
This module provides the base for creating and managing edges that connect
|
5
|
+
nodes within a graph. It includes support for conditional edges, allowing
|
6
|
+
the dynamic evaluation of connections based on custom logic.
|
7
7
|
"""
|
8
8
|
|
9
9
|
from typing import Any
|
10
10
|
from pydantic import Field, field_validator
|
11
|
-
from
|
12
|
-
from
|
11
|
+
from .component import BaseComponent
|
12
|
+
from .condition import Condition
|
13
13
|
|
14
14
|
|
15
15
|
class Edge(BaseComponent):
|
@@ -19,12 +19,14 @@ class Edge(BaseComponent):
|
|
19
19
|
Attributes:
|
20
20
|
head (str): The identifier of the head node of the edge.
|
21
21
|
tail (str): The identifier of the tail node of the edge.
|
22
|
-
condition (
|
22
|
+
condition (Condition | None): Optional condition that must be met
|
23
23
|
for the edge to be considered active.
|
24
|
-
label (
|
24
|
+
label (str | None): An optional label for the edge.
|
25
|
+
bundle (bool): A flag indicating if the edge is bundled.
|
25
26
|
|
26
27
|
Methods:
|
27
|
-
check_condition: Evaluates if the condition
|
28
|
+
check_condition: Evaluates if the condition is met.
|
29
|
+
string_condition: Retrieves the condition class source code.
|
28
30
|
"""
|
29
31
|
|
30
32
|
head: str = Field(
|
@@ -37,7 +39,8 @@ class Edge(BaseComponent):
|
|
37
39
|
)
|
38
40
|
condition: Condition | None = Field(
|
39
41
|
default=None,
|
40
|
-
description="
|
42
|
+
description="Optional condition that must be met for the edge "
|
43
|
+
"to be considered active.",
|
41
44
|
)
|
42
45
|
label: str | None = Field(
|
43
46
|
default=None,
|
@@ -51,20 +54,19 @@ class Edge(BaseComponent):
|
|
51
54
|
@field_validator("head", "tail", mode="before")
|
52
55
|
def _validate_head_tail(cls, value):
|
53
56
|
"""
|
54
|
-
Validates
|
57
|
+
Validates head and tail fields to ensure valid node identifiers.
|
55
58
|
|
56
59
|
Args:
|
57
|
-
value: The value of the field being validated.
|
58
|
-
values: A dictionary of all other values on the model.
|
59
|
-
field: The model field being validated.
|
60
|
+
value (Any): The value of the field being validated.
|
60
61
|
|
61
62
|
Returns:
|
62
|
-
The validated value, ensuring it is a valid identifier.
|
63
|
+
str: The validated value, ensuring it is a valid identifier.
|
63
64
|
|
64
65
|
Raises:
|
65
66
|
ValueError: If the validation fails.
|
66
67
|
"""
|
67
|
-
|
68
|
+
|
69
|
+
if isinstance(value, BaseComponent):
|
68
70
|
return value.id_
|
69
71
|
return value
|
70
72
|
|
@@ -73,7 +75,7 @@ class Edge(BaseComponent):
|
|
73
75
|
Evaluates if the condition associated with the edge is met.
|
74
76
|
|
75
77
|
Args:
|
76
|
-
obj (dict[str, Any]):
|
78
|
+
obj (dict[str, Any]): Context for condition evaluation.
|
77
79
|
|
78
80
|
Returns:
|
79
81
|
bool: True if the condition is met, False otherwise.
|
@@ -87,19 +89,23 @@ class Edge(BaseComponent):
|
|
87
89
|
|
88
90
|
def string_condition(self):
|
89
91
|
"""
|
90
|
-
Retrieves the
|
92
|
+
Retrieves the condition class source code as a string.
|
91
93
|
|
92
|
-
This method is useful for serialization and debugging, allowing
|
93
|
-
|
94
|
-
|
94
|
+
This method is useful for serialization and debugging, allowing
|
95
|
+
the condition logic to be inspected or stored in a human-readable
|
96
|
+
format. It employs advanced introspection techniques to locate and
|
97
|
+
extract the exact class definition, handling edge cases like
|
98
|
+
dynamically defined classes or classes defined interactively.
|
95
99
|
|
96
100
|
Returns:
|
97
|
-
str: The source code
|
98
|
-
|
101
|
+
str | None: The condition class source code if available.
|
102
|
+
If the condition is None or the source code cannot be
|
103
|
+
located, this method returns None.
|
99
104
|
|
100
105
|
Raises:
|
101
|
-
TypeError: If the source code
|
102
|
-
|
106
|
+
TypeError: If the condition class source code cannot be found
|
107
|
+
due to the class being defined in a non-standard manner or
|
108
|
+
in the interactive interpreter (__main__ context).
|
103
109
|
"""
|
104
110
|
if self.condition is None:
|
105
111
|
return
|
@@ -139,9 +145,8 @@ class Edge(BaseComponent):
|
|
139
145
|
|
140
146
|
def __str__(self) -> str:
|
141
147
|
"""
|
142
|
-
Returns a simple string representation of the
|
148
|
+
Returns a simple string representation of the Edge.
|
143
149
|
"""
|
144
|
-
|
145
150
|
return (
|
146
151
|
f"Edge (id_={self.id_}, from={self.head}, to={self.tail}, "
|
147
152
|
f"label={self.label})"
|
@@ -149,12 +154,7 @@ class Edge(BaseComponent):
|
|
149
154
|
|
150
155
|
def __repr__(self) -> str:
|
151
156
|
"""
|
152
|
-
Returns a detailed string representation of the
|
153
|
-
|
154
|
-
Examples:
|
155
|
-
>>> edge = Relationship(source_node_id="node1", target_node_id="node2")
|
156
|
-
>>> repr(edge)
|
157
|
-
'Relationship(id_=None, from=node1, to=node2, content=None, metadata=None, label=None)'
|
157
|
+
Returns a detailed string representation of the Edge.
|
158
158
|
"""
|
159
159
|
return (
|
160
160
|
f"Edge(id_={self.id_}, from={self.head}, to={self.tail}, "
|
lionagi/core/graph/graph.py
CHANGED
@@ -92,7 +92,7 @@ class Graph(BaseStructure):
|
|
92
92
|
for node_id, node in self.internal_nodes.items():
|
93
93
|
node_info = node.to_dict()
|
94
94
|
node_info.pop("id_")
|
95
|
-
node_info.update({"class_name": node.class_name
|
95
|
+
node_info.update({"class_name": node.class_name})
|
96
96
|
g.add_node(node_id, **node_info)
|
97
97
|
|
98
98
|
for _edge in list(self.internal_edges.values()):
|
@@ -159,15 +159,16 @@ class ToolManager:
|
|
159
159
|
else:
|
160
160
|
raise ValueError(f"Function {tool} is not registered.")
|
161
161
|
|
162
|
-
if
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
162
|
+
if tools:
|
163
|
+
if isinstance(tools, bool):
|
164
|
+
tool_kwarg = {"tools": self.to_tool_schema_list()}
|
165
|
+
kwargs = tool_kwarg | kwargs
|
166
|
+
|
167
|
+
else:
|
168
|
+
if not isinstance(tools, list):
|
169
|
+
tools = [tools]
|
170
|
+
tool_kwarg = {"tools": func_call.lcall(tools, tool_check)}
|
171
|
+
kwargs = tool_kwarg | kwargs
|
171
172
|
|
172
173
|
return kwargs
|
173
174
|
|
@@ -0,0 +1,64 @@
|
|
1
|
+
from pydantic import Field
|
2
|
+
|
3
|
+
# from lionagi import logging as _logging
|
4
|
+
from lionagi.core.generic import BaseComponent
|
5
|
+
from lionagi.experimental.report.util import get_input_output_fields, system_fields
|
6
|
+
|
7
|
+
|
8
|
+
class Form(BaseComponent):
|
9
|
+
|
10
|
+
assignment: str = Field(..., examples=["input1, input2 -> output"])
|
11
|
+
|
12
|
+
input_fields: list[str] = Field(default_factory=list)
|
13
|
+
output_fields: list[str] = Field(default_factory=list)
|
14
|
+
|
15
|
+
def __init__(self, **kwargs):
|
16
|
+
"""
|
17
|
+
at initialization, all relevant fields if not already provided, are set to None,
|
18
|
+
not every field is required to be filled, nor required to be declared at initialization
|
19
|
+
"""
|
20
|
+
super().__init__(**kwargs)
|
21
|
+
self.input_fields, self.output_fields = get_input_output_fields(self.assignment)
|
22
|
+
for i in self.input_fields + self.output_fields:
|
23
|
+
if i not in self.model_fields:
|
24
|
+
self._add_field(i, value=None)
|
25
|
+
|
26
|
+
@property
|
27
|
+
def workable(self):
|
28
|
+
if self.filled:
|
29
|
+
return False
|
30
|
+
|
31
|
+
for i in self.input_fields:
|
32
|
+
if not getattr(self, i, None):
|
33
|
+
return False
|
34
|
+
|
35
|
+
return True
|
36
|
+
|
37
|
+
@property
|
38
|
+
def work_fields(self):
|
39
|
+
dict_ = self.to_dict()
|
40
|
+
return {
|
41
|
+
k: v
|
42
|
+
for k, v in dict_.items()
|
43
|
+
if k not in system_fields and k in self.input_fields + self.output_fields
|
44
|
+
}
|
45
|
+
|
46
|
+
@property
|
47
|
+
def filled(self):
|
48
|
+
return all([value is not None for _, value in self.work_fields.items()])
|
49
|
+
|
50
|
+
def fill(self, form: "Form" = None, **kwargs):
|
51
|
+
"""
|
52
|
+
only work fields for this form can be filled
|
53
|
+
a field can only be filled once
|
54
|
+
"""
|
55
|
+
if self.filled:
|
56
|
+
raise ValueError("Form is already filled")
|
57
|
+
|
58
|
+
fields = form.work_fields if form else {}
|
59
|
+
kwargs = {**fields, **kwargs}
|
60
|
+
|
61
|
+
for k, v in kwargs.items():
|
62
|
+
if k not in self.work_fields:
|
63
|
+
raise ValueError(f"Field {k} is not a valid work field")
|
64
|
+
setattr(self, k, v)
|
@@ -0,0 +1,138 @@
|
|
1
|
+
from typing import Any, Type
|
2
|
+
from pydantic import Field
|
3
|
+
|
4
|
+
# from lionagi import logging as _logging
|
5
|
+
from lionagi.core.generic import BaseComponent
|
6
|
+
from lionagi.experimental.report.form import Form
|
7
|
+
from lionagi.experimental.report.util import get_input_output_fields
|
8
|
+
|
9
|
+
|
10
|
+
class Report(BaseComponent):
|
11
|
+
|
12
|
+
assignment: str = Field(..., examples=["input1, input2 -> output"])
|
13
|
+
|
14
|
+
forms: dict[str, Form] = Field(
|
15
|
+
default_factory=dict,
|
16
|
+
description="A dictionary of forms related to the report, in {assignment: Form} format.",
|
17
|
+
)
|
18
|
+
|
19
|
+
form_assignments: list = Field(
|
20
|
+
[],
|
21
|
+
description="assignment for the report",
|
22
|
+
examples=[["a, b -> c", "a -> e", "b -> f", "c -> g", "e, f, g -> h"]],
|
23
|
+
)
|
24
|
+
|
25
|
+
form_template: Type[Form] = Field(
|
26
|
+
Form, description="The template for the forms in the report."
|
27
|
+
)
|
28
|
+
|
29
|
+
input_fields: list[str] = Field(default_factory=list)
|
30
|
+
output_fields: list[str] = Field(default_factory=list)
|
31
|
+
|
32
|
+
def __init__(self, **kwargs):
|
33
|
+
"""
|
34
|
+
at initialization, all relevant fields if not already provided, are set to None
|
35
|
+
"""
|
36
|
+
super().__init__(**kwargs)
|
37
|
+
self.input_fields, self.output_fields = get_input_output_fields(self.assignment)
|
38
|
+
|
39
|
+
# if assignments is not provided, set it to assignment
|
40
|
+
if self.form_assignments == []:
|
41
|
+
self.form_assignments.append(self.assignment)
|
42
|
+
|
43
|
+
# create forms
|
44
|
+
new_forms = {i: self.form_template(assignment=i) for i in self.form_assignments}
|
45
|
+
|
46
|
+
# add new forms into the report (will ignore new forms already in the
|
47
|
+
# report with same assignment)
|
48
|
+
for k, v in new_forms.items():
|
49
|
+
if k not in self.forms:
|
50
|
+
self.forms[k] = v
|
51
|
+
|
52
|
+
# if the fields are not declared in the report, add them to report
|
53
|
+
# with value set to None
|
54
|
+
for k, v in self.forms.items():
|
55
|
+
for f in list(v.work_fields.keys()):
|
56
|
+
if f not in self.model_fields:
|
57
|
+
field = v.model_fields[f]
|
58
|
+
self._add_field(f, value=None, field=field)
|
59
|
+
|
60
|
+
# if there are fields in the report that are not in the forms, add them to
|
61
|
+
# the forms with values
|
62
|
+
for k, v in self.model_fields.items():
|
63
|
+
if getattr(self, k, None) is not None:
|
64
|
+
for f in self.forms.values():
|
65
|
+
if k in f.work_fields:
|
66
|
+
f.fill(**{k: getattr(self, k)})
|
67
|
+
|
68
|
+
@property
|
69
|
+
def work_fields(self) -> dict[str, Any]:
|
70
|
+
"""
|
71
|
+
all work fields across all forms, including intermediate output fields,
|
72
|
+
this information is extracted from the forms
|
73
|
+
"""
|
74
|
+
|
75
|
+
all_fields = {}
|
76
|
+
for form in self.forms.values():
|
77
|
+
for k, v in form.work_fields.items():
|
78
|
+
if k not in all_fields:
|
79
|
+
all_fields[k] = v
|
80
|
+
return all_fields
|
81
|
+
|
82
|
+
def fill(self, **kwargs):
|
83
|
+
"""
|
84
|
+
fill the information to both the report and forms
|
85
|
+
"""
|
86
|
+
kwargs = {**self.work_fields, **kwargs}
|
87
|
+
for k, v in kwargs.items():
|
88
|
+
if k in self.work_fields and getattr(self, k, None) is None:
|
89
|
+
setattr(self, k, v)
|
90
|
+
|
91
|
+
for form in self.forms.values():
|
92
|
+
if not form.filled:
|
93
|
+
_kwargs = {k: v for k, v in kwargs.items() if k in form.work_fields}
|
94
|
+
form.fill(**_kwargs)
|
95
|
+
|
96
|
+
@property
|
97
|
+
def filled(self):
|
98
|
+
return all([value is not None for _, value in self.work_fields.items()])
|
99
|
+
|
100
|
+
@property
|
101
|
+
def workable(self) -> bool:
|
102
|
+
|
103
|
+
if self.filled:
|
104
|
+
# _logging.info("The report is already filled, no need to work on it.")
|
105
|
+
return False
|
106
|
+
|
107
|
+
for i in self.input_fields:
|
108
|
+
if not getattr(self, i, None):
|
109
|
+
# _logging.error(f"Field '{i}' is required to work on the report.")
|
110
|
+
return False
|
111
|
+
|
112
|
+
# this is the required fields from report's own assignment
|
113
|
+
fields = self.input_fields
|
114
|
+
fields.extend(self.output_fields)
|
115
|
+
|
116
|
+
# if the report's own assignment is not in the forms, return False
|
117
|
+
for f in fields:
|
118
|
+
if f not in self.work_fields:
|
119
|
+
# _logging.error(f"Field {f} is a required deliverable, not found in work field.")
|
120
|
+
return False
|
121
|
+
|
122
|
+
# get all the output fields from all the forms
|
123
|
+
outs = []
|
124
|
+
for form in self.forms.values():
|
125
|
+
outs.extend(form.output_fields)
|
126
|
+
|
127
|
+
# all output fields should be unique, not a single output field should be
|
128
|
+
# calculated by more than one form
|
129
|
+
if len(outs) != len(set(outs)):
|
130
|
+
# _logging.error("There are duplicate output fields in the forms.")
|
131
|
+
return False
|
132
|
+
|
133
|
+
return True
|
134
|
+
|
135
|
+
@property
|
136
|
+
def next_forms(self) -> list[Form] | None:
|
137
|
+
a = [i for i in self.forms.values() if i.workable]
|
138
|
+
return a if len(a) > 0 else None
|
@@ -0,0 +1,47 @@
|
|
1
|
+
from lionagi.libs import convert
|
2
|
+
|
3
|
+
system_fields = [
|
4
|
+
"id_",
|
5
|
+
"node_id",
|
6
|
+
"meta",
|
7
|
+
"metadata",
|
8
|
+
"timestamp",
|
9
|
+
"content",
|
10
|
+
"assignment",
|
11
|
+
"assignments",
|
12
|
+
"task",
|
13
|
+
"template_name",
|
14
|
+
"version",
|
15
|
+
"description",
|
16
|
+
"in_validation_kwargs",
|
17
|
+
"out_validation_kwargs",
|
18
|
+
"fix_input",
|
19
|
+
"fix_output",
|
20
|
+
"input_fields",
|
21
|
+
"output_fields",
|
22
|
+
"choices",
|
23
|
+
"prompt_fields",
|
24
|
+
"prompt_fields_annotation",
|
25
|
+
"instruction_context",
|
26
|
+
"instruction",
|
27
|
+
"instruction_output_fields",
|
28
|
+
"inputs",
|
29
|
+
"outputs",
|
30
|
+
"process",
|
31
|
+
"_validate_field",
|
32
|
+
"_process_input",
|
33
|
+
"_process_response",
|
34
|
+
"_validate_field_choices",
|
35
|
+
"_validate_input_choices",
|
36
|
+
"_validate_output_choices",
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
def get_input_output_fields(str_: str) -> list[list[str]]:
|
41
|
+
|
42
|
+
inputs, outputs = str_.split("->")
|
43
|
+
|
44
|
+
input_fields = [convert.strip_lower(i) for i in inputs.split(",")]
|
45
|
+
output_fields = [convert.strip_lower(o) for o in outputs.split(",")]
|
46
|
+
|
47
|
+
return input_fields, output_fields
|
@@ -34,16 +34,16 @@ class Tool(Node):
|
|
34
34
|
out = None
|
35
35
|
|
36
36
|
if self.pre_processor:
|
37
|
-
kwargs = await func_call.
|
37
|
+
kwargs = await func_call.call_handler(self.pre_processor, kwargs)
|
38
38
|
try:
|
39
|
-
out = await func_call.
|
39
|
+
out = await func_call.call_handler(self.func, **kwargs)
|
40
40
|
|
41
41
|
except Exception as e:
|
42
42
|
_logging.error(f"Error invoking function {self.func_name}: {e}")
|
43
43
|
return None
|
44
44
|
|
45
45
|
if self.post_processor:
|
46
|
-
return await func_call.
|
46
|
+
return await func_call.call_handler(self.post_processor, out)
|
47
47
|
|
48
48
|
return out
|
49
49
|
|
@@ -0,0 +1,139 @@
|
|
1
|
+
from lionagi.libs import validation_funcs
|
2
|
+
from abc import abstractmethod
|
3
|
+
|
4
|
+
|
5
|
+
class Rule:
|
6
|
+
|
7
|
+
def __init__(self, **kwargs):
|
8
|
+
self.validation_kwargs = kwargs
|
9
|
+
self.fix = kwargs.get("fix", False)
|
10
|
+
|
11
|
+
@abstractmethod
|
12
|
+
def condition(self, **kwargs):
|
13
|
+
pass
|
14
|
+
|
15
|
+
@abstractmethod
|
16
|
+
async def validate(self, value, **kwargs):
|
17
|
+
pass
|
18
|
+
|
19
|
+
|
20
|
+
class ChoiceRule(Rule):
|
21
|
+
|
22
|
+
def condition(self, choices=None):
|
23
|
+
return choices is not None
|
24
|
+
|
25
|
+
def check(self, choices=None):
|
26
|
+
if choices and not isinstance(choices, list):
|
27
|
+
try:
|
28
|
+
choices = [i.value for i in choices]
|
29
|
+
except Exception as e:
|
30
|
+
raise ValueError(f"failed to get choices") from e
|
31
|
+
return choices
|
32
|
+
|
33
|
+
def fix(self, value, choices=None, **kwargs):
|
34
|
+
v_ = validation_funcs["enum"](value, choices=choices, fix_=True, **kwargs)
|
35
|
+
return v_
|
36
|
+
|
37
|
+
async def validate(self, value, choices=None, **kwargs):
|
38
|
+
if self.condition(choices):
|
39
|
+
if value in self.check(choices):
|
40
|
+
return value
|
41
|
+
if self.fix:
|
42
|
+
kwargs = {**self.validation_kwargs, **kwargs}
|
43
|
+
return self.fix(value, choices, **kwargs)
|
44
|
+
raise ValueError(f"{value} is not in chocies {choices}")
|
45
|
+
|
46
|
+
|
47
|
+
class ActionRequestRule(Rule):
|
48
|
+
|
49
|
+
def condition(self, annotation=None):
|
50
|
+
return any("actionrequest" in i for i in annotation)
|
51
|
+
|
52
|
+
async def validate(self, value, annotation=None):
|
53
|
+
if self.condition(annotation):
|
54
|
+
try:
|
55
|
+
return validation_funcs["action"](value)
|
56
|
+
except Exception as e:
|
57
|
+
raise ValueError(f"failed to validate field") from e
|
58
|
+
|
59
|
+
|
60
|
+
class BooleanRule(Rule):
|
61
|
+
|
62
|
+
def condition(self, annotation=None):
|
63
|
+
return "bool" in annotation and "str" not in annotation
|
64
|
+
|
65
|
+
async def validate(self, value, annotation=None):
|
66
|
+
if self.condition(annotation):
|
67
|
+
try:
|
68
|
+
return validation_funcs["bool"](
|
69
|
+
value, fix_=self.fix, **self.validation_kwargs
|
70
|
+
)
|
71
|
+
except Exception as e:
|
72
|
+
raise ValueError(f"failed to validate field") from e
|
73
|
+
|
74
|
+
|
75
|
+
class NumberRule(Rule):
|
76
|
+
|
77
|
+
def condition(self, annotation=None):
|
78
|
+
return (
|
79
|
+
any([i in annotation for i in ["int", "float", "number"]])
|
80
|
+
and "str" not in annotation
|
81
|
+
)
|
82
|
+
|
83
|
+
async def validate(self, value, annotation=None):
|
84
|
+
if self.condition(annotation):
|
85
|
+
if "float" in annotation:
|
86
|
+
self.validation_kwargs["num_type"] = float
|
87
|
+
if "precision" not in self.validation_kwargs:
|
88
|
+
self.validation_kwargs["precision"] = 32
|
89
|
+
|
90
|
+
try:
|
91
|
+
return validation_funcs["number"](
|
92
|
+
value, fix_=self.fix, **self.validation_kwargs
|
93
|
+
)
|
94
|
+
except Exception as e:
|
95
|
+
raise ValueError(f"failed to validate field") from e
|
96
|
+
|
97
|
+
|
98
|
+
class DictRule(Rule):
|
99
|
+
|
100
|
+
def condition(self, annotation=None):
|
101
|
+
return "dict" in annotation
|
102
|
+
|
103
|
+
async def validate(self, value, annotation=None, keys=None):
|
104
|
+
if self.condition(annotation):
|
105
|
+
if "str" not in annotation or keys:
|
106
|
+
try:
|
107
|
+
return validation_funcs["dict"](
|
108
|
+
value, keys=keys, fix_=self.fix, **self.validation_kwargs
|
109
|
+
)
|
110
|
+
except Exception as e:
|
111
|
+
raise ValueError(f"failed to validate field") from e
|
112
|
+
raise ValueError(f"failed to validate field")
|
113
|
+
|
114
|
+
|
115
|
+
class StringRule(Rule):
|
116
|
+
|
117
|
+
def condition(self, annotation=None):
|
118
|
+
return "str" in annotation
|
119
|
+
|
120
|
+
async def validate(self, value, annotation=None):
|
121
|
+
if self.condition(annotation):
|
122
|
+
try:
|
123
|
+
return validation_funcs["str"](
|
124
|
+
value, fix_=self.fix, **self.validation_kwargs
|
125
|
+
)
|
126
|
+
except Exception as e:
|
127
|
+
raise ValueError(f"failed to validate field") from e
|
128
|
+
|
129
|
+
|
130
|
+
from enum import Enum
|
131
|
+
|
132
|
+
|
133
|
+
class DEFAULT_RULES(Enum):
|
134
|
+
CHOICE = ChoiceRule
|
135
|
+
ACTION_REQUEST = ActionRequestRule
|
136
|
+
BOOL = BooleanRule
|
137
|
+
NUMBER = NumberRule
|
138
|
+
DICT = DictRule
|
139
|
+
STR = StringRule
|
@@ -0,0 +1,56 @@
|
|
1
|
+
from pydantic import BaseModel, Field
|
2
|
+
from .rule import DEFAULT_RULES, Rule
|
3
|
+
|
4
|
+
|
5
|
+
rules_ = {
|
6
|
+
"choice": DEFAULT_RULES.CHOICE.value,
|
7
|
+
"actionrequest": DEFAULT_RULES.ACTION_REQUEST.value,
|
8
|
+
"bool": DEFAULT_RULES.BOOL.value,
|
9
|
+
"number": DEFAULT_RULES.NUMBER.value,
|
10
|
+
"dict": DEFAULT_RULES.DICT.value,
|
11
|
+
"str": DEFAULT_RULES.STR.value,
|
12
|
+
}
|
13
|
+
|
14
|
+
order_ = [
|
15
|
+
"choice",
|
16
|
+
"actionrequest",
|
17
|
+
"bool",
|
18
|
+
"number",
|
19
|
+
"dict",
|
20
|
+
"str",
|
21
|
+
]
|
22
|
+
|
23
|
+
|
24
|
+
class Validator(BaseModel):
|
25
|
+
"""
|
26
|
+
rules contain all rules that this validator can apply to data
|
27
|
+
the order determines which rule gets applied in what sequence.
|
28
|
+
notice, if a rule is not present in the orders, it will not be applied.
|
29
|
+
"""
|
30
|
+
|
31
|
+
rules: dict[str, Rule] = Field(
|
32
|
+
default=rules_,
|
33
|
+
description="The rules to be used for validation.",
|
34
|
+
)
|
35
|
+
|
36
|
+
order: list[str] = Field(
|
37
|
+
default=order_,
|
38
|
+
description="The order in which the rules should be applied.",
|
39
|
+
)
|
40
|
+
|
41
|
+
async def validate(self, value, *args, strict=False, **kwargs):
|
42
|
+
|
43
|
+
for i in self.order:
|
44
|
+
if i in self.rules:
|
45
|
+
try:
|
46
|
+
if (
|
47
|
+
a := await self.rules[i].validate(value, *args, **kwargs)
|
48
|
+
is not None
|
49
|
+
):
|
50
|
+
return a
|
51
|
+
except Exception as e:
|
52
|
+
raise ValueError(f"failed to validate field") from e
|
53
|
+
if strict:
|
54
|
+
raise ValueError(f"failed to validate field")
|
55
|
+
|
56
|
+
return value
|