ripple-down-rules 0.1.64__py3-none-any.whl → 0.1.65__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -133,7 +133,7 @@ class CaseQuery:
133
133
  if value is not None and not isinstance(value, (CallableExpression, str)):
134
134
  raise ValueError("The target must be a CallableExpression or a string.")
135
135
  self._target = value
136
- self._update_target_value()
136
+ self.update_target_value()
137
137
 
138
138
  @property
139
139
  def target_value(self) -> Any:
@@ -141,10 +141,10 @@ class CaseQuery:
141
141
  :return: The target value of the case query.
142
142
  """
143
143
  if self._target_value is None:
144
- self._update_target_value()
144
+ self.update_target_value()
145
145
  return self._target_value
146
146
 
147
- def _update_target_value(self):
147
+ def update_target_value(self):
148
148
  """
149
149
  Update the target value of the case query.
150
150
  """
@@ -32,10 +32,11 @@ class Expert(ABC):
32
32
  """
33
33
  A flag to indicate if the expert should use loaded answers or not.
34
34
  """
35
- known_categories: Optional[Dict[str, Type[CaseAttribute]]] = None
36
- """
37
- The known categories (i.e. Column types) to use.
38
- """
35
+
36
+ def __init__(self, use_loaded_answers: bool = False, append: bool = False):
37
+ self.all_expert_answers = []
38
+ self.use_loaded_answers = use_loaded_answers
39
+ self.append = append
39
40
 
40
41
  @abstractmethod
41
42
  def ask_for_conditions(self, case_query: CaseQuery, last_evaluated_rule: Optional[Rule] = None) \
@@ -51,28 +52,6 @@ class Expert(ABC):
51
52
  pass
52
53
 
53
54
  @abstractmethod
54
- def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
55
- """
56
- Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
57
-
58
- :param case_query: The case query containing the case to classify.
59
- :return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
60
- conclusion and conditions for the rule.
61
- """
62
- pass
63
-
64
- @abstractmethod
65
- def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
66
- current_conclusions: Any) -> bool:
67
- """
68
- Ask the expert if the conclusion is correct.
69
-
70
- :param case_query: The case query about which the expert should answer.
71
- :param conclusion: The conclusion to check.
72
- :param current_conclusions: The current conclusions for the case.
73
- """
74
- pass
75
-
76
55
  def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
77
56
  """
78
57
  Ask the expert to provide a relational conclusion for the case.
@@ -87,18 +66,14 @@ class Human(Expert):
87
66
  The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
88
67
  """
89
68
 
90
- def __init__(self, use_loaded_answers: bool = False):
91
- self.all_expert_answers = []
92
- self.use_loaded_answers = use_loaded_answers
93
-
94
- def save_answers(self, path: str, append: bool = False):
69
+ def save_answers(self, path: str):
95
70
  """
96
71
  Save the expert answers to a file.
97
72
 
98
73
  :param path: The path to save the answers to.
99
74
  :param append: A flag to indicate if the answers should be appended to the file or not.
100
75
  """
101
- if append:
76
+ if self.append:
102
77
  # read the file and append the new answers
103
78
  with open(path + '.json', "r") as f:
104
79
  all_answers = json.load(f)
@@ -126,24 +101,6 @@ class Human(Expert):
126
101
  last_evaluated_rule=last_evaluated_rule)
127
102
  return self._get_conditions(case_query)
128
103
 
129
- def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
130
- """
131
- Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
132
-
133
- :param case_query: The case query containing the case to classify.
134
- :return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
135
- conclusion and conditions for the rule.
136
- """
137
- rules = []
138
- while True:
139
- conclusion = self.ask_for_conclusion(case_query)
140
- if conclusion is None:
141
- break
142
- conditions = self._get_conditions(case_query)
143
- rules.append({PromptFor.Conclusion: conclusion,
144
- PromptFor.Conditions: conditions})
145
- return rules
146
-
147
104
  def _get_conditions(self, case_query: CaseQuery) \
148
105
  -> CallableExpression:
149
106
  """
@@ -154,6 +111,8 @@ class Human(Expert):
154
111
  :return: The differentiating features as new rule conditions.
155
112
  """
156
113
  user_input = None
114
+ if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
115
+ self.use_loaded_answers = False
157
116
  if self.use_loaded_answers:
158
117
  user_input = self.all_expert_answers.pop(0)
159
118
  if user_input:
@@ -173,6 +132,8 @@ class Human(Expert):
173
132
  :return: The conclusion for the case as a callable expression.
174
133
  """
175
134
  expression: Optional[CallableExpression] = None
135
+ if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
136
+ self.use_loaded_answers = False
176
137
  if self.use_loaded_answers:
177
138
  expert_input = self.all_expert_answers.pop(0)
