ripple-down-rules 0.1.61__tar.gz → 0.1.62__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/PKG-INFO +1 -1
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/pyproject.toml +1 -1
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/datastructures/case.py +7 -5
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/datastructures/dataclasses.py +5 -4
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/datastructures/enums.py +5 -1
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/experts.py +43 -52
- ripple_down_rules-0.1.62/src/ripple_down_rules/helpers.py +51 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/prompt.py +26 -13
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/rdr.py +38 -59
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/utils.py +19 -15
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/test/test_rdr_world.py +52 -10
- ripple_down_rules-0.1.61/src/ripple_down_rules/helpers.py +0 -27
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/LICENSE +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/README.md +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/setup.cfg +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/__init__.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/datasets.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/datastructures/__init__.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/datastructures/callable_expression.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/failures.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/rdr_decorators.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/rules.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules.egg-info/SOURCES.txt +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/test/test_json_serialization.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/test/test_on_mutagenic.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/test/test_rdr.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/test/test_rdr_alchemy.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/test/test_relational_rdr.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/test/test_relational_rdr_alchemy.py +0 -0
- {ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.62
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
|
|
6
6
|
|
7
7
|
[project]
|
8
8
|
name = "ripple_down_rules"
|
9
|
-
version = "0.1.
|
9
|
+
version = "0.1.62"
|
10
10
|
description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
|
11
11
|
readme = "README.md"
|
12
12
|
authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
|
{ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/datastructures/case.py
RENAMED
@@ -24,7 +24,8 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
24
24
|
the names of the attributes and the values are the attributes. All are stored in lower case.
|
25
25
|
"""
|
26
26
|
|
27
|
-
def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None,
|
27
|
+
def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None,
|
28
|
+
_name: Optional[str] = None, original_object: Optional[Any] = None, **kwargs):
|
28
29
|
"""
|
29
30
|
Create a new row.
|
30
31
|
|
@@ -34,6 +35,7 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
34
35
|
:param kwargs: The attributes of the row.
|
35
36
|
"""
|
36
37
|
super().__init__(kwargs)
|
38
|
+
self._original_object = original_object
|
37
39
|
self._obj_type: Type = _obj_type
|
38
40
|
self._id: Hashable = _id if _id is not None else id(self)
|
39
41
|
self._name: str = _name if _name is not None else self._obj_type.__name__
|
@@ -221,9 +223,9 @@ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
|
|
221
223
|
return obj
|
222
224
|
if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
|
223
225
|
or (obj.__class__ in [MetaData, registry])):
|
224
|
-
return Case(type(obj), _id=id(obj), _name=obj_name,
|
226
|
+
return Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj,
|
225
227
|
**{obj_name or obj.__class__.__name__: make_list(obj) if parent_is_iterable else obj})
|
226
|
-
case = Case(type(obj), _id=id(obj), _name=obj_name)
|
228
|
+
case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
|
227
229
|
for attr in dir(obj):
|
228
230
|
if attr.startswith("_") or callable(getattr(obj, attr)):
|
229
231
|
continue
|
@@ -251,7 +253,7 @@ def create_or_update_case_from_attribute(attr_value: Any, name: str, obj: Any, o
|
|
251
253
|
:return: The updated/created case.
|
252
254
|
"""
|
253
255
|
if case is None:
|
254
|
-
case = Case(type(obj), _id=id(obj), _name=obj_name)
|
256
|
+
case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
|
255
257
|
if isinstance(attr_value, (dict, UserDict)):
|
256
258
|
case.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
|
257
259
|
if hasattr(attr_value, "__iter__") and not isinstance(attr_value, str):
|
@@ -280,7 +282,7 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
|
|
280
282
|
"""
|
281
283
|
values = list(attr_value.values()) if isinstance(attr_value, (dict, UserDict)) else attr_value
|
282
284
|
_type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
|
283
|
-
attr_case = Case(_type, _id=id(attr_value), _name=name)
|
285
|
+
attr_case = Case(_type, _id=id(attr_value), _name=name, original_object=attr_value)
|
284
286
|
case_attr = CaseAttribute(values)
|
285
287
|
for idx, val in enumerate(values):
|
286
288
|
sub_attr_case = create_case(val, recursion_idx=recursion_idx,
|
@@ -8,7 +8,7 @@ from typing_extensions import Any, Optional, Dict, Type, Tuple, Union
|
|
8
8
|
|
9
9
|
from .callable_expression import CallableExpression
|
10
10
|
from .case import create_case, Case
|
11
|
-
from ..utils import copy_case, make_list
|
11
|
+
from ..utils import copy_case, make_list, make_set
|
12
12
|
|
13
13
|
|
14
14
|
@dataclass
|
@@ -88,9 +88,10 @@ class CaseQuery:
|
|
88
88
|
"""
|
89
89
|
:return: The type of the attribute.
|
90
90
|
"""
|
91
|
-
self.
|
92
|
-
|
93
|
-
|
91
|
+
if not self.mutually_exclusive and (set not in make_list(self._attribute_types)):
|
92
|
+
self._attribute_types = tuple(set(make_list(self._attribute_types) + [set, list]))
|
93
|
+
elif not isinstance(self._attribute_types, tuple):
|
94
|
+
self._attribute_types = tuple(make_list(self._attribute_types))
|
94
95
|
return self._attribute_types
|
95
96
|
|
96
97
|
@attribute_type.setter
|
{ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/datastructures/enums.py
RENAMED
@@ -53,7 +53,7 @@ class ExpressionParser(Enum):
|
|
53
53
|
|
54
54
|
class PromptFor(Enum):
|
55
55
|
"""
|
56
|
-
The reason of the prompt. (e.g. get conditions, or
|
56
|
+
The reason of the prompt. (e.g. get conditions, conclusions, or affirmation).
|
57
57
|
"""
|
58
58
|
Conditions: str = "conditions"
|
59
59
|
"""
|
@@ -63,6 +63,10 @@ class PromptFor(Enum):
|
|
63
63
|
"""
|
64
64
|
Prompt for rule conclusion about a case.
|
65
65
|
"""
|
66
|
+
Affirmation: str = "affirmation"
|
67
|
+
"""
|
68
|
+
Prompt for rule conclusion affirmation about a case.
|
69
|
+
"""
|
66
70
|
|
67
71
|
def __str__(self):
|
68
72
|
return self.name
|
@@ -10,7 +10,7 @@ from .datastructures.callable_expression import CallableExpression
|
|
10
10
|
from .datastructures.enums import PromptFor
|
11
11
|
from .datastructures.dataclasses import CaseQuery
|
12
12
|
from .datastructures.case import show_current_and_corner_cases
|
13
|
-
from .prompt import prompt_user_for_expression
|
13
|
+
from .prompt import prompt_user_for_expression, IPythonShell
|
14
14
|
from .utils import get_all_subclasses, make_list
|
15
15
|
|
16
16
|
if TYPE_CHECKING:
|
@@ -51,28 +51,24 @@ class Expert(ABC):
|
|
51
51
|
pass
|
52
52
|
|
53
53
|
@abstractmethod
|
54
|
-
def
|
55
|
-
-> Dict[CaseAttribute, CallableExpression]:
|
54
|
+
def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
|
56
55
|
"""
|
57
|
-
Ask the expert to provide extra
|
58
|
-
that category.
|
56
|
+
Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
|
59
57
|
|
60
|
-
:param
|
61
|
-
:
|
62
|
-
|
58
|
+
:param case_query: The case query containing the case to classify.
|
59
|
+
:return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
|
60
|
+
conclusion and conditions for the rule.
|
63
61
|
"""
|
64
62
|
pass
|
65
63
|
|
66
64
|
@abstractmethod
|
67
|
-
def ask_if_conclusion_is_correct(self,
|
68
|
-
|
69
|
-
current_conclusions: Optional[List[CaseAttribute]] = None) -> bool:
|
65
|
+
def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
|
66
|
+
current_conclusions: Any) -> bool:
|
70
67
|
"""
|
71
68
|
Ask the expert if the conclusion is correct.
|
72
69
|
|
73
|
-
:param
|
70
|
+
:param case_query: The case query about which the expert should answer.
|
74
71
|
:param conclusion: The conclusion to check.
|
75
|
-
:param targets: The target categories to compare the case with.
|
76
72
|
:param current_conclusions: The current conclusions for the case.
|
77
73
|
"""
|
78
74
|
pass
|
@@ -130,6 +126,24 @@ class Human(Expert):
|
|
130
126
|
last_evaluated_rule=last_evaluated_rule)
|
131
127
|
return self._get_conditions(case_query)
|
132
128
|
|
129
|
+
def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
|
130
|
+
"""
|
131
|
+
Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
|
132
|
+
|
133
|
+
:param case_query: The case query containing the case to classify.
|
134
|
+
:return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
|
135
|
+
conclusion and conditions for the rule.
|
136
|
+
"""
|
137
|
+
rules = []
|
138
|
+
while True:
|
139
|
+
conclusion = self.ask_for_conclusion(case_query)
|
140
|
+
if conclusion is None:
|
141
|
+
break
|
142
|
+
conditions = self._get_conditions(case_query)
|
143
|
+
rules.append({PromptFor.Conclusion: conclusion,
|
144
|
+
PromptFor.Conditions: conditions})
|
145
|
+
return rules
|
146
|
+
|
133
147
|
def _get_conditions(self, case_query: CaseQuery) \
|
134
148
|
-> CallableExpression:
|
135
149
|
"""
|
@@ -151,24 +165,6 @@ class Human(Expert):
|
|
151
165
|
case_query.conditions = condition
|
152
166
|
return condition
|
153
167
|
|
154
|
-
def ask_for_extra_conclusions(self, case: Case, current_conclusions: List[CaseAttribute]) \
|
155
|
-
-> Dict[CaseAttribute, CallableExpression]:
|
156
|
-
"""
|
157
|
-
Ask the expert to provide extra conclusions for a case by providing a pair of category and conditions for
|
158
|
-
that category.
|
159
|
-
|
160
|
-
:param case: The case to classify.
|
161
|
-
:param current_conclusions: The current conclusions for the case.
|
162
|
-
:return: The extra conclusions for the case.
|
163
|
-
"""
|
164
|
-
extra_conclusions = {}
|
165
|
-
while True:
|
166
|
-
category = self.ask_for_conclusion(CaseQuery(case), current_conclusions)
|
167
|
-
if not category:
|
168
|
-
break
|
169
|
-
extra_conclusions[category] = self._get_conditions(case, {category.__class__.__name__: category})
|
170
|
-
return extra_conclusions
|
171
|
-
|
172
168
|
def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
|
173
169
|
"""
|
174
170
|
Ask the expert to provide a conclusion for the case.
|
@@ -212,44 +208,39 @@ class Human(Expert):
|
|
212
208
|
:param category_name: The name of the category to ask about.
|
213
209
|
"""
|
214
210
|
question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
|
215
|
-
return not self.
|
211
|
+
return not self.ask_for_affirmation(question)
|
216
212
|
|
217
|
-
def ask_if_conclusion_is_correct(self,
|
218
|
-
|
219
|
-
current_conclusions: Optional[List[Any]] = None) -> bool:
|
213
|
+
def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
|
214
|
+
current_conclusions: Any) -> bool:
|
220
215
|
"""
|
221
216
|
Ask the expert if the conclusion is correct.
|
222
217
|
|
223
|
-
:param
|
218
|
+
:param case_query: The case query about which the expert should answer.
|
224
219
|
:param conclusion: The conclusion to check.
|
225
|
-
:param targets: The target categories to compare the case with.
|
226
220
|
:param current_conclusions: The current conclusions for the case.
|
227
221
|
"""
|
228
|
-
question = ""
|
229
222
|
if not self.use_loaded_answers:
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
question = f"Is the conclusion {conclusion} correct for the case (y/n):" \
|
234
|
-
f"\n{str(x)}"
|
235
|
-
return self.ask_yes_no_question(question)
|
223
|
+
print(f"Current conclusions: {current_conclusions}")
|
224
|
+
return self.ask_for_affirmation(case_query,
|
225
|
+
f"Is the conclusion {conclusion} correct for the case (True/False):")
|
236
226
|
|
237
|
-
def
|
227
|
+
def ask_for_affirmation(self, case_query: CaseQuery, question: str) -> bool:
|
238
228
|
"""
|
239
229
|
Ask the expert a yes or no question.
|
240
230
|
|
241
|
-
:param
|
231
|
+
:param case_query: The case query about which the expert should answer.
|
232
|
+
:param question: The question to ask the expert.
|
242
233
|
:return: The answer to the question.
|
243
234
|
"""
|
244
|
-
if not self.use_loaded_answers:
|
245
|
-
print(question)
|
246
235
|
while True:
|
247
236
|
if self.use_loaded_answers:
|
248
237
|
answer = self.all_expert_answers.pop(0)
|
249
238
|
else:
|
250
|
-
|
251
|
-
|
252
|
-
if answer
|
239
|
+
_, expression = prompt_user_for_expression(case_query, PromptFor.Affirmation, question)
|
240
|
+
answer = expression(case_query.case)
|
241
|
+
if answer:
|
242
|
+
self.all_expert_answers.append(True)
|
253
243
|
return True
|
254
|
-
|
244
|
+
else:
|
245
|
+
self.all_expert_answers.append(False)
|
255
246
|
return False
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
|
5
|
+
from .datastructures.dataclasses import CaseQuery
|
6
|
+
from sqlalchemy.orm import Session
|
7
|
+
from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING
|
8
|
+
|
9
|
+
from .utils import get_func_rdr_model_path
|
10
|
+
from .utils import calculate_precision_and_recall
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from .rdr import RippleDownRules
|
14
|
+
|
15
|
+
|
16
|
+
def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
|
17
|
+
"""
|
18
|
+
:param classifier: The RDR classifier to check the prediction of.
|
19
|
+
:param case_query: The case query to check.
|
20
|
+
:param pred_cat: The predicted category.
|
21
|
+
:return: Whether the classifier prediction is matching case_query target or not.
|
22
|
+
"""
|
23
|
+
if case_query.target is None:
|
24
|
+
return False
|
25
|
+
if pred_cat is None:
|
26
|
+
pred_cat = classifier(case_query.case)
|
27
|
+
if not isinstance(pred_cat, dict):
|
28
|
+
pred_cat = {case_query.attribute_name: pred_cat}
|
29
|
+
target = {case_query.attribute_name: case_query.target_value}
|
30
|
+
precision, recall = calculate_precision_and_recall(pred_cat, target)
|
31
|
+
return all(recall) and all(precision)
|
32
|
+
|
33
|
+
|
34
|
+
def load_or_create_func_rdr_model(func, model_dir: str, rdr_type: Type[RippleDownRules],
|
35
|
+
session: Optional[Session] = None, **rdr_kwargs) -> RippleDownRules:
|
36
|
+
"""
|
37
|
+
Load the RDR model of the function if it exists, otherwise create a new one.
|
38
|
+
|
39
|
+
:param func: The function to load the model for.
|
40
|
+
:param model_dir: The directory where the model is stored.
|
41
|
+
:param rdr_type: The type of the RDR model to load.
|
42
|
+
:param session: The SQLAlchemy session to use.
|
43
|
+
:param rdr_kwargs: Additional arguments to pass to the RDR constructor in the case of a new model.
|
44
|
+
"""
|
45
|
+
model_path = get_func_rdr_model_path(func, model_dir)
|
46
|
+
if os.path.exists(model_path):
|
47
|
+
rdr = rdr_type.load(model_path)
|
48
|
+
rdr.session = session
|
49
|
+
else:
|
50
|
+
rdr = rdr_type(session=session, **rdr_kwargs)
|
51
|
+
return rdr
|
@@ -9,7 +9,7 @@ from typing_extensions import List, Optional, Tuple, Dict
|
|
9
9
|
from .datastructures.enums import PromptFor
|
10
10
|
from .datastructures.callable_expression import CallableExpression, parse_string_to_expression
|
11
11
|
from .datastructures.dataclasses import CaseQuery
|
12
|
-
from .utils import extract_dependencies, contains_return_statement
|
12
|
+
from .utils import extract_dependencies, contains_return_statement, make_set
|
13
13
|
|
14
14
|
|
15
15
|
class CustomInteractiveShell(InteractiveShellEmbed):
|
@@ -28,12 +28,12 @@ class CustomInteractiveShell(InteractiveShellEmbed):
|
|
28
28
|
self.ask_exit()
|
29
29
|
return None
|
30
30
|
result = super().run_cell(raw_cell, **kwargs)
|
31
|
-
if
|
31
|
+
if result.error_in_exec is None and result.error_before_exec is None:
|
32
32
|
self.all_lines.append(raw_cell)
|
33
33
|
return result
|
34
34
|
|
35
35
|
|
36
|
-
class
|
36
|
+
class IPythonShell:
|
37
37
|
"""
|
38
38
|
Create an embedded Ipython shell that can be used to prompt the user for input.
|
39
39
|
"""
|
@@ -63,8 +63,14 @@ class IpythonShell:
|
|
63
63
|
"""
|
64
64
|
Run the embedded shell.
|
65
65
|
"""
|
66
|
-
|
67
|
-
|
66
|
+
while True:
|
67
|
+
try:
|
68
|
+
self.shell()
|
69
|
+
self.update_user_input_from_code_lines()
|
70
|
+
break
|
71
|
+
except Exception as e:
|
72
|
+
logging.error(e)
|
73
|
+
print(e)
|
68
74
|
|
69
75
|
def update_user_input_from_code_lines(self):
|
70
76
|
"""
|
@@ -81,20 +87,23 @@ class IpythonShell:
|
|
81
87
|
self.user_input = self.all_code_lines[0].replace('return', '').strip()
|
82
88
|
else:
|
83
89
|
self.user_input = f"def _get_value(case):\n "
|
84
|
-
|
90
|
+
for cl in self.all_code_lines:
|
91
|
+
sub_code_lines = cl.split('\n')
|
92
|
+
self.user_input += '\n '.join(sub_code_lines) + '\n '
|
85
93
|
|
86
94
|
|
87
|
-
def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor)\
|
95
|
+
def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, prompt_str: Optional[str] = None)\
|
88
96
|
-> Tuple[Optional[str], Optional[CallableExpression]]:
|
89
97
|
"""
|
90
98
|
Prompt the user for an executable python expression to the given case query.
|
91
99
|
|
92
100
|
:param case_query: The case query to prompt the user for.
|
93
101
|
:param prompt_for: The type of information ask user about.
|
102
|
+
:param prompt_str: The prompt string to display to the user.
|
94
103
|
:return: A callable expression that takes a case and executes user expression on it.
|
95
104
|
"""
|
96
105
|
while True:
|
97
|
-
user_input, expression_tree = prompt_user_about_case(case_query, prompt_for)
|
106
|
+
user_input, expression_tree = prompt_user_about_case(case_query, prompt_for, prompt_str)
|
98
107
|
if user_input is None:
|
99
108
|
if prompt_for == PromptFor.Conclusion:
|
100
109
|
print("No conclusion provided. Exiting.")
|
@@ -114,21 +123,24 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor)\
|
|
114
123
|
return user_input, callable_expression
|
115
124
|
|
116
125
|
|
117
|
-
def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor
|
126
|
+
def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor,
|
127
|
+
prompt_str: Optional[str] = None) -> Tuple[Optional[str], Optional[AST]]:
|
118
128
|
"""
|
119
129
|
Prompt the user for input.
|
120
130
|
|
121
131
|
:param case_query: The case query to prompt the user for.
|
122
132
|
:param prompt_for: The type of information the user should provide for the given case.
|
133
|
+
:param prompt_str: The prompt string to display to the user.
|
123
134
|
:return: The user input, and the executable expression that was parsed from the user input.
|
124
135
|
"""
|
125
|
-
prompt_str
|
136
|
+
if prompt_str is None:
|
137
|
+
prompt_str = f"Give {prompt_for} for {case_query.name}"
|
126
138
|
scope = {'case': case_query.case, **case_query.scope}
|
127
|
-
shell =
|
139
|
+
shell = IPythonShell(scope=scope, header=prompt_str)
|
128
140
|
return prompt_user_input_and_parse_to_expression(shell=shell)
|
129
141
|
|
130
142
|
|
131
|
-
def prompt_user_input_and_parse_to_expression(shell: Optional[
|
143
|
+
def prompt_user_input_and_parse_to_expression(shell: Optional[IPythonShell] = None,
|
132
144
|
user_input: Optional[str] = None)\
|
133
145
|
-> Tuple[Optional[str], Optional[ast.AST]]:
|
134
146
|
"""
|
@@ -140,7 +152,7 @@ def prompt_user_input_and_parse_to_expression(shell: Optional[IpythonShell] = No
|
|
140
152
|
"""
|
141
153
|
while True:
|
142
154
|
if user_input is None:
|
143
|
-
shell =
|
155
|
+
shell = IPythonShell() if shell is None else shell
|
144
156
|
shell.run()
|
145
157
|
user_input = shell.user_input
|
146
158
|
if user_input is None:
|
@@ -151,4 +163,5 @@ def prompt_user_input_and_parse_to_expression(shell: Optional[IpythonShell] = No
|
|
151
163
|
except Exception as e:
|
152
164
|
msg = f"Error parsing expression: {e}"
|
153
165
|
logging.error(msg)
|
166
|
+
print(msg)
|
154
167
|
user_input = None
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import importlib
|
4
|
-
import re
|
5
4
|
import sys
|
6
5
|
from abc import ABC, abstractmethod
|
7
6
|
from copy import copy
|
@@ -14,15 +13,16 @@ from ordered_set import OrderedSet
|
|
14
13
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
15
14
|
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
|
16
15
|
|
17
|
-
from .datastructures.case import Case, CaseAttribute
|
18
16
|
from .datastructures.callable_expression import CallableExpression
|
17
|
+
from .datastructures.case import Case, CaseAttribute, create_case
|
19
18
|
from .datastructures.dataclasses import CaseQuery
|
20
|
-
from .datastructures.enums import MCRDRMode
|
19
|
+
from .datastructures.enums import MCRDRMode, PromptFor
|
21
20
|
from .experts import Expert, Human
|
21
|
+
from .helpers import is_matching
|
22
22
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
23
23
|
from .utils import draw_tree, make_set, copy_case, \
|
24
|
-
|
25
|
-
get_case_attribute_type,
|
24
|
+
SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
|
25
|
+
get_case_attribute_type, is_conflicting
|
26
26
|
|
27
27
|
|
28
28
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -277,7 +277,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
277
277
|
else:
|
278
278
|
return "SCRDR"
|
279
279
|
|
280
|
-
|
281
280
|
@property
|
282
281
|
def case_type(self) -> Type:
|
283
282
|
"""
|
@@ -366,7 +365,7 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
366
365
|
super().write_to_python_file(file_path, postfix)
|
367
366
|
if self.default_conclusion is not None:
|
368
367
|
with open(file_path + f"/{self.generated_python_file_name}.py", "a") as f:
|
369
|
-
f.write(f"{' '*4}else:\n{' '*8}return {self.default_conclusion}\n")
|
368
|
+
f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
|
370
369
|
|
371
370
|
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
|
372
371
|
defs_file: Optional[str] = None):
|
@@ -377,7 +376,6 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
377
376
|
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
378
377
|
file.write(if_clause)
|
379
378
|
if rule.refinement:
|
380
|
-
|
381
379
|
self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
|
382
380
|
defs_file=defs_file)
|
383
381
|
|
@@ -474,7 +472,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
474
472
|
if target and not make_set(rule_conclusion).issubset(good_conclusions):
|
475
473
|
# Rule fired and conclusion is different from target
|
476
474
|
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
|
477
|
-
|
475
|
+
len(user_conclusions) > 0)
|
478
476
|
else:
|
479
477
|
# Rule fired and target is correct or there is no target to compare
|
480
478
|
self.add_conclusion(evaluated_rule, case_query.case)
|
@@ -485,11 +483,14 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
485
483
|
self.add_rule_for_case(case_query, expert)
|
486
484
|
# Have to check all rules again to make sure only this new rule fires
|
487
485
|
next_rule = self.start_rule
|
488
|
-
elif add_extra_conclusions
|
486
|
+
elif add_extra_conclusions:
|
489
487
|
# No more conclusions can be made, ask the expert for extra conclusions if needed.
|
490
|
-
|
491
|
-
|
488
|
+
new_user_conclusions = self.ask_expert_for_extra_rules(expert, case_query)
|
489
|
+
user_conclusions.extend(new_user_conclusions)
|
490
|
+
if len(new_user_conclusions) > 0:
|
492
491
|
next_rule = self.last_top_rule
|
492
|
+
else:
|
493
|
+
add_extra_conclusions = False
|
493
494
|
evaluated_rule = next_rule
|
494
495
|
return self.conclusions
|
495
496
|
|
@@ -551,13 +552,12 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
551
552
|
"""
|
552
553
|
Stop a wrong conclusion by adding a stopping rule.
|
553
554
|
"""
|
554
|
-
target = case_query.target(case_query.case)
|
555
555
|
rule_conclusion = evaluated_rule.conclusion(case_query.case)
|
556
|
-
if
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
556
|
+
if is_conflicting(rule_conclusion, case_query.target_value):
|
557
|
+
if self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
|
558
|
+
return
|
559
|
+
else:
|
560
|
+
self.stop_conclusion(case_query, expert, evaluated_rule)
|
561
561
|
|
562
562
|
def stop_conclusion(self, case_query: CaseQuery,
|
563
563
|
expert: Expert, evaluated_rule: MultiClassTopRule):
|
@@ -576,31 +576,6 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
576
576
|
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
577
577
|
self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
|
578
578
|
|
579
|
-
@staticmethod
|
580
|
-
def is_conflicting_with_target(conclusion: Any, target: Any) -> bool:
|
581
|
-
"""
|
582
|
-
Check if the conclusion is conflicting with the target category.
|
583
|
-
|
584
|
-
:param conclusion: The conclusion to check.
|
585
|
-
:param target: The target category to compare the conclusion with.
|
586
|
-
:return: Whether the conclusion is conflicting with the target category.
|
587
|
-
"""
|
588
|
-
if hasattr(conclusion, "mutually_exclusive") and conclusion.mutually_exclusive:
|
589
|
-
return True
|
590
|
-
else:
|
591
|
-
return not make_set(conclusion).issubset(make_set(target))
|
592
|
-
|
593
|
-
@staticmethod
|
594
|
-
def is_same_category_type(conclusion: Any, target: Any) -> bool:
|
595
|
-
"""
|
596
|
-
Check if the conclusion is of the same class as the target category.
|
597
|
-
|
598
|
-
:param conclusion: The conclusion to check.
|
599
|
-
:param target: The target category to compare the conclusion with.
|
600
|
-
:return: Whether the conclusion is of the same class as the target category but has a different value.
|
601
|
-
"""
|
602
|
-
return conclusion.__class__ == target.__class__ and target.__class__ != CaseAttribute
|
603
|
-
|
604
579
|
def conclusion_is_correct(self, case_query: CaseQuery,
|
605
580
|
expert: Expert, evaluated_rule: Rule,
|
606
581
|
add_extra_conclusions: bool) -> bool:
|
@@ -616,7 +591,6 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
616
591
|
conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
|
617
592
|
if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
|
618
593
|
evaluated_rule.conclusion(case_query.case),
|
619
|
-
targets=case_query.target(case_query.case),
|
620
594
|
current_conclusions=conclusions)):
|
621
595
|
self.add_conclusion(evaluated_rule, case_query.case)
|
622
596
|
self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
|
@@ -637,23 +611,22 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
637
611
|
conditions = expert.ask_for_conditions(case_query)
|
638
612
|
self.add_top_rule(conditions, case_query.target, case_query.case)
|
639
613
|
|
640
|
-
def
|
614
|
+
def ask_expert_for_extra_rules(self, expert: Expert, case_query: CaseQuery) -> List[Any]:
|
641
615
|
"""
|
642
|
-
Ask the expert for extra
|
616
|
+
Ask the expert for extra rules when no more conclusions can be made for a case.
|
643
617
|
|
644
618
|
:param expert: The expert to ask for extra conclusions.
|
645
|
-
:param
|
646
|
-
:return: The extra conclusions that the expert has provided.
|
619
|
+
:param case_query: The case query to ask the expert about.
|
620
|
+
:return: The extra conclusions for the rules that the expert has provided.
|
647
621
|
"""
|
648
622
|
extra_conclusions = []
|
649
623
|
conclusions = list(OrderedSet(self.conclusions))
|
650
624
|
if not expert.use_loaded_answers:
|
651
625
|
print("current conclusions:", conclusions)
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
extra_conclusions.append(conclusion)
|
626
|
+
extra_rules = expert.ask_for_extra_rules(case_query)
|
627
|
+
for rule in extra_rules:
|
628
|
+
self.add_top_rule(rule[PromptFor.Conditions], rule[PromptFor.Conclusion], case_query.case)
|
629
|
+
extra_conclusions.extend(rule[PromptFor.Conclusion](case_query.case))
|
657
630
|
return extra_conclusions
|
658
631
|
|
659
632
|
def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
|
@@ -771,6 +744,7 @@ class GeneralRDR(RippleDownRules):
|
|
771
744
|
:return: The categories that the case belongs to.
|
772
745
|
"""
|
773
746
|
conclusions = {}
|
747
|
+
case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
|
774
748
|
case_cp = copy_case(case)
|
775
749
|
while True:
|
776
750
|
new_conclusions = {}
|
@@ -863,17 +837,17 @@ class GeneralRDR(RippleDownRules):
|
|
863
837
|
Initialize the appropriate RDR type for the target.
|
864
838
|
"""
|
865
839
|
if case_query.mutually_exclusive is not None:
|
866
|
-
return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive\
|
840
|
+
return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
|
867
841
|
else MultiClassRDR()
|
868
842
|
if case_query.attribute_type in [list, set]:
|
869
843
|
return MultiClassRDR()
|
870
|
-
attribute = getattr(case_query.case, case_query.attribute_name)\
|
844
|
+
attribute = getattr(case_query.case, case_query.attribute_name) \
|
871
845
|
if hasattr(case_query.case, case_query.attribute_name) else case_query.target(case_query.case)
|
872
846
|
if isinstance(attribute, CaseAttribute):
|
873
847
|
return SingleClassRDR(default_conclusion=case_query.default_value) if attribute.mutually_exclusive \
|
874
848
|
else MultiClassRDR()
|
875
849
|
else:
|
876
|
-
return MultiClassRDR() if is_iterable(attribute) or (attribute is None)\
|
850
|
+
return MultiClassRDR() if is_iterable(attribute) or (attribute is None) \
|
877
851
|
else SingleClassRDR(default_conclusion=case_query.default_value)
|
878
852
|
|
879
853
|
@staticmethod
|
@@ -895,13 +869,18 @@ class GeneralRDR(RippleDownRules):
|
|
895
869
|
attribute_type = case_query.attribute_type
|
896
870
|
else:
|
897
871
|
attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
|
898
|
-
if isinstance(attribute, set)
|
899
|
-
attribute = set() if attribute is None else attribute
|
872
|
+
if isinstance(attribute, set):
|
900
873
|
for c in conclusion:
|
901
874
|
attribute.update(make_set(c))
|
902
|
-
elif isinstance(attribute, list)
|
875
|
+
elif isinstance(attribute, list):
|
876
|
+
attribute.extend(conclusion)
|
877
|
+
elif any(at in {List, list} for at in attribute_type):
|
903
878
|
attribute = [] if attribute is None else attribute
|
904
879
|
attribute.extend(conclusion)
|
880
|
+
elif any(at in {Set, set} for at in attribute_type):
|
881
|
+
attribute = set() if attribute is None else attribute
|
882
|
+
for c in conclusion:
|
883
|
+
attribute.update(make_set(c))
|
905
884
|
elif is_iterable(conclusion) and len(conclusion) == 1 \
|
906
885
|
and any(at is type(list(conclusion)[0]) for at in attribute_type):
|
907
886
|
setattr(case_query.case, conclusion_name, list(conclusion)[0])
|
@@ -33,22 +33,25 @@ import ast
|
|
33
33
|
matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
|
34
34
|
|
35
35
|
|
36
|
-
def
|
36
|
+
def is_conflicting(conclusion: Any, target: Any) -> bool:
|
37
37
|
"""
|
38
|
-
:param
|
39
|
-
:param
|
40
|
-
:
|
41
|
-
:return: Whether the classifier prediction is matching case_query target or not.
|
38
|
+
:param conclusion: The conclusion to check.
|
39
|
+
:param target: The target to compare the conclusion with.
|
40
|
+
:return: Whether the conclusion is conflicting with the target by have different values for same type categories.
|
42
41
|
"""
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
target
|
50
|
-
|
51
|
-
|
42
|
+
return have_common_types(conclusion, target) and not make_set(conclusion).issubset(make_set(target))
|
43
|
+
|
44
|
+
|
45
|
+
def have_common_types(conclusion: Any, target: Any) -> bool:
|
46
|
+
"""
|
47
|
+
:param conclusion: The conclusion to check.
|
48
|
+
:param target: The target to compare the conclusion with.
|
49
|
+
:return: Whether the conclusion shares some types with the target.
|
50
|
+
"""
|
51
|
+
target_types = {type(t) for t in make_set(target)}
|
52
|
+
conclusion_types = {type(c) for c in make_set(conclusion)}
|
53
|
+
common_types = conclusion_types.intersection(target_types)
|
54
|
+
return len(common_types) > 0
|
52
55
|
|
53
56
|
|
54
57
|
def calculate_precision_and_recall(pred_cat: Dict[str, Any], target: Dict[str, Any]) -> Tuple[
|
@@ -177,7 +180,8 @@ def extract_dependencies(code_lines):
|
|
177
180
|
|
178
181
|
if not isinstance(final_stmt, ast.Return):
|
179
182
|
raise ValueError("Last line is not a return statement")
|
180
|
-
|
183
|
+
if final_stmt.value is None:
|
184
|
+
raise ValueError("Return statement has no value")
|
181
185
|
needed = get_names_used(final_stmt.value)
|
182
186
|
required_lines = []
|
183
187
|
line_map = {id(node): i for i, node in enumerate(tree.body)}
|
{ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules.egg-info/PKG-INFO
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.62
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -8,18 +8,21 @@ from unittest import TestCase
|
|
8
8
|
from ripple_down_rules.datastructures.dataclasses import CaseQuery
|
9
9
|
from ripple_down_rules.experts import Human
|
10
10
|
from ripple_down_rules.rdr import SingleClassRDR, GeneralRDR
|
11
|
-
from ripple_down_rules.
|
11
|
+
from ripple_down_rules.helpers import is_matching
|
12
12
|
|
13
13
|
|
14
14
|
@dataclass
|
15
15
|
class WorldEntity:
|
16
|
-
world:
|
16
|
+
world: World = field(kw_only=True, repr=False)
|
17
17
|
|
18
18
|
|
19
19
|
@dataclass
|
20
20
|
class Body(WorldEntity):
|
21
21
|
name: str
|
22
22
|
|
23
|
+
def __hash__(self):
|
24
|
+
return hash(self.name)
|
25
|
+
|
23
26
|
|
24
27
|
@dataclass
|
25
28
|
class Handle(Body):
|
@@ -51,11 +54,14 @@ class PrismaticConnection(Connection):
|
|
51
54
|
class World:
|
52
55
|
bodies: List[Body] = field(default_factory=list)
|
53
56
|
connections: List[Connection] = field(default_factory=list)
|
57
|
+
views: List[View] = field(default_factory=list, repr=False)
|
54
58
|
|
55
59
|
|
56
60
|
@dataclass
|
57
61
|
class View(WorldEntity):
|
58
|
-
|
62
|
+
def __init__(self, *args, **kwargs):
|
63
|
+
super().__init__(*args, **kwargs)
|
64
|
+
self.world.views.append(self)
|
59
65
|
|
60
66
|
|
61
67
|
@dataclass
|
@@ -64,13 +70,27 @@ class Drawer(View):
|
|
64
70
|
container: Container
|
65
71
|
correct: Optional[bool] = None
|
66
72
|
|
73
|
+
def __hash__(self):
|
74
|
+
return hash((self.handle.name, self.container.name))
|
75
|
+
|
76
|
+
|
77
|
+
@dataclass
|
78
|
+
class Cabinet(View):
|
79
|
+
container: Container
|
80
|
+
drawers: List[Drawer] = field(default_factory=list)
|
81
|
+
|
82
|
+
def __hash__(self):
|
83
|
+
return hash(tuple([self.container.name] + [hash(drawer) for drawer in self.drawers]))
|
84
|
+
|
67
85
|
|
68
86
|
class TestRDRWorld(TestCase):
|
69
87
|
drawer_case_queries: List[CaseQuery]
|
88
|
+
world: World
|
70
89
|
|
71
90
|
@classmethod
|
72
91
|
def setUpClass(cls):
|
73
92
|
world = World()
|
93
|
+
cls.world = world
|
74
94
|
|
75
95
|
handle = Handle('h1', world=world)
|
76
96
|
handle_2 = Handle('h2', world=world)
|
@@ -82,17 +102,39 @@ class TestRDRWorld(TestCase):
|
|
82
102
|
world.bodies = [handle, container_1, container_2, handle_2]
|
83
103
|
world.connections = [connection_1, connection_2]
|
84
104
|
|
85
|
-
|
86
|
-
|
87
|
-
i = 1
|
105
|
+
all_possible_drawers = []
|
88
106
|
for handle in [body for body in world.bodies if isinstance(body, Handle)]:
|
89
107
|
for container in [body for body in world.bodies if isinstance(body, Container)]:
|
90
108
|
view = Drawer(handle, container, world=world)
|
91
|
-
|
92
|
-
|
109
|
+
all_possible_drawers.append(view)
|
110
|
+
|
111
|
+
print(all_possible_drawers)
|
112
|
+
cls.drawer_case_queries = [CaseQuery(possible_drawer, "correct", bool, True, default_value=False)
|
113
|
+
for possible_drawer in all_possible_drawers]
|
93
114
|
|
94
|
-
|
95
|
-
|
115
|
+
def test_view_rdr(self):
|
116
|
+
self.get_view_rdr(use_loaded_answers=True, save_answers=False, append=False)
|
117
|
+
|
118
|
+
def get_view_rdr(self, use_loaded_answers: bool = True, save_answers: bool = False,
|
119
|
+
append: bool = False):
|
120
|
+
expert = Human(use_loaded_answers=use_loaded_answers)
|
121
|
+
filename = os.path.join(os.getcwd(), "test_expert_answers/view_rdr_expert_answers_fit")
|
122
|
+
if use_loaded_answers:
|
123
|
+
expert.load_answers(filename)
|
124
|
+
rdr = GeneralRDR()
|
125
|
+
try:
|
126
|
+
rdr.fit_case(CaseQuery(self.world, "views", View, False), expert=expert,
|
127
|
+
add_extra_conclusions=True)
|
128
|
+
except Exception as e:
|
129
|
+
if append:
|
130
|
+
expert.use_loaded_answers = False
|
131
|
+
rdr.fit_case(CaseQuery(self.world, "views", View, False), expert=expert,
|
132
|
+
add_extra_conclusions=True)
|
133
|
+
else:
|
134
|
+
raise e
|
135
|
+
if save_answers:
|
136
|
+
expert.save_answers(filename, append=append)
|
137
|
+
print(rdr.classify(self.world))
|
96
138
|
|
97
139
|
def test_drawer_rdr(self):
|
98
140
|
self.get_drawer_rdr(use_loaded_answers=True, save_answers=False)
|
@@ -1,27 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
|
3
|
-
from sqlalchemy.orm import Session
|
4
|
-
from typing_extensions import Type, Optional
|
5
|
-
|
6
|
-
from ripple_down_rules.rdr import RippleDownRules
|
7
|
-
from ripple_down_rules.utils import get_func_rdr_model_path
|
8
|
-
|
9
|
-
|
10
|
-
def load_or_create_func_rdr_model(func, model_dir: str, rdr_type: Type[RippleDownRules],
|
11
|
-
session: Optional[Session] = None, **rdr_kwargs) -> RippleDownRules:
|
12
|
-
"""
|
13
|
-
Load the RDR model of the function if it exists, otherwise create a new one.
|
14
|
-
|
15
|
-
:param func: The function to load the model for.
|
16
|
-
:param model_dir: The directory where the model is stored.
|
17
|
-
:param rdr_type: The type of the RDR model to load.
|
18
|
-
:param session: The SQLAlchemy session to use.
|
19
|
-
:param rdr_kwargs: Additional arguments to pass to the RDR constructor in the case of a new model.
|
20
|
-
"""
|
21
|
-
model_path = get_func_rdr_model_path(func, model_dir)
|
22
|
-
if os.path.exists(model_path):
|
23
|
-
rdr = rdr_type.load(model_path)
|
24
|
-
rdr.session = session
|
25
|
-
else:
|
26
|
-
rdr = rdr_type(session=session, **rdr_kwargs)
|
27
|
-
return rdr
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules/rdr_decorators.py
RENAMED
File without changes
|
File without changes
|
{ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules.egg-info/SOURCES.txt
RENAMED
File without changes
|
File without changes
|
{ripple_down_rules-0.1.61 → ripple_down_rules-0.1.62}/src/ripple_down_rules.egg-info/top_level.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|