diffstar 0.3.3__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. diffstar-0.3.4/.github/dependabot.yml +10 -0
  2. {diffstar-0.3.3 → diffstar-0.3.4}/.github/workflows/linting.yaml +2 -2
  3. {diffstar-0.3.3 → diffstar-0.3.4}/.github/workflows/monthly-warning-test.yaml +2 -2
  4. {diffstar-0.3.3 → diffstar-0.3.4}/.github/workflows/test_releases.yaml +3 -3
  5. {diffstar-0.3.3 → diffstar-0.3.4}/.github/workflows/tests_cron.yaml +2 -2
  6. {diffstar-0.3.3 → diffstar-0.3.4}/.readthedocs.yml +1 -1
  7. {diffstar-0.3.3 → diffstar-0.3.4}/CHANGES.rst +5 -0
  8. {diffstar-0.3.3/diffstar.egg-info → diffstar-0.3.4}/PKG-INFO +1 -1
  9. diffstar-0.3.4/diffstar/_version.py +1 -0
  10. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/diffstarnet/diffstarnet_tdata.py +41 -3
  11. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/diffstarnet/tests/test_diffstarnet_tdata.py +20 -0
  12. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/test_utils.py +35 -10
  13. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/utils.py +33 -0
  14. {diffstar-0.3.3 → diffstar-0.3.4/diffstar.egg-info}/PKG-INFO +1 -1
  15. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar.egg-info/SOURCES.txt +1 -0
  16. diffstar-0.3.3/diffstar/_version.py +0 -1
  17. {diffstar-0.3.3 → diffstar-0.3.4}/.coveragerc +0 -0
  18. {diffstar-0.3.3 → diffstar-0.3.4}/.git_archival.txt +0 -0
  19. {diffstar-0.3.3 → diffstar-0.3.4}/.gitattributes +0 -0
  20. {diffstar-0.3.3 → diffstar-0.3.4}/.gitignore +0 -0
  21. {diffstar-0.3.3 → diffstar-0.3.4}/LICENSE.rst +0 -0
  22. {diffstar-0.3.3 → diffstar-0.3.4}/README.md +0 -0
  23. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/__init__.py +0 -0
  24. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/data_loaders/__init__.py +0 -0
  25. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/data_loaders/load_bpl.py +0 -0
  26. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/data_loaders/load_smah_data.py +0 -0
  27. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/data_loaders/tests/__init__.py +0 -0
  28. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/data_loaders/tests/test_load_bpl.py +0 -0
  29. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/data_loaders/tests/test_load_smah_data.py +0 -0
  30. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/data_loaders/tests/testing_data/subvol_000_diffmah_fits.h5 +0 -0
  31. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/defaults.py +0 -0
  32. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/diffstarnet/__init__.py +0 -0
  33. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/diffstarnet/tests/__init__.py +0 -0
  34. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/fitting_helpers/__init__.py +0 -0
  35. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/fitting_helpers/fit_smah_helpers_tpeak.py +0 -0
  36. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/fitting_helpers/fitting_kernels.py +0 -0
  37. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/fitting_helpers/param_clippers.py +0 -0
  38. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/fitting_helpers/stars.py +0 -0
  39. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/fitting_helpers/tests/__init__.py +0 -0
  40. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/fitting_helpers/tests/test_fitting_kernels.py +0 -0
  41. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/fitting_helpers/tests/test_fitting_kernels_are_frozen.py +0 -0
  42. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/fitting_helpers/tests/test_fitting_smah_helpers.py +0 -0
  43. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/fitting_helpers/tests/test_param_clippers.py +0 -0
  44. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/fitting_helpers/tests/test_stars.py +0 -0
  45. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/fitting_helpers/utils.py +0 -0
  46. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/__init__.py +0 -0
  47. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/gas_consumption.py +0 -0
  48. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/history_kernel_builders_tpeak.py +0 -0
  49. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/main_sequence_kernels_tpeak.py +0 -0
  50. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/quenching_kernels.py +0 -0
  51. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/tests/__init__.py +0 -0
  52. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/tests/test_quenching_kernels.py +0 -0
  53. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/tests/testing_data/mah_params_testing_v0.1.0.txt +0 -0
  54. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/tests/testing_data/ms_params_testing_v0.1.0.txt +0 -0
  55. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/tests/testing_data/q_params_testing_v0.1.0.txt +0 -0
  56. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/tests/testing_data/sfh_table_testing_v0.1.0.txt +0 -0
  57. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/tests/testing_data/t_table_testing_v0.1.0.txt +0 -0
  58. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/tests/testing_data/u_ms_params_testing_v0.1.0.txt +0 -0
  59. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/kernels/tests/testing_data/u_q_params_testing_v0.1.0.txt +0 -0
  60. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/sfh_model_tpeak.py +0 -0
  61. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/__init__.py +0 -0
  62. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/test_defaults.py +0 -0
  63. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/test_quenching.py +0 -0
  64. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/test_sfh_model_tpeak.py +0 -0
  65. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/default_params_test_diffmah_params.txt +0 -0
  66. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/default_params_test_dmhdt.txt +0 -0
  67. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/default_params_test_dt.txt +0 -0
  68. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/default_params_test_lgt.txt +0 -0
  69. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/default_params_test_log_mah.txt +0 -0
  70. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/default_params_test_sfh.txt +0 -0
  71. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/default_params_test_u_ms_params.txt +0 -0
  72. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/default_params_test_u_q_params.txt +0 -0
  73. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/dt_bpl.txt +0 -0
  74. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/halo_ids_test_sample.txt +0 -0
  75. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/lgt_bpl.txt +0 -0
  76. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/mah_params_test_sample.txt +0 -0
  77. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/ms_u_params_test_sample.txt +0 -0
  78. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/q_u_params_test_sample.txt +0 -0
  79. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar/tests/testing_data/sfh_test_sample.txt +0 -0
  80. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar.egg-info/dependency_links.txt +0 -0
  81. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar.egg-info/requires.txt +0 -0
  82. {diffstar-0.3.3 → diffstar-0.3.4}/diffstar.egg-info/top_level.txt +0 -0
  83. {diffstar-0.3.3 → diffstar-0.3.4}/docs/Makefile +0 -0
  84. {diffstar-0.3.3 → diffstar-0.3.4}/docs/make.bat +0 -0
  85. {diffstar-0.3.3 → diffstar-0.3.4}/docs/source/_static/README.txt +0 -0
  86. {diffstar-0.3.3 → diffstar-0.3.4}/docs/source/citation.rst +0 -0
  87. {diffstar-0.3.3 → diffstar-0.3.4}/docs/source/conf.py +0 -0
  88. {diffstar-0.3.3 → diffstar-0.3.4}/docs/source/demo_diffstar_sfh.ipynb +0 -0
  89. {diffstar-0.3.3 → diffstar-0.3.4}/docs/source/index.rst +0 -0
  90. {diffstar-0.3.3 → diffstar-0.3.4}/docs/source/installation.rst +0 -0
  91. {diffstar-0.3.3 → diffstar-0.3.4}/docs/source/reference.rst +0 -0
  92. {diffstar-0.3.3 → diffstar-0.3.4}/docs/source/rtd_environment.yaml +0 -0
  93. {diffstar-0.3.3 → diffstar-0.3.4}/pyproject.toml +0 -0
  94. {diffstar-0.3.3 → diffstar-0.3.4}/requirements.txt +0 -0
  95. {diffstar-0.3.3 → diffstar-0.3.4}/scripts/generate_unit_testing_data.py +0 -0
  96. {diffstar-0.3.3 → diffstar-0.3.4}/scripts/history_fitting_script.py +0 -0
  97. {diffstar-0.3.3 → diffstar-0.3.4}/scripts/history_fitting_script_SMDPL_tpeak.py +0 -0
  98. {diffstar-0.3.3 → diffstar-0.3.4}/setup.cfg +0 -0
  99. {diffstar-0.3.3 → diffstar-0.3.4}/setup.py +0 -0
