ripple-down-rules 0.1.5__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/PKG-INFO +1 -1
  2. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/pyproject.toml +1 -1
  3. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datastructures/dataclasses.py +3 -4
  4. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/rdr.py +9 -40
  5. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/utils.py +39 -0
  6. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
  7. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/test/test_rdr_world.py +17 -7
  8. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/LICENSE +0 -0
  9. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/README.md +0 -0
  10. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/setup.cfg +0 -0
  11. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/__init__.py +0 -0
  12. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datasets.py +0 -0
  13. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datastructures/__init__.py +0 -0
  14. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datastructures/callable_expression.py +0 -0
  15. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datastructures/case.py +0 -0
  16. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/datastructures/enums.py +0 -0
  17. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/experts.py +0 -0
  18. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/failures.py +0 -0
  19. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/helpers.py +0 -0
  20. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/prompt.py +0 -0
  21. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/rdr_decorators.py +0 -0
  22. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules/rules.py +0 -0
  23. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules.egg-info/SOURCES.txt +0 -0
  24. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
  25. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
  26. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/test/test_json_serialization.py +0 -0
  27. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/test/test_on_mutagenic.py +0 -0
  28. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/test/test_rdr.py +0 -0
  29. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/test/test_rdr_alchemy.py +0 -0
  30. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/test/test_relational_rdr.py +0 -0
  31. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/test/test_relational_rdr_alchemy.py +0 -0
  32. {ripple_down_rules-0.1.5 → ripple_down_rules-0.1.6}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
 
7
7
  [project]
8
8
  name = "ripple_down_rules"
9
- version = "0.1.5"
9
+ version = "0.1.6"
10
10
  description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
11
11
  readme = "README.md"
12
12
  authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
@@ -8,7 +8,7 @@ from typing_extensions import Any, Optional, Dict, Type, Tuple, Union
8
8
 
9
9
  from .callable_expression import CallableExpression
10
10
  from .case import create_case, Case
11
- from ..utils import copy_case
11
+ from ..utils import copy_case, make_list
12
12
 
13
13
 
14
14
  @dataclass
@@ -88,6 +88,7 @@ class CaseQuery:
88
88
  """
89
89
  :return: The type of the attribute.
90
90
  """
91
+ self._attribute_types = tuple(make_list(self._attribute_types))
91
92
  if not self.mutually_exclusive and (set not in self._attribute_types):
92
93
  self._attribute_types = tuple(list(self._attribute_types) + [set])
93
94
  return self._attribute_types
@@ -97,9 +98,7 @@ class CaseQuery:
97
98
  """
98
99
  Set the type of the attribute.
99
100
  """
100
- if not isinstance(value, tuple):
101
- value = (value,)
102
- self._attribute_types = value
101
+ self._attribute_types = tuple(make_list(value))
103
102
 
104
103
  @property
105
104
  def name(self):
@@ -22,7 +22,7 @@ from .experts import Expert, Human
22
22
  from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
23
23
  from .utils import draw_tree, make_set, copy_case, \
24
24
  get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list, get_type_from_string, \
25
- get_case_attribute_type, ask_llm
25
+ get_case_attribute_type, ask_llm, is_matching
26
26
 
27
27
 
28
28
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -103,14 +103,14 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
103
103
  target = {case_query.attribute_name: case_query.target(case_query.case)}
104
104
  if len(targets) < len(case_queries):
105
105
  targets.append(target)
106
- match = self.is_matching(case_query, pred_cat)
106
+ match = is_matching(self.classify, case_query, pred_cat)
107
107
  if not match:
108
108
  print(f"Predicted: {pred_cat} but expected: {target}")
109
109
  if animate_tree and self.start_rule.size > num_rules:
110
110
  num_rules = self.start_rule.size
111
111
  self.update_figures()
112
112
  i += 1
