mrio-toolbox 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mrio-toolbox might be problematic. Click here for more details.

Files changed (61) hide show
  1. {mrio_toolbox-1.1.1.dist-info → mrio_toolbox-1.1.3.dist-info}/METADATA +2 -2
  2. mrio_toolbox-1.1.3.dist-info/RECORD +5 -0
  3. mrio_toolbox-1.1.3.dist-info/top_level.txt +1 -0
  4. mrio_toolbox/__init__.py +0 -21
  5. mrio_toolbox/_parts/_Axe.py +0 -539
  6. mrio_toolbox/_parts/_Part.py +0 -1698
  7. mrio_toolbox/_parts/__init__.py +0 -7
  8. mrio_toolbox/_parts/part_operations.py +0 -57
  9. mrio_toolbox/extractors/__init__.py +0 -20
  10. mrio_toolbox/extractors/downloaders.py +0 -36
  11. mrio_toolbox/extractors/emerging/__init__.py +0 -3
  12. mrio_toolbox/extractors/emerging/emerging_extractor.py +0 -117
  13. mrio_toolbox/extractors/eora/__init__.py +0 -3
  14. mrio_toolbox/extractors/eora/eora_extractor.py +0 -132
  15. mrio_toolbox/extractors/exiobase/__init__.py +0 -3
  16. mrio_toolbox/extractors/exiobase/exiobase_extractor.py +0 -270
  17. mrio_toolbox/extractors/extractors.py +0 -79
  18. mrio_toolbox/extractors/figaro/__init__.py +0 -3
  19. mrio_toolbox/extractors/figaro/figaro_downloader.py +0 -280
  20. mrio_toolbox/extractors/figaro/figaro_extractor.py +0 -187
  21. mrio_toolbox/extractors/gloria/__init__.py +0 -3
  22. mrio_toolbox/extractors/gloria/gloria_extractor.py +0 -202
  23. mrio_toolbox/extractors/gtap11/__init__.py +0 -7
  24. mrio_toolbox/extractors/gtap11/extraction/__init__.py +0 -3
  25. mrio_toolbox/extractors/gtap11/extraction/extractor.py +0 -129
  26. mrio_toolbox/extractors/gtap11/extraction/harpy_files/__init__.py +0 -6
  27. mrio_toolbox/extractors/gtap11/extraction/harpy_files/_header_sets.py +0 -279
  28. mrio_toolbox/extractors/gtap11/extraction/harpy_files/har_file.py +0 -262
  29. mrio_toolbox/extractors/gtap11/extraction/harpy_files/har_file_io.py +0 -974
  30. mrio_toolbox/extractors/gtap11/extraction/harpy_files/header_array.py +0 -300
  31. mrio_toolbox/extractors/gtap11/extraction/harpy_files/sl4.py +0 -229
  32. mrio_toolbox/extractors/gtap11/gtap_mrio/__init__.py +0 -6
  33. mrio_toolbox/extractors/gtap11/gtap_mrio/mrio_builder.py +0 -158
  34. mrio_toolbox/extractors/icio/__init__.py +0 -3
  35. mrio_toolbox/extractors/icio/icio_extractor.py +0 -121
  36. mrio_toolbox/extractors/wiod/__init__.py +0 -3
  37. mrio_toolbox/extractors/wiod/wiod_extractor.py +0 -143
  38. mrio_toolbox/mrio.py +0 -899
  39. mrio_toolbox/msm/__init__.py +0 -6
  40. mrio_toolbox/msm/multi_scale_mapping.py +0 -863
  41. mrio_toolbox/utils/__init__.py +0 -3
  42. mrio_toolbox/utils/converters/__init__.py +0 -5
  43. mrio_toolbox/utils/converters/pandas.py +0 -247
  44. mrio_toolbox/utils/converters/xarray.py +0 -130
  45. mrio_toolbox/utils/formatting/__init__.py +0 -0
  46. mrio_toolbox/utils/formatting/formatter.py +0 -528
  47. mrio_toolbox/utils/loaders/__init__.py +0 -7
  48. mrio_toolbox/utils/loaders/_loader.py +0 -312
  49. mrio_toolbox/utils/loaders/_loader_factory.py +0 -96
  50. mrio_toolbox/utils/loaders/_nc_loader.py +0 -184
  51. mrio_toolbox/utils/loaders/_np_loader.py +0 -112
  52. mrio_toolbox/utils/loaders/_pandas_loader.py +0 -128
  53. mrio_toolbox/utils/loaders/_parameter_loader.py +0 -386
  54. mrio_toolbox/utils/savers/__init__.py +0 -11
  55. mrio_toolbox/utils/savers/_path_checker.py +0 -37
  56. mrio_toolbox/utils/savers/_to_folder.py +0 -165
  57. mrio_toolbox/utils/savers/_to_nc.py +0 -60
  58. mrio_toolbox-1.1.1.dist-info/RECORD +0 -59
  59. mrio_toolbox-1.1.1.dist-info/top_level.txt +0 -1
  60. {mrio_toolbox-1.1.1.dist-info → mrio_toolbox-1.1.3.dist-info}/WHEEL +0 -0
  61. {mrio_toolbox-1.1.1.dist-info → mrio_toolbox-1.1.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,1698 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Fri Apr 14 14:13:17 2023
4
-
5
- @author: beaufils
6
- """
7
-
8
- import os
9
- import itertools
10
- import numpy as np
11
- import pandas as pd
12
- import xarray as xr
13
- import copy
14
- from mrio_toolbox._parts._Axe import Axe
15
- import logging
16
- from mrio_toolbox.utils import converters
17
- from mrio_toolbox.utils.loaders import make_loader
18
- from mrio_toolbox.utils.savers import save_part_to_folder,save_to_nc
19
- from mrio_toolbox._parts import part_operations
20
-
21
- log = logging.getLogger(__name__)
22
-
23
- def load_part(
24
- **kwargs
25
- ):
26
- loader = make_loader(**kwargs)
27
- return Part(**loader.load_part(**kwargs))
28
-
29
- class Part:
30
- """
31
- Representation of an MRIO Part object.
32
-
33
- MRIO Parts are the basic building blocks of the MRIO toolbox. A Part is
34
- built from a numpy array and a set of Axes, corresponding to the dimensions
35
- of the array. The Axes hold the labels of the Part in the different
36
- dimensions and are used to perform advanced indexing and operations on the Part.
37
-
38
- Axes support multi-level indexing and groupings.
39
-
40
- Instance variables
41
- ------------------
42
- data : numpy.ndarray
43
- Numerical data of the Part.
44
- axes : list of Axe instances
45
- Axes corresponding to the dimensions of the Part.
46
- groupings : dict
47
- Groupings of the labels of the Part, for each label defined.
48
- metadata : dict
49
- Additional metadata of the Part (e.g., path, name, multiplier, unit).
50
- name : str
51
- Name of the Part.
52
- ndim : int
53
- Number of dimensions of the Part.
54
- shape : tuple
55
- Shape of the Part.
56
-
57
- Methods
58
- -------
59
- __init__(data=None, labels=None, axes=None, **kwargs):
60
- Initialize a Part object.
61
- alias(**kwargs):
62
- Create a new Part with modified parameters.
63
- fix_dims(skip_labels=False, skip_data=False):
64
- Align the number of axes with the number of dimensions.
65
- get(*args, aspart=True, squeeze=False):
66
- Extract data from the Part object.
67
- setter(value, *args):
68
- Change the value of a data selection.
69
- develop(axis=None, on=None, squeeze=True):
70
- Reshape a Part to avoid double labels.
71
- reformat(new_dimensions):
72
- Reshape a Part to match a new dimensions combination.
73
- combine_axes(start=0, end=None, in_place=False):
74
- Combine axes of a Part into a single one.
75
- swap_axes(axis1, axis2):
76
- Swap two axes of a Part.
77
- swap_ax_levels(axis, dim1, dim2):
78
- Swap two levels of an axis.
79
- flatten(invert=False):
80
- Flatten a 2D Part into a 1D Part.
81
- squeeze():
82
- Remove dimensions of length 1 from the Part.
83
- expand_dims(axis, copy=None):
84
- Add dimensions to a Part instance.
85
- copy():
86
- Return a copy of the current Part object.
87
- extraction(dimensions, labels=["all"], on_groupings=True, domestic_only=False, axis="all"):
88
- Set labels over dimension(s) to 0.
89
- leontief_inversion():
90
- Compute the Leontief inverse of a square Part.
91
- update_groupings(groupings, ax=None):
92
- Update the groupings of the current Part object.
93
- aggregate(on="countries", axis=None):
94
- Aggregate dimensions along one or several axes.
95
- aggregate_on(on, axis):
96
- Aggregate a Part along a given axis.
97
- get_labels(axis=None):
98
- Returns a list with the labels of the axes as dictionaries.
99
- list_labels():
100
- List the labels of the Part.
101
- get_dimensions(axis=None):
102
- Return the list of dimensions of the Part.
103
- rename_labels(old, new):
104
- Rename some labels of the Part.
105
- replace_labels(name, labels, axis=None):
106
- Update a label of the Part.
107
- set_labels(labels, axis=None):
108
- Change the labels of the Part.
109
- add_labels(labels, dimension=None, axes=None, fill_value=0):
110
- Add indices to one or multiple Part axes.
111
- expand(axis=None, over="countries"):
112
- Expand an axis of the Part.
113
- issquare():
114
- Check whether the Part is square.
115
- hasneg():
116
- Check whether the Part has negative elements.
117
- hasax(name=None):
118
- Return the dimensions along which a Part has given labels.
119
- sum(axis=None, on=None, keepdims=False):
120
- Sum the Part along one or several axes or on a given dimension.
121
- save(file=None, name=None, extension=".npy", overwrite=False, include_labels=False, write_instructions=False, **kwargs):
122
- Save the Part object to a file.
123
- to_pandas():
124
- Return the current Part object as a Pandas DataFrame.
125
- to_xarray():
126
- Save the Part object to an xarray DataArray.
127
- mean(axis=None):
128
- Compute the mean of the Part along a given axis.
129
- min(axis=None):
130
- Compute the minimum value of the Part along a given axis.
131
- max(axis=None):
132
- Compute the maximum value of the Part along a given axis.
133
- mul(a, propagate_labels=True):
134
- Perform matrix multiplication between Parts with label propagation.
135
- filter(threshold, fill_value=0):
136
- Set to 0 the values below a given threshold.
137
- diag():
138
- Create a diagonal Part from a 1D Part or extract the diagonal of a 2D Part.
139
- transpose():
140
- Transpose the Part object.
141
- """
142
-
143
- def __init__(self,data=None,
144
- labels=None,
145
- axes=None,
146
- **kwargs):
147
- """
148
- Initialize a Part object.
149
-
150
- Parameters
151
- ----------
152
- data : numpy.ndarray, optional
153
- Numerical data of the Part. If not provided, a Part filled with
154
- zeros (or another fill value) is created based on the shape of the axes.
155
- labels : list of str or dict, optional
156
- Labels for the axes. If provided, the labels define the structure
157
- of the axes. If not provided, axes are created based on the data.
158
- axes : list of Axe instances, optional
159
- Custom Axes for the Part. If not provided, axes are created from
160
- the labels or inferred from the data.
161
- kwargs : dict
162
- Additional metadata for the Part (e.g., path, name, multiplier, unit).
163
-
164
- Raises
165
- ------
166
- ValueError
167
- If the length of the labels does not match the data dimensions.
168
- TypeError
169
- If the provided labels are of an unsupported type.
170
-
171
- Notes
172
- -----
173
- If both `data` and `axes` are not provided, the method creates an
174
- empty Part with default axes and zero-filled data.
175
- """
176
-
177
- if data is not None:
178
- if isinstance(data,Part):
179
- data = data.data
180
- if isinstance(data,(xr.DataArray,xr.Dataset)):
181
- self.__init__(
182
- **converters.xarray.make_part(data)
183
- )
184
- return
185
- if isinstance(data,pd.DataFrame):
186
- self.__init__(**converters.pandas.make_part(
187
- data)
188
- )
189
- return
190
- self.data = data
191
- self.ndim = data.ndim
192
- self.shape = self.data.shape
193
-
194
- self.name = kwargs.pop("name","new_part")
195
- log.debug("Create Part instance " + self.name)
196
-
197
- self.groupings = kwargs.get("groupings",dict())
198
- self.metadata = kwargs.get("metadata",dict())
199
- self.metadata = {**self.metadata,**kwargs}
200
-
201
- if axes is None:
202
- self.axes = []
203
- self._create_axis(labels)
204
- else:
205
- self.axes = axes
206
-
207
- if data is None:
208
- log.debug("Create empty Part")
209
- data = np.zeros([len(ax) for ax in self.axes])
210
- self.data = data
211
- self.ndim = data.ndim
212
- self.shape = self.data.shape
213
-
214
- self.fix_dims()
215
- for dim in range(self.ndim):
216
- if len(self.axes[dim]) != self.shape[dim] and self.shape[dim] != 1:
217
- log.critical(f"Length of label {dim} does not match data: "+\
218
- f"{len(self.axes[dim])} and {self.shape[dim]}")
219
- raise ValueError(f"Length of label {dim} does not match data: "+\
220
- f"{len(self.axes[dim])} and {self.shape[dim]}")
221
-
222
- if "_original_dimensions" in self.metadata.keys():
223
- #This set of instructions is intended to handle
224
- #Data loaded from netcdf files
225
- log.info("Checking if a reformatting is needed.")
226
- original = self.metadata.pop("_original_dimensions")
227
- new_dims = [[]]
228
- for dim in original:
229
- #Decode the original dimensions
230
- #Because netcdf files do not support multi-level attributes
231
- if dim == "_sep_":
232
- new_dims.append([])
233
- else:
234
- new_dims[-1].append(dim)
235
-
236
- if new_dims != self.get_dimensions():
237
-
238
- log.info("Reformat the Part")
239
- new_part = self.reformat(new_dims)
240
- self.data = new_part.data
241
- self.axes = new_part.axes
242
- self.ndim = self.data.ndim
243
- self.shape = self.data.shape
244
-
245
- self._store_labels()
246
-
247
-
248
- def alias(self,**kwargs):
249
- """Create a new Part in which only prescribed parameters are changed
250
-
251
- The current Part is taken as reference: all arguments not explicitely
252
- set are copied from the current part."""
253
- data = kwargs.get("data",self.data).copy()
254
- name = kwargs.get("name",self.name+"_alias")
255
- groupings = kwargs.get("groupings",self.groupings)
256
- axes = kwargs.get("axes",self.axes)
257
- labels = kwargs.get("labels",None)
258
- metadata = kwargs.get("metadata",self.metadata)
259
- return Part(
260
- data=data,
261
- name=name,
262
- groupings = groupings,
263
- labels=labels,
264
- axes=axes,
265
- metadata = metadata)
266
-
267
- def _create_axis(self,labels):
268
- """
269
- Create an Axe object based on a tuple of lists of indices
270
-
271
- Parameters
272
- ----------
273
- *args : tuple of lists of str, list of str
274
- Labels of the axe.
275
- The first argument is used as the main label.
276
- The second argument (if any) is used as secondary label.
277
- If left empty, the axe is labelled by indices only
278
-
279
- Raises
280
- ------
281
- TypeError
282
- Raised if the arguments types differs from the number of dimensions
283
- or if input labels are incorrect.
284
- ValueError
285
- Raised if the label length does not match the data.
286
-
287
- Returns
288
- -------
289
- None.
290
-
291
- """
292
- self.axes = []
293
-
294
- if isinstance(labels,(tuple,list)) and len(labels)>0 and\
295
- isinstance(labels[0],(str,int,float)):
296
- #If the first item of the labels is not iterable,
297
- #we assume the label is an axis label
298
- labels = [labels]
299
- #We add a dimension to the labels such that enumeration works properly
300
-
301
- if "data" in self.__dict__.keys():
302
- enum = self.ndim
303
- else:
304
- if labels is None:
305
- raise ValueError("Cannot create axes without data, axes or labels")
306
- enum = len(labels)
307
-
308
- for dim in range(enum):
309
- if labels is None or dim > len(labels):
310
- #Fill empty labels with indices
311
- self.axes.append(
312
- Axe([i for i in range(self.shape[dim])],
313
- groupings = self.groupings)
314
- )
315
- elif isinstance(labels,dict):
316
- axname = list(labels.keys())[dim]
317
- self.axes.append(
318
- Axe(labels[axname],groupings=self.groupings,name=axname)
319
- )
320
- elif isinstance(labels,(list,tuple)):
321
- self.axes.append(
322
- Axe(labels[dim],groupings=self.groupings)
323
- )
324
- else:
325
- log.critical("Unkown label type: "+type(labels))
326
- raise TypeError("Unknown label type: "+type(labels))
327
- log.debug(f"Create ax {dim} with len {self.axes[-1]}")
328
-
329
- def fix_dims(self,
330
- skip_labels=False,
331
- skip_data=False):
332
- """Align the number of axes with the number of dimensions
333
-
334
- If one length exceeds the other, axes and/or data are squeezed,
335
- i.e. dimensions of length 1 are removed."""
336
- if len(self.axes) == self.ndim:
337
- return
338
- log.warning(
339
- f"The number of axes ({len(self.axes)})"\
340
- +f" does not match data dimensions ({self.ndim})"
341
- )
342
-
343
- if len(self.axes) > self.ndim and not skip_labels:
344
- log.debug(
345
- "Try to squeeze axe(s) of len 1"
346
- )
347
- counter = [len(ax) != 1 for ax in self.axes]
348
- self.axes = self.axes[counter]
349
- return self.fix_dims(
350
- skip_labels=True,
351
- skip_data=skip_data)
352
- if self.ndim > len(self.axes) and not skip_data:
353
- self.data = self.data.squeeze()
354
- self.shape = self.data.shape
355
- self.ndim = self.data.ndim
356
- return self.fix_dims(
357
- skip_labels=skip_labels,
358
- skip_data=True)
359
-
360
- log.critical("Cannot reconcile data of dims "+ self.ndim+\
361
- " with axes of dim " +len(self.axes))
362
- raise IndexError("Cannot reconcile data of dims "+ self.ndim+\
363
- " with axes of dim " +len(self.axes))
364
-
365
-
366
- def __getitem__(self,args):
367
- if isinstance(args,str) or isinstance(args,int) or isinstance(args,np.integer) or isinstance(args,dict):
368
- args = (args,)
369
- return self.get(*args)
370
-
371
- def __setitem__(self,args,value):
372
- if isinstance(value,Part):
373
- value = value.data
374
- if isinstance(args,str) or isinstance(args,int) or isinstance(args,np.integer) or isinstance(args,dict):
375
- args = (args,)
376
- self.setter(value,*args)
377
-
378
- def setter(self,value,*args):
379
- """
380
- Change the value of a data selection
381
-
382
- Parameters
383
- ----------
384
- value : float or numpy like
385
- Value to set.
386
- *args : list of tuples
387
- Indices along the respective axes.
388
-
389
- Returns
390
- -------
391
- None.
392
- Modification is applied to the current Part object
393
-
394
- """
395
- sels = []
396
- try:
397
- #First tries to interpret one arg per ax
398
- for i,arg in enumerate(args):
399
- sels.append(self.axes[i].get(arg))
400
- except (IndexError,ValueError):
401
- #Otherwise, tries to interpret all args on the first ax
402
- sels = []
403
- sels = [self.axes[0].get(args)]
404
- if isinstance(value,(np.ndarray,Part)) and len(sels)!=value.ndim:
405
- if len(sels) < value.ndim:
406
- value = value.squeeze()
407
- if len(sels) > value.ndim:
408
- target_shape = [len(sel) for sel in sels]
409
- value = np.reshape(value,target_shape)
410
- self.data[np.ix_(*sels)] = value
411
-
412
- def get(self,*args,aspart=True,squeeze=False):
413
- """
414
- Extract data from the current Part object
415
-
416
- Parameters
417
- ----------
418
- *args : list of tuples
419
- Selection along the Axes of the Part.
420
- aspart : bool, optional
421
- Whether to return the selection as a Part object.
422
- If False, the selection is returned as a numpy object.
423
- The default is True.
424
- squeeze : bool, optional
425
- Whether to remove dimensions of length 1.
426
- The default is False
427
-
428
- Returns
429
- -------
430
- New Part object or numpy object
431
- """
432
- sels = []
433
- axes = []
434
-
435
- #Extract the indices for the selection
436
- try:
437
- #First tries to interpret one arg per ax
438
- for i,arg in enumerate(args):
439
- datasel,labs,groupings = self.axes[i].get(arg,True)
440
- if not squeeze or len(datasel) > 1:
441
- axes.append(Axe(labs,groupings))
442
- sels.append(datasel)
443
- except (ValueError, IndexError):
444
- sels = []
445
- axes = []
446
- #Try interpreting all args on the first ax
447
- datasel,labs,groupings = self.axes[0].get(args,True)
448
- if not squeeze or len(datasel) > 1:
449
- axes.append(Axe(labs,groupings))
450
- sels.append(datasel)
451
-
452
- #If the selection is not complete, fill with all
453
- if len(sels)<self.ndim:
454
- for i in range(len(sels),self.ndim):
455
- datasel,labs,groupings = self.axes[i].get("all",True)
456
- if not squeeze or len(datasel) > 1:
457
- axes.append(Axe(labs,groupings))
458
- sels.append(datasel)
459
-
460
- #Execute the selection
461
- data = self.data[np.ix_(*sels)]
462
-
463
- #Return the selection
464
- if squeeze:
465
- data = data.squeeze()
466
- if aspart:
467
- return Part(data=data,name=f"sel_{self.name}",
468
- groupings=self.groupings,axes = axes)
469
- return data
470
-
471
- def develop(self,axis=None,on=None,squeeze=True):
472
- """
473
- Reshape a Part to avoid double labels
474
-
475
- Parameters
476
- ----------
477
- axis : int or list of int, optional
478
- Axis to develop.
479
- If left empty, all axes are developed.
480
- The default is None.
481
- on : str or list of str, optional
482
- Dimensions to develop.
483
- If left empty, all dimensions are developed.
484
- Note that the develop method does not support the developping of
485
- non-contiguous dimensions.
486
- The default is None.
487
- squeeze : bool, optional
488
- Whether to remove dimensions of length 1.
489
- The default is True.
490
-
491
- Returns
492
- -------
493
- Developped Part : Part object
494
- The developed part
495
- """
496
- if isinstance(on,str):
497
- on = [on]
498
- if isinstance(axis,int):
499
- axis = [axis]
500
- axes = []
501
- for i,ax in enumerate(self.axes):
502
- if axis is None or i in axis:
503
- labels = ax.labels.copy()
504
- for dim in ax.dimensions:
505
- if on is None or dim in on:
506
- #Add dimension that needs to be developed
507
- axes.append(Axe({dim:labels[dim]},groupings=ax.groupings))
508
- labels.pop(dim)
509
- if len(labels) > 0:
510
- #Keep remaining dimensions together
511
- axes.append(Axe(labels,groupings=ax.groupings))
512
- else:
513
- axes.append(ax)
514
-
515
-
516
- old_dim_order = [dim for ax in self.axes for dim in ax.dimensions]
517
- new_dim_order = [dim for ax in axes for dim in ax.dimensions]
518
- if new_dim_order != old_dim_order:
519
- raise NotImplementedError(
520
- "Developping the part misaligns the dimensions. "+\
521
- "This operation is not yet supported."
522
- )
523
- #If the order of the dimensions is unchanged, we can simply reshape
524
-
525
- shape = [len(ax) for ax in axes]
526
- data = self.data.reshape(shape)
527
- if squeeze:
528
- return Part(data=data,
529
- name=f"developped_{self.name}",
530
- groupings=self.groupings,axes=axes).squeeze()
531
- return Part(data=data,name=f"developped_{self.name}",
532
- groupings=self.groupings,axes=axes)
533
-
534
- def reformat(self, new_dimensions):
535
- """
536
- Reshape a Part to match a new dimensions combination.
537
-
538
- Equivalent to a combination of the develop and combine_axes methods.
539
-
540
- This only works for contiguous dimensions in the current Part,
541
- without overlapping dimensions.
542
-
543
- Parameters
544
- ----------
545
- new_dimensions : list of list of str
546
- Target dimensions to reshape into.
547
-
548
- Returns
549
- -------
550
- data : numpy.ndarray
551
- Reshaped data.
552
- axes : list of Axe
553
- Reshaped axes.
554
-
555
- Examples
556
- --------
557
- If the Part has dimensions::
558
-
559
- [["countries"], ["sectors"], ["sectors"]]
560
-
561
- The following is allowed::
562
-
563
- [["countries", "sectors"], ["sectors"]]
564
-
565
- The following is not allowed::
566
-
567
- [["countries"], ["sectors", "sectors"]]
568
- [["sectors"], ["countries", "sectors"]]
569
- [["sectors", "countries"], ["sectors"]]
570
- """
571
- return part_operations.reformat(self,new_dimensions)
572
-
573
- def combine_axes(self,start=0,end=None,in_place=False):
574
- """
575
- Combine axes of a Part into a single one.
576
-
577
- The order of dimensions is preserved in the new axis.
578
- Only consecutive axes can be combined.
579
- The method can be used to revert the develop method.
580
-
581
- Parameters
582
- ----------
583
- start : int, optional
584
- Index of the first axis to combine, by default 0
585
- end : int, optional
586
- Index of the final axis to combine, by default None,
587
- all axis are combined, i.e. the Part is flattened.
588
-
589
- Returns
590
- -------
591
- Part instance
592
-
593
- Raises
594
- ------
595
- IndexError
596
- Axes should have no overlapping dimensions.
597
- """
598
- axes = []
599
- covered = []
600
- labels,groupings = dict(),dict()
601
- if end is None:
602
- end = self.ndim - 1
603
- for i,ax in enumerate(self.axes):
604
- if i in range(start,end+1):
605
- for name in ax.dimensions:
606
- if name in covered:
607
- raise IndexError(
608
- "Cannot undevelop axes with overlapping dimensions"
609
- )
610
- covered.append(name)
611
- labels[name] = ax.labels[name]
612
- if name in ax.groupings.keys():
613
- groupings[name] = ax.groupings[name]
614
- else:
615
- axes.append(ax)
616
- if i == end:
617
- axes.append(Axe(labels,groupings))
618
-
619
- new_shape = [len(ax) for ax in axes]
620
- data = self.data.reshape(new_shape)
621
- if in_place:
622
- return data,axes
623
- return self.alias(data=data,name=f"combined_{self.name}",
624
- axes=axes)
625
-
626
- def swap_axes(self,axis1,axis2):
627
- """Swap two axes of a Part
628
-
629
- Parameters
630
- ----------
631
- axis1 : int
632
- First axis to swap.
633
- axis2 : int
634
- Second axis to swap.
635
-
636
- Returns
637
- -------
638
- Part instance
639
- Part with swapped axes.
640
-
641
- """
642
- axes = self.axes.copy()
643
- axes[axis1],axes[axis2] = axes[axis2],axes[axis1]
644
- data = self.data.swapaxes(axis1,axis2)
645
- return Part(data=data,name=f"swapped_{self.name}",axes=axes)
646
-
647
- def swap_ax_levels(self,axis,dim1,dim2):
648
- """Swap two levels of an axis
649
-
650
- Parameters
651
- ----------
652
- axis : int
653
- Axis to modify.
654
- dim1 : str
655
- First dimension to swap.
656
- dim2 : str
657
- Second dimension to swap.
658
-
659
- Returns
660
- -------
661
- Part instance
662
- Part with swapped levels.
663
-
664
- """
665
- len1,len2 = len(self.axes[axis].labels[dim1]),len(self.axes[axis].labels[dim2])
666
- if len1 == 1 or len2 == 1:
667
- axes = self.axes.copy()
668
- axes[axis].swap_levels(dim1,dim2)
669
- return Part(data=self.data,name=f"swapped_{self.name}",axes=axes)
670
- dimensions = self.axes[axis].dimensions
671
- id1,id2 = dimensions.index(dim1),dimensions.index(dim2)
672
- offset = sum([len(ax.dimensions) for ax in self.axes[:axis]])
673
- dev = self.develop(axis)
674
- dev = dev.swap_axes(id1+offset,id2+offset)
675
- dev = dev.combine_axes(axis,axis+len(dimensions)-1)
676
- dev.name = f"swapped_{self.name}"
677
- return dev
678
-
679
- def flatten(self,invert=False):
680
- """Flatten a 2D Part into a 1D Part
681
-
682
- Parameters
683
- ----------
684
- inverse : bool, optional
685
- Whether to in the inverse level order.
686
- """
687
- if self.ndim != 2:
688
- raise ValueError(f"Cannot flatten Part with {self.ndim} dimensions")
689
- if invert:
690
- labels = {
691
- dimension : self.axes[i].labels[dimension] \
692
- for i in range(self.ndim,0,-1) \
693
- for dimension in self.axes[i].dimensions
694
- }
695
- order = "C"
696
- else:
697
- labels = {
698
- dimension : self.axes[i].labels[dimension] \
699
- for i in range(self.ndim) \
700
- for dimension in self.axes[i].dimensions
701
- }
702
- order = "F"
703
- ax = Axe(labels,self.groupings)
704
- return self.alias(data=self.data.flatten(order=order),
705
- name=f"flattened_{self.name}",
706
- axes =[ax])
707
-
708
-
709
- def squeeze(self):
710
- axes = []
711
- for ax in self.axes:
712
- if len(ax) > 1:
713
- axes.append(ax)
714
- return self.alias(data=np.squeeze(self.data),axes=axes,
715
- name=f"squeezed_{self.name}")
716
-
717
- def expand_dims(self,axis,copy=None):
718
- """Add dimensions to a Part instance
719
-
720
- Parameters
721
- ----------
722
- axis : int
723
- Position of the new axis.
724
- copy : int, optional
725
- Axis to copy the labels from.
726
- If left empty, the axis is created without labels.
727
- """
728
- axes = self.axes.copy()
729
- if copy is not None:
730
- axes.insert(axis,self.axes[copy])
731
- else:
732
- axes.insert(axis,Axe({"expanded":[0]}))
733
- return self.alias(data = np.expand_dims(self.data,axis),
734
- axes=axes,name=f"expanded_{self.name}")
735
-
736
- def copy(self):
737
- """Return a copy of the current Part object"""
738
- return self.alias()
739
-
740
- def extraction(self,
741
- dimensions,
742
- labels=["all"],
743
- on_groupings=True,
744
- domestic_only=False,
745
- axis="all"):
746
- """
747
- Set labels over dimension(s) to 0.
748
-
749
- Parameters
750
- ----------
751
- dimensions : str, list of str, dict
752
- Name of the dimensions on which the extraction is done.
753
- If dict is passed, the keys are interpreted as the dimensions
754
- and the values as the labels
755
- labels : list of (list of) str, optional
756
- Selection on the dimension to put to 0.
757
- on_groupings : bool, optional
758
- Whether to use the groupings to select the labels.
759
- This matters only when the domestic_only argument is set to True.
760
- domestic_only : bool, optional
761
- If yes, only domestic transactions are set to 0 and trade flows
762
- are left untouched. The default is False.
763
- axis : list of ints, optional
764
- Axis along which the extraction is done. The default is "all".
765
- In any case, the extraction only applies to axis allowing it, that
766
- is in axis containing zones or countries labels corresponding to
767
- the zone selection.
768
-
769
- Returns
770
- -------
771
- Part object
772
- New Part with selection set to 0.
773
-
774
- """
775
- if isinstance(dimensions,str):
776
- dimensions = [dimensions]
777
- if isinstance(labels,str) and labels!="all":
778
- labels = [labels]
779
- if isinstance(dimensions,dict):
780
- to_select = dimensions
781
- labels = list(to_select.values())
782
- dimensions = list(dimensions.keys())
783
- else:
784
- to_select = dict()
785
- for dim,label in zip(dimensions,labels):
786
- to_select[dim] = label
787
- if len(labels) != len(dimensions):
788
- if len(dimensions)==1:
789
- #If only one dimension is passed, we broadcast the labels
790
- labels = [labels]
791
- else:
792
- #Raise an error for ambiguous cases
793
- log.critical("Number of dimensions and labels do not match for extraction")
794
- raise ValueError("Number of dimensions and labels do not match for extraction")
795
-
796
- allowed = []
797
- for i,ax in enumerate(self.axes):
798
- if all(dimension in ax.dimensions for dimension in dimensions):
799
- allowed.append(i)
800
- if len(allowed) == 0:
801
- if len(dimensions) == 1:
802
- log.critical("No axis found for extraction on "+str(dimensions))
803
- raise ValueError("No axis found for extraction on "+str(dimensions))
804
- log.info(f"No axis found for simultaneous extractions on {dimensions}")
805
- log.info(f"Try successive extractions on {dimensions}")
806
- for dim,label in zip(dimensions,labels):
807
- self.extraction(dim,label,
808
- on_groupings=on_groupings,
809
- domestic_only=domestic_only,
810
- axis=axis)
811
- if axis == "all":
812
- log.info(f"Extract {to_select} on axes "+ str(allowed))
813
- axis = allowed
814
- if isinstance(axis,int):
815
- axis = [axis]
816
- if not all(ax in allowed for ax in axis):
817
- wrong = [ax for ax in axis if ax not in allowed]
818
- log.critical(f"Cannot extract {dimensions} on axis {wrong}")
819
- raise ValueError(f"Cannot extract {dimensions} on axis {wrong}")
820
-
821
- if not domestic_only:
822
- #If no domestic_only, we can simply set the selection to 0
823
- sel = ["all"]*self.ndim
824
- for i,ax in enumerate(self.axes):
825
- if i in axis:
826
- sel[i] = to_select
827
- output = self.copy()
828
- output[sel] = 0
829
- return output
830
-
831
- if not on_groupings:
832
- #If no groupings, develop the selected groupings
833
- for i,dim in enumerate(dimensions):
834
- if dim in self.groupings.keys():
835
- for label in labels[i]:
836
- if label in self.groupings[dim].keys():
837
- labels[i].append(self.groupings[dim][label])
838
- labels[i].remove(label)
839
-
840
- for i,label in enumerate(labels):
841
- if label == "all":
842
- if on_groupings and dimensions[i] in self.groupings.keys():
843
- labels[i] = list(
844
- self.axes[axis[i]].groupings[dimensions[i]].keys()
845
- )
846
- else:
847
- labels[i] = self.axes[axis[i]].labels[dimensions[i]]
848
-
849
- output = self.copy()
850
- for label in itertools.product(*labels):
851
- #Iteratively set domestic selections to 0
852
- seldict = dict(zip(dimensions,label))
853
- sel = ["all"]*self.ndim
854
- for i in range(self.ndim):
855
- if i in axis:
856
- sel[i] = seldict
857
- output[sel] = 0
858
- return output
859
-
860
- def leontief_inversion(self):
861
- if self.ndim == 2 and self.issquare():
862
- data = np.linalg.inv(np.identity(len(self.axes[0])) - self.data)
863
- return self.alias(name=f"l_{self.name}",data=data)
864
- raise ValueError("Can only compute the Leontief inverse on"+\
865
- " square parts")
866
-
867
- def zone(self):
868
- """
869
- Apply a grouping by zone dependent on the shape of the Part.
870
-
871
- Final demand Parts are summed over zones.
872
- Horizontal extensions are expanded by zone.
873
-
874
- Raises
875
- ------
876
- AttributeError
877
- Parts with other shapes are rejected.
878
-
879
- Returns
880
- -------
881
- Part object
882
- Grouped part.
883
-
884
- """
885
- log.warning("This function is deprecated as it returns different "+\
886
- "results depending on the shape of the Part. "+\
887
- "Use group or expand instead.")
888
- if self.ndim == 2 and not self.issquare():
889
- if "countries" in self.axes[1].dimensions:
890
- #Expand normal parts
891
- return self.group(1,"countries")
892
- if self.ndim == 1:
893
- #Expand horizontal parts
894
- return self.expand("countries")
895
- raise AttributeError(f"Part {self.name} has no predefined grouping.")
896
-
897
- def update_groupings(self,groupings,ax=None):
898
- """Update the groupings of the current Part object
899
-
900
- groupings : dict
901
- Description of the groupings
902
- ax: int, list of int
903
- Axes to update. If left empty, all axes are updated.
904
- """
905
- self.groupings = groupings
906
- if ax is None:
907
- ax = range(self.ndim)
908
- for axe in list(ax):
909
- self.axes[axe].update_groupings(groupings)
910
-
911
- def aggregate(self,on="countries",axis=None):
912
- """Aggregate dimensions along one or several axis.
913
-
914
- If groupings are defined, these are taken into account.
915
- If you want to sum over the dimension of an axis, use the sum method.
916
-
917
- If no axis is specified, the operation is applied to all axes.
918
-
919
- Parameters
920
- ----------
921
- axis : str or list of str, optional
922
- List of axis along which countries are grouped.
923
- If left emtpy, countries are grouped along all possible axis.
924
- on : str or dict, optional
925
- Indicate wether the grouping should be done by zones ("zones")
926
- or by sector ("sectors"), or both ("both").
927
- The default is "zones".
928
- If both, the operation is equivalent to summing over an axis
929
-
930
- Raises
931
- ------
932
- ValueError
933
- Raised if a selected Axe cannot be grouped.
934
-
935
- Returns
936
- -------
937
- Part object
938
- Part grouped by zone.
939
-
940
- """
941
- log.debug(f"Aggregate Part {self.name} along axis {axis} on {on}")
942
-
943
-
944
- if isinstance(on,list):
945
- for item in on:
946
- self = self.aggregate(on = item, axis=axis)
947
- return self
948
- if on not in self.groupings.keys():
949
- raise ValueError(f"No groupings defined for dimensions {on}")
950
-
951
- if axis is None:
952
- axis = self.hasax(
953
- on
954
- )
955
- if isinstance(axis,int):
956
- axis = [axis]
957
-
958
- output = self.alias()
959
-
960
- for ax in axis:
961
- output = output.aggregate_on(on,ax)
962
-
963
- output.name = f"{on}_grouped_{self.name}"
964
- return output
965
-
966
- def aggregate_on(self,on,axis):
967
- """Aggregate a Part along a given axis
968
-
969
- Parameters
970
- ----------
971
- on : str
972
- Dimension to aggregate on
973
- axis : int
974
- Axis to aggregate
975
-
976
- Returns
977
- -------
978
- Part instance
979
- Aggregated Part
980
- """
981
- if on not in self.axes[axis].dimensions:
982
- raise ValueError(f"Dimension {on} not found in axis {axis}")
983
-
984
- new_labels = self.axes[axis].labels.copy()
985
- new_labels[on] = list(self.axes[axis].groupings[on].keys())
986
- new_groupings = self.groupings.copy()
987
- new_groupings[on] = {
988
- item : [item] for item in self.groupings[on]
989
- }
990
-
991
- new_axis = self.axes.copy()
992
- new_axis[axis] = Axe(new_labels,new_groupings)
993
- new_shape = [len(ax) for ax in new_axis]
994
-
995
- output = Part(axes=new_axis)
996
- idsum = new_axis[axis].dimensions.index(on) #Index of the dimension to sum on
997
- ref_dev = self.develop(axis, squeeze=False)
998
- new_dev = output.develop(axis,squeeze=False)
999
- selector = ["all"]*ref_dev.ndim
1000
- for label in new_labels[on]:
1001
- selector[axis+idsum] = label
1002
- new_dev[selector] = ref_dev[selector].sum(
1003
- axis=axis+idsum,
1004
- keepdims=True
1005
- )
1006
- output = new_dev.data.reshape(new_shape)
1007
- return self.alias(data=output,name=f"{on}_grouped_{self.name}",
1008
- axes=new_axis)
1009
-
1010
-
1011
- def get_labels(self,axis=None):
1012
- """
1013
- Returns a list with the labels of each axis
1014
- of the part in a the dictionary.
1015
-
1016
- Parameters
1017
- ----------
1018
- axis : int or list of int, optional
1019
- Axis to investigate, by default None,
1020
- All axes are investigated.
1021
-
1022
- Returns
1023
- -------
1024
- list
1025
- Labels used in the part.
1026
- """
1027
- labels = []
1028
- if axis is None:
1029
- axis = range(self.ndim)
1030
- if isinstance(axis,int):
1031
- axis = [axis]
1032
- #Make sure the axis is iterable
1033
- for ax in axis:
1034
- labels.append(self.axes[ax].labels)
1035
- return labels
1036
-
1037
- def list_labels(self):
1038
- """List the labels of the Part"""
1039
- labels = dict()
1040
- ax_labels = self.get_labels()
1041
- for ax in ax_labels:
1042
- for label in ax.keys():
1043
- if label not in labels.keys():
1044
- labels[label] = ax[label]
1045
- return labels
1046
-
1047
- def get_dimensions(self,axis=None):
1048
- """
1049
- Returns the list dimensions of the Part
1050
-
1051
- Parameters
1052
- ----------
1053
- axis : int or list of int, optional
1054
- Axis to investigate, by default None,
1055
- All axes are investigated.
1056
-
1057
- Returns
1058
- -------
1059
- list
1060
- Dimensions of the axes.
1061
- """
1062
- dimensions = []
1063
- if axis is None:
1064
- axis = range(self.ndim)
1065
- if isinstance(axis,int):
1066
- axis = [axis]
1067
- #Make sure the axis is iterable
1068
- for ax in axis:
1069
- dimensions.append(self.axes[ax].dimensions)
1070
- return dimensions
1071
-
1072
- def rename_labels(self,old,new):
1073
- """
1074
- Rename some labels of the Part
1075
-
1076
- Parameters
1077
- ----------
1078
- old : str
1079
- Name of the label to change.
1080
- new : str
1081
- New label name.
1082
- """
1083
- for ax in self.axes:
1084
- if old in ax.dimensions:
1085
- ax.rename_labels(old,new)
1086
- self._store_labels()
1087
-
1088
- def replace_labels(self,name,labels,axis=None):
1089
- """
1090
- Update a label of the part
1091
-
1092
- Parameters
1093
- ----------
1094
- name : str
1095
- Name of the label to update, by default None
1096
- labels : dict or list
1097
- New labels for the corresponding ax.
1098
- If a list is passed, the former label name is used.
1099
- axis : int, list of int, optional
1100
- List of axis on which the label is changed.
1101
- By default None, all possible axes are updated.
1102
- """
1103
- if axis is None:
1104
- axis = range(self.ndim)
1105
- if isinstance(axis,int):
1106
- axis = [axis]
1107
- if isinstance(labels,list):
1108
- labels = {name:labels}
1109
- for ax in axis:
1110
- if name in self.axes[ax].dimensions:
1111
- self.axes[ax].replace_labels(name,labels)
1112
- self._store_labels()
1113
-
1114
- def set_labels(self,labels,axis=None):
1115
- """
1116
- Change the labels of the Part
1117
-
1118
- Parameters
1119
- ----------
1120
- labels : dict or nested list
1121
- New labels of the axes.
1122
- If a nested list is passed, the first level corresponds to the axes
1123
- axis : str, optional
1124
- Axis on which the labels are changes, by default None,
1125
- all axes are updated.
1126
- """
1127
- if axis is None:
1128
- axis = range(self.ndim)
1129
- if isinstance(axis,int):
1130
- axis = [axis]
1131
- if isinstance(labels,list) and len(labels) == self.ndim:
1132
- labels = {i:labels[i] for i in range(self.ndim)}
1133
- for ax in axis:
1134
- self.axes[ax].set_labels(labels[ax])
1135
- self._store_labels()
1136
-
1137
- def _store_labels(self):
1138
- """Store the labels of the Part"""
1139
- self.labels = self.list_labels()
1140
-
1141
- def add_labels(self,labels,dimension=None,axes=None,
1142
- fill_value=0):
1143
- """
1144
- Add indices to one or multiple Part axes.
1145
-
1146
- Parameters
1147
- ----------
1148
- new_labels : list of str or dict
1149
- List of indices to add
1150
- dimension : str, optional
1151
- Labels the new indices should be appended to,
1152
- in case new_labels is not a dict.
1153
- If new_labels is a dict, dimension is ignored.
1154
- axes : int or set of ints, optional
1155
- Axes or list of axes to modify.
1156
- In case it is not specified, the axes are detected
1157
- by looking for the dimension (or new_labels keys) in each ax.
1158
- fill_value : float, optional
1159
- Value used to initialize the new Part
1160
-
1161
- Returns
1162
- -------
1163
- Part instance
1164
- Part instance with the additional ax indices.
1165
-
1166
- Raise
1167
- -----
1168
- ValueError
1169
- A Value Error is raised if neither the axes nor the
1170
- ref_set arguments are set.
1171
- """
1172
- if isinstance(labels,list):
1173
- labels = {dimension:labels}
1174
- dimension = list(labels.keys())[0]
1175
- if axes is None:
1176
- #Identify the axes with the ref_set in labels
1177
- axes = self.hasax(dimension)
1178
- elif isinstance(axes,int):
1179
- axes = [axes]
1180
-
1181
- new_axes = self.axes.copy()
1182
- sel = ["all"]*self.ndim
1183
- for ax in axes:
1184
- log.debug("Add labels to axis "+str(ax))
1185
- sel[ax] = dict()
1186
- old_labels = self.axes[ax].labels
1187
- new_labels = self.axes[ax].labels.copy()
1188
- new_labels[dimension] = old_labels[dimension] + labels[dimension]
1189
- new_axes[ax] = Axe(new_labels,self.groupings)
1190
- sel[ax] = old_labels
1191
-
1192
- new_shape = [len(ax) for ax in new_axes]
1193
- output = self.alias(data=np.full(new_shape,fill_value,dtype="float64"),
1194
- axes=new_axes)
1195
-
1196
- #Put original data back in place
1197
- output[sel] = self.data
1198
- return output
1199
-
1200
- def reorder_data(self,new_labels):
1201
- """
1202
- Reorder the data of the Part according to new labels.
1203
-
1204
- Parameters
1205
- ----------
1206
- new_labels : dict
1207
- New labels for the axes.
1208
- The keys are the dimensions, the values are the labels.
1209
-
1210
- Raises
1211
- ------
1212
- ValueError
1213
- If the new labels do not match the current axes.
1214
- """
1215
-
1216
- if not isinstance(new_labels,dict):
1217
- raise ValueError("New labels should be a dictionary")
1218
-
1219
-
1220
-
1221
- sels = []
1222
- for axis in self.axes:
1223
- old_labels = axis.labels
1224
- if not set(new_labels.keys()).issubset(set(old_labels.keys())):
1225
- sels.append(axis.get("all"))
1226
- continue
1227
- for key in new_labels.keys():
1228
- set_old = set(old_labels[key])
1229
- set_new = set(new_labels[key])
1230
- if not set_old.issubset(set_new):
1231
- raise ValueError(f"The new labels provided for dimension '{key}' is not a superset of the old labels. " +
1232
- f"Old labels: {old_labels[key]}, new labels: {new_labels[key]}. "
1233
- "If you want to rename the labels of this dimensions, use the method 'replace_labels() before reordering the data")
1234
-
1235
-
1236
- ax_label_dict = {}
1237
- for key in old_labels.keys():
1238
- if key in new_labels.keys():
1239
- ax_label_dict[key] = new_labels[key]
1240
- for lab in new_labels[key]:
1241
- if lab not in old_labels[key]:
1242
- # If the label is not in the list, remove it
1243
- ax_label_dict[key].remove(lab)
1244
- else:
1245
- ax_label_dict[key] = old_labels[key]
1246
-
1247
- sels.append(axis.get(ax_label_dict))
1248
-
1249
- if len(sels) == 0:
1250
- raise ValueError(f"None of the dimensions provided in the new labels dict {new_labels.keys()} are present "+
1251
- f"in the labels of part '{self.name}', which only contains the dimensions {self.get_dimensions()}")
1252
-
1253
- #Execute the selection
1254
- self.data = self.data[np.ix_(*sels)]
1255
-
1256
- # Update the axes with the new labels
1257
- for dim in new_labels.keys():
1258
- self.replace_labels(name = dim, labels = new_labels[dim])
1259
-
1260
-
1261
- def expand(self,axis=None,over="countries"):
1262
- """
1263
- Expand an axis of the Part
1264
-
1265
- Create a new Axes with a unique dimension.
1266
- Note that this operation significantly expands the size of the Part.
1267
- It is recommended to use this method with Extension parts only.
1268
-
1269
- Parameters
1270
- ----------
1271
- axis : int, optional
1272
- Axe to extend.
1273
- If left empty, the first suitable axe is expanded.
1274
- over : str, optional
1275
- Axe dimension to expand the Part by.
1276
- The default is "countries".
1277
-
1278
- Returns
1279
- -------
1280
- Part object
1281
- New Part object with an additional dimension.
1282
-
1283
- """
1284
- if axis is None:
1285
- axis = self.hasax(over)
1286
-
1287
- for ax in axis:
1288
- ref_ax = self.axes[ax]
1289
- new_ax = Axe(ref_ax.labels[over],groupings=self.groupings)
1290
- axes = self.axes.copy()
1291
- axes.insert(ax,new_ax)
1292
- new_shape = list(self.shape)
1293
- new_shape.insert(axis,len(new_ax))
1294
- output = np.zeros(new_shape)
1295
- selector = [slice(None)]*self.ndim
1296
- for item in ref_ax.labels[over]:
1297
- newsel,refsel = selector.copy(),selector.copy()
1298
- newsel.insert(ax,new_ax.sel(item))
1299
- newsel[ax+1] = ref_ax.sel(item)
1300
- refsel[ax]= ref_ax.sel(item)
1301
- output[tuple(newsel)] = self.data[tuple(refsel)]
1302
- return self.alias(data=output,name=f"expanded_{self.name}",axes=axes)
1303
-
1304
- def issquare(self):
1305
- """Assert wether the Part is square"""
1306
- return self.ndim == 2 and len(self.axes[0])==len(self.axes[1])
1307
-
1308
- def hasneg(self):
1309
- """Test whether Part has negative elements"""
1310
- if np.any(self.data<0):
1311
- return True
1312
- return False
1313
-
1314
- def hasax(self,name=None):
1315
- """Returns the dimensions along which a Part has given labels
1316
-
1317
- If no axis can be found, an empty list is returned empty.
1318
- This method can be used to assert the existence of a given dimension
1319
- in the part.
1320
-
1321
- Parameters
1322
- ----------
1323
- name : int, optional
1324
- Name of the label to look for.
1325
- If no name is given, all axes are returned.
1326
-
1327
- Returns
1328
- -------
1329
- axes : list of ints
1330
- Dimensions along which the labels are found.
1331
-
1332
- """
1333
- if name == "any" or name is None:
1334
- return [i for i in range(self.ndim)]
1335
- axes = []
1336
- for i,ax in enumerate(self.axes):
1337
- if name in ax.dimensions:
1338
- axes.append(i)
1339
- return axes
1340
-
1341
- def __str__(self):
1342
- return f"{self.name} Part object with {self.ndim} dimensions"
1343
-
1344
- def sum(self,axis=None,on=None,keepdims=False):
1345
- """
1346
- Sum the Part along one or several axis, and/or on a given dimension.
1347
-
1348
- Parameters
1349
- ----------
1350
- axis : int or list of int, optional
1351
- Axe along which the sum is evaluated.
1352
- By default None, the sum of all coefficients of the Part is returned
1353
- on : str, optional
1354
- name of the dimension to be summed on.
1355
- If no axis is defined, the Part is summed over all axis having
1356
- the corresponding dimension.
1357
- By default None, the full ax is summed
1358
- keepdims : bool, optional
1359
- Whether to keep the number of dimensions of the original.
1360
- By default False, the dimensions of lenght 1 are removed.
1361
-
1362
- Returns
1363
- -------
1364
- Part instance or float
1365
- Result of the sum.
1366
- """
1367
- if axis is None:
1368
- if on is None:
1369
- return self.data.sum()
1370
- if not keepdims:
1371
- self = self.squeeze()
1372
- axis = self.hasax(on)
1373
- if isinstance(axis,int):
1374
- if on is not None:
1375
- return self._sum_on(axis,on,keepdims)
1376
- ax = self.axes.copy()
1377
- if not keepdims:
1378
- del ax[axis]
1379
- else:
1380
- ax[axis] = Axe(["all"])
1381
- return self.alias(
1382
- data=self.data.sum(axis,keepdims=keepdims),
1383
- name=f"{self.name}_sum_{axis}",
1384
- axes = ax
1385
- )
1386
- axis = sorted(axis)
1387
- for ax in axis[::-1]:
1388
- self = self.sum(ax,on,keepdims)
1389
- return self
1390
-
1391
- def _sum_on(self,axis,on,keepdims=False):
1392
- """
1393
- Sum a Part along an axis on a given dimension
1394
- """
1395
- ax = self.axes[axis]
1396
- if on not in ax.dimensions:
1397
- raise ValueError(f"Cannot sum on {on} as it is not a dimension of axis {axis}")
1398
- if ax.levels == 1:
1399
- #If the axis has a single level, this is a simple sum
1400
- axes = self.axes.copy()
1401
- if not keepdims:
1402
- del axes[axis]
1403
- return self.alias(
1404
- data = self.data.sum(axis,keepdims=keepdims),
1405
- name=f"{self.name}_sum_{axis}",
1406
- axes = axes
1407
- )
1408
- #Otherwise, sum on the relevant levels
1409
- idsum = ax.dimensions.index(on) #Index of the dimension to sum on
1410
- dev = self.develop(axis,squeeze=False)
1411
- dev = dev.sum(axis+idsum,keepdims=keepdims)
1412
- if keepdims:
1413
- dev = dev.combine_axes(axis,axis+idsum)
1414
- dev.name = f"{self.name}_sum_on_{on}_{axis}"
1415
- return dev
1416
-
1417
- def save(self,
1418
- file=None,
1419
- name=None,
1420
- extension=".npy",
1421
- overwrite=False,
1422
- include_labels=False,
1423
- write_instructions=False,
1424
- **kwargs):
1425
- """
1426
- Save the Part object to a file
1427
-
1428
- Parameters
1429
- ----------
1430
- name : str, optional
1431
- Name under which the Part is saved.
1432
- By default, the current part is used.
1433
- path : Path-like, optional
1434
- Directory in which the Path should be saved,
1435
- by default None, the dir from which the part was loaded.
1436
- extension : str, optional
1437
- Format under which the part is saved. The default ".npy"
1438
- If ".csv" is chosen, the part is saved as a csv file with labels
1439
- file : path-like, optional
1440
- Full path to the file to save the Part to.
1441
- This overrides the path, name and extension arguments.
1442
- overwrite : boolm optional
1443
- Whether to overwrite an existing file.
1444
- If set False, the file is saved with a new name.
1445
- The default is False.
1446
- write_instructions : bool, optional
1447
- Whether to write the loading instructions to a yaml file.
1448
- The default is False.
1449
- include_labels : bool, optional
1450
- Whether to include the labels in the saved file.
1451
- Only applicable to .csv and .xlsx files.
1452
- **kwargs : dict
1453
- Additional arguments to pass to the saving function
1454
-
1455
- Raises
1456
- ------
1457
- FileNotFoundError
1458
- _description_
1459
- """
1460
- path = kwargs.get("path",None)
1461
- if file is not None:
1462
- path,name = os.path.split(file)
1463
- name,possible_extension = os.path.splitext(name)
1464
- if possible_extension != "":
1465
- extension = possible_extension
1466
- if name is None:
1467
- name = self.name
1468
- if path is None:
1469
- raise FileNotFoundError("No path specified for saving the Part")
1470
- if extension == ".nc":
1471
- path = os.path.join(path,name+extension)
1472
- save_to_nc(self,path,overwrite,
1473
- write_instructions=write_instructions,
1474
- **kwargs)
1475
- else:
1476
- save_part_to_folder(
1477
- self,
1478
- path = path,
1479
- name = name,
1480
- extension = extension,
1481
- overwrite = overwrite,
1482
- include_labels=include_labels,
1483
- write_instructions = write_instructions,
1484
- **kwargs
1485
- )
1486
-
1487
- def to_pandas(self):
1488
- """Return the current Part object as a Pandas DataFrame
1489
-
1490
- Only applicable to Parts objects with 1 or 2 dimensions.
1491
- """
1492
- return converters.pandas.to_pandas(self)
1493
-
1494
- def to_xarray(self):
1495
- """
1496
- Save the Part object to an xarray DataArray
1497
-
1498
- Labels are directly passed to the DataArray as coords.
1499
- Note that data will be flattened.
1500
- The dimension order will be saved as an attribute.
1501
- If you're loading the data back,
1502
- the Part will be automatically reshaped to its original dimensions.
1503
-
1504
- Returns
1505
- -------
1506
- xr.DataArray
1507
- Corresponding DataArray
1508
- """
1509
- return converters.xarray.to_DataArray(self)
1510
-
1511
- def mean(self,axis=None):
1512
- return self.data.mean(axis)
1513
-
1514
- def min(self,axis=None):
1515
- return self.data.min(axis)
1516
-
1517
- def max(self,axis=None):
1518
- return self.data.max(axis)
1519
-
1520
- def mul(self,a,propagate_labels=True):
1521
- """
1522
- Matrix multiplication between parts with labels propagation
1523
-
1524
- Parameters
1525
- ----------
1526
- a : Part or numpy array
1527
- Right-hand multiplicator.
1528
- propagate_labels : bool, optional
1529
- Whether to try propagating the labels from the right hand multiplicator
1530
- By default True.
1531
- If right-hand multiplicator is not a Part object, becomes False.
1532
-
1533
- Returns
1534
- -------
1535
- Part instance
1536
- result of the multiplication
1537
- """
1538
- if not isinstance(a,Part):
1539
- propagate_labels = False
1540
- name="array"
1541
- else:
1542
- name = a.name
1543
- data = np.matmul(self.data,a.data)
1544
- axes = [self.axes[i] for i in range(self.ndim-1)]
1545
- for ax in range(a.ndim-1):
1546
- if propagate_labels:
1547
- axes.append(a.axes[ax+1])
1548
- else:
1549
- axes.append(Axe([i for i in range(a.shape[ax+1])]))
1550
- return self.alias(data=data,name=f"{self.name}.{name}",axes=axes)
1551
-
1552
- def filter(self,threshold,fill_value=0):
1553
- """
1554
- Set to 0 the values below a given threshold
1555
-
1556
- Parameters
1557
- ----------
1558
- threshold : float
1559
- Threshold value.
1560
- fill_value : float, optional
1561
- Value to replace the filtered values with.
1562
- The default is 0.
1563
-
1564
- Returns
1565
- -------
1566
- Part instance
1567
- Filtered Part.
1568
-
1569
- """
1570
- data = self.data.copy()
1571
- data[data<threshold] = fill_value
1572
- return self.alias(data=data,name=f"filtered_{self.name}_{threshold}")
1573
-
1574
- def diag(self):
1575
- if self.ndim == 1:
1576
- log.info("Diagonalize a 1D part")
1577
- return self.alias(data=np.diag(self.data),
1578
- name=f"diag_{self.name}",
1579
- axes = self.axes*2)
1580
- try:
1581
- log.info("The part has too many dimensions: try to diagonalize the squeezed part")
1582
- return self.squeeze().diag()
1583
- except:
1584
- raise ValueError("Cannot diagonalize a part with more than 2 dimensions")
1585
-
1586
- def __add__(self,a):
1587
- if isinstance(a,Part):
1588
- name = a.name
1589
- a = a.data
1590
- else:
1591
- name=""
1592
- if isinstance(a,np.ndarray) and self.ndim != a.ndim:
1593
- a = a.squeeze()
1594
- self = self.squeeze()
1595
- return self.alias(data=a+self.data,name=f"{self.name}+{name}")
1596
-
1597
- def __radd__(self,a):
1598
- return self.__add__(a)
1599
-
1600
- def __rmul__(self,a):
1601
- return self.__mul__(a)
1602
-
1603
- def __mul__(self,a):
1604
- if isinstance(a,Part):
1605
- name = "{a.name}*{self.name}"
1606
- a = a.data
1607
- else:
1608
- if isinstance(a,int):
1609
- name = f"{a}*{self.name}"
1610
- else:
1611
- name = f"array*{self.name}"
1612
- data = self.data*a
1613
- if data.ndim!=self.ndim:
1614
- data = data.squeeze()
1615
- #Trust numpy to broadcast the multiplication
1616
- #Squeeze to get rid of unused dimensions
1617
- return self.alias(data=data,
1618
- name=name)
1619
-
1620
- def __neg__(self):
1621
- return self.alias(data=-self.data,name=f"-{self.name}")
1622
-
1623
- def __sub__(self,a):
1624
- if isinstance(a,Part):
1625
- name = a.name
1626
- a = a.data
1627
- else:
1628
- name=""
1629
- return self.alias(data=self.data-a,name=f"{self.name}-{name}")
1630
-
1631
- def __rsub__(self,a):
1632
- if isinstance(a,Part):
1633
- name = a.name
1634
- a = a.data
1635
- else:
1636
- name=""
1637
- return self.alias(data=a-self.data,name=f"{name}-{self.name}")
1638
-
1639
- def power(self,a):
1640
- if isinstance(a,Part):
1641
- a = a.data
1642
- name = f"{self.name}**{a.name}"
1643
- elif isinstance(a,int) or isinstance(a,float):
1644
- name = f"{self.name}**{a}"
1645
- else:
1646
- name = f"{self.name}**array"
1647
- return self.alias(data=np.power(self,a),name=name)
1648
-
1649
- def __pow__(self,a):
1650
- return self.power(a)
1651
-
1652
- def __eq__(self,a):
1653
- if isinstance(a,Part):
1654
- return np.all(self.data==a.data)
1655
- return False
1656
-
1657
- def __rtruediv__(self,a):
1658
- if isinstance(a,Part):
1659
- name = f"{self.name}/{a.name}"
1660
- a = a.data
1661
- else:
1662
- if isinstance(a,int):
1663
- name = f"{a}/{self.name}"
1664
- else:
1665
- name= f"array/{self.name}"
1666
- if np.sum(self.data==0)!=0:
1667
- log.warning("Division by zero in "+name)
1668
- return self.alias(data=a/self.data,
1669
- name=name)
1670
-
1671
- def __truediv__(self,a):
1672
- if isinstance(a,Part):
1673
- name = f"{a.name}/{self.name}"
1674
- a = a.data
1675
- else:
1676
- if isinstance(a,int):
1677
- name = f"{self.name}/{a}"
1678
- else:
1679
- name= f"{self.name}/array"
1680
- if np.sum(a==0)!=0:
1681
- log.warning("Division by zero in "+name)
1682
- return self.alias(data=self.data/a,
1683
- name=name)
1684
-
1685
- def __getattr__(self,name):
1686
- name = name.casefold()
1687
- try:
1688
- return self.metadata[name]
1689
- except:
1690
- pass
1691
- raise AttributeError(f"Attribute {name} not found")
1692
-
1693
- def transpose(self):
1694
- return self.alias(data=self.data.transpose(),
1695
- name=f"transposed_{self.name}",
1696
- axes=self.axes[::-1])
1697
-
1698
-