zea 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -5
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/config.py +34 -25
  10. zea/data/__init__.py +22 -25
  11. zea/data/augmentations.py +221 -28
  12. zea/data/convert/__init__.py +1 -6
  13. zea/data/convert/__main__.py +123 -0
  14. zea/data/convert/camus.py +101 -40
  15. zea/data/convert/echonet.py +187 -86
  16. zea/data/convert/echonetlvh/README.md +2 -3
  17. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  18. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  19. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  20. zea/data/convert/picmus.py +37 -40
  21. zea/data/convert/utils.py +86 -0
  22. zea/data/convert/{matlab.py → verasonics.py} +44 -65
  23. zea/data/data_format.py +155 -34
  24. zea/data/dataloader.py +12 -7
  25. zea/data/datasets.py +112 -71
  26. zea/data/file.py +184 -73
  27. zea/data/file_operations.py +496 -0
  28. zea/data/layers.py +3 -3
  29. zea/data/preset_utils.py +1 -1
  30. zea/datapaths.py +16 -4
  31. zea/display.py +14 -13
  32. zea/interface.py +14 -16
  33. zea/internal/_generate_keras_ops.py +6 -7
  34. zea/internal/cache.py +2 -49
  35. zea/internal/checks.py +6 -12
  36. zea/internal/config/validation.py +1 -2
  37. zea/internal/core.py +69 -6
  38. zea/internal/device.py +6 -2
  39. zea/internal/dummy_scan.py +330 -0
  40. zea/internal/operators.py +118 -2
  41. zea/internal/parameters.py +101 -70
  42. zea/internal/setup_zea.py +5 -6
  43. zea/internal/utils.py +282 -0
  44. zea/io_lib.py +322 -146
  45. zea/keras_ops.py +74 -4
  46. zea/log.py +9 -7
  47. zea/metrics.py +15 -7
  48. zea/models/__init__.py +31 -21
  49. zea/models/base.py +30 -14
  50. zea/models/carotid_segmenter.py +19 -4
  51. zea/models/diffusion.py +235 -23
  52. zea/models/echonet.py +22 -8
  53. zea/models/echonetlvh.py +31 -7
  54. zea/models/lpips.py +19 -2
  55. zea/models/lv_segmentation.py +30 -11
  56. zea/models/preset_utils.py +5 -5
  57. zea/models/regional_quality.py +30 -10
  58. zea/models/taesd.py +21 -5
  59. zea/models/unet.py +15 -1
  60. zea/ops.py +770 -336
  61. zea/probes.py +6 -6
  62. zea/scan.py +121 -51
  63. zea/simulator.py +24 -21
  64. zea/tensor_ops.py +477 -353
  65. zea/tools/fit_scan_cone.py +90 -160
  66. zea/tools/hf.py +1 -1
  67. zea/tools/selection_tool.py +47 -86
  68. zea/tracking/__init__.py +16 -0
  69. zea/tracking/base.py +94 -0
  70. zea/tracking/lucas_kanade.py +474 -0
  71. zea/tracking/segmentation.py +110 -0
  72. zea/utils.py +101 -480
  73. zea/visualize.py +177 -39
  74. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/METADATA +6 -2
  75. zea-0.0.8.dist-info/RECORD +122 -0
  76. zea-0.0.6.dist-info/RECORD +0 -112
  77. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
  78. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
  79. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,496 @@
