ripple-down-rules 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from sqlalchemy.orm import MappedAsDataclass, Mapped, mapped_column, relationshi
9
9
  from typing_extensions import Tuple, List, Set, Optional
10
10
  from ucimlrepo import fetch_ucirepo
11
11
 
12
- from .datastructures import Case, create_rows_from_dataframe, Category, Column
12
+ from .datastructures import Case, create_cases_from_dataframe, Category, CaseAttribute
13
13
 
14
14
 
15
15
  def load_cached_dataset(cache_file):
@@ -77,7 +77,7 @@ def load_zoo_dataset(cache_file: Optional[str] = None) -> Tuple[List[Case], List
77
77
  y = zoo['targets']
78
78
  # get ids as list of strings
79
79
  ids = zoo['ids'].values.flatten()
80
- all_cases = create_rows_from_dataframe(X, "Animal")
80
+ all_cases = create_cases_from_dataframe(X)
81
81
 
82
82
  category_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "molusc"]
83
83
  category_id_to_name = {i + 1: name for i, name in enumerate(category_names)}
@@ -1,4 +1,4 @@
1
1
  from .enums import *
2
2
  from .dataclasses import *
3
3
  from .callable_expression import *
4
- from .table import *
4
+ from .case import *
@@ -7,7 +7,7 @@ from _ast import AST
7
7
  from sqlalchemy.orm import Session
8
8
  from typing_extensions import Type, Optional, Any, List, Union, Tuple, Dict, Set
9
9
 
10
- from .table import create_row, Row
10
+ from .case import create_case, Case
11
11
  from ..utils import SubclassJSONSerializer, get_full_class_name, get_type_from_string
12
12
 
13
13
 
@@ -128,8 +128,8 @@ class CallableExpression(SubclassJSONSerializer):
128
128
 
129
129
  def __call__(self, case: Any, **kwargs) -> Any:
130
130
  try:
131
- if not isinstance(case, Row):
132
- case = create_row(case, max_recursion_idx=3)
131
+ if not isinstance(case, Case):
132
+ case = create_case(case, max_recursion_idx=3)
133
133
  output = eval(self.code)
134
134
  if self.conclusion_type:
135
135
  assert isinstance(output, self.conclusion_type), (f"Expected output type {self.conclusion_type},"
@@ -1,44 +1,40 @@
1
1
  from __future__ import annotations
2
2
 
3
- import os
4
- import time
5
- from abc import ABC
6
3
  from collections import UserDict
7
- from copy import deepcopy, copy
8
4
  from dataclasses import dataclass
9
5
  from enum import Enum
10
6
 
11
7
  from pandas import DataFrame
12
8
  from sqlalchemy import MetaData
13
9
  from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColumn, registry
14
- from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING, Tuple, Self
10
+ from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING
15
11
 
16
- from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, make_list, \
17
- SubclassJSONSerializer
12
+ from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, SubclassJSONSerializer
18
13
 
19
14
  if TYPE_CHECKING:
20
15
  from ripple_down_rules.rules import Rule
21
16
  from .callable_expression import CallableExpression
22
17
 
23
18
 
24
- class Row(UserDict, SubclassJSONSerializer):
19
+ class Case(UserDict, SubclassJSONSerializer):
25
20
  """
26
21
  A collection of attributes that represents a set of constraints on a case. This is a dictionary where the keys are
27
22
  the names of the attributes and the values are the attributes. All are stored in lower case.
28
23
  """
29
24
 
30
- def __init__(self, id_: Optional[Hashable] = None, **kwargs):
25
+ def __init__(self, _id: Optional[Hashable] = None, _type: Optional[Type] = None, **kwargs):
31
26
  """
32
27
  Create a new row.
33
28
 
34
- :param id_: The id of the row.
29
+ :param _id: The id of the row.
35
30
  :param kwargs: The attributes of the row.
36
31
  """
37
32
  super().__init__(kwargs)
38
- self.id_ = id_ if id_ else id(self)
33
+ self._id = _id if _id else id(self)
34
+ self._type = _type
39
35
 
40
36
  @classmethod
41
- def from_obj(cls, obj: Any, obj_name: Optional[str] = None, max_recursion_idx: int = 3) -> Row:
37
+ def from_obj(cls, obj: Any, obj_name: Optional[str] = None, max_recursion_idx: int = 3) -> Case:
42
38
  """
43
39
  Create a row from an object.
44
40
 
@@ -47,7 +43,7 @@ class Row(UserDict, SubclassJSONSerializer):
47
43
  :param obj_name: The name of the object.
48
44
  :return: The row of the object.
49
45
  """
50
- return create_row(obj, max_recursion_idx=max_recursion_idx, obj_name=obj_name)
46
+ return create_case(obj, max_recursion_idx=max_recursion_idx, obj_name=obj_name)
51
47
 
52
48
  def __getitem__(self, item: str) -> Any:
53
49
  return super().__getitem__(item.lower())
@@ -74,30 +70,22 @@ class Row(UserDict, SubclassJSONSerializer):
74
70
  def __delitem__(self, key):
75
71
  super().__delitem__(key.lower())
76
72
 
77
- def __eq__(self, other):
78
- if not isinstance(other, (Row, dict, UserDict)):
79
- return False
80
- elif isinstance(other, (dict, UserDict)):
81
- return super().__eq__(Row(other))
82
- else:
83
- return super().__eq__(other)
84
-
85
73
  def __hash__(self):
86
- return self.id_
74
+ return self._id
87
75
 
88
76
  def _to_json(self) -> Dict[str, Any]:
89
77
  serializable = {k: v for k, v in self.items() if not k.startswith("_")}
90
- serializable["_id"] = self.id_
78
+ serializable["_id"] = self._id
91
79
  return {k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in serializable.items()}
92
80
 
93
81
  @classmethod
94
- def _from_json(cls, data: Dict[str, Any]) -> Row:
82
+ def _from_json(cls, data: Dict[str, Any]) -> Case:
95
83
  id_ = data.pop("_id")
96
- return cls(id_=id_, **data)
84
+ return cls(_id=id_, **data)
97
85
 
98
86
 
99
87
  @dataclass
100
- class ColumnValue(SubclassJSONSerializer):
88
+ class CaseAttributeValue(SubclassJSONSerializer):
101
89
  """
102
90
  A column value is a value in a column.
103
91
  """
@@ -111,7 +99,7 @@ class ColumnValue(SubclassJSONSerializer):
111
99
  """
112
100
 
113
101
  def __eq__(self, other):
114
- if not isinstance(other, ColumnValue):
102
+ if not isinstance(other, CaseAttributeValue):
115
103
  return False
116
104
  return self.value == other.value
117
105
 
@@ -122,58 +110,58 @@ class ColumnValue(SubclassJSONSerializer):
122
110
  return {"id": self.id, "value": self.value}
123
111
 
124
112
  @classmethod
125
- def _from_json(cls, data: Dict[str, Any]) -> ColumnValue:
113
+ def _from_json(cls, data: Dict[str, Any]) -> CaseAttributeValue:
126
114
  return cls(id=data["id"], value=data["value"])
127
115
 
128
116
 
129
- class Column(set, SubclassJSONSerializer):
117
+ class CaseAttribute(set, SubclassJSONSerializer):
130
118
  nullable: bool = True
131
119
  """
132
- A boolean indicating whether the column can be None or not.
120
+ A boolean indicating whether the case attribute can be None or not.
133
121
  """
134
122
  mutually_exclusive: bool = False
135
123
  """
136
- A boolean indicating whether the column is mutually exclusive or not. (i.e. can only have one value)
124
+ A boolean indicating whether the case attribute is mutually exclusive or not. (i.e. can only have one value)
137
125
  """
138
126
 
139
- def __init__(self, values: Set[ColumnValue]):
127
+ def __init__(self, values: Set[CaseAttributeValue]):
140
128
  """
141
- Create a new column.
129
+ Create a new case attribute.
142
130
 
143
- :param values: The values of the column.
131
+ :param values: The values of the case attribute.
144
132
  """
145
- values = self._type_cast_values_to_set_of_column_values(values)
146
- self.id_value_map: Dict[Hashable, Union[ColumnValue, Set[ColumnValue]]] = {id(v): v for v in values}
133
+ values = self._type_cast_values_to_set_of_case_attribute_values(values)
134
+ self.id_value_map: Dict[Hashable, Union[CaseAttributeValue, Set[CaseAttributeValue]]] = {id(v): v for v in values}
147
135
  super().__init__([v.value for v in values])
148
136
 
149
137
  @staticmethod
150
- def _type_cast_values_to_set_of_column_values(values: Set[Any]) -> Set[ColumnValue]:
138
+ def _type_cast_values_to_set_of_case_attribute_values(values: Set[Any]) -> Set[CaseAttributeValue]:
151
139
  """
152
- Type cast values to a set of column values.
140
+ Type cast values to a set of case attribute values.
153
141
 
154
142
  :param values: The values to type cast.
155
143
  """
156
144
  values = make_set(values)
157
- if len(values) > 0 and not isinstance(next(iter(values)), ColumnValue):
158
- values = {ColumnValue(id(values), v) for v in values}
145
+ if len(values) > 0 and not isinstance(next(iter(values)), CaseAttributeValue):
146
+ values = {CaseAttributeValue(id(values), v) for v in values}
159
147
  return values
160
148
 
161
149
  @classmethod
162
- def from_obj(cls, values: Set[Any], row_obj: Optional[Any] = None) -> Column:
150
+ def from_obj(cls, values: Set[Any], row_obj: Optional[Any] = None) -> CaseAttribute:
163
151
  id_ = id(row_obj) if row_obj else id(values)
164
152
  values = make_set(values)
165
- return cls({ColumnValue(id_, v) for v in values})
153
+ return cls({CaseAttributeValue(id_, v) for v in values})
166
154
 
167
155
  @property
168
156
  def as_dict(self) -> Dict[str, Any]:
169
157
  """
170
- Get the column as a dictionary.
158
+ Get the case attribute as a dictionary.
171
159
 
172
- :return: The column as a dictionary.
160
+ :return: The case attribute as a dictionary.
173
161
  """
174
162
  return {self.__class__.__name__: self}
175
163
 
176
- def filter_by(self, condition: CallableExpression) -> Column:
164
+ def filter_by(self, condition: CallableExpression) -> CaseAttribute:
177
165
  """
178
166
  Filter the column by a condition.
179
167
 
@@ -200,118 +188,114 @@ class Column(set, SubclassJSONSerializer):
200
188
  for id_, v in self.id_value_map.items()}
201
189
 
202
190
  @classmethod
203
- def _from_json(cls, data: Dict[str, Any]) -> Column:
204
- return cls({ColumnValue.from_json(v) for id_, v in data.items()})
191
+ def _from_json(cls, data: Dict[str, Any]) -> CaseAttribute:
192
+ return cls({CaseAttributeValue.from_json(v) for id_, v in data.items()})
205
193
 
206
194
 
