ripple-down-rules 0.1.2__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ripple_down_rules/rdr.py CHANGED
@@ -1,20 +1,28 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import importlib
4
+ import re
5
+ import sys
4
6
  from abc import ABC, abstractmethod
5
7
  from copy import copy
8
+ from dataclasses import is_dataclass
9
+ from io import TextIOWrapper
6
10
  from types import ModuleType
7
11
 
8
12
  from matplotlib import pyplot as plt
9
13
  from ordered_set import OrderedSet
10
- from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
14
+ from sqlalchemy.orm import DeclarativeBase as SQLTable
11
15
  from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
12
16
 
13
- from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute, CaseQuery
17
+ from .datastructures.case import Case, CaseAttribute
18
+ from .datastructures.callable_expression import CallableExpression
19
+ from .datastructures.dataclasses import CaseQuery
20
+ from .datastructures.enums import MCRDRMode
14
21
  from .experts import Expert, Human
15
22
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
16
23
  from .utils import draw_tree, make_set, copy_case, \
17
- get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string
24
+ get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
25
+ get_case_attribute_type, ask_llm
18
26
 
19
27
 
20
28
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -29,14 +37,16 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
29
37
  """
30
38
  The conclusions that the expert has accepted, such that they are not asked again.
31
39
  """
40
+ _generated_python_file_name: Optional[str] = None
41
+ """
42
+ The name of the generated python file.
43
+ """
32
44
 
33
- def __init__(self, start_rule: Optional[Rule] = None, session: Optional[Session] = None):
45
+ def __init__(self, start_rule: Optional[Rule] = None):
34
46
  """
35
47
  :param start_rule: The starting rule for the classifier.
36
- :param session: The sqlalchemy orm session.
37
48
  """
38
49
  self.start_rule = start_rule
39
- self.session = session
40
50
  self.fig: Optional[plt.Figure] = None
41
51
 
42
52
  def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
@@ -79,31 +89,29 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
79
89
  :param animate_tree: Whether to draw the tree while fitting the classifier.
80
90
  :param kwargs_for_fit_case: The keyword arguments to pass to the fit_case method.
81
91
  """
82
- cases = [case_query.case for case_query in case_queries]
83
- targets = [{case_query.attribute_name: case_query.target} for case_query in case_queries]
92
+ targets = []
84
93
  if animate_tree:
85
94
  plt.ion()
86
95
  i = 0
87
96
  stop_iterating = False
88
97
  num_rules: int = 0
89
98
  while not stop_iterating:
90
- all_pred = 0
91
- if not targets:
92
- targets = [None] * len(cases)
93
99
  for case_query in case_queries:
94
- target = {case_query.attribute_name: case_query.target}
95
100
  pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
96
- match = self.is_matching(pred_cat, target)
101
+ if case_query.target is None:
102
+ continue
103
+ target = {case_query.attribute_name: case_query.target(case_query.case)}
104
+ if len(targets) < len(case_queries):
105
+ targets.append(target)
106
+ match = self.is_matching(case_query, pred_cat)
97
107
  if not match:
98
108
  print(f"Predicted: {pred_cat} but expected: {target}")
99
- all_pred += int(match)
100
109
  if animate_tree and self.start_rule.size > num_rules:
101
110
  num_rules = self.start_rule.size
102
111
  self.update_figures()
103
112
  i += 1
104
- all_predictions = [1 if self.is_matching(self.classify(case_query.case), {case_query.attribute_name:
105
- case_query.target}) else 0
106
- for case_query in case_queries]
113
+ all_predictions = [1 if self.is_matching(case_query) else 0 for case_query in case_queries
114
+ if case_query.target is not None]
107
115
  all_pred = sum(all_predictions)
108
116
  print(f"Accuracy: {all_pred}/{len(targets)}")
109
117
  all_predicted = targets and all_pred == len(targets)
@@ -116,48 +124,43 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
116
124
  plt.ioff()
117
125
  plt.show()
118
126
 
127
+ def is_matching(self, case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
128
+ """
129
+ :param case_query: The case query to check.
130
+ :param pred_cat: The predicted category.
131
+ :return: Whether the classifier prediction is matching case_query target or not.
132
+ """
133
+ if case_query.target is None:
134
+ return False
135
+ if pred_cat is None:
136
+ pred_cat = self.classify(case_query.case)
137
+ if not isinstance(pred_cat, dict):
138
+ pred_cat = {case_query.attribute_name: pred_cat}
139
+ target = {case_query.attribute_name: case_query.target_value}
140
+ precision, recall = self.calculate_precision_and_recall(pred_cat, target)
141
+ return all(recall) and all(precision)
142
+
119
143
  @staticmethod
120
- def calculate_precision_and_recall(pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> Tuple[
144
+ def calculate_precision_and_recall(pred_cat: Dict[str, Any], target: Dict[str, Any]) -> Tuple[
121
145
  List[bool], List[bool]]:
122
146
  """
