kaggle-environments 1.17.11__py2.py3-none-any.whl → 1.18.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaggle-environments might be problematic. Click here for more details.
- kaggle_environments/__init__.py +1 -1
- kaggle_environments/api.py +5 -13
- kaggle_environments/envs/cabt/cabt.js +164 -0
- kaggle_environments/envs/cabt/cabt.json +28 -0
- kaggle_environments/envs/cabt/cabt.py +119 -0
- kaggle_environments/envs/cabt/cg/__init__.py +0 -0
- kaggle_environments/envs/cabt/cg/cg.dll +0 -0
- kaggle_environments/envs/cabt/cg/game.py +70 -0
- kaggle_environments/envs/cabt/cg/libcg.so +0 -0
- kaggle_environments/envs/cabt/cg/sim.py +44 -0
- kaggle_environments/envs/open_spiel/games/chess/chess.js +25 -22
- kaggle_environments/envs/open_spiel/open_spiel.py +53 -1
- kaggle_environments/envs/open_spiel/test_open_spiel.py +85 -1
- kaggle_environments/helpers.py +126 -86
- kaggle_environments/main.py +29 -44
- kaggle_environments/static/player.html +84 -37
- kaggle_environments/utils.py +8 -12
- {kaggle_environments-1.17.11.dist-info → kaggle_environments-1.18.0.dist-info}/METADATA +2 -71
- {kaggle_environments-1.17.11.dist-info → kaggle_environments-1.18.0.dist-info}/RECORD +23 -15
- {kaggle_environments-1.17.11.dist-info → kaggle_environments-1.18.0.dist-info}/WHEEL +1 -1
- {kaggle_environments-1.17.11.dist-info → kaggle_environments-1.18.0.dist-info}/entry_points.txt +0 -0
- {kaggle_environments-1.17.11.dist-info → kaggle_environments-1.18.0.dist-info}/licenses/LICENSE +0 -0
- {kaggle_environments-1.17.11.dist-info → kaggle_environments-1.18.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Kaggle environment wrapper for OpenSpiel games."""
|
|
2
2
|
|
|
3
3
|
import copy
|
|
4
|
+
import json
|
|
4
5
|
import importlib
|
|
5
6
|
import logging
|
|
6
7
|
import os
|
|
@@ -51,7 +52,7 @@ DEFAULT_INVALID_ACTION_REWARD = -1
|
|
|
51
52
|
AGENT_ERROR_ACTION = -2
|
|
52
53
|
|
|
53
54
|
DEFAULT_ACT_TIMEOUT = 60 * 60 # sixty minutes
|
|
54
|
-
DEFAULT_RUN_TIMEOUT = 60 * 60 *
|
|
55
|
+
DEFAULT_RUN_TIMEOUT = 60 * 60 * 48 # thirty hours
|
|
55
56
|
# Buffer in addition to max game length to account for timeouts, retrys, etc.
|
|
56
57
|
DEFAULT_STEP_BUFFER = 100
|
|
57
58
|
# TODO(jhtschultz): Add individual game descriptions.
|
|
@@ -80,6 +81,22 @@ CONFIGURATION_SPEC_TEMPLATE = {
|
|
|
80
81
|
"type": "object",
|
|
81
82
|
"default": {}
|
|
82
83
|
},
|
|
84
|
+
"useOpenings": {
|
|
85
|
+
"description": "Whether to start from a position in an opening book.",
|
|
86
|
+
"type": "boolean",
|
|
87
|
+
"default": False
|
|
88
|
+
},
|
|
89
|
+
"seed": {
|
|
90
|
+
"description": "Integer currently only used for selecting starting position.",
|
|
91
|
+
"type": "number",
|
|
92
|
+
},
|
|
93
|
+
"initialActions": {
|
|
94
|
+
"description": "Actions applied to initial state before play begins to set up starting position.",
|
|
95
|
+
"type": "array",
|
|
96
|
+
"items": {
|
|
97
|
+
"type": "integer"
|
|
98
|
+
},
|
|
99
|
+
},
|
|
83
100
|
"metadata": {
|
|
84
101
|
"description": "Arbitrary metadata.",
|
|
85
102
|
"type": "object",
|
|
@@ -159,8 +176,35 @@ ENV_SPEC_TEMPLATE = {
|
|
|
159
176
|
}
|
|
160
177
|
|
|
161
178
|
|
|
179
|
+
def _get_initial_actions(
|
|
180
|
+
configuration: dict[str, Any],
|
|
181
|
+
) -> tuple[list[int], dict[str, Any]]:
|
|
182
|
+
initial_actions = configuration.get("initialActions", [])
|
|
183
|
+
if initial_actions:
|
|
184
|
+
if configuration.get("useOpenings"):
|
|
185
|
+
raise ValueError("Cannot set both useOpenings and initialActions.")
|
|
186
|
+
else:
|
|
187
|
+
return initial_actions, {}
|
|
188
|
+
if not configuration.get("useOpenings"):
|
|
189
|
+
return [], {}
|
|
190
|
+
seed = configuration.get("seed", None)
|
|
191
|
+
if seed is None:
|
|
192
|
+
raise ValueError("Must provide seed if useOpenings is True.")
|
|
193
|
+
openings_path = pathlib.Path(
|
|
194
|
+
GAMES_DIR, configuration.get("openSpielGameName"), "openings.jsonl",
|
|
195
|
+
)
|
|
196
|
+
if not openings_path.is_file():
|
|
197
|
+
raise ValueError(f"No opening file found at {openings_path}")
|
|
198
|
+
with open(openings_path, "r", encoding="utf-8") as f:
|
|
199
|
+
openings = f.readlines()
|
|
200
|
+
opening = json.loads(openings[seed % len(openings)])
|
|
201
|
+
initial_actions = opening.pop("initialActions")
|
|
202
|
+
return initial_actions, opening
|
|
203
|
+
|
|
204
|
+
|
|
162
205
|
# --- Core step logic ---
|
|
163
206
|
|
|
207
|
+
|
|
164
208
|
def interpreter(
|
|
165
209
|
state: list[utils.Struct],
|
|
166
210
|
env: core.Environment,
|
|
@@ -185,6 +229,14 @@ def interpreter(
|
|
|
185
229
|
env.info['stateHistory'] = [str(env.os_state)]
|
|
186
230
|
env.info['actionHistory'] = []
|
|
187
231
|
env.info['moveDurations'] = []
|
|
232
|
+
initial_actions, metadata = _get_initial_actions(env.configuration)
|
|
233
|
+
if initial_actions:
|
|
234
|
+
env.info["initialActions"] = initial_actions
|
|
235
|
+
env.info["openingMetadata"] = metadata
|
|
236
|
+
for action in initial_actions:
|
|
237
|
+
env.os_state.apply_action(action)
|
|
238
|
+
env.info["actionHistory"].append(str(action))
|
|
239
|
+
env.info["stateHistory"].append(str(env.os_state))
|
|
188
240
|
|
|
189
241
|
os_game = env.os_game
|
|
190
242
|
os_state = env.os_state
|
|
@@ -1,5 +1,8 @@
|
|
|
1
|
-
|
|
1
|
+
import json
|
|
2
|
+
import pathlib
|
|
2
3
|
import sys
|
|
4
|
+
|
|
5
|
+
from absl.testing import absltest
|
|
3
6
|
from kaggle_environments import make
|
|
4
7
|
import pyspiel
|
|
5
8
|
from . import open_spiel as open_spiel_env
|
|
@@ -91,6 +94,87 @@ class OpenSpielEnvTest(absltest.TestCase):
|
|
|
91
94
|
self.assertEqual(json["rewards"], [None, None])
|
|
92
95
|
self.assertEqual(json["statuses"], ["ERROR", "ERROR"])
|
|
93
96
|
|
|
97
|
+
def test_initial_actions(self):
|
|
98
|
+
open_spiel_env._register_game_envs(["tic_tac_toe"])
|
|
99
|
+
env = make(
|
|
100
|
+
"open_spiel_tic_tac_toe",
|
|
101
|
+
{"initialActions": [0, 1, 3, 4]},
|
|
102
|
+
debug=True,
|
|
103
|
+
)
|
|
104
|
+
env.reset()
|
|
105
|
+
# Setup step
|
|
106
|
+
env.step([
|
|
107
|
+
{"submission": pyspiel.INVALID_ACTION},
|
|
108
|
+
{"submission": pyspiel.INVALID_ACTION},
|
|
109
|
+
])
|
|
110
|
+
env.step([
|
|
111
|
+
{"submission": 2},
|
|
112
|
+
{"submission": pyspiel.INVALID_ACTION},
|
|
113
|
+
])
|
|
114
|
+
env.step([
|
|
115
|
+
{"submission": pyspiel.INVALID_ACTION},
|
|
116
|
+
{"submission": 7},
|
|
117
|
+
])
|
|
118
|
+
self.assertTrue(env.done)
|
|
119
|
+
json_playthrough = env.toJSON()
|
|
120
|
+
self.assertEqual(json_playthrough["rewards"], [-1, 1])
|
|
121
|
+
|
|
122
|
+
def test_chess_openings_manually_configured(self):
|
|
123
|
+
open_spiel_env._register_game_envs(["chess"])
|
|
124
|
+
openings_path = pathlib.Path(
|
|
125
|
+
open_spiel_env.GAMES_DIR,
|
|
126
|
+
"chess/openings.jsonl",
|
|
127
|
+
)
|
|
128
|
+
self.assertTrue(openings_path.is_file())
|
|
129
|
+
with open(openings_path, "r", encoding="utf-8") as f:
|
|
130
|
+
for line in f:
|
|
131
|
+
opening = json.loads(line)
|
|
132
|
+
config = {
|
|
133
|
+
"initialActions": opening.pop("initialActions"),
|
|
134
|
+
"metadata": opening,
|
|
135
|
+
}
|
|
136
|
+
env = make(
|
|
137
|
+
"open_spiel_chess",
|
|
138
|
+
config,
|
|
139
|
+
debug=True,
|
|
140
|
+
)
|
|
141
|
+
env.reset()
|
|
142
|
+
# Setup step
|
|
143
|
+
env.step([
|
|
144
|
+
{"submission": pyspiel.INVALID_ACTION},
|
|
145
|
+
{"submission": pyspiel.INVALID_ACTION},
|
|
146
|
+
])
|
|
147
|
+
obs = env.state[0]["observation"]
|
|
148
|
+
_, state = pyspiel.deserialize_game_and_state(
|
|
149
|
+
obs["serializedGameAndState"]
|
|
150
|
+
)
|
|
151
|
+
self.assertEqual(str(state), opening["fen"])
|
|
152
|
+
self.assertEqual(str(state),
|
|
153
|
+
env.toJSON()["configuration"]["metadata"]["fen"])
|
|
154
|
+
|
|
155
|
+
def test_chess_openings_configured_with_seed(self):
|
|
156
|
+
open_spiel_env._register_game_envs(["chess"])
|
|
157
|
+
config = {
|
|
158
|
+
"useOpenings": True,
|
|
159
|
+
"seed": 0,
|
|
160
|
+
}
|
|
161
|
+
env = make(
|
|
162
|
+
"open_spiel_chess",
|
|
163
|
+
config,
|
|
164
|
+
debug=True,
|
|
165
|
+
)
|
|
166
|
+
env.reset()
|
|
167
|
+
# Setup step
|
|
168
|
+
env.step([
|
|
169
|
+
{"submission": pyspiel.INVALID_ACTION},
|
|
170
|
+
{"submission": pyspiel.INVALID_ACTION},
|
|
171
|
+
])
|
|
172
|
+
obs = env.state[0]["observation"]
|
|
173
|
+
game, state = pyspiel.deserialize_game_and_state(
|
|
174
|
+
obs["serializedGameAndState"]
|
|
175
|
+
)
|
|
176
|
+
# Check that selected opening state does not equal standard start state.
|
|
177
|
+
self.assertNotEqual(str(state), str(game.new_initial_state()))
|
|
94
178
|
|
|
95
179
|
if __name__ == '__main__':
|
|
96
180
|
absltest.main()
|
kaggle_environments/helpers.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import operator
|
|
2
|
-
import math
|
|
3
|
-
from enum import Enum, auto
|
|
4
2
|
import random
|
|
5
|
-
from
|
|
3
|
+
from enum import Enum, auto
|
|
4
|
+
from typing import Any, Callable, Dict, Generic, Iterable, List, Tuple, Type, TypeVar, Union
|
|
6
5
|
|
|
7
6
|
|
|
8
7
|
class Point(tuple):
|
|
@@ -14,7 +13,8 @@ class Point(tuple):
|
|
|
14
13
|
Note that operators in this class do not constrain points to the board.
|
|
15
14
|
You can generally constrain a point to the board by calling point % board.configuration.size.
|
|
16
15
|
"""
|
|
17
|
-
|
|
16
|
+
|
|
17
|
+
def __new__(cls: Type["Point"], x: int, y: int):
|
|
18
18
|
return super(Point, cls).__new__(cls, tuple((x, y)))
|
|
19
19
|
|
|
20
20
|
@property
|
|
@@ -25,22 +25,22 @@ class Point(tuple):
|
|
|
25
25
|
def y(self):
|
|
26
26
|
return self[1]
|
|
27
27
|
|
|
28
|
-
def map(self, f: Callable[[int], int]) ->
|
|
28
|
+
def map(self, f: Callable[[int], int]) -> "Point":
|
|
29
29
|
return Point(f(self[0]), f(self[1]))
|
|
30
30
|
|
|
31
|
-
def map2(self, other: Union[Tuple[int, int],
|
|
31
|
+
def map2(self, other: Union[Tuple[int, int], "Point"], f: Callable[[int, int], int]) -> "Point":
|
|
32
32
|
return Point(f(self[0], other[0]), f(self[1], other[1]))
|
|
33
33
|
|
|
34
|
-
def translate(self, offset:
|
|
34
|
+
def translate(self, offset: "Point", size: int):
|
|
35
35
|
"""Translates the current point by offset and wraps it around a board of width and height size"""
|
|
36
36
|
return (self + offset) % size
|
|
37
37
|
|
|
38
|
-
def distance_to(self, other:
|
|
38
|
+
def distance_to(self, other: "Point", size: int):
|
|
39
39
|
"""Computes total distance (manhattan) to travel to other Point"""
|
|
40
40
|
abs_x = abs(self.x - other.x)
|
|
41
|
-
dist_x = abs_x if abs_x < size/2 else size - abs_x
|
|
41
|
+
dist_x = abs_x if abs_x < size / 2 else size - abs_x
|
|
42
42
|
abs_y = abs(self.y - other.y)
|
|
43
|
-
dist_y = abs_y if abs_y < size/2 else size - abs_y
|
|
43
|
+
dist_y = abs_y if abs_y < size / 2 else size - abs_y
|
|
44
44
|
return dist_x + dist_y
|
|
45
45
|
|
|
46
46
|
def to_index(self, size: int):
|
|
@@ -51,7 +51,7 @@ class Point(tuple):
|
|
|
51
51
|
return (size - self.y - 1) * size + self.x
|
|
52
52
|
|
|
53
53
|
@staticmethod
|
|
54
|
-
def from_index(index: int, size: int) ->
|
|
54
|
+
def from_index(index: int, size: int) -> "Point":
|
|
55
55
|
"""
|
|
56
56
|
Converts an index in the observation.halite list to a 2d position in the form (x, y).
|
|
57
57
|
See Point method to_index for the inverse.
|
|
@@ -59,37 +59,37 @@ class Point(tuple):
|
|
|
59
59
|
y, x = divmod(index, size)
|
|
60
60
|
return Point(x, (size - y - 1))
|
|
61
61
|
|
|
62
|
-
def __abs__(self) ->
|
|
62
|
+
def __abs__(self) -> "Point":
|
|
63
63
|
return self.map(operator.abs)
|
|
64
64
|
|
|
65
|
-
def __add__(self, other: Union[Tuple[int, int],
|
|
65
|
+
def __add__(self, other: Union[Tuple[int, int], "Point"]) -> "Point":
|
|
66
66
|
return self.map2(other, operator.add)
|
|
67
67
|
|
|
68
|
-
def __eq__(self, other: Union[Tuple[int, int],
|
|
68
|
+
def __eq__(self, other: Union[Tuple[int, int], "Point"]) -> bool:
|
|
69
69
|
try:
|
|
70
70
|
return self[0] == other[0] and self[1] == other[1]
|
|
71
71
|
except (TypeError, IndexError):
|
|
72
72
|
return False
|
|
73
73
|
|
|
74
|
-
def __floordiv__(self, denominator: int) ->
|
|
74
|
+
def __floordiv__(self, denominator: int) -> "Point":
|
|
75
75
|
return self.map(lambda x: x // denominator)
|
|
76
76
|
|
|
77
77
|
def __hash__(self) -> int:
|
|
78
78
|
return hash((self.x, self.y))
|
|
79
79
|
|
|
80
|
-
def __mod__(self, mod: int) ->
|
|
80
|
+
def __mod__(self, mod: int) -> "Point":
|
|
81
81
|
return self.map(lambda x: x % mod)
|
|
82
82
|
|
|
83
|
-
def __mul__(self, factor: int) ->
|
|
83
|
+
def __mul__(self, factor: int) -> "Point":
|
|
84
84
|
return self.map(lambda x: x * factor)
|
|
85
85
|
|
|
86
|
-
def __neg__(self) ->
|
|
86
|
+
def __neg__(self) -> "Point":
|
|
87
87
|
return self.map(operator.neg)
|
|
88
88
|
|
|
89
89
|
def __str__(self):
|
|
90
90
|
return f"({self.x}, {self.y})"
|
|
91
91
|
|
|
92
|
-
def __sub__(self, other: Union[Tuple[int, int],
|
|
92
|
+
def __sub__(self, other: Union[Tuple[int, int], "Point"]) -> "Point":
|
|
93
93
|
return self.map2(other, operator.sub)
|
|
94
94
|
|
|
95
95
|
|
|
@@ -101,18 +101,22 @@ class Direction(Enum):
|
|
|
101
101
|
|
|
102
102
|
def to_point(self) -> Point:
|
|
103
103
|
"""
|
|
104
|
-
This returns the position offset associated with a particular action
|
|
104
|
+
This returns the position offset associated with a particular action
|
|
105
105
|
NORTH -> (0, 1)
|
|
106
106
|
EAST -> (1, 0)
|
|
107
107
|
SOUTH -> (0, -1)
|
|
108
108
|
WEST -> (-1, 0)
|
|
109
109
|
"""
|
|
110
110
|
return (
|
|
111
|
-
Point(0, 1)
|
|
112
|
-
|
|
113
|
-
Point(
|
|
114
|
-
|
|
115
|
-
|
|
111
|
+
Point(0, 1)
|
|
112
|
+
if self == Direction.NORTH
|
|
113
|
+
else Point(1, 0)
|
|
114
|
+
if self == Direction.EAST
|
|
115
|
+
else Point(0, -1)
|
|
116
|
+
if self == Direction.SOUTH
|
|
117
|
+
else Point(-1, 0)
|
|
118
|
+
if self == Direction.WEST
|
|
119
|
+
else None
|
|
116
120
|
)
|
|
117
121
|
|
|
118
122
|
def __str__(self) -> str:
|
|
@@ -120,93 +124,125 @@ class Direction(Enum):
|
|
|
120
124
|
|
|
121
125
|
def to_index(self) -> int:
|
|
122
126
|
return (
|
|
123
|
-
0
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
127
|
+
0
|
|
128
|
+
if self == Direction.NORTH
|
|
129
|
+
else 1
|
|
130
|
+
if self == Direction.EAST
|
|
131
|
+
else 2
|
|
132
|
+
if self == Direction.SOUTH
|
|
133
|
+
else 3
|
|
134
|
+
if self == Direction.WEST
|
|
135
|
+
else None
|
|
128
136
|
)
|
|
129
137
|
|
|
130
138
|
def to_char(self) -> str:
|
|
131
139
|
return (
|
|
132
|
-
"N"
|
|
133
|
-
|
|
134
|
-
"
|
|
135
|
-
|
|
136
|
-
|
|
140
|
+
"N"
|
|
141
|
+
if self == Direction.NORTH
|
|
142
|
+
else "E"
|
|
143
|
+
if self == Direction.EAST
|
|
144
|
+
else "S"
|
|
145
|
+
if self == Direction.SOUTH
|
|
146
|
+
else "W"
|
|
147
|
+
if self == Direction.WEST
|
|
148
|
+
else None
|
|
137
149
|
)
|
|
138
150
|
|
|
139
|
-
def opposite(self) ->
|
|
151
|
+
def opposite(self) -> "Direction":
|
|
140
152
|
return (
|
|
141
|
-
Direction.SOUTH
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
153
|
+
Direction.SOUTH
|
|
154
|
+
if self == Direction.NORTH
|
|
155
|
+
else Direction.WEST
|
|
156
|
+
if self == Direction.EAST
|
|
157
|
+
else Direction.NORTH
|
|
158
|
+
if self == Direction.SOUTH
|
|
159
|
+
else Direction.EAST
|
|
160
|
+
if self == Direction.WEST
|
|
161
|
+
else None
|
|
146
162
|
)
|
|
147
163
|
|
|
148
|
-
def rotate_left(self) ->
|
|
164
|
+
def rotate_left(self) -> "Direction":
|
|
149
165
|
return (
|
|
150
|
-
Direction.WEST
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
166
|
+
Direction.WEST
|
|
167
|
+
if self == Direction.NORTH
|
|
168
|
+
else Direction.NORTH
|
|
169
|
+
if self == Direction.EAST
|
|
170
|
+
else Direction.EAST
|
|
171
|
+
if self == Direction.SOUTH
|
|
172
|
+
else Direction.SOUTH
|
|
173
|
+
if self == Direction.WEST
|
|
174
|
+
else None
|
|
155
175
|
)
|
|
156
176
|
|
|
157
|
-
def rotate_right(self) ->
|
|
177
|
+
def rotate_right(self) -> "Direction":
|
|
158
178
|
return (
|
|
159
|
-
Direction.EAST
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
179
|
+
Direction.EAST
|
|
180
|
+
if self == Direction.NORTH
|
|
181
|
+
else Direction.SOUTH
|
|
182
|
+
if self == Direction.EAST
|
|
183
|
+
else Direction.WEST
|
|
184
|
+
if self == Direction.SOUTH
|
|
185
|
+
else Direction.NORTH
|
|
186
|
+
if self == Direction.WEST
|
|
187
|
+
else None
|
|
164
188
|
)
|
|
165
189
|
|
|
166
190
|
@staticmethod
|
|
167
191
|
def from_str(str_dir: str):
|
|
168
|
-
return
|
|
169
|
-
Direction.NORTH
|
|
170
|
-
|
|
171
|
-
Direction.
|
|
172
|
-
|
|
173
|
-
|
|
192
|
+
return (
|
|
193
|
+
Direction.NORTH
|
|
194
|
+
if str_dir == "NORTH"
|
|
195
|
+
else Direction.EAST
|
|
196
|
+
if str_dir == "EAST"
|
|
197
|
+
else Direction.SOUTH
|
|
198
|
+
if str_dir == "SOUTH"
|
|
199
|
+
else Direction.WEST
|
|
200
|
+
if str_dir == "WEST"
|
|
201
|
+
else None
|
|
174
202
|
)
|
|
175
203
|
|
|
176
204
|
@staticmethod
|
|
177
205
|
def from_char(str_char: str):
|
|
178
|
-
return
|
|
179
|
-
Direction.NORTH
|
|
180
|
-
|
|
181
|
-
Direction.
|
|
182
|
-
|
|
183
|
-
|
|
206
|
+
return (
|
|
207
|
+
Direction.NORTH
|
|
208
|
+
if str_char == "N"
|
|
209
|
+
else Direction.EAST
|
|
210
|
+
if str_char == "E"
|
|
211
|
+
else Direction.SOUTH
|
|
212
|
+
if str_char == "S"
|
|
213
|
+
else Direction.WEST
|
|
214
|
+
if str_char == "W"
|
|
215
|
+
else None
|
|
184
216
|
)
|
|
185
217
|
|
|
186
218
|
@staticmethod
|
|
187
219
|
def from_index(idx: int):
|
|
188
220
|
return (
|
|
189
|
-
Direction.NORTH
|
|
190
|
-
|
|
191
|
-
Direction.
|
|
192
|
-
|
|
193
|
-
|
|
221
|
+
Direction.NORTH
|
|
222
|
+
if idx == 0
|
|
223
|
+
else Direction.EAST
|
|
224
|
+
if idx == 1
|
|
225
|
+
else Direction.SOUTH
|
|
226
|
+
if idx == 2
|
|
227
|
+
else Direction.WEST
|
|
228
|
+
if idx == 3
|
|
229
|
+
else None
|
|
194
230
|
)
|
|
195
231
|
|
|
196
232
|
@staticmethod
|
|
197
|
-
def random_direction() ->
|
|
233
|
+
def random_direction() -> "Direction":
|
|
198
234
|
rand = random.random()
|
|
199
|
-
if rand <= .25:
|
|
235
|
+
if rand <= 0.25:
|
|
200
236
|
return Direction.NORTH
|
|
201
|
-
elif rand <= .5:
|
|
237
|
+
elif rand <= 0.5:
|
|
202
238
|
return Direction.EAST
|
|
203
|
-
elif rand <= .75:
|
|
239
|
+
elif rand <= 0.75:
|
|
204
240
|
return Direction.SOUTH
|
|
205
241
|
else:
|
|
206
242
|
return Direction.WEST
|
|
207
243
|
|
|
208
244
|
@staticmethod
|
|
209
|
-
def list_directions() -> List[
|
|
245
|
+
def list_directions() -> List["Direction"]:
|
|
210
246
|
return [
|
|
211
247
|
Direction.NORTH,
|
|
212
248
|
Direction.EAST,
|
|
@@ -215,8 +251,8 @@ class Direction(Enum):
|
|
|
215
251
|
]
|
|
216
252
|
|
|
217
253
|
|
|
218
|
-
TItem = TypeVar(
|
|
219
|
-
THash = TypeVar(
|
|
254
|
+
TItem = TypeVar("TItem")
|
|
255
|
+
THash = TypeVar("THash")
|
|
220
256
|
|
|
221
257
|
|
|
222
258
|
def group_by(items: Iterable[TItem], selector: Callable[[TItem], THash]) -> Dict[THash, List[TItem]]:
|
|
@@ -250,6 +286,7 @@ class Observation(Dict[str, any]):
|
|
|
250
286
|
"""
|
|
251
287
|
Observation provides access to per-step parameters in the environment.
|
|
252
288
|
"""
|
|
289
|
+
|
|
253
290
|
@property
|
|
254
291
|
def step(self) -> int:
|
|
255
292
|
"""Current step within the episode."""
|
|
@@ -265,6 +302,7 @@ class Configuration(Dict[str, any]):
|
|
|
265
302
|
"""
|
|
266
303
|
Configuration provides access to tunable parameters in the environment.
|
|
267
304
|
"""
|
|
305
|
+
|
|
268
306
|
@property
|
|
269
307
|
def episode_steps(self) -> int:
|
|
270
308
|
"""Total number of steps/turns in the run."""
|
|
@@ -281,9 +319,9 @@ class Configuration(Dict[str, any]):
|
|
|
281
319
|
return self["runTimeout"]
|
|
282
320
|
|
|
283
321
|
|
|
284
|
-
TConfiguration = TypeVar(
|
|
285
|
-
TObservation = TypeVar(
|
|
286
|
-
TAction = TypeVar(
|
|
322
|
+
TConfiguration = TypeVar("TConfiguration", bound=Configuration)
|
|
323
|
+
TObservation = TypeVar("TObservation", bound=Observation)
|
|
324
|
+
TAction = TypeVar("TAction")
|
|
287
325
|
Agent = Callable[[TObservation, TConfiguration], TAction]
|
|
288
326
|
|
|
289
327
|
|
|
@@ -319,17 +357,19 @@ class AgentState(Generic[TObservation, TAction], Dict[str, any]):
|
|
|
319
357
|
|
|
320
358
|
class Environment(Generic[TConfiguration, TObservation, TAction]):
|
|
321
359
|
@property
|
|
322
|
-
def specification(self) -> Dict[str,
|
|
323
|
-
raise
|
|
360
|
+
def specification(self) -> Dict[str, Any]:
|
|
361
|
+
raise NotImplementedError()
|
|
324
362
|
|
|
325
|
-
def interpret(
|
|
326
|
-
|
|
363
|
+
def interpret(
|
|
364
|
+
self, configuration: TConfiguration, state: List[AgentState[TObservation, TAction]]
|
|
365
|
+
) -> List[AgentState[TObservation, TAction]]:
|
|
366
|
+
raise NotImplementedError()
|
|
327
367
|
|
|
328
368
|
def render_html(self, configuration: TConfiguration, state: List[AgentState[TObservation, TAction]]) -> str:
|
|
329
|
-
raise
|
|
369
|
+
raise NotImplementedError()
|
|
330
370
|
|
|
331
371
|
def render_text(self, configuration: TConfiguration, state: List[AgentState[TObservation, TAction]]) -> str:
|
|
332
|
-
raise
|
|
372
|
+
raise NotImplementedError()
|
|
333
373
|
|
|
334
374
|
@property
|
|
335
375
|
def builtin_agents(self) -> Dict[str, Agent]:
|