ripple-down-rules 0.6.51__py3-none-any.whl → 0.6.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +12 -4
- ripple_down_rules/datastructures/dataclasses.py +52 -9
- ripple_down_rules/datastructures/enums.py +14 -87
- ripple_down_rules/datastructures/field_info.py +177 -0
- ripple_down_rules/datastructures/tracked_object.py +208 -0
- ripple_down_rules/helpers.py +37 -2
- ripple_down_rules/predicates.py +97 -0
- ripple_down_rules/rdr.py +10 -6
- ripple_down_rules/rdr_decorators.py +44 -34
- ripple_down_rules/rules.py +166 -97
- ripple_down_rules/user_interface/ipython_custom_shell.py +9 -1
- ripple_down_rules/user_interface/prompt.py +37 -37
- ripple_down_rules/user_interface/template_file_creator.py +3 -0
- ripple_down_rules/utils.py +32 -5
- {ripple_down_rules-0.6.51.dist-info → ripple_down_rules-0.6.60.dist-info}/METADATA +3 -1
- ripple_down_rules-0.6.60.dist-info/RECORD +28 -0
- ripple_down_rules-0.6.51.dist-info/RECORD +0 -25
- {ripple_down_rules-0.6.51.dist-info → ripple_down_rules-0.6.60.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.6.51.dist-info → ripple_down_rules-0.6.60.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.6.51.dist-info → ripple_down_rules-0.6.60.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,97 @@
|
|
1
|
+
import os.path
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from enum import Enum
|
5
|
+
from os.path import dirname
|
6
|
+
|
7
|
+
from typing_extensions import Type, ClassVar, TYPE_CHECKING, Tuple, Optional, Callable
|
8
|
+
|
9
|
+
from .datastructures.tracked_object import TrackedObjectMixin, Direction, Relation
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from .rdr_decorators import RDRDecorator
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass(eq=False)
|
16
|
+
class Predicate(TrackedObjectMixin, ABC):
|
17
|
+
|
18
|
+
def __call__(self, *args, **kwargs):
|
19
|
+
return self.evaluate(*args, **kwargs)
|
20
|
+
|
21
|
+
@classmethod
|
22
|
+
@abstractmethod
|
23
|
+
def evaluate(cls, *args, **kwargs):
|
24
|
+
"""
|
25
|
+
Evaluate the predicate with the given arguments.
|
26
|
+
This method should be implemented by subclasses.
|
27
|
+
"""
|
28
|
+
pass
|
29
|
+
|
30
|
+
@classmethod
|
31
|
+
def rdr_decorator(cls, output_types: Tuple[Type, ...], mutually_exclusive: bool,
|
32
|
+
package_name: Optional[str] = None) -> Callable[[Callable], Callable]:
|
33
|
+
"""
|
34
|
+
Returns the RDRDecorator to decorate the predicate evaluate method with.
|
35
|
+
"""
|
36
|
+
rdr_decorator: RDRDecorator = RDRDecorator(cls.models_dir, output_types, mutually_exclusive,
|
37
|
+
package_name=package_name)
|
38
|
+
return rdr_decorator.decorator
|
39
|
+
|
40
|
+
def __hash__(self):
|
41
|
+
return hash(self.__class__.__name__)
|
42
|
+
|
43
|
+
def __eq__(self, other):
|
44
|
+
if not isinstance(other, Predicate):
|
45
|
+
return False
|
46
|
+
return self.__class__ == other.__class__
|
47
|
+
|
48
|
+
|
49
|
+
@dataclass
|
50
|
+
class IsA(Predicate):
|
51
|
+
"""
|
52
|
+
A predicate that checks if an object type is a subclass of another object type.
|
53
|
+
"""
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def evaluate(cls, child_type: Type[TrackedObjectMixin], parent_type: Type[TrackedObjectMixin]) -> bool:
|
57
|
+
return issubclass(child_type, parent_type)
|
58
|
+
|
59
|
+
isA = IsA()
|
60
|
+
|
61
|
+
|
62
|
+
@dataclass
|
63
|
+
class Has(Predicate):
|
64
|
+
"""
|
65
|
+
A predicate that checks if an object type has a certain member object type.
|
66
|
+
"""
|
67
|
+
|
68
|
+
@classmethod
|
69
|
+
def evaluate(cls, owner_type: Type[TrackedObjectMixin],
|
70
|
+
member_type: Type[TrackedObjectMixin], recursive: bool = False) -> bool:
|
71
|
+
neighbors = cls._dependency_graph.adj_direction(owner_type._my_graph_idx(), Direction.OUTBOUND.value)
|
72
|
+
curr_val = any(e == Relation.has and isA(cls._dependency_graph.get_node_data(n), member_type)
|
73
|
+
or e == Relation.isA and cls.evaluate(cls._dependency_graph.get_node_data(n), member_type)
|
74
|
+
for n, e in neighbors.items())
|
75
|
+
if recursive:
|
76
|
+
return curr_val or any((e == Relation.has
|
77
|
+
and cls.evaluate(cls._dependency_graph.get_node_data(n), member_type, recursive=True))
|
78
|
+
for n, e in neighbors.items())
|
79
|
+
else:
|
80
|
+
return curr_val
|
81
|
+
|
82
|
+
has = Has()
|
83
|
+
|
84
|
+
|
85
|
+
@dataclass
|
86
|
+
class DependsOn(Predicate):
|
87
|
+
"""
|
88
|
+
A predicate that checks if an object type depends on another object type.
|
89
|
+
"""
|
90
|
+
|
91
|
+
@classmethod
|
92
|
+
def evaluate(cls, dependent: Type[TrackedObjectMixin],
|
93
|
+
dependency: Type[TrackedObjectMixin], recursive: bool = False) -> bool:
|
94
|
+
raise NotImplementedError("Should be overridden in rdr meta")
|
95
|
+
|
96
|
+
|
97
|
+
dependsOn = DependsOn()
|
ripple_down_rules/rdr.py
CHANGED
@@ -113,7 +113,10 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
113
113
|
if self.save_dir is not None and self.model_name is not None:
|
114
114
|
model_path = os.path.join(self.save_dir, self.model_name)
|
115
115
|
if os.path.exists(model_path):
|
116
|
-
|
116
|
+
try:
|
117
|
+
self.update_from_python(model_path, update_rule_tree=True)
|
118
|
+
except (FileNotFoundError, ModuleNotFoundError) as e:
|
119
|
+
pass
|
117
120
|
|
118
121
|
def write_rdr_metadata_to_pyton_file(self, file: TextIOWrapper):
|
119
122
|
"""
|
@@ -343,11 +346,11 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
343
346
|
rule.reset()
|
344
347
|
if self.start_rule is not None and self.start_rule.parent is None:
|
345
348
|
if self.input_node is None:
|
346
|
-
self.input_node = type(self.start_rule)(
|
349
|
+
self.input_node = type(self.start_rule)(_parent=None, uid='0')
|
347
350
|
self.input_node.evaluated = False
|
348
351
|
self.input_node.fired = False
|
349
352
|
self.start_rule.parent = self.input_node
|
350
|
-
self.start_rule.weight =
|
353
|
+
self.start_rule.weight = RDREdge.Empty
|
351
354
|
if self.input_node is not None:
|
352
355
|
data = case.__dict__ if is_dataclass(case) else case
|
353
356
|
if hasattr(case, "items"):
|
@@ -631,7 +634,7 @@ class TreeBuilder(ast.NodeVisitor, ABC):
|
|
631
634
|
rule_uid = condition.split("conditions_")[1]
|
632
635
|
|
633
636
|
new_rule_type = self.get_new_rule_type(node)
|
634
|
-
new_node = new_rule_type(conditions=condition,
|
637
|
+
new_node = new_rule_type(conditions=condition, _parent=self.current_parent, uid=rule_uid)
|
635
638
|
if self.current_parent is not None:
|
636
639
|
self.update_current_parent(new_node)
|
637
640
|
|
@@ -1009,6 +1012,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
1009
1012
|
defs_imports = get_imports_from_types(defs_types, defs_file_name, package_name)
|
1010
1013
|
corner_cases_imports = get_imports_from_types(corner_cases_types, cases_file_name, package_name)
|
1011
1014
|
|
1015
|
+
defs_imports.append(f"from ripple_down_rules import *")
|
1012
1016
|
# Add the imports to the defs file
|
1013
1017
|
with open(defs_file_name, "w") as f:
|
1014
1018
|
f.write('\n'.join(defs_imports) + "\n\n\n")
|
@@ -1068,7 +1072,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
1068
1072
|
main_types.update({Union, Optional})
|
1069
1073
|
defs_types.add(Union)
|
1070
1074
|
main_types.update({Case, create_case})
|
1071
|
-
main_types = main_types.difference(defs_types)
|
1075
|
+
# main_types = main_types.difference(defs_types)
|
1072
1076
|
return main_types, defs_types, cases_types
|
1073
1077
|
|
1074
1078
|
@property
|
@@ -1170,7 +1174,7 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
1170
1174
|
self.default_conclusion = case_query.default_value
|
1171
1175
|
|
1172
1176
|
pred = self.evaluate(case_query.case)
|
1173
|
-
if pred.conclusion(case_query.case) != case_query.target_value:
|
1177
|
+
if (not pred.fired and self.default_conclusion is None) or pred.conclusion(case_query.case) != case_query.target_value:
|
1174
1178
|
expert.ask_for_conditions(case_query, pred)
|
1175
1179
|
pred.fit_rule(case_query)
|
1176
1180
|
|
@@ -3,6 +3,7 @@ This file contains decorators for the RDR (Ripple Down Rules) framework. Where e
|
|
3
3
|
that can be used with any python function such that this function can benefit from the incremental knowledge acquisition
|
4
4
|
of the RDRs.
|
5
5
|
"""
|
6
|
+
import inspect
|
6
7
|
import os.path
|
7
8
|
from dataclasses import dataclass, field
|
8
9
|
from functools import wraps
|
@@ -11,8 +12,9 @@ from typing import get_origin
|
|
11
12
|
from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union, Sequence
|
12
13
|
|
13
14
|
from .datastructures.case import Case
|
14
|
-
from .datastructures.dataclasses import CaseQuery
|
15
|
+
from .datastructures.dataclasses import CaseQuery, CaseFactoryMetaData
|
15
16
|
from .experts import Expert, Human
|
17
|
+
from .failures import RDRLoadError
|
16
18
|
from .rdr import GeneralRDR
|
17
19
|
from .utils import get_type_from_type_hint
|
18
20
|
|
@@ -21,7 +23,8 @@ try:
|
|
21
23
|
except ImportError:
|
22
24
|
RDRCaseViewer = None
|
23
25
|
from .utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
|
24
|
-
get_method_class_if_exists,
|
26
|
+
get_method_class_if_exists, make_list
|
27
|
+
from .helpers import create_case_from_method
|
25
28
|
|
26
29
|
|
27
30
|
@dataclass(unsafe_hash=True)
|
@@ -83,7 +86,7 @@ class RDRDecorator:
|
|
83
86
|
The name of the rdr model, this gets auto generated from the function signature and the class/file it is contained
|
84
87
|
in.
|
85
88
|
"""
|
86
|
-
rdr: GeneralRDR = field(init=False)
|
89
|
+
rdr: GeneralRDR = field(init=False, default=None)
|
87
90
|
"""
|
88
91
|
The ripple down rules instance of the decorator class.
|
89
92
|
"""
|
@@ -109,12 +112,22 @@ class RDRDecorator:
|
|
109
112
|
This is a flag that indicates that a not None output for the rdr has been inferred, this is used to update the
|
110
113
|
generated dot file if it is set to `True`.
|
111
114
|
"""
|
115
|
+
case_factory_metadata: CaseFactoryMetaData = field(init=False, default_factory=CaseFactoryMetaData)
|
116
|
+
"""
|
117
|
+
Metadata that contains the case factory method, and the scenario that is being run during the case query.
|
118
|
+
"""
|
112
119
|
|
113
120
|
def decorator(self, func: Callable) -> Callable:
|
114
121
|
|
115
122
|
@wraps(func)
|
116
123
|
def wrapper(*args, **kwargs) -> Optional[Any]:
|
117
124
|
|
125
|
+
original_kwargs = {pname: p for pname, p in inspect.signature(func).parameters.items() if
|
126
|
+
p.default != inspect._empty}
|
127
|
+
for og_kwarg in original_kwargs:
|
128
|
+
if og_kwarg not in kwargs:
|
129
|
+
kwargs[og_kwarg] = original_kwargs[og_kwarg].default
|
130
|
+
|
118
131
|
if self.model_name is None:
|
119
132
|
self.initialize_rdr_model_name_and_load(func)
|
120
133
|
if self.origin_type is None and not self.mutual_exclusive:
|
@@ -124,7 +137,7 @@ class RDRDecorator:
|
|
124
137
|
|
125
138
|
func_output = {self.output_name: func(*args, **kwargs)}
|
126
139
|
|
127
|
-
case, case_dict =
|
140
|
+
case, case_dict = create_case_from_method(func, func_output, *args, **kwargs)
|
128
141
|
|
129
142
|
@self.fitting_decorator
|
130
143
|
def fit():
|
@@ -132,11 +145,14 @@ class RDRDecorator:
|
|
132
145
|
self.parsed_output_type = self.parse_output_type(func, self.output_type, *args)
|
133
146
|
if self.expert is None:
|
134
147
|
self.expert = Human(answers_save_path=self.models_dir + f'/{self.model_name}/expert_answers')
|
135
|
-
case_query = self.create_case_query_from_method(
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
148
|
+
case_query = self.create_case_query_from_method(
|
149
|
+
func, func_output,
|
150
|
+
self.parsed_output_type,
|
151
|
+
self.mutual_exclusive,
|
152
|
+
args, kwargs,
|
153
|
+
case=case, case_dict=case_dict,
|
154
|
+
scenario=self.case_factory_metadata.scenario,
|
155
|
+
this_case_target_value=self.case_factory_metadata.this_case_target_value)
|
140
156
|
output = self.rdr.fit_case(case_query, expert=self.expert,
|
141
157
|
update_existing_rules=self.update_existing_rules)
|
142
158
|
return output
|
@@ -166,6 +182,8 @@ class RDRDecorator:
|
|
166
182
|
else:
|
167
183
|
return func_output[self.output_name]
|
168
184
|
|
185
|
+
wrapper._rdr_decorator_instance = self
|
186
|
+
|
169
187
|
return wrapper
|
170
188
|
|
171
189
|
@staticmethod
|
@@ -173,9 +191,11 @@ class RDRDecorator:
|
|
173
191
|
func_output: Dict[str, Any],
|
174
192
|
output_type: Sequence[Type],
|
175
193
|
mutual_exclusive: bool,
|
194
|
+
func_args: Tuple[Any, ...], func_kwargs: Dict[str, Any],
|
176
195
|
case: Optional[Case] = None,
|
177
196
|
case_dict: Optional[Dict[str, Any]] = None,
|
178
|
-
|
197
|
+
scenario: Optional[Callable] = None,
|
198
|
+
this_case_target_value: Optional[Any] = None,) -> CaseQuery:
|
179
199
|
"""
|
180
200
|
Create a CaseQuery from the function and its arguments.
|
181
201
|
|
@@ -183,43 +203,28 @@ class RDRDecorator:
|
|
183
203
|
:param func_output: The output of the function as a dictionary, where the key is the output name.
|
184
204
|
:param output_type: The type of the output as a sequence of types.
|
185
205
|
:param mutual_exclusive: If True, the output types are mutually exclusive.
|
186
|
-
:param
|
187
|
-
:param
|
206
|
+
:param func_args: The positional arguments of the function.
|
207
|
+
:param func_kwargs: The keyword arguments of the function.
|
208
|
+
:param case: The case to create.
|
209
|
+
:param case_dict: The dictionary of the case.
|
210
|
+
:param scenario: The scenario that produced the given case.
|
211
|
+
:param this_case_target_value: The target value for the case.
|
188
212
|
:return: A CaseQuery object representing the case.
|
189
213
|
"""
|
190
214
|
output_type = make_set(output_type)
|
191
215
|
if case is None or case_dict is None:
|
192
|
-
case, case_dict =
|
216
|
+
case, case_dict = create_case_from_method(func, func_output, *func_args, **func_kwargs)
|
193
217
|
scope = func.__globals__
|
194
218
|
scope.update(case_dict)
|
195
219
|
func_args_type_hints = get_type_hints(func)
|
196
220
|
output_name = list(func_output.keys())[0]
|
197
221
|
func_args_type_hints.update({output_name: Union[tuple(output_type)]})
|
198
222
|
return CaseQuery(case, output_name, tuple(output_type),
|
199
|
-
mutual_exclusive, scope=scope,
|
223
|
+
mutual_exclusive, scope=scope, scenario=scenario, this_case_target_value=this_case_target_value,
|
200
224
|
is_function=True, function_args_type_hints=func_args_type_hints)
|
201
225
|
|
202
|
-
@staticmethod
|
203
|
-
def create_case_from_method(func: Callable,
|
204
|
-
func_output: Dict[str, Any],
|
205
|
-
*args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
|
206
|
-
"""
|
207
|
-
Create a Case from the function and its arguments.
|
208
|
-
|
209
|
-
:param func: The function to create a case from.
|
210
|
-
:param func_output: A dictionary containing the output of the function, where the key is the output name.
|
211
|
-
:param args: The positional arguments of the function.
|
212
|
-
:param kwargs: The keyword arguments of the function.
|
213
|
-
:return: A Case object representing the case.
|
214
|
-
"""
|
215
|
-
case_dict = get_method_args_as_dict(func, *args, **kwargs)
|
216
|
-
case_dict.update(func_output)
|
217
|
-
case_name = get_func_rdr_model_name(func)
|
218
|
-
return Case(dict, id(case_dict), case_name, case_dict, **case_dict), case_dict
|
219
|
-
|
220
226
|
def initialize_rdr_model_name_and_load(self, func: Callable) -> None:
|
221
|
-
|
222
|
-
self.model_name = str_to_snake_case(model_file_name)
|
227
|
+
self.model_name = get_func_rdr_model_name(func, include_file_name=True)
|
223
228
|
self.load()
|
224
229
|
|
225
230
|
@staticmethod
|
@@ -242,3 +247,8 @@ class RDRDecorator:
|
|
242
247
|
Load the RDR model from the specified directory, otherwise create a new one.
|
243
248
|
"""
|
244
249
|
self.rdr = GeneralRDR(save_dir=self.models_dir, model_name=self.model_name)
|
250
|
+
|
251
|
+
|
252
|
+
def fit_rdr_func(scenario: Callable, rdr_decorated_func: Callable, *func_args, **func_kwargs) -> None:
|
253
|
+
rdr_decorated_func._rdr_decorator_instance.case_factory_metadata = CaseFactoryMetaData(scenario=scenario)
|
254
|
+
rdr_decorated_func(*func_args, **func_kwargs)
|