algorhino-anemone 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. algorhino_anemone-0.1.1.dist-info/METADATA +151 -0
  2. algorhino_anemone-0.1.1.dist-info/RECORD +82 -0
  3. algorhino_anemone-0.1.1.dist-info/WHEEL +5 -0
  4. algorhino_anemone-0.1.1.dist-info/licenses/LICENSE +674 -0
  5. algorhino_anemone-0.1.1.dist-info/top_level.txt +1 -0
  6. anemone/__init__.py +27 -0
  7. anemone/basics.py +36 -0
  8. anemone/factory.py +161 -0
  9. anemone/indices/__init__.py +0 -0
  10. anemone/indices/index_manager/__init__.py +12 -0
  11. anemone/indices/index_manager/factory.py +50 -0
  12. anemone/indices/index_manager/node_exploration_manager.py +549 -0
  13. anemone/indices/node_indices/__init__.py +22 -0
  14. anemone/indices/node_indices/factory.py +121 -0
  15. anemone/indices/node_indices/index_data.py +166 -0
  16. anemone/indices/node_indices/index_types.py +20 -0
  17. anemone/nn/torch_evaluator.py +108 -0
  18. anemone/node_evaluation/__init__.py +0 -0
  19. anemone/node_evaluation/node_direct_evaluation/__init__.py +22 -0
  20. anemone/node_evaluation/node_direct_evaluation/factory.py +12 -0
  21. anemone/node_evaluation/node_direct_evaluation/node_direct_evaluator.py +192 -0
  22. anemone/node_evaluation/node_tree_evaluation/node_minmax_evaluation.py +885 -0
  23. anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation.py +137 -0
  24. anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation_factory.py +43 -0
  25. anemone/node_factory/__init__.py +14 -0
  26. anemone/node_factory/algorithm_node_factory.py +123 -0
  27. anemone/node_factory/base.py +76 -0
  28. anemone/node_selector/__init__.py +32 -0
  29. anemone/node_selector/branch_explorer.py +89 -0
  30. anemone/node_selector/factory.py +65 -0
  31. anemone/node_selector/node_selector.py +44 -0
  32. anemone/node_selector/node_selector_args.py +22 -0
  33. anemone/node_selector/node_selector_types.py +15 -0
  34. anemone/node_selector/notations_and_statics.py +88 -0
  35. anemone/node_selector/opening_instructions.py +249 -0
  36. anemone/node_selector/recurzipf/__init__.py +0 -0
  37. anemone/node_selector/recurzipf/recur_zipf_base.py +141 -0
  38. anemone/node_selector/sequool/__init__.py +19 -0
  39. anemone/node_selector/sequool/factory.py +102 -0
  40. anemone/node_selector/sequool/sequool.py +395 -0
  41. anemone/node_selector/uniform/__init__.py +16 -0
  42. anemone/node_selector/uniform/uniform.py +113 -0
  43. anemone/nodes/__init__.py +15 -0
  44. anemone/nodes/algorithm_node/__init__.py +7 -0
  45. anemone/nodes/algorithm_node/algorithm_node.py +204 -0
  46. anemone/nodes/itree_node.py +136 -0
  47. anemone/nodes/tree_node.py +240 -0
  48. anemone/nodes/tree_traversal.py +108 -0
  49. anemone/nodes/utils.py +146 -0
  50. anemone/progress_monitor/__init__.py +0 -0
  51. anemone/progress_monitor/progress_monitor.py +375 -0
  52. anemone/recommender_rule/__init__.py +12 -0
  53. anemone/recommender_rule/recommender_rule.py +140 -0
  54. anemone/search_factory/__init__.py +14 -0
  55. anemone/search_factory/search_factory.py +192 -0
  56. anemone/state_transition.py +47 -0
  57. anemone/tree_and_value_branch_selector.py +99 -0
  58. anemone/tree_exploration.py +274 -0
  59. anemone/tree_manager/__init__.py +29 -0
  60. anemone/tree_manager/algorithm_node_tree_manager.py +246 -0
  61. anemone/tree_manager/factory.py +77 -0
  62. anemone/tree_manager/tree_expander.py +122 -0
  63. anemone/tree_manager/tree_manager.py +254 -0
  64. anemone/trees/__init__.py +14 -0
  65. anemone/trees/descendants.py +765 -0
  66. anemone/trees/factory.py +80 -0
  67. anemone/trees/tree.py +70 -0
  68. anemone/trees/tree_visualization.py +143 -0
  69. anemone/updates/__init__.py +33 -0
  70. anemone/updates/algorithm_node_updater.py +157 -0
  71. anemone/updates/factory.py +36 -0
  72. anemone/updates/index_block.py +91 -0
  73. anemone/updates/index_updater.py +100 -0
  74. anemone/updates/minmax_evaluation_updater.py +108 -0
  75. anemone/updates/updates_file.py +248 -0
  76. anemone/updates/value_block.py +133 -0
  77. anemone/utils/comparable.py +32 -0
  78. anemone/utils/dataclass.py +64 -0
  79. anemone/utils/dict_of_numbered_dict_with_pointer_on_max.py +128 -0
  80. anemone/utils/logger.py +94 -0
  81. anemone/utils/my_value_sorted_dict.py +27 -0
  82. anemone/utils/small_tools.py +103 -0
