ripple-down-rules 0.0.15__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,7 @@ import json
4
4
  from abc import ABC, abstractmethod
5
5
 
6
6
  from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColumn, Session
7
- from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type, Union, Any
7
+ from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type, Union, Any, get_type_hints
8
8
 
9
9
  from .datastructures import (Case, PromptFor, CallableExpression, CaseAttribute, CaseQuery)
10
10
  from .datastructures.case import show_current_and_corner_cases
@@ -36,14 +36,13 @@ class Expert(ABC):
36
36
  """
37
37
 
38
38
  @abstractmethod
39
- def ask_for_conditions(self, x: Case, targets: List[CaseAttribute], last_evaluated_rule: Optional[Rule] = None) \
39
+ def ask_for_conditions(self, case_query: CaseQuery, last_evaluated_rule: Optional[Rule] = None) \
40
40
  -> CallableExpression:
41
41
  """
42
42
  Ask the expert to provide the differentiating features between two cases or unique features for a case
43
43
  that doesn't have a corner case to compare to.
44
44
 
45
- :param x: The case to classify.
46
- :param targets: The target categories to compare the case with.
45
+ :param case_query: The case query containing the case to classify and the required target.
47
46
  :param last_evaluated_rule: The last evaluated rule.
48
47
  :return: The differentiating features as new rule conditions.
49
48
  """
@@ -76,13 +75,11 @@ class Expert(ABC):
76
75
  """
77
76
  pass
78
77
 
79
- def ask_for_conclusion(self, case_query: CaseQuery,
80
- session: Optional[Session] = None) -> Optional[CallableExpression]:
78
+ def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
81
79
  """
82
80
  Ask the expert to provide a relational conclusion for the case.
83
81
 
84
82
  :param case_query: The case query containing the case to find a conclusion for.
85
- :param session: The sqlalchemy orm session to use if the case is a Table.
86
83
  :return: A callable expression that can be called with a new case as an argument.
87
84
  """
88
85
 
@@ -124,37 +121,33 @@ class Human(Expert):
124
121
  with open(path + '.json', "r") as f:
125
122
  self.all_expert_answers = json.load(f)
126
123
 
127
- def ask_for_conditions(self, case: Case,
128
- targets: Union[List[CaseAttribute], List[SQLColumn]],
124
+ def ask_for_conditions(self, case_query: CaseQuery,
129
125
  last_evaluated_rule: Optional[Rule] = None) \
130
126
  -> CallableExpression:
131
127
  if not self.use_loaded_answers:
132
- show_current_and_corner_cases(case, targets, last_evaluated_rule=last_evaluated_rule)
133
- return self._get_conditions(case, targets)
128
+ show_current_and_corner_cases(case_query.case, {case_query.attribute_name: case_query.target},
129
+ last_evaluated_rule=last_evaluated_rule)
130
+ return self._get_conditions(case_query)
134
131
 
135
- def _get_conditions(self, case: Case, targets: List[CaseAttribute]) \
132
+ def _get_conditions(self, case_query: CaseQuery) \
136
133
  -> CallableExpression:
137
134
  """
138
135
  Ask the expert to provide the differentiating features between two cases or unique features for a case
139
136
  that doesn't have a corner case to compare to.
140
137
 
141
- :param case: The case to classify.
142
- :param targets: The target categories to compare the case with.
138
+ :param case_query: The case query containing the case to classify.
143
139
  :return: The differentiating features as new rule conditions.
144
140
  """
145
- targets = targets if isinstance(targets, list) else [targets]
146
- condition = None
147
- for target in targets:
148
- target_name = target.__class__.__name__
149
- user_input = None
150
- if self.use_loaded_answers:
151
- user_input = self.all_expert_answers.pop(0)
152
- if user_input:
153
- condition = CallableExpression(user_input, bool, session=self.session)
154
- else:
155
- user_input, condition = prompt_user_for_expression(case, PromptFor.Conditions, target_name, bool)
156
- if not self.use_loaded_answers:
157
- self.all_expert_answers.append(user_input)
141
+ user_input = None
142
+ if self.use_loaded_answers:
143
+ user_input = self.all_expert_answers.pop(0)
144
+ if user_input:
145
+ condition = CallableExpression(user_input, bool, scope=case_query.scope, session=self.session)
146
+ else:
147
+ user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions, session=self.session)
148
+ if not self.use_loaded_answers:
149
+ self.all_expert_answers.append(user_input)
150
+ case_query.conditions = condition
158
151
  return condition
159
152
 
160
153
  def ask_for_extra_conclusions(self, case: Case, current_conclusions: List[CaseAttribute]) \
@@ -172,30 +165,26 @@ class Human(Expert):
172
165
  category = self.ask_for_conclusion(CaseQuery(case), current_conclusions)
173
166
  if not category:
174
167
  break
175
- extra_conclusions[category] = self._get_conditions(case, category)
168
+ extra_conclusions[category] = self._get_conditions(case, {category.__class__.__name__: category})
176
169
  return extra_conclusions
177
170
 
