gr-libs 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. evaluation/analyze_results_cross_alg_cross_domain.py +277 -0
  2. evaluation/create_minigrid_map_image.py +34 -0
  3. evaluation/file_system.py +42 -0
  4. evaluation/generate_experiments_results.py +92 -0
  5. evaluation/generate_experiments_results_new_ver1.py +254 -0
  6. evaluation/generate_experiments_results_new_ver2.py +331 -0
  7. evaluation/generate_task_specific_statistics_plots.py +272 -0
  8. evaluation/get_plans_images.py +47 -0
  9. evaluation/increasing_and_decreasing_.py +63 -0
  10. gr_libs/__init__.py +2 -0
  11. gr_libs/environment/__init__.py +0 -0
  12. gr_libs/environment/environment.py +227 -0
  13. gr_libs/environment/utils/__init__.py +0 -0
  14. gr_libs/environment/utils/utils.py +17 -0
  15. gr_libs/metrics/__init__.py +0 -0
  16. gr_libs/metrics/metrics.py +224 -0
  17. gr_libs/ml/__init__.py +6 -0
  18. gr_libs/ml/agent.py +56 -0
  19. gr_libs/ml/base/__init__.py +1 -0
  20. gr_libs/ml/base/rl_agent.py +54 -0
  21. gr_libs/ml/consts.py +22 -0
  22. gr_libs/ml/neural/__init__.py +3 -0
  23. gr_libs/ml/neural/deep_rl_learner.py +395 -0
  24. gr_libs/ml/neural/utils/__init__.py +2 -0
  25. gr_libs/ml/neural/utils/dictlist.py +33 -0
  26. gr_libs/ml/neural/utils/penv.py +57 -0
  27. gr_libs/ml/planner/__init__.py +0 -0
  28. gr_libs/ml/planner/mcts/__init__.py +0 -0
  29. gr_libs/ml/planner/mcts/mcts_model.py +330 -0
  30. gr_libs/ml/planner/mcts/utils/__init__.py +2 -0
  31. gr_libs/ml/planner/mcts/utils/node.py +33 -0
  32. gr_libs/ml/planner/mcts/utils/tree.py +102 -0
  33. gr_libs/ml/sequential/__init__.py +1 -0
  34. gr_libs/ml/sequential/lstm_model.py +192 -0
  35. gr_libs/ml/tabular/__init__.py +3 -0
  36. gr_libs/ml/tabular/state.py +21 -0
  37. gr_libs/ml/tabular/tabular_q_learner.py +453 -0
  38. gr_libs/ml/tabular/tabular_rl_agent.py +126 -0
  39. gr_libs/ml/utils/__init__.py +6 -0
  40. gr_libs/ml/utils/env.py +7 -0
  41. gr_libs/ml/utils/format.py +100 -0
  42. gr_libs/ml/utils/math.py +13 -0
  43. gr_libs/ml/utils/other.py +24 -0
  44. gr_libs/ml/utils/storage.py +127 -0
  45. gr_libs/recognizer/__init__.py +0 -0
  46. gr_libs/recognizer/gr_as_rl/__init__.py +0 -0
  47. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +102 -0
  48. gr_libs/recognizer/graml/__init__.py +0 -0
  49. gr_libs/recognizer/graml/gr_dataset.py +134 -0
  50. gr_libs/recognizer/graml/graml_recognizer.py +266 -0
  51. gr_libs/recognizer/recognizer.py +46 -0
  52. gr_libs/recognizer/utils/__init__.py +1 -0
  53. gr_libs/recognizer/utils/format.py +13 -0
  54. gr_libs-0.1.3.dist-info/METADATA +197 -0
  55. gr_libs-0.1.3.dist-info/RECORD +62 -0
  56. gr_libs-0.1.3.dist-info/WHEEL +5 -0
  57. gr_libs-0.1.3.dist-info/top_level.txt +3 -0
  58. tutorials/graml_minigrid_tutorial.py +30 -0
  59. tutorials/graml_panda_tutorial.py +32 -0
  60. tutorials/graml_parking_tutorial.py +38 -0
  61. tutorials/graml_point_maze_tutorial.py +43 -0
  62. tutorials/graql_minigrid_tutorial.py +29 -0
