ripple-down-rules 0.1.21__py3-none-any.whl → 0.1.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,13 +3,15 @@ from __future__ import annotations
3
3
  import json
4
4
  from abc import ABC, abstractmethod
5
5
 
6
- from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColumn, Session
7
- from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type, Union, Any, get_type_hints
6
+ from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Type, Any
8
7
 
9
- from .datastructures import (Case, PromptFor, CallableExpression, CaseAttribute, CaseQuery)
8
+ from .datastructures.case import Case, CaseAttribute
9
+ from .datastructures.callable_expression import CallableExpression
10
+ from .datastructures.enums import PromptFor
11
+ from .datastructures.dataclasses import CaseQuery
10
12
  from .datastructures.case import show_current_and_corner_cases
11
- from .prompt import prompt_user_for_expression, prompt_user_about_case
12
- from .utils import get_all_subclasses, is_iterable
13
+ from .prompt import prompt_user_for_expression, IPythonShell
14
+ from .utils import get_all_subclasses, make_list
13
15
 
14
16
  if TYPE_CHECKING:
15
17
  from .rdr import Rule
@@ -49,28 +51,24 @@ class Expert(ABC):
49
51
  pass
50
52
 
51
53
  @abstractmethod
52
- def ask_for_extra_conclusions(self, x: Case, current_conclusions: List[CaseAttribute]) \
53
- -> Dict[CaseAttribute, CallableExpression]:
54
+ def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
54
55
  """
55
- Ask the expert to provide extra conclusions for a case by providing a pair of category and conditions for
56
- that category.
56
+ Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
57
57
 
58
- :param x: The case to classify.
59
- :param current_conclusions: The current conclusions for the case.
60
- :return: The extra conclusions for the case.
58
+ :param case_query: The case query containing the case to classify.
59
+ :return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
60
+ conclusion and conditions for the rule.
61
61
  """
62
62
  pass
63
63
 
64
64
  @abstractmethod
65
- def ask_if_conclusion_is_correct(self, x: Case, conclusion: CaseAttribute,
66
- targets: Optional[List[CaseAttribute]] = None,
67
- current_conclusions: Optional[List[CaseAttribute]] = None) -> bool:
65
+ def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
66
+ current_conclusions: Any) -> bool:
68
67
  """
69
68
  Ask the expert if the conclusion is correct.
70
69
 
71
- :param x: The case to classify.
70
+ :param case_query: The case query about which the expert should answer.
72
71
  :param conclusion: The conclusion to check.
73
- :param targets: The target categories to compare the case with.
74
72
  :param current_conclusions: The current conclusions for the case.
75
73
  """
76
74
  pass
@@ -89,10 +87,9 @@ class Human(Expert):
89
87
  The Human Expert class, an expert that asks the human to provide differentiating features and conclusions.
90
88
  """
91
89
 
92
- def __init__(self, use_loaded_answers: bool = False, session: Optional[Session] = None):
90
+ def __init__(self, use_loaded_answers: bool = False):
93
91
  self.all_expert_answers = []
94
92
  self.use_loaded_answers = use_loaded_answers
95
- self.session = session
96
93
 
97
94
  def save_answers(self, path: str, append: bool = False):
98
95
  """
@@ -129,6 +126,24 @@ class Human(Expert):
129
126
  last_evaluated_rule=last_evaluated_rule)
130
127
  return self._get_conditions(case_query)
131
128
 
129
+ def ask_for_extra_rules(self, case_query: CaseQuery) -> List[Dict[PromptFor, CallableExpression]]:
130
+ """
131
+ Ask the expert to provide extra rules for a case by providing a pair of conclusion and conditions.
132
+
133
+ :param case_query: The case query containing the case to classify.
134
+ :return: The extra rules for the case as a list of dictionaries, where each dictionary contains the
135
+ conclusion and conditions for the rule.
136
+ """
137
+ rules = []
138
+ while True:
139
+ conclusion = self.ask_for_conclusion(case_query)
140
+ if conclusion is None:
141
+ break
142
+ conditions = self._get_conditions(case_query)
143
+ rules.append({PromptFor.Conclusion: conclusion,
144
+ PromptFor.Conditions: conditions})
145
+ return rules
146
+
132
147
  def _get_conditions(self, case_query: CaseQuery) \
133
148
  -> CallableExpression:
