diffstar 0.3.1__tar.gz → 0.3.2__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (117) hide show
  1. {diffstar-0.3.1 → diffstar-0.3.2}/CHANGES.rst +5 -0
  2. {diffstar-0.3.1/diffstar.egg-info → diffstar-0.3.2}/PKG-INFO +2 -2
  3. diffstar-0.3.2/diffstar/_version.py +1 -0
  4. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/data_loaders/load_smah_data.py +172 -20
  5. diffstar-0.3.2/diffstar/data_loaders/tests/test_load_smah_data.py +28 -0
  6. diffstar-0.3.2/diffstar/data_loaders/tests/testing_data/subvol_000_diffmah_fits.h5 +0 -0
  7. diffstar-0.3.2/diffstar/diffstarnet/diffstarnet_tdata.py +226 -0
  8. diffstar-0.3.2/diffstar/diffstarnet/tests/test_diffstarnet_tdata.py +80 -0
  9. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/fitting_helpers/fit_smah_helpers.py +42 -160
  10. diffstar-0.3.2/diffstar/fitting_helpers/fit_smah_helpers_tpeak.py +467 -0
  11. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/fitting_helpers/fitting_kernels.py +11 -36
  12. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/fitting_helpers/tests/test_fitting_kernels.py +2 -10
  13. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/fitting_helpers/tests/test_fitting_smah_helpers.py +1 -0
  14. diffstar-0.3.2/diffstar/kernels/history_kernel_builders_tpeak.py +269 -0
  15. diffstar-0.3.2/diffstar/kernels/main_sequence_kernels_tpeak.py +220 -0
  16. diffstar-0.3.2/diffstar/kernels/tests/test_kernel_builders_tpeak.py +172 -0
  17. diffstar-0.3.2/diffstar/sfh_model_tpeak.py +144 -0
  18. diffstar-0.3.2/diffstar/tests/__init__.py +0 -0
  19. diffstar-0.3.2/diffstar/tests/test_main_sequence_kernels.py +68 -0
  20. diffstar-0.3.2/diffstar/tests/test_sfh_model_tpeak.py +155 -0
  21. {diffstar-0.3.1 → diffstar-0.3.2/diffstar.egg-info}/PKG-INFO +2 -2
  22. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar.egg-info/SOURCES.txt +14 -1
  23. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar.egg-info/requires.txt +1 -1
  24. diffstar-0.3.2/docs/source/_static/README.txt +0 -0
  25. {diffstar-0.3.1 → diffstar-0.3.2}/docs/source/demo_diffstar_fitter.ipynb +2 -4
  26. {diffstar-0.3.1 → diffstar-0.3.2}/pyproject.toml +1 -1
  27. diffstar-0.3.2/requirements.txt +4 -0
  28. diffstar-0.3.2/scripts/history_fitting_script_SMDPL_tpeak.py +189 -0
  29. diffstar-0.3.1/diffstar/_version.py +0 -1
  30. diffstar-0.3.1/diffstar/tests/test_main_sequence_kernels.py +0 -28
  31. diffstar-0.3.1/requirements.txt +0 -4
  32. {diffstar-0.3.1 → diffstar-0.3.2}/.coveragerc +0 -0
  33. {diffstar-0.3.1 → diffstar-0.3.2}/.git_archival.txt +0 -0
  34. {diffstar-0.3.1 → diffstar-0.3.2}/.gitattributes +0 -0
  35. {diffstar-0.3.1 → diffstar-0.3.2}/.github/workflows/linting.yml +0 -0
  36. {diffstar-0.3.1 → diffstar-0.3.2}/.github/workflows/monthly-warning-test.yml +0 -0
  37. {diffstar-0.3.1 → diffstar-0.3.2}/.github/workflows/test_releases.yml +0 -0
  38. {diffstar-0.3.1 → diffstar-0.3.2}/.github/workflows/tests_cron.yml +0 -0
  39. {diffstar-0.3.1 → diffstar-0.3.2}/.gitignore +0 -0
  40. {diffstar-0.3.1 → diffstar-0.3.2}/.readthedocs.yml +0 -0
  41. {diffstar-0.3.1 → diffstar-0.3.2}/LICENSE.rst +0 -0
  42. {diffstar-0.3.1 → diffstar-0.3.2}/README.md +0 -0
  43. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/__init__.py +0 -0
  44. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/data_loaders/__init__.py +0 -0
  45. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/data_loaders/load_bpl.py +0 -0
  46. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/data_loaders/tests/__init__.py +0 -0
  47. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/data_loaders/tests/test_load_bpl.py +0 -0
  48. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/defaults.py +0 -0
  49. {diffstar-0.3.1/diffstar/fitting_helpers/tests → diffstar-0.3.2/diffstar/diffstarnet}/__init__.py +0 -0
  50. {diffstar-0.3.1/diffstar/kernels → diffstar-0.3.2/diffstar/diffstarnet}/tests/__init__.py +0 -0
  51. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/fitting_helpers/__init__.py +0 -0
  52. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/fitting_helpers/param_clippers.py +0 -0
  53. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/fitting_helpers/stars.py +0 -0
  54. {diffstar-0.3.1/diffstar → diffstar-0.3.2/diffstar/fitting_helpers}/tests/__init__.py +0 -0
  55. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/fitting_helpers/tests/test_fitting_kernels_are_frozen.py +0 -0
  56. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/fitting_helpers/tests/test_param_clippers.py +0 -0
  57. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/fitting_helpers/tests/test_stars.py +0 -0
  58. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/fitting_helpers/utils.py +0 -0
  59. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/__init__.py +0 -0
  60. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/gas_consumption.py +0 -0
  61. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/history_kernel_builders.py +0 -0
  62. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/kernel_builders.py +0 -0
  63. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/main_sequence_kernels.py +0 -0
  64. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/quenching_kernels.py +0 -0
  65. /diffstar-0.3.1/docs/source/_static/README.txt → /diffstar-0.3.2/diffstar/kernels/tests/__init__.py +0 -0
  66. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/tests/test_frozen_diffstar_kernels.py +0 -0
  67. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/tests/test_kernel_builders.py +0 -0
  68. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/tests/test_quenching_kernels.py +0 -0
  69. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/tests/testing_data/mah_params_testing_v0.1.0.txt +0 -0
  70. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/tests/testing_data/ms_params_testing_v0.1.0.txt +0 -0
  71. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/tests/testing_data/q_params_testing_v0.1.0.txt +0 -0
  72. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/tests/testing_data/sfh_table_testing_v0.1.0.txt +0 -0
  73. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/tests/testing_data/t_table_testing_v0.1.0.txt +0 -0
  74. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/tests/testing_data/u_ms_params_testing_v0.1.0.txt +0 -0
  75. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/kernels/tests/testing_data/u_q_params_testing_v0.1.0.txt +0 -0
  76. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/sfh.py +0 -0
  77. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/sfh_model.py +0 -0
  78. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/test_defaults.py +0 -0
  79. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/test_gas.py +0 -0
  80. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/test_lax_main_sequence.py +0 -0
  81. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/test_lax_sfh.py +0 -0
  82. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/test_quenching.py +0 -0
  83. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/test_sfh.py +0 -0
  84. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/test_sfh_model.py +0 -0
  85. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/test_utils.py +0 -0
  86. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/default_params_test_diffmah_params.txt +0 -0
  87. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/default_params_test_dmhdt.txt +0 -0
  88. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/default_params_test_dt.txt +0 -0
  89. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/default_params_test_lgt.txt +0 -0
  90. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/default_params_test_log_mah.txt +0 -0
  91. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/default_params_test_sfh.txt +0 -0
  92. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/default_params_test_u_ms_params.txt +0 -0
  93. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/default_params_test_u_q_params.txt +0 -0
  94. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/dt_bpl.txt +0 -0
  95. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/halo_ids_test_sample.txt +0 -0
  96. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/lgt_bpl.txt +0 -0
  97. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/mah_params_test_sample.txt +0 -0
  98. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/ms_u_params_test_sample.txt +0 -0
  99. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/q_u_params_test_sample.txt +0 -0
  100. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/tests/testing_data/sfh_test_sample.txt +0 -0
  101. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar/utils.py +0 -0
  102. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar.egg-info/dependency_links.txt +0 -0
  103. {diffstar-0.3.1 → diffstar-0.3.2}/diffstar.egg-info/top_level.txt +0 -0
  104. {diffstar-0.3.1 → diffstar-0.3.2}/docs/Makefile +0 -0
  105. {diffstar-0.3.1 → diffstar-0.3.2}/docs/make.bat +0 -0
  106. {diffstar-0.3.1 → diffstar-0.3.2}/docs/source/citation.rst +0 -0
  107. {diffstar-0.3.1 → diffstar-0.3.2}/docs/source/conf.py +0 -0
  108. {diffstar-0.3.1 → diffstar-0.3.2}/docs/source/demo_diffstar_sfh.ipynb +0 -0
  109. {diffstar-0.3.1 → diffstar-0.3.2}/docs/source/demos.rst +0 -0
  110. {diffstar-0.3.1 → diffstar-0.3.2}/docs/source/index.rst +0 -0
  111. {diffstar-0.3.1 → diffstar-0.3.2}/docs/source/installation.rst +0 -0
  112. {diffstar-0.3.1 → diffstar-0.3.2}/docs/source/reference.rst +0 -0
  113. {diffstar-0.3.1 → diffstar-0.3.2}/docs/source/rtd_environment.yaml +0 -0
  114. {diffstar-0.3.1 → diffstar-0.3.2}/scripts/generate_unit_testing_data.py +0 -0
  115. {diffstar-0.3.1 → diffstar-0.3.2}/scripts/history_fitting_script.py +0 -0
  116. {diffstar-0.3.1 → diffstar-0.3.2}/setup.cfg +0 -0
  117. {diffstar-0.3.1 → diffstar-0.3.2}/setup.py +0 -0
