plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/models/resnet.py
ADDED
@@ -0,0 +1,190 @@
|
|
1
|
+
"""
|
2
|
+
The ResNet model (for the CIFAR-10 dataset only).
|
3
|
+
|
4
|
+
Reference:
|
5
|
+
|
6
|
+
https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
|
7
|
+
"""
|
8
|
+
|
9
|
+
import collections
|
10
|
+
|
11
|
+
import torch.nn as nn
|
12
|
+
import torch.nn.functional as F
|
13
|
+
|
14
|
+
|
15
|
+
class BasicBlock(nn.Module):
|
16
|
+
expansion = 1
|
17
|
+
|
18
|
+
def __init__(self, in_planes, planes, stride=1):
|
19
|
+
super().__init__()
|
20
|
+
self.conv1 = nn.Conv2d(
|
21
|
+
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
22
|
+
)
|
23
|
+
self.bn1 = nn.BatchNorm2d(planes)
|
24
|
+
self.conv2 = nn.Conv2d(
|
25
|
+
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
|
26
|
+
)
|
27
|
+
self.bn2 = nn.BatchNorm2d(planes)
|
28
|
+
|
29
|
+
self.shortcut = nn.Sequential()
|
30
|
+
if stride != 1 or in_planes != self.expansion * planes:
|
31
|
+
self.shortcut = nn.Sequential(
|
32
|
+
nn.Conv2d(
|
33
|
+
in_planes,
|
34
|
+
self.expansion * planes,
|
35
|
+
kernel_size=1,
|
36
|
+
stride=stride,
|
37
|
+
bias=False,
|
38
|
+
),
|
39
|
+
nn.BatchNorm2d(self.expansion * planes),
|
40
|
+
)
|
41
|
+
|
42
|
+
def forward(self, x):
|
43
|
+
out = F.relu(self.bn1(self.conv1(x)))
|
44
|
+
out = self.bn2(self.conv2(out))
|
45
|
+
out += self.shortcut(x)
|
46
|
+
out = F.relu(out)
|
47
|
+
return out
|
48
|
+
|
49
|
+
|
50
|
+
class Bottleneck(nn.Module):
|
51
|
+
expansion = 4
|
52
|
+
|
53
|
+
def __init__(self, in_planes, planes, stride=1):
|
54
|
+
super().__init__()
|
55
|
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
56
|
+
self.bn1 = nn.BatchNorm2d(planes)
|
57
|
+
self.conv2 = nn.Conv2d(
|
58
|
+
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
59
|
+
)
|
60
|
+
self.bn2 = nn.BatchNorm2d(planes)
|
61
|
+
self.conv3 = nn.Conv2d(
|
62
|
+
planes, self.expansion * planes, kernel_size=1, bias=False
|
63
|
+
)
|
64
|
+
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
65
|
+
|
66
|
+
self.shortcut = nn.Sequential()
|
67
|
+
if stride != 1 or in_planes != self.expansion * planes:
|
68
|
+
self.shortcut = nn.Sequential(
|
69
|
+
nn.Conv2d(
|
70
|
+
in_planes,
|
71
|
+
self.expansion * planes,
|
72
|
+
kernel_size=1,
|
73
|
+
stride=stride,
|
74
|
+
bias=False,
|
75
|
+
),
|
76
|
+
nn.BatchNorm2d(self.expansion * planes),
|
77
|
+
)
|
78
|
+
|
79
|
+
def forward(self, x):
|
80
|
+
out = F.relu(self.bn1(self.conv1(x)))
|
81
|
+
out = F.relu(self.bn2(self.conv2(out)))
|
82
|
+
out = self.bn3(self.conv3(out))
|
83
|
+
out += self.shortcut(x)
|
84
|
+
out = F.relu(out)
|
85
|
+
return out
|
86
|
+
|
87
|
+
|
88
|
+
class Model(nn.Module):
|
89
|
+
def __init__(self, block, num_blocks, num_classes=10, cut_layer=None):
|
90
|
+
super().__init__()
|
91
|
+
|
92
|
+
self.in_planes = 64
|
93
|
+
|
94
|
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
95
|
+
self.bn1 = nn.BatchNorm2d(64)
|
96
|
+
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
97
|
+
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
98
|
+
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
99
|
+
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
100
|
+
self.linear = nn.Linear(512 * block.expansion, num_classes)
|
101
|
+
|
102
|
+
# Preparing named layers so that the model can be split and straddle
|
103
|
+
# across the client and the server
|
104
|
+
self.layers = []
|
105
|
+
self.layerdict = collections.OrderedDict()
|
106
|
+
self.layerdict["conv1"] = self.conv1
|
107
|
+
self.layerdict["bn1"] = self.bn1
|
108
|
+
self.layerdict["relu"] = F.relu
|
109
|
+
self.layerdict["layer1"] = self.layer1
|
110
|
+
self.layerdict["layer2"] = self.layer2
|
111
|
+
self.layerdict["layer3"] = self.layer3
|
112
|
+
self.layerdict["layer4"] = self.layer4
|
113
|
+
self.layers.append("conv1")
|
114
|
+
self.layers.append("bn1")
|
115
|
+
self.layers.append("relu")
|
116
|
+
self.layers.append("layer1")
|
117
|
+
self.layers.append("layer2")
|
118
|
+
self.layers.append("layer3")
|
119
|
+
self.layers.append("layer4")
|
120
|
+
self.cut_layer = cut_layer
|
121
|
+
|
122
|
+
def _make_layer(self, block, planes, num_blocks, stride):
|
123
|
+
strides = [stride] + [1] * (num_blocks - 1)
|
124
|
+
layers = []
|
125
|
+
for stride in strides:
|
126
|
+
layers.append(block(self.in_planes, planes, stride))
|
127
|
+
self.in_planes = planes * block.expansion
|
128
|
+
return nn.Sequential(*layers)
|
129
|
+
|
130
|
+
def forward(self, x):
|
131
|
+
out = F.relu(self.bn1(self.conv1(x)))
|
132
|
+
out = self.layer1(out)
|
133
|
+
out = self.layer2(out)
|
134
|
+
out = self.layer3(out)
|
135
|
+
out = self.layer4(out)
|
136
|
+
out = F.avg_pool2d(out, 4)
|
137
|
+
out = out.view(out.size(0), -1)
|
138
|
+
out = self.linear(out)
|
139
|
+
return out
|
140
|
+
|
141
|
+
def forward_to(self, x):
|
142
|
+
"""Forward pass, but only to the layer specified by cut_layer."""
|
143
|
+
layer_index = self.layers.index(self.cut_layer)
|
144
|
+
|
145
|
+
for i in range(0, layer_index + 1):
|
146
|
+
x = self.layerdict[self.layers[i]](x)
|
147
|
+
return x
|
148
|
+
|
149
|
+
def forward_from(self, x):
|
150
|
+
"""Forward pass, starting from the layer specified by cut_layer."""
|
151
|
+
layer_index = self.layers.index(self.cut_layer)
|
152
|
+
for i in range(layer_index + 1, len(self.layers)):
|
153
|
+
x = self.layerdict[self.layers[i]](x)
|
154
|
+
|
155
|
+
out = F.avg_pool2d(x, 4)
|
156
|
+
out = out.view(out.size(0), -1)
|
157
|
+
out = self.linear(out)
|
158
|
+
return out
|
159
|
+
|
160
|
+
@staticmethod
|
161
|
+
def is_valid_model_type(model_type):
|
162
|
+
return (
|
163
|
+
model_type.startswith("resnet_")
|
164
|
+
and len(model_type.split("_")) == 2
|
165
|
+
and int(model_type.split("_")[1]) in [18, 34, 50, 101, 152]
|
166
|
+
)
|
167
|
+
|
168
|
+
@staticmethod
|
169
|
+
def get(model_name=None, num_classes=None, cut_layer=None, **kwargs):
|
170
|
+
"""Returns a suitable ResNet model according to its type."""
|
171
|
+
if not Model.is_valid_model_type(model_name):
|
172
|
+
raise ValueError(f"Invalid Resnet model name: {model_name}")
|
173
|
+
|
174
|
+
resnet_type = int(model_name.split("_")[1])
|
175
|
+
|
176
|
+
if num_classes is None:
|
177
|
+
num_classes = 10
|
178
|
+
|
179
|
+
if resnet_type == 18:
|
180
|
+
return Model(BasicBlock, [2, 2, 2, 2], num_classes, cut_layer)
|
181
|
+
elif resnet_type == 34:
|
182
|
+
return Model(BasicBlock, [3, 4, 6, 3], num_classes, cut_layer)
|
183
|
+
elif resnet_type == 50:
|
184
|
+
return Model(Bottleneck, [3, 4, 6, 3], num_classes, cut_layer)
|
185
|
+
elif resnet_type == 101:
|
186
|
+
return Model(Bottleneck, [3, 4, 23, 3], num_classes, cut_layer)
|
187
|
+
elif resnet_type == 152:
|
188
|
+
return Model(Bottleneck, [3, 8, 36, 3], num_classes, cut_layer)
|
189
|
+
|
190
|
+
return None
|
@@ -0,0 +1,19 @@
|
|
1
|
+
"""
|
2
|
+
Obtaining a model from the PyTorch Hub.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
|
8
|
+
class Model:
|
9
|
+
"""
|
10
|
+
The model loaded from PyTorch Hub.
|
11
|
+
|
12
|
+
We will soon be using the get_model() method for torchvision 0.14 when it is released.
|
13
|
+
"""
|
14
|
+
|
15
|
+
@staticmethod
|
16
|
+
# pylint: disable=unused-argument
|
17
|
+
def get(model_name=None, **kwargs):
|
18
|
+
"""Returns a named model from PyTorch Hub."""
|
19
|
+
return torch.hub.load("pytorch/vision", model_name, **kwargs)
|
plato/models/vgg.py
ADDED
@@ -0,0 +1,113 @@
|
|
1
|
+
"""
|
2
|
+
A VGG-style neural network model for image classification.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import torch.nn as nn
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
|
9
|
+
class Model(nn.Module):
|
10
|
+
"""A VGG-style neural network model for image classification."""
|
11
|
+
|
12
|
+
class ConvModule(nn.Module):
|
13
|
+
"""A single convolutional module in a VGG network."""
|
14
|
+
|
15
|
+
def __init__(self, in_filters, out_filters):
|
16
|
+
super().__init__()
|
17
|
+
self.conv = nn.Conv2d(in_filters, out_filters, kernel_size=3, padding=1)
|
18
|
+
self.bn = nn.BatchNorm2d(out_filters)
|
19
|
+
|
20
|
+
def forward(self, x):
|
21
|
+
return F.relu(self.bn(self.conv(x)))
|
22
|
+
|
23
|
+
def __init__(self, plan, outputs=10):
|
24
|
+
super().__init__()
|
25
|
+
|
26
|
+
layers = []
|
27
|
+
filters = 3
|
28
|
+
|
29
|
+
for spec in plan:
|
30
|
+
if spec == "M":
|
31
|
+
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
|
32
|
+
else:
|
33
|
+
layers.append(Model.ConvModule(filters, spec))
|
34
|
+
filters = spec
|
35
|
+
|
36
|
+
self.layers = nn.Sequential(*layers)
|
37
|
+
self.fc = nn.Linear(512, outputs)
|
38
|
+
|
39
|
+
def forward(self, x):
|
40
|
+
x = self.layers(x)
|
41
|
+
x = nn.AvgPool2d(2)(x)
|
42
|
+
x = x.view(x.size(0), -1)
|
43
|
+
x = self.fc(x)
|
44
|
+
return x
|
45
|
+
|
46
|
+
@staticmethod
|
47
|
+
def is_valid_model_name(model_name):
|
48
|
+
return (
|
49
|
+
model_name.startswith("vgg")
|
50
|
+
and len(model_name.split("_")) == 2
|
51
|
+
and model_name.split("_")[1].isdigit()
|
52
|
+
and int(model_name.split("_")[1]) in [11, 13, 16, 19]
|
53
|
+
)
|
54
|
+
|
55
|
+
@staticmethod
|
56
|
+
def get(model_name, num_classes=10):
|
57
|
+
"""Returns a suitable VGG model corresponding to the provided name."""
|
58
|
+
if not Model.is_valid_model_name(model_name):
|
59
|
+
raise ValueError(f"Invalid VGG model name: {model_name}")
|
60
|
+
|
61
|
+
outputs = num_classes
|
62
|
+
|
63
|
+
num = int(model_name.split("_")[1])
|
64
|
+
|
65
|
+
if num == 11:
|
66
|
+
plan = [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512]
|
67
|
+
elif num == 13:
|
68
|
+
plan = [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512]
|
69
|
+
elif num == 16:
|
70
|
+
plan = [
|
71
|
+
64,
|
72
|
+
64,
|
73
|
+
"M",
|
74
|
+
128,
|
75
|
+
128,
|
76
|
+
"M",
|
77
|
+
256,
|
78
|
+
256,
|
79
|
+
256,
|
80
|
+
"M",
|
81
|
+
512,
|
82
|
+
512,
|
83
|
+
512,
|
84
|
+
"M",
|
85
|
+
512,
|
86
|
+
512,
|
87
|
+
512,
|
88
|
+
]
|
89
|
+
elif num == 19:
|
90
|
+
plan = [
|
91
|
+
64,
|
92
|
+
64,
|
93
|
+
"M",
|
94
|
+
128,
|
95
|
+
128,
|
96
|
+
"M",
|
97
|
+
256,
|
98
|
+
256,
|
99
|
+
256,
|
100
|
+
256,
|
101
|
+
"M",
|
102
|
+
512,
|
103
|
+
512,
|
104
|
+
512,
|
105
|
+
512,
|
106
|
+
"M",
|
107
|
+
512,
|
108
|
+
512,
|
109
|
+
512,
|
110
|
+
512,
|
111
|
+
]
|
112
|
+
|
113
|
+
return Model(plan, outputs)
|
plato/models/vit.py
ADDED
@@ -0,0 +1,166 @@
|
|
1
|
+
"""
|
2
|
+
Obtaining a Vision Transformer (ViT) model for image classification from HuggingFace.
|
3
|
+
|
4
|
+
Reference to the Tokens-to-Token ViT (T2T-ViT) model:
|
5
|
+
https://github.com/yitu-opensource/T2T-ViT
|
6
|
+
|
7
|
+
Reference to the Deep Vision Transformer (DeepViT) model:
|
8
|
+
https://github.com/zhoudaquan/dvit_repo
|
9
|
+
|
10
|
+
"""
|
11
|
+
|
12
|
+
import torch
|
13
|
+
from torch import nn
|
14
|
+
from transformers import AutoConfig, AutoModelForImageClassification
|
15
|
+
|
16
|
+
|
17
|
+
from plato.config import Config
|
18
|
+
|
19
|
+
|
20
|
+
class ResolutionAdjustedModel(nn.Module):
|
21
|
+
"""
|
22
|
+
Transforms the image resolution to the assigned resolution of a pretrained model.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(self, model_name, config) -> nn.Module:
|
26
|
+
super().__init__()
|
27
|
+
|
28
|
+
if (
|
29
|
+
hasattr(Config().parameters, "model")
|
30
|
+
and hasattr(Config().parameters.model, "pretrained")
|
31
|
+
and not Config().parameters.model.pretrained
|
32
|
+
):
|
33
|
+
ignore_mismatched_sizes = True
|
34
|
+
else:
|
35
|
+
ignore_mismatched_sizes = False
|
36
|
+
|
37
|
+
self.model = AutoModelForImageClassification.from_pretrained(
|
38
|
+
model_name,
|
39
|
+
config=config,
|
40
|
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
41
|
+
cache_dir=Config().params["model_path"] + "/huggingface",
|
42
|
+
)
|
43
|
+
|
44
|
+
if (
|
45
|
+
hasattr(Config().parameters, "model")
|
46
|
+
and hasattr(Config().parameters.model, "pretrained")
|
47
|
+
and not Config().parameters.model.pretrained
|
48
|
+
):
|
49
|
+
self.model.init_weights()
|
50
|
+
self.resolution = config.image_size
|
51
|
+
|
52
|
+
def forward(self, image):
|
53
|
+
"""
|
54
|
+
Adjusts the image resolution and outputs the logits.
|
55
|
+
"""
|
56
|
+
if image.size(-1) != self.resolution:
|
57
|
+
image = nn.functional.interpolate(
|
58
|
+
image, size=self.resolution, mode="bicubic"
|
59
|
+
)
|
60
|
+
outputs = self.model(image)
|
61
|
+
return outputs.logits
|
62
|
+
|
63
|
+
|
64
|
+
class T2TVIT(nn.Module):
|
65
|
+
"""Wrapper for the T2T-ViT model."""
|
66
|
+
|
67
|
+
def __init__(self, name) -> nn.Module:
|
68
|
+
super().__init__()
|
69
|
+
# pylint:disable=import-outside-toplevel
|
70
|
+
from plato.models import t2tvit
|
71
|
+
from plato.models.t2tvit.models import t2t_vit
|
72
|
+
|
73
|
+
model_name = getattr(t2t_vit, name)
|
74
|
+
t2t = model_name(num_classes=Config().trainer.num_classes)
|
75
|
+
|
76
|
+
if (
|
77
|
+
hasattr(Config().parameters, "model")
|
78
|
+
and hasattr(Config().parameters.model, "pretrained")
|
79
|
+
and Config().parameters.model.pretrained
|
80
|
+
):
|
81
|
+
t2tvit.utils.load_for_transfer_learning(
|
82
|
+
t2t,
|
83
|
+
Config().parameters.model.pretrain_path,
|
84
|
+
use_ema=True,
|
85
|
+
strict=False,
|
86
|
+
num_classes=Config().trainer.num_classes,
|
87
|
+
)
|
88
|
+
self.model = t2t
|
89
|
+
self.resolution = 224
|
90
|
+
|
91
|
+
def forward(self, feature):
|
92
|
+
"""The forward pass."""
|
93
|
+
if feature.size(-1) != self.resolution:
|
94
|
+
feature = nn.functional.interpolate(
|
95
|
+
feature, size=self.resolution, mode="bicubic"
|
96
|
+
)
|
97
|
+
return self.model(feature)
|
98
|
+
|
99
|
+
|
100
|
+
class DeepViT(nn.Module):
|
101
|
+
"""Wrapper for the DeepViT model."""
|
102
|
+
|
103
|
+
def __init__(self, name) -> nn.Module:
|
104
|
+
super().__init__()
|
105
|
+
# pylint:disable=import-outside-toplevel
|
106
|
+
from plato.models.dvit.models import deep_vision_transformer
|
107
|
+
|
108
|
+
model_name = getattr(deep_vision_transformer, name)
|
109
|
+
deepvit = model_name(
|
110
|
+
pretrained=False, num_classes=Config.trainer.num_classes, in_chans=3
|
111
|
+
)
|
112
|
+
if (
|
113
|
+
hasattr(Config().parameters, "model")
|
114
|
+
and hasattr(Config().parameters.model, "pretrained")
|
115
|
+
and Config().parameters.model.pretrained
|
116
|
+
):
|
117
|
+
state_dict = torch.load(
|
118
|
+
Config().parameters.model.pretrain_path, map_location="cpu"
|
119
|
+
)
|
120
|
+
del state_dict["head.weight"]
|
121
|
+
del state_dict["head.bias"]
|
122
|
+
deepvit.load_state_dict(state_dict)
|
123
|
+
self.model = deepvit
|
124
|
+
self.resolution = 224
|
125
|
+
|
126
|
+
def forward(self, feature):
|
127
|
+
"""The forward pass."""
|
128
|
+
if feature.size(-1) != self.resolution:
|
129
|
+
feature = nn.functional.interpolate(
|
130
|
+
feature, size=self.resolution, mode="bicubic"
|
131
|
+
)
|
132
|
+
|
133
|
+
return self.model(feature)
|
134
|
+
|
135
|
+
|
136
|
+
class Model:
|
137
|
+
"""
|
138
|
+
The Transformer and other models loaded from HuggingFace.
|
139
|
+
Supported by HuggingFace
|
140
|
+
https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModel
|
141
|
+
"""
|
142
|
+
|
143
|
+
# pylint:disable=too-few-public-methods
|
144
|
+
@staticmethod
|
145
|
+
def get(model_name=None, **kwargs): # pylint: disable=unused-argument
|
146
|
+
"""Returns a named model from HuggingFace."""
|
147
|
+
if "t2t" in model_name:
|
148
|
+
return T2TVIT(model_name)
|
149
|
+
|
150
|
+
if "deepvit" in model_name:
|
151
|
+
return DeepViT(model_name)
|
152
|
+
|
153
|
+
config_kwargs = {
|
154
|
+
"cache_dir": None,
|
155
|
+
"revision": "main",
|
156
|
+
"use_auth_token": None,
|
157
|
+
}
|
158
|
+
config_kwargs.update(kwargs)
|
159
|
+
|
160
|
+
model_name = model_name.replace("@", "/")
|
161
|
+
# Only prepend "google/" if the model name doesn't already contain a namespace
|
162
|
+
if "/" not in model_name:
|
163
|
+
model_name = "google/" + model_name
|
164
|
+
config = AutoConfig.from_pretrained(model_name, **config_kwargs)
|
165
|
+
|
166
|
+
return ResolutionAdjustedModel(model_name, config)
|
plato/models/yolov8.py
ADDED
@@ -0,0 +1,22 @@
|
|
1
|
+
"""
|
2
|
+
Obtaining a model from the Ultralytics.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ultralytics import YOLO
|
6
|
+
|
7
|
+
from plato.config import Config
|
8
|
+
|
9
|
+
|
10
|
+
class Model:
|
11
|
+
"""
|
12
|
+
The model loaded from the YOLOv8.
|
13
|
+
|
14
|
+
"""
|
15
|
+
|
16
|
+
@staticmethod
|
17
|
+
# pylint: disable=unused-argument
|
18
|
+
def get(model_name=None, **kwargs):
|
19
|
+
"""Returns the YOLOV8 model loaded from the Ultralytics."""
|
20
|
+
model_type = Config().parameters.model.type
|
21
|
+
|
22
|
+
return YOLO(model_type)
|
File without changes
|
plato/processors/base.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
"""
|
2
|
+
The Processor class is designed for pre-processing data payloads before or after they
|
3
|
+
are transmitted over the network between the clients and the servers.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from abc import abstractmethod
|
7
|
+
from collections.abc import Iterable
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
|
11
|
+
class Processor:
|
12
|
+
"""
|
13
|
+
The base Processor class does nothing on the data payload.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, name=None, trainer=None, **kwargs) -> None:
|
17
|
+
"""Constructor for Processor."""
|
18
|
+
self.name = name
|
19
|
+
self.trainer = trainer
|
20
|
+
|
21
|
+
@abstractmethod
|
22
|
+
def process(self, data: Any) -> Any:
|
23
|
+
"""
|
24
|
+
Processing a data payload.
|
25
|
+
"""
|
26
|
+
return data
|
27
|
+
|
28
|
+
def process_iterable(self, data: Iterable) -> Iterable:
|
29
|
+
"""
|
30
|
+
Processing an Iterable of data payload.
|
31
|
+
"""
|
32
|
+
return map(self.process, data)
|
33
|
+
|
34
|
+
def __repr__(self) -> str:
|
35
|
+
return self.name
|
@@ -0,0 +1,46 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for compressing a numpy array.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import zstd
|
8
|
+
|
9
|
+
from plato.processors import base
|
10
|
+
|
11
|
+
|
12
|
+
class Processor(base.Processor):
|
13
|
+
"""Implements a Processor for compressing numpy array."""
|
14
|
+
|
15
|
+
def __init__(self, cr=1, **kwargs) -> None:
|
16
|
+
super().__init__(**kwargs)
|
17
|
+
self.compression_ratio = cr
|
18
|
+
|
19
|
+
def process(self, data: Any) -> Any:
|
20
|
+
"""Implements a Processor for compressing numpy array."""
|
21
|
+
if isinstance(data, list):
|
22
|
+
ret = []
|
23
|
+
datashape_feature = data[0][0].shape
|
24
|
+
datatype_feature = data[0][0].dtype
|
25
|
+
ret.append((datashape_feature, datatype_feature))
|
26
|
+
for logits, targets in data:
|
27
|
+
datashape_target = targets.shape
|
28
|
+
datatype_target = targets.dtype
|
29
|
+
datacom_feature = zstd.compress(logits, self.compression_ratio)
|
30
|
+
datacom_target = zstd.compress(targets, self.compression_ratio)
|
31
|
+
ret.append(
|
32
|
+
(
|
33
|
+
datacom_feature,
|
34
|
+
datacom_target,
|
35
|
+
datashape_target,
|
36
|
+
datatype_target,
|
37
|
+
)
|
38
|
+
)
|
39
|
+
else:
|
40
|
+
ret = (
|
41
|
+
data.shape,
|
42
|
+
data.dtype,
|
43
|
+
zstd.compress(data, self.compression_ratio),
|
44
|
+
)
|
45
|
+
|
46
|
+
return ret
|
@@ -0,0 +1,48 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for decompressing a numpy array.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import zstd
|
9
|
+
|
10
|
+
from plato.processors import base
|
11
|
+
|
12
|
+
|
13
|
+
class Processor(base.Processor):
|
14
|
+
"""Implements a Processor for decompressing a numpy array."""
|
15
|
+
|
16
|
+
def __init__(self, **kwargs) -> None:
|
17
|
+
super().__init__(**kwargs)
|
18
|
+
|
19
|
+
def process(self, data: Any) -> Any:
|
20
|
+
"""Implements a Processor for decompressing a numpy array."""
|
21
|
+
if isinstance(data, list):
|
22
|
+
ret = []
|
23
|
+
datashape_feature = data[0][0]
|
24
|
+
datatype_feature = data[0][1]
|
25
|
+
for (
|
26
|
+
datacom_feature,
|
27
|
+
datacom_target,
|
28
|
+
datashape_target,
|
29
|
+
datatype_target,
|
30
|
+
) in data[1:]:
|
31
|
+
datacom_feature = zstd.decompress(datacom_feature)
|
32
|
+
datacom_feature = np.frombuffer(
|
33
|
+
datacom_feature, datatype_feature
|
34
|
+
).reshape(datashape_feature)
|
35
|
+
if len(datashape_target) > 0 and datashape_target[0] == 0:
|
36
|
+
datacom_target = np.zeros(datashape_target)
|
37
|
+
else:
|
38
|
+
datacom_target = zstd.decompress(datacom_target)
|
39
|
+
datacom_target = np.frombuffer(
|
40
|
+
datacom_target, datatype_target
|
41
|
+
).reshape(datashape_target)
|
42
|
+
ret.append((datacom_feature, datacom_target))
|
43
|
+
else:
|
44
|
+
shape, dtype, modelcom = data
|
45
|
+
modelcom = zstd.decompress(modelcom)
|
46
|
+
ret = np.frombuffer(modelcom, dtype).reshape(shape)
|
47
|
+
|
48
|
+
return ret
|
@@ -0,0 +1,51 @@
|
|
1
|
+
"""
|
2
|
+
Implements a generalized Processor for applying operations onto MistNet PyTorch features.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Any, Callable
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from plato.processors import base
|
10
|
+
|
11
|
+
|
12
|
+
class Processor(base.Processor):
|
13
|
+
"""
|
14
|
+
Implements a generalized Processor for applying operations onto MistNet PyTorch features.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
method: Callable = lambda x, y: (x, y),
|
20
|
+
client_id=None,
|
21
|
+
use_numpy=True,
|
22
|
+
**kwargs,
|
23
|
+
) -> None:
|
24
|
+
super().__init__(**kwargs)
|
25
|
+
|
26
|
+
self.client_id = client_id
|
27
|
+
self.method = method
|
28
|
+
self.use_numpy = use_numpy
|
29
|
+
|
30
|
+
def process(self, data: Any) -> Any:
|
31
|
+
"""
|
32
|
+
Implements a generalized Processor for applying operations onto MistNet PyTorch features.
|
33
|
+
"""
|
34
|
+
|
35
|
+
output = []
|
36
|
+
|
37
|
+
for logits, targets in data:
|
38
|
+
if self.use_numpy:
|
39
|
+
logits = logits.detach().numpy()
|
40
|
+
|
41
|
+
logits, targets = self.method(logits, targets)
|
42
|
+
|
43
|
+
if self.use_numpy:
|
44
|
+
if self.trainer.device != "cpu":
|
45
|
+
logits = torch.from_numpy(logits.astype("float16"))
|
46
|
+
else:
|
47
|
+
logits = torch.from_numpy(logits.astype("float32"))
|
48
|
+
|
49
|
+
output.append((logits, targets))
|
50
|
+
|
51
|
+
return output
|