dsipts 1.1.8__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dsipts/models/base.py CHANGED
@@ -12,7 +12,7 @@ import numpy as np
12
12
  from aim import Image
13
13
  import matplotlib.pyplot as plt
14
14
  from typing import List, Union
15
- from .utils import QuantileLossMO
15
+ from .utils import QuantileLossMO, CPRS
16
16
  import torch.nn as nn
17
17
 
18
18
  def standardize_momentum(x,order):
@@ -135,10 +135,15 @@ class Base(pl.LightningModule):
135
135
  if n_classes==0:
136
136
  self.is_classification = False
137
137
  if len(self.quantiles)>0:
138
- assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
139
- self.use_quantiles = True
140
- self.mul = len(self.quantiles)
141
- self.loss = QuantileLossMO(quantiles)
138
+ if self.loss_type=='cprs':
139
+ self.use_quantiles = False
140
+ self.mul = len(self.quantiles)
141
+ self.loss = CPRS()
142
+ else:
143
+ assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
144
+ self.use_quantiles = True
145
+ self.mul = len(self.quantiles)
146
+ self.loss = QuantileLossMO(quantiles)
142
147
  else:
143
148
  self.use_quantiles = False
144
149
  self.mul = 1
@@ -186,6 +191,10 @@ class Base(pl.LightningModule):
186
191
  Returns:
187
192
  torch.tensor: result
188
193
  """
194
+ if self.loss_type=='cprs':
195
+ tmp = self(batch)
196
+ return tmp.mean(axis=-1)
197
+
189
198
  return self(batch)
190
199
 
191
200
  def configure_optimizers(self):
@@ -357,6 +366,14 @@ class Base(pl.LightningModule):
357
366
  :meta private:
358
367
  """
359
368
 
369
+ if self.loss_type=='cprs':
370
+ return self.loss(y_hat,batch['y'])
371
+
372
+ if self.loss_type=='long_lag':
373
+ batch_size,width,n_variables = batch['y'].shape
374
+ tmp = torch.abs(y_hat[:,:,:,0]-batch['y'])*torch.linspace(1,self.persistence_weight,width).view(1,width,1).repeat(batch_size,1,n_variables)
375
+ return tmp.mean()
376
+
360
377
  if self.use_quantiles is False:
361
378
  initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
362
379
  else:
dsipts/models/base_v2.py CHANGED
@@ -12,7 +12,7 @@ import numpy as np
12
12
  from aim import Image
13
13
  import matplotlib.pyplot as plt
14
14
  from typing import List, Union
15
- from .utils import QuantileLossMO
15
+ from .utils import QuantileLossMO, CPRS
16
16
  import torch.nn as nn
17
17
 
18
18
 
@@ -137,10 +137,15 @@ class Base(pl.LightningModule):
137
137
  if n_classes==0:
138
138
  self.is_classification = False
139
139
  if len(self.quantiles)>0:
140
- assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
141
- self.use_quantiles = True
142
- self.mul = len(self.quantiles)
143
- self.loss = QuantileLossMO(quantiles)
140
+ if self.loss_type=='cprs':
141
+ self.use_quantiles = False
142
+ self.mul = len(self.quantiles)
143
+ self.loss = CPRS()
144
+ else:
145
+ assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
146
+ self.use_quantiles = True
147
+ self.mul = len(self.quantiles)
148
+ self.loss = QuantileLossMO(quantiles)
144
149
  else:
145
150
  self.use_quantiles = False
146
151
  self.mul = 1
@@ -189,6 +194,11 @@ class Base(pl.LightningModule):
189
194
  Returns:
190
195
  torch.tensor: result
191
196
  """
197
+
198
+ if self.loss_type=='cprs':
199
+ tmp = self(batch)
200
+ return tmp.mean(axis=-1)
201
+
192
202
  return self(batch)
193
203
 
194
204
  def configure_optimizers(self):
@@ -365,6 +375,16 @@ class Base(pl.LightningModule):
365
375
 
366
376
  :meta private:
367
377
  """
378
+ if self.loss_type=='cprs':
379
+ return self.loss(y_hat,batch['y'])
380
+
381
+ if self.loss_type=='long_lag':
382
+
383
+ batch_size,width,n_variables = batch['y'].shape
384
+ tmp = torch.abs(y_hat[:,:,:,0]-batch['y'])*torch.linspace(1,self.persistence_weight,width).view(1,width,1).repeat(batch_size,1,n_variables)
385
+ return tmp.mean()
386
+
387
+
368
388
 
369
389
  if self.use_quantiles is False:
370
390
  initial_loss = self.loss(y_hat[:,:,:,0], batch['y'])
dsipts/models/utils.py CHANGED
@@ -621,4 +621,76 @@ class Embedding_cat_variables(nn.Module):
621
621
  emb.append(layer(cat_vars[:, :, index]).unsqueeze(2))
622
622
 
623
623
  cat_n_embd = torch.cat(emb,dim=2)
624
- return cat_n_embd
624
+ return cat_n_embd
625
+
626
+
627
+
628
+ class CPRS(nn.Module):
629
+ """
630
+ Efficient vectorized implementation of Almost Fair CRPS.
631
+
632
+ This version avoids explicit loops and uses broadcasting for better performance
633
+ with large ensembles.
634
+ """
635
+
636
+ def __init__(self, alpha=0.95, reduction='mean'):
637
+ super().__init__()
638
+ self.alpha = alpha
639
+ self.reduction = reduction
640
+
641
+ def forward(self, y_hat, target, weights=None):
642
+ """
643
+ Compute the almost fair CRPS loss (efficient version).
644
+
645
+ Args:
646
+ ensemble: Tensor of shape (batch_size, n_members, ...)
647
+ target: Tensor of shape (batch_size, ...)
648
+ weights: Optional per-variable or per-location weights
649
+
650
+ Returns:
651
+ Loss tensor
652
+ """
653
+ ## initial shape BS,width,n_variables,n_members need to go into batch_size, n_members, width, n_variables
654
+ ensemble = y_hat.permute(0,3,1,2)
655
+
656
+
657
+ batch_size, n_members = ensemble.shape[:2]
658
+ epsilon = (1 - self.alpha) / n_members
659
+
660
+ # Expand target to match ensemble shape
661
+ target_expanded = target.unsqueeze(1).expand_as(ensemble)
662
+
663
+ # Compute first term: mean absolute error to target
664
+ mae_term = torch.abs(ensemble - target_expanded).mean(dim=1)
665
+
666
+ # Compute second term: pairwise differences between ensemble members
667
+ # Use broadcasting to compute all pairwise differences efficiently
668
+ ensemble_i = ensemble.unsqueeze(2) # (batch, n_members, 1, ...)
669
+ ensemble_j = ensemble.unsqueeze(1) # (batch, 1, n_members, ...)
670
+
671
+ pairwise_diffs = torch.abs(ensemble_i - ensemble_j)
672
+
673
+ # Sum over all pairs (excluding diagonal)
674
+ # Create mask to exclude diagonal (i=j)
675
+ mask = ~torch.eye(n_members, dtype=torch.bool, device=ensemble.device)
676
+ mask = mask.view(1, n_members, n_members, *[1]*(len(ensemble.shape)-2))
677
+
678
+ # Apply mask and compute mean
679
+ pairwise_term = (pairwise_diffs * mask).sum(dim=(1, 2)) / (n_members * (n_members - 1))
680
+
681
+ # Combine terms according to afCRPS formula
682
+ loss = mae_term - (1 - epsilon) * pairwise_term
683
+
684
+ # Apply weights if provided
685
+ if weights is not None:
686
+ loss = loss * weights
687
+
688
+ # Apply reduction
689
+ if self.reduction == 'none':
690
+ return loss
691
+ elif self.reduction == 'sum':
692
+ return loss.sum()
693
+ elif self.reduction == 'mean':
694
+ return loss.mean()
695
+ else:
696
+ raise ValueError(f"Invalid reduction: {self.reduction}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dsipts
3
- Version: 1.1.8
3
+ Version: 1.1.9
4
4
  Summary: Unified library for timeseries modelling
5
5
  Author-email: Andrea Gobbi <agobbi@fbk.eu>
6
6
  Project-URL: Homepage, https://github.com/DSIP-FBK/DSIPTS
@@ -28,9 +28,9 @@ dsipts/models/TimeXER.py,sha256=aCg0003LxYZzqZWyWugpbW_iOybcdHN4OH6_v77qp4o,7056
28
28
  dsipts/models/VQVAEA.py,sha256=sNJi8UZh-10qEIKcZK3SzhlOFUUjvqjoglzeZBFaeZM,13789
29
29
  dsipts/models/VVA.py,sha256=BnPkJ0Nzue0oShSHZVRNlf5RvT0Iwtf9bx19vLB9Nn0,11939
30
30
  dsipts/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
- dsipts/models/base.py,sha256=mIsEUkuyj_2MlYEvH97PPD790DrS0PQw4UCiWN8uqKI,18159
32
- dsipts/models/base_v2.py,sha256=jjlX5fIw2stCx5J3i3xFTgzYmCX-n8Lf4-4cLoq-diQ,18426
33
- dsipts/models/utils.py,sha256=H1lr1lukDk7FNyXXTJh217tyTBsBW8hVDQ6jL9oev7I,21765
31
+ dsipts/models/base.py,sha256=-K6ZxmXism231GqBxM3-pXE_KA4a4QWuYJ6FM_uSRl4,18859
32
+ dsipts/models/base_v2.py,sha256=39EJO3m00HvT3zkn8PO67YEckAVa3Ez3NQ5oEnwz9g8,19137
33
+ dsipts/models/utils.py,sha256=eBEpczdHn--ftK9I0pOiSY4ANGLzkw1WIL3SOoV9y7Y,24412
34
34
  dsipts/models/autoformer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
35
  dsipts/models/autoformer/layers.py,sha256=xHt8V1lKdD1cIvgxXdDbI_EqOz4zgOQ6LP8l7M1pAxM,13276
36
36
  dsipts/models/crossformer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -76,7 +76,7 @@ dsipts/models/vva/minigpt.py,sha256=bg0JddqSD322uxSGexen3nPXL_hGTsk3vNLR62d7-w8,
76
76
  dsipts/models/vva/vqvae.py,sha256=RzCQ_M9xBprp7_x20dSV3EQqlO0FjPUGWV-qdyKrQsM,19680
77
77
  dsipts/models/xlstm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
78
  dsipts/models/xlstm/xLSTM.py,sha256=ZKZZmffmIq1Vb71CR4GSyM8viqVx-u0FChxhcNgHub8,10081
79
- dsipts-1.1.8.dist-info/METADATA,sha256=fObwUSnqEBaCA_sDxvmOnfKsmb-Mu9gOrITzl3Tp4qQ,24794
80
- dsipts-1.1.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
81
- dsipts-1.1.8.dist-info/top_level.txt,sha256=i6o0rf5ScFwZK21E89dSKjVNjUBkrEQpn0-Vij43748,7
82
- dsipts-1.1.8.dist-info/RECORD,,
79
+ dsipts-1.1.9.dist-info/METADATA,sha256=vraJDpYWc4hhcOfaj3C4E5hACrTNlYSEgGsT2zKyiPs,24794
80
+ dsipts-1.1.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
81
+ dsipts-1.1.9.dist-info/top_level.txt,sha256=i6o0rf5ScFwZK21E89dSKjVNjUBkrEQpn0-Vij43748,7
82
+ dsipts-1.1.9.dist-info/RECORD,,
File without changes