yomitoku 0.4.0.post1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. yomitoku/__init__.py +20 -0
  2. yomitoku/base.py +136 -0
  3. yomitoku/cli/__init__.py +0 -0
  4. yomitoku/cli/main.py +230 -0
  5. yomitoku/configs/__init__.py +13 -0
  6. yomitoku/configs/cfg_layout_parser_rtdtrv2.py +89 -0
  7. yomitoku/configs/cfg_table_structure_recognizer_rtdtrv2.py +80 -0
  8. yomitoku/configs/cfg_text_detector_dbnet.py +49 -0
  9. yomitoku/configs/cfg_text_recognizer_parseq.py +51 -0
  10. yomitoku/constants.py +32 -0
  11. yomitoku/data/__init__.py +3 -0
  12. yomitoku/data/dataset.py +40 -0
  13. yomitoku/data/functions.py +279 -0
  14. yomitoku/document_analyzer.py +315 -0
  15. yomitoku/export/__init__.py +6 -0
  16. yomitoku/export/export_csv.py +71 -0
  17. yomitoku/export/export_html.py +188 -0
  18. yomitoku/export/export_json.py +34 -0
  19. yomitoku/export/export_markdown.py +145 -0
  20. yomitoku/layout_analyzer.py +66 -0
  21. yomitoku/layout_parser.py +189 -0
  22. yomitoku/models/__init__.py +9 -0
  23. yomitoku/models/dbnet_plus.py +272 -0
  24. yomitoku/models/layers/__init__.py +0 -0
  25. yomitoku/models/layers/activate.py +38 -0
  26. yomitoku/models/layers/dbnet_feature_attention.py +160 -0
  27. yomitoku/models/layers/parseq_transformer.py +218 -0
  28. yomitoku/models/layers/rtdetr_backbone.py +333 -0
  29. yomitoku/models/layers/rtdetr_hybrid_encoder.py +433 -0
  30. yomitoku/models/layers/rtdetrv2_decoder.py +811 -0
  31. yomitoku/models/parseq.py +243 -0
  32. yomitoku/models/rtdetr.py +22 -0
  33. yomitoku/ocr.py +87 -0
  34. yomitoku/postprocessor/__init__.py +9 -0
  35. yomitoku/postprocessor/dbnet_postporcessor.py +137 -0
  36. yomitoku/postprocessor/parseq_tokenizer.py +128 -0
  37. yomitoku/postprocessor/rtdetr_postprocessor.py +107 -0
  38. yomitoku/reading_order.py +214 -0
  39. yomitoku/resource/MPLUS1p-Medium.ttf +0 -0
  40. yomitoku/resource/charset.txt +1 -0
  41. yomitoku/table_structure_recognizer.py +244 -0
  42. yomitoku/text_detector.py +103 -0
  43. yomitoku/text_recognizer.py +128 -0
  44. yomitoku/utils/__init__.py +0 -0
  45. yomitoku/utils/graph.py +20 -0
  46. yomitoku/utils/logger.py +15 -0
  47. yomitoku/utils/misc.py +102 -0
  48. yomitoku/utils/visualizer.py +179 -0
  49. yomitoku-0.4.0.post1.dev0.dist-info/METADATA +127 -0
  50. yomitoku-0.4.0.post1.dev0.dist-info/RECORD +52 -0
  51. yomitoku-0.4.0.post1.dev0.dist-info/WHEEL +4 -0
  52. yomitoku-0.4.0.post1.dev0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,811 @@
