yomitoku 0.5.3__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yomitoku/layout_parser.py CHANGED
@@ -1,11 +1,16 @@
1
1
  from typing import List, Union
2
2
 
3
3
  import cv2
4
+ import os
5
+ import onnx
6
+ import onnxruntime
4
7
  import torch
5
8
  import torchvision.transforms as T
6
9
  from PIL import Image
7
10
  from pydantic import conlist
8
11
 
12
+ from .constants import ROOT_DIR
13
+
9
14
  from .base import BaseModelCatalog, BaseModule, BaseSchema
10
15
  from .configs import LayoutParserRTDETRv2Config
11
16
  from .models import RTDETRv2
@@ -91,6 +96,7 @@ class LayoutParser(BaseModule):
91
96
  device="cuda",
92
97
  visualize=False,
93
98
  from_pretrained=True,
99
+ infer_onnx=False,
94
100
  ):
95
101
  super().__init__()
96
102
  self.load_model(model_name, path_cfg, from_pretrained)
@@ -98,7 +104,6 @@ class LayoutParser(BaseModule):
98
104
  self.visualize = visualize
99
105
 
100
106
  self.model.eval()
101
- self.model.to(self.device)
102
107
 
103
108
  self.postprocessor = RTDETRPostProcessor(
104
109
  num_classes=self._cfg.RTDETRTransformerv2.num_classes,
@@ -119,11 +124,49 @@ class LayoutParser(BaseModule):
119
124
  }
120
125
 
121
126
  self.role = self._cfg.role
127
+ self.infer_onnx = infer_onnx
128
+ if infer_onnx:
129
+ name = self._cfg.hf_hub_repo.split("/")[-1]
130
+ path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
131
+ if not os.path.exists(path_onnx):
132
+ self.convert_onnx(path_onnx)
133
+
134
+ self.model = None
135
+
136
+ model = onnx.load(path_onnx)
137
+ if torch.cuda.is_available() and device == "cuda":
138
+ self.sess = onnxruntime.InferenceSession(
139
+ model.SerializeToString(), providers=["CUDAExecutionProvider"]
140
+ )
141
+ else:
142
+ self.sess = onnxruntime.InferenceSession(model.SerializeToString())
143
+
144
+ if self.model is not None:
145
+ self.model.to(self.device)
146
+
147
+ def convert_onnx(self, path_onnx):
148
+ dynamic_axes = {
149
+ "input": {0: "batch_size"},
150
+ "output": {0: "batch_size"},
151
+ }
152
+
153
+ img_size = self._cfg.data.img_size
154
+ dummy_input = torch.randn(1, 3, *img_size, requires_grad=True)
155
+
156
+ torch.onnx.export(
157
+ self.model,
158
+ dummy_input,
159
+ path_onnx,
160
+ opset_version=16,
161
+ input_names=["input"],
162
+ output_names=["pred_logits", "pred_boxes"],
163
+ dynamic_axes=dynamic_axes,
164
+ )
122
165
 
123
166
  def preprocess(self, img):
124
167
  cv_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
125
168
  img = Image.fromarray(cv_img)
126
- img_tensor = self.transforms(img)[None].to(self.device)
169
+ img_tensor = self.transforms(img)[None]
127
170
  return img_tensor
128
171
 
129
172
  def postprocess(self, preds, image_size):
@@ -175,8 +218,19 @@ class LayoutParser(BaseModule):
175
218
  ori_h, ori_w = img.shape[:2]
176
219
  img_tensor = self.preprocess(img)
177
220
 
178
- with torch.inference_mode():
179
- preds = self.model(img_tensor)
221
+ if self.infer_onnx:
222
+ input = img_tensor.numpy()
223
+ results = self.sess.run(None, {"input": input})
224
+ preds = {
225
+ "pred_logits": torch.tensor(results[0]).to(self.device),
226
+ "pred_boxes": torch.tensor(results[1]).to(self.device),
227
+ }
228
+
229
+ else:
230
+ with torch.inference_mode():
231
+ img_tensor = img_tensor.to(self.device)
232
+ preds = self.model(img_tensor)
233
+
180
234
  results = self.postprocess(preds, (ori_h, ori_w))
