cellfinder 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cellfinder might be problematic. Click here for more details.
- cellfinder/__init__.py +18 -12
- cellfinder/core/classify/classify.py +13 -6
- cellfinder/core/classify/cube_generator.py +27 -11
- cellfinder/core/classify/resnet.py +9 -6
- cellfinder/core/classify/tools.py +13 -11
- cellfinder/core/main.py +3 -28
- cellfinder/core/tools/prep.py +1 -10
- cellfinder/core/tools/system.py +10 -0
- cellfinder/core/train/train_yml.py +25 -21
- cellfinder/napari/curation.py +1 -1
- cellfinder/napari/detect/detect.py +8 -2
- cellfinder/napari/detect/detect_containers.py +2 -0
- cellfinder/napari/detect/thread_worker.py +2 -2
- {cellfinder-1.2.0.dist-info → cellfinder-1.3.0.dist-info}/METADATA +4 -3
- {cellfinder-1.2.0.dist-info → cellfinder-1.3.0.dist-info}/RECORD +19 -20
- cellfinder/core/tools/tf.py +0 -46
- {cellfinder-1.2.0.dist-info → cellfinder-1.3.0.dist-info}/LICENSE +0 -0
- {cellfinder-1.2.0.dist-info → cellfinder-1.3.0.dist-info}/WHEEL +0 -0
- {cellfinder-1.2.0.dist-info → cellfinder-1.3.0.dist-info}/entry_points.txt +0 -0
- {cellfinder-1.2.0.dist-info → cellfinder-1.3.0.dist-info}/top_level.txt +0 -0
cellfinder/__init__.py
CHANGED
|
@@ -1,27 +1,33 @@
|
|
|
1
|
+
import os
|
|
1
2
|
from importlib.metadata import PackageNotFoundError, version
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
|
|
5
|
+
# Check cellfinder is installed
|
|
4
6
|
try:
|
|
5
7
|
__version__ = version("cellfinder")
|
|
6
8
|
except PackageNotFoundError as e:
|
|
7
9
|
raise PackageNotFoundError("cellfinder package not installed") from e
|
|
8
10
|
|
|
9
|
-
# If
|
|
11
|
+
# If Keras is not present, tools cannot be used.
|
|
10
12
|
# Throw an error in this case to prevent invocation of functions.
|
|
11
13
|
try:
|
|
12
|
-
|
|
14
|
+
KERAS_VERSION = version("keras")
|
|
13
15
|
except PackageNotFoundError as e:
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
16
|
+
raise PackageNotFoundError(
|
|
17
|
+
f"cellfinder tools cannot be invoked without Keras. "
|
|
18
|
+
f"Please install Keras with a backend into your environment "
|
|
19
|
+
f"to use cellfinder tools. "
|
|
20
|
+
f"For more information on Keras backends, please see "
|
|
21
|
+
f"https://keras.io/getting_started/#installing-keras-3."
|
|
22
|
+
f"For more information on brainglobe, please see "
|
|
23
|
+
f"https://github.com/brainglobe/brainglobe-meta#readme."
|
|
24
|
+
) from e
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# Set the Keras backend to torch
|
|
28
|
+
os.environ["KERAS_BACKEND"] = "torch"
|
|
29
|
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
23
30
|
|
|
24
|
-
__author__ = "Adam Tyson, Christian Niedworok, Charly Rousseau"
|
|
25
31
|
__license__ = "BSD-3-Clause"
|
|
26
32
|
|
|
27
33
|
DEFAULT_CELLFINDER_DIRECTORY = Path.home() / ".brainglobe" / "cellfinder"
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from datetime import datetime
|
|
2
3
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
3
4
|
|
|
5
|
+
import keras
|
|
4
6
|
import numpy as np
|
|
5
7
|
from brainglobe_utils.cells.cells import Cell
|
|
6
8
|
from brainglobe_utils.general.system import get_num_processes
|
|
7
|
-
from tensorflow import keras
|
|
8
9
|
|
|
9
10
|
from cellfinder.core import logger, types
|
|
10
11
|
from cellfinder.core.classify.cube_generator import CubeGeneratorFromFile
|
|
@@ -48,9 +49,9 @@ def main(
|
|
|
48
49
|
callbacks = None
|
|
49
50
|
|
|
50
51
|
# Too many workers doesn't increase speed, and uses huge amounts of RAM
|
|
51
|
-
workers = get_num_processes(
|
|
52
|
-
|
|
53
|
-
)
|
|
52
|
+
workers = get_num_processes(min_free_cpu_cores=n_free_cpus)
|
|
53
|
+
|
|
54
|
+
start_time = datetime.now()
|
|
54
55
|
|
|
55
56
|
logger.debug("Initialising cube generator")
|
|
56
57
|
inference_generator = CubeGeneratorFromFile(
|
|
@@ -63,6 +64,8 @@ def main(
|
|
|
63
64
|
cube_width=cube_width,
|
|
64
65
|
cube_height=cube_height,
|
|
65
66
|
cube_depth=cube_depth,
|
|
67
|
+
use_multiprocessing=False,
|
|
68
|
+
workers=workers,
|
|
66
69
|
)
|
|
67
70
|
|
|
68
71
|
model = get_model(
|
|
@@ -73,10 +76,9 @@ def main(
|
|
|
73
76
|
)
|
|
74
77
|
|
|
75
78
|
logger.info("Running inference")
|
|
79
|
+
# in Keras 3.0 multiprocessing params are specified in the generator
|
|
76
80
|
predictions = model.predict(
|
|
77
81
|
inference_generator,
|
|
78
|
-
use_multiprocessing=True,
|
|
79
|
-
workers=workers,
|
|
80
82
|
verbose=True,
|
|
81
83
|
callbacks=callbacks,
|
|
82
84
|
)
|
|
@@ -91,6 +93,11 @@ def main(
|
|
|
91
93
|
cell.type = predictions[idx] + 1
|
|
92
94
|
points_list.append(cell)
|
|
93
95
|
|
|
96
|
+
time_elapsed = datetime.now() - start_time
|
|
97
|
+
print(
|
|
98
|
+
"Classfication complete - all points done in : {}".format(time_elapsed)
|
|
99
|
+
)
|
|
100
|
+
|
|
94
101
|
return points_list
|
|
95
102
|
|
|
96
103
|
|
|
@@ -2,13 +2,13 @@ from pathlib import Path
|
|
|
2
2
|
from random import shuffle
|
|
3
3
|
from typing import Dict, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
|
+
import keras
|
|
5
6
|
import numpy as np
|
|
6
|
-
import tensorflow as tf
|
|
7
7
|
from brainglobe_utils.cells.cells import Cell, group_cells_by_z
|
|
8
8
|
from brainglobe_utils.general.numerical import is_even
|
|
9
|
+
from keras.utils import Sequence
|
|
9
10
|
from scipy.ndimage import zoom
|
|
10
11
|
from skimage.io import imread
|
|
11
|
-
from tensorflow.keras.utils import Sequence
|
|
12
12
|
|
|
13
13
|
from cellfinder.core import types
|
|
14
14
|
from cellfinder.core.classify.augment import AugmentationParameters, augment
|
|
@@ -40,7 +40,7 @@ class CubeGeneratorFromFile(Sequence):
|
|
|
40
40
|
background_array: types.array,
|
|
41
41
|
voxel_sizes: Tuple[int, int, int],
|
|
42
42
|
network_voxel_sizes: Tuple[int, int, int],
|
|
43
|
-
batch_size: int =
|
|
43
|
+
batch_size: int = 64,
|
|
44
44
|
cube_width: int = 50,
|
|
45
45
|
cube_height: int = 50,
|
|
46
46
|
cube_depth: int = 20,
|
|
@@ -56,7 +56,14 @@ class CubeGeneratorFromFile(Sequence):
|
|
|
56
56
|
translate: Tuple[float, float, float] = (0.05, 0.05, 0.05),
|
|
57
57
|
shuffle: bool = False,
|
|
58
58
|
interpolation_order: int = 2,
|
|
59
|
+
*args,
|
|
60
|
+
**kwargs,
|
|
59
61
|
):
|
|
62
|
+
# pass any additional arguments not specified in signature to the
|
|
63
|
+
# constructor of the superclass (e.g.: `use_multiprocessing` or
|
|
64
|
+
# `workers`)
|
|
65
|
+
super().__init__(*args, **kwargs)
|
|
66
|
+
|
|
60
67
|
self.points = points
|
|
61
68
|
self.signal_array = signal_array
|
|
62
69
|
self.background_array = background_array
|
|
@@ -218,10 +225,10 @@ class CubeGeneratorFromFile(Sequence):
|
|
|
218
225
|
|
|
219
226
|
if self.train:
|
|
220
227
|
batch_labels = [cell.type - 1 for cell in cell_batch]
|
|
221
|
-
batch_labels =
|
|
228
|
+
batch_labels = keras.utils.to_categorical(
|
|
222
229
|
batch_labels, num_classes=self.classes
|
|
223
230
|
)
|
|
224
|
-
return images, batch_labels
|
|
231
|
+
return images, batch_labels.astype(np.float32)
|
|
225
232
|
elif self.extract:
|
|
226
233
|
batch_info = self.__get_batch_dict(cell_batch)
|
|
227
234
|
return images, batch_info
|
|
@@ -252,7 +259,8 @@ class CubeGeneratorFromFile(Sequence):
|
|
|
252
259
|
(number_images,)
|
|
253
260
|
+ (self.cube_height, self.cube_width, self.cube_depth)
|
|
254
261
|
+ (self.channels,)
|
|
255
|
-
)
|
|
262
|
+
),
|
|
263
|
+
dtype=np.float32,
|
|
256
264
|
)
|
|
257
265
|
|
|
258
266
|
for idx, cell in enumerate(cell_batch):
|
|
@@ -337,7 +345,7 @@ class CubeGeneratorFromDisk(Sequence):
|
|
|
337
345
|
signal_list: List[Union[str, Path]],
|
|
338
346
|
background_list: List[Union[str, Path]],
|
|
339
347
|
labels: Optional[List[int]] = None, # only if training or validating
|
|
340
|
-
batch_size: int =
|
|
348
|
+
batch_size: int = 64,
|
|
341
349
|
shape: Tuple[int, int, int] = (50, 50, 20),
|
|
342
350
|
channels: int = 2,
|
|
343
351
|
classes: int = 2,
|
|
@@ -350,7 +358,14 @@ class CubeGeneratorFromDisk(Sequence):
|
|
|
350
358
|
translate: Tuple[float, float, float] = (0.2, 0.2, 0.2),
|
|
351
359
|
train: bool = False, # also return labels
|
|
352
360
|
interpolation_order: int = 2,
|
|
361
|
+
*args,
|
|
362
|
+
**kwargs,
|
|
353
363
|
):
|
|
364
|
+
# pass any additional arguments not specified in signature to the
|
|
365
|
+
# constructor of the superclass (e.g.: `use_multiprocessing` or
|
|
366
|
+
# `workers`)
|
|
367
|
+
super().__init__(*args, **kwargs)
|
|
368
|
+
|
|
354
369
|
self.im_shape = shape
|
|
355
370
|
self.batch_size = batch_size
|
|
356
371
|
self.labels = labels
|
|
@@ -410,10 +425,10 @@ class CubeGeneratorFromDisk(Sequence):
|
|
|
410
425
|
|
|
411
426
|
if self.train and self.labels is not None:
|
|
412
427
|
batch_labels = [self.labels[k] for k in indexes]
|
|
413
|
-
batch_labels =
|
|
428
|
+
batch_labels = keras.utils.to_categorical(
|
|
414
429
|
batch_labels, num_classes=self.classes
|
|
415
430
|
)
|
|
416
|
-
return images, batch_labels
|
|
431
|
+
return images, batch_labels.astype(np.float32)
|
|
417
432
|
else:
|
|
418
433
|
return images
|
|
419
434
|
|
|
@@ -424,7 +439,8 @@ class CubeGeneratorFromDisk(Sequence):
|
|
|
424
439
|
) -> np.ndarray:
|
|
425
440
|
number_images = len(list_signal_tmp)
|
|
426
441
|
images = np.empty(
|
|
427
|
-
((number_images,) + self.im_shape + (self.channels,))
|
|
442
|
+
((number_images,) + self.im_shape + (self.channels,)),
|
|
443
|
+
dtype=np.float32,
|
|
428
444
|
)
|
|
429
445
|
|
|
430
446
|
for idx, signal_im in enumerate(list_signal_tmp):
|
|
@@ -433,7 +449,7 @@ class CubeGeneratorFromDisk(Sequence):
|
|
|
433
449
|
images, idx, signal_im, background_im
|
|
434
450
|
)
|
|
435
451
|
|
|
436
|
-
return images
|
|
452
|
+
return images
|
|
437
453
|
|
|
438
454
|
def __populate_array_with_cubes(
|
|
439
455
|
self,
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
from
|
|
3
|
+
from keras import (
|
|
4
|
+
KerasTensor as Tensor,
|
|
5
|
+
)
|
|
6
|
+
from keras import Model
|
|
7
|
+
from keras.initializers import Initializer
|
|
8
|
+
from keras.layers import (
|
|
7
9
|
Activation,
|
|
8
10
|
Add,
|
|
9
11
|
BatchNormalization,
|
|
@@ -14,7 +16,7 @@ from tensorflow.keras.layers import (
|
|
|
14
16
|
MaxPooling3D,
|
|
15
17
|
ZeroPadding3D,
|
|
16
18
|
)
|
|
17
|
-
from
|
|
19
|
+
from keras.optimizers import Adam, Optimizer
|
|
18
20
|
|
|
19
21
|
#####################################################################
|
|
20
22
|
# Define the types of ResNet
|
|
@@ -113,7 +115,7 @@ def non_residual_block(
|
|
|
113
115
|
activation: str = "relu",
|
|
114
116
|
use_bias: bool = False,
|
|
115
117
|
bn_epsilon: float = 1e-5,
|
|
116
|
-
pooling_padding: str = "
|
|
118
|
+
pooling_padding: str = "valid",
|
|
117
119
|
axis: int = 3,
|
|
118
120
|
) -> Tensor:
|
|
119
121
|
"""
|
|
@@ -131,6 +133,7 @@ def non_residual_block(
|
|
|
131
133
|
)(x)
|
|
132
134
|
x = BatchNormalization(axis=axis, epsilon=bn_epsilon, name="conv1_bn")(x)
|
|
133
135
|
x = Activation(activation, name="conv1_activation")(x)
|
|
136
|
+
|
|
134
137
|
x = MaxPooling3D(
|
|
135
138
|
max_pool_size,
|
|
136
139
|
strides=strides,
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import List, Optional, Tuple, Union
|
|
3
4
|
|
|
5
|
+
import keras
|
|
4
6
|
import numpy as np
|
|
5
|
-
|
|
6
|
-
from tensorflow.keras import Model
|
|
7
|
+
from keras import Model
|
|
7
8
|
|
|
8
9
|
from cellfinder.core import logger
|
|
9
10
|
from cellfinder.core.classify.resnet import build_model, layer_type
|
|
@@ -17,8 +18,7 @@ def get_model(
|
|
|
17
18
|
inference: bool = False,
|
|
18
19
|
continue_training: bool = False,
|
|
19
20
|
) -> Model:
|
|
20
|
-
"""
|
|
21
|
-
Returns the correct model based on the arguments passed
|
|
21
|
+
"""Returns the correct model based on the arguments passed
|
|
22
22
|
:param existing_model: An existing, trained model. This is returned if it
|
|
23
23
|
exists
|
|
24
24
|
:param model_weights: This file is used to set the model weights if it
|
|
@@ -30,29 +30,31 @@ def get_model(
|
|
|
30
30
|
by using the default one
|
|
31
31
|
:param continue_training: If True, will ensure that a trained model
|
|
32
32
|
exists. E.g. by using the default one
|
|
33
|
-
:return: A
|
|
33
|
+
:return: A keras model
|
|
34
34
|
|
|
35
35
|
"""
|
|
36
36
|
if existing_model is not None or network_depth is None:
|
|
37
37
|
logger.debug(f"Loading model: {existing_model}")
|
|
38
|
-
return
|
|
38
|
+
return keras.models.load_model(existing_model)
|
|
39
39
|
else:
|
|
40
40
|
logger.debug(f"Creating a new instance of model: {network_depth}")
|
|
41
41
|
model = build_model(
|
|
42
|
-
network_depth=network_depth,
|
|
42
|
+
network_depth=network_depth,
|
|
43
|
+
learning_rate=learning_rate,
|
|
43
44
|
)
|
|
44
45
|
if inference or continue_training:
|
|
45
46
|
logger.debug(
|
|
46
|
-
f"Setting model weights according to: {model_weights}"
|
|
47
|
+
f"Setting model weights according to: {model_weights}",
|
|
47
48
|
)
|
|
48
49
|
if model_weights is None:
|
|
49
|
-
raise
|
|
50
|
+
raise OSError("`model_weights` must be provided")
|
|
50
51
|
model.load_weights(model_weights)
|
|
51
52
|
return model
|
|
52
53
|
|
|
53
54
|
|
|
54
55
|
def make_lists(
|
|
55
|
-
tiff_files: Sequence,
|
|
56
|
+
tiff_files: Sequence,
|
|
57
|
+
train: bool = True,
|
|
56
58
|
) -> Union[Tuple[List, List], Tuple[List, List, np.ndarray]]:
|
|
57
59
|
signal_list = []
|
|
58
60
|
background_list = []
|
cellfinder/core/main.py
CHANGED
|
@@ -1,23 +1,13 @@
|
|
|
1
|
-
"""
|
|
2
|
-
N.B imports are within functions to prevent tensorflow being imported before
|
|
3
|
-
it's warnings are silenced
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
1
|
import os
|
|
7
2
|
from typing import Callable, List, Optional, Tuple
|
|
8
3
|
|
|
9
4
|
import numpy as np
|
|
10
5
|
from brainglobe_utils.cells.cells import Cell
|
|
11
|
-
from brainglobe_utils.general.logging import suppress_specific_logs
|
|
12
6
|
|
|
13
7
|
from cellfinder.core import logger
|
|
14
8
|
from cellfinder.core.download.download import model_type
|
|
15
9
|
from cellfinder.core.train.train_yml import depth_type
|
|
16
10
|
|
|
17
|
-
tf_suppress_log_messages = [
|
|
18
|
-
"multiprocessing can interact badly with TensorFlow"
|
|
19
|
-
]
|
|
20
|
-
|
|
21
11
|
|
|
22
12
|
def main(
|
|
23
13
|
signal_array: np.ndarray,
|
|
@@ -28,7 +18,7 @@ def main(
|
|
|
28
18
|
trained_model: Optional[os.PathLike] = None,
|
|
29
19
|
model_weights: Optional[os.PathLike] = None,
|
|
30
20
|
model: model_type = "resnet50_tv",
|
|
31
|
-
batch_size: int =
|
|
21
|
+
batch_size: int = 64,
|
|
32
22
|
n_free_cpus: int = 2,
|
|
33
23
|
network_voxel_sizes: Tuple[int, int, int] = (5, 1, 1),
|
|
34
24
|
soma_diameter: int = 16,
|
|
@@ -58,13 +48,11 @@ def main(
|
|
|
58
48
|
Called every time a plane has finished being processed during the
|
|
59
49
|
detection stage. Called with the plane number that has finished.
|
|
60
50
|
classify_callback : Callable[int], optional
|
|
61
|
-
Called every time
|
|
51
|
+
Called every time a point has finished being classified.
|
|
62
52
|
Called with the batch number that has just finished.
|
|
63
53
|
detect_finished_callback : Callable[list], optional
|
|
64
54
|
Called after detection is finished with the list of detected points.
|
|
65
55
|
"""
|
|
66
|
-
suppress_tf_logging(tf_suppress_log_messages)
|
|
67
|
-
|
|
68
56
|
from cellfinder.core.classify import classify
|
|
69
57
|
from cellfinder.core.detect import detect
|
|
70
58
|
from cellfinder.core.tools import prep
|
|
@@ -98,7 +86,7 @@ def main(
|
|
|
98
86
|
if not skip_classification:
|
|
99
87
|
install_path = None
|
|
100
88
|
model_weights = prep.prep_model_weights(
|
|
101
|
-
model_weights, install_path, model
|
|
89
|
+
model_weights, install_path, model
|
|
102
90
|
)
|
|
103
91
|
if len(points) > 0:
|
|
104
92
|
logger.info("Running classification")
|
|
@@ -120,17 +108,4 @@ def main(
|
|
|
120
108
|
)
|
|
121
109
|
else:
|
|
122
110
|
logger.info("No candidates, skipping classification")
|
|
123
|
-
|
|
124
111
|
return points
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def suppress_tf_logging(tf_suppress_log_messages: List[str]) -> None:
|
|
128
|
-
"""
|
|
129
|
-
Prevents many lines of logs such as:
|
|
130
|
-
"2019-10-24 16:54:41.363978: I tensorflow/stream_executor/platform/default
|
|
131
|
-
/dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1"
|
|
132
|
-
"""
|
|
133
|
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
134
|
-
|
|
135
|
-
for message in tf_suppress_log_messages:
|
|
136
|
-
suppress_specific_logs("tensorflow", message)
|
cellfinder/core/tools/prep.py
CHANGED
|
@@ -9,9 +9,7 @@ from pathlib import Path
|
|
|
9
9
|
from typing import Optional
|
|
10
10
|
|
|
11
11
|
from brainglobe_utils.general.config import get_config_obj
|
|
12
|
-
from brainglobe_utils.general.system import get_num_processes
|
|
13
12
|
|
|
14
|
-
import cellfinder.core.tools.tf as tf_tools
|
|
15
13
|
from cellfinder.core import logger
|
|
16
14
|
from cellfinder.core.download.download import (
|
|
17
15
|
DEFAULT_DOWNLOAD_DIRECTORY,
|
|
@@ -26,20 +24,13 @@ def prep_model_weights(
|
|
|
26
24
|
model_weights: Optional[os.PathLike],
|
|
27
25
|
install_path: Optional[os.PathLike],
|
|
28
26
|
model_name: model_type,
|
|
29
|
-
n_free_cpus: int,
|
|
30
27
|
) -> Path:
|
|
31
|
-
|
|
32
|
-
prep_tensorflow(n_processes)
|
|
28
|
+
# prepare models (get default weights or provided ones)
|
|
33
29
|
model_weights = prep_models(model_weights, install_path, model_name)
|
|
34
30
|
|
|
35
31
|
return model_weights
|
|
36
32
|
|
|
37
33
|
|
|
38
|
-
def prep_tensorflow(max_threads: int) -> None:
|
|
39
|
-
tf_tools.set_tf_threads(max_threads)
|
|
40
|
-
tf_tools.allow_gpu_memory_growth()
|
|
41
|
-
|
|
42
|
-
|
|
43
34
|
def prep_models(
|
|
44
35
|
model_weights_path: Optional[os.PathLike],
|
|
45
36
|
install_path: Optional[os.PathLike],
|
cellfinder/core/tools/system.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
2
|
|
|
3
|
+
import keras
|
|
3
4
|
from brainglobe_utils.general.exceptions import CommandLineInputError
|
|
4
5
|
|
|
5
6
|
|
|
@@ -80,3 +81,12 @@ def memory_in_bytes(memory_amount, unit):
|
|
|
80
81
|
)
|
|
81
82
|
else:
|
|
82
83
|
return memory_amount * 10 ** supported_units[unit]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def force_cpu():
|
|
87
|
+
"""
|
|
88
|
+
Forces the CPU to be used, even if a GPU is available
|
|
89
|
+
"""
|
|
90
|
+
keras.src.backend.common.global_state.set_global_attribute(
|
|
91
|
+
"torch_device", "cpu"
|
|
92
|
+
)
|
|
@@ -22,7 +22,10 @@ from brainglobe_utils.general.numerical import (
|
|
|
22
22
|
check_positive_float,
|
|
23
23
|
check_positive_int,
|
|
24
24
|
)
|
|
25
|
-
from brainglobe_utils.general.system import
|
|
25
|
+
from brainglobe_utils.general.system import (
|
|
26
|
+
ensure_directory_exists,
|
|
27
|
+
get_num_processes,
|
|
28
|
+
)
|
|
26
29
|
from brainglobe_utils.IO.cells import find_relevant_tiffs
|
|
27
30
|
from brainglobe_utils.IO.yaml import read_yaml_section
|
|
28
31
|
from fancylog import fancylog
|
|
@@ -33,11 +36,6 @@ from cellfinder.core import logger
|
|
|
33
36
|
from cellfinder.core.classify.resnet import layer_type
|
|
34
37
|
from cellfinder.core.download.download import DEFAULT_DOWNLOAD_DIRECTORY
|
|
35
38
|
|
|
36
|
-
tf_suppress_log_messages = [
|
|
37
|
-
"sample_weight modes were coerced from",
|
|
38
|
-
"multiprocessing can interact badly with TensorFlow",
|
|
39
|
-
]
|
|
40
|
-
|
|
41
39
|
depth_type = Literal["18", "34", "50", "101", "152"]
|
|
42
40
|
|
|
43
41
|
models: Dict[depth_type, layer_type] = {
|
|
@@ -318,11 +316,7 @@ def run(
|
|
|
318
316
|
save_progress=False,
|
|
319
317
|
epochs=100,
|
|
320
318
|
):
|
|
321
|
-
from
|
|
322
|
-
|
|
323
|
-
suppress_tf_logging(tf_suppress_log_messages)
|
|
324
|
-
|
|
325
|
-
from tensorflow.keras.callbacks import (
|
|
319
|
+
from keras.callbacks import (
|
|
326
320
|
CSVLogger,
|
|
327
321
|
ModelCheckpoint,
|
|
328
322
|
TensorBoard,
|
|
@@ -339,7 +333,6 @@ def run(
|
|
|
339
333
|
model_weights=model_weights,
|
|
340
334
|
install_path=install_path,
|
|
341
335
|
model_name=model,
|
|
342
|
-
n_free_cpus=n_free_cpus,
|
|
343
336
|
)
|
|
344
337
|
|
|
345
338
|
yaml_contents = parse_yaml(yaml_file)
|
|
@@ -361,6 +354,7 @@ def run(
|
|
|
361
354
|
|
|
362
355
|
signal_train, background_train, labels_train = make_lists(tiff_files)
|
|
363
356
|
|
|
357
|
+
n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus)
|
|
364
358
|
if test_fraction > 0:
|
|
365
359
|
logger.info("Splitting data into training and validation datasets")
|
|
366
360
|
(
|
|
@@ -387,15 +381,17 @@ def run(
|
|
|
387
381
|
labels=labels_test,
|
|
388
382
|
batch_size=batch_size,
|
|
389
383
|
train=True,
|
|
384
|
+
use_multiprocessing=False,
|
|
385
|
+
workers=n_processes,
|
|
390
386
|
)
|
|
391
387
|
|
|
392
388
|
# for saving checkpoints
|
|
393
|
-
base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}
|
|
389
|
+
base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}"
|
|
394
390
|
|
|
395
391
|
else:
|
|
396
392
|
logger.info("No validation data selected.")
|
|
397
393
|
validation_generator = None
|
|
398
|
-
base_checkpoint_file_name = "-epoch.{epoch:02d}
|
|
394
|
+
base_checkpoint_file_name = "-epoch.{epoch:02d}"
|
|
399
395
|
|
|
400
396
|
training_generator = CubeGeneratorFromDisk(
|
|
401
397
|
signal_train,
|
|
@@ -405,6 +401,8 @@ def run(
|
|
|
405
401
|
shuffle=True,
|
|
406
402
|
train=True,
|
|
407
403
|
augment=not no_augment,
|
|
404
|
+
use_multiprocessing=False,
|
|
405
|
+
workers=n_processes,
|
|
408
406
|
)
|
|
409
407
|
callbacks = []
|
|
410
408
|
|
|
@@ -421,9 +419,14 @@ def run(
|
|
|
421
419
|
|
|
422
420
|
if not no_save_checkpoints:
|
|
423
421
|
if save_weights:
|
|
424
|
-
filepath = str(
|
|
422
|
+
filepath = str(
|
|
423
|
+
output_dir
|
|
424
|
+
/ ("weight" + base_checkpoint_file_name + ".weights.h5")
|
|
425
|
+
)
|
|
425
426
|
else:
|
|
426
|
-
filepath = str(
|
|
427
|
+
filepath = str(
|
|
428
|
+
output_dir / ("model" + base_checkpoint_file_name + ".keras")
|
|
429
|
+
)
|
|
427
430
|
|
|
428
431
|
checkpoints = ModelCheckpoint(
|
|
429
432
|
filepath,
|
|
@@ -432,25 +435,26 @@ def run(
|
|
|
432
435
|
callbacks.append(checkpoints)
|
|
433
436
|
|
|
434
437
|
if save_progress:
|
|
435
|
-
|
|
436
|
-
csv_logger = CSVLogger(
|
|
438
|
+
csv_filepath = str(output_dir / "training.csv")
|
|
439
|
+
csv_logger = CSVLogger(csv_filepath)
|
|
437
440
|
callbacks.append(csv_logger)
|
|
438
441
|
|
|
439
442
|
logger.info("Beginning training.")
|
|
443
|
+
# Keras 3.0: `use_multiprocessing` input is set in the
|
|
444
|
+
# `training_generator` (False by default)
|
|
440
445
|
model.fit(
|
|
441
446
|
training_generator,
|
|
442
447
|
validation_data=validation_generator,
|
|
443
|
-
use_multiprocessing=False,
|
|
444
448
|
epochs=epochs,
|
|
445
449
|
callbacks=callbacks,
|
|
446
450
|
)
|
|
447
451
|
|
|
448
452
|
if save_weights:
|
|
449
453
|
logger.info("Saving model weights")
|
|
450
|
-
model.save_weights(
|
|
454
|
+
model.save_weights(output_dir / "model.weights.h5")
|
|
451
455
|
else:
|
|
452
456
|
logger.info("Saving model")
|
|
453
|
-
model.save(output_dir / "model.
|
|
457
|
+
model.save(output_dir / "model.keras")
|
|
454
458
|
|
|
455
459
|
logger.info(
|
|
456
460
|
"Finished training, " "Total time taken: %s",
|
cellfinder/napari/curation.py
CHANGED
|
@@ -253,8 +253,9 @@ def detect_widget() -> FunctionGui:
|
|
|
253
253
|
max_cluster_size: int,
|
|
254
254
|
classification_options,
|
|
255
255
|
skip_classification: bool,
|
|
256
|
-
trained_model: Optional[Path],
|
|
257
256
|
use_pre_trained_weights: bool,
|
|
257
|
+
trained_model: Optional[Path],
|
|
258
|
+
batch_size: int,
|
|
258
259
|
misc_options,
|
|
259
260
|
start_plane: int,
|
|
260
261
|
end_plane: int,
|
|
@@ -298,6 +299,8 @@ def detect_widget() -> FunctionGui:
|
|
|
298
299
|
should be attempted
|
|
299
300
|
use_pre_trained_weights : bool
|
|
300
301
|
Select to use pre-trained model weights
|
|
302
|
+
batch_size : int
|
|
303
|
+
How many points to classify at one time
|
|
301
304
|
skip_classification : bool
|
|
302
305
|
If selected, the classification step is skipped and all cells from
|
|
303
306
|
the detection stage are added
|
|
@@ -372,7 +375,10 @@ def detect_widget() -> FunctionGui:
|
|
|
372
375
|
if use_pre_trained_weights:
|
|
373
376
|
trained_model = None
|
|
374
377
|
classification_inputs = ClassificationInputs(
|
|
375
|
-
skip_classification,
|
|
378
|
+
skip_classification,
|
|
379
|
+
use_pre_trained_weights,
|
|
380
|
+
trained_model,
|
|
381
|
+
batch_size,
|
|
376
382
|
)
|
|
377
383
|
|
|
378
384
|
if analyse_local:
|
|
@@ -114,6 +114,7 @@ class ClassificationInputs(InputContainer):
|
|
|
114
114
|
skip_classification: bool = False
|
|
115
115
|
use_pre_trained_weights: bool = True
|
|
116
116
|
trained_model: Optional[Path] = Path.home()
|
|
117
|
+
batch_size: int = 64
|
|
117
118
|
|
|
118
119
|
def as_core_arguments(self) -> dict:
|
|
119
120
|
args = super().as_core_arguments()
|
|
@@ -131,6 +132,7 @@ class ClassificationInputs(InputContainer):
|
|
|
131
132
|
skip_classification=dict(
|
|
132
133
|
value=cls.defaults()["skip_classification"]
|
|
133
134
|
),
|
|
135
|
+
batch_size=dict(value=cls.defaults()["batch_size"]),
|
|
134
136
|
)
|
|
135
137
|
|
|
136
138
|
|
|
@@ -72,10 +72,10 @@ class Worker(WorkerBase):
|
|
|
72
72
|
def classify_callback(batch: int) -> None:
|
|
73
73
|
self.update_progress_bar.emit(
|
|
74
74
|
"Classifying cells",
|
|
75
|
-
# Default cellfinder-core batch size is
|
|
75
|
+
# Default cellfinder-core batch size is 64. This seems to give
|
|
76
76
|
# a slight underestimate of the number of batches though, so
|
|
77
77
|
# allow for batch number to go over this
|
|
78
|
-
max(self.npoints_detected //
|
|
78
|
+
max(self.npoints_detected // 64 + 1, batch + 1),
|
|
79
79
|
batch + 1,
|
|
80
80
|
)
|
|
81
81
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: cellfinder
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.3.0
|
|
4
4
|
Summary: Automated 3D cell detection in large microscopy images
|
|
5
5
|
Author-email: "Adam Tyson, Christian Niedworok, Charly Rousseau" <code@adamltyson.com>
|
|
6
6
|
License: BSD-3-Clause
|
|
@@ -18,6 +18,7 @@ Classifier: Programming Language :: Python
|
|
|
18
18
|
Classifier: Programming Language :: Python :: 3
|
|
19
19
|
Classifier: Programming Language :: Python :: 3.9
|
|
20
20
|
Classifier: Programming Language :: Python :: 3.10
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
21
22
|
Classifier: Topic :: Scientific/Engineering :: Image Recognition
|
|
22
23
|
Requires-Python: >=3.9
|
|
23
24
|
Description-Content-Type: text/markdown
|
|
@@ -31,10 +32,10 @@ Requires-Dist: numba
|
|
|
31
32
|
Requires-Dist: numpy
|
|
32
33
|
Requires-Dist: scikit-image
|
|
33
34
|
Requires-Dist: scikit-learn
|
|
35
|
+
Requires-Dist: keras >=3.0.0
|
|
36
|
+
Requires-Dist: torch >=2.1.0
|
|
34
37
|
Requires-Dist: tifffile
|
|
35
38
|
Requires-Dist: tqdm
|
|
36
|
-
Requires-Dist: tensorflow <2.12.0,>=2.5.0 ; platform_system != "Darwin" or platform_machine != "arm64"
|
|
37
|
-
Requires-Dist: tensorflow-macos <2.12.0,>=2.5.0 ; platform_system == "Darwin" and platform_machine == "arm64"
|
|
38
39
|
Provides-Extra: dev
|
|
39
40
|
Requires-Dist: black ; extra == 'dev'
|
|
40
41
|
Requires-Dist: pre-commit ; extra == 'dev'
|
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
cellfinder/__init__.py,sha256=
|
|
1
|
+
cellfinder/__init__.py,sha256=S5oQ3EORuyQTMYC4uUuzGKZ23J3Ya6q-1DOBib1KfiA,1166
|
|
2
2
|
cellfinder/cli_migration_warning.py,sha256=gPtNrtnXvWpl5q0k_EGAQZg0DwcpCmuBTgpg56n5NfA,1578
|
|
3
3
|
cellfinder/core/__init__.py,sha256=pRFuQsl78HEK0S6gvhJw70QLbjjSBzP-GFO0AtVaGtk,62
|
|
4
|
-
cellfinder/core/main.py,sha256=
|
|
4
|
+
cellfinder/core/main.py,sha256=t2mkq6iieEytbPckehBB43juwN5E-vhzstLSs620vdM,3625
|
|
5
5
|
cellfinder/core/types.py,sha256=lTqWE4v0SMM0qLAZJdyAzqV1nLgDtobEpglNJcXt160,106
|
|
6
6
|
cellfinder/core/classify/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
7
|
cellfinder/core/classify/augment.py,sha256=8dMbM7KhimM6NMgdMC53oHoCfYj1CIB-h3Yk8CZAxPw,6321
|
|
8
|
-
cellfinder/core/classify/classify.py,sha256=
|
|
9
|
-
cellfinder/core/classify/cube_generator.py,sha256=
|
|
10
|
-
cellfinder/core/classify/resnet.py,sha256=
|
|
11
|
-
cellfinder/core/classify/tools.py,sha256=
|
|
8
|
+
cellfinder/core/classify/classify.py,sha256=33ZvNDmgVabH_6p4jV9Xi8bLKwDZnVhzlImrL_TlnBk,3269
|
|
9
|
+
cellfinder/core/classify/cube_generator.py,sha256=jC5aogTVy213PHouViSR9CgKkuOks3yk6csQC5kRoOE,17125
|
|
10
|
+
cellfinder/core/classify/resnet.py,sha256=vGa85y_NcQnOXwAt5EtatLx5mrO8IoShCcNKtJ5-EFg,10034
|
|
11
|
+
cellfinder/core/classify/tools.py,sha256=s5PEKAsZVbVuoferZ_nqMUM0f2bV_8WEKsdKe3SXEuE,2560
|
|
12
12
|
cellfinder/core/config/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
13
|
cellfinder/core/config/cellfinder.conf,sha256=5i8axif7ekMutKDiVnZRs-LiJrgVQljg_beltidqtNk,56
|
|
14
14
|
cellfinder/core/detect/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -31,30 +31,29 @@ cellfinder/core/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3h
|
|
|
31
31
|
cellfinder/core/tools/array_operations.py,sha256=LDbqWA_N2YtxeKS877QYdJRL2FCMC1R1ExJtdb6-vEA,3371
|
|
32
32
|
cellfinder/core/tools/geometry.py,sha256=xlEQQmVQ9jcXqRzUqU554P8VxkaWkc5J_YhjkkKlI0Q,1124
|
|
33
33
|
cellfinder/core/tools/image_processing.py,sha256=z27bGjf3Iv3G4Nt1OYzpEnIYQNc4nNomj_QitqvZB78,2269
|
|
34
|
-
cellfinder/core/tools/prep.py,sha256=
|
|
34
|
+
cellfinder/core/tools/prep.py,sha256=YU4VDyjEsOXn3S3gymEfPUzka-WYiHmR4Q4XdRebhDI,2175
|
|
35
35
|
cellfinder/core/tools/source_files.py,sha256=vvwsIMe1ULKvXg_x22L75iqvCyMjEbUjJsDFQk3GYzY,856
|
|
36
|
-
cellfinder/core/tools/system.py,sha256=
|
|
37
|
-
cellfinder/core/tools/tf.py,sha256=jNU97-gIB51rLube9EoZ9bREkbqWRk-xPZhDTXG_jE4,1713
|
|
36
|
+
cellfinder/core/tools/system.py,sha256=WvEzPr7v-ohLDnzf4n13TMcN5OYIAXOEkaSmrHzdnwc,2438
|
|
38
37
|
cellfinder/core/tools/tiff.py,sha256=NzIz6wq2GzxmcIhawFMwZADe-uQO2rIG46H7xkpGKLs,2899
|
|
39
38
|
cellfinder/core/tools/tools.py,sha256=G8oDGNRuWkzEJDnnC4r3SNGgpVbqbelCZR5ODk9JRzs,4867
|
|
40
39
|
cellfinder/core/train/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
41
|
-
cellfinder/core/train/train_yml.py,sha256=
|
|
40
|
+
cellfinder/core/train/train_yml.py,sha256=9QXv2wk24G8hYskMnBij7OngEELUWySK2fH4NFbYWw4,13260
|
|
42
41
|
cellfinder/napari/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
43
|
-
cellfinder/napari/curation.py,sha256=
|
|
42
|
+
cellfinder/napari/curation.py,sha256=nbxCwY2bhEPM15Wf3S_Ff8qGdLSWojr4X48mmAJqD3U,21447
|
|
44
43
|
cellfinder/napari/input_container.py,sha256=tkm0dkPt7kSL8Xkvs1fh8M6vKWw57QLIt_wv74HFkGk,2150
|
|
45
44
|
cellfinder/napari/napari.yaml,sha256=WMR1CIAmYIVyQngbdbomTRZLvlDbb6LxsXsvTRClQnE,921
|
|
46
45
|
cellfinder/napari/sample_data.py,sha256=oUST23q09MM8dxHbUCmO0AjtXG6OlR_32LLqP0EU2UA,732
|
|
47
46
|
cellfinder/napari/utils.py,sha256=AwTs76M9azutHhHj2yuaKErDEQ5F6pFbIIakBfzen6M,3824
|
|
48
47
|
cellfinder/napari/detect/__init__.py,sha256=BD9Bg9NTAr6yRTq2A_p58U6j4w5wbY0sdXwhPJ3MSMY,34
|
|
49
|
-
cellfinder/napari/detect/detect.py,sha256=
|
|
50
|
-
cellfinder/napari/detect/detect_containers.py,sha256=
|
|
51
|
-
cellfinder/napari/detect/thread_worker.py,sha256=
|
|
48
|
+
cellfinder/napari/detect/detect.py,sha256=VB3SLZvqJhjuypjpaZuV9JscQ-sE8yVHG8RzEsZWfeA,13809
|
|
49
|
+
cellfinder/napari/detect/detect_containers.py,sha256=j9NTsIyyDNrhlI2dc7hvc7QlxvI1NRHlCe137v7fsPg,5467
|
|
50
|
+
cellfinder/napari/detect/thread_worker.py,sha256=PWM3OE-FpK-dpdhaE_Gi-2lD3u8sL-SJ13mp0pMhTyI,3078
|
|
52
51
|
cellfinder/napari/train/__init__.py,sha256=xo4CK-DvSecInGEc2ohcTgQYlH3iylFnGvKTCoq2WkI,35
|
|
53
52
|
cellfinder/napari/train/train.py,sha256=zJY7zKcLqDTDtD76thmbwViEU4tTFCmXZze-zHsTpoo,5941
|
|
54
53
|
cellfinder/napari/train/train_containers.py,sha256=1wZ_GPe7B5XsLYs5XIx4m8GMw5KeVhg6SchhPtXu4V8,4386
|
|
55
|
-
cellfinder-1.
|
|
56
|
-
cellfinder-1.
|
|
57
|
-
cellfinder-1.
|
|
58
|
-
cellfinder-1.
|
|
59
|
-
cellfinder-1.
|
|
60
|
-
cellfinder-1.
|
|
54
|
+
cellfinder-1.3.0.dist-info/LICENSE,sha256=Tw8iMytIDXLSmcIUsbQmRWojstl9yOWsPCx6ZT6dZLY,1564
|
|
55
|
+
cellfinder-1.3.0.dist-info/METADATA,sha256=b40Zf7cFJV6nv3HxqXCtJubKshRQmXNpXaPx0vVrFtI,6528
|
|
56
|
+
cellfinder-1.3.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
57
|
+
cellfinder-1.3.0.dist-info/entry_points.txt,sha256=cKKjU8GPiN-TRelG2sT2JCKAcB9XDCjP6g9atE9pSoY,247
|
|
58
|
+
cellfinder-1.3.0.dist-info/top_level.txt,sha256=jyTQzX-tDjbsMr6s-E71Oy0IKQzmHTXSk4ZhpG5EDSE,11
|
|
59
|
+
cellfinder-1.3.0.dist-info/RECORD,,
|
cellfinder/core/tools/tf.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
|
1
|
-
import tensorflow as tf
|
|
2
|
-
|
|
3
|
-
from cellfinder.core import logger
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def allow_gpu_memory_growth():
|
|
7
|
-
"""
|
|
8
|
-
If a gpu is present, prevent tensorflow from using all the memory straight
|
|
9
|
-
away. Allows multiple processes to use the GPU (and avoid occasional
|
|
10
|
-
errors on some systems) at the cost of a slight performance penalty.
|
|
11
|
-
"""
|
|
12
|
-
gpus = tf.config.experimental.list_physical_devices("GPU")
|
|
13
|
-
if gpus:
|
|
14
|
-
logger.debug("Allowing GPU memory growth")
|
|
15
|
-
try:
|
|
16
|
-
# Currently, memory growth needs to be the same across GPUs
|
|
17
|
-
for gpu in gpus:
|
|
18
|
-
tf.config.experimental.set_memory_growth(gpu, True)
|
|
19
|
-
logical_gpus = tf.config.experimental.list_logical_devices("GPU")
|
|
20
|
-
logger.debug(
|
|
21
|
-
f"{len(gpus)} physical GPUs, {len(logical_gpus)} logical GPUs"
|
|
22
|
-
)
|
|
23
|
-
except RuntimeError as e:
|
|
24
|
-
# Memory growth must be set before GPUs have been initialized
|
|
25
|
-
print(e)
|
|
26
|
-
else:
|
|
27
|
-
logger.debug("No GPUs found, using CPU.")
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def set_tf_threads(max_threads):
|
|
31
|
-
"""
|
|
32
|
-
Limit the number of threads that tensorflow uses
|
|
33
|
-
:param max_threads: Maximum number of threads to use
|
|
34
|
-
:return:
|
|
35
|
-
"""
|
|
36
|
-
logger.debug(
|
|
37
|
-
f"Setting maximum number of threads for tensorflow "
|
|
38
|
-
f"to: {max_threads}"
|
|
39
|
-
)
|
|
40
|
-
|
|
41
|
-
# If statements are for testing. If tf is initialised, then setting these
|
|
42
|
-
# parameters throws an error
|
|
43
|
-
if tf.config.threading.get_inter_op_parallelism_threads() != 0:
|
|
44
|
-
tf.config.threading.set_inter_op_parallelism_threads(max_threads)
|
|
45
|
-
if tf.config.threading.get_intra_op_parallelism_threads() != 0:
|
|
46
|
-
tf.config.threading.set_intra_op_parallelism_threads(max_threads)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|