1
+ """Copyright(c) 2023 lyuwenyu. All Rights Reserved."""
2
+
3
+ import copy
4
+ import functools
5
+ import math
6
+ from collections import OrderedDict
7
+ from typing import List
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.nn.init as init
13
+ from omegaconf import ListConfig
14
+
15
+ from .activate import get_activation
16
+
17
+
18
+ def bias_init_with_prob(prior_prob=0.01):
19
+ """initialize conv/fc bias value according to a given probability value."""
20
+ bias_init = float(-math.log((1 - prior_prob) / prior_prob))
21
+ return bias_init
22
+
23
+
24
+ def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
25
+ x = x.clip(min=0.0, max=1.0)
26
+ return torch.log(x.clip(min=eps) / (1 - x).clip(min=eps))
27
+
28
+
29
+ class MLP(nn.Module):
30
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act="relu"):
31
+ super().__init__()
32
+ self.num_layers = num_layers
33
+ h = [hidden_dim] * (num_layers - 1)
34
+ self.layers = nn.ModuleList(
35
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
36
+ )
37
+ self.act = get_activation(act)
38
+
39
+ def forward(self, x):
40
+ for i, layer in enumerate(self.layers):
41
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
42
+ return x
43
+
44
+
45
+ class MSDeformableAttention(nn.Module):
46
+ def __init__(
47
+ self,
48
+ embed_dim=256,
49
+ num_heads=8,
50
+ num_levels=4,
51
+ num_points=4,
52
+ method="default",
53
+ offset_scale=0.5,
54
+ ):
55
+ """Multi-Scale Deformable Attention"""
56
+ super(MSDeformableAttention, self).__init__()
57
+ self.embed_dim = embed_dim
58
+ self.num_heads = num_heads
59
+ self.num_levels = num_levels
60
+ self.offset_scale = offset_scale
61
+
62
+ if isinstance(num_points, list):
63
+ assert len(num_points) == num_levels, ""
64
+ num_points_list = num_points
65
+ if isinstance(num_points, ListConfig):
66
+ num_points_list = list(num_points)
67
+ else:
68
+ num_points_list = [num_points for _ in range(num_levels)]
69
+
70
+ self.num_points_list = num_points_list
71
+
72
+ num_points_scale = [1 / n for n in num_points_list for _ in range(n)]
73
+ self.register_buffer(
74
+ "num_points_scale",
75
+ torch.tensor(num_points_scale, dtype=torch.float32),
76
+ )
77
+
78
+ self.total_points = num_heads * sum(num_points_list)
79
+ self.method = method
80
+
81
+ self.head_dim = embed_dim // num_heads
82
+ assert (
83
+ self.head_dim * num_heads == self.embed_dim
84
+ ), "embed_dim must be divisible by num_heads"
85
+
86
+ self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2)
87
+ self.attention_weights = nn.Linear(embed_dim, self.total_points)
88
+ self.value_proj = nn.Linear(embed_dim, embed_dim)
89
+ self.output_proj = nn.Linear(embed_dim, embed_dim)
90
+
91
+ self.ms_deformable_attn_core = functools.partial(
92
+ deformable_attention_core_func_v2, method=self.method
93
+ )
94
+
95
+ self._reset_parameters()
96
+
97
+ if method == "discrete":
98
+ for p in self.sampling_offsets.parameters():
99
+ p.requires_grad = False
100
+
101
+ def _reset_parameters(self):
102
+ # sampling_offsets
103
+ init.constant_(self.sampling_offsets.weight, 0)
104
+ thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
105
+ 2.0 * math.pi / self.num_heads
106
+ )
107
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
108
+ grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values
109
+ grid_init = grid_init.reshape(self.num_heads, 1, 2).tile(
110
+ [1, sum(self.num_points_list), 1]
111
+ )
112
+ scaling = torch.concat(
113
+ [torch.arange(1, n + 1) for n in self.num_points_list]
114
+ ).reshape(1, -1, 1)
115
+ grid_init *= scaling
116
+ self.sampling_offsets.bias.data[...] = grid_init.flatten()
117
+
118
+ # attention_weights
119
+ init.constant_(self.attention_weights.weight, 0)
120
+ init.constant_(self.attention_weights.bias, 0)
121
+
122
+ # proj
123
+ init.xavier_uniform_(self.value_proj.weight)
124
+ init.constant_(self.value_proj.bias, 0)
125
+ init.xavier_uniform_(self.output_proj.weight)
126
+ init.constant_(self.output_proj.bias, 0)
127
+
128
+ def forward(
129
+ self,
130
+ query: torch.Tensor,
131
+ reference_points: torch.Tensor,
132
+ value: torch.Tensor,
133
+ value_spatial_shapes: List[int],
134
+ value_mask: torch.Tensor = None,
135
+ ):
136
+ """
137
+ Args:
138
+ query (Tensor): [bs, query_length, C]
139
+ reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
140
+ bottom-right (1, 1), including padding area
141
+ value (Tensor): [bs, value_length, C]
142
+ value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
143
+ value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
144
+
145
+ Returns:
146
+ output (Tensor): [bs, Length_{query}, C]
147
+ """
148
+ bs, Len_q = query.shape[:2]
149
+ Len_v = value.shape[1]
150
+
151
+ value = self.value_proj(value)
152
+ if value_mask is not None:
153
+ value = value * value_mask.to(value.dtype).unsqueeze(-1)
154
+
155
+ value = value.reshape(bs, Len_v, self.num_heads, self.head_dim)
156
+
157
+ sampling_offsets: torch.Tensor = self.sampling_offsets(query)
158
+ sampling_offsets = sampling_offsets.reshape(
159
+ bs, Len_q, self.num_heads, sum(self.num_points_list), 2
160
+ )
161
+
162
+ attention_weights = self.attention_weights(query).reshape(
163
+ bs, Len_q, self.num_heads, sum(self.num_points_list)
164
+ )
165
+ attention_weights = F.softmax(attention_weights, dim=-1).reshape(
166
+ bs, Len_q, self.num_heads, sum(self.num_points_list)
167
+ )
168
+
169
+ if reference_points.shape[-1] == 2:
170
+ offset_normalizer = torch.tensor(value_spatial_shapes)
171
+ offset_normalizer = offset_normalizer.flip([1]).reshape(
172
+ 1, 1, 1, self.num_levels, 1, 2
173
+ )
174
+ sampling_locations = (
175
+ reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2)
176
+ + sampling_offsets / offset_normalizer
177
+ )
178
+ elif reference_points.shape[-1] == 4:
179
+ # reference_points [8, 480, None, 1, 4]
180
+ # sampling_offsets [8, 480, 8, 12, 2]
181
+ num_points_scale = self.num_points_scale.to(dtype=query.dtype).unsqueeze(-1)
182
+ offset = (
183
+ sampling_offsets
184
+ * num_points_scale
185
+ * reference_points[:, :, None, :, 2:]
186
+ * self.offset_scale
187
+ )
188
+ sampling_locations = reference_points[:, :, None, :, :2] + offset
189
+ else:
190
+ raise ValueError(
191
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
192
+ reference_points.shape[-1]
193
+ )
194
+ )
195
+
196
+ output = self.ms_deformable_attn_core(
197
+ value,
198
+ value_spatial_shapes,
199
+ sampling_locations,
200
+ attention_weights,
201
+ self.num_points_list,
202
+ )
203
+
204
+ output = self.output_proj(output)
205
+
206
+ return output
207
+
208
+
209
+ class TransformerDecoderLayer(nn.Module):
210
+ def __init__(
211
+ self,
212
+ d_model=256,
213
+ n_head=8,
214
+ dim_feedforward=1024,
215
+ dropout=0.0,
216
+ activation="relu",
217
+ n_levels=4,
218
+ n_points=4,
219
+ cross_attn_method="default",
220
+ ):
221
+ super(TransformerDecoderLayer, self).__init__()
222
+
223
+ # self attention
224
+ self.self_attn = nn.MultiheadAttention(
225
+ d_model, n_head, dropout=dropout, batch_first=True
226
+ )
227
+ self.dropout1 = nn.Dropout(dropout)
228
+ self.norm1 = nn.LayerNorm(d_model)
229
+
230
+ # cross attention
231
+ self.cross_attn = MSDeformableAttention(
232
+ d_model, n_head, n_levels, n_points, method=cross_attn_method
233
+ )
234
+ self.dropout2 = nn.Dropout(dropout)
235
+ self.norm2 = nn.LayerNorm(d_model)
236
+
237
+ # ffn
238
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
239
+ self.activation = get_activation(activation)
240
+ self.dropout3 = nn.Dropout(dropout)
241
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
242
+ self.dropout4 = nn.Dropout(dropout)
243
+ self.norm3 = nn.LayerNorm(d_model)
244
+
245
+ self._reset_parameters()
246
+
247
+ def _reset_parameters(self):
248
+ init.xavier_uniform_(self.linear1.weight)
249
+ init.xavier_uniform_(self.linear2.weight)
250
+
251
+ def with_pos_embed(self, tensor, pos):
252
+ return tensor if pos is None else tensor + pos
253
+
254
+ def forward_ffn(self, tgt):
255
+ return self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
256
+
257
+ def forward(
258
+ self,
259
+ target,
260
+ reference_points,
261
+ memory,
262
+ memory_spatial_shapes,
263
+ attn_mask=None,
264
+ memory_mask=None,
265
+ query_pos_embed=None,
266
+ ):
267
+ # self attention
268
+ q = k = self.with_pos_embed(target, query_pos_embed)
269
+
270
+ target2, _ = self.self_attn(q, k, value=target, attn_mask=attn_mask)
271
+ target = target + self.dropout1(target2)
272
+ target = self.norm1(target)
273
+
274
+ # cross attention
275
+ target2 = self.cross_attn(
276
+ self.with_pos_embed(target, query_pos_embed),
277
+ reference_points,
278
+ memory,
279
+ memory_spatial_shapes,
280
+ memory_mask,
281
+ )
282
+ target = target + self.dropout2(target2)
283
+ target = self.norm2(target)
284
+
285
+ # ffn
286
+ target2 = self.forward_ffn(target)
287
+ target = target + self.dropout4(target2)
288
+ target = self.norm3(target)
289
+
290
+ return target
291
+
292
+
293
+ def deformable_attention_core_func_v2(
294
+ value: torch.Tensor,
295
+ value_spatial_shapes,
296
+ sampling_locations: torch.Tensor,
297
+ attention_weights: torch.Tensor,
298
+ num_points_list: List[int],
299
+ method="default",
300
+ ):
301
+ """
302
+ Args:
303
+ value (Tensor): [bs, value_length, n_head, c]
304
+ value_spatial_shapes (Tensor|List): [n_levels, 2]
305
+ value_level_start_index (Tensor|List): [n_levels]
306
+ sampling_locations (Tensor): [bs, query_length, n_head, n_levels * n_points, 2]
307
+ attention_weights (Tensor): [bs, query_length, n_head, n_levels * n_points]
308
+
309
+ Returns:
310
+ output (Tensor): [bs, Length_{query}, C]
311
+ """
312
+ bs, _, n_head, c = value.shape
313
+ _, Len_q, _, _, _ = sampling_locations.shape
314
+
315
+ split_shape = [h * w for h, w in value_spatial_shapes]
316
+ value_list = value.permute(0, 2, 3, 1).flatten(0, 1).split(split_shape, dim=-1)
317
+
318
+ # sampling_offsets [8, 480, 8, 12, 2]
319
+ if method == "default":
320
+ sampling_grids = 2 * sampling_locations - 1
321
+
322
+ elif method == "discrete":
323
+ sampling_grids = sampling_locations
324
+
325
+ sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
326
+ sampling_locations_list = sampling_grids.split(num_points_list, dim=-2)
327
+
328
+ sampling_value_list = []
329
+ for level, (h, w) in enumerate(value_spatial_shapes):
330
+ value_l = value_list[level].reshape(bs * n_head, c, h, w)
331
+ sampling_grid_l: torch.Tensor = sampling_locations_list[level]
332
+
333
+ if method == "default":
334
+ sampling_value_l = F.grid_sample(
335
+ value_l,
336
+ sampling_grid_l,
337
+ mode="bilinear",
338
+ padding_mode="zeros",
339
+ align_corners=False,
340
+ )
341
+
342
+ elif method == "discrete":
343
+ # n * m, seq, n, 2
344
+ sampling_coord = (
345
+ sampling_grid_l * torch.tensor([[w, h]], device=value.device) + 0.5
346
+ ).to(torch.int64)
347
+
348
+ # FIX ME? for rectangle input
349
+ sampling_coord = sampling_coord.clamp(0, h - 1)
350
+ sampling_coord = sampling_coord.reshape(
351
+ bs * n_head, Len_q * num_points_list[level], 2
352
+ )
353
+
354
+ s_idx = (
355
+ torch.arange(sampling_coord.shape[0], device=value.device)
356
+ .unsqueeze(-1)
357
+ .repeat(1, sampling_coord.shape[1])
358
+ )
359
+ sampling_value_l: torch.Tensor = value_l[
360
+ s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]
361
+ ] # n l c
362
+
363
+ sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape(
364
+ bs * n_head, c, Len_q, num_points_list[level]
365
+ )
366
+
367
+ sampling_value_list.append(sampling_value_l)
368
+
369
+ attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(
370
+ bs * n_head, 1, Len_q, sum(num_points_list)
371
+ )
372
+ weighted_sample_locs = torch.concat(sampling_value_list, dim=-1) * attn_weights
373
+ output = weighted_sample_locs.sum(-1).reshape(bs, n_head * c, Len_q)
374
+
375
+ return output.permute(0, 2, 1)
376
+
377
+
378
+ class TransformerDecoder(nn.Module):
379
+ def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1):
380
+ super(TransformerDecoder, self).__init__()
381
+ self.layers = nn.ModuleList(
382
+ [copy.deepcopy(decoder_layer) for _ in range(num_layers)]
383
+ )
384
+ self.hidden_dim = hidden_dim
385
+ self.num_layers = num_layers
386
+ self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
387
+
388
+ def forward(
389
+ self,
390
+ target,
391
+ ref_points_unact,
392
+ memory,
393
+ memory_spatial_shapes,
394
+ bbox_head,
395
+ score_head,
396
+ query_pos_head,
397
+ attn_mask=None,
398
+ memory_mask=None,
399
+ ):
400
+ dec_out_bboxes = []
401
+ dec_out_logits = []
402
+ ref_points_detach = F.sigmoid(ref_points_unact)
403
+
404
+ output = target
405
+ for i, layer in enumerate(self.layers):
406
+ ref_points_input = ref_points_detach.unsqueeze(2)
407
+ query_pos_embed = query_pos_head(ref_points_detach)
408
+
409
+ output = layer(
410
+ output,
411
+ ref_points_input,
412
+ memory,
413
+ memory_spatial_shapes,
414
+ attn_mask,
415
+ memory_mask,
416
+ query_pos_embed,
417
+ )
418
+
419
+ inter_ref_bbox = F.sigmoid(
420
+ bbox_head[i](output) + inverse_sigmoid(ref_points_detach)
421
+ )
422
+
423
+ if i == self.eval_idx:
424
+ dec_out_logits.append(score_head[i](output))
425
+ dec_out_bboxes.append(inter_ref_bbox)
426
+ break
427
+
428
+ ref_points_detach = inter_ref_bbox.detach()
429
+
430
+ return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
431
+
432
+
433
+ class RTDETRTransformerv2(nn.Module):
434
+ __share__ = ["num_classes", "eval_spatial_size"]
435
+
436
+ def __init__(
437
+ self,
438
+ num_classes=80,
439
+ hidden_dim=256,
440
+ num_queries=300,
441
+ feat_channels=[512, 1024, 2048],
442
+ feat_strides=[8, 16, 32],
443
+ num_levels=3,
444
+ num_points=4,
445
+ nhead=8,
446
+ num_layers=6,
447
+ dim_feedforward=1024,
448
+ dropout=0.0,
449
+ activation="relu",
450
+ num_denoising=100,
451
+ label_noise_ratio=0.5,
452
+ box_noise_scale=1.0,
453
+ learn_query_content=False,
454
+ eval_spatial_size=None,
455
+ eval_idx=-1,
456
+ eps=1e-2,
457
+ aux_loss=True,
458
+ cross_attn_method="default",
459
+ query_select_method="default",
460
+ ):
461
+ super().__init__()
462
+ assert len(feat_channels) <= num_levels
463
+ assert len(feat_strides) == len(feat_channels)
464
+
465
+ for _ in range(num_levels - len(feat_strides)):
466
+ feat_strides.append(feat_strides[-1] * 2)
467
+
468
+ self.hidden_dim = hidden_dim
469
+ self.nhead = nhead
470
+ self.feat_strides = feat_strides
471
+ self.num_levels = num_levels
472
+ self.num_classes = num_classes
473
+ self.num_queries = num_queries
474
+ self.eps = eps
475
+ self.num_layers = num_layers
476
+ self.eval_spatial_size = eval_spatial_size
477
+ self.aux_loss = aux_loss
478
+
479
+ assert query_select_method in ("default", "one2many", "agnostic"), ""
480
+ assert cross_attn_method in ("default", "discrete"), ""
481
+ self.cross_attn_method = cross_attn_method
482
+ self.query_select_method = query_select_method
483
+
484
+ # backbone feature projection
485
+ self._build_input_proj_layer(feat_channels)
486
+
487
+ # Transformer module
488
+ decoder_layer = TransformerDecoderLayer(
489
+ hidden_dim,
490
+ nhead,
491
+ dim_feedforward,
492
+ dropout,
493
+ activation,
494
+ num_levels,
495
+ num_points,
496
+ cross_attn_method=cross_attn_method,
497
+ )
498
+ self.decoder = TransformerDecoder(
499
+ hidden_dim, decoder_layer, num_layers, eval_idx
500
+ )
501
+
502
+ # denoising
503
+ self.num_denoising = num_denoising
504
+ self.label_noise_ratio = label_noise_ratio
505
+ self.box_noise_scale = box_noise_scale
506
+ if num_denoising > 0:
507
+ self.denoising_class_embed = nn.Embedding(
508
+ num_classes + 1, hidden_dim, padding_idx=num_classes
509
+ )
510
+ init.normal_(self.denoising_class_embed.weight[:-1])
511
+
512
+ # decoder embedding
513
+ self.learn_query_content = learn_query_content
514
+ if learn_query_content:
515
+ self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
516
+ self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2)
517
+
518
+ # if num_select_queries != self.num_queries:
519
+ # layer = TransformerEncoderLayer(hidden_dim, nhead, dim_feedforward, activation='gelu')
520
+ # self.encoder = TransformerEncoder(layer, 1)
521
+
522
+ self.enc_output = nn.Sequential(
523
+ OrderedDict(
524
+ [
525
+ ("proj", nn.Linear(hidden_dim, hidden_dim)),
526
+ (
527
+ "norm",
528
+ nn.LayerNorm(
529
+ hidden_dim,
530
+ ),
531
+ ),
532
+ ]
533
+ )
534
+ )
535
+
536
+ if query_select_method == "agnostic":
537
+ self.enc_score_head = nn.Linear(hidden_dim, 1)
538
+ else:
539
+ self.enc_score_head = nn.Linear(hidden_dim, num_classes)
540
+
541
+ self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3)
542
+
543
+ # decoder head
544
+ self.dec_score_head = nn.ModuleList(
545
+ [nn.Linear(hidden_dim, num_classes) for _ in range(num_layers)]
546
+ )
547
+ self.dec_bbox_head = nn.ModuleList(
548
+ [MLP(hidden_dim, hidden_dim, 4, 3) for _ in range(num_layers)]
549
+ )
550
+
551
+ # init encoder output anchors and valid_mask
552
+ if self.eval_spatial_size:
553
+ anchors, valid_mask = self._generate_anchors()
554
+ self.register_buffer("anchors", anchors)
555
+ self.register_buffer("valid_mask", valid_mask)
556
+
557
+ self._reset_parameters()
558
+
559
+ def _reset_parameters(self):
560
+ bias = bias_init_with_prob(0.01)
561
+ init.constant_(self.enc_score_head.bias, bias)
562
+ init.constant_(self.enc_bbox_head.layers[-1].weight, 0)
563
+ init.constant_(self.enc_bbox_head.layers[-1].bias, 0)
564
+
565
+ for _cls, _reg in zip(self.dec_score_head, self.dec_bbox_head):
566
+ init.constant_(_cls.bias, bias)
567
+ init.constant_(_reg.layers[-1].weight, 0)
568
+ init.constant_(_reg.layers[-1].bias, 0)
569
+
570
+ init.xavier_uniform_(self.enc_output[0].weight)
571
+ if self.learn_query_content:
572
+ init.xavier_uniform_(self.tgt_embed.weight)
573
+ init.xavier_uniform_(self.query_pos_head.layers[0].weight)
574
+ init.xavier_uniform_(self.query_pos_head.layers[1].weight)
575
+ for m in self.input_proj:
576
+ init.xavier_uniform_(m[0].weight)
577
+
578
+ def _build_input_proj_layer(self, feat_channels):
579
+ self.input_proj = nn.ModuleList()
580
+ for in_channels in feat_channels:
581
+ self.input_proj.append(
582
+ nn.Sequential(
583
+ OrderedDict(
584
+ [
585
+ (
586
+ "conv",
587
+ nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False),
588
+ ),
589
+ (
590
+ "norm",
591
+ nn.BatchNorm2d(
592
+ self.hidden_dim,
593
+ ),
594
+ ),
595
+ ]
596
+ )
597
+ )
598
+ )
599
+
600
+ in_channels = feat_channels[-1]
601
+
602
+ for _ in range(self.num_levels - len(feat_channels)):
603
+ self.input_proj.append(
604
+ nn.Sequential(
605
+ OrderedDict(
606
+ [
607
+ (
608
+ "conv",
609
+ nn.Conv2d(
610
+ in_channels,
611
+ self.hidden_dim,
612
+ 3,
613
+ 2,
614
+ padding=1,
615
+ bias=False,
616
+ ),
617
+ ),
618
+ ("norm", nn.BatchNorm2d(self.hidden_dim)),
619
+ ]
620
+ )
621
+ )
622
+ )
623
+ in_channels = self.hidden_dim
624
+
625
+ def _get_encoder_input(self, feats: List[torch.Tensor]):
626
+ # get projection features
627
+ proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
628
+ if self.num_levels > len(proj_feats):
629
+ len_srcs = len(proj_feats)
630
+ for i in range(len_srcs, self.num_levels):
631
+ if i == len_srcs:
632
+ proj_feats.append(self.input_proj[i](feats[-1]))
633
+ else:
634
+ proj_feats.append(self.input_proj[i](proj_feats[-1]))
635
+
636
+ # get encoder inputs
637
+ feat_flatten = []
638
+ spatial_shapes = []
639
+ for i, feat in enumerate(proj_feats):
640
+ _, _, h, w = feat.shape
641
+ # [b, c, h, w] -> [b, h*w, c]
642
+ feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
643
+ # [num_levels, 2]
644
+ spatial_shapes.append([h, w])
645
+ # [b, l, c]
646
+ feat_flatten = torch.concat(feat_flatten, 1)
647
+ return feat_flatten, spatial_shapes
648
+
649
+ def _generate_anchors(
650
+ self,
651
+ spatial_shapes=None,
652
+ grid_size=0.05,
653
+ dtype=torch.float32,
654
+ device="cpu",
655
+ ):
656
+ if spatial_shapes is None:
657
+ spatial_shapes = []
658
+ eval_h, eval_w = self.eval_spatial_size
659
+ for s in self.feat_strides:
660
+ spatial_shapes.append([int(eval_h / s), int(eval_w / s)])
661
+
662
+ anchors = []
663
+ for lvl, (h, w) in enumerate(spatial_shapes):
664
+ grid_y, grid_x = torch.meshgrid(
665
+ torch.arange(h), torch.arange(w), indexing="ij"
666
+ )
667
+ grid_xy = torch.stack([grid_x, grid_y], dim=-1)
668
+ grid_xy = (grid_xy.unsqueeze(0) + 0.5) / torch.tensor([w, h], dtype=dtype)
669
+ wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
670
+ lvl_anchors = torch.concat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4)
671
+ anchors.append(lvl_anchors)
672
+
673
+ anchors = torch.concat(anchors, dim=1).to(device)
674
+ valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(
675
+ -1, keepdim=True
676
+ )
677
+ anchors = torch.log(anchors / (1 - anchors))
678
+ anchors = torch.where(valid_mask, anchors, torch.inf)
679
+
680
+ return anchors, valid_mask
681
+
682
+ def _get_decoder_input(
683
+ self,
684
+ memory: torch.Tensor,
685
+ spatial_shapes,
686
+ denoising_logits=None,
687
+ denoising_bbox_unact=None,
688
+ ):
689
+ # prepare input for decoder
690
+ anchors = self.anchors
691
+ valid_mask = self.valid_mask
692
+
693
+ # memory = torch.where(valid_mask, memory, 0)
694
+ # TODO fix type error for onnx export
695
+ memory = valid_mask.to(memory.dtype) * memory
696
+
697
+ output_memory: torch.Tensor = self.enc_output(memory)
698
+ enc_outputs_logits: torch.Tensor = self.enc_score_head(output_memory)
699
+ enc_outputs_coord_unact: torch.Tensor = (
700
+ self.enc_bbox_head(output_memory) + anchors
701
+ )
702
+
703
+ enc_topk_bboxes_list, enc_topk_logits_list = [], []
704
+ enc_topk_memory, enc_topk_logits, enc_topk_bbox_unact = self._select_topk(
705
+ output_memory,
706
+ enc_outputs_logits,
707
+ enc_outputs_coord_unact,
708
+ self.num_queries,
709
+ )
710
+
711
+ # if self.num_select_queries != self.num_queries:
712
+ # raise NotImplementedError('')
713
+
714
+ if self.learn_query_content:
715
+ content = self.tgt_embed.weight.unsqueeze(0).tile([memory.shape[0], 1, 1])
716
+ else:
717
+ content = enc_topk_memory.detach()
718
+
719
+ enc_topk_bbox_unact = enc_topk_bbox_unact.detach()
720
+
721
+ if denoising_bbox_unact is not None:
722
+ enc_topk_bbox_unact = torch.concat(
723
+ [denoising_bbox_unact, enc_topk_bbox_unact], dim=1
724
+ )
725
+ content = torch.concat([denoising_logits, content], dim=1)
726
+
727
+ return (
728
+ content,
729
+ enc_topk_bbox_unact,
730
+ enc_topk_bboxes_list,
731
+ enc_topk_logits_list,
732
+ )
733
+
734
+ def _select_topk(
735
+ self,
736
+ memory: torch.Tensor,
737
+ outputs_logits: torch.Tensor,
738
+ outputs_coords_unact: torch.Tensor,
739
+ topk: int,
740
+ ):
741
+ if self.query_select_method == "default":
742
+ _, topk_ind = torch.topk(outputs_logits.max(-1).values, topk, dim=-1)
743
+
744
+ elif self.query_select_method == "one2many":
745
+ _, topk_ind = torch.topk(outputs_logits.flatten(1), topk, dim=-1)
746
+ topk_ind = topk_ind // self.num_classes
747
+
748
+ elif self.query_select_method == "agnostic":
749
+ _, topk_ind = torch.topk(outputs_logits.squeeze(-1), topk, dim=-1)
750
+
751
+ topk_ind: torch.Tensor
752
+
753
+ topk_coords = outputs_coords_unact.gather(
754
+ dim=1,
755
+ index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_coords_unact.shape[-1]),
756
+ )
757
+
758
+ topk_logits = outputs_logits.gather(
759
+ dim=1,
760
+ index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_logits.shape[-1]),
761
+ )
762
+
763
+ topk_memory = memory.gather(
764
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, memory.shape[-1])
765
+ )
766
+
767
+ return topk_memory, topk_logits, topk_coords
768
+
769
+ def forward(self, feats, targets=None):
770
+ # input projection and embedding
771
+ memory, spatial_shapes = self._get_encoder_input(feats)
772
+ denoising_logits, denoising_bbox_unact, attn_mask = (
773
+ None,
774
+ None,
775
+ None,
776
+ )
777
+
778
+ (
779
+ init_ref_contents,
780
+ init_ref_points_unact,
781
+ enc_topk_bboxes_list,
782
+ enc_topk_logits_list,
783
+ ) = self._get_decoder_input(
784
+ memory, spatial_shapes, denoising_logits, denoising_bbox_unact
785
+ )
786
+
787
+ # decoder
788
+ out_bboxes, out_logits = self.decoder(
789
+ init_ref_contents,
790
+ init_ref_points_unact,
791
+ memory,
792
+ spatial_shapes,
793
+ self.dec_bbox_head,
794
+ self.dec_score_head,
795
+ self.query_pos_head,
796
+ attn_mask=attn_mask,
797
+ )
798
+
799
+ out = {"pred_logits": out_logits[-1], "pred_boxes": out_bboxes[-1]}
800
+
801
+ return out
802
+
803
+ @torch.jit.unused
804
+ def _set_aux_loss(self, outputs_class, outputs_coord):
805
+ # this is a workaround to make torchscript happy, as torchscript
806
+ # doesn't support dictionary with non-homogeneous values, such
807
+ # as a dict having both a Tensor and a list.
808
+ return [
809
+ {"pred_logits": a, "pred_boxes": b}
810
+ for a, b in zip(outputs_class, outputs_coord)
811
+ ]