algorhino-anemone 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- algorhino_anemone-0.1.1.dist-info/METADATA +151 -0
- algorhino_anemone-0.1.1.dist-info/RECORD +82 -0
- algorhino_anemone-0.1.1.dist-info/WHEEL +5 -0
- algorhino_anemone-0.1.1.dist-info/licenses/LICENSE +674 -0
- algorhino_anemone-0.1.1.dist-info/top_level.txt +1 -0
- anemone/__init__.py +27 -0
- anemone/basics.py +36 -0
- anemone/factory.py +161 -0
- anemone/indices/__init__.py +0 -0
- anemone/indices/index_manager/__init__.py +12 -0
- anemone/indices/index_manager/factory.py +50 -0
- anemone/indices/index_manager/node_exploration_manager.py +549 -0
- anemone/indices/node_indices/__init__.py +22 -0
- anemone/indices/node_indices/factory.py +121 -0
- anemone/indices/node_indices/index_data.py +166 -0
- anemone/indices/node_indices/index_types.py +20 -0
- anemone/nn/torch_evaluator.py +108 -0
- anemone/node_evaluation/__init__.py +0 -0
- anemone/node_evaluation/node_direct_evaluation/__init__.py +22 -0
- anemone/node_evaluation/node_direct_evaluation/factory.py +12 -0
- anemone/node_evaluation/node_direct_evaluation/node_direct_evaluator.py +192 -0
- anemone/node_evaluation/node_tree_evaluation/node_minmax_evaluation.py +885 -0
- anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation.py +137 -0
- anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation_factory.py +43 -0
- anemone/node_factory/__init__.py +14 -0
- anemone/node_factory/algorithm_node_factory.py +123 -0
- anemone/node_factory/base.py +76 -0
- anemone/node_selector/__init__.py +32 -0
- anemone/node_selector/branch_explorer.py +89 -0
- anemone/node_selector/factory.py +65 -0
- anemone/node_selector/node_selector.py +44 -0
- anemone/node_selector/node_selector_args.py +22 -0
- anemone/node_selector/node_selector_types.py +15 -0
- anemone/node_selector/notations_and_statics.py +88 -0
- anemone/node_selector/opening_instructions.py +249 -0
- anemone/node_selector/recurzipf/__init__.py +0 -0
- anemone/node_selector/recurzipf/recur_zipf_base.py +141 -0
- anemone/node_selector/sequool/__init__.py +19 -0
- anemone/node_selector/sequool/factory.py +102 -0
- anemone/node_selector/sequool/sequool.py +395 -0
- anemone/node_selector/uniform/__init__.py +16 -0
- anemone/node_selector/uniform/uniform.py +113 -0
- anemone/nodes/__init__.py +15 -0
- anemone/nodes/algorithm_node/__init__.py +7 -0
- anemone/nodes/algorithm_node/algorithm_node.py +204 -0
- anemone/nodes/itree_node.py +136 -0
- anemone/nodes/tree_node.py +240 -0
- anemone/nodes/tree_traversal.py +108 -0
- anemone/nodes/utils.py +146 -0
- anemone/progress_monitor/__init__.py +0 -0
- anemone/progress_monitor/progress_monitor.py +375 -0
- anemone/recommender_rule/__init__.py +12 -0
- anemone/recommender_rule/recommender_rule.py +140 -0
- anemone/search_factory/__init__.py +14 -0
- anemone/search_factory/search_factory.py +192 -0
- anemone/state_transition.py +47 -0
- anemone/tree_and_value_branch_selector.py +99 -0
- anemone/tree_exploration.py +274 -0
- anemone/tree_manager/__init__.py +29 -0
- anemone/tree_manager/algorithm_node_tree_manager.py +246 -0
- anemone/tree_manager/factory.py +77 -0
- anemone/tree_manager/tree_expander.py +122 -0
- anemone/tree_manager/tree_manager.py +254 -0
- anemone/trees/__init__.py +14 -0
- anemone/trees/descendants.py +765 -0
- anemone/trees/factory.py +80 -0
- anemone/trees/tree.py +70 -0
- anemone/trees/tree_visualization.py +143 -0
- anemone/updates/__init__.py +33 -0
- anemone/updates/algorithm_node_updater.py +157 -0
- anemone/updates/factory.py +36 -0
- anemone/updates/index_block.py +91 -0
- anemone/updates/index_updater.py +100 -0
- anemone/updates/minmax_evaluation_updater.py +108 -0
- anemone/updates/updates_file.py +248 -0
- anemone/updates/value_block.py +133 -0
- anemone/utils/comparable.py +32 -0
- anemone/utils/dataclass.py +64 -0
- anemone/utils/dict_of_numbered_dict_with_pointer_on_max.py +128 -0
- anemone/utils/logger.py +94 -0
- anemone/utils/my_value_sorted_dict.py +27 -0
- anemone/utils/small_tools.py +103 -0
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module that contains the classes for the exploration data of a tree node.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from valanga import State
|
|
9
|
+
|
|
10
|
+
from anemone.nodes.itree_node import ITreeNode
|
|
11
|
+
from anemone.nodes.tree_node import TreeNode
|
|
12
|
+
from anemone.utils.small_tools import Interval
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class NodeExplorationData[
|
|
17
|
+
Node: ITreeNode[Any] = ITreeNode[Any],
|
|
18
|
+
StateT: State = State,
|
|
19
|
+
]:
|
|
20
|
+
"""
|
|
21
|
+
Represents the exploration data for a tree node.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
tree_node (TreeNode): The tree node associated with the exploration data.
|
|
25
|
+
index (float | None): The index value associated with the node. Defaults to None.
|
|
26
|
+
|
|
27
|
+
Methods:
|
|
28
|
+
dot_description(): Returns a string representation of the exploration data for dot visualization.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
tree_node: TreeNode[Node, StateT]
|
|
32
|
+
index: float | None = None
|
|
33
|
+
|
|
34
|
+
def dot_description(self) -> str:
|
|
35
|
+
"""
|
|
36
|
+
Returns a string representation of the dot description for the index.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
str: The dot description of the index.
|
|
40
|
+
"""
|
|
41
|
+
return f"index:{self.index}"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class RecurZipfQuoolExplorationData[
|
|
46
|
+
Node: ITreeNode[Any] = ITreeNode[Any],
|
|
47
|
+
StateT: State = State,
|
|
48
|
+
](NodeExplorationData[Node, StateT]):
|
|
49
|
+
"""
|
|
50
|
+
Represents the exploration data for a tree node with recursive zipf-quool factor.
|
|
51
|
+
|
|
52
|
+
Attributes:
|
|
53
|
+
zipf_factored_proba (float | None): The probability associated with the node, factored by zipf-quool factor.
|
|
54
|
+
Defaults to None.
|
|
55
|
+
|
|
56
|
+
Methods:
|
|
57
|
+
dot_description(): Returns a string representation of the exploration data for dot visualization.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
# the 'proba' associated by recursively multiplying 1/rank of the node with the max zipf_factor of the parents
|
|
61
|
+
zipf_factored_proba: float | None = None
|
|
62
|
+
|
|
63
|
+
def dot_description(self) -> str:
|
|
64
|
+
"""
|
|
65
|
+
Returns a string representation of the index and zipf_factored_proba values.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
str: A string representation of the index and zipf_factored_proba values.
|
|
69
|
+
"""
|
|
70
|
+
return f"index:{self.index} zipf_factored_proba:{self.zipf_factored_proba}"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class MinMaxPathValue[
|
|
75
|
+
Node: ITreeNode[Any] = ITreeNode[Any],
|
|
76
|
+
StateT: State = State,
|
|
77
|
+
](NodeExplorationData[Node, StateT]):
|
|
78
|
+
"""
|
|
79
|
+
Represents the exploration data for a tree node with minimum and maximum path values.
|
|
80
|
+
|
|
81
|
+
Attributes:
|
|
82
|
+
min_path_value (float | None): The minimum path value associated with the node. Defaults to None.
|
|
83
|
+
max_path_value (float | None): The maximum path value associated with the node. Defaults to None.
|
|
84
|
+
|
|
85
|
+
Methods:
|
|
86
|
+
dot_description(): Returns a string representation of the exploration data for dot visualization.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
min_path_value: float | None = None
|
|
90
|
+
max_path_value: float | None = None
|
|
91
|
+
|
|
92
|
+
def dot_description(self) -> str:
|
|
93
|
+
"""Return a string representation of min/max path values."""
|
|
94
|
+
return f"min_path_value: {self.min_path_value}, max_path_value: {self.max_path_value}"
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclass
|
|
98
|
+
class IntervalExplo[
|
|
99
|
+
Node: ITreeNode[Any] = ITreeNode[Any],
|
|
100
|
+
StateT: State = State,
|
|
101
|
+
](NodeExplorationData[Node, StateT]):
|
|
102
|
+
"""
|
|
103
|
+
Represents the exploration data for a tree node with an interval.
|
|
104
|
+
|
|
105
|
+
Attributes:
|
|
106
|
+
interval (Interval | None): The interval associated with the node. Defaults to None.
|
|
107
|
+
|
|
108
|
+
Methods:
|
|
109
|
+
dot_description(): Returns a string representation of the exploration data for dot visualization.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
interval: Interval | None = field(default_factory=Interval)
|
|
113
|
+
|
|
114
|
+
def dot_description(self) -> str:
|
|
115
|
+
"""
|
|
116
|
+
Returns a string representation of the interval values.
|
|
117
|
+
|
|
118
|
+
If the interval is None, returns 'None'.
|
|
119
|
+
Otherwise, returns a string in the format 'min_interval_value: {min_value}, max_interval_value: {max_value}'.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
str: A string representation of the interval values.
|
|
123
|
+
"""
|
|
124
|
+
if self.interval is None:
|
|
125
|
+
return "None"
|
|
126
|
+
return f"min_interval_value: {self.interval.min_value}, max_interval_value: {self.interval.max_value}"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@dataclass
|
|
130
|
+
class MaxDepthDescendants[
|
|
131
|
+
Node: ITreeNode[Any] = ITreeNode[Any],
|
|
132
|
+
StateT: State = State,
|
|
133
|
+
](NodeExplorationData[Node, StateT]):
|
|
134
|
+
"""
|
|
135
|
+
Represents the exploration data for a tree node with maximum depth of descendants.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
max_depth_descendants: int = 0
|
|
139
|
+
|
|
140
|
+
def update_from_child(self, child_max_depth_descendants: int) -> bool:
|
|
141
|
+
"""
|
|
142
|
+
Updates the max_depth_descendants value based on the child's max_depth_descendants.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
child_max_depth_descendants (int): The max_depth_descendants value of the child node.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
bool: True if the max_depth_descendants value has changed, False otherwise.
|
|
149
|
+
"""
|
|
150
|
+
previous_index = self.max_depth_descendants
|
|
151
|
+
new_index: int = max(
|
|
152
|
+
self.max_depth_descendants, child_max_depth_descendants + 1
|
|
153
|
+
)
|
|
154
|
+
self.max_depth_descendants = new_index
|
|
155
|
+
has_index_changed: bool = new_index != previous_index
|
|
156
|
+
|
|
157
|
+
return has_index_changed
|
|
158
|
+
|
|
159
|
+
def dot_description(self) -> str:
|
|
160
|
+
"""
|
|
161
|
+
Returns a string representation of the dot description for the node indices.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
str: The dot description for the node indices.
|
|
165
|
+
"""
|
|
166
|
+
return f"max_depth_descendants: {self.max_depth_descendants}"
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module defines the enumeration for index computation types.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class IndexComputationType(str, Enum):
|
|
9
|
+
"""
|
|
10
|
+
Enumeration for index computation types.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
MinGlobalChange (str): Represents the minimum global change computation type.
|
|
14
|
+
MinLocalChange (str): Represents the minimum local change computation type.
|
|
15
|
+
RecurZipf (str): Represents the recurzipf computation type.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
MIN_GLOBAL_CHANGE = "min_global_change"
|
|
19
|
+
MIN_LOCAL_CHANGE = "min_local_change"
|
|
20
|
+
RECUR_ZIPF = "recurzipf"
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Torch-based MasterStateEvaluator for efficient batch evaluations.
|
|
3
|
+
"""
|
|
4
|
+
# pyright: reportMissingImports=false
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import TYPE_CHECKING, Sequence
|
|
8
|
+
|
|
9
|
+
from valanga import State
|
|
10
|
+
from valanga.evaluations import EvalItem
|
|
11
|
+
|
|
12
|
+
from anemone.node_evaluation.node_direct_evaluation.node_direct_evaluator import (
|
|
13
|
+
MasterStateEvaluator,
|
|
14
|
+
OverEventDetector,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from coral.neural_networks.nn_content_evaluator import NNContentEvaluator
|
|
19
|
+
from torch import Tensor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(slots=True)
|
|
23
|
+
class TorchMasterNNStateEvaluator(MasterStateEvaluator):
|
|
24
|
+
"""
|
|
25
|
+
Torch-backed MasterStateEvaluator that supports efficient batch evaluation.
|
|
26
|
+
This lives in an optional module so anemone core has no torch dependency.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
over: OverEventDetector
|
|
30
|
+
state_evaluator: "NNContentEvaluator"
|
|
31
|
+
device: str = "cpu"
|
|
32
|
+
script: bool = True
|
|
33
|
+
|
|
34
|
+
def __post_init__(self) -> None:
|
|
35
|
+
"""Initialize the torch model and validate dependencies."""
|
|
36
|
+
try:
|
|
37
|
+
import torch
|
|
38
|
+
except ModuleNotFoundError as e:
|
|
39
|
+
raise ModuleNotFoundError(
|
|
40
|
+
"TorchMasterNNStateEvaluator requires 'torch'. "
|
|
41
|
+
"Install the optional torch dependencies."
|
|
42
|
+
) from e
|
|
43
|
+
|
|
44
|
+
self._torch = torch
|
|
45
|
+
|
|
46
|
+
model = self.state_evaluator.net
|
|
47
|
+
self._model = torch.jit.script(model) if self.script else model
|
|
48
|
+
self._model.eval()
|
|
49
|
+
|
|
50
|
+
def value_white(self, state: State) -> float:
|
|
51
|
+
"""Evaluate a single state by delegating to the batch path."""
|
|
52
|
+
# Slow path: evaluate a single state by wrapping it as an EvalItem.
|
|
53
|
+
return self.value_white_batch_items([_SingleEvalItem(state)])[0]
|
|
54
|
+
|
|
55
|
+
def value_white_batch_items[ItemStateT: State](
|
|
56
|
+
self, items: Sequence[EvalItem[ItemStateT]]
|
|
57
|
+
) -> list[float]:
|
|
58
|
+
"""Evaluate a batch of items with torch and return white values."""
|
|
59
|
+
torch = self._torch
|
|
60
|
+
|
|
61
|
+
xs: list["Tensor"] = []
|
|
62
|
+
states: list[ItemStateT] = []
|
|
63
|
+
|
|
64
|
+
for it in items:
|
|
65
|
+
st = it.state
|
|
66
|
+
states.append(st)
|
|
67
|
+
|
|
68
|
+
# Prefer precomputed representation when available
|
|
69
|
+
if it.state_representation is not None:
|
|
70
|
+
raw = it.state_representation.get_evaluator_input(state=st)
|
|
71
|
+
else:
|
|
72
|
+
raw = self.state_evaluator.content_to_input_convert(st)
|
|
73
|
+
|
|
74
|
+
xs.append(torch.as_tensor(raw).to(self.device))
|
|
75
|
+
|
|
76
|
+
x_batch = torch.stack(xs, dim=0)
|
|
77
|
+
|
|
78
|
+
with torch.no_grad():
|
|
79
|
+
out = self._model(x_batch)
|
|
80
|
+
|
|
81
|
+
converter = self.state_evaluator.output_and_value_converter
|
|
82
|
+
|
|
83
|
+
values: list[float] = []
|
|
84
|
+
for i, st in enumerate(states):
|
|
85
|
+
state_eval = converter.to_content_evaluation(output_nn=out[i], state=st)
|
|
86
|
+
vw = state_eval.value_white
|
|
87
|
+
assert vw is not None
|
|
88
|
+
values.append(float(vw))
|
|
89
|
+
|
|
90
|
+
return values
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class _SingleEvalItem:
|
|
94
|
+
"""Small adapter so we can call batch method from value_white."""
|
|
95
|
+
|
|
96
|
+
def __init__(self, state: State) -> None:
|
|
97
|
+
"""Initialize the EvalItem with the given state."""
|
|
98
|
+
self._state = state
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def state(self) -> State:
|
|
102
|
+
"""The state to evaluate."""
|
|
103
|
+
return self._state
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def state_representation(self) -> None:
|
|
107
|
+
"""No precomputed representation."""
|
|
108
|
+
return None
|
|
File without changes
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides functionality for evaluating nodes in a tree structure.
|
|
3
|
+
|
|
4
|
+
The module includes a factory function for creating node evaluators, as well as classes for representing
|
|
5
|
+
node evaluators, evaluation queries, and node evaluator arguments.
|
|
6
|
+
|
|
7
|
+
Available objects:
|
|
8
|
+
- AllNodeEvaluatorArgs: A named tuple representing all the arguments for creating a node evaluator.
|
|
9
|
+
- NodeEvaluator: A class representing a node evaluator.
|
|
10
|
+
- create_node_evaluator: A factory function for creating a node evaluator.
|
|
11
|
+
- EvaluationQueries: An enumeration representing different types of evaluation queries.
|
|
12
|
+
- NodeEvaluatorArgs: A class representing the arguments for creating a node evaluator.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from .factory import create_node_evaluator
|
|
16
|
+
from .node_direct_evaluator import EvaluationQueries, NodeDirectEvaluator
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"NodeDirectEvaluator",
|
|
20
|
+
"create_node_evaluator",
|
|
21
|
+
"EvaluationQueries",
|
|
22
|
+
]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""This module provides a factory function for creating node evaluators."""
|
|
2
|
+
|
|
3
|
+
from valanga import State
|
|
4
|
+
|
|
5
|
+
from .node_direct_evaluator import MasterStateEvaluator, NodeDirectEvaluator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def create_node_evaluator[StateT: State = State](
|
|
9
|
+
master_state_evaluator: MasterStateEvaluator,
|
|
10
|
+
) -> NodeDirectEvaluator[StateT]:
|
|
11
|
+
"""Create a NodeDirectEvaluator backed by a master state evaluator."""
|
|
12
|
+
return NodeDirectEvaluator(master_state_evaluator=master_state_evaluator)
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the implementation of the NodeEvaluator class, which is responsible for evaluating the value of
|
|
3
|
+
nodes in a tree-based move selector.
|
|
4
|
+
|
|
5
|
+
The NodeEvaluator class wraps a board evaluator and a syzygy table to provide more complex evaluations of chess
|
|
6
|
+
positions. It handles queries for evaluating nodes and manages obvious over events.
|
|
7
|
+
|
|
8
|
+
Classes:
|
|
9
|
+
- NodeEvaluator: Wrapping node evaluator with syzygy and obvious over event.
|
|
10
|
+
|
|
11
|
+
Enums:
|
|
12
|
+
- NodeEvaluatorTypes: Types of node evaluators.
|
|
13
|
+
|
|
14
|
+
Constants:
|
|
15
|
+
- DISCOUNT: Discount factor used in the evaluation.
|
|
16
|
+
|
|
17
|
+
Functions:
|
|
18
|
+
- None
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from enum import Enum
|
|
23
|
+
from typing import Protocol, Sequence
|
|
24
|
+
|
|
25
|
+
from valanga import OverEvent, State
|
|
26
|
+
from valanga.evaluations import EvalItem
|
|
27
|
+
|
|
28
|
+
from anemone.nodes.algorithm_node import AlgorithmNode
|
|
29
|
+
|
|
30
|
+
DISCOUNT = 0.99999999 # lokks like at the moment the use is to break ties in the evaluation (not sure if needed or helpful now)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class NodeEvaluatorTypes(str, Enum):
|
|
34
|
+
"""
|
|
35
|
+
Enum class representing different types of node evaluators.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
NEURAL_NETWORK = "neural_network"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class NodeBatchValueEvaluator(Protocol):
|
|
42
|
+
"""Return value_white for each node, can use node.state_representation for speed."""
|
|
43
|
+
|
|
44
|
+
def value_white_batch_from_nodes(
|
|
45
|
+
self, nodes: Sequence[AlgorithmNode]
|
|
46
|
+
) -> list[float]:
|
|
47
|
+
"""Return value_white evaluations for a batch of nodes."""
|
|
48
|
+
...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class EvaluationQueries[StateT: State = State]:
|
|
52
|
+
"""
|
|
53
|
+
A class that represents evaluation queries for algorithm nodes.
|
|
54
|
+
|
|
55
|
+
Attributes:
|
|
56
|
+
over_nodes (list[AlgorithmNode]): A list of algorithm nodes that are considered "over".
|
|
57
|
+
not_over_nodes (list[AlgorithmNode]): A list of algorithm nodes that are not considered "over".
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
over_nodes: list[AlgorithmNode[StateT]]
|
|
61
|
+
not_over_nodes: list[AlgorithmNode[StateT]]
|
|
62
|
+
|
|
63
|
+
def __init__(self) -> None:
|
|
64
|
+
"""
|
|
65
|
+
Initializes a new instance of the NodeEvaluator class.
|
|
66
|
+
"""
|
|
67
|
+
self.over_nodes = []
|
|
68
|
+
self.not_over_nodes = []
|
|
69
|
+
|
|
70
|
+
def clear_queries(self) -> None:
|
|
71
|
+
"""
|
|
72
|
+
Clears the evaluation queries by resetting the over_nodes and not_over_nodes lists.
|
|
73
|
+
"""
|
|
74
|
+
self.over_nodes = []
|
|
75
|
+
self.not_over_nodes = []
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class OverEventDetector(Protocol):
|
|
79
|
+
"""Protocol for detecting over events in a game state."""
|
|
80
|
+
|
|
81
|
+
def check_obvious_over_events(
|
|
82
|
+
self, state: State
|
|
83
|
+
) -> tuple[OverEvent | None, float | None]:
|
|
84
|
+
"""Return an over event and evaluation if the state is terminal."""
|
|
85
|
+
...
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class MasterStateEvaluator(Protocol):
|
|
89
|
+
"""Protocol for evaluating the value of a state."""
|
|
90
|
+
|
|
91
|
+
over: OverEventDetector
|
|
92
|
+
|
|
93
|
+
def value_white(self, state: State) -> float:
|
|
94
|
+
"""Evaluate a single state from white's perspective."""
|
|
95
|
+
...
|
|
96
|
+
|
|
97
|
+
# the one method NodeEvaluator uses
|
|
98
|
+
def value_white_batch_items[ItemStateT: State](
|
|
99
|
+
self, items: Sequence[EvalItem[ItemStateT]]
|
|
100
|
+
) -> list[float]:
|
|
101
|
+
"""Evaluate a batch of items, defaulting to single-state calls."""
|
|
102
|
+
# default fallback: single loop, state-only
|
|
103
|
+
return [self.value_white(it.state) for it in items]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class NodeDirectEvaluator[StateT: State = State]:
|
|
107
|
+
"""
|
|
108
|
+
The NodeEvaluator class is responsible for evaluating the value of nodes in a tree structure.
|
|
109
|
+
It uses a board evaluator and a syzygy evaluator to calculate the value of the nodes.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
master_state_evaluator: MasterStateEvaluator
|
|
113
|
+
|
|
114
|
+
def __init__(
|
|
115
|
+
self,
|
|
116
|
+
master_state_evaluator: MasterStateEvaluator,
|
|
117
|
+
) -> None:
|
|
118
|
+
"""
|
|
119
|
+
Initializes a NodeEvaluator object.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
state_evaluator (MasterStateEvaluator): The state evaluator used to evaluate the chess state.
|
|
123
|
+
"""
|
|
124
|
+
self.master_state_evaluator = master_state_evaluator
|
|
125
|
+
|
|
126
|
+
def check_obvious_over_events(self, node: AlgorithmNode[StateT]) -> None:
|
|
127
|
+
"""
|
|
128
|
+
Updates the node.over object if the game is obviously over.
|
|
129
|
+
"""
|
|
130
|
+
over_event: OverEvent | None
|
|
131
|
+
evaluation: float | None
|
|
132
|
+
over_event, evaluation = (
|
|
133
|
+
self.master_state_evaluator.over.check_obvious_over_events(node.state)
|
|
134
|
+
)
|
|
135
|
+
if over_event is not None:
|
|
136
|
+
node.tree_evaluation.over_event.becomes_over(
|
|
137
|
+
how_over=over_event.how_over,
|
|
138
|
+
who_is_winner=over_event.who_is_winner,
|
|
139
|
+
termination=over_event.termination,
|
|
140
|
+
)
|
|
141
|
+
assert evaluation is not None, (
|
|
142
|
+
"Evaluation should not be None for over nodes"
|
|
143
|
+
)
|
|
144
|
+
node.tree_evaluation.set_evaluation(evaluation=evaluation)
|
|
145
|
+
|
|
146
|
+
def evaluate_all_queried_nodes(
|
|
147
|
+
self, evaluation_queries: EvaluationQueries[StateT]
|
|
148
|
+
) -> None:
|
|
149
|
+
"""
|
|
150
|
+
Evaluates all the queried nodes.
|
|
151
|
+
"""
|
|
152
|
+
# node_over: AlgorithmNode
|
|
153
|
+
# for node_over in evaluation_queries.over_nodes:
|
|
154
|
+
# assert isinstance(node_over, AlgorithmNode)
|
|
155
|
+
# self.evaluate_over(node_over)
|
|
156
|
+
|
|
157
|
+
if evaluation_queries.not_over_nodes:
|
|
158
|
+
self.evaluate_all_not_over(evaluation_queries.not_over_nodes)
|
|
159
|
+
|
|
160
|
+
evaluation_queries.clear_queries()
|
|
161
|
+
|
|
162
|
+
def add_evaluation_query(
|
|
163
|
+
self, node: AlgorithmNode[StateT], evaluation_queries: EvaluationQueries[StateT]
|
|
164
|
+
) -> None:
|
|
165
|
+
"""
|
|
166
|
+
Adds an evaluation query for a node.
|
|
167
|
+
"""
|
|
168
|
+
assert node.tree_evaluation.value_white_direct_evaluation is None
|
|
169
|
+
self.check_obvious_over_events(node)
|
|
170
|
+
if node.is_over():
|
|
171
|
+
evaluation_queries.over_nodes.append(node)
|
|
172
|
+
else:
|
|
173
|
+
evaluation_queries.not_over_nodes.append(node)
|
|
174
|
+
|
|
175
|
+
def evaluate_all_not_over(
|
|
176
|
+
self, not_over_nodes: list[AlgorithmNode[StateT]]
|
|
177
|
+
) -> None:
|
|
178
|
+
"""Evaluate all non-terminal nodes and store their evaluations."""
|
|
179
|
+
values = self.master_state_evaluator.value_white_batch_items(not_over_nodes)
|
|
180
|
+
for node, v in zip(not_over_nodes, values, strict=True):
|
|
181
|
+
node.tree_evaluation.set_evaluation(
|
|
182
|
+
self.process_evalution_not_over(v, node)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
def process_evalution_not_over(
|
|
186
|
+
self, evaluation: float, node: AlgorithmNode[StateT]
|
|
187
|
+
) -> float:
|
|
188
|
+
"""
|
|
189
|
+
Processes the evaluation for a node that is not over.
|
|
190
|
+
"""
|
|
191
|
+
processed_evaluation = (1 / DISCOUNT) ** node.tree_depth * evaluation
|
|
192
|
+
return processed_evaluation
|