123
147
  :param pred_cat: The predicted category.
124
148
  :param target: The target category.
125
149
  :return: The precision and recall of the classifier.
126
150
  """
127
- pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
128
- target = target if is_iterable(target) else [target]
129
151
  recall = []
130
152
  precision = []
131
- if isinstance(pred_cat, dict):
132
- for pred_key, pred_value in pred_cat.items():
133
- if pred_key not in target:
134
- continue
135
- precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
136
- for target_key, target_value in target.items():
137
- if target_key not in pred_cat:
138
- recall.append(False)
139
- continue
140
- if is_iterable(target_value):
141
- recall.extend([v in pred_cat[target_key] for v in target_value])
142
- else:
143
- recall.append(target_value == pred_cat[target_key])
144
- else:
145
- if isinstance(target, dict):
146
- target = list(target.values())
147
- recall = [not yi or (yi in pred_cat) for yi in target]
148
- target_types = [type(yi) for yi in target]
149
- precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
153
+ for pred_key, pred_value in pred_cat.items():
154
+ if pred_key not in target:
155
+ continue
156
+ precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
157
+ for target_key, target_value in target.items():
158
+ if target_key not in pred_cat:
159
+ recall.append(False)
160
+ continue
161
+ recall.extend([v in make_set(pred_cat[target_key]) for v in make_set(target_value)])
150
162
  return precision, recall
151
163
 
152
- def is_matching(self, pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> bool:
153
- """
154
- :param pred_cat: The predicted category.
155
- :param target: The target category.
156
- :return: Whether the classifier is matching or not.
157
- """
158
- precision, recall = self.calculate_precision_and_recall(pred_cat, target)
159
- return all(recall) and all(precision)
160
-
161
164
  def update_figures(self):
162
165
  """
163
166
  Update the figures of the classifier.
@@ -183,33 +186,51 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
183
186
  """
184
187
  return hasattr(case, conclusion_name) and getattr(case, conclusion_name) is not None
185
188
 
189
+ @property
190
+ def type_(self):
191
+ return self.__class__
192
+
186
193
 
187
194
  class RDRWithCodeWriter(RippleDownRules, ABC):
188
195
 
189
196
  @abstractmethod
190
- def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = ""):
197
+ def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
198
+ defs_file: Optional[str] = None):
191
199
  """
192
200
  Write the rules as source code to a file.
193
201
 
194
202
  :param rule: The rule to write as source code.
195
203
  :param file: The file to write the source code to.
196
204
  :param parent_indent: The indentation of the parent rule.
205
+ :param defs_file: The file to write the definitions to.
197
206
  """
198
207
  pass
199
208
 
200
- def write_to_python_file(self, file_path: str):
209
+ def write_to_python_file(self, file_path: str, postfix: str = ""):
201
210
  """
202
211
  Write the tree of rules as source code to a file.
203
212
 
204
213
  :param file_path: The path to the file to write the source code to.
214
+ :param postfix: The postfix to add to the file name.
205
215
  """
216
+ self.generated_python_file_name = self._default_generated_python_file_name + postfix
206
217
  func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
207
- with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
208
- f.write(self._get_imports() + "\n\n")
218
+ file_name = file_path + f"/{self.generated_python_file_name}.py"
219
+ defs_file_name = file_path + f"/{self.generated_python_defs_file_name}.py"
220
+ imports = self._get_imports()
221
+ # clear the files first
222
+ with open(defs_file_name, "w") as f:
223
+ f.write(imports + "\n\n")
224
+ with open(file_name, "w") as f:
225
+ imports += f"from .{self.generated_python_defs_file_name} import *\n"
226
+ imports += f"from ripple_down_rules.rdr import {self.__class__.__name__}\n"
227
+ f.write(imports + "\n\n")
228
+ f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n\n")
229
+ f.write(f"type_ = {self.__class__.__name__}\n\n")
209
230
  f.write(func_def)
210
231
  f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
211
232
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
212
- self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
233
+ self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4, defs_file=defs_file_name)
213
234
 
214
235
  @property
215
236
  @abstractmethod
@@ -226,26 +247,73 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
226
247
  imports = ""
227
248
  if self.case_type.__module__ != "builtins":
228
249
  imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
229
- if self.conclusion_type.__module__ != "builtins":
230
- imports += f"from {self.conclusion_type.__module__} import {self.conclusion_type.__name__}\n"
231
- imports += "from ripple_down_rules.datastructures import Case, create_case\n"
250
+ for conclusion_type in self.conclusion_type:
251
+ if conclusion_type.__module__ != "builtins":
252
+ imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
253
+ imports += "from ripple_down_rules.datastructures.case import Case, create_case\n"
232
254
  for rule in [self.start_rule] + list(self.start_rule.descendants):
233
- if rule.conditions:
234
- if rule.conditions.scope is not None and len(rule.conditions.scope) > 0:
235
- for k, v in rule.conditions.scope.items():
236
- imports += f"from {v.__module__} import {v.__name__}\n"
255
+ if not rule.conditions:
256
+ continue
257
+ if rule.conditions.scope is None or len(rule.conditions.scope) == 0:
258
+ continue
259
+ for k, v in rule.conditions.scope.items():
260
+ new_imports = f"from {v.__module__} import {v.__name__}\n"
261
+ if new_imports in imports:
262
+ continue
263
+ imports += new_imports
237
264
  return imports
238
265
 
239
- def get_rdr_classifier_from_python_file(self, package_name) -> Callable[[Any], Any]:
266
+ def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
240
267
  """