178
139
  if expert_input is not None:
@@ -184,63 +145,3 @@ class Human(Expert):
184
145
  self.all_expert_answers.append(expert_input)
185
146
  case_query.target = expression
186
147
  return expression
187
-
188
- def get_category_type(self, cat_name: str) -> Optional[Type[CaseAttribute]]:
189
- """
190
- Get the category type from the known categories.
191
-
192
- :param cat_name: The name of the category.
193
- :return: The category type.
194
- """
195
- cat_name = cat_name.lower()
196
- self.known_categories = get_all_subclasses(
197
- CaseAttribute) if not self.known_categories else self.known_categories
198
- self.known_categories.update(CaseAttribute.registry)
199
- category_type = None
200
- if cat_name in self.known_categories:
201
- category_type = self.known_categories[cat_name]
202
- return category_type
203
-
204
- def ask_if_category_is_mutually_exclusive(self, category_name: str) -> bool:
205
- """
206
- Ask the expert if the new category can have multiple values.
207
-
208
- :param category_name: The name of the category to ask about.
209
- """
210
- question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
211
- return not self.ask_for_affirmation(question)
212
-
213
- def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
214
- current_conclusions: Any) -> bool:
215
- """
216
- Ask the expert if the conclusion is correct.
217
-
218
- :param case_query: The case query about which the expert should answer.
219
- :param conclusion: The conclusion to check.
220
- :param current_conclusions: The current conclusions for the case.
221
- """
222
- if not self.use_loaded_answers:
223
- print(f"Current conclusions: {current_conclusions}")
224
- return self.ask_for_affirmation(case_query,
225
- f"Is the conclusion {conclusion} correct for the case (True/False):")
226
-
227
- def ask_for_affirmation(self, case_query: CaseQuery, question: str) -> bool:
228
- """
229
- Ask the expert a yes or no question.
230
-
231
- :param case_query: The case query about which the expert should answer.
232
- :param question: The question to ask the expert.
233
- :return: The answer to the question.
234
- """
235
- while True:
236
- if self.use_loaded_answers:
237
- answer = self.all_expert_answers.pop(0)
238
- else:
239
- _, expression = prompt_user_for_expression(case_query, PromptFor.Affirmation, question)
240
- answer = expression(case_query.case)
241
- if answer:
242
- self.all_expert_answers.append(True)
243
- return True
244
- else:
245
- self.all_expert_answers.append(False)
246
- return False
ripple_down_rules/rdr.py CHANGED
@@ -22,7 +22,7 @@ from .helpers import is_matching
22
22
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
23
23
  from .utils import draw_tree, make_set, copy_case, \
24
24
  SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
25
- get_case_attribute_type, is_conflicting
25
+ get_case_attribute_type, is_conflicting, update_case
26
26
 
27
27
 
28
28
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -49,32 +49,6 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
49
49
  self.start_rule = start_rule
50
50
  self.fig: Optional[plt.Figure] = None
51
51
 
