kaggle-environments 1.17.2__py2.py3-none-any.whl → 1.17.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaggle-environments might be problematic. Click here for more details.
- kaggle_environments/__init__.py +1 -1
- kaggle_environments/envs/open_spiel/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/connect_four/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four.js +296 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy.py +86 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy_test.py +57 -0
- kaggle_environments/envs/open_spiel/observation.py +133 -0
- kaggle_environments/envs/open_spiel/open_spiel.py +83 -21
- kaggle_environments/envs/open_spiel/proxy.py +139 -0
- kaggle_environments/envs/open_spiel/proxy_test.py +64 -0
- kaggle_environments/envs/open_spiel/test_open_spiel.py +1 -1
- {kaggle_environments-1.17.2.dist-info → kaggle_environments-1.17.3.dist-info}/METADATA +1 -1
- {kaggle_environments-1.17.2.dist-info → kaggle_environments-1.17.3.dist-info}/RECORD +18 -9
- {kaggle_environments-1.17.2.dist-info → kaggle_environments-1.17.3.dist-info}/WHEEL +0 -0
- {kaggle_environments-1.17.2.dist-info → kaggle_environments-1.17.3.dist-info}/entry_points.txt +0 -0
- {kaggle_environments-1.17.2.dist-info → kaggle_environments-1.17.3.dist-info}/licenses/LICENSE +0 -0
- {kaggle_environments-1.17.2.dist-info → kaggle_environments-1.17.3.dist-info}/top_level.txt +0 -0
kaggle_environments/__init__.py
CHANGED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
function renderer(options) {
|
|
2
|
+
const { environment, step, parent, interactive, isInteractive } = options;
|
|
3
|
+
|
|
4
|
+
const DEFAULT_NUM_ROWS = 6;
|
|
5
|
+
const DEFAULT_NUM_COLS = 7;
|
|
6
|
+
const PLAYER_SYMBOLS = ['O', 'X']; // O: Player 0 (Yellow), X: Player 1 (Red)
|
|
7
|
+
const PLAYER_COLORS = ['#facc15', '#ef4444']; // Yellow for 'O', Red for 'X'
|
|
8
|
+
const EMPTY_CELL_COLOR = '#e5e7eb';
|
|
9
|
+
const BOARD_COLOR = '#3b82f6';
|
|
10
|
+
|
|
11
|
+
const SVG_NS = "http://www.w3.org/2000/svg";
|
|
12
|
+
const CELL_UNIT_SIZE = 100;
|
|
13
|
+
const CIRCLE_RADIUS = CELL_UNIT_SIZE * 0.42;
|
|
14
|
+
const SVG_VIEWBOX_WIDTH = DEFAULT_NUM_COLS * CELL_UNIT_SIZE;
|
|
15
|
+
const SVG_VIEWBOX_HEIGHT = DEFAULT_NUM_ROWS * CELL_UNIT_SIZE;
|
|
16
|
+
|
|
17
|
+
let currentBoardSvgElement = null;
|
|
18
|
+
let currentStatusTextElement = null;
|
|
19
|
+
let currentWinnerTextElement = null;
|
|
20
|
+
let currentMessageBoxElement = typeof document !== 'undefined' ? document.getElementById('messageBox') : null;
|
|
21
|
+
let currentRendererContainer = null;
|
|
22
|
+
let currentTitleElement = null;
|
|
23
|
+
|
|
24
|
+
function _showMessage(message, type = 'info', duration = 3000) {
|
|
25
|
+
if (typeof document === 'undefined' || !document.body) return;
|
|
26
|
+
if (!currentMessageBoxElement) {
|
|
27
|
+
currentMessageBoxElement = document.createElement('div');
|
|
28
|
+
currentMessageBoxElement.id = 'messageBox';
|
|
29
|
+
currentMessageBoxElement.style.position = 'fixed';
|
|
30
|
+
currentMessageBoxElement.style.top = '10px';
|
|
31
|
+
currentMessageBoxElement.style.left = '50%';
|
|
32
|
+
currentMessageBoxElement.style.transform = 'translateX(-50%)';
|
|
33
|
+
currentMessageBoxElement.style.padding = '0.75rem 1rem';
|
|
34
|
+
currentMessageBoxElement.style.borderRadius = '0.375rem';
|
|
35
|
+
currentMessageBoxElement.style.boxShadow = '0 2px 4px rgba(0,0,0,0.1)';
|
|
36
|
+
currentMessageBoxElement.style.zIndex = '1000';
|
|
37
|
+
currentMessageBoxElement.style.opacity = '0';
|
|
38
|
+
currentMessageBoxElement.style.transition = 'opacity 0.3s ease-in-out, background-color 0.3s';
|
|
39
|
+
currentMessageBoxElement.style.fontSize = '0.875rem';
|
|
40
|
+
currentMessageBoxElement.style.fontFamily = "'Inter', sans-serif";
|
|
41
|
+
document.body.appendChild(currentMessageBoxElement);
|
|
42
|
+
}
|
|
43
|
+
currentMessageBoxElement.textContent = message;
|
|
44
|
+
currentMessageBoxElement.style.backgroundColor = type === 'error' ? '#ef4444' : '#10b981';
|
|
45
|
+
currentMessageBoxElement.style.color = 'white';
|
|
46
|
+
currentMessageBoxElement.style.opacity = '1';
|
|
47
|
+
setTimeout(() => { if (currentMessageBoxElement) currentMessageBoxElement.style.opacity = '0'; }, duration);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
function _ensureRendererElements(parentElementToClear, rows, cols) {
|
|
51
|
+
if (!parentElementToClear) return false;
|
|
52
|
+
parentElementToClear.innerHTML = '';
|
|
53
|
+
|
|
54
|
+
currentRendererContainer = document.createElement('div');
|
|
55
|
+
currentRendererContainer.style.display = 'flex';
|
|
56
|
+
currentRendererContainer.style.flexDirection = 'column';
|
|
57
|
+
currentRendererContainer.style.alignItems = 'center';
|
|
58
|
+
currentRendererContainer.style.padding = '20px';
|
|
59
|
+
currentRendererContainer.style.boxSizing = 'border-box';
|
|
60
|
+
currentRendererContainer.style.width = '100%';
|
|
61
|
+
currentRendererContainer.style.height = '100%';
|
|
62
|
+
currentRendererContainer.style.fontFamily = "'Inter', sans-serif";
|
|
63
|
+
|
|
64
|
+
currentTitleElement = document.createElement('h1');
|
|
65
|
+
currentTitleElement.textContent = 'Connect Four';
|
|
66
|
+
currentTitleElement.style.fontSize = '1.875rem';
|
|
67
|
+
currentTitleElement.style.fontWeight = 'bold';
|
|
68
|
+
currentTitleElement.style.marginBottom = '1rem';
|
|
69
|
+
currentTitleElement.style.textAlign = 'center';
|
|
70
|
+
currentTitleElement.style.color = '#2563eb';
|
|
71
|
+
currentRendererContainer.appendChild(currentTitleElement);
|
|
72
|
+
|
|
73
|
+
currentBoardSvgElement = document.createElementNS(SVG_NS, "svg");
|
|
74
|
+
currentBoardSvgElement.setAttribute("viewBox", `0 0 ${SVG_VIEWBOX_WIDTH} ${SVG_VIEWBOX_HEIGHT}`);
|
|
75
|
+
currentBoardSvgElement.setAttribute("preserveAspectRatio", "xMidYMid meet");
|
|
76
|
+
currentBoardSvgElement.style.width = "auto";
|
|
77
|
+
currentBoardSvgElement.style.maxWidth = "500px";
|
|
78
|
+
currentBoardSvgElement.style.maxHeight = `calc(100vh - 200px)`;
|
|
79
|
+
currentBoardSvgElement.style.aspectRatio = `${cols} / ${rows}`;
|
|
80
|
+
currentBoardSvgElement.style.display = "block";
|
|
81
|
+
currentBoardSvgElement.style.margin = "0 auto 20px auto";
|
|
82
|
+
|
|
83
|
+
const boardBgRect = document.createElementNS(SVG_NS, "rect");
|
|
84
|
+
boardBgRect.setAttribute("x", "0");
|
|
85
|
+
boardBgRect.setAttribute("y", "0");
|
|
86
|
+
boardBgRect.setAttribute("width", SVG_VIEWBOX_WIDTH.toString());
|
|
87
|
+
boardBgRect.setAttribute("height", SVG_VIEWBOX_HEIGHT.toString());
|
|
88
|
+
boardBgRect.setAttribute("fill", BOARD_COLOR);
|
|
89
|
+
boardBgRect.setAttribute("rx", (CELL_UNIT_SIZE * 0.1).toString());
|
|
90
|
+
currentBoardSvgElement.appendChild(boardBgRect);
|
|
91
|
+
|
|
92
|
+
// SVG Circles are created with (0,0) being top-left visual circle
|
|
93
|
+
for (let r_visual = 0; r_visual < rows; r_visual++) {
|
|
94
|
+
for (let c_visual = 0; c_visual < cols; c_visual++) {
|
|
95
|
+
const circle = document.createElementNS(SVG_NS, "circle");
|
|
96
|
+
const cx = c_visual * CELL_UNIT_SIZE + CELL_UNIT_SIZE / 2;
|
|
97
|
+
const cy = r_visual * CELL_UNIT_SIZE + CELL_UNIT_SIZE / 2;
|
|
98
|
+
circle.setAttribute("id", `cell-${r_visual}-${c_visual}`);
|
|
99
|
+
circle.setAttribute("cx", cx.toString());
|
|
100
|
+
circle.setAttribute("cy", cy.toString());
|
|
101
|
+
circle.setAttribute("r", CIRCLE_RADIUS.toString());
|
|
102
|
+
circle.setAttribute("fill", EMPTY_CELL_COLOR);
|
|
103
|
+
currentBoardSvgElement.appendChild(circle);
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
currentRendererContainer.appendChild(currentBoardSvgElement);
|
|
107
|
+
|
|
108
|
+
const statusContainer = document.createElement('div');
|
|
109
|
+
statusContainer.style.padding = '10px 15px';
|
|
110
|
+
statusContainer.style.backgroundColor = 'white';
|
|
111
|
+
statusContainer.style.borderRadius = '8px';
|
|
112
|
+
statusContainer.style.boxShadow = '0 4px 6px -1px rgba(0,0,0,0.1), 0 2px 4px -1px rgba(0,0,0,0.06)';
|
|
113
|
+
statusContainer.style.textAlign = 'center';
|
|
114
|
+
statusContainer.style.width = 'auto';
|
|
115
|
+
statusContainer.style.minWidth = '200px';
|
|
116
|
+
statusContainer.style.maxWidth = '90vw';
|
|
117
|
+
currentRendererContainer.appendChild(statusContainer);
|
|
118
|
+
|
|
119
|
+
currentStatusTextElement = document.createElement('p');
|
|
120
|
+
currentStatusTextElement.style.fontSize = '1.1rem';
|
|
121
|
+
currentStatusTextElement.style.fontWeight = '600';
|
|
122
|
+
currentStatusTextElement.style.margin = '0 0 5px 0';
|
|
123
|
+
statusContainer.appendChild(currentStatusTextElement);
|
|
124
|
+
|
|
125
|
+
currentWinnerTextElement = document.createElement('p');
|
|
126
|
+
currentWinnerTextElement.style.fontSize = '1.25rem';
|
|
127
|
+
currentWinnerTextElement.style.fontWeight = '700';
|
|
128
|
+
currentWinnerTextElement.style.margin = '5px 0 0 0';
|
|
129
|
+
statusContainer.appendChild(currentWinnerTextElement);
|
|
130
|
+
|
|
131
|
+
parentElementToClear.appendChild(currentRendererContainer);
|
|
132
|
+
|
|
133
|
+
if (typeof document !== 'undefined' && !document.body.hasAttribute('data-renderer-initialized')) {
|
|
134
|
+
document.body.setAttribute('data-renderer-initialized', 'true');
|
|
135
|
+
}
|
|
136
|
+
return true;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
function _renderBoardDisplay_svg(gameStateToDisplay, displayRows, displayCols) {
|
|
140
|
+
if (!currentBoardSvgElement || !currentStatusTextElement || !currentWinnerTextElement) return;
|
|
141
|
+
|
|
142
|
+
if (!gameStateToDisplay || typeof gameStateToDisplay.board !== 'object' || !Array.isArray(gameStateToDisplay.board) || gameStateToDisplay.board.length === 0) {
|
|
143
|
+
currentStatusTextElement.textContent = "Waiting for game data...";
|
|
144
|
+
currentWinnerTextElement.textContent = "";
|
|
145
|
+
for (let r_visual = 0; r_visual < displayRows; r_visual++) {
|
|
146
|
+
for (let c_visual = 0; c_visual < displayCols; c_visual++) {
|
|
147
|
+
const circleElement = currentBoardSvgElement.querySelector(`#cell-${r_visual}-${c_visual}`);
|
|
148
|
+
if (circleElement) {
|
|
149
|
+
circleElement.setAttribute("fill", EMPTY_CELL_COLOR);
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
return;
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
const { board, current_player, is_terminal, winner } = gameStateToDisplay;
|
|
157
|
+
|
|
158
|
+
for (let r_data = 0; r_data < displayRows; r_data++) {
|
|
159
|
+
const dataRow = board[r_data];
|
|
160
|
+
if (!dataRow || !Array.isArray(dataRow) || dataRow.length !== displayCols) {
|
|
161
|
+
// Error handling for malformed row
|
|
162
|
+
for (let c_fill = 0; c_fill < displayCols; c_fill++) {
|
|
163
|
+
// Determine visual row for error display. If r_data=0 is top data,
|
|
164
|
+
// and we want to flip, then this error is for visual row (displayRows-1)-0.
|
|
165
|
+
const visual_row_for_error = (displayRows - 1) - r_data;
|
|
166
|
+
const circleElement = currentBoardSvgElement.querySelector(`#cell-${visual_row_for_error}-${c_fill}`);
|
|
167
|
+
if (circleElement) circleElement.setAttribute("fill", '#FF00FF'); // Magenta for error
|
|
168
|
+
}
|
|
169
|
+
continue;
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
const visual_svg_row_index = (displayRows - 1) - r_data;
|
|
173
|
+
|
|
174
|
+
for (let c_data = 0; c_data < displayCols; c_data++) { // c_data iterates through columns of `board[r_data]`
|
|
175
|
+
const originalCellValue = dataRow[c_data];
|
|
176
|
+
const cellValueForComparison = String(originalCellValue).trim().toLowerCase();
|
|
177
|
+
|
|
178
|
+
// The column index for SVG is the same as c_data
|
|
179
|
+
const visual_svg_col_index = c_data;
|
|
180
|
+
const circleElement = currentBoardSvgElement.querySelector(`#cell-${visual_svg_row_index}-${visual_svg_col_index}`);
|
|
181
|
+
|
|
182
|
+
if (!circleElement) continue;
|
|
183
|
+
|
|
184
|
+
let fillColor = EMPTY_CELL_COLOR;
|
|
185
|
+
if (cellValueForComparison === "o") {
|
|
186
|
+
fillColor = PLAYER_COLORS[0]; // Yellow
|
|
187
|
+
} else if (cellValueForComparison === "x") {
|
|
188
|
+
fillColor = PLAYER_COLORS[1]; // Red
|
|
189
|
+
}
|
|
190
|
+
circleElement.setAttribute("fill", fillColor);
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
currentStatusTextElement.innerHTML = '';
|
|
195
|
+
currentWinnerTextElement.innerHTML = '';
|
|
196
|
+
if (is_terminal) {
|
|
197
|
+
currentStatusTextElement.textContent = "Game Over!";
|
|
198
|
+
if (winner !== null && winner !== undefined) {
|
|
199
|
+
if (String(winner).toLowerCase() === 'draw') {
|
|
200
|
+
currentWinnerTextElement.textContent = "It's a Draw!";
|
|
201
|
+
} else {
|
|
202
|
+
let winnerSymbolDisplay, winnerColorDisplay;
|
|
203
|
+
if (String(winner).toLowerCase() === "o") {
|
|
204
|
+
winnerSymbolDisplay = PLAYER_SYMBOLS[0];
|
|
205
|
+
winnerColorDisplay = PLAYER_COLORS[0];
|
|
206
|
+
} else if (String(winner).toLowerCase() === "x") {
|
|
207
|
+
winnerSymbolDisplay = PLAYER_SYMBOLS[1];
|
|
208
|
+
winnerColorDisplay = PLAYER_COLORS[1];
|
|
209
|
+
}
|
|
210
|
+
if (winnerSymbolDisplay) {
|
|
211
|
+
currentWinnerTextElement.innerHTML = `Player <span style="color: ${winnerColorDisplay}; font-weight: bold;">${winnerSymbolDisplay}</span> Wins!`;
|
|
212
|
+
} else {
|
|
213
|
+
currentWinnerTextElement.textContent = `Winner: ${String(winner).toUpperCase()}`;
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
} else { currentWinnerTextElement.textContent = "Game ended."; }
|
|
217
|
+
} else {
|
|
218
|
+
let playerSymbolToDisplay, playerColorToDisplay;
|
|
219
|
+
if (String(current_player).toLowerCase() === "o") {
|
|
220
|
+
playerSymbolToDisplay = PLAYER_SYMBOLS[0];
|
|
221
|
+
playerColorToDisplay = PLAYER_COLORS[0];
|
|
222
|
+
} else if (String(current_player).toLowerCase() === "x") {
|
|
223
|
+
playerSymbolToDisplay = PLAYER_SYMBOLS[1];
|
|
224
|
+
playerColorToDisplay = PLAYER_COLORS[1];
|
|
225
|
+
}
|
|
226
|
+
if (playerSymbolToDisplay) {
|
|
227
|
+
currentStatusTextElement.innerHTML = `Current Player: <span style="color: ${playerColorToDisplay}; font-weight: bold;">${playerSymbolToDisplay}</span>`;
|
|
228
|
+
} else {
|
|
229
|
+
currentStatusTextElement.textContent = "Waiting for player...";
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
// --- Main execution logic ---
|
|
235
|
+
if (!_ensureRendererElements(parent, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS)) {
|
|
236
|
+
if (parent && typeof parent.innerHTML !== 'undefined') {
|
|
237
|
+
parent.innerHTML = "<p style='color:red; font-family: sans-serif;'>Critical Error: Renderer element setup failed.</p>";
|
|
238
|
+
}
|
|
239
|
+
return;
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
if (!environment || !environment.steps || !environment.steps[step]) {
|
|
243
|
+
_renderBoardDisplay_svg(null, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS);
|
|
244
|
+
if(currentStatusTextElement) currentStatusTextElement.textContent = "Initializing environment...";
|
|
245
|
+
return;
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
const currentStepAgents = environment.steps[step];
|
|
249
|
+
if (!currentStepAgents || !Array.isArray(currentStepAgents) || currentStepAgents.length === 0) {
|
|
250
|
+
_renderBoardDisplay_svg(null, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS);
|
|
251
|
+
if(currentStatusTextElement) currentStatusTextElement.textContent = "Waiting for agent data...";
|
|
252
|
+
return;
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
const gameMasterAgentIndex = currentStepAgents.length - 1;
|
|
256
|
+
const gameMasterAgent = currentStepAgents[gameMasterAgentIndex];
|
|
257
|
+
|
|
258
|
+
if (!gameMasterAgent || typeof gameMasterAgent.observation === 'undefined') {
|
|
259
|
+
_renderBoardDisplay_svg(null, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS);
|
|
260
|
+
if(currentStatusTextElement) currentStatusTextElement.textContent = "Waiting for observation data...";
|
|
261
|
+
return;
|
|
262
|
+
}
|
|
263
|
+
const observationForRenderer = gameMasterAgent.observation;
|
|
264
|
+
|
|
265
|
+
let gameSpecificState = null;
|
|
266
|
+
|
|
267
|
+
if (observationForRenderer && typeof observationForRenderer.observation_string === 'string' && observationForRenderer.observation_string.trim() !== '') {
|
|
268
|
+
try {
|
|
269
|
+
gameSpecificState = JSON.parse(observationForRenderer.observation_string);
|
|
270
|
+
} catch (e) {
|
|
271
|
+
_showMessage("Error: Corrupted game state (obs_string).", 'error');
|
|
272
|
+
}
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
if (!gameSpecificState && observationForRenderer && typeof observationForRenderer.json === 'string' && observationForRenderer.json.trim() !== '') {
|
|
276
|
+
try {
|
|
277
|
+
gameSpecificState = JSON.parse(observationForRenderer.json);
|
|
278
|
+
} catch (e) {
|
|
279
|
+
_showMessage("Error: Corrupted game state (json).", 'error');
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
if (!gameSpecificState && observationForRenderer &&
|
|
284
|
+
Array.isArray(observationForRenderer.board) &&
|
|
285
|
+
typeof observationForRenderer.current_player !== 'undefined'
|
|
286
|
+
) {
|
|
287
|
+
if( (observationForRenderer.board.length === DEFAULT_NUM_ROWS &&
|
|
288
|
+
(observationForRenderer.board.length === 0 ||
|
|
289
|
+
(Array.isArray(observationForRenderer.board[0]) && observationForRenderer.board[0].length === DEFAULT_NUM_COLS)))
|
|
290
|
+
){
|
|
291
|
+
gameSpecificState = observationForRenderer;
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
_renderBoardDisplay_svg(gameSpecificState, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS);
|
|
296
|
+
}
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Change Connect Four state and action string representations."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from ... import proxy
|
|
7
|
+
import pyspiel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ConnectFourState(proxy.State):
|
|
11
|
+
"""Connect Four state proxy."""
|
|
12
|
+
|
|
13
|
+
def _player_string(self, player: int) -> str:
|
|
14
|
+
if player < 0:
|
|
15
|
+
return pyspiel.PlayerId(player).name.lower()
|
|
16
|
+
elif player == 0:
|
|
17
|
+
return 'x'
|
|
18
|
+
elif player == 1:
|
|
19
|
+
return 'o'
|
|
20
|
+
else:
|
|
21
|
+
raise ValueError(f'Invalid player: {player}')
|
|
22
|
+
|
|
23
|
+
def state_dict(self) -> dict[str, Any]:
|
|
24
|
+
# row 0 is now bottom row
|
|
25
|
+
rows = reversed(self.to_string().strip().split('\n'))
|
|
26
|
+
board = [list(row) for row in rows]
|
|
27
|
+
winner = None
|
|
28
|
+
if self.is_terminal():
|
|
29
|
+
if self.returns()[0] > self.returns()[1]:
|
|
30
|
+
winner = 'x'
|
|
31
|
+
elif self.returns()[1] > self.returns()[0]:
|
|
32
|
+
winner = 'o'
|
|
33
|
+
else:
|
|
34
|
+
winner = 'draw'
|
|
35
|
+
return {
|
|
36
|
+
'board': board,
|
|
37
|
+
'current_player': self._player_string(self.current_player()),
|
|
38
|
+
'is_terminal': self.is_terminal(),
|
|
39
|
+
'winner': winner,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
def to_json(self) -> str:
|
|
43
|
+
return json.dumps(self.state_dict())
|
|
44
|
+
|
|
45
|
+
def action_to_dict(self, action: int) -> dict[str, Any]:
|
|
46
|
+
return {'col': action}
|
|
47
|
+
|
|
48
|
+
def action_to_json(self, action: int) -> str:
|
|
49
|
+
return json.dumps(self.action_to_dict(action))
|
|
50
|
+
|
|
51
|
+
def dict_to_action(self, action_dict: dict[str, Any]) -> int:
|
|
52
|
+
return int(action_dict['col'])
|
|
53
|
+
|
|
54
|
+
def json_to_action(self, action_json: str) -> int:
|
|
55
|
+
action_dict = json.loads(action_json)
|
|
56
|
+
return self.dict_to_action(action_dict)
|
|
57
|
+
|
|
58
|
+
def observation_string(self, player: int) -> str:
|
|
59
|
+
return self.observation_json(player)
|
|
60
|
+
|
|
61
|
+
def observation_json(self, player: int) -> str:
|
|
62
|
+
del player
|
|
63
|
+
return self.to_json()
|
|
64
|
+
|
|
65
|
+
def __str__(self):
|
|
66
|
+
return self.to_json()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class ConnectFourGame(proxy.Game):
|
|
70
|
+
"""Connect Four game proxy."""
|
|
71
|
+
|
|
72
|
+
def __init__(self, params: Any | None = None):
|
|
73
|
+
params = params or {}
|
|
74
|
+
wrapped = pyspiel.load_game('connect_four', params)
|
|
75
|
+
super().__init__(
|
|
76
|
+
wrapped,
|
|
77
|
+
short_name='connect_four_proxy',
|
|
78
|
+
long_name='Connect Four (proxy)',
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def new_initial_state(self, *args) -> ConnectFourState:
|
|
82
|
+
return ConnectFourState(self.__wrapped__.new_initial_state(*args),
|
|
83
|
+
game=self)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
pyspiel.register_game(ConnectFourGame().get_type(), ConnectFourGame)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Test for proxied Connect Four game."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from absl.testing import absltest
|
|
6
|
+
from absl.testing import parameterized
|
|
7
|
+
import pyspiel
|
|
8
|
+
from . import connect_four_proxy as connect_four
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
NUM_ROWS = 6
|
|
12
|
+
NUM_COLS = 7
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ConnectFourTest(parameterized.TestCase):
|
|
16
|
+
|
|
17
|
+
def test_game_is_registered(self):
|
|
18
|
+
game = pyspiel.load_game('connect_four_proxy')
|
|
19
|
+
self.assertIsInstance(game, connect_four.ConnectFourGame)
|
|
20
|
+
|
|
21
|
+
def test_random_sim(self):
|
|
22
|
+
game = connect_four.ConnectFourGame()
|
|
23
|
+
pyspiel.random_sim_test(game, num_sims=10, serialize=False, verbose=False)
|
|
24
|
+
|
|
25
|
+
def test_state_to_json(self):
|
|
26
|
+
game = connect_four.ConnectFourGame()
|
|
27
|
+
state = game.new_initial_state()
|
|
28
|
+
json_state = json.loads(state.to_json())
|
|
29
|
+
expected_board = [['.'] * NUM_COLS for _ in range(NUM_ROWS)]
|
|
30
|
+
self.assertEqual(json_state['board'], expected_board)
|
|
31
|
+
self.assertEqual(json_state['current_player'], 'x')
|
|
32
|
+
state.apply_action(3)
|
|
33
|
+
json_state = json.loads(state.to_json())
|
|
34
|
+
expected_board[0][3] = 'x'
|
|
35
|
+
self.assertEqual(json_state['board'], expected_board)
|
|
36
|
+
self.assertEqual(json_state['current_player'], 'o')
|
|
37
|
+
state.apply_action(2)
|
|
38
|
+
json_state = json.loads(state.to_json())
|
|
39
|
+
expected_board[0][2] = 'o'
|
|
40
|
+
self.assertEqual(json_state['board'], expected_board)
|
|
41
|
+
self.assertEqual(json_state['current_player'], 'x')
|
|
42
|
+
state.apply_action(2)
|
|
43
|
+
json_state = json.loads(state.to_json())
|
|
44
|
+
expected_board[1][2] = 'x'
|
|
45
|
+
self.assertEqual(json_state['board'], expected_board)
|
|
46
|
+
self.assertEqual(json_state['current_player'], 'o')
|
|
47
|
+
|
|
48
|
+
def test_action_to_json(self):
|
|
49
|
+
game = connect_four.ConnectFourGame()
|
|
50
|
+
state = game.new_initial_state()
|
|
51
|
+
action = json.loads(state.action_to_json(3))
|
|
52
|
+
self.assertEqual(json.loads(state.action_to_json(3)), action)
|
|
53
|
+
self.assertEqual(action['col'], 3)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
if __name__ == '__main__':
|
|
57
|
+
absltest.main()
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Copyright 2019 DeepMind Technologies Limited
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""An observation of a game.
|
|
16
|
+
|
|
17
|
+
This is intended to be the main way to get observations of states in Python.
|
|
18
|
+
The usage pattern is as follows:
|
|
19
|
+
|
|
20
|
+
0. Create the game we will be playing
|
|
21
|
+
1. Create each kind of observation required, using `make_observation`
|
|
22
|
+
2. Every time a new observation is required, call:
|
|
23
|
+
`observation.set_from(state, player)`
|
|
24
|
+
The tensor contained in the Observation class will be updated with an
|
|
25
|
+
observation of the supplied state. This tensor is updated in-place, so if
|
|
26
|
+
you wish to retain it, you must make a copy.
|
|
27
|
+
|
|
28
|
+
The following options are available when creating an Observation:
|
|
29
|
+
- perfect_recall: if true, each observation must allow the observing player to
|
|
30
|
+
reconstruct their history of actions and observations.
|
|
31
|
+
- public_info: if true, the observation should include public information
|
|
32
|
+
- private_info: specifies for which players private information should be
|
|
33
|
+
included - all players, the observing player, or no players
|
|
34
|
+
- params: game-specific parameters for observations
|
|
35
|
+
|
|
36
|
+
We ultimately aim to have all games support all combinations of these arguments.
|
|
37
|
+
However, initially many games will only support the combinations corresponding
|
|
38
|
+
to ObservationTensor and InformationStateTensor:
|
|
39
|
+
- ObservationTensor: perfect_recall=False, public_info=True,
|
|
40
|
+
private_info=SinglePlayer
|
|
41
|
+
- InformationStateTensor: perfect_recall=True, public_info=True,
|
|
42
|
+
private_info=SinglePlayer
|
|
43
|
+
|
|
44
|
+
Three formats of observation are supported:
|
|
45
|
+
a. 1-D numpy array, accessed by `observation.tensor`
|
|
46
|
+
b. Dict of numpy arrays, accessed by `observation.dict`. These are pieces of the
|
|
47
|
+
1-D array, reshaped. The np.array objects refer to the same memory as the
|
|
48
|
+
1-D array (no copying!).
|
|
49
|
+
c. String, hopefully human-readable (primarily for debugging purposes)
|
|
50
|
+
|
|
51
|
+
For usage examples, see `observation_test.py`.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
import numpy as np
|
|
55
|
+
|
|
56
|
+
import pyspiel
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Corresponds to the old information_state_XXX methods.
|
|
60
|
+
INFO_STATE_OBS_TYPE = pyspiel.IIGObservationType(perfect_recall=True)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class _Observation:
|
|
64
|
+
"""Contains an observation from a game."""
|
|
65
|
+
|
|
66
|
+
def __init__(self, game, observer):
|
|
67
|
+
self._observation = pyspiel._Observation(game, observer)
|
|
68
|
+
self.dict = {}
|
|
69
|
+
if self._observation.has_tensor():
|
|
70
|
+
self.tensor = np.frombuffer(self._observation, np.float32)
|
|
71
|
+
offset = 0
|
|
72
|
+
for tensor_info in self._observation.tensors_info():
|
|
73
|
+
size = np.prod(tensor_info.shape, dtype=np.int64)
|
|
74
|
+
values = self.tensor[offset:offset + size].reshape(tensor_info.shape)
|
|
75
|
+
self.dict[tensor_info.name] = values
|
|
76
|
+
offset += size
|
|
77
|
+
else:
|
|
78
|
+
self.tensor = None
|
|
79
|
+
|
|
80
|
+
def set_from(self, state, player):
|
|
81
|
+
self._observation.set_from(state, player)
|
|
82
|
+
|
|
83
|
+
def string_from(self, state, player):
|
|
84
|
+
return (self._observation.string_from(state, player)
|
|
85
|
+
if self._observation.has_string() else None)
|
|
86
|
+
|
|
87
|
+
def compress(self):
|
|
88
|
+
return self._observation.compress()
|
|
89
|
+
|
|
90
|
+
def decompress(self, compressed_observation):
|
|
91
|
+
self._observation.decompress(compressed_observation)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def make_observation(
|
|
95
|
+
game,
|
|
96
|
+
imperfect_information_observation_type=None,
|
|
97
|
+
params=None,
|
|
98
|
+
):
|
|
99
|
+
"""Returns an _Observation instance if the imperfect_information_observation_type is supported, otherwise None."""
|
|
100
|
+
params = params or {}
|
|
101
|
+
if hasattr(game, 'make_py_observer'):
|
|
102
|
+
return game.make_py_observer(imperfect_information_observation_type, params)
|
|
103
|
+
else:
|
|
104
|
+
if imperfect_information_observation_type is not None:
|
|
105
|
+
observer = game.make_observer(
|
|
106
|
+
imperfect_information_observation_type, params
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
observer = game.make_observer(params)
|
|
110
|
+
if observer is None:
|
|
111
|
+
return None
|
|
112
|
+
return _Observation(game, observer)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class IIGObserverForPublicInfoGame:
|
|
116
|
+
"""Observer for imperfect information obvservations of public-info games."""
|
|
117
|
+
|
|
118
|
+
def __init__(self, iig_obs_type, params):
|
|
119
|
+
if params:
|
|
120
|
+
raise ValueError(f'Observation parameters not supported; passed {params}')
|
|
121
|
+
self._iig_obs_type = iig_obs_type
|
|
122
|
+
self.tensor = None
|
|
123
|
+
self.dict = {}
|
|
124
|
+
|
|
125
|
+
def set_from(self, state, player):
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
def string_from(self, state, player):
|
|
129
|
+
del player
|
|
130
|
+
if self._iig_obs_type.public_info:
|
|
131
|
+
return state.history_str()
|
|
132
|
+
else:
|
|
133
|
+
return '' # No private information to return
|
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
"""Kaggle environment wrapper for OpenSpiel games."""
|
|
2
2
|
|
|
3
3
|
import copy
|
|
4
|
+
import os
|
|
5
|
+
import pathlib
|
|
4
6
|
import random
|
|
5
|
-
from typing import Any
|
|
7
|
+
from typing import Any, Callable
|
|
6
8
|
|
|
7
9
|
from kaggle_environments import core
|
|
8
10
|
from kaggle_environments import utils
|
|
9
11
|
import numpy as np
|
|
10
12
|
import pyspiel
|
|
11
|
-
|
|
13
|
+
from .games.connect_four import connect_four_proxy
|
|
12
14
|
|
|
13
15
|
DEFAULT_ACT_TIMEOUT = 5
|
|
14
16
|
DEFAULT_RUN_TIMEOUT = 1200
|
|
@@ -236,43 +238,84 @@ def renderer(state: list[utils.Struct], env: core.Environment) -> str:
|
|
|
236
238
|
print(f"Error rendering {env.name} at state: {state}.")
|
|
237
239
|
raise e
|
|
238
240
|
|
|
241
|
+
# --- HTML Renderer Logic ---
|
|
239
242
|
|
|
240
|
-
def
|
|
241
|
-
"""Provides the
|
|
243
|
+
def _default_html_renderer() -> str:
|
|
244
|
+
"""Provides the JavaScript string for the default HTML renderer."""
|
|
242
245
|
return """
|
|
243
246
|
function renderer(context) {
|
|
244
247
|
const { parent, environment, step } = context;
|
|
245
248
|
parent.innerHTML = ''; // Clear previous rendering
|
|
246
249
|
|
|
247
|
-
// Get the current step's data
|
|
248
250
|
const currentStepData = environment.steps[step];
|
|
249
|
-
|
|
250
|
-
|
|
251
|
+
if (!currentStepData) {
|
|
252
|
+
parent.textContent = "Waiting for step data...";
|
|
253
|
+
return;
|
|
254
|
+
}
|
|
255
|
+
const numAgents = currentStepData.length;
|
|
256
|
+
const gameMasterIndex = numAgents - 1;
|
|
251
257
|
let obsString = "Observation not available for this step.";
|
|
258
|
+
let title = `Step: ${step}`;
|
|
259
|
+
|
|
260
|
+
if (environment.configuration && environment.configuration.openSpielGameName) {
|
|
261
|
+
title = `${environment.configuration.openSpielGameName} - Step: ${step}`;
|
|
262
|
+
}
|
|
252
263
|
|
|
253
|
-
// Try to get
|
|
254
|
-
if (currentStepData
|
|
264
|
+
// Try to get obs_string from game_master of current step
|
|
265
|
+
if (currentStepData[gameMasterIndex] &&
|
|
266
|
+
currentStepData[gameMasterIndex].observation &&
|
|
267
|
+
typeof currentStepData[gameMasterIndex].observation.observation_string === 'string') {
|
|
255
268
|
obsString = currentStepData[gameMasterIndex].observation.observation_string;
|
|
256
|
-
}
|
|
257
|
-
|
|
269
|
+
}
|
|
270
|
+
// Fallback to initial step if current is unavailable (e.g. very first render call)
|
|
271
|
+
else if (step === 0 && environment.steps[0] && environment.steps[0][gameMasterIndex] &&
|
|
272
|
+
environment.steps[0][gameMasterIndex].observation &&
|
|
273
|
+
typeof environment.steps[0][gameMasterIndex].observation.observation_string === 'string') {
|
|
258
274
|
obsString = environment.steps[0][gameMasterIndex].observation.observation_string;
|
|
259
275
|
}
|
|
260
276
|
|
|
261
|
-
// Create a <pre> element to preserve formatting
|
|
262
277
|
const pre = document.createElement("pre");
|
|
263
|
-
pre.style.fontFamily = "monospace";
|
|
264
|
-
pre.style.margin = "10px";
|
|
278
|
+
pre.style.fontFamily = "monospace";
|
|
279
|
+
pre.style.margin = "10px";
|
|
265
280
|
pre.style.border = "1px solid #ccc";
|
|
266
|
-
pre.style.padding = "
|
|
267
|
-
pre.style.backgroundColor = "#
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
pre.textContent = `Step: ${step}\\n\\n${obsString}`; // Add step number for context
|
|
281
|
+
pre.style.padding = "10px";
|
|
282
|
+
pre.style.backgroundColor = "#f9f9f9";
|
|
283
|
+
pre.style.whiteSpace = "pre-wrap";
|
|
284
|
+
pre.style.wordBreak = "break-all";
|
|
271
285
|
|
|
286
|
+
pre.textContent = `${title}\\n\\n${obsString}`;
|
|
272
287
|
parent.appendChild(pre);
|
|
273
288
|
}
|
|
274
289
|
"""
|
|
275
290
|
|
|
291
|
+
def _get_html_renderer_content(
|
|
292
|
+
open_spiel_short_name: str,
|
|
293
|
+
base_path_for_custom_renderers: pathlib.Path,
|
|
294
|
+
default_renderer_func: Callable[[], str]
|
|
295
|
+
) -> str:
|
|
296
|
+
"""
|
|
297
|
+
Tries to load a custom JS renderer for the game.
|
|
298
|
+
Falls back to the default renderer if not found or on error.
|
|
299
|
+
"""
|
|
300
|
+
if "proxy" not in open_spiel_short_name:
|
|
301
|
+
return default_renderer_func()
|
|
302
|
+
sanitized_game_name = open_spiel_short_name.replace('-', '_').replace('.', '_')
|
|
303
|
+
sanitized_game_name = sanitized_game_name.removesuffix("_proxy")
|
|
304
|
+
custom_renderer_js_path = (
|
|
305
|
+
base_path_for_custom_renderers /
|
|
306
|
+
sanitized_game_name /
|
|
307
|
+
f"{sanitized_game_name}.js"
|
|
308
|
+
)
|
|
309
|
+
if custom_renderer_js_path.is_file():
|
|
310
|
+
try:
|
|
311
|
+
with open(custom_renderer_js_path, "r", encoding="utf-8") as f:
|
|
312
|
+
content = f.read()
|
|
313
|
+
print(f"INFO: Using custom HTML renderer for {open_spiel_short_name} from {custom_renderer_js_path}")
|
|
314
|
+
return content
|
|
315
|
+
except Exception as e_render:
|
|
316
|
+
pass
|
|
317
|
+
return default_renderer_func()
|
|
318
|
+
|
|
276
319
|
|
|
277
320
|
# --- Agents ---
|
|
278
321
|
def random_agent(
|
|
@@ -299,6 +342,8 @@ def _register_open_spiel_envs(
|
|
|
299
342
|
successfully_loaded_games = []
|
|
300
343
|
skipped_games = []
|
|
301
344
|
registered_envs = {}
|
|
345
|
+
current_file_dir = pathlib.Path(__file__).parent.resolve()
|
|
346
|
+
custom_renderers_base = current_file_dir / "games"
|
|
302
347
|
if games_list is None:
|
|
303
348
|
games_list = pyspiel.registered_names()
|
|
304
349
|
for short_name in games_list:
|
|
@@ -330,16 +375,33 @@ https://github.com/google-deepmind/open_spiel/tree/master/open_spiel/games
|
|
|
330
375
|
game_spec["observation"]["properties"]["openSpielGameName"][
|
|
331
376
|
"default"] = short_name
|
|
332
377
|
|
|
378
|
+
# Building html_renderer_callable is a bit convoluted but other approaches
|
|
379
|
+
# failed for a variety of reasons. Returning a simple lambda function
|
|
380
|
+
# doesn't work because of late-binding. The last env registered will
|
|
381
|
+
# overwrite all previous renderers.
|
|
382
|
+
js_string_content = _get_html_renderer_content(
|
|
383
|
+
open_spiel_short_name=short_name,
|
|
384
|
+
base_path_for_custom_renderers=custom_renderers_base,
|
|
385
|
+
default_renderer_func=_default_html_renderer,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
def create_html_renderer_closure(captured_content):
|
|
389
|
+
def html_renderer_callable_no_args():
|
|
390
|
+
return captured_content
|
|
391
|
+
return html_renderer_callable_no_args
|
|
392
|
+
|
|
393
|
+
html_renderer_callable = create_html_renderer_closure(js_string_content)
|
|
394
|
+
|
|
333
395
|
registered_envs[env_name] = {
|
|
334
396
|
"specification": game_spec,
|
|
335
397
|
"interpreter": interpreter,
|
|
336
398
|
"renderer": renderer,
|
|
337
|
-
"html_renderer":
|
|
399
|
+
"html_renderer": html_renderer_callable,
|
|
338
400
|
"agents": agents,
|
|
339
401
|
}
|
|
340
402
|
successfully_loaded_games.append(short_name)
|
|
341
403
|
|
|
342
|
-
except Exception: # pylint: disable=broad-exception-caught
|
|
404
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
343
405
|
skipped_games.append(short_name)
|
|
344
406
|
continue
|
|
345
407
|
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""OpenSpiel Game and State proxies.
|
|
2
|
+
|
|
3
|
+
Proxies that act as a pyspiel.State/Game by wrapping the original object and
|
|
4
|
+
forwarding calls. Subclassing allows to override specific methods or add
|
|
5
|
+
additional functionality, or payload to the State/Game object.
|
|
6
|
+
|
|
7
|
+
WARNING: Serialization of proxy games and states is not supported.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from . import observation
|
|
13
|
+
import pyspiel
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class State(pyspiel.State):
|
|
17
|
+
"""Base class for a pyspiel.State proxy."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, wrapped: pyspiel.State, game: 'Game'):
|
|
20
|
+
super().__init__(game)
|
|
21
|
+
self.__wrapped__ = wrapped
|
|
22
|
+
|
|
23
|
+
def current_player(self) -> int:
|
|
24
|
+
return self.__wrapped__.current_player()
|
|
25
|
+
|
|
26
|
+
def _legal_actions(self, player: int) -> list[int]:
|
|
27
|
+
return self.__wrapped__.legal_actions(player)
|
|
28
|
+
|
|
29
|
+
def _apply_action(self, action: int) -> None:
|
|
30
|
+
return self.__wrapped__.apply_action(action)
|
|
31
|
+
|
|
32
|
+
def _action_to_string(self, player: int, action: int) -> str:
|
|
33
|
+
return self.__wrapped__.action_to_string(player, action)
|
|
34
|
+
|
|
35
|
+
def chance_outcomes(self) -> list[tuple[int, float]]:
|
|
36
|
+
return self.__wrapped__.chance_outcomes()
|
|
37
|
+
|
|
38
|
+
def is_terminal(self) -> bool:
|
|
39
|
+
return self.__wrapped__.is_terminal()
|
|
40
|
+
|
|
41
|
+
def returns(self) -> list[float]:
|
|
42
|
+
return self.__wrapped__.returns()
|
|
43
|
+
|
|
44
|
+
def rewards(self) -> list[float]:
|
|
45
|
+
return self.__wrapped__.rewards()
|
|
46
|
+
|
|
47
|
+
def __str__(self) -> str:
|
|
48
|
+
return self.__wrapped__.__str__()
|
|
49
|
+
|
|
50
|
+
def to_string(self) -> str:
|
|
51
|
+
return self.__wrapped__.to_string()
|
|
52
|
+
|
|
53
|
+
def __getattr__(self, name: str) -> Any:
|
|
54
|
+
# Escape hatch when proxying Python implementations that have attributes
|
|
55
|
+
# that need to be accessed, e.g. TicTacToeState.board from its observer.
|
|
56
|
+
return object.__getattribute__(self.__wrapped__, name)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class Game(pyspiel.Game):
|
|
60
|
+
"""Base class for a pyspiel.Game proxy."""
|
|
61
|
+
|
|
62
|
+
def __init__(self, wrapped: pyspiel.Game, **kwargs):
|
|
63
|
+
# TODO(hennes): Add serialization.
|
|
64
|
+
game_info = pyspiel.GameInfo(
|
|
65
|
+
num_distinct_actions=wrapped.num_distinct_actions(),
|
|
66
|
+
max_chance_outcomes=wrapped.max_chance_outcomes(),
|
|
67
|
+
num_players=wrapped.num_players(),
|
|
68
|
+
min_utility=wrapped.min_utility(),
|
|
69
|
+
max_utility=wrapped.max_utility(),
|
|
70
|
+
utility_sum=wrapped.utility_sum(),
|
|
71
|
+
max_game_length=wrapped.max_game_length(),
|
|
72
|
+
)
|
|
73
|
+
super().__init__(
|
|
74
|
+
_game_type(wrapped.get_type(), **kwargs),
|
|
75
|
+
game_info,
|
|
76
|
+
wrapped.get_parameters(),
|
|
77
|
+
)
|
|
78
|
+
self.__wrapped__ = wrapped
|
|
79
|
+
|
|
80
|
+
def new_initial_state(self, from_string: str | None = None) -> State:
|
|
81
|
+
args = () if from_string is None else (from_string)
|
|
82
|
+
return State(wrapped=self.__wrapped__.new_initial_state(*args), game=self)
|
|
83
|
+
|
|
84
|
+
def max_chance_nodes_in_history(self) -> int:
|
|
85
|
+
return self.__wrapped__.max_chance_nodes_in_history()
|
|
86
|
+
|
|
87
|
+
def make_py_observer(
|
|
88
|
+
self,
|
|
89
|
+
iig_obs_type: pyspiel.IIGObservationType | None = None,
|
|
90
|
+
params: dict[str, Any] | None = None,
|
|
91
|
+
) -> pyspiel.Observer:
|
|
92
|
+
return _Observation(
|
|
93
|
+
observation.make_observation(self.__wrapped__, iig_obs_type, params)
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class _Observation(observation._Observation): # pylint: disable=protected-access
|
|
98
|
+
"""_Observation proxy that passes the wrapped state to the observation."""
|
|
99
|
+
|
|
100
|
+
def __init__(self, wrapped: observation._Observation):
|
|
101
|
+
self.__wrapped__ = wrapped
|
|
102
|
+
self.dict = self.__wrapped__.dict
|
|
103
|
+
self.tensor = self.__wrapped__.tensor
|
|
104
|
+
|
|
105
|
+
def set_from(self, state: State, player: int):
|
|
106
|
+
self.__wrapped__.set_from(state.__wrapped__, player)
|
|
107
|
+
|
|
108
|
+
def string_from(self, state: State, player: int) -> str | None:
|
|
109
|
+
return self.__wrapped__.string_from(state.__wrapped__, player)
|
|
110
|
+
|
|
111
|
+
def compress(self) -> Any:
|
|
112
|
+
return self.__wrapped__.compress()
|
|
113
|
+
|
|
114
|
+
def decompress(self, compressed_observation: Any):
|
|
115
|
+
self.__wrapped__.decompress(compressed_observation)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _game_type(game_type: pyspiel.GameType, **overrides) -> pyspiel.GameType:
|
|
119
|
+
"""Returns a GameType with the given overrides."""
|
|
120
|
+
kwargs = dict(
|
|
121
|
+
short_name=game_type.short_name,
|
|
122
|
+
long_name=game_type.long_name,
|
|
123
|
+
dynamics=game_type.dynamics,
|
|
124
|
+
chance_mode=game_type.chance_mode,
|
|
125
|
+
information=game_type.information,
|
|
126
|
+
utility=game_type.utility,
|
|
127
|
+
reward_model=game_type.reward_model,
|
|
128
|
+
max_num_players=game_type.max_num_players,
|
|
129
|
+
min_num_players=game_type.min_num_players,
|
|
130
|
+
provides_information_state_string=game_type.provides_information_state_string,
|
|
131
|
+
provides_information_state_tensor=game_type.provides_information_state_tensor,
|
|
132
|
+
provides_observation_string=game_type.provides_observation_string,
|
|
133
|
+
provides_observation_tensor=game_type.provides_observation_tensor,
|
|
134
|
+
parameter_specification=game_type.parameter_specification,
|
|
135
|
+
default_loadable=game_type.default_loadable,
|
|
136
|
+
provides_factored_observation_string=game_type.provides_factored_observation_string,
|
|
137
|
+
)
|
|
138
|
+
kwargs.update(**overrides)
|
|
139
|
+
return pyspiel.GameType(**kwargs)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Proxy tests."""
|
|
2
|
+
|
|
3
|
+
from . import proxy
|
|
4
|
+
from absl.testing import absltest
|
|
5
|
+
from absl.testing import parameterized
|
|
6
|
+
import pyspiel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def make_game() -> proxy.Game:
|
|
10
|
+
return proxy.Game(pyspiel.load_game('tic_tac_toe()'))
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestState(proxy.State):
|
|
14
|
+
|
|
15
|
+
def __str__(self) -> str:
|
|
16
|
+
return 'TestState: ' + super().__str__()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TestGame(proxy.Game):
|
|
20
|
+
|
|
21
|
+
def new_initial_state(self, *args, **kwargs) -> TestState:
|
|
22
|
+
return TestState(
|
|
23
|
+
self.__wrapped__.new_initial_state(*args, **kwargs), game=self
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ProxiesTest(parameterized.TestCase):
|
|
28
|
+
|
|
29
|
+
def test_types(self):
|
|
30
|
+
game = make_game()
|
|
31
|
+
self.assertIsInstance(game, pyspiel.Game)
|
|
32
|
+
state = game.new_initial_state()
|
|
33
|
+
self.assertIsInstance(state, pyspiel.State)
|
|
34
|
+
|
|
35
|
+
def test_get_game(self):
|
|
36
|
+
game = make_game()
|
|
37
|
+
state = game.new_initial_state()
|
|
38
|
+
self.assertIsInstance(state.get_game(), proxy.Game)
|
|
39
|
+
new_state = state.get_game().new_initial_state()
|
|
40
|
+
self.assertIsInstance(new_state, proxy.State)
|
|
41
|
+
self.assertIsNot(new_state, state)
|
|
42
|
+
|
|
43
|
+
def test_clone(self):
|
|
44
|
+
game = make_game()
|
|
45
|
+
state = game.new_initial_state()
|
|
46
|
+
state.apply_action(state.legal_actions()[0])
|
|
47
|
+
clone = state.clone()
|
|
48
|
+
self.assertIsInstance(clone, proxy.State)
|
|
49
|
+
self.assertEqual(state.history(), clone.history())
|
|
50
|
+
clone.apply_action(clone.legal_actions()[0])
|
|
51
|
+
self.assertEqual(state.history(), clone.history()[:-1])
|
|
52
|
+
|
|
53
|
+
def test_subclassing(self):
|
|
54
|
+
game = TestGame(pyspiel.load_game('tic_tac_toe()'))
|
|
55
|
+
state = game.new_initial_state()
|
|
56
|
+
self.assertIsInstance(state, TestState)
|
|
57
|
+
self.assertIsInstance(state.clone(), TestState)
|
|
58
|
+
self.assertIsInstance(state.get_game(), TestGame)
|
|
59
|
+
wrapped_state = state.__wrapped__ # type: ignore
|
|
60
|
+
self.assertEqual(str(state), 'TestState: ' + str(wrapped_state))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
if __name__ == '__main__':
|
|
64
|
+
absltest.main()
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
kaggle_environments/__init__.py,sha256=
|
|
1
|
+
kaggle_environments/__init__.py,sha256=rkTFEuN8ZVrkb3vR3ravIenS9gIZBIjcnlUoA_zi7ic,2189
|
|
2
2
|
kaggle_environments/agent.py,sha256=j9rLnCK_Gy0eRIuvlJ9vcMh3vxn-Wvu-pjCpannOolc,6703
|
|
3
3
|
kaggle_environments/api.py,sha256=eLBKqr11Ku4tdsMUdUqy74FIVEA_hdV3_QUpX84x3Z8,798
|
|
4
4
|
kaggle_environments/core.py,sha256=IrEkN9cIA2djBAxI8Sz1GRpGNKjhqbnBdV6irAeTm8Q,27851
|
|
@@ -188,8 +188,17 @@ kaggle_environments/envs/mab/agents.py,sha256=vPHNN5oRcbTG3FaW9iYmoeQjufXFJMjYOL
|
|
|
188
188
|
kaggle_environments/envs/mab/mab.js,sha256=zsKGVRL9qFyUoukRj-ES5dOh8Wig7UzNf0z5Potw84E,3256
|
|
189
189
|
kaggle_environments/envs/mab/mab.json,sha256=VAlpjJ7_ytYO648swQW_ICjC5JKTAdmnShuGggeSX4A,2077
|
|
190
190
|
kaggle_environments/envs/mab/mab.py,sha256=bkSIxkstS98Vr3eOA9kxQkseDqa1MlG2Egfzeaf-8EA,5241
|
|
191
|
-
kaggle_environments/envs/open_spiel/
|
|
192
|
-
kaggle_environments/envs/open_spiel/
|
|
191
|
+
kaggle_environments/envs/open_spiel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
192
|
+
kaggle_environments/envs/open_spiel/observation.py,sha256=yrJ_iZ9sBUTB6YOyEpKNwYiQEWmsPPtaDYtL4zsw1Ko,4834
|
|
193
|
+
kaggle_environments/envs/open_spiel/open_spiel.py,sha256=UMH_flpyKIBAyC40iU8HUip-COq9AqulBLjTuZujrik,14038
|
|
194
|
+
kaggle_environments/envs/open_spiel/proxy.py,sha256=8Shane4KWYKvbP9nV3l8VQfAFOfFSUrS78h_4xQthVM,4881
|
|
195
|
+
kaggle_environments/envs/open_spiel/proxy_test.py,sha256=QkmRo_uS0DgDDm2pbU2vwal5KOMCWKw92rC2_g3MziM,1837
|
|
196
|
+
kaggle_environments/envs/open_spiel/test_open_spiel.py,sha256=55oTMpGiK401rtZTiqKj5QBW0_UNYDdJqPF5YJyewUc,541
|
|
197
|
+
kaggle_environments/envs/open_spiel/games/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
198
|
+
kaggle_environments/envs/open_spiel/games/connect_four/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
199
|
+
kaggle_environments/envs/open_spiel/games/connect_four/connect_four.js,sha256=sT8rVfjkyY_DUQSHpZUP7I7QEoJwa7Oq9H82wD7vnbQ,15440
|
|
200
|
+
kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy.py,sha256=2otG99felDYhNhWpsadbM9YUaHtrXqhV1GFNEHhuPwA,2348
|
|
201
|
+
kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy_test.py,sha256=vYn-QDPyRRigL8XdaHMN4FTO9zc1T9YB096HDGQH_T4,1870
|
|
193
202
|
kaggle_environments/envs/rps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
194
203
|
kaggle_environments/envs/rps/agents.py,sha256=iBtBjPbutWickm-K1EzBEYvLWj5fvU3ks0AYQMYWgEI,2140
|
|
195
204
|
kaggle_environments/envs/rps/helpers.py,sha256=NUqhJafNSzlC_ArwDIYzbLx15pkmBpzfVuG8Iv4wX9U,966
|
|
@@ -203,9 +212,9 @@ kaggle_environments/envs/tictactoe/tictactoe.js,sha256=NZDT-oSG0a6a-rso9Ldh9qkJw
|
|
|
203
212
|
kaggle_environments/envs/tictactoe/tictactoe.json,sha256=zMXZ8-fpT7FBhzz2FFBvRLn4XwtngjEqOieMvI6cCj8,1121
|
|
204
213
|
kaggle_environments/envs/tictactoe/tictactoe.py,sha256=uq3sTHWNMg0dxX2v9pTbJAKM7fwerxQt7OQjCX96m-Y,3657
|
|
205
214
|
kaggle_environments/static/player.html,sha256=XyVoe0XxMa2MO1fTDY_rjyjzPN-JZgbVwJIDoLSnlw0,23016
|
|
206
|
-
kaggle_environments-1.17.
|
|
207
|
-
kaggle_environments-1.17.
|
|
208
|
-
kaggle_environments-1.17.
|
|
209
|
-
kaggle_environments-1.17.
|
|
210
|
-
kaggle_environments-1.17.
|
|
211
|
-
kaggle_environments-1.17.
|
|
215
|
+
kaggle_environments-1.17.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
216
|
+
kaggle_environments-1.17.3.dist-info/METADATA,sha256=3yV9E05vePRsHqStWawS_yKoWzNG59fK5Ak4sbpPAYU,10955
|
|
217
|
+
kaggle_environments-1.17.3.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
|
|
218
|
+
kaggle_environments-1.17.3.dist-info/entry_points.txt,sha256=HbVC-LKGQFV6lEEYBYyDTtrkHgdHJUWQ8_qt9KHGqz4,70
|
|
219
|
+
kaggle_environments-1.17.3.dist-info/top_level.txt,sha256=v3MMWIPMQFcI-WuF_dJngHWe9Bb2yH_6p4wat1x4gAc,20
|
|
220
|
+
kaggle_environments-1.17.3.dist-info/RECORD,,
|
|
File without changes
|
{kaggle_environments-1.17.2.dist-info → kaggle_environments-1.17.3.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
{kaggle_environments-1.17.2.dist-info → kaggle_environments-1.17.3.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
|
File without changes
|