egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
egogym/evaluate.py ADDED
@@ -0,0 +1,216 @@
1
+ import os
2
+ import time
3
+ import cv2
4
+ import gymnasium as gym
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from itertools import cycle, islice
8
+ import multiprocessing
9
+ import json
10
+
11
+ import egogym
12
+ from egogym.wrappers import EpisodeMonitor, get_vlm_wrapper
13
+ import egogym.assets.constants as constants
14
+
15
+ def evaluate_policy(
16
+ task_name="pick",
17
+ policy=None,
18
+ robot="rum",
19
+ action_space="delta",
20
+ num_objs=1,
21
+ num_episodes=100,
22
+ max_steps=50,
23
+ render=False,
24
+ record=False,
25
+ render_freq=0,
26
+ render_size=(299, 224),
27
+ num_envs=1,
28
+ seed=42,
29
+ log_file=None,
30
+ objects_set=None,
31
+ reward_threshold=None,
32
+ use_unprivileged_vlm=None,
33
+ logs_dir="logs"
34
+ ):
35
+
36
+ num_envs = max(1, min(num_envs, num_episodes))
37
+ if num_envs > 1 and render_freq > 0:
38
+ print(f"Warning: render_freq={render_freq} is not supported with multiple environments. Setting render_freq=0.")
39
+ render_freq = 0
40
+ task_name = task_name.lower()
41
+ if objects_set is None:
42
+ if task_name == "pick":
43
+ objects_set = constants.lite_pick_objects_set
44
+ elif task_name == "open":
45
+ objects_set = constants.all_open_objects_set
46
+ else:
47
+ objects_set_options = {
48
+ 'all_pick': constants.all_pick_objects_set,
49
+ 'lite_pick': constants.lite_pick_objects_set,
50
+ 'diverse_pick': constants.diverse_pick_objects_set,
51
+ 'full_eval': constants.full_eval_objects_set,
52
+ 'all_open': constants.all_open_objects_set,
53
+ "all_close": constants.all_close_objects_set,
54
+ 'cabinet': constants.cabinet_objects_set,
55
+ 'drawer': constants.drawer_objects_set,
56
+ }
57
+ if objects_set in objects_set_options:
58
+ objects_set = objects_set_options[objects_set]
59
+ else:
60
+ raise ValueError(f"Unknown objects_set: {objects_set}")
61
+
62
+ if reward_threshold is None:
63
+ if task_name == "pick":
64
+ reward_threshold = 0.05
65
+ elif task_name == "open":
66
+ reward_threshold = 0.5
67
+
68
+ if not log_file:
69
+ log_file = f"evaluation_{int(time.time())}"
70
+
71
+ logs_dir = f"{logs_dir}/{log_file}"
72
+ os.makedirs(logs_dir, exist_ok=True)
73
+
74
+ eval_args = {
75
+ "task_name": task_name,
76
+ "robot": robot,
77
+ "action_space": action_space,
78
+ "num_objs": num_objs,
79
+ "num_episodes": num_episodes,
80
+ "max_steps": max_steps,
81
+ "render": render,
82
+ "record": record,
83
+ "render_freq": render_freq,
84
+ "render_size": render_size,
85
+ "num_envs": num_envs,
86
+ "seed": seed,
87
+ "log_file": log_file,
88
+ "objects_set": objects_set,
89
+ "reward_threshold": reward_threshold,
90
+ }
91
+
92
+ with open(os.path.join(logs_dir, "eval_args.json"), "w") as f:
93
+ json.dump(eval_args, f, indent=4)
94
+
95
+ if num_envs > 1:
96
+ multiprocessing.set_start_method("spawn", force=True)
97
+
98
+ expanded_objects = list(islice(cycle(objects_set), num_episodes))
99
+ env_objects = [expanded_objects[i::num_envs] for i in range(num_envs)]
100
+
101
+ def make_env(env_id):
102
+ def _init():
103
+ env = gym.make(
104
+ f"Egogym-{task_name.capitalize()}-v0",
105
+ robot=robot,
106
+ action_space=action_space,
107
+ render_mode='human' if render else None,
108
+ render_size=render_size,
109
+ num_objs=num_objs,
110
+ seed=seed + env_id,
111
+ objects_set=env_objects[env_id]
112
+ )
113
+ if use_unprivileged_vlm is not None:
114
+ wrapper_cls = get_vlm_wrapper(use_unprivileged_vlm)
115
+ env = wrapper_cls(env)
116
+ return env
117
+ return _init
118
+
119
+ env = gym.vector.AsyncVectorEnv(
120
+ [make_env(i) for i in range(num_envs)],
121
+ autoreset_mode=gym.vector.AutoresetMode.DISABLED
122
+ )
123
+ env = EpisodeMonitor(env, logs_dir=logs_dir, record=record, render_freq=render_freq, num_envs=num_envs)
124
+ obs, _ = env.reset()
125
+ env_episode_count = np.zeros(num_envs, dtype=int)
126
+ env_step_count = np.zeros(num_envs, dtype=int)
127
+ env_success_count = 0
128
+ episodes_per_env = np.array([len(env_objects[i]) for i in range(num_envs)])
129
+
130
+ pbar = tqdm(total=num_episodes, desc="Evaluating | Success Rate: 0.0%")
131
+
132
+ while np.any(env_episode_count < episodes_per_env):
133
+ actions = policy.get_action(obs)
134
+ inactive = env_episode_count >= episodes_per_env
135
+ actions[inactive] = None
136
+
137
+ obs, rewards, terminated, truncated, info = env.step(actions)
138
+ env_step_count += 1
139
+
140
+ if render:
141
+ ego_exo = np.concatenate([obs["rgb_ego"], obs["rgb_exo"]], axis=2)
142
+ combined_view = np.concatenate(ego_exo, axis=0)
143
+ cv2.imshow("Ego View", cv2.cvtColor(combined_view, cv2.COLOR_RGB2BGR))
144
+ cv2.waitKey(1)
145
+
146
+ dones = (rewards > reward_threshold) | terminated | (env_step_count >= max_steps)
147
+ done_indices = np.where(dones)[0]
148
+
149
+ active_done = done_indices[env_episode_count[done_indices] < episodes_per_env[done_indices]]
150
+
151
+ if len(active_done) > 0:
152
+ env.log_episodes(active_done)
153
+
154
+ successes = rewards[active_done] > reward_threshold
155
+ env_success_count += np.sum(successes)
156
+
157
+ env_episode_count[active_done] += 1
158
+ pbar.update(len(active_done))
159
+
160
+ total_completed = np.sum(env_episode_count)
161
+ success_rate = (env_success_count / total_completed) * 100 if total_completed > 0 else 0
162
+ pbar.set_description(f"Evaluating | Success Rate: {success_rate:.1f}%")
163
+
164
+ policy.reset(active_done.tolist())
165
+
166
+ reset_mask = np.zeros(num_envs, dtype=bool)
167
+ reset_mask[active_done] = True
168
+ obs, _ = env.reset(options={"reset_mask": reset_mask})
169
+ env_step_count[active_done] = 0
170
+
171
+ pbar.close()
172
+ success_rate = (env_success_count / num_episodes) * 100
173
+ return success_rate
174
+
175
+ else:
176
+
177
+ env = gym.make(
178
+ f"Egogym-{task_name.capitalize()}-v0",
179
+ robot=robot,
180
+ action_space=action_space,
181
+ render_mode='human' if render else None,
182
+ render_size=render_size,
183
+ num_objs=num_objs,
184
+ seed=seed,
185
+ objects_set=objects_set
186
+ )
187
+ if use_unprivileged_vlm is not None:
188
+ wrapper_cls = get_vlm_wrapper(use_unprivileged_vlm)
189
+ env = wrapper_cls(env)
190
+ env = EpisodeMonitor(env, logs_dir=logs_dir, record=record, render_freq=render_freq, num_envs=num_envs)
191
+
192
+ success_count = 0
193
+ pbar = tqdm(range(num_episodes), desc="Evaluating | Success Rate: 0.0%")
194
+ for episode_idx in pbar:
195
+ obs, _ = env.reset()
196
+ policy.reset()
197
+
198
+ success = False
199
+ for _ in range(max_steps):
200
+ action = policy.get_action(obs)
201
+ obs, reward, terminated, truncated, _ = env.step(action[0])
202
+
203
+ if reward > reward_threshold:
204
+ success = True
205
+ break
206
+ if terminated or truncated:
207
+ break
208
+
209
+ if success:
210
+ success_count += 1
211
+
212
+ success_rate = (success_count / (episode_idx + 1)) * 100
213
+ pbar.set_description(f"Evaluating | Success Rate: {success_rate:.1f}%")
214
+
215
+ env.close()
216
+ return (success_count / num_episodes) * 100 if num_envs == 1 else success_rate
@@ -0,0 +1,2 @@
1
+ from .objects_managers import ObjectsManager
2
+ from .textures_manager import TexturesManager
@@ -0,0 +1,30 @@
1
+ from egogym.utils import make_cabinet_xml, make_drawer_xml
2
+ from egogym.components import Object, ArticulableObject
3
+
4
+ class ObjectsManager:
5
+ def __init__(self, objects_dict, np_random, shuffle=True):
6
+ self.np_random = np_random
7
+ sorted_objects = sorted(objects_dict, key=lambda x: x[0])
8
+ if shuffle:
9
+ np_random.shuffle(sorted_objects)
10
+ self.objects_dict = sorted_objects
11
+ self.index = -1
12
+
13
+ def sample(self, random=False):
14
+ if random:
15
+ index = self.np_random.integers(len(self.objects_dict))
16
+ else:
17
+ self.index = (self.index + 1) % len(self.objects_dict)
18
+ index = self.index
19
+ if self.objects_dict[index][1] == "procedurally_generated":
20
+ return ArticulableObject(
21
+ name=self.objects_dict[index][0],
22
+ func=globals()[f"make_{self.objects_dict[index][0]}_xml"],
23
+ np_random=self.np_random,
24
+ )
25
+ else:
26
+ return Object(
27
+ name=self.objects_dict[index][0],
28
+ path=self.objects_dict[index][1],
29
+ np_random=self.np_random,
30
+ )
@@ -0,0 +1,21 @@
1
+ import numpy as np
2
+
3
+ from egogym.components import Object
4
+
5
+
6
+ class TexturesManager:
7
+ def __init__(self, textures_list, np_random):
8
+ self.textures_list = textures_list
9
+ self.np_random = np_random
10
+ sorted_textures = sorted(textures_list)
11
+ self.np_random.shuffle(sorted_textures)
12
+ self.textures_list = sorted_textures
13
+ self.index = -1
14
+
15
+ def sample(self, random=False):
16
+ if random:
17
+ index = self.np_random.integers(len(self.textures_list))
18
+ else:
19
+ self.index = (self.index + 1) % len(self.textures_list)
20
+ index = self.index
21
+ return self.textures_list[index]
@@ -0,0 +1,49 @@
1
+ import json
2
+ import numpy as np
3
+ import websockets
4
+
5
+ class MolmoClient:
6
+ def __init__(self, host="localhost", port=8765):
7
+ self.uri = f"ws://{host}:{port}"
8
+ self.websocket = None
9
+
10
+ async def connect(self):
11
+ self.websocket = await websockets.connect(
12
+ self.uri,
13
+ max_size=50 * 1024 * 1024,
14
+ )
15
+
16
+ async def disconnect(self):
17
+ if self.websocket:
18
+ await self.websocket.close()
19
+
20
+ async def infer_point(self, rgb: np.ndarray, object_name: str = None, prompt: str = None):
21
+ if not self.websocket:
22
+ raise RuntimeError("Not connected to server. Call connect() first.")
23
+
24
+ req = {
25
+ "action": "infer_point",
26
+ "rgb": rgb.tolist(),
27
+ }
28
+
29
+ if object_name:
30
+ req["object_name"] = object_name
31
+ elif prompt:
32
+ req["prompt"] = prompt
33
+ else:
34
+ raise ValueError("Provide either 'object_name' or 'prompt'")
35
+
36
+ await self.websocket.send(json.dumps(req))
37
+ resp = json.loads(await self.websocket.recv())
38
+
39
+ if resp["status"] == "error":
40
+ raise RuntimeError(f"Server error: {resp['message']}")
41
+
42
+ return np.array(resp["point"], dtype=np.float32)
43
+
44
+ async def __aenter__(self):
45
+ await self.connect()
46
+ return self
47
+
48
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
49
+ await self.disconnect()
@@ -0,0 +1,197 @@
1
+ import argparse
2
+ import asyncio
3
+ import json
4
+ import logging
5
+ import re
6
+ import numpy as np
7
+ import torch
8
+ import websockets
9
+ from PIL import Image
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ def extract_molmo_points(molmo_output: str):
16
+ points = []
17
+ for match in re.finditer(
18
+ r'x\d*="\s*([0-9]+(?:\.[0-9]+)?)"\s+y\d*="\s*([0-9]+(?:\.[0-9]+)?)"',
19
+ molmo_output,
20
+ ):
21
+ try:
22
+ p = np.array([float(match.group(1)), float(match.group(2))], dtype=np.float32)
23
+ except ValueError:
24
+ continue
25
+
26
+ if np.max(p) > 100:
27
+ continue
28
+
29
+ points.append(p / 100.0)
30
+
31
+ return points
32
+
33
+
34
+ class Molmo:
35
+ def __init__(self, model_name="allenai/Molmo-7B-D-0924"):
36
+ try:
37
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
38
+ except ImportError:
39
+ raise ImportError("transformers package is not installed. Please install it with `pip install transformers`.")
40
+
41
+ self.processor = AutoProcessor.from_pretrained(
42
+ model_name,
43
+ trust_remote_code=True,
44
+ torch_dtype="auto",
45
+ device_map="auto",
46
+ )
47
+ self.model = AutoModelForCausalLM.from_pretrained(
48
+ model_name,
49
+ trust_remote_code=True,
50
+ torch_dtype="auto",
51
+ device_map="auto",
52
+ )
53
+
54
+ def infer(self, rgb: np.ndarray, prompt: str) -> str:
55
+ image = Image.fromarray(rgb)
56
+ inputs = self.processor.process(images=[image], text=prompt)
57
+ inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
58
+
59
+ with torch.inference_mode():
60
+ output = self.model.generate_from_batch(
61
+ inputs,
62
+ GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
63
+ tokenizer=self.processor.tokenizer,
64
+ )
65
+
66
+ gen_tokens = output[0, inputs["input_ids"].size(1):]
67
+ return self.processor.tokenizer.decode(gen_tokens, skip_special_tokens=True)
68
+
69
+ def infer_point(self, rgb: np.ndarray, prompt: str) -> np.ndarray:
70
+ text = self.infer(rgb, prompt)
71
+ log.info(f"Molmo output: {text}")
72
+ points = extract_molmo_points(text)
73
+ return points[0] if points else np.array([0.5, 0.5], dtype=np.float32)
74
+
75
+
76
+ class MolmoClient:
77
+ def __init__(self, host="localhost", port=8765):
78
+ self.uri = f"ws://{host}:{port}"
79
+ self.websocket = None
80
+
81
+ async def connect(self):
82
+ self.websocket = await websockets.connect(
83
+ self.uri,
84
+ max_size=50 * 1024 * 1024,
85
+ )
86
+ log.info(f"Connected to Molmo server at {self.uri}")
87
+
88
+ async def disconnect(self):
89
+ if self.websocket:
90
+ await self.websocket.close()
91
+ log.info("Disconnected from Molmo server")
92
+
93
+ async def infer_point(self, rgb: np.ndarray, object_name: str = None, prompt: str = None):
94
+ if not self.websocket:
95
+ raise RuntimeError("Not connected to server. Call connect() first.")
96
+
97
+ req = {
98
+ "action": "infer_point",
99
+ "rgb": rgb.tolist(),
100
+ }
101
+
102
+ if object_name:
103
+ req["object_name"] = object_name
104
+ elif prompt:
105
+ req["prompt"] = prompt
106
+ else:
107
+ raise ValueError("Provide either 'object_name' or 'prompt'")
108
+
109
+ await self.websocket.send(json.dumps(req))
110
+ resp = json.loads(await self.websocket.recv())
111
+
112
+ if resp["status"] == "error":
113
+ raise RuntimeError(f"Server error: {resp['message']}")
114
+
115
+ return np.array(resp["point"], dtype=np.float32)
116
+
117
+ async def __aenter__(self):
118
+ await self.connect()
119
+ return self
120
+
121
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
122
+ await self.disconnect()
123
+
124
+
125
+ class MolmoWebSocketServer:
126
+ def __init__(self, host="0.0.0.0", port=8765):
127
+ self.host = host
128
+ self.port = port
129
+ self.molmo = Molmo()
130
+
131
+ async def handle_client(self, websocket):
132
+ client_id = id(websocket)
133
+ log.info(f"Client connected: {client_id}")
134
+
135
+ try:
136
+ async for msg in websocket:
137
+ try:
138
+ req = json.loads(msg)
139
+ action = req.get("action")
140
+
141
+ if action != "infer_point":
142
+ raise ValueError("Only 'infer_point' action is supported")
143
+
144
+ rgb = np.array(req["rgb"], dtype=np.uint8)
145
+
146
+ if "object_name" in req:
147
+ prompt = f"Point to the center of the {req['object_name']}."
148
+ label = req["object_name"].replace(" ", "_")
149
+ elif "prompt" in req:
150
+ prompt = req["prompt"]
151
+ label = "custom_prompt"
152
+ else:
153
+ raise ValueError("Provide either 'object_name' or 'prompt'")
154
+
155
+ point = self.molmo.infer_point(rgb, prompt)
156
+
157
+ resp = {
158
+ "status": "ok",
159
+ "action": "infer_point",
160
+ "point": point.tolist(),
161
+ }
162
+
163
+ except Exception as e:
164
+ log.exception("Request error")
165
+ resp = {"status": "error", "message": str(e)}
166
+
167
+ await websocket.send(json.dumps(resp))
168
+
169
+ except websockets.exceptions.ConnectionClosed:
170
+ log.info(f"Client disconnected: {client_id}")
171
+
172
+ async def start(self):
173
+ log.info(f"Starting Molmo server on ws://{self.host}:{self.port}")
174
+ async with websockets.serve(
175
+ self.handle_client,
176
+ self.host,
177
+ self.port,
178
+ max_size=50 * 1024 * 1024,
179
+ ):
180
+ await asyncio.Future()
181
+
182
+
183
+ def main():
184
+ parser = argparse.ArgumentParser("Molmo Pointing Server")
185
+ parser.add_argument("--host", default="0.0.0.0")
186
+ parser.add_argument("--port", type=int, default=8765)
187
+ args = parser.parse_args()
188
+
189
+ server = MolmoWebSocketServer(
190
+ host=args.host,
191
+ port=args.port,
192
+ )
193
+ asyncio.run(server.start())
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
@@ -0,0 +1 @@
1
+ from .base_policy import BasePolicy
@@ -0,0 +1,13 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ class BasePolicy(ABC):
4
+
5
+ def __init__(self):
6
+ self.name = "base"
7
+
8
+ @abstractmethod
9
+ def get_action(self, obs):
10
+ pass
11
+
12
+ def reset(self):
13
+ pass