ripple-down-rules 0.5.62__py3-none-any.whl → 0.5.64__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/datastructures/case.py +10 -4
- ripple_down_rules/datastructures/dataclasses.py +1 -1
- ripple_down_rules/helpers.py +51 -9
- ripple_down_rules/rdr.py +55 -59
- ripple_down_rules/rdr_decorators.py +48 -18
- ripple_down_rules/rules.py +9 -4
- ripple_down_rules/user_interface/gui.py +9 -7
- ripple_down_rules/user_interface/ipython_custom_shell.py +1 -1
- ripple_down_rules/user_interface/object_diagram.py +9 -1
- ripple_down_rules/user_interface/template_file_creator.py +24 -24
- ripple_down_rules/utils.py +174 -59
- {ripple_down_rules-0.5.62.dist-info → ripple_down_rules-0.5.64.dist-info}/METADATA +1 -1
- ripple_down_rules-0.5.64.dist-info/RECORD +24 -0
- ripple_down_rules-0.5.62.dist-info/RECORD +0 -24
- {ripple_down_rules-0.5.62.dist-info → ripple_down_rules-0.5.64.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.5.62.dist-info → ripple_down_rules-0.5.64.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.5.62.dist-info → ripple_down_rules-0.5.64.dist-info}/top_level.txt +0 -0
ripple_down_rules/__init__.py
CHANGED
@@ -84,7 +84,7 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
84
84
|
def _to_json(self) -> Dict[str, Any]:
|
85
85
|
serializable = {k: v for k, v in self.items() if not k.startswith("_")}
|
86
86
|
serializable["_id"] = self._id
|
87
|
-
serializable["_obj_type"] = get_full_class_name(self._obj_type)
|
87
|
+
serializable["_obj_type"] = get_full_class_name(self._obj_type) if self._obj_type is not None else None
|
88
88
|
serializable["_name"] = self._name
|
89
89
|
for k, v in serializable.items():
|
90
90
|
if isinstance(v, set):
|
@@ -96,7 +96,7 @@ class Case(UserDict, SubclassJSONSerializer):
|
|
96
96
|
@classmethod
|
97
97
|
def _from_json(cls, data: Dict[str, Any]) -> Case:
|
98
98
|
id_ = data.pop("_id")
|
99
|
-
obj_type = get_type_from_string(data.pop("_obj_type"))
|
99
|
+
obj_type = get_type_from_string(data.pop("_obj_type")) if data["_obj_type"] is not None else None
|
100
100
|
name = data.pop("_name")
|
101
101
|
for k, v in data.items():
|
102
102
|
data[k] = SubclassJSONSerializer.from_json(v)
|
@@ -308,7 +308,10 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
|
|
308
308
|
:return: A case attribute that represents the original iterable attribute.
|
309
309
|
"""
|
310
310
|
values = list(attr_value.values()) if isinstance(attr_value, (dict, UserDict)) else attr_value
|
311
|
-
|
311
|
+
try:
|
312
|
+
_type = type(list(values)[0]) if len(values) > 0 else get_value_type_from_type_hint(name, obj)
|
313
|
+
except ValueError:
|
314
|
+
_type = None
|
312
315
|
attr_case = Case(_type, _id=id(attr_value), _name=name, original_object=attr_value)
|
313
316
|
case_attr = CaseAttribute(values)
|
314
317
|
for idx, val in enumerate(values):
|
@@ -317,7 +320,10 @@ def create_case_attribute_from_iterable_attribute(attr_value: Any, name: str, ob
|
|
317
320
|
obj_name=name, parent_is_iterable=True)
|
318
321
|
attr_case.update(sub_attr_case)
|
319
322
|
for sub_attr, val in attr_case.items():
|
320
|
-
|
323
|
+
try:
|
324
|
+
setattr(case_attr, sub_attr, val)
|
325
|
+
except AttributeError:
|
326
|
+
pass
|
321
327
|
return case_attr
|
322
328
|
|
323
329
|
|
@@ -95,7 +95,7 @@ class CaseQuery:
|
|
95
95
|
"""
|
96
96
|
if self._case is not None:
|
97
97
|
return self._case
|
98
|
-
elif not isinstance(self.original_case,
|
98
|
+
elif not isinstance(self.original_case, Case):
|
99
99
|
self._case = create_case(self.original_case, max_recursion_idx=3)
|
100
100
|
else:
|
101
101
|
self._case = self.original_case
|
ripple_down_rules/helpers.py
CHANGED
@@ -1,18 +1,61 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import os
|
4
|
+
from types import ModuleType
|
4
5
|
|
6
|
+
from .datastructures.case import create_case
|
5
7
|
from .datastructures.dataclasses import CaseQuery
|
6
|
-
from
|
7
|
-
from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING
|
8
|
+
from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING, Union
|
8
9
|
|
9
|
-
from .utils import
|
10
|
+
from .utils import get_func_rdr_model_name, copy_case, make_set, update_case
|
10
11
|
from .utils import calculate_precision_and_recall
|
11
12
|
|
12
13
|
if TYPE_CHECKING:
|
13
14
|
from .rdr import RippleDownRules
|
14
15
|
|
15
16
|
|
17
|
+
def general_rdr_classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
|
18
|
+
case: Any, modify_original_case: bool = False) -> Dict[str, Any]:
|
19
|
+
"""
|
20
|
+
Classify a case by going through all classifiers and adding the categories that are classified,
|
21
|
+
and then restarting the classification until no more categories can be added.
|
22
|
+
|
23
|
+
:param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
|
24
|
+
:param case: The case to classify.
|
25
|
+
:param modify_original_case: Whether to modify the original case or create a copy and modify it.
|
26
|
+
:return: The categories that the case belongs to.
|
27
|
+
"""
|
28
|
+
conclusions = {}
|
29
|
+
case = create_case(case)
|
30
|
+
case_cp = copy_case(case) if not modify_original_case else case
|
31
|
+
while True:
|
32
|
+
new_conclusions = {}
|
33
|
+
for attribute_name, rdr in classifiers_dict.items():
|
34
|
+
pred_atts = rdr.classify(case_cp)
|
35
|
+
if pred_atts is None:
|
36
|
+
continue
|
37
|
+
if rdr.mutually_exclusive:
|
38
|
+
if attribute_name not in conclusions or \
|
39
|
+
(attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
|
40
|
+
conclusions[attribute_name] = pred_atts
|
41
|
+
new_conclusions[attribute_name] = pred_atts
|
42
|
+
else:
|
43
|
+
pred_atts = make_set(pred_atts)
|
44
|
+
if attribute_name in conclusions:
|
45
|
+
pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
|
46
|
+
if len(pred_atts) > 0:
|
47
|
+
new_conclusions[attribute_name] = pred_atts
|
48
|
+
if attribute_name not in conclusions:
|
49
|
+
conclusions[attribute_name] = set()
|
50
|
+
conclusions[attribute_name].update(pred_atts)
|
51
|
+
if attribute_name in new_conclusions:
|
52
|
+
case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, rdr.mutually_exclusive)
|
53
|
+
update_case(case_query, new_conclusions)
|
54
|
+
if len(new_conclusions) == 0:
|
55
|
+
break
|
56
|
+
return conclusions
|
57
|
+
|
58
|
+
|
16
59
|
def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
|
17
60
|
"""
|
18
61
|
:param classifier: The RDR classifier to check the prediction of.
|
@@ -32,20 +75,19 @@ def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery, pred_ca
|
|
32
75
|
|
33
76
|
|
34
77
|
def load_or_create_func_rdr_model(func, model_dir: str, rdr_type: Type[RippleDownRules],
|
35
|
-
|
78
|
+
**rdr_kwargs) -> RippleDownRules:
|
36
79
|
"""
|
37
80
|
Load the RDR model of the function if it exists, otherwise create a new one.
|
38
81
|
|
39
82
|
:param func: The function to load the model for.
|
40
83
|
:param model_dir: The directory where the model is stored.
|
41
84
|
:param rdr_type: The type of the RDR model to load.
|
42
|
-
:param session: The SQLAlchemy session to use.
|
43
85
|
:param rdr_kwargs: Additional arguments to pass to the RDR constructor in the case of a new model.
|
44
86
|
"""
|
45
|
-
|
87
|
+
model_name = get_func_rdr_model_name(func)
|
88
|
+
model_path = os.path.join(model_dir, model_name, "rdr_metadata", f"{model_name}.json")
|
46
89
|
if os.path.exists(model_path):
|
47
|
-
rdr = rdr_type.load(
|
48
|
-
rdr.session = session
|
90
|
+
rdr = rdr_type.load(load_dir=model_dir, model_name=model_name)
|
49
91
|
else:
|
50
|
-
rdr = rdr_type(
|
92
|
+
rdr = rdr_type(**rdr_kwargs)
|
51
93
|
return rdr
|
ripple_down_rules/rdr.py
CHANGED
@@ -28,7 +28,7 @@ from .datastructures.case import Case, CaseAttribute, create_case
|
|
28
28
|
from .datastructures.dataclasses import CaseQuery
|
29
29
|
from .datastructures.enums import MCRDRMode
|
30
30
|
from .experts import Expert, Human
|
31
|
-
from .helpers import is_matching
|
31
|
+
from .helpers import is_matching, general_rdr_classify
|
32
32
|
from .rules import Rule, SingleClassRule, MultiClassTopRule, MultiClassStopRule
|
33
33
|
try:
|
34
34
|
from .user_interface.gui import RDRCaseViewer
|
@@ -36,7 +36,7 @@ except ImportError as e:
|
|
36
36
|
RDRCaseViewer = None
|
37
37
|
from .utils import draw_tree, make_set, copy_case, \
|
38
38
|
SubclassJSONSerializer, make_list, get_type_from_string, \
|
39
|
-
is_conflicting,
|
39
|
+
is_conflicting, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
|
40
40
|
is_iterable, str_to_snake_case
|
41
41
|
|
42
42
|
|
@@ -76,16 +76,18 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
76
76
|
"""
|
77
77
|
The name of the model. If None, the model name will be the generated python file name.
|
78
78
|
"""
|
79
|
+
mutually_exclusive: Optional[bool] = None
|
80
|
+
"""
|
81
|
+
Whether the output of the classification of this rdr allows only one possible conclusion or not.
|
82
|
+
"""
|
79
83
|
|
80
84
|
def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None,
|
81
|
-
save_dir: Optional[str] = None,
|
85
|
+
save_dir: Optional[str] = None, model_name: Optional[str] = None):
|
82
86
|
"""
|
83
87
|
:param start_rule: The starting rule for the classifier.
|
84
88
|
:param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
|
85
89
|
:param save_dir: The directory to save the classifier to.
|
86
|
-
:param ask_always: Whether to always ask the expert (True) or only ask when classification fails (False).
|
87
90
|
"""
|
88
|
-
self.ask_always: bool = ask_always
|
89
91
|
self.model_name: Optional[str] = model_name
|
90
92
|
self.save_dir = save_dir
|
91
93
|
self.start_rule = start_rule
|
@@ -224,7 +226,10 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
224
226
|
"""
|
225
227
|
pass
|
226
228
|
|
227
|
-
def fit_case(self, case_query: CaseQuery,
|
229
|
+
def fit_case(self, case_query: CaseQuery,
|
230
|
+
expert: Optional[Expert] = None,
|
231
|
+
update_existing_rules: bool = True,
|
232
|
+
**kwargs) \
|
228
233
|
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
229
234
|
"""
|
230
235
|
Fit the classifier to a case and ask the expert for refinements or alternatives if the classification is
|
@@ -232,6 +237,8 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
232
237
|
|
233
238
|
:param case_query: The query containing the case to classify and the target category to compare the case with.
|
234
239
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
240
|
+
:param update_existing_rules: Whether to update the existing same conclusion type rules that already gave
|
241
|
+
some conclusions with the type required by the case query.
|
235
242
|
:return: The category that the case belongs to.
|
236
243
|
"""
|
237
244
|
if case_query is None:
|
@@ -248,11 +255,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
248
255
|
if case_query.target is None:
|
249
256
|
case_query_cp = copy(case_query)
|
250
257
|
conclusions = self.classify(case_query_cp.case, modify_case=True)
|
251
|
-
if
|
252
|
-
or is_iterable(conclusions) and len(conclusions) == 0
|
253
|
-
or (isinstance(conclusions, dict) and (case_query_cp.attribute_name not in conclusions
|
254
|
-
or not any(type(c) in case_query_cp.core_attribute_type
|
255
|
-
for c in make_list(conclusions[case_query_cp.attribute_name]))))):
|
258
|
+
if self.should_i_ask_the_expert_for_a_target(conclusions, case_query_cp, update_existing_rules):
|
256
259
|
expert.ask_for_conclusion(case_query_cp)
|
257
260
|
case_query.target = case_query_cp.target
|
258
261
|
if case_query.target is None:
|
@@ -268,6 +271,34 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
268
271
|
|
269
272
|
return fit_case_result
|
270
273
|
|
274
|
+
@staticmethod
|
275
|
+
def should_i_ask_the_expert_for_a_target(conclusions: Union[Any, Dict[str, Any]],
|
276
|
+
case_query: CaseQuery,
|
277
|
+
update_existing: bool) -> bool:
|
278
|
+
"""
|
279
|
+
Determine if the rdr should ask the expert for the target of a given case query.
|
280
|
+
|
281
|
+
:param conclusions: The conclusions of the case.
|
282
|
+
:param case_query: The query containing the case to classify.
|
283
|
+
:param update_existing: Whether to update rules that gave the required type of conclusions.
|
284
|
+
:return: True if the rdr should ask the expert, False otherwise.
|
285
|
+
"""
|
286
|
+
if conclusions is None:
|
287
|
+
return True
|
288
|
+
elif is_iterable(conclusions) and len(conclusions) == 0:
|
289
|
+
return True
|
290
|
+
elif isinstance(conclusions, dict):
|
291
|
+
if case_query.attribute_name not in conclusions:
|
292
|
+
return True
|
293
|
+
conclusions = conclusions[case_query.attribute_name]
|
294
|
+
conclusion_types = map(type, make_list(conclusions))
|
295
|
+
if not any(ct in case_query.core_attribute_type for ct in conclusion_types):
|
296
|
+
return True
|
297
|
+
elif update_existing:
|
298
|
+
return True
|
299
|
+
else:
|
300
|
+
return False
|
301
|
+
|
271
302
|
@abstractmethod
|
272
303
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
273
304
|
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
@@ -423,11 +454,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
423
454
|
f.write(defs_imports + "\n\n")
|
424
455
|
with open(file_name, "w") as f:
|
425
456
|
imports += f"from .{self.generated_python_defs_file_name} import *\n"
|
426
|
-
imports += f"from ripple_down_rules.rdr import {self.__class__.__name__}\n"
|
427
457
|
f.write(imports + "\n\n")
|
428
458
|
f.write(f"attribute_name = '{self.attribute_name}'\n")
|
429
459
|
f.write(f"conclusion_type = ({', '.join([ct.__name__ for ct in self.conclusion_type])},)\n")
|
430
|
-
f.write(f"
|
460
|
+
f.write(f"mutually_exclusive = {self.mutually_exclusive}\n")
|
431
461
|
f.write(f"\n\n{func_def}")
|
432
462
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
433
463
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
@@ -533,6 +563,11 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
533
563
|
|
534
564
|
class SingleClassRDR(RDRWithCodeWriter):
|
535
565
|
|
566
|
+
mutually_exclusive: bool = True
|
567
|
+
"""
|
568
|
+
The output of the classification of this rdr negates all other possible outputs, there can only be one true value.
|
569
|
+
"""
|
570
|
+
|
536
571
|
def __init__(self, default_conclusion: Optional[Any] = None, **kwargs):
|
537
572
|
"""
|
538
573
|
:param start_rule: The starting rule for the classifier.
|
@@ -650,6 +685,10 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
650
685
|
"""
|
651
686
|
The conditions of the stopping rule if needed.
|
652
687
|
"""
|
688
|
+
mutually_exclusive: bool = False
|
689
|
+
"""
|
690
|
+
The output of the classification of this rdr allows for more than one true value as conclusion.
|
691
|
+
"""
|
653
692
|
|
654
693
|
def __init__(self, start_rule: Optional[MultiClassTopRule] = None,
|
655
694
|
mode: MCRDRMode = MCRDRMode.StopOnly, **kwargs):
|
@@ -903,50 +942,7 @@ class GeneralRDR(RippleDownRules):
|
|
903
942
|
:param modify_case: Whether to modify the original case or create a copy and modify it.
|
904
943
|
:return: The categories that the case belongs to.
|
905
944
|
"""
|
906
|
-
return
|
907
|
-
|
908
|
-
@staticmethod
|
909
|
-
def _classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
|
910
|
-
case: Any, modify_original_case: bool = False) -> Dict[str, Any]:
|
911
|
-
"""
|
912
|
-
Classify a case by going through all classifiers and adding the categories that are classified,
|
913
|
-
and then restarting the classification until no more categories can be added.
|
914
|
-
|
915
|
-
:param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
|
916
|
-
:param case: The case to classify.
|
917
|
-
:param modify_original_case: Whether to modify the original case or create a copy and modify it.
|
918
|
-
:return: The categories that the case belongs to.
|
919
|
-
"""
|
920
|
-
conclusions = {}
|
921
|
-
case = case if isinstance(case, (Case, SQLTable)) else create_case(case)
|
922
|
-
case_cp = copy_case(case) if not modify_original_case else case
|
923
|
-
while True:
|
924
|
-
new_conclusions = {}
|
925
|
-
for attribute_name, rdr in classifiers_dict.items():
|
926
|
-
pred_atts = rdr.classify(case_cp)
|
927
|
-
if pred_atts is None:
|
928
|
-
continue
|
929
|
-
if rdr.type_ is SingleClassRDR:
|
930
|
-
if attribute_name not in conclusions or \
|
931
|
-
(attribute_name in conclusions and conclusions[attribute_name] != pred_atts):
|
932
|
-
conclusions[attribute_name] = pred_atts
|
933
|
-
new_conclusions[attribute_name] = pred_atts
|
934
|
-
else:
|
935
|
-
pred_atts = make_set(pred_atts)
|
936
|
-
if attribute_name in conclusions:
|
937
|
-
pred_atts = {p for p in pred_atts if p not in conclusions[attribute_name]}
|
938
|
-
if len(pred_atts) > 0:
|
939
|
-
new_conclusions[attribute_name] = pred_atts
|
940
|
-
if attribute_name not in conclusions:
|
941
|
-
conclusions[attribute_name] = set()
|
942
|
-
conclusions[attribute_name].update(pred_atts)
|
943
|
-
if attribute_name in new_conclusions:
|
944
|
-
mutually_exclusive = True if rdr.type_ is SingleClassRDR else False
|
945
|
-
case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, mutually_exclusive)
|
946
|
-
update_case(case_query, new_conclusions)
|
947
|
-
if len(new_conclusions) == 0:
|
948
|
-
break
|
949
|
-
return conclusions
|
945
|
+
return general_rdr_classify(self.start_rules_dict, case, modify_original_case=modify_case)
|
950
946
|
|
951
947
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
952
948
|
-> Dict[str, Any]:
|
@@ -1043,7 +1039,7 @@ class GeneralRDR(RippleDownRules):
|
|
1043
1039
|
f.write(func_def)
|
1044
1040
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
1045
1041
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
1046
|
-
f.write(f"{' ' * 4}return
|
1042
|
+
f.write(f"{' ' * 4}return general_rdr_classify(classifiers_dict, case)\n")
|
1047
1043
|
|
1048
1044
|
@property
|
1049
1045
|
def _default_generated_python_file_name(self) -> Optional[str]:
|
@@ -1068,7 +1064,7 @@ class GeneralRDR(RippleDownRules):
|
|
1068
1064
|
# add type hints
|
1069
1065
|
imports += f"from typing_extensions import Dict, Any\n"
|
1070
1066
|
# import rdr type
|
1071
|
-
imports += f"from ripple_down_rules.
|
1067
|
+
imports += f"from ripple_down_rules.helpers import general_rdr_classify\n"
|
1072
1068
|
# add case type
|
1073
1069
|
imports += f"from ripple_down_rules.datastructures.case import Case, create_case\n"
|
1074
1070
|
imports += f"from {self.case_type.__module__} import {self.case_type.__name__}\n"
|
@@ -7,13 +7,14 @@ import os.path
|
|
7
7
|
from functools import wraps
|
8
8
|
|
9
9
|
from pyparsing.tools.cvt_pyparsing_pep8_names import camel_to_snake
|
10
|
-
from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union
|
10
|
+
from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union, Sequence
|
11
11
|
|
12
12
|
from ripple_down_rules.datastructures.case import create_case, Case
|
13
13
|
from ripple_down_rules.datastructures.dataclasses import CaseQuery
|
14
14
|
from ripple_down_rules.datastructures.enums import Category
|
15
15
|
from ripple_down_rules.experts import Expert, Human
|
16
16
|
from ripple_down_rules.rdr import GeneralRDR, RippleDownRules
|
17
|
+
from ripple_down_rules.user_interface.gui import RDRCaseViewer
|
17
18
|
from ripple_down_rules.utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
|
18
19
|
get_method_class_if_exists, get_method_name, str_to_snake_case
|
19
20
|
|
@@ -26,7 +27,10 @@ class RDRDecorator:
|
|
26
27
|
mutual_exclusive: bool,
|
27
28
|
output_name: str = "output_",
|
28
29
|
fit: bool = True,
|
29
|
-
expert: Optional[Expert] = None
|
30
|
+
expert: Optional[Expert] = None,
|
31
|
+
ask_always: bool = False,
|
32
|
+
update_existing_rules: bool = True,
|
33
|
+
viewer: Optional[RDRCaseViewer] = None):
|
30
34
|
"""
|
31
35
|
:param models_dir: The directory to save/load the RDR models.
|
32
36
|
:param output_type: The type of the output. This is used to create the RDR model.
|
@@ -38,6 +42,9 @@ class RDRDecorator:
|
|
38
42
|
classification mode. This means that the RDR will classify the function's output based on the RDR model.
|
39
43
|
:param expert: The expert that will be used to prompt the user for the correct output. If None, a Human
|
40
44
|
expert will be used.
|
45
|
+
:param ask_always: If True, the function will ask the user for a target if it doesn't exist.
|
46
|
+
:param update_existing_rules: If True, the function will update the existing RDR rules
|
47
|
+
even if they gave an output.
|
41
48
|
:return: A decorator to use a GeneralRDR as a classifier that monitors and modifies the function's output.
|
42
49
|
"""
|
43
50
|
self.rdr_models_dir = models_dir
|
@@ -48,6 +55,9 @@ class RDRDecorator:
|
|
48
55
|
self.output_name = output_name
|
49
56
|
self.fit: bool = fit
|
50
57
|
self.expert: Optional[Expert] = expert
|
58
|
+
self.ask_always = ask_always
|
59
|
+
self.update_existing_rules = update_existing_rules
|
60
|
+
self.viewer = viewer
|
51
61
|
self.load()
|
52
62
|
|
53
63
|
def decorator(self, func: Callable) -> Callable:
|
@@ -59,59 +69,77 @@ class RDRDecorator:
|
|
59
69
|
self.parsed_output_type = self.parse_output_type(func, self.output_type, *args)
|
60
70
|
if self.model_name is None:
|
61
71
|
self.initialize_rdr_model_name_and_load(func)
|
72
|
+
if self.expert is None:
|
73
|
+
self.expert = Human(viewer=self.viewer,
|
74
|
+
answers_save_path=self.rdr_models_dir + f'/expert_answers')
|
75
|
+
|
76
|
+
func_output = {self.output_name: func(*args, **kwargs)}
|
62
77
|
|
63
78
|
if self.fit:
|
64
|
-
case_query = self.create_case_query_from_method(func,
|
65
|
-
self.
|
79
|
+
case_query = self.create_case_query_from_method(func, func_output,
|
80
|
+
self.parsed_output_type,
|
81
|
+
self.mutual_exclusive,
|
66
82
|
*args, **kwargs)
|
67
|
-
output = self.rdr.fit_case(case_query, expert=self.expert
|
83
|
+
output = self.rdr.fit_case(case_query, expert=self.expert,
|
84
|
+
ask_always_for_target=self.ask_always,
|
85
|
+
update_existing_rules=self.update_existing_rules,
|
86
|
+
viewer=self.viewer)
|
87
|
+
else:
|
88
|
+
case, case_dict = self.create_case_from_method(func, func_output, *args, **kwargs)
|
89
|
+
output = self.rdr.classify(case)
|
90
|
+
|
91
|
+
if self.output_name in output:
|
68
92
|
return output[self.output_name]
|
69
93
|
else:
|
70
|
-
|
71
|
-
return self.rdr.classify(case)[self.output_name]
|
94
|
+
return func_output[self.output_name]
|
72
95
|
|
73
96
|
return wrapper
|
74
97
|
|
75
98
|
@staticmethod
|
76
|
-
def create_case_query_from_method(func: Callable,
|
77
|
-
|
99
|
+
def create_case_query_from_method(func: Callable,
|
100
|
+
func_output: Dict[str, Any],
|
101
|
+
output_type: Sequence[Type],
|
102
|
+
mutual_exclusive: bool,
|
103
|
+
*args, **kwargs) -> CaseQuery:
|
78
104
|
"""
|
79
105
|
Create a CaseQuery from the function and its arguments.
|
80
106
|
|
81
107
|
:param func: The function to create a case from.
|
82
|
-
:param
|
108
|
+
:param func_output: The output of the function as a dictionary, where the key is the output name.
|
109
|
+
:param output_type: The type of the output as a sequence of types.
|
83
110
|
:param mutual_exclusive: If True, the output types are mutually exclusive.
|
84
|
-
:param output_name: The name of the output in the case. Defaults to 'output_'.
|
85
111
|
:param args: The positional arguments of the function.
|
86
112
|
:param kwargs: The keyword arguments of the function.
|
87
113
|
:return: A CaseQuery object representing the case.
|
88
114
|
"""
|
89
115
|
output_type = make_set(output_type)
|
90
|
-
case, case_dict = RDRDecorator.create_case_from_method(func,
|
116
|
+
case, case_dict = RDRDecorator.create_case_from_method(func, func_output, *args, **kwargs)
|
91
117
|
scope = func.__globals__
|
92
118
|
scope.update(case_dict)
|
93
119
|
func_args_type_hints = get_type_hints(func)
|
120
|
+
output_name = list(func_output.keys())[0]
|
94
121
|
func_args_type_hints.update({output_name: Union[tuple(output_type)]})
|
95
122
|
return CaseQuery(case, output_name, Union[tuple(output_type)],
|
96
123
|
mutual_exclusive, scope=scope,
|
97
124
|
is_function=True, function_args_type_hints=func_args_type_hints)
|
98
125
|
|
99
126
|
@staticmethod
|
100
|
-
def create_case_from_method(func: Callable,
|
127
|
+
def create_case_from_method(func: Callable,
|
128
|
+
func_output: Dict[str, Any],
|
129
|
+
*args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
|
101
130
|
"""
|
102
131
|
Create a Case from the function and its arguments.
|
103
132
|
|
104
133
|
:param func: The function to create a case from.
|
105
|
-
:param
|
134
|
+
:param func_output: A dictionary containing the output of the function, where the key is the output name.
|
106
135
|
:param args: The positional arguments of the function.
|
107
136
|
:param kwargs: The keyword arguments of the function.
|
108
137
|
:return: A Case object representing the case.
|
109
138
|
"""
|
110
139
|
case_dict = get_method_args_as_dict(func, *args, **kwargs)
|
111
|
-
func_output
|
112
|
-
case_dict.update({output_name: func_output})
|
140
|
+
case_dict.update(func_output)
|
113
141
|
case_name = get_func_rdr_model_name(func)
|
114
|
-
return
|
142
|
+
return Case(dict, id(case_dict), case_name, case_dict, **case_dict), case_dict
|
115
143
|
|
116
144
|
def initialize_rdr_model_name_and_load(self, func: Callable) -> None:
|
117
145
|
model_file_name = get_func_rdr_model_name(func, include_file_name=True)
|
@@ -148,8 +176,10 @@ class RDRDecorator:
|
|
148
176
|
model_path = os.path.join(self.rdr_models_dir, self.model_name + f"/rdr_metadata/{self.model_name}.json")
|
149
177
|
if os.path.exists(os.path.join(self.rdr_models_dir, model_path)):
|
150
178
|
self.rdr = GeneralRDR.load(self.rdr_models_dir, self.model_name)
|
179
|
+
self.rdr.set_viewer(self.viewer)
|
151
180
|
if self.rdr is None:
|
152
|
-
self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name
|
181
|
+
self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name,
|
182
|
+
viewer=self.viewer)
|
153
183
|
|
154
184
|
def update_from_python(self):
|
155
185
|
"""
|
ripple_down_rules/rules.py
CHANGED
@@ -12,7 +12,7 @@ from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
|
|
12
12
|
from .datastructures.callable_expression import CallableExpression
|
13
13
|
from .datastructures.case import Case
|
14
14
|
from .datastructures.enums import RDREdge, Stop
|
15
|
-
from .utils import SubclassJSONSerializer, conclusion_to_json
|
15
|
+
from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name
|
16
16
|
|
17
17
|
|
18
18
|
class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
@@ -150,11 +150,16 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
150
150
|
pass
|
151
151
|
|
152
152
|
def _to_json(self) -> Dict[str, Any]:
|
153
|
-
|
153
|
+
try:
|
154
|
+
corner_case = SubclassJSONSerializer.to_json_static(self.corner_case) if self.corner_case else None
|
155
|
+
except Exception as e:
|
156
|
+
logging.debug("Failed to serialize corner case to json, setting it to None. Error: %s", e)
|
157
|
+
corner_case = None
|
158
|
+
json_serialization = {"_type": get_full_class_name(type(self)),
|
159
|
+
"conditions": self.conditions.to_json(),
|
154
160
|
"conclusion": conclusion_to_json(self.conclusion),
|
155
161
|
"parent": self.parent.json_serialization if self.parent else None,
|
156
|
-
"corner_case":
|
157
|
-
if self.corner_case else None,
|
162
|
+
"corner_case": corner_case,
|
158
163
|
"conclusion_name": self.conclusion_name,
|
159
164
|
"weight": self.weight,
|
160
165
|
"uid": self.uid}
|
@@ -281,12 +281,14 @@ class RDRCaseViewer(QMainWindow):
|
|
281
281
|
main_obj: Optional[Dict[str, Any]] = None
|
282
282
|
user_input: Optional[str] = None
|
283
283
|
attributes_widget: Optional[QWidget] = None
|
284
|
-
save_function: Optional[Callable[str], None] = None
|
284
|
+
save_function: Optional[Callable[str, str], None] = None
|
285
285
|
|
286
|
-
|
287
|
-
|
286
|
+
def __init__(self, parent=None,
|
287
|
+
save_dir: Optional[str] = None,
|
288
|
+
save_model_name: Optional[str] = None):
|
288
289
|
super().__init__(parent)
|
289
|
-
self.
|
290
|
+
self.save_dir = save_dir
|
291
|
+
self.save_model_name = save_model_name
|
290
292
|
|
291
293
|
self.setWindowTitle("RDR Case Viewer")
|
292
294
|
|
@@ -323,17 +325,17 @@ class RDRCaseViewer(QMainWindow):
|
|
323
325
|
|
324
326
|
# Add both to main layout
|
325
327
|
main_layout.addWidget(self.attributes_widget, stretch=1)
|
326
|
-
main_layout.addWidget(middle_widget, stretch=
|
328
|
+
main_layout.addWidget(middle_widget, stretch=1)
|
327
329
|
main_layout.addWidget(self.obj_diagram_viewer, stretch=2)
|
328
330
|
|
329
|
-
def set_save_function(self, save_function: Callable[[str], None]) -> None:
|
331
|
+
def set_save_function(self, save_function: Callable[[str, str], None]) -> None:
|
330
332
|
"""
|
331
333
|
Set the function to save the file.
|
332
334
|
|
333
335
|
:param save_function: The function to save the file.
|
334
336
|
"""
|
335
337
|
self.save_function = save_function
|
336
|
-
self.save_btn.clicked.connect(lambda: self.save_function(self.
|
338
|
+
self.save_btn.clicked.connect(lambda: self.save_function(self.save_dir, self.save_model_name))
|
337
339
|
|
338
340
|
def print(self, msg):
|
339
341
|
"""
|
@@ -134,7 +134,7 @@ class IPythonShell:
|
|
134
134
|
"""
|
135
135
|
Update the user input from the code lines captured in the shell.
|
136
136
|
"""
|
137
|
-
if
|
137
|
+
if self.shell.all_lines[0].replace('return', '').strip() == '':
|
138
138
|
self.user_input = None
|
139
139
|
else:
|
140
140
|
self.all_code_lines = extract_dependencies(self.shell.all_lines)
|
@@ -1,5 +1,9 @@
|
|
1
1
|
import logging
|
2
2
|
|
3
|
+
from ripple_down_rules.datastructures.case import Case
|
4
|
+
from ripple_down_rules.datastructures.dataclasses import CaseQuery
|
5
|
+
from ripple_down_rules.utils import SubclassJSONSerializer
|
6
|
+
|
3
7
|
try:
|
4
8
|
import graphviz
|
5
9
|
except ImportError:
|
@@ -77,7 +81,11 @@ def generate_object_graph(obj, name='root', seen=None, graph=None, current_depth
|
|
77
81
|
for attr in dir(obj):
|
78
82
|
if attr.startswith('_'):
|
79
83
|
continue
|
80
|
-
if attr == 'scope':
|
84
|
+
if isinstance(obj, CaseQuery) and attr == 'scope':
|
85
|
+
continue
|
86
|
+
if isinstance(obj, Case) and attr in ['data']:
|
87
|
+
continue
|
88
|
+
if isinstance(obj, SubclassJSONSerializer) and attr == 'data_class_refs':
|
81
89
|
continue
|
82
90
|
value = getattr(obj, attr)
|
83
91
|
if callable(value):
|
@@ -8,13 +8,13 @@ from functools import cached_property
|
|
8
8
|
from textwrap import indent, dedent
|
9
9
|
|
10
10
|
from colorama import Fore, Style
|
11
|
-
from typing_extensions import Optional, Type, List, Callable, Tuple, Dict
|
11
|
+
from typing_extensions import Optional, Type, List, Callable, Tuple, Dict, Any, Union
|
12
12
|
|
13
13
|
from ..datastructures.case import Case
|
14
14
|
from ..datastructures.dataclasses import CaseQuery
|
15
15
|
from ..datastructures.enums import Editor, PromptFor
|
16
|
-
from ..utils import str_to_snake_case, get_imports_from_scope, make_list,
|
17
|
-
get_imports_from_types, extract_function_source, extract_imports
|
16
|
+
from ..utils import str_to_snake_case, get_imports_from_scope, make_list, stringify_hint, \
|
17
|
+
get_imports_from_types, extract_function_source, extract_imports, get_types_to_import_from_type_hints
|
18
18
|
|
19
19
|
|
20
20
|
def detect_available_editor() -> Optional[Editor]:
|
@@ -84,6 +84,7 @@ class TemplateFileCreator:
|
|
84
84
|
self.func_doc: str = self.get_func_doc()
|
85
85
|
self.function_signature: str = self.get_function_signature()
|
86
86
|
self.editor: Optional[Editor] = detect_available_editor()
|
87
|
+
self.editor_cmd: Optional[str] = os.environ.get("RDR_EDITOR_CMD")
|
87
88
|
self.workspace: str = os.environ.get("RDR_EDITOR_WORKSPACE", os.path.dirname(self.case_query.scope['__file__']))
|
88
89
|
self.temp_file_path: str = os.path.join(self.workspace, "edit_code_here.py")
|
89
90
|
|
@@ -98,7 +99,7 @@ class TemplateFileCreator:
|
|
98
99
|
return make_list(output_type) if output_type is not None else None
|
99
100
|
|
100
101
|
def edit(self):
|
101
|
-
if self.editor is None:
|
102
|
+
if self.editor is None and self.editor_cmd is None:
|
102
103
|
self.print_func(
|
103
104
|
f"{Fore.RED}ERROR:: No editor found. Please install PyCharm, VSCode or code-server.{Style.RESET_ALL}")
|
104
105
|
return
|
@@ -112,7 +113,11 @@ class TemplateFileCreator:
|
|
112
113
|
"""
|
113
114
|
Open the file in the available editor.
|
114
115
|
"""
|
115
|
-
if self.
|
116
|
+
if self.editor_cmd is not None:
|
117
|
+
subprocess.Popen([self.editor_cmd, self.temp_file_path],
|
118
|
+
stdout=subprocess.DEVNULL,
|
119
|
+
stderr=subprocess.DEVNULL)
|
120
|
+
elif self.editor == Editor.Pycharm:
|
116
121
|
subprocess.Popen(["pycharm", "--line", str(self.user_edit_line), self.temp_file_path],
|
117
122
|
stdout=subprocess.DEVNULL,
|
118
123
|
stderr=subprocess.DEVNULL)
|
@@ -172,7 +177,7 @@ class TemplateFileCreator:
|
|
172
177
|
for k, v in self.case_query.case.items():
|
173
178
|
if (self.case_query.function_args_type_hints is not None
|
174
179
|
and k in self.case_query.function_args_type_hints):
|
175
|
-
func_args[k] =
|
180
|
+
func_args[k] = stringify_hint(self.case_query.function_args_type_hints[k])
|
176
181
|
else:
|
177
182
|
func_args[k] = type(v).__name__ if not isinstance(v, type) else f"Type[{v.__name__}]"
|
178
183
|
func_args = ', '.join([f"{k}: {v}" if str(v) not in ["NoneType", "None"] else str(k)
|
@@ -202,30 +207,25 @@ class TemplateFileCreator:
|
|
202
207
|
for k, v in self.case_query.case.items():
|
203
208
|
if (self.case_query.function_args_type_hints is not None
|
204
209
|
and k in self.case_query.function_args_type_hints):
|
205
|
-
|
206
|
-
|
207
|
-
hint_split = hint.split('.')
|
208
|
-
if len(hint_split) > 1:
|
209
|
-
case_type_imports.append(f"from {'.'.join(hint_split[:-1])} import {hint_split[-1]}")
|
210
|
+
types_to_import = get_types_to_import_from_type_hints([self.case_query.function_args_type_hints[k]])
|
211
|
+
case_type_imports.extend(list(types_to_import))
|
210
212
|
else:
|
211
|
-
|
212
|
-
case_type_imports.append(f"from {v.__module__} import {v.__name__}")
|
213
|
-
elif hasattr(v, "__module__") and not v.__module__.startswith("__"):
|
214
|
-
case_type_imports.append(f"\nfrom {type(v).__module__} import {type(v).__name__}")
|
213
|
+
case_type_imports.append(v)
|
215
214
|
else:
|
216
|
-
case_type_imports.append(
|
215
|
+
case_type_imports.append(self.case_type)
|
217
216
|
if self.output_type is None:
|
218
|
-
output_type_imports = [
|
217
|
+
output_type_imports = [Any]
|
219
218
|
else:
|
220
|
-
output_type_imports =
|
219
|
+
output_type_imports = self.output_type
|
221
220
|
if len(self.output_type) > 1:
|
222
|
-
output_type_imports.append(
|
221
|
+
output_type_imports.append(Union)
|
223
222
|
if list in self.output_type:
|
224
|
-
output_type_imports.append(
|
225
|
-
|
226
|
-
imports = [i for i in imports if ("get_ipython" not in i)]
|
227
|
-
|
228
|
-
|
223
|
+
output_type_imports.append(List)
|
224
|
+
import_types = list(self.case_query.scope.values())
|
225
|
+
# imports = [i for i in imports if ("get_ipython" not in i)]
|
226
|
+
import_types.extend(case_type_imports)
|
227
|
+
import_types.extend(output_type_imports)
|
228
|
+
imports = get_imports_from_types(import_types)
|
229
229
|
imports = set(imports)
|
230
230
|
return '\n'.join(imports)
|
231
231
|
|
ripple_down_rules/utils.py
CHANGED
@@ -10,14 +10,14 @@ import os
|
|
10
10
|
import re
|
11
11
|
import threading
|
12
12
|
import uuid
|
13
|
-
from collections import UserDict
|
13
|
+
from collections import UserDict, defaultdict
|
14
14
|
from copy import deepcopy, copy
|
15
15
|
from dataclasses import is_dataclass, fields
|
16
16
|
from enum import Enum
|
17
17
|
from textwrap import dedent
|
18
18
|
from types import NoneType
|
19
|
-
from typing import List
|
20
19
|
|
20
|
+
from sqlalchemy.exc import NoInspectionAvailable
|
21
21
|
|
22
22
|
try:
|
23
23
|
import matplotlib
|
@@ -42,8 +42,7 @@ from sqlalchemy import MetaData, inspect
|
|
42
42
|
from sqlalchemy.orm import Mapped, registry, class_mapper, DeclarativeBase as SQLTable, Session
|
43
43
|
from tabulate import tabulate
|
44
44
|
from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get_type_hints, \
|
45
|
-
get_origin, get_args, Tuple, Optional, List, Union, Self
|
46
|
-
|
45
|
+
get_origin, get_args, Tuple, Optional, List, Union, Self, ForwardRef
|
47
46
|
|
48
47
|
if TYPE_CHECKING:
|
49
48
|
from .datastructures.case import Case
|
@@ -82,7 +81,7 @@ def are_results_subclass_of_types(result_types: List[Any], types_: List[Type]) -
|
|
82
81
|
return True
|
83
82
|
|
84
83
|
|
85
|
-
def
|
84
|
+
def _get_imports_from_types(types: List[Type]) -> List[str]:
|
86
85
|
"""
|
87
86
|
Get the import statements for a list of types.
|
88
87
|
|
@@ -660,56 +659,149 @@ def get_func_rdr_model_name(func: Callable, include_file_name: bool = False) ->
|
|
660
659
|
return model_name
|
661
660
|
|
662
661
|
|
663
|
-
def
|
664
|
-
"""
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
662
|
+
def stringify_hint(tp):
|
663
|
+
"""Recursively convert a type hint to a string."""
|
664
|
+
if isinstance(tp, str):
|
665
|
+
return tp
|
666
|
+
|
667
|
+
# Handle ForwardRef (string annotations not yet evaluated)
|
668
|
+
if isinstance(tp, ForwardRef):
|
669
|
+
return tp.__forward_arg__
|
670
|
+
|
671
|
+
# Handle typing generics like List[int], Dict[str, List[int]], etc.
|
672
|
+
origin = get_origin(tp)
|
673
|
+
args = get_args(tp)
|
674
|
+
|
675
|
+
if origin is not None:
|
676
|
+
origin_str = getattr(origin, '__name__', str(origin)).capitalize()
|
677
|
+
args_str = ", ".join(stringify_hint(arg) for arg in args)
|
678
|
+
return f"{origin_str}[{args_str}]"
|
679
|
+
|
680
|
+
# Handle built-in types like int, str, etc.
|
681
|
+
if isinstance(tp, type):
|
682
|
+
if tp.__module__ == 'builtins':
|
683
|
+
return tp.__name__
|
684
|
+
return f"{tp.__qualname__}"
|
685
|
+
|
686
|
+
return str(tp)
|
687
|
+
|
688
|
+
|
689
|
+
def is_builtin_type(tp):
|
690
|
+
return isinstance(tp, type) and tp.__module__ == "builtins"
|
691
|
+
|
692
|
+
|
693
|
+
def is_typing_type(tp):
|
694
|
+
return tp.__module__ == "typing"
|
695
|
+
|
696
|
+
origin_type_to_hint = {
|
697
|
+
list: List,
|
698
|
+
set: Set,
|
699
|
+
dict: Dict,
|
700
|
+
tuple: Tuple,
|
701
|
+
}
|
702
|
+
|
703
|
+
def extract_types(tp, seen: Set = None) -> Set[type]:
|
704
|
+
"""Recursively extract all base types from a type hint."""
|
705
|
+
if seen is None:
|
706
|
+
seen = set()
|
707
|
+
|
708
|
+
if tp in seen or isinstance(tp, str):
|
709
|
+
return seen
|
710
|
+
|
711
|
+
# seen.add(tp)
|
712
|
+
|
713
|
+
if isinstance(tp, ForwardRef):
|
714
|
+
# Can't resolve until evaluated
|
715
|
+
return seen
|
716
|
+
|
717
|
+
origin = get_origin(tp)
|
718
|
+
args = get_args(tp)
|
719
|
+
|
720
|
+
if origin:
|
721
|
+
if origin in origin_type_to_hint:
|
722
|
+
seen.add(origin_type_to_hint[origin])
|
723
|
+
else:
|
724
|
+
seen.add(origin)
|
725
|
+
for arg in args:
|
726
|
+
extract_types(arg, seen)
|
727
|
+
|
728
|
+
elif isinstance(tp, type):
|
729
|
+
seen.add(tp)
|
730
|
+
|
731
|
+
return seen
|
732
|
+
|
733
|
+
|
734
|
+
def get_types_to_import_from_func_type_hints(func: Callable) -> Set[Type]:
|
735
|
+
"""
|
736
|
+
Extract importable types from a function's annotations.
|
737
|
+
|
738
|
+
:param func: The function to extract type hints from.
|
739
|
+
"""
|
740
|
+
hints = get_type_hints(func)
|
741
|
+
|
742
|
+
sig = inspect.signature(func)
|
743
|
+
all_hints = list(hints.values())
|
744
|
+
if sig.return_annotation != inspect.Signature.empty:
|
745
|
+
all_hints.append(sig.return_annotation)
|
746
|
+
|
747
|
+
for param in sig.parameters.values():
|
748
|
+
if param.annotation != inspect.Parameter.empty:
|
749
|
+
all_hints.append(param.annotation)
|
750
|
+
|
751
|
+
return get_types_to_import_from_type_hints(all_hints)
|
752
|
+
|
753
|
+
|
754
|
+
def get_types_to_import_from_type_hints(hints: List[Type]) -> Set[Type]:
|
755
|
+
"""
|
756
|
+
Extract importable types from a list of type hints.
|
757
|
+
|
758
|
+
:param hints: A list of type hints to extract types from.
|
759
|
+
:return: A set of types that need to be imported.
|
760
|
+
"""
|
761
|
+
seen_types = set()
|
762
|
+
for hint in hints:
|
763
|
+
extract_types(hint, seen_types)
|
764
|
+
|
765
|
+
# Filter out built-in and internal types
|
766
|
+
to_import = set()
|
767
|
+
for tp in seen_types:
|
768
|
+
if isinstance(tp, ForwardRef) or isinstance(tp, str):
|
769
|
+
continue
|
770
|
+
if not is_builtin_type(tp):
|
771
|
+
to_import.add(tp)
|
772
|
+
|
773
|
+
return to_import
|
774
|
+
|
775
|
+
|
776
|
+
def get_imports_from_types(type_objs: List[Type]) -> List[str]:
|
777
|
+
"""
|
778
|
+
Format import lines from type objects.
|
779
|
+
|
780
|
+
:param type_objs: A list of type objects to format.
|
781
|
+
"""
|
782
|
+
|
783
|
+
module_to_types = defaultdict(list)
|
784
|
+
for tp in type_objs:
|
785
|
+
try:
|
786
|
+
if isinstance(tp, type) or is_typing_type(tp):
|
787
|
+
module = tp.__module__
|
788
|
+
name = tp.__qualname__
|
789
|
+
elif hasattr(type(tp), "__module__"):
|
790
|
+
module = type(tp).__module__
|
791
|
+
name = type(tp).__qualname__
|
792
|
+
else:
|
793
|
+
continue
|
794
|
+
if module is None or module == 'builtins' or module.startswith('_'):
|
795
|
+
continue
|
796
|
+
module_to_types[module].append(name)
|
797
|
+
except AttributeError:
|
798
|
+
continue
|
799
|
+
|
800
|
+
lines = []
|
801
|
+
for module, names in module_to_types.items():
|
802
|
+
joined = ", ".join(sorted(set(names)))
|
803
|
+
lines.append(f"from {module} import {joined}")
|
804
|
+
return sorted(lines)
|
713
805
|
|
714
806
|
|
715
807
|
def get_method_args_as_dict(method: Callable, *args, **kwargs) -> Dict[str, Any]:
|
@@ -865,6 +957,8 @@ class SubclassJSONSerializer:
|
|
865
957
|
def to_json_static(obj, seen=None) -> Any:
|
866
958
|
if isinstance(obj, SubclassJSONSerializer):
|
867
959
|
return {"_type": get_full_class_name(obj.__class__), **obj._to_json()}
|
960
|
+
elif isinstance(obj, type):
|
961
|
+
return {"_type": get_full_class_name(obj)}
|
868
962
|
elif is_dataclass(obj):
|
869
963
|
return serialize_dataclass(obj, seen)
|
870
964
|
elif isinstance(obj, list):
|
@@ -1017,13 +1111,20 @@ def copy_orm_instance(instance: SQLTable) -> SQLTable:
|
|
1017
1111
|
:param instance: The instance to copy.
|
1018
1112
|
:return: The copied instance.
|
1019
1113
|
"""
|
1020
|
-
|
1114
|
+
try:
|
1115
|
+
session: Session = inspect(instance).session
|
1116
|
+
except NoInspectionAvailable:
|
1117
|
+
session = None
|
1021
1118
|
if session is not None:
|
1022
1119
|
session.expunge(instance)
|
1023
1120
|
new_instance = deepcopy(instance)
|
1024
1121
|
session.add(instance)
|
1025
1122
|
else:
|
1026
|
-
|
1123
|
+
try:
|
1124
|
+
new_instance = deepcopy(instance)
|
1125
|
+
except Exception as e:
|
1126
|
+
logging.debug(e)
|
1127
|
+
new_instance = instance
|
1027
1128
|
return new_instance
|
1028
1129
|
|
1029
1130
|
|
@@ -1037,8 +1138,12 @@ def copy_orm_instance_with_relationships(instance: SQLTable) -> SQLTable:
|
|
1037
1138
|
instance_cp = copy_orm_instance(instance)
|
1038
1139
|
for rel in class_mapper(instance.__class__).relationships:
|
1039
1140
|
related_obj = getattr(instance, rel.key)
|
1141
|
+
related_obj_cp = copy_orm_instance(related_obj)
|
1040
1142
|
if related_obj is not None:
|
1041
|
-
|
1143
|
+
try:
|
1144
|
+
setattr(instance_cp, rel.key, related_obj_cp)
|
1145
|
+
except Exception as e:
|
1146
|
+
logging.debug(e)
|
1042
1147
|
return instance_cp
|
1043
1148
|
|
1044
1149
|
|
@@ -1049,7 +1154,17 @@ def get_value_type_from_type_hint(attr_name: str, obj: Any) -> Type:
|
|
1049
1154
|
:param attr_name: The name of the attribute.
|
1050
1155
|
:param obj: The object to get the attributes from.
|
1051
1156
|
"""
|
1052
|
-
|
1157
|
+
# check first if obj is a function object
|
1158
|
+
if hasattr(obj, '__code__'):
|
1159
|
+
func_type_hints = get_type_hints(obj)
|
1160
|
+
if attr_name in func_type_hints:
|
1161
|
+
hint = func_type_hints[attr_name]
|
1162
|
+
origin = get_origin(hint)
|
1163
|
+
args = get_args(hint)
|
1164
|
+
else:
|
1165
|
+
raise ValueError(f"Unknown type hint: {attr_name}")
|
1166
|
+
else:
|
1167
|
+
hint, origin, args = get_hint_for_attribute(attr_name, obj)
|
1053
1168
|
if not origin and not hint:
|
1054
1169
|
if hasattr(obj, attr_name):
|
1055
1170
|
attr_value = getattr(obj, attr_name)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.5.
|
3
|
+
Version: 0.5.64
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -0,0 +1,24 @@
|
|
1
|
+
ripple_down_rules/__init__.py,sha256=5qEoRYuRIypGXWagthsMrLJNOpvkBokzNKBjmYbRyEo,100
|
2
|
+
ripple_down_rules/experts.py,sha256=bwozulI1rv0uyaMZQqEgapDO-s8wvW0D6Jqxmvu5fik,12610
|
3
|
+
ripple_down_rules/helpers.py,sha256=fRBjtknjhszsZAlFLWjiz_n-YOWOXa8LHQNSl-FjsQI,4203
|
4
|
+
ripple_down_rules/rdr.py,sha256=jhDL-a_nLSzAR-KQfKxFBoLhMZbazs_lTB_z-ebw-X0,48369
|
5
|
+
ripple_down_rules/rdr_decorators.py,sha256=pEVupcFqtHzPCaxZoxphHWlrSN6vCishdwUQ1hXiWtc,9193
|
6
|
+
ripple_down_rules/rules.py,sha256=2yFpbu6DUgYTKFJsuAod3YSM6MBMyYkiBsgKu4BgxZM,17794
|
7
|
+
ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
|
8
|
+
ripple_down_rules/utils.py,sha256=LtZ21VSB3Au27Wf23wpDdSJMv3ABQgE1dRc2kzzQj-c,54605
|
9
|
+
ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
|
10
|
+
ripple_down_rules/datastructures/callable_expression.py,sha256=f3wUPTrLa1INO-1qfgVz87ryrCABronfyq0_JKWoZCs,12800
|
11
|
+
ripple_down_rules/datastructures/case.py,sha256=1zSaXUljaH6z3SgMGzYPoDyjotNam791KpYgvxuMh90,15463
|
12
|
+
ripple_down_rules/datastructures/dataclasses.py,sha256=uIrEcvV0oJeMZM9ewGdSO7lIyZgza-4UA3L7gCy-lQk,8542
|
13
|
+
ripple_down_rules/datastructures/enums.py,sha256=ce7tqS0otfSTNAOwsnXlhsvIn4iW_Y_N3TNebF3YoZs,5700
|
14
|
+
ripple_down_rules/user_interface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
|
+
ripple_down_rules/user_interface/gui.py,sha256=_lgZAUXxxaBUFQJAHjA5TBPp6XEvJ62t-kSN8sPsocE,27379
|
16
|
+
ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=Jrf7NxOdlrwGXH0Xyz3vzQprY-PNx9etfePOTpm2Gu8,6479
|
17
|
+
ripple_down_rules/user_interface/object_diagram.py,sha256=FEa2HaYR9QmTE6NsOwBvZ0jqmu3DKyg6mig2VE5ZP4Y,4956
|
18
|
+
ripple_down_rules/user_interface/prompt.py,sha256=AkkltdDIaioN43lkRKDPKSjJcmdSSGZDMYz7AL7X9lE,8082
|
19
|
+
ripple_down_rules/user_interface/template_file_creator.py,sha256=VLS9Nxg6gPNa-YYliJ_VNsTvLPlZ003EVkJ2t8zuDgE,13563
|
20
|
+
ripple_down_rules-0.5.64.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
21
|
+
ripple_down_rules-0.5.64.dist-info/METADATA,sha256=s4rz2kIRehE7aDcErKUI0uifjq6wbOU6khX09EMXtF4,48189
|
22
|
+
ripple_down_rules-0.5.64.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
23
|
+
ripple_down_rules-0.5.64.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
24
|
+
ripple_down_rules-0.5.64.dist-info/RECORD,,
|
@@ -1,24 +0,0 @@
|
|
1
|
-
ripple_down_rules/__init__.py,sha256=MKeKdIBZpDHNeXgQzJQDmsXIXJUj5GFtP1tTx5ZbuwE,100
|
2
|
-
ripple_down_rules/experts.py,sha256=bwozulI1rv0uyaMZQqEgapDO-s8wvW0D6Jqxmvu5fik,12610
|
3
|
-
ripple_down_rules/helpers.py,sha256=TvTJU0BA3dPcAyzvZFvAu7jZqsp8Lu0HAAwvuizlGjg,2018
|
4
|
-
ripple_down_rules/rdr.py,sha256=4iobla4XmMwAOQsn_JZaZe2tWU0aMMvqgzP5WavIagI,49280
|
5
|
-
ripple_down_rules/rdr_decorators.py,sha256=bmn4h4a7xujTVxu-ofECe71cM_6iiqZhLVFosEItid4,7602
|
6
|
-
ripple_down_rules/rules.py,sha256=TPNVMqW9T-_46BS4WemrspLg5uG8kP6tsPvWWBAzJxg,17515
|
7
|
-
ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
|
8
|
-
ripple_down_rules/utils.py,sha256=N5Rgz7wb9oKrVLZiJG2P-irnsjhy7VR3Vqyggf4Mq7I,51564
|
9
|
-
ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
|
10
|
-
ripple_down_rules/datastructures/callable_expression.py,sha256=f3wUPTrLa1INO-1qfgVz87ryrCABronfyq0_JKWoZCs,12800
|
11
|
-
ripple_down_rules/datastructures/case.py,sha256=r8kjL9xP_wk84ThXusspgPMrAoed2bGQmKi54fzhmH8,15258
|
12
|
-
ripple_down_rules/datastructures/dataclasses.py,sha256=PuD-7zWqWT2p4FnGvnihHvZlZKg9A1ctnFgVYf2cs-8,8554
|
13
|
-
ripple_down_rules/datastructures/enums.py,sha256=ce7tqS0otfSTNAOwsnXlhsvIn4iW_Y_N3TNebF3YoZs,5700
|
14
|
-
ripple_down_rules/user_interface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
|
-
ripple_down_rules/user_interface/gui.py,sha256=jRRyQxgU_RK2e_wgi2gPag_FB8UCYOAXicRTk8_JWgo,27232
|
16
|
-
ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=24MIFwqnAhC6ofObEO6x5xRWRnyQmPpPmTvxbCKBrzM,6514
|
17
|
-
ripple_down_rules/user_interface/object_diagram.py,sha256=tsB6iuLNEbHxp5lR2WjyejjWbnAX_nHF9xS8jNPOQVk,4548
|
18
|
-
ripple_down_rules/user_interface/prompt.py,sha256=AkkltdDIaioN43lkRKDPKSjJcmdSSGZDMYz7AL7X9lE,8082
|
19
|
-
ripple_down_rules/user_interface/template_file_creator.py,sha256=FGtLfYBfr4310c7Dfa9b2qiOWLNzHk1q3kdhD70Ilg4,13804
|
20
|
-
ripple_down_rules-0.5.62.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
21
|
-
ripple_down_rules-0.5.62.dist-info/METADATA,sha256=6s7Y4pSt63IEgGQu_87ZOAX6XcFK9i67rJkIDSTZfoU,48189
|
22
|
-
ripple_down_rules-0.5.62.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
23
|
-
ripple_down_rules-0.5.62.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
24
|
-
ripple_down_rules-0.5.62.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|