zea 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +1 -1
- zea/agent/selection.py +24 -18
- zea/data/data_format.py +28 -26
- zea/internal/core.py +1 -1
- zea/internal/viewer.py +1 -1
- zea/log.py +8 -0
- zea/models/diffusion.py +43 -17
- zea/ops.py +36 -4
- zea/utils.py +31 -0
- zea/visualize.py +10 -4
- {zea-0.0.3.dist-info → zea-0.0.4.dist-info}/METADATA +1 -1
- {zea-0.0.3.dist-info → zea-0.0.4.dist-info}/RECORD +15 -15
- {zea-0.0.3.dist-info → zea-0.0.4.dist-info}/LICENSE +0 -0
- {zea-0.0.3.dist-info → zea-0.0.4.dist-info}/WHEEL +0 -0
- {zea-0.0.3.dist-info → zea-0.0.4.dist-info}/entry_points.txt +0 -0
zea/__init__.py
CHANGED
zea/agent/selection.py
CHANGED
|
@@ -155,7 +155,9 @@ class GreedyEntropy(LinesActionModel):
|
|
|
155
155
|
# TODO: I think we only need to compute the lower triangular
|
|
156
156
|
# of this matrix, since it's symmetric
|
|
157
157
|
squared_l2_error_matrices = (particles[:, :, None, ...] - particles[:, None, :, ...]) ** 2
|
|
158
|
-
gaussian_error_per_pixel_i_j = ops.exp(
|
|
158
|
+
gaussian_error_per_pixel_i_j = ops.exp(
|
|
159
|
+
-(squared_l2_error_matrices) / (2 * entropy_sigma**2)
|
|
160
|
+
)
|
|
159
161
|
# Vertically stack all columns corresponding with the same line
|
|
160
162
|
# This way we can just sum across the height axis and get the entropy
|
|
161
163
|
# for each pixel in a given line
|
|
@@ -176,33 +178,35 @@ class GreedyEntropy(LinesActionModel):
|
|
|
176
178
|
# [n_particles, n_particles, batch, height, width]
|
|
177
179
|
return gaussian_error_per_pixel_stacked
|
|
178
180
|
|
|
179
|
-
def
|
|
180
|
-
"""
|
|
181
|
-
|
|
181
|
+
def compute_pixelwise_entropy(self, particles):
|
|
182
|
+
"""
|
|
182
183
|
This function computes the entropy for each line using a Gaussian Mixture Model
|
|
183
184
|
approximation of the posterior distribution.
|
|
184
|
-
For more details see Section
|
|
185
|
+
For more details see Section VI. B here: https://arxiv.org/pdf/2410.13310
|
|
185
186
|
|
|
186
187
|
Args:
|
|
187
188
|
particles (Tensor): Particles of shape (batch_size, n_particles, height, width)
|
|
188
189
|
|
|
189
190
|
Returns:
|
|
190
|
-
Tensor: batch of entropies per
|
|
191
|
+
Tensor: batch of entropies per pixel, of shape (batch, height, width)
|
|
191
192
|
"""
|
|
192
|
-
|
|
193
|
+
n_particles = ops.shape(particles)[1]
|
|
194
|
+
gaussian_error_per_pixel_stacked = self.compute_pairwise_pixel_gaussian_error(
|
|
193
195
|
particles,
|
|
194
196
|
self.stack_n_cols,
|
|
195
197
|
self.n_possible_actions,
|
|
196
198
|
self.entropy_sigma,
|
|
197
199
|
)
|
|
198
|
-
gaussian_error_per_line = ops.sum(gaussian_error_per_pixel_stacked, axis=3)
|
|
199
200
|
# sum out first dimension of (n_particles x n_particles) error matrix
|
|
200
|
-
# [n_particles, batch,
|
|
201
|
-
|
|
201
|
+
# [n_particles, batch, height, width]
|
|
202
|
+
pixelwise_entropy_sum_j = ops.sum(
|
|
203
|
+
(1 / n_particles) * gaussian_error_per_pixel_stacked, axis=1
|
|
204
|
+
)
|
|
205
|
+
log_pixelwise_entropy_sum_j = ops.log(pixelwise_entropy_sum_j)
|
|
202
206
|
# sum out second dimension of (n_particles x n_particles) error matrix
|
|
203
|
-
# [batch,
|
|
204
|
-
|
|
205
|
-
return
|
|
207
|
+
# [batch, height, width]
|
|
208
|
+
pixelwise_entropy = -ops.sum((1 / n_particles) * log_pixelwise_entropy_sum_j, axis=1)
|
|
209
|
+
return pixelwise_entropy
|
|
206
210
|
|
|
207
211
|
def select_line_and_reweight_entropy(self, entropy_per_line):
|
|
208
212
|
"""Select the line with maximum entropy and reweight the entropies.
|
|
@@ -260,17 +264,19 @@ class GreedyEntropy(LinesActionModel):
|
|
|
260
264
|
particles (Tensor): Particles of shape (batch_size, n_particles, height, width)
|
|
261
265
|
|
|
262
266
|
Returns:
|
|
263
|
-
|
|
267
|
+
Tuple[Tensor, Tensor]:
|
|
264
268
|
- Newly selected lines as k-hot vectors, shaped (batch_size, n_possible_actions)
|
|
265
|
-
|
|
269
|
+
- Masks of shape (batch_size, img_height, img_width)
|
|
266
270
|
"""
|
|
267
|
-
|
|
271
|
+
|
|
272
|
+
pixelwise_entropy = self.compute_pixelwise_entropy(particles)
|
|
273
|
+
linewise_entropy = ops.sum(pixelwise_entropy, axis=1)
|
|
268
274
|
|
|
269
275
|
# Greedily select best line, reweight entropies, and repeat
|
|
270
276
|
all_selected_lines = []
|
|
271
277
|
for _ in range(self.n_actions):
|
|
272
|
-
max_entropy_line,
|
|
273
|
-
self.select_line_and_reweight_entropy,
|
|
278
|
+
max_entropy_line, linewise_entropy = ops.vectorized_map(
|
|
279
|
+
self.select_line_and_reweight_entropy, linewise_entropy
|
|
274
280
|
)
|
|
275
281
|
all_selected_lines.append(max_entropy_line)
|
|
276
282
|
|
zea/data/data_format.py
CHANGED
|
@@ -468,32 +468,34 @@ def _write_datasets(
|
|
|
468
468
|
),
|
|
469
469
|
unit="-",
|
|
470
470
|
)
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
range(
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
471
|
+
|
|
472
|
+
if waveforms_one_way is not None:
|
|
473
|
+
for n in range(len(waveforms_one_way)):
|
|
474
|
+
_add_dataset(
|
|
475
|
+
group_name=scan_group_name + "/waveforms_one_way",
|
|
476
|
+
name=f"waveform_{str(n).zfill(3)}",
|
|
477
|
+
data=waveforms_one_way[n],
|
|
478
|
+
description=(
|
|
479
|
+
"One-way waveform as simulated by the Verasonics system, "
|
|
480
|
+
"sampled at 250MHz. This is the waveform after being filtered "
|
|
481
|
+
"by the tranducer bandwidth once."
|
|
482
|
+
),
|
|
483
|
+
unit="V",
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
if waveforms_two_way is not None:
|
|
487
|
+
for n in range(len(waveforms_two_way)):
|
|
488
|
+
_add_dataset(
|
|
489
|
+
group_name=scan_group_name + "/waveforms_two_way",
|
|
490
|
+
name=f"waveform_{str(n).zfill(3)}",
|
|
491
|
+
data=waveforms_two_way[n],
|
|
492
|
+
description=(
|
|
493
|
+
"Two-way waveform as simulated by the Verasonics system, "
|
|
494
|
+
"sampled at 250MHz. This is the waveform after being filtered "
|
|
495
|
+
"by the tranducer bandwidth twice."
|
|
496
|
+
),
|
|
497
|
+
unit="V",
|
|
498
|
+
)
|
|
497
499
|
|
|
498
500
|
# Add additional elements
|
|
499
501
|
if additional_elements is not None:
|
zea/internal/core.py
CHANGED
zea/internal/viewer.py
CHANGED
|
@@ -28,7 +28,7 @@ def plt_window_has_been_closed(fig):
|
|
|
28
28
|
return not plt.fignum_exists(fig.number)
|
|
29
29
|
|
|
30
30
|
|
|
31
|
-
def filename_from_window_dialog(window_name=None, filetypes=None, initialdir=None):
|
|
31
|
+
def filename_from_window_dialog(window_name=None, filetypes=None, initialdir=None) -> Path:
|
|
32
32
|
"""Get filename through dialog window
|
|
33
33
|
Args:
|
|
34
34
|
window_name: string with name of window
|
zea/log.py
CHANGED
|
@@ -289,6 +289,14 @@ def critical(message, *args, **kwargs):
|
|
|
289
289
|
return message
|
|
290
290
|
|
|
291
291
|
|
|
292
|
+
def number_to_str(number, decimals=2):
|
|
293
|
+
"""Formats a number to a string with the given number of decimals."""
|
|
294
|
+
if isinstance(number, (int, float)):
|
|
295
|
+
return f"{number:.{decimals}f}"
|
|
296
|
+
else:
|
|
297
|
+
raise ValueError(f"Expected a number, got {type(number)}: {number}")
|
|
298
|
+
|
|
299
|
+
|
|
292
300
|
def set_file_logger_directory(directory):
|
|
293
301
|
"""Sets the log level of the logger."""
|
|
294
302
|
global LOG_DIR, file_logger
|
zea/models/diffusion.py
CHANGED
|
@@ -161,16 +161,15 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
161
161
|
f"DiffusionGuidance object, got {guidance}"
|
|
162
162
|
)
|
|
163
163
|
|
|
164
|
-
def call(self, inputs, training=False, **kwargs):
|
|
164
|
+
def call(self, inputs, training=False, network=None, **kwargs):
|
|
165
|
+
"""Calls the score network.
|
|
166
|
+
|
|
167
|
+
If network is not provided, will use the exponential moving
|
|
168
|
+
average network if training is False, otherwise the regular network.
|
|
165
169
|
"""
|
|
166
|
-
|
|
170
|
+
if network is None:
|
|
171
|
+
network = self.network if training else self.ema_network
|
|
167
172
|
|
|
168
|
-
Will use the exponential moving average network if training is False,
|
|
169
|
-
otherwise the regular network."""
|
|
170
|
-
if training:
|
|
171
|
-
network = self.network
|
|
172
|
-
else:
|
|
173
|
-
network = self.ema_network
|
|
174
173
|
return network(inputs, training=training, **kwargs)
|
|
175
174
|
|
|
176
175
|
def sample(self, n_samples=1, n_steps=20, seed=None, **kwargs):
|
|
@@ -367,18 +366,27 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
367
366
|
|
|
368
367
|
def linear_diffusion_schedule(self, diffusion_times):
|
|
369
368
|
"""Create a linear diffusion schedule"""
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
369
|
+
|
|
370
|
+
def _compute_alpha_t(t):
|
|
371
|
+
"""Compute alpha_t for linear diffusion schedule"""
|
|
372
|
+
return ops.prod(1 - diffusion_times[:t], axis=diffusion_times.shape[1:])
|
|
373
|
+
|
|
374
|
+
alphas = ops.vectorized_map(_compute_alpha_t, ops.arange(len(diffusion_times)))
|
|
374
375
|
signal_rates = ops.sqrt(alphas)
|
|
375
376
|
noise_rates = ops.sqrt(1 - alphas)
|
|
376
377
|
return signal_rates, noise_rates
|
|
377
378
|
|
|
378
|
-
def denoise(
|
|
379
|
-
|
|
379
|
+
def denoise(
|
|
380
|
+
self,
|
|
381
|
+
noisy_images,
|
|
382
|
+
noise_rates,
|
|
383
|
+
signal_rates,
|
|
384
|
+
training,
|
|
385
|
+
network=None,
|
|
386
|
+
):
|
|
387
|
+
"""Predict noise component and calculate the image component using it."""
|
|
380
388
|
|
|
381
|
-
pred_noises = self([noisy_images, noise_rates**2], training=training)
|
|
389
|
+
pred_noises = self([noisy_images, noise_rates**2], training=training, network=network)
|
|
382
390
|
pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
|
|
383
391
|
|
|
384
392
|
return pred_noises, pred_images
|
|
@@ -435,6 +443,9 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
435
443
|
seed: keras.random.SeedGenerator | None = None,
|
|
436
444
|
verbose: bool = True,
|
|
437
445
|
track_progress_type: Literal[None, "x_0", "x_t"] = "x_0",
|
|
446
|
+
disable_jit: bool = False,
|
|
447
|
+
training: bool = False,
|
|
448
|
+
network_type: Literal[None, "main", "ema"] = None,
|
|
438
449
|
):
|
|
439
450
|
"""Reverse diffusion process to generate images from noise.
|
|
440
451
|
|
|
@@ -447,6 +458,10 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
447
458
|
seed: Random seed generator.
|
|
448
459
|
verbose: Whether to show a progress bar.
|
|
449
460
|
track_progress_type: Type of progress tracking ("x_0" or "x_t").
|
|
461
|
+
disable_jit: Whether to disable JIT compilation.
|
|
462
|
+
training: Whether to use the training mode of the network.
|
|
463
|
+
network_type: Which network to use ("main" or "ema"). If None, uses the
|
|
464
|
+
network based on the `training` argument.
|
|
450
465
|
|
|
451
466
|
Returns:
|
|
452
467
|
Generated images.
|
|
@@ -478,8 +493,19 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
478
493
|
next_noise_rates, next_signal_rates = self.diffusion_schedule(next_diffusion_times)
|
|
479
494
|
|
|
480
495
|
# denoise
|
|
496
|
+
if network_type == "ema":
|
|
497
|
+
network = self.ema_network
|
|
498
|
+
elif network_type == "main":
|
|
499
|
+
network = self.network
|
|
500
|
+
else:
|
|
501
|
+
network = None
|
|
502
|
+
|
|
481
503
|
pred_noises, pred_images = self.denoise(
|
|
482
|
-
noisy_images,
|
|
504
|
+
noisy_images,
|
|
505
|
+
noise_rates,
|
|
506
|
+
signal_rates,
|
|
507
|
+
training=training,
|
|
508
|
+
network=network,
|
|
483
509
|
)
|
|
484
510
|
|
|
485
511
|
seed, seed1 = split_seed(seed, 2)
|
|
@@ -515,7 +541,7 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
515
541
|
seed,
|
|
516
542
|
),
|
|
517
543
|
# can't jit this with progbar or tracking intermediate values
|
|
518
|
-
disable_jit=verbose or track_progress_type,
|
|
544
|
+
disable_jit=verbose or track_progress_type or disable_jit,
|
|
519
545
|
)
|
|
520
546
|
|
|
521
547
|
return pred_images
|
zea/ops.py
CHANGED
|
@@ -100,7 +100,7 @@ from zea.probes import Probe
|
|
|
100
100
|
from zea.scan import Scan
|
|
101
101
|
from zea.simulator import simulate_rf
|
|
102
102
|
from zea.tensor_ops import batched_map, patched_map, resample, reshape_axis
|
|
103
|
-
from zea.utils import deep_compare, map_negative_indices, translate
|
|
103
|
+
from zea.utils import FunctionTimer, deep_compare, map_negative_indices, translate
|
|
104
104
|
|
|
105
105
|
|
|
106
106
|
def get_ops(ops_name):
|
|
@@ -378,6 +378,7 @@ class Pipeline:
|
|
|
378
378
|
jit_kwargs: dict | None = None,
|
|
379
379
|
name="pipeline",
|
|
380
380
|
validate=True,
|
|
381
|
+
timed: bool = False,
|
|
381
382
|
):
|
|
382
383
|
"""Initialize a pipeline
|
|
383
384
|
|
|
@@ -399,6 +400,8 @@ class Pipeline:
|
|
|
399
400
|
"""
|
|
400
401
|
self._call_pipeline = self.call
|
|
401
402
|
self.name = name
|
|
403
|
+
self.timer = FunctionTimer()
|
|
404
|
+
self.timed = timed
|
|
402
405
|
|
|
403
406
|
self._pipeline_layers = operations
|
|
404
407
|
|
|
@@ -519,6 +522,30 @@ class Pipeline:
|
|
|
519
522
|
"""Alias for self.layers to match the zea naming convention"""
|
|
520
523
|
return self._pipeline_layers
|
|
521
524
|
|
|
525
|
+
def timed_call(self, **inputs):
|
|
526
|
+
"""Process input data through the pipeline."""
|
|
527
|
+
|
|
528
|
+
for op in self._pipeline_layers:
|
|
529
|
+
timed_op = self.timer(op, name=op.__class__.__name__)
|
|
530
|
+
try:
|
|
531
|
+
outputs = timed_op(**inputs)
|
|
532
|
+
except KeyError as exc:
|
|
533
|
+
raise KeyError(
|
|
534
|
+
f"[zea.Pipeline] Operation '{op.__class__.__name__}' "
|
|
535
|
+
f"requires input key '{exc.args[0]}', "
|
|
536
|
+
"but it was not provided in the inputs.\n"
|
|
537
|
+
"Check whether the objects (such as `zea.Scan`) passed to "
|
|
538
|
+
"`pipeline.prepare_parameters()` contain all required keys.\n"
|
|
539
|
+
f"Current list of all passed keys: {list(inputs.keys())}\n"
|
|
540
|
+
f"Valid keys for this pipeline: {self.valid_keys}"
|
|
541
|
+
) from exc
|
|
542
|
+
except Exception as exc:
|
|
543
|
+
raise RuntimeError(
|
|
544
|
+
f"[zea.Pipeline] Error in operation '{op.__class__.__name__}': {exc}"
|
|
545
|
+
) from exc
|
|
546
|
+
inputs = outputs
|
|
547
|
+
return outputs
|
|
548
|
+
|
|
522
549
|
def call(self, **inputs):
|
|
523
550
|
"""Process input data through the pipeline."""
|
|
524
551
|
for operation in self._pipeline_layers:
|
|
@@ -605,13 +632,18 @@ class Pipeline:
|
|
|
605
632
|
if operation.jittable and operation._jit_compile:
|
|
606
633
|
operation.set_jit(value == "ops")
|
|
607
634
|
|
|
635
|
+
@property
|
|
636
|
+
def _call_fn(self):
|
|
637
|
+
"""Get the call function of the pipeline."""
|
|
638
|
+
return self.call if not self.timed else self.timed_call
|
|
639
|
+
|
|
608
640
|
def jit(self):
|
|
609
641
|
"""JIT compile the pipeline."""
|
|
610
|
-
self._call_pipeline = jit(self.
|
|
642
|
+
self._call_pipeline = jit(self._call_fn, **self.jit_kwargs)
|
|
611
643
|
|
|
612
644
|
def unjit(self):
|
|
613
645
|
"""Un-JIT compile the pipeline."""
|
|
614
|
-
self._call_pipeline = self.
|
|
646
|
+
self._call_pipeline = self._call_fn
|
|
615
647
|
|
|
616
648
|
@property
|
|
617
649
|
def jittable(self):
|
|
@@ -2090,7 +2122,7 @@ class Demodulate(Operation):
|
|
|
2090
2122
|
|
|
2091
2123
|
@ops_registry("lambda")
|
|
2092
2124
|
class Lambda(Operation):
|
|
2093
|
-
"""Use any
|
|
2125
|
+
"""Use any function as an operation."""
|
|
2094
2126
|
|
|
2095
2127
|
def __init__(self, func, func_kwargs=None, **kwargs):
|
|
2096
2128
|
super().__init__(**kwargs)
|
zea/utils.py
CHANGED
|
@@ -662,3 +662,34 @@ class FunctionTimer:
|
|
|
662
662
|
yaml.dump(cropped_timings, f, default_flow_style=False)
|
|
663
663
|
|
|
664
664
|
self.last_append = len(self.timings[func_name])
|
|
665
|
+
|
|
666
|
+
def print(self, drop_first: bool | int = False):
|
|
667
|
+
"""Print timing statistics for all recorded functions using formatted output."""
|
|
668
|
+
|
|
669
|
+
# Print title
|
|
670
|
+
print(log.bold("Function Timing Statistics"))
|
|
671
|
+
header = (
|
|
672
|
+
f"{log.cyan('Function'):<30} {log.green('Mean'):<22} "
|
|
673
|
+
f"{log.green('Median'):<22} {log.green('Std Dev'):<22} "
|
|
674
|
+
f"{log.yellow('Min'):<22} {log.yellow('Max'):<22} {log.magenta('Count'):<18}"
|
|
675
|
+
)
|
|
676
|
+
length = len(log.remove_color_escape_codes(header))
|
|
677
|
+
print("=" * length)
|
|
678
|
+
|
|
679
|
+
# Print header
|
|
680
|
+
print(header)
|
|
681
|
+
print("-" * length)
|
|
682
|
+
|
|
683
|
+
# Print data rows
|
|
684
|
+
for func_name in self.timings.keys():
|
|
685
|
+
stats = self.get_stats(func_name, drop_first=drop_first)
|
|
686
|
+
row = (
|
|
687
|
+
f"{log.cyan(func_name):<30} "
|
|
688
|
+
f"{log.green(log.number_to_str(stats['mean'], 6)):<22} "
|
|
689
|
+
f"{log.green(log.number_to_str(stats['median'], 6)):<22} "
|
|
690
|
+
f"{log.green(log.number_to_str(stats['std_dev'], 6)):<22} "
|
|
691
|
+
f"{log.yellow(log.number_to_str(stats['min'], 6)):<22} "
|
|
692
|
+
f"{log.yellow(log.number_to_str(stats['max'], 6)):<22} "
|
|
693
|
+
f"{log.magenta(str(stats['count'])):<18}"
|
|
694
|
+
)
|
|
695
|
+
print(row)
|
zea/visualize.py
CHANGED
|
@@ -169,7 +169,7 @@ def plot_image_grid(
|
|
|
169
169
|
return fig, fig_contents
|
|
170
170
|
|
|
171
171
|
|
|
172
|
-
def plot_quadrants(ax, array, fixed_coord, cmap, slice_index, stride=1, centroid=None):
|
|
172
|
+
def plot_quadrants(ax, array, fixed_coord, cmap, slice_index, stride=1, centroid=None, **kwargs):
|
|
173
173
|
"""
|
|
174
174
|
For a given 3D array, plot a plane with fixed_coord using four individual quadrants.
|
|
175
175
|
|
|
@@ -183,6 +183,7 @@ def plot_quadrants(ax, array, fixed_coord, cmap, slice_index, stride=1, centroid
|
|
|
183
183
|
stride (int, optional): The stride step for plotting. Defaults to 1.
|
|
184
184
|
centroid (tuple, optional): centroid around which to break the quadrants.
|
|
185
185
|
If None, the middle of the image is used.
|
|
186
|
+
**kwargs: Additional keyword arguments for the plot_surface method.
|
|
186
187
|
|
|
187
188
|
Returns:
|
|
188
189
|
matplotlib.axes.Axes3DSubplot: The axis with the plotted quadrants.
|
|
@@ -238,6 +239,7 @@ def plot_quadrants(ax, array, fixed_coord, cmap, slice_index, stride=1, centroid
|
|
|
238
239
|
cstride=stride,
|
|
239
240
|
facecolors=facecolors,
|
|
240
241
|
shade=False,
|
|
242
|
+
**kwargs,
|
|
241
243
|
)
|
|
242
244
|
elif fixed_coord == "y":
|
|
243
245
|
X, Z = np.mgrid[: quadrant.shape[0] + 1, : quadrant.shape[1] + 1]
|
|
@@ -252,6 +254,7 @@ def plot_quadrants(ax, array, fixed_coord, cmap, slice_index, stride=1, centroid
|
|
|
252
254
|
cstride=stride,
|
|
253
255
|
facecolors=facecolors,
|
|
254
256
|
shade=False,
|
|
257
|
+
**kwargs,
|
|
255
258
|
)
|
|
256
259
|
elif fixed_coord == "z":
|
|
257
260
|
X, Y = np.mgrid[: quadrant.shape[0] + 1, : quadrant.shape[1] + 1]
|
|
@@ -266,6 +269,7 @@ def plot_quadrants(ax, array, fixed_coord, cmap, slice_index, stride=1, centroid
|
|
|
266
269
|
cstride=stride,
|
|
267
270
|
facecolors=facecolors,
|
|
268
271
|
shade=False,
|
|
272
|
+
**kwargs,
|
|
269
273
|
)
|
|
270
274
|
return ax
|
|
271
275
|
|
|
@@ -281,6 +285,7 @@ def plot_biplanes(
|
|
|
281
285
|
show_axes=None,
|
|
282
286
|
fig=None,
|
|
283
287
|
ax=None,
|
|
288
|
+
**kwargs,
|
|
284
289
|
):
|
|
285
290
|
"""
|
|
286
291
|
Plot three intersecting planes from a 3D volume in 3D space.
|
|
@@ -301,6 +306,7 @@ def plot_biplanes(
|
|
|
301
306
|
Defaults to None. Can be used to reuse the figure in a loop.
|
|
302
307
|
ax (matplotlib.axes.Axes3DSubplot, optional): Matplotlib 3D axes object.
|
|
303
308
|
Defaults to None. Can be used to reuse the axes in a loop.
|
|
309
|
+
**kwargs: Additional keyword arguments for the plot_surface method.
|
|
304
310
|
|
|
305
311
|
Returns:
|
|
306
312
|
tuple: A tuple containing the figure and axes objects (fig, ax).
|
|
@@ -340,11 +346,11 @@ def plot_biplanes(
|
|
|
340
346
|
ax.zaxis.pane.fill = False
|
|
341
347
|
|
|
342
348
|
if slice_x is not None:
|
|
343
|
-
plot_quadrants(ax, volume, "x", cmap=cmap, slice_index=slice_x, stride=stride)
|
|
349
|
+
plot_quadrants(ax, volume, "x", cmap=cmap, slice_index=slice_x, stride=stride, **kwargs)
|
|
344
350
|
if slice_y is not None:
|
|
345
|
-
plot_quadrants(ax, volume, "y", cmap=cmap, slice_index=slice_y, stride=stride)
|
|
351
|
+
plot_quadrants(ax, volume, "y", cmap=cmap, slice_index=slice_y, stride=stride, **kwargs)
|
|
346
352
|
if slice_z is not None:
|
|
347
|
-
plot_quadrants(ax, volume, "z", cmap=cmap, slice_index=slice_z, stride=stride)
|
|
353
|
+
plot_quadrants(ax, volume, "z", cmap=cmap, slice_index=slice_z, stride=stride, **kwargs)
|
|
348
354
|
|
|
349
355
|
# Optionally show axes
|
|
350
356
|
if show_axes:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: zea
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.4
|
|
4
4
|
Summary: A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework.
|
|
5
5
|
Keywords: ultrasound,machine learning,beamforming
|
|
6
6
|
Author: Tristan Stevens
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
zea/__init__.py,sha256
|
|
1
|
+
zea/__init__.py,sha256=-et4Iz8oJRRaQahurztVr7Rvz9Do7WkXfq5D5CuhX7w,2274
|
|
2
2
|
zea/__main__.py,sha256=QQaxECJxTUo99q43OobV1Tc9m4POtmjRbeKqpkCZ4Lo,2020
|
|
3
3
|
zea/agent/__init__.py,sha256=uJjMiPvvCXmUxC2mkSkh1Q9Ege0Vf7WazKX1_Ul80GY,924
|
|
4
4
|
zea/agent/gumbel.py,sha256=WbvSrM8meXtZmDLkyXDGUyRzjB4dWZmNO_qzloRxB_s,3770
|
|
5
5
|
zea/agent/masks.py,sha256=qdSGbTs9449hUxcX6gAl_s47mrs1FKI6U_T2KjS-iz8,6581
|
|
6
|
-
zea/agent/selection.py,sha256=
|
|
6
|
+
zea/agent/selection.py,sha256=j9FejT28jRGnRIEL6F1YZYdl22G0XxKOkua9qhKavIg,19764
|
|
7
7
|
zea/backend/__init__.py,sha256=9evPyuxJtIeleZlX_eAggJtxxKRofLQiD_ZV3yARg2Y,4227
|
|
8
8
|
zea/backend/autograd.py,sha256=buWy19ctDFsAoktZaLm5qyLN9nQRicI91-V3ND-t3f4,6868
|
|
9
9
|
zea/backend/jax/__init__.py,sha256=kU89ZZeVQOm9tdrlGSLLlkrJzhUNW0wOet5lnU5BMPU,2168
|
|
@@ -43,7 +43,7 @@ zea/data/convert/echonetlvh/precompute_crop.py,sha256=A-dSbGngEvL5QfAid3d9OoZRnx
|
|
|
43
43
|
zea/data/convert/images.py,sha256=8a2oHqAA_fLhFuNnqBsNCjKTdX5pADX-ntZ6BHUoZdQ,5229
|
|
44
44
|
zea/data/convert/matlab.py,sha256=nep-n-CK7OoTQQytDZuMx3dTa7AngmYRodE55ZTCCvI,41992
|
|
45
45
|
zea/data/convert/picmus.py,sha256=1TK1vDWzfHmnFErtt0ad36TInIiy4fm3nMmfx83G5nw,6203
|
|
46
|
-
zea/data/data_format.py,sha256=
|
|
46
|
+
zea/data/data_format.py,sha256=n4OTQKl1newgvTamlPprAsRubAVt3bPeNYHJXNc89G0,25684
|
|
47
47
|
zea/data/dataloader.py,sha256=ZAlWbZB6F43EtLvRcUzfeHVtMhYViJTvyaRb9CmpEFc,14865
|
|
48
48
|
zea/data/datasets.py,sha256=XcNlJVmkfNo7TXp80y4Si_by75Ix2ULIU_iEiNEhl6Q,24492
|
|
49
49
|
zea/data/file.py,sha256=dG6if-hofBzfOTr_39_X_OFlTx7t8s7ggrYR7XwyhI8,29387
|
|
@@ -59,22 +59,22 @@ zea/internal/config/create.py,sha256=f8dy6YcAUZbROxwCvORI26ACMRO4Q7EL188KyTL8Bpk
|
|
|
59
59
|
zea/internal/config/parameters.py,sha256=3YE8aBeiosxTuIT0RAEEMkJPmkykSyZZ7hbmmyGnsuc,6602
|
|
60
60
|
zea/internal/config/validation.py,sha256=vPGFWC-UVtNh8jqfQ4V8HTUIjfnOtsJrqcG2R2ae5hg,6522
|
|
61
61
|
zea/internal/convert.py,sha256=EDg6vY73vQht-AwwrYUJ2XrE2L64qL3JBXxnft7Xw70,4596
|
|
62
|
-
zea/internal/core.py,sha256=
|
|
62
|
+
zea/internal/core.py,sha256=QAiN1yFgIm8PGTn8RyE3LyZd5YWpbxuSGwbjfbEYJzU,9387
|
|
63
63
|
zea/internal/device.py,sha256=pcgKHWudolflxe6umQh5DVNDz8QtQBZ8tDiWDNwTays,15175
|
|
64
64
|
zea/internal/git_info.py,sha256=vEeN7cdppNIJPRkC69pUQqtAfTdwCN-tpfo702xpGzY,1040
|
|
65
65
|
zea/internal/operators.py,sha256=0CMv4s0k8MIlD0uMKcxiHq4YryIfacklBFnCqf7RO0A,1837
|
|
66
66
|
zea/internal/parameters.py,sha256=MbYYQQuESQtzXA1b7S3psESdfTbntzQCgljkV_zTlLw,17857
|
|
67
67
|
zea/internal/registry.py,sha256=lQsJbYUz1S2eKzE-7shRYWUBulX2TjHTRN37evrYIGA,7884
|
|
68
68
|
zea/internal/setup_zea.py,sha256=P8dmHk_0hwfukaf-DfPxeDofOH4Kj1s5YvUs8yeFqAQ,7800
|
|
69
|
-
zea/internal/viewer.py,sha256=
|
|
69
|
+
zea/internal/viewer.py,sha256=VsDehOQgqvVp7gkw0kz2Y2v_My8P8FshJEOCM4E0RIk,15745
|
|
70
70
|
zea/io_lib.py,sha256=z65zGUDaRsxeYnddgy9LtEkbHfKhW3O3wPumqozjZvg,12527
|
|
71
|
-
zea/log.py,sha256=
|
|
71
|
+
zea/log.py,sha256=UJIL91lHUgWc-vrlJWOn1PX60fX3TFQ5slIs-b5EWEQ,10540
|
|
72
72
|
zea/metrics.py,sha256=22BonPRPs6sfpDqtSEQGC1icNhsSywBak4iaf_1cook,4894
|
|
73
73
|
zea/models/__init__.py,sha256=gBW1pXrD01beK10dlV6GeY6H17Ym_6wrzMDXrbUs1dk,4183
|
|
74
74
|
zea/models/base.py,sha256=_l1RlXIYS2V5py-en5llJpX1oU0IXK_hzLhfCYybzHg,7121
|
|
75
75
|
zea/models/carotid_segmenter.py,sha256=qxv4xSSbwLQ3AWeP8xoVCFhpPLOqsN-4dNq1ECkG3FM,5401
|
|
76
76
|
zea/models/dense.py,sha256=EwrSawfSTi9oWE55Grr7jtwLXC9MNqEOO3su4oYHqfg,4067
|
|
77
|
-
zea/models/diffusion.py,sha256=
|
|
77
|
+
zea/models/diffusion.py,sha256=DJhrV_NGgRFK8p3kIBelM-0mqvmuDbdxRyWugaftv-0,31841
|
|
78
78
|
zea/models/echonet.py,sha256=toLw6QjpdROvGrm6VEuqHuxxKkiAWtX-__5YzrE-LJA,6346
|
|
79
79
|
zea/models/generative.py,sha256=iujicyFDuCD7NEk_cZ8thlZ2Rl3Qa8LfkwPsZdWYpR0,2625
|
|
80
80
|
zea/models/gmm.py,sha256=6YYoiizsD7BtISOToEXkki6Cc8iXMkgmPH_rMFQKs3E,8324
|
|
@@ -85,7 +85,7 @@ zea/models/presets.py,sha256=drTQF2XCAxMFSqZtu_6m-kpPtcG6io4DrXsBcdFss-4,2623
|
|
|
85
85
|
zea/models/taesd.py,sha256=Vab5jywo4uxXPXMQ-8VdpOMhVwQsEc_PNztctDzFZfk,8403
|
|
86
86
|
zea/models/unet.py,sha256=B_600WLn3jj_BuhDdRrj5I1XQIHylaQQN8ISq-vt7zM,6392
|
|
87
87
|
zea/models/utils.py,sha256=My6VY27S4udOn7xatIM22Qgn8jED1FmnA2yZw1mcuVw,2015
|
|
88
|
-
zea/ops.py,sha256=
|
|
88
|
+
zea/ops.py,sha256=wgm_p1rbQFFQ5856u2SIYGGTv3dFaJR1ogdMkiteChc,109464
|
|
89
89
|
zea/probes.py,sha256=991X4ilpMablekNxAHwBU6DoBIZHrmzoJgs2C6zfU0U,7632
|
|
90
90
|
zea/scan.py,sha256=dp-PiQZWk4HnRROeu6cwrGcM8FNjezDzkF05ZJcacVI,26935
|
|
91
91
|
zea/simulator.py,sha256=KziYNRNAwIyKNE4nzVp2cbNyLATaaUkF-Ity6O1fx20,11279
|
|
@@ -95,11 +95,11 @@ zea/tools/fit_scan_cone.py,sha256=Pw1kiXUAXenCVIO1nfqWVne2i0eWWXHFmoAQ7l_Qzro,24
|
|
|
95
95
|
zea/tools/hf.py,sha256=ibuTGaLitK1vOlQdZJuVmDBoUkYRrEfCctj3kiiOQGs,5477
|
|
96
96
|
zea/tools/selection_tool.py,sha256=dNksdVw_sNcelAyszbrseOTQRDxJpwlapabdh7feNoQ,29848
|
|
97
97
|
zea/tools/wndb.py,sha256=8XY056arnDKpVV7k-B5PrMa-RANur3ldPSR4GW3jgS4,666
|
|
98
|
-
zea/utils.py,sha256=
|
|
99
|
-
zea/visualize.py,sha256=
|
|
98
|
+
zea/utils.py,sha256=8gxZaojsC0D2S63gs6lN0ulcSJ1jry3AzOMFvBdl4H4,23351
|
|
99
|
+
zea/visualize.py,sha256=4-bEHRYG3hUukmBkx7x4VBoa3oAJffO3HnnnzTfduDE,23905
|
|
100
100
|
zea/zea_darkmode.mplstyle,sha256=wHTXkgy00tLEbRmr8GZULb5zIzU0MTMn9xC0Z3WT7Bo,42141
|
|
101
|
-
zea-0.0.
|
|
102
|
-
zea-0.0.
|
|
103
|
-
zea-0.0.
|
|
104
|
-
zea-0.0.
|
|
105
|
-
zea-0.0.
|
|
101
|
+
zea-0.0.4.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
102
|
+
zea-0.0.4.dist-info/METADATA,sha256=VL_XaEx_hcxAPQTU2EcGEaz1RfYxbZ5p_5XRfmqOA6Q,6489
|
|
103
|
+
zea-0.0.4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
104
|
+
zea-0.0.4.dist-info/entry_points.txt,sha256=hQcQYCHdMu2LRM1PGZuaGU5EwAjTGErC-QakgwZKZeo,41
|
|
105
|
+
zea-0.0.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|