ripple-down-rules 0.5.64__py3-none-any.whl → 0.5.75__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/datastructures/dataclasses.py +61 -2
- ripple_down_rules/helpers.py +8 -4
- ripple_down_rules/rdr.py +110 -59
- ripple_down_rules/rdr_decorators.py +7 -6
- ripple_down_rules/rules.py +53 -9
- ripple_down_rules/utils.py +71 -4
- {ripple_down_rules-0.5.64.dist-info → ripple_down_rules-0.5.75.dist-info}/METADATA +2 -1
- {ripple_down_rules-0.5.64.dist-info → ripple_down_rules-0.5.75.dist-info}/RECORD +12 -12
- {ripple_down_rules-0.5.64.dist-info → ripple_down_rules-0.5.75.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.5.64.dist-info → ripple_down_rules-0.5.75.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.5.64.dist-info → ripple_down_rules-0.5.75.dist-info}/top_level.txt +0 -0
ripple_down_rules/__init__.py
CHANGED
@@ -5,9 +5,11 @@ import typing
|
|
5
5
|
from dataclasses import dataclass, field
|
6
6
|
|
7
7
|
import typing_extensions
|
8
|
+
from omegaconf import MISSING
|
8
9
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
9
|
-
from typing_extensions import Any, Optional, Dict, Type, Tuple, Union, List, get_origin, Set
|
10
|
+
from typing_extensions import Any, Optional, Dict, Type, Tuple, Union, List, get_origin, Set, Callable
|
10
11
|
|
12
|
+
from ..utils import get_method_name, get_function_import_data, get_function_representation
|
11
13
|
from .callable_expression import CallableExpression
|
12
14
|
from .case import create_case, Case
|
13
15
|
from ..utils import copy_case, make_list, make_set, get_origin_and_args_from_type_hint, get_value_type_from_type_hint, \
|
@@ -37,6 +39,24 @@ class CaseQuery:
|
|
37
39
|
"""
|
38
40
|
Whether the attribute can only take one value (i.e. True) or multiple values (i.e. False).
|
39
41
|
"""
|
42
|
+
case_factory: Optional[Callable[[], Any]] = None
|
43
|
+
"""
|
44
|
+
The factory method that can be used to recreate the original case.
|
45
|
+
"""
|
46
|
+
case_factory_idx: Optional[int] = None
|
47
|
+
"""
|
48
|
+
This is used when the case factory is a list of cases, this index is used to select the case from the list.
|
49
|
+
"""
|
50
|
+
case_conf: Optional[CaseConf] = None
|
51
|
+
"""
|
52
|
+
The case configuration that is used to (re)create the original case, recommended to be used when you want to
|
53
|
+
the case to persist in the rule base, this would allow it to be used for merging with other similar conclusion RDRs.
|
54
|
+
"""
|
55
|
+
scenario: Optional[Callable] = None
|
56
|
+
"""
|
57
|
+
The executable scenario is the root callable that recreates the situation that the case is
|
58
|
+
created in, for example, when the case is created from a test function, this would be the test function itself.
|
59
|
+
"""
|
40
60
|
_target: Optional[CallableExpression] = None
|
41
61
|
"""
|
42
62
|
The target expression of the attribute.
|
@@ -225,4 +245,43 @@ class CaseQuery:
|
|
225
245
|
self.mutually_exclusive, _target=self.target, default_value=self.default_value,
|
226
246
|
scope=self.scope, _case=copy_case(self.case), _target_value=self.target_value,
|
227
247
|
conditions=self.conditions, is_function=self.is_function,
|
228
|
-
function_args_type_hints=self.function_args_type_hints
|
248
|
+
function_args_type_hints=self.function_args_type_hints,
|
249
|
+
case_factory=self.case_factory, case_factory_idx=self.case_factory_idx,
|
250
|
+
case_conf=self.case_conf, scenario=self.scenario)
|
251
|
+
|
252
|
+
|
253
|
+
@dataclass
|
254
|
+
class CaseConf:
|
255
|
+
factory_method: Callable[[Any], Any] = MISSING
|
256
|
+
|
257
|
+
def create(self) -> Any:
|
258
|
+
return self.factory_method()
|
259
|
+
|
260
|
+
|
261
|
+
@dataclass
|
262
|
+
class CaseFactoryMetaData:
|
263
|
+
factory_method: Optional[Callable[[Optional[CaseConf]], Any]] = None
|
264
|
+
factory_idx: Optional[int] = None
|
265
|
+
case_conf: Optional[CaseConf] = None
|
266
|
+
scenario: Optional[Callable] = None
|
267
|
+
|
268
|
+
@classmethod
|
269
|
+
def from_case_query(cls, case_query: CaseQuery) -> CaseFactoryMetaData:
|
270
|
+
return cls(factory_method=case_query.case_factory, factory_idx=case_query.case_factory_idx,
|
271
|
+
case_conf=case_query.case_conf, scenario=case_query.scenario)
|
272
|
+
|
273
|
+
def __repr__(self):
|
274
|
+
factory_method_repr = None
|
275
|
+
scenario_repr = None
|
276
|
+
if self.factory_method is not None:
|
277
|
+
factory_method_repr = get_function_representation(self.factory_method)
|
278
|
+
if self.scenario is not None:
|
279
|
+
scenario_repr = get_function_representation(self.scenario)
|
280
|
+
return (f"CaseFactoryMetaData("
|
281
|
+
f"factory_method={factory_method_repr}, "
|
282
|
+
f"factory_idx={self.factory_idx}, "
|
283
|
+
f"case_conf={self.case_conf},"
|
284
|
+
f" scenario={scenario_repr})")
|
285
|
+
|
286
|
+
def __str__(self):
|
287
|
+
return self.__repr__()
|
ripple_down_rules/helpers.py
CHANGED
@@ -3,6 +3,8 @@ from __future__ import annotations
|
|
3
3
|
import os
|
4
4
|
from types import ModuleType
|
5
5
|
|
6
|
+
from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
|
7
|
+
|
6
8
|
from .datastructures.case import create_case
|
7
9
|
from .datastructures.dataclasses import CaseQuery
|
8
10
|
from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING, Union
|
@@ -15,7 +17,8 @@ if TYPE_CHECKING:
|
|
15
17
|
|
16
18
|
|
17
19
|
def general_rdr_classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDownRules]],
|
18
|
-
case: Any, modify_original_case: bool = False
|
20
|
+
case: Any, modify_original_case: bool = False,
|
21
|
+
case_query: Optional[CaseQuery] = None) -> Dict[str, Any]:
|
19
22
|
"""
|
20
23
|
Classify a case by going through all classifiers and adding the categories that are classified,
|
21
24
|
and then restarting the classification until no more categories can be added.
|
@@ -23,6 +26,7 @@ def general_rdr_classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDow
|
|
23
26
|
:param classifiers_dict: A dictionary mapping conclusion types to the classifiers that produce them.
|
24
27
|
:param case: The case to classify.
|
25
28
|
:param modify_original_case: Whether to modify the original case or create a copy and modify it.
|
29
|
+
:param case_query: The case query to extract metadata from if needed.
|
26
30
|
:return: The categories that the case belongs to.
|
27
31
|
"""
|
28
32
|
conclusions = {}
|
@@ -31,7 +35,7 @@ def general_rdr_classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDow
|
|
31
35
|
while True:
|
32
36
|
new_conclusions = {}
|
33
37
|
for attribute_name, rdr in classifiers_dict.items():
|
34
|
-
pred_atts = rdr.classify(case_cp)
|
38
|
+
pred_atts = rdr.classify(case_cp, case_query=case_query)
|
35
39
|
if pred_atts is None:
|
36
40
|
continue
|
37
41
|
if rdr.mutually_exclusive:
|
@@ -49,8 +53,8 @@ def general_rdr_classify(classifiers_dict: Dict[str, Union[ModuleType, RippleDow
|
|
49
53
|
conclusions[attribute_name] = set()
|
50
54
|
conclusions[attribute_name].update(pred_atts)
|
51
55
|
if attribute_name in new_conclusions:
|
52
|
-
|
53
|
-
update_case(
|
56
|
+
temp_case_query = CaseQuery(case_cp, attribute_name, rdr.conclusion_type, rdr.mutually_exclusive)
|
57
|
+
update_case(temp_case_query, new_conclusions)
|
54
58
|
if len(new_conclusions) == 0:
|
55
59
|
break
|
56
60
|
return conclusions
|
ripple_down_rules/rdr.py
CHANGED
@@ -4,6 +4,8 @@ import copyreg
|
|
4
4
|
import importlib
|
5
5
|
import os
|
6
6
|
|
7
|
+
from ripple_down_rules.datastructures.dataclasses import CaseFactoryMetaData
|
8
|
+
|
7
9
|
from . import logger
|
8
10
|
import sys
|
9
11
|
from abc import ABC, abstractmethod
|
@@ -37,7 +39,7 @@ except ImportError as e:
|
|
37
39
|
from .utils import draw_tree, make_set, copy_case, \
|
38
40
|
SubclassJSONSerializer, make_list, get_type_from_string, \
|
39
41
|
is_conflicting, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
|
40
|
-
is_iterable, str_to_snake_case
|
42
|
+
is_iterable, str_to_snake_case, get_import_path_from_path, get_imports_from_types
|
41
43
|
|
42
44
|
|
43
45
|
class RippleDownRules(SubclassJSONSerializer, ABC):
|
@@ -112,7 +114,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
112
114
|
if not os.path.exists(save_dir + '/__init__.py'):
|
113
115
|
os.makedirs(save_dir, exist_ok=True)
|
114
116
|
with open(save_dir + '/__init__.py', 'w') as f:
|
115
|
-
f.write("
|
117
|
+
f.write("from . import *\n")
|
116
118
|
if model_name is not None:
|
117
119
|
self.model_name = model_name
|
118
120
|
elif self.model_name is None:
|
@@ -136,7 +138,11 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
136
138
|
model_dir = os.path.join(load_dir, model_name)
|
137
139
|
json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
|
138
140
|
rdr = cls.from_json_file(json_file)
|
139
|
-
|
141
|
+
try:
|
142
|
+
rdr.update_from_python(model_dir)
|
143
|
+
except (FileNotFoundError, ValueError) as e:
|
144
|
+
logger.warning(f"Could not load the python file for the model {model_name} from {model_dir}. "
|
145
|
+
f"Make sure the file exists and is valid.")
|
140
146
|
rdr.save_dir = load_dir
|
141
147
|
rdr.model_name = model_name
|
142
148
|
return rdr
|
@@ -215,13 +221,15 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
215
221
|
return self.classify(case)
|
216
222
|
|
217
223
|
@abstractmethod
|
218
|
-
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False
|
224
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
|
225
|
+
case_query: Optional[CaseQuery] = None) \
|
219
226
|
-> Optional[Union[CallableExpression, Dict[str, CallableExpression]]]:
|
220
227
|
"""
|
221
228
|
Classify a case.
|
222
229
|
|
223
230
|
:param case: The case to classify.
|
224
231
|
:param modify_case: Whether to modify the original case attributes with the conclusion or not.
|
232
|
+
:param case_query: The case query containing the case to classify and the target category to compare the case with.
|
225
233
|
:return: The category that the case belongs to.
|
226
234
|
"""
|
227
235
|
pass
|
@@ -229,6 +237,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
229
237
|
def fit_case(self, case_query: CaseQuery,
|
230
238
|
expert: Optional[Expert] = None,
|
231
239
|
update_existing_rules: bool = True,
|
240
|
+
scenario: Optional[Callable] = None,
|
232
241
|
**kwargs) \
|
233
242
|
-> Union[CallableExpression, Dict[str, CallableExpression]]:
|
234
243
|
"""
|
@@ -239,6 +248,7 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
239
248
|
:param expert: The expert to ask for differentiating features as new rule conditions.
|
240
249
|
:param update_existing_rules: Whether to update the existing same conclusion type rules that already gave
|
241
250
|
some conclusions with the type required by the case query.
|
251
|
+
:param scenario: The scenario at which the case was created, this is used to recreate the case if needed.
|
242
252
|
:return: The category that the case belongs to.
|
243
253
|
"""
|
244
254
|
if case_query is None:
|
@@ -247,14 +257,14 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
247
257
|
self.name = case_query.attribute_name if self.name is None else self.name
|
248
258
|
self.case_type = case_query.case_type if self.case_type is None else self.case_type
|
249
259
|
self.case_name = case_query.case_name if self.case_name is None else self.case_name
|
260
|
+
case_query.scenario = scenario if case_query.scenario is None else case_query.scenario
|
250
261
|
|
251
262
|
expert = expert or Human(viewer=self.viewer,
|
252
263
|
answers_save_path=self.save_dir + '/expert_answers'
|
253
264
|
if self.save_dir else None)
|
254
|
-
|
255
265
|
if case_query.target is None:
|
256
266
|
case_query_cp = copy(case_query)
|
257
|
-
conclusions = self.classify(case_query_cp.case, modify_case=True)
|
267
|
+
conclusions = self.classify(case_query_cp.case, modify_case=True, case_query=case_query_cp)
|
258
268
|
if self.should_i_ask_the_expert_for_a_target(conclusions, case_query_cp, update_existing_rules):
|
259
269
|
expert.ask_for_conclusion(case_query_cp)
|
260
270
|
case_query.target = case_query_cp.target
|
@@ -389,13 +399,11 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
|
|
389
399
|
:return: The module that contains the rdr classifier function.
|
390
400
|
"""
|
391
401
|
# remove from imports if exists first
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
pass
|
398
|
-
return importlib.import_module(name).classify
|
402
|
+
package_name = get_import_path_from_path(package_name)
|
403
|
+
name = f"{package_name}.{self.generated_python_file_name}" if package_name else self.generated_python_file_name
|
404
|
+
module = importlib.import_module(name)
|
405
|
+
importlib.reload(module)
|
406
|
+
return module.classify
|
399
407
|
|
400
408
|
|
401
409
|
class RDRWithCodeWriter(RippleDownRules, ABC):
|
@@ -411,6 +419,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
411
419
|
conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys() if not isinstance(rules_dict[rid], MultiClassStopRule)]
|
412
420
|
all_func_names = condition_func_names + conclusion_func_names
|
413
421
|
filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
|
422
|
+
cases_path = f"{model_dir}/{self.generated_python_cases_file_name}.py"
|
423
|
+
cases_import_path = get_import_path_from_path(model_dir)
|
424
|
+
cases_import_path = f"{cases_import_path}.{self.generated_python_cases_file_name}" if cases_import_path\
|
425
|
+
else self.generated_python_cases_file_name
|
414
426
|
functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
|
415
427
|
# get the scope from the imports in the file
|
416
428
|
scope = extract_imports(filepath)
|
@@ -418,13 +430,17 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
418
430
|
if rule.conditions is not None:
|
419
431
|
rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
|
420
432
|
rule.conditions.scope = scope
|
433
|
+
if os.path.exists(cases_path):
|
434
|
+
module = importlib.import_module(cases_import_path)
|
435
|
+
importlib.reload(module)
|
436
|
+
rule.corner_case_metadata = module.__dict__.get(f"corner_case_{rule.uid}", None)
|
421
437
|
if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
|
422
438
|
rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
|
423
439
|
rule.conclusion.scope = scope
|
424
440
|
|
425
441
|
@abstractmethod
|
426
442
|
def write_rules_as_source_code_to_file(self, rule: Rule, file, parent_indent: str = "",
|
427
|
-
defs_file: Optional[str] = None):
|
443
|
+
defs_file: Optional[str] = None, cases_file: Optional[str] = None):
|
428
444
|
"""
|
429
445
|
Write the rules as source code to a file.
|
430
446
|
|
@@ -432,6 +448,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
432
448
|
:param file: The file to write the source code to.
|
433
449
|
:param parent_indent: The indentation of the parent rule.
|
434
450
|
:param defs_file: The file to write the definitions to.
|
451
|
+
:param cases_file: The file to write the cases to.
|
435
452
|
"""
|
436
453
|
pass
|
437
454
|
|
@@ -444,14 +461,19 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
444
461
|
os.makedirs(model_dir, exist_ok=True)
|
445
462
|
if not os.path.exists(model_dir + '/__init__.py'):
|
446
463
|
with open(model_dir + '/__init__.py', 'w') as f:
|
447
|
-
f.write("
|
448
|
-
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
464
|
+
f.write("from . import *\n")
|
465
|
+
func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
|
449
466
|
file_name = model_dir + f"/{self.generated_python_file_name}.py"
|
450
467
|
defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
|
468
|
+
cases_file_name = model_dir + f"/{self.generated_python_cases_file_name}.py"
|
451
469
|
imports, defs_imports = self._get_imports()
|
452
470
|
# clear the files first
|
453
471
|
with open(defs_file_name, "w") as f:
|
454
472
|
f.write(defs_imports + "\n\n")
|
473
|
+
case_factory_import = get_imports_from_types([CaseFactoryMetaData])
|
474
|
+
with open(cases_file_name, "w") as cases_f:
|
475
|
+
cases_f.write("# This file contains the corner cases for the rules.\n")
|
476
|
+
cases_f.write('\n'.join(case_factory_import) + "\n\n\n")
|
455
477
|
with open(file_name, "w") as f:
|
456
478
|
imports += f"from .{self.generated_python_defs_file_name} import *\n"
|
457
479
|
f.write(imports + "\n\n")
|
@@ -461,7 +483,8 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
461
483
|
f.write(f"\n\n{func_def}")
|
462
484
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
463
485
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
464
|
-
|
486
|
+
self.write_rules_as_source_code_to_file(self.start_rule, file_name, " " * 4, defs_file=defs_file_name,
|
487
|
+
cases_file=cases_file_name)
|
465
488
|
|
466
489
|
@property
|
467
490
|
@abstractmethod
|
@@ -510,6 +533,10 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
|
|
510
533
|
def generated_python_defs_file_name(self) -> str:
|
511
534
|
return f"{self.generated_python_file_name}_defs"
|
512
535
|
|
536
|
+
@property
|
537
|
+
def generated_python_cases_file_name(self) -> str:
|
538
|
+
return f"{self.generated_python_file_name}_cases"
|
539
|
+
|
513
540
|
|
514
541
|
@property
|
515
542
|
def conclusion_type(self) -> Tuple[Type]:
|
@@ -592,7 +619,7 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
592
619
|
pred = self.evaluate(case_query.case)
|
593
620
|
if pred.conclusion(case_query.case) != case_query.target_value:
|
594
621
|
expert.ask_for_conditions(case_query, pred)
|
595
|
-
pred.fit_rule(case_query
|
622
|
+
pred.fit_rule(case_query)
|
596
623
|
|
597
624
|
return self.classify(case_query.case)
|
598
625
|
|
@@ -605,18 +632,24 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
605
632
|
"""
|
606
633
|
if not self.start_rule:
|
607
634
|
expert.ask_for_conditions(case_query)
|
608
|
-
self.start_rule = SingleClassRule(case_query
|
609
|
-
conclusion_name=case_query.attribute_name)
|
635
|
+
self.start_rule = SingleClassRule.from_case_query(case_query)
|
610
636
|
|
611
|
-
def classify(self, case: Case, modify_case: bool = False
|
637
|
+
def classify(self, case: Case, modify_case: bool = False,
|
638
|
+
case_query: Optional[CaseQuery] = None) -> Optional[Any]:
|
612
639
|
"""
|
613
640
|
Classify a case by recursively evaluating the rules until a rule fires or the last rule is reached.
|
614
641
|
|
615
642
|
:param case: The case to classify.
|
616
643
|
:param modify_case: Whether to modify the original case attributes with the conclusion or not.
|
644
|
+
:param case_query: The case query containing the case and the target category to compare the case with.
|
617
645
|
"""
|
618
646
|
pred = self.evaluate(case)
|
619
|
-
|
647
|
+
conclusion = pred.conclusion(case) if pred is not None else None
|
648
|
+
if pred is not None and pred.fired and case_query is not None:
|
649
|
+
if pred.corner_case_metadata is None and conclusion is not None\
|
650
|
+
and type(conclusion) in case_query.core_attribute_type:
|
651
|
+
pred.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
652
|
+
return conclusion if pred is not None and pred.fired else self.default_conclusion
|
620
653
|
|
621
654
|
def evaluate(self, case: Case) -> SingleClassRule:
|
622
655
|
"""
|
@@ -631,23 +664,27 @@ class SingleClassRDR(RDRWithCodeWriter):
|
|
631
664
|
with open(model_dir + f"/{self.generated_python_file_name}.py", "a") as f:
|
632
665
|
f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
|
633
666
|
|
634
|
-
def write_rules_as_source_code_to_file(self, rule: SingleClassRule,
|
635
|
-
defs_file: Optional[str] = None):
|
667
|
+
def write_rules_as_source_code_to_file(self, rule: SingleClassRule, filename: str, parent_indent: str = "",
|
668
|
+
defs_file: Optional[str] = None, cases_file: Optional[str] = None):
|
636
669
|
"""
|
637
670
|
Write the rules as source code to a file.
|
638
671
|
"""
|
639
672
|
if rule.conditions:
|
673
|
+
rule.write_corner_case_as_source_code(cases_file)
|
640
674
|
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
641
|
-
|
675
|
+
with open(filename, "a") as file:
|
676
|
+
file.write(if_clause)
|
642
677
|
if rule.refinement:
|
643
|
-
self.write_rules_as_source_code_to_file(rule.refinement,
|
644
|
-
defs_file=defs_file)
|
678
|
+
self.write_rules_as_source_code_to_file(rule.refinement, filename, parent_indent + " ",
|
679
|
+
defs_file=defs_file, cases_file=cases_file)
|
645
680
|
|
646
681
|
conclusion_call = rule.write_conclusion_as_source_code(parent_indent, defs_file)
|
647
|
-
|
682
|
+
with open(filename, "a") as file:
|
683
|
+
file.write(conclusion_call)
|
648
684
|
|
649
685
|
if rule.alternative:
|
650
|
-
self.write_rules_as_source_code_to_file(rule.alternative,
|
686
|
+
self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
|
687
|
+
cases_file=cases_file)
|
651
688
|
|
652
689
|
@property
|
653
690
|
def conclusion_type_hint(self) -> str:
|
@@ -699,13 +736,19 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
699
736
|
super(MultiClassRDR, self).__init__(start_rule, **kwargs)
|
700
737
|
self.mode: MCRDRMode = mode
|
701
738
|
|
702
|
-
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False
|
739
|
+
def classify(self, case: Union[Case, SQLTable], modify_case: bool = False,
|
740
|
+
case_query: Optional[CaseQuery] = None) -> Set[Any]:
|
703
741
|
evaluated_rule = self.start_rule
|
704
742
|
self.conclusions = []
|
705
743
|
while evaluated_rule:
|
706
744
|
next_rule = evaluated_rule(case)
|
707
745
|
if evaluated_rule.fired:
|
708
|
-
|
746
|
+
rule_conclusion = evaluated_rule.conclusion(case)
|
747
|
+
if evaluated_rule.corner_case_metadata is None and case_query is not None:
|
748
|
+
if rule_conclusion is not None and len(make_list(rule_conclusion)) > 0\
|
749
|
+
and any(ct in case_query.core_attribute_type for ct in map(type, make_list(rule_conclusion))):
|
750
|
+
evaluated_rule.corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
751
|
+
self.add_conclusion(rule_conclusion)
|
709
752
|
evaluated_rule = next_rule
|
710
753
|
return make_set(self.conclusions)
|
711
754
|
|
@@ -733,7 +776,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
733
776
|
self.stop_wrong_conclusion_else_add_it(case_query, expert, evaluated_rule)
|
734
777
|
else:
|
735
778
|
# Rule fired and target is correct or there is no target to compare
|
736
|
-
self.add_conclusion(
|
779
|
+
self.add_conclusion(rule_conclusion)
|
737
780
|
|
738
781
|
if not next_rule:
|
739
782
|
if not make_set(target_value).issubset(make_set(self.conclusions)):
|
@@ -745,24 +788,31 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
745
788
|
return self.conclusions
|
746
789
|
|
747
790
|
def write_rules_as_source_code_to_file(self, rule: Union[MultiClassTopRule, MultiClassStopRule],
|
748
|
-
|
791
|
+
filename: str, parent_indent: str = "", defs_file: Optional[str] = None,
|
792
|
+
cases_file: Optional[str] = None):
|
749
793
|
if rule == self.start_rule:
|
750
|
-
|
794
|
+
with open(filename, "a") as file:
|
795
|
+
file.write(f"{parent_indent}conclusions = set()\n")
|
751
796
|
if rule.conditions:
|
797
|
+
rule.write_corner_case_as_source_code(cases_file)
|
752
798
|
if_clause = rule.write_condition_as_source_code(parent_indent, defs_file)
|
753
|
-
|
799
|
+
with open(filename, "a") as file:
|
800
|
+
file.write(if_clause)
|
754
801
|
conclusion_indent = parent_indent
|
755
802
|
if hasattr(rule, "refinement") and rule.refinement:
|
756
|
-
self.write_rules_as_source_code_to_file(rule.refinement,
|
757
|
-
defs_file=defs_file)
|
803
|
+
self.write_rules_as_source_code_to_file(rule.refinement, filename, parent_indent + " ",
|
804
|
+
defs_file=defs_file, cases_file=cases_file)
|
758
805
|
conclusion_indent = parent_indent + " " * 4
|
759
|
-
|
806
|
+
with open(filename, "a") as file:
|
807
|
+
file.write(f"{conclusion_indent}else:\n")
|
760
808
|
|
761
809
|
conclusion_call = rule.write_conclusion_as_source_code(conclusion_indent, defs_file)
|
762
|
-
|
810
|
+
with open(filename, "a") as file:
|
811
|
+
file.write(conclusion_call)
|
763
812
|
|
764
813
|
if rule.alternative:
|
765
|
-
self.write_rules_as_source_code_to_file(rule.alternative,
|
814
|
+
self.write_rules_as_source_code_to_file(rule.alternative, filename, parent_indent, defs_file=defs_file,
|
815
|
+
cases_file=cases_file)
|
766
816
|
|
767
817
|
@property
|
768
818
|
def conclusion_type_hint(self) -> str:
|
@@ -788,8 +838,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
788
838
|
"""
|
789
839
|
if not self.start_rule:
|
790
840
|
conditions = expert.ask_for_conditions(case_query)
|
791
|
-
self.start_rule = MultiClassTopRule(
|
792
|
-
conclusion_name=case_query.attribute_name)
|
841
|
+
self.start_rule = MultiClassTopRule.from_case_query(case_query)
|
793
842
|
|
794
843
|
@property
|
795
844
|
def last_top_rule(self) -> Optional[MultiClassTopRule]:
|
@@ -810,7 +859,7 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
810
859
|
if is_conflicting(rule_conclusion, case_query.target_value):
|
811
860
|
self.stop_conclusion(case_query, expert, evaluated_rule)
|
812
861
|
else:
|
813
|
-
self.add_conclusion(
|
862
|
+
self.add_conclusion(rule_conclusion)
|
814
863
|
|
815
864
|
def stop_conclusion(self, case_query: CaseQuery,
|
816
865
|
expert: Expert, evaluated_rule: MultiClassTopRule):
|
@@ -822,12 +871,13 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
822
871
|
:param evaluated_rule: The evaluated rule to ask the expert about.
|
823
872
|
"""
|
824
873
|
conditions = expert.ask_for_conditions(case_query, evaluated_rule)
|
825
|
-
evaluated_rule.fit_rule(case_query
|
874
|
+
evaluated_rule.fit_rule(case_query)
|
826
875
|
if self.mode == MCRDRMode.StopPlusRule:
|
827
876
|
self.stop_rule_conditions = conditions
|
828
877
|
if self.mode == MCRDRMode.StopPlusRuleCombined:
|
829
878
|
new_top_rule_conditions = conditions.combine_with(evaluated_rule.conditions)
|
830
|
-
|
879
|
+
case_query.conditions = new_top_rule_conditions
|
880
|
+
self.add_top_rule(case_query)
|
831
881
|
|
832
882
|
def add_rule_for_case(self, case_query: CaseQuery, expert: Expert):
|
833
883
|
"""
|
@@ -839,19 +889,19 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
839
889
|
if self.stop_rule_conditions and self.mode == MCRDRMode.StopPlusRule:
|
840
890
|
conditions = self.stop_rule_conditions
|
841
891
|
self.stop_rule_conditions = None
|
892
|
+
case_query.conditions = conditions
|
842
893
|
else:
|
843
894
|
conditions = expert.ask_for_conditions(case_query)
|
844
|
-
self.add_top_rule(
|
895
|
+
self.add_top_rule(case_query)
|
845
896
|
|
846
|
-
def add_conclusion(self,
|
897
|
+
def add_conclusion(self, rule_conclusion: List[Any]) -> None:
|
847
898
|
"""
|
848
899
|
Add the conclusion of the evaluated rule to the list of conclusions.
|
849
900
|
|
850
|
-
:param
|
851
|
-
|
901
|
+
:param rule_conclusion: The conclusion of the evaluated rule, which can be a single conclusion
|
902
|
+
or a set of conclusions.
|
852
903
|
"""
|
853
904
|
conclusion_types = [type(c) for c in self.conclusions]
|
854
|
-
rule_conclusion = evaluated_rule.conclusion(case)
|
855
905
|
if type(rule_conclusion) not in conclusion_types:
|
856
906
|
self.conclusions.extend(make_list(rule_conclusion))
|
857
907
|
else:
|
@@ -864,15 +914,13 @@ class MultiClassRDR(RDRWithCodeWriter):
|
|
864
914
|
self.conclusions.remove(c)
|
865
915
|
self.conclusions.extend(make_list(combined_conclusion))
|
866
916
|
|
867
|
-
def add_top_rule(self,
|
917
|
+
def add_top_rule(self, case_query: CaseQuery):
|
868
918
|
"""
|
869
919
|
Add a top rule to the classifier, which is a rule that is always checked and is part of the start_rules list.
|
870
920
|
|
871
|
-
:param
|
872
|
-
:param conclusion: The conclusion of the rule.
|
873
|
-
:param corner_case: The corner case of the rule.
|
921
|
+
:param case_query: The case query to add the top rule for.
|
874
922
|
"""
|
875
|
-
self.start_rule.alternative = MultiClassTopRule(
|
923
|
+
self.start_rule.alternative = MultiClassTopRule.from_case_query(case_query)
|
876
924
|
|
877
925
|
@staticmethod
|
878
926
|
def start_rule_type() -> Type[Rule]:
|
@@ -933,16 +981,19 @@ class GeneralRDR(RippleDownRules):
|
|
933
981
|
def start_rules(self) -> List[Union[SingleClassRule, MultiClassTopRule]]:
|
934
982
|
return [rdr.start_rule for rdr in self.start_rules_dict.values()]
|
935
983
|
|
936
|
-
def classify(self, case: Any, modify_case: bool = False
|
984
|
+
def classify(self, case: Any, modify_case: bool = False,
|
985
|
+
case_query: Optional[CaseQuery] = None) -> Optional[Dict[str, Any]]:
|
937
986
|
"""
|
938
987
|
Classify a case by going through all RDRs and adding the categories that are classified, and then restarting
|
939
988
|
the classification until no more categories can be added.
|
940
989
|
|
941
990
|
:param case: The case to classify.
|
942
991
|
:param modify_case: Whether to modify the original case or create a copy and modify it.
|
992
|
+
:param case_query: The case query containing the case and the target category to compare the case with.
|
943
993
|
:return: The categories that the case belongs to.
|
944
994
|
"""
|
945
|
-
return general_rdr_classify(self.start_rules_dict, case, modify_original_case=modify_case
|
995
|
+
return general_rdr_classify(self.start_rules_dict, case, modify_original_case=modify_case,
|
996
|
+
case_query=case_query)
|
946
997
|
|
947
998
|
def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
|
948
999
|
-> Dict[str, Any]:
|
@@ -1029,7 +1080,7 @@ class GeneralRDR(RippleDownRules):
|
|
1029
1080
|
"""
|
1030
1081
|
for rdr in self.start_rules_dict.values():
|
1031
1082
|
rdr._write_to_python(model_dir)
|
1032
|
-
func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
|
1083
|
+
func_def = f"def classify(case: {self.case_type.__name__}, **kwargs) -> {self.conclusion_type_hint}:\n"
|
1033
1084
|
with open(model_dir + f"/{self.generated_python_file_name}.py", "w") as f:
|
1034
1085
|
f.write(self._get_imports() + "\n\n")
|
1035
1086
|
f.write("classifiers_dict = dict()\n")
|
@@ -1039,7 +1090,7 @@ class GeneralRDR(RippleDownRules):
|
|
1039
1090
|
f.write(func_def)
|
1040
1091
|
f.write(f"{' ' * 4}if not isinstance(case, Case):\n"
|
1041
1092
|
f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
|
1042
|
-
f.write(f"{' ' * 4}return general_rdr_classify(classifiers_dict, case)\n")
|
1093
|
+
f.write(f"{' ' * 4}return general_rdr_classify(classifiers_dict, case, **kwargs)\n")
|
1043
1094
|
|
1044
1095
|
@property
|
1045
1096
|
def _default_generated_python_file_name(self) -> Optional[str]:
|
@@ -6,17 +6,18 @@ of the RDRs.
|
|
6
6
|
import os.path
|
7
7
|
from functools import wraps
|
8
8
|
|
9
|
-
from pyparsing.tools.cvt_pyparsing_pep8_names import camel_to_snake
|
10
9
|
from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union, Sequence
|
11
10
|
|
12
|
-
from ripple_down_rules.datastructures.case import
|
11
|
+
from ripple_down_rules.datastructures.case import Case
|
13
12
|
from ripple_down_rules.datastructures.dataclasses import CaseQuery
|
14
|
-
from ripple_down_rules.datastructures.enums import Category
|
15
13
|
from ripple_down_rules.experts import Expert, Human
|
16
|
-
from ripple_down_rules.rdr import GeneralRDR
|
17
|
-
|
14
|
+
from ripple_down_rules.rdr import GeneralRDR
|
15
|
+
try:
|
16
|
+
from ripple_down_rules.user_interface.gui import RDRCaseViewer
|
17
|
+
except ImportError:
|
18
|
+
RDRCaseViewer = None
|
18
19
|
from ripple_down_rules.utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
|
19
|
-
get_method_class_if_exists,
|
20
|
+
get_method_class_if_exists, str_to_snake_case
|
20
21
|
|
21
22
|
|
22
23
|
class RDRDecorator:
|
ripple_down_rules/rules.py
CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import logging
|
4
4
|
import re
|
5
5
|
from abc import ABC, abstractmethod
|
6
|
+
from pathlib import Path
|
6
7
|
from uuid import uuid4
|
7
8
|
|
8
9
|
from anytree import NodeMixin
|
@@ -11,8 +12,9 @@ from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
|
|
11
12
|
|
12
13
|
from .datastructures.callable_expression import CallableExpression
|
13
14
|
from .datastructures.case import Case
|
15
|
+
from .datastructures.dataclasses import CaseFactoryMetaData, CaseQuery
|
14
16
|
from .datastructures.enums import RDREdge, Stop
|
15
|
-
from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name
|
17
|
+
from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_imports_from_types
|
16
18
|
|
17
19
|
|
18
20
|
class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
@@ -27,7 +29,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
27
29
|
corner_case: Optional[Union[Case, SQLTable]] = None,
|
28
30
|
weight: Optional[str] = None,
|
29
31
|
conclusion_name: Optional[str] = None,
|
30
|
-
uid: Optional[str] = None
|
32
|
+
uid: Optional[str] = None,
|
33
|
+
corner_case_metadata: Optional[CaseFactoryMetaData] = None):
|
31
34
|
"""
|
32
35
|
A rule in the ripple down rules classifier.
|
33
36
|
|
@@ -38,10 +41,13 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
38
41
|
:param weight: The weight of the rule, which is the type of edge connecting the rule to its parent.
|
39
42
|
:param conclusion_name: The name of the conclusion of the rule.
|
40
43
|
:param uid: The unique id of the rule.
|
44
|
+
:param corner_case_metadata: Metadata about the corner case, such as the factory that created it or the
|
45
|
+
scenario it is based on.
|
41
46
|
"""
|
42
47
|
super(Rule, self).__init__()
|
43
48
|
self.conclusion = conclusion
|
44
49
|
self.corner_case = corner_case
|
50
|
+
self.corner_case_metadata: Optional[CaseFactoryMetaData] = corner_case_metadata
|
45
51
|
self.parent = parent
|
46
52
|
self.weight: Optional[str] = weight
|
47
53
|
self.conditions = conditions if conditions else None
|
@@ -51,6 +57,20 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
51
57
|
# generate a unique id for the rule using uuid4
|
52
58
|
self.uid: str = uid if uid else str(uuid4().int)
|
53
59
|
|
60
|
+
@classmethod
|
61
|
+
def from_case_query(cls, case_query: CaseQuery) -> Rule:
|
62
|
+
"""
|
63
|
+
Create a SingleClassRule from a CaseQuery.
|
64
|
+
|
65
|
+
:param case_query: The CaseQuery to create the rule from.
|
66
|
+
:return: A SingleClassRule instance.
|
67
|
+
"""
|
68
|
+
corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
69
|
+
return cls(conditions=case_query.conditions, conclusion=case_query.target,
|
70
|
+
corner_case=case_query.case, parent=None,
|
71
|
+
corner_case_metadata=corner_case_metadata,
|
72
|
+
conclusion_name=case_query.attribute_name)
|
73
|
+
|
54
74
|
def _post_detach(self, parent):
|
55
75
|
"""
|
56
76
|
Called after this node is detached from the tree, useful when drawing the tree.
|
@@ -82,6 +102,26 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
82
102
|
"""
|
83
103
|
pass
|
84
104
|
|
105
|
+
def write_corner_case_as_source_code(self, cases_file: Path) -> None:
|
106
|
+
"""
|
107
|
+
Write the source code representation of the corner case of the rule to a file.
|
108
|
+
|
109
|
+
:param cases_file: The file to write the corner case to if it is a definition.
|
110
|
+
"""
|
111
|
+
if self.corner_case_metadata is None:
|
112
|
+
return
|
113
|
+
types_to_import = set()
|
114
|
+
if self.corner_case_metadata.factory_method is not None:
|
115
|
+
types_to_import.add(self.corner_case_metadata.factory_method)
|
116
|
+
if self.corner_case_metadata.scenario is not None:
|
117
|
+
types_to_import.add(self.corner_case_metadata.scenario)
|
118
|
+
if self.corner_case_metadata.case_conf is not None:
|
119
|
+
types_to_import.add(self.corner_case_metadata.case_conf)
|
120
|
+
imports = get_imports_from_types(list(types_to_import))
|
121
|
+
with open(cases_file, 'a') as f:
|
122
|
+
f.write("\n".join(imports) + "\n\n\n")
|
123
|
+
f.write(f"corner_case_{self.uid} = {self.corner_case_metadata}" + "\n\n\n")
|
124
|
+
|
85
125
|
def write_conclusion_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
|
86
126
|
"""
|
87
127
|
Get the source code representation of the conclusion of the rule.
|
@@ -282,9 +322,12 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
|
282
322
|
returned_rule = self.alternative(x) if self.alternative else self
|
283
323
|
return returned_rule if returned_rule.fired else self
|
284
324
|
|
285
|
-
def fit_rule(self,
|
286
|
-
|
287
|
-
|
325
|
+
def fit_rule(self, case_query: CaseQuery):
|
326
|
+
corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
327
|
+
new_rule = SingleClassRule(case_query.conditions, case_query.target,
|
328
|
+
corner_case=case_query.case, parent=self,
|
329
|
+
corner_case_metadata=corner_case_metadata,
|
330
|
+
)
|
288
331
|
if self.fired:
|
289
332
|
self.refinement = new_rule
|
290
333
|
else:
|
@@ -368,11 +411,12 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
|
368
411
|
elif self.alternative: # Here alternative refers to next rule in MultiClassRDR
|
369
412
|
return self.alternative
|
370
413
|
|
371
|
-
def fit_rule(self,
|
372
|
-
if self.fired and target != self.conclusion:
|
373
|
-
self.refinement = MultiClassStopRule(conditions, corner_case=
|
414
|
+
def fit_rule(self, case_query: CaseQuery):
|
415
|
+
if self.fired and case_query.target != self.conclusion:
|
416
|
+
self.refinement = MultiClassStopRule(case_query.conditions, corner_case=case_query.case, parent=self)
|
374
417
|
elif not self.fired:
|
375
|
-
self.alternative = MultiClassTopRule(conditions, target,
|
418
|
+
self.alternative = MultiClassTopRule(case_query.conditions, case_query.target,
|
419
|
+
corner_case=case_query.case, parent=self)
|
376
420
|
|
377
421
|
def _to_json(self) -> Dict[str, Any]:
|
378
422
|
self.json_serialization = {**Rule._to_json(self),
|
ripple_down_rules/utils.py
CHANGED
@@ -14,11 +14,13 @@ from collections import UserDict, defaultdict
|
|
14
14
|
from copy import deepcopy, copy
|
15
15
|
from dataclasses import is_dataclass, fields
|
16
16
|
from enum import Enum
|
17
|
+
from os.path import dirname
|
17
18
|
from textwrap import dedent
|
18
19
|
from types import NoneType
|
19
20
|
|
20
21
|
from sqlalchemy.exc import NoInspectionAvailable
|
21
22
|
|
23
|
+
|
22
24
|
try:
|
23
25
|
import matplotlib
|
24
26
|
from matplotlib import pyplot as plt
|
@@ -157,7 +159,7 @@ def extract_function_source(file_path: str,
|
|
157
159
|
return_line_numbers: bool = False,
|
158
160
|
include_signature: bool = True) \
|
159
161
|
-> Union[Dict[str, Union[str, List[str]]],
|
160
|
-
Tuple[Dict[str, Union[str, List[str]]],
|
162
|
+
Tuple[Dict[str, Union[str, List[str]]], Dict[str, Tuple[int, int]]]]:
|
161
163
|
"""
|
162
164
|
Extract the source code of a function from a file.
|
163
165
|
|
@@ -176,7 +178,7 @@ def extract_function_source(file_path: str,
|
|
176
178
|
tree = ast.parse(source)
|
177
179
|
function_names = make_list(function_names)
|
178
180
|
functions_source: Dict[str, Union[str, List[str]]] = {}
|
179
|
-
line_numbers
|
181
|
+
line_numbers: Dict[str, Tuple[int, int]] = {}
|
180
182
|
for node in tree.body:
|
181
183
|
if isinstance(node, ast.FunctionDef) and (node.name in function_names or len(function_names) == 0):
|
182
184
|
# Get the line numbers of the function
|
@@ -184,7 +186,7 @@ def extract_function_source(file_path: str,
|
|
184
186
|
func_lines = lines[node.lineno - 1:node.end_lineno]
|
185
187
|
if not include_signature:
|
186
188
|
func_lines = func_lines[1:]
|
187
|
-
line_numbers.
|
189
|
+
line_numbers[node.name] = (node.lineno, node.end_lineno)
|
188
190
|
functions_source[node.name] = dedent("\n".join(func_lines)) if join_lines else func_lines
|
189
191
|
if (len(functions_source) >= len(function_names)) and (not len(function_names) == 0):
|
190
192
|
break
|
@@ -773,6 +775,64 @@ def get_types_to_import_from_type_hints(hints: List[Type]) -> Set[Type]:
|
|
773
775
|
return to_import
|
774
776
|
|
775
777
|
|
778
|
+
def get_import_path_from_path(path: str) -> Optional[str]:
|
779
|
+
"""
|
780
|
+
Convert a file system path to a Python import path.
|
781
|
+
|
782
|
+
:param path: The file system path to convert.
|
783
|
+
:return: The Python import path.
|
784
|
+
"""
|
785
|
+
package_name = os.path.abspath(path)
|
786
|
+
formated_package_name = package_name.strip('./').replace('/', '.')
|
787
|
+
parent_package_idx = 0
|
788
|
+
packages = formated_package_name.split('.')
|
789
|
+
for i, possible_pacakge in enumerate(reversed(packages)):
|
790
|
+
if i == 0:
|
791
|
+
current_path = package_name
|
792
|
+
else:
|
793
|
+
current_path = '/' + '/'.join(packages[:-i])
|
794
|
+
if os.path.exists(os.path.join(current_path, '__init__.py')):
|
795
|
+
parent_package_idx -= 1
|
796
|
+
else:
|
797
|
+
break
|
798
|
+
package_name = '.'.join(packages[parent_package_idx:]) if parent_package_idx < 0 else None
|
799
|
+
return package_name
|
800
|
+
|
801
|
+
|
802
|
+
def get_function_import_data(func: Callable) -> Tuple[str, str]:
|
803
|
+
"""
|
804
|
+
Get the import path of a function.
|
805
|
+
|
806
|
+
:param func: The function to get the import path for.
|
807
|
+
:return: The import path of the function.
|
808
|
+
"""
|
809
|
+
func_name = get_method_name(func)
|
810
|
+
func_class_name = get_method_class_name_if_exists(func)
|
811
|
+
func_file_path = get_method_file_name(func)
|
812
|
+
func_file_name = func_file_path.split('/')[-1].split('.')[0] # Get the file name without extension
|
813
|
+
func_import_path = get_import_path_from_path(dirname(func_file_path))
|
814
|
+
func_import_path = f"{func_import_path}.{func_file_name}" if func_import_path else func_file_name
|
815
|
+
if func_class_name and func_class_name != func_name:
|
816
|
+
func_import_name = func_class_name
|
817
|
+
else:
|
818
|
+
func_import_name = func_name
|
819
|
+
return func_import_path, func_import_name
|
820
|
+
|
821
|
+
|
822
|
+
def get_function_representation(func: Callable) -> str:
|
823
|
+
"""
|
824
|
+
Get a string representation of a function, including its module and class if applicable.
|
825
|
+
|
826
|
+
:param func: The function to represent.
|
827
|
+
:return: A string representation of the function.
|
828
|
+
"""
|
829
|
+
func_name = get_method_name(func)
|
830
|
+
func_class_name = get_method_class_name_if_exists(func)
|
831
|
+
if func_class_name and func_class_name != func_name:
|
832
|
+
return f"{func_class_name}.{func_name}"
|
833
|
+
return func_name
|
834
|
+
|
835
|
+
|
776
836
|
def get_imports_from_types(type_objs: List[Type]) -> List[str]:
|
777
837
|
"""
|
778
838
|
Format import lines from type objects.
|
@@ -781,11 +841,14 @@ def get_imports_from_types(type_objs: List[Type]) -> List[str]:
|
|
781
841
|
"""
|
782
842
|
|
783
843
|
module_to_types = defaultdict(list)
|
844
|
+
other_imports = []
|
784
845
|
for tp in type_objs:
|
785
846
|
try:
|
786
847
|
if isinstance(tp, type) or is_typing_type(tp):
|
787
848
|
module = tp.__module__
|
788
849
|
name = tp.__qualname__
|
850
|
+
elif callable(tp):
|
851
|
+
module, name = get_function_import_data(tp)
|
789
852
|
elif hasattr(type(tp), "__module__"):
|
790
853
|
module = type(tp).__module__
|
791
854
|
name = type(tp).__qualname__
|
@@ -801,6 +864,8 @@ def get_imports_from_types(type_objs: List[Type]) -> List[str]:
|
|
801
864
|
for module, names in module_to_types.items():
|
802
865
|
joined = ", ".join(sorted(set(names)))
|
803
866
|
lines.append(f"from {module} import {joined}")
|
867
|
+
if other_imports:
|
868
|
+
lines.extend(other_imports)
|
804
869
|
return sorted(lines)
|
805
870
|
|
806
871
|
|
@@ -838,7 +903,9 @@ def get_method_class_name_if_exists(method: Callable) -> Optional[str]:
|
|
838
903
|
:return: The class name of the method.
|
839
904
|
"""
|
840
905
|
if hasattr(method, "__self__"):
|
841
|
-
if hasattr(method.__self__, "
|
906
|
+
if hasattr(method.__self__, "__name__"):
|
907
|
+
return method.__self__.__name__
|
908
|
+
elif hasattr(method.__self__, "__class__"):
|
842
909
|
return method.__self__.__class__.__name__
|
843
910
|
return method.__qualname__.split('.')[0] if hasattr(method, "__qualname__") else None
|
844
911
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.5.
|
3
|
+
Version: 0.5.75
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -694,6 +694,7 @@ Requires-Dist: pygments
|
|
694
694
|
Requires-Dist: sqlalchemy
|
695
695
|
Requires-Dist: pandas
|
696
696
|
Requires-Dist: pyparsing>=3.2.3
|
697
|
+
Requires-Dist: omegaconf
|
697
698
|
Provides-Extra: viz
|
698
699
|
Requires-Dist: networkx>=3.1; extra == "viz"
|
699
700
|
Requires-Dist: matplotlib>=3.7.5; extra == "viz"
|
@@ -1,15 +1,15 @@
|
|
1
|
-
ripple_down_rules/__init__.py,sha256=
|
1
|
+
ripple_down_rules/__init__.py,sha256=6Ze00N3Py1dmFEMGBz3jz63qsUVLz8WzFn7qx3lJnfM,100
|
2
2
|
ripple_down_rules/experts.py,sha256=bwozulI1rv0uyaMZQqEgapDO-s8wvW0D6Jqxmvu5fik,12610
|
3
|
-
ripple_down_rules/helpers.py,sha256=
|
4
|
-
ripple_down_rules/rdr.py,sha256=
|
5
|
-
ripple_down_rules/rdr_decorators.py,sha256=
|
6
|
-
ripple_down_rules/rules.py,sha256=
|
3
|
+
ripple_down_rules/helpers.py,sha256=v4oE7C5PfQUVJfSUs1FfLHEwrJXEHJLn4vJhJMvyCR8,4453
|
4
|
+
ripple_down_rules/rdr.py,sha256=Mqh7lDjQu6wZUcJiJ57CZ3P0-hM4WfhFuV4s1jZnRv8,51833
|
5
|
+
ripple_down_rules/rdr_decorators.py,sha256=0sk7izDB53lTKSB9fm33vQahmY_05FyCOWljyQOMB0U,9072
|
6
|
+
ripple_down_rules/rules.py,sha256=ctf9yREG5l99HPFcYosjppKXTOwplZmzQbm4R1DMVaA,20107
|
7
7
|
ripple_down_rules/start-code-server.sh,sha256=otClk7VmDgBOX2TS_cjws6K0UwvgAUJhoA0ugkPCLqQ,949
|
8
|
-
ripple_down_rules/utils.py,sha256=
|
8
|
+
ripple_down_rules/utils.py,sha256=iwfpTlsxUqLHWpYqSKwrDnEEa_FYFHYb2LugEVDH_kk,57132
|
9
9
|
ripple_down_rules/datastructures/__init__.py,sha256=V2aNgf5C96Y5-IGghra3n9uiefpoIm_QdT7cc_C8cxQ,111
|
10
10
|
ripple_down_rules/datastructures/callable_expression.py,sha256=f3wUPTrLa1INO-1qfgVz87ryrCABronfyq0_JKWoZCs,12800
|
11
11
|
ripple_down_rules/datastructures/case.py,sha256=1zSaXUljaH6z3SgMGzYPoDyjotNam791KpYgvxuMh90,15463
|
12
|
-
ripple_down_rules/datastructures/dataclasses.py,sha256=
|
12
|
+
ripple_down_rules/datastructures/dataclasses.py,sha256=qoTFHV8Hi-X8VtfC9VdvH4tif73YjF3dUe8dyHXTYts,10993
|
13
13
|
ripple_down_rules/datastructures/enums.py,sha256=ce7tqS0otfSTNAOwsnXlhsvIn4iW_Y_N3TNebF3YoZs,5700
|
14
14
|
ripple_down_rules/user_interface/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
15
|
ripple_down_rules/user_interface/gui.py,sha256=_lgZAUXxxaBUFQJAHjA5TBPp6XEvJ62t-kSN8sPsocE,27379
|
@@ -17,8 +17,8 @@ ripple_down_rules/user_interface/ipython_custom_shell.py,sha256=Jrf7NxOdlrwGXH0X
|
|
17
17
|
ripple_down_rules/user_interface/object_diagram.py,sha256=FEa2HaYR9QmTE6NsOwBvZ0jqmu3DKyg6mig2VE5ZP4Y,4956
|
18
18
|
ripple_down_rules/user_interface/prompt.py,sha256=AkkltdDIaioN43lkRKDPKSjJcmdSSGZDMYz7AL7X9lE,8082
|
19
19
|
ripple_down_rules/user_interface/template_file_creator.py,sha256=VLS9Nxg6gPNa-YYliJ_VNsTvLPlZ003EVkJ2t8zuDgE,13563
|
20
|
-
ripple_down_rules-0.5.
|
21
|
-
ripple_down_rules-0.5.
|
22
|
-
ripple_down_rules-0.5.
|
23
|
-
ripple_down_rules-0.5.
|
24
|
-
ripple_down_rules-0.5.
|
20
|
+
ripple_down_rules-0.5.75.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
21
|
+
ripple_down_rules-0.5.75.dist-info/METADATA,sha256=M1N-k7Zp8qzOsGlF8K4n889agx_bLu_rWO9_c-cEViQ,48214
|
22
|
+
ripple_down_rules-0.5.75.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
23
|
+
ripple_down_rules-0.5.75.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
24
|
+
ripple_down_rules-0.5.75.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|