178
- def ask_for_conclusion(self, case_query: CaseQuery,
179
- current_conclusions: Optional[List[Any]] = None)\
180
- -> Optional[CallableExpression]:
171
+ def ask_for_conclusion(self, case_query: CaseQuery) -> CaseQuery:
181
172
  """
182
173
  Ask the expert to provide a conclusion for the case.
183
174
 
184
175
  :param case_query: The case query containing the case to find a conclusion for.
185
- :param current_conclusions: The current conclusions for the case if any.
186
- :return: The conclusion for the case.
176
+ :return: The case query updated with the conclusion for the case.
187
177
  """
188
- case = case_query.case
189
- attribute_name = case_query.attribute_name
190
- attribute_type = case_query.attribute_type
191
178
  if self.use_loaded_answers:
192
179
  expert_input = self.all_expert_answers.pop(0)
193
- expression = CallableExpression(expert_input, conclusion_type=attribute_type, session=self.session)
180
+ expression = CallableExpression(expert_input, case_query.attribute_type, session=self.session,
181
+ scope=case_query.scope)
194
182
  else:
195
- show_current_and_corner_cases(case, current_conclusions=current_conclusions)
196
- expert_input, expression = prompt_user_for_expression(case, PromptFor.Conclusion, attribute_name,
197
- attribute_type)
183
+ show_current_and_corner_cases(case_query.case)
184
+ expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion,
185
+ session=self.session)
198
186
  self.all_expert_answers.append(expert_input)
187
+ case_query.target = expression
199
188
  return expression
200
189
 
201
190
  def get_category_type(self, cat_name: str) -> Optional[Type[CaseAttribute]]:
@@ -0,0 +1,27 @@
1
+ import os
2
+
3
+ from sqlalchemy.orm import Session
4
+ from typing_extensions import Type, Optional
5
+
6
+ from ripple_down_rules.rdr import RippleDownRules
7
+ from ripple_down_rules.utils import get_func_rdr_model_path
8
+
9
+
10
+ def load_or_create_func_rdr_model(func, model_dir: str, rdr_type: Type[RippleDownRules],
11
+ session: Optional[Session] = None, **rdr_kwargs) -> RippleDownRules:
12
+ """
13
+ Load the RDR model of the function if it exists, otherwise create a new one.
14
+
15
+ :param func: The function to load the model for.
16
+ :param model_dir: The directory where the model is stored.
17
+ :param rdr_type: The type of the RDR model to load.
18
+ :param session: The SQLAlchemy session to use.
19
+ :param rdr_kwargs: Additional arguments to pass to the RDR constructor in the case of a new model.
20
+ """
21
+ model_path = get_func_rdr_model_path(func, model_dir)
22
+ if os.path.exists(model_path):
23
+ rdr = rdr_type.load(model_path)
24
+ rdr.session = session
25
+ else:
26
+ rdr = rdr_type(session=session, **rdr_kwargs)
27
+ return rdr
@@ -2,31 +2,86 @@ import ast
2
2
  import logging
3
3
  from _ast import AST
4
4
 
5
+ from IPython.core.interactiveshell import ExecutionInfo
6
+ from IPython.terminal.embed import InteractiveShellEmbed
7
+ from traitlets.config import Config
5
8
  from prompt_toolkit import PromptSession
6
9
  from prompt_toolkit.completion import WordCompleter
7
10
  from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
8
11
  from typing_extensions import Any, List, Optional, Tuple, Dict, Union, Type
9
12
 
10
- from .datastructures import Case, PromptFor, CallableExpression, create_case, parse_string_to_expression
13
+ from .datastructures import Case, PromptFor, CallableExpression, create_case, parse_string_to_expression, CaseQuery
14
+ from .utils import capture_variable_assignment
11
15
 
12
16
 
