ripple-down-rules 0.0.6__tar.gz → 0.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/PKG-INFO +1 -1
  2. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/pyproject.toml +1 -1
  3. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules/datasets.py +4 -3
  4. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules/datastructures/callable_expression.py +2 -3
  5. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules/datastructures/enums.py +12 -3
  6. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules/datastructures/table.py +14 -202
  7. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules/experts.py +0 -26
  8. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules/rdr.py +2 -4
  9. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules/rules.py +2 -3
  10. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules/utils.py +37 -1
  11. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules.egg-info/PKG-INFO +1 -1
  12. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules.egg-info/SOURCES.txt +0 -3
  13. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/test/test_json_serialization.py +4 -10
  14. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/test/test_rdr.py +1 -5
  15. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/test/test_rdr_alchemy.py +17 -7
  16. ripple_down_rules-0.0.6/src/ripple_down_rules/datastructures/generated/__init__.py +0 -0
  17. ripple_down_rules-0.0.6/src/ripple_down_rules/datastructures/generated/column/__init__.py +0 -0
  18. ripple_down_rules-0.0.6/src/ripple_down_rules/datastructures/generated/row/__init__.py +0 -0
  19. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/LICENSE +0 -0
  20. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/README.md +0 -0
  21. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/setup.cfg +0 -0
  22. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules/__init__.py +0 -0
  23. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules/datastructures/__init__.py +0 -0
  24. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules/datastructures/dataclasses.py +0 -0
  25. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules/failures.py +0 -0
  26. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules/prompt.py +0 -0
  27. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules.egg-info/dependency_links.txt +0 -0
  28. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/src/ripple_down_rules.egg-info/top_level.txt +0 -0
  29. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/test/test_relational_rdr.py +0 -0
  30. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/test/test_relational_rdr_alchemy.py +0 -0
  31. {ripple_down_rules-0.0.6 → ripple_down_rules-0.0.8}/test/test_sql_model.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
6
6
 
7
7
  [project]
8
8
  name = "ripple_down_rules"
9
- version = "0.0.6"
9
+ version = "0.0.8"
10
10
  description = "Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning."
11
11
  readme = "README.md"
12
12
  authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }]
