dsipts 1.1.9__tar.gz → 1.1.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- {dsipts-1.1.9 → dsipts-1.1.10}/PKG-INFO +1 -1
- {dsipts-1.1.9 → dsipts-1.1.10}/pyproject.toml +1 -1
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/base.py +5 -3
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/base_v2.py +5 -3
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/utils.py +7 -5
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts.egg-info/PKG-INFO +1 -1
- {dsipts-1.1.9 → dsipts-1.1.10}/README.md +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/setup.cfg +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/data_management/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/data_management/monash.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/data_management/public_datasets.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/data_structure/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/data_structure/data_structure.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/data_structure/modifiers.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/data_structure/utils.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/Autoformer.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/CrossFormer.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/D3VAE.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/Diffusion.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/DilatedConv.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/DilatedConvED.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/Duet.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/ITransformer.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/Informer.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/LinearTS.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/PatchTST.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/Persistent.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/RNN.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/Samformer.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/Simple.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/TFT.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/TIDE.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/TTM.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/TimeXER.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/VQVAEA.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/VVA.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/autoformer/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/autoformer/layers.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/crossformer/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/crossformer/attn.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/crossformer/cross_decoder.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/crossformer/cross_embed.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/crossformer/cross_encoder.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/d3vae/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/d3vae/diffusion_process.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/d3vae/embedding.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/d3vae/encoder.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/d3vae/model.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/d3vae/neural_operations.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/d3vae/resnet.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/d3vae/utils.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/duet/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/duet/layers.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/duet/masked.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/informer/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/informer/attn.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/informer/decoder.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/informer/embed.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/informer/encoder.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/itransformer/Embed.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/itransformer/SelfAttention_Family.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/itransformer/Transformer_EncDec.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/itransformer/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/patchtst/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/patchtst/layers.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/samformer/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/samformer/utils.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/tft/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/tft/sub_nn.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/timexer/Layers.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/timexer/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/ttm/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/ttm/configuration_tinytimemixer.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/ttm/consts.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/ttm/modeling_tinytimemixer.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/ttm/utils.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/vva/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/vva/minigpt.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/vva/vqvae.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/xlstm/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts/models/xlstm/xLSTM.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts.egg-info/SOURCES.txt +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts.egg-info/dependency_links.txt +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts.egg-info/requires.txt +0 -0
- {dsipts-1.1.9 → dsipts-1.1.10}/src/dsipts.egg-info/top_level.txt +0 -0
|
@@ -136,9 +136,9 @@ class Base(pl.LightningModule):
|
|
|
136
136
|
self.is_classification = False
|
|
137
137
|
if len(self.quantiles)>0:
|
|
138
138
|
if self.loss_type=='cprs':
|
|
139
|
-
self.use_quantiles =
|
|
139
|
+
self.use_quantiles = True
|
|
140
140
|
self.mul = len(self.quantiles)
|
|
141
|
-
self.loss = CPRS()
|
|
141
|
+
self.loss = CPRS(alpha=self.persistence_weight)
|
|
142
142
|
else:
|
|
143
143
|
assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
144
144
|
self.use_quantiles = True
|
|
@@ -193,7 +193,9 @@ class Base(pl.LightningModule):
|
|
|
193
193
|
"""
|
|
194
194
|
if self.loss_type=='cprs':
|
|
195
195
|
tmp = self(batch)
|
|
196
|
-
|
|
196
|
+
tmp = torch.quantile(tmp, torch.tensor([0.05, 0.5, 0.95]), dim=-1).permute(1,2,3,0)
|
|
197
|
+
return tmp
|
|
198
|
+
#return tmp.mean(axis=-1).unsqueeze(-1)
|
|
197
199
|
|
|
198
200
|
return self(batch)
|
|
199
201
|
|
|
@@ -138,9 +138,9 @@ class Base(pl.LightningModule):
|
|
|
138
138
|
self.is_classification = False
|
|
139
139
|
if len(self.quantiles)>0:
|
|
140
140
|
if self.loss_type=='cprs':
|
|
141
|
-
self.use_quantiles =
|
|
141
|
+
self.use_quantiles = True
|
|
142
142
|
self.mul = len(self.quantiles)
|
|
143
|
-
self.loss = CPRS()
|
|
143
|
+
self.loss = CPRS(alpha=self.persistence_weight)
|
|
144
144
|
else:
|
|
145
145
|
assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
146
146
|
self.use_quantiles = True
|
|
@@ -197,7 +197,9 @@ class Base(pl.LightningModule):
|
|
|
197
197
|
|
|
198
198
|
if self.loss_type=='cprs':
|
|
199
199
|
tmp = self(batch)
|
|
200
|
-
|
|
200
|
+
tmp = torch.quantile(tmp, torch.tensor([0.05, 0.5, 0.95]), dim=-1).permute(1,2,3,0)
|
|
201
|
+
return tmp
|
|
202
|
+
#return tmp.mean(axis=-1).unsqueeze(-1)
|
|
201
203
|
|
|
202
204
|
return self(batch)
|
|
203
205
|
|
|
@@ -633,7 +633,7 @@ class CPRS(nn.Module):
|
|
|
633
633
|
with large ensembles.
|
|
634
634
|
"""
|
|
635
635
|
|
|
636
|
-
def __init__(self, alpha=0.
|
|
636
|
+
def __init__(self, alpha=0.5, reduction='mean'):
|
|
637
637
|
super().__init__()
|
|
638
638
|
self.alpha = alpha
|
|
639
639
|
self.reduction = reduction
|
|
@@ -674,17 +674,19 @@ class CPRS(nn.Module):
|
|
|
674
674
|
# Create mask to exclude diagonal (i=j)
|
|
675
675
|
mask = ~torch.eye(n_members, dtype=torch.bool, device=ensemble.device)
|
|
676
676
|
mask = mask.view(1, n_members, n_members, *[1]*(len(ensemble.shape)-2))
|
|
677
|
-
|
|
677
|
+
|
|
678
678
|
# Apply mask and compute mean
|
|
679
|
-
pairwise_term = (pairwise_diffs * mask).sum(dim=(1, 2))
|
|
679
|
+
pairwise_term = (pairwise_diffs * mask).sum(dim=(1, 2)) ##formula 3 second term
|
|
680
680
|
|
|
681
681
|
# Combine terms according to afCRPS formula
|
|
682
|
-
loss = mae_term - (1 - epsilon) * pairwise_term
|
|
682
|
+
loss = mae_term - (1 - epsilon) * pairwise_term/ (2*n_members * (n_members - 1))
|
|
683
683
|
|
|
684
684
|
# Apply weights if provided
|
|
685
685
|
if weights is not None:
|
|
686
686
|
loss = loss * weights
|
|
687
|
-
|
|
687
|
+
#if loss.mean()<-2:
|
|
688
|
+
# import pdb
|
|
689
|
+
# pdb.set_trace()
|
|
688
690
|
# Apply reduction
|
|
689
691
|
if self.reduction == 'none':
|
|
690
692
|
return loss
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|