pcntoolkit 1.1.2__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/PKG-INFO +5 -3
  2. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/__init__.py +2 -1
  3. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/dataio/fileio.py +14 -17
  4. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/dataio/norm_data.py +137 -59
  5. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/math_functions/basis_function.py +116 -79
  6. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/math_functions/likelihood.py +10 -12
  7. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/math_functions/shash.py +27 -7
  8. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/math_functions/thrive.py +11 -1
  9. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/math_functions/warp.py +15 -3
  10. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/normative_model.py +167 -7
  11. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/regression_model/blr.py +35 -4
  12. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/util/evaluator.py +20 -13
  13. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/util/output.py +25 -3
  14. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/util/plotter.py +134 -72
  15. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/util/runner.py +44 -19
  16. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit.egg-info/PKG-INFO +5 -3
  17. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit.egg-info/requires.txt +3 -1
  18. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pyproject.toml +8 -3
  19. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/LICENSE +0 -0
  20. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/README.md +0 -0
  21. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/dataio/__init__.py +0 -0
  22. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/dataio/data_factory.py +0 -0
  23. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/math_functions/__init__.py +0 -0
  24. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/math_functions/factorize.py +0 -0
  25. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/math_functions/prior.py +0 -0
  26. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/math_functions/scaler.py +0 -0
  27. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/normative.py +0 -0
  28. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/regression_model/__init__.py +0 -0
  29. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/regression_model/factory.py +0 -0
  30. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/regression_model/hbr.py +0 -0
  31. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/regression_model/regression_model.py +0 -0
  32. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/regression_model/test_model.py +0 -0
  33. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/util/__init__.py +0 -0
  34. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/util/job_observer.py +0 -0
  35. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/util/model_comparison.py +0 -0
  36. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit/util/paths.py +0 -0
  37. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit.egg-info/SOURCES.txt +0 -0
  38. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit.egg-info/dependency_links.txt +0 -0
  39. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit.egg-info/entry_points.txt +0 -0
  40. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/pcntoolkit.egg-info/top_level.txt +0 -0
  41. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/setup.cfg +0 -0
  42. {pcntoolkit-1.1.2 → pcntoolkit-1.2.0}/test/test_normative.py +0 -0
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pcntoolkit
3
- Version: 1.1.2
3
+ Version: 1.2.0
4
4
  Summary: Predictive Clinical Neuroscience Toolkit
5
5
  Author: Andre Marquand, Stijn de Boer, Seyed Mostafa Kia, Saige Rutherford, Charlotte Fraza, Barbora Rehák Bučková, Pieter Barkema, Thomas Wolfers, Mariam Zabihi, Richard Dinga, Johanna Bayer, Maarten Mennes, Hester Huijsdens, Linden Parkes, Pierre Berthet
6
6
  License-Expression: GPL-3.0-only
7
- Requires-Python: <3.13,>=3.10
7
+ Requires-Python: <3.13,>=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
10
  Requires-Dist: nibabel>=5.3.1
@@ -15,11 +15,13 @@ Requires-Dist: scipy>=1.12
15
15
  Requires-Dist: matplotlib>=3.9.2
16
16
  Requires-Dist: seaborn>=0.13.2
17
17
  Requires-Dist: numba>=0.60.0
18
- Requires-Dist: nutpie>=0.13.2
18
+ Requires-Dist: nutpie>=0.16.5
19
19
  Requires-Dist: joblib>=1.4.2
20
20
  Requires-Dist: dill>=0.3.9
21
21
  Requires-Dist: ipywidgets>=8.1.5
22
22
  Requires-Dist: ipykernel>=6.29.5
23
+ Requires-Dist: dask>=2025.11.0
24
+ Requires-Dist: filelock>=3.13.0
23
25
  Provides-Extra: dev
24
26
  Requires-Dist: toml; extra == "dev"
25
27
  Requires-Dist: sphinx-tabs>=3.4.7; extra == "dev"
@@ -1,6 +1,6 @@
1
1
  from .dataio.data_factory import load_fcon1000
2
2
  from .dataio.norm_data import NormData
