gym-examples 3.0.376__py3-none-any.whl → 3.0.378__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gym_examples/__init__.py CHANGED
@@ -5,4 +5,4 @@ register(
5
5
  entry_point="gym_examples.envs:WSNRoutingEnv",
6
6
  )
7
7
 
8
- __version__ = "3.0.376"
8
+ __version__ = "3.0.378"
@@ -29,32 +29,31 @@ base_back_up_dir = "results/data/"
29
29
  max_reward = 1 # maximum reward value when the sensors sent data to the base station. The opposite value is when the sensors perform an unauthorized action
30
30
 
31
31
  # Define the final reward function using an attention mechanism
32
- class CustomizedLinear(nn.Module):
33
- def __init__(self, input_dim, output_dim):
34
- super(CustomizedLinear, self).__init__()
35
- self.weight = nn.Parameter(torch.rand(output_dim, input_dim))
36
- self.bias = None # No bias term
37
- self.Softplus = nn.Softplus() # SoftPlus activation function to ensure non-negative values: Check the paper for more details
38
-
39
- def forward(self, x):
40
- # Enforce non-negativity of weights
41
- weight = self.Softplus(self.weight)
32
+ # class CustomizedLinear(nn.Module):
33
+ # def __init__(self, input_dim, output_dim):
34
+ # super(CustomizedLinear, self).__init__()
35
+ # self.weight = nn.Parameter(torch.rand(output_dim, input_dim))
36
+ # self.bias = None # No bias term
42
37
 
43
- # Normalize cols to ensure that if sum(x1) < sum(x2) ==> sum(Ax1 + 0) < sum(Ax2 + 0): proof in the paper
44
- col_sums = weight.sum(dim=0, keepdim=True)
45
- normalized_weight = weight / col_sums
38
+ # def forward(self, x):
39
+ # # Normalize cols to ensure that if sum(x1) < sum(x2) ==> sum(Ax1 + 0) < sum(Ax2 + 0): proof in the paper
40
+ # col_sums = self.weight.sum(dim=0, keepdim=True)
41
+ # normalized_weight = self.weight / col_sums
46
42
 
47
- # Output
48
- y = torch.matmul(x, normalized_weight.t())
49
- return y
43
+ # # Output
44
+ # y = torch.matmul(x, normalized_weight.t())
45
+ # return y
50
46
 
51
47
  class Attention(nn.Module):
52
48
  def __init__(self, input_dim, output_dim):
53
49
  super(Attention, self).__init__() # Call the initializer of the parent class (nn.Module)
54
50
  self.input_dim = input_dim # Set the input dimension of the network
55
51
  self.output_dim = output_dim # Set the output dimension of the network
56
- self.linear1 = CustomizedLinear(input_dim, 64) # Define the first linear layer. It takes input of size 'input_dim' and outputs size '64'
57
- self.linear2 = CustomizedLinear(64, output_dim) # Define the second linear layer. It takes input of size '64' and outputs size 'output_dim'
52
+ # self.linear1 = CustomizedLinear(input_dim, 64) # Define the first linear layer. It takes input of size 'input_dim' and outputs size '64'
53
+ # self.linear2 = CustomizedLinear(64, output_dim) # Define the second linear layer. It takes input of size '64' and outputs size 'output_dim'
54
+ self.linear1 = nn.Linear(input_dim, 64) # Define the first linear layer. It takes input of size 'input_dim' and outputs size '64'
55
+ self.linear2 = nn.Linear(64, output_dim) # Define the second linear layer. It takes input of size '64' and outputs size 'output_dim'
56
+
58
57
 
59
58
  def forward(self, x):
60
59
  # Step 1: Ensure input is 2D by adding a batch dimension if necessary
@@ -197,6 +196,7 @@ class WSNRoutingEnv(gym.Env):
197
196
  self.packet_latency[i] = 0
198
197
 
199
198
  rewards[i] = [max_reward] * input_dim # Reward for transmitting data to the base station
199
+ print(f"Sensor {i} transmitted data to the base station with modified reward: {self.compute_attention_rewards(rewards[i])}")
200
200
  dones[i] = True
201
201
  else:
202
202
  distance = np.linalg.norm(self.sensor_positions[i] - self.sensor_positions[action])
@@ -222,7 +222,11 @@ class WSNRoutingEnv(gym.Env):
222
222
  self.packet_latency[action] += self.packet_latency[i] + latency_per_hop
223
223
  self.packet_latency[i] = 0
224
224
 
225
- rewards[i] = self.compute_individual_rewards(i, action)
225
+ rewards[i] = self.compute_individual_rewards(i, action)
226
+ if self.number_steps > 27:
227
+ raise Error("Stop here")
228
+ else:
229
+ print(f"Sensor {i} transmitted data to sensor {action} with modified reward: {self.compute_attention_rewards(rewards[i])}")
226
230
 
227
231
  # Update the number of packets
228
232
  self.number_of_packets[action] += self.number_of_packets[i]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gym-examples
3
- Version: 3.0.376
3
+ Version: 3.0.378
4
4
  Summary: A custom environment for multi-agent reinforcement learning focused on WSN routing.
5
5
  Home-page: https://github.com/gedji/CODES.git
6
6
  Author: Georges Djimefo
@@ -0,0 +1,7 @@
1
+ gym_examples/__init__.py,sha256=rkNQCsuNv0lbzwDFbbINy72Zb5H42sq2w5lvX-kVL5k,166
2
+ gym_examples/envs/__init__.py,sha256=lgMe4pyOuUTgTBUddM0iwMlETsYTwFShny6ifm8PGM8,53
3
+ gym_examples/envs/wsn_env.py,sha256=UITF_EWvfj15xD39FPBqig-tCSh2XSLHLUDcnKGATtU,27260
4
+ gym_examples-3.0.378.dist-info/METADATA,sha256=23rCX5hkl1xinlpHNHZIdFH1-F99ZtUk03yRGEek-U0,412
5
+ gym_examples-3.0.378.dist-info/WHEEL,sha256=2wepM1nk4DS4eFpYrW1TTqPcoGNfHhhO_i5m4cOimbo,92
6
+ gym_examples-3.0.378.dist-info/top_level.txt,sha256=rJRksoAF32M6lTLBEwYzRdo4PgtejceaNnnZ3HeY_Rk,13
7
+ gym_examples-3.0.378.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- gym_examples/__init__.py,sha256=saXnZJjlgXXuJWFxaY2VD_0xdoYKK3gEsrD5TPNdybE,166
2
- gym_examples/envs/__init__.py,sha256=lgMe4pyOuUTgTBUddM0iwMlETsYTwFShny6ifm8PGM8,53
3
- gym_examples/envs/wsn_env.py,sha256=CDqmuxieq6URlTpDQWgDcOaKJVhhzd6h4jrh72WvGh4,26756
4
- gym_examples-3.0.376.dist-info/METADATA,sha256=ZIc51ev191-r6-0hK2-dntjvgjPinlgtJ2UAxbQHEnQ,412
5
- gym_examples-3.0.376.dist-info/WHEEL,sha256=2wepM1nk4DS4eFpYrW1TTqPcoGNfHhhO_i5m4cOimbo,92
6
- gym_examples-3.0.376.dist-info/top_level.txt,sha256=rJRksoAF32M6lTLBEwYzRdo4PgtejceaNnnZ3HeY_Rk,13
7
- gym_examples-3.0.376.dist-info/RECORD,,