181
235
 
182
236
  vis = None
@@ -59,9 +59,7 @@ class ConvNormLayer(nn.Module):
59
59
  class BasicBlock(nn.Module):
60
60
  expansion = 1
61
61
 
62
- def __init__(
63
- self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"
64
- ):
62
+ def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
65
63
  super().__init__()
66
64
 
67
65
  self.shortcut = shortcut
@@ -100,9 +98,7 @@ class BasicBlock(nn.Module):
100
98
  class BottleNeck(nn.Module):
101
99
  expansion = 4
102
100
 
103
- def __init__(
104
- self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"
105
- ):
101
+ def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
106
102
  super().__init__()
107
103
 
108
104
  if variant == "a":
@@ -125,17 +121,13 @@ class BottleNeck(nn.Module):
125
121
  ("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
126
122
  (
127
123
  "conv",
128
- ConvNormLayer(
129
- ch_in, ch_out * self.expansion, 1, 1
130
- ),
124
+ ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1),
131
125
  ),
132
126
  ]
133
127
  )
134
128
  )
135
129
  else:
136
- self.short = ConvNormLayer(
137
- ch_in, ch_out * self.expansion, 1, stride
138
- )
130
+ self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride)
139
131
 
140
132
  self.act = nn.Identity() if act is None else get_activation(act)
141
133
 
@@ -156,9 +148,7 @@ class BottleNeck(nn.Module):
156
148
 
157
149
 
158
150
  class Blocks(nn.Module):
159
- def __init__(
160
- self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"
161
- ):
151
+ def __init__(self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"):
162
152
  super().__init__()
163
153
 
164
154
  self.blocks = nn.ModuleList()
@@ -252,9 +252,7 @@ class HybridEncoder(nn.Module):
252
252
  for in_channel in in_channels:
253
253
  if version == "v1":
254
254
  proj = nn.Sequential(
255
- nn.Conv2d(
256
- in_channel, hidden_dim, kernel_size=1, bias=False
257
- ),
255
+ nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False),
258
256
  nn.BatchNorm2d(hidden_dim),
259
257
  )
260
258
  elif version == "v2":
@@ -290,9 +288,7 @@ class HybridEncoder(nn.Module):
290
288
 
291
289
  self.encoder = nn.ModuleList(
292
290
  [
293
- TransformerEncoder(
294
- copy.deepcopy(encoder_layer), num_encoder_layers
295
- )
291
+ TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers)
296
292
  for _ in range(len(use_encoder_idx))
297
293
  ]
298
294
  )
@@ -347,9 +343,7 @@ class HybridEncoder(nn.Module):
347
343
  # self.register_buffer(f'pos_embed{idx}', pos_embed)
348
344
 
349
345
  @staticmethod
350
- def build_2d_sincos_position_embedding(
351
- w, h, embed_dim=256, temperature=10000.0
352
- ):
346
+ def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
353
347
  """ """
354
348
  grid_w = torch.arange(int(w), dtype=torch.float32)
355
349
  grid_h = torch.arange(int(h), dtype=torch.float32)
@@ -387,9 +381,7 @@ class HybridEncoder(nn.Module):
387
381
  src_flatten.device
388
382
  )
389
383
 
