ripple-down-rules 0.1.63__tar.gz → 0.1.65__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/PKG-INFO +1 -1
  2. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/pyproject.toml +1 -1
  3. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/datastructures/dataclasses.py +3 -3
  4. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/experts.py +11 -110
  5. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/rdr.py +137 -232
  6. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/rules.py +3 -1
  7. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/utils.py +44 -0
  8. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
  9. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/test/test_rdr.py +8 -23
  10. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/test/test_rdr_world.py +39 -33
  11. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/LICENSE +0 -0
  12. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/README.md +0 -0
  13. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/setup.cfg +0 -0
  14. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/__init__.py +0 -0
  15. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/datasets.py +0 -0
  16. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/datastructures/__init__.py +0 -0
  17. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/datastructures/callable_expression.py +0 -0
  18. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/datastructures/case.py +0 -0
  19. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/datastructures/enums.py +0 -0
  20. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/failures.py +0 -0
  21. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/helpers.py +0 -0
  22. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/prompt.py +0 -0
  23. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules/rdr_decorators.py +0 -0
  24. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules.egg-info/SOURCES.txt +0 -0
  25. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
  26. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
  27. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/test/test_json_serialization.py +0 -0
  28. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/test/test_on_mutagenic.py +0 -0
  29. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/test/test_rdr_alchemy.py +0 -0
  30. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/test/test_relational_rdr.py +0 -0
  31. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/test/test_relational_rdr_alchemy.py +0 -0
  32. {ripple_down_rules-0.1.63 → ripple_down_rules-0.1.65}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.63
3
+ Version: 0.1.65
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
 
7
7
  [project]
8
8
  name = "ripple_down_rules"
9
- version = "0.1.63"
9
+ version = "0.1.65"
10
10
  description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
11
11
  readme = "README.md"
12
12
  authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
@@ -133,7 +133,7 @@ class CaseQuery:
133
133
  if value is not None and not isinstance(value, (CallableExpression, str)):
134
134
  raise ValueError("The target must be a CallableExpression or a string.")
135
135
  self._target = value
136
- self._update_target_value()
136
+ self.update_target_value()
137
137
 
138
138
  @property
139
139
  def target_value(self) -> Any:
@@ -141,10 +141,10 @@ class CaseQuery:
141
141
  :return: The target value of the case query.
142
142
  """
143
143
  if self._target_value is None:
144
- self._update_target_value()
144
+ self.update_target_value()
145
145
  return self._target_value
146
146
 
147
- def _update_target_value(self):
147
+ def update_target_value(self):
148
148
  """
149
149
  Update the target value of the case query.
150
150
  """
@@ -32,10 +32,11 @@ class Expert(ABC):
32
32
  """
33
33
  A flag to indicate if the expert should use loaded answers or not.
34
34
  """
35
- known_categories: Optional[Dict[str, Type[CaseAttribute]]] = None
36
- """
37
- The known categories (i.e. Column types) to use.
38
- """
35
+
36
+ def __init__(self, use_loaded_answers: bool = False, append: bool = False):
37
+ self.all_expert_answers = []
38
+ self.use_loaded_answers = use_loaded_answers
39
+ self.append = append
39
40
 
40
41
  @abstractmethod
41
42
  def ask_for_conditions(self, case_query: CaseQuery, last_evaluated_rule: Optional[Rule] = None) \
@@ -51,28 +52,6 @@ class Expert(ABC):
51
52
  pass
52
53
 
53
54
  @abstractmethod
54
- def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
55
- """
56
- Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
57
-
58
- :param case_query: The case query containing the case to classify.
59
- :return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
60
- conclusion and conditions for the rule.
61
- """
62
- pass
63
-
64
- @abstractmethod
65
- def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
66
- current_conclusions: Any) -> bool:
67
- """
68
- Ask the expert if the conclusion is correct.
69
-
70
- :param case_query: The case query about which the expert should answer.
71
- :param conclusion: The conclusion to check.
72
- :param current_conclusions: The current conclusions for the case.
73
- """
74
- pass
75
-
76
55
  def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
77
56
  """
78
57
  Ask the expert to provide a relational conclusion for the case.
@@ -87,18 +66,14 @@ class Human(Expert):
87
66
  The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
88
67
  """
89
68
 
90
- def __init__(self, use_loaded_answers: bool = False):
91
- self.all_expert_answers = []
92
- self.use_loaded_answers = use_loaded_answers
93
-
94
- def save_answers(self, path: str, append: bool = False):
69
+ def save_answers(self, path: str):
95
70
  """
96
71
  Save the expert answers to a file.
97
72
 
98
73
  :param path: The path to save the answers to.
99
74
  :param append: A flag to indicate if the answers should be appended to the file or not.
100
75
  """
101
- if append:
76
+ if self.append:
102
77
  # read the file and append the new answers
103
78
  with open(path + '.json', "r") as f:
104
79
  all_answers = json.load(f)
@@ -126,24 +101,6 @@ class Human(Expert):
126
101
  last_evaluated_rule=last_evaluated_rule)
127
102
  return self._get_conditions(case_query)
128
103
 
129
- def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
130
- """
131
- Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
132
-
133
- :param case_query: The case query containing the case to classify.
134
- :return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
135
- conclusion and conditions for the rule.
136
- """
137
- rules = []
138
- while True:
139
- conclusion = self.ask_for_conclusion(case_query)
140
- if conclusion is None:
141
- break
142
- conditions = self._get_conditions(case_query)
143
- rules.append({PromptFor.Conclusion: conclusion,
144
- PromptFor.Conditions: conditions})
145
- return rules
146
-
147
104
  def _get_conditions(self, case_query: CaseQuery) \
148
105
  -> CallableExpression:
149
106
  """
@@ -154,6 +111,8 @@ class Human(Expert):
154
111
  :return: The differentiating features as new rule conditions.
