ripple-down-rules 0.0.15__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,55 @@
1
+ """
2
+ This file contains decorators for the RDR (Ripple Down Rules) framework. Where each type of RDR has a decorator
3
+ that can be used with any python function such that this function can benefit from the incremental knowledge acquisition
4
+ of the RDRs.
5
+ """
6
+ import os.path
7
+ from functools import wraps
8
+ from typing import Callable, Optional, Type
9
+
10
+ from sqlalchemy.orm import Session
11
+ from typing_extensions import Any
12
+
13
+ from ripple_down_rules.datastructures import Case, Category, create_case, CaseQuery
14
+ from ripple_down_rules.experts import Expert, Human
15
+
16
+ from ripple_down_rules.rdr import SingleClassRDR, MultiClassRDR, GeneralRDR
17
+ from ripple_down_rules.utils import get_method_args_as_dict, get_method_name, get_method_class_name_if_exists, \
18
+ get_method_file_name, get_func_rdr_model_path
19
+
20
+
21
+ def single_class_rdr(
22
+ model_dir: str,
23
+ fit: bool = True,
24
+ expert: Optional[Expert] = None,
25
+ session: Optional[Session] = None,
26
+ ) -> Callable:
27
+ """
28
+ Decorator to use a SingleClassRDR as a classifier.
29
+ """
30
+ expert = expert if expert else Human(session=session)
31
+
32
+ def decorator(func: Callable) -> Callable:
33
+ scrdr_model_path = get_func_rdr_model_path(func, model_dir)
34
+ if os.path.exists(scrdr_model_path):
35
+ scrdr = SingleClassRDR.load(scrdr_model_path)
36
+ scrdr.session = session
37
+ else:
38
+ scrdr = SingleClassRDR(session=session)
39
+
40
+ @wraps(func)
41
+ def wrapper(*args, **kwargs) -> Category:
42
+ case_dict = get_method_args_as_dict(func, *args, **kwargs)
43
+ func_output = func(*args, **kwargs)
44
+ if func_output is not None:
45
+ case_dict.update({"_output": func_output})
46
+ case = create_case(case_dict, recursion_idx=3)
47
+ if fit:
48
+ output = scrdr.fit_case(CaseQuery(case), expert=expert)
49
+ scrdr.save(scrdr_model_path)
50
+ return output
51
+ else:
52
+ return scrdr.classify(case)
53
+ return wrapper
54
+
55
+ return decorator
@@ -21,7 +21,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
21
21
  conclusion: Optional[CallableExpression] = None,
22
22
  parent: Optional[Rule] = None,
23
23
  corner_case: Optional[Union[Case, SQLTable]] = None,
24
- weight: Optional[str] = None):
24
+ weight: Optional[str] = None,
25
+ conclusion_name: Optional[str] = None):
25
26
  """
26
27
  A rule in the ripple down rules classifier.
27
28
 
@@ -30,6 +31,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
30
31
  :param parent: The parent rule of this rule.
31
32
  :param corner_case: The corner case that this rule is based on/created from.
32
33
  :param weight: The weight of the rule, which is the type of edge connecting the rule to its parent.
34
+ :param conclusion_name: The name of the conclusion of the rule.
33
35
  """
34
36
  super(Rule, self).__init__()
35
37
  self.conclusion = conclusion
@@ -37,6 +39,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
37
39
  self.parent = parent
38
40
  self.weight: Optional[str] = weight
39
41
  self.conditions = conditions if conditions else None
42
+ self.conclusion_name: Optional[str] = conclusion_name
40
43
  self.json_serialization: Optional[Dict[str, Any]] = None
41
44
 
42
45
  def _post_detach(self, parent):
@@ -107,8 +110,10 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
107
110
  conclusions = {'_type': get_full_class_name(type(conclusion)), 'value': []}
108
111
  for c in conclusion:
109
112
  conclusions['value'].append(conclusion_to_json(c))
110
- else:
113
+ elif hasattr(conclusion, 'to_json'):
111
114
  conclusions = conclusion.to_json()
115
+ else:
116
+ conclusions = {'_type': get_full_class_name(type(conclusion)), 'value': conclusion}
112
117
  return conclusions
113
118
 
