ripple-down-rules 0.6.0__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,177 @@
1
+ from __future__ import annotations
2
+
3
+ import enum
4
+ import importlib
5
+ import inspect
6
+ import logging
7
+ import sys
8
+ import typing
9
+ from dataclasses import dataclass, Field
10
+ from datetime import datetime
11
+ from functools import lru_cache
12
+ from types import NoneType
13
+
14
+ from typing_extensions import Type, get_origin, Optional, get_type_hints, Tuple
15
+
16
+ from ripple_down_rules.utils import make_tuple
17
+
18
+
19
+ class ParseError(TypeError):
20
+ """
21
+ Error that will be raised when the parser encounters something that can/should not be parsed.
22
+
23
+ For instance, Union types
24
+ """
25
+ pass
26
+
27
+
28
+ @dataclass
29
+ class FieldInfo:
30
+ """
31
+ A class that wraps a field of dataclass and provides some utility functions.
32
+ """
33
+
34
+ clazz: Type
35
+ """
36
+ The class that the field is in.
37
+ """
38
+
39
+ name: str
40
+ """
41
+ The name of the field.
42
+ """
43
+
44
+ type: Tuple[Type, ...]
45
+ """
46
+ The type of the field or inner type of the container if it is a container.
47
+ """
48
+
49
+ optional: bool
50
+ """
51
+ True if the field is optional, False otherwise.
52
+ """
53
+
54
+ container: Optional[Type]
55
+ """
56
+ The type of the container if it is one (list, set, tuple, etc.). If there is no container this is None
57
+ """
58
+
59
+ is_type_field: bool = False
60
+ """
61
+ The field value is a type.
62
+ """
63
+
64
+ field: Field = None
65
+ """
66
+ The field instance.
67
+ """
68
+
69
+ def __init__(self, clazz: Type, f: Field):
70
+ self.field = f
71
+ self.name = f.name
72
+ self.clazz = clazz
73
+
74
+ try:
75
+ type_hints = get_type_hints(clazz)[self.name]
76
+ except NameError as e:
77
+ found_clazz = manually_search_for_class_name(e.name)
78
+ module = importlib.import_module(found_clazz.__module__)
79
+ locals()[e.name] = getattr(module, e.name)
80
+ type_hints = get_type_hints(clazz, localns=locals())[self.name]
81
+ type_args = typing.get_args(type_hints)
82
+
83
+ # try to unpack the type if it is a nested type
84
+ if len(type_args) > 0:
85
+ self.optional = NoneType in type_args
86
+
87
+ if self.optional:
88
+ self.container = None
89
+ else:
90
+ self.container = get_origin(type_hints)
91
+
92
+ if not self.optional and type_hints == Type[type_args]:
93
+ self.is_type_field = True
94
+
95
+ self.type = type_args
96
+ else:
97
+ self.optional = False
98
+ self.container = None
99
+ self.type = make_tuple(type_hints)
100
+
101
+ @property
102
+ def is_builtin_class(self) -> bool:
103
+ return not self.container and all(t.__module__ == 'builtins' for t in self.type)
104
+
105
+ @property
106
+ def is_container_of_builtin(self) -> bool:
107
+ return self.container and all(t.__module__ == 'builtins' for t in self.type)
108
+
109
+ @property
110
+ def is_type_type(self) -> bool:
111
+ return self.is_type_field
112
+
113
+ @property
114
+ def is_enum(self):
115
+ return len(self.type) == 1 and issubclass(self.type[0], enum.Enum)
116
+
117
+ @property
118
+ def is_datetime(self):
119
+ return len(self.type) == 1 and self.type[0] == datetime
120
+
121
+
122
+ def is_container(clazz: Type) -> bool:
123
+ """
124
+ Check if a class is an iterable.
125
+
126
+ :param clazz: The class to check
127
+ :return: True if the class is an iterable, False otherwise
128
+ """
129
+ return get_origin(clazz) in [list, set, tuple]
130
+
131
+
132
+ def manually_search_for_class_name(target_class_name: str) -> Type:
133
+ """
134
+ Searches for a class with the specified name in the current module's `globals()` dictionary
135
+ and all loaded modules present in `sys.modules`. This function attempts to find and resolve
136
+ the first class that matches the given name. If multiple classes are found with the same
137
+ name, a warning is logged, and the first one is returned. If no matching class is found,
138
+ an exception is raised.
139
+
140
+ :param target_class_name: Name of the class to search for.
141
+ :return: The resolved class with the matching name.
142
+
143
+ :raises ValueError: Raised when no class with the specified name can be found.
144
+ """
145
+ found_classes = []
146
+
147
+ # Search 1: In the current module's globals()
148
+ for name, obj in globals().items():
149
+ if inspect.isclass(obj) and obj.__name__ == target_class_name:
150
+ found_classes.append(obj)
151
+
152
+ # Search 2: In all loaded modules (via sys.modules)
153
+ for module_name, module in sys.modules.items():
154
+ if module is None or not hasattr(module, '__dict__'):
155
+ continue # Skip built-in modules or modules without a __dict__
156
+
157
+ for name, obj in module.__dict__.items():
158
+ if inspect.isclass(obj) and obj.__name__ == target_class_name:
159
+ # Avoid duplicates if a class is imported into multiple namespaces
160
+ if (obj, f"from module '{module_name}'") not in found_classes:
161
+ found_classes.append(obj)
162
+
163
+ # If you wanted to "resolve" the forward ref based on this
164
+ if len(found_classes) == 0:
165
+ raise ValueError(f"Could not find any class with name {target_class_name} in globals or sys.modules.")
166
+ elif len(found_classes) == 1:
167
+ resolved_class = found_classes[0]
168
+ else:
169
+ warn_multiple_classes(target_class_name, tuple(found_classes))
170
+ resolved_class = found_classes[0]
171
+
172
+ return resolved_class
173
+
174
+
175
+ @lru_cache(maxsize=None)
176
+ def warn_multiple_classes(target_class_name, found_classes):
177
+ logging.warning(f"Found multiple classes with name {target_class_name}. Found classes: {found_classes} ")
@@ -0,0 +1,208 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import uuid
5
+ from dataclasses import dataclass, field, Field, fields
6
+ from enum import Enum
7
+ from functools import lru_cache
8
+
9
+ import pydot
10
+ import rustworkx as rx
11
+ from typing_extensions import Any, TYPE_CHECKING, Type, final, ClassVar, Dict, List, Optional, Tuple
12
+
13
+ from .field_info import FieldInfo
14
+ from .. import logger
15
+ from ..rules import Rule
16
+ from ..utils import recursive_subclasses
17
+
18
+
19
+ class Direction(Enum):
20
+ OUTBOUND = False
21
+ INBOUND = True
22
+
23
+
24
+ class Relation(str, Enum):
25
+ has = "has"
26
+ isA = "isA"
27
+ dependsOn = "dependsOn"
28
+
29
+
30
+ @dataclass(unsafe_hash=True)
31
+ class TrackedObjectMixin:
32
+ """
33
+ A class that is used as a base class to all classes that needs to be tracked for RDR inference, and reasoning.
34
+ """
35
+ _rdr_tracked_object_id: int = field(init=False, repr=False, hash=False,
36
+ compare=False, default_factory=lambda: uuid.uuid4().int)
37
+ """
38
+ The unique identifier of the conclusion.
39
+ """
40
+ _dependency_graph: ClassVar[rx.PyDAG[Type[TrackedObjectMixin]]] = rx.PyDAG()
41
+ """
42
+ A graph that represents the relationships between all tracked objects.
43
+ """
44
+ _class_graph_indices: ClassVar[Dict[Type[TrackedObjectMixin], int]] = {}
45
+ """
46
+ The index of the current class in the dependency graph.
47
+ """
48
+ _composition_edges: ClassVar[List[Tuple[int, int]]] = []
49
+ """
50
+ The edges that represent composition relations between objects (Relation.has).
51
+ """
52
+ _inheritance_edges: ClassVar[List[Tuple[int, int]]] = []
53
+ """
54
+ The edges that represent inheritance relations between objects (Relation.isA).
55
+ """
56
+ _overridden_by: Type[TrackedObjectMixin] = field(init=False, repr=False, hash=False,
57
+ compare=False, default=None)
58
+ """
59
+ Whether the class has been overridden by a subclass.
60
+ This is used to only include the new class in the dependency graph, not the overridden class.
61
+ """
62
+
63
+ @classmethod
64
+ def _reset_dependency_graph(cls) -> None:
65
+ """
66
+ Reset the dependency graph and all class indices.
67
+ """
68
+ cls._dependency_graph = rx.PyDAG()
69
+ cls._class_graph_indices = {}
70
+ cls._composition_edges = []
71
+ cls._inheritance_edges = []
72
+
73
+ @classmethod
74
+ def _my_graph_idx(cls):
75
+ return cls._class_graph_indices[cls]
76
+
77
+ @classmethod
78
+ def make_class_dependency_graph(cls, composition: bool = True):
79
+ """
80
+ Create a direct acyclic graph containing the class hierarchy.
81
+
82
+ :param composition: If True, the class dependency graph will include composition relations.
83
+ """
84
+ classes_to_track = recursive_subclasses(cls) + [Rule] + recursive_subclasses(Rule)
85
+ for clazz in classes_to_track:
86
+ cls._add_class_to_dependency_graph(clazz)
87
+
88
+ for clazz in cls._class_graph_indices:
89
+ bases = [base for base in clazz.__bases__ if
90
+ base.__module__ not in ["builtins"] and base in cls._class_graph_indices]
91
+
92
+ for base in bases:
93
+ cls._add_class_to_dependency_graph(base)
94
+ clazz_idx = cls._class_graph_indices[clazz]
95
+ base_idx = cls._class_graph_indices[base]
96
+ if (clazz_idx, base_idx) in cls._inheritance_edges or base._overridden_by == clazz:
97
+ continue
98
+ cls._dependency_graph.add_edge(clazz_idx, base_idx, Relation.isA)
99
+ cls._inheritance_edges.append((clazz_idx, base_idx))
100
+
101
+ if not composition:
102
+ return
103
+ for clazz, idx in cls._class_graph_indices.items():
104
+ if clazz.__module__ == "builtins":
105
+ continue
106
+ TrackedObjectMixin.parse_fields(clazz)
107
+
108
+ @staticmethod
109
+ def parse_fields(clazz) -> None:
110
+ for f in TrackedObjectMixin.get_fields(clazz):
111
+
112
+ logger.debug("=" * 80)
113
+ logger.debug(f"Processing Field {clazz.__name__}.{f.name}: {f.type}.")
114
+
115
+ # skip private fields
116
+ if f.name.startswith("_"):
117
+ logger.debug(f"Skipping since the field starts with _.")
118
+ continue
119
+
120
+ field_info = FieldInfo(clazz, f)
121
+ TrackedObjectMixin.parse_field(field_info)
122
+
123
+ @staticmethod
124
+ def get_fields(clazz) -> List[Field]:
125
+ skip_fields = []
126
+ bases = [base for base in clazz.__bases__ if issubclass(base, TrackedObjectMixin)]
127
+ for base in bases:
128
+ skip_fields.extend(TrackedObjectMixin.get_fields(base))
129
+
130
+ result = [cls_field for cls_field in fields(clazz) if cls_field not in skip_fields]
131
+
132
+ return result
133
+
134
+ @staticmethod
135
+ def parse_field(field_info: FieldInfo):
136
+ parent_idx = TrackedObjectMixin._class_graph_indices[field_info.clazz]
137
+ field_cls: Optional[Type[TrackedObjectMixin]] = None
138
+ field_relation = Relation.has
139
+ if len(field_info.type) == 1 and issubclass(field_info.type[0], TrackedObjectMixin):
140
+ field_cls = field_info.type[0]
141
+ else:
142
+ # TODO: Create a new TrackedObjectMixin class for new type
143
+ logger.debug(f"Skipping unhandled field type: {field_info.type}")
144
+ # logger.debug(f"Creating new TrackedObject type for builtin type {field_info.type}.")
145
+ # field_cls = cls._create_tracked_object_class_for_field(field_info)
146
+
147
+ if field_cls is not None:
148
+ field_cls_idx = TrackedObjectMixin._class_graph_indices.get(field_cls, None)
149
+ if field_cls_idx is not None and (parent_idx, field_cls_idx) in TrackedObjectMixin._composition_edges:
150
+ return
151
+ elif field_cls_idx is None:
152
+ TrackedObjectMixin._add_class_to_dependency_graph(field_cls)
153
+ field_cls_idx = TrackedObjectMixin._class_graph_indices[field_cls]
154
+ TrackedObjectMixin._dependency_graph.add_edge(parent_idx, field_cls_idx, field_relation)
155
+ TrackedObjectMixin._composition_edges.append((parent_idx, field_cls_idx))
156
+
157
+ @classmethod
158
+ def _create_tracked_object_class_for_field(cls, field_info: FieldInfo):
159
+ raise NotImplementedError
160
+
161
+ @classmethod
162
+ def to_dot(cls, filepath: str, format='png') -> None:
163
+ if not filepath.endswith(f".{format}"):
164
+ filepath += f".{format}"
165
+ dot_str = cls._dependency_graph.to_dot(
166
+ lambda node: dict(
167
+ color='black', fillcolor='lightblue', style='filled', label=node.__name__),
168
+ lambda edge: dict(color='black', style='solid', label=edge))
169
+ dot = pydot.graph_from_dot_data(dot_str)[0]
170
+ dot.write(filepath, format=format)
171
+
172
+ @classmethod
173
+ def _add_class_to_dependency_graph(cls, class_to_add: Type[TrackedObjectMixin]) -> None:
174
+ """
175
+ Add a class to the dependency graph.
176
+ """
177
+ if class_to_add not in cls._class_graph_indices:
178
+ if not issubclass(class_to_add, TrackedObjectMixin):
179
+ class_to_add._overridden_by = None
180
+ cls_idx = cls._dependency_graph.add_node(class_to_add._overridden_by or class_to_add)
181
+ cls._class_graph_indices[class_to_add] = cls_idx
182
+ if class_to_add._overridden_by:
183
+ cls._class_graph_indices[class_to_add._overridden_by] = cls_idx
184
+
185
+ def __getattribute__(self, name: str) -> Any:
186
+ # if name not in [f.name for f in fields(TrackedObjectMixin)] + ['has', 'is_a', 'depends_on']\
187
+ # and not name.startswith("_"):
188
+ # self._record_dependency(name)
189
+ return object.__getattribute__(self, name)
190
+
191
+ def _record_dependency(self, attr_name):
192
+ # Inspect stack to find instance of CallableExpression
193
+ for frame_info in inspect.stack():
194
+ func_name = frame_info.function
195
+ local_self = frame_info.frame.f_locals.get("self", None)
196
+ if (
197
+ func_name == "__call__" and
198
+ local_self is not None and
199
+ type(local_self).__module__ == "callable_expression" and
200
+ type(local_self).__name__ == "CallableExpression"
201
+ ):
202
+ logger.debug("TrackedObject used inside CallableExpression")
203
+ break
204
+
205
+
206
+ annotations = TrackedObjectMixin.__annotations__
207
+ for val in [f.name for f in fields(TrackedObjectMixin)]:
208
+ annotations.pop(val, None)