113
- all_predictions = [1 if self.is_matching(case_query) else 0 for case_query in case_queries
113
+ all_predictions = [1 if is_matching(self.classify, case_query) else 0 for case_query in case_queries
114
114
  if case_query.target is not None]
115
115
  all_pred = sum(all_predictions)
116
116
  print(f"Accuracy: {all_pred}/{len(targets)}")
@@ -124,43 +124,6 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
124
124
  plt.ioff()
125
125
  plt.show()
126
126
 
127
- def is_matching(self, case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
128
- """
129
- :param case_query: The case query to check.
130
- :param pred_cat: The predicted category.
131
- :return: Whether the classifier prediction is matching case_query target or not.
132
- """
133
- if case_query.target is None:
134
- return False
135
- if pred_cat is None:
136
- pred_cat = self.classify(case_query.case)
137
- if not isinstance(pred_cat, dict):
138
- pred_cat = {case_query.attribute_name: pred_cat}
139
- target = {case_query.attribute_name: case_query.target_value}
140
- precision, recall = self.calculate_precision_and_recall(pred_cat, target)
141
- return all(recall) and all(precision)
142
-
143
- @staticmethod
144
- def calculate_precision_and_recall(pred_cat: Dict[str, Any], target: Dict[str, Any]) -> Tuple[
145
- List[bool], List[bool]]:
146
- """
147
- :param pred_cat: The predicted category.
148
- :param target: The target category.
149
- :return: The precision and recall of the classifier.
150
- """
151
- recall = []
152
- precision = []
153
- for pred_key, pred_value in pred_cat.items():
154
- if pred_key not in target:
155
- continue
156
- precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
157
- for target_key, target_value in target.items():
158
- if target_key not in pred_cat:
159
- recall.append(False)
160
- continue
161
- recall.extend([v in make_set(pred_cat[target_key]) for v in make_set(target_value)])
162
- return precision, recall
163
-
164
127
  def update_figures(self):
165
128
  """
166
129
  Update the figures of the classifier.
@@ -399,6 +362,12 @@ class SingleClassRDR(RDRWithCodeWriter):
399
362
  matched_rule = self.start_rule(case)
400
363
  return matched_rule if matched_rule else self.start_rule
401
364
 
365
+ def write_to_python_file(self, file_path: str, postfix: str = ""):
366
+ super().write_to_python_file(file_path, postfix)
367
+ if self.default_conclusion is not None:
368
+ with open(file_path + f"/{self.generated_python_file_name}.py", "a") as f:
369
+ f.write(f"{' '*4}else:\n{' '*8}return {self.default_conclusion}\n")
370
+
402
371
  def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
403
372
  defs_file: Optional[str] = None):
404
373
  """
@@ -33,6 +33,45 @@ import ast
33
33
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
34
34
 
35
35
 
36
+ def is_matching(rdr_classifier: Callable[[Any], Any], case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
37
+ """
38
+ :param rdr_classifier: The RDR classifier to check the prediction of.
39
+ :param case_query: The case query to check.
40
+ :param pred_cat: The predicted category.
41
+ :return: Whether the classifier prediction is matching case_query target or not.
42
+ """
43
+ if case_query.target is None:
44
+ return False
45
+ if pred_cat is None:
46
+ pred_cat = rdr_classifier(case_query.case)
47
+ if not isinstance(pred_cat, dict):
48
+ pred_cat = {case_query.attribute_name: pred_cat}
49
+ target = {case_query.attribute_name: case_query.target_value}
50
+ precision, recall = calculate_precision_and_recall(pred_cat, target)
51
+ return all(recall) and all(precision)
52
+
53
+
54
+ def calculate_precision_and_recall(pred_cat: Dict[str, Any], target: Dict[str, Any]) -> Tuple[
55
+ List[bool], List[bool]]:
56
+ """
57
+ :param pred_cat: The predicted category.
58
+ :param target: The target category.
59
+ :return: The precision and recall of the classifier.
60
+ """
61
+ recall = []
62
+ precision = []
63
+ for pred_key, pred_value in pred_cat.items():
64
+ if pred_key not in target:
65
+ continue
66
+ precision.extend([v in make_set(target[pred_key]) for v in make_set(pred_value)])
67
+ for target_key, target_value in target.items():
68
+ if target_key not in pred_cat:
69
+ recall.append(False)
70
+ continue
71
+ recall.extend([v in make_set(pred_cat[target_key]) for v in make_set(target_value)])
72
+ return precision, recall
73
+
74
+
36
75
  def get_rule_conclusion_as_source_code(rule: Rule, conclusion: str, parent_indent: str = "") -> str:
37
76
  """
38
77
  Convert the conclusion of a rule to source code.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -8,6 +8,7 @@ from unittest import TestCase
8
8
  from ripple_down_rules.datastructures.dataclasses import CaseQuery
9
9
  from ripple_down_rules.experts import Human
10
10
  from ripple_down_rules.rdr import SingleClassRDR, GeneralRDR
11
+ from ripple_down_rules.utils import is_matching
11
12
 
12
13
 
13
14
  @dataclass
@@ -93,17 +94,26 @@ class TestRDRWorld(TestCase):
93
94
  print(all_views)
94
95
  cls.drawer_case_queries = [CaseQuery(view, "correct", bool, True, default_value=False) for view in all_views]
95
96
 
96
- def test_drawer_scrdr(self):
97
- use_loaded_answers = True
98
- save_answers = False
97
+ def test_drawer_rdr(self):
98
+ self.get_drawer_rdr(use_loaded_answers=True, save_answers=False)
99
+
100
+ def test_write_drawer_rdr_to_python_file(self):
101
+ rdrs_dir = "./test_generated_rdrs"
102
+ drawer_rdr = self.get_drawer_rdr()
103
+ drawer_rdr.write_to_python_file(rdrs_dir)
104
+ loaded_rdr_classifier = drawer_rdr.get_rdr_classifier_from_python_file(rdrs_dir)
105
+ for case_query in self.drawer_case_queries:
106
+ self.assertTrue(is_matching(loaded_rdr_classifier, case_query))
107
+
108
+ def get_drawer_rdr(self, use_loaded_answers: bool = True, save_answers: bool = False):
99
109
  expert = Human(use_loaded_answers=use_loaded_answers)
100
- filename = os.path.join(os.getcwd(), "test_expert_answers/scrdr_world_expert_answers_fit")
110
+ filename = os.path.join(os.getcwd(), "test_expert_answers/correct_drawer_rdr_expert_answers_fit")
101
111
  if use_loaded_answers:
102
112
  expert.load_answers(filename)
103
- rdr = SingleClassRDR()
113
+ rdr = GeneralRDR()
104
114
  rdr.fit(self.drawer_case_queries, expert=expert, animate_tree=False)
105
115
  if save_answers:
106
116
  expert.save_answers(filename)
107
117
  for case_query in self.drawer_case_queries:
108
- self.assertEqual(rdr.classify(case_query.case), case_query.target_value)
109
- # print(f"Case: {case_query}, Classification: {rdr.classify(case_query.case)}")
118
+ self.assertTrue(is_matching(rdr.classify, case_query))
119
+ return rdr