quadra 2.3.2a2__py3-none-any.whl → 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quadra/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "2.3.2a2"
1
+ __version__ = "2.4.0"
2
2
 
3
3
 
4
4
  def get_version():
@@ -307,6 +307,14 @@ class Classification(Generic[ClassificationDataModuleT], LightningTask[Classific
307
307
  # TODO: What happens if we have 64 precision?
308
308
  half_precision = "16" in self.trainer.precision
309
309
 
310
+ example_input: torch.Tensor | None = None
311
+
312
+ if hasattr(self.trainer, "datamodule") and hasattr(self.trainer.datamodule, "val_dataset"):
313
+ # Retrieve a better input to evaluate fp16 performance or efficientnetb0 does not sometimes export properly
314
+ example_input = self.trainer.datamodule.val_dataset[0][0]
315
+
316
+ # Selected rtol and atol are quite high, this is mostly done for efficientnetb0 that seems to be
317
+ # quite unstable in fp16
310
318
  self.model_json, export_paths = export_model(
311
319
  config=self.config,
312
320
  model=module.model,
@@ -314,6 +322,9 @@ class Classification(Generic[ClassificationDataModuleT], LightningTask[Classific
314
322
  half_precision=half_precision,
315
323
  input_shapes=input_shapes,
316
324
  idx_to_class=idx_to_class,
325
+ example_inputs=example_input,
326
+ rtol=0.05,
327
+ atol=0.01,
317
328
  )
318
329
 
319
330
  if len(export_paths) == 0:
@@ -1136,7 +1147,7 @@ class ClassificationEvaluation(Evaluation[ClassificationDataModuleT]):
1136
1147
  return
1137
1148
 
1138
1149
  if isinstance(self.deployment_model.model.features_extractor, timm.models.resnet.ResNet):
1139
- target_layers = [cast(BaseNetworkBuilder, self.deployment_model.model).features_extractor.layer4[-1]]
1150
+ target_layers = [cast(BaseNetworkBuilder, self.deployment_model.model).features_extractor.layer4[-1]] # type: ignore[index]
1140
1151
  self.cam = GradCAM(
1141
1152
  model=self.deployment_model.model,
1142
1153
  target_layers=target_layers,
quadra/utils/export.py CHANGED
@@ -119,6 +119,7 @@ def export_torchscript_model(
119
119
  input_shapes: list[Any] | None = None,
120
120
  half_precision: bool = False,
121
121
  model_name: str = "model.pt",
122
+ example_inputs: list[torch.Tensor] | tuple[torch.Tensor, ...] | torch.Tensor | None = None,
122
123
  ) -> tuple[str, Any] | None:
123
124
  """Export a PyTorch model with TorchScript.
124
125
 
@@ -128,6 +129,8 @@ def export_torchscript_model(
128
129
  output_path: Path to save the model
129
130
  half_precision: If True, the model will be exported with half precision
130
131
  model_name: Name of the exported model
132
+ example_inputs: If provided use this to evaluate the model instead of generating random inputs, it's expected to
133
+ be a list of tensors or a single tensor without batch dimension
131
134
 
132
135
  Returns:
133
136
  If the model is exported successfully, the path to the model and the input shape are returned.
@@ -144,7 +147,32 @@ def export_torchscript_model(
144
147
  else:
145
148
  model.cpu()
146
149
 
147
- model_inputs = extract_torch_model_inputs(model, input_shapes, half_precision)
150
+ batch_size = 1
151
+ model_inputs: tuple[list[Any] | tuple[Any, ...] | torch.Tensor, list[Any]] | None
152
+ if example_inputs is not None:
153
+ if isinstance(example_inputs, Sequence):
154
+ model_input_tensors = []
155
+ model_input_shapes = []
156
+
157
+ for example_input in example_inputs:
158
+ new_inp = example_input.to(
159
+ device="cuda:0" if half_precision else "cpu",
160
+ dtype=torch.float16 if half_precision else torch.float32,
161
+ )
162
+ new_inp = new_inp.unsqueeze(0).repeat(batch_size, *(1 for x in new_inp.shape))
163
+ model_input_tensors.append(new_inp)
164
+ model_input_shapes.append(new_inp[0].shape)
165
+
166
+ model_inputs = (model_input_tensors, [model_input_shapes])
167
+ else:
168
+ new_inp = example_inputs.to(
169
+ device="cuda:0" if half_precision else "cpu",
170
+ dtype=torch.float16 if half_precision else torch.float32,
171
+ )
172
+ new_inp = new_inp.unsqueeze(0).repeat(batch_size, *(1 for x in new_inp.shape))
173
+ model_inputs = (new_inp, [new_inp[0].shape])
174
+ else:
175
+ model_inputs = extract_torch_model_inputs(model, input_shapes, half_precision)
148
176
 
149
177
  if model_inputs is None:
150
178
  return None
@@ -182,6 +210,9 @@ def export_onnx_model(
182
210
  input_shapes: list[Any] | None = None,
183
211
  half_precision: bool = False,
184
212
  model_name: str = "model.onnx",
213
+ example_inputs: list[torch.Tensor] | tuple[torch.Tensor, ...] | torch.Tensor | None = None,
214
+ rtol: float = 0.01,
215
+ atol: float = 5e-3,
185
216
  ) -> tuple[str, Any] | None:
186
217
  """Export a PyTorch model with ONNX.
187
218
 
@@ -192,6 +223,10 @@ def export_onnx_model(
192
223
  onnx_config: ONNX export configuration
193
224
  half_precision: If True, the model will be exported with half precision
194
225
  model_name: Name of the exported model
226
+ example_inputs: If provided use this to evaluate the model instead of generating random inputs, it's expected to
227
+ be a list of tensors or a single tensor without batch dimension
228
+ rtol: Relative tolerance for the ONNX safe export in fp16
229
+ atol: Absolute tolerance for the ONNX safe export in fp16
195
230
  """
196
231
  if not ONNX_AVAILABLE:
197
232
  log.warning("ONNX is not installed, can not export model in this format.")
@@ -210,9 +245,32 @@ def export_onnx_model(
210
245
  else:
211
246
  batch_size = 1
212
247
 
213
- model_inputs = extract_torch_model_inputs(
214
- model=model, input_shapes=input_shapes, half_precision=half_precision, batch_size=batch_size
215
- )
248
+ model_inputs: tuple[list[Any] | tuple[Any, ...] | torch.Tensor, list[Any]] | None
249
+ if example_inputs is not None:
250
+ if isinstance(example_inputs, Sequence):
251
+ model_input_tensors = []
252
+ model_input_shapes = []
253
+
254
+ for example_input in example_inputs:
255
+ new_inp = example_input.to(
256
+ device="cuda:0" if half_precision else "cpu",
257
+ dtype=torch.float16 if half_precision else torch.float32,
258
+ )
259
+ new_inp = new_inp.unsqueeze(0).repeat(batch_size, *(1 for x in new_inp.shape))
260
+ model_input_tensors.append(new_inp)
261
+ model_input_shapes.append(new_inp[0].shape)
262
+
263
+ model_inputs = (model_input_tensors, [model_input_shapes])
264
+ else:
265
+ new_inp = example_inputs.to(
266
+ device="cuda:0" if half_precision else "cpu",
267
+ dtype=torch.float16 if half_precision else torch.float32,
268
+ )
269
+ new_inp = new_inp.unsqueeze(0).repeat(batch_size, *(1 for x in new_inp.shape))
270
+ model_inputs = ([new_inp], [new_inp[0].shape])
271
+ else:
272
+ model_inputs = extract_torch_model_inputs(model, input_shapes, half_precision)
273
+
216
274
  if model_inputs is None:
217
275
  return None
218
276
 
@@ -266,6 +324,8 @@ def export_onnx_model(
266
324
 
267
325
  if isinstance(inp, list):
268
326
  inp = tuple(inp) # onnx doesn't like lists representing tuples of inputs
327
+ elif isinstance(inp, torch.Tensor):
328
+ inp = (inp,)
269
329
 
270
330
  if isinstance(inp, dict):
271
331
  raise ValueError("ONNX export does not support model with dict inputs")
@@ -290,6 +350,8 @@ def export_onnx_model(
290
350
  onnx_config=onnx_config,
291
351
  input_shapes=input_shapes,
292
352
  input_names=input_names,
353
+ rtol=rtol,
354
+ atol=atol,
293
355
  )
294
356
 
295
357
  if not is_export_ok:
@@ -324,6 +386,8 @@ def _safe_export_half_precision_onnx(
324
386
  onnx_config: DictConfig,
325
387
  input_shapes: list[Any],
326
388
  input_names: list[str],
389
+ rtol: float = 0.01,
390
+ atol: float = 5e-3,
327
391
  ) -> bool:
328
392
  """Check that the exported half precision ONNX model does not contain NaN values. If it does, attempt to export
329
393
  the model with a more stable export and overwrite the original model.
@@ -335,6 +399,8 @@ def _safe_export_half_precision_onnx(
335
399
  onnx_config: ONNX export configuration
336
400
  input_shapes: Input shapes for the model
337
401
  input_names: Input names for the model
402
+ rtol: Relative tolerance to evaluate the model
403
+ atol: Absolute tolerance to evaluate the model
338
404
 
339
405
  Returns:
340
406
  True if the model is stable or it was possible to export a more stable model, False otherwise.
@@ -381,7 +447,7 @@ def _safe_export_half_precision_onnx(
381
447
  with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
382
448
  # This function prints a lot of information that is not useful for the user
383
449
  model_fp16 = auto_convert_mixed_precision(
384
- model_fp32, test_data, rtol=0.01, atol=5e-3, keep_io_types=False
450
+ model_fp32, test_data, rtol=rtol, atol=atol, keep_io_types=False
385
451
  )
386
452
  onnx.save(model_fp16, export_model_path)
387
453
 
@@ -431,6 +497,9 @@ def export_model(
431
497
  input_shapes: list[Any] | None = None,
432
498
  idx_to_class: dict[int, str] | None = None,
433
499
  pytorch_model_type: Literal["backbone", "model"] = "model",
500
+ example_inputs: list[Any] | tuple[Any, ...] | torch.Tensor | None = None,
501
+ rtol: float = 0.01,
502
+ atol: float = 5e-3,
434
503
  ) -> tuple[dict[str, Any], dict[str, str]]:
435
504
  """Generate deployment models for the task.
436
505
 
@@ -443,6 +512,9 @@ def export_model(
443
512
  idx_to_class: Mapping from class index to class name
444
513
  pytorch_model_type: Type of the pytorch model config to be exported, if it's backbone on disk we will save the
445
514
  config.backbone config, otherwise we will save the config.model
515
+ example_inputs: If provided use this to evaluate the model instead of generating random inputs
516
+ rtol: Relative tolerance for the ONNX safe export in fp16
517
+ atol: Absolute tolerance for the ONNX safe export in fp16
446
518
 
447
519
  Returns:
448
520
  If the model is exported successfully, return a dictionary containing information about the exported model and
@@ -468,6 +540,7 @@ def export_model(
468
540
  input_shapes=input_shapes,
469
541
  output_path=export_folder,
470
542
  half_precision=half_precision,
543
+ example_inputs=example_inputs,
471
544
  )
472
545
 
473
546
  if out is None:
@@ -495,6 +568,9 @@ def export_model(
495
568
  onnx_config=config.export.onnx,
496
569
  input_shapes=input_shapes,
497
570
  half_precision=half_precision,
571
+ example_inputs=example_inputs,
572
+ rtol=rtol,
573
+ atol=atol,
498
574
  )
499
575
 
500
576
  if out is None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: quadra
3
- Version: 2.3.2a2
3
+ Version: 2.4.0
4
4
  Summary: Deep Learning experiment orchestration library
5
5
  License: Apache-2.0
6
6
  Keywords: deep learning,experiment,lightning,hydra-core
@@ -39,7 +39,7 @@ Requires-Dist: onnxsim (==0.4.28) ; extra == "onnx"
39
39
  Requires-Dist: opencv_python_headless (>=4.7.0,<4.8.0)
40
40
  Requires-Dist: pandas (<2.0)
41
41
  Requires-Dist: pillow (>=10,<11)
42
- Requires-Dist: pydantic (==1.10.10)
42
+ Requires-Dist: pydantic (>=1.10.10)
43
43
  Requires-Dist: python_dotenv (>=0.21,<0.22)
44
44
  Requires-Dist: pytorch_lightning (>=2.4,<2.5)
45
45
  Requires-Dist: rich (>=13.2,<13.3)
@@ -49,11 +49,11 @@ Requires-Dist: seaborn (>=0.12,<0.13)
49
49
  Requires-Dist: segmentation_models_pytorch-orobix (==0.3.3.dev1)
50
50
  Requires-Dist: tensorboard (>=2.11,<2.12)
51
51
  Requires-Dist: timm (==0.9.12)
52
- Requires-Dist: torch (==2.4.1)
52
+ Requires-Dist: torch (==2.6.0)
53
53
  Requires-Dist: torchinfo (>=1.8,<1.9)
54
54
  Requires-Dist: torchmetrics (>=0.10,<0.11)
55
55
  Requires-Dist: torchsummary (>=1.5,<1.6)
56
- Requires-Dist: torchvision (>=0.19,<0.20)
56
+ Requires-Dist: torchvision (>=0.21,<0.22)
57
57
  Requires-Dist: tripy (>=1.0,<1.1)
58
58
  Requires-Dist: typing_extensions (==4.11.0) ; python_version < "3.10"
59
59
  Requires-Dist: xxhash (>=3.2,<3.3)
@@ -1,4 +1,4 @@
1
- quadra/__init__.py,sha256=7sypQyykpPeQGER7BmqRc3Z6RChcUEoiVg5zWeB9vn8,114
1
+ quadra/__init__.py,sha256=fv-5hfERt0uLXmjb7dOuh4wKtsXYgvW72hTUW_IvWQo,112
2
2
  quadra/callbacks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  quadra/callbacks/anomalib.py,sha256=WLBEGhZA9HoP4Yh9UbbC2GzDOKYTkvU9EY1lkZcV7Fs,11971
4
4
  quadra/callbacks/lightning.py,sha256=qvtzDiv8ZUV7K11gKHKWCyo-a9XR_Jm_M-IEicTM1Yo,20242
@@ -250,7 +250,7 @@ quadra/schedulers/warmup.py,sha256=chzzrK7OqqlicBCxiF4CqMYNrWu6nflIbRE-C86Jrw0,4
250
250
  quadra/tasks/__init__.py,sha256=tmAfMoH0k3UC7r2pNrgbBa1Pfc3tpLl3IObFF6Z0eRE,820
251
251
  quadra/tasks/anomaly.py,sha256=RHeiM1vZF1zsva37iYdiGx_HLgdAp8lXnmUzXja69YU,24638
252
252
  quadra/tasks/base.py,sha256=piYlTFtvqH-4s4oEq4GczdAs_gL29UHAJGsOC5Sd3Bc,14187
253
- quadra/tasks/classification.py,sha256=05l3QM3dsU2yTWhXxNAcJ8sZM0Vbfgey-e5EV6p1TX8,52816
253
+ quadra/tasks/classification.py,sha256=_GQOPMGuOZ_uLA9jFhLEaJkW_Sid_WKHNn9ALXnGNmo,53407
254
254
  quadra/tasks/patch.py,sha256=nzo8o-ei7iF1Iarvd8-c08s0Rs_lPvVPDLAbkFMx-Qw,20251
255
255
  quadra/tasks/segmentation.py,sha256=9Qy-V0Wvoofl7IrfotnSMgBIXcZd-WfZZtetyqmB0FY,16260
256
256
  quadra/tasks/ssl.py,sha256=XsaC9hbhvTA5UfHeRaaCstx9mTYacLRmgoCF5Tj9R5M,20547
@@ -262,7 +262,7 @@ quadra/utils/anomaly.py,sha256=49vFvT5-4SxczsEM2Akcut_M1DDwKlOVdGv36oLTgR0,4067
262
262
  quadra/utils/classification.py,sha256=dKFuv4RywWhvhstOnEOnaf-6qcViUK0dTgah9m9mw2Q,24917
263
263
  quadra/utils/deprecation.py,sha256=zF_S-yqenaZxRBOudhXts0mX763WjEUWCnHd09TZnwY,852
264
264
  quadra/utils/evaluation.py,sha256=oooRJPu1AaHhOwvB1Y6SFjQ645OkgrDzKtUvwWq8oq4,19005
265
- quadra/utils/export.py,sha256=ghNF8mQw-JjZiVeBJ0y8yIQkx8EG8ssPorn3aaIsgcA,20840
265
+ quadra/utils/export.py,sha256=dIbhnFPHo2wYoeyE48TeSzGjsf1FowCin3_ASR7BFJc,24621
266
266
  quadra/utils/imaging.py,sha256=Cz7sGb_axEmnGcwQJP2djFZpIpGCPFIBGT8NWVV-OOE,866
267
267
  quadra/utils/logger.py,sha256=tQJ4xpTAFKx1g-UUm5K1x7zgoP6qoXpcUHQyu0rOr1w,556
268
268
  quadra/utils/mlflow.py,sha256=DVso1lxn126hil8i4tTf5WFUPJ8uJNAzNU8OXbXwOzw,3586
@@ -293,8 +293,8 @@ quadra/utils/validator.py,sha256=wmVXycB90VNyAbKBUVncFCxK4nsYiOWJIY3ISXwxYCY,463
293
293
  quadra/utils/visualization.py,sha256=yYm7lPziUOlybxigZ2qTycNewb67Q80H4hjQGWUh788,16094
294
294
  quadra/utils/vit_explainability.py,sha256=Gh6BHaDEzWxOjJp1aqvCxLt9Rb8TXd5uKXOAx7-acUk,13351
295
295
  hydra_plugins/quadra_searchpath_plugin.py,sha256=AAn4TzR87zUK7nwSsK-KoqALiPtfQ8FvX3fgZPTGIJ0,1189
296
- quadra-2.3.2a2.dist-info/LICENSE,sha256=8cTbQtcWa02YJoSpMeV_gxj3jpMTkxvl-w3WJ5gV_QE,11342
297
- quadra-2.3.2a2.dist-info/METADATA,sha256=NL_uzLkCuKb52aWZ-xMbXc4zBt_egd18lpXNZK4Pqw0,17612
298
- quadra-2.3.2a2.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
299
- quadra-2.3.2a2.dist-info/entry_points.txt,sha256=sRYonBZyx-sAJeWcQNQoVQIU5lm02cnCQt6b15k0WHU,43
300
- quadra-2.3.2a2.dist-info/RECORD,,
296
+ quadra-2.4.0.dist-info/LICENSE,sha256=8cTbQtcWa02YJoSpMeV_gxj3jpMTkxvl-w3WJ5gV_QE,11342
297
+ quadra-2.4.0.dist-info/METADATA,sha256=FOt90lNFxRQd84gcN-nFewN-IoclxQ8eDZvhWIeh1Do,17610
298
+ quadra-2.4.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
299
+ quadra-2.4.0.dist-info/entry_points.txt,sha256=sRYonBZyx-sAJeWcQNQoVQIU5lm02cnCQt6b15k0WHU,43
300
+ quadra-2.4.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.1.2
2
+ Generator: poetry-core 2.1.3
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any