241
268
  :param package_name: The name of the package that contains the RDR classifier function.
242
269
  :return: The module that contains the rdr classifier function.
243
270
  """
244
- return importlib.import_module(f"{package_name.strip('./')}.{self.generated_python_file_name}").classify
271
+ # remove from imports if exists first
272
+ name = f"{package_name.strip('./')}.{self.generated_python_file_name}"
273
+ try:
274
+ module = importlib.import_module(name)
275
+ del sys.modules[name]
276
+ except ModuleNotFoundError:
277
+ pass
278
+ return importlib.import_module(name).classify
245
279
 
246
280
  @property
247
281
  def generated_python_file_name(self) -> str:
248
- return f"{self.start_rule.corner_case._name.lower()}_{self.attribute_name}_rdr"
282
+ if self._generated_python_file_name is None:
283
+ self._generated_python_file_name = self._default_generated_python_file_name
284
+ return self._generated_python_file_name
285
+
286
+ @generated_python_file_name.setter
287
+ def generated_python_file_name(self, value: str):
288
+ """
289
+ Set the generated python file name.
290
+ :param value: The new value for the generated python file name.
291
+ """
292
+ self._generated_python_file_name = value
293
+
294
+ @property
295
+ def _default_generated_python_file_name(self) -> str:
296
+ """
297
+ :return: The default generated python file name.
298
+ """
299
+ return f"{self.start_rule.corner_case._name.lower()}_{self.attribute_name}_{self.acronym.lower()}"
300
+
301
+ @property
302
+ def generated_python_defs_file_name(self) -> str:
303
+ return f"{self.generated_python_file_name}_defs"
304
+
305
+ @property
306
+ def acronym(self) -> str:
307
+ """
308
+ :return: The acronym of the classifier.
309
+ """
310
+ if self.__class__.__name__ == "GeneralRDR":
311
+ return "GRDR"
312
+ elif self.__class__.__name__ == "MultiClassRDR":
313
+ return "MCRDR"
314
+ else:
315
+ return "SCRDR"
316
+
249
317
 
250
318
  @property
251
319
  def case_type(self) -> Type:
@@ -258,16 +326,17 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
258
326
  return type(self.start_rule.corner_case)
259
327
 
260
328
  @property
261
- def conclusion_type(self) -> Type:
329
+ def conclusion_type(self) -> Tuple[Type]:
262
330
  """
263
331
  :return: The type of the conclusion of the RDR classifier.
264
332
  """
265
333
  if isinstance(self.start_rule.conclusion, CallableExpression):
266
334
  return self.start_rule.conclusion.conclusion_type
267
335
  else:
268
- if isinstance(self.start_rule.conclusion, set):
269
- return type(list(self.start_rule.conclusion)[0])
270
- return type(self.start_rule.conclusion)
336
+ conclusion = self.start_rule.conclusion
337
+ if isinstance(conclusion, set):
338
+ return type(list(conclusion)[0]), set
339
+ return (type(conclusion),)
271
340
 
272
341
  @property
273
342
  def attribute_name(self) -> str:
@@ -279,8 +348,16 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
279
348
 
280
349
  class SingleClassRDR(RDRWithCodeWriter):
281
350
 
351
+ def __init__(self, start_rule: Optional[SingleClassRule] = None, default_conclusion: Optional[Any] = None):
352
+ """
353
+ :param start_rule: The starting rule for the classifier.
354
+ :param default_conclusion: The default conclusion for the classifier if no rules fire.
355
+ """
356
+ super(SingleClassRDR, self).__init__(start_rule)
357
+ self.default_conclusion: Optional[Any] = default_conclusion
358
+
282
359
  def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
283
- -> Union[CaseAttribute, CallableExpression]:
360
+ -> Union[CaseAttribute, CallableExpression, None]:
284
361
  """
285
362
  Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
286
363
  comparing the case with the target category if provided.
@@ -289,28 +366,31 @@ class SingleClassRDR(RDRWithCodeWriter):
289
366
  :param expert: The expert to ask for differentiating features as new rule conditions.
290
367
  :return: The category that the case belongs to.