390
- memory: torch.Tensor = self.encoder[i](
391
- src_flatten, pos_embed=pos_embed
392
- )
384
+ memory: torch.Tensor = self.encoder[i](src_flatten, pos_embed=pos_embed)
393
385
  proj_feats[enc_ind] = (
394
386
  memory.permute(0, 2, 1)
395
387
  .reshape(-1, self.hidden_dim, h, w)
@@ -401,13 +393,9 @@ class HybridEncoder(nn.Module):
401
393
  for idx in range(len(self.in_channels) - 1, 0, -1):
402
394
  feat_heigh = inner_outs[0]
403
395
  feat_low = proj_feats[idx - 1]
404
- feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](
405
- feat_heigh
406
- )
396
+ feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_heigh)
407
397
  inner_outs[0] = feat_heigh
408
- upsample_feat = F.interpolate(
409
- feat_heigh, scale_factor=2.0, mode="nearest"
410
- )
398
+ upsample_feat = F.interpolate(feat_heigh, scale_factor=2.0, mode="nearest")
411
399
  inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](
412
400
  torch.concat([upsample_feat, feat_low], dim=1)
413
401
  )
@@ -40,9 +40,7 @@ def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
40
40
 
41
41
 
42
42
  class MLP(nn.Module):
43
- def __init__(
44
- self, input_dim, hidden_dim, output_dim, num_layers, act="relu"
45
- ):
43
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act="relu"):
46
44
  super().__init__()
47
45
  self.num_layers = num_layers
48
46
  h = [hidden_dim] * (num_layers - 1)
@@ -193,9 +191,7 @@ class MSDeformableAttention(nn.Module):
193
191
  elif reference_points.shape[-1] == 4:
194
192
  # reference_points [8, 480, None, 1, 4]
195
193
  # sampling_offsets [8, 480, 8, 12, 2]