207
- def create_rows_from_dataframe(df: DataFrame, name: Optional[str] = None) -> List[Row]:
195
+ def create_cases_from_dataframe(df: DataFrame) -> List[Case]:
208
196
  """
209
- Create a row from a pandas DataFrame.
197
+ Create cases from a pandas DataFrame.
210
198
 
211
- :param df: The DataFrame to create a row from.
212
- :param name: The name of the DataFrame.
213
- :return: The row of the DataFrame.
199
+ :param df: The DataFrame to create cases from.
200
+ :return: The cases of the DataFrame.
214
201
  """
215
- rows = []
216
- col_names = list(df.columns)
217
- for row_id, row in df.iterrows():
218
- row = {col_name: row[col_name].item() for col_name in col_names}
219
- rows.append(Row(id_=row_id, **row))
220
- return rows
202
+ cases = []
203
+ attribute_names = list(df.columns)
204
+ for row_id, case in df.iterrows():
205
+ case = {col_name: case[col_name].item() for col_name in attribute_names}
206
+ cases.append(Case(_id=row_id, _type=DataFrame, **case))
207
+ return cases
221
208
 
222
209
 
223
- def create_row(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
224
- obj_name: Optional[str] = None, parent_is_iterable: bool = False) -> Row:
210
+ def create_case(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
211
+ obj_name: Optional[str] = None, parent_is_iterable: bool = False) -> Case:
225
212
  """
226
- Create a table from an object.
213
+ Create a case from an object.
227
214
 
228
- :param obj: The object to create a table from.
215
+ :param obj: The object to create a case from.
229
216
  :param recursion_idx: The current recursion index.
230
217
  :param max_recursion_idx: The maximum recursion index to prevent infinite recursion.
231
218
  :param obj_name: The name of the object.
232
219
  :param parent_is_iterable: Boolean indicating whether the parent object is iterable or not.
233
- :return: The table of the object.
220
+ :return: The case that represents the object.
234
221
  """
235
- if isinstance(obj, Row):
222
+ if isinstance(obj, Case):
236
223
  return obj
237
224
  if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
238
225
  or (obj.__class__ in [MetaData, registry])):
239
- return Row(id_=id(obj), **{obj_name or obj.__class__.__name__: make_set(obj) if parent_is_iterable else obj})
240
- row = Row(id_=id(obj))
226
+ return Case(_id=id(obj), _type=obj.__class__,
227
+ **{obj_name or obj.__class__.__name__: make_set(obj) if parent_is_iterable else obj})
228
+ case = Case(_id=id(obj), _type=obj.__class__)
241
229
  for attr in dir(obj):
242
230
  if attr.startswith("_") or callable(getattr(obj, attr)):
243
231
  continue
244
232
  attr_value = getattr(obj, attr)
245
- row = create_or_update_row_from_attribute(attr_value, attr, obj, attr, recursion_idx,
246
- max_recursion_idx, parent_is_iterable, row)
247
- return row
233
+ case = create_or_update_case_from_attribute(attr_value, attr, obj, attr, recursion_idx,
234
+ max_recursion_idx, parent_is_iterable, case)
235
+ return case
248
236
 
249
237
 
250
- def create_or_update_row_from_attribute(attr_value: Any, name: str, obj: Any, obj_name: Optional[str] = None,
251
- recursion_idx: int = 0, max_recursion_idx: int = 1,
252
- parent_is_iterable: bool = False,
253
- row: Optional[Row] = None) -> Row:
238
+ def create_or_update_case_from_attribute(attr_value: Any, name: str, obj: Any, obj_name: Optional[str] = None,
239
+ recursion_idx: int = 0, max_recursion_idx: int = 1,
240
+ parent_is_iterable: bool = False,
241
+ case: Optional[Case] = None) -> Case:
254
242
  """
255
- Get a reference column and its table.
243
+ Create or update a case from an attribute of the object that the case represents.
256
244
 
257
- :param attr_value: The attribute value to get the column and table from.
245
+ :param attr_value: The attribute value.
258
246
  :param name: The name of the attribute.
259
247
  :param obj: The parent object of the attribute.
260
248
  :param obj_name: The parent object name.
261
249
  :param recursion_idx: The recursion index to prevent infinite recursion.
262
250
  :param max_recursion_idx: The maximum recursion index.
263
251
  :param parent_is_iterable: Boolean indicating whether the parent object is iterable or not.
264
- :param row: The row to update.
265
- :return: A reference column and its table.
252
+ :param case: The case to update.
253
+ :return: The updated/created case.
266
254
  """
267
- if row is None:
268
- row = Row(id_=id(obj))
255
+ if case is None:
256
+ case = Case(_id=id(obj), _type=obj.__class__)
269
257
  if isinstance(attr_value, (dict, UserDict)):
270
- row.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
258
+ case.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
271
259
  if hasattr(attr_value, "__iter__") and not isinstance(attr_value, str):
272
- column, attr_row = create_column_and_row_from_iterable_attribute(attr_value, name, obj, obj_name,
273
- recursion_idx=recursion_idx + 1,
274
- max_recursion_idx=max_recursion_idx)
275
- row[obj_name] = column
260
+ column = create_case_attribute_from_iterable_attribute(attr_value, name, obj, obj_name,
261
+ recursion_idx=recursion_idx + 1,
262
+ max_recursion_idx=max_recursion_idx)
263
+ case[obj_name] = column
276
264
  else:
277
- row[obj_name] = make_set(attr_value) if parent_is_iterable else attr_value
278
- return row
265
+ case[obj_name] = make_set(attr_value) if parent_is_iterable else attr_value
266
+ return case
279
267
 