291
368
  """
292
- expert = expert if expert else Human(session=self.session)
293
- if case_query.target is None:
294
- target = expert.ask_for_conclusion(case_query)
369
+ expert = expert if expert else Human()
370
+ if case_query.default_value is not None and self.default_conclusion != case_query.default_value:
371
+ self.default_conclusion = case_query.default_value
372
+ case = case_query.case
373
+ target = expert.ask_for_conclusion(case_query) if case_query.target is None else case_query.target
374
+ if target is None:
375
+ return self.classify(case)
295
376
  if not self.start_rule:
296
377
  conditions = expert.ask_for_conditions(case_query)
297
- self.start_rule = SingleClassRule(conditions, case_query.target, corner_case=case_query.case,
378
+ self.start_rule = SingleClassRule(conditions, target, corner_case=case,
298
379
  conclusion_name=case_query.attribute_name)
299
380
 
300
381
  pred = self.evaluate(case_query.case)
301
-
302
- if pred.conclusion != case_query.target:
382
+ if pred.conclusion(case) != target(case):
303
383
  conditions = expert.ask_for_conditions(case_query, pred)
304
- pred.fit_rule(case_query.case, case_query.target, conditions=conditions)
384
+ pred.fit_rule(case_query.case, target, conditions=conditions)
305
385
 
306
386
  return self.classify(case_query.case)
307
387
 
308
- def classify(self, case: Case) -> Optional[CaseAttribute]:
388
+ def classify(self, case: Case) -> Optional[Any]:
309
389
  """
310
390
  Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
311
391
  """
312
392
  pred = self.evaluate(case)
313
- return pred.conclusion if pred.fired else None
393
+ return pred.conclusion(case) if pred.fired else self.default_conclusion
314
394
 
315
395
  def evaluate(self, case: Case) -> SingleClassRule:
316
396
  """
@@ -319,23 +399,27 @@ class SingleClassRDR(RDRWithCodeWriter):
319
399
  matched_rule = self.start_rule(case)
320
400
  return matched_rule if matched_rule else self.start_rule
321
401
 
322
- def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file, parent_indent: str = ""):
402
+ def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
403
+ defs_file: Optional[str] = None):
323
404
  """
324
405
  Write the rules as source code to a file.
325
406
  """
326
407
  if rule.conditions:
327
- file.write(rule.write_condition_as_source_code(parent_indent))
408
+ if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
409
+ file.write(if_clause)
328
410
  if rule.refinement:
329
- self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ")
411
+
412
+ self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
413
+ defs_file=defs_file)
330
414
 
331
415
  file.write(rule.write_conclusion_as_source_code(parent_indent))
332
416
 
333
417
  if rule.alternative:
334
- self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
418
+ self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
335
419
 
336
420
  @property
337
421
  def conclusion_type_hint(self) -> str:
338
- return self.conclusion_type.__name__
422
+ return self.conclusion_type[0].__name__
339
423
 
340
424
  def _to_json(self) -> Dict[str, Any]:
341
425
  return {"start_rule": self.start_rule.to_json()}
@@ -369,28 +453,27 @@ class MultiClassRDR(RDRWithCodeWriter):
369
453
  """
370
454
 
371
455
  def __init__(self, start_rule: Optional[Rule] = None,
372
- mode: MCRDRMode = MCRDRMode.StopOnly, session: Optional[Session] = None):
456
+ mode: MCRDRMode = MCRDRMode.StopOnly):
373
457
  """
374
458
  :param start_rule: The starting rules for the classifier.
375
459
  :param mode: The mode of the classifier, either StopOnly or StopPlusRule, or StopPlusRuleCombined.
376
- :param session: The sqlalchemy orm session.
377
460
  """
378
461
  start_rule = MultiClassTopRule() if not start_rule else start_rule
379
- super(MultiClassRDR, self).__init__(start_rule, session=session)
462
+ super(MultiClassRDR, self).__init__(start_rule)
380
463
  self.mode: MCRDRMode = mode
381
464
 
382
- def classify(self, case: Union[Case, SQLTable]) -> List[Any]:
465
+ def classify(self, case: Union[Case, SQLTable]) -> Set[Any]:
383
466
  evaluated_rule = self.start_rule
384
467
  self.conclusions = []
385
468
  while evaluated_rule:
386
469
  next_rule = evaluated_rule(case)
387
470
  if evaluated_rule.fired:
388
- self.add_conclusion(evaluated_rule)
471
+ self.add_conclusion(evaluated_rule, case)
389
472
  evaluated_rule = next_rule
390
- return self.conclusions
473
+ return make_set(self.conclusions)
391
474
 
392
475
  def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None,
393
- add_extra_conclusions: bool = False) -> List[Union[CaseAttribute, CallableExpression]]:
476
+ add_extra_conclusions: bool = False) -> Set[Union[CaseAttribute, CallableExpression, None]]:
394
477
  """
395
478
  Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
396
479
  or missing by comparing the case with the target category if provided.
