dcnum 0.22.1__tar.gz → 0.23.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dcnum might be problematic. Click here for more details.

Files changed (118) hide show
  1. {dcnum-0.22.1 → dcnum-0.23.1}/.github/workflows/check.yml +1 -1
  2. {dcnum-0.22.1 → dcnum-0.23.1}/CHANGELOG +7 -0
  3. {dcnum-0.22.1 → dcnum-0.23.1}/PKG-INFO +4 -2
  4. {dcnum-0.22.1 → dcnum-0.23.1}/pyproject.toml +4 -0
  5. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/_version.py +2 -2
  6. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_background/base.py +1 -1
  7. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_texture/tex_all.py +28 -1
  8. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/gate.py +2 -2
  9. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/queue_event_extractor.py +1 -1
  10. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/logic/ctrl.py +2 -1
  11. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/meta/ppid.py +16 -2
  12. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/segm/__init__.py +4 -1
  13. dcnum-0.23.1/src/dcnum/segm/segm_torch/__init__.py +19 -0
  14. dcnum-0.23.1/src/dcnum/segm/segm_torch/segm_torch_base.py +125 -0
  15. dcnum-0.23.1/src/dcnum/segm/segm_torch/segm_torch_mpo.py +71 -0
  16. dcnum-0.23.1/src/dcnum/segm/segm_torch/segm_torch_sto.py +88 -0
  17. dcnum-0.23.1/src/dcnum/segm/segm_torch/torch_model.py +95 -0
  18. dcnum-0.23.1/src/dcnum/segm/segm_torch/torch_postproc.py +93 -0
  19. dcnum-0.23.1/src/dcnum/segm/segm_torch/torch_preproc.py +114 -0
  20. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/segm/segmenter.py +41 -1
  21. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum.egg-info/PKG-INFO +4 -2
  22. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum.egg-info/SOURCES.txt +12 -1
  23. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum.egg-info/requires.txt +3 -0
  24. dcnum-0.23.1/tests/data/segm-torch-model_unet-dcnum-test_g1_910c2.zip +0 -0
  25. dcnum-0.23.1/tests/data/segm-torch-test-data_unet-dcnum-test_g1_910c2.zip +0 -0
  26. {dcnum-0.22.1 → dcnum-0.23.1}/tests/helper_methods.py +23 -21
  27. dcnum-0.23.1/tests/test_segm_torch.py +175 -0
  28. dcnum-0.23.1/tests/test_segm_torch_preproc.py +90 -0
  29. {dcnum-0.22.1 → dcnum-0.23.1}/.github/workflows/deploy_pypi.yml +0 -0
  30. {dcnum-0.22.1 → dcnum-0.23.1}/.gitignore +0 -0
  31. {dcnum-0.22.1 → dcnum-0.23.1}/.readthedocs.yml +0 -0
  32. {dcnum-0.22.1 → dcnum-0.23.1}/LICENSE +0 -0
  33. {dcnum-0.22.1 → dcnum-0.23.1}/README.rst +0 -0
  34. {dcnum-0.22.1 → dcnum-0.23.1}/docs/conf.py +0 -0
  35. {dcnum-0.22.1 → dcnum-0.23.1}/docs/extensions/github_changelog.py +0 -0
  36. {dcnum-0.22.1 → dcnum-0.23.1}/docs/index.rst +0 -0
  37. {dcnum-0.22.1 → dcnum-0.23.1}/docs/requirements.txt +0 -0
  38. {dcnum-0.22.1 → dcnum-0.23.1}/setup.cfg +0 -0
  39. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/__init__.py +0 -0
  40. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/__init__.py +0 -0
  41. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/event_extractor_manager_thread.py +0 -0
  42. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_background/__init__.py +0 -0
  43. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_background/bg_copy.py +0 -0
  44. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_background/bg_roll_median.py +0 -0
  45. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_background/bg_sparse_median.py +0 -0
  46. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_brightness/__init__.py +0 -0
  47. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_brightness/bright_all.py +0 -0
  48. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_brightness/common.py +0 -0
  49. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_contour/__init__.py +0 -0
  50. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_contour/contour.py +0 -0
  51. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_contour/moments.py +0 -0
  52. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_contour/volume.py +0 -0
  53. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_texture/__init__.py +0 -0
  54. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/feat/feat_texture/common.py +0 -0
  55. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/logic/__init__.py +0 -0
  56. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/logic/job.py +0 -0
  57. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/logic/json_encoder.py +0 -0
  58. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/meta/__init__.py +0 -0
  59. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/meta/paths.py +0 -0
  60. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/read/__init__.py +0 -0
  61. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/read/cache.py +0 -0
  62. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/read/const.py +0 -0
  63. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/read/hdf5_data.py +0 -0
  64. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/read/mapped.py +0 -0
  65. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/segm/segm_thresh.py +0 -0
  66. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/segm/segmenter_manager_thread.py +0 -0
  67. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/segm/segmenter_mpo.py +0 -0
  68. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/segm/segmenter_sto.py +0 -0
  69. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/write/__init__.py +0 -0
  70. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/write/deque_writer_thread.py +0 -0
  71. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/write/queue_collector_thread.py +0 -0
  72. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum/write/writer.py +0 -0
  73. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum.egg-info/dependency_links.txt +0 -0
  74. {dcnum-0.22.1 → dcnum-0.23.1}/src/dcnum.egg-info/top_level.txt +0 -0
  75. {dcnum-0.22.1 → dcnum-0.23.1}/tests/conftest.py +0 -0
  76. {dcnum-0.22.1 → dcnum-0.23.1}/tests/data/fmt-hdf5_cytoshot_extended-moments-features.zip +0 -0
  77. {dcnum-0.22.1 → dcnum-0.23.1}/tests/data/fmt-hdf5_cytoshot_full-features_2023.zip +0 -0
  78. {dcnum-0.22.1 → dcnum-0.23.1}/tests/data/fmt-hdf5_cytoshot_full-features_2024.zip +0 -0
  79. {dcnum-0.22.1 → dcnum-0.23.1}/tests/data/fmt-hdf5_cytoshot_full-features_legacy_allev_2023.zip +0 -0
  80. {dcnum-0.22.1 → dcnum-0.23.1}/tests/data/fmt-hdf5_shapein_empty.zip +0 -0
  81. {dcnum-0.22.1 → dcnum-0.23.1}/tests/data/fmt-hdf5_shapein_raw-with-variable-length-logs.zip +0 -0
  82. {dcnum-0.22.1 → dcnum-0.23.1}/tests/requirements.txt +0 -0
  83. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_feat_background_base.py +0 -0
  84. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_feat_background_bg_copy.py +0 -0
  85. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_feat_background_bg_roll_median.py +0 -0
  86. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_feat_background_bg_sparsemed.py +0 -0
  87. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_feat_brightness.py +0 -0
  88. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_feat_event_extractor_manager.py +0 -0
  89. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_feat_gate.py +0 -0
  90. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_feat_haralick.py +0 -0
  91. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_feat_moments_based.py +0 -0
  92. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_feat_moments_based_extended.py +0 -0
  93. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_feat_volume.py +0 -0
  94. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_init.py +0 -0
  95. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_logic_job.py +0 -0
  96. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_logic_join.py +0 -0
  97. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_logic_json.py +0 -0
  98. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_logic_pipeline.py +0 -0
  99. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_meta_paths.py +0 -0
  100. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_meta_ppid_base.py +0 -0
  101. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_meta_ppid_bg.py +0 -0
  102. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_meta_ppid_data.py +0 -0
  103. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_meta_ppid_feat.py +0 -0
  104. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_meta_ppid_gate.py +0 -0
  105. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_meta_ppid_segm.py +0 -0
  106. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_read_basin.py +0 -0
  107. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_read_concat_hdf5.py +0 -0
  108. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_read_hdf5.py +0 -0
  109. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_read_hdf5_basins.py +0 -0
  110. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_read_hdf5_index_mapping.py +0 -0
  111. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_segm_base.py +0 -0
  112. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_segm_mpo.py +0 -0
  113. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_segm_no_mask_proc.py +0 -0
  114. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_segm_sto.py +0 -0
  115. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_segm_thresh.py +0 -0
  116. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_write_deque_writer_thread.py +0 -0
  117. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_write_queue_collector_thread.py +0 -0
  118. {dcnum-0.22.1 → dcnum-0.23.1}/tests/test_write_writer.py +0 -0
