ripple-down-rules 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -81,7 +81,8 @@ def load_zoo_dataset(cache_file: Optional[str] = None) -> Tuple[List[Case], List
81
81
 
82
82
  category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
83
83
  category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
84
- targets = [getattr(SpeciesCol, category_id_to_name[i]) for i in y.values.flatten()]
84
+ # targets = [getattr(SpeciesCol, category_id_to_name[i]) for i in y.values.flatten()]
85
+ targets = [Species.from_str(category_id_to_name[i]) for i in y.values.flatten()]
85
86
  return all_cases, targets
86
87
 
87
88
 
@@ -104,8 +105,8 @@ class Habitat(Category):
104
105
  air = "air"
105
106
 
106
107
 
107
- SpeciesCol = Column.create_from_enum(Species, mutually_exclusive=True)
108
- HabitatCol = Column.create_from_enum(Habitat, mutually_exclusive=False)
108
+ # SpeciesCol = Column.create_from_enum(Species, mutually_exclusive=True)
109
+ # HabitatCol = Column.create_from_enum(Habitat, mutually_exclusive=False)
109
110
 
110
111
 
111
112
  class Base(sqlalchemy.orm.DeclarativeBase):
@@ -162,9 +162,8 @@ class CallableExpression(SubclassJSONSerializer):
162
162
  prev_e = e
163
163
  return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
164
164
 
165
- def to_json(self) -> Dict[str, Any]:
166
- return {**SubclassJSONSerializer.to_json(self),
167
- "user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type)}
165
+ def _to_json(self) -> Dict[str, Any]:
166
+ return {"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type)}
168
167
 
169
168
  @classmethod
170
169
  def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
@@ -1,13 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
- from copy import copy, deepcopy
4
3
  from dataclasses import dataclass
5
4
 
6
5
  from sqlalchemy.orm import DeclarativeBase as SQLTable
7
- from typing_extensions import Any, Optional, Type, Union
6
+ from typing_extensions import Any, Optional, Type
8
7
 
9
8
  from .table import create_row, Case
10
- from ..utils import get_attribute_name, copy_orm_instance_with_relationships, copy_case
9
+ from ..utils import get_attribute_name, copy_case
11
10
 
12
11
 
13
12
  @dataclass
@@ -2,10 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  from enum import auto, Enum
4
4
 
5
- from typing_extensions import List
5
+ from typing_extensions import List, Dict, Any
6
6
 
7
+ from ripple_down_rules.utils import SubclassJSONSerializer
7
8
 
8
- class Category(str, Enum):
9
+
10
+ class Category(str, SubclassJSONSerializer, Enum):
9
11
 
10
12
  @classmethod
11
13
  def from_str(cls, value: str) -> Category:
@@ -19,6 +21,13 @@ class Category(str, Enum):
19
21
  def as_dict(self):
20
22
  return {self.__class__.__name__.lower(): self.value}
21
23
 
24
+ def _to_json(self) -> Dict[str, Any]:
25
+ return self.as_dict
26
+
27
+ @classmethod
28
+ def _from_json(cls, data: Dict[str, Any]) -> Category:
29
+ return cls.from_str(data[cls.__name__.lower()])
30
+
22
31
 
23
32
  class Stop(Category):
24
33
  """
@@ -21,165 +21,7 @@ if TYPE_CHECKING:
21
21
  from .callable_expression import CallableExpression
22
22
 
23
23
 
24
- class SubClassFactory:
25
- """
26
- A custom set class that is used to add other attributes to the set. This is similar to a table where the set is the
27
- table, the attributes are the columns, and the values are the rows.
28
- """
29
- _value_range: set
30
- """
31
- The range of the attribute, this can be a set of possible values or a range of numeric values (int, float).
32
- """
33
- _registry: Dict[(str, type), Type[SubClassFactory]] = {}
34
- """
35
- A dictionary of all dynamically created subclasses of this class.
36
- """
37
- _generated_classes_dir: str = os.path.dirname(os.path.abspath(__file__)) + "/generated"
38
-
39
- @classmethod
40
- def create(cls, name: str, range_: set, class_attributes: Optional[Dict[str, Any]] = None,
41
- default_values: bool = True,
42
- attributes_type_hints: Optional[Dict[str, Type]] = None) -> Type[SubClassFactory]:
43
- """
44
- Create a new subclass.
45
-
46
- :param name: The name of the subclass.
47
- :param range_: The range of the subclass values.
48
- :param class_attributes: The attributes of the new subclass.
49
- :param default_values: Boolean indicating whether to add default values to the subclass attributes or not.
50
- :param attributes_type_hints: The type hints of the subclass attributes.
51
- :return: The new subclass.
52
- """
53
- existing_class = cls._get_and_update_subclass(name, range_)
54
- if existing_class:
55
- return existing_class
56
-
57
- new_attribute_type = cls._create_class_in_new_python_file_and_import_it(name, range_, default_values,
58
- class_attributes, attributes_type_hints)
59
-
60
- cls.register(new_attribute_type)
61
-
62
- return new_attribute_type
63
-
64
- @classmethod
65
- def _create_class_in_new_python_file_and_import_it(cls, name: str, range_: set, default_values: bool = True,
66
- class_attributes: Optional[Dict[str, Any]] = None,
67
- attributes_type_hints: Optional[Dict[str, Type]] = None)\
68
- -> Type[SubClassFactory]:
69
- def get_type_import(value_type: Any) -> Tuple[str, str]:
70
- if value_type is type(None):
71
- return "from types import NoneType\n", "NoneType"
72
- elif value_type.__module__ != "builtins":
73
- value_type_alias = f"{value_type.__name__}_"
74
- return f"from {value_type.__module__} import {value_type.__name__} as {value_type_alias}\n", value_type_alias
75
- else:
76
- return "", value_type.__name__
77
- attributes_type_hints = attributes_type_hints or {}
78
- parent_class_alias = cls.__name__ + "_"
79
- imports = f"from {cls.__module__} import {cls.__name__} as {parent_class_alias}\n"
80
- class_code = f"class {name}({parent_class_alias}):\n"
81
- class_attributes = copy(class_attributes) if class_attributes else {}
82
- class_attributes.update({"_value_range": range_})
83
- for key, value in class_attributes.items():
84
- if value is not None:
85
- new_import, value_type_name = get_type_import(type(value))
86
- elif key in attributes_type_hints:
87
- new_import, value_type_name = get_type_import(attributes_type_hints[key])
88
- else:
89
- new_import, value_type_name = "from typing_extensions import Any", "Any"
90
- imports += new_import
91
- if isinstance(value, set):
92
- value_names = []
93
- for v in value:
94
- if isinstance(v, type):
95
- new_import, v_name = get_type_import(v)
96
- imports += new_import
97
- else:
98
- v_name = str(v)
99
- value_names.append(v_name)
100
- value_str = ", ".join(value_names)
101
- new_value = "{" + value_str + "}"
102
- elif isinstance(value, type):
103
- new_import, value_name = get_type_import(value)
104
- new_value = value_name
105
- value_type_name = value_name
106
- else:
107
- new_value = value
108
- if default_values or key == "_value_range":
109
- class_code += f" {key}: {value_type_name} = {new_value}\n"
110
- else:
111
- class_code += f" {key}: {value_type_name}\n"
112
- imports += "\n\n"
113
- if issubclass(cls, Row):
114
- folder_name = "row"
115
- elif issubclass(cls, Column):
116
- folder_name = "column"
117
- else:
118
- raise ValueError(f"Unknown class {cls}.")
119
- # write the code to a file
120
- with open(f"{cls._generated_classes_dir}/{folder_name}/{name.lower()}.py", "w") as f:
121
- f.write(imports + class_code)
122
-
123
- # import the class from the file
124
- import_path = ".".join(cls.__module__.split(".")[:-1] + ["generated", folder_name, name.lower()])
125
- time.sleep(0.3)
126
- return __import__(import_path, fromlist=[name.lower()]).__dict__[name]
127
-
128
- @classmethod
129
- def _get_and_update_subclass(cls, name: str, range_: set) -> Optional[Type[SubClassFactory]]:
130
- """
131
- Get a subclass of the attribute class and update its range if necessary.
132
-
133
- :param name: The name of the column.
134
- :param range_: The range of the column values.
135
- """
136
- key = (name.lower(), cls)
137
- if key in cls._registry:
138
- if not cls._registry[key].is_within_range(range_):
139
- if isinstance(cls._registry[key]._value_range, set):
140
- cls._registry[key]._value_range.update(range_)
141
- else:
142
- raise ValueError(f"Range of {key} is different from {cls._registry[key]._value_range}.")
143
- return cls._registry[key]
144
-
145
- @classmethod
146
- def register(cls, subclass: Type[SubClassFactory]):
147
- """
148
- Register a subclass of the attribute class, this is used to be able to dynamically create Attribute subclasses.
149
-
150
- :param subclass: The subclass to register.
151
- """
152
- if not issubclass(subclass, SubClassFactory):
153
- raise ValueError(f"{subclass} is not a subclass of CustomSet.")
154
- if subclass not in cls._registry:
155
- cls._registry[(subclass.__name__.lower(), cls)] = subclass
156
- else:
157
- raise ValueError(f"{subclass} is already registered.")
158
-
159
- @classmethod
160
- def is_within_range(cls, value: Any) -> bool:
161
- """
162
- Check if a value is within the range of the custom set.
163
-
164
- :param value: The value to check.
165
- :return: Boolean indicating whether the value is within the range or not.
166
- """
167
- if hasattr(value, "__iter__") and not isinstance(value, str):
168
- if all(isinstance(val_range, type) and isinstance(v, val_range)
169
- for v in value for val_range in cls._value_range):
170
- return True
171
- else:
172
- return set(value).issubset(cls._value_range)
173
- elif isinstance(value, str):
174
- return value.lower() in cls._value_range
175
- else:
176
- return value in cls._value_range
177
-
178
- def __instancecheck__(self, instance):
179
- return isinstance(instance, (SubClassFactory, *self._value_range))
180
-
181
-
182
- class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
24
+ class Row(UserDict, SubclassJSONSerializer):
183
25
  """
184
26
  A collection of attributes that represents a set of constraints on a case. This is a dictionary where the keys are
185
27
  the names of the attributes and the values are the attributes. All are stored in lower case.
@@ -193,7 +35,7 @@ class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
193
35
  :param kwargs: The attributes of the row.
194
36
  """
195
37
  super().__init__(kwargs)
196
- self.id = id_
38
+ self.id_ = id_ if id_ else id(self)
197
39
 
198
40
  @classmethod
199
41
  def from_obj(cls, obj: Any, obj_name: Optional[str] = None, max_recursion_idx: int = 3) -> Row:
@@ -219,10 +61,10 @@ class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
219
61
  value.update(make_set(self[name]))
220
62
  super().__setitem__(name, value)
221
63
  else:
222
- super().__setitem__(name, make_set(self[name]))
64
+ super().__setitem__(name, make_set([self[name], value]))
223
65
  else:
224
- setattr(self, name, value)
225
66
  super().__setitem__(name, value)
67
+ setattr(self, name, self[name])
226
68
 
227
69
  def __contains__(self, item):
228
70
  if isinstance(item, (type, Enum)):
@@ -241,16 +83,12 @@ class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
241
83
  return super().__eq__(other)
242
84
 
243
85
  def __hash__(self):
244
- return hash(tuple(self.items()))
86
+ return self.id_
245
87
 
246
- def __instancecheck__(self, instance):
247
- return isinstance(instance, (dict, UserDict, Row)) or super().__instancecheck__(instance)
248
-
249
- def to_json(self) -> Dict[str, Any]:
88
+ def _to_json(self) -> Dict[str, Any]:
250
89
  serializable = {k: v for k, v in self.items() if not k.startswith("_")}
251
- serializable["_id"] = self.id
252
- return {**SubclassJSONSerializer.to_json(self),
253
- **{k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}}
90
+ serializable["_id"] = self.id_
91
+ return {k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}
254
92
 
255
93
  @classmethod
256
94
  def _from_json(cls, data: Dict[str, Any]) -> Row:
@@ -280,16 +118,15 @@ class ColumnValue(SubclassJSONSerializer):
280
118
  def __hash__(self):
281
119
  return self.id
282
120
 
283
- def to_json(self) -> Dict[str, Any]:
284
- return {**SubclassJSONSerializer.to_json(self),
285
- "id": self.id, "value": self.value}
121
+ def _to_json(self) -> Dict[str, Any]:
122
+ return {"id": self.id, "value": self.value}
286
123
 
287
124
  @classmethod
288
125
  def _from_json(cls, data: Dict[str, Any]) -> ColumnValue:
289
126
  return cls(id=data["id"], value=data["value"])
290
127
 
291
128
 
292
- class Column(set, SubClassFactory, SubclassJSONSerializer):
129
+ class Column(set, SubclassJSONSerializer):
293
130
  nullable: bool = True
294
131
  """
295
132
  A boolean indicating whether the column can be None or not.
@@ -321,21 +158,6 @@ class Column(set, SubClassFactory, SubclassJSONSerializer):
321
158
  values = {ColumnValue(id(values), v) for v in values}
322
159
  return values
323
160
 
324
- @classmethod
325
- def create(cls, name: str, range_: set,
326
- nullable: bool = True, mutually_exclusive: bool = False) -> Type[SubClassFactory]:
327
- return super().create(name, range_, {"nullable": nullable, "mutually_exclusive": mutually_exclusive})
328
-
329
- @classmethod
330
- def create_from_enum(cls, category: Type[Enum], nullable: bool = True,
331
- mutually_exclusive: bool = False) -> Type[SubClassFactory]:
332
- new_cls = cls.create(category.__name__.lower(), {category}, nullable=nullable,
333
- mutually_exclusive=mutually_exclusive)
334
- for value in category:
335
- value_column = cls.create(category.__name__.lower(), {value}, mutually_exclusive=mutually_exclusive)(value)
336
- setattr(new_cls, value.name, value_column)
337
- return new_cls
338
-
339
161
  @classmethod
340
162
  def from_obj(cls, values: Set[Any], row_obj: Optional[Any] = None) -> Column:
341
163
  id_ = id(row_obj) if row_obj else id(values)
@@ -373,12 +195,9 @@ class Column(set, SubClassFactory, SubclassJSONSerializer):
373
195
  return "None"
374
196
  return str({v for v in self}) if len(self) > 1 else str(next(iter(self)))
375
197
 
376
- def __instancecheck__(self, instance):
377
- return isinstance(instance, (set, self.__class__)) or super().__instancecheck__(instance)
378
-
379
- def to_json(self) -> Dict[str, Any]:
380
- return {**SubclassJSONSerializer.to_json(self),
381
- **{id_: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for id_, v in self.id_value_map.items()}}
198
+ def _to_json(self) -> Dict[str, Any]:
199
+ return {id_: v.to_json() if isinstance(v, SubclassJSONSerializer) else v
200
+ for id_, v in self.id_value_map.items()}
382
201
 
383
202
  @classmethod
384
203
  def _from_json(cls, data: Dict[str, Any]) -> Column:
@@ -397,8 +216,7 @@ def create_rows_from_dataframe(df: DataFrame, name: Optional[str] = None) -> Lis
397
216
  col_names = list(df.columns)
398
217
  for row_id, row in df.iterrows():
399
218
  row = {col_name: row[col_name].item() for col_name in col_names}
400
- row_cls = Row.create(name or df.__class__.__name__, make_set(type(df)), row, default_values=False)
401
- rows.append(row_cls(id_=row_id, **row))
219
+ rows.append(Row(id_=row_id, **row))
402
220
  return rows
403
221
 
404
222
 
@@ -420,17 +238,12 @@ def create_row(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
420
238
  or (obj.__class__ in [MetaData, registry])):
421
239
  return Row(id_=id(obj), **{obj_name or obj.__class__.__name__: make_set(obj) if parent_is_iterable else obj})
422
240
  row = Row(id_=id(obj))
423
- attributes_type_hints = {}
424
241
  for attr in dir(obj):
425
242
  if attr.startswith("_") or callable(getattr(obj, attr)):
426
243
  continue
427
244
  attr_value = getattr(obj, attr)
428
245
  row = create_or_update_row_from_attribute(attr_value, attr, obj, attr, recursion_idx,
429
246
  max_recursion_idx, parent_is_iterable, row)
430
- attributes_type_hints[attr] = get_value_type_from_type_hint(attr, obj)
431
- row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False,
432
- attributes_type_hints=attributes_type_hints)
433
- row = row_cls(id_=id(obj), **row)
434
247
  return row
435
248
 
436
249
 
@@ -462,9 +275,6 @@ def create_or_update_row_from_attribute(attr_value: Any, name: str, obj: Any, ob
462
275
  row[obj_name] = column
463
276
  else:
464
277
  row[obj_name] = make_set(attr_value) if parent_is_iterable else attr_value
465
- if row.__class__.__name__ == "Row":
466
- row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False)
467
- row = row_cls(id_=id(obj), **row)
468
278
  return row
469
279
 
470
280
 
@@ -489,15 +299,12 @@ def create_column_and_row_from_iterable_attribute(attr_value: Any, name: str, ob
489
299
  if not range_:
490
300
  raise ValueError(f"Could not determine the range of {name} in {obj}.")
491
301
  attr_row = Row(id_=id(attr_value))
492
- column = Column.create(name, range_).from_obj(values, row_obj=obj)
493
- attributes_type_hints = {}
302
+ column = Column.from_obj(values, row_obj=obj)
494
303
  for idx, val in enumerate(values):
495
304
  sub_attr_row = create_row(val, recursion_idx=recursion_idx,
496
305
  max_recursion_idx=max_recursion_idx,
497
306
  obj_name=obj_name, parent_is_iterable=True)
498
307
  attr_row.update(sub_attr_row)
499
- # attr_row_cls = Row.create(name or list(range_)[0].__name__, range_, attr_row, default_values=False)
500
- # attr_row = attr_row_cls(id_=id(attr_value), **attr_row)
501
308
  for sub_attr, val in attr_row.items():
502
309
  setattr(column, sub_attr, val)
503
310
  return column, attr_row
@@ -9,7 +9,7 @@ from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type,
9
9
  from .datastructures import (Case, PromptFor, CallableExpression, Column, CaseQuery)
10
10
  from .datastructures.table import show_current_and_corner_cases
11
11
  from .prompt import prompt_user_for_expression, prompt_user_about_case
12
- from .utils import get_all_subclasses
12
+ from .utils import get_all_subclasses, is_iterable
13
13
 
14
14
  if TYPE_CHECKING:
15
15
  from .rdr import Rule
@@ -125,7 +125,7 @@ class Human(Expert):
125
125
  self.all_expert_answers = json.load(f)
126
126
 
127
127
  def ask_for_conditions(self, case: Case,
128
- targets: Union[List[Column], List[Column]],
128
+ targets: Union[List[Column], List[SQLColumn]],
129
129
  last_evaluated_rule: Optional[Rule] = None) \
130
130
  -> CallableExpression:
131
131
  if not self.use_loaded_answers:
@@ -198,19 +198,6 @@ class Human(Expert):
198
198
  self.all_expert_answers.append(expert_input)
199
199
  return expression
200
200
 
201
- def create_category_instance(self, cat_name: str, cat_value: Union[str, int, float, set]) -> Column:
202
- """
203
- Create a new category instance.
204
-
205
- :param cat_name: The name of the category.
206
- :param cat_value: The value of the category.
207
- :return: A new instance of the category.
208
- """
209
- category_type = self.get_category_type(cat_name)
210
- if not category_type:
211
- category_type = self.create_new_category_type(cat_name)
212
- return category_type(cat_value)
213
-
214
201
  def get_category_type(self, cat_name: str) -> Optional[Type[Column]]:
215
202
  """
216
203
  Get the category type from the known categories.
@@ -226,19 +213,6 @@ class Human(Expert):
226
213
  category_type = self.known_categories[cat_name]
227
214
  return category_type
228
215
 
229
- def create_new_category_type(self, cat_name: str) -> Type[Column]:
230
- """
231
- Create a new category type.
232
-
233
- :param cat_name: The name of the category.
234
- :return: A new category type.
235
- """
236
- if self.ask_if_category_is_mutually_exclusive(cat_name):
237
- category_type: Type[Column] = Column.create(cat_name, set(), mutually_exclusive=True)
238
- else:
239
- category_type: Type[Column] = Column.create(cat_name, set())
240
- return category_type
241
-
242
216
  def ask_if_category_is_mutually_exclusive(self, category_name: str) -> bool:
243
217
  """
244
218
  Ask the expert if the new category can have multiple values.
ripple_down_rules/rdr.py CHANGED
@@ -12,7 +12,7 @@ from .datastructures import Case, MCRDRMode, CallableExpression, Column, CaseQue
12
12
  from .experts import Expert, Human
13
13
  from .rules import Rule, SingleClassRule, MultiClassTopRule
14
14
  from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
15
- get_hint_for_attribute, SubclassJSONSerializer
15
+ get_hint_for_attribute, SubclassJSONSerializer, is_iterable, make_list
16
16
 
17
17
 
18
18
  class RippleDownRules(ABC):
@@ -124,10 +124,12 @@ class RippleDownRules(ABC):
124
124
  :param target: The target category.
125
125
  :return: The precision and recall of the classifier.
126
126
  """
127
- pred_cat = pred_cat if isinstance(pred_cat, list) else [pred_cat]
128
- target = target if isinstance(target, list) else [target]
127
+ pred_cat = pred_cat if is_iterable(pred_cat) else [pred_cat]
128
+ target = target if is_iterable(target) else [target]
129
129
  recall = [not yi or (yi in pred_cat) for yi in target]
130
130
  target_types = [type(yi) for yi in target]
131
+ if len(pred_cat) > 1:
132
+ print(pred_cat)
131
133
  precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
132
134
  return precision, recall
133
135
 
@@ -265,8 +267,8 @@ class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
265
267
  if rule.alternative:
266
268
  self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
267
269
 
268
- def to_json(self) -> Dict[str, Any]:
269
- return {**SubclassJSONSerializer.to_json(self), "start_rule": self.start_rule.to_json()}
270
+ def _to_json(self) -> Dict[str, Any]:
271
+ return {"start_rule": self.start_rule.to_json()}
270
272
 
271
273
  @classmethod
272
274
  def _from_json(cls, data: Dict[str, Any]) -> Self:
@@ -358,7 +360,7 @@ class MultiClassRDR(RippleDownRules):
358
360
  self.add_conclusion(evaluated_rule)
359
361
 
360
362
  if not next_rule:
361
- if target not in self.conclusions:
363
+ if not make_set(target).intersection(make_set(self.conclusions)):
362
364
  # Nothing fired and there is a target that should have been in the conclusions
363
365
  self.add_rule_for_case(case, target, expert)
364
366
  # Have to check all rules again to make sure only this new rule fires
@@ -509,16 +511,16 @@ class MultiClassRDR(RippleDownRules):
509
511
  """
510
512
  conclusion_types = [type(c) for c in self.conclusions]
511
513
  if type(evaluated_rule.conclusion) not in conclusion_types:
512
- self.conclusions.append(evaluated_rule.conclusion)
514
+ self.conclusions.extend(make_list(evaluated_rule.conclusion))
513
515
  else:
514
516
  same_type_conclusions = [c for c in self.conclusions if type(c) == type(evaluated_rule.conclusion)]
515
517
  combined_conclusion = evaluated_rule.conclusion if isinstance(evaluated_rule.conclusion, set) \
516
518
  else {evaluated_rule.conclusion}
517
- combined_conclusion = deepcopy(combined_conclusion)
519
+ combined_conclusion = copy(combined_conclusion)
518
520
  for c in same_type_conclusions:
519
521
  combined_conclusion.update(c if isinstance(c, set) else make_set(c))
520
522
  self.conclusions.remove(c)
521
- self.conclusions.extend(combined_conclusion)
523
+ self.conclusions.extend(make_list(combined_conclusion))
522
524
 
523
525
  def add_top_rule(self, conditions: CallableExpression, conclusion: Any, corner_case: Union[Case, SQLTable]):
524
526
  """
@@ -587,11 +589,11 @@ class GeneralRDR(RippleDownRules):
587
589
  continue
588
590
  pred_atts = rdr.classify(case_cp)
589
591
  if pred_atts:
590
- pred_atts = pred_atts if isinstance(pred_atts, list) else [pred_atts]
592
+ pred_atts = make_list(pred_atts)
591
593
  pred_atts = [p for p in pred_atts if p not in conclusions]
592
594
  added_attributes = True
593
595
  conclusions.extend(pred_atts)
594
- self.update_case_with_same_type_conclusions(case_cp, pred_atts)
596
+ GeneralRDR.update_case(case_cp, pred_atts)
595
597
  if not added_attributes:
596
598
  break
597
599
  return conclusions
@@ -624,21 +626,27 @@ class GeneralRDR(RippleDownRules):
624
626
  if not target:
625
627
  target = expert.ask_for_conclusion(case_query)
626
628
  case_query_cp = CaseQuery(case_cp, attribute_name=case_query.attribute_name, target=target)
627
- if type(target) not in self.start_rules_dict:
629
+ if is_iterable(target) and not isinstance(target, Column):
630
+ target_type = type(make_list(target)[0])
631
+ assert all([type(t) is target_type for t in target]), ("All targets of a case query must be of the same"
632
+ " type")
633
+ else:
634
+ target_type = type(target)
635
+ if target_type not in self.start_rules_dict:
628
636
  conclusions = self.classify(case)
629
- self.update_case_with_same_type_conclusions(case_cp, conclusions)
637
+ self.update_case(case_cp, conclusions)
630
638
  new_rdr = self.initialize_new_rdr_for_attribute(target, case_cp)
631
639
  new_conclusions = new_rdr.fit_case(case_query_cp, expert, **kwargs)
632
- self.start_rules_dict[type(target)] = new_rdr
633
- self.update_case_with_same_type_conclusions(case_cp, new_conclusions, type(target))
634
- elif not self.case_has_conclusion(case_cp, type(target)):
640
+ self.start_rules_dict[target_type] = new_rdr
641
+ self.update_case(case_cp, new_conclusions, target_type)
642
+ elif not self.case_has_conclusion(case_cp, target_type):
635
643
  for rdr_type, rdr in self.start_rules_dict.items():
636
- if type(target) is not rdr_type:
644
+ if target_type is not rdr_type:
637
645
  conclusions = rdr.classify(case_cp)
638
646
  else:
639
- conclusions = self.start_rules_dict[type(target)].fit_case(case_query_cp,
647
+ conclusions = self.start_rules_dict[target_type].fit_case(case_query_cp,
640
648
  expert, **kwargs)
641
- self.update_case_with_same_type_conclusions(case_cp, conclusions, rdr_type)
649
+ self.update_case(case_cp, conclusions, rdr_type)
642
650
 
643
651
  return self.classify(case)
644
652
 
@@ -653,12 +661,14 @@ class GeneralRDR(RippleDownRules):
653
661
  return MultiClassRDR()
654
662
  else:
655
663
  return SingleClassRDR()
656
- else:
664
+ elif isinstance(attribute, Column):
657
665
  return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
666
+ else:
667
+ return MultiClassRDR() if is_iterable(attribute) else SingleClassRDR()
658
668
 
659
669
  @staticmethod
660
- def update_case_with_same_type_conclusions(case: Union[Case, SQLTable],
661
- conclusions: List[Any], attribute_type: Optional[Any] = None):
670
+ def update_case(case: Union[Case, SQLTable],
671
+ conclusions: List[Any], attribute_type: Optional[Any] = None):
662
672
  """
663
673
  Update the case with the conclusions.
664
674
 
@@ -668,7 +678,7 @@ class GeneralRDR(RippleDownRules):
668
678
  """
669
679
  if not conclusions:
670
680
  return
671
- conclusions = [conclusions] if not isinstance(conclusions, list) else conclusions
681
+ conclusions = [conclusions] if not isinstance(conclusions, list) else list(conclusions)
672
682
  if len(conclusions) == 0:
673
683
  return
674
684
  if isinstance(case, SQLTable):
@@ -677,7 +687,8 @@ class GeneralRDR(RippleDownRules):
677
687
  hint, origin, args = get_hint_for_attribute(attr_name, case)
678
688
  if isinstance(attribute, set) or origin == set:
679
689
  attribute = set() if attribute is None else attribute
680
- attribute.update(*[make_set(c) for c in conclusions])
690
+ for c in conclusions:
691
+ attribute.update(make_set(c))
681
692
  elif isinstance(attribute, list) or origin == list:
682
693
  attribute = [] if attribute is None else attribute
683
694
  attribute.extend(conclusions)
@@ -686,7 +697,8 @@ class GeneralRDR(RippleDownRules):
686
697
  else:
687
698
  raise ValueError(f"Cannot add multiple conclusions to attribute {attr_name}")
688
699
  else:
689
- case.update(*[c.as_dict for c in make_set(conclusions)])
700
+ for c in make_set(conclusions):
701
+ case.update(c.as_dict)
690
702
 
691
703
  @property
692
704
  def names_of_all_types(self) -> List[str]:
@@ -195,9 +195,8 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
195
195
  if_clause = "elif" if self.weight == RDREdge.Alternative.value else "if"
196
196
  return f"{parent_indent}{if_clause} {self.conditions.parsed_user_input}:\n"
197
197
 
198
- def to_json(self) -> Dict[str, Any]:
199
- self.json_serialization = {**SubclassJSONSerializer.to_json(self),
200
- "conditions": self.conditions.to_json(),
198
+ def _to_json(self) -> Dict[str, Any]:
199
+ self.json_serialization = {"conditions": self.conditions.to_json(),
201
200
  "conclusion": self.conclusion.to_json(),
202
201
  "parent": self.parent.json_serialization if self.parent else None,
203
202
  "corner_case": self.corner_case.to_json() if self.corner_case else None,
@@ -24,6 +24,24 @@ if TYPE_CHECKING:
24
24
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
25
25
 
26
26
 
27
+ def make_list(value: Any) -> List:
28
+ """
29
+ Make a list from a value.
30
+
31
+ :param value: The value to make a list from.
32
+ """
33
+ return list(value) if is_iterable(value) else [value]
34
+
35
+
36
+ def is_iterable(obj: Any) -> bool:
37
+ """
38
+ Check if an object is iterable.
39
+
40
+ :param obj: The object to check.
41
+ """
42
+ return hasattr(obj, "__iter__") and not isinstance(obj, (str, type))
43
+
44
+
27
45
  def get_type_from_string(type_path: str):
28
46
  """
29
47
  Get a type from a string describing its path using the format "module_path.ClassName".
@@ -64,7 +82,14 @@ class SubclassJSONSerializer:
64
82
  """
65
83
 
66
84
  def to_json(self) -> Dict[str, Any]:
67
- return {"_type": get_full_class_name(self.__class__)}
85
+ return {"_type": get_full_class_name(self.__class__), **self._to_json()}
86
+
87
+ @abstractmethod
88
+ def _to_json(self) -> Dict[str, Any]:
89
+ """
90
+ Create a json dict from the object.
91
+ """
92
+ pass
68
93
 
69
94
  @classmethod
70
95
  @abstractmethod
@@ -350,20 +375,7 @@ def make_set(value: Any) -> Set:
350
375
 
351
376
  :param value: The value to make a set from.
352
377
  """
353
- if hasattr(value, "__iter__") and not isinstance(value, (str, type)):
354
- return set(value)
355
- return {value}
356
-
357
-
358
- def make_list(value: Any) -> List:
359
- """
360
- Make a list from a value.
361
-
362
- :param value: The value to make a list from.
363
- """
364
- if hasattr(value, "__iter__") and not isinstance(value, (str, type)):
365
- return list(value)
366
- return [value]
378
+ return set(value) if is_iterable(value) else {value}
367
379
 
368
380
 
369
381
  def make_value_or_raise_error(value: Any) -> Any:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -0,0 +1,18 @@
1
+ ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ ripple_down_rules/datasets.py,sha256=p7WmLiAKcYtzLq87RJhzU_BHCdM-zKK71yldjyYNpUE,4602
3
+ ripple_down_rules/experts.py,sha256=ajH46NXM6IREtWr2a-ZBxdKl4CFzSx5S7yktnyXgwnc,11089
4
+ ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
+ ripple_down_rules/prompt.py,sha256=lmREZRyleBTHrVtcf2j_48oc0v3VlxXYGhl6w1mk8qI,4208
6
+ ripple_down_rules/rdr.py,sha256=dS_OqaGtP1PAKEe0h-dMy7JI9ycNuH_EIbHlh-Nz4Sg,34189
7
+ ripple_down_rules/rules.py,sha256=R9wIyalXBc6fGogzy92OxsyKAtJdQwdso_Vs1evWZGU,10274
8
+ ripple_down_rules/utils.py,sha256=VjpdJ5W-xK-TQttfpNIHRKpYK91vYjtQ22kgccQgVTQ,17183
9
+ ripple_down_rules/datastructures/__init__.py,sha256=wY9WqXavuE3wQ1YP65cs_SZyr7CEMB9tol-4oxgK9CM,104
10
+ ripple_down_rules/datastructures/callable_expression.py,sha256=lF8ag3-8oYuETAUohtRfviOpypSMQR-JEJGgR0Zhv3E,9055
11
+ ripple_down_rules/datastructures/dataclasses.py,sha256=FW3MMhGXTPa0XwIyLzGalrPwiltNeUqbGIawCAKDHGk,2448
12
+ ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
13
+ ripple_down_rules/datastructures/table.py,sha256=G2ShfPmyiwfDXS-JY_jFD6nOx7OLuMr-GP--OCDGKYc,13733
14
+ ripple_down_rules-0.0.7.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
15
+ ripple_down_rules-0.0.7.dist-info/METADATA,sha256=fBb_rMvpmfAueLPCXvNs3IXLNX-COunhRnEq7iYkV5Y,42518
16
+ ripple_down_rules-0.0.7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
17
+ ripple_down_rules-0.0.7.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
18
+ ripple_down_rules-0.0.7.dist-info/RECORD,,
File without changes
@@ -1,21 +0,0 @@
1
- ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- ripple_down_rules/datasets.py,sha256=QRB-1BdFTcUNuhgYEuXYx6qQOYlDu03_iLKDBqrcVrQ,4511
3
- ripple_down_rules/experts.py,sha256=DMTC-E2g1Fs43oyr30MGeGi5-VKBb3RojzzPa8DvCSA,12093
4
- ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
- ripple_down_rules/prompt.py,sha256=lmREZRyleBTHrVtcf2j_48oc0v3VlxXYGhl6w1mk8qI,4208
6
- ripple_down_rules/rdr.py,sha256=nRhrHdQYmvXgfTFLyO_xxwe7GDYuUhaOqZlQtqSUrnE,33751
7
- ripple_down_rules/rules.py,sha256=TLptqvA6I3QlQaVBTYchbgvXm17XWwFJoTmoN0diHm8,10348
8
- ripple_down_rules/utils.py,sha256=WlUXTf-B45SPEKpDBVb9QPQWS54250MGAY5xl9MIhR4,16944
9
- ripple_down_rules/datastructures/__init__.py,sha256=wY9WqXavuE3wQ1YP65cs_SZyr7CEMB9tol-4oxgK9CM,104
10
- ripple_down_rules/datastructures/callable_expression.py,sha256=yEZ6OWzSiWsRtEz_5UquA0inmodFTOWqXWV_a2gg1cg,9110
11
- ripple_down_rules/datastructures/dataclasses.py,sha256=z_9B7Nj_MIf2Iyrs5VeUhXhYxwaqnuKVjgwxhZZTygY,2525
12
- ripple_down_rules/datastructures/enums.py,sha256=6Mh55_8QRuXyYZXtonWr01VBgLP-jYp91K_8hIgh8u8,4244
13
- ripple_down_rules/datastructures/table.py,sha256=Tq7savAaFoB7n7x6-u2Vz6NfvjLBzPxsqblR4cHujzM,23161
14
- ripple_down_rules/datastructures/generated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- ripple_down_rules/datastructures/generated/column/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- ripple_down_rules/datastructures/generated/row/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- ripple_down_rules-0.0.5.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
18
- ripple_down_rules-0.0.5.dist-info/METADATA,sha256=qwUDSDi3YlBsg1RlYOEtFn5YQ3x7yNizCAKm72_8YPU,42518
19
- ripple_down_rules-0.0.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
20
- ripple_down_rules-0.0.5.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
21
- ripple_down_rules-0.0.5.dist-info/RECORD,,