@@ -400,32 +483,35 @@ class MultiClassRDR(RDRWithCodeWriter):
400
483
  :param add_extra_conclusions: Whether to add extra conclusions after classification is done.
401
484
  :return: The conclusions that the case belongs to.
402
485
  """
403
- expert = expert if expert else Human(session=self.session)
486
+ expert = expert if expert else Human()
487
+ if case_query.target is None:
488
+ expert.ask_for_conclusion(case_query)
404
489
  if case_query.target is None:
405
- targets = expert.ask_for_conclusion(case_query)
490
+ return self.classify(case_query.case)
491
+ self.update_start_rule(case_query, expert)
406
492
  self.expert_accepted_conclusions = []
407
493
  user_conclusions = []
408
- self.update_start_rule(case_query, expert)
409
494
  self.conclusions = []
410
495
  self.stop_rule_conditions = None
411
496
  evaluated_rule = self.start_rule
497
+ target = case_query.target(case_query.case)
412
498
  while evaluated_rule:
413
499
  next_rule = evaluated_rule(case_query.case)
414
- good_conclusions = make_list(case_query.target) + user_conclusions + self.expert_accepted_conclusions
500
+ rule_conclusion = evaluated_rule.conclusion(case_query.case)
501
+ good_conclusions = make_list(target) + user_conclusions + self.expert_accepted_conclusions
415
502
  good_conclusions = make_set(good_conclusions)
416
503
 
417
504
  if evaluated_rule.fired:
418
- if case_query.target and not make_set(evaluated_rule.conclusion).issubset(good_conclusions):
419
- # if self.case_has_conclusion(case, evaluated_rule.conclusion):
505
+ if target and not make_set(rule_conclusion).issubset(good_conclusions):
420
506
  # Rule fired and conclusion is different from target
421
507
  self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
422
508
  add_extra_conclusions)
423
509
  else:
424
510
  # Rule fired and target is correct or there is no target to compare
425
- self.add_conclusion(evaluated_rule)
511
+ self.add_conclusion(evaluated_rule, case_query.case)
426
512
 
427
513
  if not next_rule:
428
- if not make_set(case_query.target).intersection(make_set(self.conclusions)):
514
+ if not make_set(target).issubset(make_set(self.conclusions)):
429
515
  # Nothing fired and there is a target that should have been in the conclusions
430
516
  self.add_rule_for_case(case_query, expert)
431
517
  # Have to check all rules again to make sure only this new rule fires
@@ -439,33 +525,31 @@ class MultiClassRDR(RDRWithCodeWriter):
439
525
  return self.conclusions
440
526
 
441
527
  def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
442
- file, parent_indent: str = ""):
443
- """
444
- Write the rules as source code to a file.
445
-
446
- :
447
- """
528
+ file, parent_indent: str = "", defs_file: Optional[str] = None):
448
529
  if rule == self.start_rule:
449
530
  file.write(f"{parent_indent}conclusions = set()\n")
450
531
  if rule.conditions:
451
- file.write(rule.write_condition_as_source_code(parent_indent))
532
+ if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
533
+ file.write(if_clause)
452
534
  conclusion_indent = parent_indent
453
535
  if hasattr(rule, "refinement") and rule.refinement:
454
- self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ")
536
+ self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
537
+ defs_file=defs_file)
455
538
  conclusion_indent = parent_indent + " " * 4
456
539
  file.write(f"{conclusion_indent}else:\n")
457
540
  file.write(rule.write_conclusion_as_source_code(conclusion_indent))
458
541
 
459
542
  if rule.alternative:
460
- self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
543
+ self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent, defs_file=defs_file)
461
544
 
462
545
  @property
463
546
  def conclusion_type_hint(self) -> str:
464
- return f"Set[{self.conclusion_type.__name__}]"
547
+ return f"Set[{self.conclusion_type[0].__name__}]"
465
548
 
466
549
  def _get_imports(self) -> str:
467
550
  imports = super()._get_imports()
468
551
  imports += "from typing_extensions import Set\n"
552
+ imports += "from ripple_down_rules.utils import make_set\n"
469
553
  return imports
470
554
 
471
555
  def update_start_rule(self, case_query: CaseQuery, expert: Expert):
@@ -498,8 +582,10 @@ class MultiClassRDR(RDRWithCodeWriter):
498
582
  """
499
583
  Stop a wrong conclusion by adding a stopping rule.
500
584
  """
501
- if self.is_same_category_type(evaluated_rule.conclusion, case_query.target) \
502
- and self.is_conflicting_with_target(evaluated_rule.conclusion, case_query.target):
585
+ target = case_query.target(case_query.case)
586
+ rule_conclusion = evaluated_rule.conclusion(case_query.case)
587
+ if self.is_same_category_type(rule_conclusion, target) \
588
+ and self.is_conflicting_with_target(rule_conclusion, target):
503
589
  self.stop_conclusion(case_query, expert, evaluated_rule)
504
590
  elif not self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
505
591
  self.stop_conclusion(case_query, expert, evaluated_rule)
@@ -559,10 +645,11 @@ class MultiClassRDR(RDRWithCodeWriter):
559
645
  :return: Whether the conclusion is correct or not.
560
646
  """
