goshape 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- goshape-0.1.0.dist-info/METADATA +66 -0
- goshape-0.1.0.dist-info/RECORD +17 -0
- goshape-0.1.0.dist-info/WHEEL +4 -0
- goshape-0.1.0.dist-info/entry_points.txt +2 -0
- goshape-0.1.0.dist-info/licenses/LICENSE +40 -0
- shape/game_logic.py +265 -0
- shape/katago/analysis.cfg +215 -0
- shape/katago/downloader.py +410 -0
- shape/katago/engine.py +176 -0
- shape/main.py +46 -0
- shape/ui/board_view.py +296 -0
- shape/ui/main_window.py +347 -0
- shape/ui/tab_analysis.py +101 -0
- shape/ui/tab_config.py +68 -0
- shape/ui/tab_main_control.py +315 -0
- shape/ui/ui_utils.py +120 -0
- shape/utils.py +7 -0
shape/ui/board_view.py
ADDED
@@ -0,0 +1,296 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from PySide6.QtCore import QPointF, QRectF, QSize, Qt
|
3
|
+
from PySide6.QtGui import (
|
4
|
+
QBrush,
|
5
|
+
QColor,
|
6
|
+
QFont,
|
7
|
+
QPainter,
|
8
|
+
QPen,
|
9
|
+
QRadialGradient,
|
10
|
+
)
|
11
|
+
from PySide6.QtWidgets import QSizePolicy, QWidget
|
12
|
+
|
13
|
+
from shape.game_logic import Move, PolicyData
|
14
|
+
from shape.utils import setup_logging
|
15
|
+
|
16
|
+
logger = setup_logging()
|
17
|
+
|
18
|
+
|
19
|
+
def interpolate_color(color1, color2, ratio):
|
20
|
+
r = color1.red() + (color2.red() - color1.red()) * ratio
|
21
|
+
g = color1.green() + (color2.green() - color1.green()) * ratio
|
22
|
+
b = color1.blue() + (color2.blue() - color1.blue()) * ratio
|
23
|
+
return QColor(int(r), int(g), int(b))
|
24
|
+
|
25
|
+
|
26
|
+
class BoardView(QWidget):
|
27
|
+
WOOD_COLOR = QColor(220, 179, 92)
|
28
|
+
PLAYER_POLICY_COLOR = QColor(20, 200, 20)
|
29
|
+
TARGET_POLICY_COLOR = QColor(0, 100, 0)
|
30
|
+
AI_POLICY_COLOR = QColor(0, 0, 139)
|
31
|
+
OPPONENT_POLICY_COLOR = QColor(139, 0, 0)
|
32
|
+
|
33
|
+
def sizeHint(self):
|
34
|
+
return QSize(600, 600)
|
35
|
+
|
36
|
+
def __init__(self, main_window, parent=None):
|
37
|
+
super().__init__(parent)
|
38
|
+
self.main_window = main_window
|
39
|
+
self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
|
40
|
+
self.setMinimumSize(400, 400)
|
41
|
+
|
42
|
+
def calculate_dimensions(self, board_size):
|
43
|
+
self.board_size = board_size
|
44
|
+
cell_size_h = self.width() / (board_size + 0.5) # n-1 grid, 1 l 0.5 r
|
45
|
+
cell_size_v = self.height() / (board_size + 1.05) # n-1 grid, 0.5 coord 0.5 btn
|
46
|
+
self.cell_size = min(cell_size_h, cell_size_v)
|
47
|
+
self.margin_left = self.cell_size
|
48
|
+
self.margin_top = self.cell_size * 0.5
|
49
|
+
self.margin_bottom = self.cell_size
|
50
|
+
self.stone_size = self.cell_size * 0.95
|
51
|
+
self.nav_buttons = [
|
52
|
+
("⏮", lambda: self.main_window.on_prev_move(1000), 1.3),
|
53
|
+
("⏪", lambda: self.main_window.on_prev_move(5), 1.3),
|
54
|
+
("◀", lambda: self.main_window.on_prev_move(), 0.9),
|
55
|
+
("pass", lambda: self.main_window.on_pass_move(), 1.0),
|
56
|
+
("▶", lambda: self.main_window.on_next_move(), 0.9),
|
57
|
+
("⏩", lambda: self.main_window.on_next_move(5), 1.3),
|
58
|
+
("⏭", lambda: self.main_window.on_next_move(1000), 1.3),
|
59
|
+
]
|
60
|
+
button_width = self.width() / (2 * len(self.nav_buttons))
|
61
|
+
self.nav_rects = [
|
62
|
+
QRectF(
|
63
|
+
self.width() * 0.25 + i * button_width,
|
64
|
+
(board_size + 0.275) * self.cell_size,
|
65
|
+
button_width,
|
66
|
+
self.margin_bottom,
|
67
|
+
)
|
68
|
+
for i in range(len(self.nav_buttons))
|
69
|
+
]
|
70
|
+
self.coord_font_size = max(int(self.cell_size / 3), 8)
|
71
|
+
|
72
|
+
def intersection_coords(self, col, row) -> QPointF:
|
73
|
+
x = self.margin_left + col * self.cell_size
|
74
|
+
y = self.margin_top + (self.board_size - row - 1) * self.cell_size
|
75
|
+
return QPointF(x, y)
|
76
|
+
|
77
|
+
def paintEvent(self, event):
|
78
|
+
board_state = self.main_window.game_logic.board_state
|
79
|
+
self.calculate_dimensions(self.main_window.game_logic.square_board_size)
|
80
|
+
|
81
|
+
painter = QPainter(self)
|
82
|
+
painter.setRenderHint(QPainter.Antialiasing, True)
|
83
|
+
painter.setRenderHint(QPainter.SmoothPixmapTransform, True)
|
84
|
+
|
85
|
+
self.draw_board(painter)
|
86
|
+
self.draw_coordinates_and_nav(painter)
|
87
|
+
self.draw_star_points(painter)
|
88
|
+
|
89
|
+
heatmap_settings = self.main_window.control_panel.get_heatmap_settings()
|
90
|
+
sampling_settings = self.main_window.config_panel.get_sampling_settings()
|
91
|
+
self.draw_heatmap(painter, heatmap_settings["policy"], sampling_settings)
|
92
|
+
|
93
|
+
self.draw_stones(board_state, painter)
|
94
|
+
self.draw_game_status(painter)
|
95
|
+
|
96
|
+
def draw_board(self, painter):
|
97
|
+
if self.main_window.game_logic.current_node.autoplay_halted_reason:
|
98
|
+
color = QColor(self.WOOD_COLOR.red() + 20, self.WOOD_COLOR.green() - 40, self.WOOD_COLOR.blue() - 20)
|
99
|
+
else:
|
100
|
+
color = self.WOOD_COLOR
|
101
|
+
painter.fillRect(self.rect(), color)
|
102
|
+
painter.setPen(QPen(QColor(0, 0, 0, 180), 1))
|
103
|
+
for i in range(self.board_size):
|
104
|
+
painter.drawLine(self.intersection_coords(0, i), self.intersection_coords(self.board_size - 1, i))
|
105
|
+
painter.drawLine(self.intersection_coords(i, 0), self.intersection_coords(i, self.board_size - 1))
|
106
|
+
painter.setPen(QPen(QColor(0, 0, 0), 2))
|
107
|
+
grid_size = self.cell_size * (self.board_size - 1)
|
108
|
+
painter.drawRect(QRectF(self.margin_left, self.margin_top, grid_size, grid_size))
|
109
|
+
|
110
|
+
def draw_star_points(self, painter):
|
111
|
+
painter.setBrush(QBrush(Qt.black))
|
112
|
+
for col, row in self.get_star_points():
|
113
|
+
painter.drawEllipse(self.intersection_coords(col, row), 3, 3)
|
114
|
+
|
115
|
+
def draw_stones(self, board_state, painter):
|
116
|
+
game_logic = self.main_window.game_logic
|
117
|
+
for row in range(self.board_size):
|
118
|
+
for col in range(self.board_size):
|
119
|
+
if board_state[row][col] is not None:
|
120
|
+
self.draw_stone(painter, row, col, board_state[row][col])
|
121
|
+
|
122
|
+
last_move = game_logic.move
|
123
|
+
if last_move and not last_move.is_pass:
|
124
|
+
center = self.intersection_coords(*last_move.coords)
|
125
|
+
outline_color = QColor(240, 240, 240, 180) if last_move.player == "B" else QColor(50, 50, 50, 180)
|
126
|
+
painter.setPen(QPen(outline_color, 2))
|
127
|
+
painter.setBrush(Qt.NoBrush)
|
128
|
+
painter.drawEllipse(center, self.stone_size / 4, self.stone_size / 4)
|
129
|
+
|
130
|
+
def draw_stone(self, painter, row, col, color):
|
131
|
+
center = self.intersection_coords(col, row)
|
132
|
+
|
133
|
+
gradient = QRadialGradient(center.x() - self.stone_size / 4, center.y() - self.stone_size / 4, self.stone_size)
|
134
|
+
if color == "B":
|
135
|
+
gradient.setColorAt(0, QColor(80, 80, 80))
|
136
|
+
gradient.setColorAt(0.5, Qt.black)
|
137
|
+
gradient.setColorAt(1, QColor(10, 10, 10))
|
138
|
+
else:
|
139
|
+
gradient.setColorAt(0, QColor(230, 230, 230))
|
140
|
+
gradient.setColorAt(0.5, Qt.white)
|
141
|
+
gradient.setColorAt(1, QColor(200, 200, 200))
|
142
|
+
|
143
|
+
painter.setBrush(QBrush(gradient))
|
144
|
+
painter.setPen(Qt.NoPen)
|
145
|
+
painter.drawEllipse(center, self.stone_size / 2, self.stone_size / 2)
|
146
|
+
|
147
|
+
def draw_coordinates_and_nav(self, painter):
|
148
|
+
font = QFont("Arial", self.coord_font_size, QFont.Bold)
|
149
|
+
painter.setFont(font)
|
150
|
+
painter.setPen(QColor(0, 0, 0))
|
151
|
+
for i in range(self.board_size):
|
152
|
+
bottom_box = self.intersection_coords(i - 0.5, -0.5)
|
153
|
+
painter.drawText(
|
154
|
+
QRectF(bottom_box.x(), bottom_box.y(), self.cell_size, self.cell_size * 0.5),
|
155
|
+
Qt.AlignHCenter | Qt.AlignTop,
|
156
|
+
Move.GTP_COORD[i],
|
157
|
+
)
|
158
|
+
left_box = self.intersection_coords(-1, i + 0.5)
|
159
|
+
painter.drawText(
|
160
|
+
QRectF(left_box.x(), left_box.y(), self.cell_size * 0.5, self.cell_size),
|
161
|
+
Qt.AlignVCenter | Qt.AlignRight,
|
162
|
+
str(i + 1),
|
163
|
+
)
|
164
|
+
for (text, _, size_adj), nav_rect in zip(self.nav_buttons, self.nav_rects, strict=False):
|
165
|
+
painter.setFont(QFont("Arial", self.coord_font_size * size_adj, QFont.Bold))
|
166
|
+
painter.drawText(nav_rect, Qt.AlignCenter, text)
|
167
|
+
|
168
|
+
def mousePressEvent(self, event):
|
169
|
+
for (_, callback, _), rect in zip(self.nav_buttons, self.nav_rects, strict=False):
|
170
|
+
if rect.contains(event.pos()):
|
171
|
+
callback()
|
172
|
+
return
|
173
|
+
col = round((event.x() - self.margin_left) / self.cell_size)
|
174
|
+
row = round((event.y() - self.margin_top) / self.cell_size)
|
175
|
+
if 0 <= col < self.board_size and 0 <= row < self.board_size:
|
176
|
+
self.main_window.make_move((col, self.board_size - row - 1))
|
177
|
+
|
178
|
+
def get_star_points(self):
|
179
|
+
star_points = {
|
180
|
+
19: [(3, 3), (3, 9), (3, 15), (9, 3), (9, 9), (9, 15), (15, 3), (15, 9), (15, 15)],
|
181
|
+
13: [(3, 3), (3, 9), (6, 6), (9, 3), (9, 9)],
|
182
|
+
9: [(2, 2), (2, 6), (4, 4), (6, 2), (6, 6)],
|
183
|
+
}
|
184
|
+
return star_points.get(self.board_size, [])
|
185
|
+
|
186
|
+
def get_weighted_policy_data(self, human_profiles: list[str]) -> tuple[np.ndarray | None, np.ndarray]:
|
187
|
+
game_logic = self.main_window.game_logic
|
188
|
+
heatmap_mean_prob = None
|
189
|
+
heatmap_mean_rank = np.array([]) # make type hints happy
|
190
|
+
num_enabled = 0
|
191
|
+
|
192
|
+
for i, (profile, enabled) in enumerate(human_profiles):
|
193
|
+
if not (analysis := game_logic.current_node.get_analysis(profile)):
|
194
|
+
continue
|
195
|
+
policy = np.array(analysis.human_policy.data)
|
196
|
+
if enabled:
|
197
|
+
num_enabled += 1
|
198
|
+
if heatmap_mean_prob is None:
|
199
|
+
heatmap_mean_prob = np.zeros_like(policy)
|
200
|
+
heatmap_mean_rank = np.zeros_like(policy)
|
201
|
+
heatmap_mean_prob += policy
|
202
|
+
heatmap_mean_rank += policy * i
|
203
|
+
if heatmap_mean_prob is not None:
|
204
|
+
heatmap_mean_rank /= heatmap_mean_prob.clip(min=1e-10)
|
205
|
+
heatmap_mean_prob /= num_enabled
|
206
|
+
|
207
|
+
return heatmap_mean_prob, heatmap_mean_rank
|
208
|
+
|
209
|
+
def draw_heatmap(self, painter, policy, sampling_settings):
|
210
|
+
heatmap_mean_prob, heatmap_mean_rank = self.get_weighted_policy_data(policy)
|
211
|
+
if (
|
212
|
+
self.main_window.game_logic.move
|
213
|
+
and self.main_window.game_logic.move.is_pass
|
214
|
+
and self.main_window.game_logic.current_node.parent.parent is None
|
215
|
+
):
|
216
|
+
prob_gradient = np.tile(np.linspace(0.01, 1, self.board_size), (self.board_size, 1))
|
217
|
+
rank_gradient = np.tile(np.linspace(0, 2, self.board_size), (self.board_size, 1)).T
|
218
|
+
heatmap_mean_prob = np.append(prob_gradient.ravel(), 0)
|
219
|
+
heatmap_mean_rank = np.append(rank_gradient.ravel(), 0)
|
220
|
+
sampling_settings = {"min_p": 0}
|
221
|
+
|
222
|
+
if heatmap_mean_prob is not None:
|
223
|
+
top_moves, _ = PolicyData(heatmap_mean_prob).sample(
|
224
|
+
secondary_data=PolicyData.grid_from_data(heatmap_mean_rank), **sampling_settings
|
225
|
+
)
|
226
|
+
self.draw_heatmap_points(painter, top_moves)
|
227
|
+
|
228
|
+
def draw_heatmap_points(self, painter, top_moves, show_text=True):
|
229
|
+
max_prob = top_moves[0][1]
|
230
|
+
for move, prob, rank in top_moves:
|
231
|
+
color = self.get_heatmap_color(rank)
|
232
|
+
rel_prob = prob / max_prob
|
233
|
+
size = 0.25 + 0.725 * rel_prob
|
234
|
+
center = self.intersection_coords(*move.coords)
|
235
|
+
x = center.x() - size / 2
|
236
|
+
y = center.y() - size / 2
|
237
|
+
|
238
|
+
text = "" if rel_prob < 0.01 else f"{prob * 100:.0f}"
|
239
|
+
|
240
|
+
painter.setBrush(QBrush(color))
|
241
|
+
painter.setPen(QPen(Qt.black))
|
242
|
+
square_size = self.cell_size * size
|
243
|
+
painter.drawRect(QRectF(x - square_size / 2, y - square_size / 2, square_size, square_size))
|
244
|
+
|
245
|
+
if show_text:
|
246
|
+
font = QFont("Arial", int(self.cell_size / 3.5))
|
247
|
+
font.setBold(True)
|
248
|
+
painter.setFont(font)
|
249
|
+
painter.setPen(QColor(200, 200, 200))
|
250
|
+
painter.drawText(
|
251
|
+
QRectF(x - self.cell_size / 2, y - self.cell_size / 2, self.cell_size, self.cell_size),
|
252
|
+
Qt.AlignCenter,
|
253
|
+
text,
|
254
|
+
)
|
255
|
+
|
256
|
+
def get_heatmap_color(self, mean_rank):
|
257
|
+
if mean_rank < 1: # Interpolate between Light Green and Dark Green
|
258
|
+
ratio = mean_rank / 1
|
259
|
+
return interpolate_color(self.PLAYER_POLICY_COLOR, self.TARGET_POLICY_COLOR, ratio)
|
260
|
+
elif mean_rank <= 2: # Interpolate between Dark Green and Dark Blue
|
261
|
+
ratio = min(1, (mean_rank - 1) / 1)
|
262
|
+
return interpolate_color(self.TARGET_POLICY_COLOR, self.AI_POLICY_COLOR, ratio)
|
263
|
+
else:
|
264
|
+
return self.OPPONENT_POLICY_COLOR
|
265
|
+
|
266
|
+
def draw_game_status(self, painter):
|
267
|
+
game_logic = self.main_window.game_logic
|
268
|
+
message = ""
|
269
|
+
if game_logic.game_ended():
|
270
|
+
message = "Both players passed."
|
271
|
+
elif game_logic.current_node.move and game_logic.current_node.move.is_pass:
|
272
|
+
message = "Pass"
|
273
|
+
|
274
|
+
if message:
|
275
|
+
font = QFont("Arial", int(self.cell_size * 0.4))
|
276
|
+
font.setBold(True)
|
277
|
+
painter.setFont(font)
|
278
|
+
painter.setPen(QColor(0, 0, 0))
|
279
|
+
text_rect = QRectF(0, 0, self.width(), self.margin_top)
|
280
|
+
painter.drawText(text_rect, Qt.AlignCenter, message)
|
281
|
+
|
282
|
+
def keyPressEvent(self, event):
|
283
|
+
if event.key() == Qt.Key_Left:
|
284
|
+
if event.modifiers() & Qt.ShiftModifier:
|
285
|
+
self.main_window.on_prev_move(5)
|
286
|
+
else:
|
287
|
+
self.main_window.on_prev_move()
|
288
|
+
elif event.key() == Qt.Key_Right:
|
289
|
+
if event.modifiers() & Qt.ShiftModifier:
|
290
|
+
self.main_window.on_next_move(5)
|
291
|
+
else:
|
292
|
+
self.main_window.on_next_move()
|
293
|
+
elif event.key() == Qt.Key_Space:
|
294
|
+
self.main_window.on_pass_move()
|
295
|
+
else:
|
296
|
+
super().keyPressEvent(event)
|
shape/ui/main_window.py
ADDED
@@ -0,0 +1,347 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from PySide6.QtCore import QEvent, Qt, QTimer, Signal
|
5
|
+
from PySide6.QtGui import QAction
|
6
|
+
from PySide6.QtWidgets import (
|
7
|
+
QApplication,
|
8
|
+
QFileDialog,
|
9
|
+
QHBoxLayout,
|
10
|
+
QLabel,
|
11
|
+
QMainWindow,
|
12
|
+
QMenu,
|
13
|
+
QMenuBar,
|
14
|
+
QStatusBar,
|
15
|
+
QTabWidget,
|
16
|
+
QVBoxLayout,
|
17
|
+
QWidget,
|
18
|
+
)
|
19
|
+
|
20
|
+
from shape.game_logic import GameLogic, Move
|
21
|
+
from shape.ui.board_view import BoardView
|
22
|
+
from shape.ui.tab_analysis import AnalysisPanel
|
23
|
+
from shape.ui.tab_config import ConfigPanel
|
24
|
+
from shape.ui.tab_main_control import ControlPanel
|
25
|
+
from shape.ui.ui_utils import MAIN_STYLESHEET
|
26
|
+
from shape.utils import setup_logging
|
27
|
+
|
28
|
+
logger = setup_logging()
|
29
|
+
|
30
|
+
|
31
|
+
class MainWindow(QMainWindow):
|
32
|
+
update_state_main_thread = Signal()
|
33
|
+
|
34
|
+
def __init__(self):
|
35
|
+
super().__init__()
|
36
|
+
self.setStyleSheet(MAIN_STYLESHEET)
|
37
|
+
self.katago_engine = None
|
38
|
+
self.game_logic = GameLogic()
|
39
|
+
self.setWindowTitle("SHAPE - Play Go with AI Feedback")
|
40
|
+
self.setFocusPolicy(Qt.StrongFocus)
|
41
|
+
self.setup_ui()
|
42
|
+
self.connect_signals()
|
43
|
+
|
44
|
+
self.update_state_timer = QTimer(self)
|
45
|
+
self.update_state_timer.setSingleShot(True)
|
46
|
+
self.update_state_timer.timeout.connect(self._update_state)
|
47
|
+
|
48
|
+
def set_engine(self, katago_engine):
|
49
|
+
self.katago_engine = katago_engine
|
50
|
+
# Update window title with version info
|
51
|
+
try:
|
52
|
+
from importlib.metadata import version
|
53
|
+
|
54
|
+
shape_version = version("goshape")
|
55
|
+
except ImportError:
|
56
|
+
shape_version = "dev" # Fallback for development
|
57
|
+
katago_version = getattr(katago_engine, "katago_version", "Unknown")
|
58
|
+
katago_backend = getattr(katago_engine, "katago_backend", "Unknown")
|
59
|
+
self.setWindowTitle(f"SHAPE v{shape_version} running KataGo {katago_version} ({katago_backend})")
|
60
|
+
self.update_state()
|
61
|
+
|
62
|
+
def set_logging_level(self, level):
|
63
|
+
logger.setLevel(level)
|
64
|
+
logging.getLogger().setLevel(level)
|
65
|
+
|
66
|
+
def connect_signals(self):
|
67
|
+
self.control_panel.ai_move_button.clicked.connect(self.request_ai_move)
|
68
|
+
self.config_panel.settings_updated.connect(self.update_state)
|
69
|
+
self.control_panel.settings_updated.connect(self.update_state)
|
70
|
+
self.update_state_main_thread.connect(self.update_state)
|
71
|
+
|
72
|
+
def update_state(self):
|
73
|
+
self.update_state_timer.start(100) # 100ms debounce
|
74
|
+
|
75
|
+
def _update_state(self):
|
76
|
+
current_node = self.game_logic.current_node
|
77
|
+
human_profiles, current_analysis = self.ensure_analysis_requested(current_node)
|
78
|
+
next_player_human = self.control_panel.get_player_color() == self.game_logic.next_player
|
79
|
+
# halt auto-play if
|
80
|
+
if not self.game_logic.game_ended() and all(current_analysis.values()):
|
81
|
+
if not next_player_human and (
|
82
|
+
should_halt_reason := self.config_panel.should_halt_on_mistake(
|
83
|
+
self.control_panel.get_move_stats(current_node)
|
84
|
+
)
|
85
|
+
):
|
86
|
+
current_node.autoplay_halted_reason = should_halt_reason
|
87
|
+
logger.info(f"Halting auto-play due to {should_halt_reason}.")
|
88
|
+
else:
|
89
|
+
self.maybe_make_ai_move(current_node, human_profiles, current_analysis, next_player_human)
|
90
|
+
|
91
|
+
for tab in [self.control_panel, self.analysis_panel, self.config_panel]:
|
92
|
+
tab.update_ui()
|
93
|
+
self.board_view.update()
|
94
|
+
|
95
|
+
def maybe_make_ai_move(self, current_node, human_profiles, current_analysis, next_player_human):
|
96
|
+
if (
|
97
|
+
not current_node.children and self.control_panel.is_auto_play_enabled() and not next_player_human
|
98
|
+
) or current_node.ai_move_requested:
|
99
|
+
current_node.ai_move_requested = False
|
100
|
+
policy_moves, reason = current_analysis[human_profiles["opponent"]].human_policy.sample(
|
101
|
+
**self.config_panel.get_sampling_settings()
|
102
|
+
)
|
103
|
+
best_ai_move = current_analysis[None].ai_moves()[0]["move"]
|
104
|
+
if policy_moves:
|
105
|
+
if best_ai_move == "pass":
|
106
|
+
logger.info("Passing because it is the best AI move")
|
107
|
+
self.make_move(None)
|
108
|
+
else:
|
109
|
+
moves, probs, _ = zip(*policy_moves, strict=False)
|
110
|
+
move = np.random.choice(moves, p=np.array(probs) / sum(probs))
|
111
|
+
logger.info(f"Making sampled move: {move} from {len(policy_moves)} cuttoff due to {reason}")
|
112
|
+
self.make_move(move.coords)
|
113
|
+
else:
|
114
|
+
logger.info("No valid moves available, passing")
|
115
|
+
self.make_move(None)
|
116
|
+
|
117
|
+
# actions
|
118
|
+
def make_move(self, coords: tuple[int, int] | None):
|
119
|
+
if self.game_logic.make_move(Move(coords=coords, player=self.game_logic.next_player)):
|
120
|
+
self.update_state()
|
121
|
+
|
122
|
+
def on_prev_move(self, n=1):
|
123
|
+
self.game_logic.undo_move(n)
|
124
|
+
self.update_state()
|
125
|
+
|
126
|
+
def on_next_move(self, n=1):
|
127
|
+
self.game_logic.redo_move(n)
|
128
|
+
self.update_state()
|
129
|
+
|
130
|
+
def request_ai_move(self):
|
131
|
+
self.game_logic.current_node.ai_move_requested = True
|
132
|
+
self.update_status_bar("AI move requested")
|
133
|
+
self.update_state()
|
134
|
+
|
135
|
+
def new_game(self, size):
|
136
|
+
logger.info(f"New game requested with size: {size}")
|
137
|
+
self.game_logic.new_game(size)
|
138
|
+
self.update_state()
|
139
|
+
|
140
|
+
def copy_sgf_to_clipboard(self):
|
141
|
+
self.save_as_sgf(to_clipboard=True)
|
142
|
+
|
143
|
+
def save_as_sgf(self, to_clipboard: bool = False):
|
144
|
+
def get_player_name(color):
|
145
|
+
if self.control_panel.get_player_color() == color:
|
146
|
+
return "Human"
|
147
|
+
else:
|
148
|
+
profile = self.control_panel.get_human_profiles()["opponent"]
|
149
|
+
return f"AI ({profile})" if profile else "KataGo"
|
150
|
+
|
151
|
+
player_names = {bw: get_player_name(bw) for bw in "BW"}
|
152
|
+
sgf_data = self.game_logic.export_sgf(player_names)
|
153
|
+
|
154
|
+
if to_clipboard:
|
155
|
+
clipboard = QApplication.clipboard()
|
156
|
+
clipboard.setText(sgf_data)
|
157
|
+
self.update_status_bar(f"SGF of length {len(sgf_data)} with {player_names} copied to clipboard.")
|
158
|
+
else:
|
159
|
+
file_path, _ = QFileDialog.getSaveFileName(self, "Save SGF File", "", "SGF Files (*.sgf)")
|
160
|
+
if file_path:
|
161
|
+
if not file_path.lower().endswith(".sgf"):
|
162
|
+
file_path += ".sgf"
|
163
|
+
with open(file_path, "w") as f:
|
164
|
+
f.write(sgf_data)
|
165
|
+
self.update_status_bar(f"SGF saved to {file_path}.")
|
166
|
+
|
167
|
+
def paste_sgf_from_clipboard(self):
|
168
|
+
clipboard = QApplication.clipboard()
|
169
|
+
sgf_data = clipboard.text()
|
170
|
+
if self.game_logic.import_sgf(sgf_data):
|
171
|
+
self.update_state()
|
172
|
+
self.update_status_bar("SGF imported successfully.")
|
173
|
+
for node in self.game_logic.current_node.node_history:
|
174
|
+
self.ensure_analysis_requested(node)
|
175
|
+
else:
|
176
|
+
self.update_status_bar("Failed to import SGF.")
|
177
|
+
|
178
|
+
# analysis
|
179
|
+
def ensure_analysis_requested(self, node):
|
180
|
+
human_profiles = self.control_panel.get_human_profiles()
|
181
|
+
current_analysis = {
|
182
|
+
k: node.get_analysis(k)
|
183
|
+
for k in [None, human_profiles["player"], human_profiles["opponent"], human_profiles["target"]]
|
184
|
+
}
|
185
|
+
for k, v in current_analysis.items():
|
186
|
+
if not v and not node.analysis_requested(k):
|
187
|
+
self.request_analysis(node, human_profile=k)
|
188
|
+
return human_profiles, current_analysis
|
189
|
+
|
190
|
+
def request_analysis(self, node, human_profile, force_visits=None):
|
191
|
+
if node.analysis_requested(human_profile) and not force_visits:
|
192
|
+
return
|
193
|
+
|
194
|
+
logger.debug(f"Requesting analysis for {human_profile=} for {node=}")
|
195
|
+
|
196
|
+
if human_profile:
|
197
|
+
human_profile_settings = {
|
198
|
+
"humanSLProfile": human_profile,
|
199
|
+
"ignorePreRootHistory": False,
|
200
|
+
"rootNumSymmetriesToSample": 8, # max quality policy
|
201
|
+
}
|
202
|
+
max_visits = 1
|
203
|
+
else:
|
204
|
+
human_profile_settings = {}
|
205
|
+
max_visits = force_visits or self.config_panel.get_ai_strength()
|
206
|
+
|
207
|
+
if self.katago_engine:
|
208
|
+
node.mark_analysis_requested(human_profile)
|
209
|
+
self.katago_engine.analyze_position(
|
210
|
+
node=node,
|
211
|
+
callback=lambda resp: self.on_analysis_complete(node, resp, human_profile),
|
212
|
+
human_profile_settings=human_profile_settings,
|
213
|
+
max_visits=max_visits,
|
214
|
+
)
|
215
|
+
|
216
|
+
# this will be called from the engine thread
|
217
|
+
def on_analysis_complete(self, node, analysis, human_profile):
|
218
|
+
if "error" in analysis:
|
219
|
+
logger.error(f"Analysis error: {analysis['error']}")
|
220
|
+
self.update_status_bar(f"Analysis error: {analysis['error']}")
|
221
|
+
if self.game_logic.current_node is node:
|
222
|
+
self.game_logic.undo_move()
|
223
|
+
logger.info(f"Deleting child node {node} because of analysis error => {node.parent.delete_child(node)}")
|
224
|
+
return
|
225
|
+
|
226
|
+
if human_profile is not None and "humanPolicy" not in analysis:
|
227
|
+
logger.error(f"No human policy found in analysis: {analysis}")
|
228
|
+
node.store_analysis(analysis, human_profile)
|
229
|
+
num_queries = self.katago_engine.num_outstanding_queries()
|
230
|
+
self.update_status_bar(
|
231
|
+
"Ready"
|
232
|
+
if num_queries == 0
|
233
|
+
else f"{human_profile or 'AI'} analysis for {node.move.gtp() if node.move else 'root'} received, still working on {num_queries} queries"
|
234
|
+
)
|
235
|
+
|
236
|
+
if node == self.game_logic.current_node: # update state in main thread
|
237
|
+
self.update_state_main_thread.emit()
|
238
|
+
|
239
|
+
# UI setup
|
240
|
+
|
241
|
+
def setup_ui(self):
|
242
|
+
self.create_menu_bar()
|
243
|
+
self.create_status_bar()
|
244
|
+
|
245
|
+
central_widget = QWidget()
|
246
|
+
self.setCentralWidget(central_widget)
|
247
|
+
|
248
|
+
main_layout = QHBoxLayout(central_widget)
|
249
|
+
main_layout.setContentsMargins(0, 0, 0, 0)
|
250
|
+
main_layout.setSpacing(0)
|
251
|
+
|
252
|
+
self.board_view = BoardView(self)
|
253
|
+
self.board_view.setFocusPolicy(Qt.StrongFocus)
|
254
|
+
self.board_view.installEventFilter(self)
|
255
|
+
main_layout.addWidget(self.board_view, 5)
|
256
|
+
|
257
|
+
# Create a container for the right panel with proper margins
|
258
|
+
right_panel_container = QWidget()
|
259
|
+
right_panel_layout = QVBoxLayout(right_panel_container)
|
260
|
+
right_panel_layout.setContentsMargins(0, 12, 0, 0)
|
261
|
+
right_panel_layout.addWidget(self.create_right_panel_tabs())
|
262
|
+
|
263
|
+
main_layout.addWidget(right_panel_container, 3)
|
264
|
+
|
265
|
+
self.setMinimumSize(1200, 800)
|
266
|
+
|
267
|
+
def create_right_panel_tabs(self):
|
268
|
+
tab_widget = QTabWidget()
|
269
|
+
|
270
|
+
# Play tab
|
271
|
+
play_tab = QWidget()
|
272
|
+
self.control_panel = ControlPanel(self)
|
273
|
+
play_tab.setLayout(self.control_panel)
|
274
|
+
tab_widget.addTab(play_tab, "Play")
|
275
|
+
|
276
|
+
# AI Analysis tab
|
277
|
+
ai_analysis_tab = QWidget()
|
278
|
+
self.analysis_panel = AnalysisPanel(self)
|
279
|
+
ai_analysis_tab.setLayout(self.analysis_panel)
|
280
|
+
tab_widget.addTab(ai_analysis_tab, "AI Analysis")
|
281
|
+
|
282
|
+
# Settings tab
|
283
|
+
settings_tab = QWidget()
|
284
|
+
self.config_panel = ConfigPanel(self)
|
285
|
+
settings_tab.setLayout(self.config_panel)
|
286
|
+
tab_widget.addTab(settings_tab, "Settings")
|
287
|
+
return tab_widget
|
288
|
+
|
289
|
+
def create_status_bar(self):
|
290
|
+
status_bar = QStatusBar(self)
|
291
|
+
self.setStatusBar(status_bar)
|
292
|
+
self.status_label = QLabel("Ready")
|
293
|
+
status_bar.addPermanentWidget(self.status_label)
|
294
|
+
|
295
|
+
def update_status_bar(self, message):
|
296
|
+
self.status_label.setText(message)
|
297
|
+
|
298
|
+
def create_menu_bar(self):
|
299
|
+
menu_bar = QMenuBar(self)
|
300
|
+
self.setMenuBar(menu_bar)
|
301
|
+
|
302
|
+
file_menu = QMenu("File", self)
|
303
|
+
menu_bar.addMenu(file_menu)
|
304
|
+
|
305
|
+
save_sgf_action = file_menu.addAction("Save as SGF")
|
306
|
+
save_sgf_action.triggered.connect(self.save_as_sgf)
|
307
|
+
save_sgf_action.setShortcut("Ctrl+S")
|
308
|
+
|
309
|
+
save_sgf_to_clipboard_action = file_menu.addAction("SGF to Clipboard")
|
310
|
+
save_sgf_to_clipboard_action.triggered.connect(self.copy_sgf_to_clipboard)
|
311
|
+
save_sgf_to_clipboard_action.setShortcut("Ctrl+C")
|
312
|
+
|
313
|
+
paste_sgf_action = file_menu.addAction("Paste SGF from Clipboard")
|
314
|
+
paste_sgf_action.triggered.connect(self.paste_sgf_from_clipboard)
|
315
|
+
paste_sgf_action.setShortcut("Ctrl+V")
|
316
|
+
|
317
|
+
exit_action = file_menu.addAction("Exit")
|
318
|
+
exit_action.triggered.connect(self.close)
|
319
|
+
exit_action.setShortcut("Ctrl+Q")
|
320
|
+
|
321
|
+
new_game_menu = QMenu("New Game", self)
|
322
|
+
menu_bar.addMenu(new_game_menu)
|
323
|
+
|
324
|
+
for size in [5, 9, 13, 19]:
|
325
|
+
new_game_action = QAction(f"New Game ({size}x{size})", self)
|
326
|
+
new_game_action.triggered.connect(lambda _checked, s=size: self.new_game(s))
|
327
|
+
new_game_menu.addAction(new_game_action)
|
328
|
+
|
329
|
+
# Add logging menu
|
330
|
+
logging_menu = QMenu("Logging", self)
|
331
|
+
menu_bar.addMenu(logging_menu)
|
332
|
+
for level in ["DEBUG", "INFO", "WARNING", "ERROR"]:
|
333
|
+
logging_action = QAction(level.capitalize(), self)
|
334
|
+
logging_action.triggered.connect(lambda level=level: self.set_logging_level(level))
|
335
|
+
logging_menu.addAction(logging_action)
|
336
|
+
|
337
|
+
def on_pass_move(self):
|
338
|
+
self.make_move(None)
|
339
|
+
|
340
|
+
def eventFilter(self, obj, event):
|
341
|
+
if obj == self.board_view and event.type() == QEvent.Wheel:
|
342
|
+
if event.angleDelta().y() > 0:
|
343
|
+
self.on_prev_move()
|
344
|
+
elif event.angleDelta().y() < 0:
|
345
|
+
self.on_next_move()
|
346
|
+
return True
|
347
|
+
return super().eventFilter(obj, event)
|