lbm_caiman_python 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ # Imported before any libraries
2
+ import os
3
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
@@ -0,0 +1,569 @@
1
+ import os
2
+ import shutil
3
+ import sys
4
+ import tempfile
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Iterable
8
+
9
+ import caiman as cm
10
+ import cv2
11
+
12
+ import numpy as np
13
+
14
+ import pandas as pd
15
+ import scipy
16
+ import tifffile
17
+ from tqdm import tqdm
18
+
19
+ from .batch import load_batch
20
+
21
+ SUMMARY_PARAMS = (
22
+ "K",
23
+ "gSig",
24
+ "gSig_filt",
25
+ "min_SNR",
26
+ "rval_thr"
27
+ )
28
+
29
+
30
+ def get_all_batch_items(files: list, algo="all") -> pd.DataFrame:
31
+ """
32
+ Load all cnmf items from a list of .pickle files.
33
+
34
+ Parameters
35
+ ----------
36
+ files : list
37
+ List of .pickle files to load.
38
+ algo : str, optional
39
+ Algorithm to filter by. Default is "all". Options are "cnmf", "cnmfe", "mcorr" and "all".
40
+
41
+ Returns
42
+ -------
43
+ df : DataFrame
44
+ DataFrame containing all items with the specified algorithm
45
+ """
46
+ temp_row = []
47
+ for file in files:
48
+ try:
49
+ df = load_batch(file)
50
+ df.paths.set_batch_path(file)
51
+ df['batch_path'] = file
52
+ except Exception as e:
53
+ print(f"Error loading {file}: {e}", file=sys.stderr)
54
+ continue
55
+
56
+ for _, row in df.iterrows():
57
+ if (isinstance(row["outputs"], dict)
58
+ and not row["outputs"].get("success")
59
+ or row["outputs"] is None
60
+ ):
61
+ continue
62
+ if algo == "all":
63
+ temp_row.append(row)
64
+ elif row["algo"] == algo:
65
+ temp_row.append(row)
66
+ return pd.DataFrame(temp_row)
67
+
68
+
69
+ def get_summary_batch(df) -> pd.DataFrame:
70
+ """
71
+ Create a summary of successful and unsuccessful runs for each completed algorithm.
72
+
73
+ Parameters
74
+ ----------
75
+ df : pd.DataFrame
76
+ Batch dataframe containing the columns `algo` and `outputs`.
77
+
78
+
79
+ Returns
80
+ -------
81
+ summary_df : pd.DataFrame
82
+ DataFrame containing the number of successful and unsuccessful runs for each algorithm.
83
+
84
+ """
85
+ if df.empty:
86
+ raise ValueError("Input DataFrame is empty.")
87
+ elif not hasattr(df, 'item_name'):
88
+ raise ValueError("Input DataFrame does not have an 'item_name' column.")
89
+
90
+ mcorr_df = df[df.algo == 'mcorr']
91
+ cnmf_df = df[df.algo.isin(['cnmf', 'cnmfe'])]
92
+ succ_mcorr = _num_successful_from_df(mcorr_df)
93
+ succ_cnmf = _num_successful_from_df(cnmf_df)
94
+ unsucc_mcorr = len(mcorr_df) - succ_mcorr
95
+ unsucc_cnmf = len(cnmf_df) - succ_cnmf
96
+
97
+ return pd.DataFrame([
98
+ {'algo': 'mcorr', 'Runs': len(mcorr_df), 'Successful': succ_mcorr,
99
+ 'Unsuccessful': unsucc_mcorr},
100
+ {'algo': 'cnmf', 'Runs': len(cnmf_df), 'Successful': succ_cnmf,
101
+ 'Unsuccessful': unsucc_cnmf}
102
+ ])
103
+
104
+
105
+ def get_summary_cnmf(df: pd.DataFrame) -> pd.DataFrame:
106
+ """
107
+ Get a summary of CNMF runs from a DataFrame.
108
+
109
+ Parameters
110
+ ----------
111
+ df : pd.DataFrame
112
+ DataFrame containing the columns `algo` and `outputs`. Dataframe will be filtered by cnmf and cnmfe runs.
113
+ """
114
+ # Safely add new columns with traces / params
115
+ return _params_from_df(_num_traces_from_df(df))
116
+
117
+
118
+ def get_summary_mcorr(df: pd.DataFrame) -> pd.DataFrame:
119
+ files = compute_mcorr_metrics_batch(df)
120
+ return _create_df_from_metric_files(files)
121
+
122
+
123
+ def concat_param_diffs(input_df, param_diffs):
124
+ """
125
+ Add parameter differences to the input DataFrame.
126
+
127
+ Parameters
128
+ ----------
129
+ input_df : DataFrame
130
+ The input DataFrame containing a 'batch_index' column.
131
+ param_diffs : DataFrame
132
+ The DataFrame containing the parameter differences for each batch.
133
+
134
+ Returns
135
+ -------
136
+ input_df : DataFrame
137
+ The input DataFrame with the parameter differences added.
138
+
139
+ Examples
140
+ --------
141
+ >>> import pandas as pd
142
+ >>> import lbm_caiman_python as lcp
143
+ >>> import lbm_mc as mc
144
+ >>> batch_df = mc.load_batch('path/to/batch.pickle')
145
+ >>> metrics_files = lcp.summary.compute_mcorr_metrics_batch(batch_df)
146
+ >>> metrics_df = lcp.summary._create_df_from_metric_files(metrics_files)
147
+ >>> param_diffs = batch_df.caiman.get_params_diffs("mcorr", item_name=batch_df.iloc[0]["item_name"])
148
+ >>> final_df = lcp.concat_param_diffs(metrics_df, param_diffs)
149
+ >>> print(final_df.head())
150
+ """
151
+ # add an empty column for each param diff
152
+ for col in param_diffs.columns:
153
+ if col not in input_df.columns:
154
+ input_df[col] = None
155
+
156
+ for i, row in input_df.iterrows():
157
+ # raw data will not have an index in the dataframe
158
+ if row['batch_index'] == -1:
159
+ continue
160
+ batch_index = int(row['batch_index'])
161
+
162
+ if batch_index < len(param_diffs):
163
+ param_diff = param_diffs.iloc[batch_index]
164
+
165
+ for col in param_diffs.columns:
166
+ input_df.at[i, col] = param_diff[col]
167
+
168
+ input_df = input_df[
169
+ ['mean_corr', 'mean_norm', 'crispness']
170
+ + list(param_diffs.columns)
171
+ + ['batch_index', 'uuid', 'metric_path']
172
+ ]
173
+
174
+ return input_df
175
+
176
+
177
+ def _create_df_from_metric_files(metrics_filepaths: Iterable[str | Path]) -> pd.DataFrame:
178
+ """
179
+ Create a DataFrame from a list of metrics files.
180
+
181
+ Parameters
182
+ ----------
183
+ metrics_filepaths : list of str or Path
184
+ List of paths to the metrics files (.npz) containing 'correlations', 'norms',
185
+ 'smoothness', 'flows', and the batch item UUID.
186
+ Typically, use the output of `compute_batch_metrics` to get the list of metrics files.
187
+
188
+ Returns
189
+ -------
190
+ metrics_df : DataFrame
191
+ A DataFrame containing the mean correlation, mean norm, crispness, UUID, batch index, and metric path.
192
+
193
+ Examples
194
+ --------
195
+ >>> import pandas as pd
196
+ >>> import lbm_caiman_python as lcp
197
+ >>> import lbm_mc as mc
198
+ >>> batch_df = mc.load_batch('path/to/batch.pickle')
199
+ >>> # overwrite=False will not recompute metrics if they already exist
200
+ >>> metrics_files = lcp.summary.compute_mcorr_metrics_batch(batch_df, overwrite=False)
201
+ >>> metrics_df = lcp._create_df_from_metric_files(metrics_files)
202
+ >>> print(metrics_df.head())
203
+ """
204
+ metrics_list = []
205
+ for i, file in enumerate(metrics_filepaths):
206
+ with np.load(file) as f:
207
+ corr = f['correlations']
208
+ norms = f['norms']
209
+ crispness = f['smoothness_corr']
210
+ uuid = f['uuid']
211
+ batch_index = f['batch_id']
212
+ metrics_list.append({
213
+ 'mean_corr': np.mean(corr),
214
+ 'mean_norm': np.mean(norms),
215
+ 'crispness': float(crispness),
216
+ 'uuid': str(uuid),
217
+ 'batch_index': int(batch_index),
218
+ 'metric_path': file
219
+ })
220
+ return pd.DataFrame(metrics_list)
221
+
222
+
223
+ def compute_mcorr_metrics_batch(batch_df: pd.DataFrame, overwrite: bool = False) -> Iterable[Path]:
224
+ """
225
+ Compute and store various statistical registration metrics for each batch of image data.
226
+
227
+ Attempts to compute metrics for raw data if:
228
+ 1. The raw data file is found in the batch path.
229
+ 2. The raw data file is found in the global parent directory set via `mc.set_parent_raw_data_path()`.
230
+
231
+ Parameters
232
+ ----------
233
+ batch_df : DataFrame, optional
234
+ A DataFrame containing information about each batch of image data.
235
+ Must be compatible with the mesmerize-core DataFrame API to call
236
+ `get_params_diffs` and `get_output` on each row.
237
+ overwrite : bool, optional
238
+ If True, recompute and overwrite existing metric files. Default is False.
239
+
240
+ Returns
241
+ -------
242
+ metrics_paths : list of Path
243
+ List of file paths where metrics are stored for each batch.
244
+
245
+ Examples
246
+ --------
247
+ >>> import pandas as pd
248
+ >>> import lbm_caiman_python as lcp
249
+ >>> import lbm_mc as mc
250
+ >>> batch_df = mc.load_batch('path/to/batch.pickle')
251
+ >>> metrics_paths = lcp.compute_mcorr_metrics_batch(batch_df)
252
+ >>> print(metrics_paths)
253
+ [Path('path/to/metrics1.npz'), Path('path/to/metrics2.npz'), ...]
254
+
255
+ TODO: This can be made to run in parallel.
256
+ """
257
+ metrics_paths = []
258
+
259
+ # raw_filename will be resolved if:
260
+ # 1. It is located in the batch dir
261
+ # 2. It is located in the global parent dir
262
+ try:
263
+ raw_filename = batch_df.iloc[0].caiman.get_input_movie_path()
264
+ except AttributeError:
265
+ print('Skipping raw data metrics computation.'
266
+ 'Could not find raw data file.'
267
+ 'Make sure to call mc.set_parent_raw_data_path(data_path) before calling this function.')
268
+ raw_filename = None
269
+
270
+ if raw_filename is not None:
271
+ if not raw_filename.exists():
272
+ raise FileNotFoundError(f"Raw data file {raw_filename} not found.")
273
+
274
+ raw_metrics_path = get_metrics_path(raw_filename)
275
+ if raw_metrics_path.exists() and not overwrite:
276
+ print(f"Raw metrics file {raw_metrics_path} already exists. Skipping. To overwrite, set `overwrite=True`.")
277
+ else:
278
+ if raw_metrics_path.exists():
279
+ print(f"Overwriting raw metrics file {raw_metrics_path}.")
280
+ raw_metrics_path.unlink(missing_ok=True)
281
+
282
+ start = time.time()
283
+ raw_metrics_path = _compute_raw_mcorr_metrics(raw_filename, overwrite=overwrite)
284
+ print(f'Computed metrics for raw data in {time.time() - start:.2f} seconds.')
285
+
286
+ metrics_paths.append(raw_metrics_path)
287
+
288
+ for i, row in batch_df.iterrows():
289
+ print(f'Processing batch index {i}...')
290
+
291
+ if row.algo != 'mcorr':
292
+ print(f"Skipping batch index {i} as algo is not 'mcorr'.")
293
+ continue
294
+
295
+ data = row.mcorr.get_output()
296
+ final_size = data.shape[1:]
297
+
298
+ # Pre-fetch metrics path
299
+ metrics_path = get_metrics_path(row.mcorr.get_output_path())
300
+
301
+ # Check if metrics already exist and skip if not overwriting
302
+ if metrics_path.exists() and not overwrite:
303
+ print(f"Metrics file {metrics_path} already exists. Skipping. To overwrite, set `overwrite=True`.")
304
+ metrics_paths.append(metrics_path)
305
+ continue
306
+
307
+ if metrics_path.exists() and overwrite:
308
+ print(f"Overwriting metrics file {metrics_path}.")
309
+ metrics_path.unlink(missing_ok=True)
310
+
311
+ try:
312
+ start = time.time()
313
+ _ = _compute_metrics(row.mcorr.get_output_path(), row.uuid, i, final_size[0], final_size[1])
314
+
315
+ print(f'Computed metrics for batch index {i} in {time.time() - start:.2f} seconds.')
316
+ metrics_paths.append(metrics_path)
317
+ except Exception as e:
318
+ print(f"Failed to compute metrics for batch index {i}. Error: {e}")
319
+
320
+ return metrics_paths
321
+
322
+
323
+ def _num_traces_from_df(df: pd.DataFrame) -> pd.DataFrame:
324
+ """
325
+ Add trace-related columns to a DataFrame for specific algorithms.
326
+
327
+ Filters the DataFrame to include rows where the `algo` column contains
328
+ either "cnmf" or "cnmfe", then adds the following columns if they
329
+ do not already exist:
330
+ - "Total Traces": Total number of temporal components.
331
+ - "Accepted": Number of accepted components.
332
+ - "Rejected": Number of rejected components.
333
+
334
+ Parameters
335
+ ----------
336
+ df : pd.DataFrame
337
+ DataFrame containing the columns `batch_path`, `uuid`, and `algo`.
338
+
339
+ Returns
340
+ -------
341
+ pd.DataFrame
342
+ DataFrame with the added trace-related columns, updated for rows
343
+ with `algo` values of "cnmf" or "cnmfe". For other rows, the new
344
+ columns are left as `None`.
345
+ """
346
+ # Safely add new columns with default values of None
347
+ df = df[df["algo"].isin(["cnmf", "cnmfe"])]
348
+
349
+ add_cols = ["Total Traces", "Accepted", "Rejected"]
350
+ for col in add_cols:
351
+ if col not in df.columns:
352
+ df[col] = None
353
+
354
+ for idx, row in df.iterrows():
355
+ batch_df = load_batch(row["batch_path"]) # Ensure access using correct key
356
+ item = batch_df[batch_df.uuid == row["uuid"]].iloc[0]
357
+ if row["algo"] in ("cnmf", "cnmfe"):
358
+ df.at[idx, "Total Traces"] = item.cnmf.get_temporal().shape[0]
359
+ df.at[idx, "Accepted"] = len(item.cnmf.get_output().estimates.idx_components)
360
+ df.at[idx, "Rejected"] = len(item.cnmf.get_output().estimates.idx_components_bad)
361
+ else:
362
+ df.at[idx, "Total Traces"] = None
363
+ df.at[idx, "Accepted"] = None
364
+ df.at[idx, "Rejected"] = None
365
+
366
+ return df
367
+
368
+
369
+ def _params_from_df(df: pd.DataFrame, params: tuple | list | None = None):
370
+ """
371
+ Add specified parameters to a DataFrame from a batch DataFrame.
372
+
373
+ Parameters
374
+ ----------
375
+ df : pd.DataFrame
376
+ DataFrame containing the columns `batch_path`, `uuid`, and `algo`.
377
+
378
+ params : tuple or list, optional
379
+ List of parameter names to add to the DataFrame.
380
+ If not provided, defaults to `SUMMARY_PARAMS`, which includes:
381
+ - "K"
382
+ - "gSig"
383
+ - "gSig_filt"
384
+ - "min_SNR"
385
+ - "rval_thr"
386
+
387
+ Returns
388
+ -------
389
+ pd.DataFrame
390
+ DataFrame with the specified parameters added as columns. The values
391
+ are extracted from the corresponding batch file for each row.
392
+ """
393
+ if params is None:
394
+ params = SUMMARY_PARAMS
395
+ for col in params:
396
+ if col not in df.columns:
397
+ df[col] = None
398
+ for idx, row in df.iterrows():
399
+ batch_df = load_batch(row.batch_path)
400
+ item = batch_df[batch_df.uuid == row.uuid].iloc[0]
401
+ for param in params:
402
+ value = item.params['main'].get(param)
403
+ # Handle iterable values
404
+ if isinstance(value, (list, tuple, np.ndarray)):
405
+ df.at[idx, param] = str(value) # Store as a string
406
+ else:
407
+ df.at[idx, param] = value
408
+ return df
409
+
410
+
411
+ def _num_successful_from_df(df: pd.DataFrame) -> int:
412
+ """ Count the number of successful runs in a DataFrame with outputs."""
413
+ return len(df[df.outputs.apply(lambda x: x.get("success"))])
414
+
415
+
416
+ def get_metrics_path(fname: Path) -> Path:
417
+ """
418
+ Get the path to the computed metrics file for a given data file.
419
+ Assumes the metrics file is to be stored in the same directory as the data file,
420
+ with the same name stem and a '_metrics.npz' suffix.
421
+
422
+ Parameters
423
+ ----------
424
+ fname : Path
425
+ The path to the input data file.
426
+
427
+ Returns
428
+ -------
429
+ metrics_path : Path
430
+ The path to the computed metrics file.
431
+ """
432
+ fname = Path(fname)
433
+ return fname.with_stem(fname.stem + '_metrics').with_suffix('.npz')
434
+
435
+
436
+ def _compute_metrics(fname, uuid, batch_id, final_size_x, final_size_y, swap_dim=False, pyr_scale=.5, levels=3,
437
+ winsize=100, iterations=15, poly_n=5, poly_sigma=1.2 / 5, flags=0,
438
+ resize_fact_flow=.2, template=None, gSig_filt=None):
439
+ """
440
+ Compute metrics for a given movie file.
441
+ """
442
+ if not uuid:
443
+ raise ValueError("UUID must be provided.")
444
+
445
+ m = cm.load(fname)
446
+ if gSig_filt is not None:
447
+ m = cm.motion_correction.high_pass_filter_space(m, gSig_filt)
448
+
449
+ max_shft_x = int(np.ceil((np.shape(m)[1] - final_size_x) / 2))
450
+ max_shft_y = int(np.ceil((np.shape(m)[2] - final_size_y) / 2))
451
+ max_shft_x_1 = - ((np.shape(m)[1] - max_shft_x) - (final_size_x))
452
+ max_shft_y_1 = - ((np.shape(m)[2] - max_shft_y) - (final_size_y))
453
+ if max_shft_x_1 == 0:
454
+ max_shft_x_1 = None
455
+
456
+ if max_shft_y_1 == 0:
457
+ max_shft_y_1 = None
458
+ m = m[:, max_shft_x:max_shft_x_1, max_shft_y:max_shft_y_1]
459
+ if np.sum(np.isnan(m)) > 0:
460
+ raise Exception('Movie contains NaN')
461
+
462
+ img_corr = m.local_correlations(eight_neighbours=True, swap_dim=swap_dim)
463
+ if template is None:
464
+ tmpl = cm.motion_correction.bin_median(m)
465
+ else:
466
+ tmpl = template
467
+
468
+ smoothness = np.sqrt(
469
+ np.sum(np.sum(np.array(np.gradient(np.mean(m, 0))) ** 2, 0)))
470
+ smoothness_corr = np.sqrt(
471
+ np.sum(np.sum(np.array(np.gradient(img_corr)) ** 2, 0)))
472
+
473
+ correlations = []
474
+ count = 0
475
+ sys.stdout.flush()
476
+ for fr in tqdm(m, desc="Correlations"):
477
+ count += 1
478
+ correlations.append(scipy.stats.pearsonr(
479
+ fr.flatten(), tmpl.flatten())[0])
480
+
481
+ m = m.resize(1, 1, resize_fact_flow)
482
+ norms = []
483
+ flows = []
484
+ count = 0
485
+ sys.stdout.flush()
486
+ for fr in tqdm(m, desc="Optical flow"):
487
+ count += 1
488
+ flow = cv2.calcOpticalFlowFarneback(
489
+ tmpl, fr, None, pyr_scale, levels, winsize, iterations, poly_n, poly_sigma, flags)
490
+
491
+ n = np.linalg.norm(flow)
492
+ flows.append(flow)
493
+ norms.append(n)
494
+
495
+ # cast to numpy-loadable primatives, handle variable cases of None
496
+ uuid = str(uuid) if uuid not in [None, 'None', 'nan'] else 'None'
497
+ batch_id = int(batch_id) if batch_id not in [None, 'None', 'nan'] else -1
498
+ np.savez(
499
+ os.path.splitext(fname)[0] + '_metrics',
500
+ uuid=uuid,
501
+ batch_id=batch_id,
502
+ flows=flows,
503
+ norms=norms,
504
+ correlations=correlations,
505
+ smoothness=smoothness,
506
+ tmpl=tmpl,
507
+ smoothness_corr=smoothness_corr,
508
+ img_corr=img_corr
509
+ )
510
+
511
+
512
+ def _compute_raw_mcorr_metrics(raw_fname: Path, overwrite=False) -> Path:
513
+ """
514
+ Wrapper for caiman.motion_correction.compute_metrics_motion_correction. Writes raw_file to a temporary memmapped file to
515
+ run compute_metrics_motion_correction, and move the metrics file back to the fname directory.
516
+
517
+ Needed due to compute_metrics_motion_correction not accepting memmapped files, just filenames.
518
+
519
+ Parameters
520
+ ----------
521
+ raw_fname : Path
522
+ The path to the raw data file. Must be a TIFF file.
523
+ overwrite : bool, optional
524
+ If True, recompute the metrics even if the file already exists. Default is False.
525
+
526
+ Returns
527
+ -------
528
+ final_metrics_path : Path
529
+ The path to the computed metrics file.
530
+
531
+ Notes
532
+ -----
533
+ The final metrics files contains the following keys:
534
+ - 'correlations': The correlation coefficients between frames.
535
+ - 'flows': The flow vectors between frames.
536
+ - 'norms': A list of magnitudes of optical flow for each frame. Represents the amount of motion in each frame.
537
+ - 'smoothness': A measure of the sharpness of the image.
538
+ """
539
+ # make a new uuid with raw_{uuid}
540
+ import uuid
541
+ raw_uuid = f'raw_{uuid.uuid4()}'
542
+
543
+ final_metrics_path = get_metrics_path(raw_fname)
544
+
545
+ if final_metrics_path.exists() and not overwrite:
546
+ return final_metrics_path
547
+
548
+ data = tifffile.memmap(raw_fname)
549
+
550
+ if final_metrics_path.exists() and overwrite:
551
+ final_metrics_path.unlink()
552
+
553
+ with tempfile.NamedTemporaryFile(suffix='.tiff', delete=False) as temp_file:
554
+ temp_path = Path(temp_file.name)
555
+
556
+ try:
557
+ tifffile.imwrite(temp_path, data)
558
+ _compute_metrics(temp_path, raw_uuid, None, data.shape[1], data.shape[2], swap_dim=False)
559
+
560
+ temp_metrics_path = get_metrics_path(temp_path)
561
+
562
+ if temp_metrics_path.exists():
563
+ shutil.move(temp_metrics_path, final_metrics_path)
564
+ else:
565
+ raise FileNotFoundError(f"Expected metrics file {temp_metrics_path} not found.")
566
+ finally:
567
+ temp_path.unlink(missing_ok=True)
568
+
569
+ return final_metrics_path
@@ -0,0 +1,87 @@
1
+ import os
2
+ import sys
3
+
4
+ from .exceptions import PipelineException
5
+
6
+
7
+ def extract_common_key(filepath: os.PathLike):
8
+ parts = filepath.stem.split("_")
9
+ return "_".join(parts[:-1])
10
+
11
+
12
+ class CacheDict(dict):
13
+ """
14
+ A dictionary that prevents itself from growing too much.
15
+ """
16
+
17
+ def __init__(self, maxentries):
18
+ self.maxentries = maxentries
19
+ super().__init__(self)
20
+
21
+ def __setitem__(self, key, value):
22
+ # Protection against growing the cache too much
23
+ if len(self) > self.maxentries:
24
+ # Remove a 10% of (arbitrary) elements from the cache
25
+ entries_to_remove = self.maxentries / 10
26
+ for k in list(self)[:entries_to_remove]:
27
+ super().__delitem__(k)
28
+ super().__setitem__(key, value)
29
+
30
+
31
+ def detect_number_of_cores():
32
+ """Detects the number of cores on a system."""
33
+
34
+ # Linux, Unix and MacOS:
35
+ if hasattr(os, "sysconf"):
36
+ if "SC_NPROCESSORS_ONLN" in os.sysconf_names:
37
+ # Linux & Unix:
38
+ ncpus = os.sysconf("SC_NPROCESSORS_ONLN")
39
+ if isinstance(ncpus, int) and ncpus > 0:
40
+ return ncpus
41
+ # Windows:
42
+ if "NUMBER_OF_PROCESSORS" in os.environ:
43
+ ncpus = int(os.environ["NUMBER_OF_PROCESSORS"])
44
+ if ncpus > 0:
45
+ return ncpus
46
+ return 1 # Default
47
+
48
+
49
+ def get_size(obj, seen=None, unit="gb"):
50
+ """Recursively finds size of objects"""
51
+ unit = unit.lower()
52
+ if unit not in ["gb", "mb", "kb", "b"]:
53
+ raise ValueError("unit must be one of 'gb', 'mb', 'kb', 'b'")
54
+ elif unit == "gb":
55
+ factor = 1024**3
56
+ elif unit == "mb":
57
+ factor = 1024**2
58
+ elif unit == "kb":
59
+ factor = 1024
60
+ else:
61
+ factor = 1
62
+
63
+ size = sys.getsizeof(obj)
64
+ if seen is None:
65
+ seen = set()
66
+ obj_id = id(obj)
67
+ if obj_id in seen:
68
+ return 0
69
+ # Important mark as seen *before* entering recursion to gracefully handle
70
+ # self-referential objects
71
+ seen.add(obj_id)
72
+ if isinstance(obj, dict):
73
+ size += sum([get_size(v, seen) for v in obj.values()])
74
+ size += sum([get_size(k, seen) for k in obj.keys()])
75
+ elif hasattr(obj, "__dict__"):
76
+ size += get_size(obj.__dict__, seen)
77
+ elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, bytearray)):
78
+ size += sum([get_size(i, seen) for i in obj])
79
+ return f"{size // factor} {unit}"
80
+
81
+
82
+ __all__ = [
83
+ "PipelineException",
84
+ "detect_number_of_cores",
85
+ "CacheDict",
86
+ "get_size",
87
+ ]
@@ -0,0 +1,6 @@
1
+ class PipelineException(Exception):
2
+ """Base pipeline exception. Prints the message plus any specific info."""
3
+ def __init__(self, message, info=None):
4
+ info_message = '\nError info: ' + repr(info) if info else ''
5
+ super().__init__(message + info_message)
6
+ self.info = info