ripple-down-rules 0.1.64__py3-none-any.whl → 0.1.66__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -133,7 +133,7 @@ class CaseQuery:
133
133
  if value is not None and not isinstance(value, (CallableExpression, str)):
134
134
  raise ValueError("The target must be a CallableExpression or a string.")
135
135
  self._target = value
136
- self._update_target_value()
136
+ self.update_target_value()
137
137
 
138
138
  @property
139
139
  def target_value(self) -> Any:
@@ -141,10 +141,10 @@ class CaseQuery:
141
141
  :return: The target value of the case query.
142
142
  """
143
143
  if self._target_value is None:
144
- self._update_target_value()
144
+ self.update_target_value()
145
145
  return self._target_value
146
146
 
147
- def _update_target_value(self):
147
+ def update_target_value(self):
148
148
  """
149
149
  Update the target value of the case query.
150
150
  """
@@ -32,10 +32,11 @@ class Expert(ABC):
32
32
  """
33
33
  A flag to indicate if the expert should use loaded answers or not.
34
34
  """
35
- known_categories: Optional[Dict[str, Type[CaseAttribute]]] = None
36
- """
37
- The known categories (i.e. Column types) to use.
38
- """
35
+
36
+ def __init__(self, use_loaded_answers: bool = False, append: bool = False):
37
+ self.all_expert_answers = []
38
+ self.use_loaded_answers = use_loaded_answers
39
+ self.append = append
39
40
 
40
41
  @abstractmethod
41
42
  def ask_for_conditions(self, case_query: CaseQuery, last_evaluated_rule: Optional[Rule] = None) \
@@ -51,28 +52,6 @@ class Expert(ABC):
51
52
  pass
52
53
 
53
54
  @abstractmethod
54
- def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
55
- """
56
- Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
57
-
58
- :param case_query: The case query containing the case to classify.
59
- :return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
60
- conclusion and conditions for the rule.
61
- """
62
- pass
63
-
64
- @abstractmethod
65
- def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
66
- current_conclusions: Any) -> bool:
67
- """
68
- Ask the expert if the conclusion is correct.
69
-
70
- :param case_query: The case query about which the expert should answer.
71
- :param conclusion: The conclusion to check.
72
- :param current_conclusions: The current conclusions for the case.
73
- """
74
- pass
75
-
76
55
  def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
77
56
  """
78
57
  Ask the expert to provide a relational conclusion for the case.
@@ -87,18 +66,14 @@ class Human(Expert):
87
66
  The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
88
67
  """
89
68
 
90
- def __init__(self, use_loaded_answers: bool = False):
91
- self.all_expert_answers = []
92
- self.use_loaded_answers = use_loaded_answers
93
-
94
- def save_answers(self, path: str, append: bool = False):
69
+ def save_answers(self, path: str):
95
70
  """
96
71
  Save the expert answers to a file.
97
72
 
98
73
  :param path: The path to save the answers to.
99
74
  :param append: A flag to indicate if the answers should be appended to the file or not.
100
75
  """
101
- if append:
76
+ if self.append:
102
77
  # read the file and append the new answers
103
78
  with open(path + '.json', "r") as f:
104
79
  all_answers = json.load(f)
@@ -126,24 +101,6 @@ class Human(Expert):
126
101
  last_evaluated_rule=last_evaluated_rule)
127
102
  return self._get_conditions(case_query)
128
103
 
129
- def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
130
- """
131
- Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
132
-
133
- :param case_query: The case query containing the case to classify.
134
- :return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
135
- conclusion and conditions for the rule.
136
- """
137
- rules = []
138
- while True:
139
- conclusion = self.ask_for_conclusion(case_query)
140
- if conclusion is None:
141
- break
142
- conditions = self._get_conditions(case_query)
143
- rules.append({PromptFor.Conclusion: conclusion,
144
- PromptFor.Conditions: conditions})
145
- return rules
146
-
147
104
  def _get_conditions(self, case_query: CaseQuery) \
148
105
  -> CallableExpression:
149
106
  """
@@ -154,6 +111,8 @@ class Human(Expert):
154
111
  :return: The differentiating features as new rule conditions.
155
112
  """
156
113
  user_input = None
114
+ if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
115
+ self.use_loaded_answers = False
157
116
  if self.use_loaded_answers:
158
117
  user_input = self.all_expert_answers.pop(0)
159
118
  if user_input:
@@ -173,6 +132,8 @@ class Human(Expert):
173
132
  :return: The conclusion for the case as a callable expression.
174
133
  """
175
134
  expression: Optional[CallableExpression] = None
135
+ if self.use_loaded_answers and len(self.all_expert_answers) == 0 and self.append:
136
+ self.use_loaded_answers = False
176
137
  if self.use_loaded_answers:
177
138
  expert_input = self.all_expert_answers.pop(0)
178
139
  if expert_input is not None:
@@ -184,63 +145,3 @@ class Human(Expert):
184
145
  self.all_expert_answers.append(expert_input)
185
146
  case_query.target = expression
186
147
  return expression
