mrio-toolbox 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mrio-toolbox might be problematic. Click here for more details.

@@ -0,0 +1,1504 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Apr 14 14:13:17 2023
4
+
5
+ @author: beaufils
6
+ """
7
+
8
+ import os
9
+ import itertools
10
+ import numpy as np
11
+ import pandas as pd
12
+ import xarray as xr
13
+ from mrio_toolbox._parts._Axe import Axe
14
+ import logging
15
+ from mrio_toolbox.utils import converters
16
+ from mrio_toolbox.utils.loaders import make_loader
17
+ from mrio_toolbox.utils.savers import save_part_to_folder,save_to_nc
18
+ from mrio_toolbox._parts import part_operations
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+ def load_part(
23
+ **kwargs
24
+ ):
25
+ loader = make_loader(**kwargs)
26
+ return Part(**loader.load_part(**kwargs))
27
+
28
+ class Part:
29
+ def __init__(self,data=None,
30
+ labels=None,
31
+ axes=None,
32
+ **kwargs):
33
+ """MRIO Parts object
34
+
35
+ MRIO Parts are the basic building blocks of the MRIO toolbox.
36
+ A Part is built from a numpy array and a set of Axes,
37
+ corresponding to the dimensions of the array.
38
+ The Axes hold the labels of the Part in the different dimensions
39
+ and are used to perform advanced indexing and operations on the Part.
40
+
41
+ Axes support multi-level indexing and groupings.
42
+
43
+ Parameters
44
+ ----------
45
+ data : numpy array
46
+ Numerical data of the part.
47
+ If left empty, a Part of zeros (or any other fill value) is created
48
+ with a shape matching the axes.
49
+ groupings : dict of label level : dict
50
+ Groupings of the labels of the Part, for each label defined.
51
+ The groupings are passed to the Axe objects.
52
+ labels : list of str or dict, optional
53
+ Labels of the axes.
54
+ The upper level of the list correspond to each axe (numpy dimension)
55
+ of the part
56
+ The lower level correspond to the dict of labels for each axe.
57
+ Remember that an Axe can have different levels of labels.
58
+ axes : list of Axe instances, optional
59
+ Custom Axes for the Part.
60
+ If left empty, the axes are created from the labels.
61
+ kwargs : dict
62
+ Additional metadata of the Part.
63
+ (e.g. path, name, multiplier, unit...)
64
+
65
+ Returns
66
+ -------
67
+ None.
68
+
69
+ """
70
+
71
+ if data is not None:
72
+ if isinstance(data,Part):
73
+ data = data.data
74
+ if isinstance(data,(xr.DataArray,xr.Dataset)):
75
+ self.__init__(
76
+ **converters.xarray.make_part(data)
77
+ )
78
+ return
79
+ if isinstance(data,pd.DataFrame):
80
+ self.__init__(**converters.pandas.make_part(
81
+ data)
82
+ )
83
+ return
84
+ self.data = data
85
+ self.ndim = data.ndim
86
+ self.shape = self.data.shape
87
+
88
+ self.name = kwargs.pop("name","new_part")
89
+ log.debug("Create Part instance " + self.name)
90
+
91
+ self.groupings = kwargs.get("groupings",dict())
92
+ self.metadata = kwargs.get("metadata",dict())
93
+ self.metadata = {**self.metadata,**kwargs}
94
+
95
+ if axes is None:
96
+ self.axes = []
97
+ self._create_axis(labels)
98
+ else:
99
+ self.axes = axes
100
+
101
+ if data is None:
102
+ log.debug("Create empty Part")
103
+ data = np.zeros([len(ax) for ax in self.axes])
104
+ self.data = data
105
+ self.ndim = data.ndim
106
+ self.shape = self.data.shape
107
+
108
+ self.fix_dims()
109
+ for dim in range(self.ndim):
110
+ if len(self.axes[dim]) != self.shape[dim] and self.shape[dim] != 1:
111
+ log.critical(f"Length of label {dim} does not match data: "+\
112
+ f"{len(self.axes[dim])} and {self.shape[dim]}")
113
+ raise ValueError(f"Length of label {dim} does not match data: "+\
114
+ f"{len(self.axes[dim])} and {self.shape[dim]}")
115
+
116
+ if "_original_dimensions" in self.metadata.keys():
117
+ #This set of instructions is intended to handle
118
+ #Data loaded from netcdf files
119
+ log.info("Checking if a reformatting is needed.")
120
+ original = self.metadata.pop("_original_dimensions")
121
+ new_dims = [[]]
122
+ for dim in original:
123
+ #Decode the original dimensions
124
+ #Because netcdf files do not support multi-level attributes
125
+ if dim == "_sep_":
126
+ new_dims.append([])
127
+ else:
128
+ new_dims[-1].append(dim)
129
+
130
+ if new_dims != self.get_dimensions():
131
+ log.info("Reformat the Part")
132
+ new_part = self.reformat(new_dims)
133
+ self.data = new_part.data
134
+ self.axes = new_part.axes
135
+ self.ndim = self.data.ndim
136
+ self.shape = self.data.shape
137
+
138
+ self._store_labels()
139
+
140
+
141
+ def alias(self,**kwargs):
142
+ """Create a new Part in which only prescribed parameters are changed
143
+
144
+ The current Part is taken as reference: all arguments not explicitely
145
+ set are copied from the current part."""
146
+ data = kwargs.get("data",self.data).copy()
147
+ name = kwargs.get("name",self.name+"_alias")
148
+ groupings = kwargs.get("groupings",self.groupings)
149
+ axes = kwargs.get("axes",self.axes)
150
+ labels = kwargs.get("labels",None)
151
+ metadata = kwargs.get("metadata",self.metadata)
152
+ return Part(
153
+ data=data,
154
+ name=name,
155
+ groupings = groupings,
156
+ labels=labels,
157
+ axes=axes,
158
+ metadata = metadata)
159
+
160
+ def _create_axis(self,labels):
161
+ """
162
+ Create an Axe object based on a tuple of lists of indices
163
+
164
+ Parameters
165
+ ----------
166
+ *args : tuple of lists of str, list of str
167
+ Labels of the axe.
168
+ The first argument is used as the main label.
169
+ The second argument (if any) is used as secondary label.
170
+ If left empty, the axe is labelled by indices only
171
+
172
+ Raises
173
+ ------
174
+ TypeError
175
+ Raised if the arguments types differs from the number of dimensions
176
+ or if input labels are incorrect.
177
+ ValueError
178
+ Raised if the label length does not match the data.
179
+
180
+ Returns
181
+ -------
182
+ None.
183
+
184
+ """
185
+ self.axes = []
186
+
187
+ if isinstance(labels,(tuple,list)) and len(labels)>0 and\
188
+ isinstance(labels[0],(str,int,float)):
189
+ #If the first item of the labels is not iterable,
190
+ #we assume the label is an axis label
191
+ labels = [labels]
192
+ #We add a dimension to the labels such that enumeration works properly
193
+
194
+ if "data" in self.__dict__.keys():
195
+ enum = self.ndim
196
+ else:
197
+ if labels is None:
198
+ raise ValueError("Cannot create axes without data, axes or labels")
199
+ enum = len(labels)
200
+
201
+ for dim in range(enum):
202
+ if labels is None or dim > len(labels):
203
+ #Fill empty labels with indices
204
+ self.axes.append(
205
+ Axe([i for i in range(self.shape[dim])],
206
+ groupings = self.groupings)
207
+ )
208
+ elif isinstance(labels,dict):
209
+ axname = list(labels.keys())[dim]
210
+ self.axes.append(
211
+ Axe(labels[axname],groupings=self.groupings,name=axname)
212
+ )
213
+ elif isinstance(labels,(list,tuple)):
214
+ self.axes.append(
215
+ Axe(labels[dim],groupings=self.groupings)
216
+ )
217
+ else:
218
+ log.critical("Unkown label type: "+type(labels))
219
+ raise TypeError("Unknown label type: "+type(labels))
220
+ log.debug(f"Create ax {dim} with len {self.axes[-1]}")
221
+
222
+ def fix_dims(self,
223
+ skip_labels=False,
224
+ skip_data=False):
225
+ """Align the number of axes with the number of dimensions
226
+
227
+ If one length exceeds the other, axes and/or data are squeezed,
228
+ i.e. dimensions of length 1 are removed."""
229
+ if len(self.axes) == self.ndim:
230
+ return
231
+ log.warning(
232
+ f"The number of axes ({len(self.axes)})"\
233
+ +f" does not match data dimensions ({self.ndim})"
234
+ )
235
+
236
+ if len(self.axes) > self.ndim and not skip_labels:
237
+ log.debug(
238
+ "Try to squeeze axe(s) of len 1"
239
+ )
240
+ counter = [len(ax) != 1 for ax in self.axes]
241
+ self.axes = self.axes[counter]
242
+ return self.fix_dims(
243
+ skip_labels=True,
244
+ skip_data=skip_data)
245
+ if self.ndim > len(self.axes) and not skip_data:
246
+ self.data = self.data.squeeze()
247
+ self.shape = self.data.shape
248
+ self.ndim = self.data.ndim
249
+ return self.fix_dims(
250
+ skip_labels=skip_labels,
251
+ skip_data=True)
252
+
253
+ log.critical("Cannot reconcile data of dims "+ self.ndim+\
254
+ " with axes of dim " +len(self.axes))
255
+ raise IndexError("Cannot reconcile data of dims "+ self.ndim+\
256
+ " with axes of dim " +len(self.axes))
257
+
258
+
259
+ def __getitem__(self,args):
260
+ if isinstance(args,str) or isinstance(args,int) or isinstance(args,np.integer):
261
+ args = (args,)
262
+ return self.get(*args)
263
+
264
+ def __setitem__(self,args,value):
265
+ if isinstance(value,Part):
266
+ value = value.data
267
+ self.setter(value,*args)
268
+
269
+ def setter(self,value,*args):
270
+ """
271
+ Change the value of a data selection
272
+
273
+ Parameters
274
+ ----------
275
+ value : float or numpy like
276
+ Value to set.
277
+ *args : list of tuples
278
+ Indices along the respective axes.
279
+
280
+ Returns
281
+ -------
282
+ None.
283
+ Modification is applied to the current Part object
284
+
285
+ """
286
+ sels = []
287
+ try:
288
+ #First tries to interpret one arg per ax
289
+ for i,arg in enumerate(args):
290
+ sels.append(self.axes[i].get(arg))
291
+ except (IndexError,ValueError):
292
+ #Otherwise, tries to interpret all args on the first ax
293
+ sels = []
294
+ sels = [self.axes[0].get(args)]
295
+ if isinstance(value,(np.ndarray,Part)) and len(sels)!=value.ndim:
296
+ if len(sels) < value.ndim:
297
+ value = value.squeeze()
298
+ if len(sels) > value.ndim:
299
+ target_shape = [len(sel) for sel in sels]
300
+ value = np.reshape(value,target_shape)
301
+ self.data[np.ix_(*sels)] = value
302
+
303
+ def get(self,*args,aspart=True,squeeze=False):
304
+ """
305
+ Extract data from the current Part object
306
+
307
+ Parameters
308
+ ----------
309
+ *args : list of tuples
310
+ Selection along the Axes of the Part.
311
+ aspart : bool, optional
312
+ Whether to return the selection as a Part object.
313
+ If False, the selection is returned as a numpy object.
314
+ The default is True.
315
+ squeeze : bool, optional
316
+ Whether to remove dimensions of length 1.
317
+ The default is False
318
+
319
+ Returns
320
+ -------
321
+ New Part object or numpy object
322
+
323
+ """
324
+ sels = []
325
+ axes = []
326
+
327
+ #Extract the indices for the selection
328
+ try:
329
+ #First tries to interpret one arg per ax
330
+ for i,arg in enumerate(args):
331
+ datasel,labs,groupings = self.axes[i].get(arg,True)
332
+ if not squeeze or len(datasel) > 1:
333
+ axes.append(Axe(labs,groupings))
334
+ sels.append(datasel)
335
+ except (ValueError, IndexError):
336
+ sels = []
337
+ axes = []
338
+ #Try interpreting all args on the first ax
339
+ datasel,labs,groupings = self.axes[0].get(args,True)
340
+ if not squeeze or len(datasel) > 1:
341
+ axes.append(Axe(labs,groupings))
342
+ sels.append(datasel)
343
+
344
+ #If the selection is not complete, fill with all
345
+ if len(sels)<self.ndim:
346
+ for i in range(len(sels),self.ndim):
347
+ datasel,labs,groupings = self.axes[i].get("all",True)
348
+ if not squeeze or len(datasel) > 1:
349
+ axes.append(Axe(labs,groupings))
350
+ sels.append(datasel)
351
+
352
+ #Execute the selection
353
+ data = self.data[np.ix_(*sels)]
354
+
355
+ #Return the selection
356
+ if squeeze:
357
+ data = data.squeeze()
358
+ if aspart:
359
+ return Part(data=data,name=f"sel_{self.name}",
360
+ groupings=self.groupings,axes = axes)
361
+ return data
362
+
363
+ def develop(self,axis=None,on=None,squeeze=True):
364
+ """
365
+ Reshape a Part to avoid double labels
366
+
367
+ Parameters
368
+ ----------
369
+ axis : int or list of int, optional
370
+ Axis to develop.
371
+ If left empty, all axes are developed.
372
+ The default is None.
373
+ on : str or list of str, optional
374
+ Dimensions to develop.
375
+ If left empty, all dimensions are developed.
376
+ Note that the develop method does not support the developping of
377
+ non-contiguous dimensions.
378
+ The default is None.
379
+ squeeze : bool, optional
380
+ Whether to remove dimensions of length 1.
381
+ The default is True.
382
+
383
+ Returns
384
+ -------
385
+ Part object
386
+ Developped Part
387
+
388
+ """
389
+ if isinstance(on,str):
390
+ on = [on]
391
+ if isinstance(axis,int):
392
+ axis = [axis]
393
+ axes = []
394
+ for i,ax in enumerate(self.axes):
395
+ if axis is None or i in axis:
396
+ labels = ax.labels.copy()
397
+ for dim in ax.dimensions:
398
+ if on is None or dim in on:
399
+ #Add dimension that needs to be developed
400
+ axes.append(Axe({dim:labels[dim]},groupings=ax.groupings))
401
+ labels.pop(dim)
402
+ if len(labels) > 0:
403
+ #Keep remaining dimensions together
404
+ axes.append(Axe(labels,groupings=ax.groupings))
405
+ else:
406
+ axes.append(ax)
407
+
408
+
409
+ old_dim_order = [dim for ax in self.axes for dim in ax.dimensions]
410
+ new_dim_order = [dim for ax in axes for dim in ax.dimensions]
411
+ if new_dim_order != old_dim_order:
412
+ raise NotImplementedError(
413
+ "Developping the part misaligns the dimensions. "+\
414
+ "This operation is not yet supported."
415
+ )
416
+ #If the order of the dimensions is unchanged, we can simply reshape
417
+ shape = [len(ax) for ax in axes]
418
+ data = self.data.reshape(shape)
419
+ if squeeze:
420
+ return Part(data=data,
421
+ name=f"developped_{self.name}",
422
+ groupings=self.groupings,axes=axes).squeeze()
423
+ return Part(data=data,name=f"developped_{self.name}",
424
+ groupings=self.groupings,axes=axes)
425
+
426
+ def reformat(self,new_dimensions):
427
+ """
428
+ Reshape a Part to match a new dimensions combination
429
+
430
+ Equivalent to a combination of the develop and combine_axes methods.
431
+
432
+ This only works for contiguous dimensions in the current Part,
433
+ without overlapping dimensions.
434
+ For example, if the Part has dimensions:
435
+ [["countries"],["sectors"],["sectors"]]
436
+ The following is allowed:
437
+ [["countries","sectors"],["sectors"]]
438
+ The following is not allowed:
439
+ [["countries"],["sectors","sectors"]]
440
+ [["sectors"],["countries","sectors"]]
441
+ [["sectors","countries"],["sectors"]]
442
+
443
+ Parameters
444
+ ----------
445
+ dimensions : list of list of str
446
+ Original dimensions of the Part
447
+
448
+ Returns
449
+ -------
450
+ data : numpy array
451
+ Reshaped data
452
+ axes : list of Axe instances
453
+ Reshaped axes
454
+
455
+ """
456
+ return part_operations.reformat(self,new_dimensions)
457
+
458
+ def combine_axes(self,start=0,end=None,in_place=False):
459
+ """
460
+ Combine axes of a Part into a single one.
461
+
462
+ The order of dimensions is preserved in the new axis.
463
+ Only consecutive axes can be combined.
464
+ The method can be used to revert the develop method.
465
+
466
+ Parameters
467
+ ----------
468
+ start : int, optional
469
+ Index of the first axis to combine, by default 0
470
+ end : int, optional
471
+ Index of the final axis to combine, by default None,
472
+ all axis are combined, i.e. the Part is flattened.
473
+
474
+ Returns
475
+ -------
476
+ Part instance
477
+
478
+ Raises
479
+ ------
480
+ IndexError
481
+ Axes should have no overlapping dimensions.
482
+ """
483
+ axes = []
484
+ covered = []
485
+ labels,groupings = dict(),dict()
486
+ if end is None:
487
+ end = self.ndim - 1
488
+ for i,ax in enumerate(self.axes):
489
+ if i in range(start,end+1):
490
+ for name in ax.dimensions:
491
+ if name in covered:
492
+ raise IndexError(
493
+ "Cannot undevelop axes with overlapping dimensions"
494
+ )
495
+ covered.append(name)
496
+ labels[name] = ax.labels[name]
497
+ if name in ax.groupings.keys():
498
+ groupings[name] = ax.groupings[name]
499
+ else:
500
+ axes.append(ax)
501
+ if i == end:
502
+ axes.append(Axe(labels,groupings))
503
+
504
+ new_shape = [len(ax) for ax in axes]
505
+ data = self.data.reshape(new_shape)
506
+ if in_place:
507
+ return data,axes
508
+ return self.alias(data=data,name=f"combined_{self.name}",
509
+ axes=axes)
510
+
511
+ def swap_axes(self,axis1,axis2):
512
+ """Swap two axes of a Part
513
+
514
+ Parameters
515
+ ----------
516
+ axis1 : int
517
+ First axis to swap.
518
+ axis2 : int
519
+ Second axis to swap.
520
+
521
+ Returns
522
+ -------
523
+ Part instance
524
+ Part with swapped axes.
525
+
526
+ """
527
+ axes = self.axes.copy()
528
+ axes[axis1],axes[axis2] = axes[axis2],axes[axis1]
529
+ data = self.data.swapaxes(axis1,axis2)
530
+ return Part(data=data,name=f"swapped_{self.name}",axes=axes)
531
+
532
+ def swap_ax_levels(self,axis,dim1,dim2):
533
+ """Swap two levels of an axis
534
+
535
+ Parameters
536
+ ----------
537
+ axis : int
538
+ Axis to modify.
539
+ dim1 : str
540
+ First dimension to swap.
541
+ dim2 : str
542
+ Second dimension to swap.
543
+
544
+ Returns
545
+ -------
546
+ Part instance
547
+ Part with swapped levels.
548
+
549
+ """
550
+ len1,len2 = len(self.axes[axis].labels[dim1]),len(self.axes[axis].labels[dim2])
551
+ if len1 == 1 or len2 == 1:
552
+ axes = self.axes.copy()
553
+ axes[axis].swap_levels(dim1,dim2)
554
+ return Part(data=self.data,name=f"swapped_{self.name}",axes=axes)
555
+ dimensions = self.axes[axis].dimensions
556
+ id1,id2 = dimensions.index(dim1),dimensions.index(dim2)
557
+ offset = sum([len(ax.dimensions) for ax in self.axes[:axis]])
558
+ dev = self.develop(axis)
559
+ dev = dev.swap_axes(id1+offset,id2+offset)
560
+ dev = dev.combine_axes(axis,axis+len(dimensions)-1)
561
+ dev.name = f"swapped_{self.name}"
562
+ return dev
563
+
564
+ def flatten(self,invert=False):
565
+ """Flatten a 2D Part into a 1D Part
566
+
567
+ Parameters
568
+ ----------
569
+ inverse : bool, optional
570
+ Whether to in the inverse level order.
571
+ """
572
+ if self.ndim != 2:
573
+ raise ValueError(f"Cannot flatten Part with {self.ndim} dimensions")
574
+ if invert:
575
+ labels = {
576
+ dimension : self.axes[i].labels[dimension] \
577
+ for i in range(self.ndim,0,-1) \
578
+ for dimension in self.axes[i].dimensions
579
+ }
580
+ order = "C"
581
+ else:
582
+ labels = {
583
+ dimension : self.axes[i].labels[dimension] \
584
+ for i in range(self.ndim) \
585
+ for dimension in self.axes[i].dimensions
586
+ }
587
+ order = "F"
588
+ ax = Axe(labels,self.groupings)
589
+ return self.alias(data=self.data.flatten(order=order),
590
+ name=f"flattened_{self.name}",
591
+ axes =[ax])
592
+
593
+
594
+ def squeeze(self):
595
+ axes = []
596
+ for ax in self.axes:
597
+ if len(ax) > 1:
598
+ axes.append(ax)
599
+ return self.alias(data=np.squeeze(self.data),axes=axes,
600
+ name=f"squeezed_{self.name}")
601
+
602
+ def expand_dims(self,axis,copy=None):
603
+ """Add dimensions to a Part instance
604
+
605
+ Parameters
606
+ ----------
607
+ axis : int
608
+ Position of the new axis.
609
+ copy : int, optional
610
+ Axis to copy the labels from.
611
+ If left empty, the axis is created without labels.
612
+ """
613
+ axes = self.axes.copy()
614
+ if copy is not None:
615
+ axes.insert(axis,self.axes[copy])
616
+ else:
617
+ axes.insert(axis,Axe({"expanded":[0]}))
618
+ return self.alias(data = np.expand_dims(self.data,axis),
619
+ axes=axes,name=f"expanded_{self.name}")
620
+
621
+ def copy(self):
622
+ """Return a copy of the current Part object"""
623
+ return self.alias()
624
+
625
+ def extraction(self,
626
+ dimensions,
627
+ labels=["all"],
628
+ on_groupings=True,
629
+ domestic_only=False,
630
+ axis="all"):
631
+ """
632
+ Set labels over dimension(s) to 0.
633
+
634
+ Parameters
635
+ ----------
636
+ dimensions : str, list of str, dict
637
+ Name of the dimensions on which the extraction is done.
638
+ If dict is passed, the keys are interpreted as the dimensions
639
+ and the values as the labels
640
+ labels : list of (list of) str, optional
641
+ Selection on the dimension to put to 0.
642
+ on_groupings : bool, optional
643
+ Whether to use the groupings to select the labels.
644
+ This matters only when the domestic_only argument is set to True.
645
+ domestic_only : bool, optional
646
+ If yes, only domestic transactions are set to 0 and trade flows
647
+ are left untouched. The default is False.
648
+ axis : list of ints, optional
649
+ Axis along which the extraction is done. The default is "all".
650
+ In any case, the extraction only applies to axis allowing it, that
651
+ is in axis containing zones or countries labels corresponding to
652
+ the zone selection.
653
+
654
+ Returns
655
+ -------
656
+ Part object
657
+ New Part with selection set to 0.
658
+
659
+ """
660
+ if isinstance(dimensions,str):
661
+ dimensions = [dimensions]
662
+ if isinstance(labels,str) and labels!="all":
663
+ labels = [labels]
664
+ if len(labels) != len(dimensions):
665
+ if len(dimensions)==1:
666
+ #If only one dimension is passed, we broadcast the labels
667
+ labels = [labels]
668
+ else:
669
+ #Raise an error for ambiguous cases
670
+ log.critical("Number of dimensions and labels do not match for extraction")
671
+ raise ValueError("Number of dimensions and labels do not match for extraction")
672
+ if isinstance(dimensions,dict):
673
+ to_select = dimensions
674
+ labels = list(to_select.values())
675
+ dimensions = list(dimensions.keys())
676
+ else:
677
+ to_select = dict()
678
+ for dim,label in zip(dimensions,labels):
679
+ to_select[dim] = label
680
+
681
+ allowed = []
682
+ for i,ax in enumerate(self.axes):
683
+ if all(dimension in ax.dimensions for dimension in dimensions):
684
+ allowed.append(i)
685
+ if len(allowed) == 0:
686
+ log.critical("No axis found for extraction on "+str(dimensions))
687
+ raise ValueError("No axis found for extraction on "+str(dimensions))
688
+ if axis == "all":
689
+ log.info(f"Extract {to_select} on axes "+ str(allowed))
690
+ axis = allowed
691
+ if isinstance(axis,int):
692
+ axis = [axis]
693
+ if not all(ax in allowed for ax in axis):
694
+ wrong = [ax for ax in axis if ax not in allowed]
695
+ log.critical(f"Cannot extract {dimensions} on axis {wrong}")
696
+ raise ValueError(f"Cannot extract {dimensions} on axis {wrong}")
697
+
698
+ if not domestic_only:
699
+ #If no domestic_only, we can simply set the selection to 0
700
+ sel = ["all"]*self.ndim
701
+ for i,ax in enumerate(self.axes):
702
+ if i in axis:
703
+ sel[i] = to_select
704
+ output = self.copy()
705
+ output[sel] = 0
706
+ return output
707
+
708
+ if not on_groupings:
709
+ #If no groupings, develop the selected groupings
710
+ for i,dim in enumerate(dimensions):
711
+ if dim in self.groupings.keys():
712
+ for label in labels[i]:
713
+ if label in self.groupings[dim].keys():
714
+ labels[i].append(self.groupings[dim][label])
715
+ labels[i].remove(label)
716
+
717
+ for i,label in enumerate(labels):
718
+ if label == "all":
719
+ if on_groupings and dimensions[i] in self.groupings.keys():
720
+ labels[i] = list(
721
+ self.axes[axis[i]].groupings[dimensions[i]].keys()
722
+ )
723
+ else:
724
+ labels[i] = self.axes[axis[i]].labels[dimensions[i]]
725
+
726
+ output = self.copy()
727
+ for label in itertools.product(*labels):
728
+ #Iteratively set domestic selections to 0
729
+ seldict = dict(zip(dimensions,label))
730
+ sel = ["all"]*self.ndim
731
+ for i in range(self.ndim):
732
+ if i in axis:
733
+ sel[i] = seldict
734
+ output[sel] = 0
735
+ return output
736
+
737
+ def leontief_inversion(self):
738
+ if self.ndim == 2 and self.issquare():
739
+ data = np.linalg.inv(np.identity(len(self.axes[0])) - self.data)
740
+ return self.alias(name=f"l_{self.name}",data=data)
741
+ raise ValueError("Can only compute the Leontief inverse on"+\
742
+ " square parts")
743
+
744
+ def zone(self):
745
+ """
746
+ Apply a grouping by zone dependent on the shape of the Part.
747
+
748
+ Final demand Parts are summed over zones.
749
+ Horizontal extensions are expanded by zone.
750
+
751
+ Raises
752
+ ------
753
+ AttributeError
754
+ Parts with other shapes are rejected.
755
+
756
+ Returns
757
+ -------
758
+ Part object
759
+ Grouped part.
760
+
761
+ """
762
+ log.warning("This function is deprecated as it returns different "+\
763
+ "results depending on the shape of the Part. "+\
764
+ "Use group or expand instead.")
765
+ if self.ndim == 2 and not self.issquare():
766
+ if "countries" in self.axes[1].dimensions:
767
+ #Expand normal parts
768
+ return self.group(1,"countries")
769
+ if self.ndim == 1:
770
+ #Expand horizontal parts
771
+ return self.expand("countries")
772
+ raise AttributeError(f"Part {self.name} has no predefined grouping.")
773
+
774
+ def update_groupings(self,groupings,ax=None):
775
+ """Update the groupings of the current Part object
776
+
777
+ groupings : dict
778
+ Description of the groupings
779
+ ax: int, list of int
780
+ Axes to update. If left empty, all axes are updated.
781
+ """
782
+ self.groupings = groupings
783
+ if ax is None:
784
+ ax = range(self.ndim)
785
+ for axe in list(ax):
786
+ self.axes[axe].update_groupings(groupings)
787
+
788
+ def aggregate(self,on="countries",axis=None):
789
+ """Aggregate dimensions along one or several axis.
790
+
791
+ If groupings are defined, these are taken into account.
792
+ If you want to sum over the dimension of an axis, use the sum method.
793
+
794
+ If no axis is specified, the operation is applied to all axes.
795
+
796
+ Parameters
797
+ ----------
798
+ axis : str or list of str, optional
799
+ List of axis along which countries are grouped.
800
+ If left emtpy, countries are grouped along all possible axis.
801
+ on : str or dict, optional
802
+ Indicate wether the grouping should be done by zones ("zones")
803
+ or by sector ("sectors"), or both ("both").
804
+ The default is "zones".
805
+ If both, the operation is equivalent to summing over an axis
806
+
807
+ Raises
808
+ ------
809
+ ValueError
810
+ Raised if a selected Axe cannot be grouped.
811
+
812
+ Returns
813
+ -------
814
+ Part object
815
+ Part grouped by zone.
816
+
817
+ """
818
+ log.debug(f"Aggregate Part {self.name} along axis {axis} on {on}")
819
+
820
+
821
+ if isinstance(on,list):
822
+ for item in on:
823
+ self = self.aggregate(axis,item)
824
+ return self
825
+ if on not in self.groupings.keys():
826
+ raise ValueError(f"No groupings defined for dimensions {on}")
827
+
828
+ if axis is None:
829
+ axis = self.hasax(
830
+ on
831
+ )
832
+ if isinstance(axis,int):
833
+ axis = [axis]
834
+
835
+ output = self.alias()
836
+
837
+ for ax in axis:
838
+ output = output.aggregate_on(on,ax)
839
+
840
+ output.name = f"{on}_grouped_{self.name}"
841
+ return output
842
+
843
+ def aggregate_on(self,on,axis):
844
+ """Aggregate a Part along a given axis
845
+
846
+ Parameters
847
+ ----------
848
+ on : str
849
+ Dimension to aggregate on
850
+ axis : int
851
+ Axis to aggregate
852
+
853
+ Returns
854
+ -------
855
+ Part instance
856
+ Aggregated Part
857
+ """
858
+ if on not in self.axes[axis].dimensions:
859
+ raise ValueError(f"Dimension {on} not found in axis {axis}")
860
+
861
+ new_labels = self.axes[axis].labels.copy()
862
+ new_labels[on] = list(self.axes[axis].groupings[on].keys())
863
+ new_groupings = self.groupings.copy()
864
+ new_groupings[on] = {
865
+ item : [item] for item in self.groupings[on]
866
+ }
867
+
868
+ new_axis = self.axes.copy()
869
+ new_axis[axis] = Axe(new_labels,new_groupings)
870
+ new_shape = [len(ax) for ax in new_axis]
871
+
872
+ output = Part(axes=new_axis)
873
+ idsum = new_axis[axis].dimensions.index(on) #Index of the dimension to sum on
874
+ ref_dev = self.develop(axis)
875
+ new_dev = output.develop(axis)
876
+ selector = ["all"]*ref_dev.ndim
877
+ for label in new_labels[on]:
878
+ selector[axis+idsum] = label
879
+ new_dev[selector] = ref_dev[selector].sum(
880
+ axis=axis+idsum,
881
+ keepdims=True
882
+ )
883
+ output = new_dev.data.reshape(new_shape)
884
+ return self.alias(data=output,name=f"{on}_grouped_{self.name}",
885
+ axes=new_axis)
886
+
887
+
888
+ def get_labels(self,axis=None):
889
+ """
890
+ Returns the dictionnary of the Part labels
891
+
892
+ Parameters
893
+ ----------
894
+ axis : int or list of int, optional
895
+ Axis to investigate, by default None,
896
+ All axes are investigated.
897
+
898
+ Returns
899
+ -------
900
+ list
901
+ Labels used in the part.
902
+ """
903
+ labels = []
904
+ if axis is None:
905
+ axis = range(self.ndim)
906
+ if isinstance(axis,int):
907
+ axis = [axis]
908
+ #Make sure the axis is iterable
909
+ for ax in axis:
910
+ labels.append(self.axes[ax].labels)
911
+ return labels
912
+
913
+ def list_labels(self):
914
+ """List the labels of the Part"""
915
+ labels = dict()
916
+ ax_labels = self.get_labels()
917
+ for ax in ax_labels:
918
+ for label in ax.keys():
919
+ if label not in labels.keys():
920
+ labels[label] = ax[label]
921
+ return labels
922
+
923
+ def get_dimensions(self,axis=None):
924
+ """
925
+ Returns the list dimensions of the Part
926
+
927
+ Parameters
928
+ ----------
929
+ axis : int or list of int, optional
930
+ Axis to investigate, by default None,
931
+ All axes are investigated.
932
+
933
+ Returns
934
+ -------
935
+ list
936
+ Dimensions of the axes.
937
+ """
938
+ dimensions = []
939
+ if axis is None:
940
+ axis = range(self.ndim)
941
+ if isinstance(axis,int):
942
+ axis = [axis]
943
+ #Make sure the axis is iterable
944
+ for ax in axis:
945
+ dimensions.append(self.axes[ax].dimensions)
946
+ return dimensions
947
+
948
+ def rename_labels(self,old,new):
949
+ """
950
+ Rename some labels of the Part
951
+
952
+ Parameters
953
+ ----------
954
+ old : str
955
+ Name of the label to change.
956
+ new : str
957
+ New label name.
958
+ """
959
+ for ax in self.axes:
960
+ if old in ax.dimensions:
961
+ ax.rename_labels(old,new)
962
+ self._store_labels()
963
+
964
+ def replace_labels(self,name,labels,axis=None):
965
+ """
966
+ Update a label of the part
967
+
968
+ Parameters
969
+ ----------
970
+ name : str
971
+ Name of the label to update, by default None
972
+ labels : dict or list
973
+ New labels for the corresponding ax.
974
+ If a list is passed, the former label name is used.
975
+ axis : int, list of int, optional
976
+ List of axis on which the label is changed.
977
+ By default None, all possible axes are updated.
978
+ """
979
+ if axis is None:
980
+ axis = range(self.ndim)
981
+ if isinstance(axis,int):
982
+ axis = [axis]
983
+ if isinstance(labels,list):
984
+ labels = {name:labels}
985
+ for ax in axis:
986
+ if name in self.axes[ax].dimensions:
987
+ self.axes[ax].replace_labels(name,labels)
988
+ self._store_labels()
989
+
990
+ def set_labels(self,labels,axis=None):
991
+ """
992
+ Change the labels of the Part
993
+
994
+ Parameters
995
+ ----------
996
+ labels : dict or nested list
997
+ New labels of the axes.
998
+ If a nested list is passed, the first level corresponds to the axes
999
+ axis : str, optional
1000
+ Axis on which the labels are changes, by default None,
1001
+ all axes are updated.
1002
+ """
1003
+ if axis is None:
1004
+ axis = range(self.ndim)
1005
+ if isinstance(axis,int):
1006
+ axis = [axis]
1007
+ if isinstance(labels,list) and len(labels) == self.ndim:
1008
+ labels = {i:labels[i] for i in range(self.ndim)}
1009
+ for ax in axis:
1010
+ self.axes[ax].set_labels(labels[ax])
1011
+ self._store_labels()
1012
+
1013
+ def _store_labels(self):
1014
+ """Store the labels of the Part"""
1015
+ self.labels = self.list_labels()
1016
+
1017
+ def add_labels(self,labels,dimension=None,axes=None,
1018
+ fill_value=0):
1019
+ """
1020
+ Add indices to one or multiple Part axes.
1021
+
1022
+ Parameters
1023
+ ----------
1024
+ new_labels : list of str or dict
1025
+ List of indices to add
1026
+ dimension : str, optional
1027
+ Labels the new indices should be appended to,
1028
+ in case new_labels is not a dict.
1029
+ If new_labels is a dict, dimension is ignored.
1030
+ axes : int or set of ints, optional
1031
+ Axes or list of axes to modify.
1032
+ In case it is not specified, the axes are detected
1033
+ by looking for the dimension (or new_labels keys) in each ax.
1034
+ fill_value : float, optional
1035
+ Value used to initialize the new Part
1036
+
1037
+ Returns
1038
+ ----------
1039
+ Part instance
1040
+ Part instance with the additional ax indices.
1041
+
1042
+ Raise
1043
+ ----------
1044
+ ValueError
1045
+ A Value Error is raised if neither the axes nor the
1046
+ ref_set arguments are set.
1047
+ """
1048
+ if isinstance(labels,list):
1049
+ labels = {dimension:labels}
1050
+ dimension = list(labels.keys())[0]
1051
+ if axes is None:
1052
+ #Identify the axes with the ref_set in labels
1053
+ axes = self.hasax(dimension)
1054
+ elif isinstance(axes,int):
1055
+ axes = [axes]
1056
+
1057
+ new_axes = self.axes.copy()
1058
+ sel = ["all"]*self.ndim
1059
+ for ax in axes:
1060
+ log.debug("Add labels to axis "+str(ax))
1061
+ sel[ax] = dict()
1062
+ old_labels = self.axes[ax].labels
1063
+ new_labels = self.axes[ax].labels.copy()
1064
+ new_labels[dimension] = old_labels[dimension] + labels[dimension]
1065
+ new_axes[ax] = Axe(new_labels,self.groupings)
1066
+ sel[ax] = old_labels
1067
+
1068
+ new_shape = [len(ax) for ax in new_axes]
1069
+ output = self.alias(data=np.full(new_shape,fill_value,dtype="float64"),
1070
+ axes=new_axes)
1071
+
1072
+ #Put original data back in place
1073
+ output[sel] = self.data
1074
+ return output
1075
+
1076
+ def expand(self,axis=None,over="countries"):
1077
+ """
1078
+ Expand an axis of the Part
1079
+
1080
+ Create a new Axes with a unique dimension.
1081
+ Note that this operation significantly expands the size of the Part.
1082
+ It is recommended to use this method with Extension parts only.
1083
+
1084
+ Parameters
1085
+ ----------
1086
+ axis : int, optional
1087
+ Axe to extend.
1088
+ If left empty, the first suitable axe is expanded.
1089
+ over : str, optional
1090
+ Axe dimension to expand the Part by.
1091
+ The default is "countries".
1092
+
1093
+ Returns
1094
+ -------
1095
+ Part object
1096
+ New Part object with an additional dimension.
1097
+
1098
+ """
1099
+ if axis is None:
1100
+ axis = self.hasax(over)
1101
+
1102
+ for ax in axis:
1103
+ ref_ax = self.axes[ax]
1104
+ new_ax = Axe(ref_ax.labels[over],groupings=self.groupings)
1105
+ axes = self.axes.copy()
1106
+ axes.insert(ax,new_ax)
1107
+ new_shape = list(self.shape)
1108
+ new_shape.insert(axis,len(new_ax))
1109
+ output = np.zeros(new_shape)
1110
+ selector = [slice(None)]*self.ndim
1111
+ for item in ref_ax.labels[over]:
1112
+ newsel,refsel = selector.copy(),selector.copy()
1113
+ newsel.insert(ax,new_ax.sel(item))
1114
+ newsel[ax+1] = ref_ax.sel(item)
1115
+ refsel[ax]= ref_ax.sel(item)
1116
+ output[tuple(newsel)] = self.data[tuple(refsel)]
1117
+ return self.alias(data=output,name=f"expanded_{self.name}",axes=axes)
1118
+
1119
+ def issquare(self):
1120
+ """Assert wether the Part is square"""
1121
+ return self.ndim == 2 and len(self.axes[0])==len(self.axes[1])
1122
+
1123
+ def hasneg(self):
1124
+ """Test whether Part has negative elements"""
1125
+ if np.any(self.data<0):
1126
+ return True
1127
+ return False
1128
+
1129
+ def hasax(self,name=None):
1130
+ """Returns the dimensions along which a Part has given labels
1131
+
1132
+ If no axis can be found, an empty list is returned empty.
1133
+ This method can be used to assert the existence of a given dimension
1134
+ in the part.
1135
+
1136
+ Parameters
1137
+ ----------
1138
+ name : int, optional
1139
+ Name of the label to look for.
1140
+ If no name is given, all axes are returned.
1141
+
1142
+ Returns
1143
+ -------
1144
+ axes : list of ints
1145
+ Dimensions along which the labels are found.
1146
+
1147
+ """
1148
+ if name == "any" or name is None:
1149
+ return [i for i in range(self.ndim)]
1150
+ axes = []
1151
+ for i,ax in enumerate(self.axes):
1152
+ if name in ax.dimensions:
1153
+ axes.append(i)
1154
+ return axes
1155
+
1156
+ def __str__(self):
1157
+ return f"{self.name} Part object with {self.ndim} dimensions"
1158
+
1159
+ def sum(self,axis=None,on=None,keepdims=False):
1160
+ """
1161
+ Sum the Part along one or several axis, and/or on a given dimension.
1162
+
1163
+ Parameters
1164
+ ----------
1165
+ axis : int or list of int, optional
1166
+ Axe along which the sum is evaluated.
1167
+ By default None, the sum of all coefficients of the Part is returned
1168
+ on : str, optional
1169
+ name of the dimension to be summed on.
1170
+ If no axis is defined, the Part is summed over all axis having
1171
+ the corresponding dimension.
1172
+ By default None, the full ax is summed
1173
+ keepdims : bool, optional
1174
+ Whether to keep the number of dimensions of the original.
1175
+ By default False, the dimensions of lenght 1 are removed.
1176
+
1177
+ Returns
1178
+ -------
1179
+ Part instance or float
1180
+ Result of the sum.
1181
+ """
1182
+ if axis is None:
1183
+ if on is None:
1184
+ return self.data.sum()
1185
+ if not keepdims:
1186
+ self = self.squeeze()
1187
+ axis = self.hasax(on)
1188
+ if isinstance(axis,int):
1189
+ if on is not None:
1190
+ return self._sum_on(axis,on,keepdims)
1191
+ ax = self.axes.copy()
1192
+ if not keepdims:
1193
+ del ax[axis]
1194
+ else:
1195
+ ax[axis] = Axe(["all"])
1196
+ return self.alias(
1197
+ data=self.data.sum(axis,keepdims=keepdims),
1198
+ name=f"{self.name}_sum_{axis}",
1199
+ axes = ax
1200
+ )
1201
+ axis = sorted(axis)
1202
+ for ax in axis[::-1]:
1203
+ self = self.sum(ax,on,keepdims)
1204
+ return self
1205
+
1206
+ def _sum_on(self,axis,on,keepdims=False):
1207
+ """
1208
+ Sum a Part along an axis on a given dimension
1209
+ """
1210
+ ax = self.axes[axis]
1211
+ if on not in ax.dimensions:
1212
+ raise ValueError(f"Cannot sum on {on} as it is not a dimension of axis {axis}")
1213
+ if ax.levels == 1:
1214
+ #If the axis has a single level, this is a simple sum
1215
+ return self.alias(
1216
+ data = self.data.sum(axis,keepdims=keepdims),
1217
+ name=f"{self.name}_sum_{axis}",
1218
+ axes = self.axes
1219
+ )
1220
+ #Otherwise, sum on the relevant levels
1221
+ idsum = ax.dimensions.index(on) #Index of the dimension to sum on
1222
+ dev = self.develop(axis,squeeze=False)
1223
+ dev = dev.sum(axis+idsum,keepdims=keepdims)
1224
+ if keepdims:
1225
+ dev = dev.combine_axes(axis,axis+idsum)
1226
+ dev.name = f"{self.name}_sum_on_{on}_{axis}"
1227
+ return dev
1228
+
1229
+ def save(self,
1230
+ file=None,
1231
+ name=None,
1232
+ extension=".npy",
1233
+ overwrite=False,
1234
+ include_labels=False,
1235
+ write_instructions=False,
1236
+ **kwargs):
1237
+ """
1238
+ Save the Part object to a file
1239
+
1240
+ Parameters
1241
+ ----------
1242
+ name : str, optional
1243
+ Name under which the Part is saved.
1244
+ By default, the current part is used.
1245
+ path : Path-like, optional
1246
+ Directory in which the Path should be saved,
1247
+ by default None, the dir from which the part was loaded.
1248
+ extension : str, optional
1249
+ Format under which the part is saved. The default ".npy"
1250
+ If ".csv" is chosen, the part is saved as a csv file with labels
1251
+ file : path-like, optional
1252
+ Full path to the file to save the Part to.
1253
+ This overrides the path, name and extension arguments.
1254
+ overwrite : boolm optional
1255
+ Whether to overwrite an existing file.
1256
+ If set False, the file is saved with a new name.
1257
+ The default is False.
1258
+ write_instructions : bool, optional
1259
+ Whether to write the loading instructions to a yaml file.
1260
+ The default is False.
1261
+ include_labels : bool, optional
1262
+ Whether to include the labels in the saved file.
1263
+ Only applicable to .csv and .xlsx files.
1264
+ **kwargs : dict
1265
+ Additional arguments to pass to the saving function
1266
+
1267
+ Raises
1268
+ ------
1269
+ FileNotFoundError
1270
+ _description_
1271
+ """
1272
+ path = kwargs.get("path",None)
1273
+ if file is not None:
1274
+ path,name = os.path.split(file)
1275
+ name,possible_extension = os.path.splitext(name)
1276
+ if possible_extension != "":
1277
+ extension = possible_extension
1278
+ if name is None:
1279
+ name = self.name
1280
+ if path is None:
1281
+ raise FileNotFoundError("No path specified for saving the Part")
1282
+ if extension == ".nc":
1283
+ path = os.path.join(path,name)
1284
+ save_to_nc(self,path,overwrite,
1285
+ write_instructions=write_instructions,
1286
+ **kwargs)
1287
+ else:
1288
+ save_part_to_folder(
1289
+ self,
1290
+ path = path,
1291
+ name = name,
1292
+ extension = extension,
1293
+ overwrite = overwrite,
1294
+ include_labels=include_labels,
1295
+ write_instructions = write_instructions,
1296
+ **kwargs
1297
+ )
1298
+
1299
+ def to_pandas(self):
1300
+ """Return the current Part object as a Pandas DataFrame
1301
+
1302
+ Only applicable to Parts objects with 1 or 2 dimensions.
1303
+ """
1304
+ return converters.pandas.to_pandas(self)
1305
+
1306
+ def to_xarray(self):
1307
+ """
1308
+ Save the Part object to an xarray DataArray
1309
+
1310
+ Labels are directly passed to the DataArray as coords.
1311
+ Note that data will be flattened.
1312
+ The dimension order will be saved as an attribute.
1313
+ If you're loading the data back,
1314
+ the Part will be automatically reshaped to its original dimensions.
1315
+
1316
+ Returns
1317
+ -------
1318
+ xr.DataArray
1319
+ Corresponding DataArray
1320
+ """
1321
+ return converters.xarray.to_DataArray(self)
1322
+
1323
+ def mean(self,axis=None):
1324
+ return self.data.mean(axis)
1325
+
1326
+ def min(self,axis=None):
1327
+ return self.data.min(axis)
1328
+
1329
+ def max(self,axis=None):
1330
+ return self.data.max(axis)
1331
+
1332
+ def mul(self,a,propagate_labels=True):
1333
+ """
1334
+ Matrix multiplication between parts with labels propagation
1335
+
1336
+ Parameters
1337
+ ----------
1338
+ a : Part or numpy array
1339
+ Right-hand multiplicator.
1340
+ propagate_labels : bool, optional
1341
+ Whether to try propagating the labels from the right hand multiplicator
1342
+ By default True.
1343
+ If right-hand multiplicator is not a Part object, becomes False.
1344
+
1345
+ Returns
1346
+ -------
1347
+ Part instance
1348
+ result of the multiplication
1349
+ """
1350
+ if not isinstance(a,Part):
1351
+ propagate_labels = False
1352
+ name="array"
1353
+ else:
1354
+ name = a.name
1355
+ data = np.matmul(self.data,a.data)
1356
+ axes = [self.axes[i] for i in range(self.ndim-1)]
1357
+ for ax in range(a.ndim-1):
1358
+ if propagate_labels:
1359
+ axes.append(a.axes[ax+1])
1360
+ else:
1361
+ axes.append(Axe([i for i in range(a.shape[ax+1])]))
1362
+ return self.alias(data=data,name=f"{self.name}.{name}",axes=axes)
1363
+
1364
+ def filter(self,threshold,fill_value=0):
1365
+ """
1366
+ Set to 0 the values below a given threshold
1367
+
1368
+ Parameters
1369
+ ----------
1370
+ threshold : float
1371
+ Threshold value.
1372
+ fill_value : float, optional
1373
+ Value to replace the filtered values with.
1374
+ The default is 0.
1375
+
1376
+ Returns
1377
+ -------
1378
+ Part instance
1379
+ Filtered Part.
1380
+
1381
+ """
1382
+ data = self.data.copy()
1383
+ data[data<threshold] = fill_value
1384
+ return self.alias(data=data,name=f"filtered_{self.name}_{threshold}")
1385
+
1386
+ def diag(self):
1387
+ if self.ndim == 1:
1388
+ return self.alias(data=np.diag(self.data),
1389
+ name=f"diag_{self.name}",
1390
+ axes = self.axes*2)
1391
+
1392
+ def __add__(self,a):
1393
+ if isinstance(a,Part):
1394
+ name = a.name
1395
+ a = a.data
1396
+ else:
1397
+ name=""
1398
+ if isinstance(a,np.ndarray) and self.ndim != a.ndim:
1399
+ a = a.squeeze()
1400
+ self = self.squeeze()
1401
+ return self.alias(data=a+self.data,name=f"{self.name}+{name}")
1402
+
1403
+ def __radd__(self,a):
1404
+ return self.__add__(a)
1405
+
1406
+ def __rmul__(self,a):
1407
+ return self.__mul__(a)
1408
+
1409
+ def __mul__(self,a):
1410
+ if isinstance(a,Part):
1411
+ name = "{a.name}*{self.name}"
1412
+ a = a.data
1413
+ else:
1414
+ if isinstance(a,int):
1415
+ name = f"{a}*{self.name}"
1416
+ else:
1417
+ name = f"array*{self.name}"
1418
+ data = self.data*a
1419
+ if data.ndim!=self.ndim:
1420
+ data = data.squeeze()
1421
+ #Trust numpy to broadcast the multiplication
1422
+ #Squeeze to get rid of unused dimensions
1423
+ return self.alias(data=data,
1424
+ name=name)
1425
+
1426
+ def __neg__(self):
1427
+ return self.alias(data=-self.data,name=f"-{self.name}")
1428
+
1429
+ def __sub__(self,a):
1430
+ if isinstance(a,Part):
1431
+ name = a.name
1432
+ a = a.data
1433
+ else:
1434
+ name=""
1435
+ return self.alias(data=self.data-a,name=f"{self.name}-{name}")
1436
+
1437
+ def __rsub__(self,a):
1438
+ if isinstance(a,Part):
1439
+ name = a.name
1440
+ a = a.data
1441
+ else:
1442
+ name=""
1443
+ return self.alias(data=a-self.data,name=f"{name}-{self.name}")
1444
+
1445
+ def power(self,a):
1446
+ if isinstance(a,Part):
1447
+ a = a.data
1448
+ name = f"{self.name}**{a.name}"
1449
+ elif isinstance(a,int) or isinstance(a,float):
1450
+ name = f"{self.name}**{a}"
1451
+ else:
1452
+ name = f"{self.name}**array"
1453
+ return self.alias(data=np.power(self,a),name=name)
1454
+
1455
+ def __pow__(self,a):
1456
+ return self.power(a)
1457
+
1458
+ def __eq__(self,a):
1459
+ if isinstance(a,Part):
1460
+ return np.all(self.data==a.data)
1461
+ return False
1462
+
1463
+ def __rtruediv__(self,a):
1464
+ if isinstance(a,Part):
1465
+ name = f"{self.name}/{a.name}"
1466
+ a = a.data
1467
+ else:
1468
+ if isinstance(a,int):
1469
+ name = f"{a}/{self.name}"
1470
+ else:
1471
+ name= f"array/{self.name}"
1472
+ if np.sum(self.data==0)!=0:
1473
+ log.warning("Division by zero in "+name)
1474
+ return self.alias(data=a/self.data,
1475
+ name=name)
1476
+
1477
+ def __truediv__(self,a):
1478
+ if isinstance(a,Part):
1479
+ name = f"{a.name}/{self.name}"
1480
+ a = a.data
1481
+ else:
1482
+ if isinstance(a,int):
1483
+ name = f"{self.name}/{a}"
1484
+ else:
1485
+ name= f"{self.name}/array"
1486
+ if np.sum(a==0)!=0:
1487
+ log.warning("Division by zero in "+name)
1488
+ return self.alias(data=self.data/a,
1489
+ name=name)
1490
+
1491
+ def __getattr__(self,name):
1492
+ name = name.casefold()
1493
+ try:
1494
+ return self.metadata[name]
1495
+ except:
1496
+ pass
1497
+ raise AttributeError(f"Attribute {name} not found")
1498
+
1499
+ def transpose(self):
1500
+ return self.alias(data=self.data.transpose(),
1501
+ name=f"transposed_{self.name}",
1502
+ axes=self.axes[::-1])
1503
+
1504
+