280
268
 
281
- def create_column_and_row_from_iterable_attribute(attr_value: Any, name: str, obj: Any, obj_name: Optional[str] = None,
269
+ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, obj: Any, obj_name: Optional[str] = None,
282
270
  recursion_idx: int = 0,
283
- max_recursion_idx: int = 1) -> Tuple[Column, Row]:
271
+ max_recursion_idx: int = 1) -> CaseAttribute:
284
272
  """
285
- Get a table from an iterable.
273
+ Get a case attribute from an iterable attribute.
286
274
 
287
- :param attr_value: The iterable attribute to get the table from.
288
- :param name: The name of the table.
275
+ :param attr_value: The iterable attribute to get the case from.
276
+ :param name: The name of the case.
289
277
  :param obj: The parent object of the iterable.
290
278
  :param obj_name: The parent object name.
291
279
  :param recursion_idx: The recursion index to prevent infinite recursion.
292
280
  :param max_recursion_idx: The maximum recursion index.
293
- :return: A table of the iterable.
281
+ :return: A case attribute that represents the original iterable attribute.
294
282
  """
295
283
  values = attr_value.values() if isinstance(attr_value, (dict, UserDict)) else attr_value
296
- range_ = {type(list(values)[0])} if len(values) > 0 else set()
297
- if len(range_) == 0:
298
- range_ = make_set(get_value_type_from_type_hint(name, obj))
299
- if not range_:
300
- raise ValueError(f"Could not determine the range of {name} in {obj}.")
301
- attr_row = Row(id_=id(attr_value))
302
- column = Column.from_obj(values, row_obj=obj)
284
+ _type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
285
+ attr_case = Case(_id=id(attr_value), _type=_type)
286
+ case_attr = CaseAttribute.from_obj(values, row_obj=obj)
303
287
  for idx, val in enumerate(values):
304
- sub_attr_row = create_row(val, recursion_idx=recursion_idx,
305
- max_recursion_idx=max_recursion_idx,
306
- obj_name=obj_name, parent_is_iterable=True)
307
- attr_row.update(sub_attr_row)
308
- for sub_attr, val in attr_row.items():
309
- setattr(column, sub_attr, val)
310
- return column, attr_row
288
+ sub_attr_case = create_case(val, recursion_idx=recursion_idx,
289
+ max_recursion_idx=max_recursion_idx,
290
+ obj_name=obj_name, parent_is_iterable=True)
291
+ attr_case.update(sub_attr_case)
292
+ for sub_attr, val in attr_case.items():
293
+ setattr(case_attr, sub_attr, val)
294
+ return case_attr
311
295
 
312
296
 
