mrio-toolbox 1.1.2__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mrio-toolbox might be problematic. Click here for more details.

Files changed (61) hide show
  1. {mrio_toolbox-1.1.2.dist-info → mrio_toolbox-1.1.3.dist-info}/METADATA +1 -1
  2. mrio_toolbox-1.1.3.dist-info/RECORD +5 -0
  3. mrio_toolbox-1.1.3.dist-info/top_level.txt +1 -0
  4. __init__.py +0 -21
  5. _parts/_Axe.py +0 -539
  6. _parts/_Part.py +0 -1739
  7. _parts/__init__.py +0 -7
  8. _parts/part_operations.py +0 -57
  9. extractors/__init__.py +0 -20
  10. extractors/downloaders.py +0 -36
  11. extractors/emerging/__init__.py +0 -3
  12. extractors/emerging/emerging_extractor.py +0 -117
  13. extractors/eora/__init__.py +0 -3
  14. extractors/eora/eora_extractor.py +0 -132
  15. extractors/exiobase/__init__.py +0 -3
  16. extractors/exiobase/exiobase_extractor.py +0 -270
  17. extractors/extractors.py +0 -81
  18. extractors/figaro/__init__.py +0 -3
  19. extractors/figaro/figaro_downloader.py +0 -280
  20. extractors/figaro/figaro_extractor.py +0 -187
  21. extractors/gloria/__init__.py +0 -3
  22. extractors/gloria/gloria_extractor.py +0 -202
  23. extractors/gtap11/__init__.py +0 -7
  24. extractors/gtap11/extraction/__init__.py +0 -3
  25. extractors/gtap11/extraction/extractor.py +0 -129
  26. extractors/gtap11/extraction/harpy_files/__init__.py +0 -6
  27. extractors/gtap11/extraction/harpy_files/_header_sets.py +0 -279
  28. extractors/gtap11/extraction/harpy_files/har_file.py +0 -262
  29. extractors/gtap11/extraction/harpy_files/har_file_io.py +0 -974
  30. extractors/gtap11/extraction/harpy_files/header_array.py +0 -300
  31. extractors/gtap11/extraction/harpy_files/sl4.py +0 -229
  32. extractors/gtap11/gtap_mrio/__init__.py +0 -6
  33. extractors/gtap11/gtap_mrio/mrio_builder.py +0 -158
  34. extractors/icio/__init__.py +0 -3
  35. extractors/icio/icio_extractor.py +0 -121
  36. extractors/wiod/__init__.py +0 -3
  37. extractors/wiod/wiod_extractor.py +0 -143
  38. mrio.py +0 -899
  39. mrio_toolbox-1.1.2.dist-info/RECORD +0 -59
  40. mrio_toolbox-1.1.2.dist-info/top_level.txt +0 -6
  41. msm/__init__.py +0 -6
  42. msm/multi_scale_mapping.py +0 -863
  43. utils/__init__.py +0 -3
  44. utils/converters/__init__.py +0 -5
  45. utils/converters/pandas.py +0 -244
  46. utils/converters/xarray.py +0 -132
  47. utils/formatting/__init__.py +0 -0
  48. utils/formatting/formatter.py +0 -527
  49. utils/loaders/__init__.py +0 -7
  50. utils/loaders/_loader.py +0 -312
  51. utils/loaders/_loader_factory.py +0 -96
  52. utils/loaders/_nc_loader.py +0 -184
  53. utils/loaders/_np_loader.py +0 -112
  54. utils/loaders/_pandas_loader.py +0 -128
  55. utils/loaders/_parameter_loader.py +0 -386
  56. utils/savers/__init__.py +0 -11
  57. utils/savers/_path_checker.py +0 -37
  58. utils/savers/_to_folder.py +0 -165
  59. utils/savers/_to_nc.py +0 -60
  60. {mrio_toolbox-1.1.2.dist-info → mrio_toolbox-1.1.3.dist-info}/WHEEL +0 -0
  61. {mrio_toolbox-1.1.2.dist-info → mrio_toolbox-1.1.3.dist-info}/licenses/LICENSE +0 -0
_parts/_Part.py DELETED
@@ -1,1739 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Fri Apr 14 14:13:17 2023
4
-
5
- @author: beaufils
6
- """
7
-
8
- import os
9
- import itertools
10
- import numpy as np
11
- import pandas as pd
12
- import xarray as xr
13
- import copy
14
- from mrio_toolbox._parts._Axe import Axe
15
- import logging
16
- from mrio_toolbox.utils import converters
17
- from mrio_toolbox.utils.loaders import make_loader
18
- from mrio_toolbox.utils.savers import save_part_to_folder,save_to_nc
19
- from mrio_toolbox._parts import part_operations
20
-
21
- log = logging.getLogger(__name__)
22
-
23
- def load_part(
24
- **kwargs
25
- ):
26
- loader = make_loader(**kwargs)
27
- return Part(**loader.load_part(**kwargs))
28
-
29
- class Part:
30
- """
31
- Representation of an MRIO Part object.
32
-
33
- MRIO Parts are the basic building blocks of the MRIO toolbox. A Part is
34
- built from a numpy array and a set of Axes, corresponding to the dimensions
35
- of the array. The Axes hold the labels of the Part in the different
36
- dimensions and are used to perform advanced indexing and operations on the Part.
37
-
38
- Axes support multi-level indexing and groupings.
39
-
40
- Instance variables
41
- ------------------
42
- data : numpy.ndarray
43
- Numerical data of the Part.
44
- axes : list of Axe instances
45
- Axes corresponding to the dimensions of the Part.
46
- groupings : dict
47
- Groupings of the labels of the Part, for each label defined.
48
- metadata : dict
49
- Additional metadata of the Part (e.g., path, name, multiplier, unit).
50
- name : str
51
- Name of the Part.
52
- ndim : int
53
- Number of dimensions of the Part.
54
- shape : tuple
55
- Shape of the Part.
56
-
57
- Methods
58
- -------
59
- __init__(data=None, labels=None, axes=None, **kwargs):
60
- Initialize a Part object.
61
- alias(**kwargs):
62
- Create a new Part with modified parameters.
63
- fix_dims(skip_labels=False, skip_data=False):
64
- Align the number of axes with the number of dimensions.
65
- get(*args, aspart=True, squeeze=False):
66
- Extract data from the Part object.
67
- setter(value, *args):
68
- Change the value of a data selection.
69
- develop(axis=None, on=None, squeeze=True):
70
- Reshape a Part to avoid double labels.
71
- reformat(new_dimensions):
72
- Reshape a Part to match a new dimensions combination.
73
- combine_axes(start=0, end=None, in_place=False):
74
- Combine axes of a Part into a single one.
75
- swap_axes(axis1, axis2):
76
- Swap two axes of a Part.
77
- swap_ax_levels(axis, dim1, dim2):
78
- Swap two levels of an axis.
79
- flatten(invert=False):
80
- Flatten a 2D Part into a 1D Part.
81
- squeeze():
82
- Remove dimensions of length 1 from the Part.
83
- expand_dims(axis, copy=None):
84
- Add dimensions to a Part instance.
85
- copy():
86
- Return a copy of the current Part object.
87
- extraction(dimensions, labels=["all"], on_groupings=True, domestic_only=False, axis="all"):
88
- Set labels over dimension(s) to 0.
89
- leontief_inversion():
90
- Compute the Leontief inverse of a square Part.
91
- update_groupings(groupings, ax=None):
92
- Update the groupings of the current Part object.
93
- aggregate(on="countries", axis=None):
94
- Aggregate dimensions along one or several axes.
95
- aggregate_on(on, axis):
96
- Aggregate a Part along a given axis.
97
- get_labels(axis=None):
98
- Returns a list with the labels of the axes as dictionaries.
99
- list_labels():
100
- List the labels of the Part.
101
- get_dimensions(axis=None):
102
- Return the list of dimensions of the Part.
103
- rename_labels(old, new):
104
- Rename some labels of the Part.
105
- replace_labels(name, labels, axis=None):
106
- Update a label of the Part.
107
- set_labels(labels, axis=None):
108
- Change the labels of the Part.
109
- add_labels(labels, dimension=None, axes=None, fill_value=0):
110
- Add indices to one or multiple Part axes.
111
- expand(axis=None, over="countries"):
112
- Expand an axis of the Part.
113
- issquare():
114
- Check whether the Part is square.
115
- hasneg():
116
- Check whether the Part has negative elements.
117
- hasax(name=None):
118
- Return the dimensions along which a Part has given labels.
119
- sum(axis=None, on=None, keepdims=False):
120
- Sum the Part along one or several axes or on a given dimension.
121
- save(file=None, name=None, extension=".npy", overwrite=False, include_labels=False, write_instructions=False, **kwargs):
122
- Save the Part object to a file.
123
- to_pandas():
124
- Return the current Part object as a Pandas DataFrame.
125
- to_xarray():
126
- Save the Part object to an xarray DataArray.
127
- mean(axis=None):
128
- Compute the mean of the Part along a given axis.
129
- min(axis=None):
130
- Compute the minimum value of the Part along a given axis.
131
- max(axis=None):
132
- Compute the maximum value of the Part along a given axis.
133
- mul(a, propagate_labels=True):
134
- Perform matrix multiplication between Parts with label propagation.
135
- filter(threshold, fill_value=0):
136
- Set to 0 the values below a given threshold.
137
- diag():
138
- Create a diagonal Part from a 1D Part or extract the diagonal of a 2D Part.
139
- transpose():
140
- Transpose the Part object.
141
- """
142
-
143
- def __init__(self,data=None,
144
- labels=None,
145
- axes=None,
146
- **kwargs):
147
- """
148
- Initialize a Part object.
149
-
150
- Parameters
151
- ----------
152
- data : numpy.ndarray, optional
153
- Numerical data of the Part. If not provided, a Part filled with
154
- zeros (or another fill value) is created based on the shape of the axes.
155
- labels : list of str or dict, optional
156
- Labels for the axes. If provided, the labels define the structure
157
- of the axes. If not provided, axes are created based on the data.
158
- axes : list of Axe instances, optional
159
- Custom Axes for the Part. If not provided, axes are created from
160
- the labels or inferred from the data.
161
- kwargs : dict
162
- Additional metadata for the Part (e.g., path, name, multiplier, unit).
163
-
164
- Raises
165
- ------
166
- ValueError
167
- If the length of the labels does not match the data dimensions.
168
- TypeError
169
- If the provided labels are of an unsupported type.
170
-
171
- Notes
172
- -----
173
- If both `data` and `axes` are not provided, the method creates an
174
- empty Part with default axes and zero-filled data.
175
- """
176
-
177
- if data is not None:
178
- if isinstance(data,Part):
179
- data = data.data
180
- if isinstance(data,(xr.DataArray,xr.Dataset)):
181
- self.__init__(
182
- **converters.xarray.make_part(data)
183
- )
184
- return
185
- if isinstance(data,pd.DataFrame):
186
- self.__init__(**converters.pandas.make_part(
187
- data)
188
- )
189
- return
190
- self.data = data
191
- self.ndim = data.ndim
192
- self.shape = self.data.shape
193
-
194
- self.name = kwargs.pop("name","new_part")
195
- log.debug("Create Part instance " + self.name)
196
-
197
- self.groupings = kwargs.get("groupings",dict())
198
- self.metadata = kwargs.get("metadata",dict())
199
- self.metadata = {**self.metadata,**kwargs}
200
-
201
- if axes is None:
202
- self.axes = []
203
- self._create_axis(labels)
204
- else:
205
- self.axes = axes
206
-
207
- if data is None:
208
- log.debug("Create empty Part")
209
- data = np.zeros([len(ax) for ax in self.axes])
210
- self.data = data
211
- self.ndim = data.ndim
212
- self.shape = self.data.shape
213
-
214
- self.fix_dims()
215
- for dim in range(self.ndim):
216
- if len(self.axes[dim]) != self.shape[dim] and self.shape[dim] != 1:
217
- log.critical(f"Length of label {dim} does not match data: "+\
218
- f"{len(self.axes[dim])} and {self.shape[dim]}")
219
- raise ValueError(f"Length of label {dim} does not match data: "+\
220
- f"{len(self.axes[dim])} and {self.shape[dim]}")
221
-
222
- if "_original_dimensions" in self.metadata.keys():
223
- #This set of instructions is intended to handle
224
- #Data loaded from netcdf files
225
- log.info("Checking if a reformatting is needed.")
226
- original = self.metadata.pop("_original_dimensions")
227
- new_dims = [[]]
228
- for dim in original:
229
- #Decode the original dimensions
230
- #Because netcdf files do not support multi-level attributes
231
- if dim == "_sep_":
232
- new_dims.append([])
233
- else:
234
- new_dims[-1].append(dim)
235
-
236
- if new_dims != self.get_dimensions():
237
-
238
- log.info("Reformat the Part")
239
- new_part = self.reformat(new_dims)
240
- self.data = new_part.data
241
- self.axes = new_part.axes
242
- self.ndim = self.data.ndim
243
- self.shape = self.data.shape
244
-
245
- self._store_labels()
246
-
247
-
248
- def alias(self,**kwargs):
249
- """Create a new Part in which only prescribed parameters are changed
250
-
251
- The current Part is taken as reference: all arguments not explicitely
252
- set are copied from the current part."""
253
- data = kwargs.get("data",self.data).copy()
254
- name = kwargs.get("name",self.name+"_alias")
255
- groupings = kwargs.get("groupings",self.groupings)
256
- axes = kwargs.get("axes",self.axes)
257
- labels = kwargs.get("labels",None)
258
- metadata = kwargs.get("metadata",self.metadata)
259
- return Part(
260
- data=data,
261
- name=name,
262
- groupings = groupings,
263
- labels=labels,
264
- axes=axes,
265
- metadata = metadata)
266
-
267
- def _create_axis(self,labels):
268
- """
269
- Create an Axe object based on a tuple of lists of indices
270
-
271
- Parameters
272
- ----------
273
- *args : tuple of lists of str, list of str
274
- Labels of the axe.
275
- The first argument is used as the main label.
276
- The second argument (if any) is used as secondary label.
277
- If left empty, the axe is labelled by indices only
278
-
279
- Raises
280
- ------
281
- TypeError
282
- Raised if the arguments types differs from the number of dimensions
283
- or if input labels are incorrect.
284
- ValueError
285
- Raised if the label length does not match the data.
286
-
287
- Returns
288
- -------
289
- None.
290
-
291
- """
292
- self.axes = []
293
-
294
- if isinstance(labels,(tuple,list)) and len(labels)>0 and\
295
- isinstance(labels[0],(str,int,float)):
296
- #If the first item of the labels is not iterable,
297
- #we assume the label is an axis label
298
- labels = [labels]
299
- #We add a dimension to the labels such that enumeration works properly
300
-
301
- if "data" in self.__dict__.keys():
302
- enum = self.ndim
303
- else:
304
- if labels is None:
305
- raise ValueError("Cannot create axes without data, axes or labels")
306
- enum = len(labels)
307
-
308
- for dim in range(enum):
309
- if labels is None or dim > len(labels):
310
- #Fill empty labels with indices
311
- self.axes.append(
312
- Axe([i for i in range(self.shape[dim])],
313
- groupings = self.groupings)
314
- )
315
- elif isinstance(labels,dict):
316
- axname = list(labels.keys())[dim]
317
- self.axes.append(
318
- Axe(labels[axname],groupings=self.groupings,name=axname)
319
- )
320
- elif isinstance(labels,(list,tuple)):
321
- self.axes.append(
322
- Axe(labels[dim],groupings=self.groupings)
323
- )
324
- else:
325
- log.critical("Unkown label type: "+type(labels))
326
- raise TypeError("Unknown label type: "+type(labels))
327
- log.debug(f"Create ax {dim} with len {self.axes[-1]}")
328
-
329
- def fix_dims(self,
330
- skip_labels=False,
331
- skip_data=False):
332
- """Align the number of axes with the number of dimensions
333
-
334
- If one length exceeds the other, axes and/or data are squeezed,
335
- i.e. dimensions of length 1 are removed."""
336
- if len(self.axes) == self.ndim:
337
- return
338
- log.warning(
339
- f"The number of axes ({len(self.axes)})"\
340
- +f" does not match data dimensions ({self.ndim})"
341
- )
342
-
343
- if len(self.axes) > self.ndim and not skip_labels:
344
- log.debug(
345
- "Try to squeeze axe(s) of len 1"
346
- )
347
- counter = [len(ax) != 1 for ax in self.axes]
348
- self.axes = self.axes[counter]
349
- return self.fix_dims(
350
- skip_labels=True,
351
- skip_data=skip_data)
352
- if self.ndim > len(self.axes) and not skip_data:
353
- self.data = self.data.squeeze()
354
- self.shape = self.data.shape
355
- self.ndim = self.data.ndim
356
- return self.fix_dims(
357
- skip_labels=skip_labels,
358
- skip_data=True)
359
-
360
- log.critical("Cannot reconcile data of dims "+ self.ndim+\
361
- " with axes of dim " +len(self.axes))
362
- raise IndexError("Cannot reconcile data of dims "+ self.ndim+\
363
- " with axes of dim " +len(self.axes))
364
-
365
-
366
- def __getitem__(self,args):
367
- if isinstance(args,str) or isinstance(args,int) or isinstance(args,np.integer) or isinstance(args,dict):
368
- args = (args,)
369
- return self.get(*args)
370
-
371
- def __setitem__(self,args,value):
372
- if isinstance(value,Part):
373
- value = value.data
374
- if isinstance(args,str) or isinstance(args,int) or isinstance(args,np.integer) or isinstance(args,dict):
375
- args = (args,)
376
- self.setter(value,*args)
377
-
378
- def setter(self,value,*args):
379
- """
380
- Change the value of a data selection
381
-
382
- Parameters
383
- ----------
384
- value : float or numpy like
385
- Value to set.
386
- *args : list of tuples
387
- Indices along the respective axes.
388
-
389
- Returns
390
- -------
391
- None.
392
- Modification is applied to the current Part object
393
-
394
- """
395
- sels = []
396
- try:
397
- #First tries to interpret one arg per ax
398
- for i,arg in enumerate(args):
399
- sels.append(self.axes[i].get(arg))
400
- except (IndexError,ValueError):
401
- #Otherwise, tries to interpret all args on the first ax
402
- sels = []
403
- sels = [self.axes[0].get(args)]
404
- if isinstance(value,(np.ndarray,Part)) and len(sels)!=value.ndim:
405
- if len(sels) < value.ndim:
406
- value = value.squeeze()
407
- if len(sels) > value.ndim:
408
- target_shape = [len(sel) for sel in sels]
409
- value = np.reshape(value,target_shape)
410
- self.data[np.ix_(*sels)] = value
411
-
412
- def get(self,*args,aspart=True,squeeze=False):
413
- """
414
- Extract data from the current Part object
415
-
416
- Parameters
417
- ----------
418
- *args : list of tuples
419
- Selection along the Axes of the Part.
420
- aspart : bool, optional
421
- Whether to return the selection as a Part object.
422
- If False, the selection is returned as a numpy object.
423
- The default is True.
424
- squeeze : bool, optional
425
- Whether to remove dimensions of length 1.
426
- The default is False
427
-
428
- Returns
429
- -------
430
- New Part object or numpy object
431
- """
432
- sels = []
433
- axes = []
434
-
435
- #Extract the indices for the selection
436
- try:
437
- #First tries to interpret one arg per ax
438
- for i,arg in enumerate(args):
439
- datasel,labs,groupings = self.axes[i].get(arg,True)
440
- if not squeeze or len(datasel) > 1:
441
- axes.append(Axe(labs,groupings))
442
- sels.append(datasel)
443
- except (ValueError, IndexError):
444
- sels = []
445
- axes = []
446
- #Try interpreting all args on the first ax
447
- datasel,labs,groupings = self.axes[0].get(args,True)
448
- if not squeeze or len(datasel) > 1:
449
- axes.append(Axe(labs,groupings))
450
- sels.append(datasel)
451
-
452
- #If the selection is not complete, fill with all
453
- if len(sels)<self.ndim:
454
- for i in range(len(sels),self.ndim):
455
- datasel,labs,groupings = self.axes[i].get("all",True)
456
- if not squeeze or len(datasel) > 1:
457
- axes.append(Axe(labs,groupings))
458
- sels.append(datasel)
459
-
460
- #Execute the selection
461
- data = self.data[np.ix_(*sels)]
462
-
463
- #Return the selection
464
- if squeeze:
465
- data = data.squeeze()
466
- if aspart:
467
- return Part(data=data,name=f"sel_{self.name}",
468
- groupings=self.groupings,axes = axes)
469
- return data
470
-
471
- def develop(self,axis=None,on=None,squeeze=True):
472
- """
473
- Reshape a Part to avoid double labels
474
-
475
- Parameters
476
- ----------
477
- axis : int or list of int, optional
478
- Axis to develop.
479
- If left empty, all axes are developed.
480
- The default is None.
481
- on : str or list of str, optional
482
- Dimensions to develop.
483
- If left empty, all dimensions are developed.
484
- Note that the develop method does not support the developping of
485
- non-contiguous dimensions.
486
- The default is None.
487
- squeeze : bool, optional
488
- Whether to remove dimensions of length 1.
489
- The default is True.
490
-
491
- Returns
492
- -------
493
- Developped Part : Part object
494
- The developed part
495
- """
496
- if isinstance(on,str):
497
- on = [on]
498
- if isinstance(axis,int):
499
- axis = [axis]
500
- axes = []
501
- for i,ax in enumerate(self.axes):
502
- if axis is None or i in axis:
503
- labels = ax.labels.copy()
504
- for dim in ax.dimensions:
505
- if on is None or dim in on:
506
- #Add dimension that needs to be developed
507
- axes.append(Axe({dim:labels[dim]},groupings=ax.groupings))
508
- labels.pop(dim)
509
- if len(labels) > 0:
510
- #Keep remaining dimensions together
511
- axes.append(Axe(labels,groupings=ax.groupings))
512
- else:
513
- axes.append(ax)
514
-
515
-
516
- old_dim_order = [dim for ax in self.axes for dim in ax.dimensions]
517
- new_dim_order = [dim for ax in axes for dim in ax.dimensions]
518
- if new_dim_order != old_dim_order:
519
- raise NotImplementedError(
520
- "Developping the part misaligns the dimensions. "+\
521
- "This operation is not yet supported."
522
- )
523
- #If the order of the dimensions is unchanged, we can simply reshape
524
-
525
- shape = [len(ax) for ax in axes]
526
- data = self.data.reshape(shape)
527
- if squeeze:
528
- return Part(data=data,
529
- name=f"developped_{self.name}",
530
- groupings=self.groupings,axes=axes).squeeze()
531
- return Part(data=data,name=f"developped_{self.name}",
532
- groupings=self.groupings,axes=axes)
533
-
534
- def reformat(self, new_dimensions):
535
- """
536
- Reshape a Part to match a new dimensions combination.
537
-
538
- Equivalent to a combination of the develop and combine_axes methods.
539
-
540
- This only works for contiguous dimensions in the current Part,
541
- without overlapping dimensions.
542
-
543
- Parameters
544
- ----------
545
- new_dimensions : list of list of str
546
- Target dimensions to reshape into.
547
-
548
- Returns
549
- -------
550
- data : numpy.ndarray
551
- Reshaped data.
552
- axes : list of Axe
553
- Reshaped axes.
554
-
555
- Examples
556
- --------
557
- If the Part has dimensions::
558
-
559
- [["countries"], ["sectors"], ["sectors"]]
560
-
561
- The following is allowed::
562
-
563
- [["countries", "sectors"], ["sectors"]]
564
-
565
- The following is not allowed::
566
-
567
- [["countries"], ["sectors", "sectors"]]
568
- [["sectors"], ["countries", "sectors"]]
569
- [["sectors", "countries"], ["sectors"]]
570
- """
571
- return part_operations.reformat(self,new_dimensions)
572
-
573
- def combine_axes(self,start=0,end=None,in_place=False):
574
- """
575
- Combine axes of a Part into a single one.
576
-
577
- The order of dimensions is preserved in the new axis.
578
- Only consecutive axes can be combined.
579
- The method can be used to revert the develop method.
580
-
581
- Parameters
582
- ----------
583
- start : int, optional
584
- Index of the first axis to combine, by default 0
585
- end : int, optional
586
- Index of the final axis to combine, by default None,
587
- all axis are combined, i.e. the Part is flattened.
588
-
589
- Returns
590
- -------
591
- Part instance
592
-
593
- Raises
594
- ------
595
- IndexError
596
- Axes should have no overlapping dimensions.
597
- """
598
- axes = []
599
- covered = []
600
- labels,groupings = dict(),dict()
601
- if end is None:
602
- end = self.ndim - 1
603
- for i,ax in enumerate(self.axes):
604
- if i in range(start,end+1):
605
- for name in ax.dimensions:
606
- if name in covered:
607
- raise IndexError(
608
- "Cannot undevelop axes with overlapping dimensions"
609
- )
610
- covered.append(name)
611
- labels[name] = ax.labels[name]
612
- if name in ax.groupings.keys():
613
- groupings[name] = ax.groupings[name]
614
- else:
615
- axes.append(ax)
616
- if i == end:
617
- axes.append(Axe(labels,groupings))
618
-
619
- new_shape = [len(ax) for ax in axes]
620
- data = self.data.reshape(new_shape)
621
- if in_place:
622
- return data,axes
623
- return self.alias(data=data,name=f"combined_{self.name}",
624
- axes=axes)
625
-
626
- def swap_axes(self,axis1,axis2):
627
- """Swap two axes of a Part
628
-
629
- Parameters
630
- ----------
631
- axis1 : int
632
- First axis to swap.
633
- axis2 : int
634
- Second axis to swap.
635
-
636
- Returns
637
- -------
638
- Part instance
639
- Part with swapped axes.
640
-
641
- """
642
- axes = self.axes.copy()
643
- axes[axis1],axes[axis2] = axes[axis2],axes[axis1]
644
- data = self.data.swapaxes(axis1,axis2)
645
- return Part(data=data,name=f"swapped_{self.name}",axes=axes)
646
-
647
- def swap_ax_levels(self,axis,dim1,dim2):
648
- """Swap two levels of an axis
649
-
650
- Parameters
651
- ----------
652
- axis : int
653
- Axis to modify.
654
- dim1 : str
655
- First dimension to swap.
656
- dim2 : str
657
- Second dimension to swap.
658
-
659
- Returns
660
- -------
661
- Part instance
662
- Part with swapped levels.
663
-
664
- """
665
- len1,len2 = len(self.axes[axis].labels[dim1]),len(self.axes[axis].labels[dim2])
666
- if len1 == 1 or len2 == 1:
667
- axes = self.axes.copy()
668
- axes[axis].swap_levels(dim1,dim2)
669
- return Part(data=self.data,name=f"swapped_{self.name}",axes=axes)
670
- dimensions = self.axes[axis].dimensions
671
- id1,id2 = dimensions.index(dim1),dimensions.index(dim2)
672
- offset = sum([len(ax.dimensions) for ax in self.axes[:axis]])
673
- dev = self.develop(axis)
674
- dev = dev.swap_axes(id1+offset,id2+offset)
675
- dev = dev.combine_axes(axis,axis+len(dimensions)-1)
676
- dev.name = f"swapped_{self.name}"
677
- return dev
678
-
679
- def flatten(self):
680
- """Flatten a multidimensional Part into a 1D Part
681
-
682
- Because Parts do not support repeated dimensions over the same axis,
683
- Axis dimensions are disambiguated if needed by appending an index to the dimension name.
684
-
685
- """
686
- def disambiguous_dimension(dim,existing,i=1):
687
- while f"{dim}_{i}" in existing:
688
- return disambiguous_dimension(dim,existing,i+1)
689
-
690
- return f"{dim}_{i}"
691
-
692
- enumerator = range(self.ndim)
693
- order = "C"
694
- labels = dict()
695
- groupings = dict()
696
- for ax in enumerator:
697
- for dim in self.axes[ax].dimensions:
698
- dim_name = dim
699
- if dim in labels.keys():
700
- dim_name = disambiguous_dimension(dim,labels)
701
- log.info(f"Disambiguate dimension {dim} on axis {ax} into {dim_name}")
702
- labels[dim_name] = self.axes[ax].labels[dim]
703
- if dim in self.groupings.keys():
704
- groupings[dim_name] = self.axes[ax].groupings[dim]
705
- ax = Axe(labels,groupings)
706
- return self.alias(data=self.data.flatten(order=order),
707
- name=f"flattened_{self.name}",
708
- axes =[ax])
709
-
710
-
711
- def squeeze(self,drop_ax=True,drop_dims=True):
712
- axes = []
713
- for ax in self.axes:
714
- if drop_dims:
715
- ax.squeeze()
716
- if len(ax) > 1 or not drop_ax:
717
- axes.append(ax)
718
- return self.alias(data=np.squeeze(self.data),axes=axes,
719
- name=f"squeezed_{self.name}")
720
-
721
- def expand_dims(self,axis,copy=None):
722
- """Add dimensions to a Part instance
723
-
724
- Parameters
725
- ----------
726
- axis : int
727
- Position of the new axis.
728
- copy : int, optional
729
- Axis to copy the labels from.
730
- If left empty, the axis is created without labels.
731
- """
732
- axes = self.axes.copy()
733
- if copy is not None:
734
- axes.insert(axis,self.axes[copy])
735
- else:
736
- axes.insert(axis,Axe({"expanded":[0]}))
737
- return self.alias(data = np.expand_dims(self.data,axis),
738
- axes=axes,name=f"expanded_{self.name}")
739
-
740
- def copy(self):
741
- """Return a copy of the current Part object"""
742
- return self.alias()
743
-
744
- def extraction(self,
745
- dimensions,
746
- labels=["all"],
747
- on_groupings=True,
748
- domestic_only=False,
749
- axis="all"):
750
- """
751
- Set labels over dimension(s) to 0.
752
-
753
- Parameters
754
- ----------
755
- dimensions : str, list of str, dict
756
- Name of the dimensions on which the extraction is done.
757
- If dict is passed, the keys are interpreted as the dimensions
758
- and the values as the labels
759
- labels : list of (list of) str, optional
760
- Selection on the dimension to put to 0.
761
- on_groupings : bool, optional
762
- Whether to use the groupings to select the labels.
763
- This matters only when the domestic_only argument is set to True.
764
- domestic_only : bool, optional
765
- If yes, only domestic transactions are set to 0 and trade flows
766
- are left untouched. The default is False.
767
- axis : list of ints, optional
768
- Axis along which the extraction is done. The default is "all".
769
- In any case, the extraction only applies to axis allowing it, that
770
- is in axis containing zones or countries labels corresponding to
771
- the zone selection.
772
-
773
- Returns
774
- -------
775
- Part object
776
- New Part with selection set to 0.
777
-
778
- """
779
- if isinstance(dimensions,str):
780
- dimensions = [dimensions]
781
- if isinstance(labels,str) and labels!="all":
782
- labels = [labels]
783
- if isinstance(dimensions,dict):
784
- to_select = dimensions
785
- labels = list(to_select.values())
786
- dimensions = list(dimensions.keys())
787
- else:
788
- to_select = dict()
789
- for dim,label in zip(dimensions,labels):
790
- to_select[dim] = label
791
- if len(labels) != len(dimensions):
792
- if len(dimensions)==1:
793
- #If only one dimension is passed, we broadcast the labels
794
- labels = [labels]
795
- else:
796
- #Raise an error for ambiguous cases
797
- log.critical("Number of dimensions and labels do not match for extraction")
798
- raise ValueError("Number of dimensions and labels do not match for extraction")
799
-
800
- allowed = []
801
- for i,ax in enumerate(self.axes):
802
- if all(dimension in ax.dimensions for dimension in dimensions):
803
- allowed.append(i)
804
- if len(allowed) == 0:
805
- if len(dimensions) == 1:
806
- log.critical("No axis found for extraction on "+str(dimensions))
807
- raise ValueError("No axis found for extraction on "+str(dimensions))
808
- log.info(f"No axis found for simultaneous extractions on {dimensions}")
809
- log.info(f"Try successive extractions on {dimensions}")
810
- for dim,label in zip(dimensions,labels):
811
- self.extraction(dim,label,
812
- on_groupings=on_groupings,
813
- domestic_only=domestic_only,
814
- axis=axis)
815
- if axis == "all":
816
- log.info(f"Extract {to_select} on axes "+ str(allowed))
817
- axis = allowed
818
- if isinstance(axis,int):
819
- axis = [axis]
820
- if not all(ax in allowed for ax in axis):
821
- wrong = [ax for ax in axis if ax not in allowed]
822
- log.critical(f"Cannot extract {dimensions} on axis {wrong}")
823
- raise ValueError(f"Cannot extract {dimensions} on axis {wrong}")
824
-
825
- if not domestic_only:
826
- #If no domestic_only, we can simply set the selection to 0
827
- sel = ["all"]*self.ndim
828
- for i,ax in enumerate(self.axes):
829
- if i in axis:
830
- sel[i] = to_select
831
- output = self.copy()
832
- output[sel] = 0
833
- return output
834
-
835
- if not on_groupings:
836
- #If no groupings, develop the selected groupings
837
- for i,dim in enumerate(dimensions):
838
- if dim in self.groupings.keys():
839
- for label in labels[i]:
840
- if label in self.groupings[dim].keys():
841
- labels[i].append(self.groupings[dim][label])
842
- labels[i].remove(label)
843
-
844
- for i,label in enumerate(labels):
845
- if label == "all":
846
- if on_groupings and dimensions[i] in self.groupings.keys():
847
- labels[i] = list(
848
- self.axes[axis[i]].groupings[dimensions[i]].keys()
849
- )
850
- else:
851
- labels[i] = self.axes[axis[i]].labels[dimensions[i]]
852
-
853
- output = self.copy()
854
- for label in itertools.product(*labels):
855
- #Iteratively set domestic selections to 0
856
- seldict = dict(zip(dimensions,label))
857
- sel = ["all"]*self.ndim
858
- for i in range(self.ndim):
859
- if i in axis:
860
- sel[i] = seldict
861
- output[sel] = 0
862
- return output
863
-
864
- def leontief_inversion(self):
865
- if self.ndim == 2 and self.issquare():
866
- data = np.linalg.inv(np.identity(len(self.axes[0])) - self.data)
867
- return self.alias(name=f"l_{self.name}",data=data)
868
- raise ValueError("Can only compute the Leontief inverse on"+\
869
- " square parts")
870
-
871
- def zone(self):
872
- """
873
- Apply a grouping by zone dependent on the shape of the Part.
874
-
875
- Final demand Parts are summed over zones.
876
- Horizontal extensions are expanded by zone.
877
-
878
- Raises
879
- ------
880
- AttributeError
881
- Parts with other shapes are rejected.
882
-
883
- Returns
884
- -------
885
- Part object
886
- Grouped part.
887
-
888
- """
889
- log.warning("This function is deprecated as it returns different "+\
890
- "results depending on the shape of the Part. "+\
891
- "Use group or expand instead.")
892
- if self.ndim == 2 and not self.issquare():
893
- if "countries" in self.axes[1].dimensions:
894
- #Expand normal parts
895
- return self.group(1,"countries")
896
- if self.ndim == 1:
897
- #Expand horizontal parts
898
- return self.expand("countries")
899
- raise AttributeError(f"Part {self.name} has no predefined grouping.")
900
-
901
- def update_groupings(self,groupings,ax=None):
902
- """Update the groupings of the current Part object
903
-
904
- groupings : dict
905
- Description of the groupings
906
- ax: int, list of int
907
- Axes to update. If left empty, all axes are updated.
908
- """
909
- self.groupings = groupings
910
- if ax is None:
911
- ax = range(self.ndim)
912
- for axe in list(ax):
913
- self.axes[axe].update_groupings(groupings)
914
-
915
- def aggregate(self,on=None,axis=None):
916
- """Aggregate dimensions along one or several axis.
917
-
918
- If groupings are defined, these are taken into account.
919
- If you want to sum over the dimension of an axis, use the sum method.
920
-
921
- If no axis is specified, the operation is applied to all axes.
922
- If no dimension is specified, the operation is applied to all possible dimensions.
923
-
924
- Parameters
925
- ----------
926
- axis : str or list of str, optional
927
- List of axis along which countries are grouped.
928
- If left emtpy, countries are grouped along all possible axis.
929
- on : str or dict, optional
930
- Indicate wether the grouping should be done by zones ("zones")
931
- or by sector ("sectors"), or both ("both").
932
- The default is "zones".
933
- If both, the operation is equivalent to summing over an axis
934
-
935
- Raises
936
- ------
937
- ValueError
938
- Raised if a selected Axe cannot be grouped.
939
-
940
- Returns
941
- -------
942
- Part object
943
- Part grouped by zone.
944
-
945
- """
946
- log.debug(f"Aggregate Part {self.name} along axis {axis} on {on}")
947
-
948
- if on is None:
949
- on = list(self.groupings.keys())
950
-
951
-
952
- if isinstance(on,list):
953
- for item in on:
954
- self = self.aggregate(on = item, axis=axis)
955
- return self
956
- if on not in self.groupings.keys():
957
- raise ValueError(f"No groupings defined for dimensions {on}")
958
-
959
- if axis is None:
960
- axis = self.hasax(
961
- on
962
- )
963
- if isinstance(axis,int):
964
- axis = [axis]
965
-
966
- output = self.alias()
967
-
968
- for ax in axis:
969
- output = output.aggregate_on(on,ax)
970
-
971
- output.name = f"{on}_grouped_{self.name}"
972
- return output
973
-
974
- def aggregate_on(self,on,axis):
975
- """Aggregate a Part along a given axis
976
-
977
- Parameters
978
- ----------
979
- on : str
980
- Dimension to aggregate on
981
- axis : int
982
- Axis to aggregate
983
-
984
- Returns
985
- -------
986
- Part instance
987
- Aggregated Part
988
- """
989
- if on not in self.axes[axis].dimensions:
990
- raise ValueError(f"Dimension {on} not found in axis {axis}")
991
-
992
- new_labels = self.axes[axis].labels.copy()
993
- new_labels[on] = list(self.axes[axis].groupings[on].keys())
994
- new_groupings = self.groupings.copy()
995
- new_groupings[on] = {
996
- item : [item] for item in self.groupings[on]
997
- }
998
-
999
- new_axis = self.axes.copy()
1000
- new_axis[axis] = Axe(new_labels,new_groupings)
1001
- new_shape = [len(ax) for ax in new_axis]
1002
-
1003
- output = Part(axes=new_axis)
1004
- idsum = new_axis[axis].dimensions.index(on) #Index of the dimension to sum on
1005
- ref_dev = self.develop(axis, squeeze=False)
1006
- new_dev = output.develop(axis,squeeze=False)
1007
- selector = ["all"]*ref_dev.ndim
1008
- for label in new_labels[on]:
1009
- selector[axis+idsum] = label
1010
- new_dev[selector] = ref_dev[selector].sum(
1011
- axis=axis+idsum,
1012
- keepdims=True
1013
- )
1014
- output = new_dev.data.reshape(new_shape)
1015
- return self.alias(data=output,name=f"{on}_grouped_{self.name}",
1016
- axes=new_axis)
1017
-
1018
-
1019
- def get_labels(self,axis=None):
1020
- """
1021
- Returns a list with the labels of each axis
1022
- of the part in a the dictionary.
1023
-
1024
- Parameters
1025
- ----------
1026
- axis : int or list of int, optional
1027
- Axis to investigate, by default None,
1028
- All axes are investigated.
1029
-
1030
- Returns
1031
- -------
1032
- list
1033
- Labels used in the part.
1034
- """
1035
- labels = []
1036
- if axis is None:
1037
- axis = range(self.ndim)
1038
- if isinstance(axis,int):
1039
- axis = [axis]
1040
- #Make sure the axis is iterable
1041
- for ax in axis:
1042
- labels.append(self.axes[ax].labels)
1043
- return labels
1044
-
1045
- def list_labels(self):
1046
- """List the labels of the Part"""
1047
- labels = dict()
1048
- ax_labels = self.get_labels()
1049
- for ax in ax_labels:
1050
- for label in ax.keys():
1051
- if label not in labels.keys():
1052
- labels[label] = ax[label]
1053
- return labels
1054
-
1055
- def get_dimensions(self,axis=None):
1056
- """
1057
- Returns the list dimensions of the Part
1058
-
1059
- Parameters
1060
- ----------
1061
- axis : int or list of int, optional
1062
- Axis to investigate, by default None,
1063
- All axes are investigated.
1064
-
1065
- Returns
1066
- -------
1067
- list
1068
- Dimensions of the axes.
1069
- """
1070
- dimensions = []
1071
- if axis is None:
1072
- axis = range(self.ndim)
1073
- if isinstance(axis,int):
1074
- axis = [axis]
1075
- #Make sure the axis is iterable
1076
- for ax in axis:
1077
- dimensions.append(self.axes[ax].dimensions)
1078
- return dimensions
1079
-
1080
- def rename_labels(self,old,new):
1081
- """
1082
- Rename some labels of the Part
1083
-
1084
- Parameters
1085
- ----------
1086
- old : str
1087
- Name of the label to change.
1088
- new : str
1089
- New label name.
1090
- """
1091
- for ax in self.axes:
1092
- if old in ax.dimensions:
1093
- ax.rename_labels(old,new)
1094
- self._store_labels()
1095
-
1096
- def replace_labels(self,name,labels,axis=None):
1097
- """
1098
- Update a label of the part
1099
-
1100
- Parameters
1101
- ----------
1102
- name : str
1103
- Name of the label to update, by default None
1104
- labels : dict or list
1105
- New labels for the corresponding ax.
1106
- If a list is passed, the former label name is used.
1107
- axis : int, list of int, optional
1108
- List of axis on which the label is changed.
1109
- By default None, all possible axes are updated.
1110
- """
1111
- if axis is None:
1112
- axis = range(self.ndim)
1113
- if isinstance(axis,int):
1114
- axis = [axis]
1115
- if isinstance(labels,list):
1116
- labels = {name:labels}
1117
- for ax in axis:
1118
- if name in self.axes[ax].dimensions:
1119
- self.axes[ax].replace_labels(name,labels)
1120
- self._store_labels()
1121
-
1122
- def set_labels(self,labels,axis=None):
1123
- """
1124
- Change the labels of the Part
1125
-
1126
- Parameters
1127
- ----------
1128
- labels : dict or nested list
1129
- New labels of the axes.
1130
- If a nested list is passed, the first level corresponds to the axes
1131
- axis : str, optional
1132
- Axis on which the labels are changes, by default None,
1133
- all axes are updated.
1134
- """
1135
- if axis is None:
1136
- axis = range(self.ndim)
1137
- if isinstance(axis,int):
1138
- axis = [axis]
1139
- if isinstance(labels,list) and len(labels) == self.ndim:
1140
- labels = {i:labels[i] for i in range(self.ndim)}
1141
- for ax in axis:
1142
- self.axes[ax].set_labels(labels[ax])
1143
- self._store_labels()
1144
-
1145
- def _store_labels(self):
1146
- """Store the labels of the Part"""
1147
- self.labels = self.list_labels()
1148
-
1149
- def add_labels(self,labels,dimension=None,axes=None,
1150
- fill_value=0):
1151
- """
1152
- Add indices to one or multiple Part axes.
1153
-
1154
- Parameters
1155
- ----------
1156
- new_labels : list of str or dict
1157
- List of indices to add
1158
- dimension : str, optional
1159
- Labels the new indices should be appended to,
1160
- in case new_labels is not a dict.
1161
- If new_labels is a dict, dimension is ignored.
1162
- axes : int or set of ints, optional
1163
- Axes or list of axes to modify.
1164
- In case it is not specified, the axes are detected
1165
- by looking for the dimension (or new_labels keys) in each ax.
1166
- fill_value : float, optional
1167
- Value used to initialize the new Part
1168
-
1169
- Returns
1170
- -------
1171
- Part instance
1172
- Part instance with the additional ax indices.
1173
-
1174
- Raise
1175
- -----
1176
- ValueError
1177
- A Value Error is raised if neither the axes nor the
1178
- ref_set arguments are set.
1179
- """
1180
- if isinstance(labels,list):
1181
- labels = {dimension:labels}
1182
- dimension = list(labels.keys())[0]
1183
- if axes is None:
1184
- #Identify the axes with the ref_set in labels
1185
- axes = self.hasax(dimension)
1186
- elif isinstance(axes,int):
1187
- axes = [axes]
1188
-
1189
- new_axes = self.axes.copy()
1190
- sel = ["all"]*self.ndim
1191
- for ax in axes:
1192
- log.debug("Add labels to axis "+str(ax))
1193
- sel[ax] = dict()
1194
- old_labels = self.axes[ax].labels
1195
- new_labels = self.axes[ax].labels.copy()
1196
- new_labels[dimension] = old_labels[dimension] + labels[dimension]
1197
- new_axes[ax] = Axe(new_labels,self.groupings)
1198
- sel[ax] = old_labels
1199
-
1200
- new_shape = [len(ax) for ax in new_axes]
1201
- output = self.alias(data=np.full(new_shape,fill_value,dtype="float64"),
1202
- axes=new_axes)
1203
-
1204
- #Put original data back in place
1205
- output[sel] = self.data
1206
- return output
1207
-
1208
- def reorder_data(self,new_labels):
1209
- """
1210
- Reorder the data of the Part according to new labels.
1211
-
1212
- Parameters
1213
- ----------
1214
- new_labels : dict
1215
- New labels for the axes.
1216
- The keys are the dimensions, the values are the labels.
1217
-
1218
- Raises
1219
- ------
1220
- ValueError
1221
- If the new labels do not match the current axes.
1222
- """
1223
-
1224
- if not isinstance(new_labels,dict):
1225
- raise ValueError("New labels should be a dictionary")
1226
-
1227
-
1228
-
1229
- sels = []
1230
- for axis in self.axes:
1231
- old_labels = axis.labels
1232
- if not set(new_labels.keys()).issubset(set(old_labels.keys())):
1233
- sels.append(axis.get("all"))
1234
- continue
1235
- for key in new_labels.keys():
1236
- set_old = set(old_labels[key])
1237
- set_new = set(new_labels[key])
1238
- if not set_old.issubset(set_new):
1239
- raise ValueError(f"The new labels provided for dimension '{key}' is not a superset of the old labels. " +
1240
- f"Old labels: {old_labels[key]}, new labels: {new_labels[key]}. "
1241
- "If you want to rename the labels of this dimensions, use the method 'replace_labels() before reordering the data")
1242
-
1243
-
1244
- ax_label_dict = {}
1245
- for key in old_labels.keys():
1246
- if key in new_labels.keys():
1247
- ax_label_dict[key] = new_labels[key]
1248
- for lab in new_labels[key]:
1249
- if lab not in old_labels[key]:
1250
- # If the label is not in the list, remove it
1251
- ax_label_dict[key].remove(lab)
1252
- else:
1253
- ax_label_dict[key] = old_labels[key]
1254
-
1255
- sels.append(axis.get(ax_label_dict))
1256
-
1257
- if len(sels) == 0:
1258
- raise ValueError(f"None of the dimensions provided in the new labels dict {new_labels.keys()} are present "+
1259
- f"in the labels of part '{self.name}', which only contains the dimensions {self.get_dimensions()}")
1260
-
1261
- #Execute the selection
1262
- self.data = self.data[np.ix_(*sels)]
1263
-
1264
- # Update the axes with the new labels
1265
- for dim in new_labels.keys():
1266
- self.replace_labels(name = dim, labels = new_labels[dim])
1267
-
1268
-
1269
- def expand(self,axis=None,over="countries"):
1270
- """
1271
- Expand an axis of the Part
1272
-
1273
- Create a new Axes with a unique dimension.
1274
- Note that this operation significantly expands the size of the Part.
1275
- It is recommended to use this method with Extension parts only.
1276
-
1277
- Parameters
1278
- ----------
1279
- axis : int, optional
1280
- Axe to extend.
1281
- If left empty, the first suitable axe is expanded.
1282
- over : str, optional
1283
- Axe dimension to expand the Part by.
1284
- The default is "countries".
1285
-
1286
- Returns
1287
- -------
1288
- Part object
1289
- New Part object with an additional dimension.
1290
-
1291
- """
1292
- if axis is None:
1293
- axis = self.hasax(over)
1294
- if isinstance(axis,int):
1295
- axis = [axis]
1296
-
1297
- output = self.copy()
1298
-
1299
- for ax in axis[::-1]:
1300
- output = output._expand_on(ax,over)
1301
- return output
1302
-
1303
- def _expand_on(self,ax,over):
1304
- """Expand over a single axis"""
1305
- ref_ax = self.axes[ax]
1306
- new_ax = Axe({over: ref_ax.labels[over]}, groupings=self.groupings)
1307
- axes = self.axes.copy()
1308
- axes.insert(ax,new_ax)
1309
- new_shape = list(self.shape)
1310
- new_shape.insert(ax,len(new_ax))
1311
- output = np.zeros(new_shape)
1312
- selector = [slice(None)]*self.ndim
1313
- ordering = ref_ax.dimensions.index(over)
1314
- for item in ref_ax.labels[over]:
1315
- newsel,refsel = selector.copy(),selector.copy()
1316
- newsel.insert(ax,new_ax.get(item))
1317
- ax_selector = ["all" for i in range(ordering)]
1318
- ax_selector.append(item)
1319
- newsel[ax+1] = ref_ax.get(ax_selector)
1320
- refsel[ax]= ref_ax.get(ax_selector)
1321
- output[tuple(newsel)] = self.data[tuple(refsel)]
1322
- return self.alias(data=output,name=f"expanded_{self.name}",axes=axes)
1323
-
1324
-
1325
- def issquare(self):
1326
- """Assert wether the Part is square"""
1327
- return self.ndim == 2 and len(self.axes[0])==len(self.axes[1])
1328
-
1329
- def hasneg(self):
1330
- """Test whether Part has negative elements"""
1331
- if np.any(self.data<0):
1332
- return True
1333
- return False
1334
-
1335
- def hasax(self,name=None):
1336
- """Returns the dimensions along which a Part has given labels
1337
-
1338
- If no axis can be found, an empty list is returned empty.
1339
- This method can be used to assert the existence of a given dimension
1340
- in the part.
1341
-
1342
- Parameters
1343
- ----------
1344
- name : int, optional
1345
- Name of the label to look for.
1346
- If no name is given, all axes are returned.
1347
-
1348
- Returns
1349
- -------
1350
- axes : list of ints
1351
- Dimensions along which the labels are found.
1352
-
1353
- """
1354
- if name == "any" or name is None:
1355
- return [i for i in range(self.ndim)]
1356
- axes = []
1357
- for i,ax in enumerate(self.axes):
1358
- if name in ax.dimensions:
1359
- axes.append(i)
1360
- return axes
1361
-
1362
- def __str__(self):
1363
- return f"{self.name} Part object with {self.ndim} dimensions"
1364
-
1365
- def sum(self,axis=None,on=None,keepdims=False):
1366
- """
1367
- Sum the Part along one or several axis, and/or on a given dimension.
1368
-
1369
- Parameters
1370
- ----------
1371
- axis : int or list of int, optional
1372
- Axe along which the sum is evaluated.
1373
- By default None, the sum of all coefficients of the Part is returned
1374
- on : str, optional
1375
- name of the dimension to be summed on.
1376
- If no axis is defined, the Part is summed over all axis having
1377
- the corresponding dimension.
1378
- By default None, the full ax is summed
1379
- keepdims : bool, optional
1380
- Whether to keep the number of dimensions of the original.
1381
- By default False, the dimensions of lenght 1 are removed.
1382
-
1383
- Returns
1384
- -------
1385
- Part instance or float
1386
- Result of the sum.
1387
- """
1388
- if axis is None:
1389
- if on is None:
1390
- return self.data.sum()
1391
- if not keepdims:
1392
- self = self.squeeze()
1393
- axis = self.hasax(on)
1394
- if isinstance(axis,int):
1395
- if on is not None:
1396
- return self._sum_on(axis,on,keepdims)
1397
- ax = self.axes.copy()
1398
- if not keepdims:
1399
- del ax[axis]
1400
- else:
1401
- ax[axis] = Axe(["all"])
1402
- return self.alias(
1403
- data=self.data.sum(axis,keepdims=keepdims),
1404
- name=f"{self.name}_sum_{axis}",
1405
- axes = ax
1406
- )
1407
- axis = sorted(axis)
1408
- for ax in axis[::-1]:
1409
- self = self.sum(ax,on,keepdims)
1410
- return self
1411
-
1412
- def _sum_on(self,axis,on,keepdims=False):
1413
- """
1414
- Sum a Part along an axis on a given dimension
1415
- """
1416
- ax = self.axes[axis]
1417
- if on not in ax.dimensions:
1418
- raise ValueError(f"Cannot sum on {on} as it is not a dimension of axis {axis}")
1419
- if ax.levels == 1:
1420
- #If the axis has a single level, this is a simple sum
1421
- axes = self.axes.copy()
1422
- if not keepdims:
1423
- del axes[axis]
1424
- return self.alias(
1425
- data = self.data.sum(axis,keepdims=keepdims),
1426
- name=f"{self.name}_sum_{axis}",
1427
- axes = axes
1428
- )
1429
- #Otherwise, sum on the relevant levels
1430
- idsum = ax.dimensions.index(on) #Index of the dimension to sum on
1431
- dev = self.develop(axis,squeeze=False)
1432
- dev = dev.sum(axis+idsum,keepdims=keepdims)
1433
- if keepdims:
1434
- dev = dev.combine_axes(axis,axis+idsum)
1435
- dev.name = f"{self.name}_sum_on_{on}_{axis}"
1436
- return dev
1437
-
1438
- def save(self,
1439
- file=None,
1440
- name=None,
1441
- extension=".npy",
1442
- overwrite=False,
1443
- include_labels=False,
1444
- write_instructions=False,
1445
- **kwargs):
1446
- """
1447
- Save the Part object to a file
1448
-
1449
- Parameters
1450
- ----------
1451
- name : str, optional
1452
- Name under which the Part is saved.
1453
- By default, the current part is used.
1454
- path : Path-like, optional
1455
- Directory in which the Path should be saved,
1456
- by default None, the dir from which the part was loaded.
1457
- extension : str, optional
1458
- Format under which the part is saved. The default ".npy"
1459
- If ".csv" is chosen, the part is saved as a csv file with labels
1460
- file : path-like, optional
1461
- Full path to the file to save the Part to.
1462
- This overrides the path, name and extension arguments.
1463
- overwrite : boolm optional
1464
- Whether to overwrite an existing file.
1465
- If set False, the file is saved with a new name.
1466
- The default is False.
1467
- write_instructions : bool, optional
1468
- Whether to write the loading instructions to a yaml file.
1469
- The default is False.
1470
- include_labels : bool, optional
1471
- Whether to include the labels in the saved file.
1472
- Only applicable to .csv and .xlsx files.
1473
- **kwargs : dict
1474
- Additional arguments to pass to the saving function
1475
-
1476
- Raises
1477
- ------
1478
- FileNotFoundError
1479
- _description_
1480
- """
1481
- path = kwargs.get("path",None)
1482
- if file is not None:
1483
- path,name = os.path.split(file)
1484
- name,possible_extension = os.path.splitext(name)
1485
- if possible_extension != "":
1486
- extension = possible_extension
1487
- if name is None:
1488
- name = self.name
1489
- if path is None:
1490
- raise FileNotFoundError("No path specified for saving the Part")
1491
- if extension == ".nc":
1492
- path = os.path.join(path,name+extension)
1493
- save_to_nc(self,path,overwrite,
1494
- write_instructions=write_instructions,
1495
- **kwargs)
1496
- else:
1497
- save_part_to_folder(
1498
- self,
1499
- path = path,
1500
- name = name,
1501
- extension = extension,
1502
- overwrite = overwrite,
1503
- include_labels=include_labels,
1504
- write_instructions = write_instructions,
1505
- **kwargs
1506
- )
1507
-
1508
- def to_pandas(self):
1509
- """Return the current Part object as a Pandas DataFrame
1510
-
1511
- Only applicable to Parts objects with 1 or 2 dimensions.
1512
- """
1513
- return converters.pandas.to_pandas(self)
1514
-
1515
- def to_xarray(self):
1516
- """
1517
- Save the Part object to an xarray DataArray
1518
-
1519
- Labels are directly passed to the DataArray as coords.
1520
- Note that data will be flattened.
1521
- The dimension order will be saved as an attribute.
1522
- If you're loading the data back,
1523
- the Part will be automatically reshaped to its original dimensions.
1524
-
1525
- Returns
1526
- -------
1527
- xr.DataArray
1528
- Corresponding DataArray
1529
- """
1530
- return converters.xarray.to_DataArray(self)
1531
-
1532
- def mean(self,axis=None):
1533
- return self.data.mean(axis)
1534
-
1535
- def min(self,axis=None):
1536
- return self.data.min(axis)
1537
-
1538
- def max(self,axis=None):
1539
- return self.data.max(axis)
1540
-
1541
- def mul(self,a,propagate_labels=True):
1542
- """
1543
- Matrix multiplication between parts with labels propagation
1544
-
1545
- Parameters
1546
- ----------
1547
- a : Part or numpy array
1548
- Right-hand multiplicator.
1549
- propagate_labels : bool, optional
1550
- Whether to try propagating the labels from the right hand multiplicator
1551
- By default True.
1552
- If right-hand multiplicator is not a Part object, becomes False.
1553
-
1554
- Returns
1555
- -------
1556
- Part instance
1557
- result of the multiplication
1558
- """
1559
- if not isinstance(a,Part):
1560
- propagate_labels = False
1561
- name="array"
1562
- else:
1563
- name = a.name
1564
- data = np.matmul(self.data,a.data)
1565
- axes = [self.axes[i] for i in range(self.ndim-1)]
1566
- for ax in range(a.ndim-1):
1567
- if propagate_labels:
1568
- axes.append(a.axes[ax+1])
1569
- else:
1570
- axes.append(Axe([i for i in range(a.shape[ax+1])]))
1571
- return self.alias(data=data,name=f"{self.name}.{name}",axes=axes)
1572
-
1573
- def filter(self,threshold,fill_value=0):
1574
- """
1575
- Set to 0 the values below a given threshold
1576
-
1577
- Parameters
1578
- ----------
1579
- threshold : float
1580
- Threshold value.
1581
- fill_value : float, optional
1582
- Value to replace the filtered values with.
1583
- The default is 0.
1584
-
1585
- Returns
1586
- -------
1587
- Part instance
1588
- Filtered Part.
1589
-
1590
- """
1591
- data = self.data.copy()
1592
- data[data<threshold] = fill_value
1593
- return self.alias(data=data,name=f"filtered_{self.name}_{threshold}")
1594
-
1595
- def diag(self):
1596
- if self.ndim == 1:
1597
- log.info("Diagonalize a 1D part")
1598
- return self.alias(data=np.diag(self.data),
1599
- name=f"diag_{self.name}",
1600
- axes = self.axes*2)
1601
- try:
1602
- log.info("The part has too many dimensions: try to diagonalize the squeezed part")
1603
- return self.squeeze().diag()
1604
- except:
1605
- raise ValueError("Cannot diagonalize a part with more than 2 dimensions")
1606
-
1607
- def __add__(self,a):
1608
- if isinstance(a,Part):
1609
- name = a.name
1610
- a = a.data
1611
- else:
1612
- name=""
1613
- if isinstance(a,np.ndarray) and self.ndim != a.ndim:
1614
- a = a.squeeze()
1615
- self = self.squeeze()
1616
- return self.alias(data=a+self.data,name=f"{self.name}+{name}")
1617
-
1618
- def __radd__(self,a):
1619
- return self.__add__(a)
1620
-
1621
- def __rmul__(self,a):
1622
- return self.__mul__(a)
1623
-
1624
- def __mul__(self,a):
1625
- if isinstance(a,Part):
1626
- name = "{a.name}*{self.name}"
1627
- a = a.data
1628
- else:
1629
- if isinstance(a,int):
1630
- name = f"{a}*{self.name}"
1631
- else:
1632
- name = f"array*{self.name}"
1633
- data = self.data*a
1634
- if data.ndim!=self.ndim:
1635
- data = data.squeeze()
1636
- #Trust numpy to broadcast the multiplication
1637
- #Squeeze to get rid of unused dimensions
1638
- return self.alias(data=data,
1639
- name=name)
1640
-
1641
- def __neg__(self):
1642
- return self.alias(data=-self.data,name=f"-{self.name}")
1643
-
1644
- def __lt__(self,other):
1645
- if isinstance(other,Part):
1646
- return self.data < other.data
1647
- return self.data < other
1648
-
1649
- def __le__(self,other):
1650
- if isinstance(other,Part):
1651
- return self.data <= other.data
1652
- return self.data <= other
1653
-
1654
- def __gt__(self,other):
1655
- if isinstance(other,Part):
1656
- return self.data > other.data
1657
- return self.data > other
1658
-
1659
- def __ge__(self,other):
1660
- if isinstance(other,Part):
1661
- return self.data >= other.data
1662
- return self.data >= other
1663
-
1664
- def __sub__(self,a):
1665
- if isinstance(a,Part):
1666
- name = a.name
1667
- a = a.data
1668
- else:
1669
- name=""
1670
- return self.alias(data=self.data-a,name=f"{self.name}-{name}")
1671
-
1672
- def __rsub__(self,a):
1673
- if isinstance(a,Part):
1674
- name = a.name
1675
- a = a.data
1676
- else:
1677
- name=""
1678
- return self.alias(data=a-self.data,name=f"{name}-{self.name}")
1679
-
1680
- def power(self,a):
1681
- if isinstance(a,Part):
1682
- a = a.data
1683
- name = f"{self.name}**{a.name}"
1684
- elif isinstance(a,int) or isinstance(a,float):
1685
- name = f"{self.name}**{a}"
1686
- else:
1687
- name = f"{self.name}**array"
1688
- return self.alias(data=np.power(self,a),name=name)
1689
-
1690
- def __pow__(self,a):
1691
- return self.power(a)
1692
-
1693
- def __eq__(self,a):
1694
- if isinstance(a,Part):
1695
- return np.all(self.data==a.data)
1696
- return False
1697
-
1698
- def __rtruediv__(self,a):
1699
- if isinstance(a,Part):
1700
- name = f"{self.name}/{a.name}"
1701
- a = a.data
1702
- else:
1703
- if isinstance(a,int):
1704
- name = f"{a}/{self.name}"
1705
- else:
1706
- name= f"array/{self.name}"
1707
- if np.sum(self.data==0)!=0:
1708
- log.warning("Division by zero in "+name)
1709
- return self.alias(data=a/self.data,
1710
- name=name)
1711
-
1712
- def __truediv__(self,a):
1713
- if isinstance(a,Part):
1714
- name = f"{a.name}/{self.name}"
1715
- a = a.data
1716
- else:
1717
- if isinstance(a,int):
1718
- name = f"{self.name}/{a}"
1719
- else:
1720
- name= f"{self.name}/array"
1721
- if np.sum(a==0)!=0:
1722
- log.warning("Division by zero in "+name)
1723
- return self.alias(data=self.data/a,
1724
- name=name)
1725
-
1726
- def __getattr__(self,name):
1727
- name = name.casefold()
1728
- try:
1729
- return self.metadata[name]
1730
- except:
1731
- pass
1732
- raise AttributeError(f"Attribute {name} not found")
1733
-
1734
- def transpose(self):
1735
- return self.alias(data=self.data.transpose(),
1736
- name=f"transposed_{self.name}",
1737
- axes=self.axes[::-1])
1738
-
1739
-