ripple-down-rules 0.6.51__py3-none-any.whl → 0.6.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +12 -4
- ripple_down_rules/datastructures/dataclasses.py +52 -9
- ripple_down_rules/datastructures/enums.py +14 -87
- ripple_down_rules/datastructures/field_info.py +177 -0
- ripple_down_rules/datastructures/tracked_object.py +208 -0
- ripple_down_rules/helpers.py +37 -2
- ripple_down_rules/predicates.py +97 -0
- ripple_down_rules/rdr.py +10 -6
- ripple_down_rules/rdr_decorators.py +44 -34
- ripple_down_rules/rules.py +166 -97
- ripple_down_rules/user_interface/ipython_custom_shell.py +9 -1
- ripple_down_rules/user_interface/prompt.py +37 -37
- ripple_down_rules/user_interface/template_file_creator.py +3 -0
- ripple_down_rules/utils.py +32 -5
- {ripple_down_rules-0.6.51.dist-info → ripple_down_rules-0.6.60.dist-info}/METADATA +3 -1
- ripple_down_rules-0.6.60.dist-info/RECORD +28 -0
- ripple_down_rules-0.6.51.dist-info/RECORD +0 -25
- {ripple_down_rules-0.6.51.dist-info → ripple_down_rules-0.6.60.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.6.51.dist-info → ripple_down_rules-0.6.60.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.6.51.dist-info → ripple_down_rules-0.6.60.dist-info}/top_level.txt +0 -0
ripple_down_rules/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "0.6.
|
1
|
+
__version__ = "0.6.60"
|
2
2
|
|
3
3
|
import logging
|
4
4
|
import sys
|
@@ -12,6 +12,14 @@ try:
|
|
12
12
|
except ImportError:
|
13
13
|
pass
|
14
14
|
|
15
|
-
|
16
|
-
|
17
|
-
|
15
|
+
|
16
|
+
# Trigger patch
|
17
|
+
try:
|
18
|
+
from .predicates import *
|
19
|
+
from .datastructures.tracked_object import TrackedObjectMixin
|
20
|
+
from .datastructures.dataclasses import CaseQuery
|
21
|
+
from .rdr_decorators import RDRDecorator
|
22
|
+
from .rdr import MultiClassRDR, SingleClassRDR, GeneralRDR
|
23
|
+
import ripple_down_rules_meta._apply_overrides
|
24
|
+
except ImportError:
|
25
|
+
pass
|
@@ -12,7 +12,7 @@ from typing_extensions import Any, Optional, Dict, Type, Tuple, Union, List, Set
|
|
12
12
|
from .callable_expression import CallableExpression
|
13
13
|
from .case import create_case, Case
|
14
14
|
from ..utils import copy_case, make_list, make_set, get_origin_and_args_from_type_hint, render_tree, \
|
15
|
-
get_function_representation
|
15
|
+
get_function_representation, get_method_object_from_pytest_request
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from ..rdr import RippleDownRules
|
@@ -60,9 +60,14 @@ class CaseQuery:
|
|
60
60
|
The executable scenario is the root callable that recreates the situation that the case is
|
61
61
|
created in, for example, when the case is created from a test function, this would be the test function itself.
|
62
62
|
"""
|
63
|
+
this_case_target_value: Optional[Any] = None
|
64
|
+
"""
|
65
|
+
The non relational case query instance target value.
|
66
|
+
"""
|
63
67
|
_target: Optional[CallableExpression] = None
|
64
68
|
"""
|
65
|
-
The target
|
69
|
+
The relational target (the evaluatable conclusion of the rule) which is a callable expression that varies with
|
70
|
+
the case.
|
66
71
|
"""
|
67
72
|
default_value: Optional[Any] = None
|
68
73
|
"""
|
@@ -306,6 +311,12 @@ class CaseFactoryMetaData:
|
|
306
311
|
factory_idx: Optional[int] = None
|
307
312
|
case_conf: Optional[CaseConf] = None
|
308
313
|
scenario: Optional[Callable] = None
|
314
|
+
pytest_request: Optional[Callable] = field(hash=False, compare=False, default=None)
|
315
|
+
this_case_target_value: Optional[Any] = None
|
316
|
+
|
317
|
+
def __post_init__(self):
|
318
|
+
if self.pytest_request is not None and self.scenario is None:
|
319
|
+
self.scenario = get_method_object_from_pytest_request(self.pytest_request)
|
309
320
|
|
310
321
|
@classmethod
|
311
322
|
def from_case_query(cls, case_query: CaseQuery) -> CaseFactoryMetaData:
|
@@ -322,8 +333,9 @@ class CaseFactoryMetaData:
|
|
322
333
|
return (f"CaseFactoryMetaData("
|
323
334
|
f"factory_method={factory_method_repr}, "
|
324
335
|
f"factory_idx={self.factory_idx}, "
|
325
|
-
f"case_conf={self.case_conf},"
|
326
|
-
f"
|
336
|
+
f"case_conf={self.case_conf}, "
|
337
|
+
f"scenario={scenario_repr}, "
|
338
|
+
f"this_case_target_value={self.this_case_target_value})")
|
327
339
|
|
328
340
|
def __str__(self):
|
329
341
|
return self.__repr__()
|
@@ -335,26 +347,57 @@ class RDRConclusion:
|
|
335
347
|
This dataclass represents a conclusion of a Ripple Down Rule.
|
336
348
|
It contains the conclusion expression, the type of the conclusion, and the scope in which it is evaluated.
|
337
349
|
"""
|
338
|
-
|
350
|
+
_conclusion: Any
|
339
351
|
"""
|
340
352
|
The conclusion value.
|
341
353
|
"""
|
342
|
-
|
354
|
+
_frozen_case: Any
|
343
355
|
"""
|
344
356
|
The frozen case that the conclusion was made for.
|
345
357
|
"""
|
346
|
-
|
358
|
+
_rule: Rule
|
347
359
|
"""
|
348
360
|
The rule that gave this conclusion.
|
349
361
|
"""
|
350
|
-
|
362
|
+
_rdr: RippleDownRules
|
351
363
|
"""
|
352
364
|
The Ripple Down Rules that classified the case and produced this conclusion.
|
353
365
|
"""
|
354
|
-
|
366
|
+
_id: int = field(default_factory=lambda: uuid.uuid4().int)
|
355
367
|
"""
|
356
368
|
The unique identifier of the conclusion.
|
357
369
|
"""
|
370
|
+
def __getattribute__(self, name: str) -> Any:
|
371
|
+
if name.startswith('_'):
|
372
|
+
return object.__getattribute__(self, name)
|
373
|
+
else:
|
374
|
+
conclusion = object.__getattribute__(self, "_conclusion")
|
375
|
+
|
376
|
+
value = getattr(conclusion, name)
|
377
|
+
|
378
|
+
self._record_dependency(name)
|
379
|
+
|
380
|
+
return value
|
381
|
+
|
382
|
+
def __setattr__(self, name, value):
|
383
|
+
if name.startswith('_'):
|
384
|
+
object.__setattr__(self, name, value)
|
385
|
+
else:
|
386
|
+
setattr(self._wrapped, name, value)
|
387
|
+
|
388
|
+
def _record_dependency(self, attr_name):
|
389
|
+
# Inspect stack to find instance of CallableExpression
|
390
|
+
for frame_info in inspect.stack():
|
391
|
+
func_name = frame_info.function
|
392
|
+
local_self = frame_info.frame.f_locals.get("self", None)
|
393
|
+
if (
|
394
|
+
func_name == "__call__" and
|
395
|
+
local_self is not None and
|
396
|
+
type(local_self) is CallableExpression
|
397
|
+
):
|
398
|
+
self._used_in_tracker = True
|
399
|
+
print("RDRConclusion used inside CallableExpression")
|
400
|
+
break
|
358
401
|
|
359
402
|
def __hash__(self):
|
360
403
|
return hash(self.id)
|
@@ -93,20 +93,6 @@ class Stop(Category):
|
|
93
93
|
stop = "stop"
|
94
94
|
|
95
95
|
|
96
|
-
class ExpressionParser(Enum):
|
97
|
-
"""
|
98
|
-
Parsers for expressions to evaluate and encapsulate the expression into a callable function.
|
99
|
-
"""
|
100
|
-
ASTVisitor: int = auto()
|
101
|
-
"""
|
102
|
-
Generic python Abstract Syntax Tree that detects variables, attributes, binary/boolean expressions , ...etc.
|
103
|
-
"""
|
104
|
-
SQLAlchemy: int = auto()
|
105
|
-
"""
|
106
|
-
Specific for SQLAlchemy expressions on ORM Tables.
|
107
|
-
"""
|
108
|
-
|
109
|
-
|
110
96
|
class PromptFor(Enum):
|
111
97
|
"""
|
112
98
|
The reason of the prompt. (e.g. get conditions, conclusions, or affirmation).
|
@@ -131,51 +117,6 @@ class PromptFor(Enum):
|
|
131
117
|
return self.__str__()
|
132
118
|
|
133
119
|
|
134
|
-
class CategoricalValue(Enum):
|
135
|
-
"""
|
136
|
-
A categorical value is a value that is a category.
|
137
|
-
"""
|
138
|
-
|
139
|
-
def __eq__(self, other):
|
140
|
-
if isinstance(other, CategoricalValue):
|
141
|
-
return self.name == other.name
|
142
|
-
elif isinstance(other, str):
|
143
|
-
return self.name == other
|
144
|
-
return self.name == other
|
145
|
-
|
146
|
-
def __hash__(self):
|
147
|
-
return hash(self.name)
|
148
|
-
|
149
|
-
@classmethod
|
150
|
-
def to_list(cls):
|
151
|
-
return list(cls._value2member_map_.keys())
|
152
|
-
|
153
|
-
@classmethod
|
154
|
-
def from_str(cls, category: str):
|
155
|
-
return cls[category.lower()]
|
156
|
-
|
157
|
-
@classmethod
|
158
|
-
def from_strs(cls, categories: List[str]):
|
159
|
-
return [cls.from_str(c) for c in categories]
|
160
|
-
|
161
|
-
def __str__(self):
|
162
|
-
return self.name
|
163
|
-
|
164
|
-
def __repr__(self):
|
165
|
-
return self.__str__()
|
166
|
-
|
167
|
-
|
168
|
-
class RDRMode(Enum):
|
169
|
-
Propositional = auto()
|
170
|
-
"""
|
171
|
-
Propositional mode, the mode where the rules are propositional.
|
172
|
-
"""
|
173
|
-
Relational = auto()
|
174
|
-
"""
|
175
|
-
Relational mode, the mode where the rules are relational.
|
176
|
-
"""
|
177
|
-
|
178
|
-
|
179
120
|
class MCRDRMode(Enum):
|
180
121
|
"""
|
181
122
|
The modes of the MultiClassRDR.
|
@@ -213,33 +154,19 @@ class RDREdge(Enum):
|
|
213
154
|
"""
|
214
155
|
Filter edge, the edge that represents the filter condition.
|
215
156
|
"""
|
216
|
-
|
217
|
-
class ValueType(Enum):
|
218
|
-
Unary = auto()
|
219
|
-
"""
|
220
|
-
Unary value type (eg. null).
|
221
|
-
"""
|
222
|
-
Binary = auto()
|
223
|
-
"""
|
224
|
-
Binary value type (eg. True, False).
|
225
|
-
"""
|
226
|
-
Discrete = auto()
|
227
|
-
"""
|
228
|
-
Discrete value type (eg. 1, 2, 3).
|
157
|
+
Empty = ""
|
229
158
|
"""
|
230
|
-
|
231
|
-
"""
|
232
|
-
Continuous value type (eg. 1.0, 2.5, 3.4).
|
233
|
-
"""
|
234
|
-
Nominal = auto()
|
235
|
-
"""
|
236
|
-
Nominal value type (eg. red, blue, green), categories where the values have no natural order.
|
237
|
-
"""
|
238
|
-
Ordinal = auto()
|
239
|
-
"""
|
240
|
-
Ordinal value type (eg. low, medium, high), categories where the values have a natural order.
|
241
|
-
"""
|
242
|
-
Iterable = auto()
|
243
|
-
"""
|
244
|
-
Iterable value type (eg. [1, 2, 3]).
|
159
|
+
Empty edge, used for example for the root/input node of the tree.
|
245
160
|
"""
|
161
|
+
|
162
|
+
@classmethod
|
163
|
+
def from_value(cls, value: str) -> RDREdge:
|
164
|
+
"""
|
165
|
+
Convert a string value to an RDREdge enum.
|
166
|
+
|
167
|
+
:param value: The string that represents the edge type.
|
168
|
+
:return: The RDREdge enum.
|
169
|
+
"""
|
170
|
+
if value not in cls._value2member_map_:
|
171
|
+
raise ValueError(f"RDREdge {value} is not supported.")
|
172
|
+
return cls._value2member_map_[value]
|
@@ -0,0 +1,177 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import enum
|
4
|
+
import importlib
|
5
|
+
import inspect
|
6
|
+
import logging
|
7
|
+
import sys
|
8
|
+
import typing
|
9
|
+
from dataclasses import dataclass, Field
|
10
|
+
from datetime import datetime
|
11
|
+
from functools import lru_cache
|
12
|
+
from types import NoneType
|
13
|
+
|
14
|
+
from typing_extensions import Type, get_origin, Optional, get_type_hints, Tuple
|
15
|
+
|
16
|
+
from ripple_down_rules.utils import make_tuple
|
17
|
+
|
18
|
+
|
19
|
+
class ParseError(TypeError):
|
20
|
+
"""
|
21
|
+
Error that will be raised when the parser encounters something that can/should not be parsed.
|
22
|
+
|
23
|
+
For instance, Union types
|
24
|
+
"""
|
25
|
+
pass
|
26
|
+
|
27
|
+
|
28
|
+
@dataclass
|
29
|
+
class FieldInfo:
|
30
|
+
"""
|
31
|
+
A class that wraps a field of dataclass and provides some utility functions.
|
32
|
+
"""
|
33
|
+
|
34
|
+
clazz: Type
|
35
|
+
"""
|
36
|
+
The class that the field is in.
|
37
|
+
"""
|
38
|
+
|
39
|
+
name: str
|
40
|
+
"""
|
41
|
+
The name of the field.
|
42
|
+
"""
|
43
|
+
|
44
|
+
type: Tuple[Type, ...]
|
45
|
+
"""
|
46
|
+
The type of the field or inner type of the container if it is a container.
|
47
|
+
"""
|
48
|
+
|
49
|
+
optional: bool
|
50
|
+
"""
|
51
|
+
True if the field is optional, False otherwise.
|
52
|
+
"""
|
53
|
+
|
54
|
+
container: Optional[Type]
|
55
|
+
"""
|
56
|
+
The type of the container if it is one (list, set, tuple, etc.). If there is no container this is None
|
57
|
+
"""
|
58
|
+
|
59
|
+
is_type_field: bool = False
|
60
|
+
"""
|
61
|
+
The field value is a type.
|
62
|
+
"""
|
63
|
+
|
64
|
+
field: Field = None
|
65
|
+
"""
|
66
|
+
The field instance.
|
67
|
+
"""
|
68
|
+
|
69
|
+
def __init__(self, clazz: Type, f: Field):
|
70
|
+
self.field = f
|
71
|
+
self.name = f.name
|
72
|
+
self.clazz = clazz
|
73
|
+
|
74
|
+
try:
|
75
|
+
type_hints = get_type_hints(clazz)[self.name]
|
76
|
+
except NameError as e:
|
77
|
+
found_clazz = manually_search_for_class_name(e.name)
|
78
|
+
module = importlib.import_module(found_clazz.__module__)
|
79
|
+
locals()[e.name] = getattr(module, e.name)
|
80
|
+
type_hints = get_type_hints(clazz, localns=locals())[self.name]
|
81
|
+
type_args = typing.get_args(type_hints)
|
82
|
+
|
83
|
+
# try to unpack the type if it is a nested type
|
84
|
+
if len(type_args) > 0:
|
85
|
+
self.optional = NoneType in type_args
|
86
|
+
|
87
|
+
if self.optional:
|
88
|
+
self.container = None
|
89
|
+
else:
|
90
|
+
self.container = get_origin(type_hints)
|
91
|
+
|
92
|
+
if not self.optional and type_hints == Type[type_args]:
|
93
|
+
self.is_type_field = True
|
94
|
+
|
95
|
+
self.type = type_args
|
96
|
+
else:
|
97
|
+
self.optional = False
|
98
|
+
self.container = None
|
99
|
+
self.type = make_tuple(type_hints)
|
100
|
+
|
101
|
+
@property
|
102
|
+
def is_builtin_class(self) -> bool:
|
103
|
+
return not self.container and all(t.__module__ == 'builtins' for t in self.type)
|
104
|
+
|
105
|
+
@property
|
106
|
+
def is_container_of_builtin(self) -> bool:
|
107
|
+
return self.container and all(t.__module__ == 'builtins' for t in self.type)
|
108
|
+
|
109
|
+
@property
|
110
|
+
def is_type_type(self) -> bool:
|
111
|
+
return self.is_type_field
|
112
|
+
|
113
|
+
@property
|
114
|
+
def is_enum(self):
|
115
|
+
return len(self.type) == 1 and issubclass(self.type[0], enum.Enum)
|
116
|
+
|
117
|
+
@property
|
118
|
+
def is_datetime(self):
|
119
|
+
return len(self.type) == 1 and self.type[0] == datetime
|
120
|
+
|
121
|
+
|
122
|
+
def is_container(clazz: Type) -> bool:
|
123
|
+
"""
|
124
|
+
Check if a class is an iterable.
|
125
|
+
|
126
|
+
:param clazz: The class to check
|
127
|
+
:return: True if the class is an iterable, False otherwise
|
128
|
+
"""
|
129
|
+
return get_origin(clazz) in [list, set, tuple]
|
130
|
+
|
131
|
+
|
132
|
+
def manually_search_for_class_name(target_class_name: str) -> Type:
|
133
|
+
"""
|
134
|
+
Searches for a class with the specified name in the current module's `globals()` dictionary
|
135
|
+
and all loaded modules present in `sys.modules`. This function attempts to find and resolve
|
136
|
+
the first class that matches the given name. If multiple classes are found with the same
|
137
|
+
name, a warning is logged, and the first one is returned. If no matching class is found,
|
138
|
+
an exception is raised.
|
139
|
+
|
140
|
+
:param target_class_name: Name of the class to search for.
|
141
|
+
:return: The resolved class with the matching name.
|
142
|
+
|
143
|
+
:raises ValueError: Raised when no class with the specified name can be found.
|
144
|
+
"""
|
145
|
+
found_classes = []
|
146
|
+
|
147
|
+
# Search 1: In the current module's globals()
|
148
|
+
for name, obj in globals().items():
|
149
|
+
if inspect.isclass(obj) and obj.__name__ == target_class_name:
|
150
|
+
found_classes.append(obj)
|
151
|
+
|
152
|
+
# Search 2: In all loaded modules (via sys.modules)
|
153
|
+
for module_name, module in sys.modules.items():
|
154
|
+
if module is None or not hasattr(module, '__dict__'):
|
155
|
+
continue # Skip built-in modules or modules without a __dict__
|
156
|
+
|
157
|
+
for name, obj in module.__dict__.items():
|
158
|
+
if inspect.isclass(obj) and obj.__name__ == target_class_name:
|
159
|
+
# Avoid duplicates if a class is imported into multiple namespaces
|
160
|
+
if (obj, f"from module '{module_name}'") not in found_classes:
|
161
|
+
found_classes.append(obj)
|
162
|
+
|
163
|
+
# If you wanted to "resolve" the forward ref based on this
|
164
|
+
if len(found_classes) == 0:
|
165
|
+
raise ValueError(f"Could not find any class with name {target_class_name} in globals or sys.modules.")
|
166
|
+
elif len(found_classes) == 1:
|
167
|
+
resolved_class = found_classes[0]
|
168
|
+
else:
|
169
|
+
warn_multiple_classes(target_class_name, tuple(found_classes))
|
170
|
+
resolved_class = found_classes[0]
|
171
|
+
|
172
|
+
return resolved_class
|
173
|
+
|
174
|
+
|
175
|
+
@lru_cache(maxsize=None)
|
176
|
+
def warn_multiple_classes(target_class_name, found_classes):
|
177
|
+
logging.warning(f"Found multiple classes with name {target_class_name}. Found classes: {found_classes} ")
|
@@ -0,0 +1,208 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import inspect
|
4
|
+
import uuid
|
5
|
+
from dataclasses import dataclass, field, Field, fields
|
6
|
+
from enum import Enum
|
7
|
+
from functools import lru_cache
|
8
|
+
|
9
|
+
import pydot
|
10
|
+
import rustworkx as rx
|
11
|
+
from typing_extensions import Any, TYPE_CHECKING, Type, final, ClassVar, Dict, List, Optional, Tuple
|
12
|
+
|
13
|
+
from .field_info import FieldInfo
|
14
|
+
from .. import logger
|
15
|
+
from ..rules import Rule
|
16
|
+
from ..utils import recursive_subclasses
|
17
|
+
|
18
|
+
|
19
|
+
class Direction(Enum):
|
20
|
+
OUTBOUND = False
|
21
|
+
INBOUND = True
|
22
|
+
|
23
|
+
|
24
|
+
class Relation(str, Enum):
|
25
|
+
has = "has"
|
26
|
+
isA = "isA"
|
27
|
+
dependsOn = "dependsOn"
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass(unsafe_hash=True)
|
31
|
+
class TrackedObjectMixin:
|
32
|
+
"""
|
33
|
+
A class that is used as a base class to all classes that needs to be tracked for RDR inference, and reasoning.
|
34
|
+
"""
|
35
|
+
_rdr_tracked_object_id: int = field(init=False, repr=False, hash=False,
|
36
|
+
compare=False, default_factory=lambda: uuid.uuid4().int)
|
37
|
+
"""
|
38
|
+
The unique identifier of the conclusion.
|
39
|
+
"""
|
40
|
+
_dependency_graph: ClassVar[rx.PyDAG[Type[TrackedObjectMixin]]] = rx.PyDAG()
|
41
|
+
"""
|
42
|
+
A graph that represents the relationships between all tracked objects.
|
43
|
+
"""
|
44
|
+
_class_graph_indices: ClassVar[Dict[Type[TrackedObjectMixin], int]] = {}
|
45
|
+
"""
|
46
|
+
The index of the current class in the dependency graph.
|
47
|
+
"""
|
48
|
+
_composition_edges: ClassVar[List[Tuple[int, int]]] = []
|
49
|
+
"""
|
50
|
+
The edges that represent composition relations between objects (Relation.has).
|
51
|
+
"""
|
52
|
+
_inheritance_edges: ClassVar[List[Tuple[int, int]]] = []
|
53
|
+
"""
|
54
|
+
The edges that represent inheritance relations between objects (Relation.isA).
|
55
|
+
"""
|
56
|
+
_overridden_by: Type[TrackedObjectMixin] = field(init=False, repr=False, hash=False,
|
57
|
+
compare=False, default=None)
|
58
|
+
"""
|
59
|
+
Whether the class has been overridden by a subclass.
|
60
|
+
This is used to only include the new class in the dependency graph, not the overridden class.
|
61
|
+
"""
|
62
|
+
|
63
|
+
@classmethod
|
64
|
+
def _reset_dependency_graph(cls) -> None:
|
65
|
+
"""
|
66
|
+
Reset the dependency graph and all class indices.
|
67
|
+
"""
|
68
|
+
cls._dependency_graph = rx.PyDAG()
|
69
|
+
cls._class_graph_indices = {}
|
70
|
+
cls._composition_edges = []
|
71
|
+
cls._inheritance_edges = []
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def _my_graph_idx(cls):
|
75
|
+
return cls._class_graph_indices[cls]
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def make_class_dependency_graph(cls, composition: bool = True):
|
79
|
+
"""
|
80
|
+
Create a direct acyclic graph containing the class hierarchy.
|
81
|
+
|
82
|
+
:param composition: If True, the class dependency graph will include composition relations.
|
83
|
+
"""
|
84
|
+
classes_to_track = recursive_subclasses(cls) + [Rule] + recursive_subclasses(Rule)
|
85
|
+
for clazz in classes_to_track:
|
86
|
+
cls._add_class_to_dependency_graph(clazz)
|
87
|
+
|
88
|
+
for clazz in cls._class_graph_indices:
|
89
|
+
bases = [base for base in clazz.__bases__ if
|
90
|
+
base.__module__ not in ["builtins"] and base in cls._class_graph_indices]
|
91
|
+
|
92
|
+
for base in bases:
|
93
|
+
cls._add_class_to_dependency_graph(base)
|
94
|
+
clazz_idx = cls._class_graph_indices[clazz]
|
95
|
+
base_idx = cls._class_graph_indices[base]
|
96
|
+
if (clazz_idx, base_idx) in cls._inheritance_edges or base._overridden_by == clazz:
|
97
|
+
continue
|
98
|
+
cls._dependency_graph.add_edge(clazz_idx, base_idx, Relation.isA)
|
99
|
+
cls._inheritance_edges.append((clazz_idx, base_idx))
|
100
|
+
|
101
|
+
if not composition:
|
102
|
+
return
|
103
|
+
for clazz, idx in cls._class_graph_indices.items():
|
104
|
+
if clazz.__module__ == "builtins":
|
105
|
+
continue
|
106
|
+
TrackedObjectMixin.parse_fields(clazz)
|
107
|
+
|
108
|
+
@staticmethod
|
109
|
+
def parse_fields(clazz) -> None:
|
110
|
+
for f in TrackedObjectMixin.get_fields(clazz):
|
111
|
+
|
112
|
+
logger.debug("=" * 80)
|
113
|
+
logger.debug(f"Processing Field {clazz.__name__}.{f.name}: {f.type}.")
|
114
|
+
|
115
|
+
# skip private fields
|
116
|
+
if f.name.startswith("_"):
|
117
|
+
logger.debug(f"Skipping since the field starts with _.")
|
118
|
+
continue
|
119
|
+
|
120
|
+
field_info = FieldInfo(clazz, f)
|
121
|
+
TrackedObjectMixin.parse_field(field_info)
|
122
|
+
|
123
|
+
@staticmethod
|
124
|
+
def get_fields(clazz) -> List[Field]:
|
125
|
+
skip_fields = []
|
126
|
+
bases = [base for base in clazz.__bases__ if issubclass(base, TrackedObjectMixin)]
|
127
|
+
for base in bases:
|
128
|
+
skip_fields.extend(TrackedObjectMixin.get_fields(base))
|
129
|
+
|
130
|
+
result = [cls_field for cls_field in fields(clazz) if cls_field not in skip_fields]
|
131
|
+
|
132
|
+
return result
|
133
|
+
|
134
|
+
@staticmethod
|
135
|
+
def parse_field(field_info: FieldInfo):
|
136
|
+
parent_idx = TrackedObjectMixin._class_graph_indices[field_info.clazz]
|
137
|
+
field_cls: Optional[Type[TrackedObjectMixin]] = None
|
138
|
+
field_relation = Relation.has
|
139
|
+
if len(field_info.type) == 1 and issubclass(field_info.type[0], TrackedObjectMixin):
|
140
|
+
field_cls = field_info.type[0]
|
141
|
+
else:
|
142
|
+
# TODO: Create a new TrackedObjectMixin class for new type
|
143
|
+
logger.debug(f"Skipping unhandled field type: {field_info.type}")
|
144
|
+
# logger.debug(f"Creating new TrackedObject type for builtin type {field_info.type}.")
|
145
|
+
# field_cls = cls._create_tracked_object_class_for_field(field_info)
|
146
|
+
|
147
|
+
if field_cls is not None:
|
148
|
+
field_cls_idx = TrackedObjectMixin._class_graph_indices.get(field_cls, None)
|
149
|
+
if field_cls_idx is not None and (parent_idx, field_cls_idx) in TrackedObjectMixin._composition_edges:
|
150
|
+
return
|
151
|
+
elif field_cls_idx is None:
|
152
|
+
TrackedObjectMixin._add_class_to_dependency_graph(field_cls)
|
153
|
+
field_cls_idx = TrackedObjectMixin._class_graph_indices[field_cls]
|
154
|
+
TrackedObjectMixin._dependency_graph.add_edge(parent_idx, field_cls_idx, field_relation)
|
155
|
+
TrackedObjectMixin._composition_edges.append((parent_idx, field_cls_idx))
|
156
|
+
|
157
|
+
@classmethod
|
158
|
+
def _create_tracked_object_class_for_field(cls, field_info: FieldInfo):
|
159
|
+
raise NotImplementedError
|
160
|
+
|
161
|
+
@classmethod
|
162
|
+
def to_dot(cls, filepath: str, format='png') -> None:
|
163
|
+
if not filepath.endswith(f".{format}"):
|
164
|
+
filepath += f".{format}"
|
165
|
+
dot_str = cls._dependency_graph.to_dot(
|
166
|
+
lambda node: dict(
|
167
|
+
color='black', fillcolor='lightblue', style='filled', label=node.__name__),
|
168
|
+
lambda edge: dict(color='black', style='solid', label=edge))
|
169
|
+
dot = pydot.graph_from_dot_data(dot_str)[0]
|
170
|
+
dot.write(filepath, format=format)
|
171
|
+
|
172
|
+
@classmethod
|
173
|
+
def _add_class_to_dependency_graph(cls, class_to_add: Type[TrackedObjectMixin]) -> None:
|
174
|
+
"""
|
175
|
+
Add a class to the dependency graph.
|
176
|
+
"""
|
177
|
+
if class_to_add not in cls._class_graph_indices:
|
178
|
+
if not issubclass(class_to_add, TrackedObjectMixin):
|
179
|
+
class_to_add._overridden_by = None
|
180
|
+
cls_idx = cls._dependency_graph.add_node(class_to_add._overridden_by or class_to_add)
|
181
|
+
cls._class_graph_indices[class_to_add] = cls_idx
|
182
|
+
if class_to_add._overridden_by:
|
183
|
+
cls._class_graph_indices[class_to_add._overridden_by] = cls_idx
|
184
|
+
|
185
|
+
def __getattribute__(self, name: str) -> Any:
|
186
|
+
# if name not in [f.name for f in fields(TrackedObjectMixin)] + ['has', 'is_a', 'depends_on']\
|
187
|
+
# and not name.startswith("_"):
|
188
|
+
# self._record_dependency(name)
|
189
|
+
return object.__getattribute__(self, name)
|
190
|
+
|
191
|
+
def _record_dependency(self, attr_name):
|
192
|
+
# Inspect stack to find instance of CallableExpression
|
193
|
+
for frame_info in inspect.stack():
|
194
|
+
func_name = frame_info.function
|
195
|
+
local_self = frame_info.frame.f_locals.get("self", None)
|
196
|
+
if (
|
197
|
+
func_name == "__call__" and
|
198
|
+
local_self is not None and
|
199
|
+
type(local_self).__module__ == "callable_expression" and
|
200
|
+
type(local_self).__name__ == "CallableExpression"
|
201
|
+
):
|
202
|
+
logger.debug("TrackedObject used inside CallableExpression")
|
203
|
+
break
|
204
|
+
|
205
|
+
|
206
|
+
annotations = TrackedObjectMixin.__annotations__
|
207
|
+
for val in [f.name for f in fields(TrackedObjectMixin)]:
|
208
|
+
annotations.pop(val, None)
|
ripple_down_rules/helpers.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import importlib
|
3
4
|
import os
|
4
5
|
import sys
|
6
|
+
from functools import wraps
|
5
7
|
from types import ModuleType
|
6
|
-
from typing import Tuple
|
8
|
+
from typing import Tuple, Callable, Dict, Any, Optional
|
7
9
|
|
8
10
|
from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING, Union
|
9
11
|
|
10
12
|
from .datastructures.case import create_case, Case
|
11
13
|
from .datastructures.dataclasses import CaseQuery
|
12
|
-
from .utils import calculate_precision_and_recall
|
14
|
+
from .utils import calculate_precision_and_recall, get_method_args_as_dict, get_func_rdr_model_name
|
13
15
|
from .utils import get_func_rdr_model_name, copy_case, make_set, update_case
|
14
16
|
|
15
17
|
if TYPE_CHECKING:
|
@@ -127,3 +129,36 @@ def enable_gui():
|
|
127
129
|
viewer = RDRCaseViewer()
|
128
130
|
except ImportError:
|
129
131
|
pass
|
132
|
+
|
133
|
+
|
134
|
+
def create_case_from_method(func: Callable,
|
135
|
+
func_output: Dict[str, Any],
|
136
|
+
*args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
|
137
|
+
"""
|
138
|
+
Create a Case from the function and its arguments.
|
139
|
+
|
140
|
+
:param func: The function to create a case from.
|
141
|
+
:param func_output: A dictionary containing the output of the function, where the key is the output name.
|
142
|
+
:param args: The positional arguments of the function.
|
143
|
+
:param kwargs: The keyword arguments of the function.
|
144
|
+
:return: A Case object representing the case.
|
145
|
+
"""
|
146
|
+
case_dict = get_method_args_as_dict(func, *args, **kwargs)
|
147
|
+
case_dict.update(func_output)
|
148
|
+
case_name = get_func_rdr_model_name(func)
|
149
|
+
return Case(dict, id(case_dict), case_name, case_dict, **case_dict), case_dict
|
150
|
+
|
151
|
+
|
152
|
+
class MockRDRDecorator:
|
153
|
+
def __init__(self, models_dir: str):
|
154
|
+
self.models_dir = models_dir
|
155
|
+
def decorator(self, func: Callable) -> Callable:
|
156
|
+
@wraps(func)
|
157
|
+
def wrapper(*args, **kwargs) -> Optional[Any]:
|
158
|
+
model_dir = get_func_rdr_model_name(func, include_file_name=True)
|
159
|
+
model_name = get_func_rdr_model_name(func, include_file_name=False)
|
160
|
+
rdr = importlib.import_module(os.path.join(self.models_dir, model_dir, f"{model_name}_rdr.py"))
|
161
|
+
func_output = {"output_": func(*args, **kwargs)}
|
162
|
+
case, case_dict = create_case_from_method(func, func_output, *args, **kwargs)
|
163
|
+
return rdr.classify(case)
|
164
|
+
return wrapper
|