313
- def show_current_and_corner_cases(case: Any, targets: Optional[Union[List[Column], List[SQLColumn]]] = None,
314
- current_conclusions: Optional[Union[List[Column], List[SQLColumn]]] = None,
297
+ def show_current_and_corner_cases(case: Any, targets: Optional[Union[List[CaseAttribute], List[SQLColumn]]] = None,
298
+ current_conclusions: Optional[Union[List[CaseAttribute], List[SQLColumn]]] = None,
315
299
  last_evaluated_rule: Optional[Rule] = None) -> None:
316
300
  """
317
301
  Show the data of the new case and if last evaluated rule exists also show that of the corner case.
@@ -351,6 +335,3 @@ def show_current_and_corner_cases(case: Any, targets: Optional[Union[List[Column
351
335
  case_dict.update(targets)
352
336
  case_dict.update(current_conclusions)
353
337
  print(table_rows_as_str(case_dict))
354
-
355
-
356
- Case = Row
@@ -5,7 +5,7 @@ from dataclasses import dataclass
5
5
  from sqlalchemy.orm import DeclarativeBase as SQLTable
6
6
  from typing_extensions import Any, Optional, Type
7
7
 
8
- from .table import create_row, Case
8
+ from .case import create_case, Case
9
9
  from ..utils import get_attribute_name, copy_case
10
10
 
11
11
 
@@ -50,7 +50,7 @@ class CaseQuery:
50
50
  self.attribute_name = attribute_name
51
51
 
52
52
  if not isinstance(case, (Case, SQLTable)):
53
- case = create_row(case, max_recursion_idx=3)
53
+ case = create_case(case, max_recursion_idx=3)
54
54
  self.case = case
55
55
 
56
56
  self.attribute = getattr(self.case, self.attribute_name) if self.attribute_name else None
@@ -6,8 +6,8 @@ from abc import ABC, abstractmethod
6
6
  from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColumn, Session
7
7
  from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type, Union, Any
8
8
 
9
- from .datastructures import (Case, PromptFor, CallableExpression, Column, CaseQuery)
10
- from .datastructures.table import show_current_and_corner_cases
9
+ from .datastructures import (Case, PromptFor, CallableExpression, CaseAttribute, CaseQuery)
10
+ from .datastructures.case import show_current_and_corner_cases
11
11
  from .prompt import prompt_user_for_expression, prompt_user_about_case
12
12
  from .utils import get_all_subclasses, is_iterable
13
13
 
@@ -30,13 +30,13 @@ class Expert(ABC):
30
30
  """
31
31
  A flag to indicate if the expert should use loaded answers or not.
32
32
  """
33
- known_categories: Optional[Dict[str, Type[Column]]] = None
33
+ known_categories: Optional[Dict[str, Type[CaseAttribute]]] = None
34
34
  """
35
35
  The known categories (i.e. Column types) to use.
36
36
  """
37
37
 
38
38
  @abstractmethod
39
- def ask_for_conditions(self, x: Case, targets: List[Column], last_evaluated_rule: Optional[Rule] = None) \
39
+ def ask_for_conditions(self, x: Case, targets: List[CaseAttribute], last_evaluated_rule: Optional[Rule] = None) \
40
40
  -> CallableExpression:
41
41
  """
42
42
  Ask the expert to provide the differentiating features between two cases or unique features for a case
@@ -50,8 +50,8 @@ class Expert(ABC):
50
50
  pass
51
51
 
52
52
  @abstractmethod
53
- def ask_for_extra_conclusions(self, x: Case, current_conclusions: List[Column]) \
54
- -> Dict[Column, CallableExpression]:
53
+ def ask_for_extra_conclusions(self, x: Case, current_conclusions: List[CaseAttribute]) \
54
+ -> Dict[CaseAttribute, CallableExpression]:
55
55
  """
56
56
  Ask the expert to provide extra conclusions for a case by providing a pair of category and conditions for
57
57
  that category.
@@ -63,9 +63,9 @@ class Expert(ABC):
63
63
  pass
64
64
 
65
65
  @abstractmethod
66
- def ask_if_conclusion_is_correct(self, x: Case, conclusion: Column,
67
- targets: Optional[List[Column]] = None,
68
- current_conclusions: Optional[List[Column]] = None) -> bool:
66
+ def ask_if_conclusion_is_correct(self, x: Case, conclusion: CaseAttribute,
67
+ targets: Optional[List[CaseAttribute]] = None,
68
+ current_conclusions: Optional[List[CaseAttribute]] = None) -> bool:
69
69
  """
70
70
  Ask the expert if the conclusion is correct.
71
71
 
@@ -125,14 +125,14 @@ class Human(Expert):
125
125
  self.all_expert_answers = json.load(f)
126
126
 
127
127
  def ask_for_conditions(self, case: Case,
128
- targets: Union[List[Column], List[SQLColumn]],
128
+ targets: Union[List[CaseAttribute], List[SQLColumn]],
129
129
  last_evaluated_rule: Optional[Rule] = None) \
130
130
  -> CallableExpression:
131
131
  if not self.use_loaded_answers:
132
132
  show_current_and_corner_cases(case, targets, last_evaluated_rule=last_evaluated_rule)
133
133
  return self._get_conditions(case, targets)
134
134
 
135
- def _get_conditions(self, case: Case, targets: List[Column]) \
135
+ def _get_conditions(self, case: Case, targets: List[CaseAttribute]) \
136
136
  -> CallableExpression:
137
137
  """
138
138
  Ask the expert to provide the differentiating features between two cases or unique features for a case
@@ -157,8 +157,8 @@ class Human(Expert):
157
157
  self.all_expert_answers.append(user_input)
158
158
  return condition
159
159
 
160
- def ask_for_extra_conclusions(self, case: Case, current_conclusions: List[Column]) \
161
- -> Dict[Column, CallableExpression]:
160
+ def ask_for_extra_conclusions(self, case: Case, current_conclusions: List[CaseAttribute]) \
161
+ -> Dict[CaseAttribute, CallableExpression]:
162
162
  """
163
163
  Ask the expert to provide extra conclusions for a case by providing a pair of category and conditions for
164
164
  that category.
@@ -198,7 +198,7 @@ class Human(Expert):
198
198
  self.all_expert_answers.append(expert_input)
199
199
  return expression
200
200
 
201
- def get_category_type(self, cat_name: str) -> Optional[Type[Column]]:
201
+ def get_category_type(self, cat_name: str) -> Optional[Type[CaseAttribute]]:
202
202
  """
203
203
  Get the category type from the known categories.
204
204
 
@@ -206,8 +206,8 @@ class Human(Expert):
206
206
  :return: The category type.
207
207
  """
208
208
  cat_name = cat_name.lower()
209
- self.known_categories = get_all_subclasses(Column) if not self.known_categories else self.known_categories
210
- self.known_categories.update(Column.registry)
209
+ self.known_categories = get_all_subclasses(CaseAttribute) if not self.known_categories else self.known_categories
210
+ self.known_categories.update(CaseAttribute.registry)
211
211
  category_type = None
212
212
  if cat_name in self.known_categories:
213
213
  category_type = self.known_categories[cat_name]
@@ -222,9 +222,9 @@ class Human(Expert):
222
222
  question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
223
223
  return not self.ask_yes_no_question(question)
224
224
 
225
- def ask_if_conclusion_is_correct(self, x: Case, conclusion: Column,
226
- targets: Optional[List[Column]] = None,
227
- current_conclusions: Optional[List[Column]] = None) -> bool:
225
+ def ask_if_conclusion_is_correct(self, x: Case, conclusion: CaseAttribute,
226
+ targets: Optional[List[CaseAttribute]] = None,
227
+ current_conclusions: Optional[List[CaseAttribute]] = None) -> bool:
228
228
  """
229
229
  Ask the expert if the conclusion is correct.
230
230
 
@@ -7,7 +7,7 @@ from prompt_toolkit.completion import WordCompleter
7
7
  from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
8
8
  from typing_extensions import Any, List, Optional, Tuple, Dict, Union, Type
9
9
 
10
- from .datastructures import Case, PromptFor, CallableExpression, create_row, parse_string_to_expression
10
+ from .datastructures import Case, PromptFor, CallableExpression, create_case, parse_string_to_expression
11
11
 
12
12
 
13
13
  def prompt_user_for_expression(case: Union[Case, SQLTable], prompt_for: PromptFor, target_name: str,
@@ -60,7 +60,7 @@ def get_completions(obj: Any) -> List[str]:
60
60
  # Define completer with all object attributes and comparison operators
61
61
  completions = ['==', '!=', '>', '<', '>=', '<=', 'in', 'not', 'and', 'or', 'is']
62
62
  completions += ["isinstance(", "issubclass(", "type(", "len(", "hasattr(", "getattr(", "setattr(", "delattr("]
63
- completions += list(create_row(obj).keys())
63
+ completions += list(create_case(obj).keys())
64
64
  return completions
65
65
 
66
66
 
ripple_down_rules/rdr.py CHANGED
@@ -8,7 +8,7 @@ from ordered_set import OrderedSet
8
8
  from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
9
9
  from typing_extensions import List, Optional, Dict, Type, Union, Any, Self, Tuple
10
10
 
11
- from .datastructures import Case, MCRDRMode, CallableExpression, Column, CaseQuery
11
+ from .datastructures import Case, MCRDRMode, CallableExpression, CaseAttribute, CaseQuery
12
12
  from .experts import Expert, Human
13
13
  from .rules import Rule, SingleClassRule, MultiClassTopRule
14
14
  from .utils import draw_tree, make_set, get_attribute_by_type, copy_case, \
@@ -23,7 +23,7 @@ class RippleDownRules(ABC):
23
23
  """
24
24
  The figure to draw the tree on.
25
25
  """
26
- expert_accepted_conclusions: Optional[List[Column]] = None
26
+ expert_accepted_conclusions: Optional[List[CaseAttribute]] = None
27
27
  """
28
28
  The conclusions that the expert has accepted, such that they are not asked again.
29
29
  """
@@ -37,11 +37,11 @@ class RippleDownRules(ABC):
37
37
  self.session = session
38
38
  self.fig: Optional[plt.Figure] = None
39
39
 
40
- def __call__(self, case: Union[Case, SQLTable]) -> Column:
40
+ def __call__(self, case: Union[Case, SQLTable]) -> CaseAttribute:
41
41
  return self.classify(case)
42
42
 
43
43
  @abstractmethod
44
- def classify(self, case: Union[Case, SQLTable]) -> Optional[Column]:
44
+ def classify(self, case: Union[Case, SQLTable]) -> Optional[CaseAttribute]:
45
45
  """
46
46
  Classify a case.
47
47
 
@@ -52,7 +52,7 @@ class RippleDownRules(ABC):
52
52
 
53
53
  @abstractmethod
54
54
  def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
55
- -> Union[Column, CallableExpression]:
55
+ -> Union[CaseAttribute, CallableExpression]:
56
56
  """
57
57
  Fit the RDR on a case, and ask the expert for refinements or alternatives if the classification is incorrect by
58
58
  comparing the case with the target category.
@@ -118,7 +118,7 @@ class RippleDownRules(ABC):
118
118
  plt.show()
119
119
 
120
120
  @staticmethod
121
- def calculate_precision_and_recall(pred_cat: List[Column], target: List[Column]) -> Tuple[List[bool], List[bool]]:
121
+ def calculate_precision_and_recall(pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> Tuple[List[bool], List[bool]]:
122
122
  """
123
123
  :param pred_cat: The predicted category.
124
124
  :param target: The target category.
@@ -128,12 +128,10 @@ class RippleDownRules(ABC):
128
128
  target = target if is_iterable(target) else [target]
129
129
  recall = [not yi or (yi in pred_cat) for yi in target]
130
130
  target_types = [type(yi) for yi in target]
131
- if len(pred_cat) > 1:
132
- print(pred_cat)
133
131
  precision = [(pred in target) or (type(pred) not in target_types) for pred in pred_cat]
134
132
  return precision, recall
135
133
 
136
- def is_matching(self, pred_cat: List[Column], target: List[Column]) -> bool:
134
+ def is_matching(self, pred_cat: List[CaseAttribute], target: List[CaseAttribute]) -> bool:
137
135
  """
138
136
  :param pred_cat: The predicted category.
139
137
  :param target: The target category.
@@ -181,7 +179,7 @@ RDR = RippleDownRules
181
179
  class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
182
180
 
183
181
  def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
184
- -> Union[Column, CallableExpression]:
182
+ -> Union[CaseAttribute, CallableExpression]:
185
183
  """
186
184
  Classify a case, and ask the user for refinements or alternatives if the classification is incorrect by
187
185
  comparing the case with the target category if provided.
@@ -209,7 +207,7 @@ class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
209
207
 
210
208
  return self.classify(case)
211
209
 
212
- def classify(self, case: Case) -> Optional[Column]:
210
+ def classify(self, case: Case) -> Optional[CaseAttribute]:
213
211
  """
214
212
  Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
215
213
  """
@@ -227,18 +225,22 @@ class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
227
225
  """
228
226
  Write the tree of rules as source code to a file.
229
227
  """
230
- case_type = self.start_rule.corner_case.__class__.__name__
231
- case_module = self.start_rule.corner_case.__class__.__module__
228
+ case = self.start_rule.corner_case
229
+ if isinstance(case, Case):
230
+ case_type = case._type
231
+ else:
232
+ case_type = type(case)
233
+ case_module = case_type.__module__
232
234
  conclusion = self.start_rule.conclusion
233
235
  if isinstance(conclusion, CallableExpression):
234
236
  conclusion_types = [conclusion.conclusion_type]
235
- elif isinstance(conclusion, Column):
237
+ elif isinstance(conclusion, CaseAttribute):
236
238
  conclusion_types = list(conclusion._value_range)
237
239
  else:
238
240
  conclusion_types = [type(conclusion)]
239
241
  imports = ""
240
242
  if case_module != "builtins":
241
- imports += f"from {case_module} import {case_type}\n"
243
+ imports += f"from {case_module} import {case_type.__name__}\n"
242
244
  if len(conclusion_types) > 1:
243
245
  conclusion_name = "Union[" + ", ".join([c.__name__ for c in conclusion_types]) + "]"
244
246
  else:
@@ -246,11 +248,14 @@ class SingleClassRDR(RippleDownRules, SubclassJSONSerializer):
246
248
  for conclusion_type in conclusion_types:
247
249
  if conclusion_type.__module__ != "builtins":
248
250
  imports += f"from {conclusion_type.__module__} import {conclusion_name}\n"
251
+ imports += "from ripple_down_rules.datastructures import Case, create_case\n"
249
252
  imports += "\n\n"
250
- func_def = f"def classify_{conclusion_name.lower()}(case: {case_type}) -> {conclusion_name}:\n"
253
+ func_def = f"def classify_{conclusion_name.lower()}(case: {case_type.__name__}) -> {conclusion_name}:\n"
251
254
  with open(filename, "w") as f:
252
255
  f.write(imports)
253
256
  f.write(func_def)
257
+ f.write(f"{' '*4}if not isinstance(case, Case):\n"
258
+ f"{' '*4} case = create_case(case, recursion_idx=3)\n""")
254
259
  self.write_rules_as_source_code_to_file(self.start_rule, f, " " * 4)
255
260
 
256
261
  def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file, parent_indent: str = ""):
