lionagi 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- lionagi/core/execute/structure_executor.py +21 -1
- lionagi/core/flow/monoflow/ReAct.py +3 -1
- lionagi/core/flow/monoflow/followup.py +3 -1
- lionagi/core/generic/component.py +197 -120
- lionagi/core/generic/condition.py +2 -0
- lionagi/core/generic/edge.py +33 -33
- lionagi/core/graph/graph.py +1 -1
- lionagi/core/tool/tool_manager.py +10 -9
- lionagi/experimental/report/form.py +64 -0
- lionagi/experimental/report/report.py +138 -0
- lionagi/experimental/report/util.py +47 -0
- lionagi/experimental/tool/schema.py +3 -3
- lionagi/experimental/tool/tool_manager.py +1 -1
- lionagi/experimental/validator/rule.py +139 -0
- lionagi/experimental/validator/validator.py +56 -0
- lionagi/experimental/work/__init__.py +10 -0
- lionagi/experimental/work/async_queue.py +54 -0
- lionagi/experimental/work/schema.py +60 -17
- lionagi/experimental/work/work_function.py +55 -77
- lionagi/experimental/work/worker.py +56 -12
- lionagi/experimental/work2/__init__.py +0 -0
- lionagi/experimental/work2/form.py +371 -0
- lionagi/experimental/work2/report.py +289 -0
- lionagi/experimental/work2/schema.py +30 -0
- lionagi/experimental/{work → work2}/tests.py +1 -1
- lionagi/experimental/work2/util.py +0 -0
- lionagi/experimental/work2/work.py +0 -0
- lionagi/experimental/work2/work_function.py +89 -0
- lionagi/experimental/work2/worker.py +12 -0
- lionagi/integrations/storage/storage_util.py +4 -4
- lionagi/integrations/storage/structure_excel.py +268 -0
- lionagi/integrations/storage/to_excel.py +18 -9
- lionagi/libs/__init__.py +4 -0
- lionagi/tests/test_core/generic/__init__.py +0 -0
- lionagi/tests/test_core/generic/test_component.py +89 -0
- lionagi/version.py +1 -1
- {lionagi-0.1.1.dist-info → lionagi-0.1.2.dist-info}/METADATA +1 -1
- {lionagi-0.1.1.dist-info → lionagi-0.1.2.dist-info}/RECORD +43 -27
- lionagi/experimental/work/_logger.py +0 -25
- /lionagi/experimental/{work/exchange.py → report/__init__.py} +0 -0
- /lionagi/experimental/{work/util.py → validator/__init__.py} +0 -0
- {lionagi-0.1.1.dist-info → lionagi-0.1.2.dist-info}/LICENSE +0 -0
- {lionagi-0.1.1.dist-info → lionagi-0.1.2.dist-info}/WHEEL +0 -0
- {lionagi-0.1.1.dist-info → lionagi-0.1.2.dist-info}/top_level.txt +0 -0
lionagi/core/generic/edge.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
1
|
"""
|
2
|
-
Module for representing conditions and edges between nodes in a graph
|
2
|
+
Module for representing conditions and edges between nodes in a graph.
|
3
3
|
|
4
|
-
This module provides the base for creating and managing edges that connect
|
5
|
-
within a graph. It includes support for conditional edges, allowing
|
6
|
-
evaluation of connections based on custom logic.
|
4
|
+
This module provides the base for creating and managing edges that connect
|
5
|
+
nodes within a graph. It includes support for conditional edges, allowing
|
6
|
+
the dynamic evaluation of connections based on custom logic.
|
7
7
|
"""
|
8
8
|
|
9
9
|
from typing import Any
|
10
10
|
from pydantic import Field, field_validator
|
11
|
-
from
|
12
|
-
from
|
11
|
+
from .component import BaseComponent
|
12
|
+
from .condition import Condition
|
13
13
|
|
14
14
|
|
15
15
|
class Edge(BaseComponent):
|
@@ -19,12 +19,14 @@ class Edge(BaseComponent):
|
|
19
19
|
Attributes:
|
20
20
|
head (str): The identifier of the head node of the edge.
|
21
21
|
tail (str): The identifier of the tail node of the edge.
|
22
|
-
condition (
|
22
|
+
condition (Condition | None): Optional condition that must be met
|
23
23
|
for the edge to be considered active.
|
24
|
-
label (
|
24
|
+
label (str | None): An optional label for the edge.
|
25
|
+
bundle (bool): A flag indicating if the edge is bundled.
|
25
26
|
|
26
27
|
Methods:
|
27
|
-
check_condition: Evaluates if the condition
|
28
|
+
check_condition: Evaluates if the condition is met.
|
29
|
+
string_condition: Retrieves the condition class source code.
|
28
30
|
"""
|
29
31
|
|
30
32
|
head: str = Field(
|
@@ -37,7 +39,8 @@ class Edge(BaseComponent):
|
|
37
39
|
)
|
38
40
|
condition: Condition | None = Field(
|
39
41
|
default=None,
|
40
|
-
description="
|
42
|
+
description="Optional condition that must be met for the edge "
|
43
|
+
"to be considered active.",
|
41
44
|
)
|
42
45
|
label: str | None = Field(
|
43
46
|
default=None,
|
@@ -51,20 +54,19 @@ class Edge(BaseComponent):
|
|
51
54
|
@field_validator("head", "tail", mode="before")
|
52
55
|
def _validate_head_tail(cls, value):
|
53
56
|
"""
|
54
|
-
Validates
|
57
|
+
Validates head and tail fields to ensure valid node identifiers.
|
55
58
|
|
56
59
|
Args:
|
57
|
-
value: The value of the field being validated.
|
58
|
-
values: A dictionary of all other values on the model.
|
59
|
-
field: The model field being validated.
|
60
|
+
value (Any): The value of the field being validated.
|
60
61
|
|
61
62
|
Returns:
|
62
|
-
The validated value, ensuring it is a valid identifier.
|
63
|
+
str: The validated value, ensuring it is a valid identifier.
|
63
64
|
|
64
65
|
Raises:
|
65
66
|
ValueError: If the validation fails.
|
66
67
|
"""
|
67
|
-
|
68
|
+
|
69
|
+
if isinstance(value, BaseComponent):
|
68
70
|
return value.id_
|
69
71
|
return value
|
70
72
|
|
@@ -73,7 +75,7 @@ class Edge(BaseComponent):
|
|
73
75
|
Evaluates if the condition associated with the edge is met.
|
74
76
|
|
75
77
|
Args:
|
76
|
-
obj (dict[str, Any]):
|
78
|
+
obj (dict[str, Any]): Context for condition evaluation.
|
77
79
|
|
78
80
|
Returns:
|
79
81
|
bool: True if the condition is met, False otherwise.
|
@@ -87,19 +89,23 @@ class Edge(BaseComponent):
|
|
87
89
|
|
88
90
|
def string_condition(self):
|
89
91
|
"""
|
90
|
-
Retrieves the
|
92
|
+
Retrieves the condition class source code as a string.
|
91
93
|
|
92
|
-
This method is useful for serialization and debugging, allowing
|
93
|
-
|
94
|
-
|
94
|
+
This method is useful for serialization and debugging, allowing
|
95
|
+
the condition logic to be inspected or stored in a human-readable
|
96
|
+
format. It employs advanced introspection techniques to locate and
|
97
|
+
extract the exact class definition, handling edge cases like
|
98
|
+
dynamically defined classes or classes defined interactively.
|
95
99
|
|
96
100
|
Returns:
|
97
|
-
str: The source code
|
98
|
-
|
101
|
+
str | None: The condition class source code if available.
|
102
|
+
If the condition is None or the source code cannot be
|
103
|
+
located, this method returns None.
|
99
104
|
|
100
105
|
Raises:
|
101
|
-
TypeError: If the source code
|
102
|
-
|
106
|
+
TypeError: If the condition class source code cannot be found
|
107
|
+
due to the class being defined in a non-standard manner or
|
108
|
+
in the interactive interpreter (__main__ context).
|
103
109
|
"""
|
104
110
|
if self.condition is None:
|
105
111
|
return
|
@@ -139,9 +145,8 @@ class Edge(BaseComponent):
|
|
139
145
|
|
140
146
|
def __str__(self) -> str:
|
141
147
|
"""
|
142
|
-
Returns a simple string representation of the
|
148
|
+
Returns a simple string representation of the Edge.
|
143
149
|
"""
|
144
|
-
|
145
150
|
return (
|
146
151
|
f"Edge (id_={self.id_}, from={self.head}, to={self.tail}, "
|
147
152
|
f"label={self.label})"
|
@@ -149,12 +154,7 @@ class Edge(BaseComponent):
|
|
149
154
|
|
150
155
|
def __repr__(self) -> str:
|
151
156
|
"""
|
152
|
-
Returns a detailed string representation of the
|
153
|
-
|
154
|
-
Examples:
|
155
|
-
>>> edge = Relationship(source_node_id="node1", target_node_id="node2")
|
156
|
-
>>> repr(edge)
|
157
|
-
'Relationship(id_=None, from=node1, to=node2, content=None, metadata=None, label=None)'
|
157
|
+
Returns a detailed string representation of the Edge.
|
158
158
|
"""
|
159
159
|
return (
|
160
160
|
f"Edge(id_={self.id_}, from={self.head}, to={self.tail}, "
|
lionagi/core/graph/graph.py
CHANGED
@@ -92,7 +92,7 @@ class Graph(BaseStructure):
|
|
92
92
|
for node_id, node in self.internal_nodes.items():
|
93
93
|
node_info = node.to_dict()
|
94
94
|
node_info.pop("id_")
|
95
|
-
node_info.update({"class_name": node.class_name
|
95
|
+
node_info.update({"class_name": node.class_name})
|
96
96
|
g.add_node(node_id, **node_info)
|
97
97
|
|
98
98
|
for _edge in list(self.internal_edges.values()):
|
@@ -159,15 +159,16 @@ class ToolManager:
|
|
159
159
|
else:
|
160
160
|
raise ValueError(f"Function {tool} is not registered.")
|
161
161
|
|
162
|
-
if
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
162
|
+
if tools:
|
163
|
+
if isinstance(tools, bool):
|
164
|
+
tool_kwarg = {"tools": self.to_tool_schema_list()}
|
165
|
+
kwargs = tool_kwarg | kwargs
|
166
|
+
|
167
|
+
else:
|
168
|
+
if not isinstance(tools, list):
|
169
|
+
tools = [tools]
|
170
|
+
tool_kwarg = {"tools": func_call.lcall(tools, tool_check)}
|
171
|
+
kwargs = tool_kwarg | kwargs
|
171
172
|
|
172
173
|
return kwargs
|
173
174
|
|
@@ -0,0 +1,64 @@
|
|
1
|
+
from pydantic import Field
|
2
|
+
|
3
|
+
# from lionagi import logging as _logging
|
4
|
+
from lionagi.core.generic import BaseComponent
|
5
|
+
from lionagi.experimental.report.util import get_input_output_fields, system_fields
|
6
|
+
|
7
|
+
|
8
|
+
class Form(BaseComponent):
|
9
|
+
|
10
|
+
assignment: str = Field(..., examples=["input1, input2 -> output"])
|
11
|
+
|
12
|
+
input_fields: list[str] = Field(default_factory=list)
|
13
|
+
output_fields: list[str] = Field(default_factory=list)
|
14
|
+
|
15
|
+
def __init__(self, **kwargs):
|
16
|
+
"""
|
17
|
+
at initialization, all relevant fields if not already provided, are set to None,
|
18
|
+
not every field is required to be filled, nor required to be declared at initialization
|
19
|
+
"""
|
20
|
+
super().__init__(**kwargs)
|
21
|
+
self.input_fields, self.output_fields = get_input_output_fields(self.assignment)
|
22
|
+
for i in self.input_fields + self.output_fields:
|
23
|
+
if i not in self.model_fields:
|
24
|
+
self._add_field(i, value=None)
|
25
|
+
|
26
|
+
@property
|
27
|
+
def workable(self):
|
28
|
+
if self.filled:
|
29
|
+
return False
|
30
|
+
|
31
|
+
for i in self.input_fields:
|
32
|
+
if not getattr(self, i, None):
|
33
|
+
return False
|
34
|
+
|
35
|
+
return True
|
36
|
+
|
37
|
+
@property
|
38
|
+
def work_fields(self):
|
39
|
+
dict_ = self.to_dict()
|
40
|
+
return {
|
41
|
+
k: v
|
42
|
+
for k, v in dict_.items()
|
43
|
+
if k not in system_fields and k in self.input_fields + self.output_fields
|
44
|
+
}
|
45
|
+
|
46
|
+
@property
|
47
|
+
def filled(self):
|
48
|
+
return all([value is not None for _, value in self.work_fields.items()])
|
49
|
+
|
50
|
+
def fill(self, form: "Form" = None, **kwargs):
|
51
|
+
"""
|
52
|
+
only work fields for this form can be filled
|
53
|
+
a field can only be filled once
|
54
|
+
"""
|
55
|
+
if self.filled:
|
56
|
+
raise ValueError("Form is already filled")
|
57
|
+
|
58
|
+
fields = form.work_fields if form else {}
|
59
|
+
kwargs = {**fields, **kwargs}
|
60
|
+
|
61
|
+
for k, v in kwargs.items():
|
62
|
+
if k not in self.work_fields:
|
63
|
+
raise ValueError(f"Field {k} is not a valid work field")
|
64
|
+
setattr(self, k, v)
|
@@ -0,0 +1,138 @@
|
|
1
|
+
from typing import Any, Type
|
2
|
+
from pydantic import Field
|
3
|
+
|
4
|
+
# from lionagi import logging as _logging
|
5
|
+
from lionagi.core.generic import BaseComponent
|
6
|
+
from lionagi.experimental.report.form import Form
|
7
|
+
from lionagi.experimental.report.util import get_input_output_fields
|
8
|
+
|
9
|
+
|
10
|
+
class Report(BaseComponent):
|
11
|
+
|
12
|
+
assignment: str = Field(..., examples=["input1, input2 -> output"])
|
13
|
+
|
14
|
+
forms: dict[str, Form] = Field(
|
15
|
+
default_factory=dict,
|
16
|
+
description="A dictionary of forms related to the report, in {assignment: Form} format.",
|
17
|
+
)
|
18
|
+
|
19
|
+
form_assignments: list = Field(
|
20
|
+
[],
|
21
|
+
description="assignment for the report",
|
22
|
+
examples=[["a, b -> c", "a -> e", "b -> f", "c -> g", "e, f, g -> h"]],
|
23
|
+
)
|
24
|
+
|
25
|
+
form_template: Type[Form] = Field(
|
26
|
+
Form, description="The template for the forms in the report."
|
27
|
+
)
|
28
|
+
|
29
|
+
input_fields: list[str] = Field(default_factory=list)
|
30
|
+
output_fields: list[str] = Field(default_factory=list)
|
31
|
+
|
32
|
+
def __init__(self, **kwargs):
|
33
|
+
"""
|
34
|
+
at initialization, all relevant fields if not already provided, are set to None
|
35
|
+
"""
|
36
|
+
super().__init__(**kwargs)
|
37
|
+
self.input_fields, self.output_fields = get_input_output_fields(self.assignment)
|
38
|
+
|
39
|
+
# if assignments is not provided, set it to assignment
|
40
|
+
if self.form_assignments == []:
|
41
|
+
self.form_assignments.append(self.assignment)
|
42
|
+
|
43
|
+
# create forms
|
44
|
+
new_forms = {i: self.form_template(assignment=i) for i in self.form_assignments}
|
45
|
+
|
46
|
+
# add new forms into the report (will ignore new forms already in the
|
47
|
+
# report with same assignment)
|
48
|
+
for k, v in new_forms.items():
|
49
|
+
if k not in self.forms:
|
50
|
+
self.forms[k] = v
|
51
|
+
|
52
|
+
# if the fields are not declared in the report, add them to report
|
53
|
+
# with value set to None
|
54
|
+
for k, v in self.forms.items():
|
55
|
+
for f in list(v.work_fields.keys()):
|
56
|
+
if f not in self.model_fields:
|
57
|
+
field = v.model_fields[f]
|
58
|
+
self._add_field(f, value=None, field=field)
|
59
|
+
|
60
|
+
# if there are fields in the report that are not in the forms, add them to
|
61
|
+
# the forms with values
|
62
|
+
for k, v in self.model_fields.items():
|
63
|
+
if getattr(self, k, None) is not None:
|
64
|
+
for f in self.forms.values():
|
65
|
+
if k in f.work_fields:
|
66
|
+
f.fill(**{k: getattr(self, k)})
|
67
|
+
|
68
|
+
@property
|
69
|
+
def work_fields(self) -> dict[str, Any]:
|
70
|
+
"""
|
71
|
+
all work fields across all forms, including intermediate output fields,
|
72
|
+
this information is extracted from the forms
|
73
|
+
"""
|
74
|
+
|
75
|
+
all_fields = {}
|
76
|
+
for form in self.forms.values():
|
77
|
+
for k, v in form.work_fields.items():
|
78
|
+
if k not in all_fields:
|
79
|
+
all_fields[k] = v
|
80
|
+
return all_fields
|
81
|
+
|
82
|
+
def fill(self, **kwargs):
|
83
|
+
"""
|
84
|
+
fill the information to both the report and forms
|
85
|
+
"""
|
86
|
+
kwargs = {**self.work_fields, **kwargs}
|
87
|
+
for k, v in kwargs.items():
|
88
|
+
if k in self.work_fields and getattr(self, k, None) is None:
|
89
|
+
setattr(self, k, v)
|
90
|
+
|
91
|
+
for form in self.forms.values():
|
92
|
+
if not form.filled:
|
93
|
+
_kwargs = {k: v for k, v in kwargs.items() if k in form.work_fields}
|
94
|
+
form.fill(**_kwargs)
|
95
|
+
|
96
|
+
@property
|
97
|
+
def filled(self):
|
98
|
+
return all([value is not None for _, value in self.work_fields.items()])
|
99
|
+
|
100
|
+
@property
|
101
|
+
def workable(self) -> bool:
|
102
|
+
|
103
|
+
if self.filled:
|
104
|
+
# _logging.info("The report is already filled, no need to work on it.")
|
105
|
+
return False
|
106
|
+
|
107
|
+
for i in self.input_fields:
|
108
|
+
if not getattr(self, i, None):
|
109
|
+
# _logging.error(f"Field '{i}' is required to work on the report.")
|
110
|
+
return False
|
111
|
+
|
112
|
+
# this is the required fields from report's own assignment
|
113
|
+
fields = self.input_fields
|
114
|
+
fields.extend(self.output_fields)
|
115
|
+
|
116
|
+
# if the report's own assignment is not in the forms, return False
|
117
|
+
for f in fields:
|
118
|
+
if f not in self.work_fields:
|
119
|
+
# _logging.error(f"Field {f} is a required deliverable, not found in work field.")
|
120
|
+
return False
|
121
|
+
|
122
|
+
# get all the output fields from all the forms
|
123
|
+
outs = []
|
124
|
+
for form in self.forms.values():
|
125
|
+
outs.extend(form.output_fields)
|
126
|
+
|
127
|
+
# all output fields should be unique, not a single output field should be
|
128
|
+
# calculated by more than one form
|
129
|
+
if len(outs) != len(set(outs)):
|
130
|
+
# _logging.error("There are duplicate output fields in the forms.")
|
131
|
+
return False
|
132
|
+
|
133
|
+
return True
|
134
|
+
|
135
|
+
@property
|
136
|
+
def next_forms(self) -> list[Form] | None:
|
137
|
+
a = [i for i in self.forms.values() if i.workable]
|
138
|
+
return a if len(a) > 0 else None
|
@@ -0,0 +1,47 @@
|
|
1
|
+
from lionagi.libs import convert
|
2
|
+
|
3
|
+
system_fields = [
|
4
|
+
"id_",
|
5
|
+
"node_id",
|
6
|
+
"meta",
|
7
|
+
"metadata",
|
8
|
+
"timestamp",
|
9
|
+
"content",
|
10
|
+
"assignment",
|
11
|
+
"assignments",
|
12
|
+
"task",
|
13
|
+
"template_name",
|
14
|
+
"version",
|
15
|
+
"description",
|
16
|
+
"in_validation_kwargs",
|
17
|
+
"out_validation_kwargs",
|
18
|
+
"fix_input",
|
19
|
+
"fix_output",
|
20
|
+
"input_fields",
|
21
|
+
"output_fields",
|
22
|
+
"choices",
|
23
|
+
"prompt_fields",
|
24
|
+
"prompt_fields_annotation",
|
25
|
+
"instruction_context",
|
26
|
+
"instruction",
|
27
|
+
"instruction_output_fields",
|
28
|
+
"inputs",
|
29
|
+
"outputs",
|
30
|
+
"process",
|
31
|
+
"_validate_field",
|
32
|
+
"_process_input",
|
33
|
+
"_process_response",
|
34
|
+
"_validate_field_choices",
|
35
|
+
"_validate_input_choices",
|
36
|
+
"_validate_output_choices",
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
def get_input_output_fields(str_: str) -> list[list[str]]:
|
41
|
+
|
42
|
+
inputs, outputs = str_.split("->")
|
43
|
+
|
44
|
+
input_fields = [convert.strip_lower(i) for i in inputs.split(",")]
|
45
|
+
output_fields = [convert.strip_lower(o) for o in outputs.split(",")]
|
46
|
+
|
47
|
+
return input_fields, output_fields
|
@@ -34,16 +34,16 @@ class Tool(Node):
|
|
34
34
|
out = None
|
35
35
|
|
36
36
|
if self.pre_processor:
|
37
|
-
kwargs = await func_call.
|
37
|
+
kwargs = await func_call.call_handler(self.pre_processor, kwargs)
|
38
38
|
try:
|
39
|
-
out = await func_call.
|
39
|
+
out = await func_call.call_handler(self.func, **kwargs)
|
40
40
|
|
41
41
|
except Exception as e:
|
42
42
|
_logging.error(f"Error invoking function {self.func_name}: {e}")
|
43
43
|
return None
|
44
44
|
|
45
45
|
if self.post_processor:
|
46
|
-
return await func_call.
|
46
|
+
return await func_call.call_handler(self.post_processor, out)
|
47
47
|
|
48
48
|
return out
|
49
49
|
|
@@ -0,0 +1,139 @@
|
|
1
|
+
from lionagi.libs import validation_funcs
|
2
|
+
from abc import abstractmethod
|
3
|
+
|
4
|
+
|
5
|
+
class Rule:
|
6
|
+
|
7
|
+
def __init__(self, **kwargs):
|
8
|
+
self.validation_kwargs = kwargs
|
9
|
+
self.fix = kwargs.get("fix", False)
|
10
|
+
|
11
|
+
@abstractmethod
|
12
|
+
def condition(self, **kwargs):
|
13
|
+
pass
|
14
|
+
|
15
|
+
@abstractmethod
|
16
|
+
async def validate(self, value, **kwargs):
|
17
|
+
pass
|
18
|
+
|
19
|
+
|
20
|
+
class ChoiceRule(Rule):
|
21
|
+
|
22
|
+
def condition(self, choices=None):
|
23
|
+
return choices is not None
|
24
|
+
|
25
|
+
def check(self, choices=None):
|
26
|
+
if choices and not isinstance(choices, list):
|
27
|
+
try:
|
28
|
+
choices = [i.value for i in choices]
|
29
|
+
except Exception as e:
|
30
|
+
raise ValueError(f"failed to get choices") from e
|
31
|
+
return choices
|
32
|
+
|
33
|
+
def fix(self, value, choices=None, **kwargs):
|
34
|
+
v_ = validation_funcs["enum"](value, choices=choices, fix_=True, **kwargs)
|
35
|
+
return v_
|
36
|
+
|
37
|
+
async def validate(self, value, choices=None, **kwargs):
|
38
|
+
if self.condition(choices):
|
39
|
+
if value in self.check(choices):
|
40
|
+
return value
|
41
|
+
if self.fix:
|
42
|
+
kwargs = {**self.validation_kwargs, **kwargs}
|
43
|
+
return self.fix(value, choices, **kwargs)
|
44
|
+
raise ValueError(f"{value} is not in chocies {choices}")
|
45
|
+
|
46
|
+
|
47
|
+
class ActionRequestRule(Rule):
|
48
|
+
|
49
|
+
def condition(self, annotation=None):
|
50
|
+
return any("actionrequest" in i for i in annotation)
|
51
|
+
|
52
|
+
async def validate(self, value, annotation=None):
|
53
|
+
if self.condition(annotation):
|
54
|
+
try:
|
55
|
+
return validation_funcs["action"](value)
|
56
|
+
except Exception as e:
|
57
|
+
raise ValueError(f"failed to validate field") from e
|
58
|
+
|
59
|
+
|
60
|
+
class BooleanRule(Rule):
|
61
|
+
|
62
|
+
def condition(self, annotation=None):
|
63
|
+
return "bool" in annotation and "str" not in annotation
|
64
|
+
|
65
|
+
async def validate(self, value, annotation=None):
|
66
|
+
if self.condition(annotation):
|
67
|
+
try:
|
68
|
+
return validation_funcs["bool"](
|
69
|
+
value, fix_=self.fix, **self.validation_kwargs
|
70
|
+
)
|
71
|
+
except Exception as e:
|
72
|
+
raise ValueError(f"failed to validate field") from e
|
73
|
+
|
74
|
+
|
75
|
+
class NumberRule(Rule):
|
76
|
+
|
77
|
+
def condition(self, annotation=None):
|
78
|
+
return (
|
79
|
+
any([i in annotation for i in ["int", "float", "number"]])
|
80
|
+
and "str" not in annotation
|
81
|
+
)
|
82
|
+
|
83
|
+
async def validate(self, value, annotation=None):
|
84
|
+
if self.condition(annotation):
|
85
|
+
if "float" in annotation:
|
86
|
+
self.validation_kwargs["num_type"] = float
|
87
|
+
if "precision" not in self.validation_kwargs:
|
88
|
+
self.validation_kwargs["precision"] = 32
|
89
|
+
|
90
|
+
try:
|
91
|
+
return validation_funcs["number"](
|
92
|
+
value, fix_=self.fix, **self.validation_kwargs
|
93
|
+
)
|
94
|
+
except Exception as e:
|
95
|
+
raise ValueError(f"failed to validate field") from e
|
96
|
+
|
97
|
+
|
98
|
+
class DictRule(Rule):
|
99
|
+
|
100
|
+
def condition(self, annotation=None):
|
101
|
+
return "dict" in annotation
|
102
|
+
|
103
|
+
async def validate(self, value, annotation=None, keys=None):
|
104
|
+
if self.condition(annotation):
|
105
|
+
if "str" not in annotation or keys:
|
106
|
+
try:
|
107
|
+
return validation_funcs["dict"](
|
108
|
+
value, keys=keys, fix_=self.fix, **self.validation_kwargs
|
109
|
+
)
|
110
|
+
except Exception as e:
|
111
|
+
raise ValueError(f"failed to validate field") from e
|
112
|
+
raise ValueError(f"failed to validate field")
|
113
|
+
|
114
|
+
|
115
|
+
class StringRule(Rule):
|
116
|
+
|
117
|
+
def condition(self, annotation=None):
|
118
|
+
return "str" in annotation
|
119
|
+
|
120
|
+
async def validate(self, value, annotation=None):
|
121
|
+
if self.condition(annotation):
|
122
|
+
try:
|
123
|
+
return validation_funcs["str"](
|
124
|
+
value, fix_=self.fix, **self.validation_kwargs
|
125
|
+
)
|
126
|
+
except Exception as e:
|
127
|
+
raise ValueError(f"failed to validate field") from e
|
128
|
+
|
129
|
+
|
130
|
+
from enum import Enum
|
131
|
+
|
132
|
+
|
133
|
+
class DEFAULT_RULES(Enum):
|
134
|
+
CHOICE = ChoiceRule
|
135
|
+
ACTION_REQUEST = ActionRequestRule
|
136
|
+
BOOL = BooleanRule
|
137
|
+
NUMBER = NumberRule
|
138
|
+
DICT = DictRule
|
139
|
+
STR = StringRule
|
@@ -0,0 +1,56 @@
|
|
1
|
+
from pydantic import BaseModel, Field
|
2
|
+
from .rule import DEFAULT_RULES, Rule
|
3
|
+
|
4
|
+
|
5
|
+
rules_ = {
|
6
|
+
"choice": DEFAULT_RULES.CHOICE.value,
|
7
|
+
"actionrequest": DEFAULT_RULES.ACTION_REQUEST.value,
|
8
|
+
"bool": DEFAULT_RULES.BOOL.value,
|
9
|
+
"number": DEFAULT_RULES.NUMBER.value,
|
10
|
+
"dict": DEFAULT_RULES.DICT.value,
|
11
|
+
"str": DEFAULT_RULES.STR.value,
|
12
|
+
}
|
13
|
+
|
14
|
+
order_ = [
|
15
|
+
"choice",
|
16
|
+
"actionrequest",
|
17
|
+
"bool",
|
18
|
+
"number",
|
19
|
+
"dict",
|
20
|
+
"str",
|
21
|
+
]
|
22
|
+
|
23
|
+
|
24
|
+
class Validator(BaseModel):
|
25
|
+
"""
|
26
|
+
rules contain all rules that this validator can apply to data
|
27
|
+
the order determines which rule gets applied in what sequence.
|
28
|
+
notice, if a rule is not present in the orders, it will not be applied.
|
29
|
+
"""
|
30
|
+
|
31
|
+
rules: dict[str, Rule] = Field(
|
32
|
+
default=rules_,
|
33
|
+
description="The rules to be used for validation.",
|
34
|
+
)
|
35
|
+
|
36
|
+
order: list[str] = Field(
|
37
|
+
default=order_,
|
38
|
+
description="The order in which the rules should be applied.",
|
39
|
+
)
|
40
|
+
|
41
|
+
async def validate(self, value, *args, strict=False, **kwargs):
|
42
|
+
|
43
|
+
for i in self.order:
|
44
|
+
if i in self.rules:
|
45
|
+
try:
|
46
|
+
if (
|
47
|
+
a := await self.rules[i].validate(value, *args, **kwargs)
|
48
|
+
is not None
|
49
|
+
):
|
50
|
+
return a
|
51
|
+
except Exception as e:
|
52
|
+
raise ValueError(f"failed to validate field") from e
|
53
|
+
if strict:
|
54
|
+
raise ValueError(f"failed to validate field")
|
55
|
+
|
56
|
+
return value
|