52
- def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
53
- return self.classify(case)
54
-
55
- @abstractmethod
56
- def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
57
- """
58
- Classify a case.
59
-
60
- :param case: The case to classify.
61
- :return: The category that the case belongs to.
62
- """
63
- pass
64
-
65
- @abstractmethod
66
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
67
- -> Union[CaseAttribute, CallableExpression]:
68
- """
69
- Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
70
- comparing the case with the target category.
71
-
72
- :param case_query: The query containing the case to classify and the target category to compare the case with.
73
- :param expert: The expert to ask for differentiating features as new rule conditions.
74
- :return: The category that the case belongs to.
75
- """
76
- pass
77
-
78
52
  def fit(self, case_queries: List[CaseQuery],
79
53
  expert: Optional[Expert] = None,
80
54
  n_iter: int = None,
@@ -124,6 +98,63 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
124
98
  plt.ioff()
125
99
  plt.show()
126
100
 
101
+ def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
102
+ return self.classify(case)
103
+
104
+ @abstractmethod
105
+ def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
106
+ """
107
+ Classify a case.
108
+
109
+ :param case: The case to classify.
110
+ :return: The category that the case belongs to.
111
+ """
112
+ pass
113
+
114
+ def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
115
+ -> Union[CaseAttribute, CallableExpression]:
116
+ """
117
+ Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
118
+ incorrect by comparing the case with the target category.
119
+
120
+ :param case_query: The query containing the case to classify and the target category to compare the case with.
121
+ :param expert: The expert to ask for differentiating features as new rule conditions.
122
+ :return: The category that the case belongs to.
123
+ """
124
+ if case_query is None:
125
+ raise ValueError("The case query cannot be None.")
126
+ if case_query.target is None:
127
+ expert.ask_for_conclusion(case_query)
128
+ if case_query.target is None:
129
+ return self.classify(case_query.case)
130
+
131
+ self.update_start_rule(case_query, expert)
132
+
133
+ return self._fit_case(case_query, expert=expert, **kwargs)
134
+
135
+ @abstractmethod
136
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
137
+ -> Union[CaseAttribute, CallableExpression]:
138
+ """
139
+ Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
140
+ comparing the case with the target category.
141
+
142
+ :param case_query: The query containing the case to classify and the target category to compare the case with.
143
+ :param expert: The expert to ask for differentiating features as new rule conditions.
144
+ :return: The category that the case belongs to.
145
+ """
146
+ pass
147
+
148
+ @abstractmethod
149
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
150
+ """
151
+ Update the starting rule of the classifier.
152
+
153
+ :param case_query: The case query to update the starting rule with.
154
+ :param expert: The expert to ask for differentiating features as new rule conditions.
155
+ """
156
+ pass
157
+
127
158
  def update_figures(self):
128
159
  """
129
160
  Update the figures of the classifier.
@@ -321,7 +352,7 @@ class SingleClassRDR(RDRWithCodeWriter):
321
352
  super(SingleClassRDR, self).__init__(start_rule)
322
353
  self.default_conclusion: Optional[Any] = default_conclusion
323
354
 
324
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
355
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
325
356
  -> Union[CaseAttribute, CallableExpression, None]:
326
357
  """
327
358
  Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
@@ -331,38 +362,44 @@ class SingleClassRDR(RDRWithCodeWriter):
331
362
  :param expert: The expert to ask for differentiating features as new rule conditions.
332
363
  :return: The category that the case belongs to.
333
364
  """
334
- expert = expert if expert else Human()
335
365
  if case_query.default_value is not None and self.default_conclusion != case_query.default_value:
336
366
  self.default_conclusion = case_query.default_value
337
- case = case_query.case
338
- target = expert.ask_for_conclusion(case_query) if case_query.target is None else case_query.target
339
- if target is None:
340
- return self.classify(case)
341
- if not self.start_rule:
342
- conditions = expert.ask_for_conditions(case_query)
343
- self.start_rule = SingleClassRule(conditions, target, corner_case=case,
344
- conclusion_name=case_query.attribute_name)
345
367
 
346
368
  pred = self.evaluate(case_query.case)
347
- if pred.conclusion(case) != target(case):
348
- conditions = expert.ask_for_conditions(case_query, pred)
349
- pred.fit_rule(case_query.case, target, conditions=conditions)
369
+ if pred.conclusion(case_query.case) != case_query.target_value:
370
+ expert.ask_for_conditions(case_query, pred)
371
+ pred.fit_rule(case_query.case, case_query.target, conditions=case_query.conditions)
350
372
 
351
373
  return self.classify(case_query.case)
352
374
 
353
- def classify(self, case: Case) -> Optional[Any]:
375
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
376
+ """
377
+ Update the starting rule of the classifier.
378
+
379
+ :param case_query: The case query to update the starting rule with.
380
+ :param expert: The expert to ask for differentiating features as new rule conditions.
381
+ """
382
+ if not self.start_rule:
383
+ expert.ask_for_conditions(case_query)
384
+ self.start_rule = SingleClassRule(case_query.conditions, case_query.target, corner_case=case_query.case,
385
+ conclusion_name=case_query.attribute_name)
386
+
387
+ def classify(self, case: Case, modify_original_case: bool = False) -> Optional[Any]:
354
388
  """
355
389
  Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
390
+
391
+ :param case: The case to classify.
392
+ :param modify_original_case: Whether to modify the original case attributes with the conclusion or not.
356
393
  """
357
394
  pred = self.evaluate(case)
358
- return pred.conclusion(case) if pred.fired else self.default_conclusion
395
+ return pred.conclusion(case) if pred is not None and pred.fired else self.default_conclusion
359
396
 
360
397
  def evaluate(self, case: Case) -> SingleClassRule:
361
398
  """
362
399
  Evaluate the starting rule on a case.
363
400
  """
364
- matched_rule = self.start_rule(case)
365
- return matched_rule if matched_rule else self.start_rule
401
+ matched_rule = self.start_rule(case) if self.start_rule is not None else None
402
+ return matched_rule if matched_rule is not None else self.start_rule
366
403
 
367
404
  def write_to_python_file(self, file_path: str, postfix: str = ""):
368
405
  super().write_to_python_file(file_path, postfix)
@@ -391,6 +428,12 @@ class SingleClassRDR(RDRWithCodeWriter):
391
428
  def conclusion_type_hint(self) -> str:
392
429
  return self.conclusion_type[0].__name__
393
430
 
431
+ @property
432
+ def conclusion_type(self) -> Tuple[Type]:
433
+ if self.default_conclusion is not None:
434
+ return (type(self.default_conclusion),)
435
+ return super().conclusion_type
436
+
394
437
  def _to_json(self) -> Dict[str, Any]:
395
438
  return {"start_rule": self.start_rule.to_json()}
396
439
 
@@ -422,13 +465,12 @@ class MultiClassRDR(RDRWithCodeWriter):
422
465
  The conditions of the stopping rule if needed.
423
466
  """
424
467
 
425
- def __init__(self, start_rule: Optional[Rule] = None,
468
+ def __init__(self, start_rule: Optional[MultiClassTopRule] = None,
426
469
  mode: MCRDRMode = MCRDRMode.StopOnly):
427
470
  """
428
471
  :param start_rule: The starting rules for the classifier.
429
472
  :param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
430
473
  """
431
- start_rule = MultiClassTopRule() if not start_rule else start_rule
432
474
  super(MultiClassRDR, self).__init__(start_rule)
433
475
  self.mode: MCRDRMode = mode
434
476
 
@@ -442,40 +484,28 @@ class MultiClassRDR(RDRWithCodeWriter):
442
484
  evaluated_rule = next_rule
443
485
  return make_set(self.conclusions)
444
486
 
445
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None,
446
- add_extra_conclusions: bool = False) -> Set[Union[CaseAttribute, CallableExpression, None]]:
487
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None
488
+ , **kwargs) -> Set[Union[CaseAttribute, CallableExpression, None]]:
447
489
  """
448
490
  Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
449
491
  or missing by comparing the case with the target category if provided.
450
492
 
451
493
  :param case_query: The query containing the case to classify and the target category to compare the case with.
452
494
  :param expert: The expert to ask for differentiating features as new rule conditions or for extra conclusions.
453
- :param add_extra_conclusions: Whether to add extra conclusions after classification is done.
454
495
  :return: The conclusions that the case belongs to.
455
496
  """
