pcntoolkit 1.2.0.post1__tar.gz → 1.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/PKG-INFO +14 -3
  2. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/__init__.py +2 -1
  3. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/dataio/data_factory.py +111 -1
  4. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/dataio/norm_data.py +51 -16
  5. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/math_functions/basis_function.py +208 -10
  6. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/math_functions/likelihood.py +63 -28
  7. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/math_functions/prior.py +50 -21
  8. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/math_functions/scaler.py +35 -7
  9. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/math_functions/shash.py +18 -32
  10. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/normative_model.py +44 -8
  11. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/regression_model/blr.py +61 -26
  12. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/regression_model/hbr.py +12 -2
  13. pcntoolkit-1.3.0/pcntoolkit/util/data_utils.py +80 -0
  14. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/util/evaluator.py +328 -39
  15. pcntoolkit-1.3.0/pcntoolkit/util/migration.py +271 -0
  16. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/util/output.py +41 -6
  17. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/util/plotter.py +295 -119
  18. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit.egg-info/PKG-INFO +14 -3
  19. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit.egg-info/SOURCES.txt +2 -0
  20. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit.egg-info/requires.txt +13 -2
  21. pcntoolkit-1.3.0/pcntoolkit.egg-info/top_level.txt +2 -0
  22. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pyproject.toml +15 -4
  23. pcntoolkit-1.2.0.post1/pcntoolkit.egg-info/top_level.txt +0 -1
  24. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/LICENSE +0 -0
  25. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/README.md +0 -0
  26. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/dataio/__init__.py +0 -0
  27. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/dataio/fileio.py +0 -0
  28. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/math_functions/__init__.py +0 -0
  29. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/math_functions/factorize.py +0 -0
  30. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/math_functions/thrive.py +0 -0
  31. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/math_functions/warp.py +0 -0
  32. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/normative.py +0 -0
  33. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/regression_model/__init__.py +0 -0
  34. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/regression_model/factory.py +0 -0
  35. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/regression_model/regression_model.py +0 -0
  36. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/regression_model/test_model.py +0 -0
  37. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/util/__init__.py +0 -0
  38. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/util/autoscale_plot.py +0 -0
  39. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/util/job_observer.py +0 -0
  40. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/util/model_comparison.py +0 -0
  41. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/util/paths.py +0 -0
  42. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit/util/runner.py +0 -0
  43. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit.egg-info/dependency_links.txt +0 -0
  44. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/pcntoolkit.egg-info/entry_points.txt +0 -0
  45. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/setup.cfg +0 -0
  46. {pcntoolkit-1.2.0.post1 → pcntoolkit-1.3.0}/test/test_normative.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pcntoolkit
3
- Version: 1.2.0.post1
3
+ Version: 1.3.0
4
4
  Summary: Predictive Clinical Neuroscience Toolkit
5
5
  Author: Andre Marquand, Stijn de Boer, Seyed Mostafa Kia, Saige Rutherford, Charlotte Fraza, Barbora Rehák Bučková, Pieter Barkema, Thomas Wolfers, Mariam Zabihi, Richard Dinga, Johanna Bayer, Maarten Mennes, Hester Huijsdens, Linden Parkes, Pierre Berthet
6
6
  License-Expression: GPL-3.0-only
@@ -8,20 +8,30 @@ Requires-Python: <3.13,>=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
10
  Requires-Dist: nibabel>=5.3.1
11
- Requires-Dist: pymc>=5.19.1
11
+ Requires-Dist: h5py>=3.11.0
12
+ Requires-Dist: h5netcdf>=1.3.0
13
+ Requires-Dist: pymc<6.0.0,>=5.19.1
14
+ Requires-Dist: pytensor<3.0.0,>=2.22.0
12
15
  Requires-Dist: scikit-learn>=1.5.2
13
16
  Requires-Dist: six>=1.16.0
14
17
  Requires-Dist: scipy>=1.12
15
18
  Requires-Dist: matplotlib>=3.9.2
16
19
  Requires-Dist: seaborn>=0.13.2
17
20
  Requires-Dist: numba>=0.60.0
18
- Requires-Dist: nutpie>=0.16.5
21
+ Requires-Dist: nutpie<0.16.9,>=0.16.5
19
22
  Requires-Dist: joblib>=1.4.2
