megadetector 5.0.5__py3-none-any.whl → 5.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (132) hide show
  1. api/batch_processing/data_preparation/manage_local_batch.py +302 -263
  2. api/batch_processing/data_preparation/manage_video_batch.py +81 -2
  3. api/batch_processing/postprocessing/add_max_conf.py +1 -0
  4. api/batch_processing/postprocessing/categorize_detections_by_size.py +50 -19
  5. api/batch_processing/postprocessing/compare_batch_results.py +110 -60
  6. api/batch_processing/postprocessing/load_api_results.py +56 -70
  7. api/batch_processing/postprocessing/md_to_coco.py +1 -1
  8. api/batch_processing/postprocessing/md_to_labelme.py +2 -1
  9. api/batch_processing/postprocessing/postprocess_batch_results.py +240 -81
  10. api/batch_processing/postprocessing/render_detection_confusion_matrix.py +625 -0
  11. api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +71 -23
  12. api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +1 -1
  13. api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +227 -75
  14. api/batch_processing/postprocessing/subset_json_detector_output.py +132 -5
  15. api/batch_processing/postprocessing/top_folders_to_bottom.py +1 -1
  16. api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +2 -2
  17. classification/prepare_classification_script.py +191 -191
  18. data_management/coco_to_yolo.py +68 -45
  19. data_management/databases/integrity_check_json_db.py +7 -5
  20. data_management/generate_crops_from_cct.py +3 -3
  21. data_management/get_image_sizes.py +8 -6
  22. data_management/importers/add_timestamps_to_icct.py +79 -0
  23. data_management/importers/animl_results_to_md_results.py +160 -0
  24. data_management/importers/auckland_doc_test_to_json.py +4 -4
  25. data_management/importers/auckland_doc_to_json.py +1 -1
  26. data_management/importers/awc_to_json.py +5 -5
  27. data_management/importers/bellevue_to_json.py +5 -5
  28. data_management/importers/carrizo_shrubfree_2018.py +5 -5
  29. data_management/importers/carrizo_trail_cam_2017.py +5 -5
  30. data_management/importers/cct_field_adjustments.py +2 -3
  31. data_management/importers/channel_islands_to_cct.py +4 -4
  32. data_management/importers/ena24_to_json.py +5 -5
  33. data_management/importers/helena_to_cct.py +10 -10
  34. data_management/importers/idaho-camera-traps.py +12 -12
  35. data_management/importers/idfg_iwildcam_lila_prep.py +8 -8
  36. data_management/importers/jb_csv_to_json.py +4 -4
  37. data_management/importers/missouri_to_json.py +1 -1
  38. data_management/importers/noaa_seals_2019.py +1 -1
  39. data_management/importers/pc_to_json.py +5 -5
  40. data_management/importers/prepare-noaa-fish-data-for-lila.py +4 -4
  41. data_management/importers/prepare_zsl_imerit.py +5 -5
  42. data_management/importers/rspb_to_json.py +4 -4
  43. data_management/importers/save_the_elephants_survey_A.py +5 -5
  44. data_management/importers/save_the_elephants_survey_B.py +6 -6
  45. data_management/importers/snapshot_safari_importer.py +9 -9
  46. data_management/importers/snapshot_serengeti_lila.py +9 -9
  47. data_management/importers/timelapse_csv_set_to_json.py +5 -7
  48. data_management/importers/ubc_to_json.py +4 -4
  49. data_management/importers/umn_to_json.py +4 -4
  50. data_management/importers/wellington_to_json.py +1 -1
  51. data_management/importers/wi_to_json.py +2 -2
  52. data_management/importers/zamba_results_to_md_results.py +181 -0
  53. data_management/labelme_to_coco.py +35 -7
  54. data_management/labelme_to_yolo.py +229 -0
  55. data_management/lila/add_locations_to_island_camera_traps.py +1 -1
  56. data_management/lila/add_locations_to_nacti.py +147 -0
  57. data_management/lila/create_lila_blank_set.py +474 -0
  58. data_management/lila/create_lila_test_set.py +2 -1
  59. data_management/lila/create_links_to_md_results_files.py +106 -0
  60. data_management/lila/download_lila_subset.py +46 -21
  61. data_management/lila/generate_lila_per_image_labels.py +23 -14
  62. data_management/lila/get_lila_annotation_counts.py +17 -11
  63. data_management/lila/lila_common.py +14 -11
  64. data_management/lila/test_lila_metadata_urls.py +116 -0
  65. data_management/ocr_tools.py +829 -0
  66. data_management/resize_coco_dataset.py +13 -11
  67. data_management/yolo_output_to_md_output.py +84 -12
  68. data_management/yolo_to_coco.py +38 -20
  69. detection/process_video.py +36 -14
  70. detection/pytorch_detector.py +23 -8
  71. detection/run_detector.py +76 -19
  72. detection/run_detector_batch.py +178 -63
  73. detection/run_inference_with_yolov5_val.py +326 -57
  74. detection/run_tiled_inference.py +153 -43
  75. detection/video_utils.py +34 -8
  76. md_utils/ct_utils.py +172 -1
  77. md_utils/md_tests.py +372 -51
  78. md_utils/path_utils.py +167 -39
  79. md_utils/process_utils.py +26 -7
  80. md_utils/split_locations_into_train_val.py +215 -0
  81. md_utils/string_utils.py +10 -0
  82. md_utils/url_utils.py +0 -2
  83. md_utils/write_html_image_list.py +9 -26
  84. md_visualization/plot_utils.py +12 -8
  85. md_visualization/visualization_utils.py +106 -7
  86. md_visualization/visualize_db.py +16 -8
  87. md_visualization/visualize_detector_output.py +208 -97
  88. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/METADATA +3 -6
  89. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/RECORD +98 -121
  90. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/WHEEL +1 -1
  91. taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +1 -1
  92. taxonomy_mapping/map_new_lila_datasets.py +43 -39
  93. taxonomy_mapping/prepare_lila_taxonomy_release.py +5 -2
  94. taxonomy_mapping/preview_lila_taxonomy.py +27 -27
  95. taxonomy_mapping/species_lookup.py +33 -13
  96. taxonomy_mapping/taxonomy_csv_checker.py +7 -5
  97. api/synchronous/api_core/yolov5/detect.py +0 -252
  98. api/synchronous/api_core/yolov5/export.py +0 -607
  99. api/synchronous/api_core/yolov5/hubconf.py +0 -146
  100. api/synchronous/api_core/yolov5/models/__init__.py +0 -0
  101. api/synchronous/api_core/yolov5/models/common.py +0 -738
  102. api/synchronous/api_core/yolov5/models/experimental.py +0 -104
  103. api/synchronous/api_core/yolov5/models/tf.py +0 -574
  104. api/synchronous/api_core/yolov5/models/yolo.py +0 -338
  105. api/synchronous/api_core/yolov5/train.py +0 -670
  106. api/synchronous/api_core/yolov5/utils/__init__.py +0 -36
  107. api/synchronous/api_core/yolov5/utils/activations.py +0 -103
  108. api/synchronous/api_core/yolov5/utils/augmentations.py +0 -284
  109. api/synchronous/api_core/yolov5/utils/autoanchor.py +0 -170
  110. api/synchronous/api_core/yolov5/utils/autobatch.py +0 -66
  111. api/synchronous/api_core/yolov5/utils/aws/__init__.py +0 -0
  112. api/synchronous/api_core/yolov5/utils/aws/resume.py +0 -40
  113. api/synchronous/api_core/yolov5/utils/benchmarks.py +0 -148
  114. api/synchronous/api_core/yolov5/utils/callbacks.py +0 -71
  115. api/synchronous/api_core/yolov5/utils/dataloaders.py +0 -1087
  116. api/synchronous/api_core/yolov5/utils/downloads.py +0 -178
  117. api/synchronous/api_core/yolov5/utils/flask_rest_api/example_request.py +0 -19
  118. api/synchronous/api_core/yolov5/utils/flask_rest_api/restapi.py +0 -46
  119. api/synchronous/api_core/yolov5/utils/general.py +0 -1018
  120. api/synchronous/api_core/yolov5/utils/loggers/__init__.py +0 -187
  121. api/synchronous/api_core/yolov5/utils/loggers/wandb/__init__.py +0 -0
  122. api/synchronous/api_core/yolov5/utils/loggers/wandb/log_dataset.py +0 -27
  123. api/synchronous/api_core/yolov5/utils/loggers/wandb/sweep.py +0 -41
  124. api/synchronous/api_core/yolov5/utils/loggers/wandb/wandb_utils.py +0 -577
  125. api/synchronous/api_core/yolov5/utils/loss.py +0 -234
  126. api/synchronous/api_core/yolov5/utils/metrics.py +0 -355
  127. api/synchronous/api_core/yolov5/utils/plots.py +0 -489
  128. api/synchronous/api_core/yolov5/utils/torch_utils.py +0 -314
  129. api/synchronous/api_core/yolov5/val.py +0 -394
  130. md_utils/matlab_porting_tools.py +0 -97
  131. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/LICENSE +0 -0
  132. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/top_level.txt +0 -0