187
-
188
- def get_category_type(self, cat_name: str) -> Optional[Type[CaseAttribute]]:
189
- """
190
- Get the category type from the known categories.
191
-
192
- :param cat_name: The name of the category.
193
- :return: The category type.
194
- """
195
- cat_name = cat_name.lower()
196
- self.known_categories = get_all_subclasses(
197
- CaseAttribute) if not self.known_categories else self.known_categories
198
- self.known_categories.update(CaseAttribute.registry)
199
- category_type = None
200
- if cat_name in self.known_categories:
201
- category_type = self.known_categories[cat_name]
202
- return category_type
203
-
204
- def ask_if_category_is_mutually_exclusive(self, category_name: str) -> bool:
205
- """
206
- Ask the expert if the new category can have multiple values.
207
-
208
- :param category_name: The name of the category to ask about.
209
- """
210
- question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
211
- return not self.ask_for_affirmation(question)
212
-
213
- def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
214
- current_conclusions: Any) -> bool:
215
- """
216
- Ask the expert if the conclusion is correct.
217
-
218
- :param case_query: The case query about which the expert should answer.
219
- :param conclusion: The conclusion to check.
220
- :param current_conclusions: The current conclusions for the case.
221
- """
222
- if not self.use_loaded_answers:
223
- print(f"Current conclusions: {current_conclusions}")
224
- return self.ask_for_affirmation(case_query,
225
- f"Is the conclusion {conclusion} correct for the case (True/False):")
226
-
227
- def ask_for_affirmation(self, case_query: CaseQuery, question: str) -> bool:
228
- """
229
- Ask the expert a yes or no question.
230
-
231
- :param case_query: The case query about which the expert should answer.
232
- :param question: The question to ask the expert.
233
- :return: The answer to the question.
234
- """
235
- while True:
236
- if self.use_loaded_answers:
237
- answer = self.all_expert_answers.pop(0)
238
- else:
239
- _, expression = prompt_user_for_expression(case_query, PromptFor.Affirmation, question)
240
- answer = expression(case_query.case)
241
- if answer:
242
- self.all_expert_answers.append(True)
243
- return True
244
- else:
245
- self.all_expert_answers.append(False)
246
- return False
ripple_down_rules/rdr.py CHANGED
@@ -22,7 +22,7 @@ from .helpers import is_matching
22
22
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
23
23
  from .utils import draw_tree, make_set, copy_case, \
24
24
  SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
25
- get_case_attribute_type, is_conflicting
25
+ get_case_attribute_type, is_conflicting, update_case
26
26
 
27
27
 
28
28
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -49,32 +49,6 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
49
49
  self.start_rule = start_rule
50
50
  self.fig: Optional[plt.Figure] = None
51
51
 
