diffstar 0.3.0__tar.gz → 0.3.1__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (102) hide show
  1. {diffstar-0.3.0 → diffstar-0.3.1}/.github/workflows/linting.yml +1 -1
  2. diffstar-0.3.1/.github/workflows/monthly-warning-test.yml +55 -0
  3. {diffstar-0.3.0 → diffstar-0.3.1}/.github/workflows/test_releases.yml +1 -1
  4. {diffstar-0.3.0 → diffstar-0.3.1}/.github/workflows/tests_cron.yml +1 -1
  5. {diffstar-0.3.0 → diffstar-0.3.1}/CHANGES.rst +5 -0
  6. {diffstar-0.3.0/diffstar.egg-info → diffstar-0.3.1}/PKG-INFO +2 -2
  7. diffstar-0.3.1/diffstar/_version.py +1 -0
  8. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/gas_consumption.py +6 -0
  9. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/history_kernel_builders.py +15 -1
  10. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/main_sequence_kernels.py +36 -2
  11. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/sfh_model.py +3 -0
  12. diffstar-0.3.1/diffstar/tests/test_main_sequence_kernels.py +28 -0
  13. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/utils.py +3 -2
  14. {diffstar-0.3.0 → diffstar-0.3.1/diffstar.egg-info}/PKG-INFO +2 -2
  15. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar.egg-info/SOURCES.txt +2 -0
  16. {diffstar-0.3.0 → diffstar-0.3.1}/pyproject.toml +1 -1
  17. diffstar-0.3.0/diffstar/_version.py +0 -1
  18. {diffstar-0.3.0 → diffstar-0.3.1}/.coveragerc +0 -0
  19. {diffstar-0.3.0 → diffstar-0.3.1}/.git_archival.txt +0 -0
  20. {diffstar-0.3.0 → diffstar-0.3.1}/.gitattributes +0 -0
  21. {diffstar-0.3.0 → diffstar-0.3.1}/.gitignore +0 -0
  22. {diffstar-0.3.0 → diffstar-0.3.1}/.readthedocs.yml +0 -0
  23. {diffstar-0.3.0 → diffstar-0.3.1}/LICENSE.rst +0 -0
  24. {diffstar-0.3.0 → diffstar-0.3.1}/README.md +0 -0
  25. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/__init__.py +0 -0
  26. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/data_loaders/__init__.py +0 -0
  27. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/data_loaders/load_bpl.py +0 -0
  28. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/data_loaders/load_smah_data.py +0 -0
  29. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/data_loaders/tests/__init__.py +0 -0
  30. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/data_loaders/tests/test_load_bpl.py +0 -0
  31. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/defaults.py +0 -0
  32. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/fitting_helpers/__init__.py +0 -0
  33. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/fitting_helpers/fit_smah_helpers.py +0 -0
  34. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/fitting_helpers/fitting_kernels.py +0 -0
  35. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/fitting_helpers/param_clippers.py +0 -0
  36. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/fitting_helpers/stars.py +0 -0
  37. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/__init__.py +0 -0
  38. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_fitting_kernels.py +0 -0
  39. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_fitting_kernels_are_frozen.py +0 -0
  40. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_fitting_smah_helpers.py +0 -0
  41. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_param_clippers.py +0 -0
  42. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/fitting_helpers/tests/test_stars.py +0 -0
  43. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/fitting_helpers/utils.py +0 -0
  44. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/__init__.py +0 -0
  45. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/kernel_builders.py +0 -0
  46. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/quenching_kernels.py +0 -0
  47. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/tests/__init__.py +0 -0
  48. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/tests/test_frozen_diffstar_kernels.py +0 -0
  49. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/tests/test_kernel_builders.py +0 -0
  50. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/tests/test_quenching_kernels.py +0 -0
  51. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/mah_params_testing_v0.1.0.txt +0 -0
  52. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/ms_params_testing_v0.1.0.txt +0 -0
  53. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/q_params_testing_v0.1.0.txt +0 -0
  54. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/sfh_table_testing_v0.1.0.txt +0 -0
  55. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/t_table_testing_v0.1.0.txt +0 -0
  56. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/u_ms_params_testing_v0.1.0.txt +0 -0
  57. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/kernels/tests/testing_data/u_q_params_testing_v0.1.0.txt +0 -0
  58. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/sfh.py +0 -0
  59. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/__init__.py +0 -0
  60. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/test_defaults.py +0 -0
  61. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/test_gas.py +0 -0
  62. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/test_lax_main_sequence.py +0 -0
  63. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/test_lax_sfh.py +0 -0
  64. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/test_quenching.py +0 -0
  65. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/test_sfh.py +0 -0
  66. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/test_sfh_model.py +0 -0
  67. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/test_utils.py +0 -0
  68. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_diffmah_params.txt +0 -0
  69. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_dmhdt.txt +0 -0
  70. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_dt.txt +0 -0
  71. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_lgt.txt +0 -0
  72. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_log_mah.txt +0 -0
  73. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_sfh.txt +0 -0
  74. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_u_ms_params.txt +0 -0
  75. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/default_params_test_u_q_params.txt +0 -0
  76. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/dt_bpl.txt +0 -0
  77. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/halo_ids_test_sample.txt +0 -0
  78. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/lgt_bpl.txt +0 -0
  79. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/mah_params_test_sample.txt +0 -0
  80. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/ms_u_params_test_sample.txt +0 -0
  81. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/q_u_params_test_sample.txt +0 -0
  82. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar/tests/testing_data/sfh_test_sample.txt +0 -0
  83. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar.egg-info/dependency_links.txt +0 -0
  84. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar.egg-info/requires.txt +0 -0
  85. {diffstar-0.3.0 → diffstar-0.3.1}/diffstar.egg-info/top_level.txt +0 -0
  86. {diffstar-0.3.0 → diffstar-0.3.1}/docs/Makefile +0 -0
  87. {diffstar-0.3.0 → diffstar-0.3.1}/docs/make.bat +0 -0
  88. {diffstar-0.3.0 → diffstar-0.3.1}/docs/source/_static/README.txt +0 -0
  89. {diffstar-0.3.0 → diffstar-0.3.1}/docs/source/citation.rst +0 -0
  90. {diffstar-0.3.0 → diffstar-0.3.1}/docs/source/conf.py +0 -0
  91. {diffstar-0.3.0 → diffstar-0.3.1}/docs/source/demo_diffstar_fitter.ipynb +0 -0
  92. {diffstar-0.3.0 → diffstar-0.3.1}/docs/source/demo_diffstar_sfh.ipynb +0 -0
  93. {diffstar-0.3.0 → diffstar-0.3.1}/docs/source/demos.rst +0 -0
  94. {diffstar-0.3.0 → diffstar-0.3.1}/docs/source/index.rst +0 -0
  95. {diffstar-0.3.0 → diffstar-0.3.1}/docs/source/installation.rst +0 -0
  96. {diffstar-0.3.0 → diffstar-0.3.1}/docs/source/reference.rst +0 -0
  97. {diffstar-0.3.0 → diffstar-0.3.1}/docs/source/rtd_environment.yaml +0 -0
  98. {diffstar-0.3.0 → diffstar-0.3.1}/requirements.txt +0 -0
  99. {diffstar-0.3.0 → diffstar-0.3.1}/scripts/generate_unit_testing_data.py +0 -0
  100. {diffstar-0.3.0 → diffstar-0.3.1}/scripts/history_fitting_script.py +0 -0
  101. {diffstar-0.3.0 → diffstar-0.3.1}/setup.cfg +0 -0
  102. {diffstar-0.3.0 → diffstar-0.3.1}/setup.py +0 -0
