ripple-down-rules 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
- from copy import copy, deepcopy
4
3
  from dataclasses import dataclass
5
4
 
6
5
  from sqlalchemy.orm import DeclarativeBase as SQLTable
7
- from typing_extensions import Any, Optional, Type, Union
6
+ from typing_extensions import Any, Optional, Type
8
7
 
9
8
  from .table import create_row, Case
10
- from ..utils import get_attribute_name, copy_orm_instance_with_relationships, copy_case
9
+ from ..utils import get_attribute_name, copy_case
11
10
 
12
11
 
13
12
  @dataclass
@@ -17,7 +17,7 @@ class Category(str, Enum):
17
17
 
18
18
  @property
19
19
  def as_dict(self):
20
- return {self.__class__.__name__.lower(): self.value}
20
+ return {self.__class__.__name__.lower(): self}
21
21
 
22
22
 
23
23
  class Stop(Category):
@@ -219,10 +219,10 @@ class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
219
219
  value.update(make_set(self[name]))
220
220
  super().__setitem__(name, value)
221
221
  else:
222
- super().__setitem__(name, make_set(self[name]))
222
+ super().__setitem__(name, make_set([self[name], value]))
223
223
  else:
224
- setattr(self, name, value)
225
224
  super().__setitem__(name, value)
225
+ setattr(self, name, self[name])
226
226
 
227
227
  def __contains__(self, item):
228
228
  if isinstance(item, (type, Enum)):
@@ -428,9 +428,10 @@ def create_row(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
428
428
  row = create_or_update_row_from_attribute(attr_value, attr, obj, attr, recursion_idx,
429
429
  max_recursion_idx, parent_is_iterable, row)
430
430
  attributes_type_hints[attr] = get_value_type_from_type_hint(attr, obj)
431
- row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False,
432
- attributes_type_hints=attributes_type_hints)
433
- row = row_cls(id_=id(obj), **row)
431
+ if recursion_idx == 0:
432
+ row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False,
433
+ attributes_type_hints=attributes_type_hints)
434
+ row = row_cls(id_=id(obj), **row)
434
435
  return row
435
436
 
436
437
 
@@ -462,9 +463,6 @@ def create_or_update_row_from_attribute(attr_value: Any, name: str, obj: Any, ob
462
463
  row[obj_name] = column
463
464
  else:
464
465
  row[obj_name] = make_set(attr_value) if parent_is_iterable else attr_value
465
- if row.__class__.__name__ == "Row":
466
- row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False)
467
- row = row_cls(id_=id(obj), **row)
468
466
  return row
469
467
 
470
468
 
