likelihood 2.2.0.dev1__cp312-cp312-musllinux_1_2_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- likelihood/VERSION +1 -0
- likelihood/__init__.py +20 -0
- likelihood/graph/__init__.py +9 -0
- likelihood/graph/_nn.py +283 -0
- likelihood/graph/graph.py +86 -0
- likelihood/graph/nn.py +329 -0
- likelihood/main.py +273 -0
- likelihood/models/__init__.py +3 -0
- likelihood/models/deep/__init__.py +13 -0
- likelihood/models/deep/_autoencoders.py +896 -0
- likelihood/models/deep/_predictor.py +809 -0
- likelihood/models/deep/autoencoders.py +903 -0
- likelihood/models/deep/bandit.py +97 -0
- likelihood/models/deep/gan.py +313 -0
- likelihood/models/deep/predictor.py +805 -0
- likelihood/models/deep/rl.py +345 -0
- likelihood/models/environments.py +202 -0
- likelihood/models/hmm.py +163 -0
- likelihood/models/regression.py +451 -0
- likelihood/models/simulation.py +213 -0
- likelihood/models/utils.py +87 -0
- likelihood/pipes.py +382 -0
- likelihood/rust_py_integration.cpython-312-x86_64-linux-musl.so +0 -0
- likelihood/tools/__init__.py +4 -0
- likelihood/tools/cat_embed.py +212 -0
- likelihood/tools/figures.py +348 -0
- likelihood/tools/impute.py +278 -0
- likelihood/tools/models_tools.py +866 -0
- likelihood/tools/numeric_tools.py +390 -0
- likelihood/tools/reports.py +375 -0
- likelihood/tools/tools.py +1336 -0
- likelihood-2.2.0.dev1.dist-info/METADATA +68 -0
- likelihood-2.2.0.dev1.dist-info/RECORD +39 -0
- likelihood-2.2.0.dev1.dist-info/WHEEL +5 -0
- likelihood-2.2.0.dev1.dist-info/licenses/LICENSE +21 -0
- likelihood-2.2.0.dev1.dist-info/sboms/auditwheel.cdx.json +1 -0
- likelihood-2.2.0.dev1.dist-info/top_level.txt +7 -0
- likelihood.libs/libgcc_s-0cd532bd.so.1 +0 -0
- src/lib.rs +12 -0
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from collections import deque
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def print_progress_bar(iteration, total, length=30):
|
|
10
|
+
percent = f"{100 * (iteration / float(total)):.1f}"
|
|
11
|
+
filled_length = int(length * iteration // total)
|
|
12
|
+
bar = "█" * filled_length + "-" * (length - filled_length)
|
|
13
|
+
print(f"\rProgress: |{bar}| {percent}% Complete", end="\r")
|
|
14
|
+
if iteration == total:
|
|
15
|
+
print()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Env:
|
|
19
|
+
def __init__(self, model: Any, maxlen: int = 100, name: str = "likenasium"):
|
|
20
|
+
"""
|
|
21
|
+
Initialize the environment with a model.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
model : Any
|
|
26
|
+
Model with `.predict()` method (e.g., Keras model).
|
|
27
|
+
maxlen : int
|
|
28
|
+
Maximum length of deque. By default it is set to `100`.
|
|
29
|
+
name : str
|
|
30
|
+
The name of the environment. By default it is set to `likenasium`.
|
|
31
|
+
"""
|
|
32
|
+
self.model = model
|
|
33
|
+
self.maxlen = maxlen
|
|
34
|
+
self.transitions = deque(
|
|
35
|
+
maxlen=self.maxlen
|
|
36
|
+
) # Stores (state, action, reward, next_action, done)
|
|
37
|
+
self.current_state = None
|
|
38
|
+
self.current_step = 0
|
|
39
|
+
self.done = False
|
|
40
|
+
|
|
41
|
+
def step(self, state: np.ndarray, action: int, verbose: int = 0):
|
|
42
|
+
"""
|
|
43
|
+
Perform an environment step with the given action.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
state : `np.ndarray`
|
|
48
|
+
Current state to process (input to the model).
|
|
49
|
+
action : `int`
|
|
50
|
+
Expected action to process.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
tuple : (current_state, action_pred, reward, next_action, done)
|
|
55
|
+
"""
|
|
56
|
+
if self.done:
|
|
57
|
+
return None, None, 0, None, True
|
|
58
|
+
|
|
59
|
+
# Process action through model
|
|
60
|
+
model_output = self.model.predict(state.reshape((1, -1)), verbose=verbose)
|
|
61
|
+
action_pred = np.argmax(model_output, axis=1)[0]
|
|
62
|
+
model_output[:, action_pred] = 0.0
|
|
63
|
+
next_action = np.max(model_output, axis=1)[0] # Second most probable action
|
|
64
|
+
|
|
65
|
+
# Calculate reward (1 if correct prediction, 0 otherwise)
|
|
66
|
+
reward = 1 if action_pred == action else 0
|
|
67
|
+
|
|
68
|
+
# Update current state
|
|
69
|
+
self.current_state = state
|
|
70
|
+
self.current_step += 1
|
|
71
|
+
|
|
72
|
+
# Add transition to history
|
|
73
|
+
if self.current_step <= self.maxlen:
|
|
74
|
+
self.transitions.append(
|
|
75
|
+
(
|
|
76
|
+
self.current_state, # Previous state
|
|
77
|
+
action_pred, # Current action
|
|
78
|
+
reward, # Reward
|
|
79
|
+
next_action, # Next action
|
|
80
|
+
self.done, # Done flag
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
return self.current_state, action_pred, reward, next_action, self.done
|
|
84
|
+
|
|
85
|
+
def reset(self):
|
|
86
|
+
"""Reset the environment to initial state."""
|
|
87
|
+
self.current_state = None
|
|
88
|
+
self.current_step = 0
|
|
89
|
+
self.done = False
|
|
90
|
+
self.transitions = deque(maxlen=self.maxlen)
|
|
91
|
+
return self.current_state
|
|
92
|
+
|
|
93
|
+
def get_transitions(self):
|
|
94
|
+
"""Get all stored transitions."""
|
|
95
|
+
return self.transitions
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class AutoQL:
|
|
99
|
+
"""
|
|
100
|
+
AutoQL: A reinforcement learning agent using Q-learning with Epsilon-greedy policy.
|
|
101
|
+
|
|
102
|
+
This class implements a Q-learning agent with:
|
|
103
|
+
- Epsilon-greedy policy for exploration
|
|
104
|
+
- Replay buffer for experience replay
|
|
105
|
+
- Automatic model version handling for TensorFlow
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
env: Any,
|
|
111
|
+
model: tf.keras.Model,
|
|
112
|
+
maxlen: int = 2000,
|
|
113
|
+
):
|
|
114
|
+
"""Initialize AutoQL agent
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
env : `Any`
|
|
119
|
+
The environment to interact with
|
|
120
|
+
model : `tf.keras.Model`
|
|
121
|
+
The Q-network model
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
self.env = env
|
|
125
|
+
self.model = model
|
|
126
|
+
self.maxlen = maxlen
|
|
127
|
+
self.replay_buffer = deque(maxlen=self.maxlen)
|
|
128
|
+
|
|
129
|
+
def epsilon_greedy_policy(self, state: np.ndarray, action: int, epsilon: float = 0.0) -> tuple:
|
|
130
|
+
"""
|
|
131
|
+
Epsilon-greedy policy for action selection
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
state : `np.ndarray`
|
|
136
|
+
Current state.
|
|
137
|
+
action : `int`
|
|
138
|
+
Expected action to process.
|
|
139
|
+
epsilon : `float`
|
|
140
|
+
Exploration probability. By default it is set to `0.0`
|
|
141
|
+
|
|
142
|
+
Returns
|
|
143
|
+
-------
|
|
144
|
+
tuple : (state, action, reward, next_action, done)
|
|
145
|
+
"""
|
|
146
|
+
current_state, value, reward, next_action, done = self.env.step(state, action)
|
|
147
|
+
|
|
148
|
+
if np.random.rand() > epsilon:
|
|
149
|
+
state = np.asarray(state).astype(np.float32)
|
|
150
|
+
return current_state, value, reward, next_action, done
|
|
151
|
+
step_ = random.sample(self.env.get_transitions(), 1)
|
|
152
|
+
_state, greedy_action, _reward, _next_action, _done = zip(*step_)
|
|
153
|
+
|
|
154
|
+
return _state[0], greedy_action[0], _reward[0], _next_action[0], _done[0]
|
|
155
|
+
|
|
156
|
+
def play_one_step(self, state: np.ndarray, action: int, epsilon: float):
|
|
157
|
+
"""
|
|
158
|
+
Perform one step in the environment and add experience to buffer
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
state : `np.ndarray`
|
|
163
|
+
Current state
|
|
164
|
+
action : `int`
|
|
165
|
+
Expected action to process.
|
|
166
|
+
|
|
167
|
+
epsilon : `float`
|
|
168
|
+
Exploration probability.
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
tuple : (state, action, reward, next_action, done)
|
|
173
|
+
"""
|
|
174
|
+
current_state, greedy_action, reward, next_action, done = self.epsilon_greedy_policy(
|
|
175
|
+
state, action, epsilon
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
done = 1 if done else 0
|
|
179
|
+
|
|
180
|
+
# Add experience to replay buffer
|
|
181
|
+
self.replay_buffer.append(
|
|
182
|
+
(
|
|
183
|
+
current_state, # Previous state
|
|
184
|
+
greedy_action, # Current action
|
|
185
|
+
reward, # Reward
|
|
186
|
+
next_action, # Next action
|
|
187
|
+
done, # Done flag
|
|
188
|
+
)
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
return current_state, greedy_action, reward, next_action, done
|
|
192
|
+
|
|
193
|
+
@tf.function
|
|
194
|
+
def _training_step(self):
|
|
195
|
+
"""
|
|
196
|
+
Perform one training step using experience replay
|
|
197
|
+
|
|
198
|
+
Returns
|
|
199
|
+
-------
|
|
200
|
+
float : Training loss
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
batch_ = random.sample(self.replay_buffer, self.batch_size)
|
|
204
|
+
states, actions, rewards, next_actions, dones = zip(*batch_)
|
|
205
|
+
states = np.array(states).reshape(self.batch_size, -1)
|
|
206
|
+
actions = np.array(actions).reshape(
|
|
207
|
+
self.batch_size,
|
|
208
|
+
)
|
|
209
|
+
rewards = np.array(rewards).reshape(
|
|
210
|
+
self.batch_size,
|
|
211
|
+
)
|
|
212
|
+
max_next_Q_values = np.array(next_actions).reshape(self.batch_size, -1)
|
|
213
|
+
dones = np.array(dones).reshape(
|
|
214
|
+
self.batch_size,
|
|
215
|
+
)
|
|
216
|
+
target_Q_values = rewards + (1 - dones) * self.gamma * max_next_Q_values
|
|
217
|
+
|
|
218
|
+
actions = tf.convert_to_tensor(actions, dtype=tf.int32)
|
|
219
|
+
states = tf.convert_to_tensor(states, dtype=tf.float32)
|
|
220
|
+
target_Q_values = tf.convert_to_tensor(target_Q_values, dtype=tf.float32)
|
|
221
|
+
|
|
222
|
+
with tf.GradientTape() as tape:
|
|
223
|
+
all_Q_values = self.model(states)
|
|
224
|
+
indices = tf.stack([tf.range(tf.shape(actions)[0]), actions], axis=1)
|
|
225
|
+
Q_values = tf.gather_nd(all_Q_values, indices)
|
|
226
|
+
loss = tf.reduce_mean(self.loss_fn(target_Q_values, Q_values))
|
|
227
|
+
grads = tape.gradient(loss, self.model.trainable_variables)
|
|
228
|
+
self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
|
|
229
|
+
return loss
|
|
230
|
+
|
|
231
|
+
def train(
|
|
232
|
+
self,
|
|
233
|
+
x_data,
|
|
234
|
+
y_data,
|
|
235
|
+
optimizer="adam",
|
|
236
|
+
loss_fn="mse",
|
|
237
|
+
num_episodes=50,
|
|
238
|
+
num_steps=100,
|
|
239
|
+
gamma=0.7,
|
|
240
|
+
batch_size=32,
|
|
241
|
+
patience=10,
|
|
242
|
+
alpha=0.01,
|
|
243
|
+
):
|
|
244
|
+
"""Train the agent for a fixed number of episodes
|
|
245
|
+
|
|
246
|
+
Parameters
|
|
247
|
+
----------
|
|
248
|
+
optimizer : `str`
|
|
249
|
+
The optimizer for training (e.g., `sgd`). By default it is set to `adam`.
|
|
250
|
+
loss_fn : `str`
|
|
251
|
+
The loss function. By default it is set to `mse`.
|
|
252
|
+
num_episodes : `int`
|
|
253
|
+
Total number of episodes to train. By default it is set to `50`.
|
|
254
|
+
num_steps : `int`
|
|
255
|
+
Steps per episode. By default it is set to `100`. If `num_steps` is less than `self.env.maxlen`, then the second will be chosen.
|
|
256
|
+
gamma : `float`
|
|
257
|
+
Discount factor. By default it is set to `0.7`.
|
|
258
|
+
batch_size : `int`
|
|
259
|
+
Size of training batches. By default it is set to `32`.
|
|
260
|
+
patience : `int`
|
|
261
|
+
How many episodes to wait for improvement.
|
|
262
|
+
alpha : `float`
|
|
263
|
+
Trade-off factor between loss and reward.
|
|
264
|
+
"""
|
|
265
|
+
rewards = []
|
|
266
|
+
self.best_weights = None
|
|
267
|
+
self.best_loss = float("inf")
|
|
268
|
+
|
|
269
|
+
optimizers = {
|
|
270
|
+
"sgd": tf.keras.optimizers.SGD(),
|
|
271
|
+
"adam": tf.keras.optimizers.Adam(),
|
|
272
|
+
"adamw": tf.keras.optimizers.AdamW(),
|
|
273
|
+
"adadelta": tf.keras.optimizers.Adadelta(),
|
|
274
|
+
"rmsprop": tf.keras.optimizers.RMSprop(),
|
|
275
|
+
}
|
|
276
|
+
self.optimizer = optimizers[optimizer]
|
|
277
|
+
losses = {
|
|
278
|
+
"mse": tf.keras.losses.MeanSquaredError(),
|
|
279
|
+
"mae": tf.keras.losses.MeanAbsoluteError(),
|
|
280
|
+
"mape": tf.keras.losses.MeanAbsolutePercentageError(),
|
|
281
|
+
}
|
|
282
|
+
self.loss_fn = losses[loss_fn]
|
|
283
|
+
self.num_episodes = num_episodes
|
|
284
|
+
self.num_steps = num_steps if num_steps >= self.env.maxlen else self.env.maxlen
|
|
285
|
+
self.gamma = gamma
|
|
286
|
+
self.batch_size = batch_size
|
|
287
|
+
loss = float("inf")
|
|
288
|
+
no_improve_count = 0
|
|
289
|
+
best_combined_metric = float("inf")
|
|
290
|
+
|
|
291
|
+
for episode in range(self.num_episodes):
|
|
292
|
+
print_progress_bar(episode + 1, self.num_episodes)
|
|
293
|
+
self.env.reset()
|
|
294
|
+
sum_rewards = 0
|
|
295
|
+
epsilon = max(1 - episode / (self.num_episodes * 0.8), 0.01)
|
|
296
|
+
|
|
297
|
+
for step in range(self.num_steps):
|
|
298
|
+
state, action, reward, next_action, done = self.play_one_step(
|
|
299
|
+
x_data[step], y_data[step], epsilon
|
|
300
|
+
)
|
|
301
|
+
sum_rewards += reward if isinstance(reward, int) else reward[0]
|
|
302
|
+
|
|
303
|
+
# Train if buffer has enough samples
|
|
304
|
+
if len(self.replay_buffer) > self.batch_size:
|
|
305
|
+
loss = self._training_step()
|
|
306
|
+
|
|
307
|
+
if done:
|
|
308
|
+
break
|
|
309
|
+
|
|
310
|
+
combined_metric = loss - alpha * sum_rewards
|
|
311
|
+
|
|
312
|
+
if combined_metric < best_combined_metric:
|
|
313
|
+
best_combined_metric = combined_metric
|
|
314
|
+
self.best_weights = self.model.get_weights()
|
|
315
|
+
self.best_loss = loss
|
|
316
|
+
no_improve_count = 0 # Reset counter on improvement
|
|
317
|
+
else:
|
|
318
|
+
no_improve_count += 1
|
|
319
|
+
|
|
320
|
+
rewards.append(sum_rewards)
|
|
321
|
+
|
|
322
|
+
# Logging
|
|
323
|
+
if episode % (self.num_episodes // 10) == 0:
|
|
324
|
+
print(
|
|
325
|
+
f"Episode: {episode}, Steps: {step+1}, Epsilon: {epsilon:.3f}, Loss: {loss:.2e}, Reward: {sum_rewards}, No Improve Count: {no_improve_count}"
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# Early stopping condition
|
|
329
|
+
if no_improve_count >= patience:
|
|
330
|
+
print(
|
|
331
|
+
f"Early stopping at episode {episode} due to no improvement in {patience} episodes."
|
|
332
|
+
)
|
|
333
|
+
break
|
|
334
|
+
|
|
335
|
+
# Save best model
|
|
336
|
+
self.model.set_weights(self.best_weights)
|
|
337
|
+
|
|
338
|
+
def __str__(self):
|
|
339
|
+
return (
|
|
340
|
+
f"AutoQL (Env: {self.env.name}, Episodes: {self.num_episodes}, Steps: {self.num_steps})"
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
if __name__ == "__main__":
|
|
345
|
+
pass
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Any, Dict, List, Tuple
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ActionSpace:
|
|
8
|
+
def __init__(self, num_actions):
|
|
9
|
+
self._num_actions = num_actions
|
|
10
|
+
|
|
11
|
+
@property
|
|
12
|
+
def n(self):
|
|
13
|
+
return self._num_actions
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OptionCriticEnv:
|
|
17
|
+
"""
|
|
18
|
+
An environment for Option Critic reinforcement learning that processes a dataset of episodes.
|
|
19
|
+
|
|
20
|
+
Attributes
|
|
21
|
+
----------
|
|
22
|
+
episodes : `Dict[str, Dict]`
|
|
23
|
+
Dataset of episodes with state, action, selected_option, reward, next_state, and done information.
|
|
24
|
+
observation_space : `np.ndarray`
|
|
25
|
+
Initial observation space shape (from first episode's state)
|
|
26
|
+
done : `bool`
|
|
27
|
+
Whether the current episode has terminated
|
|
28
|
+
num_options : `int`
|
|
29
|
+
Number of distinct options available in the dataset
|
|
30
|
+
actions_by_option : `defaultdict(set)`
|
|
31
|
+
Maps selected options to sets of actions that were taken with them
|
|
32
|
+
unique_actions_count : `List[int]`
|
|
33
|
+
Count of unique actions per option index (used for action space definition)
|
|
34
|
+
action_space : `ActionSpace`
|
|
35
|
+
Custom action space defined by unique actions per option
|
|
36
|
+
idx_episode : `int`
|
|
37
|
+
Current episode index being processed
|
|
38
|
+
current_state : `np.ndarray`
|
|
39
|
+
Current state observation in the environment
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
episodes: Dict[int, Dict[str, List]],
|
|
45
|
+
):
|
|
46
|
+
"""
|
|
47
|
+
Initializes the OptionCriticEnv with a dataset of episodes.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
episodes : `Dict[int, Dict]`
|
|
52
|
+
Dataset of episodes where keys are episode identifiers and values are episode data.
|
|
53
|
+
Each episode must contain at least:
|
|
54
|
+
- "state": List of state observations
|
|
55
|
+
- "selected_option": List of selected options
|
|
56
|
+
- "action": List of actions taken
|
|
57
|
+
- "reward": List of rewards
|
|
58
|
+
- "next_state": List of next states
|
|
59
|
+
- "done": List of termination flags
|
|
60
|
+
|
|
61
|
+
Raises
|
|
62
|
+
------
|
|
63
|
+
ValueError
|
|
64
|
+
If required fields ("state" or "selected_option") are missing from episode data
|
|
65
|
+
"""
|
|
66
|
+
self.episodes = episodes
|
|
67
|
+
|
|
68
|
+
required_keys = ["state", "action", "selected_option", "reward", "next_state", "done"]
|
|
69
|
+
for episode_id, data in episodes.items():
|
|
70
|
+
if not all(k in data for k in required_keys):
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"Episode {episode_id} missing keys: {set(required_keys) - set(data.keys())}"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
self.observation_space = np.array(episodes[0]["state"][0])
|
|
76
|
+
self.done = False
|
|
77
|
+
self.idx_episode = 0
|
|
78
|
+
self.current_state = None
|
|
79
|
+
self.num_options = len(set(episodes[0]["selected_option"]))
|
|
80
|
+
self.actions_by_option = defaultdict(set)
|
|
81
|
+
|
|
82
|
+
# Build fast lookup for transitions
|
|
83
|
+
self.state_action_option_to_transition: Dict[Tuple, Dict[str, Any]] = {}
|
|
84
|
+
|
|
85
|
+
for episode_id, data in episodes.items():
|
|
86
|
+
states = data["state"]
|
|
87
|
+
actions = data["action"]
|
|
88
|
+
options = data["selected_option"]
|
|
89
|
+
next_states = data["next_state"]
|
|
90
|
+
rewards = data["reward"]
|
|
91
|
+
dones = data["done"]
|
|
92
|
+
|
|
93
|
+
for i in range(len(states)):
|
|
94
|
+
state_key = tuple(states[i])
|
|
95
|
+
key = (state_key, options[i], actions[i])
|
|
96
|
+
|
|
97
|
+
self.state_action_option_to_transition[key] = {
|
|
98
|
+
"next_state": next_states[i],
|
|
99
|
+
"reward": rewards[i],
|
|
100
|
+
"done": dones[i],
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
for i, selected in enumerate(options):
|
|
104
|
+
self.actions_by_option[selected].add(actions[i])
|
|
105
|
+
|
|
106
|
+
self.unique_actions_count = [
|
|
107
|
+
len(self.actions_by_option.get(i, set()))
|
|
108
|
+
for i in range(max(self.actions_by_option.keys()) + 1)
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
self.action_space = ActionSpace(self.unique_actions_count)
|
|
112
|
+
|
|
113
|
+
def reset(self) -> tuple[np.ndarray, dict]:
|
|
114
|
+
"""
|
|
115
|
+
Resets the environment to a random episode and returns the initial state.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
observation : `np.ndarray`
|
|
120
|
+
Initial state observation
|
|
121
|
+
info : `Dict`
|
|
122
|
+
Empty dictionary (no additional information)
|
|
123
|
+
"""
|
|
124
|
+
episode_id = np.random.choice(list(self.episodes.keys()))
|
|
125
|
+
self.idx_episode = episode_id
|
|
126
|
+
self.current_state = self.episodes[episode_id]["state"][0]
|
|
127
|
+
return self.current_state, {}
|
|
128
|
+
|
|
129
|
+
def step(self, action: int, option: int) -> tuple[np.ndarray, float, bool, bool, dict]:
|
|
130
|
+
"""
|
|
131
|
+
Takes an action with a specific option and returns the next state, reward, and termination status.
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
action : `int`
|
|
136
|
+
Action index to execute
|
|
137
|
+
option : `int`
|
|
138
|
+
Selected option index
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
next_state : `np.ndarray`
|
|
143
|
+
State after taking the action
|
|
144
|
+
reward : `float`
|
|
145
|
+
Immediate reward for the transition
|
|
146
|
+
done : `bool`
|
|
147
|
+
Whether the episode has terminated (from episode data)
|
|
148
|
+
terminated : `bool`
|
|
149
|
+
Whether the action-option pair was found in the dataset
|
|
150
|
+
info : `Dict`
|
|
151
|
+
Empty dictionary (no additional information)
|
|
152
|
+
"""
|
|
153
|
+
key = (tuple(self.current_state), option, action)
|
|
154
|
+
if key in self.state_action_option_to_transition:
|
|
155
|
+
trans = self.state_action_option_to_transition[key]
|
|
156
|
+
self.current_state = trans["next_state"]
|
|
157
|
+
return trans["next_state"].copy(), trans["reward"], trans["done"], True, {}
|
|
158
|
+
else:
|
|
159
|
+
return self.current_state, 0.0, False, False, {}
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
if __name__ == "__main__":
|
|
163
|
+
data = {
|
|
164
|
+
0: {
|
|
165
|
+
"state": [
|
|
166
|
+
np.array([0.03503893, 0.0471871, 0.00121938, -0.00847874]),
|
|
167
|
+
np.array([0.03598267, -0.14795232, 0.00104981, 0.28458866]),
|
|
168
|
+
],
|
|
169
|
+
"selected_option": [0, 0],
|
|
170
|
+
"action": [0, 0],
|
|
171
|
+
"next_state": [
|
|
172
|
+
np.array([0.03598267, -0.14795232, 0.00104981, 0.28458866]),
|
|
173
|
+
np.array([0.03302363, -0.34308922, 0.00674158, 0.5776025]),
|
|
174
|
+
],
|
|
175
|
+
"reward": [1.0, 1.0],
|
|
176
|
+
"done": [False, False],
|
|
177
|
+
},
|
|
178
|
+
1: {
|
|
179
|
+
"state": [
|
|
180
|
+
np.array([0.04769269, -0.03987791, -0.01187594, 0.02884407]),
|
|
181
|
+
np.array([0.04689513, -0.23482755, -0.01129905, 0.31775647]),
|
|
182
|
+
],
|
|
183
|
+
"selected_option": [0, 0],
|
|
184
|
+
"action": [0, 0],
|
|
185
|
+
"next_state": [
|
|
186
|
+
np.array([0.04689513, -0.23482755, -0.01129905, 0.31775647]),
|
|
187
|
+
np.array([0.04219858, -0.42978677, -0.00494392, 0.6068548]),
|
|
188
|
+
],
|
|
189
|
+
"reward": [1.0, 1.0],
|
|
190
|
+
"done": [False, False],
|
|
191
|
+
},
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
# Initialize environment
|
|
195
|
+
env = OptionCriticEnv(episodes=data)
|
|
196
|
+
env.reset()
|
|
197
|
+
num_actions = env.action_space.n
|
|
198
|
+
print("current state :", env.current_state)
|
|
199
|
+
print("environment step :", env.step(1, 0))
|
|
200
|
+
print("current state :", env.current_state)
|
|
201
|
+
print("environment step :", env.step(1, 0))
|
|
202
|
+
print("num_actions :", num_actions)
|