155
112
  """
156
113
  user_input = None
114
+ if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
115
+ self.use_loaded_answers = False
157
116
  if self.use_loaded_answers:
158
117
  user_input = self.all_expert_answers.pop(0)
159
118
  if user_input:
@@ -173,6 +132,8 @@ class Human(Expert):
173
132
  :return: The conclusion for the case as a callable expression.
174
133
  """
175
134
  expression: Optional[CallableExpression] = None
135
+ if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
136
+ self.use_loaded_answers = False
176
137
  if self.use_loaded_answers:
177
138
  expert_input = self.all_expert_answers.pop(0)
178
139
  if expert_input is not None:
@@ -184,63 +145,3 @@ class Human(Expert):
184
145
  self.all_expert_answers.append(expert_input)
185
146
  case_query.target = expression
186
147
  return expression
187
-
188
- def get_category_type(self, cat_name: str) -> Optional[Type[CaseAttribute]]:
189
- """
190
- Get the category type from the known categories.
191
-
192
- :param cat_name: The name of the category.
193
- :return: The category type.
194
- """
195
- cat_name = cat_name.lower()
196
- self.known_categories = get_all_subclasses(
197
- CaseAttribute) if not self.known_categories else self.known_categories
198
- self.known_categories.update(CaseAttribute.registry)
199
- category_type = None
200
- if cat_name in self.known_categories:
201
- category_type = self.known_categories[cat_name]
202
- return category_type
203
-
204
- def ask_if_category_is_mutually_exclusive(self, category_name: str) -> bool:
205
- """
206
- Ask the expert if the new category can have multiple values.
207
-
208
- :param category_name: The name of the category to ask about.
209
- """
210
- question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
211
- return not self.ask_for_affirmation(question)
212
-
213
- def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
214
- current_conclusions: Any) -> bool:
215
- """
216
- Ask the expert if the conclusion is correct.
217
-
218
- :param case_query: The case query about which the expert should answer.
219
- :param conclusion: The conclusion to check.
220
- :param current_conclusions: The current conclusions for the case.
221
- """
222
- if not self.use_loaded_answers:
223
- print(f"Current conclusions: {current_conclusions}")
224
- return self.ask_for_affirmation(case_query,
225
- f"Is the conclusion {conclusion} correct for the case (True/False):")
226
-
227
- def ask_for_affirmation(self, case_query: CaseQuery, question: str) -> bool:
228
- """
229
- Ask the expert a yes or no question.
230
-
231
- :param case_query: The case query about which the expert should answer.
232
- :param question: The question to ask the expert.
233
- :return: The answer to the question.
234
- """
235
- while True:
236
- if self.use_loaded_answers:
237
- answer = self.all_expert_answers.pop(0)
238
- else:
239
- _, expression = prompt_user_for_expression(case_query, PromptFor.Affirmation, question)
240
- answer = expression(case_query.case)
241
- if answer:
242
- self.all_expert_answers.append(True)
243
- return True
244
- else:
245
- self.all_expert_answers.append(False)
246
- return False
@@ -22,7 +22,7 @@ from .helpers import is_matching
22
22
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
23
23
  from .utils import draw_tree, make_set, copy_case, \
24
24
  SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
25
- get_case_attribute_type, is_conflicting
25
+ get_case_attribute_type, is_conflicting, update_case
26
26
 
27
27
 
28
28
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -49,32 +49,6 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
49
49
  self.start_rule = start_rule
50
50
  self.fig: Optional[plt.Figure] = None
51
51
 
