docling-ibm-models 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
- docling_ibm_models/tableformer/__init__.py +0 -0
- docling_ibm_models/tableformer/common.py +200 -0
- docling_ibm_models/tableformer/data_management/__init__.py +0 -0
- docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
- docling_ibm_models/tableformer/data_management/functional.py +574 -0
- docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
- docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
- docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
- docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
- docling_ibm_models/tableformer/data_management/transforms.py +396 -0
- docling_ibm_models/tableformer/models/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/base_model.py +279 -0
- docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
- docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
- docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
- docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
- docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
- docling_ibm_models/tableformer/otsl.py +541 -0
- docling_ibm_models/tableformer/settings.py +90 -0
- docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
- docling_ibm_models/tableformer/test_prepare_image.py +99 -0
- docling_ibm_models/tableformer/utils/__init__.py +0 -0
- docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
- docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
- docling_ibm_models/tableformer/utils/utils.py +376 -0
- docling_ibm_models/tableformer/utils/variance.py +175 -0
- docling_ibm_models-0.1.0.dist-info/LICENSE +21 -0
- docling_ibm_models-0.1.0.dist-info/METADATA +172 -0
- docling_ibm_models-0.1.0.dist-info/RECORD +32 -0
- docling_ibm_models-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,99 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import glob
|
6
|
+
import os
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from PIL import Image
|
10
|
+
|
11
|
+
import docling_ibm_models.tableformer.common as c
|
12
|
+
from docling_ibm_models.tableformer.data_management.data_transformer import (
|
13
|
+
DataTransformer,
|
14
|
+
)
|
15
|
+
|
16
|
+
|
17
|
+
def dump_np(img_np: np.array, fn, n=6):
|
18
|
+
# Expect to receive a numpy array for an image with the shape [channels, rows, columns]
|
19
|
+
s = img_np.shape
|
20
|
+
if s[0] not in [1, 2, 3, 4] or len(s) != 3:
|
21
|
+
print("Image of invalid shape: {}".format(s))
|
22
|
+
return
|
23
|
+
|
24
|
+
channels = s[0]
|
25
|
+
rows = s[1]
|
26
|
+
cols = s[2]
|
27
|
+
w = n + 6
|
28
|
+
with open(fn, "w") as fd:
|
29
|
+
for r in range(rows):
|
30
|
+
for col in range(cols):
|
31
|
+
for ch in range(channels):
|
32
|
+
x = img_np[ch][r][col]
|
33
|
+
if isinstance(x, np.float32):
|
34
|
+
f_str = "0:>{}.{}f".format(w, n)
|
35
|
+
elif isinstance(x, np.uint8):
|
36
|
+
f_str = "0:>{}".format(w)
|
37
|
+
else:
|
38
|
+
return False
|
39
|
+
|
40
|
+
x_str = ("{" + f_str + "}").format(x)
|
41
|
+
fd.write(x_str)
|
42
|
+
if ch < channels - 1:
|
43
|
+
fd.write(" ")
|
44
|
+
fd.write("\n")
|
45
|
+
return True
|
46
|
+
|
47
|
+
|
48
|
+
def dump_channels(save_dir, fn_prefix, img_np: np.array):
|
49
|
+
# Dump the np array into 3 files per channel
|
50
|
+
img_np_ch0 = img_np[0, :, :]
|
51
|
+
img_np_ch1 = img_np[1, :, :]
|
52
|
+
img_np_ch2 = img_np[2, :, :]
|
53
|
+
txt_ch0_fn = os.path.join(save_dir, fn_prefix + "_ch0.txt")
|
54
|
+
txt_ch1_fn = os.path.join(save_dir, fn_prefix + "_ch1.txt")
|
55
|
+
txt_ch2_fn = os.path.join(save_dir, fn_prefix + "_ch2.txt")
|
56
|
+
np.savetxt(txt_ch0_fn, img_np_ch0)
|
57
|
+
np.savetxt(txt_ch1_fn, img_np_ch1)
|
58
|
+
np.savetxt(txt_ch2_fn, img_np_ch2)
|
59
|
+
print(f"{txt_ch0_fn}")
|
60
|
+
print(f"{txt_ch1_fn}")
|
61
|
+
print(f"{txt_ch2_fn}")
|
62
|
+
|
63
|
+
|
64
|
+
def prepare_image(config):
|
65
|
+
transformer = DataTransformer(config)
|
66
|
+
predict_dir = config["predict"]["predict_dir"]
|
67
|
+
use_normalization = config["dataset"]["image_normalization"]["state"]
|
68
|
+
|
69
|
+
pattern = os.path.join(predict_dir, "*.png")
|
70
|
+
for img_fn in glob.glob(pattern):
|
71
|
+
print(f"img_fn: {img_fn}")
|
72
|
+
|
73
|
+
with Image.open(img_fn) as img:
|
74
|
+
# Dump the initial image in txt files
|
75
|
+
img_np = np.array(img)
|
76
|
+
|
77
|
+
# Reshape the image in order to print it
|
78
|
+
img_np_m = np.moveaxis(img_np, 2, 0)
|
79
|
+
print(
|
80
|
+
"orig. img_np.shape: {}, reshaped image: {}".format(
|
81
|
+
img_np.shape, img_np_m.shape
|
82
|
+
)
|
83
|
+
)
|
84
|
+
original_fn = img_fn + "_python.txt"
|
85
|
+
dump_np(img_np_m, original_fn)
|
86
|
+
|
87
|
+
r_img_ten = transformer.rescale_in_memory(img, use_normalization)
|
88
|
+
print("npimgc: {} - {}".format(r_img_ten.type(), r_img_ten.size()))
|
89
|
+
|
90
|
+
# Dump the processed image tensor in txt files
|
91
|
+
r_img_np = r_img_ten.numpy()
|
92
|
+
|
93
|
+
prepared_fn = img_fn + "_python_prepared.txt"
|
94
|
+
dump_np(r_img_np, prepared_fn)
|
95
|
+
|
96
|
+
|
97
|
+
if __name__ == "__main__":
|
98
|
+
config = c.parse_arguments()
|
99
|
+
prepare_image(config)
|
File without changes
|
@@ -0,0 +1,243 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import time
|
6
|
+
from collections import deque
|
7
|
+
from statistics import mean, median
|
8
|
+
|
9
|
+
|
10
|
+
class SingletonClass(type):
|
11
|
+
r"""
|
12
|
+
Generic singleton metaclass
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, name, bases, dic):
|
16
|
+
self._instance = None
|
17
|
+
super().__init__(name, bases, dic)
|
18
|
+
|
19
|
+
def __call__(cls, *args, **kwargs):
|
20
|
+
# Create a singleton if needed
|
21
|
+
if cls._instance is None:
|
22
|
+
singleton = cls.__new__(cls)
|
23
|
+
singleton.__init__(*args, **kwargs)
|
24
|
+
cls._instance = singleton
|
25
|
+
return cls._instance
|
26
|
+
|
27
|
+
|
28
|
+
class Profiler:
|
29
|
+
r"""
|
30
|
+
Application specific profiler
|
31
|
+
Decompose the application into "sections". Each section is a label.
|
32
|
+
The total time a section consumes is split into "intervals"
|
33
|
+
Use the `begin`, `end` methods to mark the begining and end of an interval for
|
34
|
+
a certain section
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self):
|
38
|
+
self._section_dts = {} # section name -> sum(section intervals)
|
39
|
+
self._section_calls = {} # section name -> number of invocations
|
40
|
+
self._section_kB = {} # section name -> max kB of used heap
|
41
|
+
|
42
|
+
# section name -> beginning of the last interval
|
43
|
+
self._last_begin = {}
|
44
|
+
|
45
|
+
def begin(self, section_name, enable=True):
|
46
|
+
r"""
|
47
|
+
Mark the beginning of an interval
|
48
|
+
|
49
|
+
Parameters
|
50
|
+
----------
|
51
|
+
section_name : string
|
52
|
+
Name of the section
|
53
|
+
enable : bool
|
54
|
+
The actual interval entry takes place only if enable is true
|
55
|
+
|
56
|
+
Return
|
57
|
+
------
|
58
|
+
True if the interval has actuall begun
|
59
|
+
"""
|
60
|
+
if not enable:
|
61
|
+
return False
|
62
|
+
self._last_begin[section_name] = time.time()
|
63
|
+
return True
|
64
|
+
|
65
|
+
def end(self, section_name, enable=True):
|
66
|
+
r"""
|
67
|
+
Mark the end of an interval for a certain section
|
68
|
+
|
69
|
+
Parameters
|
70
|
+
----------
|
71
|
+
section_name : string
|
72
|
+
Name of the section
|
73
|
+
enable : bool
|
74
|
+
The actual interval entry takes place only if enable is true
|
75
|
+
|
76
|
+
Return
|
77
|
+
------
|
78
|
+
True if the section name is valid and an interval for this section has already begun
|
79
|
+
False otherwise
|
80
|
+
"""
|
81
|
+
if not enable:
|
82
|
+
return False
|
83
|
+
if section_name not in self._last_begin:
|
84
|
+
return False
|
85
|
+
|
86
|
+
dt = time.time() - self._last_begin[section_name]
|
87
|
+
if section_name not in self._section_dts:
|
88
|
+
self._section_dts[section_name] = dt
|
89
|
+
self._section_calls[section_name] = 1
|
90
|
+
else:
|
91
|
+
self._section_dts[section_name] += dt
|
92
|
+
self._section_calls[section_name] += 1
|
93
|
+
|
94
|
+
return True
|
95
|
+
|
96
|
+
def get_data(self, section_names=None):
|
97
|
+
r"""
|
98
|
+
Return a dict with profiling data for the specified sections.
|
99
|
+
|
100
|
+
Parameter
|
101
|
+
---------
|
102
|
+
section_names : list of string
|
103
|
+
List with the section names to get their accumulative dt
|
104
|
+
If it is None, all sections are returned
|
105
|
+
|
106
|
+
Return
|
107
|
+
------
|
108
|
+
dict of dicts
|
109
|
+
Outer key: section name
|
110
|
+
Inner keys: "dt": Accumulative time for that section, "cells": Number of calls
|
111
|
+
"""
|
112
|
+
# Filter the section names to apply
|
113
|
+
filtered_names = list(
|
114
|
+
filter(lambda x: x in section_names, self._section_dts.keys())
|
115
|
+
if section_names is not None
|
116
|
+
else self._section_dts.keys()
|
117
|
+
)
|
118
|
+
data = {}
|
119
|
+
for section_name in filtered_names:
|
120
|
+
data[section_name] = {
|
121
|
+
"dt": self._section_dts[section_name],
|
122
|
+
"calls": self._section_calls[section_name],
|
123
|
+
"kB": self._section_kB[section_name],
|
124
|
+
}
|
125
|
+
return data
|
126
|
+
|
127
|
+
|
128
|
+
class AppProfiler(Profiler, metaclass=SingletonClass):
|
129
|
+
r"""
|
130
|
+
AppProfiler is a singleton of the Profiler for application wide usage
|
131
|
+
"""
|
132
|
+
|
133
|
+
def __init__(self):
|
134
|
+
super(AppProfiler, self).__init__()
|
135
|
+
|
136
|
+
|
137
|
+
class AggProfiler(metaclass=SingletonClass):
|
138
|
+
r"""
|
139
|
+
Generic wrapper of Profiler that enables aggregation of profiling statistics around Cycles
|
140
|
+
|
141
|
+
- When a new cycle begins a new Profiler is created to keep the profiling data per section
|
142
|
+
- Keep the last n cycles in a sliding window manner
|
143
|
+
- At every time we can get profiling data about the last cycle and statistics over the last n
|
144
|
+
cycles
|
145
|
+
"""
|
146
|
+
|
147
|
+
def __init__(self, window_size=20):
|
148
|
+
self._window_size = window_size
|
149
|
+
# deque with up to the last "window_size" Profilers. The newest at index 0
|
150
|
+
self._cycles = deque()
|
151
|
+
|
152
|
+
def start_agg(self, enable=True):
|
153
|
+
r"""
|
154
|
+
Returns
|
155
|
+
-------
|
156
|
+
0: not enabled
|
157
|
+
1: a new scope has started
|
158
|
+
"""
|
159
|
+
if not enable:
|
160
|
+
return 0
|
161
|
+
|
162
|
+
# Add a new profiler
|
163
|
+
self._cycles.appendleft(Profiler())
|
164
|
+
# In case the deque has grown too much, remove the oldest Profiler
|
165
|
+
if len(self._cycles) > self._window_size:
|
166
|
+
self._cycles.pop()
|
167
|
+
return 1
|
168
|
+
|
169
|
+
def begin(self, section_name, enable=True):
|
170
|
+
if not enable:
|
171
|
+
return False
|
172
|
+
if len(self._cycles) == 0:
|
173
|
+
print("AggProfiler begin | Start Aggregator not initialized.")
|
174
|
+
return False
|
175
|
+
profiler = self._cycles[0]
|
176
|
+
return profiler.begin(section_name)
|
177
|
+
|
178
|
+
def end(self, section_name, enable=True):
|
179
|
+
if not enable:
|
180
|
+
return False
|
181
|
+
if len(self._cycles) == 0:
|
182
|
+
print("AggProfiler end | Start Aggregator not initialized.")
|
183
|
+
return False
|
184
|
+
profiler = self._cycles[0]
|
185
|
+
return profiler.end(section_name)
|
186
|
+
|
187
|
+
def get_data(self):
|
188
|
+
r"""
|
189
|
+
Get profiling data for:
|
190
|
+
- The last cycle
|
191
|
+
- Aggragated statistics (avg, median) per section and per metric across all cycles
|
192
|
+
- The dt numbers for the mean/median is the average time for each section ACROSS the cycle
|
193
|
+
- There is NO need to compute average by yourself.
|
194
|
+
|
195
|
+
Returns
|
196
|
+
-------
|
197
|
+
dict with the structure:
|
198
|
+
- window: int with the size of the time sliding window
|
199
|
+
- last: dict with the metrics for the last cycle (as provided by the Profiler)
|
200
|
+
- mean: dict with the mean metrics per section across the cycle
|
201
|
+
- section_name
|
202
|
+
- metric_name: mean of the metric values
|
203
|
+
- median: dict with the median metrics per section across the cycle
|
204
|
+
- section_name
|
205
|
+
- metric_name: median of the metric values
|
206
|
+
"""
|
207
|
+
last_data = self._cycles[0].get_data()
|
208
|
+
data = {
|
209
|
+
"window": len(self._cycles),
|
210
|
+
"last": last_data,
|
211
|
+
"mean": {},
|
212
|
+
"median": {},
|
213
|
+
}
|
214
|
+
|
215
|
+
# Section -> metric -> [values]
|
216
|
+
section_metric_values = {}
|
217
|
+
|
218
|
+
# Collect the metrics
|
219
|
+
for i, p in enumerate(self._cycles):
|
220
|
+
p_data = p.get_data()
|
221
|
+
for section_name, m_dict in p_data.items():
|
222
|
+
for m_name, m_val in m_dict.items():
|
223
|
+
if section_name not in section_metric_values:
|
224
|
+
section_metric_values[section_name] = {}
|
225
|
+
s_metrics = section_metric_values[section_name]
|
226
|
+
if m_name not in s_metrics:
|
227
|
+
s_metrics[m_name] = []
|
228
|
+
s_metrics[m_name].append(m_val)
|
229
|
+
|
230
|
+
# Aggregate the metrics
|
231
|
+
for section_name, m_dict in section_metric_values.items():
|
232
|
+
for m_name, m_values in m_dict.items():
|
233
|
+
if section_name not in data["mean"]:
|
234
|
+
data["mean"][section_name] = {}
|
235
|
+
if section_name not in data["median"]:
|
236
|
+
data["median"][section_name] = {}
|
237
|
+
|
238
|
+
mean_v = mean(m_values)
|
239
|
+
median_v = median(m_values)
|
240
|
+
data["mean"][section_name][m_name] = mean_v
|
241
|
+
data["median"][section_name][m_name] = median_v
|
242
|
+
|
243
|
+
return data
|
@@ -0,0 +1,216 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import torch
|
6
|
+
|
7
|
+
|
8
|
+
def model_info(model, verbose=False):
|
9
|
+
# Plots a line-by-line description of a PyTorch model
|
10
|
+
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
11
|
+
n_g = sum(
|
12
|
+
x.numel() for x in model.parameters() if x.requires_grad
|
13
|
+
) # number gradients
|
14
|
+
if verbose:
|
15
|
+
print(
|
16
|
+
"%5s %40s %9s %12s %20s %10s %10s"
|
17
|
+
% ("layer", "name", "gradient", "parameters", "shape", "mu", "sigma")
|
18
|
+
)
|
19
|
+
for i, (name, p) in enumerate(model.named_parameters()):
|
20
|
+
name = name.replace("module_list.", "")
|
21
|
+
print(
|
22
|
+
"%5g %40s %9s %12g %20s %10.3g %10.3g"
|
23
|
+
% (
|
24
|
+
i,
|
25
|
+
name,
|
26
|
+
p.requires_grad,
|
27
|
+
p.numel(),
|
28
|
+
list(p.shape),
|
29
|
+
p.mean(),
|
30
|
+
p.std(),
|
31
|
+
)
|
32
|
+
)
|
33
|
+
|
34
|
+
try: # FLOPS
|
35
|
+
from thop import profile
|
36
|
+
|
37
|
+
macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False)
|
38
|
+
fs = ", %.1f GFLOPS" % (macs / 1e9 * 2)
|
39
|
+
except Exception:
|
40
|
+
fs = ""
|
41
|
+
|
42
|
+
print(
|
43
|
+
"Model Summary: %g layers, %g parameters, %g gradients%s"
|
44
|
+
% (len(list(model.parameters())), n_p, n_g, fs)
|
45
|
+
)
|
46
|
+
|
47
|
+
|
48
|
+
# def init_seeds(seed=0):
|
49
|
+
# torch.manual_seed(seed)
|
50
|
+
#
|
51
|
+
# # Reduce randomness (may be slower on Tesla GPUs)
|
52
|
+
# # https://pytorch.org/docs/stable/notes/randomness.html
|
53
|
+
# if seed == 0:
|
54
|
+
# cudnn.deterministic = False
|
55
|
+
# cudnn.benchmark = True
|
56
|
+
#
|
57
|
+
#
|
58
|
+
# def select_device(device='', apex=False, batch_size=None):
|
59
|
+
# # device = 'cpu' or '0' or '0,1,2,3'
|
60
|
+
# cpu_request = device.lower() == 'cpu'
|
61
|
+
# if device and not cpu_request: # if device requested other than 'cpu'
|
62
|
+
# os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
|
63
|
+
# # check availablity
|
64
|
+
# assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device
|
65
|
+
#
|
66
|
+
# cuda = False if cpu_request else torch.cuda.is_available()
|
67
|
+
# if cuda:
|
68
|
+
# c = 1024 ** 2 # bytes to MB
|
69
|
+
# ng = torch.cuda.device_count()
|
70
|
+
# if ng > 1 and batch_size: # check that batch_size is compatible with device_count
|
71
|
+
# assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % \
|
72
|
+
# (batch_size, ng)
|
73
|
+
# x = [torch.cuda.get_device_properties(i) for i in range(ng)]
|
74
|
+
# # apex for mixed precision https://github.com/NVIDIA/apex
|
75
|
+
# s = 'Using CUDA ' + ('Apex ' if apex else '')
|
76
|
+
# for i in range(0, ng):
|
77
|
+
# if i == 1:
|
78
|
+
# s = ' ' * len(s)
|
79
|
+
# print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
|
80
|
+
# (s, i, x[i].name, x[i].total_memory / c))
|
81
|
+
# else:
|
82
|
+
# print('Using CPU')
|
83
|
+
#
|
84
|
+
# print('') # skip a line
|
85
|
+
# return torch.device('cuda:0' if cuda else 'cpu')
|
86
|
+
#
|
87
|
+
#
|
88
|
+
# def time_synchronized():
|
89
|
+
# torch.cuda.synchronize() if torch.cuda.is_available() else None
|
90
|
+
# return time.time()
|
91
|
+
#
|
92
|
+
#
|
93
|
+
# def initialize_weights(model):
|
94
|
+
# for m in model.modules():
|
95
|
+
# t = type(m)
|
96
|
+
# if t is nn.Conv2d:
|
97
|
+
# pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
98
|
+
# elif t is nn.BatchNorm2d:
|
99
|
+
# m.eps = 1e-4
|
100
|
+
# m.momentum = 0.03
|
101
|
+
# elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
|
102
|
+
# m.inplace = True
|
103
|
+
#
|
104
|
+
#
|
105
|
+
# def find_modules(model, mclass=nn.Conv2d):
|
106
|
+
# # finds layer indices matching module class 'mclass'
|
107
|
+
# return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
|
108
|
+
#
|
109
|
+
#
|
110
|
+
# def fuse_conv_and_bn(conv, bn):
|
111
|
+
# # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
112
|
+
# with torch.no_grad():
|
113
|
+
# # init
|
114
|
+
# fusedconv = torch.nn.Conv2d(conv.in_channels,
|
115
|
+
# conv.out_channels,
|
116
|
+
# kernel_size=conv.kernel_size,
|
117
|
+
# stride=conv.stride,
|
118
|
+
# padding=conv.padding,
|
119
|
+
# bias=True)
|
120
|
+
#
|
121
|
+
# # prepare filters
|
122
|
+
# w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
123
|
+
# w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
124
|
+
# fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
|
125
|
+
#
|
126
|
+
# # prepare spatial bias
|
127
|
+
# if conv.bias is not None:
|
128
|
+
# b_conv = conv.bias
|
129
|
+
# else:
|
130
|
+
# b_conv = torch.zeros(conv.weight.size(0))
|
131
|
+
# b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
132
|
+
# fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
133
|
+
#
|
134
|
+
# return fusedconv
|
135
|
+
#
|
136
|
+
#
|
137
|
+
# def load_classifier(name='resnet101', n=2):
|
138
|
+
# # Loads a pretrained model reshaped to n-class output
|
139
|
+
# import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision
|
140
|
+
# model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet')
|
141
|
+
#
|
142
|
+
# # Display model properties
|
143
|
+
# for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean',
|
144
|
+
# 'model.std']:
|
145
|
+
# print(x + ' =', eval(x))
|
146
|
+
#
|
147
|
+
# # Reshape output to n classes
|
148
|
+
# filters = model.last_linear.weight.shape[1]
|
149
|
+
# model.last_linear.bias = torch.nn.Parameter(torch.zeros(n))
|
150
|
+
# model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
|
151
|
+
# model.last_linear.out_features = n
|
152
|
+
# return model
|
153
|
+
#
|
154
|
+
#
|
155
|
+
# def scale_img(img, ratio=1.0, same_shape=True): # img(16,3,256,416), r=ratio
|
156
|
+
# # scales img(bs,3,y,x) by ratio
|
157
|
+
# h, w = img.shape[2:]
|
158
|
+
# s = (int(h * ratio), int(w * ratio)) # new size
|
159
|
+
# img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
160
|
+
# if not same_shape: # pad/crop img
|
161
|
+
# gs = 64 # (pixels) grid size
|
162
|
+
# h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
|
163
|
+
# return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
164
|
+
#
|
165
|
+
#
|
166
|
+
# class ModelEMA:
|
167
|
+
# """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
168
|
+
# Keep a moving average of everything in the model state_dict (parameters and buffers).
|
169
|
+
# This is intended to allow functionality like
|
170
|
+
# https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
171
|
+
# A smoothed version of the weights is necessary for some training schemes to perform well.
|
172
|
+
# E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
|
173
|
+
# RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
|
174
|
+
# smoothing of weights to match results. Pay attention to the decay constant you are using
|
175
|
+
# relative to your update count per epoch.
|
176
|
+
# To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
|
177
|
+
# disable validation of the EMA weights. Validation will have to be done manually in a separate
|
178
|
+
# process, or after the training stops converging.
|
179
|
+
# This class is sensitive where it is initialized in the sequence of model init,
|
180
|
+
# GPU assignment and distributed training wrappers.
|
181
|
+
# I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and
|
182
|
+
# single-GPU.
|
183
|
+
# """
|
184
|
+
#
|
185
|
+
# def __init__(self, model, decay=0.9999, device=''):
|
186
|
+
# # make a copy of the model for accumulating moving average of weights
|
187
|
+
# self.ema = deepcopy(model)
|
188
|
+
# self.ema.eval()
|
189
|
+
# self.updates = 0 # number of EMA updates
|
190
|
+
# # decay exponential ramp (to help early epochs)
|
191
|
+
# self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
|
192
|
+
# self.device = device # perform ema on different device from model if set
|
193
|
+
# if device:
|
194
|
+
# self.ema.to(device=device)
|
195
|
+
# for p in self.ema.parameters():
|
196
|
+
# p.requires_grad_(False)
|
197
|
+
#
|
198
|
+
# def update(self, model):
|
199
|
+
# self.updates += 1
|
200
|
+
# d = self.decay(self.updates)
|
201
|
+
# with torch.no_grad():
|
202
|
+
# if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
203
|
+
# msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
204
|
+
# else:
|
205
|
+
# msd, esd = model.state_dict(), self.ema.state_dict()
|
206
|
+
#
|
207
|
+
# for k, v in esd.items():
|
208
|
+
# if v.dtype.is_floating_point:
|
209
|
+
# v *= d
|
210
|
+
# v += (1. - d) * msd[k].detach()
|
211
|
+
#
|
212
|
+
# def update_attr(self, model):
|
213
|
+
# # Assign attributes (which may change during training)
|
214
|
+
# for k in model.__dict__.keys():
|
215
|
+
# if not k.startswith('_'):
|
216
|
+
# setattr(self.ema, k, getattr(model, k))
|