196
- num_points_scale = self.num_points_scale.to(
197
- dtype=query.dtype
198
- ).unsqueeze(-1)
194
+ num_points_scale = self.num_points_scale.to(dtype=query.dtype).unsqueeze(-1)
199
195
  offset = (
200
196
  sampling_offsets
201
197
  * num_points_scale
@@ -330,9 +326,7 @@ def deformable_attention_core_func_v2(
330
326
  _, Len_q, _, _, _ = sampling_locations.shape
331
327
 
332
328
  split_shape = [h * w for h, w in value_spatial_shapes]
333
- value_list = (
334
- value.permute(0, 2, 3, 1).flatten(0, 1).split(split_shape, dim=-1)
335
- )
329
+ value_list = value.permute(0, 2, 3, 1).flatten(0, 1).split(split_shape, dim=-1)
336
330
 
337
331
  # sampling_offsets [8, 480, 8, 12, 2]
338
332
  if method == "default":
@@ -361,8 +355,7 @@ def deformable_attention_core_func_v2(
361
355
  elif method == "discrete":
362
356
  # n * m, seq, n, 2
363
357
  sampling_coord = (
364
- sampling_grid_l * torch.tensor([[w, h]], device=value.device)
365
- + 0.5
358
+ sampling_grid_l * torch.tensor([[w, h]], device=value.device) + 0.5
366
359
  ).to(torch.int64)
367
360
 
368
361
  # FIX ME? for rectangle input
@@ -389,9 +382,7 @@ def deformable_attention_core_func_v2(
389
382
  attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(
390
383
  bs * n_head, 1, Len_q, sum(num_points_list)
391
384
  )
392
- weighted_sample_locs = (
393
- torch.concat(sampling_value_list, dim=-1) * attn_weights
394
- )
385
+ weighted_sample_locs = torch.concat(sampling_value_list, dim=-1) * attn_weights
395
386
  output = weighted_sample_locs.sum(-1).reshape(bs, n_head * c, Len_q)
396
387
 
397
388
  return output.permute(0, 2, 1)
@@ -606,9 +597,7 @@ class RTDETRTransformerv2(nn.Module):
606
597
  [
607
598
  (
608
599
  "conv",
609
- nn.Conv2d(
610
- in_channels, self.hidden_dim, 1, bias=False
611
- ),
600
+ nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False),
612
601
  ),
613
602
  (
614
603
  "norm",
@@ -689,13 +678,9 @@ class RTDETRTransformerv2(nn.Module):
689
678
  torch.arange(h), torch.arange(w), indexing="ij"
690
679
  )
691
680
  grid_xy = torch.stack([grid_x, grid_y], dim=-1)
692
- grid_xy = (grid_xy.unsqueeze(0) + 0.5) / torch.tensor(
693
- [w, h], dtype=dtype
694
- )
681
+ grid_xy = (grid_xy.unsqueeze(0) + 0.5) / torch.tensor([w, h], dtype=dtype)
695
682
  wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
696
- lvl_anchors = torch.concat([grid_xy, wh], dim=-1).reshape(
697
- -1, h * w, 4
698
- )
683
+ lvl_anchors = torch.concat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4)
699
684
  anchors.append(lvl_anchors)
700
685
 
701
686
  anchors = torch.concat(anchors, dim=1).to(device)
@@ -729,22 +714,18 @@ class RTDETRTransformerv2(nn.Module):
729
714
  )
730
715
 
731
716
  enc_topk_bboxes_list, enc_topk_logits_list = [], []
732
- enc_topk_memory, enc_topk_logits, enc_topk_bbox_unact = (
733
- self._select_topk(
734
- output_memory,
735
- enc_outputs_logits,
736
- enc_outputs_coord_unact,
737
- self.num_queries,
738
- )
717
+ enc_topk_memory, enc_topk_logits, enc_topk_bbox_unact = self._select_topk(
718
+ output_memory,
719
+ enc_outputs_logits,
720
+ enc_outputs_coord_unact,
721
+ self.num_queries,
739
722
  )
740
723
 
741
724
  # if self.num_select_queries != self.num_queries:
742
725
  # raise NotImplementedError('')
743
726
 
744
727
  if self.learn_query_content:
745
- content = self.tgt_embed.weight.unsqueeze(0).tile(
746
- [memory.shape[0], 1, 1]
747
- )
728
+ content = self.tgt_embed.weight.unsqueeze(0).tile([memory.shape[0], 1, 1])
748
729
  else:
749
730
  content = enc_topk_memory.detach()
750
731
 
@@ -771,9 +752,7 @@ class RTDETRTransformerv2(nn.Module):
771
752
  topk: int,
772
753
  ):
773
754
  if self.query_select_method == "default":
774
- _, topk_ind = torch.topk(
775
- outputs_logits.max(-1).values, topk, dim=-1
776
- )
755
+ _, topk_ind = torch.topk(outputs_logits.max(-1).values, topk, dim=-1)
777
756
 
778
757
  elif self.query_select_method == "one2many":
779
758
  _, topk_ind = torch.topk(outputs_logits.flatten(1), topk, dim=-1)
@@ -786,16 +765,12 @@ class RTDETRTransformerv2(nn.Module):
786
765
 
787
766
  topk_coords = outputs_coords_unact.gather(
788
767
  dim=1,
789
- index=topk_ind.unsqueeze(-1).repeat(
790
- 1, 1, outputs_coords_unact.shape[-1]
791
- ),
768
+ index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_coords_unact.shape[-1]),
792
769
  )
793
770
 
794
771
  topk_logits = outputs_logits.gather(
795
772
  dim=1,
796
- index=topk_ind.unsqueeze(-1).repeat(
797
- 1, 1, outputs_logits.shape[-1]
798
- ),
773
+ index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_logits.shape[-1]),
799
774
  )
800
775
 
