cellfinder 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,22 @@
1
1
  from dataclasses import dataclass
2
- from typing import Dict, Optional, TypeVar
2
+ from typing import Dict, Optional, Tuple, TypeVar, Union
3
3
 
4
4
  import numba.typed
5
5
  import numpy as np
6
6
  import numpy.typing as npt
7
- from numba import njit
7
+ from numba import njit, typed
8
8
  from numba.core import types
9
9
  from numba.experimental import jitclass
10
10
  from numba.types import DictType
11
11
 
12
+ T = TypeVar("T")
13
+ # type used for the domain of the volume - the size of the vol
14
+ vol_np_type = np.int64
15
+ vol_numba_type = types.int64
16
+ # type used for the structure id
17
+ sid_np_type = np.int64
18
+ sid_numba_type = types.int64
19
+
12
20
 
13
21
  @dataclass
14
22
  class Point:
@@ -32,18 +40,15 @@ def get_non_zero_dtype_min(values: np.ndarray) -> int:
32
40
  return min_val
33
41
 
34
42
 
35
- T = TypeVar("T")
36
-
37
-
38
43
  @njit
39
44
  def traverse_dict(d: Dict[T, T], a: T) -> T:
40
45
  """
41
46
  Traverse d, until a is not present as a key.
42
47
  """
43
- if a in d:
44
- return traverse_dict(d, d[a])
45
- else:
46
- return a
48
+ value = a
49
+ while value in d:
50
+ value = d[value]
51
+ return value
47
52
 
48
53
 
49
54
  @njit
@@ -54,14 +59,28 @@ def get_structure_centre(structure: np.ndarray) -> np.ndarray:
54
59
  Centre calculated as the mean of each pixel coordinate,
55
60
  rounded to the nearest integer.
56
61
  """
57
- # can't do np.mean(structure, axis=0)
58
- # because axis is not supported by numba
62
+ # numba support axis for sum, but not mean
63
+ return np.round(np.sum(structure, axis=0) / structure.shape[0])
64
+
65
+
66
+ @njit
67
+ def _get_structure_centre(structure: types.ListType) -> np.ndarray:
68
+ # See get_structure_centre.
69
+ # this is for our own points stored as list optimized by numba
70
+ a_sum = 0.0
71
+ b_sum = 0.0
72
+ c_sum = 0.0
73
+ for a, b, c in structure:
74
+ a_sum += a
75
+ b_sum += b
76
+ c_sum += c
77
+
59
78
  return np.round(
60
79
  np.array(
61
80
  [
62
- np.mean(structure[:, 0]),
63
- np.mean(structure[:, 1]),
64
- np.mean(structure[:, 2]),
81
+ a_sum / len(structure),
82
+ b_sum / len(structure),
83
+ c_sum / len(structure),
65
84
  ]
66
85
  )
67
86
  )
@@ -69,15 +88,18 @@ def get_structure_centre(structure: np.ndarray) -> np.ndarray:
69
88
 
70
89
  # Type declaration has to come outside of the class,
71
90
  # see https://github.com/numba/numba/issues/8808
72
- uint_2d_type = types.uint64[:, :]
91
+ tuple_point_type = types.Tuple(
92
+ (vol_numba_type, vol_numba_type, vol_numba_type)
93
+ )
94
+ list_of_points_type = types.ListType(tuple_point_type)
73
95
 
74
96
 
75
97
  spec = [
76
- ("z", types.uint64),
77
- ("next_structure_id", types.uint64),
78
- ("shape", types.UniTuple(types.int64, 2)),
79
- ("obsolete_ids", DictType(types.int64, types.int64)),
80
- ("coords_maps", DictType(types.uint64, uint_2d_type)),
98
+ ("z", vol_numba_type),
99
+ ("next_structure_id", sid_numba_type),
100
+ ("shape", types.UniTuple(vol_numba_type, 2)),
101
+ ("obsolete_ids", DictType(sid_numba_type, sid_numba_type)),
102
+ ("coords_maps", DictType(sid_numba_type, list_of_points_type)),
81
103
  ]
82
104
 
83
105
 
@@ -103,8 +125,12 @@ class CellDetector:
103
125
  are scanned.
104
126
  coords_maps :
105
127
  Mapping from structure ID to the coordinates of pixels within that
106
- structure. Coordinates are stored in a 2D array, with the second
107
- axis indexing (x, y, z) coordinates.
128
+ structure. Coordinates are stored in a list of (x, y, z) tuples of
129
+ the coordinates.
130
+
131
+ Use `get_structures` to get it as a dict whose values are each
132
+ a 2D array, where rows are points, and columns x, y, z of the
133
+ points.
108
134
  """