456
- expert = expert if expert else Human()
457
- if case_query.target is None:
458
- expert.ask_for_conclusion(case_query)
459
- if case_query.target is None:
460
- return self.classify(case_query.case)
461
- self.update_start_rule(case_query, expert)
462
- self.expert_accepted_conclusions = []
463
- user_conclusions = []
464
497
  self.conclusions = []
465
498
  self.stop_rule_conditions = None
466
499
  evaluated_rule = self.start_rule
467
- target = case_query.target(case_query.case)
500
+ target = make_set(case_query.target_value)
468
501
  while evaluated_rule:
469
502
  next_rule = evaluated_rule(case_query.case)
470
503
  rule_conclusion = evaluated_rule.conclusion(case_query.case)
471
- good_conclusions = make_list(target) + user_conclusions + self.expert_accepted_conclusions
472
- good_conclusions = make_set(good_conclusions)
473
504
 
474
505
  if evaluated_rule.fired:
475
- if target and not make_set(rule_conclusion).issubset(good_conclusions):
506
+ if not make_set(rule_conclusion).issubset(target):
476
507
  # Rule fired and conclusion is different from target
477
- self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
478
- len(user_conclusions) > 0)
508
+ self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
479
509
  else:
480
510
  # Rule fired and target is correct or there is no target to compare
481
511
  self.add_conclusion(evaluated_rule, case_query.case)
@@ -486,14 +516,6 @@ class MultiClassRDR(RDRWithCodeWriter):
486
516
  self.add_rule_for_case(case_query, expert)
487
517
  # Have to check all rules again to make sure only this new rule fires
488
518
  next_rule = self.start_rule
489
- elif add_extra_conclusions:
490
- # No more conclusions can be made, ask the expert for extra conclusions if needed.
491
- new_user_conclusions = self.ask_expert_for_extra_rules(expert, case_query)
492
- user_conclusions.extend(new_user_conclusions)
493
- if len(new_user_conclusions) > 0:
494
- next_rule = self.last_top_rule
495
- else:
496
- add_extra_conclusions = False
497
519
  evaluated_rule = next_rule
498
520
  return self.conclusions
499
521
 
@@ -532,12 +554,10 @@ class MultiClassRDR(RDRWithCodeWriter):
532
554
  :param case_query: The case query to update the starting rule with.
533
555
  :param expert: The expert to ask for differentiating features as new rule conditions.
