xcoll 0.3.6__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. xcoll/__init__.py +12 -4
  2. xcoll/beam_elements/__init__.py +7 -5
  3. xcoll/beam_elements/absorber.py +41 -7
  4. xcoll/beam_elements/base.py +1161 -244
  5. xcoll/beam_elements/collimators_src/black_absorber.h +118 -0
  6. xcoll/beam_elements/collimators_src/black_crystal.h +111 -0
  7. xcoll/beam_elements/collimators_src/everest_block.h +40 -28
  8. xcoll/beam_elements/collimators_src/everest_collimator.h +129 -50
  9. xcoll/beam_elements/collimators_src/everest_crystal.h +217 -73
  10. xcoll/beam_elements/everest.py +60 -113
  11. xcoll/colldb.py +250 -750
  12. xcoll/general.py +2 -2
  13. xcoll/headers/checks.h +1 -1
  14. xcoll/headers/particle_states.h +2 -2
  15. xcoll/initial_distribution.py +195 -0
  16. xcoll/install.py +177 -0
  17. xcoll/interaction_record/__init__.py +1 -0
  18. xcoll/interaction_record/interaction_record.py +252 -0
  19. xcoll/interaction_record/interaction_record_src/interaction_record.h +98 -0
  20. xcoll/{impacts → interaction_record}/interaction_types.py +11 -4
  21. xcoll/line_tools.py +83 -0
  22. xcoll/lossmap.py +209 -0
  23. xcoll/manager.py +2 -937
  24. xcoll/rf_sweep.py +1 -1
  25. xcoll/scattering_routines/everest/amorphous.h +239 -0
  26. xcoll/scattering_routines/everest/channeling.h +245 -0
  27. xcoll/scattering_routines/everest/crystal_parameters.h +137 -0
  28. xcoll/scattering_routines/everest/everest.h +8 -30
  29. xcoll/scattering_routines/everest/everest.py +13 -10
  30. xcoll/scattering_routines/everest/jaw.h +27 -197
  31. xcoll/scattering_routines/everest/materials.py +2 -0
  32. xcoll/scattering_routines/everest/multiple_coulomb_scattering.h +31 -10
  33. xcoll/scattering_routines/everest/nuclear_interaction.h +86 -0
  34. xcoll/scattering_routines/geometry/__init__.py +6 -0
  35. xcoll/scattering_routines/geometry/collimator_geometry.h +219 -0
  36. xcoll/scattering_routines/geometry/crystal_geometry.h +150 -0
  37. xcoll/scattering_routines/geometry/geometry.py +26 -0
  38. xcoll/scattering_routines/geometry/get_s.h +92 -0
  39. xcoll/scattering_routines/geometry/methods.h +111 -0
  40. xcoll/scattering_routines/geometry/objects.h +154 -0
  41. xcoll/scattering_routines/geometry/rotation.h +23 -0
  42. xcoll/scattering_routines/geometry/segments.h +226 -0
  43. xcoll/scattering_routines/geometry/sort.h +184 -0
  44. {xcoll-0.3.6.dist-info → xcoll-0.4.0.dist-info}/METADATA +1 -1
  45. {xcoll-0.3.6.dist-info → xcoll-0.4.0.dist-info}/RECORD +48 -33
  46. xcoll/beam_elements/collimators_src/absorber.h +0 -141
  47. xcoll/collimator_settings.py +0 -457
  48. xcoll/impacts/__init__.py +0 -1
  49. xcoll/impacts/impacts.py +0 -102
  50. xcoll/impacts/impacts_src/impacts.h +0 -99
  51. xcoll/scattering_routines/everest/crystal.h +0 -1302
  52. xcoll/scattering_routines/everest/scatter.h +0 -169
  53. xcoll/scattering_routines/everest/scatter_crystal.h +0 -260
  54. {xcoll-0.3.6.dist-info → xcoll-0.4.0.dist-info}/LICENSE +0 -0
  55. {xcoll-0.3.6.dist-info → xcoll-0.4.0.dist-info}/NOTICE +0 -0
  56. {xcoll-0.3.6.dist-info → xcoll-0.4.0.dist-info}/WHEEL +0 -0
xcoll/general.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # copyright ############################### #
2
2
  # This file is part of the Xcoll Package. #
3
- # Copyright (c) CERN, 2023. #
3
+ # Copyright (c) CERN, 2024. #
4
4
  # ######################################### #
5
5
 
6
6
  from pathlib import Path
