algorhino-anemone 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- algorhino_anemone-0.1.1.dist-info/METADATA +151 -0
- algorhino_anemone-0.1.1.dist-info/RECORD +82 -0
- algorhino_anemone-0.1.1.dist-info/WHEEL +5 -0
- algorhino_anemone-0.1.1.dist-info/licenses/LICENSE +674 -0
- algorhino_anemone-0.1.1.dist-info/top_level.txt +1 -0
- anemone/__init__.py +27 -0
- anemone/basics.py +36 -0
- anemone/factory.py +161 -0
- anemone/indices/__init__.py +0 -0
- anemone/indices/index_manager/__init__.py +12 -0
- anemone/indices/index_manager/factory.py +50 -0
- anemone/indices/index_manager/node_exploration_manager.py +549 -0
- anemone/indices/node_indices/__init__.py +22 -0
- anemone/indices/node_indices/factory.py +121 -0
- anemone/indices/node_indices/index_data.py +166 -0
- anemone/indices/node_indices/index_types.py +20 -0
- anemone/nn/torch_evaluator.py +108 -0
- anemone/node_evaluation/__init__.py +0 -0
- anemone/node_evaluation/node_direct_evaluation/__init__.py +22 -0
- anemone/node_evaluation/node_direct_evaluation/factory.py +12 -0
- anemone/node_evaluation/node_direct_evaluation/node_direct_evaluator.py +192 -0
- anemone/node_evaluation/node_tree_evaluation/node_minmax_evaluation.py +885 -0
- anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation.py +137 -0
- anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation_factory.py +43 -0
- anemone/node_factory/__init__.py +14 -0
- anemone/node_factory/algorithm_node_factory.py +123 -0
- anemone/node_factory/base.py +76 -0
- anemone/node_selector/__init__.py +32 -0
- anemone/node_selector/branch_explorer.py +89 -0
- anemone/node_selector/factory.py +65 -0
- anemone/node_selector/node_selector.py +44 -0
- anemone/node_selector/node_selector_args.py +22 -0
- anemone/node_selector/node_selector_types.py +15 -0
- anemone/node_selector/notations_and_statics.py +88 -0
- anemone/node_selector/opening_instructions.py +249 -0
- anemone/node_selector/recurzipf/__init__.py +0 -0
- anemone/node_selector/recurzipf/recur_zipf_base.py +141 -0
- anemone/node_selector/sequool/__init__.py +19 -0
- anemone/node_selector/sequool/factory.py +102 -0
- anemone/node_selector/sequool/sequool.py +395 -0
- anemone/node_selector/uniform/__init__.py +16 -0
- anemone/node_selector/uniform/uniform.py +113 -0
- anemone/nodes/__init__.py +15 -0
- anemone/nodes/algorithm_node/__init__.py +7 -0
- anemone/nodes/algorithm_node/algorithm_node.py +204 -0
- anemone/nodes/itree_node.py +136 -0
- anemone/nodes/tree_node.py +240 -0
- anemone/nodes/tree_traversal.py +108 -0
- anemone/nodes/utils.py +146 -0
- anemone/progress_monitor/__init__.py +0 -0
- anemone/progress_monitor/progress_monitor.py +375 -0
- anemone/recommender_rule/__init__.py +12 -0
- anemone/recommender_rule/recommender_rule.py +140 -0
- anemone/search_factory/__init__.py +14 -0
- anemone/search_factory/search_factory.py +192 -0
- anemone/state_transition.py +47 -0
- anemone/tree_and_value_branch_selector.py +99 -0
- anemone/tree_exploration.py +274 -0
- anemone/tree_manager/__init__.py +29 -0
- anemone/tree_manager/algorithm_node_tree_manager.py +246 -0
- anemone/tree_manager/factory.py +77 -0
- anemone/tree_manager/tree_expander.py +122 -0
- anemone/tree_manager/tree_manager.py +254 -0
- anemone/trees/__init__.py +14 -0
- anemone/trees/descendants.py +765 -0
- anemone/trees/factory.py +80 -0
- anemone/trees/tree.py +70 -0
- anemone/trees/tree_visualization.py +143 -0
- anemone/updates/__init__.py +33 -0
- anemone/updates/algorithm_node_updater.py +157 -0
- anemone/updates/factory.py +36 -0
- anemone/updates/index_block.py +91 -0
- anemone/updates/index_updater.py +100 -0
- anemone/updates/minmax_evaluation_updater.py +108 -0
- anemone/updates/updates_file.py +248 -0
- anemone/updates/value_block.py +133 -0
- anemone/utils/comparable.py +32 -0
- anemone/utils/dataclass.py +64 -0
- anemone/utils/dict_of_numbered_dict_with_pointer_on_max.py +128 -0
- anemone/utils/logger.py +94 -0
- anemone/utils/my_value_sorted_dict.py +27 -0
- anemone/utils/small_tools.py +103 -0
anemone/trees/factory.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ValueTreeFactory
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from valanga import State
|
|
6
|
+
|
|
7
|
+
from anemone import node_factory as nod_fac
|
|
8
|
+
from anemone.node_evaluation.node_direct_evaluation.node_direct_evaluator import (
|
|
9
|
+
EvaluationQueries,
|
|
10
|
+
NodeDirectEvaluator,
|
|
11
|
+
)
|
|
12
|
+
from anemone.nodes.algorithm_node.algorithm_node import AlgorithmNode
|
|
13
|
+
from anemone.trees.tree import Tree
|
|
14
|
+
|
|
15
|
+
from .descendants import RangedDescendants
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ValueTreeFactory[StateT: State = State]:
|
|
19
|
+
"""
|
|
20
|
+
ValueTreeFactory
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
node_factory: nod_fac.AlgorithmNodeFactory[StateT]
|
|
24
|
+
node_direct_evaluator: NodeDirectEvaluator[StateT]
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
node_factory: nod_fac.AlgorithmNodeFactory[StateT],
|
|
29
|
+
node_direct_evaluator: NodeDirectEvaluator[StateT],
|
|
30
|
+
) -> None:
|
|
31
|
+
"""
|
|
32
|
+
creates the tree factory
|
|
33
|
+
Args:
|
|
34
|
+
node_factory:
|
|
35
|
+
node_evaluator:
|
|
36
|
+
"""
|
|
37
|
+
self.node_factory = node_factory
|
|
38
|
+
self.node_direct_evaluator = node_direct_evaluator
|
|
39
|
+
|
|
40
|
+
def create(self, starting_state: StateT) -> Tree[AlgorithmNode[StateT]]:
|
|
41
|
+
"""
|
|
42
|
+
creates the tree
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
starting_state: the starting position
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
root_node: AlgorithmNode[StateT] = self.node_factory.create(
|
|
52
|
+
state=starting_state,
|
|
53
|
+
tree_depth=0, # by default
|
|
54
|
+
count=0,
|
|
55
|
+
parent_node=None,
|
|
56
|
+
modifications=None,
|
|
57
|
+
branch_from_parent=None,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
evaluation_queries: EvaluationQueries[StateT] = EvaluationQueries()
|
|
61
|
+
|
|
62
|
+
self.node_direct_evaluator.add_evaluation_query(
|
|
63
|
+
node=root_node, evaluation_queries=evaluation_queries
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
self.node_direct_evaluator.evaluate_all_queried_nodes(
|
|
67
|
+
evaluation_queries=evaluation_queries
|
|
68
|
+
)
|
|
69
|
+
# is this needed? used outside?
|
|
70
|
+
|
|
71
|
+
descendants: RangedDescendants[AlgorithmNode[StateT]] = RangedDescendants[
|
|
72
|
+
AlgorithmNode[StateT]
|
|
73
|
+
]()
|
|
74
|
+
descendants.add_descendant(root_node)
|
|
75
|
+
|
|
76
|
+
value_tree: Tree[AlgorithmNode[StateT]] = Tree[AlgorithmNode[StateT]](
|
|
77
|
+
root_node=root_node, descendants=descendants
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
return value_tree
|
anemone/trees/tree.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tree
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from anemone.basics import TreeDepth
|
|
8
|
+
from anemone.nodes.itree_node import ITreeNode
|
|
9
|
+
|
|
10
|
+
from .descendants import RangedDescendants
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Tree[NodeT: ITreeNode[Any]]:
|
|
14
|
+
"""
|
|
15
|
+
This class defines the Tree that is built out of all the combinations of moves given a starting board position.
|
|
16
|
+
The root node contains the starting board.
|
|
17
|
+
Each node contains a board and has as many children node as there are legal move in the board.
|
|
18
|
+
A children node then contains the board that is obtained by playing a particular moves in the board of the parent
|
|
19
|
+
node.
|
|
20
|
+
|
|
21
|
+
It is a pointer to the root node with some counters and keeping track of descendants.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
_root_node: NodeT
|
|
25
|
+
descendants: RangedDescendants[NodeT]
|
|
26
|
+
tree_root_tree_depth: TreeDepth
|
|
27
|
+
|
|
28
|
+
def __init__(self, root_node: NodeT, descendants: RangedDescendants[NodeT]) -> None:
|
|
29
|
+
"""
|
|
30
|
+
Initialize the Tree with a root node and descendants.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
root_node: The root node of the tree.
|
|
34
|
+
descendants: The descendants collection for the tree.
|
|
35
|
+
"""
|
|
36
|
+
self.tree_root_tree_depth = root_node.tree_depth
|
|
37
|
+
|
|
38
|
+
# number of nodes in the tree (already one as we have the root node provided)
|
|
39
|
+
self.nodes_count = 1
|
|
40
|
+
|
|
41
|
+
# integer counting the number of moves in the tree.
|
|
42
|
+
# the interest of self.move_count over the number of nodes in the descendants
|
|
43
|
+
# is that is always increasing at each opening,
|
|
44
|
+
# while self.node_count can stay the same if the nodes already existed.
|
|
45
|
+
self.move_count = 0
|
|
46
|
+
|
|
47
|
+
self._root_node = root_node
|
|
48
|
+
self.descendants = descendants
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def root_node(self) -> NodeT:
|
|
52
|
+
"""
|
|
53
|
+
Returns the root node of the tree.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
NodeT: The root node of the tree.
|
|
57
|
+
"""
|
|
58
|
+
return self._root_node
|
|
59
|
+
|
|
60
|
+
def node_depth(self, node: NodeT) -> int:
|
|
61
|
+
"""
|
|
62
|
+
Calculates the depth of a given node in the tree.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
node: The node for which to calculate the depth.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
int: The depth of the node.
|
|
69
|
+
"""
|
|
70
|
+
return node.tree_depth - self.tree_root_tree_depth
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides functions for visualizing and saving tree structures.
|
|
3
|
+
|
|
4
|
+
The functions in this module allow for the visualization of tree structures using the Graphviz library.
|
|
5
|
+
It provides a way to display the tree structure as a graph and save it as a PDF file.
|
|
6
|
+
Additionally, it provides a function to save the raw data of the tree structure to a file using pickle.
|
|
7
|
+
|
|
8
|
+
Functions:
|
|
9
|
+
- add_dot(dot: Digraph, treenode: ITreeNode) -> None: Adds nodes and edges to the graph representation of the tree.
|
|
10
|
+
- display_special(node: ITreeNode, format: str, index: dict[chess.Move, str]) -> Digraph: Displays a special
|
|
11
|
+
representation of the tree with additional information.
|
|
12
|
+
- display(tree: ValueTree, format_str: str) -> Digraph: Displays the tree structure as a graph.
|
|
13
|
+
- save_pdf_to_file(tree: ValueTree) -> None: Saves the tree structure as a PDF file.
|
|
14
|
+
- save_raw_data_to_file(tree: ValueTree, count: str = '#') -> None: Saves the raw data of the tree
|
|
15
|
+
structure to a file.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
|
19
|
+
|
|
20
|
+
import sys
|
|
21
|
+
from pickle import dump
|
|
22
|
+
|
|
23
|
+
from graphviz import Digraph
|
|
24
|
+
from valanga import BranchKey, State
|
|
25
|
+
|
|
26
|
+
from anemone.nodes import ITreeNode
|
|
27
|
+
from anemone.nodes.algorithm_node.algorithm_node import (
|
|
28
|
+
AlgorithmNode,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
from .tree import Tree
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def add_dot[StateT: State](dot: Digraph, treenode: ITreeNode[StateT]) -> None:
|
|
35
|
+
"""
|
|
36
|
+
Adds a node and edges to the given Dot graph based on the provided tree node.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
dot (Digraph): The Dot graph to add the node and edges to.
|
|
40
|
+
treenode (ITreeNode): The tree node to visualize.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
None
|
|
44
|
+
"""
|
|
45
|
+
nd = treenode.dot_description()
|
|
46
|
+
dot.node(str(treenode.id), nd)
|
|
47
|
+
branch: BranchKey
|
|
48
|
+
for _, branch in enumerate(treenode.branches_children):
|
|
49
|
+
if treenode.branches_children[branch] is not None:
|
|
50
|
+
child = treenode.branches_children[branch]
|
|
51
|
+
if child is not None:
|
|
52
|
+
cdd = str(child.id)
|
|
53
|
+
dot.edge(
|
|
54
|
+
str(treenode.id),
|
|
55
|
+
cdd,
|
|
56
|
+
str(treenode.state.branch_name_from_key(key=branch)),
|
|
57
|
+
)
|
|
58
|
+
add_dot(dot, child)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def display_special[StateT: State](
|
|
62
|
+
node: AlgorithmNode[StateT], # or AlgorithmNode if you prefer
|
|
63
|
+
format_str: str,
|
|
64
|
+
index: dict[BranchKey, str],
|
|
65
|
+
) -> Digraph:
|
|
66
|
+
"""Display a tree with custom edge labels for the given node."""
|
|
67
|
+
dot = Digraph(format=format_str)
|
|
68
|
+
|
|
69
|
+
nd = node.dot_description()
|
|
70
|
+
dot.node(str(node.id), nd)
|
|
71
|
+
|
|
72
|
+
sorted_branches: list[BranchKey] = sorted(node.branches_children.keys(), key=str)
|
|
73
|
+
|
|
74
|
+
for branch_key in sorted_branches:
|
|
75
|
+
child = node.branches_children[branch_key]
|
|
76
|
+
if child is None:
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
edge_description: str = (
|
|
80
|
+
index[branch_key]
|
|
81
|
+
+ "|"
|
|
82
|
+
+ str(node.state.branch_name_from_key(key=branch_key))
|
|
83
|
+
+ "|"
|
|
84
|
+
+ node.tree_evaluation.description_tree_visualizer_branch(child)
|
|
85
|
+
)
|
|
86
|
+
dot.edge(str(node.id), str(child.id), edge_description)
|
|
87
|
+
dot.node(str(child.id), child.dot_description())
|
|
88
|
+
print("--move:", edge_description)
|
|
89
|
+
print("--child:", child.dot_description())
|
|
90
|
+
|
|
91
|
+
return dot
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def display[StateT: State](
|
|
95
|
+
tree: Tree[AlgorithmNode[StateT]], format_str: str
|
|
96
|
+
) -> Digraph:
|
|
97
|
+
"""
|
|
98
|
+
Display the move and value tree using graph visualization.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
tree (Tree): The move and value tree to be displayed.
|
|
102
|
+
format_str (str): The format of the output graph (e.g., 'png', 'pdf', 'svg').
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Digraph: The graph representation of the move and value tree.
|
|
106
|
+
"""
|
|
107
|
+
dot = Digraph(format=format_str)
|
|
108
|
+
add_dot(dot, tree.root_node)
|
|
109
|
+
return dot
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def save_pdf_to_file[StateT: State](tree: Tree[AlgorithmNode[StateT]]) -> None:
|
|
113
|
+
"""
|
|
114
|
+
Saves the visualization of a tree as a PDF file.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
tree (Tree): The tree to be visualized and saved.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
None
|
|
121
|
+
"""
|
|
122
|
+
dot = display(tree=tree, format_str="pdf")
|
|
123
|
+
tag_ = tree.root_node.state.tag
|
|
124
|
+
dot.render("chipiron/runs/treedisplays/TreeVisual_" + str(tag_) + ".pdf")
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def save_raw_data_to_file(tree: Tree[AlgorithmNode], count: str = "#") -> None:
|
|
128
|
+
"""
|
|
129
|
+
Save raw data of a ValueTree to a file.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
tree (Tree): The Tree object to save.
|
|
133
|
+
count (str, optional): A string to append to the filename. Defaults to '#'.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
None
|
|
137
|
+
"""
|
|
138
|
+
tag_ = tree.root_node.state.tag
|
|
139
|
+
filename = "chipiron/debugTreeData_" + str(tag_) + "-" + str(count) + ".td"
|
|
140
|
+
|
|
141
|
+
sys.setrecursionlimit(100000)
|
|
142
|
+
with open(filename, "wb") as f:
|
|
143
|
+
dump([tree.descendants, tree.root_node], f)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides classes and functions for updating tree values in the move selector.
|
|
3
|
+
|
|
4
|
+
Classes:
|
|
5
|
+
- AlgorithmNodeUpdater: A class for updating algorithm nodes in the tree.
|
|
6
|
+
- MinMaxEvaluationUpdater: A class for updating min-max evaluation values in the tree.
|
|
7
|
+
|
|
8
|
+
Functions:
|
|
9
|
+
- create_algorithm_node_updater: A function for creating an instance of AlgorithmNodeUpdater.
|
|
10
|
+
|
|
11
|
+
Other:
|
|
12
|
+
- UpdateInstructions: A class representing update instructions for a single node.
|
|
13
|
+
- UpdateInstructionsBatch: A class representing a batch of update instructions.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from .algorithm_node_updater import AlgorithmNodeUpdater
|
|
18
|
+
from .factory import create_algorithm_node_updater
|
|
19
|
+
from .minmax_evaluation_updater import MinMaxEvaluationUpdater
|
|
20
|
+
from .updates_file import (
|
|
21
|
+
UpdateInstructionsFromOneNode,
|
|
22
|
+
UpdateInstructionsTowardsMultipleNodes,
|
|
23
|
+
UpdateInstructionsTowardsOneParentNode,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"create_algorithm_node_updater",
|
|
28
|
+
"AlgorithmNodeUpdater",
|
|
29
|
+
"UpdateInstructionsFromOneNode",
|
|
30
|
+
"UpdateInstructionsTowardsOneParentNode",
|
|
31
|
+
"MinMaxEvaluationUpdater",
|
|
32
|
+
"UpdateInstructionsTowardsMultipleNodes",
|
|
33
|
+
]
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the AlgorithmNodeUpdater class, which is responsible for updating AlgorithmNode objects in a
|
|
3
|
+
tree structure.
|
|
4
|
+
|
|
5
|
+
The AlgorithmNodeUpdater class provides methods for creating update instructions after a node is added to the
|
|
6
|
+
tree, generating update instructions for a batch of tree expansions, and performing updates on a specific node
|
|
7
|
+
based on the given update instructions.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
from anemone.nodes.algorithm_node.algorithm_node import (
|
|
14
|
+
AlgorithmNode,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from .index_updater import IndexUpdater
|
|
18
|
+
from .minmax_evaluation_updater import MinMaxEvaluationUpdater
|
|
19
|
+
from .updates_file import (
|
|
20
|
+
UpdateInstructionsFromOneNode,
|
|
21
|
+
UpdateInstructionsTowardsMultipleNodes,
|
|
22
|
+
UpdateInstructionsTowardsOneParentNode,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from anemone.tree_manager.tree_expander import TreeExpansion, TreeExpansions
|
|
27
|
+
|
|
28
|
+
from .index_block import IndexUpdateInstructionsFromOneNode
|
|
29
|
+
from .value_block import ValueUpdateInstructionsFromOneNode
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class AlgorithmNodeUpdater:
|
|
34
|
+
"""
|
|
35
|
+
The AlgorithmNodeUpdater class is responsible for updating AlgorithmNode objects in a tree.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
minmax_evaluation_updater (MinMaxEvaluationUpdater): The updater for min-max evaluation values.
|
|
39
|
+
index_updater (IndexUpdater | None): The updater for node indices, if available.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
minmax_evaluation_updater: MinMaxEvaluationUpdater
|
|
43
|
+
index_updater: IndexUpdater | None = None
|
|
44
|
+
|
|
45
|
+
def create_update_instructions_after_node_birth(
|
|
46
|
+
self, new_node: AlgorithmNode
|
|
47
|
+
) -> UpdateInstructionsFromOneNode:
|
|
48
|
+
"""
|
|
49
|
+
Creates update instructions after a new node is added to the tree.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
new_node (AlgorithmNode): The newly added AlgorithmNode.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
UpdateInstructions: The update instructions for the new node.
|
|
56
|
+
"""
|
|
57
|
+
value_update_instructions: ValueUpdateInstructionsFromOneNode
|
|
58
|
+
value_update_instructions = (
|
|
59
|
+
self.minmax_evaluation_updater.create_update_instructions_after_node_birth(
|
|
60
|
+
new_node=new_node
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
index_update_instructions: IndexUpdateInstructionsFromOneNode | None
|
|
65
|
+
if self.index_updater is not None:
|
|
66
|
+
index_update_instructions = (
|
|
67
|
+
self.index_updater.create_update_instructions_after_node_birth(
|
|
68
|
+
new_node=new_node
|
|
69
|
+
)
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
index_update_instructions = None
|
|
73
|
+
|
|
74
|
+
update_instructions: UpdateInstructionsFromOneNode = (
|
|
75
|
+
UpdateInstructionsFromOneNode(
|
|
76
|
+
value_block=value_update_instructions,
|
|
77
|
+
index_block=index_update_instructions,
|
|
78
|
+
)
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
return update_instructions
|
|
82
|
+
|
|
83
|
+
def generate_update_instructions[NodeT: AlgorithmNode](
|
|
84
|
+
self, tree_expansions: "TreeExpansions[NodeT]"
|
|
85
|
+
) -> "UpdateInstructionsTowardsMultipleNodes[NodeT]":
|
|
86
|
+
"""
|
|
87
|
+
Generates update instructions for a batch of tree expansions.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
tree_expansions (tree_man.TreeExpansions): The batch of tree expansions.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
UpdateInstructionsBatch: The update instructions for the batch of tree expansions.
|
|
94
|
+
"""
|
|
95
|
+
# TODO is the way of merging now overkill?
|
|
96
|
+
|
|
97
|
+
update_instructions_batch: UpdateInstructionsTowardsMultipleNodes[NodeT]
|
|
98
|
+
update_instructions_batch = UpdateInstructionsTowardsMultipleNodes()
|
|
99
|
+
|
|
100
|
+
tree_expansion: "TreeExpansion[NodeT]"
|
|
101
|
+
for tree_expansion in tree_expansions:
|
|
102
|
+
update_instructions: UpdateInstructionsFromOneNode = (
|
|
103
|
+
self.create_update_instructions_after_node_birth(
|
|
104
|
+
new_node=tree_expansion.child_node
|
|
105
|
+
)
|
|
106
|
+
)
|
|
107
|
+
# update_instructions_batch is key sorted dict, sorted by depth to ensure proper backprop from the back
|
|
108
|
+
|
|
109
|
+
assert tree_expansion.parent_node is not None
|
|
110
|
+
# looks like we should not update from the root node backward!
|
|
111
|
+
|
|
112
|
+
assert tree_expansion.branch_key is not None
|
|
113
|
+
update_instructions_batch.add_update_from_one_child_node(
|
|
114
|
+
update_from_child_node=update_instructions,
|
|
115
|
+
parent_node=tree_expansion.parent_node,
|
|
116
|
+
branch_from_parent=tree_expansion.branch_key,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return update_instructions_batch
|
|
120
|
+
|
|
121
|
+
def perform_updates(
|
|
122
|
+
self,
|
|
123
|
+
node_to_update: AlgorithmNode,
|
|
124
|
+
update_instructions: UpdateInstructionsTowardsOneParentNode,
|
|
125
|
+
) -> UpdateInstructionsFromOneNode:
|
|
126
|
+
"""
|
|
127
|
+
Performs updates on a specific node based on the given update instructions.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
node_to_update (AlgorithmNode): The node to update.
|
|
131
|
+
update_instructions (UpdateInstructions): The update instructions for the node.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
UpdateInstructions: The new update instructions after performing the updates.
|
|
135
|
+
"""
|
|
136
|
+
value_update_instructions: ValueUpdateInstructionsFromOneNode = (
|
|
137
|
+
self.minmax_evaluation_updater.perform_updates(
|
|
138
|
+
node_to_update, updates_instructions=update_instructions
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
index_update_instructions: IndexUpdateInstructionsFromOneNode | None
|
|
143
|
+
if self.index_updater is not None:
|
|
144
|
+
index_update_instructions = self.index_updater.perform_updates(
|
|
145
|
+
node_to_update, updates_instructions=update_instructions
|
|
146
|
+
)
|
|
147
|
+
else:
|
|
148
|
+
index_update_instructions = None
|
|
149
|
+
|
|
150
|
+
new_update_instructions: UpdateInstructionsFromOneNode = (
|
|
151
|
+
UpdateInstructionsFromOneNode(
|
|
152
|
+
value_block=value_update_instructions,
|
|
153
|
+
index_block=index_update_instructions,
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
return new_update_instructions
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides a factory function to create an instance of AlgorithmNodeUpdater.
|
|
3
|
+
|
|
4
|
+
The AlgorithmNodeUpdater is responsible for updating the algorithm node in a tree structure.
|
|
5
|
+
|
|
6
|
+
The factory function `create_algorithm_node_updater` takes an optional `index_updater` parameter and returns an instance of AlgorithmNodeUpdater.
|
|
7
|
+
|
|
8
|
+
Example usage:
|
|
9
|
+
algorithm_node_updater = create_algorithm_node_updater(index_updater)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .algorithm_node_updater import AlgorithmNodeUpdater
|
|
13
|
+
from .index_updater import IndexUpdater
|
|
14
|
+
from .minmax_evaluation_updater import MinMaxEvaluationUpdater
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def create_algorithm_node_updater(
|
|
18
|
+
index_updater: IndexUpdater | None,
|
|
19
|
+
) -> AlgorithmNodeUpdater:
|
|
20
|
+
"""
|
|
21
|
+
Creates an instance of AlgorithmNodeUpdater.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
index_updater (IndexUpdater | None): The index updater object.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
AlgorithmNodeUpdater: An instance of AlgorithmNodeUpdater.
|
|
28
|
+
|
|
29
|
+
"""
|
|
30
|
+
minmax_evaluation_updater: MinMaxEvaluationUpdater = MinMaxEvaluationUpdater()
|
|
31
|
+
|
|
32
|
+
algorithm_node_updater: AlgorithmNodeUpdater = AlgorithmNodeUpdater(
|
|
33
|
+
minmax_evaluation_updater=minmax_evaluation_updater, index_updater=index_updater
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
return algorithm_node_updater
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module defines the IndexUpdateInstructionsBlock class, which represents a block of update instructions for
|
|
3
|
+
index values in a tree structure.
|
|
4
|
+
|
|
5
|
+
The IndexUpdateInstructionsBlock class is a dataclass that contains a set of AlgorithmNode objects representing
|
|
6
|
+
children with updated index values. It provides methods for merging update instructions and printing information
|
|
7
|
+
about the block.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Self
|
|
12
|
+
|
|
13
|
+
from valanga import BranchKey
|
|
14
|
+
|
|
15
|
+
from anemone.nodes.algorithm_node.algorithm_node import (
|
|
16
|
+
AlgorithmNode,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(slots=True)
|
|
21
|
+
class IndexUpdateInstructionsFromOneNode:
|
|
22
|
+
"""
|
|
23
|
+
Represents a block of instructions for updating an index.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
node_sending_update (AlgorithmNode): The node sending the update.
|
|
27
|
+
updated_index (bool): Indicates whether the index has been updated.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
node_sending_update: AlgorithmNode
|
|
31
|
+
updated_index: bool
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _new_branches_with_updated_index() -> set[BranchKey]:
|
|
35
|
+
"""Return a new set for branches with updated indices."""
|
|
36
|
+
return set()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass(slots=True)
|
|
40
|
+
class IndexUpdateInstructionsTowardsOneParentNode:
|
|
41
|
+
"""
|
|
42
|
+
Represents a block of index update instructions intended to a specific node in the algorithm tree.
|
|
43
|
+
|
|
44
|
+
This class is used to store and manipulate sets of children with updated index values.
|
|
45
|
+
|
|
46
|
+
Attributes:
|
|
47
|
+
moves_with_updated_index (Set[IMove]): A set of children with updated index values.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
branches_with_updated_index: set[BranchKey] = field(
|
|
51
|
+
default_factory=_new_branches_with_updated_index
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def add_update_from_one_child_node(
|
|
55
|
+
self,
|
|
56
|
+
update_from_one_child_node: IndexUpdateInstructionsFromOneNode,
|
|
57
|
+
branch_from_parent_to_child: BranchKey,
|
|
58
|
+
) -> None:
|
|
59
|
+
"""Adds an update from a child node to the parent node.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
update_from_one_child_node (IndexUpdateInstructionsFromOneNode): The update instructions from the child node.
|
|
63
|
+
move_from_parent_to_child (moveKey): The move key representing the parent's move to the child.
|
|
64
|
+
"""
|
|
65
|
+
if update_from_one_child_node.updated_index:
|
|
66
|
+
self.branches_with_updated_index.add(branch_from_parent_to_child)
|
|
67
|
+
|
|
68
|
+
def add_update_toward_one_parent_node(self, another_update: Self) -> None:
|
|
69
|
+
"""Adds an update from another child node to the parent node.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
another_update (Self): The update instructions from another child node.
|
|
73
|
+
"""
|
|
74
|
+
self.branches_with_updated_index = (
|
|
75
|
+
self.branches_with_updated_index
|
|
76
|
+
| another_update.branches_with_updated_index
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def empty(self) -> bool:
|
|
80
|
+
"""
|
|
81
|
+
Check if the IndexUpdateInstructionsBlock is empty.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
bool: True if the block is empty, False otherwise.
|
|
85
|
+
"""
|
|
86
|
+
empty_bool = not bool(self.branches_with_updated_index)
|
|
87
|
+
return empty_bool
|
|
88
|
+
|
|
89
|
+
def print_info(self) -> None:
|
|
90
|
+
"""Prints information about the moves with updated indices."""
|
|
91
|
+
print(self.branches_with_updated_index)
|