plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,113 @@
1
+ """
2
+ A federated learning server with RL Agent.
3
+ """
4
+
5
+ import asyncio
6
+ import logging
7
+ from abc import abstractmethod
8
+
9
+ from plato.servers import fedavg
10
+
11
+
12
+ class RLServer(fedavg.Server):
13
+ """A federated learning server with an RL Agent."""
14
+
15
+ def __init__(
16
+ self,
17
+ agent,
18
+ model=None,
19
+ datasource=None,
20
+ algorithm=None,
21
+ trainer=None,
22
+ callbacks=None,
23
+ ):
24
+ super().__init__(
25
+ model=model,
26
+ datasource=datasource,
27
+ algorithm=algorithm,
28
+ trainer=trainer,
29
+ callbacks=callbacks,
30
+ )
31
+ self.agent = agent
32
+
33
+ def reset(self):
34
+ """Resetting the model, trainer, and algorithm on the server."""
35
+ logging.info(
36
+ "Reconfiguring the server for episode %d", self.agent.current_episode
37
+ )
38
+
39
+ self.model = None
40
+ self.trainer = None
41
+ self.algorithm = None
42
+ self.init_trainer()
43
+
44
+ self.current_round = 0
45
+
46
+ async def aggregate_deltas(self, updates, deltas_received):
47
+ """Aggregate weight updates from the clients using smart weighting."""
48
+ self.update_state()
49
+
50
+ # Extract the total number of samples
51
+ num_samples = [update.report.num_samples for update in updates]
52
+ self.total_samples = sum(num_samples)
53
+
54
+ # Perform weighted averaging
55
+ avg_update = {
56
+ name: self.trainer.zeros(weights.shape)
57
+ for name, weights in deltas_received[0].items()
58
+ }
59
+
60
+ # e.g., wait for the new action from RL agent
61
+ # if the action affects the global aggregation
62
+ self.agent.num_samples = num_samples
63
+ await self.agent.prep_agent_update()
64
+ await self.update_action()
65
+
66
+ # Use adaptive weighted average
67
+ for i, update in enumerate(deltas_received):
68
+ for name, delta in update.items():
69
+ if delta.type() == "torch.LongTensor":
70
+ avg_update[name] += delta * self.smart_weighting[i][0]
71
+ else:
72
+ avg_update[name] += delta * self.smart_weighting[i]
73
+
74
+ # Yield to other tasks in the server
75
+ await asyncio.sleep(0)
76
+
77
+ return avg_update
78
+
79
+ async def update_action(self):
80
+ """Updating the RL agent's actions."""
81
+ if self.agent.current_step == 0:
82
+ logging.info("[RL Agent] Preparing initial action...")
83
+ self.agent.prep_action()
84
+ else:
85
+ await self.agent.action_updated.wait()
86
+ self.agent.action_updated.clear()
87
+
88
+ self.apply_action()
89
+
90
+ def update_state(self):
91
+ """Wrap up the state update to RL Agent."""
92
+ # Pass new state to RL Agent
93
+ self.agent.new_state = self.prep_state()
94
+ self.agent.process_env_update()
95
+
96
+ async def wrap_up(self) -> None:
97
+ """Wrapping up when each round of training is done."""
98
+ self.save_to_checkpoint()
99
+
100
+ if self.agent.reset_env:
101
+ self.agent.reset_env = False
102
+ self.reset()
103
+ if self.agent.finished:
104
+ await self._close()
105
+
106
+ @abstractmethod
107
+ def prep_state(self):
108
+ """Wrap up the state update to RL Agent."""
109
+ return
110
+
111
+ @abstractmethod
112
+ def apply_action(self):
113
+ """Apply action update from RL Agent to FL Env."""
plato/utils/rl_env.py ADDED
@@ -0,0 +1,154 @@
1
+ """
2
+ An environment of the reinforcement learning agent for tuning parameters
3
+ during the training of federated learning.
4
+ This environment follows the gym interface, in order to use stable-baselines3:
5
+ https://github.com/DLR-RM/stable-baselines3.
6
+
7
+ To create and use other custom environments, check out:
8
+ https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html.
9
+ """
10
+
11
+ import asyncio
12
+ import logging
13
+
14
+ import gym
15
+ import numpy as np
16
+ from plato.config import Config
17
+ from gym import spaces
18
+
19
+
20
+ class RLEnv(gym.Env):
21
+ """The environment of federated learning."""
22
+
23
+ metadata = {"render.modes": ["fl"]}
24
+
25
+ def __init__(self, rl_agent):
26
+ super().__init__()
27
+
28
+ self.rl_agent = rl_agent
29
+ self.time_step = 0
30
+ self.state = None
31
+ self.is_episode_done = False
32
+
33
+ # An RL env waits for the event that it gets the current state from RL agent
34
+ self.state_got = asyncio.Event()
35
+
36
+ # An RL agent waits for the event that the RL env finishes step()
37
+ # so that it can start a new FL round
38
+ self.step_done = asyncio.Event()
39
+
40
+ # Normalize action space and make it symmetric when continuous.
41
+ # The reasons behind:
42
+ # https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html#tips-and-tricks-when-creating-a-custom-environment
43
+ n_actions = 1
44
+ self.action_space = spaces.Box(
45
+ low=-1, high=1, shape=(n_actions,), dtype="float32"
46
+ )
47
+
48
+ # Use only global model accurarcy as state for now
49
+ self.n_states = 1
50
+ # Also normalize observation space for better RL training
51
+ self.observation_space = spaces.Box(
52
+ low=-1, high=1, shape=(self.n_states,), dtype="float32"
53
+ )
54
+
55
+ self.state = [0 for i in range(self.n_states)]
56
+
57
+ def reset(self):
58
+ if self.rl_agent.rl_episode >= Config().algorithm.rl_episodes:
59
+ while True:
60
+ # Give RL agent some time to close connections and exit
61
+ current_loop = asyncio.get_event_loop()
62
+ task = current_loop.create_task(asyncio.sleep(1))
63
+ current_loop.run_until_complete(task)
64
+
65
+ logging.info("Reseting RL environment.")
66
+
67
+ self.time_step = 0
68
+ # Let the RL agent restart FL training
69
+ self.rl_agent.reset_rl_env()
70
+
71
+ self.rl_agent.new_episode_begin.set()
72
+
73
+ self.state = [0 for i in range(self.n_states)]
74
+ return np.array(self.state)
75
+
76
+ def step(self, action):
77
+ """One step of reinforcement learning."""
78
+ assert self.action_space.contains(action), "%r (%s) invalid" % (
79
+ action,
80
+ type(action),
81
+ )
82
+ self.time_step += 1
83
+ reward = float(0)
84
+ self.is_episode_done = False
85
+
86
+ # For testing code
87
+ current_edge_agg_num = self.time_step
88
+
89
+ # Rescale the action from [-1, 1] to [1, 2, ... , 9]
90
+ # The action is the number of aggregations on edge servers
91
+ # current_edge_agg_num = int((action + 2) * (action + 2))
92
+
93
+ logging.info("RL Agent: Start time step #%s...", self.time_step)
94
+ logging.info(
95
+ "Each edge server will run %s rounds of local aggregation.",
96
+ current_edge_agg_num,
97
+ )
98
+
99
+ # Pass the tuned parameter to RL agent
100
+ self.rl_agent.get_tuned_para(current_edge_agg_num, self.time_step)
101
+
102
+ # Wait for state
103
+ current_loop = asyncio.get_event_loop()
104
+ get_state_task = current_loop.create_task(self.wait_for_state())
105
+ current_loop.run_until_complete(get_state_task)
106
+ # print('State:', self.state)
107
+
108
+ self.normalize_state()
109
+ # print('Normalized state:', self.state)
110
+
111
+ reward = self.get_reward()
112
+ info = {}
113
+
114
+ self.rl_agent.cumulative_reward += reward
115
+
116
+ # Signal the RL agent to start next time step (next round of FL)
117
+ self.step_done.set()
118
+
119
+ return np.array([self.state]), reward, self.is_episode_done, info
120
+
121
+ async def wait_for_state(self):
122
+ """Wait for getting the current state."""
123
+ await self.state_got.wait()
124
+ assert self.time_step == self.rl_agent.current_round
125
+ self.state_got.clear()
126
+
127
+ def get_state(self, state, is_episode_done):
128
+ """
129
+ Get transitted state from RL agent.
130
+ This function is called by RL agent.
131
+ """
132
+ self.state = state
133
+ self.is_episode_done = is_episode_done
134
+ # Signal the RL env that it gets the current state
135
+ self.state_got.set()
136
+ print("RL env: Get state", state)
137
+ self.rl_agent.is_rl_tuned_para_got = False
138
+
139
+ def normalize_state(self):
140
+ """Normalize each element of state to [-1,1]."""
141
+ self.state = 2 * (self.state - 0.5)
142
+
143
+ def get_reward(self):
144
+ """Get reward based on the state."""
145
+ accuracy = self.state
146
+ # Use accuracy as reward for now.
147
+ reward = accuracy
148
+ return reward
149
+
150
+ def render(self, mode="rl"):
151
+ pass
152
+
153
+ def close(self):
154
+ pass
plato/utils/s3.py ADDED
@@ -0,0 +1,141 @@
1
+ """
2
+ Utilities to transmit Python objects to and from an S3-compatible object storage service.
3
+ """
4
+
5
+ import pickle
6
+ from typing import Any
7
+
8
+ import boto3
9
+ import botocore
10
+ import requests
11
+
12
+ from plato.config import Config
13
+
14
+
15
+ class S3:
16
+ """Manages the utilities to transmit Python objects to and from an S3-compatibile
17
+ object storage service.
18
+ """
19
+
20
+ def __init__(self, endpoint=None, access_key=None, secret_key=None, bucket=None):
21
+ """All S3-related credentials, such as the access key and the secret key,
22
+ are either to be stored in ~/.aws/credentials by using the 'aws configure'
23
+ command, passed into the constructor as parameters, or specified in the
24
+ `server` section of the configuration file.
25
+ """
26
+ self.endpoint = endpoint
27
+ self.bucket = bucket
28
+ self.key_prefix = ""
29
+ self.access_key = access_key
30
+ self.secret_key = secret_key
31
+
32
+ if hasattr(Config().server, "s3_endpoint_url"):
33
+ self.endpoint = Config().server.s3_endpoint_url
34
+
35
+ if hasattr(Config().server, "s3_bucket"):
36
+ self.bucket = Config().server.s3_bucket
37
+
38
+ if hasattr(Config().server, "access_key"):
39
+ self.access_key = Config().server.access_key
40
+
41
+ if hasattr(Config().server, "secret_key"):
42
+ self.secret_key = Config().server.secret_key
43
+
44
+ if self.bucket is None:
45
+ raise ValueError("The S3 storage service has not been properly configured.")
46
+
47
+ if "s3://" in self.bucket:
48
+ bucket_part = self.bucket[5:]
49
+ str_list = bucket_part.split("/")
50
+ self.bucket = str_list[0]
51
+ if len(str_list) > 1:
52
+ self.key_prefix = bucket_part[len(self.bucket) :]
53
+
54
+ if self.access_key is not None and self.secret_key is not None:
55
+ self.s3_client = boto3.client(
56
+ "s3",
57
+ endpoint_url=self.endpoint,
58
+ aws_access_key_id=self.access_key,
59
+ aws_secret_access_key=self.secret_key,
60
+ )
61
+ else:
62
+ # the access key and secret key are stored locally in ~/.aws/credentials
63
+ self.s3_client = boto3.client("s3", endpoint_url=self.endpoint)
64
+
65
+ # Does the bucket exist?
66
+ try:
67
+ self.s3_client.head_bucket(Bucket=self.bucket)
68
+ except botocore.exceptions.ClientError:
69
+ try:
70
+ self.s3_client.create_bucket(Bucket=self.bucket)
71
+ except botocore.exceptions.ClientError as s3_exception:
72
+ raise ValueError("Fail to create a bucket.") from s3_exception
73
+
74
+ def send_to_s3(self, object_key, object_to_send) -> str:
75
+ """Sends an object to an S3-compatible object storage service.
76
+
77
+ Returns: A presigned URL for use later to retrieve the data.
78
+ """
79
+ object_key = self.key_prefix + "/" + object_key
80
+ try:
81
+ # Does the object key exist already in S3?
82
+ self.s3_client.head_object(Bucket=self.bucket, Key=object_key)
83
+ except botocore.exceptions.ClientError:
84
+ try:
85
+ # Only send the object if the key does not exist yet
86
+ data = pickle.dumps(object_to_send)
87
+ put_url = self.s3_client.generate_presigned_url(
88
+ ClientMethod="put_object",
89
+ Params={"Bucket": self.bucket, "Key": object_key},
90
+ ExpiresIn=300,
91
+ )
92
+ response = requests.put(put_url, data=data)
93
+
94
+ if response.status_code != 200:
95
+ raise ValueError(
96
+ f"Error occurred sending data: status code = {response.status_code}"
97
+ ) from None
98
+
99
+ except botocore.exceptions.ClientError as error:
100
+ raise ValueError(
101
+ f"Error occurred sending data to S3: {error}"
102
+ ) from error
103
+
104
+ except botocore.exceptions.ParamValidationError as error:
105
+ raise ValueError(f"Incorrect parameters: {error}") from error
106
+
107
+ def receive_from_s3(self, object_key) -> Any:
108
+ """Retrieves an object from an S3-compatible object storage service.
109
+
110
+ All S3-related credentials, such as the access key and the secret key,
111
+ are assumed to be stored in ~/.aws/credentials by using the 'aws configure'
112
+ command.
113
+
114
+ Returns: The object to be retrieved.
115
+ """
116
+ object_key = self.key_prefix + "/" + object_key
117
+ get_url = self.s3_client.generate_presigned_url(
118
+ ClientMethod="get_object",
119
+ Params={"Bucket": self.bucket, "Key": object_key},
120
+ ExpiresIn=300,
121
+ )
122
+ response = requests.get(get_url)
123
+
124
+ if response.status_code == 200:
125
+ return pickle.loads(response.content)
126
+
127
+ raise ValueError(
128
+ f"Error occurred sending data: request status code = {response.status_code}"
129
+ )
130
+
131
+ def delete_from_s3(self, object_key):
132
+ """Deletes an object using its key from S3."""
133
+ __ = self.s3_client.delete_object(Bucket=self.bucket, Key=object_key)
134
+
135
+ def lists(self):
136
+ """Retrieves keys to all the objects in the S3 bucket."""
137
+ response = self.s3_client.list_objects_v2(Bucket=self.bucket)
138
+ keys = []
139
+ for obj in response["Contents"]:
140
+ keys.append(obj["Key"])
141
+ return keys
@@ -0,0 +1,21 @@
1
+ """
2
+ The necessary tools used by trainers.
3
+ """
4
+
5
+
6
+ def freeze_model(model, layer_names=None):
7
+ """Freeze a part of the model."""
8
+ if layer_names is not None:
9
+ frozen_params = []
10
+ for name, param in model.named_parameters():
11
+ if any(param_name in name for param_name in layer_names):
12
+ param.requires_grad = False
13
+ frozen_params.append(name)
14
+
15
+
16
+ def activate_model(model, layer_names=None):
17
+ """Activate a part of the model."""
18
+ if layer_names is not None:
19
+ for name, param in model.named_parameters():
20
+ if any(param_name in name for param_name in layer_names):
21
+ param.requires_grad = True
@@ -0,0 +1,47 @@
1
+ """Implements unary encoding, used by Google's RAPPOR, as the local differential privacy mechanism.
2
+
3
+ References:
4
+
5
+ Wang, et al. "Optimizing Locally Differentially Private Protocols," ATC USENIX 2017.
6
+
7
+ Erlingsson, et al. "RAPPOR: Randomized Aggregatable Privacy-Preserving Ordinal Response,"
8
+ ACM CCS 2014.
9
+
10
+ """
11
+
12
+ import numpy as np
13
+
14
+
15
+ def encode(x: np.ndarray):
16
+ x[x > 0] = 1
17
+ x[x <= 0] = 0
18
+ return x
19
+
20
+
21
+ def randomize(bit_array: np.ndarray, epsilon):
22
+ """
23
+ The default unary encoding method is symmetric.
24
+ """
25
+ assert isinstance(bit_array, np.ndarray)
26
+ return symmetric_unary_encoding(bit_array, epsilon)
27
+
28
+
29
+ def symmetric_unary_encoding(bit_array: np.ndarray, epsilon):
30
+ p = np.e ** (epsilon / 2) / (np.e ** (epsilon / 2) + 1)
31
+ q = 1 / (np.e ** (epsilon / 2) + 1)
32
+ return produce_randomized_response(bit_array, p, q)
33
+
34
+
35
+ def optimized_unary_encoding(bit_array: np.ndarray, epsilon):
36
+ p = 1 / 2
37
+ q = 1 / (np.e**epsilon + 1)
38
+ return produce_randomized_response(bit_array, p, q)
39
+
40
+
41
+ def produce_randomized_response(bit_array: np.ndarray, p, q=None):
42
+ """Implements randomized response as the perturbation method."""
43
+ q = 1 - p if q is None else q
44
+
45
+ p_binomial = np.random.binomial(1, p, bit_array.shape)
46
+ q_binomial = np.random.binomial(1, q, bit_array.shape)
47
+ return np.where(bit_array == 1, p_binomial, q_binomial)
@@ -0,0 +1,35 @@
1
+ Metadata-Version: 2.4
2
+ Name: plato-learn
3
+ Version: 1.1
4
+ Summary: Plato: a research framework for federated learning
5
+ Project-URL: Homepage, https://github.com/TL-System/plato
6
+ Project-URL: Repository, https://github.com/TL-System/plato
7
+ Project-URL: Documentation, https://platodocs.netlify.app/
8
+ License-Expression: Apache-2.0
9
+ License-File: LICENSE
10
+ Requires-Python: >=3.13
11
+ Requires-Dist: aiohttp
12
+ Requires-Dist: boto3
13
+ Requires-Dist: datasets
14
+ Requires-Dist: evaluate
15
+ Requires-Dist: gym
16
+ Requires-Dist: lightly
17
+ Requires-Dist: numpy
18
+ Requires-Dist: opacus
19
+ Requires-Dist: python-socketio
20
+ Requires-Dist: pyyaml
21
+ Requires-Dist: requests
22
+ Requires-Dist: tenseal
23
+ Requires-Dist: timm
24
+ Requires-Dist: torch
25
+ Requires-Dist: torch-optimizer
26
+ Requires-Dist: torchvision
27
+ Requires-Dist: transformers
28
+ Requires-Dist: zstd
29
+ Provides-Extra: dev
30
+ Requires-Dist: pytest; extra == 'dev'
31
+ Description-Content-Type: text/markdown
32
+
33
+ # Plato: A New Framework for Scalable Federated Learning Research
34
+
35
+ Welcome to *Plato*, a software framework to facilitate scalable, reproducible, and extensible federated learning research. Please refer to the documentation website, available in `/documentation`, for more details on installing, running and deploying Plato.