@@ -0,0 +1,10 @@
1
+ version: 2
2
+ updates:
3
+ - package-ecosystem: "github-actions"
4
+ directory: "/"
5
+ schedule:
6
+ interval: "monthly"
7
+ groups:
8
+ github-actions:
9
+ patterns:
10
+ - '*'
@@ -12,9 +12,9 @@ jobs:
12
12
  runs-on: "ubuntu-latest"
13
13
 
14
14
  steps:
15
- - uses: actions/checkout@v2
15
+ - uses: actions/checkout@v4
16
16
 
17
- - uses: conda-incubator/setup-miniconda@v2
17
+ - uses: conda-incubator/setup-miniconda@v3
18
18
  with:
19
19
  python-version: 3.11
20
20
  channels: conda-forge,defaults
@@ -12,11 +12,11 @@ jobs:
12
12
  runs-on: "ubuntu-latest"
13
13
 
14
14
  steps:
15
- - uses: actions/checkout@v2
15
+ - uses: actions/checkout@v4
16
16
  with:
17
17
  fetch-depth: 0
18
18
 
19
- - uses: conda-incubator/setup-miniconda@v2
19
+ - uses: conda-incubator/setup-miniconda@v3
20
20
  with:
21
21
  python-version: 3.11
22
22
  channels: conda-forge,defaults
