datastock 0.0.33__py3-none-any.whl → 0.0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -58,37 +58,37 @@ def binning(
58
58
  store_keys=None,
59
59
  ):
60
60
  """ Return the binned data
61
-
61
+
62
62
  data: the data on which to apply binning, can be
63
63
  - a list of np.ndarray to be binned
64
64
  (any dimension as long as they all have the same)
65
65
  - a list of keys to ddata items sharing the same refs
66
-
66
+
67
67
  data_units: str only necessary if data is a list of arrays
68
-
68
+
69
69
  axis: int or array of int indices
70
70
  the axis of data along which to bin
71
71
  data will be flattened along all those axis priori to binning
72
- If None, assumes bin_data is not variable and uses all its axis
73
-
72
+ If None, assumes bin_data is not variable and uses all its axis
73
+
74
74
  bins0: the bins (centers), can be
75
75
  - a 1d vector of monotonous bins
76
76
  - a int, used to compute a bins vector from max(data), min(data)
77
-
77
+
78
78
  bin_data0: the data used to compute binning indices, can be:
79
79
  - a str, key to a ddata item
80
80
  - a np.ndarray
81
81
  _ a list of any of the above if each data has different size along axis
82
-
82
+
83
83
  bin_units: str
84
84
  only used if integrate = True and bin_data is a np.ndarray
85
-
85
+
86
86
  integrate: bool
87
87
  flag indicating whether binning is used for integration
88
88
  Implies that:
89
89
  Only usable for 1d binning (axis has to be a single index)
90
90
  data is multiplied by the underlying bin_data0 step prior to binning
91
-
91
+
92
92
  statistic: str
93
93
  the statistic kwd feed to scipy.stats.binned_statistic()
94
94
  automatically set to 'sum' if integrate = True
@@ -117,7 +117,7 @@ def binning(
117
117
 
118
118
  dout = {k0: {'units': v0['units']} for k0, v0 in ddata.items()}
119
119
  for k0, v0 in ddata.items():
120
-
120
+
121
121
  # handle dbins1
122
122
  if dbins1 is None:
123
123
  bins1, vect1, bin_ref1 = None, None, None
@@ -125,7 +125,7 @@ def binning(
125
125
  bins1 = dbins1['edges']
126
126
  vect1 = dbins1['data']
127
127
  bin_ref1 = dbins1[k0].get('bin_ref')
128
-
128
+
129
129
  # compute
130
130
  dout[k0]['data'], dout[k0]['ref'] = _bin_fixed_bin(
131
131
  # data to bin
@@ -146,7 +146,7 @@ def binning(
146
146
  # integration
147
147
  variable_data=dvariable['data'],
148
148
  )
149
-
149
+
150
150
  else:
151
151
  msg = (
152
152
  "Variable bin vectors not implemented yet!\n"
@@ -158,9 +158,9 @@ def binning(
158
158
 
159
159
  # --------------
160
160
  # storing
161
-
161
+
162
162
  if store is True:
163
-
163
+
164
164
  _store(
165
165
  coll=coll,
166
166
  dout=dout,
@@ -169,7 +169,7 @@ def binning(
169
169
 
170
170
  # -------------
171
171
  # return
172
-
172
+
173
173
  if returnas is True:
174
174
  return dout
175
175
 
@@ -208,14 +208,14 @@ def _check(
208
208
  # -----------------
209
209
  # store and verb
210
210
  # -------------------
211
-
211
+
212
212
  # verb
213
213
  verb = _generic_check._check_var(
214
214
  verb, 'verb',
215
215
  types=bool,
216
216
  default=True,
217
217
  )
218
-
218
+
219
219
  # ------------------
220
220
  # data: str vs array
221
221
  # -------------------
@@ -226,9 +226,9 @@ def _check(
226
226
  data_units=data_units,
227
227
  store=store,
228
228
  )
229
-
230
- ndim_data = list(ddata.values())[0]['data'].ndim
231
-
229
+
230
+ ndim_data = list(ddata.values())[0]['data'].ndim
231
+
232
232
  # -----------------
233
233
  # check statistic
234
234
  # -------------------
@@ -242,11 +242,11 @@ def _check(
242
242
  types=str,
243
243
  default='sum',
244
244
  )
245
-
245
+
246
246
  # -----------
247
247
  # bins
248
248
  # ------------
249
-
249
+
250
250
  dbins0 = _check_bins(
251
251
  coll=coll,
252
252
  lkdata=list(ddata.keys()),
@@ -262,7 +262,7 @@ def _check(
262
262
  dref_vector=dref_vector,
263
263
  store=store,
264
264
  )
265
-
265
+
266
266
  # -----------
267
267
  # bins
268
268
  # ------------
@@ -279,12 +279,12 @@ def _check(
279
279
  safety_ratio=safety_ratio,
280
280
  store=store,
281
281
  )
282
-
282
+
283
283
  # data vs axis
284
284
  if np.any(axis > ndim_data - 1):
285
285
  msg = f"axis too large\n{axis}"
286
286
  raise Exception(msg)
287
-
287
+
288
288
  variable_data = len(axis) < ndim_data
289
289
 
290
290
  # dbins1
@@ -300,11 +300,11 @@ def _check(
300
300
  safety_ratio=safety_ratio,
301
301
  store=store,
302
302
  )
303
-
303
+
304
304
  if variable_bin0 != variable_bin1:
305
305
  msg = "bin_data0 and bin_data1 have different shapes, todo"
306
306
  raise NotImplementedError(msg)
307
-
307
+
308
308
  else:
309
309
  dbins1 = None
310
310
  variable_bin1 = False
@@ -312,36 +312,36 @@ def _check(
312
312
  # -----------------
313
313
  # check integrate
314
314
  # -------------------
315
-
315
+
316
316
  # integrate
317
317
  integrate = _generic_check._check_var(
318
318
  integrate, 'integrate',
319
319
  types=bool,
320
320
  default=False,
321
321
  )
322
-
322
+
323
323
  # safety checks
324
324
  if integrate is True:
325
-
325
+
326
326
  if bin_data1 is not None:
327
327
  msg = (
328
328
  "If integrate = True, bin_data1 must be None!\n"
329
329
  "\t- bin_data1: {bin_data1}\n"
330
330
  )
331
331
  raise Exception(msg)
332
-
332
+
333
333
  if len(axis) > 1:
334
334
  msg = (
335
335
  "If integrate is true, binning can only be done on one axis!\n"
336
336
  f"\t- axis: {axis}\n"
337
337
  )
338
338
  raise Exception(msg)
339
-
339
+
340
340
 
341
341
  # -----------------------
342
342
  # additional safety check
343
343
 
344
- if integrate is True:
344
+ if integrate is True:
345
345
 
346
346
  if variable_bin0:
347
347
  axbin = axis[0]
@@ -349,12 +349,12 @@ def _check(
349
349
  axbin = 0
350
350
 
351
351
  for k0, v0 in ddata.items():
352
-
352
+
353
353
  ddata[k0]['units'] = v0['units'] * dbins0[k0]['units']
354
354
  if dbins0[k0]['data'].size == 0:
355
355
  continue
356
-
357
- dv = np.diff(dbins0[k0]['data'], axis=axbin)
356
+
357
+ dv = np.diff(dbins0[k0]['data'], axis=axbin)
358
358
  dv = np.concatenate(
359
359
  (np.take(dv, [0], axis=axbin), dv),
360
360
  axis=axbin,
@@ -362,12 +362,12 @@ def _check(
362
362
 
363
363
  # reshape
364
364
  if variable_data != variable_bin0:
365
-
365
+
366
366
  if variable_data:
367
367
  shape_dv = np.ones((ndim_data,), dtype=int)
368
368
  shape_dv[axis[0]] = -1
369
369
  dv = dv.reshape(tuple(shape_dv))
370
-
370
+
371
371
  if variable_bin0:
372
372
  raise NotImplementedError()
373
373
 
@@ -375,16 +375,16 @@ def _check(
375
375
 
376
376
  # --------
377
377
  # variability dict
378
-
378
+
379
379
  dvariable = {
380
380
  'data': variable_data,
381
381
  'bin0': variable_bin0,
382
382
  'bin1': variable_bin1,
383
383
  }
384
-
384
+
385
385
  # --------
386
386
  # returnas
387
-
387
+
388
388
  returnas = _generic_check._check_var(
389
389
  returnas, 'returnas',
390
390
  types=bool,
@@ -407,23 +407,23 @@ def _check_data(
407
407
  ):
408
408
  # -----------
409
409
  # store
410
-
410
+
411
411
  store = _generic_check._check_var(
412
412
  store, 'store',
413
413
  types=bool,
414
414
  default=False,
415
415
  )
416
-
416
+
417
417
  # ---------------------
418
418
  # make sure it's a list
419
-
419
+
420
420
  if isinstance(data, (np.ndarray, str)):
421
421
  data = [data]
422
422
  assert isinstance(data, list)
423
423
 
424
424
  # ------------------------------------------------
425
425
  # identify case: str vs array, all with same ndim
426
-
426
+
427
427
  lc = [
428
428
  all([
429
429
  isinstance(dd, str)
@@ -437,9 +437,9 @@ def _check_data(
437
437
  for dd in data
438
438
  ]),
439
439
  ]
440
-
440
+
441
441
  # vs store
442
- if store is True:
442
+ if store is True:
443
443
  if not lc[0]:
444
444
  msg = "If storing, all data, bin data and bins must be declared!"
445
445
  raise Exception(msg)
@@ -457,7 +457,7 @@ def _check_data(
457
457
 
458
458
  # --------------------
459
459
  # sort cases
460
-
460
+
461
461
  # str => keys to existing data
462
462
  if lc[0]:
463
463
  ddata = {
@@ -480,8 +480,8 @@ def _check_data(
480
480
  'units': data_units,
481
481
  }
482
482
  for ii in range(len(data))
483
- }
484
-
483
+ }
484
+
485
485
  return ddata
486
486
 
487
487
 
@@ -505,7 +505,7 @@ def _check_bins(
505
505
  lok_bins = list(coll.dobj.get(wb, {}).keys())
506
506
  else:
507
507
  lok_bins = []
508
-
508
+
509
509
  bins = _generic_check._check_var(
510
510
  bins, 'bins',
511
511
  types=str,
@@ -519,31 +519,31 @@ def _check_bins(
519
519
  unique=True,
520
520
  can_be_None=False,
521
521
  )
522
-
522
+
523
523
  # --------------
524
524
  # check vs store
525
-
525
+
526
526
  if store is True and not isinstance(bins, str):
527
527
  msg = "With store=True, bins must be keys to coll.dobj['bins'] items!"
528
528
  raise Exception(msg)
529
-
529
+
530
530
  # ----------------------------
531
531
  # compute bin edges if needed
532
-
532
+
533
533
  if isinstance(bins, str):
534
-
534
+
535
535
  if bins in lok_bins:
536
536
  for k0 in lkdata:
537
537
  dbins[k0]['bin_ref'] = coll.dobj[wb][bins]['ref']
538
538
  dbins[k0]['edges'] = coll.dobj[wb][bins]['edges']
539
-
539
+
540
540
  else:
541
-
541
+
542
542
  if bins in lok_ref:
543
-
543
+
544
544
  if dref_vector is None:
545
545
  dref_vector = {}
546
-
546
+
547
547
  bins = coll.get_ref_vector(
548
548
  ref=bins,
549
549
  **dref_vector,
@@ -551,7 +551,7 @@ def _check_bins(
551
551
  if bins is None:
552
552
  msg = "No ref vector identified!"
553
553
  raise Exception(msg)
554
-
554
+
555
555
  binc = coll.ddata[bins]['data']
556
556
  for k0 in lkdata:
557
557
  dbins[k0]['bin_ref'] = coll.ddata[bins]['ref']
@@ -559,19 +559,19 @@ def _check_bins(
559
559
  binc[0] - 0.5*(binc[1] - binc[0]),
560
560
  0.5*(binc[1:] + binc[:-1]),
561
561
  binc[-1] + 0.5*(binc[-1] - binc[-2]),
562
- ]
563
-
562
+ ]
563
+
564
564
  else:
565
-
565
+
566
566
  for k0 in lkdata:
567
567
  bin_edges = np.r_[
568
568
  bins[0] - 0.5*(bins[1] - bins[0]),
569
569
  0.5*(bins[1:] + bins[:-1]),
570
570
  bins[-1] + 0.5*(bins[-1] - bins[-2]),
571
571
  ]
572
-
572
+
573
573
  dbins[k0]['edges'] = bin_edges
574
-
574
+
575
575
  return dbins
576
576
 
577
577
 
@@ -616,7 +616,7 @@ def _check_bins_data(
616
616
  # make list
617
617
  if isinstance(bin_data, (str, np.ndarray)):
618
618
  bin_data = [bin_data for ii in range(len(ddata))]
619
-
619
+
620
620
  # check consistency
621
621
  if not (isinstance(bin_data, list) and len(bin_data) == len(ddata)):
622
622
  msg = (
@@ -629,13 +629,13 @@ def _check_bins_data(
629
629
  f"\t- len(bin_data) = {len(bin_data)}\n"
630
630
  )
631
631
  raise Exception(msg)
632
-
632
+
633
633
  # -------------
634
634
  # case sorting
635
-
635
+
636
636
  lok_ref = list(coll.dref.keys())
637
637
  lok_data = [k0 for k0, v0 in coll.ddata.items()]
638
-
638
+
639
639
  lok = lok_data + lok_ref
640
640
  lc = [
641
641
  all([isinstance(bb, str) and bb in lok for bb in bin_data]),
@@ -650,37 +650,37 @@ def _check_bins_data(
650
650
  f"Available:\n{sorted(lok)}"
651
651
  )
652
652
  raise Exception(msg)
653
-
653
+
654
654
  # --------------
655
655
  # check vs store
656
-
656
+
657
657
  if store is True and not lc[0]:
658
658
  msg = "With store=True, all bin_data must be keys to ddata or ref"
659
659
  raise Exception(msg)
660
-
660
+
661
661
  # case with all str
662
662
  if lc[0]:
663
-
663
+
664
664
  if dref_vector is None:
665
665
  dref_vector = {}
666
-
666
+
667
667
  # derive dbins
668
668
  for ii, k0 in enumerate(ddata.keys()):
669
-
669
+
670
670
  # if ref => identify vector
671
671
  if bin_data[ii] in lok_ref:
672
-
672
+
673
673
  key_vect = coll.get_ref_vector(
674
674
  ref=bin_data[ii],
675
675
  **dref_vector,
676
676
  )[3]
677
-
677
+
678
678
  if key_vect is None:
679
679
  msg = "bin_data '{bin_data[ii]}' has no reference vector!"
680
680
  raise Exception(msg)
681
-
681
+
682
682
  bin_data[ii] = key_vect
683
-
683
+
684
684
  # fill dict
685
685
  dbins[k0].update({
686
686
  'key': bin_data[ii],
@@ -700,7 +700,7 @@ def _check_bins_data(
700
700
 
701
701
  # -----------------------------------
702
702
  # check nb of dimensions consistency
703
-
703
+
704
704
  ldim = list(set([v0['data'].ndim for v0 in dbins.values()]))
705
705
  if len(ldim) > 1:
706
706
  msg = (
@@ -708,22 +708,22 @@ def _check_bins_data(
708
708
  f"Provided: {ldim}"
709
709
  )
710
710
  raise Exception(msg)
711
-
711
+
712
712
  # -------------------------
713
713
  # check dimensions vs axis
714
-
714
+
715
715
  # None => set to all bin (assuming variable_bin = False)
716
716
  if axis is None:
717
717
  for k0, v0 in dbins.items():
718
-
718
+
719
719
  if ddata[k0]['ref'] is not None and v0['ref'] is not None:
720
720
  seq_data = list(ddata[k0]['ref'])
721
721
  seq_bin = v0['ref']
722
-
722
+
723
723
  else:
724
724
  seq_data = list(ddata[k0]['data'].shape)
725
725
  seq_bin = v0['data'].shape
726
-
726
+
727
727
  # get start indices of subsequence seq_bin in sequence seq_data
728
728
  laxis0 = list(_generic_utils.KnuthMorrisPratt(seq_data, seq_bin))
729
729
  if len(laxis0) != 1:
@@ -734,17 +734,17 @@ def _check_bins_data(
734
734
  f"=> laxis0 = {laxis0}\n"
735
735
  )
736
736
  raise Exception(msg)
737
-
737
+
738
738
  axisi = laxis0[0] + np.arange(0, len(seq_bin))
739
739
  if axis is None:
740
740
  axis = axisi
741
741
  else:
742
742
  assert axis == axisi
743
-
743
+
744
744
  # --------------
745
745
  # axis
746
746
  # -------------------
747
-
747
+
748
748
  axis = _generic_check._check_flat1darray(
749
749
  axis, 'axis',
750
750
  dtype=int,
@@ -752,7 +752,7 @@ def _check_bins_data(
752
752
  can_be_None=False,
753
753
  sign='>=0',
754
754
  )
755
-
755
+
756
756
  if np.any(np.diff(axis) > 1):
757
757
  msg = f"axis must be adjacent indices!\n{axis}"
758
758
  raise Exception(msg)
@@ -767,7 +767,7 @@ def _check_bins_data(
767
767
  f"\t- bin_data: {bin_data}"
768
768
  )
769
769
  raise Exception(msg)
770
-
770
+
771
771
  variable_bin = ndim_bin > len(axis)
772
772
 
773
773
  # -------------------------------
@@ -776,10 +776,10 @@ def _check_bins_data(
776
776
  ndim_data = list(ddata.values())[0]['data'].ndim
777
777
  variable_data = len(axis) < ndim_data
778
778
  for k0, v0 in dbins.items():
779
-
779
+
780
780
  shape_data = ddata[k0]['data'].shape
781
781
  shape_bin = v0['data'].shape
782
-
782
+
783
783
  if variable_bin == variable_data and shape_data != v0['data'].shape:
784
784
  msg = (
785
785
  "variable_bin == variable_data => shapes should be the same!\n"
@@ -796,7 +796,7 @@ def _check_bins_data(
796
796
  sh_var, sh_fix = shape_data, shape_bin
797
797
  else:
798
798
  sh_fix, sh_var = shape_data, shape_bin
799
-
799
+
800
800
  shape_axis = [ss for ii, ss in enumerate(sh_var) if ii in axis]
801
801
  if sh_fix != tuple(shape_axis):
802
802
  msg = (
@@ -805,25 +805,25 @@ def _check_bins_data(
805
805
  f"\t- shape_bin: {shape_bin}\n"
806
806
  f"\t- axis: {axis}\n"
807
807
  )
808
- raise Exception(msg)
808
+ raise Exception(msg)
809
809
 
810
810
  # ----------------------------------------
811
811
  # safety check on bin sizes
812
812
  # ----------------------------------------
813
813
 
814
814
  if len(axis) == 1:
815
-
815
+
816
816
  for k0, v0 in dbins.items():
817
-
817
+
818
818
  if variable_bin:
819
819
  raise NotImplementedError()
820
820
  else:
821
821
  dv = np.abs(np.diff(v0['data']))
822
-
822
+
823
823
  dvmean = np.mean(dv) + np.std(dv)
824
-
824
+
825
825
  if strict is True:
826
-
826
+
827
827
  lim = safety_ratio * dvmean
828
828
  db = np.mean(np.diff(dbins[k0]['edges']))
829
829
  if db < lim:
@@ -871,15 +871,15 @@ def _bin_fixed_bin(
871
871
 
872
872
  # -------------
873
873
  # prepare shape
874
-
874
+
875
875
  shape_data = data.shape
876
876
  ind_other = np.arange(data.ndim)
877
877
  nomit = len(axis) - 1
878
878
  ind_other_flat = np.r_[ind_other[:axis[0]], ind_other[axis[-1]+1:] - nomit]
879
879
  ind_other = np.r_[ind_other[:axis[0]], ind_other[axis[-1]+1:]]
880
-
880
+
881
881
  shape_other = [ss for ii, ss in enumerate(shape_data) if ii not in axis]
882
-
882
+
883
883
  shape_val = list(shape_other)
884
884
  shape_val.insert(axis[0], int(bins0.size - 1))
885
885
  if bins1 is not None:
@@ -900,20 +900,30 @@ def _bin_fixed_bin(
900
900
  # data
901
901
  sli = [slice(None) for ii in shape_other]
902
902
  sli.insert(axis[0], indin)
903
-
903
+
904
904
  data = data[tuple(sli)]
905
905
 
906
+ # ---------------
907
+ # custom
908
+
909
+ if statistic == 'sum_smooth':
910
+ stat = 'mean'
911
+ else:
912
+ stat = statistic
913
+
906
914
  # ------------------
907
915
  # simple case
908
-
916
+
909
917
  if variable_data is False:
910
-
918
+
911
919
  if bins1 is None:
920
+
921
+ # compute
912
922
  val[...] = scpst.binned_statistic(
913
923
  vect0,
914
924
  data,
915
925
  bins=bins0,
916
- statistic=statistic,
926
+ statistic=stat,
917
927
  )[0]
918
928
 
919
929
  else:
@@ -922,14 +932,18 @@ def _bin_fixed_bin(
922
932
  vect1,
923
933
  data,
924
934
  bins=[bins0, bins1],
925
- statistic=statistic,
935
+ statistic=stat,
926
936
  )[0]
927
-
937
+
928
938
  # -------------------------------------------------------
929
939
  # variable data, but axis = int and ufunc exists (faster)
930
-
931
- elif len(axis) == 1 and statistic in _DUFUNC.keys() and bins1 is None:
932
-
940
+
941
+ elif len(axis) == 1 and stat in _DUFUNC.keys() and bins1 is None:
942
+
943
+ if statistic == 'sum_smooth':
944
+ msg = "statistic 'sum_smooth' not properly handled here yet"
945
+ raise NotImplementedError(msg)
946
+
933
947
  # safety check
934
948
  vect0s = np.sort(vect0)
935
949
  if not np.allclose(vect0s, vect0):
@@ -944,9 +958,9 @@ def _bin_fixed_bin(
944
958
  f"\t- vect0s: {vect0s}\n"
945
959
  )
946
960
  raise Exception(msg)
947
-
961
+
948
962
  # get ufunc
949
- ufunc = _DUFUNC[statistic]
963
+ ufunc = _DUFUNC[stat]
950
964
 
951
965
  # get indices
952
966
  ind0 = np.searchsorted(
@@ -962,7 +976,7 @@ def _bin_fixed_bin(
962
976
  # cases
963
977
  if indu.size == 1:
964
978
  sli[axis[0]] = indu[0]
965
- val[sli] = np.nansum(data, axis=axis[0])
979
+ val[tuple(sli)] = np.nansum(data, axis=axis[0])
966
980
 
967
981
  else:
968
982
 
@@ -974,68 +988,85 @@ def _bin_fixed_bin(
974
988
 
975
989
  # sum
976
990
  val[tuple(sli)] = ufunc(data, ind, axis=axis[0])
977
-
991
+
978
992
  # -----------------------------------
979
993
  # other statistic with variable data
980
-
994
+
981
995
  else:
982
-
996
+
983
997
  # indices
984
998
  linds = [range(nn) for nn in shape_other]
985
-
999
+
986
1000
  # slice_data
987
1001
  sli = [0 for ii in shape_other]
988
1002
  sli.insert(axis[0], slice(None))
989
1003
  sli = np.array(sli)
990
-
1004
+
991
1005
  if bins1 is None:
992
-
1006
+
993
1007
  for ind in itt.product(linds):
994
1008
  sli[ind_other_flat] = ind
995
-
1009
+
996
1010
  val[tuple(sli)] = scpst.binned_statistic(
997
1011
  vect0,
998
1012
  data[tuple(sli)],
999
1013
  bins=bins0,
1000
- statistic=statistic,
1014
+ statistic=stat,
1001
1015
  )[0]
1002
-
1016
+
1017
+ if statistic == 'sum_smooth':
1018
+ val[tuple(sli)] *= (
1019
+ np.nansum(data[tuple(sli)]) / np.nansum(val[tuple(sli)])
1020
+ )
1021
+
1003
1022
  else:
1004
-
1023
+
1005
1024
  sli_val = np.copy(sli)
1006
1025
  sli_val = np.insert(axis[0] + 1, slice(None))
1007
1026
 
1008
1027
  for ind in itt.product(linds):
1009
-
1028
+
1010
1029
  sli[ind_other_flat] = ind
1011
1030
  sli_val[ind_other_flat] = ind
1012
-
1031
+
1013
1032
  val[tuple(sli_val)] = scpst.binned_statistic_2d(
1014
1033
  vect0,
1015
1034
  vect1,
1016
1035
  data[tuple(sli)],
1017
1036
  bins=[bins0, bins1],
1018
- statistic=statistic,
1037
+ statistic=stat,
1019
1038
  )[0]
1020
-
1039
+
1040
+ if statistic == 'sum_smooth':
1041
+ val[tuple(sli_val)] *= (
1042
+ np.nansum(data[tuple(sli)]) / np.nansum(val[tuple(sli_val)])
1043
+ )
1044
+
1045
+ # ---------------
1046
+ # adjust custom
1047
+
1048
+ if statistic == 'sum_smooth':
1049
+ if variable_data is False:
1050
+ val[...] *= np.nansum(data) / np.nansum(val)
1051
+
1021
1052
  # ------------
1022
1053
  # references
1023
-
1054
+
1024
1055
  if data_ref is not None:
1025
1056
  ref = [
1026
1057
  rr for ii, rr in enumerate(data_ref)
1027
1058
  if ii not in axis
1028
1059
  ]
1029
-
1060
+
1030
1061
  if bin_ref0 is not None:
1031
1062
  bin_ref0 = bin_ref0[0]
1032
1063
  if bin_ref1 is not None:
1033
1064
  bin_ref1 = bin_ref1[0]
1034
-
1065
+
1035
1066
  ref.insert(axis[0], bin_ref0)
1036
1067
  if bins1 is not None:
1037
1068
  ref.insert(axis[0] + 1, bin_ref1)
1038
-
1069
+
1039
1070
  ref = tuple(ref)
1040
1071
  else:
1041
1072
  ref = None
@@ -1056,10 +1087,10 @@ def _store(
1056
1087
 
1057
1088
  # ----------------
1058
1089
  # check store_keys
1059
-
1090
+
1060
1091
  if len(dout) == 1 and isinstance(store_keys, str):
1061
1092
  store_keys = [store_keys]
1062
-
1093
+
1063
1094
  ldef = [f"{k0}_binned" for k0 in dout.items()]
1064
1095
  lex = list(coll.ddata.keys())
1065
1096
  store_keys = _generic_check._check_var_iter(
@@ -1069,18 +1100,14 @@ def _store(
1069
1100
  default=ldef,
1070
1101
  excluded=lex,
1071
1102
  )
1072
-
1103
+
1073
1104
  # -------------
1074
1105
  # store
1075
-
1106
+
1076
1107
  for ii, (k0, v0) in enumerate(dout.items()):
1077
1108
  coll.add_data(
1078
1109
  key=store_keys[ii],
1079
1110
  data=v0['data'],
1080
1111
  ref=v0['ref'],
1081
1112
  units=v0['units'],
1082
- )
1083
-
1084
-
1085
-
1086
-
1113
+ )