bdext 0.1.69__py3-none-any.whl → 0.1.70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bdeissct_dl/dl_model.py CHANGED
@@ -4,17 +4,9 @@ from tensorflow.python.keras.utils.generic_utils import register_keras_serializa
4
4
  from bdeissct_dl.bdeissct_model import F_S, UPSILON, REPRODUCTIVE_NUMBER, \
5
5
  INFECTION_DURATION, X_S, X_C, RHO, INCUBATION_PERIOD
6
6
 
7
- LEARNING_RATE = 0.001
7
+ from collections import defaultdict
8
8
 
9
- LOSS_WEIGHTS = {
10
- REPRODUCTIVE_NUMBER: 1,
11
- INFECTION_DURATION: 1,
12
- INCUBATION_PERIOD: 1,
13
- F_S: 200, # as it is a value between 0 and 0.5, we multiply by 200 to scale it to [0, 100]
14
- UPSILON: 100,
15
- X_C: 1,
16
- X_S: 1
17
- }
9
+ LEARNING_RATE = 0.001
18
10
 
19
11
  @register_keras_serializable(package="bdeissct_dl", name="half_sigmoid")
20
12
  def half_sigmoid(x):
@@ -26,16 +18,17 @@ def relu_plus_one(x):
26
18
 
27
19
 
28
20
 
29
- LOSS_FUNCTIONS = {
30
- REPRODUCTIVE_NUMBER: "mean_absolute_percentage_error",
31
- INFECTION_DURATION: "mean_absolute_percentage_error",
32
- INCUBATION_PERIOD: "mae",
33
- UPSILON: 'mae',
34
- RHO: 'mean_absolute_percentage_error',
35
- X_C: "mean_absolute_percentage_error",
36
- F_S: 'mae',
37
- X_S: "mean_absolute_percentage_error",
38
- }
21
+ LOSS_FUNCTIONS = defaultdict(lambda: "mean_squared_error")
22
+ LOSS_FUNCTIONS.update({
23
+ REPRODUCTIVE_NUMBER: "mean_squared_error",
24
+ INFECTION_DURATION: "mean_squared_error",
25
+ INCUBATION_PERIOD: "mean_squared_error",
26
+ UPSILON: 'mean_squared_error',
27
+ RHO: 'mean_squared_error',
28
+ X_C: "mean_squared_error",
29
+ F_S: 'mean_squared_error',
30
+ X_S: "mean_squared_error",
31
+ })
39
32
 
40
33
 
41
34
  def build_model(target_columns, n_x, optimizer=None, metrics=None):
@@ -63,11 +56,11 @@ def build_model(target_columns, n_x, optimizer=None, metrics=None):
63
56
  outputs = {}
64
57
 
65
58
  if REPRODUCTIVE_NUMBER in target_columns:
66
- outputs[REPRODUCTIVE_NUMBER] = tf.keras.layers.Dense(1, activation="softplus", name=REPRODUCTIVE_NUMBER)(x) # positive values only
59
+ outputs[REPRODUCTIVE_NUMBER] = tf.keras.layers.Dense(1, activation="relu", name=REPRODUCTIVE_NUMBER)(x) # positive values only
67
60
  if INFECTION_DURATION in target_columns:
68
- outputs[INFECTION_DURATION] = tf.keras.layers.Dense(1, activation="softplus", name=INFECTION_DURATION)(x) # positive values only
61
+ outputs[INFECTION_DURATION] = tf.keras.layers.Dense(1, activation="relu", name=INFECTION_DURATION)(x) # positive values only
69
62
  if INCUBATION_PERIOD in target_columns:
70
- outputs[INCUBATION_PERIOD] = tf.keras.layers.Dense(1, activation="softplus", name=INCUBATION_PERIOD)(x) # positive values only
63
+ outputs[INCUBATION_PERIOD] = tf.keras.layers.Dense(1, activation="relu", name=INCUBATION_PERIOD)(x) # positive values only
71
64
  if F_S in target_columns:
72
65
  outputs[F_S] = tf.keras.layers.Dense(1, activation=half_sigmoid, name="FS_logits")(x)
73
66
  if X_S in target_columns:
@@ -84,6 +77,5 @@ def build_model(target_columns, n_x, optimizer=None, metrics=None):
84
77
 