801
776
  topk_memory = memory.gather(
yomitoku/models/parseq.py CHANGED
@@ -22,7 +22,6 @@ from huggingface_hub import PyTorchModelHubMixin
22
22
  from timm.models.helpers import named_apply
23
23
  from torch import Tensor
24
24
 
25
- from ..postprocessor import ParseqTokenizer as Tokenizer
26
25
  from .layers.parseq_transformer import Decoder, Encoder, TokenEmbedding
27
26
 
28
27
 
@@ -123,7 +122,6 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
123
122
 
124
123
  def forward(
125
124
  self,
126
- tokenizer: Tokenizer,
127
125
  images: Tensor,
128
126
  max_length: Optional[int] = None,
129
127
  ) -> Tensor:
@@ -150,11 +148,11 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
150
148
  if self.decode_ar:
151
149
  tgt_in = torch.full(
152
150
  (bs, num_steps),
153
- tokenizer.pad_id,
151
+ self.tokenizer.pad_id,
154
152
  dtype=torch.long,
155
153
  device=self._device,
156
154
  )
157
- tgt_in[:, 0] = tokenizer.bos_id
155
+ tgt_in[:, 0] = self.tokenizer.bos_id
158
156
 
159
157
  logits = []
160
158
  for i in range(num_steps):
@@ -177,7 +175,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
177
175
  # greedy decode. add the next token index to the target input
178
176
  tgt_in[:, j] = p_i.squeeze().argmax(-1)
179
177
  # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
180
- if testing and (tgt_in == tokenizer.eos_id).any(dim=-1).all():
178
+ if testing and (tgt_in == self.tokenizer.eos_id).any(dim=-1).all():
181
179
  break
182
180
 
183
181
  logits = torch.cat(logits, dim=1)
@@ -185,7 +183,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
185
183
  # No prior context, so input is just <bos>. We query all positions.
186
184
  tgt_in = torch.full(
187
185
  (bs, 1),
188
- tokenizer.bos_id,
186
+ self.tokenizer.bos_id,
189
187
  dtype=torch.long,
190
188
  device=self._device,
191
189
  )
@@ -200,7 +198,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
200
198
  torch.ones(
201
199
  num_steps,
202
200
  num_steps,
203
- dtype=torch.bool,
201
+ dtype=torch.int64,
204
202
  device=self._device,
205
203
  ),
206
204
  2,
@@ -208,7 +206,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
208
206
  ] = 0
209
207
  bos = torch.full(
210
208
  (bs, 1),
211
- tokenizer.bos_id,
209
+ self.tokenizer.bos_id,
212
210
  dtype=torch.long,
213
211
  device=self._device,
214
212
  )
@@ -216,7 +214,9 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
216
214
  # Prior context is the previous output.
217
215
  tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
218
216
  # Mask tokens beyond the first EOS token.
219
- tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(-1) > 0
217
+ tgt_padding_mask = (tgt_in == self.tokenizer.eos_id).int().cumsum(
218
+ -1
219
+ ) > 0
220
220
  tgt_out = self.decode(
221
221
  tgt_in,
222
222
  memory,
yomitoku/ocr.py CHANGED
@@ -16,16 +16,37 @@ class WordPrediction(BaseSchema):
16
16
  )
17
17
  content: str
18
18
  direction: str
19
- det_score: float
20
19
  rec_score: float
20
+ det_score: float
21
21
 
22
22
 
23
23
  class OCRSchema(BaseSchema):
24
24
  words: List[WordPrediction]
25
25
 
26
26
 
27
+ def ocr_aggregate(det_outputs, rec_outputs):
28
+ words = []
29
+ for points, det_score, pred, rec_score, direction in zip(
30
+ det_outputs.points,
31
+ det_outputs.scores,
32
+ rec_outputs.contents,
33
+ rec_outputs.scores,
34
+ rec_outputs.directions,
35
+ ):
36
+ words.append(
37
+ {
38
+ "points": points,
39
+ "content": pred,
40
+ "direction": direction,
41
+ "det_score": det_score,
42
+ "rec_score": rec_score,
43
+ }
44
+ )
45
+ return words
46
+
47
+
27
48
  class OCR:
28
- def __init__(self, configs=None, device="cuda", visualize=False):
49
+ def __init__(self, configs={}, device="cuda", visualize=False):
29
50
  text_detector_kwargs = {
30
51
  "device": device,
31
52
  "visualize": visualize,
@@ -36,10 +57,6 @@ class OCR:
36
57
  }
