ripple-down-rules 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,7 @@ import os
4
4
  import time
5
5
  from abc import ABC
6
6
  from collections import UserDict
7
- from copy import deepcopy
7
+ from copy import deepcopy, copy
8
8
  from dataclasses import dataclass
9
9
  from enum import Enum
10
10
 
@@ -78,7 +78,7 @@ class SubClassFactory:
78
78
  parent_class_alias = cls.__name__ + "_"
79
79
  imports = f"from {cls.__module__} import {cls.__name__} as {parent_class_alias}\n"
80
80
  class_code = f"class {name}({parent_class_alias}):\n"
81
- class_attributes = deepcopy(class_attributes) if class_attributes else {}
81
+ class_attributes = copy(class_attributes) if class_attributes else {}
82
82
  class_attributes.update({"_value_range": range_})
83
83
  for key, value in class_attributes.items():
84
84
  if value is not None:
@@ -97,14 +97,23 @@ class Human(Expert):
97
97
  self.use_loaded_answers = use_loaded_answers
98
98
  self.session = session
99
99
 
100
- def save_answers(self, path: str):
100
+ def save_answers(self, path: str, append: bool = False):
101
101
  """
102
102
  Save the expert answers to a file.
103
103
 
104
104
  :param path: The path to save the answers to.
105
+ :param append: A flag to indicate if the answers should be appended to the file or not.
105
106
  """
106
- with open(path + '.json', "w") as f:
107
- json.dump(self.all_expert_answers, f)
107
+ if append:
108
+ # read the file and append the new answers
109
+ with open(path + '.json', "r") as f:
110
+ all_answers = json.load(f)
111
+ all_answers.extend(self.all_expert_answers)
112
+ with open(path + '.json', "w") as f:
113
+ json.dump(all_answers, f)
114
+ else:
115
+ with open(path + '.json', "w") as f:
116
+ json.dump(self.all_expert_answers, f)
108
117
 
109
118
  def load_answers(self, path: str):
110
119
  """
ripple_down_rules/rdr.py CHANGED
@@ -1,12 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from copy import copy
4
+ from copy import copy, deepcopy
5
5
 
6
6
  from matplotlib import pyplot as plt
7
7
  from ordered_set import OrderedSet
8
8
  from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
9
- from typing_extensions import List, Optional, Dict, Type, Union, Any, Self
9
+ from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple
10
10
 
11
11
  from .datastructures import Case, MCRDRMode, CallableExpression, Column, CaseQuery
12
12
  from .experts import Expert, Human
@@ -51,7 +51,7 @@ class RippleDownRules(ABC):
51
51
  pass
52
52
 
53
53
  @abstractmethod
54
- def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs)\
54
+ def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
55
55
  -> Union[Column, CallableExpression]:
56
56
  """
57
57
  Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
@@ -86,8 +86,6 @@ class RippleDownRules(ABC):
86
86
  num_rules: int = 0
87
87
  while not stop_iterating:
88
88
  all_pred = 0
89
- all_recall = []
90
- all_precision = []
91
89
  if not targets:
92
90
  targets = [None] * len(cases)
93
91
  for case_query in case_queries:
@@ -97,14 +95,7 @@ class RippleDownRules(ABC):
97
95
  conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
98
96
  target = expert.ask_for_conclusion(case_query, conclusions)
99
97
  pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
100
- pred_cat = pred_cat if isinstance(pred_cat, list) else [pred_cat]
101
- target = target if isinstance(target, list) else [target]
102
- recall = [not yi or (yi in pred_cat) for yi in target]
103
- y_type = [type(yi) for yi in target]
104
- precision = [(pred in target) or (type(pred) not in y_type) for pred in pred_cat]
105
- match = all(recall) and all(precision)
106
- all_recall.extend(recall)
107
- all_precision.extend(precision)
98
+ match = self.is_matching(pred_cat, target)
108
99
  if not match:
109
100
  print(f"Predicted: {pred_cat} but expected: {target}")
110
101
  all_pred += int(match)
@@ -112,21 +103,43 @@ class RippleDownRules(ABC):
112
103
  num_rules = self.start_rule.size
113
104
  self.update_figures()
114
105
  i += 1
115
- all_pred = [1 if p == t else 0
116
- for case, target in zip(cases, targets) for p, t in zip(self.classify(case), target)]
117
- all_predicted = targets and sum(all_pred) == len(targets)
106
+ all_predictions = [1 if self.is_matching(self.classify(case), target) else 0
107
+ for case, target in zip(cases, targets)]
108
+ all_pred = sum(all_predictions)
109
+ print(f"Accuracy: {all_pred}/{len(targets)}")
110
+ all_predicted = targets and all_pred == len(targets)
118
111
  num_iter_reached = n_iter and i >= n_iter
119
112
  stop_iterating = all_predicted or num_iter_reached
120
113
  if stop_iterating:
121
114
  break
122
- print(f"Recall: {sum(all_recall) / len(all_recall)}")
123
- print(f"Precision: {sum(all_precision) / len(all_precision)}")
124
- print(f"Accuracy: {all_pred}/{len(targets)}")
125
115
  print(f"Finished training in {i} iterations")
126
116
  if animate_tree:
127
117
  plt.ioff()
128
118
  plt.show()
129
119
 