@@ -16,11 +16,11 @@ jobs:
16
16
  runs-on: "ubuntu-latest"
17
17
 
18
18
  steps:
19
- - uses: actions/checkout@v2
19
+ - uses: actions/checkout@v4
20
20
  with:
21
21
  fetch-depth: 0
22
22
 
23
- - uses: conda-incubator/setup-miniconda@v2
23
+ - uses: conda-incubator/setup-miniconda@v3
24
24
  with:
25
25
  python-version: 3.11
26
26
  channels: conda-forge,defaults
@@ -53,7 +53,7 @@ jobs:
53
53
  pytest -v diffstar --cov --cov-report=xml
54
54
 
55
55
  - name: Upload coverage reports to Codecov
56
- uses: codecov/codecov-action@v3
56
+ uses: codecov/codecov-action@v5
57
57
 
58
58
  - name: test versions
59
59
  shell: bash -el {0}
@@ -16,11 +16,11 @@ jobs:
16
16
  runs-on: "ubuntu-latest"
17
17
 
18
18
  steps:
19
- - uses: actions/checkout@v2
19
+ - uses: actions/checkout@v4
20
20
  with:
21
21
  fetch-depth: 0
22
22
 
23
- - uses: conda-incubator/setup-miniconda@v2
23
+ - uses: conda-incubator/setup-miniconda@v3
24
24
  with:
25
25
  python-version: 3.11
26
26
  channels: conda-forge,defaults
@@ -6,7 +6,7 @@ sphinx:
6
6
  build:
7
7
  os: "ubuntu-22.04"
8
8
  tools:
9
- python: "mambaforge-22.9"
9
+ python: "miniconda-latest"
10
10
 
11
11
  conda:
12
12
  environment: docs/source/rtd_environment.yaml
