docling-ibm-models 3.3.0__tar.gz → 3.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/PKG-INFO +2 -2
  2. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/code_formula_model/code_formula_predictor.py +69 -10
  3. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/code_formula_model/models/sam_opt.py +7 -6
  4. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/code_formula_model/models/sam_opt_image_processor.py +4 -1
  5. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/document_figure_classifier_model/document_figure_classifier_predictor.py +6 -7
  6. docling_ibm_models-3.3.2/docling_ibm_models/tableformer/__init__.py +0 -0
  7. docling_ibm_models-3.3.2/docling_ibm_models/tableformer/data_management/__init__.py +0 -0
  8. docling_ibm_models-3.3.2/docling_ibm_models/tableformer/models/__init__.py +0 -0
  9. docling_ibm_models-3.3.2/docling_ibm_models/tableformer/models/common/__init__.py +0 -0
  10. docling_ibm_models-3.3.2/docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
  11. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +3 -3
  12. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/otsl.py +1 -1
  13. docling_ibm_models-3.3.2/docling_ibm_models/tableformer/utils/__init__.py +0 -0
  14. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/utils/mem_monitor.py +3 -2
  15. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/pyproject.toml +13 -13
  16. docling_ibm_models-3.3.0/docling_ibm_models/tableformer/utils/torch_utils.py +0 -216
  17. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/LICENSE +0 -0
  18. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/README.md +0 -0
  19. {docling_ibm_models-3.3.0/docling_ibm_models/tableformer → docling_ibm_models-3.3.2/docling_ibm_models}/__init__.py +0 -0
  20. {docling_ibm_models-3.3.0/docling_ibm_models/tableformer/data_management → docling_ibm_models-3.3.2/docling_ibm_models/code_formula_model}/__init__.py +0 -0
  21. {docling_ibm_models-3.3.0/docling_ibm_models/tableformer → docling_ibm_models-3.3.2/docling_ibm_models/code_formula_model}/models/__init__.py +0 -0
  22. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/code_formula_model/models/sam.py +0 -0
  23. {docling_ibm_models-3.3.0/docling_ibm_models/tableformer/models/common → docling_ibm_models-3.3.2/docling_ibm_models/document_figure_classifier_model}/__init__.py +0 -0
  24. {docling_ibm_models-3.3.0/docling_ibm_models/tableformer/models/table04_rs → docling_ibm_models-3.3.2/docling_ibm_models/layoutmodel}/__init__.py +0 -0
  25. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/layoutmodel/layout_predictor.py +0 -0
  26. /docling_ibm_models-3.3.0/docling_ibm_models/tableformer/utils/__init__.py → /docling_ibm_models-3.3.2/docling_ibm_models/py.typed +0 -0
  27. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/common.py +0 -0
  28. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/data_management/functional.py +0 -0
  29. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/data_management/matching_post_processor.py +0 -0
  30. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +0 -0
  31. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/data_management/tf_predictor.py +0 -0
  32. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/data_management/transforms.py +0 -0
  33. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/models/common/base_model.py +0 -0
  34. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +0 -0
  35. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +0 -0
  36. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +0 -0
  37. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/settings.py +0 -0
  38. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/utils/app_profiler.py +0 -0
  39. {docling_ibm_models-3.3.0 → docling_ibm_models-3.3.2}/docling_ibm_models/tableformer/utils/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: docling-ibm-models
3
- Version: 3.3.0
3
+ Version: 3.3.2
4
4
  Summary: This package contains the AI models used by the Docling PDF conversion package
5
5
  License: MIT
6
6
  Keywords: docling,convert,document,pdf,layout model,segmentation,table structure,table former
@@ -20,7 +20,7 @@ Classifier: Programming Language :: Python :: 3.11
20
20
  Classifier: Programming Language :: Python :: 3.12
21
21
  Classifier: Programming Language :: Python :: 3.13
22
22
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
- Requires-Dist: Pillow (>=10.0.0,<11.0.0)
23
+ Requires-Dist: Pillow (>=10.0.0,<12.0.0)
24
24
  Requires-Dist: huggingface_hub (>=0.23,<1)
25
25
  Requires-Dist: jsonlines (>=3.1.0,<4.0.0)
26
26
  Requires-Dist: numpy (>=1.24.4,<2.0.0) ; sys_platform == "darwin" and platform_machine == "x86_64"
@@ -3,12 +3,12 @@
3
3
  # SPDX-License-Identifier: MIT
4
4
  #
5
5
  import logging
6
- from typing import List, Union
6
+ from typing import List, Optional, Union
7
7
 
8
8
  import numpy as np
9
9
  import torch
10
10
  from PIL import Image
11
- from transformers import AutoTokenizer
11
+ from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList
12
12
 
13
13
  from docling_ibm_models.code_formula_model.models.sam_opt import SamOPTForCausalLM
14
14
  from docling_ibm_models.code_formula_model.models.sam_opt_image_processor import (
@@ -18,6 +18,22 @@ from docling_ibm_models.code_formula_model.models.sam_opt_image_processor import
18
18
  _log = logging.getLogger(__name__)
19
19
 
20
20
 
21
+ class StopOnString(StoppingCriteria):
22
+ def __init__(self, tokenizer, stop_string):
23
+ self.stop_token_ids = tokenizer.encode(stop_string, add_special_tokens=False)
24
+
25
+ def __call__(self, input_ids, scores, **kwargs):
26
+ for sequence in input_ids:
27
+ sequence_list = sequence.tolist()
28
+ for i in range(len(sequence_list) - len(self.stop_token_ids) + 1):
29
+ if (
30
+ sequence_list[i : i + len(self.stop_token_ids)]
31
+ == self.stop_token_ids
32
+ ):
33
+ return True
34
+ return False
35
+
36
+
21
37
  class CodeFormulaPredictor:
22
38
  """
23
39
  Code and Formula Predictor using a multi-modal vision-language model.
@@ -127,12 +143,37 @@ class CodeFormulaPredictor:
127
143
 
128
144
  return prompt
129
145
 
146
+ def _strip(self, text: str):
147
+ """
148
+ Removes any occurrences of the substrings in remove_list from the end of text.
149
+
150
+ Parameters
151
+ ----------
152
+ text : str
153
+ The original string.
154
+
155
+ Returns
156
+ -------
157
+ str
158
+ The trimmed string.
159
+ """
160
+ remove_list = [r"\quad", r"\\", r"\,", " c c c c", " l l l l l"]
161
+ changed = True
162
+ while changed:
163
+ changed = False
164
+ for substr in remove_list:
165
+ if text.endswith(substr):
166
+ text = text[: -len(substr)]
167
+ changed = True
168
+
169
+ return text.strip()
170
+
130
171
  @torch.inference_mode()
131
172
  def predict(
132
173
  self,
133
174
  images: List[Union[Image.Image, np.ndarray]],
134
175
  labels: List[str],
135
- temperature: float = 0.1,
176
+ temperature: Optional[float] = 0.0,
136
177
  ) -> List[str]:
137
178
  """
138
179
  Predicts the textual representation of input images (code or LaTeX).
@@ -143,8 +184,8 @@ class CodeFormulaPredictor:
143
184
  List of images to be processed, provided as PIL Image objects or numpy arrays.
144
185
  labels : List[str]
145
186
  List of labels indicating the type of each image ('code' or 'formula').
146
- temperature : float, optional
147
- Sampling temperature for generation, by default set to 0.1.
187
+ temperature : Optional[float]
188
+ Sampling temperature for generation, by default set to 0.0.
148
189
 
149
190
  Returns
150
191
  -------
@@ -159,7 +200,11 @@ class CodeFormulaPredictor:
159
200
  Excpetion
160
201
  In case the temperature is an invalid number.
161
202
  """
162
- if (type(temperature) != float and type(temperature) != int) or temperature < 0:
203
+ if (
204
+ temperature is None
205
+ or not (isinstance(temperature, float) or isinstance(temperature, int))
206
+ or temperature < 0
207
+ ):
163
208
  raise Exception("Temperature must be a number greater or equal to 0.")
164
209
 
165
210
  do_sample = True
@@ -181,11 +226,10 @@ class CodeFormulaPredictor:
181
226
  else:
182
227
  raise TypeError("Not supported input image format")
183
228
  images_tmp.append(image)
184
- images = images_tmp
185
229
 
186
- images_tensor = torch.stack([self._image_processor(img) for img in images]).to(
187
- self._device
188
- )
230
+ images_tensor = torch.stack(
231
+ [self._image_processor(img) for img in images_tmp]
232
+ ).to(self._device)
189
233
 
190
234
  prompts = [self._get_prompt(label) for label in labels]
191
235
 
@@ -195,6 +239,16 @@ class CodeFormulaPredictor:
195
239
  prompt_ids = tokenized["input_ids"]
196
240
  attention_mask = tokenized["attention_mask"]
197
241
 
242
+ stopping_criteria = StoppingCriteriaList(
243
+ [
244
+ StopOnString(self._tokenizer, r" \quad \quad \quad \quad"),
245
+ StopOnString(self._tokenizer, r" \\ \\ \\ \\"),
246
+ StopOnString(self._tokenizer, r" \, \, \, \,"),
247
+ StopOnString(self._tokenizer, r" c c c c c c c c c c c c c c c c"),
248
+ StopOnString(self._tokenizer, r" l l l l l l l l l l l l l l l l l"),
249
+ ]
250
+ )
251
+
198
252
  if self._device == "cpu":
199
253
  output_ids_list = self._model.generate(
200
254
  input_ids=prompt_ids,
@@ -204,6 +258,8 @@ class CodeFormulaPredictor:
204
258
  temperature=temperature,
205
259
  max_new_tokens=4096 - prompt_ids.shape[1],
206
260
  use_cache=True,
261
+ no_repeat_ngram_size=200,
262
+ stopping_criteria=stopping_criteria,
207
263
  )
208
264
  else:
209
265
  with torch.autocast(device_type=self._device, dtype=torch.bfloat16):
@@ -214,10 +270,13 @@ class CodeFormulaPredictor:
214
270
  temperature=temperature,
215
271
  max_new_tokens=4096 - prompt_ids.shape[1],
216
272
  use_cache=True,
273
+ no_repeat_ngram_size=200,
274
+ stopping_criteria=stopping_criteria,
217
275
  )
218
276
 
219
277
  outputs = self._tokenizer.batch_decode(
220
278
  output_ids_list[:, prompt_ids.shape[1] :], skip_special_tokens=True
221
279
  )
280
+ outputs = [self._strip(output) for output in outputs]
222
281
 
223
282
  return outputs
@@ -67,14 +67,14 @@ class SamOPTModel(OPTModel):
67
67
 
68
68
  def forward(
69
69
  self,
70
- input_ids: torch.LongTensor = None,
70
+ input_ids: torch.LongTensor,
71
71
  attention_mask: Optional[torch.Tensor] = None,
72
72
  past_key_values: Optional[List[torch.FloatTensor]] = None,
73
73
  inputs_embeds: Optional[torch.FloatTensor] = None,
74
74
  use_cache: Optional[bool] = None,
75
75
  output_attentions: Optional[bool] = None,
76
76
  output_hidden_states: Optional[bool] = None,
77
- images: torch.FloatTensor = None,
77
+ images: Optional[torch.FloatTensor] = None,
78
78
  return_dict: Optional[bool] = None,
79
79
  ) -> Union[Tuple, BaseModelOutputWithPast]:
80
80
 
@@ -86,6 +86,7 @@ class SamOPTModel(OPTModel):
86
86
 
87
87
  if input_ids.shape[1] != 1 or self.training:
88
88
  with torch.set_grad_enabled(self.training):
89
+ assert vision_tower is not None
89
90
  image_features = vision_tower(images)
90
91
  image_features = image_features.flatten(2).permute(0, 2, 1)
91
92
  image_features = self.mm_projector(image_features)
@@ -94,9 +95,9 @@ class SamOPTModel(OPTModel):
94
95
  for cur_input_ids, cur_input_embeds, cur_image_features in zip(
95
96
  input_ids, inputs_embeds, image_features
96
97
  ):
97
- image_start_token_position = torch.where(
98
- cur_input_ids == im_start_token
99
- )[0].item()
98
+ image_start_token_position = int(
99
+ torch.where(cur_input_ids == im_start_token)[0].item()
100
+ ) # cast to int for mypy
100
101
 
101
102
  cur_image_features = cur_image_features.to(
102
103
  device=cur_input_embeds.device
@@ -115,7 +116,7 @@ class SamOPTModel(OPTModel):
115
116
 
116
117
  new_input_embeds.append(cur_input_embeds)
117
118
 
118
- inputs_embeds = torch.stack(new_input_embeds, dim=0)
119
+ inputs_embeds = torch.stack(new_input_embeds, dim=0) # type: ignore
119
120
 
120
121
  return super(SamOPTModel, self).forward(
121
122
  input_ids=None,
@@ -28,4 +28,7 @@ class SamOptImageProcessor(ImageProcessingMixin):
28
28
  return image
29
29
 
30
30
 
31
- AutoImageProcessor.register(SamOptImageProcessor, SamOptImageProcessor)
31
+ AutoImageProcessor.register(
32
+ config_class="SamOptImageProcessor",
33
+ slow_image_processor_class=SamOptImageProcessor,
34
+ )
@@ -147,24 +147,23 @@ class DocumentFigureClassifierPredictor:
147
147
 
148
148
  The predictions for each image are sorted in descending order of confidence.
149
149
  """
150
- processed_images = []
150
+ rgb_images = []
151
151
  for image in images:
152
152
  if isinstance(image, Image.Image):
153
- processed_images.append(image.convert("RGB"))
153
+ rgb_images.append(image.convert("RGB"))
154
154
  elif isinstance(image, np.ndarray):
155
- processed_images.append(Image.fromarray(image).convert("RGB"))
155
+ rgb_images.append(Image.fromarray(image).convert("RGB"))
156
156
  else:
157
157
  raise TypeError(
158
158
  "Supported input formats are PIL.Image.Image or numpy.ndarray."
159
159
  )
160
- images = processed_images
161
160
 
162
161
  # (batch_size, 3, 224, 224)
163
- images = [self._image_processor(image) for image in images]
164
- images = torch.stack(images).to(self._device)
162
+ processed_images = [self._image_processor(image) for image in rgb_images]
163
+ torch_images = torch.stack(processed_images).to(self._device)
165
164
 
166
165
  with torch.no_grad():
167
- logits = self._model(images).logits # (batch_size, num_classes)
166
+ logits = self._model(torch_images).logits # (batch_size, num_classes)
168
167
  probs_batch = logits.softmax(dim=1) # (batch_size, num_classes)
169
168
  probs_batch = probs_batch.cpu().numpy().tolist()
170
169
 
@@ -36,7 +36,7 @@ class PositionalEncoding(nn.Module):
36
36
 
37
37
 
38
38
  class TMTransformerDecoder(nn.TransformerDecoder):
39
- def forward(
39
+ def forward( # type: ignore
40
40
  self,
41
41
  tgt: Tensor,
42
42
  memory: Optional[Tensor] = None,
@@ -69,11 +69,11 @@ class TMTransformerDecoder(nn.TransformerDecoder):
69
69
  else:
70
70
  out_cache = torch.stack(tag_cache, dim=0)
71
71
 
72
- return output, out_cache
72
+ return output, out_cache # type: ignore
73
73
 
74
74
 
75
75
  class TMTransformerDecoderLayer(nn.TransformerDecoderLayer):
76
- def forward(
76
+ def forward( # type: ignore
77
77
  self,
78
78
  tgt: Tensor,
79
79
  memory: Optional[Tensor] = None,
@@ -11,7 +11,7 @@ import docling_ibm_models.tableformer.settings as s
11
11
  LOG_LEVEL = logging.INFO
12
12
  # LOG_LEVEL = logging.DEBUG
13
13
  logger = s.get_custom_logger("consolidate", LOG_LEVEL)
14
- png_files = {} # Evaluation files
14
+ # png_files = {} # Evaluation files
15
15
  total_pics = 0
16
16
 
17
17
 
@@ -5,6 +5,7 @@
5
5
  import os
6
6
  import platform
7
7
  import re
8
+ from typing import Dict, Union
8
9
 
9
10
 
10
11
  class MemMonitor:
@@ -112,7 +113,7 @@ class MemMonitor:
112
113
  regex_str = r"({}:)(\s+)(\d*)(.*)".format(mem_field)
113
114
  self._status_regex[mem_field] = re.compile(regex_str)
114
115
 
115
- def get_memory_full(self) -> dict:
116
+ def get_memory_full(self) -> Union[Dict, int]:
116
117
  r"""
117
118
  - Parse /proc/<pid>status to get all memory info.
118
119
  - The method returns a dict with the fields self._status_fields
@@ -140,7 +141,7 @@ class MemMonitor:
140
141
 
141
142
  return memory
142
143
 
143
- def get_memory(self) -> dict:
144
+ def get_memory(self) -> Union[Dict, int]:
144
145
  r"""
145
146
  - Parse /proc/<pid>statm to get the most important memory fields
146
147
  - This is a fast implementation.
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "docling-ibm-models"
3
- version = "3.3.0" # DO NOT EDIT, updated automatically
3
+ version = "3.3.2" # DO NOT EDIT, updated automatically
4
4
  description = "This package contains the AI models used by the Docling PDF conversion package"
5
5
  authors = ["Nikos Livathinos <nli@zurich.ibm.com>", "Maxim Lysak <mly@zurich.ibm.com>", "Ahmed Nassar <ahn@zurich.ibm.com>", "Christoph Auer <cau@zurich.ibm.com>", "Michele Dolfi <dol@zurich.ibm.com>", "Peter Staar <taa@zurich.ibm.com>"]
6
6
  license = "MIT"
@@ -33,7 +33,7 @@ numpy = [
33
33
  { version = ">=1.24.4,<2.0.0", markers = 'sys_platform == "darwin" and platform_machine == "x86_64"' },
34
34
  ]
35
35
  jsonlines = "^3.1.0"
36
- Pillow = "^10.0.0"
36
+ Pillow = ">=10.0.0,<12.0.0"
37
37
  tqdm = "^4.64.0"
38
38
  opencv-python-headless = "^4.6.0.66"
39
39
  huggingface_hub = ">=0.23,<1"
@@ -106,14 +106,14 @@ parser_angular_minor_types = "feat"
106
106
  parser_angular_patch_types = "fix,perf"
107
107
 
108
108
 
109
- # [tool.mypy]
110
- # pretty = true
111
- # no_implicit_optional = true
112
- # python_version = "3.10"
113
- #
114
- # [[tool.mypy.overrides]]
115
- # module = [
116
- # "torchvision.*",
117
- # "transformers.*"
118
- # ]
119
- # ignore_missing_imports = true
109
+ [tool.mypy]
110
+ pretty = true
111
+ no_implicit_optional = true
112
+ python_version = "3.10"
113
+
114
+ [[tool.mypy.overrides]]
115
+ module = [
116
+ "torchvision.*",
117
+ "transformers.*"
118
+ ]
119
+ ignore_missing_imports = true
@@ -1,216 +0,0 @@
1
- #
2
- # Copyright IBM Corp. 2024 - 2024
3
- # SPDX-License-Identifier: MIT
4
- #
5
- import torch
6
-
7
-
8
- def model_info(model, verbose=False):
9
- # Plots a line-by-line description of a PyTorch model
10
- n_p = sum(x.numel() for x in model.parameters()) # number parameters
11
- n_g = sum(
12
- x.numel() for x in model.parameters() if x.requires_grad
13
- ) # number gradients
14
- if verbose:
15
- print(
16
- "%5s %40s %9s %12s %20s %10s %10s"
17
- % ("layer", "name", "gradient", "parameters", "shape", "mu", "sigma")
18
- )
19
- for i, (name, p) in enumerate(model.named_parameters()):
20
- name = name.replace("module_list.", "")
21
- print(
22
- "%5g %40s %9s %12g %20s %10.3g %10.3g"
23
- % (
24
- i,
25
- name,
26
- p.requires_grad,
27
- p.numel(),
28
- list(p.shape),
29
- p.mean(),
30
- p.std(),
31
- )
32
- )
33
-
34
- try: # FLOPS
35
- from thop import profile
36
-
37
- macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False)
38
- fs = ", %.1f GFLOPS" % (macs / 1e9 * 2)
39
- except Exception:
40
- fs = ""
41
-
42
- print(
43
- "Model Summary: %g layers, %g parameters, %g gradients%s"
44
- % (len(list(model.parameters())), n_p, n_g, fs)
45
- )
46
-
47
-
48
- # def init_seeds(seed=0):
49
- # torch.manual_seed(seed)
50
- #
51
- # # Reduce randomness (may be slower on Tesla GPUs)
52
- # # https://pytorch.org/docs/stable/notes/randomness.html
53
- # if seed == 0:
54
- # cudnn.deterministic = False
55
- # cudnn.benchmark = True
56
- #
57
- #
58
- # def select_device(device='', apex=False, batch_size=None):
59
- # # device = 'cpu' or '0' or '0,1,2,3'
60
- # cpu_request = device.lower() == 'cpu'
61
- # if device and not cpu_request: # if device requested other than 'cpu'
62
- # os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
63
- # # check availablity
64
- # assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device
65
- #
66
- # cuda = False if cpu_request else torch.cuda.is_available()
67
- # if cuda:
68
- # c = 1024 ** 2 # bytes to MB
69
- # ng = torch.cuda.device_count()
70
- # if ng > 1 and batch_size: # check that batch_size is compatible with device_count
71
- # assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % \
72
- # (batch_size, ng)
73
- # x = [torch.cuda.get_device_properties(i) for i in range(ng)]
74
- # # apex for mixed precision https://github.com/NVIDIA/apex
75
- # s = 'Using CUDA ' + ('Apex ' if apex else '')
76
- # for i in range(0, ng):
77
- # if i == 1:
78
- # s = ' ' * len(s)
79
- # print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
80
- # (s, i, x[i].name, x[i].total_memory / c))
81
- # else:
82
- # print('Using CPU')
83
- #
84
- # print('') # skip a line
85
- # return torch.device('cuda:0' if cuda else 'cpu')
86
- #
87
- #
88
- # def time_synchronized():
89
- # torch.cuda.synchronize() if torch.cuda.is_available() else None
90
- # return time.time()
91
- #
92
- #
93
- # def initialize_weights(model):
94
- # for m in model.modules():
95
- # t = type(m)
96
- # if t is nn.Conv2d:
97
- # pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
98
- # elif t is nn.BatchNorm2d:
99
- # m.eps = 1e-4
100
- # m.momentum = 0.03
101
- # elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
102
- # m.inplace = True
103
- #
104
- #
105
- # def find_modules(model, mclass=nn.Conv2d):
106
- # # finds layer indices matching module class 'mclass'
107
- # return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
108
- #
109
- #
110
- # def fuse_conv_and_bn(conv, bn):
111
- # # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
112
- # with torch.no_grad():
113
- # # init
114
- # fusedconv = torch.nn.Conv2d(conv.in_channels,
115
- # conv.out_channels,
116
- # kernel_size=conv.kernel_size,
117
- # stride=conv.stride,
118
- # padding=conv.padding,
119
- # bias=True)
120
- #
121
- # # prepare filters
122
- # w_conv = conv.weight.clone().view(conv.out_channels, -1)
123
- # w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
124
- # fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
125
- #
126
- # # prepare spatial bias
127
- # if conv.bias is not None:
128
- # b_conv = conv.bias
129
- # else:
130
- # b_conv = torch.zeros(conv.weight.size(0))
131
- # b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
132
- # fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
133
- #
134
- # return fusedconv
135
- #
136
- #
137
- # def load_classifier(name='resnet101', n=2):
138
- # # Loads a pretrained model reshaped to n-class output
139
- # import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision
140
- # model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet')
141
- #
142
- # # Display model properties
143
- # for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean',
144
- # 'model.std']:
145
- # print(x + ' =', eval(x))
146
- #
147
- # # Reshape output to n classes
148
- # filters = model.last_linear.weight.shape[1]
149
- # model.last_linear.bias = torch.nn.Parameter(torch.zeros(n))
150
- # model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
151
- # model.last_linear.out_features = n
152
- # return model
153
- #
154
- #
155
- # def scale_img(img, ratio=1.0, same_shape=True): # img(16,3,256,416), r=ratio
156
- # # scales img(bs,3,y,x) by ratio
157
- # h, w = img.shape[2:]
158
- # s = (int(h * ratio), int(w * ratio)) # new size
159
- # img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
160
- # if not same_shape: # pad/crop img
161
- # gs = 64 # (pixels) grid size
162
- # h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
163
- # return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
164
- #
165
- #
166
- # class ModelEMA:
167
- # """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
168
- # Keep a moving average of everything in the model state_dict (parameters and buffers).
169
- # This is intended to allow functionality like
170
- # https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
171
- # A smoothed version of the weights is necessary for some training schemes to perform well.
172
- # E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
173
- # RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
174
- # smoothing of weights to match results. Pay attention to the decay constant you are using
175
- # relative to your update count per epoch.
176
- # To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
177
- # disable validation of the EMA weights. Validation will have to be done manually in a separate
178
- # process, or after the training stops converging.
179
- # This class is sensitive where it is initialized in the sequence of model init,
180
- # GPU assignment and distributed training wrappers.
181
- # I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and
182
- # single-GPU.
183
- # """
184
- #
185
- # def __init__(self, model, decay=0.9999, device=''):
186
- # # make a copy of the model for accumulating moving average of weights
187
- # self.ema = deepcopy(model)
188
- # self.ema.eval()
189
- # self.updates = 0 # number of EMA updates
190
- # # decay exponential ramp (to help early epochs)
191
- # self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
192
- # self.device = device # perform ema on different device from model if set
193
- # if device:
194
- # self.ema.to(device=device)
195
- # for p in self.ema.parameters():
196
- # p.requires_grad_(False)
197
- #
198
- # def update(self, model):
199
- # self.updates += 1
200
- # d = self.decay(self.updates)
201
- # with torch.no_grad():
202
- # if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
203
- # msd, esd = model.module.state_dict(), self.ema.module.state_dict()
204
- # else:
205
- # msd, esd = model.state_dict(), self.ema.state_dict()
206
- #
207
- # for k, v in esd.items():
208
- # if v.dtype.is_floating_point:
209
- # v *= d
210
- # v += (1. - d) * msd[k].detach()
211
- #
212
- # def update_attr(self, model):
213
- # # Assign attributes (which may change during training)
214
- # for k in model.__dict__.keys():
215
- # if not k.startswith('_'):
216
- # setattr(self.ema, k, getattr(model, k))