ripple-down-rules 0.6.51__py3-none-any.whl → 0.6.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- __version__ = "0.6.51"
1
+ __version__ = "0.6.60"
2
2
 
3
3
  import logging
4
4
  import sys
@@ -12,6 +12,14 @@ try:
12
12
  except ImportError:
13
13
  pass
14
14
 
15
- from .datastructures.dataclasses import CaseQuery
16
- from .rdr_decorators import RDRDecorator
17
- from .rdr import MultiClassRDR, SingleClassRDR, GeneralRDR
15
+
16
+ # Trigger patch
17
+ try:
18
+ from .predicates import *
19
+ from .datastructures.tracked_object import TrackedObjectMixin
20
+ from .datastructures.dataclasses import CaseQuery
21
+ from .rdr_decorators import RDRDecorator
22
+ from .rdr import MultiClassRDR, SingleClassRDR, GeneralRDR
23
+ import ripple_down_rules_meta._apply_overrides
24
+ except ImportError:
25
+ pass
@@ -12,7 +12,7 @@ from typing_extensions import Any, Optional, Dict, Type, Tuple, Union, List, Set
12
12
  from .callable_expression import CallableExpression
13
13
  from .case import create_case, Case
14
14
  from ..utils import copy_case, make_list, make_set, get_origin_and_args_from_type_hint, render_tree, \
15
- get_function_representation
15
+ get_function_representation, get_method_object_from_pytest_request
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from ..rdr import RippleDownRules
@@ -60,9 +60,14 @@ class CaseQuery:
60
60
  The executable scenario is the root callable that recreates the situation that the case is
61
61
  created in, for example, when the case is created from a test function, this would be the test function itself.
62
62
  """
63
+ this_case_target_value: Optional[Any] = None
64
+ """
65
+ The non relational case query instance target value.
66
+ """
63
67
  _target: Optional[CallableExpression] = None
64
68
  """
65
- The target expression of the attribute.
69
+ The relational target (the evaluatable conclusion of the rule) which is a callable expression that varies with
70
+ the case.
66
71
  """
67
72
  default_value: Optional[Any] = None
68
73
  """
@@ -306,6 +311,12 @@ class CaseFactoryMetaData:
306
311
  factory_idx: Optional[int] = None
307
312
  case_conf: Optional[CaseConf] = None
308
313
  scenario: Optional[Callable] = None
314
+ pytest_request: Optional[Callable] = field(hash=False, compare=False, default=None)
315
+ this_case_target_value: Optional[Any] = None
316
+
317
+ def __post_init__(self):
318
+ if self.pytest_request is not None and self.scenario is None:
319
+ self.scenario = get_method_object_from_pytest_request(self.pytest_request)
309
320
 
310
321
  @classmethod
311
322
  def from_case_query(cls, case_query: CaseQuery) -> CaseFactoryMetaData:
@@ -322,8 +333,9 @@ class CaseFactoryMetaData:
322
333
  return (f"CaseFactoryMetaData("
323
334
  f"factory_method={factory_method_repr}, "
324
335
  f"factory_idx={self.factory_idx}, "
325
- f"case_conf={self.case_conf},"
326
- f" scenario={scenario_repr})")
336
+ f"case_conf={self.case_conf}, "
337
+ f"scenario={scenario_repr}, "
338
+ f"this_case_target_value={self.this_case_target_value})")
327
339
 
328
340
  def __str__(self):
329
341
  return self.__repr__()
@@ -335,26 +347,57 @@ class RDRConclusion:
335
347
  This dataclass represents a conclusion of a Ripple Down Rule.
336
348
  It contains the conclusion expression, the type of the conclusion, and the scope in which it is evaluated.
337
349
  """
338
- value: Any
350
+ _conclusion: Any
339
351
  """
340
352
  The conclusion value.
341
353
  """
342
- frozen_case: Any
354
+ _frozen_case: Any
343
355
  """
344
356
  The frozen case that the conclusion was made for.
345
357
  """
346
- rule: Rule
358
+ _rule: Rule
347
359
  """
348
360
  The rule that gave this conclusion.
349
361
  """
350
- rdr: RippleDownRules
362
+ _rdr: RippleDownRules
351
363
  """
352
364
  The Ripple Down Rules that classified the case and produced this conclusion.
353
365
  """
354
- id: int = field(default_factory=lambda: uuid.uuid4().int)
366
+ _id: int = field(default_factory=lambda: uuid.uuid4().int)
355
367
  """
356
368
  The unique identifier of the conclusion.
357
369
  """
370
+ def __getattribute__(self, name: str) -> Any:
371
+ if name.startswith('_'):
372
+ return object.__getattribute__(self, name)
373
+ else:
374
+ conclusion = object.__getattribute__(self, "_conclusion")
375
+
376
+ value = getattr(conclusion, name)
377
+
378
+ self._record_dependency(name)
379
+
380
+ return value
381
+
382
+ def __setattr__(self, name, value):
383
+ if name.startswith('_'):
384
+ object.__setattr__(self, name, value)
385
+ else:
386
+ setattr(self._wrapped, name, value)
387
+
388
+ def _record_dependency(self, attr_name):
389
+ # Inspect stack to find instance of CallableExpression
390
+ for frame_info in inspect.stack():
391
+ func_name = frame_info.function
392
+ local_self = frame_info.frame.f_locals.get("self", None)
393
+ if (
394
+ func_name == "__call__" and
395
+ local_self is not None and
396
+ type(local_self) is CallableExpression
397
+ ):
398
+ self._used_in_tracker = True
399
+ print("RDRConclusion used inside CallableExpression")
400
+ break
358
401
 
359
402
  def __hash__(self):
360
403
  return hash(self.id)
@@ -93,20 +93,6 @@ class Stop(Category):
93
93
  stop = "stop"
94
94
 
95
95
 
96
- class ExpressionParser(Enum):
97
- """
98
- Parsers for expressions to evaluate and encapsulate the expression into a callable function.
99
- """
100
- ASTVisitor: int = auto()
101
- """
102
- Generic python Abstract Syntax Tree that detects variables, attributes, binary/boolean expressions , ...etc.
103
- """
104
- SQLAlchemy: int = auto()
105
- """
106
- Specific for SQLAlchemy expressions on ORM Tables.
107
- """
108
-
109
-
110
96
  class PromptFor(Enum):
111
97
  """
112
98
  The reason of the prompt. (e.g. get conditions, conclusions, or affirmation).
@@ -131,51 +117,6 @@ class PromptFor(Enum):
131
117
  return self.__str__()
132
118
 
133
119
 
134
- class CategoricalValue(Enum):
135
- """
136
- A categorical value is a value that is a category.
137
- """
138
-
139
- def __eq__(self, other):
140
- if isinstance(other, CategoricalValue):
141
- return self.name == other.name
142
- elif isinstance(other, str):
143
- return self.name == other
144
- return self.name == other
145
-
146
- def __hash__(self):
147
- return hash(self.name)
148
-
149
- @classmethod
150
- def to_list(cls):
151
- return list(cls._value2member_map_.keys())
152
-
153
- @classmethod
154
- def from_str(cls, category: str):
155
- return cls[category.lower()]
156
-
157
- @classmethod
158
- def from_strs(cls, categories: List[str]):
159
- return [cls.from_str(c) for c in categories]
160
-
161
- def __str__(self):
162
- return self.name
163
-
164
- def __repr__(self):
165
- return self.__str__()
166
-
167
-
168
- class RDRMode(Enum):
169
- Propositional = auto()
170
- """
171
- Propositional mode, the mode where the rules are propositional.
172
- """
173
- Relational = auto()
174
- """
175
- Relational mode, the mode where the rules are relational.
176
- """
177
-
178
-
179
120
  class MCRDRMode(Enum):