@@ -1,314 +0,0 @@
1
- # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
- """
3
- PyTorch utils
4
- """
5
-
6
- import math
7
- import os
8
- import platform
9
- import subprocess
10
- import time
11
- import warnings
12
- from contextlib import contextmanager
13
- from copy import deepcopy
14
- from pathlib import Path
15
-
16
- import torch
17
- import torch.distributed as dist
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
-
21
- from utils.general import LOGGER, file_date, git_describe
22
-
23
- try:
24
- import thop # for FLOPs computation
25
- except ImportError:
26
- thop = None
27
-
28
- # Suppress PyTorch warnings
29
- warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
30
-
31
-
32
- @contextmanager
33
- def torch_distributed_zero_first(local_rank: int):
34
- # Decorator to make all processes in distributed training wait for each local_master to do something
35
- if local_rank not in [-1, 0]:
36
- dist.barrier(device_ids=[local_rank])
37
- yield
38
- if local_rank == 0:
39
- dist.barrier(device_ids=[0])
40
-
41
-
42
- def device_count():
43
- # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows
44
- assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows'
45
- try:
46
- cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows
47
- return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
48
- except Exception:
49
- return 0
50
-
51
-
52
- def select_device(device='', batch_size=0, newline=True):
53
- # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
54
- s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
55
- device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
56
- cpu = device == 'cpu'
57
- mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
58
- if cpu or mps:
59
- os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
60
- elif device: # non-cpu device requested
61
- os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
62
- assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
63
- f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
64
-
65
- cuda = not cpu and torch.cuda.is_available()
66
- if cuda:
67
- devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
68
- n = len(devices) # device count
69
- if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
70
- assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
71
- space = ' ' * (len(s) + 1)
72
- for i, d in enumerate(devices):
73
- p = torch.cuda.get_device_properties(i)
74
- s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
75
- elif mps:
76
- s += 'MPS\n'
77
- else:
78
- s += 'CPU\n'
79
-
80
- if not newline:
81
- s = s.rstrip()
82
- LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
83
- return torch.device('cuda:0' if cuda else 'mps' if mps else 'cpu')
84
-
85
-
86
- def time_sync():
87
- # PyTorch-accurate time
88
- if torch.cuda.is_available():
89
- torch.cuda.synchronize()
90
- return time.time()
91
-
92
-
93
- def profile(input, ops, n=10, device=None):
94
- # YOLOv5 speed/memory/FLOPs profiler
95
- #
96
- # Usage:
97
- # input = torch.randn(16, 3, 640, 640)
98
- # m1 = lambda x: x * torch.sigmoid(x)
99
- # m2 = nn.SiLU()
100
- # profile(input, [m1, m2], n=100) # profile over 100 iterations
101
-
102
- results = []
103
- if not isinstance(device, torch.device):
104
- device = select_device(device)
105
- print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
106
- f"{'input':>24s}{'output':>24s}")
107
-
108
- for x in input if isinstance(input, list) else [input]:
109
- x = x.to(device)
110
- x.requires_grad = True
111
- for m in ops if isinstance(ops, list) else [ops]:
112
- m = m.to(device) if hasattr(m, 'to') else m # device
113
- m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
114
- tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
115
- try:
116
- flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
117
- except Exception:
118
- flops = 0
119
-
120
- try:
121
- for _ in range(n):
122
- t[0] = time_sync()
123
- y = m(x)
124
- t[1] = time_sync()
125
- try:
126
- _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
127
- t[2] = time_sync()
128
- except Exception: # no backward method
129
- # print(e) # for debug
130
- t[2] = float('nan')
131
- tf += (t[1] - t[0]) * 1000 / n # ms per op forward
132
- tb += (t[2] - t[1]) * 1000 / n # ms per op backward
133
- mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
134
- s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
135
- p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
136
- print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
137
- results.append([p, flops, mem, tf, tb, s_in, s_out])
138
- except Exception as e:
139
- print(e)
140
- results.append(None)
141
- torch.cuda.empty_cache()
142
- return results
143
-
144
-
145
- def is_parallel(model):
146
- # Returns True if model is of type DP or DDP
147
- return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
148
-
149
-
150
- def de_parallel(model):
151
- # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
152
- return model.module if is_parallel(model) else model
153
-
154
-
155
- def initialize_weights(model):
156
- for m in model.modules():
157
- t = type(m)
158
- if t is nn.Conv2d:
159
- pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
160
- elif t is nn.BatchNorm2d:
161
- m.eps = 1e-3
162
- m.momentum = 0.03
163
- elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
164
- m.inplace = True
165
-
166
-
167
- def find_modules(model, mclass=nn.Conv2d):
168
- # Finds layer indices matching module class 'mclass'
169
- return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
170
-
171
-
172
- def sparsity(model):
173
- # Return global model sparsity
174
- a, b = 0, 0
175
- for p in model.parameters():
176
- a += p.numel()
177
- b += (p == 0).sum()
178
- return b / a
179
-
180
-
181
- def prune(model, amount=0.3):
182
- # Prune model to requested global sparsity
183
- import torch.nn.utils.prune as prune
184
- print('Pruning model... ', end='')
185
- for name, m in model.named_modules():
186
- if isinstance(m, nn.Conv2d):
187
- prune.l1_unstructured(m, name='weight', amount=amount) # prune
188
- prune.remove(m, 'weight') # make permanent
189
- print(' %.3g global sparsity' % sparsity(model))
190
-
191
-
192
- def fuse_conv_and_bn(conv, bn):
193
- # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
194
- fusedconv = nn.Conv2d(conv.in_channels,
195
- conv.out_channels,
196
- kernel_size=conv.kernel_size,
197
- stride=conv.stride,
198
- padding=conv.padding,
199
- groups=conv.groups,
200
- bias=True).requires_grad_(False).to(conv.weight.device)
201
-
202
- # Prepare filters
203
- w_conv = conv.weight.clone().view(conv.out_channels, -1)
204
- w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
205
- fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
206
-
207
- # Prepare spatial bias
208
- b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
209
- b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
210
- fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
211
-
212
- return fusedconv
213
-
214
-
215
- def model_info(model, verbose=False, img_size=640):
216
- # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
217
- n_p = sum(x.numel() for x in model.parameters()) # number parameters
218
- n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
219
- if verbose:
220
- print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
221
- for i, (name, p) in enumerate(model.named_parameters()):
222
- name = name.replace('module_list.', '')
223
- print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
224
- (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
225
-
226
- try: # FLOPs
227
- from thop import profile
228
- stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
229
- img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
230
- flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
231
- img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
232
- fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
233
- except Exception:
234
- fs = ''
235
-
236
- name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model'
237
- LOGGER.info(f"{name} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
238
-
239
-
240
- def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
241
- # Scales img(bs,3,y,x) by ratio constrained to gs-multiple
242
- if ratio == 1.0:
243
- return img
244
- h, w = img.shape[2:]
245
- s = (int(h * ratio), int(w * ratio)) # new size
246
- img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
247
- if not same_shape: # pad/crop img
248
- h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
249
- return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
250
-
251
-
252
- def copy_attr(a, b, include=(), exclude=()):
253
- # Copy attributes from b to a, options to only include [...] and to exclude [...]
254
- for k, v in b.__dict__.items():
255
- if (len(include) and k not in include) or k.startswith('_') or k in exclude:
256
- continue
257
- else:
258
- setattr(a, k, v)
259
-
260
-
261
- class EarlyStopping:
262
- # YOLOv5 simple early stopper
263
- def __init__(self, patience=30):
264
- self.best_fitness = 0.0 # i.e. mAP
265
- self.best_epoch = 0
266
- self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
267
- self.possible_stop = False # possible stop may occur next epoch
268
-
269
- def __call__(self, epoch, fitness):
270
- if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
271
- self.best_epoch = epoch
272
- self.best_fitness = fitness
273
- delta = epoch - self.best_epoch # epochs without improvement
274
- self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
275
- stop = delta >= self.patience # stop training if patience exceeded
276
- if stop:
277
- LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
278
- f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
279
- f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
280
- f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
281
- return stop
282
-
283
-
284
- class ModelEMA:
285
- """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
286
- Keeps a moving average of everything in the model state_dict (parameters and buffers)
287
- For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
288
- """
289
-
290
- def __init__(self, model, decay=0.9999, tau=2000, updates=0):
291
- # Create EMA
292
- self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
293
- # if next(model.parameters()).device.type != 'cpu':
294
- # self.ema.half() # FP16 EMA
295
- self.updates = updates # number of EMA updates
296
- self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
297
- for p in self.ema.parameters():
298
- p.requires_grad_(False)
299
-
300
- def update(self, model):
301
- # Update EMA parameters
302
- with torch.no_grad():
303
- self.updates += 1
304
- d = self.decay(self.updates)
305
-
306
- msd = de_parallel(model).state_dict() # model state_dict
307
- for k, v in self.ema.state_dict().items():
308
- if v.dtype.is_floating_point:
309
- v *= d
310
- v += (1 - d) * msd[k].detach()
311
-
312
- def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
313
- # Update EMA attributes
314
- copy_attr(self.ema, model, include, exclude)