gym-examples 3.0.753__py3-none-any.whl → 3.0.755__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gym_examples/__init__.py +1 -1
- gym_examples/envs/wsn_env.py +15 -12
- {gym_examples-3.0.753.dist-info → gym_examples-3.0.755.dist-info}/METADATA +1 -1
- gym_examples-3.0.755.dist-info/RECORD +7 -0
- gym_examples-3.0.753.dist-info/RECORD +0 -7
- {gym_examples-3.0.753.dist-info → gym_examples-3.0.755.dist-info}/WHEEL +0 -0
- {gym_examples-3.0.753.dist-info → gym_examples-3.0.755.dist-info}/top_level.txt +0 -0
gym_examples/__init__.py
CHANGED
gym_examples/envs/wsn_env.py
CHANGED
@@ -30,7 +30,7 @@ base_back_up_dir = "results/data/"
|
|
30
30
|
max_reward = 1 # maximum reward value when the sensors sent data to the base station. The opposite value is when the sensors perform an unauthorized action
|
31
31
|
|
32
32
|
class ScalarAttentionModel(nn.Module):
|
33
|
-
def __init__(self, input_dim
|
33
|
+
def __init__(self, input_dim, output_dim=1):
|
34
34
|
super(ScalarAttentionModel, self).__init__()
|
35
35
|
# Initialize GaussianAdaptiveAttention
|
36
36
|
self.ga_attention = GaussianAdaptiveAttention(
|
@@ -51,8 +51,8 @@ class ScalarAttentionModel(nn.Module):
|
|
51
51
|
scalar_output = self.output_layer(attention_output)
|
52
52
|
return scalar_output
|
53
53
|
|
54
|
-
net = ScalarAttentionModel(input_dim)
|
55
|
-
net = net.double() # Convert the weights to Double
|
54
|
+
# net = ScalarAttentionModel(input_dim)
|
55
|
+
# net = net.double() # Convert the weights to Double
|
56
56
|
|
57
57
|
class WSNRoutingEnv(gym.Env):
|
58
58
|
|
@@ -227,7 +227,7 @@ class WSNRoutingEnv(gym.Env):
|
|
227
227
|
|
228
228
|
# rewards = [reward.item() if isinstance(reward, torch.Tensor) else reward for reward in rewards] # Convert the reward to a float
|
229
229
|
|
230
|
-
rewards = self.compute_attention_rewards(rewards)
|
230
|
+
rewards = self.compute_attention_rewards(rewards, self.n_sensors)
|
231
231
|
print(f"Rewards: {rewards}")
|
232
232
|
# rewards = np.mean(rewards)
|
233
233
|
|
@@ -466,17 +466,20 @@ class WSNRoutingEnv(gym.Env):
|
|
466
466
|
return np.clip(normalized_throughput, 0, 1)
|
467
467
|
|
468
468
|
|
469
|
-
def compute_attention_rewards(self, rewards):
|
469
|
+
def compute_attention_rewards(self, rewards, n_sensors):
|
470
470
|
'''
|
471
471
|
Compute the attention-based rewards
|
472
472
|
'''
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
473
|
+
net = ScalarAttentionModel(n_sensors)
|
474
|
+
net = net.double() # Convert the weights to Double
|
475
|
+
final_reward = []
|
476
|
+
for i in range(len(rewards[0])):
|
477
|
+
rewards_i = [reward[i] for reward in rewards]
|
478
|
+
rewards_i = torch.tensor(rewards_i, dtype=torch.double)
|
479
|
+
rewards_i = rewards_i.unsqueeze(0) # Add batch dimension
|
480
|
+
final_reward.append(net(rewards_i).item())
|
481
|
+
|
482
|
+
return final_reward.mean().item()
|
480
483
|
|
481
484
|
|
482
485
|
# def compute_attention_reward(self, rewards):
|
@@ -0,0 +1,7 @@
|
|
1
|
+
gym_examples/__init__.py,sha256=815cyqeKkNC7ftFt8Vtj07athkj36c2STNlFxmRcqlc,166
|
2
|
+
gym_examples/envs/__init__.py,sha256=lgMe4pyOuUTgTBUddM0iwMlETsYTwFShny6ifm8PGM8,53
|
3
|
+
gym_examples/envs/wsn_env.py,sha256=h56qfOsV7mjaQi9eaZ1Mc92X322sGXPZV_KYUeUgiRE,27000
|
4
|
+
gym_examples-3.0.755.dist-info/METADATA,sha256=2wawIwJy9MsQXABHUT0Oq63fUBw2mtHxKRw2Ipv-R_Q,412
|
5
|
+
gym_examples-3.0.755.dist-info/WHEEL,sha256=2wepM1nk4DS4eFpYrW1TTqPcoGNfHhhO_i5m4cOimbo,92
|
6
|
+
gym_examples-3.0.755.dist-info/top_level.txt,sha256=rJRksoAF32M6lTLBEwYzRdo4PgtejceaNnnZ3HeY_Rk,13
|
7
|
+
gym_examples-3.0.755.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
gym_examples/__init__.py,sha256=HPc6ReDuVSkSVN757l9_qqSk5jUFLgJf2quo1k9tnnw,166
|
2
|
-
gym_examples/envs/__init__.py,sha256=lgMe4pyOuUTgTBUddM0iwMlETsYTwFShny6ifm8PGM8,53
|
3
|
-
gym_examples/envs/wsn_env.py,sha256=DX1Ejs3qj417OAFgCi4Y2Bq7kFX8wUFp3flGeX0IZPI,27008
|
4
|
-
gym_examples-3.0.753.dist-info/METADATA,sha256=9kxvKD1m4MgAjh6PeOfoWA6DlEgIdMlnsvC_mttYQxU,412
|
5
|
-
gym_examples-3.0.753.dist-info/WHEEL,sha256=2wepM1nk4DS4eFpYrW1TTqPcoGNfHhhO_i5m4cOimbo,92
|
6
|
-
gym_examples-3.0.753.dist-info/top_level.txt,sha256=rJRksoAF32M6lTLBEwYzRdo4PgtejceaNnnZ3HeY_Rk,13
|
7
|
-
gym_examples-3.0.753.dist-info/RECORD,,
|
File without changes
|
File without changes
|