ripple-down-rules 0.6.1__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
- import ast
4
3
  import builtins
4
+ import codecs
5
5
  import copyreg
6
6
  import importlib
7
7
  import json
8
- import logging
9
8
  import os
10
9
  import re
10
+ import shutil
11
11
  import sys
12
12
  import threading
13
13
  import uuid
@@ -17,36 +17,41 @@ from dataclasses import is_dataclass, fields
17
17
  from enum import Enum
18
18
  from os.path import dirname
19
19
  from pathlib import Path
20
+ from subprocess import check_call
21
+ from tempfile import NamedTemporaryFile
20
22
  from textwrap import dedent
21
- from types import NoneType
23
+ from types import NoneType, ModuleType
24
+ import inspect
22
25
 
26
+ import six
27
+ from graphviz import Source
23
28
  from sqlalchemy.exc import NoInspectionAvailable
24
-
29
+ from . import logger
25
30
 
26
31
  try:
27
32
  import matplotlib
28
33
  from matplotlib import pyplot as plt
34
+
29
35
  Figure = plt.Figure
30
36
  except ImportError as e:
31
37
  matplotlib = None
32
38
  plt = None
33
39
  Figure = None
34
- logging.debug(f"{e}: matplotlib is not installed")
40
+ logger.debug(f"{e}: matplotlib is not installed")
35
41
 
36
42
  try:
37
43
  import networkx as nx
38
44
  except ImportError as e:
39
45
  nx = None
40
- logging.debug(f"{e}: networkx is not installed")
46
+ logger.debug(f"{e}: networkx is not installed")
41
47
 
42
48
  import requests
43
- from anytree import Node, RenderTree
44
- from anytree.exporter import DotExporter
45
- from sqlalchemy import MetaData, inspect
49
+ from anytree import Node, RenderTree, PreOrderIter
50
+ from sqlalchemy import MetaData, inspect as sql_inspect
46
51
  from sqlalchemy.orm import Mapped, registry, class_mapper, DeclarativeBase as SQLTable, Session
47
52
  from tabulate import tabulate
48
53
  from typing_extensions import Callable, Set, Any, Type, Dict, TYPE_CHECKING, get_type_hints, \
49
- get_origin, get_args, Tuple, Optional, List, Union, Self, ForwardRef, Sequence, Iterable
54
+ get_origin, get_args, Tuple, Optional, List, Union, Self, ForwardRef, Iterable
50
55
 
51
56
  if TYPE_CHECKING:
52
57
  from .datastructures.case import Case
@@ -55,6 +60,49 @@ if TYPE_CHECKING:
55
60
  import ast
56
61
 
57
62
 
63
+ def get_and_import_python_modules_in_a_package(file_paths: List[str],
64
+ parent_package_name: Optional[str] = None) -> List[Optional[ModuleType]]:
65
+ """
66
+ :param file_paths: The paths to the python files to import.
67
+ :param parent_package_name: The name of the parent package to use for relative imports.
68
+ :return: The imported modules.
69
+ """
70
+ package_path = dirname(file_paths[0])
71
+ package_import_path = get_import_path_from_path(package_path)
72
+ file_names = [Path(file_path).name.replace(".py", "") for file_path in file_paths]
73
+ module_import_paths = [
74
+ f"{package_import_path}.{file_name}" if package_import_path else file_name
75
+ for file_name in file_names
76
+ ]
77
+ modules = [
78
+ importlib.import_module(module_import_path, package=parent_package_name)
79
+ if os.path.exists(file_paths[i]) else None
80
+ for i, module_import_path in enumerate(module_import_paths)
81
+ ]
82
+ for module in modules:
83
+ if module is not None:
84
+ importlib.reload(module)
85
+ return modules
86
+
87
+
88
+ def get_and_import_python_module(python_file_path: str, package_import_path: Optional[str] = None,
89
+ parent_package_name: Optional[str] = None) -> ModuleType:
90
+ """
91
+ :param python_file_path: The path to the python file to import.
92
+ :param package_import_path: The import path of the package that contains the python file.
93
+ :param parent_package_name: The name of the parent package to use for relative imports.
94
+ :return: The imported module.
95
+ """
96
+ if package_import_path is None:
97
+ package_path = dirname(python_file_path)
98
+ package_import_path = get_import_path_from_path(package_path)
99
+ file_name = Path(python_file_path).name.replace(".py", "")
100
+ module_import_path = f"{package_import_path}.{file_name}" if package_import_path else file_name
101
+ module = importlib.import_module(module_import_path, package=parent_package_name)
102
+ importlib.reload(module)
103
+ return module
104
+
105
+
58
106
  def str_to_snake_case(snake_str: str) -> str:
59
107
  """
60
108
  Convert a string to snake case.
@@ -120,22 +168,30 @@ def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = N
120
168
  try:
121
169
  scope[asname] = importlib.import_module(module_name, package=package_name)
122
170
  except ImportError as e:
123
- print(f"Could not import {module_name}: {e}")
171
+ logger.warning(f"Could not import {module_name}: {e}")
124
172
  elif isinstance(node, ast.ImportFrom):
125
173
  module_name = node.module
126
174
  for alias in node.names:
127
175
  name = alias.name
128
176
  asname = alias.asname or name
129
177
  try:
130
- if package_name is not None and node.level > 0: # Handle relative imports
178
+ if node.level > 0: # Handle relative imports
179
+ package_name = get_import_path_from_path(Path(os.path.join(file_path, *['..'] * node.level)).resolve())
180
+ if package_name is not None and node.level > 0: # Handle relative imports
131
181
  module_rel_path = Path(os.path.join(file_path, *['..'] * node.level, module_name)).resolve()
132
182
  idx = str(module_rel_path).rfind(package_name)
133
183
  if idx != -1:
134
184
  module_name = str(module_rel_path)[idx:].replace(os.path.sep, '.')
135
- module = importlib.import_module(module_name, package=package_name)
136
- scope[asname] = getattr(module, name)
185
+ try:
186
+ module = importlib.import_module(module_name, package=package_name)
187
+ except ModuleNotFoundError:
188
+ module = importlib.import_module(f"{package_name}.{module_name}")
189
+ if name == "*":
190
+ scope.update(module.__dict__)
191
+ else:
192
+ scope[asname] = getattr(module, name)
137
193
  except (ImportError, AttributeError) as e:
138
- logging.warning(f"Could not import {module_name}: {e} while extracting imports from {file_path}")
194
+ logger.warning(f"Could not import {module_name}: {e} while extracting imports from {file_path}")
139
195
 
140
196
  return scope
141
197
 
@@ -143,7 +199,9 @@ def extract_imports(file_path: Optional[str] = None, tree: Optional[ast.AST] = N
143
199
  def extract_function_source(file_path: str,
144
200
  function_names: List[str], join_lines: bool = True,
145
201
  return_line_numbers: bool = False,
146
- include_signature: bool = True) \
202
+ include_signature: bool = True,
203
+ as_list: bool = False,
204
+ is_class: bool = False) \
147
205
  -> Union[Dict[str, Union[str, List[str]]],
148
206
  Tuple[Dict[str, Union[str, List[str]]], Dict[str, Tuple[int, int]]]]:
149
207
  """
@@ -154,6 +212,9 @@ def extract_function_source(file_path: str,
154
212
  :param join_lines: Whether to join the lines of the function.
155
213
  :param return_line_numbers: Whether to return the line numbers of the function.
156
214
  :param include_signature: Whether to include the function signature in the source code.
215
+ :param as_list: Whether to return a list of function sources instead of dict (useful when there is multiple
216
+ functions with same name).
217
+ :param is_class: Whether to also look for class definitions
157
218
  :return: A dictionary mapping function names to their source code as a string if join_lines is True,
158
219
  otherwise as a list of strings.
159
220
  """
@@ -164,24 +225,39 @@ def extract_function_source(file_path: str,
164
225
  tree = ast.parse(source)
165
226
  function_names = make_list(function_names)
166
227
  functions_source: Dict[str, Union[str, List[str]]] = {}
228
+ functions_source_list: List[Union[str, List[str]]] = []
167
229
  line_numbers: Dict[str, Tuple[int, int]] = {}
230
+ line_numbers_list: List[Tuple[int, int]] = []
231
+ if is_class:
232
+ look_for_type = ast.ClassDef
233
+ else:
234
+ look_for_type = ast.FunctionDef
235
+
168
236
  for node in tree.body:
169
- if isinstance(node, ast.FunctionDef) and (node.name in function_names or len(function_names) == 0):
237
+ if isinstance(node, look_for_type) and (node.name in function_names or len(function_names) == 0):
170
238
  # Get the line numbers of the function
171
239
  lines = source.splitlines()
172
240
  func_lines = lines[node.lineno - 1:node.end_lineno]
173
241
  if not include_signature:
174
242
  func_lines = func_lines[1:]
175
- line_numbers[node.name] = (node.lineno, node.end_lineno)
176
- functions_source[node.name] = dedent("\n".join(func_lines)) if join_lines else func_lines
177
- if (len(functions_source) >= len(function_names)) and (not len(function_names) == 0):
178
- break
179
- if len(functions_source) < len(function_names):
180
- raise ValueError(f"Could not find all functions in {file_path}: {function_names} not found,"
181
- f"functions not found: {set(function_names) - set(functions_source.keys())}")
243
+ if as_list:
244
+ line_numbers_list.append((node.lineno, node.end_lineno))
245
+ else:
246
+ line_numbers[node.name] = (node.lineno, node.end_lineno)
247
+ parsed_function = dedent("\n".join(func_lines)) if join_lines else func_lines
248
+ if as_list:
249
+ functions_source_list.append(parsed_function)
250
+ else:
251
+ functions_source[node.name] = parsed_function
252
+ if len(function_names) > 0:
253
+ if len(functions_source) >= len(function_names) or len(functions_source_list) >= len(function_names):
254
+ break
255
+ if len(functions_source) < len(function_names) and len(functions_source_list) < len(function_names):
256
+ logger.warning(f"Could not find all functions in {file_path}: {function_names} not found, "
257
+ f"functions not found: {set(function_names) - set(functions_source.keys())}")
182
258
  if return_line_numbers:
183
- return functions_source, line_numbers
184
- return functions_source
259
+ return functions_source if not as_list else functions_source_list, line_numbers if not as_list else line_numbers_list
260
+ return functions_source if not as_list else functions_source_list
185
261
 
186
262
 
187
263
  def encapsulate_user_input(user_input: str, func_signature: str, func_doc: Optional[str] = None) -> str:
@@ -193,7 +269,8 @@ def encapsulate_user_input(user_input: str, func_signature: str, func_doc: Optio
193
269
  :param func_doc: The function docstring to use for encapsulation.
194
270
  :return: The encapsulated user input string.
195
271
  """
196
- if func_signature not in user_input:
272
+ func_name = func_signature.split('(')[0].strip()
273
+ if func_name not in user_input:
197
274
  new_user_input = func_signature + "\n "
198
275
  if func_doc is not None:
199
276
  new_user_input += f"\"\"\"{func_doc}\"\"\"" + "\n "
@@ -282,7 +359,7 @@ def update_case(case_query: CaseQuery, conclusions: Dict[str, Any]):
282
359
  case_query.case.update(conclusions)
283
360
 
284
361
 
285
- def is_conflicting(conclusion: Any, target: Any) -> bool:
362
+ def is_value_conflicting(conclusion: Any, target: Any) -> bool:
286
363
  """
287
364
  :param conclusion: The conclusion to check.
288
365
  :param target: The target to compare the conclusion with.
@@ -621,18 +698,21 @@ def capture_variable_assignment(code: str, variable_name: str) -> Optional[str]:
621
698
  return assignment
622
699
 
623
700
 
624
- def get_func_rdr_model_path(func: Callable, model_dir: str) -> str:
701
+ def get_func_rdr_model_path(func: Callable, model_dir: str, include_file_name: bool = False) -> str:
625
702
  """
626
703
  :param func: The function to get the model path for.
627
704
  :param model_dir: The directory to save the model to.
705
+ :param include_file_name: Whether to include the file name in the model name.
628
706
  :return: The path to the model file.
629
707
  """
630
- return os.path.join(model_dir, f"{get_func_rdr_model_name(func)}.json")
708
+ return os.path.join(model_dir, get_func_rdr_model_name(func, include_file_name=include_file_name),
709
+ f"{get_func_rdr_model_name(func)}_rdr.py")
631
710
 
632
711
 
633
712
  def get_func_rdr_model_name(func: Callable, include_file_name: bool = False) -> str:
634
713
  """
635
714
  :param func: The function to get the model name for.
715
+ :param include_file_name: Whether to include the file name in the model name.
636
716
  :return: The name of the model.
637
717
  """
638
718
  func_name = get_method_name(func)
@@ -644,7 +724,7 @@ def get_func_rdr_model_name(func: Callable, include_file_name: bool = False) ->
644
724
  model_name = ''
645
725
  model_name += f"{func_class_name}_" if func_class_name else ""
646
726
  model_name += f"{func_name}"
647
- return model_name
727
+ return str_to_snake_case(model_name)
648
728
 
649
729
 
650
730
  def stringify_hint(tp):
@@ -681,6 +761,7 @@ def is_builtin_type(tp):
681
761
  def is_typing_type(tp):
682
762
  return tp.__module__ == "typing"
683
763
 
764
+
684
765
  origin_type_to_hint = {
685
766
  list: List,
686
767
  set: Set,
@@ -688,6 +769,47 @@ origin_type_to_hint = {
688
769
  tuple: Tuple,
689
770
  }
690
771
 
772
+
773
+ def get_file_that_ends_with(directory_path: str, suffix: str) -> Optional[str]:
774
+ """
775
+ Get the file that ends with the given suffix in the model directory.
776
+
777
+ :param directory_path: The path to the directory where the file is located.
778
+ :param suffix: The suffix to search for.
779
+ :return: The path to the file that ends with the given suffix, or None if not found.
780
+ """
781
+ files = [f for f in os.listdir(directory_path) if f.endswith(suffix)]
782
+ if files:
783
+ return files[0]
784
+ return None
785
+
786
+ def get_function_return_type(func: Callable) -> Union[Type, None, Tuple[Type, ...]]:
787
+ """
788
+ Get the return type of a function.
789
+
790
+ :param func: The function to get the return type for.
791
+ :return: The return type of the function, or None if not specified.
792
+ """
793
+ sig = inspect.signature(func)
794
+ if sig.return_annotation == inspect.Signature.empty:
795
+ return None
796
+ type_hint = sig.return_annotation
797
+ return get_type_from_type_hint(type_hint)
798
+
799
+
800
+ def get_type_from_type_hint(type_hint: Type) -> Union[Type, Tuple[Type, ...]]:
801
+ origin = get_origin(type_hint)
802
+ args = get_args(type_hint)
803
+ if origin not in [list, set, None, Union]:
804
+ raise TypeError(f"{origin} is not a handled return type for type hint {type_hint}")
805
+ if origin is None:
806
+ return typing_to_python_type(type_hint)
807
+ if args is None or len(args) == 0:
808
+ return typing_to_python_type(type_hint)
809
+ return args
810
+
811
+
812
+
691
813
  def extract_types(tp, seen: Set = None) -> Set[type]:
692
814
  """Recursively extract all base types from a type hint."""
693
815
  if seen is None:
@@ -786,6 +908,13 @@ def get_import_path_from_path(path: str) -> Optional[str]:
786
908
  return package_name
787
909
 
788
910
 
911
+ def get_class_file_path(cls):
912
+ """
913
+ Get the file path of a class.
914
+ """
915
+ return os.path.abspath(inspect.getfile(cls))
916
+
917
+
789
918
  def get_function_import_data(func: Callable) -> Tuple[str, str]:
790
919
  """
791
920
  Get the import path of a function.
@@ -840,10 +969,12 @@ def get_relative_import(target_file_path, imported_module_path: Optional[str] =
840
969
  imported_file_name = Path(imported_module_path).name
841
970
  target_file_name = Path(target_file_path).name
842
971
  if package_name is not None:
843
- target_path = Path(get_path_starting_from_latest_encounter_of(str(target_path), package_name, [target_file_name]))
972
+ target_path = Path(
973
+ get_path_starting_from_latest_encounter_of(str(target_path), package_name, [target_file_name]))
844
974
  imported_path = Path(imported_module_path).resolve()
845
975
  if package_name is not None:
846
- imported_path = Path(get_path_starting_from_latest_encounter_of(str(imported_path), package_name, [imported_file_name]))
976
+ imported_path = Path(
977
+ get_path_starting_from_latest_encounter_of(str(imported_path), package_name, [imported_file_name]))
847
978
 
848
979
  # Compute relative path from target to imported module
849
980
  rel_path = os.path.relpath(imported_path.parent, target_path.parent)
@@ -920,8 +1051,8 @@ def get_imports_from_types(type_objs: Iterable[Type],
920
1051
  continue
921
1052
  if name == "NoneType":
922
1053
  module = "types"
923
- if module is None or module == 'builtins' or module.startswith('_')\
924
- or module in sys.builtin_module_names or module in excluded_modules or "<" in module \
1054
+ if module is None or module == 'builtins' or module.startswith('_') \
1055
+ or module in sys.builtin_module_names or module in excluded_modules or "<" in module \
925
1056
  or name in exclueded_names:
926
1057
  continue
927
1058
  if module == "typing":
@@ -989,7 +1120,8 @@ def get_method_class_name_if_exists(method: Callable) -> Optional[str]:
989
1120
  return method.__self__.__name__
990
1121
  elif hasattr(method.__self__, "__class__"):
991
1122
  return method.__self__.__class__.__name__
992
- return method.__qualname__.split('.')[0] if hasattr(method, "__qualname__") else None
1123
+ return (method.__qualname__.split('.')[0]
1124
+ if hasattr(method, "__qualname__") and '.' in method.__qualname__ else None)
993
1125
 
994
1126
 
995
1127
  def get_method_class_if_exists(method: Callable, *args) -> Optional[Type]:
@@ -1087,6 +1219,7 @@ def get_full_class_name(cls):
1087
1219
  def recursive_subclasses(cls):
1088
1220
  """
1089
1221
  Copied from: https://github.com/tomsch420/random-events/blob/master/src/random_events/utils.py#L6C1-L21C101
1222
+
1090
1223
  :param cls: The class.
1091
1224
  :return: A list of the classes subclasses.
1092
1225
  """
@@ -1211,7 +1344,8 @@ class SubclassJSONSerializer:
1211
1344
  return cls._from_json(data)
1212
1345
  for subclass in recursive_subclasses(SubclassJSONSerializer):
1213
1346
  if get_full_class_name(subclass) == data["_type"]:
1214
- subclass_data = deepcopy(data)
1347
+ # subclass_data = deepcopy(data)
1348
+ subclass_data = data
1215
1349
  subclass_data.pop("_type")
1216
1350
  return subclass._from_json(subclass_data)
1217
1351
 
@@ -1273,7 +1407,7 @@ def copy_orm_instance(instance: SQLTable) -> SQLTable:
1273
1407
  :return: The copied instance.
1274
1408
  """
1275
1409
  try:
1276
- session: Session = inspect(instance).session
1410
+ session: Session = sql_inspect(instance).session
1277
1411
  except NoInspectionAvailable:
1278
1412
  session = None
1279
1413
  if session is not None:
@@ -1284,7 +1418,7 @@ def copy_orm_instance(instance: SQLTable) -> SQLTable:
1284
1418
  try:
1285
1419
  new_instance = deepcopy(instance)
1286
1420
  except Exception as e:
1287
- logging.debug(e)
1421
+ logger.debug(e)
1288
1422
  new_instance = instance
1289
1423
  return new_instance
1290
1424
 
@@ -1304,7 +1438,7 @@ def copy_orm_instance_with_relationships(instance: SQLTable) -> SQLTable:
1304
1438
  try:
1305
1439
  setattr(instance_cp, rel.key, related_obj_cp)
1306
1440
  except Exception as e:
1307
- logging.debug(e)
1441
+ logger.debug(e)
1308
1442
  return instance_cp
1309
1443
 
1310
1444
 
@@ -1400,9 +1534,14 @@ def table_rows_as_str(row_dicts: List[Dict[str, Any]], columns_per_row: int = 20
1400
1534
  keys_values = [list(r[0]) + list(r[1]) if len(r) > 1 else r[0] for r in keys_values]
1401
1535
  all_table_rows = []
1402
1536
  row_values = [list(map(lambda v: str(v) if v is not None else "", row)) for row in keys_values]
1403
- row_values = [list(map(lambda v: v[:max_line_sze] + '...' if len(v) > max_line_sze else v, row)) for row in row_values]
1537
+ row_values = [list(map(lambda v: v[:max_line_sze] + '...' if len(v) > max_line_sze else v, row)) for row in
1538
+ row_values]
1404
1539
  row_values = [list(map(lambda v: v.lower() if v in ["True", "False"] else v, row)) for row in row_values]
1405
- table = tabulate(row_values, tablefmt='simple_grid', maxcolwidths=[max_line_sze] * 2)
1540
+ # Step 1: Get terminal size
1541
+ terminal_width = shutil.get_terminal_size((80, 20)).columns
1542
+ # Step 2: Dynamically calculate max width per column (simple approximation)
1543
+ max_col_width = terminal_width // len(row_values[0])
1544
+ table = tabulate(row_values, tablefmt='simple_grid', maxcolwidths=max_col_width) # [max_line_sze] * 2)
1406
1545
  all_table_rows.append(table)
1407
1546
  return "\n".join(all_table_rows)
1408
1547
 
@@ -1551,6 +1690,13 @@ def get_all_subclasses(cls: Type) -> Dict[str, Type]:
1551
1690
  return all_subclasses
1552
1691
 
1553
1692
 
1693
+ def make_tuple(value: Any) -> Any:
1694
+ """
1695
+ Make a tuple from a value.
1696
+ """
1697
+ return tuple(value) if is_iterable(value) else (value,)
1698
+
1699
+
1554
1700
  def make_set(value: Any) -> Set:
1555
1701
  """
1556
1702
  Make a set from a value.
@@ -1622,34 +1768,358 @@ def edge_attr_setter(parent, child):
1622
1768
  """
1623
1769
  Set the edge attributes for the dot exporter.
1624
1770
  """
1625
- if child and hasattr(child, "weight") and child.weight:
1771
+ if child and hasattr(child, "weight") and child.weight is not None:
1626
1772
  return f'style="bold", label=" {child.weight}"'
1627
1773
  return ""
1628
1774
 
1629
1775
 
1776
+ _RE_ESC = re.compile(r'["\\]')
1777
+
1778
+
1779
+ class FilteredDotExporter(object):
1780
+
1781
+ def __init__(self, node, include_nodes=None, graph="digraph", name="tree", options=None,
1782
+ indent=4, nodenamefunc=None, nodeattrfunc=None,
1783
+ edgeattrfunc=None, edgetypefunc=None, maxlevel=None):
1784
+ """
1785
+ Dot Language Exporter.
1786
+
1787
+ Args:
1788
+ node (Node): start node.
1789
+
1790
+ Keyword Args:
1791
+ graph: DOT graph type.
1792
+
1793
+ name: DOT graph name.
1794
+
1795
+ options: list of options added to the graph.
1796
+
1797
+ indent (int): number of spaces for indent.
1798
+
1799
+ nodenamefunc: Function to extract node name from `node` object.
1800
+ The function shall accept one `node` object as
1801
+ argument and return the name of it.
1802
+
1803
+ nodeattrfunc: Function to decorate a node with attributes.
1804
+ The function shall accept one `node` object as
1805
+ argument and return the attributes.
1806
+
1807
+ edgeattrfunc: Function to decorate a edge with attributes.
1808
+ The function shall accept two `node` objects as
1809
+ argument. The first the node and the second the child
1810
+ and return the attributes.
1811
+
1812
+ edgetypefunc: Function to which gives the edge type.
1813
+ The function shall accept two `node` objects as
1814
+ argument. The first the node and the second the child
1815
+ and return the edge (i.e. '->').
1816
+
1817
+ maxlevel (int): Limit export to this number of levels.
1818
+
1819
+ >>> from anytree import Node
1820
+ >>> root = Node("root")
1821
+ >>> s0 = Node("sub0", parent=root, edge=2)
1822
+ >>> s0b = Node("sub0B", parent=s0, foo=4, edge=109)
1823
+ >>> s0a = Node("sub0A", parent=s0, edge="")
1824
+ >>> s1 = Node("sub1", parent=root, edge="")
1825
+ >>> s1a = Node("sub1A", parent=s1, edge=7)
1826
+ >>> s1b = Node("sub1B", parent=s1, edge=8)
1827
+ >>> s1c = Node("sub1C", parent=s1, edge=22)
1828
+ >>> s1ca = Node("sub1Ca", parent=s1c, edge=42)
1829
+
1830
+ .. note:: If the node names are not unqiue, see :any:`UniqueDotExporter`.
1831
+
1832
+ A directed graph:
1833
+
1834
+ >>> from anytree.exporter import DotExporter
1835
+ >>> for line in DotExporter(root):
1836
+ ... print(line)
1837
+ digraph tree {
1838
+ "root";
1839
+ "sub0";
1840
+ "sub0B";
1841
+ "sub0A";
1842
+ "sub1";
1843
+ "sub1A";
1844
+ "sub1B";
1845
+ "sub1C";
1846
+ "sub1Ca";
1847
+ "root" -> "sub0";
1848
+ "root" -> "sub1";
1849
+ "sub0" -> "sub0B";
1850
+ "sub0" -> "sub0A";
1851
+ "sub1" -> "sub1A";
1852
+ "sub1" -> "sub1B";
1853
+ "sub1" -> "sub1C";
1854
+ "sub1C" -> "sub1Ca";
1855
+ }
1856
+
1857
+ The resulting graph:
1858
+
1859
+ .. image:: ../static/dotexporter0.png
1860
+
1861
+ An undirected graph:
1862
+
1863
+ >>> def nodenamefunc(node):
1864
+ ... return '%s:%s' % (node.name, node.depth)
1865
+ >>> def edgeattrfunc(node, child):
1866
+ ... return 'label="%s:%s"' % (node.name, child.name)
1867
+ >>> def edgetypefunc(node, child):
1868
+ ... return '--'
1869
+ >>> from anytree.exporter import DotExporter
1870
+ >>> for line in DotExporter(root, graph="graph",
1871
+ ... nodenamefunc=nodenamefunc,
1872
+ ... nodeattrfunc=lambda node: "shape=box",
1873
+ ... edgeattrfunc=edgeattrfunc,
1874
+ ... edgetypefunc=edgetypefunc):
1875
+ ... print(line)
1876
+ graph tree {
1877
+ "root:0" [shape=box];
1878
+ "sub0:1" [shape=box];
1879
+ "sub0B:2" [shape=box];
1880
+ "sub0A:2" [shape=box];
1881
+ "sub1:1" [shape=box];
1882
+ "sub1A:2" [shape=box];
1883
+ "sub1B:2" [shape=box];
1884
+ "sub1C:2" [shape=box];
1885
+ "sub1Ca:3" [shape=box];
1886
+ "root:0" -- "sub0:1" [label="root:sub0"];
1887
+ "root:0" -- "sub1:1" [label="root:sub1"];
1888
+ "sub0:1" -- "sub0B:2" [label="sub0:sub0B"];
1889
+ "sub0:1" -- "sub0A:2" [label="sub0:sub0A"];
1890
+ "sub1:1" -- "sub1A:2" [label="sub1:sub1A"];
1891
+ "sub1:1" -- "sub1B:2" [label="sub1:sub1B"];
1892
+ "sub1:1" -- "sub1C:2" [label="sub1:sub1C"];
1893
+ "sub1C:2" -- "sub1Ca:3" [label="sub1C:sub1Ca"];
1894
+ }
1895
+
1896
+ The resulting graph:
1897
+
1898
+ .. image:: ../static/dotexporter1.png
1899
+
1900
+ To export custom node implementations or :any:`AnyNode`, please provide a proper `nodenamefunc`:
1901
+
1902
+ >>> from anytree import AnyNode
1903
+ >>> root = AnyNode(id="root")
1904
+ >>> s0 = AnyNode(id="sub0", parent=root)
1905
+ >>> s0b = AnyNode(id="s0b", parent=s0)
1906
+ >>> s0a = AnyNode(id="s0a", parent=s0)
1907
+
1908
+ >>> from anytree.exporter import DotExporter
1909
+ >>> for line in DotExporter(root, nodenamefunc=lambda n: n.id):
1910
+ ... print(line)
1911
+ digraph tree {
1912
+ "root";
1913
+ "sub0";
1914
+ "s0b";
1915
+ "s0a";
1916
+ "root" -> "sub0";
1917
+ "sub0" -> "s0b";
1918
+ "sub0" -> "s0a";
1919
+ }
1920
+ """
1921
+ self.node = node
1922
+ self.graph = graph
1923
+ self.name = name
1924
+ self.options = options
1925
+ self.indent = indent
1926
+ self.nodenamefunc = nodenamefunc
1927
+ self.nodeattrfunc = nodeattrfunc
1928
+ self.edgeattrfunc = edgeattrfunc
1929
+ self.edgetypefunc = edgetypefunc
1930
+ self.maxlevel = maxlevel
1931
+ self.include_nodes = include_nodes
1932
+ node_name_func = get_unique_node_names_func(node)
1933
+ self.include_node_names = [node_name_func(n) for n in self.include_nodes] if include_nodes else None
1934
+
1935
+ def __iter__(self):
1936
+ # prepare
1937
+ indent = " " * self.indent
1938
+ nodenamefunc = self.nodenamefunc or self._default_nodenamefunc
1939
+ nodeattrfunc = self.nodeattrfunc or self._default_nodeattrfunc
1940
+ edgeattrfunc = self.edgeattrfunc or self._default_edgeattrfunc
1941
+ edgetypefunc = self.edgetypefunc or self._default_edgetypefunc
1942
+ return self.__iter(indent, nodenamefunc, nodeattrfunc, edgeattrfunc,
1943
+ edgetypefunc)
1944
+
1945
+ @staticmethod
1946
+ def _default_nodenamefunc(node):
1947
+ return node.name
1948
+
1949
+ @staticmethod
1950
+ def _default_nodeattrfunc(node):
1951
+ return None
1952
+
1953
+ @staticmethod
1954
+ def _default_edgeattrfunc(node, child):
1955
+ return None
1956
+
1957
+ @staticmethod
1958
+ def _default_edgetypefunc(node, child):
1959
+ return "->"
1960
+
1961
+ def __iter(self, indent, nodenamefunc, nodeattrfunc, edgeattrfunc, edgetypefunc):
1962
+ yield "{self.graph} {self.name} {{".format(self=self)
1963
+ for option in self.__iter_options(indent):
1964
+ yield option
1965
+ for node in self.__iter_nodes(indent, nodenamefunc, nodeattrfunc):
1966
+ yield node
1967
+ for edge in self.__iter_edges(indent, nodenamefunc, edgeattrfunc, edgetypefunc):
1968
+ yield edge
1969
+ legend_dot_graph = """
1970
+ // Color legend as a subgraph
1971
+ subgraph cluster_legend {
1972
+ label = "Legend";
1973
+ style = dashed;
1974
+ color = gray;
1975
+
1976
+ legend_green [label="Fired->Query Related Value", shape=box, style=filled, fillcolor=green, fontcolor=black, size=0.5];
1977
+ legend_yellow [label="Fired->Some Value", shape=box, style=filled, fillcolor=yellow, fontcolor=black, size=0.5];
1978
+ legend_orange [label="Fired->Empty Value", shape=box, style=filled, fillcolor=orange, fontcolor=black, size=0.5];
1979
+ legend_red [label="Evaluated->Not Fired", shape=box, style=filled, fillcolor=red, fontcolor=black, size=0.5];
1980
+ legend_white [label="Not Evaluated", shape=box, style=filled, fillcolor=white, fontcolor=black, size=0.5];
1981
+
1982
+ // Invisible edges to arrange legend vertically
1983
+ legend_white -> legend_red -> legend_orange -> legend_yellow -> legend_green [style=invis];
1984
+ }"""
1985
+ for line in legend_dot_graph.splitlines():
1986
+ yield "%s" % (line.strip())
1987
+ yield "}"
1988
+
1989
+ def __iter_options(self, indent):
1990
+ options = self.options
1991
+ if options:
1992
+ for option in options:
1993
+ yield "%s%s" % (indent, option)
1994
+
1995
+ def __iter_nodes(self, indent, nodenamefunc, nodeattrfunc):
1996
+ for node in PreOrderIter(self.node, maxlevel=self.maxlevel):
1997
+ nodename = nodenamefunc(node)
1998
+ if self.include_nodes is not None and nodename not in self.include_node_names:
1999
+ continue
2000
+ nodeattr = nodeattrfunc(node)
2001
+ nodeattr = " [%s]" % nodeattr if nodeattr is not None else ""
2002
+ yield '%s"%s"%s;' % (indent, FilteredDotExporter.esc(nodename), nodeattr)
2003
+
2004
+ def __iter_edges(self, indent, nodenamefunc, edgeattrfunc, edgetypefunc):
2005
+ maxlevel = self.maxlevel - 1 if self.maxlevel else None
2006
+ for node in PreOrderIter(self.node, maxlevel=maxlevel):
2007
+ nodename = nodenamefunc(node)
2008
+ if self.include_nodes is not None and nodename not in self.include_node_names:
2009
+ continue
2010
+ for child in node.children:
2011
+ childname = nodenamefunc(child)
2012
+ if self.include_nodes is not None and childname not in self.include_node_names:
2013
+ continue
2014
+ edgeattr = edgeattrfunc(node, child)
2015
+ edgetype = edgetypefunc(node, child)
2016
+ edgeattr = " [%s]" % edgeattr if edgeattr is not None else ""
2017
+ yield '%s"%s" %s "%s"%s;' % (indent, FilteredDotExporter.esc(nodename), edgetype,
2018
+ FilteredDotExporter.esc(childname), edgeattr)
2019
+
2020
+ def to_dotfile(self, filename):
2021
+ """
2022
+ Write graph to `filename`.
2023
+
2024
+ >>> from anytree import Node
2025
+ >>> root = Node("root")
2026
+ >>> s0 = Node("sub0", parent=root)
2027
+ >>> s0b = Node("sub0B", parent=s0)
2028
+ >>> s0a = Node("sub0A", parent=s0)
2029
+ >>> s1 = Node("sub1", parent=root)
2030
+ >>> s1a = Node("sub1A", parent=s1)
2031
+ >>> s1b = Node("sub1B", parent=s1)
2032
+ >>> s1c = Node("sub1C", parent=s1)
2033
+ >>> s1ca = Node("sub1Ca", parent=s1c)
2034
+
2035
+ >>> from anytree.exporter import DotExporter
2036
+ >>> DotExporter(root).to_dotfile("tree.dot")
2037
+
2038
+ The generated file should be handed over to the `dot` tool from the
2039
+ http://www.graphviz.org/ package::
2040
+
2041
+ $ dot tree.dot -T png -o tree.png
2042
+ """
2043
+ with codecs.open(filename, "w", "utf-8") as file:
2044
+ for line in self:
2045
+ file.write("%s\n" % line)
2046
+
2047
+ def to_picture(self, filename):
2048
+ """
2049
+ Write graph to a temporary file and invoke `dot`.
2050
+
2051
+ The output file type is automatically detected from the file suffix.
2052
+
2053
+ *`graphviz` needs to be installed, before usage of this method.*
2054
+ """
2055
+ fileformat = os.path.splitext(filename)[1][1:]
2056
+ with NamedTemporaryFile("wb", delete=False) as dotfile:
2057
+ dotfilename = dotfile.name
2058
+ for line in self:
2059
+ dotfile.write(("%s\n" % line).encode("utf-8"))
2060
+ dotfile.flush()
2061
+ cmd = ["dot", dotfilename, "-T", fileformat, "-o", filename]
2062
+ check_call(cmd)
2063
+ try:
2064
+ os.remove(dotfilename)
2065
+ except Exception: # pragma: no cover
2066
+ msg = 'Could not remove temporary file %s' % dotfilename
2067
+ logger.warning(msg)
2068
+
2069
+ def to_source(self) -> Source:
2070
+ """
2071
+ Return the source code of the graph as a Source object.
2072
+ """
2073
+ return Source("\n".join(self), filename=self.name)
2074
+
2075
+ @staticmethod
2076
+ def esc(value):
2077
+ """Escape Strings."""
2078
+ return _RE_ESC.sub(lambda m: r"\%s" % m.group(0), six.text_type(value))
2079
+
2080
+
1630
2081
  def render_tree(root: Node, use_dot_exporter: bool = False,
1631
- filename: str = "scrdr"):
2082
+ filename: str = "scrdr", only_nodes: List[Node] = None, show_in_console: bool = False,
2083
+ color_map: Optional[Callable[[Node], str]] = None,
2084
+ view: bool = False) -> None:
1632
2085
  """
1633
2086
  Render the tree using the console and optionally export it to a dot file.
1634
2087
 
1635
2088
  :param root: The root node of the tree.
1636
2089
  :param use_dot_exporter: Whether to export the tree to a dot file.
1637
2090
  :param filename: The name of the file to export the tree to.
2091
+ :param only_nodes: A list of nodes to include in the dot export.
2092
+ :param show_in_console: Whether to print the tree to the console.
2093
+ :param color_map: A function that returns a color for certain nodes.
2094
+ :param view: Whether to view the dot file in a viewer.
1638
2095
  """
1639
2096
  if not root:
1640
- logging.warning("No rules to render")
2097
+ logger.warning("No rules to render")
1641
2098
  return
1642
- for pre, _, node in RenderTree(root):
1643
- print(f"{pre}{node.weight if hasattr(node, 'weight') and node.weight else ''} {node.__str__()}")
2099
+ if show_in_console:
2100
+ for pre, _, node in RenderTree(root):
2101
+ if only_nodes is not None and node not in only_nodes:
2102
+ continue
2103
+ print(f"{pre}{node.weight if hasattr(node, 'weight') and node.weight else ''} {node.__str__()}")
1644
2104
  if use_dot_exporter:
1645
2105
  unique_node_names = get_unique_node_names_func(root)
1646
2106
 
1647
- de = DotExporter(root,
1648
- nodenamefunc=unique_node_names,
1649
- edgeattrfunc=edge_attr_setter
1650
- )
1651
- de.to_dotfile(f"{filename}{'.dot'}")
1652
- de.to_picture(f"{filename}{'.png'}")
2107
+ de = FilteredDotExporter(root,
2108
+ include_nodes=only_nodes,
2109
+ nodenamefunc=unique_node_names,
2110
+ edgeattrfunc=edge_attr_setter,
2111
+ nodeattrfunc=lambda node: f'style=filled,'
2112
+ f' fillcolor={color_map(node) if color_map else node.color}',
2113
+ )
2114
+ if view:
2115
+ de.to_source().view()
2116
+ else:
2117
+ filename = filename or "rule_tree"
2118
+ de.to_dotfile(f"{filename}{'.dot'}")
2119
+ try:
2120
+ de.to_picture(f"{filename}{'.svg'}")
2121
+ except FileNotFoundError as e:
2122
+ logger.warning(f"{e}")
1653
2123
 
1654
2124
 
1655
2125
  def draw_tree(root: Node, fig: Figure):
@@ -1698,3 +2168,17 @@ def encapsulate_code_lines_into_a_function(code_lines: List[str], function_name:
1698
2168
  if f"return {function_name}({args})" not in code:
1699
2169
  code = code.strip() + f"\nreturn {function_name}({args})"
1700
2170
  return code
2171
+
2172
+
2173
+ def get_method_object_from_pytest_request(request) -> Callable:
2174
+ test_module = request.module.__name__ # e.g., "test_my_module"
2175
+ test_class = request.cls.__name__ if request.cls else None # if inside a class
2176
+ test_name = request.node.name
2177
+ func = importlib.import_module(test_module)
2178
+ if test_class:
2179
+ func = getattr(getattr(func, test_class), test_name)
2180
+ else:
2181
+ func = getattr(func, test_name)
2182
+ return func
2183
+
2184
+