134
149
  """
@@ -142,47 +157,30 @@ class Human(Expert):
142
157
  if self.use_loaded_answers:
143
158
  user_input = self.all_expert_answers.pop(0)
144
159
  if user_input:
145
- condition = CallableExpression(user_input, bool, scope=case_query.scope, session=self.session)
160
+ condition = CallableExpression(user_input, bool, scope=case_query.scope)
146
161
  else:
147
- user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions, session=self.session)
162
+ user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions)
148
163
  if not self.use_loaded_answers:
149
164
  self.all_expert_answers.append(user_input)
150
165
  case_query.conditions = condition
151
166
  return condition
152
167
 
153
- def ask_for_extra_conclusions(self, case: Case, current_conclusions: List[CaseAttribute]) \
154
- -> Dict[CaseAttribute, CallableExpression]:
155
- """
156
- Ask the expert to provide extra conclusions for a case by providing a pair of category and conditions for
157
- that category.
158
-
159
- :param case: The case to classify.
160
- :param current_conclusions: The current conclusions for the case.
161
- :return: The extra conclusions for the case.
162
- """
163
- extra_conclusions = {}
164
- while True:
165
- category = self.ask_for_conclusion(CaseQuery(case), current_conclusions)
166
- if not category:
167
- break
168
- extra_conclusions[category] = self._get_conditions(case, {category.__class__.__name__: category})
169
- return extra_conclusions
170
-
171
- def ask_for_conclusion(self, case_query: CaseQuery) -> CaseQuery:
168
+ def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
172
169
  """
173
170
  Ask the expert to provide a conclusion for the case.
174
171
 
175
172
  :param case_query: The case query containing the case to find a conclusion for.
176
- :return: The case query updated with the conclusion for the case.
173
+ :return: The conclusion for the case as a callable expression.
177
174
  """
175
+ expression: Optional[CallableExpression] = None
178
176
  if self.use_loaded_answers:
179
177
  expert_input = self.all_expert_answers.pop(0)
180
- expression = CallableExpression(expert_input, case_query.attribute_type, session=self.session,
181
- scope=case_query.scope)
178
+ if expert_input is not None:
179
+ expression = CallableExpression(expert_input, case_query.attribute_type,
180
+ scope=case_query.scope)
182
181
  else:
183
182
  show_current_and_corner_cases(case_query.case)
184
- expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion,
185
- session=self.session)
183
+ expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion)
186
184
  self.all_expert_answers.append(expert_input)
187
185
  case_query.target = expression
188
186
  return expression
@@ -195,7 +193,8 @@ class Human(Expert):
195
193
  :return: The category type.
196
194
  """
197
195
  cat_name = cat_name.lower()
198
- self.known_categories = get_all_subclasses(CaseAttribute) if not self.known_categories else self.known_categories
196
+ self.known_categories = get_all_subclasses(
197
+ CaseAttribute) if not self.known_categories else self.known_categories
199
198
  self.known_categories.update(CaseAttribute.registry)
200
199
  category_type = None
201
200
  if cat_name in self.known_categories:
@@ -209,45 +208,39 @@ class Human(Expert):
209
208
  :param category_name: The name of the category to ask about.
210
209
  """
211
210
  question = f"Can a case have multiple values of the new category {category_name}? (y/n):"
212
- return not self.ask_yes_no_question(question)
211
+ return not self.ask_for_affirmation(question)
213
212
 
214
- def ask_if_conclusion_is_correct(self, x: Case, conclusion: CaseAttribute,
215
- targets: Optional[List[CaseAttribute]] = None,
216
- current_conclusions: Optional[List[CaseAttribute]] = None) -> bool:
213
+ def ask_if_conclusion_is_correct(self, case_query: CaseQuery, conclusion: Any,
214
+ current_conclusions: Any) -> bool:
217
215
  """
218
216
  Ask the expert if the conclusion is correct.
219
217
 
220
- :param x: The case to classify.
218
+ :param case_query: The case query about which the expert should answer.
221
219
  :param conclusion: The conclusion to check.
222
- :param targets: The target categories to compare the case with.
223
220
  :param current_conclusions: The current conclusions for the case.
224
221
  """
225
- question = ""
226
222
  if not self.use_loaded_answers:
227
- targets = targets or []
228
- targets = targets if isinstance(targets, list) else [targets]
229
- x.conclusions = current_conclusions
230
- x.targets = targets
231
- question = f"Is the conclusion {conclusion} correct for the case (y/n):" \
232
- f"\n{str(x)}"
233
- return self.ask_yes_no_question(question)
223
+ print(f"Current conclusions: {current_conclusions}")
224
+ return self.ask_for_affirmation(case_query,
225
+ f"Is the conclusion {conclusion} correct for the case (True/False):")
234
226
 
235
- def ask_yes_no_question(self, question: str) -> bool:
227
+ def ask_for_affirmation(self, case_query: CaseQuery, question: str) -> bool:
236
228
  """