@@ -31,7 +31,7 @@ jobs:
31
31
  run: |
32
32
  # https://github.com/luispedro/mahotas/issues/144
33
33
  pip install mahotas==1.4.13
34
- pip install -e .
34
+ pip install .[torch]
35
35
  - name: List installed packages
36
36
  run: |
37
37
  pip freeze
@@ -1,3 +1,10 @@
1
+ 0.23.1
2
+ - enh: support passing custom default arguments to get_class_method_info
3
+ - tests: fix torch preprocessing tests
4
+ 0.23.0
5
+ - feat: implement segmentation using PyTorch models
6
+ - fix: always compute image_bg if it is not in the input file
7
+ - enh: introduce `Segmenter.validate_applicability` method
1
8
  0.22.1
2
9
  - fix: compute pipeline identifier of origin dataset for basin mapping
3
10
  0.22.0
@@ -1,8 +1,8 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dcnum
3
- Version: 0.22.1
3
+ Version: 0.23.1
4
4
  Summary: numerics toolbox for imaging deformability cytometry
5
- Author: Maximilian Schlögel, Paul Müller
5
+ Author: Maximilian Schlögel, Paul Müller, Raghava Alajangi
6
6
  Maintainer-email: Paul Müller <dev@craban.de>
7
7
  License: MIT