@@ -1,3 +1,8 @@
1
+ 0.3.2 (2024-10-25)
2
+ ------------------
3
+ - Adapt sfh kernels to be compatible with diffmah 0.6.1
4
+
5
+
1
6
  0.3.1 (2024-6-19)
2
7
  ------------------
3
8
  - Performance improvements for calc_sfh_galpop and calc_sfh_singlegal
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: diffstar
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: Differentiable Star Formation Histories
5
5
  Author-email: Alex Alarcon <alexalarcongonzalez@gmail.com>
6
6
  License: BSD 3-Clause License
@@ -38,7 +38,7 @@ Classifier: Programming Language :: Python :: 3
38
38
  Requires-Python: >=3.9
39
39
  Description-Content-Type: text/markdown
40
40
  License-File: LICENSE.rst
41
- Requires-Dist: diffmah>=0.5.0
41
+ Requires-Dist: diffmah>=0.6.1
42
42
  Requires-Dist: numpy
43
43
  Requires-Dist: jax
44
44
  Requires-Dist: h5py
@@ -0,0 +1 @@
1
+ __version__ = '0.3.2'
@@ -1,5 +1,6 @@
1
1
  """
2
2
  """
3
+
3
4
  import os
4
5
  import warnings
5
6
 
@@ -9,21 +10,31 @@ import numpy as np
9
10
  from ..defaults import SFR_MIN
