python-doctr 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. doctr/datasets/cord.py +10 -1
  2. doctr/datasets/funsd.py +11 -1
  3. doctr/datasets/ic03.py +11 -1
  4. doctr/datasets/ic13.py +10 -1
  5. doctr/datasets/iiit5k.py +26 -16
  6. doctr/datasets/imgur5k.py +10 -1
  7. doctr/datasets/sroie.py +11 -1
  8. doctr/datasets/svhn.py +11 -1
  9. doctr/datasets/svt.py +11 -1
  10. doctr/datasets/synthtext.py +11 -1
  11. doctr/datasets/utils.py +7 -2
  12. doctr/datasets/vocabs.py +6 -2
  13. doctr/datasets/wildreceipt.py +12 -1
  14. doctr/file_utils.py +19 -0
  15. doctr/io/elements.py +12 -4
  16. doctr/models/builder.py +2 -2
  17. doctr/models/classification/magc_resnet/tensorflow.py +13 -6
  18. doctr/models/classification/mobilenet/pytorch.py +2 -0
  19. doctr/models/classification/mobilenet/tensorflow.py +14 -8
  20. doctr/models/classification/predictor/pytorch.py +11 -7
  21. doctr/models/classification/predictor/tensorflow.py +10 -6
  22. doctr/models/classification/resnet/tensorflow.py +21 -8
  23. doctr/models/classification/textnet/tensorflow.py +11 -5
  24. doctr/models/classification/vgg/tensorflow.py +9 -3
  25. doctr/models/classification/vit/tensorflow.py +10 -4
  26. doctr/models/classification/zoo.py +22 -10
  27. doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
  28. doctr/models/detection/fast/tensorflow.py +14 -11
  29. doctr/models/detection/linknet/tensorflow.py +23 -11
  30. doctr/models/detection/predictor/tensorflow.py +2 -2
  31. doctr/models/factory/hub.py +5 -6
  32. doctr/models/kie_predictor/base.py +4 -0
  33. doctr/models/kie_predictor/pytorch.py +4 -0
  34. doctr/models/kie_predictor/tensorflow.py +8 -1
  35. doctr/models/modules/transformer/tensorflow.py +0 -2
  36. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  37. doctr/models/modules/vision_transformer/tensorflow.py +1 -1
  38. doctr/models/predictor/base.py +24 -12
  39. doctr/models/predictor/pytorch.py +4 -0
  40. doctr/models/predictor/tensorflow.py +8 -1
  41. doctr/models/preprocessor/tensorflow.py +1 -1
  42. doctr/models/recognition/crnn/tensorflow.py +8 -6
  43. doctr/models/recognition/master/tensorflow.py +9 -4
  44. doctr/models/recognition/parseq/tensorflow.py +10 -8
  45. doctr/models/recognition/sar/tensorflow.py +7 -3
  46. doctr/models/recognition/vitstr/tensorflow.py +9 -4
  47. doctr/models/utils/pytorch.py +1 -1
  48. doctr/models/utils/tensorflow.py +15 -15
  49. doctr/transforms/functional/pytorch.py +1 -1
  50. doctr/transforms/modules/pytorch.py +7 -6
  51. doctr/transforms/modules/tensorflow.py +15 -12
  52. doctr/utils/geometry.py +106 -19
  53. doctr/utils/metrics.py +1 -1
  54. doctr/utils/reconstitution.py +151 -65
  55. doctr/version.py +1 -1
  56. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/METADATA +11 -11
  57. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/RECORD +61 -61
  58. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
  59. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
  60. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
  61. {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
@@ -12,7 +12,7 @@ from tensorflow.keras import Model, layers
12
12
  from doctr.datasets import VOCABS
13
13
 
14
14
  from ...classification import vit_b, vit_s
15
- from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
15
+ from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
16
16
  from .base import _ViTSTR, _ViTSTRPostProcessor
17
17
 
18
18
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
@@ -23,14 +23,14 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
23
23
  "std": (0.299, 0.296, 0.301),
24
24
  "input_shape": (32, 128, 3),
25
25
  "vocab": VOCABS["french"],
26
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_small-358fab2e.zip&src=0",
26
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0",
27
27
  },
