gr-libs 0.1.8__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. gr_libs/__init__.py +3 -1
  2. gr_libs/_version.py +2 -2
  3. gr_libs/all_experiments.py +260 -0
  4. gr_libs/environment/__init__.py +14 -1
  5. gr_libs/environment/_utils/__init__.py +0 -0
  6. gr_libs/environment/{utils → _utils}/utils.py +1 -1
  7. gr_libs/environment/environment.py +278 -23
  8. gr_libs/evaluation/__init__.py +1 -0
  9. gr_libs/evaluation/generate_experiments_results.py +100 -0
  10. gr_libs/metrics/__init__.py +2 -0
  11. gr_libs/metrics/metrics.py +166 -31
  12. gr_libs/ml/__init__.py +1 -6
  13. gr_libs/ml/base/__init__.py +3 -1
  14. gr_libs/ml/base/rl_agent.py +68 -3
  15. gr_libs/ml/neural/__init__.py +1 -3
  16. gr_libs/ml/neural/deep_rl_learner.py +241 -84
  17. gr_libs/ml/neural/utils/__init__.py +1 -2
  18. gr_libs/ml/planner/mcts/{utils → _utils}/tree.py +1 -1
  19. gr_libs/ml/planner/mcts/mcts_model.py +71 -34
  20. gr_libs/ml/sequential/__init__.py +0 -1
  21. gr_libs/ml/sequential/{lstm_model.py → _lstm_model.py} +11 -14
  22. gr_libs/ml/tabular/__init__.py +1 -3
  23. gr_libs/ml/tabular/tabular_q_learner.py +27 -9
  24. gr_libs/ml/tabular/tabular_rl_agent.py +22 -9
  25. gr_libs/ml/utils/__init__.py +2 -9
  26. gr_libs/ml/utils/format.py +13 -90
  27. gr_libs/ml/utils/math.py +3 -2
  28. gr_libs/ml/utils/other.py +2 -2
  29. gr_libs/ml/utils/storage.py +41 -94
  30. gr_libs/odgr_executor.py +263 -0
  31. gr_libs/problems/consts.py +570 -292
  32. gr_libs/recognizer/{utils → _utils}/format.py +2 -2
  33. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +127 -36
  34. gr_libs/recognizer/graml/{gr_dataset.py → _gr_dataset.py} +11 -11
  35. gr_libs/recognizer/graml/graml_recognizer.py +186 -35
  36. gr_libs/recognizer/recognizer.py +59 -10
  37. gr_libs/tutorials/draco_panda_tutorial.py +58 -0
  38. gr_libs/tutorials/draco_parking_tutorial.py +56 -0
  39. {tutorials → gr_libs/tutorials}/gcdraco_panda_tutorial.py +11 -11
  40. {tutorials → gr_libs/tutorials}/gcdraco_parking_tutorial.py +6 -8
  41. {tutorials → gr_libs/tutorials}/graml_minigrid_tutorial.py +18 -14
  42. {tutorials → gr_libs/tutorials}/graml_panda_tutorial.py +11 -12
  43. {tutorials → gr_libs/tutorials}/graml_parking_tutorial.py +8 -10
  44. {tutorials → gr_libs/tutorials}/graml_point_maze_tutorial.py +17 -3
  45. {tutorials → gr_libs/tutorials}/graql_minigrid_tutorial.py +2 -2
  46. {gr_libs-0.1.8.dist-info → gr_libs-0.2.5.dist-info}/METADATA +95 -29
  47. gr_libs-0.2.5.dist-info/RECORD +72 -0
  48. {gr_libs-0.1.8.dist-info → gr_libs-0.2.5.dist-info}/WHEEL +1 -1
  49. gr_libs-0.2.5.dist-info/top_level.txt +2 -0
  50. tests/test_draco.py +14 -0
  51. tests/test_gcdraco.py +2 -2
  52. tests/test_graml.py +4 -4
  53. tests/test_graql.py +1 -1
  54. tests/test_odgr_executor_expertbasedgraml.py +14 -0
  55. tests/test_odgr_executor_gcdraco.py +14 -0
  56. tests/test_odgr_executor_gcgraml.py +14 -0
  57. tests/test_odgr_executor_graql.py +14 -0
  58. evaluation/analyze_results_cross_alg_cross_domain.py +0 -267
  59. evaluation/create_minigrid_map_image.py +0 -38
  60. evaluation/file_system.py +0 -53
  61. evaluation/generate_experiments_results.py +0 -141
  62. evaluation/generate_experiments_results_new_ver1.py +0 -238
  63. evaluation/generate_experiments_results_new_ver2.py +0 -331
  64. evaluation/generate_task_specific_statistics_plots.py +0 -500
  65. evaluation/get_plans_images.py +0 -62
  66. evaluation/increasing_and_decreasing_.py +0 -104
  67. gr_libs/ml/neural/utils/penv.py +0 -60
  68. gr_libs-0.1.8.dist-info/RECORD +0 -70
  69. gr_libs-0.1.8.dist-info/top_level.txt +0 -4
  70. /gr_libs/{environment/utils/__init__.py → _evaluation/_generate_experiments_results.py} +0 -0
  71. /gr_libs/ml/planner/mcts/{utils → _utils}/__init__.py +0 -0
  72. /gr_libs/ml/planner/mcts/{utils → _utils}/node.py +0 -0
  73. /gr_libs/recognizer/{utils → _utils}/__init__.py +0 -0