@@ -1,3 +1,8 @@
1
+ 0.3.4 (2025-03-03)
2
+ ------------------
3
+ - Add convenience function cumulative_mstar_formed_galpop (https://github.com/ArgonneCPAC/diffstar/pull/77)
4
+
5
+
1
6
  0.3.3 (2025-01-15)
2
7
  ------------------
3
8
  - Clean out old code so that only tpeak-based diffmah models remain (https://github.com/ArgonneCPAC/diffstar/pull/72)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: diffstar
3
- Version: 0.3.3
3
+ Version: 0.3.4
4
4
  Summary: Differentiable Star Formation Histories
5
5
  Author-email: Alex Alarcon <alexalarcongonzalez@gmail.com>
6
6
  License: BSD 3-Clause License
@@ -0,0 +1 @@
1
+ __version__ = '0.3.4'
@@ -34,10 +34,45 @@ _TDATA_SFH_ROOTKEYS = ["sfh_params", "sfh", "smh"]
34
34
  _TDATA_NOQ_KEYS = [key + "_noq" for key in _TDATA_SFH_ROOTKEYS]
35
35
  _TDATA_NOQ_NOLAG_KEYS = [key + "_nolag" for key in _TDATA_NOQ_KEYS]
36
36
  SFH_KEYS = _TDATA_SFH_ROOTKEYS + _TDATA_NOQ_KEYS + _TDATA_NOQ_NOLAG_KEYS
37
- TDATA_KEYS = ["mah_params", "log_mah"] + SFH_KEYS
37
+ TDATA_KEYS = ["mah_params", "log_mah"] + SFH_KEYS + ["time_arr"]
38
38
  TData = namedtuple("TData", TDATA_KEYS)
39
39
 
40
40
 
41
+ def tdata_generator_dithertarr(
42
+ ran_key,
43
+ logm0_sample,
44
+ n_sfh_table=N_SFH_TABLE,
45
+ logsm0_min=LGSM0_MIN,
46
+ n_epochs=float("inf")
47
+ ):
48
+ """
49
+ Same as tdata_generator, but for each generation, a new t_table_min value
50
+ is drawn uniformly within the first bin (and same for logm0_sample)
51
+ """
52
+ min_m0, max_m0 = logm0_sample.min(), logm0_sample.max()
53
+ num_m0 = logm0_sample.size
54
+
55
+ max_dither_t = (T0 - T_TABLE_MIN) / (n_sfh_table - 1)
56
+ max_dither_m = (max_m0 - min_m0) / (num_m0 - 1)
57
+ batchnum = 0
58
+ while batchnum < n_epochs:
59
+ ran_key, batch_key, *dither_keys = jran.split(ran_key, 4)
60
+ dither1, dither2 = jran.uniform(
61
+ dither_keys[0], (2,), maxval=max_dither_t
62
+ )
63
+ tarr = np.linspace(T_TABLE_MIN + dither1, T0 - dither2, n_sfh_table)
64
+ dither1, dither2 = jran.uniform(
65
+ dither_keys[1], (2,), maxval=max_dither_m
66
+ )
67
+ logm0_new = np.linspace(min_m0 + dither1, max_m0 - dither2, num_m0)
68
+ tdata = _compute_tdata(
69
+ batch_key, logm0_new, n_sfh_table, logsm0_min,
70
+ tarr=tarr
71
+ )
72
+ yield tdata
73
+ batchnum += 1
74
+
75
+
41
76
  def tdata_generator(
42
77
  ran_key,
43
78
  logm0_sample,
@@ -78,10 +113,12 @@ def tdata_generator(
78
113
 
79
114
 
80
115
  def _compute_tdata(
81
- ran_key, logm0_sample, n_sfh_table=N_SFH_TABLE, logsm0_min=LGSM0_MIN
116
+ ran_key, logm0_sample, n_sfh_table=N_SFH_TABLE, logsm0_min=LGSM0_MIN,
117
+ tarr=None,
82
118
  ):
83
119
  """"""
84
- tarr = np.linspace(T_TABLE_MIN, T0, n_sfh_table)
120
+ if tarr is None:
121
+ tarr = np.linspace(T_TABLE_MIN, T0, n_sfh_table)
85
122
 
86
123
  mah_key, early_late_key, sfh_key = jran.split(ran_key, 3)
87
124
 
@@ -199,6 +236,7 @@ def _compute_tdata(
199
236
  sfh_params_noq_nolag_out,
200
237
  sfh_noq_nolag_out,
201
238
  smh_noq_nolag_out,
239
+ tarr
202
240
  )
203
241
 
204
242
 
@@ -29,6 +29,26 @@ def enforce_good_tdata(tdata, logsm0_min=float("-inf")):
29
29
  assert arr.shape == (n_halos, n_times)
30
30
 
31
31
 
32
+ def test_tdata_generator_dithertarr():
33
+ ran_key = jran.key(0)
34
+ n_halos = 5_000
35
+ logm0_sample = np.linspace(10, 15, n_halos)
36
+
37
+ # generate 5 epochs of data
38
+ n_epochs = 5
39
+ gen = dtg.tdata_generator_dithertarr(ran_key, logm0_sample, n_epochs=n_epochs)
40
+ tdata_list = list(gen)
41
+ assert len(tdata_list) == n_epochs
42
+ for tdata in tdata_list:
43
+ enforce_good_tdata(tdata, logsm0_min=dtg.LGSM0_MIN)
44
+
45
+ # Assert that the time arrays start and end from unique values
46
+ start_times = [x.time_arr[0] for x in tdata_list]
47
+ end_times = [x.time_arr[-1] for x in tdata_list]
48
+ assert len(set(start_times)) == len(start_times)
49
+ assert len(set(end_times)) == len(end_times)
50
+
51
+
32
52
  def test_tdata_generator():
33
53
  ran_key = jran.key(0)
34
54
  n_halos = 5_000
@@ -1,11 +1,12 @@
1
1
  """
2
2
  """
3
+
3
4
  import numpy as np
4
5
  import pytest
5
6
  from jax import random as jran
6
7
 
8
+ from .. import utils
7
9
  from ..defaults import T_TABLE_MIN
8
- from ..utils import _get_dt_array, _jax_get_dt_array, cumtrapz, cumulative_mstar_formed
9
10
 
10
11
  try:
11
12
  import dsps
@@ -14,13 +15,21 @@ try:
14
15
  except ImportError:
15
16
  HAS_DSPS = False
16
17
 
18
+ try:
19
+ from scipy.integrate import trapezoid
20
+
21
+ HAS_SCIPY = True
22
+ except ImportError:
23
+ HAS_SCIPY = False
24
+
17
25
  MSG_HAS_DSPS = "Must have dsps installed to run this test"
26
+ MSG_HAS_SCIPY = "Must have scipy installed to run this test"
18
27
 
19
28
 
20
29
  def test_jax_get_dt_array_linspace():
21
30
  tarr = np.linspace(1, 13.8, 50)
22
- dtarr_np = _get_dt_array(tarr)
23
- dtarr_jnp = _jax_get_dt_array(tarr)
31
+ dtarr_np = utils._get_dt_array(tarr)
32
+ dtarr_jnp = utils._jax_get_dt_array(tarr)
24
33
  assert np.allclose(dtarr_np, dtarr_jnp, atol=0.01)
25
34
 
26
35
 
@@ -30,11 +39,12 @@ def test_jax_get_dt_array_random():
30
39
  for __ in range(n_tests):
31
40
  ran_key, key = jran.split(ran_key, 2)
32
41
  tarr = np.sort(jran.uniform(key, minval=0, maxval=14, shape=(50,)))
33
- dtarr_np = _get_dt_array(tarr)
34
- dtarr_jnp = _jax_get_dt_array(tarr)
42
+ dtarr_np = utils._get_dt_array(tarr)
43
+ dtarr_jnp = utils._jax_get_dt_array(tarr)
35
44
  assert np.allclose(dtarr_np, dtarr_jnp, atol=0.01)
36
45
 
37
46
 
47
+ @pytest.mark.skipif(not HAS_SCIPY, reason=MSG_HAS_SCIPY)
38
48
  def test_cumtrapz():
39
49
  ran_key = jran.PRNGKey(0)
40
50
  n_x = 100
@@ -43,16 +53,16 @@ def test_cumtrapz():
43
53
  x_key, y_key, ran_key = jran.split(ran_key, 3)
44
54
  xarr = np.sort(jran.uniform(x_key, minval=0, maxval=1, shape=(n_x,)))
45
55
  yarr = jran.uniform(y_key, minval=0, maxval=1, shape=(n_x,))
46
- jax_result = cumtrapz(xarr, yarr)
47
- np_result = [np.trapz(yarr[:-i], x=xarr[:-i]) for i in range(1, n_x)][::-1]
56
+ jax_result = utils.cumtrapz(xarr, yarr)
57
+ np_result = [trapezoid(yarr[:-i], x=xarr[:-i]) for i in range(1, n_x)][::-1]
48
58
  assert np.allclose(jax_result[:-1], np_result, rtol=1e-4)
49
- assert np.allclose(jax_result[-1], np.trapz(yarr, x=xarr), rtol=1e-4)
59
+ assert np.allclose(jax_result[-1], trapezoid(yarr, x=xarr), rtol=1e-4)
50
60
 
51
61
 
52
62
  def test_cumulative_mstar_formed_returns_reasonable_arrays():
53
63
  t_table = np.linspace(T_TABLE_MIN, 13.8, 200)
54
64
  sfh_table = np.random.uniform(0, 1, t_table.size)
55
- smh_table = cumulative_mstar_formed(t_table, sfh_table)
65
+ smh_table = utils.cumulative_mstar_formed(t_table, sfh_table)
56
66
  assert smh_table.shape == t_table.shape
57
67
  assert np.all(smh_table > 0)
58
68
  assert np.all(np.diff(smh_table) > 0)
@@ -67,6 +77,21 @@ def test_cumulative_mstar_formed_agrees_with_dsps():
67
77
  ran_keys = jran.split(ran_key, n_tests)
68
78
  for key in ran_keys:
69
79
  sfh_table = jran.uniform(key, minval=0, maxval=1, shape=(nt,))
70
- smh_table_diffstar = cumulative_mstar_formed(t_table, sfh_table)
80
+ smh_table_diffstar = utils.cumulative_mstar_formed(t_table, sfh_table)
71
81
  smh_table_dsps = dsps.utils.cumulative_mstar_formed(t_table, sfh_table)
72
82
  assert np.allclose(smh_table_diffstar, smh_table_dsps, rtol=1e-4)
83
+
84
+
85
+ def test_cumulative_mstar_formed_vmap():
86
+
87
+ n_t = 200
88
+ t_table = np.linspace(T_TABLE_MIN, 13.8, n_t)
89
+ ran_key = jran.PRNGKey(0)
90
+
91
+ n_gals = 25
92
+ sfh_table_galpop = jran.uniform(ran_key, minval=0, maxval=1, shape=(n_gals, n_t))
93
+ smh_table_galpop = utils.cumulative_mstar_formed_galpop(t_table, sfh_table_galpop)
94
+
95
+ for ig in range(n_gals):
96
+ smh_table_ig = utils.cumulative_mstar_formed(t_table, sfh_table_galpop[ig, :])
97
+ assert np.allclose(smh_table_ig, smh_table_galpop[ig, :], rtol=1e-5)
@@ -5,6 +5,7 @@ import numpy as np
5
5
  from jax import jit as jjit
6
6
  from jax import lax, nn
7
7
  from jax import numpy as jnp
8
+ from jax import vmap
8
9
  from jax.lax import scan
9
10
 
10
11
  from .defaults import SFR_MIN, T_BIRTH_MIN
@@ -273,3 +274,35 @@ def cumulative_mstar_formed(t_table, sfh_table):
273
274
  mstar_formed = cumtrapz(padded_t_table, padded_sfh_table)[1:] * YEAR_PER_GYR
274
275
 
275
276
  return mstar_formed
277
+
278
+
279
+ _cuml_mstar_vmap = jjit(vmap(cumulative_mstar_formed, in_axes=(None, 0)))
280
+
281
+
282
+ @jjit
283
+ def cumulative_mstar_formed_galpop(t_table, sfh_table):
284
+ """Compute the cumulative stellar mass formed at each input time
285
+
286
+ Parameters
287
+ ----------
288
+ t_table : ndarray, shape (n_t, )
289
+ Age of the Universe in Gyr.
290
+ Array should be monotonically increasing and
291
+ t_table[0] >= dsps.constants.T_TABLE_MIN
292
+
293
+ sfh_table : ndarray, shape (n_gals, n_t)
294
+ SFR in Msun/yr at each of the input times
295
+
296
+ Returns
297
+ -------
298
+ mstar_formed : ndarray, shape (n_gals, n_t)
299
+ Cumulative stellar mass formed in Msun at each input time
300
+
301
+ Notes
302
+ -----
303
+ Mstar formed is calculated using trapezoidal integration
304
+ and assuming that during the interval (dsps.constants.T_BIRTH_MIN, t_table[0]),
305
+ SFH is constant and equal to dsps.constants.SFR_MIN
306
+
307
+ """
308
+ return _cuml_mstar_vmap(t_table, sfh_table)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: diffstar
3
- Version: 0.3.3
3
+ Version: 0.3.4
4
4
  Summary: Differentiable Star Formation Histories
5
5
  Author-email: Alex Alarcon <alexalarcongonzalez@gmail.com>
6
6
  License: BSD 3-Clause License
@@ -10,6 +10,7 @@ pyproject.toml
10
10
  requirements.txt
11
11
  setup.cfg
12
12
  setup.py
13
+ .github/dependabot.yml
13
14
  .github/workflows/linting.yaml
14
15
  .github/workflows/monthly-warning-test.yaml
15
16
  .github/workflows/test_releases.yaml
@@ -1 +0,0 @@
1
- __version__ = '0.3.3'
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes