sqil-core 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,107 @@
1
+ import numpy as np
2
+
3
+ from ._const import _PARAM_METADATA
4
+ from ._formatter import format_number
5
+ from ._read import read_json
6
+
7
+
8
+ def set_plot_style(plt):
9
+ """Sets the matplotlib plotting style to a SQIL curated one."""
10
+ style = {
11
+ "font.size": 20,
12
+ "xtick.labelsize": 18, # X-axis tick labels
13
+ "ytick.labelsize": 18, # Y-axis tick labels
14
+ "lines.linewidth": 2.5, # Line width
15
+ "lines.marker": "o",
16
+ "lines.markersize": 7, # Marker size
17
+ "lines.markeredgewidth": 1.5, # Marker line width
18
+ "lines.markerfacecolor": "none",
19
+ "axes.grid": True,
20
+ "grid.linestyle": "--",
21
+ "xtick.major.size": 8,
22
+ "xtick.major.width": 1.5,
23
+ "ytick.major.size": 8,
24
+ "ytick.major.width": 1.5,
25
+ "figure.figsize": (20, 7),
26
+ }
27
+ return plt.rcParams.update(style)
28
+
29
+
30
+ def reset_plot_style(plt):
31
+ """Resets the matplotlib plotting style to its default value."""
32
+ return plt.rcParams.update(plt.rcParamsDefault)
33
+
34
+
35
+ def get_x_id_by_plot_dim(exp_id: str, plot_dim: str, sweep_param_id: str | None) -> str:
36
+ """Returns the param_id of the parameter that should be used as the x-axis."""
37
+ if exp_id == "CW_onetone" or exp_id == "pulsed_onetone":
38
+ if plot_dim == "1":
39
+ return sweep_param_id or "ro_freq"
40
+ return "ro_freq"
41
+ elif exp_id == "CW_twotone" or exp_id == "pulsed_twotone":
42
+ if plot_dim == "1":
43
+ return sweep_param_id or "qu_freq"
44
+ return "qu_freq"
45
+
46
+
47
+ def build_title(title: str, path: str, params: list[str]) -> str:
48
+ """Build a plot title that includes the values of given parameters found in
49
+ the params_dict.json file, e.g. One tone with I = 0.5 mA.
50
+
51
+ Parameters
52
+ ----------
53
+ title : str
54
+ Title of the plot to which the parameters will be appended.
55
+
56
+ path: str
57
+ Path to the param_dict.json file.
58
+
59
+ params : List[str]
60
+ List of keys of parameters in the param_dict.json file.
61
+
62
+ Returns
63
+ -------
64
+ str
65
+ The original title followed by parameter values.
66
+ """
67
+ dic = read_json(f"{path}/param_dict.json")
68
+ title += " with "
69
+ for idx, param in enumerate(params):
70
+ if not (param in _PARAM_METADATA.keys()) or not (param in dic):
71
+ title += f"{param} = ? & "
72
+ continue
73
+ meta = _PARAM_METADATA[param]
74
+ value = format_number(dic[param], 3, meta["unit"])
75
+ title += f"${meta['symbol']} =${value} & "
76
+ if idx % 2 == 0 and idx != 0:
77
+ title += "\n"
78
+ return title[0:-3]
79
+
80
+
81
+ def guess_plot_dimension(
82
+ f: np.ndarray, sweep: np.ndarray | list = [], threshold_2D=10
83
+ ) -> tuple[list["1", "1.5", "2"] | np.ndarray]:
84
+ """Guess if the plot should be a 1D line, a collection of 1D lines (1.5D),
85
+ or a 2D color plot.
86
+
87
+ Parameters
88
+ ----------
89
+ f : np.ndarray
90
+ Main variable, usually frequency
91
+ sweep : Union[np.ndarray, List], optional
92
+ Sweep variable, by default []
93
+ threshold_2D : int, optional
94
+ Threshold of sweeping parameters after which the data is considered, by default 10
95
+
96
+ Returns
97
+ -------
98
+ Tuple[Union['1', '1.5', '2'], np.ndarray]
99
+ The plot dimension ('1', '1.5' or '2') and the vector that should be used as the x
100
+ axis in the plot.
101
+ """
102
+ if len(sweep) > threshold_2D:
103
+ return "2"
104
+ elif len(f.shape) == 2 and len(sweep.shape) == 1:
105
+ return "1.5"
106
+ else:
107
+ return "1"
@@ -1,156 +1,179 @@
1
- import json
2
- import os
3
-
4
- import h5py
5
- import numpy as np
6
-
7
- from .const import PARAM_METADATA
8
-
9
-
10
- def extract_h5_data(
11
- path: str, keys: list[str] | None = None
12
- ) -> dict | tuple[np.ndarray, ...]:
13
- """Extract data at the given keys from an HDF5 file. If no keys are
14
- given (None) returns the data field of the object.
15
-
16
- Parameters
17
- ----------
18
- path : str
19
- path to the HDF5 file or a folder in which is contained a data.ddh5 file
20
- keys : None or List, optional
21
- list of keys to extract from file['data'], by default None
22
-
23
- Returns
24
- -------
25
- Dict or Tuple[np.ndarray, ...]
26
- The full data dictionary if keys = None.
27
- The tuple with the requested keys otherwise.
28
-
29
- Example
30
- -------
31
- Extract the data object from the dataset:
32
- >>> data = extract_h5_data(path)
33
- Extracting only 'amp' and 'phase' from the dataset:
34
- >>> amp, phase = extract_h5_data(path, ['amp', 'phase'])
35
- Extracting only 'phase':
36
- >>> phase, = extract_h5_data(path, ['phase'])
37
- """
38
- # If the path is to a folder open /data.ddh5
39
- if os.path.isdir(path):
40
- path = os.path.join(path, "data.ddh5")
41
-
42
- with h5py.File(path, "r") as h5file:
43
- data = h5file["data"]
44
- data_keys = data.keys()
45
- # Extract only the requested keys
46
- if bool(keys) and (len(keys) > 0):
47
- res = []
48
- for key in keys:
49
- key = str(key)
50
- if (not bool(key)) | (key not in data_keys):
51
- res.append([])
52
- continue
53
- res.append(np.array(data[key][:]))
54
- return tuple(res)
55
- # Extract the whole data dictionary
56
- return _h5_to_dict(data)
57
-
58
-
59
- def _h5_to_dict(obj) -> dict:
60
- """Convert h5 data into a dictionary"""
61
- data_dict = {}
62
- for key in obj.keys():
63
- item = obj[key]
64
- if isinstance(item, h5py.Dataset):
65
- data_dict[key] = item[:]
66
- elif isinstance(item, h5py.Group):
67
- data_dict[key] = extract_h5_data(item)
68
- return data_dict
69
-
70
-
71
- def read_json(path: str) -> dict:
72
- """Reads a json file and returns the data as a dictionary."""
73
- with open(path) as f:
74
- dictionary = json.load(f)
75
- return dictionary
76
-
77
-
78
- class ParamInfo:
79
- """Parameter information for items of param_dict
80
-
81
- Attributes:
82
- id (str): param_dict key
83
- value (any): the value of the parameter
84
- name (str): full name of the parameter (e.g. Readout frequency)
85
- symbol (str): symbol of the parameter in Latex notation (e.g. f_{RO})
86
- unit (str): base unit of measurement (e.g. Hz)
87
- scale (int): the scale that should be generally applied to raw data (e.g. 1e-9 to take raw Hz to GHz)
88
- """
89
-
90
- def __init__(self, id, value):
91
- self.id = id
92
- self.value = value
93
- if id in PARAM_METADATA:
94
- meta = PARAM_METADATA[id]
95
- else:
96
- meta = {}
97
- self.name = meta["name"] if "name" in meta else id
98
- self.symbol = meta["symbol"] if "symbol" in meta else id
99
- self.unit = meta["unit"] if "unit" in meta else ""
100
- self.scale = meta["scale"] if "scale" in meta else 1
101
-
102
- def to_dict(self):
103
- """Convert ParamInfo to a dictionary."""
104
- return {
105
- "id": self.id,
106
- "value": self.value,
107
- "name": self.name,
108
- "symbol": self.symbol,
109
- "unit": self.unit,
110
- "scale": self.scale,
111
- }
112
-
113
- def __str__(self):
114
- """Return a JSON-formatted string of the object."""
115
- return json.dumps(self.to_dict())
116
-
117
- def __eq__(self, other):
118
- if isinstance(other, ParamInfo):
119
- return (self.id == other.id) & (self.value == other.value)
120
- if isinstance(other, (int, float, complex, str)):
121
- return self.value == other
122
- return False
123
-
124
-
125
- ParamDict = dict[str, ParamInfo | dict[str, ParamInfo]]
126
-
127
-
128
- def _enrich_param_dict(param_dict: dict) -> ParamDict:
129
- """Add metadata to param_dict entries."""
130
- res = {}
131
- for key, value in param_dict.items():
132
- if isinstance(value, dict):
133
- # Recursive step for nested dictionaries
134
- res[key] = _enrich_param_dict(value)
135
- else:
136
- res[key] = ParamInfo(key, value)
137
- return res
138
-
139
-
140
- def read_param_dict(path: str) -> ParamDict:
141
- """Read param_dict and include additional information for each entry.
142
-
143
- Parameters
144
- ----------
145
- path : str
146
- Path to the file or a folder in which is contained a param_dict.json file
147
-
148
- Returns
149
- -------
150
- ParamDict
151
- The param_dict with additional metadata
152
- """
153
- # If the path is to a folder open /param_dict.json
154
- if os.path.isdir(path):
155
- path = os.path.join(path, "param_dict.json")
156
- return _enrich_param_dict(read_json(path))
1
+ import json
2
+ import os
3
+
4
+ import h5py
5
+ import numpy as np
6
+
7
+ from ._const import _EXP_UNIT_MAP, _PARAM_METADATA
8
+
9
+
10
+ def extract_h5_data(
11
+ path: str, keys: list[str] | None = None
12
+ ) -> dict | tuple[np.ndarray, ...]:
13
+ """Extract data at the given keys from an HDF5 file. If no keys are
14
+ given (None) returns the data field of the object.
15
+
16
+ Parameters
17
+ ----------
18
+ path : str
19
+ path to the HDF5 file or a folder in which is contained a data.ddh5 file
20
+ keys : None or List, optional
21
+ list of keys to extract from file['data'], by default None
22
+
23
+ Returns
24
+ -------
25
+ Dict or Tuple[np.ndarray, ...]
26
+ The full data dictionary if keys = None.
27
+ The tuple with the requested keys otherwise.
28
+
29
+ Example
30
+ -------
31
+ Extract the data object from the dataset:
32
+ >>> data = extract_h5_data(path)
33
+ Extracting only 'amp' and 'phase' from the dataset:
34
+ >>> amp, phase = extract_h5_data(path, ['amp', 'phase'])
35
+ Extracting only 'phase':
36
+ >>> phase, = extract_h5_data(path, ['phase'])
37
+ """
38
+ # If the path is to a folder open /data.ddh5
39
+ if os.path.isdir(path):
40
+ path = os.path.join(path, "data.ddh5")
41
+
42
+ with h5py.File(path, "r") as h5file:
43
+ data = h5file["data"]
44
+ data_keys = data.keys()
45
+ # Extract only the requested keys
46
+ if bool(keys) and (len(keys) > 0):
47
+ res = []
48
+ for key in keys:
49
+ key = str(key)
50
+ if (not bool(key)) | (key not in data_keys):
51
+ res.append([])
52
+ continue
53
+ res.append(np.array(data[key][:]))
54
+ return tuple(res)
55
+ # Extract the whole data dictionary
56
+ return _h5_to_dict(data)
57
+
58
+
59
+ def _h5_to_dict(obj) -> dict:
60
+ """Convert h5 data into a dictionary"""
61
+ data_dict = {}
62
+ for key in obj.keys():
63
+ item = obj[key]
64
+ if isinstance(item, h5py.Dataset):
65
+ data_dict[key] = item[:]
66
+ elif isinstance(item, h5py.Group):
67
+ data_dict[key] = extract_h5_data(item)
68
+ return data_dict
69
+
70
+
71
+ def read_json(path: str) -> dict:
72
+ """Reads a json file and returns the data as a dictionary."""
73
+ with open(path) as f:
74
+ dictionary = json.load(f)
75
+ return dictionary
76
+
77
+
78
+ class ParamInfo:
79
+ """Parameter information for items of param_dict
80
+
81
+ Attributes:
82
+ id (str): param_dict key
83
+ value (any): the value of the parameter
84
+ name (str): full name of the parameter (e.g. Readout frequency)
85
+ symbol (str): symbol of the parameter in Latex notation (e.g. f_{RO})
86
+ unit (str): base unit of measurement (e.g. Hz)
87
+ scale (int): the scale that should be generally applied to raw data (e.g. 1e-9 to take raw Hz to GHz)
88
+ """
89
+
90
+ def __init__(self, id, value):
91
+ self.id = id
92
+ self.value = value
93
+ if id in _PARAM_METADATA:
94
+ meta = _PARAM_METADATA[id]
95
+ else:
96
+ meta = {}
97
+ self.name = meta["name"] if "name" in meta else id
98
+ self.symbol = meta["symbol"] if "symbol" in meta else id
99
+ self.unit = meta["unit"] if "unit" in meta else ""
100
+ self.scale = meta["scale"] if "scale" in meta else 1
101
+
102
+ def get_name_and_unit(self):
103
+ res = self.name
104
+ if self.unit != "":
105
+ exponent = -(int(f"{self.scale:.0e}".split("e")[1]) // 3) * 3
106
+ unit = f" [{_EXP_UNIT_MAP[exponent]}{self.unit}]"
107
+ res += unit
108
+ return res
109
+
110
+ def to_dict(self):
111
+ """Convert ParamInfo to a dictionary."""
112
+ return {
113
+ "id": self.id,
114
+ "value": self.value,
115
+ "name": self.name,
116
+ "symbol": self.symbol,
117
+ "unit": self.unit,
118
+ "scale": self.scale,
119
+ }
120
+
121
+ def __str__(self):
122
+ """Return a JSON-formatted string of the object."""
123
+ return json.dumps(self.to_dict())
124
+
125
+ def __eq__(self, other):
126
+ if isinstance(other, ParamInfo):
127
+ return (self.id == other.id) & (self.value == other.value)
128
+ if isinstance(other, (int, float, complex, str)):
129
+ return self.value == other
130
+ return False
131
+
132
+ def __bool__(self):
133
+ return bool(self.id)
134
+
135
+
136
+ ParamDict = dict[str, ParamInfo | dict[str, ParamInfo]]
137
+
138
+
139
+ def _enrich_param_dict(param_dict: dict) -> ParamDict:
140
+ """Add metadata to param_dict entries."""
141
+ res = {}
142
+ for key, value in param_dict.items():
143
+ if isinstance(value, dict):
144
+ # Recursive step for nested dictionaries
145
+ res[key] = _enrich_param_dict(value)
146
+ else:
147
+ res[key] = ParamInfo(key, value)
148
+ return res
149
+
150
+
151
+ def read_param_dict(path: str) -> ParamDict:
152
+ """Read param_dict and include additional information for each entry.
153
+
154
+ Parameters
155
+ ----------
156
+ path : str
157
+ Path to the file or a folder in which is contained a param_dict.json file
158
+
159
+ Returns
160
+ -------
161
+ ParamDict
162
+ The param_dict with additional metadata
163
+ """
164
+ # If the path is to a folder open /param_dict.json
165
+ if os.path.isdir(path):
166
+ path = os.path.join(path, "param_dict.json")
167
+ return _enrich_param_dict(read_json(path))
168
+
169
+
170
+ def get_sweep_param(path: str, exp_id: str):
171
+ params = read_param_dict(path)
172
+ sweep_id = params[exp_id]["sweep"].value
173
+ if sweep_id:
174
+ return params[sweep_id]
175
+ return ParamInfo("", "")
176
+
177
+
178
+ def get_measurement_id(path):
179
+ return os.path.basename(path)[0:5]
@@ -0,0 +1,17 @@
1
+ import inspect
2
+
3
+
4
+ def _count_function_parameters(func):
5
+ sig = inspect.signature(func)
6
+ return len(
7
+ [
8
+ param
9
+ for param in sig.parameters.values()
10
+ if param.default == inspect.Parameter.empty
11
+ and param.kind
12
+ in (
13
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
14
+ inspect.Parameter.POSITIONAL_ONLY,
15
+ )
16
+ ]
17
+ )
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.3
2
2
  Name: sqil-core
3
- Version: 0.0.1
3
+ Version: 0.1.0
4
4
  Summary: The codebase of the SQIL group in EPFL
5
5
  Author: Andrea Duina
6
6
  Requires-Python: >=3.10,<4.0
@@ -11,10 +11,11 @@ Classifier: Programming Language :: Python :: 3.12
11
11
  Classifier: Programming Language :: Python :: 3.13
12
12
  Requires-Dist: h5py (>=3.12.1,<4.0.0)
13
13
  Requires-Dist: isort (==5.9.3)
14
+ Requires-Dist: lmfit (>=1.3.2,<2.0.0)
14
15
  Requires-Dist: matplotlib (>=3.9.3,<4.0.0)
15
- Requires-Dist: numpy (>=2.2.0,<3.0.0)
16
+ Requires-Dist: numpy (>=2.2.3,<3.0.0)
16
17
  Requires-Dist: scipy (>=1.14.1,<2.0.0)
17
- Requires-Dist: seaborn (>=0.13.2,<0.14.0)
18
+ Requires-Dist: tabulate (>=0.9.0,<0.10.0)
18
19
  Description-Content-Type: text/markdown
19
20
 
20
21
  # For users
@@ -38,34 +39,44 @@ path = 'path to your data folder'
38
39
  mag, phase, freq = sqil.extract_h5_data(path, ['mag_dB', 'phase', 'ro_freq'])
39
40
  ```
40
41
 
42
+ ## Documentation
43
+ You can find the documentation for this package [here](https://sqil-epfl.github.io/sqil-core/)
44
+
41
45
  # For developers
42
46
 
43
47
  ## Development
44
48
 
45
49
  1. **Install poetry if you haven't already**
50
+
46
51
  ```bash
47
52
  $ pip install poetry
53
+ $ pip install poetry-plugin-shell
48
54
  ```
49
55
 
50
56
  2. **Install the required packages using poetry**
57
+
51
58
  ```bash
52
59
  $ poetry install
53
60
  ```
54
61
 
55
62
  3. **Install the pre-commit hooks**
56
- If you are on windows you need to install git ([https://git-scm.com/downloads](here)) and add it to your windows PATH.
57
- After the installation open a new terminal.
63
+ If you are on windows you need to install git ([https://git-scm.com/downloads](here)) and add it to your windows PATH.
64
+ After the installation open a new terminal.
65
+
58
66
  ```bash
59
67
  $ poetry run pre-commit install
60
68
  ```
69
+
61
70
  This will check if your python files are formatted correctly when you try to commit.
62
71
  If that's not the case the commit will be canceled and the files will be automatically formatted.
63
72
  Then you'll have to add and commit again the new files.
64
73
 
65
74
  4. **Start the virtual environment**
75
+
66
76
  ```bash
67
77
  $ poetry shell
68
78
  ```
79
+
69
80
  To exit the virtual environment just use `exit`
70
81
 
71
82
  #### Test your changes
@@ -91,7 +102,11 @@ $ poetry run build
91
102
  $ pip install PATH_TO_SQIL_CORE_FOLDER/dist/SQIL_CORE-VERSION.whl
92
103
  ```
93
104
 
94
- If you're using a jupyter notebook remember to restart the kernel
105
+ If you're testing a new function remember to import it in the folder's `__init__.py` file.
106
+
107
+
108
+ If you're using a jupyter notebook remember to restart the kernel.
109
+
95
110
 
96
111
  ## Build
97
112
 
@@ -99,6 +114,16 @@ If you're using a jupyter notebook remember to restart the kernel
99
114
  $ poetry run build
100
115
  ```
101
116
 
117
+ ## Publish
118
+
119
+ To publish version X.X.X run the commands below. This will trigger a GitHub action that deploys to release to PyPi (pip) and GitHub.
120
+ Remember also to change the version number in the `pyproject.toml` file.
121
+
122
+ ```bash
123
+ $ git tag vX.X.X
124
+ $ git push origin vX.X.X
125
+ ```
126
+
102
127
  ## Docs
103
128
 
104
129
  Serve docs
@@ -0,0 +1,19 @@
1
+ sqil_core/__init__.py,sha256=6CsaaYqp3DKoQLNiOzVix_VeHqUo0Bl3ugpnyBVs19E,193
2
+ sqil_core/config.py,sha256=x7nNdIGJh_2jU6_WuyZ_VjjwkcvibAK-Rc-k0yePUFA,166
3
+ sqil_core/fit/__init__.py,sha256=q_dqH77DsYrpJUWm9sm7T4d2k2uQsLbVLgrQaE25-Fw,320
4
+ sqil_core/fit/_core.py,sha256=izbn9Lh_5Gk3ywc9GXKbRbGEQrBniIkaatfuPMI-HmM,36003
5
+ sqil_core/fit/_fit.py,sha256=dy4Vshoy8jcuof0ROxxG4h5oiwTFy2593eEHDRRobqE,25183
6
+ sqil_core/fit/_models.py,sha256=eoIbAvtmusoUn_LN9tIe4PemRV4HXhPDkl75AxVbFRw,3034
7
+ sqil_core/resonator/__init__.py,sha256=LFq30-1r6cVNRnNYfiqlzIhLJa0yHNeZ_tHPcSJej_Q,210
8
+ sqil_core/resonator/_resonator.py,sha256=rjcUqDvglKK00M5zIqVXukNrZFMscOZZIG2YBn13BjE,29198
9
+ sqil_core/utils/__init__.py,sha256=xdefhf8GUpGPKuGEhuP3He3067nPcD3CRP-5CJrH6Gk,1363
10
+ sqil_core/utils/_analysis.py,sha256=qhipdYvxWR3P886us_gF8dBQZGzGJme8-5AT1YxBaKo,9982
11
+ sqil_core/utils/_const.py,sha256=H58XFwq7_pX6M7kapBPbeky8Ck-GGE66fzkNe2e3GW0,1217
12
+ sqil_core/utils/_formatter.py,sha256=F7JRPyU_pYymRD7G1iNVYiZS3D1ccuD9AYZWYg3SiW4,5750
13
+ sqil_core/utils/_plot.py,sha256=ABmt-e7wfqynnU2P9584EqMkmwhGqJItlgFctZQzOp4,3442
14
+ sqil_core/utils/_read.py,sha256=K0IjnUhkbCuf_wbKn_0fRP86bNGYj3dEbFXzQi9lOoU,5421
15
+ sqil_core/utils/_utils.py,sha256=BMelbGLbjBh0Mu-Cb86bkoKXcGAABVZ7UkBZs68JzUA,420
16
+ sqil_core-0.1.0.dist-info/METADATA,sha256=U0gtbLDys5_rSBQ5Y0ifCtBYMMgNd7IATqGcpjL49xM,3000
17
+ sqil_core-0.1.0.dist-info/WHEEL,sha256=7dDg4QLnNKTvwIDR9Ac8jJaAmBC_owJrckbC0jjThyA,88
18
+ sqil_core-0.1.0.dist-info/entry_points.txt,sha256=mnYKe0NKDcT5Py_8cL44TgcbLVOUS-CxmGBMEcsbGAQ,95
19
+ sqil_core-0.1.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.9.1
2
+ Generator: poetry-core 2.1.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,5 +1,5 @@
1
1
  [console_scripts]
2
2
  build=scripts:build
3
3
  docs-build=scripts:docs_build
4
- docs-serve=scripts:docs_serve
4
+ docs-dev=scripts:docs_dev
5
5
 
@@ -1,68 +0,0 @@
1
- import numpy as np
2
-
3
-
4
- def remove_offset(data: np.ndarray, avg: int = 3) -> np.ndarray:
5
- """Removes the initial offset from a data matrix or vector by subtracting
6
- the average of the first `avg` points. After applying this function,
7
- the first point of each column of the data will be shifted to (about) 0.
8
-
9
- Parameters
10
- ----------
11
- data : np.ndarray
12
- Input data, either a 1D vector or a 2D matrix
13
- avg : int, optional
14
- The number of initial points to average when calculating
15
- the offset, by default 3
16
-
17
- Returns
18
- -------
19
- np.ndarray
20
- The input data with the offset removed
21
- """
22
- is1D = len(data.shape) == 1
23
- if is1D:
24
- return data - np.mean(data[0:avg])
25
- return data - np.mean(data[:, 0:avg], axis=1).reshape(data.shape[0], 1)
26
-
27
-
28
- def estimate_linear_background(x: np.ndarray, data: np.ndarray, points_cut=0.1) -> list:
29
- is1D = len(data.shape) == 1
30
- points = data.shape[0] if is1D else data.shape[1]
31
- cut = int(points * points_cut)
32
-
33
- # Consider just the cut points
34
- x_data = x[0:cut] if is1D else x[0:cut, :]
35
- X = np.vstack([np.ones_like(x_data), x_data]).T
36
- y_data = data[0:cut] if is1D else data[0:cut, :]
37
-
38
- # Linear fit
39
- coefficients, residuals, _, _ = np.linalg.lstsq(
40
- X, y_data if is1D else y_data.T, rcond=None
41
- )
42
-
43
- return coefficients
44
-
45
-
46
- def remove_linear_background(
47
- x: np.ndarray, data: np.ndarray, points_cut=0.1
48
- ) -> np.ndarray:
49
- """Removes a linear background from the input data (e.g. the phase background
50
- of a spectroscopy).
51
-
52
-
53
- Parameters
54
- ----------
55
- data : np.ndarray
56
- Input data. Can be a 1D vector or a 2D matrix.
57
-
58
- Returns
59
- -------
60
- np.ndarray
61
- The input data with the linear background removed. The shape of the
62
- returned array matches the input `data`.
63
- """
64
- coefficients = estimate_linear_background(x, data, points_cut)
65
-
66
- # Remove background over the whole array
67
- X = np.vstack([np.ones_like(x), x]).T
68
- return data - (X @ coefficients).T