genesis-forge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genesis_forge/__init__.py +12 -0
- genesis_forge/gamepads/__init__.py +5 -0
- genesis_forge/gamepads/config.py +104 -0
- genesis_forge/gamepads/debug.py +34 -0
- genesis_forge/gamepads/gamepad.py +191 -0
- genesis_forge/gamepads/logitech.py +63 -0
- genesis_forge/genesis_env.py +218 -0
- genesis_forge/managed_env.py +271 -0
- genesis_forge/managers/__init__.py +24 -0
- genesis_forge/managers/action/__init__.py +9 -0
- genesis_forge/managers/action/base.py +62 -0
- genesis_forge/managers/action/position_action_manager.py +450 -0
- genesis_forge/managers/action/position_within_limits.py +123 -0
- genesis_forge/managers/base.py +43 -0
- genesis_forge/managers/command/__init__.py +7 -0
- genesis_forge/managers/command/command_manager.py +316 -0
- genesis_forge/managers/command/velocity_command.py +344 -0
- genesis_forge/managers/contact_manager.py +500 -0
- genesis_forge/managers/entity/__init__.py +7 -0
- genesis_forge/managers/entity/config.py +46 -0
- genesis_forge/managers/entity/entity_manager.py +120 -0
- genesis_forge/managers/entity/reset.py +164 -0
- genesis_forge/managers/observation_manager.py +205 -0
- genesis_forge/managers/reward_manager.py +175 -0
- genesis_forge/managers/termination_manager.py +185 -0
- genesis_forge/managers/terrain_manager.py +288 -0
- genesis_forge/mdp/__init__.py +4 -0
- genesis_forge/mdp/rewards.py +262 -0
- genesis_forge/mdp/terminations.py +107 -0
- genesis_forge/rl/__init__.py +3 -0
- genesis_forge/rl/skrl/__init__.py +3 -0
- genesis_forge/rl/skrl/skrl_wrapper.py +79 -0
- genesis_forge/rl/skrl/utils.py +88 -0
- genesis_forge/utils.py +88 -0
- genesis_forge/wrappers/__init__.py +5 -0
- genesis_forge/wrappers/video.py +221 -0
- genesis_forge/wrappers/wrapper.py +98 -0
- genesis_forge-0.0.1.dist-info/METADATA +23 -0
- genesis_forge-0.0.1.dist-info/RECORD +41 -0
- genesis_forge-0.0.1.dist-info/WHEEL +4 -0
- genesis_forge-0.0.1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .rl.skrl import create_skrl_env
|
|
2
|
+
from .wrappers import VideoWrapper
|
|
3
|
+
from .genesis_env import GenesisEnv, EnvMode
|
|
4
|
+
from .managed_env import ManagedEnvironment
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"GenesisEnv",
|
|
8
|
+
"ManagedEnvironment",
|
|
9
|
+
"EnvMode",
|
|
10
|
+
"VideoWrapper",
|
|
11
|
+
"create_skrl_env",
|
|
12
|
+
]
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from typing import TypedDict
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Button(Enum):
|
|
6
|
+
A = 1
|
|
7
|
+
B = 2
|
|
8
|
+
X = 3
|
|
9
|
+
Y = 4
|
|
10
|
+
LB = 5
|
|
11
|
+
RB = 6
|
|
12
|
+
LT = 7
|
|
13
|
+
RT = 8
|
|
14
|
+
BACK = 9
|
|
15
|
+
START = 10
|
|
16
|
+
MODE = 11
|
|
17
|
+
UP = 12
|
|
18
|
+
DOWN = 13
|
|
19
|
+
LEFT = 14
|
|
20
|
+
RIGHT = 15
|
|
21
|
+
LEFT_JOYSTICK = 16
|
|
22
|
+
RIGHT_JOYSTICK = 17
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DPad(Enum):
|
|
26
|
+
NONE = 0
|
|
27
|
+
UP = 1
|
|
28
|
+
DOWN = 2
|
|
29
|
+
LEFT = 3
|
|
30
|
+
RIGHT = 4
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class GamepadMapping(TypedDict):
|
|
34
|
+
"""
|
|
35
|
+
Defines the how to extract a value from a gamepad data array.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
data: int
|
|
39
|
+
"""The index of the data in the gamepad data array."""
|
|
40
|
+
|
|
41
|
+
button: Button
|
|
42
|
+
"""If this is a button, what is it's name"""
|
|
43
|
+
|
|
44
|
+
axis: int
|
|
45
|
+
"""If this is a joystick, what is it's axis number"""
|
|
46
|
+
|
|
47
|
+
dpad: DPad
|
|
48
|
+
"""D-pad direction."""
|
|
49
|
+
|
|
50
|
+
bitmask: int
|
|
51
|
+
"""The bitmask to extract the value from the data."""
|
|
52
|
+
|
|
53
|
+
value: int
|
|
54
|
+
"""Match this value to the value at the data index."""
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class GamepadState:
|
|
58
|
+
"""Data from a gamepad."""
|
|
59
|
+
|
|
60
|
+
axis_values: list[float] = []
|
|
61
|
+
"""The value (-1 to 1) for each axes."""
|
|
62
|
+
|
|
63
|
+
buttons: list[str] = []
|
|
64
|
+
"""A list of the buttons that are pressed."""
|
|
65
|
+
|
|
66
|
+
dpad: DPad
|
|
67
|
+
"""D-pad direction."""
|
|
68
|
+
|
|
69
|
+
def axis(self, index: int):
|
|
70
|
+
"""
|
|
71
|
+
Get the axis value at an index.
|
|
72
|
+
This is the preferred way to get the axis value, because the axis array will not be filled until the gamepad
|
|
73
|
+
receives input.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
index: The index of the axis to get the value of.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
The value of the axis at the index.
|
|
80
|
+
"""
|
|
81
|
+
if index >= len(self.axis_values):
|
|
82
|
+
return 0.0
|
|
83
|
+
return self.axis_values[index]
|
|
84
|
+
|
|
85
|
+
def __repr__(self):
|
|
86
|
+
return f"GamepadState(axis={self.axis_values}, buttons={self.buttons}, dpad={self.dpad})"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class GamepadConfig(TypedDict):
|
|
90
|
+
"""
|
|
91
|
+
Defines a gamepad, how to connect to it, and how to map the buttons and axes.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
name: str
|
|
95
|
+
"""The name of the gamepad."""
|
|
96
|
+
|
|
97
|
+
vendor_id: int
|
|
98
|
+
"""The vendor id of the gamepad."""
|
|
99
|
+
|
|
100
|
+
product_id: int
|
|
101
|
+
"""The product id of the gamepad."""
|
|
102
|
+
|
|
103
|
+
mapping: list[GamepadMapping]
|
|
104
|
+
"""The mapping of the gamepad."""
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import argparse
|
|
3
|
+
|
|
4
|
+
from .gamepad import Gamepad, GamepadState
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DebugGamepad(Gamepad):
|
|
8
|
+
"""
|
|
9
|
+
This is just used to output the HID data to the console so you can implement a new gamepad class.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, *args, **kwargs):
|
|
13
|
+
super().__init__(*args, **kwargs)
|
|
14
|
+
|
|
15
|
+
def parse_data(self, data) -> GamepadState:
|
|
16
|
+
print(data)
|
|
17
|
+
# print([bin(d) for d in data])
|
|
18
|
+
return GamepadState()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
if __name__ == "__main__":
|
|
22
|
+
"""Run from the CLI with python -m genesis_forge.managers.command.gamepads.debug --help"""
|
|
23
|
+
parser = argparse.ArgumentParser(add_help=True)
|
|
24
|
+
parser.add_argument("-v", "--vender_id", type=str)
|
|
25
|
+
parser.add_argument("-p", "--product_id", type=str)
|
|
26
|
+
args = parser.parse_args()
|
|
27
|
+
|
|
28
|
+
vendor_id = int(args.vender_id, 16)
|
|
29
|
+
product_id = int(args.product_id, 16)
|
|
30
|
+
|
|
31
|
+
gamepad = DebugGamepad(vendor_id=vendor_id, product_id=product_id)
|
|
32
|
+
while True:
|
|
33
|
+
# print(gamepad.get_command())
|
|
34
|
+
time.sleep(0.1)
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""Logitech F310/F710 Gamepad class that uses HID under the hood.
|
|
2
|
+
|
|
3
|
+
Adapted from: https://github.com/google-deepmind/mujoco_playground/blob/a873d53765a4c83572cf44fa74768ab62ceb7be1/mujoco_playground/experimental/sim2sim/gamepad_reader.py.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import time
|
|
7
|
+
import argparse
|
|
8
|
+
import hid
|
|
9
|
+
import threading
|
|
10
|
+
|
|
11
|
+
from .config import GamepadConfig, GamepadState
|
|
12
|
+
from .logitech import LOGITECH_F710_CONFIG, LOGITECH_F310_CONFIG
|
|
13
|
+
|
|
14
|
+
GAMEPAD_CONFIGS = [
|
|
15
|
+
LOGITECH_F710_CONFIG,
|
|
16
|
+
LOGITECH_F310_CONFIG,
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Gamepad:
|
|
21
|
+
"""
|
|
22
|
+
General gamepad controller.
|
|
23
|
+
If a config is not provided, it will automatically attempt to connect to one of the configured gamepads.
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
>>> gamepad = Gamepad(config=LOGITECH_F710_CONFIG)
|
|
27
|
+
>>> gamepad.state
|
|
28
|
+
GamepadState(axis=[0.0, 0.0, 0.0, 0.0], buttons=[A], dpad=UP)
|
|
29
|
+
>>> gamepad.state.axis
|
|
30
|
+
[0.0, 0.0, 0.0, 0.0]
|
|
31
|
+
>>> gamepad.state.buttons
|
|
32
|
+
["A"]
|
|
33
|
+
>>> gamepad.state.dpad
|
|
34
|
+
"UP"
|
|
35
|
+
>>> gamepad.state.buttons = [Button.A]
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
config: GamepadConfig = None,
|
|
41
|
+
vendor_id=None,
|
|
42
|
+
product_id=None,
|
|
43
|
+
debug=False,
|
|
44
|
+
):
|
|
45
|
+
self._config = config
|
|
46
|
+
self._vendor_id = vendor_id
|
|
47
|
+
self._product_id = product_id
|
|
48
|
+
|
|
49
|
+
if vendor_id is None and config is not None:
|
|
50
|
+
self._vendor_id = config["vendor_id"]
|
|
51
|
+
if product_id is None and config is not None:
|
|
52
|
+
self._product_id = config["product_id"]
|
|
53
|
+
|
|
54
|
+
self._state = GamepadState()
|
|
55
|
+
self._debug = debug
|
|
56
|
+
|
|
57
|
+
self.is_running = True
|
|
58
|
+
self._device = None
|
|
59
|
+
|
|
60
|
+
self.connect()
|
|
61
|
+
self.read_thread = threading.Thread(target=self.read_loop, daemon=True)
|
|
62
|
+
self.read_thread.start()
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def state(self) -> GamepadState:
|
|
66
|
+
"""
|
|
67
|
+
The current state of the gamepad.
|
|
68
|
+
"""
|
|
69
|
+
return self._state
|
|
70
|
+
|
|
71
|
+
def parse_data(self, data: list[int]) -> GamepadState:
|
|
72
|
+
"""Parse gamepad data into a GamepadState object."""
|
|
73
|
+
axis = []
|
|
74
|
+
buttons = []
|
|
75
|
+
dpad = None
|
|
76
|
+
|
|
77
|
+
# No gamepad config, so we cann't parse the data
|
|
78
|
+
if self._config is None:
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
for cfg in self._config["mapping"]:
|
|
82
|
+
if "data" not in cfg:
|
|
83
|
+
print(f"Warning: {cfg} has no data value")
|
|
84
|
+
continue
|
|
85
|
+
if cfg["data"] >= len(data):
|
|
86
|
+
print(f"Error: {cfg} data is out of range")
|
|
87
|
+
continue
|
|
88
|
+
value = data[cfg["data"]]
|
|
89
|
+
value_truthy = False
|
|
90
|
+
|
|
91
|
+
# Apply the bitmask to the value
|
|
92
|
+
if "bitmask" in cfg:
|
|
93
|
+
value = value & cfg["bitmask"]
|
|
94
|
+
if value != 0:
|
|
95
|
+
value_truthy = True
|
|
96
|
+
elif "button" in cfg or "dpad" in cfg:
|
|
97
|
+
print(f"Warning: {cfg} has no bitmask value")
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
# Check if value is matches
|
|
101
|
+
if "value" in cfg:
|
|
102
|
+
value_truthy = value == cfg["value"]
|
|
103
|
+
|
|
104
|
+
if "button" in cfg and value_truthy:
|
|
105
|
+
buttons.append(cfg["button"].name)
|
|
106
|
+
elif "dpad" in cfg and value_truthy:
|
|
107
|
+
dpad = cfg["dpad"].name
|
|
108
|
+
elif "axis" in cfg:
|
|
109
|
+
value = -(value - 128) / 128.0
|
|
110
|
+
axis.insert(cfg["axis"], value)
|
|
111
|
+
|
|
112
|
+
self._state.axis_values = axis
|
|
113
|
+
self._state.buttons = buttons
|
|
114
|
+
self._state.dpad = dpad
|
|
115
|
+
return self._state
|
|
116
|
+
|
|
117
|
+
def auto_connect(self):
|
|
118
|
+
"""Loop through the available gamepad configs until one connects."""
|
|
119
|
+
for config in GAMEPAD_CONFIGS:
|
|
120
|
+
self._vendor_id = config["vendor_id"]
|
|
121
|
+
self._product_id = config["product_id"]
|
|
122
|
+
self._config = config
|
|
123
|
+
try:
|
|
124
|
+
if self.connect():
|
|
125
|
+
return
|
|
126
|
+
except:
|
|
127
|
+
pass
|
|
128
|
+
raise IOError(f"Could not find a gamepad to connect to")
|
|
129
|
+
|
|
130
|
+
def connect(self, vendor_id=None, product_id=None):
|
|
131
|
+
"""
|
|
132
|
+
Attempt to connect to a gamepad.
|
|
133
|
+
"""
|
|
134
|
+
if vendor_id is None:
|
|
135
|
+
vendor_id = self._vendor_id
|
|
136
|
+
if product_id is None:
|
|
137
|
+
product_id = self._product_id
|
|
138
|
+
|
|
139
|
+
# If the vendor/product IDs aren't set, loop through the available gamepad configs
|
|
140
|
+
if product_id is None and vendor_id is None:
|
|
141
|
+
self.auto_connect()
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
self._device = hid.device()
|
|
146
|
+
self._device.open(vendor_id, product_id)
|
|
147
|
+
self._device.set_nonblocking(True)
|
|
148
|
+
print(
|
|
149
|
+
f"Connected to gamepad {self._device.get_manufacturer_string()} {self._device.get_product_string()}"
|
|
150
|
+
)
|
|
151
|
+
return True
|
|
152
|
+
except IOError as e:
|
|
153
|
+
raise IOError(
|
|
154
|
+
f"Error connecting to gamepad 0x{vendor_id:04x}:0x{product_id:04x}: {e}"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def read_loop(self):
|
|
158
|
+
"""
|
|
159
|
+
Wait for gamepad input, and then update the gamepad state.
|
|
160
|
+
"""
|
|
161
|
+
while self.is_running:
|
|
162
|
+
try:
|
|
163
|
+
data = self._device.read(64)
|
|
164
|
+
if data:
|
|
165
|
+
try:
|
|
166
|
+
self._state = self.parse_data(data)
|
|
167
|
+
if self._debug:
|
|
168
|
+
print(self._state)
|
|
169
|
+
except Exception as e:
|
|
170
|
+
print(f"Error parsing data: {e}")
|
|
171
|
+
except Exception as e:
|
|
172
|
+
print(f"Error reading from device: {e}")
|
|
173
|
+
|
|
174
|
+
self._device.close()
|
|
175
|
+
|
|
176
|
+
def stop(self):
|
|
177
|
+
"""
|
|
178
|
+
Stop reading gamepad input.
|
|
179
|
+
"""
|
|
180
|
+
self.is_running = False
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
if __name__ == "__main__":
|
|
184
|
+
parser = argparse.ArgumentParser(
|
|
185
|
+
description="Test the Gamepad connection", add_help=True
|
|
186
|
+
)
|
|
187
|
+
args = parser.parse_args()
|
|
188
|
+
|
|
189
|
+
gamepad = Gamepad(debug=True)
|
|
190
|
+
while True:
|
|
191
|
+
time.sleep(1.0)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Logitech F310/F710 Gamepad configuration."""
|
|
2
|
+
|
|
3
|
+
from .config import Button, DPad, GamepadConfig
|
|
4
|
+
|
|
5
|
+
VENDOR_ID = 0x046D
|
|
6
|
+
|
|
7
|
+
LOGITECH_F710_CONFIG: GamepadConfig = {
|
|
8
|
+
"name": "F710",
|
|
9
|
+
"vendor_id": VENDOR_ID,
|
|
10
|
+
"product_id": 0xC219,
|
|
11
|
+
"mapping": [
|
|
12
|
+
{"axis": 0, "data": 1},
|
|
13
|
+
{"axis": 1, "data": 2},
|
|
14
|
+
{"axis": 2, "data": 3},
|
|
15
|
+
{"axis": 3, "data": 4},
|
|
16
|
+
{"dpad": DPad.UP, "data": 5, "bitmask": 15, "value": 0},
|
|
17
|
+
{"dpad": DPad.DOWN, "data": 5, "bitmask": 15, "value": 4},
|
|
18
|
+
{"dpad": DPad.RIGHT, "data": 5, "bitmask": 15, "value": 2},
|
|
19
|
+
{"dpad": DPad.LEFT, "data": 5, "bitmask": 15, "value": 6},
|
|
20
|
+
{"button": Button.A, "data": 5, "bitmask": 32},
|
|
21
|
+
{"button": Button.B, "data": 5, "bitmask": 64},
|
|
22
|
+
{"button": Button.X, "data": 5, "bitmask": 16},
|
|
23
|
+
{"button": Button.Y, "data": 5, "bitmask": 128},
|
|
24
|
+
{"button": Button.LB, "data": 6, "bitmask": 1},
|
|
25
|
+
{"button": Button.RB, "data": 6, "bitmask": 2},
|
|
26
|
+
{"button": Button.LT, "data": 6, "bitmask": 4},
|
|
27
|
+
{"button": Button.RT, "data": 6, "bitmask": 8},
|
|
28
|
+
{"button": Button.BACK, "data": 6, "bitmask": 16},
|
|
29
|
+
{"button": Button.START, "data": 6, "bitmask": 32},
|
|
30
|
+
{"button": Button.MODE, "data": 7, "bitmask": 8},
|
|
31
|
+
{"button": Button.LEFT_JOYSTICK, "data": 6, "bitmask": 64},
|
|
32
|
+
{"button": Button.RIGHT_JOYSTICK, "data": 6, "bitmask": 128},
|
|
33
|
+
],
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
LOGITECH_F310_CONFIG: GamepadConfig = {
|
|
37
|
+
"name": "F310",
|
|
38
|
+
"vendor_id": VENDOR_ID,
|
|
39
|
+
"product_id": 0xC216,
|
|
40
|
+
"mapping": [
|
|
41
|
+
{"axis": 0, "data": 0},
|
|
42
|
+
{"axis": 1, "data": 1},
|
|
43
|
+
{"axis": 2, "data": 2},
|
|
44
|
+
{"axis": 3, "data": 3},
|
|
45
|
+
{"dpad": DPad.UP, "data": 4, "bitmask": 15, "value": 0},
|
|
46
|
+
{"dpad": DPad.DOWN, "data": 4, "bitmask": 15, "value": 4},
|
|
47
|
+
{"dpad": DPad.RIGHT, "data": 4, "bitmask": 15, "value": 2},
|
|
48
|
+
{"dpad": DPad.LEFT, "data": 4, "bitmask": 15, "value": 6},
|
|
49
|
+
{"button": Button.A, "data": 4, "bitmask": 32},
|
|
50
|
+
{"button": Button.B, "data": 4, "bitmask": 64},
|
|
51
|
+
{"button": Button.X, "data": 4, "bitmask": 16},
|
|
52
|
+
{"button": Button.Y, "data": 4, "bitmask": 128},
|
|
53
|
+
{"button": Button.LB, "data": 5, "bitmask": 1},
|
|
54
|
+
{"button": Button.RB, "data": 5, "bitmask": 2},
|
|
55
|
+
{"button": Button.LT, "data": 5, "bitmask": 4},
|
|
56
|
+
{"button": Button.RT, "data": 5, "bitmask": 8},
|
|
57
|
+
{"button": Button.BACK, "data": 5, "bitmask": 16},
|
|
58
|
+
{"button": Button.START, "data": 5, "bitmask": 32},
|
|
59
|
+
{"button": Button.MODE, "data": 6, "bitmask": 8},
|
|
60
|
+
{"button": Button.LEFT_JOYSTICK, "data": 5, "bitmask": 64},
|
|
61
|
+
{"button": Button.RIGHT_JOYSTICK, "data": 5, "bitmask": 128},
|
|
62
|
+
],
|
|
63
|
+
}
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Simplified Spider Robot Environment with Curriculum Learning
|
|
3
|
+
Focuses on core objectives with progressive difficulty
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
import torch
|
|
8
|
+
import genesis as gs
|
|
9
|
+
from gymnasium import spaces
|
|
10
|
+
from genesis.engine.entities import RigidEntity
|
|
11
|
+
from typing import Any, Literal
|
|
12
|
+
|
|
13
|
+
EnvMode = Literal["train", "eval", "play"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GenesisEnv:
|
|
17
|
+
"""
|
|
18
|
+
Base vectorized environment for Genesis.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
num_envs: Number of parallel environments.
|
|
22
|
+
dt: Simulation time step.
|
|
23
|
+
max_episode_length_sec: Maximum episode length in seconds.
|
|
24
|
+
max_episode_random_scaling: Scale the maximum episode length by this amount (+/-) so that not all environments reset at the same time.
|
|
25
|
+
headless: Whether to run the environment in headless mode.
|
|
26
|
+
extras_logging_key: The key used, in info/extras dict, which is returned by step and reset functions, to send data to tensorboard by the RL agent.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
action_space: spaces.Space | None = None
|
|
30
|
+
observation_space: spaces.Space | None = None
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
num_envs: int = 1,
|
|
35
|
+
dt: float = 1 / 100,
|
|
36
|
+
max_episode_length_sec: int | None = 10,
|
|
37
|
+
max_episode_random_scaling: float = 0.0,
|
|
38
|
+
headless: bool = True,
|
|
39
|
+
mode: EnvMode = "train",
|
|
40
|
+
extras_logging_key: str = "episode",
|
|
41
|
+
):
|
|
42
|
+
self.dt = dt
|
|
43
|
+
self.device = gs.device
|
|
44
|
+
self.num_envs = num_envs
|
|
45
|
+
self.headless = headless
|
|
46
|
+
self.mode = mode
|
|
47
|
+
self.scene: gs.Scene = None
|
|
48
|
+
self.robot: RigidEntity = None
|
|
49
|
+
self.terrain: RigidEntity = None
|
|
50
|
+
|
|
51
|
+
self.extras_logging_key = extras_logging_key
|
|
52
|
+
self._extras = {}
|
|
53
|
+
self._extras[extras_logging_key] = {}
|
|
54
|
+
|
|
55
|
+
self._actions: torch.Tensor = None
|
|
56
|
+
self.last_actions: torch.Tensor = None
|
|
57
|
+
|
|
58
|
+
self.step_count: int = 0
|
|
59
|
+
self.episode_length = torch.zeros(
|
|
60
|
+
(self.num_envs,), device=gs.device, dtype=torch.int32
|
|
61
|
+
)
|
|
62
|
+
self.max_episode_length: torch.Tensor = None
|
|
63
|
+
|
|
64
|
+
self._max_episode_length_sec = 0.0
|
|
65
|
+
self._max_episode_random_scaling = 0.0
|
|
66
|
+
self._base_max_episode_length = None
|
|
67
|
+
if max_episode_length_sec and max_episode_length_sec > 0:
|
|
68
|
+
self._max_episode_random_scaling = max_episode_random_scaling / self.dt
|
|
69
|
+
self.max_episode_length = torch.zeros(
|
|
70
|
+
(self.num_envs,), device=gs.device, dtype=gs.tc_int
|
|
71
|
+
)
|
|
72
|
+
self.max_episode_length[:] = self.set_max_episode_length(
|
|
73
|
+
max_episode_length_sec
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
Properties
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def unwrapped(self):
|
|
82
|
+
"""Returns the base non-wrapped environment.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Env: The base non-wrapped :class:`GenesisEnv` instance
|
|
86
|
+
"""
|
|
87
|
+
return self
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def max_episode_length_sec(self) -> int | None:
|
|
91
|
+
"""The max episode length, in seconds, for each environment."""
|
|
92
|
+
return self._max_episode_length_sec
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def extras(self) -> dict:
|
|
96
|
+
"""
|
|
97
|
+
The extras/infos dictionary that should be returned by the step and reset functions.
|
|
98
|
+
This dictionary will be cleared at the start of every step.
|
|
99
|
+
"""
|
|
100
|
+
return self._extras
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def actions(self) -> torch.Tensor:
|
|
104
|
+
"""The current actions for each environment for this step."""
|
|
105
|
+
return self._actions
|
|
106
|
+
|
|
107
|
+
@actions.setter
|
|
108
|
+
def actions(self, actions: torch.Tensor):
|
|
109
|
+
"""Set the actions for each environment for this step."""
|
|
110
|
+
self._actions = actions
|
|
111
|
+
|
|
112
|
+
"""
|
|
113
|
+
Utilities
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def set_max_episode_length(self, max_episode_length_sec: int) -> int:
|
|
117
|
+
"""
|
|
118
|
+
Set or change the maximum episode length.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
max_episode_length_sec: The maximum episode length in seconds.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
The maximum episode length in steps.
|
|
125
|
+
"""
|
|
126
|
+
self._max_episode_length_sec = max_episode_length_sec
|
|
127
|
+
self._base_max_episode_length = math.ceil(max_episode_length_sec / self.dt)
|
|
128
|
+
return self._base_max_episode_length
|
|
129
|
+
|
|
130
|
+
"""
|
|
131
|
+
Operations
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def build(self) -> None:
|
|
135
|
+
"""
|
|
136
|
+
Builds the scene and other supporting components necessary for the training environment.
|
|
137
|
+
This assumes that the scene has already been constructed and assigned to the <env>.scene attribute.
|
|
138
|
+
"""
|
|
139
|
+
assert (
|
|
140
|
+
self.scene is not None
|
|
141
|
+
), "The scene must be constructed and assigned to the <env>.scene attribute before building."
|
|
142
|
+
self.scene.build(n_envs=self.num_envs)
|
|
143
|
+
|
|
144
|
+
def step(
|
|
145
|
+
self, actions: torch.Tensor
|
|
146
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
|
|
147
|
+
"""
|
|
148
|
+
Take an action for each parallel environment.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
actions: Batch of actions with the :attr:`action_space` shape.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Batch of (observations, rewards, terminations, truncations, infos/extras)
|
|
155
|
+
"""
|
|
156
|
+
self._extras = {}
|
|
157
|
+
self._extras[self.extras_logging_key] = {}
|
|
158
|
+
self.step_count += 1
|
|
159
|
+
self.episode_length += 1
|
|
160
|
+
|
|
161
|
+
self.last_actions[:] = self.actions[:]
|
|
162
|
+
self._actions = actions
|
|
163
|
+
|
|
164
|
+
return None, None, None, None, self._extras
|
|
165
|
+
|
|
166
|
+
def reset(
|
|
167
|
+
self,
|
|
168
|
+
envs_idx: list[int] = None,
|
|
169
|
+
) -> tuple[torch.Tensor, dict[str, Any]]:
|
|
170
|
+
"""
|
|
171
|
+
Reset one or all parallel environments.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
envs_idx: The environment ids to reset. If None, all environments are reset.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
A batch of observations and info from the vectorized environment.
|
|
178
|
+
"""
|
|
179
|
+
if envs_idx is None:
|
|
180
|
+
envs_idx = torch.arange(self.num_envs, device=gs.device)
|
|
181
|
+
|
|
182
|
+
# Initial reset, set buffers
|
|
183
|
+
if self.step_count == 0:
|
|
184
|
+
self.actions = torch.zeros(
|
|
185
|
+
(self.num_envs, self.action_space.shape[0]),
|
|
186
|
+
device=gs.device,
|
|
187
|
+
dtype=gs.tc_float,
|
|
188
|
+
)
|
|
189
|
+
self.last_actions = torch.zeros_like(self.actions, device=gs.device)
|
|
190
|
+
|
|
191
|
+
# Actions
|
|
192
|
+
if envs_idx.numel() > 0:
|
|
193
|
+
self.actions[envs_idx] = 0.0
|
|
194
|
+
self.last_actions[envs_idx] = 0.0
|
|
195
|
+
|
|
196
|
+
# Episode length
|
|
197
|
+
self.episode_length[envs_idx] = 0
|
|
198
|
+
|
|
199
|
+
# Randomize max episode length for env_ids
|
|
200
|
+
if (
|
|
201
|
+
len(envs_idx) > 0
|
|
202
|
+
and self._max_episode_random_scaling > 0.0
|
|
203
|
+
and self._base_max_episode_length
|
|
204
|
+
):
|
|
205
|
+
scale = torch.rand((envs_idx.numel(),)) * self._max_episode_random_scaling
|
|
206
|
+
self.max_episode_length[envs_idx] = torch.round(
|
|
207
|
+
self._base_max_episode_length + scale
|
|
208
|
+
).to(gs.tc_int)
|
|
209
|
+
|
|
210
|
+
return None, self.extras
|
|
211
|
+
|
|
212
|
+
def close(self):
|
|
213
|
+
"""Close the environment."""
|
|
214
|
+
pass
|
|
215
|
+
|
|
216
|
+
def render(self):
|
|
217
|
+
"""Not implemented."""
|
|
218
|
+
pass
|