85
78
  model.compile(optimizer=optimizer,
86
79
  loss={col: LOSS_FUNCTIONS[col] for col in outputs.keys()},
87
- loss_weights={col: LOSS_WEIGHTS[col] for col in outputs.keys()},
88
80
  metrics=metrics)
89
81
  return model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bdext
3
- Version: 0.1.69
3
+ Version: 0.1.70
4
4
  Summary: Estimation of BDEISS-CT parameters from phylogenetic trees.
5
5
  Home-page: https://github.com/modpath/bdeissct
6
6
  Author: Anna Zhukova
@@ -1,7 +1,7 @@
1
1
  README.md,sha256=Ngj8bt0Yu3LUsvwblmMtUqqjvGyqxv6ku2_cYCb5_DQ,6539
2
2
  bdeissct_dl/__init__.py,sha256=QPEiIP-xVqGQgydeqN_9AZgT26IYWeJC4-JlHnd8Rjo,296
3
3
  bdeissct_dl/bdeissct_model.py,sha256=sQclYN5V8utw6wEMDN0_Ua-0NeuyuWHG_e0_jQIUe8Q,1986
4
- bdeissct_dl/dl_model.py,sha256=eXyFHqgJtXovzPFYn9xW5-Th92f28Ci6iuRHEwSd8y0,3678
4
+ bdeissct_dl/dl_model.py,sha256=wpwlUVy6kOhPIsT1zg-Us2_bdnntxdnCbNQB4UxYzTg,3433
5
5
  bdeissct_dl/estimator.py,sha256=QBWA8R0pBPZPd3JvItdJS2lN1J3VqvdJqBMzCi-NADs,3336
6
6
  bdeissct_dl/model_serializer.py,sha256=s1yBzQjhtr-w7eT8bTsNkG9_xnYRZrUc3HkeOzNZpQY,2464
7
7
  bdeissct_dl/scaler_fitting.py,sha256=9X0O7-Wc9xGTI-iF-Pfp1PPoW7j01wZUfJVZf8ky-IU,1752
@@ -9,9 +9,9 @@ bdeissct_dl/sumstat_checker.py,sha256=TQ0nb86-BXmusqgMnOJusLpR4ul3N3Hi886IWUovrM
9
9
  bdeissct_dl/training.py,sha256=H5wA3V72nhc9Km7kvKmzjCYw0N1itMGDbj9c-Uat5BU,8350
10
10
  bdeissct_dl/tree_encoder.py,sha256=V-7_Kis9x9JacI_mF7rWRGGKvxn7AWFCto7LkgRawBw,18286
11
11
  bdeissct_dl/tree_manager.py,sha256=UXxUVmEkxwUhKpJeACVgiXZ8Kp1o_hiv8Qb80b6qmVU,11814
12
- bdext-0.1.69.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
13
- bdext-0.1.69.dist-info/METADATA,sha256=XR1qaonSH4YNjna_G_yMzDwYdCrvFHw75msvPJ2XMUo,7479
14
- bdext-0.1.69.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
15
- bdext-0.1.69.dist-info/entry_points.txt,sha256=lcAwyk-Fc0G_w4Ex7KDivh7h1tzSA99PRMcy971b-nM,208
16
- bdext-0.1.69.dist-info/top_level.txt,sha256=z4dadFfcLghr4lwROy7QR3zEICpa-eCPT6mmcoHeEJY,12
17
- bdext-0.1.69.dist-info/RECORD,,
12
+ bdext-0.1.70.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
13
+ bdext-0.1.70.dist-info/METADATA,sha256=USWAUX3zunofN9x2-6E63lFVrNAdRx2cZf2Sc8HKGe8,7479
14
+ bdext-0.1.70.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
15
+ bdext-0.1.70.dist-info/entry_points.txt,sha256=lcAwyk-Fc0G_w4Ex7KDivh7h1tzSA99PRMcy971b-nM,208
16
+ bdext-0.1.70.dist-info/top_level.txt,sha256=z4dadFfcLghr4lwROy7QR3zEICpa-eCPT6mmcoHeEJY,12
17
+ bdext-0.1.70.dist-info/RECORD,,
File without changes