52
- def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
53
- return self.classify(case)
54
-
55
- @abstractmethod
56
- def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
57
- """
58
- Classify a case.
59
-
60
- :param case: The case to classify.
61
- :return: The category that the case belongs to.
62
- """
63
- pass
64
-
65
- @abstractmethod
66
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
67
- -> Union[CaseAttribute, CallableExpression]:
68
- """
69
- Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
70
- comparing the case with the target category.
71
-
72
- :param case_query: The query containing the case to classify and the target category to compare the case with.
73
- :param expert: The expert to ask for differentiating features as new rule conditions.
74
- :return: The category that the case belongs to.
75
- """
76
- pass
77
-
78
52
  def fit(self, case_queries: List[CaseQuery],
79
53
  expert: Optional[Expert] = None,
80
54
  n_iter: int = None,
@@ -124,6 +98,63 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
124
98
  plt.ioff()
125
99
  plt.show()
126
100
 
101
+ def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
102
+ return self.classify(case)
103
+
104
+ @abstractmethod
105
+ def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
106
+ """
107
+ Classify a case.
108
+
109
+ :param case: The case to classify.
110
+ :return: The category that the case belongs to.
111
+ """
112
+ pass
113
+
114
+ def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
115
+ -> Union[CaseAttribute, CallableExpression]:
116
+ """
117
+ Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
118
+ incorrect by comparing the case with the target category.
119
+
120
+ :param case_query: The query containing the case to classify and the target category to compare the case with.
121
+ :param expert: The expert to ask for differentiating features as new rule conditions.
122
+ :return: The category that the case belongs to.
123
+ """
124
+ if case_query is None:
125
+ raise ValueError("The case query cannot be None.")
126
+ if case_query.target is None:
127
+ expert.ask_for_conclusion(case_query)
128
+ if case_query.target is None:
129
+ return self.classify(case_query.case)
130
+
131
+ self.update_start_rule(case_query, expert)
132
+
133
+ return self._fit_case(case_query, expert=expert, **kwargs)
134
+
135
+ @abstractmethod
136
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
137
+ -> Union[CaseAttribute, CallableExpression]:
138
+ """
139
+ Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
140
+ comparing the case with the target category.
141
+
142
+ :param case_query: The query containing the case to classify and the target category to compare the case with.
143
+ :param expert: The expert to ask for differentiating features as new rule conditions.
144
+ :return: The category that the case belongs to.
145
+ """
146
+ pass
147
+
148
+ @abstractmethod
149
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
150
+ """
151
+ Update the starting rule of the classifier.
152
+
153
+ :param case_query: The case query to update the starting rule with.
154
+ :param expert: The expert to ask for differentiating features as new rule conditions.
155
+ """
156
+ pass
157
+
127
158
  def update_figures(self):
128
159
  """
129
160
  Update the figures of the classifier.
@@ -217,13 +248,16 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
217
248
  for rule in [self.start_rule] + list(self.start_rule.descendants):
218
249
  if not rule.conditions:
219
250
  continue
220
- if rule.conditions.scope is None or len(rule.conditions.scope) == 0:
221
- continue
222
- for k, v in rule.conditions.scope.items():
223
- new_imports = f"from {v.__module__} import {v.__name__}\n"
224
- if new_imports in imports:
251
+ for scope in [rule.conditions.scope, rule.conclusion.scope]:
252
+ if scope is None:
225
253
  continue
226
- imports += new_imports
254
+ for k, v in scope.items():
255
+ if not hasattr(v, "__module__") or not hasattr(v, "__name__"):
256
+ continue
257
+ new_imports = f"from {v.__module__} import {v.__name__}\n"
258
+ if new_imports in imports:
259
+ continue
260
+ imports += new_imports
227
261
  return imports
228
262
 
229
263
  def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
@@ -318,7 +352,7 @@ class SingleClassRDR(RDRWithCodeWriter):
318
352
  super(SingleClassRDR, self).__init__(start_rule)
319
353
  self.default_conclusion: Optional[Any] = default_conclusion
320
354
 
321
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
355
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
322
356
  -> Union[CaseAttribute, CallableExpression, None]:
323
357
  """
324
358
  Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
@@ -328,38 +362,44 @@ class SingleClassRDR(RDRWithCodeWriter):
328
362
  :param expert: The expert to ask for differentiating features as new rule conditions.
329
363
  :return: The category that the case belongs to.
330
364
  """
331
- expert = expert if expert else Human()
332
365
  if case_query.default_value is not None and self.default_conclusion != case_query.default_value:
333
366
  self.default_conclusion = case_query.default_value
334
- case = case_query.case
335
- target = expert.ask_for_conclusion(case_query) if case_query.target is None else case_query.target
336
- if target is None:
337
- return self.classify(case)
338
- if not self.start_rule:
339
- conditions = expert.ask_for_conditions(case_query)
340
- self.start_rule = SingleClassRule(conditions, target, corner_case=case,
341
- conclusion_name=case_query.attribute_name)
342
367
 
343
368
  pred = self.evaluate(case_query.case)
344
- if pred.conclusion(case) != target(case):
345
- conditions = expert.ask_for_conditions(case_query, pred)
346
- pred.fit_rule(case_query.case, target, conditions=conditions)
369
+ if pred.conclusion(case_query.case) != case_query.target_value:
370
+ expert.ask_for_conditions(case_query, pred)
371
+ pred.fit_rule(case_query.case, case_query.target, conditions=case_query.conditions)
347
372
 
348
373
  return self.classify(case_query.case)
349
374
 
350
- def classify(self, case: Case) -> Optional[Any]:
375
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
376
+ """
377
+ Update the starting rule of the classifier.
378
+
379
+ :param case_query: The case query to update the starting rule with.
380
+ :param expert: The expert to ask for differentiating features as new rule conditions.
381
+ """
382
+ if not self.start_rule:
383
+ expert.ask_for_conditions(case_query)
384
+ self.start_rule = SingleClassRule(case_query.conditions, case_query.target, corner_case=case_query.case,
385
+ conclusion_name=case_query.attribute_name)
386
+
387
+ def classify(self, case: Case, modify_original_case: bool = False) -> Optional[Any]:
351
388
  """
352
389
  Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
390
+
391
+ :param case: The case to classify.
392
+ :param modify_original_case: Whether to modify the original case attributes with the conclusion or not.
353
393
  """
354
394
  pred = self.evaluate(case)
355
- return pred.conclusion(case) if pred.fired else self.default_conclusion
395
+ return pred.conclusion(case) if pred is not None and pred.fired else self.default_conclusion
356
396
 
357
397
  def evaluate(self, case: Case) -> SingleClassRule:
358
398
  """
359
399
  Evaluate the starting rule on a case.
360
400
  """
361
- matched_rule = self.start_rule(case)
362
- return matched_rule if matched_rule else self.start_rule
401
+ matched_rule = self.start_rule(case) if self.start_rule is not None else None
402
+ return matched_rule if matched_rule is not None else self.start_rule
363
403
 
364
404
  def write_to_python_file(self, file_path: str, postfix: str = ""):
365
405
  super().write_to_python_file(file_path, postfix)
@@ -388,6 +428,12 @@ class SingleClassRDR(RDRWithCodeWriter):
388
428
  def conclusion_type_hint(self) -> str:
389
429
  return self.conclusion_type[0].__name__
390
430
 
431
+ @property
432
+ def conclusion_type(self) -> Tuple[Type]:
433
+ if self.default_conclusion is not None:
434
+ return (type(self.default_conclusion),)
435
+ return super().conclusion_type
436
+
391
437
  def _to_json(self) -> Dict[str, Any]:
392
438
  return {"start_rule": self.start_rule.to_json()}
393
439
 
@@ -419,13 +465,12 @@ class MultiClassRDR(RDRWithCodeWriter):
419
465
  The conditions of the stopping rule if needed.
420
466
  """
421
467
 
422
- def __init__(self, start_rule: Optional[Rule] = None,
468
+ def __init__(self, start_rule: Optional[MultiClassTopRule] = None,
423
469
  mode: MCRDRMode = MCRDRMode.StopOnly):
424
470
  """
425
471
  :param start_rule: The starting rules for the classifier.
426
472
  :param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
427
473
  """
428
- start_rule = MultiClassTopRule() if not start_rule else start_rule
429
474
  super(MultiClassRDR, self).__init__(start_rule)
430
475
  self.mode: MCRDRMode = mode
431
476
 
@@ -439,40 +484,28 @@ class MultiClassRDR(RDRWithCodeWriter):
439
484
  evaluated_rule = next_rule
440
485
  return make_set(self.conclusions)
441
486
 
442
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None,
443
- add_extra_conclusions: bool = False) -> Set[Union[CaseAttribute, CallableExpression, None]]:
487
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None
488
+ , **kwargs) -> Set[Union[CaseAttribute, CallableExpression, None]]:
444
489
  """
445
490
  Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
446
491
  or missing by comparing the case with the target category if provided.
447
492
 
448
493
  :param case_query: The query containing the case to classify and the target category to compare the case with.
449
494
  :param expert: The expert to ask for differentiating features as new rule conditions or for extra conclusions.
450
- :param add_extra_conclusions: Whether to add extra conclusions after classification is done.
451
495
  :return: The conclusions that the case belongs to.
452
496
  """
