algorhino-anemone 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. algorhino_anemone-0.1.1.dist-info/METADATA +151 -0
  2. algorhino_anemone-0.1.1.dist-info/RECORD +82 -0
  3. algorhino_anemone-0.1.1.dist-info/WHEEL +5 -0
  4. algorhino_anemone-0.1.1.dist-info/licenses/LICENSE +674 -0
  5. algorhino_anemone-0.1.1.dist-info/top_level.txt +1 -0
  6. anemone/__init__.py +27 -0
  7. anemone/basics.py +36 -0
  8. anemone/factory.py +161 -0
  9. anemone/indices/__init__.py +0 -0
  10. anemone/indices/index_manager/__init__.py +12 -0
  11. anemone/indices/index_manager/factory.py +50 -0
  12. anemone/indices/index_manager/node_exploration_manager.py +549 -0
  13. anemone/indices/node_indices/__init__.py +22 -0
  14. anemone/indices/node_indices/factory.py +121 -0
  15. anemone/indices/node_indices/index_data.py +166 -0
  16. anemone/indices/node_indices/index_types.py +20 -0
  17. anemone/nn/torch_evaluator.py +108 -0
  18. anemone/node_evaluation/__init__.py +0 -0
  19. anemone/node_evaluation/node_direct_evaluation/__init__.py +22 -0
  20. anemone/node_evaluation/node_direct_evaluation/factory.py +12 -0
  21. anemone/node_evaluation/node_direct_evaluation/node_direct_evaluator.py +192 -0
  22. anemone/node_evaluation/node_tree_evaluation/node_minmax_evaluation.py +885 -0
  23. anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation.py +137 -0
  24. anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation_factory.py +43 -0
  25. anemone/node_factory/__init__.py +14 -0
  26. anemone/node_factory/algorithm_node_factory.py +123 -0
  27. anemone/node_factory/base.py +76 -0
  28. anemone/node_selector/__init__.py +32 -0
  29. anemone/node_selector/branch_explorer.py +89 -0
  30. anemone/node_selector/factory.py +65 -0
  31. anemone/node_selector/node_selector.py +44 -0
  32. anemone/node_selector/node_selector_args.py +22 -0
  33. anemone/node_selector/node_selector_types.py +15 -0
  34. anemone/node_selector/notations_and_statics.py +88 -0
  35. anemone/node_selector/opening_instructions.py +249 -0
  36. anemone/node_selector/recurzipf/__init__.py +0 -0
  37. anemone/node_selector/recurzipf/recur_zipf_base.py +141 -0
  38. anemone/node_selector/sequool/__init__.py +19 -0
  39. anemone/node_selector/sequool/factory.py +102 -0
  40. anemone/node_selector/sequool/sequool.py +395 -0
  41. anemone/node_selector/uniform/__init__.py +16 -0
  42. anemone/node_selector/uniform/uniform.py +113 -0
  43. anemone/nodes/__init__.py +15 -0
  44. anemone/nodes/algorithm_node/__init__.py +7 -0
  45. anemone/nodes/algorithm_node/algorithm_node.py +204 -0
  46. anemone/nodes/itree_node.py +136 -0
  47. anemone/nodes/tree_node.py +240 -0
  48. anemone/nodes/tree_traversal.py +108 -0
  49. anemone/nodes/utils.py +146 -0
  50. anemone/progress_monitor/__init__.py +0 -0
  51. anemone/progress_monitor/progress_monitor.py +375 -0
  52. anemone/recommender_rule/__init__.py +12 -0
  53. anemone/recommender_rule/recommender_rule.py +140 -0
  54. anemone/search_factory/__init__.py +14 -0
  55. anemone/search_factory/search_factory.py +192 -0
  56. anemone/state_transition.py +47 -0
  57. anemone/tree_and_value_branch_selector.py +99 -0
  58. anemone/tree_exploration.py +274 -0
  59. anemone/tree_manager/__init__.py +29 -0
  60. anemone/tree_manager/algorithm_node_tree_manager.py +246 -0
  61. anemone/tree_manager/factory.py +77 -0
  62. anemone/tree_manager/tree_expander.py +122 -0
  63. anemone/tree_manager/tree_manager.py +254 -0
  64. anemone/trees/__init__.py +14 -0
  65. anemone/trees/descendants.py +765 -0
  66. anemone/trees/factory.py +80 -0
  67. anemone/trees/tree.py +70 -0
  68. anemone/trees/tree_visualization.py +143 -0
  69. anemone/updates/__init__.py +33 -0
  70. anemone/updates/algorithm_node_updater.py +157 -0
  71. anemone/updates/factory.py +36 -0
  72. anemone/updates/index_block.py +91 -0
  73. anemone/updates/index_updater.py +100 -0
  74. anemone/updates/minmax_evaluation_updater.py +108 -0
  75. anemone/updates/updates_file.py +248 -0
  76. anemone/updates/value_block.py +133 -0
  77. anemone/utils/comparable.py +32 -0
  78. anemone/utils/dataclass.py +64 -0
  79. anemone/utils/dict_of_numbered_dict_with_pointer_on_max.py +128 -0
  80. anemone/utils/logger.py +94 -0
  81. anemone/utils/my_value_sorted_dict.py +27 -0
  82. anemone/utils/small_tools.py +103 -0