28
28
  "vitstr_base": {
29
29
  "mean": (0.694, 0.695, 0.693),
30
30
  "std": (0.299, 0.296, 0.301),
31
31
  "input_shape": (32, 128, 3),
32
32
  "vocab": VOCABS["french"],
33
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_base-2889159a.zip&src=0",
33
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0",
34
34
  },
35
35
  }
36
36
 
@@ -216,9 +216,14 @@ def _vitstr(
216
216
 
217
217
  # Build the model
218
218
  model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs)
219
+ _build_model(model)
220
+
219
221
  # Load pretrained parameters
220
222
  if pretrained:
221
- load_pretrained_params(model, default_cfgs[arch]["url"])
223
+ # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
224
+ load_pretrained_params(
225
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
226
+ )
222
227
 
223
228
  return model
224
229
 
@@ -157,7 +157,7 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
157
157
  """
158
158
  torch.onnx.export(
159
159
  model,
160
- dummy_input,
160
+ dummy_input, # type: ignore[arg-type]
161
161
  f"{model_name}.onnx",
162
162
  input_names=["input"],
163
163
  output_names=["logits"],
@@ -4,9 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import logging
7
- import os
8
7
  from typing import Any, Callable, List, Optional, Tuple, Union
9
- from zipfile import ZipFile
10
8
 
11
9
  import tensorflow as tf
12
10
  import tf2onnx
@@ -19,6 +17,7 @@ logging.getLogger("tensorflow").setLevel(logging.DEBUG)
19
17
 
20
18
  __all__ = [
21
19
  "load_pretrained_params",
20
+ "_build_model",
22
21
  "conv_sequence",
23
22
  "IntermediateLayerGetter",
24
23
  "export_model_to_onnx",
@@ -36,41 +35,42 @@ def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor:
36
35
  return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x
37
36
 
38
37
 
38
+ def _build_model(model: Model):
39
+ """Build a model by calling it once with dummy input
40
+
41
+ Args:
42
+ ----
43
+ model: the model to be built
44
+ """
45
+ model(tf.zeros((1, *model.cfg["input_shape"])), training=False)
46
+
47
+
39
48
  def load_pretrained_params(
40
49
  model: Model,
41
50
  url: Optional[str] = None,
42
51
  hash_prefix: Optional[str] = None,
43
- overwrite: bool = False,
44
- internal_name: str = "weights",
52
+ skip_mismatch: bool = False,
45
53
  **kwargs: Any,
46
54
  ) -> None:
47
55
  """Load a set of parameters onto a model
48
56
 
49
57
  >>> from doctr.models import load_pretrained_params
50
- >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
58
+ >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.weights.h5")
51
59
 
52
60
  Args:
53
61
  ----
54
62
  model: the keras model to be loaded
55
63
  url: URL of the zipped set of parameters
56
64
  hash_prefix: first characters of SHA256 expected hash
57
- overwrite: should the zip extraction be enforced if the archive has already been extracted
58
- internal_name: name of the ckpt files
65
+ skip_mismatch: skip loading layers with mismatched shapes
59
66
  **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