1
+ """
2
+ This module provides some utilities to edit zea data files.
3
+
4
+ Available operations
5
+ --------------------
6
+
7
+ - `sum`: Sum multiple raw data files into one.
8
+
9
+ - `compound_frames`: Compound frames in a raw data file to increase SNR.
10
+
11
+ - `compound_transmits`: Compound transmits in a raw data file to increase SNR.
12
+
13
+ - `resave`: Resave a zea data file. This can be used to change the file format version.
14
+
15
+ - `extract`: extract frames and transmits in a raw data file.
16
+ """
17
+
18
+ import argparse
19
+ from pathlib import Path
20
+
21
+ import numpy as np
22
+
23
+ from zea import Probe, Scan
24
+ from zea.data.data_format import generate_zea_dataset, load_additional_elements, load_description
25
+ from zea.data.file import load_file_all_data_types
26
+ from zea.internal.checks import _IMAGE_DATA_TYPES, _NON_IMAGE_DATA_TYPES
27
+ from zea.internal.core import DataTypes
28
+ from zea.log import logger
29
+
30
+ ALL_DATA_TYPES_EXCEPT_RAW = set(_IMAGE_DATA_TYPES + _NON_IMAGE_DATA_TYPES) - {"raw_data"}
31
+
32
+ OPERATION_NAMES = [
33
+ "sum",
34
+ "compound_frames",
35
+ "compound_transmits",
36
+ "resave",
37
+ "extract",
38
+ ]
39
+
40
+
41
+ def save_file(
42
+ path,
43
+ scan: Scan,
44
+ probe: Probe,
45
+ raw_data: np.ndarray = None,
46
+ aligned_data: np.ndarray = None,
47
+ beamformed_data: np.ndarray = None,
48
+ envelope_data: np.ndarray = None,
49
+ image: np.ndarray = None,
50
+ image_sc: np.ndarray = None,
51
+ additional_elements=None,
52
+ description="",
53
+ **kwargs,
54
+ ):
55
+ """Saves data to a zea data file (h5py file).
56
+
57
+ Args:
58
+ path (str, pathlike): The path to the hdf5 file.
59
+ raw_data (np.ndarray): The data to save.
60
+ scan (Scan): The scan object containing the parameters of the acquisition.
61
+ probe (Probe): The probe object containing the parameters of the probe.
62
+ additional_elements (list of DatasetElement, optional): Additional elements to save in the
63
+ file. Defaults to None.
64
+ """
65
+
66
+ generate_zea_dataset(
67
+ path=path,
68
+ raw_data=raw_data,
69
+ aligned_data=aligned_data,
70
+ beamformed_data=beamformed_data,
71
+ image=image,
72
+ image_sc=image_sc,
73
+ envelope_data=envelope_data,
74
+ probe_name="generic",
75
+ probe_geometry=probe.probe_geometry,
76
+ sampling_frequency=scan.sampling_frequency,
77
+ center_frequency=scan.center_frequency,
78
+ initial_times=scan.initial_times,
79
+ t0_delays=scan.t0_delays,
80
+ sound_speed=scan.sound_speed,
81
+ focus_distances=scan.focus_distances,
82
+ polar_angles=scan.polar_angles,
83
+ azimuth_angles=scan.azimuth_angles,
84
+ tx_apodizations=scan.tx_apodizations,
85
+ bandwidth_percent=scan.bandwidth_percent,
86
+ time_to_next_transmit=scan.time_to_next_transmit,
87
+ tgc_gain_curve=scan.tgc_gain_curve,
88
+ element_width=scan.element_width,
89
+ tx_waveform_indices=scan.tx_waveform_indices,
90
+ waveforms_one_way=scan.waveforms_one_way,
91
+ waveforms_two_way=scan.waveforms_two_way,
92
+ description=description,
93
+ additional_elements=additional_elements,
94
+ )
95
+
96
+
97
+ def sum_data(input_paths: list[Path], output_path: Path, overwrite=False):
98
+ """
99
+ Sums multiple raw data files and saves the result to a new file.
100
+
101
+ Args:
102
+ input_paths (list[Path]): List of paths to the input raw data files.
103
+ output_path (Path): Path to the output file where the summed data will be saved.
104
+ overwrite (bool, optional): Whether to overwrite the output file if it exists. Defaults to
105
+ False.
106
+ """
107
+
108
+ data_dict, scan, probe = load_file_all_data_types(input_paths[0])
109
+ description = load_description(input_paths[0])
110
+ additional_elements = load_additional_elements(input_paths[0])
111
+
112
+ for file in input_paths[1:]:
113
+ new_data, new_scan, new_probe = load_file_all_data_types(file)
114
+
115
+ if data_dict["raw_data"] is not None:
116
+ _assert_shapes_equal(data_dict["raw_data"], new_data["raw_data"], "raw_data")
117
+ data_dict["raw_data"] += new_data["raw_data"]
118
+
119
+ if data_dict["aligned_data"] is not None:
120
+ _assert_shapes_equal(
121
+ data_dict["aligned_data"], new_data["aligned_data"], "aligned_data"
122
+ )
123
+ data_dict["aligned_data"] += new_data["aligned_data"]
124
+
125
+ if data_dict["beamformed_data"] is not None:
126
+ _assert_shapes_equal(
127
+ data_dict["beamformed_data"], new_data["beamformed_data"], "beamformed_data"
128
+ )
129
+ data_dict["beamformed_data"] += new_data["beamformed_data"]
130
+
131
+ if data_dict["envelope_data"] is not None:
132
+ _assert_shapes_equal(
133
+ data_dict["envelope_data"], new_data["envelope_data"], "envelope_data"
134
+ )
135
+ data_dict["envelope_data"] += new_data["envelope_data"]
136
+
137
+ if data_dict["image"] is not None:
138
+ _assert_shapes_equal(data_dict["image"], new_data["image"], "image")
139
+ data_dict["image"] = np.log(np.exp(new_data["image"]) + np.exp(data_dict["image"]))
140
+
141
+ if data_dict["image_sc"] is not None:
142
+ _assert_shapes_equal(data_dict["image_sc"], new_data["image_sc"], "image_sc")
143
+ data_dict["image_sc"] = np.log(
144
+ np.exp(new_data["image_sc"]) + np.exp(data_dict["image_sc"])
145
+ )
146
+ assert scan == new_scan, "Scan parameters do not match."
147
+ assert probe == new_probe, "Probe parameters do not match."
148
+
149
+ if overwrite:
150
+ _delete_file_if_exists(output_path)
151
+
152
+ save_file(
153
+ path=output_path,
154
+ scan=scan,
155
+ probe=probe,
156
+ additional_elements=additional_elements,
157
+ description=description,
158
+ **data_dict,
159
+ )
160
+
161
+
162
+ def _assert_shapes_equal(array0, array1, name="array"):
163
+ shape0, shape1 = array0.shape, array1.shape
164
+ assert shape0 == shape1, f"{name} shapes do not match. Got {shape0} and {shape1}."
165
+
166
+
167
+ def compound_frames(input_path: Path, output_path: Path, overwrite=False):
168
+ """
169
+ Compounds frames in a raw data file by averaging them.
170
+
171
+ Args:
172
+ input_path (Path): Path to the input raw data file.
173
+ output_path (Path): Path to the output file where the compounded data will be saved.
174
+ overwrite (bool, optional): Whether to overwrite the output file if it exists. Defaults to
175
+ False.
176
+ """
177
+
178
+ data_dict, scan, probe = load_file_all_data_types(input_path)
179
+ additional_elements = load_additional_elements(input_path)
180
+ description = load_description(input_path)
181
+
182
+ # Assuming the first dimension is the frame dimension
183
+
184
+ compounded_data = {}
185
+ for data_type in DataTypes:
186
+ key = data_type.value
187
+ if data_dict[key] is None:
188
+ compounded_data[key] = None
189
+ continue
190
+ if key == "image" or key == "image_sc":
191
+ compounded_data[key] = np.log(np.mean(np.exp(data_dict[key]), axis=0, keepdims=True))
192
+ else:
193
+ compounded_data[key] = np.mean(data_dict[key], axis=0, keepdims=True)
194
+
195
+ scan = _scan_reduce_frames(scan, [0])
196
+
197
+ if overwrite:
198
+ _delete_file_if_exists(output_path)
199
+
200
+ save_file(
201
+ path=output_path,
202
+ scan=scan,
203
+ probe=probe,
204
+ additional_elements=additional_elements,
205
+ description=description,
206
+ **compounded_data,
207
+ )
208
+
209
+
210
+ def compound_transmits(input_path: Path, output_path: Path, overwrite=False):
211
+ """
212
+ Compounds transmits in a raw data file by averaging them.
213
+
214
+ Note
215
+ ----
216
+ This function assumes that all transmits are identical. If this is not the case the function
217
+ will result in incorrect scan parameters.
218
+
219
+ Args:
220
+ input_path (Path): Path to the input raw data file.
221
+ output_path (Path): Path to the output file where the compounded data will be saved.
222
+ overwrite (bool, optional): Whether to overwrite the output file if it exists. Defaults to
223
+ False.
224
+ """
225
+
226
+ data_dict, scan, probe = load_file_all_data_types(input_path)
227
+ additional_elements = load_additional_elements(input_path)
228
+ description = load_description(input_path)
229
+
230
+ if not _all_tx_are_identical(scan):
231
+ logger.warning(
232
+ "Not all transmits are identical. Compounding transmits may lead to unexpected results."
233
+ )
234
+
235
+ # Assuming the second dimension is the transmit dimension
236
+ for key in ["raw_data", "aligned_data"]:
237
+ if data_dict[key] is None:
238
+ continue
239
+ data_dict[key] = np.mean(data_dict[key], axis=1, keepdims=True)
240
+
241
+ scan.set_transmits([0])
242
+
243
+ if overwrite:
244
+ _delete_file_if_exists(output_path)
245
+
246
+ save_file(
247
+ path=output_path,
248
+ scan=scan,
249
+ probe=probe,
250
+ additional_elements=additional_elements,
251
+ description=description,
252
+ **data_dict,
253
+ )
254
+
255
+
256
+ def _all_tx_are_identical(scan: Scan):
257
+ """Checks if all transmits in a Scan object are identical."""
258
+ attributes_to_check = [
259
+ scan.polar_angles,
260
+ scan.azimuth_angles,
261
+ scan.t0_delays,
262
+ scan.tx_apodizations,
263
+ scan.focus_distances,
264
+ scan.initial_times,
265
+ ]
266
+
267
+ for attr in attributes_to_check:
268
+ if attr is not None and not _check_all_identical(attr, axis=0):
269
+ return False
270
+ return True
271
+
272
+
273
+ def _check_all_identical(array, axis=0):
274
+ """Checks if all elements along a given axis are identical."""
275
+ first = array.take(0, axis=axis)
276
+ return np.all(np.equal(array, first), axis=axis).all()
277
+
278
+
279
+ def resave(input_path: Path, output_path: Path, overwrite=False):
280
+ """
281
+ Resaves a zea data file to a new location.
282
+
283
+ Args:
284
+ input_path (Path): Path to the input zea data file.
285
+ output_path (Path): Path to the output file where the data will be saved.
286
+ overwrite (bool, optional): Whether to overwrite the output file if it exists. Defaults to
287
+ False.
288
+ """
289
+
290
+ data_dict, scan, probe = load_file_all_data_types(input_path)
291
+ additional_elements = load_additional_elements(input_path)
292
+ description = load_description(input_path)
293
+ scan.set_transmits("all")
294
+
295
+ if overwrite:
296
+ _delete_file_if_exists(output_path)
297
+ save_file(
298
+ path=output_path,
299
+ **data_dict,
300
+ scan=scan,
301
+ probe=probe,
302
+ additional_elements=additional_elements,
303
+ description=description,
304
+ )
305
+
306
+
307
+ def extract_frames_transmits(
308
+ input_path: Path,
309
+ output_path: Path,
310
+ frame_indices=slice(None),
311
+ transmit_indices=slice(None),
312
+ overwrite=False,
313
+ ):
314
+ """
315
+ extracts frames and transmits in a raw data file.
316
+
317
+ Note that the frame indices cannot both be lists. At least one of them must be a slice.
318
+ Please refer to the documentation of :func:`zea.data.file.load_file_all_data_types` for more
319
+ information on the supported index types.
320
+
321
+ Args:
322
+ input_path (Path): Path to the input raw data file.
323
+ output_path (Path): Path to the output file where the extracted data will be saved.
324
+ frame_indices (list, array-like, or slice): Indices of the frames to keep.
325
+ transmit_indices (list, array-like, or slice): Indices of the transmits to keep.
326
+ overwrite (bool, optional): Whether to overwrite the output file if it exists. Defaults to
327
+ False.
328
+ """
329
+ indices = (frame_indices, transmit_indices)
330
+ data_dict, scan, probe = load_file_all_data_types(input_path, indices=indices)
331
+
332
+ additional_elements = load_additional_elements(input_path)
333
+ description = load_description(input_path)
334
+
335
+ scan = _scan_reduce_frames(scan, frame_indices)
336
+
337
+ if overwrite:
338
+ _delete_file_if_exists(output_path)
339
+
340
+ save_file(
341
+ path=output_path,
342
+ **data_dict,
343
+ scan=scan,
344
+ probe=probe,
345
+ additional_elements=additional_elements,
346
+ description=description,
347
+ )
348
+
349
+
350
+ def _delete_file_if_exists(path: Path):
351
+ """Deletes a file if it exists."""
352
+ if path.exists():
353
+ path.unlink()
354
+
355
+
356
+ def _interpret_index(input_str):
357
+ if "-" in input_str:
358
+ start, end = map(int, input_str.split("-"))
359
+ return list(range(start, end + 1))
360
+ else:
361
+ return [int(x) for x in input_str.split(" ")]
362
+
363
+
364
+ def _interpret_indices(input_str_list):
365
+ if isinstance(input_str_list, str) and input_str_list == "all":
366
+ return slice(None)
367
+
368
+ if len(input_str_list) == 1 and "-" in input_str_list[0]:
369
+ start, end = map(int, input_str_list[0].split("-"))
370
+ return slice(start, end + 1)
371
+
372
+ indices = []
373
+ for part in input_str_list:
374
+ indices.extend(_interpret_index(part))
375
+ return indices
376
+
377
+
378
+ def _scan_reduce_frames(scan, frame_indices):
379
+ transmit_indices = scan.selected_transmits
380
+ scan.set_transmits("all")
381
+ if scan.time_to_next_transmit is not None:
382
+ scan.time_to_next_transmit = scan.time_to_next_transmit[frame_indices]
383
+ scan.set_transmits(transmit_indices)
384
+ return scan
385
+
386
+
387
+ def get_parser():
388
+ """Command line argument parser with subcommands"""
389
+
390
+ parser = argparse.ArgumentParser(
391
+ description="Manipulate zea data files.",
392
+ formatter_class=argparse.RawDescriptionHelpFormatter,
393
+ )
394
+ subparsers = parser.add_subparsers(dest="operation", required=True)
395
+ _add_parser_sum(subparsers)
396
+ _add_parser_compound_frames(subparsers)
397
+ _add_parser_compound_transmits(subparsers)
398
+ _add_parser_resave(subparsers)
399
+ _add_parser_extract(subparsers)
400
+
401
+ return parser
402
+
403
+
404
+ def _add_parser_sum(subparsers):
405
+ sum_parser = subparsers.add_parser("sum", help="Sum the raw data of multiple files.")
406
+ sum_parser.add_argument("input_paths", type=Path, nargs="+", help="Paths to the input files.")
407
+ sum_parser.add_argument("output_path", type=Path, help="Output HDF5 file.")
408
+ sum_parser.add_argument(
409
+ "--overwrite", action="store_true", default=False, help="Overwrite existing output file."
410
+ )
411
+
412
+
413
+ def _add_parser_compound_frames(subparsers):
414
+ cf_parser = subparsers.add_parser("compound_frames", help="Compound frames to increase SNR.")
415
+ cf_parser.add_argument("input_path", type=Path, help="Input HDF5 file.")
416
+ cf_parser.add_argument("output_path", type=Path, help="Output HDF5 file.")
417
+ cf_parser.add_argument(
418
+ "--overwrite", action="store_true", default=False, help="Overwrite existing output file."
419
+ )
420
+
421
+
422
+ def _add_parser_compound_transmits(subparsers):
423
+ ct_parser = subparsers.add_parser(
424
+ "compound_transmits", help="Compound transmits to increase SNR."
425
+ )
426
+ ct_parser.add_argument("input_path", type=Path, help="Input HDF5 file.")
427
+ ct_parser.add_argument("output_path", type=Path, help="Output HDF5 file.")
428
+ ct_parser.add_argument(
429
+ "--overwrite", action="store_true", default=False, help="Overwrite existing output file."
430
+ )
431
+
432
+
433
+ def _add_parser_resave(subparsers):
434
+ resave_parser = subparsers.add_parser("resave", help="Resave a file to change format version.")
435
+ resave_parser.add_argument("input_path", type=Path, help="Input HDF5 file.")
436
+ resave_parser.add_argument("output_path", type=Path, help="Output HDF5 file.")
437
+ resave_parser.add_argument(
438
+ "--overwrite", action="store_true", default=False, help="Overwrite existing output file."
439
+ )
440
+
441
+
442
+ def _add_parser_extract(subparsers):
443
+ extract_parser = subparsers.add_parser("extract", help="Extract subset of frames or transmits.")
444
+ extract_parser.add_argument("input_path", type=Path, help="Input HDF5 file.")
445
+ extract_parser.add_argument("output_path", type=Path, help="Output HDF5 file.")
446
+ extract_parser.add_argument(
447
+ "--transmits",
448
+ type=str,
449
+ nargs="+",
450
+ default="all",
451
+ help="Target transmits. Can be a list of integers or ranges (e.g. 0-3 7).",
452
+ )
453
+ extract_parser.add_argument(
454
+ "--frames",
455
+ type=str,
456
+ nargs="+",
457
+ default="all",
458
+ help="Target frames. Can be a list of integers or ranges (e.g. 0-3 7).",
459
+ )
460
+ extract_parser.add_argument(
461
+ "--overwrite", action="store_true", default=False, help="Overwrite existing output file."
462
+ )
463
+
464
+
465
+ if __name__ == "__main__":
466
+ parser = get_parser()
467
+ args = parser.parse_args()
468
+
469
+ if args.output_path.exists() and not args.overwrite:
470
+ logger.error(
471
+ f"Output file {args.output_path} already exists. Use --overwrite to overwrite it."
472
+ )
473
+ exit(1)
474
+
475
+ if args.operation == "compound_frames":
476
+ compound_frames(
477
+ input_path=args.input_path, output_path=args.output_path, overwrite=args.overwrite
478
+ )
479
+ elif args.operation == "compound_transmits":
480
+ compound_transmits(
481
+ input_path=args.input_path, output_path=args.output_path, overwrite=args.overwrite
482
+ )
483
+ elif args.operation == "resave":
484
+ resave(input_path=args.input_path, output_path=args.output_path, overwrite=args.overwrite)
485
+ elif args.operation == "extract":
486
+ extract_frames_transmits(
487
+ input_path=args.input_path,
488
+ output_path=args.output_path,
489
+ frame_indices=_interpret_indices(args.frames),
490
+ transmit_indices=_interpret_indices(args.transmits),
491
+ overwrite=args.overwrite,
492
+ )
493
+ else:
494
+ sum_data(
495
+ input_paths=args.input_paths, output_path=args.output_path, overwrite=args.overwrite
496
+ )
zea/data/layers.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  import keras
4
4
  import numpy as np
