gr-libs 0.1.7.post0__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. evaluation/analyze_results_cross_alg_cross_domain.py +236 -246
  2. evaluation/create_minigrid_map_image.py +10 -6
  3. evaluation/file_system.py +16 -5
  4. evaluation/generate_experiments_results.py +123 -74
  5. evaluation/generate_experiments_results_new_ver1.py +227 -243
  6. evaluation/generate_experiments_results_new_ver2.py +317 -317
  7. evaluation/generate_task_specific_statistics_plots.py +481 -253
  8. evaluation/get_plans_images.py +41 -26
  9. evaluation/increasing_and_decreasing_.py +97 -56
  10. gr_libs/__init__.py +2 -1
  11. gr_libs/_version.py +2 -2
  12. gr_libs/environment/__init__.py +16 -8
  13. gr_libs/environment/environment.py +167 -39
  14. gr_libs/environment/utils/utils.py +22 -12
  15. gr_libs/metrics/__init__.py +5 -0
  16. gr_libs/metrics/metrics.py +76 -34
  17. gr_libs/ml/__init__.py +2 -0
  18. gr_libs/ml/agent.py +21 -6
  19. gr_libs/ml/base/__init__.py +1 -1
  20. gr_libs/ml/base/rl_agent.py +13 -10
  21. gr_libs/ml/consts.py +1 -1
  22. gr_libs/ml/neural/deep_rl_learner.py +433 -352
  23. gr_libs/ml/neural/utils/__init__.py +1 -1
  24. gr_libs/ml/neural/utils/dictlist.py +3 -3
  25. gr_libs/ml/neural/utils/penv.py +5 -2
  26. gr_libs/ml/planner/mcts/mcts_model.py +524 -302
  27. gr_libs/ml/planner/mcts/utils/__init__.py +1 -1
  28. gr_libs/ml/planner/mcts/utils/node.py +11 -7
  29. gr_libs/ml/planner/mcts/utils/tree.py +14 -10
  30. gr_libs/ml/sequential/__init__.py +1 -1
  31. gr_libs/ml/sequential/lstm_model.py +256 -175
  32. gr_libs/ml/tabular/state.py +7 -7
  33. gr_libs/ml/tabular/tabular_q_learner.py +123 -73
  34. gr_libs/ml/tabular/tabular_rl_agent.py +20 -19
  35. gr_libs/ml/utils/__init__.py +8 -2
  36. gr_libs/ml/utils/format.py +78 -70
  37. gr_libs/ml/utils/math.py +2 -1
  38. gr_libs/ml/utils/other.py +1 -1
  39. gr_libs/ml/utils/storage.py +88 -28
  40. gr_libs/problems/consts.py +1549 -1227
  41. gr_libs/recognizer/gr_as_rl/gr_as_rl_recognizer.py +145 -80
  42. gr_libs/recognizer/graml/gr_dataset.py +209 -110
  43. gr_libs/recognizer/graml/graml_recognizer.py +431 -240
  44. gr_libs/recognizer/recognizer.py +38 -27
  45. gr_libs/recognizer/utils/__init__.py +1 -1
  46. gr_libs/recognizer/utils/format.py +8 -3
  47. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.1.8.dist-info}/METADATA +1 -1
  48. gr_libs-0.1.8.dist-info/RECORD +70 -0
  49. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.1.8.dist-info}/WHEEL +1 -1
  50. tests/test_gcdraco.py +10 -0
  51. tests/test_graml.py +8 -4
  52. tests/test_graql.py +2 -1
  53. tutorials/gcdraco_panda_tutorial.py +66 -0
  54. tutorials/gcdraco_parking_tutorial.py +61 -0
  55. tutorials/graml_minigrid_tutorial.py +42 -12
  56. tutorials/graml_panda_tutorial.py +35 -14
  57. tutorials/graml_parking_tutorial.py +37 -20
  58. tutorials/graml_point_maze_tutorial.py +33 -13
  59. tutorials/graql_minigrid_tutorial.py +31 -15
  60. gr_libs-0.1.7.post0.dist-info/RECORD +0 -67
  61. {gr_libs-0.1.7.post0.dist-info → gr_libs-0.1.8.dist-info}/top_level.txt +0 -0
@@ -6,13 +6,15 @@ import sys
6
6
 
7
7
  from .other import device
8
8
 
9
+
9
10
  def create_folders_if_necessary(path):
10
11
  if not os.path.exists(path):
11
12
  os.makedirs(path)
12
13
 
13
14
 
14
15
  def get_storage_framework_dir(recognizer: str):
15
- return os.path.join(get_storage_dir(),recognizer)
16
+ return os.path.join(get_storage_dir(), recognizer)
17
+
16
18
 
17
19
  def get_storage_dir():
18
20
  # Prefer local directory if it exists (e.g., in GitHub workspace)
@@ -24,70 +26,128 @@ def get_storage_dir():
24
26
  # Default to "dataset" even if it doesn't exist (e.g., will be created)
25
27
  return "dataset"
26
28
 
29
+
27
30
  def _get_models_directory_name():
28
31
  return "models"
29
32
 
33
+
30
34
  def _get_siamese_datasets_directory_name():
31
35
  return "siamese_datasets"
32
36
 
37
+
33
38
  def _get_observations_directory_name():
34
39
  return "observations"
35
40
 
41
+
36
42
  def get_observation_file_name(observability_percentage: float):
37
- return 'obs' + str(observability_percentage) + '.pkl'
43
+ return "obs" + str(observability_percentage) + ".pkl"
38
44
 
39
45
 
40
- def get_domain_dir(domain_name, recognizer:str):
46
+ def get_domain_dir(domain_name, recognizer: str):
41
47
  return os.path.join(get_storage_framework_dir(recognizer), domain_name)
42
48
 
43
- def get_env_dir(domain_name, env_name, recognizer:str):
49
+
50
+ def get_env_dir(domain_name, env_name, recognizer: str):
44
51
  return os.path.join(get_domain_dir(domain_name, recognizer), env_name)
45
52
 
46
- def get_observations_dir(domain_name, env_name, recognizer:str):
47
- return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), _get_observations_directory_name())
53
+
54
+ def get_observations_dir(domain_name, env_name, recognizer: str):
55
+ return os.path.join(
56
+ get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer),
57
+ _get_observations_directory_name(),
58
+ )
59
+
48
60
 
49
61
  def get_agent_model_dir(domain_name, model_name, class_name):
50
- return os.path.join(get_storage_dir(), _get_models_directory_name(), domain_name, model_name, class_name)
62
+ return os.path.join(
63
+ get_storage_dir(),
64
+ _get_models_directory_name(),
65
+ domain_name,
66
+ model_name,
67
+ class_name,
68
+ )
69
+
70
+
71
+ def get_lstm_model_dir(domain_name, env_name, model_name, recognizer: str):
72
+ return os.path.join(
73
+ get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer),
74
+ model_name,
75
+ )
76
+
51
77
 
52
- def get_lstm_model_dir(domain_name, env_name, model_name, recognizer:str):
53
- return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), model_name)
78
+ def get_models_dir(domain_name, env_name, recognizer: str):
79
+ return os.path.join(
80
+ get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer),
81
+ _get_models_directory_name(),
82
+ )
54
83
 
55
- def get_models_dir(domain_name, env_name, recognizer:str):
56
- return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), _get_models_directory_name())
57
84
 
58
85
  ### GRAML PATHS ###
59
86
 
60
- def get_siamese_dataset_path(domain_name, env_name, model_name, recognizer:str):
61
- return os.path.join(get_lstm_model_dir(domain_name, env_name, model_name, recognizer), _get_siamese_datasets_directory_name())
62
87
 
