diffstar 0.3.2__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. diffstar-0.3.2/.github/workflows/linting.yml → diffstar-0.3.3/.github/workflows/linting.yaml +2 -3
  2. diffstar-0.3.2/.github/workflows/monthly-warning-test.yml → diffstar-0.3.3/.github/workflows/monthly-warning-test.yaml +2 -4
  3. diffstar-0.3.2/.github/workflows/test_releases.yml → diffstar-0.3.3/.github/workflows/test_releases.yaml +2 -4
  4. diffstar-0.3.2/.github/workflows/tests_cron.yml → diffstar-0.3.3/.github/workflows/tests_cron.yaml +2 -4
  5. {diffstar-0.3.2 → diffstar-0.3.3}/CHANGES.rst +5 -0
  6. {diffstar-0.3.2/diffstar.egg-info → diffstar-0.3.3}/PKG-INFO +4 -4
  7. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/__init__.py +2 -1
  8. diffstar-0.3.3/diffstar/_version.py +1 -0
  9. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/load_bpl.py +2 -1
  10. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/defaults.py +3 -2
  11. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/fit_smah_helpers_tpeak.py +1 -2
  12. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/fitting_kernels.py +7 -11
  13. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/param_clippers.py +2 -1
  14. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_fitting_kernels.py +7 -43
  15. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_fitting_kernels_are_frozen.py +24 -27
  16. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_fitting_smah_helpers.py +8 -10
  17. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_param_clippers.py +2 -1
  18. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_stars.py +2 -1
  19. diffstar-0.3.3/diffstar/kernels/__init__.py +4 -0
  20. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/test_quenching_kernels.py +1 -0
  21. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/test_defaults.py +2 -1
  22. diffstar-0.3.3/diffstar/tests/test_sfh_model_tpeak.py +62 -0
  23. {diffstar-0.3.2 → diffstar-0.3.3/diffstar.egg-info}/PKG-INFO +4 -4
  24. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar.egg-info/SOURCES.txt +4 -21
  25. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar.egg-info/requires.txt +1 -1
  26. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/index.rst +0 -1
  27. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/rtd_environment.yaml +2 -2
  28. {diffstar-0.3.2 → diffstar-0.3.3}/pyproject.toml +2 -2
  29. diffstar-0.3.3/requirements.txt +4 -0
  30. diffstar-0.3.2/diffstar/_version.py +0 -1
  31. diffstar-0.3.2/diffstar/fitting_helpers/fit_smah_helpers.py +0 -1715
  32. diffstar-0.3.2/diffstar/kernels/__init__.py +0 -5
  33. diffstar-0.3.2/diffstar/kernels/history_kernel_builders.py +0 -268
  34. diffstar-0.3.2/diffstar/kernels/kernel_builders.py +0 -249
  35. diffstar-0.3.2/diffstar/kernels/main_sequence_kernels.py +0 -233
  36. diffstar-0.3.2/diffstar/kernels/tests/test_frozen_diffstar_kernels.py +0 -121
  37. diffstar-0.3.2/diffstar/kernels/tests/test_kernel_builders.py +0 -161
  38. diffstar-0.3.2/diffstar/kernels/tests/test_kernel_builders_tpeak.py +0 -172
  39. diffstar-0.3.2/diffstar/sfh.py +0 -220
  40. diffstar-0.3.2/diffstar/sfh_model.py +0 -140
  41. diffstar-0.3.2/diffstar/tests/test_gas.py +0 -40
  42. diffstar-0.3.2/diffstar/tests/test_lax_main_sequence.py +0 -153
  43. diffstar-0.3.2/diffstar/tests/test_lax_sfh.py +0 -154
  44. diffstar-0.3.2/diffstar/tests/test_main_sequence_kernels.py +0 -68
  45. diffstar-0.3.2/diffstar/tests/test_sfh.py +0 -198
  46. diffstar-0.3.2/diffstar/tests/test_sfh_model.py +0 -156
  47. diffstar-0.3.2/diffstar/tests/test_sfh_model_tpeak.py +0 -155
  48. diffstar-0.3.2/docs/source/demo_diffstar_fitter.ipynb +0 -337
  49. diffstar-0.3.2/docs/source/demos.rst +0 -8
  50. diffstar-0.3.2/requirements.txt +0 -4
  51. {diffstar-0.3.2 → diffstar-0.3.3}/.coveragerc +0 -0
  52. {diffstar-0.3.2 → diffstar-0.3.3}/.git_archival.txt +0 -0
  53. {diffstar-0.3.2 → diffstar-0.3.3}/.gitattributes +0 -0
  54. {diffstar-0.3.2 → diffstar-0.3.3}/.gitignore +0 -0
  55. {diffstar-0.3.2 → diffstar-0.3.3}/.readthedocs.yml +0 -0
  56. {diffstar-0.3.2 → diffstar-0.3.3}/LICENSE.rst +0 -0
  57. {diffstar-0.3.2 → diffstar-0.3.3}/README.md +0 -0
  58. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/__init__.py +0 -0
  59. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/load_smah_data.py +0 -0
  60. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/tests/__init__.py +0 -0
  61. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/tests/test_load_bpl.py +0 -0
  62. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/tests/test_load_smah_data.py +0 -0
  63. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/tests/testing_data/subvol_000_diffmah_fits.h5 +0 -0
  64. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/diffstarnet/__init__.py +0 -0
  65. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/diffstarnet/diffstarnet_tdata.py +0 -0
  66. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/diffstarnet/tests/__init__.py +0 -0
  67. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/diffstarnet/tests/test_diffstarnet_tdata.py +0 -0
  68. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/__init__.py +0 -0
  69. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/stars.py +0 -0
  70. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/__init__.py +0 -0
  71. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/utils.py +0 -0
  72. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/gas_consumption.py +0 -0
  73. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/history_kernel_builders_tpeak.py +0 -0
  74. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/main_sequence_kernels_tpeak.py +0 -0
  75. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/quenching_kernels.py +0 -0
  76. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/__init__.py +0 -0
  77. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/mah_params_testing_v0.1.0.txt +0 -0
  78. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/ms_params_testing_v0.1.0.txt +0 -0
  79. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/q_params_testing_v0.1.0.txt +0 -0
  80. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/sfh_table_testing_v0.1.0.txt +0 -0
  81. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/t_table_testing_v0.1.0.txt +0 -0
  82. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/u_ms_params_testing_v0.1.0.txt +0 -0
  83. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/u_q_params_testing_v0.1.0.txt +0 -0
  84. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/sfh_model_tpeak.py +0 -0
  85. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/__init__.py +0 -0
  86. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/test_quenching.py +0 -0
  87. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/test_utils.py +0 -0
  88. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_diffmah_params.txt +0 -0
  89. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_dmhdt.txt +0 -0
  90. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_dt.txt +0 -0
  91. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_lgt.txt +0 -0
  92. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_log_mah.txt +0 -0
  93. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_sfh.txt +0 -0
  94. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_u_ms_params.txt +0 -0
  95. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_u_q_params.txt +0 -0
  96. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/dt_bpl.txt +0 -0
  97. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/halo_ids_test_sample.txt +0 -0
  98. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/lgt_bpl.txt +0 -0
  99. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/mah_params_test_sample.txt +0 -0
  100. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/ms_u_params_test_sample.txt +0 -0
  101. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/q_u_params_test_sample.txt +0 -0
  102. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/sfh_test_sample.txt +0 -0
  103. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/utils.py +0 -0
  104. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar.egg-info/dependency_links.txt +0 -0
  105. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar.egg-info/top_level.txt +0 -0
  106. {diffstar-0.3.2 → diffstar-0.3.3}/docs/Makefile +0 -0
  107. {diffstar-0.3.2 → diffstar-0.3.3}/docs/make.bat +0 -0
  108. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/_static/README.txt +0 -0
  109. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/citation.rst +0 -0
  110. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/conf.py +0 -0
  111. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/demo_diffstar_sfh.ipynb +0 -0
  112. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/installation.rst +0 -0
  113. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/reference.rst +0 -0
  114. {diffstar-0.3.2 → diffstar-0.3.3}/scripts/generate_unit_testing_data.py +0 -0
  115. {diffstar-0.3.2 → diffstar-0.3.3}/scripts/history_fitting_script.py +0 -0
  116. {diffstar-0.3.2 → diffstar-0.3.3}/scripts/history_fitting_script_SMDPL_tpeak.py +0 -0
  117. {diffstar-0.3.2 → diffstar-0.3.3}/setup.cfg +0 -0
  118. {diffstar-0.3.2 → diffstar-0.3.3}/setup.py +0 -0
