plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,75 @@
1
+ """
2
+ Reference:
3
+
4
+ https://github.com/sweetice/Deep-reinforcement-learning-with-pytorch
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from plato.config import Config
10
+ from plato.utils.reinforcement_learning.policies import base
11
+
12
+
13
+ class Policy(base.Policy):
14
+ def __init__(self, state_dim, action_space):
15
+ super().__init__(state_dim, action_space)
16
+
17
+ def select_action(self, state):
18
+ """Select action from policy."""
19
+ state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
20
+ return self.actor(state).cpu().data.numpy().flatten()
21
+
22
+ def update(self):
23
+ """Update policy."""
24
+ for _ in range(Config().algorithm.update_iteration):
25
+ # Sample replay buffer
26
+ state, action, reward, next_state, done = self.replay_buffer.sample()
27
+ state = torch.FloatTensor(state).to(self.device).unsqueeze(1)
28
+ action = torch.FloatTensor(action).to(self.device).unsqueeze(1)
29
+ reward = torch.FloatTensor(reward).to(self.device).unsqueeze(1)
30
+ next_state = torch.FloatTensor(next_state).to(self.device).unsqueeze(1)
31
+ done = torch.FloatTensor(done).to(self.device).unsqueeze(1)
32
+
33
+ # Compute the target Q value
34
+ target_Q = self.critic_target(next_state, self.actor_target(next_state))
35
+ target_Q = (
36
+ reward + ((1 - done) * Config().algorithm.gamma * target_Q).detach()
37
+ )
38
+
39
+ # Get current Q estimate
40
+ current_Q = self.critic(state, action)
41
+
42
+ # Compute critic loss
43
+ critic_loss = F.mse_loss(current_Q, target_Q)
44
+
45
+ # Optimize the critic
46
+ self.critic_optimizer.zero_grad()
47
+ critic_loss.backward()
48
+ self.critic_optimizer.step()
49
+
50
+ # Compute actor loss
51
+ actor_loss = -self.critic(state, self.actor(state)).mean()
52
+
53
+ # Optimize the actor
54
+ self.actor_optimizer.zero_grad()
55
+ actor_loss.backward()
56
+ self.actor_optimizer.step()
57
+
58
+ # Update the frozen target models
59
+ for param, target_param in zip(
60
+ self.critic.parameters(), self.critic_target.parameters()
61
+ ):
62
+ target_param.data.copy_(
63
+ Config().algorithm.tau * param.data
64
+ + (1 - Config().algorithm.tau) * target_param.data
65
+ )
66
+
67
+ for param, target_param in zip(
68
+ self.actor.parameters(), self.actor_target.parameters()
69
+ ):
70
+ target_param.data.copy_(
71
+ Config().algorithm.tau * param.data
72
+ + (1 - Config().algorithm.tau) * target_param.data
73
+ )
74
+
75
+ return critic_loss.item(), actor_loss.item()
@@ -0,0 +1,32 @@
1
+ """
2
+ Having a registry of all available classes is convenient for retrieving an instance
3
+ based on a configuration at run-time.
4
+ """
5
+
6
+ import logging
7
+ from collections import OrderedDict
8
+
9
+ from plato.config import Config
10
+ from plato.utils.reinforcement_learning.policies import base, ddpg, sac, td3
11
+
12
+ registered_policies = OrderedDict(
13
+ [
14
+ ("base", base.Policy),
15
+ ("ddpg", ddpg.Policy),
16
+ ("sac", sac.Policy),
17
+ ("td3", td3.Policy),
18
+ ]
19
+ )
20
+
21
+
22
+ def get(state_dim, action_space):
23
+ """Get the DRL policy with the provided name."""
24
+ policy_name = Config().algorithm.model_name
25
+ logging.info("DRL Policy: %s", policy_name)
26
+
27
+ if policy_name in registered_policies:
28
+ registered_policy = registered_policies[policy_name](state_dim, action_space)
29
+ else:
30
+ raise ValueError("No such policy: {}".format(policy_name))
31
+
32
+ return registered_policy
@@ -0,0 +1,343 @@
1
+ """
2
+ Reference:
3
+
4
+ https://github.com/pranz24/pytorch-soft-actor-critic
5
+ """
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from plato.config import Config
14
+ from plato.utils.reinforcement_learning.policies import base
15
+ from torch.distributions import Normal
16
+
17
+ LOG_SIG_MAX = 2
18
+ LOG_SIG_MIN = -20
19
+
20
+
21
+ def create_log_gaussian(mean, log_std, t):
22
+ quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2))
23
+ l = mean.shape
24
+ log_z = log_std
25
+ z = l[-1] * math.log(2 * math.pi)
26
+ log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z
27
+ return log_p
28
+
29
+
30
+ def logsumexp(inputs, dim=None, keepdim=False):
31
+ if dim is None:
32
+ inputs = inputs.view(-1)
33
+ dim = 0
34
+ s, _ = torch.max(inputs, dim=dim, keepdim=True)
35
+ outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
36
+ if not keepdim:
37
+ outputs = outputs.squeeze(dim)
38
+ return outputs
39
+
40
+
41
+ def soft_update(target, source, tau):
42
+ for target_param, param in zip(target.parameters(), source.parameters()):
43
+ target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
44
+
45
+
46
+ def hard_update(target, source):
47
+ for target_param, param in zip(target.parameters(), source.parameters()):
48
+ target_param.data.copy_(param.data)
49
+
50
+
51
+ # Initialize Policy weights
52
+ def weights_init_(m):
53
+ if isinstance(m, nn.Linear):
54
+ torch.nn.init.xavier_uniform_(m.weight, gain=1)
55
+ torch.nn.init.constant_(m.bias, 0)
56
+
57
+
58
+ class ValueNetwork(nn.Module):
59
+ def __init__(self, num_inputs, hidden_dim):
60
+ super(ValueNetwork, self).__init__()
61
+
62
+ self.linear1 = nn.Linear(num_inputs, hidden_dim)
63
+ self.linear2 = nn.Linear(hidden_dim, hidden_dim)
64
+ self.linear3 = nn.Linear(hidden_dim, 1)
65
+
66
+ self.apply(weights_init_)
67
+
68
+ def forward(self, state):
69
+ x = F.relu(self.linear1(state))
70
+ x = F.relu(self.linear2(x))
71
+ x = self.linear3(x)
72
+ return x
73
+
74
+
75
+ class QNetwork(nn.Module):
76
+ def __init__(self, num_inputs, num_actions, hidden_dim):
77
+ super(QNetwork, self).__init__()
78
+
79
+ # Q1 architecture
80
+ self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
81
+ self.linear2 = nn.Linear(hidden_dim, hidden_dim)
82
+ self.linear3 = nn.Linear(hidden_dim, 1)
83
+
84
+ # Q2 architecture
85
+ self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
86
+ self.linear5 = nn.Linear(hidden_dim, hidden_dim)
87
+ self.linear6 = nn.Linear(hidden_dim, 1)
88
+
89
+ self.apply(weights_init_)
90
+
91
+ def forward(self, state, action):
92
+ xu = torch.cat([state, action], 1)
93
+
94
+ x1 = F.relu(self.linear1(xu))
95
+ x1 = F.relu(self.linear2(x1))
96
+ x1 = self.linear3(x1)
97
+
98
+ x2 = F.relu(self.linear4(xu))
99
+ x2 = F.relu(self.linear5(x2))
100
+ x2 = self.linear6(x2)
101
+
102
+ return x1, x2
103
+
104
+
105
+ class GaussianPolicy(nn.Module):
106
+ def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
107
+ super(GaussianPolicy, self).__init__()
108
+
109
+ self.linear1 = nn.Linear(num_inputs, hidden_dim)
110
+ self.linear2 = nn.Linear(hidden_dim, hidden_dim)
111
+
112
+ self.mean_linear = nn.Linear(hidden_dim, num_actions)
113
+ self.log_std_linear = nn.Linear(hidden_dim, num_actions)
114
+
115
+ self.apply(weights_init_)
116
+
117
+ # action rescaling
118
+ if action_space is None:
119
+ self.action_scale = torch.tensor(1.0)
120
+ self.action_bias = torch.tensor(0.0)
121
+ else:
122
+ self.action_scale = torch.FloatTensor(
123
+ (action_space.high - action_space.low) / 2.0
124
+ )
125
+ self.action_bias = torch.FloatTensor(
126
+ (action_space.high + action_space.low) / 2.0
127
+ )
128
+
129
+ def forward(self, state):
130
+ x = F.relu(self.linear1(state))
131
+ x = F.relu(self.linear2(x))
132
+ mean = self.mean_linear(x)
133
+ log_std = self.log_std_linear(x)
134
+ log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
135
+ return mean, log_std
136
+
137
+ def sample(self, state):
138
+ mean, log_std = self.forward(state)
139
+ std = log_std.exp()
140
+ normal = Normal(mean, std)
141
+ x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
142
+ y_t = torch.tanh(x_t)
143
+ action = y_t * self.action_scale + self.action_bias
144
+ log_prob = normal.log_prob(x_t)
145
+ # Enforcing Action Bound
146
+ log_prob -= torch.log(
147
+ self.action_scale * (1 - y_t.pow(2)) + Config().algorithm.epsilon
148
+ )
149
+ log_prob = log_prob.sum(1, keepdim=True)
150
+ mean = torch.tanh(mean) * self.action_scale + self.action_bias
151
+ return action, log_prob, mean
152
+
153
+ def to(self, device):
154
+ self.action_scale = self.action_scale.to(device)
155
+ self.action_bias = self.action_bias.to(device)
156
+ return super(GaussianPolicy, self).to(device)
157
+
158
+
159
+ class DeterministicPolicy(nn.Module):
160
+ def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
161
+ super(DeterministicPolicy, self).__init__()
162
+ self.linear1 = nn.Linear(num_inputs, hidden_dim)
163
+ self.linear2 = nn.Linear(hidden_dim, hidden_dim)
164
+
165
+ self.mean = nn.Linear(hidden_dim, num_actions)
166
+ self.noise = torch.Tensor(num_actions)
167
+
168
+ self.apply(weights_init_)
169
+
170
+ # action rescaling
171
+ if action_space is None:
172
+ self.action_scale = 1.0
173
+ self.action_bias = 0.0
174
+ else:
175
+ self.action_scale = torch.FloatTensor(
176
+ (action_space.high - action_space.low) / 2.0
177
+ )
178
+ self.action_bias = torch.FloatTensor(
179
+ (action_space.high + action_space.low) / 2.0
180
+ )
181
+
182
+ def forward(self, state):
183
+ x = F.relu(self.linear1(state))
184
+ x = F.relu(self.linear2(x))
185
+ mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias
186
+ return mean
187
+
188
+ def sample(self, state):
189
+ mean = self.forward(state)
190
+ noise = self.noise.normal_(0.0, std=0.1)
191
+ noise = noise.clamp(-0.25, 0.25)
192
+ action = mean + noise
193
+ return action, torch.tensor(0.0), mean
194
+
195
+ def to(self, device):
196
+ self.action_scale = self.action_scale.to(device)
197
+ self.action_bias = self.action_bias.to(device)
198
+ self.noise = self.noise.to(device)
199
+ return super(DeterministicPolicy, self).to(device)
200
+
201
+
202
+ class Policy(base.Policy):
203
+ def __init__(self, state_dim, action_space):
204
+ super().__init__(state_dim, action_space)
205
+
206
+ # Initialize NNs
207
+ self.critic = QNetwork(
208
+ state_dim, action_space.shape[0], Config().algorithm.hidden_size
209
+ ).to(self.device)
210
+ self.critic_optimizer = torch.optim.Adam(
211
+ self.critic.parameters(), lr=Config().algorithm.learning_rate
212
+ )
213
+
214
+ self.critic_target = QNetwork(
215
+ state_dim, action_space.shape[0], Config().algorithm.hidden_size
216
+ ).to(self.device)
217
+ hard_update(self.critic_target, self.critic)
218
+
219
+ if Config().algorithm.deterministic:
220
+ self.alpha = 0
221
+ self.automatic_entropy_tuning = False
222
+ self.actor = DeterministicPolicy(
223
+ state_dim,
224
+ action_space.shape[0],
225
+ Config().algorithm.hidden_size,
226
+ action_space,
227
+ ).to(self.device)
228
+ self.actor_optimizer = torch.optim.Adam(
229
+ self.actor.parameters(), lr=Config().algorithm.learning_rate
230
+ )
231
+ else:
232
+ if self.automatic_entropy_tuning is True:
233
+ self.target_entropy = -torch.prod(
234
+ torch.Tensor(action_space.shape).to(self.device)
235
+ ).item()
236
+ self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
237
+ self.alpha_optimizer = torch.optim.Adam(
238
+ [self.log_alpha], lr=Config().algorithm.learning_rate
239
+ )
240
+
241
+ self.actor = GaussianPolicy(
242
+ state_dim,
243
+ action_space.shape[0],
244
+ Config().algorithm.hidden_size,
245
+ action_space,
246
+ ).to(self.device)
247
+ self.actor_optimizer = torch.optim.Adam(
248
+ self.actor.parameters(), lr=Config().algorithm.learning_rate
249
+ )
250
+
251
+ # Initialize replay memory
252
+ self.replay_buffer = base.ReplayMemory(
253
+ state_dim,
254
+ action_space.shape[0],
255
+ Config().algorithm.replay_size,
256
+ Config().algorithm.replay_seed,
257
+ )
258
+ self.alpha = Config().algorithm.alpha
259
+ self.automatic_entropy_tuning = Config().algorithm.automatic_entropy_tuning
260
+
261
+ def select_action(self, state, test=False):
262
+ state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
263
+ if test is False:
264
+ action, _, _ = self.actor.sample(state)
265
+ else:
266
+ _, _, action = self.actor.sample(state)
267
+ return action.detach().cpu().numpy().flatten()
268
+
269
+ def update(self):
270
+ for _ in range(Config().algorithm.update_iteration):
271
+ # Sample a batch from memory
272
+ state_batch, action_batch, reward_batch, next_state_batch, mask_batch = (
273
+ self.replay_buffer.sample()
274
+ )
275
+
276
+ state_batch = torch.FloatTensor(state_batch).to(self.device)
277
+ next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
278
+ action_batch = torch.FloatTensor(action_batch).to(self.device)
279
+ reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
280
+ mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)
281
+
282
+ with torch.no_grad():
283
+ next_state_action, next_state_log_pi, _ = self.actor.sample(
284
+ next_state_batch
285
+ )
286
+ qf1_next_target, qf2_next_target = self.critic_target(
287
+ next_state_batch, next_state_action
288
+ )
289
+ min_qf_next_target = (
290
+ torch.min(qf1_next_target, qf2_next_target)
291
+ - self.alpha * next_state_log_pi
292
+ )
293
+ next_q_value = reward_batch + (
294
+ 1 - mask_batch
295
+ ) * Config().algorithm.gamma * (min_qf_next_target)
296
+ qf1, qf2 = self.critic(
297
+ state_batch, action_batch
298
+ ) # Two Q-functions to mitigate positive bias in the policy improvement step
299
+ qf1_loss = F.mse_loss(qf1, next_q_value)
300
+ qf2_loss = F.mse_loss(qf2, next_q_value)
301
+ qf_loss = qf1_loss + qf2_loss
302
+
303
+ self.critic_optimizer.zero_grad()
304
+ qf_loss.backward()
305
+ self.critic_optimizer.step()
306
+
307
+ pi, log_pi, _ = self.actor.sample(state_batch)
308
+
309
+ qf1_pi, qf2_pi = self.critic(state_batch, pi)
310
+ min_qf_pi = torch.min(qf1_pi, qf2_pi)
311
+
312
+ policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()
313
+
314
+ self.actor_optimizer.zero_grad()
315
+ policy_loss.backward()
316
+ self.actor_optimizer.step()
317
+
318
+ if self.automatic_entropy_tuning:
319
+ alpha_loss = -(
320
+ self.log_alpha * (log_pi + self.target_entropy).detach()
321
+ ).mean()
322
+
323
+ self.alpha_optimizer.zero_grad()
324
+ alpha_loss.backward()
325
+ self.alpha_optimizer.step()
326
+
327
+ self.alpha = self.log_alpha.exp()
328
+ alpha_tlogs = self.alpha.clone() # For TensorboardX logs
329
+ else:
330
+ alpha_loss = torch.tensor(0.0).to(self.device)
331
+ alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
332
+
333
+ soft_update(self.critic_target, self.critic, Config().algorithm.tau)
334
+
335
+ self.total_it += 1
336
+
337
+ return (
338
+ qf1_loss.item(),
339
+ qf2_loss.item(),
340
+ policy_loss.item(),
341
+ alpha_loss.item(),
342
+ alpha_tlogs.item(),
343
+ )