63
- def get_embeddings_result_path(domain_name, env_name, recognizer:str):
64
- return os.path.join(get_env_dir(domain_name, env_name=env_name, recognizer=recognizer), "goal_embeddings")
88
+ def get_siamese_dataset_path(domain_name, env_name, model_name, recognizer: str):
89
+ return os.path.join(
90
+ get_lstm_model_dir(domain_name, env_name, model_name, recognizer),
91
+ _get_siamese_datasets_directory_name(),
92
+ )
93
+
94
+
95
+ def get_embeddings_result_path(domain_name, env_name, recognizer: str):
96
+ return os.path.join(
97
+ get_env_dir(domain_name, env_name=env_name, recognizer=recognizer),
98
+ "goal_embeddings",
99
+ )
100
+
101
+
102
+ def get_embeddings_result_path(domain_name, env_name, recognizer: str):
103
+ return os.path.join(
104
+ get_env_dir(domain_name, env_name=env_name, recognizer=recognizer),
105
+ "goal_embeddings",
106
+ )
65
107
 
66
- def get_embeddings_result_path(domain_name, env_name, recognizer:str):
67
- return os.path.join(get_env_dir(domain_name, env_name=env_name, recognizer=recognizer), "goal_embeddings")
68
108
 
69
109
  def get_and_create(path):
70
110
  create_folders_if_necessary(path)
71
111
  return path
72
112
 
73
- def get_experiment_results_path(domain, env_name, task, recognizer:str):
74
- return os.path.join(get_env_dir(domain, env_name=env_name, recognizer=recognizer), "experiment_results", env_name, task, "experiment_results")
75
113
 
76
- def get_plans_result_path(domain_name, env_name, recognizer:str):
77
- return os.path.join(get_env_dir(domain_name, env_name=env_name, recognizer=recognizer), "plans")
114
+ def get_experiment_results_path(domain, env_name, task, recognizer: str):
115
+ return os.path.join(
116
+ get_env_dir(domain, env_name=env_name, recognizer=recognizer),
117
+ "experiment_results",
118
+ env_name,
119
+ task,
120
+ "experiment_results",
121
+ )
122
+
123
+
124
+ def get_plans_result_path(domain_name, env_name, recognizer: str):
125
+ return os.path.join(
126
+ get_env_dir(domain_name, env_name=env_name, recognizer=recognizer), "plans"
127
+ )
128
+
129
+
130
+ def get_policy_sequences_result_path(domain_name, env_name, recognizer: str):
131
+ return os.path.join(
132
+ get_env_dir(domain_name, env_name, recognizer=recognizer), "policy_sequences"
133
+ )
78
134
 
79
- def get_policy_sequences_result_path(domain_name, env_name, recognizer:str):
80
- return os.path.join(get_env_dir(domain_name, env_name, recognizer=recognizer), "policy_sequences")
81
135
 
82
136
  ### END GRAML PATHS ###
83
- ''
137
+ ""
84
138
  ### GRAQL PATHS ###
85
139
 
86
- def get_gr_as_rl_experiment_confidence_path(domain_name, env_name, recognizer:str):
87
- return os.path.join(get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer), "experiments")
140
+
141
+ def get_gr_as_rl_experiment_confidence_path(domain_name, env_name, recognizer: str):
142
+ return os.path.join(
143
+ get_env_dir(domain_name=domain_name, env_name=env_name, recognizer=recognizer),
144
+ "experiments",
145
+ )
146
+
88
147
 
89
148
  ### GRAQL PATHS ###
90
149
 
150
+
91
151
  def get_status_path(model_dir):
92
152
  return os.path.join(model_dir, "status.pt")
93
153
 
@@ -120,8 +180,8 @@ def get_txt_logger(model_dir):
120
180
  format="%(message)s",
121
181
  handlers=[
122
182
  logging.FileHandler(filename=path),
123
- logging.StreamHandler(sys.stdout)
124
- ]
183
+ logging.StreamHandler(sys.stdout),
184
+ ],
125
185
  )
126
186
 
127
187
  return logging.getLogger()