@@ -21,15 +21,14 @@ jobs:
21
21
  channel-priority: strict
22
22
  show-channel-urls: true
23
23
  miniforge-version: latest
24
- miniforge-variant: Mambaforge
25
24
 
26
25
  - name: configure conda and install code
27
26
  shell: bash -l {0}
28
27
  run: |
29
- mamba install --quiet \
28
+ conda install --quiet \
30
29
  --file=requirements.txt
31
30
  python -m pip install --no-deps -e .
32
- mamba install -y -q \
31
+ conda install -y -q \
33
32
  flake8
34
33
 
35
34
  - name: lint
@@ -23,18 +23,16 @@ jobs:
23
23
  channel-priority: strict
24
24
  show-channel-urls: true
25
25
  miniforge-version: latest
26
- miniforge-variant: Mambaforge
27
- use-mamba: true
28
26
 
29
27
  - name: configure conda and install code
30
28
  # Test against current main branch of diffmah and dsps
31
29
  shell: bash -l {0}
32
30
  run: |
33
31
  conda config --set always_yes yes
34
- mamba install --quiet \
32
+ conda install --quiet \
35
33
  --file=requirements.txt
36
34
  python -m pip install --no-deps -e .
37
- mamba install -y -q \
35
+ conda install -y -q \
38
36
  flake8 \
39
37
  pytest \
40
38
  pytest-xdist \
@@ -27,18 +27,16 @@ jobs:
27
27
  channel-priority: strict
28
28
  show-channel-urls: true
29
29
  miniforge-version: latest
30
- miniforge-variant: Mambaforge
31
- use-mamba: true
32
30
 
33
31
  - name: configure conda and install code
34
32
  # Test against current confa-forge release of diffmah
35
33
  shell: bash -l {0}
36
34
  run: |
37
35
  conda config --set always_yes yes
38
- mamba install --quiet \
36
+ conda install --quiet \
39
37
  --file=requirements.txt
40
38
  python -m pip install --no-deps -e .
41
- mamba install -y -q \
39
+ conda install -y -q \
42
40
  flake8 \
43
41
  pytest \
44
42
  pytest-xdist \
@@ -27,18 +27,16 @@ jobs:
27
27
  channel-priority: strict
28
28
  show-channel-urls: true
29
29
  miniforge-version: latest
30
- miniforge-variant: Mambaforge
31
- use-mamba: true
32
30
 
33
31
  - name: configure conda and install code
34
32
  # Test against current main branch of diffmah and dsps
35
33
  shell: bash -l {0}
36
34
  run: |
37
35
  conda config --set always_yes yes
38
- mamba install --quiet \
36
+ conda install --quiet \
39
37
  --file=requirements.txt
40
38
  python -m pip install --no-deps -e .
41
- mamba install -y -q \
39
+ conda install -y -q \
42
40
  flake8 \
43
41
  pytest \
44
42
  pytest-xdist \
@@ -1,3 +1,8 @@
1
+ 0.3.3 (2025-01-15)
2
+ ------------------
3
+ - Clean out old code so that only tpeak-based diffmah models remain (https://github.com/ArgonneCPAC/diffstar/pull/72)
4
+
5
+
1
6
  0.3.2 (2024-10-25)
2
7
  ------------------
3
8
  - Adapt sfh kernels to be compatible with diffmah 0.6.1
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: diffstar
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: Differentiable Star Formation Histories
5
5
  Author-email: Alex Alarcon <alexalarcongonzalez@gmail.com>
6
6
  License: BSD 3-Clause License
@@ -35,10 +35,10 @@ License: BSD 3-Clause License
35
35
 
36
36
  Project-URL: home, https://github.com/ArgonneCPAC/diffstar
37
37
  Classifier: Programming Language :: Python :: 3
38
- Requires-Python: >=3.9
38
+ Requires-Python: >=3.11
39
39
  Description-Content-Type: text/markdown
40
40
  License-File: LICENSE.rst
41
- Requires-Dist: diffmah>=0.6.1
41
+ Requires-Dist: diffmah>=0.7.0
42
42
  Requires-Dist: numpy
43
43
  Requires-Dist: jax
44
44
  Requires-Dist: h5py
@@ -1,5 +1,6 @@
1
1
  """
2
2
  """
3
+
3
4
  # flake8: noqa
4
5
  from ._version import __version__
5
6
  from .defaults import (
@@ -14,4 +15,4 @@ from .defaults import (
14
15
  get_bounded_diffstar_params,
15
16
  get_unbounded_diffstar_params,
16
17
  )
17
- from .sfh_model import calc_sfh_galpop, calc_sfh_singlegal
18
+ from .sfh_model_tpeak import calc_sfh_galpop, calc_sfh_singlegal
@@ -0,0 +1 @@
1
+ __version__ = '0.3.3'
@@ -1,5 +1,6 @@
1
1
  """
2
2
  """
3
+
3
4
  import os
4
5
  from collections import OrderedDict
5
6
 
@@ -13,7 +14,7 @@ try:
13
14
  except ImportError:
14
15
  HAS_ASTROPY = False
15
16
 
16
- from ..kernels.main_sequence_kernels import _get_bounded_sfr_params_vmap
17
+ from ..kernels.main_sequence_kernels_tpeak import _get_bounded_sfr_params_vmap
17
18
  from ..kernels.quenching_kernels import _get_bounded_q_params_vmap
18
19
  from ..utils import _jax_get_dt_array
19
20
 
@@ -1,10 +1,11 @@
1
1
  """
2
2
  """
3
+
3
4
  # flake8: noqa
4
5
  from collections import namedtuple
5
6
 
6
7
  import numpy as np
7
- from diffmah.defaults import DEFAULT_MAH_PARAMS, DEFAULT_MAH_PDICT
8
+ from diffmah.diffmah_kernels import DEFAULT_MAH_PARAMS, DEFAULT_MAH_PDICT
8
9
  from jax import jit as jjit
9
10
 
10
11
  TODAY = 13.8
@@ -20,7 +21,7 @@ DEFAULT_N_STEPS = 50
20
21
 
21
22
 
22
23
  from .kernels.gas_consumption import FB
23
- from .kernels.main_sequence_kernels import (
24
+ from .kernels.main_sequence_kernels_tpeak import (
24
25
  DEFAULT_MS_PARAMS,
25
26
  DEFAULT_MS_PDICT,
26
27
  DEFAULT_U_MS_PARAMS,
@@ -146,7 +146,6 @@ def get_loss_data_default(
146
146
  log_smah_sim,
147
147
  logmp,
148
148
  mah_params,
149
- t_peak,
150
149
  dlogm_cut=DLOGM_CUT,
151
150
  t_fit_min=T_FIT_MIN,
152
151
  mass_fit_min=MIN_MASS_CUT,
@@ -251,7 +250,7 @@ def get_loss_data_default(
251
250
  )
252
251
 
253
252
  logt = jnp.log10(t_sim)
254
- dmhdt, log_mah = _diffmah_kern(mah_params, t_sim, t_peak, lgt0)
253
+ dmhdt, log_mah = _diffmah_kern(mah_params, t_sim, lgt0)
255
254
 
256
255
  weight, weight_fstar = get_weights(
257
256
  t_sim,
@@ -1,14 +1,14 @@
1
1
  """
2
2
  """
3
3
 
4
- from diffmah.individual_halo_assembly import _calc_halo_history
4
+ from diffmah import DEFAULT_MAH_PARAMS, mah_singlehalo
5
5
  from jax import jit as jjit
6
6
  from jax import numpy as jnp
7
7
  from jax import vmap
8
8
 
9
- from ..defaults import FB
9
+ from ..defaults import FB, LGT0
10
10
  from ..kernels.gas_consumption import _get_lagged_gas
11
- from ..kernels.main_sequence_kernels import (
11
+ from ..kernels.main_sequence_kernels_tpeak import (
12
12
  MS_BOUNDING_SIGMOID_PDICT,
13
13
  _get_bounded_sfr_params,
14
14
  _sfr_eff_plaw,
@@ -127,13 +127,7 @@ def calculate_sm_sfr_history_from_mah(
127
127
 
128
128
  @jjit
129
129
  def calculate_histories(
130
- lgt,
131
- dt,
132
- mah_params,
133
- u_ms_params,
134
- u_q_params,
135
- fstar_tdelay,
136
- fb=FB,
130
+ lgt, dt, mah_params, u_ms_params, u_q_params, fstar_tdelay, fb=FB, lgt0=LGT0
137
131
  ):
138
132
  """Calculate individual halo mass MAH and galaxy SFH
139
133
 
@@ -185,7 +179,9 @@ def calculate_histories(
185
179
  Base-10 log of cumulative peak halo mass in units of Msun assuming h=1
186
180
 
187
181
  """
188
- dmhdt, log_mah = _calc_halo_history(lgt, *mah_params)
182
+ tarr = 10**lgt
183
+ mah_params = DEFAULT_MAH_PARAMS._make(mah_params)
184
+ dmhdt, log_mah = mah_singlehalo(mah_params, tarr, lgt0)
189
185
  mstar, sfr, fstar = calculate_sm_sfr_fstar_history_from_mah(
190
186
  lgt,
191
187
  dt,
@@ -1,13 +1,14 @@
1
1
  """Functions ms_param_clipper and q_param_clipper implement clips on the diffstar
2
2
  parameters to help protect against NaNs and infinities
3
3
  """
4
+
4
5
  from collections import OrderedDict
5
6
 
6
7
  from jax import jit as jjit
7
8
  from jax import numpy as jnp
8
9
  from jax import vmap
9
10
 
10
- from ..kernels.main_sequence_kernels import MS_PARAM_BOUNDS_PDICT
11
+ from ..kernels.main_sequence_kernels_tpeak import MS_PARAM_BOUNDS_PDICT
11
12
  from ..kernels.quenching_kernels import Q_PARAM_BOUNDS_PDICT
12
13
 
13
14
  _EPS = 0.001
@@ -2,14 +2,12 @@
2
2
  """
3
3
 
4
4
  import numpy as np
5
- from diffmah.defaults import DEFAULT_MAH_PARAMS, MAH_K
6
- from diffmah.individual_halo_assembly import _calc_halo_history
5
+ from diffmah import mah_singlehalo
6
+ from diffmah.defaults import DEFAULT_MAH_PARAMS
7
7
 
8
8
  from ...defaults import DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, FB, LGT0
9
- from ...kernels import get_sfh_from_mah_kern
10
9
  from ...utils import _jax_get_dt_array
11
10
  from ..fitting_kernels import (
12
- _sfr_history_from_mah,
13
11
  calculate_histories,
14
12
  calculate_histories_vmap,
15
13
  calculate_sm_sfr_fstar_history_from_mah,
@@ -19,17 +17,6 @@ from ..fitting_kernels import (
19
17
  DEFAULT_LOGM0 = 12.0
20
18
 
21
19
 
22
- def _get_default_diffmah_args():
23
- return (
24
- LGT0,
25
- DEFAULT_MAH_PARAMS.logmp,
26
- DEFAULT_MAH_PARAMS.logtc,
27
- MAH_K,
28
- DEFAULT_MAH_PARAMS.early_index,
29
- DEFAULT_MAH_PARAMS.late_index,
30
- )
31
-
32
-
33
20
  def test_calculate_sm_sfr_fstar_history_from_mah():
34
21
  n_t = 100
35
22
  tarr = np.linspace(0.1, 10**LGT0, n_t)
@@ -41,7 +28,7 @@ def test_calculate_sm_sfr_fstar_history_from_mah():
41
28
  _mask = tarr > fstar_tdelay + fstar_tdelay / 2.0
42
29
  fstar_indx_high = fstar_indx_high[_mask]
43
30
 
44
- dmhdt, log_mah = _calc_halo_history(lgtarr, *_get_default_diffmah_args())
31
+ dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, tarr, LGT0)
45
32
  args = (
46
33
  lgtarr,
47
34
  dtarr,
@@ -68,7 +55,7 @@ def test_calculate_sm_sfr_history_from_mah():
68
55
  tarr = np.linspace(0.1, 10**LGT0, n_t)
69
56
  lgtarr = np.log10(tarr)
70
57
  dtarr = _jax_get_dt_array(tarr)
71
- dmhdt, log_mah = _calc_halo_history(lgtarr, *_get_default_diffmah_args())
58
+ dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, tarr, LGT0)
72
59
 
73
60
  args = lgtarr, dtarr, dmhdt, log_mah, DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, FB
74
61
  _res = calculate_sm_sfr_history_from_mah(*args)
@@ -86,7 +73,6 @@ def test_calculate_histories():
86
73
  tarr = np.linspace(0.1, 10**LGT0, n_t)
87
74
  lgtarr = np.log10(tarr)
88
75
  dtarr = _jax_get_dt_array(tarr)
89
- all_diffmah_args = _get_default_diffmah_args()
90
76
 
91
77
  fstar_tdelay = 0.5 # gyr
92
78
  fstar_indx_high = np.searchsorted(tarr, tarr - fstar_tdelay)
@@ -96,7 +82,7 @@ def test_calculate_histories():
96
82
  args = (
97
83
  lgtarr,
98
84
  dtarr,
99
- all_diffmah_args,
85
+ DEFAULT_MAH_PARAMS,
100
86
  DEFAULT_U_MS_PARAMS,
101
87
  DEFAULT_U_Q_PARAMS,
102
88
  fstar_tdelay,
@@ -112,16 +98,12 @@ def test_calculate_histories():
112
98
  assert x.shape == (n_t,)
113
99
  assert np.all(np.diff(log_mah) > 0)
114
100
 
115
- logmp = all_diffmah_args[1]
116
- assert log_mah[-1] == logmp
117
-
118
101
 
119
102
  def test_calculate_histories_vmap():
120
103
  n_t = 100
121
104
  tarr = np.linspace(0.1, 10**LGT0, n_t)
122
105
  lgtarr = np.log10(tarr)
123
106
  dtarr = _jax_get_dt_array(tarr)
124
- all_diffmah_args = np.array(_get_default_diffmah_args()).reshape((1, -1))
125
107
  u_ms_params = np.array(DEFAULT_U_MS_PARAMS).reshape((1, -1))
126
108
  u_q_params = np.array(DEFAULT_U_Q_PARAMS).reshape((1, -1))
127
109
  fstar_tdelay = 0.5 # gyr
@@ -129,11 +111,12 @@ def test_calculate_histories_vmap():
129
111
  _mask = tarr > fstar_tdelay + fstar_tdelay / 2.0
130
112
  fstar_indx_high = fstar_indx_high[_mask]
131
113
 
114
+ mah_params = DEFAULT_MAH_PARAMS._make([np.zeros(1) + x for x in DEFAULT_MAH_PARAMS])
132
115
  # in_axes = (None, None, 0, 0, 0, None, None, None)
133
116
  args = (
134
117
  lgtarr,
135
118
  dtarr,
136
- all_diffmah_args,
119
+ mah_params,
137
120
  u_ms_params,
138
121
  u_q_params,
139
122
  fstar_tdelay,
@@ -145,22 +128,3 @@ def test_calculate_histories_vmap():
145
128
  mstar_galpop, sfr_galpop, fstar_galpop, dmhdt_galpop, log_mah_galpop = _res
146
129
  for x in mstar_galpop, sfr_galpop, dmhdt_galpop, log_mah_galpop:
147
130
  assert x.shape == (1, n_t)
148
-
149
-
150
- def test_sfr_history_from_mah():
151
- n_t = 200
152
- tarr = np.linspace(0.1, 10**LGT0, n_t)
153
- lgtarr = np.log10(tarr)
154
- dtarr = _jax_get_dt_array(tarr)
155
- all_diffmah_args = _get_default_diffmah_args()
156
- dmhdt, log_mah = _calc_halo_history(lgtarr, *all_diffmah_args)
157
- args = lgtarr, dtarr, dmhdt, log_mah, DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, FB
158
- sfh_from_fitting_kernels = _sfr_history_from_mah(*args)
159
- lgt0, logmp, logtc, k, early_index, late_index = all_diffmah_args
160
- mah_params = logmp, logtc, early_index, late_index
161
-
162
- sfh_kern = get_sfh_from_mah_kern(tobs_loop="vmap")
163
- sfh_from_diffstar_kernels = sfh_kern(
164
- tarr, mah_params, DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, LGT0, FB
165
- )
166
- assert np.allclose(sfh_from_fitting_kernels, sfh_from_diffstar_kernels, atol=0.01)
@@ -1,10 +1,11 @@
1
1
  """Unit tests enforcing that the behavior of Diffstar on the default params is frozen.
2
2
  """
3
+
3
4
  import os
4
5
 
5
6
  import numpy as np
7
+ from diffmah import mah_singlehalo
6
8
  from diffmah.defaults import DEFAULT_MAH_PARAMS, MAH_K
7
- from diffmah.individual_halo_assembly import _calc_halo_history
8
9
  from jax import numpy as jnp
9
10
 
10
11
  from ...defaults import DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, LGT0
@@ -19,18 +20,6 @@ TESTING_DATA_DRN = os.path.join(
19
20
  )
20
21
 
21
22
 
22
- def _get_default_mah_params():
23
- """Return (logt0, logmp, logtc, k, early, late)"""
24
- return (
25
- LGT0,
26
- DEFAULT_MAH_PARAMS.logmp,
27
- DEFAULT_MAH_PARAMS.logtc,
28
- MAH_K,
29
- DEFAULT_MAH_PARAMS.early_index,
30
- DEFAULT_MAH_PARAMS.late_index,
31
- )
32
-
33
-
34
23
  def _get_default_sfr_u_params():
35
24
  u_ms_params = jnp.array(DEFAULT_U_MS_PARAMS)
36
25
  u_q_params = jnp.array(DEFAULT_U_Q_PARAMS)
@@ -43,11 +32,10 @@ def calc_sfh_on_default_params(n_t=100):
43
32
  This function is used to generate the unit-testing data used in this module
44
33
  to freeze the behavior of Diffstar evaluated on the default parameters.
45
34
  """
46
- mah_params = _get_default_mah_params()
47
35
 
48
36
  lgt = jnp.linspace(-1, LGT0, n_t)
49
37
  dt = _get_dt_array(10**lgt)
50
- dmhdt, log_mah = _calc_halo_history(lgt, *mah_params)
38
+ dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, 10**lgt, LGT0)
51
39
  u_ms_params, u_q_params = _get_default_sfr_u_params()
52
40
  args = lgt, dt, dmhdt, log_mah, u_ms_params, u_q_params
53
41
  sfh = _sfr_history_from_mah(*args)
@@ -120,7 +108,14 @@ def test_diffmah_behavior_is_frozen():
120
108
  (at which point this test will need to be updated).
121
109
 
122
110
  """
123
- assumed_default_params = _get_default_mah_params()
111
+ assumed_default_params = (
112
+ LGT0,
113
+ DEFAULT_MAH_PARAMS.logm0,
114
+ DEFAULT_MAH_PARAMS.logtc,
115
+ MAH_K,
116
+ DEFAULT_MAH_PARAMS.early_index,
117
+ DEFAULT_MAH_PARAMS.late_index,
118
+ )
124
119
  lgt0, logmp, logtc, k, early_index, late_index = assumed_default_params
125
120
 
126
121
  msg = "Default age of the universe assumed by Diffmah has changed"
@@ -144,7 +139,7 @@ def test_diffmah_behavior_is_frozen():
144
139
 
145
140
  args, sfh = calc_sfh_on_default_params()
146
141
  lgt, dt, dmhdt, log_mah, u_ms_params, u_q_params = args
147
- dmhdt, log_mah = _calc_halo_history(lgt, *assumed_default_params)
142
+ dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, 10**lgt, lgt0)
148
143
 
149
144
  log_mah_fn = os.path.join(TESTING_DATA_DRN, "default_params_test_log_mah.txt")
150
145
  frozen_log_mah = np.loadtxt(log_mah_fn)
@@ -171,26 +166,28 @@ def test_sfh_is_frozen_on_example_bpl_sample():
171
166
  frozen_sfhs = np.loadtxt(sfh_fn)
172
167
  lgt_bpl = np.loadtxt(lgt_fn)
173
168
  dt_bpl = np.loadtxt(dt_fn)
169
+ t_bpl = 10**lgt_bpl
170
+ lgt0 = lgt_bpl[-1]
171
+ assert np.allclose(LGT0_BPL, lgt0, rtol=1e-3)
172
+
174
173
  mah_params_test_sample = np.loadtxt(mah_params_fn)
175
174
  ms_u_params_test_sample = np.loadtxt(ms_params_fn)
176
175
  q_u_params_test_sample = np.loadtxt(q_params_fn)
177
176
 
178
177
  sfh_test_sample = []
179
178
  for ih in range(mah_params_test_sample.shape[0]):
180
- all_mah_params_ih = np.array(
181
- (
182
- LGT0_BPL,
183
- mah_params_test_sample[ih, 0],
184
- mah_params_test_sample[ih, 1],
185
- MAH_K,
186
- mah_params_test_sample[ih, 2],
187
- mah_params_test_sample[ih, 3],
188
- )
179
+ logm0 = mah_params_test_sample[ih, 0]
180
+ logtc = mah_params_test_sample[ih, 1]
181
+ early_indx = mah_params_test_sample[ih, 2]
182
+ late_indx = mah_params_test_sample[ih, 3]
183
+ t_peak = DEFAULT_MAH_PARAMS.t_peak
184
+ mah_params = DEFAULT_MAH_PARAMS._make(
185
+ [logm0, logtc, early_indx, late_indx, t_peak]
189
186
  )
190
187
  ms_u_params_ih = np.array(ms_u_params_test_sample[ih, :])
191
188
  q_u_params_ih = np.array(q_u_params_test_sample[ih, :])
192
189
 
193
- dmhdt_ih, log_mah_ih = _calc_halo_history(lgt_bpl, *all_mah_params_ih)
190
+ dmhdt_ih, log_mah_ih = mah_singlehalo(mah_params, t_bpl, lgt0)
194
191
  sfh_ih = _sfr_history_from_mah(
195
192
  lgt_bpl, dt_bpl, dmhdt_ih, log_mah_ih, ms_u_params_ih, q_u_params_ih
196
193
  )
@@ -2,33 +2,33 @@
2
2
  """
3
3
 
4
4
  import numpy as np
5
+ from diffmah.diffmah_kernels import DEFAULT_MAH_PARAMS
5
6
 
6
7
  from ...defaults import DEFAULT_MS_PDICT, DEFAULT_Q_PDICT
7
8
  from ...utils import _jax_get_dt_array
8
- from ..fit_smah_helpers import get_header, get_loss_data_fixed_hi
9
+ from ..fit_smah_helpers_tpeak import get_header, get_loss_data_default
9
10
 
10
11
  DIFFMAH_K = 3.5
11
12
 
12
13
 
13
14
  def test_get_header_colnames_agree_with_model_param_names():
14
- header = get_header()
15
+ header, colnames = get_header()
15
16
  assert header[0] == "#"
16
- colnames = header[1:].strip().split()
17
17
 
18
18
  assert colnames[0] == "halo_id"
19
19
 
20
20
  u_ms_colnames_from_header = colnames[1:6]
21
- ms_colnames_from_header = [s[2:] for s in u_ms_colnames_from_header]
21
+ ms_colnames_from_header = [s for s in u_ms_colnames_from_header]
22
22
  assert ms_colnames_from_header == list(DEFAULT_MS_PDICT.keys())
23
23
 
24
24
  u_q_colnames_from_header = colnames[6:10]
25
- q_colnames_from_header = [s[2:] for s in u_q_colnames_from_header]
25
+ q_colnames_from_header = [s for s in u_q_colnames_from_header]
26
26
  assert q_colnames_from_header == list(DEFAULT_Q_PDICT.keys())
27
27
 
28
28
  assert colnames[10:] == ["loss", "success"]
29
29
 
30
30
 
31
- def test_get_loss_data_fixed_hi():
31
+ def test_get_loss_data_evaluates():
32
32
  t_sim = np.linspace(0.1, 13.8, 100)
33
33
  dt_sim = _jax_get_dt_array(t_sim)
34
34
  sfrh = np.random.uniform(0, 10, t_sim.size)
@@ -36,8 +36,6 @@ def test_get_loss_data_fixed_hi():
36
36
  log_smah_sim = np.log10(smh)
37
37
 
38
38
  logmp = 12.0
39
- logtc, early, late = 0.1, 2.0, 1.0
40
- mah_params = logtc, DIFFMAH_K, early, late
41
- p_init, loss_data = get_loss_data_fixed_hi(
42
- t_sim, dt_sim, sfrh, log_smah_sim, logmp, mah_params
39
+ p_init, loss_data = get_loss_data_default(
40
+ t_sim, dt_sim, sfrh, log_smah_sim, logmp, DEFAULT_MAH_PARAMS
43
41
  )
@@ -1,9 +1,10 @@
1
1
  """
2
2
  """
3
+
3
4
  import numpy as np
4
5
  from jax import random as jran
5
6
 
6
- from ...kernels.main_sequence_kernels import (
7
+ from ...kernels.main_sequence_kernels_tpeak import (
7
8
  MS_PARAM_BOUNDS_PDICT,
8
9
  _get_bounded_sfr_params_vmap,
9
10
  _get_unbounded_sfr_params_vmap,
@@ -1,7 +1,8 @@
1
1
  """
2
2
  """
3
+
3
4
  from ...defaults import DEFAULT_MS_PDICT
4
- from ...kernels.main_sequence_kernels import MS_PARAM_BOUNDS_PDICT
5
+ from ...kernels.main_sequence_kernels_tpeak import MS_PARAM_BOUNDS_PDICT
5
6
 
6
7
 
7
8
  def test_sfh_parameter_bounds():
@@ -0,0 +1,4 @@
1
+ """
2
+ """
3
+
4
+ # flake8: noqa
@@ -1,5 +1,6 @@
1
1
  """
2
2
  """
3
+
3
4
  import numpy as np
4
5
 
5
6
  from ...defaults import DEFAULT_Q_PARAMS, DEFAULT_U_Q_PARAMS, Q_PARAM_BOUNDS_PDICT
@@ -1,10 +1,11 @@
1
1
  """
2
2
  """
3
+
3
4
  import numpy as np
4
5
  import pytest
5
6
 
6
7
  from .. import defaults
7
- from ..kernels.main_sequence_kernels import (
8
+ from ..kernels.main_sequence_kernels_tpeak import (
8
9
  DEFAULT_MS_PDICT,
9
10
  DEFAULT_U_MS_PDICT,
10
11
  _get_bounded_sfr_params,
@@ -0,0 +1,62 @@
1
+ """"""
2
+
3
+ import numpy as np
4
+ from diffmah.diffmah_kernels import DEFAULT_MAH_PARAMS
5
+ from jax import random as jran
6
+
7
+ from ..defaults import (
8
+ DEFAULT_MS_PARAMS,
9
+ DEFAULT_Q_PARAMS,
10
+ DEFAULT_U_MS_PARAMS,
11
+ DEFAULT_U_Q_PARAMS,
12
+ FB,
13
+ LGT0,
14
+ DiffstarUParams,
15
+ MSUParams,
16
+ QUParams,
17
+ get_bounded_diffstar_params,
18
+ )
19
+ from ..sfh_model_tpeak import calc_sfh_singlegal
20
+
21
+
22
+ def _get_all_default_params():
23
+ ms_params, q_params = DEFAULT_MS_PARAMS, DEFAULT_Q_PARAMS
24
+ return LGT0, DEFAULT_MAH_PARAMS, ms_params, q_params
25
+
26
+
27
+ def _get_all_default_u_params():
28
+ u_ms_params, u_q_params = DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS
29
+ return LGT0, DEFAULT_MAH_PARAMS, u_ms_params, u_q_params
30
+
31
+
32
+ def test_calc_sfh_singlegal_imports_from_top_level():
33
+ from .. import calc_sfh_singlegal as _func # noqa
34
+
35
+
36
+ def test_calc_sfh_galpop_imports_from_top_level():
37
+ from .. import calc_sfh_galpop as _func # noqa
38
+
39
+
40
+ def test_sfh_singlegal_evaluates_on_wide_param_range():
41
+ lgt0, mah_params, u_ms_params_init, u_q_params_init = _get_all_default_u_params()
42
+
43
+ n_t = 100
44
+ tarr = np.linspace(0.1, 10**lgt0, n_t)
45
+
46
+ ran_key = jran.PRNGKey(0)
47
+ ntests = 20
48
+ ran_keys = jran.split(ran_key, ntests)
49
+ for test_key in ran_keys:
50
+ ms_key, q_key = jran.split(test_key, 2)
51
+ u_ms_params = jran.normal(ms_key, shape=(5,)) + np.array(u_ms_params_init)
52
+ u_q_params = jran.normal(q_key, shape=(4,)) + np.array(u_q_params_init)
53
+ sfh_u_params = DiffstarUParams(MSUParams(*u_ms_params), QUParams(*u_q_params))
54
+ sfh_params = get_bounded_diffstar_params(sfh_u_params)
55
+ sfh_new = calc_sfh_singlegal(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB)
56
+ assert np.all(np.isfinite(sfh_new))
57
+
58
+ sfh_new2, smh_new2 = calc_sfh_singlegal(
59
+ sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB, return_smh=True
60
+ )
61
+ assert np.allclose(sfh_new, sfh_new2)
62
+ assert np.all(np.isfinite(smh_new2))
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: diffstar
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: Differentiable Star Formation Histories
5
5
  Author-email: Alex Alarcon <alexalarcongonzalez@gmail.com>
6
6
  License: BSD 3-Clause License
@@ -35,10 +35,10 @@ License: BSD 3-Clause License
35
35
 
36
36
  Project-URL: home, https://github.com/ArgonneCPAC/diffstar
37
37
  Classifier: Programming Language :: Python :: 3
38
- Requires-Python: >=3.9
38
+ Requires-Python: >=3.11
39
39
  Description-Content-Type: text/markdown
40
40
  License-File: LICENSE.rst
41
- Requires-Dist: diffmah>=0.6.1
41
+ Requires-Dist: diffmah>=0.7.0
42
42
  Requires-Dist: numpy
43
43
  Requires-Dist: jax
44
44
  Requires-Dist: h5py