ripple-down-rules 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datastructures/dataclasses.py +2 -3
- ripple_down_rules/datastructures/enums.py +1 -1
- ripple_down_rules/datastructures/table.py +6 -11
- ripple_down_rules/experts.py +2 -2
- ripple_down_rules/rdr.py +35 -23
- ripple_down_rules/utils.py +19 -14
- {ripple_down_rules-0.0.5.dist-info → ripple_down_rules-0.0.6.dist-info}/METADATA +1 -1
- {ripple_down_rules-0.0.5.dist-info → ripple_down_rules-0.0.6.dist-info}/RECORD +11 -11
- {ripple_down_rules-0.0.5.dist-info → ripple_down_rules-0.0.6.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.0.5.dist-info → ripple_down_rules-0.0.6.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.0.5.dist-info → ripple_down_rules-0.0.6.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from copy import copy, deepcopy
|
4
3
|
from dataclasses import dataclass
|
5
4
|
|
6
5
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
7
|
-
from typing_extensions import Any, Optional, Type
|
6
|
+
from typing_extensions import Any, Optional, Type
|
8
7
|
|
9
8
|
from .table import create_row, Case
|
10
|
-
from ..utils import get_attribute_name,
|
9
|
+
from ..utils import get_attribute_name, copy_case
|
11
10
|
|
12
11
|
|
13
12
|
@dataclass
|
@@ -219,10 +219,10 @@ class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
|
|
219
219
|
value.update(make_set(self[name]))
|
220
220
|
super().__setitem__(name, value)
|
221
221
|
else:
|
222
|
-
super().__setitem__(name, make_set(self[name]))
|
222
|
+
super().__setitem__(name, make_set([self[name], value]))
|
223
223
|
else:
|
224
|
-
setattr(self, name, value)
|
225
224
|
super().__setitem__(name, value)
|
225
|
+
setattr(self, name, self[name])
|
226
226
|
|
227
227
|
def __contains__(self, item):
|
228
228
|
if isinstance(item, (type, Enum)):
|
@@ -428,9 +428,10 @@ def create_row(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
|
|
428
428
|
row = create_or_update_row_from_attribute(attr_value, attr, obj, attr, recursion_idx,
|
429
429
|
max_recursion_idx, parent_is_iterable, row)
|
430
430
|
attributes_type_hints[attr] = get_value_type_from_type_hint(attr, obj)
|
431
|
-
|
432
|
-
|
433
|
-
|
431
|
+
if recursion_idx == 0:
|
432
|
+
row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False,
|
433
|
+
attributes_type_hints=attributes_type_hints)
|
434
|
+
row = row_cls(id_=id(obj), **row)
|
434
435
|
return row
|
435
436
|
|
436
437
|
|
@@ -462,9 +463,6 @@ def create_or_update_row_from_attribute(attr_value: Any, name: str, obj: Any, ob
|
|
462
463
|
row[obj_name] = column
|
463
464
|
else:
|
464
465
|
row[obj_name] = make_set(attr_value) if parent_is_iterable else attr_value
|
465
|
-
if row.__class__.__name__ == "Row":
|
466
|
-
row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False)
|
467
|
-
row = row_cls(id_=id(obj), **row)
|
468
466
|
return row
|
469
467
|
|
470
468
|
|
@@ -490,14 +488,11 @@ def create_column_and_row_from_iterable_attribute(attr_value: Any, name: str, ob
|
|
490
488
|
raise ValueError(f"Could not determine the range of {name} in {obj}.")
|
491
489
|
attr_row = Row(id_=id(attr_value))
|
492
490
|
column = Column.create(name, range_).from_obj(values, row_obj=obj)
|
493
|
-
attributes_type_hints = {}
|
494
491
|
for idx, val in enumerate(values):
|
495
492
|
sub_attr_row = create_row(val, recursion_idx=recursion_idx,
|
496
493
|
max_recursion_idx=max_recursion_idx,
|
497
494
|
obj_name=obj_name, parent_is_iterable=True)
|
498
495
|
attr_row.update(sub_attr_row)
|
499
|
-
# attr_row_cls = Row.create(name or list(range_)[0].__name__, range_, attr_row, default_values=False)
|
500
|
-
# attr_row = attr_row_cls(id_=id(attr_value), **attr_row)
|
501
496
|
for sub_attr, val in attr_row.items():
|
502
497
|
setattr(column, sub_attr, val)
|
503
498
|
return column, attr_row
|
ripple_down_rules/experts.py
CHANGED
@@ -9,7 +9,7 @@ from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type,
|
|
9
9
|
from .datastructures import (Case, PromptFor, CallableExpression, Column, CaseQuery)
|
10
10
|
from .datastructures.table import show_current_and_corner_cases
|
11
11
|
from .prompt import prompt_user_for_expression, prompt_user_about_case
|
12
|
-
from .utils import get_all_subclasses
|
12
|
+
from .utils import get_all_subclasses, is_iterable
|
13
13
|
|
14
14
|
if TYPE_CHECKING:
|
15
15
|
from .rdr import Rule
|
@@ -125,7 +125,7 @@ class Human(Expert):
|
|
125
125
|
self.all_expert_answers = json.load(f)
|
126
126
|
|
127
127
|
def ask_for_conditions(self, case: Case,
|
128
|
-
targets: Union[List[Column], List[
|
128
|
+
targets: Union[List[Column], List[SQLColumn]],
|
129
129
|
last_evaluated_rule: Optional[Rule] = None) \
|
130
130
|
-> CallableExpression:
|
131
131
|
if not self.use_loaded_answers:
|
ripple_down_rules/rdr.py
CHANGED
@@ -12,7 +12,7 @@ from .datastructures import Case, MCRDRMode, CallableExpression, Column, CaseQue
|
|
12
12
|
from .experts import Expert, Human
|
13
13
|
from .rules import Rule, SingleClassRule, MultiClassTopRule
|
14
14
|
from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
|
15
|
-
get_hint_for_attribute, SubclassJSONSerializer
|
15
|
+
get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list
|
16
16
|
|
17
17
|
|
18
18
|
class RippleDownRules(ABC):
|
@@ -124,10 +124,12 @@ class RippleDownRules(ABC):
|
|
124
124
|
:param target: The target category.
|
125
125
|
:return: The precision and recall of the classifier.
|
126
126
|
"""
|
127
|
-
pred_cat = pred_cat if
|
128
|
-
target = target if
|
127
|
+
pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
|
128
|
+
target = target if is_iterable(target) else [target]
|
129
129
|
recall = [not yi or (yi in pred_cat) for yi in target]
|
130
130
|
target_types = [type(yi) for yi in target]
|
131
|
+
if len(pred_cat) > 1:
|
132
|
+
print(pred_cat)
|
131
133
|
precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
|
132
134
|
return precision, recall
|
133
135
|
|
@@ -358,7 +360,7 @@ class MultiClassRDR(RippleDownRules):
|
|
358
360
|
self.add_conclusion(evaluated_rule)
|
359
361
|
|
360
362
|
if not next_rule:
|
361
|
-
if
|
363
|
+
if not make_set(target).intersection(make_set(self.conclusions)):
|
362
364
|
# Nothing fired and there is a target that should have been in the conclusions
|
363
365
|
self.add_rule_for_case(case, target, expert)
|
364
366
|
# Have to check all rules again to make sure only this new rule fires
|
@@ -509,16 +511,16 @@ class MultiClassRDR(RippleDownRules):
|
|
509
511
|
"""
|
510
512
|
conclusion_types = [type(c) for c in self.conclusions]
|
511
513
|
if type(evaluated_rule.conclusion) not in conclusion_types:
|
512
|
-
self.conclusions.
|
514
|
+
self.conclusions.extend(make_list(evaluated_rule.conclusion))
|
513
515
|
else:
|
514
516
|
same_type_conclusions = [c for c in self.conclusions if type(c) == type(evaluated_rule.conclusion)]
|
515
517
|
combined_conclusion = evaluated_rule.conclusion if isinstance(evaluated_rule.conclusion, set) \
|
516
518
|
else {evaluated_rule.conclusion}
|
517
|
-
combined_conclusion =
|
519
|
+
combined_conclusion = copy(combined_conclusion)
|
518
520
|
for c in same_type_conclusions:
|
519
521
|
combined_conclusion.update(c if isinstance(c, set) else make_set(c))
|
520
522
|
self.conclusions.remove(c)
|
521
|
-
self.conclusions.extend(combined_conclusion)
|
523
|
+
self.conclusions.extend(make_list(combined_conclusion))
|
522
524
|
|
523
525
|
def add_top_rule(self, conditions: CallableExpression, conclusion: Any, corner_case: Union[Case, SQLTable]):
|
524
526
|
"""
|
@@ -587,11 +589,11 @@ class GeneralRDR(RippleDownRules):
|
|
587
589
|
continue
|
588
590
|
pred_atts = rdr.classify(case_cp)
|
589
591
|
if pred_atts:
|
590
|
-
pred_atts =
|
592
|
+
pred_atts = make_list(pred_atts)
|
591
593
|
pred_atts = [p for p in pred_atts if p not in conclusions]
|
592
594
|
added_attributes = True
|
593
595
|
conclusions.extend(pred_atts)
|
594
|
-
|
596
|
+
GeneralRDR.update_case(case_cp, pred_atts)
|
595
597
|
if not added_attributes:
|
596
598
|
break
|
597
599
|
return conclusions
|
@@ -624,21 +626,27 @@ class GeneralRDR(RippleDownRules):
|
|
624
626
|
if not target:
|
625
627
|
target = expert.ask_for_conclusion(case_query)
|
626
628
|
case_query_cp = CaseQuery(case_cp, attribute_name=case_query.attribute_name, target=target)
|
627
|
-
if
|
629
|
+
if is_iterable(target) and not isinstance(target, Column):
|
630
|
+
target_type = type(make_list(target)[0])
|
631
|
+
assert all([type(t) is target_type for t in target]), ("All targets of a case query must be of the same"
|
632
|
+
" type")
|
633
|
+
else:
|
634
|
+
target_type = type(target)
|
635
|
+
if target_type not in self.start_rules_dict:
|
628
636
|
conclusions = self.classify(case)
|
629
|
-
self.
|
637
|
+
self.update_case(case_cp, conclusions)
|
630
638
|
new_rdr = self.initialize_new_rdr_for_attribute(target, case_cp)
|
631
639
|
new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
|
632
|
-
self.start_rules_dict[
|
633
|
-
self.
|
634
|
-
elif not self.case_has_conclusion(case_cp,
|
640
|
+
self.start_rules_dict[target_type] = new_rdr
|
641
|
+
self.update_case(case_cp, new_conclusions, target_type)
|
642
|
+
elif not self.case_has_conclusion(case_cp, target_type):
|
635
643
|
for rdr_type, rdr in self.start_rules_dict.items():
|
636
|
-
if
|
644
|
+
if target_type is not rdr_type:
|
637
645
|
conclusions = rdr.classify(case_cp)
|
638
646
|
else:
|
639
|
-
conclusions = self.start_rules_dict[
|
647
|
+
conclusions = self.start_rules_dict[target_type].fit_case(case_query_cp,
|
640
648
|
expert, **kwargs)
|
641
|
-
self.
|
649
|
+
self.update_case(case_cp, conclusions, rdr_type)
|
642
650
|
|
643
651
|
return self.classify(case)
|
644
652
|
|
@@ -653,12 +661,14 @@ class GeneralRDR(RippleDownRules):
|
|
653
661
|
return MultiClassRDR()
|
654
662
|
else:
|
655
663
|
return SingleClassRDR()
|
656
|
-
|
664
|
+
elif isinstance(attribute, Column):
|
657
665
|
return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
|
666
|
+
else:
|
667
|
+
return MultiClassRDR() if is_iterable(attribute) else SingleClassRDR()
|
658
668
|
|
659
669
|
@staticmethod
|
660
|
-
def
|
661
|
-
|
670
|
+
def update_case(case: Union[Case, SQLTable],
|
671
|
+
conclusions: List[Any], attribute_type: Optional[Any] = None):
|
662
672
|
"""
|
663
673
|
Update the case with the conclusions.
|
664
674
|
|
@@ -668,7 +678,7 @@ class GeneralRDR(RippleDownRules):
|
|
668
678
|
"""
|
669
679
|
if not conclusions:
|
670
680
|
return
|
671
|
-
conclusions = [conclusions] if not isinstance(conclusions, list) else conclusions
|
681
|
+
conclusions = [conclusions] if not isinstance(conclusions, list) else list(conclusions)
|
672
682
|
if len(conclusions) == 0:
|
673
683
|
return
|
674
684
|
if isinstance(case, SQLTable):
|
@@ -677,7 +687,8 @@ class GeneralRDR(RippleDownRules):
|
|
677
687
|
hint, origin, args = get_hint_for_attribute(attr_name, case)
|
678
688
|
if isinstance(attribute, set) or origin == set:
|
679
689
|
attribute = set() if attribute is None else attribute
|
680
|
-
|
690
|
+
for c in conclusions:
|
691
|
+
attribute.update(make_set(c))
|
681
692
|
elif isinstance(attribute, list) or origin == list:
|
682
693
|
attribute = [] if attribute is None else attribute
|
683
694
|
attribute.extend(conclusions)
|
@@ -686,7 +697,8 @@ class GeneralRDR(RippleDownRules):
|
|
686
697
|
else:
|
687
698
|
raise ValueError(f"Cannot add multiple conclusions to attribute {attr_name}")
|
688
699
|
else:
|
689
|
-
|
700
|
+
for c in make_set(conclusions):
|
701
|
+
case.update(c.as_dict)
|
690
702
|
|
691
703
|
@property
|
692
704
|
def names_of_all_types(self) -> List[str]:
|
ripple_down_rules/utils.py
CHANGED
@@ -24,6 +24,24 @@ if TYPE_CHECKING:
|
|
24
24
|
matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
|
25
25
|
|
26
26
|
|
27
|
+
def make_list(value: Any) -> List:
|
28
|
+
"""
|
29
|
+
Make a list from a value.
|
30
|
+
|
31
|
+
:param value: The value to make a list from.
|
32
|
+
"""
|
33
|
+
return list(value) if is_iterable(value) else [value]
|
34
|
+
|
35
|
+
|
36
|
+
def is_iterable(obj: Any) -> bool:
|
37
|
+
"""
|
38
|
+
Check if an object is iterable.
|
39
|
+
|
40
|
+
:param obj: The object to check.
|
41
|
+
"""
|
42
|
+
return hasattr(obj, "__iter__") and not isinstance(obj, (str, type))
|
43
|
+
|
44
|
+
|
27
45
|
def get_type_from_string(type_path: str):
|
28
46
|
"""
|
29
47
|
Get a type from a string describing its path using the format "module_path.ClassName".
|
@@ -350,20 +368,7 @@ def make_set(value: Any) -> Set:
|
|
350
368
|
|
351
369
|
:param value: The value to make a set from.
|
352
370
|
"""
|
353
|
-
|
354
|
-
return set(value)
|
355
|
-
return {value}
|
356
|
-
|
357
|
-
|
358
|
-
def make_list(value: Any) -> List:
|
359
|
-
"""
|
360
|
-
Make a list from a value.
|
361
|
-
|
362
|
-
:param value: The value to make a list from.
|
363
|
-
"""
|
364
|
-
if hasattr(value, "__iter__") and not isinstance(value, (str, type)):
|
365
|
-
return list(value)
|
366
|
-
return [value]
|
371
|
+
return set(value) if is_iterable(value) else {value}
|
367
372
|
|
368
373
|
|
369
374
|
def make_value_or_raise_error(value: Any) -> Any:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.6
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -1,21 +1,21 @@
|
|
1
1
|
ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
ripple_down_rules/datasets.py,sha256=QRB-1BdFTcUNuhgYEuXYx6qQOYlDu03_iLKDBqrcVrQ,4511
|
3
|
-
ripple_down_rules/experts.py,sha256=
|
3
|
+
ripple_down_rules/experts.py,sha256=F3Xx9G3DM-WSGkBql65G6Fztpt3bXfTXpAs06wuxYHQ,12109
|
4
4
|
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
5
|
ripple_down_rules/prompt.py,sha256=lmREZRyleBTHrVtcf2j_48oc0v3VlxXYGhl6w1mk8qI,4208
|
6
|
-
ripple_down_rules/rdr.py,sha256=
|
6
|
+
ripple_down_rules/rdr.py,sha256=y5LmWYTBkq8r9aGcuYgwACE4yEC6PXRadRwcLBjCJoI,34228
|
7
7
|
ripple_down_rules/rules.py,sha256=TLptqvA6I3QlQaVBTYchbgvXm17XWwFJoTmoN0diHm8,10348
|
8
|
-
ripple_down_rules/utils.py,sha256=
|
8
|
+
ripple_down_rules/utils.py,sha256=X0WOWPX57pejUtLchD4NyApQ9nKb6I313qVHlqTYWLs,17020
|
9
9
|
ripple_down_rules/datastructures/__init__.py,sha256=wY9WqXavuE3wQ1YP65cs_SZyr7CEMB9tol-4oxgK9CM,104
|
10
10
|
ripple_down_rules/datastructures/callable_expression.py,sha256=yEZ6OWzSiWsRtEz_5UquA0inmodFTOWqXWV_a2gg1cg,9110
|
11
|
-
ripple_down_rules/datastructures/dataclasses.py,sha256=
|
12
|
-
ripple_down_rules/datastructures/enums.py,sha256=
|
13
|
-
ripple_down_rules/datastructures/table.py,sha256=
|
11
|
+
ripple_down_rules/datastructures/dataclasses.py,sha256=FW3MMhGXTPa0XwIyLzGalrPwiltNeUqbGIawCAKDHGk,2448
|
12
|
+
ripple_down_rules/datastructures/enums.py,sha256=ftcVTVkH0ttKr-kN9eCwAkQkdq04BLCgaZkpSIonAps,4238
|
13
|
+
ripple_down_rules/datastructures/table.py,sha256=h6ZM_RBo6QPXDQiUvlCvO-I8orN_vaXyK2bYO0ns_4I,22816
|
14
14
|
ripple_down_rules/datastructures/generated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
15
|
ripple_down_rules/datastructures/generated/column/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
16
|
ripple_down_rules/datastructures/generated/row/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
-
ripple_down_rules-0.0.
|
18
|
-
ripple_down_rules-0.0.
|
19
|
-
ripple_down_rules-0.0.
|
20
|
-
ripple_down_rules-0.0.
|
21
|
-
ripple_down_rules-0.0.
|
17
|
+
ripple_down_rules-0.0.6.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
18
|
+
ripple_down_rules-0.0.6.dist-info/METADATA,sha256=qHgViKXWjjvMwi-ZxcILRJoBOipWHwGu2sCuTei06Sw,42518
|
19
|
+
ripple_down_rules-0.0.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
20
|
+
ripple_down_rules-0.0.6.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
21
|
+
ripple_down_rules-0.0.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|