453
- expert = expert if expert else Human()
454
- if case_query.target is None:
455
- expert.ask_for_conclusion(case_query)
456
- if case_query.target is None:
457
- return self.classify(case_query.case)
458
- self.update_start_rule(case_query, expert)
459
- self.expert_accepted_conclusions = []
460
- user_conclusions = []
461
497
  self.conclusions = []
462
498
  self.stop_rule_conditions = None
463
499
  evaluated_rule = self.start_rule
464
- target = case_query.target(case_query.case)
500
+ target = make_set(case_query.target_value)
465
501
  while evaluated_rule:
466
502
  next_rule = evaluated_rule(case_query.case)
467
503
  rule_conclusion = evaluated_rule.conclusion(case_query.case)
468
- good_conclusions = make_list(target) + user_conclusions + self.expert_accepted_conclusions
469
- good_conclusions = make_set(good_conclusions)
470
504
 
471
505
  if evaluated_rule.fired:
472
- if target and not make_set(rule_conclusion).issubset(good_conclusions):
506
+ if not make_set(rule_conclusion).issubset(target):
473
507
  # Rule fired and conclusion is different from target
474
- self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
475
- len(user_conclusions) > 0)
508
+ self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
476
509
  else:
477
510
  # Rule fired and target is correct or there is no target to compare
478
511
  self.add_conclusion(evaluated_rule, case_query.case)
@@ -483,14 +516,6 @@ class MultiClassRDR(RDRWithCodeWriter):
483
516
  self.add_rule_for_case(case_query, expert)
484
517
  # Have to check all rules again to make sure only this new rule fires
485
518
  next_rule = self.start_rule
486
- elif add_extra_conclusions:
487
- # No more conclusions can be made, ask the expert for extra conclusions if needed.
488
- new_user_conclusions = self.ask_expert_for_extra_rules(expert, case_query)
489
- user_conclusions.extend(new_user_conclusions)
490
- if len(new_user_conclusions) > 0:
491
- next_rule = self.last_top_rule
492
- else:
493
- add_extra_conclusions = False
494
519
  evaluated_rule = next_rule
495
520
  return self.conclusions
496
521
 
@@ -529,12 +554,10 @@ class MultiClassRDR(RDRWithCodeWriter):
529
554
  :param case_query: The case query to update the starting rule with.
530
555
  :param expert: The expert to ask for differentiating features as new rule conditions.
531
556
  """
532
- if not self.start_rule.conditions:
557
+ if not self.start_rule:
533
558
  conditions = expert.ask_for_conditions(case_query)
534
- self.start_rule.conditions = conditions
535
- self.start_rule.conclusion = case_query.target
536
- self.start_rule.corner_case = case_query.case
537
- self.start_rule.conclusion_name = case_query.attribute_name
559
+ self.start_rule = MultiClassTopRule(conditions, case_query.target, corner_case=case_query.case,
560
+ conclusion_name=case_query.attribute_name)
538
561
 
539
562
  @property
540
563
  def last_top_rule(self) -> Optional[MultiClassTopRule]:
@@ -547,17 +570,13 @@ class MultiClassRDR(RDRWithCodeWriter):
547
570
  return self.start_rule.furthest_alternative[-1]
548
571
 
549
572
  def stop_wrong_conclusion_else_add_it(self, case_query: CaseQuery, expert: Expert,
550
- evaluated_rule: MultiClassTopRule,
551
- add_extra_conclusions: bool):
573
+ evaluated_rule: MultiClassTopRule):
552
574
  """
553
575
  Stop a wrong conclusion by adding a stopping rule.
554
576
  """
555
577
  rule_conclusion = evaluated_rule.conclusion(case_query.case)
556
578
  if is_conflicting(rule_conclusion, case_query.target_value):
557
- if self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
558
- return
559
- else:
560
- self.stop_conclusion(case_query, expert, evaluated_rule)
579
+ self.stop_conclusion(case_query, expert, evaluated_rule)
561
580
 
562
581
  def stop_conclusion(self, case_query: CaseQuery,
563
582
  expert: Expert, evaluated_rule: MultiClassTopRule):
@@ -576,27 +595,6 @@ class MultiClassRDR(RDRWithCodeWriter):
576
595
  new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
577
596
  self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
578
597
 
579
- def conclusion_is_correct(self, case_query: CaseQuery,
580
- expert: Expert, evaluated_rule: Rule,
581
- add_extra_conclusions: bool) -> bool:
582
- """
583
- Ask the expert if the conclusion is correct, and add it to the conclusions if it is.
584
-
585
- :param case_query: The case query to ask the expert about.
586
- :param expert: The expert to ask for differentiating features as new rule conditions.
587
- :param evaluated_rule: The evaluated rule to ask the expert about.
588
- :param add_extra_conclusions: Whether adding extra conclusions after classification is allowed.
589
- :return: Whether the conclusion is correct or not.
590
- """
591
- conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
592
- if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
593
- evaluated_rule.conclusion(case_query.case),
594
- current_conclusions=conclusions)):
595
- self.add_conclusion(evaluated_rule, case_query.case)
596
- self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
597
- return True
598
- return False
599
-
600
598
  def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
601
599
  """