@@ -289,7 +294,7 @@ class MultiClassRDR(RippleDownRules):
289
294
  """
290
295
  The evaluated rules in the classifier for one case.
291
296
  """
292
- conclusions: Optional[List[Column]] = None
297
+ conclusions: Optional[List[CaseAttribute]] = None
293
298
  """
294
299
  The conclusions that the case belongs to.
295
300
  """
@@ -322,7 +327,7 @@ class MultiClassRDR(RippleDownRules):
322
327
  return self.conclusions
323
328
 
324
329
  def fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None,
325
- add_extra_conclusions: bool = False) -> List[Union[Column, CallableExpression]]:
330
+ add_extra_conclusions: bool = False) -> List[Union[CaseAttribute, CallableExpression]]:
326
331
  """
327
332
  Classify a case, and ask the user for stopping rules or classifying rules if the classification is incorrect
328
333
  or missing by comparing the case with the target category if provided.
@@ -450,7 +455,7 @@ class MultiClassRDR(RippleDownRules):
450
455
  :param target: The target category to compare the conclusion with.
451
456
  :return: Whether the conclusion is of the same class as the target category but has a different value.
452
457
  """
453
- return conclusion.__class__ == target.__class__ and target.__class__ != Column
458
+ return conclusion.__class__ == target.__class__ and target.__class__ != CaseAttribute
454
459
 
