plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,485 @@
|
|
1
|
+
"""
|
2
|
+
Reference:
|
3
|
+
|
4
|
+
https://github.com/AntoineTheb/RNN-RL
|
5
|
+
"""
|
6
|
+
|
7
|
+
import copy
|
8
|
+
import random
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import torch
|
12
|
+
import torch.nn.functional as F
|
13
|
+
from plato.config import Config
|
14
|
+
from plato.utils.reinforcement_learning.policies import base
|
15
|
+
from torch import nn
|
16
|
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
|
17
|
+
|
18
|
+
|
19
|
+
class RNNReplayMemory:
|
20
|
+
def __init__(self, state_dim, action_dim, hidden_size, capacity, seed):
|
21
|
+
random.seed(seed)
|
22
|
+
self.device = Config().device()
|
23
|
+
self.capacity = int(capacity)
|
24
|
+
self.ptr = 0
|
25
|
+
self.size = 0
|
26
|
+
|
27
|
+
self.h = np.zeros((self.capacity, hidden_size))
|
28
|
+
self.nh = np.zeros((self.capacity, hidden_size))
|
29
|
+
self.c = np.zeros((self.capacity, hidden_size))
|
30
|
+
self.nc = np.zeros((self.capacity, hidden_size))
|
31
|
+
if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
|
32
|
+
self.state = [0] * self.capacity
|
33
|
+
self.action = [0] * self.capacity
|
34
|
+
self.reward = [0] * self.capacity
|
35
|
+
self.next_state = [0] * self.capacity
|
36
|
+
self.done = [0] * self.capacity
|
37
|
+
else:
|
38
|
+
self.state = np.zeros((self.capacity, state_dim))
|
39
|
+
self.action = np.zeros((self.capacity, action_dim))
|
40
|
+
self.reward = np.zeros((self.capacity, 1))
|
41
|
+
self.next_state = np.zeros((self.capacity, state_dim))
|
42
|
+
self.done = np.zeros((self.capacity, 1))
|
43
|
+
|
44
|
+
def push(self, data):
|
45
|
+
self.state[self.ptr] = data[0]
|
46
|
+
self.action[self.ptr] = data[1]
|
47
|
+
self.reward[self.ptr] = data[2]
|
48
|
+
self.next_state[self.ptr] = data[3]
|
49
|
+
self.done[self.ptr] = data[4]
|
50
|
+
|
51
|
+
self.h[self.ptr] = data[5].detach().cpu()
|
52
|
+
self.c[self.ptr] = data[6].detach().cpu()
|
53
|
+
self.nh[self.ptr] = data[7].detach().cpu()
|
54
|
+
self.nc[self.ptr] = data[8].detach().cpu()
|
55
|
+
|
56
|
+
self.ptr = (self.ptr + 1) % self.capacity
|
57
|
+
self.size = min(self.size + 1, self.capacity)
|
58
|
+
|
59
|
+
def sample(self):
|
60
|
+
ind = np.random.randint(0, self.size, size=int(Config().algorithm.batch_size))
|
61
|
+
|
62
|
+
h = torch.tensor(
|
63
|
+
self.h[ind][None, ...], requires_grad=True, dtype=torch.float
|
64
|
+
).to(self.device)
|
65
|
+
c = torch.tensor(
|
66
|
+
self.c[ind][None, ...], requires_grad=True, dtype=torch.float
|
67
|
+
).to(self.device)
|
68
|
+
nh = torch.tensor(
|
69
|
+
self.nh[ind][None, ...], requires_grad=True, dtype=torch.float
|
70
|
+
).to(self.device)
|
71
|
+
nc = torch.tensor(
|
72
|
+
self.nc[ind][None, ...], requires_grad=True, dtype=torch.float
|
73
|
+
).to(self.device)
|
74
|
+
|
75
|
+
if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
|
76
|
+
state = [torch.FloatTensor(self.state[i]).to(self.device) for i in ind]
|
77
|
+
action = [torch.FloatTensor(self.action[i]).to(self.device) for i in ind]
|
78
|
+
reward = [self.reward[i] for i in ind]
|
79
|
+
next_state = [
|
80
|
+
torch.FloatTensor(self.next_state[i]).to(self.device) for i in ind
|
81
|
+
]
|
82
|
+
done = [self.done[i] for i in ind]
|
83
|
+
else:
|
84
|
+
state = torch.FloatTensor(self.state[ind][:, None, :]).to(self.device)
|
85
|
+
|
86
|
+
action = torch.FloatTensor(self.action[ind][:, None, :]).to(self.device)
|
87
|
+
reward = torch.FloatTensor(self.reward[ind][:, None, :]).to(self.device)
|
88
|
+
next_state = torch.FloatTensor(self.next_state[ind][:, None, :]).to(
|
89
|
+
self.device
|
90
|
+
)
|
91
|
+
done = torch.FloatTensor(self.done[ind][:, None, :]).to(self.device)
|
92
|
+
|
93
|
+
return state, action, reward, next_state, done, h, c, nh, nc
|
94
|
+
|
95
|
+
def __len__(self):
|
96
|
+
return self.size
|
97
|
+
|
98
|
+
|
99
|
+
class TD3Actor(base.Actor):
|
100
|
+
def __init__(self, state_dim, action_dim, max_action):
|
101
|
+
super().__init__(state_dim, action_dim, max_action)
|
102
|
+
|
103
|
+
def forward(self, x, hidden=None):
|
104
|
+
x = F.relu(self.l1(x))
|
105
|
+
x = F.relu(self.l2(x))
|
106
|
+
x = self.max_action * torch.tanh(self.l3(x))
|
107
|
+
# Normalize/Scaling aggregation weights so that the sum is 1
|
108
|
+
x += 1 # [-1, 1] -> [0, 2]
|
109
|
+
x /= x.sum()
|
110
|
+
return x
|
111
|
+
|
112
|
+
|
113
|
+
class TD3Critic(nn.Module):
|
114
|
+
def __init__(self, state_dim, action_dim):
|
115
|
+
super(TD3Critic, self).__init__()
|
116
|
+
|
117
|
+
# Q1 architecture
|
118
|
+
self.l1 = nn.Linear(state_dim + action_dim, 400)
|
119
|
+
self.l2 = nn.Linear(400, 300)
|
120
|
+
self.l3 = nn.Linear(300, 1)
|
121
|
+
|
122
|
+
# Q2 architecture
|
123
|
+
self.l4 = nn.Linear(state_dim + action_dim, 400)
|
124
|
+
self.l5 = nn.Linear(400, 300)
|
125
|
+
self.l6 = nn.Linear(300, 1)
|
126
|
+
|
127
|
+
def forward(self, state, action, hidden1=None, hidden2=None):
|
128
|
+
sa = torch.cat([state, action], 1)
|
129
|
+
q1 = F.relu(self.l1(sa))
|
130
|
+
q1 = F.relu(self.l2(q1))
|
131
|
+
q1 = self.l3(q1)
|
132
|
+
q2 = F.relu(self.l4(sa))
|
133
|
+
q2 = F.relu(self.l5(q2))
|
134
|
+
q2 = self.l6(q2)
|
135
|
+
return q1, q2
|
136
|
+
|
137
|
+
def Q1(self, state, action, hidden=None):
|
138
|
+
sa = torch.cat([state, action], 1)
|
139
|
+
q1 = F.relu(self.l1(sa))
|
140
|
+
q1 = F.relu(self.l2(q1))
|
141
|
+
q1 = self.l3(q1)
|
142
|
+
return q1
|
143
|
+
|
144
|
+
|
145
|
+
class RNNActor(nn.Module):
|
146
|
+
def __init__(self, state_dim, action_dim, hidden_size, max_action):
|
147
|
+
super(RNNActor, self).__init__()
|
148
|
+
self.action_dim = action_dim
|
149
|
+
self.max_action = max_action
|
150
|
+
|
151
|
+
self.l1 = nn.LSTM(state_dim, hidden_size, batch_first=True)
|
152
|
+
self.l2 = nn.Linear(hidden_size, hidden_size)
|
153
|
+
if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
|
154
|
+
self.l3 = nn.Linear(hidden_size, 1)
|
155
|
+
else:
|
156
|
+
self.l3 = nn.Linear(hidden_size, action_dim)
|
157
|
+
|
158
|
+
def forward(self, state, hidden=None):
|
159
|
+
if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
|
160
|
+
# Pad the first state to full dims
|
161
|
+
if len(state) == 1:
|
162
|
+
pilot = state
|
163
|
+
else:
|
164
|
+
pilot = state[0]
|
165
|
+
pilot = F.pad(
|
166
|
+
input=pilot,
|
167
|
+
pad=(0, 0, 0, self.action_dim - pilot.shape[-2]),
|
168
|
+
mode="constant",
|
169
|
+
value=0,
|
170
|
+
)
|
171
|
+
if len(state) == 1:
|
172
|
+
state = pilot
|
173
|
+
else:
|
174
|
+
state[0] = pilot
|
175
|
+
# Pad variable states
|
176
|
+
# Get the length explicitly for later packing sequences
|
177
|
+
lens = list(map(len, state))
|
178
|
+
if len(state) == 1:
|
179
|
+
state = [torch.squeeze(state)]
|
180
|
+
# Pad and pack
|
181
|
+
padded = pad_sequence(state, batch_first=True)
|
182
|
+
state = pack_padded_sequence(
|
183
|
+
padded, lengths=lens, batch_first=True, enforce_sorted=False
|
184
|
+
)
|
185
|
+
self.l1.flatten_parameters()
|
186
|
+
a, h = self.l1(state, hidden)
|
187
|
+
|
188
|
+
# mini-batch update
|
189
|
+
if (
|
190
|
+
hasattr(Config().server, "synchronous")
|
191
|
+
and not Config().server.synchronous
|
192
|
+
and len(state) != 1
|
193
|
+
):
|
194
|
+
a, _ = pad_packed_sequence(a, batch_first=True)
|
195
|
+
|
196
|
+
a = F.relu(self.l2(a))
|
197
|
+
a = self.max_action * torch.tanh(self.l3(a))
|
198
|
+
|
199
|
+
# Normalize/Scaling aggregation weights so that the sum is 1
|
200
|
+
a += 1 # [-1, 1] -> [0, 2]
|
201
|
+
a /= a.sum()
|
202
|
+
|
203
|
+
return a, h
|
204
|
+
|
205
|
+
|
206
|
+
class RNNCritic(nn.Module):
|
207
|
+
def __init__(self, state_dim, action_dim, hidden_size):
|
208
|
+
super(RNNCritic, self).__init__()
|
209
|
+
self.action_dim = action_dim
|
210
|
+
|
211
|
+
# Q1 architecture
|
212
|
+
if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
|
213
|
+
self.l1 = nn.LSTM(state_dim + 1, hidden_size, batch_first=True)
|
214
|
+
else:
|
215
|
+
self.l1 = nn.LSTM(state_dim + action_dim, hidden_size, batch_first=True)
|
216
|
+
self.l2 = nn.Linear(hidden_size, hidden_size)
|
217
|
+
self.l3 = nn.Linear(hidden_size, 1)
|
218
|
+
|
219
|
+
# Q2 architecture
|
220
|
+
if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
|
221
|
+
self.l4 = nn.LSTM(state_dim + 1, hidden_size, batch_first=True)
|
222
|
+
else:
|
223
|
+
self.l4 = nn.LSTM(state_dim + action_dim, hidden_size, batch_first=True)
|
224
|
+
self.l5 = nn.Linear(hidden_size, hidden_size)
|
225
|
+
self.l6 = nn.Linear(hidden_size, 1)
|
226
|
+
|
227
|
+
def forward(self, state, action, hidden1, hidden2):
|
228
|
+
if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
|
229
|
+
# Pad the first state to full dims
|
230
|
+
if len(state) == 1:
|
231
|
+
pilot = state
|
232
|
+
else:
|
233
|
+
pilot = state[0]
|
234
|
+
pilot = F.pad(
|
235
|
+
input=pilot,
|
236
|
+
pad=(0, 0, 0, self.action_dim - pilot.shape[-2]),
|
237
|
+
mode="constant",
|
238
|
+
value=0,
|
239
|
+
)
|
240
|
+
if len(state) == 1:
|
241
|
+
state = pilot
|
242
|
+
else:
|
243
|
+
state[0] = pilot
|
244
|
+
# Pad variable states
|
245
|
+
# Get the length explicitly for later packing sequences
|
246
|
+
lens = list(map(len, state))
|
247
|
+
if len(state) == 1:
|
248
|
+
state = [torch.squeeze(state)]
|
249
|
+
# Pad and pack
|
250
|
+
padded = pad_sequence(state, batch_first=True)
|
251
|
+
state = padded
|
252
|
+
sa = torch.cat([state, action], -1)
|
253
|
+
|
254
|
+
if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
|
255
|
+
sa = pack_padded_sequence(
|
256
|
+
sa, lengths=lens, batch_first=True, enforce_sorted=False
|
257
|
+
)
|
258
|
+
self.l1.flatten_parameters()
|
259
|
+
self.l4.flatten_parameters()
|
260
|
+
q1, hidden1 = self.l1(sa, hidden1)
|
261
|
+
q2, hidden2 = self.l4(sa, hidden2)
|
262
|
+
|
263
|
+
if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
|
264
|
+
q1, _ = pad_packed_sequence(q1, batch_first=True)
|
265
|
+
q2, _ = pad_packed_sequence(q2, batch_first=True)
|
266
|
+
|
267
|
+
q1 = F.relu(self.l2(q1))
|
268
|
+
q1 = self.l3(q1)
|
269
|
+
q1 = torch.mean(q1.reshape(q1.shape[0], -1, 1), 1)
|
270
|
+
|
271
|
+
q2 = F.relu(self.l5(q2))
|
272
|
+
q2 = self.l6(q2)
|
273
|
+
q2 = torch.mean(q2.reshape(q2.shape[0], -1, 1), 1)
|
274
|
+
|
275
|
+
return q1, q2
|
276
|
+
|
277
|
+
def Q1(self, state, action, hidden1):
|
278
|
+
if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
|
279
|
+
# Pad variable states
|
280
|
+
# Get the length explicitly for later packing sequences
|
281
|
+
lens = list(map(len, state))
|
282
|
+
# Pad and pack
|
283
|
+
padded = pad_sequence(state, batch_first=True)
|
284
|
+
state = padded
|
285
|
+
|
286
|
+
sa = torch.cat([state, action], -1)
|
287
|
+
|
288
|
+
if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
|
289
|
+
sa = pack_padded_sequence(
|
290
|
+
sa, lengths=lens, batch_first=True, enforce_sorted=False
|
291
|
+
)
|
292
|
+
self.l1.flatten_parameters()
|
293
|
+
q1, hidden1 = self.l1(sa, hidden1)
|
294
|
+
|
295
|
+
if hasattr(Config().server, "synchronous") and not Config().server.synchronous:
|
296
|
+
q1, _ = pad_packed_sequence(q1, batch_first=True)
|
297
|
+
|
298
|
+
q1 = F.relu(self.l2(q1))
|
299
|
+
q1 = self.l3(q1)
|
300
|
+
q1 = torch.mean(q1.reshape(q1.shape[0], -1, 1), 1)
|
301
|
+
|
302
|
+
return q1
|
303
|
+
|
304
|
+
|
305
|
+
class Policy(base.Policy):
|
306
|
+
def __init__(self, state_dim, action_dim):
|
307
|
+
super().__init__(state_dim, action_dim)
|
308
|
+
|
309
|
+
# Initialize NNs
|
310
|
+
if Config().algorithm.recurrent_actor:
|
311
|
+
self.actor = RNNActor(
|
312
|
+
state_dim, action_dim, Config().algorithm.hidden_size, self.max_action
|
313
|
+
).to(self.device)
|
314
|
+
self.critic = RNNCritic(
|
315
|
+
state_dim, action_dim, Config().algorithm.hidden_size
|
316
|
+
).to(self.device)
|
317
|
+
else:
|
318
|
+
self.actor = TD3Actor(state_dim, action_dim, self.max_action).to(
|
319
|
+
self.device
|
320
|
+
)
|
321
|
+
self.critic = TD3Critic(state_dim, action_dim).to(self.device)
|
322
|
+
|
323
|
+
self.actor_target = copy.deepcopy(self.actor)
|
324
|
+
self.actor_optimizer = torch.optim.Adam(
|
325
|
+
self.actor.parameters(), lr=Config().algorithm.learning_rate
|
326
|
+
)
|
327
|
+
|
328
|
+
self.critic_target = copy.deepcopy(self.critic)
|
329
|
+
self.critic_optimizer = torch.optim.Adam(
|
330
|
+
self.critic.parameters(), lr=Config().algorithm.learning_rate
|
331
|
+
)
|
332
|
+
|
333
|
+
# Initialize replay memory
|
334
|
+
if Config().algorithm.recurrent_actor:
|
335
|
+
self.replay_buffer = RNNReplayMemory(
|
336
|
+
state_dim,
|
337
|
+
action_dim,
|
338
|
+
Config().algorithm.hidden_size,
|
339
|
+
Config().algorithm.replay_size,
|
340
|
+
Config().algorithm.replay_seed,
|
341
|
+
)
|
342
|
+
|
343
|
+
else:
|
344
|
+
self.replay_buffer = base.ReplayMemory(
|
345
|
+
state_dim,
|
346
|
+
action_dim,
|
347
|
+
Config().algorithm.replay_size,
|
348
|
+
Config().algorithm.replay_seed,
|
349
|
+
)
|
350
|
+
|
351
|
+
self.policy_noise = Config().algorithm.policy_noise * self.max_action
|
352
|
+
self.noise_clip = Config().algorithm.noise_clip * self.max_action
|
353
|
+
|
354
|
+
def get_initial_states(self):
|
355
|
+
h_0, c_0 = None, None
|
356
|
+
if Config().algorithm.recurrent_actor:
|
357
|
+
h_0 = torch.zeros(
|
358
|
+
(self.actor.l1.num_layers, 1, self.actor.l1.hidden_size),
|
359
|
+
dtype=torch.float,
|
360
|
+
)
|
361
|
+
# h_0 = h_0.to(self.device)
|
362
|
+
|
363
|
+
c_0 = torch.zeros(
|
364
|
+
(self.actor.l1.num_layers, 1, self.actor.l1.hidden_size),
|
365
|
+
dtype=torch.float,
|
366
|
+
)
|
367
|
+
# c_0 = c_0.to(self.device)
|
368
|
+
return (h_0, c_0)
|
369
|
+
|
370
|
+
def select_action(self, state, hidden=None, test=False):
|
371
|
+
"""Select action from policy."""
|
372
|
+
if Config().algorithm.recurrent_actor:
|
373
|
+
if (
|
374
|
+
hasattr(Config().server, "synchronous")
|
375
|
+
and not Config().server.synchronous
|
376
|
+
):
|
377
|
+
state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
|
378
|
+
else:
|
379
|
+
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)[
|
380
|
+
:, None, :
|
381
|
+
]
|
382
|
+
# state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
|
383
|
+
action, hidden = self.actor(state, hidden)
|
384
|
+
return action.cpu().data.numpy().flatten(), hidden
|
385
|
+
else:
|
386
|
+
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
|
387
|
+
action = self.actor(state)
|
388
|
+
return action.cpu().data.numpy().flatten()
|
389
|
+
|
390
|
+
def update(self):
|
391
|
+
"""Update policy."""
|
392
|
+
self.total_it += 1
|
393
|
+
|
394
|
+
# Sample replay buffer
|
395
|
+
if Config().algorithm.recurrent_actor:
|
396
|
+
state, action, reward, next_state, done, h, c, nh, nc = (
|
397
|
+
self.replay_buffer.sample()
|
398
|
+
)
|
399
|
+
if (
|
400
|
+
hasattr(Config().server, "synchronous")
|
401
|
+
and not Config().server.synchronous
|
402
|
+
):
|
403
|
+
# Pad variable actions
|
404
|
+
padded = pad_sequence(action, batch_first=True)
|
405
|
+
action = padded
|
406
|
+
reward = torch.FloatTensor(reward).to(self.device).unsqueeze(1)
|
407
|
+
done = torch.FloatTensor(done).to(self.device).unsqueeze(1)
|
408
|
+
hidden = (h, c)
|
409
|
+
next_hidden = (nh, nc)
|
410
|
+
else:
|
411
|
+
state, action, reward, next_state, done = self.replay_buffer.sample()
|
412
|
+
state = torch.FloatTensor(state).to(self.device)
|
413
|
+
action = torch.FloatTensor(action).to(self.device)
|
414
|
+
reward = torch.FloatTensor(reward).to(self.device)
|
415
|
+
next_state = torch.FloatTensor(next_state).to(self.device)
|
416
|
+
done = torch.FloatTensor(done).to(self.device)
|
417
|
+
hidden, next_hidden = (None, None), (None, None)
|
418
|
+
|
419
|
+
with torch.no_grad():
|
420
|
+
# Select action according to policy and add clipped noise
|
421
|
+
noise = (torch.randn_like(action) * self.policy_noise).clamp(
|
422
|
+
-self.noise_clip, self.noise_clip
|
423
|
+
)
|
424
|
+
|
425
|
+
next_action = (self.actor_target(next_state, next_hidden)[0] + noise).clamp(
|
426
|
+
-self.max_action, self.max_action
|
427
|
+
)
|
428
|
+
|
429
|
+
# Compute the target Q value
|
430
|
+
target_Q1, target_Q2 = self.critic_target(
|
431
|
+
next_state, next_action, next_hidden, next_hidden
|
432
|
+
)
|
433
|
+
target_Q = torch.min(target_Q1, target_Q2)
|
434
|
+
target_Q = reward + (1 - done) * Config().algorithm.gamma * target_Q
|
435
|
+
|
436
|
+
# Get current Q estimates
|
437
|
+
current_Q1, current_Q2 = self.critic(state, action, hidden, hidden)
|
438
|
+
|
439
|
+
# Compute critic loss
|
440
|
+
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
|
441
|
+
current_Q2, target_Q
|
442
|
+
)
|
443
|
+
|
444
|
+
# Optimize the critic
|
445
|
+
self.critic_optimizer.zero_grad()
|
446
|
+
critic_loss.backward()
|
447
|
+
self.critic_optimizer.step()
|
448
|
+
|
449
|
+
actor_loss = critic_loss
|
450
|
+
|
451
|
+
# Delayed policy updates
|
452
|
+
if self.total_it % Config().algorithm.policy_freq == 0:
|
453
|
+
# Compute actor loss
|
454
|
+
if Config().algorithm.recurrent_actor:
|
455
|
+
actor_loss = -self.critic.Q1(
|
456
|
+
state, self.actor(state, hidden)[0], hidden
|
457
|
+
).mean()
|
458
|
+
else:
|
459
|
+
actor_loss = -self.critic.Q1(
|
460
|
+
state, self.actor(state, hidden), hidden
|
461
|
+
).mean()
|
462
|
+
|
463
|
+
# Optimize the actor
|
464
|
+
self.actor_optimizer.zero_grad()
|
465
|
+
actor_loss.backward()
|
466
|
+
self.actor_optimizer.step()
|
467
|
+
|
468
|
+
# Update the frozen target models
|
469
|
+
for param, target_param in zip(
|
470
|
+
self.critic.parameters(), self.critic_target.parameters()
|
471
|
+
):
|
472
|
+
target_param.data.copy_(
|
473
|
+
Config().algorithm.tau * param.data
|
474
|
+
+ (1 - Config().algorithm.tau) * target_param.data
|
475
|
+
)
|
476
|
+
|
477
|
+
for param, target_param in zip(
|
478
|
+
self.actor.parameters(), self.actor_target.parameters()
|
479
|
+
):
|
480
|
+
target_param.data.copy_(
|
481
|
+
Config().algorithm.tau * param.data
|
482
|
+
+ (1 - Config().algorithm.tau) * target_param.data
|
483
|
+
)
|
484
|
+
|
485
|
+
return critic_loss.item(), actor_loss.item()
|
@@ -0,0 +1,142 @@
|
|
1
|
+
"""
|
2
|
+
A basic RL environment for FL server using Gym for RL control.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import logging
|
7
|
+
from abc import abstractmethod
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from gym import spaces
|
11
|
+
from plato.config import Config
|
12
|
+
|
13
|
+
|
14
|
+
class RLAgent(object):
|
15
|
+
"""A basic RL environment for the FL server, using Gym for RL control."""
|
16
|
+
|
17
|
+
def __init__(self):
|
18
|
+
self.n_actions = Config().clients.per_round
|
19
|
+
self.n_states = Config().clients.per_round * Config().algorithm.n_features
|
20
|
+
|
21
|
+
if Config().algorithm.discrete_action_space:
|
22
|
+
self.action_space = spaces.Discrete(self.n_actions)
|
23
|
+
else:
|
24
|
+
self.action_space = spaces.Box(
|
25
|
+
low=int(Config().algorithm.min_action),
|
26
|
+
high=Config().algorithm.max_action,
|
27
|
+
shape=(self.n_actions,),
|
28
|
+
dtype=np.float32,
|
29
|
+
)
|
30
|
+
|
31
|
+
self.observation_space = spaces.Box(
|
32
|
+
low=-np.inf, high=np.inf, shape=(self.n_states,), dtype=np.float32
|
33
|
+
)
|
34
|
+
|
35
|
+
self.state = None
|
36
|
+
self.next_state = None
|
37
|
+
self.new_state = None
|
38
|
+
self.action = None
|
39
|
+
self.next_action = None
|
40
|
+
self.reward = 0
|
41
|
+
self.episode_reward = 0
|
42
|
+
self.current_step = 0
|
43
|
+
self.total_steps = 0
|
44
|
+
self.current_episode = 0
|
45
|
+
self.is_done = False
|
46
|
+
self.reset_env = False
|
47
|
+
self.finished = False
|
48
|
+
|
49
|
+
# RL server waits for the event that the next action is updated
|
50
|
+
self.action_updated = asyncio.Event()
|
51
|
+
|
52
|
+
def step(self):
|
53
|
+
"""Update the followings using server update."""
|
54
|
+
self.new_state = self.get_state()
|
55
|
+
self.is_done = self.get_done()
|
56
|
+
self.reward = self.get_reward()
|
57
|
+
info = self.get_info()
|
58
|
+
|
59
|
+
return self.new_state, self.reward, self.is_done, info
|
60
|
+
|
61
|
+
async def reset(self):
|
62
|
+
"""Reset RL environment."""
|
63
|
+
# Start a new training session
|
64
|
+
logging.info("[RL Agent] Reseting RL environment.")
|
65
|
+
|
66
|
+
# Reset the episode-related variables
|
67
|
+
self.current_step = 0
|
68
|
+
self.is_done = False
|
69
|
+
self.episode_reward = 0
|
70
|
+
self.current_episode += 1
|
71
|
+
self.reset_env = True
|
72
|
+
logging.info("[RL Agent] Starting RL episode #%d.", self.current_episode)
|
73
|
+
|
74
|
+
def prep_action(self):
|
75
|
+
"""Get action from RL policy."""
|
76
|
+
logging.info("[RL Agent] Selecting action...")
|
77
|
+
self.action = self.policy.select_action(self.state)
|
78
|
+
|
79
|
+
def get_state(self):
|
80
|
+
"""Get state for agent."""
|
81
|
+
return self.new_state
|
82
|
+
|
83
|
+
def get_reward(self):
|
84
|
+
"""Get reward for agent."""
|
85
|
+
return 0.0
|
86
|
+
|
87
|
+
def get_done(self):
|
88
|
+
"""Get done condition for agent."""
|
89
|
+
if (
|
90
|
+
Config().algorithm.mode == "train"
|
91
|
+
and self.current_step >= Config().algorithm.steps_per_episode
|
92
|
+
):
|
93
|
+
logging.info("[RL Agent] Episode #%d ended.", self.current_episode)
|
94
|
+
return True
|
95
|
+
return False
|
96
|
+
|
97
|
+
def get_info(self):
|
98
|
+
"""Get info used for benchmarking."""
|
99
|
+
return {}
|
100
|
+
|
101
|
+
def process_env_update(self):
|
102
|
+
"""Process state update to RL Agent."""
|
103
|
+
if self.current_step == 0:
|
104
|
+
self.state = self.get_state()
|
105
|
+
else:
|
106
|
+
self.next_state, self.reward, self.is_done, __ = self.step()
|
107
|
+
if Config().algorithm.mode == "train":
|
108
|
+
self.process_experience()
|
109
|
+
self.state = self.next_state
|
110
|
+
self.episode_reward += self.reward
|
111
|
+
|
112
|
+
async def prep_agent_update(self):
|
113
|
+
"""Update RL Agent."""
|
114
|
+
self.current_step += 1
|
115
|
+
self.total_steps += 1
|
116
|
+
logging.info("[RL Agent] Preparing action...")
|
117
|
+
self.prep_action()
|
118
|
+
self.action_updated.set()
|
119
|
+
|
120
|
+
# when episode ends
|
121
|
+
if Config().algorithm.mode == "train" and self.is_done:
|
122
|
+
self.update_policy()
|
123
|
+
|
124
|
+
# Break the loop when RL training is concluded
|
125
|
+
if self.current_episode >= Config().algorithm.max_episode:
|
126
|
+
self.finished = True
|
127
|
+
else:
|
128
|
+
await self.reset()
|
129
|
+
elif (
|
130
|
+
Config().algorithm.mode == "test"
|
131
|
+
and self.current_step >= Config().algorithm.test_step
|
132
|
+
):
|
133
|
+
# Break the loop when RL testing is concluded
|
134
|
+
self.finished = True
|
135
|
+
|
136
|
+
@abstractmethod
|
137
|
+
def update_policy(self):
|
138
|
+
"""Update policy if needed in training mode."""
|
139
|
+
|
140
|
+
@abstractmethod
|
141
|
+
def process_experience(self):
|
142
|
+
"""Process step experience if needed in training mode."""
|