ripple-down-rules 0.6.0__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +21 -1
- ripple_down_rules/datastructures/callable_expression.py +24 -7
- ripple_down_rules/datastructures/case.py +12 -11
- ripple_down_rules/datastructures/dataclasses.py +135 -14
- ripple_down_rules/datastructures/enums.py +29 -86
- ripple_down_rules/datastructures/field_info.py +177 -0
- ripple_down_rules/datastructures/tracked_object.py +208 -0
- ripple_down_rules/experts.py +141 -50
- ripple_down_rules/failures.py +4 -0
- ripple_down_rules/helpers.py +75 -8
- ripple_down_rules/predicates.py +97 -0
- ripple_down_rules/rdr.py +712 -96
- ripple_down_rules/rdr_decorators.py +164 -112
- ripple_down_rules/rules.py +351 -114
- ripple_down_rules/user_interface/gui.py +66 -41
- ripple_down_rules/user_interface/ipython_custom_shell.py +46 -9
- ripple_down_rules/user_interface/prompt.py +80 -60
- ripple_down_rules/user_interface/template_file_creator.py +13 -8
- ripple_down_rules/utils.py +537 -53
- {ripple_down_rules-0.6.0.dist-info → ripple_down_rules-0.6.6.dist-info}/METADATA +4 -1
- ripple_down_rules-0.6.6.dist-info/RECORD +28 -0
- ripple_down_rules-0.6.0.dist-info/RECORD +0 -24
- {ripple_down_rules-0.6.0.dist-info → ripple_down_rules-0.6.6.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.6.0.dist-info → ripple_down_rules-0.6.6.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.6.0.dist-info → ripple_down_rules-0.6.6.dist-info}/top_level.txt +0 -0
ripple_down_rules/rules.py
CHANGED
@@ -1,74 +1,185 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import logging
|
4
3
|
import re
|
5
4
|
from abc import ABC, abstractmethod
|
6
|
-
from
|
5
|
+
from dataclasses import dataclass, field
|
7
6
|
from types import NoneType
|
8
7
|
from uuid import uuid4
|
9
8
|
|
10
|
-
from anytree import
|
11
|
-
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
9
|
+
from anytree import Node
|
12
10
|
from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple, Type, Set
|
13
11
|
|
14
12
|
from .datastructures.callable_expression import CallableExpression
|
15
13
|
from .datastructures.case import Case
|
16
14
|
from .datastructures.dataclasses import CaseFactoryMetaData, CaseQuery
|
17
15
|
from .datastructures.enums import RDREdge, Stop
|
18
|
-
from .
|
16
|
+
from .helpers import get_an_updated_case_copy
|
17
|
+
from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_type_from_string
|
19
18
|
|
20
19
|
|
21
|
-
|
22
|
-
|
20
|
+
@dataclass
|
21
|
+
class Rule(SubclassJSONSerializer, ABC):
|
22
|
+
conditions: Optional[CallableExpression] = field(default=None)
|
23
|
+
"""
|
24
|
+
The conditions of the rule, which is a callable expression that takes a case and returns a boolean.
|
25
|
+
"""
|
26
|
+
conclusion: Optional[CallableExpression] = field(default=None)
|
27
|
+
"""
|
28
|
+
The conclusion of the rule, which is a callable expression that takes a case and returns a value.
|
29
|
+
"""
|
30
|
+
_parent: Optional[Rule] = field(default=None)
|
31
|
+
"""
|
32
|
+
The parent rule of this rule in the ripple down rules tree.
|
33
|
+
"""
|
34
|
+
corner_case: Optional[Any] = field(default=None)
|
35
|
+
"""
|
36
|
+
The corner case for which this rule was created.
|
37
|
+
"""
|
38
|
+
_weight: Optional[RDREdge] = field(default=None)
|
39
|
+
"""
|
40
|
+
The weight of the rule, which is the type of edge connecting the rule to its parent.
|
41
|
+
"""
|
42
|
+
conclusion_name: Optional[str] = field(default=None)
|
43
|
+
"""
|
44
|
+
The name of the conclusion of the rule, which is used to identify the conclusion
|
45
|
+
"""
|
46
|
+
uid: str = field(default_factory=lambda: str(uuid4().int))
|
47
|
+
"""
|
48
|
+
A unique id for the rule using uuid4
|
49
|
+
"""
|
50
|
+
corner_case_metadata: Optional[CaseFactoryMetaData] = field(default=None)
|
51
|
+
"""
|
52
|
+
Metadata about the corner case, such as the factory that created it or the scenario it is based on.
|
53
|
+
"""
|
54
|
+
json_serialization: Optional[Dict[str, Any]] = field(init=False, default=None)
|
55
|
+
"""
|
56
|
+
The JSON serialization of the rule, which is used to serialize the rule to JSON.
|
57
|
+
"""
|
58
|
+
_name: Optional[str] = field(init=False, default=None)
|
59
|
+
"""
|
60
|
+
The name of the rule, which is the names of the conditions and the conclusion
|
61
|
+
"""
|
62
|
+
evaluated: bool = field(init=False, default=False)
|
63
|
+
"""
|
64
|
+
Whether the rule has been evaluated or not (i.e. whether the rule has been reached during evaluation of the ripple
|
65
|
+
down rules tree and the conditions have been checked).
|
66
|
+
"""
|
67
|
+
last_conclusion: Optional[Any] = field(init=False, default=None)
|
68
|
+
"""
|
69
|
+
The last conclusion of the rule, which is the conclusion of the rule when it was last evaluated.
|
70
|
+
"""
|
71
|
+
contributed: bool = field(init=False, default=False)
|
72
|
+
"""
|
73
|
+
Whether the rule has contributed by a value, meaning that it has fired and the conclusion has been added to the case.
|
74
|
+
"""
|
75
|
+
contributed_to_case_query: bool = field(init=False, default=False)
|
76
|
+
"""
|
77
|
+
Whether the rule has contributed to the case query, meaning that it has fired and the conclusion is relevant to the
|
78
|
+
case query.
|
79
|
+
"""
|
80
|
+
fired: Optional[bool] = field(init=False, default=None)
|
23
81
|
"""
|
24
82
|
Whether the rule has fired or not.
|
25
83
|
"""
|
84
|
+
mutually_exclusive: bool = field(init=False, default=True)
|
85
|
+
"""
|
86
|
+
Whether the rule is mutually exclusive with other rules.
|
87
|
+
"""
|
88
|
+
node: Optional[Node] = field(init=False, default=None)
|
26
89
|
|
27
|
-
def
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
corner_case_metadata: Optional[CaseFactoryMetaData] = None):
|
90
|
+
def __post_init__(self):
|
91
|
+
self.node = Node(self.name, parent=self.parent.node if self.parent else None)
|
92
|
+
self.node.weight = self.weight.value if self.weight else None
|
93
|
+
self.node._rdr_rule = self
|
94
|
+
|
95
|
+
@property
|
96
|
+
def descendants(self) -> List[Rule]:
|
35
97
|
"""
|
36
|
-
|
98
|
+
:return: the descendants of this rule, which are the rules that are children of this rule in the ripple down
|
99
|
+
rules tree.
|
100
|
+
"""
|
101
|
+
return [child._rdr_rule for child in self.node.descendants]
|
37
102
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
:
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
self.
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
self.
|
103
|
+
@property
|
104
|
+
def children(self) -> List[Rule]:
|
105
|
+
"""
|
106
|
+
:return: the children of this rule, which are the rules that are direct children of this rule in the ripple down
|
107
|
+
rules tree.
|
108
|
+
"""
|
109
|
+
return [child._rdr_rule for child in self.node.children]
|
110
|
+
|
111
|
+
@property
|
112
|
+
def parent(self):
|
113
|
+
"""
|
114
|
+
:return: The parent rule of this rule.
|
115
|
+
"""
|
116
|
+
return self._parent
|
117
|
+
|
118
|
+
@parent.setter
|
119
|
+
def parent(self, new_parent: Optional[Rule]):
|
120
|
+
"""
|
121
|
+
Set the parent rule of this rule.
|
122
|
+
:param new_parent: The new parent rule to set.
|
123
|
+
"""
|
124
|
+
self._parent = new_parent
|
125
|
+
if self.node:
|
126
|
+
self.node.parent = new_parent.node
|
127
|
+
|
128
|
+
@property
|
129
|
+
def weight(self) -> RDREdge:
|
130
|
+
return self._weight
|
131
|
+
|
132
|
+
@weight.setter
|
133
|
+
def weight(self, new_weight: RDREdge):
|
134
|
+
"""
|
135
|
+
Set the weight of the rule, which is the type of edge connecting the rule to its parent.
|
136
|
+
:param new_weight: The new weight to set.
|
137
|
+
"""
|
138
|
+
self._weight = new_weight
|
139
|
+
if self.node:
|
140
|
+
self.node.weight = new_weight.value
|
141
|
+
|
142
|
+
def get_an_updated_case_copy(self, case: Case) -> Case:
|
143
|
+
"""
|
144
|
+
:param case: The case to copy and update.
|
145
|
+
:return: A copy of the case updated with this rule conclusion.
|
146
|
+
"""
|
147
|
+
return get_an_updated_case_copy(case, self.conclusion, self.conclusion_name, self.conclusion.conclusion_type,
|
148
|
+
self.mutually_exclusive)
|
149
|
+
|
150
|
+
def reset(self):
|
151
|
+
self.evaluated = False
|
152
|
+
self.fired = False
|
153
|
+
self.contributed = False
|
154
|
+
self.contributed_to_case_query = False
|
155
|
+
self.last_conclusion = None
|
156
|
+
|
157
|
+
@property
|
158
|
+
def color(self) -> str:
|
159
|
+
if self.evaluated:
|
160
|
+
if self.contributed_to_case_query:
|
161
|
+
return "green"
|
162
|
+
elif self.contributed:
|
163
|
+
return "yellow"
|
164
|
+
elif self.fired:
|
165
|
+
return "orange"
|
166
|
+
else:
|
167
|
+
return "red"
|
168
|
+
else:
|
169
|
+
return "white"
|
60
170
|
|
61
171
|
@classmethod
|
62
|
-
def from_case_query(cls, case_query: CaseQuery) -> Rule:
|
172
|
+
def from_case_query(cls, case_query: CaseQuery, parent: Optional[Rule] = None) -> Rule:
|
63
173
|
"""
|
64
174
|
Create a SingleClassRule from a CaseQuery.
|
65
175
|
|
66
176
|
:param case_query: The CaseQuery to create the rule from.
|
177
|
+
:param parent: The parent rule of this rule.
|
67
178
|
:return: A SingleClassRule instance.
|
68
179
|
"""
|
69
180
|
corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
70
181
|
return cls(conditions=case_query.conditions, conclusion=case_query.target,
|
71
|
-
corner_case=case_query.case, parent
|
182
|
+
corner_case=case_query.case, _parent=parent,
|
72
183
|
corner_case_metadata=corner_case_metadata,
|
73
184
|
conclusion_name=case_query.attribute_name)
|
74
185
|
|
@@ -91,6 +202,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
91
202
|
:param x: The case to evaluate the rule on.
|
92
203
|
:return: The rule that fired or the last evaluated rule if no rule fired.
|
93
204
|
"""
|
205
|
+
self.evaluated = True
|
94
206
|
if not self.conditions:
|
95
207
|
raise ValueError("Rule has no conditions")
|
96
208
|
self.fired = self.conditions(x)
|
@@ -149,6 +261,18 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
149
261
|
f.write(conclusion_func.strip() + "\n\n\n")
|
150
262
|
return conclusion_func_call
|
151
263
|
|
264
|
+
@property
|
265
|
+
def generated_conclusion_function_name(self) -> str:
|
266
|
+
return f"conclusion_{self.uid}"
|
267
|
+
|
268
|
+
@property
|
269
|
+
def generated_conditions_function_name(self) -> str:
|
270
|
+
return f"conditions_{self.uid}"
|
271
|
+
|
272
|
+
@property
|
273
|
+
def generated_corner_case_object_name(self) -> str:
|
274
|
+
return f"corner_case_{self.uid}"
|
275
|
+
|
152
276
|
def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
|
153
277
|
"""
|
154
278
|
Convert the conclusion of a rule to source code.
|
@@ -161,23 +285,24 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
161
285
|
# This means the conclusion is a definition that should be written and then called
|
162
286
|
conclusion_lines = conclusion.split('\n')
|
163
287
|
# use regex to replace the function name
|
164
|
-
new_function_name = f"def
|
288
|
+
new_function_name = f"def {self.generated_conclusion_function_name}"
|
165
289
|
conclusion_lines[0] = re.sub(r"def (\w+)", new_function_name, conclusion_lines[0])
|
166
290
|
# add type hint
|
167
|
-
if
|
168
|
-
|
291
|
+
if not self.conclusion.mutually_exclusive:
|
292
|
+
type_names = [t.__name__ for t in self.conclusion.conclusion_type if t not in [list, set]]
|
293
|
+
if len(type_names) == 1:
|
294
|
+
hint = f"List[{type_names[0]}]"
|
295
|
+
else:
|
296
|
+
hint = f"List[Union[{', '.join(type_names)}]]"
|
169
297
|
else:
|
170
|
-
if
|
171
|
-
|
172
|
-
|
173
|
-
|
298
|
+
if NoneType in self.conclusion.conclusion_type:
|
299
|
+
type_names = [t.__name__ for t in self.conclusion.conclusion_type if t is not NoneType]
|
300
|
+
hint = f"Optional[{', '.join(type_names)}]"
|
301
|
+
elif len(self.conclusion.conclusion_type) == 1:
|
302
|
+
hint = self.conclusion.conclusion_type[0].__name__
|
174
303
|
else:
|
175
|
-
|
176
|
-
|
177
|
-
hint = f"Optional[{', '.join(type_names)}]"
|
178
|
-
else:
|
179
|
-
type_names = [t.__name__ for t in self.conclusion.conclusion_type]
|
180
|
-
hint = f"Union[{', '.join(type_names)}]"
|
304
|
+
type_names = [t.__name__ for t in self.conclusion.conclusion_type]
|
305
|
+
hint = f"Union[{', '.join(type_names)}]"
|
181
306
|
conclusion_lines[0] = conclusion_lines[0].replace("):", f") -> {hint}:")
|
182
307
|
func_call = f"{parent_indent} return {new_function_name.replace('def ', '')}(case)\n"
|
183
308
|
return "\n".join(conclusion_lines).strip(' '), func_call
|
@@ -199,7 +324,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
199
324
|
# This means the conditions are a definition that should be written and then called
|
200
325
|
conditions_lines = self.conditions.user_input.split('\n')
|
201
326
|
# use regex to replace the function name
|
202
|
-
new_function_name = f"def
|
327
|
+
new_function_name = f"def {self.generated_conditions_function_name}"
|
203
328
|
conditions_lines[0] = re.sub(r"def (\w+)", new_function_name, conditions_lines[0])
|
204
329
|
# add type hint
|
205
330
|
conditions_lines[0] = conditions_lines[0].replace('):', ') -> bool:')
|
@@ -216,34 +341,22 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
216
341
|
pass
|
217
342
|
|
218
343
|
def _to_json(self) -> Dict[str, Any]:
|
219
|
-
try:
|
220
|
-
corner_case = SubclassJSONSerializer.to_json_static(self.corner_case) if self.corner_case else None
|
221
|
-
except Exception as e:
|
222
|
-
logging.debug("Failed to serialize corner case to json, setting it to None. Error: %s", e)
|
223
|
-
corner_case = None
|
224
344
|
json_serialization = {"_type": get_full_class_name(type(self)),
|
225
345
|
"conditions": self.conditions.to_json(),
|
226
346
|
"conclusion": conclusion_to_json(self.conclusion),
|
227
347
|
"parent": self.parent.json_serialization if self.parent else None,
|
228
|
-
"corner_case": corner_case,
|
229
348
|
"conclusion_name": self.conclusion_name,
|
230
|
-
"weight": self.weight,
|
349
|
+
"weight": self.weight.value,
|
231
350
|
"uid": self.uid}
|
232
351
|
return json_serialization
|
233
352
|
|
234
353
|
@classmethod
|
235
354
|
def _from_json(cls, data: Dict[str, Any]) -> Rule:
|
236
|
-
try:
|
237
|
-
corner_case = Case.from_json(data["corner_case"])
|
238
|
-
except Exception as e:
|
239
|
-
logging.debug("Failed to load corner case from json, setting it to None.")
|
240
|
-
corner_case = None
|
241
355
|
loaded_rule = cls(conditions=CallableExpression.from_json(data["conditions"]),
|
242
356
|
conclusion=CallableExpression.from_json(data["conclusion"]),
|
243
|
-
|
244
|
-
corner_case=corner_case,
|
357
|
+
_parent=cls.from_json(data["parent"]),
|
245
358
|
conclusion_name=data["conclusion_name"],
|
246
|
-
|
359
|
+
_weight=RDREdge.from_value(data["weight"]),
|
247
360
|
uid=data["uid"])
|
248
361
|
return loaded_rule
|
249
362
|
|
@@ -260,30 +373,75 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
260
373
|
Set the name of the rule.
|
261
374
|
"""
|
262
375
|
self._name = new_name
|
376
|
+
self.node.name = new_name
|
377
|
+
|
378
|
+
@property
|
379
|
+
def semantic_condition_name(self) -> Optional[str]:
|
380
|
+
"""
|
381
|
+
Get the name of the conditions of the rule, which is the user input of the conditions.
|
382
|
+
"""
|
383
|
+
if isinstance(self.conditions, CallableExpression):
|
384
|
+
return self.expression_name(self.conditions)
|
385
|
+
return None
|
386
|
+
|
387
|
+
@property
|
388
|
+
def semantic_conclusion_name(self) -> Optional[str]:
|
389
|
+
"""
|
390
|
+
Get the name of the conclusion of the rule, which is the user input of the conclusion.
|
391
|
+
"""
|
392
|
+
if isinstance(self.conclusion, CallableExpression):
|
393
|
+
return self.expression_name(self.conclusion)
|
394
|
+
return None
|
395
|
+
|
396
|
+
@staticmethod
|
397
|
+
def expression_name(expression: CallableExpression) -> str:
|
398
|
+
"""
|
399
|
+
Get the name of the expression, which is the user input of the expression if it exists,
|
400
|
+
otherwise it is the conclusion or conditions of the rule.
|
401
|
+
"""
|
402
|
+
if expression.user_defined_name is not None and expression.user_defined_name != expression.encapsulating_function_name:
|
403
|
+
return expression.user_defined_name.strip()
|
404
|
+
func_name = expression.user_input.split('(')[0].replace('def ',
|
405
|
+
'').strip() if "def " in expression.user_input else None
|
406
|
+
if func_name is not None and func_name != expression.encapsulating_function_name:
|
407
|
+
return func_name
|
408
|
+
elif expression.user_input:
|
409
|
+
return expression.user_input.strip()
|
410
|
+
else:
|
411
|
+
return str(expression)
|
263
412
|
|
264
413
|
def __str__(self, sep="\n"):
|
265
414
|
"""
|
266
415
|
Get the string representation of the rule, which is the conditions and the conclusion.
|
267
416
|
"""
|
268
|
-
return f"{self.
|
417
|
+
return f"{self.semantic_condition_name}{sep}=> {self.semantic_conclusion_name}"
|
269
418
|
|
270
419
|
def __repr__(self):
|
271
420
|
return self.__str__()
|
272
421
|
|
422
|
+
def __eq__(self, other):
|
423
|
+
if not isinstance(other, Rule):
|
424
|
+
return False
|
425
|
+
return other.uid == self.uid
|
426
|
+
|
427
|
+
def __hash__(self):
|
428
|
+
return hash(self.uid)
|
273
429
|
|
430
|
+
|
431
|
+
@dataclass
|
274
432
|
class HasAlternativeRule:
|
275
433
|
"""
|
276
434
|
A mixin class for rules that have an alternative rule.
|
277
435
|
"""
|
278
|
-
_alternative: Optional[Rule] = None
|
436
|
+
_alternative: Optional[Rule] = field(init=False, default=None)
|
279
437
|
"""
|
280
438
|
The alternative rule of the rule, which is evaluated when the rule doesn't fire.
|
281
439
|
"""
|
282
|
-
furthest_alternative: Optional[List[Rule]] = None
|
440
|
+
furthest_alternative: Optional[List[Rule]] = field(init=False, default=None)
|
283
441
|
"""
|
284
442
|
The furthest alternative rule of the rule, which is the last alternative rule in the chain of alternative rules.
|
285
443
|
"""
|
286
|
-
all_alternatives: Optional[List[Rule]] = None
|
444
|
+
all_alternatives: Optional[List[Rule]] = field(init=False, default=None)
|
287
445
|
"""
|
288
446
|
All alternative rules of the rule, which is all the alternative rules in the chain of alternative rules.
|
289
447
|
"""
|
@@ -304,13 +462,14 @@ class HasAlternativeRule:
|
|
304
462
|
self.furthest_alternative[-1].alternative = new_rule
|
305
463
|
else:
|
306
464
|
new_rule.parent = self
|
307
|
-
new_rule.weight = RDREdge.Alternative
|
465
|
+
new_rule.weight = RDREdge.Alternative if not new_rule.weight else new_rule.weight
|
308
466
|
self._alternative = new_rule
|
309
467
|
self.furthest_alternative = [new_rule]
|
310
468
|
|
311
469
|
|
470
|
+
@dataclass
|
312
471
|
class HasRefinementRule:
|
313
|
-
_refinement: Optional[HasAlternativeRule] = None
|
472
|
+
_refinement: Optional[HasAlternativeRule] = field(init=False, default=None)
|
314
473
|
"""
|
315
474
|
The refinement rule of the rule, which is evaluated when the rule fires.
|
316
475
|
"""
|
@@ -327,20 +486,23 @@ class HasRefinementRule:
|
|
327
486
|
"""
|
328
487
|
if new_rule is None:
|
329
488
|
return
|
330
|
-
new_rule.top_rule = self
|
331
489
|
if self.refinement:
|
332
490
|
self.refinement.alternative = new_rule
|
333
491
|
else:
|
334
492
|
new_rule.parent = self
|
335
|
-
new_rule.weight = RDREdge.Refinement
|
493
|
+
new_rule.weight = RDREdge.Refinement if not isinstance(new_rule,
|
494
|
+
MultiClassFilterRule) else new_rule.weight
|
336
495
|
self._refinement = new_rule
|
337
496
|
|
338
497
|
|
498
|
+
@dataclass(eq=False)
|
339
499
|
class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
340
500
|
"""
|
341
501
|
A rule in the SingleClassRDR classifier, it can have a refinement or an alternative rule or both.
|
342
502
|
"""
|
343
503
|
|
504
|
+
mutually_exclusive: bool = field(init=False, default=True)
|
505
|
+
|
344
506
|
def evaluate_next_rule(self, x: Case) -> SingleClassRule:
|
345
507
|
if self.fired:
|
346
508
|
returned_rule = self.refinement(x) if self.refinement else self
|
@@ -351,7 +513,7 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
|
351
513
|
def fit_rule(self, case_query: CaseQuery):
|
352
514
|
corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
353
515
|
new_rule = SingleClassRule(case_query.conditions, case_query.target,
|
354
|
-
corner_case=case_query.case,
|
516
|
+
corner_case=case_query.case, _parent=self,
|
355
517
|
corner_case_metadata=corner_case_metadata,
|
356
518
|
)
|
357
519
|
if self.fired:
|
@@ -373,31 +535,19 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
|
373
535
|
return loaded_rule
|
374
536
|
|
375
537
|
def _if_statement_source_code_clause(self) -> str:
|
376
|
-
return "elif" if self.weight == RDREdge.Alternative
|
538
|
+
return "elif" if self.weight == RDREdge.Alternative else "if"
|
377
539
|
|
378
540
|
|
379
|
-
|
541
|
+
@dataclass(eq=False)
|
542
|
+
class MultiClassRefinementRule(Rule, HasAlternativeRule, ABC):
|
380
543
|
"""
|
381
|
-
A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule
|
382
|
-
the conclusion of the rule is a Stop category meant to stop the parent conclusion from being made.
|
544
|
+
A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule.
|
383
545
|
"""
|
384
|
-
top_rule: Optional[MultiClassTopRule] = None
|
546
|
+
top_rule: Optional[MultiClassTopRule] = field(init=False, default=None)
|
385
547
|
"""
|
386
548
|
The top rule of the rule, which is the nearest ancestor that fired and this rule is a refinement of.
|
387
549
|
"""
|
388
|
-
|
389
|
-
def __init__(self, *args, **kwargs):
|
390
|
-
super(MultiClassStopRule, self).__init__(*args, **kwargs)
|
391
|
-
self.conclusion = CallableExpression(conclusion_type=(Stop,), conclusion=Stop.stop)
|
392
|
-
|
393
|
-
def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassStopRule, MultiClassTopRule]]:
|
394
|
-
if self.fired:
|
395
|
-
self.top_rule.fired = False
|
396
|
-
return self.top_rule.alternative
|
397
|
-
elif self.alternative:
|
398
|
-
return self.alternative(x)
|
399
|
-
else:
|
400
|
-
return self.top_rule.alternative
|
550
|
+
mutually_exclusive: bool = field(init=False, default=False)
|
401
551
|
|
402
552
|
def _to_json(self) -> Dict[str, Any]:
|
403
553
|
self.json_serialization = {**Rule._to_json(self),
|
@@ -405,47 +555,135 @@ class MultiClassStopRule(Rule, HasAlternativeRule):
|
|
405
555
|
return self.json_serialization
|
406
556
|
|
407
557
|
@classmethod
|
408
|
-
def _from_json(cls, data: Dict[str, Any]) ->
|
409
|
-
loaded_rule = super(
|
558
|
+
def _from_json(cls, data: Dict[str, Any]) -> MultiClassRefinementRule:
|
559
|
+
loaded_rule = super(MultiClassRefinementRule, cls)._from_json(data)
|
410
560
|
# The following is done to prevent re-initialization of the top rule,
|
411
561
|
# so the top rule that is already initialized is passed in the data instead of its json serialization.
|
412
562
|
loaded_rule.top_rule = data['top_rule']
|
413
563
|
if data['alternative'] is not None:
|
414
564
|
data['alternative']['top_rule'] = data['top_rule']
|
415
|
-
|
565
|
+
loaded_rule.alternative = SubclassJSONSerializer.from_json(data["alternative"])
|
416
566
|
return loaded_rule
|
417
567
|
|
568
|
+
def _if_statement_source_code_clause(self) -> str:
|
569
|
+
return "elif" if self.weight == RDREdge.Alternative else "if"
|
570
|
+
|
571
|
+
|
572
|
+
@dataclass(eq=False)
|
573
|
+
class MultiClassStopRule(MultiClassRefinementRule):
|
574
|
+
"""
|
575
|
+
A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
|
576
|
+
the conclusion of the rule is a Stop category meant to stop the parent conclusion from being made.
|
577
|
+
"""
|
578
|
+
conclusion: CallableExpression = field(default_factory=lambda: CallableExpression(conclusion_type=(Stop,),
|
579
|
+
conclusion=Stop.stop))
|
580
|
+
"""
|
581
|
+
The conclusion of the rule, which is a CallableExpression that returns a Stop category.
|
582
|
+
"""
|
583
|
+
|
584
|
+
def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassRefinementRule, MultiClassTopRule]]:
|
585
|
+
if self.fired:
|
586
|
+
self.top_rule.fired = False
|
587
|
+
return self.top_rule.alternative
|
588
|
+
elif self.alternative:
|
589
|
+
return self.alternative(x)
|
590
|
+
else:
|
591
|
+
return self.top_rule.alternative
|
592
|
+
|
418
593
|
def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[None, str]:
|
419
594
|
return None, f"{parent_indent}{' ' * 4}pass\n"
|
420
595
|
|
421
|
-
def _if_statement_source_code_clause(self) -> str:
|
422
|
-
return "elif" if self.weight == RDREdge.Alternative.value else "if"
|
423
596
|
|
597
|
+
@dataclass(eq=False)
|
598
|
+
class MultiClassFilterRule(MultiClassRefinementRule, HasRefinementRule):
|
599
|
+
"""
|
600
|
+
A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
|
601
|
+
the conclusion of the rule is a Filter category meant to filter the parent conclusion.
|
602
|
+
"""
|
603
|
+
weight: RDREdge = field(init=False, default_factory=lambda: RDREdge.Filter)
|
604
|
+
|
605
|
+
def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassRefinementRule, MultiClassTopRule]]:
|
606
|
+
if self.fired:
|
607
|
+
if self.refinement:
|
608
|
+
case_cp = x
|
609
|
+
if isinstance(self.refinement, MultiClassFilterRule):
|
610
|
+
case_cp = self.get_an_updated_case_copy(case_cp)
|
611
|
+
return self.refinement(case_cp)
|
612
|
+
else:
|
613
|
+
return self.top_rule.alternative
|
614
|
+
elif self.alternative:
|
615
|
+
return self.alternative(x)
|
616
|
+
else:
|
617
|
+
return self.top_rule.alternative
|
618
|
+
|
619
|
+
def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[None, str]:
|
620
|
+
func, func_call = super().get_conclusion_as_source_code(str(conclusion), parent_indent=parent_indent)
|
621
|
+
conclusion_str = func_call.replace("return ", "").strip()
|
622
|
+
conclusion_str = conclusion_str.replace("(case)", "(case_cp)")
|
424
623
|
|
624
|
+
parent_to_filter = self.get_parent_to_filter()
|
625
|
+
statement = (
|
626
|
+
f"\n{parent_indent} case_cp = get_an_updated_case_copy(case, {parent_to_filter.generated_conclusion_function_name},"
|
627
|
+
f" attribute_name, conclusion_type, mutually_exclusive)")
|
628
|
+
statement += f"\n{parent_indent} conclusions.update(make_set({conclusion_str}))\n"
|
629
|
+
return func, statement
|
630
|
+
|
631
|
+
def get_parent_to_filter(self, parent: Union[None, MultiClassRefinementRule, MultiClassTopRule] = None) \
|
632
|
+
-> Union[MultiClassFilterRule, MultiClassTopRule]:
|
633
|
+
parent = self.parent if parent is None else parent
|
634
|
+
if isinstance(parent, (MultiClassFilterRule, MultiClassTopRule)) and parent.fired:
|
635
|
+
return parent
|
636
|
+
else:
|
637
|
+
return parent.parent
|
638
|
+
|
639
|
+
def _to_json(self) -> Dict[str, Any]:
|
640
|
+
self.json_serialization = super(MultiClassFilterRule, self)._to_json()
|
641
|
+
self.json_serialization['refinement'] = self.refinement.to_json() if self.refinement is not None else None
|
642
|
+
return self.json_serialization
|
643
|
+
|
644
|
+
@classmethod
|
645
|
+
def _from_json(cls, data: Dict[str, Any]) -> MultiClassFilterRule:
|
646
|
+
loaded_rule = super(MultiClassFilterRule, cls)._from_json(data)
|
647
|
+
if data['refinement'] is not None:
|
648
|
+
data['refinement']['top_rule'] = data['top_rule']
|
649
|
+
loaded_rule.refinement = cls.from_json(data["refinement"]) if data["refinement"] is not None else None
|
650
|
+
return loaded_rule
|
651
|
+
|
652
|
+
|
653
|
+
@dataclass(eq=False)
|
425
654
|
class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
426
655
|
"""
|
427
656
|
A rule in the MultiClassRDR classifier, it can have a refinement and a next rule.
|
428
657
|
"""
|
429
|
-
|
430
|
-
|
431
|
-
super(MultiClassTopRule, self).__init__(*args, **kwargs)
|
432
|
-
self.weight = RDREdge.Next.value
|
658
|
+
mutually_exclusive: bool = field(init=False, default=False)
|
659
|
+
weight: RDREdge = field(init=False, default_factory=lambda: RDREdge.Next)
|
433
660
|
|
434
661
|
def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassStopRule, MultiClassTopRule]]:
|
435
662
|
if self.fired and self.refinement:
|
436
|
-
|
663
|
+
case_cp = x
|
664
|
+
if isinstance(self.refinement, MultiClassFilterRule):
|
665
|
+
case_cp = self.get_an_updated_case_copy(case_cp)
|
666
|
+
return self.refinement(case_cp)
|
437
667
|
elif self.alternative: # Here alternative refers to next rule in MultiClassRDR
|
438
668
|
return self.alternative
|
669
|
+
return None
|
439
670
|
|
440
|
-
def fit_rule(self, case_query: CaseQuery):
|
671
|
+
def fit_rule(self, case_query: CaseQuery, refinement_type: Optional[Type[MultiClassRefinementRule]] = None):
|
441
672
|
if self.fired and case_query.target != self.conclusion:
|
442
|
-
|
673
|
+
if refinement_type in [None, MultiClassStopRule]:
|
674
|
+
new_rule = MultiClassStopRule(case_query.conditions, corner_case=case_query.case,
|
675
|
+
_parent=self)
|
676
|
+
elif refinement_type is MultiClassFilterRule:
|
677
|
+
new_rule = MultiClassFilterRule.from_case_query(case_query, parent=self)
|
678
|
+
else:
|
679
|
+
raise ValueError(f"Unknown refinement type {refinement_type}")
|
680
|
+
new_rule.top_rule = self
|
681
|
+
self.refinement = new_rule
|
443
682
|
elif not self.fired:
|
444
|
-
self.alternative = MultiClassTopRule(case_query
|
445
|
-
corner_case=case_query.case, parent=self)
|
683
|
+
self.alternative = MultiClassTopRule.from_case_query(case_query, parent=self)
|
446
684
|
|
447
685
|
def _to_json(self) -> Dict[str, Any]:
|
448
|
-
self.json_serialization = {**
|
686
|
+
self.json_serialization = {**super()._to_json(),
|
449
687
|
"refinement": self.refinement.to_json() if self.refinement is not None else None,
|
450
688
|
"alternative": self.alternative.to_json() if self.alternative is not None else None}
|
451
689
|
return self.json_serialization
|
@@ -457,7 +695,8 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
|
457
695
|
# so the top rule that is already initialized is passed in the data instead of its json serialization.
|
458
696
|
if data['refinement'] is not None:
|
459
697
|
data['refinement']['top_rule'] = loaded_rule
|
460
|
-
|
698
|
+
data_type = get_type_from_string(data["refinement"]["_type"])
|
699
|
+
loaded_rule.refinement = data_type.from_json(data["refinement"])
|
461
700
|
loaded_rule.alternative = MultiClassTopRule.from_json(data["alternative"])
|
462
701
|
return loaded_rule
|
463
702
|
|
@@ -466,8 +705,6 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
|
466
705
|
conclusion_str = func_call.replace("return ", "").strip()
|
467
706
|
|
468
707
|
statement = f"{parent_indent} conclusions.update(make_set({conclusion_str}))\n"
|
469
|
-
if self.alternative is None:
|
470
|
-
statement += f"{parent_indent}return conclusions\n"
|
471
708
|
return func, statement
|
472
709
|
|
473
710
|
def _if_statement_source_code_clause(self) -> str:
|