237
229
  Ask the expert a yes or no question.
238
230
 
239
- :param question: The question to ask.
231
+ :param case_query: The case query about which the expert should answer.
232
+ :param question: The question to ask the expert.
240
233
  :return: The answer to the question.
241
234
  """
242
- if not self.use_loaded_answers:
243
- print(question)
244
235
  while True:
245
236
  if self.use_loaded_answers:
246
237
  answer = self.all_expert_answers.pop(0)
247
238
  else:
248
- answer = input()
249
- self.all_expert_answers.append(answer)
250
- if answer.lower() == "y":
239
+ _, expression = prompt_user_for_expression(case_query, PromptFor.Affirmation, question)
240
+ answer = expression(case_query.case)
241
+ if answer:
242
+ self.all_expert_answers.append(True)
251
243
  return True
252
- elif answer.lower() == "n":
244
+ else:
245
+ self.all_expert_answers.append(False)
253
246
  return False
@@ -1,10 +1,34 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
 
5
+ from .datastructures.dataclasses import CaseQuery
3
6
  from sqlalchemy.orm import Session
4
- from typing_extensions import Type, Optional
7
+ from typing_extensions import Type, Optional, Callable, Any, Dict, TYPE_CHECKING
8
+
9
+ from .utils import get_func_rdr_model_path
10
+ from .utils import calculate_precision_and_recall
11
+
12
+ if TYPE_CHECKING:
13
+ from .rdr import RippleDownRules
5
14
 
6
- from ripple_down_rules.rdr import RippleDownRules
7
- from ripple_down_rules.utils import get_func_rdr_model_path
15
+
16
+ def is_matching(classifier: Callable[[Any], Any], case_query: CaseQuery, pred_cat: Optional[Dict[str, Any]] = None) -> bool:
17
+ """
18
+ :param classifier: The RDR classifier to check the prediction of.
19
+ :param case_query: The case query to check.
20
+ :param pred_cat: The predicted category.
21
+ :return: Whether the classifier prediction is matching case_query target or not.
22
+ """
23
+ if case_query.target is None:
24
+ return False
25
+ if pred_cat is None:
26
+ pred_cat = classifier(case_query.case)
27
+ if not isinstance(pred_cat, dict):
28
+ pred_cat = {case_query.attribute_name: pred_cat}
29
+ target = {case_query.attribute_name: case_query.target_value}
30
+ precision, recall = calculate_precision_and_recall(pred_cat, target)
31
+ return all(recall) and all(precision)
8
32
 
9
33
 
10
34
  def load_or_create_func_rdr_model(func, model_dir: str, rdr_type: Type[RippleDownRules],
@@ -2,84 +2,118 @@ import ast
2
2
  import logging
3
3
  from _ast import AST
4
4
 
5
- from IPython.core.interactiveshell import ExecutionInfo
6
5
  from IPython.terminal.embed import InteractiveShellEmbed
7
6
  from traitlets.config import Config
8
- from prompt_toolkit import PromptSession
9
- from prompt_toolkit.completion import WordCompleter
10
- from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
11
- from typing_extensions import Any, List, Optional, Tuple, Dict, Union, Type
7
+ from typing_extensions import List, Optional, Tuple, Dict
12
8
 
13
- from .datastructures import Case, PromptFor, CallableExpression, create_case, parse_string_to_expression, CaseQuery
14
- from .utils import capture_variable_assignment
9
+ from .datastructures.enums import PromptFor
10
+ from .datastructures.callable_expression import CallableExpression, parse_string_to_expression
11
+ from .datastructures.dataclasses import CaseQuery
12
+ from .utils import extract_dependencies, contains_return_statement, make_set
15
13
 
16
14
 
17
- class IpythonShell:
15
+ class CustomInteractiveShell(InteractiveShellEmbed):
16
+ def __init__(self, **kwargs):
17
+ super().__init__(**kwargs)
18
+ self.all_lines = []
19
+
20
+ def run_cell(self, raw_cell: str, **kwargs):
21
+ """
22
+ Override the run_cell method to capture return statements.
23
+ """
24
+ if contains_return_statement(raw_cell):
25
+ self.all_lines.append(raw_cell)
26
+ print("Exiting shell on `return` statement.")
27
+ self.history_manager.store_inputs(line_num=self.execution_count, source=raw_cell)
28
+ self.ask_exit()
29
+ return None
30
+ result = super().run_cell(raw_cell, **kwargs)
31
+ if result.error_in_exec is None and result.error_before_exec is None:
32
+ self.all_lines.append(raw_cell)
33
+ return result
34
+
35
+
36
+ class IPythonShell:
18
37
  """
19
38
  Create an embedded Ipython shell that can be used to prompt the user for input.
20
39
  """
21
- def __init__(self, variable_to_capture: str, scope: Optional[Dict] = None, header: Optional[str] = None):
40
+
41
+ def __init__(self, scope: Optional[Dict] = None, header: Optional[str] = None):
22
42
  """
23
43
  Initialize the Ipython shell with the given scope and header.
24
44
 
25
- :param variable_to_capture: The variable to capture from the user input.
26
45
  :param scope: The scope to use for the shell.
27
46
  :param header: The header to display when the shell is started.
28
47
  """
29
- self.variable_to_capture: str = variable_to_capture
30
48
  self.scope: Dict = scope or {}
31
49
  self.header: str = header or ">>> Embedded Ipython Shell"
32
50
  self.user_input: Optional[str] = None
33
- self.shell: InteractiveShellEmbed = self._init_shell()
34
- self._register_hooks()
51
+ self.shell: CustomInteractiveShell = self._init_shell()
52
+ self.all_code_lines: List[str] = []
35
53
 
36
54
  def _init_shell(self):
37
55
  """
38
56
  Initialize the Ipython shell with a custom configuration.
39
57
  """
40
58
  cfg = Config()
41
- shell = InteractiveShellEmbed(config=cfg, user_ns=self.scope, banner1=self.header)
59
+ shell = CustomInteractiveShell(config=cfg, user_ns=self.scope, banner1=self.header)
42
60
  return shell
43
61
 
44
- def _register_hooks(self):
45
- """
46
- Register hooks to capture specific events in the Ipython shell.
47
- """
48
- def capture_variable(exec_info: ExecutionInfo):
49
- code = exec_info.raw_cell
50
- if self.variable_to_capture not in code:
51
- return
52
- # use ast to find if the user is assigning a value to the variable "condition"
53
- assignment = capture_variable_assignment(code, self.variable_to_capture)
54
- if assignment:
55
- # if the user is assigning a value to the variable "condition", update the raw_condition
56
- self.user_input = assignment
57
- print(f"[Captured {self.variable_to_capture}]:\n{self.user_input}")
58
-
59
- self.shell.events.register('pre_run_cell', capture_variable)
60
-
61
62
  def run(self):
62
63
  """
63
64
  Run the embedded shell.
64
65
  """
65
- self.shell()
66
-
67
-
68
- def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor,
69
- session: Optional[Session] = None) -> Tuple[str, CallableExpression]:
66
+ while True:
67
+ try:
68
+ self.shell()
69
+ self.update_user_input_from_code_lines()
70
+ break
71
+ except Exception as e:
72
+ logging.error(e)
73
+ print(e)
74
+
75
+ def update_user_input_from_code_lines(self):
76
+ """
77
+ Update the user input from the code lines captured in the shell.
78
+ """
79
+ if len(self.shell.all_lines) == 1 and self.shell.all_lines[0].replace('return', '').strip() == '':
80
+ self.user_input = None
81
+ else:
82
+ self.all_code_lines = extract_dependencies(self.shell.all_lines)
83
+ if len(self.all_code_lines) == 1:
84
+ if self.all_code_lines[0].strip() == '':
85
+ self.user_input = None
86
+ else:
87
+ self.user_input = self.all_code_lines[0].replace('return', '').strip()
88
+ else:
89
+ self.user_input = f"def _get_value(case):\n "
90
+ for cl in self.all_code_lines:
91
+ sub_code_lines = cl.split('\n')
92
+ self.user_input += '\n '.join(sub_code_lines) + '\n '
93
+
94
+
95
+ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor, prompt_str: Optional[str] = None)\
96
+ -> Tuple[Optional[str], Optional[CallableExpression]]:
70
97
  """
71
98
  Prompt the user for an executable python expression to the given case query.
72
99
 
73
100
  :param case_query: The case query to prompt the user for.
74
101
  :param prompt_for: The type of information ask user about.
75
- :param session: The sqlalchemy orm session.
102
+ :param prompt_str: The prompt string to display to the user.
76
103
  :return: A callable expression that takes a case and executes user expression on it.
77
104
  """
78
105
  while True:
79
- user_input, expression_tree = prompt_user_about_case(case_query, prompt_for)
106
+ user_input, expression_tree = prompt_user_about_case(case_query, prompt_for, prompt_str)
107
+ if user_input is None:
108
+ if prompt_for == PromptFor.Conclusion:
109
+ print("No conclusion provided. Exiting.")
110
+ return None, None
111
+ else:
112
+ print("Conditions must be provided. Please try again.")
113
+ continue
80
114
  conclusion_type = bool if prompt_for == PromptFor.Conditions else case_query.attribute_type
81
115
  callable_expression = CallableExpression(user_input, conclusion_type, expression_tree=expression_tree,
82
- scope=case_query.scope, session=session)
116
+ scope=case_query.scope)
83
117
  try:
84
118
  callable_expression(case_query.case)
85
119
  break
@@ -89,37 +123,26 @@ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor,
89
123
  return user_input, callable_expression
90
124
 
91
125
 
92
- def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor) -> Tuple[str, AST]:
126
+ def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor,
127
+ prompt_str: Optional[str] = None) -> Tuple[Optional[str], Optional[AST]]:
93
128
  """
94
129
  Prompt the user for input.
95
130
 
96
131
  :param case_query: The case query to prompt the user for.
97
132
  :param prompt_for: The type of information the user should provide for the given case.
133
+ :param prompt_str: The prompt string to display to the user.
98
134
  :return: The user input, and the executable expression that was parsed from the user input.
99
135
  """
100
- prompt_str = f"Give {prompt_for} for {case_query.name}"
136
+ if prompt_str is None:
137
+ prompt_str = f"Give {prompt_for} for {case_query.name}"
101
138
  scope = {'case': case_query.case, **case_query.scope}
102
- shell = IpythonShell(prompt_for.value, scope=scope, header=prompt_str)
103
- user_input, expression_tree = prompt_user_input_and_parse_to_expression(shell=shell)
104
- return user_input, expression_tree
105
-
106
-
107
- def get_completions(obj: Any) -> List[str]:
108
- """
109
- Get all completions for the object. This is used in the python prompt shell to provide completions for the user.
110
-
111
- :param obj: The object to get completions for.
112
- :return: A list of completions.
113
- """
114
- # Define completer with all object attributes and comparison operators
115
- completions = ['==', '!=', '>', '<', '>=', '<=', 'in', 'not', 'and', 'or', 'is']
116
- completions += ["isinstance(", "issubclass(", "type(", "len(", "hasattr(", "getattr(", "setattr(", "delattr("]
117
- completions += list(create_case(obj).keys())
118
- return completions
139
+ shell = IPythonShell(scope=scope, header=prompt_str)
140
+ return prompt_user_input_and_parse_to_expression(shell=shell)
119
141
 
120
142
 
121
- def prompt_user_input_and_parse_to_expression(shell: Optional[IpythonShell] = None,
122
- user_input: Optional[str] = None) -> Tuple[str, ast.AST]:
143
+ def prompt_user_input_and_parse_to_expression(shell: Optional[IPythonShell] = None,
144
+ user_input: Optional[str] = None)\
145
+ -> Tuple[Optional[str], Optional[ast.AST]]:
123
146
  """
124
147
  Prompt the user for input.
125
148
 
@@ -129,9 +152,12 @@ def prompt_user_input_and_parse_to_expression(shell: Optional[IpythonShell] = No
129
152
  """
130
153
  while True:
131
154
  if user_input is None:
132
- shell = IpythonShell() if shell is None else shell
155
+ shell = IPythonShell() if shell is None else shell
133
156
  shell.run()
134
157
  user_input = shell.user_input
158
+ if user_input is None:
159
+ return None, None
160
+ print(user_input)
135
161
  try:
136
162
  return user_input, parse_string_to_expression(user_input)
137
163
  except Exception as e:
@@ -139,16 +165,3 @@ def prompt_user_input_and_parse_to_expression(shell: Optional[IpythonShell] = No
139
165
  logging.error(msg)
140
166
  print(msg)
141
167
  user_input = None
142
-
143
-
144
- def get_prompt_session_for_obj(obj: Any) -> PromptSession:
145
- """
146
- Get a prompt session for an object.
147
-
148
- :param obj: The object to get the prompt session for.
149
- :return: The prompt session.
150
- """
151
- completions = get_completions(obj)
152
- completer = WordCompleter(completions)
153
- session = PromptSession(completer=completer)
154
- return session