goshape 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- goshape-0.1.0.dist-info/METADATA +66 -0
- goshape-0.1.0.dist-info/RECORD +17 -0
- goshape-0.1.0.dist-info/WHEEL +4 -0
- goshape-0.1.0.dist-info/entry_points.txt +2 -0
- goshape-0.1.0.dist-info/licenses/LICENSE +40 -0
- shape/game_logic.py +265 -0
- shape/katago/analysis.cfg +215 -0
- shape/katago/downloader.py +410 -0
- shape/katago/engine.py +176 -0
- shape/main.py +46 -0
- shape/ui/board_view.py +296 -0
- shape/ui/main_window.py +347 -0
- shape/ui/tab_analysis.py +101 -0
- shape/ui/tab_config.py +68 -0
- shape/ui/tab_main_control.py +315 -0
- shape/ui/ui_utils.py +120 -0
- shape/utils.py +7 -0
shape/ui/tab_analysis.py
ADDED
@@ -0,0 +1,101 @@
|
|
1
|
+
import pyqtgraph as pg
|
2
|
+
from PySide6.QtCore import Qt
|
3
|
+
from PySide6.QtWidgets import (
|
4
|
+
QHeaderView,
|
5
|
+
QLabel,
|
6
|
+
QPushButton,
|
7
|
+
QTableWidget,
|
8
|
+
QTableWidgetItem,
|
9
|
+
)
|
10
|
+
|
11
|
+
from shape.ui.ui_utils import SettingsTab, create_label_info_section
|
12
|
+
from shape.utils import setup_logging
|
13
|
+
|
14
|
+
logger = setup_logging()
|
15
|
+
|
16
|
+
|
17
|
+
class AnalysisPanel(SettingsTab):
|
18
|
+
WARNING_TEXT = "SHAPE is not optimized for AI analysis, and this tab is provided mainly for debugging and a quick look at the score graph after a game."
|
19
|
+
|
20
|
+
def create_widgets(self):
|
21
|
+
info_box, self.info_widgets = create_label_info_section(
|
22
|
+
{
|
23
|
+
"win_rate": "Black Win Rate:",
|
24
|
+
"score": "Score:",
|
25
|
+
"mistake_size": "Mistake Size (Score Lead):",
|
26
|
+
"top_moves": "Top Moves:",
|
27
|
+
"total_visits": "Total Visits:",
|
28
|
+
}
|
29
|
+
)
|
30
|
+
self.addWidget(info_box)
|
31
|
+
|
32
|
+
self.top_moves_table = self.create_top_moves_table()
|
33
|
+
self.addWidget(self.top_moves_table)
|
34
|
+
|
35
|
+
self.graph_widget = pg.PlotWidget()
|
36
|
+
self.graph_widget.setBackground("w")
|
37
|
+
self.graph_widget.setLabel("left", "Score")
|
38
|
+
self.graph_widget.showGrid(x=True, y=True, alpha=0.3)
|
39
|
+
self.addWidget(self.graph_widget)
|
40
|
+
self.addStretch(1)
|
41
|
+
|
42
|
+
extra_visits_button = QPushButton("Deeper AI Analysis")
|
43
|
+
extra_visits_button.clicked.connect(self.on_extra_visits)
|
44
|
+
self.addWidget(extra_visits_button)
|
45
|
+
self.addWidget(QLabel(self.WARNING_TEXT, wordWrap=True))
|
46
|
+
|
47
|
+
def on_extra_visits(self):
|
48
|
+
current_node = self.main_window.game_logic.current_node
|
49
|
+
current_analysis = current_node.get_analysis(None)
|
50
|
+
current_visits = current_analysis.visit_count() if current_analysis else 0
|
51
|
+
new_visits = max(500, int(current_visits * 2))
|
52
|
+
self.main_window.request_analysis(current_node, human_profile=None, force_visits=new_visits)
|
53
|
+
self.main_window.update_status_bar(f"Requested {new_visits} total visits")
|
54
|
+
|
55
|
+
def create_top_moves_table(self):
|
56
|
+
table = QTableWidget(5, 4)
|
57
|
+
table.setHorizontalHeaderLabels(["Move", "B Win Rate", "B Score", "Visits"])
|
58
|
+
table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
|
59
|
+
table.setEditTriggers(QTableWidget.NoEditTriggers)
|
60
|
+
table.setFocusPolicy(Qt.NoFocus)
|
61
|
+
table.setMinimumHeight(300)
|
62
|
+
return table
|
63
|
+
|
64
|
+
def update_ui(self):
|
65
|
+
game_logic = self.main_window.game_logic
|
66
|
+
analysis = game_logic.current_node.get_analysis(None)
|
67
|
+
if analysis:
|
68
|
+
win_rate = analysis.win_rate() * 100
|
69
|
+
self.info_widgets["win_rate"].setText(f"{win_rate:.1f}%")
|
70
|
+
score = analysis.ai_score()
|
71
|
+
self.info_widgets["score"].setText(f"{'B' if score >= 0 else 'W'}+{abs(score):.1f}")
|
72
|
+
self.info_widgets["total_visits"].setText(f"{analysis.visit_count()}")
|
73
|
+
|
74
|
+
top_moves = analysis.ai_moves()
|
75
|
+
self.top_moves_table.setRowCount(len(top_moves))
|
76
|
+
for row, move in enumerate(top_moves):
|
77
|
+
self.top_moves_table.setItem(row, 0, QTableWidgetItem(move["move"]))
|
78
|
+
self.top_moves_table.setItem(row, 1, QTableWidgetItem(f"{move['winrate'] * 100:.1f}%"))
|
79
|
+
self.top_moves_table.setItem(row, 2, QTableWidgetItem(f"{move['scoreLead']:.1f}"))
|
80
|
+
self.top_moves_table.setItem(row, 3, QTableWidgetItem(f"{move['visits']}"))
|
81
|
+
|
82
|
+
else:
|
83
|
+
self.clear()
|
84
|
+
self.update_graph(game_logic.get_score_history() or [(0, 0)])
|
85
|
+
|
86
|
+
mistake_size = game_logic.current_node.mistake_size()
|
87
|
+
if mistake_size is not None:
|
88
|
+
self.info_widgets["mistake_size"].setText(f"{mistake_size:.2f}")
|
89
|
+
else:
|
90
|
+
self.info_widgets["mistake_size"].setText("N/A")
|
91
|
+
|
92
|
+
def clear(self):
|
93
|
+
for widget in self.info_widgets.values():
|
94
|
+
widget.setText("N/A")
|
95
|
+
self.top_moves_table.clearContents()
|
96
|
+
|
97
|
+
def update_graph(self, scores: list[tuple[int, float]]):
|
98
|
+
moves, filtered_values = zip(*scores, strict=False)
|
99
|
+
self.graph_widget.plot(moves, filtered_values, pen=pg.mkPen(color="b", width=2), clear=True)
|
100
|
+
self.graph_widget.setYRange(min(filtered_values) - 0.1, max(filtered_values) + 0.1)
|
101
|
+
self.graph_widget.setXRange(0, max(1, len(moves) - 1))
|
shape/ui/tab_config.py
ADDED
@@ -0,0 +1,68 @@
|
|
1
|
+
from shape.ui.ui_utils import (
|
2
|
+
SettingsTab,
|
3
|
+
create_config_section,
|
4
|
+
create_double_spin_box,
|
5
|
+
create_spin_box,
|
6
|
+
)
|
7
|
+
|
8
|
+
|
9
|
+
class ConfigPanel(SettingsTab):
|
10
|
+
def create_widgets(self):
|
11
|
+
self.top_k = create_spin_box(1, 100, 50)
|
12
|
+
self.top_p = create_double_spin_box(0.1, 1.0, 1.0, 0.05)
|
13
|
+
self.min_p = create_double_spin_box(0.0, 1.0, 0.05, 0.01)
|
14
|
+
self.addWidget(
|
15
|
+
create_config_section(
|
16
|
+
"Policy Sampling Settings",
|
17
|
+
{
|
18
|
+
"Top K:": self.top_k,
|
19
|
+
"Top P:": self.top_p,
|
20
|
+
"Min P:": self.min_p,
|
21
|
+
},
|
22
|
+
note="These settings affect both opponent move selection and the heatmap visualization.",
|
23
|
+
)
|
24
|
+
)
|
25
|
+
self.visits = create_spin_box(8, 1024, 24)
|
26
|
+
self.addWidget(create_config_section("Analysis Settings", {"Visits:": self.visits}))
|
27
|
+
self.mistake_size_spinbox = create_spin_box(0, 100, 1)
|
28
|
+
self.target_rank_spinbox = create_spin_box(0, 100, 20)
|
29
|
+
self.max_probability_spinbox = create_spin_box(0, 5, 1)
|
30
|
+
self.addWidget(
|
31
|
+
create_config_section(
|
32
|
+
"Mistake Feedback Settings. Halt if:",
|
33
|
+
{
|
34
|
+
"Mistake size > (points)": self.mistake_size_spinbox,
|
35
|
+
"Either Target rank probability < (%)": self.target_rank_spinbox,
|
36
|
+
"Or Max policy probability < (%)": self.max_probability_spinbox,
|
37
|
+
},
|
38
|
+
)
|
39
|
+
)
|
40
|
+
self.addStretch(1)
|
41
|
+
|
42
|
+
def connect_signals(self):
|
43
|
+
for widget in [self.top_k, self.top_p, self.min_p, self.visits]:
|
44
|
+
widget.valueChanged.connect(self.on_settings_changed)
|
45
|
+
|
46
|
+
def get_ai_strength(self):
|
47
|
+
return self.visits.value()
|
48
|
+
|
49
|
+
def get_sampling_settings(self):
|
50
|
+
return {
|
51
|
+
"top_k": self.top_k.value(),
|
52
|
+
"top_p": self.top_p.value(),
|
53
|
+
"min_p": self.min_p.value(),
|
54
|
+
}
|
55
|
+
|
56
|
+
def should_halt_on_mistake(self, move_stats) -> str | None:
|
57
|
+
if move_stats:
|
58
|
+
max_prob = max(move_stats[f"{k}_prob"] for k in ["player", "target", "ai"])
|
59
|
+
mistake_size = move_stats["mistake_size"]
|
60
|
+
target_rank_prob = move_stats["move_like_target"]
|
61
|
+
|
62
|
+
if mistake_size > self.mistake_size_spinbox.value():
|
63
|
+
mistake_size_msg = f"Mistake size ({mistake_size:.2f}) > {self.mistake_size_spinbox.value()} points"
|
64
|
+
if max_prob < self.max_probability_spinbox.value() / 100:
|
65
|
+
return f"Max policy probability ({max_prob:.1%}) < {self.max_probability_spinbox.value()}% and {mistake_size_msg}"
|
66
|
+
if target_rank_prob < self.target_rank_spinbox.value() / 100:
|
67
|
+
return f"Target rank probability ({target_rank_prob:.1%}) < {self.target_rank_spinbox.value()}% and {mistake_size_msg}"
|
68
|
+
return None
|
@@ -0,0 +1,315 @@
|
|
1
|
+
from functools import cache
|
2
|
+
|
3
|
+
from PySide6.QtCore import Qt
|
4
|
+
from PySide6.QtGui import QKeySequence, QShortcut
|
5
|
+
from PySide6.QtWidgets import (
|
6
|
+
QButtonGroup,
|
7
|
+
QCheckBox,
|
8
|
+
QComboBox,
|
9
|
+
QGridLayout,
|
10
|
+
QGroupBox,
|
11
|
+
QHBoxLayout,
|
12
|
+
QLabel,
|
13
|
+
QProgressBar,
|
14
|
+
QPushButton,
|
15
|
+
QSizePolicy,
|
16
|
+
)
|
17
|
+
|
18
|
+
from shape.ui.ui_utils import SettingsTab
|
19
|
+
from shape.utils import setup_logging
|
20
|
+
|
21
|
+
logger = setup_logging()
|
22
|
+
|
23
|
+
RANK_RANGE = (-20, 9)
|
24
|
+
|
25
|
+
|
26
|
+
@cache
|
27
|
+
def get_rank_from_id(id: int) -> str:
|
28
|
+
if id < 0:
|
29
|
+
return f"{-id}k"
|
30
|
+
return f"{id + 1}d"
|
31
|
+
|
32
|
+
|
33
|
+
def get_human_profile_from_id(id: int, preaz: bool = False) -> str | None:
|
34
|
+
if id >= RANK_RANGE[1] + 10:
|
35
|
+
return None # AI
|
36
|
+
if id >= RANK_RANGE[1]:
|
37
|
+
return "proyear_2023"
|
38
|
+
return f"{'preaz_' if preaz else 'rank_'}{get_rank_from_id(id)}"
|
39
|
+
|
40
|
+
|
41
|
+
class ControlPanel(SettingsTab):
|
42
|
+
def create_widgets(self):
|
43
|
+
self.setSpacing(5)
|
44
|
+
self.addWidget(self.create_game_control_box())
|
45
|
+
self.addWidget(self.create_player_settings_group())
|
46
|
+
self.addWidget(self.create_collapsible_info_panel())
|
47
|
+
self.addStretch(1)
|
48
|
+
|
49
|
+
def create_game_control_box(self):
|
50
|
+
group = QGroupBox("Game Control")
|
51
|
+
layout = QGridLayout()
|
52
|
+
layout.setVerticalSpacing(5)
|
53
|
+
layout.setHorizontalSpacing(5)
|
54
|
+
|
55
|
+
layout.addWidget(QLabel("Play as:"), 0, 0)
|
56
|
+
self.player_color = QButtonGroup(self)
|
57
|
+
for i, color in enumerate(["Black", "White"]):
|
58
|
+
button = QPushButton(color)
|
59
|
+
button.setCheckable(True)
|
60
|
+
self.player_color.addButton(button)
|
61
|
+
layout.addWidget(button, 0, i + 1)
|
62
|
+
self.player_color.buttons()[0].setChecked(True)
|
63
|
+
|
64
|
+
self.auto_play_checkbox = QCheckBox("Auto-play", checked=True)
|
65
|
+
self.ai_move_button = QPushButton("Force Move (Ctrl+Enter)")
|
66
|
+
self.ai_move_button.setShortcut("Ctrl+Enter")
|
67
|
+
layout.addWidget(QLabel("Opponent:"), 1, 0)
|
68
|
+
layout.addWidget(self.ai_move_button, 1, 1)
|
69
|
+
layout.addWidget(self.auto_play_checkbox, 1, 2)
|
70
|
+
|
71
|
+
group.setLayout(layout)
|
72
|
+
return group
|
73
|
+
|
74
|
+
def create_player_settings_group(self):
|
75
|
+
group = QGroupBox("Player Settings")
|
76
|
+
layout = QGridLayout()
|
77
|
+
layout.setVerticalSpacing(5)
|
78
|
+
layout.setHorizontalSpacing(5)
|
79
|
+
|
80
|
+
# Rank settings
|
81
|
+
self.rank_dropdowns = {}
|
82
|
+
layout.addWidget(QLabel("Current Rank:"), 0, 0)
|
83
|
+
self.rank_dropdowns["current"] = QComboBox()
|
84
|
+
self.populate_rank_combo(self.rank_dropdowns["current"], "3k")
|
85
|
+
layout.addWidget(self.rank_dropdowns["current"], 0, 1)
|
86
|
+
|
87
|
+
# Target Rank
|
88
|
+
layout.addWidget(QLabel("Target Rank:"), 0, 2)
|
89
|
+
self.rank_dropdowns["target"] = QComboBox()
|
90
|
+
self.populate_rank_combo(self.rank_dropdowns["target"], "2d")
|
91
|
+
layout.addWidget(self.rank_dropdowns["target"], 0, 3)
|
92
|
+
|
93
|
+
# Opponent selection
|
94
|
+
layout.addWidget(QLabel("Opponent:"), 1, 0)
|
95
|
+
self.opponent_type_combo = QComboBox()
|
96
|
+
self.opponent_type_combo.addItems(["Rank", "Pre-AZ", "Pro"])
|
97
|
+
|
98
|
+
layout.addWidget(self.opponent_type_combo, 1, 1)
|
99
|
+
|
100
|
+
self.opponent_pro_combo = QComboBox()
|
101
|
+
self.populate_pro_combo(self.opponent_pro_combo)
|
102
|
+
layout.addWidget(self.opponent_pro_combo, 1, 2, 1, 2)
|
103
|
+
|
104
|
+
self.opponent_rank_combo = QComboBox()
|
105
|
+
self.populate_rank_combo(self.opponent_rank_combo, "1k")
|
106
|
+
layout.addWidget(self.opponent_rank_combo, 1, 2, 1, 2)
|
107
|
+
|
108
|
+
self.opponent_rank_preaz_combo = QComboBox()
|
109
|
+
self.populate_rank_combo(self.opponent_rank_preaz_combo, "1k")
|
110
|
+
layout.addWidget(self.opponent_rank_preaz_combo, 1, 2, 1, 2)
|
111
|
+
|
112
|
+
# Heatmap settings
|
113
|
+
layout.addWidget(QLabel("Heatmap:"), 3, 0)
|
114
|
+
heatmap_layout = QHBoxLayout()
|
115
|
+
heatmap_layout.setSpacing(2) # Reduce spacing between heatmap buttons
|
116
|
+
|
117
|
+
self.heatmap_buttons = {}
|
118
|
+
heatmap_colors = {
|
119
|
+
"Current": self.main_window.board_view.PLAYER_POLICY_COLOR,
|
120
|
+
"Target": self.main_window.board_view.TARGET_POLICY_COLOR,
|
121
|
+
"AI": self.main_window.board_view.AI_POLICY_COLOR,
|
122
|
+
"Opponent": self.main_window.board_view.OPPONENT_POLICY_COLOR,
|
123
|
+
}
|
124
|
+
for text, shortcut in [("Current", "Ctrl+1"), ("Target", "Ctrl+2"), ("AI", "Ctrl+3"), ("Opponent", "Ctrl+9")]:
|
125
|
+
button = QPushButton(f"{text} ({shortcut})")
|
126
|
+
button.setCheckable(True)
|
127
|
+
button.setShortcut(shortcut)
|
128
|
+
color = heatmap_colors[text]
|
129
|
+
button.setStyleSheet(
|
130
|
+
f"""
|
131
|
+
QPushButton:checked {{
|
132
|
+
background-color: {color.name()};
|
133
|
+
border: 2px solid black;
|
134
|
+
color: white;
|
135
|
+
}}
|
136
|
+
"""
|
137
|
+
)
|
138
|
+
self.heatmap_buttons[text.lower()] = button
|
139
|
+
heatmap_layout.addWidget(button)
|
140
|
+
|
141
|
+
layout.addLayout(heatmap_layout, 3, 1, 1, 3)
|
142
|
+
|
143
|
+
group.setLayout(layout)
|
144
|
+
return group
|
145
|
+
|
146
|
+
def create_collapsible_info_panel(self):
|
147
|
+
self.info_group = QGroupBox("Info (Ctrl+0)")
|
148
|
+
self.info_group.setCheckable(True)
|
149
|
+
self.info_group.setChecked(False)
|
150
|
+
shortcut = QShortcut(QKeySequence("Ctrl+0"), self.main_window)
|
151
|
+
shortcut.activated.connect(lambda: self.info_group.setChecked(not self.info_group.isChecked()))
|
152
|
+
layout = QGridLayout()
|
153
|
+
|
154
|
+
self.last_move_label = QLabel("Last move: N/A")
|
155
|
+
self.current_policy_widget = ProbabilityWidget()
|
156
|
+
self.target_policy_widget = ProbabilityWidget()
|
157
|
+
self.ai_policy_widget = ProbabilityWidget()
|
158
|
+
self.bayesian_prob_widget = ProbabilityWidget()
|
159
|
+
|
160
|
+
layout.addWidget(QLabel("Last move:"), 0, 0)
|
161
|
+
layout.addWidget(self.last_move_label, 0, 1)
|
162
|
+
layout.addWidget(QLabel("Current rank policy:"), 1, 0)
|
163
|
+
layout.addWidget(self.current_policy_widget, 1, 1)
|
164
|
+
layout.addWidget(QLabel("Target rank policy:"), 2, 0)
|
165
|
+
layout.addWidget(self.target_policy_widget, 2, 1)
|
166
|
+
layout.addWidget(QLabel("AI policy:"), 3, 0)
|
167
|
+
layout.addWidget(self.ai_policy_widget, 3, 1)
|
168
|
+
layout.addWidget(QLabel("P(target | move):"), 4, 0)
|
169
|
+
layout.addWidget(self.bayesian_prob_widget, 4, 1)
|
170
|
+
|
171
|
+
self.halted_reason_label = QLabel("")
|
172
|
+
self.halted_reason_label.setWordWrap(True)
|
173
|
+
self.halted_reason_label.setAlignment(Qt.AlignLeft | Qt.AlignTop)
|
174
|
+
self.halted_reason_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Minimum)
|
175
|
+
layout.addWidget(self.halted_reason_label, 5, 0, 1, 2)
|
176
|
+
self.halted_reason_label.setWordWrap(True)
|
177
|
+
self.halted_reason_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Minimum)
|
178
|
+
self.info_group.setLayout(layout)
|
179
|
+
return self.info_group
|
180
|
+
|
181
|
+
def populate_rank_combo(self, combo, default_rank: str):
|
182
|
+
for id in range(*RANK_RANGE):
|
183
|
+
combo.addItem(get_rank_from_id(id), id)
|
184
|
+
combo.setCurrentText(default_rank)
|
185
|
+
|
186
|
+
def populate_pro_combo(self, combo):
|
187
|
+
pro_years = [f"proyear_{year}" for year in range(1803, 2024, 5)]
|
188
|
+
combo.addItems(pro_years)
|
189
|
+
combo.setCurrentText("proyear_1985")
|
190
|
+
|
191
|
+
def connect_signals(self):
|
192
|
+
self.player_color.buttonClicked.connect(self.on_settings_changed)
|
193
|
+
for button in self.heatmap_buttons.values():
|
194
|
+
button.toggled.connect(self.on_settings_changed)
|
195
|
+
self.auto_play_checkbox.stateChanged.connect(self.on_settings_changed)
|
196
|
+
for spinbox in self.rank_dropdowns.values():
|
197
|
+
spinbox.currentIndexChanged.connect(self.on_settings_changed)
|
198
|
+
self.info_group.toggled.connect(self.on_settings_changed)
|
199
|
+
self.opponent_type_combo.currentIndexChanged.connect(self.on_opponent_type_changed)
|
200
|
+
|
201
|
+
def on_opponent_type_changed(self):
|
202
|
+
opponent_type = self.opponent_type_combo.currentText()
|
203
|
+
self.opponent_pro_combo.setVisible(opponent_type == "Pro")
|
204
|
+
self.opponent_rank_combo.setVisible(opponent_type == "Rank")
|
205
|
+
self.opponent_rank_preaz_combo.setVisible(opponent_type == "Pre-AZ")
|
206
|
+
|
207
|
+
def get_human_profiles(self):
|
208
|
+
opponent_type = self.opponent_type_combo.currentText()
|
209
|
+
|
210
|
+
if opponent_type == "Pro":
|
211
|
+
opponent_profile = self.opponent_pro_combo.currentText()
|
212
|
+
elif opponent_type == "Rank":
|
213
|
+
opponent_profile = get_human_profile_from_id(self.opponent_rank_combo.currentData())
|
214
|
+
else: # Rank (pre-AZ)
|
215
|
+
opponent_profile = get_human_profile_from_id(self.opponent_rank_preaz_combo.currentData(), preaz=True)
|
216
|
+
|
217
|
+
return {
|
218
|
+
"player": get_human_profile_from_id(self.rank_dropdowns["current"].currentData()),
|
219
|
+
"opponent": opponent_profile,
|
220
|
+
"target": get_human_profile_from_id(self.rank_dropdowns["target"].currentData()),
|
221
|
+
}
|
222
|
+
|
223
|
+
def get_player_color(self):
|
224
|
+
return self.player_color.checkedButton().text()[0]
|
225
|
+
|
226
|
+
def is_auto_play_enabled(self):
|
227
|
+
return self.auto_play_checkbox.isChecked()
|
228
|
+
|
229
|
+
def get_move_stats(self, node):
|
230
|
+
if not node.move: # root
|
231
|
+
return None
|
232
|
+
human_profiles = self.get_human_profiles()
|
233
|
+
currentlv_analysis = node.get_analysis(human_profiles["player"], parent=True)
|
234
|
+
target_analysis = node.get_analysis(human_profiles["target"], parent=True)
|
235
|
+
ai_analysis = node.get_analysis(None, parent=True)
|
236
|
+
if currentlv_analysis and target_analysis and ai_analysis:
|
237
|
+
player_prob, player_relative_prob = currentlv_analysis.human_policy.at(node.move)
|
238
|
+
target_prob, target_relative_prob = target_analysis.human_policy.at(node.move)
|
239
|
+
ai_prob, ai_relative_prob = ai_analysis.ai_policy.at(node.move)
|
240
|
+
return {
|
241
|
+
"player_prob": player_prob,
|
242
|
+
"target_prob": target_prob,
|
243
|
+
"ai_prob": ai_prob,
|
244
|
+
"player_relative_prob": player_relative_prob,
|
245
|
+
"target_relative_prob": target_relative_prob,
|
246
|
+
"ai_relative_prob": ai_relative_prob,
|
247
|
+
"move_like_target": target_prob / max(player_prob + target_prob, 1e-10),
|
248
|
+
"mistake_size": node.mistake_size(),
|
249
|
+
}
|
250
|
+
return None
|
251
|
+
|
252
|
+
def get_heatmap_settings(self):
|
253
|
+
human_profiles = self.get_human_profiles()
|
254
|
+
policies = [
|
255
|
+
(human_profiles["player"], self.heatmap_buttons["current"].isChecked()),
|
256
|
+
(human_profiles["target"], self.heatmap_buttons["target"].isChecked()),
|
257
|
+
(None, self.heatmap_buttons["ai"].isChecked()),
|
258
|
+
]
|
259
|
+
if self.heatmap_buttons["opponent"].isChecked():
|
260
|
+
policies = [("", False)] * 3 + [(human_profiles["opponent"], True)]
|
261
|
+
return {
|
262
|
+
"policy": policies,
|
263
|
+
"enabled": any(policy[1] for policy in policies),
|
264
|
+
}
|
265
|
+
|
266
|
+
def update_ui(self):
|
267
|
+
if not self.heatmap_buttons["opponent"].isChecked():
|
268
|
+
self.heatmap_buttons["opponent"].setFixedSize(0, 0)
|
269
|
+
else:
|
270
|
+
self.heatmap_buttons["opponent"].setFixedSize(self.heatmap_buttons["opponent"].sizeHint())
|
271
|
+
game_logic = self.main_window.game_logic
|
272
|
+
player_color = self.get_player_color()
|
273
|
+
|
274
|
+
node = game_logic.current_node if game_logic.player == player_color else game_logic.current_node.parent
|
275
|
+
|
276
|
+
if node and (last_player_move := node.move):
|
277
|
+
self.last_move_label.setText(f"Last move: {last_player_move.gtp()}")
|
278
|
+
else:
|
279
|
+
self.last_move_label.setText("Last move: N/A")
|
280
|
+
|
281
|
+
if node and node.autoplay_halted_reason:
|
282
|
+
self.halted_reason_label.setText(f"Critical mistake: {node.autoplay_halted_reason}")
|
283
|
+
else:
|
284
|
+
self.halted_reason_label.setText("")
|
285
|
+
|
286
|
+
if self.info_group.isChecked() and node and (move_stats := self.get_move_stats(node)):
|
287
|
+
self.current_policy_widget.update_probability(move_stats["player_prob"], move_stats["player_relative_prob"])
|
288
|
+
self.target_policy_widget.update_probability(move_stats["target_prob"], move_stats["target_relative_prob"])
|
289
|
+
self.ai_policy_widget.update_probability(move_stats["ai_prob"], move_stats["ai_relative_prob"])
|
290
|
+
self.bayesian_prob_widget.update_probability(move_stats["move_like_target"])
|
291
|
+
else:
|
292
|
+
self.current_policy_widget.set_na()
|
293
|
+
self.target_policy_widget.set_na()
|
294
|
+
self.ai_policy_widget.set_na()
|
295
|
+
self.bayesian_prob_widget.set_na()
|
296
|
+
|
297
|
+
|
298
|
+
class ProbabilityWidget(QProgressBar):
|
299
|
+
def __init__(self, probability: float = 0):
|
300
|
+
super().__init__()
|
301
|
+
self.setValue(probability * 100)
|
302
|
+
self.setTextVisible(True)
|
303
|
+
self.setStyleSheet(
|
304
|
+
"QProgressBar::chunk { background-color: green; }"
|
305
|
+
"QProgressBar { background-color: #aaa; border: 1px solid #cccccc; font-size: 14px; font-weight: bold; color: #eee; text-align: center; }"
|
306
|
+
)
|
307
|
+
|
308
|
+
def update_probability(self, label_probability: float, fill_percentage: float = None):
|
309
|
+
fill_percentage = fill_percentage or label_probability
|
310
|
+
self.setFormat(f"{label_probability:.2%}")
|
311
|
+
self.setValue(fill_percentage * 100)
|
312
|
+
|
313
|
+
def set_na(self):
|
314
|
+
self.setValue(0)
|
315
|
+
self.setFormat("N/A")
|
shape/ui/ui_utils.py
ADDED
@@ -0,0 +1,120 @@
|
|
1
|
+
from PySide6.QtCore import Qt, Signal
|
2
|
+
from PySide6.QtWidgets import (
|
3
|
+
QDoubleSpinBox,
|
4
|
+
QFormLayout,
|
5
|
+
QGroupBox,
|
6
|
+
QLabel,
|
7
|
+
QSpinBox,
|
8
|
+
QVBoxLayout,
|
9
|
+
QWidget,
|
10
|
+
)
|
11
|
+
|
12
|
+
# Stylesheets
|
13
|
+
MAIN_STYLESHEET = """
|
14
|
+
QWidget {
|
15
|
+
font-size: 11px;
|
16
|
+
}
|
17
|
+
QGroupBox {
|
18
|
+
font-weight: bold;
|
19
|
+
border: 1px solid #cccccc;
|
20
|
+
border-radius: 6px;
|
21
|
+
margin-top: 6px;
|
22
|
+
padding-top: 6px;
|
23
|
+
}
|
24
|
+
QGroupBox::title {
|
25
|
+
subcontrol-origin: margin;
|
26
|
+
left: 7px;
|
27
|
+
padding: 0px 5px 0px 5px;
|
28
|
+
}
|
29
|
+
QPushButton {
|
30
|
+
background-color: #f0f0f0;
|
31
|
+
border: 1px solid #cccccc;
|
32
|
+
border-radius: 4px;
|
33
|
+
padding: 3px;
|
34
|
+
min-width: 30px;
|
35
|
+
}
|
36
|
+
QPushButton:hover {
|
37
|
+
background-color: #e0e0e0;
|
38
|
+
}
|
39
|
+
QPushButton:checked {
|
40
|
+
background-color: #c0c0c0;
|
41
|
+
border: 2px solid #808080;
|
42
|
+
}
|
43
|
+
QComboBox, QSpinBox {
|
44
|
+
border: 1px solid #cccccc;
|
45
|
+
border-radius: 4px;
|
46
|
+
padding: 1px;
|
47
|
+
min-width: 30px;
|
48
|
+
}
|
49
|
+
QLabel {
|
50
|
+
padding-right: 3px;
|
51
|
+
}
|
52
|
+
QTableWidget {
|
53
|
+
border: 1px solid #cccccc;
|
54
|
+
border-radius: 5px;
|
55
|
+
}
|
56
|
+
"""
|
57
|
+
|
58
|
+
|
59
|
+
# Helper functions
|
60
|
+
def create_spin_box(min_value, max_value, default_value):
|
61
|
+
spin_box = QSpinBox()
|
62
|
+
spin_box.setRange(min_value, max_value)
|
63
|
+
spin_box.setValue(default_value)
|
64
|
+
return spin_box
|
65
|
+
|
66
|
+
|
67
|
+
def create_double_spin_box(min_value, max_value, default_value, step):
|
68
|
+
spin_box = QDoubleSpinBox()
|
69
|
+
spin_box.setRange(min_value, max_value)
|
70
|
+
spin_box.setSingleStep(step)
|
71
|
+
spin_box.setValue(default_value)
|
72
|
+
return spin_box
|
73
|
+
|
74
|
+
|
75
|
+
def create_config_section(title: str, widgets: dict[str, QWidget], note: str = None):
|
76
|
+
sample_box = QGroupBox(title)
|
77
|
+
sampling_form_layout = QFormLayout()
|
78
|
+
sampling_form_layout.setLabelAlignment(Qt.AlignRight)
|
79
|
+
sampling_form_layout.setFieldGrowthPolicy(QFormLayout.AllNonFixedFieldsGrow)
|
80
|
+
if note:
|
81
|
+
sampling_form_layout.addRow(QLabel(note, wordWrap=True))
|
82
|
+
for k, widget in widgets.items():
|
83
|
+
sampling_form_layout.addRow(k, widget)
|
84
|
+
sample_box.setLayout(sampling_form_layout)
|
85
|
+
return sample_box
|
86
|
+
|
87
|
+
|
88
|
+
def create_label_info_section(labels: dict[str, str]):
|
89
|
+
holding_widget = QWidget()
|
90
|
+
info_form_layout = QFormLayout()
|
91
|
+
info_form_layout.setLabelAlignment(Qt.AlignLeft)
|
92
|
+
widgets = {k: QLabel() for k, v in labels.items()}
|
93
|
+
for k, v in labels.items():
|
94
|
+
info_form_layout.addRow(v, widgets[k])
|
95
|
+
holding_widget.setLayout(info_form_layout)
|
96
|
+
return holding_widget, widgets
|
97
|
+
|
98
|
+
|
99
|
+
class SettingsTab(QVBoxLayout):
|
100
|
+
settings_updated = Signal()
|
101
|
+
|
102
|
+
def __init__(self, main_window):
|
103
|
+
super().__init__()
|
104
|
+
self.main_window = main_window
|
105
|
+
self.setSpacing(10)
|
106
|
+
self.setContentsMargins(10, 10, 10, 10)
|
107
|
+
self.create_widgets()
|
108
|
+
self.connect_signals()
|
109
|
+
|
110
|
+
def create_widgets(self):
|
111
|
+
pass
|
112
|
+
|
113
|
+
def connect_signals(self):
|
114
|
+
pass
|
115
|
+
|
116
|
+
def on_settings_changed(self):
|
117
|
+
self.settings_updated.emit()
|
118
|
+
|
119
|
+
def update_ui(self):
|
120
|
+
pass
|