ripple-down-rules 0.0.14__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,55 @@
1
+ """
2
+ This file contains decorators for the RDR (Ripple Down Rules) framework. Where each type of RDR has a decorator
3
+ that can be used with any python function such that this function can benefit from the incremental knowledge acquisition
4
+ of the RDRs.
5
+ """
6
+ import os.path
7
+ from functools import wraps
8
+ from typing import Callable, Optional, Type
9
+
10
+ from sqlalchemy.orm import Session
11
+ from typing_extensions import Any
12
+
13
+ from ripple_down_rules.datastructures import Case, Category, create_case, CaseQuery
14
+ from ripple_down_rules.experts import Expert, Human
15
+
16
+ from ripple_down_rules.rdr import SingleClassRDR, MultiClassRDR, GeneralRDR
17
+ from ripple_down_rules.utils import get_method_args_as_dict, get_method_name, get_method_class_name_if_exists, \
18
+ get_method_file_name, get_func_rdr_model_path
19
+
20
+
21
+ def single_class_rdr(
22
+ model_dir: str,
23
+ fit: bool = True,
24
+ expert: Optional[Expert] = None,
25
+ session: Optional[Session] = None,
26
+ ) -> Callable:
27
+ """
28
+ Decorator to use a SingleClassRDR as a classifier.
29
+ """
30
+ expert = expert if expert else Human(session=session)
31
+
32
+ def decorator(func: Callable) -> Callable:
33
+ scrdr_model_path = get_func_rdr_model_path(func, model_dir)
34
+ if os.path.exists(scrdr_model_path):
35
+ scrdr = SingleClassRDR.load(scrdr_model_path)
36
+ scrdr.session = session
37
+ else:
38
+ scrdr = SingleClassRDR(session=session)
39
+
40
+ @wraps(func)
41
+ def wrapper(*args, **kwargs) -> Category:
42
+ case_dict = get_method_args_as_dict(func, *args, **kwargs)
43
+ func_output = func(*args, **kwargs)
44
+ if func_output is not None:
45
+ case_dict.update({"_output": func_output})
46
+ case = create_case(case_dict, recursion_idx=3)
47
+ if fit:
48
+ output = scrdr.fit_case(CaseQuery(case), expert=expert)
49
+ scrdr.save(scrdr_model_path)
50
+ return output
51
+ else:
52
+ return scrdr.classify(case)
53
+ return wrapper
54
+
55
+ return decorator
@@ -21,7 +21,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
21
21
  conclusion: Optional[CallableExpression] = None,
22
22
  parent: Optional[Rule] = None,
23
23
  corner_case: Optional[Union[Case, SQLTable]] = None,
24
- weight: Optional[str] = None):
24
+ weight: Optional[str] = None,
25
+ conclusion_name: Optional[str] = None):
25
26
  """
26
27
  A rule in the ripple down rules classifier.
27
28
 
@@ -30,6 +31,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
30
31
  :param parent: The parent rule of this rule.
31
32
  :param corner_case: The corner case that this rule is based on/created from.
32
33
  :param weight: The weight of the rule, which is the type of edge connecting the rule to its parent.
34
+ :param conclusion_name: The name of the conclusion of the rule.
33
35
  """
34
36
  super(Rule, self).__init__()
35
37
  self.conclusion = conclusion
@@ -37,6 +39,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
37
39
  self.parent = parent
38
40
  self.weight: Optional[str] = weight
39
41
  self.conditions = conditions if conditions else None
42
+ self.conclusion_name: Optional[str] = conclusion_name
40
43
  self.json_serialization: Optional[Dict[str, Any]] = None
41
44
 
42
45
  def _post_detach(self, parent):
@@ -107,8 +110,10 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
107
110
  conclusions = {'_type': get_full_class_name(type(conclusion)), 'value': []}
108
111
  for c in conclusion:
109
112
  conclusions['value'].append(conclusion_to_json(c))
110
- else:
113
+ elif hasattr(conclusion, 'to_json'):
111
114
  conclusions = conclusion.to_json()
115
+ else:
116
+ conclusions = {'_type': get_full_class_name(type(conclusion)), 'value': conclusion}
112
117
  return conclusions
113
118
 
