ripple-down-rules 0.5.63__py3-none-any.whl → 0.5.71__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,16 +6,18 @@ of the RDRs.
6
6
  import os.path
7
7
  from functools import wraps
8
8
 
9
- from pyparsing.tools.cvt_pyparsing_pep8_names import camel_to_snake
10
- from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union
9
+ from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union, Sequence
11
10
 
12
- from ripple_down_rules.datastructures.case import create_case, Case
11
+ from ripple_down_rules.datastructures.case import Case
13
12
  from ripple_down_rules.datastructures.dataclasses import CaseQuery
14
- from ripple_down_rules.datastructures.enums import Category
15
13
  from ripple_down_rules.experts import Expert, Human
16
- from ripple_down_rules.rdr import GeneralRDR, RippleDownRules
14
+ from ripple_down_rules.rdr import GeneralRDR
15
+ try:
16
+ from ripple_down_rules.user_interface.gui import RDRCaseViewer
17
+ except ImportError:
18
+ RDRCaseViewer = None
17
19
  from ripple_down_rules.utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
18
- get_method_class_if_exists, get_method_name, str_to_snake_case
20
+ get_method_class_if_exists, str_to_snake_case
19
21
 
20
22
 
21
23
  class RDRDecorator:
@@ -26,7 +28,10 @@ class RDRDecorator:
26
28
  mutual_exclusive: bool,
27
29
  output_name: str = "output_",
28
30
  fit: bool = True,
29
- expert: Optional[Expert] = None):
31
+ expert: Optional[Expert] = None,
32
+ ask_always: bool = False,
33
+ update_existing_rules: bool = True,
34
+ viewer: Optional[RDRCaseViewer] = None):
30
35
  """
31
36
  :param models_dir: The directory to save/load the RDR models.
32
37
  :param output_type: The type of the output. This is used to create the RDR model.
@@ -38,6 +43,9 @@ class RDRDecorator:
38
43
  classification mode. This means that the RDR will classify the function's output based on the RDR model.
39
44
  :param expert: The expert that will be used to prompt the user for the correct output. If None, a Human
40
45
  expert will be used.
46
+ :param ask_always: If True, the function will ask the user for a target if it doesn't exist.
47
+ :param update_existing_rules: If True, the function will update the existing RDR rules
48
+ even if they gave an output.
41
49
  :return: A decorator to use a GeneralRDR as a classifier that monitors and modifies the function's output.
42
50
  """
43
51
  self.rdr_models_dir = models_dir
@@ -48,6 +56,9 @@ class RDRDecorator:
48
56
  self.output_name = output_name
49
57
  self.fit: bool = fit
50
58
  self.expert: Optional[Expert] = expert
59
+ self.ask_always = ask_always
60
+ self.update_existing_rules = update_existing_rules
61
+ self.viewer = viewer
51
62
  self.load()
52
63
 
53
64
  def decorator(self, func: Callable) -> Callable:
@@ -59,59 +70,77 @@ class RDRDecorator:
59
70
  self.parsed_output_type = self.parse_output_type(func, self.output_type, *args)
60
71
  if self.model_name is None:
61
72
  self.initialize_rdr_model_name_and_load(func)
73
+ if self.expert is None:
74
+ self.expert = Human(viewer=self.viewer,
75
+ answers_save_path=self.rdr_models_dir + f'/expert_answers')
76
+
77
+ func_output = {self.output_name: func(*args, **kwargs)}
62
78
 
63
79
  if self.fit:
64
- case_query = self.create_case_query_from_method(func, self.parsed_output_type,
65
- self.mutual_exclusive, self.output_name,
80
+ case_query = self.create_case_query_from_method(func, func_output,
81
+ self.parsed_output_type,
82
+ self.mutual_exclusive,
66
83
  *args, **kwargs)
67
- output = self.rdr.fit_case(case_query, expert=self.expert)
84
+ output = self.rdr.fit_case(case_query, expert=self.expert,
85
+ ask_always_for_target=self.ask_always,
86
+ update_existing_rules=self.update_existing_rules,
87
+ viewer=self.viewer)
88
+ else:
89
+ case, case_dict = self.create_case_from_method(func, func_output, *args, **kwargs)
90
+ output = self.rdr.classify(case)
91
+
92
+ if self.output_name in output:
68
93
  return output[self.output_name]
69
94
  else:
70
- case, case_dict = self.create_case_from_method(func, self.output_name, *args, **kwargs)
71
- return self.rdr.classify(case)[self.output_name]
95
+ return func_output[self.output_name]
72
96
 
73
97
  return wrapper
74
98
 
75
99
  @staticmethod
76
- def create_case_query_from_method(func: Callable, output_type, mutual_exclusive: bool,
77
- output_name: str = 'output_', *args, **kwargs) -> CaseQuery:
100
+ def create_case_query_from_method(func: Callable,
101
+ func_output: Dict[str, Any],
102
+ output_type: Sequence[Type],
103
+ mutual_exclusive: bool,
104
+ *args, **kwargs) -> CaseQuery:
78
105
  """
79
106
  Create a CaseQuery from the function and its arguments.
80
107
 
81
108
  :param func: The function to create a case from.
82
- :param output_type: The type of the output.
109
+ :param func_output: The output of the function as a dictionary, where the key is the output name.
110
+ :param output_type: The type of the output as a sequence of types.
83
111
  :param mutual_exclusive: If True, the output types are mutually exclusive.
84
- :param output_name: The name of the output in the case. Defaults to 'output_'.
85
112
  :param args: The positional arguments of the function.
86
113
  :param kwargs: The keyword arguments of the function.
87
114
  :return: A CaseQuery object representing the case.
88
115
  """
89
116
  output_type = make_set(output_type)
90
- case, case_dict = RDRDecorator.create_case_from_method(func, output_name, *args, **kwargs)
117
+ case, case_dict = RDRDecorator.create_case_from_method(func, func_output, *args, **kwargs)
91
118
  scope = func.__globals__
92
119
  scope.update(case_dict)
93
120
  func_args_type_hints = get_type_hints(func)
121
+ output_name = list(func_output.keys())[0]
94
122
  func_args_type_hints.update({output_name: Union[tuple(output_type)]})
95
123
  return CaseQuery(case, output_name, Union[tuple(output_type)],
96
124
  mutual_exclusive, scope=scope,
97
125
  is_function=True, function_args_type_hints=func_args_type_hints)
98
126
 
99
127
  @staticmethod
100
- def create_case_from_method(func: Callable, output_name: str = "output_", *args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
128
+ def create_case_from_method(func: Callable,
129
+ func_output: Dict[str, Any],
130
+ *args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
101
131
  """
102
132
  Create a Case from the function and its arguments.
103
133
 
104
134
  :param func: The function to create a case from.
105
- :param output_name: The name of the output in the case. Defaults to 'output_'.
135
+ :param func_output: A dictionary containing the output of the function, where the key is the output name.
106
136
  :param args: The positional arguments of the function.
107
137
  :param kwargs: The keyword arguments of the function.
108
138
  :return: A Case object representing the case.
109
139
  """
110
140
  case_dict = get_method_args_as_dict(func, *args, **kwargs)
111
- func_output = func(*args, **kwargs)
112
- case_dict.update({output_name: func_output})
141
+ case_dict.update(func_output)
113
142
  case_name = get_func_rdr_model_name(func)
114
- return create_case(case_dict, obj_name=case_name, max_recursion_idx=3), case_dict
143
+ return Case(dict, id(case_dict), case_name, case_dict, **case_dict), case_dict
115
144
 
116
145
  def initialize_rdr_model_name_and_load(self, func: Callable) -> None:
117
146
  model_file_name = get_func_rdr_model_name(func, include_file_name=True)
@@ -148,8 +177,10 @@ class RDRDecorator:
148
177
  model_path = os.path.join(self.rdr_models_dir, self.model_name + f"/rdr_metadata/{self.model_name}.json")
149
178
  if os.path.exists(os.path.join(self.rdr_models_dir, model_path)):
150
179
  self.rdr = GeneralRDR.load(self.rdr_models_dir, self.model_name)
180
+ self.rdr.set_viewer(self.viewer)
151
181
  if self.rdr is None:
152
- self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name)
182
+ self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name,
183
+ viewer=self.viewer)
153
184
 
154
185
  def update_from_python(self):
155
186
  """
@@ -3,16 +3,18 @@ from __future__ import annotations
3
3
  import logging
4
4
  import re
5
5
  from abc import ABC, abstractmethod
6
+ from pathlib import Path
6
7
  from uuid import uuid4
7
8
 
8
9
  from anytree import NodeMixin
9
10
  from sqlalchemy.orm import DeclarativeBase as SQLTable
10
- from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
11
+ from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple, Callable
11
12
 
12
13
  from .datastructures.callable_expression import CallableExpression
13
14
  from .datastructures.case import Case
15
+ from .datastructures.dataclasses import CaseFactoryMetaData, CaseConf, CaseQuery
14
16
  from .datastructures.enums import RDREdge, Stop
15
- from .utils import SubclassJSONSerializer, conclusion_to_json
17
+ from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_imports_from_types
16
18
 
17
19
 
18
20
  class Rule(NodeMixin, SubclassJSONSerializer, ABC):
@@ -27,7 +29,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
27
29
  corner_case: Optional[Union[Case, SQLTable]] = None,
28
30
  weight: Optional[str] = None,
29
31
  conclusion_name: Optional[str] = None,
30
- uid: Optional[str] = None):
32
+ uid: Optional[str] = None,
33
+ corner_case_metadata: Optional[CaseFactoryMetaData] = None):
31
34
  """
32
35
  A rule in the ripple down rules classifier.
33
36
 
@@ -38,10 +41,13 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
38
41
  :param weight: The weight of the rule, which is the type of edge connecting the rule to its parent.
39
42
  :param conclusion_name: The name of the conclusion of the rule.
40
43
  :param uid: The unique id of the rule.
44
+ :param corner_case_metadata: Metadata about the corner case, such as the factory that created it or the
45
+ scenario it is based on.
41
46
  """
42
47
  super(Rule, self).__init__()
43
48
  self.conclusion = conclusion
44
49
  self.corner_case = corner_case
50
+ self.corner_case_metadata: Optional[CaseFactoryMetaData] = corner_case_metadata
45
51
  self.parent = parent
46
52
  self.weight: Optional[str] = weight
47
53
  self.conditions = conditions if conditions else None
@@ -51,6 +57,20 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
51
57
  # generate a unique id for the rule using uuid4
52
58
  self.uid: str = uid if uid else str(uuid4().int)
53
59
 
60
+ @classmethod
61
+ def from_case_query(cls, case_query: CaseQuery) -> Rule:
62
+ """
63
+ Create a SingleClassRule from a CaseQuery.
64
+
65
+ :param case_query: The CaseQuery to create the rule from.
66
+ :return: A SingleClassRule instance.
67
+ """
68
+ corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
69
+ return cls(conditions=case_query.conditions, conclusion=case_query.target,
70
+ corner_case=case_query.case, parent=None,
71
+ corner_case_metadata=corner_case_metadata,
72
+ conclusion_name=case_query.attribute_name)
73
+
54
74
  def _post_detach(self, parent):
55
75
  """
56
76
  Called after this node is detached from the tree, useful when drawing the tree.
@@ -82,6 +102,27 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
82
102
  """
83
103
  pass
84
104
 
105
+ def write_corner_case_as_source_code(self, cases_file: Path) -> None:
106
+ """
107
+ Write the source code representation of the corner case of the rule to a file.
108
+
109
+ :param cases_file: The file to write the corner case to if it is a definition.
110
+ """
111
+ if self.corner_case_metadata is None:
112
+ return
113
+ types_to_import = set()
114
+ if self.corner_case_metadata.factory_method is not None:
115
+ types_to_import.add(self.corner_case_metadata.factory_method)
116
+ if self.corner_case_metadata.scenario is not None:
117
+ types_to_import.add(self.corner_case_metadata.scenario)
118
+ if self.corner_case_metadata.case_conf is not None:
119
+ types_to_import.add(self.corner_case_metadata.case_conf)
120
+ types_to_import.add(CaseFactoryMetaData)
121
+ imports = get_imports_from_types(list(types_to_import))
122
+ with open(cases_file, 'a') as f:
123
+ f.write("\n".join(imports) + "\n\n\n")
124
+ f.write(f"corner_case_{self.uid} = {self.corner_case_metadata}" + "\n\n\n")
125
+
85
126
  def write_conclusion_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
86
127
  """
87
128
  Get the source code representation of the conclusion of the rule.
@@ -150,11 +191,16 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
150
191
  pass
151
192
 
152
193
  def _to_json(self) -> Dict[str, Any]:
153
- json_serialization = {"conditions": self.conditions.to_json(),
194
+ try:
195
+ corner_case = SubclassJSONSerializer.to_json_static(self.corner_case) if self.corner_case else None
196
+ except Exception as e:
197
+ logging.debug("Failed to serialize corner case to json, setting it to None. Error: %s", e)
198
+ corner_case = None
199
+ json_serialization = {"_type": get_full_class_name(type(self)),
200
+ "conditions": self.conditions.to_json(),
154
201
  "conclusion": conclusion_to_json(self.conclusion),
155
202
  "parent": self.parent.json_serialization if self.parent else None,
156
- "corner_case": SubclassJSONSerializer.to_json_static(self.corner_case)
157
- if self.corner_case else None,
203
+ "corner_case": corner_case,
158
204
  "conclusion_name": self.conclusion_name,
159
205
  "weight": self.weight,
160
206
  "uid": self.uid}
@@ -277,9 +323,12 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
277
323
  returned_rule = self.alternative(x) if self.alternative else self
278
324
  return returned_rule if returned_rule.fired else self
279
325
 
280
- def fit_rule(self, x: Case, target: CallableExpression, conditions: CallableExpression):
281
- new_rule = SingleClassRule(conditions, target,
282
- corner_case=x, parent=self)
326
+ def fit_rule(self, case_query: CaseQuery):
327
+ corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
328
+ new_rule = SingleClassRule(case_query.conditions, case_query.target,
329
+ corner_case=case_query.case, parent=self,
330
+ corner_case_metadata=corner_case_metadata,
331
+ )
283
332
  if self.fired:
284
333
  self.refinement = new_rule
285
334
  else:
@@ -363,11 +412,12 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
363
412
  elif self.alternative: # Here alternative refers to next rule in MultiClassRDR
364
413
  return self.alternative
365
414
 
366
- def fit_rule(self, x: Case, target: CallableExpression, conditions: CallableExpression):
367
- if self.fired and target != self.conclusion:
368
- self.refinement = MultiClassStopRule(conditions, corner_case=x, parent=self)
415
+ def fit_rule(self, case_query: CaseQuery):
416
+ if self.fired and case_query.target != self.conclusion:
417
+ self.refinement = MultiClassStopRule(case_query.conditions, corner_case=case_query.case, parent=self)
369
418
  elif not self.fired:
370
- self.alternative = MultiClassTopRule(conditions, target, corner_case=x, parent=self)
419
+ self.alternative = MultiClassTopRule(case_query.conditions, case_query.target,
420
+ corner_case=case_query.case, parent=self)
371
421
 
372
422
  def _to_json(self) -> Dict[str, Any]:
373
423
  self.json_serialization = {**Rule._to_json(self),
@@ -281,12 +281,14 @@ class RDRCaseViewer(QMainWindow):
281
281
  main_obj: Optional[Dict[str, Any]] = None
282
282
  user_input: Optional[str] = None
283
283
  attributes_widget: Optional[QWidget] = None
284
- save_function: Optional[Callable[str], None] = None
284
+ save_function: Optional[Callable[str, str], None] = None
285
285
 
286
-
287
- def __init__(self, parent=None, save_file: Optional[str] = None):
286
+ def __init__(self, parent=None,
287
+ save_dir: Optional[str] = None,
288
+ save_model_name: Optional[str] = None):
288
289
  super().__init__(parent)
289
- self.save_file = save_file
290
+ self.save_dir = save_dir
291
+ self.save_model_name = save_model_name
290
292
 
291
293
  self.setWindowTitle("RDR Case Viewer")
292
294
 
@@ -323,17 +325,17 @@ class RDRCaseViewer(QMainWindow):
323
325
 
324
326
  # Add both to main layout
325
327
  main_layout.addWidget(self.attributes_widget, stretch=1)
326
- main_layout.addWidget(middle_widget, stretch=2)
328
+ main_layout.addWidget(middle_widget, stretch=1)
327
329
  main_layout.addWidget(self.obj_diagram_viewer, stretch=2)
328
330
 
329
- def set_save_function(self, save_function: Callable[[str], None]) -> None:
331
+ def set_save_function(self, save_function: Callable[[str, str], None]) -> None:
330
332
  """
331
333
  Set the function to save the file.
332
334
 
333
335
  :param save_function: The function to save the file.
334
336
  """
335
337
  self.save_function = save_function
336
- self.save_btn.clicked.connect(lambda: self.save_function(self.save_file))
338
+ self.save_btn.clicked.connect(lambda: self.save_function(self.save_dir, self.save_model_name))
337
339
 
338
340
  def print(self, msg):
339
341
  """
@@ -134,7 +134,7 @@ class IPythonShell:
134
134
  """
135
135
  Update the user input from the code lines captured in the shell.
136
136
  """
137
- if len(self.shell.all_lines) == 1 and self.shell.all_lines[0].replace('return', '').strip() == '':
137
+ if self.shell.all_lines[0].replace('return', '').strip() == '':
138
138
  self.user_input = None
139
139
  else:
140
140
  self.all_code_lines = extract_dependencies(self.shell.all_lines)
@@ -1,5 +1,9 @@
1
1
  import logging
2
2
 
3
+ from ripple_down_rules.datastructures.case import Case
4
+ from ripple_down_rules.datastructures.dataclasses import CaseQuery
5
+ from ripple_down_rules.utils import SubclassJSONSerializer
6
+
3
7
  try:
4
8
  import graphviz
5
9
  except ImportError:
@@ -77,7 +81,11 @@ def generate_object_graph(obj, name='root', seen=None, graph=None, current_depth
77
81
  for attr in dir(obj):
78
82
  if attr.startswith('_'):
79
83
  continue
80
- if attr == 'scope':
84
+ if isinstance(obj, CaseQuery) and attr == 'scope':
85
+ continue
86
+ if isinstance(obj, Case) and attr in ['data']:
87
+ continue
88
+ if isinstance(obj, SubclassJSONSerializer) and attr == 'data_class_refs':
81
89
  continue
82
90
  value = getattr(obj, attr)
83
91
  if callable(value):
@@ -8,13 +8,13 @@ from functools import cached_property
8
8
  from textwrap import indent, dedent
9
9
 
10
10
  from colorama import Fore, Style
11
- from typing_extensions import Optional, Type, List, Callable, Tuple, Dict
11
+ from typing_extensions import Optional, Type, List, Callable, Tuple, Dict, Any, Union
12
12
 
13
13
  from ..datastructures.case import Case
14
14
  from ..datastructures.dataclasses import CaseQuery
15
15
  from ..datastructures.enums import Editor, PromptFor
16
- from ..utils import str_to_snake_case, get_imports_from_scope, make_list, typing_hint_to_str, \
17
- get_imports_from_types, extract_function_source, extract_imports
16
+ from ..utils import str_to_snake_case, get_imports_from_scope, make_list, stringify_hint, \
17
+ get_imports_from_types, extract_function_source, extract_imports, get_types_to_import_from_type_hints
18
18
 
19
19
 
20
20
  def detect_available_editor() -> Optional[Editor]:
@@ -177,7 +177,7 @@ class TemplateFileCreator:
177
177
  for k, v in self.case_query.case.items():
178
178
  if (self.case_query.function_args_type_hints is not None
179
179
  and k in self.case_query.function_args_type_hints):
180
- func_args[k] = typing_hint_to_str(self.case_query.function_args_type_hints[k])[0]
180
+ func_args[k] = stringify_hint(self.case_query.function_args_type_hints[k])
181
181
  else:
182
182
  func_args[k] = type(v).__name__ if not isinstance(v, type) else f"Type[{v.__name__}]"
183
183
  func_args = ', '.join([f"{k}: {v}" if str(v) not in ["NoneType", "None"] else str(k)
@@ -207,30 +207,25 @@ class TemplateFileCreator:
207
207
  for k, v in self.case_query.case.items():
208
208
  if (self.case_query.function_args_type_hints is not None
209
209
  and k in self.case_query.function_args_type_hints):
210
- hint_list = typing_hint_to_str(self.case_query.function_args_type_hints[k])[1]
211
- for hint in hint_list:
212
- hint_split = hint.split('.')
213
- if len(hint_split) > 1:
214
- case_type_imports.append(f"from {'.'.join(hint_split[:-1])} import {hint_split[-1]}")
210
+ types_to_import = get_types_to_import_from_type_hints([self.case_query.function_args_type_hints[k]])
211
+ case_type_imports.extend(list(types_to_import))
215
212
  else:
216
- if isinstance(v, type):
217
- case_type_imports.append(f"from {v.__module__} import {v.__name__}")
218
- elif hasattr(v, "__module__") and not v.__module__.startswith("__"):
219
- case_type_imports.append(f"\nfrom {type(v).__module__} import {type(v).__name__}")
213
+ case_type_imports.append(v)
220
214
  else:
221
- case_type_imports.append(f"from {self.case_type.__module__} import {self.case_type.__name__}")
215
+ case_type_imports.append(self.case_type)
222
216
  if self.output_type is None:
223
- output_type_imports = [f"from typing_extensions import Any"]
217
+ output_type_imports = [Any]
224
218
  else:
225
- output_type_imports = get_imports_from_types(self.output_type)
219
+ output_type_imports = self.output_type
226
220
  if len(self.output_type) > 1:
227
- output_type_imports.append("from typing_extensions import Union")
221
+ output_type_imports.append(Union)
228
222
  if list in self.output_type:
229
- output_type_imports.append("from typing_extensions import List")
230
- imports = get_imports_from_scope(self.case_query.scope)
231
- imports = [i for i in imports if ("get_ipython" not in i)]
232
- imports.extend(case_type_imports)
233
- imports.extend([oti for oti in output_type_imports if oti not in imports])
223
+ output_type_imports.append(List)
224
+ import_types = list(self.case_query.scope.values())
225
+ # imports = [i for i in imports if ("get_ipython" not in i)]
226
+ import_types.extend(case_type_imports)
227
+ import_types.extend(output_type_imports)
228
+ imports = get_imports_from_types(import_types)
234
229
  imports = set(imports)
235
230
  return '\n'.join(imports)
236
231