kaggle-environments 1.15.3__py2.py3-none-any.whl → 1.16.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaggle-environments might be problematic. Click here for more details.

Files changed (31) hide show
  1. kaggle_environments/__init__.py +1 -1
  2. kaggle_environments/envs/chess/chess.js +15 -0
  3. kaggle_environments/envs/chess/chess.json +5 -4
  4. kaggle_environments/envs/chess/chess.py +206 -51
  5. kaggle_environments/envs/chess/test_chess.py +43 -1
  6. kaggle_environments/envs/lux_ai_s3/agents.py +4 -0
  7. kaggle_environments/envs/lux_ai_s3/index.html +42 -0
  8. kaggle_environments/envs/lux_ai_s3/lux_ai_s3.json +47 -0
  9. kaggle_environments/envs/lux_ai_s3/lux_ai_s3.py +138 -0
  10. kaggle_environments/envs/lux_ai_s3/luxai_s3/__init__.py +1 -0
  11. kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py +924 -0
  12. kaggle_environments/envs/lux_ai_s3/luxai_s3/globals.py +13 -0
  13. kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py +101 -0
  14. kaggle_environments/envs/lux_ai_s3/luxai_s3/profiler.py +140 -0
  15. kaggle_environments/envs/lux_ai_s3/luxai_s3/pygame_render.py +270 -0
  16. kaggle_environments/envs/lux_ai_s3/luxai_s3/spaces.py +30 -0
  17. kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py +399 -0
  18. kaggle_environments/envs/lux_ai_s3/luxai_s3/utils.py +12 -0
  19. kaggle_environments/envs/lux_ai_s3/luxai_s3/wrappers.py +187 -0
  20. kaggle_environments/envs/lux_ai_s3/test_agents/python/agent.py +71 -0
  21. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/__init__.py +0 -0
  22. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/kit.py +27 -0
  23. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/utils.py +17 -0
  24. kaggle_environments/envs/lux_ai_s3/test_agents/python/main.py +53 -0
  25. kaggle_environments/envs/lux_ai_s3/test_lux.py +9 -0
  26. {kaggle_environments-1.15.3.dist-info → kaggle_environments-1.16.1.dist-info}/METADATA +2 -2
  27. {kaggle_environments-1.15.3.dist-info → kaggle_environments-1.16.1.dist-info}/RECORD +31 -11
  28. {kaggle_environments-1.15.3.dist-info → kaggle_environments-1.16.1.dist-info}/LICENSE +0 -0
  29. {kaggle_environments-1.15.3.dist-info → kaggle_environments-1.16.1.dist-info}/WHEEL +0 -0
  30. {kaggle_environments-1.15.3.dist-info → kaggle_environments-1.16.1.dist-info}/entry_points.txt +0 -0
  31. {kaggle_environments-1.15.3.dist-info → kaggle_environments-1.16.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,13 @@
1
+ import os
2
+
3
+ TERM_COLORS = True
4
+ try:
5
+ from termcolor import colored
6
+
7
+ TERM_COLORS = (
8
+ os.environ["LUX_COLORS"] == "False" if "LUX_COLORS" in os.environ else True
9
+ )
10
+ except:
11
+ TERM_COLORS = False
12
+ # print("termcolor not installed, skipping dependency")
13
+ pass
@@ -0,0 +1,101 @@
1
+ from flax import struct
2
+ import jax
3
+
4
+ MAP_TYPES = ["dev0", "random"]
5
+
6
+
7
+ @struct.dataclass
8
+ class EnvParams:
9
+ max_steps_in_match: int = 100
10
+ map_type: int = 1
11
+ """Map generation algorithm. Can change between games"""
12
+ map_width: int = 24
13
+ map_height: int = 24
14
+ num_teams: int = 2
15
+ match_count_per_episode: int = 5
16
+ """number of matches to play in one episode"""
17
+
18
+ # configs for units
19
+ max_units: int = 16
20
+ init_unit_energy: int = 100
21
+ min_unit_energy: int = 0
22
+ max_unit_energy: int = 400
23
+ unit_move_cost: int = 2
24
+ spawn_rate: int = 3
25
+
26
+ unit_sap_cost: int = 10
27
+ """
28
+ The unit sap cost is the amount of energy a unit uses when it saps another unit. Can change between games.
29
+ """
30
+ unit_sap_range: int = 4
31
+ """
32
+ The unit sap range is the range of the unit's sap action.
33
+ """
34
+ unit_sap_dropoff_factor: float = 0.5
35
+ """
36
+ The unit sap dropoff factor multiplied by unit_sap_drain
37
+ """
38
+ unit_energy_void_factor: float = 0.125
39
+ """
40
+ The unit energy void factor multiplied by unit_energy
41
+ """
42
+
43
+ # configs for energy nodes
44
+ max_energy_nodes: int = 6
45
+ max_energy_per_tile: int = 20
46
+ min_energy_per_tile: int = -20
47
+
48
+ max_relic_nodes: int = 6
49
+ relic_config_size: int = 5
50
+ fog_of_war: bool = True
51
+ """
52
+ whether there is fog of war or not
53
+ """
54
+ unit_sensor_range: int = 2
55
+ """
56
+ The unit sensor range is the range of the unit's sensor.
57
+ Units provide "vision power" over tiles in range, equal to manhattan distance to the unit.
58
+
59
+ vision power > 0 that team can see the tiles properties
60
+ """
61
+
62
+ # nebula tile params
63
+ nebula_tile_vision_reduction: int = 1
64
+ """
65
+ The nebula tile vision reduction is the amount of vision reduction a nebula tile provides.
66
+ A tile can be seen if the vision power over it is > 0.
67
+ """
68
+
69
+ nebula_tile_energy_reduction: int = 0
70
+ """amount of energy nebula tiles reduce from a unit"""
71
+
72
+ nebula_tile_drift_speed: float = -0.05
73
+ """
74
+ how fast nebula tiles drift in one of the diagonal directions over time. If positive, flows to the top/right, negative flows to bottom/left
75
+ """
76
+ # TODO (stao): allow other kinds of symmetric drifts?
77
+
78
+ energy_node_drift_speed: int = 0.02
79
+ """
80
+ how fast energy nodes will move around over time
81
+ """
82
+ energy_node_drift_magnitude: int = 5
83
+
84
+ # option to change sap configurations
85
+
86
+
87
+ env_params_ranges = dict(
88
+ # map_type=[1],
89
+ unit_move_cost=list(range(1, 6)),
90
+ unit_sensor_range=list(range(2, 5)),
91
+ nebula_tile_vision_reduction=list(range(0, 4)),
92
+ nebula_tile_energy_reduction=[0, 0, 10, 25],
93
+ unit_sap_cost=list(range(30, 51)),
94
+ unit_sap_range=list(range(3, 8)),
95
+ unit_sap_dropoff_factor=[0.25, 0.5, 1],
96
+ unit_energy_void_factor=[0.0625, 0.125, 0.25, 0.375],
97
+ # map randomizations
98
+ nebula_tile_drift_speed=[-0.05, -0.025, 0.025, 0.05],
99
+ energy_node_drift_speed=[0.01, 0.02, 0.03, 0.04, 0.05],
100
+ energy_node_drift_magnitude=list(range(3, 6)),
101
+ )
@@ -0,0 +1,140 @@
1
+ from collections import defaultdict
2
+ import os
3
+ import time
4
+ from contextlib import contextmanager
5
+ from typing import Literal
6
+ import numpy as np
7
+
8
+ import psutil
9
+ import pynvml
10
+ import subprocess as sp
11
+ def flatten_dict_keys(d: dict, prefix=""):
12
+ """Flatten a dict by expanding its keys recursively."""
13
+ out = dict()
14
+ for k, v in d.items():
15
+ if isinstance(v, dict):
16
+ out.update(flatten_dict_keys(v, prefix + k + "/"))
17
+ else:
18
+ out[prefix + k] = v
19
+ return out
20
+ class Profiler:
21
+ """
22
+ A simple class to help profile/benchmark simulator code
23
+ """
24
+
25
+ def __init__(
26
+ self, output_format: Literal["stdout", "json"], synchronize_torch: bool = True
27
+ ) -> None:
28
+ self.output_format = output_format
29
+ self.synchronize_torch = synchronize_torch
30
+ self.stats = defaultdict(list)
31
+ # Initialize NVML
32
+ pynvml.nvmlInit()
33
+
34
+ # Get handle for the first GPU (index 0)
35
+ self.handle = pynvml.nvmlDeviceGetHandleByIndex(0)
36
+
37
+ # Get the PID of the current process
38
+ self.current_pid = os.getpid()
39
+
40
+ def log(self, msg):
41
+ """log a message to stdout"""
42
+ if self.output_format == "stdout":
43
+ print(msg)
44
+
45
+ def update_csv(self, csv_path: str, data: dict):
46
+ """Update a csv file with the given data (a dict representing a unique identifier of the result row)
47
+ and stats. If the file does not exist, it will be created. The update will replace an existing row
48
+ if the given data matches the data in the row. If there are multiple matches, only the first match
49
+ will be replaced and the rest are deleted"""
50
+ import pandas as pd
51
+ import os
52
+
53
+ if os.path.exists(csv_path):
54
+ df = pd.read_csv(csv_path)
55
+ else:
56
+ df = pd.DataFrame()
57
+ stats_flat = flatten_dict_keys(self.stats)
58
+ cond = None
59
+
60
+ for k in stats_flat:
61
+ if k not in df:
62
+ df[k] = None
63
+ for k in data:
64
+ if k not in df:
65
+ df[k] = None
66
+
67
+ mask = df[k].isna() if data[k] is None else df[k] == data[k]
68
+ if cond is None:
69
+ cond = mask
70
+ else:
71
+ cond = cond & mask
72
+ data_dict = {**data, **stats_flat}
73
+ if not cond.any():
74
+ df = pd.concat([df, pd.DataFrame(data_dict, index=[len(df)])])
75
+ else:
76
+ # replace the first instance
77
+ df.loc[df.loc[cond].index[0]] = data_dict
78
+ df.drop(df.loc[cond].index[1:], inplace=True)
79
+ # delete other instances
80
+ df.to_csv(csv_path, index=False)
81
+
82
+ def profile(self, function, name: str, total_steps: int, num_envs: int, trials=1):
83
+ print(f"start recording {name} metrics")
84
+ process = psutil.Process(os.getpid())
85
+ cpu_mem_use = process.memory_info().rss
86
+ gpu_mem_use = self.get_current_process_gpu_memory()
87
+ if gpu_mem_use is None:
88
+ gpu_mem_use = 0
89
+
90
+ for trial in range(trials):
91
+ stime = time.time()
92
+ function()
93
+ dt = time.time() - stime
94
+ # dt: delta time (s)
95
+ # fps: frames per second
96
+ # psps: parallel steps per second (number of env.step calls per second)
97
+ self.stats[name].append(dict(
98
+ dt=dt,
99
+ fps=total_steps * num_envs / dt,
100
+ psps=total_steps / dt,
101
+ total_steps=total_steps,
102
+ cpu_mem_use=cpu_mem_use,
103
+ gpu_mem_use=gpu_mem_use,
104
+ ))
105
+ # torch.cuda.synchronize()
106
+
107
+ def log_stats(self, name: str):
108
+ stats = self.stats[name]
109
+ more_than_one_trial = len(stats) > 1
110
+ if len(stats) == 0:
111
+ return
112
+ # average the stats
113
+ avg_stats = defaultdict(list)
114
+ for data in stats:
115
+ for k, v in data.items():
116
+ avg_stats[k].append(v)
117
+ stats = {k: {"avg": np.mean(v), "std": np.std(v) if len(v) > 1 else None} for k, v in avg_stats.items()}
118
+ self.log(
119
+ f"{name} ({len(self.stats[name])} trials)"
120
+ )
121
+ self.log(
122
+ f"AVG: {stats['fps']['avg']:0.3f} steps/s, {stats['psps']['avg']:0.3f} parallel steps/s, {stats['total_steps']['avg']} steps in {stats['dt']['avg']:0.3f}s"
123
+ )
124
+ if more_than_one_trial:
125
+ self.log(
126
+ f"STD: {stats['fps']['std']:0.3f} steps/s, {stats['psps']['std']:0.3f} parallel steps/s, {stats['total_steps']['std']} steps in {stats['dt']['std']:0.3f}s"
127
+ )
128
+ self.log(
129
+ f"{' ' * 4}CPU mem: {stats['cpu_mem_use']['avg'] / (1024**2):0.3f} MB, GPU mem: {stats['gpu_mem_use']['avg'] / (1024**2):0.3f} MB"
130
+ )
131
+
132
+ def get_current_process_gpu_memory(self):
133
+ # Get all processes running on the GPU
134
+ processes = pynvml.nvmlDeviceGetComputeRunningProcesses(self.handle)
135
+
136
+ # Iterate through the processes to find the current process
137
+ for process in processes:
138
+ if process.pid == self.current_pid:
139
+ memory_usage = process.usedGpuMemory
140
+ return memory_usage
@@ -0,0 +1,270 @@
1
+ from luxai_s3.params import EnvParams
2
+ from luxai_s3.state import ASTEROID_TILE, NEBULA_TILE, EnvState
3
+ import numpy as np
4
+
5
+ try:
6
+ import pygame
7
+ except:
8
+ pass
9
+
10
+ TILE_SIZE = 64
11
+
12
+
13
+ class LuxAIPygameRenderer:
14
+ def __init__(self):
15
+ pass
16
+
17
+ def render(self, state: EnvState, params: EnvParams):
18
+ """Render the environment."""
19
+
20
+ # Initialize Pygame if not already done
21
+ if not pygame.get_init():
22
+ pygame.init()
23
+ self.clock = pygame.time.Clock()
24
+ # Set up the display
25
+ screen_width = params.map_width * TILE_SIZE
26
+ screen_height = params.map_height * TILE_SIZE
27
+ self.screen = pygame.display.set_mode((screen_width, screen_height))
28
+ self.surface = pygame.Surface(self.screen.get_size(), pygame.SRCALPHA)
29
+ pygame.display.set_caption("Lux AI Season 3")
30
+
31
+ self.display_options = {
32
+ "show_grid": True,
33
+ "show_relic_spots": False,
34
+ "show_sensor_mask": True,
35
+ "show_vision_power_map": True,
36
+ "show_energy_field": False,
37
+ }
38
+
39
+ # Handle events to keep the window responsive
40
+ render_state = "running"
41
+ while True:
42
+ self._update_display(state, params)
43
+ for event in pygame.event.get():
44
+ if event.type == pygame.TEXTINPUT:
45
+ if event.text == " ":
46
+ if render_state == "running":
47
+ render_state = "paused"
48
+ else:
49
+ render_state = "running"
50
+ elif event.text == "r":
51
+ self.display_options["show_relic_spots"] = (
52
+ not self.display_options["show_relic_spots"]
53
+ )
54
+ elif event.text == "s":
55
+ self.display_options["show_sensor_mask"] = (
56
+ not self.display_options["show_sensor_mask"]
57
+ )
58
+ elif event.text == "e":
59
+ self.display_options["show_energy_field"] = (
60
+ not self.display_options["show_energy_field"]
61
+ )
62
+ else:
63
+ if render_state == "paused":
64
+ self.clock.tick(60)
65
+ continue
66
+ break
67
+
68
+ def _update_display(self, state: EnvState, params: EnvParams):
69
+ # Fill the screen with a background color
70
+ self.screen.fill((200, 200, 200))
71
+ self.surface.fill((200, 200, 200, 255)) # Light gray background
72
+
73
+ # Draw the grid of tiles
74
+ for x in range(params.map_width):
75
+ for y in range(params.map_height):
76
+ rect = pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE)
77
+ tile_type = state.map_features.tile_type[x, y]
78
+ if tile_type == NEBULA_TILE:
79
+ color = (166, 177, 225, 255) # Light blue (a6b1e1) for tile type 1
80
+ elif tile_type == ASTEROID_TILE:
81
+ color = (51, 56, 68, 255)
82
+ else:
83
+ color = (255, 255, 255, 255) # White for other tile types
84
+ pygame.draw.rect(self.surface, color, rect) # Draw filled squares
85
+
86
+ # Draw relic node configs if display option is enabled
87
+ def draw_rect_alpha(surface, color, rect):
88
+ shape_surf = pygame.Surface(pygame.Rect(rect).size, pygame.SRCALPHA)
89
+ pygame.draw.rect(shape_surf, color, shape_surf.get_rect())
90
+ surface.blit(shape_surf, rect)
91
+
92
+ if self.display_options["show_relic_spots"]:
93
+ for i in range(params.max_relic_nodes):
94
+ if state.relic_nodes_mask[i]:
95
+ x, y = state.relic_nodes[i, :2]
96
+ config_size = params.relic_config_size
97
+ half_size = config_size // 2
98
+ for dx in range(-half_size, half_size + 1):
99
+ for dy in range(-half_size, half_size + 1):
100
+ config_x = x + dx
101
+ config_y = y + dy
102
+
103
+ if (
104
+ 0 <= config_x < params.map_width
105
+ and 0 <= config_y < params.map_height
106
+ ):
107
+
108
+ config_value = state.relic_node_configs[
109
+ i, dy + half_size, dx + half_size
110
+ ]
111
+
112
+ if config_value > 0:
113
+ rect = pygame.Rect(
114
+ config_x * TILE_SIZE,
115
+ config_y * TILE_SIZE,
116
+ TILE_SIZE,
117
+ TILE_SIZE,
118
+ )
119
+ draw_rect_alpha(
120
+ self.surface, (255, 215, 0, 50), rect
121
+ ) # Semi-transparent gold
122
+
123
+ # Draw energy nodes
124
+ for i in range(params.max_energy_nodes):
125
+ if state.energy_nodes_mask[i]:
126
+ x, y = state.energy_nodes[i, :2]
127
+ center_x = (x + 0.5) * TILE_SIZE
128
+ center_y = (y + 0.5) * TILE_SIZE
129
+ radius = (
130
+ TILE_SIZE // 4
131
+ ) # Adjust this value to change the size of the circle
132
+ pygame.draw.circle(
133
+ self.surface,
134
+ (0, 255, 0, 255),
135
+ (int(center_x), int(center_y)),
136
+ radius,
137
+ )
138
+ # Draw relic nodes
139
+ for i in range(params.max_relic_nodes):
140
+ if state.relic_nodes_mask[i]:
141
+ x, y = state.relic_nodes[i, :2]
142
+ rect_size = TILE_SIZE // 2 # Make the square smaller than the tile
143
+ rect_x = x * TILE_SIZE + (TILE_SIZE - rect_size) // 2
144
+ rect_y = y * TILE_SIZE + (TILE_SIZE - rect_size) // 2
145
+ rect = pygame.Rect(rect_x, rect_y, rect_size, rect_size)
146
+ pygame.draw.rect(
147
+ self.surface, (173, 151, 32, 255), rect
148
+ ) # Light blue color
149
+
150
+ # Draw sensor mask
151
+ if self.display_options["show_sensor_mask"]:
152
+ for team in range(params.num_teams):
153
+ for x in range(params.map_width):
154
+ for y in range(params.map_height):
155
+ if state.sensor_mask[team, x, y]:
156
+ draw_rect_alpha(
157
+ self.surface,
158
+ (255, 0, 0, 25),
159
+ pygame.Rect(
160
+ x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE
161
+ ),
162
+ )
163
+
164
+ if self.display_options["show_energy_field"]:
165
+ font = pygame.font.Font(None, 32) # You may need to adjust the font size
166
+ for x in range(params.map_width):
167
+ for y in range(params.map_height):
168
+ energy_field_value = state.map_features.energy[x, y]
169
+ text = font.render(str(energy_field_value), True, (255, 255, 255))
170
+ text_rect = text.get_rect(
171
+ center=((x + 0.5) * TILE_SIZE, (y + 0.5) * TILE_SIZE)
172
+ )
173
+ self.surface.blit(text, text_rect)
174
+ if energy_field_value > 0:
175
+ draw_rect_alpha(
176
+ self.surface,
177
+ (
178
+ 0,
179
+ 255,
180
+ 0,
181
+ 255 * energy_field_value / params.max_energy_per_tile,
182
+ ),
183
+ pygame.Rect(
184
+ x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE
185
+ ),
186
+ )
187
+ else:
188
+ draw_rect_alpha(
189
+ self.surface,
190
+ (
191
+ 255,
192
+ 0,
193
+ 0,
194
+ 255 * energy_field_value / params.min_energy_per_tile,
195
+ ),
196
+ pygame.Rect(
197
+ x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE
198
+ ),
199
+ )
200
+ # if self.display_options["show_vision_power_map"]:
201
+ # print(state.vision_power_map.shape)
202
+ # font = pygame.font.Font(None, 32) # You may need to adjust the font size
203
+ # # vision_power_map = vision_power_map - (state.map_features.tile_type == NEBULA_TILE)[..., 0] * params.nebula_tile_vision_reduction
204
+ # for team in range(0, 1):
205
+ # for x in range(params.map_width):
206
+ # for y in range(params.map_height):
207
+ # vision_power_value = state.vision_power_map[team, x, y]
208
+ # vision_power_value -= state.map_features.tile_type[x, y] == NEBULA_TILE
209
+ # text = font.render(str(vision_power_value), True, (0, 255, 255))
210
+ # text_rect = text.get_rect(
211
+ # center=((x + 0.5) * TILE_SIZE, (y + 0.5) * TILE_SIZE)
212
+ # )
213
+ # self.surface.blit(text, text_rect)
214
+
215
+ # Draw units
216
+ for team in range(2):
217
+ for i in range(params.max_units):
218
+ if state.units_mask[team, i]:
219
+ x, y = state.units.position[team, i]
220
+ center_x = (x + 0.5) * TILE_SIZE
221
+ center_y = (y + 0.5) * TILE_SIZE
222
+ radius = (
223
+ TILE_SIZE // 3
224
+ ) # Adjust this value to change the size of the circle
225
+ color = (
226
+ (255, 0, 0, 255) if team == 0 else (0, 0, 255, 255)
227
+ ) # Red for team 0, Blue for team 1
228
+ pygame.draw.circle(
229
+ self.surface, color, (int(center_x), int(center_y)), radius
230
+ )
231
+ # Draw unit counts
232
+ unit_counts = {}
233
+ for team in range(2):
234
+ for i in range(params.max_units):
235
+ if state.units_mask[team, i]:
236
+ x, y = np.array(state.units.position[team, i])
237
+ pos = (x, y)
238
+ if pos not in unit_counts:
239
+ unit_counts[pos] = 0
240
+ unit_counts[pos] += 1
241
+
242
+ font = pygame.font.Font(None, 32) # You may need to adjust the font size
243
+ for pos, count in unit_counts.items():
244
+ if count >= 1:
245
+ x, y = pos
246
+ text = font.render(str(count), True, (255, 255, 255)) # White text
247
+ text_rect = text.get_rect(
248
+ center=((x + 0.5) * TILE_SIZE, (y + 0.5) * TILE_SIZE)
249
+ )
250
+ self.surface.blit(text, text_rect)
251
+
252
+ # Draw the grid lines
253
+ for x in range(params.map_width + 1):
254
+ pygame.draw.line(
255
+ self.surface,
256
+ (100, 100, 100),
257
+ (x * TILE_SIZE, 0),
258
+ (x * TILE_SIZE, params.map_height * TILE_SIZE),
259
+ )
260
+ for y in range(params.map_height + 1):
261
+ pygame.draw.line(
262
+ self.surface,
263
+ (100, 100, 100),
264
+ (0, y * TILE_SIZE),
265
+ (params.map_width * TILE_SIZE, y * TILE_SIZE),
266
+ )
267
+
268
+ self.screen.blit(self.surface, (0, 0))
269
+ # Update the display
270
+ pygame.display.flip()
@@ -0,0 +1,30 @@
1
+ import chex
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+ from gymnax.environments.spaces import Space
6
+
7
+
8
+ class MultiDiscrete(Space):
9
+ """Minimal jittable class for multi discrete gymnax spaces."""
10
+
11
+ def __init__(self, low: np.ndarray, high: np.ndarray):
12
+ self.low = low
13
+ self.high = high
14
+ self.dist = self.high - self.low
15
+ assert low.shape == high.shape
16
+ self.shape = low.shape
17
+ self.dtype = jnp.int16
18
+
19
+ def sample(self, rng: chex.PRNGKey) -> chex.Array:
20
+ return (
21
+ jax.random.uniform(rng, shape=self.shape, minval=0, maxval=1) * self.dist
22
+ + self.low
23
+ ).astype(self.dtype)
24
+
25
+ def contains(self, x) -> jnp.ndarray:
26
+ """Check whether specific object is within space."""
27
+ # type_cond = isinstance(x, self.dtype)
28
+ # shape_cond = (x.shape == self.shape)
29
+ range_cond = jnp.logical_and(x >= 0, x < self.n)
30
+ return range_cond