@@ -12,5 +12,5 @@ citation = "F.F. Van der Veken, et al.: Recent Developments with the New Tools f
12
12
  # ===================
13
13
  # Do not change
14
14
  # ===================
15
- __version__ = '0.3.6'
15
+ __version__ = '0.4.0'
16
16
  # ===================
xcoll/headers/checks.h CHANGED
@@ -31,5 +31,5 @@ int8_t xcoll_check_particle_init(RandomRutherfordData rng, LocalParticle* part)
31
31
  }
32
32
  return is_tracking*rng_is_set*ruth_is_set;
33
33
  }
34
-
34
+
35
35
  #endif /* XCOLL_CHECKS_H */
@@ -1,6 +1,6 @@
1
1
  // copyright ############################### #
2
2
  // This file is part of the Xcoll Package. #
3
- // Copyright (c) CERN, 2023. #
3
+ // Copyright (c) CERN, 2024. #
4
4
  // ######################################### #
5
5
 
6
6
  #ifndef XCOLL_STATES_H
@@ -21,5 +21,5 @@
21
21
  #define XC_ERR_NOT_IMPLEMENTED -391
22
22
  #define XC_ERR_INVALID_XOFIELD -392
23
23
  #define XC_ERR -399
24
-
24
+
25
25
  #endif /* XCOLL_STATES_H */
