tensorcircuit-nightly 1.3.0.dev20250728__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorcircuit-nightly might be problematic. Click here for more details.
- tensorcircuit/__init__.py +5 -1
- tensorcircuit/abstractcircuit.py +4 -0
- tensorcircuit/analogcircuit.py +413 -0
- tensorcircuit/applications/layers.py +1 -1
- tensorcircuit/applications/van.py +1 -1
- tensorcircuit/backends/abstract_backend.py +312 -5
- tensorcircuit/backends/cupy_backend.py +3 -1
- tensorcircuit/backends/jax_backend.py +92 -3
- tensorcircuit/backends/jax_ops.py +108 -0
- tensorcircuit/backends/numpy_backend.py +49 -3
- tensorcircuit/backends/pytorch_backend.py +92 -3
- tensorcircuit/backends/tensorflow_backend.py +102 -3
- tensorcircuit/basecircuit.py +123 -82
- tensorcircuit/circuit.py +67 -57
- tensorcircuit/cloud/local.py +1 -1
- tensorcircuit/cloud/quafu_provider.py +1 -1
- tensorcircuit/cloud/tencent.py +1 -1
- tensorcircuit/compiler/simple_compiler.py +2 -2
- tensorcircuit/cons.py +1 -0
- tensorcircuit/densitymatrix.py +16 -11
- tensorcircuit/experimental.py +7 -152
- tensorcircuit/fgs.py +5 -6
- tensorcircuit/gates.py +66 -22
- tensorcircuit/keras.py +3 -3
- tensorcircuit/mpscircuit.py +109 -61
- tensorcircuit/quantum.py +697 -133
- tensorcircuit/quditcircuit.py +733 -0
- tensorcircuit/quditgates.py +618 -0
- tensorcircuit/results/counts.py +45 -31
- tensorcircuit/shadows.py +1 -1
- tensorcircuit/simplify.py +3 -1
- tensorcircuit/stabilizercircuit.py +4 -2
- tensorcircuit/templates/blocks.py +2 -2
- tensorcircuit/templates/hamiltonians.py +29 -8
- tensorcircuit/templates/lattice.py +676 -335
- tensorcircuit/timeevol.py +896 -0
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +50 -25
- tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
- tensorcircuit_nightly-1.3.0.dev20250728.dist-info/RECORD +0 -122
- tests/__init__.py +0 -0
- tests/conftest.py +0 -67
- tests/test_backends.py +0 -1035
- tests/test_calibrating.py +0 -149
- tests/test_channels.py +0 -409
- tests/test_circuit.py +0 -1713
- tests/test_cloud.py +0 -219
- tests/test_compiler.py +0 -147
- tests/test_dmcircuit.py +0 -555
- tests/test_ensemble.py +0 -72
- tests/test_fgs.py +0 -318
- tests/test_gates.py +0 -156
- tests/test_hamiltonians.py +0 -159
- tests/test_interfaces.py +0 -557
- tests/test_keras.py +0 -160
- tests/test_lattice.py +0 -1666
- tests/test_miscs.py +0 -334
- tests/test_mpscircuit.py +0 -341
- tests/test_noisemodel.py +0 -156
- tests/test_qaoa.py +0 -86
- tests/test_qem.py +0 -152
- tests/test_quantum.py +0 -549
- tests/test_quantum_attr.py +0 -42
- tests/test_results.py +0 -379
- tests/test_shadows.py +0 -160
- tests/test_simplify.py +0 -46
- tests/test_stabilizer.py +0 -226
- tests/test_templates.py +0 -218
- tests/test_torchnn.py +0 -99
- tests/test_van.py +0 -102
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +0 -0
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,8 +3,11 @@ Customized ops for ML framework
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
# pylint: disable=invalid-name
|
|
6
|
+
# pylint: disable=unused-variable
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
from typing import Any, Tuple, Sequence
|
|
10
|
+
from functools import partial
|
|
8
11
|
|
|
9
12
|
import jax
|
|
10
13
|
import jax.numpy as jnp
|
|
@@ -174,3 +177,108 @@ def jaxeigh_bwd(r: Array, tangents: Array) -> Array:
|
|
|
174
177
|
|
|
175
178
|
adaware_eigh.defvjp(jaxeigh_fwd, jaxeigh_bwd)
|
|
176
179
|
adaware_eigh_jit = jax.jit(adaware_eigh)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@partial(jax.jit, static_argnums=[0, 2])
|
|
183
|
+
def bessel_jv_jax_rescaled(k: int, x: jnp.ndarray, M: int) -> jnp.ndarray:
|
|
184
|
+
"""
|
|
185
|
+
Computes Bessel function Jv using Miller's algorithm with dynamic rescaling,
|
|
186
|
+
implemented in JAX.
|
|
187
|
+
"""
|
|
188
|
+
if M <= k:
|
|
189
|
+
raise ValueError(
|
|
190
|
+
f"Recurrence length M ({M}) must be greater than the required order k ({k})."
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Use vmap to handle array inputs for x efficiently.
|
|
194
|
+
# We map _bessel_jv_scalar_rescaled over the last dimension of x.
|
|
195
|
+
return _bessel_jv_scalar_rescaled(k, M, x)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _bessel_jv_scalar_rescaled(k: int, M: int, x_val: jnp.ndarray) -> jnp.ndarray:
|
|
199
|
+
"""
|
|
200
|
+
JAX implementation for a scalar input x_val.
|
|
201
|
+
This function will be vmapped for array inputs.
|
|
202
|
+
"""
|
|
203
|
+
rescale_threshold = 1e250
|
|
204
|
+
|
|
205
|
+
# Define the body of the recurrence loop
|
|
206
|
+
def recurrence_body(i, state): # type: ignore
|
|
207
|
+
# M - i is the current 'm' value in the original loop.
|
|
208
|
+
# Loop from M down to 1. jax.lax.fori_loop goes from lower to upper-1.
|
|
209
|
+
# So for m from M down to 1, we map i from 0 to M-1.
|
|
210
|
+
# Current m_val = M - i
|
|
211
|
+
# The loop range for m in numpy was `range(M, 0, -1)`, which means m goes from M, M-1, ..., 1.
|
|
212
|
+
# For lax.fori_loop (start, stop, body_fn, init_val), start is inclusive, stop is exclusive.
|
|
213
|
+
# So to iterate M times for m from M down to 1, we do i from 0 to M-1.
|
|
214
|
+
# m_val = M - i means that for i=0, m_val=M; for i=M-1, m_val=1.
|
|
215
|
+
m_val = M - i
|
|
216
|
+
f_m, f_m_p1, f_vals = state
|
|
217
|
+
|
|
218
|
+
# If x_val is near zero, this division could be an issue,
|
|
219
|
+
# but the outer lax.cond handles the x_val near zero case before this loop runs.
|
|
220
|
+
f_m_m1 = (2.0 * m_val / x_val) * f_m - f_m_p1
|
|
221
|
+
|
|
222
|
+
# --- Rescaling Step ---
|
|
223
|
+
# jax.lax.cond requires all branches to return the exact same type and shape.
|
|
224
|
+
def rescale_branch(vals): # type: ignore
|
|
225
|
+
f_m_val, f_m_p1_val, f_vals_arr = vals
|
|
226
|
+
scale_factor = f_m_m1
|
|
227
|
+
# Return new f_m, f_m_p1, updated f_vals_arr, and the new f_m_m1 value (which is 1.0)
|
|
228
|
+
return (
|
|
229
|
+
f_m_val / scale_factor,
|
|
230
|
+
f_m_p1_val / scale_factor,
|
|
231
|
+
f_vals_arr / scale_factor,
|
|
232
|
+
1.0,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def no_rescale_branch(vals): # type: ignore
|
|
236
|
+
f_m_val, f_m_p1_val, f_vals_arr = (
|
|
237
|
+
vals # Unpack to keep signatures consistent
|
|
238
|
+
)
|
|
239
|
+
# Return original f_m, f_m_p1, original f_vals_arr, and the computed f_m_m1
|
|
240
|
+
return (f_m_val, f_m_p1_val, f_vals_arr, f_m_m1)
|
|
241
|
+
|
|
242
|
+
f_m_rescaled, f_m_p1_rescaled, f_vals_rescaled, f_m_m1_effective = jax.lax.cond(
|
|
243
|
+
jnp.abs(f_m_m1) > rescale_threshold,
|
|
244
|
+
rescale_branch,
|
|
245
|
+
no_rescale_branch,
|
|
246
|
+
(f_m, f_m_p1, f_vals), # Arguments passed to branches
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Update f_vals at index m_val - 1. JAX uses .at[idx].set(val) for non-in-place updates.
|
|
250
|
+
f_vals_updated = f_vals_rescaled.at[m_val - 1].set(f_m_m1_effective)
|
|
251
|
+
|
|
252
|
+
# Return new state for the next iteration: (new f_m, new f_m_p1, updated f_vals)
|
|
253
|
+
return (f_m_m1_effective, f_m_rescaled, f_vals_updated)
|
|
254
|
+
|
|
255
|
+
# Initial state for the recurrence loop
|
|
256
|
+
f_m_p1_init = 0.0
|
|
257
|
+
f_m_init = 1e-30 # Start with a very small number
|
|
258
|
+
f_vals_init = jnp.zeros(M + 1).at[M].set(f_m_init)
|
|
259
|
+
|
|
260
|
+
# Use jax.lax.fori_loop for the backward recurrence
|
|
261
|
+
# Loop from i = 0 to M-1 (total M iterations)
|
|
262
|
+
# The 'body' function gets current 'i' and 'state', returns 'new_state'.
|
|
263
|
+
# We don't need the final f_m_p1, only f_m and f_vals.
|
|
264
|
+
final_f_m, _, f_vals = jax.lax.fori_loop(
|
|
265
|
+
0, M, recurrence_body, (f_m_init, f_m_p1_init, f_vals_init)
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Normalization using Neumann's sum rule
|
|
269
|
+
even_sum = jnp.sum(f_vals[2::2])
|
|
270
|
+
norm_const = f_vals[0] + 2.0 * even_sum
|
|
271
|
+
|
|
272
|
+
# Handle division by near-zero normalization constant
|
|
273
|
+
norm_const_safe = jnp.where(jnp.abs(norm_const) < 1e-12, 1e-12, norm_const)
|
|
274
|
+
|
|
275
|
+
# Conditional logic for x_val close to zero
|
|
276
|
+
def x_is_zero_case() -> jnp.ndarray:
|
|
277
|
+
# For x=0, J_0(0)=1, J_k(0)=0 for k>0
|
|
278
|
+
return jnp.zeros(k).at[0].set(1.0)
|
|
279
|
+
|
|
280
|
+
def x_is_not_zero_case() -> jnp.ndarray:
|
|
281
|
+
return f_vals[:k] / norm_const_safe # type: ignore
|
|
282
|
+
|
|
283
|
+
# Use lax.cond to select between the two cases based on x_val
|
|
284
|
+
return jax.lax.cond(jnp.abs(x_val) < 1e-12, x_is_zero_case, x_is_not_zero_case) # type: ignore
|
|
@@ -17,7 +17,7 @@ except ImportError: # np2.0 compatibility
|
|
|
17
17
|
|
|
18
18
|
import tensornetwork
|
|
19
19
|
from scipy.linalg import expm, solve, schur
|
|
20
|
-
from scipy.special import softmax, expit
|
|
20
|
+
from scipy.special import softmax, expit, jv
|
|
21
21
|
from scipy.sparse import coo_matrix, issparse
|
|
22
22
|
from tensornetwork.backends.numpy import numpy_backend
|
|
23
23
|
from .abstract_backend import ExtendedBackend
|
|
@@ -35,10 +35,14 @@ def _sum_numpy(
|
|
|
35
35
|
# see https://github.com/google/TensorNetwork/issues/952
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
def _convert_to_tensor_numpy(
|
|
38
|
+
def _convert_to_tensor_numpy(
|
|
39
|
+
self: Any, a: Tensor, dtype: Optional[str] = None
|
|
40
|
+
) -> Tensor:
|
|
39
41
|
if not isinstance(a, np.ndarray) and not np.isscalar(a):
|
|
40
42
|
a = np.array(a)
|
|
41
43
|
a = np.asarray(a)
|
|
44
|
+
if dtype is not None:
|
|
45
|
+
a = a.astype(getattr(np, dtype))
|
|
42
46
|
return a
|
|
43
47
|
|
|
44
48
|
|
|
@@ -132,6 +136,9 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
|
|
|
132
136
|
def kron(self, a: Tensor, b: Tensor) -> Tensor:
|
|
133
137
|
return np.kron(a, b)
|
|
134
138
|
|
|
139
|
+
def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
|
|
140
|
+
return np.meshgrid(*args, **kwargs)
|
|
141
|
+
|
|
135
142
|
def dtype(self, a: Tensor) -> str:
|
|
136
143
|
return a.dtype.__str__() # type: ignore
|
|
137
144
|
|
|
@@ -151,6 +158,9 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
|
|
|
151
158
|
dtype = getattr(np, dtype)
|
|
152
159
|
return np.array(1j, dtype=dtype)
|
|
153
160
|
|
|
161
|
+
def expand_dims(self, a: Tensor, axis: int) -> Tensor:
|
|
162
|
+
return np.expand_dims(a, axis)
|
|
163
|
+
|
|
154
164
|
def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
|
|
155
165
|
return np.stack(a, axis=axis)
|
|
156
166
|
|
|
@@ -173,6 +183,9 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
|
|
|
173
183
|
) -> Tensor:
|
|
174
184
|
return np.std(a, axis=axis, keepdims=keepdims)
|
|
175
185
|
|
|
186
|
+
def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
|
|
187
|
+
return np.all(a, axis=axis)
|
|
188
|
+
|
|
176
189
|
def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
|
|
177
190
|
return np.unique(a, return_counts=True) # type: ignore
|
|
178
191
|
|
|
@@ -188,6 +201,9 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
|
|
|
188
201
|
def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
|
|
189
202
|
return np.argmin(a, axis=axis)
|
|
190
203
|
|
|
204
|
+
def sort(self, a: Tensor, axis: int = -1) -> Tensor:
|
|
205
|
+
return np.sort(a, axis=axis)
|
|
206
|
+
|
|
191
207
|
def sigmoid(self, a: Tensor) -> Tensor:
|
|
192
208
|
return expit(a)
|
|
193
209
|
|
|
@@ -200,6 +216,7 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
|
|
|
200
216
|
return softmax(a, axis=axis)
|
|
201
217
|
|
|
202
218
|
def onehot(self, a: Tensor, num: int) -> Tensor:
|
|
219
|
+
a = np.asarray(a)
|
|
203
220
|
res = np.eye(num)[a.reshape([-1])]
|
|
204
221
|
return res.reshape(list(a.shape) + [num])
|
|
205
222
|
# https://stackoverflow.com/questions/38592324/one-hot-encoding-using-numpy
|
|
@@ -233,6 +250,15 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
|
|
|
233
250
|
def mod(self, x: Tensor, y: Tensor) -> Tensor:
|
|
234
251
|
return np.mod(x, y)
|
|
235
252
|
|
|
253
|
+
def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
|
|
254
|
+
return np.floor_divide(x, y)
|
|
255
|
+
|
|
256
|
+
def floor(self, a: Tensor) -> Tensor:
|
|
257
|
+
return np.floor(a)
|
|
258
|
+
|
|
259
|
+
def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
|
|
260
|
+
return np.clip(a, a_min, a_max)
|
|
261
|
+
|
|
236
262
|
def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
|
|
237
263
|
return np.right_shift(x, y)
|
|
238
264
|
|
|
@@ -244,9 +270,15 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
|
|
|
244
270
|
# https://stackoverflow.com/questions/44672029/difference-between-numpy-linalg-solve-and-numpy-linalg-lu-solve/44710451
|
|
245
271
|
return solve(A, b, assume_a=assume_a)
|
|
246
272
|
|
|
273
|
+
def special_jv(self, v: int, z: Tensor, M: int) -> Tensor:
|
|
274
|
+
return jv(np.arange(v), z)
|
|
275
|
+
|
|
247
276
|
def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
|
|
248
277
|
return np.searchsorted(a, v, side=side) # type: ignore
|
|
249
278
|
|
|
279
|
+
def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
|
|
280
|
+
return np.argsort(a, axis=axis)
|
|
281
|
+
|
|
250
282
|
def set_random_state(
|
|
251
283
|
self, seed: Optional[int] = None, get_only: bool = False
|
|
252
284
|
) -> Any:
|
|
@@ -329,12 +361,26 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
|
|
|
329
361
|
) -> Tensor:
|
|
330
362
|
return sp_a @ b
|
|
331
363
|
|
|
364
|
+
def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor:
|
|
365
|
+
return coo.tocsr()
|
|
366
|
+
|
|
332
367
|
def to_dense(self, sp_a: Tensor) -> Tensor:
|
|
333
368
|
return sp_a.todense()
|
|
334
369
|
|
|
335
370
|
def is_sparse(self, a: Tensor) -> bool:
|
|
336
371
|
return issparse(a) # type: ignore
|
|
337
372
|
|
|
373
|
+
def where(
|
|
374
|
+
self,
|
|
375
|
+
condition: Tensor,
|
|
376
|
+
x: Optional[Tensor] = None,
|
|
377
|
+
y: Optional[Tensor] = None,
|
|
378
|
+
) -> Tensor:
|
|
379
|
+
if x is None and y is None:
|
|
380
|
+
return np.where(condition)
|
|
381
|
+
assert x is not None and y is not None
|
|
382
|
+
return np.where(condition, x, y)
|
|
383
|
+
|
|
338
384
|
def cond(
|
|
339
385
|
self,
|
|
340
386
|
pred: bool,
|
|
@@ -390,7 +436,7 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
|
|
|
390
436
|
f: Callable[..., Any],
|
|
391
437
|
static_argnums: Optional[Union[int, Sequence[int]]] = None,
|
|
392
438
|
jit_compile: Optional[bool] = None,
|
|
393
|
-
**kws: Any
|
|
439
|
+
**kws: Any,
|
|
394
440
|
) -> Callable[..., Any]:
|
|
395
441
|
logger.info("numpy backend has no jit interface, just do nothing")
|
|
396
442
|
return f
|
|
@@ -9,6 +9,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
|
9
9
|
from operator import mul
|
|
10
10
|
from functools import reduce, partial
|
|
11
11
|
|
|
12
|
+
from scipy.sparse import coo_matrix
|
|
12
13
|
import tensornetwork
|
|
13
14
|
from tensornetwork.backends.pytorch import pytorch_backend
|
|
14
15
|
from .abstract_backend import ExtendedBackend
|
|
@@ -23,7 +24,6 @@ logger = logging.getLogger(__name__)
|
|
|
23
24
|
|
|
24
25
|
# TODO(@refraction-ray): lack stateful random methods implementation for now
|
|
25
26
|
# TODO(@refraction-ray): lack scatter impl for now
|
|
26
|
-
# TODO(@refraction-ray): lack sparse relevant methods for now
|
|
27
27
|
# To be added once pytorch backend is ready
|
|
28
28
|
|
|
29
29
|
|
|
@@ -229,6 +229,9 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
|
|
|
229
229
|
r = torchlib.ones(shape)
|
|
230
230
|
return self.cast(r, dtype)
|
|
231
231
|
|
|
232
|
+
def exp(self, tensor: Tensor) -> Tensor:
|
|
233
|
+
return torchlib.exp(tensor)
|
|
234
|
+
|
|
232
235
|
def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor:
|
|
233
236
|
if dtype is None:
|
|
234
237
|
dtype = dtypestr
|
|
@@ -238,8 +241,18 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
|
|
|
238
241
|
def copy(self, a: Tensor) -> Tensor:
|
|
239
242
|
return a.clone()
|
|
240
243
|
|
|
244
|
+
def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor:
|
|
245
|
+
if self.is_tensor(tensor):
|
|
246
|
+
result = tensor
|
|
247
|
+
else:
|
|
248
|
+
result = torchlib.tensor(tensor)
|
|
249
|
+
if dtype is not None:
|
|
250
|
+
result = self.cast(result, dtype)
|
|
251
|
+
return result
|
|
252
|
+
|
|
241
253
|
def expm(self, a: Tensor) -> Tensor:
|
|
242
|
-
|
|
254
|
+
return torchlib.linalg.matrix_exp(a)
|
|
255
|
+
# raise NotImplementedError("pytorch backend doesn't support expm")
|
|
243
256
|
# in 2020, torch has no expm, hmmm. but that's ok,
|
|
244
257
|
# it doesn't support complex numbers which is more severe issue.
|
|
245
258
|
# see https://github.com/pytorch/pytorch/issues/9983
|
|
@@ -293,6 +306,9 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
|
|
|
293
306
|
return torchlib.kron(a, b)
|
|
294
307
|
|
|
295
308
|
def numpy(self, a: Tensor) -> Tensor:
|
|
309
|
+
if self.is_sparse(a):
|
|
310
|
+
a = a.coalesce()
|
|
311
|
+
return coo_matrix((a.values().numpy(), a.indices().numpy()), shape=a.shape)
|
|
296
312
|
a = a.cpu()
|
|
297
313
|
if a.is_conj():
|
|
298
314
|
return a.resolve_conj().numpy()
|
|
@@ -369,6 +385,20 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
|
|
|
369
385
|
def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
|
|
370
386
|
return torchlib.argmin(a, dim=axis)
|
|
371
387
|
|
|
388
|
+
def sort(self, a: Tensor, axis: int = -1) -> Tensor:
|
|
389
|
+
return torchlib.sort(a, dim=axis).values
|
|
390
|
+
|
|
391
|
+
def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
|
|
392
|
+
return torchlib.argsort(a, dim=axis)
|
|
393
|
+
|
|
394
|
+
def all(self, tensor: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
|
|
395
|
+
"""
|
|
396
|
+
Corresponds to torch.all.
|
|
397
|
+
"""
|
|
398
|
+
if axis is None:
|
|
399
|
+
return torchlib.all(tensor)
|
|
400
|
+
return torchlib.all(tensor, dim=axis)
|
|
401
|
+
|
|
372
402
|
def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
|
|
373
403
|
return torchlib.unique(a, return_counts=True) # type: ignore
|
|
374
404
|
|
|
@@ -382,6 +412,7 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
|
|
|
382
412
|
return torchlib.nn.Softmax(a, dim=axis)
|
|
383
413
|
|
|
384
414
|
def onehot(self, a: Tensor, num: int) -> Tensor:
|
|
415
|
+
a = a.long()
|
|
385
416
|
return torchlib.nn.functional.one_hot(a, num)
|
|
386
417
|
|
|
387
418
|
def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
|
|
@@ -409,6 +440,15 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
|
|
|
409
440
|
def mod(self, x: Tensor, y: Tensor) -> Tensor:
|
|
410
441
|
return torchlib.fmod(x, y)
|
|
411
442
|
|
|
443
|
+
def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
|
|
444
|
+
return torchlib.floor_divide(x, y)
|
|
445
|
+
|
|
446
|
+
def floor(self, a: Tensor) -> Tensor:
|
|
447
|
+
return torchlib.floor(a)
|
|
448
|
+
|
|
449
|
+
def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
|
|
450
|
+
return torchlib.clamp(a, a_min, a_max)
|
|
451
|
+
|
|
412
452
|
def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
|
|
413
453
|
return torchlib.bitwise_right_shift(x, y)
|
|
414
454
|
|
|
@@ -425,9 +465,52 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
|
|
|
425
465
|
v = self.convert_to_tensor(v)
|
|
426
466
|
return torchlib.searchsorted(a, v, side=side)
|
|
427
467
|
|
|
468
|
+
def where(
|
|
469
|
+
self,
|
|
470
|
+
condition: Tensor,
|
|
471
|
+
x: Optional[Tensor] = None,
|
|
472
|
+
y: Optional[Tensor] = None,
|
|
473
|
+
) -> Tensor:
|
|
474
|
+
if x is None and y is None:
|
|
475
|
+
return torchlib.where(condition)
|
|
476
|
+
return torchlib.where(condition, x, y)
|
|
477
|
+
|
|
428
478
|
def reverse(self, a: Tensor) -> Tensor:
|
|
429
479
|
return torchlib.flip(a, dims=(-1,))
|
|
430
480
|
|
|
481
|
+
def coo_sparse_matrix(
|
|
482
|
+
self, indices: Tensor, values: Tensor, shape: Tensor
|
|
483
|
+
) -> Tensor:
|
|
484
|
+
# Convert COO format to PyTorch sparse tensor
|
|
485
|
+
indices = self.convert_to_tensor(indices)
|
|
486
|
+
return torchlib.sparse_coo_tensor(self.transpose(indices), values, shape)
|
|
487
|
+
|
|
488
|
+
def sparse_dense_matmul(
|
|
489
|
+
self,
|
|
490
|
+
sp_a: Tensor,
|
|
491
|
+
b: Tensor,
|
|
492
|
+
) -> Tensor:
|
|
493
|
+
# Matrix multiplication between sparse and dense tensor
|
|
494
|
+
return torchlib.sparse.mm(sp_a, b)
|
|
495
|
+
|
|
496
|
+
def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor:
|
|
497
|
+
try:
|
|
498
|
+
# Convert COO to CSR format if supported
|
|
499
|
+
return coo.to_sparse_csr()
|
|
500
|
+
except AttributeError as e:
|
|
501
|
+
if not strict:
|
|
502
|
+
return coo
|
|
503
|
+
else:
|
|
504
|
+
raise e
|
|
505
|
+
|
|
506
|
+
def to_dense(self, sp_a: Tensor) -> Tensor:
|
|
507
|
+
# Convert sparse tensor to dense
|
|
508
|
+
return sp_a.to_dense()
|
|
509
|
+
|
|
510
|
+
def is_sparse(self, a: Tensor) -> bool:
|
|
511
|
+
# Check if tensor is sparse
|
|
512
|
+
return a.is_sparse or a.is_sparse_csr # type: ignore
|
|
513
|
+
|
|
431
514
|
def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any:
|
|
432
515
|
# torch native tree_map not support multiple pytree args
|
|
433
516
|
# return torchlib.utils._pytree.tree_map(f, *pytrees)
|
|
@@ -661,7 +744,7 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
|
|
|
661
744
|
f: Callable[..., Any],
|
|
662
745
|
static_argnums: Optional[Union[int, Sequence[int]]] = None,
|
|
663
746
|
jit_compile: Optional[bool] = None,
|
|
664
|
-
**kws: Any
|
|
747
|
+
**kws: Any,
|
|
665
748
|
) -> Any:
|
|
666
749
|
if jit_compile is True:
|
|
667
750
|
# experimental feature reusing the jit_compile flag for tf
|
|
@@ -706,6 +789,12 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
|
|
|
706
789
|
|
|
707
790
|
return wrapper
|
|
708
791
|
|
|
792
|
+
def expand_dims(self, a: Tensor, axis: int) -> Tensor:
|
|
793
|
+
return torchlib.unsqueeze(a, dim=axis)
|
|
794
|
+
|
|
709
795
|
vvag = vectorized_value_and_grad
|
|
710
796
|
|
|
797
|
+
def meshgrid(self, *args: Any, **kws: Any) -> Tensor:
|
|
798
|
+
return torchlib.meshgrid(*args, **kws)
|
|
799
|
+
|
|
711
800
|
optimizer = torch_optimizer
|
|
@@ -75,6 +75,12 @@ class keras_optimizer:
|
|
|
75
75
|
def _tensordot_tf(
|
|
76
76
|
self: Any, a: Tensor, b: Tensor, axes: Union[int, Sequence[Sequence[int]]]
|
|
77
77
|
) -> Tensor:
|
|
78
|
+
# Use TensorFlow's dtype promotion rules by converting both to a common dtype
|
|
79
|
+
if a.dtype != b.dtype:
|
|
80
|
+
# Find the result dtype using TensorFlow's type promotion rules
|
|
81
|
+
common_dtype = tf.experimental.numpy.result_type(a.dtype, b.dtype)
|
|
82
|
+
a = tf.cast(a, common_dtype)
|
|
83
|
+
b = tf.cast(b, common_dtype)
|
|
78
84
|
return tf.tensordot(a, b, axes)
|
|
79
85
|
|
|
80
86
|
|
|
@@ -360,6 +366,38 @@ tensornetwork.backends.tensorflow.tensorflow_backend.TensorFlowBackend.rq = _rq_
|
|
|
360
366
|
tensornetwork.backends.tensorflow.tensorflow_backend.TensorFlowBackend.svd = _svd_tf
|
|
361
367
|
|
|
362
368
|
|
|
369
|
+
def sparse_tensor_matmul(self: Tensor, other: Tensor) -> Tensor:
|
|
370
|
+
"""
|
|
371
|
+
An implementation of matrix multiplication (@) for tf.SparseTensor.
|
|
372
|
+
|
|
373
|
+
This function is designed to be monkey-patched onto the tf.SparseTensor class.
|
|
374
|
+
It handles multiplication with a dense vector (rank-1 Tensor) by temporarily
|
|
375
|
+
promoting it to a matrix (rank-2 Tensor) for the underlying TensorFlow call.
|
|
376
|
+
"""
|
|
377
|
+
# Ensure the 'other' tensor is of a compatible dtype
|
|
378
|
+
if not other.dtype.is_compatible_with(self.dtype):
|
|
379
|
+
other = tf.cast(other, self.dtype)
|
|
380
|
+
|
|
381
|
+
# tf.sparse.sparse_dense_matmul requires the dense tensor to be a 2D matrix.
|
|
382
|
+
# If we get a 1D vector, we need to reshape it.
|
|
383
|
+
is_vector = len(other.shape) == 1
|
|
384
|
+
|
|
385
|
+
if is_vector:
|
|
386
|
+
# Promote the vector to a column matrix [N] -> [N, 1]
|
|
387
|
+
other_matrix = tf.expand_dims(other, axis=1)
|
|
388
|
+
else:
|
|
389
|
+
other_matrix = other
|
|
390
|
+
|
|
391
|
+
# Perform the actual multiplication
|
|
392
|
+
result_matrix = tf.sparse.sparse_dense_matmul(self, other_matrix)
|
|
393
|
+
|
|
394
|
+
if is_vector:
|
|
395
|
+
# Demote the result matrix back to a vector [M, 1] -> [M]
|
|
396
|
+
return tf.squeeze(result_matrix, axis=1)
|
|
397
|
+
else:
|
|
398
|
+
return result_matrix
|
|
399
|
+
|
|
400
|
+
|
|
363
401
|
class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend): # type: ignore
|
|
364
402
|
"""
|
|
365
403
|
See the original backend API at `tensorflow backend
|
|
@@ -378,6 +416,8 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
|
|
|
378
416
|
)
|
|
379
417
|
tf = tensorflow
|
|
380
418
|
tf.sparse.SparseTensor.__add__ = tf.sparse.add
|
|
419
|
+
tf.SparseTensor.__matmul__ = sparse_tensor_matmul
|
|
420
|
+
|
|
381
421
|
self.minor = int(tf.__version__.split(".")[1])
|
|
382
422
|
self.name = "tensorflow"
|
|
383
423
|
logger = tf.get_logger() # .setLevel('ERROR')
|
|
@@ -407,6 +447,12 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
|
|
|
407
447
|
def copy(self, a: Tensor) -> Tensor:
|
|
408
448
|
return tf.identity(a)
|
|
409
449
|
|
|
450
|
+
def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor:
|
|
451
|
+
result = tf.convert_to_tensor(tensor)
|
|
452
|
+
if dtype is not None:
|
|
453
|
+
result = self.cast(result, dtype)
|
|
454
|
+
return result
|
|
455
|
+
|
|
410
456
|
def expm(self, a: Tensor) -> Tensor:
|
|
411
457
|
return tf.linalg.expm(a)
|
|
412
458
|
|
|
@@ -490,12 +536,35 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
|
|
|
490
536
|
def max(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
|
|
491
537
|
return tf.reduce_max(a, axis=axis)
|
|
492
538
|
|
|
539
|
+
def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
|
|
540
|
+
return tf.reduce_all(tf.cast(a, tf.bool), axis=axis)
|
|
541
|
+
|
|
542
|
+
def where(
|
|
543
|
+
self,
|
|
544
|
+
condition: Tensor,
|
|
545
|
+
x: Optional[Tensor] = None,
|
|
546
|
+
y: Optional[Tensor] = None,
|
|
547
|
+
) -> Tensor:
|
|
548
|
+
if x is None and y is None:
|
|
549
|
+
# Return a tuple of tensors to be consistent with other backends
|
|
550
|
+
return tuple(tf.unstack(tf.where(condition), axis=1))
|
|
551
|
+
return tf.where(condition, x, y)
|
|
552
|
+
|
|
493
553
|
def argmax(self, a: Tensor, axis: int = 0) -> Tensor:
|
|
494
554
|
return tf.math.argmax(a, axis=axis)
|
|
495
555
|
|
|
496
556
|
def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
|
|
497
557
|
return tf.math.argmin(a, axis=axis)
|
|
498
558
|
|
|
559
|
+
def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
|
|
560
|
+
return tf.argsort(a, axis=axis)
|
|
561
|
+
|
|
562
|
+
def sort(self, a: Tensor, axis: int = -1) -> Tensor:
|
|
563
|
+
return tf.sort(a, axis=axis)
|
|
564
|
+
|
|
565
|
+
def shape_tuple(self, a: Tensor) -> Tuple[int, ...]:
|
|
566
|
+
return tuple(a.shape)
|
|
567
|
+
|
|
499
568
|
def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
|
|
500
569
|
r = tf.unique_with_counts(a)
|
|
501
570
|
order = tf.argsort(r.y)
|
|
@@ -504,6 +573,17 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
|
|
|
504
573
|
def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
|
|
505
574
|
return tf.stack(a, axis=axis)
|
|
506
575
|
|
|
576
|
+
def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
|
|
577
|
+
return tf.clip_by_value(a, a_min, a_max)
|
|
578
|
+
|
|
579
|
+
def floor(self, a: Tensor) -> Tensor:
|
|
580
|
+
if a.dtype.is_integer:
|
|
581
|
+
return a
|
|
582
|
+
return tf.math.floor(a)
|
|
583
|
+
|
|
584
|
+
def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
|
|
585
|
+
return tf.math.floordiv(x, y)
|
|
586
|
+
|
|
507
587
|
def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
|
|
508
588
|
return tf.concat(a, axis=axis)
|
|
509
589
|
|
|
@@ -678,7 +758,14 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
|
|
|
678
758
|
sp_a: Tensor,
|
|
679
759
|
b: Tensor,
|
|
680
760
|
) -> Tensor:
|
|
681
|
-
|
|
761
|
+
is_vec = False
|
|
762
|
+
if len(b.shape) == 1:
|
|
763
|
+
b = self.reshape(b, [-1, 1])
|
|
764
|
+
is_vec = True
|
|
765
|
+
r = tf.sparse.sparse_dense_matmul(sp_a, b)
|
|
766
|
+
if is_vec:
|
|
767
|
+
return self.reshape(r, [-1])
|
|
768
|
+
return r
|
|
682
769
|
|
|
683
770
|
def _densify(self) -> Tensor:
|
|
684
771
|
@partial(self.jit, jit_compile=True)
|
|
@@ -712,7 +799,10 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
|
|
|
712
799
|
def scan(
|
|
713
800
|
self, f: Callable[[Tensor, Tensor], Tensor], xs: Tensor, init: Tensor
|
|
714
801
|
) -> Tensor:
|
|
715
|
-
|
|
802
|
+
stacked_results = tf.scan(f, xs, init)
|
|
803
|
+
final_state = tf.nest.map_structure(lambda x: x[-1], stacked_results)
|
|
804
|
+
return final_state
|
|
805
|
+
# return tf.scan(f, xs, init)[-1]
|
|
716
806
|
|
|
717
807
|
def device(self, a: Tensor) -> str:
|
|
718
808
|
dev = a.device
|
|
@@ -864,7 +954,7 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
|
|
|
864
954
|
f: Callable[..., Any],
|
|
865
955
|
static_argnums: Optional[Union[int, Sequence[int]]] = None,
|
|
866
956
|
jit_compile: Optional[bool] = None,
|
|
867
|
-
**kws: Any
|
|
957
|
+
**kws: Any,
|
|
868
958
|
) -> Any:
|
|
869
959
|
# static_argnums not supported in tf case, this is only for a consistent interface
|
|
870
960
|
# for more on static_argnums in tf.function, see issue: https://github.com/tensorflow/tensorflow/issues/52193
|
|
@@ -1011,4 +1101,13 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
|
|
|
1011
1101
|
|
|
1012
1102
|
vvag = vectorized_value_and_grad
|
|
1013
1103
|
|
|
1104
|
+
def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
|
|
1105
|
+
"""
|
|
1106
|
+
Backend-agnostic meshgrid function.
|
|
1107
|
+
"""
|
|
1108
|
+
return tf.meshgrid(*args, **kwargs)
|
|
1109
|
+
|
|
1014
1110
|
optimizer = keras_optimizer
|
|
1111
|
+
|
|
1112
|
+
def expand_dims(self, a: Tensor, axis: int) -> Tensor:
|
|
1113
|
+
return tf.expand_dims(a, axis)
|