52
- def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
53
- return self.classify(case)
54
-
55
- @abstractmethod
56
- def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
57
- """
58
- Classify a case.
59
-
60
- :param case: The case to classify.
61
- :return: The category that the case belongs to.
62
- """
63
- pass
64
-
65
- @abstractmethod
66
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
67
- -> Union[CaseAttribute, CallableExpression]:
68
- """
69
- Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
70
- comparing the case with the target category.
71
-
72
- :param case_query: The query containing the case to classify and the target category to compare the case with.
73
- :param expert: The expert to ask for differentiating features as new rule conditions.
74
- :return: The category that the case belongs to.
75
- """
76
- pass
77
-
78
52
  def fit(self, case_queries: List[CaseQuery],
79
53
  expert: Optional[Expert] = None,
80
54
  n_iter: int = None,
@@ -124,6 +98,63 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
124
98
  plt.ioff()
125
99
  plt.show()
126
100
 
101
+ def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
102
+ return self.classify(case)
103
+
104
+ @abstractmethod
105
+ def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
106
+ """
107
+ Classify a case.
108
+
109
+ :param case: The case to classify.
110
+ :return: The category that the case belongs to.
111
+ """
112
+ pass
113
+
114
+ def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
115
+ -> Union[CaseAttribute, CallableExpression]:
116
+ """
117
+ Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
118
+ incorrect by comparing the case with the target category.
119
+
120
+ :param case_query: The query containing the case to classify and the target category to compare the case with.
121
+ :param expert: The expert to ask for differentiating features as new rule conditions.
122
+ :return: The category that the case belongs to.
123
+ """
124
+ if case_query is None:
125
+ raise ValueError("The case query cannot be None.")
126
+ if case_query.target is None:
127
+ expert.ask_for_conclusion(case_query)
128
+ if case_query.target is None:
129
+ return self.classify(case_query.case)
130
+
131
+ self.update_start_rule(case_query, expert)
132
+
133
+ return self._fit_case(case_query, expert=expert, **kwargs)
134
+
135
+ @abstractmethod
136
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
137
+ -> Union[CaseAttribute, CallableExpression]:
138
+ """
139
+ Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
140
+ comparing the case with the target category.
141
+
142
+ :param case_query: The query containing the case to classify and the target category to compare the case with.
143
+ :param expert: The expert to ask for differentiating features as new rule conditions.
144
+ :return: The category that the case belongs to.
145
+ """
146
+ pass
147
+
148
+ @abstractmethod
149
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
150
+ """
151
+ Update the starting rule of the classifier.
152
+
153
+ :param case_query: The case query to update the starting rule with.
154
+ :param expert: The expert to ask for differentiating features as new rule conditions.
155
+ """
156
+ pass
157
+
127
158
  def update_figures(self):
128
159
  """
129
160
  Update the figures of the classifier.
@@ -295,13 +326,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
295
326
  """
296
327
  :return: The type of the conclusion of the RDR classifier.
297
328
  """
298
- if isinstance(self.start_rule.conclusion, CallableExpression):
299
- return self.start_rule.conclusion.conclusion_type
300
- else:
301
- conclusion = self.start_rule.conclusion
302
- if isinstance(conclusion, set):
303
- return type(list(conclusion)[0]), set
304
- return (type(conclusion),)
329
+ all_types = []
330
+ for rule in [self.start_rule] + list(self.start_rule.descendants):
331
+ all_types.extend(list(rule.conclusion.conclusion_type))
332
+ return tuple(set(all_types))
305
333
 
306
334
  @property
307
335
  def attribute_name(self) -> str:
@@ -321,7 +349,7 @@ class SingleClassRDR(RDRWithCodeWriter):
321
349
  super(SingleClassRDR, self).__init__(start_rule)
322
350
  self.default_conclusion: Optional[Any] = default_conclusion
323
351
 
324
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
352
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
325
353
  -> Union[CaseAttribute, CallableExpression, None]:
326
354
  """
327
355
  Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
@@ -331,38 +359,44 @@ class SingleClassRDR(RDRWithCodeWriter):
331
359
  :param expert: The expert to ask for differentiating features as new rule conditions.
332
360
  :return: The category that the case belongs to.
333
361
  """
334
- expert = expert if expert else Human()
335
362
  if case_query.default_value is not None and self.default_conclusion != case_query.default_value:
336
363
  self.default_conclusion = case_query.default_value
337
- case = case_query.case
338
- target = expert.ask_for_conclusion(case_query) if case_query.target is None else case_query.target
339
- if target is None:
340
- return self.classify(case)
341
- if not self.start_rule:
342
- conditions = expert.ask_for_conditions(case_query)
343
- self.start_rule = SingleClassRule(conditions, target, corner_case=case,
344
- conclusion_name=case_query.attribute_name)
345
364
 
346
365
  pred = self.evaluate(case_query.case)
347
- if pred.conclusion(case) != target(case):
348
- conditions = expert.ask_for_conditions(case_query, pred)
349
- pred.fit_rule(case_query.case, target, conditions=conditions)
366
+ if pred.conclusion(case_query.case) != case_query.target_value:
367
+ expert.ask_for_conditions(case_query, pred)
368
+ pred.fit_rule(case_query.case, case_query.target, conditions=case_query.conditions)
350
369
 
351
370
  return self.classify(case_query.case)
352
371
 
353
- def classify(self, case: Case) -> Optional[Any]:
372
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
373
+ """
374
+ Update the starting rule of the classifier.
375
+
376
+ :param case_query: The case query to update the starting rule with.
377
+ :param expert: The expert to ask for differentiating features as new rule conditions.
378
+ """
379
+ if not self.start_rule:
380
+ expert.ask_for_conditions(case_query)
381
+ self.start_rule = SingleClassRule(case_query.conditions, case_query.target, corner_case=case_query.case,
382
+ conclusion_name=case_query.attribute_name)
383
+
384
+ def classify(self, case: Case, modify_original_case: bool = False) -> Optional[Any]:
354
385
  """
355
386
  Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
387
+
388
+ :param case: The case to classify.
389
+ :param modify_original_case: Whether to modify the original case attributes with the conclusion or not.
356
390
  """
357
391
  pred = self.evaluate(case)
358
- return pred.conclusion(case) if pred.fired else self.default_conclusion
392
+ return pred.conclusion(case) if pred is not None and pred.fired else self.default_conclusion
359
393
 
360
394
  def evaluate(self, case: Case) -> SingleClassRule:
361
395
  """
362
396
  Evaluate the starting rule on a case.
363
397
  """
364
- matched_rule = self.start_rule(case)
365
- return matched_rule if matched_rule else self.start_rule
398
+ matched_rule = self.start_rule(case) if self.start_rule is not None else None
399
+ return matched_rule if matched_rule is not None else self.start_rule
366
400
 
367
401
  def write_to_python_file(self, file_path: str, postfix: str = ""):
368
402
  super().write_to_python_file(file_path, postfix)
@@ -382,7 +416,8 @@ class SingleClassRDR(RDRWithCodeWriter):
382
416
  self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
383
417
  defs_file=defs_file)
384
418
 
385
- file.write(rule.write_conclusion_as_source_code(parent_indent))
419
+ conclusion_call = rule.write_conclusion_as_source_code(parent_indent, defs_file)
420
+ file.write(conclusion_call)
386
421
 
387
422
  if rule.alternative:
388
423
  self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
@@ -391,6 +426,12 @@ class SingleClassRDR(RDRWithCodeWriter):
391
426
  def conclusion_type_hint(self) -> str:
392
427
  return self.conclusion_type[0].__name__
393
428
 
429
+ @property
430
+ def conclusion_type(self) -> Tuple[Type]:
431
+ if self.default_conclusion is not None:
432
+ return (type(self.default_conclusion),)
433
+ return super().conclusion_type
434
+
394
435
  def _to_json(self) -> Dict[str, Any]:
395
436
  return {"start_rule": self.start_rule.to_json()}
396
437
 
@@ -422,13 +463,12 @@ class MultiClassRDR(RDRWithCodeWriter):
422
463
  The conditions of the stopping rule if needed.
423
464
  """
424
465
 
425
- def __init__(self, start_rule: Optional[Rule] = None,
466
+ def __init__(self, start_rule: Optional[MultiClassTopRule] = None,
426
467
  mode: MCRDRMode = MCRDRMode.StopOnly):
427
468
  """
428
469
  :param start_rule: The starting rules for the classifier.
429
470
  :param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
430
471
  """
431
- start_rule = MultiClassTopRule() if not start_rule else start_rule
432
472
  super(MultiClassRDR, self).__init__(start_rule)
433
473
  self.mode: MCRDRMode = mode
434
474
 
@@ -442,40 +482,28 @@ class MultiClassRDR(RDRWithCodeWriter):
442
482
  evaluated_rule = next_rule
443
483
  return make_set(self.conclusions)
444
484
 
445
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None,
446
- add_extra_conclusions: bool = False) -> Set[Union[CaseAttribute, CallableExpression, None]]:
485
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None
486
+ , **kwargs) -> Set[Union[CaseAttribute, CallableExpression, None]]:
447
487
  """
448
488
  Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
449
489
  or missing by comparing the case with the target category if provided.
450
490
 
451
491
  :param case_query: The query containing the case to classify and the target category to compare the case with.
452
492
  :param expert: The expert to ask for differentiating features as new rule conditions or for extra conclusions.
453
- :param add_extra_conclusions: Whether to add extra conclusions after classification is done.
454
493
  :return: The conclusions that the case belongs to.
455
494
  """
456
- expert = expert if expert else Human()
457
- if case_query.target is None:
458
- expert.ask_for_conclusion(case_query)
459
- if case_query.target is None:
460
- return self.classify(case_query.case)
461
- self.update_start_rule(case_query, expert)
462
- self.expert_accepted_conclusions = []
463
- user_conclusions = []
464
495
  self.conclusions = []
465
496
  self.stop_rule_conditions = None
466
497
  evaluated_rule = self.start_rule
467
- target = case_query.target(case_query.case)
498
+ target = make_set(case_query.target_value)
468
499
  while evaluated_rule:
469
500
  next_rule = evaluated_rule(case_query.case)
470
501
  rule_conclusion = evaluated_rule.conclusion(case_query.case)
