ripple-down-rules 0.0.15__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +2 -2
- ripple_down_rules/datastructures/callable_expression.py +52 -10
- ripple_down_rules/datastructures/case.py +54 -70
- ripple_down_rules/datastructures/dataclasses.py +69 -29
- ripple_down_rules/experts.py +29 -40
- ripple_down_rules/helpers.py +27 -0
- ripple_down_rules/prompt.py +77 -24
- ripple_down_rules/rdr.py +218 -200
- ripple_down_rules/rdr_decorators.py +55 -0
- ripple_down_rules/rules.py +7 -2
- ripple_down_rules/utils.py +167 -3
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/METADATA +1 -1
- ripple_down_rules-0.1.1.dist-info/RECORD +20 -0
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.0.15.dist-info/RECORD +0 -18
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,55 @@
|
|
1
|
+
"""
|
2
|
+
This file contains decorators for the RDR (Ripple Down Rules) framework. Where each type of RDR has a decorator
|
3
|
+
that can be used with any python function such that this function can benefit from the incremental knowledge acquisition
|
4
|
+
of the RDRs.
|
5
|
+
"""
|
6
|
+
import os.path
|
7
|
+
from functools import wraps
|
8
|
+
from typing import Callable, Optional, Type
|
9
|
+
|
10
|
+
from sqlalchemy.orm import Session
|
11
|
+
from typing_extensions import Any
|
12
|
+
|
13
|
+
from ripple_down_rules.datastructures import Case, Category, create_case, CaseQuery
|
14
|
+
from ripple_down_rules.experts import Expert, Human
|
15
|
+
|
16
|
+
from ripple_down_rules.rdr import SingleClassRDR, MultiClassRDR, GeneralRDR
|
17
|
+
from ripple_down_rules.utils import get_method_args_as_dict, get_method_name, get_method_class_name_if_exists, \
|
18
|
+
get_method_file_name, get_func_rdr_model_path
|
19
|
+
|
20
|
+
|
21
|
+
def single_class_rdr(
|
22
|
+
model_dir: str,
|
23
|
+
fit: bool = True,
|
24
|
+
expert: Optional[Expert] = None,
|
25
|
+
session: Optional[Session] = None,
|
26
|
+
) -> Callable:
|
27
|
+
"""
|
28
|
+
Decorator to use a SingleClassRDR as a classifier.
|
29
|
+
"""
|
30
|
+
expert = expert if expert else Human(session=session)
|
31
|
+
|
32
|
+
def decorator(func: Callable) -> Callable:
|
33
|
+
scrdr_model_path = get_func_rdr_model_path(func, model_dir)
|
34
|
+
if os.path.exists(scrdr_model_path):
|
35
|
+
scrdr = SingleClassRDR.load(scrdr_model_path)
|
36
|
+
scrdr.session = session
|
37
|
+
else:
|
38
|
+
scrdr = SingleClassRDR(session=session)
|
39
|
+
|
40
|
+
@wraps(func)
|
41
|
+
def wrapper(*args, **kwargs) -> Category:
|
42
|
+
case_dict = get_method_args_as_dict(func, *args, **kwargs)
|
43
|
+
func_output = func(*args, **kwargs)
|
44
|
+
if func_output is not None:
|
45
|
+
case_dict.update({"_output": func_output})
|
46
|
+
case = create_case(case_dict, recursion_idx=3)
|
47
|
+
if fit:
|
48
|
+
output = scrdr.fit_case(CaseQuery(case), expert=expert)
|
49
|
+
scrdr.save(scrdr_model_path)
|
50
|
+
return output
|
51
|
+
else:
|
52
|
+
return scrdr.classify(case)
|
53
|
+
return wrapper
|
54
|
+
|
55
|
+
return decorator
|
ripple_down_rules/rules.py
CHANGED
@@ -21,7 +21,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
21
21
|
conclusion: Optional[CallableExpression] = None,
|
22
22
|
parent: Optional[Rule] = None,
|
23
23
|
corner_case: Optional[Union[Case, SQLTable]] = None,
|
24
|
-
weight: Optional[str] = None
|
24
|
+
weight: Optional[str] = None,
|
25
|
+
conclusion_name: Optional[str] = None):
|
25
26
|
"""
|
26
27
|
A rule in the ripple down rules classifier.
|
27
28
|
|
@@ -30,6 +31,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
30
31
|
:param parent: The parent rule of this rule.
|
31
32
|
:param corner_case: The corner case that this rule is based on/created from.
|
32
33
|
:param weight: The weight of the rule, which is the type of edge connecting the rule to its parent.
|
34
|
+
:param conclusion_name: The name of the conclusion of the rule.
|
33
35
|
"""
|
34
36
|
super(Rule, self).__init__()
|
35
37
|
self.conclusion = conclusion
|
@@ -37,6 +39,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
37
39
|
self.parent = parent
|
38
40
|
self.weight: Optional[str] = weight
|
39
41
|
self.conditions = conditions if conditions else None
|
42
|
+
self.conclusion_name: Optional[str] = conclusion_name
|
40
43
|
self.json_serialization: Optional[Dict[str, Any]] = None
|
41
44
|
|
42
45
|
def _post_detach(self, parent):
|
@@ -107,8 +110,10 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
107
110
|
conclusions = {'_type': get_full_class_name(type(conclusion)), 'value': []}
|
108
111
|
for c in conclusion:
|
109
112
|
conclusions['value'].append(conclusion_to_json(c))
|
110
|
-
|
113
|
+
elif hasattr(conclusion, 'to_json'):
|
111
114
|
conclusions = conclusion.to_json()
|
115
|
+
else:
|
116
|
+
conclusions = {'_type': get_full_class_name(type(conclusion)), 'value': conclusion}
|
112
117
|
return conclusions
|
113
118
|
|
114
119
|
json_serialization = {"conditions": self.conditions.to_json(),
|
ripple_down_rules/utils.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import ast
|
3
4
|
import importlib
|
4
5
|
import json
|
5
6
|
import logging
|
7
|
+
import os
|
6
8
|
from abc import abstractmethod
|
7
9
|
from collections import UserDict
|
8
10
|
from copy import deepcopy
|
9
|
-
from dataclasses import dataclass
|
11
|
+
from dataclasses import dataclass, is_dataclass, fields
|
10
12
|
|
11
13
|
import matplotlib
|
12
14
|
import networkx as nx
|
@@ -25,6 +27,162 @@ if TYPE_CHECKING:
|
|
25
27
|
matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
|
26
28
|
|
27
29
|
|
30
|
+
def serialize_dataclass(obj: Any) -> Union[Dict, Any]:
|
31
|
+
"""
|
32
|
+
Recursively serialize a dataclass to a dictionary. If the dataclass contains any nested dataclasses, they will be
|
33
|
+
serialized as well. If the object is not a dataclass, it will be returned as is.
|
34
|
+
|
35
|
+
:param obj: The dataclass to serialize.
|
36
|
+
:return: The serialized dataclass as a dictionary or the object itself if it is not a dataclass.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def recursive_convert(obj):
|
40
|
+
if is_dataclass(obj):
|
41
|
+
return {
|
42
|
+
"__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}",
|
43
|
+
"fields": {f.name: recursive_convert(getattr(obj, f.name)) for f in fields(obj)}
|
44
|
+
}
|
45
|
+
elif isinstance(obj, list):
|
46
|
+
return [recursive_convert(item) for item in obj]
|
47
|
+
elif isinstance(obj, dict):
|
48
|
+
return {k: recursive_convert(v) for k, v in obj.items()}
|
49
|
+
else:
|
50
|
+
return obj
|
51
|
+
|
52
|
+
return recursive_convert(obj)
|
53
|
+
|
54
|
+
|
55
|
+
def deserialize_dataclass(data: dict) -> Any:
|
56
|
+
"""
|
57
|
+
Recursively deserialize a dataclass from a dictionary, if the dictionary contains a key "__dataclass__" (Most likely
|
58
|
+
created by the serialize_dataclass function), it will be treated as a dataclass and deserialized accordingly,
|
59
|
+
otherwise it will be returned as is.
|
60
|
+
|
61
|
+
:param data: The dictionary to deserialize.
|
62
|
+
:return: The deserialized dataclass.
|
63
|
+
"""
|
64
|
+
def recursive_load(obj):
|
65
|
+
if isinstance(obj, dict) and "__dataclass__" in obj:
|
66
|
+
module_name, class_name = obj["__dataclass__"].rsplit(".", 1)
|
67
|
+
module = importlib.import_module(module_name)
|
68
|
+
cls: Type = getattr(module, class_name)
|
69
|
+
field_values = {
|
70
|
+
k: recursive_load(v)
|
71
|
+
for k, v in obj["fields"].items()
|
72
|
+
}
|
73
|
+
return cls(**field_values)
|
74
|
+
elif isinstance(obj, list):
|
75
|
+
return [recursive_load(item) for item in obj]
|
76
|
+
elif isinstance(obj, dict):
|
77
|
+
return {k: recursive_load(v) for k, v in obj.items()}
|
78
|
+
else:
|
79
|
+
return obj
|
80
|
+
|
81
|
+
return recursive_load(data)
|
82
|
+
|
83
|
+
|
84
|
+
def typing_to_python_type(typing_hint: Type) -> Type:
|
85
|
+
"""
|
86
|
+
Convert a typing hint to a python type.
|
87
|
+
|
88
|
+
:param typing_hint: The typing hint to convert.
|
89
|
+
:return: The python type.
|
90
|
+
"""
|
91
|
+
if typing_hint in [list, List]:
|
92
|
+
return list
|
93
|
+
elif typing_hint in [tuple, Tuple]:
|
94
|
+
return tuple
|
95
|
+
elif typing_hint in [set, Set]:
|
96
|
+
return set
|
97
|
+
elif typing_hint in [dict, Dict]:
|
98
|
+
return dict
|
99
|
+
else:
|
100
|
+
return typing_hint
|
101
|
+
|
102
|
+
|
103
|
+
def capture_variable_assignment(code: str, variable_name: str) -> Optional[str]:
|
104
|
+
"""
|
105
|
+
Capture the assignment of a variable in the given code.
|
106
|
+
|
107
|
+
:param code: The code to analyze.
|
108
|
+
:param variable_name: The name of the variable to capture.
|
109
|
+
:return: The assignment statement or None if not found.
|
110
|
+
"""
|
111
|
+
tree = ast.parse(code)
|
112
|
+
assignment = None
|
113
|
+
for node in ast.walk(tree):
|
114
|
+
if isinstance(node, ast.Assign):
|
115
|
+
for target in node.targets:
|
116
|
+
if isinstance(target, ast.Name) and target.id == variable_name:
|
117
|
+
# now extract the right side of the assignment
|
118
|
+
assignment = ast.get_source_segment(code, node.value)
|
119
|
+
break
|
120
|
+
if assignment is not None:
|
121
|
+
break
|
122
|
+
return assignment
|
123
|
+
|
124
|
+
|
125
|
+
def get_func_rdr_model_path(func: Callable, model_dir: str) -> str:
|
126
|
+
"""
|
127
|
+
:param func: The function to get the model path for.
|
128
|
+
:param model_dir: The directory to save the model to.
|
129
|
+
:return: The path to the model file.
|
130
|
+
"""
|
131
|
+
func_name = get_method_name(func)
|
132
|
+
func_class_name = get_method_class_name_if_exists(func)
|
133
|
+
func_file_name = get_method_file_name(func)
|
134
|
+
model_name = func_file_name
|
135
|
+
model_name += f"_{func_class_name}" if func_class_name else ""
|
136
|
+
model_name += f"_{func_name}"
|
137
|
+
return os.path.join(model_dir, f"{model_name}.json")
|
138
|
+
|
139
|
+
|
140
|
+
def get_method_args_as_dict(method: Callable, *args, **kwargs) -> Dict[str, Any]:
|
141
|
+
"""
|
142
|
+
Get the arguments of a method as a dictionary.
|
143
|
+
|
144
|
+
:param method: The method to get the arguments from.
|
145
|
+
:param args: The positional arguments.
|
146
|
+
:param kwargs: The keyword arguments.
|
147
|
+
:return: A dictionary of the arguments.
|
148
|
+
"""
|
149
|
+
func_arg_names = method.__code__.co_varnames
|
150
|
+
func_arg_values = args + tuple(kwargs.values())
|
151
|
+
return dict(zip(func_arg_names, func_arg_values))
|
152
|
+
|
153
|
+
|
154
|
+
def get_method_name(method: Callable) -> str:
|
155
|
+
"""
|
156
|
+
Get the name of a method.
|
157
|
+
|
158
|
+
:param method: The method to get the name of.
|
159
|
+
:return: The name of the method.
|
160
|
+
"""
|
161
|
+
return method.__name__ if hasattr(method, "__name__") else str(method)
|
162
|
+
|
163
|
+
|
164
|
+
def get_method_class_name_if_exists(method: Callable) -> Optional[str]:
|
165
|
+
"""
|
166
|
+
Get the class name of a method if it has one.
|
167
|
+
|
168
|
+
:param method: The method to get the class name of.
|
169
|
+
:return: The class name of the method.
|
170
|
+
"""
|
171
|
+
if hasattr(method, "__self__") and hasattr(method.__self__, "__class__"):
|
172
|
+
return method.__self__.__class__.__name__
|
173
|
+
return None
|
174
|
+
|
175
|
+
|
176
|
+
def get_method_file_name(method: Callable) -> str:
|
177
|
+
"""
|
178
|
+
Get the file name of a method.
|
179
|
+
|
180
|
+
:param method: The method to get the file name of.
|
181
|
+
:return: The file name of the method.
|
182
|
+
"""
|
183
|
+
return method.__code__.co_filename
|
184
|
+
|
185
|
+
|
28
186
|
def flatten_list(a: List):
|
29
187
|
a_flattened = []
|
30
188
|
for c in a:
|
@@ -86,7 +244,7 @@ def recursive_subclasses(cls):
|
|
86
244
|
|
87
245
|
class SubclassJSONSerializer:
|
88
246
|
"""
|
89
|
-
|
247
|
+
Originally from: https://github.com/tomsch420/random-events/blob/master/src/random_events/utils.py#L6C1-L21C101
|
90
248
|
Class for automatic (de)serialization of subclasses.
|
91
249
|
Classes that inherit from this class can be serialized and deserialized automatically by calling this classes
|
92
250
|
'from_json' method.
|
@@ -150,8 +308,14 @@ class SubclassJSONSerializer:
|
|
150
308
|
return None
|
151
309
|
if not isinstance(data, dict) or ('_type' not in data):
|
152
310
|
return data
|
311
|
+
if '__dataclass__' in data:
|
312
|
+
# if the data is a dataclass, deserialize it
|
313
|
+
return deserialize_dataclass(data)
|
314
|
+
|
153
315
|
# check if type module is builtins
|
154
316
|
data_type = get_type_from_string(data["_type"])
|
317
|
+
if len(data) == 1:
|
318
|
+
return data_type
|
155
319
|
if data_type.__module__ == 'builtins':
|
156
320
|
if is_iterable(data['value']) and not isinstance(data['value'], dict):
|
157
321
|
return data_type([cls.from_json(d) for d in data['value']])
|
@@ -279,7 +443,7 @@ def table_rows_as_str(row_dict: Dict[str, Any], columns_per_row: int = 9):
|
|
279
443
|
values = [list(map(lambda i: i[1], row)) for row in all_items]
|
280
444
|
all_table_rows = []
|
281
445
|
for row_keys, row_values in zip(keys, values):
|
282
|
-
table = tabulate([row_values], headers=row_keys, tablefmt='plain')
|
446
|
+
table = tabulate([row_values], headers=row_keys, tablefmt='plain', maxcolwidths=[20] * len(row_keys))
|
283
447
|
all_table_rows.append(table)
|
284
448
|
return "\n".join(all_table_rows)
|
285
449
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ripple_down_rules
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.1.1
|
4
4
|
Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
|
5
5
|
Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
@@ -0,0 +1,20 @@
|
|
1
|
+
ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
ripple_down_rules/datasets.py,sha256=AzPtqUXuR1qLQNtRsWLsJ3gX2oIf8nIkFvmsmz7fHlw,4601
|
3
|
+
ripple_down_rules/experts.py,sha256=Xz1U1Tdq7jrFlcVuSusaMB241AG9TEs7q101i59Xijs,10683
|
4
|
+
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
|
+
ripple_down_rules/helpers.py,sha256=AhqerAQoCdSovJ7SdQrNtAI_hYagKpLsy2nJQGA0bl0,1062
|
6
|
+
ripple_down_rules/prompt.py,sha256=z6KddZOsNiStptgCRNh2OVHHuH6Ooa2f-nsrgJH1qJ8,6311
|
7
|
+
ripple_down_rules/rdr.py,sha256=vuRGyqB2wasAzSUhj9HenBp0nC-C1zCGUj3e_th2X7A,43511
|
8
|
+
ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
|
9
|
+
ripple_down_rules/rules.py,sha256=aM3Im4ePuFDlkuD2EKRtiVmYgoQ_sxlwcbzrDKqXAfs,14578
|
10
|
+
ripple_down_rules/utils.py,sha256=9gPnRWlLye7FettI2QRWJx8oU9z3ckwdO5jopXK8b-8,24290
|
11
|
+
ripple_down_rules/datastructures/__init__.py,sha256=zpmiYm4WkwNHaGdTIfacS7llN5d2xyU6U-saH_TpydI,103
|
12
|
+
ripple_down_rules/datastructures/callable_expression.py,sha256=ac2TaMr0hiRX928GMcr3oTQic8KXXO4syLw4KV-Iehs,10515
|
13
|
+
ripple_down_rules/datastructures/case.py,sha256=3Pl07jmYn94wdCVTaRZDmBPgyAsN1TjebvrE6-68MVU,13606
|
14
|
+
ripple_down_rules/datastructures/dataclasses.py,sha256=AI-wqNy8y9QPg6lov0P-c5b8JXemuM4X62tIRhW-Gqs,4231
|
15
|
+
ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
|
16
|
+
ripple_down_rules-0.1.1.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
17
|
+
ripple_down_rules-0.1.1.dist-info/METADATA,sha256=rzAQLYJ7yFFBy9Nthw1qwhhwbOZtv1aTNnNSVyf9BQE,42518
|
18
|
+
ripple_down_rules-0.1.1.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
|
19
|
+
ripple_down_rules-0.1.1.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
20
|
+
ripple_down_rules-0.1.1.dist-info/RECORD,,
|
@@ -1,18 +0,0 @@
|
|
1
|
-
ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
ripple_down_rules/datasets.py,sha256=qYX7IF7ACm0VRbaKfEgQ32j0YbUUyt2GfGU5Lo42CqI,4601
|
3
|
-
ripple_down_rules/experts.py,sha256=wg1uY0ox9dUMR4s1RdGjzpX1_WUqnCa060r1U9lrKYI,11214
|
4
|
-
ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
|
5
|
-
ripple_down_rules/prompt.py,sha256=QAmxg4ssrGUAlK7lbyKw2nuRczTZColpjc9uMC1ts3I,4210
|
6
|
-
ripple_down_rules/rdr.py,sha256=am57_cmN2ge_eKGQLIAyd5iH4iSC15QrFHmdmwwM_uY,41947
|
7
|
-
ripple_down_rules/rules.py,sha256=MUZv42WPOB73rzw6Um2F_q0woRFG4db4yYB6R-qMpKA,14239
|
8
|
-
ripple_down_rules/utils.py,sha256=32gU9NHIcxQpfS4zPSAtn63np5SdVvS9VuyoYiyKZbc,18664
|
9
|
-
ripple_down_rules/datastructures/__init__.py,sha256=zpmiYm4WkwNHaGdTIfacS7llN5d2xyU6U-saH_TpydI,103
|
10
|
-
ripple_down_rules/datastructures/callable_expression.py,sha256=TN6bi4VYjyLlSLTEA3dRo5ENfEdQYc8Fjj5nbnsz-C0,9058
|
11
|
-
ripple_down_rules/datastructures/case.py,sha256=0EV69VBzsBTHQ3Ots7fkP09g1tyJ20xW9kSm0nZGy40,14071
|
12
|
-
ripple_down_rules/datastructures/dataclasses.py,sha256=EVQ1jBKW7K7q7_JNgikHX9fm3EmQQKA74sNjEQ4rXn8,2449
|
13
|
-
ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
|
14
|
-
ripple_down_rules-0.0.15.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
|
15
|
-
ripple_down_rules-0.0.15.dist-info/METADATA,sha256=JOCvhnY9h7adkNGnR9uX0j_d7oQcIIBTY8kCpZ2VsD4,42519
|
16
|
-
ripple_down_rules-0.0.15.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
17
|
-
ripple_down_rules-0.0.15.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
|
18
|
-
ripple_down_rules-0.0.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|