CUQIpy 1.1.1.post0.dev36__py3-none-any.whl → 1.4.1.post0.dev124__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of CUQIpy might be problematic. Click here for more details.
- cuqi/__init__.py +2 -0
- cuqi/_version.py +3 -3
- cuqi/algebra/__init__.py +2 -0
- cuqi/algebra/_abstract_syntax_tree.py +358 -0
- cuqi/algebra/_ordered_set.py +82 -0
- cuqi/algebra/_random_variable.py +457 -0
- cuqi/array/_array.py +4 -13
- cuqi/config.py +7 -0
- cuqi/density/_density.py +9 -1
- cuqi/distribution/__init__.py +3 -2
- cuqi/distribution/_beta.py +7 -11
- cuqi/distribution/_cauchy.py +2 -2
- cuqi/distribution/_custom.py +0 -6
- cuqi/distribution/_distribution.py +31 -45
- cuqi/distribution/_gamma.py +7 -3
- cuqi/distribution/_gaussian.py +2 -12
- cuqi/distribution/_inverse_gamma.py +4 -10
- cuqi/distribution/_joint_distribution.py +112 -15
- cuqi/distribution/_lognormal.py +0 -7
- cuqi/distribution/{_modifiedhalfnormal.py → _modified_half_normal.py} +23 -23
- cuqi/distribution/_normal.py +34 -7
- cuqi/distribution/_posterior.py +9 -0
- cuqi/distribution/_truncated_normal.py +129 -0
- cuqi/distribution/_uniform.py +47 -1
- cuqi/experimental/__init__.py +2 -2
- cuqi/experimental/_recommender.py +216 -0
- cuqi/geometry/__init__.py +2 -0
- cuqi/geometry/_geometry.py +15 -1
- cuqi/geometry/_product_geometry.py +181 -0
- cuqi/implicitprior/__init__.py +5 -3
- cuqi/implicitprior/_regularized_gaussian.py +483 -0
- cuqi/implicitprior/{_regularizedGMRF.py → _regularized_gmrf.py} +4 -2
- cuqi/implicitprior/{_regularizedUnboundedUniform.py → _regularized_unbounded_uniform.py} +3 -2
- cuqi/implicitprior/_restorator.py +269 -0
- cuqi/legacy/__init__.py +2 -0
- cuqi/{experimental/mcmc → legacy/sampler}/__init__.py +7 -11
- cuqi/legacy/sampler/_conjugate.py +55 -0
- cuqi/legacy/sampler/_conjugate_approx.py +52 -0
- cuqi/legacy/sampler/_cwmh.py +196 -0
- cuqi/legacy/sampler/_gibbs.py +231 -0
- cuqi/legacy/sampler/_hmc.py +335 -0
- cuqi/{experimental/mcmc → legacy/sampler}/_langevin_algorithm.py +82 -111
- cuqi/legacy/sampler/_laplace_approximation.py +184 -0
- cuqi/legacy/sampler/_mh.py +190 -0
- cuqi/legacy/sampler/_pcn.py +244 -0
- cuqi/{experimental/mcmc → legacy/sampler}/_rto.py +132 -90
- cuqi/legacy/sampler/_sampler.py +182 -0
- cuqi/likelihood/_likelihood.py +9 -1
- cuqi/model/__init__.py +1 -1
- cuqi/model/_model.py +1361 -359
- cuqi/pde/__init__.py +4 -0
- cuqi/pde/_observation_map.py +36 -0
- cuqi/pde/_pde.py +134 -33
- cuqi/problem/_problem.py +93 -87
- cuqi/sampler/__init__.py +120 -8
- cuqi/sampler/_conjugate.py +376 -35
- cuqi/sampler/_conjugate_approx.py +40 -16
- cuqi/sampler/_cwmh.py +132 -138
- cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
- cuqi/sampler/_gibbs.py +288 -130
- cuqi/sampler/_hmc.py +328 -201
- cuqi/sampler/_langevin_algorithm.py +284 -100
- cuqi/sampler/_laplace_approximation.py +87 -117
- cuqi/sampler/_mh.py +47 -157
- cuqi/sampler/_pcn.py +65 -213
- cuqi/sampler/_rto.py +211 -142
- cuqi/sampler/_sampler.py +553 -136
- cuqi/samples/__init__.py +1 -1
- cuqi/samples/_samples.py +24 -18
- cuqi/solver/__init__.py +6 -4
- cuqi/solver/_solver.py +230 -26
- cuqi/testproblem/_testproblem.py +2 -3
- cuqi/utilities/__init__.py +6 -1
- cuqi/utilities/_get_python_variable_name.py +2 -2
- cuqi/utilities/_utilities.py +182 -2
- {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info}/METADATA +10 -6
- cuqipy-1.4.1.post0.dev124.dist-info/RECORD +101 -0
- {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info}/WHEEL +1 -1
- CUQIpy-1.1.1.post0.dev36.dist-info/RECORD +0 -92
- cuqi/experimental/mcmc/_conjugate.py +0 -197
- cuqi/experimental/mcmc/_conjugate_approx.py +0 -81
- cuqi/experimental/mcmc/_cwmh.py +0 -191
- cuqi/experimental/mcmc/_gibbs.py +0 -268
- cuqi/experimental/mcmc/_hmc.py +0 -470
- cuqi/experimental/mcmc/_laplace_approximation.py +0 -156
- cuqi/experimental/mcmc/_mh.py +0 -78
- cuqi/experimental/mcmc/_pcn.py +0 -89
- cuqi/experimental/mcmc/_sampler.py +0 -561
- cuqi/experimental/mcmc/_utilities.py +0 -17
- cuqi/implicitprior/_regularizedGaussian.py +0 -323
- {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info/licenses}/LICENSE +0 -0
- {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info}/top_level.txt +0 -0
cuqi/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from . import data
|
|
2
2
|
from . import density
|
|
3
3
|
from . import diagnostics
|
|
4
|
+
from . import algebra
|
|
4
5
|
from . import distribution
|
|
5
6
|
from . import experimental
|
|
6
7
|
from . import geometry
|
|
@@ -11,6 +12,7 @@ from . import operator
|
|
|
11
12
|
from . import pde
|
|
12
13
|
from . import problem
|
|
13
14
|
from . import sampler
|
|
15
|
+
from . import legacy
|
|
14
16
|
from . import array
|
|
15
17
|
from . import samples
|
|
16
18
|
from . import solver
|
cuqi/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "
|
|
11
|
+
"date": "2025-12-09T08:46:13+0100",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "1.
|
|
14
|
+
"full-revisionid": "9d4bcd3dbc233e3ae3f26c8a0897cfecad93a5f1",
|
|
15
|
+
"version": "1.4.1.post0.dev124"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
cuqi/algebra/__init__.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CUQIpy specific implementation of an abstract syntax tree (AST) for algebra on variables.
|
|
3
|
+
|
|
4
|
+
The AST is used to record the operations applied to variables allowing a delayed evaluation
|
|
5
|
+
of said operations when needed by traversing the tree with the __call__ method.
|
|
6
|
+
|
|
7
|
+
For example, the following code
|
|
8
|
+
|
|
9
|
+
x = VariableNode('x')
|
|
10
|
+
y = VariableNode('y')
|
|
11
|
+
z = 2*x + 3*y
|
|
12
|
+
|
|
13
|
+
will create the following AST:
|
|
14
|
+
|
|
15
|
+
z = AddNode(
|
|
16
|
+
MultiplyNode(
|
|
17
|
+
ValueNode(2),
|
|
18
|
+
VariableNode('x')
|
|
19
|
+
),
|
|
20
|
+
MultiplyNode(
|
|
21
|
+
ValueNode(3),
|
|
22
|
+
VariableNode('y')
|
|
23
|
+
)
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
which can be evaluated by calling the __call__ method:
|
|
27
|
+
|
|
28
|
+
z(x=1, y=2) # returns 8
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from abc import ABC, abstractmethod
|
|
33
|
+
|
|
34
|
+
convert_to_node = lambda x: x if isinstance(x, Node) else ValueNode(x)
|
|
35
|
+
""" Converts any non-Node object to a ValueNode object. """
|
|
36
|
+
|
|
37
|
+
# ====== Base classes for the nodes ======
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Node(ABC):
|
|
41
|
+
"""Base class for all nodes in the abstract syntax tree.
|
|
42
|
+
|
|
43
|
+
Responsible for building the AST by creating nodes that represent the operations applied to variables.
|
|
44
|
+
|
|
45
|
+
Each subclass must implement the __call__ method that will evaluate the node given the input parameters.
|
|
46
|
+
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def __call__(self, **kwargs):
|
|
51
|
+
"""Evaluate node at a given parameter value. This will traverse the sub-tree originated at this node and evaluate it given the recorded operations."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def condition(self, **kwargs):
|
|
56
|
+
""" Conditions the tree by replacing any VariableNode with a ValueNode if the variable is in the kwargs dictionary. """
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def __repr__(self):
|
|
61
|
+
"""String representation of the node. Used for printing the AST."""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
def get_variables(self, variables=None):
|
|
65
|
+
"""Returns a set with the names of all variables in the sub-tree originated at this node."""
|
|
66
|
+
if variables is None:
|
|
67
|
+
variables = set()
|
|
68
|
+
if isinstance(self, VariableNode):
|
|
69
|
+
variables.add(self.name)
|
|
70
|
+
if hasattr(self, "child"):
|
|
71
|
+
self.child.get_variables(variables)
|
|
72
|
+
if hasattr(self, "left"):
|
|
73
|
+
self.left.get_variables(variables)
|
|
74
|
+
if hasattr(self, "right"):
|
|
75
|
+
self.right.get_variables(variables)
|
|
76
|
+
return variables
|
|
77
|
+
|
|
78
|
+
def __add__(self, other):
|
|
79
|
+
return AddNode(self, convert_to_node(other))
|
|
80
|
+
|
|
81
|
+
def __radd__(self, other):
|
|
82
|
+
return AddNode(convert_to_node(other), self)
|
|
83
|
+
|
|
84
|
+
def __sub__(self, other):
|
|
85
|
+
return SubtractNode(self, convert_to_node(other))
|
|
86
|
+
|
|
87
|
+
def __rsub__(self, other):
|
|
88
|
+
return SubtractNode(convert_to_node(other), self)
|
|
89
|
+
|
|
90
|
+
def __mul__(self, other):
|
|
91
|
+
return MultiplyNode(self, convert_to_node(other))
|
|
92
|
+
|
|
93
|
+
def __rmul__(self, other):
|
|
94
|
+
return MultiplyNode(convert_to_node(other), self)
|
|
95
|
+
|
|
96
|
+
def __truediv__(self, other):
|
|
97
|
+
return DivideNode(self, convert_to_node(other))
|
|
98
|
+
|
|
99
|
+
def __rtruediv__(self, other):
|
|
100
|
+
return DivideNode(convert_to_node(other), self)
|
|
101
|
+
|
|
102
|
+
def __pow__(self, other):
|
|
103
|
+
return PowerNode(self, convert_to_node(other))
|
|
104
|
+
|
|
105
|
+
def __rpow__(self, other):
|
|
106
|
+
return PowerNode(convert_to_node(other), self)
|
|
107
|
+
|
|
108
|
+
def __neg__(self):
|
|
109
|
+
return NegateNode(self)
|
|
110
|
+
|
|
111
|
+
def __abs__(self):
|
|
112
|
+
return AbsNode(self)
|
|
113
|
+
|
|
114
|
+
def __getitem__(self, i):
|
|
115
|
+
return GetItemNode(self, convert_to_node(i))
|
|
116
|
+
|
|
117
|
+
def __matmul__(self, other):
|
|
118
|
+
return MatMulNode(self, convert_to_node(other))
|
|
119
|
+
|
|
120
|
+
def __rmatmul__(self, other):
|
|
121
|
+
return MatMulNode(convert_to_node(other), self)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class UnaryNode(Node, ABC):
|
|
125
|
+
"""Base class for all unary nodes in the abstract syntax tree.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
child : Node
|
|
130
|
+
The direct child node on which the unary operation is performed.
|
|
131
|
+
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def __init__(self, child: Node):
|
|
135
|
+
self.child = child
|
|
136
|
+
|
|
137
|
+
def condition(self, **kwargs):
|
|
138
|
+
return self.__class__(self.child.condition(**kwargs))
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class BinaryNode(Node, ABC):
|
|
142
|
+
"""Base class for all binary nodes in the abstract syntax tree.
|
|
143
|
+
|
|
144
|
+
The op_symbol attribute is used for printing the operation in the __repr__ method.
|
|
145
|
+
|
|
146
|
+
Parameters
|
|
147
|
+
----------
|
|
148
|
+
left : Node
|
|
149
|
+
Left child node to the binary operation.
|
|
150
|
+
|
|
151
|
+
right : Node
|
|
152
|
+
Right child node to the binary operation.
|
|
153
|
+
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
@abstractmethod
|
|
158
|
+
def op_symbol(self):
|
|
159
|
+
"""Symbol used to represent the operation in the __repr__ method."""
|
|
160
|
+
pass
|
|
161
|
+
|
|
162
|
+
def __init__(self, left: Node, right: Node):
|
|
163
|
+
self.left = left
|
|
164
|
+
self.right = right
|
|
165
|
+
|
|
166
|
+
def condition(self, **kwargs):
|
|
167
|
+
return self.__class__(self.left.condition(**kwargs), self.right.condition(**kwargs))
|
|
168
|
+
|
|
169
|
+
def __repr__(self):
|
|
170
|
+
return f"{self.left} {self.op_symbol} {self.right}"
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class BinaryNodeWithParenthesis(BinaryNode, ABC):
|
|
174
|
+
"""Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis."""
|
|
175
|
+
|
|
176
|
+
def __repr__(self):
|
|
177
|
+
left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
|
|
178
|
+
right = (
|
|
179
|
+
f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right)
|
|
180
|
+
)
|
|
181
|
+
return f"{left} {self.op_symbol} {right}"
|
|
182
|
+
|
|
183
|
+
class BinaryNodeWithParenthesisNoSpace(BinaryNode, ABC):
|
|
184
|
+
"""Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis but no space."""
|
|
185
|
+
|
|
186
|
+
def __repr__(self):
|
|
187
|
+
left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
|
|
188
|
+
right = (
|
|
189
|
+
f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right)
|
|
190
|
+
)
|
|
191
|
+
return f"{left}{self.op_symbol}{right}"
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# ====== Specific implementations of the "leaf" nodes ======
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class VariableNode(Node):
|
|
198
|
+
"""Node that represents a generic variable, e.g. "x" or "y".
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
name : str
|
|
203
|
+
Name of the variable. Used for printing and to retrieve the given input value
|
|
204
|
+
of the variable in the kwargs dictionary when evaluating the tree.
|
|
205
|
+
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
def __init__(self, name):
|
|
209
|
+
self.name = name
|
|
210
|
+
|
|
211
|
+
def __call__(self, **kwargs):
|
|
212
|
+
"""Retrieves the value of the variable from the passed kwargs. If no value is found, it raises a KeyError."""
|
|
213
|
+
if not self.name in kwargs:
|
|
214
|
+
raise KeyError(
|
|
215
|
+
f"Variable '{self.name}' not found in the given input parameters. Unable to evaluate the expression."
|
|
216
|
+
)
|
|
217
|
+
return kwargs[self.name]
|
|
218
|
+
|
|
219
|
+
def condition(self, **kwargs):
|
|
220
|
+
if self.name in kwargs:
|
|
221
|
+
return ValueNode(kwargs[self.name])
|
|
222
|
+
return self
|
|
223
|
+
|
|
224
|
+
def __repr__(self):
|
|
225
|
+
return self.name
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class ValueNode(Node):
|
|
229
|
+
"""Node that represents a constant value. The value can be any python object that is not a Node.
|
|
230
|
+
|
|
231
|
+
Parameters
|
|
232
|
+
----------
|
|
233
|
+
value : object
|
|
234
|
+
The python object that represents the value of the node.
|
|
235
|
+
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
def __init__(self, value):
|
|
239
|
+
self.value = value
|
|
240
|
+
|
|
241
|
+
def __call__(self, **kwargs):
|
|
242
|
+
"""Returns the value of the node."""
|
|
243
|
+
return self.value
|
|
244
|
+
|
|
245
|
+
def condition(self, **kwargs):
|
|
246
|
+
return self
|
|
247
|
+
|
|
248
|
+
def __repr__(self):
|
|
249
|
+
return str(self.value)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
# ====== Specific implementations of the "internal" nodes ======
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class AddNode(BinaryNode):
|
|
256
|
+
"""Node that represents the addition operation."""
|
|
257
|
+
|
|
258
|
+
@property
|
|
259
|
+
def op_symbol(self):
|
|
260
|
+
return "+"
|
|
261
|
+
|
|
262
|
+
def __call__(self, **kwargs):
|
|
263
|
+
return self.left(**kwargs) + self.right(**kwargs)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class SubtractNode(BinaryNode):
|
|
267
|
+
"""Node that represents the subtraction operation."""
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def op_symbol(self):
|
|
271
|
+
return "-"
|
|
272
|
+
|
|
273
|
+
def __call__(self, **kwargs):
|
|
274
|
+
return self.left(**kwargs) - self.right(**kwargs)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class MultiplyNode(BinaryNodeWithParenthesis):
|
|
278
|
+
"""Node that represents the multiplication operation."""
|
|
279
|
+
|
|
280
|
+
@property
|
|
281
|
+
def op_symbol(self):
|
|
282
|
+
return "*"
|
|
283
|
+
|
|
284
|
+
def __call__(self, **kwargs):
|
|
285
|
+
return self.left(**kwargs) * self.right(**kwargs)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class DivideNode(BinaryNodeWithParenthesis):
|
|
289
|
+
"""Node that represents the division operation."""
|
|
290
|
+
|
|
291
|
+
@property
|
|
292
|
+
def op_symbol(self):
|
|
293
|
+
return "/"
|
|
294
|
+
|
|
295
|
+
def __call__(self, **kwargs):
|
|
296
|
+
return self.left(**kwargs) / self.right(**kwargs)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class PowerNode(BinaryNodeWithParenthesisNoSpace):
|
|
300
|
+
"""Node that represents the power operation."""
|
|
301
|
+
|
|
302
|
+
@property
|
|
303
|
+
def op_symbol(self):
|
|
304
|
+
return "^"
|
|
305
|
+
|
|
306
|
+
def __call__(self, **kwargs):
|
|
307
|
+
return self.left(**kwargs) ** self.right(**kwargs)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class GetItemNode(BinaryNode):
|
|
311
|
+
"""Node that represents the get item operation. Here the left node is the object and the right node is the index."""
|
|
312
|
+
|
|
313
|
+
def __call__(self, **kwargs):
|
|
314
|
+
return self.left(**kwargs)[self.right(**kwargs)]
|
|
315
|
+
|
|
316
|
+
def __repr__(self):
|
|
317
|
+
left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
|
|
318
|
+
return f"{left}[{self.right}]"
|
|
319
|
+
|
|
320
|
+
@property
|
|
321
|
+
def op_symbol(self):
|
|
322
|
+
pass
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class NegateNode(UnaryNode):
|
|
326
|
+
"""Node that represents the arithmetic negation operation."""
|
|
327
|
+
|
|
328
|
+
def __call__(self, **kwargs):
|
|
329
|
+
return -self.child(**kwargs)
|
|
330
|
+
|
|
331
|
+
def __repr__(self):
|
|
332
|
+
child = (
|
|
333
|
+
f"({self.child})"
|
|
334
|
+
if isinstance(self.child, (BinaryNode, UnaryNode))
|
|
335
|
+
else str(self.child)
|
|
336
|
+
)
|
|
337
|
+
return f"-{child}"
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class AbsNode(UnaryNode):
|
|
341
|
+
"""Node that represents the absolute value operation."""
|
|
342
|
+
|
|
343
|
+
def __call__(self, **kwargs):
|
|
344
|
+
return abs(self.child(**kwargs))
|
|
345
|
+
|
|
346
|
+
def __repr__(self):
|
|
347
|
+
return f"abs({self.child})"
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class MatMulNode(BinaryNodeWithParenthesis):
|
|
351
|
+
"""Node that represents the matrix multiplication operation."""
|
|
352
|
+
|
|
353
|
+
@property
|
|
354
|
+
def op_symbol(self):
|
|
355
|
+
return "@"
|
|
356
|
+
|
|
357
|
+
def __call__(self, **kwargs):
|
|
358
|
+
return self.left(**kwargs) @ self.right(**kwargs)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
class _OrderedSet:
|
|
2
|
+
"""A set (i.e. unique elements) that keeps its elements in the order they were added.
|
|
3
|
+
|
|
4
|
+
This is a minimal implementation of an ordered set, using a dictionary for storage.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
def __init__(self, iterable=None):
|
|
8
|
+
"""Initialize the OrderedSet.
|
|
9
|
+
|
|
10
|
+
If an iterable is provided, add all its elements to the set.
|
|
11
|
+
"""
|
|
12
|
+
self.dict = dict.fromkeys(iterable if iterable else [])
|
|
13
|
+
|
|
14
|
+
def add(self, item):
|
|
15
|
+
"""Add an item to the set.
|
|
16
|
+
|
|
17
|
+
If the item is already in the set, it does nothing.
|
|
18
|
+
Otherwise, the item is stored as a key in the dictionary, with None as its value.
|
|
19
|
+
"""
|
|
20
|
+
self.dict[item] = None
|
|
21
|
+
|
|
22
|
+
def remove(self, item):
|
|
23
|
+
"""Remove an item from the set.
|
|
24
|
+
|
|
25
|
+
If the item is not in the set, it raises a KeyError.
|
|
26
|
+
"""
|
|
27
|
+
del self.dict[item]
|
|
28
|
+
|
|
29
|
+
def __contains__(self, item):
|
|
30
|
+
"""Check if an item is in the set.
|
|
31
|
+
|
|
32
|
+
This is equivalent to checking if the item is a key in the dictionary.
|
|
33
|
+
"""
|
|
34
|
+
return item in self.dict
|
|
35
|
+
|
|
36
|
+
def __iter__(self):
|
|
37
|
+
"""Return an iterator over the set.
|
|
38
|
+
|
|
39
|
+
This iterates over the keys in the dictionary.
|
|
40
|
+
"""
|
|
41
|
+
return iter(self.dict)
|
|
42
|
+
|
|
43
|
+
def __len__(self):
|
|
44
|
+
"""Return the number of items in the set."""
|
|
45
|
+
return len(self.dict)
|
|
46
|
+
|
|
47
|
+
def extend(self, other):
|
|
48
|
+
"""Extend the set with the items in another set.
|
|
49
|
+
|
|
50
|
+
Raises a TypeError if the other object is not an _OrderedSet.
|
|
51
|
+
"""
|
|
52
|
+
if not isinstance(other, _OrderedSet):
|
|
53
|
+
raise TypeError("unsupported operand type(s) for extend: '_OrderedSet' and '{}'".format(type(other).__name__))
|
|
54
|
+
for item in other:
|
|
55
|
+
self.add(item)
|
|
56
|
+
|
|
57
|
+
def replace(self, old_item, new_item):
|
|
58
|
+
"""Replace old_item with new_item at the same position, preserving order."""
|
|
59
|
+
if old_item not in self.dict:
|
|
60
|
+
raise KeyError(f"{old_item} not in set")
|
|
61
|
+
|
|
62
|
+
items = list(self.dict.keys()) # Preserve order
|
|
63
|
+
index = items.index(old_item) # Find position
|
|
64
|
+
items[index] = new_item # Replace at the same position
|
|
65
|
+
|
|
66
|
+
# Reconstruct the ordered set with the new item in place
|
|
67
|
+
self.dict = dict.fromkeys(items)
|
|
68
|
+
|
|
69
|
+
def __or__(self, other):
|
|
70
|
+
"""Return a new set that is the union of this set and another set.
|
|
71
|
+
|
|
72
|
+
Raises a TypeError if the other object is not an _OrderedSet.
|
|
73
|
+
"""
|
|
74
|
+
if not isinstance(other, _OrderedSet):
|
|
75
|
+
raise TypeError("unsupported operand type(s) for |: '_OrderedSet' and '{}'".format(type(other).__name__))
|
|
76
|
+
new_set = _OrderedSet(self.dict.keys())
|
|
77
|
+
new_set.extend(other)
|
|
78
|
+
return new_set
|
|
79
|
+
|
|
80
|
+
def __repr__(self):
|
|
81
|
+
"""Return a string representation of the set."""
|
|
82
|
+
return "_OrderedSet({})".format(list(self.dict.keys()))
|