ripple-down-rules 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,544 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import time
5
+ from abc import ABC
6
+ from collections import UserDict
7
+ from dataclasses import dataclass
8
+ from enum import Enum
9
+
10
+ from pandas import DataFrame
11
+ from sqlalchemy import MetaData
12
+ from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColumn, registry
13
+ from typing_extensions import Any, Optional, Dict, Type, Set, Hashable, Union, List, TYPE_CHECKING, Tuple, Self
14
+
15
+ from ..utils import make_set, row_to_dict, table_rows_as_str, get_value_type_from_type_hint, make_list, \
16
+ SubclassJSONSerializer
17
+
18
+ if TYPE_CHECKING:
19
+ from ripple_down_rules.rules import Rule
20
+ from .callable_expression import CallableExpression
21
+
22
+
23
+ class SubClassFactory:
24
+ """
25
+ A custom set class that is used to add other attributes to the set. This is similar to a table where the set is the
26
+ table, the attributes are the columns, and the values are the rows.
27
+ """
28
+ _value_range: set
29
+ """
30
+ The range of the attribute, this can be a set of possible values or a range of numeric values (int, float).
31
+ """
32
+ _registry: Dict[(str, type), Type[SubClassFactory]] = {}
33
+ """
34
+ A dictionary of all dynamically created subclasses of this class.
35
+ """
36
+ _generated_classes_dir: str = os.path.dirname(os.path.abspath(__file__)) + "/generated"
37
+
38
+ @classmethod
39
+ def create(cls, name: str, range_: set, class_attributes: Optional[Dict[str, Any]] = None,
40
+ default_values: bool = True,
41
+ attributes_type_hints: Optional[Dict[str, Type]] = None) -> Type[SubClassFactory]:
42
+ """
43
+ Create a new subclass.
44
+
45
+ :param name: The name of the subclass.
46
+ :param range_: The range of the subclass values.
47
+ :param class_attributes: The attributes of the new subclass.
48
+ :param default_values: Boolean indicating whether to add default values to the subclass attributes or not.
49
+ :param attributes_type_hints: The type hints of the subclass attributes.
50
+ :return: The new subclass.
51
+ """
52
+ existing_class = cls._get_and_update_subclass(name, range_)
53
+ if existing_class:
54
+ return existing_class
55
+
56
+ new_attribute_type = cls._create_class_in_new_python_file_and_import_it(name, range_, default_values,
57
+ class_attributes, attributes_type_hints)
58
+
59
+ cls.register(new_attribute_type)
60
+
61
+ return new_attribute_type
62
+
63
+ @classmethod
64
+ def _create_class_in_new_python_file_and_import_it(cls, name: str, range_: set, default_values: bool = True,
65
+ class_attributes: Optional[Dict[str, Any]] = None,
66
+ attributes_type_hints: Optional[Dict[str, Type]] = None)\
67
+ -> Type[SubClassFactory]:
68
+ def get_type_import(value_type: Any) -> Tuple[str, str]:
69
+ if value_type is type(None):
70
+ return "from types import NoneType\n", "NoneType"
71
+ elif value_type.__module__ != "builtins":
72
+ value_type_alias = f"{value_type.__name__}_"
73
+ return f"from {value_type.__module__} import {value_type.__name__} as {value_type_alias}\n", value_type_alias
74
+ else:
75
+ return "", value_type.__name__
76
+ attributes_type_hints = attributes_type_hints or {}
77
+ parent_class_alias = cls.__name__ + "_"
78
+ imports = f"from {cls.__module__} import {cls.__name__} as {parent_class_alias}\n"
79
+ class_code = f"class {name}({parent_class_alias}):\n"
80
+ class_attributes.update({"_value_range": range_})
81
+ for key, value in class_attributes.items():
82
+ if value is not None:
83
+ new_import, value_type_name = get_type_import(type(value))
84
+ elif key in attributes_type_hints:
85
+ new_import, value_type_name = get_type_import(attributes_type_hints[key])
86
+ else:
87
+ new_import, value_type_name = "from typing_extensions import Any", "Any"
88
+ imports += new_import
89
+ if isinstance(value, set):
90
+ value_names = []
91
+ for v in value:
92
+ if isinstance(v, type):
93
+ new_import, v_name = get_type_import(v)
94
+ imports += new_import
95
+ else:
96
+ v_name = str(v)
97
+ value_names.append(v_name)
98
+ value_str = ", ".join(value_names)
99
+ new_value = "{" + value_str + "}"
100
+ elif isinstance(value, type):
101
+ new_import, value_name = get_type_import(value)
102
+ new_value = value_name
103
+ value_type_name = value_name
104
+ else:
105
+ new_value = value
106
+ if default_values or key == "_value_range":
107
+ class_code += f" {key}: {value_type_name} = {new_value}\n"
108
+ else:
109
+ class_code += f" {key}: {value_type_name}\n"
110
+ imports += "\n\n"
111
+ if issubclass(cls, Row):
112
+ folder_name = "row"
113
+ elif issubclass(cls, Column):
114
+ folder_name = "column"
115
+ else:
116
+ raise ValueError(f"Unknown class {cls}.")
117
+ # write the code to a file
118
+ with open(f"{cls._generated_classes_dir}/{folder_name}/{name.lower()}.py", "w") as f:
119
+ f.write(imports + class_code)
120
+
121
+ # import the class from the file
122
+ import_path = ".".join(cls.__module__.split(".")[:-1] + ["generated", folder_name, name.lower()])
123
+ time.sleep(0.3)
124
+ return __import__(import_path, fromlist=[name.lower()]).__dict__[name]
125
+
126
+ @classmethod
127
+ def _get_and_update_subclass(cls, name: str, range_: set) -> Optional[Type[SubClassFactory]]:
128
+ """
129
+ Get a subclass of the attribute class and update its range if necessary.
130
+
131
+ :param name: The name of the column.
132
+ :param range_: The range of the column values.
133
+ """
134
+ key = (name.lower(), cls)
135
+ if key in cls._registry:
136
+ if not cls._registry[key].is_within_range(range_):
137
+ if isinstance(cls._registry[key]._value_range, set):
138
+ cls._registry[key]._value_range.update(range_)
139
+ else:
140
+ raise ValueError(f"Range of {key} is different from {cls._registry[key]._value_range}.")
141
+ return cls._registry[key]
142
+
143
+ @classmethod
144
+ def register(cls, subclass: Type[SubClassFactory]):
145
+ """
146
+ Register a subclass of the attribute class, this is used to be able to dynamically create Attribute subclasses.
147
+
148
+ :param subclass: The subclass to register.
149
+ """
150
+ if not issubclass(subclass, SubClassFactory):
151
+ raise ValueError(f"{subclass} is not a subclass of CustomSet.")
152
+ if subclass not in cls._registry:
153
+ cls._registry[(subclass.__name__.lower(), cls)] = subclass
154
+ else:
155
+ raise ValueError(f"{subclass} is already registered.")
156
+
157
+ @classmethod
158
+ def is_within_range(cls, value: Any) -> bool:
159
+ """
160
+ Check if a value is within the range of the custom set.
161
+
162
+ :param value: The value to check.
163
+ :return: Boolean indicating whether the value is within the range or not.
164
+ """
165
+ if hasattr(value, "__iter__") and not isinstance(value, str):
166
+ if all(isinstance(val_range, type) and isinstance(v, val_range)
167
+ for v in value for val_range in cls._value_range):
168
+ return True
169
+ else:
170
+ return set(value).issubset(cls._value_range)
171
+ elif isinstance(value, str):
172
+ return value.lower() in cls._value_range
173
+ else:
174
+ return value in cls._value_range
175
+
176
+ def __instancecheck__(self, instance):
177
+ return isinstance(instance, (SubClassFactory, *self._value_range))
178
+
179
+
180
+ class Row(UserDict, SubClassFactory, SubclassJSONSerializer):
181
+ """
182
+ A collection of attributes that represents a set of constraints on a case. This is a dictionary where the keys are
183
+ the names of the attributes and the values are the attributes. All are stored in lower case.
184
+ """
185
+
186
+ def __init__(self, id_: Optional[Hashable] = None, **kwargs):
187
+ """
188
+ Create a new row.
189
+
190
+ :param id_: The id of the row.
191
+ :param kwargs: The attributes of the row.
192
+ """
193
+ super().__init__(kwargs)
194
+ self.id = id_
195
+
196
+ @classmethod
197
+ def from_obj(cls, obj: Any, obj_name: Optional[str] = None, max_recursion_idx: int = 3) -> Row:
198
+ """
199
+ Create a row from an object.
200
+
201
+ :param obj: The object to create a row from.
202
+ :param max_recursion_idx: The maximum recursion index to prevent infinite recursion.
203
+ :param obj_name: The name of the object.
204
+ :return: The row of the object.
205
+ """
206
+ return create_row(obj, max_recursion_idx=max_recursion_idx, obj_name=obj_name)
207
+
208
+ def __getitem__(self, item: str) -> Any:
209
+ return super().__getitem__(item.lower())
210
+
211
+ def __setitem__(self, name: str, value: Any):
212
+ name = name.lower()
213
+ if name in self:
214
+ if isinstance(self[name], set):
215
+ self[name].update(make_set(value))
216
+ elif isinstance(value, set):
217
+ value.update(make_set(self[name]))
218
+ super().__setitem__(name, value)
219
+ else:
220
+ super().__setitem__(name, make_set(self[name]))
221
+ else:
222
+ setattr(self, name, value)
223
+ super().__setitem__(name, value)
224
+
225
+ def __contains__(self, item):
226
+ if isinstance(item, (type, Enum)):
227
+ item = item.__name__
228
+ return super().__contains__(item.lower())
229
+
230
+ def __delitem__(self, key):
231
+ super().__delitem__(key.lower())
232
+
233
+ def __eq__(self, other):
234
+ if not isinstance(other, (Row, dict, UserDict)):
235
+ return False
236
+ elif isinstance(other, (dict, UserDict)):
237
+ return super().__eq__(Row(other))
238
+ else:
239
+ return super().__eq__(other)
240
+
241
+ def __hash__(self):
242
+ return hash(tuple(self.items()))
243
+
244
+ def __instancecheck__(self, instance):
245
+ return isinstance(instance, (dict, UserDict, Row)) or super().__instancecheck__(instance)
246
+
247
+ def to_json(self) -> Dict[str, Any]:
248
+ return {**SubclassJSONSerializer.to_json(self),
249
+ **{k: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for k, v in self.items()}}
250
+
251
+ @classmethod
252
+ def _from_json(cls, data: Dict[str, Any]) -> Row:
253
+ return cls(id_=data["id"], **data)
254
+
255
+
256
+ @dataclass
257
+ class ColumnValue(SubclassJSONSerializer):
258
+ """
259
+ A column value is a value in a column.
260
+ """
261
+ id: Hashable
262
+ """
263
+ The row id of the column value.
264
+ """
265
+ value: Any
266
+ """
267
+ The value of the column.
268
+ """
269
+
270
+ def __eq__(self, other):
271
+ if not isinstance(other, ColumnValue):
272
+ return False
273
+ return self.value == other.value
274
+
275
+ def __hash__(self):
276
+ return self.id
277
+
278
+ def to_json(self) -> Dict[str, Any]:
279
+ return {**SubclassJSONSerializer.to_json(self),
280
+ "id": self.id, "value": self.value}
281
+
282
+ @classmethod
283
+ def _from_json(cls, data: Dict[str, Any]) -> ColumnValue:
284
+ return cls(id=data["id"], value=data["value"])
285
+
286
+
287
+ class Column(set, SubClassFactory, SubclassJSONSerializer):
288
+ nullable: bool = True
289
+ """
290
+ A boolean indicating whether the column can be None or not.
291
+ """
292
+ mutually_exclusive: bool = False
293
+ """
294
+ A boolean indicating whether the column is mutually exclusive or not. (i.e. can only have one value)
295
+ """
296
+
297
+ def __init__(self, values: Set[ColumnValue]):
298
+ """
299
+ Create a new column.
300
+
301
+ :param values: The values of the column.
302
+ """
303
+ values = self._type_cast_values_to_set_of_column_values(values)
304
+ self.id_value_map: Dict[Hashable, Union[ColumnValue, Set[ColumnValue]]] = {id(v): v for v in values}
305
+ super().__init__([v.value for v in values])
306
+
307
+ @staticmethod
308
+ def _type_cast_values_to_set_of_column_values(values: Set[Any]) -> Set[ColumnValue]:
309
+ """
310
+ Type cast values to a set of column values.
311
+
312
+ :param values: The values to type cast.
313
+ """
314
+ values = make_set(values)
315
+ if len(values) > 0 and not isinstance(next(iter(values)), ColumnValue):
316
+ values = {ColumnValue(id(values), v) for v in values}
317
+ return values
318
+
319
+ @classmethod
320
+ def create(cls, name: str, range_: set,
321
+ nullable: bool = True, mutually_exclusive: bool = False) -> Type[SubClassFactory]:
322
+ return super().create(name, range_, {"nullable": nullable, "mutually_exclusive": mutually_exclusive})
323
+
324
+ @classmethod
325
+ def create_from_enum(cls, category: Type[Enum], nullable: bool = True,
326
+ mutually_exclusive: bool = False) -> Type[SubClassFactory]:
327
+ new_cls = cls.create(category.__name__.lower(), {category}, nullable=nullable,
328
+ mutually_exclusive=mutually_exclusive)
329
+ for value in category:
330
+ value_column = cls.create(category.__name__.lower(), {value}, mutually_exclusive=mutually_exclusive)(value)
331
+ setattr(new_cls, value.name, value_column)
332
+ return new_cls
333
+
334
+ @classmethod
335
+ def from_obj(cls, values: Set[Any], row_obj: Optional[Any] = None) -> Column:
336
+ id_ = id(row_obj) if row_obj else id(values)
337
+ values = make_set(values)
338
+ return cls({ColumnValue(id_, v) for v in values})
339
+
340
+ @property
341
+ def as_dict(self) -> Dict[str, Any]:
342
+ """
343
+ Get the column as a dictionary.
344
+
345
+ :return: The column as a dictionary.
346
+ """
347
+ return {self.__class__.__name__: self}
348
+
349
+ def filter_by(self, condition: CallableExpression) -> Column:
350
+ """
351
+ Filter the column by a condition.
352
+
353
+ :param condition: The condition to filter by.
354
+ :return: The filtered column.
355
+ """
356
+ return self.__class__({v for v in self if condition(v)})
357
+
358
+ def __eq__(self, other):
359
+ if not isinstance(other, set):
360
+ return super().__eq__(make_set(other))
361
+ return super().__eq__(other)
362
+
363
+ def __hash__(self):
364
+ return hash(tuple(self.id_value_map.values()))
365
+
366
+ def __str__(self):
367
+ if len(self) == 0:
368
+ return "None"
369
+ return str({v for v in self}) if len(self) > 1 else str(next(iter(self)))
370
+
371
+ def __instancecheck__(self, instance):
372
+ return isinstance(instance, (set, self.__class__)) or super().__instancecheck__(instance)
373
+
374
+ def to_json(self) -> Dict[str, Any]:
375
+ return {**SubclassJSONSerializer.to_json(self),
376
+ **{id_: v.to_json() if isinstance(v, SubclassJSONSerializer) else v for id_, v in self.id_value_map.items()}}
377
+
378
+ @classmethod
379
+ def _from_json(cls, data: Dict[str, Any]) -> Column:
380
+ return cls({ColumnValue(int(id_), v) for id_, v in data.items()})
381
+
382
+
383
+ def create_rows_from_dataframe(df: DataFrame, name: Optional[str] = None) -> List[Row]:
384
+ """
385
+ Create a row from a pandas DataFrame.
386
+
387
+ :param df: The DataFrame to create a row from.
388
+ :param name: The name of the DataFrame.
389
+ :return: The row of the DataFrame.
390
+ """
391
+ rows = []
392
+ col_names = list(df.columns)
393
+ for row_id, row in df.iterrows():
394
+ row = {col_name: row[col_name].item() for col_name in col_names}
395
+ row_cls = Row.create(name or df.__class__.__name__, make_set(type(df)), row, default_values=False)
396
+ rows.append(row_cls(id_=row_id, **row))
397
+ return rows
398
+
399
+
400
+ def create_row(obj: Any, recursion_idx: int = 0, max_recursion_idx: int = 0,
401
+ obj_name: Optional[str] = None, parent_is_iterable: bool = False) -> Row:
402
+ """
403
+ Create a table from an object.
404
+
405
+ :param obj: The object to create a table from.
406
+ :param recursion_idx: The current recursion index.
407
+ :param max_recursion_idx: The maximum recursion index to prevent infinite recursion.
408
+ :param obj_name: The name of the object.
409
+ :param parent_is_iterable: Boolean indicating whether the parent object is iterable or not.
410
+ :return: The table of the object.
411
+ """
412
+ if isinstance(obj, Row):
413
+ return obj
414
+ if ((recursion_idx > max_recursion_idx) or (obj.__class__.__module__ == "builtins")
415
+ or (obj.__class__ in [MetaData, registry])):
416
+ return Row(id_=id(obj), **{obj_name or obj.__class__.__name__: make_set(obj) if parent_is_iterable else obj})
417
+ row = Row(id_=id(obj))
418
+ attributes_type_hints = {}
419
+ for attr in dir(obj):
420
+ if attr.startswith("_") or callable(getattr(obj, attr)):
421
+ continue
422
+ attr_value = getattr(obj, attr)
423
+ row = create_or_update_row_from_attribute(attr_value, attr, obj, attr, recursion_idx,
424
+ max_recursion_idx, parent_is_iterable, row)
425
+ attributes_type_hints[attr] = get_value_type_from_type_hint(attr, obj)
426
+ row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False,
427
+ attributes_type_hints=attributes_type_hints)
428
+ row = row_cls(id_=id(obj), **row)
429
+ return row
430
+
431
+
432
+ def create_or_update_row_from_attribute(attr_value: Any, name: str, obj: Any, obj_name: Optional[str] = None,
433
+ recursion_idx: int = 0, max_recursion_idx: int = 1,
434
+ parent_is_iterable: bool = False,
435
+ row: Optional[Row] = None) -> Row:
436
+ """
437
+ Get a reference column and its table.
438
+
439
+ :param attr_value: The attribute value to get the column and table from.
440
+ :param name: The name of the attribute.
441
+ :param obj: The parent object of the attribute.
442
+ :param obj_name: The parent object name.
443
+ :param recursion_idx: The recursion index to prevent infinite recursion.
444
+ :param max_recursion_idx: The maximum recursion index.
445
+ :param parent_is_iterable: Boolean indicating whether the parent object is iterable or not.
446
+ :param row: The row to update.
447
+ :return: A reference column and its table.
448
+ """
449
+ if row is None:
450
+ row = Row(id_=id(obj))
451
+ if isinstance(attr_value, (dict, UserDict)):
452
+ row.update({f"{obj_name}.{k}": v for k, v in attr_value.items()})
453
+ if hasattr(attr_value, "__iter__") and not isinstance(attr_value, str):
454
+ column, attr_row = create_column_and_row_from_iterable_attribute(attr_value, name, obj, obj_name,
455
+ recursion_idx=recursion_idx + 1,
456
+ max_recursion_idx=max_recursion_idx)
457
+ row[obj_name] = column
458
+ else:
459
+ row[obj_name] = make_set(attr_value) if parent_is_iterable else attr_value
460
+ if row.__class__.__name__ == "Row":
461
+ row_cls = Row.create(obj_name or obj.__class__.__name__, make_set(type(obj)), row, default_values=False)
462
+ row = row_cls(id_=id(obj), **row)
463
+ return row
464
+
465
+
466
+ def create_column_and_row_from_iterable_attribute(attr_value: Any, name: str, obj: Any, obj_name: Optional[str] = None,
467
+ recursion_idx: int = 0,
468
+ max_recursion_idx: int = 1) -> Tuple[Column, Row]:
469
+ """
470
+ Get a table from an iterable.
471
+
472
+ :param attr_value: The iterable attribute to get the table from.
473
+ :param name: The name of the table.
474
+ :param obj: The parent object of the iterable.
475
+ :param obj_name: The parent object name.
476
+ :param recursion_idx: The recursion index to prevent infinite recursion.
477
+ :param max_recursion_idx: The maximum recursion index.
478
+ :return: A table of the iterable.
479
+ """
480
+ values = attr_value.values() if isinstance(attr_value, (dict, UserDict)) else attr_value
481
+ range_ = {type(list(values)[0])} if len(values) > 0 else set()
482
+ if len(range_) == 0:
483
+ range_ = make_set(get_value_type_from_type_hint(name, obj))
484
+ if not range_:
485
+ raise ValueError(f"Could not determine the range of {name} in {obj}.")
486
+ attr_row = Row(id_=id(attr_value))
487
+ column = Column.create(name, range_).from_obj(values, row_obj=obj)
488
+ attributes_type_hints = {}
489
+ for idx, val in enumerate(values):
490
+ sub_attr_row = create_row(val, recursion_idx=recursion_idx,
491
+ max_recursion_idx=max_recursion_idx,
492
+ obj_name=obj_name, parent_is_iterable=True)
493
+ attr_row.update(sub_attr_row)
494
+ # attr_row_cls = Row.create(name or list(range_)[0].__name__, range_, attr_row, default_values=False)
495
+ # attr_row = attr_row_cls(id_=id(attr_value), **attr_row)
496
+ for sub_attr, val in attr_row.items():
497
+ setattr(column, sub_attr, val)
498
+ return column, attr_row
499
+
500
+
501
+ def show_current_and_corner_cases(case: Any, targets: Optional[Union[List[Column], List[SQLColumn]]] = None,
502
+ current_conclusions: Optional[Union[List[Column], List[SQLColumn]]] = None,
503
+ last_evaluated_rule: Optional[Rule] = None) -> None:
504
+ """
505
+ Show the data of the new case and if last evaluated rule exists also show that of the corner case.
506
+
507
+ :param case: The new case.
508
+ :param targets: The target attribute of the case.
509
+ :param current_conclusions: The current conclusions of the case.
510
+ :param last_evaluated_rule: The last evaluated rule in the RDR.
511
+ """
512
+ corner_case = None
513
+ if targets:
514
+ targets = targets if isinstance(targets, list) else [targets]
515
+ if current_conclusions:
516
+ current_conclusions = current_conclusions if isinstance(current_conclusions, list) else [current_conclusions]
517
+ targets = {f"target_{t.__class__.__name__}": t for t in targets} if targets else {}
518
+ current_conclusions = {c.__class__.__name__: c for c in current_conclusions} if current_conclusions else {}
519
+ if last_evaluated_rule:
520
+ action = "Refinement" if last_evaluated_rule.fired else "Alternative"
521
+ print(f"{action} needed for rule: {last_evaluated_rule}\n")
522
+ corner_case = last_evaluated_rule.corner_case
523
+
524
+ corner_row_dict = None
525
+ if isinstance(case, SQLTable):
526
+ case_dict = row_to_dict(case)
527
+ if last_evaluated_rule and last_evaluated_rule.fired:
528
+ corner_row_dict = row_to_dict(last_evaluated_rule.corner_case)
529
+ else:
530
+ case_dict = case
531
+ if last_evaluated_rule and last_evaluated_rule.fired:
532
+ corner_row_dict = corner_case
533
+
534
+ if corner_row_dict:
535
+ corner_conclusion = last_evaluated_rule.conclusion
536
+ corner_row_dict.update({corner_conclusion.__class__.__name__: corner_conclusion})
537
+ print(table_rows_as_str(corner_row_dict))
538
+ print("=" * 50)
539
+ case_dict.update(targets)
540
+ case_dict.update(current_conclusions)
541
+ print(table_rows_as_str(case_dict))
542
+
543
+
544
+ Case = Row