modelbase2 0.1.78__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. modelbase2/__init__.py +138 -26
  2. modelbase2/distributions.py +306 -0
  3. modelbase2/experimental/__init__.py +17 -0
  4. modelbase2/experimental/codegen.py +239 -0
  5. modelbase2/experimental/diff.py +227 -0
  6. modelbase2/experimental/notes.md +4 -0
  7. modelbase2/experimental/tex.py +521 -0
  8. modelbase2/fit.py +284 -0
  9. modelbase2/fns.py +185 -0
  10. modelbase2/integrators/__init__.py +19 -0
  11. modelbase2/integrators/int_assimulo.py +146 -0
  12. modelbase2/integrators/int_scipy.py +147 -0
  13. modelbase2/label_map.py +610 -0
  14. modelbase2/linear_label_map.py +301 -0
  15. modelbase2/mc.py +548 -0
  16. modelbase2/mca.py +280 -0
  17. modelbase2/model.py +1621 -0
  18. modelbase2/npe.py +343 -0
  19. modelbase2/parallel.py +171 -0
  20. modelbase2/parameterise.py +28 -0
  21. modelbase2/paths.py +36 -0
  22. modelbase2/plot.py +829 -0
  23. modelbase2/sbml/__init__.py +14 -0
  24. modelbase2/sbml/_data.py +77 -0
  25. modelbase2/sbml/_export.py +656 -0
  26. modelbase2/sbml/_import.py +585 -0
  27. modelbase2/sbml/_mathml.py +691 -0
  28. modelbase2/sbml/_name_conversion.py +52 -0
  29. modelbase2/sbml/_unit_conversion.py +74 -0
  30. modelbase2/scan.py +616 -0
  31. modelbase2/scope.py +96 -0
  32. modelbase2/simulator.py +635 -0
  33. modelbase2/surrogates/__init__.py +32 -0
  34. modelbase2/surrogates/_poly.py +66 -0
  35. modelbase2/surrogates/_torch.py +249 -0
  36. modelbase2/surrogates.py +316 -0
  37. modelbase2/types.py +352 -11
  38. modelbase2-0.2.0.dist-info/METADATA +81 -0
  39. modelbase2-0.2.0.dist-info/RECORD +42 -0
  40. {modelbase2-0.1.78.dist-info → modelbase2-0.2.0.dist-info}/WHEEL +1 -1
  41. modelbase2/core/__init__.py +0 -29
  42. modelbase2/core/algebraic_module_container.py +0 -130
  43. modelbase2/core/constant_container.py +0 -113
  44. modelbase2/core/data.py +0 -109
  45. modelbase2/core/name_container.py +0 -29
  46. modelbase2/core/reaction_container.py +0 -115
  47. modelbase2/core/utils.py +0 -28
  48. modelbase2/core/variable_container.py +0 -24
  49. modelbase2/ode/__init__.py +0 -13
  50. modelbase2/ode/integrator.py +0 -80
  51. modelbase2/ode/mca.py +0 -270
  52. modelbase2/ode/model.py +0 -470
  53. modelbase2/ode/simulator.py +0 -153
  54. modelbase2/utils/__init__.py +0 -0
  55. modelbase2/utils/plotting.py +0 -372
  56. modelbase2-0.1.78.dist-info/METADATA +0 -44
  57. modelbase2-0.1.78.dist-info/RECORD +0 -22
  58. {modelbase2-0.1.78.dist-info → modelbase2-0.2.0.dist-info/licenses}/LICENSE +0 -0
modelbase2/npe.py ADDED
@@ -0,0 +1,343 @@
1
+ """Neural Network Parameter Estimation (NPE) Module.
2
+
3
+ This module provides classes and functions for training neural network models to estimate
4
+ parameters in metabolic models. It includes functionality for both steady-state and
5
+ time-series data.
6
+
7
+ Classes:
8
+ DefaultSSAproximator: Default neural network model for steady-state approximation
9
+ DefaultTimeSeriesApproximator: Default neural network model for time-series approximation
10
+
11
+ Functions:
12
+ train_torch_surrogate: Train a PyTorch surrogate model
13
+ train_torch_time_course_estimator: Train a PyTorch time course estimator
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ __all__ = [
19
+ "AbstractEstimator",
20
+ "DefaultCache",
21
+ "DefaultDevice",
22
+ "DefaultSSAproximator",
23
+ "DefaultTimeSeriesApproximator",
24
+ "TorchSSEstimator",
25
+ "TorchTimeCourseEstimator",
26
+ "train_torch_ss_estimator",
27
+ "train_torch_time_course_estimator",
28
+ ]
29
+
30
+ from abc import abstractmethod
31
+ from dataclasses import dataclass
32
+ from pathlib import Path
33
+ from typing import cast
34
+
35
+ import numpy as np
36
+ import pandas as pd
37
+ import torch
38
+ import tqdm
39
+ from torch import nn
40
+ from torch.optim.adam import Adam
41
+
42
+ from modelbase2.parallel import Cache
43
+
44
+ DefaultDevice = torch.device("cpu")
45
+ DefaultCache = Cache(Path(".cache"))
46
+
47
+
48
+ class DefaultSSAproximator(nn.Module):
49
+ """Default neural network model for steady-state approximation."""
50
+
51
+ def __init__(self, n_inputs: int, n_outputs: int, n_hidden: int = 50) -> None:
52
+ """Initializes the neural network with the specified number of inputs and outputs.
53
+
54
+ Args:
55
+ n_inputs (int): The number of input features.
56
+ n_outputs (int): The number of output features.
57
+ n_hidden (int): The number of hidden units in the fully connected layers
58
+
59
+ The network consists of three fully connected layers with ReLU activations in between.
60
+ The weights of the linear layers are initialized with a normal distribution (mean=0, std=0.1),
61
+ and the biases are initialized to zero.
62
+
63
+ """
64
+ super().__init__()
65
+
66
+ self.net = nn.Sequential(
67
+ nn.Linear(n_inputs, n_hidden),
68
+ nn.ReLU(),
69
+ nn.Linear(n_hidden, n_hidden),
70
+ nn.ReLU(),
71
+ nn.Linear(n_hidden, n_outputs),
72
+ )
73
+
74
+ for m in self.net.modules():
75
+ if isinstance(m, nn.Linear):
76
+ nn.init.normal_(m.weight, mean=0, std=0.1)
77
+ nn.init.constant_(m.bias, val=0)
78
+
79
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
80
+ """Forward pass through the neural network."""
81
+ return cast(torch.Tensor, self.net(x))
82
+
83
+
84
+ class DefaultTimeSeriesApproximator(nn.Module):
85
+ """Default neural network model for time-series approximation."""
86
+
87
+ def __init__(self, n_inputs: int, n_outputs: int, n_hidden: int) -> None:
88
+ """Initializes the neural network model.
89
+
90
+ Args:
91
+ n_inputs (int): Number of input features.
92
+ n_outputs (int): Number of output features.
93
+ n_hidden (int): Number of hidden units in the LSTM layer.
94
+
95
+ """
96
+ super().__init__()
97
+
98
+ self.n_hidden = n_hidden
99
+
100
+ self.lstm = nn.LSTM(n_inputs, n_hidden)
101
+ self.to_out = nn.Linear(n_hidden, n_outputs)
102
+
103
+ nn.init.normal_(self.to_out.weight, mean=0, std=0.1)
104
+ nn.init.constant_(self.to_out.bias, val=0)
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ """Forward pass through the neural network."""
108
+ # lstm_out, (hidden_state, cell_state)
109
+ _, (hn, _) = self.lstm(x)
110
+ return cast(torch.Tensor, self.to_out(hn[-1])) # Use last hidden state
111
+
112
+
113
+ @dataclass(kw_only=True)
114
+ class AbstractEstimator:
115
+ """Abstract class for parameter estimation using neural networks."""
116
+
117
+ parameter_names: list[str]
118
+
119
+ @abstractmethod
120
+ def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
121
+ """Predict the target values for the given features."""
122
+
123
+
124
+ @dataclass(kw_only=True)
125
+ class TorchSSEstimator(AbstractEstimator):
126
+ """Estimator for steady state data using PyTorch models."""
127
+
128
+ model: torch.nn.Module
129
+
130
+ def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
131
+ """Predict the target values for the given features."""
132
+ with torch.no_grad():
133
+ pred = self.model(torch.tensor(features.to_numpy(), dtype=torch.float32))
134
+ return pd.DataFrame(pred, columns=self.parameter_names)
135
+
136
+
137
+ @dataclass(kw_only=True)
138
+ class TorchTimeCourseEstimator(AbstractEstimator):
139
+ """Estimator for time course data using PyTorch models."""
140
+
141
+ model: torch.nn.Module
142
+
143
+ def predict(self, features: pd.Series | pd.DataFrame) -> pd.DataFrame:
144
+ """Predict the target values for the given features."""
145
+ idx = cast(pd.MultiIndex, features.index)
146
+ features_ = torch.Tensor(
147
+ np.swapaxes(
148
+ features.to_numpy().reshape(
149
+ (
150
+ len(idx.levels[0]),
151
+ len(idx.levels[1]),
152
+ len(features.columns),
153
+ )
154
+ ),
155
+ axis1=0,
156
+ axis2=1,
157
+ ),
158
+ )
159
+ with torch.no_grad():
160
+ pred = self.model(features_)
161
+ return pd.DataFrame(pred, columns=self.parameter_names)
162
+
163
+
164
+ def _train_batched(
165
+ approximator: nn.Module,
166
+ features: torch.Tensor,
167
+ targets: torch.Tensor,
168
+ epochs: int,
169
+ optimizer: Adam,
170
+ batch_size: int,
171
+ ) -> pd.Series:
172
+ losses = {}
173
+
174
+ for epoch in tqdm.trange(epochs):
175
+ permutation = torch.randperm(features.size()[0])
176
+ epoch_loss = 0
177
+ for i in range(0, features.size()[0], batch_size):
178
+ optimizer.zero_grad()
179
+ indices = permutation[i : i + batch_size]
180
+
181
+ loss = torch.mean(
182
+ torch.abs(approximator(features[indices]) - targets[indices])
183
+ )
184
+ loss.backward()
185
+ optimizer.step()
186
+ epoch_loss += loss.detach().numpy()
187
+
188
+ losses[epoch] = epoch_loss / (features.size()[0] / batch_size)
189
+ return pd.Series(losses, dtype=float)
190
+
191
+
192
+ def _train_full(
193
+ approximator: nn.Module,
194
+ features: torch.Tensor,
195
+ targets: torch.Tensor,
196
+ epochs: int,
197
+ optimizer: Adam,
198
+ ) -> pd.Series:
199
+ losses = {}
200
+ for i in tqdm.trange(epochs):
201
+ optimizer.zero_grad()
202
+ loss = torch.mean(torch.abs(approximator(features) - targets))
203
+ loss.backward()
204
+ optimizer.step()
205
+ losses[i] = loss.detach().numpy()
206
+ return pd.Series(losses, dtype=float)
207
+
208
+
209
+ def train_torch_ss_estimator(
210
+ features: pd.DataFrame,
211
+ targets: pd.DataFrame,
212
+ epochs: int,
213
+ batch_size: int | None = None,
214
+ approximator: nn.Module | None = None,
215
+ optimimzer_cls: type[Adam] = Adam,
216
+ device: torch.device = DefaultDevice,
217
+ ) -> tuple[TorchSSEstimator, pd.Series]:
218
+ """Train a PyTorch steady state estimator.
219
+
220
+ This function trains a neural network model to estimate steady state data
221
+ using the provided features and targets. It supports both full-batch and
222
+ mini-batch training.
223
+
224
+ Examples:
225
+ >>> train_torch_ss_estimator(features, targets, epochs=100)
226
+
227
+ Args:
228
+ features: DataFrame containing the input features for training
229
+ targets: DataFrame containing the target values for training
230
+ epochs: Number of training epochs
231
+ batch_size: Size of mini-batches for training (None for full-batch)
232
+ approximator: Predefined neural network model (None to use default)
233
+ optimimzer_cls: Optimizer class to use for training (default: Adam)
234
+ device: Device to run the training on (default: DefaultDevice)
235
+
236
+ Returns:
237
+ tuple[TorchTimeSeriesEstimator, pd.Series]: Trained estimator and loss history
238
+
239
+ """
240
+ if approximator is None:
241
+ approximator = DefaultSSAproximator(
242
+ n_inputs=len(features.columns),
243
+ n_outputs=len(targets.columns),
244
+ n_hidden=max(2 * len(features.columns) * len(targets.columns), 10),
245
+ ).to(device)
246
+
247
+ features_ = torch.Tensor(features.to_numpy(), device=device)
248
+ targets_ = torch.Tensor(targets.to_numpy(), device=device)
249
+
250
+ optimizer = optimimzer_cls(approximator.parameters())
251
+ if batch_size is None:
252
+ losses = _train_full(
253
+ approximator=approximator,
254
+ features=features_,
255
+ targets=targets_,
256
+ epochs=epochs,
257
+ optimizer=optimizer,
258
+ )
259
+ else:
260
+ losses = _train_batched(
261
+ approximator=approximator,
262
+ features=features_,
263
+ targets=targets_,
264
+ epochs=epochs,
265
+ optimizer=optimizer,
266
+ batch_size=batch_size,
267
+ )
268
+
269
+ return TorchSSEstimator(
270
+ model=approximator,
271
+ parameter_names=list(targets.columns),
272
+ ), losses
273
+
274
+
275
+ def train_torch_time_course_estimator(
276
+ features: pd.DataFrame,
277
+ targets: pd.DataFrame,
278
+ epochs: int,
279
+ batch_size: int | None = None,
280
+ approximator: nn.Module | None = None,
281
+ optimimzer_cls: type[Adam] = Adam,
282
+ device: torch.device = DefaultDevice,
283
+ ) -> tuple[TorchTimeCourseEstimator, pd.Series]:
284
+ """Train a PyTorch time course estimator.
285
+
286
+ This function trains a neural network model to estimate time course data
287
+ using the provided features and targets. It supports both full-batch and
288
+ mini-batch training.
289
+
290
+ Examples:
291
+ >>> train_torch_time_course_estimator(features, targets, epochs=100)
292
+
293
+ Args:
294
+ features: DataFrame containing the input features for training
295
+ targets: DataFrame containing the target values for training
296
+ epochs: Number of training epochs
297
+ batch_size: Size of mini-batches for training (None for full-batch)
298
+ approximator: Predefined neural network model (None to use default)
299
+ optimimzer_cls: Optimizer class to use for training (default: Adam)
300
+ device: Device to run the training on (default: DefaultDevice)
301
+
302
+ Returns:
303
+ tuple[TorchTimeSeriesEstimator, pd.Series]: Trained estimator and loss history
304
+
305
+ """
306
+ if approximator is None:
307
+ approximator = DefaultTimeSeriesApproximator(
308
+ n_inputs=len(features.columns),
309
+ n_outputs=len(targets.columns),
310
+ n_hidden=1,
311
+ ).to(device)
312
+
313
+ optimizer = optimimzer_cls(approximator.parameters())
314
+ features_ = torch.Tensor(
315
+ np.swapaxes(
316
+ features.to_numpy().reshape((len(targets), -1, len(features.columns))),
317
+ axis1=0,
318
+ axis2=1,
319
+ ),
320
+ device=device,
321
+ )
322
+ targets_ = torch.Tensor(targets.to_numpy(), device=device)
323
+ if batch_size is None:
324
+ losses = _train_full(
325
+ approximator=approximator,
326
+ features=features_,
327
+ targets=targets_,
328
+ epochs=epochs,
329
+ optimizer=optimizer,
330
+ )
331
+ else:
332
+ losses = _train_batched(
333
+ approximator=approximator,
334
+ features=features_,
335
+ targets=targets_,
336
+ epochs=epochs,
337
+ optimizer=optimizer,
338
+ batch_size=batch_size,
339
+ )
340
+ return TorchTimeCourseEstimator(
341
+ model=approximator,
342
+ parameter_names=list(targets.columns),
343
+ ), losses
modelbase2/parallel.py ADDED
@@ -0,0 +1,171 @@
1
+ """Parallel Execution Module.
2
+
3
+ This module provides functions and classes for parallel execution and caching of
4
+ computation results. It includes functionality for parallel processing and result
5
+ caching using multiprocessing and pickle.
6
+
7
+ Classes:
8
+ Cache: Cache class for storing and retrieving computation results.
9
+
10
+ Functions:
11
+ parallelise: Execute a function in parallel over a collection of inputs.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import multiprocessing
17
+ import pickle
18
+ import sys
19
+ from dataclasses import dataclass
20
+ from functools import partial
21
+ from pathlib import Path
22
+ from typing import TYPE_CHECKING, Any, cast
23
+
24
+ import pebble
25
+ from tqdm import tqdm
26
+
27
+ __all__ = ["Cache", "parallelise"]
28
+
29
+ if TYPE_CHECKING:
30
+ from collections.abc import Callable, Collection, Hashable
31
+
32
+
33
+ def _pickle_name(k: Hashable) -> str:
34
+ return f"{k}.p"
35
+
36
+
37
+ def _pickle_load(file: Path) -> Any:
38
+ with file.open("rb") as fp:
39
+ return pickle.load(fp) # nosec
40
+
41
+
42
+ def _pickle_save(file: Path, data: Any) -> None:
43
+ with file.open("wb") as fp:
44
+ pickle.dump(data, fp)
45
+
46
+
47
+ @dataclass
48
+ class Cache:
49
+ """Cache class for storing and retrieving computation results.
50
+
51
+ Attributes:
52
+ tmp_dir: Directory to store cache files.
53
+ name_fn: Function to generate file names from keys.
54
+ load_fn: Function to load data from files.
55
+ save_fn: Function to save data to files.
56
+
57
+ """
58
+
59
+ tmp_dir: Path = Path(".cache")
60
+ name_fn: Callable[[Any], str] = _pickle_name
61
+ load_fn: Callable[[Path], Any] = _pickle_load
62
+ save_fn: Callable[[Path, Any], None] = _pickle_save
63
+
64
+
65
+ def _load_or_run[K: Hashable, Tin, Tout](
66
+ inp: tuple[K, Tin],
67
+ fn: Callable[[Tin], Tout],
68
+ cache: Cache | None,
69
+ ) -> tuple[K, Tout]:
70
+ """Load data from cache or execute function and save result.
71
+
72
+ Args:
73
+ inp: Tuple containing a key and input value.
74
+ fn: Function to execute if result is not in cache.
75
+ cache: Optional cache to store and retrieve results.
76
+
77
+ Returns:
78
+ tuple[K, Tout]: Tuple containing the key and the result of the function.
79
+
80
+ """
81
+ k, v = inp
82
+ if cache is None:
83
+ res = fn(v)
84
+ else:
85
+ file = cache.tmp_dir / cache.name_fn(k)
86
+ if file.exists():
87
+ return k, cast(Tout, cache.load_fn(file))
88
+ res = fn(v)
89
+ cache.save_fn(file, res)
90
+ return k, res
91
+
92
+
93
+ def parallelise[K: Hashable, Tin, Tout](
94
+ fn: Callable[[Tin], Tout],
95
+ inputs: Collection[tuple[K, Tin]],
96
+ *,
97
+ cache: Cache | None = None,
98
+ parallel: bool = True,
99
+ max_workers: int | None = None,
100
+ timeout: float | None = None,
101
+ disable_tqdm: bool = False,
102
+ tqdm_desc: str | None = None,
103
+ ) -> dict[Tin, Tout]:
104
+ """Execute a function in parallel over a collection of inputs.
105
+
106
+ Examples:
107
+ >>> parallelise(square, [("a", 2), ("b", 3), ("c", 4)])
108
+ {"a": 4, "b": 9, "c": 16}
109
+
110
+ Args:
111
+ fn: Function to execute in parallel. Takes a single input and returns a result.
112
+ inputs: Collection of (key, input) tuples to process.
113
+ cache: Optional cache to store and retrieve results.
114
+ parallel: Whether to execute in parallel (default: True).
115
+ max_workers: Maximum number of worker processes (default: None, uses all available CPUs).
116
+ timeout: Maximum time (in seconds) to wait for each worker to complete (default: None).
117
+ disable_tqdm: Whether to disable the tqdm progress bar (default: False).
118
+ tqdm_desc: Description for the tqdm progress bar (default: None).
119
+
120
+ Returns:
121
+ dict[Tin, Tout]: Dictionary mapping inputs to their corresponding outputs.
122
+
123
+ """
124
+ if cache is not None:
125
+ cache.tmp_dir.mkdir(parents=True, exist_ok=True)
126
+
127
+ if sys.platform in ["win32", "cygwin"]:
128
+ parallel = False
129
+
130
+ worker: Callable[[K, Tin], tuple[K, Tout]] = partial(
131
+ _load_or_run,
132
+ fn=fn,
133
+ cache=cache,
134
+ ) # type: ignore
135
+
136
+ results: dict[Tin, Tout]
137
+ if parallel:
138
+ results = {}
139
+ max_workers = (
140
+ multiprocessing.cpu_count() if max_workers is None else max_workers
141
+ )
142
+
143
+ with (
144
+ tqdm(
145
+ total=len(inputs),
146
+ disable=disable_tqdm,
147
+ desc=tqdm_desc,
148
+ ) as pbar,
149
+ pebble.ProcessPool(max_workers=max_workers) as pool,
150
+ ):
151
+ future = pool.map(worker, inputs, timeout=timeout)
152
+ it = future.result()
153
+ while True:
154
+ try:
155
+ key, value = next(it)
156
+ pbar.update(1)
157
+ results[key] = value
158
+ except StopIteration:
159
+ break
160
+ except TimeoutError:
161
+ pbar.update(1)
162
+ else:
163
+ results = dict(
164
+ tqdm(
165
+ map(worker, inputs), # type: ignore
166
+ total=len(inputs),
167
+ disable=disable_tqdm,
168
+ desc=tqdm_desc,
169
+ ) # type: ignore
170
+ ) # type: ignore
171
+ return results
@@ -0,0 +1,28 @@
1
+ """Module to parameterise models."""
2
+
3
+ from pathlib import Path
4
+
5
+ import pandas as pd
6
+ from parameteriser.brenda.v0 import Brenda
7
+
8
+ __all__ = ["get_km_and_kcat_from_brenda"]
9
+
10
+
11
+ def get_km_and_kcat_from_brenda(
12
+ ec: str,
13
+ brenda_path: Path,
14
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
15
+ """Obtain michaelis and catalytic constants for given ec number.
16
+
17
+ You can obtain the database from https://www.brenda-enzymes.org/download.php
18
+ """
19
+ brenda = Brenda()
20
+ if brenda_path is not None:
21
+ brenda.read_database(brenda_path)
22
+
23
+ kms, kcats = brenda.get_kms_and_kcats(
24
+ ec=ec,
25
+ filter_mutant=True,
26
+ filter_missing_sequences=True,
27
+ )
28
+ return kms, kcats
modelbase2/paths.py ADDED
@@ -0,0 +1,36 @@
1
+ """Shared paths between the modelbase2 package."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import shutil
6
+ from pathlib import Path
7
+
8
+ __all__ = [
9
+ "default_tmp_dir",
10
+ ]
11
+
12
+
13
+ def default_tmp_dir(tmp_dir: Path | None, *, remove_old_cache: bool) -> Path:
14
+ """Returns the default temporary directory path.
15
+
16
+ If `tmp_dir` is None, it defaults to the user's home directory under ".cache/modelbase".
17
+ Optionally removes old cache if specified.
18
+
19
+ Args:
20
+ tmp_dir (Path | None): The temporary directory path. If None, defaults to
21
+ Path.home() / ".cache" / "modelbase".
22
+ remove_old_cache (bool): If True, removes the old cache directory if it exists.
23
+ Defaults to False.
24
+
25
+ Returns:
26
+ Path: The path to the temporary directory.
27
+
28
+ """
29
+ if tmp_dir is None:
30
+ tmp_dir = Path.home() / ".cache" / "modelbase"
31
+
32
+ if tmp_dir.exists() and remove_old_cache:
33
+ shutil.rmtree(tmp_dir)
34
+
35
+ tmp_dir.mkdir(exist_ok=True, parents=True)
36
+ return tmp_dir