ripple-down-rules 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datastructures/table.py +2 -2
- ripple_down_rules/experts.py +12 -3
- ripple_down_rules/rdr.py +34 -20
- ripple_down_rules-0.0.5.dist-info/METADATA +730 -0
- {ripple_down_rules-0.0.3.dist-info → ripple_down_rules-0.0.5.dist-info}/RECORD +8 -7
- ripple_down_rules-0.0.5.dist-info/licenses/LICENSE +674 -0
- ripple_down_rules-0.0.3.dist-info/METADATA +0 -54
- {ripple_down_rules-0.0.3.dist-info → ripple_down_rules-0.0.5.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.0.3.dist-info → ripple_down_rules-0.0.5.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import os
|
|
4
4
|
import time
|
5
5
|
from abc import ABC
|
6
6
|
from collections import UserDict
|
7
|
-
from copy import deepcopy
|
7
|
+
from copy import deepcopy, copy
|
8
8
|
from dataclasses import dataclass
|
9
9
|
from enum import Enum
|
10
10
|
|
@@ -78,7 +78,7 @@ class SubClassFactory:
|
|
78
78
|
parent_class_alias = cls.__name__ + "_"
|
79
79
|
imports = f"from {cls.__module__} import {cls.__name__} as {parent_class_alias}\n"
|
80
80
|
class_code = f"class {name}({parent_class_alias}):\n"
|
81
|
-
class_attributes =
|
81
|
+
class_attributes = copy(class_attributes) if class_attributes else {}
|
82
82
|
class_attributes.update({"_value_range": range_})
|
83
83
|
for key, value in class_attributes.items():
|
84
84
|
if value is not None:
|
ripple_down_rules/experts.py
CHANGED
@@ -97,14 +97,23 @@ class Human(Expert):
|
|
97
97
|
self.use_loaded_answers = use_loaded_answers
|
98
98
|
self.session = session
|
99
99
|
|
100
|
-
def save_answers(self, path: str):
|
100
|
+
def save_answers(self, path: str, append: bool = False):
|
101
101
|
"""
|
102
102
|
Save the expert answers to a file.
|
103
103
|
|
104
104
|
:param path: The path to save the answers to.
|
105
|
+
:param append: A flag to indicate if the answers should be appended to the file or not.
|
105
106
|
"""
|
106
|
-
|
107
|
-
|
107
|
+
if append:
|
108
|
+
# read the file and append the new answers
|
109
|
+
with open(path + '.json', "r") as f:
|
110
|
+
all_answers = json.load(f)
|
111
|
+
all_answers.extend(self.all_expert_answers)
|
112
|
+
with open(path + '.json', "w") as f:
|
113
|
+
json.dump(all_answers, f)
|
114
|
+
else:
|
115
|
+
with open(path + '.json', "w") as f:
|
116
|
+
json.dump(self.all_expert_answers, f)
|
108
117
|
|
109
118
|
def load_answers(self, path: str):
|
110
119
|
"""
|
ripple_down_rules/rdr.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
|
-
from copy import copy
|
4
|
+
from copy import copy, deepcopy
|
5
5
|
|
6
6
|
from matplotlib import pyplot as plt
|
7
7
|
from ordered_set import OrderedSet
|
8
8
|
from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
|
9
|
-
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self
|
9
|
+
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple
|
10
10
|
|
11
11
|
from .datastructures import Case, MCRDRMode, CallableExpression, Column, CaseQuery
|
12
12
|
from .experts import Expert, Human
|
@@ -51,7 +51,7 @@ class RippleDownRules(ABC):
|
|
51
51
|
pass
|
52
52
|
|
53
53
|
@abstractmethod
|
54
|
-
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs)\
|
54
|
+
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
55
55
|
-> Union[Column, CallableExpression]:
|
56
56
|
"""
|
57
57
|
Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
|
@@ -86,8 +86,6 @@ class RippleDownRules(ABC):
|
|
86
86
|
num_rules: int = 0
|
87
87
|
while not stop_iterating:
|
88
88
|
all_pred = 0
|
89
|
-
all_recall = []
|
90
|
-
all_precision = []
|
91
89
|
if not targets:
|
92
90
|
targets = [None] * len(cases)
|
93
91
|
for case_query in case_queries:
|
@@ -97,14 +95,7 @@ class RippleDownRules(ABC):
|
|
97
95
|
conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
|
98
96
|
target = expert.ask_for_conclusion(case_query, conclusions)
|
99
97
|
pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
|
100
|
-
|
101
|
-
target = target if isinstance(target, list) else [target]
|
102
|
-
recall = [not yi or (yi in pred_cat) for yi in target]
|
103
|
-
y_type = [type(yi) for yi in target]
|
104
|
-
precision = [(pred in target) or (type(pred) not in y_type) for pred in pred_cat]
|
105
|
-
match = all(recall) and all(precision)
|
106
|
-
all_recall.extend(recall)
|
107
|
-
all_precision.extend(precision)
|
98
|
+
match = self.is_matching(pred_cat, target)
|
108
99
|
if not match:
|
109
100
|
print(f"Predicted: {pred_cat} but expected: {target}")
|
110
101
|
all_pred += int(match)
|
@@ -112,21 +103,43 @@ class RippleDownRules(ABC):
|
|
112
103
|
num_rules = self.start_rule.size
|
113
104
|
self.update_figures()
|
114
105
|
i += 1
|
115
|
-
|
116
|
-
|
117
|
-
|
106
|
+
all_predictions = [1 if self.is_matching(self.classify(case), target) else 0
|
107
|
+
for case, target in zip(cases, targets)]
|
108
|
+
all_pred = sum(all_predictions)
|
109
|
+
print(f"Accuracy: {all_pred}/{len(targets)}")
|
110
|
+
all_predicted = targets and all_pred == len(targets)
|
118
111
|
num_iter_reached = n_iter and i >= n_iter
|
119
112
|
stop_iterating = all_predicted or num_iter_reached
|
120
113
|
if stop_iterating:
|
121
114
|
break
|
122
|
-
print(f"Recall: {sum(all_recall) / len(all_recall)}")
|
123
|
-
print(f"Precision: {sum(all_precision) / len(all_precision)}")
|
124
|
-
print(f"Accuracy: {all_pred}/{len(targets)}")
|
125
115
|
print(f"Finished training in {i} iterations")
|
126
116
|
if animate_tree:
|
127
117
|
plt.ioff()
|
128
118
|
plt.show()
|
129
119
|
|
120
|
+
@staticmethod
|
121
|
+
def calculate_precision_and_recall(pred_cat: List[Column], target: List[Column]) -> Tuple[List[bool], List[bool]]:
|
122
|
+
"""
|
123
|
+
:param pred_cat: The predicted category.
|
124
|
+
:param target: The target category.
|
125
|
+
:return: The precision and recall of the classifier.
|
126
|
+
"""
|
127
|
+
pred_cat = pred_cat if isinstance(pred_cat, list) else [pred_cat]
|
128
|
+
target = target if isinstance(target, list) else [target]
|
129
|
+
recall = [not yi or (yi in pred_cat) for yi in target]
|
130
|
+
target_types = [type(yi) for yi in target]
|
131
|
+
precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
|
132
|
+
return precision, recall
|
133
|
+
|
134
|
+
def is_matching(self, pred_cat: List[Column], target: List[Column]) -> bool:
|
135
|
+
"""
|
136
|
+
:param pred_cat: The predicted category.
|
137
|
+
:param target: The target category.
|
138
|
+
:return: Whether the classifier is matching or not.
|
139
|
+
"""
|
140
|
+
precision, recall = self.calculate_precision_and_recall(pred_cat, target)
|
141
|
+
return all(recall) and all(precision)
|
142
|
+
|
130
143
|
def update_figures(self):
|
131
144
|
"""
|
132
145
|
Update the figures of the classifier.
|
@@ -501,6 +514,7 @@ class MultiClassRDR(RippleDownRules):
|
|
501
514
|
same_type_conclusions = [c for c in self.conclusions if type(c) == type(evaluated_rule.conclusion)]
|
502
515
|
combined_conclusion = evaluated_rule.conclusion if isinstance(evaluated_rule.conclusion, set) \
|
503
516
|
else {evaluated_rule.conclusion}
|
517
|
+
combined_conclusion = deepcopy(combined_conclusion)
|
504
518
|
for c in same_type_conclusions:
|
505
519
|
combined_conclusion.update(c if isinstance(c, set) else make_set(c))
|
506
520
|
self.conclusions.remove(c)
|
@@ -582,7 +596,7 @@ class GeneralRDR(RippleDownRules):
|
|
582
596
|
break
|
583
597
|
return conclusions
|
584
598
|
|
585
|
-
def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs)\
|
599
|
+
def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs) \
|
586
600
|
-> List[Union[Column, CallableExpression]]:
|
587
601
|
"""
|
588
602
|
Fit the GRDR on a case, if the target is a new type of category, a new RDR is created for it,
|