tmnt 0.7.44b20240120__py3-none-any.whl → 0.7.44b20240122__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tmnt/distribution.py +13 -9
- {tmnt-0.7.44b20240120.dist-info → tmnt-0.7.44b20240122.dist-info}/METADATA +1 -1
- {tmnt-0.7.44b20240120.dist-info → tmnt-0.7.44b20240122.dist-info}/RECORD +7 -7
- {tmnt-0.7.44b20240120.dist-info → tmnt-0.7.44b20240122.dist-info}/LICENSE +0 -0
- {tmnt-0.7.44b20240120.dist-info → tmnt-0.7.44b20240122.dist-info}/NOTICE +0 -0
- {tmnt-0.7.44b20240120.dist-info → tmnt-0.7.44b20240122.dist-info}/WHEEL +0 -0
- {tmnt-0.7.44b20240120.dist-info → tmnt-0.7.44b20240122.dist-info}/top_level.txt +0 -0
tmnt/distribution.py
CHANGED
@@ -83,7 +83,7 @@ class GaussianDistribution(BaseDistribution):
|
|
83
83
|
z = self.post_sample_dr_o(z)
|
84
84
|
return z, KL
|
85
85
|
|
86
|
-
def get_mu_encoding(self, data, include_bn=False):
|
86
|
+
def get_mu_encoding(self, data, include_bn=True, normalize=False):
|
87
87
|
"""Provide the distribution mean as the natural result of running the full encoder
|
88
88
|
|
89
89
|
Parameters:
|
@@ -94,7 +94,8 @@ class GaussianDistribution(BaseDistribution):
|
|
94
94
|
enc = self.mu_encoder(data)
|
95
95
|
if include_bn:
|
96
96
|
enc = self.mu_bn(enc)
|
97
|
-
|
97
|
+
mu = self.softplus(enc) if normalize else enc
|
98
|
+
return mu
|
98
99
|
|
99
100
|
|
100
101
|
|
@@ -126,7 +127,7 @@ class GaussianUnitVarDistribution(BaseDistribution):
|
|
126
127
|
KL = self._get_kl_term(mu_bn)
|
127
128
|
return self.post_sample_dr_o(z), KL
|
128
129
|
|
129
|
-
def get_mu_encoding(self, data, include_bn=False):
|
130
|
+
def get_mu_encoding(self, data, include_bn=True, normalize=False):
|
130
131
|
"""Provide the distribution mean as the natural result of running the full encoder
|
131
132
|
|
132
133
|
Parameters:
|
@@ -137,7 +138,8 @@ class GaussianUnitVarDistribution(BaseDistribution):
|
|
137
138
|
enc = self.mu_encoder(data)
|
138
139
|
if include_bn:
|
139
140
|
enc = self.mu_bn(enc)
|
140
|
-
|
141
|
+
mu = self.softplus(enc) if normalize else enc
|
142
|
+
return mu
|
141
143
|
|
142
144
|
|
143
145
|
class LogisticGaussianDistribution(BaseDistribution):
|
@@ -182,7 +184,7 @@ class LogisticGaussianDistribution(BaseDistribution):
|
|
182
184
|
z = self.post_sample_dr_o(z_p)
|
183
185
|
return self.softmax(z), KL
|
184
186
|
|
185
|
-
def get_mu_encoding(self, data, include_bn=False):
|
187
|
+
def get_mu_encoding(self, data, include_bn=True, normalize=False):
|
186
188
|
"""Provide the distribution mean as the natural result of running the full encoder
|
187
189
|
|
188
190
|
Parameters:
|
@@ -193,7 +195,8 @@ class LogisticGaussianDistribution(BaseDistribution):
|
|
193
195
|
enc = self.mu_encoder(data)
|
194
196
|
if include_bn:
|
195
197
|
enc = self.mu_bn(enc)
|
196
|
-
|
198
|
+
mu = self.softmax(enc) if normalize else enc
|
199
|
+
return mu
|
197
200
|
|
198
201
|
|
199
202
|
class VonMisesDistribution(BaseDistribution):
|
@@ -220,7 +223,7 @@ class VonMisesDistribution(BaseDistribution):
|
|
220
223
|
kld = self.kld_v.expand(batch_size)
|
221
224
|
return z_p, kld
|
222
225
|
|
223
|
-
def get_mu_encoding(self, data, include_bn=False):
|
226
|
+
def get_mu_encoding(self, data, include_bn=True, normalize=False):
|
224
227
|
"""Provide the distribution mean as the natural result of running the full encoder
|
225
228
|
|
226
229
|
Parameters:
|
@@ -231,7 +234,8 @@ class VonMisesDistribution(BaseDistribution):
|
|
231
234
|
enc = self.mu_encoder(data)
|
232
235
|
if include_bn:
|
233
236
|
enc = self.mu_bn(enc)
|
234
|
-
|
237
|
+
mu = self.softplus(enc) if normalize else enc
|
238
|
+
return mu
|
235
239
|
|
236
240
|
|
237
241
|
|
@@ -247,7 +251,7 @@ class Projection(BaseDistribution):
|
|
247
251
|
kld = torch.zeros(batch_size).to(self.device)
|
248
252
|
return mu_bn, kld
|
249
253
|
|
250
|
-
def get_mu_encoding(self, data, include_bn=False):
|
254
|
+
def get_mu_encoding(self, data, include_bn=True, normalize=False):
|
251
255
|
"""Provide the distribution mean as the natural result of running the full encoder
|
252
256
|
|
253
257
|
Parameters:
|
@@ -1,7 +1,7 @@
|
|
1
1
|
tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
|
2
2
|
tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
|
3
3
|
tmnt/data_loading.py,sha256=_NpAwmpeFBoQp7xtWOLb6i3WS271JoSJqDx9BMrXtKM,18207
|
4
|
-
tmnt/distribution.py,sha256=
|
4
|
+
tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
|
5
5
|
tmnt/estimator.py,sha256=xk4QATqqD8ukxtraOQ6BvSJrdqGTQvX52fNdcgfQ3w8,77801
|
6
6
|
tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
|
7
7
|
tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
|
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
|
|
17
17
|
tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
|
18
18
|
tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
|
19
19
|
tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
|
20
|
-
tmnt-0.7.
|
21
|
-
tmnt-0.7.
|
22
|
-
tmnt-0.7.
|
23
|
-
tmnt-0.7.
|
24
|
-
tmnt-0.7.
|
25
|
-
tmnt-0.7.
|
20
|
+
tmnt-0.7.44b20240122.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
|
21
|
+
tmnt-0.7.44b20240122.dist-info/METADATA,sha256=ATwYUpYO65LnsfQagNNk0n_tfXsrfq5-Me1Ojg-5jVI,1403
|
22
|
+
tmnt-0.7.44b20240122.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
|
23
|
+
tmnt-0.7.44b20240122.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
24
|
+
tmnt-0.7.44b20240122.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
|
25
|
+
tmnt-0.7.44b20240122.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|