dsipts 1.1.9__py3-none-any.whl → 1.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

dsipts/models/base.py CHANGED
@@ -136,9 +136,9 @@ class Base(pl.LightningModule):
136
136
  self.is_classification = False
137
137
  if len(self.quantiles)>0:
138
138
  if self.loss_type=='cprs':
139
- self.use_quantiles = False
139
+ self.use_quantiles = True
140
140
  self.mul = len(self.quantiles)
141
- self.loss = CPRS()
141
+ self.loss = CPRS(alpha=self.persistence_weight)
142
142
  else:
143
143
  assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
144
144
  self.use_quantiles = True
@@ -193,7 +193,9 @@ class Base(pl.LightningModule):
193
193
  """
194
194
  if self.loss_type=='cprs':
195
195
  tmp = self(batch)
196
- return tmp.mean(axis=-1)
196
+ tmp = torch.quantile(tmp, torch.tensor([0.05, 0.5, 0.95]), dim=-1).permute(1,2,3,0)
197
+ return tmp
198
+ #return tmp.mean(axis=-1).unsqueeze(-1)
197
199
 
198
200
  return self(batch)
199
201
 
dsipts/models/base_v2.py CHANGED
@@ -138,9 +138,9 @@ class Base(pl.LightningModule):
138
138
  self.is_classification = False
139
139
  if len(self.quantiles)>0:
140
140
  if self.loss_type=='cprs':
141
- self.use_quantiles = False
141
+ self.use_quantiles = True
142
142
  self.mul = len(self.quantiles)
143
- self.loss = CPRS()
143
+ self.loss = CPRS(alpha=self.persistence_weight)
144
144
  else:
145
145
  assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
146
146
  self.use_quantiles = True
@@ -197,7 +197,9 @@ class Base(pl.LightningModule):
197
197
 
198
198
  if self.loss_type=='cprs':
199
199
  tmp = self(batch)
200
- return tmp.mean(axis=-1)
200
+ tmp = torch.quantile(tmp, torch.tensor([0.05, 0.5, 0.95]), dim=-1).permute(1,2,3,0)
201
+ return tmp
202
+ #return tmp.mean(axis=-1).unsqueeze(-1)
201
203
 
202
204
  return self(batch)
203
205
 
dsipts/models/utils.py CHANGED
@@ -633,7 +633,7 @@ class CPRS(nn.Module):
633
633
  with large ensembles.
634
634
  """
635
635
 
636
- def __init__(self, alpha=0.95, reduction='mean'):
636
+ def __init__(self, alpha=0.5, reduction='mean'):
637
637
  super().__init__()
638
638
  self.alpha = alpha
639
639
  self.reduction = reduction
@@ -674,17 +674,19 @@ class CPRS(nn.Module):
674
674
  # Create mask to exclude diagonal (i=j)
675
675
  mask = ~torch.eye(n_members, dtype=torch.bool, device=ensemble.device)
676
676
  mask = mask.view(1, n_members, n_members, *[1]*(len(ensemble.shape)-2))
677
-
677
+
678
678
  # Apply mask and compute mean
679
- pairwise_term = (pairwise_diffs * mask).sum(dim=(1, 2)) / (n_members * (n_members - 1))
679
+ pairwise_term = (pairwise_diffs * mask).sum(dim=(1, 2)) ##formula 3 second term
680
680
 
681
681
  # Combine terms according to afCRPS formula
682
- loss = mae_term - (1 - epsilon) * pairwise_term
682
+ loss = mae_term - (1 - epsilon) * pairwise_term/ (2*n_members * (n_members - 1))
683
683
 
684
684
  # Apply weights if provided
685
685
  if weights is not None:
686
686
  loss = loss * weights
687
-
687
+ #if loss.mean()<-2:
688
+ # import pdb
689
+ # pdb.set_trace()
688
690
  # Apply reduction
689
691
  if self.reduction == 'none':
690
692
  return loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dsipts
3
- Version: 1.1.9
3
+ Version: 1.1.10
4
4
  Summary: Unified library for timeseries modelling
5
5
  Author-email: Andrea Gobbi <agobbi@fbk.eu>
6
6
  Project-URL: Homepage, https://github.com/DSIP-FBK/DSIPTS
@@ -28,9 +28,9 @@ dsipts/models/TimeXER.py,sha256=aCg0003LxYZzqZWyWugpbW_iOybcdHN4OH6_v77qp4o,7056
28
28
  dsipts/models/VQVAEA.py,sha256=sNJi8UZh-10qEIKcZK3SzhlOFUUjvqjoglzeZBFaeZM,13789
29
29
  dsipts/models/VVA.py,sha256=BnPkJ0Nzue0oShSHZVRNlf5RvT0Iwtf9bx19vLB9Nn0,11939
30
30
  dsipts/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
- dsipts/models/base.py,sha256=-K6ZxmXism231GqBxM3-pXE_KA4a4QWuYJ6FM_uSRl4,18859
32
- dsipts/models/base_v2.py,sha256=39EJO3m00HvT3zkn8PO67YEckAVa3Ez3NQ5oEnwz9g8,19137
33
- dsipts/models/utils.py,sha256=eBEpczdHn--ftK9I0pOiSY4ANGLzkw1WIL3SOoV9y7Y,24412
31
+ dsipts/models/base.py,sha256=0r_gGD9CPAVmuqTmySugTpCVUgoHJrwaMAqLx3P-ZBw,19021
32
+ dsipts/models/base_v2.py,sha256=b_RaVTBnA2dU4HpVPI-P0_VkmbsQHtYzxVf5iFVvp1U,19299
33
+ dsipts/models/utils.py,sha256=kjTwyktNCFMpPUy6zoleBCSKlvMvK_Jkgyh2T1OXg3E,24497
34
34
  dsipts/models/autoformer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
35
  dsipts/models/autoformer/layers.py,sha256=xHt8V1lKdD1cIvgxXdDbI_EqOz4zgOQ6LP8l7M1pAxM,13276
36
36
  dsipts/models/crossformer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -76,7 +76,7 @@ dsipts/models/vva/minigpt.py,sha256=bg0JddqSD322uxSGexen3nPXL_hGTsk3vNLR62d7-w8,
76
76
  dsipts/models/vva/vqvae.py,sha256=RzCQ_M9xBprp7_x20dSV3EQqlO0FjPUGWV-qdyKrQsM,19680
77
77
  dsipts/models/xlstm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
78
  dsipts/models/xlstm/xLSTM.py,sha256=ZKZZmffmIq1Vb71CR4GSyM8viqVx-u0FChxhcNgHub8,10081
79
- dsipts-1.1.9.dist-info/METADATA,sha256=vraJDpYWc4hhcOfaj3C4E5hACrTNlYSEgGsT2zKyiPs,24794
80
- dsipts-1.1.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
81
- dsipts-1.1.9.dist-info/top_level.txt,sha256=i6o0rf5ScFwZK21E89dSKjVNjUBkrEQpn0-Vij43748,7
82
- dsipts-1.1.9.dist-info/RECORD,,
79
+ dsipts-1.1.10.dist-info/METADATA,sha256=hwFJB926XiPZjhisLz-Usqpic_ty16lk3ZwvHoZHC0c,24795
80
+ dsipts-1.1.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
81
+ dsipts-1.1.10.dist-info/top_level.txt,sha256=i6o0rf5ScFwZK21E89dSKjVNjUBkrEQpn0-Vij43748,7
82
+ dsipts-1.1.10.dist-info/RECORD,,