@@ -1,14 +1,16 @@
1
+ """ model that performs mcts to find a plan in discrete state/action environments. """
2
+
1
3
  import os
4
+ import pickle
2
5
  import random
3
- from math import sqrt, log
6
+ from math import log, sqrt
4
7
 
8
+ import gymnasium as gym
5
9
  from tqdm import tqdm
6
- import pickle
7
10
 
8
11
  from gr_libs.ml.utils.storage import get_agent_model_dir
9
- from .utils import Node
10
- from .utils import Tree
11
- import gymnasium as gym
12
+
13
+ from ._utils import Node, Tree
12
14
 
13
15
  PROB = 0.8
14
16
  UNIFORM_PROB = 0.1
@@ -18,8 +20,20 @@ dict_action_id_to_str = {0: "turn left", 1: "turn right", 2: "go straight"}
18
20
 
19
21
 
20
22
  def save_figure(steps, env_name, problem_name, img_path, env_prop):
23
+ """
24
+ Save a figure representing the sequence of steps taken in a problem.
25
+
26
+ Args:
27
+ steps (list): List of tuples representing the state, position, and action taken at each step.
28
+ env_name (str): Name of the environment.
29
+ problem_name (str): Name of the problem.
30
+ img_path (str): Path to save the generated image.
31
+ env_prop: Object with methods to create the sequence image.
32
+
33
+ Returns:
34
+ None
35
+ """
21
36
  sequence = [pos for ((state, pos), action) in steps]
22
- # print(f"sequence to {self.problem_name} is:\n\t{steps}\ngenerating image at {img_path}.")
23
37
  print(f"generating sequence image at {img_path}.")
24
38
  env_prop.create_sequence_image(sequence, img_path, problem_name)
25
39
 
@@ -27,16 +41,39 @@ def save_figure(steps, env_name, problem_name, img_path, env_prop):
27
41
  # TODO add number of expanded nodes and debug by putting breakpoint on the creation of nodes representing (8,4) and checking if they're invalid or something
28
42
 
29
43
 
30
- # Explanation on hashing and uncertainty in the acto outcome:
31
- # We want to detect circles, while not preventing expected behavior. To achieve it, hasing must include previous state, action, and resulting state.
32
- # Hashing the direction means coming to the same position from different positions gets different id's.
33
- # Example: the agent might have stood at (2,2), picked action 2 (forward), and accidently turned right, resulting at state ((2,2), right).
34
- # later, when the agent stood at (2,1), looked right and walked forward, it got to the same state. We would want to enable that, because
35
- # this is the expected behavior, so these nodes must have unique id's.
36
- # The situations where circles will indeed be detected, are only if the outcome was the same for the previous state, consistent with the action - whether it was or wasn't expected.
37
44
  class MonteCarloTreeSearch:
45
+ """
46
+ Monte Carlo Tree Search class for performing search on an environment using a tree data structure.
47
+
48
+ Explanation on hashing and uncertainty in the acto outcome:
49
+ We want to detect circles, while not preventing expected behavior.
50
+ To achieve it, hasing must include previous state, action, and resulting state.
51
+ Hashing the direction means coming to the same position from different positions gets different id's.
52
+ Example: the agent might have stood at (2,2), picked action 2 (forward), and accidently turned right,
53
+ resulting at state ((2,2), right).
54
+ later, when the agent stood at (2,1), looked right and walked forward,
55
+ it got to the same state. We would want to enable that, because
56
+ this is the expected behavior, so these nodes must have unique id's.
57
+ The situations where circles will indeed be detected, are only if the outcome was the same for the previous state,
58
+ consistent with the action - whether it was or wasn't expected.
59
+
60
+ Args:
61
+ env (gym.Env): The environment to perform the search on.
62
+ tree (Tree): The tree data structure to store the search tree.
63
+ goal (object): The goal state of the search.
64
+ use_heuristic (bool, optional): Whether to use a heuristic function during the search. Defaults to True.
65
+ """
38
66
 
39
67
  def __init__(self, env, tree, goal, use_heuristic=True):
68
+ """
69
+ Initializes the Monte Carlo Tree Search.
70
+
71
+ Args:
72
+ env (gym.Env): The environment to perform the search on.
73
+ tree (Tree): The tree data structure to store the search tree.
74
+ goal (object): The goal state of the search.
75
+ use_heuristic (bool, optional): Whether to use a heuristic function during the search. Defaults to True.
76
+ """
40
77
  self.env = env
41
78
  self.tree = tree
42
79
  self.action_space = self.env.action_space.n
@@ -81,13 +118,13 @@ class MonteCarloTreeSearch:
81
118
  new_node_ptr = new_node_father
82
119
  old_node_ptr = old_node
83
120
 
84
- while new_node_ptr != None:
121
+ while new_node_ptr is not None:
85
122
  new_visits[0] += new_node_ptr.num_visits
86
123
  new_visits[1] += 1
87
124
  new_node_ptr = self.tree.parent(new_node_ptr)
88
125
 
89
126
  while (
90
- old_node_ptr != None
127
+ old_node_ptr is not None
91
128
  ): # getting to the old node wasn't necessarily through the current root. check all the way until None, the original root's parent.
92
129
  old_visits[0] += old_node_ptr.num_visits
93
130
  old_visits[1] += 1
@@ -128,7 +165,7 @@ class MonteCarloTreeSearch:
128
165
  while (
129
166
  new_identifier in self.tree.nodes.keys()
130
167
  ): # iterate over all circle nodes. important not to hash the parent node id to get the next id, because it will not be the same for all circle nodes.
131
- if self.tree.nodes[new_identifier].invalid == False:
168
+ if self.tree.nodes[new_identifier].invalid is False:
132
169
  valid_id = new_identifier
133
170
  new_identifier = hash((666, new_identifier))
134
171
  # after this while, the id is for sure unused.
@@ -194,7 +231,7 @@ class MonteCarloTreeSearch:
194
231
  while (
195
232
  resulting_identifier in self.tree.nodes.keys()
196
233
  ): # iterate over all circle nodes. important not to hash the parent node id to get the next id, because it will not be the same for all circle nodes.
197
- if self.tree.nodes[resulting_identifier].invalid == False:
234
+ if self.tree.nodes[resulting_identifier].invalid is False:
198
235
  valid_id = resulting_identifier
199
236
  resulting_identifier = hash((666, resulting_identifier))
200
237
  # after this while, the id is for sure unused.
@@ -234,6 +271,7 @@ class MonteCarloTreeSearch:
234
271
  if self.use_heuristic:
235
272
  # taken from Monte-Carlo Planning for Pathfinding in Real-Time Strategy Games , 2010.
236
273
  # need to handle the case of walking into a wall here: the resulting node will be considered invalid and it's reward and performance needs to be 0, but must handle stochasticity
274
+ pass
237
275
  # suggestion to handle stochasticity - consider *all* the children associated with taking action 2 towards a wall as performance 0, even if they accidently led in walking to another direction.
238
276
  # which suggests the invalidity needs to be checked not according to the resulting state, rather according to the intended action itself and the environment! remember, you cannot access the "stochastic_action", it is meant to be hidden from you.
239
277
  if node.pos[0] == self.goal[0] and node.pos[1] == self.goal[1]:
@@ -382,7 +420,7 @@ class MonteCarloTreeSearch:
382
420
  def backpropagation(self, node, value):
383
421
  while node != self.tree.parent(self.tree.root):
384
422
  assert (
385
- node != None
423
+ node is not None
386
424
  ) # if we got to None it means we got to the actual root with the backpropogation instead of to the current root, which means in this path, someone had a differrent parent than it should, probably a double id.
387
425
  node.num_visits += 1
