ripple-down-rules 0.6.0__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +21 -1
- ripple_down_rules/datastructures/callable_expression.py +24 -7
- ripple_down_rules/datastructures/case.py +12 -11
- ripple_down_rules/datastructures/dataclasses.py +135 -14
- ripple_down_rules/datastructures/enums.py +29 -86
- ripple_down_rules/datastructures/field_info.py +177 -0
- ripple_down_rules/datastructures/tracked_object.py +208 -0
- ripple_down_rules/experts.py +141 -50
- ripple_down_rules/failures.py +4 -0
- ripple_down_rules/helpers.py +75 -8
- ripple_down_rules/predicates.py +97 -0
- ripple_down_rules/rdr.py +712 -96
- ripple_down_rules/rdr_decorators.py +164 -112
- ripple_down_rules/rules.py +351 -114
- ripple_down_rules/user_interface/gui.py +66 -41
- ripple_down_rules/user_interface/ipython_custom_shell.py +46 -9
- ripple_down_rules/user_interface/prompt.py +80 -60
- ripple_down_rules/user_interface/template_file_creator.py +13 -8
- ripple_down_rules/utils.py +537 -53
- {ripple_down_rules-0.6.0.dist-info → ripple_down_rules-0.6.6.dist-info}/METADATA +4 -1
- ripple_down_rules-0.6.6.dist-info/RECORD +28 -0
- ripple_down_rules-0.6.0.dist-info/RECORD +0 -24
- {ripple_down_rules-0.6.0.dist-info → ripple_down_rules-0.6.6.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.6.0.dist-info → ripple_down_rules-0.6.6.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.6.0.dist-info → ripple_down_rules-0.6.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import enum
|
4
|
+
import importlib
|
5
|
+
import inspect
|
6
|
+
import logging
|
7
|
+
import sys
|
8
|
+
import typing
|
9
|
+
from dataclasses import dataclass, Field
|
10
|
+
from datetime import datetime
|
11
|
+
from functools import lru_cache
|
12
|
+
from types import NoneType
|
13
|
+
|
14
|
+
from typing_extensions import Type, get_origin, Optional, get_type_hints, Tuple
|
15
|
+
|
16
|
+
from ripple_down_rules.utils import make_tuple
|
17
|
+
|
18
|
+
|
19
|
+
class ParseError(TypeError):
|
20
|
+
"""
|
21
|
+
Error that will be raised when the parser encounters something that can/should not be parsed.
|
22
|
+
|
23
|
+
For instance, Union types
|
24
|
+
"""
|
25
|
+
pass
|
26
|
+
|
27
|
+
|
28
|
+
@dataclass
|
29
|
+
class FieldInfo:
|
30
|
+
"""
|
31
|
+
A class that wraps a field of dataclass and provides some utility functions.
|
32
|
+
"""
|
33
|
+
|
34
|
+
clazz: Type
|
35
|
+
"""
|
36
|
+
The class that the field is in.
|
37
|
+
"""
|
38
|
+
|
39
|
+
name: str
|
40
|
+
"""
|
41
|
+
The name of the field.
|
42
|
+
"""
|
43
|
+
|
44
|
+
type: Tuple[Type, ...]
|
45
|
+
"""
|
46
|
+
The type of the field or inner type of the container if it is a container.
|
47
|
+
"""
|
48
|
+
|
49
|
+
optional: bool
|
50
|
+
"""
|
51
|
+
True if the field is optional, False otherwise.
|
52
|
+
"""
|
53
|
+
|
54
|
+
container: Optional[Type]
|
55
|
+
"""
|
56
|
+
The type of the container if it is one (list, set, tuple, etc.). If there is no container this is None
|
57
|
+
"""
|
58
|
+
|
59
|
+
is_type_field: bool = False
|
60
|
+
"""
|
61
|
+
The field value is a type.
|
62
|
+
"""
|
63
|
+
|
64
|
+
field: Field = None
|
65
|
+
"""
|
66
|
+
The field instance.
|
67
|
+
"""
|
68
|
+
|
69
|
+
def __init__(self, clazz: Type, f: Field):
|
70
|
+
self.field = f
|
71
|
+
self.name = f.name
|
72
|
+
self.clazz = clazz
|
73
|
+
|
74
|
+
try:
|
75
|
+
type_hints = get_type_hints(clazz)[self.name]
|
76
|
+
except NameError as e:
|
77
|
+
found_clazz = manually_search_for_class_name(e.name)
|
78
|
+
module = importlib.import_module(found_clazz.__module__)
|
79
|
+
locals()[e.name] = getattr(module, e.name)
|
80
|
+
type_hints = get_type_hints(clazz, localns=locals())[self.name]
|
81
|
+
type_args = typing.get_args(type_hints)
|
82
|
+
|
83
|
+
# try to unpack the type if it is a nested type
|
84
|
+
if len(type_args) > 0:
|
85
|
+
self.optional = NoneType in type_args
|
86
|
+
|
87
|
+
if self.optional:
|
88
|
+
self.container = None
|
89
|
+
else:
|
90
|
+
self.container = get_origin(type_hints)
|
91
|
+
|
92
|
+
if not self.optional and type_hints == Type[type_args]:
|
93
|
+
self.is_type_field = True
|
94
|
+
|
95
|
+
self.type = type_args
|
96
|
+
else:
|
97
|
+
self.optional = False
|
98
|
+
self.container = None
|
99
|
+
self.type = make_tuple(type_hints)
|
100
|
+
|
101
|
+
@property
|
102
|
+
def is_builtin_class(self) -> bool:
|
103
|
+
return not self.container and all(t.__module__ == 'builtins' for t in self.type)
|
104
|
+
|
105
|
+
@property
|
106
|
+
def is_container_of_builtin(self) -> bool:
|
107
|
+
return self.container and all(t.__module__ == 'builtins' for t in self.type)
|
108
|
+
|
109
|
+
@property
|
110
|
+
def is_type_type(self) -> bool:
|
111
|
+
return self.is_type_field
|
112
|
+
|
113
|
+
@property
|
114
|
+
def is_enum(self):
|
115
|
+
return len(self.type) == 1 and issubclass(self.type[0], enum.Enum)
|
116
|
+
|
117
|
+
@property
|
118
|
+
def is_datetime(self):
|
119
|
+
return len(self.type) == 1 and self.type[0] == datetime
|
120
|
+
|
121
|
+
|
122
|
+
def is_container(clazz: Type) -> bool:
|
123
|
+
"""
|
124
|
+
Check if a class is an iterable.
|
125
|
+
|
126
|
+
:param clazz: The class to check
|
127
|
+
:return: True if the class is an iterable, False otherwise
|
128
|
+
"""
|
129
|
+
return get_origin(clazz) in [list, set, tuple]
|
130
|
+
|
131
|
+
|
132
|
+
def manually_search_for_class_name(target_class_name: str) -> Type:
|
133
|
+
"""
|
134
|
+
Searches for a class with the specified name in the current module's `globals()` dictionary
|
135
|
+
and all loaded modules present in `sys.modules`. This function attempts to find and resolve
|
136
|
+
the first class that matches the given name. If multiple classes are found with the same
|
137
|
+
name, a warning is logged, and the first one is returned. If no matching class is found,
|
138
|
+
an exception is raised.
|
139
|
+
|
140
|
+
:param target_class_name: Name of the class to search for.
|
141
|
+
:return: The resolved class with the matching name.
|
142
|
+
|
143
|
+
:raises ValueError: Raised when no class with the specified name can be found.
|
144
|
+
"""
|
145
|
+
found_classes = []
|
146
|
+
|
147
|
+
# Search 1: In the current module's globals()
|
148
|
+
for name, obj in globals().items():
|
149
|
+
if inspect.isclass(obj) and obj.__name__ == target_class_name:
|
150
|
+
found_classes.append(obj)
|
151
|
+
|
152
|
+
# Search 2: In all loaded modules (via sys.modules)
|
153
|
+
for module_name, module in sys.modules.items():
|
154
|
+
if module is None or not hasattr(module, '__dict__'):
|
155
|
+
continue # Skip built-in modules or modules without a __dict__
|
156
|
+
|
157
|
+
for name, obj in module.__dict__.items():
|
158
|
+
if inspect.isclass(obj) and obj.__name__ == target_class_name:
|
159
|
+
# Avoid duplicates if a class is imported into multiple namespaces
|
160
|
+
if (obj, f"from module '{module_name}'") not in found_classes:
|
161
|
+
found_classes.append(obj)
|
162
|
+
|
163
|
+
# If you wanted to "resolve" the forward ref based on this
|
164
|
+
if len(found_classes) == 0:
|
165
|
+
raise ValueError(f"Could not find any class with name {target_class_name} in globals or sys.modules.")
|
166
|
+
elif len(found_classes) == 1:
|
167
|
+
resolved_class = found_classes[0]
|
168
|
+
else:
|
169
|
+
warn_multiple_classes(target_class_name, tuple(found_classes))
|
170
|
+
resolved_class = found_classes[0]
|
171
|
+
|
172
|
+
return resolved_class
|
173
|
+
|
174
|
+
|
175
|
+
@lru_cache(maxsize=None)
|
176
|
+
def warn_multiple_classes(target_class_name, found_classes):
|
177
|
+
logging.warning(f"Found multiple classes with name {target_class_name}. Found classes: {found_classes} ")
|
@@ -0,0 +1,208 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import inspect
|
4
|
+
import uuid
|
5
|
+
from dataclasses import dataclass, field, Field, fields
|
6
|
+
from enum import Enum
|
7
|
+
from functools import lru_cache
|
8
|
+
|
9
|
+
import pydot
|
10
|
+
import rustworkx as rx
|
11
|
+
from typing_extensions import Any, TYPE_CHECKING, Type, final, ClassVar, Dict, List, Optional, Tuple
|
12
|
+
|
13
|
+
from .field_info import FieldInfo
|
14
|
+
from .. import logger
|
15
|
+
from ..rules import Rule
|
16
|
+
from ..utils import recursive_subclasses
|
17
|
+
|
18
|
+
|
19
|
+
class Direction(Enum):
|
20
|
+
OUTBOUND = False
|
21
|
+
INBOUND = True
|
22
|
+
|
23
|
+
|
24
|
+
class Relation(str, Enum):
|
25
|
+
has = "has"
|
26
|
+
isA = "isA"
|
27
|
+
dependsOn = "dependsOn"
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass(unsafe_hash=True)
|
31
|
+
class TrackedObjectMixin:
|
32
|
+
"""
|
33
|
+
A class that is used as a base class to all classes that needs to be tracked for RDR inference, and reasoning.
|
34
|
+
"""
|
35
|
+
_rdr_tracked_object_id: int = field(init=False, repr=False, hash=False,
|
36
|
+
compare=False, default_factory=lambda: uuid.uuid4().int)
|
37
|
+
"""
|
38
|
+
The unique identifier of the conclusion.
|
39
|
+
"""
|
40
|
+
_dependency_graph: ClassVar[rx.PyDAG[Type[TrackedObjectMixin]]] = rx.PyDAG()
|
41
|
+
"""
|
42
|
+
A graph that represents the relationships between all tracked objects.
|
43
|
+
"""
|
44
|
+
_class_graph_indices: ClassVar[Dict[Type[TrackedObjectMixin], int]] = {}
|
45
|
+
"""
|
46
|
+
The index of the current class in the dependency graph.
|
47
|
+
"""
|
48
|
+
_composition_edges: ClassVar[List[Tuple[int, int]]] = []
|
49
|
+
"""
|
50
|
+
The edges that represent composition relations between objects (Relation.has).
|
51
|
+
"""
|
52
|
+
_inheritance_edges: ClassVar[List[Tuple[int, int]]] = []
|
53
|
+
"""
|
54
|
+
The edges that represent inheritance relations between objects (Relation.isA).
|
55
|
+
"""
|
56
|
+
_overridden_by: Type[TrackedObjectMixin] = field(init=False, repr=False, hash=False,
|
57
|
+
compare=False, default=None)
|
58
|
+
"""
|
59
|
+
Whether the class has been overridden by a subclass.
|
60
|
+
This is used to only include the new class in the dependency graph, not the overridden class.
|
61
|
+
"""
|
62
|
+
|
63
|
+
@classmethod
|
64
|
+
def _reset_dependency_graph(cls) -> None:
|
65
|
+
"""
|
66
|
+
Reset the dependency graph and all class indices.
|
67
|
+
"""
|
68
|
+
cls._dependency_graph = rx.PyDAG()
|
69
|
+
cls._class_graph_indices = {}
|
70
|
+
cls._composition_edges = []
|
71
|
+
cls._inheritance_edges = []
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def _my_graph_idx(cls):
|
75
|
+
return cls._class_graph_indices[cls]
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def make_class_dependency_graph(cls, composition: bool = True):
|
79
|
+
"""
|
80
|
+
Create a direct acyclic graph containing the class hierarchy.
|
81
|
+
|
82
|
+
:param composition: If True, the class dependency graph will include composition relations.
|
83
|
+
"""
|
84
|
+
classes_to_track = recursive_subclasses(cls) + [Rule] + recursive_subclasses(Rule)
|
85
|
+
for clazz in classes_to_track:
|
86
|
+
cls._add_class_to_dependency_graph(clazz)
|
87
|
+
|
88
|
+
for clazz in cls._class_graph_indices:
|
89
|
+
bases = [base for base in clazz.__bases__ if
|
90
|
+
base.__module__ not in ["builtins"] and base in cls._class_graph_indices]
|
91
|
+
|
92
|
+
for base in bases:
|
93
|
+
cls._add_class_to_dependency_graph(base)
|
94
|
+
clazz_idx = cls._class_graph_indices[clazz]
|
95
|
+
base_idx = cls._class_graph_indices[base]
|
96
|
+
if (clazz_idx, base_idx) in cls._inheritance_edges or base._overridden_by == clazz:
|
97
|
+
continue
|
98
|
+
cls._dependency_graph.add_edge(clazz_idx, base_idx, Relation.isA)
|
99
|
+
cls._inheritance_edges.append((clazz_idx, base_idx))
|
100
|
+
|
101
|
+
if not composition:
|
102
|
+
return
|
103
|
+
for clazz, idx in cls._class_graph_indices.items():
|
104
|
+
if clazz.__module__ == "builtins":
|
105
|
+
continue
|
106
|
+
TrackedObjectMixin.parse_fields(clazz)
|
107
|
+
|
108
|
+
@staticmethod
|
109
|
+
def parse_fields(clazz) -> None:
|
110
|
+
for f in TrackedObjectMixin.get_fields(clazz):
|
111
|
+
|
112
|
+
logger.debug("=" * 80)
|
113
|
+
logger.debug(f"Processing Field {clazz.__name__}.{f.name}: {f.type}.")
|
114
|
+
|
115
|
+
# skip private fields
|
116
|
+
if f.name.startswith("_"):
|
117
|
+
logger.debug(f"Skipping since the field starts with _.")
|
118
|
+
continue
|
119
|
+
|
120
|
+
field_info = FieldInfo(clazz, f)
|
121
|
+
TrackedObjectMixin.parse_field(field_info)
|
122
|
+
|
123
|
+
@staticmethod
|
124
|
+
def get_fields(clazz) -> List[Field]:
|
125
|
+
skip_fields = []
|
126
|
+
bases = [base for base in clazz.__bases__ if issubclass(base, TrackedObjectMixin)]
|
127
|
+
for base in bases:
|
128
|
+
skip_fields.extend(TrackedObjectMixin.get_fields(base))
|
129
|
+
|
130
|
+
result = [cls_field for cls_field in fields(clazz) if cls_field not in skip_fields]
|
131
|
+
|
132
|
+
return result
|
133
|
+
|
134
|
+
@staticmethod
|
135
|
+
def parse_field(field_info: FieldInfo):
|
136
|
+
parent_idx = TrackedObjectMixin._class_graph_indices[field_info.clazz]
|
137
|
+
field_cls: Optional[Type[TrackedObjectMixin]] = None
|
138
|
+
field_relation = Relation.has
|
139
|
+
if len(field_info.type) == 1 and issubclass(field_info.type[0], TrackedObjectMixin):
|
140
|
+
field_cls = field_info.type[0]
|
141
|
+
else:
|
142
|
+
# TODO: Create a new TrackedObjectMixin class for new type
|
143
|
+
logger.debug(f"Skipping unhandled field type: {field_info.type}")
|
144
|
+
# logger.debug(f"Creating new TrackedObject type for builtin type {field_info.type}.")
|
145
|
+
# field_cls = cls._create_tracked_object_class_for_field(field_info)
|
146
|
+
|
147
|
+
if field_cls is not None:
|
148
|
+
field_cls_idx = TrackedObjectMixin._class_graph_indices.get(field_cls, None)
|
149
|
+
if field_cls_idx is not None and (parent_idx, field_cls_idx) in TrackedObjectMixin._composition_edges:
|
150
|
+
return
|
151
|
+
elif field_cls_idx is None:
|
152
|
+
TrackedObjectMixin._add_class_to_dependency_graph(field_cls)
|
153
|
+
field_cls_idx = TrackedObjectMixin._class_graph_indices[field_cls]
|
154
|
+
TrackedObjectMixin._dependency_graph.add_edge(parent_idx, field_cls_idx, field_relation)
|
155
|
+
TrackedObjectMixin._composition_edges.append((parent_idx, field_cls_idx))
|
156
|
+
|
157
|
+
@classmethod
|
158
|
+
def _create_tracked_object_class_for_field(cls, field_info: FieldInfo):
|
159
|
+
raise NotImplementedError
|
160
|
+
|
161
|
+
@classmethod
|
162
|
+
def to_dot(cls, filepath: str, format='png') -> None:
|
163
|
+
if not filepath.endswith(f".{format}"):
|
164
|
+
filepath += f".{format}"
|
165
|
+
dot_str = cls._dependency_graph.to_dot(
|
166
|
+
lambda node: dict(
|
167
|
+
color='black', fillcolor='lightblue', style='filled', label=node.__name__),
|
168
|
+
lambda edge: dict(color='black', style='solid', label=edge))
|
169
|
+
dot = pydot.graph_from_dot_data(dot_str)[0]
|
170
|
+
dot.write(filepath, format=format)
|
171
|
+
|
172
|
+
@classmethod
|
173
|
+
def _add_class_to_dependency_graph(cls, class_to_add: Type[TrackedObjectMixin]) -> None:
|
174
|
+
"""
|
175
|
+
Add a class to the dependency graph.
|
176
|
+
"""
|
177
|
+
if class_to_add not in cls._class_graph_indices:
|
178
|
+
if not issubclass(class_to_add, TrackedObjectMixin):
|
179
|
+
class_to_add._overridden_by = None
|
180
|
+
cls_idx = cls._dependency_graph.add_node(class_to_add._overridden_by or class_to_add)
|
181
|
+
cls._class_graph_indices[class_to_add] = cls_idx
|
182
|
+
if class_to_add._overridden_by:
|
183
|
+
cls._class_graph_indices[class_to_add._overridden_by] = cls_idx
|
184
|
+
|
185
|
+
def __getattribute__(self, name: str) -> Any:
|
186
|
+
# if name not in [f.name for f in fields(TrackedObjectMixin)] + ['has', 'is_a', 'depends_on']\
|
187
|
+
# and not name.startswith("_"):
|
188
|
+
# self._record_dependency(name)
|
189
|
+
return object.__getattribute__(self, name)
|
190
|
+
|
191
|
+
def _record_dependency(self, attr_name):
|
192
|
+
# Inspect stack to find instance of CallableExpression
|
193
|
+
for frame_info in inspect.stack():
|
194
|
+
func_name = frame_info.function
|
195
|
+
local_self = frame_info.frame.f_locals.get("self", None)
|
196
|
+
if (
|
197
|
+
func_name == "__call__" and
|
198
|
+
local_self is not None and
|
199
|
+
type(local_self).__module__ == "callable_expression" and
|
200
|
+
type(local_self).__name__ == "CallableExpression"
|
201
|
+
):
|
202
|
+
logger.debug("TrackedObject used inside CallableExpression")
|
203
|
+
break
|
204
|
+
|
205
|
+
|
206
|
+
annotations = TrackedObjectMixin.__annotations__
|
207
|
+
for val in [f.name for f in fields(TrackedObjectMixin)]:
|
208
|
+
annotations.pop(val, None)
|