10
11
  from ..utils import _get_dt_array
11
12
 
13
+ try:
14
+ from umachine_pyio.load_mock import load_mock_from_binaries
15
+
16
+ HAS_UM_LOADER = True
17
+ except ImportError:
18
+ HAS_UM_LOADER = False
19
+
12
20
  TASSO = "/Users/aphearin/work/DATA/diffmah_data"
13
21
  BEBOP = "/lcrc/project/halotools/diffmah_data"
14
22
  LAPTOP = "/Users/alarcon/Documents/diffmah_data"
15
-
23
+ BEBOP_SMDPL = os.path.join(
24
+ "/lcrc/project/galsampler/SMDPL/",
25
+ "dr1_no_merging_upidh/sfh_binary_catalogs/a_1.000000/",
26
+ )
16
27
  H_BPL = 0.678
17
28
  H_TNG = 0.6774
18
29
  H_MDPL = H_BPL
19
30
 
20
31
 
21
- def load_fit_mah(filename, data_drn=BEBOP):
32
+ def load_fit_mah(basename, data_drn=BEBOP):
22
33
  """Load the best fit diffmah parameter data
23
34
 
24
35
  Parameters
25
36
  ----------
26
- filename : string
37
+ basename : string
27
38
  Name of the h5 file where the diffmah best fit parameters are stored
28
39
 
29
40
  data_drn : string
@@ -38,30 +49,107 @@ def load_fit_mah(filename, data_drn=BEBOP):
38
49
  logmp: ndarray of shape (n_gal, )
39
50
  Base-10 logarithm of the present day peak halo mass
40
51
 
52
+ logmp: ndarray of shape (n_gal, )
53
+ Base-10 logarithm of the present day peak halo mass
41
54
  """
42
- fitting_data = dict()
43
55
 
44
- fn = os.path.join(data_drn, filename)
56
+ fn = os.path.join(data_drn, basename)
45
57
  with h5py.File(fn, "r") as hdf:
46
- for key in hdf.keys():
47
- if key == "halo_id":
48
- fitting_data[key] = hdf[key][...]
49
- else:
50
- fitting_data["fit_" + key] = hdf[key][...]
51
-
52
- mah_fit_params = np.array(
53
- [
54
- fitting_data["fit_mah_logtc"],
55
- fitting_data["fit_mah_k"],
56
- fitting_data["fit_early_index"],
57
- fitting_data["fit_late_index"],
58
- ]
59
- ).T
60
- logmp = fitting_data["fit_logmp_fit"]
58
+ mah_fit_params = np.array(
59
+ [
60
+ hdf["logm0"][:],
61
+ hdf["logtc"][:],
62
+ hdf["early_index"][:],
63
+ hdf["late_index"][:],
64
+ ]
65
+ ).T
66
+ logmp = hdf["logm0"][:]
61
67
 
62
68
  return mah_fit_params, logmp
63
69
 
64
70
 
71
+ def load_fit_mah_tpeak(basename, data_drn=BEBOP):
72
+ """Load the best fit diffmah parameter data
73
+
74
+ Parameters
75
+ ----------
76
+ basename : string
77
+ Name of the h5 file where the diffmah best fit parameters are stored
78
+
79
+ data_drn : string
80
+ Filepath where the Diffstar best-fit parameters are stored
81
+
82
+ Returns
83
+ -------
84
+ mah_fit_params: ndarray of shape (n_gal, 4)
85
+ Best fit parameters for each halo:
86
+ (logtc, k, early_index, late_index)
87
+
88
+ logmp: ndarray of shape (n_gal, )
89
+ Base-10 logarithm of the present day peak halo mass
90
+
91
+ logmp: ndarray of shape (n_gal, )
92
+ Base-10 logarithm of the present day peak halo mass
93
+ """
94
+
95
+ fn = os.path.join(data_drn, basename)
96
+ with h5py.File(fn, "r") as hdf:
97
+ mah_fit_params = np.array(
98
+ [
99
+ hdf["logm0"][:],
100
+ hdf["logtc"][:],
101
+ hdf["early_index"][:],
102
+ hdf["late_index"][:],
103
+ ]
104
+ ).T
105
+ t_peak = hdf["t_peak"][:]
106
+ logmp = hdf["logm0"][:]
107
+
108
+ return mah_fit_params, logmp, t_peak
109
+
110
+
111
+ def load_fit_sfh(basename, data_drn=BEBOP):
112
+ """Load the best fit diffmah parameter data
113
+
114
+ Parameters
115
+ ----------
116
+ basename : string
117
+ Name of the h5 file where the diffmah best fit parameters are stored
118
+
119
+ data_drn : string
120
+ Filepath where the Diffstar best-fit parameters are stored
121
+
122
+ Returns
123
+ -------
124
+ sfh_fit_params: ndarray of shape (n_gal, 4)
125
+ Best fit parameters for each halo:
126
+ (logtc, k, early_index, late_index)
127
+
128
+ """
129
+
130
+ fn = os.path.join(data_drn, basename)
131
+ with h5py.File(fn, "r") as hdf:
132
+ ms_fit_params = np.array(
133
+ [
134
+ hdf["lgmcrit"][:],
135
+ hdf["lgy_at_mcrit"][:],
136
+ hdf["indx_lo"][:],
137
+ hdf["indx_hi"][:],
138
+ hdf["tau_dep"][:],
139
+ ]
140
+ ).T
141
+ q_fit_params = np.array(
142
+ [
143
+ hdf["lg_qt"][:],
144
+ hdf["qlglgdt"][:],
145
+ hdf["lg_drop"][:],
146
+ hdf["lg_rejuv"][:],
147
+ ]
148
+ ).T
149
+
150
+ return ms_fit_params, q_fit_params
151
+
152
+
65
153
  def load_bolshoi_data(gal_type, data_drn=BEBOP):
66
154
  """Load the stellar mass histories from UniverseMachine simulation
67
155
  applied to the Bolshoi-Planck (BPL) simulation.
@@ -397,3 +485,67 @@ def load_mdpl_small_data(gal_type, data_drn=BEBOP):
397
485
  log_smahs = np.where(sm_cumsum == 0, 0, np.log10(sm_cumsum))
398
486
 
399
487
  return halo_ids, log_smahs, sfrh, mdpl_t, dt
488
+
489
+
490
+ def load_SMDPL_data(subvols, data_drn=BEBOP_SMDPL):
491
+ """Load the stellar mass histories from UniverseMachine simulation
492
+ applied to the Bolshoi-Planck (BPL) simulation.
493
+
494
+ The loaded stellar mass data has units of Msun assuming the h = H_BPL
495
+ from the cosmology of the underlying simulation.
496
+
497
+ The output stellar mass data has units of Msun/h, or units of
498
+ Mstar[h=H_BPL] using the h value of the simulation.
499
+
500
+ H_BPL is defined at the top of the module.
501
+
502
+ Parameters
503
+ ----------
504
+ gal_type : string
505
+ Name of the galaxy type of the file being loaded. Options are
506
+ 'cens': central galaxies
507
+ 'sats': satellite galaxies
508
+ 'orphans': orphan galaxies
509
+ data_drn : string
510
+ Filepath where the Diffstar best-fit parameters are stored.
511
+
512
+ Returns
513
+ -------
514
+ halo_ids: ndarray of shape (n_gal, )
515
+ IDs of the halos in the file.
516
+ log_smahs: ndarray of shape (n_gal, n_times)
517
+ Cumulative stellar mass history in units of Msun assuming h=1.
518
+ sfrh: ndarray of shape (n_gal, n_times)
519
+ Star formation rate history in units of Msun/yr assuming h=1.
520
+ bpl_t : ndarray of shape (n_times, )
521
+ Cosmic time of each simulated snapshot in Gyr
522
+ dt : ndarray of shape (n_times, )
523
+ Cosmic time steps between each simulated snapshot in Gyr
524
+ """
525
+ if not HAS_UM_LOADER:
526
+ raise ImportError("Must have umachine_pyio installed to load this dataset")
527
+
528
+ galprops = ["halo_id", "sfr_history_main_prog", "mpeak_history_main_prog"]
529
+ mock = load_mock_from_binaries(subvols, root_dirname=data_drn, galprops=galprops)
530
+
531
+ SMDPL_t = np.loadtxt(os.path.join(data_drn, "smdpl_cosmic_time.txt"))
532
+
533
+ halo_ids = mock["halo_id"]
534
+ dt = _get_dt_array(SMDPL_t)
535
+ sfrh = mock["sfr_history_main_prog"]
536
+ sm_cumsum = np.cumsum(sfrh * dt, axis=1) * 1e9
537
+
538
+ with warnings.catch_warnings():
539
+ warnings.simplefilter("ignore")
540
+ log_smahs = np.where(sm_cumsum == 0, 0, np.log10(sm_cumsum))
541
+
542
+ lgmh_min = 7.0
543
+ mh_min = 10**lgmh_min
544
+ msk = mock["mpeak_history_main_prog"] < mh_min
545
+ clipped_mahs = np.where(msk, 1.0, mock["mpeak_history_main_prog"])
546
+ log_mahs = np.log10(clipped_mahs)
547
+ log_mahs = np.maximum.accumulate(log_mahs, axis=1)
548
+
549
+ logmp = log_mahs[:, -1]
550
+
551
+ return halo_ids, log_smahs, sfrh, SMDPL_t, dt, log_mahs, logmp
@@ -0,0 +1,28 @@
1
+ """
2
+ """
3
+
4
+ import os
5
+
6
+ import numpy as np
7
+
8
+ from ..load_smah_data import load_fit_mah_tpeak
9
+
10
+ _THIS_DRNAME = os.path.dirname(os.path.abspath(__file__))
11
+
12
+
13
+ def test_load_fit_mah_tpeak():
14
+ basename = "subvol_000_diffmah_fits.h5"
15
+ data_drn = os.path.join(_THIS_DRNAME, "testing_data")
16
+ _res = load_fit_mah_tpeak(basename, data_drn=data_drn)
17
+ for x in _res:
18
+ assert np.all(np.isfinite(x))
19
+ mah_fit_params, logmp, t_peak = _res
20
+ n_halos, n_params = mah_fit_params.shape
21
+
22
+ assert logmp.shape == (n_halos,)
23
+ assert np.all(logmp > 10)
24
+ assert np.all(logmp < 16)
25
+
26
+ assert t_peak.shape == (n_halos,)
27
+ assert np.all(t_peak > 0)
28
+ assert np.all(t_peak < 14)
@@ -0,0 +1,226 @@
1
+ """
2
+ """
3
+
4
+ from collections import namedtuple
5
+
6
+ import numpy as np
7
+ from diffmah.diffmah_kernels import DEFAULT_MAH_PARAMS
8
+ from diffmah.diffmahpop_kernels.bimod_censat_params import DEFAULT_DIFFMAHPOP_PARAMS
9
+ from diffmah.diffmahpop_kernels.mc_bimod_cens import _mc_diffmah_singlecen_vmap_kern
10
+ from jax import numpy as jnp
11
+ from jax import random as jran
12
+
13
+ from ..defaults import (
14
+ DEFAULT_DIFFSTAR_U_PARAMS,
15
+ DEFAULT_Q_PARAMS_UNQUENCHED,
16
+ DEFAULT_U_MS_PARAMS,
17
+ DEFAULT_U_Q_PARAMS,
18
+ FB,
19
+ LGT0,
20
+ T_TABLE_MIN,
21
+ get_bounded_diffstar_params,
22
+ )
23
+ from ..kernels.main_sequence_kernels_tpeak import MS_PARAM_BOUNDS_PDICT
24
+ from ..sfh_model_tpeak import calc_sfh_galpop
25
+
26
+ T0 = 10**LGT0
27
+ N_SFH_TABLE = 200
28
+ LGMH_MIN = 10.5
29
+ LGSM0_MIN = 5.0
30
+
31
+ TAU_INST = MS_PARAM_BOUNDS_PDICT["tau_dep"][0] + 1e-4
32
+
33
+ _TDATA_SFH_ROOTKEYS = ["sfh_params", "sfh", "smh"]
34
+ _TDATA_NOQ_KEYS = [key + "_noq" for key in _TDATA_SFH_ROOTKEYS]
35
+ _TDATA_NOQ_NOLAG_KEYS = [key + "_nolag" for key in _TDATA_NOQ_KEYS]
36
+ SFH_KEYS = _TDATA_SFH_ROOTKEYS + _TDATA_NOQ_KEYS + _TDATA_NOQ_NOLAG_KEYS
37
+ TDATA_KEYS = ["mah_params", "log_mah"] + SFH_KEYS
38
+ TData = namedtuple("TData", TDATA_KEYS)
39
+
40
+
41
+ def tdata_generator(
42
+ ran_key,
43
+ logm0_sample,
44
+ n_sfh_table=N_SFH_TABLE,
45
+ logsm0_min=LGSM0_MIN,
46
+ n_epochs=float("inf"),
47
+ ):
48
+ """Training data generator for diffstarnet
49
+
50
+ Parameters
51
+ ----------
52
+ ran_key : jax.random.key
53
+
54
+ logm0_sample : array, shape (n_halos, )
55
+ Array of values of diffmah parameter logm0
56
+
57
+ logsm0_min : float, optional
58
+ Minimum z=0 stellar mass in the training data
59
+ Default is set by LGSM0_MIN at top of module
60
+
61
+ n_epochs : int, optional
62
+ maximum number of batches to yield. Default is infinite
63
+
64
+ Yields
65
+ ------
66
+ tdata : namedtuple
67
+ mah_params, t_peak, log_mah, sfh_params, sfh, smh
68
+ sfh_params_noq, sfh_noq, smh_noq
69
+ sfh_params_noq_nolag, sfh_noq_nolag, smh_noq_nolag
70
+
71
+ """
72
+ batchnum = 0
73
+ while batchnum < n_epochs:
74
+ ran_key, batch_key = jran.split(ran_key, 2)
75
+ tdata = _compute_tdata(batch_key, logm0_sample, n_sfh_table, logsm0_min)
76
+ yield tdata
77
+ batchnum += 1
78
+
79
+
80
+ def _compute_tdata(
81
+ ran_key, logm0_sample, n_sfh_table=N_SFH_TABLE, logsm0_min=LGSM0_MIN
82
+ ):
83
+ """"""
84
+ tarr = np.linspace(T_TABLE_MIN, T0, n_sfh_table)
85
+
86
+ mah_key, early_late_key, sfh_key = jran.split(ran_key, 3)
87
+
88
+ _reslist = mc_diffmah_halo_sample(mah_key, tarr, logm0_sample)
89
+ mah_params_early, dmhdt_early, log_mah_early = _reslist[:3]
90
+ mah_params_late, dmhdt_late, log_mah_late = _reslist[3:6]
91
+ frac_early = _reslist[6]
92
+
93
+ n_halos = mah_params_early.logm0.size
94
+ uran_mah = jran.uniform(early_late_key, minval=0, maxval=1, shape=(n_halos,))
95
+ msk_mah = frac_early < uran_mah
96
+ mah_params = DEFAULT_MAH_PARAMS._make(
97
+ [
98
+ jnp.where(
99
+ msk_mah, getattr(mah_params_early, key), getattr(mah_params_late, key)
100
+ )
101
+ for key in mah_params_late._fields
102
+ ]
103
+ )
104
+
105
+ log_mah = jnp.where(msk_mah.reshape((-1, 1)), log_mah_early, log_mah_late)
106
+
107
+ ZZ = np.zeros(n_halos)
108
+
109
+ uran = jran.uniform(sfh_key, minval=-100, maxval=100, shape=(8, n_halos))
110
+
111
+ u_ms_late_index = np.zeros(n_halos) + DEFAULT_U_MS_PARAMS.u_indx_hi
112
+ u_ms_params = [uran[0], uran[1], uran[2], u_ms_late_index, uran[3]]
113
+ u_ms_params = DEFAULT_U_MS_PARAMS._make(u_ms_params)
114
+ u_q_params = DEFAULT_U_Q_PARAMS._make([uran[i, :] for i in range(4, 8)])
115
+
116
+ sfh_u_params = DEFAULT_DIFFSTAR_U_PARAMS._make((u_ms_params, u_q_params))
117
+ sfh_params = get_bounded_diffstar_params(sfh_u_params)
118
+
119
+ q_params_noq = u_q_params._make([ZZ + x for x in DEFAULT_Q_PARAMS_UNQUENCHED])
120
+ sfh_params_noq = sfh_params._replace(q_params=q_params_noq)
121
+
122
+ ms_params_nolag = sfh_params.ms_params._replace(tau_dep=TAU_INST + ZZ)
123
+ sfh_params_noq_nolag = sfh_params_noq._replace(ms_params=ms_params_nolag)
124
+
125
+ sfh, smh = calc_sfh_galpop(
126
+ sfh_params,
127
+ mah_params,
128
+ tarr,
129
+ lgt0=LGT0,
130
+ fb=FB,
131
+ return_smh=True,
132
+ )
133
+ logsm0 = np.log10(smh)[:, -1]
134
+
135
+ sfh_noq, smh_noq = calc_sfh_galpop(
136
+ sfh_params_noq,
137
+ mah_params,
138
+ tarr,
139
+ lgt0=LGT0,
140
+ fb=FB,
141
+ return_smh=True,
142
+ )
143
+ logsm0_noq = np.log10(smh_noq)[:, -1]
144
+
145
+ sfh_noq_nolag, smh_noq_nolag = calc_sfh_galpop(
146
+ sfh_params_noq_nolag,
147
+ mah_params,
148
+ tarr,
149
+ lgt0=LGT0,
150
+ fb=FB,
151
+ return_smh=True,
152
+ )
153
+ logsm0_noq_nolag = np.log10(smh_noq_nolag)[:, -1]
154
+
155
+ # Implement stellar mass cut
156
+ msk = (
157
+ (logsm0 > logsm0_min)
158
+ & (logsm0_noq > logsm0_min)
159
+ & (logsm0_noq_nolag > logsm0_min)
160
+ )
161
+
162
+ mah_params_out = mah_params._make([x[msk] for x in mah_params])
163
+ log_mah_out = log_mah[msk]
164
+ sfh_out = sfh[msk]
165
+ smh_out = smh[msk]
166
+ sfh_noq_out = sfh_noq[msk]
167
+ smh_noq_out = smh_noq[msk]
168
+ sfh_noq_nolag_out = sfh_noq_nolag[msk]
169
+ smh_noq_nolag_out = smh_noq_nolag[msk]
170
+
171
+ ms_params = sfh_params.ms_params._make([x[msk] for x in sfh_params.ms_params])
172
+ q_params = sfh_params.q_params._make([x[msk] for x in sfh_params.q_params])
173
+ sfh_params_out = sfh_params._make((ms_params, q_params))
174
+
175
+ ms_params_noq_out = sfh_params_noq.ms_params._make(
176
+ [x[msk] for x in sfh_params_noq.ms_params]
177
+ )
178
+ q_params_noq_out = sfh_params_noq.q_params._make(
179
+ [x[msk] for x in sfh_params_noq.q_params]
180
+ )
181
+ sfh_params_noq_out = sfh_params_noq._make((ms_params_noq_out, q_params_noq_out))
182
+
183
+ ms_params_noq_nolag_out = sfh_params_noq_nolag.ms_params._make(
184
+ [x[msk] for x in sfh_params_noq_nolag.ms_params]
185
+ )
186
+ sfh_params_noq_nolag_out = sfh_params_noq_out._replace(
187
+ ms_params=ms_params_noq_nolag_out
188
+ )
189
+
190
+ return TData(
191
+ mah_params_out,
192
+ log_mah_out,
193
+ sfh_params_out,
194
+ sfh_out,
195
+ smh_out,
196
+ sfh_params_noq_out,
197
+ sfh_noq_out,
198
+ smh_noq_out,
199
+ sfh_params_noq_nolag_out,
200
+ sfh_noq_nolag_out,
201
+ smh_noq_nolag_out,
202
+ )
203
+
204
+
205
+ def mc_diffmah_halo_sample(ran_key, tarr, logm0_sample):
206
+ n_halos = logm0_sample.size
207
+ ZZ = np.zeros(n_halos)
208
+ t_0 = tarr[-1]
209
+ t_obs = t_0 + ZZ
210
+ lgt0 = np.log10(t_0)
211
+ ran_keys = jran.split(ran_key, n_halos)
212
+ _reslist = _mc_diffmah_singlecen_vmap_kern(
213
+ DEFAULT_DIFFMAHPOP_PARAMS, tarr, logm0_sample, t_obs, ran_keys, lgt0
214
+ )
215
+ mah_params_early, dmhdt_early, log_mah_early = _reslist[:3]
216
+ mah_params_late, dmhdt_late, log_mah_late = _reslist[3:6]
217
+ frac_early = _reslist[6]
218
+ return (
219
+ mah_params_early,
220
+ dmhdt_early,
221
+ log_mah_early,
222
+ mah_params_late,
223
+ dmhdt_late,
224
+ log_mah_late,
225
+ frac_early,
226
+ )
@@ -0,0 +1,80 @@
1
+ """
2
+ """
3
+
4
+ import numpy as np
5
+ from jax import random as jran
6
+
7
+ from .. import diffstarnet_tdata as dtg
8
+
9
+
10
+ def enforce_good_tdata(tdata, logsm0_min=float("-inf")):
11
+ for x in tdata:
12
+ try:
13
+ assert np.all(np.isfinite(x))
14
+ except ValueError: # x is a namedtuple
15
+ for y in x:
16
+ assert np.all(np.isfinite(y))
17
+
18
+ logsm0 = np.log10(tdata.smh[:, -1])
19
+ assert np.all(logsm0 >= logsm0_min)
20
+
21
+ n_halos, n_times = tdata.sfh.shape
22
+ n_halos2 = tdata.mah_params.logm0.size
23
+ assert n_halos2 == n_halos
24
+
25
+ history_keys = [x for x in dtg.SFH_KEYS if "params" not in x]
26
+
27
+ for key in history_keys:
28
+ arr = getattr(tdata, key)
29
+ assert arr.shape == (n_halos, n_times)
30
+
31
+
32
+ def test_tdata_generator():
33
+ ran_key = jran.key(0)
34
+ n_halos = 5_000
35
+ logm0_sample = np.linspace(10, 15, n_halos)
36
+
37
+ # generate 5 epochs of data
38
+ n_epochs = 5
39
+ gen = dtg.tdata_generator(ran_key, logm0_sample, n_epochs=n_epochs)
40
+ tdata_list = list(gen)
41
+ assert len(tdata_list) == n_epochs
42
+ for tdata in tdata_list:
43
+ enforce_good_tdata(tdata, logsm0_min=dtg.LGSM0_MIN)
44
+
45
+ # demo typical usage
46
+ LOGSM0_MIN = 6.0
47
+ gen = dtg.tdata_generator(ran_key, logm0_sample, n_epochs=2, logsm0_min=LOGSM0_MIN)
48
+ tdata0 = next(gen)
49
+ tdata1 = next(gen)
50
+ try:
51
+ next(gen)
52
+ except StopIteration:
53
+ pass # expected because we tried to iterate for longer than n_epochs
54
+
55
+ # tdata should not contain Mstar exceeding logsm0_min
56
+ assert np.all(tdata0.smh[:, -1] >= 10**LOGSM0_MIN)
57
+ assert np.all(tdata1.smh[:, -1] >= 10**LOGSM0_MIN)
58
+
59
+ # tdata generator should yield different tdata with each iteration
60
+ assert not np.allclose(tdata0.log_mah[0, :], tdata1.log_mah[0, :])
61
+ assert not np.allclose(tdata0.sfh[0, :], tdata1.sfh[0, :])
62
+ assert not np.allclose(tdata0.sfh_noq[0, :], tdata1.sfh_noq[0, :])
63
+ assert not np.allclose(tdata0.sfh_noq_nolag[0, :], tdata1.sfh_noq_nolag[0, :])
64
+
65
+
66
+ def test_mc_diffmah_halo_sample():
67
+ ran_key = jran.key(0)
68
+ n_halos_init = 2_000
69
+ logm0_sample = np.linspace(10, 15, n_halos_init)
70
+
71
+ gen = dtg.tdata_generator(ran_key, logm0_sample)
72
+ tdata = next(gen)
73
+
74
+ n_halos = tdata.mah_params.logm0.size
75
+ assert n_halos_init >= n_halos
76
+ assert tdata.mah_params.t_peak.size == n_halos
77
+
78
+ diff = tdata.mah_params.logm0 - tdata.log_mah[:, -1]
79
+ assert np.abs(diff).mean() < 0.1
80
+ assert np.std(diff) < 0.3