algorhino-anemone 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- algorhino_anemone-0.1.1.dist-info/METADATA +151 -0
- algorhino_anemone-0.1.1.dist-info/RECORD +82 -0
- algorhino_anemone-0.1.1.dist-info/WHEEL +5 -0
- algorhino_anemone-0.1.1.dist-info/licenses/LICENSE +674 -0
- algorhino_anemone-0.1.1.dist-info/top_level.txt +1 -0
- anemone/__init__.py +27 -0
- anemone/basics.py +36 -0
- anemone/factory.py +161 -0
- anemone/indices/__init__.py +0 -0
- anemone/indices/index_manager/__init__.py +12 -0
- anemone/indices/index_manager/factory.py +50 -0
- anemone/indices/index_manager/node_exploration_manager.py +549 -0
- anemone/indices/node_indices/__init__.py +22 -0
- anemone/indices/node_indices/factory.py +121 -0
- anemone/indices/node_indices/index_data.py +166 -0
- anemone/indices/node_indices/index_types.py +20 -0
- anemone/nn/torch_evaluator.py +108 -0
- anemone/node_evaluation/__init__.py +0 -0
- anemone/node_evaluation/node_direct_evaluation/__init__.py +22 -0
- anemone/node_evaluation/node_direct_evaluation/factory.py +12 -0
- anemone/node_evaluation/node_direct_evaluation/node_direct_evaluator.py +192 -0
- anemone/node_evaluation/node_tree_evaluation/node_minmax_evaluation.py +885 -0
- anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation.py +137 -0
- anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation_factory.py +43 -0
- anemone/node_factory/__init__.py +14 -0
- anemone/node_factory/algorithm_node_factory.py +123 -0
- anemone/node_factory/base.py +76 -0
- anemone/node_selector/__init__.py +32 -0
- anemone/node_selector/branch_explorer.py +89 -0
- anemone/node_selector/factory.py +65 -0
- anemone/node_selector/node_selector.py +44 -0
- anemone/node_selector/node_selector_args.py +22 -0
- anemone/node_selector/node_selector_types.py +15 -0
- anemone/node_selector/notations_and_statics.py +88 -0
- anemone/node_selector/opening_instructions.py +249 -0
- anemone/node_selector/recurzipf/__init__.py +0 -0
- anemone/node_selector/recurzipf/recur_zipf_base.py +141 -0
- anemone/node_selector/sequool/__init__.py +19 -0
- anemone/node_selector/sequool/factory.py +102 -0
- anemone/node_selector/sequool/sequool.py +395 -0
- anemone/node_selector/uniform/__init__.py +16 -0
- anemone/node_selector/uniform/uniform.py +113 -0
- anemone/nodes/__init__.py +15 -0
- anemone/nodes/algorithm_node/__init__.py +7 -0
- anemone/nodes/algorithm_node/algorithm_node.py +204 -0
- anemone/nodes/itree_node.py +136 -0
- anemone/nodes/tree_node.py +240 -0
- anemone/nodes/tree_traversal.py +108 -0
- anemone/nodes/utils.py +146 -0
- anemone/progress_monitor/__init__.py +0 -0
- anemone/progress_monitor/progress_monitor.py +375 -0
- anemone/recommender_rule/__init__.py +12 -0
- anemone/recommender_rule/recommender_rule.py +140 -0
- anemone/search_factory/__init__.py +14 -0
- anemone/search_factory/search_factory.py +192 -0
- anemone/state_transition.py +47 -0
- anemone/tree_and_value_branch_selector.py +99 -0
- anemone/tree_exploration.py +274 -0
- anemone/tree_manager/__init__.py +29 -0
- anemone/tree_manager/algorithm_node_tree_manager.py +246 -0
- anemone/tree_manager/factory.py +77 -0
- anemone/tree_manager/tree_expander.py +122 -0
- anemone/tree_manager/tree_manager.py +254 -0
- anemone/trees/__init__.py +14 -0
- anemone/trees/descendants.py +765 -0
- anemone/trees/factory.py +80 -0
- anemone/trees/tree.py +70 -0
- anemone/trees/tree_visualization.py +143 -0
- anemone/updates/__init__.py +33 -0
- anemone/updates/algorithm_node_updater.py +157 -0
- anemone/updates/factory.py +36 -0
- anemone/updates/index_block.py +91 -0
- anemone/updates/index_updater.py +100 -0
- anemone/updates/minmax_evaluation_updater.py +108 -0
- anemone/updates/updates_file.py +248 -0
- anemone/updates/value_block.py +133 -0
- anemone/utils/comparable.py +32 -0
- anemone/utils/dataclass.py +64 -0
- anemone/utils/dict_of_numbered_dict_with_pointer_on_max.py +128 -0
- anemone/utils/logger.py +94 -0
- anemone/utils/my_value_sorted_dict.py +27 -0
- anemone/utils/small_tools.py +103 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Protocol, Self
|
|
2
|
+
|
|
3
|
+
from valanga import (
|
|
4
|
+
BoardEvaluation,
|
|
5
|
+
BranchKey,
|
|
6
|
+
OverEvent,
|
|
7
|
+
State,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
type BranchSortValue = tuple[float, int, int]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from anemone.nodes.itree_node import ITreeNode
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class NodeTreeEvaluation[StateT: State = State](Protocol):
|
|
18
|
+
"""
|
|
19
|
+
Interface for Node Tree Evaluation
|
|
20
|
+
This is the evaluation of a node that is based both on a direct evaluation of the state within and of the NodeTreeEvaluation
|
|
21
|
+
and its children.
|
|
22
|
+
The direct evaluation is used to evaluate leaf nodes, while the children evaluations are used to propagate values up the tree.
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
# absolute value wrt to white player as estimated by a state evaluator
|
|
27
|
+
value_white_direct_evaluation: float | None = None
|
|
28
|
+
|
|
29
|
+
# creating a base Over event that is set to None
|
|
30
|
+
over_event: OverEvent
|
|
31
|
+
|
|
32
|
+
# the list of branches that have not yet be found to be over
|
|
33
|
+
# using atm a list instead of set as atm python set are not insertion ordered which adds randomness
|
|
34
|
+
# and makes debug harder
|
|
35
|
+
branches_not_over: list[BranchKey]
|
|
36
|
+
|
|
37
|
+
branches_sorted_by_value_: dict[BranchKey, BranchSortValue]
|
|
38
|
+
|
|
39
|
+
best_branch_sequence: list[BranchKey]
|
|
40
|
+
|
|
41
|
+
# absolute value wrt to white player as computed from the value_white_* of the descendants
|
|
42
|
+
# of this node (self) by a minmax procedure.
|
|
43
|
+
value_white_minmax: float | None = None
|
|
44
|
+
|
|
45
|
+
def set_evaluation(self, evaluation: float) -> None:
|
|
46
|
+
"""sets the evaluation from the board evaluator
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
evaluation (float): The evaluation value to be set.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
None
|
|
53
|
+
"""
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
def is_over(self) -> bool:
|
|
57
|
+
"""
|
|
58
|
+
Checks if the game is over.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
bool: True if the game is over, False otherwise.
|
|
62
|
+
"""
|
|
63
|
+
...
|
|
64
|
+
|
|
65
|
+
def dot_description(self) -> str:
|
|
66
|
+
"""
|
|
67
|
+
Returns a string representation of the node's description in DOT format.
|
|
68
|
+
|
|
69
|
+
The description includes the values of `value_white_minmax` and `value_white_evaluator`,
|
|
70
|
+
as well as the best branch sequence and the over event tag.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
A string representation of the node's description in DOT format.
|
|
74
|
+
"""
|
|
75
|
+
...
|
|
76
|
+
|
|
77
|
+
def update_best_branch_sequence(
|
|
78
|
+
self, branches_with_updated_best_branch_seq: set[BranchKey]
|
|
79
|
+
) -> bool:
|
|
80
|
+
"""Update the best branch sequence from updated branches."""
|
|
81
|
+
...
|
|
82
|
+
|
|
83
|
+
def minmax_value_update_from_children(
|
|
84
|
+
self, branches_with_updated_value: set[BranchKey]
|
|
85
|
+
) -> tuple[bool, bool]:
|
|
86
|
+
"""Update minmax value from children and return update flags."""
|
|
87
|
+
...
|
|
88
|
+
|
|
89
|
+
def update_over(self, branches_with_updated_over: set[BranchKey]) -> bool:
|
|
90
|
+
"""Update terminal state based on updated branches."""
|
|
91
|
+
...
|
|
92
|
+
|
|
93
|
+
def evaluate(self) -> BoardEvaluation:
|
|
94
|
+
"""Return a board evaluation for this node."""
|
|
95
|
+
...
|
|
96
|
+
|
|
97
|
+
def description_tree_visualizer_branch(self, child: "ITreeNode[StateT]") -> str:
|
|
98
|
+
"""Return a visualization label for a child branch."""
|
|
99
|
+
...
|
|
100
|
+
|
|
101
|
+
def print_best_line(self) -> None:
|
|
102
|
+
"""Print the current best line."""
|
|
103
|
+
...
|
|
104
|
+
|
|
105
|
+
def get_value_white(self) -> float:
|
|
106
|
+
"""Return the current white evaluation value."""
|
|
107
|
+
...
|
|
108
|
+
|
|
109
|
+
def best_branch(self) -> BranchKey | None:
|
|
110
|
+
"""Return the current best branch key."""
|
|
111
|
+
...
|
|
112
|
+
|
|
113
|
+
def second_best_branch(self) -> BranchKey:
|
|
114
|
+
"""Return the second-best branch key."""
|
|
115
|
+
...
|
|
116
|
+
|
|
117
|
+
def print_branches_sorted_by_value(self) -> None:
|
|
118
|
+
"""Print branches sorted by value."""
|
|
119
|
+
...
|
|
120
|
+
|
|
121
|
+
def print_branches_sorted_by_value_and_exploration(self) -> None:
|
|
122
|
+
"""Print branches sorted by value and exploration metrics."""
|
|
123
|
+
...
|
|
124
|
+
|
|
125
|
+
def get_all_of_the_best_branches(
|
|
126
|
+
self, how_equal: str | None = None
|
|
127
|
+
) -> list[BranchKey]:
|
|
128
|
+
"""Return all best branches according to an equality rule."""
|
|
129
|
+
...
|
|
130
|
+
|
|
131
|
+
def subjective_value_of(self, another_node_eval: Self) -> float:
|
|
132
|
+
"""Return this node's value relative to another evaluation."""
|
|
133
|
+
...
|
|
134
|
+
|
|
135
|
+
def sort_branches_not_over(self) -> list[BranchKey]:
|
|
136
|
+
"""Return branches not over, sorted by evaluation."""
|
|
137
|
+
...
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import Any, Protocol
|
|
2
|
+
|
|
3
|
+
from valanga import State, TurnState
|
|
4
|
+
|
|
5
|
+
from anemone.node_evaluation.node_tree_evaluation.node_minmax_evaluation import (
|
|
6
|
+
NodeMinmaxEvaluation,
|
|
7
|
+
)
|
|
8
|
+
from anemone.node_evaluation.node_tree_evaluation.node_tree_evaluation import (
|
|
9
|
+
NodeTreeEvaluation,
|
|
10
|
+
)
|
|
11
|
+
from anemone.nodes.tree_node import TreeNode
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class NodeTreeMinmaxEvaluationFactory[StateT: TurnState]:
|
|
15
|
+
"""
|
|
16
|
+
The class creating Node Evaluations including children
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def create(
|
|
20
|
+
self,
|
|
21
|
+
tree_node: TreeNode[Any, StateT],
|
|
22
|
+
) -> NodeMinmaxEvaluation[Any, StateT]:
|
|
23
|
+
"""
|
|
24
|
+
Creates a new NodeEvaluationIncludingChildren object.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
tree_node (TreeNode): The tree node for which the evaluation is created.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
NodeEvaluationIncludingChildren: The newly created NodeEvaluationIncludingChildren object.
|
|
31
|
+
"""
|
|
32
|
+
return NodeMinmaxEvaluation(tree_node=tree_node)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class NodeTreeEvaluationFactory[StateT2: State = State](Protocol):
|
|
36
|
+
"""The class creating Node Evaluations including children."""
|
|
37
|
+
|
|
38
|
+
def create(
|
|
39
|
+
self,
|
|
40
|
+
tree_node: TreeNode[Any, StateT2],
|
|
41
|
+
) -> NodeTreeEvaluation[StateT2]:
|
|
42
|
+
"""Create a NodeTreeEvaluation instance for the given node."""
|
|
43
|
+
...
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides the node factory classes for creating tree nodes in the move selector algorithm.
|
|
3
|
+
|
|
4
|
+
The available classes in this module are:
|
|
5
|
+
- TreeNodeFactory: A base class for creating tree nodes.
|
|
6
|
+
- Base: A base class for the node factory classes.
|
|
7
|
+
- create_node_factory: A function for creating a node factory.
|
|
8
|
+
- AlgorithmNodeFactory: A node factory class for the move selector algorithm.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from .algorithm_node_factory import AlgorithmNodeFactory
|
|
12
|
+
from .base import TreeNodeFactory
|
|
13
|
+
|
|
14
|
+
__all__ = ["TreeNodeFactory", "AlgorithmNodeFactory"]
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
""" "
|
|
2
|
+
AlgorithmNodeFactory
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from valanga import (
|
|
9
|
+
BranchKey,
|
|
10
|
+
ContentRepresentation,
|
|
11
|
+
RepresentationFactory,
|
|
12
|
+
State,
|
|
13
|
+
StateModifications,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from anemone.basics import TreeDepth
|
|
17
|
+
from anemone.indices import node_indices
|
|
18
|
+
from anemone.node_evaluation.node_tree_evaluation.node_tree_evaluation_factory import (
|
|
19
|
+
NodeTreeEvaluationFactory,
|
|
20
|
+
)
|
|
21
|
+
from anemone.node_factory.base import TreeNodeFactory
|
|
22
|
+
from anemone.nodes.algorithm_node.algorithm_node import (
|
|
23
|
+
AlgorithmNode,
|
|
24
|
+
)
|
|
25
|
+
from anemone.nodes.tree_node import TreeNode
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from anemone.node_evaluation.node_tree_evaluation.node_tree_evaluation import (
|
|
29
|
+
NodeTreeEvaluation,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class AlgorithmNodeFactory[StateT: State = State]:
|
|
35
|
+
"""
|
|
36
|
+
The classe creating Algorithm Nodes
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
tree_node_factory: TreeNodeFactory[AlgorithmNode[StateT], StateT]
|
|
40
|
+
state_representation_factory: RepresentationFactory | None
|
|
41
|
+
node_tree_evaluation_factory: NodeTreeEvaluationFactory[StateT]
|
|
42
|
+
exploration_index_data_create: node_indices.ExplorationIndexDataFactory[
|
|
43
|
+
AlgorithmNode[StateT], StateT
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
def create_from_tree_node(
|
|
47
|
+
self,
|
|
48
|
+
tree_node: TreeNode[AlgorithmNode[StateT], StateT],
|
|
49
|
+
parent_node: AlgorithmNode[StateT] | None,
|
|
50
|
+
modifications: StateModifications | None,
|
|
51
|
+
) -> AlgorithmNode[StateT]:
|
|
52
|
+
"""Build an AlgorithmNode from an existing TreeNode."""
|
|
53
|
+
tree_evaluation: NodeTreeEvaluation[StateT] = (
|
|
54
|
+
self.node_tree_evaluation_factory.create(
|
|
55
|
+
tree_node=tree_node,
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
exploration_index_data: (
|
|
60
|
+
node_indices.NodeExplorationData[AlgorithmNode[StateT], StateT] | None
|
|
61
|
+
) = self.exploration_index_data_create(tree_node)
|
|
62
|
+
|
|
63
|
+
state_representation: ContentRepresentation | None = None
|
|
64
|
+
if self.state_representation_factory is not None:
|
|
65
|
+
if parent_node is not None:
|
|
66
|
+
parent_node_representation = parent_node.state_representation
|
|
67
|
+
else:
|
|
68
|
+
parent_node_representation = None
|
|
69
|
+
|
|
70
|
+
state_representation = (
|
|
71
|
+
self.state_representation_factory.create_from_transition(
|
|
72
|
+
state=tree_node.state,
|
|
73
|
+
previous_state_representation=parent_node_representation,
|
|
74
|
+
modifications=modifications,
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return AlgorithmNode(
|
|
79
|
+
tree_node=tree_node,
|
|
80
|
+
tree_evaluation=tree_evaluation,
|
|
81
|
+
exploration_index_data=exploration_index_data,
|
|
82
|
+
state_representation=state_representation,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def create(
|
|
86
|
+
self,
|
|
87
|
+
state: StateT,
|
|
88
|
+
tree_depth: TreeDepth,
|
|
89
|
+
count: int,
|
|
90
|
+
parent_node: AlgorithmNode[StateT] | None,
|
|
91
|
+
branch_from_parent: BranchKey | None,
|
|
92
|
+
modifications: StateModifications | None,
|
|
93
|
+
) -> AlgorithmNode[StateT]:
|
|
94
|
+
"""
|
|
95
|
+
Creates an AlgorithmNode object.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
branch_from_parent (BranchKey | None): the move that led to the node from the parent node
|
|
99
|
+
state: The state object.
|
|
100
|
+
tree_depth: The tree depth.
|
|
101
|
+
count: The count.
|
|
102
|
+
parent_node: The parent node object.
|
|
103
|
+
modifications: The board modifications object.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
An AlgorithmNode object.
|
|
107
|
+
|
|
108
|
+
"""
|
|
109
|
+
tree_node: TreeNode[AlgorithmNode[StateT], StateT] = (
|
|
110
|
+
self.tree_node_factory.create(
|
|
111
|
+
state=state,
|
|
112
|
+
tree_depth=tree_depth,
|
|
113
|
+
count=count,
|
|
114
|
+
branch_from_parent=branch_from_parent,
|
|
115
|
+
parent_node=parent_node,
|
|
116
|
+
)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return self.create_from_tree_node(
|
|
120
|
+
tree_node=tree_node,
|
|
121
|
+
parent_node=parent_node,
|
|
122
|
+
modifications=modifications,
|
|
123
|
+
)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Basic class for Creating Tree nodes
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Any, Protocol
|
|
6
|
+
|
|
7
|
+
from valanga import BranchKey, State, StateModifications
|
|
8
|
+
|
|
9
|
+
from anemone.basics import TreeDepth
|
|
10
|
+
from anemone.nodes.itree_node import ITreeNode
|
|
11
|
+
from anemone.nodes.tree_node import TreeNode
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class NodeFactory[NodeT: ITreeNode[Any] = ITreeNode[Any]](Protocol):
|
|
15
|
+
"""
|
|
16
|
+
Node Factory
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def create(
|
|
20
|
+
self,
|
|
21
|
+
state: State,
|
|
22
|
+
tree_depth: TreeDepth,
|
|
23
|
+
count: int,
|
|
24
|
+
parent_node: NodeT | None,
|
|
25
|
+
branch_from_parent: BranchKey | None,
|
|
26
|
+
modifications: StateModifications | None,
|
|
27
|
+
) -> NodeT:
|
|
28
|
+
"""Create a node from state and tree metadata."""
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TreeNodeFactory[T: ITreeNode[Any] = ITreeNode[Any], StateT: State = State]:
|
|
33
|
+
"""
|
|
34
|
+
Basic class for Creating Tree nodes
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def create(
|
|
38
|
+
self,
|
|
39
|
+
state: StateT,
|
|
40
|
+
tree_depth: TreeDepth,
|
|
41
|
+
count: int,
|
|
42
|
+
parent_node: T | None,
|
|
43
|
+
branch_from_parent: BranchKey | None,
|
|
44
|
+
modifications: StateModifications | None = None,
|
|
45
|
+
) -> TreeNode[T, StateT]:
|
|
46
|
+
"""
|
|
47
|
+
Creates a new TreeNode object.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
board (boards.BoardChi): The current board state.
|
|
51
|
+
half_move (int): The half-move count.
|
|
52
|
+
count (int): The ID of the new node.
|
|
53
|
+
parent_node (ITreeNode | None): The parent node of the new node.
|
|
54
|
+
move_from_parent (chess.Move | None): The move that leads to the new node.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
TreeNode: The newly created TreeNode object.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
# TreeNode doesn't use modifications (it's a pure data container).
|
|
61
|
+
_ = modifications
|
|
62
|
+
|
|
63
|
+
parent_nodes: dict[T, BranchKey]
|
|
64
|
+
if parent_node is None:
|
|
65
|
+
parent_nodes = {}
|
|
66
|
+
else:
|
|
67
|
+
assert branch_from_parent is not None
|
|
68
|
+
parent_nodes = {parent_node: branch_from_parent}
|
|
69
|
+
|
|
70
|
+
tree_node: TreeNode[T, StateT] = TreeNode[T, StateT](
|
|
71
|
+
state_=state,
|
|
72
|
+
tree_depth_=tree_depth,
|
|
73
|
+
id_=count,
|
|
74
|
+
parent_nodes_=parent_nodes,
|
|
75
|
+
)
|
|
76
|
+
return tree_node
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides classes and functions for selecting nodes in a tree structure.
|
|
3
|
+
|
|
4
|
+
The module includes the following components:
|
|
5
|
+
- `create`: A factory function for creating node selectors.
|
|
6
|
+
- `AllNodeSelectorArgs`: A class that represents all possible arguments for node selectors.
|
|
7
|
+
- `NodeSelector`: A class that represents a node selector.
|
|
8
|
+
- `NodeSelectorArgs`: A class that represents the arguments for a node selector.
|
|
9
|
+
- `NodeSelectorType`: An enumeration of different types of node selectors.
|
|
10
|
+
- `OpeningInstructions`: A class that represents opening instructions for node selectors.
|
|
11
|
+
- `OpeningInstruction`: A class that represents an opening instruction for node selectors.
|
|
12
|
+
- `OpeningType`: An enumeration of different types of opening instructions.
|
|
13
|
+
|
|
14
|
+
To use this module, import it and use the provided classes and functions as needed.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from .factory import AllNodeSelectorArgs, create
|
|
18
|
+
from .node_selector import NodeSelector
|
|
19
|
+
from .node_selector_args import NodeSelectorArgs
|
|
20
|
+
from .node_selector_types import NodeSelectorType
|
|
21
|
+
from .opening_instructions import OpeningInstruction, OpeningInstructions, OpeningType
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"OpeningInstructions",
|
|
25
|
+
"OpeningInstruction",
|
|
26
|
+
"AllNodeSelectorArgs",
|
|
27
|
+
"OpeningType",
|
|
28
|
+
"NodeSelector",
|
|
29
|
+
"create",
|
|
30
|
+
"NodeSelectorArgs",
|
|
31
|
+
"NodeSelectorType",
|
|
32
|
+
]
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the branchExplorer class and its subclasses.
|
|
3
|
+
branchExplorer is responsible for exploring branches in a game tree.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from random import Random
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from valanga import BranchKey
|
|
11
|
+
|
|
12
|
+
from anemone.node_selector.notations_and_statics import (
|
|
13
|
+
zipf_picks_random,
|
|
14
|
+
)
|
|
15
|
+
from anemone.nodes.algorithm_node import AlgorithmNode
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SamplingPriorities(str, Enum):
|
|
19
|
+
"""
|
|
20
|
+
Enumeration class representing the sampling priorities for branch exploration.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
NO_PRIORITY (str): No priority for branch sampling.
|
|
24
|
+
PRIORITY_BEST (str): Priority for the best branch.
|
|
25
|
+
PRIORITY_TWO_BEST (str): Priority for the two best branches.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
NO_PRIORITY = "no_priority"
|
|
29
|
+
PRIORITY_BEST = "priority_best"
|
|
30
|
+
PRIORITY_TWO_BEST = "priority_two_best"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BranchExplorer:
|
|
34
|
+
"""
|
|
35
|
+
BranchExplorer is responsible for exploring branches in a game tree.
|
|
36
|
+
It provides a method to sample a child node to explore.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
priority_sampling: SamplingPriorities
|
|
40
|
+
|
|
41
|
+
def __init__(self, priority_sampling: SamplingPriorities):
|
|
42
|
+
"""
|
|
43
|
+
Initializes a branchExplorer instance.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
priority_sampling (SamplingPriorities): The priority sampling strategy to use.
|
|
47
|
+
"""
|
|
48
|
+
self.priority_sampling = priority_sampling
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ZipfBranchExplorer(BranchExplorer):
|
|
52
|
+
"""
|
|
53
|
+
ZipfBranchExplorer is a subclass of BranchExplorer that uses the Zipf distribution for sampling.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self, priority_sampling: SamplingPriorities, random_generator: Random
|
|
58
|
+
) -> None:
|
|
59
|
+
"""
|
|
60
|
+
Initializes a ZipfbranchExplorer instance.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
priority_sampling (SamplingPriorities): The priority sampling strategy to use.
|
|
64
|
+
random_generator (Random): The random number generator to use.
|
|
65
|
+
"""
|
|
66
|
+
super().__init__(priority_sampling)
|
|
67
|
+
self.random_generator = random_generator
|
|
68
|
+
|
|
69
|
+
def sample_branch_to_explore(
|
|
70
|
+
self, tree_node_to_sample_from: AlgorithmNode[Any]
|
|
71
|
+
) -> BranchKey:
|
|
72
|
+
"""
|
|
73
|
+
Samples a child node to explore from the given tree node.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
tree_node_to_sample_from (AlgorithmNode): The tree node to sample from.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
AlgorithmNode: The sampled child node to explore.
|
|
80
|
+
"""
|
|
81
|
+
sorted_not_over_branches: list[BranchKey] = (
|
|
82
|
+
tree_node_to_sample_from.tree_evaluation.sort_branches_not_over()
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
branch = zipf_picks_random(
|
|
86
|
+
ordered_list_elements=sorted_not_over_branches,
|
|
87
|
+
random_generator=self.random_generator,
|
|
88
|
+
)
|
|
89
|
+
return branch
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Factory to build node selectors
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from random import Random
|
|
7
|
+
from typing import Literal, TypeAlias
|
|
8
|
+
|
|
9
|
+
from .node_selector import NodeSelector
|
|
10
|
+
from .node_selector_types import NodeSelectorType
|
|
11
|
+
from .opening_instructions import OpeningInstructor
|
|
12
|
+
from .recurzipf.recur_zipf_base import RecurZipfBase, RecurZipfBaseArgs
|
|
13
|
+
from .sequool import SequoolArgs, create_sequool
|
|
14
|
+
from .uniform import Uniform
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class UniformArgs:
|
|
19
|
+
"""
|
|
20
|
+
Arguments for the Uniform node selector.
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
type: Literal[NodeSelectorType.UNIFORM]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
AllNodeSelectorArgs: TypeAlias = RecurZipfBaseArgs | SequoolArgs | UniformArgs
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def create(
|
|
31
|
+
args: AllNodeSelectorArgs,
|
|
32
|
+
opening_instructor: OpeningInstructor,
|
|
33
|
+
random_generator: Random,
|
|
34
|
+
) -> NodeSelector:
|
|
35
|
+
"""
|
|
36
|
+
Creation of a node selector
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
node_move_opening_selector: NodeSelector
|
|
40
|
+
|
|
41
|
+
match args.type:
|
|
42
|
+
case NodeSelectorType.UNIFORM:
|
|
43
|
+
node_move_opening_selector = Uniform(opening_instructor=opening_instructor)
|
|
44
|
+
case NodeSelectorType.RECUR_ZIPF_BASE:
|
|
45
|
+
assert isinstance(args, RecurZipfBaseArgs)
|
|
46
|
+
node_move_opening_selector = RecurZipfBase(
|
|
47
|
+
args=args,
|
|
48
|
+
random_generator=random_generator,
|
|
49
|
+
opening_instructor=opening_instructor,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
case NodeSelectorType.SEQUOOL:
|
|
53
|
+
assert isinstance(args, SequoolArgs)
|
|
54
|
+
node_move_opening_selector = create_sequool(
|
|
55
|
+
opening_instructor=opening_instructor,
|
|
56
|
+
random_generator=random_generator,
|
|
57
|
+
args=args,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
case _:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f"node selector construction: can not find {args.type} {args} in file {__name__}"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
return node_move_opening_selector
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the definition of the NodeSelector class and related types.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Protocol
|
|
7
|
+
|
|
8
|
+
from anemone import trees
|
|
9
|
+
from anemone.nodes.algorithm_node.algorithm_node import AlgorithmNode
|
|
10
|
+
|
|
11
|
+
from .opening_instructions import OpeningInstructions
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from anemone import tree_manager as tree_man
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class NodeSelectorState:
|
|
19
|
+
"""Node Selector State"""
|
|
20
|
+
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NodeSelector[NodeT: AlgorithmNode[Any] = AlgorithmNode[Any]](Protocol):
|
|
25
|
+
"""
|
|
26
|
+
Protocol for Node Selectors.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def choose_node_and_branch_to_open(
|
|
30
|
+
self,
|
|
31
|
+
tree: trees.Tree[NodeT],
|
|
32
|
+
latest_tree_expansions: "tree_man.TreeExpansions[NodeT]",
|
|
33
|
+
) -> OpeningInstructions[NodeT]:
|
|
34
|
+
"""
|
|
35
|
+
Selects a node from the given tree and returns the instructions to move to an open position.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
tree: The tree containing the nodes.
|
|
39
|
+
latest_tree_expansions: The latest expansions of the tree.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
OpeningInstructions: The instructions to move to an open position.
|
|
43
|
+
"""
|
|
44
|
+
raise NotImplementedError()
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module defines the NodeSelectorArgs class, which represents the arguments for a node selector.
|
|
3
|
+
|
|
4
|
+
The NodeSelectorArgs class is a dataclass that contains a single attribute:
|
|
5
|
+
- type: The type of the node selector, represented by the NodeSelectorType enum.
|
|
6
|
+
|
|
7
|
+
Example usage:
|
|
8
|
+
args = NodeSelectorArgs(type=NodeSelectorType.BEST)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
|
|
13
|
+
from .node_selector_types import NodeSelectorType
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class NodeSelectorArgs:
|
|
18
|
+
"""
|
|
19
|
+
Represents the arguments for a node selector.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
type: NodeSelectorType
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module defines the NodeSelectorType enumeration, which represents the types of node selectors.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class NodeSelectorType(str, Enum):
|
|
9
|
+
"""
|
|
10
|
+
Enumeration representing the types of node selectors.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
RECUR_ZIPF_BASE = "RecurZipfBase"
|
|
14
|
+
SEQUOOL = "Sequool"
|
|
15
|
+
UNIFORM = "Uniform"
|