471
- good_conclusions = make_list(target) + user_conclusions + self.expert_accepted_conclusions
472
- good_conclusions = make_set(good_conclusions)
473
502
 
474
503
  if evaluated_rule.fired:
475
- if target and not make_set(rule_conclusion).issubset(good_conclusions):
504
+ if not make_set(rule_conclusion).issubset(target):
476
505
  # Rule fired and conclusion is different from target
477
- self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
478
- len(user_conclusions) > 0)
506
+ self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
479
507
  else:
480
508
  # Rule fired and target is correct or there is no target to compare
481
509
  self.add_conclusion(evaluated_rule, case_query.case)
@@ -486,14 +514,6 @@ class MultiClassRDR(RDRWithCodeWriter):
486
514
  self.add_rule_for_case(case_query, expert)
487
515
  # Have to check all rules again to make sure only this new rule fires
488
516
  next_rule = self.start_rule
489
- elif add_extra_conclusions:
490
- # No more conclusions can be made, ask the expert for extra conclusions if needed.
491
- new_user_conclusions = self.ask_expert_for_extra_rules(expert, case_query)
492
- user_conclusions.extend(new_user_conclusions)
493
- if len(new_user_conclusions) > 0:
494
- next_rule = self.last_top_rule
495
- else:
496
- add_extra_conclusions = False
497
517
  evaluated_rule = next_rule
498
518
  return self.conclusions
499
519
 
@@ -510,18 +530,20 @@ class MultiClassRDR(RDRWithCodeWriter):
510
530
  defs_file=defs_file)
511
531
  conclusion_indent = parent_indent + " " * 4
512
532
  file.write(f"{conclusion_indent}else:\n")
513
- file.write(rule.write_conclusion_as_source_code(conclusion_indent))
533
+
534
+ conclusion_call = rule.write_conclusion_as_source_code(conclusion_indent, defs_file)
535
+ file.write(conclusion_call)
514
536
 
515
537
  if rule.alternative:
516
538
  self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
517
539
 
518
540
  @property
519
541
  def conclusion_type_hint(self) -> str:
520
- return f"Set[{self.conclusion_type[0].__name__}]"
542
+ return f"Set[Union[{', '.join([ct.__name__ for ct in self.conclusion_type if ct not in [list, set]])}]]"
521
543
 
522
544
  def _get_imports(self) -> str:
523
545
  imports = super()._get_imports()
524
- imports += "from typing_extensions import Set\n"
546
+ imports += "from typing_extensions import Set, Union\n"
525
547
  imports += "from ripple_down_rules.utils import make_set\n"
526
548
  return imports
527
549
 
@@ -532,12 +554,10 @@ class MultiClassRDR(RDRWithCodeWriter):
532
554
  :param case_query: The case query to update the starting rule with.
533
555
  :param expert: The expert to ask for differentiating features as new rule conditions.
534
556
  """
535
- if not self.start_rule.conditions:
557
+ if not self.start_rule:
536
558
  conditions = expert.ask_for_conditions(case_query)
537
- self.start_rule.conditions = conditions
538
- self.start_rule.conclusion = case_query.target
539
- self.start_rule.corner_case = case_query.case
540
- self.start_rule.conclusion_name = case_query.attribute_name
559
+ self.start_rule = MultiClassTopRule(conditions, case_query.target, corner_case=case_query.case,
560
+ conclusion_name=case_query.attribute_name)
541
561
 
542
562
  @property
543
563
  def last_top_rule(self) -> Optional[MultiClassTopRule]:
@@ -550,17 +570,13 @@ class MultiClassRDR(RDRWithCodeWriter):
550
570
  return self.start_rule.furthest_alternative[-1]
551
571
 
552
572
  def stop_wrong_conclusion_else_add_it(self, case_query: CaseQuery, expert: Expert,
553
- evaluated_rule: MultiClassTopRule,
554
- add_extra_conclusions: bool):
573
+ evaluated_rule: MultiClassTopRule):
555
574
  """
556
575
  Stop a wrong conclusion by adding a stopping rule.
557
576
  """
558
577
  rule_conclusion = evaluated_rule.conclusion(case_query.case)
559
578
  if is_conflicting(rule_conclusion, case_query.target_value):
560
- if self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
561
- return
562
- else:
563
- self.stop_conclusion(case_query, expert, evaluated_rule)
579
+ self.stop_conclusion(case_query, expert, evaluated_rule)
564
580
 
565
581
  def stop_conclusion(self, case_query: CaseQuery,
566
582
  expert: Expert, evaluated_rule: MultiClassTopRule):
@@ -579,27 +595,6 @@ class MultiClassRDR(RDRWithCodeWriter):
579
595
  new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
580
596
  self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
581
597
 
582
- def conclusion_is_correct(self, case_query: CaseQuery,
583
- expert: Expert, evaluated_rule: Rule,
584
- add_extra_conclusions: bool) -> bool:
585
- """
586
- Ask the expert if the conclusion is correct, and add it to the conclusions if it is.
587
-
588
- :param case_query: The case query to ask the expert about.
589
- :param expert: The expert to ask for differentiating features as new rule conditions.
590
- :param evaluated_rule: The evaluated rule to ask the expert about.
591
- :param add_extra_conclusions: Whether adding extra conclusions after classification is allowed.
592
- :return: Whether the conclusion is correct or not.
593
- """
594
- conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
595
- if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
596
- evaluated_rule.conclusion(case_query.case),
597
- current_conclusions=conclusions)):
598
- self.add_conclusion(evaluated_rule, case_query.case)
599
- self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
600
- return True
601
- return False
602
-
603
598
  def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
604
599
  """
605
600
  Add a rule for a case that has not been classified with any conclusion.
@@ -614,24 +609,6 @@ class MultiClassRDR(RDRWithCodeWriter):
614
609
  conditions = expert.ask_for_conditions(case_query)
615
610
  self.add_top_rule(conditions, case_query.target, case_query.case)
616
611
 