5
- from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
5
+ from keras.src.layers.preprocessing.data_layer import DataLayer
6
6
 
7
7
  from zea.ops import Pad as PadOp
8
8
  from zea.utils import map_negative_indices
@@ -11,7 +11,7 @@ from zea.utils import map_negative_indices
11
11
  class Pad(PadOp):
12
12
  """Pad layer for padding tensors to a specified shape which can be used in tf.data pipelines."""
13
13
 
14
- __call__ = TFDataLayer.__call__
14
+ __call__ = DataLayer.__call__
15
15
 
16
16
  def call(self, inputs):
17
17
  """
@@ -20,7 +20,7 @@ class Pad(PadOp):
20
20
  return super().call(data=inputs)["data"]
21
21
 
22
22
 
23
- class Resizer(TFDataLayer):
23
+ class Resizer(DataLayer):
24
24
  """
25
25
  Resize layer for resizing images. Can deal with N-dimensional images.
26
26
  Can do resize, center_crop, random_crop and crop_or_pad.
zea/data/preset_utils.py CHANGED
@@ -1,6 +1,6 @@
1
1
  """Preset utils for zea datasets hosted on Hugging Face.
2
2
 
3
- See https://huggingface.co/zea/
3
+ See https://huggingface.co/zeahub/
4
4
  """
