diffstar 0.3.2__tar.gz → 0.3.3__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (118) hide show
  1. diffstar-0.3.2/.github/workflows/linting.yml → diffstar-0.3.3/.github/workflows/linting.yaml +2 -3
  2. diffstar-0.3.2/.github/workflows/monthly-warning-test.yml → diffstar-0.3.3/.github/workflows/monthly-warning-test.yaml +2 -4
  3. diffstar-0.3.2/.github/workflows/test_releases.yml → diffstar-0.3.3/.github/workflows/test_releases.yaml +2 -4
  4. diffstar-0.3.2/.github/workflows/tests_cron.yml → diffstar-0.3.3/.github/workflows/tests_cron.yaml +2 -4
  5. {diffstar-0.3.2 → diffstar-0.3.3}/CHANGES.rst +5 -0
  6. {diffstar-0.3.2/diffstar.egg-info → diffstar-0.3.3}/PKG-INFO +4 -4
  7. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/__init__.py +2 -1
  8. diffstar-0.3.3/diffstar/_version.py +1 -0
  9. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/load_bpl.py +2 -1
  10. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/defaults.py +3 -2
  11. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/fit_smah_helpers_tpeak.py +1 -2
  12. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/fitting_kernels.py +7 -11
  13. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/param_clippers.py +2 -1
  14. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_fitting_kernels.py +7 -43
  15. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_fitting_kernels_are_frozen.py +24 -27
  16. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_fitting_smah_helpers.py +8 -10
  17. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_param_clippers.py +2 -1
  18. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/test_stars.py +2 -1
  19. diffstar-0.3.3/diffstar/kernels/__init__.py +4 -0
  20. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/test_quenching_kernels.py +1 -0
  21. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/test_defaults.py +2 -1
  22. diffstar-0.3.3/diffstar/tests/test_sfh_model_tpeak.py +62 -0
  23. {diffstar-0.3.2 → diffstar-0.3.3/diffstar.egg-info}/PKG-INFO +4 -4
  24. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar.egg-info/SOURCES.txt +4 -21
  25. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar.egg-info/requires.txt +1 -1
  26. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/index.rst +0 -1
  27. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/rtd_environment.yaml +2 -2
  28. {diffstar-0.3.2 → diffstar-0.3.3}/pyproject.toml +2 -2
  29. diffstar-0.3.3/requirements.txt +4 -0
  30. diffstar-0.3.2/diffstar/_version.py +0 -1
  31. diffstar-0.3.2/diffstar/fitting_helpers/fit_smah_helpers.py +0 -1715
  32. diffstar-0.3.2/diffstar/kernels/__init__.py +0 -5
  33. diffstar-0.3.2/diffstar/kernels/history_kernel_builders.py +0 -268
  34. diffstar-0.3.2/diffstar/kernels/kernel_builders.py +0 -249
  35. diffstar-0.3.2/diffstar/kernels/main_sequence_kernels.py +0 -233
  36. diffstar-0.3.2/diffstar/kernels/tests/test_frozen_diffstar_kernels.py +0 -121
  37. diffstar-0.3.2/diffstar/kernels/tests/test_kernel_builders.py +0 -161
  38. diffstar-0.3.2/diffstar/kernels/tests/test_kernel_builders_tpeak.py +0 -172
  39. diffstar-0.3.2/diffstar/sfh.py +0 -220
  40. diffstar-0.3.2/diffstar/sfh_model.py +0 -140
  41. diffstar-0.3.2/diffstar/tests/test_gas.py +0 -40
  42. diffstar-0.3.2/diffstar/tests/test_lax_main_sequence.py +0 -153
  43. diffstar-0.3.2/diffstar/tests/test_lax_sfh.py +0 -154
  44. diffstar-0.3.2/diffstar/tests/test_main_sequence_kernels.py +0 -68
  45. diffstar-0.3.2/diffstar/tests/test_sfh.py +0 -198
  46. diffstar-0.3.2/diffstar/tests/test_sfh_model.py +0 -156
  47. diffstar-0.3.2/diffstar/tests/test_sfh_model_tpeak.py +0 -155
  48. diffstar-0.3.2/docs/source/demo_diffstar_fitter.ipynb +0 -337
  49. diffstar-0.3.2/docs/source/demos.rst +0 -8
  50. diffstar-0.3.2/requirements.txt +0 -4
  51. {diffstar-0.3.2 → diffstar-0.3.3}/.coveragerc +0 -0
  52. {diffstar-0.3.2 → diffstar-0.3.3}/.git_archival.txt +0 -0
  53. {diffstar-0.3.2 → diffstar-0.3.3}/.gitattributes +0 -0
  54. {diffstar-0.3.2 → diffstar-0.3.3}/.gitignore +0 -0
  55. {diffstar-0.3.2 → diffstar-0.3.3}/.readthedocs.yml +0 -0
  56. {diffstar-0.3.2 → diffstar-0.3.3}/LICENSE.rst +0 -0
  57. {diffstar-0.3.2 → diffstar-0.3.3}/README.md +0 -0
  58. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/__init__.py +0 -0
  59. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/load_smah_data.py +0 -0
  60. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/tests/__init__.py +0 -0
  61. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/tests/test_load_bpl.py +0 -0
  62. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/tests/test_load_smah_data.py +0 -0
  63. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/data_loaders/tests/testing_data/subvol_000_diffmah_fits.h5 +0 -0
  64. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/diffstarnet/__init__.py +0 -0
  65. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/diffstarnet/diffstarnet_tdata.py +0 -0
  66. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/diffstarnet/tests/__init__.py +0 -0
  67. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/diffstarnet/tests/test_diffstarnet_tdata.py +0 -0
  68. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/__init__.py +0 -0
  69. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/stars.py +0 -0
  70. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/tests/__init__.py +0 -0
  71. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/fitting_helpers/utils.py +0 -0
  72. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/gas_consumption.py +0 -0
  73. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/history_kernel_builders_tpeak.py +0 -0
  74. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/main_sequence_kernels_tpeak.py +0 -0
  75. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/quenching_kernels.py +0 -0
  76. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/__init__.py +0 -0
  77. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/mah_params_testing_v0.1.0.txt +0 -0
  78. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/ms_params_testing_v0.1.0.txt +0 -0
  79. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/q_params_testing_v0.1.0.txt +0 -0
  80. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/sfh_table_testing_v0.1.0.txt +0 -0
  81. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/t_table_testing_v0.1.0.txt +0 -0
  82. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/u_ms_params_testing_v0.1.0.txt +0 -0
  83. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/kernels/tests/testing_data/u_q_params_testing_v0.1.0.txt +0 -0
  84. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/sfh_model_tpeak.py +0 -0
  85. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/__init__.py +0 -0
  86. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/test_quenching.py +0 -0
  87. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/test_utils.py +0 -0
  88. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_diffmah_params.txt +0 -0
  89. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_dmhdt.txt +0 -0
  90. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_dt.txt +0 -0
  91. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_lgt.txt +0 -0
  92. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_log_mah.txt +0 -0
  93. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_sfh.txt +0 -0
  94. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_u_ms_params.txt +0 -0
  95. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/default_params_test_u_q_params.txt +0 -0
  96. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/dt_bpl.txt +0 -0
  97. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/halo_ids_test_sample.txt +0 -0
  98. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/lgt_bpl.txt +0 -0
  99. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/mah_params_test_sample.txt +0 -0
  100. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/ms_u_params_test_sample.txt +0 -0
  101. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/q_u_params_test_sample.txt +0 -0
  102. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/tests/testing_data/sfh_test_sample.txt +0 -0
  103. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar/utils.py +0 -0
  104. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar.egg-info/dependency_links.txt +0 -0
  105. {diffstar-0.3.2 → diffstar-0.3.3}/diffstar.egg-info/top_level.txt +0 -0
  106. {diffstar-0.3.2 → diffstar-0.3.3}/docs/Makefile +0 -0
  107. {diffstar-0.3.2 → diffstar-0.3.3}/docs/make.bat +0 -0
  108. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/_static/README.txt +0 -0
  109. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/citation.rst +0 -0
  110. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/conf.py +0 -0
  111. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/demo_diffstar_sfh.ipynb +0 -0
  112. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/installation.rst +0 -0
  113. {diffstar-0.3.2 → diffstar-0.3.3}/docs/source/reference.rst +0 -0
  114. {diffstar-0.3.2 → diffstar-0.3.3}/scripts/generate_unit_testing_data.py +0 -0
  115. {diffstar-0.3.2 → diffstar-0.3.3}/scripts/history_fitting_script.py +0 -0
  116. {diffstar-0.3.2 → diffstar-0.3.3}/scripts/history_fitting_script_SMDPL_tpeak.py +0 -0
  117. {diffstar-0.3.2 → diffstar-0.3.3}/setup.cfg +0 -0
  118. {diffstar-0.3.2 → diffstar-0.3.3}/setup.py +0 -0
