ripple-down-rules 0.6.51__py3-none-any.whl → 0.6.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,97 @@
1
+ import os.path
2
+ from abc import ABC, abstractmethod
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from os.path import dirname
6
+
7
+ from typing_extensions import Type, ClassVar, TYPE_CHECKING, Tuple, Optional, Callable
8
+
9
+ from .datastructures.tracked_object import TrackedObjectMixin, Direction, Relation
10
+
11
+ if TYPE_CHECKING:
12
+ from .rdr_decorators import RDRDecorator
13
+
14
+
15
+ @dataclass(eq=False)
16
+ class Predicate(TrackedObjectMixin, ABC):
17
+
18
+ def __call__(self, *args, **kwargs):
19
+ return self.evaluate(*args, **kwargs)
20
+
21
+ @classmethod
22
+ @abstractmethod
23
+ def evaluate(cls, *args, **kwargs):
24
+ """
25
+ Evaluate the predicate with the given arguments.
26
+ This method should be implemented by subclasses.
27
+ """
28
+ pass
29
+
30
+ @classmethod
31
+ def rdr_decorator(cls, output_types: Tuple[Type, ...], mutually_exclusive: bool,
32
+ package_name: Optional[str] = None) -> Callable[[Callable], Callable]:
33
+ """
34
+ Returns the RDRDecorator to decorate the predicate evaluate method with.
35
+ """
36
+ rdr_decorator: RDRDecorator = RDRDecorator(cls.models_dir, output_types, mutually_exclusive,
37
+ package_name=package_name)
38
+ return rdr_decorator.decorator
39
+
40
+ def __hash__(self):
41
+ return hash(self.__class__.__name__)
42
+
43
+ def __eq__(self, other):
44
+ if not isinstance(other, Predicate):
45
+ return False
46
+ return self.__class__ == other.__class__
47
+
48
+
49
+ @dataclass
50
+ class IsA(Predicate):
51
+ """
52
+ A predicate that checks if an object type is a subclass of another object type.
53
+ """
54
+
55
+ @classmethod
56
+ def evaluate(cls, child_type: Type[TrackedObjectMixin], parent_type: Type[TrackedObjectMixin]) -> bool:
57
+ return issubclass(child_type, parent_type)
58
+
59
+ isA = IsA()
60
+
61
+
62
+ @dataclass
63
+ class Has(Predicate):
64
+ """
65
+ A predicate that checks if an object type has a certain member object type.
66
+ """
67
+
68
+ @classmethod
69
+ def evaluate(cls, owner_type: Type[TrackedObjectMixin],
70
+ member_type: Type[TrackedObjectMixin], recursive: bool = False) -> bool:
71
+ neighbors = cls._dependency_graph.adj_direction(owner_type._my_graph_idx(), Direction.OUTBOUND.value)
72
+ curr_val = any(e == Relation.has and isA(cls._dependency_graph.get_node_data(n), member_type)
73
+ or e == Relation.isA and cls.evaluate(cls._dependency_graph.get_node_data(n), member_type)
74
+ for n, e in neighbors.items())
75
+ if recursive:
76
+ return curr_val or any((e == Relation.has
77
+ and cls.evaluate(cls._dependency_graph.get_node_data(n), member_type, recursive=True))
78
+ for n, e in neighbors.items())
79
+ else:
80
+ return curr_val
81
+
82
+ has = Has()
83
+
84
+
85
+ @dataclass
86
+ class DependsOn(Predicate):
87
+ """
88
+ A predicate that checks if an object type depends on another object type.
89
+ """
90
+
91
+ @classmethod
92
+ def evaluate(cls, dependent: Type[TrackedObjectMixin],
93
+ dependency: Type[TrackedObjectMixin], recursive: bool = False) -> bool:
94
+ raise NotImplementedError("Should be overridden in rdr meta")
95
+
96
+
97
+ dependsOn = DependsOn()
ripple_down_rules/rdr.py CHANGED
@@ -113,7 +113,10 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
113
113
  if self.save_dir is not None and self.model_name is not None:
