ripple-down-rules 0.0.15__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/datasets.py +2 -2
- ripple_down_rules/datastructures/callable_expression.py +52 -10
- ripple_down_rules/datastructures/case.py +54 -70
- ripple_down_rules/datastructures/dataclasses.py +69 -29
- ripple_down_rules/experts.py +29 -40
- ripple_down_rules/helpers.py +27 -0
- ripple_down_rules/prompt.py +77 -24
- ripple_down_rules/rdr.py +218 -200
- ripple_down_rules/rdr_decorators.py +55 -0
- ripple_down_rules/rules.py +7 -2
- ripple_down_rules/utils.py +167 -3
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/METADATA +1 -1
- ripple_down_rules-0.1.1.dist-info/RECORD +20 -0
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/WHEEL +1 -1
- ripple_down_rules-0.0.15.dist-info/RECORD +0 -18
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.0.15.dist-info → ripple_down_rules-0.1.1.dist-info}/top_level.txt +0 -0
ripple_down_rules/experts.py
CHANGED
@@ -4,7 +4,7 @@ import json
|
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
|
6
6
|
from sqlalchemy.orm import DeclarativeBase as SQLTable, MappedColumn as SQLColumn, Session
|
7
|
-
from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type, Union, Any
|
7
|
+
from typing_extensions import Optional, Dict, TYPE_CHECKING, List, Tuple, Type, Union, Any, get_type_hints
|
8
8
|
|
9
9
|
from .datastructures import (Case, PromptFor, CallableExpression, CaseAttribute, CaseQuery)
|
10
10
|
from .datastructures.case import show_current_and_corner_cases
|
@@ -36,14 +36,13 @@ class Expert(ABC):
|
|
36
36
|
"""
|
37
37
|
|
38
38
|
@abstractmethod
|
39
|
-
def ask_for_conditions(self,
|
39
|
+
def ask_for_conditions(self, case_query: CaseQuery, last_evaluated_rule: Optional[Rule] = None) \
|
40
40
|
-> CallableExpression:
|
41
41
|
"""
|
42
42
|
Ask the expert to provide the differentiating features between two cases or unique features for a case
|
43
43
|
that doesn't have a corner case to compare to.
|
44
44
|
|
45
|
-
:param
|
46
|
-
:param targets: The target categories to compare the case with.
|
45
|
+
:param case_query: The case query containing the case to classify and the required target.
|
47
46
|
:param last_evaluated_rule: The last evaluated rule.
|
48
47
|
:return: The differentiating features as new rule conditions.
|
49
48
|
"""
|
@@ -76,13 +75,11 @@ class Expert(ABC):
|
|
76
75
|
"""
|
77
76
|
pass
|
78
77
|
|
79
|
-
def ask_for_conclusion(self, case_query: CaseQuery
|
80
|
-
session: Optional[Session] = None) -> Optional[CallableExpression]:
|
78
|
+
def ask_for_conclusion(self, case_query: CaseQuery) -> Optional[CallableExpression]:
|
81
79
|
"""
|
82
80
|
Ask the expert to provide a relational conclusion for the case.
|
83
81
|
|
84
82
|
:param case_query: The case query containing the case to find a conclusion for.
|
85
|
-
:param session: The sqlalchemy orm session to use if the case is a Table.
|
86
83
|
:return: A callable expression that can be called with a new case as an argument.
|
87
84
|
"""
|
88
85
|
|
@@ -124,37 +121,33 @@ class Human(Expert):
|
|
124
121
|
with open(path + '.json', "r") as f:
|
125
122
|
self.all_expert_answers = json.load(f)
|
126
123
|
|
127
|
-
def ask_for_conditions(self,
|
128
|
-
targets: Union[List[CaseAttribute], List[SQLColumn]],
|
124
|
+
def ask_for_conditions(self, case_query: CaseQuery,
|
129
125
|
last_evaluated_rule: Optional[Rule] = None) \
|
130
126
|
-> CallableExpression:
|
131
127
|
if not self.use_loaded_answers:
|
132
|
-
show_current_and_corner_cases(case,
|
133
|
-
|
128
|
+
show_current_and_corner_cases(case_query.case, {case_query.attribute_name: case_query.target},
|
129
|
+
last_evaluated_rule=last_evaluated_rule)
|
130
|
+
return self._get_conditions(case_query)
|
134
131
|
|
135
|
-
def _get_conditions(self,
|
132
|
+
def _get_conditions(self, case_query: CaseQuery) \
|
136
133
|
-> CallableExpression:
|
137
134
|
"""
|
138
135
|
Ask the expert to provide the differentiating features between two cases or unique features for a case
|
139
136
|
that doesn't have a corner case to compare to.
|
140
137
|
|
141
|
-
:param
|
142
|
-
:param targets: The target categories to compare the case with.
|
138
|
+
:param case_query: The case query containing the case to classify.
|
143
139
|
:return: The differentiating features as new rule conditions.
|
144
140
|
"""
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
user_input =
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
user_input, condition = prompt_user_for_expression(case, PromptFor.Conditions, target_name, bool)
|
156
|
-
if not self.use_loaded_answers:
|
157
|
-
self.all_expert_answers.append(user_input)
|
141
|
+
user_input = None
|
142
|
+
if self.use_loaded_answers:
|
143
|
+
user_input = self.all_expert_answers.pop(0)
|
144
|
+
if user_input:
|
145
|
+
condition = CallableExpression(user_input, bool, scope=case_query.scope, session=self.session)
|
146
|
+
else:
|
147
|
+
user_input, condition = prompt_user_for_expression(case_query, PromptFor.Conditions, session=self.session)
|
148
|
+
if not self.use_loaded_answers:
|
149
|
+
self.all_expert_answers.append(user_input)
|
150
|
+
case_query.conditions = condition
|
158
151
|
return condition
|
159
152
|
|
160
153
|
def ask_for_extra_conclusions(self, case: Case, current_conclusions: List[CaseAttribute]) \
|
@@ -172,30 +165,26 @@ class Human(Expert):
|
|
172
165
|
category = self.ask_for_conclusion(CaseQuery(case), current_conclusions)
|
173
166
|
if not category:
|
174
167
|
break
|
175
|
-
extra_conclusions[category] = self._get_conditions(case, category)
|
168
|
+
extra_conclusions[category] = self._get_conditions(case, {category.__class__.__name__: category})
|
176
169
|
return extra_conclusions
|
177
170
|
|
178
|
-
def ask_for_conclusion(self, case_query: CaseQuery
|
179
|
-
current_conclusions: Optional[List[Any]] = None)\
|
180
|
-
-> Optional[CallableExpression]:
|
171
|
+
def ask_for_conclusion(self, case_query: CaseQuery) -> CaseQuery:
|
181
172
|
"""
|
182
173
|
Ask the expert to provide a conclusion for the case.
|
183
174
|
|
184
175
|
:param case_query: The case query containing the case to find a conclusion for.
|
185
|
-
:
|
186
|
-
:return: The conclusion for the case.
|
176
|
+
:return: The case query updated with the conclusion for the case.
|
187
177
|
"""
|
188
|
-
case = case_query.case
|
189
|
-
attribute_name = case_query.attribute_name
|
190
|
-
attribute_type = case_query.attribute_type
|
191
178
|
if self.use_loaded_answers:
|
192
179
|
expert_input = self.all_expert_answers.pop(0)
|
193
|
-
expression = CallableExpression(expert_input,
|
180
|
+
expression = CallableExpression(expert_input, case_query.attribute_type, session=self.session,
|
181
|
+
scope=case_query.scope)
|
194
182
|
else:
|
195
|
-
show_current_and_corner_cases(case
|
196
|
-
expert_input, expression = prompt_user_for_expression(
|
197
|
-
|
183
|
+
show_current_and_corner_cases(case_query.case)
|
184
|
+
expert_input, expression = prompt_user_for_expression(case_query, PromptFor.Conclusion,
|
185
|
+
session=self.session)
|
198
186
|
self.all_expert_answers.append(expert_input)
|
187
|
+
case_query.target = expression
|
199
188
|
return expression
|
200
189
|
|
201
190
|
def get_category_type(self, cat_name: str) -> Optional[Type[CaseAttribute]]:
|
@@ -0,0 +1,27 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
from sqlalchemy.orm import Session
|
4
|
+
from typing_extensions import Type, Optional
|
5
|
+
|
6
|
+
from ripple_down_rules.rdr import RippleDownRules
|
7
|
+
from ripple_down_rules.utils import get_func_rdr_model_path
|
8
|
+
|
9
|
+
|
10
|
+
def load_or_create_func_rdr_model(func, model_dir: str, rdr_type: Type[RippleDownRules],
|
11
|
+
session: Optional[Session] = None, **rdr_kwargs) -> RippleDownRules:
|
12
|
+
"""
|
13
|
+
Load the RDR model of the function if it exists, otherwise create a new one.
|
14
|
+
|
15
|
+
:param func: The function to load the model for.
|
16
|
+
:param model_dir: The directory where the model is stored.
|
17
|
+
:param rdr_type: The type of the RDR model to load.
|
18
|
+
:param session: The SQLAlchemy session to use.
|
19
|
+
:param rdr_kwargs: Additional arguments to pass to the RDR constructor in the case of a new model.
|
20
|
+
"""
|
21
|
+
model_path = get_func_rdr_model_path(func, model_dir)
|
22
|
+
if os.path.exists(model_path):
|
23
|
+
rdr = rdr_type.load(model_path)
|
24
|
+
rdr.session = session
|
25
|
+
else:
|
26
|
+
rdr = rdr_type(session=session, **rdr_kwargs)
|
27
|
+
return rdr
|
ripple_down_rules/prompt.py
CHANGED
@@ -2,31 +2,86 @@ import ast
|
|
2
2
|
import logging
|
3
3
|
from _ast import AST
|
4
4
|
|
5
|
+
from IPython.core.interactiveshell import ExecutionInfo
|
6
|
+
from IPython.terminal.embed import InteractiveShellEmbed
|
7
|
+
from traitlets.config import Config
|
5
8
|
from prompt_toolkit import PromptSession
|
6
9
|
from prompt_toolkit.completion import WordCompleter
|
7
10
|
from sqlalchemy.orm import DeclarativeBase as SQLTable, Session
|
8
11
|
from typing_extensions import Any, List, Optional, Tuple, Dict, Union, Type
|
9
12
|
|
10
|
-
from .datastructures import Case, PromptFor, CallableExpression, create_case, parse_string_to_expression
|
13
|
+
from .datastructures import Case, PromptFor, CallableExpression, create_case, parse_string_to_expression, CaseQuery
|
14
|
+
from .utils import capture_variable_assignment
|
11
15
|
|
12
16
|
|
13
|
-
|
14
|
-
output_type: Type, session: Optional[Session] = None) -> Tuple[str, CallableExpression]:
|
17
|
+
class IpythonShell:
|
15
18
|
"""
|
16
|
-
|
19
|
+
Create an embedded Ipython shell that can be used to prompt the user for input.
|
20
|
+
"""
|
21
|
+
def __init__(self, variable_to_capture: str, scope: Optional[Dict] = None, header: Optional[str] = None):
|
22
|
+
"""
|
23
|
+
Initialize the Ipython shell with the given scope and header.
|
24
|
+
|
25
|
+
:param variable_to_capture: The variable to capture from the user input.
|
26
|
+
:param scope: The scope to use for the shell.
|
27
|
+
:param header: The header to display when the shell is started.
|
28
|
+
"""
|
29
|
+
self.variable_to_capture: str = variable_to_capture
|
30
|
+
self.scope: Dict = scope or {}
|
31
|
+
self.header: str = header or ">>> Embedded Ipython Shell"
|
32
|
+
self.user_input: Optional[str] = None
|
33
|
+
self.shell: InteractiveShellEmbed = self._init_shell()
|
34
|
+
self._register_hooks()
|
35
|
+
|
36
|
+
def _init_shell(self):
|
37
|
+
"""
|
38
|
+
Initialize the Ipython shell with a custom configuration.
|
39
|
+
"""
|
40
|
+
cfg = Config()
|
41
|
+
shell = InteractiveShellEmbed(config=cfg, user_ns=self.scope, banner1=self.header)
|
42
|
+
return shell
|
43
|
+
|
44
|
+
def _register_hooks(self):
|
45
|
+
"""
|
46
|
+
Register hooks to capture specific events in the Ipython shell.
|
47
|
+
"""
|
48
|
+
def capture_variable(exec_info: ExecutionInfo):
|
49
|
+
code = exec_info.raw_cell
|
50
|
+
if self.variable_to_capture not in code:
|
51
|
+
return
|
52
|
+
# use ast to find if the user is assigning a value to the variable "condition"
|
53
|
+
assignment = capture_variable_assignment(code, self.variable_to_capture)
|
54
|
+
if assignment:
|
55
|
+
# if the user is assigning a value to the variable "condition", update the raw_condition
|
56
|
+
self.user_input = assignment
|
57
|
+
print(f"[Captured {self.variable_to_capture}]:\n{self.user_input}")
|
58
|
+
|
59
|
+
self.shell.events.register('pre_run_cell', capture_variable)
|
60
|
+
|
61
|
+
def run(self):
|
62
|
+
"""
|
63
|
+
Run the embedded shell.
|
64
|
+
"""
|
65
|
+
self.shell()
|
66
|
+
|
67
|
+
|
68
|
+
def prompt_user_for_expression(case_query: CaseQuery, prompt_for: PromptFor,
|
69
|
+
session: Optional[Session] = None) -> Tuple[str, CallableExpression]:
|
70
|
+
"""
|
71
|
+
Prompt the user for an executable python expression to the given case query.
|
17
72
|
|
18
|
-
:param
|
73
|
+
:param case_query: The case query to prompt the user for.
|
19
74
|
:param prompt_for: The type of information ask user about.
|
20
|
-
:param target_name: The name of the target attribute to compare the case with.
|
21
|
-
:param output_type: The type of the output of the given statement from the user.
|
22
75
|
:param session: The sqlalchemy orm session.
|
23
76
|
:return: A callable expression that takes a case and executes user expression on it.
|
24
77
|
"""
|
25
78
|
while True:
|
26
|
-
user_input, expression_tree = prompt_user_about_case(
|
27
|
-
|
79
|
+
user_input, expression_tree = prompt_user_about_case(case_query, prompt_for)
|
80
|
+
conclusion_type = bool if prompt_for == PromptFor.Conditions else case_query.attribute_type
|
81
|
+
callable_expression = CallableExpression(user_input, conclusion_type, expression_tree=expression_tree,
|
82
|
+
scope=case_query.scope, session=session)
|
28
83
|
try:
|
29
|
-
callable_expression(case)
|
84
|
+
callable_expression(case_query.case)
|
30
85
|
break
|
31
86
|
except Exception as e:
|
32
87
|
logging.error(e)
|
@@ -34,19 +89,18 @@ def prompt_user_for_expression(case: Union[Case, SQLTable], prompt_for: PromptFo
|
|
34
89
|
return user_input, callable_expression
|
35
90
|
|
36
91
|
|
37
|
-
def prompt_user_about_case(
|
38
|
-
-> Tuple[str, AST]:
|
92
|
+
def prompt_user_about_case(case_query: CaseQuery, prompt_for: PromptFor) -> Tuple[str, AST]:
|
39
93
|
"""
|
40
94
|
Prompt the user for input.
|
41
95
|
|
42
|
-
:param
|
96
|
+
:param case_query: The case query to prompt the user for.
|
43
97
|
:param prompt_for: The type of information the user should provide for the given case.
|
44
|
-
:param target_name: The name of the target property of the case that is queried.
|
45
98
|
:return: The user input, and the executable expression that was parsed from the user input.
|
46
99
|
"""
|
47
|
-
prompt_str = f"Give {prompt_for} for {
|
48
|
-
|
49
|
-
|
100
|
+
prompt_str = f"Give {prompt_for} for {case_query.name}"
|
101
|
+
scope = {'case': case_query.case, **case_query.scope}
|
102
|
+
shell = IpythonShell(prompt_for.value, scope=scope, header=prompt_str)
|
103
|
+
user_input, expression_tree = prompt_user_input_and_parse_to_expression(shell=shell)
|
50
104
|
return user_input, expression_tree
|
51
105
|
|
52
106
|
|
@@ -64,21 +118,20 @@ def get_completions(obj: Any) -> List[str]:
|
|
64
118
|
return completions
|
65
119
|
|
66
120
|
|
67
|
-
def prompt_user_input_and_parse_to_expression(
|
121
|
+
def prompt_user_input_and_parse_to_expression(shell: Optional[IpythonShell] = None,
|
68
122
|
user_input: Optional[str] = None) -> Tuple[str, ast.AST]:
|
69
123
|
"""
|
70
124
|
Prompt the user for input.
|
71
125
|
|
72
|
-
:param
|
73
|
-
:param session: The prompt session to use.
|
126
|
+
:param shell: The Ipython shell to use for prompting the user.
|
74
127
|
:param user_input: The user input to use. If given, the user input will be used instead of prompting the user.
|
75
128
|
:return: The user input and the AST tree.
|
76
129
|
"""
|
77
130
|
while True:
|
78
|
-
if
|
79
|
-
|
80
|
-
|
81
|
-
|
131
|
+
if user_input is None:
|
132
|
+
shell = IpythonShell() if shell is None else shell
|
133
|
+
shell.run()
|
134
|
+
user_input = shell.user_input
|
82
135
|
try:
|
83
136
|
return user_input, parse_string_to_expression(user_input)
|
84
137
|
except Exception as e:
|