dsipts 1.1.8__tar.gz → 1.1.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- {dsipts-1.1.8 → dsipts-1.1.10}/PKG-INFO +1 -1
- {dsipts-1.1.8 → dsipts-1.1.10}/pyproject.toml +1 -1
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/base.py +24 -5
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/base_v2.py +27 -5
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/utils.py +75 -1
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts.egg-info/PKG-INFO +1 -1
- {dsipts-1.1.8 → dsipts-1.1.10}/README.md +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/setup.cfg +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/data_management/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/data_management/monash.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/data_management/public_datasets.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/data_structure/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/data_structure/data_structure.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/data_structure/modifiers.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/data_structure/utils.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/Autoformer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/CrossFormer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/D3VAE.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/Diffusion.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/DilatedConv.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/DilatedConvED.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/Duet.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/ITransformer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/Informer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/LinearTS.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/PatchTST.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/Persistent.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/RNN.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/Samformer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/Simple.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/TFT.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/TIDE.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/TTM.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/TimeXER.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/VQVAEA.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/VVA.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/autoformer/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/autoformer/layers.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/crossformer/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/crossformer/attn.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/crossformer/cross_decoder.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/crossformer/cross_embed.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/crossformer/cross_encoder.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/d3vae/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/d3vae/diffusion_process.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/d3vae/embedding.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/d3vae/encoder.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/d3vae/model.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/d3vae/neural_operations.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/d3vae/resnet.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/d3vae/utils.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/duet/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/duet/layers.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/duet/masked.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/informer/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/informer/attn.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/informer/decoder.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/informer/embed.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/informer/encoder.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/itransformer/Embed.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/itransformer/SelfAttention_Family.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/itransformer/Transformer_EncDec.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/itransformer/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/patchtst/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/patchtst/layers.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/samformer/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/samformer/utils.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/tft/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/tft/sub_nn.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/timexer/Layers.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/timexer/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/ttm/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/ttm/configuration_tinytimemixer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/ttm/consts.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/ttm/modeling_tinytimemixer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/ttm/utils.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/vva/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/vva/minigpt.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/vva/vqvae.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/xlstm/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts/models/xlstm/xLSTM.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts.egg-info/SOURCES.txt +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts.egg-info/dependency_links.txt +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts.egg-info/requires.txt +0 -0
- {dsipts-1.1.8 → dsipts-1.1.10}/src/dsipts.egg-info/top_level.txt +0 -0
|
@@ -12,7 +12,7 @@ import numpy as np
|
|
|
12
12
|
from aim import Image
|
|
13
13
|
import matplotlib.pyplot as plt
|
|
14
14
|
from typing import List, Union
|
|
15
|
-
from .utils import QuantileLossMO
|
|
15
|
+
from .utils import QuantileLossMO, CPRS
|
|
16
16
|
import torch.nn as nn
|
|
17
17
|
|
|
18
18
|
def standardize_momentum(x,order):
|
|
@@ -135,10 +135,15 @@ class Base(pl.LightningModule):
|
|
|
135
135
|
if n_classes==0:
|
|
136
136
|
self.is_classification = False
|
|
137
137
|
if len(self.quantiles)>0:
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
138
|
+
if self.loss_type=='cprs':
|
|
139
|
+
self.use_quantiles = True
|
|
140
|
+
self.mul = len(self.quantiles)
|
|
141
|
+
self.loss = CPRS(alpha=self.persistence_weight)
|
|
142
|
+
else:
|
|
143
|
+
assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
144
|
+
self.use_quantiles = True
|
|
145
|
+
self.mul = len(self.quantiles)
|
|
146
|
+
self.loss = QuantileLossMO(quantiles)
|
|
142
147
|
else:
|
|
143
148
|
self.use_quantiles = False
|
|
144
149
|
self.mul = 1
|
|
@@ -186,6 +191,12 @@ class Base(pl.LightningModule):
|
|
|
186
191
|
Returns:
|
|
187
192
|
torch.tensor: result
|
|
188
193
|
"""
|
|
194
|
+
if self.loss_type=='cprs':
|
|
195
|
+
tmp = self(batch)
|
|
196
|
+
tmp = torch.quantile(tmp, torch.tensor([0.05, 0.5, 0.95]), dim=-1).permute(1,2,3,0)
|
|
197
|
+
return tmp
|
|
198
|
+
#return tmp.mean(axis=-1).unsqueeze(-1)
|
|
199
|
+
|
|
189
200
|
return self(batch)
|
|
190
201
|
|
|
191
202
|
def configure_optimizers(self):
|
|
@@ -357,6 +368,14 @@ class Base(pl.LightningModule):
|
|
|
357
368
|
:meta private:
|
|
358
369
|
"""
|
|
359
370
|
|
|
371
|
+
if self.loss_type=='cprs':
|
|
372
|
+
return self.loss(y_hat,batch['y'])
|
|
373
|
+
|
|
374
|
+
if self.loss_type=='long_lag':
|
|
375
|
+
batch_size,width,n_variables = batch['y'].shape
|
|
376
|
+
tmp = torch.abs(y_hat[:,:,:,0]-batch['y'])*torch.linspace(1,self.persistence_weight,width).view(1,width,1).repeat(batch_size,1,n_variables)
|
|
377
|
+
return tmp.mean()
|
|
378
|
+
|
|
360
379
|
if self.use_quantiles is False:
|
|
361
380
|
initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
|
|
362
381
|
else:
|
|
@@ -12,7 +12,7 @@ import numpy as np
|
|
|
12
12
|
from aim import Image
|
|
13
13
|
import matplotlib.pyplot as plt
|
|
14
14
|
from typing import List, Union
|
|
15
|
-
from .utils import QuantileLossMO
|
|
15
|
+
from .utils import QuantileLossMO, CPRS
|
|
16
16
|
import torch.nn as nn
|
|
17
17
|
|
|
18
18
|
|
|
@@ -137,10 +137,15 @@ class Base(pl.LightningModule):
|
|
|
137
137
|
if n_classes==0:
|
|
138
138
|
self.is_classification = False
|
|
139
139
|
if len(self.quantiles)>0:
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
140
|
+
if self.loss_type=='cprs':
|
|
141
|
+
self.use_quantiles = True
|
|
142
|
+
self.mul = len(self.quantiles)
|
|
143
|
+
self.loss = CPRS(alpha=self.persistence_weight)
|
|
144
|
+
else:
|
|
145
|
+
assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
146
|
+
self.use_quantiles = True
|
|
147
|
+
self.mul = len(self.quantiles)
|
|
148
|
+
self.loss = QuantileLossMO(quantiles)
|
|
144
149
|
else:
|
|
145
150
|
self.use_quantiles = False
|
|
146
151
|
self.mul = 1
|
|
@@ -189,6 +194,13 @@ class Base(pl.LightningModule):
|
|
|
189
194
|
Returns:
|
|
190
195
|
torch.tensor: result
|
|
191
196
|
"""
|
|
197
|
+
|
|
198
|
+
if self.loss_type=='cprs':
|
|
199
|
+
tmp = self(batch)
|
|
200
|
+
tmp = torch.quantile(tmp, torch.tensor([0.05, 0.5, 0.95]), dim=-1).permute(1,2,3,0)
|
|
201
|
+
return tmp
|
|
202
|
+
#return tmp.mean(axis=-1).unsqueeze(-1)
|
|
203
|
+
|
|
192
204
|
return self(batch)
|
|
193
205
|
|
|
194
206
|
def configure_optimizers(self):
|
|
@@ -365,6 +377,16 @@ class Base(pl.LightningModule):
|
|
|
365
377
|
|
|
366
378
|
:meta private:
|
|
367
379
|
"""
|
|
380
|
+
if self.loss_type=='cprs':
|
|
381
|
+
return self.loss(y_hat,batch['y'])
|
|
382
|
+
|
|
383
|
+
if self.loss_type=='long_lag':
|
|
384
|
+
|
|
385
|
+
batch_size,width,n_variables = batch['y'].shape
|
|
386
|
+
tmp = torch.abs(y_hat[:,:,:,0]-batch['y'])*torch.linspace(1,self.persistence_weight,width).view(1,width,1).repeat(batch_size,1,n_variables)
|
|
387
|
+
return tmp.mean()
|
|
388
|
+
|
|
389
|
+
|
|
368
390
|
|
|
369
391
|
if self.use_quantiles is False:
|
|
370
392
|
initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
|
|
@@ -621,4 +621,78 @@ class Embedding_cat_variables(nn.Module):
|
|
|
621
621
|
emb.append(layer(cat_vars[:, :, index]).unsqueeze(2))
|
|
622
622
|
|
|
623
623
|
cat_n_embd = torch.cat(emb,dim=2)
|
|
624
|
-
return cat_n_embd
|
|
624
|
+
return cat_n_embd
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
class CPRS(nn.Module):
|
|
629
|
+
"""
|
|
630
|
+
Efficient vectorized implementation of Almost Fair CRPS.
|
|
631
|
+
|
|
632
|
+
This version avoids explicit loops and uses broadcasting for better performance
|
|
633
|
+
with large ensembles.
|
|
634
|
+
"""
|
|
635
|
+
|
|
636
|
+
def __init__(self, alpha=0.5, reduction='mean'):
|
|
637
|
+
super().__init__()
|
|
638
|
+
self.alpha = alpha
|
|
639
|
+
self.reduction = reduction
|
|
640
|
+
|
|
641
|
+
def forward(self, y_hat, target, weights=None):
|
|
642
|
+
"""
|
|
643
|
+
Compute the almost fair CRPS loss (efficient version).
|
|
644
|
+
|
|
645
|
+
Args:
|
|
646
|
+
ensemble: Tensor of shape (batch_size, n_members, ...)
|
|
647
|
+
target: Tensor of shape (batch_size, ...)
|
|
648
|
+
weights: Optional per-variable or per-location weights
|
|
649
|
+
|
|
650
|
+
Returns:
|
|
651
|
+
Loss tensor
|
|
652
|
+
"""
|
|
653
|
+
## initial shape BS,width,n_variables,n_members need to go into batch_size, n_members, width, n_variables
|
|
654
|
+
ensemble = y_hat.permute(0,3,1,2)
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
batch_size, n_members = ensemble.shape[:2]
|
|
658
|
+
epsilon = (1 - self.alpha) / n_members
|
|
659
|
+
|
|
660
|
+
# Expand target to match ensemble shape
|
|
661
|
+
target_expanded = target.unsqueeze(1).expand_as(ensemble)
|
|
662
|
+
|
|
663
|
+
# Compute first term: mean absolute error to target
|
|
664
|
+
mae_term = torch.abs(ensemble - target_expanded).mean(dim=1)
|
|
665
|
+
|
|
666
|
+
# Compute second term: pairwise differences between ensemble members
|
|
667
|
+
# Use broadcasting to compute all pairwise differences efficiently
|
|
668
|
+
ensemble_i = ensemble.unsqueeze(2) # (batch, n_members, 1, ...)
|
|
669
|
+
ensemble_j = ensemble.unsqueeze(1) # (batch, 1, n_members, ...)
|
|
670
|
+
|
|
671
|
+
pairwise_diffs = torch.abs(ensemble_i - ensemble_j)
|
|
672
|
+
|
|
673
|
+
# Sum over all pairs (excluding diagonal)
|
|
674
|
+
# Create mask to exclude diagonal (i=j)
|
|
675
|
+
mask = ~torch.eye(n_members, dtype=torch.bool, device=ensemble.device)
|
|
676
|
+
mask = mask.view(1, n_members, n_members, *[1]*(len(ensemble.shape)-2))
|
|
677
|
+
|
|
678
|
+
# Apply mask and compute mean
|
|
679
|
+
pairwise_term = (pairwise_diffs * mask).sum(dim=(1, 2)) ##formula 3 second term
|
|
680
|
+
|
|
681
|
+
# Combine terms according to afCRPS formula
|
|
682
|
+
loss = mae_term - (1 - epsilon) * pairwise_term/ (2*n_members * (n_members - 1))
|
|
683
|
+
|
|
684
|
+
# Apply weights if provided
|
|
685
|
+
if weights is not None:
|
|
686
|
+
loss = loss * weights
|
|
687
|
+
#if loss.mean()<-2:
|
|
688
|
+
# import pdb
|
|
689
|
+
# pdb.set_trace()
|
|
690
|
+
# Apply reduction
|
|
691
|
+
if self.reduction == 'none':
|
|
692
|
+
return loss
|
|
693
|
+
elif self.reduction == 'sum':
|
|
694
|
+
return loss.sum()
|
|
695
|
+
elif self.reduction == 'mean':
|
|
696
|
+
return loss.mean()
|
|
697
|
+
else:
|
|
698
|
+
raise ValueError(f"Invalid reduction: {self.reduction}")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|