114
114
  model_path = os.path.join(self.save_dir, self.model_name)
115
115
  if os.path.exists(model_path):
116
- self.update_from_python(model_path, update_rule_tree=True)
116
+ try:
117
+ self.update_from_python(model_path, update_rule_tree=True)
118
+ except (FileNotFoundError, ModuleNotFoundError) as e:
119
+ pass
117
120
 
118
121
  def write_rdr_metadata_to_pyton_file(self, file: TextIOWrapper):
119
122
  """
@@ -343,11 +346,11 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
343
346
  rule.reset()
344
347
  if self.start_rule is not None and self.start_rule.parent is None:
345
348
  if self.input_node is None:
346
- self.input_node = type(self.start_rule)(parent=None, uid='0')
349
+ self.input_node = type(self.start_rule)(_parent=None, uid='0')
347
350
  self.input_node.evaluated = False
348
351
  self.input_node.fired = False
349
352
  self.start_rule.parent = self.input_node
350
- self.start_rule.weight = ""
353
+ self.start_rule.weight = RDREdge.Empty
351
354
  if self.input_node is not None:
352
355
  data = case.__dict__ if is_dataclass(case) else case
353
356
  if hasattr(case, "items"):
@@ -631,7 +634,7 @@ class TreeBuilder(ast.NodeVisitor, ABC):
631
634
  rule_uid = condition.split("conditions_")[1]
632
635
 
633
636
  new_rule_type = self.get_new_rule_type(node)
634
- new_node = new_rule_type(conditions=condition, parent=self.current_parent, uid=rule_uid)
637
+ new_node = new_rule_type(conditions=condition, _parent=self.current_parent, uid=rule_uid)
635
638
  if self.current_parent is not None:
636
639
  self.update_current_parent(new_node)
637
640
 
@@ -1009,6 +1012,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
1009
1012
  defs_imports = get_imports_from_types(defs_types, defs_file_name, package_name)
1010
1013
  corner_cases_imports = get_imports_from_types(corner_cases_types, cases_file_name, package_name)
1011
1014
 
1015
+ defs_imports.append(f"from ripple_down_rules import *")
1012
1016
  # Add the imports to the defs file
1013
1017
  with open(defs_file_name, "w") as f:
1014
1018
  f.write('\n'.join(defs_imports) + "\n\n\n")
@@ -1068,7 +1072,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
1068
1072
  main_types.update({Union, Optional})
1069
1073
  defs_types.add(Union)
1070
1074
  main_types.update({Case, create_case})
1071
- main_types = main_types.difference(defs_types)
1075
+ # main_types = main_types.difference(defs_types)
1072
1076
  return main_types, defs_types, cases_types
1073
1077
 
1074
1078
  @property
@@ -1170,7 +1174,7 @@ class SingleClassRDR(RDRWithCodeWriter):
1170
1174
  self.default_conclusion = case_query.default_value
1171
1175
 
1172
1176
  pred = self.evaluate(case_query.case)
1173
- if pred.conclusion(case_query.case) != case_query.target_value:
1177
+ if (not pred.fired and self.default_conclusion is None) or pred.conclusion(case_query.case) != case_query.target_value:
1174
1178
  expert.ask_for_conditions(case_query, pred)
1175
1179
  pred.fit_rule(case_query)
1176
1180
 
@@ -3,6 +3,7 @@ This file contains decorators for the RDR (Ripple Down Rules) framework. Where e
3
3
  that can be used with any python function such that this function can benefit from the incremental knowledge acquisition
4
4
  of the RDRs.
5
5
  """
6
+ import inspect
6
7
  import os.path
7
8
  from dataclasses import dataclass, field
8
9
  from functools import wraps
@@ -11,8 +12,9 @@ from typing import get_origin
11
12
  from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union, Sequence
12
13
 
13
14
  from .datastructures.case import Case
14
- from .datastructures.dataclasses import CaseQuery
15
+ from .datastructures.dataclasses import CaseQuery, CaseFactoryMetaData
15
16
  from .experts import Expert, Human
17
+ from .failures import RDRLoadError
16
18
  from .rdr import GeneralRDR
17
19
  from .utils import get_type_from_type_hint
18
20
 
@@ -21,7 +23,8 @@ try:
21
23
  except ImportError:
22
24
  RDRCaseViewer = None
23
25
  from .utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
24
- get_method_class_if_exists, str_to_snake_case, make_list
26
+ get_method_class_if_exists, make_list
27
+ from .helpers import create_case_from_method
25
28
 
26
29
 
27
30
  @dataclass(unsafe_hash=True)
@@ -83,7 +86,7 @@ class RDRDecorator:
83
86
  The name of the rdr model, this gets auto generated from the function signature and the class/file it is contained
84
87
  in.
85
88
  """
86
- rdr: GeneralRDR = field(init=False)
89
+ rdr: GeneralRDR = field(init=False, default=None)
87
90
  """
88
91
  The ripple down rules instance of the decorator class.
89
92
  """
@@ -109,12 +112,22 @@ class RDRDecorator:
109
112
  This is a flag that indicates that a not None output for the rdr has been inferred, this is used to update the
110
113
  generated dot file if it is set to `True`.
111
114
  """
115
+ case_factory_metadata: CaseFactoryMetaData = field(init=False, default_factory=CaseFactoryMetaData)
116
+ """
117
+ Metadata that contains the case factory method, and the scenario that is being run during the case query.
118
+ """
112
119
 
113
120
  def decorator(self, func: Callable) -> Callable:
114
121
 
115
122
  @wraps(func)
116
123
  def wrapper(*args, **kwargs) -> Optional[Any]:
117
124
 
125
+ original_kwargs = {pname: p for pname, p in inspect.signature(func).parameters.items() if
126
+ p.default != inspect._empty}
127
+ for og_kwarg in original_kwargs:
128
+ if og_kwarg not in kwargs:
129
+ kwargs[og_kwarg] = original_kwargs[og_kwarg].default
130
+
118
131
  if self.model_name is None:
119
132
  self.initialize_rdr_model_name_and_load(func)
120
133
  if self.origin_type is None and not self.mutual_exclusive:
@@ -124,7 +137,7 @@ class RDRDecorator:
124
137
 
125
138
  func_output = {self.output_name: func(*args, **kwargs)}
126
139
 
127
- case, case_dict = self.create_case_from_method(func, func_output, *args, **kwargs)
140
+ case, case_dict = create_case_from_method(func, func_output, *args, **kwargs)
128
141
 
129
142
  @self.fitting_decorator
130
143
  def fit():
@@ -132,11 +145,14 @@ class RDRDecorator:
132
145
  self.parsed_output_type = self.parse_output_type(func, self.output_type, *args)
133
146
  if self.expert is None:
134
147
  self.expert = Human(answers_save_path=self.models_dir + f'/{self.model_name}/expert_answers')
135
- case_query = self.create_case_query_from_method(func, func_output,
136
- self.parsed_output_type,
137
- self.mutual_exclusive,
138
- case, case_dict,
139
- *args, **kwargs)
148
+ case_query = self.create_case_query_from_method(
149
+ func, func_output,
150
+ self.parsed_output_type,
151
+ self.mutual_exclusive,
152
+ args, kwargs,
153
+ case=case, case_dict=case_dict,
154
+ scenario=self.case_factory_metadata.scenario,
155
+ this_case_target_value=self.case_factory_metadata.this_case_target_value)
140
156
  output = self.rdr.fit_case(case_query, expert=self.expert,
141
157
  update_existing_rules=self.update_existing_rules)
142
158
  return output
@@ -166,6 +182,8 @@ class RDRDecorator:
166
182
  else:
167
183
  return func_output[self.output_name]
168
184
 
185
+ wrapper._rdr_decorator_instance = self
186
+
169
187
  return wrapper
170
188
 
171
189
  @staticmethod
@@ -173,9 +191,11 @@ class RDRDecorator:
173
191
  func_output: Dict[str, Any],
174
192
  output_type: Sequence[Type],
175
193
  mutual_exclusive: bool,
194
+ func_args: Tuple[Any, ...], func_kwargs: Dict[str, Any],
176
195
  case: Optional[Case] = None,
177
196
  case_dict: Optional[Dict[str, Any]] = None,
178
- *args, **kwargs) -> CaseQuery:
197
+ scenario: Optional[Callable] = None,
198
+ this_case_target_value: Optional[Any] = None,) -> CaseQuery:
179
199
  """
180
200
  Create a CaseQuery from the function and its arguments.
181
201
 
@@ -183,43 +203,28 @@ class RDRDecorator:
183
203
  :param func_output: The output of the function as a dictionary, where the key is the output name.
184
204
  :param output_type: The type of the output as a sequence of types.
185
205
  :param mutual_exclusive: If True, the output types are mutually exclusive.
186
- :param args: The positional arguments of the function.
187
- :param kwargs: The keyword arguments of the function.
206
+ :param func_args: The positional arguments of the function.
207
+ :param func_kwargs: The keyword arguments of the function.
208
+ :param case: The case to create.
209
+ :param case_dict: The dictionary of the case.
210
+ :param scenario: The scenario that produced the given case.
211
+ :param this_case_target_value: The target value for the case.
188
212
  :return: A CaseQuery object representing the case.
189
213
  """
190
214
  output_type = make_set(output_type)
191
215
  if case is None or case_dict is None:
192
- case, case_dict = RDRDecorator.create_case_from_method(func, func_output, *args, **kwargs)
216
+ case, case_dict = create_case_from_method(func, func_output, *func_args, **func_kwargs)
193
217
  scope = func.__globals__
194
218
  scope.update(case_dict)
195
219
  func_args_type_hints = get_type_hints(func)
196
220
  output_name = list(func_output.keys())[0]
197
221
  func_args_type_hints.update({output_name: Union[tuple(output_type)]})
198
222
  return CaseQuery(case, output_name, tuple(output_type),
199
- mutual_exclusive, scope=scope,
223
+ mutual_exclusive, scope=scope, scenario=scenario, this_case_target_value=this_case_target_value,
200
224
  is_function=True, function_args_type_hints=func_args_type_hints)
201
225
 
202
- @staticmethod
203
- def create_case_from_method(func: Callable,
204
- func_output: Dict[str, Any],
205
- *args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
206
- """
207
- Create a Case from the function and its arguments.
208
-
209
- :param func: The function to create a case from.
210
- :param func_output: A dictionary containing the output of the function, where the key is the output name.
211
- :param args: The positional arguments of the function.
212
- :param kwargs: The keyword arguments of the function.
213
- :return: A Case object representing the case.
214
- """
215
- case_dict = get_method_args_as_dict(func, *args, **kwargs)
216
- case_dict.update(func_output)
217
- case_name = get_func_rdr_model_name(func)
218
- return Case(dict, id(case_dict), case_name, case_dict, **case_dict), case_dict
219
-
220
226
  def initialize_rdr_model_name_and_load(self, func: Callable) -> None:
221
- model_file_name = get_func_rdr_model_name(func, include_file_name=True)
222
- self.model_name = str_to_snake_case(model_file_name)
227
+ self.model_name = get_func_rdr_model_name(func, include_file_name=True)
223
228
  self.load()
224
229
 
225
230
  @staticmethod
@@ -242,3 +247,8 @@ class RDRDecorator:
242
247
  Load the RDR model from the specified directory, otherwise create a new one.
243
248
  """
244
249
  self.rdr = GeneralRDR(save_dir=self.models_dir, model_name=self.model_name)
250
+
251
+
252
+ def fit_rdr_func(scenario: Callable, rdr_decorated_func: Callable, *func_args, **func_kwargs) -> None:
253
+ rdr_decorated_func._rdr_decorator_instance.case_factory_metadata = CaseFactoryMetaData(scenario=scenario)
254
+ rdr_decorated_func(*func_args, **func_kwargs)