13
- def prompt_user_for_expression(case: Union[Case, SQLTable], prompt_for: PromptFor, target_name: str,
14
- output_type: Type, session: Optional[Session] = None) -> Tuple[str, CallableExpression]:
17
+ class IpythonShell:
15
18
  """
16
- Prompt the user for an executable python expression.
19
+ Create an embedded Ipython shell that can be used to prompt the user for input.
20
+ """
21
+ def __init__(self, variable_to_capture: str, scope: Optional[Dict] = None, header: Optional[str] = None):
22
+ """
23
+ Initialize the Ipython shell with the given scope and header.
24
+
25
+ :param variable_to_capture: The variable to capture from the user input.
26
+ :param scope: The scope to use for the shell.
27
+ :param header: The header to display when the shell is started.
28
+ """
29
+ self.variable_to_capture: str = variable_to_capture
30
+ self.scope: Dict = scope or {}
31
+ self.header: str = header or ">>> Embedded Ipython Shell"
32
+ self.user_input: Optional[str] = None
33
+ self.shell: InteractiveShellEmbed = self._init_shell()
34
+ self._register_hooks()
35
+
36
+ def _init_shell(self):
37
+ """
38
+ Initialize the Ipython shell with a custom configuration.
39
+ """
40
+ cfg = Config()
41
+ shell = InteractiveShellEmbed(config=cfg, user_ns=self.scope, banner1=self.header)
42
+ return shell
43
+
44
+ def _register_hooks(self):
45
+ """
46
+ Register hooks to capture specific events in the Ipython shell.
47
+ """
48
+ def capture_variable(exec_info: ExecutionInfo):
49
+ code = exec_info.raw_cell
50
+ if self.variable_to_capture not in code:
51
+ return
52
+ # use ast to find if the user is assigning a value to the variable "condition"
53
+ assignment = capture_variable_assignment(code, self.variable_to_capture)
54
+ if assignment:
55
+ # if the user is assigning a value to the variable "condition", update the raw_condition
56
+ self.user_input = assignment
57
+ print(f"[Captured {self.variable_to_capture}]:\n{self.user_input}")
58
+
59
+ self.shell.events.register('pre_run_cell', capture_variable)
60
+
61
+ def run(self):
62
+ """
63
+ Run the embedded shell.
64
+ """
65
+ self.shell()
66
+
67
+
68
+ def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor,
69
+ session: Optional[Session] = None) -> Tuple[str, CallableExpression]:
70
+ """
71
+ Prompt the user for an executable python expression to the given case query.
17
72
 
18
- :param case: The case to classify.
73
+ :param case_query: The case query to prompt the user for.
19
74
  :param prompt_for: The type of information ask user about.
20
- :param target_name: The name of the target attribute to compare the case with.
21
- :param output_type: The type of the output of the given statement from the user.
22
75
  :param session: The sqlalchemy orm session.
23
76
  :return: A callable expression that takes a case and executes user expression on it.
24
77
  """
25
78
  while True:
26
- user_input, expression_tree = prompt_user_about_case(case, prompt_for, target_name)
27
- callable_expression = CallableExpression(user_input, output_type, expression_tree=expression_tree, session=session)
79
+ user_input, expression_tree = prompt_user_about_case(case_query, prompt_for)
80
+ conclusion_type = bool if prompt_for == PromptFor.Conditions else case_query.attribute_type
81
+ callable_expression = CallableExpression(user_input, conclusion_type, expression_tree=expression_tree,
82
+ scope=case_query.scope, session=session)
28
83
  try:
29
- callable_expression(case)
84
+ callable_expression(case_query.case)
30
85
  break
31
86
  except Exception as e:
32
87
  logging.error(e)
@@ -34,19 +89,18 @@ def prompt_user_for_expression(case: Union[Case, SQLTable], prompt_for: PromptFo
34
89
  return user_input, callable_expression
35
90
 
36
91
 
37
- def prompt_user_about_case(case: Union[Case, SQLTable], prompt_for: PromptFor, target_name: str) \
38
- -> Tuple[str, AST]:
92
+ def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor) -> Tuple[str, AST]:
39
93
  """
40
94
  Prompt the user for input.
41
95
 
42
- :param case: The case to prompt the user on.
96
+ :param case_query: The case query to prompt the user for.
43
97
  :param prompt_for: The type of information the user should provide for the given case.
44
- :param target_name: The name of the target property of the case that is queried.
45
98
  :return: The user input, and the executable expression that was parsed from the user input.
46
99
  """
47
- prompt_str = f"Give {prompt_for} for {case.__class__.__name__}.{target_name}"
48
- session = get_prompt_session_for_obj(case)
49
- user_input, expression_tree = prompt_user_input_and_parse_to_expression(prompt_str, session)
100
+ prompt_str = f"Give {prompt_for} for {case_query.name}"
101
+ scope = {'case': case_query.case, **case_query.scope}
102
+ shell = IpythonShell(prompt_for.value, scope=scope, header=prompt_str)
103
+ user_input, expression_tree = prompt_user_input_and_parse_to_expression(shell=shell)
50
104
  return user_input, expression_tree
51
105
 
52
106
 
@@ -64,21 +118,20 @@ def get_completions(obj: Any) -> List[str]:
64
118
  return completions
65
119
 
66
120
 
67
- def prompt_user_input_and_parse_to_expression(prompt: Optional[str] = None, session: Optional[PromptSession] = None,
121
+ def prompt_user_input_and_parse_to_expression(shell: Optional[IpythonShell] = None,
68
122
  user_input: Optional[str] = None) -> Tuple[str, ast.AST]:
69
123
  """
70
124
  Prompt the user for input.
71
125
 
72
- :param prompt: The prompt to display to the user.
73
- :param session: The prompt session to use.
126
+ :param shell: The Ipython shell to use for prompting the user.
74
127
  :param user_input: The user input to use. If given, the user input will be used instead of prompting the user.
75
128
  :return: The user input and the AST tree.
76
129
  """
77
130
  while True:
78
- if not user_input:
79
- user_input = session.prompt(f"\n{prompt} >>> ")
80
- if user_input.lower() in ['exit', 'quit', '']:
81
- break
131
+ if user_input is None:
132
+ shell = IpythonShell() if shell is None else shell
133
+ shell.run()
134
+ user_input = shell.user_input
82
135
  try:
83
136
  return user_input, parse_string_to_expression(user_input)
84
137
  except Exception as e: