qilisdk 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. qilisdk/__init__.py +11 -2
  2. qilisdk/__init__.pyi +2 -3
  3. qilisdk/_logging.py +135 -0
  4. qilisdk/_optionals.py +5 -7
  5. qilisdk/analog/__init__.py +3 -18
  6. qilisdk/analog/exceptions.py +2 -4
  7. qilisdk/analog/hamiltonian.py +455 -110
  8. qilisdk/analog/linear_schedule.py +118 -0
  9. qilisdk/analog/schedule.py +272 -79
  10. qilisdk/backends/__init__.py +45 -0
  11. qilisdk/{digital/digital_algorithm.py → backends/__init__.pyi} +3 -5
  12. qilisdk/backends/backend.py +117 -0
  13. qilisdk/{extras/cuda → backends}/cuda_backend.py +153 -161
  14. qilisdk/backends/qutip_backend.py +492 -0
  15. qilisdk/common/__init__.py +48 -2
  16. qilisdk/common/algorithm.py +2 -1
  17. qilisdk/{extras/qaas/qaas_settings.py → common/exceptions.py} +12 -6
  18. qilisdk/common/model.py +1019 -1
  19. qilisdk/common/parameterizable.py +75 -0
  20. qilisdk/common/qtensor.py +666 -0
  21. qilisdk/common/result.py +2 -1
  22. qilisdk/common/variables.py +1931 -0
  23. qilisdk/{extras/cuda/cuda_analog_result.py → cost_functions/__init__.py} +3 -4
  24. qilisdk/cost_functions/cost_function.py +77 -0
  25. qilisdk/cost_functions/model_cost_function.py +145 -0
  26. qilisdk/cost_functions/observable_cost_function.py +109 -0
  27. qilisdk/digital/__init__.py +3 -22
  28. qilisdk/digital/ansatz.py +203 -160
  29. qilisdk/digital/circuit.py +81 -9
  30. qilisdk/digital/exceptions.py +12 -6
  31. qilisdk/digital/gates.py +228 -85
  32. qilisdk/{extras/qaas/qaas_analog_result.py → functionals/__init__.py} +14 -5
  33. qilisdk/functionals/functional.py +39 -0
  34. qilisdk/{extras/cuda/cuda_digital_result.py → functionals/functional_result.py} +3 -4
  35. qilisdk/functionals/sampling.py +81 -0
  36. qilisdk/functionals/sampling_result.py +92 -0
  37. qilisdk/functionals/time_evolution.py +98 -0
  38. qilisdk/functionals/time_evolution_result.py +84 -0
  39. qilisdk/functionals/variational_program.py +80 -0
  40. qilisdk/functionals/variational_program_result.py +69 -0
  41. qilisdk/logging_config.yaml +16 -0
  42. qilisdk/{common/backend.py → optimizers/__init__.py} +2 -1
  43. qilisdk/optimizers/optimizer.py +39 -0
  44. qilisdk/{common → optimizers}/optimizer_result.py +3 -12
  45. qilisdk/{common/optimizer.py → optimizers/scipy_optimizer.py} +10 -28
  46. qilisdk/settings.py +78 -0
  47. qilisdk/{extras → speqtrum}/__init__.py +7 -8
  48. qilisdk/{extras → speqtrum}/__init__.pyi +3 -3
  49. qilisdk/speqtrum/experiments/__init__.py +25 -0
  50. qilisdk/speqtrum/experiments/experiment_functional.py +124 -0
  51. qilisdk/speqtrum/experiments/experiment_result.py +231 -0
  52. qilisdk/{extras/qaas → speqtrum}/keyring.py +8 -4
  53. qilisdk/speqtrum/speqtrum.py +432 -0
  54. qilisdk/speqtrum/speqtrum_models.py +300 -0
  55. qilisdk/utils/__init__.py +0 -14
  56. qilisdk/utils/openqasm2.py +1 -1
  57. qilisdk/utils/serialization.py +1 -1
  58. qilisdk/utils/visualization/PlusJakartaSans-SemiBold.ttf +0 -0
  59. qilisdk/utils/visualization/__init__.py +24 -0
  60. qilisdk/utils/visualization/circuit_renderers.py +781 -0
  61. qilisdk/utils/visualization/schedule_renderers.py +161 -0
  62. qilisdk/utils/visualization/style.py +154 -0
  63. qilisdk/utils/visualization/themes.py +76 -0
  64. qilisdk/yaml.py +126 -0
  65. {qilisdk-0.1.3.dist-info → qilisdk-0.1.5.dist-info}/METADATA +180 -135
  66. qilisdk-0.1.5.dist-info/RECORD +69 -0
  67. qilisdk/analog/algorithms.py +0 -111
  68. qilisdk/analog/analog_backend.py +0 -43
  69. qilisdk/analog/analog_result.py +0 -114
  70. qilisdk/analog/quantum_objects.py +0 -533
  71. qilisdk/digital/digital_backend.py +0 -90
  72. qilisdk/digital/digital_result.py +0 -145
  73. qilisdk/digital/vqe.py +0 -166
  74. qilisdk/extras/cuda/__init__.py +0 -13
  75. qilisdk/extras/qaas/__init__.py +0 -13
  76. qilisdk/extras/qaas/models.py +0 -132
  77. qilisdk/extras/qaas/qaas_backend.py +0 -255
  78. qilisdk/extras/qaas/qaas_digital_result.py +0 -20
  79. qilisdk/extras/qaas/qaas_time_evolution_result.py +0 -20
  80. qilisdk/extras/qaas/qaas_vqe_result.py +0 -20
  81. qilisdk-0.1.3.dist-info/RECORD +0 -51
  82. {qilisdk-0.1.3.dist-info → qilisdk-0.1.5.dist-info}/WHEEL +0 -0
  83. {qilisdk-0.1.3.dist-info → qilisdk-0.1.5.dist-info}/licenses/LICENCE +0 -0
@@ -0,0 +1,161 @@
1
+ # Copyright 2025 Qilimanjaro Quantum Tech
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import TYPE_CHECKING
19
+
20
+ import matplotlib.pyplot as plt
21
+
22
+ if TYPE_CHECKING:
23
+ from qilisdk.analog.schedule import Schedule
24
+ from qilisdk.common.variables import Number
25
+
26
+ from qilisdk.utils.visualization.style import ScheduleStyle
27
+
28
+
29
+ class MatplotlibScheduleRenderer:
30
+ """Render a Schedule using matplotlib, with theme support."""
31
+
32
+ def __init__(self, schedule: Schedule, ax: plt.Axes | None = None, *, style: ScheduleStyle | None = None) -> None:
33
+ self.schedule = schedule
34
+ self.style = style or ScheduleStyle()
35
+ self.ax = ax or self._make_axes(self.style.dpi, self.style)
36
+
37
+ def plot(self, ax: plt.Axes | None = None) -> None:
38
+ """
39
+ Plot the schedule coefficients for each Hamiltonian over time.
40
+ Args:
41
+ ax (plt.Axes | None): The matplotlib axes to plot on. Default is None.
42
+ """
43
+ style = self.style
44
+ theme = style.theme
45
+ facecolor = theme.background
46
+ title_color = theme.on_background
47
+ label_color = theme.on_background
48
+ legend_facecolor = theme.surface
49
+ legend_edgecolor = theme.border
50
+ tick_color = theme.on_background
51
+
52
+ # Set axes and figure background to theme
53
+ self.ax.set_facecolor(facecolor)
54
+ if hasattr(ax, "figure"):
55
+ self.ax.figure.set_facecolor(facecolor)
56
+ plots: dict[str, list[Number]] = {}
57
+ T = self.schedule.T
58
+ dt = self.schedule.dt
59
+ hamiltonians = self.schedule.hamiltonians
60
+ times = [i * dt for i in range(int(T / dt))]
61
+ for h in hamiltonians:
62
+ plots[h] = []
63
+ for _t in range(int(T / dt)):
64
+ t = _t * dt
65
+ for h in hamiltonians:
66
+ plots[h].append(self.schedule.get_coefficient(t, h))
67
+
68
+ # Generate gradient colors between primary and accent
69
+ def hex_to_rgb(hex_color: str) -> tuple[int, ...]:
70
+ hex_color = hex_color.lstrip("#")
71
+ return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))
72
+
73
+ def rgb_to_hex(rgb: tuple[int, ...]) -> str:
74
+ return "#{:02x}{:02x}{:02x}".format(*rgb)
75
+
76
+ def gradient_colors(start_hex: str, end_hex: str, n: int) -> list[str]:
77
+ start_rgb = hex_to_rgb(start_hex)
78
+ end_rgb = hex_to_rgb(end_hex)
79
+ colors = []
80
+ for i in range(n):
81
+ ratio = i / max(n - 1, 1)
82
+ rgb = tuple(int(start_rgb[j] + (end_rgb[j] - start_rgb[j]) * ratio) for j in range(3))
83
+ colors.append(rgb_to_hex(rgb))
84
+ return colors
85
+
86
+ n_hams = len(hamiltonians)
87
+ grad_colors = gradient_colors(theme.primary, theme.accent, n_hams)
88
+
89
+ for idx, h in enumerate(hamiltonians):
90
+ line_style = style.line_styles.get(h, style.default_line_style)
91
+ marker = style.marker
92
+ # If no color specified, use gradient color
93
+ if "color" not in line_style:
94
+ color = grad_colors[idx]
95
+ line_style = {**line_style, "color": color}
96
+ self.ax.plot(times, plots[h], label=h, marker=marker, markersize=style.marker_size, **line_style)
97
+ if style.grid:
98
+ grid_style = dict(style.grid_style)
99
+ if "color" not in grid_style:
100
+ grid_style["color"] = theme.surface_muted
101
+ self.ax.grid(**grid_style)
102
+ leg = self.ax.legend(
103
+ loc=style.legend_loc,
104
+ fontsize=style.legend_fontsize,
105
+ frameon=style.legend_frame,
106
+ facecolor=legend_facecolor,
107
+ edgecolor=legend_edgecolor,
108
+ )
109
+ # Set legend text color to match theme text color
110
+ if leg:
111
+ for text in leg.get_texts():
112
+ text.set_color(title_color)
113
+ self.ax.set_title(
114
+ self.style.title or "Schedule Plot",
115
+ fontsize=style.title_fontsize,
116
+ color=title_color,
117
+ fontweight=style.fontweight,
118
+ family=style.fontfamily,
119
+ )
120
+ self.ax.set_xlabel(
121
+ style.xlabel,
122
+ fontsize=style.label_fontsize,
123
+ color=label_color,
124
+ fontweight=style.fontweight,
125
+ family=style.fontfamily,
126
+ )
127
+ self.ax.set_ylabel(
128
+ style.ylabel,
129
+ fontsize=style.label_fontsize,
130
+ color=label_color,
131
+ fontweight=style.fontweight,
132
+ family=style.fontfamily,
133
+ )
134
+ self.ax.tick_params(axis="x", labelsize=style.xtick_fontsize, colors=tick_color)
135
+ self.ax.tick_params(axis="y", labelsize=style.ytick_fontsize, colors=tick_color)
136
+ if style.tight_layout:
137
+ plt.tight_layout()
138
+ plt.draw()
139
+
140
+ def save(self, filename: str) -> None: # thin wrapper
141
+ """Save current figure to disk.
142
+
143
+ Args:
144
+ filename: Path to save the figure (e.g., 'circuit.png').
145
+ """
146
+
147
+ self.ax.figure.savefig(filename, bbox_inches="tight") # type: ignore[union-attr]
148
+
149
+ @staticmethod
150
+ def _make_axes(dpi: int, style: ScheduleStyle) -> plt.Axes:
151
+ """
152
+ Create a new figure and axes with the given DPI.
153
+
154
+ Args:
155
+ style: Optional style configuration (for DPI).
156
+
157
+ Returns:
158
+ A newly created Matplotlib Axes.
159
+ """
160
+ _, ax = plt.subplots(figsize=style.figsize, dpi=style.dpi, facecolor=style.theme.background)
161
+ return ax
@@ -0,0 +1,154 @@
1
+ # Copyright 2025 Qilimanjaro Quantum Tech
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import Any, Literal, Optional
17
+
18
+ import matplotlib.font_manager as fm
19
+ from pydantic import BaseModel, Field
20
+
21
+ from .themes import Theme, light
22
+
23
+ _DEFAULT_FONT_PATH = Path(__file__).parent / "PlusJakartaSans-SemiBold.ttf"
24
+
25
+
26
+ class Style(BaseModel):
27
+ # --- FontProperties-mapped fields (mirror matplotlib.font_manager.FontProperties) ---
28
+ # If `fontfname` exists, it takes precedence and loads the exact TTF.
29
+ theme: Theme = Field(default=light, description="Colour theme.")
30
+ fontfamily: str | list[str] | None = Field(
31
+ default=None, description="Font family name(s), e.g. 'Outfit' or ['Outfit', 'DejaVu Sans']."
32
+ )
33
+ fontstyle: Literal["normal", "italic", "oblique"] = Field(
34
+ default="normal", description="Font style: 'normal', 'italic', or 'oblique'."
35
+ )
36
+ fontvariant: Literal["normal", "small-caps"] = Field(
37
+ default="normal", description="Font variant: typically 'normal' or 'small-caps'."
38
+ )
39
+ fontweight: str | int = Field(
40
+ default="normal", description="Font weight: 'normal', 'bold', 'light', or numeric (100-900)."
41
+ )
42
+ fontstretch: str | int = Field(
43
+ default="normal", description="Width/condensation: 'ultra-condensed'..'ultra-expanded' or numeric."
44
+ )
45
+ fontsize: float | str = Field(
46
+ default=10, description="Font size in pt or keywords like 'small', 'medium', 'large'."
47
+ )
48
+ fontfname: str | None = Field(
49
+ default=str(_DEFAULT_FONT_PATH), description="Absolute path to the TTF/OTF file. If present, overrides family."
50
+ )
51
+ math_fontfamily: str | None = Field(default=None, description="Math text family, e.g. 'dejavusans', 'cm', or None.")
52
+ dpi: int = Field(default=150, description="Figure DPI.")
53
+ title: str | None = Field(default=None, description="Figure title.")
54
+
55
+ @property
56
+ def font(self) -> fm.FontProperties:
57
+ """
58
+ Construct a Matplotlib FontProperties from the configured fields.
59
+ If `fontfname` points to a real file, it is used (and overrides family).
60
+ """
61
+ return fm.FontProperties(
62
+ family=self.fontfamily,
63
+ style=self.fontstyle,
64
+ variant=self.fontvariant,
65
+ weight=self.fontweight,
66
+ stretch=self.fontstretch,
67
+ size=self.fontsize,
68
+ fname=self.fontfname,
69
+ math_fontfamily=self.math_fontfamily,
70
+ )
71
+
72
+
73
+ class CircuitStyle(Style):
74
+ """All visual parameters controlling the appearance of a circuit plot."""
75
+
76
+ end_wire_ext: int = Field(default=2, description="Extra space after last layer.")
77
+ padding: float = Field(default=0.3, description="Padding around drawing (inches).")
78
+ gate_margin: float = Field(default=0.15, description="Left/right margin per gate.")
79
+ wire_sep: float = Field(default=0.5, description="Vertical separation of wires.")
80
+ layer_sep: float = Field(default=0.5, description="Horizontal separation of layers.")
81
+ gate_pad: float = Field(default=0.05, description="Padding around gate text.")
82
+ label_pad: float = Field(default=0.1, description="Padding before wire label.")
83
+ bulge: str = Field(default="round", description="Box-style for gate rectangles.")
84
+ align_layer: bool = Field(default=True, description="Align layers across wires.")
85
+
86
+ wire_label: list[Any] | None = Field(default=None, description="Custom wire labels.")
87
+ start_pad: float = Field(
88
+ default=0.1, description="Minimum spacing (inches) before the first layer so wire labels fit."
89
+ )
90
+ min_gate_h: float = Field(default=0.2, description="Minimum gate box height (inches).")
91
+ min_gate_w: float = Field(default=0.2, description="Minimum gate box width (inches).")
92
+ connector_r: float = Field(
93
+ default=0.01, description="Radius (inches) of small connector dots on multi-target gates."
94
+ )
95
+ target_r: float = Field(default=0.12, description="Radius (inches) of ⊕ target circle and SWAP half-width.")
96
+ control_r: float = Field(default=0.05, description="Radius (inches) of a filled control dot.")
97
+
98
+ layout: Literal["normal", "compact"] = Field(
99
+ default="normal",
100
+ description="If 'compact' minimizes the layers to highlight circuit depth, if 'normal' conserves the order of the circuit",
101
+ )
102
+
103
+
104
+ class ScheduleStyle(Style):
105
+ """
106
+ Customization options for matplotlib schedule plots, with theme support.
107
+ """
108
+
109
+ # Figure and axes
110
+ figsize: Optional[tuple] = Field(default=(8, 5), description="Figure size in inches (width, height).")
111
+ grid: bool = Field(default=True, description="Whether to show grid lines on the plot.")
112
+ grid_style: dict[str, Any] = Field(
113
+ default_factory=lambda: {"linestyle": "--", "color": "#e0e0e0", "alpha": 0.7},
114
+ description="Style dictionary for grid lines (linestyle, color, alpha, etc.).",
115
+ )
116
+
117
+ # Title and labels
118
+ title_fontsize: int = Field(default=16, description="Font size for the plot title.")
119
+ xlabel: str = Field(default="time (dt)", description="Label for the x-axis.")
120
+ ylabel: str = Field(default="coefficient value", description="Label for the y-axis.")
121
+ label_fontsize: int = Field(default=14, description="Font size for axis labels.")
122
+
123
+ # Legend
124
+ legend_loc: str = Field(
125
+ default="best", description="Location of the legend (matplotlib string, e.g. 'best', 'upper right')."
126
+ )
127
+ legend_fontsize: int = Field(default=12, description="Font size for legend text.")
128
+ legend_frame: bool = Field(default=True, description="Whether to draw a frame around the legend.")
129
+
130
+ # Line style
131
+ line_styles: dict[str, dict[str, Any]] = Field(
132
+ default_factory=dict,
133
+ description="Custom line style dictionary for each Hamiltonian (e.g. {label: {color, linestyle, linewidth}}).",
134
+ )
135
+ default_line_style: dict[str, Any] = Field(
136
+ default_factory=lambda: {"linestyle": "-", "linewidth": 2},
137
+ description="Default line style for Hamiltonians not in line_styles.",
138
+ )
139
+
140
+ # Marker style
141
+ marker: Optional[str] = Field(
142
+ default=None, description="Matplotlib marker style for data points (e.g. 'o', 's', None for no marker)."
143
+ )
144
+ marker_size: int = Field(default=6, description="Size of markers if used.")
145
+
146
+ # Ticks
147
+ xtick_fontsize: int = Field(default=12, description="Font size for x-axis tick labels.")
148
+ ytick_fontsize: int = Field(default=12, description="Font size for y-axis tick labels.")
149
+ tick_color: Optional[str] = Field(
150
+ default=None, description="Color for tick labels (None uses theme.on_background)."
151
+ )
152
+
153
+ # Misc
154
+ tight_layout: bool = Field(default=True, description="Whether to use matplotlib's tight_layout for figure spacing.")
@@ -0,0 +1,76 @@
1
+ # Copyright 2025 Qilimanjaro Quantum Tech
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Final
16
+
17
+ from pydantic import BaseModel, ConfigDict, Field
18
+
19
+ WHITE: Final[str] = "#FFFFFF"
20
+ BLACK: Final[str] = "#000000"
21
+
22
+ # Neutral ramp (light → dark)
23
+ NEUTRAL_050: Final[str] = "#F0F0F0"
24
+ NEUTRAL_100: Final[str] = "#DCDCDC"
25
+ NEUTRAL_200: Final[str] = "#CACACA"
26
+ NEUTRAL_600: Final[str] = "#7A7A7A"
27
+ NEUTRAL_800: Final[str] = "#2F2F2F"
28
+ NEUTRAL_900: Final[str] = "#1F1F1F"
29
+
30
+ # Brand colours
31
+ VIOLET: Final[str] = "#5E56A1"
32
+ MAGENTA: Final[str] = "#AC115F"
33
+
34
+
35
+ class Theme(BaseModel):
36
+ """Colour Theme."""
37
+
38
+ model_config = ConfigDict(frozen=True)
39
+
40
+ background: str = Field(description="Canvas background.")
41
+ on_background: str = Field(description="Default text/line color on background.")
42
+ surface: str = Field(description="Raised surface/panel fill.")
43
+ on_surface: str = Field(description="Text/line color on surface.")
44
+ surface_muted: str = Field(description="Muted lines on background (wires/grid).")
45
+ border: str = Field(description="Neutral stroke/border color.")
46
+ primary: str = Field(description="Primary/brand fill.")
47
+ on_primary: str = Field(description="Text/icons over primary.")
48
+ accent: str = Field(description="Accent/highlight color.")
49
+ on_accent: str = Field(description="Text/icons over accent.")
50
+
51
+
52
+ light = Theme(
53
+ background=WHITE,
54
+ on_background=BLACK,
55
+ surface=NEUTRAL_100,
56
+ on_surface=BLACK,
57
+ surface_muted=NEUTRAL_050,
58
+ border=NEUTRAL_200,
59
+ primary=VIOLET,
60
+ on_primary=WHITE,
61
+ accent=MAGENTA,
62
+ on_accent=WHITE,
63
+ )
64
+
65
+ dark = Theme(
66
+ background=BLACK,
67
+ on_background=WHITE,
68
+ surface=NEUTRAL_800,
69
+ on_surface=BLACK,
70
+ surface_muted=NEUTRAL_900,
71
+ border=NEUTRAL_600,
72
+ primary=MAGENTA,
73
+ on_primary=WHITE,
74
+ accent=VIOLET,
75
+ on_accent=WHITE,
76
+ )
qilisdk/yaml.py CHANGED
@@ -16,11 +16,33 @@
16
16
 
17
17
  import base64
18
18
  import types
19
+ from collections import defaultdict
19
20
 
20
21
  import numpy as np
21
22
  from dill import dumps, loads
22
23
  from pydantic import BaseModel
23
24
  from ruamel.yaml import YAML
25
+ from scipy import sparse
26
+
27
+
28
+ def csr_representer(representer, data: sparse.csr_matrix):
29
+ """Representer for CSR matrix."""
30
+ value = {
31
+ "data": data.data.tolist(),
32
+ "indices": data.indices.tolist(),
33
+ "indptr": data.indptr.tolist(),
34
+ "shape": data.shape,
35
+ }
36
+ return representer.represent_mapping("!csr_matrix", value)
37
+
38
+
39
+ def csr_constructor(constructor, node):
40
+ """Constructor for CSR matrix."""
41
+ mapping = constructor.construct_mapping(node, deep=True)
42
+ return sparse.csr_matrix(
43
+ (mapping["data"], mapping["indices"], mapping["indptr"]),
44
+ shape=tuple(mapping["shape"]),
45
+ )
24
46
 
25
47
 
26
48
  def ndarray_representer(representer, data):
@@ -38,6 +60,49 @@ def ndarray_constructor(constructor, node):
38
60
  return np.array(data, dtype=dtype).reshape(shape)
39
61
 
40
62
 
63
+ def np_scalar_representer(representer, data: np.generic):
64
+ """Represent any NumPy scalar (e.g. np.int64, np.float32)."""
65
+ return representer.represent_mapping(
66
+ "!np_scalar",
67
+ {"dtype": str(data.dtype), "value": data.item()},
68
+ )
69
+
70
+
71
+ def np_scalar_constructor(constructor, node):
72
+ """Reconstruct a NumPy scalar."""
73
+ mapping = constructor.construct_mapping(node, deep=True)
74
+ dtype = np.dtype(mapping["dtype"])
75
+ return dtype.type(mapping["value"])
76
+
77
+
78
+ def defaultdict_representer(representer, data: defaultdict):
79
+ """
80
+ Represent a defaultdict by serializing its default_factory
81
+ (as module+qualname) plus its items dict.
82
+ """
83
+ factory = data.default_factory
84
+ factory_name = None if factory is None else f"{factory.__module__}.{factory.__qualname__}"
85
+ return representer.represent_mapping(
86
+ "!defaultdict",
87
+ {"default_factory": factory_name, "items": dict(data)},
88
+ )
89
+
90
+
91
+ def defaultdict_constructor(constructor, node):
92
+ """Reconstruct a defaultdict, restoring its factory and contents."""
93
+ mapping = constructor.construct_mapping(node, deep=True)
94
+ fname = mapping["default_factory"]
95
+ if fname is None:
96
+ factory = None
97
+ else:
98
+ module, qual = fname.rsplit(".", 1)
99
+ mod = __import__(module, fromlist=[qual])
100
+ factory = getattr(mod, qual)
101
+ dd = defaultdict(factory)
102
+ dd.update(mapping["items"])
103
+ return dd
104
+
105
+
41
106
  def function_representer(representer, data):
42
107
  """Represent a non-lambda function by serializing it."""
43
108
  serialized_function = base64.b64encode(dumps(data, recurse=True)).decode("utf-8")
@@ -90,14 +155,75 @@ def complex_constructor(constructor, node):
90
155
  return complex(mapping["real"], mapping["imag"])
91
156
 
92
157
 
158
+ def tuple_representer(representer, data: tuple):
159
+ """Representer for built-in Python tuple."""
160
+ # Emit a tuple as a YAML sequence with tag !tuple
161
+ return representer.represent_sequence("!tuple", list(data))
162
+
163
+
164
+ def tuple_constructor(constructor, node):
165
+ """Constructor for built-in Python tuple."""
166
+ seq = constructor.construct_sequence(node, deep=True)
167
+ return tuple(seq)
168
+
169
+
170
+ def type_representer(representer, data: type):
171
+ """
172
+ Represent any Python class/type by its import path.
173
+ E.g. datetime.datetime → 'datetime.datetime'
174
+ """
175
+ path = f"{data.__module__}.{data.__qualname__}"
176
+ # emit as a simple scalar under !type
177
+ return representer.represent_scalar("!type", path)
178
+
179
+
180
+ def type_constructor(constructor, node):
181
+ """
182
+ Reconstruct a class/type from its import path.
183
+ """
184
+ path = node.value # e.g. "datetime.datetime"
185
+ module_name, qualname = path.rsplit(".", 1)
186
+ mod = __import__(module_name, fromlist=[qualname])
187
+ return getattr(mod, qualname)
188
+
189
+
190
+ # Create YAML handler and register all custom types
93
191
  yaml = YAML(typ="unsafe")
192
+
193
+ # SciPy CSR
194
+ yaml.representer.add_representer(sparse.csr_matrix, csr_representer)
195
+ yaml.constructor.add_constructor("!csr_matrix", csr_constructor)
196
+
197
+ # NumPy scalars
198
+ yaml.representer.add_multi_representer(np.generic, np_scalar_representer)
199
+ yaml.constructor.add_constructor("!np_scalar", np_scalar_constructor)
200
+
201
+ # defaultdict
202
+ yaml.representer.add_representer(defaultdict, defaultdict_representer)
203
+ yaml.constructor.add_constructor("!defaultdict", defaultdict_constructor)
204
+
205
+ # NumPy arrays
94
206
  yaml.representer.add_representer(np.ndarray, ndarray_representer)
95
207
  yaml.constructor.add_constructor("!ndarray", ndarray_constructor)
208
+
209
+ # Python functions and lambdas
96
210
  yaml.representer.add_representer(types.FunctionType, function_representer)
97
211
  yaml.constructor.add_constructor("!function", function_constructor)
98
212
  yaml.representer.add_representer(types.LambdaType, lambda_representer)
99
213
  yaml.constructor.add_constructor("!lambda", lambda_constructor)
214
+
215
+ # Pydantic models
100
216
  yaml.representer.add_representer(BaseModel, pydantic_model_representer)
101
217
  yaml.constructor.add_constructor("!PydanticModel", pydantic_model_constructor)
218
+
219
+ # Built-in complex numbers
102
220
  yaml.representer.add_representer(complex, complex_representer)
103
221
  yaml.constructor.add_constructor("!complex", complex_constructor)
222
+
223
+ # Built-in tuples
224
+ yaml.representer.add_representer(tuple, tuple_representer)
225
+ yaml.constructor.add_constructor("!tuple", tuple_constructor)
226
+
227
+ # Built-in type
228
+ yaml.representer.add_multi_representer(type, type_representer)
229
+ yaml.constructor.add_constructor("!type", type_constructor)