dsipts 1.1.8__tar.gz → 1.1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- {dsipts-1.1.8 → dsipts-1.1.9}/PKG-INFO +1 -1
- {dsipts-1.1.8 → dsipts-1.1.9}/pyproject.toml +1 -1
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/base.py +22 -5
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/base_v2.py +25 -5
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/utils.py +73 -1
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts.egg-info/PKG-INFO +1 -1
- {dsipts-1.1.8 → dsipts-1.1.9}/README.md +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/setup.cfg +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/data_management/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/data_management/monash.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/data_management/public_datasets.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/data_structure/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/data_structure/data_structure.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/data_structure/modifiers.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/data_structure/utils.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/Autoformer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/CrossFormer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/D3VAE.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/Diffusion.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/DilatedConv.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/DilatedConvED.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/Duet.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/ITransformer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/Informer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/LinearTS.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/PatchTST.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/Persistent.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/RNN.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/Samformer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/Simple.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/TFT.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/TIDE.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/TTM.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/TimeXER.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/VQVAEA.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/VVA.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/autoformer/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/autoformer/layers.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/crossformer/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/crossformer/attn.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/crossformer/cross_decoder.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/crossformer/cross_embed.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/crossformer/cross_encoder.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/d3vae/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/d3vae/diffusion_process.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/d3vae/embedding.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/d3vae/encoder.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/d3vae/model.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/d3vae/neural_operations.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/d3vae/resnet.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/d3vae/utils.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/duet/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/duet/layers.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/duet/masked.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/informer/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/informer/attn.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/informer/decoder.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/informer/embed.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/informer/encoder.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/itransformer/Embed.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/itransformer/SelfAttention_Family.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/itransformer/Transformer_EncDec.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/itransformer/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/patchtst/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/patchtst/layers.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/samformer/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/samformer/utils.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/tft/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/tft/sub_nn.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/timexer/Layers.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/timexer/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/ttm/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/ttm/configuration_tinytimemixer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/ttm/consts.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/ttm/modeling_tinytimemixer.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/ttm/utils.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/vva/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/vva/minigpt.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/vva/vqvae.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/xlstm/__init__.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts/models/xlstm/xLSTM.py +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts.egg-info/SOURCES.txt +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts.egg-info/dependency_links.txt +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts.egg-info/requires.txt +0 -0
- {dsipts-1.1.8 → dsipts-1.1.9}/src/dsipts.egg-info/top_level.txt +0 -0
|
@@ -12,7 +12,7 @@ import numpy as np
|
|
|
12
12
|
from aim import Image
|
|
13
13
|
import matplotlib.pyplot as plt
|
|
14
14
|
from typing import List, Union
|
|
15
|
-
from .utils import QuantileLossMO
|
|
15
|
+
from .utils import QuantileLossMO, CPRS
|
|
16
16
|
import torch.nn as nn
|
|
17
17
|
|
|
18
18
|
def standardize_momentum(x,order):
|
|
@@ -135,10 +135,15 @@ class Base(pl.LightningModule):
|
|
|
135
135
|
if n_classes==0:
|
|
136
136
|
self.is_classification = False
|
|
137
137
|
if len(self.quantiles)>0:
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
138
|
+
if self.loss_type=='cprs':
|
|
139
|
+
self.use_quantiles = False
|
|
140
|
+
self.mul = len(self.quantiles)
|
|
141
|
+
self.loss = CPRS()
|
|
142
|
+
else:
|
|
143
|
+
assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
144
|
+
self.use_quantiles = True
|
|
145
|
+
self.mul = len(self.quantiles)
|
|
146
|
+
self.loss = QuantileLossMO(quantiles)
|
|
142
147
|
else:
|
|
143
148
|
self.use_quantiles = False
|
|
144
149
|
self.mul = 1
|
|
@@ -186,6 +191,10 @@ class Base(pl.LightningModule):
|
|
|
186
191
|
Returns:
|
|
187
192
|
torch.tensor: result
|
|
188
193
|
"""
|
|
194
|
+
if self.loss_type=='cprs':
|
|
195
|
+
tmp = self(batch)
|
|
196
|
+
return tmp.mean(axis=-1)
|
|
197
|
+
|
|
189
198
|
return self(batch)
|
|
190
199
|
|
|
191
200
|
def configure_optimizers(self):
|
|
@@ -357,6 +366,14 @@ class Base(pl.LightningModule):
|
|
|
357
366
|
:meta private:
|
|
358
367
|
"""
|
|
359
368
|
|
|
369
|
+
if self.loss_type=='cprs':
|
|
370
|
+
return self.loss(y_hat,batch['y'])
|
|
371
|
+
|
|
372
|
+
if self.loss_type=='long_lag':
|
|
373
|
+
batch_size,width,n_variables = batch['y'].shape
|
|
374
|
+
tmp = torch.abs(y_hat[:,:,:,0]-batch['y'])*torch.linspace(1,self.persistence_weight,width).view(1,width,1).repeat(batch_size,1,n_variables)
|
|
375
|
+
return tmp.mean()
|
|
376
|
+
|
|
360
377
|
if self.use_quantiles is False:
|
|
361
378
|
initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
|
|
362
379
|
else:
|
|
@@ -12,7 +12,7 @@ import numpy as np
|
|
|
12
12
|
from aim import Image
|
|
13
13
|
import matplotlib.pyplot as plt
|
|
14
14
|
from typing import List, Union
|
|
15
|
-
from .utils import QuantileLossMO
|
|
15
|
+
from .utils import QuantileLossMO, CPRS
|
|
16
16
|
import torch.nn as nn
|
|
17
17
|
|
|
18
18
|
|
|
@@ -137,10 +137,15 @@ class Base(pl.LightningModule):
|
|
|
137
137
|
if n_classes==0:
|
|
138
138
|
self.is_classification = False
|
|
139
139
|
if len(self.quantiles)>0:
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
140
|
+
if self.loss_type=='cprs':
|
|
141
|
+
self.use_quantiles = False
|
|
142
|
+
self.mul = len(self.quantiles)
|
|
143
|
+
self.loss = CPRS()
|
|
144
|
+
else:
|
|
145
|
+
assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
146
|
+
self.use_quantiles = True
|
|
147
|
+
self.mul = len(self.quantiles)
|
|
148
|
+
self.loss = QuantileLossMO(quantiles)
|
|
144
149
|
else:
|
|
145
150
|
self.use_quantiles = False
|
|
146
151
|
self.mul = 1
|
|
@@ -189,6 +194,11 @@ class Base(pl.LightningModule):
|
|
|
189
194
|
Returns:
|
|
190
195
|
torch.tensor: result
|
|
191
196
|
"""
|
|
197
|
+
|
|
198
|
+
if self.loss_type=='cprs':
|
|
199
|
+
tmp = self(batch)
|
|
200
|
+
return tmp.mean(axis=-1)
|
|
201
|
+
|
|
192
202
|
return self(batch)
|
|
193
203
|
|
|
194
204
|
def configure_optimizers(self):
|
|
@@ -365,6 +375,16 @@ class Base(pl.LightningModule):
|
|
|
365
375
|
|
|
366
376
|
:meta private:
|
|
367
377
|
"""
|
|
378
|
+
if self.loss_type=='cprs':
|
|
379
|
+
return self.loss(y_hat,batch['y'])
|
|
380
|
+
|
|
381
|
+
if self.loss_type=='long_lag':
|
|
382
|
+
|
|
383
|
+
batch_size,width,n_variables = batch['y'].shape
|
|
384
|
+
tmp = torch.abs(y_hat[:,:,:,0]-batch['y'])*torch.linspace(1,self.persistence_weight,width).view(1,width,1).repeat(batch_size,1,n_variables)
|
|
385
|
+
return tmp.mean()
|
|
386
|
+
|
|
387
|
+
|
|
368
388
|
|
|
369
389
|
if self.use_quantiles is False:
|
|
370
390
|
initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
|
|
@@ -621,4 +621,76 @@ class Embedding_cat_variables(nn.Module):
|
|
|
621
621
|
emb.append(layer(cat_vars[:, :, index]).unsqueeze(2))
|
|
622
622
|
|
|
623
623
|
cat_n_embd = torch.cat(emb,dim=2)
|
|
624
|
-
return cat_n_embd
|
|
624
|
+
return cat_n_embd
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
class CPRS(nn.Module):
|
|
629
|
+
"""
|
|
630
|
+
Efficient vectorized implementation of Almost Fair CRPS.
|
|
631
|
+
|
|
632
|
+
This version avoids explicit loops and uses broadcasting for better performance
|
|
633
|
+
with large ensembles.
|
|
634
|
+
"""
|
|
635
|
+
|
|
636
|
+
def __init__(self, alpha=0.95, reduction='mean'):
|
|
637
|
+
super().__init__()
|
|
638
|
+
self.alpha = alpha
|
|
639
|
+
self.reduction = reduction
|
|
640
|
+
|
|
641
|
+
def forward(self, y_hat, target, weights=None):
|
|
642
|
+
"""
|
|
643
|
+
Compute the almost fair CRPS loss (efficient version).
|
|
644
|
+
|
|
645
|
+
Args:
|
|
646
|
+
ensemble: Tensor of shape (batch_size, n_members, ...)
|
|
647
|
+
target: Tensor of shape (batch_size, ...)
|
|
648
|
+
weights: Optional per-variable or per-location weights
|
|
649
|
+
|
|
650
|
+
Returns:
|
|
651
|
+
Loss tensor
|
|
652
|
+
"""
|
|
653
|
+
## initial shape BS,width,n_variables,n_members need to go into batch_size, n_members, width, n_variables
|
|
654
|
+
ensemble = y_hat.permute(0,3,1,2)
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
batch_size, n_members = ensemble.shape[:2]
|
|
658
|
+
epsilon = (1 - self.alpha) / n_members
|
|
659
|
+
|
|
660
|
+
# Expand target to match ensemble shape
|
|
661
|
+
target_expanded = target.unsqueeze(1).expand_as(ensemble)
|
|
662
|
+
|
|
663
|
+
# Compute first term: mean absolute error to target
|
|
664
|
+
mae_term = torch.abs(ensemble - target_expanded).mean(dim=1)
|
|
665
|
+
|
|
666
|
+
# Compute second term: pairwise differences between ensemble members
|
|
667
|
+
# Use broadcasting to compute all pairwise differences efficiently
|
|
668
|
+
ensemble_i = ensemble.unsqueeze(2) # (batch, n_members, 1, ...)
|
|
669
|
+
ensemble_j = ensemble.unsqueeze(1) # (batch, 1, n_members, ...)
|
|
670
|
+
|
|
671
|
+
pairwise_diffs = torch.abs(ensemble_i - ensemble_j)
|
|
672
|
+
|
|
673
|
+
# Sum over all pairs (excluding diagonal)
|
|
674
|
+
# Create mask to exclude diagonal (i=j)
|
|
675
|
+
mask = ~torch.eye(n_members, dtype=torch.bool, device=ensemble.device)
|
|
676
|
+
mask = mask.view(1, n_members, n_members, *[1]*(len(ensemble.shape)-2))
|
|
677
|
+
|
|
678
|
+
# Apply mask and compute mean
|
|
679
|
+
pairwise_term = (pairwise_diffs * mask).sum(dim=(1, 2)) / (n_members * (n_members - 1))
|
|
680
|
+
|
|
681
|
+
# Combine terms according to afCRPS formula
|
|
682
|
+
loss = mae_term - (1 - epsilon) * pairwise_term
|
|
683
|
+
|
|
684
|
+
# Apply weights if provided
|
|
685
|
+
if weights is not None:
|
|
686
|
+
loss = loss * weights
|
|
687
|
+
|
|
688
|
+
# Apply reduction
|
|
689
|
+
if self.reduction == 'none':
|
|
690
|
+
return loss
|
|
691
|
+
elif self.reduction == 'sum':
|
|
692
|
+
return loss.sum()
|
|
693
|
+
elif self.reduction == 'mean':
|
|
694
|
+
return loss.mean()
|
|
695
|
+
else:
|
|
696
|
+
raise ValueError(f"Invalid reduction: {self.reduction}")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|