bdext 0.1.64__py3-none-any.whl → 0.1.66__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- README.md +165 -103
- bdeissct_dl/__init__.py +1 -3
- bdeissct_dl/bdeissct_model.py +11 -65
- bdeissct_dl/dl_model.py +7 -119
- bdeissct_dl/estimator.py +8 -108
- bdeissct_dl/model_serializer.py +3 -33
- bdeissct_dl/scaler_fitting.py +3 -6
- bdeissct_dl/sumstat_checker.py +2 -2
- bdeissct_dl/training.py +9 -30
- bdeissct_dl/tree_encoder.py +13 -32
- bdext-0.1.66.dist-info/METADATA +240 -0
- bdext-0.1.66.dist-info/RECORD +17 -0
- {bdext-0.1.64.dist-info → bdext-0.1.66.dist-info}/entry_points.txt +0 -2
- bdeissct_dl/estimator_ct.py +0 -63
- bdeissct_dl/main_covid.py +0 -76
- bdeissct_dl/model_finder.py +0 -47
- bdeissct_dl/pinball_loss.py +0 -48
- bdeissct_dl/train_ct.py +0 -125
- bdext-0.1.64.dist-info/METADATA +0 -178
- bdext-0.1.64.dist-info/RECORD +0 -22
- {bdext-0.1.64.dist-info → bdext-0.1.66.dist-info}/LICENSE +0 -0
- {bdext-0.1.64.dist-info → bdext-0.1.66.dist-info}/WHEEL +0 -0
- {bdext-0.1.64.dist-info → bdext-0.1.66.dist-info}/top_level.txt +0 -0
README.md
CHANGED
|
@@ -1,149 +1,211 @@
|
|
|
1
|
-
#
|
|
1
|
+
# bdext
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
The bdext package provides scripts to train and assess
|
|
4
|
+
Deep-Learning-enables estimators of BD(EI)(SS)(CT) model parameters from phylogenetic trees
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
[//]: # ([](https://doi.org/10.1093/sysbio/syad059))
|
|
8
|
-
[![GitHub release]
|
|
9
|
-
[](https://github.com/evolbioinfo/bdext/releases))
|
|
10
|
+
[](https://pypi.org/project/bdext/)
|
|
11
|
+
[](https://pypi.org/project/bdext)
|
|
12
|
+
[](https://hub.docker.com/r/evolbioinfo/bdext/tags)
|
|
12
13
|
|
|
13
14
|
## BDEISS-CT model
|
|
14
15
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
get removed (become non-infectious) with a constant rate ψ,
|
|
18
|
-
and their pathogen can be sampled upon removal
|
|
19
|
-
with a constant probability ρ. On top of that, in the BD-PN model,
|
|
20
|
-
at the moment of sampling the sampled individual
|
|
21
|
-
might notify their most recent partner with a constant probability υ.
|
|
22
|
-
Upon notification, the partner is removed almost instantaneously (modeled via a constant notified
|
|
23
|
-
removal rate φ >> ψ) and their pathogen is sampled.
|
|
16
|
+
The Birth-Death (BD) Exposed-Infectious (EI) with SuperSpreading (SS) and Contact-Tracing (CT) model (BDEISS-CT)
|
|
17
|
+
can be described with the following 8 parameters:
|
|
24
18
|
|
|
25
|
-
|
|
26
|
-
*
|
|
27
|
-
*
|
|
28
|
-
*
|
|
29
|
-
*
|
|
30
|
-
*
|
|
19
|
+
* average reproduction number R;
|
|
20
|
+
* average total infection duration d;
|
|
21
|
+
* incubation period d<sub>inc</sub>;
|
|
22
|
+
* sampling probability ρ;
|
|
23
|
+
* fraction of superspreaders f<sub>S</sub>;
|
|
24
|
+
* super-spreading transmission increase X<sub>S</sub>;
|
|
25
|
+
* contact tracing probability υ;
|
|
26
|
+
* contact-traced removal speed up X<sub>C</sub>.
|
|
31
27
|
|
|
32
|
-
|
|
33
|
-
* R<sub>0</sub>=λ/ψ -- reproduction number
|
|
34
|
-
* 1/ψ -- infectious time
|
|
35
|
-
* 1/φ -- partner removal time
|
|
28
|
+
Setting d<sub>inc</sub>=0 removes incubation (EI), setting f<sub>S</sub>=0 removes superspreading (SS), while setting υ=0 removes contact-tracing (CT).
|
|
36
29
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
2. notified individuals are always observed upon removal;
|
|
40
|
-
3. only the most recent partner can get notified.
|
|
30
|
+
For identifiability, we require the sampling probability ρ to be given by the user.
|
|
31
|
+
The other parameters are estimated from a time-scaled phylogenetic tree.
|
|
41
32
|
|
|
42
|
-
|
|
33
|
+
[//]: # (## BDEISS-CT parameter estimator)
|
|
43
34
|
|
|
44
|
-
|
|
35
|
+
[//]: # ()
|
|
36
|
+
[//]: # (The bdeissct_dl package provides deep-learning-based BDEISS-CT model parameter estimator )
|
|
45
37
|
|
|
46
|
-
|
|
47
|
-
from a user-supplied time-scaled phylogenetic tree.
|
|
48
|
-
User must also provide a value for one of the three BD model parameters (λ, ψ, or ρ).
|
|
49
|
-
We recommend providing the sampling probability ρ,
|
|
50
|
-
which could be estimated as the number of tree tips divided by the number of declared cases for the same time period.
|
|
38
|
+
[//]: # (from a user-supplied time-scaled phylogenetic tree. )
|
|
51
39
|
|
|
40
|
+
[//]: # (User must also provide a value for one of the three BD model parameters (λ, ψ, or ρ). )
|
|
52
41
|
|
|
53
|
-
|
|
54
|
-
One needs to supply a time-scaled phylogenetic tree in newick format.
|
|
55
|
-
In the examples below we will use an HIV tree reconstructed from 200 sequences,
|
|
56
|
-
published in [[Rasmussen _et al._ PLoS Comput. Biol. 2017]](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1005448),
|
|
57
|
-
which you can find at [PairTree GitHub](https://github.com/davidrasm/PairTree)
|
|
58
|
-
and in [hiv_zurich/Zurich.nwk](hiv_zurich/Zurich.nwk).
|
|
42
|
+
[//]: # (We recommend providing the sampling probability ρ, )
|
|
59
43
|
|
|
60
|
-
|
|
44
|
+
[//]: # (which could be estimated as the number of tree tips divided by the number of declared cases for the same time period.)
|
|
61
45
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
[
|
|
65
|
-
in Python3, or via command line (requires installation with Python3).
|
|
46
|
+
[//]: # ()
|
|
47
|
+
[//]: # ()
|
|
48
|
+
[//]: # (## Input data)
|
|
66
49
|
|
|
50
|
+
[//]: # (One needs to supply a time-scaled phylogenetic tree in newick format. )
|
|
67
51
|
|
|
52
|
+
[//]: # (In the examples below we will use an HIV tree reconstructed from 200 sequences, )
|
|
68
53
|
|
|
69
|
-
|
|
54
|
+
[//]: # (published in [[Rasmussen _et al._ PLoS Comput. Biol. 2017]](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1005448), )
|
|
70
55
|
|
|
71
|
-
|
|
72
|
-
```bash
|
|
73
|
-
sudo apt install -y python3 python3-pip python3-setuptools python3-distutils
|
|
74
|
-
pip3 install bdeissct_dl
|
|
75
|
-
```
|
|
56
|
+
[//]: # (which you can find at [PairTree GitHub](https://github.com/davidrasm/PairTree) )
|
|
76
57
|
|
|
77
|
-
|
|
78
|
-
Here we will create a conda environment called _phyloenv_:
|
|
79
|
-
```bash
|
|
80
|
-
conda create --name phyloenv python=3.12
|
|
81
|
-
conda activate phyloenv
|
|
82
|
-
pip install bdeissct_dl
|
|
83
|
-
```
|
|
58
|
+
[//]: # (and in [hiv_zurich/Zurich.nwk](hiv_zurich/Zurich.nwk). )
|
|
84
59
|
|
|
60
|
+
[//]: # ()
|
|
61
|
+
[//]: # (## Installation)
|
|
85
62
|
|
|
86
|
-
|
|
87
|
-
|
|
63
|
+
[//]: # ()
|
|
64
|
+
[//]: # (There are 4 alternative ways to run __bdeissct_dl__ on your computer: )
|
|
88
65
|
|
|
89
|
-
|
|
90
|
-
conda activate phyloenv
|
|
91
|
-
```
|
|
66
|
+
[//]: # (with [docker](https://www.docker.com/community-edition), )
|
|
92
67
|
|
|
93
|
-
|
|
94
|
-
and save the estimated parameters to a comma-separated file estimates.csv.
|
|
95
|
-
```bash
|
|
96
|
-
bdeissct_infer --nwk Zurich.nwk --ci --p 0.25 --log estimates.csv
|
|
97
|
-
```
|
|
68
|
+
[//]: # ([apptainer](https://apptainer.org/),)
|
|
98
69
|
|
|
99
|
-
|
|
70
|
+
[//]: # (in Python3, or via command line (requires installation with Python3).)
|
|
100
71
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
72
|
+
[//]: # ()
|
|
73
|
+
[//]: # ()
|
|
74
|
+
[//]: # ()
|
|
75
|
+
[//]: # (### Run in python3 or command-line (for linux systems, recommended Ubuntu 21 or newer versions))
|
|
105
76
|
|
|
77
|
+
[//]: # ()
|
|
78
|
+
[//]: # (You could either install python (version 3.9 or higher) system-wide and then install bdeissct_dl via pip:)
|
|
106
79
|
|
|
107
|
-
|
|
80
|
+
[//]: # (```bash)
|
|
108
81
|
|
|
109
|
-
|
|
110
|
-
Once [docker](https://www.docker.com/community-edition) is installed,
|
|
111
|
-
run the following command to estimate BDEISS-CT model parameters:
|
|
112
|
-
```bash
|
|
113
|
-
docker run -v <path_to_the_folder_containing_the_tree>:/data:rw -t evolbioinfo/bdeissct --nwk /data/Zurich.nwk --ci --p 0.25 --log /data/estimates.csv
|
|
114
|
-
```
|
|
82
|
+
[//]: # (sudo apt install -y python3 python3-pip python3-setuptools python3-distutils)
|
|
115
83
|
|
|
116
|
-
|
|
117
|
-
containing the estimated parameter values and their 95% CIs (can be viewed with a text editor, Excel or Libre Office Calc).
|
|
84
|
+
[//]: # (pip3 install bdeissct_dl)
|
|
118
85
|
|
|
119
|
-
|
|
86
|
+
[//]: # (```)
|
|
120
87
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
docker run -t evolbioinfo/bdeissct -h
|
|
124
|
-
```
|
|
88
|
+
[//]: # ()
|
|
89
|
+
[//]: # (or alternatively, you could install python (version 3.9 or higher) and bdeissct_dl via [conda](https://conda.io/docs/) (make sure that conda is installed first). )
|
|
125
90
|
|
|
91
|
+
[//]: # (Here we will create a conda environment called _phyloenv_:)
|
|
126
92
|
|
|
93
|
+
[//]: # (```bash)
|
|
127
94
|
|
|
128
|
-
|
|
95
|
+
[//]: # (conda create --name phyloenv python=3.12)
|
|
129
96
|
|
|
130
|
-
|
|
131
|
-
Once [apptainer](https://apptainer.org/docs/user/latest/quick_start.html#installation) is installed,
|
|
132
|
-
run the following command to estimate BDEISS-CT model parameters (from the folder where the Zurich.nwk tree is contained):
|
|
97
|
+
[//]: # (conda activate phyloenv)
|
|
133
98
|
|
|
134
|
-
|
|
135
|
-
apptainer run docker://evolbioinfo/bdeissct --nwk Zurich.nwk --ci --p 0.25 --log estimates.csv
|
|
136
|
-
```
|
|
99
|
+
[//]: # (pip install bdeissct_dl)
|
|
137
100
|
|
|
138
|
-
|
|
139
|
-
containing the estimated parameter values and their 95% CIs (can be viewed with a text editor, Excel or Libre Office Calc).
|
|
101
|
+
[//]: # (```)
|
|
140
102
|
|
|
103
|
+
[//]: # ()
|
|
104
|
+
[//]: # ()
|
|
105
|
+
[//]: # (#### Basic usage in a command line)
|
|
141
106
|
|
|
142
|
-
|
|
107
|
+
[//]: # (If you installed __bdeissct_dl__ in a conda environment (here named _phyloenv_), do not forget to first activate it, e.g.)
|
|
143
108
|
|
|
144
|
-
|
|
145
|
-
```bash
|
|
146
|
-
apptainer run docker://evolbioinfo/bdeissct -h
|
|
147
|
-
```
|
|
109
|
+
[//]: # ()
|
|
110
|
+
[//]: # (```bash)
|
|
148
111
|
|
|
112
|
+
[//]: # (conda activate phyloenv)
|
|
149
113
|
|
|
114
|
+
[//]: # (```)
|
|
115
|
+
|
|
116
|
+
[//]: # ()
|
|
117
|
+
[//]: # (Run the following command to estimate the BDEISS_CT parameters and their 95% CIs for this tree, assuming the sampling probability of 0.25, )
|
|
118
|
+
|
|
119
|
+
[//]: # (and save the estimated parameters to a comma-separated file estimates.csv.)
|
|
120
|
+
|
|
121
|
+
[//]: # (```bash)
|
|
122
|
+
|
|
123
|
+
[//]: # (bdeissct_infer --nwk Zurich.nwk --ci --p 0.25 --log estimates.csv)
|
|
124
|
+
|
|
125
|
+
[//]: # (```)
|
|
126
|
+
|
|
127
|
+
[//]: # ()
|
|
128
|
+
[//]: # (#### Help)
|
|
129
|
+
|
|
130
|
+
[//]: # ()
|
|
131
|
+
[//]: # (To see detailed options, run:)
|
|
132
|
+
|
|
133
|
+
[//]: # (```bash)
|
|
134
|
+
|
|
135
|
+
[//]: # (bdeissct_infer --help)
|
|
136
|
+
|
|
137
|
+
[//]: # (```)
|
|
138
|
+
|
|
139
|
+
[//]: # ()
|
|
140
|
+
[//]: # ()
|
|
141
|
+
[//]: # (### Run with docker)
|
|
142
|
+
|
|
143
|
+
[//]: # ()
|
|
144
|
+
[//]: # (#### Basic usage)
|
|
145
|
+
|
|
146
|
+
[//]: # (Once [docker](https://www.docker.com/community-edition) is installed, )
|
|
147
|
+
|
|
148
|
+
[//]: # (run the following command to estimate BDEISS-CT model parameters:)
|
|
149
|
+
|
|
150
|
+
[//]: # (```bash)
|
|
151
|
+
|
|
152
|
+
[//]: # (docker run -v <path_to_the_folder_containing_the_tree>:/data:rw -t evolbioinfo/bdeissct --nwk /data/Zurich.nwk --ci --p 0.25 --log /data/estimates.csv)
|
|
153
|
+
|
|
154
|
+
[//]: # (```)
|
|
155
|
+
|
|
156
|
+
[//]: # ()
|
|
157
|
+
[//]: # (This will produce a comma-separated file estimates.csv in the <path_to_the_folder_containing_the_tree> folder,)
|
|
158
|
+
|
|
159
|
+
[//]: # ( containing the estimated parameter values and their 95% CIs (can be viewed with a text editor, Excel or Libre Office Calc).)
|
|
160
|
+
|
|
161
|
+
[//]: # ()
|
|
162
|
+
[//]: # (#### Help)
|
|
163
|
+
|
|
164
|
+
[//]: # ()
|
|
165
|
+
[//]: # (To see advanced options, run)
|
|
166
|
+
|
|
167
|
+
[//]: # (```bash)
|
|
168
|
+
|
|
169
|
+
[//]: # (docker run -t evolbioinfo/bdeissct -h)
|
|
170
|
+
|
|
171
|
+
[//]: # (```)
|
|
172
|
+
|
|
173
|
+
[//]: # ()
|
|
174
|
+
[//]: # ()
|
|
175
|
+
[//]: # ()
|
|
176
|
+
[//]: # (### Run with apptainer)
|
|
177
|
+
|
|
178
|
+
[//]: # ()
|
|
179
|
+
[//]: # (#### Basic usage)
|
|
180
|
+
|
|
181
|
+
[//]: # (Once [apptainer](https://apptainer.org/docs/user/latest/quick_start.html#installation) is installed, )
|
|
182
|
+
|
|
183
|
+
[//]: # (run the following command to estimate BDEISS-CT model parameters (from the folder where the Zurich.nwk tree is contained):)
|
|
184
|
+
|
|
185
|
+
[//]: # ()
|
|
186
|
+
[//]: # (```bash)
|
|
187
|
+
|
|
188
|
+
[//]: # (apptainer run docker://evolbioinfo/bdeissct --nwk Zurich.nwk --ci --p 0.25 --log estimates.csv)
|
|
189
|
+
|
|
190
|
+
[//]: # (```)
|
|
191
|
+
|
|
192
|
+
[//]: # ()
|
|
193
|
+
[//]: # (This will produce a comma-separated file estimates.csv,)
|
|
194
|
+
|
|
195
|
+
[//]: # ( containing the estimated parameter values and their 95% CIs (can be viewed with a text editor, Excel or Libre Office Calc).)
|
|
196
|
+
|
|
197
|
+
[//]: # ()
|
|
198
|
+
[//]: # ()
|
|
199
|
+
[//]: # (#### Help)
|
|
200
|
+
|
|
201
|
+
[//]: # ()
|
|
202
|
+
[//]: # (To see advanced options, run)
|
|
203
|
+
|
|
204
|
+
[//]: # (```bash)
|
|
205
|
+
|
|
206
|
+
[//]: # (apptainer run docker://evolbioinfo/bdeissct -h)
|
|
207
|
+
|
|
208
|
+
[//]: # (```)
|
|
209
|
+
|
|
210
|
+
[//]: # ()
|
|
211
|
+
[//]: # ()
|
bdeissct_dl/__init__.py
CHANGED
|
@@ -7,12 +7,10 @@ warnings.filterwarnings('ignore', r'divide by zero encountered in log')
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
MODEL_PATH = os.path.join(os.path.dirname(__file__), 'models')
|
|
10
|
-
TRAINING_PATH = os.path.join(os.path.dirname(__file__), 'data')
|
|
11
|
-
|
|
12
10
|
|
|
13
11
|
|
|
14
12
|
EPOCHS = 1000
|
|
15
|
-
BATCH_SIZE =
|
|
13
|
+
BATCH_SIZE = 8192
|
|
16
14
|
|
|
17
15
|
|
|
18
16
|
|
bdeissct_dl/bdeissct_model.py
CHANGED
|
@@ -1,42 +1,26 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
2
|
|
|
3
3
|
LA = 'la'
|
|
4
|
-
LA_AVG = 'avg_la'
|
|
5
4
|
PSI = 'psi'
|
|
6
5
|
RHO = 'rho'
|
|
7
6
|
INFECTIOUS_TIME = 'd_I'
|
|
8
7
|
REPRODUCTIVE_NUMBER = 'R'
|
|
9
8
|
INFECTION_DURATION = 'd'
|
|
10
9
|
|
|
11
|
-
F_E = 'f_E'
|
|
12
10
|
MU = 'mu'
|
|
13
11
|
INCUBATION_PERIOD = 'd_E'
|
|
14
12
|
|
|
15
13
|
F_S = 'f_S'
|
|
16
14
|
X_S = 'X_S'
|
|
17
15
|
|
|
18
|
-
F_S_X_S = 'f_S_X_S'
|
|
19
|
-
|
|
20
|
-
PI_I = 'pi_I'
|
|
21
|
-
PI_IC = 'pi_IC'
|
|
22
|
-
PI_S = 'pi_S'
|
|
23
|
-
PI_SC = 'pi_SC'
|
|
24
|
-
PI_E = 'pi_E'
|
|
25
|
-
PI_EC = 'pi_EC'
|
|
26
|
-
|
|
27
|
-
PIS = 'pi'
|
|
28
|
-
|
|
29
16
|
|
|
30
17
|
X_C = 'X_C'
|
|
31
18
|
UPSILON = 'upsilon'
|
|
32
19
|
|
|
33
|
-
UPS_X_C = 'upsilon_X_C'
|
|
34
|
-
|
|
35
20
|
KAPPA = 'kappa'
|
|
36
|
-
PHI = 'phi'
|
|
37
21
|
REMOVAL_TIME_AFTER_NOTIFICATION = 'd_C'
|
|
38
22
|
|
|
39
|
-
RATE_PARAMETERS = (LA, PSI,
|
|
23
|
+
RATE_PARAMETERS = (LA, PSI, MU)
|
|
40
24
|
TIME_PARAMETERS = (INCUBATION_PERIOD, INFECTIOUS_TIME, REMOVAL_TIME_AFTER_NOTIFICATION, INFECTION_DURATION)
|
|
41
25
|
|
|
42
26
|
|
|
@@ -49,27 +33,15 @@ DEFAULT_MAX_RATE = 1e3
|
|
|
49
33
|
|
|
50
34
|
BD = 'BD'
|
|
51
35
|
BDCT = 'BDCT'
|
|
52
|
-
BDCT1 = 'BDCT1'
|
|
53
|
-
BDCT2 = 'BDCT2'
|
|
54
|
-
BDCT2000 = 'BDCT2000'
|
|
55
36
|
|
|
56
37
|
BDEI = 'BDEI'
|
|
57
38
|
BDEICT = 'BDEICT'
|
|
58
|
-
BDEICT1 = 'BDEICT1'
|
|
59
|
-
BDEICT2 = 'BDEICT2'
|
|
60
|
-
BDEICT2000 = 'BDEICT2000'
|
|
61
39
|
|
|
62
40
|
BDSS = 'BDSS'
|
|
63
41
|
BDSSCT = 'BDSSCT'
|
|
64
|
-
BDSSCT1 = 'BDSSCT1'
|
|
65
|
-
BDSSCT2 = 'BDSSCT2'
|
|
66
|
-
BDSSCT2000 = 'BDSSCT2000'
|
|
67
42
|
|
|
68
43
|
BDEISS = 'BDEISS'
|
|
69
44
|
BDEISSCT = 'BDEISSCT'
|
|
70
|
-
BDEISSCT1 = 'BDEISSCT1'
|
|
71
|
-
BDEISSCT2 = 'BDEISSCT2'
|
|
72
|
-
BDEISSCT2000 = 'BDEISSCT2000'
|
|
73
45
|
|
|
74
46
|
|
|
75
47
|
|
|
@@ -80,16 +52,9 @@ MODELS = (BD, BDCT, \
|
|
|
80
52
|
BDSS, BDSSCT, \
|
|
81
53
|
BDEISS, BDEISSCT)
|
|
82
54
|
|
|
83
|
-
DATA_TYPES = ['tree']
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
WITH_PIS = False
|
|
87
55
|
|
|
88
|
-
# TARGET_COLUMNS_BDCT = (REPRODUCTIVE_NUMBER, INFECTIOUS_TIME, REMOVAL_TIME_AFTER_NOTIFICATION, UPSILON)
|
|
89
|
-
# TARGET_COLUMNS_BD = (REPRODUCTIVE_NUMBER, INFECTIOUS_TIME)
|
|
90
|
-
# TARGET_CT_COLUMNS = (UPSILON, X_C, KAPPA)
|
|
91
56
|
TARGET_CT_COLUMNS = (UPSILON, X_C)
|
|
92
|
-
TARGET_INCUBATION_COLUMNS = (
|
|
57
|
+
TARGET_INCUBATION_COLUMNS = (INCUBATION_PERIOD,)
|
|
93
58
|
TARGET_SS_COLUMNS = (F_S, X_S)
|
|
94
59
|
TARGET_COLUMNS_BD = (REPRODUCTIVE_NUMBER, INFECTION_DURATION)
|
|
95
60
|
TARGET_COLUMNS_BDCT = TARGET_COLUMNS_BD + TARGET_CT_COLUMNS
|
|
@@ -100,33 +65,14 @@ TARGET_COLUMNS_BDSSCT = TARGET_COLUMNS_BDSS + TARGET_CT_COLUMNS
|
|
|
100
65
|
TARGET_COLUMNS_BDEISS = TARGET_COLUMNS_BDEI + TARGET_SS_COLUMNS
|
|
101
66
|
TARGET_COLUMNS_BDEISSCT = TARGET_COLUMNS_BDEISS + TARGET_CT_COLUMNS
|
|
102
67
|
|
|
103
|
-
if WITH_PIS:
|
|
104
|
-
TARGET_COLUMNS_BD = TARGET_COLUMNS_BD
|
|
105
|
-
TARGET_COLUMNS_BDCT = TARGET_COLUMNS_BDCT + (PI_I, PI_IC, )
|
|
106
|
-
TARGET_COLUMNS_BDEI = TARGET_COLUMNS_BDEI + (PI_E, PI_I, )
|
|
107
|
-
TARGET_COLUMNS_BDSS = TARGET_COLUMNS_BDSS + (PI_I, PI_S, )
|
|
108
|
-
TARGET_COLUMNS_BDEISS = TARGET_COLUMNS_BDEISS + (PI_E, PI_I, PI_S)
|
|
109
|
-
TARGET_COLUMNS_BDEICT = TARGET_COLUMNS_BDEICT + (PI_E, PI_I, PI_EC, PI_IC)
|
|
110
|
-
TARGET_COLUMNS_BDSSCT = TARGET_COLUMNS_BDSSCT + (PI_I, PI_S, PI_IC, PI_SC)
|
|
111
|
-
TARGET_COLUMNS_BDEISSCT = TARGET_COLUMNS_BDEISSCT + (PI_E, PI_I, PI_S, PI_EC, PI_IC, PI_SC)
|
|
112
|
-
|
|
113
68
|
|
|
114
69
|
MODEL2TARGET_COLUMNS = defaultdict(lambda: TARGET_COLUMNS_BDEISSCT)
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
for model in (BDSSCT, BDSSCT1, BDSSCT2, BDSSCT2000):
|
|
125
|
-
MODEL2TARGET_COLUMNS[model] = TARGET_COLUMNS_BDSSCT
|
|
126
|
-
for model in (BDEISSCT, BDEISSCT1, BDEISSCT2, BDEISSCT2000):
|
|
127
|
-
MODEL2TARGET_COLUMNS[model] = TARGET_COLUMNS_BDEISSCT
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
CT_EPI_COLUMNS = [REPRODUCTIVE_NUMBER, LA, F_E, F_S, X_S, UPSILON, X_C]
|
|
132
|
-
CT_RATE_COLUMNS = [PSI]
|
|
70
|
+
MODEL2TARGET_COLUMNS.update({BD: TARGET_COLUMNS_BD,
|
|
71
|
+
BDEI: TARGET_COLUMNS_BDEI,
|
|
72
|
+
BDSS: TARGET_COLUMNS_BDSS,
|
|
73
|
+
BDEISS: TARGET_COLUMNS_BDEISS,
|
|
74
|
+
BDCT: TARGET_COLUMNS_BDCT,
|
|
75
|
+
BDEICT: TARGET_COLUMNS_BDEICT,
|
|
76
|
+
BDSSCT: TARGET_COLUMNS_BDSSCT,
|
|
77
|
+
BDEISSCT: TARGET_COLUMNS_BDEISSCT
|
|
78
|
+
})
|
bdeissct_dl/dl_model.py
CHANGED
|
@@ -1,124 +1,21 @@
|
|
|
1
1
|
import tensorflow as tf
|
|
2
2
|
from tensorflow.python.keras.utils.generic_utils import register_keras_serializable
|
|
3
3
|
|
|
4
|
-
from bdeissct_dl.bdeissct_model import
|
|
5
|
-
INFECTION_DURATION, X_S, X_C,
|
|
4
|
+
from bdeissct_dl.bdeissct_model import F_S, UPSILON, REPRODUCTIVE_NUMBER, \
|
|
5
|
+
INFECTION_DURATION, X_S, X_C, RHO, INCUBATION_PERIOD
|
|
6
6
|
|
|
7
7
|
LEARNING_RATE = 0.001
|
|
8
8
|
|
|
9
|
-
DELTA = 0.001
|
|
10
9
|
LOSS_WEIGHTS = {
|
|
11
|
-
LA: 1,
|
|
12
|
-
PSI: 1,
|
|
13
10
|
REPRODUCTIVE_NUMBER: 1,
|
|
14
11
|
INFECTION_DURATION: 1,
|
|
15
|
-
|
|
12
|
+
INCUBATION_PERIOD: 1,
|
|
16
13
|
F_S: 200, # as it is a value between 0 and 0.5, we multiply by 200 to scale it to [0, 100]
|
|
17
|
-
F_E: 100,
|
|
18
14
|
UPSILON: 100,
|
|
19
|
-
F_S_X_S: 200, # as there are 2 outputs, we multiply by 200 to scale it to [0, 200]
|
|
20
15
|
X_C: 1,
|
|
21
|
-
X_S: 1
|
|
22
|
-
RHO: 1,
|
|
16
|
+
X_S: 1
|
|
23
17
|
}
|
|
24
18
|
|
|
25
|
-
QUANTILES = (0.5, )
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
@register_keras_serializable(package='bdeissct_dl', name='SSLayer')
|
|
29
|
-
class SSLayer(tf.keras.layers.Layer):
|
|
30
|
-
def call(self, inputs):
|
|
31
|
-
# inputs shape: (batch, 2)
|
|
32
|
-
f_S = half_sigmoid(inputs[:, 0:1]) # keepdims -> (batch, 1)
|
|
33
|
-
X_S = relu_plus_one(inputs[:, 1:2]) # (batch, 1)
|
|
34
|
-
return tf.concat([f_S, X_S], axis=-1) # (batch, 2)
|
|
35
|
-
|
|
36
|
-
def compute_output_shape(self, input_shape):
|
|
37
|
-
# input_shape is (batch, 2) -> output_shape is (batch, 2)
|
|
38
|
-
return input_shape[:-1] + (2,)
|
|
39
|
-
|
|
40
|
-
def get_config(self):
|
|
41
|
-
# If there are no special args, only return super() config
|
|
42
|
-
return super().get_config()
|
|
43
|
-
|
|
44
|
-
@classmethod
|
|
45
|
-
def from_config(cls, config):
|
|
46
|
-
return cls(**config)
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
@register_keras_serializable(package='bdeissct_dl', name='CTLayer')
|
|
50
|
-
class CTLayer(tf.keras.layers.Layer):
|
|
51
|
-
def call(self, inputs):
|
|
52
|
-
# inputs shape: (batch, 2)
|
|
53
|
-
ups = tf.sigmoid(inputs[:, 0:1]) # keepdims -> (batch, 1)
|
|
54
|
-
X_C = relu_plus_one(inputs[:, 1:2]) # (batch, 1)
|
|
55
|
-
return tf.concat([ups, X_C], axis=-1) # (batch, 2)
|
|
56
|
-
|
|
57
|
-
def compute_output_shape(self, input_shape):
|
|
58
|
-
# input_shape is (batch, 2) -> output_shape is (batch, 2)
|
|
59
|
-
return input_shape[:-1] + (2,)
|
|
60
|
-
|
|
61
|
-
def get_config(self):
|
|
62
|
-
# If there are no special args, only return super() config
|
|
63
|
-
return super().get_config()
|
|
64
|
-
|
|
65
|
-
@classmethod
|
|
66
|
-
def from_config(cls, config):
|
|
67
|
-
return cls(**config)
|
|
68
|
-
|
|
69
|
-
@tf.keras.utils.register_keras_serializable(package='bdeissct_dl', name='loss_ct')
|
|
70
|
-
def loss_ct(y_true, y_pred):
|
|
71
|
-
|
|
72
|
-
# Unpack the true values
|
|
73
|
-
p_true = y_true[:, 0]
|
|
74
|
-
X_true = y_true[:, 1]
|
|
75
|
-
|
|
76
|
-
# Unpack the predicted values
|
|
77
|
-
p_pred = y_pred[:, 0]
|
|
78
|
-
X_pred = y_pred[:, 1]
|
|
79
|
-
|
|
80
|
-
# Relative error for X_C
|
|
81
|
-
X_loss = tf.abs((X_pred - X_true) / X_true)
|
|
82
|
-
|
|
83
|
-
# Absolute error for ups
|
|
84
|
-
p_loss = tf.abs(p_pred - p_true)
|
|
85
|
-
# p_loss = tf.abs((p_pred - p_true) / tf.maximum(p_true, 1e-2))
|
|
86
|
-
|
|
87
|
-
mask = tf.cast(tf.greater(p_true, 1e-6), tf.float32)
|
|
88
|
-
X_loss = tf.reduce_mean(mask * X_loss)
|
|
89
|
-
|
|
90
|
-
# Combine the losses
|
|
91
|
-
return tf.reduce_mean(X_loss + p_loss)
|
|
92
|
-
|
|
93
|
-
@tf.keras.utils.register_keras_serializable(package='bdeissct_dl', name='loss_ss')
|
|
94
|
-
def loss_ss(y_true, y_pred):
|
|
95
|
-
|
|
96
|
-
# Unpack the true values
|
|
97
|
-
p_true = y_true[:, 0]
|
|
98
|
-
X_true = y_true[:, 1]
|
|
99
|
-
|
|
100
|
-
# Unpack the predicted values
|
|
101
|
-
p_pred = y_pred[:, 0]
|
|
102
|
-
X_pred = y_pred[:, 1]
|
|
103
|
-
|
|
104
|
-
# Relative error for X_S
|
|
105
|
-
X_loss = tf.abs((X_pred - X_true) / X_true)
|
|
106
|
-
|
|
107
|
-
# Absolute error for f_S, multiplied by 2, as f_S is in [0, 0.5]
|
|
108
|
-
p_loss = 2 * tf.abs(p_pred - p_true)
|
|
109
|
-
# p_loss = tf.abs((p_pred - p_true) / tf.maximum(p_true, 1e-2 / 2))
|
|
110
|
-
|
|
111
|
-
mask = tf.cast(tf.greater(p_true, 1e-6), tf.float32)
|
|
112
|
-
X_loss = tf.reduce_mean(mask * X_loss)
|
|
113
|
-
|
|
114
|
-
# Combine the losses
|
|
115
|
-
return tf.reduce_mean(X_loss + p_loss)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
@tf.keras.utils.register_keras_serializable(package='bdeissct_dl', name='loss_prob')
|
|
119
|
-
def loss_prob(y_true, y_pred):
|
|
120
|
-
return tf.reduce_mean(tf.abs((y_pred - y_true) / tf.maximum(y_true, 1e-2)))
|
|
121
|
-
|
|
122
19
|
@register_keras_serializable(package="bdeissct_dl", name="half_sigmoid")
|
|
123
20
|
def half_sigmoid(x):
|
|
124
21
|
return 0.5 * tf.sigmoid(x) # range ~ [0, 0.5)
|
|
@@ -130,16 +27,12 @@ def relu_plus_one(x):
|
|
|
130
27
|
|
|
131
28
|
|
|
132
29
|
LOSS_FUNCTIONS = {
|
|
133
|
-
LA: "mean_absolute_percentage_error",
|
|
134
|
-
PSI: "mean_absolute_percentage_error",
|
|
135
30
|
REPRODUCTIVE_NUMBER: "mean_absolute_percentage_error",
|
|
136
31
|
INFECTION_DURATION: "mean_absolute_percentage_error",
|
|
137
|
-
|
|
32
|
+
INCUBATION_PERIOD: "mean_absolute_percentage_error",
|
|
138
33
|
UPSILON: 'mae',
|
|
139
34
|
RHO: 'mean_absolute_percentage_error',
|
|
140
35
|
X_C: "mean_absolute_percentage_error",
|
|
141
|
-
F_E: 'mae',
|
|
142
|
-
F_S_X_S: loss_ss,
|
|
143
36
|
F_S: 'mae',
|
|
144
37
|
X_S: "mean_absolute_percentage_error",
|
|
145
38
|
}
|
|
@@ -165,7 +58,6 @@ def build_model(target_columns, n_x, optimizer=None, metrics=None):
|
|
|
165
58
|
x = tf.keras.layers.Dense(64, activation='elu', name=f'layer2_dense128_elu')(x)
|
|
166
59
|
x = tf.keras.layers.Dropout(0.5, name='dropout2_50')(x)
|
|
167
60
|
x = tf.keras.layers.Dense(32, activation='elu', name=f'layer3_dense64elu')(x)
|
|
168
|
-
# x = tf.keras.layers.Dropout(0.5, name='dropout3_50')(x)
|
|
169
61
|
x = tf.keras.layers.Dense(16, activation='elu', name=f'layer4_dense32_elu')(x)
|
|
170
62
|
|
|
171
63
|
outputs = {}
|
|
@@ -174,16 +66,12 @@ def build_model(target_columns, n_x, optimizer=None, metrics=None):
|
|
|
174
66
|
outputs[REPRODUCTIVE_NUMBER] = tf.keras.layers.Dense(1, activation="softplus", name=REPRODUCTIVE_NUMBER)(x) # positive values only
|
|
175
67
|
if INFECTION_DURATION in target_columns:
|
|
176
68
|
outputs[INFECTION_DURATION] = tf.keras.layers.Dense(1, activation="softplus", name=INFECTION_DURATION)(x) # positive values only
|
|
177
|
-
if
|
|
178
|
-
outputs[
|
|
179
|
-
# if F_S in target_columns:
|
|
180
|
-
# outputs[F_S_X_S] = SSLayer(name=F_S_X_S)(tf.keras.layers.Dense(2, activation=None, name="FS_XS_logits")(x))
|
|
69
|
+
if INCUBATION_PERIOD in target_columns:
|
|
70
|
+
outputs[INCUBATION_PERIOD] = tf.keras.layers.Dense(1, activation="softplus", name=INCUBATION_PERIOD)(x) # positive values only
|
|
181
71
|
if F_S in target_columns:
|
|
182
72
|
outputs[F_S] = tf.keras.layers.Dense(1, activation=half_sigmoid, name="FS_logits")(x)
|
|
183
73
|
if X_S in target_columns:
|
|
184
74
|
outputs[X_S] = tf.keras.layers.Dense(1, activation=relu_plus_one, name="XS_logits")(x)
|
|
185
|
-
# if UPSILON in target_columns:
|
|
186
|
-
# outputs[UPS_X_C] = CTLayer(name=UPS_X_C)(tf.keras.layers.Dense(2, activation=None, name="ups_XC_logits")(x))
|
|
187
75
|
if UPSILON in target_columns:
|
|
188
76
|
outputs[UPSILON] = tf.keras.layers.Dense(1, activation="sigmoid", name="ups_logits")(x)
|
|
189
77
|
if X_C in target_columns:
|