@@ -81,7 +81,8 @@ def load_zoo_dataset(cache_file: Optional[str] = None) -> Tuple[List[Case], List
81
81
 
82
82
  category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
83
83
  category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
84
- targets = [getattr(SpeciesCol, category_id_to_name[i]) for i in y.values.flatten()]
84
+ # targets = [getattr(SpeciesCol, category_id_to_name[i]) for i in y.values.flatten()]
85
+ targets = [Species.from_str(category_id_to_name[i]) for i in y.values.flatten()]
85
86
  return all_cases, targets
86
87
 
87
88
 
@@ -104,8 +105,8 @@ class Habitat(Category):
104
105
  air = "air"
105
106
 
106
107
 
107
- SpeciesCol = Column.create_from_enum(Species, mutually_exclusive=True)
108
- HabitatCol = Column.create_from_enum(Habitat, mutually_exclusive=False)
108
+ # SpeciesCol = Column.create_from_enum(Species, mutually_exclusive=True)
109
+ # HabitatCol = Column.create_from_enum(Habitat, mutually_exclusive=False)
109
110
 
110
111
 
111
112
  class Base(sqlalchemy.orm.DeclarativeBase):
@@ -162,9 +162,8 @@ class CallableExpression(SubclassJSONSerializer):
162
162
  prev_e = e
163
163
  return "\n".join(all_binary_ops) if len(all_binary_ops) > 0 else user_input
164
164
 
165
- def to_json(self) -> Dict[str, Any]:
166
- return {**SubclassJSONSerializer.to_json(self),
167
- "user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type)}
165
+ def _to_json(self) -> Dict[str, Any]:
166
+ return {"user_input": self.user_input, "conclusion_type": get_full_class_name(self.conclusion_type)}
168
167
 
169
168
  @classmethod
170
169
  def _from_json(cls, data: Dict[str, Any]) -> CallableExpression:
@@ -2,10 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  from enum import auto, Enum
4
4
 
5
- from typing_extensions import List
5
+ from typing_extensions import List, Dict, Any
6
6
 
7
+ from ripple_down_rules.utils import SubclassJSONSerializer
7
8
 
8
- class Category(str, Enum):
9
+
10
+ class Category(str, SubclassJSONSerializer, Enum):
9
11
 
10
12
  @classmethod
11
13
  def from_str(cls, value: str) -> Category:
@@ -17,7 +19,14 @@ class Category(str, Enum):
17
19
 
18
20
  @property
19
21
  def as_dict(self):
20
- return {self.__class__.__name__.lower(): self}
22
+ return {self.__class__.__name__.lower(): self.value}
23
+
24
+ def _to_json(self) -> Dict[str, Any]:
25
+ return self.as_dict
26
+
27
+ @classmethod
28
+ def _from_json(cls, data: Dict[str, Any]) -> Category:
29
+ return cls.from_str(data[cls.__name__.lower()])
21
30
 
22
31
 
23
32
  class Stop(Category):
@@ -21,165 +21,7 @@ if TYPE_CHECKING:
21
21
  from .callable_expression import CallableExpression
22
22
 
23
23
 
24
- class SubClassFactory:
25
- """
26
- A custom set class that is used to add other attributes to the set. This is similar to a table where the set is the
27
- table, the attributes are the columns, and the values are the rows.
28
- """
29
- _value_range: set
30
- """
31
- The range of the attribute, this can be a set of possible values or a range of numeric values (int, float).
32
- """
33
- _registry: Dict[(str, type), Type[SubClassFactory]] = {}
34
- """
35
- A dictionary of all dynamically created subclasses of this class.
36
- """
37
- _generated_classes_dir: str = os.path.dirname(os.path.abspath(__file__)) + "/generated"
38
-
39
- @classmethod
40
- def create(cls, name: str, range_: set, class_attributes: Optional[Dict[str, Any]] = None,
41
- default_values: bool = True,
42
- attributes_type_hints: Optional[Dict[str, Type]] = None) -> Type[SubClassFactory]:
43
- """
44
- Create a new subclass.
45
-
46
- :param name: The name of the subclass.
47
- :param range_: The range of the subclass values.
48
- :param class_attributes: The attributes of the new subclass.
49
- :param default_values: Boolean indicating whether to add default values to the subclass attributes or not.
50
- :param attributes_type_hints: The type hints of the subclass attributes.
51
- :return: The new subclass.
52
- """
53
- existing_class = cls._get_and_update_subclass(name, range_)
54
- if existing_class:
55
- return existing_class
56
-
57
- new_attribute_type = cls._create_class_in_new_python_file_and_import_it(name, range_, default_values,
58
- class_attributes, attributes_type_hints)
59
-
60
- cls.register(new_attribute_type)
61
-
62
- return new_attribute_type
63
-
64
- @classmethod
65
- def _create_class_in_new_python_file_and_import_it(cls, name: str, range_: set, default_values: bool = True,
66
- class_attributes: Optional[Dict[str, Any]] = None,
67
- attributes_type_hints: Optional[Dict[str, Type]] = None)\
68
- -> Type[SubClassFactory]:
69
- def get_type_import(value_type: Any) -> Tuple[str, str]:
70
- if value_type is type(None):
71
- return "from types import NoneType\n", "NoneType"
72
- elif value_type.__module__ != "builtins":
73
- value_type_alias = f"{value_type.__name__}_"
74
- return f"from {value_type.__module__} import {value_type.__name__} as {value_type_alias}\n", value_type_alias
75
- else:
76
- return "", value_type.__name__
77
- attributes_type_hints = attributes_type_hints or {}
78
- parent_class_alias = cls.__name__ + "_"
79
- imports = f"from {cls.__module__} import {cls.__name__} as {parent_class_alias}\n"
80
- class_code = f"class {name}({parent_class_alias}):\n"
81
- class_attributes = copy(class_attributes) if class_attributes else {}
82
- class_attributes.update({"_value_range": range_})
83
- for key, value in class_attributes.items():
84
- if value is not None:
85
- new_import, value_type_name = get_type_import(type(value))
86
- elif key in attributes_type_hints:
87
- new_import, value_type_name = get_type_import(attributes_type_hints[key])
88
- else:
89
- new_import, value_type_name = "from typing_extensions import Any", "Any"
90
- imports += new_import
91
- if isinstance(value, set):
92
- value_names = []
93
- for v in value:
94
- if isinstance(v, type):
95
- new_import, v_name = get_type_import(v)
96
- imports += new_import
97
- else:
98
- v_name = str(v)
99
- value_names.append(v_name)
100
- value_str = ", ".join(value_names)
101
- new_value = "{" + value_str + "}"
102
- elif isinstance(value, type):
103
- new_import, value_name = get_type_import(value)
104
- new_value = value_name
105
- value_type_name = value_name
106
- else:
107
- new_value = value
108
- if default_values or key == "_value_range":
109
- class_code += f" {key}: {value_type_name} = {new_value}\n"
110
- else:
111
- class_code += f" {key}: {value_type_name}\n"
112
- imports += "\n\n"
113
- if issubclass(cls, Row):
114
- folder_name = "row"
115
- elif issubclass(cls, Column):
116
- folder_name = "column"
117
- else:
118
- raise ValueError(f"Unknown class {cls}.")
119
- # write the code to a file
120
- with open(f"{cls._generated_classes_dir}/{folder_name}/{name.lower()}.py", "w") as f:
121
- f.write(imports + class_code)
122
-
123
- # import the class from the file
124
- import_path = ".".join(cls.__module__.split(".")[:-1] + ["generated", folder_name, name.lower()])
125
- time.sleep(0.3)
126
- return __import__(import_path, fromlist=[name.lower()]).__dict__[name]
127
-
128
- @classmethod
129
- def _get_and_update_subclass(cls, name: str, range_: set) -> Optional[Type[SubClassFactory]]:
130
- """
131
- Get a subclass of the attribute class and update its range if necessary.
132
-
133
- :param name: The name of the column.
134
- :param range_: The range of the column values.
135
- """
136
- key = (name.lower(), cls)
137
- if key in cls._registry:
138
- if not cls._registry[key].is_within_range(range_):
139
- if isinstance(cls._registry[key]._value_range, set):
140
- cls._registry[key]._value_range.update(range_)
141
- else:
142
- raise ValueError(f"Range of {key} is different from {cls._registry[key]._value_range}.")
143
- return cls._registry[key]
144
-
145
- @classmethod
146
- def register(cls, subclass: Type[SubClassFactory]):
147
- """
148
- Register a subclass of the attribute class, this is used to be able to dynamically create Attribute subclasses.
149
-
150
- :param subclass: The subclass to register.
151
- """
152
- if not issubclass(subclass, SubClassFactory):
153
- raise ValueError(f"{subclass} is not a subclass of CustomSet.")
154
- if subclass not in cls._registry:
155
- cls._registry[(subclass.__name__.lower(), cls)] = subclass
156
- else:
157
- raise ValueError(f"{subclass} is already registered.")
158
-
159
- @classmethod
160
- def is_within_range(cls, value: Any) -> bool:
161
- """
162
- Check if a value is within the range of the custom set.
163
-
164
- :param value: The value to check.
165
- :return: Boolean indicating whether the value is within the range or not.
166
- """
167
- if hasattr(value, "__iter__") and not isinstance(value, str):
168
- if all(isinstance(val_range, type) and isinstance(v, val_range)
169
- for v in value for val_range in cls._value_range):
170
- return True
171
- else:
172
- return set(value).issubset(cls._value_range)
173
- elif isinstance(value, str):
174
- return value.lower() in cls._value_range
175
- else:
176
- return value in cls._value_range
177
-
178
- def __instancecheck__(self, instance):
179
- return isinstance(instance, (SubClassFactory, *self._value_range))
180
-
181
-
182
- class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
24
+ class Row(UserDict, SubclassJSONSerializer):
183
25
  """
184
26
  A collection of attributes that represents a set of constraints on a case. This is a dictionary where the keys are
185
27
  the names of the attributes and the values are the attributes. All are stored in lower case.
@@ -193,7 +35,7 @@ class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
193
35
  :param kwargs: The attributes of the row.
194
36
  """
195
37
  super().__init__(kwargs)
196
- self.id = id_
38
+ self.id_ = id_ if id_ else id(self)
197
39
 
198
40
  @classmethod
199
41
  def from_obj(cls, obj: Any, obj_name: Optional[str] = None, max_recursion_idx: int = 3) -> Row:
@@ -241,16 +83,12 @@ class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
241
83
  return super().__eq__(other)
242
84
 
243
85
  def __hash__(self):
244
- return hash(tuple(self.items()))
86
+ return self.id_
245
87
 
246
- def __instancecheck__(self, instance):
247
- return isinstance(instance, (dict, UserDict, Row)) or super().__instancecheck__(instance)
248
-
249
- def to_json(self) -> Dict[str, Any]:
88
+ def _to_json(self) -> Dict[str, Any]:
250
89
  serializable = {k: v for k, v in self.items() if not k.startswith("_")}
251
- serializable["_id"] = self.id
252
- return {**SubclassJSONSerializer.to_json(self),
253
- **{k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}}
90
+ serializable["_id"] = self.id_
91
+ return {k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}
254
92
 
255
93
  @classmethod
256
94
  def _from_json(cls, data: Dict[str, Any]) -> Row:
@@ -280,16 +118,15 @@ class ColumnValue(SubclassJSONSerializer):
280
118
  def __hash__(self):
281
119
  return self.id
282
120
 
283
- def to_json(self) -> Dict[str, Any]:
284
- return {**SubclassJSONSerializer.to_json(self),
285
- "id": self.id, "value": self.value}
121
+ def _to_json(self) -> Dict[str, Any]:
122
+ return {"id": self.id, "value": self.value}
286
123
 
287
124
  @classmethod
288
125
  def _from_json(cls, data: Dict[str, Any]) -> ColumnValue:
289
126
  return cls(id=data["id"], value=data["value"])
290
127
 
291
128
 
292
- class Column(set, SubClassFactory, SubclassJSONSerializer):
129
+ class Column(set, SubclassJSONSerializer):
293
130
  nullable: bool = True
294
131
  """
295
132
  A boolean indicating whether the column can be None or not.
@@ -321,21 +158,6 @@ class Column(set, SubClassFactory, SubclassJSONSerializer):
321
158
  values = {ColumnValue(id(values), v) for v in values}
322
159
  return values
323
160
 
324
- @classmethod
325
- def create(cls, name: str, range_: set,
326
- nullable: bool = True, mutually_exclusive: bool = False) -> Type[SubClassFactory]:
327
- return super().create(name, range_, {"nullable": nullable, "mutually_exclusive": mutually_exclusive})
328
-
329
- @classmethod
330
- def create_from_enum(cls, category: Type[Enum], nullable: bool = True,
331
- mutually_exclusive: bool = False) -> Type[SubClassFactory]:
332
- new_cls = cls.create(category.__name__.lower(), {category}, nullable=nullable,
333
- mutually_exclusive=mutually_exclusive)
334
- for value in category:
335
- value_column = cls.create(category.__name__.lower(), {value}, mutually_exclusive=mutually_exclusive)(value)
336
- setattr(new_cls, value.name, value_column)
337
- return new_cls
338
-
339
161
  @classmethod
340
162
  def from_obj(cls, values: Set[Any], row_obj: Optional[Any] = None) -> Column:
341
163
  id_ = id(row_obj) if row_obj else id(values)
@@ -373,12 +195,9 @@ class Column(set, SubClassFactory, SubclassJSONSerializer):
373
195
  return "None"
374
196
  return str({v for v in self}) if len(self) > 1 else str(next(iter(self)))
375
197
 
376
- def __instancecheck__(self, instance):
377
- return isinstance(instance, (set, self.__class__)) or super().__instancecheck__(instance)
378
-
379
- def to_json(self) -> Dict[str, Any]:
380
- return {**SubclassJSONSerializer.to_json(self),
381
- **{id_: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for id_, v in self.id_value_map.items()}}
198
+ def _to_json(self) -> Dict[str, Any]:
199
+ return {id_: v.to_json() if isinstance(v, SubclassJSONSerializer) else v
200
+ for id_, v in self.id_value_map.items()}
382
201
 
383
202
  @classmethod
384
203
  def _from_json(cls, data: Dict[str, Any]) -> Column:
@@ -397,8 +216,7 @@ def create_rows_from_dataframe(df: DataFrame, name: Optional[str] = None) -> Lis
397
216
  col_names = list(df.columns)
398
217
  for row_id, row in df.iterrows():
399
218
  row = {col_name: row[col_name].item() for col_name in col_names}
400
- row_cls = Row.create(name or df.__class__.__name__, make_set(type(df)), row, default_values=False)
401
- rows.append(row_cls(id_=row_id, **row))
219
+ rows.append(Row(id_=row_id, **row))
402
220
  return rows
403
221
 
404
222
 
@@ -420,18 +238,12 @@ def create_row(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
420
238
  or (obj.__class__ in [MetaData, registry])):
421
239
  return Row(id_=id(obj), **{obj_name or obj.__class__.__name__: make_set(obj) if parent_is_iterable else obj})
422
240
  row = Row(id_=id(obj))
423
- attributes_type_hints = {}
424
241
  for attr in dir(obj):
425
242
  if attr.startswith("_") or callable(getattr(obj, attr)):
426
243
  continue
427
244
  attr_value = getattr(obj, attr)
428
245
  row = create_or_update_row_from_attribute(attr_value, attr, obj, attr, recursion_idx,
429
246
  max_recursion_idx, parent_is_iterable, row)
430
- attributes_type_hints[attr] = get_value_type_from_type_hint(attr, obj)
431
- if recursion_idx == 0:
432
- row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False,
433
- attributes_type_hints=attributes_type_hints)
434
- row = row_cls(id_=id(obj), **row)
435
247
  return row
436
248
 
437
249
 
@@ -487,7 +299,7 @@ def create_column_and_row_from_iterable_attribute(attr_value: Any, name: str, ob
487
299
  if not range_:
488
300
  raise ValueError(f"Could not determine the range of {name} in {obj}.")
489
301
  attr_row = Row(id_=id(attr_value))
490
- column = Column.create(name, range_).from_obj(values, row_obj=obj)
302
+ column = Column.from_obj(values, row_obj=obj)
491
303
  for idx, val in enumerate(values):
492
304
  sub_attr_row = create_row(val, recursion_idx=recursion_idx,
493
305
  max_recursion_idx=max_recursion_idx,
@@ -198,19 +198,6 @@ class Human(Expert):
198
198
  self.all_expert_answers.append(expert_input)
199
199
  return expression
200
200
 
201
- def create_category_instance(self, cat_name: str, cat_value: Union[str, int, float, set]) -> Column:
202
- """
203
- Create a new category instance.
204
-
205
- :param cat_name: The name of the category.
206
- :param cat_value: The value of the category.
207
- :return: A new instance of the category.
208
- """
209
- category_type = self.get_category_type(cat_name)
210
- if not category_type:
211
- category_type = self.create_new_category_type(cat_name)
212
- return category_type(cat_value)
213
-
214
201
  def get_category_type(self, cat_name: str) -> Optional[Type[Column]]:
215
202
  """
216
203
  Get the category type from the known categories.
@@ -226,19 +213,6 @@ class Human(Expert):
226
213
  category_type = self.known_categories[cat_name]
227
214
  return category_type
228
215
 
229
- def create_new_category_type(self, cat_name: str) -> Type[Column]:
230
- """
231
- Create a new category type.
232
-
233
- :param cat_name: The name of the category.
234
- :return: A new category type.
235
- """
236
- if self.ask_if_category_is_mutually_exclusive(cat_name):
237
- category_type: Type[Column] = Column.create(cat_name, set(), mutually_exclusive=True)
238
- else:
239
- category_type: Type[Column] = Column.create(cat_name, set())
240
- return category_type
241
-
242
216
  def ask_if_category_is_mutually_exclusive(self, category_name: str) -> bool:
243
217
  """
244
218
  Ask the expert if the new category can have multiple values.
@@ -128,8 +128,6 @@ class RippleDownRules(ABC):
128
128
  target = target if is_iterable(target) else [target]
129
129
  recall = [not yi or (yi in pred_cat) for yi in target]
130
130
  target_types = [type(yi) for yi in target]
131
- if len(pred_cat) > 1:
132
- print(pred_cat)
133
131
  precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
134
132
  return precision, recall
135
133
 
@@ -267,8 +265,8 @@ class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
267
265
  if rule.alternative:
268
266
  self.write_rules_as_source_code_to_file(rule.alternative, file, parent_indent)
269
267
 
270
- def to_json(self) -> Dict[str, Any]:
271
- return {**SubclassJSONSerializer.to_json(self), "start_rule": self.start_rule.to_json()}
268
+ def _to_json(self) -> Dict[str, Any]:
269
+ return {"start_rule": self.start_rule.to_json()}
272
270
 
273
271
  @classmethod
274
272
  def _from_json(cls, data: Dict[str, Any]) -> Self:
@@ -195,9 +195,8 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
195
195
  if_clause = "elif" if self.weight == RDREdge.Alternative.value else "if"
196
196
  return f"{parent_indent}{if_clause} {self.conditions.parsed_user_input}:\n"
197
197
 
198
- def to_json(self) -> Dict[str, Any]:
199
- self.json_serialization = {**SubclassJSONSerializer.to_json(self),
200
- "conditions": self.conditions.to_json(),
198
+ def _to_json(self) -> Dict[str, Any]:
199
+ self.json_serialization = {"conditions": self.conditions.to_json(),
201
200
  "conclusion": self.conclusion.to_json(),
202
201
  "parent": self.parent.json_serialization if self.parent else None,
203
202
  "corner_case": self.corner_case.to_json() if self.corner_case else None,
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import importlib
4
+ import json
4
5
  import logging
5
6
  from abc import abstractmethod
6
7
  from collections import UserDict
@@ -81,8 +82,27 @@ class SubclassJSONSerializer:
81
82
  'from_json' method.
82
83
  """
83
84
 
85
+ def to_json_file(self, filename: str):
86
+ """
87
+ Save the object to a json file.
88
+ """
89
+ data = self.to_json()
90
+ # save the json to a file
91
+ if not filename.endswith(".json"):
92
+ filename += ".json"
93
+ with open(filename, "w") as f:
94
+ json.dump(data, f, indent=4)
95
+ return data
96
+
84
97
  def to_json(self) -> Dict[str, Any]:
85
- return {"_type": get_full_class_name(self.__class__)}
98
+ return {"_type": get_full_class_name(self.__class__), **self._to_json()}
99
+
100
+ @abstractmethod
101
+ def _to_json(self) -> Dict[str, Any]:
102
+ """
103
+ Create a json dict from the object.
104
+ """
105
+ pass
86
106
 
87
107
  @classmethod
88
108
  @abstractmethod
@@ -97,6 +117,19 @@ class SubclassJSONSerializer:
97
117
  """
98
118
  raise NotImplementedError()
99
119
 
120
+ @classmethod
121
+ def from_json_file(cls, filename: str):
122
+ """
123
+ Create an instance of the subclass from the data in the given json file.
124
+
125
+ :param filename: The filename of the json file.
126
+ """
127
+ if not filename.endswith(".json"):
128
+ filename += ".json"
129
+ with open(filename, "r") as f:
130
+ scrdr_json = json.load(f)
131
+ return cls.from_json(scrdr_json)
132
+
100
133
  @classmethod
101
134
  def from_json(cls, data: Dict[str, Any]) -> Self:
102
135
  """
@@ -117,6 +150,9 @@ class SubclassJSONSerializer:
117
150
 
118
151
  raise ValueError("Unknown type {}".format(data["_type"]))
119
152
 
153
+ save = to_json_file
154
+ load = from_json_file
155
+
120
156
 
121
157
  def copy_case(case: Union[Case, SQLTable]) -> Union[Case, SQLTable]:
122
158
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -18,9 +18,6 @@ src/ripple_down_rules/datastructures/callable_expression.py
18
18
  src/ripple_down_rules/datastructures/dataclasses.py
19
19
  src/ripple_down_rules/datastructures/enums.py
20
20
  src/ripple_down_rules/datastructures/table.py
21
- src/ripple_down_rules/datastructures/generated/__init__.py
22
- src/ripple_down_rules/datastructures/generated/column/__init__.py
23
- src/ripple_down_rules/datastructures/generated/row/__init__.py
24
21
  test/test_json_serialization.py
25
22
  test/test_rdr.py
26
23
  test/test_rdr_alchemy.py
@@ -19,17 +19,11 @@ class TestJSONSerialization(TestCase):
19
19
  def setUpClass(cls):
20
20
  cls.all_cases, cls.targets = load_zoo_dataset(cls.cache_dir + "/zoo_dataset.pkl")
21
21
 
22
- def test_json_serialization(self):
22
+ def test_scrdr_json_serialization(self):
23
23
  scrdr = self.get_fit_scrdr()
24
- scrdr_json = scrdr.to_json()
25
- # save the json to a file
26
- with open(f"{self.cache_dir}/scrdr.json", "w") as f:
27
- json.dump(scrdr_json, f, indent=4)
28
-
29
- # load the json from the file
30
- with open(f"{self.cache_dir}/scrdr.json", "r") as f:
31
- scrdr_json = json.load(f)
32
- scrdr = SingleClassRDR.from_json(scrdr_json)
24
+ filename = f"{self.cache_dir}/scrdr.json"
25
+ scrdr.save(filename)
26
+ scrdr = SingleClassRDR.load(filename)
33
27
  for case, target in zip(self.all_cases, self.targets):
34
28
  cat = scrdr.classify(case)
35
29
  self.assertEqual(cat, target)
@@ -3,7 +3,7 @@ from unittest import TestCase, skip
3
3
 
4
4
  from typing_extensions import List
5
5
 
6
- from ripple_down_rules.datasets import Habitat, SpeciesCol as Species
6
+ from ripple_down_rules.datasets import Habitat, Species
7
7
  from ripple_down_rules.datasets import load_zoo_dataset
8
8
  from ripple_down_rules.datastructures import Case, MCRDRMode, \
9
9
  Row, Column, Category, CaseQuery
@@ -28,10 +28,6 @@ class TestRDR(TestCase):
28
28
  if not os.path.exists(test_dir):
29
29
  os.makedirs(test_dir)
30
30
 
31
- def tearDown(self):
32
- Row._registry = {}
33
- Column._registry = {}
34
-
35
31
  def test_classify_scrdr(self):
36
32
  use_loaded_answers = True
37
33
  save_answers = False
@@ -5,6 +5,7 @@ import sqlalchemy.orm
5
5
  from sqlalchemy import select
6
6
  from sqlalchemy.orm import MappedColumn as Column
7
7
  from typing_extensions import List, Sequence
8
+ import pandas as pd
8
9
 
9
10
  from ripple_down_rules.datasets import Base, Animal, Species, get_dataset, Habitat, HabitatTable
10
11
  from ripple_down_rules.datastructures import CaseQuery
@@ -23,26 +24,35 @@ class TestAlchemyRDR(TestCase):
23
24
 
24
25
  @classmethod
25
26
  def setUpClass(cls):
27
+ # load dataset
26
28
  zoo = get_dataset(111, cls.cache_file)
27
29
 
28
- # data (as pandas dataframes)
30
+ # get data and targets (as pandas dataframes)
29
31
  X = zoo['features']
30
32
  y = zoo['targets']
31
33
  names = zoo['ids'].values.flatten()
32
34
  X.loc[:, "name"] = names
33
35
 
34
- engine = sqlalchemy.create_engine("sqlite:///:memory:")
35
- Base.metadata.create_all(engine)
36
- session = sqlalchemy.orm.Session(engine)
37
- session.bulk_insert_mappings(Animal, X.to_dict(orient="records"))
38
- session.commit()
39
- cls.session = session
36
+ cls._init_session_and_insert_data(X)
37
+
38
+ # init cases
40
39
  query = select(Animal)
41
40
  cls.all_cases = cls.session.scalars(query).all()
41
+
42
+ # init targets
42
43
  category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
43
44
  category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
44
45
  cls.targets = [Species(category_id_to_name[i]) for i in y.values.flatten()]
45
46
 
47
+ @classmethod
48
+ def _init_session_and_insert_data(cls, data: pd.DataFrame):
49
+ engine = sqlalchemy.create_engine("sqlite:///:memory:")
50
+ Base.metadata.create_all(engine)
51
+ session = sqlalchemy.orm.Session(engine)
52
+ session.bulk_insert_mappings(Animal, data.to_dict(orient="records"))
53
+ session.commit()
54
+ cls.session = session
55
+
46
56
  def test_fit_scrdr(self):
47
57
  use_loaded_answers = True
48
58
  draw_tree = False