@@ -490,14 +488,11 @@ def create_column_and_row_from_iterable_attribute(attr_value: Any, name: str, ob
490
488
  raise ValueError(f"Could not determine the range of {name} in {obj}.")
491
489
  attr_row = Row(id_=id(attr_value))
492
490
  column = Column.create(name, range_).from_obj(values, row_obj=obj)
493
- attributes_type_hints = {}
494
491
  for idx, val in enumerate(values):
495
492
  sub_attr_row = create_row(val, recursion_idx=recursion_idx,
496
493
  max_recursion_idx=max_recursion_idx,
497
494
  obj_name=obj_name, parent_is_iterable=True)
498
495
  attr_row.update(sub_attr_row)
499
- # attr_row_cls = Row.create(name or list(range_)[0].__name__, range_, attr_row, default_values=False)
500
- # attr_row = attr_row_cls(id_=id(attr_value), **attr_row)
501
496
  for sub_attr, val in attr_row.items():
502
497
  setattr(column, sub_attr, val)
503
498
  return column, attr_row
@@ -9,7 +9,7 @@ from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type,
9
9
  from .datastructures import (Case, PromptFor, CallableExpression, Column, CaseQuery)
10
10
  from .datastructures.table import show_current_and_corner_cases
11
11
  from .prompt import prompt_user_for_expression, prompt_user_about_case
12
- from .utils import get_all_subclasses
12
+ from .utils import get_all_subclasses, is_iterable
13
13
 
14
14
  if TYPE_CHECKING:
15
15
  from .rdr import Rule
@@ -125,7 +125,7 @@ class Human(Expert):
125
125
  self.all_expert_answers = json.load(f)
126
126
 
127
127
  def ask_for_conditions(self, case: Case,
128
- targets: Union[List[Column], List[Column]],
128
+ targets: Union[List[Column], List[SQLColumn]],
129
129
  last_evaluated_rule: Optional[Rule] = None) \
130
130
  -> CallableExpression:
131
131
  if not self.use_loaded_answers:
ripple_down_rules/rdr.py CHANGED
@@ -12,7 +12,7 @@ from .datastructures import Case, MCRDRMode, CallableExpression, Column, CaseQue
12
12
  from .experts import Expert, Human
13
13
  from .rules import Rule, SingleClassRule, MultiClassTopRule
14
14
  from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
15
- get_hint_for_attribute, SubclassJSONSerializer
15
+ get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list
16
16
 
17
17
 
18
18
  class RippleDownRules(ABC):
@@ -124,10 +124,12 @@ class RippleDownRules(ABC):
124
124
  :param target: The target category.
125
125
  :return: The precision and recall of the classifier.
126
126
  """
127
- pred_cat = pred_cat if isinstance(pred_cat, list) else [pred_cat]
128
- target = target if isinstance(target, list) else [target]
127
+ pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
128
+ target = target if is_iterable(target) else [target]
129
129
  recall = [not yi or (yi in pred_cat) for yi in target]
130
130
  target_types = [type(yi) for yi in target]
131
+ if len(pred_cat) > 1:
132
+ print(pred_cat)
131
133
  precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
132
134
  return precision, recall
133
135
 
@@ -358,7 +360,7 @@ class MultiClassRDR(RippleDownRules):
358
360
  self.add_conclusion(evaluated_rule)
359
361
 
360
362
  if not next_rule:
361
- if target not in self.conclusions:
363
+ if not make_set(target).intersection(make_set(self.conclusions)):
362
364
  # Nothing fired and there is a target that should have been in the conclusions
363
365
  self.add_rule_for_case(case, target, expert)
364
366
  # Have to check all rules again to make sure only this new rule fires
@@ -509,16 +511,16 @@ class MultiClassRDR(RippleDownRules):
509
511
  """
510
512
  conclusion_types = [type(c) for c in self.conclusions]
511
513
  if type(evaluated_rule.conclusion) not in conclusion_types:
512
- self.conclusions.append(evaluated_rule.conclusion)
514
+ self.conclusions.extend(make_list(evaluated_rule.conclusion))
513
515
  else:
514
516
  same_type_conclusions = [c for c in self.conclusions if type(c) == type(evaluated_rule.conclusion)]
515
517
  combined_conclusion = evaluated_rule.conclusion if isinstance(evaluated_rule.conclusion, set) \
516
518
  else {evaluated_rule.conclusion}
517
- combined_conclusion = deepcopy(combined_conclusion)
519
+ combined_conclusion = copy(combined_conclusion)
518
520
  for c in same_type_conclusions:
519
521
  combined_conclusion.update(c if isinstance(c, set) else make_set(c))
520
522
  self.conclusions.remove(c)
521
- self.conclusions.extend(combined_conclusion)
523
+ self.conclusions.extend(make_list(combined_conclusion))
522
524
 
523
525
  def add_top_rule(self, conditions: CallableExpression, conclusion: Any, corner_case: Union[Case, SQLTable]):
524
526
  """
@@ -587,11 +589,11 @@ class GeneralRDR(RippleDownRules):
587
589
  continue
588
590
  pred_atts = rdr.classify(case_cp)
589
591
  if pred_atts:
590
- pred_atts = pred_atts if isinstance(pred_atts, list) else [pred_atts]
592
+ pred_atts = make_list(pred_atts)
591
593
  pred_atts = [p for p in pred_atts if p not in conclusions]
592
594
  added_attributes = True
593
595
  conclusions.extend(pred_atts)
594
- self.update_case_with_same_type_conclusions(case_cp, pred_atts)
596
+ GeneralRDR.update_case(case_cp, pred_atts)
595
597
  if not added_attributes:
596
598
  break
597
599
  return conclusions
@@ -624,21 +626,27 @@ class GeneralRDR(RippleDownRules):
624
626
  if not target:
625
627
  target = expert.ask_for_conclusion(case_query)
626
628
  case_query_cp = CaseQuery(case_cp, attribute_name=case_query.attribute_name, target=target)
627
- if type(target) not in self.start_rules_dict:
629
+ if is_iterable(target) and not isinstance(target, Column):
630
+ target_type = type(make_list(target)[0])
631
+ assert all([type(t) is target_type for t in target]), ("All targets of a case query must be of the same"
632
+ " type")
633
+ else:
634
+ target_type = type(target)
635
+ if target_type not in self.start_rules_dict:
628
636
  conclusions = self.classify(case)
629
- self.update_case_with_same_type_conclusions(case_cp, conclusions)
637
+ self.update_case(case_cp, conclusions)
630
638
  new_rdr = self.initialize_new_rdr_for_attribute(target, case_cp)
631
639
  new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
632
- self.start_rules_dict[type(target)] = new_rdr
633
- self.update_case_with_same_type_conclusions(case_cp, new_conclusions, type(target))
634
- elif not self.case_has_conclusion(case_cp, type(target)):
640
+ self.start_rules_dict[target_type] = new_rdr
641
+ self.update_case(case_cp, new_conclusions, target_type)
642
+ elif not self.case_has_conclusion(case_cp, target_type):
635
643
  for rdr_type, rdr in self.start_rules_dict.items():
636
- if type(target) is not rdr_type:
644
+ if target_type is not rdr_type:
637
645
  conclusions = rdr.classify(case_cp)
638
646
  else:
639
- conclusions = self.start_rules_dict[type(target)].fit_case(case_query_cp,
647
+ conclusions = self.start_rules_dict[target_type].fit_case(case_query_cp,
640
648
  expert, **kwargs)
641
- self.update_case_with_same_type_conclusions(case_cp, conclusions, rdr_type)
649
+ self.update_case(case_cp, conclusions, rdr_type)
642
650
 
643
651
  return self.classify(case)
644
652
 
@@ -653,12 +661,14 @@ class GeneralRDR(RippleDownRules):
653
661
  return MultiClassRDR()
654
662
  else:
655
663
  return SingleClassRDR()
656
- else:
664
+ elif isinstance(attribute, Column):
657
665
  return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
666
+ else:
667
+ return MultiClassRDR() if is_iterable(attribute) else SingleClassRDR()
658
668
 
659
669
  @staticmethod
660
- def update_case_with_same_type_conclusions(case: Union[Case, SQLTable],
661
- conclusions: List[Any], attribute_type: Optional[Any] = None):
670
+ def update_case(case: Union[Case, SQLTable],
671
+ conclusions: List[Any], attribute_type: Optional[Any] = None):
662
672
  """
663
673
  Update the case with the conclusions.
664
674
 
@@ -668,7 +678,7 @@ class GeneralRDR(RippleDownRules):
668
678
  """
669
679
  if not conclusions:
670
680
  return
671
- conclusions = [conclusions] if not isinstance(conclusions, list) else conclusions
681
+ conclusions = [conclusions] if not isinstance(conclusions, list) else list(conclusions)
672
682
  if len(conclusions) == 0:
673
683
  return
674
684
  if isinstance(case, SQLTable):
@@ -677,7 +687,8 @@ class GeneralRDR(RippleDownRules):
677
687
  hint, origin, args = get_hint_for_attribute(attr_name, case)
678
688
  if isinstance(attribute, set) or origin == set:
679
689
  attribute = set() if attribute is None else attribute
680
- attribute.update(*[make_set(c) for c in conclusions])
690
+ for c in conclusions:
691
+ attribute.update(make_set(c))
681
692
  elif isinstance(attribute, list) or origin == list:
682
693
  attribute = [] if attribute is None else attribute
683
694
  attribute.extend(conclusions)
@@ -686,7 +697,8 @@ class GeneralRDR(RippleDownRules):
686
697
  else:
687
698
  raise ValueError(f"Cannot add multiple conclusions to attribute {attr_name}")
688
699
  else:
689
- case.update(*[c.as_dict for c in make_set(conclusions)])
700
+ for c in make_set(conclusions):
701
+ case.update(c.as_dict)
690
702
 
691
703
  @property
692
704
  def names_of_all_types(self) -> List[str]:
@@ -24,6 +24,24 @@ if TYPE_CHECKING:
24
24
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
25
25
 
26
26
 
27
+ def make_list(value: Any) -> List:
28
+ """
29
+ Make a list from a value.
30
+
31
+ :param value: The value to make a list from.
32
+ """
33
+ return list(value) if is_iterable(value) else [value]
34
+
35
+
36
+ def is_iterable(obj: Any) -> bool:
37
+ """
38
+ Check if an object is iterable.
39
+
40
+ :param obj: The object to check.
41
+ """
42
+ return hasattr(obj, "__iter__") and not isinstance(obj, (str, type))
43
+
44
+
27
45
  def get_type_from_string(type_path: str):
28
46
  """
29
47
  Get a type from a string describing its path using the format "module_path.ClassName".
@@ -350,20 +368,7 @@ def make_set(value: Any) -> Set:
350
368
 
351
369
  :param value: The value to make a set from.
352
370
  """
353
- if hasattr(value, "__iter__") and not isinstance(value, (str, type)):
354
- return set(value)
355
- return {value}
356
-
357
-
358
- def make_list(value: Any) -> List:
359
- """
360
- Make a list from a value.
361
-
362
- :param value: The value to make a list from.
363
- """
364
- if hasattr(value, "__iter__") and not isinstance(value, (str, type)):
365
- return list(value)
366
- return [value]
371
+ return set(value) if is_iterable(value) else {value}
367
372
 
368
373
 
369
374
  def make_value_or_raise_error(value: Any) -> Any:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -1,21 +1,21 @@
1
1
  ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  ripple_down_rules/datasets.py,sha256=QRB-1BdFTcUNuhgYEuXYx6qQOYlDu03_iLKDBqrcVrQ,4511
3
- ripple_down_rules/experts.py,sha256=DMTC-E2g1Fs43oyr30MGeGi5-VKBb3RojzzPa8DvCSA,12093
3
+ ripple_down_rules/experts.py,sha256=F3Xx9G3DM-WSGkBql65G6Fztpt3bXfTXpAs06wuxYHQ,12109
4
4
  ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
5
  ripple_down_rules/prompt.py,sha256=lmREZRyleBTHrVtcf2j_48oc0v3VlxXYGhl6w1mk8qI,4208
6
- ripple_down_rules/rdr.py,sha256=nRhrHdQYmvXgfTFLyO_xxwe7GDYuUhaOqZlQtqSUrnE,33751
6
+ ripple_down_rules/rdr.py,sha256=y5LmWYTBkq8r9aGcuYgwACE4yEC6PXRadRwcLBjCJoI,34228
7
7
  ripple_down_rules/rules.py,sha256=TLptqvA6I3QlQaVBTYchbgvXm17XWwFJoTmoN0diHm8,10348
8
- ripple_down_rules/utils.py,sha256=WlUXTf-B45SPEKpDBVb9QPQWS54250MGAY5xl9MIhR4,16944
8
+ ripple_down_rules/utils.py,sha256=X0WOWPX57pejUtLchD4NyApQ9nKb6I313qVHlqTYWLs,17020
9
9
  ripple_down_rules/datastructures/__init__.py,sha256=wY9WqXavuE3wQ1YP65cs_SZyr7CEMB9tol-4oxgK9CM,104
10
10
  ripple_down_rules/datastructures/callable_expression.py,sha256=yEZ6OWzSiWsRtEz_5UquA0inmodFTOWqXWV_a2gg1cg,9110
11
- ripple_down_rules/datastructures/dataclasses.py,sha256=z_9B7Nj_MIf2Iyrs5VeUhXhYxwaqnuKVjgwxhZZTygY,2525
12
- ripple_down_rules/datastructures/enums.py,sha256=6Mh55_8QRuXyYZXtonWr01VBgLP-jYp91K_8hIgh8u8,4244
13
- ripple_down_rules/datastructures/table.py,sha256=Tq7savAaFoB7n7x6-u2Vz6NfvjLBzPxsqblR4cHujzM,23161
11
+ ripple_down_rules/datastructures/dataclasses.py,sha256=FW3MMhGXTPa0XwIyLzGalrPwiltNeUqbGIawCAKDHGk,2448
12
+ ripple_down_rules/datastructures/enums.py,sha256=ftcVTVkH0ttKr-kN9eCwAkQkdq04BLCgaZkpSIonAps,4238
13
+ ripple_down_rules/datastructures/table.py,sha256=h6ZM_RBo6QPXDQiUvlCvO-I8orN_vaXyK2bYO0ns_4I,22816
14
14
  ripple_down_rules/datastructures/generated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  ripple_down_rules/datastructures/generated/column/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  ripple_down_rules/datastructures/generated/row/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- ripple_down_rules-0.0.5.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
18
- ripple_down_rules-0.0.5.dist-info/METADATA,sha256=qwUDSDi3YlBsg1RlYOEtFn5YQ3x7yNizCAKm72_8YPU,42518
19
- ripple_down_rules-0.0.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
20
- ripple_down_rules-0.0.5.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
21
- ripple_down_rules-0.0.5.dist-info/RECORD,,
17
+ ripple_down_rules-0.0.6.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
18
+ ripple_down_rules-0.0.6.dist-info/METADATA,sha256=qHgViKXWjjvMwi-ZxcILRJoBOipWHwGu2sCuTei06Sw,42518
19
+ ripple_down_rules-0.0.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
20
+ ripple_down_rules-0.0.6.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
21
+ ripple_down_rules-0.0.6.dist-info/RECORD,,