602
600
  Add a rule for a case that has not been classified with any conclusion.
@@ -611,24 +609,6 @@ class MultiClassRDR(RDRWithCodeWriter):
611
609
  conditions = expert.ask_for_conditions(case_query)
612
610
  self.add_top_rule(conditions, case_query.target, case_query.case)
613
611
 
614
- def ask_expert_for_extra_rules(self, expert: Expert, case_query: CaseQuery) -> List[Any]:
615
- """
616
- Ask the expert for extra rules when no more conclusions can be made for a case.
617
-
618
- :param expert: The expert to ask for extra conclusions.
619
- :param case_query: The case query to ask the expert about.
620
- :return: The extra conclusions for the rules that the expert has provided.
621
- """
622
- extra_conclusions = []
623
- conclusions = list(OrderedSet(self.conclusions))
624
- if not expert.use_loaded_answers:
625
- print("current conclusions:", conclusions)
626
- extra_rules = expert.ask_for_extra_rules(case_query)
627
- for rule in extra_rules:
628
- self.add_top_rule(rule[PromptFor.Conditions], rule[PromptFor.Conclusion], case_query.case)
629
- extra_conclusions.extend(rule[PromptFor.Conclusion](case_query.case))
630
- return extra_conclusions
631
-
632
612
  def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
633
613
  """
634
614
  Add the conclusion of the evaluated rule to the list of conclusions.
@@ -722,30 +702,32 @@ class GeneralRDR(RippleDownRules):
722
702
  def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
723
703
  return [rdr.start_rule for rdr in self.start_rules_dict.values()]
724
704
 
725
- def classify(self, case: Any) -> Optional[Dict[str, Any]]:
705
+ def classify(self, case: Any, modify_case: bool = False) -> Optional[Dict[str, Any]]:
726
706
  """
727
707
  Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
728
708
  the classification until no more categories can be added.
729
709
 
730
710
  :param case: The case to classify.
711
+ :param modify_case: Whether to modify the original case or create a copy and modify it.
731
712
  :return: The categories that the case belongs to.
732
713
  """
733
- return self._classify(self.start_rules_dict, case)
714
+ return self._classify(self.start_rules_dict, case, modify_original_case=modify_case)
734
715
 
735
716
  @staticmethod
736
717
  def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
737
- case: Any) -> Dict[str, Any]:
718
+ case: Any, modify_original_case: bool = True) -> Dict[str, Any]:
738
719
  """
739
720
  Classify a case by going through all classifiers and adding the categories that are classified,
740
721
  and then restarting the classification until no more categories can be added.
741
722
 
742
723
  :param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
743
724
  :param case: The case to classify.
725
+ :param modify_original_case: Whether to modify the original case or create a copy and modify it.
744
726
  :return: The categories that the case belongs to.
745
727
  """
746
728
  conclusions = {}
747
729
  case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
748
- case_cp = copy_case(case)
730
+ case_cp = copy_case(case) if not modify_original_case else case
749
731
  while True:
750
732
  new_conclusions = {}
751
733
  for attribute_name, rdr in classifiers_dict.items():
@@ -768,13 +750,13 @@ class GeneralRDR(RippleDownRules):
768
750
  conclusions[attribute_name].update(pred_atts)
769
751
  if attribute_name in new_conclusions:
770
752
  mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
771
- GeneralRDR.update_case(CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive),
772
- new_conclusions)
753
+ case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive)
754
+ update_case(case_query, new_conclusions)
773
755
  if len(new_conclusions) == 0:
774
756
  break
775
757
  return conclusions
776
758
 
777
- def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs) \
759
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
778
760
  -> Dict[str, Any]:
779
761
  """
780
762
  Fit the GRDR on a case, if the target is a new type of category, a new RDR is created for it,
@@ -784,115 +766,38 @@ class GeneralRDR(RippleDownRules):
784
766
  they are accepted by the expert, and the attribute of that category is represented in the case as a set of
785
767
  values.
786
768
 
787
- :param case_queries: The queries containing the case to classify and the target categories to compare the case
769
+ :param case_query: The query containing the case to classify and the target category to compare the case
788
770
  with.
789
771
  :param expert: The expert to ask for differentiating features as new rule conditions.
790
772
  :return: The categories that the case belongs to.