617
- def ask_expert_for_extra_rules(self, expert: Expert, case_query: CaseQuery) -> List[Any]:
618
- """
619
- Ask the expert for extra rules when no more conclusions can be made for a case.
620
-
621
- :param expert: The expert to ask for extra conclusions.
622
- :param case_query: The case query to ask the expert about.
623
- :return: The extra conclusions for the rules that the expert has provided.
624
- """
625
- extra_conclusions = []
626
- conclusions = list(OrderedSet(self.conclusions))
627
- if not expert.use_loaded_answers:
628
- print("current conclusions:", conclusions)
629
- extra_rules = expert.ask_for_extra_rules(case_query)
630
- for rule in extra_rules:
631
- self.add_top_rule(rule[PromptFor.Conditions], rule[PromptFor.Conclusion], case_query.case)
632
- extra_conclusions.extend(rule[PromptFor.Conclusion](case_query.case))
633
- return extra_conclusions
634
-
635
612
  def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
636
613
  """
637
614
  Add the conclusion of the evaluated rule to the list of conclusions.
@@ -725,30 +702,32 @@ class GeneralRDR(RippleDownRules):
725
702
  def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
726
703
  return [rdr.start_rule for rdr in self.start_rules_dict.values()]
727
704
 
728
- def classify(self, case: Any) -> Optional[Dict[str, Any]]:
705
+ def classify(self, case: Any, modify_case: bool = False) -> Optional[Dict[str, Any]]:
729
706
  """
730
707
  Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
731
708
  the classification until no more categories can be added.
732
709
 
733
710
  :param case: The case to classify.
711
+ :param modify_case: Whether to modify the original case or create a copy and modify it.
734
712
  :return: The categories that the case belongs to.
735
713
  """
736
- return self._classify(self.start_rules_dict, case)
714
+ return self._classify(self.start_rules_dict, case, modify_original_case=modify_case)
737
715
 
738
716
  @staticmethod
739
717
  def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
740
- case: Any) -> Dict[str, Any]:
718
+ case: Any, modify_original_case: bool = True) -> Dict[str, Any]:
741
719
  """
742
720
  Classify a case by going through all classifiers and adding the categories that are classified,
743
721
  and then restarting the classification until no more categories can be added.
744
722
 
745
723
  :param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
746
724
  :param case: The case to classify.
725
+ :param modify_original_case: Whether to modify the original case or create a copy and modify it.
747
726
  :return: The categories that the case belongs to.
748
727
  """
749
728
  conclusions = {}
750
729
  case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
751
- case_cp = copy_case(case)
730
+ case_cp = copy_case(case) if not modify_original_case else case
752
731
  while True:
753
732
  new_conclusions = {}
754
733
  for attribute_name, rdr in classifiers_dict.items():
@@ -771,13 +750,13 @@ class GeneralRDR(RippleDownRules):
771
750
  conclusions[attribute_name].update(pred_atts)
772
751
  if attribute_name in new_conclusions:
773
752
  mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
774
- GeneralRDR.update_case(CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive),
775
- new_conclusions)
753
+ case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive)
754
+ update_case(case_query, new_conclusions)
776
755
  if len(new_conclusions) == 0:
777
756
  break
778
757
  return conclusions
779
758
 
780
- def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs) \
759
+ def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
781
760
  -> Dict[str, Any]:
782
761
  """
783
762
  Fit the GRDR on a case, if the target is a new type of category, a new RDR is created for it,
@@ -787,115 +766,38 @@ class GeneralRDR(RippleDownRules):
787
766
  they are accepted by the expert, and the attribute of that category is represented in the case as a set of
788
767
  values.
789
768
 
790
- :param case_queries: The queries containing the case to classify and the target categories to compare the case
769
+ :param case_query: The query containing the case to classify and the target category to compare the case
791
770
  with.
792
771
  :param expert: The expert to ask for differentiating features as new rule conditions.
793
772
  :return: The categories that the case belongs to.
794
773
  """
795
- expert = expert if expert else Human()
796
- case_queries = make_list(case_queries)
797
- assert len(case_queries) > 0, "No case queries provided"
798
- case = case_queries[0].case
799
- assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
800
- " for multiple cases use fit instead")
801
- original_case_query_cp = copy(case_queries[0])
802
- for case_query in case_queries:
803
- case_query_cp = copy(case_query)
804
- case_query_cp.case = original_case_query_cp.case
805
- if case_query_cp.target is None:
806
- conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
807
- self.update_case(case_query_cp, conclusions)
808
- expert.ask_for_conclusion(case_query_cp)
809
- if case_query_cp.target is None:
810
- continue
811
- case_query.target = case_query_cp.target
812
-
813
- if case_query.attribute_name not in self.start_rules_dict:
814
- conclusions = self.classify(case)
815
- self.update_case(case_query_cp, conclusions)
816
-
817
- new_rdr = self.initialize_new_rdr_for_attribute(case_query_cp)
818
- self.add_rdr(new_rdr, case_query.attribute_name)
819
-
820
- new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
821
- self.update_case(case_query_cp, {case_query.attribute_name: new_conclusions})
822
- else:
823
- for rdr_attribute_name, rdr in self.start_rules_dict.items():
824
- if case_query.attribute_name != rdr_attribute_name:
825
- conclusions = rdr.classify(case_query_cp.case)
826
- else:
827
- conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
828
- **kwargs)
829
- if conclusions is not None:
830
- if (not is_iterable(conclusions)) or len(conclusions) > 0:
831
- conclusions = {rdr_attribute_name: conclusions}
832
- case_query_cp.mutually_exclusive = True if isinstance(rdr, SingleClassRDR) else False
833
- self.update_case(case_query_cp, conclusions)
834
- case_query.conditions = case_query_cp.conditions
835
774
 
836
- return self.classify(case)
775
+ case_query_cp = copy(case_query)
776
+ self.classify(case_query_cp.case, modify_case=True)
777
+ case_query_cp.update_target_value()
778
+
779
+ self.start_rules_dict[case_query_cp.attribute_name].fit_case(case_query_cp, expert, **kwargs)
780
+
781
+ return self.classify(case_query.case)
782
+
783
+ def update_start_rule(self, case_query: CaseQuery, expert: Expert):
784
+ """
785
+ Update the starting rule of the classifier.
786
+
787
+ :param case_query: The case query to update the starting rule with.
788
+ :param expert: The expert to ask for differentiating features as new rule conditions.
789
+ """
790
+ if case_query.attribute_name not in self.start_rules_dict:
791
+ new_rdr = self.initialize_new_rdr_for_attribute(case_query)
792
+ self.add_rdr(new_rdr, case_query.attribute_name)
837
793
 
838
794
  @staticmethod
839
795
  def initialize_new_rdr_for_attribute(case_query: CaseQuery):
840
796
  """
841
797
  Initialize the appropriate RDR type for the target.
842
798
  """
843
- if case_query.mutually_exclusive is not None:
844
- return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
845
- else MultiClassRDR()
846
- if case_query.attribute_type in [list, set]:
847
- return MultiClassRDR()
848
- attribute = getattr(case_query.case, case_query.attribute_name) \
849
- if hasattr(case_query.case, case_query.attribute_name) else case_query.target(case_query.case)
850
- if isinstance(attribute, CaseAttribute):
851
- return SingleClassRDR(default_conclusion=case_query.default_value) if attribute.mutually_exclusive \
852
- else MultiClassRDR()
853
- else:
854
- return MultiClassRDR() if is_iterable(attribute) or (attribute is None) \
855
- else SingleClassRDR(default_conclusion=case_query.default_value)
856
-
857
- @staticmethod
858
- def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
859
- """
860
- Update the case with the conclusions.
861
-
862
- :param case_query: The case query that contains the case to update.
863
- :param conclusions: The conclusions to update the case with.
864
- """
865
- if not conclusions:
866
- return
867
- if len(conclusions) == 0:
868
- return
869
- if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
870
- for conclusion_name, conclusion in conclusions.items():
871
- attribute = getattr(case_query.case, conclusion_name)
872
- if conclusion_name == case_query.attribute_name:
873
- attribute_type = case_query.attribute_type
874
- else:
875
- attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
876
- if isinstance(attribute, set):
877
- for c in conclusion:
878
- attribute.update(make_set(c))
879
- elif isinstance(attribute, list):
880
- attribute.extend(conclusion)
881
- elif any(at in {List, list} for at in attribute_type):
882
- attribute = [] if attribute is None else attribute
883
- attribute.extend(conclusion)
884
- elif any(at in {Set, set} for at in attribute_type):
885
- attribute = set() if attribute is None else attribute
886
- for c in conclusion:
887
- attribute.update(make_set(c))
888
- elif is_iterable(conclusion) and len(conclusion) == 1 \
889
- and any(at is type(list(conclusion)[0]) for at in attribute_type):
890
- setattr(case_query.case, conclusion_name, list(conclusion)[0])
891
- elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
892
- setattr(case_query.case, conclusion_name, conclusion)
893
- else:
894
- raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
895
- f"{case_query.attribute_type} with conclusion "
896
- f"{conclusion} of type {type(conclusion)}")
897
- else:
898
- case_query.case.update(conclusions)
799
+ return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
800
+ else MultiClassRDR()
899
801
 
900
802
  def _to_json(self) -> Dict[str, Any]:
901
803
  return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
@@ -969,7 +871,7 @@ class GeneralRDR(RippleDownRules):
969
871
 
970
872
  @property
971
873
  def conclusion_type_hint(self) -> str:
972
- return f"List[Union[{', '.join([rdr.conclusion_type_hint for rdr in self.start_rules_dict.values()])}]]"
874
+ return f"Dict[str, Any]"
973
875
 
974
876
  def _get_imports(self, file_path: str) -> str:
975
877
  """
@@ -980,7 +882,7 @@ class GeneralRDR(RippleDownRules):
980
882
  """
981
883
  imports = ""
982
884
  # add type hints
983
- imports += f"from typing_extensions import List, Union, Set\n"
885
+ imports += f"from typing_extensions import Dict, Any, Union, Set\n"
984
886
  # import rdr type
985
887
  imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
986
888
  # add case type
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
5
5
  from enum import Enum
6
6
 
7
7
  from anytree import NodeMixin
8
- from typing_extensions import List, Optional, Self, Union, Dict, Any
8
+ from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
9
9
 
10
10
  from .datastructures.callable_expression import CallableExpression
11
11
  from .datastructures.case import Case
@@ -78,24 +78,26 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
78
78
  """
79
79
  pass
80
80
 
81
- def write_conclusion_as_source_code(self, parent_indent: str = "") -> str:
81
+ def write_conclusion_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
82
82
  """
83
83
  Get the source code representation of the conclusion of the rule.
84
84
 
85
85
  :param parent_indent: The indentation of the parent rule.
86
+ :param defs_file: The file to write the conclusion to if it is a definition.
87
+ :return: The source code representation of the conclusion of the rule.
86
88
  """
87
- conclusion = self.conclusion
88
- if isinstance(conclusion, CallableExpression):
89
- if self.conclusion.user_input is not None:
90
- conclusion = self.conclusion.user_input
91
- else:
92
- conclusion = self.conclusion.conclusion
93
- if isinstance(conclusion, Enum):
94
- conclusion = str(conclusion)
95
- return self._conclusion_source_code(conclusion, parent_indent=parent_indent)
89
+ if self.conclusion.user_input is not None:
90
+ conclusion = self.conclusion.user_input
91
+ else:
92
+ conclusion = self.conclusion.conclusion
93
+ conclusion_func, conclusion_func_call = self._conclusion_source_code(conclusion, parent_indent=parent_indent)
94
+ if conclusion_func is not None:
95
+ with open(defs_file, 'a') as f:
96
+ f.write(conclusion_func + "\n\n")
97
+ return conclusion_func_call
96
98
 
97
99
  @abstractmethod
98
- def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> str:
100
+ def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
99
101
  pass
100
102
 
101
103
  def write_condition_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
@@ -118,7 +120,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
118
120
  conditions_lines[0] = re.sub(r"def (\w+)", new_function_name, conditions_lines[0])
119
121
  def_code = "\n".join(conditions_lines)
120
122
  with open(defs_file, 'a') as f:
121
- f.write(def_code + "\n")
123
+ f.write(def_code + "\n\n")
122
124
  return f"\n{parent_indent}{if_clause} {new_function_name.replace('def ', '')}(case):\n"
123
125
 
124
126
  @abstractmethod
@@ -130,6 +132,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
130
132
  "conclusion": conclusion_to_json(self.conclusion),
131
133
  "parent": self.parent.json_serialization if self.parent else None,
132
134
  "corner_case": self.corner_case.to_json() if self.corner_case else None,
135
+ "conclusion_name": self.conclusion_name,
133
136
  "weight": self.weight}
134
137
  return json_serialization
135
138
 
@@ -139,6 +142,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
139
142
  conclusion=CallableExpression.from_json(data["conclusion"]),
140
143
  parent=cls.from_json(data["parent"]),
141
144
  corner_case=Case.from_json(data["corner_case"]),
145
+ conclusion_name=data["conclusion_name"],
142
146
  weight=data["weight"])
143
147
  return loaded_rule
144
148
 
@@ -264,11 +268,11 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
264
268
  loaded_rule.alternative = SingleClassRule.from_json(data["alternative"])
265
269
  return loaded_rule
266
270
 
267
- def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> str:
271
+ def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
268
272
  conclusion = str(conclusion)
269
273
  indent = parent_indent + " " * 4
270
274
  if '\n' not in conclusion:
271
- return f"{indent}return {conclusion}\n"
275
+ return None, f"{indent}return {conclusion}\n"
272
276
  else:
273
277
  return get_rule_conclusion_as_source_code(self, conclusion, parent_indent=parent_indent)
274
278
 
@@ -315,8 +319,8 @@ class MultiClassStopRule(Rule, HasAlternativeRule):
315
319
  loaded_rule.alternative = MultiClassStopRule.from_json(data["alternative"])
316
320
  return loaded_rule
317
321
 
318
- def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> str:
319
- return f"{parent_indent}{' ' * 4}pass\n"
322
+ def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[None, str]:
323
+ return None, f"{parent_indent}{' ' * 4}pass\n"
320
324
 
321
325
  def _if_statement_source_code_clause(self) -> str:
322
326
  return "elif" if self.weight == RDREdge.Alternative.value else "if"
@@ -360,25 +364,23 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
360
364
  loaded_rule.alternative = MultiClassTopRule.from_json(data["alternative"])
361
365
  return loaded_rule
362
366
 
363
- def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> str:
367
+ def _conclusion_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[str, str]:
364
368
  conclusion_str = str(conclusion)
365
369
  indent = parent_indent + " " * 4
366
- statement = ""
367
370
  if '\n' not in conclusion_str:
371
+ func = None
368
372
  if is_iterable(conclusion):
369
373
  conclusion_str = "{" + ", ".join([str(c) for c in conclusion]) + "}"
370
374
  else:
371
375
  conclusion_str = "{" + str(conclusion) + "}"
372
376
  else:
373
- conclusion_str = get_rule_conclusion_as_source_code(self, conclusion_str, parent_indent=parent_indent)
374
- lines = conclusion_str.split("\n")
375
- conclusion_str = lines[-2].replace("return ", "").strip()
376
- statement += "\n".join(lines[:-2]) + "\n"
377
+ func, func_call = get_rule_conclusion_as_source_code(self, conclusion_str, parent_indent=parent_indent)
378
+ conclusion_str = func_call.replace("return ", "").strip()
377
379
 
378
- statement += f"{indent}conclusions.update(make_set({conclusion_str}))\n"
380
+ statement = f"{indent}conclusions.update(make_set({conclusion_str}))\n"
379
381
  if self.alternative is None:
380
382
  statement += f"{parent_indent}return conclusions\n"
381
- return statement
383
+ return func, statement
382
384
 
383
385
  def _if_statement_source_code_clause(self) -> str:
384
386
  return "if"
@@ -26,6 +26,7 @@ from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get
26
26
 
27
27
  if TYPE_CHECKING:
28
28
  from .datastructures.case import Case
29
+ from .datastructures.dataclasses import CaseQuery
29
30
  from .rules import Rule
30
31
 
31
32
  import ast
@@ -33,6 +34,49 @@ import ast
33
34
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
34
35
 
35
36
 
37
+ def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
38
+ """
39
+ Update the case with the conclusions.
40
+
41
+ :param case_query: The case query that contains the case to update.
42
+ :param conclusions: The conclusions to update the case with.
43
+ """
44
+ if not conclusions:
45
+ return
46
+ if len(conclusions) == 0:
47
+ return
48
+ if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
49
+ for conclusion_name, conclusion in conclusions.items():
50
+ attribute = getattr(case_query.case, conclusion_name)
51
+ if conclusion_name == case_query.attribute_name:
52
+ attribute_type = case_query.attribute_type
53
+ else:
54
+ attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
55
+ if isinstance(attribute, set):
56
+ for c in conclusion:
57
+ attribute.update(make_set(c))
58
+ elif isinstance(attribute, list):
59
+ attribute.extend(conclusion)
60
+ elif any(at in {List, list} for at in attribute_type):
61
+ attribute = [] if attribute is None else attribute
62
+ attribute.extend(conclusion)
63
+ elif any(at in {Set, set} for at in attribute_type):
64
+ attribute = set() if attribute is None else attribute
65
+ for c in conclusion:
66
+ attribute.update(make_set(c))
67
+ elif is_iterable(conclusion) and len(conclusion) == 1 \
68
+ and any(at is type(list(conclusion)[0]) for at in attribute_type):
69
+ setattr(case_query.case, conclusion_name, list(conclusion)[0])
70
+ elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
71
+ setattr(case_query.case, conclusion_name, conclusion)
72
+ else:
73
+ raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
74
+ f"{case_query.attribute_type} with conclusion "
75
+ f"{conclusion} of type {type(conclusion)}")
76
+ else:
77
+ case_query.case.update(conclusions)
78
+
79
+
36
80
  def is_conflicting(conclusion: Any, target: Any) -> bool:
37
81
  """
38
82
  :param conclusion: The conclusion to check.
@@ -75,14 +119,14 @@ def calculate_precision_and_recall(pred_cat: Dict[str, Any], target: Dict[str, A
75
119
  return precision, recall
76
120
 
77
121
 
78
- def get_rule_conclusion_as_source_code(rule: Rule, conclusion: str, parent_indent: str = "") -> str:
122
+ def get_rule_conclusion_as_source_code(rule: Rule, conclusion: str, parent_indent: str = "") -> Tuple[str, str]:
79
123
  """
80
124
  Convert the conclusion of a rule to source code.
81
125
 
82
126
  :param rule: The rule to get the conclusion from.
83
127
  :param conclusion: The conclusion to convert to source code.
84
128
  :param parent_indent: The indentation to use for the source code.
85
- :return: The source code of the conclusion.
129
+ :return: The source code of the conclusion as a tuple of strings, one for the function and one for the call.
86
130
  """
87
131
  indent = f"{parent_indent}{' ' * 4}"
88
132
  if "def " in conclusion:
@@ -91,9 +135,8 @@ def get_rule_conclusion_as_source_code(rule: Rule, conclusion: str, parent_inden
91
135
  # use regex to replace the function name
92
136
  new_function_name = f"def conclusion_{id(rule)}"
93
137
  conclusion_lines[0] = re.sub(r"def (\w+)", new_function_name, conclusion_lines[0])
94
- conclusion_lines = [f"{indent}{line}" for line in conclusion_lines]
95
- conclusion_lines.append(f"{indent}return {new_function_name.replace('def ', '')}(case)\n")
96
- return "\n".join(conclusion_lines)
138
+ func_call = f"{indent}return {new_function_name.replace('def ', '')}(case)\n"
139
+ return "\n".join(conclusion_lines).strip(' '), func_call
97
140
  else:
98
141
  raise ValueError(f"Conclusion is format is not valid, it should be a one line string or "
99
142
  f"contain a function definition. Instead got:\n{conclusion}\n")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.64
3
+ Version: 0.1.66
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -1,20 +1,20 @@
1
1
  ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  ripple_down_rules/datasets.py,sha256=rCSpeFeu1gTuKESwjHUdQkPPvomI5OMRNGpbdKmHwMg,4639
3
- ripple_down_rules/experts.py,sha256=TretkdR1IY2RjcSh4WJUFJtYEbItsand7SYFmzouE_Y,10348
3
+ ripple_down_rules/experts.py,sha256=JGVvSNiWhm4FpRpg76f98tl8Ii_C7x_aWD9FxD-JDLQ,6130
4
4
  ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
5
  ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
6
6
  ripple_down_rules/prompt.py,sha256=cHqhMJqubGhfGpOOY_uXv5L7PBNb64O0IBWSfiY0ui0,6682
7
- ripple_down_rules/rdr.py,sha256=tJDbPF3D2qJ5cw2WPmZrs_C2Jj6b7cCxEJDp2SK9GCI,47863
7
+ ripple_down_rules/rdr.py,sha256=iRoQcfXlCtAP9e_TxWY26ZZOv-Ki4MVRkmfPjlJ2vYY,41535
8
8
  ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
9
- ripple_down_rules/rules.py,sha256=4oQSb4p36A-YzW0hJ8HH2FhmIvokjXKK1rIXEn-8WcE,16203
10
- ripple_down_rules/utils.py,sha256=JIF99Knqzqjgny7unvEnib3sCmExqU-w9xYOSGIT86Q,32276
9
+ ripple_down_rules/rules.py,sha256=Ms9uDDM7QaFrzzPB-nsPOWbIrT-QVqvhmHA4z8HQOgI,16547
10
+ ripple_down_rules/utils.py,sha256=A4ArFvF4lTsjicfN2fV2oIIo2HSapA9LXb92sY6n-Rs,34538
11
11
  ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
12
12
  ripple_down_rules/datastructures/callable_expression.py,sha256=TW_u6CJfelW2CiJj9pWFpdOBNIxeEuhhsQEz_pLpFVE,9092
13
13
  ripple_down_rules/datastructures/case.py,sha256=uM5YkJOfYfARrZZ3oAxKMWE5QBSnvHLLZa9Atoxb7eY,13800
14
- ripple_down_rules/datastructures/dataclasses.py,sha256=_aabVXsgdVUeAmgGA9K_LZpO2U5a6-htrg2Tka7qc30,5960
14
+ ripple_down_rules/datastructures/dataclasses.py,sha256=V757VwxROlevXh5ZVFLVuzwY4JIJKG8ARlCfjhubfy8,5957
15
15
  ripple_down_rules/datastructures/enums.py,sha256=RdyPUp9Ls1QuLmkcMMkBbCWrmXIZI4xWuM-cLPYZhR0,4666
16
- ripple_down_rules-0.1.64.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
- ripple_down_rules-0.1.64.dist-info/METADATA,sha256=PChgWG-I8g-_yaKMDwrxJPuIHOe0fIrThGu_vOqrQkE,42576
18
- ripple_down_rules-0.1.64.dist-info/WHEEL,sha256=ooBFpIzZCPdw3uqIQsOo4qqbA4ZRPxHnOH7peeONza0,91
19
- ripple_down_rules-0.1.64.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
- ripple_down_rules-0.1.64.dist-info/RECORD,,
16
+ ripple_down_rules-0.1.66.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
+ ripple_down_rules-0.1.66.dist-info/METADATA,sha256=PFdnyrcmNvbrVchvwNW0zl19IkCPhmsCAcpJ3VYO49Y,42576
18
+ ripple_down_rules-0.1.66.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
19
+ ripple_down_rules-0.1.66.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
+ ripple_down_rules-0.1.66.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.0.1)
2
+ Generator: setuptools (80.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5