python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/datasets/__init__.py +1 -5
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1100 -54
- doctr/file_utils.py +2 -92
- doctr/io/elements.py +37 -3
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/pytorch.py +1 -1
- doctr/models/_utils.py +4 -4
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +3 -4
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +2 -2
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +11 -2
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vip/__init__.py +1 -0
- doctr/models/classification/vip/layers/__init__.py +1 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +12 -3
- doctr/models/classification/zoo.py +7 -8
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/core.py +1 -1
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +7 -16
- doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +6 -17
- doctr/models/detection/fast/pytorch.py +17 -8
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +5 -15
- doctr/models/detection/linknet/pytorch.py +12 -3
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +1 -1
- doctr/models/detection/zoo.py +15 -32
- doctr/models/factory/hub.py +9 -22
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/pytorch.py +3 -7
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +52 -4
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +3 -8
- doctr/models/predictor/pytorch.py +3 -6
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +27 -32
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +16 -7
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/pytorch.py +15 -6
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/pytorch.py +26 -8
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +100 -47
- doctr/models/recognition/predictor/pytorch.py +4 -5
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +13 -4
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +1 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/pytorch.py +13 -4
- doctr/models/recognition/zoo.py +13 -8
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +29 -19
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +26 -92
- doctr/transforms/modules/pytorch.py +28 -26
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +7 -11
- doctr/utils/visualization.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
- python_doctr-1.0.0.dist-info/RECORD +149 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -433
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -397
- doctr/models/classification/textnet/tensorflow.py +0 -266
- doctr/models/classification/vgg/tensorflow.py +0 -116
- doctr/models/classification/vit/tensorflow.py +0 -192
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
- doctr/models/detection/fast/tensorflow.py +0 -419
- doctr/models/detection/linknet/tensorflow.py +0 -369
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -308
- doctr/models/recognition/master/tensorflow.py +0 -313
- doctr/models/recognition/parseq/tensorflow.py +0 -508
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -416
- doctr/models/recognition/vitstr/tensorflow.py +0 -278
- doctr/models/utils/tensorflow.py +0 -182
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.11.0.dist-info/RECORD +0 -173
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
doctr/io/image/tensorflow.py
DELETED
|
@@ -1,101 +0,0 @@
|
|
|
1
|
-
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
-
|
|
3
|
-
# This program is licensed under the Apache License 2.0.
|
|
4
|
-
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
import tensorflow as tf
|
|
9
|
-
from PIL import Image
|
|
10
|
-
from tensorflow.keras.utils import img_to_array
|
|
11
|
-
|
|
12
|
-
from doctr.utils.common_types import AbstractPath
|
|
13
|
-
|
|
14
|
-
__all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def tensor_from_pil(pil_img: Image.Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
|
|
18
|
-
"""Convert a PIL Image to a TensorFlow tensor
|
|
19
|
-
|
|
20
|
-
Args:
|
|
21
|
-
pil_img: a PIL image
|
|
22
|
-
dtype: the output tensor data type
|
|
23
|
-
|
|
24
|
-
Returns:
|
|
25
|
-
decoded image as tensor
|
|
26
|
-
"""
|
|
27
|
-
npy_img = img_to_array(pil_img)
|
|
28
|
-
|
|
29
|
-
return tensor_from_numpy(npy_img, dtype)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def read_img_as_tensor(img_path: AbstractPath, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
|
|
33
|
-
"""Read an image file as a TensorFlow tensor
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
img_path: location of the image file
|
|
37
|
-
dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
|
|
38
|
-
|
|
39
|
-
Returns:
|
|
40
|
-
decoded image as a tensor
|
|
41
|
-
"""
|
|
42
|
-
if dtype not in (tf.uint8, tf.float16, tf.float32):
|
|
43
|
-
raise ValueError("insupported value for dtype")
|
|
44
|
-
|
|
45
|
-
img = tf.io.read_file(img_path)
|
|
46
|
-
img = tf.image.decode_jpeg(img, channels=3)
|
|
47
|
-
|
|
48
|
-
if dtype != tf.uint8:
|
|
49
|
-
img = tf.image.convert_image_dtype(img, dtype=dtype)
|
|
50
|
-
img = tf.clip_by_value(img, 0, 1)
|
|
51
|
-
|
|
52
|
-
return img
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def decode_img_as_tensor(img_content: bytes, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
|
|
56
|
-
"""Read a byte stream as a TensorFlow tensor
|
|
57
|
-
|
|
58
|
-
Args:
|
|
59
|
-
img_content: bytes of a decoded image
|
|
60
|
-
dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
decoded image as a tensor
|
|
64
|
-
"""
|
|
65
|
-
if dtype not in (tf.uint8, tf.float16, tf.float32):
|
|
66
|
-
raise ValueError("insupported value for dtype")
|
|
67
|
-
|
|
68
|
-
img = tf.io.decode_image(img_content, channels=3)
|
|
69
|
-
|
|
70
|
-
if dtype != tf.uint8:
|
|
71
|
-
img = tf.image.convert_image_dtype(img, dtype=dtype)
|
|
72
|
-
img = tf.clip_by_value(img, 0, 1)
|
|
73
|
-
|
|
74
|
-
return img
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def tensor_from_numpy(npy_img: np.ndarray, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
|
|
78
|
-
"""Read an image file as a TensorFlow tensor
|
|
79
|
-
|
|
80
|
-
Args:
|
|
81
|
-
npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8
|
|
82
|
-
dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
|
|
83
|
-
|
|
84
|
-
Returns:
|
|
85
|
-
same image as a tensor of shape (H, W, C)
|
|
86
|
-
"""
|
|
87
|
-
if dtype not in (tf.uint8, tf.float16, tf.float32):
|
|
88
|
-
raise ValueError("insupported value for dtype")
|
|
89
|
-
|
|
90
|
-
if dtype == tf.uint8:
|
|
91
|
-
img = tf.convert_to_tensor(npy_img, dtype=dtype)
|
|
92
|
-
else:
|
|
93
|
-
img = tf.image.convert_image_dtype(npy_img, dtype=dtype)
|
|
94
|
-
img = tf.clip_by_value(img, 0, 1)
|
|
95
|
-
|
|
96
|
-
return img
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def get_img_shape(img: tf.Tensor) -> tuple[int, int]:
|
|
100
|
-
"""Get the shape of an image"""
|
|
101
|
-
return img.shape[:2]
|
|
@@ -1,196 +0,0 @@
|
|
|
1
|
-
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
-
|
|
3
|
-
# This program is licensed under the Apache License 2.0.
|
|
4
|
-
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
-
|
|
6
|
-
import math
|
|
7
|
-
from copy import deepcopy
|
|
8
|
-
from functools import partial
|
|
9
|
-
from typing import Any
|
|
10
|
-
|
|
11
|
-
import tensorflow as tf
|
|
12
|
-
from tensorflow.keras import activations, layers
|
|
13
|
-
from tensorflow.keras.models import Sequential
|
|
14
|
-
|
|
15
|
-
from doctr.datasets import VOCABS
|
|
16
|
-
|
|
17
|
-
from ...utils import _build_model, load_pretrained_params
|
|
18
|
-
from ..resnet.tensorflow import ResNet
|
|
19
|
-
|
|
20
|
-
__all__ = ["magc_resnet31"]
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
default_cfgs: dict[str, dict[str, Any]] = {
|
|
24
|
-
"magc_resnet31": {
|
|
25
|
-
"mean": (0.694, 0.695, 0.693),
|
|
26
|
-
"std": (0.299, 0.296, 0.301),
|
|
27
|
-
"input_shape": (32, 32, 3),
|
|
28
|
-
"classes": list(VOCABS["french"]),
|
|
29
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0",
|
|
30
|
-
},
|
|
31
|
-
}
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
class MAGC(layers.Layer):
|
|
35
|
-
"""Implements the Multi-Aspect Global Context Attention, as described in
|
|
36
|
-
<https://arxiv.org/pdf/1910.02562.pdf>`_.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
inplanes: input channels
|
|
40
|
-
headers: number of headers to split channels
|
|
41
|
-
attn_scale: if True, re-scale attention to counteract the variance distibutions
|
|
42
|
-
ratio: bottleneck ratio
|
|
43
|
-
**kwargs
|
|
44
|
-
"""
|
|
45
|
-
|
|
46
|
-
def __init__(
|
|
47
|
-
self,
|
|
48
|
-
inplanes: int,
|
|
49
|
-
headers: int = 8,
|
|
50
|
-
attn_scale: bool = False,
|
|
51
|
-
ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper
|
|
52
|
-
**kwargs,
|
|
53
|
-
) -> None:
|
|
54
|
-
super().__init__(**kwargs)
|
|
55
|
-
|
|
56
|
-
self.headers = headers # h
|
|
57
|
-
self.inplanes = inplanes # C
|
|
58
|
-
self.attn_scale = attn_scale
|
|
59
|
-
self.ratio = ratio
|
|
60
|
-
self.planes = int(inplanes * ratio)
|
|
61
|
-
|
|
62
|
-
self.single_header_inplanes = int(inplanes / headers) # C / h
|
|
63
|
-
|
|
64
|
-
self.conv_mask = layers.Conv2D(filters=1, kernel_size=1, kernel_initializer=tf.initializers.he_normal())
|
|
65
|
-
|
|
66
|
-
self.transform = Sequential(
|
|
67
|
-
[
|
|
68
|
-
layers.Conv2D(filters=self.planes, kernel_size=1, kernel_initializer=tf.initializers.he_normal()),
|
|
69
|
-
layers.LayerNormalization([1, 2, 3]),
|
|
70
|
-
layers.ReLU(),
|
|
71
|
-
layers.Conv2D(filters=self.inplanes, kernel_size=1, kernel_initializer=tf.initializers.he_normal()),
|
|
72
|
-
],
|
|
73
|
-
name="transform",
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
def context_modeling(self, inputs: tf.Tensor) -> tf.Tensor:
|
|
77
|
-
b, h, w, c = (tf.shape(inputs)[i] for i in range(4))
|
|
78
|
-
|
|
79
|
-
# B, H, W, C -->> B*h, H, W, C/h
|
|
80
|
-
x = tf.reshape(inputs, shape=(b, h, w, self.headers, self.single_header_inplanes))
|
|
81
|
-
x = tf.transpose(x, perm=(0, 3, 1, 2, 4))
|
|
82
|
-
x = tf.reshape(x, shape=(b * self.headers, h, w, self.single_header_inplanes))
|
|
83
|
-
|
|
84
|
-
# Compute shorcut
|
|
85
|
-
shortcut = x
|
|
86
|
-
# B*h, 1, H*W, C/h
|
|
87
|
-
shortcut = tf.reshape(shortcut, shape=(b * self.headers, 1, h * w, self.single_header_inplanes))
|
|
88
|
-
# B*h, 1, C/h, H*W
|
|
89
|
-
shortcut = tf.transpose(shortcut, perm=[0, 1, 3, 2])
|
|
90
|
-
|
|
91
|
-
# Compute context mask
|
|
92
|
-
# B*h, H, W, 1
|
|
93
|
-
context_mask = self.conv_mask(x)
|
|
94
|
-
# B*h, 1, H*W, 1
|
|
95
|
-
context_mask = tf.reshape(context_mask, shape=(b * self.headers, 1, h * w, 1))
|
|
96
|
-
# scale variance
|
|
97
|
-
if self.attn_scale and self.headers > 1:
|
|
98
|
-
context_mask = context_mask / math.sqrt(self.single_header_inplanes)
|
|
99
|
-
# B*h, 1, H*W, 1
|
|
100
|
-
context_mask = activations.softmax(context_mask, axis=2)
|
|
101
|
-
|
|
102
|
-
# Compute context
|
|
103
|
-
# B*h, 1, C/h, 1
|
|
104
|
-
context = tf.matmul(shortcut, context_mask)
|
|
105
|
-
context = tf.reshape(context, shape=(b, 1, c, 1))
|
|
106
|
-
# B, 1, 1, C
|
|
107
|
-
context = tf.transpose(context, perm=(0, 1, 3, 2))
|
|
108
|
-
# Set shape to resolve shape when calling this module in the Sequential MAGCResnet
|
|
109
|
-
batch, chan = inputs.get_shape().as_list()[0], inputs.get_shape().as_list()[-1]
|
|
110
|
-
context.set_shape([batch, 1, 1, chan])
|
|
111
|
-
return context
|
|
112
|
-
|
|
113
|
-
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
|
|
114
|
-
# Context modeling: B, H, W, C -> B, 1, 1, C
|
|
115
|
-
context = self.context_modeling(inputs)
|
|
116
|
-
# Transform: B, 1, 1, C -> B, 1, 1, C
|
|
117
|
-
transformed = self.transform(context, **kwargs)
|
|
118
|
-
return inputs + transformed
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
def _magc_resnet(
|
|
122
|
-
arch: str,
|
|
123
|
-
pretrained: bool,
|
|
124
|
-
num_blocks: list[int],
|
|
125
|
-
output_channels: list[int],
|
|
126
|
-
stage_downsample: list[bool],
|
|
127
|
-
stage_conv: list[bool],
|
|
128
|
-
stage_pooling: list[tuple[int, int] | None],
|
|
129
|
-
origin_stem: bool = True,
|
|
130
|
-
**kwargs: Any,
|
|
131
|
-
) -> ResNet:
|
|
132
|
-
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
|
|
133
|
-
kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
|
|
134
|
-
kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
|
|
135
|
-
|
|
136
|
-
_cfg = deepcopy(default_cfgs[arch])
|
|
137
|
-
_cfg["num_classes"] = kwargs["num_classes"]
|
|
138
|
-
_cfg["classes"] = kwargs["classes"]
|
|
139
|
-
_cfg["input_shape"] = kwargs["input_shape"]
|
|
140
|
-
kwargs.pop("classes")
|
|
141
|
-
|
|
142
|
-
# Build the model
|
|
143
|
-
model = ResNet(
|
|
144
|
-
num_blocks,
|
|
145
|
-
output_channels,
|
|
146
|
-
stage_downsample,
|
|
147
|
-
stage_conv,
|
|
148
|
-
stage_pooling,
|
|
149
|
-
origin_stem,
|
|
150
|
-
attn_module=partial(MAGC, headers=8, attn_scale=True),
|
|
151
|
-
cfg=_cfg,
|
|
152
|
-
**kwargs,
|
|
153
|
-
)
|
|
154
|
-
_build_model(model)
|
|
155
|
-
|
|
156
|
-
# Load pretrained parameters
|
|
157
|
-
if pretrained:
|
|
158
|
-
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
159
|
-
# skip the mismatching layers for fine tuning
|
|
160
|
-
load_pretrained_params(
|
|
161
|
-
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
return model
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
|
168
|
-
"""Resnet31 architecture with Multi-Aspect Global Context Attention as described in
|
|
169
|
-
`"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition",
|
|
170
|
-
<https://arxiv.org/pdf/1910.02562.pdf>`_.
|
|
171
|
-
|
|
172
|
-
>>> import tensorflow as tf
|
|
173
|
-
>>> from doctr.models import magc_resnet31
|
|
174
|
-
>>> model = magc_resnet31(pretrained=False)
|
|
175
|
-
>>> input_tensor = tf.random.uniform(shape=[1, 224, 224, 3], maxval=1, dtype=tf.float32)
|
|
176
|
-
>>> out = model(input_tensor)
|
|
177
|
-
|
|
178
|
-
Args:
|
|
179
|
-
pretrained: boolean, True if model is pretrained
|
|
180
|
-
**kwargs: keyword arguments of the ResNet architecture
|
|
181
|
-
|
|
182
|
-
Returns:
|
|
183
|
-
A feature extractor model
|
|
184
|
-
"""
|
|
185
|
-
return _magc_resnet(
|
|
186
|
-
"magc_resnet31",
|
|
187
|
-
pretrained,
|
|
188
|
-
[1, 2, 5, 3],
|
|
189
|
-
[256, 256, 512, 512],
|
|
190
|
-
[False] * 4,
|
|
191
|
-
[True] * 4,
|
|
192
|
-
[(2, 2), (2, 1), None, None],
|
|
193
|
-
False,
|
|
194
|
-
stem_channels=128,
|
|
195
|
-
**kwargs,
|
|
196
|
-
)
|