114
119
  json_serialization = {"conditions": self.conditions.to_json(),
@@ -334,7 +339,11 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
334
339
  return loaded_rule
335
340
 
336
341
  def _conclusion_source_code_clause(self, conclusion: Any, parent_indent: str = "") -> str:
337
- statement = f"{parent_indent}{' ' * 4}conclusions.add({conclusion})\n"
342
+ if is_iterable(conclusion):
343
+ conclusion_str = "{" + ", ".join([str(c) for c in conclusion]) + "}"
344
+ else:
345
+ conclusion_str = "{" + str(conclusion) + "}"
346
+ statement = f"{parent_indent}{' ' * 4}conclusions.update({conclusion_str})\n"
338
347
  if self.alternative is None:
339
348
  statement += f"{parent_indent}return conclusions\n"
340
349
  return statement
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import ast
3
4
  import importlib
4
5
  import json
5
6
  import logging
7
+ import os
6
8
  from abc import abstractmethod
7
9
  from collections import UserDict
8
10
  from copy import deepcopy
9
- from dataclasses import dataclass
11
+ from dataclasses import dataclass, is_dataclass, fields
10
12
 
11
13
  import matplotlib
12
14
  import networkx as nx
@@ -25,6 +27,149 @@ if TYPE_CHECKING:
25
27
  matplotlib.use("Qt5Agg") # or "Qt5Agg", depending on availability
26
28
 
27
29
 
30
+ def serialize_dataclass(obj: Any) -> Dict:
31
+
32
+ def recursive_convert(obj):
33
+ if is_dataclass(obj):
34
+ return {
35
+ "__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}",
36
+ "fields": {f.name: recursive_convert(getattr(obj, f.name)) for f in fields(obj)}
37
+ }
38
+ elif isinstance(obj, list):
39
+ return [recursive_convert(item) for item in obj]
40
+ elif isinstance(obj, dict):
41
+ return {k: recursive_convert(v) for k, v in obj.items()}
42
+ elif hasattr(obj, "__module__") and hasattr(obj, "__name__"):
43
+ return {'_type': get_full_class_name(obj.__class__)}
44
+ else:
45
+ return obj
46
+
47
+ return recursive_convert(obj)
48
+
49
+
50
+ def deserialize_dataclass(data: dict) -> Any:
51
+ def recursive_load(obj):
52
+ if isinstance(obj, dict) and "__dataclass__" in obj:
53
+ module_name, class_name = obj["__dataclass__"].rsplit(".", 1)
54
+ module = importlib.import_module(module_name)
55
+ cls: Type = getattr(module, class_name)
56
+ field_values = {
57
+ k: recursive_load(v)
58
+ for k, v in obj["fields"].items()
59
+ }
60
+ return cls(**field_values)
61
+ elif isinstance(obj, list):
62
+ return [recursive_load(item) for item in obj]
63
+ elif isinstance(obj, dict):
64
+ return {k: recursive_load(v) for k, v in obj.items()}
65
+ else:
66
+ return obj
67
+
68
+ return recursive_load(data)
69
+
70
+
71
+ def typing_to_python_type(typing_hint: Type) -> Type:
72
+ """
73
+ Convert a typing hint to a python type.
74
+
75
+ :param typing_hint: The typing hint to convert.
76
+ :return: The python type.
77
+ """
78
+ if typing_hint in [list, List]:
79
+ return list
80
+ elif typing_hint in [tuple, Tuple]:
81
+ return tuple
82
+ elif typing_hint in [set, Set]:
83
+ return set
84
+ elif typing_hint in [dict, Dict]:
85
+ return dict
86
+ else:
87
+ return typing_hint
88
+
89
+
90
+ def capture_variable_assignment(code: str, variable_name: str) -> Optional[str]:
91
+ """
92
+ Capture the assignment of a variable in the given code.
93
+
94
+ :param code: The code to analyze.
95
+ :param variable_name: The name of the variable to capture.
96
+ :return: The assignment statement or None if not found.
97
+ """
98
+ tree = ast.parse(code)
99
+ assignment = None
100
+ for node in ast.walk(tree):
101
+ if isinstance(node, ast.Assign):
102
+ for target in node.targets:
103
+ if isinstance(target, ast.Name) and target.id == variable_name:
104
+ # now extract the right side of the assignment
105
+ assignment = ast.get_source_segment(code, node.value)
106
+ break
107
+ if assignment is not None:
108
+ break
109
+ return assignment
110
+
111
+
112
+ def get_func_rdr_model_path(func: Callable, model_dir: str) -> str:
113
+ """
114
+ :param func: The function to get the model path for.
115
+ :param model_dir: The directory to save the model to.
116
+ :return: The path to the model file.
117
+ """
118
+ func_name = get_method_name(func)
119
+ func_class_name = get_method_class_name_if_exists(func)
120
+ func_file_name = get_method_file_name(func)
121
+ model_name = func_file_name
122
+ model_name += f"_{func_class_name}" if func_class_name else ""
123
+ model_name += f"_{func_name}"
124
+ return os.path.join(model_dir, f"{model_name}.json")
125
+
126
+
127
+ def get_method_args_as_dict(method: Callable, *args, **kwargs) -> Dict[str, Any]:
128
+ """
129
+ Get the arguments of a method as a dictionary.
130
+
131
+ :param method: The method to get the arguments from.
132
+ :param args: The positional arguments.
133
+ :param kwargs: The keyword arguments.
134
+ :return: A dictionary of the arguments.
135
+ """
136
+ func_arg_names = method.__code__.co_varnames
137
+ func_arg_values = args + tuple(kwargs.values())
138
+ return dict(zip(func_arg_names, func_arg_values))
139
+
140
+
141
+ def get_method_name(method: Callable) -> str:
142
+ """
143
+ Get the name of a method.
144
+
145
+ :param method: The method to get the name of.
146
+ :return: The name of the method.
147
+ """
148
+ return method.__name__ if hasattr(method, "__name__") else str(method)
149
+
150
+
151
+ def get_method_class_name_if_exists(method: Callable) -> Optional[str]:
152
+ """
153
+ Get the class name of a method if it has one.
154
+
155
+ :param method: The method to get the class name of.
156
+ :return: The class name of the method.
157
+ """
158
+ if hasattr(method, "__self__") and hasattr(method.__self__, "__class__"):
159
+ return method.__self__.__class__.__name__
160
+ return None
161
+
162
+
163
+ def get_method_file_name(method: Callable) -> str:
164
+ """
165
+ Get the file name of a method.
166
+
167
+ :param method: The method to get the file name of.
168
+ :return: The file name of the method.
169
+ """
170
+ return method.__code__.co_filename
171
+
172
+
28
173
  def flatten_list(a: List):
29
174
  a_flattened = []
30
175
  for c in a:
@@ -86,7 +231,7 @@ def recursive_subclasses(cls):
86
231
 
87
232
  class SubclassJSONSerializer:
88
233
  """
89
- Copied from: https://github.com/tomsch420/random-events/blob/master/src/random_events/utils.py#L6C1-L21C101
234
+ Originally from: https://github.com/tomsch420/random-events/blob/master/src/random_events/utils.py#L6C1-L21C101
90
235
  Class for automatic (de)serialization of subclasses.
91
236
  Classes that inherit from this class can be serialized and deserialized automatically by calling this classes
92
237
  'from_json' method.
@@ -150,8 +295,14 @@ class SubclassJSONSerializer:
150
295
  return None
151
296
  if not isinstance(data, dict) or ('_type' not in data):
152
297
  return data
298
+ if '__dataclass__' in data:
299
+ # if the data is a dataclass, deserialize it
300
+ return deserialize_dataclass(data)
301
+
153
302
  # check if type module is builtins
154
303
  data_type = get_type_from_string(data["_type"])
304
+ if len(data) == 1:
305
+ return data_type
155
306
  if data_type.__module__ == 'builtins':
156
307
  if is_iterable(data['value']) and not isinstance(data['value'], dict):
157
308
  return data_type([cls.from_json(d) for d in data['value']])
@@ -279,7 +430,7 @@ def table_rows_as_str(row_dict: Dict[str, Any], columns_per_row: int = 9):
279
430
  values = [list(map(lambda i: i[1], row)) for row in all_items]
280
431
  all_table_rows = []
281
432
  for row_keys, row_values in zip(keys, values):
282
- table = tabulate([row_values], headers=row_keys, tablefmt='plain')
433
+ table = tabulate([row_values], headers=row_keys, tablefmt='plain', maxcolwidths=[20] * len(row_keys))
283
434
  all_table_rows.append(table)
284
435
  return "\n".join(all_table_rows)
285
436
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ripple_down_rules
3
- Version: 0.0.14
3
+ Version: 0.1.0
4
4
  Summary: Implements the various versions of Ripple Down Rules (RDR) for knowledge representation and reasoning.
5
5
  Author-email: Abdelrhman Bassiouny <abassiou@uni-bremen.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -0,0 +1,20 @@
1
+ ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ ripple_down_rules/datasets.py,sha256=AzPtqUXuR1qLQNtRsWLsJ3gX2oIf8nIkFvmsmz7fHlw,4601
3
+ ripple_down_rules/experts.py,sha256=Xz1U1Tdq7jrFlcVuSusaMB241AG9TEs7q101i59Xijs,10683
4
+ ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
+ ripple_down_rules/helpers.py,sha256=AhqerAQoCdSovJ7SdQrNtAI_hYagKpLsy2nJQGA0bl0,1062
6
+ ripple_down_rules/prompt.py,sha256=z6KddZOsNiStptgCRNh2OVHHuH6Ooa2f-nsrgJH1qJ8,6311
7
+ ripple_down_rules/rdr.py,sha256=b2_tfTzWQxSOLrWTjhdnqfRSjRmDdykjCY_2NxNhX2Y,43560
8
+ ripple_down_rules/rdr_decorators.py,sha256=8SclpceI3EtrsbuukWJu8HGLh7Q1ZCgYGLX-RPlG-w0,2018
9
+ ripple_down_rules/rules.py,sha256=aM3Im4ePuFDlkuD2EKRtiVmYgoQ_sxlwcbzrDKqXAfs,14578
10
+ ripple_down_rules/utils.py,sha256=O5wFIzS1GzUaIlT-JlxZcQFhu5nEscqKsI7CPDOjn4U,23666
11
+ ripple_down_rules/datastructures/__init__.py,sha256=zpmiYm4WkwNHaGdTIfacS7llN5d2xyU6U-saH_TpydI,103
12
+ ripple_down_rules/datastructures/callable_expression.py,sha256=ac2TaMr0hiRX928GMcr3oTQic8KXXO4syLw4KV-Iehs,10515
13
+ ripple_down_rules/datastructures/case.py,sha256=YI14X8_BdxXfN2-FG1m3uSu6SPUAtZ8NxD6cMIs-z48,13556
14
+ ripple_down_rules/datastructures/dataclasses.py,sha256=AI-wqNy8y9QPg6lov0P-c5b8JXemuM4X62tIRhW-Gqs,4231
15
+ ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
16
+ ripple_down_rules-0.1.0.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
17
+ ripple_down_rules-0.1.0.dist-info/METADATA,sha256=1Pcqch8RAd1REIaCGFXn5uJFTI3VmcmqgzavZd7G4f4,42518
18
+ ripple_down_rules-0.1.0.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
19
+ ripple_down_rules-0.1.0.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
20
+ ripple_down_rules-0.1.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (79.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,18 +0,0 @@
1
- ripple_down_rules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- ripple_down_rules/datasets.py,sha256=qYX7IF7ACm0VRbaKfEgQ32j0YbUUyt2GfGU5Lo42CqI,4601
3
- ripple_down_rules/experts.py,sha256=wg1uY0ox9dUMR4s1RdGjzpX1_WUqnCa060r1U9lrKYI,11214
4
- ripple_down_rules/failures.py,sha256=E6ajDUsw3Blom8eVLbA7d_Qnov2conhtZ0UmpQ9ZtSE,302
5
- ripple_down_rules/prompt.py,sha256=QAmxg4ssrGUAlK7lbyKw2nuRczTZColpjc9uMC1ts3I,4210
6
- ripple_down_rules/rdr.py,sha256=Z-Zu8dH1kdlwOzTVT3DrDdqkMXk6xFQze-Mtz1YZ320,38142
7
- ripple_down_rules/rules.py,sha256=Gi_GRYDvnwxPIHEh_aP1NmN9cL2pHZbwbaiO6M7YdU0,14044
8
- ripple_down_rules/utils.py,sha256=32gU9NHIcxQpfS4zPSAtn63np5SdVvS9VuyoYiyKZbc,18664
9
- ripple_down_rules/datastructures/__init__.py,sha256=zpmiYm4WkwNHaGdTIfacS7llN5d2xyU6U-saH_TpydI,103
10
- ripple_down_rules/datastructures/callable_expression.py,sha256=TN6bi4VYjyLlSLTEA3dRo5ENfEdQYc8Fjj5nbnsz-C0,9058
11
- ripple_down_rules/datastructures/case.py,sha256=0EV69VBzsBTHQ3Ots7fkP09g1tyJ20xW9kSm0nZGy40,14071
12
- ripple_down_rules/datastructures/dataclasses.py,sha256=EVQ1jBKW7K7q7_JNgikHX9fm3EmQQKA74sNjEQ4rXn8,2449
13
- ripple_down_rules/datastructures/enums.py,sha256=l0Eu-TeJ6qB2XHoJycXmUgLw-3yUebQ8SsEbW8bBZdM,4543
14
- ripple_down_rules-0.0.14.dist-info/licenses/LICENSE,sha256=ixuiBLtpoK3iv89l7ylKkg9rs2GzF9ukPH7ynZYzK5s,35148
15
- ripple_down_rules-0.0.14.dist-info/METADATA,sha256=VHKBkYFWhQvidy3vV2Td6kyLLjrwkYMv18ROs_dbb2o,42519
16
- ripple_down_rules-0.0.14.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
17
- ripple_down_rules-0.0.14.dist-info/top_level.txt,sha256=VeoLhEhyK46M1OHwoPbCQLI1EifLjChqGzhQ6WEUqeM,18
18
- ripple_down_rules-0.0.14.dist-info/RECORD,,