plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,142 @@
|
|
1
|
+
"""
|
2
|
+
|
3
|
+
the overfitting value is the gap between the train loss L_i^T and
|
4
|
+
the groundtruth L_i^* w.r.t the hypothetical target distribution
|
5
|
+
Note: the L^* is approximated by the validation loss L^V
|
6
|
+
|
7
|
+
"""
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
|
11
|
+
|
12
|
+
def compute_overfitting_o(eval_avg_loss, train_avg_loss):
|
13
|
+
"""We define overfitting at epoch N as the gap between LTN and
|
14
|
+
L∗N (approximated by ON in fig. 3)."""
|
15
|
+
return eval_avg_loss - train_avg_loss
|
16
|
+
|
17
|
+
|
18
|
+
def compute_delta_overfitting_o(
|
19
|
+
n_eval_avg_loss, n_train_avg_loss, N_eval_avg_loss, N_train_avg_loss
|
20
|
+
):
|
21
|
+
"""Compute the overfitting O based on losses between step n and N, (n < N)"""
|
22
|
+
delta_O = compute_overfitting_o(
|
23
|
+
n_eval_avg_loss, n_train_avg_loss
|
24
|
+
) - compute_overfitting_o(N_eval_avg_loss, N_train_avg_loss)
|
25
|
+
return delta_O
|
26
|
+
|
27
|
+
|
28
|
+
def compute_generalization_g(eval_avg_loss):
|
29
|
+
"""Compute the generalization g which is actually the evluation loss"""
|
30
|
+
return eval_avg_loss
|
31
|
+
|
32
|
+
|
33
|
+
def compute_delta_generalization(eval_avg_loss_n, eval_avg_loss_N):
|
34
|
+
"""Compute the difference of the generalization"""
|
35
|
+
return compute_generalization_g(eval_avg_loss_n) - compute_generalization_g(
|
36
|
+
eval_avg_loss_N
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
# n < N,
|
41
|
+
def OGR_n2N(n_eval_avg_loss, n_train_avg_loss, N_eval_avg_loss, N_train_avg_loss):
|
42
|
+
"""Compute the OGR = abs(delta_O/delta_G)"""
|
43
|
+
delta_O = compute_delta_overfitting_o(
|
44
|
+
n_eval_avg_loss, n_train_avg_loss, N_eval_avg_loss, N_train_avg_loss
|
45
|
+
)
|
46
|
+
delta_G = compute_delta_generalization(n_eval_avg_loss, N_eval_avg_loss)
|
47
|
+
|
48
|
+
ogr = abs(delta_O / delta_G)
|
49
|
+
return ogr
|
50
|
+
|
51
|
+
|
52
|
+
# Optimal Gradient Blend
|
53
|
+
# x << N
|
54
|
+
def get_optimal_gradient_blend_weights(modalities_losses_n, modalities_losses_N):
|
55
|
+
"""Get the weights of modaliteis for optimal gradient blending
|
56
|
+
|
57
|
+
Args:
|
58
|
+
modalities_losses_n (dict): contains the train/eval losses for each modality in epoch n
|
59
|
+
modalities_losses_N (dict): contains the train/eval losses for each modality in epoch N
|
60
|
+
|
61
|
+
The structure of the above two dicts should be: (for example)
|
62
|
+
{"train": {"RGB": float, "Flow": float},
|
63
|
+
"eval": {"RGB": float, "Flow": float}}
|
64
|
+
|
65
|
+
The equation:
|
66
|
+
w^i = <∇L^*, v_i>/(σ_i)^2 * 1/Z
|
67
|
+
= <∇L^*, v_i>/(σ_i)^2 * 1/(sum_i <∇L^*, v_i>/2*(σ_i)^2)
|
68
|
+
= G^i / (O^i)^2 * 1 / (sum_i G^i / (2 * (O^i)^2))
|
69
|
+
|
70
|
+
where G^i = G_N,n = L^*_n − L^*_N = compute_delta_generalization,
|
71
|
+
O^i = O_N,n = O_N - O_n = compute_delta_overfitting_O
|
72
|
+
"""
|
73
|
+
modality_names = list(modalities_losses_n["train"].keys())
|
74
|
+
|
75
|
+
Z = 0
|
76
|
+
modls_GO = dict()
|
77
|
+
for modality_nm in modality_names:
|
78
|
+
modl_eval_avg_loss_n = modalities_losses_n["eval"][modality_nm]
|
79
|
+
modl_subtrain_avg_loss_n = modalities_losses_n["train"][modality_nm]
|
80
|
+
modl_eval_avg_loss_N = modalities_losses_N["eval"][modality_nm]
|
81
|
+
modl_subtrain_avg_loss_N = modalities_losses_N["train"][modality_nm]
|
82
|
+
G_i = compute_delta_generalization(
|
83
|
+
eval_avg_loss_n=modl_eval_avg_loss_n, eval_avg_loss_N=modl_eval_avg_loss_N
|
84
|
+
)
|
85
|
+
O_i = compute_delta_overfitting_o(
|
86
|
+
n_eval_avg_loss=modl_eval_avg_loss_n,
|
87
|
+
n_train_avg_loss=modl_subtrain_avg_loss_n,
|
88
|
+
N_eval_avg_loss=modl_eval_avg_loss_N,
|
89
|
+
N_train_avg_loss=modl_subtrain_avg_loss_N,
|
90
|
+
)
|
91
|
+
|
92
|
+
modls_GO[modality_nm] = G_i / (O_i * O_i)
|
93
|
+
Gi_div_sqr_Oi = G_i / (2 * O_i * O_i)
|
94
|
+
|
95
|
+
Z += Gi_div_sqr_Oi
|
96
|
+
|
97
|
+
optimal_weights = dict()
|
98
|
+
for modality_nm in modality_names:
|
99
|
+
optimal_weights[modality_nm] = modls_GO[modality_nm] / Z
|
100
|
+
|
101
|
+
return optimal_weights
|
102
|
+
|
103
|
+
|
104
|
+
# Optimal Gradient Blend
|
105
|
+
# x << N
|
106
|
+
def get_optimal_gradient_blend_weights_og(delta_OGs):
|
107
|
+
"""Get the weights of clients for optimal gradient blending
|
108
|
+
|
109
|
+
Args:
|
110
|
+
delta_OGs (list): each item is a tuple that contains (delta_O, delta_G)
|
111
|
+
|
112
|
+
The structure of the above two dicts should be: (for example)
|
113
|
+
[(0.2, 0.45), (0.3, 0.67)]
|
114
|
+
|
115
|
+
The equation that is the same as the weights computation for modalities:
|
116
|
+
w^i = <∇L^*, v_i>/(σ_i)^2 * 1/Z
|
117
|
+
= <∇L^*, v_i>/(σ_i)^2 * 1/(sum_i <∇L^*, v_i>/2*(σ_i)^2)
|
118
|
+
= G^i / (O^i)^2 * 1 / (sum_i G^i / (2 * (O^i)^2))
|
119
|
+
|
120
|
+
where G^i = G_N,n = delta_G,
|
121
|
+
O^i = O_N,n = O_N - O_n = delta_O
|
122
|
+
"""
|
123
|
+
num_of_clients = len(delta_OGs)
|
124
|
+
|
125
|
+
Z = 0
|
126
|
+
clients_ratios = list()
|
127
|
+
for cli_i in range(num_of_clients):
|
128
|
+
cli_delta_O, cli_delta_G = delta_OGs[cli_i]
|
129
|
+
|
130
|
+
G_i = cli_delta_G
|
131
|
+
O_i = cli_delta_O
|
132
|
+
|
133
|
+
# models_GO[modality_nm] = G_i / (O_i * O_i)
|
134
|
+
Gi_div_sqr_Oi = G_i / (2 * O_i * O_i)
|
135
|
+
|
136
|
+
clients_ratios.append(Gi_div_sqr_Oi)
|
137
|
+
|
138
|
+
Z += Gi_div_sqr_Oi
|
139
|
+
|
140
|
+
optimal_weights = np.array(clients_ratios) / Z
|
141
|
+
|
142
|
+
return optimal_weights
|
@@ -0,0 +1,77 @@
|
|
1
|
+
"""
|
2
|
+
Build the fully-connected net based on the configs
|
3
|
+
|
4
|
+
An example of the fc layer is:
|
5
|
+
fuse_model = dict(
|
6
|
+
type='FullyConnectedHead',
|
7
|
+
num_classes=400,
|
8
|
+
in_channels=rgb_model['cls_head']['in_channels'] +
|
9
|
+
flow_model['cls_head']['in_channels'] +
|
10
|
+
audio_model['cls_head']['in_channels'],
|
11
|
+
hidden_layer_size=[2014, 512],
|
12
|
+
dropout_ratio=0.5,
|
13
|
+
)
|
14
|
+
"""
|
15
|
+
|
16
|
+
from collections import OrderedDict
|
17
|
+
|
18
|
+
from torch import nn
|
19
|
+
|
20
|
+
|
21
|
+
# FullyConnectedHead
|
22
|
+
def build_fc_from_config(fc_configs):
|
23
|
+
"""Build one fully-connected network based our settings"""
|
24
|
+
|
25
|
+
out_classes_n = fc_configs["num_classes"]
|
26
|
+
|
27
|
+
hidden_layer_dims = fc_configs["hidden_layer_size"]
|
28
|
+
hidden_n = len(hidden_layer_dims)
|
29
|
+
|
30
|
+
drop_out_porbs = (
|
31
|
+
fc_configs["dropout_ratio"]
|
32
|
+
if isinstance(fc_configs["dropout_ratio"], list)
|
33
|
+
else [fc_configs["dropout_ratio"]]
|
34
|
+
)
|
35
|
+
drop_out_porbs = (
|
36
|
+
drop_out_porbs * hidden_n if len(drop_out_porbs) == 1 else drop_out_porbs
|
37
|
+
)
|
38
|
+
|
39
|
+
fc_strcuture = OrderedDict()
|
40
|
+
|
41
|
+
# the first layer
|
42
|
+
fc_strcuture["fc1"] = nn.Linear(fc_configs["in_channels"], hidden_layer_dims[0])
|
43
|
+
fc_strcuture["relu1"] = nn.ReLU()
|
44
|
+
fc_strcuture["drop1"] = nn.Dropout(p=drop_out_porbs[0])
|
45
|
+
# the hidden layer
|
46
|
+
for hidden_l_i, layer_dim in enumerate(hidden_layer_dims):
|
47
|
+
layer_dim = hidden_layer_dims[hidden_l_i]
|
48
|
+
if hidden_l_i == hidden_n - 1: # the final prediction layer
|
49
|
+
layer_name = "fcf" # the final layer
|
50
|
+
|
51
|
+
fc_strcuture[layer_name] = nn.Linear(layer_dim, out_classes_n)
|
52
|
+
fc_strcuture["sigmoid"] = nn.Sigmoid()
|
53
|
+
|
54
|
+
else:
|
55
|
+
next_layer_in_dim = hidden_layer_dims[hidden_l_i + 1]
|
56
|
+
layer_name = "fc" + str(hidden_l_i + 2)
|
57
|
+
relu_name = "relu" + str(hidden_l_i + 2)
|
58
|
+
dropout_name = "dropout" + str(hidden_l_i + 2)
|
59
|
+
fc_strcuture[layer_name] = nn.Linear(layer_dim, next_layer_in_dim)
|
60
|
+
fc_strcuture[relu_name] = nn.ReLU()
|
61
|
+
fc_strcuture[dropout_name] = nn.Dropout(p=drop_out_porbs[hidden_l_i + 1])
|
62
|
+
|
63
|
+
fc_net = nn.Sequential(fc_strcuture)
|
64
|
+
|
65
|
+
return fc_net
|
66
|
+
|
67
|
+
|
68
|
+
if __name__ == "__main__":
|
69
|
+
fuse_model = dict(
|
70
|
+
type="FullyConnectedHead",
|
71
|
+
num_classes=400,
|
72
|
+
in_channels=512 * 3,
|
73
|
+
hidden_layer_size=[1024, 512],
|
74
|
+
dropout_ratio=0.5,
|
75
|
+
)
|
76
|
+
fc_model = build_fc_from_config(fuse_model)
|
77
|
+
print(fc_model)
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# @Date : 2021-06-27 13:23:05
|
2
|
+
"""
|
3
|
+
This multimodal network is the core network used in our paper.
|
4
|
+
It can receives three datasets from three modalities(RGB, optical flow, and audio)
|
5
|
+
and then process them with
|
6
|
+
three different networks:
|
7
|
+
- RGB and flow: ResNet3D from the paper 'A closer look at spatiotemporal
|
8
|
+
convolutions for action recognition'.
|
9
|
+
This is the r2plus1d in the mmaction packet
|
10
|
+
- audio: ResNet: Deep residual learning for image recognition. In CVPR, 2016.
|
11
|
+
both with 50 layers.
|
12
|
+
|
13
|
+
- For fusion, we use a two-FC-layer network on concatenated
|
14
|
+
features from visual and audio backbones,
|
15
|
+
followed by one prediction layer.
|
16
|
+
"""
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn as nn
|
20
|
+
|
21
|
+
from mmaction.models import build_loss
|
22
|
+
|
23
|
+
from plato.models.multimodal import fc_net
|
24
|
+
|
25
|
+
|
26
|
+
class ConcatFusionNet(nn.Module):
|
27
|
+
"""This supports concat the features of different modalities to one vector"""
|
28
|
+
|
29
|
+
def __init__(self, support_modalities, modalities_fea_dim, net_configs):
|
30
|
+
super(ConcatFusionNet, self).__init__()
|
31
|
+
|
32
|
+
# the support modality name is the pre-defined order that must be
|
33
|
+
# followed in the forward process
|
34
|
+
# especially in the fusion part
|
35
|
+
self.support_modality_names = support_modalities # a list
|
36
|
+
self.modalities_fea_dim = modalities_fea_dim
|
37
|
+
self.net_configs = net_configs
|
38
|
+
# 1 build the model based on the configurations
|
39
|
+
self._fuse_net = fc_net.build_fc_from_config(net_configs)
|
40
|
+
|
41
|
+
self.loss_cls = build_loss(self.net_configs["loss_cls"])
|
42
|
+
|
43
|
+
def create_fusion_feature(self, batch_size, modalities_features_container):
|
44
|
+
"""[summary]
|
45
|
+
|
46
|
+
Args:
|
47
|
+
modalities_features_container (dict): [key is the name of the modality
|
48
|
+
while the value is the corresponding features]
|
49
|
+
modalities_features_dims_container (dict): [key is the name of the modality
|
50
|
+
while the value is the defined dim of the feature]
|
51
|
+
"""
|
52
|
+
# obtain the fused feats by concating the modalities features
|
53
|
+
# The order should follow the that in the support_modality_names
|
54
|
+
modalities_feature = []
|
55
|
+
for modality_name in self.support_modality_names:
|
56
|
+
if modality_name not in modalities_features_container:
|
57
|
+
modality_dim = self.modalities_fea_dim[modality_name]
|
58
|
+
# insert all zeros features if that modality is missing
|
59
|
+
modality_feature = torch.zeros(size=(batch_size, modality_dim))
|
60
|
+
else:
|
61
|
+
modality_feature = modalities_features_container[modality_name]
|
62
|
+
|
63
|
+
modalities_feature.append(modality_feature)
|
64
|
+
|
65
|
+
fused_feat = torch.cat(modalities_feature, 1)
|
66
|
+
|
67
|
+
return fused_feat
|
68
|
+
|
69
|
+
def forward(self, fused_features, gt_labels, return_loss):
|
70
|
+
"""Forward the network"""
|
71
|
+
fused_cls_score = self._fuse_net(fused_features)
|
72
|
+
|
73
|
+
if return_loss:
|
74
|
+
fused_loss = self.loss_cls(fused_cls_score, gt_labels)
|
75
|
+
|
76
|
+
return [fused_cls_score, fused_loss]
|
77
|
+
else:
|
78
|
+
return [fused_cls_score]
|
@@ -0,0 +1,152 @@
|
|
1
|
+
"""
|
2
|
+
This multimodal network is the core network used in our paper.
|
3
|
+
It can receives three datasets from three modalities(RGB, optical flow, and audio)
|
4
|
+
and then process them with
|
5
|
+
three different networks:
|
6
|
+
- RGB and flow: ResNet3D from the paper 'A closer look at spatiotemporal
|
7
|
+
convolutions for action recognition'.
|
8
|
+
This is the r2plus1d in the mmaction packet
|
9
|
+
- audio: ResNet: Deep residual learning for image recognition. In CVPR, 2016.
|
10
|
+
both with 50 layers.
|
11
|
+
|
12
|
+
- For fusion, we use a two-FC-layer network on concatenated features from
|
13
|
+
visual and audio backbones,
|
14
|
+
followed by one prediction layer.
|
15
|
+
"""
|
16
|
+
|
17
|
+
import logging
|
18
|
+
|
19
|
+
import torch.nn as nn
|
20
|
+
|
21
|
+
from plato.models.multimodal import base_net
|
22
|
+
from plato.models.multimodal import fusion_net
|
23
|
+
|
24
|
+
|
25
|
+
class DynamicMultimodalModule(nn.Module):
|
26
|
+
"""DynamicMultimodalModule network.
|
27
|
+
This network supports the learning of several modalities (the modalities can be dynamic)
|
28
|
+
|
29
|
+
Args:
|
30
|
+
multimodal_nets_configs (namedtuple): a namedtuple contains the configurations for
|
31
|
+
different modalities, 'rgb_model', 'audio_model',
|
32
|
+
'flow_model', 'text_model'
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
support_modality_names,
|
38
|
+
multimodal_nets_configs, # multimodal_data_model
|
39
|
+
is_fused_head=True,
|
40
|
+
): # a cls head makes prediction based on the fused multimodal feature
|
41
|
+
super().__init__()
|
42
|
+
|
43
|
+
# ['rgb', "flow", "audio"]
|
44
|
+
self.support_modality_names = support_modality_names
|
45
|
+
self.support_nets = [
|
46
|
+
mod_nm + "_model" for mod_nm in self.support_modality_names
|
47
|
+
]
|
48
|
+
|
49
|
+
self.is_fused_head = is_fused_head
|
50
|
+
|
51
|
+
assert all(
|
52
|
+
[s_net in multimodal_nets_configs.keys() for s_net in self.support_nets]
|
53
|
+
)
|
54
|
+
self.name_net_mapper = {}
|
55
|
+
self.modality_fea_dims_mapper = {}
|
56
|
+
for idx, modality_net in enumerate(self.support_nets):
|
57
|
+
modality_name = self.support_modality_names[idx]
|
58
|
+
modality_net = self.support_nets[idx]
|
59
|
+
if modality_net in multimodal_nets_configs.keys():
|
60
|
+
logging.info("Building the %s......", modality_net)
|
61
|
+
net_config = multimodal_nets_configs[modality_net]
|
62
|
+
is_head_included = "cls_head" in net_config.keys()
|
63
|
+
logging.info("The head is defined")
|
64
|
+
|
65
|
+
if is_head_included:
|
66
|
+
# the feature dimension is the input dimension of the cls head
|
67
|
+
fea_dims = net_config["cls_head"]["in_channels"]
|
68
|
+
self.modality_fea_dims_mapper[modality_name] = fea_dims
|
69
|
+
|
70
|
+
self.name_net_mapper[modality_name] = base_net.BaseClassificationNet(
|
71
|
+
net_configs=net_config, is_head_included=is_head_included
|
72
|
+
)
|
73
|
+
|
74
|
+
if is_fused_head:
|
75
|
+
fuse_net_config = multimodal_nets_configs["fuse_model"]
|
76
|
+
|
77
|
+
if "modalities_feature_dim" in list(fuse_net_config.keys()):
|
78
|
+
self.modality_fea_dims_mapper.update(
|
79
|
+
fuse_net_config["modalities_feature_dim"]
|
80
|
+
)
|
81
|
+
|
82
|
+
self.cat_fusion_net = fusion_net.ConcatFusionNet(
|
83
|
+
support_modalities=support_modality_names,
|
84
|
+
modalities_fea_dim=self.modality_fea_dims_mapper,
|
85
|
+
net_configs=fuse_net_config,
|
86
|
+
)
|
87
|
+
|
88
|
+
def assing_weights(self, net_name, weights):
|
89
|
+
"""Assign the weights to the specific network"""
|
90
|
+
self.name_net_mapper[net_name].load_state_dict(weights, strict=True)
|
91
|
+
|
92
|
+
def _freeze_stages(self):
|
93
|
+
"""Prevent all the parameters from being optimized before
|
94
|
+
``self.frozen_stages``."""
|
95
|
+
if self.frozen_stages >= 0:
|
96
|
+
self.conv1.eval()
|
97
|
+
for param in self.conv1.parameters():
|
98
|
+
param.requires_grad = False
|
99
|
+
|
100
|
+
for i in range(1, self.frozen_stages + 1):
|
101
|
+
layer_module = getattr(self, f"layer{i}")
|
102
|
+
layer_module.eval()
|
103
|
+
for param in layer_module.parameters():
|
104
|
+
param.requires_grad = False
|
105
|
+
|
106
|
+
def forward(self, data_container, label=None, return_loss=True, **kwargs):
|
107
|
+
"""[Forward the data to the whole net]
|
108
|
+
|
109
|
+
Args:
|
110
|
+
data_container (dict): [key is the name of the modality
|
111
|
+
while the value is batch of data]
|
112
|
+
label (torch.tensor, optional): [the lable of the sample]. Defaults to None.
|
113
|
+
return_loss (bool, optional): [whether return the loss ]. Defaults to True.
|
114
|
+
"""
|
115
|
+
modalities_pred_scores_container = dict()
|
116
|
+
modalities_losses_container = dict()
|
117
|
+
modalities_features_container = dict()
|
118
|
+
|
119
|
+
for modality_name in data_container.keys():
|
120
|
+
modality_net = self.name_net_mapper[modality_name]
|
121
|
+
modality_ipt_data = data_container[modality_name]
|
122
|
+
batch_size = modality_ipt_data.shape[0]
|
123
|
+
|
124
|
+
logging.debug("modality_name: %s", modality_name)
|
125
|
+
logging.debug("modality_net: %s", modality_net)
|
126
|
+
logging.debug("modality_net inner net: %s", modality_net.get_net())
|
127
|
+
logging.debug("modality_ipt_data: %s", modality_ipt_data.shape)
|
128
|
+
logging.debug("batch_size: %s", batch_size)
|
129
|
+
|
130
|
+
# obtain the modality fea and the class opt
|
131
|
+
modality_opt = modality_net.forward(
|
132
|
+
ipt_data=modality_ipt_data, label=label, return_loss=return_loss
|
133
|
+
)
|
134
|
+
|
135
|
+
modalities_features_container[modality_name] = modality_opt[0]
|
136
|
+
modalities_pred_scores_container[modality_name] = modality_opt[1]
|
137
|
+
modalities_losses_container[modality_name] = modality_opt[2]
|
138
|
+
|
139
|
+
if self.is_fused_head:
|
140
|
+
# obtain the fused feats by concating the modalities features
|
141
|
+
# The order should follow the that in the support_modality_names
|
142
|
+
fused_feat = self.cat_fusion_net.create_fusion_feature(
|
143
|
+
batch_size=batch_size,
|
144
|
+
modalities_features_container=modalities_features_container,
|
145
|
+
)
|
146
|
+
fused_cls_score, fused_loss = self.cat_fusion_net.forward(
|
147
|
+
fused_feat, label, return_loss=return_loss
|
148
|
+
)
|
149
|
+
modalities_pred_scores_container["fused"] = fused_cls_score
|
150
|
+
modalities_losses_container["fused"] = fused_loss
|
151
|
+
|
152
|
+
return modalities_pred_scores_container, modalities_losses_container
|
plato/models/registry.py
ADDED
@@ -0,0 +1,99 @@
|
|
1
|
+
"""
|
2
|
+
The registry for machine learning models.
|
3
|
+
|
4
|
+
Having a registry of all available classes is convenient for retrieving an instance
|
5
|
+
based on a configuration at run-time.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Any, Dict, TypedDict, cast
|
9
|
+
|
10
|
+
from plato.config import Config
|
11
|
+
from plato.models import (
|
12
|
+
cnn_encoder,
|
13
|
+
dcgan,
|
14
|
+
general_multilayer,
|
15
|
+
huggingface,
|
16
|
+
lenet5,
|
17
|
+
multilayer,
|
18
|
+
resnet,
|
19
|
+
torch_hub,
|
20
|
+
vgg,
|
21
|
+
vit,
|
22
|
+
)
|
23
|
+
|
24
|
+
registered_models = {
|
25
|
+
"lenet5": lenet5.Model,
|
26
|
+
"dcgan": dcgan.Model,
|
27
|
+
"multilayer": multilayer.Model,
|
28
|
+
}
|
29
|
+
|
30
|
+
registered_factories = {
|
31
|
+
"resnet": resnet.Model,
|
32
|
+
"vgg": vgg.Model,
|
33
|
+
"cnn_encoder": cnn_encoder.Model,
|
34
|
+
"general_multilayer": general_multilayer.Model,
|
35
|
+
"torch_hub": torch_hub.Model,
|
36
|
+
"huggingface": huggingface.Model,
|
37
|
+
"vit": vit.Model,
|
38
|
+
}
|
39
|
+
|
40
|
+
|
41
|
+
class ModelKwargs(TypedDict, total=False):
|
42
|
+
model_name: str
|
43
|
+
model_type: str
|
44
|
+
model_params: Dict[str, Any]
|
45
|
+
|
46
|
+
|
47
|
+
def get(**kwargs: Any) -> Any:
|
48
|
+
"""Get the model with the provided name."""
|
49
|
+
config = Config()
|
50
|
+
|
51
|
+
# Get model name
|
52
|
+
model_name: str = ""
|
53
|
+
if "model_name" in kwargs:
|
54
|
+
model_name = cast(str, kwargs["model_name"])
|
55
|
+
elif hasattr(config, "trainer"):
|
56
|
+
trainer = getattr(config, "trainer")
|
57
|
+
if hasattr(trainer, "model_name"):
|
58
|
+
model_name = getattr(trainer, "model_name")
|
59
|
+
|
60
|
+
# Get model type
|
61
|
+
model_type: str = ""
|
62
|
+
if "model_type" in kwargs:
|
63
|
+
model_type = cast(str, kwargs["model_type"])
|
64
|
+
elif hasattr(config, "trainer"):
|
65
|
+
trainer = getattr(config, "trainer")
|
66
|
+
if hasattr(trainer, "model_type"):
|
67
|
+
model_type = getattr(trainer, "model_type")
|
68
|
+
|
69
|
+
# If model_type is still empty, derive it from model_name
|
70
|
+
if not model_type and model_name:
|
71
|
+
model_type = model_name.split("_")[0]
|
72
|
+
|
73
|
+
# Get model parameters
|
74
|
+
model_params: Dict[str, Any] = {}
|
75
|
+
if "model_params" in kwargs:
|
76
|
+
model_params = cast(Dict[str, Any], kwargs["model_params"])
|
77
|
+
elif hasattr(config, "parameters"):
|
78
|
+
parameters = getattr(config, "parameters")
|
79
|
+
if hasattr(parameters, "model"):
|
80
|
+
model = getattr(parameters, "model")
|
81
|
+
if hasattr(model, "_asdict"):
|
82
|
+
model_params = model._asdict()
|
83
|
+
|
84
|
+
if model_type in registered_models:
|
85
|
+
registered_model = registered_models[model_type]
|
86
|
+
return registered_model(**model_params)
|
87
|
+
|
88
|
+
if model_type in registered_factories:
|
89
|
+
return registered_factories[model_type].get(
|
90
|
+
model_name=model_name, **model_params
|
91
|
+
)
|
92
|
+
|
93
|
+
# The YOLOv8 model needs special handling as it needs to import third-party packages
|
94
|
+
if model_name == "yolov8":
|
95
|
+
from plato.models import yolov8
|
96
|
+
|
97
|
+
return yolov8.Model.get()
|
98
|
+
|
99
|
+
raise ValueError(f"No such model: {model_name}")
|