bdext 0.1.71__py3-none-any.whl → 0.1.72__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bdeissct_dl/dl_model.py CHANGED
@@ -6,7 +6,7 @@ from tensorflow.python.keras.utils.generic_utils import register_keras_serializa
6
6
  from bdeissct_dl.bdeissct_model import F_S, UPSILON, REPRODUCTIVE_NUMBER, \
7
7
  INFECTION_DURATION, X_S, X_C, INCUBATION_FRACTION
8
8
 
9
- LEARNING_RATE = 0.001
9
+ LEARNING_RATE = 0.01
10
10
 
11
11
  @register_keras_serializable(package="bdeissct_dl", name="half_sigmoid")
12
12
  def half_sigmoid(x):
@@ -36,12 +36,14 @@ def build_model(target_columns, n_x, optimizer=None, metrics=None):
36
36
  inputs = tf.keras.Input(shape=(n_x,))
37
37
 
38
38
  # (Your hidden layers go here)
39
- x = tf.keras.layers.Dense(128, activation='elu', name=f'layer1_dense256_elu')(inputs)
40
- x = tf.keras.layers.Dropout(0.5, name='dropout1_50')(x)
41
- x = tf.keras.layers.Dense(64, activation='elu', name=f'layer2_dense128_elu')(x)
42
- x = tf.keras.layers.Dropout(0.5, name='dropout2_50')(x)
43
- x = tf.keras.layers.Dense(32, activation='elu', name=f'layer3_dense64elu')(x)
44
- x = tf.keras.layers.Dense(16, activation='elu', name=f'layer4_dense32_elu')(x)
39
+ x = tf.keras.layers.Dense(128, activation='elu', name=f'layer1_dense128_elu')(inputs)
40
+ # x = tf.keras.layers.Dropout(0.5, name='dropout1_50')(x)
41
+ x = tf.keras.layers.Dense(64, activation='elu', name=f'layer2_dense64_elu')(x)
42
+ # x = tf.keras.layers.Dropout(0.5, name='dropout2_50')(x)
43
+ x = tf.keras.layers.Dense(32, activation='elu', name=f'layer3_dense32elu')(x)
44
+ x = tf.keras.layers.Dense(16, activation='elu', name=f'layer4_dense16_elu')(x)
45
+ x = tf.keras.layers.Dense(8, activation='elu', name=f'layer5_dense8_elu')(x)
46
+ x = tf.keras.layers.Dense(4, activation='elu', name=f'layer5_dense4_elu')(x)
45
47
 
46
48
  outputs = {}
47
49
 
@@ -52,13 +54,13 @@ def build_model(target_columns, n_x, optimizer=None, metrics=None):
52
54
  if INCUBATION_FRACTION in target_columns:
53
55
  outputs[INCUBATION_FRACTION] = tf.keras.layers.Dense(1, activation="sigmoid", name=INCUBATION_FRACTION)(x) # positive values only
54
56
  if F_S in target_columns:
55
- outputs[F_S] = tf.keras.layers.Dense(1, activation=half_sigmoid, name="FS_logits")(x)
57
+ outputs[F_S] = tf.keras.layers.Dense(1, activation=half_sigmoid, name=F_S)(x)
56
58
  if X_S in target_columns:
57
- outputs[X_S] = tf.keras.layers.Dense(1, activation=relu_plus_one, name="XS_logits")(x)
59
+ outputs[X_S] = tf.keras.layers.Dense(1, activation=relu_plus_one, name=X_S)(x)
58
60
  if UPSILON in target_columns:
59
- outputs[UPSILON] = tf.keras.layers.Dense(1, activation="sigmoid", name="ups_logits")(x)
61
+ outputs[UPSILON] = tf.keras.layers.Dense(1, activation="sigmoid", name=UPSILON)(x)
60
62
  if X_C in target_columns:
61
- outputs[X_C] = tf.keras.layers.Dense(1, activation=relu_plus_one, name="XC_logits")(x)
63
+ outputs[X_C] = tf.keras.layers.Dense(1, activation=relu_plus_one, name=X_C)(x)
62
64
 
63
65
  model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
64
66
 
@@ -27,7 +27,6 @@ def main():
27
27
  parser = \
28
28
  argparse.ArgumentParser(description="Fit a BD(EI)(SS)(CT) data scaler.")
29
29
  parser.add_argument('--train_data', type=str, nargs='+',
30
- # default=[f'/home/azhukova/projects/bdeissct_dl/simulations_bdeissct/training/500_1000/{model}/{i}/trees.csv.xz' for i in range(120) for model in [BD, BDCT, BDEI, BDEICT, BDSS, BDSSCT, BDEISS, BDEISSCT]],
31
30
  help="path to the files where the encoded training data are stored")
32
31
  parser.add_argument('--model_path', default=MODEL_PATH, type=str,
33
32
  help="path to the folder where the scaler should be stored.")