791
773
  """
792
- expert = expert if expert else Human()
793
- case_queries = make_list(case_queries)
794
- assert len(case_queries) > 0, "No case queries provided"
795
- case = case_queries[0].case
796
- assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
797
- " for multiple cases use fit instead")
798
- original_case_query_cp = copy(case_queries[0])
799
- for case_query in case_queries:
800
- case_query_cp = copy(case_query)
801
- case_query_cp.case = original_case_query_cp.case
802
- if case_query_cp.target is None:
803
- conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
804
- self.update_case(case_query_cp, conclusions)
805
- expert.ask_for_conclusion(case_query_cp)
806
- if case_query_cp.target is None:
807
- continue
808
- case_query.target = case_query_cp.target
809
-
810
- if case_query.attribute_name not in self.start_rules_dict:
811
- conclusions = self.classify(case)
812
- self.update_case(case_query_cp, conclusions)
813
-
814
- new_rdr = self.initialize_new_rdr_for_attribute(case_query_cp)
815
- self.add_rdr(new_rdr, case_query.attribute_name)
816
-
817
- new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
818
- self.update_case(case_query_cp, {case_query.attribute_name: new_conclusions})
819
- else:
820
- for rdr_attribute_name, rdr in self.start_rules_dict.items():
821
- if case_query.attribute_name != rdr_attribute_name:
822
- conclusions = rdr.classify(case_query_cp.case)
823
- else:
824
- conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
825
- **kwargs)
826
- if conclusions is not None:
827
- if (not is_iterable(conclusions)) or len(conclusions) > 0:
828
- conclusions = {rdr_attribute_name: conclusions}
829
- case_query_cp.mutually_exclusive = True if isinstance(rdr, SingleClassRDR) else False
830
- self.update_case(case_query_cp, conclusions)
831
- case_query.conditions = case_query_cp.conditions
832
774
 
833
- return self.classify(case)
775
+ case_query_cp = copy(case_query)
776
+ self.classify(case_query_cp.case, modify_case=True)
777
+ case_query_cp.update_target_value()
778
+
779
+ self.start_rules_dict[case_query_cp.attribute_name].fit_case(case_query_cp, expert, **kwargs)
780
+
781
+ return self.classify(case_query.case)
782
+
783
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
784
+ """
785
+ Update the starting rule of the classifier.
786
+
787
+ :param case_query: The case query to update the starting rule with.
788
+ :param expert: The expert to ask for differentiating features as new rule conditions.
789
+ """
790
+ if case_query.attribute_name not in self.start_rules_dict:
791
+ new_rdr = self.initialize_new_rdr_for_attribute(case_query)
792
+ self.add_rdr(new_rdr, case_query.attribute_name)
834
793
 
835
794
  @staticmethod
836
795
  def initialize_new_rdr_for_attribute(case_query: CaseQuery):
837
796
  """
838
797
  Initialize the appropriate RDR type for the target.
839
798
  """
840
- if case_query.mutually_exclusive is not None:
841
- return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
842
- else MultiClassRDR()
843
- if case_query.attribute_type in [list, set]:
844
- return MultiClassRDR()
845
- attribute = getattr(case_query.case, case_query.attribute_name) \
846
- if hasattr(case_query.case, case_query.attribute_name) else case_query.target(case_query.case)
847
- if isinstance(attribute, CaseAttribute):
848
- return SingleClassRDR(default_conclusion=case_query.default_value) if attribute.mutually_exclusive \
849
- else MultiClassRDR()
850
- else:
851
- return MultiClassRDR() if is_iterable(attribute) or (attribute is None) \
852
- else SingleClassRDR(default_conclusion=case_query.default_value)
853
-
854
- @staticmethod
855
- def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
856
- """
857
- Update the case with the conclusions.
858
-
859
- :param case_query: The case query that contains the case to update.
860
- :param conclusions: The conclusions to update the case with.
861
- """
862
- if not conclusions:
863
- return
864
- if len(conclusions) == 0:
865
- return
866
- if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
867
- for conclusion_name, conclusion in conclusions.items():
868
- attribute = getattr(case_query.case, conclusion_name)
869
- if conclusion_name == case_query.attribute_name:
870
- attribute_type = case_query.attribute_type
871
- else:
872
- attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
873
- if isinstance(attribute, set):
874
- for c in conclusion:
875
- attribute.update(make_set(c))
876
- elif isinstance(attribute, list):
877
- attribute.extend(conclusion)
878
- elif any(at in {List, list} for at in attribute_type):
879
- attribute = [] if attribute is None else attribute
880
- attribute.extend(conclusion)
881
- elif any(at in {Set, set} for at in attribute_type):
882
- attribute = set() if attribute is None else attribute
883
- for c in conclusion:
884
- attribute.update(make_set(c))
885
- elif is_iterable(conclusion) and len(conclusion) == 1 \
886
- and any(at is type(list(conclusion)[0]) for at in attribute_type):
887
- setattr(case_query.case, conclusion_name, list(conclusion)[0])
888
- elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
889
- setattr(case_query.case, conclusion_name, conclusion)
890
- else:
891
- raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
892
- f"{case_query.attribute_type} with conclusion "
893
- f"{conclusion} of type {type(conclusion)}")
894
- else:
895
- case_query.case.update(conclusions)
799
+ return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
800
+ else MultiClassRDR()
896
801
 
897
802
  def _to_json(self) -> Dict[str, Any]:
898
803
  return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
@@ -130,6 +130,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
130
130
  "conclusion": conclusion_to_json(self.conclusion),
131
131
  "parent": self.parent.json_serialization if self.parent else None,
132
132
  "corner_case": self.corner_case.to_json() if self.corner_case else None,
133
+ "conclusion_name": self.conclusion_name,
133
134
  "weight": self.weight}
134
135
  return json_serialization
135
136
 
@@ -139,6 +140,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
139
140
  conclusion=CallableExpression.from_json(data["conclusion"]),
140
141
  parent=cls.from_json(data["parent"]),
141
142
  corner_case=Case.from_json(data["corner_case"]),
143
+ conclusion_name=data["conclusion_name"],
142
144
  weight=data["weight"])
143
145
  return loaded_rule
144
146
 
@@ -288,7 +290,7 @@ class MultiClassStopRule(Rule, HasAlternativeRule):
288
290
 
289
291
  def __init__(self, *args, **kwargs):
290
292
  super(MultiClassStopRule, self).__init__(*args, **kwargs)
291
- self.conclusion = Stop.stop
293
+ self.conclusion = CallableExpression(conclusion_type=(Stop,), conclusion=Stop.stop)
292
294
 
293
295
  def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassStopRule, MultiClassTopRule]]:
294
296
  if self.fired:
@@ -26,6 +26,7 @@ from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get
26
26
 
27
27
  if TYPE_CHECKING:
28
28
  from .datastructures.case import Case
29
+ from .datastructures.dataclasses import CaseQuery
29
30
  from .rules import Rule
30
31
 
31
32
  import ast
@@ -33,6 +34,49 @@ import ast
33
34
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
34
35
 
35
36
 
37
+ def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
38
+ """
39
+ Update the case with the conclusions.
40
+
41
+ :param case_query: The case query that contains the case to update.
42
+ :param conclusions: The conclusions to update the case with.
43
+ """
44
+ if not conclusions:
45
+ return
46
+ if len(conclusions) == 0:
47
+ return
48
+ if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
49
+ for conclusion_name, conclusion in conclusions.items():
50
+ attribute = getattr(case_query.case, conclusion_name)
51
+ if conclusion_name == case_query.attribute_name:
52
+ attribute_type = case_query.attribute_type
53
+ else:
54
+ attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
55
+ if isinstance(attribute, set):
56
+ for c in conclusion:
57
+ attribute.update(make_set(c))
58
+ elif isinstance(attribute, list):
59
+ attribute.extend(conclusion)
60
+ elif any(at in {List, list} for at in attribute_type):
61
+ attribute = [] if attribute is None else attribute
62
+ attribute.extend(conclusion)
63
+ elif any(at in {Set, set} for at in attribute_type):
64
+ attribute = set() if attribute is None else attribute
65
+ for c in conclusion:
66
+ attribute.update(make_set(c))
67
+ elif is_iterable(conclusion) and len(conclusion) == 1 \
68
+ and any(at is type(list(conclusion)[0]) for at in attribute_type):
69
+ setattr(case_query.case, conclusion_name, list(conclusion)[0])
70
+ elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
71
+ setattr(case_query.case, conclusion_name, conclusion)
72
+ else:
73
+ raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
74
+ f"{case_query.attribute_type} with conclusion "
75
+ f"{conclusion} of type {type(conclusion)}")
76
+ else:
77
+ case_query.case.update(conclusions)
78
+
79
+
36
80
  def is_conflicting(conclusion: Any, target: Any) -> bool:
37
81
  """