8
8
  Project-URL: source, https://github.com/DC-Analysis/dcnum
@@ -25,6 +25,8 @@ Requires-Dist: numpy>=1.21
25
25
  Requires-Dist: opencv-python-headless
26
26
  Requires-Dist: scikit-image
27
27
  Requires-Dist: scipy>=1.8.0
28
+ Provides-Extra: torch
29
+ Requires-Dist: torch>=2.3; extra == "torch"
28
30
 
29
31
  |dcnum|
30
32
  =======
@@ -8,6 +8,7 @@ authors = [
8
8
  # In alphabetical order.
9
9
  {name = "Maximilian Schlögel"},
10
10
  {name = "Paul Müller"},
11
+ {name = "Raghava Alajangi"},
11
12
  ]
12
13
  maintainers = [
13
14
  {name = "Paul Müller", email="dev@craban.de"},
@@ -35,6 +36,9 @@ dependencies = [
35
36
  ]
36
37
  dynamic = ["version"]
37
38
 
39
+ [project.optional-dependencies]
40
+ torch = ["torch>=2.3"]
41
+
38
42
  [project.urls]
39
43
  source = "https://github.com/DC-Analysis/dcnum"
40
44
  tracker = "https://github.com/DC-Analysis/dcnum/issues"
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.22.1'
16
- __version_tuple__ = version_tuple = (0, 22, 1)
15
+ __version__ = version = '0.23.1'
16
+ __version_tuple__ = version_tuple = (0, 23, 1)
@@ -130,7 +130,7 @@ class Background(abc.ABC):
130
130
  """Return a unique background pipeline identifier
131
131
 
132
132
  The pipeline identifier is universally applicable and must
133
- be backwards-compatible (future versions of dcevent will
133
+ be backwards-compatible (future versions of dcnum will
134
134
  correctly acknowledge the ID).
135
135
 
136
136
  The segmenter pipeline ID is defined as::
@@ -6,6 +6,34 @@ from .common import haralick_names
6
6
 
7
7
  def haralick_texture_features(
8
8
  mask, image=None, image_bg=None, image_corr=None):
9
+ """Compute Haralick texture features
10
+
11
+ The following texture features are excluded
12
+
13
+ - feature 6 "Sum Average", which is equivalent to `2 * bright_bc_avg`
14
+ since dclab 0.44.0
15
+ - feature 10 "Difference Variance", because it has a functional
16
+ dependency on the offset value and since we do background correction,
17
+ we are not interested in it
18
+ - feature 14, because nobody is using it, it is not understood by
19
+ everyone what it actually is, and it is computationally expensive.
20
+
21
+ This leaves us with the following 11 texture features (22 if you count
22
+ avg and ptp):
23
+ https://earlglynn.github.io/RNotes/package/EBImage/Haralick-Textural-Features.html
24
+
25
+ - 1. `tex_asm`: (1) Angular Second Moment
26
+ - 2. `tex_con`: (2) Contrast
27
+ - 3. `tex_cor`: (3) Correlation
28
+ - 4. `tex_var`: (4) Variance
29
+ - 5. `tex_idm`: (5) Inverse Difference Moment
30
+ - 6. `tex_sva`: (7) Sum Variance
31
+ - 7. `tex_sen`: (8) Sum Entropy
32
+ - 8. `tex_ent`: (9) Entropy
33
+ - 9. `tex_den`: (11) Difference Entropy
34
+ - 10. `tex_f12`: (12) Information Measure of Correlation 1
35
+ - 11. `tex_f13`: (13) Information Measure of Correlation 2
36
+ """
9
37
  # make sure we have a boolean array
10
38
  mask = np.array(mask, dtype=bool)
11
39
  size = mask.shape[0]
@@ -22,7 +50,6 @@ def haralick_texture_features(
22
50
 
23
51
  for ii in range(size):
24
52
  # Haralick texture features
25
- # https://gitlab.gwdg.de/blood_data_analysis/dcevent/-/issues/20
26
53
  # Preprocessing:
27
54
  # - create a copy of the array (don't edit `image_corr`)
28
55
  # - add grayscale values (negative values not supported)
@@ -20,7 +20,7 @@ class Gate:
20
20
  Parameters
21
21
  ----------
22
22
  data: .HDF5Data
23
- dcevent data instance
23
+ dcnum data instance
24
24
  online_gates: bool
25
25
  set to True to enable gating with "online" gates stored
26
26
  in the input file; online gates are applied in real-time
@@ -95,7 +95,7 @@ class Gate:
95
95
  """Return a unique gating pipeline identifier
96
96
 
97
97
  The pipeline identifier is universally applicable and must
98
- be backwards-compatible (future versions of dcevent will
98
+ be backwards-compatible (future versions of dcnum will
99
99
  correctly acknowledge the ID).
100
100
 
101
101
  The gating pipeline ID is defined as::
@@ -266,7 +266,7 @@ class QueueEventExtractor:
266
266
  """Return a unique feature extractor pipeline identifier
267
267
 
268
268
  The pipeline identifier is universally applicable and must
269
- be backwards-compatible (future versions of dcevent will
269
+ be backwards-compatible (future versions of dcnum will
270
270
  correctly acknowledge the ID).
271
271
 
272
272
  The feature extractor pipeline ID is defined as::
@@ -339,7 +339,8 @@ class DCNumJobRunner(threading.Thread):
339
339
  # hash sanity check above, check the generation, input data,
340
340
  # and background pipeline identifiers.
341
341
  redo_bg = (
342
- (datdict["gen_id"] != self.ppdict["gen_id"])
342
+ "image_bg" not in self.draw
343
+ or (datdict["gen_id"] != self.ppdict["gen_id"])
343
344
  or (datdict["dat_id"] != self.ppdict["dat_id"])
344
345
  or (datdict["bg_id"] != self.ppdict["bg_id"]))
345
346
 
@@ -59,7 +59,9 @@ def convert_to_dtype(value, dtype):
59
59
 
60
60
 
61
61
  def get_class_method_info(class_obj: ClassWithPPIDCapabilities,
62
- static_kw_methods: List = None):
62
+ static_kw_methods: List = None,
63
+ static_kw_defaults: Dict = None,
64
+ ):
63
65
  """Return dictionary of class info with static keyword methods docs
64
66
 
65
67
  Parameters
@@ -69,7 +71,16 @@ def get_class_method_info(class_obj: ClassWithPPIDCapabilities,
69
71
  static_kw_methods: list of callable
70
72
  The methods to inspect; all kwargs-only keyword arguments
71
73
  are extracted.
74
+ static_kw_defaults: dict
75
+ If a key in this dictionary matches an item in `static_kw_methods`,
76
+ then these are the default values returned in the "defaults"
77
+ dictionary. This is used in cases where a base class does
78
+ implement some annotations, but the subclass does not actually
79
+ use them, because e.g. they are taken from a property such as is
80
+ the case for the mask postprocessing of segmenter classes.
72
81
  """
82
+ if static_kw_defaults is None:
83
+ static_kw_defaults = {}
73
84
  doc = class_obj.__doc__ or class_obj.__init__.__doc__
74
85
  info = {
75
86
  "code": class_obj.get_ppid_code(),
@@ -82,7 +93,10 @@ def get_class_method_info(class_obj: ClassWithPPIDCapabilities,
82
93
  for mm in static_kw_methods:
83
94
  meth = getattr(class_obj, mm)
84
95
  spec = inspect.getfullargspec(meth)
85
- defau[mm] = spec.kwonlydefaults or {}
96
+ if mm_defaults := static_kw_defaults.get(mm):
97
+ defau[mm] = mm_defaults
98
+ else:
99
+ defau[mm] = spec.kwonlydefaults or {}
86
100
  annot[mm] = spec.annotations
87
101
  info["defaults"] = defau
88
102
  info["annotations"] = annot
@@ -1,6 +1,9 @@
1
1
  # flake8: noqa: F401
2
- from .segmenter import Segmenter, get_available_segmenters
2
+ from .segmenter import (
3
+ Segmenter, SegmenterNotApplicableError, get_available_segmenters
4
+ )
3
5
  from .segmenter_mpo import MPOSegmenter
4
6
  from .segmenter_sto import STOSegmenter
5
7
  from .segmenter_manager_thread import SegmenterManagerThread
6
8
  from . import segm_thresh
9
+ from . import segm_torch
@@ -0,0 +1,19 @@
1
+ import importlib
2
+
3
+ try:
4
+ torch = importlib.import_module("torch")
5
+ req_maj = 2
6
+ req_min = 3
7
+ ver_tuple = torch.__version__.split(".")
8
+ act_maj = int(ver_tuple[0])
9
+ act_min = int(ver_tuple[1])
10
+ if act_maj < req_maj or (act_maj == req_maj and act_min < req_min):
11
+ raise ValueError(f"Your PyTorch version {act_maj}.{act_min} is not "
12
+ f"supported, please update to at least "
13
+ f"{req_maj}.{req_min}")
14
+ except ImportError:
15
+ pass
16
+ else:
17
+ from .segm_torch_mpo import SegmentTorchMPO # noqa: F401
18
+ if torch.cuda.is_available():
19
+ from .segm_torch_sto import SegmentTorchSTO # noqa: F401
@@ -0,0 +1,125 @@
1
+ import functools
2
+ import pathlib
3
+ import re
4
+ from typing import Dict
5
+
6
+ from ...meta import paths
7
+
8
+ from ..segmenter import Segmenter, SegmenterNotApplicableError
9
+
10
+ from .torch_model import load_model
11
+
12
+
13
+ class TorchSegmenterBase(Segmenter):
14
+ """Torch segmenters that use a pretrained model for segmentation"""
15
+ requires_background_correction = False
16
+ mask_postprocessing = True
17
+ mask_default_kwargs = {
18
+ "clear_border": True,
19
+ "fill_holes": True,
20
+ "closing_disk": 0,
21
+ }
22
+
23
+ @classmethod
24
+ def get_ppid_from_ppkw(cls, kwargs, kwargs_mask=None):
25
+ kwargs_new = kwargs.copy()
26
+ # Make sure that the `model_file` kwarg is actually just a filename
27
+ # so that the pipeline identifier only contains the name, but not
28
+ # the full path.
29
+ if "model_file" in kwargs:
30
+ model_file = kwargs["model_file"]
31
+ mpath = pathlib.Path(model_file)
32
+ if mpath.exists():
33
+ # register the location of the file in the search path
34
+ # registry so other threads/processes will find it.
35
+ paths.register_search_path("torch_model_files", mpath.parent)
36
+ kwargs_new["model_file"] = mpath.name
37
+ return super(TorchSegmenterBase, cls).get_ppid_from_ppkw(kwargs_new,
38
+ kwargs_mask)
39
+
40
+ @classmethod
41
+ def validate_applicability(cls,
42
+ segmenter_kwargs: Dict,
43
+ meta: Dict = None,
44
+ logs: Dict = None):
45
+ """Validate the applicability of this segmenter for a dataset
46
+
47
+ The applicability is defined by the metadata in the segmentation
48
+ model.
49
+
50
+ Parameters
51
+ ----------
52
+ segmenter_kwargs: dict
53
+ Keyword arguments for the segmenter
54
+ meta: dict
55
+ Dictionary of metadata from an :class:`HDF5Data` instance
56
+ logs: dict
57
+ Dictionary of logs from an :class:`HDF5Data` instance
58
+
59
+ Returns
60
+ -------
61
+ applicable: bool
62
+ True if the segmenter is applicable to the dataset
63
+
64
+ Raises
65
+ ------
66
+ SegmenterNotApplicable
67
+ If the segmenter is not applicable to the dataset
68
+ """
69
+ if "model_file" not in segmenter_kwargs:
70
+ raise ValueError("A `model_file` must be provided in the "
71
+ "`segmenter_kwargs` to validate applicability")
72
+
73
+ model_file = segmenter_kwargs["model_file"]
74
+ _, model_meta = load_model(model_file, device="cpu")
75
+
76
+ reasons_list = []
77
+ validators = {
78
+ "meta": functools.partial(
79
+ cls._validate_applicability_item,
80
+ data_dict=meta,
81
+ reasons_list=reasons_list),
82
+ "logs": functools.partial(
83
+ cls._validate_applicability_item,
84
+ # convert logs to strings
85
+ data_dict={key: "\n".join(val) for key, val in logs.items()},
86
+ reasons_list=reasons_list)
87
+ }
88
+ for item in model_meta.get("validation", []):
89
+ it = item["type"]
90
+ if it in validators:
91
+ validators[it](item)
92
+ else:
93
+ reasons_list.append(
94
+ f"invalid validation type {it} in {model_file}")
95
+
96
+ if reasons_list:
97
+ raise SegmenterNotApplicableError(segmenter_class=cls,
98
+ reasons_list=reasons_list)
99
+
100
+ return True
101
+
102
+ @staticmethod
103
+ def _validate_applicability_item(item, data_dict, reasons_list):
104
+ """Populate `reasons_list` with invalid entries
105
+
106
+ Example `data_dict`::
107
+
108
+ {"type": "meta",
109
+ "key": "setup:region",
110
+ "allow-missing-key": False,
111
+ "regexp": "^channel$",
112
+ "regexp-negate": False,
113
+ "reason": "only channel region supported",
114
+ }
115
+ """
116
+ key = item["key"]
117
+ if key in data_dict:
118
+ regexp = re.compile(item["regexp"])
119
+ matched = bool(regexp.match(data_dict[key]))
120
+ negate = item.get("regexp-negate", False)
121
+ valid = matched if not negate else not matched
122
+ if not valid:
123
+ reasons_list.append(item.get("reason", "unknown reason"))
124
+ elif not item.get("allow-missing-key", False):
125
+ reasons_list.append(f"Key '{key}' missing in {item['type']}")
@@ -0,0 +1,71 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from ..segmenter_mpo import MPOSegmenter
5
+
6
+ from .segm_torch_base import TorchSegmenterBase
7
+ from .torch_model import load_model
8
+ from .torch_preproc import preprocess_images
9
+ from .torch_postproc import postprocess_masks
10
+
11
+
12
+ class SegmentTorchMPO(TorchSegmenterBase, MPOSegmenter):
13
+ """PyTorch segmentation (multiprocessing version)"""
14
+
15
+ @staticmethod
16
+ def segment_algorithm(image, *,
17
+ model_file: str = None):
18
+ """
19
+ Parameters
20
+ ----------
21
+ image: 2d ndarray
22
+ event image
23
+ model_file: str
24
+ path to or name of a dcnum model file (.dcnm); if only a
25
+ name is provided, then the "torch_model_files" directory
26
+ paths are searched for the file name
27
+
28
+ Returns
29
+ -------
30
+ mask: 2d boolean or integer ndarray
31
+ mask or labeling image for the give index
32
+ """
33
+ if model_file is None:
34
+ raise ValueError("Please specify a .dcnm model file!")
35
+
36
+ # Set number of pytorch threads to 1, because dcnum is doing
37
+ # all the multiprocessing.
38
+ # https://pytorch.org/docs/stable/generated/torch.set_num_threads.html#torch.set_num_threads
39
+ torch.set_num_threads(1)
40
+ device = torch.device("cpu")
41
+
42
+ # Load model and metadata
43
+ model, model_meta = load_model(model_file, device)
44
+
45
+ image_preproc = preprocess_images(image[np.newaxis, :, :],
46
+ **model_meta["preprocessing"])
47
+
48
+ image_ten = torch.from_numpy(image_preproc)
49
+
50
+ # Move image tensors to device
51
+ image_ten_on_device = image_ten.to(device)
52
+ # Model inference
53
+ pred_tensor = model(image_ten_on_device)
54
+
55
+ # Convert cuda-tensor into numpy mask array. The `pred_tensor`
56
+ # array is still of the shape (1, 1, H, W). The `masks`
57
+ # array is of shape (1, H, W). We can optionally label it
58
+ # here (we have to if the shapes don't match) or do it in
59
+ # postprocessing.
60
+ masks = pred_tensor.detach().cpu().numpy()[0] >= 0.5
61
+
62
+ # Perform postprocessing in cases where the image shapes don't match
63
+ assert len(masks[0].shape) == len(image.shape), "sanity check"
64
+ if masks[0].shape != image.shape:
65
+ labels = postprocess_masks(
66
+ masks=masks,
67
+ original_image_shape=image.shape,
68
+ )
69
+ return labels[0]
70
+ else:
71
+ return masks[0]
@@ -0,0 +1,88 @@
1
+ from dcnum.segm import STOSegmenter
2
+ import numpy as np
3
+ import torch
4
+
5
+ from .segm_torch_base import TorchSegmenterBase
6
+ from .torch_model import load_model
7
+ from .torch_preproc import preprocess_images
8
+ from .torch_postproc import postprocess_masks
9
+
10
+
11
+ class SegmentTorchSTO(TorchSegmenterBase, STOSegmenter):
12
+ """PyTorch segmentation (GPU version)"""
13
+
14
+ @staticmethod
15
+ def _segment_in_batches(imgs_t, model, batch_size, device):
16
+ """Segment image data in batches"""
17
+ size = len(imgs_t)
18
+ # Create empty array to fill up with segmented batches
19
+ masks = np.empty((len(imgs_t), *imgs_t[0].shape[-2:]),
20
+ dtype=bool)
21
+
22
+ for start_idx in range(0, size, batch_size):
23
+ batch = imgs_t[start_idx:start_idx + batch_size]
24
+ # Move image tensors to cuda
25
+ batch = torch.tensor(batch, device=device)
26
+ # Model inference
27
+ batch_seg = model(batch)
28
+ # Remove extra dim [B, C, H, W] --> [B, H, W]
29
+ batch_seg = batch_seg.squeeze(1)
30
+ # Convert cuda-tensor into numpy arrays
31
+ batch_seg_np = batch_seg.detach().cpu().numpy()
32
+ # Fill empty array with segmented batch
33
+ masks[start_idx:start_idx + batch_size] = batch_seg_np >= 0.5
34
+
35
+ return masks
36
+
37
+ @staticmethod
38
+ def segment_algorithm(images, gpu_id=None, batch_size=50, *,
39
+ model_file: str = None):
40
+ """
41
+ Parameters
42
+ ----------
43
+ images: 3d ndarray
44
+ array of N event images of shape (N, H, W)
45
+ gpu_id: str
46
+ optional argument specifying the GPU to use
47
+ batch_size: int
48
+ number of images to process in one batch
49
+ model_file: str
50
+ path to or name of a dcnum model file (.dcnm); if only a
51
+ name is provided, then the "torch_model_files" directory
52
+ paths are searched for the file name
53
+
54
+ Returns
55
+ -------
56
+ mask: 2d boolean or integer ndarray
57
+ mask or label images of shape (N, H, W)
58
+ """
59
+ if model_file is None:
60
+ raise ValueError("Please specify a model file!")
61
+
62
+ # Determine device to use
63
+ device = torch.device(gpu_id if gpu_id is not None else "cuda")
64
+
65
+ # Load model and metadata
66
+ model, model_meta = load_model(model_file, device)
67
+
68
+ # Preprocess the images
69
+ image_preproc = preprocess_images(images,
70
+ **model_meta["preprocessing"])
71
+ # Model inference
72
+ # The `masks` array has the shape (len(images), H, W), where
73
+ # H and W may be different from the corresponding axes in `images`.
74
+ masks = SegmentTorchSTO._segment_in_batches(image_preproc,
75
+ model,
76
+ batch_size,
77
+ device
78
+ )
79
+
80
+ # Perform postprocessing in cases where the image shapes don't match
81
+ assert len(masks.shape[1:]) == len(images.shape[1:]), "sanity check"
82
+ if masks.shape[1:] != images.shape[1:]:
83
+ labels = postprocess_masks(
84
+ masks=masks,
85
+ original_image_shape=images.shape[1:])
86
+ return labels
87
+ else:
88
+ return masks
@@ -0,0 +1,95 @@
1
+ import errno
2
+ import functools
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ import os
7
+ import pathlib
8
+
9
+ import torch
10
+
11
+ from ...meta import paths
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def check_md5sum(path):
18
+ """Verify the last five characters of the file stem with its MD5 hash"""
19
+ md5 = hashlib.md5(path.read_bytes()).hexdigest()
20
+ if md5[:5] != path.stem.split("_")[-1]:
21
+ raise ValueError(f"MD5 mismatch for {path} ({md5})! Expected the "
22
+ f"input file to end with '{md5[:5]}{path.suffix}'.")
23
+
24
+
25
+ @functools.cache
26
+ def load_model(path_or_name, device):
27
+ """Load a PyTorch model + metadata from a TorchScript jit checkpoint
28
+
29
+ Parameters
30
+ ----------
31
+ path_or_name: str or pathlib.Path
32
+ jit checkpoint file; For dcnum, these files have the suffix .dcnm
33
+ and contain a special `_extra_files["dcnum_meta.json"]` extra
34
+ file that can be loaded via `torch.jit.load` (see below).
35
+ device: str or torch.device
36
+ device on which to run the model
37
+
38
+ Returns
39
+ -------
40
+ model_jit: torch.jit.ScriptModule
41
+ loaded PyTorch model stored as a TorchScript module
42
+ model_meta: dict
43
+ metadata associated with the loaded model
44
+ """
45
+ model_path = retrieve_model_file(path_or_name)
46
+ # define an extra files mapping dictionary that loads the model's metadata
47
+ extra_files = {"dcnum_meta.json": ""}
48
+ # load model
49
+ model_jit = torch.jit.load(model_path,
50
+ _extra_files=extra_files,
51
+ map_location=device)
52
+ # load model metadata
53
+ model_meta = json.loads(extra_files["dcnum_meta.json"])
54
+ # set model to evaluation mode
55
+ model_jit.eval()
56
+ # optimize for inference on device
57
+ model_jit = torch.jit.optimize_for_inference(model_jit)
58
+ return model_jit, model_meta
59
+
60
+
61
+ @functools.cache
62
+ def retrieve_model_file(path_or_name):
63
+ """Retrieve a dcnum torch model file
64
+
65
+ If a path to a model is given, then this path is returned directly.
66
+ If a file name is given, then look for the file with
67
+ :func:`dcnum.meta.paths.find_file` using the "torch_model_file"
68
+ topic.
69
+ """
70
+ # Did the user already pass a path?
71
+ if isinstance(path_or_name, pathlib.Path):
72
+ if path_or_name.exists():
73
+ path = path_or_name
74
+ else:
75
+ try:
76
+ return retrieve_model_file(path_or_name.name)
77
+ except BaseException:
78
+ raise FileNotFoundError(errno.ENOENT,
79
+ os.strerror(errno.ENOENT),
80
+ str(path_or_name))
81
+ elif isinstance(path_or_name, str):
82
+ name = path_or_name.strip()
83
+ # We now have a string for a filename, and we have to figure out what
84
+ # the path is. There are several options, including cached files.
85
+ if pathlib.Path(name).exists():
86
+ path = pathlib.Path(name)
87
+ else:
88
+ path = paths.find_file("torch_model_files", name)
89
+ else:
90
+ raise ValueError(
91
+ f"Please pass a string or a path, got {type(path_or_name)}!")
92
+
93
+ logger.info(f"Found dcnum model file {path}")
94
+ check_md5sum(path)
95
+ return path