cuslines 2.2__tar.gz → 2.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cuslines-2.2 → cuslines-2.2.2}/PKG-INFO +1 -1
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/generic_tracker.py +61 -5
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal/mt_tractography.py +2 -90
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/numba/nu_tractography.py +0 -2
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/webgpu/wg_tractography.py +2 -91
- {cuslines-2.2 → cuslines-2.2.2}/cuslines.egg-info/PKG-INFO +1 -1
- {cuslines-2.2 → cuslines-2.2.2}/run_gpu_streamlines.py +1 -1
- {cuslines-2.2 → cuslines-2.2.2}/.github/workflows/dockerbuild.yml +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/.github/workflows/publish_pypi.yml +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/.gitignore +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/.pre-commit-config.yaml +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/CLAUDE.md +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/Dockerfile +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/LICENSE +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/README.md +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/__init__.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/boot_utils.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_c/boot.cu +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_c/cudamacro.h +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_c/cuwsort.cuh +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_c/disc.h +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_c/generate_streamlines_cuda.cu +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_c/globals.h +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_c/ptt.cu +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_c/ptt.cuh +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_c/ptt_init.cu +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_c/tracking_helpers.cu +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_c/utils.cu +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_python/__init__.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_python/_globals.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_python/cu_direction_getters.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_python/cu_propagate_seeds.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_python/cu_tractography.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/cuda_python/cutils.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal/README.md +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal/__init__.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal/mt_direction_getters.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal/mt_propagate_seeds.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal/mutils.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal_shaders/boot.metal +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal_shaders/disc.h +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal_shaders/generate_streamlines_metal.metal +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal_shaders/globals.h +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal_shaders/philox_rng.h +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal_shaders/ptt.metal +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal_shaders/tracking_helpers.metal +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal_shaders/types.h +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal_shaders/utils.metal +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/metal_shaders/warp_sort.metal +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/numba/__init__.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/numba/nu_globals.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/numba_njit/generate_streamlines_numba.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/numba_njit/num_streamlines_numba.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/numba_njit/tracking_helpers.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/webgpu/README.md +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/webgpu/__init__.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/webgpu/benchmark.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/webgpu/wg_direction_getters.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/webgpu/wg_propagate_seeds.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/webgpu/wgutils.py +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/wgsl_shaders/boot.wgsl +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/wgsl_shaders/disc.wgsl +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/wgsl_shaders/generate_streamlines.wgsl +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/wgsl_shaders/globals.wgsl +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/wgsl_shaders/philox_rng.wgsl +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/wgsl_shaders/ptt.wgsl +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/wgsl_shaders/tracking_helpers.wgsl +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/wgsl_shaders/types.wgsl +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/wgsl_shaders/utils.wgsl +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines/wgsl_shaders/warp_sort.wgsl +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines.egg-info/SOURCES.txt +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines.egg-info/dependency_links.txt +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines.egg-info/requires.txt +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/cuslines.egg-info/top_level.txt +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/pyproject.toml +0 -0
- {cuslines-2.2 → cuslines-2.2.2}/setup.cfg +0 -0
|
@@ -3,6 +3,7 @@ import numpy as np
|
|
|
3
3
|
from tqdm import tqdm
|
|
4
4
|
from trx.trx_file_memmap import TrxFile
|
|
5
5
|
from dipy.io.stateful_tractogram import Space, StatefulTractogram
|
|
6
|
+
from dipy.tracking.streamlinespeed import compress_streamlines
|
|
6
7
|
from nibabel.streamlines.array_sequence import ArraySequence
|
|
7
8
|
from nibabel.streamlines.tractogram import Tractogram
|
|
8
9
|
|
|
@@ -16,11 +17,55 @@ class GenericTracker:
|
|
|
16
17
|
def __exit__(self, exc_type, exc, tb):
|
|
17
18
|
return False
|
|
18
19
|
|
|
20
|
+
def set_compression_parameters(self, pos_dtype=np.float32, linearize=False, tol_error=0.1, max_segment_length=10):
|
|
21
|
+
"""
|
|
22
|
+
Set compression parameters to compress generated streamlines.
|
|
23
|
+
Only works with TRX.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
pos_dtype : dtype, optional
|
|
28
|
+
Data type to use for the positions of the streamlines.
|
|
29
|
+
Default: np.float32
|
|
30
|
+
|
|
31
|
+
linearize : bool, optional
|
|
32
|
+
Whether to linearize the streamlines using [1].
|
|
33
|
+
Default: False
|
|
34
|
+
|
|
35
|
+
tol_error : float, optional
|
|
36
|
+
If linearize is true, tolerance error in mm.
|
|
37
|
+
Default: 0.1
|
|
38
|
+
|
|
39
|
+
max_segment_length : float, optional
|
|
40
|
+
If linearize is true, maximum length in mm of any given segment produced by the compression.
|
|
41
|
+
Default: 10
|
|
42
|
+
|
|
43
|
+
References
|
|
44
|
+
----------
|
|
45
|
+
[1] Caroline Presseau, Pierre-Marc Jodoin, Jean-Christophe Houde, and Maxime Descoteaux.
|
|
46
|
+
A new compression format for fiber tracking datasets.
|
|
47
|
+
NeuroImage, 109:73-83, 2015. URL: 10.1016/j.neuroimage.2014.12.058
|
|
48
|
+
"""
|
|
49
|
+
self.pos_dtype = pos_dtype
|
|
50
|
+
self.linearize = linearize
|
|
51
|
+
self.tol_error = tol_error
|
|
52
|
+
self.max_segment_length = max_segment_length
|
|
53
|
+
|
|
54
|
+
|
|
19
55
|
def _ngpus(self):
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
56
|
+
return getattr(self, "ngpus", 1)
|
|
57
|
+
|
|
58
|
+
def _pos_dtype(self):
|
|
59
|
+
return getattr(self, "pos_dtype", np.float16)
|
|
60
|
+
|
|
61
|
+
def _linearize(self):
|
|
62
|
+
return getattr(self, "linearize", False)
|
|
63
|
+
|
|
64
|
+
def _tol_error(self):
|
|
65
|
+
return getattr(self, "tol_error", 0.1)
|
|
66
|
+
|
|
67
|
+
def _max_segment_length(self):
|
|
68
|
+
return getattr(self, "max_segment_length", np.inf)
|
|
24
69
|
|
|
25
70
|
def _divide_chunks(self, seeds):
|
|
26
71
|
global_chunk_sz = self.chunk_size * self._ngpus()
|
|
@@ -58,7 +103,7 @@ class GenericTracker:
|
|
|
58
103
|
# trx files use memory mapping
|
|
59
104
|
trx_reference = TrxFile(reference=ref_img)
|
|
60
105
|
trx_reference.streamlines._data = trx_reference.streamlines._data.astype(
|
|
61
|
-
|
|
106
|
+
self._pos_dtype()
|
|
62
107
|
)
|
|
63
108
|
trx_reference.streamlines._offsets = trx_reference.streamlines._offsets.astype(
|
|
64
109
|
np.uint64
|
|
@@ -81,9 +126,20 @@ class GenericTracker:
|
|
|
81
126
|
self.seed_propagator.as_array_sequence(),
|
|
82
127
|
affine_to_rasmm=ref_img.affine,
|
|
83
128
|
)
|
|
129
|
+
if len(tractogram) == 0:
|
|
130
|
+
continue
|
|
131
|
+
|
|
84
132
|
tractogram.to_world()
|
|
85
133
|
sls = tractogram.streamlines
|
|
86
134
|
|
|
135
|
+
if self._linearize():
|
|
136
|
+
sls = ArraySequence(compress_streamlines(
|
|
137
|
+
sls,
|
|
138
|
+
tol_error=self._tol_error(),
|
|
139
|
+
max_segment_length=self._max_segment_length(),
|
|
140
|
+
))
|
|
141
|
+
sls._data = sls._data.astype(self._pos_dtype())
|
|
142
|
+
|
|
87
143
|
new_offsets_idx = offsets_idx + len(sls._offsets)
|
|
88
144
|
new_sls_data_idx = sls_data_idx + len(sls._data)
|
|
89
145
|
|
|
@@ -5,25 +5,16 @@ we wrap numpy arrays as Metal shared buffers with zero copies.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
|
-
from tqdm import tqdm
|
|
9
8
|
import logging
|
|
10
9
|
from math import radians
|
|
11
10
|
|
|
12
11
|
from cuslines.metal.mutils import (
|
|
13
|
-
REAL_SIZE,
|
|
14
12
|
REAL_DTYPE,
|
|
15
|
-
aligned_array,
|
|
16
|
-
PAGE_SIZE,
|
|
17
|
-
checkMetalError,
|
|
18
13
|
)
|
|
19
14
|
|
|
20
15
|
from cuslines.metal.mt_direction_getters import MetalGPUDirectionGetter, MetalBootDirectionGetter
|
|
21
16
|
from cuslines.metal.mt_propagate_seeds import MetalSeedBatchPropagator
|
|
22
|
-
|
|
23
|
-
from trx.trx_file_memmap import TrxFile
|
|
24
|
-
from nibabel.streamlines.tractogram import Tractogram
|
|
25
|
-
from nibabel.streamlines.array_sequence import ArraySequence, MEGABYTE
|
|
26
|
-
from dipy.io.stateful_tractogram import Space, StatefulTractogram
|
|
17
|
+
from cuslines.generic_tracker import GenericTracker
|
|
27
18
|
|
|
28
19
|
logger = logging.getLogger("GPUStreamlines")
|
|
29
20
|
|
|
@@ -64,7 +55,7 @@ def _buffer_as_array(buf, dtype, shape):
|
|
|
64
55
|
return np.frombuffer(memview, dtype=dtype, count=count).reshape(shape)
|
|
65
56
|
|
|
66
57
|
|
|
67
|
-
class MetalGPUTracker:
|
|
58
|
+
class MetalGPUTracker(GenericTracker):
|
|
68
59
|
def __init__(
|
|
69
60
|
self,
|
|
70
61
|
dg: MetalGPUDirectionGetter,
|
|
@@ -177,82 +168,3 @@ class MetalGPUTracker:
|
|
|
177
168
|
self.dg.gen_pipeline = None
|
|
178
169
|
self._allocated = False
|
|
179
170
|
return False
|
|
180
|
-
|
|
181
|
-
def _divide_chunks(self, seeds):
|
|
182
|
-
global_chunk_sz = self.chunk_size # single GPU
|
|
183
|
-
nchunks = (seeds.shape[0] + global_chunk_sz - 1) // global_chunk_sz
|
|
184
|
-
return global_chunk_sz, nchunks
|
|
185
|
-
|
|
186
|
-
def generate_sft(self, seeds, ref_img):
|
|
187
|
-
global_chunk_sz, nchunks = self._divide_chunks(seeds)
|
|
188
|
-
buffer_size = 0
|
|
189
|
-
generators = []
|
|
190
|
-
|
|
191
|
-
with tqdm(total=seeds.shape[0]) as pbar:
|
|
192
|
-
for idx in range(nchunks):
|
|
193
|
-
chunk = seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz]
|
|
194
|
-
self.seed_propagator.propagate(chunk)
|
|
195
|
-
buffer_size += self.seed_propagator.get_buffer_size()
|
|
196
|
-
generators.append(self.seed_propagator.as_generator())
|
|
197
|
-
pbar.update(chunk.shape[0])
|
|
198
|
-
|
|
199
|
-
array_sequence = ArraySequence(
|
|
200
|
-
(item for gen in generators for item in gen), buffer_size
|
|
201
|
-
)
|
|
202
|
-
return StatefulTractogram(array_sequence, ref_img, Space.VOX)
|
|
203
|
-
|
|
204
|
-
def generate_trx(self, seeds, ref_img):
|
|
205
|
-
global_chunk_sz, nchunks = self._divide_chunks(seeds)
|
|
206
|
-
|
|
207
|
-
sl_len_guess = 100
|
|
208
|
-
sl_per_seed_guess = 2
|
|
209
|
-
n_sls_guess = sl_per_seed_guess * seeds.shape[0]
|
|
210
|
-
|
|
211
|
-
trx_reference = TrxFile(reference=ref_img)
|
|
212
|
-
trx_reference.streamlines._data = trx_reference.streamlines._data.astype(np.float32)
|
|
213
|
-
trx_reference.streamlines._offsets = trx_reference.streamlines._offsets.astype(np.uint64)
|
|
214
|
-
|
|
215
|
-
trx_file = TrxFile(
|
|
216
|
-
nb_streamlines=n_sls_guess,
|
|
217
|
-
nb_vertices=n_sls_guess * sl_len_guess,
|
|
218
|
-
init_as=trx_reference,
|
|
219
|
-
)
|
|
220
|
-
offsets_idx = 0
|
|
221
|
-
sls_data_idx = 0
|
|
222
|
-
|
|
223
|
-
with tqdm(total=seeds.shape[0]) as pbar:
|
|
224
|
-
for idx in range(int(nchunks)):
|
|
225
|
-
chunk = seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz]
|
|
226
|
-
self.seed_propagator.propagate(chunk)
|
|
227
|
-
tractogram = Tractogram(
|
|
228
|
-
self.seed_propagator.as_array_sequence(),
|
|
229
|
-
affine_to_rasmm=ref_img.affine,
|
|
230
|
-
)
|
|
231
|
-
tractogram.to_world()
|
|
232
|
-
sls = tractogram.streamlines
|
|
233
|
-
|
|
234
|
-
new_offsets_idx = offsets_idx + len(sls._offsets)
|
|
235
|
-
new_sls_data_idx = sls_data_idx + len(sls._data)
|
|
236
|
-
|
|
237
|
-
if (
|
|
238
|
-
new_offsets_idx > trx_file.header["NB_STREAMLINES"]
|
|
239
|
-
or new_sls_data_idx > trx_file.header["NB_VERTICES"]
|
|
240
|
-
):
|
|
241
|
-
logger.info("TRX resizing...")
|
|
242
|
-
trx_file.resize(
|
|
243
|
-
nb_streamlines=new_offsets_idx * 2,
|
|
244
|
-
nb_vertices=new_sls_data_idx * 2,
|
|
245
|
-
)
|
|
246
|
-
|
|
247
|
-
trx_file.streamlines._data[sls_data_idx:new_sls_data_idx] = sls._data
|
|
248
|
-
trx_file.streamlines._offsets[offsets_idx:new_offsets_idx] = (
|
|
249
|
-
sls_data_idx + sls._offsets
|
|
250
|
-
)
|
|
251
|
-
trx_file.streamlines._lengths[offsets_idx:new_offsets_idx] = sls._lengths
|
|
252
|
-
|
|
253
|
-
offsets_idx = new_offsets_idx
|
|
254
|
-
sls_data_idx = new_sls_data_idx
|
|
255
|
-
pbar.update(chunk.shape[0])
|
|
256
|
-
|
|
257
|
-
trx_file.resize()
|
|
258
|
-
return trx_file
|
|
@@ -2,9 +2,7 @@ import math
|
|
|
2
2
|
from math import radians
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
|
-
from dipy.io.stateful_tractogram import Space, StatefulTractogram
|
|
6
5
|
from nibabel.streamlines.array_sequence import ArraySequence, MEGABYTE
|
|
7
|
-
from tqdm import tqdm
|
|
8
6
|
|
|
9
7
|
from cuslines.generic_tracker import GenericTracker
|
|
10
8
|
from cuslines.numba_njit.num_streamlines_numba import getNumStreamlinesProb_generator
|
|
@@ -5,28 +5,22 @@ readbacks via device.queue.read_buffer() (similar to CUDA's cudaMemcpy).
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
|
-
from tqdm import tqdm
|
|
9
8
|
import logging
|
|
10
9
|
from math import radians
|
|
11
10
|
|
|
12
11
|
from cuslines.webgpu.wgutils import (
|
|
13
|
-
REAL_SIZE,
|
|
14
12
|
REAL_DTYPE,
|
|
15
13
|
create_buffer_from_data,
|
|
16
14
|
)
|
|
17
15
|
|
|
18
16
|
from cuslines.webgpu.wg_direction_getters import WebGPUDirectionGetter, WebGPUBootDirectionGetter
|
|
19
17
|
from cuslines.webgpu.wg_propagate_seeds import WebGPUSeedBatchPropagator
|
|
20
|
-
|
|
21
|
-
from trx.trx_file_memmap import TrxFile
|
|
22
|
-
from nibabel.streamlines.tractogram import Tractogram
|
|
23
|
-
from nibabel.streamlines.array_sequence import ArraySequence, MEGABYTE
|
|
24
|
-
from dipy.io.stateful_tractogram import Space, StatefulTractogram
|
|
18
|
+
from cuslines.generic_tracker import GenericTracker
|
|
25
19
|
|
|
26
20
|
logger = logging.getLogger("GPUStreamlines")
|
|
27
21
|
|
|
28
22
|
|
|
29
|
-
class WebGPUTracker:
|
|
23
|
+
class WebGPUTracker(GenericTracker):
|
|
30
24
|
def __init__(
|
|
31
25
|
self,
|
|
32
26
|
dg: WebGPUDirectionGetter,
|
|
@@ -207,86 +201,3 @@ class WebGPUTracker:
|
|
|
207
201
|
self.device = None
|
|
208
202
|
self._allocated = False
|
|
209
203
|
return False
|
|
210
|
-
|
|
211
|
-
def _divide_chunks(self, seeds):
|
|
212
|
-
global_chunk_sz = self.chunk_size # single GPU
|
|
213
|
-
nchunks = (seeds.shape[0] + global_chunk_sz - 1) // global_chunk_sz
|
|
214
|
-
return global_chunk_sz, nchunks
|
|
215
|
-
|
|
216
|
-
def generate_sft(self, seeds, ref_img):
|
|
217
|
-
global_chunk_sz, nchunks = self._divide_chunks(seeds)
|
|
218
|
-
buffer_size = 0
|
|
219
|
-
generators = []
|
|
220
|
-
|
|
221
|
-
with tqdm(total=seeds.shape[0]) as pbar:
|
|
222
|
-
for idx in range(nchunks):
|
|
223
|
-
chunk = seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz]
|
|
224
|
-
self.seed_propagator.propagate(chunk)
|
|
225
|
-
buffer_size += self.seed_propagator.get_buffer_size()
|
|
226
|
-
generators.append(self.seed_propagator.as_generator())
|
|
227
|
-
pbar.update(chunk.shape[0])
|
|
228
|
-
|
|
229
|
-
array_sequence = ArraySequence(
|
|
230
|
-
(item for gen in generators for item in gen), buffer_size
|
|
231
|
-
)
|
|
232
|
-
return StatefulTractogram(array_sequence, ref_img, Space.VOX)
|
|
233
|
-
|
|
234
|
-
def generate_trx(self, seeds, ref_img):
|
|
235
|
-
global_chunk_sz, nchunks = self._divide_chunks(seeds)
|
|
236
|
-
|
|
237
|
-
sl_len_guess = 100
|
|
238
|
-
sl_per_seed_guess = 2
|
|
239
|
-
n_sls_guess = sl_per_seed_guess * seeds.shape[0]
|
|
240
|
-
|
|
241
|
-
trx_reference = TrxFile(reference=ref_img)
|
|
242
|
-
trx_reference.streamlines._data = trx_reference.streamlines._data.astype(
|
|
243
|
-
np.float32
|
|
244
|
-
)
|
|
245
|
-
trx_reference.streamlines._offsets = trx_reference.streamlines._offsets.astype(
|
|
246
|
-
np.uint64
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
trx_file = TrxFile(
|
|
250
|
-
nb_streamlines=n_sls_guess,
|
|
251
|
-
nb_vertices=n_sls_guess * sl_len_guess,
|
|
252
|
-
init_as=trx_reference,
|
|
253
|
-
)
|
|
254
|
-
offsets_idx = 0
|
|
255
|
-
sls_data_idx = 0
|
|
256
|
-
|
|
257
|
-
with tqdm(total=seeds.shape[0]) as pbar:
|
|
258
|
-
for idx in range(int(nchunks)):
|
|
259
|
-
chunk = seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz]
|
|
260
|
-
self.seed_propagator.propagate(chunk)
|
|
261
|
-
tractogram = Tractogram(
|
|
262
|
-
self.seed_propagator.as_array_sequence(),
|
|
263
|
-
affine_to_rasmm=ref_img.affine,
|
|
264
|
-
)
|
|
265
|
-
tractogram.to_world()
|
|
266
|
-
sls = tractogram.streamlines
|
|
267
|
-
|
|
268
|
-
new_offsets_idx = offsets_idx + len(sls._offsets)
|
|
269
|
-
new_sls_data_idx = sls_data_idx + len(sls._data)
|
|
270
|
-
|
|
271
|
-
if (
|
|
272
|
-
new_offsets_idx > trx_file.header["NB_STREAMLINES"]
|
|
273
|
-
or new_sls_data_idx > trx_file.header["NB_VERTICES"]
|
|
274
|
-
):
|
|
275
|
-
logger.info("TRX resizing...")
|
|
276
|
-
trx_file.resize(
|
|
277
|
-
nb_streamlines=new_offsets_idx * 2,
|
|
278
|
-
nb_vertices=new_sls_data_idx * 2,
|
|
279
|
-
)
|
|
280
|
-
|
|
281
|
-
trx_file.streamlines._data[sls_data_idx:new_sls_data_idx] = sls._data
|
|
282
|
-
trx_file.streamlines._offsets[offsets_idx:new_offsets_idx] = (
|
|
283
|
-
sls_data_idx + sls._offsets
|
|
284
|
-
)
|
|
285
|
-
trx_file.streamlines._lengths[offsets_idx:new_offsets_idx] = sls._lengths
|
|
286
|
-
|
|
287
|
-
offsets_idx = new_offsets_idx
|
|
288
|
-
sls_data_idx = new_sls_data_idx
|
|
289
|
-
pbar.update(chunk.shape[0])
|
|
290
|
-
|
|
291
|
-
trx_file.resize()
|
|
292
|
-
return trx_file
|
|
@@ -127,7 +127,7 @@ parser.add_argument(
|
|
|
127
127
|
parser.add_argument(
|
|
128
128
|
"--ngpus", type=int, default=1, help="number of GPUs to use if using gpu"
|
|
129
129
|
)
|
|
130
|
-
parser.add_argument("--write-method", type=str, default="
|
|
130
|
+
parser.add_argument("--write-method", type=str, default="trx", help="Can be trx or trk")
|
|
131
131
|
parser.add_argument(
|
|
132
132
|
"--max-angle", type=float, default=60, help="max angle (in degrees)"
|
|
133
133
|
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|