@@ -0,0 +1,21 @@
1
+ from abc import ABC
2
+
3
+
4
+ class TabularState(ABC):
5
+ def __init__(self,
6
+ agent_x_position: int,
7
+ agent_y_position: int,
8
+ agent_direction: int
9
+ ):
10
+ self._agent_x_position = agent_x_position
11
+ self._agent_y_position = agent_y_position
12
+ self._agent_direction = agent_direction
13
+
14
+ @staticmethod
15
+ def gen_tabular_state(environment, observation):
16
+ x, y = environment.unwrapped.agent_pos
17
+ direction = observation['direction']
18
+ return TabularState(agent_x_position=x, agent_y_position=y, agent_direction=direction)
19
+
20
+ def __str__(self):
21
+ return f"({self._agent_x_position},{self._agent_y_position}):{self._agent_direction}"
@@ -0,0 +1,453 @@
1
+ # Don't import stuff from metrics! it's a higher level module.
2
+ import os.path
3
+ import pickle
4
+ import random
5
+ from types import MethodType
6
+
7
+ import dill
8
+ from gymnasium import register
9
+ import numpy as np
10
+
11
+ from tqdm import tqdm
12
+ from typing import Any
13
+ from random import Random
14
+ from typing import List, Iterable
15
+ from gymnasium.error import InvalidAction
16
+ from gr_libs.environment.environment import QLEARNING, MinigridProperty
17
+ from gr_libs.ml.tabular import TabularState
18
+ from gr_libs.ml.tabular.tabular_rl_agent import TabularRLAgent
19
+ from gr_libs.ml.utils import get_agent_model_dir, random_subset_with_order, softmax
20
+
21
+
22
+ class TabularQLearner(TabularRLAgent):
23
+ """
24
+ A simple Tabular Q-Learning agent.
25
+ """
26
+
27
+ MODEL_FILE_NAME = r"tabular_model.txt"
28
+ CONF_FILE = r"conf.pkl"
29
+
30
+ def __init__(self,
31
+ domain_name: str,
32
+ problem_name: str,
33
+ algorithm: str,
34
+ num_timesteps: int,
35
+ decaying_eps: bool = True,
36
+ eps: float = 1.0,
37
+ alpha: float = 0.5,
38
+ decay: float = 0.000002,
39
+ gamma: float = 0.9,
40
+ rand: Random = Random(),
41
+ learning_rate: float = 0.001,
42
+ check_partial_goals: bool = True,
43
+ valid_only: bool = False
44
+ ):
45
+ super().__init__(
46
+ domain_name=domain_name,
47
+ problem_name=problem_name,
48
+ episodes=num_timesteps,
49
+ decaying_eps=decaying_eps,
50
+ eps=eps,
51
+ alpha=alpha,
52
+ decay=decay,
53
+ gamma=gamma,
54
+ rand=rand,
55
+ learning_rate=learning_rate
56
+ )
57
+ assert algorithm == QLEARNING, f"algorithm {algorithm} is not supported by {self.__class__.__name__}"
58
+ self.valid_only = valid_only
59
+ self.check_partial_goals = check_partial_goals
60
+ self.goal_literals_achieved = set()
61
+ self.model_directory = get_agent_model_dir(domain_name=domain_name, model_name=problem_name, class_name=self.class_name())
62
+ self.model_file_path = os.path.join(self.model_directory, TabularQLearner.MODEL_FILE_NAME)
63
+ self._conf_file = os.path.join(self.model_directory, TabularQLearner.CONF_FILE)
64
+
65
+ self._learned_episodes = 0
66
+
67
+ if os.path.exists(self.model_file_path):
68
+ print(f"Loading pre-existing model in {self.model_file_path}")
69
+ self.load_q_table(path=self.model_file_path)
70
+ else:
71
+ print(f"Creating new model in {self.model_file_path}")
72
+ if os.path.exists(self._conf_file):
73
+ print(f"Loading pre-existing conf file in {self._conf_file}")
74
+ with open(self._conf_file, "rb") as f:
75
+ conf = dill.load(file=f)
76
+ self._learned_episodes = conf['learned_episodes']
77
+
78
+ # hyperparameters
79
+ self.base_eps = eps
80
+ self.patience = 400000
81
+ if self.decaying_eps:
82
+ def epsilon():
83
+ self._c_eps = max((self.episodes - self.step) / self.episodes, 0.01)
84
+ return self._c_eps
85
+
86
+ self.eps = epsilon
87
+ else:
88
+ self.eps = lambda: eps
89
+ self.decaying_eps = decaying_eps
90
+ self.alpha = alpha
91
+ self.last_state = None
92
+ self.last_action = None
93
+
94
+ def states_in_q(self) -> Iterable:
95
+ """Returns the states stored in the q_values table
96
+
97
+ Returns:
98
+ List: The states for which we have a mapping in the q-table
99
+ """
100
+ return self.q_table.keys()
101
+
102
+ def policy(self, state: TabularState) -> Any:
103
+ """Returns the greedy deterministic policy for the specified state
104
+
105
+ Args:
106
+ state (State): the state for which we want the action
107
+
108
+ Raises:
109
+ InvalidAction: Not sure about this one
110
+
111
+ Returns:
112
+ Any: The greedy action learned for state
113
+ """
114
+ return self.best_action(state)
115
+
116
+ def epsilon_greedy_policy(self, state: TabularState) -> Any:
117
+ eps = self.eps()
118
+ if self._random.random() <= eps:
119
+ action = self._random.randint(0, self.number_of_actions - 1)
120
+ else:
121
+ action = self.policy(state)
122
+ return action
123
+
124
+ def softmax_policy(self, state: TabularState) -> np.array:
125
+ """Returns a softmax policy over the q-value returns stored in the q-table
126
+
127
+ Args:
128
+ state (State): the state for which we want a softmax policy
129
+
130
+ Returns:
131
+ np.array: probability of taking each action in self.actions given a state
132
+ """
133
+ if str(state) not in self.q_table:
134
+ self.add_new_state(state)
135
+ # If we query a state we have not visited, return a uniform distribution
136
+ # return softmax([0]*self.actions)
137
+ return softmax(self.q_table[str(state)])
138
+
139
+ def save_q_table(self, path: str):
140
+ # sadly, this does not work, because the state we are using
141
+ # is a frozenset of literals, which are not serializable.
142
+ # a way to fix this is to use array states built using
143
+ # common_functions.build_state
144
+
145
+ directory = os.path.dirname(path)
146
+ if not os.path.exists(directory):
147
+ os.makedirs(directory)
148
+
149
+ with open(path, 'wb') as f:
150
+ pickle.dump(self.q_table, f)
151
+
152
+ def load_q_table(self, path: str):
153
+ with open(path, 'rb') as f:
154
+ table = pickle.load(f)
155
+ self.q_table = table
156
+
157
+ def add_new_state(self, state: TabularState):
158
+ self.q_table[str(state)] = [0.] * self.number_of_actions
159
+
160
+ def get_all_q_values(self, state: TabularState) -> List[float]:
161
+ if str(state) in self.q_table:
162
+ return self.q_table[str(state)]
163
+ else:
164
+ return [0.] * self.number_of_actions
165
+
166
+ def best_action(self, state: TabularState) -> float:
167
+ if str(state) not in self.q_table:
168
+ self.add_new_state(state)
169
+ return np.argmax(self.q_table[str(state)])
170
+
171
+ def get_max_q(self, state) -> float:
172
+ if str(state) not in self.q_table:
173
+ self.add_new_state(state)
174
+ return np.max(self.q_table[str(state)])
175
+
176
+ def set_q_value(self, state: TabularState, action: Any, q_value: float):
177
+ if str(state) not in self.q_table:
178
+ self.add_new_state(state)
179
+ self.q_table[str(state)][action] = q_value
180
+
181
+ def get_q_value(self, state: TabularState, action: Any) -> float:
182
+ if str(state) not in self.q_table:
183
+ self.add_new_state(state)
184
+ return self.q_table[str(state)][action]
185
+
186
+ def agent_start(self, state: TabularState) -> int:
187
+ """The first method called when the experiment starts,
188
+ called after the environment starts.
189
+ Args:
190
+ state (Numpy array): the state from the
191
+ environment's env_start function.
192
+ Returns:
193
+ (int) the first action the agent takes.
194
+ """
195
+ self.last_state = state
196
+ self.last_action = self.policy(state)
197
+ return self.last_action
198
+
199
+ def agent_step(self, reward: float, state: TabularState) -> int:
200
+ """A step taken by the agent.
201
+
202
+ Args:
203
+ reward (float): the reward received for taking the last action taken
204
+ state (Any): the state from the
205
+ environment's step based on where the agent ended up after the
206
+ last step
207
+ Returns:
208
+ (int) The action the agent takes given this state.
209
+ """
210
+ max_q = self.get_max_q(state)
211
+ old_q = self.get_q_value(self.last_state, self.last_action)
212
+
213
+ td_error = self.gamma * max_q - old_q
214
+ new_q = old_q + self.alpha * (reward + td_error)
215
+
216
+ self.set_q_value(self.last_state, self.last_action, new_q)
217
+ # action = self.best_action(state)
218
+ action = self.epsilon_greedy_policy(state)
219
+ self.last_state = state
220
+ self.last_action = action
221
+ return action
222
+
223
+ def agent_end(self, reward: float) -> Any:
224
+ """Called when the agent terminates.
225
+
226
+ Args:
227
+ reward (float): the reward the agent received for entering the
228
+ terminal state.
229
+ """
230
+ old_q = self.get_q_value(self.last_state, self.last_action)
231
+
232
+ td_error = - old_q
233
+
234
+ new_q = old_q + self.alpha * (reward + td_error)
235
+ self.set_q_value(self.last_state, self.last_action, new_q)
236
+
237
+ def learn(self, init_threshold: int = 20):
238
+ tsteps = 2000
239
+ done_times = 0
240
+ patience = 0
241
+ converged_at = None
242
+ max_r = float("-inf")
243
+ print(f"{self._learned_episodes}->{self.episodes}")
244
+ if self._learned_episodes >= self.episodes:
245
+ print("learned episodes is above the requsted episodes")
246
+ return
247
+ print(f'Using {self.__class__.__name__}')
248
+ tq = tqdm(range(self.episodes - self._learned_episodes),
249
+ postfix=f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
250
+ for n in tq:
251
+ self.step = n
252
+ episode_r = 0
253
+ observation, info = self.env.reset()
254
+ tabular_state = TabularState.gen_tabular_state(environment=self.env, observation=observation)
255
+ action = self.agent_start(state=tabular_state)
256
+
257
+ self.update_states_counter(observation_str=str(tabular_state))
258
+ done = False
259
+ tstep = 0
260
+ while tstep < tsteps and not done:
261
+ observation, reward, terminated, truncated, _ = self.env.step(action)
262
+ done = terminated | truncated
263
+ if done:
264
+ done_times += 1
265
+
266
+ # standard q-learning algorithm
267
+ next_tabular_state = TabularState.gen_tabular_state(environment=self.env, observation=observation)
268
+ self.update_states_counter(observation_str=str(next_tabular_state))
269
+ action = self.agent_step(reward, next_tabular_state)
270
+ tstep += 1
271
+ episode_r += reward
272
+ self._learned_episodes = self._learned_episodes + 1
273
+ if done: # One last update at the terminal state
274
+ self.agent_end(reward)
275
+
276
+ if episode_r > max_r:
277
+ max_r = episode_r
278
+ # print("New all time high reward:", episode_r)
279
+ tq.set_postfix_str(
280
+ f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
281
+ if (n + 1) % 100 == 0:
282
+ tq.set_postfix_str(
283
+ f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
284
+ if (n + 1) % 1000 == 0:
285
+ tq.set_postfix_str(
286
+ f"States: {len(self.q_table.keys())}. Goals: {done_times}. Eps: {self._c_eps:.3f}. MaxR: {max_r}")
287
+ if done_times <= 10:
288
+ patience += 1
289
+ if patience >= self.patience:
290
+ print(f"Did not find goal after {n} episodes. Retrying.")
291
+ raise InvalidAction("Did not learn")
292
+ else:
293
+ patience = 0
294
+ if done_times == 1000 and converged_at is not None:
295
+ converged_at = n
296
+ print(f"***Policy converged to goal at {converged_at}***")
297
+ done_times = 0
298
+ self.goal_literals_achieved.clear()
299
+
300
+ print(f"number of unique states found during training:{self.get_number_of_unique_states()}")
301
+ print("finish learning and saving status")
302
+ self.save_models_to_files()
303
+
304
+ def exploit(self, number_of_steps=20):
305
+ observation, info = self.env.reset()
306
+ for step_number in range(number_of_steps):
307
+ tabular_state = TabularState.gen_tabular_state(environment=self.env, observation=observation)
308
+ action = self.policy(state=tabular_state)
309
+ observation, reward, terminated, truncated, _ = self.env.step(action)
310
+ done = terminated | truncated
311
+ if done:
312
+ print(f"reached goal after {step_number + 1} steps!")
313
+ break
314
+
315
+ def get_actions_probabilities(self, observation):
316
+ obs, agent_pos = observation
317
+ direction = obs['direction']
318
+
319
+ x, y = agent_pos
320
+ tabular_state = TabularState(agent_x_position=x, agent_y_position=y, agent_direction=direction)
321
+ return softmax(self.get_all_q_values(tabular_state))
322
+
323
+ def get_q_of_specific_cell(self, cell_key):
324
+ cell_q_table = {}
325
+ for i in range(4):
326
+ key = cell_key + ':' + str(i)
327
+ if key in self.q_table:
328
+ cell_q_table[key] = self.q_table[key]
329
+ return cell_q_table
330
+
331
+ def get_all_cells(self):
332
+ cells = set()
333
+ for key in self.q_table.keys():
334
+ cell = key.split(':')[0]
335
+ cells.add(cell)
336
+ return list(cells)
337
+
338
+
339
+ def _save_conf_file(self):
340
+ conf = {
341
+ 'learned_episodes': self._learned_episodes,
342
+ 'states_counter': self.states_counter
343
+ }
344
+ with open(self._conf_file, "wb") as f:
345
+ dill.dump(conf, f)
346
+
347
+ def save_models_to_files(self):
348
+ self.save_q_table(path=self.model_file_path)
349
+ self._save_conf_file()
350
+
351
+ def simplify_observation(self, observation):
352
+ return [(obs['direction'], agent_pos_x, agent_pos_y, action) for ((obs, (agent_pos_x, agent_pos_y)), action) in observation] # list of tuples, each tuple the sample
353
+
354
+ def generate_observation(self, action_selection_method: MethodType, random_optimalism, save_fig = False, fig_path: str=None, env_prop=None):
355
+ """
356
+ Generate a single observation given a list of agents
357
+
358
+ Args:
359
+ agents (list): A list of agents from which to select one randomly.
360
+ action_selection_method : a MethodType, to generate the observation stochastically, greedily, or softmax.
361
+
362
+ Returns:
363
+ list: A list of state-action pairs representing the generated observation.
364
+
365
+ Notes:
366
+ The function randomly selects an agent from the given list and generates a sequence of state-action pairs
367
+ based on the Q-table of the selected agent. The action selection is stochastic, where each action is
368
+ selected based on the probability distribution defined by the Q-values in the Q-table.
369
+
370
+ The generated sequence terminates when a maximum number of steps is reached or when the environment
371
+ episode terminates.
372
+ """
373
+ if save_fig == False:
374
+ assert fig_path == None, "You can't specify a vid path when you don't even save the figure."
375
+ else:
376
+ assert fig_path != None, "You must specify a vid path when you save the figure."
377
+ obs, _ = self.env.reset()
378
+ MAX_STEPS = 32
379
+ done = False
380
+ steps = []
381
+ for step_index in range(MAX_STEPS):
382
+ x, y = self.env.unwrapped.agent_pos
383
+ str_state = "({},{}):{}".format(x, y, obs['direction'])
384
+ relevant_actions_idx = 3
385
+ action_probs = self.q_table[str_state][:relevant_actions_idx] / np.sum(self.q_table[str_state][:relevant_actions_idx]) # Normalize probabilities
386
+ if step_index == 0 and random_optimalism:
387
+ # print("in 1st step in generating plan and got random optimalism.")
388
+ std_dev = np.std(action_probs)
389
+ # uniques_sorted = np.unique(action_probs)
390
+ num_of_stds = abs(action_probs[0] - action_probs[2]) / std_dev
391
+ if num_of_stds < 2.1:
392
+ # sorted_indices = np.argsort(action_probs)
393
+ # action = np.random.choice([sorted_indices[-1], sorted_indices[-2]])
394
+ action = np.random.choice([0, 2])
395
+ if action == 0:
396
+ steps.append(((obs, self.env.unwrapped.agent_pos), action))
397
+ obs, reward, terminated, truncated, info = self.env.step(action)
398
+ assert reward >= 0
399
+ action = 2
400
+ step_index += 1
401
+ else: action = action_selection_method(action_probs)
402
+ else:
403
+ action = action_selection_method(action_probs)
404
+ steps.append(((obs, self.env.unwrapped.agent_pos), action))
405
+ obs, reward, terminated, truncated, info = self.env.step(action)
406
+ assert reward >= 0
407
+ done = terminated | truncated
408
+ if done:
409
+ break
410
+
411
+ #assert len(steps) >= 2
412
+ if save_fig:
413
+ sequence = [pos for ((state, pos), action) in steps]
414
+ #print(f"sequence to {self.problem_name} is:\n\t{steps}\ngenerating image at {img_path}.")
415
+ print(f"generating sequence image at {fig_path}.")
416
+ env_prop.create_sequence_image(sequence, fig_path, self.problem_name) # TODO change that assumption, cannot assume this is minigrid env
417
+
418
+ return steps
419
+
420
+ def generate_partial_observation(self, action_selection_method: MethodType, percentage: float, save_fig = False, is_consecutive = True, random_optimalism=True, fig_path=None):
421
+ """
422
+ Generate a single observation given a list of agents
423
+
424
+ Args:
425
+ agents (list): A list of agents from which to select one randomly.
426
+ action_selection_method : a MethodType, to generate the observation stochastically, greedily, or softmax.
427
+
428
+ Returns:
429
+ list: A list of state-action pairs representing the generated observation.
430
+
431
+ Notes:
432
+ The function randomly selects an agent from the given list and generates a sequence of state-action pairs
433
+ based on the Q-table of the selected agent. The action selection is stochastic, where each action is
434
+ selected based on the probability distribution defined by the Q-values in the Q-table.
435
+
436
+ The generated sequence terminates when a maximum number of steps is reached or when the environment
437
+ episode terminates.
438
+ """
439
+
440
+ steps = self.generate_observation(action_selection_method=action_selection_method, random_optimalism=random_optimalism, save_fig=save_fig, fig_path=fig_path) # steps are a full observation
441
+ result = random_subset_with_order(steps, (int)(percentage * len(steps)), is_consecutive)
442
+ if percentage >= 0.8:
443
+ assert len(result) > 2
444
+ return result
445
+
446
+ if __name__ == "__main__":
447
+ from gr_libs.metrics.metrics import greedy_selection
448
+ import gr_envs # to register everything
449
+ agent = TabularQLearner(domain_name="minigrid", problem_name="MiniGrid-LavaCrossingS9N2-DynamicGoal-1x7-v0")
450
+ agent.generate_observation(greedy_selection, True, True)
451
+
452
+ # python experiments.py --recognizer graml --domain point_maze --task L5 --partial_obs_type continuing --point_maze_env obstacles --collect_stats --inference_same_seq_len
453
+
@@ -0,0 +1,126 @@
1
+ import gymnasium as gym
2
+ from abc import abstractmethod
3
+ from typing import Collection, Literal, Any
4
+ from random import Random
5
+ import numpy as np
6
+
7
+ from gr_libs.ml.base import RLAgent
8
+ from gr_libs.ml.base import State
9
+
10
+
11
+ class TabularRLAgent(RLAgent):
12
+ """
13
+ This is a base class used as parent class for any
14
+ RL agent. This is currently not much in use, but is
15
+ recommended as development goes on.
16
+ """
17
+
18
+ def __init__(self,
19
+ domain_name: str,
20
+ problem_name: str,
21
+ episodes: int,
22
+ decaying_eps: bool,
23
+ eps: float,
24
+ alpha: float,
25
+ decay: float,
26
+ gamma: float,
27
+ rand: Random,
28
+ learning_rate
29
+ ):
30
+ super().__init__(
31
+ episodes=episodes,
32
+ decaying_eps=decaying_eps,
33
+ epsilon=eps,
34
+ learning_rate=learning_rate,
35
+ gamma=gamma,
36
+ domain_name=domain_name,
37
+ problem_name=problem_name
38
+ )
39
+ self.env = gym.make(id=problem_name)
40
+ self.actions = self.env.unwrapped.actions
41
+ self.number_of_actions = len(self.actions)
42
+ self._actions_space = self.env.action_space
43
+ self._random = rand
44
+ self._alpha = alpha
45
+ self._decay = decay
46
+ self._c_eps = eps
47
+ self.q_table = {}
48
+
49
+ # TODO:: maybe need to save env.reset output
50
+ self.env.reset()
51
+
52
+ @abstractmethod
53
+ def agent_start(self, state) -> Any:
54
+ """The first method called when the experiment starts,
55
+ called after the environment starts.
56
+ Args:
57
+ state (Numpy array): the state from the
58
+ environment's env_start function.
59
+ Returns:
60
+ (int) the first action the agent takes.
61
+ """
62
+ pass
63
+
64
+ @abstractmethod
65
+ def agent_step(self, reward: float, state: State) -> Any:
66
+ """A step taken by the agent.
67
+ Args:
68
+ reward (float): the reward received for taking the last action taken
69
+ state (Any): the state observation from the
70
+ environment's step based, where the agent ended up after the
71
+ last step
72
+ Returns:
73
+ The action the agent is taking.
74
+ """
75
+ pass
76
+
77
+ @abstractmethod
78
+ def agent_end(self, reward: float) -> Any:
79
+ """Called when the agent terminates.
80
+
81
+ Args:
82
+ reward (float): the reward the agent received for entering the
83
+ terminal state.
84
+ """
85
+ pass
86
+
87
+ @abstractmethod
88
+ def policy(self, state: State) -> Any:
89
+ """The action for the specified state under the currently learned policy
90
+ (unlike agent_step, this does not update the policy using state as a sample.
91
+ Args:
92
+ state (Any): the state observation from the environment
93
+ Returns:
94
+ The action prescribed for that state
95
+ """
96
+ pass
97
+
98
+ @abstractmethod
99
+ def softmax_policy(self, state: State) -> np.array:
100
+ """Returns a softmax policy over the q-value returns stored in the q-table
101
+
102
+ Args:
103
+ state (State): the state for which we want a softmax policy
104
+
105
+ Returns:
106
+ np.array: probability of taking each action in self.actions given a state
107
+ """
108
+ pass
109
+
110
+ @abstractmethod
111
+ def learn(self, init_threshold: int = 20):
112
+ pass
113
+
114
+ def __getitem__(self, state: State) -> Any:
115
+ """[summary]
116
+
117
+ Args:
118
+ state (Any): The state for which we want to get the policy
119
+
120
+ Raises:
121
+ InvalidAction: [description]
122
+
123
+ Returns:
124
+ Any: [description]
125
+ """""
126
+ return self.softmax_policy(state)
@@ -0,0 +1,6 @@
1
+ #from .agent import *
2
+ from .env import make_env
3
+ from .format import Vocabulary, preprocess_images, preprocess_texts, get_obss_preprocessor, random_subset_with_order
4
+ from .other import device, seed, synthesize
5
+ from .storage import *
6
+ from .math import softmax
@@ -0,0 +1,7 @@
1
+ import gymnasium as gym
2
+
3
+
4
+ def make_env(env_key, seed=None, render_mode=None):
5
+ env = gym.make(env_key, render_mode=render_mode)
6
+ env.reset(seed=seed)
7
+ return env