20
23
  Requires-Dist: dill>=0.3.9
21
24
  Requires-Dist: ipywidgets>=8.1.5
22
25
  Requires-Dist: ipykernel>=6.29.5
26
+ Requires-Dist: ipython>=8.0.0
23
27
  Requires-Dist: dask>=2025.11.0
24
28
  Requires-Dist: filelock>=3.13.0
29
+ Requires-Dist: packaging>=21.3
30
+ Requires-Dist: arviz<1.0.0,>=0.21.0
31
+ Requires-Dist: numpy>=2.0.0
32
+ Requires-Dist: pandas>=2.2.0
33
+ Requires-Dist: xarray>=2024.1.0
34
+ Requires-Dist: cloudpickle>=3.0.0
25
35
  Provides-Extra: dev
26
36
  Requires-Dist: toml; extra == "dev"
27
37
  Requires-Dist: sphinx-tabs>=3.4.7; extra == "dev"
@@ -30,6 +40,7 @@ Requires-Dist: black>=24.10.0; extra == "dev"
30
40
  Requires-Dist: sphinx-rtd-theme>=3.0.2; extra == "dev"
31
41
  Requires-Dist: ruff>=0.8.6; extra == "dev"
32
42
  Requires-Dist: pytest-cov>=6.0.0; extra == "dev"
43
+ Requires-Dist: nbconvert>=7.15.0; extra == "dev"
33
44
  Dynamic: license-file
34
45
 
35
46
  # Predictive Clinical Neuroscience Toolkit
@@ -1,6 +1,6 @@
1
1
  from .dataio.data_factory import load_fcon1000
2
2
  from .dataio.norm_data import NormData
3
- from .math_functions.basis_function import BsplineBasisFunction, LinearBasisFunction, PolynomialBasisFunction, CompositeBasisFunction
3
+ from .math_functions.basis_function import BsplineBasisFunction, LinearBasisFunction, PolynomialBasisFunction, CompositeBasisFunction, FractionalPolynomialBasisFunction
4
4
  from .math_functions.likelihood import BetaLikelihood, NormalLikelihood, SHASHbLikelihood
5
5
  from .math_functions.prior import make_prior
6
6
  from .normative_model import NormativeModel
