ripple-down-rules 0.1.63__py3-none-any.whl → 0.1.65__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -133,7 +133,7 @@ class CaseQuery:
133
133
  if value is not None and not isinstance(value, (CallableExpression, str)):
134
134
  raise ValueError("The target must be a CallableExpression or a string.")
135
135
  self._target = value
136
- self._update_target_value()
136
+ self.update_target_value()
137
137
 
138
138
  @property
139
139
  def target_value(self) -> Any:
@@ -141,10 +141,10 @@ class CaseQuery:
141
141
  :return: The target value of the case query.
142
142
  """
143
143
  if self._target_value is None:
144
- self._update_target_value()
144
+ self.update_target_value()
145
145
  return self._target_value
146
146
 
147
- def _update_target_value(self):
147
+ def update_target_value(self):
148
148
  """
149
149
  Update the target value of the case query.
150
150
  """
@@ -32,10 +32,11 @@ class Expert(ABC):
32
32
  """
33
33
  A flag to indicate if the expert should use loaded answers or not.
34
34
  """
35
- known_categories: Optional[Dict[str, Type[CaseAttribute]]] = None
36
- """
37
- The known categories (i.e. Column types) to use.
38
- """
35
+
36
+ def __init__(self, use_loaded_answers: bool = False, append: bool = False):
37
+ self.all_expert_answers = []
38
+ self.use_loaded_answers = use_loaded_answers
39
+ self.append = append
39
40
 
40
41
  @abstractmethod
41
42
  def ask_for_conditions(self, case_query: CaseQuery, last_evaluated_rule: Optional[Rule] = None) \
@@ -51,28 +52,6 @@ class Expert(ABC):
51
52
  pass
52
53
 
53
54
  @abstractmethod
54
- def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
55
- """
56
- Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
57
-
58
- :param case_query: The case query containing the case to classify.
59
- :return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
60
- conclusion and conditions for the rule.
61
- """
62
- pass
63
-
64
- @abstractmethod
65
- def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
66
- current_conclusions: Any) -> bool:
67
- """
68
- Ask the expert if the conclusion is correct.
69
-
70
- :param case_query: The case query about which the expert should answer.
71
- :param conclusion: The conclusion to check.
72
- :param current_conclusions: The current conclusions for the case.
73
- """
74
- pass
75
-
76
55
  def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
77
56
  """
78
57
  Ask the expert to provide a relational conclusion for the case.
@@ -87,18 +66,14 @@ class Human(Expert):
87
66
  The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
88
67
  """
89
68
 
90
- def __init__(self, use_loaded_answers: bool = False):
91
- self.all_expert_answers = []
92
- self.use_loaded_answers = use_loaded_answers
93
-
94
- def save_answers(self, path: str, append: bool = False):
69
+ def save_answers(self, path: str):
95
70
  """
96
71
  Save the expert answers to a file.
97
72
 
98
73
  :param path: The path to save the answers to.
99
74
  :param append: A flag to indicate if the answers should be appended to the file or not.
100
75
  """
101
- if append:
76
+ if self.append:
102
77
  # read the file and append the new answers
103
78
  with open(path + '.json', "r") as f:
104
79
  all_answers = json.load(f)
@@ -126,24 +101,6 @@ class Human(Expert):
126
101
  last_evaluated_rule=last_evaluated_rule)
127
102
  return self._get_conditions(case_query)
128
103
 
129
- def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
130
- """
131
- Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
132
-
133
- :param case_query: The case query containing the case to classify.
134
- :return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
135
- conclusion and conditions for the rule.
136
- """
137
- rules = []
138
- while True:
139
- conclusion = self.ask_for_conclusion(case_query)
140
- if conclusion is None:
141
- break
142
- conditions = self._get_conditions(case_query)
143
- rules.append({PromptFor.Conclusion: conclusion,
144
- PromptFor.Conditions: conditions})
145
- return rules
146
-
147
104
  def _get_conditions(self, case_query: CaseQuery) \
148
105
  -> CallableExpression:
149
106
  """
@@ -154,6 +111,8 @@ class Human(Expert):
154
111
  :return: The differentiating features as new rule conditions.
155
112
  """
156
113
  user_input = None
114
+ if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
115
+ self.use_loaded_answers = False
157
116
  if self.use_loaded_answers:
158
117
  user_input = self.all_expert_answers.pop(0)
159
118
  if user_input:
@@ -173,6 +132,8 @@ class Human(Expert):
173
132
  :return: The conclusion for the case as a callable expression.
174
133
  """
175
134
  expression: Optional[CallableExpression] = None
135
+ if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
136
+ self.use_loaded_answers = False
176
137
  if self.use_loaded_answers:
177
138
  expert_input = self.all_expert_answers.pop(0)
178
139
  if expert_input is not None:
@@ -184,63 +145,3 @@ class Human(Expert):
184
145
  self.all_expert_answers.append(expert_input)