534
556
  """
535
- if not self.start_rule.conditions:
557
+ if not self.start_rule:
536
558
  conditions = expert.ask_for_conditions(case_query)
537
- self.start_rule.conditions = conditions
538
- self.start_rule.conclusion = case_query.target
539
- self.start_rule.corner_case = case_query.case
540
- self.start_rule.conclusion_name = case_query.attribute_name
559
+ self.start_rule = MultiClassTopRule(conditions, case_query.target, corner_case=case_query.case,
560
+ conclusion_name=case_query.attribute_name)
541
561
 
542
562
  @property
543
563
  def last_top_rule(self) -> Optional[MultiClassTopRule]:
@@ -550,17 +570,13 @@ class MultiClassRDR(RDRWithCodeWriter):
550
570
  return self.start_rule.furthest_alternative[-1]
551
571
 
552
572
  def stop_wrong_conclusion_else_add_it(self, case_query: CaseQuery, expert: Expert,
553
- evaluated_rule: MultiClassTopRule,
554
- add_extra_conclusions: bool):
573
+ evaluated_rule: MultiClassTopRule):
555
574
  """
556
575
  Stop a wrong conclusion by adding a stopping rule.
557
576
  """
558
577
  rule_conclusion = evaluated_rule.conclusion(case_query.case)
559
578
  if is_conflicting(rule_conclusion, case_query.target_value):
560
- if self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
561
- return
562
- else:
563
- self.stop_conclusion(case_query, expert, evaluated_rule)
579
+ self.stop_conclusion(case_query, expert, evaluated_rule)
564
580
 
565
581
  def stop_conclusion(self, case_query: CaseQuery,
566
582
  expert: Expert, evaluated_rule: MultiClassTopRule):
@@ -579,27 +595,6 @@ class MultiClassRDR(RDRWithCodeWriter):
579
595
  new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
580
596
  self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
581
597
 
582
- def conclusion_is_correct(self, case_query: CaseQuery,
583
- expert: Expert, evaluated_rule: Rule,
584
- add_extra_conclusions: bool) -> bool:
585
- """
586
- Ask the expert if the conclusion is correct, and add it to the conclusions if it is.
587
-
588
- :param case_query: The case query to ask the expert about.
589
- :param expert: The expert to ask for differentiating features as new rule conditions.
590
- :param evaluated_rule: The evaluated rule to ask the expert about.
591
- :param add_extra_conclusions: Whether adding extra conclusions after classification is allowed.
592
- :return: Whether the conclusion is correct or not.
593
- """
594
- conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
595
- if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
596
- evaluated_rule.conclusion(case_query.case),
597
- current_conclusions=conclusions)):
598
- self.add_conclusion(evaluated_rule, case_query.case)
599
- self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
600
- return True
601
- return False
602
-
603
598
  def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
604
599
  """
605
600
  Add a rule for a case that has not been classified with any conclusion.
@@ -614,24 +609,6 @@ class MultiClassRDR(RDRWithCodeWriter):
614
609
  conditions = expert.ask_for_conditions(case_query)
615
610
  self.add_top_rule(conditions, case_query.target, case_query.case)
616
611
 
617
- def ask_expert_for_extra_rules(self, expert: Expert, case_query: CaseQuery) -> List[Any]:
618
- """
619
- Ask the expert for extra rules when no more conclusions can be made for a case.
620
-
621
- :param expert: The expert to ask for extra conclusions.
622
- :param case_query: The case query to ask the expert about.
623
- :return: The extra conclusions for the rules that the expert has provided.
624
- """
625
- extra_conclusions = []
626
- conclusions = list(OrderedSet(self.conclusions))
627
- if not expert.use_loaded_answers:
628
- print("current conclusions:", conclusions)
629
- extra_rules = expert.ask_for_extra_rules(case_query)
630
- for rule in extra_rules:
631
- self.add_top_rule(rule[PromptFor.Conditions], rule[PromptFor.Conclusion], case_query.case)
632
- extra_conclusions.extend(rule[PromptFor.Conclusion](case_query.case))
633
- return extra_conclusions
634
-
635
612
  def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
636
613
  """
637
614
  Add the conclusion of the evaluated rule to the list of conclusions.
@@ -725,30 +702,32 @@ class GeneralRDR(RippleDownRules):
725
702
  def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
726
703
  return [rdr.start_rule for rdr in self.start_rules_dict.values()]
727
704
 
728
- def classify(self, case: Any) -> Optional[Dict[str, Any]]:
705
+ def classify(self, case: Any, modify_case: bool = False) -> Optional[Dict[str, Any]]:
729
706
  """
730
707
  Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
731
708
  the classification until no more categories can be added.
732
709
 
733
710
  :param case: The case to classify.
711
+ :param modify_case: Whether to modify the original case or create a copy and modify it.
734
712
  :return: The categories that the case belongs to.
735
713
  """
736
- return self._classify(self.start_rules_dict, case)
714
+ return self._classify(self.start_rules_dict, case, modify_original_case=modify_case)
737
715
 
738
716
  @staticmethod
739
717
  def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
740
- case: Any) -> Dict[str, Any]:
718
+ case: Any, modify_original_case: bool = True) -> Dict[str, Any]:
741
719
  """