37
58
 
38
59
  if isinstance(configs, dict):
39
- assert (
40
- "text_detector" in configs or "text_recognizer" in configs
41
- ), "Invalid config key. Please check the config keys."
42
-
43
60
  if "text_detector" in configs:
44
61
  text_detector_kwargs.update(configs["text_detector"])
45
62
  if "text_recognizer" in configs:
@@ -52,26 +69,6 @@ class OCR:
52
69
  self.detector = TextDetector(**text_detector_kwargs)
53
70
  self.recognizer = TextRecognizer(**text_recognizer_kwargs)
54
71
 
55
- def aggregate(self, det_outputs, rec_outputs):
56
- words = []
57
- for points, det_score, pred, rec_score, direction in zip(
58
- det_outputs.points,
59
- det_outputs.scores,
60
- rec_outputs.contents,
61
- rec_outputs.scores,
62
- rec_outputs.directions,
63
- ):
64
- words.append(
65
- {
66
- "points": points,
67
- "content": pred,
68
- "direction": direction,
69
- "det_score": det_score,
70
- "rec_score": rec_score,
71
- }
72
- )
73
- return words
74
-
75
72
  def __call__(self, img):
76
73
  """_summary_
77
74
 
@@ -82,6 +79,6 @@ class OCR:
82
79
  det_outputs, vis = self.detector(img)
83
80
  rec_outputs, vis = self.recognizer(img, det_outputs.points, vis=vis)
84
81
 
85
- outputs = {"words": self.aggregate(det_outputs, rec_outputs)}
82
+ outputs = {"words": ocr_aggregate(det_outputs, rec_outputs)}
86
83
  results = OCRSchema(**outputs)
87
84
  return results, vis
yomitoku/onnx/.gitkeep ADDED
File without changes
@@ -54,16 +54,12 @@ class RTDETRPostProcessor(nn.Module):
54
54
  logits, boxes = outputs["pred_logits"], outputs["pred_boxes"]
55
55
  # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
56
56
 
57
- bbox_pred = torchvision.ops.box_convert(
58
- boxes, in_fmt="cxcywh", out_fmt="xyxy"
59
- )
57
+ bbox_pred = torchvision.ops.box_convert(boxes, in_fmt="cxcywh", out_fmt="xyxy")
60
58
  bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
61
59
 
62
60
  if self.use_focal_loss:
63
61
  scores = F.sigmoid(logits)
64
- scores, index = torch.topk(
65
- scores.flatten(1), self.num_top_queries, dim=-1
66
- )
62
+ scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1)
67
63
  # TODO for older tensorrt
68
64
  # labels = index % self.num_classes
69
65
  labels = mod(index, self.num_classes)
@@ -77,9 +73,7 @@ class RTDETRPostProcessor(nn.Module):
77
73
  scores = F.softmax(logits)[:, :, :-1]
78
74
  scores, labels = scores.max(dim=-1)
79
75
  if scores.shape[1] > self.num_top_queries:
80
- scores, index = torch.topk(
81
- scores, self.num_top_queries, dim=-1
82
- )
76
+ scores, index = torch.topk(scores, self.num_top_queries, dim=-1)
83
77
  labels = torch.gather(labels, dim=1, index=index)
84
78
  boxes = torch.gather(
85
79
  boxes,
@@ -97,10 +91,7 @@ class RTDETRPostProcessor(nn.Module):
97
91
 
98
92
  labels = (
99
93
  torch.tensor(
100
- [
101
- mscoco_label2category[int(x.item())]
102
- for x in labels.flatten()
103
- ]
94
+ [mscoco_label2category[int(x.item())] for x in labels.flatten()]
104
95
  )
105
96
  .to(boxes.device)
106
97
  .reshape(labels.shape)