185
146
  case_query.target = expression
186
147
  return expression
187
-
188
- def get_category_type(self, cat_name: str) -> Optional[Type[CaseAttribute]]:
189
- """
190
- Get the category type from the known categories.
191
-
192
- :param cat_name: The name of the category.
193
- :return: The category type.
194
- """
195
- cat_name = cat_name.lower()
196
- self.known_categories = get_all_subclasses(
197
- CaseAttribute) if not self.known_categories else self.known_categories
198
- self.known_categories.update(CaseAttribute.registry)
199
- category_type = None
200
- if cat_name in self.known_categories:
201
- category_type = self.known_categories[cat_name]
202
- return category_type
203
-
204
- def ask_if_category_is_mutually_exclusive(self, category_name: str) -> bool:
205
- """
206
- Ask the expert if the new category can have multiple values.
207
-
208
- :param category_name: The name of the category to ask about.
209
- """
210
- question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
211
- return not self.ask_for_affirmation(question)
212
-
213
- def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
214
- current_conclusions: Any) -> bool:
215
- """
216
- Ask the expert if the conclusion is correct.
217
-
218
- :param case_query: The case query about which the expert should answer.
219
- :param conclusion: The conclusion to check.
220
- :param current_conclusions: The current conclusions for the case.
221
- """
222
- if not self.use_loaded_answers:
223
- print(f"Current conclusions: {current_conclusions}")
224
- return self.ask_for_affirmation(case_query,
225
- f"Is the conclusion {conclusion} correct for the case (True/False):")
226
-
227
- def ask_for_affirmation(self, case_query: CaseQuery, question: str) -> bool:
228
- """
229
- Ask the expert a yes or no question.
230
-
231
- :param case_query: The case query about which the expert should answer.
232
- :param question: The question to ask the expert.
233
- :return: The answer to the question.
234
- """
235
- while True:
236
- if self.use_loaded_answers:
237
- answer = self.all_expert_answers.pop(0)
238
- else:
239
- _, expression = prompt_user_for_expression(case_query, PromptFor.Affirmation, question)
240
- answer = expression(case_query.case)
241
- if answer:
242
- self.all_expert_answers.append(True)
243
- return True
244
- else:
245
- self.all_expert_answers.append(False)
246
- return False
ripple_down_rules/rdr.py CHANGED
@@ -22,7 +22,7 @@ from .helpers import is_matching
22
22
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
23
23
  from .utils import draw_tree, make_set, copy_case, \
24
24
  SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
25
- get_case_attribute_type, is_conflicting
25
+ get_case_attribute_type, is_conflicting, update_case
26
26
 
27
27
 
28
28
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -49,32 +49,6 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
49
49
  self.start_rule = start_rule
50
50
  self.fig: Optional[plt.Figure] = None
51
51
 