388
426
  node.total_simulation_reward += value
@@ -411,10 +449,10 @@ class MonteCarloTreeSearch:
411
449
  ) # need to add the previous node with the action leading to the next node which is a property of the next node
412
450
  prev_node = node
413
451
  if save_fig:
414
- assert fig_path != None
452
+ assert fig_path is not None
415
453
  save_figure(trace, env_name, problem_name, fig_path, env_prop)
416
454
  else:
417
- assert fig_path == None
455
+ assert fig_path is None
418
456
  return trace
419
457
 
420
458
 
@@ -430,6 +468,17 @@ def save_model_and_generate_policy(
430
468
 
431
469
 
432
470
  def plan(env_name, problem_name, goal, save_fig=False, fig_path=None, env_prop=None):
471
+ """
472
+ Plan a path using Monte Carlo Tree Search (MCTS) algorithm.
473
+
474
+ Args:
475
+ env_name (str): Name of the environment.
476
+ problem_name (str): Name of the problem.
477
+ goal (tuple): Goal state to reach.
478
+ save_fig (bool): Flag to save the figure of the plan.
479
+ fig_path (str): Path to save the figure.
480
+ env_prop: Object with methods to create the sequence image.
481
+ """
433
482
  global newely_expanded
434
483
  model_dir = get_agent_model_dir(
435
484
  env_name=env_name, model_name=problem_name, class_name="MCTS"
@@ -440,16 +489,14 @@ def plan(env_name, problem_name, goal, save_fig=False, fig_path=None, env_prop=N
440
489
  with open(model_file_path, "rb") as file: # Load the pre-existing model
441
490
  try:
442
491
  monteCarloTreeSearch = pickle.load(file)
443
- except Exception as e:
492
+ except Exception:
444
493
 
445
494
  class RenameUnpickler(pickle.Unpickler):
446
495
  def find_class(self, module, name):
447
496
  renamed_module = module
448
497
  if module.startswith("ml"):
449
498
  renamed_module = "gr_libs." + renamed_module
450
- return super(RenameUnpickler, self).find_class(
451
- renamed_module, name
452
- )
499
+ return super().find_class(renamed_module, name)
453
500
 
454
501
  def renamed_load(file_obj):
455
502
  return RenameUnpickler(file_obj).load()
@@ -540,13 +587,3 @@ def plan(env_name, problem_name, goal, save_fig=False, fig_path=None, env_prop=N
540
587
  return mcts.generate_full_policy_sequence(
541
588
  env_name, problem_name, save_fig, fig_path
542
589
  )
543
-
544
-
545
- if __name__ == "__main__":
546
- # register(
547
- # id="MiniGrid-DynamicGoalEmpty-8x8-3x6-v0",
548
- # entry_point="minigrid.envs:DynamicGoalEmpty",
549
- # kwargs={"size": 8, "agent_start_pos" : (1, 1), "goal_pos": (3,6) },
550
- # )
551
- # plan("MiniGrid-DynamicGoalEmpty-8x8-3x6-v0")
552
- pass
@@ -1 +0,0 @@
1
- from gr_libs.ml.sequential.lstm_model import LstmObservations
@@ -1,13 +1,10 @@
1
- import os
2
1
  import torch
3
2
  import torch.nn as nn
4
3
  import torch.nn.functional as F
5
- import torch.optim as optim
6
- from types import MethodType
7
- import numpy as np
8
- from gr_libs.ml.utils import device
9
4
  from torch.nn.utils.rnn import pack_padded_sequence
10
5
 
6
+ from gr_libs.ml.utils import device
7
+
11
8
 
12
9
  def accuracy_per_epoch(model, data_loader):
13
10
  model.eval()
@@ -119,7 +116,7 @@ class LstmObservations(nn.Module):
119
116
  def __init__(
120
117
  self, input_size, hidden_size
121
118
  ): # TODO make sure the right cuda is used!
122
- super(LstmObservations, self).__init__()
119
+ super().__init__()
123
120
  # self.embeddor = CNNImageEmbeddor(obs_space, action_space)
124
121
  # check if the traces are a bunch of images
125
122
  self.lstm = nn.LSTM(
@@ -221,10 +218,10 @@ def train_metric_model(model, train_loader, dev_loader, nepochs=5, patience=2):
221
218
  no_improvement_count = 1
222
219
 
223
220
  print(
224
- "epoch - {}/{}...".format(epoch + 1, nepochs),
225
- "train loss - {:.6f}...".format(sum_loss / denominator),
226
- "dev loss - {:.6f}...".format(dev_loss),
227
- "dev accuracy - {:.6f}".format(dev_accuracy),
221
+ f"epoch - {epoch + 1}/{nepochs}...",
222
+ f"train loss - {sum_loss / denominator:.6f}...",
223
+ f"dev loss - {dev_loss:.6f}...",
224
+ f"dev accuracy - {dev_accuracy:.6f}",
228
225
  )
229
226
 
230
227
  if no_improvement_count >= patience:
@@ -266,8 +263,8 @@ def train_metric_model_cont(model, train_loader, dev_loader, nepochs=5):
266
263
  devAccuracy.append(dev_accuracy)
267
264
 
268
265
  print(
269
- "epoch - {}/{}...".format(epoch + 1, nepochs),
270
- "train loss - {:.6f}...".format(sum_loss / denominator),
271
- "dev loss - {:.6f}...".format(dev_loss),
272
- "dev accuracy - {:.6f}".format(dev_accuracy),
266
+ f"epoch - {epoch + 1}/{nepochs}...",
267
+ f"train loss - {sum_loss / denominator:.6f}...",
268
+ f"dev loss - {dev_loss:.6f}...",
269
+ f"dev accuracy - {dev_accuracy:.6f}",
273
270
  )
@@ -1,3 +1 @@
1
- from gr_libs.ml.tabular.state import TabularState
2
- from gr_libs.ml.tabular.tabular_q_learner import TabularQLearner
3
- from gr_libs.ml.sequential.lstm_model import LstmObservations
1
+ from .state import TabularState
@@ -1,18 +1,17 @@
1
- # Don't import stuff from metrics! it's a higher level module.
1
+ """ implementation of q-learning """
2
+
2
3
  import os.path
3
4
  import pickle
4
- import random
5
+ from collections.abc import Iterable
6
+ from random import Random
5
7
  from types import MethodType
8
+ from typing import Any
6
9
 
7
10
  import dill
8
- from gymnasium import register
9
11
  import numpy as np
10
-
11
- from tqdm import tqdm
12
- from typing import Any
13
- from random import Random
14
- from typing import List, Iterable
15
12
  from gymnasium.error import InvalidAction
13
+ from tqdm import tqdm
14
+
16
15
  from gr_libs.environment.environment import QLEARNING, EnvProperty
17
16
  from gr_libs.ml.tabular import TabularState
18
17
  from gr_libs.ml.tabular.tabular_rl_agent import TabularRLAgent
@@ -44,6 +43,25 @@ class TabularQLearner(TabularRLAgent):
44
43
  check_partial_goals: bool = True,
45
44
  valid_only: bool = False,
46
45
  ):
46
+ """
47
+ Initialize a TabularQLearner object.
48
+
49
+ Args:
50
+ domain_name (str): The name of the domain.
51
+ problem_name (str): The name of the problem.
52
+ env_prop (EnvProperty): The environment properties.
53
+ algorithm (str): The algorithm to use.
54
+ num_timesteps (int): The number of timesteps.
55
+ decaying_eps (bool, optional): Whether to use decaying epsilon. Defaults to True.
56
+ eps (float, optional): The initial epsilon value. Defaults to 1.0.
57
+ alpha (float, optional): The learning rate. Defaults to 0.5.
58
+ decay (float, optional): The decay rate. Defaults to 0.000002.
59
+ gamma (float, optional): The discount factor. Defaults to 0.9.
60
+ rand (Random, optional): The random number generator. Defaults to Random().
61
+ learning_rate (float, optional): The learning rate. Defaults to 0.001.
62
+ check_partial_goals (bool, optional): Whether to check partial goals. Defaults to True.
63
+ valid_only (bool, optional): Whether to use valid goals only. Defaults to False.
64
+ """
47
65
  super().__init__(
48
66
  domain_name=domain_name,
49
67
  problem_name=problem_name,
@@ -169,7 +187,7 @@ class TabularQLearner(TabularRLAgent):
169
187
  def add_new_state(self, state: TabularState):
170
188
  self.q_table[str(state)] = [0.0] * self.number_of_actions
171
189
 
172
- def get_all_q_values(self, state: TabularState) -> List[float]:
190
+ def get_all_q_values(self, state: TabularState) -> list[float]:
173
191
  if str(state) in self.q_table:
174
192
  return self.q_table[str(state)]
175
193
  else:
@@ -1,11 +1,11 @@
1
- import gymnasium as gym
2
1
  from abc import abstractmethod
3
- from typing import Collection, Literal, Any
4
2
  from random import Random
3
+ from typing import Any
4
+
5
+ import gymnasium as gym
5
6
  import numpy as np
6
7
 
7
- from gr_libs.ml.base import RLAgent
8
- from gr_libs.ml.base import State
8
+ from gr_libs.ml.base import RLAgent, State
9
9
 
10
10
 
11
11
  class TabularRLAgent(RLAgent):
@@ -28,6 +28,24 @@ class TabularRLAgent(RLAgent):
28
28
  rand: Random,
29
29
  learning_rate,
30
30
  ):
31
+ """
32
+ Initializes a TabularRLAgent object.
33
+
34
+ Args:
35
+ domain_name (str): The name of the domain.
36
+ problem_name (str): The name of the problem.
37
+ episodes (int): The number of episodes to run.
38
+ decaying_eps (bool): Whether to use decaying epsilon.
39
+ eps (float): The initial epsilon value.
40
+ alpha (float): The learning rate.
41
+ decay (float): The decay rate for epsilon.
42
+ gamma (float): The discount factor.
43
+ rand (Random): The random number generator.
44
+ learning_rate: The learning rate.
45
+
46
+ Returns:
47
+ None
48
+ """
31
49
  super().__init__(
32
50
  episodes=episodes,
33
51
  decaying_eps=decaying_eps,
@@ -60,7 +78,6 @@ class TabularRLAgent(RLAgent):
60
78
  Returns:
61
79
  (int) the first action the agent takes.
62
80
  """
63
- pass
64
81
 
65
82
  @abstractmethod
66
83
  def agent_step(self, reward: float, state: State) -> Any:
@@ -73,7 +90,6 @@ class TabularRLAgent(RLAgent):
73
90
  Returns:
74
91
  The action the agent is taking.
75
92
  """
76
- pass
77
93
 
78
94
  @abstractmethod
79
95
  def agent_end(self, reward: float) -> Any:
@@ -83,7 +99,6 @@ class TabularRLAgent(RLAgent):
83
99
  reward (float): the reward the agent received for entering the
84
100
  terminal state.
85
101
  """
86
- pass
87
102
 
88
103
  @abstractmethod
89
104
  def policy(self, state: State) -> Any:
@@ -94,7 +109,6 @@ class TabularRLAgent(RLAgent):
94
109
  Returns:
95
110
  The action prescribed for that state
96
111
  """
97
- pass
98
112
 
99
113
  @abstractmethod
100
114
  def softmax_policy(self, state: State) -> np.array:
@@ -106,7 +120,6 @@ class TabularRLAgent(RLAgent):
106
120
  Returns:
107
121
  np.array: probability of taking each action in self.actions given a state
108
122
  """
109
- pass
110
123
 
111
124
  @abstractmethod
112
125
  def learn(self, init_threshold: int = 20):
@@ -1,12 +1,5 @@
1
- # from .agent import *
2
1
  from .env import make_env
3
- from .format import (
4
- Vocabulary,
5
- preprocess_images,
6
- preprocess_texts,
7
- get_obss_preprocessor,
8
- random_subset_with_order,
9
- )
2
+ from .format import random_subset_with_order
3
+ from .math import softmax
10
4
  from .other import device, seed, synthesize
11
5
  from .storage import *
12
- from .math import softmax
@@ -1,62 +1,22 @@
1
- import numpy
2
- import re
3
- import torch
4
- import gr_libs.ml
5
- import gymnasium as gym
6
- import random
7
-
8
-
9
- def get_obss_preprocessor(obs_space):
10
- # Check if obs_space is an image space
11
- if isinstance(obs_space, gym.spaces.Box):
12
- obs_space = {"image": obs_space.shape}
13
-
14
- def preprocess_obss(obss, device=None):
15
- return ml.DictList({"image": preprocess_images(obss, device=device)})
16
-
17
- # Check if it is a MiniGrid observation space
18
- elif isinstance(obs_space, gym.spaces.Dict) and "image" in obs_space.spaces.keys():
19
- obs_space = {"image": obs_space.spaces["image"].shape, "text": 100}
20
-
21
- vocab = Vocabulary(obs_space["text"])
22
-
23
- def preprocess_obss(obss, device=None):
24
- return ml.DictList(
25
- {
26
- "image": preprocess_images(
27
- [obs["image"] for obs in obss], device=device
28
- ),
29
- "text": preprocess_texts(
30
- [obs["mission"] for obs in obss], vocab, device=device
31
- ),
32
- }
33
- )
1
+ """ formatting-related utilities """
34
2
 
35
- preprocess_obss.vocab = vocab
36
-
37
- # Check if it is a MiniGrid observation space
38
- elif (
39
- isinstance(obs_space, gym.spaces.Dict)
40
- and "observation" in obs_space.spaces.keys()
41
- ):
42
- obs_space = {"observation": obs_space.spaces["observation"].shape}
43
-
44
- def preprocess_obss(obss, device=None):
45
- return ml.DictList({"observation": preprocess_images(obss, device=device)})
46
-
47
- else:
48
- raise ValueError("Unknown observation space: " + str(obs_space))
3
+ import random
49
4
 
50
- return obs_space, preprocess_obss
51
5
 
6
+ def random_subset_with_order(sequence, subset_size, is_consecutive=True):
7
+ """
8
+ Returns a random subset of elements from the given sequence with a specified subset size.
52
9
 
53
- def preprocess_images(images, device=None):
54
- # Bug of Pytorch: very slow if not first converted to numpy array
55
- images = numpy.array(images)
56
- return torch.tensor(images, device=device, dtype=torch.float)
10
+ Args:
11
+ sequence (list): The sequence of elements to select from.
12
+ subset_size (int): The size of the desired subset.
13
+ is_consecutive (bool, optional): Whether the selected subset should be consecutive elements from the sequence.
14
+ Defaults to True.
57
15
 
16
+ Returns:
17
+ list: A random subset of elements from the sequence.
58
18
 
59
- def random_subset_with_order(sequence, subset_size, is_consecutive=True):
19
+ """
60
20
  if subset_size >= len(sequence):
61
21
  return sequence
62
22
  else:
@@ -69,40 +29,3 @@ def random_subset_with_order(sequence, subset_size, is_consecutive=True):
69
29
  return [
70
30
  sequence[i] for i in indices_to_select
71
31
  ] # Return the elements corresponding to the selected indices
72
-
73
-
74
- def preprocess_texts(texts, vocab, device=None):
75
- var_indexed_texts = []
76
- max_text_len = 0
77
-
78
- for text in texts:
79
- tokens = re.findall("([a-z]+)", text.lower())
80
- var_indexed_text = numpy.array([vocab[token] for token in tokens])
81
- var_indexed_texts.append(var_indexed_text)
82
- max_text_len = max(len(var_indexed_text), max_text_len)
83
-
84
- indexed_texts = numpy.zeros((len(texts), max_text_len))
85
-
86
- for i, indexed_text in enumerate(var_indexed_texts):
87
- indexed_texts[i, : len(indexed_text)] = indexed_text
88
-
89
- return torch.tensor(indexed_texts, device=device, dtype=torch.long)
90
-
91
-
92
- class Vocabulary:
93
- """A mapping from tokens to ids with a capacity of `max_size` words.
94
- It can be saved in a `vocab.json` file."""
95
-
96
- def __init__(self, max_size):
97
- self.max_size = max_size
98
- self.vocab = {}
99
-
100
- def load_vocab(self, vocab):
101
- self.vocab = vocab
102
-
103
- def __getitem__(self, token):
104
- if not token in self.vocab.keys():
105
- if len(self.vocab) >= self.max_size:
106
- raise ValueError("Maximum vocabulary capacity reached")
107
- self.vocab[token] = len(self.vocab) + 1
108
- return self.vocab[token]
gr_libs/ml/utils/math.py CHANGED
@@ -1,8 +1,9 @@
1
+ """ math-related functions """
2
+
1
3
  import math
2
- from typing import Callable, Generator, List
3
4
 
4
5
 
5
- def softmax(values: List[float]) -> List[float]:
6
+ def softmax(values: list[float]) -> list[float]:
6
7
  """Computes softmax probabilities for an array of values
7
8
  TODO We should probably use numpy arrays here
8
9
  Args:
gr_libs/ml/utils/other.py CHANGED
@@ -1,8 +1,8 @@
1
+ import collections
1
2
  import random
3
+
2
4
  import numpy
3
5
  import torch
4
- import collections
5
-
6
6
 
7
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
8