ripple-down-rules 0.4.7__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +5 -0
- ripple_down_rules/datastructures/callable_expression.py +20 -1
- ripple_down_rules/datastructures/case.py +8 -6
- ripple_down_rules/datastructures/dataclasses.py +9 -1
- ripple_down_rules/experts.py +194 -33
- ripple_down_rules/rdr.py +196 -114
- ripple_down_rules/rdr_decorators.py +73 -52
- ripple_down_rules/rules.py +7 -6
- ripple_down_rules/start-code-server.sh +27 -0
- ripple_down_rules/user_interface/gui.py +21 -35
- ripple_down_rules/user_interface/ipython_custom_shell.py +6 -4
- ripple_down_rules/user_interface/object_diagram.py +7 -1
- ripple_down_rules/user_interface/prompt.py +9 -4
- ripple_down_rules/user_interface/template_file_creator.py +55 -32
- ripple_down_rules/utils.py +66 -26
- {ripple_down_rules-0.4.7.dist-info → ripple_down_rules-0.4.9.dist-info}/METADATA +10 -8
- ripple_down_rules-0.4.9.dist-info/RECORD +26 -0
- ripple_down_rules-0.4.7.dist-info/RECORD +0 -25
- {ripple_down_rules-0.4.7.dist-info → ripple_down_rules-0.4.9.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.4.7.dist-info → ripple_down_rules-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.4.7.dist-info → ripple_down_rules-0.4.9.dist-info}/top_level.txt +0 -0
@@ -5,15 +5,17 @@ of the RDRs.
|
|
5
5
|
"""
|
6
6
|
import os.path
|
7
7
|
from functools import wraps
|
8
|
+
|
9
|
+
from pyparsing.tools.cvt_pyparsing_pep8_names import camel_to_snake
|
8
10
|
from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union
|
9
11
|
|
10
|
-
from ripple_down_rules.datastructures.case import create_case
|
12
|
+
from ripple_down_rules.datastructures.case import create_case, Case
|
11
13
|
from ripple_down_rules.datastructures.dataclasses import CaseQuery
|
12
14
|
from ripple_down_rules.datastructures.enums import Category
|
13
15
|
from ripple_down_rules.experts import Expert, Human
|
14
16
|
from ripple_down_rules.rdr import GeneralRDR, RippleDownRules
|
15
17
|
from ripple_down_rules.utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
|
16
|
-
get_method_class_if_exists
|
18
|
+
get_method_class_if_exists, get_method_name, str_to_snake_case
|
17
19
|
|
18
20
|
|
19
21
|
class RDRDecorator:
|
@@ -41,99 +43,118 @@ class RDRDecorator:
|
|
41
43
|
:return: A decorator to use a GeneralRDR as a classifier that monitors and modifies the function's output.
|
42
44
|
"""
|
43
45
|
self.rdr_models_dir = models_dir
|
46
|
+
self.model_name: Optional[str] = None
|
44
47
|
self.output_type = output_type
|
45
48
|
self.parsed_output_type: List[Type] = []
|
46
49
|
self.mutual_exclusive = mutual_exclusive
|
47
50
|
self.rdr_python_path: Optional[str] = python_dir
|
48
51
|
self.output_name = output_name
|
49
52
|
self.fit: bool = fit
|
50
|
-
self.expert = expert
|
51
|
-
self.rdr_model_path: Optional[str] = None
|
53
|
+
self.expert: Optional[Expert] = expert
|
52
54
|
self.load()
|
53
55
|
|
54
56
|
def decorator(self, func: Callable) -> Callable:
|
55
57
|
|
56
58
|
@wraps(func)
|
57
59
|
def wrapper(*args, **kwargs) -> Optional[Any]:
|
60
|
+
|
58
61
|
if len(self.parsed_output_type) == 0:
|
59
|
-
self.parse_output_type(func, *args)
|
60
|
-
if self.
|
61
|
-
self.
|
62
|
-
|
63
|
-
func_output = func(*args, **kwargs)
|
64
|
-
case_dict.update({self.output_name: func_output})
|
65
|
-
case = create_case(case_dict, obj_name=get_func_rdr_model_name(func), max_recursion_idx=3)
|
62
|
+
self.parsed_output_type = self.parse_output_type(func, self.output_type, *args)
|
63
|
+
if self.model_name is None:
|
64
|
+
self.initialize_rdr_model_name_and_load(func)
|
65
|
+
|
66
66
|
if self.fit:
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
self.mutual_exclusive,
|
73
|
-
scope=scope, is_function=True, function_args_type_hints=func_args_type_hints)
|
67
|
+
expert_answers_path = os.path.join(self.rdr_models_dir, self.model_name, "expert_answers")
|
68
|
+
self.expert = self.expert or Human(answers_save_path=expert_answers_path)
|
69
|
+
case_query = self.create_case_query_from_method(func, self.parsed_output_type,
|
70
|
+
self.mutual_exclusive, self.output_name,
|
71
|
+
*args, **kwargs)
|
74
72
|
output = self.rdr.fit_case(case_query, expert=self.expert)
|
75
73
|
return output[self.output_name]
|
76
74
|
else:
|
75
|
+
case, case_dict = self.create_case_from_method(func, self.output_name, *args, **kwargs)
|
77
76
|
return self.rdr.classify(case)[self.output_name]
|
78
77
|
|
79
78
|
return wrapper
|
80
79
|
|
81
|
-
|
80
|
+
@staticmethod
|
81
|
+
def create_case_query_from_method(func: Callable, output_type, mutual_exclusive: bool,
|
82
|
+
output_name: str = 'output_', *args, **kwargs) -> CaseQuery:
|
83
|
+
"""
|
84
|
+
Create a CaseQuery from the function and its arguments.
|
85
|
+
|
86
|
+
:param func: The function to create a case from.
|
87
|
+
:param output_type: The type of the output.
|
88
|
+
:param mutual_exclusive: If True, the output types are mutually exclusive.
|
89
|
+
:param output_name: The name of the output in the case. Defaults to 'output_'.
|
90
|
+
:param args: The positional arguments of the function.
|
91
|
+
:param kwargs: The keyword arguments of the function.
|
92
|
+
:return: A CaseQuery object representing the case.
|
93
|
+
"""
|
94
|
+
output_type = make_set(output_type)
|
95
|
+
case, case_dict = RDRDecorator.create_case_from_method(func, output_name, *args, **kwargs)
|
96
|
+
scope = func.__globals__
|
97
|
+
scope.update(case_dict)
|
98
|
+
func_args_type_hints = get_type_hints(func)
|
99
|
+
func_args_type_hints.update({output_name: Union[tuple(output_type)]})
|
100
|
+
return CaseQuery(case, output_name, Union[tuple(output_type)],
|
101
|
+
mutual_exclusive, scope=scope,
|
102
|
+
is_function=True, function_args_type_hints=func_args_type_hints)
|
103
|
+
|
104
|
+
@staticmethod
|
105
|
+
def create_case_from_method(func: Callable, output_name: str = "output_", *args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
|
106
|
+
"""
|
107
|
+
Create a Case from the function and its arguments.
|
108
|
+
|
109
|
+
:param func: The function to create a case from.
|
110
|
+
:param output_name: The name of the output in the case. Defaults to 'output_'.
|
111
|
+
:param args: The positional arguments of the function.
|
112
|
+
:param kwargs: The keyword arguments of the function.
|
113
|
+
:return: A Case object representing the case.
|
114
|
+
"""
|
115
|
+
case_dict = get_method_args_as_dict(func, *args, **kwargs)
|
116
|
+
func_output = func(*args, **kwargs)
|
117
|
+
case_dict.update({output_name: func_output})
|
118
|
+
case_name = get_func_rdr_model_name(func)
|
119
|
+
return create_case(case_dict, obj_name=case_name, max_recursion_idx=3), case_dict
|
120
|
+
|
121
|
+
def initialize_rdr_model_name_and_load(self, func: Callable) -> None:
|
82
122
|
model_file_name = get_func_rdr_model_name(func, include_file_name=True)
|
83
|
-
|
84
|
-
.replace('__', '_') + ".json")
|
85
|
-
self.rdr_model_path = os.path.join(self.rdr_models_dir, model_file_name)
|
123
|
+
self.model_name = str_to_snake_case(model_file_name)
|
86
124
|
self.load()
|
87
125
|
|
88
|
-
|
89
|
-
|
126
|
+
@staticmethod
|
127
|
+
def parse_output_type(func: Callable, output_type: Any, *args) -> List[Type]:
|
128
|
+
parsed_output_type = []
|
129
|
+
for ot in make_set(output_type):
|
90
130
|
if ot is Self:
|
91
131
|
func_class = get_method_class_if_exists(func, *args)
|
92
132
|
if func_class is not None:
|
93
|
-
|
133
|
+
parsed_output_type.append(func_class)
|
94
134
|
else:
|
95
135
|
raise ValueError(f"The function {func} is not a method of a class,"
|
96
136
|
f" and the output type is {Self}.")
|
97
137
|
else:
|
98
|
-
|
138
|
+
parsed_output_type.append(ot)
|
139
|
+
return parsed_output_type
|
99
140
|
|
100
141
|
def save(self):
|
101
142
|
"""
|
102
143
|
Save the RDR model to the specified directory.
|
103
144
|
"""
|
104
|
-
self.rdr.save(self.
|
105
|
-
|
106
|
-
if self.rdr_python_path is not None:
|
107
|
-
if not os.path.exists(self.rdr_python_path):
|
108
|
-
os.makedirs(self.rdr_python_path)
|
109
|
-
if not os.path.exists(os.path.join(self.rdr_python_path, "__init__.py")):
|
110
|
-
# add __init__.py file to the directory
|
111
|
-
with open(os.path.join(self.rdr_python_path, "__init__.py"), "w") as f:
|
112
|
-
f.write("# This is an empty __init__.py file to make the directory a package.")
|
113
|
-
# write the RDR model to a python file
|
114
|
-
self.rdr.write_to_python_file(self.rdr_python_path)
|
145
|
+
self.rdr.save(self.rdr_models_dir)
|
115
146
|
|
116
147
|
def load(self):
|
117
148
|
"""
|
118
149
|
Load the RDR model from the specified directory.
|
119
150
|
"""
|
120
|
-
if self.
|
121
|
-
self.rdr = GeneralRDR.load(self.
|
151
|
+
if self.model_name is not None and os.path.exists(os.path.join(self.rdr_models_dir, self.model_name)):
|
152
|
+
self.rdr = GeneralRDR.load(self.rdr_models_dir, self.model_name)
|
122
153
|
else:
|
123
|
-
self.rdr = GeneralRDR()
|
124
|
-
|
125
|
-
def write_to_python_file(self, package_dir: str, file_name_postfix: str = ""):
|
126
|
-
"""
|
127
|
-
Write the RDR model to a python file.
|
128
|
-
|
129
|
-
:param package_dir: The path to the directory to write the python file.
|
130
|
-
"""
|
131
|
-
self.rdr.write_to_python_file(package_dir, postfix=file_name_postfix)
|
154
|
+
self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name)
|
132
155
|
|
133
|
-
def
|
156
|
+
def update_from_python(self):
|
134
157
|
"""
|
135
158
|
Update the RDR model from a python file.
|
136
|
-
|
137
|
-
:param package_dir: The directory of the package that contains the generated python file.
|
138
159
|
"""
|
139
|
-
self.rdr.
|
160
|
+
self.rdr.update_from_python(self.rdr_models_dir, self.model_name)
|
ripple_down_rules/rules.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import logging
|
3
4
|
import re
|
4
5
|
from abc import ABC, abstractmethod
|
5
6
|
from uuid import uuid4
|
6
7
|
|
7
8
|
from anytree import NodeMixin
|
8
|
-
from rospy import logwarn, logdebug
|
9
9
|
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
10
10
|
from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple
|
11
11
|
|
@@ -118,7 +118,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
118
118
|
func_call = f"{parent_indent} return {new_function_name.replace('def ', '')}(case)\n"
|
119
119
|
return "\n".join(conclusion_lines).strip(' '), func_call
|
120
120
|
else:
|
121
|
-
raise ValueError(f"Conclusion
|
121
|
+
raise ValueError(f"Conclusion format is not valid, it should contain a function definition."
|
122
122
|
f" Instead got:\n{conclusion}\n")
|
123
123
|
|
124
124
|
def write_condition_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
|
@@ -129,9 +129,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
129
129
|
:param defs_file: The file to write the conditions to if they are a definition.
|
130
130
|
"""
|
131
131
|
if_clause = self._if_statement_source_code_clause()
|
132
|
-
if
|
133
|
-
return f"{parent_indent}{if_clause} {self.conditions.user_input}:\n"
|
134
|
-
elif "def " in self.conditions.user_input:
|
132
|
+
if "def " in self.conditions.user_input:
|
135
133
|
if defs_file is None:
|
136
134
|
raise ValueError("Cannot write conditions to source code as definitions python file was not given.")
|
137
135
|
# This means the conditions are a definition that should be written and then called
|
@@ -143,6 +141,9 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
143
141
|
with open(defs_file, 'a') as f:
|
144
142
|
f.write(def_code.strip() + "\n\n\n")
|
145
143
|
return f"\n{parent_indent}{if_clause} {new_function_name.replace('def ', '')}(case):\n"
|
144
|
+
else:
|
145
|
+
raise ValueError(f"Conditions format is not valid, it should contain a function definition"
|
146
|
+
f" Instead got:\n{self.conditions.user_input}\n")
|
146
147
|
|
147
148
|
@abstractmethod
|
148
149
|
def _if_statement_source_code_clause(self) -> str:
|
@@ -164,7 +165,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
164
165
|
try:
|
165
166
|
corner_case = Case.from_json(data["corner_case"])
|
166
167
|
except Exception as e:
|
167
|
-
|
168
|
+
logging.debug("Failed to load corner case from json, setting it to None.")
|
168
169
|
corner_case = None
|
169
170
|
loaded_rule = cls(conditions=CallableExpression.from_json(data["conditions"]),
|
170
171
|
conclusion=CallableExpression.from_json(data["conclusion"]),
|
@@ -0,0 +1,27 @@
|
|
1
|
+
#!/bin/bash
|
2
|
+
set -e
|
3
|
+
if [ -z "$RDR_EDITOR_PORT" ]; then
|
4
|
+
echo "RDR_EDITOR_PORT is not set. Using default port 8080."
|
5
|
+
RDR_EDITOR_PORT=8080
|
6
|
+
fi
|
7
|
+
ADDR="0.0.0.0:$RDR_EDITOR_PORT"
|
8
|
+
# DATA_DIR="/root/.local/share/code-server"
|
9
|
+
echo "🚀 Starting code-server on $ADDR"
|
10
|
+
# Activate your Python virtual environment if exists else ignore
|
11
|
+
if [ -z "$RDR_VENV_PATH" ]; then
|
12
|
+
echo "No virtual environment found. Skipping activation."
|
13
|
+
else
|
14
|
+
source "$RDR_VENV_PATH/bin/activate"
|
15
|
+
# Set the default Python interpreter for VS Code
|
16
|
+
export DEFAULT_PYTHON_PATH=$(which python)
|
17
|
+
fi
|
18
|
+
|
19
|
+
# Start code-server.
|
20
|
+
echo "🚀 Starting code-server on $ADDR"
|
21
|
+
if [ -z "$CODE_SERVER_USER_DATA_DIR" ]; then
|
22
|
+
echo "No user data directory found. Using default"
|
23
|
+
code-server --bind-addr $ADDR --auth none "$@"
|
24
|
+
else
|
25
|
+
echo "Using user data directory: $CODE_SERVER_USER_DATA_DIR"
|
26
|
+
code-server --bind-addr $ADDR --user-data-dir $CODE_SERVER_USER_DATA_DIR --auth none "$@"
|
27
|
+
fi
|
@@ -1,23 +1,29 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import inspect
|
4
|
-
|
4
|
+
import logging
|
5
5
|
from types import MethodType
|
6
6
|
|
7
|
-
|
8
|
-
from PyQt6.
|
9
|
-
from PyQt6.
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
from qtconsole.
|
7
|
+
try:
|
8
|
+
from PyQt6.QtCore import Qt
|
9
|
+
from PyQt6.QtGui import QPixmap, QPainter, QPalette
|
10
|
+
from PyQt6.QtWidgets import (
|
11
|
+
QWidget, QVBoxLayout, QLabel, QScrollArea,
|
12
|
+
QSizePolicy, QToolButton, QHBoxLayout, QPushButton, QMainWindow, QGraphicsView, QGraphicsScene, QGraphicsPixmapItem
|
13
|
+
)
|
14
|
+
from qtconsole.inprocess import QtInProcessKernelManager
|
15
|
+
from qtconsole.rich_jupyter_widget import RichJupyterWidget
|
16
|
+
except ImportError as e:
|
17
|
+
logging.debug("RDRCaseViewer is not available. GUI features will not work. "
|
18
|
+
"Make sure you have PyQt6 installed if you want to use the GUI features.")
|
19
|
+
raise ImportError("PyQt6 is required for the GUI features. Please install it using 'pip install PyQt6'") from e
|
20
|
+
|
15
21
|
from typing_extensions import Optional, Any, List, Dict, Callable
|
16
22
|
|
17
23
|
from ..datastructures.dataclasses import CaseQuery
|
18
24
|
from ..datastructures.enums import PromptFor
|
19
25
|
from .template_file_creator import TemplateFileCreator
|
20
|
-
from ..utils import is_iterable, contains_return_statement,
|
26
|
+
from ..utils import is_iterable, contains_return_statement, encapsulate_code_lines_into_a_function
|
21
27
|
from .object_diagram import generate_object_graph
|
22
28
|
|
23
29
|
|
@@ -467,15 +473,17 @@ class RDRCaseViewer(QMainWindow):
|
|
467
473
|
self.close()
|
468
474
|
|
469
475
|
def _edit(self):
|
470
|
-
self.template_file_creator = TemplateFileCreator(self.
|
471
|
-
self.case_query, self.prompt_for, self.code_to_modify,
|
476
|
+
self.template_file_creator = TemplateFileCreator(self.case_query, self.prompt_for, self.code_to_modify,
|
472
477
|
self.print)
|
473
478
|
self.template_file_creator.edit()
|
474
479
|
|
475
480
|
def _load(self):
|
476
481
|
if not self.template_file_creator:
|
477
482
|
return
|
478
|
-
self.code_lines = self.template_file_creator.load(
|
483
|
+
self.code_lines, updates = self.template_file_creator.load(self.template_file_creator.temp_file_path,
|
484
|
+
self.template_file_creator.func_name,
|
485
|
+
self.template_file_creator.print_func)
|
486
|
+
self.ipython_console.kernel.shell.user_ns.update(updates)
|
479
487
|
if self.code_lines is not None:
|
480
488
|
self.user_input = encapsulate_code_lines_into_a_function(
|
481
489
|
self.code_lines, self.template_file_creator.func_name,
|
@@ -558,28 +566,6 @@ class RDRCaseViewer(QMainWindow):
|
|
558
566
|
layout.addWidget(item_label)
|
559
567
|
|
560
568
|
|
561
|
-
def encapsulate_code_lines_into_a_function(code_lines: List[str], function_name: str, function_signature: str,
|
562
|
-
func_doc: str, case_query: CaseQuery) -> str:
|
563
|
-
"""
|
564
|
-
Encapsulate the given code lines into a function with the specified name, signature, and docstring.
|
565
|
-
|
566
|
-
:param code_lines: The lines of code to include in the user input.
|
567
|
-
:param function_name: The name of the function to include in the user input.
|
568
|
-
:param function_signature: The function signature to include in the user input.
|
569
|
-
:param func_doc: The function docstring to include in the user input.
|
570
|
-
:param case_query: The case query object.
|
571
|
-
"""
|
572
|
-
code = '\n'.join(code_lines)
|
573
|
-
code = encapsulate_user_input(code, function_signature, func_doc)
|
574
|
-
if case_query.is_function:
|
575
|
-
args = "**case"
|
576
|
-
else:
|
577
|
-
args = "case"
|
578
|
-
if f"return {function_name}({args})" not in code:
|
579
|
-
code = code.strip() + f"\nreturn {function_name}({args})"
|
580
|
-
return code
|
581
|
-
|
582
|
-
|
583
569
|
class IPythonConsole(RichJupyterWidget):
|
584
570
|
def __init__(self, namespace=None, parent=None):
|
585
571
|
super(IPythonConsole, self).__init__(parent)
|
@@ -8,9 +8,8 @@ from traitlets.config import Config
|
|
8
8
|
|
9
9
|
from ..datastructures.dataclasses import CaseQuery
|
10
10
|
from ..datastructures.enums import PromptFor
|
11
|
-
from .gui import encapsulate_code_lines_into_a_function
|
12
11
|
from .template_file_creator import TemplateFileCreator
|
13
|
-
from ..utils import contains_return_statement, extract_dependencies
|
12
|
+
from ..utils import contains_return_statement, extract_dependencies, encapsulate_code_lines_into_a_function
|
14
13
|
|
15
14
|
|
16
15
|
@magics_class
|
@@ -21,7 +20,7 @@ class MyMagics(Magics):
|
|
21
20
|
prompt_for: Optional[PromptFor] = None,
|
22
21
|
case_query: Optional[CaseQuery] = None):
|
23
22
|
super().__init__(shell)
|
24
|
-
self.rule_editor = TemplateFileCreator(
|
23
|
+
self.rule_editor = TemplateFileCreator(case_query, prompt_for=prompt_for, code_to_modify=code_to_modify)
|
25
24
|
self.all_code_lines: Optional[List[str]] = None
|
26
25
|
|
27
26
|
@line_magic
|
@@ -30,7 +29,10 @@ class MyMagics(Magics):
|
|
30
29
|
|
31
30
|
@line_magic
|
32
31
|
def load(self, line):
|
33
|
-
self.all_code_lines = self.rule_editor.load(
|
32
|
+
self.all_code_lines, updates = self.rule_editor.load(self.rule_editor.temp_file_path,
|
33
|
+
self.rule_editor.func_name,
|
34
|
+
self.rule_editor.print_func)
|
35
|
+
self.shell.user_ns.update(updates)
|
34
36
|
|
35
37
|
@line_magic
|
36
38
|
def help(self, line):
|
@@ -2,7 +2,13 @@ import ast
|
|
2
2
|
import logging
|
3
3
|
from _ast import AST
|
4
4
|
|
5
|
-
|
5
|
+
try:
|
6
|
+
from PyQt6.QtWidgets import QApplication
|
7
|
+
from .gui import RDRCaseViewer
|
8
|
+
except ImportError:
|
9
|
+
QApplication = None
|
10
|
+
RDRCaseViewer = None
|
11
|
+
|
6
12
|
from colorama import Fore, Style
|
7
13
|
from pygments import highlight
|
8
14
|
from pygments.formatters.terminal import TerminalFormatter
|
@@ -11,8 +17,7 @@ from typing_extensions import Optional, Tuple
|
|
11
17
|
|
12
18
|
from ..datastructures.callable_expression import CallableExpression, parse_string_to_expression
|
13
19
|
from ..datastructures.dataclasses import CaseQuery
|
14
|
-
from ..datastructures.enums import PromptFor
|
15
|
-
from .gui import RDRCaseViewer
|
20
|
+
from ..datastructures.enums import PromptFor
|
16
21
|
from .ipython_custom_shell import IPythonShell
|
17
22
|
from ..utils import make_list
|
18
23
|
|
@@ -90,7 +95,7 @@ class UserPrompt:
|
|
90
95
|
prompt_str = f"Give conditions on when can the rule be evaluated for:"
|
91
96
|
case_query.scope.update({'case': case_query.case})
|
92
97
|
shell = None
|
93
|
-
if
|
98
|
+
if self.viewer is None:
|
94
99
|
prompt_str = self.construct_prompt_str_for_shell(case_query, prompt_for, prompt_str)
|
95
100
|
shell = IPythonShell(header=prompt_str, prompt_for=prompt_for, case_query=case_query,
|
96
101
|
code_to_modify=code_to_modify)
|
@@ -8,14 +8,13 @@ from functools import cached_property
|
|
8
8
|
from textwrap import indent, dedent
|
9
9
|
|
10
10
|
from colorama import Fore, Style
|
11
|
-
from
|
12
|
-
from typing_extensions import Optional, Type, List, Callable
|
11
|
+
from typing_extensions import Optional, Type, List, Callable, Tuple, Dict
|
13
12
|
|
14
13
|
from ..datastructures.case import Case
|
15
14
|
from ..datastructures.dataclasses import CaseQuery
|
16
15
|
from ..datastructures.enums import Editor, PromptFor
|
17
16
|
from ..utils import str_to_snake_case, get_imports_from_scope, make_list, typing_hint_to_str, \
|
18
|
-
get_imports_from_types, extract_function_source
|
17
|
+
get_imports_from_types, extract_function_source, extract_imports
|
19
18
|
|
20
19
|
|
21
20
|
def detect_available_editor() -> Optional[Editor]:
|
@@ -49,6 +48,9 @@ def start_code_server(workspace):
|
|
49
48
|
stderr=subprocess.PIPE, text=True)
|
50
49
|
|
51
50
|
|
51
|
+
FunctionData = Tuple[Optional[List[str]], Optional[Dict[str, Callable]]]
|
52
|
+
|
53
|
+
|
52
54
|
class TemplateFileCreator:
|
53
55
|
"""
|
54
56
|
A class to create a rule template file for a given case and prompt for the user to edit it.
|
@@ -70,10 +72,9 @@ class TemplateFileCreator:
|
|
70
72
|
The list of all code lines in the function in the temporary file.
|
71
73
|
"""
|
72
74
|
|
73
|
-
def __init__(self,
|
74
|
-
code_to_modify: Optional[str] = None, print_func:
|
75
|
-
self.print_func = print_func
|
76
|
-
self.shell = shell
|
75
|
+
def __init__(self, case_query: CaseQuery, prompt_for: PromptFor,
|
76
|
+
code_to_modify: Optional[str] = None, print_func: Callable[[str], None] = print):
|
77
|
+
self.print_func = print_func
|
77
78
|
self.code_to_modify = code_to_modify
|
78
79
|
self.prompt_for = prompt_for
|
79
80
|
self.case_query = case_query
|
@@ -228,15 +229,27 @@ class TemplateFileCreator:
|
|
228
229
|
imports = set(imports)
|
229
230
|
return '\n'.join(imports)
|
230
231
|
|
232
|
+
@staticmethod
|
233
|
+
def get_core_attribute_types(case_query: CaseQuery) -> List[Type]:
|
234
|
+
"""
|
235
|
+
Get the core attribute types of the case query.
|
236
|
+
|
237
|
+
:return: A list of core attribute types.
|
238
|
+
"""
|
239
|
+
attr_types = [t for t in case_query.core_attribute_type if t.__module__ != "builtins" and t is not None
|
240
|
+
and t is not type(None)]
|
241
|
+
return attr_types
|
242
|
+
|
231
243
|
def get_func_doc(self) -> Optional[str]:
|
232
244
|
"""
|
233
245
|
:return: A string containing the function docstring.
|
234
246
|
"""
|
247
|
+
type_data = f" of type {' or '.join(map(lambda c: c.__name__, self.get_core_attribute_types(self.case_query)))}"
|
235
248
|
if self.prompt_for == PromptFor.Conditions:
|
236
249
|
return (f"Get conditions on whether it's possible to conclude a value"
|
237
|
-
f" for {self.case_query.name}")
|
250
|
+
f" for {self.case_query.name} {type_data}.")
|
238
251
|
else:
|
239
|
-
return f"Get possible value(s) for {self.case_query.name}"
|
252
|
+
return f"Get possible value(s) for {self.case_query.name} {type_data}."
|
240
253
|
|
241
254
|
@staticmethod
|
242
255
|
def get_func_name(prompt_for, case_query) -> Optional[str]:
|
@@ -245,12 +258,12 @@ class TemplateFileCreator:
|
|
245
258
|
func_name = f"{prompt_for.value.lower()}_for_"
|
246
259
|
case_name = case_query.name.replace(".", "_")
|
247
260
|
if case_query.is_function:
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
261
|
+
func_name += case_name.replace(f"_{case_query.attribute_name}", "")
|
262
|
+
else:
|
263
|
+
func_name += case_name
|
264
|
+
attribute_types = TemplateFileCreator.get_core_attribute_types(case_query)
|
265
|
+
attribute_type_names = [t.__name__ for t in attribute_types]
|
266
|
+
func_name += f"_of_type_{'_or_'.join(attribute_type_names)}"
|
254
267
|
return str_to_snake_case(func_name)
|
255
268
|
|
256
269
|
@cached_property
|
@@ -263,33 +276,43 @@ class TemplateFileCreator:
|
|
263
276
|
case = self.case_query.scope['case']
|
264
277
|
return case._obj_type if isinstance(case, Case) else type(case)
|
265
278
|
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
279
|
+
@staticmethod
|
280
|
+
def load(file_path: str, func_name: str, print_func: Callable = print) -> FunctionData:
|
281
|
+
"""
|
282
|
+
Load the function from the given file path.
|
283
|
+
|
284
|
+
:param file_path: The path to the file to load.
|
285
|
+
:param func_name: The name of the function to load.
|
286
|
+
:param print_func: The function to use for printing messages.
|
287
|
+
:return: A tuple containing the function source code and the function object as a dictionary
|
288
|
+
with the function name as the key and the function object as the value.
|
289
|
+
"""
|
290
|
+
if not file_path:
|
291
|
+
print_func(f"{Fore.RED}ERROR:: No file to load. Run %edit first.{Style.RESET_ALL}")
|
292
|
+
return None, None
|
270
293
|
|
271
|
-
with open(
|
294
|
+
with open(file_path, 'r') as f:
|
272
295
|
source = f.read()
|
273
296
|
|
274
297
|
tree = ast.parse(source)
|
275
298
|
updates = {}
|
276
299
|
for node in tree.body:
|
277
|
-
if isinstance(node, ast.FunctionDef) and node.name ==
|
300
|
+
if isinstance(node, ast.FunctionDef) and node.name == func_name:
|
278
301
|
exec_globals = {}
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
302
|
+
scope = extract_imports(tree=tree)
|
303
|
+
exec(source, scope, exec_globals)
|
304
|
+
user_function = exec_globals[func_name]
|
305
|
+
updates[func_name] = user_function
|
306
|
+
print_func(f"{Fore.BLUE}Loaded `{func_name}` function into user namespace.{Style.RESET_ALL}")
|
283
307
|
break
|
284
308
|
if updates:
|
285
|
-
|
286
|
-
|
287
|
-
[
|
288
|
-
|
289
|
-
return self.all_code_lines
|
309
|
+
all_code_lines = extract_function_source(file_path,
|
310
|
+
[func_name],
|
311
|
+
join_lines=False)[func_name]
|
312
|
+
return all_code_lines, updates
|
290
313
|
else:
|
291
|
-
|
292
|
-
return None
|
314
|
+
print_func(f"{Fore.RED}ERROR:: Function `{func_name}` not found.{Style.RESET_ALL}")
|
315
|
+
return None, None
|
293
316
|
|
294
317
|
def __del__(self):
|
295
318
|
if hasattr(self, 'process') and self.process is not None and self.process.poll() is None:
|