5
5
 
6
6
  from pathlib import Path
zea/datapaths.py CHANGED
@@ -11,12 +11,24 @@ to set up your local data paths.
11
11
  Example usage
12
12
  ^^^^^^^^^^^^^
13
13
 
14
- .. code-block:: python
14
+ .. doctest::
15
15
 
16
- from zea.datapaths import set_data_paths
16
+ >>> import yaml
17
+ >>> from zea.datapaths import set_data_paths
17
18
 
18
- user = set_data_paths("users.yaml")
19
- print(user.data_root)
19
+ >>> user_config = {"data_root": "/path/to/data", "output": "/path/to/output"}
20
+ >>> with open("users.yaml", "w", encoding="utf-8") as file:
21
+ ... yaml.dump(user_config, file)
22
+
23
+ >>> user = set_data_paths("users.yaml")
24
+ >>> print(user.data_root)
25
+ /path/to/data
26
+
27
+ .. testcleanup::
28
+
29
+ import os
30
+
31
+ os.remove("users.yaml")
20
32
 
21
33
  """
22
34
 
zea/display.py CHANGED
@@ -8,9 +8,8 @@ import scipy
8
8
  from keras import ops
9
9
  from PIL import Image
10
10
 
11
- from zea import log
11
+ from zea.tensor_ops import translate
12
12
  from zea.tools.fit_scan_cone import fit_and_crop_around_scan_cone
13
- from zea.utils import translate
14
13
 
15
14
 
16
15
  def to_8bit(image, dynamic_range: Union[None, tuple] = None, pillow: bool = True):
@@ -340,12 +339,14 @@ def scan_convert(
340
339
  def map_coordinates(inputs, coordinates, order, fill_mode="constant", fill_value=0):
341
340
  """map_coordinates using keras.ops or scipy.ndimage when order > 1."""
342
341
  if order > 1:
343
- inputs = ops.convert_to_numpy(inputs)
344
- coordinates = ops.convert_to_numpy(coordinates)
342
+ # Preserve original dtype before conversion
343
+ original_dtype = ops.dtype(inputs)
344
+ inputs_np = ops.convert_to_numpy(inputs).astype(np.float32)
345
+ coordinates_np = ops.convert_to_numpy(coordinates).astype(np.float32)
345
346
  out = scipy.ndimage.map_coordinates(
346
- inputs, coordinates, order=order, mode=fill_mode, cval=fill_value
347
+ inputs_np, coordinates_np, order=order, mode=fill_mode, cval=fill_value
347
348
  )
348
- return ops.convert_to_tensor(out)
349
+ return ops.convert_to_tensor(out.astype(original_dtype))
349
350
  else:
350
351
  return ops.image.map_coordinates(
351
352
  inputs,
@@ -439,11 +440,7 @@ def cartesian_to_polar_matrix(
439
440
  Returns:
440
441
  polar_matrix (Array): The image re-sampled in polar coordinates with shape `polar_shape`.
441
442
  """
