ripple-down-rules 0.1.61__tar.gz → 0.1.63__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/PKG-INFO +1 -1
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/pyproject.toml +1 -1
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/datastructures/case.py +9 -7
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/datastructures/dataclasses.py +5 -4
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/datastructures/enums.py +5 -1
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/experts.py +44 -53
- ripple_down_rules-0.1.63/src/ripple_down_rules/helpers.py +51 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/prompt.py +26 -13
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/rdr.py +47 -67
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/utils.py +19 -15
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/test/test_rdr_world.py +54 -12
- ripple_down_rules-0.1.61/src/ripple_down_rules/helpers.py +0 -27
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/LICENSE +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/README.md +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/setup.cfg +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/__init__.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/datasets.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/datastructures/__init__.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/datastructures/callable_expression.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/failures.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/rdr_decorators.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/rules.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules.egg-info/SOURCES.txt +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/test/test_json_serialization.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/test/test_on_mutagenic.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/test/test_rdr.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/test/test_rdr_alchemy.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/test/test_relational_rdr.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/test/test_relational_rdr_alchemy.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.63
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
|
|
6
6
|
|
7
7
|
[project]
|
8
8
|
name = "ripple_down_rules"
|
9
|
-
version = "0.1.
|
9
|
+
version = "0.1.63"
|
10
10
|
description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
|
11
11
|
readme = "README.md"
|
12
12
|
authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
|
{ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/datastructures/case.py
RENAMED
@@ -24,7 +24,8 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
24
24
|
the names of the attributes and the values are the attributes. All are stored in lower case.
|
25
25
|
"""
|
26
26
|
|
27
|
-
def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None,
|
27
|
+
def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None,
|
28
|
+
_name: Optional[str] = None, original_object: Optional[Any] = None, **kwargs):
|
28
29
|
"""
|
29
30
|
Create a new row.
|
30
31
|
|
@@ -34,6 +35,7 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
34
35
|
:param kwargs: The attributes of the row.
|
35
36
|
"""
|
36
37
|
super().__init__(kwargs)
|
38
|
+
self._original_object = original_object
|
37
39
|
self._obj_type: Type = _obj_type
|
38
40
|
self._id: Hashable = _id if _id is not None else id(self)
|
39
41
|
self._name: str = _name if _name is not None else self._obj_type.__name__
|
@@ -221,9 +223,9 @@ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
|
|
221
223
|
return obj
|
222
224
|
if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
|
223
225
|
or (obj.__class__ in [MetaData, registry])):
|
224
|
-
return Case(type(obj), _id=id(obj), _name=obj_name,
|
226
|
+
return Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj,
|
225
227
|
**{obj_name or obj.__class__.__name__: make_list(obj) if parent_is_iterable else obj})
|
226
|
-
case = Case(type(obj), _id=id(obj), _name=obj_name)
|
228
|
+
case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
|
227
229
|
for attr in dir(obj):
|
228
230
|
if attr.startswith("_") or callable(getattr(obj, attr)):
|
229
231
|
continue
|
@@ -251,7 +253,7 @@ def create_or_update_case_from_attribute(attr_value: Any, name: str, obj: Any, o
|
|
251
253
|
:return: The updated/created case.
|
252
254
|
"""
|
253
255
|
if case is None:
|
254
|
-
case = Case(type(obj), _id=id(obj), _name=obj_name)
|
256
|
+
case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
|
255
257
|
if isinstance(attr_value, (dict, UserDict)):
|
256
258
|
case.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
|
257
259
|
if hasattr(attr_value, "__iter__") and not isinstance(attr_value, str):
|
@@ -280,7 +282,7 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
|
|
280
282
|
"""
|
281
283
|
values = list(attr_value.values()) if isinstance(attr_value, (dict, UserDict)) else attr_value
|
282
284
|
_type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
|
283
|
-
attr_case = Case(_type, _id=id(attr_value), _name=name)
|
285
|
+
attr_case = Case(_type, _id=id(attr_value), _name=name, original_object=attr_value)
|
284
286
|
case_attr = CaseAttribute(values)
|
285
287
|
for idx, val in enumerate(values):
|
286
288
|
sub_attr_case = create_case(val, recursion_idx=recursion_idx,
|
@@ -305,7 +307,7 @@ def show_current_and_corner_cases(case: Any, targets: Optional[Dict[str, Any]] =
|
|
305
307
|
"""
|
306
308
|
corner_case = None
|
307
309
|
targets = {f"target_{name}": value for name, value in targets.items()} if targets else {}
|
308
|
-
current_conclusions = {name: value for name, value in current_conclusions.items} if current_conclusions else {}
|
310
|
+
current_conclusions = {name: value for name, value in current_conclusions.items()} if current_conclusions else {}
|
309
311
|
if last_evaluated_rule:
|
310
312
|
action = "Refinement" if last_evaluated_rule.fired else "Alternative"
|
311
313
|
print(f"{action} needed for rule: {last_evaluated_rule}\n")
|
@@ -322,7 +324,7 @@ def show_current_and_corner_cases(case: Any, targets: Optional[Dict[str, Any]] =
|
|
322
324
|
corner_row_dict = corner_case
|
323
325
|
|
324
326
|
if corner_row_dict:
|
325
|
-
corner_conclusion = last_evaluated_rule.conclusion
|
327
|
+
corner_conclusion = last_evaluated_rule.conclusion(case)
|
326
328
|
corner_row_dict.update({corner_conclusion.__class__.__name__: corner_conclusion})
|
327
329
|
print(table_rows_as_str(corner_row_dict))
|
328
330
|
print("=" * 50)
|
@@ -8,7 +8,7 @@ from typing_extensions import Any, Optional, Dict, Type, Tuple, Union
|
|
8
8
|
|
9
9
|
from .callable_expression import CallableExpression
|
10
10
|
from .case import create_case, Case
|
11
|
-
from ..utils import copy_case, make_list
|
11
|
+
from ..utils import copy_case, make_list, make_set
|
12
12
|
|
13
13
|
|
14
14
|
@dataclass
|
@@ -88,9 +88,10 @@ class CaseQuery:
|
|
88
88
|
"""
|
89
89
|
:return: The type of the attribute.
|
90
90
|
"""
|
91
|
-
self.
|
92
|
-
|
93
|
-
|
91
|
+
if not self.mutually_exclusive and (set not in make_list(self._attribute_types)):
|
92
|
+
self._attribute_types = tuple(set(make_list(self._attribute_types) + [set, list]))
|
93
|
+
elif not isinstance(self._attribute_types, tuple):
|
94
|
+
self._attribute_types = tuple(make_list(self._attribute_types))
|
94
95
|
return self._attribute_types
|
95
96
|
|
96
97
|
@attribute_type.setter
|
{ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/datastructures/enums.py
RENAMED
@@ -53,7 +53,7 @@ class ExpressionParser(Enum):
|
|
53
53
|
|
54
54
|
class PromptFor(Enum):
|
55
55
|
"""
|
56
|
-
The reason of the prompt. (e.g. get conditions, or
|
56
|
+
The reason of the prompt. (e.g. get conditions, conclusions, or affirmation).
|
57
57
|
"""
|
58
58
|
Conditions: str = "conditions"
|
59
59
|
"""
|
@@ -63,6 +63,10 @@ class PromptFor(Enum):
|
|
63
63
|
"""
|
64
64
|
Prompt for rule conclusion about a case.
|
65
65
|
"""
|
66
|
+
Affirmation: str = "affirmation"
|
67
|
+
"""
|
68
|
+
Prompt for rule conclusion affirmation about a case.
|
69
|
+
"""
|
66
70
|
|
67
71
|
def __str__(self):
|
68
72
|
return self.name
|
@@ -10,7 +10,7 @@ from .datastructures.callable_expression import CallableExpression
|
|
10
10
|
from .datastructures.enums import PromptFor
|
11
11
|
from .datastructures.dataclasses import CaseQuery
|
12
12
|
from .datastructures.case import show_current_and_corner_cases
|
13
|
-
from .prompt import prompt_user_for_expression
|
13
|
+
from .prompt import prompt_user_for_expression, IPythonShell
|
14
14
|
from .utils import get_all_subclasses, make_list
|
15
15
|
|
16
16
|
if TYPE_CHECKING:
|
@@ -51,28 +51,24 @@ class Expert(ABC):
|
|
51
51
|
pass
|
52
52
|
|
53
53
|
@abstractmethod
|
54
|
-
def
|
55
|
-
-> Dict[CaseAttribute, CallableExpression]:
|
54
|
+
def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
|
56
55
|
"""
|
57
|
-
Ask the expert to provide extra
|
58
|
-
that category.
|
56
|
+
Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
|
59
57
|
|
60
|
-
:param
|
61
|
-
:
|
62
|
-
|
58
|
+
:param case_query: The case query containing the case to classify.
|
59
|
+
:return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
|
60
|
+
conclusion and conditions for the rule.
|
63
61
|
"""
|
64
62
|
pass
|
65
63
|
|
66
64
|
@abstractmethod
|
67
|
-
def ask_if_conclusion_is_correct(self,
|
68
|
-
|
69
|
-
current_conclusions: Optional[List[CaseAttribute]] = None) -> bool:
|
65
|
+
def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
|
66
|
+
current_conclusions: Any) -> bool:
|
70
67
|
"""
|
71
68
|
Ask the expert if the conclusion is correct.
|
72
69
|
|
73
|
-
:param
|
70
|
+
:param case_query: The case query about which the expert should answer.
|
74
71
|
:param conclusion: The conclusion to check.
|
75
|
-
:param targets: The target categories to compare the case with.
|
76
72
|
:param current_conclusions: The current conclusions for the case.
|
77
73
|
"""
|
78
74
|
pass
|
@@ -126,10 +122,28 @@ class Human(Expert):
|
|
126
122
|
last_evaluated_rule: Optional[Rule] = None) \
|
127
123
|
-> CallableExpression:
|
128
124
|
if not self.use_loaded_answers:
|
129
|
-
show_current_and_corner_cases(case_query.case, {case_query.attribute_name: case_query.
|
125
|
+
show_current_and_corner_cases(case_query.case, {case_query.attribute_name: case_query.target_value},
|
130
126
|
last_evaluated_rule=last_evaluated_rule)
|
131
127
|
return self._get_conditions(case_query)
|
132
128
|
|
129
|
+
def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
|
130
|
+
"""
|
131
|
+
Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
|
132
|
+
|
133
|
+
:param case_query: The case query containing the case to classify.
|
134
|
+
:return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
|
135
|
+
conclusion and conditions for the rule.
|
136
|
+
"""
|
137
|
+
rules = []
|
138
|
+
while True:
|
139
|
+
conclusion = self.ask_for_conclusion(case_query)
|
140
|
+
if conclusion is None:
|
141
|
+
break
|
142
|
+
conditions = self._get_conditions(case_query)
|
143
|
+
rules.append({PromptFor.Conclusion: conclusion,
|
144
|
+
PromptFor.Conditions: conditions})
|
145
|
+
return rules
|
146
|
+
|
133
147
|
def _get_conditions(self, case_query: CaseQuery) \
|
134
148
|
-> CallableExpression:
|
135
149
|
"""
|
@@ -151,24 +165,6 @@ class Human(Expert):
|
|
151
165
|
case_query.conditions = condition
|
152
166
|
return condition
|
153
167
|
|
154
|
-
def ask_for_extra_conclusions(self, case: Case, current_conclusions: List[CaseAttribute]) \
|
155
|
-
-> Dict[CaseAttribute, CallableExpression]:
|
156
|
-
"""
|
157
|
-
Ask the expert to provide extra conclusions for a case by providing a pair of category and conditions for
|
158
|
-
that category.
|
159
|
-
|
160
|
-
:param case: The case to classify.
|
161
|
-
:param current_conclusions: The current conclusions for the case.
|
162
|
-
:return: The extra conclusions for the case.
|
163
|
-
"""
|
164
|
-
extra_conclusions = {}
|
165
|
-
while True:
|
166
|
-
category = self.ask_for_conclusion(CaseQuery(case), current_conclusions)
|
167
|
-
if not category:
|
168
|
-
break
|
169
|
-
extra_conclusions[category] = self._get_conditions(case, {category.__class__.__name__: category})
|
170
|
-
return extra_conclusions
|
171
|
-
|
172
168
|
def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
|
173
169
|
"""
|
174
170
|
Ask the expert to provide a conclusion for the case.
|
@@ -212,44 +208,39 @@ class Human(Expert):
|
|
212
208
|
:param category_name: The name of the category to ask about.
|
213
209
|
"""
|
214
210
|
question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
|
215
|
-
return not self.
|
211
|
+
return not self.ask_for_affirmation(question)
|
216
212
|
|
217
|
-
def ask_if_conclusion_is_correct(self,
|
218
|
-
|
219
|
-
current_conclusions: Optional[List[Any]] = None) -> bool:
|
213
|
+
def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
|
214
|
+
current_conclusions: Any) -> bool:
|
220
215
|
"""
|
221
216
|
Ask the expert if the conclusion is correct.
|
222
217
|
|
223
|
-
:param
|
218
|
+
:param case_query: The case query about which the expert should answer.
|
224
219
|
:param conclusion: The conclusion to check.
|
225
|
-
:param targets: The target categories to compare the case with.
|
226
220
|
:param current_conclusions: The current conclusions for the case.
|
227
221
|
"""
|
228
|
-
question = ""
|
229
222
|
if not self.use_loaded_answers:
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
question = f"Is the conclusion {conclusion} correct for the case (y/n):" \
|
234
|
-
f"\n{str(x)}"
|
235
|
-
return self.ask_yes_no_question(question)
|
223
|
+
print(f"Current conclusions: {current_conclusions}")
|
224
|
+
return self.ask_for_affirmation(case_query,
|
225
|
+
f"Is the conclusion {conclusion} correct for the case (True/False):")
|
236
226
|
|
237
|
-
def
|
227
|
+
def ask_for_affirmation(self, case_query: CaseQuery, question: str) -> bool:
|
238
228
|
"""
|
239
229
|
Ask the expert a yes or no question.
|
240
230
|
|
241
|
-
:param
|
231
|
+
:param case_query: The case query about which the expert should answer.
|
232
|
+
:param question: The question to ask the expert.
|
242
233
|
:return: The answer to the question.
|
243
234
|
"""
|
244
|
-
if not self.use_loaded_answers:
|
245
|
-
print(question)
|
246
235
|
while True:
|
247
236
|
if self.use_loaded_answers:
|
248
237
|
answer = self.all_expert_answers.pop(0)
|
249
238
|
else:
|
250
|
-
|
251
|
-
|
252
|
-
if answer
|
239
|
+
_, expression = prompt_user_for_expression(case_query, PromptFor.Affirmation, question)
|
240
|
+
answer = expression(case_query.case)
|
241
|
+
if answer:
|
242
|
+
self.all_expert_answers.append(True)
|
253
243
|
return True
|
254
|
-
|
244
|
+
else:
|
245
|
+
self.all_expert_answers.append(False)
|
255
246
|
return False
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
|
5
|
+
from .datastructures.dataclasses import CaseQuery
|
6
|
+
from sqlalchemy.orm import Session
|
7
|
+
from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING
|
8
|
+
|
9
|
+
from .utils import get_func_rdr_model_path
|
10
|
+
from .utils import calculate_precision_and_recall
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from .rdr import RippleDownRules
|
14
|
+
|
15
|
+
|
16
|
+
def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
|
17
|
+
"""
|
18
|
+
:param classifier: The RDR classifier to check the prediction of.
|
19
|
+
:param case_query: The case query to check.
|
20
|
+
:param pred_cat: The predicted category.
|
21
|
+
:return: Whether the classifier prediction is matching case_query target or not.
|
22
|
+
"""
|
23
|
+
if case_query.target is None:
|
24
|
+
return False
|
25
|
+
if pred_cat is None:
|
26
|
+
pred_cat = classifier(case_query.case)
|
27
|
+
if not isinstance(pred_cat, dict):
|
28
|
+
pred_cat = {case_query.attribute_name: pred_cat}
|
29
|
+
target = {case_query.attribute_name: case_query.target_value}
|
30
|
+
precision, recall = calculate_precision_and_recall(pred_cat, target)
|
31
|
+
return all(recall) and all(precision)
|
32
|
+
|
33
|
+
|
34
|
+
def load_or_create_func_rdr_model(func, model_dir: str, rdr_type: Type[RippleDownRules],
|
35
|
+
session: Optional[Session] = None, **rdr_kwargs) -> RippleDownRules:
|
36
|
+
"""
|
37
|
+
Load the RDR model of the function if it exists, otherwise create a new one.
|
38
|
+
|
39
|
+
:param func: The function to load the model for.
|
40
|
+
:param model_dir: The directory where the model is stored.
|
41
|
+
:param rdr_type: The type of the RDR model to load.
|
42
|
+
:param session: The SQLAlchemy session to use.
|
43
|
+
:param rdr_kwargs: Additional arguments to pass to the RDR constructor in the case of a new model.
|
44
|
+
"""
|
45
|
+
model_path = get_func_rdr_model_path(func, model_dir)
|
46
|
+
if os.path.exists(model_path):
|
47
|
+
rdr = rdr_type.load(model_path)
|
48
|
+
rdr.session = session
|
49
|
+
else:
|
50
|
+
rdr = rdr_type(session=session, **rdr_kwargs)
|
51
|
+
return rdr
|
@@ -9,7 +9,7 @@ from typing_extensions import List, Optional, Tuple, Dict
|
|
9
9
|
from .datastructures.enums import PromptFor
|
10
10
|
from .datastructures.callable_expression import CallableExpression, parse_string_to_expression
|
11
11
|
from .datastructures.dataclasses import CaseQuery
|
12
|
-
from .utils import extract_dependencies, contains_return_statement
|
12
|
+
from .utils import extract_dependencies, contains_return_statement, make_set
|
13
13
|
|
14
14
|
|
15
15
|
class CustomInteractiveShell(InteractiveShellEmbed):
|
@@ -28,12 +28,12 @@ class CustomInteractiveShell(InteractiveShellEmbed):
|
|
28
28
|
self.ask_exit()
|
29
29
|
return None
|
30
30
|
result = super().run_cell(raw_cell, **kwargs)
|
31
|
-
if
|
31
|
+
if result.error_in_exec is None and result.error_before_exec is None:
|
32
32
|
self.all_lines.append(raw_cell)
|
33
33
|
return result
|
34
34
|
|
35
35
|
|
36
|
-
class
|
36
|
+
class IPythonShell:
|
37
37
|
"""
|
38
38
|
Create an embedded Ipython shell that can be used to prompt the user for input.
|
39
39
|
"""
|
@@ -63,8 +63,14 @@ class IpythonShell:
|
|
63
63
|
"""
|
64
64
|
Run the embedded shell.
|
65
65
|
"""
|
66
|
-
|
67
|
-
|
66
|
+
while True:
|
67
|
+
try:
|
68
|
+
self.shell()
|
69
|
+
self.update_user_input_from_code_lines()
|
70
|
+
break
|
71
|
+
except Exception as e:
|
72
|
+
logging.error(e)
|
73
|
+
print(e)
|
68
74
|
|
69
75
|
def update_user_input_from_code_lines(self):
|
70
76
|
"""
|
@@ -81,20 +87,23 @@ class IpythonShell:
|
|
81
87
|
self.user_input = self.all_code_lines[0].replace('return', '').strip()
|
82
88
|
else:
|
83
89
|
self.user_input = f"def _get_value(case):\n "
|
84
|
-
|
90
|
+
for cl in self.all_code_lines:
|
91
|
+
sub_code_lines = cl.split('\n')
|
92
|
+
self.user_input += '\n '.join(sub_code_lines) + '\n '
|
85
93
|
|
86
94
|
|
87
|
-
def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor)\
|
95
|
+
def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, prompt_str: Optional[str] = None)\
|
88
96
|
-> Tuple[Optional[str], Optional[CallableExpression]]:
|
89
97
|
"""
|
90
98
|
Prompt the user for an executable python expression to the given case query.
|
91
99
|
|
92
100
|
:param case_query: The case query to prompt the user for.
|
93
101
|
:param prompt_for: The type of information ask user about.
|
102
|
+
:param prompt_str: The prompt string to display to the user.
|
94
103
|
:return: A callable expression that takes a case and executes user expression on it.
|
95
104
|
"""
|
96
105
|
while True:
|
97
|
-
user_input, expression_tree = prompt_user_about_case(case_query, prompt_for)
|
106
|
+
user_input, expression_tree = prompt_user_about_case(case_query, prompt_for, prompt_str)
|
98
107
|
if user_input is None:
|
99
108
|
if prompt_for == PromptFor.Conclusion:
|
100
109
|
print("No conclusion provided. Exiting.")
|
@@ -114,21 +123,24 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor)\
|
|
114
123
|
return user_input, callable_expression
|
115
124
|
|
116
125
|
|
117
|
-
def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor
|
126
|
+
def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor,
|
127
|
+
prompt_str: Optional[str] = None) -> Tuple[Optional[str], Optional[AST]]:
|
118
128
|
"""
|
119
129
|
Prompt the user for input.
|
120
130
|
|
121
131
|
:param case_query: The case query to prompt the user for.
|
122
132
|
:param prompt_for: The type of information the user should provide for the given case.
|
133
|
+
:param prompt_str: The prompt string to display to the user.
|
123
134
|
:return: The user input, and the executable expression that was parsed from the user input.
|
124
135
|
"""
|
125
|
-
prompt_str
|
136
|
+
if prompt_str is None:
|
137
|
+
prompt_str = f"Give {prompt_for} for {case_query.name}"
|
126
138
|
scope = {'case': case_query.case, **case_query.scope}
|
127
|
-
shell =
|
139
|
+
shell = IPythonShell(scope=scope, header=prompt_str)
|
128
140
|
return prompt_user_input_and_parse_to_expression(shell=shell)
|
129
141
|
|
130
142
|
|
131
|
-
def prompt_user_input_and_parse_to_expression(shell: Optional[
|
143
|
+
def prompt_user_input_and_parse_to_expression(shell: Optional[IPythonShell] = None,
|
132
144
|
user_input: Optional[str] = None)\
|
133
145
|
-> Tuple[Optional[str], Optional[ast.AST]]:
|
134
146
|
"""
|
@@ -140,7 +152,7 @@ def prompt_user_input_and_parse_to_expression(shell: Optional[IpythonShell] = No
|
|
140
152
|
"""
|
141
153
|
while True:
|
142
154
|
if user_input is None:
|
143
|
-
shell =
|
155
|
+
shell = IPythonShell() if shell is None else shell
|
144
156
|
shell.run()
|
145
157
|
user_input = shell.user_input
|
146
158
|
if user_input is None:
|
@@ -151,4 +163,5 @@ def prompt_user_input_and_parse_to_expression(shell: Optional[IpythonShell] = No
|
|
151
163
|
except Exception as e:
|
152
164
|
msg = f"Error parsing expression: {e}"
|
153
165
|
logging.error(msg)
|
166
|
+
print(msg)
|
154
167
|
user_input = None
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import importlib
|
4
|
-
import re
|
5
4
|
import sys
|
6
5
|
from abc import ABC, abstractmethod
|
7
6
|
from copy import copy
|
@@ -14,15 +13,16 @@ from ordered_set import OrderedSet
|
|
14
13
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
15
14
|
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
|
16
15
|
|
17
|
-
from .datastructures.case import Case, CaseAttribute
|
18
16
|
from .datastructures.callable_expression import CallableExpression
|
17
|
+
from .datastructures.case import Case, CaseAttribute, create_case
|
19
18
|
from .datastructures.dataclasses import CaseQuery
|
20
|
-
from .datastructures.enums import MCRDRMode
|
19
|
+
from .datastructures.enums import MCRDRMode, PromptFor
|
21
20
|
from .experts import Expert, Human
|
21
|
+
from .helpers import is_matching
|
22
22
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
23
23
|
from .utils import draw_tree, make_set, copy_case, \
|
24
|
-
|
25
|
-
get_case_attribute_type,
|
24
|
+
SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
|
25
|
+
get_case_attribute_type, is_conflicting
|
26
26
|
|
27
27
|
|
28
28
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -277,7 +277,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
277
277
|
else:
|
278
278
|
return "SCRDR"
|
279
279
|
|
280
|
-
|
281
280
|
@property
|
282
281
|
def case_type(self) -> Type:
|
283
282
|
"""
|
@@ -366,7 +365,7 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
366
365
|
super().write_to_python_file(file_path, postfix)
|
367
366
|
if self.default_conclusion is not None:
|
368
367
|
with open(file_path + f"/{self.generated_python_file_name}.py", "a") as f:
|
369
|
-
f.write(f"{' '*4}else:\n{' '*8}return {self.default_conclusion}\n")
|
368
|
+
f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
|
370
369
|
|
371
370
|
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
|
372
371
|
defs_file: Optional[str] = None):
|
@@ -377,7 +376,6 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
377
376
|
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
378
377
|
file.write(if_clause)
|
379
378
|
if rule.refinement:
|
380
|
-
|
381
379
|
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
|
382
380
|
defs_file=defs_file)
|
383
381
|
|
@@ -474,7 +472,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
474
472
|
if target and not make_set(rule_conclusion).issubset(good_conclusions):
|
475
473
|
# Rule fired and conclusion is different from target
|
476
474
|
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
|
477
|
-
|
475
|
+
len(user_conclusions) > 0)
|
478
476
|
else:
|
479
477
|
# Rule fired and target is correct or there is no target to compare
|
480
478
|
self.add_conclusion(evaluated_rule, case_query.case)
|
@@ -485,11 +483,14 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
485
483
|
self.add_rule_for_case(case_query, expert)
|
486
484
|
# Have to check all rules again to make sure only this new rule fires
|
487
485
|
next_rule = self.start_rule
|
488
|
-
elif add_extra_conclusions
|
486
|
+
elif add_extra_conclusions:
|
489
487
|
# No more conclusions can be made, ask the expert for extra conclusions if needed.
|
490
|
-
|
491
|
-
|
488
|
+
new_user_conclusions = self.ask_expert_for_extra_rules(expert, case_query)
|
489
|
+
user_conclusions.extend(new_user_conclusions)
|
490
|
+
if len(new_user_conclusions) > 0:
|
492
491
|
next_rule = self.last_top_rule
|
492
|
+
else:
|
493
|
+
add_extra_conclusions = False
|
493
494
|
evaluated_rule = next_rule
|
494
495
|
return self.conclusions
|
495
496
|
|
@@ -551,13 +552,12 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
551
552
|
"""
|
552
553
|
Stop a wrong conclusion by adding a stopping rule.
|
553
554
|
"""
|
554
|
-
target = case_query.target(case_query.case)
|
555
555
|
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
556
|
-
if
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
556
|
+
if is_conflicting(rule_conclusion, case_query.target_value):
|
557
|
+
if self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
|
558
|
+
return
|
559
|
+
else:
|
560
|
+
self.stop_conclusion(case_query, expert, evaluated_rule)
|
561
561
|
|
562
562
|
def stop_conclusion(self, case_query: CaseQuery,
|
563
563
|
expert: Expert, evaluated_rule: MultiClassTopRule):
|
@@ -576,31 +576,6 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
576
576
|
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
577
577
|
self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
|
578
578
|
|
579
|
-
@staticmethod
|
580
|
-
def is_conflicting_with_target(conclusion: Any, target: Any) -> bool:
|
581
|
-
"""
|
582
|
-
Check if the conclusion is conflicting with the target category.
|
583
|
-
|
584
|
-
:param conclusion: The conclusion to check.
|
585
|
-
:param target: The target category to compare the conclusion with.
|
586
|
-
:return: Whether the conclusion is conflicting with the target category.
|
587
|
-
"""
|
588
|
-
if hasattr(conclusion, "mutually_exclusive") and conclusion.mutually_exclusive:
|
589
|
-
return True
|
590
|
-
else:
|
591
|
-
return not make_set(conclusion).issubset(make_set(target))
|
592
|
-
|
593
|
-
@staticmethod
|
594
|
-
def is_same_category_type(conclusion: Any, target: Any) -> bool:
|
595
|
-
"""
|
596
|
-
Check if the conclusion is of the same class as the target category.
|
597
|
-
|
598
|
-
:param conclusion: The conclusion to check.
|
599
|
-
:param target: The target category to compare the conclusion with.
|
600
|
-
:return: Whether the conclusion is of the same class as the target category but has a different value.
|
601
|
-
"""
|
602
|
-
return conclusion.__class__ == target.__class__ and target.__class__ != CaseAttribute
|
603
|
-
|
604
579
|
def conclusion_is_correct(self, case_query: CaseQuery,
|
605
580
|
expert: Expert, evaluated_rule: Rule,
|
606
581
|
add_extra_conclusions: bool) -> bool:
|
@@ -616,7 +591,6 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
616
591
|
conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
|
617
592
|
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
|
618
593
|
evaluated_rule.conclusion(case_query.case),
|
619
|
-
targets=case_query.target(case_query.case),
|
620
594
|
current_conclusions=conclusions)):
|
621
595
|
self.add_conclusion(evaluated_rule, case_query.case)
|
622
596
|
self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
|
@@ -637,23 +611,22 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
637
611
|
conditions = expert.ask_for_conditions(case_query)
|
638
612
|
self.add_top_rule(conditions, case_query.target, case_query.case)
|
639
613
|
|
640
|
-
def
|
614
|
+
def ask_expert_for_extra_rules(self, expert: Expert, case_query: CaseQuery) -> List[Any]:
|
641
615
|
"""
|
642
|
-
Ask the expert for extra
|
616
|
+
Ask the expert for extra rules when no more conclusions can be made for a case.
|
643
617
|
|
644
618
|
:param expert: The expert to ask for extra conclusions.
|
645
|
-
:param
|
646
|
-
:return: The extra conclusions that the expert has provided.
|
619
|
+
:param case_query: The case query to ask the expert about.
|
620
|
+
:return: The extra conclusions for the rules that the expert has provided.
|
647
621
|
"""
|
648
622
|
extra_conclusions = []
|
649
623
|
conclusions = list(OrderedSet(self.conclusions))
|
650
624
|
if not expert.use_loaded_answers:
|
651
625
|
print("current conclusions:", conclusions)
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
extra_conclusions.append(conclusion)
|
626
|
+
extra_rules = expert.ask_for_extra_rules(case_query)
|
627
|
+
for rule in extra_rules:
|
628
|
+
self.add_top_rule(rule[PromptFor.Conditions], rule[PromptFor.Conclusion], case_query.case)
|
629
|
+
extra_conclusions.extend(rule[PromptFor.Conclusion](case_query.case))
|
657
630
|
return extra_conclusions
|
658
631
|
|
659
632
|
def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
|
@@ -749,7 +722,7 @@ class GeneralRDR(RippleDownRules):
|
|
749
722
|
def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
|
750
723
|
return [rdr.start_rule for rdr in self.start_rules_dict.values()]
|
751
724
|
|
752
|
-
def classify(self, case:
|
725
|
+
def classify(self, case: Any) -> Optional[Dict[str, Any]]:
|
753
726
|
"""
|
754
727
|
Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
|
755
728
|
the classification until no more categories can be added.
|
@@ -761,7 +734,7 @@ class GeneralRDR(RippleDownRules):
|
|
761
734
|
|
762
735
|
@staticmethod
|
763
736
|
def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
|
764
|
-
case:
|
737
|
+
case: Any) -> Dict[str, Any]:
|
765
738
|
"""
|
766
739
|
Classify a case by going through all classifiers and adding the categories that are classified,
|
767
740
|
and then restarting the classification until no more categories can be added.
|
@@ -771,6 +744,7 @@ class GeneralRDR(RippleDownRules):
|
|
771
744
|
:return: The categories that the case belongs to.
|
772
745
|
"""
|
773
746
|
conclusions = {}
|
747
|
+
case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
|
774
748
|
case_cp = copy_case(case)
|
775
749
|
while True:
|
776
750
|
new_conclusions = {}
|
@@ -801,7 +775,7 @@ class GeneralRDR(RippleDownRules):
|
|
801
775
|
return conclusions
|
802
776
|
|
803
777
|
def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs) \
|
804
|
-
->
|
778
|
+
-> Dict[str, Any]:
|
805
779
|
"""
|
806
780
|
Fit the GRDR on a case, if the target is a new type of category, a new RDR is created for it,
|
807
781
|
else the existing RDR of that type will be fitted on the case, and then classification is done and all
|
@@ -816,7 +790,7 @@ class GeneralRDR(RippleDownRules):
|
|
816
790
|
:return: The categories that the case belongs to.
|
817
791
|
"""
|
818
792
|
expert = expert if expert else Human()
|
819
|
-
case_queries =
|
793
|
+
case_queries = make_list(case_queries)
|
820
794
|
assert len(case_queries) > 0, "No case queries provided"
|
821
795
|
case = case_queries[0].case
|
822
796
|
assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
|
@@ -849,10 +823,11 @@ class GeneralRDR(RippleDownRules):
|
|
849
823
|
else:
|
850
824
|
conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
|
851
825
|
**kwargs)
|
852
|
-
if conclusions is not None
|
853
|
-
conclusions
|
854
|
-
|
855
|
-
|
826
|
+
if conclusions is not None:
|
827
|
+
if (not is_iterable(conclusions)) or len(conclusions) > 0:
|
828
|
+
conclusions = {rdr_attribute_name: conclusions}
|
829
|
+
case_query_cp.mutually_exclusive = True if isinstance(rdr, SingleClassRDR) else False
|
830
|
+
self.update_case(case_query_cp, conclusions)
|
856
831
|
case_query.conditions = case_query_cp.conditions
|
857
832
|
|
858
833
|
return self.classify(case)
|
@@ -863,17 +838,17 @@ class GeneralRDR(RippleDownRules):
|
|
863
838
|
Initialize the appropriate RDR type for the target.
|
864
839
|
"""
|
865
840
|
if case_query.mutually_exclusive is not None:
|
866
|
-
return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive\
|
841
|
+
return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
|
867
842
|
else MultiClassRDR()
|
868
843
|
if case_query.attribute_type in [list, set]:
|
869
844
|
return MultiClassRDR()
|
870
|
-
attribute = getattr(case_query.case, case_query.attribute_name)\
|
845
|
+
attribute = getattr(case_query.case, case_query.attribute_name) \
|
871
846
|
if hasattr(case_query.case, case_query.attribute_name) else case_query.target(case_query.case)
|
872
847
|
if isinstance(attribute, CaseAttribute):
|
873
848
|
return SingleClassRDR(default_conclusion=case_query.default_value) if attribute.mutually_exclusive \
|
874
849
|
else MultiClassRDR()
|
875
850
|
else:
|
876
|
-
return MultiClassRDR() if is_iterable(attribute) or (attribute is None)\
|
851
|
+
return MultiClassRDR() if is_iterable(attribute) or (attribute is None) \
|
877
852
|
else SingleClassRDR(default_conclusion=case_query.default_value)
|
878
853
|
|
879
854
|
@staticmethod
|
@@ -895,13 +870,18 @@ class GeneralRDR(RippleDownRules):
|
|
895
870
|
attribute_type = case_query.attribute_type
|
896
871
|
else:
|
897
872
|
attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
|
898
|
-
if isinstance(attribute, set)
|
899
|
-
attribute = set() if attribute is None else attribute
|
873
|
+
if isinstance(attribute, set):
|
900
874
|
for c in conclusion:
|
901
875
|
attribute.update(make_set(c))
|
902
|
-
elif isinstance(attribute, list)
|
876
|
+
elif isinstance(attribute, list):
|
877
|
+
attribute.extend(conclusion)
|
878
|
+
elif any(at in {List, list} for at in attribute_type):
|
903
879
|
attribute = [] if attribute is None else attribute
|
904
880
|
attribute.extend(conclusion)
|
881
|
+
elif any(at in {Set, set} for at in attribute_type):
|
882
|
+
attribute = set() if attribute is None else attribute
|
883
|
+
for c in conclusion:
|
884
|
+
attribute.update(make_set(c))
|
905
885
|
elif is_iterable(conclusion) and len(conclusion) == 1 \
|
906
886
|
and any(at is type(list(conclusion)[0]) for at in attribute_type):
|
907
887
|
setattr(case_query.case, conclusion_name, list(conclusion)[0])
|
@@ -33,22 +33,25 @@ import ast
|
|
33
33
|
matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
|
34
34
|
|
35
35
|
|
36
|
-
def
|
36
|
+
def is_conflicting(conclusion: Any, target: Any) -> bool:
|
37
37
|
"""
|
38
|
-
:param
|
39
|
-
:param
|
40
|
-
:
|
41
|
-
:return: Whether the classifier prediction is matching case_query target or not.
|
38
|
+
:param conclusion: The conclusion to check.
|
39
|
+
:param target: The target to compare the conclusion with.
|
40
|
+
:return: Whether the conclusion is conflicting with the target by have different values for same type categories.
|
42
41
|
"""
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
target
|
50
|
-
|
51
|
-
|
42
|
+
return have_common_types(conclusion, target) and not make_set(conclusion).issubset(make_set(target))
|
43
|
+
|
44
|
+
|
45
|
+
def have_common_types(conclusion: Any, target: Any) -> bool:
|
46
|
+
"""
|
47
|
+
:param conclusion: The conclusion to check.
|
48
|
+
:param target: The target to compare the conclusion with.
|
49
|
+
:return: Whether the conclusion shares some types with the target.
|
50
|
+
"""
|
51
|
+
target_types = {type(t) for t in make_set(target)}
|
52
|
+
conclusion_types = {type(c) for c in make_set(conclusion)}
|
53
|
+
common_types = conclusion_types.intersection(target_types)
|
54
|
+
return len(common_types) > 0
|
52
55
|
|
53
56
|
|
54
57
|
def calculate_precision_and_recall(pred_cat: Dict[str, Any], target: Dict[str, Any]) -> Tuple[
|
@@ -177,7 +180,8 @@ def extract_dependencies(code_lines):
|
|
177
180
|
|
178
181
|
if not isinstance(final_stmt, ast.Return):
|
179
182
|
raise ValueError("Last line is not a return statement")
|
180
|
-
|
183
|
+
if final_stmt.value is None:
|
184
|
+
raise ValueError("Return statement has no value")
|
181
185
|
needed = get_names_used(final_stmt.value)
|
182
186
|
required_lines = []
|
183
187
|
line_map = {id(node): i for i, node in enumerate(tree.body)}
|
{ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules.egg-info/PKG-INFO
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.63
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -7,19 +7,22 @@ from unittest import TestCase
|
|
7
7
|
|
8
8
|
from ripple_down_rules.datastructures.dataclasses import CaseQuery
|
9
9
|
from ripple_down_rules.experts import Human
|
10
|
-
from ripple_down_rules.
|
11
|
-
from ripple_down_rules.
|
10
|
+
from ripple_down_rules.helpers import is_matching
|
11
|
+
from ripple_down_rules.rdr import GeneralRDR
|
12
12
|
|
13
13
|
|
14
14
|
@dataclass
|
15
15
|
class WorldEntity:
|
16
|
-
world:
|
16
|
+
world: World = field(kw_only=True, repr=False)
|
17
17
|
|
18
18
|
|
19
19
|
@dataclass
|
20
20
|
class Body(WorldEntity):
|
21
21
|
name: str
|
22
22
|
|
23
|
+
def __hash__(self):
|
24
|
+
return hash(self.name)
|
25
|
+
|
23
26
|
|
24
27
|
@dataclass
|
25
28
|
class Handle(Body):
|
@@ -51,11 +54,14 @@ class PrismaticConnection(Connection):
|
|
51
54
|
class World:
|
52
55
|
bodies: List[Body] = field(default_factory=list)
|
53
56
|
connections: List[Connection] = field(default_factory=list)
|
57
|
+
views: List[View] = field(default_factory=list, repr=False)
|
54
58
|
|
55
59
|
|
56
60
|
@dataclass
|
57
61
|
class View(WorldEntity):
|
58
|
-
|
62
|
+
def __init__(self, *args, **kwargs):
|
63
|
+
super().__init__(*args, **kwargs)
|
64
|
+
self.world.views.append(self)
|
59
65
|
|
60
66
|
|
61
67
|
@dataclass
|
@@ -64,17 +70,31 @@ class Drawer(View):
|
|
64
70
|
container: Container
|
65
71
|
correct: Optional[bool] = None
|
66
72
|
|
73
|
+
def __hash__(self):
|
74
|
+
return hash((self.handle.name, self.container.name))
|
75
|
+
|
76
|
+
|
77
|
+
@dataclass
|
78
|
+
class Cabinet(View):
|
79
|
+
container: Container
|
80
|
+
drawers: List[Drawer] = field(default_factory=list)
|
81
|
+
|
82
|
+
def __hash__(self):
|
83
|
+
return hash(tuple([self.container.name] + [hash(drawer) for drawer in self.drawers]))
|
84
|
+
|
67
85
|
|
68
86
|
class TestRDRWorld(TestCase):
|
69
87
|
drawer_case_queries: List[CaseQuery]
|
88
|
+
world: World
|
70
89
|
|
71
90
|
@classmethod
|
72
91
|
def setUpClass(cls):
|
73
92
|
world = World()
|
93
|
+
cls.world = world
|
74
94
|
|
75
95
|
handle = Handle('h1', world=world)
|
76
96
|
handle_2 = Handle('h2', world=world)
|
77
|
-
container_1 = Container('c1',world=world)
|
97
|
+
container_1 = Container('c1', world=world)
|
78
98
|
container_2 = Container('c2', world=world)
|
79
99
|
connection_1 = FixedConnection(container_1, handle, world=world)
|
80
100
|
connection_2 = PrismaticConnection(container_2, container_1, world=world)
|
@@ -82,17 +102,39 @@ class TestRDRWorld(TestCase):
|
|
82
102
|
world.bodies = [handle, container_1, container_2, handle_2]
|
83
103
|
world.connections = [connection_1, connection_2]
|
84
104
|
|
85
|
-
|
86
|
-
|
87
|
-
i = 1
|
105
|
+
all_possible_drawers = []
|
88
106
|
for handle in [body for body in world.bodies if isinstance(body, Handle)]:
|
89
107
|
for container in [body for body in world.bodies if isinstance(body, Container)]:
|
90
108
|
view = Drawer(handle, container, world=world)
|
91
|
-
|
92
|
-
|
109
|
+
all_possible_drawers.append(view)
|
110
|
+
|
111
|
+
print(all_possible_drawers)
|
112
|
+
cls.drawer_case_queries = [CaseQuery(possible_drawer, "correct", (bool,), True, default_value=False)
|
113
|
+
for possible_drawer in all_possible_drawers]
|
93
114
|
|
94
|
-
|
95
|
-
|
115
|
+
def test_view_rdr(self):
|
116
|
+
self.get_view_rdr(use_loaded_answers=True, save_answers=False, append=False)
|
117
|
+
|
118
|
+
def get_view_rdr(self, use_loaded_answers: bool = True, save_answers: bool = False,
|
119
|
+
append: bool = False):
|
120
|
+
expert = Human(use_loaded_answers=use_loaded_answers)
|
121
|
+
filename = os.path.join(os.getcwd(), "test_expert_answers/view_rdr_expert_answers_fit")
|
122
|
+
if use_loaded_answers:
|
123
|
+
expert.load_answers(filename)
|
124
|
+
rdr = GeneralRDR()
|
125
|
+
try:
|
126
|
+
rdr.fit_case([CaseQuery(self.world, "views", (View,), False)], expert=expert,
|
127
|
+
add_extra_conclusions=True)
|
128
|
+
except Exception as e:
|
129
|
+
if append:
|
130
|
+
expert.use_loaded_answers = False
|
131
|
+
rdr.fit_case([CaseQuery(self.world, "views", (View,), False)], expert=expert,
|
132
|
+
add_extra_conclusions=True)
|
133
|
+
else:
|
134
|
+
raise e
|
135
|
+
if save_answers:
|
136
|
+
expert.save_answers(filename, append=append)
|
137
|
+
print(rdr.classify(self.world))
|
96
138
|
|
97
139
|
def test_drawer_rdr(self):
|
98
140
|
self.get_drawer_rdr(use_loaded_answers=True, save_answers=False)
|
@@ -1,27 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
|
3
|
-
from sqlalchemy.orm import Session
|
4
|
-
from typing_extensions import Type, Optional
|
5
|
-
|
6
|
-
from ripple_down_rules.rdr import RippleDownRules
|
7
|
-
from ripple_down_rules.utils import get_func_rdr_model_path
|
8
|
-
|
9
|
-
|
10
|
-
def load_or_create_func_rdr_model(func, model_dir: str, rdr_type: Type[RippleDownRules],
|
11
|
-
session: Optional[Session] = None, **rdr_kwargs) -> RippleDownRules:
|
12
|
-
"""
|
13
|
-
Load the RDR model of the function if it exists, otherwise create a new one.
|
14
|
-
|
15
|
-
:param func: The function to load the model for.
|
16
|
-
:param model_dir: The directory where the model is stored.
|
17
|
-
:param rdr_type: The type of the RDR model to load.
|
18
|
-
:param session: The SQLAlchemy session to use.
|
19
|
-
:param rdr_kwargs: Additional arguments to pass to the RDR constructor in the case of a new model.
|
20
|
-
"""
|
21
|
-
model_path = get_func_rdr_model_path(func, model_dir)
|
22
|
-
if os.path.exists(model_path):
|
23
|
-
rdr = rdr_type.load(model_path)
|
24
|
-
rdr.session = session
|
25
|
-
else:
|
26
|
-
rdr = rdr_type(session=session, **rdr_kwargs)
|
27
|
-
return rdr
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules/rdr_decorators.py
RENAMED
File without changes
|
File without changes
|
{ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules.egg-info/SOURCES.txt
RENAMED
File without changes
|
File without changes
|
{ripple_down_rules-0.1.61 → ripple_down_rules-0.1.63}/src/ripple_down_rules.egg-info/top_level.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|