120
+ @staticmethod
121
+ def calculate_precision_and_recall(pred_cat: List[Column], target: List[Column]) -> Tuple[List[bool], List[bool]]:
122
+ """
123
+ :param pred_cat: The predicted category.
124
+ :param target: The target category.
125
+ :return: The precision and recall of the classifier.
126
+ """
127
+ pred_cat = pred_cat if isinstance(pred_cat, list) else [pred_cat]
128
+ target = target if isinstance(target, list) else [target]
129
+ recall = [not yi or (yi in pred_cat) for yi in target]
130
+ target_types = [type(yi) for yi in target]
131
+ precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
132
+ return precision, recall
133
+
134
+ def is_matching(self, pred_cat: List[Column], target: List[Column]) -> bool:
135
+ """
136
+ :param pred_cat: The predicted category.
137
+ :param target: The target category.
138
+ :return: Whether the classifier is matching or not.
139
+ """
140
+ precision, recall = self.calculate_precision_and_recall(pred_cat, target)
141
+ return all(recall) and all(precision)
142
+
130
143
  def update_figures(self):
131
144
  """
132
145
  Update the figures of the classifier.
@@ -501,6 +514,7 @@ class MultiClassRDR(RippleDownRules):
501
514
  same_type_conclusions = [c for c in self.conclusions if type(c) == type(evaluated_rule.conclusion)]
502
515
  combined_conclusion = evaluated_rule.conclusion if isinstance(evaluated_rule.conclusion, set) \
503
516
  else {evaluated_rule.conclusion}
517
+ combined_conclusion = deepcopy(combined_conclusion)
504
518
  for c in same_type_conclusions:
505
519
  combined_conclusion.update(c if isinstance(c, set) else make_set(c))
506
520
  self.conclusions.remove(c)
@@ -582,7 +596,7 @@ class GeneralRDR(RippleDownRules):
582
596
  break
583
597
  return conclusions
584
598
 
585
- def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs)\
599
+ def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs) \
586
600
  -> List[Union[Column, CallableExpression]]:
587
601
  """
588
602
  Fit the GRDR on a case, if the target is a new type of category, a new RDR is created for it,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -1,21 +1,21 @@
1
1
  ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  ripple_down_rules/datasets.py,sha256=QRB-1BdFTcUNuhgYEuXYx6qQOYlDu03_iLKDBqrcVrQ,4511
3
- ripple_down_rules/experts.py,sha256=wss-rMjmVRNoNpPNo_EygNsh57iebNi-g2j7_SkgWIU,11636
3
+ ripple_down_rules/experts.py,sha256=DMTC-E2g1Fs43oyr30MGeGi5-VKBb3RojzzPa8DvCSA,12093
4
4
  ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
5
  ripple_down_rules/prompt.py,sha256=lmREZRyleBTHrVtcf2j_48oc0v3VlxXYGhl6w1mk8qI,4208
6
- ripple_down_rules/rdr.py,sha256=i-SFrnrkii_ARc-hLnSJHCwfOim09ynrqCnnOw_y-BI,33199
6
+ ripple_down_rules/rdr.py,sha256=nRhrHdQYmvXgfTFLyO_xxwe7GDYuUhaOqZlQtqSUrnE,33751
7
7
  ripple_down_rules/rules.py,sha256=TLptqvA6I3QlQaVBTYchbgvXm17XWwFJoTmoN0diHm8,10348
8
8
  ripple_down_rules/utils.py,sha256=WlUXTf-B45SPEKpDBVb9QPQWS54250MGAY5xl9MIhR4,16944
9
9
  ripple_down_rules/datastructures/__init__.py,sha256=wY9WqXavuE3wQ1YP65cs_SZyr7CEMB9tol-4oxgK9CM,104
10
10
  ripple_down_rules/datastructures/callable_expression.py,sha256=yEZ6OWzSiWsRtEz_5UquA0inmodFTOWqXWV_a2gg1cg,9110
11
11
  ripple_down_rules/datastructures/dataclasses.py,sha256=z_9B7Nj_MIf2Iyrs5VeUhXhYxwaqnuKVjgwxhZZTygY,2525
12
12
  ripple_down_rules/datastructures/enums.py,sha256=6Mh55_8QRuXyYZXtonWr01VBgLP-jYp91K_8hIgh8u8,4244
13
- ripple_down_rules/datastructures/table.py,sha256=q3RIZ5BiqaFNTKcVRc28xQSvso0X_eGf1l_kVCw4mbA,23159
13
+ ripple_down_rules/datastructures/table.py,sha256=Tq7savAaFoB7n7x6-u2Vz6NfvjLBzPxsqblR4cHujzM,23161
14
14
  ripple_down_rules/datastructures/generated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  ripple_down_rules/datastructures/generated/column/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  ripple_down_rules/datastructures/generated/row/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- ripple_down_rules-0.0.4.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
18
- ripple_down_rules-0.0.4.dist-info/METADATA,sha256=MUirOiWljZwLdaaEqnIjhk39fbeILkxFZcoYQI_Zw5g,42518
19
- ripple_down_rules-0.0.4.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
20
- ripple_down_rules-0.0.4.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
21
- ripple_down_rules-0.0.4.dist-info/RECORD,,
17
+ ripple_down_rules-0.0.5.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
18
+ ripple_down_rules-0.0.5.dist-info/METADATA,sha256=qwUDSDi3YlBsg1RlYOEtFn5YQ3x7yNizCAKm72_8YPU,42518
19
+ ripple_down_rules-0.0.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
20
+ ripple_down_rules-0.0.5.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
21
+ ripple_down_rules-0.0.5.dist-info/RECORD,,