3
- from .math_functions.basis_function import BsplineBasisFunction, LinearBasisFunction, PolynomialBasisFunction
3
+ from .math_functions.basis_function import BsplineBasisFunction, LinearBasisFunction, PolynomialBasisFunction, CompositeBasisFunction
4
4
  from .math_functions.likelihood import BetaLikelihood, NormalLikelihood, SHASHbLikelihood
5
5
  from .math_functions.prior import make_prior
6
6
  from .normative_model import NormativeModel
@@ -16,6 +16,7 @@ __all__ = [
16
16
  "BsplineBasisFunction",
17
17
  "LinearBasisFunction",
18
18
  "PolynomialBasisFunction",
19
+ "ComnpositeBasisFunction",
19
20
  "NormativeModel",
20
21
  "BLR",
21
22
  "HBR",
@@ -3,6 +3,7 @@ from __future__ import print_function
3
3
  import os
4
4
  import re
5
5
  import shutil
6
+ import subprocess
6
7
  import sys
7
8
  import tempfile
8
9
 
@@ -259,8 +260,8 @@ def load_cifti(filename, vol=False, mask=None, rmtmp=True):
259
260
  Output.print(Messages.EXTRACTING_CIFTI_SURFACE_DATA, outstem=outstem)
260
261
  giinamel = outstem + "-left.func.gii"
261
262
  giinamer = outstem + "-right.func.gii"
262
- os.system("wb_command -cifti-separate " + filename + " COLUMN -metric CORTEX_LEFT " + giinamel)
263
- os.system("wb_command -cifti-separate " + filename + " COLUMN -metric CORTEX_RIGHT " + giinamer)
263
+ subprocess.run(["wb_command", "-cifti-separate", filename, "COLUMN", "-metric", "CORTEX_LEFT", giinamel], check=True)
264
+ subprocess.run(["wb_command", "-cifti-separate", filename, "COLUMN", "-metric", "CORTEX_RIGHT", giinamer], check=True)
264
265
 
265
266
  # load the surface data
266
267
  giil = nib.load(giinamel)
@@ -284,7 +285,7 @@ def load_cifti(filename, vol=False, mask=None, rmtmp=True):
284
285
  if vol:
285
286
  niiname = outstem + "-vol.nii"
286
287
  Output.print(Messages.EXTRACTING_CIFTI_VOLUME_DATA, niiname=niiname)
287
- os.system("wb_command -cifti-separate " + filename + " COLUMN -volume-all " + niiname)
288
+ subprocess.run(["wb_command", "-cifti-separate", filename, "COLUMN", "-volume-all", niiname], check=True)
288
289
  vol = load_nifti(niiname, vol=True)
289
290
  volmask = create_mask(vol)
290
291
  out = np.concatenate((out, vol2vec(vol, volmask)), axis=0)
@@ -331,8 +332,8 @@ def save_cifti(data, filename, example, mask=None, vol=True, volatlas=None):
331
332
  estem = os.path.join(tempfile.gettempdir(), str(os.getpid()) + "-" + fstem)
332
333
  giiexnamel = estem + "-left.func.gii"
333
334
  giiexnamer = estem + "-right.func.gii"
334
- os.system("wb_command -cifti-separate " + example + " COLUMN -metric CORTEX_LEFT " + giiexnamel)
335
- os.system("wb_command -cifti-separate " + example + " COLUMN -metric CORTEX_RIGHT " + giiexnamer)
335
+ subprocess.run(["wb_command", "-cifti-separate", example, "COLUMN", "-metric", "CORTEX_LEFT", giiexnamel], check=True)
336
+ subprocess.run(["wb_command", "-cifti-separate", example, "COLUMN", "-metric", "CORTEX_RIGHT", giiexnamer], check=True)
336
337
 
337
338
  # write left hemisphere
338
339
  giiexl = nib.load(giiexnamel)
@@ -359,7 +360,7 @@ def save_cifti(data, filename, example, mask=None, vol=True, volatlas=None):
359
360
  # process volumetric data
360
361
  if vol:
361
362
  niiexname = estem + "-vol.nii"
362
- os.system("wb_command -cifti-separate " + example + " COLUMN -volume-all " + niiexname)
363
+ subprocess.run(["wb_command", "-cifti-separate", example, "COLUMN", "-volume-all", niiexname], check=True)
363
364
  niivol = load_nifti(niiexname, vol=True)
364
365
  if mask is None:
365
366
  mask = create_mask(niivol)
@@ -373,17 +374,13 @@ def save_cifti(data, filename, example, mask=None, vol=True, volatlas=None):
373
374
 
374
375
  # write cifti
375
376
  fname = fstem + ".dtseries.nii"
376
- os.system(
377
- "wb_command -cifti-create-dense-timeseries "
378
- + fname
379
- + " -volume "
380
- + fnamev
381
- + " "
382
- + volatlas
383
- + " -left-metric "
384
- + fnamel
385
- + " -right-metric "
386
- + fnamer
377
+ subprocess.run(
378
+ [
379
+ "wb_command", "-cifti-create-dense-timeseries",
380
+ fname, "-volume", fnamev, volatlas,
381
+ "-left-metric", fnamel, "-right-metric", fnamer,
382
+ ],
383
+ check=True,
387
384
  )
388
385
 
389
386
  # clean up
@@ -11,7 +11,7 @@ is used by all the models in the toolkit.
11
11
  from __future__ import annotations
12
12
 
13
13
  import copy
14
- import fcntl
14
+ import json
15
15
  import os
16
16
  from collections import defaultdict
17
17
  from functools import reduce
@@ -32,20 +32,21 @@ from typing import (
32
32
 
33
33
  # pylint: enable=deprecated-class
34
34
  import numpy as np
35
- from numpy.typing import ArrayLike
36
35
  import pandas as pd # type: ignore
37
36
  import xarray as xr
38
37
  from nibabel.loadsave import load
38
+ from numpy.typing import ArrayLike
39
+ from scipy import stats
39
40
  from sklearn.model_selection import StratifiedKFold, train_test_split # type: ignore
40
41
 
41
42
  # import datavars from xarray
42
43
  from xarray.core.types import DataVars
43
44
 
45
+ from filelock import FileLock
46
+
44
47
  from pcntoolkit.dataio.fileio import load
45
48
  from pcntoolkit.util.output import Messages, Output, Warnings
46
49
 
47
- from scipy import stats
48
-
49
50
 
50
51
  class NormData(xr.Dataset):
51
52
  """A class for handling normative modeling data, extending xarray.Dataset.
@@ -212,8 +213,7 @@ class NormData(xr.Dataset):
212
213
  NormData
213
214
  An instance of NormData.
214
215
  """
215
- img = load(fsl_folder)
216
- dat = img.get_fdata()
216
+ raise NotImplementedError("from_fsl is not yet implemented.")
217
217
 
218
218
  @classmethod
219
219
  def from_bids(cls, bids_folder, config_params) -> "NormData": # type: ignore
@@ -232,6 +232,7 @@ class NormData(xr.Dataset):
232
232
  NormData
233
233
  An instance of NormData.
234
234
  """
235
+ raise NotImplementedError("from_bids is not yet implemented.")
235
236
 
236
237
  @classmethod
237
238
  def from_xarray(cls, name: str, xarray_dataset: xr.Dataset) -> NormData:
@@ -257,6 +258,35 @@ class NormData(xr.Dataset):
257
258
  xarray_dataset.attrs,
258
259
  )
259
260
 
261
+ @classmethod
262
+ def from_netcdf(cls, name: str, netcdf_path: str) -> NormData:
263
+ """
264
+ Load a normative dataset from a netcdf file.
265
+
266
+ Parameters
267
+ ----------
268
+ name: str
269
+ The name of the dataset.
270
+ netcdf_path: str
271
+ The path to the netcdf file.
272
+
273
+ Returns
274
+ -------
275
+ NormData
276
+ An instance of NormData.
277
+ """
278
+ xr_dset = xr.open_dataset(netcdf_path)
279
+
280
+ # Deserialize the attributes.
281
+ for attr in xr_dset.attrs:
282
+ if attr in xr_dset.attrs:
283
+ xr_dset.attrs[attr] = json.loads(xr_dset.attrs[attr])
284
+
285
+ if "batch_effect_counts" in xr_dset.attrs and xr_dset.attrs["batch_effect_counts"]:
286
+ # Convert the batch_effect_counts to a defaultdict
287
+ xr_dset.attrs["batch_effect_counts"] = defaultdict(lambda: 0, xr_dset.attrs["batch_effect_counts"])
288
+ return cls.from_xarray(name=name, xarray_dataset=xr_dset)
289
+
260
290
  # pylint: disable=arguments-differ
261
291
  @classmethod
262
292
  def from_dataframe( # type:ignore
@@ -292,7 +322,7 @@ class NormData(xr.Dataset):
292
322
  attrs : Mapping[str, Any] | None, optional
293
323
  Additional attributes for the dataset, by default None.
294
324
  remove_Nan: bool
295
- Wheter or not to remove NAN values from the dataframe before creationg of the class object. By default False
325
+ Whether or not to remove NAN values from the dataframe before creating of the class object. By default False
296
326
 
297
327
  Returns
298
328
  -------
@@ -358,6 +388,26 @@ class NormData(xr.Dataset):
358
388
  attrs,
359
389
  )
360
390
 
391
+ def to_netcdf(self, netcdf_path: str) -> None:
392
+ """
393
+ Save the NormData object to a netcdf file.
394
+
395
+ Parameters
396
+ ----------
397
+ netcdf_path: str
398
+ The path to the netcdf file.
399
+
400
+ Returns
401
+ -------
402
+ None
403
+ """
404
+ ds = self.copy(deep=False)
405
+ # Serialize the attributes using json so that they can be saved to netcdf
406
+ for attr in ds.attrs:
407
+ if attr in ds.attrs:
408
+ ds.attrs[attr] = json.dumps(ds.attrs[attr])
409
+ xr.Dataset.to_netcdf(ds, netcdf_path, invalid_netcdf=False, format="NETCDF4")
410
+
361
411
  @classmethod
362
412
  def remove_nan(cls, dataframe: pd.DataFrame) -> pd.DataFrame:
363
413
  """
@@ -367,7 +417,6 @@ class NormData(xr.Dataset):
367
417
  Output.print(f"Removed {len(dataframe) - len(cleaned)} NANs")
368
418
  return cleaned
369
419
 
370
-
371
420
  @classmethod
372
421
  def remove_outliers(cls, dataframe: pd.DataFrame, continuous_vars: List[str], z_threshold: float = 3.0) -> pd.DataFrame:
373
422
  """
@@ -385,7 +434,6 @@ class NormData(xr.Dataset):
385
434
  Output.print(f"Removed {np.sum(~idx)} outliers")
386
435
  return dataframe.loc[idx]
387
436
 
388
-
389
437
  def merge(self, other: NormData, name: str | None = None) -> NormData:
390
438
  """
391
439
  Merge two NormData objects.
@@ -643,41 +691,71 @@ class NormData(xr.Dataset):
643
691
  B = self.select_batch_effects(names[1], batch_effects, invert=True)
644
692
  return A, B
645
693
 
694
+ def has_registered_metadata(self) -> bool:
695
+ """
696
+ Check if the batch effect and covariate metadata have been registered and are non-empty.
697
+
698
+ Returns
699
+ -------
700
+ bool
701
+ True if all required metadata attributes exist and are not empty, False otherwise.
702
+ """
703
+ required_attrs = [
704
+ "unique_batch_effects",
705
+ "batch_effect_counts",
706
+ "covariate_ranges",
707
+ "batch_effect_covariate_ranges",
708
+ ]
709
+
710
+ for attr in required_attrs:
711
+ # Check if attribute exists and is not an empty dict/defaultdict
712
+ if attr not in self.attrs or not self.attrs[attr]:
713
+ return False
714
+
715
+ return True
716
+
646
717
  def register_batch_effects(self) -> None:
647
718
  """
648
719
  Create a mapping of batch effects to unique values.
649
720
  """
721
+ if self.has_registered_metadata():
722
+ return
650
723
  my_be: xr.DataArray = self.batch_effects
651
724
  # create a dictionary with for each column in the batch effects, a dict from value to int
652
725
  self.attrs["unique_batch_effects"] = {}
653
726
  self.attrs["batch_effect_counts"] = defaultdict(lambda: 0)
654
727
  self.attrs["covariate_ranges"] = {}
655
-
656
- # TODO: the following can be done much easier using df.groupby.min and xarray.unstack, but that is a TODO for another day. This works for now.
657
728
  self.attrs["batch_effect_covariate_ranges"] = {}
658
- for dim in self.batch_effect_dims.to_numpy():
659
- dim_subset = my_be.sel(batch_effect_dims=dim)
660
- uniques, counts = np.unique(dim_subset, return_counts=True)
661
729
 
662
- self.attrs["unique_batch_effects"][dim] = list(uniques)
663
- self.attrs["batch_effect_counts"][dim] = {k: int(v) for k, v in zip(uniques, counts)}
730
+ # Vectorized implementation using pandas groupby/agg
731
+ be_cols = self.batch_effect_dims.to_numpy()
732
+ be_df = pd.DataFrame(my_be.values, columns=be_cols)
733
+
734
+ x_available = "X" in self.data_vars
735
+ if x_available:
736
+ covs = self.covariates.to_numpy()
737
+ X_df = pd.DataFrame(self.X.values, columns=covs)
738
+
739
+ for dim in be_cols:
740
+ vc = be_df[dim].value_counts(sort=False)
741
+ self.attrs["unique_batch_effects"][dim] = vc.index.astype(str).tolist()
742
+ self.attrs["batch_effect_counts"][dim] = {str(k): int(v) for k, v in vc.to_dict().items()}
664
743
  self.attrs["batch_effect_covariate_ranges"][dim] = {}
665
- if self.X is not None:
666
- for u in uniques:
667
- self.attrs["batch_effect_covariate_ranges"][dim][u] = {}
668
- for c in self.covariates.to_numpy():
669
- u_mask = dim_subset.values == u
670
- my_c = self.X.sel(covariates=c).values[u_mask]
671
- my_min = my_c.min()
672
- my_max = my_c.max()
673
- my_mean = my_c.mean()
674
- self.attrs["batch_effect_covariate_ranges"][dim][u][c] = {"mean": my_mean, "min": my_min, "max": my_max}
675
- for c in self.covariates.to_numpy():
676
- my_c = self.X.sel(covariates=c).values
677
- my_mean = my_c.mean()
678
- my_min = my_c.min()
679
- my_max = my_c.max()
680
- self.attrs["covariate_ranges"][c] = {"mean": my_mean, "min": my_min, "max": my_max}
744
+
745
+ if x_available:
746
+ grouped = X_df.groupby(be_df[dim], sort=False).agg(["min", "max"])
747
+ for u, row in grouped.iterrows():
748
+ self.attrs["batch_effect_covariate_ranges"][dim][u] = {
749
+ c: {"min": float(row[(c, "min")]), "max": float(row[(c, "max")])} for c in covs
750
+ }
751
+
752
+ if x_available:
753
+ overall = X_df.agg(["min", "max"])
754
+ for c in covs:
755
+ self.attrs["covariate_ranges"][c] = {
756
+ "min": float(overall.loc["min", c]),
757
+ "max": float(overall.loc["max", c]),
758
+ }
681
759
 
682
760
  def check_compatibility(self, other: NormData) -> bool:
683
761
  """
@@ -735,7 +813,6 @@ class NormData(xr.Dataset):
735
813
  for cov in self.covariates.to_numpy()
736
814
  }
737
815
 
738
-
739
816
  mybecr = self.batch_effect_covariate_ranges
740
817
  otbecr = other.batch_effect_covariate_ranges
741
818
  nbecr = {}
@@ -757,6 +834,7 @@ class NormData(xr.Dataset):
757
834
  case False, False:
758
835
  raise ValueError("This should never happen")
759
836
 
837
+ # Update instance attributes
760
838
  self.unique_batch_effects = copy.deepcopy(all_unique_batch_effects)
761
839
  other.unique_batch_effects = copy.deepcopy(all_unique_batch_effects)
762
840
  self.covariate_ranges = copy.deepcopy(ncr)
@@ -764,6 +842,14 @@ class NormData(xr.Dataset):
764
842
  self.batch_effect_covariate_ranges = copy.deepcopy(nbecr)
765
843
  other.batch_effect_covariate_ranges = copy.deepcopy(nbecr)
766
844
 
845
+ # Update xarray attrs dicts to make them in sync.
846
+ self.attrs["unique_batch_effects"] = copy.copy(self.unique_batch_effects)
847
+ other.attrs["unique_batch_effects"] = copy.copy(other.unique_batch_effects)
848
+ self.attrs["covariate_ranges"] = copy.copy(self.covariate_ranges)
849
+ other.attrs["covariate_ranges"] = copy.copy(other.covariate_ranges)
850
+ self.attrs["batch_effect_covariate_ranges"] = copy.copy(self.batch_effect_covariate_ranges)
851
+ other.attrs["batch_effect_covariate_ranges"] = copy.copy(other.batch_effect_covariate_ranges)
852
+
767
853
  def scale_forward(self, inscalers: Dict[str, Any], outscalers: Dict[str, Any]) -> None:
768
854
  """
769
855
  Scale the data forward in-place using provided scalers.
@@ -1036,12 +1122,12 @@ class NormData(xr.Dataset):
1036
1122
  zdf = self.Z.to_dataframe().unstack(level="response_vars")
1037
1123
  zdf.columns = zdf.columns.droplevel(0)
1038
1124
  zdf = zdf.merge(self.subject_ids.to_dataframe(), on="observations", how="left")
1039
- zdf = zdf[[ "subject_ids", *[z for z in sorted(zdf.columns.tolist()) if z not in ["subject_ids"]]]]
1125
+ zdf = zdf[["subject_ids", *[z for z in sorted(zdf.columns.tolist()) if z not in ["subject_ids"]]]]
1040
1126
  zdf.index = zdf.index.astype(str)
1041
1127
  res_path = os.path.join(save_dir, f"Z_{self.name}.csv")
1042
- with open(res_path, mode="a+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1043
- try:
1044
- fcntl.flock(f.fileno(), fcntl.LOCK_EX)
1128
+ lock_path = res_path + ".lock"
1129
+ with FileLock(lock_path):
1130
+ with open(res_path, mode="a+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1045
1131
  f.seek(0)
1046
1132
  old_results = pd.read_csv(f) if os.path.getsize(res_path) > 0 else None
1047
1133
  if old_results is not None:
@@ -1067,8 +1153,6 @@ class NormData(xr.Dataset):
1067
1153
  )
1068
1154
  new_results.index = new_results.index.astype(str)
1069
1155
  new_results.to_csv(f)
1070
- finally:
1071
- fcntl.flock(f.fileno(), fcntl.LOCK_UN)
1072
1156
 
1073
1157
  def load_zscores(self, save_dir) -> None:
1074
1158
  Z_path = os.path.join(save_dir, f"Z_{self.name}.csv")
@@ -1089,15 +1173,15 @@ class NormData(xr.Dataset):
1089
1173
  subject_ids.index = subject_ids.index.astype(str)
1090
1174
  subject_ids.columns = pd.MultiIndex.from_tuples([("subject_ids", "X")], names=["subject_ids", "centile"])
1091
1175
  for c in self.centile.to_numpy():
1092
- subject_ids[("subject_ids", c)] = subject_ids[("subject_ids","X")]
1176
+ subject_ids[("subject_ids", c)] = subject_ids[("subject_ids", "X")]
1093
1177
  subject_ids = subject_ids.drop(columns=[("subject_ids", "X")])
1094
1178
  subject_ids = subject_ids.stack(level="centile")
1095
- centiles = centiles.merge(subject_ids, on=["observations","centile"], how="left")
1096
- centiles = centiles[[ "subject_ids", *[z for z in sorted(centiles.columns.tolist()) if z not in ["subject_ids"]]]]
1179
+ centiles = centiles.merge(subject_ids, on=["observations", "centile"], how="left")
1180
+ centiles = centiles[["subject_ids", *[z for z in sorted(centiles.columns.tolist()) if z not in ["subject_ids"]]]]
1097
1181
  res_path = os.path.join(save_dir, f"centiles_{self.name}.csv")
1098
- with open(res_path, mode="a+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1099
- try:
1100
- fcntl.flock(f.fileno(), fcntl.LOCK_EX)
1182
+ lock_path = res_path + ".lock"
1183
+ with FileLock(lock_path):
1184
+ with open(res_path, mode="a+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1101
1185
  f.seek(0)
1102
1186
  old_results = pd.read_csv(f) if os.path.getsize(res_path) > 0 else None
1103
1187
  if old_results is not None:
@@ -1123,8 +1207,6 @@ class NormData(xr.Dataset):
1123
1207
  )
1124
1208
  # new_results.index = new_results.index.astype(str)
1125
1209
  new_results.to_csv(f)
1126
- finally:
1127
- fcntl.flock(f.fileno(), fcntl.LOCK_UN)
1128
1210
 
1129
1211
  def load_centiles(self, save_dir) -> None:
1130
1212
  C_path = os.path.join(save_dir, f"centiles_{self.name}.csv")
@@ -1137,7 +1219,7 @@ class NormData(xr.Dataset):
1137
1219
  A = np.zeros((len(centiles), len(obs), len(response_vars)))
1138
1220
  for i, c in enumerate(centiles):
1139
1221
  sub = df[df["centile"] == c]
1140
- sub.sort_values(by="observations")
1222
+ sub = sub.sort_values(by="observations")
1141
1223
  for j, rv in enumerate(response_vars):
1142
1224
  A[i, :, j] = sub[rv]
1143
1225
 
@@ -1151,12 +1233,12 @@ class NormData(xr.Dataset):
1151
1233
  logp = self.logp.to_dataframe().unstack(level="response_vars")
1152
1234
  logp.columns = logp.columns.droplevel(0)
1153
1235
  logp = logp.merge(self.subject_ids.to_dataframe(), on="observations", how="left")
1154
- logp = logp[[ "subject_ids", *[z for z in sorted(logp.columns.tolist()) if z not in ["subject_ids"]]]]
1236
+ logp = logp[["subject_ids", *[z for z in sorted(logp.columns.tolist()) if z not in ["subject_ids"]]]]
1155
1237
  logp.index = logp.index.astype(str)
1156
1238
  res_path = os.path.join(save_dir, f"logp_{self.name}.csv")
1157
- with open(res_path, mode="a+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1158
- try:
1159
- fcntl.flock(f.fileno(), fcntl.LOCK_EX)
1239
+ lock_path = res_path + ".lock"
1240
+ with FileLock(lock_path):
1241
+ with open(res_path, mode="a+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1160
1242
  f.seek(0)
1161
1243
  old_results = pd.read_csv(f) if os.path.getsize(res_path) > 0 else None
1162
1244
  if old_results is not None:
@@ -1182,8 +1264,6 @@ class NormData(xr.Dataset):
1182
1264
  )
1183
1265
  new_results.index = new_results.index.astype(str)
1184
1266
  new_results.to_csv(f)
1185
- finally:
1186
- fcntl.flock(f.fileno(), fcntl.LOCK_UN)
1187
1267
 
1188
1268
  def load_logp(self, save_dir) -> None:
1189
1269
  logp_path = os.path.join(save_dir, f"logp_{self.name}.csv")
@@ -1200,9 +1280,9 @@ class NormData(xr.Dataset):
1200
1280
  mdf = self.statistics.to_dataframe().unstack(level="response_vars")
1201
1281
  mdf.columns = mdf.columns.droplevel(0)
1202
1282
  res_path = os.path.join(save_dir, f"statistics_{self.name}.csv")
1203
- with open(res_path, mode="a+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1204
- try:
1205
- fcntl.flock(f.fileno(), fcntl.LOCK_EX)
1283
+ lock_path = res_path + ".lock"
1284
+ with FileLock(lock_path):
1285
+ with open(res_path, mode="a+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1206
1286
  f.seek(0)
1207
1287
  old_results = pd.read_csv(f, index_col=0) if os.path.getsize(res_path) > 0 else None
1208
1288
  if old_results is not None:
@@ -1215,8 +1295,6 @@ class NormData(xr.Dataset):
1215
1295
  f.seek(0)
1216
1296
  f.truncate()
1217
1297
  new_results.to_csv(f)
1218
- finally:
1219
- fcntl.flock(f.fileno(), fcntl.LOCK_UN)
1220
1298
 
1221
1299
  def load_statistics(self, save_dir) -> None:
1222
1300
  logp_path = os.path.join(save_dir, f"statistics_{self.name}.csv")