60
67
  """
61
68
  if url is None:
62
69
  logging.warning("Invalid model URL, using default initialization.")
63
70
  else:
64
71
  archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
65
-
66
- # Unzip the archive
67
- params_path = archive_path.parent.joinpath(archive_path.stem)
68
- if not params_path.is_dir() or overwrite:
69
- with ZipFile(archive_path, "r") as f:
70
- f.extractall(path=params_path)
71
-
72
72
  # Load weights
73
- model.load_weights(f"{params_path}{os.sep}{internal_name}")
73
+ model.load_weights(archive_path, skip_mismatch=skip_mismatch)
74
74
 
75
75
 
76
76
  def conv_sequence(
@@ -89,7 +89,7 @@ def rotate_sample(
89
89
  rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[2]
90
90
  rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[1]
91
91
 
92
- return rotated_img, np.clip(rotated_geoms, 0, 1)
92
+ return rotated_img, np.clip(np.around(rotated_geoms, decimals=15), 0, 1)
93
93
 
94
94
 
95
95
  def crop_detection(
@@ -74,16 +74,18 @@ class Resize(T.Resize):
74
74
  if self.symmetric_pad:
75
75
  half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2))
76
76
  _pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1])
77
+ # Pad image
77
78
  img = pad(img, _pad)
78
79
 
79
80
  # In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
80
81
  if target is not None:
82
+ if self.symmetric_pad:
83
+ offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
84
+
81
85
  if self.preserve_aspect_ratio:
82
86
  # Get absolute coords
83
87
  if target.shape[1:] == (4,):
84
88
  if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
85
- if np.max(target) <= 1:
86
- offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
87
89
  target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1]
88
90
  target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2]
89
91
  else:
@@ -91,16 +93,15 @@ class Resize(T.Resize):
91
93
  target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2]
92
94
  elif target.shape[1:] == (4, 2):
93
95
  if isinstance(self.size, (tuple, list)) and self.symmetric_pad:
94
- if np.max(target) <= 1:
95
- offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]
96
96
  target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1]
97
97
  target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2]
98
98
  else:
99
99
  target[..., 0] *= raw_shape[-1] / img.shape[-1]
100
100
  target[..., 1] *= raw_shape[-2] / img.shape[-2]
101
101
  else:
102
- raise AssertionError
103
- return img, target
102
+ raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)")
103
+
104
+ return img, np.clip(target, 0, 1)
104
105
 
105
106
  return img
106
107
 
@@ -107,29 +107,34 @@ class Resize(NestedObject):
107
107
  target: Optional[np.ndarray] = None,
108
108
  ) -> Union[tf.Tensor, Tuple[tf.Tensor, np.ndarray]]:
109
109
  input_dtype = img.dtype
110
+ self.output_size = (
111
+ (self.output_size, self.output_size) if isinstance(self.output_size, int) else self.output_size
112
+ )
110
113
 
111
114
  img = tf.image.resize(img, self.wanted_size, self.method, self.preserve_aspect_ratio, self.antialias)
112
115
  # It will produce an un-padded resized image, with a side shorter than wanted if we preserve aspect ratio
113
116
  raw_shape = img.shape[:2]
117
+ if self.symmetric_pad:
118
+ half_pad = (int((self.output_size[0] - img.shape[0]) / 2), 0)
114
119
  if self.preserve_aspect_ratio:
115
120
  if isinstance(self.output_size, (tuple, list)):
116
121
  # In that case we need to pad because we want to enforce both width and height
117
122
  if not self.symmetric_pad:
118
- offset = (0, 0)
123
+ half_pad = (0, 0)
119
124
  elif self.output_size[0] == img.shape[0]:
120
- offset = (0, int((self.output_size[1] - img.shape[1]) / 2))
121
- else:
122
- offset = (int((self.output_size[0] - img.shape[0]) / 2), 0)
123
- img = tf.image.pad_to_bounding_box(img, *offset, *self.output_size)
125
+ half_pad = (0, int((self.output_size[1] - img.shape[1]) / 2))
126
+ # Pad image
127
+ img = tf.image.pad_to_bounding_box(img, *half_pad, *self.output_size)
124
128
 
125
129
  # In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
126
130
  if target is not None:
131
+ if self.symmetric_pad:
132
+ offset = half_pad[0] / img.shape[0], half_pad[1] / img.shape[1]
133
+
127
134
  if self.preserve_aspect_ratio:
128
135
  # Get absolute coords
129
136
  if target.shape[1:] == (4,):
130
137
  if isinstance(self.output_size, (tuple, list)) and self.symmetric_pad:
131
- if np.max(target) <= 1:
132
- offset = offset[0] / img.shape[0], offset[1] / img.shape[1]
133
138
  target[:, [0, 2]] = offset[1] + target[:, [0, 2]] * raw_shape[1] / img.shape[1]
134
139
  target[:, [1, 3]] = offset[0] + target[:, [1, 3]] * raw_shape[0] / img.shape[0]
135
140
  else:
@@ -137,16 +142,15 @@ class Resize(NestedObject):
137
142
  target[:, [1, 3]] *= raw_shape[0] / img.shape[0]
138
143
  elif target.shape[1:] == (4, 2):
139
144
  if isinstance(self.output_size, (tuple, list)) and self.symmetric_pad:
140
- if np.max(target) <= 1:
141
- offset = offset[0] / img.shape[0], offset[1] / img.shape[1]
142
145
  target[..., 0] = offset[1] + target[..., 0] * raw_shape[1] / img.shape[1]
143
146
  target[..., 1] = offset[0] + target[..., 1] * raw_shape[0] / img.shape[0]
144
147
  else:
145
148
  target[..., 0] *= raw_shape[1] / img.shape[1]
146
149
  target[..., 1] *= raw_shape[0] / img.shape[0]
147
150
  else:
148
- raise AssertionError
149
- return tf.cast(img, dtype=input_dtype), target
151
+ raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)")
152
+
153
+ return tf.cast(img, dtype=input_dtype), np.clip(target, 0, 1)
150
154
 
151
155
  return tf.cast(img, dtype=input_dtype)
152
156
 
@@ -395,7 +399,6 @@ class GaussianBlur(NestedObject):
395
399
  def extra_repr(self) -> str:
396
400
  return f"kernel_shape={self.kernel_shape}, std={self.std}"
397
401
 
398
- @tf.function
399
402
  def __call__(self, img: tf.Tensor) -> tf.Tensor:
400
403
  return tf.squeeze(
401
404
  _gaussian_filter(
doctr/utils/geometry.py CHANGED
@@ -20,6 +20,7 @@ __all__ = [
20
20
  "rotate_boxes",
21
21
  "compute_expanded_shape",
22
22
  "rotate_image",
23
+ "remove_image_padding",
23
24
  "estimate_page_angle",
24
25
  "convert_to_relative_coords",
25
26
  "rotate_abs_geoms",
@@ -351,6 +352,26 @@ def rotate_image(
351
352
  return rot_img
352
353
 
353
354
 
355
+ def remove_image_padding(image: np.ndarray) -> np.ndarray:
356
+ """Remove black border padding from an image
357
+
358
+ Args:
359
+ ----
360
+ image: numpy tensor to remove padding from
361
+
362
+ Returns:
363
+ -------
364
+ Image with padding removed
365
+ """
366
+ # Find the bounding box of the non-black region
367
+ rows = np.any(image, axis=1)
368
+ cols = np.any(image, axis=0)
369
+ rmin, rmax = np.where(rows)[0][[0, -1]]
370
+ cmin, cmax = np.where(cols)[0][[0, -1]]
371
+
372
+ return image[rmin : rmax + 1, cmin : cmax + 1]
373
+
374
+
354
375
  def estimate_page_angle(polys: np.ndarray) -> float:
355
376
  """Takes a batch of rotated previously ORIENTED polys (N, 4, 2) (rectified by the classifier) and return the