@@ -21,15 +21,14 @@ jobs:
21
21
  channel-priority: strict
22
22
  show-channel-urls: true
23
23
  miniforge-version: latest
24
- miniforge-variant: Mambaforge
25
24
 
26
25
  - name: configure conda and install code
27
26
  shell: bash -l {0}
28
27
  run: |
29
- mamba install --quiet \
28
+ conda install --quiet \
30
29
  --file=requirements.txt
31
30
  python -m pip install --no-deps -e .
32
- mamba install -y -q \
31
+ conda install -y -q \
33
32
  flake8
34
33
 
35
34
  - name: lint
@@ -23,18 +23,16 @@ jobs:
23
23
  channel-priority: strict
24
24
  show-channel-urls: true
25
25
  miniforge-version: latest
26
- miniforge-variant: Mambaforge
27
- use-mamba: true
28
26
 
29
27
  - name: configure conda and install code
30
28
  # Test against current main branch of diffmah and dsps
31
29
  shell: bash -l {0}
32
30
  run: |
33
31
  conda config --set always_yes yes
34
- mamba install --quiet \
32
+ conda install --quiet \
35
33
  --file=requirements.txt
36
34
  python -m pip install --no-deps -e .
37
- mamba install -y -q \
35
+ conda install -y -q \
38
36
  flake8 \
39
37
  pytest \
40
38
  pytest-xdist \
@@ -27,18 +27,16 @@ jobs:
27
27
  channel-priority: strict
28
28
  show-channel-urls: true
29
29
  miniforge-version: latest
30
- miniforge-variant: Mambaforge
31
- use-mamba: true
32
30
 
33
31
  - name: configure conda and install code
34
32
  # Test against current confa-forge release of diffmah
35
33
  shell: bash -l {0}
36
34
  run: |
37
35
  conda config --set always_yes yes
38
- mamba install --quiet \
36
+ conda install --quiet \
39
37
  --file=requirements.txt
40
38
  python -m pip install --no-deps -e .
41
- mamba install -y -q \
39
+ conda install -y -q \
42
40
  flake8 \
43
41
  pytest \
44
42
  pytest-xdist \
@@ -27,18 +27,16 @@ jobs:
27
27
  channel-priority: strict
28
28
  show-channel-urls: true
29
29
  miniforge-version: latest
30
- miniforge-variant: Mambaforge
31
- use-mamba: true
32
30
 
33
31
  - name: configure conda and install code
34
32
  # Test against current main branch of diffmah and dsps
35
33
  shell: bash -l {0}
36
34
  run: |
37
35
  conda config --set always_yes yes
38
- mamba install --quiet \
36
+ conda install --quiet \
39
37
  --file=requirements.txt
40
38
  python -m pip install --no-deps -e .
41
- mamba install -y -q \
39
+ conda install -y -q \
42
40
  flake8 \
43
41
  pytest \
44
42
  pytest-xdist \
@@ -1,3 +1,8 @@
1
+ 0.3.3 (2025-01-15)
2
+ ------------------
3
+ - Clean out old code so that only tpeak-based diffmah models remain (https://github.com/ArgonneCPAC/diffstar/pull/72)
4
+
5
+
1
6
  0.3.2 (2024-10-25)
2
7
  ------------------
3
8
  - Adapt sfh kernels to be compatible with diffmah 0.6.1
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: diffstar
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: Differentiable Star Formation Histories
5
5
  Author-email: Alex Alarcon <alexalarcongonzalez@gmail.com>
6
6
  License: BSD 3-Clause License
@@ -35,10 +35,10 @@ License: BSD 3-Clause License
35
35
 
36
36
  Project-URL: home, https://github.com/ArgonneCPAC/diffstar
37
37
  Classifier: Programming Language :: Python :: 3
38
- Requires-Python: >=3.9
38
+ Requires-Python: >=3.11
39
39
  Description-Content-Type: text/markdown
40
40
  License-File: LICENSE.rst
41
- Requires-Dist: diffmah>=0.6.1
41
+ Requires-Dist: diffmah>=0.7.0
42
42
  Requires-Dist: numpy
43
43
  Requires-Dist: jax
44
44
  Requires-Dist: h5py
@@ -1,5 +1,6 @@
1
1
  """
2
2
  """
3
+
3
4
  # flake8: noqa
4
5
  from ._version import __version__
5
6
  from .defaults import (
@@ -14,4 +15,4 @@ from .defaults import (
14
15
  get_bounded_diffstar_params,
15
16
  get_unbounded_diffstar_params,
16
17
  )
17
- from .sfh_model import calc_sfh_galpop, calc_sfh_singlegal
18
+ from .sfh_model_tpeak import calc_sfh_galpop, calc_sfh_singlegal
@@ -0,0 +1 @@
1
+ __version__ = '0.3.3'
@@ -1,5 +1,6 @@
1
1
  """
2
2
  """
3
+
3
4
  import os
4
5
  from collections import OrderedDict
5
6
 
@@ -13,7 +14,7 @@ try:
13
14
  except ImportError:
14
15
  HAS_ASTROPY = False
15
16
 
16
- from ..kernels.main_sequence_kernels import _get_bounded_sfr_params_vmap
17
+ from ..kernels.main_sequence_kernels_tpeak import _get_bounded_sfr_params_vmap
17
18
  from ..kernels.quenching_kernels import _get_bounded_q_params_vmap
18
19
  from ..utils import _jax_get_dt_array
19
20
 
@@ -1,10 +1,11 @@
1
1
  """
2
2
  """
3
+
3
4
  # flake8: noqa
4
5
  from collections import namedtuple
5
6
 
6
7
  import numpy as np
7
- from diffmah.defaults import DEFAULT_MAH_PARAMS, DEFAULT_MAH_PDICT
8
+ from diffmah.diffmah_kernels import DEFAULT_MAH_PARAMS, DEFAULT_MAH_PDICT
8
9
  from jax import jit as jjit
9
10
 
10
11
  TODAY = 13.8
@@ -20,7 +21,7 @@ DEFAULT_N_STEPS = 50
20
21
 
21
22
 
22
23
  from .kernels.gas_consumption import FB
23
- from .kernels.main_sequence_kernels import (
24
+ from .kernels.main_sequence_kernels_tpeak import (
24
25
  DEFAULT_MS_PARAMS,
25
26
  DEFAULT_MS_PDICT,
26
27
  DEFAULT_U_MS_PARAMS,
@@ -146,7 +146,6 @@ def get_loss_data_default(
146
146
  log_smah_sim,
147
147
  logmp,
148
148
  mah_params,
149
- t_peak,
150
149
  dlogm_cut=DLOGM_CUT,
151
150
  t_fit_min=T_FIT_MIN,
152
151
  mass_fit_min=MIN_MASS_CUT,
@@ -251,7 +250,7 @@ def get_loss_data_default(
251
250
  )
252
251
 
253
252
  logt = jnp.log10(t_sim)
254
- dmhdt, log_mah = _diffmah_kern(mah_params, t_sim, t_peak, lgt0)
253
+ dmhdt, log_mah = _diffmah_kern(mah_params, t_sim, lgt0)
255
254
 
256
255
  weight, weight_fstar = get_weights(
257
256
  t_sim,
@@ -1,14 +1,14 @@
1
1
  """
2
2
  """
3
3
 
4
- from diffmah.individual_halo_assembly import _calc_halo_history
4
+ from diffmah import DEFAULT_MAH_PARAMS, mah_singlehalo
5
5
  from jax import jit as jjit
6
6
  from jax import numpy as jnp
7
7
  from jax import vmap
8
8
 
9
- from ..defaults import FB
9
+ from ..defaults import FB, LGT0
10
10
  from ..kernels.gas_consumption import _get_lagged_gas
11
- from ..kernels.main_sequence_kernels import (
11
+ from ..kernels.main_sequence_kernels_tpeak import (
12
12
  MS_BOUNDING_SIGMOID_PDICT,
13
13
  _get_bounded_sfr_params,
14
14
  _sfr_eff_plaw,
@@ -127,13 +127,7 @@ def calculate_sm_sfr_history_from_mah(
127
127
 
128
128
  @jjit
129
129
  def calculate_histories(
130
- lgt,
131
- dt,
132
- mah_params,
133
- u_ms_params,
134
- u_q_params,
135
- fstar_tdelay,
136
- fb=FB,
130
+ lgt, dt, mah_params, u_ms_params, u_q_params, fstar_tdelay, fb=FB, lgt0=LGT0
137
131
  ):
138
132
  """Calculate individual halo mass MAH and galaxy SFH
139
133
 
@@ -185,7 +179,9 @@ def calculate_histories(
185
179
  Base-10 log of cumulative peak halo mass in units of Msun assuming h=1
186
180
 
187
181
  """
188
- dmhdt, log_mah = _calc_halo_history(lgt, *mah_params)
182
+ tarr = 10**lgt
183
+ mah_params = DEFAULT_MAH_PARAMS._make(mah_params)
184
+ dmhdt, log_mah = mah_singlehalo(mah_params, tarr, lgt0)
189
185
  mstar, sfr, fstar = calculate_sm_sfr_fstar_history_from_mah(
190
186
  lgt,
191
187
  dt,
@@ -1,13 +1,14 @@
1
1
  """Functions ms_param_clipper and q_param_clipper implement clips on the diffstar
2
2
  parameters to help protect against NaNs and infinities
3
3
  """
4
+
4
5
  from collections import OrderedDict
5
6
 
6
7
  from jax import jit as jjit
7
8
  from jax import numpy as jnp
8
9
  from jax import vmap
9
10
 
10
- from ..kernels.main_sequence_kernels import MS_PARAM_BOUNDS_PDICT
11
+ from ..kernels.main_sequence_kernels_tpeak import MS_PARAM_BOUNDS_PDICT
11
12
  from ..kernels.quenching_kernels import Q_PARAM_BOUNDS_PDICT
12
13
 
13
14
  _EPS = 0.001
@@ -2,14 +2,12 @@
2
2
  """
3
3
 
4
4
  import numpy as np
5
- from diffmah.defaults import DEFAULT_MAH_PARAMS, MAH_K
6
- from diffmah.individual_halo_assembly import _calc_halo_history
5
+ from diffmah import mah_singlehalo
6
+ from diffmah.defaults import DEFAULT_MAH_PARAMS
7
7
 
8
8
  from ...defaults import DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, FB, LGT0
9
- from ...kernels import get_sfh_from_mah_kern
10
9
  from ...utils import _jax_get_dt_array
11
10
  from ..fitting_kernels import (
12
- _sfr_history_from_mah,
13
11
  calculate_histories,
14
12
  calculate_histories_vmap,
15
13
  calculate_sm_sfr_fstar_history_from_mah,
@@ -19,17 +17,6 @@ from ..fitting_kernels import (
19
17
  DEFAULT_LOGM0 = 12.0
20
18
 
21
19
 
22
- def _get_default_diffmah_args():
23
- return (
24
- LGT0,
25
- DEFAULT_MAH_PARAMS.logmp,
26
- DEFAULT_MAH_PARAMS.logtc,
27
- MAH_K,
28
- DEFAULT_MAH_PARAMS.early_index,
29
- DEFAULT_MAH_PARAMS.late_index,
30
- )
31
-
32
-
33
20
  def test_calculate_sm_sfr_fstar_history_from_mah():
34
21
  n_t = 100
35
22
  tarr = np.linspace(0.1, 10**LGT0, n_t)
@@ -41,7 +28,7 @@ def test_calculate_sm_sfr_fstar_history_from_mah():
41
28
  _mask = tarr > fstar_tdelay + fstar_tdelay / 2.0
42
29
  fstar_indx_high = fstar_indx_high[_mask]
43
30
 
44
- dmhdt, log_mah = _calc_halo_history(lgtarr, *_get_default_diffmah_args())
31
+ dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, tarr, LGT0)
45
32
  args = (
46
33
  lgtarr,
47
34
  dtarr,
@@ -68,7 +55,7 @@ def test_calculate_sm_sfr_history_from_mah():
68
55
  tarr = np.linspace(0.1, 10**LGT0, n_t)
69
56
  lgtarr = np.log10(tarr)
70
57
  dtarr = _jax_get_dt_array(tarr)
71
- dmhdt, log_mah = _calc_halo_history(lgtarr, *_get_default_diffmah_args())
58
+ dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, tarr, LGT0)
72
59
 
73
60
  args = lgtarr, dtarr, dmhdt, log_mah, DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, FB
74
61
  _res = calculate_sm_sfr_history_from_mah(*args)
@@ -86,7 +73,6 @@ def test_calculate_histories():
86
73
  tarr = np.linspace(0.1, 10**LGT0, n_t)
87
74
  lgtarr = np.log10(tarr)
88
75
  dtarr = _jax_get_dt_array(tarr)
89
- all_diffmah_args = _get_default_diffmah_args()
90
76
 
91
77
  fstar_tdelay = 0.5 # gyr
92
78
  fstar_indx_high = np.searchsorted(tarr, tarr - fstar_tdelay)
@@ -96,7 +82,7 @@ def test_calculate_histories():
96
82
  args = (
97
83
  lgtarr,
98
84
  dtarr,
99
- all_diffmah_args,
85
+ DEFAULT_MAH_PARAMS,
100
86
  DEFAULT_U_MS_PARAMS,
101
87
  DEFAULT_U_Q_PARAMS,
102
88
  fstar_tdelay,
@@ -112,16 +98,12 @@ def test_calculate_histories():
112
98
  assert x.shape == (n_t,)
113
99
  assert np.all(np.diff(log_mah) > 0)
114
100
 
115
- logmp = all_diffmah_args[1]
116
- assert log_mah[-1] == logmp
117
-
118
101
 
119
102
  def test_calculate_histories_vmap():
120
103
  n_t = 100
121
104
  tarr = np.linspace(0.1, 10**LGT0, n_t)
122
105
  lgtarr = np.log10(tarr)
123
106
  dtarr = _jax_get_dt_array(tarr)
124
- all_diffmah_args = np.array(_get_default_diffmah_args()).reshape((1, -1))
125
107
  u_ms_params = np.array(DEFAULT_U_MS_PARAMS).reshape((1, -1))
126
108
  u_q_params = np.array(DEFAULT_U_Q_PARAMS).reshape((1, -1))
127
109
  fstar_tdelay = 0.5 # gyr
@@ -129,11 +111,12 @@ def test_calculate_histories_vmap():
129
111
  _mask = tarr > fstar_tdelay + fstar_tdelay / 2.0
130
112
  fstar_indx_high = fstar_indx_high[_mask]
131
113
 
114
+ mah_params = DEFAULT_MAH_PARAMS._make([np.zeros(1) + x for x in DEFAULT_MAH_PARAMS])
132
115
  # in_axes = (None, None, 0, 0, 0, None, None, None)
133
116
  args = (
134
117
  lgtarr,
135
118
  dtarr,
136
- all_diffmah_args,
119
+ mah_params,
137
120
  u_ms_params,
138
121
  u_q_params,
139
122
  fstar_tdelay,
@@ -145,22 +128,3 @@ def test_calculate_histories_vmap():
145
128
  mstar_galpop, sfr_galpop, fstar_galpop, dmhdt_galpop, log_mah_galpop = _res
146
129
  for x in mstar_galpop, sfr_galpop, dmhdt_galpop, log_mah_galpop:
147
130
  assert x.shape == (1, n_t)
148
-
149
-
150
- def test_sfr_history_from_mah():
151
- n_t = 200
152
- tarr = np.linspace(0.1, 10**LGT0, n_t)
153
- lgtarr = np.log10(tarr)
154
- dtarr = _jax_get_dt_array(tarr)
155
- all_diffmah_args = _get_default_diffmah_args()
156
- dmhdt, log_mah = _calc_halo_history(lgtarr, *all_diffmah_args)
157
- args = lgtarr, dtarr, dmhdt, log_mah, DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, FB
158
- sfh_from_fitting_kernels = _sfr_history_from_mah(*args)
159
- lgt0, logmp, logtc, k, early_index, late_index = all_diffmah_args
160
- mah_params = logmp, logtc, early_index, late_index
161
-
162
- sfh_kern = get_sfh_from_mah_kern(tobs_loop="vmap")
163
- sfh_from_diffstar_kernels = sfh_kern(
164
- tarr, mah_params, DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, LGT0, FB
165
- )
166
- assert np.allclose(sfh_from_fitting_kernels, sfh_from_diffstar_kernels, atol=0.01)
@@ -1,10 +1,11 @@
1
1
  """Unit tests enforcing that the behavior of Diffstar on the default params is frozen.
2
2
  """
3
+
3
4
  import os
4
5
 
5
6
  import numpy as np
7
+ from diffmah import mah_singlehalo
6
8
  from diffmah.defaults import DEFAULT_MAH_PARAMS, MAH_K
7
- from diffmah.individual_halo_assembly import _calc_halo_history
8
9
  from jax import numpy as jnp
9
10
 
10
11
  from ...defaults import DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS, LGT0
@@ -19,18 +20,6 @@ TESTING_DATA_DRN = os.path.join(
19
20
  )
20
21
 
21
22
 
22
- def _get_default_mah_params():
23
- """Return (logt0, logmp, logtc, k, early, late)"""
24
- return (
25
- LGT0,
26
- DEFAULT_MAH_PARAMS.logmp,
27
- DEFAULT_MAH_PARAMS.logtc,
28
- MAH_K,
29
- DEFAULT_MAH_PARAMS.early_index,
30
- DEFAULT_MAH_PARAMS.late_index,
31
- )
32
-
33
-
34
23
  def _get_default_sfr_u_params():
35
24
  u_ms_params = jnp.array(DEFAULT_U_MS_PARAMS)
36
25
  u_q_params = jnp.array(DEFAULT_U_Q_PARAMS)
@@ -43,11 +32,10 @@ def calc_sfh_on_default_params(n_t=100):
43
32
  This function is used to generate the unit-testing data used in this module
44
33
  to freeze the behavior of Diffstar evaluated on the default parameters.
45
34
  """
46
- mah_params = _get_default_mah_params()
47
35
 
48
36
  lgt = jnp.linspace(-1, LGT0, n_t)
49
37
  dt = _get_dt_array(10**lgt)
50
- dmhdt, log_mah = _calc_halo_history(lgt, *mah_params)
38
+ dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, 10**lgt, LGT0)
51
39
  u_ms_params, u_q_params = _get_default_sfr_u_params()
52
40
  args = lgt, dt, dmhdt, log_mah, u_ms_params, u_q_params
53
41
  sfh = _sfr_history_from_mah(*args)
@@ -120,7 +108,14 @@ def test_diffmah_behavior_is_frozen():
120
108
  (at which point this test will need to be updated).
121
109
 
122
110
  """
123
- assumed_default_params = _get_default_mah_params()
111
+ assumed_default_params = (
112
+ LGT0,
113
+ DEFAULT_MAH_PARAMS.logm0,
114
+ DEFAULT_MAH_PARAMS.logtc,
115
+ MAH_K,
116
+ DEFAULT_MAH_PARAMS.early_index,
117
+ DEFAULT_MAH_PARAMS.late_index,
118
+ )
124
119
  lgt0, logmp, logtc, k, early_index, late_index = assumed_default_params
125
120
 
126
121
  msg = "Default age of the universe assumed by Diffmah has changed"
@@ -144,7 +139,7 @@ def test_diffmah_behavior_is_frozen():
144
139
 
145
140
  args, sfh = calc_sfh_on_default_params()
146
141
  lgt, dt, dmhdt, log_mah, u_ms_params, u_q_params = args
147
- dmhdt, log_mah = _calc_halo_history(lgt, *assumed_default_params)
142
+ dmhdt, log_mah = mah_singlehalo(DEFAULT_MAH_PARAMS, 10**lgt, lgt0)
148
143
 
149
144
  log_mah_fn = os.path.join(TESTING_DATA_DRN, "default_params_test_log_mah.txt")
150
145
  frozen_log_mah = np.loadtxt(log_mah_fn)
@@ -171,26 +166,28 @@ def test_sfh_is_frozen_on_example_bpl_sample():
171
166
  frozen_sfhs = np.loadtxt(sfh_fn)
172
167
  lgt_bpl = np.loadtxt(lgt_fn)
173
168
  dt_bpl = np.loadtxt(dt_fn)
169
+ t_bpl = 10**lgt_bpl
170
+ lgt0 = lgt_bpl[-1]
171
+ assert np.allclose(LGT0_BPL, lgt0, rtol=1e-3)
172
+
174
173
  mah_params_test_sample = np.loadtxt(mah_params_fn)
175
174
  ms_u_params_test_sample = np.loadtxt(ms_params_fn)
176
175
  q_u_params_test_sample = np.loadtxt(q_params_fn)
177
176
 
178
177
  sfh_test_sample = []
179
178
  for ih in range(mah_params_test_sample.shape[0]):
180
- all_mah_params_ih = np.array(
181
- (
182
- LGT0_BPL,
183
- mah_params_test_sample[ih, 0],
184
- mah_params_test_sample[ih, 1],
185
- MAH_K,
186
- mah_params_test_sample[ih, 2],
187
- mah_params_test_sample[ih, 3],
188
- )
179
+ logm0 = mah_params_test_sample[ih, 0]
180
+ logtc = mah_params_test_sample[ih, 1]
181
+ early_indx = mah_params_test_sample[ih, 2]
182
+ late_indx = mah_params_test_sample[ih, 3]
183
+ t_peak = DEFAULT_MAH_PARAMS.t_peak
184
+ mah_params = DEFAULT_MAH_PARAMS._make(
185
+ [logm0, logtc, early_indx, late_indx, t_peak]
189
186
  )
190
187
  ms_u_params_ih = np.array(ms_u_params_test_sample[ih, :])
191
188
  q_u_params_ih = np.array(q_u_params_test_sample[ih, :])
192
189
 
193
- dmhdt_ih, log_mah_ih = _calc_halo_history(lgt_bpl, *all_mah_params_ih)
190
+ dmhdt_ih, log_mah_ih = mah_singlehalo(mah_params, t_bpl, lgt0)
194
191
  sfh_ih = _sfr_history_from_mah(
195
192
  lgt_bpl, dt_bpl, dmhdt_ih, log_mah_ih, ms_u_params_ih, q_u_params_ih
196
193
  )
@@ -2,33 +2,33 @@
2
2
  """
3
3
 
4
4
  import numpy as np
5
+ from diffmah.diffmah_kernels import DEFAULT_MAH_PARAMS
5
6
 
6
7
  from ...defaults import DEFAULT_MS_PDICT, DEFAULT_Q_PDICT
7
8
  from ...utils import _jax_get_dt_array
8
- from ..fit_smah_helpers import get_header, get_loss_data_fixed_hi
9
+ from ..fit_smah_helpers_tpeak import get_header, get_loss_data_default
9
10
 
10
11
  DIFFMAH_K = 3.5
11
12
 
12
13
 
13
14
  def test_get_header_colnames_agree_with_model_param_names():
14
- header = get_header()
15
+ header, colnames = get_header()
15
16
  assert header[0] == "#"
16
- colnames = header[1:].strip().split()
17
17
 
18
18
  assert colnames[0] == "halo_id"
19
19
 
20
20
  u_ms_colnames_from_header = colnames[1:6]
21
- ms_colnames_from_header = [s[2:] for s in u_ms_colnames_from_header]
21
+ ms_colnames_from_header = [s for s in u_ms_colnames_from_header]
22
22
  assert ms_colnames_from_header == list(DEFAULT_MS_PDICT.keys())
23
23
 
24
24
  u_q_colnames_from_header = colnames[6:10]
25
- q_colnames_from_header = [s[2:] for s in u_q_colnames_from_header]
25
+ q_colnames_from_header = [s for s in u_q_colnames_from_header]
26
26
  assert q_colnames_from_header == list(DEFAULT_Q_PDICT.keys())
27
27
 
28
28
  assert colnames[10:] == ["loss", "success"]
29
29
 
30
30
 
31
- def test_get_loss_data_fixed_hi():
31
+ def test_get_loss_data_evaluates():
32
32
  t_sim = np.linspace(0.1, 13.8, 100)
33
33
  dt_sim = _jax_get_dt_array(t_sim)
34
34
  sfrh = np.random.uniform(0, 10, t_sim.size)
@@ -36,8 +36,6 @@ def test_get_loss_data_fixed_hi():
36
36
  log_smah_sim = np.log10(smh)
37
37
 
38
38
  logmp = 12.0
39
- logtc, early, late = 0.1, 2.0, 1.0
40
- mah_params = logtc, DIFFMAH_K, early, late
41
- p_init, loss_data = get_loss_data_fixed_hi(
42
- t_sim, dt_sim, sfrh, log_smah_sim, logmp, mah_params
39
+ p_init, loss_data = get_loss_data_default(
40
+ t_sim, dt_sim, sfrh, log_smah_sim, logmp, DEFAULT_MAH_PARAMS
43
41
  )
@@ -1,9 +1,10 @@
1
1
  """
2
2
  """
3
+
3
4
  import numpy as np
4
5
  from jax import random as jran
5
6
 
6
- from ...kernels.main_sequence_kernels import (
7
+ from ...kernels.main_sequence_kernels_tpeak import (
7
8
  MS_PARAM_BOUNDS_PDICT,
8
9
  _get_bounded_sfr_params_vmap,
9
10
  _get_unbounded_sfr_params_vmap,
@@ -1,7 +1,8 @@
1
1
  """
2
2
  """
3
+
3
4
  from ...defaults import DEFAULT_MS_PDICT
4
- from ...kernels.main_sequence_kernels import MS_PARAM_BOUNDS_PDICT
5
+ from ...kernels.main_sequence_kernels_tpeak import MS_PARAM_BOUNDS_PDICT
5
6
 
6
7
 
7
8
  def test_sfh_parameter_bounds():
@@ -0,0 +1,4 @@
1
+ """
2
+ """
3
+
4
+ # flake8: noqa
@@ -1,5 +1,6 @@
1
1
  """
2
2
  """
3
+
3
4
  import numpy as np
4
5
 
5
6
  from ...defaults import DEFAULT_Q_PARAMS, DEFAULT_U_Q_PARAMS, Q_PARAM_BOUNDS_PDICT
@@ -1,10 +1,11 @@
1
1
  """
2
2
  """
3
+
3
4
  import numpy as np
4
5
  import pytest
5
6
 
6
7
  from .. import defaults
7
- from ..kernels.main_sequence_kernels import (
8
+ from ..kernels.main_sequence_kernels_tpeak import (
8
9
  DEFAULT_MS_PDICT,
9
10
  DEFAULT_U_MS_PDICT,
10
11
  _get_bounded_sfr_params,
@@ -0,0 +1,62 @@
1
+ """"""
2
+
3
+ import numpy as np
4
+ from diffmah.diffmah_kernels import DEFAULT_MAH_PARAMS
5
+ from jax import random as jran
6
+
7
+ from ..defaults import (
8
+ DEFAULT_MS_PARAMS,
9
+ DEFAULT_Q_PARAMS,
10
+ DEFAULT_U_MS_PARAMS,
11
+ DEFAULT_U_Q_PARAMS,
12
+ FB,
13
+ LGT0,
14
+ DiffstarUParams,
15
+ MSUParams,
16
+ QUParams,
17
+ get_bounded_diffstar_params,
18
+ )
19
+ from ..sfh_model_tpeak import calc_sfh_singlegal
20
+
21
+
22
+ def _get_all_default_params():
23
+ ms_params, q_params = DEFAULT_MS_PARAMS, DEFAULT_Q_PARAMS
24
+ return LGT0, DEFAULT_MAH_PARAMS, ms_params, q_params
25
+
26
+
27
+ def _get_all_default_u_params():
28
+ u_ms_params, u_q_params = DEFAULT_U_MS_PARAMS, DEFAULT_U_Q_PARAMS
29
+ return LGT0, DEFAULT_MAH_PARAMS, u_ms_params, u_q_params
30
+
31
+
32
+ def test_calc_sfh_singlegal_imports_from_top_level():
33
+ from .. import calc_sfh_singlegal as _func # noqa
34
+
35
+
36
+ def test_calc_sfh_galpop_imports_from_top_level():
37
+ from .. import calc_sfh_galpop as _func # noqa
38
+
39
+
40
+ def test_sfh_singlegal_evaluates_on_wide_param_range():
41
+ lgt0, mah_params, u_ms_params_init, u_q_params_init = _get_all_default_u_params()
42
+
43
+ n_t = 100
44
+ tarr = np.linspace(0.1, 10**lgt0, n_t)
45
+
46
+ ran_key = jran.PRNGKey(0)
47
+ ntests = 20
48
+ ran_keys = jran.split(ran_key, ntests)
49
+ for test_key in ran_keys:
50
+ ms_key, q_key = jran.split(test_key, 2)
51
+ u_ms_params = jran.normal(ms_key, shape=(5,)) + np.array(u_ms_params_init)
52
+ u_q_params = jran.normal(q_key, shape=(4,)) + np.array(u_q_params_init)
53
+ sfh_u_params = DiffstarUParams(MSUParams(*u_ms_params), QUParams(*u_q_params))
54
+ sfh_params = get_bounded_diffstar_params(sfh_u_params)
55
+ sfh_new = calc_sfh_singlegal(sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB)
56
+ assert np.all(np.isfinite(sfh_new))
57
+
58
+ sfh_new2, smh_new2 = calc_sfh_singlegal(
59
+ sfh_params, mah_params, tarr, lgt0=lgt0, fb=FB, return_smh=True
60
+ )
61
+ assert np.allclose(sfh_new, sfh_new2)
62
+ assert np.all(np.isfinite(smh_new2))
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: diffstar
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: Differentiable Star Formation Histories
5
5
  Author-email: Alex Alarcon <alexalarcongonzalez@gmail.com>
6
6
  License: BSD 3-Clause License
@@ -35,10 +35,10 @@ License: BSD 3-Clause License
35
35
 
36
36
  Project-URL: home, https://github.com/ArgonneCPAC/diffstar
37
37
  Classifier: Programming Language :: Python :: 3
38
- Requires-Python: >=3.9
38
+ Requires-Python: >=3.11
39
39
  Description-Content-Type: text/markdown
40
40
  License-File: LICENSE.rst
41
- Requires-Dist: diffmah>=0.6.1
41
+ Requires-Dist: diffmah>=0.7.0
42
42
  Requires-Dist: numpy
43
43
  Requires-Dist: jax
44
44
  Requires-Dist: h5py