@@ -16,7 +16,7 @@ jobs:
16
16
 
17
17
  - uses: conda-incubator/setup-miniconda@v2
18
18
  with:
19
- python-version: 3.9
19
+ python-version: 3.11
20
20
  channels: conda-forge,defaults
21
21
  channel-priority: strict
22
22
  show-channel-urls: true
@@ -0,0 +1,55 @@
1
+ name: Test for Warnings
2
+
3
+ on:
4
+ workflow_dispatch: null
5
+ schedule:
6
+ # Runs "First of every month at 3:15am Central"
7
+ - cron: '15 8 1 * *'
8
+
9
+ jobs:
10
+ tests:
11
+ name: tests
12
+ runs-on: "ubuntu-latest"
13
+
14
+ steps:
15
+ - uses: actions/checkout@v2
16
+ with:
17
+ fetch-depth: 0
18
+
19
+ - uses: conda-incubator/setup-miniconda@v2
20
+ with:
21
+ python-version: 3.11
22
+ channels: conda-forge,defaults
23
+ channel-priority: strict
24
+ show-channel-urls: true
25
+ miniforge-version: latest
26
+ miniforge-variant: Mambaforge
27
+ use-mamba: true
28
+
29
+ - name: configure conda and install code
30
+ # Test against current main branch of diffmah and dsps
31
+ shell: bash -l {0}
32
+ run: |
33
+ conda config --set always_yes yes
34
+ mamba install --quiet \
35
+ --file=requirements.txt
36
+ python -m pip install --no-deps -e .
37
+ mamba install -y -q \
38
+ flake8 \
39
+ pytest \
40
+ pytest-xdist \
41
+ pytest-cov \
42
+ pip \
43
+ setuptools \
44
+ "setuptools_scm>=7,<8" \
45
+ python-build
46
+ pip uninstall diffmah --yes
47
+ pip install --no-deps git+https://github.com/ArgonneCPAC/diffmah.git
48
+ pip install --no-deps git+https://github.com/ArgonneCPAC/dsps.git
49
+ python -m pip install --no-build-isolation --no-deps -e .
50
+
51
+ - name: test that no warnings are raised
52
+ shell: bash -l {0}
53
+ run: |
54
+ export PYTHONWARNINGS=error
55
+ pytest -v diffstar --cov --cov-report=xml
@@ -22,7 +22,7 @@ jobs:
22
22
 
