mrio-toolbox 1.0.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mrio-toolbox might be problematic. Click here for more details.

Files changed (67) hide show
  1. __init__.py +21 -0
  2. {mrio_toolbox/_parts → _parts}/_Axe.py +95 -37
  3. {mrio_toolbox/_parts → _parts}/_Part.py +346 -111
  4. _parts/__init__.py +7 -0
  5. {mrio_toolbox/_parts → _parts}/part_operations.py +24 -17
  6. extractors/__init__.py +20 -0
  7. extractors/downloaders.py +36 -0
  8. extractors/emerging/__init__.py +3 -0
  9. extractors/emerging/emerging_extractor.py +117 -0
  10. extractors/eora/__init__.py +3 -0
  11. extractors/eora/eora_extractor.py +132 -0
  12. extractors/exiobase/__init__.py +3 -0
  13. extractors/exiobase/exiobase_extractor.py +270 -0
  14. extractors/extractors.py +81 -0
  15. extractors/figaro/__init__.py +3 -0
  16. extractors/figaro/figaro_downloader.py +280 -0
  17. extractors/figaro/figaro_extractor.py +187 -0
  18. extractors/gloria/__init__.py +3 -0
  19. extractors/gloria/gloria_extractor.py +202 -0
  20. extractors/gtap11/__init__.py +7 -0
  21. extractors/gtap11/extraction/__init__.py +3 -0
  22. extractors/gtap11/extraction/extractor.py +129 -0
  23. extractors/gtap11/extraction/harpy_files/__init__.py +6 -0
  24. extractors/gtap11/extraction/harpy_files/_header_sets.py +279 -0
  25. extractors/gtap11/extraction/harpy_files/har_file.py +262 -0
  26. extractors/gtap11/extraction/harpy_files/har_file_io.py +974 -0
  27. extractors/gtap11/extraction/harpy_files/header_array.py +300 -0
  28. extractors/gtap11/extraction/harpy_files/sl4.py +229 -0
  29. extractors/gtap11/gtap_mrio/__init__.py +6 -0
  30. extractors/gtap11/gtap_mrio/mrio_builder.py +158 -0
  31. extractors/icio/__init__.py +3 -0
  32. extractors/icio/icio_extractor.py +121 -0
  33. extractors/wiod/__init__.py +3 -0
  34. extractors/wiod/wiod_extractor.py +143 -0
  35. mrio_toolbox/mrio.py → mrio.py +254 -94
  36. {mrio_toolbox-1.0.0.dist-info → mrio_toolbox-1.1.2.dist-info}/METADATA +11 -7
  37. mrio_toolbox-1.1.2.dist-info/RECORD +59 -0
  38. {mrio_toolbox-1.0.0.dist-info → mrio_toolbox-1.1.2.dist-info}/WHEEL +1 -1
  39. mrio_toolbox-1.1.2.dist-info/top_level.txt +6 -0
  40. msm/__init__.py +6 -0
  41. msm/multi_scale_mapping.py +863 -0
  42. utils/__init__.py +3 -0
  43. utils/converters/__init__.py +5 -0
  44. {mrio_toolbox/utils → utils}/converters/pandas.py +5 -6
  45. {mrio_toolbox/utils → utils}/converters/xarray.py +6 -15
  46. utils/formatting/formatter.py +527 -0
  47. utils/loaders/__init__.py +7 -0
  48. {mrio_toolbox/utils → utils}/loaders/_loader.py +60 -4
  49. {mrio_toolbox/utils → utils}/loaders/_loader_factory.py +22 -1
  50. {mrio_toolbox/utils → utils}/loaders/_nc_loader.py +37 -1
  51. {mrio_toolbox/utils → utils}/loaders/_pandas_loader.py +29 -3
  52. {mrio_toolbox/utils → utils}/loaders/_parameter_loader.py +61 -16
  53. {mrio_toolbox/utils → utils}/savers/__init__.py +3 -0
  54. utils/savers/_path_checker.py +37 -0
  55. {mrio_toolbox/utils → utils}/savers/_to_folder.py +6 -1
  56. utils/savers/_to_nc.py +60 -0
  57. mrio_toolbox/__init__.py +0 -5
  58. mrio_toolbox/_parts/__init__.py +0 -3
  59. mrio_toolbox/utils/converters/__init__.py +0 -2
  60. mrio_toolbox/utils/loaders/__init__.py +0 -3
  61. mrio_toolbox/utils/savers/_path_checker.py +0 -19
  62. mrio_toolbox/utils/savers/_to_nc.py +0 -52
  63. mrio_toolbox-1.0.0.dist-info/RECORD +0 -26
  64. mrio_toolbox-1.0.0.dist-info/top_level.txt +0 -1
  65. {mrio_toolbox-1.0.0.dist-info → mrio_toolbox-1.1.2.dist-info/licenses}/LICENSE +0 -0
  66. {mrio_toolbox/utils → utils/formatting}/__init__.py +0 -0
  67. {mrio_toolbox/utils → utils}/loaders/_np_loader.py +0 -0
@@ -10,6 +10,7 @@ import itertools
10
10
  import numpy as np
11
11
  import pandas as pd
12
12
  import xarray as xr
13
+ import copy
13
14
  from mrio_toolbox._parts._Axe import Axe
14
15
  import logging
15
16
  from mrio_toolbox.utils import converters
@@ -26,46 +27,151 @@ def load_part(
26
27
  return Part(**loader.load_part(**kwargs))
27
28
 
28
29
  class Part:
30
+ """
31
+ Representation of an MRIO Part object.
32
+
33
+ MRIO Parts are the basic building blocks of the MRIO toolbox. A Part is
34
+ built from a numpy array and a set of Axes, corresponding to the dimensions
35
+ of the array. The Axes hold the labels of the Part in the different
36
+ dimensions and are used to perform advanced indexing and operations on the Part.
37
+
38
+ Axes support multi-level indexing and groupings.
39
+
40
+ Instance variables
41
+ ------------------
42
+ data : numpy.ndarray
43
+ Numerical data of the Part.
44
+ axes : list of Axe instances
45
+ Axes corresponding to the dimensions of the Part.
46
+ groupings : dict
47
+ Groupings of the labels of the Part, for each label defined.
48
+ metadata : dict
49
+ Additional metadata of the Part (e.g., path, name, multiplier, unit).
50
+ name : str
51
+ Name of the Part.
52
+ ndim : int
53
+ Number of dimensions of the Part.
54
+ shape : tuple
55
+ Shape of the Part.
56
+
57
+ Methods
58
+ -------
59
+ __init__(data=None, labels=None, axes=None, **kwargs):
60
+ Initialize a Part object.
61
+ alias(**kwargs):
62
+ Create a new Part with modified parameters.
63
+ fix_dims(skip_labels=False, skip_data=False):
64
+ Align the number of axes with the number of dimensions.
65
+ get(*args, aspart=True, squeeze=False):
66
+ Extract data from the Part object.
67
+ setter(value, *args):
68
+ Change the value of a data selection.
69
+ develop(axis=None, on=None, squeeze=True):
70
+ Reshape a Part to avoid double labels.
71
+ reformat(new_dimensions):
72
+ Reshape a Part to match a new dimensions combination.
73
+ combine_axes(start=0, end=None, in_place=False):
74
+ Combine axes of a Part into a single one.
75
+ swap_axes(axis1, axis2):
76
+ Swap two axes of a Part.
77
+ swap_ax_levels(axis, dim1, dim2):
78
+ Swap two levels of an axis.
79
+ flatten(invert=False):
80
+ Flatten a 2D Part into a 1D Part.
81
+ squeeze():
82
+ Remove dimensions of length 1 from the Part.
83
+ expand_dims(axis, copy=None):
84
+ Add dimensions to a Part instance.
85
+ copy():
86
+ Return a copy of the current Part object.
87
+ extraction(dimensions, labels=["all"], on_groupings=True, domestic_only=False, axis="all"):
88
+ Set labels over dimension(s) to 0.
89
+ leontief_inversion():
90
+ Compute the Leontief inverse of a square Part.
91
+ update_groupings(groupings, ax=None):
92
+ Update the groupings of the current Part object.
93
+ aggregate(on="countries", axis=None):
94
+ Aggregate dimensions along one or several axes.
95
+ aggregate_on(on, axis):
96
+ Aggregate a Part along a given axis.
97
+ get_labels(axis=None):
98
+ Returns a list with the labels of the axes as dictionaries.
99
+ list_labels():
100
+ List the labels of the Part.
101
+ get_dimensions(axis=None):
102
+ Return the list of dimensions of the Part.
103
+ rename_labels(old, new):
104
+ Rename some labels of the Part.
105
+ replace_labels(name, labels, axis=None):
106
+ Update a label of the Part.
107
+ set_labels(labels, axis=None):
108
+ Change the labels of the Part.
109
+ add_labels(labels, dimension=None, axes=None, fill_value=0):
110
+ Add indices to one or multiple Part axes.
111
+ expand(axis=None, over="countries"):
112
+ Expand an axis of the Part.
113
+ issquare():
114
+ Check whether the Part is square.
115
+ hasneg():
116
+ Check whether the Part has negative elements.
117
+ hasax(name=None):
118
+ Return the dimensions along which a Part has given labels.
119
+ sum(axis=None, on=None, keepdims=False):
120
+ Sum the Part along one or several axes or on a given dimension.
121
+ save(file=None, name=None, extension=".npy", overwrite=False, include_labels=False, write_instructions=False, **kwargs):
122
+ Save the Part object to a file.
123
+ to_pandas():
124
+ Return the current Part object as a Pandas DataFrame.
125
+ to_xarray():
126
+ Save the Part object to an xarray DataArray.
127
+ mean(axis=None):
128
+ Compute the mean of the Part along a given axis.
129
+ min(axis=None):
130
+ Compute the minimum value of the Part along a given axis.
131
+ max(axis=None):
132
+ Compute the maximum value of the Part along a given axis.
133
+ mul(a, propagate_labels=True):
134
+ Perform matrix multiplication between Parts with label propagation.
135
+ filter(threshold, fill_value=0):
136
+ Set to 0 the values below a given threshold.
137
+ diag():
138
+ Create a diagonal Part from a 1D Part or extract the diagonal of a 2D Part.
139
+ transpose():
140
+ Transpose the Part object.
141
+ """
142
+
29
143
  def __init__(self,data=None,
30
144
  labels=None,
31
145
  axes=None,
32
146
  **kwargs):
33
- """MRIO Parts object
34
-
35
- MRIO Parts are the basic building blocks of the MRIO toolbox.
36
- A Part is built from a numpy array and a set of Axes,
37
- corresponding to the dimensions of the array.
38
- The Axes hold the labels of the Part in the different dimensions
39
- and are used to perform advanced indexing and operations on the Part.
40
-
41
- Axes support multi-level indexing and groupings.
147
+ """
148
+ Initialize a Part object.
42
149
 
43
150
  Parameters
44
151
  ----------
45
- data : numpy array
46
- Numerical data of the part.
47
- If left empty, a Part of zeros (or any other fill value) is created
48
- with a shape matching the axes.
49
- groupings : dict of label level : dict
50
- Groupings of the labels of the Part, for each label defined.
51
- The groupings are passed to the Axe objects.
152
+ data : numpy.ndarray, optional
153
+ Numerical data of the Part. If not provided, a Part filled with
154
+ zeros (or another fill value) is created based on the shape of the axes.
52
155
  labels : list of str or dict, optional
53
- Labels of the axes.
54
- The upper level of the list correspond to each axe (numpy dimension)
55
- of the part
56
- The lower level correspond to the dict of labels for each axe.
57
- Remember that an Axe can have different levels of labels.
156
+ Labels for the axes. If provided, the labels define the structure
157
+ of the axes. If not provided, axes are created based on the data.
58
158
  axes : list of Axe instances, optional
59
- Custom Axes for the Part.
60
- If left empty, the axes are created from the labels.
159
+ Custom Axes for the Part. If not provided, axes are created from
160
+ the labels or inferred from the data.
61
161
  kwargs : dict
62
- Additional metadata of the Part.
63
- (e.g. path, name, multiplier, unit...)
162
+ Additional metadata for the Part (e.g., path, name, multiplier, unit).
64
163
 
65
- Returns
66
- -------
67
- None.
164
+ Raises
165
+ ------
166
+ ValueError
167
+ If the length of the labels does not match the data dimensions.
168
+ TypeError
169
+ If the provided labels are of an unsupported type.
68
170
 
171
+ Notes
172
+ -----
173
+ If both `data` and `axes` are not provided, the method creates an
174
+ empty Part with default axes and zero-filled data.
69
175
  """
70
176
 
71
177
  if data is not None:
@@ -128,6 +234,7 @@ class Part:
128
234
  new_dims[-1].append(dim)
129
235
 
130
236
  if new_dims != self.get_dimensions():
237
+
131
238
  log.info("Reformat the Part")
132
239
  new_part = self.reformat(new_dims)
133
240
  self.data = new_part.data
@@ -257,13 +364,15 @@ class Part:
257
364
 
258
365
 
259
366
  def __getitem__(self,args):
260
- if isinstance(args,str) or isinstance(args,int) or isinstance(args,np.integer):
367
+ if isinstance(args,str) or isinstance(args,int) or isinstance(args,np.integer) or isinstance(args,dict):
261
368
  args = (args,)
262
369
  return self.get(*args)
263
370
 
264
371
  def __setitem__(self,args,value):
265
372
  if isinstance(value,Part):
266
373
  value = value.data
374
+ if isinstance(args,str) or isinstance(args,int) or isinstance(args,np.integer) or isinstance(args,dict):
375
+ args = (args,)
267
376
  self.setter(value,*args)
268
377
 
269
378
  def setter(self,value,*args):
@@ -319,7 +428,6 @@ class Part:
319
428
  Returns
320
429
  -------
321
430
  New Part object or numpy object
322
-
323
431
  """
324
432
  sels = []
325
433
  axes = []
@@ -379,12 +487,11 @@ class Part:
379
487
  squeeze : bool, optional
380
488
  Whether to remove dimensions of length 1.
381
489
  The default is True.
382
-
490
+
383
491
  Returns
384
492
  -------
385
- Part object
386
- Developped Part
387
-
493
+ Developped Part : Part object
494
+ The developed part
388
495
  """
389
496
  if isinstance(on,str):
390
497
  on = [on]
@@ -414,6 +521,7 @@ class Part:
414
521
  "This operation is not yet supported."
415
522
  )
416
523
  #If the order of the dimensions is unchanged, we can simply reshape
524
+
417
525
  shape = [len(ax) for ax in axes]
418
526
  data = self.data.reshape(shape)
419
527
  if squeeze:
@@ -423,35 +531,42 @@ class Part:
423
531
  return Part(data=data,name=f"developped_{self.name}",
424
532
  groupings=self.groupings,axes=axes)
425
533
 
426
- def reformat(self,new_dimensions):
534
+ def reformat(self, new_dimensions):
427
535
  """
428
- Reshape a Part to match a new dimensions combination
536
+ Reshape a Part to match a new dimensions combination.
429
537
 
430
538
  Equivalent to a combination of the develop and combine_axes methods.
431
539
 
432
540
  This only works for contiguous dimensions in the current Part,
433
541
  without overlapping dimensions.
434
- For example, if the Part has dimensions:
435
- [["countries"],["sectors"],["sectors"]]
436
- The following is allowed:
437
- [["countries","sectors"],["sectors"]]
438
- The following is not allowed:
439
- [["countries"],["sectors","sectors"]]
440
- [["sectors"],["countries","sectors"]]
441
- [["sectors","countries"],["sectors"]]
442
542
 
443
543
  Parameters
444
544
  ----------
445
- dimensions : list of list of str
446
- Original dimensions of the Part
545
+ new_dimensions : list of list of str
546
+ Target dimensions to reshape into.
447
547
 
448
548
  Returns
449
549
  -------
450
- data : numpy array
451
- Reshaped data
452
- axes : list of Axe instances
453
- Reshaped axes
550
+ data : numpy.ndarray
551
+ Reshaped data.
552
+ axes : list of Axe
553
+ Reshaped axes.
554
+
555
+ Examples
556
+ --------
557
+ If the Part has dimensions::
558
+
559
+ [["countries"], ["sectors"], ["sectors"]]
560
+
561
+ The following is allowed::
562
+
563
+ [["countries", "sectors"], ["sectors"]]
454
564
 
565
+ The following is not allowed::
566
+
567
+ [["countries"], ["sectors", "sectors"]]
568
+ [["sectors"], ["countries", "sectors"]]
569
+ [["sectors", "countries"], ["sectors"]]
455
570
  """
456
571
  return part_operations.reformat(self,new_dimensions)
457
572
 
@@ -561,40 +676,44 @@ class Part:
561
676
  dev.name = f"swapped_{self.name}"
562
677
  return dev
563
678
 
564
- def flatten(self,invert=False):
565
- """Flatten a 2D Part into a 1D Part
679
+ def flatten(self):
680
+ """Flatten a multidimensional Part into a 1D Part
566
681
 
567
- Parameters
568
- ----------
569
- inverse : bool, optional
570
- Whether to in the inverse level order.
682
+ Because Parts do not support repeated dimensions over the same axis,
683
+ Axis dimensions are disambiguated if needed by appending an index to the dimension name.
684
+
571
685
  """
572
- if self.ndim != 2:
573
- raise ValueError(f"Cannot flatten Part with {self.ndim} dimensions")
574
- if invert:
575
- labels = {
576
- dimension : self.axes[i].labels[dimension] \
577
- for i in range(self.ndim,0,-1) \
578
- for dimension in self.axes[i].dimensions
579
- }
580
- order = "C"
581
- else:
582
- labels = {
583
- dimension : self.axes[i].labels[dimension] \
584
- for i in range(self.ndim) \
585
- for dimension in self.axes[i].dimensions
586
- }
587
- order = "F"
588
- ax = Axe(labels,self.groupings)
686
+ def disambiguous_dimension(dim,existing,i=1):
687
+ while f"{dim}_{i}" in existing:
688
+ return disambiguous_dimension(dim,existing,i+1)
689
+
690
+ return f"{dim}_{i}"
691
+
692
+ enumerator = range(self.ndim)
693
+ order = "C"
694
+ labels = dict()
695
+ groupings = dict()
696
+ for ax in enumerator:
697
+ for dim in self.axes[ax].dimensions:
698
+ dim_name = dim
699
+ if dim in labels.keys():
700
+ dim_name = disambiguous_dimension(dim,labels)
701
+ log.info(f"Disambiguate dimension {dim} on axis {ax} into {dim_name}")
702
+ labels[dim_name] = self.axes[ax].labels[dim]
703
+ if dim in self.groupings.keys():
704
+ groupings[dim_name] = self.axes[ax].groupings[dim]
705
+ ax = Axe(labels,groupings)
589
706
  return self.alias(data=self.data.flatten(order=order),
590
707
  name=f"flattened_{self.name}",
591
708
  axes =[ax])
592
709
 
593
710
 
594
- def squeeze(self):
711
+ def squeeze(self,drop_ax=True,drop_dims=True):
595
712
  axes = []
596
713
  for ax in self.axes:
597
- if len(ax) > 1:
714
+ if drop_dims:
715
+ ax.squeeze()
716
+ if len(ax) > 1 or not drop_ax:
598
717
  axes.append(ax)
599
718
  return self.alias(data=np.squeeze(self.data),axes=axes,
600
719
  name=f"squeezed_{self.name}")
@@ -661,14 +780,6 @@ class Part:
661
780
  dimensions = [dimensions]
662
781
  if isinstance(labels,str) and labels!="all":
663
782
  labels = [labels]
664
- if len(labels) != len(dimensions):
665
- if len(dimensions)==1:
666
- #If only one dimension is passed, we broadcast the labels
667
- labels = [labels]
668
- else:
669
- #Raise an error for ambiguous cases
670
- log.critical("Number of dimensions and labels do not match for extraction")
671
- raise ValueError("Number of dimensions and labels do not match for extraction")
672
783
  if isinstance(dimensions,dict):
673
784
  to_select = dimensions
674
785
  labels = list(to_select.values())
@@ -677,14 +788,30 @@ class Part:
677
788
  to_select = dict()
678
789
  for dim,label in zip(dimensions,labels):
679
790
  to_select[dim] = label
791
+ if len(labels) != len(dimensions):
792
+ if len(dimensions)==1:
793
+ #If only one dimension is passed, we broadcast the labels
794
+ labels = [labels]
795
+ else:
796
+ #Raise an error for ambiguous cases
797
+ log.critical("Number of dimensions and labels do not match for extraction")
798
+ raise ValueError("Number of dimensions and labels do not match for extraction")
680
799
 
681
800
  allowed = []
682
801
  for i,ax in enumerate(self.axes):
683
802
  if all(dimension in ax.dimensions for dimension in dimensions):
684
803
  allowed.append(i)
685
804
  if len(allowed) == 0:
686
- log.critical("No axis found for extraction on "+str(dimensions))
687
- raise ValueError("No axis found for extraction on "+str(dimensions))
805
+ if len(dimensions) == 1:
806
+ log.critical("No axis found for extraction on "+str(dimensions))
807
+ raise ValueError("No axis found for extraction on "+str(dimensions))
808
+ log.info(f"No axis found for simultaneous extractions on {dimensions}")
809
+ log.info(f"Try successive extractions on {dimensions}")
810
+ for dim,label in zip(dimensions,labels):
811
+ self.extraction(dim,label,
812
+ on_groupings=on_groupings,
813
+ domestic_only=domestic_only,
814
+ axis=axis)
688
815
  if axis == "all":
689
816
  log.info(f"Extract {to_select} on axes "+ str(allowed))
690
817
  axis = allowed
@@ -785,13 +912,14 @@ class Part:
785
912
  for axe in list(ax):
786
913
  self.axes[axe].update_groupings(groupings)
787
914
 
788
- def aggregate(self,on="countries",axis=None):
915
+ def aggregate(self,on=None,axis=None):
789
916
  """Aggregate dimensions along one or several axis.
790
917
 
791
918
  If groupings are defined, these are taken into account.
792
919
  If you want to sum over the dimension of an axis, use the sum method.
793
920
 
794
921
  If no axis is specified, the operation is applied to all axes.
922
+ If no dimension is specified, the operation is applied to all possible dimensions.
795
923
 
796
924
  Parameters
797
925
  ----------
@@ -816,11 +944,14 @@ class Part:
816
944
 
817
945
  """
818
946
  log.debug(f"Aggregate Part {self.name} along axis {axis} on {on}")
947
+
948
+ if on is None:
949
+ on = list(self.groupings.keys())
819
950
 
820
951
 
821
952
  if isinstance(on,list):
822
953
  for item in on:
823
- self = self.aggregate(axis,item)
954
+ self = self.aggregate(on = item, axis=axis)
824
955
  return self
825
956
  if on not in self.groupings.keys():
826
957
  raise ValueError(f"No groupings defined for dimensions {on}")
@@ -871,8 +1002,8 @@ class Part:
871
1002
 
872
1003
  output = Part(axes=new_axis)
873
1004
  idsum = new_axis[axis].dimensions.index(on) #Index of the dimension to sum on
874
- ref_dev = self.develop(axis)
875
- new_dev = output.develop(axis)
1005
+ ref_dev = self.develop(axis, squeeze=False)
1006
+ new_dev = output.develop(axis,squeeze=False)
876
1007
  selector = ["all"]*ref_dev.ndim
877
1008
  for label in new_labels[on]:
878
1009
  selector[axis+idsum] = label
@@ -887,8 +1018,9 @@ class Part:
887
1018
 
888
1019
  def get_labels(self,axis=None):
889
1020
  """
890
- Returns the dictionnary of the Part labels
891
-
1021
+ Returns a list with the labels of each axis
1022
+ of the part in a the dictionary.
1023
+
892
1024
  Parameters
893
1025
  ----------
894
1026
  axis : int or list of int, optional
@@ -1035,12 +1167,12 @@ class Part:
1035
1167
  Value used to initialize the new Part
1036
1168
 
1037
1169
  Returns
1038
- ----------
1170
+ -------
1039
1171
  Part instance
1040
1172
  Part instance with the additional ax indices.
1041
1173
 
1042
1174
  Raise
1043
- ----------
1175
+ -----
1044
1176
  ValueError
1045
1177
  A Value Error is raised if neither the axes nor the
1046
1178
  ref_set arguments are set.
@@ -1073,6 +1205,67 @@ class Part:
1073
1205
  output[sel] = self.data
1074
1206
  return output
1075
1207
 
1208
+ def reorder_data(self,new_labels):
1209
+ """
1210
+ Reorder the data of the Part according to new labels.
1211
+
1212
+ Parameters
1213
+ ----------
1214
+ new_labels : dict
1215
+ New labels for the axes.
1216
+ The keys are the dimensions, the values are the labels.
1217
+
1218
+ Raises
1219
+ ------
1220
+ ValueError
1221
+ If the new labels do not match the current axes.
1222
+ """
1223
+
1224
+ if not isinstance(new_labels,dict):
1225
+ raise ValueError("New labels should be a dictionary")
1226
+
1227
+
1228
+
1229
+ sels = []
1230
+ for axis in self.axes:
1231
+ old_labels = axis.labels
1232
+ if not set(new_labels.keys()).issubset(set(old_labels.keys())):
1233
+ sels.append(axis.get("all"))
1234
+ continue
1235
+ for key in new_labels.keys():
1236
+ set_old = set(old_labels[key])
1237
+ set_new = set(new_labels[key])
1238
+ if not set_old.issubset(set_new):
1239
+ raise ValueError(f"The new labels provided for dimension '{key}' is not a superset of the old labels. " +
1240
+ f"Old labels: {old_labels[key]}, new labels: {new_labels[key]}. "
1241
+ "If you want to rename the labels of this dimensions, use the method 'replace_labels() before reordering the data")
1242
+
1243
+
1244
+ ax_label_dict = {}
1245
+ for key in old_labels.keys():
1246
+ if key in new_labels.keys():
1247
+ ax_label_dict[key] = new_labels[key]
1248
+ for lab in new_labels[key]:
1249
+ if lab not in old_labels[key]:
1250
+ # If the label is not in the list, remove it
1251
+ ax_label_dict[key].remove(lab)
1252
+ else:
1253
+ ax_label_dict[key] = old_labels[key]
1254
+
1255
+ sels.append(axis.get(ax_label_dict))
1256
+
1257
+ if len(sels) == 0:
1258
+ raise ValueError(f"None of the dimensions provided in the new labels dict {new_labels.keys()} are present "+
1259
+ f"in the labels of part '{self.name}', which only contains the dimensions {self.get_dimensions()}")
1260
+
1261
+ #Execute the selection
1262
+ self.data = self.data[np.ix_(*sels)]
1263
+
1264
+ # Update the axes with the new labels
1265
+ for dim in new_labels.keys():
1266
+ self.replace_labels(name = dim, labels = new_labels[dim])
1267
+
1268
+
1076
1269
  def expand(self,axis=None,over="countries"):
1077
1270
  """
1078
1271
  Expand an axis of the Part
@@ -1098,23 +1291,36 @@ class Part:
1098
1291
  """
1099
1292
  if axis is None:
1100
1293
  axis = self.hasax(over)
1294
+ if isinstance(axis,int):
1295
+ axis = [axis]
1296
+
1297
+ output = self.copy()
1101
1298
 
1102
- for ax in axis:
1103
- ref_ax = self.axes[ax]
1104
- new_ax = Axe(ref_ax.labels[over],groupings=self.groupings)
1105
- axes = self.axes.copy()
1106
- axes.insert(ax,new_ax)
1107
- new_shape = list(self.shape)
1108
- new_shape.insert(axis,len(new_ax))
1109
- output = np.zeros(new_shape)
1110
- selector = [slice(None)]*self.ndim
1111
- for item in ref_ax.labels[over]:
1112
- newsel,refsel = selector.copy(),selector.copy()
1113
- newsel.insert(ax,new_ax.sel(item))
1114
- newsel[ax+1] = ref_ax.sel(item)
1115
- refsel[ax]= ref_ax.sel(item)
1116
- output[tuple(newsel)] = self.data[tuple(refsel)]
1299
+ for ax in axis[::-1]:
1300
+ output = output._expand_on(ax,over)
1301
+ return output
1302
+
1303
+ def _expand_on(self,ax,over):
1304
+ """Expand over a single axis"""
1305
+ ref_ax = self.axes[ax]
1306
+ new_ax = Axe({over: ref_ax.labels[over]}, groupings=self.groupings)
1307
+ axes = self.axes.copy()
1308
+ axes.insert(ax,new_ax)
1309
+ new_shape = list(self.shape)
1310
+ new_shape.insert(ax,len(new_ax))
1311
+ output = np.zeros(new_shape)
1312
+ selector = [slice(None)]*self.ndim
1313
+ ordering = ref_ax.dimensions.index(over)
1314
+ for item in ref_ax.labels[over]:
1315
+ newsel,refsel = selector.copy(),selector.copy()
1316
+ newsel.insert(ax,new_ax.get(item))
1317
+ ax_selector = ["all" for i in range(ordering)]
1318
+ ax_selector.append(item)
1319
+ newsel[ax+1] = ref_ax.get(ax_selector)
1320
+ refsel[ax]= ref_ax.get(ax_selector)
1321
+ output[tuple(newsel)] = self.data[tuple(refsel)]
1117
1322
  return self.alias(data=output,name=f"expanded_{self.name}",axes=axes)
1323
+
1118
1324
 
1119
1325
  def issquare(self):
1120
1326
  """Assert wether the Part is square"""
@@ -1212,10 +1418,13 @@ class Part:
1212
1418
  raise ValueError(f"Cannot sum on {on} as it is not a dimension of axis {axis}")
1213
1419
  if ax.levels == 1:
1214
1420
  #If the axis has a single level, this is a simple sum
1421
+ axes = self.axes.copy()
1422
+ if not keepdims:
1423
+ del axes[axis]
1215
1424
  return self.alias(
1216
1425
  data = self.data.sum(axis,keepdims=keepdims),
1217
1426
  name=f"{self.name}_sum_{axis}",
1218
- axes = self.axes
1427
+ axes = axes
1219
1428
  )
1220
1429
  #Otherwise, sum on the relevant levels
1221
1430
  idsum = ax.dimensions.index(on) #Index of the dimension to sum on
@@ -1280,7 +1489,7 @@ class Part:
1280
1489
  if path is None:
1281
1490
  raise FileNotFoundError("No path specified for saving the Part")
1282
1491
  if extension == ".nc":
1283
- path = os.path.join(path,name)
1492
+ path = os.path.join(path,name+extension)
1284
1493
  save_to_nc(self,path,overwrite,
1285
1494
  write_instructions=write_instructions,
1286
1495
  **kwargs)
@@ -1385,10 +1594,16 @@ class Part:
1385
1594
 
1386
1595
  def diag(self):
1387
1596
  if self.ndim == 1:
1597
+ log.info("Diagonalize a 1D part")
1388
1598
  return self.alias(data=np.diag(self.data),
1389
1599
  name=f"diag_{self.name}",
1390
1600
  axes = self.axes*2)
1391
-
1601
+ try:
1602
+ log.info("The part has too many dimensions: try to diagonalize the squeezed part")
1603
+ return self.squeeze().diag()
1604
+ except:
1605
+ raise ValueError("Cannot diagonalize a part with more than 2 dimensions")
1606
+
1392
1607
  def __add__(self,a):
1393
1608
  if isinstance(a,Part):
1394
1609
  name = a.name
@@ -1426,6 +1641,26 @@ class Part:
1426
1641
  def __neg__(self):
1427
1642
  return self.alias(data=-self.data,name=f"-{self.name}")
1428
1643
 
1644
+ def __lt__(self,other):
1645
+ if isinstance(other,Part):
1646
+ return self.data < other.data
1647
+ return self.data < other
1648
+
1649
+ def __le__(self,other):
1650
+ if isinstance(other,Part):
1651
+ return self.data <= other.data
1652
+ return self.data <= other
1653
+
1654
+ def __gt__(self,other):
1655
+ if isinstance(other,Part):
1656
+ return self.data > other.data
1657
+ return self.data > other
1658
+
1659
+ def __ge__(self,other):
1660
+ if isinstance(other,Part):
1661
+ return self.data >= other.data
1662
+ return self.data >= other
1663
+
1429
1664
  def __sub__(self,a):
1430
1665
  if isinstance(a,Part):
1431
1666
  name = a.name
_parts/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ """
2
+ This module provides the Part and Axe classes.
3
+ """
4
+
5
+ from ._Part import Part,load_part
6
+
7
+ __all__ = ['Part','load_part']