algorhino-anemone 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. algorhino_anemone-0.1.1.dist-info/METADATA +151 -0
  2. algorhino_anemone-0.1.1.dist-info/RECORD +82 -0
  3. algorhino_anemone-0.1.1.dist-info/WHEEL +5 -0
  4. algorhino_anemone-0.1.1.dist-info/licenses/LICENSE +674 -0
  5. algorhino_anemone-0.1.1.dist-info/top_level.txt +1 -0
  6. anemone/__init__.py +27 -0
  7. anemone/basics.py +36 -0
  8. anemone/factory.py +161 -0
  9. anemone/indices/__init__.py +0 -0
  10. anemone/indices/index_manager/__init__.py +12 -0
  11. anemone/indices/index_manager/factory.py +50 -0
  12. anemone/indices/index_manager/node_exploration_manager.py +549 -0
  13. anemone/indices/node_indices/__init__.py +22 -0
  14. anemone/indices/node_indices/factory.py +121 -0
  15. anemone/indices/node_indices/index_data.py +166 -0
  16. anemone/indices/node_indices/index_types.py +20 -0
  17. anemone/nn/torch_evaluator.py +108 -0
  18. anemone/node_evaluation/__init__.py +0 -0
  19. anemone/node_evaluation/node_direct_evaluation/__init__.py +22 -0
  20. anemone/node_evaluation/node_direct_evaluation/factory.py +12 -0
  21. anemone/node_evaluation/node_direct_evaluation/node_direct_evaluator.py +192 -0
  22. anemone/node_evaluation/node_tree_evaluation/node_minmax_evaluation.py +885 -0
  23. anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation.py +137 -0
  24. anemone/node_evaluation/node_tree_evaluation/node_tree_evaluation_factory.py +43 -0
  25. anemone/node_factory/__init__.py +14 -0
  26. anemone/node_factory/algorithm_node_factory.py +123 -0
  27. anemone/node_factory/base.py +76 -0
  28. anemone/node_selector/__init__.py +32 -0
  29. anemone/node_selector/branch_explorer.py +89 -0
  30. anemone/node_selector/factory.py +65 -0
  31. anemone/node_selector/node_selector.py +44 -0
  32. anemone/node_selector/node_selector_args.py +22 -0
  33. anemone/node_selector/node_selector_types.py +15 -0
  34. anemone/node_selector/notations_and_statics.py +88 -0
  35. anemone/node_selector/opening_instructions.py +249 -0
  36. anemone/node_selector/recurzipf/__init__.py +0 -0
  37. anemone/node_selector/recurzipf/recur_zipf_base.py +141 -0
  38. anemone/node_selector/sequool/__init__.py +19 -0
  39. anemone/node_selector/sequool/factory.py +102 -0
  40. anemone/node_selector/sequool/sequool.py +395 -0
  41. anemone/node_selector/uniform/__init__.py +16 -0
  42. anemone/node_selector/uniform/uniform.py +113 -0
  43. anemone/nodes/__init__.py +15 -0
  44. anemone/nodes/algorithm_node/__init__.py +7 -0
  45. anemone/nodes/algorithm_node/algorithm_node.py +204 -0
  46. anemone/nodes/itree_node.py +136 -0
  47. anemone/nodes/tree_node.py +240 -0
  48. anemone/nodes/tree_traversal.py +108 -0
  49. anemone/nodes/utils.py +146 -0
  50. anemone/progress_monitor/__init__.py +0 -0
  51. anemone/progress_monitor/progress_monitor.py +375 -0
  52. anemone/recommender_rule/__init__.py +12 -0
  53. anemone/recommender_rule/recommender_rule.py +140 -0
  54. anemone/search_factory/__init__.py +14 -0
  55. anemone/search_factory/search_factory.py +192 -0
  56. anemone/state_transition.py +47 -0
  57. anemone/tree_and_value_branch_selector.py +99 -0
  58. anemone/tree_exploration.py +274 -0
  59. anemone/tree_manager/__init__.py +29 -0
  60. anemone/tree_manager/algorithm_node_tree_manager.py +246 -0
  61. anemone/tree_manager/factory.py +77 -0
  62. anemone/tree_manager/tree_expander.py +122 -0
  63. anemone/tree_manager/tree_manager.py +254 -0
  64. anemone/trees/__init__.py +14 -0
  65. anemone/trees/descendants.py +765 -0
  66. anemone/trees/factory.py +80 -0
  67. anemone/trees/tree.py +70 -0
  68. anemone/trees/tree_visualization.py +143 -0
  69. anemone/updates/__init__.py +33 -0
  70. anemone/updates/algorithm_node_updater.py +157 -0
  71. anemone/updates/factory.py +36 -0
  72. anemone/updates/index_block.py +91 -0
  73. anemone/updates/index_updater.py +100 -0
  74. anemone/updates/minmax_evaluation_updater.py +108 -0
  75. anemone/updates/updates_file.py +248 -0
  76. anemone/updates/value_block.py +133 -0
  77. anemone/utils/comparable.py +32 -0
  78. anemone/utils/dataclass.py +64 -0
  79. anemone/utils/dict_of_numbered_dict_with_pointer_on_max.py +128 -0
  80. anemone/utils/logger.py +94 -0
  81. anemone/utils/my_value_sorted_dict.py +27 -0
  82. anemone/utils/small_tools.py +103 -0
anemone/nodes/utils.py ADDED
@@ -0,0 +1,146 @@
1
+ """
2
+ This module contains utility functions for working with tree nodes in the move selector.
3
+
4
+ Functions:
5
+ - are_all_moves_and_children_opened(tree_node: TreeNode) -> bool: Checks if all moves and children of a tree node are opened.
6
+ - a_move_sequence_from_root(tree_node: ITreeNode) -> list[str]: Returns a list of move sequences from the root node to a given tree node.
7
+ - print_a_move_sequence_from_root(tree_node: TreeNode) -> None: Prints the move sequence from the root node to a given tree node.
8
+ - is_winning(node_minmax_evaluation: NodeMinmaxEvaluation, color: chess.Color) -> bool: Checks if the color to play in the node is winning.
9
+ """
10
+
11
+ from valanga import BranchKey, Color, State
12
+
13
+ from anemone.node_evaluation.node_tree_evaluation.node_tree_evaluation import (
14
+ NodeTreeEvaluation,
15
+ )
16
+ from anemone.nodes.algorithm_node.algorithm_node import (
17
+ AlgorithmNode,
18
+ )
19
+
20
+ from .itree_node import ITreeNode
21
+ from .tree_node import TreeNode
22
+
23
+
24
+ def are_all_moves_and_children_opened(tree_node: TreeNode) -> bool:
25
+ """
26
+ Checks if all moves and children of a tree node are opened.
27
+
28
+ Args:
29
+ tree_node (TreeNode): The tree node to check.
30
+
31
+ Returns:
32
+ bool: True if all moves and children are opened, False otherwise.
33
+ """
34
+ return tree_node.all_branches_generated and tree_node.non_opened_branches == set()
35
+
36
+
37
+ def a_move_key_sequence_from_root[StateT: State](
38
+ tree_node: ITreeNode[StateT],
39
+ ) -> list[str]:
40
+ """
41
+ Returns a list of move sequences from the root node to a given tree node.
42
+
43
+ Args:
44
+ tree_node (ITreeNode): The tree node to get the move sequence for.
45
+
46
+ Returns:
47
+ list[str]: A list of move sequences from the root node to the given tree node.
48
+ """
49
+ move_sequence_from_root: list[BranchKey] = []
50
+ child: ITreeNode[StateT] = tree_node
51
+ while child.parent_nodes:
52
+ parent: ITreeNode[StateT] = next(iter(child.parent_nodes))
53
+ move: BranchKey = child.parent_nodes[parent]
54
+ move_sequence_from_root.append(move)
55
+ child = parent
56
+ move_sequence_from_root.reverse()
57
+ return [str(i) for i in move_sequence_from_root]
58
+
59
+
60
+ def a_branch_str_sequence_from_root[StateT: State](
61
+ tree_node: ITreeNode[StateT],
62
+ ) -> list[str]:
63
+ """
64
+ Returns a list of move sequences from the root node to a given tree node.
65
+
66
+ Args:
67
+ tree_node (ITreeNode): The tree node to get the move sequence for.
68
+
69
+ Returns:
70
+ list[str]: A list of move sequences from the root node to the given tree node.
71
+ """
72
+ move_sequence_from_root: list[str] = []
73
+ child: ITreeNode[StateT] = tree_node
74
+ while child.parent_nodes:
75
+ parent: ITreeNode[StateT] = next(iter(child.parent_nodes))
76
+ branch_key: BranchKey = child.parent_nodes[parent]
77
+ branch_str: str = parent.state.branch_name_from_key(branch_key)
78
+ move_sequence_from_root.append(branch_str)
79
+ child = parent
80
+ move_sequence_from_root.reverse()
81
+ return [str(i) for i in move_sequence_from_root]
82
+
83
+
84
+ def best_node_sequence_from_node[StateT: State](
85
+ tree_node: AlgorithmNode[StateT],
86
+ ) -> list[AlgorithmNode[StateT]]:
87
+ """
88
+ Returns the best node sequence from the given tree node following the best moves.
89
+ Args:
90
+ tree_node (AlgorithmNode): The tree node to start from.
91
+ Returns:
92
+ list[AlgorithmNode]: A list of tree nodes representing the best node sequence.
93
+ """
94
+
95
+ best_move_seq: list[BranchKey] = tree_node.tree_evaluation.best_branch_sequence
96
+ index = 0
97
+ move_sequence: list[AlgorithmNode[StateT]] = [tree_node]
98
+ child: AlgorithmNode[StateT] = tree_node
99
+ while child.branches_children:
100
+ move: BranchKey = best_move_seq[index]
101
+ child_ = child.branches_children[move]
102
+ assert child_ is not None
103
+ child = child_
104
+ move_sequence.append(child)
105
+ index = index + 1
106
+ return move_sequence
107
+
108
+
109
+ def print_a_move_sequence_from_root[StateT: State](
110
+ tree_node: ITreeNode[StateT],
111
+ ) -> None:
112
+ """
113
+ Prints the move sequence from the root node to a given tree node.
114
+
115
+ Args:
116
+ tree_node (TreeNode): The tree node to print the move sequence for.
117
+
118
+ Returns:
119
+ None
120
+ """
121
+ move_sequence_from_root: list[str] = a_move_key_sequence_from_root(
122
+ tree_node=tree_node
123
+ )
124
+ print(f"a_move_sequence_from_root{move_sequence_from_root}")
125
+
126
+
127
+ def is_winning(node_tree_evaluation: NodeTreeEvaluation, color: Color) -> bool:
128
+ """
129
+ Checks if the color to play in the node is winning.
130
+
131
+ Args:
132
+ node_minmax_evaluation (NodeMinmaxEvaluation): The evaluation of the node.
133
+ color (chess.Color): The color to check.
134
+
135
+ Returns:
136
+ bool: True if the color is winning, False otherwise.
137
+ """
138
+ assert node_tree_evaluation.value_white_minmax is not None
139
+ winning_if_color_white: bool = (
140
+ node_tree_evaluation.value_white_minmax > 0.98 and color is Color.WHITE
141
+ )
142
+ winning_if_color_black: bool = (
143
+ node_tree_evaluation.value_white_minmax < -0.98 and color is Color.BLACK
144
+ )
145
+
146
+ return winning_if_color_white or winning_if_color_black
File without changes
@@ -0,0 +1,375 @@
1
+ """
2
+ This module defines stopping criteria for a move selector in a game tree.
3
+
4
+ The stopping criteria determine when the move selector should stop exploring the game tree and make a decision.
5
+
6
+ The module includes the following classes:
7
+
8
+ - StoppingCriterion: The general stopping criterion class.
9
+ - TreeMoveLimit: A stopping criterion based on a tree move limit.
10
+ - DepthLimit: A stopping criterion based on a depth limit.
11
+
12
+ It also includes helper classes and functions for creating and managing stopping criteria.
13
+ """
14
+
15
+ from abc import abstractmethod
16
+ from dataclasses import dataclass
17
+ from enum import Enum
18
+ from typing import Any, Callable, Literal, Protocol, runtime_checkable
19
+
20
+ from anemone import node_selector as node_sel
21
+ from anemone import trees
22
+ from anemone.nodes.algorithm_node.algorithm_node import AlgorithmNode
23
+
24
+
25
+ @runtime_checkable
26
+ class DepthToExpendP(Protocol):
27
+ """
28
+ Protocol for objects that provide the current depth to expand.
29
+
30
+ This protocol defines a single method `get_current_depth_to_expand` that should be implemented by classes
31
+ that want to provide the current depth to expand.
32
+
33
+ Attributes:
34
+ None
35
+
36
+ Methods:
37
+ get_current_depth_to_expand: Returns the current depth to expand as an integer.
38
+
39
+ Examples:
40
+ >>> class MyDepthToExpend(DepthToExpendP):
41
+ ... def get_current_depth_to_expand(self) -> int:
42
+ ... return 5
43
+ ...
44
+ >>> obj = MyDepthToExpend()
45
+ >>> obj.get_current_depth_to_expand()
46
+ 5
47
+ """
48
+
49
+ def get_current_depth_to_expand(self) -> int:
50
+ """
51
+ Returns the current depth to expand as an integer.
52
+
53
+ Returns:
54
+ The current depth to expand.
55
+
56
+ Raises:
57
+ None
58
+ """
59
+ ...
60
+
61
+
62
+ class StoppingCriterionTypes(str, Enum):
63
+ """
64
+ Enum class representing different types of stopping criteria for tree value calculation.
65
+ """
66
+
67
+ DEPTH_LIMIT = "depth_limit"
68
+ TREE_MOVE_LIMIT = "tree_move_limit"
69
+
70
+
71
+ @dataclass
72
+ class StoppingCriterionArgs:
73
+ """
74
+ Represents the arguments for a stopping criterion.
75
+
76
+ Attributes:
77
+ type (StoppingCriterionTypes): The type of stopping criterion.
78
+ """
79
+
80
+ type: StoppingCriterionTypes
81
+
82
+
83
+ class ProgressMonitorP[NodeT: AlgorithmNode[Any] = AlgorithmNode[Any]](Protocol):
84
+ """
85
+ The general stopping criterion Protocol
86
+ """
87
+
88
+ def should_we_continue(self, tree: trees.Tree[NodeT]) -> bool:
89
+ """
90
+ Asking should we continue
91
+
92
+ Returns:
93
+ boolean of should we continue
94
+ """
95
+ ...
96
+
97
+ def respectful_opening_instructions(
98
+ self,
99
+ opening_instructions: node_sel.OpeningInstructions[NodeT],
100
+ tree: trees.Tree[NodeT],
101
+ ) -> node_sel.OpeningInstructions[NodeT]:
102
+ """
103
+ Ensures the opening request do not exceed the stopping criterion
104
+
105
+
106
+ """
107
+ ...
108
+
109
+ def get_string_of_progress(self, tree: trees.Tree[NodeT]) -> str:
110
+ """
111
+ Returns a string representation of the progress made by the stopping criterion.
112
+
113
+ Args:
114
+ tree (ValueTree): The move and value tree.
115
+
116
+ Returns:
117
+ str: A string representation of the progress.
118
+ """
119
+ ...
120
+
121
+ def get_percent_of_progress(
122
+ self,
123
+ tree: trees.Tree[NodeT],
124
+ ) -> str:
125
+ """Return a human-readable percent progress string."""
126
+ ...
127
+
128
+
129
+ class ProgressMonitor[NodeT: AlgorithmNode[Any] = AlgorithmNode[Any]]:
130
+ """
131
+ The general stopping criterion base class
132
+ """
133
+
134
+ def should_we_continue(self, tree: trees.Tree[NodeT]) -> bool:
135
+ """
136
+ Asking should we continue
137
+
138
+ Returns:
139
+ boolean of should we continue
140
+ """
141
+ if tree.root_node.is_over():
142
+ return False
143
+ return True
144
+
145
+ def respectful_opening_instructions(
146
+ self,
147
+ opening_instructions: node_sel.OpeningInstructions[NodeT],
148
+ tree: trees.Tree[NodeT],
149
+ ) -> node_sel.OpeningInstructions[NodeT]:
150
+ """
151
+ Ensures the opening request do not exceed the stopping criterion
152
+
153
+
154
+ """
155
+ _ = tree
156
+ return opening_instructions
157
+
158
+ def get_string_of_progress(self, _tree: trees.Tree[NodeT]) -> str:
159
+ """
160
+ Returns a string representation of the progress made by the stopping criterion.
161
+
162
+ Args:
163
+ tree (ValueTree): The move and value tree.
164
+
165
+ Returns:
166
+ str: A string representation of the progress.
167
+ """
168
+ return ""
169
+
170
+ @abstractmethod
171
+ def get_percent_of_progress(self, tree: trees.Tree[NodeT]) -> int:
172
+ """Return a numeric progress percentage for this monitor."""
173
+ ...
174
+
175
+ def notify_percent_progress(
176
+ self,
177
+ tree: trees.Tree[NodeT],
178
+ notify_percent_function: Callable[[int], None] | None,
179
+ ) -> None:
180
+ """Notify a callback with the current progress percentage."""
181
+ percent_progress: int = self.get_percent_of_progress(tree=tree)
182
+
183
+ if notify_percent_function is not None:
184
+ notify_percent_function(percent_progress)
185
+
186
+
187
+ @dataclass
188
+ class TreeMoveLimitArgs:
189
+ """Arguments for the tree move limit stopping criterion."""
190
+
191
+ type: Literal[StoppingCriterionTypes.TREE_MOVE_LIMIT]
192
+ tree_move_limit: int
193
+
194
+
195
+ class TreeMoveLimit[NodeT: AlgorithmNode[Any] = AlgorithmNode[Any]](
196
+ ProgressMonitor[NodeT]
197
+ ):
198
+ """
199
+ The stopping criterion based on a tree move limit
200
+ """
201
+
202
+ tree_move_limit: int
203
+
204
+ def __init__(self, tree_move_limit: int) -> None:
205
+ """Initialize the monitor with a move-count limit."""
206
+ self.tree_move_limit = tree_move_limit
207
+
208
+ def should_we_continue(self, tree: trees.Tree[NodeT]) -> bool:
209
+ """Return True while within the move-count budget."""
210
+ continue_base: bool = super().should_we_continue(tree=tree)
211
+
212
+ should_we: bool
213
+ if not continue_base:
214
+ should_we = continue_base
215
+ else:
216
+ should_we = tree.move_count < self.tree_move_limit
217
+ return should_we
218
+
219
+ def respectful_opening_instructions(
220
+ self,
221
+ opening_instructions: node_sel.OpeningInstructions[NodeT],
222
+ tree: trees.Tree[NodeT],
223
+ ) -> node_sel.OpeningInstructions[NodeT]:
224
+ """
225
+ Ensures the opening request do not exceed the stopping criterion
226
+
227
+
228
+ """
229
+ opening_instructions_subset: node_sel.OpeningInstructions[NodeT] = (
230
+ node_sel.OpeningInstructions()
231
+ )
232
+ opening_instructions.pop_items(
233
+ popped=opening_instructions_subset,
234
+ how_many=self.tree_move_limit - tree.move_count,
235
+ )
236
+ return opening_instructions_subset
237
+
238
+ def get_string_of_progress(self, tree: trees.Tree[NodeT]) -> str:
239
+ """
240
+ compute the string that display the progress in the terminal
241
+
242
+ Returns:
243
+ a string that display the progress in the terminal
244
+ """
245
+ return (
246
+ f"========= tree move counting: {tree.move_count} out of {self.tree_move_limit}"
247
+ f" | {tree.move_count / self.tree_move_limit:.0%}"
248
+ )
249
+
250
+ def get_percent_of_progress(
251
+ self,
252
+ tree: trees.Tree[NodeT],
253
+ ) -> int:
254
+ """Return progress percentage based on move count."""
255
+ percent: int = int(tree.move_count / self.tree_move_limit * 100)
256
+ return percent
257
+
258
+
259
+ @dataclass
260
+ class DepthLimitArgs:
261
+ """
262
+ Arguments for the depth limit stopping criterion.
263
+
264
+ Attributes:
265
+ depth_limit (int): The maximum depth allowed for the search.
266
+ """
267
+
268
+ type: Literal[StoppingCriterionTypes.DEPTH_LIMIT]
269
+ depth_limit: int
270
+
271
+
272
+ class DepthLimit[NodeT: AlgorithmNode[Any] = AlgorithmNode[Any]](
273
+ ProgressMonitor[NodeT]
274
+ ):
275
+ """
276
+ The stopping criterion based on a depth limit
277
+ """
278
+
279
+ depth_limit: int
280
+ node_selector: DepthToExpendP
281
+
282
+ def __init__(self, depth_limit: int, node_selector: DepthToExpendP) -> None:
283
+ """
284
+ Initializes a StoppingCriterion object.
285
+
286
+ Args:
287
+ depth_limit (int): The maximum depth to search in the tree.
288
+ node_selector (DepthToExpendP): The node selector used to determine which nodes to expand.
289
+
290
+ Returns:
291
+ None
292
+ """
293
+ self.depth_limit = depth_limit
294
+ self.node_selector = node_selector
295
+
296
+ def should_we_continue(self, tree: trees.Tree[NodeT]) -> bool:
297
+ """
298
+ Determines whether the search should continue expanding nodes in the tree.
299
+
300
+ Args:
301
+ tree (ValueTree): The tree containing the moves and their corresponding values.
302
+
303
+ Returns:
304
+ bool: True if the search should continue, False otherwise.
305
+ """
306
+ continue_base = super().should_we_continue(tree=tree)
307
+ if not continue_base:
308
+ return continue_base
309
+ return self.node_selector.get_current_depth_to_expand() < self.depth_limit
310
+
311
+ def get_string_of_progress(self, tree: trees.Tree[NodeT]) -> str:
312
+ """
313
+ compute the string that display the progress in the terminal
314
+
315
+ Returns:
316
+ a string that display the progress in the terminal
317
+ """
318
+ return (
319
+ "========= tree move counting: "
320
+ + str(tree.move_count)
321
+ + " | Depth: "
322
+ + str(self.node_selector.get_current_depth_to_expand())
323
+ + " out of "
324
+ + str(self.depth_limit)
325
+ )
326
+
327
+ def get_percent_of_progress(
328
+ self,
329
+ tree: trees.Tree[NodeT],
330
+ ) -> int:
331
+ """Return progress percentage based on current depth."""
332
+ # todo this percent is not precise
333
+ percent: int = int(
334
+ self.node_selector.get_current_depth_to_expand() / self.depth_limit * 100
335
+ )
336
+ return percent
337
+
338
+
339
+ AllStoppingCriterionArgs = TreeMoveLimitArgs | DepthLimitArgs
340
+
341
+
342
+ def create_stopping_criterion[NodeT: AlgorithmNode[Any]](
343
+ args: AllStoppingCriterionArgs,
344
+ node_selector: node_sel.NodeSelector[NodeT],
345
+ ) -> ProgressMonitor[NodeT]:
346
+ """
347
+ creating the stopping criterion
348
+
349
+ Args:
350
+ args:
351
+ node_selector:
352
+
353
+ Returns:
354
+ A stopping criterion
355
+
356
+ """
357
+ stopping_criterion: ProgressMonitor[NodeT]
358
+
359
+ match args.type:
360
+ case StoppingCriterionTypes.DEPTH_LIMIT:
361
+ assert isinstance(node_selector, DepthToExpendP)
362
+ assert isinstance(args, DepthLimitArgs)
363
+ stopping_criterion = DepthLimit(
364
+ depth_limit=args.depth_limit, node_selector=node_selector
365
+ )
366
+ case StoppingCriterionTypes.TREE_MOVE_LIMIT:
367
+ assert isinstance(args, TreeMoveLimitArgs)
368
+
369
+ stopping_criterion = TreeMoveLimit(tree_move_limit=args.tree_move_limit)
370
+ case _:
371
+ raise ValueError(
372
+ f"stopping criterion builder: can not find {args.type} in file {__name__}"
373
+ )
374
+
375
+ return stopping_criterion
@@ -0,0 +1,12 @@
1
+ """
2
+ This module provides classes for defining recommender rules.
3
+
4
+ Classes:
5
+ - RecommenderRule: Represents a recommender rule.
6
+ - AllRecommendFunctionsArgs: Represents the arguments for all recommend functions.
7
+
8
+ """
9
+
10
+ from .recommender_rule import AllRecommendFunctionsArgs, RecommenderRule
11
+
12
+ __all__ = ["AllRecommendFunctionsArgs", "RecommenderRule"]
@@ -0,0 +1,140 @@
1
+ """
2
+ This module defines recommender rules for selecting moves in a tree-based move selector.
3
+
4
+ The recommender rules are implemented as data classes that define a `__call__` method. The `__call__` method takes a
5
+ `ValueTree` object and a random generator, and returns a recommended chess move.
6
+
7
+ The available recommender rule types are defined in the `RecommenderRuleTypes` enum.
8
+
9
+ The module also defines a `RecommenderRule` protocol that all recommender rule classes must implement.
10
+
11
+ Example usage:
12
+ rule = AlmostEqualLogistic(type=RecommenderRuleTypes.AlmostEqualLogistic, temperature=0.5)
13
+ move = rule(tree, random_generator)
14
+ """
15
+
16
+ from dataclasses import dataclass
17
+ from enum import Enum
18
+ from random import Random
19
+ from typing import Literal, Mapping, Protocol
20
+
21
+ from valanga import BranchKey, State
22
+
23
+ from anemone.nodes.algorithm_node.algorithm_node import (
24
+ AlgorithmNode,
25
+ )
26
+ from anemone.utils.small_tools import softmax
27
+
28
+
29
+ @dataclass(frozen=True, slots=True)
30
+ class BranchPolicy:
31
+ """
32
+ Represents a probability distribution over branches.
33
+ """
34
+
35
+ probs: Mapping[BranchKey, float] # should sum to ~1.0
36
+
37
+
38
+ def sample_from_policy(policy: BranchPolicy, rng: Random) -> BranchKey:
39
+ """Sample a branch key from a probability policy using a RNG."""
40
+ branches = list(policy.probs.keys())
41
+ weights = list(policy.probs.values())
42
+ return rng.choices(branches, weights=weights, k=1)[0]
43
+
44
+
45
+ class RecommenderRule(Protocol):
46
+ """
47
+ Protocol for recommender rules.
48
+ """
49
+
50
+ type: str
51
+
52
+ def policy[StateT: State](self, root_node: AlgorithmNode[StateT]) -> BranchPolicy:
53
+ """Return the policy distribution for the root node."""
54
+ ...
55
+
56
+ def sample(self, policy: BranchPolicy, rng: Random) -> BranchKey:
57
+ """Sample a branch key using the provided RNG."""
58
+ ...
59
+
60
+
61
+ class RecommenderRuleTypes(str, Enum):
62
+ """
63
+ Enum class that defines the available recommender rule types.
64
+ """
65
+
66
+ ALMOST_EQUAL_LOGISTIC = "almost_equal_logistic"
67
+ SOFTMAX = "softmax"
68
+
69
+
70
+ # theses are functions but i still use dataclasses instead
71
+ # of partial to be able to easily construct from yaml files using dacite
72
+
73
+
74
+ @dataclass(slots=True)
75
+ class AlmostEqualLogistic:
76
+ """
77
+ Almost Equal Logistic recommender rule that selects moves with nearly equal evaluations.
78
+ """
79
+
80
+ type: Literal["almost_equal_logistic"]
81
+ temperature: float # kept for config compatibility; rule uses minmax method
82
+
83
+ def policy[StateT: State](self, root_node: AlgorithmNode[StateT]) -> BranchPolicy:
84
+ """Compute a policy based on near-equal best branches."""
85
+ best: list[BranchKey] = root_node.tree_evaluation.get_all_of_the_best_branches(
86
+ how_equal="almost_equal_logistic"
87
+ )
88
+
89
+ # Fallback: if empty, uniform over all existing children
90
+ if not best:
91
+ best = [
92
+ bk for bk, ch in root_node.branches_children.items() if ch is not None
93
+ ]
94
+
95
+ # If still empty, something is wrong (no legal moves / not expanded)
96
+ if not best:
97
+ return BranchPolicy(probs={})
98
+
99
+ p = 1.0 / len(best)
100
+ return BranchPolicy(probs={bk: p for bk in best})
101
+
102
+ def sample(self, policy: BranchPolicy, rng: Random) -> BranchKey:
103
+ """Sample a branch from the policy using the provided RNG."""
104
+ return sample_from_policy(policy, rng)
105
+
106
+
107
+ @dataclass(slots=True)
108
+ class SoftmaxRule:
109
+ """
110
+ Softmax recommender rule that computes a softmax distribution over child evaluations.
111
+ """
112
+
113
+ type: Literal["softmax"]
114
+ temperature: float
115
+
116
+ def policy[StateT: State](self, root_node: AlgorithmNode[StateT]) -> BranchPolicy:
117
+ """Compute a softmax policy over child evaluations."""
118
+ branches: list[BranchKey] = []
119
+ scores: list[float] = []
120
+
121
+ for bk, child in root_node.branches_children.items():
122
+ if child is None:
123
+ continue
124
+ branches.append(bk)
125
+ score = root_node.tree_evaluation.subjective_value_of(child.tree_evaluation)
126
+ scores.append(float(score))
127
+
128
+ if not branches:
129
+ return BranchPolicy(probs={})
130
+
131
+ probs_list = softmax(scores, self.temperature) # list[float] or Sequence[float]
132
+ probs = {bk: float(p) for bk, p in zip(branches, probs_list, strict=True)}
133
+ return BranchPolicy(probs=probs)
134
+
135
+ def sample(self, policy: BranchPolicy, rng: Random) -> BranchKey:
136
+ """Sample a branch from the policy using the provided RNG."""
137
+ return sample_from_policy(policy, rng)
138
+
139
+
140
+ AllRecommendFunctionsArgs = AlmostEqualLogistic | SoftmaxRule
@@ -0,0 +1,14 @@
1
+ """
2
+ This module provides factories for creating search objects and node selectors.
3
+
4
+ The factories included in this module are:
5
+ - SearchFactoryP: A factory for creating search objects with parallel execution.
6
+ - SearchFactory: A factory for creating search objects with sequential execution.
7
+ - NodeSelectorFactory: A factory for creating node selectors.
8
+
9
+ To use this module, import the desired factory class from this module and use it to create the desired objects.
10
+ """
11
+
12
+ from .search_factory import NodeSelectorFactory, SearchFactory, SearchFactoryP
13
+
14
+ __all__ = ["SearchFactoryP", "SearchFactory", "NodeSelectorFactory"]