23
23
  - uses: conda-incubator/setup-miniconda@v2
24
24
  with:
25
- python-version: 3.9
25
+ python-version: 3.11
26
26
  channels: conda-forge,defaults
27
27
  channel-priority: strict
28
28
  show-channel-urls: true
@@ -22,7 +22,7 @@ jobs:
22
22
 
23
23
  - uses: conda-incubator/setup-miniconda@v2
24
24
  with:
25
- python-version: 3.9
25
+ python-version: 3.11
26
26
  channels: conda-forge,defaults
27
27
  channel-priority: strict
28
28
  show-channel-urls: true
@@ -1,3 +1,8 @@
1
+ 0.3.1 (2024-6-19)
2
+ ------------------
3
+ - Performance improvements for calc_sfh_galpop and calc_sfh_singlegal
4
+
5
+
1
6
  0.3.0 (2024-01-17)
2
7
  ------------------
3
8
  - Implement new API for primary user-facing functions calc_sfh_galpop and calc_sfh_singlegal
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: diffstar
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: Differentiable Star Formation Histories
5
5
  Author-email: Alex Alarcon <alexalarcongonzalez@gmail.com>
6
6
  License: BSD 3-Clause License
@@ -35,7 +35,7 @@ License: BSD 3-Clause License
35
35
 
36
36
  Project-URL: home, https://github.com/ArgonneCPAC/diffstar
37
37
  Classifier: Programming Language :: Python :: 3
38
- Requires-Python: >=3.8
38
+ Requires-Python: >=3.9
39
39
  Description-Content-Type: text/markdown
40
40
  License-File: LICENSE.rst
41
41
  Requires-Dist: diffmah>=0.5.0
@@ -0,0 +1 @@
1
+ __version__ = '0.3.1'
@@ -1,5 +1,6 @@
1
1
  """Calculate mass of gas available for star formation from the freshly accreted gas.
2
2
  """
