ripple-down-rules 0.0.4__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/PKG-INFO +1 -1
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/pyproject.toml +1 -1
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/datastructures/dataclasses.py +2 -3
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/datastructures/enums.py +1 -1
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/datastructures/table.py +8 -13
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/experts.py +14 -5
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/rdr.py +66 -40
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/utils.py +19 -14
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/test/test_rdr.py +31 -13
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/test/test_rdr_alchemy.py +1 -1
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/LICENSE +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/README.md +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/setup.cfg +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/__init__.py +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/datasets.py +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/datastructures/__init__.py +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/datastructures/callable_expression.py +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/datastructures/generated/__init__.py +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/datastructures/generated/column/__init__.py +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/datastructures/generated/row/__init__.py +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/failures.py +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/prompt.py +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/rules.py +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules.egg-info/SOURCES.txt +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/test/test_json_serialization.py +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/test/test_relational_rdr.py +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/test/test_relational_rdr_alchemy.py +0 -0
- {ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.6
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
|
|
6
6
|
|
7
7
|
[project]
|
8
8
|
name = "ripple_down_rules"
|
9
|
-
version = "0.0.
|
9
|
+
version = "0.0.6"
|
10
10
|
description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
|
11
11
|
readme = "README.md"
|
12
12
|
authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
|
@@ -1,13 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from copy import copy, deepcopy
|
4
3
|
from dataclasses import dataclass
|
5
4
|
|
6
5
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
7
|
-
from typing_extensions import Any, Optional, Type
|
6
|
+
from typing_extensions import Any, Optional, Type
|
8
7
|
|
9
8
|
from .table import create_row, Case
|
10
|
-
from ..utils import get_attribute_name,
|
9
|
+
from ..utils import get_attribute_name, copy_case
|
11
10
|
|
12
11
|
|
13
12
|
@dataclass
|
{ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/datastructures/table.py
RENAMED
@@ -4,7 +4,7 @@ import os
|
|
4
4
|
import time
|
5
5
|
from abc import ABC
|
6
6
|
from collections import UserDict
|
7
|
-
from copy import deepcopy
|
7
|
+
from copy import deepcopy, copy
|
8
8
|
from dataclasses import dataclass
|
9
9
|
from enum import Enum
|
10
10
|
|
@@ -78,7 +78,7 @@ class SubClassFactory:
|
|
78
78
|
parent_class_alias = cls.__name__ + "_"
|
79
79
|
imports = f"from {cls.__module__} import {cls.__name__} as {parent_class_alias}\n"
|
80
80
|
class_code = f"class {name}({parent_class_alias}):\n"
|
81
|
-
class_attributes =
|
81
|
+
class_attributes = copy(class_attributes) if class_attributes else {}
|
82
82
|
class_attributes.update({"_value_range": range_})
|
83
83
|
for key, value in class_attributes.items():
|
84
84
|
if value is not None:
|
@@ -219,10 +219,10 @@ class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
|
|
219
219
|
value.update(make_set(self[name]))
|
220
220
|
super().__setitem__(name, value)
|
221
221
|
else:
|
222
|
-
super().__setitem__(name, make_set(self[name]))
|
222
|
+
super().__setitem__(name, make_set([self[name], value]))
|
223
223
|
else:
|
224
|
-
setattr(self, name, value)
|
225
224
|
super().__setitem__(name, value)
|
225
|
+
setattr(self, name, self[name])
|
226
226
|
|
227
227
|
def __contains__(self, item):
|
228
228
|
if isinstance(item, (type, Enum)):
|
@@ -428,9 +428,10 @@ def create_row(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
|
|
428
428
|
row = create_or_update_row_from_attribute(attr_value, attr, obj, attr, recursion_idx,
|
429
429
|
max_recursion_idx, parent_is_iterable, row)
|
430
430
|
attributes_type_hints[attr] = get_value_type_from_type_hint(attr, obj)
|
431
|
-
|
432
|
-
|
433
|
-
|
431
|
+
if recursion_idx == 0:
|
432
|
+
row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False,
|
433
|
+
attributes_type_hints=attributes_type_hints)
|
434
|
+
row = row_cls(id_=id(obj), **row)
|
434
435
|
return row
|
435
436
|
|
436
437
|
|
@@ -462,9 +463,6 @@ def create_or_update_row_from_attribute(attr_value: Any, name: str, obj: Any, ob
|
|
462
463
|
row[obj_name] = column
|
463
464
|
else:
|
464
465
|
row[obj_name] = make_set(attr_value) if parent_is_iterable else attr_value
|
465
|
-
if row.__class__.__name__ == "Row":
|
466
|
-
row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False)
|
467
|
-
row = row_cls(id_=id(obj), **row)
|
468
466
|
return row
|
469
467
|
|
470
468
|
|
@@ -490,14 +488,11 @@ def create_column_and_row_from_iterable_attribute(attr_value: Any, name: str, ob
|
|
490
488
|
raise ValueError(f"Could not determine the range of {name} in {obj}.")
|
491
489
|
attr_row = Row(id_=id(attr_value))
|
492
490
|
column = Column.create(name, range_).from_obj(values, row_obj=obj)
|
493
|
-
attributes_type_hints = {}
|
494
491
|
for idx, val in enumerate(values):
|
495
492
|
sub_attr_row = create_row(val, recursion_idx=recursion_idx,
|
496
493
|
max_recursion_idx=max_recursion_idx,
|
497
494
|
obj_name=obj_name, parent_is_iterable=True)
|
498
495
|
attr_row.update(sub_attr_row)
|
499
|
-
# attr_row_cls = Row.create(name or list(range_)[0].__name__, range_, attr_row, default_values=False)
|
500
|
-
# attr_row = attr_row_cls(id_=id(attr_value), **attr_row)
|
501
496
|
for sub_attr, val in attr_row.items():
|
502
497
|
setattr(column, sub_attr, val)
|
503
498
|
return column, attr_row
|
@@ -9,7 +9,7 @@ from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type,
|
|
9
9
|
from .datastructures import (Case, PromptFor, CallableExpression, Column, CaseQuery)
|
10
10
|
from .datastructures.table import show_current_and_corner_cases
|
11
11
|
from .prompt import prompt_user_for_expression, prompt_user_about_case
|
12
|
-
from .utils import get_all_subclasses
|
12
|
+
from .utils import get_all_subclasses, is_iterable
|
13
13
|
|
14
14
|
if TYPE_CHECKING:
|
15
15
|
from .rdr import Rule
|
@@ -97,14 +97,23 @@ class Human(Expert):
|
|
97
97
|
self.use_loaded_answers = use_loaded_answers
|
98
98
|
self.session = session
|
99
99
|
|
100
|
-
def save_answers(self, path: str):
|
100
|
+
def save_answers(self, path: str, append: bool = False):
|
101
101
|
"""
|
102
102
|
Save the expert answers to a file.
|
103
103
|
|
104
104
|
:param path: The path to save the answers to.
|
105
|
+
:param append: A flag to indicate if the answers should be appended to the file or not.
|
105
106
|
"""
|
106
|
-
|
107
|
-
|
107
|
+
if append:
|
108
|
+
# read the file and append the new answers
|
109
|
+
with open(path + '.json', "r") as f:
|
110
|
+
all_answers = json.load(f)
|
111
|
+
all_answers.extend(self.all_expert_answers)
|
112
|
+
with open(path + '.json', "w") as f:
|
113
|
+
json.dump(all_answers, f)
|
114
|
+
else:
|
115
|
+
with open(path + '.json', "w") as f:
|
116
|
+
json.dump(self.all_expert_answers, f)
|
108
117
|
|
109
118
|
def load_answers(self, path: str):
|
110
119
|
"""
|
@@ -116,7 +125,7 @@ class Human(Expert):
|
|
116
125
|
self.all_expert_answers = json.load(f)
|
117
126
|
|
118
127
|
def ask_for_conditions(self, case: Case,
|
119
|
-
targets: Union[List[Column], List[
|
128
|
+
targets: Union[List[Column], List[SQLColumn]],
|
120
129
|
last_evaluated_rule: Optional[Rule] = None) \
|
121
130
|
-> CallableExpression:
|
122
131
|
if not self.use_loaded_answers:
|
@@ -1,18 +1,18 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
|
-
from copy import copy
|
4
|
+
from copy import copy, deepcopy
|
5
5
|
|
6
6
|
from matplotlib import pyplot as plt
|
7
7
|
from ordered_set import OrderedSet
|
8
8
|
from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
|
9
|
-
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self
|
9
|
+
from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple
|
10
10
|
|
11
11
|
from .datastructures import Case, MCRDRMode, CallableExpression, Column, CaseQuery
|
12
12
|
from .experts import Expert, Human
|
13
13
|
from .rules import Rule, SingleClassRule, MultiClassTopRule
|
14
14
|
from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
|
15
|
-
get_hint_for_attribute, SubclassJSONSerializer
|
15
|
+
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list
|
16
16
|
|
17
17
|
|
18
18
|
class RippleDownRules(ABC):
|
@@ -51,7 +51,7 @@ class RippleDownRules(ABC):
|
|
51
51
|
pass
|
52
52
|
|
53
53
|
@abstractmethod
|
54
|
-
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs)\
|
54
|
+
def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
55
55
|
-> Union[Column, CallableExpression]:
|
56
56
|
"""
|
57
57
|
Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
|
@@ -86,8 +86,6 @@ class RippleDownRules(ABC):
|
|
86
86
|
num_rules: int = 0
|
87
87
|
while not stop_iterating:
|
88
88
|
all_pred = 0
|
89
|
-
all_recall = []
|
90
|
-
all_precision = []
|
91
89
|
if not targets:
|
92
90
|
targets = [None] * len(cases)
|
93
91
|
for case_query in case_queries:
|
@@ -97,14 +95,7 @@ class RippleDownRules(ABC):
|
|
97
95
|
conclusions = self.classify(case) if self.start_rule and self.start_rule.conditions else []
|
98
96
|
target = expert.ask_for_conclusion(case_query, conclusions)
|
99
97
|
pred_cat = self.fit_case(case_query, expert=expert, **kwargs_for_fit_case)
|
100
|
-
|
101
|
-
target = target if isinstance(target, list) else [target]
|
102
|
-
recall = [not yi or (yi in pred_cat) for yi in target]
|
103
|
-
y_type = [type(yi) for yi in target]
|
104
|
-
precision = [(pred in target) or (type(pred) not in y_type) for pred in pred_cat]
|
105
|
-
match = all(recall) and all(precision)
|
106
|
-
all_recall.extend(recall)
|
107
|
-
all_precision.extend(precision)
|
98
|
+
match = self.is_matching(pred_cat, target)
|
108
99
|
if not match:
|
109
100
|
print(f"Predicted: {pred_cat} but expected: {target}")
|
110
101
|
all_pred += int(match)
|
@@ -112,21 +103,45 @@ class RippleDownRules(ABC):
|
|
112
103
|
num_rules = self.start_rule.size
|
113
104
|
self.update_figures()
|
114
105
|
i += 1
|
115
|
-
|
116
|
-
|
117
|
-
|
106
|
+
all_predictions = [1 if self.is_matching(self.classify(case), target) else 0
|
107
|
+
for case, target in zip(cases, targets)]
|
108
|
+
all_pred = sum(all_predictions)
|
109
|
+
print(f"Accuracy: {all_pred}/{len(targets)}")
|
110
|
+
all_predicted = targets and all_pred == len(targets)
|
118
111
|
num_iter_reached = n_iter and i >= n_iter
|
119
112
|
stop_iterating = all_predicted or num_iter_reached
|
120
113
|
if stop_iterating:
|
121
114
|
break
|
122
|
-
print(f"Recall: {sum(all_recall) / len(all_recall)}")
|
123
|
-
print(f"Precision: {sum(all_precision) / len(all_precision)}")
|
124
|
-
print(f"Accuracy: {all_pred}/{len(targets)}")
|
125
115
|
print(f"Finished training in {i} iterations")
|
126
116
|
if animate_tree:
|
127
117
|
plt.ioff()
|
128
118
|
plt.show()
|
129
119
|
|
120
|
+
@staticmethod
|
121
|
+
def calculate_precision_and_recall(pred_cat: List[Column], target: List[Column]) -> Tuple[List[bool], List[bool]]:
|
122
|
+
"""
|
123
|
+
:param pred_cat: The predicted category.
|
124
|
+
:param target: The target category.
|
125
|
+
:return: The precision and recall of the classifier.
|
126
|
+
"""
|
127
|
+
pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
|
128
|
+
target = target if is_iterable(target) else [target]
|
129
|
+
recall = [not yi or (yi in pred_cat) for yi in target]
|
130
|
+
target_types = [type(yi) for yi in target]
|
131
|
+
if len(pred_cat) > 1:
|
132
|
+
print(pred_cat)
|
133
|
+
precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
|
134
|
+
return precision, recall
|
135
|
+
|
136
|
+
def is_matching(self, pred_cat: List[Column], target: List[Column]) -> bool:
|
137
|
+
"""
|
138
|
+
:param pred_cat: The predicted category.
|
139
|
+
:param target: The target category.
|
140
|
+
:return: Whether the classifier is matching or not.
|
141
|
+
"""
|
142
|
+
precision, recall = self.calculate_precision_and_recall(pred_cat, target)
|
143
|
+
return all(recall) and all(precision)
|
144
|
+
|
130
145
|
def update_figures(self):
|
131
146
|
"""
|
132
147
|
Update the figures of the classifier.
|
@@ -345,7 +360,7 @@ class MultiClassRDR(RippleDownRules):
|
|
345
360
|
self.add_conclusion(evaluated_rule)
|
346
361
|
|
347
362
|
if not next_rule:
|
348
|
-
if
|
363
|
+
if not make_set(target).intersection(make_set(self.conclusions)):
|
349
364
|
# Nothing fired and there is a target that should have been in the conclusions
|
350
365
|
self.add_rule_for_case(case, target, expert)
|
351
366
|
# Have to check all rules again to make sure only this new rule fires
|
@@ -496,15 +511,16 @@ class MultiClassRDR(RippleDownRules):
|
|
496
511
|
"""
|
497
512
|
conclusion_types = [type(c) for c in self.conclusions]
|
498
513
|
if type(evaluated_rule.conclusion) not in conclusion_types:
|
499
|
-
self.conclusions.
|
514
|
+
self.conclusions.extend(make_list(evaluated_rule.conclusion))
|
500
515
|
else:
|
501
516
|
same_type_conclusions = [c for c in self.conclusions if type(c) == type(evaluated_rule.conclusion)]
|
502
517
|
combined_conclusion = evaluated_rule.conclusion if isinstance(evaluated_rule.conclusion, set) \
|
503
518
|
else {evaluated_rule.conclusion}
|
519
|
+
combined_conclusion = copy(combined_conclusion)
|
504
520
|
for c in same_type_conclusions:
|
505
521
|
combined_conclusion.update(c if isinstance(c, set) else make_set(c))
|
506
522
|
self.conclusions.remove(c)
|
507
|
-
self.conclusions.extend(combined_conclusion)
|
523
|
+
self.conclusions.extend(make_list(combined_conclusion))
|
508
524
|
|
509
525
|
def add_top_rule(self, conditions: CallableExpression, conclusion: Any, corner_case: Union[Case, SQLTable]):
|
510
526
|
"""
|
@@ -573,16 +589,16 @@ class GeneralRDR(RippleDownRules):
|
|
573
589
|
continue
|
574
590
|
pred_atts = rdr.classify(case_cp)
|
575
591
|
if pred_atts:
|
576
|
-
pred_atts =
|
592
|
+
pred_atts = make_list(pred_atts)
|
577
593
|
pred_atts = [p for p in pred_atts if p not in conclusions]
|
578
594
|
added_attributes = True
|
579
595
|
conclusions.extend(pred_atts)
|
580
|
-
|
596
|
+
GeneralRDR.update_case(case_cp, pred_atts)
|
581
597
|
if not added_attributes:
|
582
598
|
break
|
583
599
|
return conclusions
|
584
600
|
|
585
|
-
def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs)\
|
601
|
+
def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs) \
|
586
602
|
-> List[Union[Column, CallableExpression]]:
|
587
603
|
"""
|
588
604
|
Fit the GRDR on a case, if the target is a new type of category, a new RDR is created for it,
|
@@ -610,21 +626,27 @@ class GeneralRDR(RippleDownRules):
|
|
610
626
|
if not target:
|
611
627
|
target = expert.ask_for_conclusion(case_query)
|
612
628
|
case_query_cp = CaseQuery(case_cp, attribute_name=case_query.attribute_name, target=target)
|
613
|
-
if
|
629
|
+
if is_iterable(target) and not isinstance(target, Column):
|
630
|
+
target_type = type(make_list(target)[0])
|
631
|
+
assert all([type(t) is target_type for t in target]), ("All targets of a case query must be of the same"
|
632
|
+
" type")
|
633
|
+
else:
|
634
|
+
target_type = type(target)
|
635
|
+
if target_type not in self.start_rules_dict:
|
614
636
|
conclusions = self.classify(case)
|
615
|
-
self.
|
637
|
+
self.update_case(case_cp, conclusions)
|
616
638
|
new_rdr = self.initialize_new_rdr_for_attribute(target, case_cp)
|
617
639
|
new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
|
618
|
-
self.start_rules_dict[
|
619
|
-
self.
|
620
|
-
elif not self.case_has_conclusion(case_cp,
|
640
|
+
self.start_rules_dict[target_type] = new_rdr
|
641
|
+
self.update_case(case_cp, new_conclusions, target_type)
|
642
|
+
elif not self.case_has_conclusion(case_cp, target_type):
|
621
643
|
for rdr_type, rdr in self.start_rules_dict.items():
|
622
|
-
if
|
644
|
+
if target_type is not rdr_type:
|
623
645
|
conclusions = rdr.classify(case_cp)
|
624
646
|
else:
|
625
|
-
conclusions = self.start_rules_dict[
|
647
|
+
conclusions = self.start_rules_dict[target_type].fit_case(case_query_cp,
|
626
648
|
expert, **kwargs)
|
627
|
-
self.
|
649
|
+
self.update_case(case_cp, conclusions, rdr_type)
|
628
650
|
|
629
651
|
return self.classify(case)
|
630
652
|
|
@@ -639,12 +661,14 @@ class GeneralRDR(RippleDownRules):
|
|
639
661
|
return MultiClassRDR()
|
640
662
|
else:
|
641
663
|
return SingleClassRDR()
|
642
|
-
|
664
|
+
elif isinstance(attribute, Column):
|
643
665
|
return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
|
666
|
+
else:
|
667
|
+
return MultiClassRDR() if is_iterable(attribute) else SingleClassRDR()
|
644
668
|
|
645
669
|
@staticmethod
|
646
|
-
def
|
647
|
-
|
670
|
+
def update_case(case: Union[Case, SQLTable],
|
671
|
+
conclusions: List[Any], attribute_type: Optional[Any] = None):
|
648
672
|
"""
|
649
673
|
Update the case with the conclusions.
|
650
674
|
|
@@ -654,7 +678,7 @@ class GeneralRDR(RippleDownRules):
|
|
654
678
|
"""
|
655
679
|
if not conclusions:
|
656
680
|
return
|
657
|
-
conclusions = [conclusions] if not isinstance(conclusions, list) else conclusions
|
681
|
+
conclusions = [conclusions] if not isinstance(conclusions, list) else list(conclusions)
|
658
682
|
if len(conclusions) == 0:
|
659
683
|
return
|
660
684
|
if isinstance(case, SQLTable):
|
@@ -663,7 +687,8 @@ class GeneralRDR(RippleDownRules):
|
|
663
687
|
hint, origin, args = get_hint_for_attribute(attr_name, case)
|
664
688
|
if isinstance(attribute, set) or origin == set:
|
665
689
|
attribute = set() if attribute is None else attribute
|
666
|
-
|
690
|
+
for c in conclusions:
|
691
|
+
attribute.update(make_set(c))
|
667
692
|
elif isinstance(attribute, list) or origin == list:
|
668
693
|
attribute = [] if attribute is None else attribute
|
669
694
|
attribute.extend(conclusions)
|
@@ -672,7 +697,8 @@ class GeneralRDR(RippleDownRules):
|
|
672
697
|
else:
|
673
698
|
raise ValueError(f"Cannot add multiple conclusions to attribute {attr_name}")
|
674
699
|
else:
|
675
|
-
|
700
|
+
for c in make_set(conclusions):
|
701
|
+
case.update(c.as_dict)
|
676
702
|
|
677
703
|
@property
|
678
704
|
def names_of_all_types(self) -> List[str]:
|
@@ -24,6 +24,24 @@ if TYPE_CHECKING:
|
|
24
24
|
matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
|
25
25
|
|
26
26
|
|
27
|
+
def make_list(value: Any) -> List:
|
28
|
+
"""
|
29
|
+
Make a list from a value.
|
30
|
+
|
31
|
+
:param value: The value to make a list from.
|
32
|
+
"""
|
33
|
+
return list(value) if is_iterable(value) else [value]
|
34
|
+
|
35
|
+
|
36
|
+
def is_iterable(obj: Any) -> bool:
|
37
|
+
"""
|
38
|
+
Check if an object is iterable.
|
39
|
+
|
40
|
+
:param obj: The object to check.
|
41
|
+
"""
|
42
|
+
return hasattr(obj, "__iter__") and not isinstance(obj, (str, type))
|
43
|
+
|
44
|
+
|
27
45
|
def get_type_from_string(type_path: str):
|
28
46
|
"""
|
29
47
|
Get a type from a string describing its path using the format "module_path.ClassName".
|
@@ -350,20 +368,7 @@ def make_set(value: Any) -> Set:
|
|
350
368
|
|
351
369
|
:param value: The value to make a set from.
|
352
370
|
"""
|
353
|
-
|
354
|
-
return set(value)
|
355
|
-
return {value}
|
356
|
-
|
357
|
-
|
358
|
-
def make_list(value: Any) -> List:
|
359
|
-
"""
|
360
|
-
Make a list from a value.
|
361
|
-
|
362
|
-
:param value: The value to make a list from.
|
363
|
-
"""
|
364
|
-
if hasattr(value, "__iter__") and not isinstance(value, (str, type)):
|
365
|
-
return list(value)
|
366
|
-
return [value]
|
371
|
+
return set(value) if is_iterable(value) else {value}
|
367
372
|
|
368
373
|
|
369
374
|
def make_value_or_raise_error(value: Any) -> Any:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.6
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -3,7 +3,7 @@ from unittest import TestCase, skip
|
|
3
3
|
|
4
4
|
from typing_extensions import List
|
5
5
|
|
6
|
-
from ripple_down_rules.datasets import
|
6
|
+
from ripple_down_rules.datasets import Habitat, SpeciesCol as Species
|
7
7
|
from ripple_down_rules.datasets import load_zoo_dataset
|
8
8
|
from ripple_down_rules.datastructures import Case, MCRDRMode, \
|
9
9
|
Row, Column, Category, CaseQuery
|
@@ -102,7 +102,7 @@ class TestRDR(TestCase):
|
|
102
102
|
expert.load_answers(filename)
|
103
103
|
mcrdr = MultiClassRDR()
|
104
104
|
case_queries = [CaseQuery(case, target=target) for case, target in zip(self.all_cases, self.targets)]
|
105
|
-
mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
|
105
|
+
mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree, n_iter=1)
|
106
106
|
render_tree(mcrdr.start_rule, use_dot_exporter=True,
|
107
107
|
filename=self.test_results_dir + f"/mcrdr_stop_only")
|
108
108
|
cats = mcrdr.classify(self.all_cases[50])
|
@@ -117,13 +117,22 @@ class TestRDR(TestCase):
|
|
117
117
|
use_loaded_answers = True
|
118
118
|
draw_tree = False
|
119
119
|
save_answers = False
|
120
|
+
append = False
|
120
121
|
filename = self.expert_answers_dir + "/mcrdr_stop_plus_rule_expert_answers_fit"
|
121
122
|
expert = Human(use_loaded_answers=use_loaded_answers)
|
122
123
|
if use_loaded_answers:
|
123
124
|
expert.load_answers(filename)
|
124
125
|
mcrdr = MultiClassRDR(mode=MCRDRMode.StopPlusRule)
|
125
126
|
case_queries = [CaseQuery(case, target=target) for case, target in zip(self.all_cases, self.targets)]
|
126
|
-
|
127
|
+
try:
|
128
|
+
mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
|
129
|
+
# catch pop from empty list error
|
130
|
+
except IndexError as e:
|
131
|
+
if append:
|
132
|
+
expert.use_loaded_answers = False
|
133
|
+
mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
|
134
|
+
else:
|
135
|
+
raise e
|
127
136
|
render_tree(mcrdr.start_rule, use_dot_exporter=True,
|
128
137
|
filename=self.test_results_dir + f"/mcrdr_stop_plus_rule")
|
129
138
|
cats = mcrdr.classify(self.all_cases[50])
|
@@ -132,19 +141,28 @@ class TestRDR(TestCase):
|
|
132
141
|
if save_answers:
|
133
142
|
cwd = os.getcwd()
|
134
143
|
file = os.path.join(cwd, filename)
|
135
|
-
expert.save_answers(file)
|
144
|
+
expert.save_answers(file, append=append)
|
136
145
|
|
137
146
|
def test_fit_mcrdr_stop_plus_rule_combined(self):
|
138
147
|
use_loaded_answers = True
|
139
148
|
save_answers = False
|
140
149
|
draw_tree = False
|
150
|
+
append = False
|
141
151
|
filename = self.expert_answers_dir + "/mcrdr_stop_plus_rule_combined_expert_answers_fit"
|
142
152
|
expert = Human(use_loaded_answers=use_loaded_answers)
|
143
153
|
if use_loaded_answers:
|
144
154
|
expert.load_answers(filename)
|
145
155
|
mcrdr = MultiClassRDR(mode=MCRDRMode.StopPlusRuleCombined)
|
146
156
|
case_queries = [CaseQuery(case, target=target) for case, target in zip(self.all_cases, self.targets)]
|
147
|
-
|
157
|
+
try:
|
158
|
+
mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
|
159
|
+
# catch pop from empty list error
|
160
|
+
except IndexError as e:
|
161
|
+
if append:
|
162
|
+
expert.use_loaded_answers = False
|
163
|
+
mcrdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
|
164
|
+
else:
|
165
|
+
raise e
|
148
166
|
render_tree(mcrdr.start_rule, use_dot_exporter=True,
|
149
167
|
filename=self.test_results_dir + f"/mcrdr_stop_plus_rule_combined")
|
150
168
|
cats = mcrdr.classify(self.all_cases[50])
|
@@ -153,7 +171,7 @@ class TestRDR(TestCase):
|
|
153
171
|
if save_answers:
|
154
172
|
cwd = os.getcwd()
|
155
173
|
file = os.path.join(cwd, filename)
|
156
|
-
expert.save_answers(file)
|
174
|
+
expert.save_answers(file, append=append)
|
157
175
|
|
158
176
|
@skip("Extra conclusions loaded answers are not working with new prompt interface")
|
159
177
|
def test_classify_mcrdr_with_extra_conclusions(self):
|
@@ -235,19 +253,19 @@ class TestRDR(TestCase):
|
|
235
253
|
def get_habitat(x: Row, t: Category):
|
236
254
|
all_habs = []
|
237
255
|
if t == Species.mammal and x["aquatic"] == 0:
|
238
|
-
all_habs.append(Habitat.land)
|
256
|
+
all_habs.append({Habitat.land})
|
239
257
|
elif t == Species.bird:
|
240
|
-
all_habs.append(Habitat.land)
|
258
|
+
all_habs.append({Habitat.land})
|
241
259
|
if x["airborne"] == 1:
|
242
|
-
all_habs[-1].update(Habitat.air)
|
260
|
+
all_habs[-1].update({Habitat.air})
|
243
261
|
if x["aquatic"] == 1:
|
244
|
-
all_habs[-1].update(Habitat.water)
|
262
|
+
all_habs[-1].update({Habitat.water})
|
245
263
|
elif t == Species.fish:
|
246
|
-
all_habs.append(Habitat.water)
|
264
|
+
all_habs.append({Habitat.water})
|
247
265
|
elif t == Species.molusc:
|
248
|
-
all_habs.append(Habitat.land)
|
266
|
+
all_habs.append({Habitat.land})
|
249
267
|
if x["aquatic"] == 1:
|
250
|
-
all_habs[-1].update(Habitat.water)
|
268
|
+
all_habs[-1].update({Habitat.water})
|
251
269
|
return all_habs + [t]
|
252
270
|
|
253
271
|
n = 20
|
@@ -116,7 +116,7 @@ class TestAlchemyRDR(TestCase):
|
|
116
116
|
for case, attributes, targets in zip(self.all_cases[:n], all_attributes, habitat_targets):
|
117
117
|
for attr, target in zip(attributes, targets):
|
118
118
|
case_queries.append(CaseQuery(case, attr, target=target))
|
119
|
-
grdr.fit(case_queries, expert=expert, animate_tree=draw_tree
|
119
|
+
grdr.fit(case_queries, expert=expert, animate_tree=draw_tree)
|
120
120
|
for rule in grdr.start_rules:
|
121
121
|
render_tree(rule, use_dot_exporter=True,
|
122
122
|
filename=self.test_results_dir + f"/grdr_{type(rule.conclusion).__name__}")
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules/datastructures/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules.egg-info/SOURCES.txt
RENAMED
File without changes
|
File without changes
|
{ripple_down_rules-0.0.4 → ripple_down_rules-0.0.6}/src/ripple_down_rules.egg-info/top_level.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|