109
135
 
110
136
  def __init__(self, width: int, height: int, start_z: int):
@@ -123,11 +149,11 @@ class CellDetector:
123
149
  # Mapping from obsolete IDs to the IDs that they have been
124
150
  # made obsolete by
125
151
  self.obsolete_ids = numba.typed.Dict.empty(
126
- key_type=types.int64, value_type=types.int64
152
+ key_type=sid_numba_type, value_type=sid_numba_type
127
153
  )
128
154
  # Mapping from IDs to list of points in that structure
129
155
  self.coords_maps = numba.typed.Dict.empty(
130
- key_type=types.int64, value_type=uint_2d_type
156
+ key_type=sid_numba_type, value_type=list_of_points_type
131
157
  )
132
158
 
133
159
  def process(
@@ -136,7 +162,7 @@ class CellDetector:
136
162
  """
137
163
  Process a new plane.
138
164
  """
139
- if [e for e in plane.shape[:2]] != [e for e in self.shape]:
165
+ if plane.shape[:2] != self.shape:
140
166
  raise ValueError("plane does not have correct shape")
141
167
 
142
168
  plane = self.connect_four(plane, previous_plane)
@@ -166,7 +192,7 @@ class CellDetector:
166
192
  for x in range(plane.shape[0]):
167
193
  if plane[x, y] == SOMA_CENTRE_VALUE:
168
194
  # Labels of structures below, left and behind
169
- neighbour_ids = np.zeros(3, dtype=np.uint64)
195
+ neighbour_ids = np.zeros(3, dtype=sid_np_type)
170
196
  # If in bounds look at neighbours
171
197
  if x > 0:
172
198
  neighbour_ids[0] = plane[x - 1, y]
@@ -191,17 +217,54 @@ class CellDetector:
191
217
  def get_cell_centres(self) -> np.ndarray:
192
218
  return self.structures_to_cells()
193
219
 
194
- def get_coords_dict(self) -> Dict:
195
- return self.coords_maps
220
+ def get_structures(self) -> Dict[int, np.ndarray]:
221
+ """
222
+ Gets the structures as a dict of structure IDs mapped to the 2D array
223
+ of structure points.
224
+ """
225
+ d = {}
226
+ for sid, points in self.coords_maps.items():
227
+ # numba silliness - it cannot handle
228
+ # `item = np.array(points, dtype=vol_np_type)` so we need to create
229
+ # array and then fill in the point
230
+ item = np.empty((len(points), 3), dtype=vol_np_type)
231
+ d[sid] = item
232
+
233
+ for i, point in enumerate(points):
234
+ item[i, :] = point
235
+
236
+ return d
237
+
238
+ def add_point(
239
+ self, sid: int, point: Union[tuple, list, np.ndarray]
240
+ ) -> None:
241
+ """
242
+ Add single 3d *point* to the structure with the given *sid*.
243
+ """
244
+ if sid not in self.coords_maps:
245
+ self.coords_maps[sid] = typed.List.empty_list(tuple_point_type)
246
+
247
+ self._add_point(sid, (int(point[0]), int(point[1]), int(point[2])))
196
248
 
197
- def add_point(self, sid: int, point: np.ndarray) -> None:
249
+ def add_points(self, sid: int, points: np.ndarray):
198
250
  """
199
- Add *point* to the structure with the given *sid*.
251
+ Adds ndarray of *points* to the structure with the given *sid*.
252
+ Each row is a 3d point.
200
253
  """
201
- self.coords_maps[sid] = np.row_stack((self.coords_maps[sid], point))
254
+ if sid not in self.coords_maps:
255
+ self.coords_maps[sid] = typed.List.empty_list(tuple_point_type)
256
+
257
+ append = self.coords_maps[sid].append
258
+ pts = np.round(points).astype(vol_np_type)
259
+ for point in pts:
260
+ append((point[0], point[1], point[2]))
261
+
262
+ def _add_point(self, sid: int, point: Tuple[int, int, int]) -> None:
263
+ # sid must exist
264
+ self.coords_maps[sid].append(point)
202
265
 
203
266
  def add(
204
- self, x: int, y: int, z: int, neighbour_ids: npt.NDArray[np.uint64]
267
+ self, x: int, y: int, z: int, neighbour_ids: npt.NDArray[sid_np_type]
205
268
  ) -> int:
206
269
  """
207
270
  For the current coordinates takes all the neighbours and find the
@@ -215,17 +278,16 @@ class CellDetector:
215
278
  """
216
279
  updated_id = self.sanitise_ids(neighbour_ids)
217
280
  if updated_id not in self.coords_maps:
218
- self.coords_maps[updated_id] = np.zeros(
219
- shape=(0, 3), dtype=np.uint64
281
+ self.coords_maps[updated_id] = typed.List.empty_list(
282
+ tuple_point_type
220
283
  )
221
284
  self.merge_structures(updated_id, neighbour_ids)
222
285
 
223
286
  # Add point for that structure
224
- point = np.array([[x, y, z]], dtype=np.uint64)
225
- self.add_point(updated_id, point)
287
+ self._add_point(updated_id, (int(x), int(y), int(z)))
226
288
  return updated_id
227
289
 
228
- def sanitise_ids(self, neighbour_ids: npt.NDArray[np.uint64]) -> int:
290
+ def sanitise_ids(self, neighbour_ids: npt.NDArray[sid_np_type]) -> int:
229
291
  """
230
292
  Get the smallest ID of all the structures that are connected to IDs
231
293
  in `neighbour_ids`.
@@ -246,7 +308,7 @@ class CellDetector:
246
308
  return int(updated_id)
247
309
 
248
310
  def merge_structures(
249
- self, updated_id: int, neighbour_ids: npt.NDArray[np.uint64]
311
+ self, updated_id: int, neighbour_ids: npt.NDArray[sid_np_type]
250
312
  ) -> None:
251
313
  """
252
314
  For all the neighbours, reassign all the points of neighbour to
@@ -261,14 +323,16 @@ class CellDetector:
261
323
  # minimise ID so if neighbour with higher ID, reassign its points
262
324
  # to current
263
325
  if neighbour_id > updated_id:
264
- self.add_point(updated_id, self.coords_maps[neighbour_id])
326
+ self.coords_maps[updated_id].extend(
327
+ self.coords_maps[neighbour_id]
328
+ )
265
329
  self.coords_maps.pop(neighbour_id)
266
330
  self.obsolete_ids[neighbour_id] = updated_id
267
331
 
268
332
  def structures_to_cells(self) -> np.ndarray:
269
- cell_centres = np.empty((len(self.coords_maps.keys()), 3))
333
+ cell_centres = np.empty((len(self.coords_maps), 3))
270
334
  for idx, structure in enumerate(self.coords_maps.values()):
271
- p = get_structure_centre(structure)
335
+ p = _get_structure_centre(structure)
272
336
  cell_centres[idx] = p
273
337
  return cell_centres
274
338
 
@@ -71,7 +71,7 @@ def ball_filter_imgs(
71
71
  """
72
72
  # OPTIMISE: reuse ball filter instance
73
73
 
74
- good_tiles_mask = np.ones((1, 1, volume.shape[2]), dtype=bool)
74
+ good_tiles_mask = np.ones((1, 1, volume.shape[2]), dtype=np.bool_)
75
75
 
76
76
  plane_width, plane_height = volume.shape[:2]
77
77
 
@@ -1,5 +1,7 @@
1
1
  import math
2
+ import multiprocessing.pool
2
3
  import os
4
+ from functools import partial
3
5
  from queue import Queue
4
6
  from threading import Lock
5
7
  from typing import Any, Callable, List, Optional, Tuple
@@ -77,7 +79,7 @@ class VolumeFilter(object):
77
79
  locks: List[Lock],
78
80
  *,
79
81
  callback: Callable[[int], None],
80
- ) -> List[Cell]:
82
+ ) -> None:
81
83
  progress_bar = tqdm(total=self.n_planes, desc="Processing planes")
82
84
  for z in range(self.n_planes):
83
85
  # Get result from the queue.
@@ -108,11 +110,13 @@ class VolumeFilter(object):
108
110
 
109
111
  progress_bar.close()
110
112
  logger.debug("3D filter done")
111
- return self.get_results()
112
113
 
113
114
  def _run_filter(self) -> None:
114
115
  logger.debug(f"🏐 Ball filtering plane {self.z}")
115
- self.ball_filter.walk()
116
+ # filtering original images, the images should be large enough in x/y
117
+ # to benefit from parallelization. Note: don't pass arg as keyword arg
118
+ # because numba gets stuck (probably b/c class jit is new)
119
+ self.ball_filter.walk(True)
116
120
 
117
121
  middle_plane = self.ball_filter.get_middle_plane()
118
122
  if self.save_planes:
@@ -134,7 +138,7 @@ class VolumeFilter(object):
134
138
  f_path = os.path.join(self.plane_directory, plane_name)
135
139
  tifffile.imsave(f_path, plane.T)
136
140
 
137
- def get_results(self) -> List[Cell]:
141
+ def get_results(self, worker_pool: multiprocessing.Pool) -> List[Cell]:
138
142
  logger.info("Splitting cell clusters and writing results")
139
143
 
140
144
  max_cell_volume = sphere_volume(
@@ -142,62 +146,57 @@ class VolumeFilter(object):
142
146
  )
143
147
 
144
148
  cells = []
149
+ needs_split = []
150
+ structures = self.cell_detector.get_structures().items()
151
+ logger.debug(f"Processing {len(structures)} found cells")
145
152
 
146
- logger.debug(
147
- f"Processing {len(self.cell_detector.coords_maps.items())} cells"
148
- )
149
- for cell_id, cell_points in self.cell_detector.coords_maps.items():
153
+ # first get all the cells that are not clusters
154
+ for cell_id, cell_points in structures:
150
155
  cell_volume = len(cell_points)
151
156
 
152
157
  if cell_volume < max_cell_volume:
153
158
  cell_centre = get_structure_centre(cell_points)
154
- cells.append(
155
- Cell(
156
- (
157
- cell_centre[0],
158
- cell_centre[1],
159
- cell_centre[2],
160
- ),
161
- Cell.UNKNOWN,
162
- )
163
- )
159
+ cells.append(Cell(cell_centre.tolist(), Cell.UNKNOWN))
164
160
  else:
165
161
  if cell_volume < self.max_cluster_size:
166
- try:
167
- cell_centres = split_cells(
168
- cell_points, outlier_keep=self.outlier_keep
169
- )
170
- except (ValueError, AssertionError) as err:
171
- raise StructureSplitException(
172
- f"Cell {cell_id}, error; {err}"
173
- )
174
- for cell_centre in cell_centres:
175
- cells.append(
176
- Cell(
177
- (
178
- cell_centre[0],
179
- cell_centre[1],
180
- cell_centre[2],
181
- ),
182
- Cell.UNKNOWN,
183
- )
184
- )
162
+ needs_split.append((cell_id, cell_points))
185
163
  else:
186
164
  cell_centre = get_structure_centre(cell_points)
187
- cells.append(
188
- Cell(
189
- (
190
- cell_centre[0],
191
- cell_centre[1],
192
- cell_centre[2],
193
- ),
194
- Cell.ARTIFACT,
195
- )
196
- )
197
-
198
- logger.debug("Finished splitting cell clusters.")
165
+ cells.append(Cell(cell_centre.tolist(), Cell.ARTIFACT))
166
+
167
+ if not needs_split:
168
+ logger.debug("Finished splitting cell clusters - none found")
169
+ return cells
170
+
171
+ # now split clusters into cells
172
+ logger.debug(f"Splitting {len(needs_split)} clusters")
173
+ progress_bar = tqdm(
174
+ total=len(needs_split), desc="Splitting cell clusters"
175
+ )
176
+
177
+ # we are not returning Cell instances from func because it'd be pickled
178
+ # by multiprocess which slows it down
179
+ func = partial(_split_cells, outlier_keep=self.outlier_keep)
180
+ for cell_centres in worker_pool.imap_unordered(func, needs_split):
181
+ for cell_centre in cell_centres:
182
+ cells.append(Cell(cell_centre.tolist(), Cell.UNKNOWN))
183
+ progress_bar.update()
184
+
185
+ progress_bar.close()
186
+ logger.debug(
187
+ f"Finished splitting cell clusters. Found {len(cells)} total cells"
188
+ )
189
+
199
190
  return cells
200
191
 
201
192
 
193
+ def _split_cells(arg, outlier_keep):
194
+ cell_id, cell_points = arg
195
+ try:
196
+ return split_cells(cell_points, outlier_keep=outlier_keep)
197
+ except (ValueError, AssertionError) as err:
198
+ raise StructureSplitException(f"Cell {cell_id}, error; {err}")
199
+
200
+
202
201
  def sphere_volume(radius: float) -> float:
203
202
  return (4 / 3) * math.pi * radius**3
@@ -1,17 +1,29 @@
1
- import tempfile
2
1
  from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
3
2
  from pathlib import Path
4
3
 
5
- from cellfinder.core.download import models
6
- from cellfinder.core.download.download import amend_user_configuration
4
+ from cellfinder.core.download.download import (
5
+ DEFAULT_DOWNLOAD_DIRECTORY,
6
+ amend_user_configuration,
7
+ download_models,
8
+ )
7
9
 
8
- home = Path.home()
9
- DEFAULT_DOWNLOAD_DIRECTORY = home / ".cellfinder"
10
- temp_dir = tempfile.TemporaryDirectory()
11
- temp_dir_path = Path(temp_dir.name)
12
10
 
11
+ def download_parser(parser: ArgumentParser) -> ArgumentParser:
12
+ """
13
+ Configure the argument parser for downloading files.
14
+
15
+ Parameters
16
+ ----------
17
+ parser : ArgumentParser
18
+ The argument parser to configure.
19
+
20
+ Returns
21
+ -------
22
+ ArgumentParser
23
+ The configured argument parser.
24
+
25
+ """
13
26
 
14
- def download_directory_parser(parser):
15
27
  parser.add_argument(
16
28
  "--install-path",
17
29
  dest="install_path",
@@ -19,29 +31,12 @@ def download_directory_parser(parser):
19
31
  default=DEFAULT_DOWNLOAD_DIRECTORY,
20
32
  help="The path to install files to.",
21
33
  )
22
- parser.add_argument(
23
- "--download-path",
24
- dest="download_path",
25
- type=Path,
26
- default=temp_dir_path,
27
- help="The path to download files into.",
28
- )
29
34
  parser.add_argument(
30
35
  "--no-amend-config",
31
36
  dest="no_amend_config",
32
37
  action="store_true",
33
38
  help="Don't amend the config file",
34
39
  )
35
- return parser
36
-
37
-
38
- def model_parser(parser):
39
- parser.add_argument(
40
- "--no-models",
41
- dest="no_models",
42
- action="store_true",
43
- help="Don't download the model",
44
- )
45
40
  parser.add_argument(
46
41
  "--model",
47
42
  dest="model",
@@ -52,17 +47,29 @@ def model_parser(parser):
52
47
  return parser
53
48
 
54
49
 
55
- def download_parser():
50
+ def get_parser() -> ArgumentParser:
51
+ """
52
+ Create an argument parser for downloading files.
53
+
54
+ Returns
55
+ -------
56
+ ArgumentParser
57
+ The configured argument parser.
58
+
59
+ """
56
60
  parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
57
- parser = model_parser(parser)
58
- parser = download_directory_parser(parser)
61
+ parser = download_parser(parser)
59
62
  return parser
60
63
 
61
64
 
62
- def main():
63
- args = download_parser().parse_args()
64
- if not args.no_models:
65
- model_path = models.main(args.model, args.install_path)
65
+ def main() -> None:
66
+ """
67
+ Run the main download function, and optionally amend the user
68
+ configuration.
69
+
70
+ """
71
+ args = get_parser().parse_args()
72
+ model_path = download_models(args.model, args.install_path)
66
73
 
67
74
  if not args.no_amend_config:
68
75
  amend_user_configuration(new_model_path=model_path)
@@ -1,78 +1,67 @@
1
1
  import os
2
- import shutil
3
- import tarfile
4
- import urllib.request
2
+ from pathlib import Path
3
+ from typing import Literal
5
4
 
5
+ import pooch
6
6
  from brainglobe_utils.general.config import get_config_obj
7
- from brainglobe_utils.general.system import disk_free_gb
8
7
 
8
+ from cellfinder import DEFAULT_CELLFINDER_DIRECTORY
9
9
  from cellfinder.core.tools.source_files import (
10
10
  default_configuration_path,
11
11
  user_specific_configuration_path,
12
12
  )
13
13
 
14
+ DEFAULT_DOWNLOAD_DIRECTORY = DEFAULT_CELLFINDER_DIRECTORY / "models"
14
15
 
15
- class DownloadError(Exception):
16
- pass
17
16
 
17
+ MODEL_URL = "https://gin.g-node.org/cellfinder/models/raw/master"
18
18
 
19
- def download_file(destination_path, file_url, filename):
20
- direct_download = True
21
- file_url = file_url.format(int(direct_download))
22
- print(f"Downloading file: {filename}")
23
- with urllib.request.urlopen(file_url) as response:
24
- with open(destination_path, "wb") as outfile:
25
- shutil.copyfileobj(response, outfile)
19
+ model_filenames = {
20
+ "resnet50_tv": "resnet50_tv.h5",
21
+ "resnet50_all": "resnet50_weights.h5",
22
+ }
26
23
 
24
+ model_hashes = {
25
+ "resnet50_tv": "63d36af456640590ba6c896dc519f9f29861015084f4c40777a54c18c1fc4edd", # noqa: E501
26
+ "resnet50_all": None,
27
+ }
27
28
 
28
- def extract_file(tar_file_path, destination_path):
29
- tar = tarfile.open(tar_file_path)
30
- tar.extractall(path=destination_path)
31
- tar.close()
32
29
 
30
+ model_type = Literal["resnet50_tv", "resnet50_all"]
33
31
 
34
- # TODO: check that intermediate folders exist
35
- def download(
36
- download_path,
37
- url,
38
- file_name,
39
- install_path=None,
40
- download_requires=None,
41
- extract_requires=None,
42
- ):
43
- if not os.path.exists(os.path.dirname(download_path)):
44
- raise DownloadError(
45
- f"Could not find directory '{os.path.dirname(download_path)}' "
46
- f"to download file: {file_name}"
47
- )
48
32
 
49
- if (download_requires is not None) and (
50
- disk_free_gb(os.path.dirname(download_path)) < download_requires
51
- ):
52
- raise DownloadError(
53
- f"Insufficient disk space in {os.path.dirname(download_path)} to"
54
- f"download file: {file_name}"
55
- )
33
+ def download_models(
34
+ model_name: model_type, download_path: os.PathLike
35
+ ) -> Path:
36
+ """
37
+ For a given model name and download path, download the model file
38
+ and return the path to the downloaded file.
56
39
 
57
- if install_path is not None:
58
- if not os.path.exists(install_path):
59
- raise DownloadError(
60
- f"Could not find directory '{install_path}' "
61
- f"to extract file: {file_name}"
62
- )
40
+ Parameters
41
+ ----------
42
+ model_name : model_type
43
+ The name of the model to be downloaded.
44
+ download_path : os.PathLike
45
+ The path where the model file will be downloaded.
63
46
 
64
- if (extract_requires is not None) and (
65
- disk_free_gb(install_path) < extract_requires
66
- ):
67
- raise DownloadError(
68
- f"Insufficient disk space in {install_path} to"
69
- f"extract file: {file_name}"
70
- )
47
+ Returns
48
+ -------
49
+ Path
50
+ The path to the downloaded model file.
51
+
52
+ """
71
53
 
72
- download_file(download_path, url, file_name)
73
- if install_path is not None:
74
- extract_file(download_path, install_path)
75
- os.remove(download_path)
54
+ download_path = Path(download_path)
55
+ filename = model_filenames[model_name]
56
+ model_path = pooch.retrieve(
57
+ url=f"{MODEL_URL}/{filename}",
58
+ known_hash=model_hashes[model_name],
59
+ path=download_path,
60
+ fname=filename,
61
+ progressbar=True,
62
+ )
63
+
64
+ return Path(model_path)
76
65
 
77
66
 
78
67
  def amend_user_configuration(new_model_path=None) -> None:
@@ -82,7 +71,7 @@ def amend_user_configuration(new_model_path=None) -> None:
82
71
 
83
72
  Parameters
84
73
  ----------
85
- new_model_path : str, optional
74
+ new_model_path : Path, optional
86
75
  The path to the new model configuration.
87
76
  """
88
77
  print("(Over-)writing custom user configuration")
@@ -124,5 +113,8 @@ def write_model_to_config(new_model_path, orig_config, custom_config):
124
113
  data[i] = line.replace(
125
114
  f"model_path = '{orig_path}", f"model_path = '{new_model_path}"
126
115
  )
116
+
117
+ custom_config_path = Path(custom_config)
118
+ custom_config_path.parent.mkdir(parents=True, exist_ok=True)
127
119
  with open(custom_config, "w") as out_conf:
128
120
  out_conf.writelines(data)