561
647
  conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
562
- if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case, evaluated_rule.conclusion,
563
- targets=case_query.target,
648
+ if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
649
+ evaluated_rule.conclusion(case_query.case),
650
+ targets=case_query.target(case_query.case),
564
651
  current_conclusions=conclusions)):
565
- self.add_conclusion(evaluated_rule)
652
+ self.add_conclusion(evaluated_rule, case_query.case)
566
653
  self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
567
654
  return True
568
655
  return False
@@ -600,19 +687,21 @@ class MultiClassRDR(RDRWithCodeWriter):
600
687
  extra_conclusions.append(conclusion)
601
688
  return extra_conclusions
602
689
 
603
- def add_conclusion(self, evaluated_rule: Rule) -> None:
690
+ def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
604
691
  """
605
692
  Add the conclusion of the evaluated rule to the list of conclusions.
606
693
 
607
694
  :param evaluated_rule: The evaluated rule to add the conclusion of.
695
+ :param case: The case to add the conclusion for.
608
696
  """
609
697
  conclusion_types = [type(c) for c in self.conclusions]
610
- if type(evaluated_rule.conclusion) not in conclusion_types:
611
- self.conclusions.extend(make_list(evaluated_rule.conclusion))
698
+ rule_conclusion = evaluated_rule.conclusion(case)
699
+ if type(rule_conclusion) not in conclusion_types:
700
+ self.conclusions.extend(make_list(rule_conclusion))
612
701
  else:
613
- same_type_conclusions = [c for c in self.conclusions if type(c) == type(evaluated_rule.conclusion)]
614
- combined_conclusion = evaluated_rule.conclusion if isinstance(evaluated_rule.conclusion, set) \
615
- else {evaluated_rule.conclusion}
702
+ same_type_conclusions = [c for c in self.conclusions if type(c) == type(rule_conclusion)]
703
+ combined_conclusion = rule_conclusion if isinstance(rule_conclusion, set) \
704
+ else {rule_conclusion}
616
705
  combined_conclusion = copy(combined_conclusion)
617
706
  for c in same_type_conclusions:
618
707
  combined_conclusion.update(c if isinstance(c, set) else make_set(c))
@@ -720,22 +809,24 @@ class GeneralRDR(RippleDownRules):
720
809
  pred_atts = rdr.classify(case_cp)
721
810
  if pred_atts is None:
722
811
  continue
723
- if isinstance(rdr, SingleClassRDR):
812
+ if rdr.type_ is SingleClassRDR:
724
813
  if attribute_name not in conclusions or \
725
814
  (attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
726
815
  conclusions[attribute_name] = pred_atts
727
816
  new_conclusions[attribute_name] = pred_atts
728
817
  else:
729
- pred_atts = make_list(pred_atts)
818
+ pred_atts = make_set(pred_atts)
730
819
  if attribute_name in conclusions:
731
- pred_atts = [p for p in pred_atts if p not in conclusions[attribute_name]]
820
+ pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
732
821
  if len(pred_atts) > 0:
733
822
  new_conclusions[attribute_name] = pred_atts
734
823
  if attribute_name not in conclusions:
735
- conclusions[attribute_name] = []
736
- conclusions[attribute_name].extend(pred_atts)
824
+ conclusions[attribute_name] = set()
825
+ conclusions[attribute_name].update(pred_atts)
737
826
  if attribute_name in new_conclusions:
738
- GeneralRDR.update_case(case_cp, new_conclusions)
827
+ mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
828
+ GeneralRDR.update_case(CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive),
829
+ new_conclusions)
739
830
  if len(new_conclusions) == 0:
740
831
  break
741
832
  return conclusions
@@ -761,76 +852,98 @@ class GeneralRDR(RippleDownRules):
761
852
  case = case_queries[0].case
762
853
  assert all([case is case_query.case for case_query in case_queries]), ("fit_case requires only one case,"
763
854
  " for multiple cases use fit instead")
764
- case_cp = copy(case_queries[0]).case
855
+ original_case_query_cp = copy(case_queries[0])
765
856
  for case_query in case_queries:
766
857
  case_query_cp = copy(case_query)
767
- case_query_cp.case = case_cp
768
- if case_query.target is None:
858
+ case_query_cp.case = original_case_query_cp.case
859
+ if case_query_cp.target is None:
769
860
  conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
770
- target = expert.ask_for_conclusion(case_query)
861
+ self.update_case(case_query_cp, conclusions)
862
+ expert.ask_for_conclusion(case_query_cp)
863
+ if case_query_cp.target is None:
864
+ continue
865
+ case_query.target = case_query_cp.target
771
866
 