3
+
3
4
  from jax import jit as jjit
4
5
  from jax import lax
5
6
  from jax import numpy as jnp
@@ -28,6 +29,11 @@ def _gas_conversion_kern(t_form, t_acc, dt, tau_dep, tau_dep_max):
28
29
  return tri_kern
29
30
 
30
31
 
32
+ _vmap_gas_conversion_kern = jjit(
33
+ vmap(_gas_conversion_kern, in_axes=(None, 0, None, None, None))
34
+ )
35
+
36
+
31
37
  _a, _b = (0, None, 0, None, None), (None, 0, None, None, None)
32
38
  _depletion_kernel = jjit(vmap(vmap(_gas_conversion_kern, in_axes=_b), in_axes=_a))
33
39
 
@@ -1,5 +1,6 @@
1
1
  """
2
2
  """
3
+
3
4
  from jax import jit as jjit
4
5
  from jax import lax
5
6
  from jax import numpy as jnp
@@ -12,7 +13,10 @@ from ..defaults import (
12
13
  SFR_MIN,
13
14
  T_BIRTH_MIN,
14
15
  )
15
- from .main_sequence_kernels import _lax_ms_sfh_scalar_kern
16
+ from .main_sequence_kernels import (
17
+ _lax_ms_sfh_scalar_kern_scan,
18
+ _lax_ms_sfh_scalar_kern_sum,
19
+ )
16
20
  from .quenching_kernels import _quenching_kern
17
21
 
18
22
  __all__ = ("build_sfh_from_mah_kernel",)
@@ -27,6 +31,7 @@ def build_sfh_from_mah_kernel(
27
31
  tacc_integration_min=T_BIRTH_MIN,
28
32
  tobs_loop=None,
29
33
  galpop_loop=None,
34
+ tform_loop="sum",
30
35
  ):
31
36
  """Build a JAX-jitted kernel to calculate SFHs of a galaxy population.
32
37
 
@@ -51,6 +56,10 @@ def build_sfh_from_mah_kernel(
51
56
  For a JAX kernel that assumes galaxy population,
52
57
  options are either 'vmap' or 'scan', specifying the calculation method
53
58
 
59
+ tform_loop : string
60
+ Use 'sum' for faster vmap-based calculation and 'scan' for slower alternative.
61
+ Default is 'sum'
62
+
54
63
  Returns
55
64
  -------
56
65
  sfh_from_mah_kern : function
@@ -61,6 +70,11 @@ def build_sfh_from_mah_kernel(
61
70
  return sfh
62
71
 
63
72
  """
73
+ if tform_loop == "sum":
74
+ _lax_ms_sfh_scalar_kern = _lax_ms_sfh_scalar_kern_sum
75
+ elif tform_loop == "scan":
76
+ _lax_ms_sfh_scalar_kern = _lax_ms_sfh_scalar_kern_scan
77
+
64
78
  uniform_table = jnp.linspace(0, 1, n_steps)
65
79
 
66
80
  @jjit
@@ -1,10 +1,12 @@
1
1
  """
2
2
  """
3
+
3
4
  from collections import OrderedDict, namedtuple
4
5
 
5
6
  import numpy as np
6
7
  from diffmah.defaults import MAH_K
7
8
  from diffmah.individual_halo_assembly import (
9
+ _calc_halo_history,
8
10
  _calc_halo_history_scalar,
9
11
  _rolling_plaw_vs_logt,
10
12
  )
@@ -14,7 +16,7 @@ from jax import numpy as jnp
14
16
  from jax import vmap
15
17
 
16
18
  from ..utils import _inverse_sigmoid, _jax_get_dt_array, _sigmoid
17
- from .gas_consumption import _gas_conversion_kern
19
+ from .gas_consumption import _gas_conversion_kern, _vmap_gas_conversion_kern
18
20
 