455
460
  def conclusion_is_correct(self, case: Union[Case, SQLTable], target: Any, expert: Expert, evaluated_rule: Rule,
456
461
  add_extra_conclusions: bool) -> bool:
@@ -599,7 +604,7 @@ class GeneralRDR(RippleDownRules):
599
604
  return conclusions
600
605
 
601
606
  def fit_case(self, case_queries: List[CaseQuery], expert: Optional[Expert] = None, **kwargs) \
602
- -> List[Union[Column, CallableExpression]]:
607
+ -> List[Union[CaseAttribute, CallableExpression]]:
603
608
  """
604
609
  Fit the GRDR on a case, if the target is a new type of category, a new RDR is created for it,
605
610
  else the existing RDR of that type will be fitted on the case, and then classification is done and all
@@ -626,7 +631,7 @@ class GeneralRDR(RippleDownRules):
626
631
  if not target:
627
632
  target = expert.ask_for_conclusion(case_query)
628
633
  case_query_cp = CaseQuery(case_cp, attribute_name=case_query.attribute_name, target=target)
629
- if is_iterable(target) and not isinstance(target, Column):
634
+ if is_iterable(target) and not isinstance(target, CaseAttribute):
630
635
  target_type = type(make_list(target)[0])
631
636
  assert all([type(t) is target_type for t in target]), ("All targets of a case query must be of the same"
632
637
  " type")
@@ -661,7 +666,7 @@ class GeneralRDR(RippleDownRules):
661
666
  return MultiClassRDR()
662
667
  else:
663
668
  return SingleClassRDR()
664
- elif isinstance(attribute, Column):
669
+ elif isinstance(attribute, CaseAttribute):
665
670
  return SingleClassRDR() if attribute.mutually_exclusive else MultiClassRDR()
666
671
  else:
667
672
  return MultiClassRDR() if is_iterable(attribute) else SingleClassRDR()
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import importlib
4
+ import json
4
5
  import logging
5
6
  from abc import abstractmethod
6
7
  from collections import UserDict
@@ -81,6 +82,18 @@ class SubclassJSONSerializer:
81
82
  'from_json' method.
82
83
  """
83
84
 
85
+ def to_json_file(self, filename: str):
86
+ """
87
+ Save the object to a json file.
88
+ """
89
+ data = self.to_json()
90
+ # save the json to a file
91
+ if not filename.endswith(".json"):
92
+ filename += ".json"
93
+ with open(filename, "w") as f:
94
+ json.dump(data, f, indent=4)
95
+ return data
96
+
84
97
  def to_json(self) -> Dict[str, Any]:
85
98
  return {"_type": get_full_class_name(self.__class__), **self._to_json()}
86
99
 
