plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,485 @@
1
+ """
2
+ Reference:
3
+
4
+ https://github.com/AntoineTheb/RNN-RL
5
+ """
6
+
7
+ import copy
8
+ import random
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from plato.config import Config
14
+ from plato.utils.reinforcement_learning.policies import base
15
+ from torch import nn
16
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
17
+
18
+
19
+ class RNNReplayMemory:
20
+ def __init__(self, state_dim, action_dim, hidden_size, capacity, seed):
21
+ random.seed(seed)
22
+ self.device = Config().device()
23
+ self.capacity = int(capacity)
24
+ self.ptr = 0
25
+ self.size = 0
26
+
27
+ self.h = np.zeros((self.capacity, hidden_size))
28
+ self.nh = np.zeros((self.capacity, hidden_size))
29
+ self.c = np.zeros((self.capacity, hidden_size))
30
+ self.nc = np.zeros((self.capacity, hidden_size))
31
+ if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
32
+ self.state = [0] * self.capacity
33
+ self.action = [0] * self.capacity
34
+ self.reward = [0] * self.capacity
35
+ self.next_state = [0] * self.capacity
36
+ self.done = [0] * self.capacity
37
+ else:
38
+ self.state = np.zeros((self.capacity, state_dim))
39
+ self.action = np.zeros((self.capacity, action_dim))
40
+ self.reward = np.zeros((self.capacity, 1))
41
+ self.next_state = np.zeros((self.capacity, state_dim))
42
+ self.done = np.zeros((self.capacity, 1))
43
+
44
+ def push(self, data):
45
+ self.state[self.ptr] = data[0]
46
+ self.action[self.ptr] = data[1]
47
+ self.reward[self.ptr] = data[2]
48
+ self.next_state[self.ptr] = data[3]
49
+ self.done[self.ptr] = data[4]
50
+
51
+ self.h[self.ptr] = data[5].detach().cpu()
52
+ self.c[self.ptr] = data[6].detach().cpu()
53
+ self.nh[self.ptr] = data[7].detach().cpu()
54
+ self.nc[self.ptr] = data[8].detach().cpu()
55
+
56
+ self.ptr = (self.ptr + 1) % self.capacity
57
+ self.size = min(self.size + 1, self.capacity)
58
+
59
+ def sample(self):
60
+ ind = np.random.randint(0, self.size, size=int(Config().algorithm.batch_size))
61
+
62
+ h = torch.tensor(
63
+ self.h[ind][None, ...], requires_grad=True, dtype=torch.float
64
+ ).to(self.device)
65
+ c = torch.tensor(
66
+ self.c[ind][None, ...], requires_grad=True, dtype=torch.float
67
+ ).to(self.device)
68
+ nh = torch.tensor(
69
+ self.nh[ind][None, ...], requires_grad=True, dtype=torch.float
70
+ ).to(self.device)
71
+ nc = torch.tensor(
72
+ self.nc[ind][None, ...], requires_grad=True, dtype=torch.float
73
+ ).to(self.device)
74
+
75
+ if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
76
+ state = [torch.FloatTensor(self.state[i]).to(self.device) for i in ind]
77
+ action = [torch.FloatTensor(self.action[i]).to(self.device) for i in ind]
78
+ reward = [self.reward[i] for i in ind]
79
+ next_state = [
80
+ torch.FloatTensor(self.next_state[i]).to(self.device) for i in ind
81
+ ]
82
+ done = [self.done[i] for i in ind]
83
+ else:
84
+ state = torch.FloatTensor(self.state[ind][:, None, :]).to(self.device)
85
+
86
+ action = torch.FloatTensor(self.action[ind][:, None, :]).to(self.device)
87
+ reward = torch.FloatTensor(self.reward[ind][:, None, :]).to(self.device)
88
+ next_state = torch.FloatTensor(self.next_state[ind][:, None, :]).to(
89
+ self.device
90
+ )
91
+ done = torch.FloatTensor(self.done[ind][:, None, :]).to(self.device)
92
+
93
+ return state, action, reward, next_state, done, h, c, nh, nc
94
+
95
+ def __len__(self):
96
+ return self.size
97
+
98
+
99
+ class TD3Actor(base.Actor):
100
+ def __init__(self, state_dim, action_dim, max_action):
101
+ super().__init__(state_dim, action_dim, max_action)
102
+
103
+ def forward(self, x, hidden=None):
104
+ x = F.relu(self.l1(x))
105
+ x = F.relu(self.l2(x))
106
+ x = self.max_action * torch.tanh(self.l3(x))
107
+ # Normalize/Scaling aggregation weights so that the sum is 1
108
+ x += 1 # [-1, 1] -> [0, 2]
109
+ x /= x.sum()
110
+ return x
111
+
112
+
113
+ class TD3Critic(nn.Module):
114
+ def __init__(self, state_dim, action_dim):
115
+ super(TD3Critic, self).__init__()
116
+
117
+ # Q1 architecture
118
+ self.l1 = nn.Linear(state_dim + action_dim, 400)
119
+ self.l2 = nn.Linear(400, 300)
120
+ self.l3 = nn.Linear(300, 1)
121
+
122
+ # Q2 architecture
123
+ self.l4 = nn.Linear(state_dim + action_dim, 400)
124
+ self.l5 = nn.Linear(400, 300)
125
+ self.l6 = nn.Linear(300, 1)
126
+
127
+ def forward(self, state, action, hidden1=None, hidden2=None):
128
+ sa = torch.cat([state, action], 1)
129
+ q1 = F.relu(self.l1(sa))
130
+ q1 = F.relu(self.l2(q1))
131
+ q1 = self.l3(q1)
132
+ q2 = F.relu(self.l4(sa))
133
+ q2 = F.relu(self.l5(q2))
134
+ q2 = self.l6(q2)
135
+ return q1, q2
136
+
137
+ def Q1(self, state, action, hidden=None):
138
+ sa = torch.cat([state, action], 1)
139
+ q1 = F.relu(self.l1(sa))
140
+ q1 = F.relu(self.l2(q1))
141
+ q1 = self.l3(q1)
142
+ return q1
143
+
144
+
145
+ class RNNActor(nn.Module):
146
+ def __init__(self, state_dim, action_dim, hidden_size, max_action):
147
+ super(RNNActor, self).__init__()
148
+ self.action_dim = action_dim
149
+ self.max_action = max_action
150
+
151
+ self.l1 = nn.LSTM(state_dim, hidden_size, batch_first=True)
152
+ self.l2 = nn.Linear(hidden_size, hidden_size)
153
+ if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
154
+ self.l3 = nn.Linear(hidden_size, 1)
155
+ else:
156
+ self.l3 = nn.Linear(hidden_size, action_dim)
157
+
158
+ def forward(self, state, hidden=None):
159
+ if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
160
+ # Pad the first state to full dims
161
+ if len(state) == 1:
162
+ pilot = state
163
+ else:
164
+ pilot = state[0]
165
+ pilot = F.pad(
166
+ input=pilot,
167
+ pad=(0, 0, 0, self.action_dim - pilot.shape[-2]),
168
+ mode="constant",
169
+ value=0,
170
+ )
171
+ if len(state) == 1:
172
+ state = pilot
173
+ else:
174
+ state[0] = pilot
175
+ # Pad variable states
176
+ # Get the length explicitly for later packing sequences
177
+ lens = list(map(len, state))
178
+ if len(state) == 1:
179
+ state = [torch.squeeze(state)]
180
+ # Pad and pack
181
+ padded = pad_sequence(state, batch_first=True)
182
+ state = pack_padded_sequence(
183
+ padded, lengths=lens, batch_first=True, enforce_sorted=False
184
+ )
185
+ self.l1.flatten_parameters()
186
+ a, h = self.l1(state, hidden)
187
+
188
+ # mini-batch update
189
+ if (
190
+ hasattr(Config().server, "synchronous")
191
+ and not Config().server.synchronous
192
+ and len(state) != 1
193
+ ):
194
+ a, _ = pad_packed_sequence(a, batch_first=True)
195
+
196
+ a = F.relu(self.l2(a))
197
+ a = self.max_action * torch.tanh(self.l3(a))
198
+
199
+ # Normalize/Scaling aggregation weights so that the sum is 1
200
+ a += 1 # [-1, 1] -> [0, 2]
201
+ a /= a.sum()
202
+
203
+ return a, h
204
+
205
+
206
+ class RNNCritic(nn.Module):
207
+ def __init__(self, state_dim, action_dim, hidden_size):
208
+ super(RNNCritic, self).__init__()
209
+ self.action_dim = action_dim
210
+
211
+ # Q1 architecture
212
+ if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
213
+ self.l1 = nn.LSTM(state_dim + 1, hidden_size, batch_first=True)
214
+ else:
215
+ self.l1 = nn.LSTM(state_dim + action_dim, hidden_size, batch_first=True)
216
+ self.l2 = nn.Linear(hidden_size, hidden_size)
217
+ self.l3 = nn.Linear(hidden_size, 1)
218
+
219
+ # Q2 architecture
220
+ if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
221
+ self.l4 = nn.LSTM(state_dim + 1, hidden_size, batch_first=True)
222
+ else:
223
+ self.l4 = nn.LSTM(state_dim + action_dim, hidden_size, batch_first=True)
224
+ self.l5 = nn.Linear(hidden_size, hidden_size)
225
+ self.l6 = nn.Linear(hidden_size, 1)
226
+
227
+ def forward(self, state, action, hidden1, hidden2):
228
+ if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
229
+ # Pad the first state to full dims
230
+ if len(state) == 1:
231
+ pilot = state
232
+ else:
233
+ pilot = state[0]
234
+ pilot = F.pad(
235
+ input=pilot,
236
+ pad=(0, 0, 0, self.action_dim - pilot.shape[-2]),
237
+ mode="constant",
238
+ value=0,
239
+ )
240
+ if len(state) == 1:
241
+ state = pilot
242
+ else:
243
+ state[0] = pilot
244
+ # Pad variable states
245
+ # Get the length explicitly for later packing sequences
246
+ lens = list(map(len, state))
247
+ if len(state) == 1:
248
+ state = [torch.squeeze(state)]
249
+ # Pad and pack
250
+ padded = pad_sequence(state, batch_first=True)
251
+ state = padded
252
+ sa = torch.cat([state, action], -1)
253
+
254
+ if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
255
+ sa = pack_padded_sequence(
256
+ sa, lengths=lens, batch_first=True, enforce_sorted=False
257
+ )
258
+ self.l1.flatten_parameters()
259
+ self.l4.flatten_parameters()
260
+ q1, hidden1 = self.l1(sa, hidden1)
261
+ q2, hidden2 = self.l4(sa, hidden2)
262
+
263
+ if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
264
+ q1, _ = pad_packed_sequence(q1, batch_first=True)
265
+ q2, _ = pad_packed_sequence(q2, batch_first=True)
266
+
267
+ q1 = F.relu(self.l2(q1))
268
+ q1 = self.l3(q1)
269
+ q1 = torch.mean(q1.reshape(q1.shape[0], -1, 1), 1)
270
+
271
+ q2 = F.relu(self.l5(q2))
272
+ q2 = self.l6(q2)
273
+ q2 = torch.mean(q2.reshape(q2.shape[0], -1, 1), 1)
274
+
275
+ return q1, q2
276
+
277
+ def Q1(self, state, action, hidden1):
278
+ if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
279
+ # Pad variable states
280
+ # Get the length explicitly for later packing sequences
281
+ lens = list(map(len, state))
282
+ # Pad and pack
283
+ padded = pad_sequence(state, batch_first=True)
284
+ state = padded
285
+
286
+ sa = torch.cat([state, action], -1)
287
+
288
+ if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
289
+ sa = pack_padded_sequence(
290
+ sa, lengths=lens, batch_first=True, enforce_sorted=False
291
+ )
292
+ self.l1.flatten_parameters()
293
+ q1, hidden1 = self.l1(sa, hidden1)
294
+
295
+ if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
296
+ q1, _ = pad_packed_sequence(q1, batch_first=True)
297
+
298
+ q1 = F.relu(self.l2(q1))
299
+ q1 = self.l3(q1)
300
+ q1 = torch.mean(q1.reshape(q1.shape[0], -1, 1), 1)
301
+
302
+ return q1
303
+
304
+
305
+ class Policy(base.Policy):
306
+ def __init__(self, state_dim, action_dim):
307
+ super().__init__(state_dim, action_dim)
308
+
309
+ # Initialize NNs
310
+ if Config().algorithm.recurrent_actor:
311
+ self.actor = RNNActor(
312
+ state_dim, action_dim, Config().algorithm.hidden_size, self.max_action
313
+ ).to(self.device)
314
+ self.critic = RNNCritic(
315
+ state_dim, action_dim, Config().algorithm.hidden_size
316
+ ).to(self.device)
317
+ else:
318
+ self.actor = TD3Actor(state_dim, action_dim, self.max_action).to(
319
+ self.device
320
+ )
321
+ self.critic = TD3Critic(state_dim, action_dim).to(self.device)
322
+
323
+ self.actor_target = copy.deepcopy(self.actor)
324
+ self.actor_optimizer = torch.optim.Adam(
325
+ self.actor.parameters(), lr=Config().algorithm.learning_rate
326
+ )
327
+
328
+ self.critic_target = copy.deepcopy(self.critic)
329
+ self.critic_optimizer = torch.optim.Adam(
330
+ self.critic.parameters(), lr=Config().algorithm.learning_rate
331
+ )
332
+
333
+ # Initialize replay memory
334
+ if Config().algorithm.recurrent_actor:
335
+ self.replay_buffer = RNNReplayMemory(
336
+ state_dim,
337
+ action_dim,
338
+ Config().algorithm.hidden_size,
339
+ Config().algorithm.replay_size,
340
+ Config().algorithm.replay_seed,
341
+ )
342
+
343
+ else:
344
+ self.replay_buffer = base.ReplayMemory(
345
+ state_dim,
346
+ action_dim,
347
+ Config().algorithm.replay_size,
348
+ Config().algorithm.replay_seed,
349
+ )
350
+
351
+ self.policy_noise = Config().algorithm.policy_noise * self.max_action
352
+ self.noise_clip = Config().algorithm.noise_clip * self.max_action
353
+
354
+ def get_initial_states(self):
355
+ h_0, c_0 = None, None
356
+ if Config().algorithm.recurrent_actor:
357
+ h_0 = torch.zeros(
358
+ (self.actor.l1.num_layers, 1, self.actor.l1.hidden_size),
359
+ dtype=torch.float,
360
+ )
361
+ # h_0 = h_0.to(self.device)
362
+
363
+ c_0 = torch.zeros(
364
+ (self.actor.l1.num_layers, 1, self.actor.l1.hidden_size),
365
+ dtype=torch.float,
366
+ )
367
+ # c_0 = c_0.to(self.device)
368
+ return (h_0, c_0)
369
+
370
+ def select_action(self, state, hidden=None, test=False):
371
+ """Select action from policy."""
372
+ if Config().algorithm.recurrent_actor:
373
+ if (
374
+ hasattr(Config().server, "synchronous")
375
+ and not Config().server.synchronous
376
+ ):
377
+ state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
378
+ else:
379
+ state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)[
380
+ :, None, :
381
+ ]
382
+ # state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
383
+ action, hidden = self.actor(state, hidden)
384
+ return action.cpu().data.numpy().flatten(), hidden
385
+ else:
386
+ state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
387
+ action = self.actor(state)
388
+ return action.cpu().data.numpy().flatten()
389
+
390
+ def update(self):
391
+ """Update policy."""
392
+ self.total_it += 1
393
+
394
+ # Sample replay buffer
395
+ if Config().algorithm.recurrent_actor:
396
+ state, action, reward, next_state, done, h, c, nh, nc = (
397
+ self.replay_buffer.sample()
398
+ )
399
+ if (
400
+ hasattr(Config().server, "synchronous")
401
+ and not Config().server.synchronous
402
+ ):
403
+ # Pad variable actions
404
+ padded = pad_sequence(action, batch_first=True)
405
+ action = padded
406
+ reward = torch.FloatTensor(reward).to(self.device).unsqueeze(1)
407
+ done = torch.FloatTensor(done).to(self.device).unsqueeze(1)
408
+ hidden = (h, c)
409
+ next_hidden = (nh, nc)
410
+ else:
411
+ state, action, reward, next_state, done = self.replay_buffer.sample()
412
+ state = torch.FloatTensor(state).to(self.device)
413
+ action = torch.FloatTensor(action).to(self.device)
414
+ reward = torch.FloatTensor(reward).to(self.device)
415
+ next_state = torch.FloatTensor(next_state).to(self.device)
416
+ done = torch.FloatTensor(done).to(self.device)
417
+ hidden, next_hidden = (None, None), (None, None)
418
+
419
+ with torch.no_grad():
420
+ # Select action according to policy and add clipped noise
421
+ noise = (torch.randn_like(action) * self.policy_noise).clamp(
422
+ -self.noise_clip, self.noise_clip
423
+ )
424
+
425
+ next_action = (self.actor_target(next_state, next_hidden)[0] + noise).clamp(
426
+ -self.max_action, self.max_action
427
+ )
428
+
429
+ # Compute the target Q value
430
+ target_Q1, target_Q2 = self.critic_target(
431
+ next_state, next_action, next_hidden, next_hidden
432
+ )
433
+ target_Q = torch.min(target_Q1, target_Q2)
434
+ target_Q = reward + (1 - done) * Config().algorithm.gamma * target_Q
435
+
436
+ # Get current Q estimates
437
+ current_Q1, current_Q2 = self.critic(state, action, hidden, hidden)
438
+
439
+ # Compute critic loss
440
+ critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
441
+ current_Q2, target_Q
442
+ )
443
+
444
+ # Optimize the critic
445
+ self.critic_optimizer.zero_grad()
446
+ critic_loss.backward()
447
+ self.critic_optimizer.step()
448
+
449
+ actor_loss = critic_loss
450
+
451
+ # Delayed policy updates
452
+ if self.total_it % Config().algorithm.policy_freq == 0:
453
+ # Compute actor loss
454
+ if Config().algorithm.recurrent_actor:
455
+ actor_loss = -self.critic.Q1(
456
+ state, self.actor(state, hidden)[0], hidden
457
+ ).mean()
458
+ else:
459
+ actor_loss = -self.critic.Q1(
460
+ state, self.actor(state, hidden), hidden
461
+ ).mean()
462
+
463
+ # Optimize the actor
464
+ self.actor_optimizer.zero_grad()
465
+ actor_loss.backward()
466
+ self.actor_optimizer.step()
467
+
468
+ # Update the frozen target models
469
+ for param, target_param in zip(
470
+ self.critic.parameters(), self.critic_target.parameters()
471
+ ):
472
+ target_param.data.copy_(
473
+ Config().algorithm.tau * param.data
474
+ + (1 - Config().algorithm.tau) * target_param.data
475
+ )
476
+
477
+ for param, target_param in zip(
478
+ self.actor.parameters(), self.actor_target.parameters()
479
+ ):
480
+ target_param.data.copy_(
481
+ Config().algorithm.tau * param.data
482
+ + (1 - Config().algorithm.tau) * target_param.data
483
+ )
484
+
485
+ return critic_loss.item(), actor_loss.item()
@@ -0,0 +1,142 @@
1
+ """
2
+ A basic RL environment for FL server using Gym for RL control.
3
+ """
4
+
5
+ import asyncio
6
+ import logging
7
+ from abc import abstractmethod
8
+
9
+ import numpy as np
10
+ from gym import spaces
11
+ from plato.config import Config
12
+
13
+
14
+ class RLAgent(object):
15
+ """A basic RL environment for the FL server, using Gym for RL control."""
16
+
17
+ def __init__(self):
18
+ self.n_actions = Config().clients.per_round
19
+ self.n_states = Config().clients.per_round * Config().algorithm.n_features
20
+
21
+ if Config().algorithm.discrete_action_space:
22
+ self.action_space = spaces.Discrete(self.n_actions)
23
+ else:
24
+ self.action_space = spaces.Box(
25
+ low=int(Config().algorithm.min_action),
26
+ high=Config().algorithm.max_action,
27
+ shape=(self.n_actions,),
28
+ dtype=np.float32,
29
+ )
30
+
31
+ self.observation_space = spaces.Box(
32
+ low=-np.inf, high=np.inf, shape=(self.n_states,), dtype=np.float32
33
+ )
34
+
35
+ self.state = None
36
+ self.next_state = None
37
+ self.new_state = None
38
+ self.action = None
39
+ self.next_action = None
40
+ self.reward = 0
41
+ self.episode_reward = 0
42
+ self.current_step = 0
43
+ self.total_steps = 0
44
+ self.current_episode = 0
45
+ self.is_done = False
46
+ self.reset_env = False
47
+ self.finished = False
48
+
49
+ # RL server waits for the event that the next action is updated
50
+ self.action_updated = asyncio.Event()
51
+
52
+ def step(self):
53
+ """Update the followings using server update."""
54
+ self.new_state = self.get_state()
55
+ self.is_done = self.get_done()
56
+ self.reward = self.get_reward()
57
+ info = self.get_info()
58
+
59
+ return self.new_state, self.reward, self.is_done, info
60
+
61
+ async def reset(self):
62
+ """Reset RL environment."""
63
+ # Start a new training session
64
+ logging.info("[RL Agent] Reseting RL environment.")
65
+
66
+ # Reset the episode-related variables
67
+ self.current_step = 0
68
+ self.is_done = False
69
+ self.episode_reward = 0
70
+ self.current_episode += 1
71
+ self.reset_env = True
72
+ logging.info("[RL Agent] Starting RL episode #%d.", self.current_episode)
73
+
74
+ def prep_action(self):
75
+ """Get action from RL policy."""
76
+ logging.info("[RL Agent] Selecting action...")
77
+ self.action = self.policy.select_action(self.state)
78
+
79
+ def get_state(self):
80
+ """Get state for agent."""
81
+ return self.new_state
82
+
83
+ def get_reward(self):
84
+ """Get reward for agent."""
85
+ return 0.0
86
+
87
+ def get_done(self):
88
+ """Get done condition for agent."""
89
+ if (
90
+ Config().algorithm.mode == "train"
91
+ and self.current_step >= Config().algorithm.steps_per_episode
92
+ ):
93
+ logging.info("[RL Agent] Episode #%d ended.", self.current_episode)
94
+ return True
95
+ return False
96
+
97
+ def get_info(self):
98
+ """Get info used for benchmarking."""
99
+ return {}
100
+
101
+ def process_env_update(self):
102
+ """Process state update to RL Agent."""
103
+ if self.current_step == 0:
104
+ self.state = self.get_state()
105
+ else:
106
+ self.next_state, self.reward, self.is_done, __ = self.step()
107
+ if Config().algorithm.mode == "train":
108
+ self.process_experience()
109
+ self.state = self.next_state
110
+ self.episode_reward += self.reward
111
+
112
+ async def prep_agent_update(self):
113
+ """Update RL Agent."""
114
+ self.current_step += 1
115
+ self.total_steps += 1
116
+ logging.info("[RL Agent] Preparing action...")
117
+ self.prep_action()
118
+ self.action_updated.set()
119
+
120
+ # when episode ends
121
+ if Config().algorithm.mode == "train" and self.is_done:
122
+ self.update_policy()
123
+
124
+ # Break the loop when RL training is concluded
125
+ if self.current_episode >= Config().algorithm.max_episode:
126
+ self.finished = True
127
+ else:
128
+ await self.reset()
129
+ elif (
130
+ Config().algorithm.mode == "test"
131
+ and self.current_step >= Config().algorithm.test_step
132
+ ):
133
+ # Break the loop when RL testing is concluded
134
+ self.finished = True
135
+
136
+ @abstractmethod
137
+ def update_policy(self):
138
+ """Update policy if needed in training mode."""
139
+
140
+ @abstractmethod
141
+ def process_experience(self):
142
+ """Process step experience if needed in training mode."""