ripple-down-rules 0.6.1__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +21 -1
- ripple_down_rules/datastructures/callable_expression.py +24 -7
- ripple_down_rules/datastructures/case.py +12 -11
- ripple_down_rules/datastructures/dataclasses.py +135 -14
- ripple_down_rules/datastructures/enums.py +29 -86
- ripple_down_rules/datastructures/field_info.py +177 -0
- ripple_down_rules/datastructures/tracked_object.py +208 -0
- ripple_down_rules/experts.py +141 -50
- ripple_down_rules/failures.py +4 -0
- ripple_down_rules/helpers.py +75 -8
- ripple_down_rules/predicates.py +97 -0
- ripple_down_rules/rdr.py +712 -96
- ripple_down_rules/rdr_decorators.py +164 -112
- ripple_down_rules/rules.py +351 -114
- ripple_down_rules/user_interface/gui.py +66 -41
- ripple_down_rules/user_interface/ipython_custom_shell.py +46 -9
- ripple_down_rules/user_interface/prompt.py +80 -60
- ripple_down_rules/user_interface/template_file_creator.py +13 -8
- ripple_down_rules/utils.py +537 -53
- {ripple_down_rules-0.6.1.dist-info → ripple_down_rules-0.6.6.dist-info}/METADATA +4 -1
- ripple_down_rules-0.6.6.dist-info/RECORD +28 -0
- ripple_down_rules-0.6.1.dist-info/RECORD +0 -24
- {ripple_down_rules-0.6.1.dist-info → ripple_down_rules-0.6.6.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.6.1.dist-info → ripple_down_rules-0.6.6.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.6.1.dist-info → ripple_down_rules-0.6.6.dist-info}/top_level.txt +0 -0
ripple_down_rules/utils.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import ast
|
4
3
|
import builtins
|
4
|
+
import codecs
|
5
5
|
import copyreg
|
6
6
|
import importlib
|
7
7
|
import json
|
8
|
-
import logging
|
9
8
|
import os
|
10
9
|
import re
|
10
|
+
import shutil
|
11
11
|
import sys
|
12
12
|
import threading
|
13
13
|
import uuid
|
@@ -17,36 +17,41 @@ from dataclasses import is_dataclass, fields
|
|
17
17
|
from enum import Enum
|
18
18
|
from os.path import dirname
|
19
19
|
from pathlib import Path
|
20
|
+
from subprocess import check_call
|
21
|
+
from tempfile import NamedTemporaryFile
|
20
22
|
from textwrap import dedent
|
21
|
-
from types import NoneType
|
23
|
+
from types import NoneType, ModuleType
|
24
|
+
import inspect
|
22
25
|
|
26
|
+
import six
|
27
|
+
from graphviz import Source
|
23
28
|
from sqlalchemy.exc import NoInspectionAvailable
|
24
|
-
|
29
|
+
from . import logger
|
25
30
|
|
26
31
|
try:
|
27
32
|
import matplotlib
|
28
33
|
from matplotlib import pyplot as plt
|
34
|
+
|
29
35
|
Figure = plt.Figure
|
30
36
|
except ImportError as e:
|
31
37
|
matplotlib = None
|
32
38
|
plt = None
|
33
39
|
Figure = None
|
34
|
-
|
40
|
+
logger.debug(f"{e}: matplotlib is not installed")
|
35
41
|
|
36
42
|
try:
|
37
43
|
import networkx as nx
|
38
44
|
except ImportError as e:
|
39
45
|
nx = None
|
40
|
-
|
46
|
+
logger.debug(f"{e}: networkx is not installed")
|
41
47
|
|
42
48
|
import requests
|
43
|
-
from anytree import Node, RenderTree
|
44
|
-
from
|
45
|
-
from sqlalchemy import MetaData, inspect
|
49
|
+
from anytree import Node, RenderTree, PreOrderIter
|
50
|
+
from sqlalchemy import MetaData, inspect as sql_inspect
|
46
51
|
from sqlalchemy.orm import Mapped, registry, class_mapper, DeclarativeBase as SQLTable, Session
|
47
52
|
from tabulate import tabulate
|
48
53
|
from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get_type_hints, \
|
49
|
-
get_origin, get_args, Tuple, Optional, List, Union, Self, ForwardRef,
|
54
|
+
get_origin, get_args, Tuple, Optional, List, Union, Self, ForwardRef, Iterable
|
50
55
|
|
51
56
|
if TYPE_CHECKING:
|
52
57
|
from .datastructures.case import Case
|
@@ -55,6 +60,49 @@ if TYPE_CHECKING:
|
|
55
60
|
import ast
|
56
61
|
|
57
62
|
|
63
|
+
def get_and_import_python_modules_in_a_package(file_paths: List[str],
|
64
|
+
parent_package_name: Optional[str] = None) -> List[Optional[ModuleType]]:
|
65
|
+
"""
|
66
|
+
:param file_paths: The paths to the python files to import.
|
67
|
+
:param parent_package_name: The name of the parent package to use for relative imports.
|
68
|
+
:return: The imported modules.
|
69
|
+
"""
|
70
|
+
package_path = dirname(file_paths[0])
|
71
|
+
package_import_path = get_import_path_from_path(package_path)
|
72
|
+
file_names = [Path(file_path).name.replace(".py", "") for file_path in file_paths]
|
73
|
+
module_import_paths = [
|
74
|
+
f"{package_import_path}.{file_name}" if package_import_path else file_name
|
75
|
+
for file_name in file_names
|
76
|
+
]
|
77
|
+
modules = [
|
78
|
+
importlib.import_module(module_import_path, package=parent_package_name)
|
79
|
+
if os.path.exists(file_paths[i]) else None
|
80
|
+
for i, module_import_path in enumerate(module_import_paths)
|
81
|
+
]
|
82
|
+
for module in modules:
|
83
|
+
if module is not None:
|
84
|
+
importlib.reload(module)
|
85
|
+
return modules
|
86
|
+
|
87
|
+
|
88
|
+
def get_and_import_python_module(python_file_path: str, package_import_path: Optional[str] = None,
|
89
|
+
parent_package_name: Optional[str] = None) -> ModuleType:
|
90
|
+
"""
|
91
|
+
:param python_file_path: The path to the python file to import.
|
92
|
+
:param package_import_path: The import path of the package that contains the python file.
|
93
|
+
:param parent_package_name: The name of the parent package to use for relative imports.
|
94
|
+
:return: The imported module.
|
95
|
+
"""
|
96
|
+
if package_import_path is None:
|
97
|
+
package_path = dirname(python_file_path)
|
98
|
+
package_import_path = get_import_path_from_path(package_path)
|
99
|
+
file_name = Path(python_file_path).name.replace(".py", "")
|
100
|
+
module_import_path = f"{package_import_path}.{file_name}" if package_import_path else file_name
|
101
|
+
module = importlib.import_module(module_import_path, package=parent_package_name)
|
102
|
+
importlib.reload(module)
|
103
|
+
return module
|
104
|
+
|
105
|
+
|
58
106
|
def str_to_snake_case(snake_str: str) -> str:
|
59
107
|
"""
|
60
108
|
Convert a string to snake case.
|
@@ -120,22 +168,30 @@ def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = N
|
|
120
168
|
try:
|
121
169
|
scope[asname] = importlib.import_module(module_name, package=package_name)
|
122
170
|
except ImportError as e:
|
123
|
-
|
171
|
+
logger.warning(f"Could not import {module_name}: {e}")
|
124
172
|
elif isinstance(node, ast.ImportFrom):
|
125
173
|
module_name = node.module
|
126
174
|
for alias in node.names:
|
127
175
|
name = alias.name
|
128
176
|
asname = alias.asname or name
|
129
177
|
try:
|
130
|
-
if
|
178
|
+
if node.level > 0: # Handle relative imports
|
179
|
+
package_name = get_import_path_from_path(Path(os.path.join(file_path, *['..'] * node.level)).resolve())
|
180
|
+
if package_name is not None and node.level > 0: # Handle relative imports
|
131
181
|
module_rel_path = Path(os.path.join(file_path, *['..'] * node.level, module_name)).resolve()
|
132
182
|
idx = str(module_rel_path).rfind(package_name)
|
133
183
|
if idx != -1:
|
134
184
|
module_name = str(module_rel_path)[idx:].replace(os.path.sep, '.')
|
135
|
-
|
136
|
-
|
185
|
+
try:
|
186
|
+
module = importlib.import_module(module_name, package=package_name)
|
187
|
+
except ModuleNotFoundError:
|
188
|
+
module = importlib.import_module(f"{package_name}.{module_name}")
|
189
|
+
if name == "*":
|
190
|
+
scope.update(module.__dict__)
|
191
|
+
else:
|
192
|
+
scope[asname] = getattr(module, name)
|
137
193
|
except (ImportError, AttributeError) as e:
|
138
|
-
|
194
|
+
logger.warning(f"Could not import {module_name}: {e} while extracting imports from {file_path}")
|
139
195
|
|
140
196
|
return scope
|
141
197
|
|
@@ -143,7 +199,9 @@ def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = N
|
|
143
199
|
def extract_function_source(file_path: str,
|
144
200
|
function_names: List[str], join_lines: bool = True,
|
145
201
|
return_line_numbers: bool = False,
|
146
|
-
include_signature: bool = True
|
202
|
+
include_signature: bool = True,
|
203
|
+
as_list: bool = False,
|
204
|
+
is_class: bool = False) \
|
147
205
|
-> Union[Dict[str, Union[str, List[str]]],
|
148
206
|
Tuple[Dict[str, Union[str, List[str]]], Dict[str, Tuple[int, int]]]]:
|
149
207
|
"""
|
@@ -154,6 +212,9 @@ def extract_function_source(file_path: str,
|
|
154
212
|
:param join_lines: Whether to join the lines of the function.
|
155
213
|
:param return_line_numbers: Whether to return the line numbers of the function.
|
156
214
|
:param include_signature: Whether to include the function signature in the source code.
|
215
|
+
:param as_list: Whether to return a list of function sources instead of dict (useful when there is multiple
|
216
|
+
functions with same name).
|
217
|
+
:param is_class: Whether to also look for class definitions
|
157
218
|
:return: A dictionary mapping function names to their source code as a string if join_lines is True,
|
158
219
|
otherwise as a list of strings.
|
159
220
|
"""
|
@@ -164,24 +225,39 @@ def extract_function_source(file_path: str,
|
|
164
225
|
tree = ast.parse(source)
|
165
226
|
function_names = make_list(function_names)
|
166
227
|
functions_source: Dict[str, Union[str, List[str]]] = {}
|
228
|
+
functions_source_list: List[Union[str, List[str]]] = []
|
167
229
|
line_numbers: Dict[str, Tuple[int, int]] = {}
|
230
|
+
line_numbers_list: List[Tuple[int, int]] = []
|
231
|
+
if is_class:
|
232
|
+
look_for_type = ast.ClassDef
|
233
|
+
else:
|
234
|
+
look_for_type = ast.FunctionDef
|
235
|
+
|
168
236
|
for node in tree.body:
|
169
|
-
if isinstance(node,
|
237
|
+
if isinstance(node, look_for_type) and (node.name in function_names or len(function_names) == 0):
|
170
238
|
# Get the line numbers of the function
|
171
239
|
lines = source.splitlines()
|
172
240
|
func_lines = lines[node.lineno - 1:node.end_lineno]
|
173
241
|
if not include_signature:
|
174
242
|
func_lines = func_lines[1:]
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
243
|
+
if as_list:
|
244
|
+
line_numbers_list.append((node.lineno, node.end_lineno))
|
245
|
+
else:
|
246
|
+
line_numbers[node.name] = (node.lineno, node.end_lineno)
|
247
|
+
parsed_function = dedent("\n".join(func_lines)) if join_lines else func_lines
|
248
|
+
if as_list:
|
249
|
+
functions_source_list.append(parsed_function)
|
250
|
+
else:
|
251
|
+
functions_source[node.name] = parsed_function
|
252
|
+
if len(function_names) > 0:
|
253
|
+
if len(functions_source) >= len(function_names) or len(functions_source_list) >= len(function_names):
|
254
|
+
break
|
255
|
+
if len(functions_source) < len(function_names) and len(functions_source_list) < len(function_names):
|
256
|
+
logger.warning(f"Could not find all functions in {file_path}: {function_names} not found, "
|
257
|
+
f"functions not found: {set(function_names) - set(functions_source.keys())}")
|
182
258
|
if return_line_numbers:
|
183
|
-
return functions_source, line_numbers
|
184
|
-
return functions_source
|
259
|
+
return functions_source if not as_list else functions_source_list, line_numbers if not as_list else line_numbers_list
|
260
|
+
return functions_source if not as_list else functions_source_list
|
185
261
|
|
186
262
|
|
187
263
|
def encapsulate_user_input(user_input: str, func_signature: str, func_doc: Optional[str] = None) -> str:
|
@@ -193,7 +269,8 @@ def encapsulate_user_input(user_input: str, func_signature: str, func_doc: Optio
|
|
193
269
|
:param func_doc: The function docstring to use for encapsulation.
|
194
270
|
:return: The encapsulated user input string.
|
195
271
|
"""
|
196
|
-
|
272
|
+
func_name = func_signature.split('(')[0].strip()
|
273
|
+
if func_name not in user_input:
|
197
274
|
new_user_input = func_signature + "\n "
|
198
275
|
if func_doc is not None:
|
199
276
|
new_user_input += f"\"\"\"{func_doc}\"\"\"" + "\n "
|
@@ -282,7 +359,7 @@ def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
|
|
282
359
|
case_query.case.update(conclusions)
|
283
360
|
|
284
361
|
|
285
|
-
def
|
362
|
+
def is_value_conflicting(conclusion: Any, target: Any) -> bool:
|
286
363
|
"""
|
287
364
|
:param conclusion: The conclusion to check.
|
288
365
|
:param target: The target to compare the conclusion with.
|
@@ -621,18 +698,21 @@ def capture_variable_assignment(code: str, variable_name: str) -> Optional[str]:
|
|
621
698
|
return assignment
|
622
699
|
|
623
700
|
|
624
|
-
def get_func_rdr_model_path(func: Callable, model_dir: str) -> str:
|
701
|
+
def get_func_rdr_model_path(func: Callable, model_dir: str, include_file_name: bool = False) -> str:
|
625
702
|
"""
|
626
703
|
:param func: The function to get the model path for.
|
627
704
|
:param model_dir: The directory to save the model to.
|
705
|
+
:param include_file_name: Whether to include the file name in the model name.
|
628
706
|
:return: The path to the model file.
|
629
707
|
"""
|
630
|
-
return os.path.join(model_dir,
|
708
|
+
return os.path.join(model_dir, get_func_rdr_model_name(func, include_file_name=include_file_name),
|
709
|
+
f"{get_func_rdr_model_name(func)}_rdr.py")
|
631
710
|
|
632
711
|
|
633
712
|
def get_func_rdr_model_name(func: Callable, include_file_name: bool = False) -> str:
|
634
713
|
"""
|
635
714
|
:param func: The function to get the model name for.
|
715
|
+
:param include_file_name: Whether to include the file name in the model name.
|
636
716
|
:return: The name of the model.
|
637
717
|
"""
|
638
718
|
func_name = get_method_name(func)
|
@@ -644,7 +724,7 @@ def get_func_rdr_model_name(func: Callable, include_file_name: bool = False) ->
|
|
644
724
|
model_name = ''
|
645
725
|
model_name += f"{func_class_name}_" if func_class_name else ""
|
646
726
|
model_name += f"{func_name}"
|
647
|
-
return model_name
|
727
|
+
return str_to_snake_case(model_name)
|
648
728
|
|
649
729
|
|
650
730
|
def stringify_hint(tp):
|
@@ -681,6 +761,7 @@ def is_builtin_type(tp):
|
|
681
761
|
def is_typing_type(tp):
|
682
762
|
return tp.__module__ == "typing"
|
683
763
|
|
764
|
+
|
684
765
|
origin_type_to_hint = {
|
685
766
|
list: List,
|
686
767
|
set: Set,
|
@@ -688,6 +769,47 @@ origin_type_to_hint = {
|
|
688
769
|
tuple: Tuple,
|
689
770
|
}
|
690
771
|
|
772
|
+
|
773
|
+
def get_file_that_ends_with(directory_path: str, suffix: str) -> Optional[str]:
|
774
|
+
"""
|
775
|
+
Get the file that ends with the given suffix in the model directory.
|
776
|
+
|
777
|
+
:param directory_path: The path to the directory where the file is located.
|
778
|
+
:param suffix: The suffix to search for.
|
779
|
+
:return: The path to the file that ends with the given suffix, or None if not found.
|
780
|
+
"""
|
781
|
+
files = [f for f in os.listdir(directory_path) if f.endswith(suffix)]
|
782
|
+
if files:
|
783
|
+
return files[0]
|
784
|
+
return None
|
785
|
+
|
786
|
+
def get_function_return_type(func: Callable) -> Union[Type, None, Tuple[Type, ...]]:
|
787
|
+
"""
|
788
|
+
Get the return type of a function.
|
789
|
+
|
790
|
+
:param func: The function to get the return type for.
|
791
|
+
:return: The return type of the function, or None if not specified.
|
792
|
+
"""
|
793
|
+
sig = inspect.signature(func)
|
794
|
+
if sig.return_annotation == inspect.Signature.empty:
|
795
|
+
return None
|
796
|
+
type_hint = sig.return_annotation
|
797
|
+
return get_type_from_type_hint(type_hint)
|
798
|
+
|
799
|
+
|
800
|
+
def get_type_from_type_hint(type_hint: Type) -> Union[Type, Tuple[Type, ...]]:
|
801
|
+
origin = get_origin(type_hint)
|
802
|
+
args = get_args(type_hint)
|
803
|
+
if origin not in [list, set, None, Union]:
|
804
|
+
raise TypeError(f"{origin} is not a handled return type for type hint {type_hint}")
|
805
|
+
if origin is None:
|
806
|
+
return typing_to_python_type(type_hint)
|
807
|
+
if args is None or len(args) == 0:
|
808
|
+
return typing_to_python_type(type_hint)
|
809
|
+
return args
|
810
|
+
|
811
|
+
|
812
|
+
|
691
813
|
def extract_types(tp, seen: Set = None) -> Set[type]:
|
692
814
|
"""Recursively extract all base types from a type hint."""
|
693
815
|
if seen is None:
|
@@ -786,6 +908,13 @@ def get_import_path_from_path(path: str) -> Optional[str]:
|
|
786
908
|
return package_name
|
787
909
|
|
788
910
|
|
911
|
+
def get_class_file_path(cls):
|
912
|
+
"""
|
913
|
+
Get the file path of a class.
|
914
|
+
"""
|
915
|
+
return os.path.abspath(inspect.getfile(cls))
|
916
|
+
|
917
|
+
|
789
918
|
def get_function_import_data(func: Callable) -> Tuple[str, str]:
|
790
919
|
"""
|
791
920
|
Get the import path of a function.
|
@@ -840,10 +969,12 @@ def get_relative_import(target_file_path, imported_module_path: Optional[str] =
|
|
840
969
|
imported_file_name = Path(imported_module_path).name
|
841
970
|
target_file_name = Path(target_file_path).name
|
842
971
|
if package_name is not None:
|
843
|
-
target_path = Path(
|
972
|
+
target_path = Path(
|
973
|
+
get_path_starting_from_latest_encounter_of(str(target_path), package_name, [target_file_name]))
|
844
974
|
imported_path = Path(imported_module_path).resolve()
|
845
975
|
if package_name is not None:
|
846
|
-
imported_path = Path(
|
976
|
+
imported_path = Path(
|
977
|
+
get_path_starting_from_latest_encounter_of(str(imported_path), package_name, [imported_file_name]))
|
847
978
|
|
848
979
|
# Compute relative path from target to imported module
|
849
980
|
rel_path = os.path.relpath(imported_path.parent, target_path.parent)
|
@@ -920,8 +1051,8 @@ def get_imports_from_types(type_objs: Iterable[Type],
|
|
920
1051
|
continue
|
921
1052
|
if name == "NoneType":
|
922
1053
|
module = "types"
|
923
|
-
if module is None or module == 'builtins' or module.startswith('_')\
|
924
|
-
|
1054
|
+
if module is None or module == 'builtins' or module.startswith('_') \
|
1055
|
+
or module in sys.builtin_module_names or module in excluded_modules or "<" in module \
|
925
1056
|
or name in exclueded_names:
|
926
1057
|
continue
|
927
1058
|
if module == "typing":
|
@@ -989,7 +1120,8 @@ def get_method_class_name_if_exists(method: Callable) -> Optional[str]:
|
|
989
1120
|
return method.__self__.__name__
|
990
1121
|
elif hasattr(method.__self__, "__class__"):
|
991
1122
|
return method.__self__.__class__.__name__
|
992
|
-
return method.__qualname__.split('.')[0]
|
1123
|
+
return (method.__qualname__.split('.')[0]
|
1124
|
+
if hasattr(method, "__qualname__") and '.' in method.__qualname__ else None)
|
993
1125
|
|
994
1126
|
|
995
1127
|
def get_method_class_if_exists(method: Callable, *args) -> Optional[Type]:
|
@@ -1087,6 +1219,7 @@ def get_full_class_name(cls):
|
|
1087
1219
|
def recursive_subclasses(cls):
|
1088
1220
|
"""
|
1089
1221
|
Copied from: https://github.com/tomsch420/random-events/blob/master/src/random_events/utils.py#L6C1-L21C101
|
1222
|
+
|
1090
1223
|
:param cls: The class.
|
1091
1224
|
:return: A list of the classes subclasses.
|
1092
1225
|
"""
|
@@ -1211,7 +1344,8 @@ class SubclassJSONSerializer:
|
|
1211
1344
|
return cls._from_json(data)
|
1212
1345
|
for subclass in recursive_subclasses(SubclassJSONSerializer):
|
1213
1346
|
if get_full_class_name(subclass) == data["_type"]:
|
1214
|
-
subclass_data = deepcopy(data)
|
1347
|
+
# subclass_data = deepcopy(data)
|
1348
|
+
subclass_data = data
|
1215
1349
|
subclass_data.pop("_type")
|
1216
1350
|
return subclass._from_json(subclass_data)
|
1217
1351
|
|
@@ -1273,7 +1407,7 @@ def copy_orm_instance(instance: SQLTable) -> SQLTable:
|
|
1273
1407
|
:return: The copied instance.
|
1274
1408
|
"""
|
1275
1409
|
try:
|
1276
|
-
session: Session =
|
1410
|
+
session: Session = sql_inspect(instance).session
|
1277
1411
|
except NoInspectionAvailable:
|
1278
1412
|
session = None
|
1279
1413
|
if session is not None:
|
@@ -1284,7 +1418,7 @@ def copy_orm_instance(instance: SQLTable) -> SQLTable:
|
|
1284
1418
|
try:
|
1285
1419
|
new_instance = deepcopy(instance)
|
1286
1420
|
except Exception as e:
|
1287
|
-
|
1421
|
+
logger.debug(e)
|
1288
1422
|
new_instance = instance
|
1289
1423
|
return new_instance
|
1290
1424
|
|
@@ -1304,7 +1438,7 @@ def copy_orm_instance_with_relationships(instance: SQLTable) -> SQLTable:
|
|
1304
1438
|
try:
|
1305
1439
|
setattr(instance_cp, rel.key, related_obj_cp)
|
1306
1440
|
except Exception as e:
|
1307
|
-
|
1441
|
+
logger.debug(e)
|
1308
1442
|
return instance_cp
|
1309
1443
|
|
1310
1444
|
|
@@ -1400,9 +1534,14 @@ def table_rows_as_str(row_dicts: List[Dict[str, Any]], columns_per_row: int = 20
|
|
1400
1534
|
keys_values = [list(r[0]) + list(r[1]) if len(r) > 1 else r[0] for r in keys_values]
|
1401
1535
|
all_table_rows = []
|
1402
1536
|
row_values = [list(map(lambda v: str(v) if v is not None else "", row)) for row in keys_values]
|
1403
|
-
row_values = [list(map(lambda v: v[:max_line_sze] + '...' if len(v) > max_line_sze else v, row)) for row in
|
1537
|
+
row_values = [list(map(lambda v: v[:max_line_sze] + '...' if len(v) > max_line_sze else v, row)) for row in
|
1538
|
+
row_values]
|
1404
1539
|
row_values = [list(map(lambda v: v.lower() if v in ["True", "False"] else v, row)) for row in row_values]
|
1405
|
-
|
1540
|
+
# Step 1: Get terminal size
|
1541
|
+
terminal_width = shutil.get_terminal_size((80, 20)).columns
|
1542
|
+
# Step 2: Dynamically calculate max width per column (simple approximation)
|
1543
|
+
max_col_width = terminal_width // len(row_values[0])
|
1544
|
+
table = tabulate(row_values, tablefmt='simple_grid', maxcolwidths=max_col_width) # [max_line_sze] * 2)
|
1406
1545
|
all_table_rows.append(table)
|
1407
1546
|
return "\n".join(all_table_rows)
|
1408
1547
|
|
@@ -1551,6 +1690,13 @@ def get_all_subclasses(cls: Type) -> Dict[str, Type]:
|
|
1551
1690
|
return all_subclasses
|
1552
1691
|
|
1553
1692
|
|
1693
|
+
def make_tuple(value: Any) -> Any:
|
1694
|
+
"""
|
1695
|
+
Make a tuple from a value.
|
1696
|
+
"""
|
1697
|
+
return tuple(value) if is_iterable(value) else (value,)
|
1698
|
+
|
1699
|
+
|
1554
1700
|
def make_set(value: Any) -> Set:
|
1555
1701
|
"""
|
1556
1702
|
Make a set from a value.
|
@@ -1622,34 +1768,358 @@ def edge_attr_setter(parent, child):
|
|
1622
1768
|
"""
|
1623
1769
|
Set the edge attributes for the dot exporter.
|
1624
1770
|
"""
|
1625
|
-
if child and hasattr(child, "weight") and child.weight:
|
1771
|
+
if child and hasattr(child, "weight") and child.weight is not None:
|
1626
1772
|
return f'style="bold", label=" {child.weight}"'
|
1627
1773
|
return ""
|
1628
1774
|
|
1629
1775
|
|
1776
|
+
_RE_ESC = re.compile(r'["\\]')
|
1777
|
+
|
1778
|
+
|
1779
|
+
class FilteredDotExporter(object):
|
1780
|
+
|
1781
|
+
def __init__(self, node, include_nodes=None, graph="digraph", name="tree", options=None,
|
1782
|
+
indent=4, nodenamefunc=None, nodeattrfunc=None,
|
1783
|
+
edgeattrfunc=None, edgetypefunc=None, maxlevel=None):
|
1784
|
+
"""
|
1785
|
+
Dot Language Exporter.
|
1786
|
+
|
1787
|
+
Args:
|
1788
|
+
node (Node): start node.
|
1789
|
+
|
1790
|
+
Keyword Args:
|
1791
|
+
graph: DOT graph type.
|
1792
|
+
|
1793
|
+
name: DOT graph name.
|
1794
|
+
|
1795
|
+
options: list of options added to the graph.
|
1796
|
+
|
1797
|
+
indent (int): number of spaces for indent.
|
1798
|
+
|
1799
|
+
nodenamefunc: Function to extract node name from `node` object.
|
1800
|
+
The function shall accept one `node` object as
|
1801
|
+
argument and return the name of it.
|
1802
|
+
|
1803
|
+
nodeattrfunc: Function to decorate a node with attributes.
|
1804
|
+
The function shall accept one `node` object as
|
1805
|
+
argument and return the attributes.
|
1806
|
+
|
1807
|
+
edgeattrfunc: Function to decorate a edge with attributes.
|
1808
|
+
The function shall accept two `node` objects as
|
1809
|
+
argument. The first the node and the second the child
|
1810
|
+
and return the attributes.
|
1811
|
+
|
1812
|
+
edgetypefunc: Function to which gives the edge type.
|
1813
|
+
The function shall accept two `node` objects as
|
1814
|
+
argument. The first the node and the second the child
|
1815
|
+
and return the edge (i.e. '->').
|
1816
|
+
|
1817
|
+
maxlevel (int): Limit export to this number of levels.
|
1818
|
+
|
1819
|
+
>>> from anytree import Node
|
1820
|
+
>>> root = Node("root")
|
1821
|
+
>>> s0 = Node("sub0", parent=root, edge=2)
|
1822
|
+
>>> s0b = Node("sub0B", parent=s0, foo=4, edge=109)
|
1823
|
+
>>> s0a = Node("sub0A", parent=s0, edge="")
|
1824
|
+
>>> s1 = Node("sub1", parent=root, edge="")
|
1825
|
+
>>> s1a = Node("sub1A", parent=s1, edge=7)
|
1826
|
+
>>> s1b = Node("sub1B", parent=s1, edge=8)
|
1827
|
+
>>> s1c = Node("sub1C", parent=s1, edge=22)
|
1828
|
+
>>> s1ca = Node("sub1Ca", parent=s1c, edge=42)
|
1829
|
+
|
1830
|
+
.. note:: If the node names are not unqiue, see :any:`UniqueDotExporter`.
|
1831
|
+
|
1832
|
+
A directed graph:
|
1833
|
+
|
1834
|
+
>>> from anytree.exporter import DotExporter
|
1835
|
+
>>> for line in DotExporter(root):
|
1836
|
+
... print(line)
|
1837
|
+
digraph tree {
|
1838
|
+
"root";
|
1839
|
+
"sub0";
|
1840
|
+
"sub0B";
|
1841
|
+
"sub0A";
|
1842
|
+
"sub1";
|
1843
|
+
"sub1A";
|
1844
|
+
"sub1B";
|
1845
|
+
"sub1C";
|
1846
|
+
"sub1Ca";
|
1847
|
+
"root" -> "sub0";
|
1848
|
+
"root" -> "sub1";
|
1849
|
+
"sub0" -> "sub0B";
|
1850
|
+
"sub0" -> "sub0A";
|
1851
|
+
"sub1" -> "sub1A";
|
1852
|
+
"sub1" -> "sub1B";
|
1853
|
+
"sub1" -> "sub1C";
|
1854
|
+
"sub1C" -> "sub1Ca";
|
1855
|
+
}
|
1856
|
+
|
1857
|
+
The resulting graph:
|
1858
|
+
|
1859
|
+
.. image:: ../static/dotexporter0.png
|
1860
|
+
|
1861
|
+
An undirected graph:
|
1862
|
+
|
1863
|
+
>>> def nodenamefunc(node):
|
1864
|
+
... return '%s:%s' % (node.name, node.depth)
|
1865
|
+
>>> def edgeattrfunc(node, child):
|
1866
|
+
... return 'label="%s:%s"' % (node.name, child.name)
|
1867
|
+
>>> def edgetypefunc(node, child):
|
1868
|
+
... return '--'
|
1869
|
+
>>> from anytree.exporter import DotExporter
|
1870
|
+
>>> for line in DotExporter(root, graph="graph",
|
1871
|
+
... nodenamefunc=nodenamefunc,
|
1872
|
+
... nodeattrfunc=lambda node: "shape=box",
|
1873
|
+
... edgeattrfunc=edgeattrfunc,
|
1874
|
+
... edgetypefunc=edgetypefunc):
|
1875
|
+
... print(line)
|
1876
|
+
graph tree {
|
1877
|
+
"root:0" [shape=box];
|
1878
|
+
"sub0:1" [shape=box];
|
1879
|
+
"sub0B:2" [shape=box];
|
1880
|
+
"sub0A:2" [shape=box];
|
1881
|
+
"sub1:1" [shape=box];
|
1882
|
+
"sub1A:2" [shape=box];
|
1883
|
+
"sub1B:2" [shape=box];
|
1884
|
+
"sub1C:2" [shape=box];
|
1885
|
+
"sub1Ca:3" [shape=box];
|
1886
|
+
"root:0" -- "sub0:1" [label="root:sub0"];
|
1887
|
+
"root:0" -- "sub1:1" [label="root:sub1"];
|
1888
|
+
"sub0:1" -- "sub0B:2" [label="sub0:sub0B"];
|
1889
|
+
"sub0:1" -- "sub0A:2" [label="sub0:sub0A"];
|
1890
|
+
"sub1:1" -- "sub1A:2" [label="sub1:sub1A"];
|
1891
|
+
"sub1:1" -- "sub1B:2" [label="sub1:sub1B"];
|
1892
|
+
"sub1:1" -- "sub1C:2" [label="sub1:sub1C"];
|
1893
|
+
"sub1C:2" -- "sub1Ca:3" [label="sub1C:sub1Ca"];
|
1894
|
+
}
|
1895
|
+
|
1896
|
+
The resulting graph:
|
1897
|
+
|
1898
|
+
.. image:: ../static/dotexporter1.png
|
1899
|
+
|
1900
|
+
To export custom node implementations or :any:`AnyNode`, please provide a proper `nodenamefunc`:
|
1901
|
+
|
1902
|
+
>>> from anytree import AnyNode
|
1903
|
+
>>> root = AnyNode(id="root")
|
1904
|
+
>>> s0 = AnyNode(id="sub0", parent=root)
|
1905
|
+
>>> s0b = AnyNode(id="s0b", parent=s0)
|
1906
|
+
>>> s0a = AnyNode(id="s0a", parent=s0)
|
1907
|
+
|
1908
|
+
>>> from anytree.exporter import DotExporter
|
1909
|
+
>>> for line in DotExporter(root, nodenamefunc=lambda n: n.id):
|
1910
|
+
... print(line)
|
1911
|
+
digraph tree {
|
1912
|
+
"root";
|
1913
|
+
"sub0";
|
1914
|
+
"s0b";
|
1915
|
+
"s0a";
|
1916
|
+
"root" -> "sub0";
|
1917
|
+
"sub0" -> "s0b";
|
1918
|
+
"sub0" -> "s0a";
|
1919
|
+
}
|
1920
|
+
"""
|
1921
|
+
self.node = node
|
1922
|
+
self.graph = graph
|
1923
|
+
self.name = name
|
1924
|
+
self.options = options
|
1925
|
+
self.indent = indent
|
1926
|
+
self.nodenamefunc = nodenamefunc
|
1927
|
+
self.nodeattrfunc = nodeattrfunc
|
1928
|
+
self.edgeattrfunc = edgeattrfunc
|
1929
|
+
self.edgetypefunc = edgetypefunc
|
1930
|
+
self.maxlevel = maxlevel
|
1931
|
+
self.include_nodes = include_nodes
|
1932
|
+
node_name_func = get_unique_node_names_func(node)
|
1933
|
+
self.include_node_names = [node_name_func(n) for n in self.include_nodes] if include_nodes else None
|
1934
|
+
|
1935
|
+
def __iter__(self):
|
1936
|
+
# prepare
|
1937
|
+
indent = " " * self.indent
|
1938
|
+
nodenamefunc = self.nodenamefunc or self._default_nodenamefunc
|
1939
|
+
nodeattrfunc = self.nodeattrfunc or self._default_nodeattrfunc
|
1940
|
+
edgeattrfunc = self.edgeattrfunc or self._default_edgeattrfunc
|
1941
|
+
edgetypefunc = self.edgetypefunc or self._default_edgetypefunc
|
1942
|
+
return self.__iter(indent, nodenamefunc, nodeattrfunc, edgeattrfunc,
|
1943
|
+
edgetypefunc)
|
1944
|
+
|
1945
|
+
@staticmethod
|
1946
|
+
def _default_nodenamefunc(node):
|
1947
|
+
return node.name
|
1948
|
+
|
1949
|
+
@staticmethod
|
1950
|
+
def _default_nodeattrfunc(node):
|
1951
|
+
return None
|
1952
|
+
|
1953
|
+
@staticmethod
|
1954
|
+
def _default_edgeattrfunc(node, child):
|
1955
|
+
return None
|
1956
|
+
|
1957
|
+
@staticmethod
|
1958
|
+
def _default_edgetypefunc(node, child):
|
1959
|
+
return "->"
|
1960
|
+
|
1961
|
+
def __iter(self, indent, nodenamefunc, nodeattrfunc, edgeattrfunc, edgetypefunc):
|
1962
|
+
yield "{self.graph} {self.name} {{".format(self=self)
|
1963
|
+
for option in self.__iter_options(indent):
|
1964
|
+
yield option
|
1965
|
+
for node in self.__iter_nodes(indent, nodenamefunc, nodeattrfunc):
|
1966
|
+
yield node
|
1967
|
+
for edge in self.__iter_edges(indent, nodenamefunc, edgeattrfunc, edgetypefunc):
|
1968
|
+
yield edge
|
1969
|
+
legend_dot_graph = """
|
1970
|
+
// Color legend as a subgraph
|
1971
|
+
subgraph cluster_legend {
|
1972
|
+
label = "Legend";
|
1973
|
+
style = dashed;
|
1974
|
+
color = gray;
|
1975
|
+
|
1976
|
+
legend_green [label="Fired->Query Related Value", shape=box, style=filled, fillcolor=green, fontcolor=black, size=0.5];
|
1977
|
+
legend_yellow [label="Fired->Some Value", shape=box, style=filled, fillcolor=yellow, fontcolor=black, size=0.5];
|
1978
|
+
legend_orange [label="Fired->Empty Value", shape=box, style=filled, fillcolor=orange, fontcolor=black, size=0.5];
|
1979
|
+
legend_red [label="Evaluated->Not Fired", shape=box, style=filled, fillcolor=red, fontcolor=black, size=0.5];
|
1980
|
+
legend_white [label="Not Evaluated", shape=box, style=filled, fillcolor=white, fontcolor=black, size=0.5];
|
1981
|
+
|
1982
|
+
// Invisible edges to arrange legend vertically
|
1983
|
+
legend_white -> legend_red -> legend_orange -> legend_yellow -> legend_green [style=invis];
|
1984
|
+
}"""
|
1985
|
+
for line in legend_dot_graph.splitlines():
|
1986
|
+
yield "%s" % (line.strip())
|
1987
|
+
yield "}"
|
1988
|
+
|
1989
|
+
def __iter_options(self, indent):
|
1990
|
+
options = self.options
|
1991
|
+
if options:
|
1992
|
+
for option in options:
|
1993
|
+
yield "%s%s" % (indent, option)
|
1994
|
+
|
1995
|
+
def __iter_nodes(self, indent, nodenamefunc, nodeattrfunc):
|
1996
|
+
for node in PreOrderIter(self.node, maxlevel=self.maxlevel):
|
1997
|
+
nodename = nodenamefunc(node)
|
1998
|
+
if self.include_nodes is not None and nodename not in self.include_node_names:
|
1999
|
+
continue
|
2000
|
+
nodeattr = nodeattrfunc(node)
|
2001
|
+
nodeattr = " [%s]" % nodeattr if nodeattr is not None else ""
|
2002
|
+
yield '%s"%s"%s;' % (indent, FilteredDotExporter.esc(nodename), nodeattr)
|
2003
|
+
|
2004
|
+
def __iter_edges(self, indent, nodenamefunc, edgeattrfunc, edgetypefunc):
|
2005
|
+
maxlevel = self.maxlevel - 1 if self.maxlevel else None
|
2006
|
+
for node in PreOrderIter(self.node, maxlevel=maxlevel):
|
2007
|
+
nodename = nodenamefunc(node)
|
2008
|
+
if self.include_nodes is not None and nodename not in self.include_node_names:
|
2009
|
+
continue
|
2010
|
+
for child in node.children:
|
2011
|
+
childname = nodenamefunc(child)
|
2012
|
+
if self.include_nodes is not None and childname not in self.include_node_names:
|
2013
|
+
continue
|
2014
|
+
edgeattr = edgeattrfunc(node, child)
|
2015
|
+
edgetype = edgetypefunc(node, child)
|
2016
|
+
edgeattr = " [%s]" % edgeattr if edgeattr is not None else ""
|
2017
|
+
yield '%s"%s" %s "%s"%s;' % (indent, FilteredDotExporter.esc(nodename), edgetype,
|
2018
|
+
FilteredDotExporter.esc(childname), edgeattr)
|
2019
|
+
|
2020
|
+
def to_dotfile(self, filename):
|
2021
|
+
"""
|
2022
|
+
Write graph to `filename`.
|
2023
|
+
|
2024
|
+
>>> from anytree import Node
|
2025
|
+
>>> root = Node("root")
|
2026
|
+
>>> s0 = Node("sub0", parent=root)
|
2027
|
+
>>> s0b = Node("sub0B", parent=s0)
|
2028
|
+
>>> s0a = Node("sub0A", parent=s0)
|
2029
|
+
>>> s1 = Node("sub1", parent=root)
|
2030
|
+
>>> s1a = Node("sub1A", parent=s1)
|
2031
|
+
>>> s1b = Node("sub1B", parent=s1)
|
2032
|
+
>>> s1c = Node("sub1C", parent=s1)
|
2033
|
+
>>> s1ca = Node("sub1Ca", parent=s1c)
|
2034
|
+
|
2035
|
+
>>> from anytree.exporter import DotExporter
|
2036
|
+
>>> DotExporter(root).to_dotfile("tree.dot")
|
2037
|
+
|
2038
|
+
The generated file should be handed over to the `dot` tool from the
|
2039
|
+
http://www.graphviz.org/ package::
|
2040
|
+
|
2041
|
+
$ dot tree.dot -T png -o tree.png
|
2042
|
+
"""
|
2043
|
+
with codecs.open(filename, "w", "utf-8") as file:
|
2044
|
+
for line in self:
|
2045
|
+
file.write("%s\n" % line)
|
2046
|
+
|
2047
|
+
def to_picture(self, filename):
|
2048
|
+
"""
|
2049
|
+
Write graph to a temporary file and invoke `dot`.
|
2050
|
+
|
2051
|
+
The output file type is automatically detected from the file suffix.
|
2052
|
+
|
2053
|
+
*`graphviz` needs to be installed, before usage of this method.*
|
2054
|
+
"""
|
2055
|
+
fileformat = os.path.splitext(filename)[1][1:]
|
2056
|
+
with NamedTemporaryFile("wb", delete=False) as dotfile:
|
2057
|
+
dotfilename = dotfile.name
|
2058
|
+
for line in self:
|
2059
|
+
dotfile.write(("%s\n" % line).encode("utf-8"))
|
2060
|
+
dotfile.flush()
|
2061
|
+
cmd = ["dot", dotfilename, "-T", fileformat, "-o", filename]
|
2062
|
+
check_call(cmd)
|
2063
|
+
try:
|
2064
|
+
os.remove(dotfilename)
|
2065
|
+
except Exception: # pragma: no cover
|
2066
|
+
msg = 'Could not remove temporary file %s' % dotfilename
|
2067
|
+
logger.warning(msg)
|
2068
|
+
|
2069
|
+
def to_source(self) -> Source:
|
2070
|
+
"""
|
2071
|
+
Return the source code of the graph as a Source object.
|
2072
|
+
"""
|
2073
|
+
return Source("\n".join(self), filename=self.name)
|
2074
|
+
|
2075
|
+
@staticmethod
|
2076
|
+
def esc(value):
|
2077
|
+
"""Escape Strings."""
|
2078
|
+
return _RE_ESC.sub(lambda m: r"\%s" % m.group(0), six.text_type(value))
|
2079
|
+
|
2080
|
+
|
1630
2081
|
def render_tree(root: Node, use_dot_exporter: bool = False,
|
1631
|
-
filename: str = "scrdr"
|
2082
|
+
filename: str = "scrdr", only_nodes: List[Node] = None, show_in_console: bool = False,
|
2083
|
+
color_map: Optional[Callable[[Node], str]] = None,
|
2084
|
+
view: bool = False) -> None:
|
1632
2085
|
"""
|
1633
2086
|
Render the tree using the console and optionally export it to a dot file.
|
1634
2087
|
|
1635
2088
|
:param root: The root node of the tree.
|
1636
2089
|
:param use_dot_exporter: Whether to export the tree to a dot file.
|
1637
2090
|
:param filename: The name of the file to export the tree to.
|
2091
|
+
:param only_nodes: A list of nodes to include in the dot export.
|
2092
|
+
:param show_in_console: Whether to print the tree to the console.
|
2093
|
+
:param color_map: A function that returns a color for certain nodes.
|
2094
|
+
:param view: Whether to view the dot file in a viewer.
|
1638
2095
|
"""
|
1639
2096
|
if not root:
|
1640
|
-
|
2097
|
+
logger.warning("No rules to render")
|
1641
2098
|
return
|
1642
|
-
|
1643
|
-
|
2099
|
+
if show_in_console:
|
2100
|
+
for pre, _, node in RenderTree(root):
|
2101
|
+
if only_nodes is not None and node not in only_nodes:
|
2102
|
+
continue
|
2103
|
+
print(f"{pre}{node.weight if hasattr(node, 'weight') and node.weight else ''} {node.__str__()}")
|
1644
2104
|
if use_dot_exporter:
|
1645
2105
|
unique_node_names = get_unique_node_names_func(root)
|
1646
2106
|
|
1647
|
-
de =
|
1648
|
-
|
1649
|
-
|
1650
|
-
|
1651
|
-
|
1652
|
-
|
2107
|
+
de = FilteredDotExporter(root,
|
2108
|
+
include_nodes=only_nodes,
|
2109
|
+
nodenamefunc=unique_node_names,
|
2110
|
+
edgeattrfunc=edge_attr_setter,
|
2111
|
+
nodeattrfunc=lambda node: f'style=filled,'
|
2112
|
+
f' fillcolor={color_map(node) if color_map else node.color}',
|
2113
|
+
)
|
2114
|
+
if view:
|
2115
|
+
de.to_source().view()
|
2116
|
+
else:
|
2117
|
+
filename = filename or "rule_tree"
|
2118
|
+
de.to_dotfile(f"{filename}{'.dot'}")
|
2119
|
+
try:
|
2120
|
+
de.to_picture(f"{filename}{'.svg'}")
|
2121
|
+
except FileNotFoundError as e:
|
2122
|
+
logger.warning(f"{e}")
|
1653
2123
|
|
1654
2124
|
|
1655
2125
|
def draw_tree(root: Node, fig: Figure):
|
@@ -1698,3 +2168,17 @@ def encapsulate_code_lines_into_a_function(code_lines: List[str], function_name:
|
|
1698
2168
|
if f"return {function_name}({args})" not in code:
|
1699
2169
|
code = code.strip() + f"\nreturn {function_name}({args})"
|
1700
2170
|
return code
|
2171
|
+
|
2172
|
+
|
2173
|
+
def get_method_object_from_pytest_request(request) -> Callable:
|
2174
|
+
test_module = request.module.__name__ # e.g., "test_my_module"
|
2175
|
+
test_class = request.cls.__name__ if request.cls else None # if inside a class
|
2176
|
+
test_name = request.node.name
|
2177
|
+
func = importlib.import_module(test_module)
|
2178
|
+
if test_class:
|
2179
|
+
func = getattr(getattr(func, test_class), test_name)
|
2180
|
+
else:
|
2181
|
+
func = getattr(func, test_name)
|
2182
|
+
return func
|
2183
|
+
|
2184
|
+
|