742
720
  Classify a case by going through all classifiers and adding the categories that are classified,
743
721
  and then restarting the classification until no more categories can be added.
744
722
 
745
723
  :param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
746
724
  :param case: The case to classify.
725
+ :param modify_original_case: Whether to modify the original case or create a copy and modify it.
747
726
  :return: The categories that the case belongs to.
748
727
  """
749
728
  conclusions = {}
750
729
  case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
751
- case_cp = copy_case(case)
730
+ case_cp = copy_case(case) if not modify_original_case else case
752
731
  while True:
753
732
  new_conclusions = {}
754
733
  for attribute_name, rdr in classifiers_dict.items():
@@ -771,13 +750,13 @@ class GeneralRDR(RippleDownRules):
771
750
  conclusions[attribute_name].update(pred_atts)
772
751
  if attribute_name in new_conclusions:
773
752
  mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
774
- GeneralRDR.update_case(CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive),
775
- new_conclusions)
753
+ case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive)
754
+ update_case(case_query, new_conclusions)
776
755
  if len(new_conclusions) == 0:
777
756
  break
778
757
  return conclusions
779
758
 
780
- def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs) \
759
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
781
760
  -> Dict[str, Any]:
782
761
  """
783
762
  Fit the GRDR on a case, if the target is a new type of category, a new RDR is created for it,
@@ -787,115 +766,38 @@ class GeneralRDR(RippleDownRules):
787
766
  they are accepted by the expert, and the attribute of that category is represented in the case as a set of
788
767
  values.
789
768
 
790
- :param case_queries: The queries containing the case to classify and the target categories to compare the case
769
+ :param case_query: The query containing the case to classify and the target category to compare the case
791
770
  with.
792
771
  :param expert: The expert to ask for differentiating features as new rule conditions.
793
772
  :return: The categories that the case belongs to.
794
773
  """
795
- expert = expert if expert else Human()
796
- case_queries = make_list(case_queries)
797
- assert len(case_queries) > 0, "No case queries provided"
798
- case = case_queries[0].case
799
- assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
800
- " for multiple cases use fit instead")
801
- original_case_query_cp = copy(case_queries[0])
802
- for case_query in case_queries:
803
- case_query_cp = copy(case_query)
804
- case_query_cp.case = original_case_query_cp.case
805
- if case_query_cp.target is None:
806
- conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
807
- self.update_case(case_query_cp, conclusions)
808
- expert.ask_for_conclusion(case_query_cp)
809
- if case_query_cp.target is None:
810
- continue
811
- case_query.target = case_query_cp.target
812
-
813
- if case_query.attribute_name not in self.start_rules_dict:
814
- conclusions = self.classify(case)
815
- self.update_case(case_query_cp, conclusions)
816
-
817
- new_rdr = self.initialize_new_rdr_for_attribute(case_query_cp)
818
- self.add_rdr(new_rdr, case_query.attribute_name)
819
-
820
- new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
821
- self.update_case(case_query_cp, {case_query.attribute_name: new_conclusions})
822
- else:
823
- for rdr_attribute_name, rdr in self.start_rules_dict.items():
824
- if case_query.attribute_name != rdr_attribute_name:
825
- conclusions = rdr.classify(case_query_cp.case)
826
- else:
827
- conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
828
- **kwargs)
829
- if conclusions is not None:
830
- if (not is_iterable(conclusions)) or len(conclusions) > 0:
831
- conclusions = {rdr_attribute_name: conclusions}
832
- case_query_cp.mutually_exclusive = True if isinstance(rdr, SingleClassRDR) else False
833
- self.update_case(case_query_cp, conclusions)
834
- case_query.conditions = case_query_cp.conditions
835
774
 
836
- return self.classify(case)
775
+ case_query_cp = copy(case_query)
776
+ self.classify(case_query_cp.case, modify_case=True)
777
+ case_query_cp.update_target_value()
778
+
779
+ self.start_rules_dict[case_query_cp.attribute_name].fit_case(case_query_cp, expert, **kwargs)
780
+
781
+ return self.classify(case_query.case)
782
+
783
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
784
+ """
785
+ Update the starting rule of the classifier.
786
+
787
+ :param case_query: The case query to update the starting rule with.
788
+ :param expert: The expert to ask for differentiating features as new rule conditions.
789
+ """
790
+ if case_query.attribute_name not in self.start_rules_dict:
791
+ new_rdr = self.initialize_new_rdr_for_attribute(case_query)
792
+ self.add_rdr(new_rdr, case_query.attribute_name)
837
793
 
838
794
  @staticmethod
839
795
  def initialize_new_rdr_for_attribute(case_query: CaseQuery):
840
796
  """
