ripple-down-rules 0.5.63__py3-none-any.whl → 0.5.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/datastructures/case.py +10 -4
- ripple_down_rules/datastructures/dataclasses.py +62 -3
- ripple_down_rules/helpers.py +55 -9
- ripple_down_rules/rdr.py +141 -101
- ripple_down_rules/rdr_decorators.py +54 -23
- ripple_down_rules/rules.py +63 -13
- ripple_down_rules/user_interface/gui.py +9 -7
- ripple_down_rules/user_interface/ipython_custom_shell.py +1 -1
- ripple_down_rules/user_interface/object_diagram.py +9 -1
- ripple_down_rules/user_interface/template_file_creator.py +17 -22
- ripple_down_rules/utils.py +235 -62
- {ripple_down_rules-0.5.63.dist-info → ripple_down_rules-0.5.71.dist-info}/METADATA +2 -1
- ripple_down_rules-0.5.71.dist-info/RECORD +24 -0
- ripple_down_rules-0.5.63.dist-info/RECORD +0 -24
- {ripple_down_rules-0.5.63.dist-info → ripple_down_rules-0.5.71.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.5.63.dist-info → ripple_down_rules-0.5.71.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.5.63.dist-info → ripple_down_rules-0.5.71.dist-info}/top_level.txt +0 -0
ripple_down_rules/__init__.py
CHANGED
@@ -84,7 +84,7 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
84
84
|
def _to_json(self) -> Dict[str, Any]:
|
85
85
|
serializable = {k: v for k, v in self.items() if not k.startswith("_")}
|
86
86
|
serializable["_id"] = self._id
|
87
|
-
serializable["_obj_type"] = get_full_class_name(self._obj_type)
|
87
|
+
serializable["_obj_type"] = get_full_class_name(self._obj_type) if self._obj_type is not None else None
|
88
88
|
serializable["_name"] = self._name
|
89
89
|
for k, v in serializable.items():
|
90
90
|
if isinstance(v, set):
|
@@ -96,7 +96,7 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
96
96
|
@classmethod
|
97
97
|
def _from_json(cls, data: Dict[str, Any]) -> Case:
|
98
98
|
id_ = data.pop("_id")
|
99
|
-
obj_type = get_type_from_string(data.pop("_obj_type"))
|
99
|
+
obj_type = get_type_from_string(data.pop("_obj_type")) if data["_obj_type"] is not None else None
|
100
100
|
name = data.pop("_name")
|
101
101
|
for k, v in data.items():
|
102
102
|
data[k] = SubclassJSONSerializer.from_json(v)
|
@@ -308,7 +308,10 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
|
|
308
308
|
:return: A case attribute that represents the original iterable attribute.
|
309
309
|
"""
|
310
310
|
values = list(attr_value.values()) if isinstance(attr_value, (dict, UserDict)) else attr_value
|
311
|
-
|
311
|
+
try:
|
312
|
+
_type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
|
313
|
+
except ValueError:
|
314
|
+
_type = None
|
312
315
|
attr_case = Case(_type, _id=id(attr_value), _name=name, original_object=attr_value)
|
313
316
|
case_attr = CaseAttribute(values)
|
314
317
|
for idx, val in enumerate(values):
|
@@ -317,7 +320,10 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
|
|
317
320
|
obj_name=name, parent_is_iterable=True)
|
318
321
|
attr_case.update(sub_attr_case)
|
319
322
|
for sub_attr, val in attr_case.items():
|
320
|
-
|
323
|
+
try:
|
324
|
+
setattr(case_attr, sub_attr, val)
|
325
|
+
except AttributeError:
|
326
|
+
pass
|
321
327
|
return case_attr
|
322
328
|
|
323
329
|
|
@@ -5,9 +5,11 @@ import typing
|
|
5
5
|
from dataclasses import dataclass, field
|
6
6
|
|
7
7
|
import typing_extensions
|
8
|
+
from omegaconf import MISSING
|
8
9
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
9
|
-
from typing_extensions import Any, Optional, Dict, Type, Tuple, Union, List, get_origin, Set
|
10
|
+
from typing_extensions import Any, Optional, Dict, Type, Tuple, Union, List, get_origin, Set, Callable
|
10
11
|
|
12
|
+
from ..utils import get_method_name, get_function_import_path_and_representation
|
11
13
|
from .callable_expression import CallableExpression
|
12
14
|
from .case import create_case, Case
|
13
15
|
from ..utils import copy_case, make_list, make_set, get_origin_and_args_from_type_hint, get_value_type_from_type_hint, \
|
@@ -37,6 +39,24 @@ class CaseQuery:
|
|
37
39
|
"""
|
38
40
|
Whether the attribute can only take one value (i.e. True) or multiple values (i.e. False).
|
39
41
|
"""
|
42
|
+
case_factory: Optional[Callable[[], Any]] = None
|
43
|
+
"""
|
44
|
+
The factory method that can be used to recreate the original case.
|
45
|
+
"""
|
46
|
+
case_factory_idx: Optional[int] = None
|
47
|
+
"""
|
48
|
+
This is used when the case factory is a list of cases, this index is used to select the case from the list.
|
49
|
+
"""
|
50
|
+
case_conf: Optional[CaseConf] = None
|
51
|
+
"""
|
52
|
+
The case configuration that is used to (re)create the original case, recommended to be used when you want to
|
53
|
+
the case to persist in the rule base, this would allow it to be used for merging with other similar conclusion RDRs.
|
54
|
+
"""
|
55
|
+
scenario: Optional[Callable] = None
|
56
|
+
"""
|
57
|
+
The executable scenario is the root callable that recreates the situation that the case is
|
58
|
+
created in, for example, when the case is created from a test function, this would be the test function itself.
|
59
|
+
"""
|
40
60
|
_target: Optional[CallableExpression] = None
|
41
61
|
"""
|
42
62
|
The target expression of the attribute.
|
@@ -95,7 +115,7 @@ class CaseQuery:
|
|
95
115
|
"""
|
96
116
|
if self._case is not None:
|
97
117
|
return self._case
|
98
|
-
elif not isinstance(self.original_case,
|
118
|
+
elif not isinstance(self.original_case, Case):
|
99
119
|
self._case = create_case(self.original_case, max_recursion_idx=3)
|
100
120
|
else:
|
101
121
|
self._case = self.original_case
|
@@ -225,4 +245,43 @@ class CaseQuery:
|
|
225
245
|
self.mutually_exclusive, _target=self.target, default_value=self.default_value,
|
226
246
|
scope=self.scope, _case=copy_case(self.case), _target_value=self.target_value,
|
227
247
|
conditions=self.conditions, is_function=self.is_function,
|
228
|
-
function_args_type_hints=self.function_args_type_hints
|
248
|
+
function_args_type_hints=self.function_args_type_hints,
|
249
|
+
case_factory=self.case_factory, case_factory_idx=self.case_factory_idx,
|
250
|
+
case_conf=self.case_conf, scenario=self.scenario)
|
251
|
+
|
252
|
+
|
253
|
+
@dataclass
|
254
|
+
class CaseConf:
|
255
|
+
factory_method: Callable[[Any], Any] = MISSING
|
256
|
+
|
257
|
+
def create(self) -> Any:
|
258
|
+
return self.factory_method()
|
259
|
+
|
260
|
+
|
261
|
+
@dataclass
|
262
|
+
class CaseFactoryMetaData:
|
263
|
+
factory_method: Optional[Callable[[Optional[CaseConf]], Any]] = None
|
264
|
+
factory_idx: Optional[int] = None
|
265
|
+
case_conf: Optional[CaseConf] = None
|
266
|
+
scenario: Optional[Callable] = None
|
267
|
+
|
268
|
+
@classmethod
|
269
|
+
def from_case_query(cls, case_query: CaseQuery) -> CaseFactoryMetaData:
|
270
|
+
return cls(factory_method=case_query.case_factory, factory_idx=case_query.case_factory_idx,
|
271
|
+
case_conf=case_query.case_conf, scenario=case_query.scenario)
|
272
|
+
|
273
|
+
def __repr__(self):
|
274
|
+
factory_method_repr = None
|
275
|
+
scenario_repr = None
|
276
|
+
if self.factory_method is not None:
|
277
|
+
_, factory_method_repr = get_function_import_path_and_representation(self.factory_method)
|
278
|
+
if self.scenario is not None:
|
279
|
+
_, scenario_repr = get_function_import_path_and_representation(self.scenario)
|
280
|
+
return (f"CaseFactoryMetaData("
|
281
|
+
f"factory_method={factory_method_repr}, "
|
282
|
+
f"factory_idx={self.factory_idx}, "
|
283
|
+
f"case_conf={self.case_conf},"
|
284
|
+
f" scenario={scenario_repr})")
|
285
|
+
|
286
|
+
def __str__(self):
|
287
|
+
return self.__repr__()
|
ripple_down_rules/helpers.py
CHANGED
@@ -1,18 +1,65 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import os
|
4
|
+
from types import ModuleType
|
4
5
|
|
6
|
+
from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
|
7
|
+
|
8
|
+
from .datastructures.case import create_case
|
5
9
|
from .datastructures.dataclasses import CaseQuery
|
6
|
-
from
|
7
|
-
from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING
|
10
|
+
from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING, Union
|
8
11
|
|
9
|
-
from .utils import
|
12
|
+
from .utils import get_func_rdr_model_name, copy_case, make_set, update_case
|
10
13
|
from .utils import calculate_precision_and_recall
|
11
14
|
|
12
15
|
if TYPE_CHECKING:
|
13
16
|
from .rdr import RippleDownRules
|
14
17
|
|
15
18
|
|
19
|
+
def general_rdr_classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
|
20
|
+
case: Any, modify_original_case: bool = False,
|
21
|
+
case_query: Optional[CaseQuery] = None) -> Dict[str, Any]:
|
22
|
+
"""
|
23
|
+
Classify a case by going through all classifiers and adding the categories that are classified,
|
24
|
+
and then restarting the classification until no more categories can be added.
|
25
|
+
|
26
|
+
:param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
|
27
|
+
:param case: The case to classify.
|
28
|
+
:param modify_original_case: Whether to modify the original case or create a copy and modify it.
|
29
|
+
:param case_query: The case query to extract metadata from if needed.
|
30
|
+
:return: The categories that the case belongs to.
|
31
|
+
"""
|
32
|
+
conclusions = {}
|
33
|
+
case = create_case(case)
|
34
|
+
case_cp = copy_case(case) if not modify_original_case else case
|
35
|
+
while True:
|
36
|
+
new_conclusions = {}
|
37
|
+
for attribute_name, rdr in classifiers_dict.items():
|
38
|
+
pred_atts = rdr.classify(case_cp, case_query=case_query)
|
39
|
+
if pred_atts is None:
|
40
|
+
continue
|
41
|
+
if rdr.mutually_exclusive:
|
42
|
+
if attribute_name not in conclusions or \
|
43
|
+
(attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
|
44
|
+
conclusions[attribute_name] = pred_atts
|
45
|
+
new_conclusions[attribute_name] = pred_atts
|
46
|
+
else:
|
47
|
+
pred_atts = make_set(pred_atts)
|
48
|
+
if attribute_name in conclusions:
|
49
|
+
pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
|
50
|
+
if len(pred_atts) > 0:
|
51
|
+
new_conclusions[attribute_name] = pred_atts
|
52
|
+
if attribute_name not in conclusions:
|
53
|
+
conclusions[attribute_name] = set()
|
54
|
+
conclusions[attribute_name].update(pred_atts)
|
55
|
+
if attribute_name in new_conclusions:
|
56
|
+
temp_case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, rdr.mutually_exclusive)
|
57
|
+
update_case(temp_case_query, new_conclusions)
|
58
|
+
if len(new_conclusions) == 0:
|
59
|
+
break
|
60
|
+
return conclusions
|
61
|
+
|
62
|
+
|
16
63
|
def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
|
17
64
|
"""
|
18
65
|
:param classifier: The RDR classifier to check the prediction of.
|
@@ -32,20 +79,19 @@ def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery, pred_ca
|
|
32
79
|
|
33
80
|
|
34
81
|
def load_or_create_func_rdr_model(func, model_dir: str, rdr_type: Type[RippleDownRules],
|
35
|
-
|
82
|
+
**rdr_kwargs) -> RippleDownRules:
|
36
83
|
"""
|
37
84
|
Load the RDR model of the function if it exists, otherwise create a new one.
|
38
85
|
|
39
86
|
:param func: The function to load the model for.
|
40
87
|
:param model_dir: The directory where the model is stored.
|
41
88
|
:param rdr_type: The type of the RDR model to load.
|
42
|
-
:param session: The SQLAlchemy session to use.
|
43
89
|
:param rdr_kwargs: Additional arguments to pass to the RDR constructor in the case of a new model.
|
44
90
|
"""
|
45
|
-
|
91
|
+
model_name = get_func_rdr_model_name(func)
|
92
|
+
model_path = os.path.join(model_dir, model_name, "rdr_metadata", f"{model_name}.json")
|
46
93
|
if os.path.exists(model_path):
|
47
|
-
rdr = rdr_type.load(
|
48
|
-
rdr.session = session
|
94
|
+
rdr = rdr_type.load(load_dir=model_dir, model_name=model_name)
|
49
95
|
else:
|
50
|
-
rdr = rdr_type(
|
96
|
+
rdr = rdr_type(**rdr_kwargs)
|
51
97
|
return rdr
|