algorhino-anemone 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- algorhino_anemone-0.1.1.dist-info/METADATA +151 -0
- algorhino_anemone-0.1.1.dist-info/RECORD +82 -0
- algorhino_anemone-0.1.1.dist-info/WHEEL +5 -0
- algorhino_anemone-0.1.1.dist-info/licenses/LICENSE +674 -0
- algorhino_anemone-0.1.1.dist-info/top_level.txt +1 -0
- anemone/__init__.py +27 -0
- anemone/basics.py +36 -0
- anemone/factory.py +161 -0
- anemone/indices/__init__.py +0 -0
- anemone/indices/index_manager/__init__.py +12 -0
- anemone/indices/index_manager/factory.py +50 -0
- anemone/indices/index_manager/node_exploration_manager.py +549 -0
- anemone/indices/node_indices/__init__.py +22 -0
- anemone/indices/node_indices/factory.py +121 -0
- anemone/indices/node_indices/index_data.py +166 -0
- anemone/indices/node_indices/index_types.py +20 -0
- anemone/nn/torch_evaluator.py +108 -0
- anemone/node_evaluation/__init__.py +0 -0
- anemone/node_evaluation/node_direct_evaluation/__init__.py +22 -0
- anemone/node_evaluation/node_direct_evaluation/factory.py +12 -0
- anemone/node_evaluation/node_direct_evaluation/node_direct_evaluator.py +192 -0
- anemone/node_evaluation/node_tree_evaluation/node_minmax_evaluation.py +885 -0
- anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation.py +137 -0
- anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation_factory.py +43 -0
- anemone/node_factory/__init__.py +14 -0
- anemone/node_factory/algorithm_node_factory.py +123 -0
- anemone/node_factory/base.py +76 -0
- anemone/node_selector/__init__.py +32 -0
- anemone/node_selector/branch_explorer.py +89 -0
- anemone/node_selector/factory.py +65 -0
- anemone/node_selector/node_selector.py +44 -0
- anemone/node_selector/node_selector_args.py +22 -0
- anemone/node_selector/node_selector_types.py +15 -0
- anemone/node_selector/notations_and_statics.py +88 -0
- anemone/node_selector/opening_instructions.py +249 -0
- anemone/node_selector/recurzipf/__init__.py +0 -0
- anemone/node_selector/recurzipf/recur_zipf_base.py +141 -0
- anemone/node_selector/sequool/__init__.py +19 -0
- anemone/node_selector/sequool/factory.py +102 -0
- anemone/node_selector/sequool/sequool.py +395 -0
- anemone/node_selector/uniform/__init__.py +16 -0
- anemone/node_selector/uniform/uniform.py +113 -0
- anemone/nodes/__init__.py +15 -0
- anemone/nodes/algorithm_node/__init__.py +7 -0
- anemone/nodes/algorithm_node/algorithm_node.py +204 -0
- anemone/nodes/itree_node.py +136 -0
- anemone/nodes/tree_node.py +240 -0
- anemone/nodes/tree_traversal.py +108 -0
- anemone/nodes/utils.py +146 -0
- anemone/progress_monitor/__init__.py +0 -0
- anemone/progress_monitor/progress_monitor.py +375 -0
- anemone/recommender_rule/__init__.py +12 -0
- anemone/recommender_rule/recommender_rule.py +140 -0
- anemone/search_factory/__init__.py +14 -0
- anemone/search_factory/search_factory.py +192 -0
- anemone/state_transition.py +47 -0
- anemone/tree_and_value_branch_selector.py +99 -0
- anemone/tree_exploration.py +274 -0
- anemone/tree_manager/__init__.py +29 -0
- anemone/tree_manager/algorithm_node_tree_manager.py +246 -0
- anemone/tree_manager/factory.py +77 -0
- anemone/tree_manager/tree_expander.py +122 -0
- anemone/tree_manager/tree_manager.py +254 -0
- anemone/trees/__init__.py +14 -0
- anemone/trees/descendants.py +765 -0
- anemone/trees/factory.py +80 -0
- anemone/trees/tree.py +70 -0
- anemone/trees/tree_visualization.py +143 -0
- anemone/updates/__init__.py +33 -0
- anemone/updates/algorithm_node_updater.py +157 -0
- anemone/updates/factory.py +36 -0
- anemone/updates/index_block.py +91 -0
- anemone/updates/index_updater.py +100 -0
- anemone/updates/minmax_evaluation_updater.py +108 -0
- anemone/updates/updates_file.py +248 -0
- anemone/updates/value_block.py +133 -0
- anemone/utils/comparable.py +32 -0
- anemone/utils/dataclass.py +64 -0
- anemone/utils/dict_of_numbered_dict_with_pointer_on_max.py +128 -0
- anemone/utils/logger.py +94 -0
- anemone/utils/my_value_sorted_dict.py +27 -0
- anemone/utils/small_tools.py +103 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Defining the AlgorithmNodeTreeManager class
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
from valanga import BranchKey
|
|
9
|
+
|
|
10
|
+
from anemone import trees
|
|
11
|
+
from anemone.indices.index_manager import (
|
|
12
|
+
NodeExplorationIndexManager,
|
|
13
|
+
)
|
|
14
|
+
from anemone.indices.index_manager.node_exploration_manager import (
|
|
15
|
+
update_all_indices,
|
|
16
|
+
)
|
|
17
|
+
from anemone.node_evaluation.node_direct_evaluation import (
|
|
18
|
+
EvaluationQueries,
|
|
19
|
+
NodeDirectEvaluator,
|
|
20
|
+
)
|
|
21
|
+
from anemone.node_factory.algorithm_node_factory import AlgorithmNodeFactory
|
|
22
|
+
from anemone.nodes.algorithm_node.algorithm_node import (
|
|
23
|
+
AlgorithmNode,
|
|
24
|
+
)
|
|
25
|
+
from anemone.updates.algorithm_node_updater import AlgorithmNodeUpdater
|
|
26
|
+
from anemone.updates.updates_file import (
|
|
27
|
+
UpdateInstructionsFromOneNode,
|
|
28
|
+
UpdateInstructionsTowardsMultipleNodes,
|
|
29
|
+
UpdateInstructionsTowardsOneParentNode,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
from .tree_expander import TreeExpansion, TreeExpansions, record_tree_expansion
|
|
33
|
+
from .tree_manager import TreeManager
|
|
34
|
+
|
|
35
|
+
# todo should we use a discount? and discounted per round reward?
|
|
36
|
+
# todo maybe convenient to seperate this object into openner updater and dsiplayer
|
|
37
|
+
# todo have the reward with a discount
|
|
38
|
+
# DISCOUNT = 1/.99999
|
|
39
|
+
if TYPE_CHECKING:
|
|
40
|
+
from anemone import node_selector as node_sel
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class AlgorithmNodeTreeManager[NodeT: AlgorithmNode[Any] = AlgorithmNode[Any]]:
|
|
45
|
+
"""
|
|
46
|
+
This class that and manages a tree by opening new nodes and updating the values and indexes on the nodes.
|
|
47
|
+
It wraps around the Tree Manager class as it has a tree_manager as member and adds functionality as this handles
|
|
48
|
+
trees with nodes that are of the class AlgorithmNode (managing the value for instance)
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
tree_manager: TreeManager[NodeT]
|
|
52
|
+
algorithm_tree_node_factory: AlgorithmNodeFactory
|
|
53
|
+
|
|
54
|
+
algorithm_node_updater: AlgorithmNodeUpdater
|
|
55
|
+
evaluation_queries: EvaluationQueries
|
|
56
|
+
node_evaluator: NodeDirectEvaluator | None
|
|
57
|
+
index_manager: NodeExplorationIndexManager
|
|
58
|
+
|
|
59
|
+
def open_tree_expansion_from_branch(
|
|
60
|
+
self,
|
|
61
|
+
tree: trees.Tree[NodeT],
|
|
62
|
+
parent_node: NodeT,
|
|
63
|
+
branch: BranchKey,
|
|
64
|
+
) -> TreeExpansion[NodeT]:
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
tree: the tree to open
|
|
69
|
+
parent_node: the node to open
|
|
70
|
+
move: to move to open with
|
|
71
|
+
|
|
72
|
+
Returns: the tree expansions
|
|
73
|
+
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
tree_expansion: TreeExpansion[NodeT]
|
|
77
|
+
tree_expansion = self.tree_manager.open_tree_expansion_from_branch(
|
|
78
|
+
tree=tree, parent_node=parent_node, branch=branch
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
parent_node.tree_evaluation.branches_not_over.append(
|
|
82
|
+
branch
|
|
83
|
+
) # default action checks for over event are performed later
|
|
84
|
+
|
|
85
|
+
return tree_expansion
|
|
86
|
+
|
|
87
|
+
def open_instructions(
|
|
88
|
+
self,
|
|
89
|
+
tree: trees.Tree[NodeT],
|
|
90
|
+
opening_instructions: "node_sel.OpeningInstructions[NodeT]",
|
|
91
|
+
) -> TreeExpansions[NodeT]:
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
tree: the tree object to open
|
|
96
|
+
opening_instructions: the opening instructions
|
|
97
|
+
|
|
98
|
+
Returns: the expansions that have been performed
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
# place to store the tree expansion logs generated by the openings
|
|
103
|
+
tree_expansions: TreeExpansions[NodeT] = TreeExpansions()
|
|
104
|
+
|
|
105
|
+
opening_instruction: node_sel.OpeningInstruction[NodeT]
|
|
106
|
+
tree_expansion: TreeExpansion[NodeT]
|
|
107
|
+
for opening_instruction in opening_instructions.values():
|
|
108
|
+
# open
|
|
109
|
+
assert isinstance(opening_instruction.node_to_open, AlgorithmNode)
|
|
110
|
+
tree_expansion = self.open_tree_expansion_from_branch(
|
|
111
|
+
tree=tree,
|
|
112
|
+
parent_node=opening_instruction.node_to_open,
|
|
113
|
+
branch=opening_instruction.branch,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
print("opened", tree_expansion)
|
|
117
|
+
|
|
118
|
+
record_tree_expansion(
|
|
119
|
+
tree=tree,
|
|
120
|
+
tree_expansions=tree_expansions,
|
|
121
|
+
tree_expansion=tree_expansion,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
assert self.node_evaluator is not None
|
|
125
|
+
for tree_expansion in tree_expansions.expansions_with_node_creation:
|
|
126
|
+
# TODO give the tree expansion to the function directly
|
|
127
|
+
assert isinstance(tree_expansion.child_node, AlgorithmNode)
|
|
128
|
+
self.node_evaluator.add_evaluation_query(
|
|
129
|
+
node=tree_expansion.child_node,
|
|
130
|
+
evaluation_queries=self.evaluation_queries,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
self.node_evaluator.evaluate_all_queried_nodes(
|
|
134
|
+
evaluation_queries=self.evaluation_queries
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
return tree_expansions
|
|
138
|
+
|
|
139
|
+
def update_indices(self, tree: trees.Tree[NodeT]) -> None:
|
|
140
|
+
"""
|
|
141
|
+
Updates the indices of the nodes in the given tree.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
tree (ValueTree): The tree whose indices need to be updated.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
None
|
|
148
|
+
"""
|
|
149
|
+
update_all_indices(index_manager=self.index_manager, tree=tree)
|
|
150
|
+
|
|
151
|
+
def update_backward(self, tree_expansions: TreeExpansions[NodeT]) -> None:
|
|
152
|
+
"""
|
|
153
|
+
Updates the algorithm node tree in a backward manner based on the given tree expansions.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
tree_expansions (TreeExpansions): The tree expansions used to update the algorithm node tree.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
None
|
|
160
|
+
"""
|
|
161
|
+
update_instructions_batch: UpdateInstructionsTowardsMultipleNodes[NodeT]
|
|
162
|
+
update_instructions_batch = (
|
|
163
|
+
self.algorithm_node_updater.generate_update_instructions(
|
|
164
|
+
tree_expansions=tree_expansions
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
while update_instructions_batch:
|
|
169
|
+
node_to_update: NodeT
|
|
170
|
+
update_instructions: UpdateInstructionsTowardsOneParentNode
|
|
171
|
+
node_to_update, update_instructions = update_instructions_batch.pop_item()
|
|
172
|
+
extra_update_instructions_batch: UpdateInstructionsTowardsMultipleNodes[
|
|
173
|
+
NodeT
|
|
174
|
+
]
|
|
175
|
+
extra_update_instructions_batch = self.update_node(
|
|
176
|
+
node_to_update=node_to_update, update_instructions=update_instructions
|
|
177
|
+
)
|
|
178
|
+
# merge
|
|
179
|
+
while extra_update_instructions_batch.one_node_instructions:
|
|
180
|
+
parent_node_to_update: NodeT
|
|
181
|
+
update: UpdateInstructionsTowardsOneParentNode
|
|
182
|
+
parent_node_to_update, update = (
|
|
183
|
+
extra_update_instructions_batch.pop_item()
|
|
184
|
+
)
|
|
185
|
+
update_instructions_batch.add_updates_towards_one_parent_node(
|
|
186
|
+
parent_node=parent_node_to_update, update_from_child_node=update
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def update_node(
|
|
190
|
+
self,
|
|
191
|
+
node_to_update: NodeT,
|
|
192
|
+
update_instructions: UpdateInstructionsTowardsOneParentNode,
|
|
193
|
+
) -> UpdateInstructionsTowardsMultipleNodes[NodeT]:
|
|
194
|
+
"""
|
|
195
|
+
Updates the given node with the provided update instructions.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
node_to_update (AlgorithmNode): The node to be updated.
|
|
199
|
+
update_instructions (UpdateInstructions): The instructions for updating the node.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
UpdateInstructionsBatch: A batch of update instructions for the parent nodes of the updated node.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
# UPDATES
|
|
206
|
+
new_update_instructions: UpdateInstructionsFromOneNode = (
|
|
207
|
+
self.algorithm_node_updater.perform_updates(
|
|
208
|
+
node_to_update=node_to_update, update_instructions=update_instructions
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
update_instructions_batch: UpdateInstructionsTowardsMultipleNodes[NodeT]
|
|
213
|
+
update_instructions_batch = UpdateInstructionsTowardsMultipleNodes()
|
|
214
|
+
parent_node: NodeT
|
|
215
|
+
branch_from_parent: BranchKey
|
|
216
|
+
for parent_node, branch_from_parent in node_to_update.parent_nodes.items():
|
|
217
|
+
# there was a test for emptiness here of new updates instructions remove this comment if no bug appear
|
|
218
|
+
assert parent_node not in update_instructions_batch.one_node_instructions
|
|
219
|
+
update_instructions_batch.add_update_from_one_child_node(
|
|
220
|
+
update_from_child_node=new_update_instructions,
|
|
221
|
+
parent_node=parent_node,
|
|
222
|
+
branch_from_parent=branch_from_parent,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return update_instructions_batch
|
|
226
|
+
|
|
227
|
+
def print_some_stats(self, tree: trees.Tree[NodeT]) -> None:
|
|
228
|
+
"""
|
|
229
|
+
Prints statistics about the given tree.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
tree (ValueTree): The tree to print statistics for.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
None
|
|
236
|
+
"""
|
|
237
|
+
self.tree_manager.print_some_stats(tree=tree)
|
|
238
|
+
|
|
239
|
+
def print_best_line(self, tree: trees.Tree[NodeT]) -> None:
|
|
240
|
+
"""
|
|
241
|
+
Prints the best line of moves based on the tree evaluation of the tree.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
tree (ValueTree): The tree containing the moves and their minmax evaluations.
|
|
245
|
+
"""
|
|
246
|
+
tree.root_node.tree_evaluation.print_best_line()
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides a factory function for creating an AlgorithmNodeTreeManager object.
|
|
3
|
+
|
|
4
|
+
The AlgorithmNodeTreeManager is responsible for managing the tree structure of algorithm nodes,
|
|
5
|
+
performing updates on the nodes, and handling evaluation queries.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from anemone import updates as upda
|
|
12
|
+
from anemone.indices.index_manager import (
|
|
13
|
+
NodeExplorationIndexManager,
|
|
14
|
+
create_exploration_index_manager,
|
|
15
|
+
)
|
|
16
|
+
from anemone.indices.node_indices.index_types import (
|
|
17
|
+
IndexComputationType,
|
|
18
|
+
)
|
|
19
|
+
from anemone.node_evaluation.node_direct_evaluation import (
|
|
20
|
+
EvaluationQueries,
|
|
21
|
+
NodeDirectEvaluator,
|
|
22
|
+
)
|
|
23
|
+
from anemone.node_factory import (
|
|
24
|
+
AlgorithmNodeFactory,
|
|
25
|
+
)
|
|
26
|
+
from anemone.nodes.algorithm_node.algorithm_node import AlgorithmNode
|
|
27
|
+
from anemone.state_transition import ValangaStateTransition
|
|
28
|
+
from anemone.updates.index_updater import IndexUpdater
|
|
29
|
+
|
|
30
|
+
from .algorithm_node_tree_manager import AlgorithmNodeTreeManager
|
|
31
|
+
from .tree_manager import TreeManager
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def create_algorithm_node_tree_manager(
|
|
35
|
+
node_direct_evaluator: NodeDirectEvaluator[Any] | None,
|
|
36
|
+
algorithm_node_factory: AlgorithmNodeFactory[Any],
|
|
37
|
+
index_computation: IndexComputationType | None,
|
|
38
|
+
index_updater: IndexUpdater | None,
|
|
39
|
+
) -> AlgorithmNodeTreeManager:
|
|
40
|
+
"""
|
|
41
|
+
Create an AlgorithmNodeTreeManager object.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
node_evaluator: The NodeEvaluator object used for evaluating nodes in the tree.
|
|
45
|
+
algorithm_node_factory: The AlgorithmNodeFactory object used for creating algorithm nodes.
|
|
46
|
+
index_computation: The type of index computation to be used.
|
|
47
|
+
index_updater: The IndexUpdater object used for updating the indices.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
An AlgorithmNodeTreeManager object.
|
|
51
|
+
|
|
52
|
+
"""
|
|
53
|
+
tree_manager: TreeManager[AlgorithmNode] = TreeManager[AlgorithmNode](
|
|
54
|
+
node_factory=algorithm_node_factory,
|
|
55
|
+
transition=ValangaStateTransition(),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
algorithm_node_updater: upda.AlgorithmNodeUpdater = (
|
|
59
|
+
upda.create_algorithm_node_updater(index_updater=index_updater)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
evaluation_queries: EvaluationQueries = EvaluationQueries()
|
|
63
|
+
|
|
64
|
+
exploration_index_manager: NodeExplorationIndexManager = (
|
|
65
|
+
create_exploration_index_manager(index_computation=index_computation)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
algorithm_node_tree_manager: AlgorithmNodeTreeManager = AlgorithmNodeTreeManager(
|
|
69
|
+
node_evaluator=node_direct_evaluator,
|
|
70
|
+
tree_manager=tree_manager,
|
|
71
|
+
algorithm_node_updater=algorithm_node_updater,
|
|
72
|
+
algorithm_tree_node_factory=algorithm_node_factory,
|
|
73
|
+
evaluation_queries=evaluation_queries,
|
|
74
|
+
index_manager=exploration_index_manager,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return algorithm_node_tree_manager
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tree expansion representations for managing game trees.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Iterator, TypeVar
|
|
7
|
+
|
|
8
|
+
from valanga import BranchKey, StateModifications
|
|
9
|
+
|
|
10
|
+
from anemone import nodes as node
|
|
11
|
+
from anemone import trees
|
|
12
|
+
|
|
13
|
+
NodeT = TypeVar("NodeT", bound=node.ITreeNode[Any])
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(slots=True)
|
|
17
|
+
class TreeExpansion[NodeT: node.ITreeNode[Any] = node.ITreeNode[Any]]:
|
|
18
|
+
"""
|
|
19
|
+
Represents an expansion of a tree in a chess game.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
child_node (NodeT): The child node created during the expansion.
|
|
23
|
+
parent_node (node.ITreeNode | None): The parent node of the child node. None if it's the root node.
|
|
24
|
+
board_modifications (board_mod.BoardModification | None): The modifications made to the chess board during the expansion.
|
|
25
|
+
creation_child_node (bool): Indicates whether the child node was created during the expansion.
|
|
26
|
+
move (chess.Move): the move from parent to child node.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
child_node: NodeT
|
|
30
|
+
parent_node: NodeT | None
|
|
31
|
+
state_modifications: StateModifications | None
|
|
32
|
+
creation_child_node: bool
|
|
33
|
+
branch_key: BranchKey | None
|
|
34
|
+
|
|
35
|
+
def __repr__(self) -> str:
|
|
36
|
+
"""Return a debug representation of the tree expansion."""
|
|
37
|
+
return (
|
|
38
|
+
f"child_node{self.child_node.id} | "
|
|
39
|
+
f"parent_node{self.parent_node.id if self.parent_node is not None else None} | "
|
|
40
|
+
f"creation_child_node{self.creation_child_node}"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def record_tree_expansion(
|
|
45
|
+
*,
|
|
46
|
+
tree: trees.Tree[NodeT],
|
|
47
|
+
tree_expansions: "TreeExpansions[NodeT]",
|
|
48
|
+
tree_expansion: TreeExpansion[NodeT],
|
|
49
|
+
) -> None:
|
|
50
|
+
"""Apply one TreeExpansion's bookkeeping to a tree + its log."""
|
|
51
|
+
|
|
52
|
+
if tree_expansion.creation_child_node:
|
|
53
|
+
tree.nodes_count += 1
|
|
54
|
+
tree.descendants.add_descendant(tree_expansion.child_node)
|
|
55
|
+
|
|
56
|
+
tree_expansions.add(tree_expansion=tree_expansion)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _new_expansions_list() -> list[TreeExpansion[Any]]:
|
|
60
|
+
"""Return a new list for tree expansions."""
|
|
61
|
+
return []
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass(slots=True)
|
|
65
|
+
class TreeExpansions[NodeT: node.ITreeNode[Any] = node.ITreeNode[Any]]:
|
|
66
|
+
"""
|
|
67
|
+
Represents a collection of tree expansions in a chess game.
|
|
68
|
+
|
|
69
|
+
Attributes:
|
|
70
|
+
expansions_with_node_creation (List[TreeExpansion]): List of expansions where child nodes were created.
|
|
71
|
+
expansions_without_node_creation (List[TreeExpansion]): List of expansions where child nodes were not created.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
expansions_with_node_creation: list[TreeExpansion[NodeT]] = field(
|
|
75
|
+
default_factory=_new_expansions_list
|
|
76
|
+
)
|
|
77
|
+
expansions_without_node_creation: list[TreeExpansion[NodeT]] = field(
|
|
78
|
+
default_factory=_new_expansions_list
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def __iter__(self) -> Iterator[TreeExpansion[NodeT]]:
|
|
82
|
+
"""Iterate over all recorded expansions."""
|
|
83
|
+
return iter(
|
|
84
|
+
self.expansions_with_node_creation + self.expansions_without_node_creation
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def add(self, tree_expansion: TreeExpansion[NodeT]) -> None:
|
|
88
|
+
"""
|
|
89
|
+
Adds a tree expansion to the collection.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
tree_expansion (TreeExpansion): The tree expansion to add.
|
|
93
|
+
"""
|
|
94
|
+
if tree_expansion.creation_child_node:
|
|
95
|
+
self.add_creation(tree_expansion=tree_expansion)
|
|
96
|
+
else:
|
|
97
|
+
self.add_connection(tree_expansion=tree_expansion)
|
|
98
|
+
|
|
99
|
+
def add_creation(self, tree_expansion: TreeExpansion[NodeT]) -> None:
|
|
100
|
+
"""
|
|
101
|
+
Adds a tree expansion with a created child node to the collection.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
tree_expansion (TreeExpansion): The tree expansion to add.
|
|
105
|
+
"""
|
|
106
|
+
self.expansions_with_node_creation.append(tree_expansion)
|
|
107
|
+
|
|
108
|
+
def add_connection(self, tree_expansion: TreeExpansion[NodeT]) -> None:
|
|
109
|
+
"""
|
|
110
|
+
Adds a tree expansion without a created child node to the collection.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
tree_expansion (TreeExpansion): The tree expansion to add.
|
|
114
|
+
"""
|
|
115
|
+
self.expansions_without_node_creation.append(tree_expansion)
|
|
116
|
+
|
|
117
|
+
def __str__(self) -> str:
|
|
118
|
+
"""Return a string summary of recorded expansions."""
|
|
119
|
+
return (
|
|
120
|
+
f"expansions_with_node_creation {self.expansions_with_node_creation} \n"
|
|
121
|
+
f"expansions_without_node_creation{self.expansions_without_node_creation}"
|
|
122
|
+
)
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the TreeManager class, which is responsible for managing a tree by opening new nodes and updating the values and indexes on the nodes.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from valanga import BranchKey, State, StateModifications, StateTag
|
|
8
|
+
|
|
9
|
+
from anemone import nodes as node
|
|
10
|
+
from anemone import trees
|
|
11
|
+
from anemone.node_factory.base import (
|
|
12
|
+
NodeFactory,
|
|
13
|
+
)
|
|
14
|
+
from anemone.node_selector.opening_instructions import (
|
|
15
|
+
OpeningInstruction,
|
|
16
|
+
OpeningInstructions,
|
|
17
|
+
)
|
|
18
|
+
from anemone.state_transition import StateTransition
|
|
19
|
+
from anemone.tree_manager.tree_expander import (
|
|
20
|
+
TreeExpansion,
|
|
21
|
+
TreeExpansions,
|
|
22
|
+
record_tree_expansion,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from anemone.basics import TreeDepth
|
|
27
|
+
|
|
28
|
+
# todo should we use a discount? and discounted per round reward?
|
|
29
|
+
# todo maybe convenient to seperate this object into openner updater and dsiplayer
|
|
30
|
+
# todo have the reward with a discount
|
|
31
|
+
# DISCOUNT = 1/.99999
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TreeManager[
|
|
35
|
+
FamilyT: node.ITreeNode[Any] = node.ITreeNode[Any],
|
|
36
|
+
]:
|
|
37
|
+
"""
|
|
38
|
+
This class manages a tree by opening new nodes.This is the core one only responsible for creating core TreeNodes
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
node_factory: NodeFactory[FamilyT]
|
|
42
|
+
transition: StateTransition[State]
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
node_factory: NodeFactory[FamilyT],
|
|
47
|
+
transition: StateTransition[State],
|
|
48
|
+
) -> None:
|
|
49
|
+
"""Initialize the tree manager with a node factory and transition."""
|
|
50
|
+
self.node_factory = node_factory
|
|
51
|
+
self.transition = transition
|
|
52
|
+
|
|
53
|
+
def open_tree_expansion_from_branch(
|
|
54
|
+
self,
|
|
55
|
+
tree: trees.Tree[FamilyT],
|
|
56
|
+
parent_node: FamilyT,
|
|
57
|
+
branch: BranchKey,
|
|
58
|
+
) -> TreeExpansion[FamilyT]:
|
|
59
|
+
"""
|
|
60
|
+
Opening a Node that contains a board following a move.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
tree: The tree object.
|
|
64
|
+
parent_node: The parent node that we want to expand.
|
|
65
|
+
move: The move to play to expand the node.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
The tree expansion object.
|
|
69
|
+
"""
|
|
70
|
+
# The parent board is copied, we only copy the stack (history of previous board) if the depth is smaller than 2
|
|
71
|
+
# Having the stack information allows checking for draw by repetition.
|
|
72
|
+
# To limit computation we limit copying it all the time. The resulting policy will only be aware of immediate
|
|
73
|
+
# risk of draw by repetition
|
|
74
|
+
copy_stack: bool = tree.node_depth(parent_node) < 2
|
|
75
|
+
parent_state: State = parent_node.state
|
|
76
|
+
state: State = self.transition.copy_for_expansion(
|
|
77
|
+
parent_state,
|
|
78
|
+
copy_stack=copy_stack,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# The move is played. The state is now advanced.
|
|
82
|
+
state, modifications = self.transition.step(state, branch_key=branch)
|
|
83
|
+
|
|
84
|
+
return self.open_tree_expansion_from_state(
|
|
85
|
+
tree=tree,
|
|
86
|
+
parent_node=parent_node,
|
|
87
|
+
state=state,
|
|
88
|
+
modifications=modifications,
|
|
89
|
+
branch=branch,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def open_tree_expansion_from_state(
|
|
93
|
+
self,
|
|
94
|
+
tree: trees.Tree[FamilyT],
|
|
95
|
+
parent_node: FamilyT,
|
|
96
|
+
state: State,
|
|
97
|
+
modifications: StateModifications | None,
|
|
98
|
+
branch: BranchKey,
|
|
99
|
+
) -> TreeExpansion[FamilyT]:
|
|
100
|
+
"""
|
|
101
|
+
Opening a Node that contains a board given the modifications.
|
|
102
|
+
Checks if the new node needs to be created or if the new_board already existed in the tree
|
|
103
|
+
(was reached from a different serie of move)
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
tree: The tree object.
|
|
107
|
+
parent_node: The parent node that we want to expand.
|
|
108
|
+
board: The board object that is a move forward compared to the board in the parent node
|
|
109
|
+
modifications: The board modifications.
|
|
110
|
+
move: The move to play to expand the node.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
The tree expansion object.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
# Creation of the child node. If the board already exited in another node, that node is returned as child_node.
|
|
117
|
+
tree_depth: int = parent_node.tree_depth + 1
|
|
118
|
+
state_tag: StateTag = state.tag
|
|
119
|
+
|
|
120
|
+
need_creation_child_node: bool = (
|
|
121
|
+
tree.descendants.is_new_generation(tree_depth)
|
|
122
|
+
or state_tag not in tree.descendants.descendants_at_tree_depth[tree_depth]
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
tree_expansion: TreeExpansion[FamilyT]
|
|
126
|
+
|
|
127
|
+
if need_creation_child_node:
|
|
128
|
+
child_node: FamilyT
|
|
129
|
+
child_node = self.node_factory.create(
|
|
130
|
+
state=state,
|
|
131
|
+
tree_depth=tree_depth,
|
|
132
|
+
count=tree.nodes_count,
|
|
133
|
+
branch_from_parent=branch,
|
|
134
|
+
parent_node=parent_node,
|
|
135
|
+
modifications=modifications,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
tree_expansion = TreeExpansion(
|
|
139
|
+
child_node=child_node,
|
|
140
|
+
parent_node=parent_node,
|
|
141
|
+
state_modifications=modifications,
|
|
142
|
+
creation_child_node=need_creation_child_node,
|
|
143
|
+
branch_key=branch,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
else: # the node already exists
|
|
147
|
+
child_node_existing: FamilyT
|
|
148
|
+
child_node_existing = tree.descendants[tree_depth][state_tag]
|
|
149
|
+
child_node_existing.add_parent(
|
|
150
|
+
branch_key=branch, new_parent_node=parent_node
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
tree_expansion = TreeExpansion(
|
|
154
|
+
child_node=child_node_existing,
|
|
155
|
+
parent_node=parent_node,
|
|
156
|
+
state_modifications=modifications,
|
|
157
|
+
creation_child_node=need_creation_child_node,
|
|
158
|
+
branch_key=branch,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# add it to the list of opened move and out of the non-opened moves
|
|
162
|
+
parent_node.branches_children[branch] = tree_expansion.child_node
|
|
163
|
+
# parent_node.tree_node.non_opened_legal_moves.remove(move)
|
|
164
|
+
tree.move_count += 1 # counting moves
|
|
165
|
+
|
|
166
|
+
return tree_expansion
|
|
167
|
+
|
|
168
|
+
def open_instructions(
|
|
169
|
+
self,
|
|
170
|
+
tree: trees.Tree[FamilyT],
|
|
171
|
+
opening_instructions: OpeningInstructions[FamilyT],
|
|
172
|
+
) -> TreeExpansions[FamilyT]:
|
|
173
|
+
"""
|
|
174
|
+
Opening multiple nodes based on the opening instructions.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
tree: The tree object.
|
|
178
|
+
opening_instructions: The opening instructions.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
The tree expansions that have been performed.
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
# place to store the tree expansion logs generated by the openings
|
|
185
|
+
tree_expansions: TreeExpansions[FamilyT] = TreeExpansions()
|
|
186
|
+
|
|
187
|
+
opening_instruction: OpeningInstruction[FamilyT]
|
|
188
|
+
for opening_instruction in opening_instructions.values():
|
|
189
|
+
# open
|
|
190
|
+
tree_expansion: TreeExpansion[FamilyT] = (
|
|
191
|
+
self.open_tree_expansion_from_branch(
|
|
192
|
+
tree=tree,
|
|
193
|
+
parent_node=opening_instruction.node_to_open,
|
|
194
|
+
branch=opening_instruction.branch,
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
record_tree_expansion(
|
|
199
|
+
tree=tree,
|
|
200
|
+
tree_expansions=tree_expansions,
|
|
201
|
+
tree_expansion=tree_expansion,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
return tree_expansions
|
|
205
|
+
|
|
206
|
+
def print_some_stats(
|
|
207
|
+
self,
|
|
208
|
+
tree: trees.Tree[FamilyT],
|
|
209
|
+
) -> None:
|
|
210
|
+
"""
|
|
211
|
+
Print some statistics about the tree.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
tree: The tree object.
|
|
215
|
+
"""
|
|
216
|
+
print(
|
|
217
|
+
"Tree stats: move_count",
|
|
218
|
+
tree.move_count,
|
|
219
|
+
" node_count",
|
|
220
|
+
tree.descendants.get_count(),
|
|
221
|
+
)
|
|
222
|
+
sum_ = 0
|
|
223
|
+
tree.descendants.print_stats()
|
|
224
|
+
tree_depth: TreeDepth
|
|
225
|
+
for tree_depth in tree.descendants:
|
|
226
|
+
sum_ += len(tree.descendants[tree_depth])
|
|
227
|
+
print("tree_depth", tree_depth, len(tree.descendants[tree_depth]), sum_)
|
|
228
|
+
|
|
229
|
+
def test_count(
|
|
230
|
+
self,
|
|
231
|
+
tree: trees.Tree[FamilyT],
|
|
232
|
+
) -> None:
|
|
233
|
+
"""
|
|
234
|
+
Test the count of nodes in the tree.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
tree: The tree object.
|
|
238
|
+
"""
|
|
239
|
+
assert tree.descendants.get_count() == tree.nodes_count
|
|
240
|
+
|
|
241
|
+
def print_best_line(
|
|
242
|
+
self,
|
|
243
|
+
tree: trees.Tree[FamilyT],
|
|
244
|
+
) -> None:
|
|
245
|
+
"""
|
|
246
|
+
Print the best line in the tree.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
tree: The tree object.
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
raise NotImplementedError(
|
|
253
|
+
"print_best_line should not be called; override or modify this behavior"
|
|
254
|
+
)
|