plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,254 @@
1
+ """
2
+ A factory that generates a Multi-layer perceptron (MLP), with the ability to build a fully-
3
+ connected network based on a specific configuration. This a very flexible MLP network generator to
4
+ define any type of MLP network.
5
+
6
+ Note: The general order of components in one MLP layer is:
7
+ Schema A: From the original paper of bn and dropout.
8
+ fc -> bn -> activation -> dropout -> ....
9
+
10
+ Schema B: From the researcher "https://math.stackexchange.com/users/167500/pseudomarvin".
11
+ fc -> activation -> dropout -> bn -> ....
12
+
13
+ See more discussion on the website:
14
+ https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout
15
+
16
+ Our work use the schema A.
17
+
18
+ Trick: One may just drop the Dropout (when you have BN) as BN eliminates the need for Dropout in
19
+ some cases, since intuitively BN provides similar regularization benefits as Dropout.
20
+
21
+ """
22
+
23
+ from typing import Union, Dict, List
24
+ from collections import OrderedDict
25
+
26
+ from torch import nn
27
+
28
+ activations_func = {
29
+ "relu": nn.ReLU,
30
+ "sigmoid": nn.Sigmoid,
31
+ "softmax": nn.Softmax,
32
+ "tanh": nn.Tanh,
33
+ }
34
+
35
+
36
+ # pylint: disable=too-many-locals
37
+ def build_mlp_from_config(
38
+ mlp_configs: Dict[str, Union[int, List[Union[str, None, dict]]]],
39
+ layer_name_prefix: str = "layer",
40
+ ):
41
+ """
42
+ Build the fully-connected network (Multi-layer perceptron)
43
+ based on the input configuration.
44
+
45
+ :param mlp_configs: A Dict type containing the hyper-parameters for definition.
46
+ It should contains:
47
+ input_dim: Integar
48
+ output_dim: Integar
49
+ hidden_layers_dim: List[int], with length N - 1
50
+ batch_norms: List[Union[None, dict]], with length N
51
+ activations: List[Union[None, str]], with length N
52
+ dropout_ratios: List[float], with length N
53
+ :param layer_name_prefix: A string added to the layer's name.
54
+ """
55
+ input_dim = mlp_configs["input_dim"]
56
+ output_dim = mlp_configs["output_dim"]
57
+
58
+ hidden_layers_dim = mlp_configs["hidden_layers_dim"]
59
+ hidden_n = len(hidden_layers_dim)
60
+
61
+ batch_norms = mlp_configs["batch_norms"]
62
+ activations = mlp_configs["activations"]
63
+ dropout_porbs = (
64
+ mlp_configs["dropout_ratios"]
65
+ if isinstance(mlp_configs["dropout_ratios"], list)
66
+ else [mlp_configs["dropout_ratios"]]
67
+ )
68
+
69
+ assert len(batch_norms) == len(activations) == len(dropout_porbs)
70
+ assert hidden_n == len(batch_norms) - 1
71
+
72
+ def build_one_layer(
73
+ layer_ipt_dim, layer_opt_dim, batch_norm_param, activation, dropout_prob
74
+ ):
75
+ """Build one layer of MLP. Default no hidden layer.
76
+
77
+ For the structure of one MLP layer. Please access the description
78
+ in the NOTE part.
79
+
80
+ """
81
+ layer_structure = OrderedDict()
82
+ layer_structure["fc"] = nn.Linear(layer_ipt_dim, layer_opt_dim)
83
+
84
+ if batch_norm_param is not None:
85
+ layer_structure["bn"] = nn.BatchNorm1d(layer_opt_dim, **batch_norm_param)
86
+ if activation is not None:
87
+ layer_structure[activation] = activations_func[activation]()
88
+ if dropout_prob != 0.0:
89
+ layer_structure["drop"] = nn.Dropout(p=dropout_prob)
90
+
91
+ return nn.Sequential(layer_structure)
92
+
93
+ mlp_layers = OrderedDict()
94
+
95
+ # add the final output layer to the hidden layer for building layers
96
+ hidden_layers_dim.append(output_dim)
97
+ for hid_id, hid_dim in enumerate(hidden_layers_dim):
98
+ layer_input_dim = input_dim if hid_id == 0 else hidden_layers_dim[hid_id - 1]
99
+ desired_batch_norm = batch_norms[hid_id]
100
+ activation = activations[hid_id]
101
+ dropout_prob = dropout_porbs[hid_id]
102
+ built_layer = build_one_layer(
103
+ layer_input_dim, hid_dim, desired_batch_norm, activation, dropout_prob
104
+ )
105
+ mlp_layers[layer_name_prefix + str(hid_id + 1)] = built_layer
106
+
107
+ return nn.Sequential(mlp_layers)
108
+
109
+
110
+ class Model:
111
+ """
112
+ The Multi-layer perceptron (MLP) model.
113
+
114
+ The implemented mlp networks are:
115
+ - linear_mlp, The mlp with one hidden layer.
116
+ - simclr_projection_mlp, The projection layer of SimCLR method.
117
+ - simsiam_projection_mlp, The projection layer of SimSiam method.
118
+ - simsiam_prediction_mlp, The prediction layer of SimSiam method.
119
+ - byol_projection_mlp, The projection layer of BYOL method.
120
+ - byol_prediction_mlp, The prediction layer of BYOL method.
121
+ - moco_final_mlp, The final layer of MoCo method.
122
+ - plato_multilayer, The Plato's multilayer.
123
+ - customized_mlp, The customized layer.
124
+ """
125
+
126
+ # pylint: disable=too-few-public-methods
127
+ @staticmethod
128
+ def get(
129
+ model_name: str,
130
+ input_dim: int,
131
+ output_dim: int,
132
+ **kwargs: Dict[str, Union[int, List[Union[str, None, dict]]]],
133
+ ):
134
+ # pylint:disable=too-many-return-statements
135
+ """Get the desired MLP model with required hyper-parameters (input_dim)."""
136
+
137
+ if model_name == "linear_mlp":
138
+ return build_mlp_from_config(
139
+ mlp_configs=dict(
140
+ output_dim=output_dim,
141
+ input_dim=input_dim,
142
+ hidden_layers_dim=[],
143
+ batch_norms=[None],
144
+ activations=[None],
145
+ dropout_ratios=[0.0],
146
+ )
147
+ )
148
+
149
+ if model_name == "simclr_projection_mlp":
150
+ projection_hidden_dim = kwargs["projection_hidden_dim"]
151
+ return build_mlp_from_config(
152
+ dict(
153
+ output_dim=output_dim,
154
+ input_dim=input_dim,
155
+ hidden_layers_dim=[projection_hidden_dim],
156
+ batch_norms=[None, None],
157
+ activations=["relu", None],
158
+ dropout_ratios=[0.0, 0.0],
159
+ )
160
+ )
161
+
162
+ if model_name == "simsiam_projection_mlp":
163
+ projection_hidden_dim = kwargs["projection_hidden_dim"]
164
+ return build_mlp_from_config(
165
+ dict(
166
+ output_dim=output_dim,
167
+ input_dim=input_dim,
168
+ hidden_layers_dim=[
169
+ projection_hidden_dim,
170
+ projection_hidden_dim,
171
+ ],
172
+ batch_norms=[
173
+ dict(momentum=0.1, eps=1e-5),
174
+ dict(momentum=0.1, eps=1e-5),
175
+ dict(momentum=0.1, eps=1e-5),
176
+ ],
177
+ activations=["relu", "relu", None],
178
+ dropout_ratios=[0.0, 0.0, 0.0],
179
+ )
180
+ )
181
+
182
+ if model_name == "simsiam_prediction_mlp":
183
+ prediction_hidden_dim = kwargs["prediction_hidden_dim"]
184
+ return build_mlp_from_config(
185
+ dict(
186
+ output_dim=output_dim,
187
+ input_dim=input_dim,
188
+ hidden_layers_dim=[prediction_hidden_dim],
189
+ batch_norms=[dict(momentum=0.1, eps=1e-5), None],
190
+ activations=["relu", None],
191
+ dropout_ratios=[0.0, 0.0],
192
+ )
193
+ )
194
+
195
+ if model_name == "byol_projection_mlp":
196
+ projection_hidden_dim = kwargs["projection_hidden_dim"]
197
+ return build_mlp_from_config(
198
+ dict(
199
+ output_dim=output_dim,
200
+ input_dim=input_dim,
201
+ hidden_layers_dim=[projection_hidden_dim],
202
+ batch_norms=[dict(momentum=0.1, eps=1e-5), None],
203
+ activations=["relu", None],
204
+ dropout_ratios=[0.0, 0.0],
205
+ )
206
+ )
207
+
208
+ if model_name == "byol_prediction_mlp":
209
+ prediction_hidden_dim = kwargs["prediction_hidden_dim"]
210
+ return build_mlp_from_config(
211
+ dict(
212
+ output_dim=output_dim,
213
+ input_dim=input_dim,
214
+ hidden_layers_dim=[prediction_hidden_dim],
215
+ batch_norms=[dict(momentum=0.1, eps=1e-5), None],
216
+ activations=["relu", None],
217
+ dropout_ratios=[0.0, 0.0],
218
+ )
219
+ )
220
+
221
+ if model_name == "moco_final_mlp":
222
+ projection_hidden_dim = kwargs["projection_hidden_dim"]
223
+ return build_mlp_from_config(
224
+ dict(
225
+ output_dim=output_dim,
226
+ input_dim=input_dim,
227
+ hidden_layers_dim=[projection_hidden_dim],
228
+ batch_norms=[None, None],
229
+ activations=["relu", None],
230
+ dropout_ratios=[0.0, 0.0],
231
+ )
232
+ )
233
+
234
+ if model_name == "plato_multilayer":
235
+ return build_mlp_from_config(
236
+ dict(
237
+ output_dim=output_dim,
238
+ input_dim=input_dim,
239
+ hidden_layers_dim=[1024, 512, 256, 128],
240
+ batch_norms=[None, None, None, None, None],
241
+ activations=["tanh", "tanh", "tanh", "tanh", None],
242
+ dropout_ratios=[0.0, 0.0, 0.0, 0.0, 0.0],
243
+ )
244
+ )
245
+
246
+ # obtain the customized mlp laye
247
+ # the user needs to put the corresponding hyper-parameters
248
+ # in the 'kwargs'
249
+ if model_name == "customized_mlp":
250
+ return build_mlp_from_config(
251
+ dict(output_dim=output_dim, input_dim=input_dim, **kwargs)
252
+ )
253
+
254
+ raise ValueError(f"No such MLP model: {model_name}")
@@ -0,0 +1,27 @@
1
+ """
2
+ Obtaining a model from the PyTorch Hub.
3
+ """
4
+
5
+ from transformers import AutoModelForCausalLM, AutoConfig
6
+ from plato.config import Config
7
+
8
+
9
+ class Model:
10
+ """The CausalLM model loaded from HuggingFace."""
11
+
12
+ @staticmethod
13
+ def get(model_name=None, **kwargs): # pylint: disable=unused-argument
14
+ """Returns a named model from HuggingFace."""
15
+ config_kwargs = {
16
+ "cache_dir": None,
17
+ "revision": "main",
18
+ "use_auth_token": None,
19
+ }
20
+
21
+ config = AutoConfig.from_pretrained(model_name, **config_kwargs)
22
+
23
+ return AutoModelForCausalLM.from_pretrained(
24
+ model_name,
25
+ config=config,
26
+ cache_dir=Config().params["model_path"] + "/huggingface",
27
+ )
plato/models/lenet5.py ADDED
@@ -0,0 +1,113 @@
1
+ """The LeNet-5 model for PyTorch.
2
+
3
+ Reference:
4
+
5
+ Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based learning applied to
6
+ document recognition." Proceedings of the IEEE, November 1998.
7
+ """
8
+
9
+ import collections
10
+
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class Model(nn.Module):
16
+ """The LeNet-5 model.
17
+
18
+ :param num_classes: The number of classes. The default value is 10.
19
+ """
20
+
21
+ def __init__(self, num_classes: int = 10, cut_layer=None):
22
+ super().__init__()
23
+ self.cut_layer = cut_layer
24
+
25
+ # We pad the image to get an input size of 32x32 as for the
26
+ # original network in the LeCun paper
27
+ self.conv1 = nn.Conv2d(
28
+ in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2, bias=True
29
+ )
30
+ self.relu1 = nn.ReLU()
31
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
32
+ self.conv2 = nn.Conv2d(
33
+ in_channels=6,
34
+ out_channels=16,
35
+ kernel_size=5,
36
+ stride=1,
37
+ padding=0,
38
+ bias=True,
39
+ )
40
+ self.relu2 = nn.ReLU()
41
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
42
+ self.conv3 = nn.Conv2d(
43
+ in_channels=16, out_channels=120, kernel_size=5, bias=True
44
+ )
45
+ self.relu3 = nn.ReLU()
46
+ self.fc4 = nn.Linear(120, 84)
47
+ self.relu4 = nn.ReLU()
48
+ self.fc5 = nn.Linear(84, num_classes)
49
+
50
+ # Preparing named layers so that the model can be split and straddle
51
+ # across the client and the server
52
+ self.layers = []
53
+ self.layerdict = collections.OrderedDict()
54
+ self.layerdict["conv1"] = self.conv1
55
+ self.layerdict["relu1"] = self.relu1
56
+ self.layerdict["pool1"] = self.pool1
57
+ self.layerdict["conv2"] = self.conv2
58
+ self.layerdict["relu2"] = self.relu2
59
+ self.layerdict["pool2"] = self.pool2
60
+ self.layerdict["conv3"] = self.conv3
61
+ self.layerdict["relu3"] = self.relu3
62
+ self.layerdict["flatten"] = self.flatten
63
+ self.layerdict["fc4"] = self.fc4
64
+ self.layerdict["relu4"] = self.relu4
65
+ self.layerdict["fc5"] = self.fc5
66
+ self.layers.append("conv1")
67
+ self.layers.append("relu1")
68
+ self.layers.append("pool1")
69
+ self.layers.append("conv2")
70
+ self.layers.append("relu2")
71
+ self.layers.append("pool2")
72
+ self.layers.append("conv3")
73
+ self.layers.append("relu3")
74
+ self.layers.append("flatten")
75
+ self.layers.append("fc4")
76
+ self.layers.append("relu4")
77
+ self.layers.append("fc5")
78
+
79
+ def flatten(self, x):
80
+ """Flatten the tensor."""
81
+ return x.view(x.size(0), -1)
82
+
83
+ def forward(self, x):
84
+ """Forward pass."""
85
+ if self.cut_layer is not None and self.training:
86
+ layer_index = self.layers.index(self.cut_layer)
87
+
88
+ for i in range(layer_index + 1, len(self.layers)):
89
+ x = self.layerdict[self.layers[i]](x)
90
+ else:
91
+ x = self.conv1(x)
92
+ x = self.relu1(x)
93
+ x = self.pool1(x)
94
+ x = self.conv2(x)
95
+ x = self.relu2(x)
96
+ x = self.pool2(x)
97
+ x = self.conv3(x)
98
+ x = self.relu3(x)
99
+ x = self.flatten(x)
100
+ x = self.fc4(x)
101
+ x = self.relu4(x)
102
+ x = self.fc5(x)
103
+
104
+ return F.log_softmax(x, dim=1)
105
+
106
+ def forward_to(self, x):
107
+ """Forward pass, but only to the layer specified by cut_layer."""
108
+ layer_index = self.layers.index(self.cut_layer)
109
+
110
+ for i in range(0, layer_index + 1):
111
+ x = self.layerdict[self.layers[i]](x)
112
+
113
+ return x
@@ -0,0 +1,90 @@
1
+ """
2
+ The Multi-Layer Perception model for PyTorch.
3
+ The model follows the previous work to use tanh as activation
4
+ Reference: https://www.comp.nus.edu.sg/~reza/files/Shokri-SP2019.pdf
5
+ """
6
+
7
+ import collections
8
+
9
+ import torch.nn as nn
10
+
11
+ from plato.config import Config
12
+
13
+
14
+ class Model(nn.Module):
15
+ """The Multi-Layer Perception model.
16
+
17
+ Arguments:
18
+ num_classes (int): The number of classes. Default: 10.
19
+ """
20
+
21
+ def __init__(self, input_dim=600, num_classes=10):
22
+ super().__init__()
23
+ self.fc1 = nn.Sequential(nn.Linear(input_dim, 1024), nn.Tanh())
24
+
25
+ self.fc2 = nn.Sequential(nn.Linear(1024, 512), nn.Tanh())
26
+
27
+ self.fc3 = nn.Sequential(
28
+ nn.Linear(512, 256),
29
+ nn.Tanh(),
30
+ )
31
+
32
+ self.fc4 = nn.Sequential(
33
+ nn.Linear(256, 128),
34
+ nn.Tanh(),
35
+ )
36
+
37
+ self.fc5 = nn.Linear(128, num_classes)
38
+
39
+ # Preparing named layers so that the model can be split and straddle
40
+ # across the client and the server
41
+ self.layers = []
42
+ self.layerdict = collections.OrderedDict()
43
+ self.layerdict["fc1"] = self.fc1
44
+ self.layerdict["fc2"] = self.fc2
45
+ self.layerdict["fc3"] = self.fc3
46
+ self.layerdict["fc4"] = self.fc4
47
+ self.layerdict["fc5"] = self.fc5
48
+
49
+ self.layers.append("fc1")
50
+ self.layers.append("fc2")
51
+ self.layers.append("fc3")
52
+ self.layers.append("fc4")
53
+ self.layers.append("fc5")
54
+
55
+ def forward(self, x):
56
+ """Forward pass."""
57
+ x = self.fc1(x)
58
+ x = self.fc2(x)
59
+ x = self.fc3(x)
60
+ x = self.fc4(x)
61
+ x = self.fc5(x)
62
+ return x
63
+
64
+ def forward_to(self, x, cut_layer):
65
+ """Forward pass, but only to the layer specified by cut_layer."""
66
+ layer_index = self.layers.index(cut_layer)
67
+
68
+ for i in range(0, layer_index + 1):
69
+ x = self.layerdict[self.layers[i]](x)
70
+
71
+ return x
72
+
73
+ def forward_from(self, x, cut_layer):
74
+ """Forward pass, starting from the layer specified by cut_layer."""
75
+ layer_index = self.layers.index(cut_layer)
76
+
77
+ for i in range(layer_index + 1, len(self.layers)):
78
+ x = self.layerdict[self.layers[i]](x)
79
+
80
+ return x
81
+
82
+ @staticmethod
83
+ def get_model(*args):
84
+ """Obtaining an instance of this model."""
85
+ if hasattr(Config().trainer, "num_classes"):
86
+ return Model(
87
+ input_dim=Config().trainer.input_dim,
88
+ num_classes=Config().trainer.num_classes,
89
+ )
90
+ return Model()
File without changes
@@ -0,0 +1,91 @@
1
+ """
2
+ The base class used for all following classes
3
+
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from mmaction.models import build_model
10
+
11
+
12
+ class BaseClassificationNet(nn.Module):
13
+ """Base class for classification networks"""
14
+
15
+ def __init__(self, net_configs, is_head_included=True):
16
+ super(BaseClassificationNet, self).__init__()
17
+
18
+ self.net_configs = net_configs
19
+ # 1 build the model based on the configurations
20
+ self._net = build_model(net_configs)
21
+
22
+ self.is_head_included = is_head_included
23
+
24
+ # the features must be forwarded the avg pool
25
+ self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
26
+
27
+ def get_net(self):
28
+ """Get the built network"""
29
+ return self._net
30
+
31
+ def forward_train(self, ipt_data, labels, **kwargs):
32
+ """Defines the computation performed at every call when training."""
33
+
34
+ ipt_data_sz = ipt_data.reshape((-1,) + ipt_data.shape[2:])
35
+
36
+ # 1. forward the backbone
37
+ data_feat = self._net.extract_feat(ipt_data_sz)
38
+ # from [N * num_segs, in_channels, h, w]
39
+ # to [N, in_channels, 1, 1]
40
+ immediate_feat = self.avg_pool(data_feat)
41
+ # to [N, in_channels]
42
+ immediate_feat = torch.squeeze(immediate_feat)
43
+
44
+ # 2. forward the classification head if possible and obtain the losses
45
+ loss_cls = 0.0
46
+ if self.is_head_included:
47
+ cls_score = self._net.cls_head(data_feat)
48
+
49
+ gt_labels = labels.squeeze()
50
+ loss_cls = self._net.cls_head.loss(cls_score, gt_labels, **kwargs)
51
+
52
+ return [immediate_feat, cls_score, loss_cls]
53
+
54
+ return [immediate_feat]
55
+
56
+ def forward_test(self, ipt_data, **kwargs):
57
+ """Defines the computation performed at every call when training."""
58
+
59
+ ipt_data = ipt_data.reshape((-1,) + ipt_data.shape[2:])
60
+ # 1. forward the backbone
61
+ data_feat = self._net.extract_feat(ipt_data)
62
+ # 2. forward the classification head if possible and obtain the losses
63
+ cls_score = 0.0
64
+ if self.is_head_included:
65
+ cls_score = self._net.cls_head(data_feat)
66
+
67
+ return [data_feat, cls_score]
68
+
69
+ return [data_feat]
70
+
71
+ def forward(self, ipt_data, label=None, return_loss=True, **kwargs):
72
+ """Defines the computation performed at every call.
73
+
74
+ Args:
75
+ ipt_data (torch.Tensor): The input data.
76
+ the size of x is (num_batches, channel, num_slices, h, w).
77
+ Returns:
78
+ torch.Tensor: The feature of the input
79
+ samples extracted by the backbone.
80
+ """
81
+
82
+ if return_loss:
83
+ if label is None:
84
+ raise ValueError("Label should not be None.")
85
+ if self._net.blending is not None:
86
+ blended_ipt_data, label = self._net.blending(ipt_data, label)
87
+ else:
88
+ blended_ipt_data = ipt_data
89
+ return self.forward_train(blended_ipt_data, label, **kwargs)
90
+
91
+ return self.forward_test(ipt_data, **kwargs)