38
82
  :param conclusion: The conclusion to check.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.63
3
+ Version: 0.1.65
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -245,20 +245,12 @@ class TestRDR(TestCase):
245
245
  save_answers = False
246
246
  append = False
247
247
  filename = self.expert_answers_dir + "/mcrdr_stop_plus_rule_expert_answers_fit"
248
- expert = Human(use_loaded_answers=use_loaded_answers)
248
+ expert = Human(use_loaded_answers=use_loaded_answers, append=append)
249
249
  if use_loaded_answers:
250
250
  expert.load_answers(filename)
251
251
  mcrdr = MultiClassRDR(mode=MCRDRMode.StopPlusRule)
252
252
  case_queries = self.case_queries
253
- try:
254
- mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
255
- # catch pop from empty list error
256
- except IndexError as e:
257
- if append:
258
- expert.use_loaded_answers = False
259
- mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
260
- else:
261
- raise e
253
+ mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
262
254
  render_tree(mcrdr.start_rule, use_dot_exporter=True,
263
255
  filename=self.test_results_dir + f"/mcrdr_stop_plus_rule")
264
256
  for case_query in case_queries:
@@ -267,7 +259,7 @@ class TestRDR(TestCase):
267
259
  if save_answers:
268
260
  cwd = os.getcwd()
269
261
  file = os.path.join(cwd, filename)
270
- expert.save_answers(file, append=append)
262
+ expert.save_answers(file)
271
263
 
272
264
  def test_fit_mcrdr_stop_plus_rule_combined(self):
273
265
  use_loaded_answers = True
@@ -275,20 +267,12 @@ class TestRDR(TestCase):
275
267
  draw_tree = False
276
268
  append = False
277
269
  filename = self.expert_answers_dir + "/mcrdr_stop_plus_rule_combined_expert_answers_fit"
278
- expert = Human(use_loaded_answers=use_loaded_answers)
270
+ expert = Human(use_loaded_answers=use_loaded_answers, append=append)
279
271
  if use_loaded_answers:
280
272
  expert.load_answers(filename)
281
273
  mcrdr = MultiClassRDR(mode=MCRDRMode.StopPlusRuleCombined)
282
274
  case_queries = self.case_queries
283
- try:
284
- mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
285
- # catch pop from empty list error
286
- except IndexError as e:
287
- if append:
288
- expert.use_loaded_answers = False
289
- mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
290
- else:
291
- raise e
275
+ mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
292
276
  render_tree(mcrdr.start_rule, use_dot_exporter=True,
293
277
  filename=self.test_results_dir + f"/mcrdr_stop_plus_rule_combined")
294
278
  for case_query in case_queries:
@@ -297,7 +281,7 @@ class TestRDR(TestCase):
297
281
  if save_answers:
298
282
  cwd = os.getcwd()
299
283
  file = os.path.join(cwd, filename)
300
- expert.save_answers(file, append=append)
284
+ expert.save_answers(file)
301
285
 
302
286
  def test_classify_grdr(self):
303
287
  use_loaded_answers = True
@@ -314,7 +298,8 @@ class TestRDR(TestCase):
314
298
  targets = dict(zip(attribute_names, targets))
315
299
  case_queries = [CaseQuery(self.all_cases[0], a, (type(t),), True if a == 'species' else False,
316
300
  _target=t) for a, t in targets.items()]
317
- cats = grdr.fit_case(case_queries, expert=expert)
301
+ grdr.fit(case_queries, expert=expert)
302
+ cats = grdr.classify(self.all_cases[0])
318
303
  for cat_name, value in cats.items():
319
304
  self.assertEqual(make_set(value), make_set(targets[cat_name]))
320
305
 
@@ -13,66 +13,64 @@ from ripple_down_rules.rdr import GeneralRDR
13
13
 
14
14
  @dataclass
15
15
  class WorldEntity:
16
- world: World = field(kw_only=True, repr=False)
16
+ world: Optional[World] = field(default=None, kw_only=True, repr=False, hash=False)
17
17
 
18
18
 
19
- @dataclass
19
+ @dataclass(unsafe_hash=True)
20
20
  class Body(WorldEntity):
21
21
  name: str
22
22
 
23
- def __hash__(self):
24
- return hash(self.name)
25
-
26
23
 
27
- @dataclass
24
+ @dataclass(unsafe_hash=True)
28
25
  class Handle(Body):
29
26
  ...
30
27
 
31
28
 
32
- @dataclass
29
+ @dataclass(unsafe_hash=True)
33
30
  class Container(Body):
34
31
  ...
35
32
 
36
33
 
37
- @dataclass
34
+ @dataclass(unsafe_hash=True)
38
35
  class Connection(WorldEntity):
39
36
  parent: Body
40
37
  child: Body
41
38
 
42
39
 
43
- @dataclass
40
+ @dataclass(unsafe_hash=True)
44
41
  class FixedConnection(Connection):
45
42
  ...
46
43
 
47
44
 
48
- @dataclass
45
+ @dataclass(unsafe_hash=True)
49
46
  class PrismaticConnection(Connection):
50
47
  ...
51
48
 
52
49
 
53
50
  @dataclass
54
51
  class World:
52
+ id: int = 0
55
53
  bodies: List[Body] = field(default_factory=list)
56
54
  connections: List[Connection] = field(default_factory=list)
57
55
  views: List[View] = field(default_factory=list, repr=False)
58
56
 
57
+ def __eq__(self, other):
58
+ if not isinstance(other, World):
59
+ return False
60
+ return self.id == other.id
59
61
 
60
- @dataclass
62
+
63
+ @dataclass(unsafe_hash=True)
61
64
  class View(WorldEntity):
62
- def __init__(self, *args, **kwargs):
63
- super().__init__(*args, **kwargs)
64
- self.world.views.append(self)
65
+ ...
65
66
 
66
67
 
67
- @dataclass
68
+ @dataclass(unsafe_hash=True)
68
69
  class Drawer(View):
69
70
  handle: Handle
70
71
  container: Container
71
72
  correct: Optional[bool] = None
72
73
 
73
- def __hash__(self):
74
- return hash((self.handle.name, self.container.name))
75
-
76
74
 
77
75
  @dataclass
78
76
  class Cabinet(View):
@@ -80,7 +78,7 @@ class Cabinet(View):
80
78
  drawers: List[Drawer] = field(default_factory=list)
81
79
 
82
80
  def __hash__(self):
83
- return hash(tuple([self.container.name] + [hash(drawer) for drawer in self.drawers]))
81
+ return hash((self.__class__.__name__, self.container))
84
82
 
85
83
 
86
84
  class TestRDRWorld(TestCase):
@@ -115,26 +113,34 @@ class TestRDRWorld(TestCase):
115
113
  def test_view_rdr(self):
116
114
  self.get_view_rdr(use_loaded_answers=True, save_answers=False, append=False)
117
115
 
118
- def get_view_rdr(self, use_loaded_answers: bool = True, save_answers: bool = False,
116
+ def test_write_view_rdr_to_python_file(self):
117
+ rdrs_dir = "./test_generated_rdrs"
118
+ view_rdr = self.get_view_rdr()
119
+ view_rdr.write_to_python_file(rdrs_dir)
120
+ loaded_rdr_classifier = view_rdr.get_rdr_classifier_from_python_file(rdrs_dir)
121
+ found_views = loaded_rdr_classifier(self.world)
122
+ self.assertTrue(len([v for v in found_views["views"] if isinstance(v, Drawer)]) == 1)
123
+ self.assertTrue(len([v for v in found_views["views"] if isinstance(v, Cabinet)]) == 1)
124
+ self.assertTrue(len(found_views["views"]) == 2)
125
+
126
+ def get_view_rdr(self, views=(Drawer, Cabinet), use_loaded_answers: bool = True, save_answers: bool = False,
119
127
  append: bool = False):
120
- expert = Human(use_loaded_answers=use_loaded_answers)
128
+ expert = Human(use_loaded_answers=use_loaded_answers, append=append)
121
129
  filename = os.path.join(os.getcwd(), "test_expert_answers/view_rdr_expert_answers_fit")
122
130
  if use_loaded_answers:
123
131
  expert.load_answers(filename)
124
132
  rdr = GeneralRDR()
125
- try:
126
- rdr.fit_case([CaseQuery(self.world, "views", (View,), False)], expert=expert,
127
- add_extra_conclusions=True)
128
- except Exception as e:
129
- if append:
130
- expert.use_loaded_answers = False
131
- rdr.fit_case([CaseQuery(self.world, "views", (View,), False)], expert=expert,
132
- add_extra_conclusions=True)
133
- else:
134
- raise e
133
+ for view in views:
134
+ rdr.fit_case(CaseQuery(self.world, "views", (view,), False), expert=expert)
135
135
  if save_answers:
136
- expert.save_answers(filename, append=append)
137
- print(rdr.classify(self.world))
136
+ expert.save_answers(filename)
137
+
138
+ found_views = rdr.classify(self.world)
139
+ print(found_views)
140
+ for view in views:
141
+ self.assertTrue(len([v for v in found_views["views"] if isinstance(v, view)]) > 0)
142
+
143
+ return rdr
138
144
 
139
145
  def test_drawer_rdr(self):
140
146
  self.get_drawer_rdr(use_loaded_answers=True, save_answers=False)