180
121
  """
181
122
  The modes of the MultiClassRDR.
@@ -213,33 +154,19 @@ class RDREdge(Enum):
213
154
  """
214
155
  Filter edge, the edge that represents the filter condition.
215
156
  """
216
-
217
- class ValueType(Enum):
218
- Unary = auto()
219
- """
220
- Unary value type (eg. null).
221
- """
222
- Binary = auto()
223
- """
224
- Binary value type (eg. True, False).
225
- """
226
- Discrete = auto()
227
- """
228
- Discrete value type (eg. 1, 2, 3).
157
+ Empty = ""
229
158
  """
230
- Continuous = auto()
231
- """
232
- Continuous value type (eg. 1.0, 2.5, 3.4).
233
- """
234
- Nominal = auto()
235
- """
236
- Nominal value type (eg. red, blue, green), categories where the values have no natural order.
237
- """
238
- Ordinal = auto()
239
- """
240
- Ordinal value type (eg. low, medium, high), categories where the values have a natural order.
241
- """
242
- Iterable = auto()
243
- """
244
- Iterable value type (eg. [1, 2, 3]).
159
+ Empty edge, used for example for the root/input node of the tree.
245
160
  """
161
+
162
+ @classmethod
163
+ def from_value(cls, value: str) -> RDREdge:
164
+ """
165
+ Convert a string value to an RDREdge enum.
166
+
167
+ :param value: The string that represents the edge type.
168
+ :return: The RDREdge enum.
169
+ """
170
+ if value not in cls._value2member_map_:
171
+ raise ValueError(f"RDREdge {value} is not supported.")
172
+ return cls._value2member_map_[value]
@@ -0,0 +1,177 @@
1
+ from __future__ import annotations
2
+
3
+ import enum
4
+ import importlib
5
+ import inspect
6
+ import logging
7
+ import sys
8
+ import typing
9
+ from dataclasses import dataclass, Field
10
+ from datetime import datetime
11
+ from functools import lru_cache
12
+ from types import NoneType
13
+
14
+ from typing_extensions import Type, get_origin, Optional, get_type_hints, Tuple
15
+
16
+ from ripple_down_rules.utils import make_tuple
17
+
18
+
19
+ class ParseError(TypeError):
20
+ """
21
+ Error that will be raised when the parser encounters something that can/should not be parsed.
22
+
23
+ For instance, Union types
24
+ """
25
+ pass
26
+
27
+
28
+ @dataclass
29
+ class FieldInfo:
30
+ """
31
+ A class that wraps a field of dataclass and provides some utility functions.
32
+ """
33
+
34
+ clazz: Type
35
+ """
36
+ The class that the field is in.
37
+ """
38
+
39
+ name: str
40
+ """
41
+ The name of the field.
42
+ """
43
+
44
+ type: Tuple[Type, ...]
45
+ """
46
+ The type of the field or inner type of the container if it is a container.
47
+ """
48
+
49
+ optional: bool
50
+ """
51
+ True if the field is optional, False otherwise.
52
+ """
53
+
54
+ container: Optional[Type]
55
+ """
56
+ The type of the container if it is one (list, set, tuple, etc.). If there is no container this is None
57
+ """
58
+
59
+ is_type_field: bool = False
60
+ """
61
+ The field value is a type.
62
+ """
63
+
64
+ field: Field = None
65
+ """
66
+ The field instance.
67
+ """
68
+
69
+ def __init__(self, clazz: Type, f: Field):
70
+ self.field = f
71
+ self.name = f.name
72
+ self.clazz = clazz
73
+
74
+ try:
75
+ type_hints = get_type_hints(clazz)[self.name]
76
+ except NameError as e:
77
+ found_clazz = manually_search_for_class_name(e.name)
78
+ module = importlib.import_module(found_clazz.__module__)
79
+ locals()[e.name] = getattr(module, e.name)
80
+ type_hints = get_type_hints(clazz, localns=locals())[self.name]
81
+ type_args = typing.get_args(type_hints)
82
+
83
+ # try to unpack the type if it is a nested type
84
+ if len(type_args) > 0:
85
+ self.optional = NoneType in type_args
86
+
87
+ if self.optional:
88
+ self.container = None
89
+ else:
90
+ self.container = get_origin(type_hints)
91
+
92
+ if not self.optional and type_hints == Type[type_args]:
93
+ self.is_type_field = True
94
+
95
+ self.type = type_args
96
+ else:
97
+ self.optional = False
98
+ self.container = None
99
+ self.type = make_tuple(type_hints)
100
+
101
+ @property
102
+ def is_builtin_class(self) -> bool:
103
+ return not self.container and all(t.__module__ == 'builtins' for t in self.type)
104
+
105
+ @property
106
+ def is_container_of_builtin(self) -> bool:
107
+ return self.container and all(t.__module__ == 'builtins' for t in self.type)
108
+
109
+ @property
110
+ def is_type_type(self) -> bool:
111
+ return self.is_type_field
112
+
113
+ @property
114
+ def is_enum(self):
115
+ return len(self.type) == 1 and issubclass(self.type[0], enum.Enum)
116
+
117
+ @property
118
+ def is_datetime(self):
119
+ return len(self.type) == 1 and self.type[0] == datetime
120
+
121
+
122
+ def is_container(clazz: Type) -> bool:
123
+ """
124
+ Check if a class is an iterable.
125
+
126
+ :param clazz: The class to check
127
+ :return: True if the class is an iterable, False otherwise
128
+ """
129
+ return get_origin(clazz) in [list, set, tuple]
130
+
131
+
132
+ def manually_search_for_class_name(target_class_name: str) -> Type:
133
+ """
134
+ Searches for a class with the specified name in the current module's `globals()` dictionary
135
+ and all loaded modules present in `sys.modules`. This function attempts to find and resolve
136
+ the first class that matches the given name. If multiple classes are found with the same
137
+ name, a warning is logged, and the first one is returned. If no matching class is found,
138
+ an exception is raised.
139
+
140
+ :param target_class_name: Name of the class to search for.
141
+ :return: The resolved class with the matching name.
142
+
143
+ :raises ValueError: Raised when no class with the specified name can be found.
144
+ """
145
+ found_classes = []
146
+
147
+ # Search 1: In the current module's globals()
148
+ for name, obj in globals().items():
149
+ if inspect.isclass(obj) and obj.__name__ == target_class_name:
150
+ found_classes.append(obj)
151
+
152
+ # Search 2: In all loaded modules (via sys.modules)
153
+ for module_name, module in sys.modules.items():
154
+ if module is None or not hasattr(module, '__dict__'):
155
+ continue # Skip built-in modules or modules without a __dict__
156
+
157
+ for name, obj in module.__dict__.items():
158
+ if inspect.isclass(obj) and obj.__name__ == target_class_name:
159
+ # Avoid duplicates if a class is imported into multiple namespaces
160
+ if (obj, f"from module '{module_name}'") not in found_classes:
161
+ found_classes.append(obj)
162
+
163
+ # If you wanted to "resolve" the forward ref based on this
164
+ if len(found_classes) == 0:
165
+ raise ValueError(f"Could not find any class with name {target_class_name} in globals or sys.modules.")
166
+ elif len(found_classes) == 1:
167
+ resolved_class = found_classes[0]
168
+ else:
169
+ warn_multiple_classes(target_class_name, tuple(found_classes))
170
+ resolved_class = found_classes[0]
171
+
172
+ return resolved_class
173
+
174
+
175
+ @lru_cache(maxsize=None)
176
+ def warn_multiple_classes(target_class_name, found_classes):
177
+ logging.warning(f"Found multiple classes with name {target_class_name}. Found classes: {found_classes} ")
@@ -0,0 +1,208 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import uuid
5
+ from dataclasses import dataclass, field, Field, fields
6
+ from enum import Enum
7
+ from functools import lru_cache
8
+
9
+ import pydot
10
+ import rustworkx as rx
11
+ from typing_extensions import Any, TYPE_CHECKING, Type, final, ClassVar, Dict, List, Optional, Tuple
12
+
13
+ from .field_info import FieldInfo
14
+ from .. import logger
15
+ from ..rules import Rule
16
+ from ..utils import recursive_subclasses
17
+
18
+
19
+ class Direction(Enum):
20
+ OUTBOUND = False
21
+ INBOUND = True
22
+
23
+
24
+ class Relation(str, Enum):
25
+ has = "has"
26
+ isA = "isA"
27
+ dependsOn = "dependsOn"
28
+
29
+
30
+ @dataclass(unsafe_hash=True)
31
+ class TrackedObjectMixin:
32
+ """
33
+ A class that is used as a base class to all classes that needs to be tracked for RDR inference, and reasoning.
34
+ """
35
+ _rdr_tracked_object_id: int = field(init=False, repr=False, hash=False,
36
+ compare=False, default_factory=lambda: uuid.uuid4().int)
37
+ """
38
+ The unique identifier of the conclusion.
39
+ """
40
+ _dependency_graph: ClassVar[rx.PyDAG[Type[TrackedObjectMixin]]] = rx.PyDAG()
41
+ """
42
+ A graph that represents the relationships between all tracked objects.
43
+ """
44
+ _class_graph_indices: ClassVar[Dict[Type[TrackedObjectMixin], int]] = {}
45
+ """
46
+ The index of the current class in the dependency graph.
47
+ """
48
+ _composition_edges: ClassVar[List[Tuple[int, int]]] = []
49
+ """
50
+ The edges that represent composition relations between objects (Relation.has).
51
+ """
52
+ _inheritance_edges: ClassVar[List[Tuple[int, int]]] = []
53
+ """
54
+ The edges that represent inheritance relations between objects (Relation.isA).
55
+ """
56
+ _overridden_by: Type[TrackedObjectMixin] = field(init=False, repr=False, hash=False,
57
+ compare=False, default=None)
58
+ """
59
+ Whether the class has been overridden by a subclass.
60
+ This is used to only include the new class in the dependency graph, not the overridden class.
61
+ """
62
+
63
+ @classmethod
64
+ def _reset_dependency_graph(cls) -> None:
65
+ """
66
+ Reset the dependency graph and all class indices.
67
+ """
68
+ cls._dependency_graph = rx.PyDAG()
69
+ cls._class_graph_indices = {}
70
+ cls._composition_edges = []
71
+ cls._inheritance_edges = []
72
+
73
+ @classmethod
74
+ def _my_graph_idx(cls):
75
+ return cls._class_graph_indices[cls]
76
+
77
+ @classmethod
78
+ def make_class_dependency_graph(cls, composition: bool = True):
79
+ """
80
+ Create a direct acyclic graph containing the class hierarchy.
81
+
82
+ :param composition: If True, the class dependency graph will include composition relations.
83
+ """
84
+ classes_to_track = recursive_subclasses(cls) + [Rule] + recursive_subclasses(Rule)
85
+ for clazz in classes_to_track:
86
+ cls._add_class_to_dependency_graph(clazz)
87
+
88
+ for clazz in cls._class_graph_indices:
89
+ bases = [base for base in clazz.__bases__ if
90
+ base.__module__ not in ["builtins"] and base in cls._class_graph_indices]
91
+
92
+ for base in bases:
93
+ cls._add_class_to_dependency_graph(base)
94
+ clazz_idx = cls._class_graph_indices[clazz]
95
+ base_idx = cls._class_graph_indices[base]
96
+ if (clazz_idx, base_idx) in cls._inheritance_edges or base._overridden_by == clazz:
97
+ continue
98
+ cls._dependency_graph.add_edge(clazz_idx, base_idx, Relation.isA)
99
+ cls._inheritance_edges.append((clazz_idx, base_idx))
100
+
101
+ if not composition:
102
+ return
103
+ for clazz, idx in cls._class_graph_indices.items():
104
+ if clazz.__module__ == "builtins":
105
+ continue
106
+ TrackedObjectMixin.parse_fields(clazz)
107
+
108
+ @staticmethod
109
+ def parse_fields(clazz) -> None:
110
+ for f in TrackedObjectMixin.get_fields(clazz):
111
+
112
+ logger.debug("=" * 80)
113
+ logger.debug(f"Processing Field {clazz.__name__}.{f.name}: {f.type}.")
114
+
115
+ # skip private fields
116
+ if f.name.startswith("_"):
117
+ logger.debug(f"Skipping since the field starts with _.")
118
+ continue
119
+
120
+ field_info = FieldInfo(clazz, f)
121
+ TrackedObjectMixin.parse_field(field_info)
122
+
123
+ @staticmethod
124
+ def get_fields(clazz) -> List[Field]:
125
+ skip_fields = []
126
+ bases = [base for base in clazz.__bases__ if issubclass(base, TrackedObjectMixin)]
127
+ for base in bases:
128
+ skip_fields.extend(TrackedObjectMixin.get_fields(base))
129
+
130
+ result = [cls_field for cls_field in fields(clazz) if cls_field not in skip_fields]
131
+
132
+ return result
133
+
134
+ @staticmethod
135
+ def parse_field(field_info: FieldInfo):
136
+ parent_idx = TrackedObjectMixin._class_graph_indices[field_info.clazz]
137
+ field_cls: Optional[Type[TrackedObjectMixin]] = None
138
+ field_relation = Relation.has
139
+ if len(field_info.type) == 1 and issubclass(field_info.type[0], TrackedObjectMixin):
140
+ field_cls = field_info.type[0]
141
+ else:
142
+ # TODO: Create a new TrackedObjectMixin class for new type
143
+ logger.debug(f"Skipping unhandled field type: {field_info.type}")
144
+ # logger.debug(f"Creating new TrackedObject type for builtin type {field_info.type}.")
145
+ # field_cls = cls._create_tracked_object_class_for_field(field_info)
146
+
147
+ if field_cls is not None:
148
+ field_cls_idx = TrackedObjectMixin._class_graph_indices.get(field_cls, None)
149
+ if field_cls_idx is not None and (parent_idx, field_cls_idx) in TrackedObjectMixin._composition_edges:
150
+ return
151
+ elif field_cls_idx is None:
152
+ TrackedObjectMixin._add_class_to_dependency_graph(field_cls)
153
+ field_cls_idx = TrackedObjectMixin._class_graph_indices[field_cls]
154
+ TrackedObjectMixin._dependency_graph.add_edge(parent_idx, field_cls_idx, field_relation)
155
+ TrackedObjectMixin._composition_edges.append((parent_idx, field_cls_idx))
156
+
157
+ @classmethod
158
+ def _create_tracked_object_class_for_field(cls, field_info: FieldInfo):
159
+ raise NotImplementedError
160
+
161
+ @classmethod
162
+ def to_dot(cls, filepath: str, format='png') -> None:
163
+ if not filepath.endswith(f".{format}"):
164
+ filepath += f".{format}"
165
+ dot_str = cls._dependency_graph.to_dot(
166
+ lambda node: dict(
167
+ color='black', fillcolor='lightblue', style='filled', label=node.__name__),
168
+ lambda edge: dict(color='black', style='solid', label=edge))
169
+ dot = pydot.graph_from_dot_data(dot_str)[0]
170
+ dot.write(filepath, format=format)
171
+
172
+ @classmethod
173
+ def _add_class_to_dependency_graph(cls, class_to_add: Type[TrackedObjectMixin]) -> None:
174
+ """
175
+ Add a class to the dependency graph.
176
+ """
177
+ if class_to_add not in cls._class_graph_indices:
178
+ if not issubclass(class_to_add, TrackedObjectMixin):
179
+ class_to_add._overridden_by = None
180
+ cls_idx = cls._dependency_graph.add_node(class_to_add._overridden_by or class_to_add)
181
+ cls._class_graph_indices[class_to_add] = cls_idx
182
+ if class_to_add._overridden_by:
183
+ cls._class_graph_indices[class_to_add._overridden_by] = cls_idx
184
+
185
+ def __getattribute__(self, name: str) -> Any:
186
+ # if name not in [f.name for f in fields(TrackedObjectMixin)] + ['has', 'is_a', 'depends_on']\
187
+ # and not name.startswith("_"):
188
+ # self._record_dependency(name)
189
+ return object.__getattribute__(self, name)
190
+
191
+ def _record_dependency(self, attr_name):
192
+ # Inspect stack to find instance of CallableExpression
193
+ for frame_info in inspect.stack():
194
+ func_name = frame_info.function
195
+ local_self = frame_info.frame.f_locals.get("self", None)
196
+ if (
197
+ func_name == "__call__" and
198
+ local_self is not None and
199
+ type(local_self).__module__ == "callable_expression" and
200
+ type(local_self).__name__ == "CallableExpression"
201
+ ):
202
+ logger.debug("TrackedObject used inside CallableExpression")
203
+ break
204
+
205
+
206
+ annotations = TrackedObjectMixin.__annotations__
207
+ for val in [f.name for f in fields(TrackedObjectMixin)]:
208
+ annotations.pop(val, None)
@@ -1,15 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import importlib
3
4
  import os
4
5
  import sys
6
+ from functools import wraps
5
7
  from types import ModuleType
6
- from typing import Tuple
8
+ from typing import Tuple, Callable, Dict, Any, Optional
7
9
 
8
10
  from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING, Union
9
11
 
10
12
  from .datastructures.case import create_case, Case
11
13
  from .datastructures.dataclasses import CaseQuery
12
- from .utils import calculate_precision_and_recall
14
+ from .utils import calculate_precision_and_recall, get_method_args_as_dict, get_func_rdr_model_name
13
15
  from .utils import get_func_rdr_model_name, copy_case, make_set, update_case
14
16
 
15
17
  if TYPE_CHECKING:
@@ -127,3 +129,36 @@ def enable_gui():
127
129
  viewer = RDRCaseViewer()
128
130
  except ImportError:
129
131
  pass
132
+
133
+
134
+ def create_case_from_method(func: Callable,
135
+ func_output: Dict[str, Any],
136
+ *args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
137
+ """
138
+ Create a Case from the function and its arguments.
139
+
140
+ :param func: The function to create a case from.
141
+ :param func_output: A dictionary containing the output of the function, where the key is the output name.
142
+ :param args: The positional arguments of the function.
143
+ :param kwargs: The keyword arguments of the function.
144
+ :return: A Case object representing the case.
145
+ """
146
+ case_dict = get_method_args_as_dict(func, *args, **kwargs)
147
+ case_dict.update(func_output)
148
+ case_name = get_func_rdr_model_name(func)
149
+ return Case(dict, id(case_dict), case_name, case_dict, **case_dict), case_dict
150
+
151
+
152
+ class MockRDRDecorator:
153
+ def __init__(self, models_dir: str):
154
+ self.models_dir = models_dir
155
+ def decorator(self, func: Callable) -> Callable:
156
+ @wraps(func)
157
+ def wrapper(*args, **kwargs) -> Optional[Any]:
158
+ model_dir = get_func_rdr_model_name(func, include_file_name=True)
159
+ model_name = get_func_rdr_model_name(func, include_file_name=False)
160
+ rdr = importlib.import_module(os.path.join(self.models_dir, model_dir, f"{model_name}_rdr.py"))
161
+ func_output = {"output_": func(*args, **kwargs)}
162
+ case, case_dict = create_case_from_method(func, func_output, *args, **kwargs)
163
+ return rdr.classify(case)
164
+ return wrapper