docling-ibm-models 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
  2. docling_ibm_models/tableformer/__init__.py +0 -0
  3. docling_ibm_models/tableformer/common.py +200 -0
  4. docling_ibm_models/tableformer/data_management/__init__.py +0 -0
  5. docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
  6. docling_ibm_models/tableformer/data_management/functional.py +574 -0
  7. docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
  8. docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
  9. docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
  10. docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
  11. docling_ibm_models/tableformer/data_management/transforms.py +396 -0
  12. docling_ibm_models/tableformer/models/__init__.py +0 -0
  13. docling_ibm_models/tableformer/models/common/__init__.py +0 -0
  14. docling_ibm_models/tableformer/models/common/base_model.py +279 -0
  15. docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
  16. docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
  17. docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
  18. docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
  19. docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
  20. docling_ibm_models/tableformer/otsl.py +541 -0
  21. docling_ibm_models/tableformer/settings.py +90 -0
  22. docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
  23. docling_ibm_models/tableformer/test_prepare_image.py +99 -0
  24. docling_ibm_models/tableformer/utils/__init__.py +0 -0
  25. docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
  26. docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
  27. docling_ibm_models/tableformer/utils/utils.py +376 -0
  28. docling_ibm_models/tableformer/utils/variance.py +175 -0
  29. docling_ibm_models-0.1.0.dist-info/LICENSE +21 -0
  30. docling_ibm_models-0.1.0.dist-info/METADATA +172 -0
  31. docling_ibm_models-0.1.0.dist-info/RECORD +32 -0
  32. docling_ibm_models-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,99 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import glob