@@ -14,6 +14,7 @@ __version__ = version("pcntoolkit")
14
14
  __all__ = [
15
15
  "NormData",
16
16
  "BsplineBasisFunction",
17
+ "FractionalPolynomialBasisFunction",
17
18
  "LinearBasisFunction",
18
19
  "PolynomialBasisFunction",
19
20
  "CompositeBasisFunction",
@@ -8,11 +8,26 @@ from pcntoolkit.dataio.norm_data import NormData
8
8
 
9
9
 
10
10
  def load_fcon1000(save_path: str | None = None):
11
- """Download and save fcon dataset to specified path, or load it from there if it is already downloaded"""
11
+ """Download and save fcon dataset to specified path, or load it from there
12
+ if it is already downloaded
13
+
14
+ Parameters
15
+ ----------
16
+ save_path : str | None
17
+ The path to save the dataset to, or load it from if it is already
18
+ downloaded
19
+
20
+ Returns
21
+ -------
22
+ NormData
23
+ The loaded dataset as a NormData object"""
12
24
  if not save_path:
13
25
  save_path = os.path.join("pcntoolkit_resources", "data")
14
26
  os.makedirs(save_path, exist_ok=True)
15
27
  data_path = os.path.join(save_path, "fcon1000.csv")
28
+
29
+ # If the dataset is not already downloaded, download it and save it to
30
+ # the specified path
16
31
  if not os.path.exists(data_path):
17
32
  data = pd.read_csv(
18
33
  "https://raw.githubusercontent.com/predictive-clinical-neuroscience/PCNtoolkit-demo/refs/heads/main/data/fcon1000.csv"
@@ -256,3 +271,98 @@ def load_fcon1000(save_path: str | None = None):
256
271
  remove_Nan=True,
257
272
  )
258
273
  return norm_data
274
+
275
+
276
+ # NOTE: This dataset is not public
277
+ def load_lifespan_big(
278
+ n_response_vars: int | None = None,
279
+ n_largest_sites: int | None = None,
280
+ n_subjects: int | None = None
281
+ ) -> NormData:
282
+ """
283
+ Load the lifespan_big dataset, which is a large lifespan dataset with many sites.
284
+
285
+ Parameters
286
+ ----------
287
+ n_response_vars : int | None
288
+ If specified, only use the first n_response_vars response
289
+ variables.
290
+ n_largest_sites : int | None
291
+ If specified, only keep data from the n_largest_sites largest
292
+ sites.
293
+ n_subjects : int | None
294
+ If specified, randomly sample n_subjects subjects.
295
+
296
+ Returns
297
+ -------
298
+ NormData
299
+ The loaded dataset as a NormData object.
300
+ """
301
+ # Define the variables in the dataset
302
+ subject_ids = ["participant_id"]
303
+ covariates = ["age"]
304
+ batch_effects = ["sex", "site"]
305
+
306
+ # Define the dtypes for loading the dataset, to ensure that categorical
307
+ # variables are loaded as strings and numerical variables as floats
308
+ dtypes = {"participant_id": str, "group": str, "group2": str}
309
+ for col in batch_effects:
310
+ dtypes[col] = str
311
+ for col in covariates:
312
+ dtypes[col] = float
313
+
314
+ # Load the lifespan dataset with 57116 subjects from the Braicharts paper:
315
+ # https://doi.org/10.7554/eLife.72904
316
+ data = pd.read_csv(
317
+ "/project_cephfs/3022017.06/projects/stijdboe/Data/sairut_data/"
318
+ "lifespan_big.csv", dtype=dtypes)
319
+
320
+ # Drop rows where all values are NaN
321
+ data = data.dropna(axis=0, how="all", inplace=False)
322
+ # Drop columns where even if 1 value is NaN
323
+ data = data.dropna(axis=1, how="any", inplace=False)
324
+
325
+ data["sex"] = data["sex"].map(
326
+ {"0.0": "Female", "1.0": "Male", "2.0": "Female"})
327
+ data["site"] = data["site_ID"]
328
+
329
+ # If requested, take only the n largest sites
330
+ if n_largest_sites is not None:
331
+ data = data[data["site_ID"].isin(
332
+ data["site_ID"].value_counts().head(n_largest_sites).index)]
333
+
334
+ # If requested, take only n subjects
335
+ if n_subjects is not None:
336
+ data = data.sample(n=n_subjects, replace=False)
337
+
338
+ # Define response variables as all variables that
339
+ # are not in subject_ids, covariates, batch_effects,
340
+ # and that have variance > 0
341
+ def is_response_var(col_name: str) -> bool:
342
+ return (
343
+ col_name not in subject_ids
344
+ and col_name not in covariates
345
+ and col_name not in batch_effects
346
+ and not col_name.startswith("site_")
347
+ and not col_name.startswith("group")
348
+ and not col_name.startswith("race")
349
+ and data[col_name].var() > 0
350
+ )
351
+
352
+ response_vars = [col for col in data.columns if is_response_var(col)]
353
+
354
+ # If requested, take only n response variables
355
+ if n_response_vars is not None:
356
+ response_vars = response_vars[:n_response_vars]
357
+
358
+ # Create NormData object
359
+ norm_data = NormData.from_dataframe(
360
+ name="lifespan_big",
361
+ dataframe=data,
362
+ covariates=covariates,
363
+ batch_effects=batch_effects,
364
+ response_vars=response_vars,
365
+ subject_ids=subject_ids,
366
+ )
367
+
368
+ return norm_data
@@ -510,7 +510,7 @@ class NormData(xr.Dataset):
510
510
  new_data_vars["Z"] = (["observations", "response_vars"], new_Z.data)
511
511
 
512
512
  if hasattr(self, "centiles") and hasattr(other, "centiles"):
513
- if self.centile.to_numpy() == other.centile.to_numpy():
513
+ if np.array_equal(self.centile.to_numpy(), other.centile.to_numpy()):
514
514
  new_centiles = xr.DataArray(
515
515
  np.zeros((new_X.shape[0], len(respvar_intersection), len(self.centile.to_numpy()))),
516
516
  dims=["observations", "response_vars", "centile"],
@@ -682,7 +682,23 @@ class NormData(xr.Dataset):
682
682
  names: Optional[Tuple[str, str]],
683
683
  ) -> Tuple[NormData, NormData]:
684
684
  """
685
- Split the data into two datasets, one with the specified batch effects and one without.
685
+ Split the data into two datasets, one with the specified batch effects
686
+ and one without.
687
+
688
+ This is useful when you want to split a dataset into two smaller ones.
689
+
690
+ Parameters
691
+ ----------
692
+ batch_effects : Dict[str, List[str]]
693
+ A dictionary mapping batch effect dimensions to lists of values to
694
+ split on.
695
+ names : Optional[Tuple[str, str]]
696
+ The names for the two splits.
697
+
698
+ Returns
699
+ -------
700
+ Tuple[NormData, NormData]
701
+ A tuple containing the two split NormData instances.
686
702
  """
687
703
  if names is None:
688
704
  names = ["selected", "not_selected"] # type:ignore
@@ -1033,19 +1049,31 @@ class NormData(xr.Dataset):
1033
1049
 
1034
1050
  self.attrs["is_scaled"] = False
1035
1051
 
1036
- def select_batch_effects(self, name, batch_effects: Dict[str, List[str]], invert: bool = False) -> NormData:
1052
+ def select_batch_effects(
1053
+ self,
1054
+ name: str,
1055
+ batch_effects: Dict[str, List[str]],
1056
+ invert: bool = False,
1057
+ ) -> NormData:
1037
1058
  """
1038
- Select only the specified batch effects.
1059
+ Select observations matching (or not matching) batch effects.
1039
1060
 
1040
1061
  Parameters
1041
1062
  ----------
1063
+ name : str
1064
+ Name to assign to the returned ``NormData`` instance.
1042
1065
  batch_effects : Dict[str, List[str]]
1043
- A dictionary specifying which batch effects to select.
1066
+ A dictionary mapping batch effect dimensions to lists of values to
1067
+ select batch effects from.
1068
+ invert : bool, optional
1069
+ If ``True``, return observations that do *not* match
1070
+ any of the specified batch effect values. Default is ``False``.
1044
1071
 
1045
1072
  Returns
1046
1073
  -------
1047
1074
  NormData
1048
- A NormData instance with the selected batch effects.
1075
+ A NormData instance containing observations matching
1076
+ (or not matching) the specified batch effects.
1049
1077
  """
1050
1078
  mask = np.zeros(self.batch_effects.shape[0], dtype=bool)
1051
1079
  for key, values in batch_effects.items():
@@ -1077,21 +1105,28 @@ class NormData(xr.Dataset):
1077
1105
  """
1078
1106
  acc = []
1079
1107
  x_columns = [col for col in ["X"] if hasattr(self, col)]
1080
- y_columns = [col for col in ["Y", "Y_harmonized", "Z"] if hasattr(self, col)]
1108
+ y_columns = [col for col in ["Y", "Y_harmonized", "Z", "logp", "Yhat"]
1109
+ if hasattr(self, col)]
1081
1110
  acc.append(
1082
1111
  xr.Dataset.to_dataframe(self[x_columns], dim_order)
1083
1112
  .reset_index(drop=False)
1084
- .pivot(index="observations", columns="covariates", values=x_columns)
1113
+ .pivot(index="observations",
1114
+ columns="covariates",
1115
+ values=x_columns)
1085
1116
  )
1086
1117
  acc.append(
1087
1118
  xr.Dataset.to_dataframe(self[y_columns], dim_order)
1088
1119
  .reset_index(drop=False)
1089
- .pivot(index="observations", columns="response_vars", values=y_columns)
1120
+ .pivot(index="observations",
1121
+ columns="response_vars",
1122
+ values=y_columns)
1090
1123
  )
1091
1124
  be = (
1092
1125
  xr.DataArray.to_dataframe(self.batch_effects, dim_order)
1093
1126
  .reset_index(drop=False)
1094
- .pivot(index="observations", columns="batch_effect_dims", values="batch_effects")
1127
+ .pivot(index="observations",
1128
+ columns="batch_effect_dims",
1129
+ values="batch_effects")
1095
1130
  )
1096
1131
  be.columns = [("batch_effects", col) for col in be.columns]
1097
1132
 
@@ -1127,7 +1162,7 @@ class NormData(xr.Dataset):
1127
1162
  res_path = os.path.join(save_dir, f"Z_{self.name}.csv")
1128
1163
  lock_path = res_path + ".lock"
1129
1164
  with FileLock(lock_path):
1130
- with open(res_path, mode="a+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1165
+ with open(res_path, mode="r+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1131
1166
  f.seek(0)
1132
1167
  old_results = pd.read_csv(f) if os.path.getsize(res_path) > 0 else None
1133
1168
  if old_results is not None:
@@ -1181,7 +1216,7 @@ class NormData(xr.Dataset):
1181
1216
  res_path = os.path.join(save_dir, f"centiles_{self.name}.csv")
1182
1217
  lock_path = res_path + ".lock"
1183
1218
  with FileLock(lock_path):
1184
- with open(res_path, mode="a+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1219
+ with open(res_path, mode="r+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1185
1220
  f.seek(0)
1186
1221
  old_results = pd.read_csv(f) if os.path.getsize(res_path) > 0 else None
1187
1222
  if old_results is not None:
@@ -1238,7 +1273,7 @@ class NormData(xr.Dataset):
1238
1273
  res_path = os.path.join(save_dir, f"logp_{self.name}.csv")
1239
1274
  lock_path = res_path + ".lock"
1240
1275
  with FileLock(lock_path):
1241
- with open(res_path, mode="a+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1276
+ with open(res_path, mode="r+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1242
1277
  f.seek(0)
1243
1278
  old_results = pd.read_csv(f) if os.path.getsize(res_path) > 0 else None
1244
1279
  if old_results is not None:
@@ -1282,7 +1317,7 @@ class NormData(xr.Dataset):
1282
1317
  res_path = os.path.join(save_dir, f"statistics_{self.name}.csv")
1283
1318
  lock_path = res_path + ".lock"
1284
1319
  with FileLock(lock_path):
1285
- with open(res_path, mode="a+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1320
+ with open(res_path, mode="r+" if os.path.exists(res_path) else "w", encoding="utf-8") as f:
1286
1321
  f.seek(0)
1287
1322
  old_results = pd.read_csv(f, index_col=0) if os.path.getsize(res_path) > 0 else None
1288
1323
  if old_results is not None:
@@ -1336,7 +1371,7 @@ class NormData(xr.Dataset):
1336
1371
 
1337
1372
  This method creates a DataArray with dimensions 'response_vars' and 'statistics',
1338
1373
  where 'response_vars' corresponds to the response variables in the dataset,
1339
- and 'statistics' includes statistics such as Rho, RMSE, SMSE, EXPV, NLL, and ShapiroW.
1374
+ and 'statistics' includes statistics such as Rho, RMSE, SMSE, EXPV, MLL, and ShapiroW.
1340
1375
  The DataArray is filled with NaN values initially.
1341
1376
  """
1342
1377
  rv = self.response_vars.to_numpy().copy().tolist()
@@ -1346,7 +1381,7 @@ class NormData(xr.Dataset):
1346
1381
  dims=("response_vars", "statistics"),
1347
1382
  coords={
1348
1383
  "response_vars": np.arange(len(rv)),
1349
- "statistics": ["Rho", "Rho_p", "R2", "RMSE", "SMSE", "MSLL", "NLL", "ShapiroW", "MACE", "MAPE", "EXPV"],
1384
+ "statistics": ["Rho", "Rho_p", "R2", "RMSE", "SMSE", "MSLL", "MLL", "ShapiroW", "MACE", "MAPE", "EXPV"],
1350
1385
  },
1351
1386
  )