@@ -0,0 +1,195 @@
1
+ # copyright ############################### #
2
+ # This file is part of the Xcoll Package. #
3
+ # Copyright (c) CERN, 2024. #
4
+ # ######################################### #
5
+
6
+ import numpy as np
7
+
8
+ import xtrack as xt
9
+ import xobjects as xo
10
+ import xpart as xp
11
+
12
+ from .beam_elements import collimator_classes
13
+
14
+
15
+ def generate_pencil_on_collimator(line, name, num_particles, *, side='+-', pencil_spread=1e-6,
16
+ impact_parameter=0, sigma_z=7.61e-2, tw=None, longitudinal=None,
17
+ longitudinal_betatron_cut=None):
18
+ """
19
+ Generate a pencil beam on a collimator.
20
+ """
21
+
22
+ if not line._has_valid_tracker():
23
+ raise Exception("Please build tracker before generating pencil distribution!")
24
+
25
+ coll = line[name]
26
+
27
+ if not isinstance(coll, tuple(collimator_classes)):
28
+ raise ValueError("Need to provide a valid collimator!")
29
+
30
+ if coll.optics is None:
31
+ raise Exception("Need to assign optics to collimators before generating pencil distribution!")
32
+
33
+ num_particles = int(num_particles)
34
+
35
+ if coll.side == 'left':
36
+ side = '+'
37
+ if coll.side == 'right':
38
+ side = '-'
39
+
40
+ # Define the plane
41
+ angle = coll.angle
42
+ if abs(np.mod(angle-90,180)-90) < 1e-6:
43
+ plane = 'x'
44
+ transv_plane = 'y'
45
+ elif abs(np.mod(angle,180)-90) < 1e-6:
46
+ plane = 'y'
47
+ transv_plane = 'x'
48
+ else:
49
+ raise NotImplementedError("Pencil beam on a skew collimator not yet supported!")
50
+
51
+ if tw is None:
52
+ tw = line.twiss() # TODO: can we do this smarter by caching?
53
+
54
+ # Is it converging or diverging? # TODO: This might change with a tilt!!!!!!
55
+ s_front = line.get_s_position(name)
56
+ s_back = s_front + coll.length
57
+ is_converging = tw[f'alf{plane}', name] > 0
58
+ print(f"Collimator {name} is {'con' if is_converging else 'di'}verging.")
59
+
60
+ beam_sizes = tw.get_beam_covariance(nemitt_x=coll.nemitt_x, nemitt_y=coll.nemitt_y)
61
+ if is_converging:
62
+ # pencil at front of jaw
63
+ match_at_s = s_front
64
+ sigma = beam_sizes.rows[name:f'{name}%%1'][f'sigma_{plane}'][0]
65
+ sigma_transv = beam_sizes.rows[name:f'{name}%%1'][f'sigma_{transv_plane}'][0]
66
+ else:
67
+ # pencil at back of jaw
68
+ match_at_s = s_back
69
+ sigma = beam_sizes.rows[name:f'{name}%%1'][f'sigma_{plane}'][1]
70
+ sigma_transv = beam_sizes.rows[name:f'{name}%%1'][f'sigma_{transv_plane}'][1]
71
+ dr_sigmas = pencil_spread/sigma
72
+
73
+ # Generate 4D coordinates
74
+ # TODO: there is some looping in the calculation here and in xpart. Can it be improved?
75
+ if side == '+-':
76
+ num_plus = int(num_particles/2)
77
+ num_min = int(num_particles - num_plus)
78
+ coords_plus = _generate_4D_pencil_one_jaw(line, name, num_plus, plane, '+', impact_parameter, dr_sigmas, match_at_s, is_converging)
79
+ coords_min = _generate_4D_pencil_one_jaw(line, name, num_min, plane, '-', impact_parameter, dr_sigmas, match_at_s, is_converging)
80
+ coords = [ [*c_plus, *c_min] for c_plus, c_min in zip(coords_plus, coords_min)]
81
+ else:
82
+ coords = _generate_4D_pencil_one_jaw(line, name, num_particles, plane, side, impact_parameter, dr_sigmas, match_at_s, is_converging)
83
+ pencil = coords[0]
84
+ p_pencil = coords[1]
85
+ transverse_norm = coords[2]
86
+ p_transverse_norm = coords[3]
87
+
88
+ # Longitudinal plane
89
+ # TODO: make this more general, make this better
90
+ if longitudinal is None:
91
+ delta = 0
92
+ zeta = 0
93
+ elif longitudinal == 'matched_dispersion':
94
+ raise NotImplementedError
95
+ # if longitudinal_betatron_cut is None:
96
+ # cut = 0
97
+ # else:
98
+ # cut = np.random.uniform(-longitudinal_betatron_cut, longitudinal_betatron_cut,
99
+ # num_particles)
100
+ # delta = generate_delta_from_dispersion(line, name, plane=plane, position_mm=pencil,
101
+ # nemitt_x=nemitt_x, nemitt_y=nemitt_y, twiss=tw,
102
+ # betatron_cut=cut, match_at_front=is_converging)
103
+ # zeta = 0
104
+ elif longitudinal == 'bucket':
105
+ zeta, delta = xp.generate_longitudinal_coordinates(
106
+ num_particles=num_particles, distribution='gaussian', sigma_z=sigma_z, line=line
107
+ )
108
+ elif not hasattr(longitudinal, '__iter__'):
109
+ raise ValueError
110
+ elif len(longitudinal) != 2:
111
+ raise ValueError
112
+ elif isinstance(longitudinal, str):
113
+ raise ValueError
114
+ elif isinstance(longitudinal, dict):
115
+ zeta = longitudinal['zeta']
116
+ delta = longitudinal['delta']
117
+ else:
118
+ zeta = longitudinal[0]
119
+ delta = longitudinal[1]
120
+
121
+ # Build the particles
122
+ if plane == 'x':
123
+ part = xp.build_particles(
124
+ x=pencil, px=p_pencil, y_norm=transverse_norm, py_norm=p_transverse_norm,
125
+ zeta=zeta, delta=delta, nemitt_x=coll.nemitt_x, nemitt_y=coll.nemitt_y,
126
+ line=line, at_element=name, match_at_s=match_at_s,
127
+ _context=coll._buffer.context
128
+ )
129
+ else:
130
+ part = xp.build_particles(
131
+ x_norm=transverse_norm, px_norm=p_transverse_norm, y=pencil, py=p_pencil,
132
+ zeta=zeta, delta=delta, nemitt_x=coll.nemitt_x, nemitt_y=coll.nemitt_y,
133
+ line=line, at_element=name, match_at_s=match_at_s,
134
+ _context=coll._buffer.context
135
+ )
136
+
137
+ part._init_random_number_generator()
138
+
139
+ return part
140
+
141
+
142
+ def generate_delta_from_dispersion(line, at_element, *, plane, position_mm, nemitt_x, nemitt_y,
143
+ twiss=None, betatron_cut=0, match_at_front=True):
144
+ if line.tracker is None:
145
+ raise ValueError("Need to build tracker first!")
146
+ if not hasattr(betatron_cut, '__iter__'):
147
+ if hasattr(position_mm, '__iter__'):
148
+ betatron_cut = np.full_like(position_mm, betatron_cut)
149
+ elif not hasattr(position_mm, '__iter__'):
150
+ position_mm = np.full_like(betatron_cut, position_mm)
151
+ elif len(position_mm) != len(betatron_cut):
152
+ raise ValueError
153
+ if plane not in ['x', 'y']:
154
+ raise ValueError("The variable 'plane' needs to be either 'x' or 'y'!")
155
+
156
+ if twiss is None:
157
+ twiss = line.twiss()
158
+
159
+ beam_sizes = twiss.get_beam_covariance(nemitt_x=nemitt_x, nemitt_y=nemitt_y)
160
+ beam_sizes = beam_sizes.rows[at_element:f'{at_element}%%1'][f'sigma_{plane}']
161
+ sigma = beam_sizes[0] if match_at_front else beam_sizes[1]
162
+ delta = (position_mm - betatron_cut*sigma - twiss.rows[at_element][plane])
163
+ delta /= twiss.rows[at_element][f'd{plane}']
164
+
165
+ return delta
166
+
167
+
168
+ def _generate_4D_pencil_one_jaw(line, name, num_particles, plane, side, impact_parameter,
169
+ dr_sigmas, match_at_s, is_converging):
170
+ coll = line[name]
171
+
172
+ if side == '+':
173
+ if is_converging:
174
+ pencil_pos = coll.jaw_LU + impact_parameter
175
+ else:
176
+ pencil_pos = coll.jaw_LD + impact_parameter
177
+ elif side == '-':
178
+ if is_converging:
179
+ pencil_pos = coll.jaw_RU - impact_parameter
180
+ else:
181
+ pencil_pos = coll.jaw_RD - impact_parameter
182
+
183
+ # Collimator plane: generate pencil distribution
184
+ pencil, p_pencil = xp.generate_2D_pencil_with_absolute_cut(
185
+ num_particles, plane=plane, absolute_cut=pencil_pos, line=line,
186
+ dr_sigmas=dr_sigmas, nemitt_x=coll.nemitt_x, nemitt_y=coll.nemitt_y,
187
+ at_element=name, side=side,match_at_s=match_at_s
188
+ )
189
+
190
+ # Other plane: generate gaussian distribution in normalized coordinates
191
+ transverse_norm = np.random.normal(size=num_particles)
192
+ p_transverse_norm = np.random.normal(size=num_particles)
193
+
194
+ return [pencil, p_pencil, transverse_norm, p_transverse_norm]
195
+
xcoll/install.py ADDED
@@ -0,0 +1,177 @@
1
+ # copyright ############################### #
2
+ # This file is part of the Xcoll Package. #
3
+ # Copyright (c) CERN, 2024. #
4
+ # ######################################### #
5
+
6
+ import numpy as np
7
+ import xtrack as xt
8
+
9
+ from .beam_elements import element_classes
10
+
11
+ def install_elements(line, names, elements, *, at_s=None, apertures=None, need_apertures=False, s_tol=1.e-6):
12
+ if line._has_valid_tracker():
13
+ raise Exception("Tracker already built!\nPlease install collimators before building "
14
+ + "tracker!")
15
+
16
+ if not hasattr(names, '__iter__') or isinstance(names, str):
17
+ names = [names]
18
+ if not hasattr(elements, '__iter__') or isinstance(elements, str):
19
+ elements = [elements]
20
+ names = np.array(names)
21
+ length = np.array([coll.length for coll in elements])
22
+ assert len(length) == len(names)
23
+ if not hasattr(at_s, '__iter__'):
24
+ at_s = [at_s for _ in range(len(names))]
25
+ assert len(at_s) == len(names)
26
+ if isinstance(apertures, str) or not hasattr(apertures, '__iter__'):
27
+ apertures = [apertures for _ in range(len(names))]
28
+ assert len(apertures) == len(names)
29
+
30
+ # Verify elements
31
+ for el in elements:
32
+ assert isinstance(el, element_classes)
33
+ el._tracking = False
34
+
35
+ # Get positions
36
+ tab = line.get_table()
37
+ tt = tab.rows[[name for name in names if name in line.element_names]]
38
+ s_start = []
39
+ for name, s, l in zip(names, at_s, length):
40
+ if s is None:
41
+ s_start.append(_get_s_start(line, name, l, tt))
42
+ else:
43
+ s_start.append(s)
44
+ s_start = np.array(s_start)
45
+ s_end = s_start + length
46
+
47
+ # Check positions
48
+ l_line = line.get_length()
49
+ for s1, s2, name, s3 in zip(s_start, s_end, names, at_s):
50
+ check_element_position(line, name, s1, s2, s3, l_line, s_tol=s_tol)
51
+
52
+ # Look for apertures
53
+ aper_upstream = []
54
+ aper_downstream = []
55
+ for s1, s2, name, aper in zip(s_start, s_end, names, apertures):
56
+ if not need_apertures:
57
+ aper_upstream.append(None)
58
+ aper_downstream.append(None)
59
+ else:
60
+ aper1, aper2 = get_aperture_for_element(line, name, s1, s2, aper, tab, s_tol=s_tol)
61
+ aper_upstream.append(aper1)
62
+ aper_downstream.append(aper2)
63
+
64
+ # Remove elements at location of collimator (by changing them into markers)
65
+ for s1, s2, name in zip(s_start, s_end, names):
66
+ prepare_space_for_element(line, name, s1, s2, tab=tab, s_tol=s_tol)
67
+
68
+ # Install
69
+ line._insert_thick_elements_at_s(element_names=list(names), elements=elements, at_s=s_start, s_tol=s_tol)
70
+
71
+ # Install apertures
72
+ if need_apertures:
73
+ for s1, name, aper1, aper2 in zip(s_start, names, aper_upstream, aper_downstream):
74
+ line.insert_element(element=aper1, name=f'{name}_aper_upstream', at=name, s_tol=s_tol)
75
+ idx = line.element_names.index(name) + 1
76
+ line.insert_element(element=aper2, name=f'{name}_aper_downstream', at=idx, s_tol=s_tol)
77
+
78
+
79
+ def _get_s_start(line, name, length, tab=None):
80
+ if tab is None:
81
+ tab = line.get_table()
82
+ if name in line.element_names and hasattr(line[name], 'length'):
83
+ existing_length = line[name].length
84
+ else:
85
+ existing_length = 0
86
+ return tab.rows[name].s[0] + existing_length/2. - length/2
87
+
88
+
89
+ def check_element_position(line, name, s_start, s_end, at_s, length=None, s_tol=1.e-6):
90
+ if at_s is None:
91
+ if name not in line.element_names:
92
+ raise ValueError(f"Element {name} not found in line. Provide `at_s`.")
93
+ elif name in line.element_names:
94
+ if at_s < s_start or at_s > s_end:
95
+ raise ValueError(f"Element {name} already exists in line at different "
96
+ + f"location: at_s = {at_s}, exists at [{s_start}, {s_end}].")
97
+ if length is None:
98
+ length = line.get_length()
99
+ if s_start <= s_tol:
100
+ raise ValueError(f"Position of {name} too close to start of line. Please cycle.")
101
+ if s_end >= length - s_tol:
102
+ raise ValueError(f"Position of {name} too close to end of line. Please cycle.")
103
+
104
+
105
+ def get_apertures_at_s(tab, s, s_tol=1.e-6):
106
+ tab_s = tab.rows[s-s_tol:s+s_tol:'s']
107
+ aper = tab_s.rows[[cls.startswith('Limit') for cls in tab_s.element_type]]
108
+ if len(aper) == 0:
109
+ return None
110
+ elif len(aper) == 1:
111
+ return aper.name[0]
112
+ else:
113
+ raise ValueError(f"Multiple apertures found at location {s} with "
114
+ + f"tolerance {s_tol}: {aper.name}. Not supported.")
115
+
116
+
117
+ def get_aperture_for_element(line, name, s_start, s_end, aperture=None, tab=None, s_tol=1.e-6):
118
+ if aperture is not None:
119
+ if isinstance(aperture, str):
120
+ aper1 = line[aperture]
121
+ aper2 = line[aperture]
122
+ elif hasattr(aperture, '__iter__'):
123
+ if len(aperture) != 2:
124
+ raise ValueError(f"The value `aperture` should be None or a list "
125
+ + f"[upstream, downstream].")
126
+ assert aperture[0] is not None and aperture[1] is not None
127
+ if isinstance(aperture[0], str):
128
+ aper1 = line[aperture[0]]
129
+ if isinstance(aperture[1], str):
130
+ aper2 = line[aperture[1]]
131
+ else:
132
+ aper1 = aperture
133
+ aper2 = aperture
134
+ if not xt.line._is_aperture(aper1, line):
135
+ raise ValueError(f"Not a valid aperture: {aper1}")
136
+ if not xt.line._is_aperture(aper2, line):
137
+ raise ValueError(f"Not a valid aperture: {aper2}")
138
+ return aper1.copy(), aper2.copy()
139
+ else:
140
+ if tab is None:
141
+ tab = line.get_table()
142
+ aper1 = get_apertures_at_s(tab, s_start, s_tol=s_tol)
143
+ aper2 = get_apertures_at_s(tab, s_end, s_tol=s_tol)
144
+ if aper1 is None and aper2 is not None:
145
+ aper1 = aper2
146
+ print(f"Warning: Could not find upstream aperture for {name}! "
147
+ + f"Used copy of downstream aperture. Proceed with caution.")
148
+ elif aper2 is None and aper1 is not None:
149
+ aper2 = aper1
150
+ print(f"Warning: Could not find downstream aperture for {name}! "
151
+ + f"Used copy of upstream aperture. Proceed with caution.")
152
+ elif aper1 is None and aper2 is None:
153
+ aper_mid = get_apertures_at_s(tab, (s_start+s_end)/2, s_tol=s_tol)
154
+ if aper_mid is None:
155
+ raise ValueError(f"No aperture found for {name}! Please provide one.")
156
+ if line[aper_mid].allow_rot_and_shift \
157
+ and xt.base_element._tranformations_active(line[aper_mid]):
158
+ print(f"Warning: Using the centre aperture for {name}, but "
159
+ + f"transformations are present. Proceed with caution.")
160
+ aper1 = aper_mid
161
+ aper2 = aper_mid
162
+ return line[aper1].copy(), line[aper2].copy()
163
+
164
+
165
+ def prepare_space_for_element(line, name, s_start, s_end, tab=None, s_tol=1.e-6):
166
+ if tab is None:
167
+ tab = line.get_table()
168
+ tt = tab.rows[s_start-s_tol:s_end+s_tol:'s']
169
+ for element_name, element_type in zip(tt.name[:-1], tt.element_type[:-1]):
170
+ if element_type == 'Marker' or element_type.startswith('Drift'):
171
+ continue
172
+ if not element_type.startswith('Limit'):
173
+ print(f"Warning: Removed active element {element_name} "
174
+ + f"at location inside collimator!")
175
+ length = line[element_name].length if hasattr(line[element_name], 'length') else 0
176
+ line.element_dict[element_name] = xt.Drift(length=length)
177
+
@@ -0,0 +1 @@
1
+ from .interaction_record import InteractionRecord
@@ -0,0 +1,252 @@
1
+ # copyright ############################### #
2
+ # This file is part of the Xcoll Package. #
3
+ # Copyright (c) CERN, 2024. #
4
+ # ######################################### #
5
+
6
+ import xobjects as xo
7
+ import xtrack as xt
8
+
9
+ from .interaction_types import source, interactions, shortcuts, is_point
10
+ from ..general import _pkg_root
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+
15
+
16
+ class InteractionRecord(xt.BeamElement):
17
+ _xofields = {
18
+ '_index': xt.RecordIndex,
19
+ 'at_element': xo.Int64[:],
20
+ 'at_turn': xo.Int64[:],
21
+ 'ds': xo.Float64[:],
22
+ '_inter': xo.Int64[:],
23
+ 'parent_id': xo.Int64[:],
24
+ 'parent_x': xo.Float64[:],
25
+ 'parent_px': xo.Float64[:],
26
+ 'parent_y': xo.Float64[:],
27
+ 'parent_py': xo.Float64[:],
28
+ 'parent_zeta': xo.Float64[:],
29
+ 'parent_delta': xo.Float64[:],
30
+ 'parent_energy': xo.Float64[:],
31
+ 'parent_mass': xo.Float64[:],
32
+ 'parent_charge': xo.Int64[:],
33
+ 'parent_z': xo.Int64[:],
34
+ 'parent_a': xo.Int64[:],
35
+ 'parent_pdgid': xo.Int64[:],
36
+ 'child_id': xo.Int64[:],
37
+ 'child_x': xo.Float64[:],
38
+ 'child_px': xo.Float64[:],
39
+ 'child_y': xo.Float64[:],
40
+ 'child_py': xo.Float64[:],
41
+ 'child_zeta': xo.Float64[:],
42
+ 'child_delta': xo.Float64[:],
43
+ 'child_energy': xo.Float64[:],
44
+ 'child_mass': xo.Float64[:],
45
+ 'child_charge': xo.Int64[:],
46
+ 'child_z': xo.Int64[:],
47
+ 'child_a': xo.Int64[:],
48
+ 'child_pdgid': xo.Int64[:],
49
+ }
50
+
51
+ allow_track = False
52
+
53
+ _extra_c_sources = [
54
+ source,
55
+ _pkg_root.joinpath('headers','particle_states.h'),
56
+ _pkg_root.joinpath('interaction_record','interaction_record_src','interaction_record.h')
57
+ ]
58
+
59
+
60
+ @classmethod
61
+ def start(cls, line, names=None, *, record_touches=None, record_scatterings=None, capacity=1e6, io_buffer=None):
62
+ names = _get_xcoll_elements(line, names)
63
+ if len(names) == 0:
64
+ return
65
+ elements = [line[name] for name in names]
66
+ capacity = int(capacity)
67
+ if io_buffer is None:
68
+ io_buffer = xt.new_io_buffer(capacity=capacity)
69
+ if record_touches is None and record_scatterings is None:
70
+ record_touches = True
71
+ record_scatterings = True
72
+ elif record_touches is None:
73
+ record_touches = not record_scatterings
74
+ elif record_scatterings is None:
75
+ record_scatterings = not record_touches
76
+ assert record_touches is True or record_touches is False
77
+ assert record_scatterings is True or record_scatterings is False
78
+ for el in elements:
79
+ if not el.record_touches and not el.record_scatterings:
80
+ el.record_touches = record_touches
81
+ el.record_scatterings = record_scatterings
82
+ record = xt.start_internal_logging(io_buffer=io_buffer, capacity=capacity, \
83
+ elements=elements)
84
+ record._line = line
85
+ record._io_buffer = io_buffer
86
+ record._recording_elements = names
87
+ record._coll_ids = {name: line.element_names.index(name) for name in names}
88
+ record._coll_names = {vv: kk for kk, vv in record._coll_ids.items()}
89
+ return record
90
+
91
+ def stop(self, names=None):
92
+ self.assert_class_init()
93
+ names = _get_xcoll_elements(self.line, names)
94
+ elements = [self.line[name] for name in names]
95
+ if self.line.tracker is not None:
96
+ self.line.tracker._check_invalidated()
97
+ xt.stop_internal_logging(elements=elements)
98
+ # Removed the stopped collimators from list of logged elements
99
+ self._recording_elements = list(set(self._recording_elements) - set(names))
100
+
101
+
102
+ def assert_class_init(self):
103
+ if not hasattr(self, '_io_buffer') or not hasattr(self, '_line') \
104
+ or not hasattr(self, '_recording_elements'):
105
+ raise ValueError("This InteractionRecord has been manually instantiated, "
106
+ + "hence the expanded API is not available. Use "
107
+ + "InteractionRecord.start() to initialise with extended API.")
108
+
109
+ @property
110
+ def line(self):
111
+ if hasattr(self, '_line'):
112
+ return self._line
113
+
114
+ @property
115
+ def io_buffer(self):
116
+ if hasattr(self, '_io_buffer'):
117
+ return self._io_buffer
118
+
119
+ @property
120
+ def capacity(self):
121
+ if hasattr(self, '_io_buffer'):
122
+ return self.io_buffer.capacity
123
+
124
+ # @capacity.setter
125
+ # def capacity(self, val):
126
+ # if hasattr(self, '_io_buffer'):
127
+ # capacity = int(capacity)
128
+ # if capacity < self.capacity:
129
+ # raise NotImplementedError("Shrinking of capacity not yet implemented!")
130
+ # elif capacity == self.capacity:
131
+ # return
132
+ # else:
133
+ # self.io_buffer.grow(capacity - self.capacity)
134
+ # # TODO: increase capacity of iobuffer AND of fields in record table
135
+
136
+ @property
137
+ def recording_elements(self):
138
+ if hasattr(self, '_recording_elements'):
139
+ return self._recording_elements
140
+
141
+ @recording_elements.setter
142
+ def recording_elements(self, val):
143
+ self.assert_class_init()
144
+ if val is None:
145
+ val = []
146
+ record_start = _get_xcoll_elements(self.line, val)
147
+ self.stop(set(self.recording_elements) - set(record_start))
148
+ elements = [line[name] for name in record_start]
149
+ for el in elements:
150
+ if not el.record_touches and not el.record_scatterings:
151
+ el.record_touches = True
152
+ el.record_scatterings = True
153
+ xt.start_internal_logging(io_buffer=self.io_buffer, capacity=self.capacity, \
154
+ record=self, elements=elements)
155
+ self._recording_elements = record_start
156
+ # Updating coll IDs: careful to correctly overwrite existing values
157
+ self._coll_ids.update({name: self.line.element_names.index(name) for name in record_start})
158
+ self._coll_names = {vv: kk for kk, vv in self._coll_ids.items()}
159
+
160
+ @property
161
+ def interaction_type(self):
162
+ return np.array([interactions[inter] for inter in self._inter])
163
+
164
+ def _collimator_name(self, element_id):
165
+ if not hasattr(self, '_coll_names'):
166
+ return element_id
167
+ elif element_id not in self._coll_names:
168
+ raise ValueError(f"Element {element_id} not found in list of collimators of this record table! "
169
+ + f"Did the line change without updating the list in the table?")
170
+ else:
171
+ return self._coll_names[element_id]
172
+
173
+ def _collimator_id(self, element_name):
174
+ if not hasattr(self, '_coll_ids'):
175
+ return element_id
176
+ elif element_name not in self._coll_ids:
177
+ raise ValueError(f"Element {element_name} not found in list of collimators of this record table! "
178
+ + f"Did the line change without updating the list in the table?")
179
+ else:
180
+ return self._coll_ids[element_name]
181
+
182
+ def to_pandas(self):
183
+ n_rows = self._index.num_recorded
184
+ coll_header = 'collimator' if hasattr(self, '_coll_names') else 'collimator_id'
185
+ df = pd.DataFrame({
186
+ 'turn': self.at_turn[:n_rows],
187
+ coll_header: [self._collimator_name(element_id) for element_id in self.at_element[:n_rows]],
188
+ 'interaction_type': [interactions[inter] for inter in self._inter[:n_rows]],
189
+ 'ds': self.ds[:n_rows],
190
+ **{
191
+ f'{p}_{val}': getattr(self, f'{p}_{val}')[:n_rows]
192
+ for p in ['parent', 'child']
193
+ for val in ['id', 'x', 'px', 'y', 'py', 'zeta', 'delta', 'energy', 'mass', 'charge', 'z', 'a', 'pdgid']
194
+ }
195
+ })
196
+ return df
197
+
198
+ # TODO: list of impacted collimators
199
+
200
+
201
+ # TODO: does not work when multiple children
202
+ def interactions_per_collimator(self, collimator=0, *, turn=None):
203
+ if isinstance(collimator, str):
204
+ collimator = self._collimator_id(collimator)
205
+ mask = (self._inter > 0) & (self.at_element == collimator)
206
+ if turn is not None:
207
+ mask = mask & (self.at_turn == turn)
208
+ df = pd.DataFrame({
209
+ 'int': [shortcuts[inter] for inter in self._inter[mask]],
210
+ 'pid': self.parent_id[mask]
211
+ })
212
+ return df.groupby('pid', sort=False)['int'].agg(list)
213
+ else:
214
+ df = pd.DataFrame({
215
+ 'int': [shortcuts[inter] for inter in self._inter[mask]],
216
+ 'turn': self.at_turn[mask],
217
+ 'pid': self.parent_id[mask]
218
+ })
219
+ return df.groupby(['pid', 'turn'], sort=False)['int'].apply(list)
220
+
221
+ def first_touch_per_turn(self):
222
+ n_rows = self._index.num_recorded
223
+ df = pd.DataFrame({'parent_id': self.parent_id[:n_rows],
224
+ 'at_turn': self.at_turn[:n_rows],
225
+ 'at_element': self.at_element[:n_rows]})
226
+ mask = np.char.startswith(self.interaction_type[:n_rows], 'Enter Jaw')
227
+ idx_first = [group.at_element.idxmin() for _, group in df[mask].groupby(['at_turn', 'parent_id'], sort=False)]
228
+ df_first = self.to_pandas().loc[idx_first]
229
+ df_first.insert(2, "jaw", df_first.interaction_type.astype(str).str[-1])
230
+ to_drop = ['ds', 'interaction_type',
231
+ *[col for col in df_first.columns if col.startswith('child_')]]
232
+ to_rename = {col: col.replace('parent_', '') for col in df_first.columns if col.startswith('parent_')}
233
+ return df_first.drop(columns=to_drop).rename(columns=to_rename)
234
+
235
+
236
+ def _get_xcoll_elements(line, names):
237
+ from xcoll import element_classes
238
+ if names is None or names is True:
239
+ names = line.get_elements_of_type(element_classes)[1]
240
+ if len(names) == 0:
241
+ raise ValueError("No Xcoll elements in line!")
242
+ if names is False:
243
+ names = []
244
+ if not hasattr(names, '__iter__') or isinstance(names, str):
245
+ names = [names]
246
+ for name in names:
247
+ if name not in line.element_names:
248
+ raise ValueError(f"Element {name} not found in line!")
249
+ if not isinstance(line[name], element_classes):
250
+ raise ValueError(f"Element {name} not an Xcoll element!")
251
+ return names
252
+