plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,142 @@
1
+ """
2
+
3
+ the overfitting value is the gap between the train loss L_i^T and
4
+ the groundtruth L_i^* w.r.t the hypothetical target distribution
5
+ Note: the L^* is approximated by the validation loss L^V
6
+
7
+ """
8
+
9
+ import numpy as np
10
+
11
+
12
+ def compute_overfitting_o(eval_avg_loss, train_avg_loss):
13
+ """We define overfitting at epoch N as the gap between LTN and
14
+ L∗N (approximated by ON in fig. 3)."""
15
+ return eval_avg_loss - train_avg_loss
16
+
17
+
18
+ def compute_delta_overfitting_o(
19
+ n_eval_avg_loss, n_train_avg_loss, N_eval_avg_loss, N_train_avg_loss
20
+ ):
21
+ """Compute the overfitting O based on losses between step n and N, (n < N)"""
22
+ delta_O = compute_overfitting_o(
23
+ n_eval_avg_loss, n_train_avg_loss
24
+ ) - compute_overfitting_o(N_eval_avg_loss, N_train_avg_loss)
25
+ return delta_O
26
+
27
+
28
+ def compute_generalization_g(eval_avg_loss):
29
+ """Compute the generalization g which is actually the evluation loss"""
30
+ return eval_avg_loss
31
+
32
+
33
+ def compute_delta_generalization(eval_avg_loss_n, eval_avg_loss_N):
34
+ """Compute the difference of the generalization"""
35
+ return compute_generalization_g(eval_avg_loss_n) - compute_generalization_g(
36
+ eval_avg_loss_N
37
+ )
38
+
39
+
40
+ # n < N,
41
+ def OGR_n2N(n_eval_avg_loss, n_train_avg_loss, N_eval_avg_loss, N_train_avg_loss):
42
+ """Compute the OGR = abs(delta_O/delta_G)"""
43
+ delta_O = compute_delta_overfitting_o(
44
+ n_eval_avg_loss, n_train_avg_loss, N_eval_avg_loss, N_train_avg_loss
45
+ )
46
+ delta_G = compute_delta_generalization(n_eval_avg_loss, N_eval_avg_loss)
47
+
48
+ ogr = abs(delta_O / delta_G)
49
+ return ogr
50
+
51
+
52
+ # Optimal Gradient Blend
53
+ # x << N
54
+ def get_optimal_gradient_blend_weights(modalities_losses_n, modalities_losses_N):
55
+ """Get the weights of modaliteis for optimal gradient blending
56
+
57
+ Args:
58
+ modalities_losses_n (dict): contains the train/eval losses for each modality in epoch n
59
+ modalities_losses_N (dict): contains the train/eval losses for each modality in epoch N
60
+
61
+ The structure of the above two dicts should be: (for example)
62
+ {"train": {"RGB": float, "Flow": float},
63
+ "eval": {"RGB": float, "Flow": float}}
64
+
65
+ The equation:
66
+ w^i = <∇L^*, v_i>/(σ_i)^2 * 1/Z
67
+ = <∇L^*, v_i>/(σ_i)^2 * 1/(sum_i <∇L^*, v_i>/2*(σ_i)^2)
68
+ = G^i / (O^i)^2 * 1 / (sum_i G^i / (2 * (O^i)^2))
69
+
70
+ where G^i = G_N,n = L^*_n − L^*_N = compute_delta_generalization,
71
+ O^i = O_N,n = O_N - O_n = compute_delta_overfitting_O
72
+ """
73
+ modality_names = list(modalities_losses_n["train"].keys())
74
+
75
+ Z = 0
76
+ modls_GO = dict()
77
+ for modality_nm in modality_names:
78
+ modl_eval_avg_loss_n = modalities_losses_n["eval"][modality_nm]
79
+ modl_subtrain_avg_loss_n = modalities_losses_n["train"][modality_nm]
80
+ modl_eval_avg_loss_N = modalities_losses_N["eval"][modality_nm]
81
+ modl_subtrain_avg_loss_N = modalities_losses_N["train"][modality_nm]
82
+ G_i = compute_delta_generalization(
83
+ eval_avg_loss_n=modl_eval_avg_loss_n, eval_avg_loss_N=modl_eval_avg_loss_N
84
+ )
85
+ O_i = compute_delta_overfitting_o(
86
+ n_eval_avg_loss=modl_eval_avg_loss_n,
87
+ n_train_avg_loss=modl_subtrain_avg_loss_n,
88
+ N_eval_avg_loss=modl_eval_avg_loss_N,
89
+ N_train_avg_loss=modl_subtrain_avg_loss_N,
90
+ )
91
+
92
+ modls_GO[modality_nm] = G_i / (O_i * O_i)
93
+ Gi_div_sqr_Oi = G_i / (2 * O_i * O_i)
94
+
95
+ Z += Gi_div_sqr_Oi
96
+
97
+ optimal_weights = dict()
98
+ for modality_nm in modality_names:
99
+ optimal_weights[modality_nm] = modls_GO[modality_nm] / Z
100
+
101
+ return optimal_weights
102
+
103
+
104
+ # Optimal Gradient Blend
105
+ # x << N
106
+ def get_optimal_gradient_blend_weights_og(delta_OGs):
107
+ """Get the weights of clients for optimal gradient blending
108
+
109
+ Args:
110
+ delta_OGs (list): each item is a tuple that contains (delta_O, delta_G)
111
+
112
+ The structure of the above two dicts should be: (for example)
113
+ [(0.2, 0.45), (0.3, 0.67)]
114
+
115
+ The equation that is the same as the weights computation for modalities:
116
+ w^i = <∇L^*, v_i>/(σ_i)^2 * 1/Z
117
+ = <∇L^*, v_i>/(σ_i)^2 * 1/(sum_i <∇L^*, v_i>/2*(σ_i)^2)
118
+ = G^i / (O^i)^2 * 1 / (sum_i G^i / (2 * (O^i)^2))
119
+
120
+ where G^i = G_N,n = delta_G,
121
+ O^i = O_N,n = O_N - O_n = delta_O
122
+ """
123
+ num_of_clients = len(delta_OGs)
124
+
125
+ Z = 0
126
+ clients_ratios = list()
127
+ for cli_i in range(num_of_clients):
128
+ cli_delta_O, cli_delta_G = delta_OGs[cli_i]
129
+
130
+ G_i = cli_delta_G
131
+ O_i = cli_delta_O
132
+
133
+ # models_GO[modality_nm] = G_i / (O_i * O_i)
134
+ Gi_div_sqr_Oi = G_i / (2 * O_i * O_i)
135
+
136
+ clients_ratios.append(Gi_div_sqr_Oi)
137
+
138
+ Z += Gi_div_sqr_Oi
139
+
140
+ optimal_weights = np.array(clients_ratios) / Z
141
+
142
+ return optimal_weights
@@ -0,0 +1,77 @@
1
+ """
2
+ Build the fully-connected net based on the configs
3
+
4
+ An example of the fc layer is:
5
+ fuse_model = dict(
6
+ type='FullyConnectedHead',
7
+ num_classes=400,
8
+ in_channels=rgb_model['cls_head']['in_channels'] +
9
+ flow_model['cls_head']['in_channels'] +
10
+ audio_model['cls_head']['in_channels'],
11
+ hidden_layer_size=[2014, 512],
12
+ dropout_ratio=0.5,
13
+ )
14
+ """
15
+
16
+ from collections import OrderedDict
17
+
18
+ from torch import nn
19
+
20
+
21
+ # FullyConnectedHead
22
+ def build_fc_from_config(fc_configs):
23
+ """Build one fully-connected network based our settings"""
24
+
25
+ out_classes_n = fc_configs["num_classes"]
26
+
27
+ hidden_layer_dims = fc_configs["hidden_layer_size"]
28
+ hidden_n = len(hidden_layer_dims)
29
+
30
+ drop_out_porbs = (
31
+ fc_configs["dropout_ratio"]
32
+ if isinstance(fc_configs["dropout_ratio"], list)
33
+ else [fc_configs["dropout_ratio"]]
34
+ )
35
+ drop_out_porbs = (
36
+ drop_out_porbs * hidden_n if len(drop_out_porbs) == 1 else drop_out_porbs
37
+ )
38
+
39
+ fc_strcuture = OrderedDict()
40
+
41
+ # the first layer
42
+ fc_strcuture["fc1"] = nn.Linear(fc_configs["in_channels"], hidden_layer_dims[0])
43
+ fc_strcuture["relu1"] = nn.ReLU()
44
+ fc_strcuture["drop1"] = nn.Dropout(p=drop_out_porbs[0])
45
+ # the hidden layer
46
+ for hidden_l_i, layer_dim in enumerate(hidden_layer_dims):
47
+ layer_dim = hidden_layer_dims[hidden_l_i]
48
+ if hidden_l_i == hidden_n - 1: # the final prediction layer
49
+ layer_name = "fcf" # the final layer
50
+
51
+ fc_strcuture[layer_name] = nn.Linear(layer_dim, out_classes_n)
52
+ fc_strcuture["sigmoid"] = nn.Sigmoid()
53
+
54
+ else:
55
+ next_layer_in_dim = hidden_layer_dims[hidden_l_i + 1]
56
+ layer_name = "fc" + str(hidden_l_i + 2)
57
+ relu_name = "relu" + str(hidden_l_i + 2)
58
+ dropout_name = "dropout" + str(hidden_l_i + 2)
59
+ fc_strcuture[layer_name] = nn.Linear(layer_dim, next_layer_in_dim)
60
+ fc_strcuture[relu_name] = nn.ReLU()
61
+ fc_strcuture[dropout_name] = nn.Dropout(p=drop_out_porbs[hidden_l_i + 1])
62
+
63
+ fc_net = nn.Sequential(fc_strcuture)
64
+
65
+ return fc_net
66
+
67
+
68
+ if __name__ == "__main__":
69
+ fuse_model = dict(
70
+ type="FullyConnectedHead",
71
+ num_classes=400,
72
+ in_channels=512 * 3,
73
+ hidden_layer_size=[1024, 512],
74
+ dropout_ratio=0.5,
75
+ )
76
+ fc_model = build_fc_from_config(fuse_model)
77
+ print(fc_model)
@@ -0,0 +1,78 @@
1
+ # @Date : 2021-06-27 13:23:05
2
+ """
3
+ This multimodal network is the core network used in our paper.
4
+ It can receives three datasets from three modalities(RGB, optical flow, and audio)
5
+ and then process them with
6
+ three different networks:
7
+ - RGB and flow: ResNet3D from the paper 'A closer look at spatiotemporal
8
+ convolutions for action recognition'.
9
+ This is the r2plus1d in the mmaction packet
10
+ - audio: ResNet: Deep residual learning for image recognition. In CVPR, 2016.
11
+ both with 50 layers.
12
+
13
+ - For fusion, we use a two-FC-layer network on concatenated
14
+ features from visual and audio backbones,
15
+ followed by one prediction layer.
16
+ """
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from mmaction.models import build_loss
22
+
23
+ from plato.models.multimodal import fc_net
24
+
25
+
26
+ class ConcatFusionNet(nn.Module):
27
+ """This supports concat the features of different modalities to one vector"""
28
+
29
+ def __init__(self, support_modalities, modalities_fea_dim, net_configs):
30
+ super(ConcatFusionNet, self).__init__()
31
+
32
+ # the support modality name is the pre-defined order that must be
33
+ # followed in the forward process
34
+ # especially in the fusion part
35
+ self.support_modality_names = support_modalities # a list
36
+ self.modalities_fea_dim = modalities_fea_dim
37
+ self.net_configs = net_configs
38
+ # 1 build the model based on the configurations
39
+ self._fuse_net = fc_net.build_fc_from_config(net_configs)
40
+
41
+ self.loss_cls = build_loss(self.net_configs["loss_cls"])
42
+
43
+ def create_fusion_feature(self, batch_size, modalities_features_container):
44
+ """[summary]
45
+
46
+ Args:
47
+ modalities_features_container (dict): [key is the name of the modality
48
+ while the value is the corresponding features]
49
+ modalities_features_dims_container (dict): [key is the name of the modality
50
+ while the value is the defined dim of the feature]
51
+ """
52
+ # obtain the fused feats by concating the modalities features
53
+ # The order should follow the that in the support_modality_names
54
+ modalities_feature = []
55
+ for modality_name in self.support_modality_names:
56
+ if modality_name not in modalities_features_container:
57
+ modality_dim = self.modalities_fea_dim[modality_name]
58
+ # insert all zeros features if that modality is missing
59
+ modality_feature = torch.zeros(size=(batch_size, modality_dim))
60
+ else:
61
+ modality_feature = modalities_features_container[modality_name]
62
+
63
+ modalities_feature.append(modality_feature)
64
+
65
+ fused_feat = torch.cat(modalities_feature, 1)
66
+
67
+ return fused_feat
68
+
69
+ def forward(self, fused_features, gt_labels, return_loss):
70
+ """Forward the network"""
71
+ fused_cls_score = self._fuse_net(fused_features)
72
+
73
+ if return_loss:
74
+ fused_loss = self.loss_cls(fused_cls_score, gt_labels)
75
+
76
+ return [fused_cls_score, fused_loss]
77
+ else:
78
+ return [fused_cls_score]
@@ -0,0 +1,152 @@
1
+ """
2
+ This multimodal network is the core network used in our paper.
3
+ It can receives three datasets from three modalities(RGB, optical flow, and audio)
4
+ and then process them with
5
+ three different networks:
6
+ - RGB and flow: ResNet3D from the paper 'A closer look at spatiotemporal
7
+ convolutions for action recognition'.
8
+ This is the r2plus1d in the mmaction packet
9
+ - audio: ResNet: Deep residual learning for image recognition. In CVPR, 2016.
10
+ both with 50 layers.
11
+
12
+ - For fusion, we use a two-FC-layer network on concatenated features from
13
+ visual and audio backbones,
14
+ followed by one prediction layer.
15
+ """
16
+
17
+ import logging
18
+
19
+ import torch.nn as nn
20
+
21
+ from plato.models.multimodal import base_net
22
+ from plato.models.multimodal import fusion_net
23
+
24
+
25
+ class DynamicMultimodalModule(nn.Module):
26
+ """DynamicMultimodalModule network.
27
+ This network supports the learning of several modalities (the modalities can be dynamic)
28
+
29
+ Args:
30
+ multimodal_nets_configs (namedtuple): a namedtuple contains the configurations for
31
+ different modalities, 'rgb_model', 'audio_model',
32
+ 'flow_model', 'text_model'
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ support_modality_names,
38
+ multimodal_nets_configs, # multimodal_data_model
39
+ is_fused_head=True,
40
+ ): # a cls head makes prediction based on the fused multimodal feature
41
+ super().__init__()
42
+
43
+ # ['rgb', "flow", "audio"]
44
+ self.support_modality_names = support_modality_names
45
+ self.support_nets = [
46
+ mod_nm + "_model" for mod_nm in self.support_modality_names
47
+ ]
48
+
49
+ self.is_fused_head = is_fused_head
50
+
51
+ assert all(
52
+ [s_net in multimodal_nets_configs.keys() for s_net in self.support_nets]
53
+ )
54
+ self.name_net_mapper = {}
55
+ self.modality_fea_dims_mapper = {}
56
+ for idx, modality_net in enumerate(self.support_nets):
57
+ modality_name = self.support_modality_names[idx]
58
+ modality_net = self.support_nets[idx]
59
+ if modality_net in multimodal_nets_configs.keys():
60
+ logging.info("Building the %s......", modality_net)
61
+ net_config = multimodal_nets_configs[modality_net]
62
+ is_head_included = "cls_head" in net_config.keys()
63
+ logging.info("The head is defined")
64
+
65
+ if is_head_included:
66
+ # the feature dimension is the input dimension of the cls head
67
+ fea_dims = net_config["cls_head"]["in_channels"]
68
+ self.modality_fea_dims_mapper[modality_name] = fea_dims
69
+
70
+ self.name_net_mapper[modality_name] = base_net.BaseClassificationNet(
71
+ net_configs=net_config, is_head_included=is_head_included
72
+ )
73
+
74
+ if is_fused_head:
75
+ fuse_net_config = multimodal_nets_configs["fuse_model"]
76
+
77
+ if "modalities_feature_dim" in list(fuse_net_config.keys()):
78
+ self.modality_fea_dims_mapper.update(
79
+ fuse_net_config["modalities_feature_dim"]
80
+ )
81
+
82
+ self.cat_fusion_net = fusion_net.ConcatFusionNet(
83
+ support_modalities=support_modality_names,
84
+ modalities_fea_dim=self.modality_fea_dims_mapper,
85
+ net_configs=fuse_net_config,
86
+ )
87
+
88
+ def assing_weights(self, net_name, weights):
89
+ """Assign the weights to the specific network"""
90
+ self.name_net_mapper[net_name].load_state_dict(weights, strict=True)
91
+
92
+ def _freeze_stages(self):
93
+ """Prevent all the parameters from being optimized before
94
+ ``self.frozen_stages``."""
95
+ if self.frozen_stages >= 0:
96
+ self.conv1.eval()
97
+ for param in self.conv1.parameters():
98
+ param.requires_grad = False
99
+
100
+ for i in range(1, self.frozen_stages + 1):
101
+ layer_module = getattr(self, f"layer{i}")
102
+ layer_module.eval()
103
+ for param in layer_module.parameters():
104
+ param.requires_grad = False
105
+
106
+ def forward(self, data_container, label=None, return_loss=True, **kwargs):
107
+ """[Forward the data to the whole net]
108
+
109
+ Args:
110
+ data_container (dict): [key is the name of the modality
111
+ while the value is batch of data]
112
+ label (torch.tensor, optional): [the lable of the sample]. Defaults to None.
113
+ return_loss (bool, optional): [whether return the loss ]. Defaults to True.
114
+ """
115
+ modalities_pred_scores_container = dict()
116
+ modalities_losses_container = dict()
117
+ modalities_features_container = dict()
118
+
119
+ for modality_name in data_container.keys():
120
+ modality_net = self.name_net_mapper[modality_name]
121
+ modality_ipt_data = data_container[modality_name]
122
+ batch_size = modality_ipt_data.shape[0]
123
+
124
+ logging.debug("modality_name: %s", modality_name)
125
+ logging.debug("modality_net: %s", modality_net)
126
+ logging.debug("modality_net inner net: %s", modality_net.get_net())
127
+ logging.debug("modality_ipt_data: %s", modality_ipt_data.shape)
128
+ logging.debug("batch_size: %s", batch_size)
129
+
130
+ # obtain the modality fea and the class opt
131
+ modality_opt = modality_net.forward(
132
+ ipt_data=modality_ipt_data, label=label, return_loss=return_loss
133
+ )
134
+
135
+ modalities_features_container[modality_name] = modality_opt[0]
136
+ modalities_pred_scores_container[modality_name] = modality_opt[1]
137
+ modalities_losses_container[modality_name] = modality_opt[2]
138
+
139
+ if self.is_fused_head:
140
+ # obtain the fused feats by concating the modalities features
141
+ # The order should follow the that in the support_modality_names
142
+ fused_feat = self.cat_fusion_net.create_fusion_feature(
143
+ batch_size=batch_size,
144
+ modalities_features_container=modalities_features_container,
145
+ )
146
+ fused_cls_score, fused_loss = self.cat_fusion_net.forward(
147
+ fused_feat, label, return_loss=return_loss
148
+ )
149
+ modalities_pred_scores_container["fused"] = fused_cls_score
150
+ modalities_losses_container["fused"] = fused_loss
151
+
152
+ return modalities_pred_scores_container, modalities_losses_container
@@ -0,0 +1,99 @@
1
+ """
2
+ The registry for machine learning models.
3
+
4
+ Having a registry of all available classes is convenient for retrieving an instance
5
+ based on a configuration at run-time.
6
+ """
7
+
8
+ from typing import Any, Dict, TypedDict, cast
9
+
10
+ from plato.config import Config
11
+ from plato.models import (
12
+ cnn_encoder,
13
+ dcgan,
14
+ general_multilayer,
15
+ huggingface,
16
+ lenet5,
17
+ multilayer,
18
+ resnet,
19
+ torch_hub,
20
+ vgg,
21
+ vit,
22
+ )
23
+
24
+ registered_models = {
25
+ "lenet5": lenet5.Model,
26
+ "dcgan": dcgan.Model,
27
+ "multilayer": multilayer.Model,
28
+ }
29
+
30
+ registered_factories = {
31
+ "resnet": resnet.Model,
32
+ "vgg": vgg.Model,
33
+ "cnn_encoder": cnn_encoder.Model,
34
+ "general_multilayer": general_multilayer.Model,
35
+ "torch_hub": torch_hub.Model,
36
+ "huggingface": huggingface.Model,
37
+ "vit": vit.Model,
38
+ }
39
+
40
+
41
+ class ModelKwargs(TypedDict, total=False):
42
+ model_name: str
43
+ model_type: str
44
+ model_params: Dict[str, Any]
45
+
46
+
47
+ def get(**kwargs: Any) -> Any:
48
+ """Get the model with the provided name."""
49
+ config = Config()
50
+
51
+ # Get model name
52
+ model_name: str = ""
53
+ if "model_name" in kwargs:
54
+ model_name = cast(str, kwargs["model_name"])
55
+ elif hasattr(config, "trainer"):
56
+ trainer = getattr(config, "trainer")
57
+ if hasattr(trainer, "model_name"):
58
+ model_name = getattr(trainer, "model_name")
59
+
60
+ # Get model type
61
+ model_type: str = ""
62
+ if "model_type" in kwargs:
63
+ model_type = cast(str, kwargs["model_type"])
64
+ elif hasattr(config, "trainer"):
65
+ trainer = getattr(config, "trainer")
66
+ if hasattr(trainer, "model_type"):
67
+ model_type = getattr(trainer, "model_type")
68
+
69
+ # If model_type is still empty, derive it from model_name
70
+ if not model_type and model_name:
71
+ model_type = model_name.split("_")[0]
72
+
73
+ # Get model parameters
74
+ model_params: Dict[str, Any] = {}
75
+ if "model_params" in kwargs:
76
+ model_params = cast(Dict[str, Any], kwargs["model_params"])
77
+ elif hasattr(config, "parameters"):
78
+ parameters = getattr(config, "parameters")
79
+ if hasattr(parameters, "model"):
80
+ model = getattr(parameters, "model")
81
+ if hasattr(model, "_asdict"):
82
+ model_params = model._asdict()
83
+
84
+ if model_type in registered_models:
85
+ registered_model = registered_models[model_type]
86
+ return registered_model(**model_params)
87
+
88
+ if model_type in registered_factories:
89
+ return registered_factories[model_type].get(
90
+ model_name=model_name, **model_params
91
+ )
92
+
93
+ # The YOLOv8 model needs special handling as it needs to import third-party packages
94
+ if model_name == "yolov8":
95
+ from plato.models import yolov8
96
+
97
+ return yolov8.Model.get()
98
+
99
+ raise ValueError(f"No such model: {model_name}")