@@ -0,0 +1,204 @@
1
+ """
2
+ This module defines the AlgorithmNode class, which is a generic node used by the tree and value algorithm.
3
+ It wraps tree nodes with values, minimax computation, and exploration tools.
4
+ """
5
+
6
+ from typing import MutableMapping, Self
7
+
8
+ from valanga import (
9
+ BranchKey,
10
+ BranchKeyGeneratorP,
11
+ ContentRepresentation,
12
+ State,
13
+ StateTag,
14
+ )
15
+
16
+ from anemone.indices.node_indices import NodeExplorationData
17
+ from anemone.node_evaluation.node_tree_evaluation.node_tree_evaluation import (
18
+ NodeTreeEvaluation,
19
+ )
20
+ from anemone.nodes.tree_node import TreeNode
21
+
22
+
23
+ class AlgorithmNode[StateT: State = State]:
24
+ """
25
+ The generic Node used by the tree and value algorithm.
26
+ It wraps tree nodes with values, minimax computation and exploration tools
27
+ """
28
+
29
+ tree_node: TreeNode[Self, StateT]
30
+ # the reference to the tree node that is wrapped pointing to other AlgorithmNodes
31
+ tree_evaluation: NodeTreeEvaluation[StateT] # Use Any to break circular dependency
32
+ exploration_index_data: (
33
+ NodeExplorationData[Self, StateT] | None
34
+ ) # the object storing the information to help the algorithm decide the next nodes to explore
35
+ _state_representation: (
36
+ ContentRepresentation | None
37
+ ) # the state representation for evaluation
38
+
39
+ @property
40
+ def state_representation(self) -> ContentRepresentation | None:
41
+ """
42
+ Returns the state representation.
43
+ """
44
+ return self._state_representation
45
+
46
+ def __init__(
47
+ self,
48
+ tree_node: TreeNode[Self, StateT],
49
+ tree_evaluation: NodeTreeEvaluation[StateT],
50
+ exploration_index_data: NodeExplorationData[Self, StateT] | None,
51
+ state_representation: ContentRepresentation | None,
52
+ ) -> None:
53
+ """
54
+ Initializes an AlgorithmNode object.
55
+
56
+ Args:
57
+ tree_node (TreeNode): The tree node that is wrapped.
58
+ tree_evaluation (NodeTreeEvaluation): The object computing the value.
59
+ exploration_index_data (NodeExplorationData | None): The object storing the information to help the algorithm decide the next nodes to explore.
60
+ state_representation (StateRepresentation | None): The board representation.
61
+ """
62
+ self.tree_node = tree_node
63
+ self.tree_evaluation = tree_evaluation
64
+ self.exploration_index_data = exploration_index_data
65
+ self._state_representation = state_representation
66
+
67
+ @property
68
+ def id(self) -> int:
69
+ """
70
+ Returns the ID of the node.
71
+
72
+ Returns:
73
+ int: The ID of the node.
74
+ """
75
+ return self.tree_node.id
76
+
77
+ @property
78
+ def tree_depth(self) -> int:
79
+ """
80
+ Returns the tree depth.
81
+
82
+ Returns:
83
+ int: The tree depth.
84
+ """
85
+ return self.tree_node.tree_depth_
86
+
87
+ @property
88
+ def tag(self) -> StateTag:
89
+ """
90
+ Returns the fast representation of the node.
91
+
92
+ Returns:
93
+ str: The fast representation of the node.
94
+ """
95
+ return self.tree_node.tag
96
+
97
+ @property
98
+ def branches_children(self) -> MutableMapping[BranchKey, Self | None]:
99
+ """
100
+ Returns the bidirectional dictionary of moves and their corresponding child nodes.
101
+
102
+ Returns:
103
+ dict[IMove, ITreeNode | None]: The bidirectional dictionary of moves and their corresponding child nodes.
104
+ """
105
+ return self.tree_node.branches_children
106
+
107
+ @property
108
+ def parent_nodes(self) -> dict[Self, BranchKey]:
109
+ """
110
+ Returns the dictionary of parent nodes of the current tree node with associated move.
111
+
112
+ :return: A dictionary of parent nodes of the current tree node with associated move.
113
+ """
114
+ return self.tree_node.parent_nodes
115
+
116
+ @property
117
+ def state(self) -> StateT:
118
+ """
119
+ Returns the state associated with this tree node.
120
+
121
+ Returns:
122
+ StateWithTag: The state associated with this tree node.
123
+ """
124
+ return self.tree_node.state
125
+
126
+ def is_over(self) -> bool:
127
+ """
128
+ Checks if the game is over.
129
+
130
+ Returns:
131
+ bool: True if the game is over, False otherwise.
132
+ """
133
+ return self.tree_evaluation.is_over()
134
+
135
+ def add_parent(self, branch_key: BranchKey, new_parent_node: Self) -> None:
136
+ """
137
+ Adds a parent node.
138
+
139
+ Args:
140
+ branch_key (BranchKey): The branch key associated with the move that led to the node from the new_parent_node.
141
+ new_parent_node (ITreeNode): The new parent node to add.
142
+ """
143
+ self.tree_node.add_parent(
144
+ branch_key=branch_key, new_parent_node=new_parent_node
145
+ )
146
+
147
+ @property
148
+ def all_branches_keys(self) -> BranchKeyGeneratorP:
149
+ """
150
+ Returns a generator that yields the branch keys for the current board state.
151
+
152
+ Returns:
153
+ BranchKeyGenerator: A generator that yields the branch keys.
154
+ """
155
+ return self.tree_node.state_.branch_keys
156
+
157
+ @property
158
+ def all_branches_generated(self) -> bool:
159
+ """
160
+ Returns True if all branches have been generated, False otherwise.
161
+
162
+ Returns:
163
+ bool: True if all branches have been generated, False otherwise.
164
+ """
165
+ return self.tree_node.all_branches_generated
166
+
167
+ @all_branches_generated.setter
168
+ def all_branches_generated(self, value: bool) -> None:
169
+ """
170
+ Sets the flag indicating if all branches have been generated.
171
+
172
+ Args:
173
+ value (bool): The value to set.
174
+ """
175
+ self.tree_node.all_branches_generated = value
176
+
177
+ @property
178
+ def non_opened_branches(self) -> set[BranchKey]:
179
+ """
180
+ Returns the set of non-opened branches.
181
+
182
+ Returns:
183
+ set[BranchKey]: The set of non-opened branches.
184
+ """
185
+ return self.tree_node.non_opened_branches
186
+
187
+ def dot_description(self) -> str:
188
+ """
189
+ Returns the dot description of the node.
190
+
191
+ Returns:
192
+ str: The dot description of the node.
193
+ """
194
+ exploration_description: str = (
195
+ self.exploration_index_data.dot_description()
196
+ if self.exploration_index_data is not None
197
+ else ""
198
+ )
199
+
200
+ return f"{self.tree_node.dot_description()}\n{self.tree_evaluation.dot_description()}\n{exploration_description}"
201
+
202
+ def __str__(self) -> str:
203
+ """Return a concise string representation of the node."""
204
+ return f"{self.__class__} id :{self.tree_node.id}"
@@ -0,0 +1,136 @@
1
+ """
2
+ This module defines the interface for a tree node in a chess move selector.
3
+
4
+ The `ITreeNode` protocol represents a node in a tree structure used for selecting chess moves.
5
+ It provides properties and methods for accessing information about the node, such as its ID,
6
+ the chess board state, the half move count, the child nodes, and the parent nodes.
7
+
8
+ The `ITreeNode` protocol also defines methods for adding a parent node, generating a dot description
9
+ for visualization, checking if all legal moves have been generated, accessing the legal moves,
10
+ and checking if the game is over.
11
+
12
+ Note: This is an interface and should not be instantiated directly.
13
+ """
14
+
15
+ from typing import MutableMapping, Protocol, Self
16
+
17
+ from valanga import BranchKey, BranchKeyGeneratorP, State, StateTag
18
+
19
+
20
+ class ITreeNode[StateT: State = State](Protocol):
21
+ """
22
+ The `ITreeNode` protocol represents a node in a tree structure used for selecting chess moves.
23
+ """
24
+
25
+ @property
26
+ def id(self) -> int:
27
+ """
28
+ Get the ID of the node.
29
+
30
+ Returns:
31
+ The ID of the node.
32
+ """
33
+ ...
34
+
35
+ @property
36
+ def state(self) -> StateT:
37
+ """
38
+ Get the chess board state of the node.
39
+
40
+ Returns:
41
+ The chess board state of the node.
42
+ """
43
+ ...
44
+
45
+ @property
46
+ def tree_depth(self) -> int:
47
+ """
48
+ Get the tree depth of the node.
49
+
50
+ Returns:
51
+ The tree depth of the node.
52
+ """
53
+ ...
54
+
55
+ @property
56
+ def branches_children(self) -> MutableMapping[BranchKey, Self | None]:
57
+ """
58
+ Get the child nodes of the node.
59
+
60
+ Returns:
61
+ A bidirectional dictionary mapping branches to child nodes.
62
+ """
63
+ ...
64
+
65
+ @property
66
+ def parent_nodes(self) -> dict[Self, BranchKey]:
67
+ """
68
+ Returns the dictionary of parent nodes of the current tree node with associated move.
69
+
70
+ :return: A dictionary of parent nodes of the current tree node with associated move.
71
+ """
72
+ ...
73
+
74
+ def add_parent(self, branch_key: BranchKey, new_parent_node: Self) -> None:
75
+ """
76
+ Add a parent node to the node.
77
+
78
+ Args:
79
+ new_parent_node: The parent node to add.
80
+ move (chess.Move): the move that led to the node from the new_parent_node
81
+
82
+ """
83
+
84
+ def dot_description(self) -> str:
85
+ """
86
+ Generate a dot description for visualization.
87
+
88
+ Returns:
89
+ A string containing the dot description.
90
+ """
91
+ ...
92
+
93
+ @property
94
+ def all_branches_generated(self) -> bool:
95
+ """
96
+ Check if all branches have been generated.
97
+
98
+ Returns:
99
+ True if all branches have been generated, False otherwise.
100
+ """
101
+ ...
102
+
103
+ @all_branches_generated.setter
104
+ def all_branches_generated(self, value: bool) -> None:
105
+ """
106
+ Set the flag indicating that all branches have been generated.
107
+ """
108
+
109
+ @property
110
+ def all_branches_keys(self) -> BranchKeyGeneratorP:
111
+ """
112
+ Get the legal moves of the node.
113
+
114
+ Returns:
115
+ A generator for iterating over the legal moves.
116
+ """
117
+ ...
118
+
119
+ @property
120
+ def tag(self) -> StateTag:
121
+ """
122
+ Get the fast representation of the node.
123
+
124
+ Returns:
125
+ The fast representation of the node as a string.
126
+ """
127
+ ...
128
+
129
+ def is_over(self) -> bool:
130
+ """
131
+ Check if the game is over.
132
+
133
+ Returns:
134
+ True if the game is over, False otherwise.
135
+ """
136
+ ...
@@ -0,0 +1,240 @@
1
+ """
2
+ This module defines the TreeNode class, which represents a node in a tree structure for a chess game.
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any
7
+
8
+ from valanga import BranchKey, BranchKeyGeneratorP, State, StateTag
9
+
10
+ from .itree_node import ITreeNode
11
+
12
+ # todo replace the any with a defaut value in ITReenode when availble in python; 3.13?
13
+
14
+
15
+ @dataclass(slots=True)
16
+ class TreeNode[
17
+ FamilyT: ITreeNode[Any] = ITreeNode[Any],
18
+ StateT: State = State,
19
+ ]:
20
+ r"""
21
+ The TreeNode class stores information about a specific board position, including the board representation,
22
+ the player to move, the half-move count, and the parent-child relationships with other nodes.
23
+
24
+ Attributes:
25
+ id\_ (int): The number to identify this node for easier debugging.
26
+ half_move\_ (int): The number of half-moves since the start of the game to reach the board position.
27
+ board\_ (boards.BoardChi): The board representation of the node.
28
+ parent_nodes\_ (set[ITreeNode]): The set of parent nodes to this node.
29
+ all_legal_moves_generated (bool): A boolean indicating whether all moves have been generated.
30
+ non_opened_legal_moves (set[chess.Move]): The set of non-opened legal moves.
31
+ moves_children\_ (dict[chess.Move, ITreeNode | None]): The dictionary mapping moves to child nodes.
32
+ fast_rep (str): The fast representation of the board.
33
+ player_to_move\_ (chess.Color): The color of the player that has to move in the board.
34
+
35
+ Methods:
36
+ __post_init__(): Initializes the TreeNode object after it has been created.
37
+ id(): Returns the id of the node.
38
+ player_to_move(): Returns the color of the player to move.
39
+ board(): Returns the board representation.
40
+ half_move(): Returns the number of half-moves.
41
+ moves_children(): Returns the dictionary mapping moves to child nodes.
42
+ parent_nodes(): Returns the set of parent nodes.
43
+ is_root_node(): Checks if the node is a root node.
44
+ legal_moves(): Returns the legal moves of the board.
45
+ add_parent(new_parent_node: ITreeNode): Adds a parent node to the current node.
46
+ is_over(): Checks if the game is over.
47
+ print_moves_children(): Prints the moves-children links of the node.
48
+ test(): Performs a test on the node.
49
+ dot_description(): Returns the dot description of the node.
50
+ test_all_legal_moves_generated(): Tests if all legal moves have been generated.
51
+ get_descendants(): Returns a dictionary of descendants of the node.
52
+ """
53
+
54
+ # id is a number to identify this node for easier debug
55
+ id_: int
56
+
57
+ # the tree depth of this node
58
+ tree_depth_: int
59
+
60
+ # the node holds a state.
61
+ state_: StateT
62
+
63
+ # the set of parent nodes to this node. Note that a node can have multiple parents!
64
+ parent_nodes_: dict[FamilyT, BranchKey]
65
+
66
+ # all_branches_generated is a boolean saying whether all branches have been generated.
67
+ # If true the branches are either opened in which case the corresponding opened node is stored in
68
+ # the dictionary self.branches_children, otherwise it is stored in self.non_opened_branches
69
+ all_branches_generated: bool = False
70
+
71
+ @staticmethod
72
+ def _empty_non_opened_branches() -> set[BranchKey]:
73
+ """Return a new empty set for non-opened branches."""
74
+ return set()
75
+
76
+ @staticmethod
77
+ def _empty_branches_children() -> dict[BranchKey, FamilyT | None]:
78
+ """Return a new empty mapping for branch children."""
79
+ return {}
80
+
81
+ non_opened_branches: set[BranchKey] = field(
82
+ default_factory=_empty_non_opened_branches
83
+ )
84
+
85
+ # dictionary mapping moves to children nodes. Node is set to None if not created
86
+ branches_children_: dict[BranchKey, FamilyT | None] = field(
87
+ default_factory=_empty_branches_children
88
+ )
89
+
90
+ @property
91
+ def tag(self) -> StateTag:
92
+ """Returns the fast representation of the board.
93
+
94
+ Returns:
95
+ boards.boardKey: The fast representation of the board.
96
+ """
97
+ return self.state_.tag
98
+
99
+ @property
100
+ def id(self) -> int:
101
+ """
102
+ Returns the ID of the tree node.
103
+
104
+ Returns:
105
+ int: The ID of the tree node.
106
+ """
107
+ return self.id_
108
+
109
+ @property
110
+ def state(self) -> StateT:
111
+ """
112
+ Returns the state associated with this tree node.
113
+
114
+ Returns:
115
+ State: The state associated with this tree node.
116
+ """
117
+ return self.state_
118
+
119
+ @property
120
+ def tree_depth(self) -> int:
121
+ """
122
+ Returns the tree depth of this node.
123
+
124
+ Returns:
125
+ int: The tree depth of this node.
126
+ """
127
+ return self.tree_depth_
128
+
129
+ @property
130
+ def branches_children(self) -> dict[BranchKey, FamilyT | None]:
131
+ """
132
+ Returns a bidirectional dictionary containing the children nodes of the current tree node,
133
+ along with the corresponding chess moves that lead to each child node.
134
+
135
+ Returns:
136
+ dict[BranchKey, ITreeNode | None]: A bidirectional dictionary mapping branches to
137
+ the corresponding child nodes. If a branch does not have a corresponding child node, it is
138
+ mapped to None.
139
+ """
140
+ return self.branches_children_
141
+
142
+ @property
143
+ def parent_nodes(self) -> dict[FamilyT, BranchKey]:
144
+ """
145
+ Returns the dictionary of parent nodes of the current tree node with associated move.
146
+
147
+ :return: A dictionary of parent nodes of the current tree node with associated move.
148
+ """
149
+ return self.parent_nodes_
150
+
151
+ def is_root_node(self) -> bool:
152
+ """
153
+ Check if the current node is a root node.
154
+
155
+ Returns:
156
+ bool: True if the node is a root node, False otherwise.
157
+ """
158
+ return not self.parent_nodes
159
+
160
+ @property
161
+ def all_branches_keys(self) -> BranchKeyGeneratorP:
162
+ """
163
+ Returns a generator that yields the branch keys for the current board state.
164
+
165
+ Returns:
166
+ BranchKeyGenerator: A generator that yields the branch keys.
167
+ """
168
+ return self.state_.branch_keys
169
+
170
+ def add_parent(self, branch_key: BranchKey, new_parent_node: FamilyT) -> None:
171
+ """
172
+ Adds a new parent node to the current node.
173
+
174
+ Args:
175
+ branch_key (BranchKey): The branch key associated with the move that led to the node from the new_parent_node.
176
+ new_parent_node (ITreeNode): The new parent node to be added.
177
+
178
+ Raises:
179
+ AssertionError: If the new parent node is already in the parent nodes set.
180
+
181
+ Returns:
182
+ None
183
+ """
184
+ # debug
185
+ assert (
186
+ new_parent_node not in self.parent_nodes
187
+ ) # there cannot be two ways to link the same child-parent
188
+ self.parent_nodes[new_parent_node] = branch_key
189
+
190
+ def is_over(self) -> bool:
191
+ """
192
+ Checks if the game is over.
193
+
194
+ Returns:
195
+ bool: True if the game is over, False otherwise.
196
+ """
197
+ return self.state.is_game_over()
198
+
199
+ def print_branches_children(self) -> None:
200
+ """
201
+ Prints the branches-children link of the node.
202
+
203
+ This method prints the branches-children link of the node, showing the branch and the ID of the child node.
204
+ If a child node is None, it will be displayed as 'None'.
205
+
206
+ Returns:
207
+ None
208
+ """
209
+ print(
210
+ "here are the ",
211
+ len(self.branches_children_),
212
+ " branches-children link of node",
213
+ self.id,
214
+ ": ",
215
+ end=" ",
216
+ )
217
+ for branch, child in self.branches_children_.items():
218
+ if child is None:
219
+ print(branch, child, end=" ")
220
+ else:
221
+ print(branch, child.id, end=" ")
222
+ print(" ")
223
+
224
+ def dot_description(self) -> str:
225
+ """
226
+ Returns a string representation of the node in the DOT format.
227
+
228
+ The string includes the node's ID, half move, and board FEN.
229
+
230
+ Returns:
231
+ A string representation of the node in the DOT format.
232
+ """
233
+ return (
234
+ "id:"
235
+ + str(self.id)
236
+ + " dep: "
237
+ + str(self.tree_depth)
238
+ + "\nfen:"
239
+ + str(self.state.tag)
240
+ )
@@ -0,0 +1,108 @@
1
+ """
2
+ This module provides functions for traversing a tree of nodes.
3
+
4
+ The functions in this module allow you to retrieve descendants of a given node in a tree structure.
5
+ """
6
+
7
+ from typing import Any
8
+
9
+ from .algorithm_node import AlgorithmNode
10
+ from .itree_node import ITreeNode
11
+
12
+
13
+ def get_descendants[NodeT: ITreeNode[Any]](from_tree_node: NodeT) -> dict[NodeT, None]:
14
+ """
15
+ Get all descendants of a given tree node.
16
+
17
+ Args:
18
+ from_tree_node (ITreeNode): The starting tree node.
19
+
20
+ Returns:
21
+ dict[ITreeNode, None]: A dictionary containing all descendants of the starting tree node.
22
+ """
23
+ des: dict[NodeT, None] = {from_tree_node: None} # include itself
24
+ generation: set[NodeT] = {
25
+ node for node in from_tree_node.branches_children.values() if node is not None
26
+ }
27
+
28
+ while generation:
29
+ next_depth_generation: set[NodeT] = set()
30
+ for node in generation:
31
+ assert node is not None
32
+ des[node] = None
33
+ for _, next_generation_child in node.branches_children.items():
34
+ if next_generation_child is not None:
35
+ next_depth_generation.add(next_generation_child)
36
+ generation = next_depth_generation
37
+ return des
38
+
39
+
40
+ def get_descendants_candidate_to_open[NodeT: AlgorithmNode[Any]](
41
+ from_tree_node: NodeT, max_depth: int | None = None
42
+ ) -> list[NodeT]:
43
+ """
44
+ Get descendants of a given tree node that are not over.
45
+
46
+ Args:
47
+ from_tree_node (AlgorithmNode): The starting tree node.
48
+ max_depth (int | None, optional): The maximum depth to traverse. Defaults to None.
49
+
50
+ Returns:
51
+ list[AlgorithmNode]: A list of descendants that are not over.
52
+ """
53
+ if not from_tree_node.all_branches_generated and not from_tree_node.is_over():
54
+ # should use are_all_moves_and_children_opened() but its messy!
55
+ # also using is_over is messy as over_events are defined in a child class!!!
56
+ des = {from_tree_node: None} # include itself maybe
57
+ else:
58
+ des = {}
59
+ generation: set[NodeT] = {
60
+ node for node in from_tree_node.branches_children.values() if node is not None
61
+ }
62
+ depth: int = 1
63
+ assert max_depth is not None
64
+ while generation and depth <= max_depth:
65
+ next_depth_generation: set[NodeT] = set()
66
+ for node in generation:
67
+ if not node.all_branches_generated and not node.is_over():
68
+ des[node] = None
69
+ for _, next_generation_child in node.branches_children.items():
70
+ if next_generation_child is not None:
71
+ next_depth_generation.add(next_generation_child)
72
+ generation = next_depth_generation
73
+ return list(des.keys())
74
+
75
+
76
+ def get_descendants_candidate_not_over[NodeT: AlgorithmNode[Any]](
77
+ from_tree_node: NodeT, max_depth: int | None = None
78
+ ) -> list[NodeT]:
79
+ """
80
+ Get descendants of a given tree node that are not over.
81
+
82
+ Args:
83
+ from_tree_node (ITreeNode): The starting tree node.
84
+ max_depth (int | None, optional): The maximum depth to traverse. Defaults to None.
85
+
86
+ Returns:
87
+ list[ITreeNode]: A list of descendants that are not over.
88
+ """
89
+ assert not from_tree_node.is_over()
90
+ if not from_tree_node.branches_children:
91
+ return [from_tree_node]
92
+ des: dict[NodeT, None] = {}
93
+ generation: set[NodeT] = {
94
+ node for node in from_tree_node.branches_children.values() if node is not None
95
+ }
96
+
97
+ depth: int = 1
98
+ assert max_depth is not None
99
+ while generation and depth <= max_depth:
100
+ next_depth_generation: set[NodeT] = set()
101
+ for node in generation:
102
+ if not node.is_over():
103
+ des[node] = None
104
+ for _, next_generation_child in node.branches_children.items():
105
+ if next_generation_child is not None:
106
+ next_depth_generation.add(next_generation_child)
107
+ generation = next_depth_generation
108
+ return list(des.keys())