ripple-down-rules 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +1 -1
- ripple_down_rules/datastructures/callable_expression.py +16 -9
- ripple_down_rules/datastructures/case.py +10 -4
- ripple_down_rules/datastructures/dataclasses.py +62 -3
- ripple_down_rules/experts.py +12 -2
- ripple_down_rules/helpers.py +55 -9
- ripple_down_rules/rdr.py +148 -101
- ripple_down_rules/rdr_decorators.py +54 -26
- ripple_down_rules/rules.py +63 -13
- ripple_down_rules/user_interface/gui.py +10 -7
- ripple_down_rules/user_interface/ipython_custom_shell.py +1 -1
- ripple_down_rules/user_interface/object_diagram.py +9 -1
- ripple_down_rules/user_interface/template_file_creator.py +24 -24
- ripple_down_rules/utils.py +258 -76
- {ripple_down_rules-0.5.5.dist-info → ripple_down_rules-0.5.7.dist-info}/METADATA +2 -1
- ripple_down_rules-0.5.7.dist-info/RECORD +24 -0
- ripple_down_rules-0.5.5.dist-info/RECORD +0 -24
- {ripple_down_rules-0.5.5.dist-info → ripple_down_rules-0.5.7.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.5.5.dist-info → ripple_down_rules-0.5.7.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.5.5.dist-info → ripple_down_rules-0.5.7.dist-info}/top_level.txt +0 -0
@@ -7,13 +7,14 @@ import os.path
|
|
7
7
|
from functools import wraps
|
8
8
|
|
9
9
|
from pyparsing.tools.cvt_pyparsing_pep8_names import camel_to_snake
|
10
|
-
from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union
|
10
|
+
from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union, Sequence
|
11
11
|
|
12
12
|
from ripple_down_rules.datastructures.case import create_case, Case
|
13
13
|
from ripple_down_rules.datastructures.dataclasses import CaseQuery
|
14
14
|
from ripple_down_rules.datastructures.enums import Category
|
15
15
|
from ripple_down_rules.experts import Expert, Human
|
16
16
|
from ripple_down_rules.rdr import GeneralRDR, RippleDownRules
|
17
|
+
from ripple_down_rules.user_interface.gui import RDRCaseViewer
|
17
18
|
from ripple_down_rules.utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
|
18
19
|
get_method_class_if_exists, get_method_name, str_to_snake_case
|
19
20
|
|
@@ -24,15 +25,16 @@ class RDRDecorator:
|
|
24
25
|
def __init__(self, models_dir: str,
|
25
26
|
output_type: Tuple[Type],
|
26
27
|
mutual_exclusive: bool,
|
27
|
-
python_dir: Optional[str] = None,
|
28
28
|
output_name: str = "output_",
|
29
29
|
fit: bool = True,
|
30
|
-
expert: Optional[Expert] = None
|
30
|
+
expert: Optional[Expert] = None,
|
31
|
+
ask_always: bool = False,
|
32
|
+
update_existing_rules: bool = True,
|
33
|
+
viewer: Optional[RDRCaseViewer] = None):
|
31
34
|
"""
|
32
35
|
:param models_dir: The directory to save/load the RDR models.
|
33
36
|
:param output_type: The type of the output. This is used to create the RDR model.
|
34
37
|
:param mutual_exclusive: If True, the output types are mutually exclusive.
|
35
|
-
:param python_dir: The directory to save the RDR model as a python file.
|
36
38
|
If None, the RDR model will not be saved as a python file.
|
37
39
|
:param output_name: The name of the output. This is used to create the RDR model.
|
38
40
|
:param fit: If True, the function will be in fit mode. This means that the RDR will prompt the user for the
|
@@ -40,6 +42,9 @@ class RDRDecorator:
|
|
40
42
|
classification mode. This means that the RDR will classify the function's output based on the RDR model.
|
41
43
|
:param expert: The expert that will be used to prompt the user for the correct output. If None, a Human
|
42
44
|
expert will be used.
|
45
|
+
:param ask_always: If True, the function will ask the user for a target if it doesn't exist.
|
46
|
+
:param update_existing_rules: If True, the function will update the existing RDR rules
|
47
|
+
even if they gave an output.
|
43
48
|
:return: A decorator to use a GeneralRDR as a classifier that monitors and modifies the function's output.
|
44
49
|
"""
|
45
50
|
self.rdr_models_dir = models_dir
|
@@ -47,10 +52,12 @@ class RDRDecorator:
|
|
47
52
|
self.output_type = output_type
|
48
53
|
self.parsed_output_type: List[Type] = []
|
49
54
|
self.mutual_exclusive = mutual_exclusive
|
50
|
-
self.rdr_python_path: Optional[str] = python_dir
|
51
55
|
self.output_name = output_name
|
52
56
|
self.fit: bool = fit
|
53
57
|
self.expert: Optional[Expert] = expert
|
58
|
+
self.ask_always = ask_always
|
59
|
+
self.update_existing_rules = update_existing_rules
|
60
|
+
self.viewer = viewer
|
54
61
|
self.load()
|
55
62
|
|
56
63
|
def decorator(self, func: Callable) -> Callable:
|
@@ -62,61 +69,77 @@ class RDRDecorator:
|
|
62
69
|
self.parsed_output_type = self.parse_output_type(func, self.output_type, *args)
|
63
70
|
if self.model_name is None:
|
64
71
|
self.initialize_rdr_model_name_and_load(func)
|
72
|
+
if self.expert is None:
|
73
|
+
self.expert = Human(viewer=self.viewer,
|
74
|
+
answers_save_path=self.rdr_models_dir + f'/expert_answers')
|
75
|
+
|
76
|
+
func_output = {self.output_name: func(*args, **kwargs)}
|
65
77
|
|
66
78
|
if self.fit:
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
self.mutual_exclusive, self.output_name,
|
79
|
+
case_query = self.create_case_query_from_method(func, func_output,
|
80
|
+
self.parsed_output_type,
|
81
|
+
self.mutual_exclusive,
|
71
82
|
*args, **kwargs)
|
72
|
-
output = self.rdr.fit_case(case_query, expert=self.expert
|
83
|
+
output = self.rdr.fit_case(case_query, expert=self.expert,
|
84
|
+
ask_always_for_target=self.ask_always,
|
85
|
+
update_existing_rules=self.update_existing_rules,
|
86
|
+
viewer=self.viewer)
|
87
|
+
else:
|
88
|
+
case, case_dict = self.create_case_from_method(func, func_output, *args, **kwargs)
|
89
|
+
output = self.rdr.classify(case)
|
90
|
+
|
91
|
+
if self.output_name in output:
|
73
92
|
return output[self.output_name]
|
74
93
|
else:
|
75
|
-
|
76
|
-
return self.rdr.classify(case)[self.output_name]
|
94
|
+
return func_output[self.output_name]
|
77
95
|
|
78
96
|
return wrapper
|
79
97
|
|
80
98
|
@staticmethod
|
81
|
-
def create_case_query_from_method(func: Callable,
|
82
|
-
|
99
|
+
def create_case_query_from_method(func: Callable,
|
100
|
+
func_output: Dict[str, Any],
|
101
|
+
output_type: Sequence[Type],
|
102
|
+
mutual_exclusive: bool,
|
103
|
+
*args, **kwargs) -> CaseQuery:
|
83
104
|
"""
|
84
105
|
Create a CaseQuery from the function and its arguments.
|
85
106
|
|
86
107
|
:param func: The function to create a case from.
|
87
|
-
:param
|
108
|
+
:param func_output: The output of the function as a dictionary, where the key is the output name.
|
109
|
+
:param output_type: The type of the output as a sequence of types.
|
88
110
|
:param mutual_exclusive: If True, the output types are mutually exclusive.
|
89
|
-
:param output_name: The name of the output in the case. Defaults to 'output_'.
|
90
111
|
:param args: The positional arguments of the function.
|
91
112
|
:param kwargs: The keyword arguments of the function.
|
92
113
|
:return: A CaseQuery object representing the case.
|
93
114
|
"""
|
94
115
|
output_type = make_set(output_type)
|
95
|
-
case, case_dict = RDRDecorator.create_case_from_method(func,
|
116
|
+
case, case_dict = RDRDecorator.create_case_from_method(func, func_output, *args, **kwargs)
|
96
117
|
scope = func.__globals__
|
97
118
|
scope.update(case_dict)
|
98
119
|
func_args_type_hints = get_type_hints(func)
|
120
|
+
output_name = list(func_output.keys())[0]
|
99
121
|
func_args_type_hints.update({output_name: Union[tuple(output_type)]})
|
100
122
|
return CaseQuery(case, output_name, Union[tuple(output_type)],
|
101
123
|
mutual_exclusive, scope=scope,
|
102
124
|
is_function=True, function_args_type_hints=func_args_type_hints)
|
103
125
|
|
104
126
|
@staticmethod
|
105
|
-
def create_case_from_method(func: Callable,
|
127
|
+
def create_case_from_method(func: Callable,
|
128
|
+
func_output: Dict[str, Any],
|
129
|
+
*args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
|
106
130
|
"""
|
107
131
|
Create a Case from the function and its arguments.
|
108
132
|
|
109
133
|
:param func: The function to create a case from.
|
110
|
-
:param
|
134
|
+
:param func_output: A dictionary containing the output of the function, where the key is the output name.
|
111
135
|
:param args: The positional arguments of the function.
|
112
136
|
:param kwargs: The keyword arguments of the function.
|
113
137
|
:return: A Case object representing the case.
|
114
138
|
"""
|
115
139
|
case_dict = get_method_args_as_dict(func, *args, **kwargs)
|
116
|
-
func_output
|
117
|
-
case_dict.update({output_name: func_output})
|
140
|
+
case_dict.update(func_output)
|
118
141
|
case_name = get_func_rdr_model_name(func)
|
119
|
-
return
|
142
|
+
return Case(dict, id(case_dict), case_name, case_dict, **case_dict), case_dict
|
120
143
|
|
121
144
|
def initialize_rdr_model_name_and_load(self, func: Callable) -> None:
|
122
145
|
model_file_name = get_func_rdr_model_name(func, include_file_name=True)
|
@@ -148,10 +171,15 @@ class RDRDecorator:
|
|
148
171
|
"""
|
149
172
|
Load the RDR model from the specified directory.
|
150
173
|
"""
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
174
|
+
self.rdr = None
|
175
|
+
if self.model_name is not None:
|
176
|
+
model_path = os.path.join(self.rdr_models_dir, self.model_name + f"/rdr_metadata/{self.model_name}.json")
|
177
|
+
if os.path.exists(os.path.join(self.rdr_models_dir, model_path)):
|
178
|
+
self.rdr = GeneralRDR.load(self.rdr_models_dir, self.model_name)
|
179
|
+
self.rdr.set_viewer(self.viewer)
|
180
|
+
if self.rdr is None:
|
181
|
+
self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name,
|
182
|
+
viewer=self.viewer)
|
155
183
|
|
156
184
|
def update_from_python(self):
|
157
185
|
"""
|
ripple_down_rules/rules.py
CHANGED
@@ -3,16 +3,18 @@ from __future__ import annotations
|
|
3
3
|
import logging
|
4
4
|
import re
|
5
5
|
from abc import ABC, abstractmethod
|
6
|
+
from pathlib import Path
|
6
7
|
from uuid import uuid4
|
7
8
|
|
8
9
|
from anytree import NodeMixin
|
9
10
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
10
|
-
from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
|
11
|
+
from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple, Callable
|
11
12
|
|
12
13
|
from .datastructures.callable_expression import CallableExpression
|
13
14
|
from .datastructures.case import Case
|
15
|
+
from .datastructures.dataclasses import CaseFactoryMetaData, CaseConf, CaseQuery
|
14
16
|
from .datastructures.enums import RDREdge, Stop
|
15
|
-
from .utils import SubclassJSONSerializer, conclusion_to_json
|
17
|
+
from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_imports_from_types
|
16
18
|
|
17
19
|
|
18
20
|
class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
@@ -27,7 +29,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
27
29
|
corner_case: Optional[Union[Case, SQLTable]] = None,
|
28
30
|
weight: Optional[str] = None,
|
29
31
|
conclusion_name: Optional[str] = None,
|
30
|
-
uid: Optional[str] = None
|
32
|
+
uid: Optional[str] = None,
|
33
|
+
corner_case_metadata: Optional[CaseFactoryMetaData] = None):
|
31
34
|
"""
|
32
35
|
A rule in the ripple down rules classifier.
|
33
36
|
|
@@ -38,10 +41,13 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
38
41
|
:param weight: The weight of the rule, which is the type of edge connecting the rule to its parent.
|
39
42
|
:param conclusion_name: The name of the conclusion of the rule.
|
40
43
|
:param uid: The unique id of the rule.
|
44
|
+
:param corner_case_metadata: Metadata about the corner case, such as the factory that created it or the
|
45
|
+
scenario it is based on.
|
41
46
|
"""
|
42
47
|
super(Rule, self).__init__()
|
43
48
|
self.conclusion = conclusion
|
44
49
|
self.corner_case = corner_case
|
50
|
+
self.corner_case_metadata: Optional[CaseFactoryMetaData] = corner_case_metadata
|
45
51
|
self.parent = parent
|
46
52
|
self.weight: Optional[str] = weight
|
47
53
|
self.conditions = conditions if conditions else None
|
@@ -51,6 +57,20 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
51
57
|
# generate a unique id for the rule using uuid4
|
52
58
|
self.uid: str = uid if uid else str(uuid4().int)
|
53
59
|
|
60
|
+
@classmethod
|
61
|
+
def from_case_query(cls, case_query: CaseQuery) -> Rule:
|
62
|
+
"""
|
63
|
+
Create a SingleClassRule from a CaseQuery.
|
64
|
+
|
65
|
+
:param case_query: The CaseQuery to create the rule from.
|
66
|
+
:return: A SingleClassRule instance.
|
67
|
+
"""
|
68
|
+
corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
69
|
+
return cls(conditions=case_query.conditions, conclusion=case_query.target,
|
70
|
+
corner_case=case_query.case, parent=None,
|
71
|
+
corner_case_metadata=corner_case_metadata,
|
72
|
+
conclusion_name=case_query.attribute_name)
|
73
|
+
|
54
74
|
def _post_detach(self, parent):
|
55
75
|
"""
|
56
76
|
Called after this node is detached from the tree, useful when drawing the tree.
|
@@ -82,6 +102,27 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
82
102
|
"""
|
83
103
|
pass
|
84
104
|
|
105
|
+
def write_corner_case_as_source_code(self, cases_file: Path) -> None:
|
106
|
+
"""
|
107
|
+
Write the source code representation of the corner case of the rule to a file.
|
108
|
+
|
109
|
+
:param cases_file: The file to write the corner case to if it is a definition.
|
110
|
+
"""
|
111
|
+
if self.corner_case_metadata is None:
|
112
|
+
return
|
113
|
+
types_to_import = set()
|
114
|
+
if self.corner_case_metadata.factory_method is not None:
|
115
|
+
types_to_import.add(self.corner_case_metadata.factory_method)
|
116
|
+
if self.corner_case_metadata.scenario is not None:
|
117
|
+
types_to_import.add(self.corner_case_metadata.scenario)
|
118
|
+
if self.corner_case_metadata.case_conf is not None:
|
119
|
+
types_to_import.add(self.corner_case_metadata.case_conf)
|
120
|
+
types_to_import.add(CaseFactoryMetaData)
|
121
|
+
imports = get_imports_from_types(list(types_to_import))
|
122
|
+
with open(cases_file, 'a') as f:
|
123
|
+
f.write("\n".join(imports) + "\n\n\n")
|
124
|
+
f.write(f"corner_case_{self.uid} = {self.corner_case_metadata}" + "\n\n\n")
|
125
|
+
|
85
126
|
def write_conclusion_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
|
86
127
|
"""
|
87
128
|
Get the source code representation of the conclusion of the rule.
|
@@ -150,11 +191,16 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
150
191
|
pass
|
151
192
|
|
152
193
|
def _to_json(self) -> Dict[str, Any]:
|
153
|
-
|
194
|
+
try:
|
195
|
+
corner_case = SubclassJSONSerializer.to_json_static(self.corner_case) if self.corner_case else None
|
196
|
+
except Exception as e:
|
197
|
+
logging.debug("Failed to serialize corner case to json, setting it to None. Error: %s", e)
|
198
|
+
corner_case = None
|
199
|
+
json_serialization = {"_type": get_full_class_name(type(self)),
|
200
|
+
"conditions": self.conditions.to_json(),
|
154
201
|
"conclusion": conclusion_to_json(self.conclusion),
|
155
202
|
"parent": self.parent.json_serialization if self.parent else None,
|
156
|
-
"corner_case":
|
157
|
-
if self.corner_case else None,
|
203
|
+
"corner_case": corner_case,
|
158
204
|
"conclusion_name": self.conclusion_name,
|
159
205
|
"weight": self.weight,
|
160
206
|
"uid": self.uid}
|
@@ -277,9 +323,12 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
|
277
323
|
returned_rule = self.alternative(x) if self.alternative else self
|
278
324
|
return returned_rule if returned_rule.fired else self
|
279
325
|
|
280
|
-
def fit_rule(self,
|
281
|
-
|
282
|
-
|
326
|
+
def fit_rule(self, case_query: CaseQuery):
|
327
|
+
corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
328
|
+
new_rule = SingleClassRule(case_query.conditions, case_query.target,
|
329
|
+
corner_case=case_query.case, parent=self,
|
330
|
+
corner_case_metadata=corner_case_metadata,
|
331
|
+
)
|
283
332
|
if self.fired:
|
284
333
|
self.refinement = new_rule
|
285
334
|
else:
|
@@ -363,11 +412,12 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
|
363
412
|
elif self.alternative: # Here alternative refers to next rule in MultiClassRDR
|
364
413
|
return self.alternative
|
365
414
|
|
366
|
-
def fit_rule(self,
|
367
|
-
if self.fired and target != self.conclusion:
|
368
|
-
self.refinement = MultiClassStopRule(conditions, corner_case=
|
415
|
+
def fit_rule(self, case_query: CaseQuery):
|
416
|
+
if self.fired and case_query.target != self.conclusion:
|
417
|
+
self.refinement = MultiClassStopRule(case_query.conditions, corner_case=case_query.case, parent=self)
|
369
418
|
elif not self.fired:
|
370
|
-
self.alternative = MultiClassTopRule(conditions, target,
|
419
|
+
self.alternative = MultiClassTopRule(case_query.conditions, case_query.target,
|
420
|
+
corner_case=case_query.case, parent=self)
|
371
421
|
|
372
422
|
def _to_json(self) -> Dict[str, Any]:
|
373
423
|
self.json_serialization = {**Rule._to_json(self),
|
@@ -281,12 +281,14 @@ class RDRCaseViewer(QMainWindow):
|
|
281
281
|
main_obj: Optional[Dict[str, Any]] = None
|
282
282
|
user_input: Optional[str] = None
|
283
283
|
attributes_widget: Optional[QWidget] = None
|
284
|
-
save_function: Optional[Callable[str], None] = None
|
284
|
+
save_function: Optional[Callable[str, str], None] = None
|
285
285
|
|
286
|
-
|
287
|
-
|
286
|
+
def __init__(self, parent=None,
|
287
|
+
save_dir: Optional[str] = None,
|
288
|
+
save_model_name: Optional[str] = None):
|
288
289
|
super().__init__(parent)
|
289
|
-
self.
|
290
|
+
self.save_dir = save_dir
|
291
|
+
self.save_model_name = save_model_name
|
290
292
|
|
291
293
|
self.setWindowTitle("RDR Case Viewer")
|
292
294
|
|
@@ -323,17 +325,17 @@ class RDRCaseViewer(QMainWindow):
|
|
323
325
|
|
324
326
|
# Add both to main layout
|
325
327
|
main_layout.addWidget(self.attributes_widget, stretch=1)
|
326
|
-
main_layout.addWidget(middle_widget, stretch=
|
328
|
+
main_layout.addWidget(middle_widget, stretch=1)
|
327
329
|
main_layout.addWidget(self.obj_diagram_viewer, stretch=2)
|
328
330
|
|
329
|
-
def set_save_function(self, save_function: Callable[[str], None]) -> None:
|
331
|
+
def set_save_function(self, save_function: Callable[[str, str], None]) -> None:
|
330
332
|
"""
|
331
333
|
Set the function to save the file.
|
332
334
|
|
333
335
|
:param save_function: The function to save the file.
|
334
336
|
"""
|
335
337
|
self.save_function = save_function
|
336
|
-
self.save_btn.clicked.connect(lambda: self.save_function(self.
|
338
|
+
self.save_btn.clicked.connect(lambda: self.save_function(self.save_dir, self.save_model_name))
|
337
339
|
|
338
340
|
def print(self, msg):
|
339
341
|
"""
|
@@ -489,6 +491,7 @@ class RDRCaseViewer(QMainWindow):
|
|
489
491
|
self.code_lines, self.template_file_creator.func_name,
|
490
492
|
self.template_file_creator.function_signature,
|
491
493
|
self.template_file_creator.func_doc, self.case_query)
|
494
|
+
self.case_query.scope.update(updates)
|
492
495
|
self.template_file_creator = None
|
493
496
|
|
494
497
|
def update_attribute_layout(self, obj, name: str):
|
@@ -134,7 +134,7 @@ class IPythonShell:
|
|
134
134
|
"""
|
135
135
|
Update the user input from the code lines captured in the shell.
|
136
136
|
"""
|
137
|
-
if
|
137
|
+
if self.shell.all_lines[0].replace('return', '').strip() == '':
|
138
138
|
self.user_input = None
|
139
139
|
else:
|
140
140
|
self.all_code_lines = extract_dependencies(self.shell.all_lines)
|
@@ -1,5 +1,9 @@
|
|
1
1
|
import logging
|
2
2
|
|
3
|
+
from ripple_down_rules.datastructures.case import Case
|
4
|
+
from ripple_down_rules.datastructures.dataclasses import CaseQuery
|
5
|
+
from ripple_down_rules.utils import SubclassJSONSerializer
|
6
|
+
|
3
7
|
try:
|
4
8
|
import graphviz
|
5
9
|
except ImportError:
|
@@ -77,7 +81,11 @@ def generate_object_graph(obj, name='root', seen=None, graph=None, current_depth
|
|
77
81
|
for attr in dir(obj):
|
78
82
|
if attr.startswith('_'):
|
79
83
|
continue
|
80
|
-
if attr == 'scope':
|
84
|
+
if isinstance(obj, CaseQuery) and attr == 'scope':
|
85
|
+
continue
|
86
|
+
if isinstance(obj, Case) and attr in ['data']:
|
87
|
+
continue
|
88
|
+
if isinstance(obj, SubclassJSONSerializer) and attr == 'data_class_refs':
|
81
89
|
continue
|
82
90
|
value = getattr(obj, attr)
|
83
91
|
if callable(value):
|
@@ -8,13 +8,13 @@ from functools import cached_property
|
|
8
8
|
from textwrap import indent, dedent
|
9
9
|
|
10
10
|
from colorama import Fore, Style
|
11
|
-
from typing_extensions import Optional, Type, List, Callable, Tuple, Dict
|
11
|
+
from typing_extensions import Optional, Type, List, Callable, Tuple, Dict, Any, Union
|
12
12
|
|
13
13
|
from ..datastructures.case import Case
|
14
14
|
from ..datastructures.dataclasses import CaseQuery
|
15
15
|
from ..datastructures.enums import Editor, PromptFor
|
16
|
-
from ..utils import str_to_snake_case, get_imports_from_scope, make_list,
|
17
|
-
get_imports_from_types, extract_function_source, extract_imports
|
16
|
+
from ..utils import str_to_snake_case, get_imports_from_scope, make_list, stringify_hint, \
|
17
|
+
get_imports_from_types, extract_function_source, extract_imports, get_types_to_import_from_type_hints
|
18
18
|
|
19
19
|
|
20
20
|
def detect_available_editor() -> Optional[Editor]:
|
@@ -84,6 +84,7 @@ class TemplateFileCreator:
|
|
84
84
|
self.func_doc: str = self.get_func_doc()
|
85
85
|
self.function_signature: str = self.get_function_signature()
|
86
86
|
self.editor: Optional[Editor] = detect_available_editor()
|
87
|
+
self.editor_cmd: Optional[str] = os.environ.get("RDR_EDITOR_CMD")
|
87
88
|
self.workspace: str = os.environ.get("RDR_EDITOR_WORKSPACE", os.path.dirname(self.case_query.scope['__file__']))
|
88
89
|
self.temp_file_path: str = os.path.join(self.workspace, "edit_code_here.py")
|
89
90
|
|
@@ -98,7 +99,7 @@ class TemplateFileCreator:
|
|
98
99
|
return make_list(output_type) if output_type is not None else None
|
99
100
|
|
100
101
|
def edit(self):
|
101
|
-
if self.editor is None:
|
102
|
+
if self.editor is None and self.editor_cmd is None:
|
102
103
|
self.print_func(
|
103
104
|
f"{Fore.RED}ERROR:: No editor found. Please install PyCharm, VSCode or code-server.{Style.RESET_ALL}")
|
104
105
|
return
|
@@ -112,7 +113,11 @@ class TemplateFileCreator:
|
|
112
113
|
"""
|
113
114
|
Open the file in the available editor.
|
114
115
|
"""
|
115
|
-
if self.
|
116
|
+
if self.editor_cmd is not None:
|
117
|
+
subprocess.Popen([self.editor_cmd, self.temp_file_path],
|
118
|
+
stdout=subprocess.DEVNULL,
|
119
|
+
stderr=subprocess.DEVNULL)
|
120
|
+
elif self.editor == Editor.Pycharm:
|
116
121
|
subprocess.Popen(["pycharm", "--line", str(self.user_edit_line), self.temp_file_path],
|
117
122
|
stdout=subprocess.DEVNULL,
|
118
123
|
stderr=subprocess.DEVNULL)
|
@@ -172,7 +177,7 @@ class TemplateFileCreator:
|
|
172
177
|
for k, v in self.case_query.case.items():
|
173
178
|
if (self.case_query.function_args_type_hints is not None
|
174
179
|
and k in self.case_query.function_args_type_hints):
|
175
|
-
func_args[k] =
|
180
|
+
func_args[k] = stringify_hint(self.case_query.function_args_type_hints[k])
|
176
181
|
else:
|
177
182
|
func_args[k] = type(v).__name__ if not isinstance(v, type) else f"Type[{v.__name__}]"
|
178
183
|
func_args = ', '.join([f"{k}: {v}" if str(v) not in ["NoneType", "None"] else str(k)
|
@@ -202,30 +207,25 @@ class TemplateFileCreator:
|
|
202
207
|
for k, v in self.case_query.case.items():
|
203
208
|
if (self.case_query.function_args_type_hints is not None
|
204
209
|
and k in self.case_query.function_args_type_hints):
|
205
|
-
|
206
|
-
|
207
|
-
hint_split = hint.split('.')
|
208
|
-
if len(hint_split) > 1:
|
209
|
-
case_type_imports.append(f"from {'.'.join(hint_split[:-1])} import {hint_split[-1]}")
|
210
|
+
types_to_import = get_types_to_import_from_type_hints([self.case_query.function_args_type_hints[k]])
|
211
|
+
case_type_imports.extend(list(types_to_import))
|
210
212
|
else:
|
211
|
-
|
212
|
-
case_type_imports.append(f"from {v.__module__} import {v.__name__}")
|
213
|
-
elif hasattr(v, "__module__") and not v.__module__.startswith("__"):
|
214
|
-
case_type_imports.append(f"\nfrom {type(v).__module__} import {type(v).__name__}")
|
213
|
+
case_type_imports.append(v)
|
215
214
|
else:
|
216
|
-
case_type_imports.append(
|
215
|
+
case_type_imports.append(self.case_type)
|
217
216
|
if self.output_type is None:
|
218
|
-
output_type_imports = [
|
217
|
+
output_type_imports = [Any]
|
219
218
|
else:
|
220
|
-
output_type_imports =
|
219
|
+
output_type_imports = self.output_type
|
221
220
|
if len(self.output_type) > 1:
|
222
|
-
output_type_imports.append(
|
221
|
+
output_type_imports.append(Union)
|
223
222
|
if list in self.output_type:
|
224
|
-
output_type_imports.append(
|
225
|
-
|
226
|
-
imports = [i for i in imports if ("get_ipython" not in i)]
|
227
|
-
|
228
|
-
|
223
|
+
output_type_imports.append(List)
|
224
|
+
import_types = list(self.case_query.scope.values())
|
225
|
+
# imports = [i for i in imports if ("get_ipython" not in i)]
|
226
|
+
import_types.extend(case_type_imports)
|
227
|
+
import_types.extend(output_type_imports)
|
228
|
+
imports = get_imports_from_types(import_types)
|
229
229
|
imports = set(imports)
|
230
230
|
return '\n'.join(imports)
|
231
231
|
|