114
119
  json_serialization = {"conditions": self.conditions.to_json(),
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import ast
3
4
  import importlib
4
5
  import json
5
6
  import logging
7
+ import os
6
8
  from abc import abstractmethod
7
9
  from collections import UserDict
8
10
  from copy import deepcopy
9
- from dataclasses import dataclass
11
+ from dataclasses import dataclass, is_dataclass, fields
10
12
 
11
13
  import matplotlib
12
14
  import networkx as nx
@@ -25,6 +27,162 @@ if TYPE_CHECKING:
25
27
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
26
28
 
27
29
 
30
+ def serialize_dataclass(obj: Any) -> Union[Dict, Any]:
31
+ """
32
+ Recursively serialize a dataclass to a dictionary. If the dataclass contains any nested dataclasses, they will be
33
+ serialized as well. If the object is not a dataclass, it will be returned as is.
34
+
35
+ :param obj: The dataclass to serialize.
36
+ :return: The serialized dataclass as a dictionary or the object itself if it is not a dataclass.
37
+ """
38
+
39
+ def recursive_convert(obj):
40
+ if is_dataclass(obj):
41
+ return {
42
+ "__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}",
43
+ "fields": {f.name: recursive_convert(getattr(obj, f.name)) for f in fields(obj)}
44
+ }
45
+ elif isinstance(obj, list):
46
+ return [recursive_convert(item) for item in obj]
47
+ elif isinstance(obj, dict):
48
+ return {k: recursive_convert(v) for k, v in obj.items()}
49
+ else:
50
+ return obj
51
+
52
+ return recursive_convert(obj)
53
+
54
+
55
+ def deserialize_dataclass(data: dict) -> Any:
56
+ """
57
+ Recursively deserialize a dataclass from a dictionary, if the dictionary contains a key "__dataclass__" (Most likely
58
+ created by the serialize_dataclass function), it will be treated as a dataclass and deserialized accordingly,
59
+ otherwise it will be returned as is.
60
+
61
+ :param data: The dictionary to deserialize.
62
+ :return: The deserialized dataclass.
63
+ """
64
+ def recursive_load(obj):
65
+ if isinstance(obj, dict) and "__dataclass__" in obj:
66
+ module_name, class_name = obj["__dataclass__"].rsplit(".", 1)
67
+ module = importlib.import_module(module_name)
68
+ cls: Type = getattr(module, class_name)
69
+ field_values = {
70
+ k: recursive_load(v)
71
+ for k, v in obj["fields"].items()
72
+ }
73
+ return cls(**field_values)
74
+ elif isinstance(obj, list):
75
+ return [recursive_load(item) for item in obj]
76
+ elif isinstance(obj, dict):
77
+ return {k: recursive_load(v) for k, v in obj.items()}
78
+ else:
79
+ return obj
80
+
81
+ return recursive_load(data)
82
+
83
+
84
+ def typing_to_python_type(typing_hint: Type) -> Type:
85
+ """
86
+ Convert a typing hint to a python type.
87
+
88
+ :param typing_hint: The typing hint to convert.
89
+ :return: The python type.
90
+ """
91
+ if typing_hint in [list, List]:
92
+ return list
93
+ elif typing_hint in [tuple, Tuple]:
94
+ return tuple
95
+ elif typing_hint in [set, Set]:
96
+ return set
97
+ elif typing_hint in [dict, Dict]:
98
+ return dict
99
+ else:
100
+ return typing_hint
101
+
102
+
103
+ def capture_variable_assignment(code: str, variable_name: str) -> Optional[str]:
104
+ """
105
+ Capture the assignment of a variable in the given code.
106
+
107
+ :param code: The code to analyze.
108
+ :param variable_name: The name of the variable to capture.
109
+ :return: The assignment statement or None if not found.
110
+ """
111
+ tree = ast.parse(code)
112
+ assignment = None
113
+ for node in ast.walk(tree):
114
+ if isinstance(node, ast.Assign):
115
+ for target in node.targets:
116
+ if isinstance(target, ast.Name) and target.id == variable_name:
117
+ # now extract the right side of the assignment
118
+ assignment = ast.get_source_segment(code, node.value)
119
+ break
120
+ if assignment is not None:
121
+ break
122
+ return assignment
123
+
124
+
125
+ def get_func_rdr_model_path(func: Callable, model_dir: str) -> str:
126
+ """
127
+ :param func: The function to get the model path for.
128
+ :param model_dir: The directory to save the model to.
129
+ :return: The path to the model file.
130
+ """
131
+ func_name = get_method_name(func)
132
+ func_class_name = get_method_class_name_if_exists(func)
133
+ func_file_name = get_method_file_name(func)
134
+ model_name = func_file_name
135
+ model_name += f"_{func_class_name}" if func_class_name else ""
136
+ model_name += f"_{func_name}"
137
+ return os.path.join(model_dir, f"{model_name}.json")
138
+
139
+
140
+ def get_method_args_as_dict(method: Callable, *args, **kwargs) -> Dict[str, Any]:
141
+ """
142
+ Get the arguments of a method as a dictionary.
143
+
144
+ :param method: The method to get the arguments from.
145
+ :param args: The positional arguments.
146
+ :param kwargs: The keyword arguments.
147
+ :return: A dictionary of the arguments.
148
+ """
149
+ func_arg_names = method.__code__.co_varnames
150
+ func_arg_values = args + tuple(kwargs.values())
151
+ return dict(zip(func_arg_names, func_arg_values))
152
+
153
+
154
+ def get_method_name(method: Callable) -> str:
155
+ """
156
+ Get the name of a method.
157
+
158
+ :param method: The method to get the name of.
159
+ :return: The name of the method.
160
+ """
161
+ return method.__name__ if hasattr(method, "__name__") else str(method)
162
+
163
+
164
+ def get_method_class_name_if_exists(method: Callable) -> Optional[str]:
165
+ """
166
+ Get the class name of a method if it has one.
167
+
168
+ :param method: The method to get the class name of.
169
+ :return: The class name of the method.
170
+ """
171
+ if hasattr(method, "__self__") and hasattr(method.__self__, "__class__"):
172
+ return method.__self__.__class__.__name__
173
+ return None
174
+
175
+
176
+ def get_method_file_name(method: Callable) -> str:
177
+ """
178
+ Get the file name of a method.
179
+
180
+ :param method: The method to get the file name of.
181
+ :return: The file name of the method.
182
+ """
183
+ return method.__code__.co_filename
184
+
185
+
28
186
  def flatten_list(a: List):
29
187
  a_flattened = []
30
188
  for c in a:
@@ -86,7 +244,7 @@ def recursive_subclasses(cls):
86
244
 
87
245
  class SubclassJSONSerializer:
88
246
  """
89
- Copied from: https://github.com/tomsch420/random-events/blob/master/src/random_events/utils.py#L6C1-L21C101
247
+ Originally from: https://github.com/tomsch420/random-events/blob/master/src/random_events/utils.py#L6C1-L21C101
90
248
  Class for automatic (de)serialization of subclasses.
91
249
  Classes that inherit from this class can be serialized and deserialized automatically by calling this classes
92
250
  'from_json' method.
@@ -150,8 +308,14 @@ class SubclassJSONSerializer:
150
308
  return None
151
309
  if not isinstance(data, dict) or ('_type' not in data):
152
310
  return data
311
+ if '__dataclass__' in data:
312
+ # if the data is a dataclass, deserialize it
313
+ return deserialize_dataclass(data)
314
+
153
315
  # check if type module is builtins
154
316
  data_type = get_type_from_string(data["_type"])
317
+ if len(data) == 1:
318
+ return data_type
155
319
  if data_type.__module__ == 'builtins':
156
320
  if is_iterable(data['value']) and not isinstance(data['value'], dict):
157
321
  return data_type([cls.from_json(d) for d in data['value']])
@@ -279,7 +443,7 @@ def table_rows_as_str(row_dict: Dict[str, Any], columns_per_row: int = 9):
279
443
  values = [list(map(lambda i: i[1], row)) for row in all_items]
280
444
  all_table_rows = []
281
445
  for row_keys, row_values in zip(keys, values):
282
- table = tabulate([row_values], headers=row_keys, tablefmt='plain')
446
+ table = tabulate([row_values], headers=row_keys, tablefmt='plain', maxcolwidths=[20] * len(row_keys))
283
447
  all_table_rows.append(table)
284
448
  return "\n".join(all_table_rows)
285
449
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.15
3
+ Version: 0.1.1
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -0,0 +1,20 @@
1
+ ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ ripple_down_rules/datasets.py,sha256=AzPtqUXuR1qLQNtRsWLsJ3gX2oIf8nIkFvmsmz7fHlw,4601
3
+ ripple_down_rules/experts.py,sha256=Xz1U1Tdq7jrFlcVuSusaMB241AG9TEs7q101i59Xijs,10683
4
+ ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
+ ripple_down_rules/helpers.py,sha256=AhqerAQoCdSovJ7SdQrNtAI_hYagKpLsy2nJQGA0bl0,1062
6
+ ripple_down_rules/prompt.py,sha256=z6KddZOsNiStptgCRNh2OVHHuH6Ooa2f-nsrgJH1qJ8,6311
7
+ ripple_down_rules/rdr.py,sha256=vuRGyqB2wasAzSUhj9HenBp0nC-C1zCGUj3e_th2X7A,43511
8
+ ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
9
+ ripple_down_rules/rules.py,sha256=aM3Im4ePuFDlkuD2EKRtiVmYgoQ_sxlwcbzrDKqXAfs,14578
10
+ ripple_down_rules/utils.py,sha256=9gPnRWlLye7FettI2QRWJx8oU9z3ckwdO5jopXK8b-8,24290
11
+ ripple_down_rules/datastructures/__init__.py,sha256=zpmiYm4WkwNHaGdTIfacS7llN5d2xyU6U-saH_TpydI,103
12
+ ripple_down_rules/datastructures/callable_expression.py,sha256=ac2TaMr0hiRX928GMcr3oTQic8KXXO4syLw4KV-Iehs,10515
13
+ ripple_down_rules/datastructures/case.py,sha256=3Pl07jmYn94wdCVTaRZDmBPgyAsN1TjebvrE6-68MVU,13606
14
+ ripple_down_rules/datastructures/dataclasses.py,sha256=AI-wqNy8y9QPg6lov0P-c5b8JXemuM4X62tIRhW-Gqs,4231
15
+ ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
16
+ ripple_down_rules-0.1.1.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
+ ripple_down_rules-0.1.1.dist-info/METADATA,sha256=rzAQLYJ7yFFBy9Nthw1qwhhwbOZtv1aTNnNSVyf9BQE,42518
18
+ ripple_down_rules-0.1.1.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
19
+ ripple_down_rules-0.1.1.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
+ ripple_down_rules-0.1.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (79.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,18 +0,0 @@
1
- ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- ripple_down_rules/datasets.py,sha256=qYX7IF7ACm0VRbaKfEgQ32j0YbUUyt2GfGU5Lo42CqI,4601
3
- ripple_down_rules/experts.py,sha256=wg1uY0ox9dUMR4s1RdGjzpX1_WUqnCa060r1U9lrKYI,11214
4
- ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
- ripple_down_rules/prompt.py,sha256=QAmxg4ssrGUAlK7lbyKw2nuRczTZColpjc9uMC1ts3I,4210
6
- ripple_down_rules/rdr.py,sha256=am57_cmN2ge_eKGQLIAyd5iH4iSC15QrFHmdmwwM_uY,41947
7
- ripple_down_rules/rules.py,sha256=MUZv42WPOB73rzw6Um2F_q0woRFG4db4yYB6R-qMpKA,14239
8
- ripple_down_rules/utils.py,sha256=32gU9NHIcxQpfS4zPSAtn63np5SdVvS9VuyoYiyKZbc,18664
9
- ripple_down_rules/datastructures/__init__.py,sha256=zpmiYm4WkwNHaGdTIfacS7llN5d2xyU6U-saH_TpydI,103
10
- ripple_down_rules/datastructures/callable_expression.py,sha256=TN6bi4VYjyLlSLTEA3dRo5ENfEdQYc8Fjj5nbnsz-C0,9058
11
- ripple_down_rules/datastructures/case.py,sha256=0EV69VBzsBTHQ3Ots7fkP09g1tyJ20xW9kSm0nZGy40,14071
12
- ripple_down_rules/datastructures/dataclasses.py,sha256=EVQ1jBKW7K7q7_JNgikHX9fm3EmQQKA74sNjEQ4rXn8,2449
13
- ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
14
- ripple_down_rules-0.0.15.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
15
- ripple_down_rules-0.0.15.dist-info/METADATA,sha256=JOCvhnY9h7adkNGnR9uX0j_d7oQcIIBTY8kCpZ2VsD4,42519
16
- ripple_down_rules-0.0.15.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
17
- ripple_down_rules-0.0.15.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
18
- ripple_down_rules-0.0.15.dist-info/RECORD,,