442
- if ops.dtype(cartesian_matrix) != "float32":
443
- log.info(
444
- f"Cartesian matrix with dtype {ops.dtype(cartesian_matrix)} has been cast to float32."
445
- )
446
- cartesian_matrix = ops.cast(cartesian_matrix, "float32")
443
+ assert "float" in ops.dtype(cartesian_matrix), "Input image must be float type"
447
444
 
448
445
  # Assume that polar grid is same shape as cartesian grid unless specified
449
446
  cartesian_rows, cartesian_cols = ops.shape(cartesian_matrix)
@@ -501,6 +498,7 @@ def inverse_scan_convert_2d(
501
498
  output_size=None,
502
499
  interpolation_order=1,
503
500
  find_scan_cone=True,
501
+ image_range: tuple | None = None,
504
502
  ):
505
503
  """
506
504
  Convert a Cartesian-format ultrasound image to a polar representation.
@@ -510,7 +508,7 @@ def inverse_scan_convert_2d(
510
508
  Optionally, it can detect and crop around the scan cone before conversion.
511
509
 
512
510
  Args:
513
- cartesian_image (tensor): 2D image array in Cartesian coordinates.
511
+ cartesian_image (tensor): 2D image array in Cartesian coordinates of type float.
514
512
  fill_value (float): Value used to fill regions outside the original image
515
513
  during interpolation.
516
514
  angle (float): Angular field of view (in radians) used for the polar transformation.
@@ -523,12 +521,15 @@ def inverse_scan_convert_2d(
523
521
  in the Cartesian image before polar conversion, ensuring that the scan cone is
524
522
  centered without padding. Can be set to False if the image is already cropped
525
523
  and centered.
524
+ image_range (tuple, optional): Tuple (vmin, vmax) for display scaling
525
+ when detecting the scan cone.
526
526
 
527
527
  Returns:
528
528
  polar_image (Array): 2D image in polar coordinates (sector-shaped scan).
529
529
  """
530
530
  if find_scan_cone:
531
- cartesian_image = fit_and_crop_around_scan_cone(cartesian_image)
531
+ assert image_range is not None, "image_range must be provided when find_scan_cone is True"
532
+ cartesian_image = fit_and_crop_around_scan_cone(cartesian_image, image_range)
532
533
  polar_image = cartesian_to_polar_matrix(
533
534
  cartesian_image,
534
535
  fill_value=fill_value,