772
867
  if case_query.attribute_name not in self.start_rules_dict:
773
868
  conclusions = self.classify(case)
774
- self.update_case(case_cp, conclusions)
869
+ self.update_case(case_query_cp, conclusions)
775
870
 
776
- new_rdr = self.initialize_new_rdr_for_attribute(case_query.attribute_name, case_cp, case_query.target)
871
+ new_rdr = self.initialize_new_rdr_for_attribute(case_query_cp)
777
872
  self.add_rdr(new_rdr, case_query.attribute_name)
778
873
 
779
874
  new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
780
- self.update_case(case_cp, {case_query.attribute_name: new_conclusions})
875
+ self.update_case(case_query_cp, {case_query.attribute_name: new_conclusions})
781
876
  else:
782
877
  for rdr_attribute_name, rdr in self.start_rules_dict.items():
783
878
  if case_query.attribute_name != rdr_attribute_name:
784
- conclusions = rdr.classify(case_cp)
879
+ conclusions = rdr.classify(case_query_cp.case)
785
880
  else:
786
881
  conclusions = self.start_rules_dict[rdr_attribute_name].fit_case(case_query_cp, expert,
787
882
  **kwargs)
788
883
  if conclusions is not None or (is_iterable(conclusions) and len(conclusions) > 0):
789
884
  conclusions = {rdr_attribute_name: conclusions}
790
- self.update_case(case_cp, conclusions)
885
+ case_query_cp.mutually_exclusive = True if isinstance(rdr, SingleClassRDR) else False
886
+ self.update_case(case_query_cp, conclusions)
887
+ case_query.conditions = case_query_cp.conditions
791
888
 
792
889
  return self.classify(case)
793
890
 
794
891
  @staticmethod
795
- def initialize_new_rdr_for_attribute(attribute_name: str, case: Union[Case, SQLTable], target: Any):
892
+ def initialize_new_rdr_for_attribute(case_query: CaseQuery):
796
893
  """
797
894
  Initialize the appropriate RDR type for the target.
798
895
  """
799
- attribute = getattr(case, attribute_name) if hasattr(case, attribute_name) else target
896
+ if case_query.mutually_exclusive is not None:
897
+ return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive\
898
+ else MultiClassRDR()
899
+ if case_query.attribute_type in [list, set]:
900
+ return MultiClassRDR()
901
+ attribute = getattr(case_query.case, case_query.attribute_name)\
902
+ if hasattr(case_query.case, case_query.attribute_name) else case_query.target(case_query.case)
800
903
  if isinstance(attribute, CaseAttribute):
801
- return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
904
+ return SingleClassRDR(default_conclusion=case_query.default_value) if attribute.mutually_exclusive \
905
+ else MultiClassRDR()
802
906
  else:
803
- return MultiClassRDR() if is_iterable(attribute) or (attribute is None) else SingleClassRDR()
907
+ return MultiClassRDR() if is_iterable(attribute) or (attribute is None)\
908
+ else SingleClassRDR(default_conclusion=case_query.default_value)
804
909
 
805
910
  @staticmethod
806
- def update_case(case: Union[Case, SQLTable], conclusions: Dict[str, Any]):
911
+ def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
807
912
  """
808
913
  Update the case with the conclusions.
809
914
 
810
- :param case: The case to update.
915
+ :param case_query: The case query that contains the case to update.
811
916
  :param conclusions: The conclusions to update the case with.
812
917
  """
813
918
  if not conclusions:
814
919
  return
815
920
  if len(conclusions) == 0:
816
921
  return
817
- if isinstance(case, SQLTable):
922
+ if isinstance(case_query.original_case, SQLTable) or is_dataclass(case_query.original_case):
818
923
  for conclusion_name, conclusion in conclusions.items():
819
- hint, origin, args = get_hint_for_attribute(conclusion_name, case)
820
- attribute = getattr(case, conclusion_name)
821
- if isinstance(attribute, set) or origin in {Set, set}:
924
+ attribute = getattr(case_query.case, conclusion_name)
925
+ if conclusion_name == case_query.attribute_name:
926
+ attribute_type = case_query.attribute_type
927
+ else:
928
+ attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
929
+ if isinstance(attribute, set) or any(at in {Set, set} for at in attribute_type):
822
930
  attribute = set() if attribute is None else attribute
823
931
  for c in conclusion:
824
932
  attribute.update(make_set(c))
825
- elif isinstance(attribute, list) or origin in {list, List}:
933
+ elif isinstance(attribute, list) or any(at in {List, list} for at in attribute_type):
826
934
  attribute = [] if attribute is None else attribute
827
935
  attribute.extend(conclusion)
