plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,113 @@
|
|
1
|
+
"""
|
2
|
+
A federated learning server with RL Agent.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import logging
|
7
|
+
from abc import abstractmethod
|
8
|
+
|
9
|
+
from plato.servers import fedavg
|
10
|
+
|
11
|
+
|
12
|
+
class RLServer(fedavg.Server):
|
13
|
+
"""A federated learning server with an RL Agent."""
|
14
|
+
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
agent,
|
18
|
+
model=None,
|
19
|
+
datasource=None,
|
20
|
+
algorithm=None,
|
21
|
+
trainer=None,
|
22
|
+
callbacks=None,
|
23
|
+
):
|
24
|
+
super().__init__(
|
25
|
+
model=model,
|
26
|
+
datasource=datasource,
|
27
|
+
algorithm=algorithm,
|
28
|
+
trainer=trainer,
|
29
|
+
callbacks=callbacks,
|
30
|
+
)
|
31
|
+
self.agent = agent
|
32
|
+
|
33
|
+
def reset(self):
|
34
|
+
"""Resetting the model, trainer, and algorithm on the server."""
|
35
|
+
logging.info(
|
36
|
+
"Reconfiguring the server for episode %d", self.agent.current_episode
|
37
|
+
)
|
38
|
+
|
39
|
+
self.model = None
|
40
|
+
self.trainer = None
|
41
|
+
self.algorithm = None
|
42
|
+
self.init_trainer()
|
43
|
+
|
44
|
+
self.current_round = 0
|
45
|
+
|
46
|
+
async def aggregate_deltas(self, updates, deltas_received):
|
47
|
+
"""Aggregate weight updates from the clients using smart weighting."""
|
48
|
+
self.update_state()
|
49
|
+
|
50
|
+
# Extract the total number of samples
|
51
|
+
num_samples = [update.report.num_samples for update in updates]
|
52
|
+
self.total_samples = sum(num_samples)
|
53
|
+
|
54
|
+
# Perform weighted averaging
|
55
|
+
avg_update = {
|
56
|
+
name: self.trainer.zeros(weights.shape)
|
57
|
+
for name, weights in deltas_received[0].items()
|
58
|
+
}
|
59
|
+
|
60
|
+
# e.g., wait for the new action from RL agent
|
61
|
+
# if the action affects the global aggregation
|
62
|
+
self.agent.num_samples = num_samples
|
63
|
+
await self.agent.prep_agent_update()
|
64
|
+
await self.update_action()
|
65
|
+
|
66
|
+
# Use adaptive weighted average
|
67
|
+
for i, update in enumerate(deltas_received):
|
68
|
+
for name, delta in update.items():
|
69
|
+
if delta.type() == "torch.LongTensor":
|
70
|
+
avg_update[name] += delta * self.smart_weighting[i][0]
|
71
|
+
else:
|
72
|
+
avg_update[name] += delta * self.smart_weighting[i]
|
73
|
+
|
74
|
+
# Yield to other tasks in the server
|
75
|
+
await asyncio.sleep(0)
|
76
|
+
|
77
|
+
return avg_update
|
78
|
+
|
79
|
+
async def update_action(self):
|
80
|
+
"""Updating the RL agent's actions."""
|
81
|
+
if self.agent.current_step == 0:
|
82
|
+
logging.info("[RL Agent] Preparing initial action...")
|
83
|
+
self.agent.prep_action()
|
84
|
+
else:
|
85
|
+
await self.agent.action_updated.wait()
|
86
|
+
self.agent.action_updated.clear()
|
87
|
+
|
88
|
+
self.apply_action()
|
89
|
+
|
90
|
+
def update_state(self):
|
91
|
+
"""Wrap up the state update to RL Agent."""
|
92
|
+
# Pass new state to RL Agent
|
93
|
+
self.agent.new_state = self.prep_state()
|
94
|
+
self.agent.process_env_update()
|
95
|
+
|
96
|
+
async def wrap_up(self) -> None:
|
97
|
+
"""Wrapping up when each round of training is done."""
|
98
|
+
self.save_to_checkpoint()
|
99
|
+
|
100
|
+
if self.agent.reset_env:
|
101
|
+
self.agent.reset_env = False
|
102
|
+
self.reset()
|
103
|
+
if self.agent.finished:
|
104
|
+
await self._close()
|
105
|
+
|
106
|
+
@abstractmethod
|
107
|
+
def prep_state(self):
|
108
|
+
"""Wrap up the state update to RL Agent."""
|
109
|
+
return
|
110
|
+
|
111
|
+
@abstractmethod
|
112
|
+
def apply_action(self):
|
113
|
+
"""Apply action update from RL Agent to FL Env."""
|
plato/utils/rl_env.py
ADDED
@@ -0,0 +1,154 @@
|
|
1
|
+
"""
|
2
|
+
An environment of the reinforcement learning agent for tuning parameters
|
3
|
+
during the training of federated learning.
|
4
|
+
This environment follows the gym interface, in order to use stable-baselines3:
|
5
|
+
https://github.com/DLR-RM/stable-baselines3.
|
6
|
+
|
7
|
+
To create and use other custom environments, check out:
|
8
|
+
https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html.
|
9
|
+
"""
|
10
|
+
|
11
|
+
import asyncio
|
12
|
+
import logging
|
13
|
+
|
14
|
+
import gym
|
15
|
+
import numpy as np
|
16
|
+
from plato.config import Config
|
17
|
+
from gym import spaces
|
18
|
+
|
19
|
+
|
20
|
+
class RLEnv(gym.Env):
|
21
|
+
"""The environment of federated learning."""
|
22
|
+
|
23
|
+
metadata = {"render.modes": ["fl"]}
|
24
|
+
|
25
|
+
def __init__(self, rl_agent):
|
26
|
+
super().__init__()
|
27
|
+
|
28
|
+
self.rl_agent = rl_agent
|
29
|
+
self.time_step = 0
|
30
|
+
self.state = None
|
31
|
+
self.is_episode_done = False
|
32
|
+
|
33
|
+
# An RL env waits for the event that it gets the current state from RL agent
|
34
|
+
self.state_got = asyncio.Event()
|
35
|
+
|
36
|
+
# An RL agent waits for the event that the RL env finishes step()
|
37
|
+
# so that it can start a new FL round
|
38
|
+
self.step_done = asyncio.Event()
|
39
|
+
|
40
|
+
# Normalize action space and make it symmetric when continuous.
|
41
|
+
# The reasons behind:
|
42
|
+
# https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html#tips-and-tricks-when-creating-a-custom-environment
|
43
|
+
n_actions = 1
|
44
|
+
self.action_space = spaces.Box(
|
45
|
+
low=-1, high=1, shape=(n_actions,), dtype="float32"
|
46
|
+
)
|
47
|
+
|
48
|
+
# Use only global model accurarcy as state for now
|
49
|
+
self.n_states = 1
|
50
|
+
# Also normalize observation space for better RL training
|
51
|
+
self.observation_space = spaces.Box(
|
52
|
+
low=-1, high=1, shape=(self.n_states,), dtype="float32"
|
53
|
+
)
|
54
|
+
|
55
|
+
self.state = [0 for i in range(self.n_states)]
|
56
|
+
|
57
|
+
def reset(self):
|
58
|
+
if self.rl_agent.rl_episode >= Config().algorithm.rl_episodes:
|
59
|
+
while True:
|
60
|
+
# Give RL agent some time to close connections and exit
|
61
|
+
current_loop = asyncio.get_event_loop()
|
62
|
+
task = current_loop.create_task(asyncio.sleep(1))
|
63
|
+
current_loop.run_until_complete(task)
|
64
|
+
|
65
|
+
logging.info("Reseting RL environment.")
|
66
|
+
|
67
|
+
self.time_step = 0
|
68
|
+
# Let the RL agent restart FL training
|
69
|
+
self.rl_agent.reset_rl_env()
|
70
|
+
|
71
|
+
self.rl_agent.new_episode_begin.set()
|
72
|
+
|
73
|
+
self.state = [0 for i in range(self.n_states)]
|
74
|
+
return np.array(self.state)
|
75
|
+
|
76
|
+
def step(self, action):
|
77
|
+
"""One step of reinforcement learning."""
|
78
|
+
assert self.action_space.contains(action), "%r (%s) invalid" % (
|
79
|
+
action,
|
80
|
+
type(action),
|
81
|
+
)
|
82
|
+
self.time_step += 1
|
83
|
+
reward = float(0)
|
84
|
+
self.is_episode_done = False
|
85
|
+
|
86
|
+
# For testing code
|
87
|
+
current_edge_agg_num = self.time_step
|
88
|
+
|
89
|
+
# Rescale the action from [-1, 1] to [1, 2, ... , 9]
|
90
|
+
# The action is the number of aggregations on edge servers
|
91
|
+
# current_edge_agg_num = int((action + 2) * (action + 2))
|
92
|
+
|
93
|
+
logging.info("RL Agent: Start time step #%s...", self.time_step)
|
94
|
+
logging.info(
|
95
|
+
"Each edge server will run %s rounds of local aggregation.",
|
96
|
+
current_edge_agg_num,
|
97
|
+
)
|
98
|
+
|
99
|
+
# Pass the tuned parameter to RL agent
|
100
|
+
self.rl_agent.get_tuned_para(current_edge_agg_num, self.time_step)
|
101
|
+
|
102
|
+
# Wait for state
|
103
|
+
current_loop = asyncio.get_event_loop()
|
104
|
+
get_state_task = current_loop.create_task(self.wait_for_state())
|
105
|
+
current_loop.run_until_complete(get_state_task)
|
106
|
+
# print('State:', self.state)
|
107
|
+
|
108
|
+
self.normalize_state()
|
109
|
+
# print('Normalized state:', self.state)
|
110
|
+
|
111
|
+
reward = self.get_reward()
|
112
|
+
info = {}
|
113
|
+
|
114
|
+
self.rl_agent.cumulative_reward += reward
|
115
|
+
|
116
|
+
# Signal the RL agent to start next time step (next round of FL)
|
117
|
+
self.step_done.set()
|
118
|
+
|
119
|
+
return np.array([self.state]), reward, self.is_episode_done, info
|
120
|
+
|
121
|
+
async def wait_for_state(self):
|
122
|
+
"""Wait for getting the current state."""
|
123
|
+
await self.state_got.wait()
|
124
|
+
assert self.time_step == self.rl_agent.current_round
|
125
|
+
self.state_got.clear()
|
126
|
+
|
127
|
+
def get_state(self, state, is_episode_done):
|
128
|
+
"""
|
129
|
+
Get transitted state from RL agent.
|
130
|
+
This function is called by RL agent.
|
131
|
+
"""
|
132
|
+
self.state = state
|
133
|
+
self.is_episode_done = is_episode_done
|
134
|
+
# Signal the RL env that it gets the current state
|
135
|
+
self.state_got.set()
|
136
|
+
print("RL env: Get state", state)
|
137
|
+
self.rl_agent.is_rl_tuned_para_got = False
|
138
|
+
|
139
|
+
def normalize_state(self):
|
140
|
+
"""Normalize each element of state to [-1,1]."""
|
141
|
+
self.state = 2 * (self.state - 0.5)
|
142
|
+
|
143
|
+
def get_reward(self):
|
144
|
+
"""Get reward based on the state."""
|
145
|
+
accuracy = self.state
|
146
|
+
# Use accuracy as reward for now.
|
147
|
+
reward = accuracy
|
148
|
+
return reward
|
149
|
+
|
150
|
+
def render(self, mode="rl"):
|
151
|
+
pass
|
152
|
+
|
153
|
+
def close(self):
|
154
|
+
pass
|
plato/utils/s3.py
ADDED
@@ -0,0 +1,141 @@
|
|
1
|
+
"""
|
2
|
+
Utilities to transmit Python objects to and from an S3-compatible object storage service.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pickle
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import boto3
|
9
|
+
import botocore
|
10
|
+
import requests
|
11
|
+
|
12
|
+
from plato.config import Config
|
13
|
+
|
14
|
+
|
15
|
+
class S3:
|
16
|
+
"""Manages the utilities to transmit Python objects to and from an S3-compatibile
|
17
|
+
object storage service.
|
18
|
+
"""
|
19
|
+
|
20
|
+
def __init__(self, endpoint=None, access_key=None, secret_key=None, bucket=None):
|
21
|
+
"""All S3-related credentials, such as the access key and the secret key,
|
22
|
+
are either to be stored in ~/.aws/credentials by using the 'aws configure'
|
23
|
+
command, passed into the constructor as parameters, or specified in the
|
24
|
+
`server` section of the configuration file.
|
25
|
+
"""
|
26
|
+
self.endpoint = endpoint
|
27
|
+
self.bucket = bucket
|
28
|
+
self.key_prefix = ""
|
29
|
+
self.access_key = access_key
|
30
|
+
self.secret_key = secret_key
|
31
|
+
|
32
|
+
if hasattr(Config().server, "s3_endpoint_url"):
|
33
|
+
self.endpoint = Config().server.s3_endpoint_url
|
34
|
+
|
35
|
+
if hasattr(Config().server, "s3_bucket"):
|
36
|
+
self.bucket = Config().server.s3_bucket
|
37
|
+
|
38
|
+
if hasattr(Config().server, "access_key"):
|
39
|
+
self.access_key = Config().server.access_key
|
40
|
+
|
41
|
+
if hasattr(Config().server, "secret_key"):
|
42
|
+
self.secret_key = Config().server.secret_key
|
43
|
+
|
44
|
+
if self.bucket is None:
|
45
|
+
raise ValueError("The S3 storage service has not been properly configured.")
|
46
|
+
|
47
|
+
if "s3://" in self.bucket:
|
48
|
+
bucket_part = self.bucket[5:]
|
49
|
+
str_list = bucket_part.split("/")
|
50
|
+
self.bucket = str_list[0]
|
51
|
+
if len(str_list) > 1:
|
52
|
+
self.key_prefix = bucket_part[len(self.bucket) :]
|
53
|
+
|
54
|
+
if self.access_key is not None and self.secret_key is not None:
|
55
|
+
self.s3_client = boto3.client(
|
56
|
+
"s3",
|
57
|
+
endpoint_url=self.endpoint,
|
58
|
+
aws_access_key_id=self.access_key,
|
59
|
+
aws_secret_access_key=self.secret_key,
|
60
|
+
)
|
61
|
+
else:
|
62
|
+
# the access key and secret key are stored locally in ~/.aws/credentials
|
63
|
+
self.s3_client = boto3.client("s3", endpoint_url=self.endpoint)
|
64
|
+
|
65
|
+
# Does the bucket exist?
|
66
|
+
try:
|
67
|
+
self.s3_client.head_bucket(Bucket=self.bucket)
|
68
|
+
except botocore.exceptions.ClientError:
|
69
|
+
try:
|
70
|
+
self.s3_client.create_bucket(Bucket=self.bucket)
|
71
|
+
except botocore.exceptions.ClientError as s3_exception:
|
72
|
+
raise ValueError("Fail to create a bucket.") from s3_exception
|
73
|
+
|
74
|
+
def send_to_s3(self, object_key, object_to_send) -> str:
|
75
|
+
"""Sends an object to an S3-compatible object storage service.
|
76
|
+
|
77
|
+
Returns: A presigned URL for use later to retrieve the data.
|
78
|
+
"""
|
79
|
+
object_key = self.key_prefix + "/" + object_key
|
80
|
+
try:
|
81
|
+
# Does the object key exist already in S3?
|
82
|
+
self.s3_client.head_object(Bucket=self.bucket, Key=object_key)
|
83
|
+
except botocore.exceptions.ClientError:
|
84
|
+
try:
|
85
|
+
# Only send the object if the key does not exist yet
|
86
|
+
data = pickle.dumps(object_to_send)
|
87
|
+
put_url = self.s3_client.generate_presigned_url(
|
88
|
+
ClientMethod="put_object",
|
89
|
+
Params={"Bucket": self.bucket, "Key": object_key},
|
90
|
+
ExpiresIn=300,
|
91
|
+
)
|
92
|
+
response = requests.put(put_url, data=data)
|
93
|
+
|
94
|
+
if response.status_code != 200:
|
95
|
+
raise ValueError(
|
96
|
+
f"Error occurred sending data: status code = {response.status_code}"
|
97
|
+
) from None
|
98
|
+
|
99
|
+
except botocore.exceptions.ClientError as error:
|
100
|
+
raise ValueError(
|
101
|
+
f"Error occurred sending data to S3: {error}"
|
102
|
+
) from error
|
103
|
+
|
104
|
+
except botocore.exceptions.ParamValidationError as error:
|
105
|
+
raise ValueError(f"Incorrect parameters: {error}") from error
|
106
|
+
|
107
|
+
def receive_from_s3(self, object_key) -> Any:
|
108
|
+
"""Retrieves an object from an S3-compatible object storage service.
|
109
|
+
|
110
|
+
All S3-related credentials, such as the access key and the secret key,
|
111
|
+
are assumed to be stored in ~/.aws/credentials by using the 'aws configure'
|
112
|
+
command.
|
113
|
+
|
114
|
+
Returns: The object to be retrieved.
|
115
|
+
"""
|
116
|
+
object_key = self.key_prefix + "/" + object_key
|
117
|
+
get_url = self.s3_client.generate_presigned_url(
|
118
|
+
ClientMethod="get_object",
|
119
|
+
Params={"Bucket": self.bucket, "Key": object_key},
|
120
|
+
ExpiresIn=300,
|
121
|
+
)
|
122
|
+
response = requests.get(get_url)
|
123
|
+
|
124
|
+
if response.status_code == 200:
|
125
|
+
return pickle.loads(response.content)
|
126
|
+
|
127
|
+
raise ValueError(
|
128
|
+
f"Error occurred sending data: request status code = {response.status_code}"
|
129
|
+
)
|
130
|
+
|
131
|
+
def delete_from_s3(self, object_key):
|
132
|
+
"""Deletes an object using its key from S3."""
|
133
|
+
__ = self.s3_client.delete_object(Bucket=self.bucket, Key=object_key)
|
134
|
+
|
135
|
+
def lists(self):
|
136
|
+
"""Retrieves keys to all the objects in the S3 bucket."""
|
137
|
+
response = self.s3_client.list_objects_v2(Bucket=self.bucket)
|
138
|
+
keys = []
|
139
|
+
for obj in response["Contents"]:
|
140
|
+
keys.append(obj["Key"])
|
141
|
+
return keys
|
@@ -0,0 +1,21 @@
|
|
1
|
+
"""
|
2
|
+
The necessary tools used by trainers.
|
3
|
+
"""
|
4
|
+
|
5
|
+
|
6
|
+
def freeze_model(model, layer_names=None):
|
7
|
+
"""Freeze a part of the model."""
|
8
|
+
if layer_names is not None:
|
9
|
+
frozen_params = []
|
10
|
+
for name, param in model.named_parameters():
|
11
|
+
if any(param_name in name for param_name in layer_names):
|
12
|
+
param.requires_grad = False
|
13
|
+
frozen_params.append(name)
|
14
|
+
|
15
|
+
|
16
|
+
def activate_model(model, layer_names=None):
|
17
|
+
"""Activate a part of the model."""
|
18
|
+
if layer_names is not None:
|
19
|
+
for name, param in model.named_parameters():
|
20
|
+
if any(param_name in name for param_name in layer_names):
|
21
|
+
param.requires_grad = True
|
@@ -0,0 +1,47 @@
|
|
1
|
+
"""Implements unary encoding, used by Google's RAPPOR, as the local differential privacy mechanism.
|
2
|
+
|
3
|
+
References:
|
4
|
+
|
5
|
+
Wang, et al. "Optimizing Locally Differentially Private Protocols," ATC USENIX 2017.
|
6
|
+
|
7
|
+
Erlingsson, et al. "RAPPOR: Randomized Aggregatable Privacy-Preserving Ordinal Response,"
|
8
|
+
ACM CCS 2014.
|
9
|
+
|
10
|
+
"""
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
|
14
|
+
|
15
|
+
def encode(x: np.ndarray):
|
16
|
+
x[x > 0] = 1
|
17
|
+
x[x <= 0] = 0
|
18
|
+
return x
|
19
|
+
|
20
|
+
|
21
|
+
def randomize(bit_array: np.ndarray, epsilon):
|
22
|
+
"""
|
23
|
+
The default unary encoding method is symmetric.
|
24
|
+
"""
|
25
|
+
assert isinstance(bit_array, np.ndarray)
|
26
|
+
return symmetric_unary_encoding(bit_array, epsilon)
|
27
|
+
|
28
|
+
|
29
|
+
def symmetric_unary_encoding(bit_array: np.ndarray, epsilon):
|
30
|
+
p = np.e ** (epsilon / 2) / (np.e ** (epsilon / 2) + 1)
|
31
|
+
q = 1 / (np.e ** (epsilon / 2) + 1)
|
32
|
+
return produce_randomized_response(bit_array, p, q)
|
33
|
+
|
34
|
+
|
35
|
+
def optimized_unary_encoding(bit_array: np.ndarray, epsilon):
|
36
|
+
p = 1 / 2
|
37
|
+
q = 1 / (np.e**epsilon + 1)
|
38
|
+
return produce_randomized_response(bit_array, p, q)
|
39
|
+
|
40
|
+
|
41
|
+
def produce_randomized_response(bit_array: np.ndarray, p, q=None):
|
42
|
+
"""Implements randomized response as the perturbation method."""
|
43
|
+
q = 1 - p if q is None else q
|
44
|
+
|
45
|
+
p_binomial = np.random.binomial(1, p, bit_array.shape)
|
46
|
+
q_binomial = np.random.binomial(1, q, bit_array.shape)
|
47
|
+
return np.where(bit_array == 1, p_binomial, q_binomial)
|
@@ -0,0 +1,35 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: plato-learn
|
3
|
+
Version: 1.1
|
4
|
+
Summary: Plato: a research framework for federated learning
|
5
|
+
Project-URL: Homepage, https://github.com/TL-System/plato
|
6
|
+
Project-URL: Repository, https://github.com/TL-System/plato
|
7
|
+
Project-URL: Documentation, https://platodocs.netlify.app/
|
8
|
+
License-Expression: Apache-2.0
|
9
|
+
License-File: LICENSE
|
10
|
+
Requires-Python: >=3.13
|
11
|
+
Requires-Dist: aiohttp
|
12
|
+
Requires-Dist: boto3
|
13
|
+
Requires-Dist: datasets
|
14
|
+
Requires-Dist: evaluate
|
15
|
+
Requires-Dist: gym
|
16
|
+
Requires-Dist: lightly
|
17
|
+
Requires-Dist: numpy
|
18
|
+
Requires-Dist: opacus
|
19
|
+
Requires-Dist: python-socketio
|
20
|
+
Requires-Dist: pyyaml
|
21
|
+
Requires-Dist: requests
|
22
|
+
Requires-Dist: tenseal
|
23
|
+
Requires-Dist: timm
|
24
|
+
Requires-Dist: torch
|
25
|
+
Requires-Dist: torch-optimizer
|
26
|
+
Requires-Dist: torchvision
|
27
|
+
Requires-Dist: transformers
|
28
|
+
Requires-Dist: zstd
|
29
|
+
Provides-Extra: dev
|
30
|
+
Requires-Dist: pytest; extra == 'dev'
|
31
|
+
Description-Content-Type: text/markdown
|
32
|
+
|
33
|
+
# Plato: A New Framework for Scalable Federated Learning Research
|
34
|
+
|
35
|
+
Welcome to *Plato*, a software framework to facilitate scalable, reproducible, and extensible federated learning research. Please refer to the documentation website, available in `/documentation`, for more details on installing, running and deploying Plato.
|