356
377
  estimated angle ccw in degrees
@@ -431,7 +452,7 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True
431
452
 
432
453
 
433
454
  def extract_rcrops(
434
- img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True
455
+ img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True, assume_horizontal: bool = False
435
456
  ) -> List[np.ndarray]:
436
457
  """Created cropped images from list of rotated bounding boxes
437
458
 
@@ -441,6 +462,7 @@ def extract_rcrops(
441
462
  polys: bounding boxes of shape (N, 4, 2)
442
463
  dtype: target data type of bounding boxes
443
464
  channels_last: whether the channel dimensions is the last one instead of the last one
465
+ assume_horizontal: whether the boxes are assumed to be only horizontally oriented
444
466
 
445
467
  Returns:
446
468
  -------
@@ -458,22 +480,87 @@ def extract_rcrops(
458
480
  _boxes[:, :, 0] *= width
459
481
  _boxes[:, :, 1] *= height
460
482
 
461
- src_pts = _boxes[:, :3].astype(np.float32)
462
- # Preserve size
463
- d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1)
464
- d2 = np.linalg.norm(src_pts[:, 1] - src_pts[:, 2], axis=-1)
465
- # (N, 3, 2)
466
- dst_pts = np.zeros((_boxes.shape[0], 3, 2), dtype=dtype)
467
- dst_pts[:, 1, 0] = dst_pts[:, 2, 0] = d1 - 1
468
- dst_pts[:, 2, 1] = d2 - 1
469
- # Use a warp transformation to extract the crop
470
- crops = [
471
- cv2.warpAffine(
472
- img if channels_last else img.transpose(1, 2, 0),
473
- # Transformation matrix
474
- cv2.getAffineTransform(src_pts[idx], dst_pts[idx]),
475
- (int(d1[idx]), int(d2[idx])),
476
- )
477
- for idx in range(_boxes.shape[0])
478
- ]
483
+ src_img = img if channels_last else img.transpose(1, 2, 0)
484
+
485
+ # Handle only horizontal oriented boxes
486
+ if assume_horizontal:
487
+ crops = []
488
+
489
+ for box in _boxes:
490
+ # Calculate the centroid of the quadrilateral
491
+ centroid = np.mean(box, axis=0)
492
+
493
+ # Divide the points into left and right
494
+ left_points = box[box[:, 0] < centroid[0]]
495
+ right_points = box[box[:, 0] >= centroid[0]]
496
+
497
+ # Sort the left points according to the y-axis
498
+ left_points = left_points[np.argsort(left_points[:, 1])]
499
+ top_left_pt = left_points[0]
500
+ bottom_left_pt = left_points[-1]
501
+ # Sort the right points according to the y-axis
502
+ right_points = right_points[np.argsort(right_points[:, 1])]
503
+ top_right_pt = right_points[0]
504
+ bottom_right_pt = right_points[-1]
505
+ box_points = np.array(
506
+ [top_left_pt, bottom_left_pt, top_right_pt, bottom_right_pt],
507
+ dtype=dtype,
508
+ )
509
+
510
+ # Get the width and height of the rectangle that will contain the warped quadrilateral
511
+ width_upper = np.linalg.norm(top_right_pt - top_left_pt)
512
+ width_lower = np.linalg.norm(bottom_right_pt - bottom_left_pt)
513
+ height_left = np.linalg.norm(bottom_left_pt - top_left_pt)
514
+ height_right = np.linalg.norm(bottom_right_pt - top_right_pt)
515
+
516
+ # Get the maximum width and height
517
+ rect_width = max(int(width_upper), int(width_lower))
518
+ rect_height = max(int(height_left), int(height_right))
519
+
520
+ dst_pts = np.array(
521
+ [
522
+ [0, 0], # top-left
523
+ # bottom-left
524
+ [0, rect_height - 1],
525
+ # top-right
526
+ [rect_width - 1, 0],
527
+ # bottom-right
528
+ [rect_width - 1, rect_height - 1],
529
+ ],
530
+ dtype=dtype,
531
+ )
532
+
533
+ # Get the perspective transform matrix using the box points
534
+ affine_mat = cv2.getPerspectiveTransform(box_points, dst_pts)
535
+
536
+ # Perform the perspective warp to get the rectified crop
537
+ crop = cv2.warpPerspective(
538
+ src_img,
539
+ affine_mat,
540
+ (rect_width, rect_height),
541
+ )
542
+
543
+ # Add the crop to the list of crops
544
+ crops.append(crop)
545
+
546
+ # Handle any oriented boxes
547
+ else:
548
+ src_pts = _boxes[:, :3].astype(np.float32)
549
+ # Preserve size
550
+ d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1)
551
+ d2 = np.linalg.norm(src_pts[:, 1] - src_pts[:, 2], axis=-1)
552
+ # (N, 3, 2)
553
+ dst_pts = np.zeros((_boxes.shape[0], 3, 2), dtype=dtype)
554
+ dst_pts[:, 1, 0] = dst_pts[:, 2, 0] = d1 - 1
555
+ dst_pts[:, 2, 1] = d2 - 1
556
+ # Use a warp transformation to extract the crop
557
+ crops = [
558
+ cv2.warpAffine(
559
+ src_img,
560
+ # Transformation matrix
561
+ cv2.getAffineTransform(src_pts[idx], dst_pts[idx]),
562
+ (int(d1[idx]), int(d2[idx])),
563
+ )
564
+ for idx in range(_boxes.shape[0])
565
+ ]
479
566
  return crops # type: ignore[return-value]
doctr/utils/metrics.py CHANGED
@@ -149,7 +149,7 @@ def box_iou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray:
149
149
  right = np.minimum(r1, r2.T)
150
150
  bot = np.minimum(b1, b2.T)
151
151
 
152
- intersection = np.clip(right - left, 0, np.Inf) * np.clip(bot - top, 0, np.Inf)
152
+ intersection = np.clip(right - left, 0, np.inf) * np.clip(bot - top, 0, np.inf)
153
153
  union = (r1 - l1) * (b1 - t1) + ((r2 - l2) * (b2 - t2)).T - intersection
154
154
  iou_mat = intersection / union
155
155