1352
1387
 
@@ -7,12 +7,14 @@ import numpy as np
7
7
  from scipy.interpolate import BSpline
8
8
 
9
9
  from pcntoolkit.util.output import Errors, Output
10
+ from pcntoolkit.util.migration import registry
11
+
10
12
 
11
13
  def create_basis_function(
12
- basis_type: str | dict | None,
13
- basis_column: int = 0,
14
- **kwargs,
15
- ) -> BasisFunction:
14
+ basis_type: str | dict | None,
15
+ basis_column: int = 0,
16
+ **kwargs,
17
+ ) -> BasisFunction:
16
18
  if isinstance(basis_type, dict):
17
19
  return BasisFunction.from_dict(basis_type)
18
20
  elif basis_type in ["polynomial", "PolynomialBasisFunction"]:
@@ -25,6 +27,12 @@ def create_basis_function(
25
27
  elif basis_type in ["Composite", "CompositeBasis"]:
26
28
  parts = [BasisFunction.from_dict(p) for p in kwargs['parts']]
27
29
  return CompositeBasisFunction(parts)
30
+ elif basis_type in [
31
+ "fractional_polynomial",
32
+ "FractionalPolynomialBasisFunction"]:
33
+ return FractionalPolynomialBasisFunction(
34
+ basis_column, **kwargs
35
+ )
28
36
  else:
29
37
  return LinearBasisFunction(basis_column)
30
38
 
@@ -37,13 +45,19 @@ class BasisFunction(ABC):
37
45
  self.basis_column = basis_column
38
46
  self.is_fitted: bool = kwargs.get("is_fitted", False)
39
47
  self.basis_name: str = kwargs.get("basis_name", "basis")
40
- self.min: float = kwargs.get("min", 0)
41
- self.max: float = kwargs.get("max", 1)
42
- self.compute_min: bool = self.min == 0
43
- self.compute_max: bool = self.max == 1
48
+ self.min: float | None = kwargs.get("min", None)
49
+ self.max: float | None = kwargs.get("max", None)
50
+ self.compute_min: bool = self.min is None
51
+ self.compute_max: bool = self.max is None
44
52
 
45
53
  @classmethod
46
- def from_dict(cls, my_dict: dict) -> BasisFunction:
54
+ def from_dict(
55
+ cls, my_dict: dict, version: str | None = None
56
+ ) -> "BasisFunction":
57
+ # Apply any registered BasisFunction migrations for this version.
58
+ my_dict = registry.migrate(
59
+ "BasisFunction", my_dict, version=version
60
+ )
47
61
  basis_function_type = my_dict["basis_function"]
48
62
  basis_function = create_basis_function(basis_function_type, **my_dict)
49
63
  return basis_function
@@ -264,4 +278,188 @@ class CompositeBasisFunction(BasisFunction):
264
278
 
265
279
  @property
266
280
  def dimension(self):
267
- return sum([p.dimension for p in self.parts])
281
+ return sum([p.dimension for p in self.parts])
282
+
283
+
284
+ class FractionalPolynomialBasisFunction(BasisFunction):
285
+ """
286
+ Fractional polynomial basis function for modelling smooth nonlinear
287
+ effects.
288
+
289
+ The input must be strictly positive (do not standardize the covariates).
290
+ Power convention:
291
+ p = 0 -> log(x)
292
+ p != 0 -> x**p
293
+
294
+ Repeated powers:
295
+ [p, p, p] -> x**p, x**p * log(x), x**p * log(x)**2
296
+ """
297
+
298
+ DEFAULT_POWER_SET = [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 3.0]
299
+
300
+ AGE_FP_POWER_PRESETS = {
301
+ 1: {
302
+ "default": [0.5],
303
+ },
304
+ 2: {
305
+ "default": [0.5, 1.0],
306
+ },
307
+ 3: {
308
+ "default": [0.5, 1.0, 2.0],
309
+ },
310
+ }
311
+
312
+ def __init__(
313
+ self,
314
+ basis_column: int = 0,
315
+ order: int = 3,
316
+ powers: list | tuple | str | None = "default",
317
+ power_set: list | tuple | None = None,
318
+ eps: float = 1e-8,
319
+ **kwargs,
320
+ ):
321
+ """
322
+ Initialise the fractional polynomial basis function.
323
+
324
+ Parameters
325
+ ----------
326
+ basis_column : int, default=0
327
+ Column index to transform.
328
+
329
+ order : int, default=3
330
+ Fractional polynomial order. Must be 1, 2, or 3.
331
+
332
+ powers : list, tuple, str, or None, default="default"
333
+
334
+ power_set : list, tuple, or None, default=None
335
+ Allowed fractional polynomial powers.
336
+
337
+ eps : float, default=1e-8
338
+ Numerical stability constant.
339
+ """
340
+ super().__init__(basis_column, **kwargs)
341
+
342
+ if order not in [1, 2, 3]:
343
+ raise ValueError("Fractional polynomial order must be 1, 2, or 3.")
344
+
345
+ self.basis_name = "fractional_polynomial"
346
+ self.order = int(order)
347
+ self.eps = float(eps)
348
+
349
+ self.power_set = (
350
+ list(self.DEFAULT_POWER_SET)
351
+ if power_set is None
352
+ else [float(p) for p in power_set]
353
+ )
354
+
355
+ if powers is None:
356
+ powers = "default"
357
+
358
+ if isinstance(powers, str):
359
+ presets = self.AGE_FP_POWER_PRESETS[self.order]
360
+
361
+ if powers not in presets:
362
+ raise ValueError(
363
+ f"Unknown preset '{powers}' for FP order {self.order}. "
364
+ f"Available presets are: {list(presets.keys())}"
365
+ )
366
+
367
+ self.powers = list(presets[powers])
368
+ else:
369
+ self.powers = [float(p) for p in powers]
370
+
371
+ if len(self.powers) != self.order:
372
+ raise ValueError(
373
+ f"FP order {self.order} requires exactly {self.order} powers, "
374
+ f"but received {len(self.powers)}: {self.powers}"
375
+ )
376
+
377
+ for power in self.powers:
378
+ if power not in self.power_set:
379
+ raise ValueError(
380
+ f"Power {power} is not in the allowed FP power set: "
381
+ f"{self.power_set}"
382
+ )
383
+
384
+ def _validate_positive_finite_input(self, data: np.ndarray) -> np.ndarray:
385
+ """
386
+ Validate that input values are finite and strictly positive.
387
+
388
+ Returns
389
+ -------
390
+ np.ndarray
391
+ One-dimensional validated input array.
392
+ """
393
+ x = np.asarray(data, dtype=float).reshape(-1)
394
+
395
+ if not np.all(np.isfinite(x)):
396
+ raise ValueError(
397
+ "FractionalPolynomialBasisFunction received non-finite values."
398
+ )
399
+
400
+ if np.any(x <= 0):
401
+ raise ValueError(
402
+ "FractionalPolynomialBasisFunction requires strictly positive "
403
+ "input values. Please shift or rescale the covariate before "
404
+ "applying this basis function."
405
+ )
406
+
407
+ return np.maximum(x, self.eps)
408
+
409
+ def _fit(self, data: np.ndarray) -> None:
410
+ """
411
+ This function is added just for compatibility with parent class.
412
+ It only validates training data without computing or storing any
413
+ parameters.
414
+ """
415
+ self._validate_positive_finite_input(data)
416
+
417
+ def _transform(self, data: np.ndarray) -> np.ndarray:
418
+ """
419
+ Transform data into the fractional polynomial basis matrix.
420
+
421
+ Returns
422
+ -------
423
+ np.ndarray
424
+ Basis matrix of shape `(n_samples, order)`.
425
+ """
426
+ x = self._validate_positive_finite_input(data)
427
+ log_x = np.log(x)
428
+
429
+ columns = []
430
+ power_counts = {}
431
+
432
+ for power in self.powers:
433
+ repeat_index = power_counts.get(power, 0)
434
+
435
+ if power == 0.0:
436
+ column = log_x.copy()
437
+ else:
438
+ column = np.power(x, power)
439
+
440
+ if repeat_index > 0:
441
+ column = column * np.power(log_x, repeat_index)
442
+
443
+ columns.append(column)
444
+ power_counts[power] = repeat_index + 1
445
+
446
+ return np.column_stack(columns)
447
+
448
+ @property
449
+ def dimension(self) -> int:
450
+ """
451
+ Number of generated basis columns.
452
+ """
453
+ return self.order
454
+
455
+ def to_dict(self) -> dict:
456
+ """
457
+ Serialize the basis function configuration.
458
+ """
459
+ mydict = super().to_dict()
460
+ mydict["order"] = self.order
461
+ mydict["powers"] = list(self.powers)
462
+ mydict["power_set"] = list(self.power_set)
463
+ mydict["eps"] = self.eps
464
+ return mydict
465
+