@@ -104,6 +117,19 @@ class SubclassJSONSerializer:
104
117
  """
105
118
  raise NotImplementedError()
106
119
 
120
+ @classmethod
121
+ def from_json_file(cls, filename: str):
122
+ """
123
+ Create an instance of the subclass from the data in the given json file.
124
+
125
+ :param filename: The filename of the json file.
126
+ """
127
+ if not filename.endswith(".json"):
128
+ filename += ".json"
129
+ with open(filename, "r") as f:
130
+ scrdr_json = json.load(f)
131
+ return cls.from_json(scrdr_json)
132
+
107
133
  @classmethod
108
134
  def from_json(cls, data: Dict[str, Any]) -> Self:
109
135
  """
@@ -124,6 +150,9 @@ class SubclassJSONSerializer:
124
150
 
125
151
  raise ValueError("Unknown type {}".format(data["_type"]))
126
152
 
153
+ save = to_json_file
154
+ load = from_json_file
155
+
127
156
 
128
157
  def copy_case(case: Union[Case, SQLTable]) -> Union[Case, SQLTable]:
129
158
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.7
3
+ Version: 0.0.9
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -0,0 +1,18 @@
1
+ ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ ripple_down_rules/datasets.py,sha256=qYX7IF7ACm0VRbaKfEgQ32j0YbUUyt2GfGU5Lo42CqI,4601
3
+ ripple_down_rules/experts.py,sha256=wg1uY0ox9dUMR4s1RdGjzpX1_WUqnCa060r1U9lrKYI,11214
4
+ ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
+ ripple_down_rules/prompt.py,sha256=QAmxg4ssrGUAlK7lbyKw2nuRczTZColpjc9uMC1ts3I,4210
6
+ ripple_down_rules/rdr.py,sha256=IZNoZm6w_WPTlLkHEQmjWcDLZnyczgSK0zX5OFYl0Ys,34572
7
+ ripple_down_rules/rules.py,sha256=R9wIyalXBc6fGogzy92OxsyKAtJdQwdso_Vs1evWZGU,10274
8
+ ripple_down_rules/utils.py,sha256=pmjIPjF9wq6dmCHCoUKFRsJmSaI0b-hRLn375WHgWgQ,18010
9
+ ripple_down_rules/datastructures/__init__.py,sha256=zpmiYm4WkwNHaGdTIfacS7llN5d2xyU6U-saH_TpydI,103
10
+ ripple_down_rules/datastructures/callable_expression.py,sha256=TN6bi4VYjyLlSLTEA3dRo5ENfEdQYc8Fjj5nbnsz-C0,9058
11
+ ripple_down_rules/datastructures/case.py,sha256=mcBVu1IGNpLGyXb152oLrsv8UbNncIhPtjmSTUgJ7uc,13593
12
+ ripple_down_rules/datastructures/dataclasses.py,sha256=EVQ1jBKW7K7q7_JNgikHX9fm3EmQQKA74sNjEQ4rXn8,2449
13
+ ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
14
+ ripple_down_rules-0.0.9.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
15
+ ripple_down_rules-0.0.9.dist-info/METADATA,sha256=7ACb5X17EumZ1W19aIvx12FcjuovaQvmW4BE4Iccz7o,42518
16
+ ripple_down_rules-0.0.9.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
17
+ ripple_down_rules-0.0.9.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
18
+ ripple_down_rules-0.0.9.dist-info/RECORD,,
@@ -1,18 +0,0 @@
1
- ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- ripple_down_rules/datasets.py,sha256=p7WmLiAKcYtzLq87RJhzU_BHCdM-zKK71yldjyYNpUE,4602
3
- ripple_down_rules/experts.py,sha256=ajH46NXM6IREtWr2a-ZBxdKl4CFzSx5S7yktnyXgwnc,11089
4
- ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
- ripple_down_rules/prompt.py,sha256=lmREZRyleBTHrVtcf2j_48oc0v3VlxXYGhl6w1mk8qI,4208
6
- ripple_down_rules/rdr.py,sha256=dS_OqaGtP1PAKEe0h-dMy7JI9ycNuH_EIbHlh-Nz4Sg,34189
7
- ripple_down_rules/rules.py,sha256=R9wIyalXBc6fGogzy92OxsyKAtJdQwdso_Vs1evWZGU,10274
8
- ripple_down_rules/utils.py,sha256=VjpdJ5W-xK-TQttfpNIHRKpYK91vYjtQ22kgccQgVTQ,17183
9
- ripple_down_rules/datastructures/__init__.py,sha256=wY9WqXavuE3wQ1YP65cs_SZyr7CEMB9tol-4oxgK9CM,104
10
- ripple_down_rules/datastructures/callable_expression.py,sha256=lF8ag3-8oYuETAUohtRfviOpypSMQR-JEJGgR0Zhv3E,9055
11
- ripple_down_rules/datastructures/dataclasses.py,sha256=FW3MMhGXTPa0XwIyLzGalrPwiltNeUqbGIawCAKDHGk,2448
12
- ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
13
- ripple_down_rules/datastructures/table.py,sha256=G2ShfPmyiwfDXS-JY_jFD6nOx7OLuMr-GP--OCDGKYc,13733
14
- ripple_down_rules-0.0.7.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
15
- ripple_down_rules-0.0.7.dist-info/METADATA,sha256=fBb_rMvpmfAueLPCXvNs3IXLNX-COunhRnEq7iYkV5Y,42518
16
- ripple_down_rules-0.0.7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
17
- ripple_down_rules-0.0.7.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
18
- ripple_down_rules-0.0.7.dist-info/RECORD,,