fow-rl 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,403 @@
1
+ # This file is automatically @generated by Cargo.
2
+ # It is not intended for manual editing.
3
+ version = 4
4
+
5
+ [[package]]
6
+ name = "addr2line"
7
+ version = "0.25.1"
8
+ source = "registry+https://github.com/rust-lang/crates.io-index"
9
+ checksum = "1b5d307320b3181d6d7954e663bd7c774a838b8220fe0593c86d9fb09f498b4b"
10
+ dependencies = [
11
+ "gimli",
12
+ ]
13
+
14
+ [[package]]
15
+ name = "adler2"
16
+ version = "2.0.1"
17
+ source = "registry+https://github.com/rust-lang/crates.io-index"
18
+ checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
19
+
20
+ [[package]]
21
+ name = "arrayvec"
22
+ version = "0.5.2"
23
+ source = "registry+https://github.com/rust-lang/crates.io-index"
24
+ checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b"
25
+
26
+ [[package]]
27
+ name = "autocfg"
28
+ version = "1.5.0"
29
+ source = "registry+https://github.com/rust-lang/crates.io-index"
30
+ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
31
+
32
+ [[package]]
33
+ name = "backtrace"
34
+ version = "0.3.76"
35
+ source = "registry+https://github.com/rust-lang/crates.io-index"
36
+ checksum = "bb531853791a215d7c62a30daf0dde835f381ab5de4589cfe7c649d2cbe92bd6"
37
+ dependencies = [
38
+ "addr2line",
39
+ "cfg-if",
40
+ "libc",
41
+ "miniz_oxide",
42
+ "object",
43
+ "rustc-demangle",
44
+ "windows-link",
45
+ ]
46
+
47
+ [[package]]
48
+ name = "cfg-if"
49
+ version = "1.0.4"
50
+ source = "registry+https://github.com/rust-lang/crates.io-index"
51
+ checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
52
+
53
+ [[package]]
54
+ name = "chess"
55
+ version = "3.2.0"
56
+ source = "registry+https://github.com/rust-lang/crates.io-index"
57
+ checksum = "2ed299b171ec34f372945ad6726f7bc1d2afd5f59fb8380f64f48e2bab2f0ec8"
58
+ dependencies = [
59
+ "arrayvec",
60
+ "failure",
61
+ "nodrop",
62
+ "rand",
63
+ ]
64
+
65
+ [[package]]
66
+ name = "failure"
67
+ version = "0.1.8"
68
+ source = "registry+https://github.com/rust-lang/crates.io-index"
69
+ checksum = "d32e9bd16cc02eae7db7ef620b392808b89f6a5e16bb3497d159c6b92a0f4f86"
70
+ dependencies = [
71
+ "backtrace",
72
+ "failure_derive",
73
+ ]
74
+
75
+ [[package]]
76
+ name = "failure_derive"
77
+ version = "0.1.8"
78
+ source = "registry+https://github.com/rust-lang/crates.io-index"
79
+ checksum = "aa4da3c766cd7a0db8242e326e9e4e081edd567072893ed320008189715366a4"
80
+ dependencies = [
81
+ "proc-macro2",
82
+ "quote",
83
+ "syn 1.0.109",
84
+ "synstructure",
85
+ ]
86
+
87
+ [[package]]
88
+ name = "fow_rl"
89
+ version = "0.1.0"
90
+ dependencies = [
91
+ "chess",
92
+ "pyo3",
93
+ ]
94
+
95
+ [[package]]
96
+ name = "gimli"
97
+ version = "0.32.3"
98
+ source = "registry+https://github.com/rust-lang/crates.io-index"
99
+ checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7"
100
+
101
+ [[package]]
102
+ name = "heck"
103
+ version = "0.5.0"
104
+ source = "registry+https://github.com/rust-lang/crates.io-index"
105
+ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
106
+
107
+ [[package]]
108
+ name = "indoc"
109
+ version = "2.0.7"
110
+ source = "registry+https://github.com/rust-lang/crates.io-index"
111
+ checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706"
112
+ dependencies = [
113
+ "rustversion",
114
+ ]
115
+
116
+ [[package]]
117
+ name = "libc"
118
+ version = "0.2.186"
119
+ source = "registry+https://github.com/rust-lang/crates.io-index"
120
+ checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66"
121
+
122
+ [[package]]
123
+ name = "memchr"
124
+ version = "2.8.0"
125
+ source = "registry+https://github.com/rust-lang/crates.io-index"
126
+ checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
127
+
128
+ [[package]]
129
+ name = "memoffset"
130
+ version = "0.9.1"
131
+ source = "registry+https://github.com/rust-lang/crates.io-index"
132
+ checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
133
+ dependencies = [
134
+ "autocfg",
135
+ ]
136
+
137
+ [[package]]
138
+ name = "miniz_oxide"
139
+ version = "0.8.9"
140
+ source = "registry+https://github.com/rust-lang/crates.io-index"
141
+ checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316"
142
+ dependencies = [
143
+ "adler2",
144
+ ]
145
+
146
+ [[package]]
147
+ name = "nodrop"
148
+ version = "0.1.14"
149
+ source = "registry+https://github.com/rust-lang/crates.io-index"
150
+ checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb"
151
+
152
+ [[package]]
153
+ name = "object"
154
+ version = "0.37.3"
155
+ source = "registry+https://github.com/rust-lang/crates.io-index"
156
+ checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe"
157
+ dependencies = [
158
+ "memchr",
159
+ ]
160
+
161
+ [[package]]
162
+ name = "once_cell"
163
+ version = "1.21.4"
164
+ source = "registry+https://github.com/rust-lang/crates.io-index"
165
+ checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
166
+
167
+ [[package]]
168
+ name = "portable-atomic"
169
+ version = "1.13.1"
170
+ source = "registry+https://github.com/rust-lang/crates.io-index"
171
+ checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
172
+
173
+ [[package]]
174
+ name = "ppv-lite86"
175
+ version = "0.2.21"
176
+ source = "registry+https://github.com/rust-lang/crates.io-index"
177
+ checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
178
+ dependencies = [
179
+ "zerocopy",
180
+ ]
181
+
182
+ [[package]]
183
+ name = "proc-macro2"
184
+ version = "1.0.106"
185
+ source = "registry+https://github.com/rust-lang/crates.io-index"
186
+ checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
187
+ dependencies = [
188
+ "unicode-ident",
189
+ ]
190
+
191
+ [[package]]
192
+ name = "pyo3"
193
+ version = "0.23.5"
194
+ source = "registry+https://github.com/rust-lang/crates.io-index"
195
+ checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
196
+ dependencies = [
197
+ "cfg-if",
198
+ "indoc",
199
+ "libc",
200
+ "memoffset",
201
+ "once_cell",
202
+ "portable-atomic",
203
+ "pyo3-build-config",
204
+ "pyo3-ffi",
205
+ "pyo3-macros",
206
+ "unindent",
207
+ ]
208
+
209
+ [[package]]
210
+ name = "pyo3-build-config"
211
+ version = "0.23.5"
212
+ source = "registry+https://github.com/rust-lang/crates.io-index"
213
+ checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
214
+ dependencies = [
215
+ "once_cell",
216
+ "target-lexicon",
217
+ ]
218
+
219
+ [[package]]
220
+ name = "pyo3-ffi"
221
+ version = "0.23.5"
222
+ source = "registry+https://github.com/rust-lang/crates.io-index"
223
+ checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
224
+ dependencies = [
225
+ "libc",
226
+ "pyo3-build-config",
227
+ ]
228
+
229
+ [[package]]
230
+ name = "pyo3-macros"
231
+ version = "0.23.5"
232
+ source = "registry+https://github.com/rust-lang/crates.io-index"
233
+ checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
234
+ dependencies = [
235
+ "proc-macro2",
236
+ "pyo3-macros-backend",
237
+ "quote",
238
+ "syn 2.0.117",
239
+ ]
240
+
241
+ [[package]]
242
+ name = "pyo3-macros-backend"
243
+ version = "0.23.5"
244
+ source = "registry+https://github.com/rust-lang/crates.io-index"
245
+ checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
246
+ dependencies = [
247
+ "heck",
248
+ "proc-macro2",
249
+ "pyo3-build-config",
250
+ "quote",
251
+ "syn 2.0.117",
252
+ ]
253
+
254
+ [[package]]
255
+ name = "quote"
256
+ version = "1.0.45"
257
+ source = "registry+https://github.com/rust-lang/crates.io-index"
258
+ checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
259
+ dependencies = [
260
+ "proc-macro2",
261
+ ]
262
+
263
+ [[package]]
264
+ name = "rand"
265
+ version = "0.7.3"
266
+ source = "registry+https://github.com/rust-lang/crates.io-index"
267
+ checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
268
+ dependencies = [
269
+ "rand_chacha",
270
+ "rand_core",
271
+ "rand_hc",
272
+ "rand_pcg",
273
+ ]
274
+
275
+ [[package]]
276
+ name = "rand_chacha"
277
+ version = "0.2.2"
278
+ source = "registry+https://github.com/rust-lang/crates.io-index"
279
+ checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
280
+ dependencies = [
281
+ "ppv-lite86",
282
+ "rand_core",
283
+ ]
284
+
285
+ [[package]]
286
+ name = "rand_core"
287
+ version = "0.5.1"
288
+ source = "registry+https://github.com/rust-lang/crates.io-index"
289
+ checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
290
+
291
+ [[package]]
292
+ name = "rand_hc"
293
+ version = "0.2.0"
294
+ source = "registry+https://github.com/rust-lang/crates.io-index"
295
+ checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
296
+ dependencies = [
297
+ "rand_core",
298
+ ]
299
+
300
+ [[package]]
301
+ name = "rand_pcg"
302
+ version = "0.2.1"
303
+ source = "registry+https://github.com/rust-lang/crates.io-index"
304
+ checksum = "16abd0c1b639e9eb4d7c50c0b8100b0d0f849be2349829c740fe8e6eb4816429"
305
+ dependencies = [
306
+ "rand_core",
307
+ ]
308
+
309
+ [[package]]
310
+ name = "rustc-demangle"
311
+ version = "0.1.27"
312
+ source = "registry+https://github.com/rust-lang/crates.io-index"
313
+ checksum = "b50b8869d9fc858ce7266cce0194bd74df58b9d0e3f6df3a9fc8eb470d95c09d"
314
+
315
+ [[package]]
316
+ name = "rustversion"
317
+ version = "1.0.22"
318
+ source = "registry+https://github.com/rust-lang/crates.io-index"
319
+ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
320
+
321
+ [[package]]
322
+ name = "syn"
323
+ version = "1.0.109"
324
+ source = "registry+https://github.com/rust-lang/crates.io-index"
325
+ checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
326
+ dependencies = [
327
+ "proc-macro2",
328
+ "quote",
329
+ "unicode-ident",
330
+ ]
331
+
332
+ [[package]]
333
+ name = "syn"
334
+ version = "2.0.117"
335
+ source = "registry+https://github.com/rust-lang/crates.io-index"
336
+ checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
337
+ dependencies = [
338
+ "proc-macro2",
339
+ "quote",
340
+ "unicode-ident",
341
+ ]
342
+
343
+ [[package]]
344
+ name = "synstructure"
345
+ version = "0.12.6"
346
+ source = "registry+https://github.com/rust-lang/crates.io-index"
347
+ checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f"
348
+ dependencies = [
349
+ "proc-macro2",
350
+ "quote",
351
+ "syn 1.0.109",
352
+ "unicode-xid",
353
+ ]
354
+
355
+ [[package]]
356
+ name = "target-lexicon"
357
+ version = "0.12.16"
358
+ source = "registry+https://github.com/rust-lang/crates.io-index"
359
+ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
360
+
361
+ [[package]]
362
+ name = "unicode-ident"
363
+ version = "1.0.24"
364
+ source = "registry+https://github.com/rust-lang/crates.io-index"
365
+ checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
366
+
367
+ [[package]]
368
+ name = "unicode-xid"
369
+ version = "0.2.6"
370
+ source = "registry+https://github.com/rust-lang/crates.io-index"
371
+ checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
372
+
373
+ [[package]]
374
+ name = "unindent"
375
+ version = "0.2.4"
376
+ source = "registry+https://github.com/rust-lang/crates.io-index"
377
+ checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3"
378
+
379
+ [[package]]
380
+ name = "windows-link"
381
+ version = "0.2.1"
382
+ source = "registry+https://github.com/rust-lang/crates.io-index"
383
+ checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
384
+
385
+ [[package]]
386
+ name = "zerocopy"
387
+ version = "0.8.48"
388
+ source = "registry+https://github.com/rust-lang/crates.io-index"
389
+ checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9"
390
+ dependencies = [
391
+ "zerocopy-derive",
392
+ ]
393
+
394
+ [[package]]
395
+ name = "zerocopy-derive"
396
+ version = "0.8.48"
397
+ source = "registry+https://github.com/rust-lang/crates.io-index"
398
+ checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4"
399
+ dependencies = [
400
+ "proc-macro2",
401
+ "quote",
402
+ "syn 2.0.117",
403
+ ]
@@ -0,0 +1,17 @@
1
+ [package]
2
+ name = "fow_rl"
3
+ version = "0.1.0"
4
+ edition = "2021"
5
+
6
+ # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7
+ [lib]
8
+ name = "fow_rl"
9
+ crate-type = ["rlib"]
10
+
11
+ [dependencies]
12
+ pyo3 = "0.23.3"
13
+ chess = "3.2"
14
+
15
+ [features]
16
+ extension-module = ["pyo3/extension-module"]
17
+ fuzzing = []
fow_rl-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.4
2
+ Name: fow_rl
3
+ Version: 0.1.0
4
+ Classifier: Programming Language :: Rust
5
+ Classifier: Programming Language :: Python :: Implementation :: CPython
6
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
7
+ Requires-Python: >=3.8
@@ -0,0 +1,15 @@
1
+ [build-system]
2
+ requires = ["maturin>=1.8,<2.0"]
3
+ build-backend = "maturin"
4
+
5
+ [project]
6
+ name = "fow_rl"
7
+ requires-python = ">=3.8"
8
+ classifiers = [
9
+ "Programming Language :: Rust",
10
+ "Programming Language :: Python :: Implementation :: CPython",
11
+ "Programming Language :: Python :: Implementation :: PyPy",
12
+ ]
13
+ dynamic = ["version"]
14
+ [tool.maturin]
15
+ features = ["extension-module"]
@@ -0,0 +1,903 @@
1
+ use std::collections::{HashSet, VecDeque};
2
+ use std::str::FromStr;
3
+
4
+ use chess::{
5
+ get_bishop_moves, get_king_moves, get_knight_moves, get_pawn_attacks, get_rook_moves, Board,
6
+ BitBoard, ChessMove, Color, File, MoveGen, Piece, Rank, Square, EMPTY,
7
+ };
8
+ use pyo3::exceptions::PyValueError;
9
+ use pyo3::prelude::*;
10
+ use pyo3::types::PyDict;
11
+
12
+ const HISTORY_LEN: usize = 8;
13
+ const SLICE_CHANNELS: usize = 13;
14
+ const CASTLING_CHANNELS: usize = 2;
15
+ const BOARD_SQUARES: usize = 64;
16
+ const PROMOTION_BASE: usize = 4096;
17
+
18
+ const FILES: [File; 8] = [
19
+ File::A,
20
+ File::B,
21
+ File::C,
22
+ File::D,
23
+ File::E,
24
+ File::F,
25
+ File::G,
26
+ File::H,
27
+ ];
28
+
29
+ const RANKS: [Rank; 8] = [
30
+ Rank::First,
31
+ Rank::Second,
32
+ Rank::Third,
33
+ Rank::Fourth,
34
+ Rank::Fifth,
35
+ Rank::Sixth,
36
+ Rank::Seventh,
37
+ Rank::Eighth,
38
+ ];
39
+
40
+ fn square_from_index(index: usize) -> Square {
41
+ let file = FILES[index % 8];
42
+ let rank = RANKS[index / 8];
43
+ Square::make_square(rank, file)
44
+ }
45
+
46
+ fn square_index(square: Square) -> usize {
47
+ square.to_index() as usize
48
+ }
49
+
50
+ fn mirror_index(index: usize) -> usize {
51
+ index ^ 56
52
+ }
53
+
54
+ fn mirror_square(square: Square) -> Square {
55
+ square_from_index(mirror_index(square_index(square)))
56
+ }
57
+
58
+ fn piece_base_index(piece: Piece) -> usize {
59
+ match piece {
60
+ Piece::Pawn => 0,
61
+ Piece::Knight => 1,
62
+ Piece::Bishop => 2,
63
+ Piece::Rook => 3,
64
+ Piece::Queen => 4,
65
+ Piece::King => 5,
66
+ }
67
+ }
68
+
69
+ fn piece_value(piece: Piece) -> i8 {
70
+ match piece {
71
+ Piece::Pawn => 1,
72
+ Piece::Knight => 2,
73
+ Piece::Bishop => 3,
74
+ Piece::Rook => 4,
75
+ Piece::Queen => 5,
76
+ Piece::King => 6,
77
+ }
78
+ }
79
+
80
+ fn promotion_index(piece: Piece) -> Option<usize> {
81
+ match piece {
82
+ Piece::Queen => Some(0),
83
+ Piece::Rook => Some(1),
84
+ Piece::Bishop => Some(2),
85
+ Piece::Knight => Some(3),
86
+ _ => None,
87
+ }
88
+ }
89
+
90
+ fn reshape_rows(data: Vec<f32>, cols: usize) -> Vec<Vec<f32>> {
91
+ data.chunks(cols).map(|chunk| chunk.to_vec()).collect()
92
+ }
93
+
94
+ fn fen_from_py(board_obj: &Bound<'_, PyAny>) -> PyResult<String> {
95
+ if let Ok(fen) = board_obj.extract::<String>() {
96
+ return Ok(fen);
97
+ }
98
+
99
+ if let Ok(fen_attr) = board_obj.getattr("fen") {
100
+ if fen_attr.is_callable() {
101
+ return fen_attr.call0()?.extract::<String>();
102
+ }
103
+ if let Ok(fen) = fen_attr.extract::<String>() {
104
+ return Ok(fen);
105
+ }
106
+ }
107
+
108
+ Err(PyValueError::new_err(
109
+ "board must be a python-chess Board or FEN string",
110
+ ))
111
+ }
112
+
113
+ fn flatten_matrix(matrix: Vec<Vec<f32>>) -> PyResult<Vec<f32>> {
114
+ if matrix.len() != BOARD_SQUARES {
115
+ return Err(PyValueError::new_err(format!(
116
+ "history_tensors entries must have {} rows",
117
+ BOARD_SQUARES
118
+ )));
119
+ }
120
+
121
+ let mut flat = Vec::with_capacity(BOARD_SQUARES * SLICE_CHANNELS);
122
+ for row in matrix {
123
+ if row.len() != SLICE_CHANNELS {
124
+ return Err(PyValueError::new_err(format!(
125
+ "history_tensors entries must have {} columns",
126
+ SLICE_CHANNELS
127
+ )));
128
+ }
129
+ flat.extend_from_slice(&row);
130
+ }
131
+
132
+ Ok(flat)
133
+ }
134
+
135
+ fn validate_flat(flat: Vec<f32>) -> PyResult<Vec<f32>> {
136
+ if flat.len() != BOARD_SQUARES * SLICE_CHANNELS {
137
+ return Err(PyValueError::new_err(format!(
138
+ "history_tensors entries must have length {}",
139
+ BOARD_SQUARES * SLICE_CHANNELS
140
+ )));
141
+ }
142
+ Ok(flat)
143
+ }
144
+
145
+ fn tensor_to_flat(tensor_obj: &Bound<'_, PyAny>) -> PyResult<Vec<f32>> {
146
+ if let Ok(matrix) = tensor_obj.extract::<Vec<Vec<f32>>>() {
147
+ return flatten_matrix(matrix);
148
+ }
149
+ if let Ok(flat) = tensor_obj.extract::<Vec<f32>>() {
150
+ return validate_flat(flat);
151
+ }
152
+
153
+ if let Ok(tolist) = tensor_obj.getattr("tolist") {
154
+ if tolist.is_callable() {
155
+ let list_obj = tolist.call0()?;
156
+ if let Ok(matrix) = list_obj.extract::<Vec<Vec<f32>>>() {
157
+ return flatten_matrix(matrix);
158
+ }
159
+ if let Ok(flat) = list_obj.extract::<Vec<f32>>() {
160
+ return validate_flat(flat);
161
+ }
162
+ }
163
+ }
164
+
165
+ Err(PyValueError::new_err(
166
+ "history_tensors entries must be 64x13 tensors or lists",
167
+ ))
168
+ }
169
+
170
+ fn history_from_py(history_obj: &Bound<'_, PyAny>) -> PyResult<VecDeque<Vec<f32>>> {
171
+ let mut deque = VecDeque::new();
172
+
173
+ for item in history_obj.try_iter()? {
174
+ let item = item?;
175
+ let flat = tensor_to_flat(&item)?;
176
+ deque.push_back(flat);
177
+ }
178
+
179
+ while deque.len() > HISTORY_LEN {
180
+ deque.pop_front();
181
+ }
182
+
183
+ Ok(deque)
184
+ }
185
+
186
+ fn uci_from_py(action_obj: &Bound<'_, PyAny>) -> PyResult<String> {
187
+ if let Ok(uci) = action_obj.extract::<String>() {
188
+ return Ok(uci);
189
+ }
190
+
191
+ if let Ok(uci_attr) = action_obj.getattr("uci") {
192
+ if uci_attr.is_callable() {
193
+ return uci_attr.call0()?.extract::<String>();
194
+ }
195
+ if let Ok(uci) = uci_attr.extract::<String>() {
196
+ return Ok(uci);
197
+ }
198
+ }
199
+
200
+ Err(PyValueError::new_err(
201
+ "action must be a python-chess Move or UCI string",
202
+ ))
203
+ }
204
+
205
+ #[pyclass]
206
+ #[derive(Clone)]
207
+ pub struct SelfPlayGame {
208
+ board: Board,
209
+ history_tensors: VecDeque<Vec<f32>>,
210
+ }
211
+
212
+ impl SelfPlayGame {
213
+ fn new_internal(board: Board, history_tensors: Option<VecDeque<Vec<f32>>>) -> Self {
214
+ let mut game = SelfPlayGame {
215
+ board,
216
+ history_tensors: history_tensors.unwrap_or_default(),
217
+ };
218
+
219
+ if game.history_tensors.is_empty() {
220
+ game.save_history_state();
221
+ }
222
+
223
+ game
224
+ }
225
+
226
+ fn side_to_move(&self) -> Color {
227
+ self.board.side_to_move()
228
+ }
229
+
230
+ fn push_history(&mut self, slice: Vec<f32>) {
231
+ if self.history_tensors.len() == HISTORY_LEN {
232
+ self.history_tensors.pop_front();
233
+ }
234
+ self.history_tensors.push_back(slice);
235
+ }
236
+
237
+ fn save_history_state(&mut self) {
238
+ let current_player = self.side_to_move();
239
+ let visible_squares = self.compute_visibility(current_player);
240
+
241
+ let mut slice_tensor = vec![0.0; BOARD_SQUARES * SLICE_CHANNELS];
242
+
243
+ for index in 0..BOARD_SQUARES {
244
+ let view_index = if current_player == Color::White {
245
+ index
246
+ } else {
247
+ mirror_index(index)
248
+ };
249
+
250
+ if !visible_squares.contains(&index) {
251
+ slice_tensor[view_index * SLICE_CHANNELS + 12] = 1.0;
252
+ continue;
253
+ }
254
+
255
+ let square = square_from_index(index);
256
+ if let Some(piece) = self.board.piece_on(square) {
257
+ let is_own_piece = self.board.color_on(square) == Some(current_player);
258
+ let mut piece_idx = piece_base_index(piece);
259
+ if !is_own_piece {
260
+ piece_idx += 6;
261
+ }
262
+ slice_tensor[view_index * SLICE_CHANNELS + piece_idx] = 1.0;
263
+ }
264
+ }
265
+
266
+ self.push_history(slice_tensor);
267
+ }
268
+
269
+ fn board_matrix(&self) -> Vec<Vec<i8>> {
270
+ let mut mat = Vec::with_capacity(8);
271
+ for rank in (0..8).rev() {
272
+ let mut row = Vec::with_capacity(8);
273
+ for file in 0..8 {
274
+ let square = Square::make_square(RANKS[rank], FILES[file]);
275
+ if let Some(piece) = self.board.piece_on(square) {
276
+ let mut val = piece_value(piece);
277
+ if self.board.color_on(square) == Some(Color::Black) {
278
+ val = -val;
279
+ }
280
+ row.push(val);
281
+ } else {
282
+ row.push(0);
283
+ }
284
+ }
285
+ mat.push(row);
286
+ }
287
+ mat
288
+ }
289
+
290
+ fn castle_eligible_internal(&self) -> (bool, bool) {
291
+ let rights = self.board.castle_rights(self.side_to_move());
292
+ (rights.has_kingside(), rights.has_queenside())
293
+ }
294
+
295
+ fn legal_actions_internal(&self) -> Vec<ChessMove> {
296
+ MoveGen::new_legal(&self.board).collect()
297
+ }
298
+
299
+ fn apply_internal(&mut self, action: ChessMove) {
300
+ self.board = self.board.make_move_new(action);
301
+ self.save_history_state();
302
+ }
303
+
304
+ fn is_terminal_internal(&self) -> bool {
305
+ !self.king_present(Color::White) || !self.king_present(Color::Black)
306
+ }
307
+
308
+ fn terminal_value_internal(&self, player: Color) -> f32 {
309
+ if !self.king_present(Color::White) {
310
+ return if player == Color::Black { 1.0 } else { -1.0 };
311
+ }
312
+ if !self.king_present(Color::Black) {
313
+ return if player == Color::White { 1.0 } else { -1.0 };
314
+ }
315
+ 0.0
316
+ }
317
+
318
+ fn king_present(&self, color: Color) -> bool {
319
+ let kings = *self.board.pieces(Piece::King) & *self.board.color_combined(color);
320
+ kings != EMPTY
321
+ }
322
+
323
+ fn piece_attacks(&self, square: Square, piece: Piece, color: Color) -> BitBoard {
324
+ let blockers = *self.board.combined();
325
+ match piece {
326
+ Piece::Pawn => get_pawn_attacks(square, color, !EMPTY),
327
+ Piece::Knight => get_knight_moves(square),
328
+ Piece::Bishop => get_bishop_moves(square, blockers),
329
+ Piece::Rook => get_rook_moves(square, blockers),
330
+ Piece::Queen => get_bishop_moves(square, blockers) | get_rook_moves(square, blockers),
331
+ Piece::King => get_king_moves(square),
332
+ }
333
+ }
334
+
335
+ fn compute_visibility(&self, player: Color) -> HashSet<usize> {
336
+ let mut visible = HashSet::new();
337
+
338
+ for index in 0..BOARD_SQUARES {
339
+ let square = square_from_index(index);
340
+ let piece = self.board.piece_on(square);
341
+ let color = self.board.color_on(square);
342
+
343
+ if piece.is_some() && color == Some(player) {
344
+ visible.insert(index);
345
+ let attacks = self.piece_attacks(square, piece.unwrap(), player);
346
+ for target in attacks {
347
+ visible.insert(square_index(target));
348
+ }
349
+
350
+ if piece == Some(Piece::Pawn) {
351
+ let forward_offset: i32 = if player == Color::White { 8 } else { -8 };
352
+ let fwd_index = index as i32 + forward_offset;
353
+ if (0..BOARD_SQUARES as i32).contains(&fwd_index) {
354
+ visible.insert(fwd_index as usize);
355
+
356
+ let rank = index / 8;
357
+ let start_rank = if player == Color::White { 1 } else { 6 };
358
+ if rank == start_rank {
359
+ let fwd_square = square_from_index(fwd_index as usize);
360
+ if self.board.piece_on(fwd_square).is_none() {
361
+ let dbl_index = index as i32 + 2 * forward_offset;
362
+ if (0..BOARD_SQUARES as i32).contains(&dbl_index) {
363
+ visible.insert(dbl_index as usize);
364
+ }
365
+ }
366
+ }
367
+ }
368
+ }
369
+ }
370
+ }
371
+
372
+ visible
373
+ }
374
+
375
+ fn encode_flat(&self) -> Vec<f32> {
376
+ let mut history_list: Vec<Vec<f32>> = if self.history_tensors.is_empty() {
377
+ vec![vec![0.0; BOARD_SQUARES * SLICE_CHANNELS]]
378
+ } else {
379
+ self.history_tensors.iter().cloned().collect()
380
+ };
381
+
382
+ while history_list.len() < HISTORY_LEN {
383
+ let first = history_list[0].clone();
384
+ history_list.insert(0, first);
385
+ }
386
+
387
+ let history_channels = SLICE_CHANNELS * HISTORY_LEN;
388
+ let mut tensor = vec![0.0; BOARD_SQUARES * history_channels];
389
+
390
+ for square in 0..BOARD_SQUARES {
391
+ for (idx, slice) in history_list.iter().enumerate() {
392
+ let dst_base = square * history_channels + idx * SLICE_CHANNELS;
393
+ let src_base = square * SLICE_CHANNELS;
394
+ tensor[dst_base..dst_base + SLICE_CHANNELS]
395
+ .copy_from_slice(&slice[src_base..src_base + SLICE_CHANNELS]);
396
+ }
397
+ }
398
+
399
+ let mut castling_channel = vec![0.0; BOARD_SQUARES * CASTLING_CHANNELS];
400
+ let (king_side, queen_side) = self.castle_eligible_internal();
401
+ let king_val = if king_side { 1.0 } else { 0.0 };
402
+ let queen_val = if queen_side { 1.0 } else { 0.0 };
403
+
404
+ if BOARD_SQUARES >= 2 {
405
+ castling_channel[0] += king_val;
406
+ castling_channel[1] += king_val;
407
+ castling_channel[2] += queen_val;
408
+ castling_channel[3] += queen_val;
409
+ }
410
+
411
+ let mut encoded = vec![0.0; BOARD_SQUARES * (history_channels + CASTLING_CHANNELS)];
412
+ for square in 0..BOARD_SQUARES {
413
+ let dst_base = square * (history_channels + CASTLING_CHANNELS);
414
+ let src_base = square * history_channels;
415
+ encoded[dst_base..dst_base + history_channels]
416
+ .copy_from_slice(&tensor[src_base..src_base + history_channels]);
417
+
418
+ let castling_base = square * CASTLING_CHANNELS;
419
+ encoded[dst_base + history_channels..dst_base + history_channels + CASTLING_CHANNELS]
420
+ .copy_from_slice(&castling_channel[castling_base..castling_base + CASTLING_CHANNELS]);
421
+ }
422
+
423
+ encoded
424
+ }
425
+
426
+ fn action_index_internal(&self, action: ChessMove) -> Result<usize, String> {
427
+ let mut from_sq = action.get_source();
428
+ let mut to_sq = action.get_dest();
429
+
430
+ if self.side_to_move() == Color::Black {
431
+ from_sq = mirror_square(from_sq);
432
+ to_sq = mirror_square(to_sq);
433
+ }
434
+
435
+ if let Some(promotion) = action.get_promotion() {
436
+ let promotion_idx = promotion_index(promotion)
437
+ .ok_or_else(|| format!("Unsupported promotion piece: {:?}", promotion))?;
438
+ return Ok(PROMOTION_BASE + square_index(from_sq) * 4 + promotion_idx);
439
+ }
440
+
441
+ Ok(square_index(from_sq) * BOARD_SQUARES + square_index(to_sq))
442
+ }
443
+ }
444
+
445
+ #[pymethods]
446
+ impl SelfPlayGame {
447
+ #[new]
448
+ #[pyo3(signature = (board=None, history_tensors=None))]
449
+ fn new(
450
+ py: Python,
451
+ board: Option<PyObject>,
452
+ history_tensors: Option<PyObject>,
453
+ ) -> PyResult<Self> {
454
+ let board = match board {
455
+ Some(obj) => {
456
+ let bound = obj.bind(py);
457
+ let fen = fen_from_py(&bound)?;
458
+ Board::from_str(&fen).map_err(|_| PyValueError::new_err("Invalid FEN string"))?
459
+ }
460
+ None => Board::default(),
461
+ };
462
+
463
+ let history_deque = match history_tensors {
464
+ Some(obj) => {
465
+ let bound = obj.bind(py);
466
+ Some(history_from_py(&bound)?)
467
+ }
468
+ None => None,
469
+ };
470
+
471
+ Ok(SelfPlayGame::new_internal(board, history_deque))
472
+ }
473
+
474
+ #[getter]
475
+ fn turn(&self) -> bool {
476
+ self.side_to_move() == Color::White
477
+ }
478
+
479
+ #[getter]
480
+ fn board(&self) -> Vec<Vec<i8>> {
481
+ self.board_matrix()
482
+ }
483
+
484
+ fn castle_eligible(&self) -> (bool, bool) {
485
+ self.castle_eligible_internal()
486
+ }
487
+
488
+ fn clone(&self) -> SelfPlayGame {
489
+ Clone::clone(self)
490
+ }
491
+
492
+ fn current_player(&self) -> bool {
493
+ self.side_to_move() == Color::White
494
+ }
495
+
496
+ #[pyo3(name = "legal_actions")]
497
+ fn legal_actions_py(&self, py: Python) -> PyResult<Vec<PyObject>> {
498
+ let chess_mod = PyModule::import(py, "chess")
499
+ .map_err(|_| PyValueError::new_err("python-chess is required for legal_actions()"))?;
500
+ let move_type = chess_mod.getattr("Move")?;
501
+ let from_uci = move_type.getattr("from_uci")?;
502
+
503
+ let mut moves = Vec::new();
504
+ for mv in self.legal_actions_internal() {
505
+ let obj = from_uci.call1((mv.to_string(),))?;
506
+ moves.push(obj.unbind());
507
+ }
508
+
509
+ Ok(moves)
510
+ }
511
+
512
+ fn apply(&mut self, py: Python, action: PyObject) -> PyResult<()> {
513
+ let bound = action.bind(py);
514
+ let uci = uci_from_py(&bound)?;
515
+ let action = ChessMove::from_str(&uci)
516
+ .map_err(|_| PyValueError::new_err("Invalid UCI move"))?;
517
+ self.apply_internal(action);
518
+ Ok(())
519
+ }
520
+
521
+ fn is_terminal(&self) -> bool {
522
+ self.is_terminal_internal()
523
+ }
524
+
525
+ fn terminal_value(&self, player_is_white: bool) -> f32 {
526
+ let player = if player_is_white { Color::White } else { Color::Black };
527
+ self.terminal_value_internal(player)
528
+ }
529
+
530
+ fn encode(&self, py: Python) -> PyResult<PyObject> {
531
+ let flat = self.encode_flat();
532
+ let cols = SLICE_CHANNELS * HISTORY_LEN + CASTLING_CHANNELS;
533
+ let rows = reshape_rows(flat, cols);
534
+
535
+ let torch = PyModule::import(py, "torch")
536
+ .map_err(|_| PyValueError::new_err("torch is required for encode()"))?;
537
+ let tensor_fn = torch.getattr("tensor")?;
538
+ let kwargs = PyDict::new(py);
539
+ kwargs.set_item("dtype", torch.getattr("float32")?)?;
540
+ let tensor = tensor_fn.call((rows,), Some(&kwargs))?;
541
+ Ok(tensor.unbind())
542
+ }
543
+
544
+ fn action_index(&self, py: Python, action: PyObject) -> PyResult<usize> {
545
+ let bound = action.bind(py);
546
+ let uci = uci_from_py(&bound)?;
547
+ let action = ChessMove::from_str(&uci)
548
+ .map_err(|_| PyValueError::new_err("Invalid UCI move"))?;
549
+ self.action_index_internal(action)
550
+ .map_err(PyValueError::new_err)
551
+ }
552
+ }
553
+
554
+ /// A Python module implemented in Rust.
555
+ #[pymodule]
556
+ fn fow_rl(m: &Bound<'_, PyModule>) -> PyResult<()> {
557
+ m.add_class::<SelfPlayGame>()?;
558
+ Ok(())
559
+ }
560
+
561
+ #[cfg(feature = "fuzzing")]
562
+ pub mod fuzzing {
563
+ use super::*;
564
+
565
+ pub fn new_default_game() -> SelfPlayGame {
566
+ SelfPlayGame::new_internal(Board::default(), None)
567
+ }
568
+
569
+ pub fn apply(game: &mut SelfPlayGame, action: ChessMove) {
570
+ game.apply_internal(action);
571
+ }
572
+
573
+ pub fn legal_actions(game: &SelfPlayGame) -> Vec<ChessMove> {
574
+ game.legal_actions_internal()
575
+ }
576
+
577
+ pub fn board_matrix(game: &SelfPlayGame) -> Vec<Vec<i8>> {
578
+ game.board_matrix()
579
+ }
580
+
581
+ pub fn castle_eligible(game: &SelfPlayGame) -> (bool, bool) {
582
+ game.castle_eligible_internal()
583
+ }
584
+
585
+ pub fn encode_flat(game: &SelfPlayGame) -> Vec<f32> {
586
+ game.encode_flat()
587
+ }
588
+
589
+ pub fn action_index(game: &SelfPlayGame, action: ChessMove) -> Result<usize, String> {
590
+ game.action_index_internal(action)
591
+ }
592
+
593
+ pub fn is_white_to_move(game: &SelfPlayGame) -> bool {
594
+ game.side_to_move() == Color::White
595
+ }
596
+ }
597
+
598
+ #[cfg(test)]
599
+ mod tests {
600
+ use super::*;
601
+ use pyo3::types::PyModule;
602
+ use std::ffi::CString;
603
+
604
+ const PY_SELFPLAY_CODE: &str = r#"
605
+ import chess
606
+ import torch
607
+ from collections import deque
608
+
609
+ _PROMOTION_PIECE_ORDER = {
610
+ chess.QUEEN: 0,
611
+ chess.ROOK: 1,
612
+ chess.BISHOP: 2,
613
+ chess.KNIGHT: 3,
614
+ }
615
+ _PROMOTION_OPTION_COUNT = len(_PROMOTION_PIECE_ORDER)
616
+
617
+
618
+ class SelfPlayGame:
619
+ """Mutable game-state protocol consumed by PPO for Fog of War Chess."""
620
+
621
+ def __init__(self, board: chess.Board = None, history_tensors: deque = None):
622
+ self._board = board if board else chess.Board()
623
+
624
+ if history_tensors is not None:
625
+ self._history_tensors = history_tensors
626
+ else:
627
+ self._history_tensors = deque(maxlen=8)
628
+ self._save_history_state()
629
+
630
+ def _save_history_state(self) -> None:
631
+ current_player = self.turn
632
+ visible_squares = self._compute_visibility(self._board, current_player)
633
+
634
+ slice_tensor = torch.zeros((64, 13), dtype=torch.float32)
635
+
636
+ for sq in chess.SQUARES:
637
+ view_sq = sq if current_player == chess.WHITE else chess.square_mirror(sq)
638
+
639
+ if sq not in visible_squares:
640
+ slice_tensor[view_sq, 12] = 1.0
641
+ else:
642
+ piece = self._board.piece_at(sq)
643
+ if piece is not None:
644
+ is_own_piece = (piece.color == current_player)
645
+ piece_idx = piece.piece_type - 1
646
+ if not is_own_piece:
647
+ piece_idx += 6
648
+ slice_tensor[view_sq, piece_idx] = 1.0
649
+
650
+ self._history_tensors.append(slice_tensor)
651
+
652
+ @property
653
+ def turn(self) -> bool:
654
+ return self._board.turn
655
+
656
+ @property
657
+ def board(self):
658
+ mat = []
659
+ for rank in range(7, -1, -1):
660
+ row = []
661
+ for file in range(8):
662
+ piece = self._board.piece_at(chess.square(file, rank))
663
+ if not piece:
664
+ row.append(0)
665
+ else:
666
+ val = piece.piece_type
667
+ row.append(val if piece.color == chess.WHITE else -val)
668
+ mat.append(row)
669
+ return mat
670
+
671
+ def castle_eligible(self):
672
+ return (
673
+ self._board.has_kingside_castling_rights(self.turn),
674
+ self._board.has_queenside_castling_rights(self.turn),
675
+ )
676
+
677
+ def clone(self):
678
+ cloned_history = deque([t.clone() for t in self._history_tensors], maxlen=8)
679
+ return SelfPlayGame(board=self._board.copy(), history_tensors=cloned_history)
680
+
681
+ def current_player(self):
682
+ return self.turn
683
+
684
+ def legal_actions(self):
685
+ return list(self._board.pseudo_legal_moves)
686
+
687
+ def apply(self, action):
688
+ self._board.push(action)
689
+ self._save_history_state()
690
+
691
+ def is_terminal(self) -> bool:
692
+ return (self._board.king(chess.WHITE) is None) or (
693
+ self._board.king(chess.BLACK) is None
694
+ )
695
+
696
+ def terminal_value(self, player):
697
+ white_king = self._board.king(chess.WHITE)
698
+ black_king = self._board.king(chess.BLACK)
699
+ if white_king is None:
700
+ return 1.0 if player == chess.BLACK else -1.0
701
+ if black_king is None:
702
+ return 1.0 if player == chess.WHITE else -1.0
703
+ return 0.0
704
+
705
+ def _compute_visibility(self, board: chess.Board, player: bool):
706
+ visible = set()
707
+ for sq in chess.SQUARES:
708
+ piece = board.piece_at(sq)
709
+ if piece and piece.color == player:
710
+ visible.add(sq)
711
+ visible.update(board.attacks(sq))
712
+
713
+ if piece.piece_type == chess.PAWN:
714
+ forward_offset = 8 if player == chess.WHITE else -8
715
+ fwd_sq = sq + forward_offset
716
+ if 0 <= fwd_sq < 64:
717
+ visible.add(fwd_sq)
718
+ rank = chess.square_rank(sq)
719
+ if (player == chess.WHITE and rank == 1) or (
720
+ player == chess.BLACK and rank == 6
721
+ ):
722
+ if board.piece_at(fwd_sq) is None:
723
+ visible.add(sq + 2 * forward_offset)
724
+ return visible
725
+
726
+ def encode(self):
727
+ history_list = list(self._history_tensors)
728
+ while len(history_list) < 8:
729
+ history_list.insert(0, history_list[0])
730
+
731
+ tensor = torch.cat(history_list, dim=1)
732
+
733
+ castle_eligibility = self.castle_eligible()
734
+
735
+ castling_channel = torch.zeros((64, 2), dtype=torch.float32)
736
+ castling_channel[0] += castle_eligibility[0]
737
+ castling_channel[1] += castle_eligibility[1]
738
+
739
+ return torch.cat([tensor, castling_channel], dim=1)
740
+
741
+ def action_index(self, action):
742
+ from_sq = action.from_square
743
+ to_sq = action.to_square
744
+
745
+ if not self.turn:
746
+ from_sq = chess.square_mirror(from_sq)
747
+ to_sq = chess.square_mirror(to_sq)
748
+
749
+ if action.promotion is not None:
750
+ promotion_idx = _PROMOTION_PIECE_ORDER.get(action.promotion)
751
+ if promotion_idx is None:
752
+ raise ValueError(f"Unsupported promotion piece: {action.promotion}")
753
+ return 4096 + from_sq * 4 + promotion_idx
754
+
755
+ return from_sq * 64 + to_sq
756
+ "#;
757
+
758
+ fn python_selfplay_module<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyModule>> {
759
+ let code = CString::new(PY_SELFPLAY_CODE).expect("Python code contains no NUL bytes");
760
+ let filename = CString::new("selfplay_ref.py").expect("valid filename");
761
+ let module = CString::new("selfplay_ref").expect("valid module name");
762
+ PyModule::from_code(py, code.as_c_str(), filename.as_c_str(), module.as_c_str())
763
+ }
764
+
765
+ fn flatten_2d(matrix: Vec<Vec<f32>>) -> Vec<f32> {
766
+ let mut flat = Vec::new();
767
+ for row in matrix {
768
+ flat.extend_from_slice(&row);
769
+ }
770
+ flat
771
+ }
772
+
773
+ fn py_encode_flat(game: &Bound<'_, PyAny>) -> PyResult<Vec<f32>> {
774
+ let tensor = game.call_method0("encode")?;
775
+ let list = tensor.call_method0("tolist")?;
776
+ let rows: Vec<Vec<f32>> = list.extract()?;
777
+ Ok(flatten_2d(rows))
778
+ }
779
+
780
+ #[test]
781
+ fn regression_default_state() -> PyResult<()> {
782
+ pyo3::prepare_freethreaded_python();
783
+ Python::with_gil(|py| {
784
+ let chess_mod = match PyModule::import(py, "chess") {
785
+ Ok(module) => module,
786
+ Err(_) => return Ok(()),
787
+ };
788
+ if PyModule::import(py, "torch").is_err() {
789
+ return Ok(());
790
+ }
791
+
792
+ let module = python_selfplay_module(py)?;
793
+ let cls = module.getattr("SelfPlayGame")?;
794
+ let py_game = cls.call0()?;
795
+ let rust_game = SelfPlayGame::new_internal(Board::default(), None);
796
+
797
+ let py_turn: bool = py_game.getattr("turn")?.extract()?;
798
+ assert_eq!(rust_game.side_to_move() == Color::White, py_turn);
799
+
800
+ let py_board: Vec<Vec<i8>> = py_game.getattr("board")?.extract()?;
801
+ assert_eq!(rust_game.board_matrix(), py_board);
802
+
803
+ let py_castle: (bool, bool) = py_game.call_method0("castle_eligible")?.extract()?;
804
+ assert_eq!(rust_game.castle_eligible_internal(), py_castle);
805
+
806
+ let py_encoded = py_encode_flat(&py_game)?;
807
+ assert_eq!(rust_game.encode_flat(), py_encoded);
808
+
809
+ let mv = chess_mod
810
+ .getattr("Move")?
811
+ .getattr("from_uci")?
812
+ .call1(("e2e4",))?;
813
+ let py_idx: usize = py_game.call_method1("action_index", (mv,))?.extract()?;
814
+ let rust_idx = rust_game
815
+ .action_index_internal(ChessMove::from_str("e2e4").unwrap())
816
+ .unwrap();
817
+ assert_eq!(rust_idx, py_idx);
818
+
819
+ Ok(())
820
+ })
821
+ }
822
+
823
+ #[test]
824
+ fn regression_after_moves() -> PyResult<()> {
825
+ pyo3::prepare_freethreaded_python();
826
+ Python::with_gil(|py| {
827
+ let chess_mod = match PyModule::import(py, "chess") {
828
+ Ok(module) => module,
829
+ Err(_) => return Ok(()),
830
+ };
831
+ if PyModule::import(py, "torch").is_err() {
832
+ return Ok(());
833
+ }
834
+
835
+ let module = python_selfplay_module(py)?;
836
+ let cls = module.getattr("SelfPlayGame")?;
837
+ let py_game = cls.call0()?;
838
+ let mut rust_game = SelfPlayGame::new_internal(Board::default(), None);
839
+
840
+ for uci in ["e2e4", "e7e5", "g1f3"] {
841
+ let mv = chess_mod
842
+ .getattr("Move")?
843
+ .getattr("from_uci")?
844
+ .call1((uci,))?;
845
+ py_game.call_method1("apply", (mv,))?;
846
+
847
+ let rust_mv = ChessMove::from_str(uci).unwrap();
848
+ rust_game.apply_internal(rust_mv);
849
+ }
850
+
851
+ let py_turn: bool = py_game.getattr("turn")?.extract()?;
852
+ assert_eq!(rust_game.side_to_move() == Color::White, py_turn);
853
+
854
+ let py_board: Vec<Vec<i8>> = py_game.getattr("board")?.extract()?;
855
+ assert_eq!(rust_game.board_matrix(), py_board);
856
+
857
+ let py_castle: (bool, bool) = py_game.call_method0("castle_eligible")?.extract()?;
858
+ assert_eq!(rust_game.castle_eligible_internal(), py_castle);
859
+
860
+ let py_encoded = py_encode_flat(&py_game)?;
861
+ assert_eq!(rust_game.encode_flat(), py_encoded);
862
+
863
+ Ok(())
864
+ })
865
+ }
866
+
867
+ #[test]
868
+ fn regression_promotion_action_index() -> PyResult<()> {
869
+ pyo3::prepare_freethreaded_python();
870
+ Python::with_gil(|py| {
871
+ let chess_mod = match PyModule::import(py, "chess") {
872
+ Ok(module) => module,
873
+ Err(_) => return Ok(()),
874
+ };
875
+ if PyModule::import(py, "torch").is_err() {
876
+ return Ok(());
877
+ }
878
+
879
+ let module = python_selfplay_module(py)?;
880
+ let cls = module.getattr("SelfPlayGame")?;
881
+
882
+ let fen = "8/P7/8/8/8/8/8/K6k w - - 0 1";
883
+ let py_board = chess_mod.getattr("Board")?.call1((fen,))?;
884
+ let py_game = cls.call1((py_board,))?;
885
+
886
+ let rust_board = Board::from_str(fen).unwrap();
887
+ let rust_game = SelfPlayGame::new_internal(rust_board, None);
888
+
889
+ let mv = chess_mod
890
+ .getattr("Move")?
891
+ .getattr("from_uci")?
892
+ .call1(("a7a8q",))?;
893
+ let py_idx: usize = py_game.call_method1("action_index", (mv,))?.extract()?;
894
+
895
+ let rust_idx = rust_game
896
+ .action_index_internal(ChessMove::from_str("a7a8q").unwrap())
897
+ .unwrap();
898
+ assert_eq!(rust_idx, py_idx);
899
+
900
+ Ok(())
901
+ })
902
+ }
903
+ }