quadra 2.4.0__py3-none-any.whl → 2.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quadra/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "2.4.0"
1
+ __version__ = "2.5.0"
2
2
 
3
3
 
4
4
  def get_version():
@@ -9,3 +9,4 @@ experiment_path: null
9
9
  upload_artifacts: False
10
10
  upload_models: ${export.types} # Default behavior in quadra <= 1.5.6
11
11
  log_level: info
12
+ mlflow_zip_models: False
quadra/utils/export.py CHANGED
@@ -430,16 +430,15 @@ def _safe_export_half_precision_onnx(
430
430
  export_output = export_onnx_model(
431
431
  model=model,
432
432
  output_path=os.path.dirname(export_model_path),
433
- onnx_config=onnx_config,
433
+ # Force to not simplify fp32 model
434
+ onnx_config=DictConfig({**onnx_config, "simplify": False}),
434
435
  input_shapes=input_shapes,
435
436
  half_precision=False,
436
437
  model_name=os.path.basename(export_model_path),
437
438
  )
438
- if export_output is not None:
439
- export_model_path, _ = export_output
440
- else:
441
- log.warning("Failed to export model")
442
- return False
439
+ if export_output is None:
440
+ # This should not happen
441
+ raise RuntimeError("Failed to export model")
443
442
 
444
443
  model_fp32 = onnx.load(export_model_path)
445
444
  test_data = {input_names[i]: inp[i].float().cpu().numpy() for i in range(len(inp))}
quadra/utils/utils.py CHANGED
@@ -8,10 +8,12 @@ import glob
8
8
  import json
9
9
  import logging
10
10
  import os
11
+ import shutil
11
12
  import subprocess
12
13
  import sys
13
14
  import warnings
14
15
  from collections.abc import Iterable, Iterator, Sequence
16
+ from tempfile import TemporaryDirectory
15
17
  from typing import Any, cast
16
18
 
17
19
  import cv2
@@ -299,45 +301,78 @@ def finish(
299
301
  quadra_export.generate_torch_inputs(input_size, device=device, half_precision=half_precision),
300
302
  )
301
303
  types_to_upload = config.core.get("upload_models")
302
- for model_path in deployed_models:
303
- model_type = model_type_from_path(model_path)
304
- if model_type is None:
305
- logging.warning("%s model type not supported", model_path)
306
- continue
307
- if model_type is not None and model_type in types_to_upload:
308
- if model_type == "pytorch":
309
- logging.warning("Pytorch format still not supported for mlflow upload")
304
+ mlflow_zip_models = config.core.get("mlflow_zip_models", False)
305
+ model_uploaded = False
306
+ with mlflow.start_run(run_id=mlflow_logger.run_id) as _:
307
+ for model_path in deployed_models:
308
+ model_type = model_type_from_path(model_path)
309
+ model_name = os.path.basename(model_path)
310
+
311
+ if model_type is None:
312
+ logging.warning("%s model type not supported", model_path)
310
313
  continue
311
-
312
- model = quadra_export.import_deployment_model(
313
- model_path,
314
- device=device,
315
- inference_config=config.inference,
316
- )
317
-
318
- if model_type in ["torchscript", "pytorch"]:
319
- signature = infer_signature_model(model.model, inputs)
320
- with mlflow.start_run(run_id=mlflow_logger.run_id) as _:
321
- mlflow.pytorch.log_model(
322
- model.model,
323
- artifact_path=model_path,
324
- signature=signature,
325
- )
326
- elif model_type in ["onnx", "simplified_onnx"] and ONNX_AVAILABLE:
327
- signature = infer_signature_model(model, inputs)
328
- with mlflow.start_run(run_id=mlflow_logger.run_id) as _:
329
- if model.model_path is None:
330
- logging.warning(
331
- "Cannot log onnx model on mlflow, \
332
- BaseEvaluationModel 'model_path' attribute is None"
314
+ if model_type is not None and model_type in types_to_upload:
315
+ if model_type == "pytorch" and not mlflow_zip_models:
316
+ logging.warning("Pytorch format still not supported for mlflow upload")
317
+ continue
318
+
319
+ if mlflow_zip_models:
320
+ with TemporaryDirectory() as temp_dir:
321
+ if model_type == "pytorch" and os.path.isfile(
322
+ os.path.join(export_folder, "model_config.yaml")
323
+ ):
324
+ shutil.copy(model_path, temp_dir)
325
+ shutil.copy(os.path.join(export_folder, "model_config.yaml"), temp_dir)
326
+ shutil.make_archive("assets", "zip", root_dir=temp_dir)
327
+ else:
328
+ shutil.make_archive(
329
+ "assets",
330
+ "zip",
331
+ root_dir=os.path.dirname(model_path),
332
+ base_dir=model_name,
333
+ )
334
+ shutil.move("assets.zip", temp_dir)
335
+ mlflow.pyfunc.log_model(
336
+ artifact_path=model_path,
337
+ loader_module="not.used",
338
+ data_path=os.path.join(temp_dir, "assets.zip"),
339
+ pip_requirements=[""],
333
340
  )
334
- else:
335
- model_proto = onnx.load(model.model_path)
336
- mlflow.onnx.log_model(
337
- model_proto,
341
+ model_uploaded = True
342
+ else:
343
+ model = quadra_export.import_deployment_model(
344
+ model_path,
345
+ device=device,
346
+ inference_config=config.inference,
347
+ )
348
+
349
+ if model_type in ["torchscript", "pytorch"]:
350
+ signature = infer_signature_model(model.model, inputs)
351
+ mlflow.pytorch.log_model(
352
+ model.model,
338
353
  artifact_path=model_path,
339
354
  signature=signature,
340
355
  )
356
+ model_uploaded = True
357
+
358
+ elif model_type in ["onnx", "simplified_onnx"] and ONNX_AVAILABLE:
359
+ if model.model_path is None:
360
+ logging.warning(
361
+ "Cannot log onnx model on mlflow, \
362
+ BaseEvaluationModel 'model_path' attribute is None"
363
+ )
364
+ else:
365
+ signature = infer_signature_model(model, inputs)
366
+ model_proto = onnx.load(model.model_path)
367
+ mlflow.onnx.log_model(
368
+ model_proto,
369
+ artifact_path=model_path,
370
+ signature=signature,
371
+ )
372
+ model_uploaded = True
373
+
374
+ if model_uploaded:
375
+ mlflow.log_artifact(os.path.join(export_folder, "model.json"), export_folder)
341
376
 
342
377
  if tensorboard_logger is not None:
343
378
  config_paths = []
@@ -376,7 +411,7 @@ def model_type_from_path(model_path: str) -> str | None:
376
411
  - "pytorch" if the model has a '.pth' extension (PyTorch).
377
412
  - "simplified_onnx" if the model file ends with 'simplified.onnx' (Simplified ONNX).
378
413
  - "onnx" if the model has a '.onnx' extension (ONNX).
379
- - "json" id the model has a '.json' extension (JSON).
414
+ - "json" if the model has a '.json' extension (JSON).
380
415
  - None if model extension is not supported.
381
416
 
382
417
  Example:
@@ -1,8 +1,9 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: quadra
3
- Version: 2.4.0
3
+ Version: 2.5.0
4
4
  Summary: Deep Learning experiment orchestration library
5
5
  License: Apache-2.0
6
+ License-File: LICENSE
6
7
  Keywords: deep learning,experiment,lightning,hydra-core
7
8
  Author: Federico Belotti
8
9
  Author-email: federico.belotti@orobix.com
@@ -1,4 +1,4 @@
1
- quadra/__init__.py,sha256=fv-5hfERt0uLXmjb7dOuh4wKtsXYgvW72hTUW_IvWQo,112
1
+ quadra/__init__.py,sha256=my_GwHSTlCXOqpoWuPRTAF1_rKRUMylmeA4eWoMa96g,112
2
2
  quadra/callbacks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  quadra/callbacks/anomalib.py,sha256=WLBEGhZA9HoP4Yh9UbbC2GzDOKYTkvU9EY1lkZcV7Fs,11971
4
4
  quadra/callbacks/lightning.py,sha256=qvtzDiv8ZUV7K11gKHKWCyo-a9XR_Jm_M-IEicTM1Yo,20242
@@ -36,7 +36,7 @@ quadra/configs/callbacks/all.yaml,sha256=LZx8d0apwv9t0KlKqzFCrYo2NpuNBgcUnHhG_ku
36
36
  quadra/configs/callbacks/default.yaml,sha256=ZFPU1bm36hJsxI-85uiJx7TpX7qWkR8dibBKWtES4Yc,1180
37
37
  quadra/configs/callbacks/default_anomalib.yaml,sha256=FjjSj6HgMvH18MV4AKR4Ew0pOerPefj9eUGinUkODLE,2256
38
38
  quadra/configs/config.yaml,sha256=IULhqUF8Z7Cqr5Xx41EGj8dwtPQWSRiZ5jwzNg0Rjwk,686
39
- quadra/configs/core/default.yaml,sha256=IfgjKHXuOknq3CKvKKBqMPfiqmZSUOnc80Q2Jkbm7go,239
39
+ quadra/configs/core/default.yaml,sha256=-qXngf5iSRIgPPBNZFu00TeYmE2DC0WuG7x_YPSAG44,264
40
40
  quadra/configs/datamodule/base/anomaly.yaml,sha256=CILLAoQHunrT4BN0ynOzizTX29k-7B9vDVIZUYm-cBU,377
41
41
  quadra/configs/datamodule/base/classification.yaml,sha256=NYtGk4lmi9Os6bP01AU9xJw3cONRPcGXisdrnoOYpjE,546
42
42
  quadra/configs/datamodule/base/multilabel_classification.yaml,sha256=zp7AQx7V6cLJn7zRMKA04zzWKgVq5jTDJMGltqd9OP8,613
@@ -262,7 +262,7 @@ quadra/utils/anomaly.py,sha256=49vFvT5-4SxczsEM2Akcut_M1DDwKlOVdGv36oLTgR0,4067
262
262
  quadra/utils/classification.py,sha256=dKFuv4RywWhvhstOnEOnaf-6qcViUK0dTgah9m9mw2Q,24917
263
263
  quadra/utils/deprecation.py,sha256=zF_S-yqenaZxRBOudhXts0mX763WjEUWCnHd09TZnwY,852
264
264
  quadra/utils/evaluation.py,sha256=oooRJPu1AaHhOwvB1Y6SFjQ645OkgrDzKtUvwWq8oq4,19005
265
- quadra/utils/export.py,sha256=dIbhnFPHo2wYoeyE48TeSzGjsf1FowCin3_ASR7BFJc,24621
265
+ quadra/utils/export.py,sha256=fUdcZ2_VKBjuM1yK9nuIEUQ6tNE21SACgBfWhyC5rjw,24651
266
266
  quadra/utils/imaging.py,sha256=Cz7sGb_axEmnGcwQJP2djFZpIpGCPFIBGT8NWVV-OOE,866
267
267
  quadra/utils/logger.py,sha256=tQJ4xpTAFKx1g-UUm5K1x7zgoP6qoXpcUHQyu0rOr1w,556
268
268
  quadra/utils/mlflow.py,sha256=DVso1lxn126hil8i4tTf5WFUPJ8uJNAzNU8OXbXwOzw,3586
@@ -288,13 +288,13 @@ quadra/utils/tests/fixtures/models/classification.py,sha256=5qpyOonqK6W2LCUWEHhm
288
288
  quadra/utils/tests/fixtures/models/segmentation.py,sha256=CTNXeEPcFxFq-YcNfQi5DbbytPZwBQaZn5dQq3L41j0,765
289
289
  quadra/utils/tests/helpers.py,sha256=9PJlwozUl_lpQW-Ck-tN7sGFcgeieEd3q56aYuwMIlk,2381
290
290
  quadra/utils/tests/models.py,sha256=KbAlv_ukxaUYsyVNUO_dM0NyIosx8RpC0EVyF1HvPkM,507
291
- quadra/utils/utils.py,sha256=3tgj_tFFhKsGNJ9jrmULI9rWxFyhuUe53Y5SBJFkwSM,19124
291
+ quadra/utils/utils.py,sha256=_-iD8MG4g_qrzzMcBrgPWSouQU96jI-5kgYSNYjs4d0,21293
292
292
  quadra/utils/validator.py,sha256=wmVXycB90VNyAbKBUVncFCxK4nsYiOWJIY3ISXwxYCY,4632
293
293
  quadra/utils/visualization.py,sha256=yYm7lPziUOlybxigZ2qTycNewb67Q80H4hjQGWUh788,16094
294
294
  quadra/utils/vit_explainability.py,sha256=Gh6BHaDEzWxOjJp1aqvCxLt9Rb8TXd5uKXOAx7-acUk,13351
295
295
  hydra_plugins/quadra_searchpath_plugin.py,sha256=AAn4TzR87zUK7nwSsK-KoqALiPtfQ8FvX3fgZPTGIJ0,1189
296
- quadra-2.4.0.dist-info/LICENSE,sha256=8cTbQtcWa02YJoSpMeV_gxj3jpMTkxvl-w3WJ5gV_QE,11342
297
- quadra-2.4.0.dist-info/METADATA,sha256=FOt90lNFxRQd84gcN-nFewN-IoclxQ8eDZvhWIeh1Do,17610
298
- quadra-2.4.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
299
- quadra-2.4.0.dist-info/entry_points.txt,sha256=sRYonBZyx-sAJeWcQNQoVQIU5lm02cnCQt6b15k0WHU,43
300
- quadra-2.4.0.dist-info/RECORD,,
296
+ quadra-2.5.0.dist-info/METADATA,sha256=s-yGU-tHTS03MQtY5lXMvPv739qgVmz4ZvdDJs_J4FE,17632
297
+ quadra-2.5.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
298
+ quadra-2.5.0.dist-info/entry_points.txt,sha256=sRYonBZyx-sAJeWcQNQoVQIU5lm02cnCQt6b15k0WHU,43
299
+ quadra-2.5.0.dist-info/licenses/LICENSE,sha256=8cTbQtcWa02YJoSpMeV_gxj3jpMTkxvl-w3WJ5gV_QE,11342
300
+ quadra-2.5.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.1.3
2
+ Generator: poetry-core 2.2.1
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any