19
21
  DEFAULT_MS_PDICT = OrderedDict(
20
22
  lgmcrit=12.0,
@@ -53,7 +55,7 @@ MS_BOUNDING_SIGMOID_PDICT = calculate_sigmoid_bounds(MS_PARAM_BOUNDS_PDICT)
53
55
 
54
56
 
55
57
  @jjit
56
- def _lax_ms_sfh_scalar_kern(t_form, mah_params, ms_params, lgt0, fb, t_table):
58
+ def _lax_ms_sfh_scalar_kern_scan(t_form, mah_params, ms_params, lgt0, fb, t_table):
57
59
  logmp, logtc, early, late = mah_params
58
60
  all_mah_params = lgt0, logmp, logtc, MAH_K, early, late
59
61
  lgt_form = jnp.log10(t_form)
@@ -93,6 +95,38 @@ def _lax_ms_sfh_scalar_kern(t_form, mah_params, ms_params, lgt0, fb, t_table):
93
95
  return sfr
94
96
 
95
97
 
98
+ @jjit
99
+ def _lax_ms_sfh_scalar_kern_sum(t_form, mah_params, ms_params, lgt0, fb, t_table):
100
+ logmp, logtc, early, late = mah_params
101
+ all_mah_params = lgt0, logmp, logtc, MAH_K, early, late
102
+ lgt_form = jnp.log10(t_form)
103
+ log_mah_at_tform = _rolling_plaw_vs_logt(lgt_form, *all_mah_params)
104
+
105
+ sfr_eff_params = ms_params[:4]
106
+ sfr_eff = _sfr_eff_plaw(log_mah_at_tform, *sfr_eff_params)
107
+
108
+ tau_dep = ms_params[4]
109
+ tau_dep_max = MS_BOUNDING_SIGMOID_PDICT["tau_dep"][3]
110
+
111
+ # compute inst. gas accretion
112
+ lgtacc = jnp.log10(t_table)
113
+ res = _calc_halo_history(lgtacc, *all_mah_params)
114
+ dmhdt_at_tacc, log_mah_at_tacc = res
115
+ dmgdt_inst = fb * dmhdt_at_tacc
116
+
117
+ # compute the consumption kernel
118
+ dt = t_table[1] - t_table[0]
119
+ kern = _vmap_gas_conversion_kern(t_form, t_table, dt, tau_dep, tau_dep_max)
120
+
121
+ # convolve
122
+ dmgas_dt = jnp.sum(dmgdt_inst * kern * dt)
123
+ sfr = dmgas_dt * sfr_eff
124
+ return sfr
125
+
126
+
127
+ _lax_ms_sfh_scalar_kern = _lax_ms_sfh_scalar_kern_sum
128
+
129
+
96
130
  @jjit
97
131
  def _sfr_eff_plaw(lgm, lgmcrit, lgy_at_mcrit, indx_lo, indx_hi):
98
132
  """Instantaneous baryon conversion efficiency of main sequence galaxies
@@ -1,5 +1,6 @@
1
1
  """
2
2
  """
3
+
3
4
  from collections import namedtuple
4
5
  from functools import partial
5
6
 
@@ -14,6 +15,7 @@ _sfh_singlegal_kern = build_sfh_from_mah_kernel(
14
15
  n_steps=DEFAULT_N_STEPS,
15
16
  tacc_integration_min=T_BIRTH_MIN,
16
17
  tobs_loop="scan",
18
+ tform_loop="sum",
17
19
  )
18
20
 
19
21
  _sfh_galpop_kern = build_sfh_from_mah_kernel(
@@ -21,6 +23,7 @@ _sfh_galpop_kern = build_sfh_from_mah_kernel(
21
23
  tacc_integration_min=T_BIRTH_MIN,
22
24
  tobs_loop="scan",
23
25
  galpop_loop="vmap",
26
+ tform_loop="sum",
24
27
  )
25
28
 
26
29
  _cumulative_mstar_formed_vmap = jjit(vmap(cumulative_mstar_formed, in_axes=(None, 0)))
@@ -0,0 +1,28 @@
1
+ import numpy as np
2
+ import jax.numpy as jnp
3
+
4
+ from diffstar.kernels.main_sequence_kernels import (
5
+ _lax_ms_sfh_scalar_kern_scan,
6
+ _lax_ms_sfh_scalar_kern_sum,
7
+ )
8
+ from diffstar.kernels.main_sequence_kernels import DEFAULT_MS_PARAMS
9
+ from diffmah.defaults import DEFAULT_MAH_PARAMS
10
+ from diffstar.defaults import T_TABLE_MIN, TODAY
11
+ from diffstar.defaults import FB
12
+
13
+
14
+ def test_main_sequence_kernels_lax_ms_sfh_scalar_kern_scan_vs_sum():
15
+ lgt0 = jnp.log10(TODAY)
16
+ t_form = 12.0
17
+ t_table = jnp.linspace(T_TABLE_MIN, t_form, 20)
18
+
19
+ np.testing.assert_allclose(
20
+ _lax_ms_sfh_scalar_kern_scan(
21
+ t_form, DEFAULT_MAH_PARAMS, DEFAULT_MS_PARAMS, lgt0, FB, t_table
22
+ ),
23
+ _lax_ms_sfh_scalar_kern_sum(
24
+ t_form, DEFAULT_MAH_PARAMS, DEFAULT_MS_PARAMS, lgt0, FB, t_table
25
+ ),
26
+ rtol=1e-6,
27
+ atol=1e-6,
28
+ )
@@ -1,8 +1,9 @@
1
1
  """
2
2
  """
3
+
3
4
  import numpy as np
4
5
  from jax import jit as jjit
5
- from jax import lax
6
+ from jax import lax, nn
6
7
  from jax import numpy as jnp
7
8
  from jax.lax import scan
8
9
 
@@ -91,7 +92,7 @@ def _sigmoid(x, x0, k, ymin, ymax):
91
92
 
92
93
  """
93
94
  height_diff = ymax - ymin
94
- return ymin + height_diff * lax.logistic(k * (x - x0))
95
+ return ymin + height_diff * nn.sigmoid(k * (x - x0))
95
96
 
96
97
 
97
98
  @jjit
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: diffstar
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: Differentiable Star Formation Histories
5
5
  Author-email: Alex Alarcon <alexalarcongonzalez@gmail.com>
6
6
  License: BSD 3-Clause License
@@ -35,7 +35,7 @@ License: BSD 3-Clause License
35
35
 
36
36
  Project-URL: home, https://github.com/ArgonneCPAC/diffstar
37
37
  Classifier: Programming Language :: Python :: 3
38
- Requires-Python: >=3.8
38
+ Requires-Python: >=3.9
39
39
  Description-Content-Type: text/markdown
40
40
  License-File: LICENSE.rst
41
41
  Requires-Dist: diffmah>=0.5.0
@@ -11,6 +11,7 @@ requirements.txt
11
11
  setup.cfg
12
12
  setup.py
13
13
  .github/workflows/linting.yml
14
+ .github/workflows/monthly-warning-test.yml
14
15
  .github/workflows/test_releases.yml
15
16
  .github/workflows/tests_cron.yml
16
17
  diffstar/__init__.py
@@ -63,6 +64,7 @@ diffstar/tests/test_defaults.py
63
64
  diffstar/tests/test_gas.py
64
65
  diffstar/tests/test_lax_main_sequence.py
65
66
  diffstar/tests/test_lax_sfh.py
67
+ diffstar/tests/test_main_sequence_kernels.py
66
68
  diffstar/tests/test_quenching.py
67
69
  diffstar/tests/test_sfh.py
68
70
  diffstar/tests/test_sfh_model.py
@@ -26,7 +26,7 @@ authors = [
26
26
  ]
27
27
  description = "Differentiable Star Formation Histories"
28
28
  readme = "README.md"
29
- requires-python = ">=3.8"
29
+ requires-python = ">=3.9"
30
30
  license = {file = "LICENSE.rst"}
31
31
  classifiers = [
32
32
  "Programming Language :: Python :: 3",
@@ -1 +0,0 @@
1
- __version__ = '0.3.0'
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes