ripple-down-rules 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,7 @@ from typing_extensions import Any, Optional, Dict, Type, Tuple, Union
8
8
 
9
9
  from .callable_expression import CallableExpression
10
10
  from .case import create_case, Case
11
- from ..utils import copy_case
11
+ from ..utils import copy_case, make_list
12
12
 
13
13
 
14
14
  @dataclass
@@ -88,6 +88,7 @@ class CaseQuery:
88
88
  """
89
89
  :return: The type of the attribute.
90
90
  """
91
+ self._attribute_types = tuple(make_list(self._attribute_types))
91
92
  if not self.mutually_exclusive and (set not in self._attribute_types):
92
93
  self._attribute_types = tuple(list(self._attribute_types) + [set])
93
94
  return self._attribute_types
@@ -97,9 +98,7 @@ class CaseQuery:
97
98
  """
98
99
  Set the type of the attribute.
99
100
  """
100
- if not isinstance(value, tuple):
101
- value = (value,)
102
- self._attribute_types = value
101
+ self._attribute_types = tuple(make_list(value))
103
102
 
104
103
  @property
105
104
  def name(self):
ripple_down_rules/rdr.py CHANGED
@@ -22,7 +22,7 @@ from .experts import Expert, Human
22
22
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
23
23
  from .utils import draw_tree, make_set, copy_case, \
24
24
  get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
25
- get_case_attribute_type, ask_llm
25
+ get_case_attribute_type, ask_llm, is_matching
26
26
 
27
27
 
28
28
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -103,14 +103,14 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
103
103
  target = {case_query.attribute_name: case_query.target(case_query.case)}
104
104
  if len(targets) < len(case_queries):
105
105
  targets.append(target)
106
- match = self.is_matching(case_query, pred_cat)
106
+ match = is_matching(self.classify, case_query, pred_cat)
107
107
  if not match:
108
108
  print(f"Predicted: {pred_cat} but expected: {target}")
109
109
  if animate_tree and self.start_rule.size > num_rules:
110
110
  num_rules = self.start_rule.size
111
111
  self.update_figures()
112
112
  i += 1
113
- all_predictions = [1 if self.is_matching(case_query) else 0 for case_query in case_queries
113
+ all_predictions = [1 if is_matching(self.classify, case_query) else 0 for case_query in case_queries
114
114
  if case_query.target is not None]
115
115
  all_pred = sum(all_predictions)
116
116
  print(f"Accuracy: {all_pred}/{len(targets)}")
@@ -124,43 +124,6 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
124
124
  plt.ioff()
125
125
  plt.show()
126
126
 
127
- def is_matching(self, case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
128
- """
129
- :param case_query: The case query to check.
130
- :param pred_cat: The predicted category.
131
- :return: Whether the classifier prediction is matching case_query target or not.
132
- """
133
- if case_query.target is None:
134
- return False
135
- if pred_cat is None:
136
- pred_cat = self.classify(case_query.case)
137
- if not isinstance(pred_cat, dict):
138
- pred_cat = {case_query.attribute_name: pred_cat}
139
- target = {case_query.attribute_name: case_query.target_value}
140
- precision, recall = self.calculate_precision_and_recall(pred_cat, target)
141
- return all(recall) and all(precision)
142
-
143
- @staticmethod
144
- def calculate_precision_and_recall(pred_cat: Dict[str, Any], target: Dict[str, Any]) -> Tuple[
145
- List[bool], List[bool]]:
146
- """
147
- :param pred_cat: The predicted category.
148
- :param target: The target category.
149
- :return: The precision and recall of the classifier.
150
- """
151
- recall = []
152
- precision = []
153
- for pred_key, pred_value in pred_cat.items():
154
- if pred_key not in target:
155
- continue
156
- precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
157
- for target_key, target_value in target.items():
158
- if target_key not in pred_cat:
159
- recall.append(False)
160
- continue
161
- recall.extend([v in make_set(pred_cat[target_key]) for v in make_set(target_value)])
162
- return precision, recall
163
-
164
127
  def update_figures(self):
165
128
  """
166
129
  Update the figures of the classifier.
@@ -399,6 +362,12 @@ class SingleClassRDR(RDRWithCodeWriter):
399
362
  matched_rule = self.start_rule(case)
400
363
  return matched_rule if matched_rule else self.start_rule
401
364
 
365
+ def write_to_python_file(self, file_path: str, postfix: str = ""):
366
+ super().write_to_python_file(file_path, postfix)
367
+ if self.default_conclusion is not None:
368
+ with open(file_path + f"/{self.generated_python_file_name}.py", "a") as f:
369
+ f.write(f"{' '*4}else:\n{' '*8}return {self.default_conclusion}\n")
370
+
402
371
  def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
403
372
  defs_file: Optional[str] = None):
404
373
  """
@@ -33,6 +33,45 @@ import ast
33
33
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
34
34
 
35
35
 
36
+ def is_matching(rdr_classifier: Callable[[Any], Any], case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
37
+ """
38
+ :param rdr_classifier: The RDR classifier to check the prediction of.
39
+ :param case_query: The case query to check.
40
+ :param pred_cat: The predicted category.
41
+ :return: Whether the classifier prediction is matching case_query target or not.
42
+ """
43
+ if case_query.target is None:
44
+ return False
45
+ if pred_cat is None:
46
+ pred_cat = rdr_classifier(case_query.case)
47
+ if not isinstance(pred_cat, dict):
48
+ pred_cat = {case_query.attribute_name: pred_cat}
49
+ target = {case_query.attribute_name: case_query.target_value}
50
+ precision, recall = calculate_precision_and_recall(pred_cat, target)
51
+ return all(recall) and all(precision)
52
+
53
+
54
+ def calculate_precision_and_recall(pred_cat: Dict[str, Any], target: Dict[str, Any]) -> Tuple[
55
+ List[bool], List[bool]]:
56
+ """
57
+ :param pred_cat: The predicted category.
58
+ :param target: The target category.
59
+ :return: The precision and recall of the classifier.
60
+ """
61
+ recall = []
62
+ precision = []
63
+ for pred_key, pred_value in pred_cat.items():
64
+ if pred_key not in target:
65
+ continue
66
+ precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
67
+ for target_key, target_value in target.items():
68
+ if target_key not in pred_cat:
69
+ recall.append(False)
70
+ continue
71
+ recall.extend([v in make_set(pred_cat[target_key]) for v in make_set(target_value)])
72
+ return precision, recall
73
+
74
+
36
75
  def get_rule_conclusion_as_source_code(rule: Rule, conclusion: str, parent_indent: str = "") -> str:
37
76
  """
38
77
  Convert the conclusion of a rule to source code.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -4,17 +4,17 @@ ripple_down_rules/experts.py,sha256=sA9Cmx9BlwlCFYRDDLz3VG6e5njujAFZEItSnnzrG5E,
4
4
  ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
5
  ripple_down_rules/helpers.py,sha256=AhqerAQoCdSovJ7SdQrNtAI_hYagKpLsy2nJQGA0bl0,1062
6
6
  ripple_down_rules/prompt.py,sha256=6g-WqMiOFp9QyAZDmiNbHbPjAeeJHb6ItLGdQAVxGKk,6063
7
- ripple_down_rules/rdr.py,sha256=HevACyk22k2m7sTKDTxbFiRp4MOQNK7XSOvJyVBg20Q,50047
7
+ ripple_down_rules/rdr.py,sha256=VT7AWTDlLOyk2FILa4mHixdno2kXtk82m_pSY1CoEiE,48789
8
8
  ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
9
9
  ripple_down_rules/rules.py,sha256=KTB7kPnyyU9GuZhVe9ba25-3ICdzl46r9MFduckk-_Y,16147
10
- ripple_down_rules/utils.py,sha256=0yyLpvt-GEamV4Z3515ip200IfzpqOhNcrXhGzZtEPk,30521
10
+ ripple_down_rules/utils.py,sha256=ppKTt3_O6JgmTqCdkjVBfuVaI6P7b4oRCSOmnBaqaVM,32110
11
11
  ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
12
12
  ripple_down_rules/datastructures/callable_expression.py,sha256=TW_u6CJfelW2CiJj9pWFpdOBNIxeEuhhsQEz_pLpFVE,9092
13
13
  ripple_down_rules/datastructures/case.py,sha256=A7qkl5W48zldTtA4m-NJRYEwlMBpo7uGugnriNwcY0E,13597
14
- ripple_down_rules/datastructures/dataclasses.py,sha256=2HISRjO_rfsOVCD19bmWkR5tRK9kWyFGTn3QHdMfLSw,5829
14
+ ripple_down_rules/datastructures/dataclasses.py,sha256=inhTE4tlMrwVRcYDtqAaR0JlxlyD87JIUvXIu5H9Ioo,5860
15
15
  ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
16
- ripple_down_rules-0.1.5.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
- ripple_down_rules-0.1.5.dist-info/METADATA,sha256=zjPgcX0Z3DMcPnU5YjbVRPu8w2MhdQ4gimpdC9JabJk,42518
18
- ripple_down_rules-0.1.5.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
19
- ripple_down_rules-0.1.5.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
- ripple_down_rules-0.1.5.dist-info/RECORD,,
16
+ ripple_down_rules-0.1.6.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
+ ripple_down_rules-0.1.6.dist-info/METADATA,sha256=aoGSheQJpGvz7fNTP9jyKz0V1M416oOQfzgsT_8Sd5s,42518
18
+ ripple_down_rules-0.1.6.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
19
+ ripple_down_rules-0.1.6.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
+ ripple_down_rules-0.1.6.dist-info/RECORD,,