ripple-down-rules 0.1.61__py3-none-any.whl → 0.1.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,7 +24,8 @@ class Case(UserDict, SubclassJSONSerializer):
24
24
  the names of the attributes and the values are the attributes. All are stored in lower case.
25
25
  """
26
26
 
27
- def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None, _name: Optional[str] = None, **kwargs):
27
+ def __init__(self, _obj_type: Type, _id: Optional[Hashable] = None,
28
+ _name: Optional[str] = None, original_object: Optional[Any] = None, **kwargs):
28
29
  """
29
30
  Create a new row.
30
31
 
@@ -34,6 +35,7 @@ class Case(UserDict, SubclassJSONSerializer):
34
35
  :param kwargs: The attributes of the row.
35
36
  """
36
37
  super().__init__(kwargs)
38
+ self._original_object = original_object
37
39
  self._obj_type: Type = _obj_type
38
40
  self._id: Hashable = _id if _id is not None else id(self)
39
41
  self._name: str = _name if _name is not None else self._obj_type.__name__
@@ -221,9 +223,9 @@ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
221
223
  return obj
222
224
  if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
223
225
  or (obj.__class__ in [MetaData, registry])):
224
- return Case(type(obj), _id=id(obj), _name=obj_name,
226
+ return Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj,
225
227
  **{obj_name or obj.__class__.__name__: make_list(obj) if parent_is_iterable else obj})
226
- case = Case(type(obj), _id=id(obj), _name=obj_name)
228
+ case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
227
229
  for attr in dir(obj):
228
230
  if attr.startswith("_") or callable(getattr(obj, attr)):
229
231
  continue
@@ -251,7 +253,7 @@ def create_or_update_case_from_attribute(attr_value: Any, name: str, obj: Any, o
251
253
  :return: The updated/created case.
252
254
  """
253
255
  if case is None:
254
- case = Case(type(obj), _id=id(obj), _name=obj_name)
256
+ case = Case(type(obj), _id=id(obj), _name=obj_name, original_object=obj)
255
257
  if isinstance(attr_value, (dict, UserDict)):
256
258
  case.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
257
259
  if hasattr(attr_value, "__iter__") and not isinstance(attr_value, str):
@@ -280,7 +282,7 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
280
282
  """
281
283
  values = list(attr_value.values()) if isinstance(attr_value, (dict, UserDict)) else attr_value
282
284
  _type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
283
- attr_case = Case(_type, _id=id(attr_value), _name=name)
285
+ attr_case = Case(_type, _id=id(attr_value), _name=name, original_object=attr_value)
284
286
  case_attr = CaseAttribute(values)
285
287
  for idx, val in enumerate(values):
286
288
  sub_attr_case = create_case(val, recursion_idx=recursion_idx,
@@ -8,7 +8,7 @@ from typing_extensions import Any, Optional, Dict, Type, Tuple, Union
8
8
 
9
9
  from .callable_expression import CallableExpression
10
10
  from .case import create_case, Case
11
- from ..utils import copy_case, make_list
11
+ from ..utils import copy_case, make_list, make_set
12
12
 
13
13
 
14
14
  @dataclass
@@ -88,9 +88,10 @@ class CaseQuery:
88
88
  """
89
89
  :return: The type of the attribute.
90
90
  """
91
- self._attribute_types = tuple(make_list(self._attribute_types))
92
- if not self.mutually_exclusive and (set not in self._attribute_types):
93
- self._attribute_types = tuple(list(self._attribute_types) + [set])
91
+ if not self.mutually_exclusive and (set not in make_list(self._attribute_types)):
92
+ self._attribute_types = tuple(set(make_list(self._attribute_types) + [set, list]))
93
+ elif not isinstance(self._attribute_types, tuple):
94
+ self._attribute_types = tuple(make_list(self._attribute_types))
94
95
  return self._attribute_types
95
96
 
96
97
  @attribute_type.setter
@@ -53,7 +53,7 @@ class ExpressionParser(Enum):
53
53
 
54
54
  class PromptFor(Enum):
55
55
  """
56
- The reason of the prompt. (e.g. get conditions, or conclusions).
56
+ The reason of the prompt. (e.g. get conditions, conclusions, or affirmation).
57
57
  """
58
58
  Conditions: str = "conditions"
59
59
  """
@@ -63,6 +63,10 @@ class PromptFor(Enum):
63
63
  """
64
64
  Prompt for rule conclusion about a case.
65
65
  """
66
+ Affirmation: str = "affirmation"
67
+ """
68
+ Prompt for rule conclusion affirmation about a case.
69
+ """
66
70
 
