ripple-down-rules 0.6.29__py3-none-any.whl → 0.6.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/datastructures/case.py +12 -11
- ripple_down_rules/datastructures/dataclasses.py +87 -9
- ripple_down_rules/experts.py +98 -20
- ripple_down_rules/rdr.py +37 -26
- ripple_down_rules/rules.py +15 -1
- ripple_down_rules/user_interface/gui.py +59 -40
- ripple_down_rules/user_interface/ipython_custom_shell.py +36 -7
- ripple_down_rules/user_interface/prompt.py +41 -26
- ripple_down_rules/user_interface/template_file_creator.py +10 -8
- ripple_down_rules/utils.py +57 -8
- {ripple_down_rules-0.6.29.dist-info → ripple_down_rules-0.6.30.dist-info}/METADATA +1 -1
- ripple_down_rules-0.6.30.dist-info/RECORD +24 -0
- ripple_down_rules-0.6.29.dist-info/RECORD +0 -24
- {ripple_down_rules-0.6.29.dist-info → ripple_down_rules-0.6.30.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.6.29.dist-info → ripple_down_rules-0.6.30.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.6.29.dist-info → ripple_down_rules-0.6.30.dist-info}/top_level.txt +0 -0
ripple_down_rules/__init__.py
CHANGED
@@ -20,19 +20,20 @@ if TYPE_CHECKING:
|
|
20
20
|
|
21
21
|
class Case(UserDict, SubclassJSONSerializer):
|
22
22
|
"""
|
23
|
-
A collection of attributes that represents a set of
|
24
|
-
the names of the attributes and the values are the attributes. All are stored in lower case
|
23
|
+
A collection of attributes that represents a set of attributes of a case. This is a dictionary where the keys are
|
24
|
+
the names of the attributes and the values are the attributes. All are stored in lower case, and can be accessed
|
25
|
+
using the dot notation as well as the dictionary access notation.
|
25
26
|
"""
|
26
27
|
|
27
28
|
def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None,
|
28
29
|
_name: Optional[str] = None, original_object: Optional[Any] = None, **kwargs):
|
29
30
|
"""
|
30
|
-
Create a new
|
31
|
+
Create a new case.
|
31
32
|
|
32
|
-
:param _obj_type: The type of the object that the
|
33
|
-
:param _id: The id of the
|
34
|
-
:param _name: The semantic name that describes the
|
35
|
-
:param kwargs: The attributes of the
|
33
|
+
:param _obj_type: The original type of the object that the case represents.
|
34
|
+
:param _id: The id of the case.
|
35
|
+
:param _name: The semantic name that describes the case.
|
36
|
+
:param kwargs: The attributes of the case.
|
36
37
|
"""
|
37
38
|
super().__init__(kwargs)
|
38
39
|
self._original_object = original_object
|
@@ -43,12 +44,12 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
43
44
|
@classmethod
|
44
45
|
def from_obj(cls, obj: Any, obj_name: Optional[str] = None, max_recursion_idx: int = 3) -> Case:
|
45
46
|
"""
|
46
|
-
Create a
|
47
|
+
Create a case from an object.
|
47
48
|
|
48
|
-
:param obj: The object to create a
|
49
|
+
:param obj: The object to create a case from.
|
49
50
|
:param max_recursion_idx: The maximum recursion index to prevent infinite recursion.
|
50
51
|
:param obj_name: The name of the object.
|
51
|
-
:return: The
|
52
|
+
:return: The case that represents the object.
|
52
53
|
"""
|
53
54
|
return create_case(obj, max_recursion_idx=max_recursion_idx, obj_name=obj_name)
|
54
55
|
|
@@ -129,7 +130,7 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
129
130
|
@dataclass
|
130
131
|
class CaseAttributeValue(SubclassJSONSerializer):
|
131
132
|
"""
|
132
|
-
|
133
|
+
Encapsulates a single value of a case attribute, it adds an id to the value.
|
133
134
|
"""
|
134
135
|
id: Hashable
|
135
136
|
"""
|
@@ -1,19 +1,22 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import inspect
|
4
|
-
import
|
4
|
+
import uuid
|
5
5
|
from dataclasses import dataclass, field
|
6
6
|
|
7
|
-
import
|
7
|
+
from colorama import Fore, Style
|
8
8
|
from omegaconf import MISSING
|
9
9
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
10
|
-
from typing_extensions import Any, Optional, Dict, Type, Tuple, Union, List,
|
10
|
+
from typing_extensions import Any, Optional, Dict, Type, Tuple, Union, List, Set, Callable, TYPE_CHECKING
|
11
11
|
|
12
|
-
from ..utils import get_method_name, get_function_import_data, get_function_representation
|
13
12
|
from .callable_expression import CallableExpression
|
14
13
|
from .case import create_case, Case
|
15
|
-
from ..utils import copy_case, make_list, make_set, get_origin_and_args_from_type_hint,
|
16
|
-
|
14
|
+
from ..utils import copy_case, make_list, make_set, get_origin_and_args_from_type_hint, render_tree, \
|
15
|
+
get_function_representation
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from ..rdr import RippleDownRules
|
19
|
+
from ..rules import Rule
|
17
20
|
|
18
21
|
|
19
22
|
@dataclass
|
@@ -92,6 +95,38 @@ class CaseQuery:
|
|
92
95
|
"""
|
93
96
|
The type hints of the function arguments. This is used to recreate the function signature.
|
94
97
|
"""
|
98
|
+
rdr: Optional[RippleDownRules] = None
|
99
|
+
"""
|
100
|
+
The Ripple Down Rules that was used to answer the case query.
|
101
|
+
"""
|
102
|
+
|
103
|
+
def render_rule_tree(self, filepath: Optional[str] = None, view: bool = False):
|
104
|
+
if self.rdr is None:
|
105
|
+
return
|
106
|
+
render_tree(self.rdr.start_rule, use_dot_exporter=True, filename=filepath, view=view)
|
107
|
+
|
108
|
+
@property
|
109
|
+
def current_value_str(self):
|
110
|
+
return (f"{Fore.MAGENTA}Current value of {Fore.CYAN}{self.name}{Fore.MAGENTA} of type(s) "
|
111
|
+
f"{Fore.CYAN}({self.core_attribute_type_str}){Fore.MAGENTA}: "
|
112
|
+
f"{Fore.WHITE}{self.current_value}{Style.RESET_ALL}")
|
113
|
+
|
114
|
+
@property
|
115
|
+
def current_value(self) -> Any:
|
116
|
+
"""
|
117
|
+
:return: The current value of the attribute.
|
118
|
+
"""
|
119
|
+
if not hasattr(self.case, self.attribute_name):
|
120
|
+
return None
|
121
|
+
|
122
|
+
attr_value = getattr(self.case, self.attribute_name)
|
123
|
+
|
124
|
+
if attr_value is None:
|
125
|
+
return attr_value
|
126
|
+
elif self.mutually_exclusive:
|
127
|
+
return attr_value
|
128
|
+
else:
|
129
|
+
return list({v for v in make_list(attr_value) if isinstance(v, self.core_attribute_type)})
|
95
130
|
|
96
131
|
@property
|
97
132
|
def case_type(self) -> Type:
|
@@ -145,7 +180,14 @@ class CaseQuery:
|
|
145
180
|
return attribute_types_str
|
146
181
|
|
147
182
|
@property
|
148
|
-
def
|
183
|
+
def core_attribute_type_str(self) -> str:
|
184
|
+
"""
|
185
|
+
:return: The names of the core types of the attribute.
|
186
|
+
"""
|
187
|
+
return ','.join([t.__name__ for t in self.core_attribute_type])
|
188
|
+
|
189
|
+
@property
|
190
|
+
def core_attribute_type(self) -> Tuple[Type, ...]:
|
149
191
|
"""
|
150
192
|
:return: The core type of the attribute.
|
151
193
|
"""
|
@@ -247,7 +289,7 @@ class CaseQuery:
|
|
247
289
|
conditions=self.conditions, is_function=self.is_function,
|
248
290
|
function_args_type_hints=self.function_args_type_hints,
|
249
291
|
case_factory=self.case_factory, case_factory_idx=self.case_factory_idx,
|
250
|
-
case_conf=self.case_conf, scenario=self.scenario)
|
292
|
+
case_conf=self.case_conf, scenario=self.scenario, rdr=self.rdr)
|
251
293
|
|
252
294
|
|
253
295
|
@dataclass
|
@@ -284,4 +326,40 @@ class CaseFactoryMetaData:
|
|
284
326
|
f" scenario={scenario_repr})")
|
285
327
|
|
286
328
|
def __str__(self):
|
287
|
-
return self.__repr__()
|
329
|
+
return self.__repr__()
|
330
|
+
|
331
|
+
|
332
|
+
@dataclass
|
333
|
+
class RDRConclusion:
|
334
|
+
"""
|
335
|
+
This dataclass represents a conclusion of a Ripple Down Rule.
|
336
|
+
It contains the conclusion expression, the type of the conclusion, and the scope in which it is evaluated.
|
337
|
+
"""
|
338
|
+
value: Any
|
339
|
+
"""
|
340
|
+
The conclusion value.
|
341
|
+
"""
|
342
|
+
frozen_case: Any
|
343
|
+
"""
|
344
|
+
The frozen case that the conclusion was made for.
|
345
|
+
"""
|
346
|
+
rule: Rule
|
347
|
+
"""
|
348
|
+
The rule that gave this conclusion.
|
349
|
+
"""
|
350
|
+
rdr: RippleDownRules
|
351
|
+
"""
|
352
|
+
The Ripple Down Rules that classified the case and produced this conclusion.
|
353
|
+
"""
|
354
|
+
id: int = field(default_factory=lambda: uuid.uuid4().int)
|
355
|
+
"""
|
356
|
+
The unique identifier of the conclusion.
|
357
|
+
"""
|
358
|
+
|
359
|
+
def __hash__(self):
|
360
|
+
return hash(self.id)
|
361
|
+
|
362
|
+
def __eq__(self, other):
|
363
|
+
if not isinstance(other, RDRConclusion):
|
364
|
+
return False
|
365
|
+
return self.id == other.id
|
ripple_down_rules/experts.py
CHANGED
@@ -1,22 +1,23 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import ast
|
4
|
+
import importlib
|
4
5
|
import json
|
5
|
-
import logging
|
6
6
|
import os
|
7
|
-
import
|
7
|
+
import sys
|
8
8
|
from abc import ABC, abstractmethod
|
9
|
+
from dataclasses import is_dataclass
|
9
10
|
from textwrap import dedent, indent
|
10
11
|
from typing import Tuple, Dict
|
11
12
|
|
12
13
|
from typing_extensions import Optional, TYPE_CHECKING, List
|
13
14
|
|
14
15
|
from .datastructures.callable_expression import CallableExpression
|
15
|
-
from .datastructures.enums import PromptFor
|
16
|
-
from .datastructures.dataclasses import CaseQuery
|
17
16
|
from .datastructures.case import show_current_and_corner_cases
|
17
|
+
from .datastructures.dataclasses import CaseQuery
|
18
|
+
from .datastructures.enums import PromptFor
|
18
19
|
from .user_interface.template_file_creator import TemplateFileCreator
|
19
|
-
from .utils import extract_imports, extract_function_source, get_imports_from_scope,
|
20
|
+
from .utils import extract_imports, extract_function_source, get_imports_from_scope, get_class_file_path
|
20
21
|
|
21
22
|
try:
|
22
23
|
from .user_interface.gui import RDRCaseViewer
|
@@ -49,12 +50,13 @@ class Expert(ABC):
|
|
49
50
|
answers_save_path: Optional[str] = None):
|
50
51
|
self.all_expert_answers = []
|
51
52
|
self.use_loaded_answers = use_loaded_answers
|
52
|
-
self.append = True
|
53
53
|
self.answers_save_path = answers_save_path
|
54
54
|
if answers_save_path is not None and os.path.exists(answers_save_path + '.py'):
|
55
55
|
if use_loaded_answers:
|
56
56
|
self.load_answers(answers_save_path)
|
57
|
-
|
57
|
+
if not append:
|
58
|
+
os.remove(answers_save_path + '.py')
|
59
|
+
self.append = True
|
58
60
|
|
59
61
|
@abstractmethod
|
60
62
|
def ask_for_conditions(self, case_query: CaseQuery, last_evaluated_rule: Optional[Rule] = None) \
|
@@ -87,7 +89,7 @@ class Expert(ABC):
|
|
87
89
|
"""
|
88
90
|
if path is None and self.answers_save_path is None:
|
89
91
|
raise ValueError("No path provided to clear expert answers, either provide a path or set the "
|
90
|
-
|
92
|
+
"answers_save_path attribute.")
|
91
93
|
if path is None:
|
92
94
|
path = self.answers_save_path
|
93
95
|
if os.path.exists(path + '.py'):
|
@@ -106,7 +108,7 @@ class Expert(ABC):
|
|
106
108
|
return
|
107
109
|
if path is None and self.answers_save_path is None:
|
108
110
|
raise ValueError("No path provided to save expert answers, either provide a path or set the "
|
109
|
-
|
111
|
+
"answers_save_path attribute.")
|
110
112
|
if path is None:
|
111
113
|
path = self.answers_save_path
|
112
114
|
self._save_to_python(path, expert_answers=expert_answers)
|
@@ -163,7 +165,7 @@ class Expert(ABC):
|
|
163
165
|
"""
|
164
166
|
if path is None and self.answers_save_path is None:
|
165
167
|
raise ValueError("No path provided to load expert answers from, either provide a path or set the "
|
166
|
-
|
168
|
+
"answers_save_path attribute.")
|
167
169
|
if path is None:
|
168
170
|
path = self.answers_save_path
|
169
171
|
if os.path.exists(path + '.py'):
|
@@ -203,6 +205,81 @@ class Expert(ABC):
|
|
203
205
|
self.all_expert_answers.append((scope, function_source))
|
204
206
|
|
205
207
|
|
208
|
+
class AI(Expert):
|
209
|
+
"""
|
210
|
+
The AI Expert class, an expert that uses AI to provide differentiating features and conclusions.
|
211
|
+
"""
|
212
|
+
|
213
|
+
def __init__(self, **kwargs):
|
214
|
+
"""
|
215
|
+
Initialize the AI expert.
|
216
|
+
"""
|
217
|
+
super().__init__(**kwargs)
|
218
|
+
self.user_prompt = UserPrompt()
|
219
|
+
|
220
|
+
def ask_for_conditions(self, case_query: CaseQuery,
|
221
|
+
last_evaluated_rule: Optional[Rule] = None) \
|
222
|
+
-> CallableExpression:
|
223
|
+
prompt_str = self.get_prompt_for_ai(case_query, PromptFor.Conditions)
|
224
|
+
print(prompt_str)
|
225
|
+
sys.exit()
|
226
|
+
|
227
|
+
def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
|
228
|
+
prompt_str = self.get_prompt_for_ai(case_query, PromptFor.Conclusion)
|
229
|
+
output_type_source = self.get_output_type_class_source(case_query)
|
230
|
+
prompt_str = f"\n\n\nOutput type(s) class source:\n{output_type_source}\n\n" + prompt_str
|
231
|
+
print(prompt_str)
|
232
|
+
sys.exit()
|
233
|
+
|
234
|
+
def get_output_type_class_source(self, case_query: CaseQuery) -> str:
|
235
|
+
"""
|
236
|
+
Get the output type class source for the AI expert.
|
237
|
+
|
238
|
+
:param case_query: The case query containing the case to classify.
|
239
|
+
:return: The output type class source.
|
240
|
+
"""
|
241
|
+
output_types = case_query.core_attribute_type
|
242
|
+
|
243
|
+
def get_class_source(cls):
|
244
|
+
cls_source_file = get_class_file_path(cls)
|
245
|
+
found_class_source = extract_function_source(cls_source_file, function_names=[cls.__name__],
|
246
|
+
is_class=True,
|
247
|
+
as_list=True)[0]
|
248
|
+
class_signature = found_class_source.split('\n')[0]
|
249
|
+
if '(' in class_signature:
|
250
|
+
parent_class_names = list(map(lambda x: x.strip(),
|
251
|
+
class_signature.split('(')[1].split(')')[0].split(',')))
|
252
|
+
parent_classes = [importlib.import_module(cls.__module__).__dict__.get(cls_name.strip())
|
253
|
+
for cls_name in parent_class_names]
|
254
|
+
else:
|
255
|
+
parent_classes = []
|
256
|
+
if is_dataclass(cls):
|
257
|
+
found_class_source = f"@dataclass\n{found_class_source}"
|
258
|
+
return '\n'.join([get_class_source(pcls) for pcls in parent_classes] + [found_class_source])
|
259
|
+
|
260
|
+
found_class_sources = []
|
261
|
+
for output_type in output_types:
|
262
|
+
found_class_sources.append(get_class_source(output_type))
|
263
|
+
found_class_sources = '\n\n\n'.join(found_class_sources)
|
264
|
+
return found_class_sources
|
265
|
+
|
266
|
+
def get_prompt_for_ai(self, case_query: CaseQuery, prompt_for: PromptFor) -> str:
|
267
|
+
"""
|
268
|
+
Get the prompt for the AI expert.
|
269
|
+
|
270
|
+
:param case_query: The case query containing the case to classify.
|
271
|
+
:param prompt_for: The type of prompt to get.
|
272
|
+
:return: The prompt for the AI expert.
|
273
|
+
"""
|
274
|
+
# data_to_show = show_current_and_corner_cases(case_query.case)
|
275
|
+
data_to_show = f"\nCase ({case_query.case_name}):\n {case_query.case.__dict__}"
|
276
|
+
template_file_creator = TemplateFileCreator(case_query, prompt_for=prompt_for)
|
277
|
+
boilerplate_code = template_file_creator.build_boilerplate_code()
|
278
|
+
initial_prompt_str = data_to_show + "\n\n" + boilerplate_code + "\n\n"
|
279
|
+
return self.user_prompt.build_prompt_str_for_ai(case_query, prompt_for=prompt_for,
|
280
|
+
initial_prompt_str=initial_prompt_str)
|
281
|
+
|
282
|
+
|
206
283
|
class Human(Expert):
|
207
284
|
"""
|
208
285
|
The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
|
@@ -220,8 +297,9 @@ class Human(Expert):
|
|
220
297
|
-> CallableExpression:
|
221
298
|
data_to_show = None
|
222
299
|
if (not self.use_loaded_answers or len(self.all_expert_answers) == 0) and self.user_prompt.viewer is None:
|
223
|
-
data_to_show = show_current_and_corner_cases(case_query.case,
|
224
|
-
|
300
|
+
data_to_show = show_current_and_corner_cases(case_query.case,
|
301
|
+
{case_query.attribute_name: case_query.target_value},
|
302
|
+
last_evaluated_rule=last_evaluated_rule)
|
225
303
|
return self._get_conditions(case_query, data_to_show)
|
226
304
|
|
227
305
|
def _get_conditions(self, case_query: CaseQuery, data_to_show: Optional[str] = None) \
|
@@ -247,9 +325,10 @@ class Human(Expert):
|
|
247
325
|
if self.answers_save_path is not None and not any(loaded_scope):
|
248
326
|
self.convert_json_answer_to_python_answer(case_query, user_input, condition, PromptFor.Conditions)
|
249
327
|
else:
|
250
|
-
user_input, condition = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conditions,
|
251
|
-
|
252
|
-
|
328
|
+
user_input, condition = self.user_prompt.prompt_user_for_expression(case_query, PromptFor.Conditions,
|
329
|
+
prompt_str=data_to_show)
|
330
|
+
if user_input in ['exit', 'quit']:
|
331
|
+
sys.exit()
|
253
332
|
if not self.use_loaded_answers:
|
254
333
|
self.all_expert_answers.append((condition.scope, user_input))
|
255
334
|
if self.answers_save_path is not None:
|
@@ -260,7 +339,6 @@ class Human(Expert):
|
|
260
339
|
def convert_json_answer_to_python_answer(self, case_query: CaseQuery, user_input: str,
|
261
340
|
callable_expression: CallableExpression,
|
262
341
|
prompt_for: PromptFor):
|
263
|
-
case_query.scope['case'] = case_query.case
|
264
342
|
tfc = TemplateFileCreator(case_query, prompt_for=prompt_for)
|
265
343
|
code = tfc.build_boilerplate_code()
|
266
344
|
if user_input.startswith('def'):
|
@@ -303,11 +381,11 @@ class Human(Expert):
|
|
303
381
|
prompt_str=data_to_show)
|
304
382
|
if expert_input is None:
|
305
383
|
self.all_expert_answers.append(({}, None))
|
306
|
-
elif expert_input
|
384
|
+
elif expert_input not in ['exit', 'quit']:
|
307
385
|
self.all_expert_answers.append((expression.scope, expert_input))
|
308
|
-
if self.answers_save_path is not None and expert_input
|
386
|
+
if self.answers_save_path is not None and expert_input not in ['exit', 'quit']:
|
309
387
|
self.save_answers()
|
310
|
-
if expert_input
|
311
|
-
exit()
|
388
|
+
if expert_input in ['exit', 'quit']:
|
389
|
+
sys.exit()
|
312
390
|
case_query.target = expression
|
313
391
|
return expression
|
ripple_down_rules/rdr.py
CHANGED
@@ -13,7 +13,6 @@ from . import logger
|
|
13
13
|
|
14
14
|
try:
|
15
15
|
from matplotlib import pyplot as plt
|
16
|
-
|
17
16
|
Figure = plt.Figure
|
18
17
|
except ImportError as e:
|
19
18
|
logger.debug(f"{e}: matplotlib is not installed")
|
@@ -90,23 +89,13 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
90
89
|
:param save_dir: The directory to save the classifier to.
|
91
90
|
"""
|
92
91
|
self.model_name: Optional[str] = model_name
|
93
|
-
self.save_dir = save_dir
|
94
|
-
self.start_rule = start_rule
|
92
|
+
self.save_dir: Optional[str] = save_dir
|
93
|
+
self.start_rule: Optional[Rule] = start_rule
|
95
94
|
self.fig: Optional[Figure] = None
|
96
95
|
self.viewer: Optional[RDRCaseViewer] = RDRCaseViewer.instances[0]\
|
97
96
|
if RDRCaseViewer and any(RDRCaseViewer.instances) else None
|
98
97
|
self.input_node: Optional[Rule] = None
|
99
98
|
|
100
|
-
@property
|
101
|
-
def viewer(self):
|
102
|
-
return self._viewer
|
103
|
-
|
104
|
-
@viewer.setter
|
105
|
-
def viewer(self, viewer):
|
106
|
-
self._viewer = viewer
|
107
|
-
if viewer:
|
108
|
-
viewer.set_save_function(self.save)
|
109
|
-
|
110
99
|
def render_evaluated_rule_tree(self, filename: str, show_full_tree: bool = False) -> None:
|
111
100
|
if show_full_tree:
|
112
101
|
start_rule = self.start_rule if self.input_node is None else self.input_node
|
@@ -117,6 +106,26 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
117
106
|
render_tree(evaluated_rules[0], use_dot_exporter=True, filename=filename,
|
118
107
|
only_nodes=evaluated_rules)
|
119
108
|
|
109
|
+
def get_contributing_rules(self) -> Optional[List[Rule]]:
|
110
|
+
"""
|
111
|
+
Get the contributing rules of the classifier.
|
112
|
+
|
113
|
+
:return: The contributing rules.
|
114
|
+
"""
|
115
|
+
if self.start_rule is None:
|
116
|
+
return None
|
117
|
+
return [r for r in self.get_fired_rule_tree() if r.contributed]
|
118
|
+
|
119
|
+
def get_fired_rule_tree(self) -> Optional[List[Rule]]:
|
120
|
+
"""
|
121
|
+
Get the fired rule tree of the classifier.
|
122
|
+
|
123
|
+
:return: The fired rule tree.
|
124
|
+
"""
|
125
|
+
if self.start_rule is None:
|
126
|
+
return None
|
127
|
+
return [r for r in self.get_evaluated_rule_tree() if r.fired]
|
128
|
+
|
120
129
|
def get_evaluated_rule_tree(self) -> Optional[List[Rule]]:
|
121
130
|
"""
|
122
131
|
Get the evaluated rule tree of the classifier.
|
@@ -196,16 +205,6 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
196
205
|
"""
|
197
206
|
pass
|
198
207
|
|
199
|
-
def set_viewer(self, viewer: RDRCaseViewer):
|
200
|
-
"""
|
201
|
-
Set the viewer for the classifier.
|
202
|
-
|
203
|
-
:param viewer: The viewer to set.
|
204
|
-
"""
|
205
|
-
self.viewer = viewer
|
206
|
-
if self.viewer is not None:
|
207
|
-
self.viewer.set_save_function(self.save)
|
208
|
-
|
209
208
|
def fit(self, case_queries: List[CaseQuery],
|
210
209
|
expert: Optional[Expert] = None,
|
211
210
|
n_iter: int = None,
|
@@ -273,8 +272,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
273
272
|
"""
|
274
273
|
if self.start_rule is not None:
|
275
274
|
for rule in [self.start_rule] + list(self.start_rule.descendants):
|
276
|
-
rule.
|
277
|
-
rule.fired = False
|
275
|
+
rule.reset()
|
278
276
|
if self.start_rule is not None and self.start_rule.parent is None:
|
279
277
|
if self.input_node is None:
|
280
278
|
self.input_node = type(self.start_rule)(parent=None, uid='0')
|
@@ -332,6 +330,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
332
330
|
self.case_type = case_query.case_type if self.case_type is None else self.case_type
|
333
331
|
self.case_name = case_query.case_name if self.case_name is None else self.case_name
|
334
332
|
case_query.scenario = scenario if case_query.scenario is None else case_query.scenario
|
333
|
+
case_query.rdr = self
|
335
334
|
|
336
335
|
expert = expert or Human(answers_save_path=self.save_dir + '/expert_answers'
|
337
336
|
if self.save_dir else None)
|
@@ -771,6 +770,11 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
771
770
|
"""
|
772
771
|
pred = self.evaluate(case)
|
773
772
|
conclusion = pred.conclusion(case) if pred is not None and pred.fired else self.default_conclusion
|
773
|
+
if pred is not None and pred.fired:
|
774
|
+
pred.contributed = True
|
775
|
+
pred.last_conclusion = conclusion
|
776
|
+
if case_query is not None:
|
777
|
+
pred.contributed_to_case_query = True
|
774
778
|
if pred is not None and pred.fired and case_query is not None:
|
775
779
|
if pred.corner_case_metadata is None and conclusion is not None \
|
776
780
|
and type(conclusion) in case_query.core_attribute_type:
|
@@ -888,6 +892,13 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
888
892
|
and any(
|
889
893
|
ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
|
890
894
|
evaluated_rule.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
895
|
+
if rule_conclusion is not None and any(make_list(rule_conclusion)):
|
896
|
+
evaluated_rule.contributed = True
|
897
|
+
evaluated_rule.last_conclusion = rule_conclusion
|
898
|
+
if case_query is not None:
|
899
|
+
rule_conclusion_types = set(map(type, make_list(rule_conclusion)))
|
900
|
+
if any(rule_conclusion_types.intersection(set(case_query.core_attribute_type))):
|
901
|
+
evaluated_rule.contributed_to_case_query = True
|
891
902
|
self.add_conclusion(rule_conclusion)
|
892
903
|
evaluated_rule = next_rule
|
893
904
|
return make_set(self.conclusions)
|
@@ -982,7 +993,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
982
993
|
"""
|
983
994
|
if not self.start_rule:
|
984
995
|
conditions = expert.ask_for_conditions(case_query)
|
985
|
-
self.start_rule = MultiClassTopRule.from_case_query(case_query)
|
996
|
+
self.start_rule: MultiClassTopRule = MultiClassTopRule.from_case_query(case_query)
|
986
997
|
|
987
998
|
@property
|
988
999
|
def last_top_rule(self) -> Optional[MultiClassTopRule]:
|
ripple_down_rules/rules.py
CHANGED
@@ -63,6 +63,9 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
63
63
|
self.uid: str = uid if uid else str(uuid4().int)
|
64
64
|
self.evaluated: bool = False
|
65
65
|
self._user_defined_name: Optional[str] = None
|
66
|
+
self.last_conclusion: Optional[Any] = None
|
67
|
+
self.contributed: bool = False
|
68
|
+
self.contributed_to_case_query: bool = False
|
66
69
|
|
67
70
|
def get_an_updated_case_copy(self, case: Case) -> Case:
|
68
71
|
"""
|
@@ -72,11 +75,22 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
72
75
|
return get_an_updated_case_copy(case, self.conclusion, self.conclusion_name, self.conclusion.conclusion_type,
|
73
76
|
self.mutually_exclusive)
|
74
77
|
|
78
|
+
def reset(self):
|
79
|
+
self.evaluated = False
|
80
|
+
self.fired = False
|
81
|
+
self.contributed = False
|
82
|
+
self.contributed_to_case_query = False
|
83
|
+
self.last_conclusion = None
|
84
|
+
|
75
85
|
@property
|
76
86
|
def color(self) -> str:
|
77
87
|
if self.evaluated:
|
78
|
-
if self.
|
88
|
+
if self.contributed_to_case_query:
|
79
89
|
return "green"
|
90
|
+
elif self.contributed:
|
91
|
+
return "yellow"
|
92
|
+
elif self.fired:
|
93
|
+
return "orange"
|
80
94
|
else:
|
81
95
|
return "red"
|
82
96
|
else:
|