nbs-bl 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. nbs_bl/__init__.py +15 -0
  2. nbs_bl/beamline.py +450 -0
  3. nbs_bl/configuration.py +838 -0
  4. nbs_bl/detectors.py +89 -0
  5. nbs_bl/devices/__init__.py +12 -0
  6. nbs_bl/devices/detectors.py +154 -0
  7. nbs_bl/devices/motors.py +242 -0
  8. nbs_bl/devices/sampleholders.py +360 -0
  9. nbs_bl/devices/shutters.py +120 -0
  10. nbs_bl/devices/slits.py +51 -0
  11. nbs_bl/gGrEqns.py +171 -0
  12. nbs_bl/geometry/__init__.py +0 -0
  13. nbs_bl/geometry/affine.py +197 -0
  14. nbs_bl/geometry/bars.py +189 -0
  15. nbs_bl/geometry/frames.py +534 -0
  16. nbs_bl/geometry/linalg.py +138 -0
  17. nbs_bl/geometry/polygons.py +56 -0
  18. nbs_bl/help.py +126 -0
  19. nbs_bl/hw.py +270 -0
  20. nbs_bl/load.py +113 -0
  21. nbs_bl/motors.py +19 -0
  22. nbs_bl/planStatus.py +5 -0
  23. nbs_bl/plans/__init__.py +8 -0
  24. nbs_bl/plans/batches.py +174 -0
  25. nbs_bl/plans/conditions.py +77 -0
  26. nbs_bl/plans/flyscan_base.py +180 -0
  27. nbs_bl/plans/groups.py +55 -0
  28. nbs_bl/plans/maximizers.py +423 -0
  29. nbs_bl/plans/metaplans.py +179 -0
  30. nbs_bl/plans/plan_stubs.py +246 -0
  31. nbs_bl/plans/preprocessors.py +160 -0
  32. nbs_bl/plans/scan_base.py +58 -0
  33. nbs_bl/plans/scan_decorators.py +524 -0
  34. nbs_bl/plans/scans.py +145 -0
  35. nbs_bl/plans/suspenders.py +87 -0
  36. nbs_bl/plans/time_estimation.py +168 -0
  37. nbs_bl/plans/xas.py +123 -0
  38. nbs_bl/printing.py +221 -0
  39. nbs_bl/qt/models/beamline.py +11 -0
  40. nbs_bl/qt/models/energy.py +53 -0
  41. nbs_bl/qt/widgets/energy.py +225 -0
  42. nbs_bl/queueserver.py +249 -0
  43. nbs_bl/redisDevice.py +96 -0
  44. nbs_bl/run_engine.py +63 -0
  45. nbs_bl/samples.py +130 -0
  46. nbs_bl/settings.py +68 -0
  47. nbs_bl/shutters.py +39 -0
  48. nbs_bl/sim/__init__.py +2 -0
  49. nbs_bl/sim/config/polphase.nc +0 -0
  50. nbs_bl/sim/energy.py +403 -0
  51. nbs_bl/sim/manipulator.py +14 -0
  52. nbs_bl/sim/utils.py +36 -0
  53. nbs_bl/startup.py +27 -0
  54. nbs_bl/status.py +114 -0
  55. nbs_bl/tests/__init__.py +0 -0
  56. nbs_bl/tests/modify_regions.py +160 -0
  57. nbs_bl/tests/test_frames.py +99 -0
  58. nbs_bl/tests/test_panels.py +69 -0
  59. nbs_bl/utils.py +235 -0
  60. nbs_bl-0.2.0.dist-info/METADATA +71 -0
  61. nbs_bl-0.2.0.dist-info/RECORD +64 -0
  62. nbs_bl-0.2.0.dist-info/WHEEL +4 -0
  63. nbs_bl-0.2.0.dist-info/entry_points.txt +2 -0
  64. nbs_bl-0.2.0.dist-info/licenses/LICENSE +13 -0
nbs_bl/startup.py ADDED
@@ -0,0 +1,27 @@
1
+ from bluesky.plan_stubs import abs_set, mv
2
+ import nbs_bl
3
+ from nbs_bl.hw import *
4
+ from nbs_bl.detectors import (
5
+ list_detectors,
6
+ activate_detector,
7
+ deactivate_detector,
8
+ activate_detector_set,
9
+ )
10
+ from nbs_bl.motors import list_motors
11
+ import nbs_bl.plans.scans
12
+
13
+ from nbs_bl.run_engine import setup_run_engine, create_run_engine
14
+
15
+ from nbs_bl.plans.groups import group
16
+ from nbs_bl.plans.plan_stubs import set_exposure
17
+ from nbs_bl.queueserver import request_update, get_status
18
+ from nbs_bl.samples import list_samples
19
+ from nbs_bl.beamline import GLOBAL_BEAMLINE
20
+
21
+
22
+ print("NBS Startup")
23
+
24
+ RE(set_exposure(1.0))
25
+
26
+ # load_saved_configuration()
27
+ activate_detector_set("default")
nbs_bl/status.py ADDED
@@ -0,0 +1,114 @@
1
+ import uuid
2
+ from abc import ABC, abstractmethod
3
+ from redis_json_dict import RedisJSONDict
4
+
5
+
6
+ class StatusContainerBase(ABC):
7
+
8
+ @classmethod
9
+ @property
10
+ @abstractmethod
11
+ def NORMAL_METHODS(cls):
12
+ raise NotImplementedError
13
+
14
+ @classmethod
15
+ @property
16
+ @abstractmethod
17
+ def REINIT_METHODS(cls):
18
+ raise NotImplementedError
19
+
20
+ @classmethod
21
+ def _make_normal_method(cls, method):
22
+ def _inner(self, *args):
23
+ self._uid = uuid.uuid4()
24
+ return getattr(super(), method)(*args)
25
+
26
+ _inner.__name__ = method
27
+ setattr(cls, method, _inner)
28
+
29
+ @classmethod
30
+ def _make_reinit_method(cls, method):
31
+ def _inner(self, *args):
32
+ self._uid = uuid.uuid4()
33
+ newitem = getattr(super(), method)(*args)
34
+ return self.__class__(newitem)
35
+
36
+ _inner.__name__ = method
37
+ setattr(cls, method, _inner)
38
+
39
+ def __init__(self, *args, **kwargs):
40
+ super().__init__(*args, **kwargs)
41
+ self._uid = uuid.uuid4()
42
+ for method in self.NORMAL_METHODS:
43
+ self.__class__._make_normal_method(method)
44
+
45
+ for method in self.REINIT_METHODS:
46
+ self.__class__._make_reinit_method(method)
47
+
48
+ def get_uid(self):
49
+ return self._uid
50
+
51
+
52
+ class StatusList(StatusContainerBase, list):
53
+ NORMAL_METHODS = [
54
+ "__delitem__",
55
+ "__setitem__",
56
+ "append",
57
+ "clear",
58
+ "extend",
59
+ "insert",
60
+ "pop",
61
+ "remove",
62
+ ]
63
+ REINIT_METHODS = ["__rmul__", "__iadd__", "__add__", "__imul__", "__mul__"]
64
+
65
+
66
+ class StatusDict(StatusContainerBase, dict):
67
+ NORMAL_METHODS = ["__delitem__", "__setitem__", "clear", "pop", "update"]
68
+ REINIT_METHODS = ["__or__", "__ror__"]
69
+
70
+
71
+ class StatusTuple(StatusContainerBase, tuple):
72
+ NORMAL_METHODS = []
73
+ REINIT_METHODS = ["__add__", "__mul__", "__rmul__"]
74
+
75
+ def __init__(self, *args, **kwargs):
76
+ self._uid = uuid.uuid4()
77
+ for method in self.NORMAL_METHODS:
78
+ self.__class__._make_normal_method(method)
79
+
80
+ for method in self.REINIT_METHODS:
81
+ self.__class__._make_reinit_method(method)
82
+
83
+
84
+ class StatusSet(StatusContainerBase, set):
85
+ NORMAL_METHODS = [
86
+ "clear",
87
+ "pop",
88
+ "add",
89
+ "remove",
90
+ "update",
91
+ "discard",
92
+ "intersection_update",
93
+ "symmetric_difference_update",
94
+ ]
95
+ REINIT_METHODS = [
96
+ "__and__",
97
+ "__iand__",
98
+ "__rand__",
99
+ "__or__",
100
+ "__ior__",
101
+ "__ixor__",
102
+ "__xor__",
103
+ "__ror__",
104
+ "__rxor__",
105
+ "intersection",
106
+ "difference",
107
+ "union",
108
+ "symmetric_difference",
109
+ ]
110
+
111
+
112
+ class RedisStatusDict(StatusContainerBase, RedisJSONDict):
113
+ NORMAL_METHODS = ["__delitem__", "__setitem__", "clear", "pop", "update"]
114
+ REINIT_METHODS = []
File without changes
@@ -0,0 +1,160 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to modify test_regions.toml by swapping the order of energy and step values
4
+ in each region array.
5
+
6
+ The original format is: [energy1, energy2, step1, energy3, step2, ...]
7
+ The new format will be: [energy1, step1, energy2, step2, energy3, ...]
8
+ """
9
+
10
+ import argparse
11
+ import re
12
+ from pathlib import Path
13
+
14
+
15
+ def modify_region_array(region_list):
16
+ """
17
+ Modify a region array by swapping the order of energy and step values.
18
+
19
+ Parameters
20
+ ----------
21
+ region_list : list
22
+ List of alternating energy and step values
23
+
24
+ Returns
25
+ -------
26
+ list
27
+ Modified list with energy and step values reordered
28
+ """
29
+ if len(region_list) < 3:
30
+ return region_list # Need at least 3 values to swap
31
+
32
+ # The pattern is: [energy1, energy2, step1, energy3, step2, ...]
33
+ # We want: [energy1, step1, energy2, step2, energy3, ...]
34
+
35
+ result = [region_list[0]] # Keep first energy value
36
+
37
+ # Process remaining values in pairs
38
+ for i in range(1, len(region_list) - 1, 2):
39
+ if i + 1 < len(region_list):
40
+ # Swap: energy, step -> step, energy
41
+ result.append(region_list[i + 1]) # step
42
+ result.append(region_list[i]) # energy
43
+
44
+ return result
45
+
46
+
47
+ def modify_toml_content(content):
48
+ """
49
+ Modify TOML content by finding and modifying region arrays.
50
+
51
+ Parameters
52
+ ----------
53
+ content : str
54
+ Original TOML content
55
+
56
+ Returns
57
+ -------
58
+ str
59
+ Modified TOML content
60
+ """
61
+
62
+ # Use regex to find region arrays and modify them
63
+ def replace_region_array(match):
64
+ # Extract the array content
65
+ array_str = match.group(1)
66
+
67
+ # Parse the array values
68
+ # Remove brackets and split by commas
69
+ values_str = array_str.strip("[]")
70
+ values = []
71
+
72
+ # Simple parsing - split by comma and strip whitespace
73
+ for val in values_str.split(","):
74
+ val = val.strip()
75
+ if val:
76
+ # Try to convert to float/int, otherwise keep as string
77
+ try:
78
+ if "." in val:
79
+ values.append(float(val))
80
+ else:
81
+ values.append(int(val))
82
+ except ValueError:
83
+ values.append(val)
84
+
85
+ # Modify the array
86
+ modified_values = modify_region_array(values)
87
+
88
+ # Convert back to string
89
+ modified_str = "[" + ", ".join(str(v) for v in modified_values) + "]"
90
+ return f"region = {modified_str}"
91
+
92
+ # Find all region arrays and replace them
93
+ pattern = r"region = \[([^\]]+)\]"
94
+ modified_content = re.sub(pattern, replace_region_array, content)
95
+
96
+ return modified_content
97
+
98
+
99
+ def parse_arguments():
100
+ """Parse command line arguments."""
101
+ parser = argparse.ArgumentParser(
102
+ description="Modify TOML region arrays by swapping energy and step values"
103
+ )
104
+ parser.add_argument(
105
+ "--input",
106
+ type=str,
107
+ default="test_regions.toml",
108
+ help="Input TOML file path (default: test_regions.toml)",
109
+ )
110
+ parser.add_argument(
111
+ "--output",
112
+ type=str,
113
+ default="test_regions_modified.toml",
114
+ help="Output TOML file path (default: test_regions_modified.toml)",
115
+ )
116
+ return parser.parse_args()
117
+
118
+
119
+ def main():
120
+ """Main function to process the TOML file."""
121
+ args = parse_arguments()
122
+
123
+ # File paths
124
+ input_file = Path(args.input)
125
+ output_file = Path(args.output)
126
+
127
+ # Check if input file exists
128
+ if not input_file.exists():
129
+ print(f"Error: Input file {input_file} not found")
130
+ return
131
+
132
+ try:
133
+ # Read the original file
134
+ with open(input_file, "r") as f:
135
+ content = f.read()
136
+
137
+ print(f"Processing {input_file}...")
138
+
139
+ # Modify the content
140
+ modified_content = modify_toml_content(content)
141
+
142
+ # Write the modified content
143
+ with open(output_file, "w") as f:
144
+ f.write(modified_content)
145
+
146
+ print(f"Modified file written to {output_file}")
147
+
148
+ # Show an example of the transformation
149
+ print("\nExample transformation:")
150
+ original_example = [1055, 1065, 1.0, 1070, 0.2, 1080, 0.1, 1100, 0.2, 1140, 1.0]
151
+ print(f"Original: {original_example}")
152
+ modified_example = modify_region_array(original_example)
153
+ print(f"Modified: {modified_example}")
154
+
155
+ except Exception as e:
156
+ print(f"Error processing file: {e}")
157
+
158
+
159
+ if __name__ == "__main__":
160
+ main()
@@ -0,0 +1,99 @@
1
+ import pytest
2
+ import numpy as np
3
+ from ..geometry.frames import Frame
4
+ from ..geometry.linalg import vec
5
+
6
+
7
+ @pytest.fixture
8
+ def unit_frame():
9
+ p1 = vec(1, 0, 0)
10
+ p2 = vec(1, 0, 1)
11
+ p3 = vec(1, 1, 0)
12
+ unit_frame = Frame(p1, p2, p3)
13
+ return unit_frame
14
+
15
+
16
+ @pytest.fixture
17
+ def unit_frame90():
18
+ p1 = vec(1, 1, 0)
19
+ p2 = vec(1, 1, 1)
20
+ p3 = vec(0, 1, 0)
21
+ unit_frame90 = Frame(p1, p2, p3)
22
+ return unit_frame90
23
+
24
+
25
+ @pytest.fixture
26
+ def compound_frame():
27
+ pp1 = vec(1, 1, 0)
28
+ pp2 = vec(1, 1, 1)
29
+ pp3 = vec(1, 2, 0)
30
+ parent = Frame(pp1, pp2, pp3)
31
+
32
+ p1 = vec(0, 0, 0)
33
+ p2 = vec(0, 1, 0)
34
+ p3 = vec(0, 0, -1)
35
+ compound_frame = Frame(p1, p2, p3, parent=parent)
36
+ return compound_frame
37
+
38
+ # Write compound test r0 -- figure out what it should mean
39
+
40
+
41
+ @pytest.mark.parametrize("xoffset", [-1, 0, 1])
42
+ @pytest.mark.parametrize("yoffset", [-1, 0, 1])
43
+ @pytest.mark.parametrize("zoffset", [-1, 0, 1])
44
+ def test_translated_frame(unit_frame, xoffset, yoffset, zoffset):
45
+ offset_vec = vec(xoffset, yoffset, zoffset)
46
+ p1 = vec(0, 0, 0) + offset_vec
47
+ p2 = vec(0, 1, 0) + offset_vec
48
+ p3 = vec(1, 0, 0) + offset_vec
49
+ compound = Frame(p1, p2, p3, parent=unit_frame)
50
+ parent_vec = unit_frame.frame_to_global(offset_vec, r=90,
51
+ rotation='global')
52
+ trans_vec = compound.frame_to_global(vec(0, 0, 0), r=90)
53
+ assert np.all(np.isclose(trans_vec, parent_vec))
54
+
55
+
56
+ def test_unit_frame_roffset(unit_frame, unit_frame90):
57
+ assert unit_frame.r0 == 0
58
+ assert unit_frame90.r0 == 90
59
+
60
+
61
+ def test_frame_to_global(unit_frame):
62
+ v_f = vec(0, 0, 0)
63
+ v_g = vec(1, 0, 0)
64
+ assert np.all(unit_frame.frame_to_global(v_f, rotation="global") == v_g)
65
+
66
+ v_f1 = vec(1, 0, 0)
67
+ v_g1 = vec(1, 1, 0)
68
+ assert np.all(unit_frame.frame_to_global(v_f1, rotation="global") == v_g1)
69
+
70
+ v_g2 = vec(0, 1, 0)
71
+ assert np.all(np.isclose(unit_frame.frame_to_global(v_f, r=90,
72
+ rotation="frame"),
73
+ v_g2))
74
+
75
+
76
+ def test_compound_frame(unit_frame90, compound_frame):
77
+ v1 = vec(0, 0, 0)
78
+ vu1 = unit_frame90.frame_to_global(v1, rotation='global')
79
+ vc1 = compound_frame.frame_to_global(v1, rotation='global')
80
+ assert np.all(np.isclose(vu1, vc1))
81
+
82
+ v2 = vec(*np.random.rand(3))
83
+ vu2 = unit_frame90.frame_to_global(v2, rotation='global')
84
+ vc2 = compound_frame.frame_to_global(v2, rotation='global')
85
+ assert np.all(np.isclose(vu2, vc2))
86
+
87
+ r = 90*np.random.rand()
88
+ vu3 = unit_frame90.frame_to_global(v2, r=r, rotation='frame')
89
+ vc3 = compound_frame.frame_to_global(v2, r=r, rotation='frame')
90
+ assert np.all(np.isclose(vu3, vc3))
91
+
92
+ vu4 = unit_frame90.global_to_frame(v2, r=r)
93
+ vc4 = compound_frame.global_to_frame(v2, r=r)
94
+ assert np.all(np.isclose(vu4, vc4))
95
+
96
+ manip = vec(*np.random.rand(3))
97
+ vu5 = unit_frame90.global_to_frame(v2, manip=manip, r=r)
98
+ vc5 = compound_frame.global_to_frame(v2, manip=manip, r=r)
99
+ assert np.all(np.isclose(vu5, vc5))
@@ -0,0 +1,69 @@
1
+ import pytest
2
+ import numpy as np
3
+ from sst_base.frames import Panel
4
+ from sst_base.linalg import vec
5
+
6
+
7
+ @pytest.fixture()
8
+ def unit_panel():
9
+ p1 = vec(1, 0, 0)
10
+ p2 = vec(1, 0, 1)
11
+ p3 = vec(-1, 0, 0)
12
+ height = 10
13
+ width = 2
14
+ panel = Panel(p1, p2, p3, width=width, height=height)
15
+ return panel
16
+
17
+
18
+ @pytest.fixture()
19
+ def rotated_panel():
20
+ p1 = vec(0, 0, 0)
21
+ p2 = vec(0, 0, 1)
22
+ p3 = vec(0, 1, 0)
23
+ height = 10
24
+ width = 2
25
+ panel = Panel(p1, p2, p3, width=width, height=height)
26
+ return panel
27
+
28
+
29
+ def test_panel_position(unit_panel):
30
+ coords = np.array(unit_panel.beam_to_frame(0, 0, 0, 0))
31
+ assert np.all(coords == np.array([1, 0, 0, 90]))
32
+ coords = np.array(unit_panel.beam_to_frame(0, 0, 1, 0))
33
+ assert np.all(coords == np.array([1, -1, 0, 90]))
34
+ coords = np.array(unit_panel.beam_to_frame(1, 0, 1, 0))
35
+ assert np.all(coords == np.array([2, -1, 0, 90]))
36
+
37
+
38
+ def test_panel_distance_sign_convention(unit_panel):
39
+ """
40
+ Tests that the distance is positive when the beam is outside the
41
+ panel, and negative when the beam is inside the panel
42
+ """
43
+ assert np.isclose(unit_panel.distance_to_beam(0, 0, 0, 0), 0)
44
+ assert unit_panel.distance_to_beam(0, 0, 1, 0) > 0
45
+ assert unit_panel.distance_to_beam(0, 0, -1, 0) < 0
46
+
47
+
48
+ def test_panel_distance_from_edge(unit_panel):
49
+ assert np.isclose(unit_panel.distance_to_beam(0, 0, 0, 0), 0)
50
+ assert np.isclose(unit_panel.distance_to_beam(0, 0, 1, 0), 1)
51
+ assert np.isclose(unit_panel.distance_to_beam(0, 0, -0.5, 0), -0.5)
52
+ assert np.isclose(unit_panel.distance_to_beam(0, 0, -1, 0), -1)
53
+ assert np.isclose(unit_panel.distance_to_beam(0, 0, -2, 0), -1)
54
+ assert np.isclose(unit_panel.distance_to_beam(0, 0, -3, 0), -1)
55
+
56
+
57
+ def test_panel_distance_from_corner(unit_panel):
58
+ assert np.isclose(unit_panel.distance_to_beam(-1, 0, 0, 0), 0)
59
+ assert np.isclose(unit_panel.distance_to_beam(-2, 0, -1, 0), 1)
60
+ assert np.isclose(unit_panel.distance_to_beam(-2, 0, 0, 0), 1)
61
+ assert np.isclose(unit_panel.distance_to_beam(-2, 0, 1, 0), np.sqrt(2))
62
+ assert np.isclose(unit_panel.distance_to_beam(-1, 0, 1, 0), 1)
63
+
64
+
65
+ def test_edge_on_panel_distance(rotated_panel):
66
+ assert np.isclose(rotated_panel.distance_to_beam(0, 0, 0, 0), 0)
67
+ assert np.isclose(rotated_panel.distance_to_beam(0, 0, -1, 0), 0)
68
+ assert np.isclose(rotated_panel.distance_to_beam(0, 0, 1, 0), 1)
69
+ assert np.isclose(rotated_panel.distance_to_beam(1, 0, 1, 0), np.sqrt(2))
nbs_bl/utils.py ADDED
@@ -0,0 +1,235 @@
1
+ import collections
2
+ import inspect
3
+ from functools import update_wrapper
4
+ from numpydoc.docscrape import NumpyDocString, Parameter
5
+ import copy
6
+
7
+
8
+ def iterfy(x):
9
+ """
10
+ This function guarantees that a parameter passed will act like a list (or tuple) for the purposes of iteration,
11
+ while treating a string as a single item in a list.
12
+
13
+ Parameters
14
+ ----------
15
+ x : Any
16
+ The input parameter to be iterfied.
17
+
18
+ Returns
19
+ -------
20
+ Iterable
21
+ The input parameter as an iterable.
22
+ """
23
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, (str, bytes)):
24
+ return x
25
+ else:
26
+ return [x]
27
+
28
+
29
+ def adjust_signature(*omit_args):
30
+ """
31
+ A decorator factory that adjusts the signature of the decorated function.
32
+ It omits specified arguments from the function's signature.
33
+
34
+ Parameters
35
+ ----------
36
+ *omit_args : str
37
+ Names of the arguments to be omitted from the function's signature.
38
+
39
+ Returns
40
+ -------
41
+ function
42
+ The decorated function with an adjusted signature.
43
+
44
+ Example
45
+ -------
46
+ @adjust_signature('arg_to_omit')
47
+ def func(arg_to_keep, arg_to_omit):
48
+ pass
49
+ """
50
+
51
+ def decorator(func):
52
+ sig = inspect.signature(func)
53
+ new_params = [p for name, p in sig.parameters.items() if name not in omit_args]
54
+ new_sig = sig.replace(parameters=new_params)
55
+ func.__signature__ = new_sig
56
+ return func
57
+
58
+ return decorator
59
+
60
+
61
+ def merge_signatures(func):
62
+ def decorator(wrapper):
63
+ # Get the signatures of the two functions
64
+ sig_func = inspect.signature(func)
65
+ sig_wrapper = inspect.signature(wrapper)
66
+
67
+ # Separate positional and keyword parameters from variadic keyword parameters
68
+ func_params = list(sig_func.parameters.values())[1:]
69
+ wrapper_params = [
70
+ param
71
+ for param in sig_wrapper.parameters.values()
72
+ if param.kind != param.VAR_KEYWORD
73
+ ]
74
+ wrapper_var_keyword_params = [
75
+ param
76
+ for param in sig_wrapper.parameters.values()
77
+ if param.kind == param.VAR_KEYWORD
78
+ ]
79
+
80
+ # Create a new parameters list that includes parameters from both functions
81
+
82
+ new_params = (
83
+ wrapper_params
84
+ + [
85
+ param
86
+ for param in func_params
87
+ if param.name not in sig_wrapper.parameters
88
+ ]
89
+ + wrapper_var_keyword_params
90
+ )
91
+
92
+ # Create a new signature with the combined parameters
93
+ new_sig = sig_wrapper.replace(parameters=new_params)
94
+
95
+ # Update the signature of the wrapper function
96
+ wrapper.__signature__ = new_sig
97
+
98
+ # Update the docstring and other attributes of the wrapper function
99
+ update_wrapper(wrapper, func)
100
+
101
+ return wrapper
102
+
103
+ return decorator
104
+
105
+
106
+ def merge_docstrings(doc1, doc2, omit_params=[], param_order=None):
107
+ """
108
+ Merge two numpy-style docstrings.
109
+
110
+ Parameters
111
+ ----------
112
+ doc1 : str
113
+ The first docstring.
114
+ doc2 : str
115
+ The second docstring.
116
+ omit_params : list of str
117
+ The names of parameters to omit from the final function signature and docstring.
118
+
119
+ Returns
120
+ -------
121
+ str
122
+ The merged docstring.
123
+ """
124
+ # Parse the docstrings
125
+ if doc1 is None:
126
+ doc1 = ""
127
+ if doc2 is None:
128
+ doc2 = ""
129
+ parsed_doc1 = NumpyDocString(doc1)
130
+ parsed_doc2 = NumpyDocString(doc2)
131
+
132
+ # Merge the parameters
133
+ params1 = {name: (typ, desc) for name, typ, desc in parsed_doc1["Parameters"]}
134
+ params2 = {name: (typ, desc) for name, typ, desc in parsed_doc2["Parameters"]}
135
+ params1.update(params2)
136
+ merged_params = [
137
+ Parameter(name, *params1[name]) for name in params1 if name not in omit_params
138
+ ]
139
+
140
+ # Sort the parameters so that any parameter with a "**" in its name is placed at the end
141
+ if param_order:
142
+ merged_params.sort(
143
+ key=lambda param: (
144
+ (
145
+ param_order.index(param.name)
146
+ if param.name in param_order
147
+ else len(param_order)
148
+ ),
149
+ param.name,
150
+ )
151
+ )
152
+ else:
153
+ merged_params.sort(key=lambda param: "**" in param.name)
154
+
155
+ # Create a copy of the first parsed docstring
156
+ merged_doc = copy.deepcopy(parsed_doc1)
157
+
158
+ # Update the 'Parameters' section of the copied docstring
159
+ merged_doc["Parameters"] = merged_params
160
+
161
+ # Convert the merged docstring back to a string
162
+ return str(merged_doc)
163
+
164
+
165
+ def sort_params(params):
166
+ params.sort(key=lambda param: param.default != inspect._empty)
167
+ params.sort(key=lambda param: param.kind)
168
+
169
+
170
+ def merge_func(
171
+ func, omit_params=[], exclude_wrapper_args=True, exclude_wrapper_kwargs=True, use_func_name=True
172
+ ):
173
+ """
174
+ A decorator that merges the docstrings and function signatures of the wrapped function and the wrapper function.
175
+
176
+ Parameters
177
+ ----------
178
+ func : callable
179
+ The function to be wrapped.
180
+ omit_params : list of str
181
+ The names of parameters to omit from the final function signature and docstring.
182
+
183
+ Returns
184
+ -------
185
+ callable
186
+ The wrapper function with the merged docstring and function signature.
187
+ """
188
+
189
+ def decorator(wrapper):
190
+ # Merge the docstrings
191
+
192
+ # Merge the function signatures
193
+ sig_wrapper = inspect.signature(wrapper)
194
+ sig_func = inspect.signature(func)
195
+
196
+ # Get the parameters from the wrapper function and the wrapped function
197
+ params_wrapper = list(sig_wrapper.parameters.values())
198
+ if exclude_wrapper_args:
199
+ params_wrapper = [
200
+ param for param in params_wrapper if param.kind != param.VAR_POSITIONAL
201
+ ]
202
+ if exclude_wrapper_kwargs:
203
+ params_wrapper = [
204
+ param for param in params_wrapper if param.kind != param.VAR_KEYWORD
205
+ ]
206
+ params_func = [
207
+ param
208
+ for param in list(sig_func.parameters.values())
209
+ if param.name not in [wparam.name for wparam in params_wrapper]
210
+ ]
211
+ if params_func and params_func[0].name == "self":
212
+ params_func = params_func[1:]
213
+
214
+ combined_params = [
215
+ param
216
+ for param in params_func + params_wrapper
217
+ if param.name not in omit_params
218
+ ]
219
+ sort_params(combined_params)
220
+ param_order = [param.name for param in combined_params]
221
+ merged_docstring = merge_docstrings(
222
+ wrapper.__doc__, func.__doc__, omit_params, param_order
223
+ )
224
+ wrapper.__doc__ = merged_docstring
225
+ # Create a new signature with the merged parameters
226
+ new_sig = sig_wrapper.replace(parameters=combined_params)
227
+
228
+ # Update the signature of the wrapper function
229
+ wrapper.__signature__ = new_sig
230
+ if use_func_name:
231
+ wrapper.__name__ = func.__name__
232
+
233
+ return wrapper
234
+
235
+ return decorator