plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/models/resnet.py ADDED
@@ -0,0 +1,190 @@
1
+ """
2
+ The ResNet model (for the CIFAR-10 dataset only).
3
+
4
+ Reference:
5
+
6
+ https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
7
+ """
8
+
9
+ import collections
10
+
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class BasicBlock(nn.Module):
16
+ expansion = 1
17
+
18
+ def __init__(self, in_planes, planes, stride=1):
19
+ super().__init__()
20
+ self.conv1 = nn.Conv2d(
21
+ in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
22
+ )
23
+ self.bn1 = nn.BatchNorm2d(planes)
24
+ self.conv2 = nn.Conv2d(
25
+ planes, planes, kernel_size=3, stride=1, padding=1, bias=False
26
+ )
27
+ self.bn2 = nn.BatchNorm2d(planes)
28
+
29
+ self.shortcut = nn.Sequential()
30
+ if stride != 1 or in_planes != self.expansion * planes:
31
+ self.shortcut = nn.Sequential(
32
+ nn.Conv2d(
33
+ in_planes,
34
+ self.expansion * planes,
35
+ kernel_size=1,
36
+ stride=stride,
37
+ bias=False,
38
+ ),
39
+ nn.BatchNorm2d(self.expansion * planes),
40
+ )
41
+
42
+ def forward(self, x):
43
+ out = F.relu(self.bn1(self.conv1(x)))
44
+ out = self.bn2(self.conv2(out))
45
+ out += self.shortcut(x)
46
+ out = F.relu(out)
47
+ return out
48
+
49
+
50
+ class Bottleneck(nn.Module):
51
+ expansion = 4
52
+
53
+ def __init__(self, in_planes, planes, stride=1):
54
+ super().__init__()
55
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
56
+ self.bn1 = nn.BatchNorm2d(planes)
57
+ self.conv2 = nn.Conv2d(
58
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
59
+ )
60
+ self.bn2 = nn.BatchNorm2d(planes)
61
+ self.conv3 = nn.Conv2d(
62
+ planes, self.expansion * planes, kernel_size=1, bias=False
63
+ )
64
+ self.bn3 = nn.BatchNorm2d(self.expansion * planes)
65
+
66
+ self.shortcut = nn.Sequential()
67
+ if stride != 1 or in_planes != self.expansion * planes:
68
+ self.shortcut = nn.Sequential(
69
+ nn.Conv2d(
70
+ in_planes,
71
+ self.expansion * planes,
72
+ kernel_size=1,
73
+ stride=stride,
74
+ bias=False,
75
+ ),
76
+ nn.BatchNorm2d(self.expansion * planes),
77
+ )
78
+
79
+ def forward(self, x):
80
+ out = F.relu(self.bn1(self.conv1(x)))
81
+ out = F.relu(self.bn2(self.conv2(out)))
82
+ out = self.bn3(self.conv3(out))
83
+ out += self.shortcut(x)
84
+ out = F.relu(out)
85
+ return out
86
+
87
+
88
+ class Model(nn.Module):
89
+ def __init__(self, block, num_blocks, num_classes=10, cut_layer=None):
90
+ super().__init__()
91
+
92
+ self.in_planes = 64
93
+
94
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
95
+ self.bn1 = nn.BatchNorm2d(64)
96
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
97
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
98
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
99
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
100
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
101
+
102
+ # Preparing named layers so that the model can be split and straddle
103
+ # across the client and the server
104
+ self.layers = []
105
+ self.layerdict = collections.OrderedDict()
106
+ self.layerdict["conv1"] = self.conv1
107
+ self.layerdict["bn1"] = self.bn1
108
+ self.layerdict["relu"] = F.relu
109
+ self.layerdict["layer1"] = self.layer1
110
+ self.layerdict["layer2"] = self.layer2
111
+ self.layerdict["layer3"] = self.layer3
112
+ self.layerdict["layer4"] = self.layer4
113
+ self.layers.append("conv1")
114
+ self.layers.append("bn1")
115
+ self.layers.append("relu")
116
+ self.layers.append("layer1")
117
+ self.layers.append("layer2")
118
+ self.layers.append("layer3")
119
+ self.layers.append("layer4")
120
+ self.cut_layer = cut_layer
121
+
122
+ def _make_layer(self, block, planes, num_blocks, stride):
123
+ strides = [stride] + [1] * (num_blocks - 1)
124
+ layers = []
125
+ for stride in strides:
126
+ layers.append(block(self.in_planes, planes, stride))
127
+ self.in_planes = planes * block.expansion
128
+ return nn.Sequential(*layers)
129
+
130
+ def forward(self, x):
131
+ out = F.relu(self.bn1(self.conv1(x)))
132
+ out = self.layer1(out)
133
+ out = self.layer2(out)
134
+ out = self.layer3(out)
135
+ out = self.layer4(out)
136
+ out = F.avg_pool2d(out, 4)
137
+ out = out.view(out.size(0), -1)
138
+ out = self.linear(out)
139
+ return out
140
+
141
+ def forward_to(self, x):
142
+ """Forward pass, but only to the layer specified by cut_layer."""
143
+ layer_index = self.layers.index(self.cut_layer)
144
+
145
+ for i in range(0, layer_index + 1):
146
+ x = self.layerdict[self.layers[i]](x)
147
+ return x
148
+
149
+ def forward_from(self, x):
150
+ """Forward pass, starting from the layer specified by cut_layer."""
151
+ layer_index = self.layers.index(self.cut_layer)
152
+ for i in range(layer_index + 1, len(self.layers)):
153
+ x = self.layerdict[self.layers[i]](x)
154
+
155
+ out = F.avg_pool2d(x, 4)
156
+ out = out.view(out.size(0), -1)
157
+ out = self.linear(out)
158
+ return out
159
+
160
+ @staticmethod
161
+ def is_valid_model_type(model_type):
162
+ return (
163
+ model_type.startswith("resnet_")
164
+ and len(model_type.split("_")) == 2
165
+ and int(model_type.split("_")[1]) in [18, 34, 50, 101, 152]
166
+ )
167
+
168
+ @staticmethod
169
+ def get(model_name=None, num_classes=None, cut_layer=None, **kwargs):
170
+ """Returns a suitable ResNet model according to its type."""
171
+ if not Model.is_valid_model_type(model_name):
172
+ raise ValueError(f"Invalid Resnet model name: {model_name}")
173
+
174
+ resnet_type = int(model_name.split("_")[1])
175
+
176
+ if num_classes is None:
177
+ num_classes = 10
178
+
179
+ if resnet_type == 18:
180
+ return Model(BasicBlock, [2, 2, 2, 2], num_classes, cut_layer)
181
+ elif resnet_type == 34:
182
+ return Model(BasicBlock, [3, 4, 6, 3], num_classes, cut_layer)
183
+ elif resnet_type == 50:
184
+ return Model(Bottleneck, [3, 4, 6, 3], num_classes, cut_layer)
185
+ elif resnet_type == 101:
186
+ return Model(Bottleneck, [3, 4, 23, 3], num_classes, cut_layer)
187
+ elif resnet_type == 152:
188
+ return Model(Bottleneck, [3, 8, 36, 3], num_classes, cut_layer)
189
+
190
+ return None
@@ -0,0 +1,19 @@
1
+ """
2
+ Obtaining a model from the PyTorch Hub.
3
+ """
4
+
5
+ import torch
6
+
7
+
8
+ class Model:
9
+ """
10
+ The model loaded from PyTorch Hub.
11
+
12
+ We will soon be using the get_model() method for torchvision 0.14 when it is released.
13
+ """
14
+
15
+ @staticmethod
16
+ # pylint: disable=unused-argument
17
+ def get(model_name=None, **kwargs):
18
+ """Returns a named model from PyTorch Hub."""
19
+ return torch.hub.load("pytorch/vision", model_name, **kwargs)
plato/models/vgg.py ADDED
@@ -0,0 +1,113 @@
1
+ """
2
+ A VGG-style neural network model for image classification.
3
+ """
4
+
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Model(nn.Module):
10
+ """A VGG-style neural network model for image classification."""
11
+
12
+ class ConvModule(nn.Module):
13
+ """A single convolutional module in a VGG network."""
14
+
15
+ def __init__(self, in_filters, out_filters):
16
+ super().__init__()
17
+ self.conv = nn.Conv2d(in_filters, out_filters, kernel_size=3, padding=1)
18
+ self.bn = nn.BatchNorm2d(out_filters)
19
+
20
+ def forward(self, x):
21
+ return F.relu(self.bn(self.conv(x)))
22
+
23
+ def __init__(self, plan, outputs=10):
24
+ super().__init__()
25
+
26
+ layers = []
27
+ filters = 3
28
+
29
+ for spec in plan:
30
+ if spec == "M":
31
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
32
+ else:
33
+ layers.append(Model.ConvModule(filters, spec))
34
+ filters = spec
35
+
36
+ self.layers = nn.Sequential(*layers)
37
+ self.fc = nn.Linear(512, outputs)
38
+
39
+ def forward(self, x):
40
+ x = self.layers(x)
41
+ x = nn.AvgPool2d(2)(x)
42
+ x = x.view(x.size(0), -1)
43
+ x = self.fc(x)
44
+ return x
45
+
46
+ @staticmethod
47
+ def is_valid_model_name(model_name):
48
+ return (
49
+ model_name.startswith("vgg")
50
+ and len(model_name.split("_")) == 2
51
+ and model_name.split("_")[1].isdigit()
52
+ and int(model_name.split("_")[1]) in [11, 13, 16, 19]
53
+ )
54
+
55
+ @staticmethod
56
+ def get(model_name, num_classes=10):
57
+ """Returns a suitable VGG model corresponding to the provided name."""
58
+ if not Model.is_valid_model_name(model_name):
59
+ raise ValueError(f"Invalid VGG model name: {model_name}")
60
+
61
+ outputs = num_classes
62
+
63
+ num = int(model_name.split("_")[1])
64
+
65
+ if num == 11:
66
+ plan = [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512]
67
+ elif num == 13:
68
+ plan = [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512]
69
+ elif num == 16:
70
+ plan = [
71
+ 64,
72
+ 64,
73
+ "M",
74
+ 128,
75
+ 128,
76
+ "M",
77
+ 256,
78
+ 256,
79
+ 256,
80
+ "M",
81
+ 512,
82
+ 512,
83
+ 512,
84
+ "M",
85
+ 512,
86
+ 512,
87
+ 512,
88
+ ]
89
+ elif num == 19:
90
+ plan = [
91
+ 64,
92
+ 64,
93
+ "M",
94
+ 128,
95
+ 128,
96
+ "M",
97
+ 256,
98
+ 256,
99
+ 256,
100
+ 256,
101
+ "M",
102
+ 512,
103
+ 512,
104
+ 512,
105
+ 512,
106
+ "M",
107
+ 512,
108
+ 512,
109
+ 512,
110
+ 512,
111
+ ]
112
+
113
+ return Model(plan, outputs)
plato/models/vit.py ADDED
@@ -0,0 +1,166 @@
1
+ """
2
+ Obtaining a Vision Transformer (ViT) model for image classification from HuggingFace.
3
+
4
+ Reference to the Tokens-to-Token ViT (T2T-ViT) model:
5
+ https://github.com/yitu-opensource/T2T-ViT
6
+
7
+ Reference to the Deep Vision Transformer (DeepViT) model:
8
+ https://github.com/zhoudaquan/dvit_repo
9
+
10
+ """
11
+
12
+ import torch
13
+ from torch import nn
14
+ from transformers import AutoConfig, AutoModelForImageClassification
15
+
16
+
17
+ from plato.config import Config
18
+
19
+
20
+ class ResolutionAdjustedModel(nn.Module):
21
+ """
22
+ Transforms the image resolution to the assigned resolution of a pretrained model.
23
+ """
24
+
25
+ def __init__(self, model_name, config) -> nn.Module:
26
+ super().__init__()
27
+
28
+ if (
29
+ hasattr(Config().parameters, "model")
30
+ and hasattr(Config().parameters.model, "pretrained")
31
+ and not Config().parameters.model.pretrained
32
+ ):
33
+ ignore_mismatched_sizes = True
34
+ else:
35
+ ignore_mismatched_sizes = False
36
+
37
+ self.model = AutoModelForImageClassification.from_pretrained(
38
+ model_name,
39
+ config=config,
40
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
41
+ cache_dir=Config().params["model_path"] + "/huggingface",
42
+ )
43
+
44
+ if (
45
+ hasattr(Config().parameters, "model")
46
+ and hasattr(Config().parameters.model, "pretrained")
47
+ and not Config().parameters.model.pretrained
48
+ ):
49
+ self.model.init_weights()
50
+ self.resolution = config.image_size
51
+
52
+ def forward(self, image):
53
+ """
54
+ Adjusts the image resolution and outputs the logits.
55
+ """
56
+ if image.size(-1) != self.resolution:
57
+ image = nn.functional.interpolate(
58
+ image, size=self.resolution, mode="bicubic"
59
+ )
60
+ outputs = self.model(image)
61
+ return outputs.logits
62
+
63
+
64
+ class T2TVIT(nn.Module):
65
+ """Wrapper for the T2T-ViT model."""
66
+
67
+ def __init__(self, name) -> nn.Module:
68
+ super().__init__()
69
+ # pylint:disable=import-outside-toplevel
70
+ from plato.models import t2tvit
71
+ from plato.models.t2tvit.models import t2t_vit
72
+
73
+ model_name = getattr(t2t_vit, name)
74
+ t2t = model_name(num_classes=Config().trainer.num_classes)
75
+
76
+ if (
77
+ hasattr(Config().parameters, "model")
78
+ and hasattr(Config().parameters.model, "pretrained")
79
+ and Config().parameters.model.pretrained
80
+ ):
81
+ t2tvit.utils.load_for_transfer_learning(
82
+ t2t,
83
+ Config().parameters.model.pretrain_path,
84
+ use_ema=True,
85
+ strict=False,
86
+ num_classes=Config().trainer.num_classes,
87
+ )
88
+ self.model = t2t
89
+ self.resolution = 224
90
+
91
+ def forward(self, feature):
92
+ """The forward pass."""
93
+ if feature.size(-1) != self.resolution:
94
+ feature = nn.functional.interpolate(
95
+ feature, size=self.resolution, mode="bicubic"
96
+ )
97
+ return self.model(feature)
98
+
99
+
100
+ class DeepViT(nn.Module):
101
+ """Wrapper for the DeepViT model."""
102
+
103
+ def __init__(self, name) -> nn.Module:
104
+ super().__init__()
105
+ # pylint:disable=import-outside-toplevel
106
+ from plato.models.dvit.models import deep_vision_transformer
107
+
108
+ model_name = getattr(deep_vision_transformer, name)
109
+ deepvit = model_name(
110
+ pretrained=False, num_classes=Config.trainer.num_classes, in_chans=3
111
+ )
112
+ if (
113
+ hasattr(Config().parameters, "model")
114
+ and hasattr(Config().parameters.model, "pretrained")
115
+ and Config().parameters.model.pretrained
116
+ ):
117
+ state_dict = torch.load(
118
+ Config().parameters.model.pretrain_path, map_location="cpu"
119
+ )
120
+ del state_dict["head.weight"]
121
+ del state_dict["head.bias"]
122
+ deepvit.load_state_dict(state_dict)
123
+ self.model = deepvit
124
+ self.resolution = 224
125
+
126
+ def forward(self, feature):
127
+ """The forward pass."""
128
+ if feature.size(-1) != self.resolution:
129
+ feature = nn.functional.interpolate(
130
+ feature, size=self.resolution, mode="bicubic"
131
+ )
132
+
133
+ return self.model(feature)
134
+
135
+
136
+ class Model:
137
+ """
138
+ The Transformer and other models loaded from HuggingFace.
139
+ Supported by HuggingFace
140
+ https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModel
141
+ """
142
+
143
+ # pylint:disable=too-few-public-methods
144
+ @staticmethod
145
+ def get(model_name=None, **kwargs): # pylint: disable=unused-argument
146
+ """Returns a named model from HuggingFace."""
147
+ if "t2t" in model_name:
148
+ return T2TVIT(model_name)
149
+
150
+ if "deepvit" in model_name:
151
+ return DeepViT(model_name)
152
+
153
+ config_kwargs = {
154
+ "cache_dir": None,
155
+ "revision": "main",
156
+ "use_auth_token": None,
157
+ }
158
+ config_kwargs.update(kwargs)
159
+
160
+ model_name = model_name.replace("@", "/")
161
+ # Only prepend "google/" if the model name doesn't already contain a namespace
162
+ if "/" not in model_name:
163
+ model_name = "google/" + model_name
164
+ config = AutoConfig.from_pretrained(model_name, **config_kwargs)
165
+
166
+ return ResolutionAdjustedModel(model_name, config)
plato/models/yolov8.py ADDED
@@ -0,0 +1,22 @@
1
+ """
2
+ Obtaining a model from the Ultralytics.
3
+ """
4
+
5
+ from ultralytics import YOLO
6
+
7
+ from plato.config import Config
8
+
9
+
10
+ class Model:
11
+ """
12
+ The model loaded from the YOLOv8.
13
+
14
+ """
15
+
16
+ @staticmethod
17
+ # pylint: disable=unused-argument
18
+ def get(model_name=None, **kwargs):
19
+ """Returns the YOLOV8 model loaded from the Ultralytics."""
20
+ model_type = Config().parameters.model.type
21
+
22
+ return YOLO(model_type)
File without changes
@@ -0,0 +1,35 @@
1
+ """
2
+ The Processor class is designed for pre-processing data payloads before or after they
3
+ are transmitted over the network between the clients and the servers.
4
+ """
5
+
6
+ from abc import abstractmethod
7
+ from collections.abc import Iterable
8
+ from typing import Any
9
+
10
+
11
+ class Processor:
12
+ """
13
+ The base Processor class does nothing on the data payload.
14
+ """
15
+
16
+ def __init__(self, name=None, trainer=None, **kwargs) -> None:
17
+ """Constructor for Processor."""
18
+ self.name = name
19
+ self.trainer = trainer
20
+
21
+ @abstractmethod
22
+ def process(self, data: Any) -> Any:
23
+ """
24
+ Processing a data payload.
25
+ """
26
+ return data
27
+
28
+ def process_iterable(self, data: Iterable) -> Iterable:
29
+ """
30
+ Processing an Iterable of data payload.
31
+ """
32
+ return map(self.process, data)
33
+
34
+ def __repr__(self) -> str:
35
+ return self.name
@@ -0,0 +1,46 @@
1
+ """
2
+ Implements a Processor for compressing a numpy array.
3
+ """
4
+
5
+ from typing import Any
6
+
7
+ import zstd
8
+
9
+ from plato.processors import base
10
+
11
+
12
+ class Processor(base.Processor):
13
+ """Implements a Processor for compressing numpy array."""
14
+
15
+ def __init__(self, cr=1, **kwargs) -> None:
16
+ super().__init__(**kwargs)
17
+ self.compression_ratio = cr
18
+
19
+ def process(self, data: Any) -> Any:
20
+ """Implements a Processor for compressing numpy array."""
21
+ if isinstance(data, list):
22
+ ret = []
23
+ datashape_feature = data[0][0].shape
24
+ datatype_feature = data[0][0].dtype
25
+ ret.append((datashape_feature, datatype_feature))
26
+ for logits, targets in data:
27
+ datashape_target = targets.shape
28
+ datatype_target = targets.dtype
29
+ datacom_feature = zstd.compress(logits, self.compression_ratio)
30
+ datacom_target = zstd.compress(targets, self.compression_ratio)
31
+ ret.append(
32
+ (
33
+ datacom_feature,
34
+ datacom_target,
35
+ datashape_target,
36
+ datatype_target,
37
+ )
38
+ )
39
+ else:
40
+ ret = (
41
+ data.shape,
42
+ data.dtype,
43
+ zstd.compress(data, self.compression_ratio),
44
+ )
45
+
46
+ return ret
@@ -0,0 +1,48 @@
1
+ """
2
+ Implements a Processor for decompressing a numpy array.
3
+ """
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import zstd
9
+
10
+ from plato.processors import base
11
+
12
+
13
+ class Processor(base.Processor):
14
+ """Implements a Processor for decompressing a numpy array."""
15
+
16
+ def __init__(self, **kwargs) -> None:
17
+ super().__init__(**kwargs)
18
+
19
+ def process(self, data: Any) -> Any:
20
+ """Implements a Processor for decompressing a numpy array."""
21
+ if isinstance(data, list):
22
+ ret = []
23
+ datashape_feature = data[0][0]
24
+ datatype_feature = data[0][1]
25
+ for (
26
+ datacom_feature,
27
+ datacom_target,
28
+ datashape_target,
29
+ datatype_target,
30
+ ) in data[1:]:
31
+ datacom_feature = zstd.decompress(datacom_feature)
32
+ datacom_feature = np.frombuffer(
33
+ datacom_feature, datatype_feature
34
+ ).reshape(datashape_feature)
35
+ if len(datashape_target) > 0 and datashape_target[0] == 0:
36
+ datacom_target = np.zeros(datashape_target)
37
+ else:
38
+ datacom_target = zstd.decompress(datacom_target)
39
+ datacom_target = np.frombuffer(
40
+ datacom_target, datatype_target
41
+ ).reshape(datashape_target)
42
+ ret.append((datacom_feature, datacom_target))
43
+ else:
44
+ shape, dtype, modelcom = data
45
+ modelcom = zstd.decompress(modelcom)
46
+ ret = np.frombuffer(modelcom, dtype).reshape(shape)
47
+
48
+ return ret
@@ -0,0 +1,51 @@
1
+ """
2
+ Implements a generalized Processor for applying operations onto MistNet PyTorch features.
3
+ """
4
+
5
+ from typing import Any, Callable
6
+
7
+ import torch
8
+
9
+ from plato.processors import base
10
+
11
+
12
+ class Processor(base.Processor):
13
+ """
14
+ Implements a generalized Processor for applying operations onto MistNet PyTorch features.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ method: Callable = lambda x, y: (x, y),
20
+ client_id=None,
21
+ use_numpy=True,
22
+ **kwargs,
23
+ ) -> None:
24
+ super().__init__(**kwargs)
25
+
26
+ self.client_id = client_id
27
+ self.method = method
28
+ self.use_numpy = use_numpy
29
+
30
+ def process(self, data: Any) -> Any:
31
+ """
32
+ Implements a generalized Processor for applying operations onto MistNet PyTorch features.
33
+ """
34
+
35
+ output = []
36
+
37
+ for logits, targets in data:
38
+ if self.use_numpy:
39
+ logits = logits.detach().numpy()
40
+
41
+ logits, targets = self.method(logits, targets)
42
+
43
+ if self.use_numpy:
44
+ if self.trainer.device != "cpu":
45
+ logits = torch.from_numpy(logits.astype("float16"))
46
+ else:
47
+ logits = torch.from_numpy(logits.astype("float32"))
48
+
49
+ output.append((logits, targets))
50
+
51
+ return output