67
71
  def __str__(self):
68
72
  return self.name
@@ -10,7 +10,7 @@ from .datastructures.callable_expression import CallableExpression
10
10
  from .datastructures.enums import PromptFor
11
11
  from .datastructures.dataclasses import CaseQuery
12
12
  from .datastructures.case import show_current_and_corner_cases
13
- from .prompt import prompt_user_for_expression
13
+ from .prompt import prompt_user_for_expression, IPythonShell
14
14
  from .utils import get_all_subclasses, make_list
15
15
 
16
16
  if TYPE_CHECKING:
@@ -51,28 +51,24 @@ class Expert(ABC):
51
51
  pass
52
52
 
53
53
  @abstractmethod
54
- def ask_for_extra_conclusions(self, x: Case, current_conclusions: List[CaseAttribute]) \
55
- -> Dict[CaseAttribute, CallableExpression]:
54
+ def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
56
55
  """
57
- Ask the expert to provide extra conclusions for a case by providing a pair of category and conditions for
58
- that category.
56
+ Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
59
57
 
60
- :param x: The case to classify.
61
- :param current_conclusions: The current conclusions for the case.
62
- :return: The extra conclusions for the case.
58
+ :param case_query: The case query containing the case to classify.
59
+ :return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
60
+ conclusion and conditions for the rule.
63
61
  """
64
62
  pass
65
63
 
66
64
  @abstractmethod
67
- def ask_if_conclusion_is_correct(self, x: Case, conclusion: CaseAttribute,
68
- targets: Optional[List[CaseAttribute]] = None,
69
- current_conclusions: Optional[List[CaseAttribute]] = None) -> bool:
65
+ def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
66
+ current_conclusions: Any) -> bool:
70
67
  """
71
68
  Ask the expert if the conclusion is correct.
72
69
 
73
- :param x: The case to classify.
70
+ :param case_query: The case query about which the expert should answer.
74
71
  :param conclusion: The conclusion to check.
75
- :param targets: The target categories to compare the case with.
76
72
  :param current_conclusions: The current conclusions for the case.
77
73
  """
78
74
  pass
@@ -130,6 +126,24 @@ class Human(Expert):
130
126
  last_evaluated_rule=last_evaluated_rule)
131
127
  return self._get_conditions(case_query)
132
128
 
129
+ def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
130
+ """
131
+ Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
132
+
133
+ :param case_query: The case query containing the case to classify.
134
+ :return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
135
+ conclusion and conditions for the rule.
136
+ """
137
+ rules = []
138
+ while True:
139
+ conclusion = self.ask_for_conclusion(case_query)
140
+ if conclusion is None:
141
+ break
142
+ conditions = self._get_conditions(case_query)
143
+ rules.append({PromptFor.Conclusion: conclusion,
144
+ PromptFor.Conditions: conditions})
145
+ return rules
146
+
133
147
  def _get_conditions(self, case_query: CaseQuery) \
134
148
  -> CallableExpression:
135
149
  """
@@ -151,24 +165,6 @@ class Human(Expert):
151
165
  case_query.conditions = condition
152
166
  return condition
153
167
 
154
- def ask_for_extra_conclusions(self, case: Case, current_conclusions: List[CaseAttribute]) \
155
- -> Dict[CaseAttribute, CallableExpression]:
156
- """
157
- Ask the expert to provide extra conclusions for a case by providing a pair of category and conditions for
158
- that category.
159
-
160
- :param case: The case to classify.
161
- :param current_conclusions: The current conclusions for the case.
162
- :return: The extra conclusions for the case.
163
- """
164
- extra_conclusions = {}
165
- while True:
166
- category = self.ask_for_conclusion(CaseQuery(case), current_conclusions)
167
- if not category:
168
- break
169
- extra_conclusions[category] = self._get_conditions(case, {category.__class__.__name__: category})
170
- return extra_conclusions
171
-
172
168
  def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
173
169
  """
174
170
  Ask the expert to provide a conclusion for the case.
@@ -212,44 +208,39 @@ class Human(Expert):
212
208
  :param category_name: The name of the category to ask about.
213
209
  """
214
210
  question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
215
- return not self.ask_yes_no_question(question)
211
+ return not self.ask_for_affirmation(question)
216
212
 
217
- def ask_if_conclusion_is_correct(self, x: Case, conclusion: Any,
218
- targets: Optional[List[Any]] = None,
219
- current_conclusions: Optional[List[Any]] = None) -> bool:
213
+ def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
214
+ current_conclusions: Any) -> bool:
220
215
  """
221
216
  Ask the expert if the conclusion is correct.
222
217
 
223
- :param x: The case to classify.
218
+ :param case_query: The case query about which the expert should answer.
224
219
  :param conclusion: The conclusion to check.
225
- :param targets: The target categories to compare the case with.
226
220
  :param current_conclusions: The current conclusions for the case.
227
221
  """
228
- question = ""
229
222
  if not self.use_loaded_answers:
230
- targets = targets or []
231
- x.conclusions = make_list(current_conclusions)
232
- x.targets = make_list(targets)
233
- question = f"Is the conclusion {conclusion} correct for the case (y/n):" \
234
- f"\n{str(x)}"
235
- return self.ask_yes_no_question(question)
223
+ print(f"Current conclusions: {current_conclusions}")
224
+ return self.ask_for_affirmation(case_query,
225
+ f"Is the conclusion {conclusion} correct for the case (True/False):")
236
226
 
237
- def ask_yes_no_question(self, question: str) -> bool:
227
+ def ask_for_affirmation(self, case_query: CaseQuery, question: str) -> bool:
238
228
  """
239
229
  Ask the expert a yes or no question.
240
230
 
241
- :param question: The question to ask.
231
+ :param case_query: The case query about which the expert should answer.
232
+ :param question: The question to ask the expert.
242
233
  :return: The answer to the question.
243
234
  """
244
- if not self.use_loaded_answers:
245
- print(question)
246
235
  while True:
247
236
  if self.use_loaded_answers:
248
237
  answer = self.all_expert_answers.pop(0)
249
238
  else:
250
- answer = input()
251
- self.all_expert_answers.append(answer)
252
- if answer.lower() == "y":
239
+ _, expression = prompt_user_for_expression(case_query, PromptFor.Affirmation, question)
240
+ answer = expression(case_query.case)
241
+ if answer:
242
+ self.all_expert_answers.append(True)
253
243
  return True
254
- elif answer.lower() == "n":
244
+ else:
245
+ self.all_expert_answers.append(False)
255
246
  return False
@@ -1,10 +1,34 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
 
5
+ from .datastructures.dataclasses import CaseQuery
3
6
  from sqlalchemy.orm import Session
4
- from typing_extensions import Type, Optional
7
+ from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING
8
+
9
+ from .utils import get_func_rdr_model_path
10
+ from .utils import calculate_precision_and_recall
11
+
12
+ if TYPE_CHECKING:
13
+ from .rdr import RippleDownRules
5
14
 
6
- from ripple_down_rules.rdr import RippleDownRules
7
- from ripple_down_rules.utils import get_func_rdr_model_path
15
+
16
+ def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
17
+ """
18
+ :param classifier: The RDR classifier to check the prediction of.
19
+ :param case_query: The case query to check.
20
+ :param pred_cat: The predicted category.
21
+ :return: Whether the classifier prediction is matching case_query target or not.
22
+ """
23
+ if case_query.target is None:
24
+ return False
25
+ if pred_cat is None:
26
+ pred_cat = classifier(case_query.case)
27
+ if not isinstance(pred_cat, dict):
28
+ pred_cat = {case_query.attribute_name: pred_cat}
29
+ target = {case_query.attribute_name: case_query.target_value}
30
+ precision, recall = calculate_precision_and_recall(pred_cat, target)
31
+ return all(recall) and all(precision)
8
32
 
9
33
 
10
34
  def load_or_create_func_rdr_model(func, model_dir: str, rdr_type: Type[RippleDownRules],
@@ -9,7 +9,7 @@ from typing_extensions import List, Optional, Tuple, Dict
9
9
  from .datastructures.enums import PromptFor
10
10
  from .datastructures.callable_expression import CallableExpression, parse_string_to_expression
11
11
  from .datastructures.dataclasses import CaseQuery
12
- from .utils import extract_dependencies, contains_return_statement
12
+ from .utils import extract_dependencies, contains_return_statement, make_set
13
13
 
14
14
 
15
15
  class CustomInteractiveShell(InteractiveShellEmbed):
@@ -28,12 +28,12 @@ class CustomInteractiveShell(InteractiveShellEmbed):
28
28
  self.ask_exit()
29
29
  return None
30
30
  result = super().run_cell(raw_cell, **kwargs)
31
- if not result.error_in_exec:
31
+ if result.error_in_exec is None and result.error_before_exec is None:
32
32
  self.all_lines.append(raw_cell)
33
33
  return result
34
34
 
35
35
 
36
- class IpythonShell:
36
+ class IPythonShell:
37
37
  """
38
38
  Create an embedded Ipython shell that can be used to prompt the user for input.
39
39
  """
@@ -63,8 +63,14 @@ class IpythonShell:
63
63
  """
64
64
  Run the embedded shell.
65
65
  """
66
- self.shell()
67
- self.update_user_input_from_code_lines()
66
+ while True:
67
+ try:
68
+ self.shell()
69
+ self.update_user_input_from_code_lines()
70
+ break
71
+ except Exception as e:
72
+ logging.error(e)
73
+ print(e)
68
74
 
69
75
  def update_user_input_from_code_lines(self):
70
76
  """
@@ -81,20 +87,23 @@ class IpythonShell:
81
87
  self.user_input = self.all_code_lines[0].replace('return', '').strip()
82
88
  else:
83
89
  self.user_input = f"def _get_value(case):\n "
84
- self.user_input += '\n '.join(self.all_code_lines)
90
+ for cl in self.all_code_lines:
91
+ sub_code_lines = cl.split('\n')
92
+ self.user_input += '\n '.join(sub_code_lines) + '\n '
85
93
 
86
94
 
87
- def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor)\
95
+ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, prompt_str: Optional[str] = None)\
88
96
  -> Tuple[Optional[str], Optional[CallableExpression]]:
89
97
  """
90
98
  Prompt the user for an executable python expression to the given case query.
91
99
 
92
100
  :param case_query: The case query to prompt the user for.
93
101
  :param prompt_for: The type of information ask user about.
102
+ :param prompt_str: The prompt string to display to the user.
94
103
  :return: A callable expression that takes a case and executes user expression on it.
95
104
  """
96
105
  while True:
97
- user_input, expression_tree = prompt_user_about_case(case_query, prompt_for)
106
+ user_input, expression_tree = prompt_user_about_case(case_query, prompt_for, prompt_str)
98
107
  if user_input is None:
99
108
  if prompt_for == PromptFor.Conclusion:
100
109
  print("No conclusion provided. Exiting.")
@@ -114,21 +123,24 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor)\
114
123
  return user_input, callable_expression
115
124
 
116
125
 
117
- def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor) -> Tuple[Optional[str], Optional[AST]]:
126
+ def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor,
127
+ prompt_str: Optional[str] = None) -> Tuple[Optional[str], Optional[AST]]:
118
128
  """
119
129
  Prompt the user for input.
120
130
 
121
131
  :param case_query: The case query to prompt the user for.
122
132
  :param prompt_for: The type of information the user should provide for the given case.
133
+ :param prompt_str: The prompt string to display to the user.
123
134
  :return: The user input, and the executable expression that was parsed from the user input.
124
135
  """
125
- prompt_str = f"Give {prompt_for} for {case_query.name}"
136
+ if prompt_str is None:
137
+ prompt_str = f"Give {prompt_for} for {case_query.name}"
126
138
  scope = {'case': case_query.case, **case_query.scope}
127
- shell = IpythonShell(scope=scope, header=prompt_str)
139
+ shell = IPythonShell(scope=scope, header=prompt_str)
128
140
  return prompt_user_input_and_parse_to_expression(shell=shell)
129
141
 
130
142
 
131
- def prompt_user_input_and_parse_to_expression(shell: Optional[IpythonShell] = None,
143
+ def prompt_user_input_and_parse_to_expression(shell: Optional[IPythonShell] = None,
132
144
  user_input: Optional[str] = None)\
133
145
  -> Tuple[Optional[str], Optional[ast.AST]]:
134
146
  """
@@ -140,7 +152,7 @@ def prompt_user_input_and_parse_to_expression(shell: Optional[IpythonShell] = No
140
152
  """
141
153
  while True:
142
154
  if user_input is None:
143
- shell = IpythonShell() if shell is None else shell
155
+ shell = IPythonShell() if shell is None else shell
144
156
  shell.run()
145
157
  user_input = shell.user_input
146
158
  if user_input is None:
@@ -151,4 +163,5 @@ def prompt_user_input_and_parse_to_expression(shell: Optional[IpythonShell] = No
151
163
  except Exception as e:
152
164
  msg = f"Error parsing expression: {e}"
153
165
  logging.error(msg)
166
+ print(msg)
154
167
  user_input = None
ripple_down_rules/rdr.py CHANGED
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import importlib
4
- import re
5
4
  import sys
6
5
  from abc import ABC, abstractmethod
7
6
  from copy import copy
@@ -14,15 +13,16 @@ from ordered_set import OrderedSet
14
13
  from sqlalchemy.orm import DeclarativeBase as SQLTable
15
14
  from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple, Callable, Set
16
15
 
17
- from .datastructures.case import Case, CaseAttribute
18
16
  from .datastructures.callable_expression import CallableExpression
17
+ from .datastructures.case import Case, CaseAttribute, create_case
19
18
  from .datastructures.dataclasses import CaseQuery
20
- from .datastructures.enums import MCRDRMode
19
+ from .datastructures.enums import MCRDRMode, PromptFor
21
20
  from .experts import Expert, Human
21
+ from .helpers import is_matching
22
22
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
23
23
  from .utils import draw_tree, make_set, copy_case, \
24
- get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
25
- get_case_attribute_type, ask_llm, is_matching
24
+ SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
25
+ get_case_attribute_type, is_conflicting
26
26
 
27
27
 
28
28
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -277,7 +277,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
277
277
  else:
278
278
  return "SCRDR"
279
279
 
280
-
281
280
  @property
282
281
  def case_type(self) -> Type:
283
282
  """
@@ -366,7 +365,7 @@ class SingleClassRDR(RDRWithCodeWriter):
366
365
  super().write_to_python_file(file_path, postfix)
367
366
  if self.default_conclusion is not None:
368
367
  with open(file_path + f"/{self.generated_python_file_name}.py", "a") as f:
369
- f.write(f"{' '*4}else:\n{' '*8}return {self.default_conclusion}\n")
368
+ f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
370
369
 
371
370
  def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
372
371
  defs_file: Optional[str] = None):
@@ -377,7 +376,6 @@ class SingleClassRDR(RDRWithCodeWriter):
377
376
  if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
378
377
  file.write(if_clause)
379
378
  if rule.refinement:
380
-
381
379
  self.write_rules_as_source_code_to_file(rule.refinement, file, parent_indent + " ",
382
380
  defs_file=defs_file)
383
381
 
@@ -474,7 +472,7 @@ class MultiClassRDR(RDRWithCodeWriter):
474
472
  if target and not make_set(rule_conclusion).issubset(good_conclusions):
475
473
  # Rule fired and conclusion is different from target
476
474
  self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule,
477
- add_extra_conclusions)
475
+ len(user_conclusions) > 0)
478
476
  else:
479
477
  # Rule fired and target is correct or there is no target to compare
480
478
  self.add_conclusion(evaluated_rule, case_query.case)
@@ -485,11 +483,14 @@ class MultiClassRDR(RDRWithCodeWriter):
485
483
  self.add_rule_for_case(case_query, expert)
486
484
  # Have to check all rules again to make sure only this new rule fires
487
485
  next_rule = self.start_rule
488
- elif add_extra_conclusions and not user_conclusions:
486
+ elif add_extra_conclusions:
489
487
  # No more conclusions can be made, ask the expert for extra conclusions if needed.
490
- user_conclusions.extend(self.ask_expert_for_extra_conclusions(expert, case_query.case))
491
- if user_conclusions:
488
+ new_user_conclusions = self.ask_expert_for_extra_rules(expert, case_query)
489
+ user_conclusions.extend(new_user_conclusions)
490
+ if len(new_user_conclusions) > 0:
492
491
  next_rule = self.last_top_rule
492
+ else:
493
+ add_extra_conclusions = False
493
494
  evaluated_rule = next_rule
494
495
  return self.conclusions
495
496
 
@@ -551,13 +552,12 @@ class MultiClassRDR(RDRWithCodeWriter):
551
552
  """
552
553
  Stop a wrong conclusion by adding a stopping rule.
553
554
  """
554
- target = case_query.target(case_query.case)
555
555
  rule_conclusion = evaluated_rule.conclusion(case_query.case)
556
- if self.is_same_category_type(rule_conclusion, target) \
557
- and self.is_conflicting_with_target(rule_conclusion, target):
558
- self.stop_conclusion(case_query, expert, evaluated_rule)
559
- elif not self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
560
- self.stop_conclusion(case_query, expert, evaluated_rule)
556
+ if is_conflicting(rule_conclusion, case_query.target_value):
557
+ if self.conclusion_is_correct(case_query, expert, evaluated_rule, add_extra_conclusions):
558
+ return
559
+ else:
560
+ self.stop_conclusion(case_query, expert, evaluated_rule)
561
561
 
562
562
  def stop_conclusion(self, case_query: CaseQuery,
563
563
  expert: Expert, evaluated_rule: MultiClassTopRule):
@@ -576,31 +576,6 @@ class MultiClassRDR(RDRWithCodeWriter):
576
576
  new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
577
577
  self.add_top_rule(new_top_rule_conditions, case_query.target, case_query.case)
578
578
 
579
- @staticmethod
580
- def is_conflicting_with_target(conclusion: Any, target: Any) -> bool:
581
- """
582
- Check if the conclusion is conflicting with the target category.
583
-
584
- :param conclusion: The conclusion to check.
585
- :param target: The target category to compare the conclusion with.
586
- :return: Whether the conclusion is conflicting with the target category.
587
- """
588
- if hasattr(conclusion, "mutually_exclusive") and conclusion.mutually_exclusive:
589
- return True
590
- else:
591
- return not make_set(conclusion).issubset(make_set(target))
592
-
593
- @staticmethod
594
- def is_same_category_type(conclusion: Any, target: Any) -> bool:
595
- """
596
- Check if the conclusion is of the same class as the target category.
597
-
598
- :param conclusion: The conclusion to check.
599
- :param target: The target category to compare the conclusion with.
600
- :return: Whether the conclusion is of the same class as the target category but has a different value.
601
- """
602
- return conclusion.__class__ == target.__class__ and target.__class__ != CaseAttribute
603
-
604
579
  def conclusion_is_correct(self, case_query: CaseQuery,
605
580
  expert: Expert, evaluated_rule: Rule,
606
581
  add_extra_conclusions: bool) -> bool:
@@ -616,7 +591,6 @@ class MultiClassRDR(RDRWithCodeWriter):
616
591
  conclusions = {case_query.attribute_name: c for c in OrderedSet(self.conclusions)}
617
592
  if (add_extra_conclusions and expert.ask_if_conclusion_is_correct(case_query.case,
618
593
  evaluated_rule.conclusion(case_query.case),
619
- targets=case_query.target(case_query.case),
620
594
  current_conclusions=conclusions)):
621
595
  self.add_conclusion(evaluated_rule, case_query.case)
622
596
  self.expert_accepted_conclusions.append(evaluated_rule.conclusion)
@@ -637,23 +611,22 @@ class MultiClassRDR(RDRWithCodeWriter):
637
611
  conditions = expert.ask_for_conditions(case_query)
638
612
  self.add_top_rule(conditions, case_query.target, case_query.case)
639
613
 
640
- def ask_expert_for_extra_conclusions(self, expert: Expert, case: Union[Case, SQLTable]) -> List[Any]:
614
+ def ask_expert_for_extra_rules(self, expert: Expert, case_query: CaseQuery) -> List[Any]:
641
615
  """
642
- Ask the expert for extra conclusions when no more conclusions can be made.
616
+ Ask the expert for extra rules when no more conclusions can be made for a case.
643
617
 
644
618
  :param expert: The expert to ask for extra conclusions.
645
- :param case: The case to ask extra conclusions for.
646
- :return: The extra conclusions that the expert has provided.
619
+ :param case_query: The case query to ask the expert about.
620
+ :return: The extra conclusions for the rules that the expert has provided.
647
621
  """
648
622
  extra_conclusions = []
649
623
  conclusions = list(OrderedSet(self.conclusions))
650
624
  if not expert.use_loaded_answers:
651
625
  print("current conclusions:", conclusions)
652
- extra_conclusions_dict = expert.ask_for_extra_conclusions(case, conclusions)
653
- if extra_conclusions_dict:
654
- for conclusion, conditions in extra_conclusions_dict.items():
655
- self.add_top_rule(conditions, conclusion, case)
656
- extra_conclusions.append(conclusion)
626
+ extra_rules = expert.ask_for_extra_rules(case_query)
627
+ for rule in extra_rules:
628
+ self.add_top_rule(rule[PromptFor.Conditions], rule[PromptFor.Conclusion], case_query.case)
629
+ extra_conclusions.extend(rule[PromptFor.Conclusion](case_query.case))
657
630
  return extra_conclusions
658
631
 
659
632
  def add_conclusion(self, evaluated_rule: Rule, case: Case) -> None:
@@ -771,6 +744,7 @@ class GeneralRDR(RippleDownRules):
771
744
  :return: The categories that the case belongs to.
772
745
  """
773
746
  conclusions = {}
747
+ case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
774
748
  case_cp = copy_case(case)
775
749
  while True:
776
750
  new_conclusions = {}
@@ -863,17 +837,17 @@ class GeneralRDR(RippleDownRules):
863
837
  Initialize the appropriate RDR type for the target.
864
838
  """
865
839
  if case_query.mutually_exclusive is not None:
866
- return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive\
840
+ return SingleClassRDR(default_conclusion=case_query.default_value) if case_query.mutually_exclusive \
867
841
  else MultiClassRDR()
868
842
  if case_query.attribute_type in [list, set]:
869
843
  return MultiClassRDR()
870
- attribute = getattr(case_query.case, case_query.attribute_name)\
844
+ attribute = getattr(case_query.case, case_query.attribute_name) \
871
845
  if hasattr(case_query.case, case_query.attribute_name) else case_query.target(case_query.case)
872
846
  if isinstance(attribute, CaseAttribute):
873
847
  return SingleClassRDR(default_conclusion=case_query.default_value) if attribute.mutually_exclusive \
874
848
  else MultiClassRDR()
875
849
  else:
876
- return MultiClassRDR() if is_iterable(attribute) or (attribute is None)\
850
+ return MultiClassRDR() if is_iterable(attribute) or (attribute is None) \
877
851
  else SingleClassRDR(default_conclusion=case_query.default_value)
878
852
 
879
853
  @staticmethod
@@ -895,13 +869,18 @@ class GeneralRDR(RippleDownRules):
895
869
  attribute_type = case_query.attribute_type
896
870
  else:
897
871
  attribute_type = (get_case_attribute_type(case_query.original_case, conclusion_name, attribute),)
898
- if isinstance(attribute, set) or any(at in {Set, set} for at in attribute_type):
899
- attribute = set() if attribute is None else attribute
872
+ if isinstance(attribute, set):
900
873
  for c in conclusion:
901
874
  attribute.update(make_set(c))
902
- elif isinstance(attribute, list) or any(at in {List, list} for at in attribute_type):
875
+ elif isinstance(attribute, list):
876
+ attribute.extend(conclusion)
877
+ elif any(at in {List, list} for at in attribute_type):
903
878
  attribute = [] if attribute is None else attribute
904
879
  attribute.extend(conclusion)
880
+ elif any(at in {Set, set} for at in attribute_type):
881
+ attribute = set() if attribute is None else attribute
882
+ for c in conclusion:
883
+ attribute.update(make_set(c))
905
884
  elif is_iterable(conclusion) and len(conclusion) == 1 \
906
885
  and any(at is type(list(conclusion)[0]) for at in attribute_type):
907
886
  setattr(case_query.case, conclusion_name, list(conclusion)[0])
@@ -33,22 +33,25 @@ import ast
33
33
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
34
34
 
35
35
 
36
- def is_matching(rdr_classifier: Callable[[Any], Any], case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
36
+ def is_conflicting(conclusion: Any, target: Any) -> bool:
37
37
  """
38
- :param rdr_classifier: The RDR classifier to check the prediction of.
39
- :param case_query: The case query to check.
40
- :param pred_cat: The predicted category.
41
- :return: Whether the classifier prediction is matching case_query target or not.
38
+ :param conclusion: The conclusion to check.
39
+ :param target: The target to compare the conclusion with.
40
+ :return: Whether the conclusion is conflicting with the target by have different values for same type categories.
42
41
  """
43
- if case_query.target is None:
44
- return False
45
- if pred_cat is None:
46
- pred_cat = rdr_classifier(case_query.case)
47
- if not isinstance(pred_cat, dict):
48
- pred_cat = {case_query.attribute_name: pred_cat}
49
- target = {case_query.attribute_name: case_query.target_value}
50
- precision, recall = calculate_precision_and_recall(pred_cat, target)
51
- return all(recall) and all(precision)
42
+ return have_common_types(conclusion, target) and not make_set(conclusion).issubset(make_set(target))
43
+
44
+
45
+ def have_common_types(conclusion: Any, target: Any) -> bool:
46
+ """
47
+ :param conclusion: The conclusion to check.
48
+ :param target: The target to compare the conclusion with.
49
+ :return: Whether the conclusion shares some types with the target.
50
+ """
51
+ target_types = {type(t) for t in make_set(target)}
52
+ conclusion_types = {type(c) for c in make_set(conclusion)}
53
+ common_types = conclusion_types.intersection(target_types)
54
+ return len(common_types) > 0
52
55
 
53
56
 
54
57
  def calculate_precision_and_recall(pred_cat: Dict[str, Any], target: Dict[str, Any]) -> Tuple[
@@ -177,7 +180,8 @@ def extract_dependencies(code_lines):
177
180
 
178
181
  if not isinstance(final_stmt, ast.Return):
179
182
  raise ValueError("Last line is not a return statement")
180
-
183
+ if final_stmt.value is None:
184
+ raise ValueError("Return statement has no value")
181
185
  needed = get_names_used(final_stmt.value)
182
186
  required_lines = []
183
187
  line_map = {id(node): i for i, node in enumerate(tree.body)}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.61
3
+ Version: 0.1.62
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -0,0 +1,20 @@
1
+ ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ ripple_down_rules/datasets.py,sha256=rCSpeFeu1gTuKESwjHUdQkPPvomI5OMRNGpbdKmHwMg,4639
3
+ ripple_down_rules/experts.py,sha256=TXU-VKhWlSE3aYrCEcXf9t9N4TBDW_T5tE4T6MdCibE,10342
4
+ ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
+ ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
6
+ ripple_down_rules/prompt.py,sha256=cHqhMJqubGhfGpOOY_uXv5L7PBNb64O0IBWSfiY0ui0,6682
7
+ ripple_down_rules/rdr.py,sha256=K-b1I1pBEN6rn3AZdTAxgOs6AFXXvqkGI4x8dV9nrWw,47793
8
+ ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
9
+ ripple_down_rules/rules.py,sha256=KTB7kPnyyU9GuZhVe9ba25-3ICdzl46r9MFduckk-_Y,16147
10
+ ripple_down_rules/utils.py,sha256=JIF99Knqzqjgny7unvEnib3sCmExqU-w9xYOSGIT86Q,32276
11
+ ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
12
+ ripple_down_rules/datastructures/callable_expression.py,sha256=TW_u6CJfelW2CiJj9pWFpdOBNIxeEuhhsQEz_pLpFVE,9092
13
+ ripple_down_rules/datastructures/case.py,sha256=XJC6Sb67gzpEYYYYjvECJlJBRVphMScWhWMTc2kTtbc,13792
14
+ ripple_down_rules/datastructures/dataclasses.py,sha256=_aabVXsgdVUeAmgGA9K_LZpO2U5a6-htrg2Tka7qc30,5960
15
+ ripple_down_rules/datastructures/enums.py,sha256=RdyPUp9Ls1QuLmkcMMkBbCWrmXIZI4xWuM-cLPYZhR0,4666
16
+ ripple_down_rules-0.1.62.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
+ ripple_down_rules-0.1.62.dist-info/METADATA,sha256=jPKNb7YM2Wcsy9p1otEVqyxv5apcfADU-sPin5Wu0qs,42576
18
+ ripple_down_rules-0.1.62.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
19
+ ripple_down_rules-0.1.62.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
+ ripple_down_rules-0.1.62.dist-info/RECORD,,
@@ -1,20 +0,0 @@
1
- ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- ripple_down_rules/datasets.py,sha256=rCSpeFeu1gTuKESwjHUdQkPPvomI5OMRNGpbdKmHwMg,4639
3
- ripple_down_rules/experts.py,sha256=sA9Cmx9BlwlCFYRDDLz3VG6e5njujAFZEItSnnzrG5E,10490
4
- ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
- ripple_down_rules/helpers.py,sha256=AhqerAQoCdSovJ7SdQrNtAI_hYagKpLsy2nJQGA0bl0,1062
6
- ripple_down_rules/prompt.py,sha256=6g-WqMiOFp9QyAZDmiNbHbPjAeeJHb6ItLGdQAVxGKk,6063
7
- ripple_down_rules/rdr.py,sha256=VT7AWTDlLOyk2FILa4mHixdno2kXtk82m_pSY1CoEiE,48789
8
- ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
9
- ripple_down_rules/rules.py,sha256=KTB7kPnyyU9GuZhVe9ba25-3ICdzl46r9MFduckk-_Y,16147
10
- ripple_down_rules/utils.py,sha256=ppKTt3_O6JgmTqCdkjVBfuVaI6P7b4oRCSOmnBaqaVM,32110
11
- ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
12
- ripple_down_rules/datastructures/callable_expression.py,sha256=TW_u6CJfelW2CiJj9pWFpdOBNIxeEuhhsQEz_pLpFVE,9092
13
- ripple_down_rules/datastructures/case.py,sha256=A7qkl5W48zldTtA4m-NJRYEwlMBpo7uGugnriNwcY0E,13597
14
- ripple_down_rules/datastructures/dataclasses.py,sha256=inhTE4tlMrwVRcYDtqAaR0JlxlyD87JIUvXIu5H9Ioo,5860
15
- ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
16
- ripple_down_rules-0.1.61.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
- ripple_down_rules-0.1.61.dist-info/METADATA,sha256=fstYXWm2KIN0bmiUkJTD6UIDiGQGtjE-pt263bDgJqk,42576
18
- ripple_down_rules-0.1.61.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
19
- ripple_down_rules-0.1.61.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
- ripple_down_rules-0.1.61.dist-info/RECORD,,