6
+ import os
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+ import docling_ibm_models.tableformer.common as c
12
+ from docling_ibm_models.tableformer.data_management.data_transformer import (
13
+ DataTransformer,
14
+ )
15
+
16
+
17
+ def dump_np(img_np: np.array, fn, n=6):
18
+ # Expect to receive a numpy array for an image with the shape [channels, rows, columns]
19
+ s = img_np.shape
20
+ if s[0] not in [1, 2, 3, 4] or len(s) != 3:
21
+ print("Image of invalid shape: {}".format(s))
22
+ return
23
+
24
+ channels = s[0]
25
+ rows = s[1]
26
+ cols = s[2]
27
+ w = n + 6
28
+ with open(fn, "w") as fd:
29
+ for r in range(rows):
30
+ for col in range(cols):
31
+ for ch in range(channels):
32
+ x = img_np[ch][r][col]
33
+ if isinstance(x, np.float32):
34
+ f_str = "0:>{}.{}f".format(w, n)
35
+ elif isinstance(x, np.uint8):
36
+ f_str = "0:>{}".format(w)
37
+ else:
38
+ return False
39
+
40
+ x_str = ("{" + f_str + "}").format(x)
41
+ fd.write(x_str)
42
+ if ch < channels - 1:
43
+ fd.write(" ")
44
+ fd.write("\n")
45
+ return True
46
+
47
+
48
+ def dump_channels(save_dir, fn_prefix, img_np: np.array):
49
+ # Dump the np array into 3 files per channel
50
+ img_np_ch0 = img_np[0, :, :]
51
+ img_np_ch1 = img_np[1, :, :]
52
+ img_np_ch2 = img_np[2, :, :]
53
+ txt_ch0_fn = os.path.join(save_dir, fn_prefix + "_ch0.txt")
54
+ txt_ch1_fn = os.path.join(save_dir, fn_prefix + "_ch1.txt")
55
+ txt_ch2_fn = os.path.join(save_dir, fn_prefix + "_ch2.txt")
56
+ np.savetxt(txt_ch0_fn, img_np_ch0)
57
+ np.savetxt(txt_ch1_fn, img_np_ch1)
58
+ np.savetxt(txt_ch2_fn, img_np_ch2)
59
+ print(f"{txt_ch0_fn}")
60
+ print(f"{txt_ch1_fn}")
61
+ print(f"{txt_ch2_fn}")
62
+
63
+
64
+ def prepare_image(config):
65
+ transformer = DataTransformer(config)
66
+ predict_dir = config["predict"]["predict_dir"]
67
+ use_normalization = config["dataset"]["image_normalization"]["state"]
68
+
69
+ pattern = os.path.join(predict_dir, "*.png")
70
+ for img_fn in glob.glob(pattern):
71
+ print(f"img_fn: {img_fn}")
72
+
73
+ with Image.open(img_fn) as img:
74
+ # Dump the initial image in txt files
75
+ img_np = np.array(img)
76
+
77
+ # Reshape the image in order to print it
78
+ img_np_m = np.moveaxis(img_np, 2, 0)
79
+ print(
80
+ "orig. img_np.shape: {}, reshaped image: {}".format(
81
+ img_np.shape, img_np_m.shape
82
+ )
83
+ )
84
+ original_fn = img_fn + "_python.txt"
85
+ dump_np(img_np_m, original_fn)
86
+
87
+ r_img_ten = transformer.rescale_in_memory(img, use_normalization)
88
+ print("npimgc: {} - {}".format(r_img_ten.type(), r_img_ten.size()))
89
+
90
+ # Dump the processed image tensor in txt files
91
+ r_img_np = r_img_ten.numpy()
92
+
93
+ prepared_fn = img_fn + "_python_prepared.txt"
94
+ dump_np(r_img_np, prepared_fn)
95
+
96
+
97
+ if __name__ == "__main__":
98
+ config = c.parse_arguments()
99
+ prepare_image(config)
File without changes
@@ -0,0 +1,243 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import time
6
+ from collections import deque
7
+ from statistics import mean, median
8
+
9
+
10
+ class SingletonClass(type):
11
+ r"""
12
+ Generic singleton metaclass
13
+ """
14
+
15
+ def __init__(self, name, bases, dic):
16
+ self._instance = None
17
+ super().__init__(name, bases, dic)
18
+
19
+ def __call__(cls, *args, **kwargs):
20
+ # Create a singleton if needed
21
+ if cls._instance is None:
22
+ singleton = cls.__new__(cls)
23
+ singleton.__init__(*args, **kwargs)
24
+ cls._instance = singleton
25
+ return cls._instance
26
+
27
+
28
+ class Profiler:
29
+ r"""
30
+ Application specific profiler
31
+ Decompose the application into "sections". Each section is a label.
32
+ The total time a section consumes is split into "intervals"
33
+ Use the `begin`, `end` methods to mark the begining and end of an interval for
34
+ a certain section
35
+ """
36
+
37
+ def __init__(self):
38
+ self._section_dts = {} # section name -> sum(section intervals)
39
+ self._section_calls = {} # section name -> number of invocations
40
+ self._section_kB = {} # section name -> max kB of used heap
41
+
42
+ # section name -> beginning of the last interval
43
+ self._last_begin = {}
44
+
45
+ def begin(self, section_name, enable=True):
46
+ r"""
47
+ Mark the beginning of an interval
48
+
49
+ Parameters
50
+ ----------
51
+ section_name : string
52
+ Name of the section
53
+ enable : bool
54
+ The actual interval entry takes place only if enable is true
55
+
56
+ Return
57
+ ------
58
+ True if the interval has actuall begun
59
+ """
60
+ if not enable:
61
+ return False
62
+ self._last_begin[section_name] = time.time()
63
+ return True
64
+
65
+ def end(self, section_name, enable=True):
66
+ r"""
67
+ Mark the end of an interval for a certain section
68
+
69
+ Parameters
70
+ ----------
71
+ section_name : string
72
+ Name of the section
73
+ enable : bool
74
+ The actual interval entry takes place only if enable is true
75
+
76
+ Return
77
+ ------
78
+ True if the section name is valid and an interval for this section has already begun
79
+ False otherwise
80
+ """
81
+ if not enable:
82
+ return False
83
+ if section_name not in self._last_begin:
84
+ return False
85
+
86
+ dt = time.time() - self._last_begin[section_name]
87
+ if section_name not in self._section_dts:
88
+ self._section_dts[section_name] = dt
89
+ self._section_calls[section_name] = 1
90
+ else:
91
+ self._section_dts[section_name] += dt
92
+ self._section_calls[section_name] += 1
93
+
94
+ return True
95
+
96
+ def get_data(self, section_names=None):
97
+ r"""
98
+ Return a dict with profiling data for the specified sections.
99
+
100
+ Parameter
101
+ ---------
102
+ section_names : list of string
103
+ List with the section names to get their accumulative dt
104
+ If it is None, all sections are returned
105
+
106
+ Return
107
+ ------
108
+ dict of dicts
109
+ Outer key: section name
110
+ Inner keys: "dt": Accumulative time for that section, "cells": Number of calls
111
+ """
112
+ # Filter the section names to apply
113
+ filtered_names = list(
114
+ filter(lambda x: x in section_names, self._section_dts.keys())
115
+ if section_names is not None
116
+ else self._section_dts.keys()
117
+ )
118
+ data = {}
119
+ for section_name in filtered_names:
120
+ data[section_name] = {
121
+ "dt": self._section_dts[section_name],
122
+ "calls": self._section_calls[section_name],
123
+ "kB": self._section_kB[section_name],
124
+ }
125
+ return data
126
+
127
+
128
+ class AppProfiler(Profiler, metaclass=SingletonClass):
129
+ r"""
130
+ AppProfiler is a singleton of the Profiler for application wide usage
131
+ """
132
+
133
+ def __init__(self):
134
+ super(AppProfiler, self).__init__()
135
+
136
+
137
+ class AggProfiler(metaclass=SingletonClass):
138
+ r"""
139
+ Generic wrapper of Profiler that enables aggregation of profiling statistics around Cycles
140
+
141
+ - When a new cycle begins a new Profiler is created to keep the profiling data per section
142
+ - Keep the last n cycles in a sliding window manner
143
+ - At every time we can get profiling data about the last cycle and statistics over the last n
144
+ cycles
145
+ """
146
+
147
+ def __init__(self, window_size=20):
148
+ self._window_size = window_size
149
+ # deque with up to the last "window_size" Profilers. The newest at index 0
150
+ self._cycles = deque()
151
+
152
+ def start_agg(self, enable=True):
153
+ r"""
154
+ Returns
155
+ -------
156
+ 0: not enabled
157
+ 1: a new scope has started
158
+ """
159
+ if not enable:
160
+ return 0
161
+
162
+ # Add a new profiler
163
+ self._cycles.appendleft(Profiler())
164
+ # In case the deque has grown too much, remove the oldest Profiler
165
+ if len(self._cycles) > self._window_size:
166
+ self._cycles.pop()
167
+ return 1
168
+
169
+ def begin(self, section_name, enable=True):
170
+ if not enable:
171
+ return False
172
+ if len(self._cycles) == 0:
173
+ print("AggProfiler begin | Start Aggregator not initialized.")
174
+ return False
175
+ profiler = self._cycles[0]
176
+ return profiler.begin(section_name)
177
+
178
+ def end(self, section_name, enable=True):
179
+ if not enable:
180
+ return False
181
+ if len(self._cycles) == 0:
182
+ print("AggProfiler end | Start Aggregator not initialized.")
183
+ return False
184
+ profiler = self._cycles[0]
185
+ return profiler.end(section_name)
186
+
187
+ def get_data(self):
188
+ r"""
189
+ Get profiling data for:
190
+ - The last cycle
191
+ - Aggragated statistics (avg, median) per section and per metric across all cycles
192
+ - The dt numbers for the mean/median is the average time for each section ACROSS the cycle
193
+ - There is NO need to compute average by yourself.
194
+
195
+ Returns
196
+ -------
197
+ dict with the structure:
198
+ - window: int with the size of the time sliding window
199
+ - last: dict with the metrics for the last cycle (as provided by the Profiler)
200
+ - mean: dict with the mean metrics per section across the cycle
201
+ - section_name
202
+ - metric_name: mean of the metric values
203
+ - median: dict with the median metrics per section across the cycle
204
+ - section_name
205
+ - metric_name: median of the metric values
206
+ """
207
+ last_data = self._cycles[0].get_data()
208
+ data = {
209
+ "window": len(self._cycles),
210
+ "last": last_data,
211
+ "mean": {},
212
+ "median": {},
213
+ }
214
+
215
+ # Section -> metric -> [values]
216
+ section_metric_values = {}
217
+
218
+ # Collect the metrics
219
+ for i, p in enumerate(self._cycles):
220
+ p_data = p.get_data()
221
+ for section_name, m_dict in p_data.items():
222
+ for m_name, m_val in m_dict.items():
223
+ if section_name not in section_metric_values:
224
+ section_metric_values[section_name] = {}
225
+ s_metrics = section_metric_values[section_name]
226
+ if m_name not in s_metrics:
227
+ s_metrics[m_name] = []
228
+ s_metrics[m_name].append(m_val)
229
+
230
+ # Aggregate the metrics
231
+ for section_name, m_dict in section_metric_values.items():
232
+ for m_name, m_values in m_dict.items():
233
+ if section_name not in data["mean"]:
234
+ data["mean"][section_name] = {}
235
+ if section_name not in data["median"]:
236
+ data["median"][section_name] = {}
237
+
238
+ mean_v = mean(m_values)
239
+ median_v = median(m_values)
240
+ data["mean"][section_name][m_name] = mean_v
241
+ data["median"][section_name][m_name] = median_v
242
+
243
+ return data
@@ -0,0 +1,216 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import torch
6
+
7
+
8
+ def model_info(model, verbose=False):
9
+ # Plots a line-by-line description of a PyTorch model
10
+ n_p = sum(x.numel() for x in model.parameters()) # number parameters
11
+ n_g = sum(
12
+ x.numel() for x in model.parameters() if x.requires_grad
13
+ ) # number gradients
14
+ if verbose:
15
+ print(
16
+ "%5s %40s %9s %12s %20s %10s %10s"
17
+ % ("layer", "name", "gradient", "parameters", "shape", "mu", "sigma")
18
+ )
19
+ for i, (name, p) in enumerate(model.named_parameters()):
20
+ name = name.replace("module_list.", "")
21
+ print(
22
+ "%5g %40s %9s %12g %20s %10.3g %10.3g"
23
+ % (
24
+ i,
25
+ name,
26
+ p.requires_grad,
27
+ p.numel(),
28
+ list(p.shape),
29
+ p.mean(),
30
+ p.std(),
31
+ )
32
+ )
33
+
34
+ try: # FLOPS
35
+ from thop import profile
36
+
37
+ macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False)
38
+ fs = ", %.1f GFLOPS" % (macs / 1e9 * 2)
39
+ except Exception:
40
+ fs = ""
41
+
42
+ print(
43
+ "Model Summary: %g layers, %g parameters, %g gradients%s"
44
+ % (len(list(model.parameters())), n_p, n_g, fs)
45
+ )
46
+
47
+
48
+ # def init_seeds(seed=0):
49
+ # torch.manual_seed(seed)
50
+ #
51
+ # # Reduce randomness (may be slower on Tesla GPUs)
52
+ # # https://pytorch.org/docs/stable/notes/randomness.html
53
+ # if seed == 0:
54
+ # cudnn.deterministic = False
55
+ # cudnn.benchmark = True
56
+ #
57
+ #
58
+ # def select_device(device='', apex=False, batch_size=None):
59
+ # # device = 'cpu' or '0' or '0,1,2,3'
60
+ # cpu_request = device.lower() == 'cpu'
61
+ # if device and not cpu_request: # if device requested other than 'cpu'
62
+ # os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
63
+ # # check availablity
64
+ # assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device
65
+ #
66
+ # cuda = False if cpu_request else torch.cuda.is_available()
67
+ # if cuda:
68
+ # c = 1024 ** 2 # bytes to MB
69
+ # ng = torch.cuda.device_count()
70
+ # if ng > 1 and batch_size: # check that batch_size is compatible with device_count
71
+ # assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % \
72
+ # (batch_size, ng)
73
+ # x = [torch.cuda.get_device_properties(i) for i in range(ng)]
74
+ # # apex for mixed precision https://github.com/NVIDIA/apex
75
+ # s = 'Using CUDA ' + ('Apex ' if apex else '')
76
+ # for i in range(0, ng):
77
+ # if i == 1:
78
+ # s = ' ' * len(s)
79
+ # print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
80
+ # (s, i, x[i].name, x[i].total_memory / c))
81
+ # else:
82
+ # print('Using CPU')
83
+ #
84
+ # print('') # skip a line
85
+ # return torch.device('cuda:0' if cuda else 'cpu')
86
+ #
87
+ #
88
+ # def time_synchronized():
89
+ # torch.cuda.synchronize() if torch.cuda.is_available() else None
90
+ # return time.time()
91
+ #
92
+ #
93
+ # def initialize_weights(model):
94
+ # for m in model.modules():
95
+ # t = type(m)
96
+ # if t is nn.Conv2d:
97
+ # pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
98
+ # elif t is nn.BatchNorm2d:
99
+ # m.eps = 1e-4
100
+ # m.momentum = 0.03
101
+ # elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
102
+ # m.inplace = True
103
+ #
104
+ #
105
+ # def find_modules(model, mclass=nn.Conv2d):
106
+ # # finds layer indices matching module class 'mclass'
107
+ # return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
108
+ #
109
+ #
110
+ # def fuse_conv_and_bn(conv, bn):
111
+ # # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
112
+ # with torch.no_grad():
113
+ # # init
114
+ # fusedconv = torch.nn.Conv2d(conv.in_channels,
115
+ # conv.out_channels,
116
+ # kernel_size=conv.kernel_size,
117
+ # stride=conv.stride,
118
+ # padding=conv.padding,
119
+ # bias=True)
120
+ #
121
+ # # prepare filters
122
+ # w_conv = conv.weight.clone().view(conv.out_channels, -1)
123
+ # w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
124
+ # fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
125
+ #
126
+ # # prepare spatial bias
127
+ # if conv.bias is not None:
128
+ # b_conv = conv.bias
129
+ # else:
130
+ # b_conv = torch.zeros(conv.weight.size(0))
131
+ # b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
132
+ # fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
133
+ #
134
+ # return fusedconv
135
+ #
136
+ #
137
+ # def load_classifier(name='resnet101', n=2):
138
+ # # Loads a pretrained model reshaped to n-class output
139
+ # import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision
140
+ # model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet')
141
+ #
142
+ # # Display model properties
143
+ # for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean',
144
+ # 'model.std']:
145
+ # print(x + ' =', eval(x))
146
+ #
147
+ # # Reshape output to n classes
148
+ # filters = model.last_linear.weight.shape[1]
149
+ # model.last_linear.bias = torch.nn.Parameter(torch.zeros(n))
150
+ # model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
151
+ # model.last_linear.out_features = n
152
+ # return model
153
+ #
154
+ #
155
+ # def scale_img(img, ratio=1.0, same_shape=True): # img(16,3,256,416), r=ratio
156
+ # # scales img(bs,3,y,x) by ratio
157
+ # h, w = img.shape[2:]
158
+ # s = (int(h * ratio), int(w * ratio)) # new size
159
+ # img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
160
+ # if not same_shape: # pad/crop img
161
+ # gs = 64 # (pixels) grid size
162
+ # h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
163
+ # return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
164
+ #
165
+ #
166
+ # class ModelEMA:
167
+ # """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
168
+ # Keep a moving average of everything in the model state_dict (parameters and buffers).
169
+ # This is intended to allow functionality like
170
+ # https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
171
+ # A smoothed version of the weights is necessary for some training schemes to perform well.
172
+ # E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
173
+ # RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
174
+ # smoothing of weights to match results. Pay attention to the decay constant you are using
175
+ # relative to your update count per epoch.
176
+ # To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
177
+ # disable validation of the EMA weights. Validation will have to be done manually in a separate
178
+ # process, or after the training stops converging.
179
+ # This class is sensitive where it is initialized in the sequence of model init,
180
+ # GPU assignment and distributed training wrappers.
181
+ # I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and
182
+ # single-GPU.
183
+ # """
184
+ #
185
+ # def __init__(self, model, decay=0.9999, device=''):
186
+ # # make a copy of the model for accumulating moving average of weights
187
+ # self.ema = deepcopy(model)
188
+ # self.ema.eval()
189
+ # self.updates = 0 # number of EMA updates
190
+ # # decay exponential ramp (to help early epochs)
191
+ # self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
192
+ # self.device = device # perform ema on different device from model if set
193
+ # if device:
194
+ # self.ema.to(device=device)
195
+ # for p in self.ema.parameters():
196
+ # p.requires_grad_(False)
197
+ #
198
+ # def update(self, model):
199
+ # self.updates += 1
200
+ # d = self.decay(self.updates)
201
+ # with torch.no_grad():
202
+ # if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
203
+ # msd, esd = model.module.state_dict(), self.ema.module.state_dict()
204
+ # else:
205
+ # msd, esd = model.state_dict(), self.ema.state_dict()
206
+ #
207
+ # for k, v in esd.items():
208
+ # if v.dtype.is_floating_point:
209
+ # v *= d
210
+ # v += (1. - d) * msd[k].detach()
211
+ #
212
+ # def update_attr(self, model):
213
+ # # Assign attributes (which may change during training)
214
+ # for k in model.__dict__.keys():
215
+ # if not k.startswith('_'):
216
+ # setattr(self.ema, k, getattr(model, k))