goshape 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- goshape-0.1.0.dist-info/METADATA +66 -0
- goshape-0.1.0.dist-info/RECORD +17 -0
- goshape-0.1.0.dist-info/WHEEL +4 -0
- goshape-0.1.0.dist-info/entry_points.txt +2 -0
- goshape-0.1.0.dist-info/licenses/LICENSE +40 -0
- shape/game_logic.py +265 -0
- shape/katago/analysis.cfg +215 -0
- shape/katago/downloader.py +410 -0
- shape/katago/engine.py +176 -0
- shape/main.py +46 -0
- shape/ui/board_view.py +296 -0
- shape/ui/main_window.py +347 -0
- shape/ui/tab_analysis.py +101 -0
- shape/ui/tab_config.py +68 -0
- shape/ui/tab_main_control.py +315 -0
- shape/ui/ui_utils.py +120 -0
- shape/utils.py +7 -0
@@ -0,0 +1,410 @@
|
|
1
|
+
import asyncio
|
2
|
+
import os
|
3
|
+
import platform
|
4
|
+
import shutil
|
5
|
+
import subprocess
|
6
|
+
import zipfile
|
7
|
+
from dataclasses import dataclass, field
|
8
|
+
from pathlib import Path
|
9
|
+
|
10
|
+
import httpx
|
11
|
+
from PySide6.QtCore import QThread, Signal
|
12
|
+
from PySide6.QtGui import QFont
|
13
|
+
from PySide6.QtWidgets import (
|
14
|
+
QDialog,
|
15
|
+
QHBoxLayout,
|
16
|
+
QLabel,
|
17
|
+
QProgressBar,
|
18
|
+
QPushButton,
|
19
|
+
QVBoxLayout,
|
20
|
+
QWidget,
|
21
|
+
)
|
22
|
+
|
23
|
+
from shape.utils import setup_logging
|
24
|
+
|
25
|
+
logger = setup_logging()
|
26
|
+
|
27
|
+
KATRAIN_DIR = Path.home() / ".katrain"
|
28
|
+
|
29
|
+
|
30
|
+
def get_katago_version_info(katago_path: Path) -> tuple[str, str]:
|
31
|
+
"""Get KataGo version and backend info. Returns (version, backend)."""
|
32
|
+
try:
|
33
|
+
result = subprocess.run([str(katago_path), "version"], capture_output=True, text=True, timeout=10)
|
34
|
+
if result.returncode == 0:
|
35
|
+
lines = result.stdout.strip().split("\n")
|
36
|
+
version = "Unknown"
|
37
|
+
backend = "Unknown"
|
38
|
+
|
39
|
+
for line in lines:
|
40
|
+
if line.startswith("KataGo v"):
|
41
|
+
version = line.split()[1] # Extract version like "v1.15.3"
|
42
|
+
elif "backend" in line.lower():
|
43
|
+
# Extract backend like "OpenCL", "CUDA", etc.
|
44
|
+
if "OpenCL" in line:
|
45
|
+
backend = "OpenCL"
|
46
|
+
elif "CUDA" in line:
|
47
|
+
backend = "CUDA"
|
48
|
+
elif "CPU" in line:
|
49
|
+
backend = "CPU"
|
50
|
+
else:
|
51
|
+
backend = line.split()[-1] # Last word of the line
|
52
|
+
|
53
|
+
return version, backend
|
54
|
+
else:
|
55
|
+
logger.warning(f"KataGo version command failed: {result.stderr}")
|
56
|
+
return "Unknown", "Unknown"
|
57
|
+
except Exception as e:
|
58
|
+
logger.warning(f"Failed to get KataGo version: {e}")
|
59
|
+
return "Unknown", "Unknown"
|
60
|
+
|
61
|
+
|
62
|
+
KATAGO_DIR = KATRAIN_DIR / "katago"
|
63
|
+
KATAGO_EXE_NAME = "katago.exe" if platform.system() == "Windows" else "katago"
|
64
|
+
KATAGO_PATH = KATAGO_DIR / KATAGO_EXE_NAME
|
65
|
+
|
66
|
+
|
67
|
+
@dataclass
|
68
|
+
class DownloadableComponent:
|
69
|
+
name: str
|
70
|
+
destination_dir: Path
|
71
|
+
destination_filename: str
|
72
|
+
download_url: str
|
73
|
+
is_zip: bool = False
|
74
|
+
found: bool = field(init=False, default=False)
|
75
|
+
downloading: bool = field(init=False, default=False)
|
76
|
+
error: str | None = field(init=False, default=None)
|
77
|
+
|
78
|
+
def __post_init__(self):
|
79
|
+
self.destination_dir.mkdir(parents=True, exist_ok=True)
|
80
|
+
self.check_if_found()
|
81
|
+
|
82
|
+
@property
|
83
|
+
def destination_path(self) -> Path:
|
84
|
+
return self.destination_dir / self.destination_filename
|
85
|
+
|
86
|
+
def check_if_found(self):
|
87
|
+
# Special case for KataGo: check PATH first
|
88
|
+
if self.name == "KataGo Engine":
|
89
|
+
path_katago = shutil.which("katago")
|
90
|
+
if path_katago:
|
91
|
+
self.found = True
|
92
|
+
self.error = None
|
93
|
+
# Update destination to point to the PATH version
|
94
|
+
self._path_katago = Path(path_katago)
|
95
|
+
return True
|
96
|
+
|
97
|
+
self.found = self.destination_path.exists()
|
98
|
+
if self.found:
|
99
|
+
self.error = None # reset error on found
|
100
|
+
return self.found
|
101
|
+
|
102
|
+
def get_widget(self, download_callback, parent_dialog) -> "ComponentWidget":
|
103
|
+
return ComponentWidget(self, download_callback, parent_dialog)
|
104
|
+
|
105
|
+
|
106
|
+
class ComponentWidget(QWidget):
|
107
|
+
def __init__(self, component: DownloadableComponent, download_callback, parent: QDialog):
|
108
|
+
super().__init__(parent)
|
109
|
+
self.component = component
|
110
|
+
self.download_callback = download_callback
|
111
|
+
|
112
|
+
layout = QHBoxLayout()
|
113
|
+
self.name_label = QLabel(f"<b>{component.name}</b>")
|
114
|
+
self.status_label = QLabel()
|
115
|
+
self.download_button = QPushButton()
|
116
|
+
self.download_button.clicked.connect(self._on_download_click)
|
117
|
+
self.progress_bar = QProgressBar()
|
118
|
+
self.progress_bar.setVisible(False)
|
119
|
+
|
120
|
+
layout.addWidget(self.name_label)
|
121
|
+
layout.addWidget(self.status_label)
|
122
|
+
layout.addStretch()
|
123
|
+
layout.addWidget(self.progress_bar)
|
124
|
+
layout.addWidget(self.download_button)
|
125
|
+
self.setLayout(layout)
|
126
|
+
self.update_status()
|
127
|
+
|
128
|
+
def _on_download_click(self):
|
129
|
+
self.download_button.setEnabled(False)
|
130
|
+
self.progress_bar.setVisible(True)
|
131
|
+
self.progress_bar.setRange(0, 0)
|
132
|
+
self.download_callback(self.component)
|
133
|
+
|
134
|
+
def update_status(self):
|
135
|
+
self.component.check_if_found()
|
136
|
+
if self.component.downloading:
|
137
|
+
self.status_label.setText("Downloading...")
|
138
|
+
self.download_button.setVisible(False)
|
139
|
+
self.progress_bar.setVisible(True)
|
140
|
+
elif self.component.found:
|
141
|
+
if self.component.name == "KataGo Engine":
|
142
|
+
# Show version info for KataGo
|
143
|
+
if hasattr(self.component, "_path_katago"):
|
144
|
+
katago_path = self.component._path_katago
|
145
|
+
location_text = f"Found in PATH: {katago_path}"
|
146
|
+
else:
|
147
|
+
katago_path = self.component.destination_path
|
148
|
+
location_text = f"Found at {katago_path}"
|
149
|
+
|
150
|
+
version, backend = get_katago_version_info(katago_path)
|
151
|
+
self.status_label.setText(
|
152
|
+
f"<font color='green'>{location_text}<br/>Version: {version} ({backend})</font>"
|
153
|
+
)
|
154
|
+
else:
|
155
|
+
# For models, show file info
|
156
|
+
if hasattr(self.component, "_path_katago"):
|
157
|
+
self.status_label.setText(
|
158
|
+
f"<font color='green'>Found in PATH: {self.component._path_katago}</font>"
|
159
|
+
)
|
160
|
+
else:
|
161
|
+
file_size = (
|
162
|
+
self.component.destination_path.stat().st_size // (1024 * 1024)
|
163
|
+
if self.component.destination_path.exists()
|
164
|
+
else 0
|
165
|
+
)
|
166
|
+
self.status_label.setText(
|
167
|
+
f"<font color='green'>Found at {self.component.destination_path}<br/>Size: {file_size} MB</font>"
|
168
|
+
)
|
169
|
+
self.download_button.setVisible(False)
|
170
|
+
self.progress_bar.setVisible(False)
|
171
|
+
elif self.component.error:
|
172
|
+
self.status_label.setText(f"<font color='red'>Error: {self.component.error}</font>")
|
173
|
+
self.download_button.setText("Retry")
|
174
|
+
self.download_button.setVisible(True)
|
175
|
+
self.download_button.setEnabled(True)
|
176
|
+
self.progress_bar.setVisible(False)
|
177
|
+
else:
|
178
|
+
self.status_label.setText("<font color='orange'>Missing</font>")
|
179
|
+
self.download_button.setText("Download")
|
180
|
+
self.download_button.setVisible(True)
|
181
|
+
self.download_button.setEnabled(True)
|
182
|
+
self.progress_bar.setVisible(False)
|
183
|
+
|
184
|
+
def update_progress(self, percent):
|
185
|
+
self.progress_bar.setRange(0, 100)
|
186
|
+
self.progress_bar.setValue(percent)
|
187
|
+
|
188
|
+
|
189
|
+
class DownloadThread(QThread):
|
190
|
+
progress_signal = Signal(DownloadableComponent, int)
|
191
|
+
finished_signal = Signal(DownloadableComponent, str) # Component, error_message (empty string for success)
|
192
|
+
|
193
|
+
def __init__(self, components_to_download: list[DownloadableComponent]):
|
194
|
+
super().__init__()
|
195
|
+
self.components = components_to_download
|
196
|
+
for c in self.components:
|
197
|
+
c.downloading = True
|
198
|
+
|
199
|
+
def run(self):
|
200
|
+
try:
|
201
|
+
asyncio.run(self._download_files_async())
|
202
|
+
except Exception as e:
|
203
|
+
logger.error(f"Downloader thread failed: {e}")
|
204
|
+
for component in self.components: # fail all on thread error
|
205
|
+
if component.downloading:
|
206
|
+
self.finished_signal.emit(component, str(e))
|
207
|
+
|
208
|
+
async def _download_files_async(self):
|
209
|
+
async with httpx.AsyncClient(timeout=300.0) as client:
|
210
|
+
tasks = [self._download_file(client, c) for c in self.components]
|
211
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
212
|
+
|
213
|
+
for component, result in zip(self.components, results, strict=False):
|
214
|
+
if isinstance(result, Exception):
|
215
|
+
self.finished_signal.emit(component, str(result))
|
216
|
+
else:
|
217
|
+
self.finished_signal.emit(component, "")
|
218
|
+
|
219
|
+
async def _download_file(self, client: httpx.AsyncClient, component: DownloadableComponent):
|
220
|
+
download_path = (
|
221
|
+
component.destination_path.with_suffix(".zip.download")
|
222
|
+
if component.is_zip
|
223
|
+
else component.destination_path.with_suffix(".download")
|
224
|
+
)
|
225
|
+
try:
|
226
|
+
async with client.stream("GET", component.download_url, follow_redirects=True) as response:
|
227
|
+
response.raise_for_status()
|
228
|
+
total_size = int(response.headers.get("content-length", 0))
|
229
|
+
with open(download_path, "wb") as f:
|
230
|
+
downloaded = 0
|
231
|
+
async for chunk in response.aiter_bytes(chunk_size=8192):
|
232
|
+
f.write(chunk)
|
233
|
+
downloaded += len(chunk)
|
234
|
+
if total_size > 0:
|
235
|
+
percent = min(100, (downloaded * 100) // total_size)
|
236
|
+
self.progress_signal.emit(component, percent)
|
237
|
+
|
238
|
+
if component.is_zip:
|
239
|
+
with zipfile.ZipFile(download_path, "r") as zip_ref:
|
240
|
+
# find executable in zip
|
241
|
+
exe_files = [
|
242
|
+
f for f in zip_ref.namelist() if f.endswith(KATAGO_EXE_NAME) and not f.startswith("__MACOSX")
|
243
|
+
]
|
244
|
+
if not exe_files:
|
245
|
+
raise Exception(f"Katago executable not found in zip {download_path}")
|
246
|
+
internal_exe_path = exe_files[0]
|
247
|
+
with zip_ref.open(internal_exe_path) as source, open(component.destination_path, "wb") as target:
|
248
|
+
shutil.copyfileobj(source, target)
|
249
|
+
os.chmod(component.destination_path, 0o755) # make executable
|
250
|
+
else:
|
251
|
+
shutil.move(download_path, component.destination_path)
|
252
|
+
|
253
|
+
finally:
|
254
|
+
if download_path.exists():
|
255
|
+
download_path.unlink()
|
256
|
+
|
257
|
+
|
258
|
+
class ComponentsDownloaderDialog(QDialog):
|
259
|
+
def __init__(self, parent=None):
|
260
|
+
super().__init__(parent)
|
261
|
+
self.setWindowTitle("SHAPE Setup: KataGo Components")
|
262
|
+
self.setModal(True)
|
263
|
+
self.setMinimumWidth(600)
|
264
|
+
|
265
|
+
self.components = self._define_components()
|
266
|
+
self.component_widgets: dict[str, ComponentWidget] = {}
|
267
|
+
self.setup_ui()
|
268
|
+
self.check_all_found()
|
269
|
+
|
270
|
+
def _get_katago_url(self):
|
271
|
+
system = platform.system()
|
272
|
+
base_url = "https://github.com/lightvector/KataGo/releases/download/v1.16.0/"
|
273
|
+
if system == "Linux":
|
274
|
+
return base_url + "katago-v1.16.0-opencl-linux-x64.zip"
|
275
|
+
elif system == "Windows":
|
276
|
+
return base_url + "katago-v1.16.0-opencl-windows-x64.zip"
|
277
|
+
elif system == "Darwin": # MacOS
|
278
|
+
return base_url + "katago-v1.16.0-opencl-macos-x64.zip"
|
279
|
+
raise RuntimeError(f"Unsupported OS for KataGo download: {system}")
|
280
|
+
|
281
|
+
def _define_components(self) -> list[DownloadableComponent]:
|
282
|
+
return [
|
283
|
+
DownloadableComponent(
|
284
|
+
name="KataGo Engine",
|
285
|
+
destination_dir=KATAGO_DIR,
|
286
|
+
destination_filename=KATAGO_EXE_NAME,
|
287
|
+
download_url=self._get_katago_url(),
|
288
|
+
is_zip=True,
|
289
|
+
),
|
290
|
+
DownloadableComponent(
|
291
|
+
name="KataGo Model (28b)",
|
292
|
+
destination_dir=KATRAIN_DIR,
|
293
|
+
destination_filename="katago-28b.bin.gz",
|
294
|
+
download_url="https://media.katagotraining.org/uploaded/networks/models/kata1/kata1-b28c512nbt-s7709128960-d4462231357.bin.gz",
|
295
|
+
),
|
296
|
+
DownloadableComponent(
|
297
|
+
name="KataGo Model (Human)",
|
298
|
+
destination_dir=KATRAIN_DIR,
|
299
|
+
destination_filename="katago-human.bin.gz",
|
300
|
+
download_url="https://github.com/lightvector/KataGo/releases/download/v1.15.0/b18c384nbt-humanv0.bin.gz",
|
301
|
+
),
|
302
|
+
]
|
303
|
+
|
304
|
+
def setup_ui(self):
|
305
|
+
layout = QVBoxLayout()
|
306
|
+
self.title_label = QLabel("Checking for required components...")
|
307
|
+
font = QFont()
|
308
|
+
font.setPointSize(16)
|
309
|
+
font.setBold(True)
|
310
|
+
self.title_label.setFont(font)
|
311
|
+
layout.addWidget(self.title_label)
|
312
|
+
|
313
|
+
for component in self.components:
|
314
|
+
widget = component.get_widget(self.download_one, self)
|
315
|
+
self.component_widgets[component.name] = widget
|
316
|
+
layout.addWidget(widget)
|
317
|
+
|
318
|
+
self.download_all_button = QPushButton("Download All Missing")
|
319
|
+
self.download_all_button.clicked.connect(self.download_all)
|
320
|
+
layout.addWidget(self.download_all_button)
|
321
|
+
|
322
|
+
self.close_button = QPushButton("Close")
|
323
|
+
self.close_button.clicked.connect(self.accept)
|
324
|
+
layout.addWidget(self.close_button)
|
325
|
+
|
326
|
+
self.setLayout(layout)
|
327
|
+
|
328
|
+
def download_one(self, component: DownloadableComponent):
|
329
|
+
self.download([component])
|
330
|
+
|
331
|
+
def download_all(self):
|
332
|
+
missing = [c for c in self.components if not c.found]
|
333
|
+
self.download(missing)
|
334
|
+
|
335
|
+
def download(self, components: list[DownloadableComponent]):
|
336
|
+
if not components:
|
337
|
+
return
|
338
|
+
self.download_all_button.setEnabled(False)
|
339
|
+
self.close_button.setEnabled(False)
|
340
|
+
|
341
|
+
self.download_thread = DownloadThread(components)
|
342
|
+
for c in components:
|
343
|
+
self.component_widgets[c.name].update_status()
|
344
|
+
self.download_thread.progress_signal.connect(self._on_progress)
|
345
|
+
self.download_thread.finished_signal.connect(self._on_finished)
|
346
|
+
self.download_thread.start()
|
347
|
+
|
348
|
+
def _on_progress(self, component: DownloadableComponent, percent: int):
|
349
|
+
self.component_widgets[component.name].update_progress(percent)
|
350
|
+
|
351
|
+
def _on_finished(self, component: DownloadableComponent, error: str):
|
352
|
+
component.downloading = False
|
353
|
+
component.error = error if error else None
|
354
|
+
component.check_if_found()
|
355
|
+
# Force progress bar to be hidden and reset
|
356
|
+
widget = self.component_widgets[component.name]
|
357
|
+
widget.progress_bar.setVisible(False)
|
358
|
+
widget.progress_bar.setRange(0, 100)
|
359
|
+
widget.progress_bar.setValue(0)
|
360
|
+
widget.update_status()
|
361
|
+
self.check_all_found()
|
362
|
+
|
363
|
+
def check_all_found(self):
|
364
|
+
all_found = all(c.check_if_found() for c in self.components)
|
365
|
+
downloading = any(c.downloading for c in self.components)
|
366
|
+
|
367
|
+
# Update title based on status
|
368
|
+
if downloading:
|
369
|
+
self.title_label.setText("Downloading components...")
|
370
|
+
elif all_found:
|
371
|
+
self.title_label.setText("All components ready!")
|
372
|
+
else:
|
373
|
+
missing_count = sum(1 for c in self.components if not c.found)
|
374
|
+
self.title_label.setText(f"Missing {missing_count} component{'s' if missing_count != 1 else ''}")
|
375
|
+
|
376
|
+
self.download_all_button.setEnabled(not all_found and not downloading)
|
377
|
+
self.close_button.setEnabled(all_found and not downloading)
|
378
|
+
self.close_button.setText("Continue" if all_found else "Close")
|
379
|
+
if all_found:
|
380
|
+
self.download_all_button.setVisible(False)
|
381
|
+
|
382
|
+
def get_paths(self) -> dict[str, Path] | None:
|
383
|
+
if not all(c.found for c in self.components):
|
384
|
+
return None
|
385
|
+
|
386
|
+
# Get KataGo path - either from PATH or downloaded location
|
387
|
+
katago_component = self.components[0] # KataGo Engine is first
|
388
|
+
if hasattr(katago_component, "_path_katago"):
|
389
|
+
katago_path = katago_component._path_katago
|
390
|
+
else:
|
391
|
+
katago_path = KATAGO_PATH
|
392
|
+
|
393
|
+
return {
|
394
|
+
"katago_path": katago_path,
|
395
|
+
"model_path": self.components[1].destination_path,
|
396
|
+
"human_model_path": self.components[2].destination_path,
|
397
|
+
}
|
398
|
+
|
399
|
+
def get_katago_version_info(self) -> tuple[str, str]:
|
400
|
+
"""Get KataGo version and backend info for the main window title."""
|
401
|
+
katago_component = self.components[0] # KataGo Engine is first
|
402
|
+
if not katago_component.found:
|
403
|
+
return "Unknown", "Unknown"
|
404
|
+
|
405
|
+
if hasattr(katago_component, "_path_katago"):
|
406
|
+
katago_path = katago_component._path_katago
|
407
|
+
else:
|
408
|
+
katago_path = KATAGO_PATH
|
409
|
+
|
410
|
+
return get_katago_version_info(katago_path)
|
shape/katago/engine.py
ADDED
@@ -0,0 +1,176 @@
|
|
1
|
+
import copy
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
import queue
|
5
|
+
import subprocess
|
6
|
+
import threading
|
7
|
+
import traceback
|
8
|
+
from collections.abc import Callable
|
9
|
+
|
10
|
+
from PySide6.QtWidgets import QApplication, QDialog
|
11
|
+
|
12
|
+
from shape.game_logic import GameNode
|
13
|
+
from shape.katago.downloader import ComponentsDownloaderDialog
|
14
|
+
from shape.utils import setup_logging
|
15
|
+
|
16
|
+
logger = setup_logging()
|
17
|
+
|
18
|
+
|
19
|
+
class KataGoEngine:
|
20
|
+
RULESETS_ABBR = {
|
21
|
+
"jp": "japanese",
|
22
|
+
"cn": "chinese",
|
23
|
+
"ko": "korean",
|
24
|
+
"aga": "aga",
|
25
|
+
"tt": "tromp-taylor",
|
26
|
+
"nz": "new zealand",
|
27
|
+
"stone_scoring": "stone_scoring",
|
28
|
+
}
|
29
|
+
|
30
|
+
def __init__(self, model_folder=None):
|
31
|
+
# analysis.cfg is now stored in the package
|
32
|
+
config_path = os.path.join(os.path.dirname(__file__), "analysis.cfg")
|
33
|
+
if not os.path.exists(config_path):
|
34
|
+
raise RuntimeError(f"Analysis config not found at {config_path}")
|
35
|
+
|
36
|
+
app = QApplication.instance()
|
37
|
+
if app is None:
|
38
|
+
app = QApplication([])
|
39
|
+
|
40
|
+
dialog = ComponentsDownloaderDialog()
|
41
|
+
paths = dialog.get_paths()
|
42
|
+
if not paths:
|
43
|
+
result = dialog.exec()
|
44
|
+
if result != QDialog.DialogCode.Accepted:
|
45
|
+
raise RuntimeError("KataGo components are required but download was cancelled or failed.")
|
46
|
+
paths = dialog.get_paths()
|
47
|
+
if not paths:
|
48
|
+
raise RuntimeError("Could not retrieve component paths even after download dialog.")
|
49
|
+
|
50
|
+
# Store version info for the main window title
|
51
|
+
self.katago_version, self.katago_backend = dialog.get_katago_version_info()
|
52
|
+
|
53
|
+
command = [
|
54
|
+
os.path.abspath(paths["katago_path"]),
|
55
|
+
"analysis",
|
56
|
+
"-config",
|
57
|
+
config_path,
|
58
|
+
"-model",
|
59
|
+
str(paths["model_path"]),
|
60
|
+
"-human-model",
|
61
|
+
str(paths["human_model_path"]),
|
62
|
+
]
|
63
|
+
self.query_queue = queue.Queue()
|
64
|
+
self.response_callbacks = {}
|
65
|
+
self.process = self._start_process(command)
|
66
|
+
if self.process.poll() is not None:
|
67
|
+
stderr_output = self.process.stderr.read() if self.process.stderr else "No stderr available"
|
68
|
+
raise RuntimeError(f"KataGo process exited unexpectedly on startup: {stderr_output}")
|
69
|
+
|
70
|
+
threads = [
|
71
|
+
threading.Thread(target=self._log_stderr, daemon=True),
|
72
|
+
threading.Thread(target=self._process_responses, daemon=True),
|
73
|
+
threading.Thread(target=self._process_query_queue, daemon=True),
|
74
|
+
]
|
75
|
+
for thread in threads:
|
76
|
+
thread.start()
|
77
|
+
|
78
|
+
self.query_counter = 0 # Initialize a counter for query IDs
|
79
|
+
|
80
|
+
def _start_process(self, command):
|
81
|
+
try:
|
82
|
+
return subprocess.Popen(
|
83
|
+
command,
|
84
|
+
stdin=subprocess.PIPE,
|
85
|
+
stdout=subprocess.PIPE,
|
86
|
+
stderr=subprocess.PIPE,
|
87
|
+
text=True,
|
88
|
+
bufsize=1,
|
89
|
+
universal_newlines=True,
|
90
|
+
)
|
91
|
+
except Exception as e:
|
92
|
+
logger.error(f"Failed to start KataGo process: {e}")
|
93
|
+
raise
|
94
|
+
|
95
|
+
def close(self):
|
96
|
+
self.query_queue.put((None, None))
|
97
|
+
if self.process:
|
98
|
+
self.process.terminate()
|
99
|
+
self.process.wait()
|
100
|
+
|
101
|
+
def _process_responses(self):
|
102
|
+
if self.process.stdout:
|
103
|
+
for line in self.process.stdout:
|
104
|
+
try:
|
105
|
+
response = json.loads(line)
|
106
|
+
query_id = response.get("id")
|
107
|
+
if query_id and query_id in self.response_callbacks:
|
108
|
+
callback = self.response_callbacks.pop(query_id)
|
109
|
+
self._log_response(response)
|
110
|
+
logger.debug(f"Calling callback for query_id: {query_id}")
|
111
|
+
try:
|
112
|
+
callback(response)
|
113
|
+
except Exception as e:
|
114
|
+
logger.error(f"Error calling callback for query_id: {query_id}, error: {e}")
|
115
|
+
traceback.print_exc()
|
116
|
+
logger.debug(f"Callback called for query_id: {query_id}")
|
117
|
+
else:
|
118
|
+
logger.error(f"Received response with unknown id: {query_id}")
|
119
|
+
except json.JSONDecodeError:
|
120
|
+
logger.error(f"Failed to parse KataGo response: {line.strip()}")
|
121
|
+
|
122
|
+
def analyze_position(self, node: GameNode, callback: Callable, human_profile_settings: dict, max_visits: int = 100):
|
123
|
+
nodes = node.nodes_from_root
|
124
|
+
moves = [m for node in nodes for m in node.moves]
|
125
|
+
self.query_counter += 1
|
126
|
+
query_id = f"{len(nodes)}_{(moves or ['root'])[-1]}_{human_profile_settings.get('humanSLProfile', 'ai')}_{max_visits}v_{self.query_counter}"
|
127
|
+
query = {
|
128
|
+
"id": query_id,
|
129
|
+
"rules": self.RULESETS_ABBR.get(node.ruleset.lower(), node.ruleset.lower()),
|
130
|
+
"boardXSize": node.board_size[0],
|
131
|
+
"boardYSize": node.board_size[1],
|
132
|
+
"moves": [[m.player, m.gtp()] for m in moves],
|
133
|
+
"includePolicy": True,
|
134
|
+
"initialStones": [[m.player, m.gtp()] for node in nodes for m in node.placements],
|
135
|
+
"includeOwnership": False,
|
136
|
+
"maxVisits": max_visits,
|
137
|
+
"overrideSettings": human_profile_settings,
|
138
|
+
}
|
139
|
+
self.query_queue.put((query, callback))
|
140
|
+
|
141
|
+
def _process_query_queue(self):
|
142
|
+
while True:
|
143
|
+
query, callback = self.query_queue.get()
|
144
|
+
if query is None:
|
145
|
+
break
|
146
|
+
try:
|
147
|
+
if self.process.stdin:
|
148
|
+
logger.debug(f"Sending query: {json.dumps(query, indent=2)}")
|
149
|
+
self.process.stdin.write(json.dumps(query) + "\n")
|
150
|
+
self.process.stdin.flush()
|
151
|
+
logger.debug(f"Sent query id {query['id']}")
|
152
|
+
self.response_callbacks[query["id"]] = callback
|
153
|
+
except Exception as e:
|
154
|
+
logger.error(f"Error sending query: {e}")
|
155
|
+
callback({"error": str(e)})
|
156
|
+
self.query_queue.task_done()
|
157
|
+
|
158
|
+
def num_outstanding_queries(self):
|
159
|
+
return len(self.response_callbacks)
|
160
|
+
|
161
|
+
def _log_stderr(self):
|
162
|
+
if self.process.stderr:
|
163
|
+
for line in self.process.stderr:
|
164
|
+
logger.info(f"[KataGo] {line.strip()}")
|
165
|
+
|
166
|
+
def _log_response(self, response):
|
167
|
+
response = copy.deepcopy(response)
|
168
|
+
for k in ["policy", "humanPolicy"]:
|
169
|
+
if k in response:
|
170
|
+
response[k] = f"[{len(response[k])} floats]"
|
171
|
+
moves = [
|
172
|
+
{k: v for k, v in move.items() if k in ["move", "visits", "winrate"]}
|
173
|
+
for move in response.get("moveInfos", [])
|
174
|
+
]
|
175
|
+
response["moveInfos"] = moves[:5] + [f"{len(moves) - 5} more..."] if len(moves) > 5 else moves
|
176
|
+
logger.debug(f"Received response: {json.dumps(response, indent=2)}")
|
shape/main.py
ADDED
@@ -0,0 +1,46 @@
|
|
1
|
+
import argparse
|
2
|
+
import signal
|
3
|
+
import sys
|
4
|
+
|
5
|
+
from PySide6.QtWidgets import QApplication
|
6
|
+
|
7
|
+
from shape.katago.engine import KataGoEngine
|
8
|
+
from shape.ui.main_window import MainWindow
|
9
|
+
from shape.utils import setup_logging
|
10
|
+
|
11
|
+
logger = setup_logging()
|
12
|
+
|
13
|
+
signal.signal(signal.SIGINT, signal.SIG_DFL) # hard exit on SIGINT
|
14
|
+
|
15
|
+
|
16
|
+
class SHAPEApp:
|
17
|
+
def __init__(self):
|
18
|
+
self.app = QApplication(sys.argv)
|
19
|
+
self.main_window = MainWindow()
|
20
|
+
|
21
|
+
try:
|
22
|
+
self.katago = KataGoEngine()
|
23
|
+
except Exception as e:
|
24
|
+
self.show_error(f"Failed to initialize KataGo engine: {e}")
|
25
|
+
sys.exit(1)
|
26
|
+
|
27
|
+
self.main_window.set_engine(self.katago)
|
28
|
+
|
29
|
+
def run(self):
|
30
|
+
self.main_window.show()
|
31
|
+
return self.app.exec()
|
32
|
+
|
33
|
+
def show_error(self, message):
|
34
|
+
logger.error(message)
|
35
|
+
|
36
|
+
|
37
|
+
def main():
|
38
|
+
parser = argparse.ArgumentParser(description="SHAPE: Shape Habits Analysis and Personalized Evaluation")
|
39
|
+
parser.parse_args()
|
40
|
+
|
41
|
+
shape = SHAPEApp()
|
42
|
+
sys.exit(shape.run())
|
43
|
+
|
44
|
+
|
45
|
+
if __name__ == "__main__":
|
46
|
+
main()
|