kaggle-environments 0.2.1__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaggle-environments might be problematic. Click here for more details.
- kaggle_environments/__init__.py +49 -13
- kaggle_environments/agent.py +177 -124
- kaggle_environments/api.py +31 -0
- kaggle_environments/core.py +295 -170
- kaggle_environments/envs/cabt/cabt.js +164 -0
- kaggle_environments/envs/cabt/cabt.json +28 -0
- kaggle_environments/envs/cabt/cabt.py +186 -0
- kaggle_environments/envs/cabt/cg/__init__.py +0 -0
- kaggle_environments/envs/cabt/cg/cg.dll +0 -0
- kaggle_environments/envs/cabt/cg/game.py +75 -0
- kaggle_environments/envs/cabt/cg/libcg.so +0 -0
- kaggle_environments/envs/cabt/cg/sim.py +48 -0
- kaggle_environments/envs/cabt/test_cabt.py +120 -0
- kaggle_environments/envs/chess/chess.js +4289 -0
- kaggle_environments/envs/chess/chess.json +60 -0
- kaggle_environments/envs/chess/chess.py +4241 -0
- kaggle_environments/envs/chess/test_chess.py +60 -0
- kaggle_environments/envs/connectx/connectx.ipynb +3186 -0
- kaggle_environments/envs/connectx/connectx.js +1 -1
- kaggle_environments/envs/connectx/connectx.json +15 -1
- kaggle_environments/envs/connectx/connectx.py +6 -23
- kaggle_environments/envs/connectx/test_connectx.py +70 -24
- kaggle_environments/envs/football/football.ipynb +75 -0
- kaggle_environments/envs/football/football.json +91 -0
- kaggle_environments/envs/football/football.py +277 -0
- kaggle_environments/envs/football/helpers.py +95 -0
- kaggle_environments/envs/football/test_football.py +360 -0
- kaggle_environments/envs/halite/__init__.py +0 -0
- kaggle_environments/envs/halite/halite.ipynb +44741 -0
- kaggle_environments/envs/halite/halite.js +199 -83
- kaggle_environments/envs/halite/halite.json +31 -18
- kaggle_environments/envs/halite/halite.py +164 -303
- kaggle_environments/envs/halite/helpers.py +720 -0
- kaggle_environments/envs/halite/test_halite.py +190 -0
- kaggle_environments/envs/hungry_geese/__init__.py +0 -0
- kaggle_environments/envs/{battlegeese/battlegeese.js → hungry_geese/hungry_geese.js} +38 -22
- kaggle_environments/envs/{battlegeese/battlegeese.json → hungry_geese/hungry_geese.json} +21 -14
- kaggle_environments/envs/hungry_geese/hungry_geese.py +316 -0
- kaggle_environments/envs/hungry_geese/test_hungry_geese.py +0 -0
- kaggle_environments/envs/identity/identity.json +6 -5
- kaggle_environments/envs/identity/identity.py +15 -2
- kaggle_environments/envs/kore_fleets/__init__.py +0 -0
- kaggle_environments/envs/kore_fleets/helpers.py +1005 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.ipynb +114 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.js +658 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.json +164 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.py +555 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/Bot.java +54 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/README.md +26 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/jars/hamcrest-core-1.3.jar +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/jars/junit-4.13.2.jar +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Board.java +518 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Cell.java +61 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Configuration.java +24 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Direction.java +166 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Fleet.java +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/KoreJson.java +97 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Observation.java +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Pair.java +13 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Player.java +68 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Point.java +65 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Shipyard.java +70 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/ShipyardAction.java +59 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/main.py +73 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/BoardTest.java +567 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ConfigurationTest.java +25 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/KoreJsonTest.java +62 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ObservationTest.java +46 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/PointTest.java +21 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ShipyardTest.java +22 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/configuration.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/fullob.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/observation.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/python/__init__.py +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/python/main.py +27 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/Bot.ts +34 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/DoNothingBot.ts +12 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/MinerBot.ts +62 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/README.md +55 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/interpreter.ts +402 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Board.ts +514 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Cell.ts +63 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Configuration.ts +25 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Direction.ts +169 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Fleet.ts +76 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/KoreIO.ts +70 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Observation.ts +45 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Pair.ts +11 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Player.ts +68 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Point.ts +65 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Shipyard.ts +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/ShipyardAction.ts +58 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/main.py +73 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/miner.py +73 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/package.json +23 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/BoardTest.ts +551 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ConfigurationTest.ts +16 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ObservationTest.ts +33 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/PointTest.ts +17 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ShipyardTest.ts +18 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/configuration.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/fullob.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/observation.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/tsconfig.json +22 -0
- kaggle_environments/envs/kore_fleets/test_kore_fleets.py +331 -0
- kaggle_environments/envs/lux_ai_2021/README.md +3 -0
- kaggle_environments/envs/lux_ai_2021/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_2021/agents.py +11 -0
- kaggle_environments/envs/lux_ai_2021/dimensions/754.js +2 -0
- kaggle_environments/envs/lux_ai_2021/dimensions/754.js.LICENSE.txt +296 -0
- kaggle_environments/envs/lux_ai_2021/dimensions/main.js +1 -0
- kaggle_environments/envs/lux_ai_2021/index.html +43 -0
- kaggle_environments/envs/lux_ai_2021/lux_ai_2021.json +100 -0
- kaggle_environments/envs/lux_ai_2021/lux_ai_2021.py +231 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.js +6 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.json +59 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_objects.js +145 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/io.js +14 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/kit.js +209 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/map.js +107 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/parser.js +79 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.js +88 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.py +75 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/simple.tar.gz +0 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/annotate.py +20 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/constants.py +25 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game.py +86 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.json +59 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.py +7 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_map.py +106 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_objects.py +154 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/random_agent.py +38 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/simple_agent.py +82 -0
- kaggle_environments/envs/lux_ai_2021/test_lux.py +19 -0
- kaggle_environments/envs/lux_ai_2021/testing.md +23 -0
- kaggle_environments/envs/lux_ai_2021/todo.md.og +18 -0
- kaggle_environments/envs/lux_ai_s3/README.md +21 -0
- kaggle_environments/envs/lux_ai_s3/agents.py +5 -0
- kaggle_environments/envs/lux_ai_s3/index.html +42 -0
- kaggle_environments/envs/lux_ai_s3/lux_ai_s3.json +47 -0
- kaggle_environments/envs/lux_ai_s3/lux_ai_s3.py +178 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/__init__.py +1 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py +819 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/globals.py +9 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py +101 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/profiler.py +141 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/pygame_render.py +222 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/spaces.py +27 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py +464 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/utils.py +12 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/wrappers.py +156 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/agent.py +78 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/kit.py +31 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/utils.py +17 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/main.py +66 -0
- kaggle_environments/envs/lux_ai_s3/test_lux.py +9 -0
- kaggle_environments/envs/mab/__init__.py +0 -0
- kaggle_environments/envs/mab/agents.py +12 -0
- kaggle_environments/envs/mab/mab.js +100 -0
- kaggle_environments/envs/mab/mab.json +74 -0
- kaggle_environments/envs/mab/mab.py +146 -0
- kaggle_environments/envs/open_spiel/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/chess/chess.js +441 -0
- kaggle_environments/envs/open_spiel/games/chess/image_config.jsonl +20 -0
- kaggle_environments/envs/open_spiel/games/chess/openings.jsonl +20 -0
- kaggle_environments/envs/open_spiel/games/connect_four/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four.js +284 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy.py +86 -0
- kaggle_environments/envs/open_spiel/games/go/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/go/go.js +481 -0
- kaggle_environments/envs/open_spiel/games/go/go_proxy.py +99 -0
- kaggle_environments/envs/open_spiel/games/tic_tac_toe/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe.js +345 -0
- kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe_proxy.py +98 -0
- kaggle_environments/envs/open_spiel/games/universal_poker/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker.js +431 -0
- kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker_proxy.py +159 -0
- kaggle_environments/envs/open_spiel/html_playthrough_generator.py +31 -0
- kaggle_environments/envs/open_spiel/observation.py +128 -0
- kaggle_environments/envs/open_spiel/open_spiel.py +565 -0
- kaggle_environments/envs/open_spiel/proxy.py +138 -0
- kaggle_environments/envs/open_spiel/test_open_spiel.py +191 -0
- kaggle_environments/envs/rps/__init__.py +0 -0
- kaggle_environments/envs/rps/agents.py +84 -0
- kaggle_environments/envs/rps/helpers.py +25 -0
- kaggle_environments/envs/rps/rps.js +117 -0
- kaggle_environments/envs/rps/rps.json +63 -0
- kaggle_environments/envs/rps/rps.py +90 -0
- kaggle_environments/envs/rps/test_rps.py +110 -0
- kaggle_environments/envs/rps/utils.py +7 -0
- kaggle_environments/envs/tictactoe/test_tictactoe.py +43 -77
- kaggle_environments/envs/tictactoe/tictactoe.ipynb +1397 -0
- kaggle_environments/envs/tictactoe/tictactoe.json +10 -2
- kaggle_environments/envs/tictactoe/tictactoe.py +1 -1
- kaggle_environments/errors.py +2 -4
- kaggle_environments/helpers.py +377 -0
- kaggle_environments/main.py +340 -0
- kaggle_environments/schemas.json +23 -18
- kaggle_environments/static/player.html +206 -74
- kaggle_environments/utils.py +46 -73
- kaggle_environments-1.20.0.dist-info/METADATA +25 -0
- kaggle_environments-1.20.0.dist-info/RECORD +211 -0
- {kaggle_environments-0.2.1.dist-info → kaggle_environments-1.20.0.dist-info}/WHEEL +1 -2
- kaggle_environments-1.20.0.dist-info/entry_points.txt +3 -0
- kaggle_environments/envs/battlegeese/battlegeese.py +0 -223
- kaggle_environments/temp.py +0 -14
- kaggle_environments-0.2.1.dist-info/METADATA +0 -393
- kaggle_environments-0.2.1.dist-info/RECORD +0 -32
- kaggle_environments-0.2.1.dist-info/entry_points.txt +0 -3
- kaggle_environments-0.2.1.dist-info/top_level.txt +0 -1
- {kaggle_environments-0.2.1.dist-info → kaggle_environments-1.20.0.dist-info/licenses}/LICENSE +0 -0
kaggle_environments/core.py
CHANGED
|
@@ -14,11 +14,16 @@
|
|
|
14
14
|
|
|
15
15
|
import copy
|
|
16
16
|
import json
|
|
17
|
-
|
|
17
|
+
import traceback
|
|
18
18
|
import uuid
|
|
19
|
+
from contextlib import redirect_stderr, redirect_stdout
|
|
20
|
+
from io import StringIO
|
|
21
|
+
from multiprocessing import Pool
|
|
22
|
+
from time import perf_counter
|
|
23
|
+
|
|
19
24
|
from .agent import Agent
|
|
20
|
-
from .errors import DeadlineExceeded, FailedPrecondition,
|
|
21
|
-
from .utils import get,
|
|
25
|
+
from .errors import DeadlineExceeded, FailedPrecondition, InvalidArgument
|
|
26
|
+
from .utils import get, get_player, has, process_schema, schemas, structify
|
|
22
27
|
|
|
23
28
|
# Registered Environments.
|
|
24
29
|
environments = {}
|
|
@@ -33,71 +38,137 @@ def register(name, environment):
|
|
|
33
38
|
* specification - JSON Schema representing the environment.
|
|
34
39
|
* interpreter - Function(state, environment) -> new_state
|
|
35
40
|
* renderer - Function(state, environment) -> string
|
|
36
|
-
* html_renderer(
|
|
41
|
+
* html_renderer - Function(environment) -> JavaScript HTML renderer function.
|
|
37
42
|
* agents(optional) - List of default agents [Function(observation, config) -> action]
|
|
38
43
|
"""
|
|
39
44
|
environments[name] = environment
|
|
40
45
|
|
|
41
46
|
|
|
42
|
-
def evaluate(environment, agents=
|
|
47
|
+
def evaluate(environment, agents=None, configuration=None, steps=None, num_episodes=1, debug=False, state=None):
|
|
43
48
|
"""
|
|
44
49
|
Evaluate and return the rewards of one or more episodes (environment and agents combo).
|
|
45
50
|
|
|
46
51
|
Args:
|
|
47
|
-
environment (str|Environment):
|
|
52
|
+
environment (str|Environment):
|
|
48
53
|
agents (list):
|
|
49
54
|
configuration (dict, optional):
|
|
50
55
|
steps (list, optional):
|
|
51
56
|
num_episodes (int=1, optional): How many episodes to execute (run until done).
|
|
57
|
+
debug (bool=False, optional): Render print() statments to stdout
|
|
58
|
+
state (optional)
|
|
52
59
|
|
|
53
60
|
Returns:
|
|
54
61
|
list of list of int: List of final rewards for all agents for all episodes.
|
|
55
62
|
"""
|
|
56
|
-
|
|
57
|
-
|
|
63
|
+
if agents is None:
|
|
64
|
+
agents = []
|
|
65
|
+
if configuration is None:
|
|
66
|
+
configuration = {}
|
|
67
|
+
if steps is None:
|
|
68
|
+
steps = []
|
|
69
|
+
|
|
70
|
+
e = make(environment, configuration, steps, debug=debug, state=state)
|
|
71
|
+
rewards = [[] for i in range(num_episodes)]
|
|
58
72
|
for i in range(num_episodes):
|
|
59
73
|
last_state = e.run(agents)[-1]
|
|
60
74
|
rewards[i] = [state.reward for state in last_state]
|
|
61
75
|
return rewards
|
|
62
76
|
|
|
63
77
|
|
|
64
|
-
def make(environment, configuration=
|
|
78
|
+
def make(environment, configuration=None, info=None, steps=None, logs=None, debug=False, state=None):
|
|
65
79
|
"""
|
|
66
80
|
Creates an instance of an Environment.
|
|
67
81
|
|
|
68
82
|
Args:
|
|
69
|
-
environment (str|Environment):
|
|
83
|
+
environment (str|Environment):
|
|
70
84
|
configuration (dict, optional):
|
|
85
|
+
info (dict, optional):
|
|
71
86
|
steps (list, optional):
|
|
72
|
-
debug (bool=False, optional):
|
|
87
|
+
debug (bool=False, optional): Render print() statments to stdout
|
|
88
|
+
state (optional):
|
|
73
89
|
|
|
74
90
|
Returns:
|
|
75
91
|
Environment: Instance of a specific environment.
|
|
76
92
|
"""
|
|
93
|
+
if configuration is None:
|
|
94
|
+
configuration = {}
|
|
95
|
+
if info is None:
|
|
96
|
+
info = {}
|
|
97
|
+
if steps is None:
|
|
98
|
+
steps = []
|
|
99
|
+
if logs is None:
|
|
100
|
+
logs = []
|
|
101
|
+
|
|
77
102
|
if has(environment, str) and has(environments, dict, path=[environment]):
|
|
78
|
-
return Environment(
|
|
103
|
+
return Environment(
|
|
104
|
+
**environments[environment],
|
|
105
|
+
configuration=configuration,
|
|
106
|
+
info=info,
|
|
107
|
+
steps=steps,
|
|
108
|
+
logs=logs,
|
|
109
|
+
debug=debug,
|
|
110
|
+
state=state,
|
|
111
|
+
)
|
|
79
112
|
elif callable(environment):
|
|
80
|
-
return Environment(
|
|
113
|
+
return Environment(
|
|
114
|
+
interpreter=environment,
|
|
115
|
+
configuration=configuration,
|
|
116
|
+
info=info,
|
|
117
|
+
steps=steps,
|
|
118
|
+
logs=logs,
|
|
119
|
+
debug=debug,
|
|
120
|
+
state=state,
|
|
121
|
+
)
|
|
81
122
|
elif has(environment, path=["interpreter"], is_callable=True):
|
|
82
|
-
return Environment(
|
|
123
|
+
return Environment(
|
|
124
|
+
**environment, configuration=configuration, info=info, steps=steps, logs=logs, debug=debug, state=state
|
|
125
|
+
)
|
|
83
126
|
raise InvalidArgument("Unknown Environment Specification")
|
|
84
127
|
|
|
85
128
|
|
|
86
|
-
|
|
129
|
+
def act_agent(args):
|
|
130
|
+
agent, state, configuration, none_action = args
|
|
131
|
+
if state["status"] != "ACTIVE":
|
|
132
|
+
return None, {}
|
|
133
|
+
elif agent is None:
|
|
134
|
+
return none_action, {}
|
|
135
|
+
else:
|
|
136
|
+
return agent.act(state["observation"])
|
|
87
137
|
|
|
138
|
+
|
|
139
|
+
class Environment:
|
|
88
140
|
def __init__(
|
|
89
141
|
self,
|
|
90
|
-
specification=
|
|
91
|
-
configuration=
|
|
92
|
-
|
|
93
|
-
|
|
142
|
+
specification=None,
|
|
143
|
+
configuration=None,
|
|
144
|
+
info=None,
|
|
145
|
+
steps=None,
|
|
146
|
+
logs=None,
|
|
147
|
+
agents=None,
|
|
94
148
|
interpreter=None,
|
|
95
149
|
renderer=None,
|
|
96
150
|
html_renderer=None,
|
|
97
151
|
debug=False,
|
|
152
|
+
state=None,
|
|
98
153
|
):
|
|
154
|
+
if specification is None:
|
|
155
|
+
specification = {}
|
|
156
|
+
if configuration is None:
|
|
157
|
+
configuration = {}
|
|
158
|
+
if info is None:
|
|
159
|
+
info = {}
|
|
160
|
+
if steps is None:
|
|
161
|
+
steps = []
|
|
162
|
+
if logs is None:
|
|
163
|
+
logs = []
|
|
164
|
+
if agents is None:
|
|
165
|
+
agents = {}
|
|
166
|
+
|
|
167
|
+
self.logs = logs
|
|
99
168
|
self.id = str(uuid.uuid1())
|
|
100
169
|
self.debug = debug
|
|
170
|
+
self.info = info
|
|
171
|
+
self.pool = None
|
|
101
172
|
|
|
102
173
|
err, specification = self.__process_specification(specification)
|
|
103
174
|
if err:
|
|
@@ -106,7 +177,7 @@ class Environment:
|
|
|
106
177
|
|
|
107
178
|
err, configuration = process_schema(
|
|
108
179
|
{"type": "object", "properties": self.specification.configuration},
|
|
109
|
-
{} if configuration
|
|
180
|
+
{} if configuration is None else configuration,
|
|
110
181
|
)
|
|
111
182
|
if err:
|
|
112
183
|
raise InvalidArgument("Configuration Invalid: " + err)
|
|
@@ -120,30 +191,37 @@ class Environment:
|
|
|
120
191
|
raise InvalidArgument("Renderer is not Callable.")
|
|
121
192
|
self.renderer = renderer
|
|
122
193
|
|
|
123
|
-
if callable(html_renderer):
|
|
124
|
-
|
|
125
|
-
self.html_renderer =
|
|
194
|
+
if not callable(html_renderer):
|
|
195
|
+
raise InvalidArgument("Html_renderer is not Callable.")
|
|
196
|
+
self.html_renderer = html_renderer
|
|
126
197
|
|
|
127
198
|
if not all([callable(a) for a in agents.values()]):
|
|
128
199
|
raise InvalidArgument("Default agents must be Callable.")
|
|
129
200
|
self.agents = structify(agents)
|
|
130
201
|
|
|
131
|
-
if steps
|
|
132
|
-
self.reset()
|
|
133
|
-
else:
|
|
202
|
+
if steps is not None and len(steps) > 0:
|
|
134
203
|
self.__set_state(steps[-1])
|
|
135
204
|
self.steps = steps[0:-1] + self.steps
|
|
205
|
+
elif state is not None:
|
|
206
|
+
step = [{}] * self.specification.agents[0]
|
|
207
|
+
step[0] = state
|
|
208
|
+
self.__set_state(step)
|
|
209
|
+
else:
|
|
210
|
+
self.reset()
|
|
136
211
|
|
|
137
|
-
def step(self, actions):
|
|
212
|
+
def step(self, actions, logs=None):
|
|
138
213
|
"""
|
|
139
214
|
Execute the environment interpreter using the current state and a list of actions.
|
|
140
215
|
|
|
141
216
|
Args:
|
|
142
217
|
actions (list): Actions to pair up with the current agent states.
|
|
218
|
+
logs (list): Logs to pair up with each agent for the current step.
|
|
143
219
|
|
|
144
220
|
Returns:
|
|
145
221
|
list of dict: The agents states after the step.
|
|
146
222
|
"""
|
|
223
|
+
if logs is None:
|
|
224
|
+
logs = []
|
|
147
225
|
|
|
148
226
|
if self.done:
|
|
149
227
|
raise FailedPrecondition("Environment done, reset required.")
|
|
@@ -158,26 +236,27 @@ class Environment:
|
|
|
158
236
|
self.debug_print(f"Timeout: {str(action)}")
|
|
159
237
|
action_state[index]["status"] = "TIMEOUT"
|
|
160
238
|
elif isinstance(action, BaseException):
|
|
161
|
-
self.debug_print(f"Error: {
|
|
239
|
+
self.debug_print(f"Error: {traceback.format_exception(None, action, action.__traceback__)}")
|
|
162
240
|
action_state[index]["status"] = "ERROR"
|
|
163
241
|
else:
|
|
164
|
-
err, data = process_schema(
|
|
165
|
-
self.__state_schema.properties.action, action)
|
|
242
|
+
err, data = process_schema(self.__state_schema.properties.action, action)
|
|
166
243
|
if err:
|
|
167
244
|
self.debug_print(f"Invalid Action: {str(err)}")
|
|
168
245
|
action_state[index]["status"] = "INVALID"
|
|
169
246
|
else:
|
|
170
247
|
action_state[index]["action"] = data
|
|
171
248
|
|
|
172
|
-
self.state = self.__run_interpreter(action_state)
|
|
249
|
+
self.state = self.__run_interpreter(action_state, logs)
|
|
173
250
|
|
|
174
251
|
# Max Steps reached. Mark ACTIVE/INACTIVE agents as DONE.
|
|
175
|
-
if
|
|
252
|
+
if self.state[0].observation.step >= self.configuration.episodeSteps - 1:
|
|
176
253
|
for s in self.state:
|
|
177
254
|
if s.status == "ACTIVE" or s.status == "INACTIVE":
|
|
178
255
|
s.status = "DONE"
|
|
179
256
|
|
|
180
257
|
self.steps.append(self.state)
|
|
258
|
+
if logs is not None:
|
|
259
|
+
self.logs.append(logs)
|
|
181
260
|
|
|
182
261
|
return self.state
|
|
183
262
|
|
|
@@ -189,19 +268,25 @@ class Environment:
|
|
|
189
268
|
agents (list of any): List of agents to obtain actions from.
|
|
190
269
|
|
|
191
270
|
Returns:
|
|
192
|
-
|
|
271
|
+
tuple of:
|
|
272
|
+
list of list of dict: The agent states of all steps executed.
|
|
273
|
+
list of list of dict: The agent logs of all steps executed.
|
|
193
274
|
"""
|
|
194
|
-
if self.state
|
|
275
|
+
if self.state is None or len(self.steps) == 1 or self.done:
|
|
195
276
|
self.reset(len(agents))
|
|
196
277
|
if len(self.state) != len(agents):
|
|
197
|
-
raise InvalidArgument(
|
|
198
|
-
f"{len(self.state)} agents were expected, but {len(agents)} was given.")
|
|
278
|
+
raise InvalidArgument(f"{len(self.state)} agents were expected, but {len(agents)} was given.")
|
|
199
279
|
|
|
200
280
|
runner = self.__agent_runner(agents)
|
|
201
|
-
start =
|
|
202
|
-
while not self.done and
|
|
203
|
-
|
|
204
|
-
|
|
281
|
+
start = perf_counter()
|
|
282
|
+
while not self.done and perf_counter() - start < self.configuration.runTimeout:
|
|
283
|
+
actions, logs = runner.act()
|
|
284
|
+
self.step(actions, logs)
|
|
285
|
+
if not self.done and perf_counter() - start >= self.configuration.runTimeout:
|
|
286
|
+
raise DeadlineExceeded(
|
|
287
|
+
f"runtime of {perf_counter() - start} exceeded the runTimeout of {self.configuration.runTimeout}"
|
|
288
|
+
)
|
|
289
|
+
|
|
205
290
|
return self.steps
|
|
206
291
|
|
|
207
292
|
def reset(self, num_agents=None):
|
|
@@ -215,7 +300,7 @@ class Environment:
|
|
|
215
300
|
list of dict: The agents states after the reset.
|
|
216
301
|
"""
|
|
217
302
|
|
|
218
|
-
if num_agents
|
|
303
|
+
if num_agents is None:
|
|
219
304
|
num_agents = self.specification.agents[0]
|
|
220
305
|
|
|
221
306
|
# Get configuration default state.
|
|
@@ -225,7 +310,9 @@ class Environment:
|
|
|
225
310
|
for agent in self.state:
|
|
226
311
|
agent.status = "INACTIVE"
|
|
227
312
|
# Give the interpreter an opportunity to make any initializations.
|
|
228
|
-
|
|
313
|
+
logs = []
|
|
314
|
+
self.__set_state(self.__run_interpreter(self.state, logs))
|
|
315
|
+
self.logs.append(logs)
|
|
229
316
|
# Replace the starting "status" if still "done".
|
|
230
317
|
if self.done and len(self.state) == len(statuses):
|
|
231
318
|
for i in range(len(self.state)):
|
|
@@ -247,35 +334,40 @@ class Environment:
|
|
|
247
334
|
mode = get(kwargs, str, "human", path=["mode"])
|
|
248
335
|
if mode == "ansi" or mode == "human":
|
|
249
336
|
args = [self.state, self]
|
|
250
|
-
out = self.renderer(*args[:self.renderer.__code__.co_argcount])
|
|
337
|
+
out = self.renderer(*args[: self.renderer.__code__.co_argcount])
|
|
251
338
|
if mode == "ansi":
|
|
252
339
|
return out
|
|
253
|
-
print(out)
|
|
254
340
|
elif mode == "html" or mode == "ipython":
|
|
341
|
+
is_playing = get(kwargs, bool, self.done, path=["playing"])
|
|
255
342
|
window_kaggle = {
|
|
256
343
|
"debug": get(kwargs, bool, self.debug, path=["debug"]),
|
|
257
|
-
"
|
|
258
|
-
"step": 0 if
|
|
344
|
+
"playing": is_playing,
|
|
345
|
+
"step": 0 if is_playing else len(self.steps) - 1,
|
|
259
346
|
"controls": get(kwargs, bool, self.done, path=["controls"]),
|
|
260
347
|
"environment": self.toJSON(),
|
|
348
|
+
"logs": self.logs,
|
|
261
349
|
**kwargs,
|
|
262
350
|
}
|
|
263
|
-
|
|
351
|
+
args = [self]
|
|
352
|
+
player_html = get_player(
|
|
353
|
+
window_kaggle, self.html_renderer(*args[: self.html_renderer.__code__.co_argcount])
|
|
354
|
+
)
|
|
264
355
|
if mode == "html":
|
|
265
356
|
return player_html
|
|
266
|
-
|
|
267
|
-
|
|
357
|
+
|
|
358
|
+
from IPython.display import HTML, display
|
|
359
|
+
|
|
360
|
+
player_html = player_html.replace('"', """)
|
|
268
361
|
width = get(kwargs, int, 300, path=["width"])
|
|
269
362
|
height = get(kwargs, int, 300, path=["height"])
|
|
270
363
|
html = f'<iframe srcdoc="{player_html}" width="{width}" height="{height}" frameborder="0"></iframe> '
|
|
271
364
|
display(HTML(html))
|
|
272
365
|
elif mode == "json":
|
|
273
|
-
return json.dumps(self.toJSON(), sort_keys=True)
|
|
366
|
+
return json.dumps(self.toJSON(), sort_keys=True, indent=2 if self.debug else None)
|
|
274
367
|
else:
|
|
275
|
-
raise InvalidArgument(
|
|
276
|
-
"Available render modes: human, ansi, html, ipython")
|
|
368
|
+
raise InvalidArgument("Available render modes: human, ansi, html, ipython")
|
|
277
369
|
|
|
278
|
-
def play(self, agents=
|
|
370
|
+
def play(self, agents=None, **kwargs):
|
|
279
371
|
"""
|
|
280
372
|
Renders a visual representation of the environment and allows interactive action selection.
|
|
281
373
|
|
|
@@ -285,12 +377,15 @@ class Environment:
|
|
|
285
377
|
Returns:
|
|
286
378
|
None: prints directly to an IPython notebook
|
|
287
379
|
"""
|
|
380
|
+
if agents is None:
|
|
381
|
+
agents = []
|
|
382
|
+
|
|
288
383
|
env = self.clone()
|
|
289
384
|
trainer = env.train(agents)
|
|
290
385
|
interactives[env.id] = (env, trainer)
|
|
291
386
|
env.render(mode="ipython", interactive=True, **kwargs)
|
|
292
387
|
|
|
293
|
-
def train(self, agents=
|
|
388
|
+
def train(self, agents=None):
|
|
294
389
|
"""
|
|
295
390
|
Setup a lightweight training environment for a single agent.
|
|
296
391
|
Note: This is designed to be a lightweight starting point which can
|
|
@@ -318,43 +413,41 @@ class Environment:
|
|
|
318
413
|
`dict`.reset: Reset def that reset the environment, then advances until the agents turn.
|
|
319
414
|
`dict`.step: Steps using the agent action, then advance until agents turn again.
|
|
320
415
|
"""
|
|
416
|
+
if agents is None:
|
|
417
|
+
agents = []
|
|
418
|
+
|
|
321
419
|
runner = None
|
|
322
420
|
position = None
|
|
323
421
|
for index, agent in enumerate(agents):
|
|
324
|
-
if agent
|
|
325
|
-
if position
|
|
326
|
-
raise InvalidArgument(
|
|
327
|
-
"Only one agent can be marked 'None'")
|
|
422
|
+
if agent is None:
|
|
423
|
+
if position is not None:
|
|
424
|
+
raise InvalidArgument("Only one agent can be marked 'None'")
|
|
328
425
|
position = index
|
|
329
426
|
|
|
330
|
-
if position
|
|
427
|
+
if position is None:
|
|
331
428
|
raise InvalidArgument("One agent must be marked 'None' to train.")
|
|
332
429
|
|
|
333
430
|
def advance():
|
|
334
431
|
while not self.done and self.state[position].status == "INACTIVE":
|
|
335
|
-
|
|
432
|
+
actions, logs = runner.act()
|
|
433
|
+
self.step(actions, logs)
|
|
336
434
|
|
|
337
435
|
def reset():
|
|
338
436
|
nonlocal runner
|
|
339
437
|
self.reset(len(agents))
|
|
340
|
-
if runner != None:
|
|
341
|
-
runner.destroy()
|
|
342
438
|
runner = self.__agent_runner(agents)
|
|
343
439
|
advance()
|
|
344
|
-
return self.
|
|
440
|
+
return self.__get_shared_state(position).observation
|
|
345
441
|
|
|
346
442
|
def step(action):
|
|
347
|
-
|
|
443
|
+
actions, logs = runner.act(action)
|
|
444
|
+
self.step(actions, logs)
|
|
348
445
|
advance()
|
|
349
|
-
agent = self.
|
|
446
|
+
agent = self.__get_shared_state(position)
|
|
350
447
|
reward = agent.reward
|
|
351
|
-
if len(self.steps) > 1 and reward
|
|
448
|
+
if len(self.steps) > 1 and reward is not None:
|
|
352
449
|
reward -= self.steps[-2][position].reward
|
|
353
|
-
|
|
354
|
-
runner.destroy()
|
|
355
|
-
return [
|
|
356
|
-
agent.observation, reward, agent.status != "ACTIVE", agent.info
|
|
357
|
-
]
|
|
450
|
+
return [agent.observation, reward, agent.status != "ACTIVE", agent.info]
|
|
358
451
|
|
|
359
452
|
reset()
|
|
360
453
|
|
|
@@ -395,12 +488,13 @@ class Environment:
|
|
|
395
488
|
"configuration": spec.configuration,
|
|
396
489
|
"info": spec.info,
|
|
397
490
|
"observation": spec.observation,
|
|
398
|
-
"reward": spec.reward
|
|
491
|
+
"reward": spec.reward,
|
|
399
492
|
},
|
|
400
493
|
"steps": self.steps,
|
|
401
494
|
"rewards": [state.reward for state in self.steps[-1]],
|
|
402
495
|
"statuses": [state.status for state in self.steps[-1]],
|
|
403
496
|
"schema_version": 1,
|
|
497
|
+
"info": self.info,
|
|
404
498
|
}
|
|
405
499
|
)
|
|
406
500
|
|
|
@@ -427,44 +521,35 @@ class Environment:
|
|
|
427
521
|
self.__state_schema_value = {
|
|
428
522
|
**schemas["state"],
|
|
429
523
|
"properties": {
|
|
430
|
-
"action": {
|
|
431
|
-
|
|
432
|
-
**get(spec, dict, path=["action"], fallback={})
|
|
433
|
-
},
|
|
434
|
-
"reward": {
|
|
435
|
-
**schemas.state.properties.reward,
|
|
436
|
-
**get(spec, dict, path=["reward"], fallback={})
|
|
437
|
-
},
|
|
524
|
+
"action": {**schemas.state.properties.action, **get(spec, dict, path=["action"], fallback={})},
|
|
525
|
+
"reward": {**schemas.state.properties.reward, **get(spec, dict, path=["reward"], fallback={})},
|
|
438
526
|
"info": {
|
|
439
527
|
**schemas.state.properties.info,
|
|
440
|
-
"properties": get(spec, dict, path=["info"], fallback={})
|
|
528
|
+
"properties": get(spec, dict, path=["info"], fallback={}),
|
|
441
529
|
},
|
|
442
530
|
"observation": {
|
|
443
531
|
**schemas.state.properties.observation,
|
|
444
|
-
"properties": get(spec, dict, path=["observation"], fallback={})
|
|
445
|
-
},
|
|
446
|
-
"status": {
|
|
447
|
-
**schemas.state.properties.status,
|
|
448
|
-
**get(spec, dict, path=["status"], fallback={})
|
|
532
|
+
"properties": get(spec, dict, path=["observation"], fallback={}),
|
|
449
533
|
},
|
|
534
|
+
"status": {**schemas.state.properties.status, **get(spec, dict, path=["status"], fallback={})},
|
|
450
535
|
},
|
|
451
536
|
}
|
|
452
537
|
return structify(self.__state_schema_value)
|
|
453
538
|
|
|
454
|
-
def __set_state(self, state=
|
|
539
|
+
def __set_state(self, state=None):
|
|
540
|
+
if state is None:
|
|
541
|
+
state = []
|
|
542
|
+
|
|
455
543
|
if len(state) not in self.specification.agents:
|
|
456
|
-
raise InvalidArgument(
|
|
457
|
-
f"{len(state)} is not a valid number of agent(s).")
|
|
544
|
+
raise InvalidArgument(f"{len(state)} is not a valid number of agent(s).")
|
|
458
545
|
|
|
459
|
-
self.state = structify([self.__get_state(index, s)
|
|
460
|
-
for index, s in enumerate(state)])
|
|
546
|
+
self.state = structify([self.__get_state(index, s) for index, s in enumerate(state)])
|
|
461
547
|
self.steps = [self.state]
|
|
462
548
|
return self.state
|
|
463
549
|
|
|
464
550
|
def __get_state(self, position, state):
|
|
465
551
|
key = f"__state_schema_{position}"
|
|
466
552
|
if not hasattr(self, key):
|
|
467
|
-
|
|
468
553
|
# Update a property default value based on position in defaults.
|
|
469
554
|
# Remove shared properties from non-first agents.
|
|
470
555
|
def update_props(props):
|
|
@@ -479,32 +564,79 @@ class Environment:
|
|
|
479
564
|
update_props(prop["properties"])
|
|
480
565
|
return props
|
|
481
566
|
|
|
482
|
-
props = structify(update_props(
|
|
483
|
-
copy.deepcopy(self.__state_schema.properties)))
|
|
567
|
+
props = structify(update_props(copy.deepcopy(self.__state_schema.properties)))
|
|
484
568
|
|
|
485
569
|
setattr(self, key, {**self.__state_schema, "properties": props})
|
|
486
570
|
|
|
487
571
|
err, data = process_schema(getattr(self, key), state)
|
|
488
572
|
if err:
|
|
489
|
-
raise InvalidArgument(
|
|
490
|
-
f"Default state generation failed for #{position}: " + err
|
|
491
|
-
)
|
|
573
|
+
raise InvalidArgument(f"Default state generation failed for #{position}: " + err)
|
|
492
574
|
return data
|
|
493
575
|
|
|
494
|
-
def
|
|
576
|
+
def __loop_through_interpreter(self, state, logs):
|
|
577
|
+
args = [structify(state), self, logs]
|
|
578
|
+
new_state = structify(self.interpreter(*args[: self.interpreter.__code__.co_argcount]))
|
|
579
|
+
new_state[0].observation.step = 0 if self.done else len(self.steps)
|
|
580
|
+
|
|
581
|
+
for index, agent in enumerate(new_state):
|
|
582
|
+
if index < len(logs) and "duration" in logs[index]:
|
|
583
|
+
duration = logs[index]["duration"]
|
|
584
|
+
overage_time_consumed = max(0, duration - self.configuration.actTimeout)
|
|
585
|
+
agent.observation.remainingOverageTime -= overage_time_consumed
|
|
586
|
+
if agent.status not in self.__state_schema.properties.status.enum:
|
|
587
|
+
self.debug_print(f"Invalid Action: {agent.status}")
|
|
588
|
+
agent.status = "INVALID"
|
|
589
|
+
if agent.status in ["ERROR", "INVALID", "TIMEOUT"]:
|
|
590
|
+
agent.reward = None
|
|
591
|
+
return new_state
|
|
592
|
+
|
|
593
|
+
def __run_interpreter_prod(self, state, logs):
|
|
594
|
+
out = None
|
|
595
|
+
err = None
|
|
495
596
|
try:
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
597
|
+
with (
|
|
598
|
+
StringIO() as out_buffer,
|
|
599
|
+
StringIO() as err_buffer,
|
|
600
|
+
redirect_stdout(out_buffer),
|
|
601
|
+
redirect_stderr(err_buffer),
|
|
602
|
+
):
|
|
603
|
+
try:
|
|
604
|
+
new_state = self.__loop_through_interpreter(state, logs)
|
|
605
|
+
return new_state
|
|
606
|
+
except Exception as e:
|
|
607
|
+
# Print the exception stack trace to our log
|
|
608
|
+
traceback.print_exc(file=err_buffer)
|
|
609
|
+
# Reraise e to ensure that the program exits
|
|
610
|
+
raise e
|
|
611
|
+
finally:
|
|
612
|
+
out = out_buffer.getvalue()
|
|
613
|
+
err = err_buffer.getvalue()
|
|
614
|
+
|
|
615
|
+
# strip if needed
|
|
616
|
+
# Allow up to 10k (default) log characters per step which is ~10MB per 600 step episode
|
|
617
|
+
max_log_length = self.configuration.get("maxLogLength", 10000)
|
|
618
|
+
if max_log_length is not None:
|
|
619
|
+
out = out[0:max_log_length]
|
|
620
|
+
err = err[0:max_log_length]
|
|
621
|
+
|
|
622
|
+
if out or err:
|
|
623
|
+
logs.append({"stdout": out, "stderr": err})
|
|
624
|
+
finally:
|
|
625
|
+
if out:
|
|
626
|
+
while out.endswith("\n"):
|
|
627
|
+
out = out[:-1]
|
|
628
|
+
self.debug_print(out)
|
|
629
|
+
if err:
|
|
630
|
+
while err.endswith("\n"):
|
|
631
|
+
err = err[:-1]
|
|
632
|
+
self.debug_print(err)
|
|
633
|
+
|
|
634
|
+
def __run_interpreter(self, state, logs):
|
|
635
|
+
# Append any environmental logs to any agent logs we collected.
|
|
636
|
+
if self.debug:
|
|
637
|
+
return self.__loop_through_interpreter(state, logs)
|
|
638
|
+
else:
|
|
639
|
+
return self.__run_interpreter_prod(state, logs)
|
|
508
640
|
|
|
509
641
|
def __process_specification(self, spec):
|
|
510
642
|
if has(spec, path=["reward"]):
|
|
@@ -514,86 +646,79 @@ class Environment:
|
|
|
514
646
|
return ("type must be an integer or number", None)
|
|
515
647
|
reward["type"] = [reward_type, "null"]
|
|
516
648
|
|
|
517
|
-
# Allow environments to extend the
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
649
|
+
# Allow environments to extend various parts of the specification.
|
|
650
|
+
def extend_specification(source, field_name):
|
|
651
|
+
field = copy.deepcopy(source[field_name]["properties"])
|
|
652
|
+
for k, v in get(spec, dict, {}, [field_name]).items():
|
|
653
|
+
# Set a new default value.
|
|
654
|
+
if not isinstance(v, dict):
|
|
655
|
+
if not has(field, path=[k]):
|
|
656
|
+
raise InvalidArgument(f"Field {field} was unable to set default of missing property: {k}")
|
|
657
|
+
field[k]["default"] = v
|
|
658
|
+
# Add a new field.
|
|
659
|
+
elif not has(field, path=[k]):
|
|
660
|
+
field[k] = v
|
|
661
|
+
# Override an existing field if types match.
|
|
662
|
+
elif field[k]["type"] == get(v, path=["type"]):
|
|
663
|
+
field[k] = v
|
|
664
|
+
# Types don't match - unable to extend.
|
|
665
|
+
else:
|
|
666
|
+
raise InvalidArgument(f"Field {field} was unable to extend: {k}")
|
|
667
|
+
|
|
668
|
+
spec[field_name] = field
|
|
669
|
+
|
|
670
|
+
extend_specification(schemas, "configuration")
|
|
671
|
+
extend_specification(schemas["state"]["properties"], "observation")
|
|
538
672
|
|
|
539
|
-
spec["configuration"] = configuration
|
|
540
673
|
return process_schema(schemas.specification, spec)
|
|
541
674
|
|
|
542
675
|
def __agent_runner(self, agents):
|
|
543
|
-
# Replace default agents with their source.
|
|
544
|
-
for i, agent in enumerate(agents):
|
|
545
|
-
if has(self.agents, path=[agent]):
|
|
546
|
-
agents[i] = self.agents[agent]
|
|
547
|
-
|
|
548
676
|
# Generate the agents.
|
|
549
|
-
agents = [Agent(
|
|
550
|
-
None else None for a in agents]
|
|
551
|
-
|
|
552
|
-
# Have the agents had a chance to initialize (first non-empty act).
|
|
553
|
-
initialized = [False] * len(agents)
|
|
677
|
+
agents = [Agent(agent, self) if agent is not None else None for agent in agents]
|
|
554
678
|
|
|
555
679
|
def act(none_action=None):
|
|
556
680
|
if len(agents) != len(self.state):
|
|
557
|
-
raise InvalidArgument(
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
if not initialized[i]:
|
|
569
|
-
initialized[i] = True
|
|
570
|
-
timeout += self.configuration.agentTimeout
|
|
571
|
-
state = self.__get_shared_state(i)
|
|
572
|
-
actions[i] = agent.act(state, timeout)
|
|
573
|
-
return actions
|
|
681
|
+
raise InvalidArgument("Number of agents must match the state length")
|
|
682
|
+
|
|
683
|
+
act_args = [
|
|
684
|
+
(
|
|
685
|
+
agent,
|
|
686
|
+
self.__get_shared_state(i),
|
|
687
|
+
self.configuration,
|
|
688
|
+
none_action,
|
|
689
|
+
)
|
|
690
|
+
for i, agent in enumerate(agents)
|
|
691
|
+
]
|
|
574
692
|
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
693
|
+
if all((agent is None or agent.is_parallelizable) for agent in agents):
|
|
694
|
+
if self.pool is None:
|
|
695
|
+
self.pool = Pool(processes=len(agents))
|
|
696
|
+
results = self.pool.map(act_agent, act_args)
|
|
697
|
+
else:
|
|
698
|
+
results = list(map(act_agent, act_args))
|
|
579
699
|
|
|
580
|
-
|
|
700
|
+
# results is a list of tuples where the first element is an agent action and the second is the agent log
|
|
701
|
+
# This destructures into two lists, a list of actions and a list of logs.
|
|
702
|
+
actions, logs = zip(*results)
|
|
703
|
+
return list(actions), list(logs)
|
|
581
704
|
|
|
582
|
-
|
|
583
|
-
if position == 0:
|
|
584
|
-
return self.state[0]
|
|
585
|
-
state = copy.deepcopy(self.state[position])
|
|
705
|
+
return structify({"act": act})
|
|
586
706
|
|
|
707
|
+
def __get_shared_state(self, position):
|
|
587
708
|
# Note: state and schema are required to be in sync (apart from shared ones).
|
|
588
709
|
def update_props(shared_state, state, schema_props):
|
|
589
710
|
for k, prop in schema_props.items():
|
|
590
|
-
|
|
711
|
+
# Hidden fields are tracked in the episode replay but are not provided to the agent at runtime
|
|
712
|
+
if get(prop, bool, path=["hidden"], fallback=False):
|
|
713
|
+
if k in state:
|
|
714
|
+
del state[k]
|
|
715
|
+
elif get(prop, bool, path=["shared"], fallback=False):
|
|
591
716
|
state[k] = shared_state[k]
|
|
592
717
|
elif has(prop, dict, path=["properties"]):
|
|
593
718
|
update_props(shared_state[k], state[k], prop["properties"])
|
|
594
719
|
return state
|
|
595
720
|
|
|
596
|
-
return update_props(self.state[0], state, self.__state_schema.properties)
|
|
721
|
+
return update_props(self.state[0], copy.deepcopy(self.state[position]), self.__state_schema.properties)
|
|
597
722
|
|
|
598
723
|
def debug_print(self, message):
|
|
599
724
|
if self.debug:
|