tmnt 0.7.44b20240120__py3-none-any.whl → 0.7.44b20240122__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/distribution.py CHANGED
@@ -83,7 +83,7 @@ class GaussianDistribution(BaseDistribution):
83
83
  z = self.post_sample_dr_o(z)
84
84
  return z, KL
85
85
 
86
- def get_mu_encoding(self, data, include_bn=False):
86
+ def get_mu_encoding(self, data, include_bn=True, normalize=False):
87
87
  """Provide the distribution mean as the natural result of running the full encoder
88
88
 
89
89
  Parameters:
@@ -94,7 +94,8 @@ class GaussianDistribution(BaseDistribution):
94
94
  enc = self.mu_encoder(data)
95
95
  if include_bn:
96
96
  enc = self.mu_bn(enc)
97
- return self.softplus(enc)
97
+ mu = self.softplus(enc) if normalize else enc
98
+ return mu
98
99
 
99
100
 
100
101
 
@@ -126,7 +127,7 @@ class GaussianUnitVarDistribution(BaseDistribution):
126
127
  KL = self._get_kl_term(mu_bn)
127
128
  return self.post_sample_dr_o(z), KL
128
129
 
129
- def get_mu_encoding(self, data, include_bn=False):
130
+ def get_mu_encoding(self, data, include_bn=True, normalize=False):
130
131
  """Provide the distribution mean as the natural result of running the full encoder
131
132
 
132
133
  Parameters:
@@ -137,7 +138,8 @@ class GaussianUnitVarDistribution(BaseDistribution):
137
138
  enc = self.mu_encoder(data)
138
139
  if include_bn:
139
140
  enc = self.mu_bn(enc)
140
- return self.softplus(enc)
141
+ mu = self.softplus(enc) if normalize else enc
142
+ return mu
141
143
 
142
144
 
143
145
  class LogisticGaussianDistribution(BaseDistribution):
@@ -182,7 +184,7 @@ class LogisticGaussianDistribution(BaseDistribution):
182
184
  z = self.post_sample_dr_o(z_p)
183
185
  return self.softmax(z), KL
184
186
 
185
- def get_mu_encoding(self, data, include_bn=False):
187
+ def get_mu_encoding(self, data, include_bn=True, normalize=False):
186
188
  """Provide the distribution mean as the natural result of running the full encoder
187
189
 
188
190
  Parameters:
@@ -193,7 +195,8 @@ class LogisticGaussianDistribution(BaseDistribution):
193
195
  enc = self.mu_encoder(data)
194
196
  if include_bn:
195
197
  enc = self.mu_bn(enc)
196
- return self.softmax(enc)
198
+ mu = self.softmax(enc) if normalize else enc
199
+ return mu
197
200
 
198
201
 
199
202
  class VonMisesDistribution(BaseDistribution):
@@ -220,7 +223,7 @@ class VonMisesDistribution(BaseDistribution):
220
223
  kld = self.kld_v.expand(batch_size)
221
224
  return z_p, kld
222
225
 
223
- def get_mu_encoding(self, data, include_bn=False):
226
+ def get_mu_encoding(self, data, include_bn=True, normalize=False):
224
227
  """Provide the distribution mean as the natural result of running the full encoder
225
228
 
226
229
  Parameters:
@@ -231,7 +234,8 @@ class VonMisesDistribution(BaseDistribution):
231
234
  enc = self.mu_encoder(data)
232
235
  if include_bn:
233
236
  enc = self.mu_bn(enc)
234
- return self.softplus(enc)
237
+ mu = self.softplus(enc) if normalize else enc
238
+ return mu
235
239
 
236
240
 
237
241
 
@@ -247,7 +251,7 @@ class Projection(BaseDistribution):
247
251
  kld = torch.zeros(batch_size).to(self.device)
248
252
  return mu_bn, kld
249
253
 
250
- def get_mu_encoding(self, data, include_bn=False):
254
+ def get_mu_encoding(self, data, include_bn=True, normalize=False):
251
255
  """Provide the distribution mean as the natural result of running the full encoder
252
256
 
253
257
  Parameters:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tmnt
3
- Version: 0.7.44b20240120
3
+ Version: 0.7.44b20240122
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -1,7 +1,7 @@
1
1
  tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
2
2
  tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
3
3
  tmnt/data_loading.py,sha256=_NpAwmpeFBoQp7xtWOLb6i3WS271JoSJqDx9BMrXtKM,18207
4
- tmnt/distribution.py,sha256=vNvBq7vLwLOqic-r8knD_pqjnUGKiBmmwbhbwOvl-cE,10633
4
+ tmnt/distribution.py,sha256=Pmyc5gwDd_-jP7vLVb0vdNQaSSvF1EuiTZEWg3KfmI8,10866
5
5
  tmnt/estimator.py,sha256=xk4QATqqD8ukxtraOQ6BvSJrdqGTQvX52fNdcgfQ3w8,77801
6
6
  tmnt/eval_npmi.py,sha256=ODRDMsBgDM__iCNEX399ck7bAhl7ydvgDqmpfR7Y-q4,5048
7
7
  tmnt/inference.py,sha256=Sw7GO7QiWVEtbPJKBjFB7AiKRmUOZbFZn3tCrsStzWw,17845
@@ -17,9 +17,9 @@ tmnt/utils/ngram_helpers.py,sha256=VrIzou2oQHCLBLSWODDeikN3PYat1NqqvEeYQj_GhbA,1
17
17
  tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,1274
18
18
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
19
19
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
20
- tmnt-0.7.44b20240120.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
- tmnt-0.7.44b20240120.dist-info/METADATA,sha256=GZZj0rVgcjfff-EtNd-AUJV-_bN6E3nE57smSxI1sHo,1403
22
- tmnt-0.7.44b20240120.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
- tmnt-0.7.44b20240120.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
24
- tmnt-0.7.44b20240120.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
- tmnt-0.7.44b20240120.dist-info/RECORD,,
20
+ tmnt-0.7.44b20240122.dist-info/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
21
+ tmnt-0.7.44b20240122.dist-info/METADATA,sha256=ATwYUpYO65LnsfQagNNk0n_tfXsrfq5-Me1Ojg-5jVI,1403
22
+ tmnt-0.7.44b20240122.dist-info/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
+ tmnt-0.7.44b20240122.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
24
+ tmnt-0.7.44b20240122.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
25
+ tmnt-0.7.44b20240122.dist-info/RECORD,,