828
- elif (not is_iterable(conclusion) or (len(conclusion) == 1)) and hint == type(conclusion):
829
- setattr(case, conclusion_name, conclusion)
936
+ elif is_iterable(conclusion) and len(conclusion) == 1 \
937
+ and any(at is type(list(conclusion)[0]) for at in attribute_type):
938
+ setattr(case_query.case, conclusion_name, list(conclusion)[0])
939
+ elif not is_iterable(conclusion) and any(at is type(conclusion) for at in attribute_type):
940
+ setattr(case_query.case, conclusion_name, conclusion)
830
941
  else:
831
- raise ValueError(f"Cannot add multiple conclusions to attribute {conclusion_name}")
942
+ raise ValueError(f"Unknown type or type mismatch for attribute {conclusion_name} with type "
943
+ f"{case_query.attribute_type} with conclusion "
944
+ f"{conclusion} of type {type(conclusion)}")
832
945
  else:
833
- case.update(conclusions)
946
+ case_query.case.update(conclusions)
834
947
 
835
948
  def _to_json(self) -> Dict[str, Any]:
836
949
  return {"start_rules": {t: rdr.to_json() for t, rdr in self.start_rules_dict.items()}}
@@ -845,14 +958,16 @@ class GeneralRDR(RippleDownRules):
845
958
  start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
846
959
  return cls(start_rules_dict)
847
960
 
848
- def write_to_python_file(self, file_path: str):
961
+ def write_to_python_file(self, file_path: str, postfix: str = "") -> None:
849
962
  """
850
963
  Write the tree of rules as source code to a file.
851
964
 
852
965
  :param file_path: The path to the file to write the source code to.
966
+ :param postfix: The postfix to add to the file name.
853
967
  """
968
+ self.generated_python_file_name = self._default_generated_python_file_name + postfix
854
969
  for rdr in self.start_rules_dict.values():
855
- rdr.write_to_python_file(file_path)
970
+ rdr.write_to_python_file(file_path, postfix=f"_of_grdr{postfix}")
856
971
  func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
857
972
  with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
858
973
  f.write(self._get_imports(file_path) + "\n\n")
@@ -875,15 +990,29 @@ class GeneralRDR(RippleDownRules):
875
990
  else:
876
991
  return type(self.start_rule.corner_case)
877
992
 
878
- def get_rdr_classifier_from_python_file(self, file_path: str):
993
+ def get_rdr_classifier_from_python_file(self, file_path: str) -> Callable[[Any], Any]:
879
994
  """
880
995
  :param file_path: The path to the file that contains the RDR classifier function.
996
+ :param postfix: The postfix to add to the file name.
881
997
  :return: The module that contains the rdr classifier function.
882
998
  """
883
999
  return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
884
1000
 
885
1001
  @property
886
1002
  def generated_python_file_name(self) -> str:
1003
+ if self._generated_python_file_name is None:
1004
+ self._generated_python_file_name = self._default_generated_python_file_name
1005
+ return self._generated_python_file_name
1006
+
1007
+ @generated_python_file_name.setter
1008
+ def generated_python_file_name(self, value: str):
1009
+ self._generated_python_file_name = value
1010
+
1011
+ @property
1012
+ def _default_generated_python_file_name(self) -> str:
1013
+ """
1014
+ :return: The default generated python file name.
1015
+ """
887
1016
  return f"{self.start_rule.corner_case._name.lower()}_rdr"
888
1017
 
889
1018
  @property
@@ -891,17 +1020,25 @@ class GeneralRDR(RippleDownRules):
891
1020
  return f"List[Union[{', '.join([rdr.conclusion_type_hint for rdr in self.start_rules_dict.values()])}]]"
892
1021
 
893
1022
  def _get_imports(self, file_path: str) -> str:
1023
+ """
1024
+ Get the imports needed for the generated python file.
1025
+
1026
+ :param file_path: The path to the file that contains the RDR classifier function.
1027
+ :return: The imports needed for the generated python file.
1028
+ """
894
1029
  imports = ""
895
1030
  # add type hints
896
1031
  imports += f"from typing_extensions import List, Union, Set\n"
897
1032
  # import rdr type
898
1033
  imports += f"from ripple_down_rules.rdr import GeneralRDR\n"
899
1034
  # add case type
900
- imports += f"from ripple_down_rules.datastructures import Case, create_case\n"
1035
+ imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
901
1036
  imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
902
1037
  # add conclusion type imports
903
1038
  for rdr in self.start_rules_dict.values():
904
- imports += f"from {rdr.conclusion_type.__module__} import {rdr.conclusion_type.__name__}\n"
1039
+ for conclusion_type in rdr.conclusion_type:
1040
+ if conclusion_type.__module__ != "builtins":
1041
+ imports += f"from {conclusion_type.__module__} import {conclusion_type.__name__}\n"
905
1042
  # add rdr python generated functions.
906
1043
  for rdr_key, rdr in self.start_rules_dict.items():
907
1044
  imports += (f"from {file_path.strip('./')}"