841
797
  Initialize the appropriate RDR type for the target.
842
798
  """
843
- if case_query.mutually_exclusive is not None:
844
- return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
845
- else MultiClassRDR()
846
- if case_query.attribute_type in [list, set]:
847
- return MultiClassRDR()
848
- attribute = getattr(case_query.case, case_query.attribute_name) \
849
- if hasattr(case_query.case, case_query.attribute_name) else case_query.target(case_query.case)
850
- if isinstance(attribute, CaseAttribute):
851
- return SingleClassRDR(default_conclusion=case_query.default_value) if attribute.mutually_exclusive \
852
- else MultiClassRDR()
853
- else:
854
- return MultiClassRDR() if is_iterable(attribute) or (attribute is None) \
855
- else SingleClassRDR(default_conclusion=case_query.default_value)
856
-
857
- @staticmethod
858
- def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
859
- """
860
- Update the case with the conclusions.
861
-
862
- :param case_query: The case query that contains the case to update.
863
- :param conclusions: The conclusions to update the case with.
864
- """
865
- if not conclusions:
866
- return
867
- if len(conclusions) == 0:
868
- return
869
- if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
870
- for conclusion_name, conclusion in conclusions.items():
871
- attribute = getattr(case_query.case, conclusion_name)
872
- if conclusion_name == case_query.attribute_name:
873
- attribute_type = case_query.attribute_type
874
- else:
875
- attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
876
- if isinstance(attribute, set):
877
- for c in conclusion:
878
- attribute.update(make_set(c))
879
- elif isinstance(attribute, list):
880
- attribute.extend(conclusion)
881
- elif any(at in {List, list} for at in attribute_type):
882
- attribute = [] if attribute is None else attribute
883
- attribute.extend(conclusion)
884
- elif any(at in {Set, set} for at in attribute_type):
885
- attribute = set() if attribute is None else attribute
886
- for c in conclusion:
887
- attribute.update(make_set(c))
888
- elif is_iterable(conclusion) and len(conclusion) == 1 \
889
- and any(at is type(list(conclusion)[0]) for at in attribute_type):
890
- setattr(case_query.case, conclusion_name, list(conclusion)[0])
891
- elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
892
- setattr(case_query.case, conclusion_name, conclusion)
893
- else:
894
- raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
895
- f"{case_query.attribute_type} with conclusion "
896
- f"{conclusion} of type {type(conclusion)}")
897
- else:
898
- case_query.case.update(conclusions)
799
+ return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
800
+ else MultiClassRDR()
899
801
 
900
802
  def _to_json(self) -> Dict[str, Any]:
901
803
  return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
@@ -130,6 +130,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
130
130
  "conclusion": conclusion_to_json(self.conclusion),
131
131
  "parent": self.parent.json_serialization if self.parent else None,
132
132
  "corner_case": self.corner_case.to_json() if self.corner_case else None,
133
+ "conclusion_name": self.conclusion_name,
133
134
  "weight": self.weight}
134
135
  return json_serialization
135
136
 
@@ -139,6 +140,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
139
140
  conclusion=CallableExpression.from_json(data["conclusion"]),
140
141
  parent=cls.from_json(data["parent"]),
141
142
  corner_case=Case.from_json(data["corner_case"]),
143
+ conclusion_name=data["conclusion_name"],
142
144
  weight=data["weight"])
143
145
  return loaded_rule
144
146
 
@@ -26,6 +26,7 @@ from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get
26
26
 
27
27
  if TYPE_CHECKING:
28
28
  from .datastructures.case import Case
29
+ from .datastructures.dataclasses import CaseQuery
29
30
  from .rules import Rule
30
31
 
31
32
  import ast
@@ -33,6 +34,49 @@ import ast
33
34
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
34
35
 
35
36
 
37
+ def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
38
+ """
39
+ Update the case with the conclusions.
40
+
41
+ :param case_query: The case query that contains the case to update.
42
+ :param conclusions: The conclusions to update the case with.
43
+ """
44
+ if not conclusions:
45
+ return
46
+ if len(conclusions) == 0:
47
+ return
48
+ if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
49
+ for conclusion_name, conclusion in conclusions.items():
50
+ attribute = getattr(case_query.case, conclusion_name)
51
+ if conclusion_name == case_query.attribute_name:
52
+ attribute_type = case_query.attribute_type
53
+ else:
54
+ attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
55
+ if isinstance(attribute, set):
56
+ for c in conclusion:
57
+ attribute.update(make_set(c))
58
+ elif isinstance(attribute, list):
59
+ attribute.extend(conclusion)
60
+ elif any(at in {List, list} for at in attribute_type):
61
+ attribute = [] if attribute is None else attribute
62
+ attribute.extend(conclusion)
63
+ elif any(at in {Set, set} for at in attribute_type):
64
+ attribute = set() if attribute is None else attribute
65
+ for c in conclusion:
66
+ attribute.update(make_set(c))
67
+ elif is_iterable(conclusion) and len(conclusion) == 1 \
68
+ and any(at is type(list(conclusion)[0]) for at in attribute_type):
69
+ setattr(case_query.case, conclusion_name, list(conclusion)[0])
70
+ elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
71
+ setattr(case_query.case, conclusion_name, conclusion)
72
+ else:
73
+ raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
74
+ f"{case_query.attribute_type} with conclusion "
75
+ f"{conclusion} of type {type(conclusion)}")
76
+ else:
77
+ case_query.case.update(conclusions)
78
+
79
+
36
80
  def is_conflicting(conclusion: Any, target: Any) -> bool:
37
81
  """
38
82
  :param conclusion: The conclusion to check.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.64
3
+ Version: 0.1.65
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -1,20 +1,20 @@
1
1
  ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  ripple_down_rules/datasets.py,sha256=rCSpeFeu1gTuKESwjHUdQkPPvomI5OMRNGpbdKmHwMg,4639
3
- ripple_down_rules/experts.py,sha256=TretkdR1IY2RjcSh4WJUFJtYEbItsand7SYFmzouE_Y,10348
3
+ ripple_down_rules/experts.py,sha256=JGVvSNiWhm4FpRpg76f98tl8Ii_C7x_aWD9FxD-JDLQ,6130
4
4
  ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
5
  ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
6
6
  ripple_down_rules/prompt.py,sha256=cHqhMJqubGhfGpOOY_uXv5L7PBNb64O0IBWSfiY0ui0,6682
7
- ripple_down_rules/rdr.py,sha256=tJDbPF3D2qJ5cw2WPmZrs_C2Jj6b7cCxEJDp2SK9GCI,47863
7
+ ripple_down_rules/rdr.py,sha256=es0ml_UVLgWNb7cTxAo3ZsmNMlpifmoOOnR8K9k4vbQ,41566
8
8
  ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
9
- ripple_down_rules/rules.py,sha256=4oQSb4p36A-YzW0hJ8HH2FhmIvokjXKK1rIXEn-8WcE,16203
10
- ripple_down_rules/utils.py,sha256=JIF99Knqzqjgny7unvEnib3sCmExqU-w9xYOSGIT86Q,32276
9
+ ripple_down_rules/rules.py,sha256=Y36at3etrxQpDVUJUpx-xZll7jYgPObzc6qBy8ZpKP0,16341
10
+ ripple_down_rules/utils.py,sha256=DUPjrSWMfeeNf9OllGDWJ5kf3Hazz1sEW5DTqKQ4h6E,34528
11
11
  ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
12
12
  ripple_down_rules/datastructures/callable_expression.py,sha256=TW_u6CJfelW2CiJj9pWFpdOBNIxeEuhhsQEz_pLpFVE,9092
13
13
  ripple_down_rules/datastructures/case.py,sha256=uM5YkJOfYfARrZZ3oAxKMWE5QBSnvHLLZa9Atoxb7eY,13800
14
- ripple_down_rules/datastructures/dataclasses.py,sha256=_aabVXsgdVUeAmgGA9K_LZpO2U5a6-htrg2Tka7qc30,5960
14
+ ripple_down_rules/datastructures/dataclasses.py,sha256=V757VwxROlevXh5ZVFLVuzwY4JIJKG8ARlCfjhubfy8,5957
15
15
  ripple_down_rules/datastructures/enums.py,sha256=RdyPUp9Ls1QuLmkcMMkBbCWrmXIZI4xWuM-cLPYZhR0,4666
16
- ripple_down_rules-0.1.64.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
- ripple_down_rules-0.1.64.dist-info/METADATA,sha256=PChgWG-I8g-_yaKMDwrxJPuIHOe0fIrThGu_vOqrQkE,42576
18
- ripple_down_rules-0.1.64.dist-info/WHEEL,sha256=ooBFpIzZCPdw3uqIQsOo4qqbA4ZRPxHnOH7peeONza0,91
19
- ripple_down_rules-0.1.64.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
- ripple_down_rules-0.1.64.dist-info/RECORD,,
16
+ ripple_down_rules-0.1.65.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
+ ripple_down_rules-0.1.65.dist-info/METADATA,sha256=WtL6_RMdFxEhSWnK9CMqbp_94caIwcGw4s28KqJoKm4,42576
18
+ ripple_down_rules-0.1.65.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
19
+ ripple_down_rules-0.1.65.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
+ ripple_down_rules-0.1.65.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.0.1)
2
+ Generator: setuptools (80.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5