@@ -0,0 +1,246 @@
1
+ """
2
+ Defining the AlgorithmNodeTreeManager class
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ from valanga import BranchKey
9
+
10
+ from anemone import trees
11
+ from anemone.indices.index_manager import (
12
+ NodeExplorationIndexManager,
13
+ )
14
+ from anemone.indices.index_manager.node_exploration_manager import (
15
+ update_all_indices,
16
+ )
17
+ from anemone.node_evaluation.node_direct_evaluation import (
18
+ EvaluationQueries,
19
+ NodeDirectEvaluator,
20
+ )
21
+ from anemone.node_factory.algorithm_node_factory import AlgorithmNodeFactory
22
+ from anemone.nodes.algorithm_node.algorithm_node import (
23
+ AlgorithmNode,
24
+ )
25
+ from anemone.updates.algorithm_node_updater import AlgorithmNodeUpdater
26
+ from anemone.updates.updates_file import (
27
+ UpdateInstructionsFromOneNode,
28
+ UpdateInstructionsTowardsMultipleNodes,
29
+ UpdateInstructionsTowardsOneParentNode,
30
+ )
31
+
32
+ from .tree_expander import TreeExpansion, TreeExpansions, record_tree_expansion
33
+ from .tree_manager import TreeManager
34
+
35
+ # todo should we use a discount? and discounted per round reward?
36
+ # todo maybe convenient to seperate this object into openner updater and dsiplayer
37
+ # todo have the reward with a discount
38
+ # DISCOUNT = 1/.99999
39
+ if TYPE_CHECKING:
40
+ from anemone import node_selector as node_sel
41
+
42
+
43
+ @dataclass
44
+ class AlgorithmNodeTreeManager[NodeT: AlgorithmNode[Any] = AlgorithmNode[Any]]:
45
+ """
46
+ This class that and manages a tree by opening new nodes and updating the values and indexes on the nodes.
47
+ It wraps around the Tree Manager class as it has a tree_manager as member and adds functionality as this handles
48
+ trees with nodes that are of the class AlgorithmNode (managing the value for instance)
49
+ """
50
+
51
+ tree_manager: TreeManager[NodeT]
52
+ algorithm_tree_node_factory: AlgorithmNodeFactory
53
+
54
+ algorithm_node_updater: AlgorithmNodeUpdater
55
+ evaluation_queries: EvaluationQueries
56
+ node_evaluator: NodeDirectEvaluator | None
57
+ index_manager: NodeExplorationIndexManager
58
+
59
+ def open_tree_expansion_from_branch(
60
+ self,
61
+ tree: trees.Tree[NodeT],
62
+ parent_node: NodeT,
63
+ branch: BranchKey,
64
+ ) -> TreeExpansion[NodeT]:
65
+ """
66
+
67
+ Args:
68
+ tree: the tree to open
69
+ parent_node: the node to open
70
+ move: to move to open with
71
+
72
+ Returns: the tree expansions
73
+
74
+ """
75
+
76
+ tree_expansion: TreeExpansion[NodeT]
77
+ tree_expansion = self.tree_manager.open_tree_expansion_from_branch(
78
+ tree=tree, parent_node=parent_node, branch=branch
79
+ )
80
+
81
+ parent_node.tree_evaluation.branches_not_over.append(
82
+ branch
83
+ ) # default action checks for over event are performed later
84
+
85
+ return tree_expansion
86
+
87
+ def open_instructions(
88
+ self,
89
+ tree: trees.Tree[NodeT],
90
+ opening_instructions: "node_sel.OpeningInstructions[NodeT]",
91
+ ) -> TreeExpansions[NodeT]:
92
+ """
93
+
94
+ Args:
95
+ tree: the tree object to open
96
+ opening_instructions: the opening instructions
97
+
98
+ Returns: the expansions that have been performed
99
+
100
+ """
101
+
102
+ # place to store the tree expansion logs generated by the openings
103
+ tree_expansions: TreeExpansions[NodeT] = TreeExpansions()
104
+
105
+ opening_instruction: node_sel.OpeningInstruction[NodeT]
106
+ tree_expansion: TreeExpansion[NodeT]
107
+ for opening_instruction in opening_instructions.values():
108
+ # open
109
+ assert isinstance(opening_instruction.node_to_open, AlgorithmNode)
110
+ tree_expansion = self.open_tree_expansion_from_branch(
111
+ tree=tree,
112
+ parent_node=opening_instruction.node_to_open,
113
+ branch=opening_instruction.branch,
114
+ )
115
+
116
+ print("opened", tree_expansion)
117
+
118
+ record_tree_expansion(
119
+ tree=tree,
120
+ tree_expansions=tree_expansions,
121
+ tree_expansion=tree_expansion,
122
+ )
123
+
124
+ assert self.node_evaluator is not None
125
+ for tree_expansion in tree_expansions.expansions_with_node_creation:
126
+ # TODO give the tree expansion to the function directly
127
+ assert isinstance(tree_expansion.child_node, AlgorithmNode)
128
+ self.node_evaluator.add_evaluation_query(
129
+ node=tree_expansion.child_node,
130
+ evaluation_queries=self.evaluation_queries,
131
+ )
132
+
133
+ self.node_evaluator.evaluate_all_queried_nodes(
134
+ evaluation_queries=self.evaluation_queries
135
+ )
136
+
137
+ return tree_expansions
138
+
139
+ def update_indices(self, tree: trees.Tree[NodeT]) -> None:
140
+ """
141
+ Updates the indices of the nodes in the given tree.
142
+
143
+ Args:
144
+ tree (ValueTree): The tree whose indices need to be updated.
145
+
146
+ Returns:
147
+ None
148
+ """
149
+ update_all_indices(index_manager=self.index_manager, tree=tree)
150
+
151
+ def update_backward(self, tree_expansions: TreeExpansions[NodeT]) -> None:
152
+ """
153
+ Updates the algorithm node tree in a backward manner based on the given tree expansions.
154
+
155
+ Args:
156
+ tree_expansions (TreeExpansions): The tree expansions used to update the algorithm node tree.
157
+
158
+ Returns:
159
+ None
160
+ """
161
+ update_instructions_batch: UpdateInstructionsTowardsMultipleNodes[NodeT]
162
+ update_instructions_batch = (
163
+ self.algorithm_node_updater.generate_update_instructions(
164
+ tree_expansions=tree_expansions
165
+ )
166
+ )
167
+
168
+ while update_instructions_batch:
169
+ node_to_update: NodeT
170
+ update_instructions: UpdateInstructionsTowardsOneParentNode
171
+ node_to_update, update_instructions = update_instructions_batch.pop_item()
172
+ extra_update_instructions_batch: UpdateInstructionsTowardsMultipleNodes[
173
+ NodeT
174
+ ]
175
+ extra_update_instructions_batch = self.update_node(
176
+ node_to_update=node_to_update, update_instructions=update_instructions
177
+ )
178
+ # merge
179
+ while extra_update_instructions_batch.one_node_instructions:
180
+ parent_node_to_update: NodeT
181
+ update: UpdateInstructionsTowardsOneParentNode
182
+ parent_node_to_update, update = (
183
+ extra_update_instructions_batch.pop_item()
184
+ )
185
+ update_instructions_batch.add_updates_towards_one_parent_node(
186
+ parent_node=parent_node_to_update, update_from_child_node=update
187
+ )
188
+
189
+ def update_node(
190
+ self,
191
+ node_to_update: NodeT,
192
+ update_instructions: UpdateInstructionsTowardsOneParentNode,
193
+ ) -> UpdateInstructionsTowardsMultipleNodes[NodeT]:
194
+ """
195
+ Updates the given node with the provided update instructions.
196
+
197
+ Args:
198
+ node_to_update (AlgorithmNode): The node to be updated.
199
+ update_instructions (UpdateInstructions): The instructions for updating the node.
200
+
201
+ Returns:
202
+ UpdateInstructionsBatch: A batch of update instructions for the parent nodes of the updated node.
203
+ """
204
+
205
+ # UPDATES
206
+ new_update_instructions: UpdateInstructionsFromOneNode = (
207
+ self.algorithm_node_updater.perform_updates(
208
+ node_to_update=node_to_update, update_instructions=update_instructions
209
+ )
210
+ )
211
+
212
+ update_instructions_batch: UpdateInstructionsTowardsMultipleNodes[NodeT]
213
+ update_instructions_batch = UpdateInstructionsTowardsMultipleNodes()
214
+ parent_node: NodeT
215
+ branch_from_parent: BranchKey
216
+ for parent_node, branch_from_parent in node_to_update.parent_nodes.items():
217
+ # there was a test for emptiness here of new updates instructions remove this comment if no bug appear
218
+ assert parent_node not in update_instructions_batch.one_node_instructions
219
+ update_instructions_batch.add_update_from_one_child_node(
220
+ update_from_child_node=new_update_instructions,
221
+ parent_node=parent_node,
222
+ branch_from_parent=branch_from_parent,
223
+ )
224
+
225
+ return update_instructions_batch
226
+
227
+ def print_some_stats(self, tree: trees.Tree[NodeT]) -> None:
228
+ """
229
+ Prints statistics about the given tree.
230
+
231
+ Args:
232
+ tree (ValueTree): The tree to print statistics for.
233
+
234
+ Returns:
235
+ None
236
+ """
237
+ self.tree_manager.print_some_stats(tree=tree)
238
+
239
+ def print_best_line(self, tree: trees.Tree[NodeT]) -> None:
240
+ """
241
+ Prints the best line of moves based on the tree evaluation of the tree.
242
+
243
+ Args:
244
+ tree (ValueTree): The tree containing the moves and their minmax evaluations.
245
+ """
246
+ tree.root_node.tree_evaluation.print_best_line()
@@ -0,0 +1,77 @@
1
+ """
2
+ This module provides a factory function for creating an AlgorithmNodeTreeManager object.
3
+
4
+ The AlgorithmNodeTreeManager is responsible for managing the tree structure of algorithm nodes,
5
+ performing updates on the nodes, and handling evaluation queries.
6
+
7
+ """
8
+
9
+ from typing import Any
10
+
11
+ from anemone import updates as upda
12
+ from anemone.indices.index_manager import (
13
+ NodeExplorationIndexManager,
14
+ create_exploration_index_manager,
15
+ )
16
+ from anemone.indices.node_indices.index_types import (
17
+ IndexComputationType,
18
+ )
19
+ from anemone.node_evaluation.node_direct_evaluation import (
20
+ EvaluationQueries,
21
+ NodeDirectEvaluator,
22
+ )
23
+ from anemone.node_factory import (
24
+ AlgorithmNodeFactory,
25
+ )
26
+ from anemone.nodes.algorithm_node.algorithm_node import AlgorithmNode
27
+ from anemone.state_transition import ValangaStateTransition
28
+ from anemone.updates.index_updater import IndexUpdater
29
+
30
+ from .algorithm_node_tree_manager import AlgorithmNodeTreeManager
31
+ from .tree_manager import TreeManager
32
+
33
+
34
+ def create_algorithm_node_tree_manager(
35
+ node_direct_evaluator: NodeDirectEvaluator[Any] | None,
36
+ algorithm_node_factory: AlgorithmNodeFactory[Any],
37
+ index_computation: IndexComputationType | None,
38
+ index_updater: IndexUpdater | None,
39
+ ) -> AlgorithmNodeTreeManager:
40
+ """
41
+ Create an AlgorithmNodeTreeManager object.
42
+
43
+ Args:
44
+ node_evaluator: The NodeEvaluator object used for evaluating nodes in the tree.
45
+ algorithm_node_factory: The AlgorithmNodeFactory object used for creating algorithm nodes.
46
+ index_computation: The type of index computation to be used.
47
+ index_updater: The IndexUpdater object used for updating the indices.
48
+
49
+ Returns:
50
+ An AlgorithmNodeTreeManager object.
51
+
52
+ """
53
+ tree_manager: TreeManager[AlgorithmNode] = TreeManager[AlgorithmNode](
54
+ node_factory=algorithm_node_factory,
55
+ transition=ValangaStateTransition(),
56
+ )
57
+
58
+ algorithm_node_updater: upda.AlgorithmNodeUpdater = (
59
+ upda.create_algorithm_node_updater(index_updater=index_updater)
60
+ )
61
+
62
+ evaluation_queries: EvaluationQueries = EvaluationQueries()
63
+
64
+ exploration_index_manager: NodeExplorationIndexManager = (
65
+ create_exploration_index_manager(index_computation=index_computation)
66
+ )
67
+
68
+ algorithm_node_tree_manager: AlgorithmNodeTreeManager = AlgorithmNodeTreeManager(
69
+ node_evaluator=node_direct_evaluator,
70
+ tree_manager=tree_manager,
71
+ algorithm_node_updater=algorithm_node_updater,
72
+ algorithm_tree_node_factory=algorithm_node_factory,
73
+ evaluation_queries=evaluation_queries,
74
+ index_manager=exploration_index_manager,
75
+ )
76
+
77
+ return algorithm_node_tree_manager
@@ -0,0 +1,122 @@
1
+ """
2
+ Tree expansion representations for managing game trees.
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Iterator, TypeVar
7
+
8
+ from valanga import BranchKey, StateModifications
9
+
10
+ from anemone import nodes as node
11
+ from anemone import trees
12
+
13
+ NodeT = TypeVar("NodeT", bound=node.ITreeNode[Any])
14
+
15
+
16
+ @dataclass(slots=True)
17
+ class TreeExpansion[NodeT: node.ITreeNode[Any] = node.ITreeNode[Any]]:
18
+ """
19
+ Represents an expansion of a tree in a chess game.
20
+
21
+ Attributes:
22
+ child_node (NodeT): The child node created during the expansion.
23
+ parent_node (node.ITreeNode | None): The parent node of the child node. None if it's the root node.
24
+ board_modifications (board_mod.BoardModification | None): The modifications made to the chess board during the expansion.
25
+ creation_child_node (bool): Indicates whether the child node was created during the expansion.
26
+ move (chess.Move): the move from parent to child node.
27
+ """
28
+
29
+ child_node: NodeT
30
+ parent_node: NodeT | None
31
+ state_modifications: StateModifications | None
32
+ creation_child_node: bool
33
+ branch_key: BranchKey | None
34
+
35
+ def __repr__(self) -> str:
36
+ """Return a debug representation of the tree expansion."""
37
+ return (
38
+ f"child_node{self.child_node.id} | "
39
+ f"parent_node{self.parent_node.id if self.parent_node is not None else None} | "
40
+ f"creation_child_node{self.creation_child_node}"
41
+ )
42
+
43
+
44
+ def record_tree_expansion(
45
+ *,
46
+ tree: trees.Tree[NodeT],
47
+ tree_expansions: "TreeExpansions[NodeT]",
48
+ tree_expansion: TreeExpansion[NodeT],
49
+ ) -> None:
50
+ """Apply one TreeExpansion's bookkeeping to a tree + its log."""
51
+
52
+ if tree_expansion.creation_child_node:
53
+ tree.nodes_count += 1
54
+ tree.descendants.add_descendant(tree_expansion.child_node)
55
+
56
+ tree_expansions.add(tree_expansion=tree_expansion)
57
+
58
+
59
+ def _new_expansions_list() -> list[TreeExpansion[Any]]:
60
+ """Return a new list for tree expansions."""
61
+ return []
62
+
63
+
64
+ @dataclass(slots=True)
65
+ class TreeExpansions[NodeT: node.ITreeNode[Any] = node.ITreeNode[Any]]:
66
+ """
67
+ Represents a collection of tree expansions in a chess game.
68
+
69
+ Attributes:
70
+ expansions_with_node_creation (List[TreeExpansion]): List of expansions where child nodes were created.
71
+ expansions_without_node_creation (List[TreeExpansion]): List of expansions where child nodes were not created.
72
+ """
73
+
74
+ expansions_with_node_creation: list[TreeExpansion[NodeT]] = field(
75
+ default_factory=_new_expansions_list
76
+ )
77
+ expansions_without_node_creation: list[TreeExpansion[NodeT]] = field(
78
+ default_factory=_new_expansions_list
79
+ )
80
+
81
+ def __iter__(self) -> Iterator[TreeExpansion[NodeT]]:
82
+ """Iterate over all recorded expansions."""
83
+ return iter(
84
+ self.expansions_with_node_creation + self.expansions_without_node_creation
85
+ )
86
+
87
+ def add(self, tree_expansion: TreeExpansion[NodeT]) -> None:
88
+ """
89
+ Adds a tree expansion to the collection.
90
+
91
+ Args:
92
+ tree_expansion (TreeExpansion): The tree expansion to add.
93
+ """
94
+ if tree_expansion.creation_child_node:
95
+ self.add_creation(tree_expansion=tree_expansion)
96
+ else:
97
+ self.add_connection(tree_expansion=tree_expansion)
98
+
99
+ def add_creation(self, tree_expansion: TreeExpansion[NodeT]) -> None:
100
+ """
101
+ Adds a tree expansion with a created child node to the collection.
102
+
103
+ Args:
104
+ tree_expansion (TreeExpansion): The tree expansion to add.
105
+ """
106
+ self.expansions_with_node_creation.append(tree_expansion)
107
+
108
+ def add_connection(self, tree_expansion: TreeExpansion[NodeT]) -> None:
109
+ """
110
+ Adds a tree expansion without a created child node to the collection.
111
+
112
+ Args:
113
+ tree_expansion (TreeExpansion): The tree expansion to add.
114
+ """
115
+ self.expansions_without_node_creation.append(tree_expansion)
116
+
117
+ def __str__(self) -> str:
118
+ """Return a string summary of recorded expansions."""
119
+ return (
120
+ f"expansions_with_node_creation {self.expansions_with_node_creation} \n"
121
+ f"expansions_without_node_creation{self.expansions_without_node_creation}"
122
+ )
@@ -0,0 +1,254 @@
1
+ """
2
+ This module contains the TreeManager class, which is responsible for managing a tree by opening new nodes and updating the values and indexes on the nodes.
3
+ """
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ from valanga import BranchKey, State, StateModifications, StateTag
8
+
9
+ from anemone import nodes as node
10
+ from anemone import trees
11
+ from anemone.node_factory.base import (
12
+ NodeFactory,
13
+ )
14
+ from anemone.node_selector.opening_instructions import (
15
+ OpeningInstruction,
16
+ OpeningInstructions,
17
+ )
18
+ from anemone.state_transition import StateTransition
19
+ from anemone.tree_manager.tree_expander import (
20
+ TreeExpansion,
21
+ TreeExpansions,
22
+ record_tree_expansion,
23
+ )
24
+
25
+ if TYPE_CHECKING:
26
+ from anemone.basics import TreeDepth
27
+
28
+ # todo should we use a discount? and discounted per round reward?
29
+ # todo maybe convenient to seperate this object into openner updater and dsiplayer
30
+ # todo have the reward with a discount
31
+ # DISCOUNT = 1/.99999
32
+
33
+
34
+ class TreeManager[
35
+ FamilyT: node.ITreeNode[Any] = node.ITreeNode[Any],
36
+ ]:
37
+ """
38
+ This class manages a tree by opening new nodes.This is the core one only responsible for creating core TreeNodes
39
+ """
40
+
41
+ node_factory: NodeFactory[FamilyT]
42
+ transition: StateTransition[State]
43
+
44
+ def __init__(
45
+ self,
46
+ node_factory: NodeFactory[FamilyT],
47
+ transition: StateTransition[State],
48
+ ) -> None:
49
+ """Initialize the tree manager with a node factory and transition."""
50
+ self.node_factory = node_factory
51
+ self.transition = transition
52
+
53
+ def open_tree_expansion_from_branch(
54
+ self,
55
+ tree: trees.Tree[FamilyT],
56
+ parent_node: FamilyT,
57
+ branch: BranchKey,
58
+ ) -> TreeExpansion[FamilyT]:
59
+ """
60
+ Opening a Node that contains a board following a move.
61
+
62
+ Args:
63
+ tree: The tree object.
64
+ parent_node: The parent node that we want to expand.
65
+ move: The move to play to expand the node.
66
+
67
+ Returns:
68
+ The tree expansion object.
69
+ """
70
+ # The parent board is copied, we only copy the stack (history of previous board) if the depth is smaller than 2
71
+ # Having the stack information allows checking for draw by repetition.
72
+ # To limit computation we limit copying it all the time. The resulting policy will only be aware of immediate
73
+ # risk of draw by repetition
74
+ copy_stack: bool = tree.node_depth(parent_node) < 2
75
+ parent_state: State = parent_node.state
76
+ state: State = self.transition.copy_for_expansion(
77
+ parent_state,
78
+ copy_stack=copy_stack,
79
+ )
80
+
81
+ # The move is played. The state is now advanced.
82
+ state, modifications = self.transition.step(state, branch_key=branch)
83
+
84
+ return self.open_tree_expansion_from_state(
85
+ tree=tree,
86
+ parent_node=parent_node,
87
+ state=state,
88
+ modifications=modifications,
89
+ branch=branch,
90
+ )
91
+
92
+ def open_tree_expansion_from_state(
93
+ self,
94
+ tree: trees.Tree[FamilyT],
95
+ parent_node: FamilyT,
96
+ state: State,
97
+ modifications: StateModifications | None,
98
+ branch: BranchKey,
99
+ ) -> TreeExpansion[FamilyT]:
100
+ """
101
+ Opening a Node that contains a board given the modifications.
102
+ Checks if the new node needs to be created or if the new_board already existed in the tree
103
+ (was reached from a different serie of move)
104
+
105
+ Args:
106
+ tree: The tree object.
107
+ parent_node: The parent node that we want to expand.
108
+ board: The board object that is a move forward compared to the board in the parent node
109
+ modifications: The board modifications.
110
+ move: The move to play to expand the node.
111
+
112
+ Returns:
113
+ The tree expansion object.
114
+ """
115
+
116
+ # Creation of the child node. If the board already exited in another node, that node is returned as child_node.
117
+ tree_depth: int = parent_node.tree_depth + 1
118
+ state_tag: StateTag = state.tag
119
+
120
+ need_creation_child_node: bool = (
121
+ tree.descendants.is_new_generation(tree_depth)
122
+ or state_tag not in tree.descendants.descendants_at_tree_depth[tree_depth]
123
+ )
124
+
125
+ tree_expansion: TreeExpansion[FamilyT]
126
+
127
+ if need_creation_child_node:
128
+ child_node: FamilyT
129
+ child_node = self.node_factory.create(
130
+ state=state,
131
+ tree_depth=tree_depth,
132
+ count=tree.nodes_count,
133
+ branch_from_parent=branch,
134
+ parent_node=parent_node,
135
+ modifications=modifications,
136
+ )
137
+
138
+ tree_expansion = TreeExpansion(
139
+ child_node=child_node,
140
+ parent_node=parent_node,
141
+ state_modifications=modifications,
142
+ creation_child_node=need_creation_child_node,
143
+ branch_key=branch,
144
+ )
145
+
146
+ else: # the node already exists
147
+ child_node_existing: FamilyT
148
+ child_node_existing = tree.descendants[tree_depth][state_tag]
149
+ child_node_existing.add_parent(
150
+ branch_key=branch, new_parent_node=parent_node
151
+ )
152
+
153
+ tree_expansion = TreeExpansion(
154
+ child_node=child_node_existing,
155
+ parent_node=parent_node,
156
+ state_modifications=modifications,
157
+ creation_child_node=need_creation_child_node,
158
+ branch_key=branch,
159
+ )
160
+
161
+ # add it to the list of opened move and out of the non-opened moves
162
+ parent_node.branches_children[branch] = tree_expansion.child_node
163
+ # parent_node.tree_node.non_opened_legal_moves.remove(move)
164
+ tree.move_count += 1 # counting moves
165
+
166
+ return tree_expansion
167
+
168
+ def open_instructions(
169
+ self,
170
+ tree: trees.Tree[FamilyT],
171
+ opening_instructions: OpeningInstructions[FamilyT],
172
+ ) -> TreeExpansions[FamilyT]:
173
+ """
174
+ Opening multiple nodes based on the opening instructions.
175
+
176
+ Args:
177
+ tree: The tree object.
178
+ opening_instructions: The opening instructions.
179
+
180
+ Returns:
181
+ The tree expansions that have been performed.
182
+ """
183
+
184
+ # place to store the tree expansion logs generated by the openings
185
+ tree_expansions: TreeExpansions[FamilyT] = TreeExpansions()
186
+
187
+ opening_instruction: OpeningInstruction[FamilyT]
188
+ for opening_instruction in opening_instructions.values():
189
+ # open
190
+ tree_expansion: TreeExpansion[FamilyT] = (
191
+ self.open_tree_expansion_from_branch(
192
+ tree=tree,
193
+ parent_node=opening_instruction.node_to_open,
194
+ branch=opening_instruction.branch,
195
+ )
196
+ )
197
+
198
+ record_tree_expansion(
199
+ tree=tree,
200
+ tree_expansions=tree_expansions,
201
+ tree_expansion=tree_expansion,
202
+ )
203
+
204
+ return tree_expansions
205
+
206
+ def print_some_stats(
207
+ self,
208
+ tree: trees.Tree[FamilyT],
209
+ ) -> None:
210
+ """
211
+ Print some statistics about the tree.
212
+
213
+ Args:
214
+ tree: The tree object.
215
+ """
216
+ print(
217
+ "Tree stats: move_count",
218
+ tree.move_count,
219
+ " node_count",
220
+ tree.descendants.get_count(),
221
+ )
222
+ sum_ = 0
223
+ tree.descendants.print_stats()
224
+ tree_depth: TreeDepth
225
+ for tree_depth in tree.descendants:
226
+ sum_ += len(tree.descendants[tree_depth])
227
+ print("tree_depth", tree_depth, len(tree.descendants[tree_depth]), sum_)
228
+
229
+ def test_count(
230
+ self,
231
+ tree: trees.Tree[FamilyT],
232
+ ) -> None:
233
+ """
234
+ Test the count of nodes in the tree.
235
+
236
+ Args:
237
+ tree: The tree object.
238
+ """
239
+ assert tree.descendants.get_count() == tree.nodes_count
240
+
241
+ def print_best_line(
242
+ self,
243
+ tree: trees.Tree[FamilyT],
244
+ ) -> None:
245
+ """
246
+ Print the best line in the tree.
247
+
248
+ Args:
249
+ tree: The tree object.
250
+ """
251
+
252
+ raise NotImplementedError(
253
+ "print_best_line should not be called; override or modify this behavior"
254
+ )