52
- def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
53
- return self.classify(case)
54
-
55
- @abstractmethod
56
- def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
57
- """
58
- Classify a case.
59
-
60
- :param case: The case to classify.
61
- :return: The category that the case belongs to.
62
- """
63
- pass
64
-
65
- @abstractmethod
66
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
67
- -> Union[CaseAttribute, CallableExpression]:
68
- """
69
- Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
70
- comparing the case with the target category.
71
-
72
- :param case_query: The query containing the case to classify and the target category to compare the case with.
73
- :param expert: The expert to ask for differentiating features as new rule conditions.
74
- :return: The category that the case belongs to.
75
- """
76
- pass
77
-
78
52
  def fit(self, case_queries: List[CaseQuery],
79
53
  expert: Optional[Expert] = None,
80
54
  n_iter: int = None,
@@ -124,6 +98,63 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
124
98
  plt.ioff()
125
99
  plt.show()
126
100
 
101
+ def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
102
+ return self.classify(case)
103
+
104
+ @abstractmethod
105
+ def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
106
+ """
107
+ Classify a case.
108
+
109
+ :param case: The case to classify.
110
+ :return: The category that the case belongs to.
111
+ """
112
+ pass
113
+
114
+ def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
115
+ -> Union[CaseAttribute, CallableExpression]:
116
+ """
117
+ Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
118
+ incorrect by comparing the case with the target category.
119
+
120
+ :param case_query: The query containing the case to classify and the target category to compare the case with.
121
+ :param expert: The expert to ask for differentiating features as new rule conditions.
122
+ :return: The category that the case belongs to.
123
+ """
124
+ if case_query is None:
125
+ raise ValueError("The case query cannot be None.")
126
+ if case_query.target is None:
127
+ expert.ask_for_conclusion(case_query)
128
+ if case_query.target is None:
129
+ return self.classify(case_query.case)
130
+
131
+ self.update_start_rule(case_query, expert)
132
+
133
+ return self._fit_case(case_query, expert=expert, **kwargs)
134
+
135
+ @abstractmethod
136
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
137
+ -> Union[CaseAttribute, CallableExpression]:
138
+ """
139
+ Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
140
+ comparing the case with the target category.
141
+
142
+ :param case_query: The query containing the case to classify and the target category to compare the case with.
143
+ :param expert: The expert to ask for differentiating features as new rule conditions.
144
+ :return: The category that the case belongs to.
145
+ """
146
+ pass
147
+
148
+ @abstractmethod
149
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
150
+ """
151
+ Update the starting rule of the classifier.
152
+
153
+ :param case_query: The case query to update the starting rule with.
154
+ :param expert: The expert to ask for differentiating features as new rule conditions.
155
+ """
156
+ pass
157
+
127
158
  def update_figures(self):
128
159
  """
129
160
  Update the figures of the classifier.
@@ -217,13 +248,16 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
217
248
  for rule in [self.start_rule] + list(self.start_rule.descendants):
218
249
  if not rule.conditions:
219
250
  continue
220
- if rule.conditions.scope is None or len(rule.conditions.scope) == 0:
221
- continue
222
- for k, v in rule.conditions.scope.items():
223
- new_imports = f"from {v.__module__} import {v.__name__}\n"
224
- if new_imports in imports:
251
+ for scope in [rule.conditions.scope, rule.conclusion.scope]:
252
+ if scope is None:
225
253
  continue
226
- imports += new_imports
254
+ for k, v in scope.items():
255
+ if not hasattr(v, "__module__") or not hasattr(v, "__name__"):
256
+ continue
257
+ new_imports = f"from {v.__module__} import {v.__name__}\n"
258
+ if new_imports in imports:
259
+ continue
260
+ imports += new_imports
227
261
  return imports
228
262
 
229
263
  def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
@@ -318,7 +352,7 @@ class SingleClassRDR(RDRWithCodeWriter):
318
352
  super(SingleClassRDR, self).__init__(start_rule)
319
353
  self.default_conclusion: Optional[Any] = default_conclusion
320
354
 
321
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
355
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
322
356
  -> Union[CaseAttribute, CallableExpression, None]:
323
357
  """
324
358
  Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
@@ -328,38 +362,44 @@ class SingleClassRDR(RDRWithCodeWriter):
328
362
  :param expert: The expert to ask for differentiating features as new rule conditions.
329
363
  :return: The category that the case belongs to.
330
364
  """
331
- expert = expert if expert else Human()
332
365
  if case_query.default_value is not None and self.default_conclusion != case_query.default_value:
333
366
  self.default_conclusion = case_query.default_value
334
- case = case_query.case
335
- target = expert.ask_for_conclusion(case_query) if case_query.target is None else case_query.target
336
- if target is None:
337
- return self.classify(case)
338
- if not self.start_rule:
339
- conditions = expert.ask_for_conditions(case_query)
340
- self.start_rule = SingleClassRule(conditions, target, corner_case=case,
341
- conclusion_name=case_query.attribute_name)
342
367
 
343
368
  pred = self.evaluate(case_query.case)
344
- if pred.conclusion(case) != target(case):
345
- conditions = expert.ask_for_conditions(case_query, pred)
346
- pred.fit_rule(case_query.case, target, conditions=conditions)
369
+ if pred.conclusion(case_query.case) != case_query.target_value:
370
+ expert.ask_for_conditions(case_query, pred)
371
+ pred.fit_rule(case_query.case, case_query.target, conditions=case_query.conditions)
347
372
 
348
373
  return self.classify(case_query.case)
349
374
 
350
- def classify(self, case: Case) -> Optional[Any]:
375
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
376
+ """
377
+ Update the starting rule of the classifier.
378
+
379
+ :param case_query: The case query to update the starting rule with.
380
+ :param expert: The expert to ask for differentiating features as new rule conditions.
381
+ """
382
+ if not self.start_rule:
383
+ expert.ask_for_conditions(case_query)
384
+ self.start_rule = SingleClassRule(case_query.conditions, case_query.target, corner_case=case_query.case,
385
+ conclusion_name=case_query.attribute_name)
386
+
387
+ def classify(self, case: Case, modify_original_case: bool = False) -> Optional[Any]:
351
388
  """
352
389
  Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
390
+
391
+ :param case: The case to classify.
392
+ :param modify_original_case: Whether to modify the original case attributes with the conclusion or not.
353
393
  """
354
394
  pred = self.evaluate(case)
355
- return pred.conclusion(case) if pred.fired else self.default_conclusion
395
+ return pred.conclusion(case) if pred is not None and pred.fired else self.default_conclusion
356
396
 
357
397
  def evaluate(self, case: Case) -> SingleClassRule:
358
398
  """
359
399
  Evaluate the starting rule on a case.
360
400
  """
361
- matched_rule = self.start_rule(case)
362
- return matched_rule if matched_rule else self.start_rule
401
+ matched_rule = self.start_rule(case) if self.start_rule is not None else None
402
+ return matched_rule if matched_rule is not None else self.start_rule
363
403
 
364
404
  def write_to_python_file(self, file_path: str, postfix: str = ""):
365
405
  super().write_to_python_file(file_path, postfix)
@@ -388,6 +428,12 @@ class SingleClassRDR(RDRWithCodeWriter):
388
428
  def conclusion_type_hint(self) -> str:
389
429
  return self.conclusion_type[0].__name__
390
430
 
431
+ @property
432
+ def conclusion_type(self) -> Tuple[Type]:
433
+ if self.default_conclusion is not None:
434
+ return (type(self.default_conclusion),)
435
+ return super().conclusion_type
436
+
391
437
  def _to_json(self) -> Dict[str, Any]:
392
438
  return {"start_rule": self.start_rule.to_json()}
393
439
 
@@ -419,13 +465,12 @@ class MultiClassRDR(RDRWithCodeWriter):
419
465
  The conditions of the stopping rule if needed.
420
466
  """
421
467
 
422
- def __init__(self, start_rule: Optional[Rule] = None,
468
+ def __init__(self, start_rule: Optional[MultiClassTopRule] = None,
423
469
  mode: MCRDRMode = MCRDRMode.StopOnly):
424
470
  """
425
471
  :param start_rule: The starting rules for the classifier.
426
472
  :param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
427
473
  """
428
- start_rule = MultiClassTopRule() if not start_rule else start_rule
429
474
  super(MultiClassRDR, self).__init__(start_rule)
430
475
  self.mode: MCRDRMode = mode
431
476
 
@@ -439,40 +484,28 @@ class MultiClassRDR(RDRWithCodeWriter):
439
484
  evaluated_rule = next_rule
440
485
  return make_set(self.conclusions)
441
486
 
442
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None,
443
- add_extra_conclusions: bool = False) -> Set[Union[CaseAttribute, CallableExpression, None]]:
487
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None
488
+ , **kwargs) -> Set[Union[CaseAttribute, CallableExpression, None]]:
444
489
  """
445
490
  Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
446
491
  or missing by comparing the case with the target category if provided.
447
492
 
448
493
  :param case_query: The query containing the case to classify and the target category to compare the case with.
449
494
  :param expert: The expert to ask for differentiating features as new rule conditions or for extra conclusions.
450
- :param add_extra_conclusions: Whether to add extra conclusions after classification is done.
451
495
  :return: The conclusions that the case belongs to.
452
496
  """
453
- expert = expert if expert else Human()
454
- if case_query.target is None:
455
- expert.ask_for_conclusion(case_query)
456
- if case_query.target is None:
457
- return self.classify(case_query.case)
458
- self.update_start_rule(case_query, expert)
459
- self.expert_accepted_conclusions = []
460
- user_conclusions = []
461
497
  self.conclusions = []
462
498
  self.stop_rule_conditions = None
463
499
  evaluated_rule = self.start_rule
464
- target = case_query.target(case_query.case)
500
+ target = make_set(case_query.target_value)
465
501
  while evaluated_rule:
466
502
  next_rule = evaluated_rule(case_query.case)
467
503
  rule_conclusion = evaluated_rule.conclusion(case_query.case)
468
- good_conclusions = make_list(target) + user_conclusions + self.expert_accepted_conclusions
469
- good_conclusions = make_set(good_conclusions)
470
504
 
471
505
  if evaluated_rule.fired:
472
- if target and not make_set(rule_conclusion).issubset(good_conclusions):
506
+ if not make_set(rule_conclusion).issubset(target):
473
507
  # Rule fired and conclusion is different from target
474
- self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
475
- len(user_conclusions) > 0)
508
+ self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
476
509
  else:
477
510
  # Rule fired and target is correct or there is no target to compare
478
511
  self.add_conclusion(evaluated_rule, case_query.case)
@@ -483,14 +516,6 @@ class MultiClassRDR(RDRWithCodeWriter):
483
516
  self.add_rule_for_case(case_query, expert)
484
517
  # Have to check all rules again to make sure only this new rule fires
485
518
  next_rule = self.start_rule
486
- elif add_extra_conclusions:
487
- # No more conclusions can be made, ask the expert for extra conclusions if needed.
488
- new_user_conclusions = self.ask_expert_for_extra_rules(expert, case_query)
489
- user_conclusions.extend(new_user_conclusions)
490
- if len(new_user_conclusions) > 0:
491
- next_rule = self.last_top_rule
492
- else:
493
- add_extra_conclusions = False
494
519
  evaluated_rule = next_rule
495
520
  return self.conclusions
496
521
 
@@ -529,12 +554,10 @@ class MultiClassRDR(RDRWithCodeWriter):
529
554
  :param case_query: The case query to update the starting rule with.
530
555
  :param expert: The expert to ask for differentiating features as new rule conditions.
531
556
  """
532
- if not self.start_rule.conditions:
557
+ if not self.start_rule:
533
558
  conditions = expert.ask_for_conditions(case_query)
534
- self.start_rule.conditions = conditions
535
- self.start_rule.conclusion = case_query.target
536
- self.start_rule.corner_case = case_query.case
537
- self.start_rule.conclusion_name = case_query.attribute_name
559
+ self.start_rule = MultiClassTopRule(conditions, case_query.target, corner_case=case_query.case,
560
+ conclusion_name=case_query.attribute_name)
538
561
 
539
562
  @property
540
563
  def last_top_rule(self) -> Optional[MultiClassTopRule]:
@@ -547,17 +570,13 @@ class MultiClassRDR(RDRWithCodeWriter):
547
570
  return self.start_rule.furthest_alternative[-1]
548
571
 
549
572
  def stop_wrong_conclusion_else_add_it(self, case_query: CaseQuery, expert: Expert,
550
- evaluated_rule: MultiClassTopRule,
551
- add_extra_conclusions: bool):
573
+ evaluated_rule: MultiClassTopRule):
552
574
  """
553
575
  Stop a wrong conclusion by adding a stopping rule.
554
576
  """
555
577
  rule_conclusion = evaluated_rule.conclusion(case_query.case)
556
578
  if is_conflicting(rule_conclusion, case_query.target_value):
557
- if self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
558
- return
559
- else:
560
- self.stop_conclusion(case_query, expert, evaluated_rule)
579
+ self.stop_conclusion(case_query, expert, evaluated_rule)
561
580
 
562
581
  def stop_conclusion(self, case_query: CaseQuery,
563
582
  expert: Expert, evaluated_rule: MultiClassTopRule):
@@ -576,27 +595,6 @@ class MultiClassRDR(RDRWithCodeWriter):
576
595
  new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
577
596
  self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
578
597
 
579
- def conclusion_is_correct(self, case_query: CaseQuery,
580
- expert: Expert, evaluated_rule: Rule,
581
- add_extra_conclusions: bool) -> bool:
582
- """
583
- Ask the expert if the conclusion is correct, and add it to the conclusions if it is.
584
-
585
- :param case_query: The case query to ask the expert about.
586
- :param expert: The expert to ask for differentiating features as new rule conditions.
587
- :param evaluated_rule: The evaluated rule to ask the expert about.
588
- :param add_extra_conclusions: Whether adding extra conclusions after classification is allowed.
589
- :return: Whether the conclusion is correct or not.
590
- """
591
- conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
592
- if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
593
- evaluated_rule.conclusion(case_query.case),
594
- current_conclusions=conclusions)):
595
- self.add_conclusion(evaluated_rule, case_query.case)
596
- self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
597
- return True
598
- return False
599
-
600
598
  def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
601
599
  """
602
600
  Add a rule for a case that has not been classified with any conclusion.
@@ -611,24 +609,6 @@ class MultiClassRDR(RDRWithCodeWriter):
611
609
  conditions = expert.ask_for_conditions(case_query)
612
610
  self.add_top_rule(conditions, case_query.target, case_query.case)
613
611
 
614
- def ask_expert_for_extra_rules(self, expert: Expert, case_query: CaseQuery) -> List[Any]:
615
- """
616
- Ask the expert for extra rules when no more conclusions can be made for a case.
617
-
618
- :param expert: The expert to ask for extra conclusions.
619
- :param case_query: The case query to ask the expert about.
620
- :return: The extra conclusions for the rules that the expert has provided.
621
- """
622
- extra_conclusions = []
623
- conclusions = list(OrderedSet(self.conclusions))
624
- if not expert.use_loaded_answers:
625
- print("current conclusions:", conclusions)
626
- extra_rules = expert.ask_for_extra_rules(case_query)
627
- for rule in extra_rules:
628
- self.add_top_rule(rule[PromptFor.Conditions], rule[PromptFor.Conclusion], case_query.case)
629
- extra_conclusions.extend(rule[PromptFor.Conclusion](case_query.case))
630
- return extra_conclusions
631
-
632
612
  def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
633
613
  """
634
614
  Add the conclusion of the evaluated rule to the list of conclusions.
@@ -722,30 +702,32 @@ class GeneralRDR(RippleDownRules):
722
702
  def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
723
703
  return [rdr.start_rule for rdr in self.start_rules_dict.values()]
724
704
 
725
- def classify(self, case: Any) -> Optional[Dict[str, Any]]:
705
+ def classify(self, case: Any, modify_case: bool = False) -> Optional[Dict[str, Any]]:
726
706
  """
727
707
  Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
728
708
  the classification until no more categories can be added.
729
709
 
730
710
  :param case: The case to classify.
711
+ :param modify_case: Whether to modify the original case or create a copy and modify it.
731
712
  :return: The categories that the case belongs to.
732
713
  """
733
- return self._classify(self.start_rules_dict, case)
714
+ return self._classify(self.start_rules_dict, case, modify_original_case=modify_case)
734
715
 
735
716
  @staticmethod
736
717
  def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
737
- case: Any) -> Dict[str, Any]:
718
+ case: Any, modify_original_case: bool = True) -> Dict[str, Any]:
738
719
  """
739
720
  Classify a case by going through all classifiers and adding the categories that are classified,
740
721
  and then restarting the classification until no more categories can be added.
741
722
 
742
723
  :param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
743
724
  :param case: The case to classify.
725
+ :param modify_original_case: Whether to modify the original case or create a copy and modify it.
744
726
  :return: The categories that the case belongs to.
745
727
  """
746
728
  conclusions = {}
747
729
  case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
748
- case_cp = copy_case(case)
730
+ case_cp = copy_case(case) if not modify_original_case else case
749
731
  while True:
750
732
  new_conclusions = {}
751
733
  for attribute_name, rdr in classifiers_dict.items():
@@ -768,13 +750,13 @@ class GeneralRDR(RippleDownRules):
768
750
  conclusions[attribute_name].update(pred_atts)
769
751
  if attribute_name in new_conclusions:
770
752
  mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
771
- GeneralRDR.update_case(CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive),
772
- new_conclusions)
753
+ case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive)
754
+ update_case(case_query, new_conclusions)
773
755
  if len(new_conclusions) == 0:
774
756
  break
775
757
  return conclusions
776
758
 
777
- def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs) \
759
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
778
760
  -> Dict[str, Any]:
779
761
  """
780
762
  Fit the GRDR on a case, if the target is a new type of category, a new RDR is created for it,
@@ -784,115 +766,38 @@ class GeneralRDR(RippleDownRules):
784
766
  they are accepted by the expert, and the attribute of that category is represented in the case as a set of
785
767
  values.
786
768
 
787
- :param case_queries: The queries containing the case to classify and the target categories to compare the case
769
+ :param case_query: The query containing the case to classify and the target category to compare the case
788
770
  with.
789
771
  :param expert: The expert to ask for differentiating features as new rule conditions.
790
772
  :return: The categories that the case belongs to.
791
773
  """
792
- expert = expert if expert else Human()
793
- case_queries = make_list(case_queries)
794
- assert len(case_queries) > 0, "No case queries provided"
795
- case = case_queries[0].case
796
- assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
797
- " for multiple cases use fit instead")
798
- original_case_query_cp = copy(case_queries[0])
799
- for case_query in case_queries:
800
- case_query_cp = copy(case_query)
801
- case_query_cp.case = original_case_query_cp.case
802
- if case_query_cp.target is None:
803
- conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
804
- self.update_case(case_query_cp, conclusions)
805
- expert.ask_for_conclusion(case_query_cp)
806
- if case_query_cp.target is None:
807
- continue
808
- case_query.target = case_query_cp.target
809
-
810
- if case_query.attribute_name not in self.start_rules_dict:
811
- conclusions = self.classify(case)
812
- self.update_case(case_query_cp, conclusions)
813
-
814
- new_rdr = self.initialize_new_rdr_for_attribute(case_query_cp)
815
- self.add_rdr(new_rdr, case_query.attribute_name)
816
-
817
- new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
818
- self.update_case(case_query_cp, {case_query.attribute_name: new_conclusions})
819
- else:
820
- for rdr_attribute_name, rdr in self.start_rules_dict.items():
821
- if case_query.attribute_name != rdr_attribute_name:
822
- conclusions = rdr.classify(case_query_cp.case)
823
- else:
824
- conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
825
- **kwargs)
826
- if conclusions is not None:
827
- if (not is_iterable(conclusions)) or len(conclusions) > 0:
828
- conclusions = {rdr_attribute_name: conclusions}
829
- case_query_cp.mutually_exclusive = True if isinstance(rdr, SingleClassRDR) else False
830
- self.update_case(case_query_cp, conclusions)
831
- case_query.conditions = case_query_cp.conditions
832
774
 
833
- return self.classify(case)
775
+ case_query_cp = copy(case_query)
776
+ self.classify(case_query_cp.case, modify_case=True)
777
+ case_query_cp.update_target_value()
778
+
779
+ self.start_rules_dict[case_query_cp.attribute_name].fit_case(case_query_cp, expert, **kwargs)
780
+
781
+ return self.classify(case_query.case)
782
+
783
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
784
+ """
785
+ Update the starting rule of the classifier.
786
+
787
+ :param case_query: The case query to update the starting rule with.
788
+ :param expert: The expert to ask for differentiating features as new rule conditions.
789
+ """
790
+ if case_query.attribute_name not in self.start_rules_dict:
791
+ new_rdr = self.initialize_new_rdr_for_attribute(case_query)
792
+ self.add_rdr(new_rdr, case_query.attribute_name)
834
793
 
835
794
  @staticmethod
836
795
  def initialize_new_rdr_for_attribute(case_query: CaseQuery):
837
796
  """
838
797
  Initialize the appropriate RDR type for the target.
839
798
  """
840
- if case_query.mutually_exclusive is not None:
841
- return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
842
- else MultiClassRDR()
843
- if case_query.attribute_type in [list, set]:
844
- return MultiClassRDR()
845
- attribute = getattr(case_query.case, case_query.attribute_name) \
846
- if hasattr(case_query.case, case_query.attribute_name) else case_query.target(case_query.case)
847
- if isinstance(attribute, CaseAttribute):
848
- return SingleClassRDR(default_conclusion=case_query.default_value) if attribute.mutually_exclusive \
849
- else MultiClassRDR()
850
- else:
851
- return MultiClassRDR() if is_iterable(attribute) or (attribute is None) \
852
- else SingleClassRDR(default_conclusion=case_query.default_value)
853
-
854
- @staticmethod
855
- def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
856
- """
857
- Update the case with the conclusions.
858
-
859
- :param case_query: The case query that contains the case to update.
860
- :param conclusions: The conclusions to update the case with.
861
- """
862
- if not conclusions:
863
- return
864
- if len(conclusions) == 0:
865
- return
866
- if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
867
- for conclusion_name, conclusion in conclusions.items():
868
- attribute = getattr(case_query.case, conclusion_name)
869
- if conclusion_name == case_query.attribute_name:
870
- attribute_type = case_query.attribute_type
871
- else:
872
- attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
873
- if isinstance(attribute, set):
874
- for c in conclusion:
875
- attribute.update(make_set(c))
876
- elif isinstance(attribute, list):
877
- attribute.extend(conclusion)
878
- elif any(at in {List, list} for at in attribute_type):
879
- attribute = [] if attribute is None else attribute
880
- attribute.extend(conclusion)
881
- elif any(at in {Set, set} for at in attribute_type):
882
- attribute = set() if attribute is None else attribute
883
- for c in conclusion:
884
- attribute.update(make_set(c))
885
- elif is_iterable(conclusion) and len(conclusion) == 1 \
886
- and any(at is type(list(conclusion)[0]) for at in attribute_type):
887
- setattr(case_query.case, conclusion_name, list(conclusion)[0])
888
- elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
889
- setattr(case_query.case, conclusion_name, conclusion)
890
- else:
891
- raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
892
- f"{case_query.attribute_type} with conclusion "
893
- f"{conclusion} of type {type(conclusion)}")
894
- else:
895
- case_query.case.update(conclusions)
799
+ return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
800
+ else MultiClassRDR()
896
801
 
897
802
  def _to_json(self) -> Dict[str, Any]:
898
803
  return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
@@ -130,6 +130,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
130
130
  "conclusion": conclusion_to_json(self.conclusion),
131
131
  "parent": self.parent.json_serialization if self.parent else None,
132
132
  "corner_case": self.corner_case.to_json() if self.corner_case else None,
133
+ "conclusion_name": self.conclusion_name,
133
134
  "weight": self.weight}
134
135
  return json_serialization
135
136
 
@@ -139,6 +140,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
139
140
  conclusion=CallableExpression.from_json(data["conclusion"]),
140
141
  parent=cls.from_json(data["parent"]),
141
142
  corner_case=Case.from_json(data["corner_case"]),
143
+ conclusion_name=data["conclusion_name"],
142
144
  weight=data["weight"])
143
145
  return loaded_rule
144
146
 
@@ -288,7 +290,7 @@ class MultiClassStopRule(Rule, HasAlternativeRule):
288
290
 
289
291
  def __init__(self, *args, **kwargs):
290
292
  super(MultiClassStopRule, self).__init__(*args, **kwargs)
291
- self.conclusion = Stop.stop
293
+ self.conclusion = CallableExpression(conclusion_type=(Stop,), conclusion=Stop.stop)
292
294
 
293
295
  def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassStopRule, MultiClassTopRule]]:
294
296
  if self.fired:
@@ -26,6 +26,7 @@ from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get
26
26
 
27
27
  if TYPE_CHECKING:
28
28
  from .datastructures.case import Case
29
+ from .datastructures.dataclasses import CaseQuery
29
30
  from .rules import Rule
30
31
 
31
32
  import ast
@@ -33,6 +34,49 @@ import ast
33
34
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
34
35
 
35
36
 
37
+ def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
38
+ """
39
+ Update the case with the conclusions.
40
+
41
+ :param case_query: The case query that contains the case to update.
42
+ :param conclusions: The conclusions to update the case with.
43
+ """
44
+ if not conclusions:
45
+ return
46
+ if len(conclusions) == 0:
47
+ return
48
+ if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
49
+ for conclusion_name, conclusion in conclusions.items():
50
+ attribute = getattr(case_query.case, conclusion_name)
51
+ if conclusion_name == case_query.attribute_name:
52
+ attribute_type = case_query.attribute_type
53
+ else:
54
+ attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
55
+ if isinstance(attribute, set):
56
+ for c in conclusion:
57
+ attribute.update(make_set(c))
58
+ elif isinstance(attribute, list):
59
+ attribute.extend(conclusion)
60
+ elif any(at in {List, list} for at in attribute_type):
61
+ attribute = [] if attribute is None else attribute
62
+ attribute.extend(conclusion)
63
+ elif any(at in {Set, set} for at in attribute_type):
64
+ attribute = set() if attribute is None else attribute
65
+ for c in conclusion:
66
+ attribute.update(make_set(c))
67
+ elif is_iterable(conclusion) and len(conclusion) == 1 \
68
+ and any(at is type(list(conclusion)[0]) for at in attribute_type):
69
+ setattr(case_query.case, conclusion_name, list(conclusion)[0])
70
+ elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
71
+ setattr(case_query.case, conclusion_name, conclusion)
72
+ else:
73
+ raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
74
+ f"{case_query.attribute_type} with conclusion "
75
+ f"{conclusion} of type {type(conclusion)}")
76
+ else:
77
+ case_query.case.update(conclusions)
78
+
79
+
36
80
  def is_conflicting(conclusion: Any, target: Any) -> bool:
37
81
  """
38
82
  :param conclusion: The conclusion to check.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.63
3
+ Version: 0.1.65
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -1,20 +1,20 @@
1
1
  ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  ripple_down_rules/datasets.py,sha256=rCSpeFeu1gTuKESwjHUdQkPPvomI5OMRNGpbdKmHwMg,4639
3
- ripple_down_rules/experts.py,sha256=TretkdR1IY2RjcSh4WJUFJtYEbItsand7SYFmzouE_Y,10348
3
+ ripple_down_rules/experts.py,sha256=JGVvSNiWhm4FpRpg76f98tl8Ii_C7x_aWD9FxD-JDLQ,6130
4
4
  ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
5
  ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
6
6
  ripple_down_rules/prompt.py,sha256=cHqhMJqubGhfGpOOY_uXv5L7PBNb64O0IBWSfiY0ui0,6682
7
- ripple_down_rules/rdr.py,sha256=Ts47r1Tf00tBDRKZu_0s-8o1kYn97Pc8M6uMIRPa-6s,47713
7
+ ripple_down_rules/rdr.py,sha256=es0ml_UVLgWNb7cTxAo3ZsmNMlpifmoOOnR8K9k4vbQ,41566
8
8
  ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
9
- ripple_down_rules/rules.py,sha256=KTB7kPnyyU9GuZhVe9ba25-3ICdzl46r9MFduckk-_Y,16147
10
- ripple_down_rules/utils.py,sha256=JIF99Knqzqjgny7unvEnib3sCmExqU-w9xYOSGIT86Q,32276
9
+ ripple_down_rules/rules.py,sha256=Y36at3etrxQpDVUJUpx-xZll7jYgPObzc6qBy8ZpKP0,16341
10
+ ripple_down_rules/utils.py,sha256=DUPjrSWMfeeNf9OllGDWJ5kf3Hazz1sEW5DTqKQ4h6E,34528
11
11
  ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
12
12
  ripple_down_rules/datastructures/callable_expression.py,sha256=TW_u6CJfelW2CiJj9pWFpdOBNIxeEuhhsQEz_pLpFVE,9092
13
13
  ripple_down_rules/datastructures/case.py,sha256=uM5YkJOfYfARrZZ3oAxKMWE5QBSnvHLLZa9Atoxb7eY,13800
14
- ripple_down_rules/datastructures/dataclasses.py,sha256=_aabVXsgdVUeAmgGA9K_LZpO2U5a6-htrg2Tka7qc30,5960
14
+ ripple_down_rules/datastructures/dataclasses.py,sha256=V757VwxROlevXh5ZVFLVuzwY4JIJKG8ARlCfjhubfy8,5957
15
15
  ripple_down_rules/datastructures/enums.py,sha256=RdyPUp9Ls1QuLmkcMMkBbCWrmXIZI4xWuM-cLPYZhR0,4666
16
- ripple_down_rules-0.1.63.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
- ripple_down_rules-0.1.63.dist-info/METADATA,sha256=jXq2PSfcC2cqtqrpIwt7zyyWLtKf3aUGD5EZH9_bYuA,42576
18
- ripple_down_rules-0.1.63.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
19
- ripple_down_rules-0.1.63.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
- ripple_down_rules-0.1.63.dist-info/RECORD,,
16
+ ripple_down_rules-0.1.65.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
+ ripple_down_rules-0.1.65.dist-info/METADATA,sha256=WtL6_RMdFxEhSWnK9CMqbp_94caIwcGw4s28KqJoKm4,42576
18
+ ripple_down_rules-0.1.65.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
19
+ ripple_down_rules-0.1.65.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
+ ripple_down_rules-0.1.65.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.0.0)
2
+ Generator: setuptools (80.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5