bdeissct_dl/training.py CHANGED
@@ -141,14 +141,8 @@ def main():
141
141
  parser = \
142
142
  argparse.ArgumentParser(description="Train a BD(EI)(SS)(CT) model.")
143
143
  parser.add_argument('--train_data', type=str, nargs='+',
144
- # default=[f'/home/azhukova/projects/bdeissct_dl/simulations_bdeissct/train/2000_5000/BDEI/{i}/trees.csv.xz' for i in range(100)] \
145
- # + [f'/home/azhukova/projects/bdeissct_dl/simulations_bdeissct/training/2000_5000/BD/{i}/trees.csv.xz' for i in range(10)]
146
- # ,
147
144
  help="path to the files where the encoded training data are stored")
148
145
  parser.add_argument('--val_data', type=str, nargs='+',
149
- # default=[f'/home/azhukova/projects/bdeissct_dl/simulations_bdeissct/train/2000_5000/BDEI/{i}/trees.csv.xz' for i in range(100, 120)] \
150
- # + [f'/home/azhukova/projects/bdeissct_dl/simulations_bdeissct/train/2000_5000/BD/{i}/trees.csv.xz' for i in range(10, 12)]
151
- # ,
152
146
  help="path to the files where the encoded validation data are stored")
153
147
 
154
148
  parser.add_argument('--epochs', type=int, default=EPOCHS, help='number of epochs to train the model')
@@ -178,6 +172,13 @@ def main():
178
172
 
179
173
 
180
174
  for col, y_idx in y_col2index.items():
175
+ try:
176
+ if load_model_keras(path=params.model_path, model_name=f'{params.model_name}.{col}'):
177
+ print(f'Model {params.model_name}.{col} already exists at {params.model_path}. Skipping training for this target.')
178
+ continue
179
+ except:
180
+ pass
181
+
181
182
  print(f'Training to predict {col} with {params.model_name}...')
182
183
 
183
184
  if params.base_model_name is not None:
@@ -194,7 +195,7 @@ def main():
194
195
  scaler_x=scaler_x, batch_size=BATCH_SIZE, shuffle=True)
195
196
 
196
197
  #early stopping to avoid overfitting
197
- early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=25)
198
+ early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
198
199
 
199
200
  #Training of the Network, with an independent validation set
200
201
  model.fit(ds_train, verbose=1, epochs=params.epochs, validation_data=ds_val, callbacks=[early_stop])
@@ -1,13 +1,11 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: bdext
3
- Version: 0.1.71
3
+ Version: 0.1.72
4
4
  Summary: Estimation of BDEISS-CT parameters from phylogenetic trees.
5
5
  Home-page: https://github.com/modpath/bdeissct
6
6
  Author: Anna Zhukova
7
7
  Author-email: anna.zhukova@pasteur.fr
8
- License: UNKNOWN
9
8
  Keywords: phylogenetics,birth-death model,incubation,super-spreading,contact tracing
10
- Platform: UNKNOWN
11
9
  Classifier: Development Status :: 4 - Beta
12
10
  Classifier: Environment :: Console
13
11
  Classifier: Intended Audience :: Developers
@@ -15,6 +13,7 @@ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
15
13
  Classifier: Topic :: Software Development
16
14
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
17
15
  Description-Content-Type: text/markdown
16
+ License-File: LICENSE
18
17
  Requires-Dist: tensorflow==2.19.0
19
18
  Requires-Dist: six
20
19
  Requires-Dist: ete3
@@ -24,6 +23,16 @@ Requires-Dist: biopython
24
23
  Requires-Dist: scikit-learn==1.5.2
25
24
  Requires-Dist: pandas==2.2.3
26
25
  Requires-Dist: treesumstats==0.7
26
+ Dynamic: author
27
+ Dynamic: author-email
28
+ Dynamic: classifier
29
+ Dynamic: description
30
+ Dynamic: description-content-type
31
+ Dynamic: home-page
32
+ Dynamic: keywords
33
+ Dynamic: license-file
34
+ Dynamic: requires-dist
35
+ Dynamic: summary
27
36
 
28
37
  # bdext
29
38
 
@@ -236,5 +245,3 @@ The other parameters are estimated from a time-scaled phylogenetic tree.
236
245
 
237
246
  [//]: # ()
238
247
  [//]: # ()
239
-
240
-
@@ -0,0 +1,17 @@
1
+ README.md,sha256=Ngj8bt0Yu3LUsvwblmMtUqqjvGyqxv6ku2_cYCb5_DQ,6539
2
+ bdeissct_dl/__init__.py,sha256=QPEiIP-xVqGQgydeqN_9AZgT26IYWeJC4-JlHnd8Rjo,296
3
+ bdeissct_dl/bdeissct_model.py,sha256=um1nEQf4uym_jkrkuUjpvIbVc7VRfmAY3gnW9xgXv6I,2016
4
+ bdeissct_dl/dl_model.py,sha256=gl6uBK6rwEJxWgzInQfyn-1UWbePQJPDWwc7Lwq5F0U,3250
5
+ bdeissct_dl/estimator.py,sha256=QBWA8R0pBPZPd3JvItdJS2lN1J3VqvdJqBMzCi-NADs,3336
6
+ bdeissct_dl/model_serializer.py,sha256=s1yBzQjhtr-w7eT8bTsNkG9_xnYRZrUc3HkeOzNZpQY,2464
7
+ bdeissct_dl/scaler_fitting.py,sha256=wvHLtLmg5QP58NKSUnYBOQ4TzAtTAi_AfLVaxKXfJzM,1522
8
+ bdeissct_dl/sumstat_checker.py,sha256=TQ0nb86-BXmusqgMnOJusLpR4ul3N3Hi886IWUovrMI,1846
9
+ bdeissct_dl/training.py,sha256=EvD1n3uiaUb8gubxwhP1kt4xUW2hokHB7ywoUScCtmI,7979
10
+ bdeissct_dl/tree_encoder.py,sha256=WAwn3e1lPiksZNCnwTt9wsoEX3rgF8O0b2vOx7g0gUY,20286
11
+ bdeissct_dl/tree_manager.py,sha256=UXxUVmEkxwUhKpJeACVgiXZ8Kp1o_hiv8Qb80b6qmVU,11814
12
+ bdext-0.1.72.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
13
+ bdext-0.1.72.dist-info/METADATA,sha256=5ix9OE4DIpC4K4xoW61JNPePZro3J2K5tt1AcGR6puc,7676
14
+ bdext-0.1.72.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
15
+ bdext-0.1.72.dist-info/entry_points.txt,sha256=DP-XVnUjSLJt-PHOJUurpkEUkkicdtGoEuGVeVb0gGg,207
16
+ bdext-0.1.72.dist-info/top_level.txt,sha256=z4dadFfcLghr4lwROy7QR3zEICpa-eCPT6mmcoHeEJY,12
17
+ bdext-0.1.72.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.45.1)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -3,4 +3,3 @@ bdeissct_encode = bdeissct_dl.tree_encoder:main
3
3
  bdeissct_fit_scaler = bdeissct_dl.scaler_fitting:main
4
4
  bdeissct_infer = bdeissct_dl.estimator:main
5
5
  bdeissct_train = bdeissct_dl.training:main
6
-
@@ -1,17 +0,0 @@
1
- README.md,sha256=Ngj8bt0Yu3LUsvwblmMtUqqjvGyqxv6ku2_cYCb5_DQ,6539
2
- bdeissct_dl/__init__.py,sha256=QPEiIP-xVqGQgydeqN_9AZgT26IYWeJC4-JlHnd8Rjo,296
3
- bdeissct_dl/bdeissct_model.py,sha256=um1nEQf4uym_jkrkuUjpvIbVc7VRfmAY3gnW9xgXv6I,2016
4
- bdeissct_dl/dl_model.py,sha256=xtAov5s_l5XUd569t1BUrgnmsT14FFRT-gissXZZkwo,3115
5
- bdeissct_dl/estimator.py,sha256=QBWA8R0pBPZPd3JvItdJS2lN1J3VqvdJqBMzCi-NADs,3336
6
- bdeissct_dl/model_serializer.py,sha256=s1yBzQjhtr-w7eT8bTsNkG9_xnYRZrUc3HkeOzNZpQY,2464
7
- bdeissct_dl/scaler_fitting.py,sha256=9X0O7-Wc9xGTI-iF-Pfp1PPoW7j01wZUfJVZf8ky-IU,1752
8
- bdeissct_dl/sumstat_checker.py,sha256=TQ0nb86-BXmusqgMnOJusLpR4ul3N3Hi886IWUovrMI,1846
9
- bdeissct_dl/training.py,sha256=ahkQoVq88sVyX8Q5bEG5DBfeamu1TH5pXQ1daf5Z-Cw,8359
10
- bdeissct_dl/tree_encoder.py,sha256=WAwn3e1lPiksZNCnwTt9wsoEX3rgF8O0b2vOx7g0gUY,20286
11
- bdeissct_dl/tree_manager.py,sha256=UXxUVmEkxwUhKpJeACVgiXZ8Kp1o_hiv8Qb80b6qmVU,11814
12
- bdext-0.1.71.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
13
- bdext-0.1.71.dist-info/METADATA,sha256=U2hCvUQLoGiZQruc8Y5MyipYQSNzoUT-A8YXzMue6Ag,7479
14
- bdext-0.1.71.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
15
- bdext-0.1.71.dist-info/entry_points.txt,sha256=lcAwyk-Fc0G_w4Ex7KDivh7h1tzSA99PRMcy971b-nM,208
16
- bdext-0.1.71.dist-info/top_level.txt,sha256=z4dadFfcLghr4lwROy7QR3zEICpa-eCPT6mmcoHeEJY,12
17
- bdext-0.1.71.dist-info/RECORD,,