cellfinder 1.1.3__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cellfinder might be problematic. Click here for more details.
- cellfinder/__init__.py +21 -12
- cellfinder/core/classify/classify.py +13 -6
- cellfinder/core/classify/cube_generator.py +27 -11
- cellfinder/core/classify/resnet.py +9 -6
- cellfinder/core/classify/tools.py +13 -11
- cellfinder/core/detect/detect.py +12 -1
- cellfinder/core/detect/filters/volume/ball_filter.py +198 -113
- cellfinder/core/detect/filters/volume/structure_detection.py +105 -41
- cellfinder/core/detect/filters/volume/structure_splitting.py +1 -1
- cellfinder/core/detect/filters/volume/volume_filter.py +48 -49
- cellfinder/core/download/cli.py +39 -32
- cellfinder/core/download/download.py +44 -56
- cellfinder/core/main.py +53 -68
- cellfinder/core/tools/prep.py +12 -20
- cellfinder/core/tools/source_files.py +5 -3
- cellfinder/core/tools/system.py +10 -0
- cellfinder/core/train/train_yml.py +29 -27
- cellfinder/napari/curation.py +1 -1
- cellfinder/napari/detect/detect.py +259 -58
- cellfinder/napari/detect/detect_containers.py +11 -1
- cellfinder/napari/detect/thread_worker.py +16 -2
- cellfinder/napari/train/train.py +2 -9
- cellfinder/napari/train/train_containers.py +3 -3
- cellfinder/napari/utils.py +88 -47
- {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/METADATA +12 -11
- {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/RECORD +30 -34
- cellfinder/core/download/models.py +0 -49
- cellfinder/core/tools/IO.py +0 -48
- cellfinder/core/tools/tf.py +0 -46
- cellfinder/napari/images/brainglobe.png +0 -0
- {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/LICENSE +0 -0
- {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/WHEEL +0 -0
- {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/entry_points.txt +0 -0
- {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,22 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Dict, Optional, TypeVar
|
|
2
|
+
from typing import Dict, Optional, Tuple, TypeVar, Union
|
|
3
3
|
|
|
4
4
|
import numba.typed
|
|
5
5
|
import numpy as np
|
|
6
6
|
import numpy.typing as npt
|
|
7
|
-
from numba import njit
|
|
7
|
+
from numba import njit, typed
|
|
8
8
|
from numba.core import types
|
|
9
9
|
from numba.experimental import jitclass
|
|
10
10
|
from numba.types import DictType
|
|
11
11
|
|
|
12
|
+
T = TypeVar("T")
|
|
13
|
+
# type used for the domain of the volume - the size of the vol
|
|
14
|
+
vol_np_type = np.int64
|
|
15
|
+
vol_numba_type = types.int64
|
|
16
|
+
# type used for the structure id
|
|
17
|
+
sid_np_type = np.int64
|
|
18
|
+
sid_numba_type = types.int64
|
|
19
|
+
|
|
12
20
|
|
|
13
21
|
@dataclass
|
|
14
22
|
class Point:
|
|
@@ -32,18 +40,15 @@ def get_non_zero_dtype_min(values: np.ndarray) -> int:
|
|
|
32
40
|
return min_val
|
|
33
41
|
|
|
34
42
|
|
|
35
|
-
T = TypeVar("T")
|
|
36
|
-
|
|
37
|
-
|
|
38
43
|
@njit
|
|
39
44
|
def traverse_dict(d: Dict[T, T], a: T) -> T:
|
|
40
45
|
"""
|
|
41
46
|
Traverse d, until a is not present as a key.
|
|
42
47
|
"""
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
48
|
+
value = a
|
|
49
|
+
while value in d:
|
|
50
|
+
value = d[value]
|
|
51
|
+
return value
|
|
47
52
|
|
|
48
53
|
|
|
49
54
|
@njit
|
|
@@ -54,14 +59,28 @@ def get_structure_centre(structure: np.ndarray) -> np.ndarray:
|
|
|
54
59
|
Centre calculated as the mean of each pixel coordinate,
|
|
55
60
|
rounded to the nearest integer.
|
|
56
61
|
"""
|
|
57
|
-
#
|
|
58
|
-
|
|
62
|
+
# numba support axis for sum, but not mean
|
|
63
|
+
return np.round(np.sum(structure, axis=0) / structure.shape[0])
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@njit
|
|
67
|
+
def _get_structure_centre(structure: types.ListType) -> np.ndarray:
|
|
68
|
+
# See get_structure_centre.
|
|
69
|
+
# this is for our own points stored as list optimized by numba
|
|
70
|
+
a_sum = 0.0
|
|
71
|
+
b_sum = 0.0
|
|
72
|
+
c_sum = 0.0
|
|
73
|
+
for a, b, c in structure:
|
|
74
|
+
a_sum += a
|
|
75
|
+
b_sum += b
|
|
76
|
+
c_sum += c
|
|
77
|
+
|
|
59
78
|
return np.round(
|
|
60
79
|
np.array(
|
|
61
80
|
[
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
81
|
+
a_sum / len(structure),
|
|
82
|
+
b_sum / len(structure),
|
|
83
|
+
c_sum / len(structure),
|
|
65
84
|
]
|
|
66
85
|
)
|
|
67
86
|
)
|
|
@@ -69,15 +88,18 @@ def get_structure_centre(structure: np.ndarray) -> np.ndarray:
|
|
|
69
88
|
|
|
70
89
|
# Type declaration has to come outside of the class,
|
|
71
90
|
# see https://github.com/numba/numba/issues/8808
|
|
72
|
-
|
|
91
|
+
tuple_point_type = types.Tuple(
|
|
92
|
+
(vol_numba_type, vol_numba_type, vol_numba_type)
|
|
93
|
+
)
|
|
94
|
+
list_of_points_type = types.ListType(tuple_point_type)
|
|
73
95
|
|
|
74
96
|
|
|
75
97
|
spec = [
|
|
76
|
-
("z",
|
|
77
|
-
("next_structure_id",
|
|
78
|
-
("shape", types.UniTuple(
|
|
79
|
-
("obsolete_ids", DictType(
|
|
80
|
-
("coords_maps", DictType(
|
|
98
|
+
("z", vol_numba_type),
|
|
99
|
+
("next_structure_id", sid_numba_type),
|
|
100
|
+
("shape", types.UniTuple(vol_numba_type, 2)),
|
|
101
|
+
("obsolete_ids", DictType(sid_numba_type, sid_numba_type)),
|
|
102
|
+
("coords_maps", DictType(sid_numba_type, list_of_points_type)),
|
|
81
103
|
]
|
|
82
104
|
|
|
83
105
|
|
|
@@ -103,8 +125,12 @@ class CellDetector:
|
|
|
103
125
|
are scanned.
|
|
104
126
|
coords_maps :
|
|
105
127
|
Mapping from structure ID to the coordinates of pixels within that
|
|
106
|
-
structure. Coordinates are stored in a
|
|
107
|
-
|
|
128
|
+
structure. Coordinates are stored in a list of (x, y, z) tuples of
|
|
129
|
+
the coordinates.
|
|
130
|
+
|
|
131
|
+
Use `get_structures` to get it as a dict whose values are each
|
|
132
|
+
a 2D array, where rows are points, and columns x, y, z of the
|
|
133
|
+
points.
|
|
108
134
|
"""
|
|
109
135
|
|
|
110
136
|
def __init__(self, width: int, height: int, start_z: int):
|
|
@@ -123,11 +149,11 @@ class CellDetector:
|
|
|
123
149
|
# Mapping from obsolete IDs to the IDs that they have been
|
|
124
150
|
# made obsolete by
|
|
125
151
|
self.obsolete_ids = numba.typed.Dict.empty(
|
|
126
|
-
key_type=
|
|
152
|
+
key_type=sid_numba_type, value_type=sid_numba_type
|
|
127
153
|
)
|
|
128
154
|
# Mapping from IDs to list of points in that structure
|
|
129
155
|
self.coords_maps = numba.typed.Dict.empty(
|
|
130
|
-
key_type=
|
|
156
|
+
key_type=sid_numba_type, value_type=list_of_points_type
|
|
131
157
|
)
|
|
132
158
|
|
|
133
159
|
def process(
|
|
@@ -136,7 +162,7 @@ class CellDetector:
|
|
|
136
162
|
"""
|
|
137
163
|
Process a new plane.
|
|
138
164
|
"""
|
|
139
|
-
if
|
|
165
|
+
if plane.shape[:2] != self.shape:
|
|
140
166
|
raise ValueError("plane does not have correct shape")
|
|
141
167
|
|
|
142
168
|
plane = self.connect_four(plane, previous_plane)
|
|
@@ -166,7 +192,7 @@ class CellDetector:
|
|
|
166
192
|
for x in range(plane.shape[0]):
|
|
167
193
|
if plane[x, y] == SOMA_CENTRE_VALUE:
|
|
168
194
|
# Labels of structures below, left and behind
|
|
169
|
-
neighbour_ids = np.zeros(3, dtype=
|
|
195
|
+
neighbour_ids = np.zeros(3, dtype=sid_np_type)
|
|
170
196
|
# If in bounds look at neighbours
|
|
171
197
|
if x > 0:
|
|
172
198
|
neighbour_ids[0] = plane[x - 1, y]
|
|
@@ -191,17 +217,54 @@ class CellDetector:
|
|
|
191
217
|
def get_cell_centres(self) -> np.ndarray:
|
|
192
218
|
return self.structures_to_cells()
|
|
193
219
|
|
|
194
|
-
def
|
|
195
|
-
|
|
220
|
+
def get_structures(self) -> Dict[int, np.ndarray]:
|
|
221
|
+
"""
|
|
222
|
+
Gets the structures as a dict of structure IDs mapped to the 2D array
|
|
223
|
+
of structure points.
|
|
224
|
+
"""
|
|
225
|
+
d = {}
|
|
226
|
+
for sid, points in self.coords_maps.items():
|
|
227
|
+
# numba silliness - it cannot handle
|
|
228
|
+
# `item = np.array(points, dtype=vol_np_type)` so we need to create
|
|
229
|
+
# array and then fill in the point
|
|
230
|
+
item = np.empty((len(points), 3), dtype=vol_np_type)
|
|
231
|
+
d[sid] = item
|
|
232
|
+
|
|
233
|
+
for i, point in enumerate(points):
|
|
234
|
+
item[i, :] = point
|
|
235
|
+
|
|
236
|
+
return d
|
|
237
|
+
|
|
238
|
+
def add_point(
|
|
239
|
+
self, sid: int, point: Union[tuple, list, np.ndarray]
|
|
240
|
+
) -> None:
|
|
241
|
+
"""
|
|
242
|
+
Add single 3d *point* to the structure with the given *sid*.
|
|
243
|
+
"""
|
|
244
|
+
if sid not in self.coords_maps:
|
|
245
|
+
self.coords_maps[sid] = typed.List.empty_list(tuple_point_type)
|
|
246
|
+
|
|
247
|
+
self._add_point(sid, (int(point[0]), int(point[1]), int(point[2])))
|
|
196
248
|
|
|
197
|
-
def
|
|
249
|
+
def add_points(self, sid: int, points: np.ndarray):
|
|
198
250
|
"""
|
|
199
|
-
|
|
251
|
+
Adds ndarray of *points* to the structure with the given *sid*.
|
|
252
|
+
Each row is a 3d point.
|
|
200
253
|
"""
|
|
201
|
-
|
|
254
|
+
if sid not in self.coords_maps:
|
|
255
|
+
self.coords_maps[sid] = typed.List.empty_list(tuple_point_type)
|
|
256
|
+
|
|
257
|
+
append = self.coords_maps[sid].append
|
|
258
|
+
pts = np.round(points).astype(vol_np_type)
|
|
259
|
+
for point in pts:
|
|
260
|
+
append((point[0], point[1], point[2]))
|
|
261
|
+
|
|
262
|
+
def _add_point(self, sid: int, point: Tuple[int, int, int]) -> None:
|
|
263
|
+
# sid must exist
|
|
264
|
+
self.coords_maps[sid].append(point)
|
|
202
265
|
|
|
203
266
|
def add(
|
|
204
|
-
self, x: int, y: int, z: int, neighbour_ids: npt.NDArray[
|
|
267
|
+
self, x: int, y: int, z: int, neighbour_ids: npt.NDArray[sid_np_type]
|
|
205
268
|
) -> int:
|
|
206
269
|
"""
|
|
207
270
|
For the current coordinates takes all the neighbours and find the
|
|
@@ -215,17 +278,16 @@ class CellDetector:
|
|
|
215
278
|
"""
|
|
216
279
|
updated_id = self.sanitise_ids(neighbour_ids)
|
|
217
280
|
if updated_id not in self.coords_maps:
|
|
218
|
-
self.coords_maps[updated_id] =
|
|
219
|
-
|
|
281
|
+
self.coords_maps[updated_id] = typed.List.empty_list(
|
|
282
|
+
tuple_point_type
|
|
220
283
|
)
|
|
221
284
|
self.merge_structures(updated_id, neighbour_ids)
|
|
222
285
|
|
|
223
286
|
# Add point for that structure
|
|
224
|
-
|
|
225
|
-
self.add_point(updated_id, point)
|
|
287
|
+
self._add_point(updated_id, (int(x), int(y), int(z)))
|
|
226
288
|
return updated_id
|
|
227
289
|
|
|
228
|
-
def sanitise_ids(self, neighbour_ids: npt.NDArray[
|
|
290
|
+
def sanitise_ids(self, neighbour_ids: npt.NDArray[sid_np_type]) -> int:
|
|
229
291
|
"""
|
|
230
292
|
Get the smallest ID of all the structures that are connected to IDs
|
|
231
293
|
in `neighbour_ids`.
|
|
@@ -246,7 +308,7 @@ class CellDetector:
|
|
|
246
308
|
return int(updated_id)
|
|
247
309
|
|
|
248
310
|
def merge_structures(
|
|
249
|
-
self, updated_id: int, neighbour_ids: npt.NDArray[
|
|
311
|
+
self, updated_id: int, neighbour_ids: npt.NDArray[sid_np_type]
|
|
250
312
|
) -> None:
|
|
251
313
|
"""
|
|
252
314
|
For all the neighbours, reassign all the points of neighbour to
|
|
@@ -261,14 +323,16 @@ class CellDetector:
|
|
|
261
323
|
# minimise ID so if neighbour with higher ID, reassign its points
|
|
262
324
|
# to current
|
|
263
325
|
if neighbour_id > updated_id:
|
|
264
|
-
self.
|
|
326
|
+
self.coords_maps[updated_id].extend(
|
|
327
|
+
self.coords_maps[neighbour_id]
|
|
328
|
+
)
|
|
265
329
|
self.coords_maps.pop(neighbour_id)
|
|
266
330
|
self.obsolete_ids[neighbour_id] = updated_id
|
|
267
331
|
|
|
268
332
|
def structures_to_cells(self) -> np.ndarray:
|
|
269
|
-
cell_centres = np.empty((len(self.coords_maps
|
|
333
|
+
cell_centres = np.empty((len(self.coords_maps), 3))
|
|
270
334
|
for idx, structure in enumerate(self.coords_maps.values()):
|
|
271
|
-
p =
|
|
335
|
+
p = _get_structure_centre(structure)
|
|
272
336
|
cell_centres[idx] = p
|
|
273
337
|
return cell_centres
|
|
274
338
|
|
|
@@ -71,7 +71,7 @@ def ball_filter_imgs(
|
|
|
71
71
|
"""
|
|
72
72
|
# OPTIMISE: reuse ball filter instance
|
|
73
73
|
|
|
74
|
-
good_tiles_mask = np.ones((1, 1, volume.shape[2]), dtype=
|
|
74
|
+
good_tiles_mask = np.ones((1, 1, volume.shape[2]), dtype=np.bool_)
|
|
75
75
|
|
|
76
76
|
plane_width, plane_height = volume.shape[:2]
|
|
77
77
|
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import math
|
|
2
|
+
import multiprocessing.pool
|
|
2
3
|
import os
|
|
4
|
+
from functools import partial
|
|
3
5
|
from queue import Queue
|
|
4
6
|
from threading import Lock
|
|
5
7
|
from typing import Any, Callable, List, Optional, Tuple
|
|
@@ -77,7 +79,7 @@ class VolumeFilter(object):
|
|
|
77
79
|
locks: List[Lock],
|
|
78
80
|
*,
|
|
79
81
|
callback: Callable[[int], None],
|
|
80
|
-
) ->
|
|
82
|
+
) -> None:
|
|
81
83
|
progress_bar = tqdm(total=self.n_planes, desc="Processing planes")
|
|
82
84
|
for z in range(self.n_planes):
|
|
83
85
|
# Get result from the queue.
|
|
@@ -108,11 +110,13 @@ class VolumeFilter(object):
|
|
|
108
110
|
|
|
109
111
|
progress_bar.close()
|
|
110
112
|
logger.debug("3D filter done")
|
|
111
|
-
return self.get_results()
|
|
112
113
|
|
|
113
114
|
def _run_filter(self) -> None:
|
|
114
115
|
logger.debug(f"🏐 Ball filtering plane {self.z}")
|
|
115
|
-
|
|
116
|
+
# filtering original images, the images should be large enough in x/y
|
|
117
|
+
# to benefit from parallelization. Note: don't pass arg as keyword arg
|
|
118
|
+
# because numba gets stuck (probably b/c class jit is new)
|
|
119
|
+
self.ball_filter.walk(True)
|
|
116
120
|
|
|
117
121
|
middle_plane = self.ball_filter.get_middle_plane()
|
|
118
122
|
if self.save_planes:
|
|
@@ -134,7 +138,7 @@ class VolumeFilter(object):
|
|
|
134
138
|
f_path = os.path.join(self.plane_directory, plane_name)
|
|
135
139
|
tifffile.imsave(f_path, plane.T)
|
|
136
140
|
|
|
137
|
-
def get_results(self) -> List[Cell]:
|
|
141
|
+
def get_results(self, worker_pool: multiprocessing.Pool) -> List[Cell]:
|
|
138
142
|
logger.info("Splitting cell clusters and writing results")
|
|
139
143
|
|
|
140
144
|
max_cell_volume = sphere_volume(
|
|
@@ -142,62 +146,57 @@ class VolumeFilter(object):
|
|
|
142
146
|
)
|
|
143
147
|
|
|
144
148
|
cells = []
|
|
149
|
+
needs_split = []
|
|
150
|
+
structures = self.cell_detector.get_structures().items()
|
|
151
|
+
logger.debug(f"Processing {len(structures)} found cells")
|
|
145
152
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
)
|
|
149
|
-
for cell_id, cell_points in self.cell_detector.coords_maps.items():
|
|
153
|
+
# first get all the cells that are not clusters
|
|
154
|
+
for cell_id, cell_points in structures:
|
|
150
155
|
cell_volume = len(cell_points)
|
|
151
156
|
|
|
152
157
|
if cell_volume < max_cell_volume:
|
|
153
158
|
cell_centre = get_structure_centre(cell_points)
|
|
154
|
-
cells.append(
|
|
155
|
-
Cell(
|
|
156
|
-
(
|
|
157
|
-
cell_centre[0],
|
|
158
|
-
cell_centre[1],
|
|
159
|
-
cell_centre[2],
|
|
160
|
-
),
|
|
161
|
-
Cell.UNKNOWN,
|
|
162
|
-
)
|
|
163
|
-
)
|
|
159
|
+
cells.append(Cell(cell_centre.tolist(), Cell.UNKNOWN))
|
|
164
160
|
else:
|
|
165
161
|
if cell_volume < self.max_cluster_size:
|
|
166
|
-
|
|
167
|
-
cell_centres = split_cells(
|
|
168
|
-
cell_points, outlier_keep=self.outlier_keep
|
|
169
|
-
)
|
|
170
|
-
except (ValueError, AssertionError) as err:
|
|
171
|
-
raise StructureSplitException(
|
|
172
|
-
f"Cell {cell_id}, error; {err}"
|
|
173
|
-
)
|
|
174
|
-
for cell_centre in cell_centres:
|
|
175
|
-
cells.append(
|
|
176
|
-
Cell(
|
|
177
|
-
(
|
|
178
|
-
cell_centre[0],
|
|
179
|
-
cell_centre[1],
|
|
180
|
-
cell_centre[2],
|
|
181
|
-
),
|
|
182
|
-
Cell.UNKNOWN,
|
|
183
|
-
)
|
|
184
|
-
)
|
|
162
|
+
needs_split.append((cell_id, cell_points))
|
|
185
163
|
else:
|
|
186
164
|
cell_centre = get_structure_centre(cell_points)
|
|
187
|
-
cells.append(
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
165
|
+
cells.append(Cell(cell_centre.tolist(), Cell.ARTIFACT))
|
|
166
|
+
|
|
167
|
+
if not needs_split:
|
|
168
|
+
logger.debug("Finished splitting cell clusters - none found")
|
|
169
|
+
return cells
|
|
170
|
+
|
|
171
|
+
# now split clusters into cells
|
|
172
|
+
logger.debug(f"Splitting {len(needs_split)} clusters")
|
|
173
|
+
progress_bar = tqdm(
|
|
174
|
+
total=len(needs_split), desc="Splitting cell clusters"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# we are not returning Cell instances from func because it'd be pickled
|
|
178
|
+
# by multiprocess which slows it down
|
|
179
|
+
func = partial(_split_cells, outlier_keep=self.outlier_keep)
|
|
180
|
+
for cell_centres in worker_pool.imap_unordered(func, needs_split):
|
|
181
|
+
for cell_centre in cell_centres:
|
|
182
|
+
cells.append(Cell(cell_centre.tolist(), Cell.UNKNOWN))
|
|
183
|
+
progress_bar.update()
|
|
184
|
+
|
|
185
|
+
progress_bar.close()
|
|
186
|
+
logger.debug(
|
|
187
|
+
f"Finished splitting cell clusters. Found {len(cells)} total cells"
|
|
188
|
+
)
|
|
189
|
+
|
|
199
190
|
return cells
|
|
200
191
|
|
|
201
192
|
|
|
193
|
+
def _split_cells(arg, outlier_keep):
|
|
194
|
+
cell_id, cell_points = arg
|
|
195
|
+
try:
|
|
196
|
+
return split_cells(cell_points, outlier_keep=outlier_keep)
|
|
197
|
+
except (ValueError, AssertionError) as err:
|
|
198
|
+
raise StructureSplitException(f"Cell {cell_id}, error; {err}")
|
|
199
|
+
|
|
200
|
+
|
|
202
201
|
def sphere_volume(radius: float) -> float:
|
|
203
202
|
return (4 / 3) * math.pi * radius**3
|
cellfinder/core/download/cli.py
CHANGED
|
@@ -1,17 +1,29 @@
|
|
|
1
|
-
import tempfile
|
|
2
1
|
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
|
3
2
|
from pathlib import Path
|
|
4
3
|
|
|
5
|
-
from cellfinder.core.download import
|
|
6
|
-
|
|
4
|
+
from cellfinder.core.download.download import (
|
|
5
|
+
DEFAULT_DOWNLOAD_DIRECTORY,
|
|
6
|
+
amend_user_configuration,
|
|
7
|
+
download_models,
|
|
8
|
+
)
|
|
7
9
|
|
|
8
|
-
home = Path.home()
|
|
9
|
-
DEFAULT_DOWNLOAD_DIRECTORY = home / ".cellfinder"
|
|
10
|
-
temp_dir = tempfile.TemporaryDirectory()
|
|
11
|
-
temp_dir_path = Path(temp_dir.name)
|
|
12
10
|
|
|
11
|
+
def download_parser(parser: ArgumentParser) -> ArgumentParser:
|
|
12
|
+
"""
|
|
13
|
+
Configure the argument parser for downloading files.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
parser : ArgumentParser
|
|
18
|
+
The argument parser to configure.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
ArgumentParser
|
|
23
|
+
The configured argument parser.
|
|
24
|
+
|
|
25
|
+
"""
|
|
13
26
|
|
|
14
|
-
def download_directory_parser(parser):
|
|
15
27
|
parser.add_argument(
|
|
16
28
|
"--install-path",
|
|
17
29
|
dest="install_path",
|
|
@@ -19,29 +31,12 @@ def download_directory_parser(parser):
|
|
|
19
31
|
default=DEFAULT_DOWNLOAD_DIRECTORY,
|
|
20
32
|
help="The path to install files to.",
|
|
21
33
|
)
|
|
22
|
-
parser.add_argument(
|
|
23
|
-
"--download-path",
|
|
24
|
-
dest="download_path",
|
|
25
|
-
type=Path,
|
|
26
|
-
default=temp_dir_path,
|
|
27
|
-
help="The path to download files into.",
|
|
28
|
-
)
|
|
29
34
|
parser.add_argument(
|
|
30
35
|
"--no-amend-config",
|
|
31
36
|
dest="no_amend_config",
|
|
32
37
|
action="store_true",
|
|
33
38
|
help="Don't amend the config file",
|
|
34
39
|
)
|
|
35
|
-
return parser
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def model_parser(parser):
|
|
39
|
-
parser.add_argument(
|
|
40
|
-
"--no-models",
|
|
41
|
-
dest="no_models",
|
|
42
|
-
action="store_true",
|
|
43
|
-
help="Don't download the model",
|
|
44
|
-
)
|
|
45
40
|
parser.add_argument(
|
|
46
41
|
"--model",
|
|
47
42
|
dest="model",
|
|
@@ -52,17 +47,29 @@ def model_parser(parser):
|
|
|
52
47
|
return parser
|
|
53
48
|
|
|
54
49
|
|
|
55
|
-
def
|
|
50
|
+
def get_parser() -> ArgumentParser:
|
|
51
|
+
"""
|
|
52
|
+
Create an argument parser for downloading files.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
ArgumentParser
|
|
57
|
+
The configured argument parser.
|
|
58
|
+
|
|
59
|
+
"""
|
|
56
60
|
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
|
57
|
-
parser =
|
|
58
|
-
parser = download_directory_parser(parser)
|
|
61
|
+
parser = download_parser(parser)
|
|
59
62
|
return parser
|
|
60
63
|
|
|
61
64
|
|
|
62
|
-
def main():
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
65
|
+
def main() -> None:
|
|
66
|
+
"""
|
|
67
|
+
Run the main download function, and optionally amend the user
|
|
68
|
+
configuration.
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
args = get_parser().parse_args()
|
|
72
|
+
model_path = download_models(args.model, args.install_path)
|
|
66
73
|
|
|
67
74
|
if not args.no_amend_config:
|
|
68
75
|
amend_user_configuration(new_model_path=model_path)
|
|
@@ -1,79 +1,67 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import shutil
|
|
3
|
-
import tarfile
|
|
4
|
-
import urllib.request
|
|
5
2
|
from pathlib import Path
|
|
3
|
+
from typing import Literal
|
|
6
4
|
|
|
5
|
+
import pooch
|
|
7
6
|
from brainglobe_utils.general.config import get_config_obj
|
|
8
|
-
from brainglobe_utils.general.system import disk_free_gb
|
|
9
7
|
|
|
8
|
+
from cellfinder import DEFAULT_CELLFINDER_DIRECTORY
|
|
10
9
|
from cellfinder.core.tools.source_files import (
|
|
11
10
|
default_configuration_path,
|
|
12
11
|
user_specific_configuration_path,
|
|
13
12
|
)
|
|
14
13
|
|
|
14
|
+
DEFAULT_DOWNLOAD_DIRECTORY = DEFAULT_CELLFINDER_DIRECTORY / "models"
|
|
15
15
|
|
|
16
|
-
class DownloadError(Exception):
|
|
17
|
-
pass
|
|
18
16
|
|
|
17
|
+
MODEL_URL = "https://gin.g-node.org/cellfinder/models/raw/master"
|
|
19
18
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
with urllib.request.urlopen(file_url) as response:
|
|
25
|
-
with open(destination_path, "wb") as outfile:
|
|
26
|
-
shutil.copyfileobj(response, outfile)
|
|
19
|
+
model_filenames = {
|
|
20
|
+
"resnet50_tv": "resnet50_tv.h5",
|
|
21
|
+
"resnet50_all": "resnet50_weights.h5",
|
|
22
|
+
}
|
|
27
23
|
|
|
24
|
+
model_hashes = {
|
|
25
|
+
"resnet50_tv": "63d36af456640590ba6c896dc519f9f29861015084f4c40777a54c18c1fc4edd", # noqa: E501
|
|
26
|
+
"resnet50_all": None,
|
|
27
|
+
}
|
|
28
28
|
|
|
29
|
-
def extract_file(tar_file_path, destination_path):
|
|
30
|
-
tar = tarfile.open(tar_file_path)
|
|
31
|
-
tar.extractall(path=destination_path)
|
|
32
|
-
tar.close()
|
|
33
29
|
|
|
30
|
+
model_type = Literal["resnet50_tv", "resnet50_all"]
|
|
34
31
|
|
|
35
|
-
# TODO: check that intermediate folders exist
|
|
36
|
-
def download(
|
|
37
|
-
download_path,
|
|
38
|
-
url,
|
|
39
|
-
file_name,
|
|
40
|
-
install_path=None,
|
|
41
|
-
download_requires=None,
|
|
42
|
-
extract_requires=None,
|
|
43
|
-
):
|
|
44
|
-
if not os.path.exists(os.path.dirname(download_path)):
|
|
45
|
-
raise DownloadError(
|
|
46
|
-
f"Could not find directory '{os.path.dirname(download_path)}' "
|
|
47
|
-
f"to download file: {file_name}"
|
|
48
|
-
)
|
|
49
32
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
33
|
+
def download_models(
|
|
34
|
+
model_name: model_type, download_path: os.PathLike
|
|
35
|
+
) -> Path:
|
|
36
|
+
"""
|
|
37
|
+
For a given model name and download path, download the model file
|
|
38
|
+
and return the path to the downloaded file.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
model_name : model_type
|
|
43
|
+
The name of the model to be downloaded.
|
|
44
|
+
download_path : os.PathLike
|
|
45
|
+
The path where the model file will be downloaded.
|
|
57
46
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
Path
|
|
50
|
+
The path to the downloaded model file.
|
|
51
|
+
|
|
52
|
+
"""
|
|
64
53
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
54
|
+
download_path = Path(download_path)
|
|
55
|
+
filename = model_filenames[model_name]
|
|
56
|
+
model_path = pooch.retrieve(
|
|
57
|
+
url=f"{MODEL_URL}/{filename}",
|
|
58
|
+
known_hash=model_hashes[model_name],
|
|
59
|
+
path=download_path,
|
|
60
|
+
fname=filename,
|
|
61
|
+
progressbar=True,
|
|
62
|
+
)
|
|
72
63
|
|
|
73
|
-
|
|
74
|
-
if install_path is not None:
|
|
75
|
-
extract_file(download_path, install_path)
|
|
76
|
-
os.remove(download_path)
|
|
64
|
+
return Path(model_path)
|
|
77
65
|
|
|
78
66
|
|
|
79
67
|
def amend_user_configuration(new_model_path=None) -> None:
|
|
@@ -83,7 +71,7 @@ def amend_user_configuration(new_model_path=None) -> None:
|
|
|
83
71
|
|
|
84
72
|
Parameters
|
|
85
73
|
----------
|
|
86
|
-
new_model_path :
|
|
74
|
+
new_model_path : Path, optional
|
|
87
75
|
The path to the new model configuration.
|
|
88
76
|
"""
|
|
89
77
|
print("(Over-)writing custom user configuration")
|