tensorcircuit-nightly 1.3.0.dev20250728__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (72) hide show
  1. tensorcircuit/__init__.py +5 -1
  2. tensorcircuit/abstractcircuit.py +4 -0
  3. tensorcircuit/analogcircuit.py +413 -0
  4. tensorcircuit/applications/layers.py +1 -1
  5. tensorcircuit/applications/van.py +1 -1
  6. tensorcircuit/backends/abstract_backend.py +312 -5
  7. tensorcircuit/backends/cupy_backend.py +3 -1
  8. tensorcircuit/backends/jax_backend.py +92 -3
  9. tensorcircuit/backends/jax_ops.py +108 -0
  10. tensorcircuit/backends/numpy_backend.py +49 -3
  11. tensorcircuit/backends/pytorch_backend.py +92 -3
  12. tensorcircuit/backends/tensorflow_backend.py +102 -3
  13. tensorcircuit/basecircuit.py +123 -82
  14. tensorcircuit/circuit.py +67 -57
  15. tensorcircuit/cloud/local.py +1 -1
  16. tensorcircuit/cloud/quafu_provider.py +1 -1
  17. tensorcircuit/cloud/tencent.py +1 -1
  18. tensorcircuit/compiler/simple_compiler.py +2 -2
  19. tensorcircuit/cons.py +1 -0
  20. tensorcircuit/densitymatrix.py +16 -11
  21. tensorcircuit/experimental.py +7 -152
  22. tensorcircuit/fgs.py +5 -6
  23. tensorcircuit/gates.py +66 -22
  24. tensorcircuit/keras.py +3 -3
  25. tensorcircuit/mpscircuit.py +109 -61
  26. tensorcircuit/quantum.py +697 -133
  27. tensorcircuit/quditcircuit.py +733 -0
  28. tensorcircuit/quditgates.py +618 -0
  29. tensorcircuit/results/counts.py +45 -31
  30. tensorcircuit/shadows.py +1 -1
  31. tensorcircuit/simplify.py +3 -1
  32. tensorcircuit/stabilizercircuit.py +4 -2
  33. tensorcircuit/templates/blocks.py +2 -2
  34. tensorcircuit/templates/hamiltonians.py +29 -8
  35. tensorcircuit/templates/lattice.py +676 -335
  36. tensorcircuit/timeevol.py +896 -0
  37. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +50 -25
  38. tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
  39. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
  40. tensorcircuit_nightly-1.3.0.dev20250728.dist-info/RECORD +0 -122
  41. tests/__init__.py +0 -0
  42. tests/conftest.py +0 -67
  43. tests/test_backends.py +0 -1035
  44. tests/test_calibrating.py +0 -149
  45. tests/test_channels.py +0 -409
  46. tests/test_circuit.py +0 -1713
  47. tests/test_cloud.py +0 -219
  48. tests/test_compiler.py +0 -147
  49. tests/test_dmcircuit.py +0 -555
  50. tests/test_ensemble.py +0 -72
  51. tests/test_fgs.py +0 -318
  52. tests/test_gates.py +0 -156
  53. tests/test_hamiltonians.py +0 -159
  54. tests/test_interfaces.py +0 -557
  55. tests/test_keras.py +0 -160
  56. tests/test_lattice.py +0 -1666
  57. tests/test_miscs.py +0 -334
  58. tests/test_mpscircuit.py +0 -341
  59. tests/test_noisemodel.py +0 -156
  60. tests/test_qaoa.py +0 -86
  61. tests/test_qem.py +0 -152
  62. tests/test_quantum.py +0 -549
  63. tests/test_quantum_attr.py +0 -42
  64. tests/test_results.py +0 -379
  65. tests/test_shadows.py +0 -160
  66. tests/test_simplify.py +0 -46
  67. tests/test_stabilizer.py +0 -226
  68. tests/test_templates.py +0 -218
  69. tests/test_torchnn.py +0 -99
  70. tests/test_van.py +0 -102
  71. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +0 -0
  72. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/licenses/LICENSE +0 -0
@@ -3,8 +3,11 @@ Customized ops for ML framework
3
3
  """
4
4
 
5
5
  # pylint: disable=invalid-name
6
+ # pylint: disable=unused-variable
7
+
6
8
 
7
9
  from typing import Any, Tuple, Sequence
10
+ from functools import partial
8
11
 
9
12
  import jax
10
13
  import jax.numpy as jnp
@@ -174,3 +177,108 @@ def jaxeigh_bwd(r: Array, tangents: Array) -> Array:
174
177
 
175
178
  adaware_eigh.defvjp(jaxeigh_fwd, jaxeigh_bwd)
176
179
  adaware_eigh_jit = jax.jit(adaware_eigh)
180
+
181
+
182
+ @partial(jax.jit, static_argnums=[0, 2])
183
+ def bessel_jv_jax_rescaled(k: int, x: jnp.ndarray, M: int) -> jnp.ndarray:
184
+ """
185
+ Computes Bessel function Jv using Miller's algorithm with dynamic rescaling,
186
+ implemented in JAX.
187
+ """
188
+ if M <= k:
189
+ raise ValueError(
190
+ f"Recurrence length M ({M}) must be greater than the required order k ({k})."
191
+ )
192
+
193
+ # Use vmap to handle array inputs for x efficiently.
194
+ # We map _bessel_jv_scalar_rescaled over the last dimension of x.
195
+ return _bessel_jv_scalar_rescaled(k, M, x)
196
+
197
+
198
+ def _bessel_jv_scalar_rescaled(k: int, M: int, x_val: jnp.ndarray) -> jnp.ndarray:
199
+ """
200
+ JAX implementation for a scalar input x_val.
201
+ This function will be vmapped for array inputs.
202
+ """
203
+ rescale_threshold = 1e250
204
+
205
+ # Define the body of the recurrence loop
206
+ def recurrence_body(i, state): # type: ignore
207
+ # M - i is the current 'm' value in the original loop.
208
+ # Loop from M down to 1. jax.lax.fori_loop goes from lower to upper-1.
209
+ # So for m from M down to 1, we map i from 0 to M-1.
210
+ # Current m_val = M - i
211
+ # The loop range for m in numpy was `range(M, 0, -1)`, which means m goes from M, M-1, ..., 1.
212
+ # For lax.fori_loop (start, stop, body_fn, init_val), start is inclusive, stop is exclusive.
213
+ # So to iterate M times for m from M down to 1, we do i from 0 to M-1.
214
+ # m_val = M - i means that for i=0, m_val=M; for i=M-1, m_val=1.
215
+ m_val = M - i
216
+ f_m, f_m_p1, f_vals = state
217
+
218
+ # If x_val is near zero, this division could be an issue,
219
+ # but the outer lax.cond handles the x_val near zero case before this loop runs.
220
+ f_m_m1 = (2.0 * m_val / x_val) * f_m - f_m_p1
221
+
222
+ # --- Rescaling Step ---
223
+ # jax.lax.cond requires all branches to return the exact same type and shape.
224
+ def rescale_branch(vals): # type: ignore
225
+ f_m_val, f_m_p1_val, f_vals_arr = vals
226
+ scale_factor = f_m_m1
227
+ # Return new f_m, f_m_p1, updated f_vals_arr, and the new f_m_m1 value (which is 1.0)
228
+ return (
229
+ f_m_val / scale_factor,
230
+ f_m_p1_val / scale_factor,
231
+ f_vals_arr / scale_factor,
232
+ 1.0,
233
+ )
234
+
235
+ def no_rescale_branch(vals): # type: ignore
236
+ f_m_val, f_m_p1_val, f_vals_arr = (
237
+ vals # Unpack to keep signatures consistent
238
+ )
239
+ # Return original f_m, f_m_p1, original f_vals_arr, and the computed f_m_m1
240
+ return (f_m_val, f_m_p1_val, f_vals_arr, f_m_m1)
241
+
242
+ f_m_rescaled, f_m_p1_rescaled, f_vals_rescaled, f_m_m1_effective = jax.lax.cond(
243
+ jnp.abs(f_m_m1) > rescale_threshold,
244
+ rescale_branch,
245
+ no_rescale_branch,
246
+ (f_m, f_m_p1, f_vals), # Arguments passed to branches
247
+ )
248
+
249
+ # Update f_vals at index m_val - 1. JAX uses .at[idx].set(val) for non-in-place updates.
250
+ f_vals_updated = f_vals_rescaled.at[m_val - 1].set(f_m_m1_effective)
251
+
252
+ # Return new state for the next iteration: (new f_m, new f_m_p1, updated f_vals)
253
+ return (f_m_m1_effective, f_m_rescaled, f_vals_updated)
254
+
255
+ # Initial state for the recurrence loop
256
+ f_m_p1_init = 0.0
257
+ f_m_init = 1e-30 # Start with a very small number
258
+ f_vals_init = jnp.zeros(M + 1).at[M].set(f_m_init)
259
+
260
+ # Use jax.lax.fori_loop for the backward recurrence
261
+ # Loop from i = 0 to M-1 (total M iterations)
262
+ # The 'body' function gets current 'i' and 'state', returns 'new_state'.
263
+ # We don't need the final f_m_p1, only f_m and f_vals.
264
+ final_f_m, _, f_vals = jax.lax.fori_loop(
265
+ 0, M, recurrence_body, (f_m_init, f_m_p1_init, f_vals_init)
266
+ )
267
+
268
+ # Normalization using Neumann's sum rule
269
+ even_sum = jnp.sum(f_vals[2::2])
270
+ norm_const = f_vals[0] + 2.0 * even_sum
271
+
272
+ # Handle division by near-zero normalization constant
273
+ norm_const_safe = jnp.where(jnp.abs(norm_const) < 1e-12, 1e-12, norm_const)
274
+
275
+ # Conditional logic for x_val close to zero
276
+ def x_is_zero_case() -> jnp.ndarray:
277
+ # For x=0, J_0(0)=1, J_k(0)=0 for k>0
278
+ return jnp.zeros(k).at[0].set(1.0)
279
+
280
+ def x_is_not_zero_case() -> jnp.ndarray:
281
+ return f_vals[:k] / norm_const_safe # type: ignore
282
+
283
+ # Use lax.cond to select between the two cases based on x_val
284
+ return jax.lax.cond(jnp.abs(x_val) < 1e-12, x_is_zero_case, x_is_not_zero_case) # type: ignore
@@ -17,7 +17,7 @@ except ImportError: # np2.0 compatibility
17
17
 
18
18
  import tensornetwork
19
19
  from scipy.linalg import expm, solve, schur
20
- from scipy.special import softmax, expit
20
+ from scipy.special import softmax, expit, jv
21
21
  from scipy.sparse import coo_matrix, issparse
22
22
  from tensornetwork.backends.numpy import numpy_backend
23
23
  from .abstract_backend import ExtendedBackend
@@ -35,10 +35,14 @@ def _sum_numpy(
35
35
  # see https://github.com/google/TensorNetwork/issues/952
36
36
 
37
37
 
38
- def _convert_to_tensor_numpy(self: Any, a: Tensor) -> Tensor:
38
+ def _convert_to_tensor_numpy(
39
+ self: Any, a: Tensor, dtype: Optional[str] = None
40
+ ) -> Tensor:
39
41
  if not isinstance(a, np.ndarray) and not np.isscalar(a):
40
42
  a = np.array(a)
41
43
  a = np.asarray(a)
44
+ if dtype is not None:
45
+ a = a.astype(getattr(np, dtype))
42
46
  return a
43
47
 
44
48
 
@@ -132,6 +136,9 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
132
136
  def kron(self, a: Tensor, b: Tensor) -> Tensor:
133
137
  return np.kron(a, b)
134
138
 
139
+ def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
140
+ return np.meshgrid(*args, **kwargs)
141
+
135
142
  def dtype(self, a: Tensor) -> str:
136
143
  return a.dtype.__str__() # type: ignore
137
144
 
@@ -151,6 +158,9 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
151
158
  dtype = getattr(np, dtype)
152
159
  return np.array(1j, dtype=dtype)
153
160
 
161
+ def expand_dims(self, a: Tensor, axis: int) -> Tensor:
162
+ return np.expand_dims(a, axis)
163
+
154
164
  def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
155
165
  return np.stack(a, axis=axis)
156
166
 
@@ -173,6 +183,9 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
173
183
  ) -> Tensor:
174
184
  return np.std(a, axis=axis, keepdims=keepdims)
175
185
 
186
+ def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
187
+ return np.all(a, axis=axis)
188
+
176
189
  def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
177
190
  return np.unique(a, return_counts=True) # type: ignore
178
191
 
@@ -188,6 +201,9 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
188
201
  def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
189
202
  return np.argmin(a, axis=axis)
190
203
 
204
+ def sort(self, a: Tensor, axis: int = -1) -> Tensor:
205
+ return np.sort(a, axis=axis)
206
+
191
207
  def sigmoid(self, a: Tensor) -> Tensor:
192
208
  return expit(a)
193
209
 
@@ -200,6 +216,7 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
200
216
  return softmax(a, axis=axis)
201
217
 
202
218
  def onehot(self, a: Tensor, num: int) -> Tensor:
219
+ a = np.asarray(a)
203
220
  res = np.eye(num)[a.reshape([-1])]
204
221
  return res.reshape(list(a.shape) + [num])
205
222
  # https://stackoverflow.com/questions/38592324/one-hot-encoding-using-numpy
@@ -233,6 +250,15 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
233
250
  def mod(self, x: Tensor, y: Tensor) -> Tensor:
234
251
  return np.mod(x, y)
235
252
 
253
+ def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
254
+ return np.floor_divide(x, y)
255
+
256
+ def floor(self, a: Tensor) -> Tensor:
257
+ return np.floor(a)
258
+
259
+ def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
260
+ return np.clip(a, a_min, a_max)
261
+
236
262
  def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
237
263
  return np.right_shift(x, y)
238
264
 
@@ -244,9 +270,15 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
244
270
  # https://stackoverflow.com/questions/44672029/difference-between-numpy-linalg-solve-and-numpy-linalg-lu-solve/44710451
245
271
  return solve(A, b, assume_a=assume_a)
246
272
 
273
+ def special_jv(self, v: int, z: Tensor, M: int) -> Tensor:
274
+ return jv(np.arange(v), z)
275
+
247
276
  def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
248
277
  return np.searchsorted(a, v, side=side) # type: ignore
249
278
 
279
+ def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
280
+ return np.argsort(a, axis=axis)
281
+
250
282
  def set_random_state(
251
283
  self, seed: Optional[int] = None, get_only: bool = False
252
284
  ) -> Any:
@@ -329,12 +361,26 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
329
361
  ) -> Tensor:
330
362
  return sp_a @ b
331
363
 
364
+ def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor:
365
+ return coo.tocsr()
366
+
332
367
  def to_dense(self, sp_a: Tensor) -> Tensor:
333
368
  return sp_a.todense()
334
369
 
335
370
  def is_sparse(self, a: Tensor) -> bool:
336
371
  return issparse(a) # type: ignore
337
372
 
373
+ def where(
374
+ self,
375
+ condition: Tensor,
376
+ x: Optional[Tensor] = None,
377
+ y: Optional[Tensor] = None,
378
+ ) -> Tensor:
379
+ if x is None and y is None:
380
+ return np.where(condition)
381
+ assert x is not None and y is not None
382
+ return np.where(condition, x, y)
383
+
338
384
  def cond(
339
385
  self,
340
386
  pred: bool,
@@ -390,7 +436,7 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
390
436
  f: Callable[..., Any],
391
437
  static_argnums: Optional[Union[int, Sequence[int]]] = None,
392
438
  jit_compile: Optional[bool] = None,
393
- **kws: Any
439
+ **kws: Any,
394
440
  ) -> Callable[..., Any]:
395
441
  logger.info("numpy backend has no jit interface, just do nothing")
396
442
  return f
@@ -9,6 +9,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Union
9
9
  from operator import mul
10
10
  from functools import reduce, partial
11
11
 
12
+ from scipy.sparse import coo_matrix
12
13
  import tensornetwork
13
14
  from tensornetwork.backends.pytorch import pytorch_backend
14
15
  from .abstract_backend import ExtendedBackend
@@ -23,7 +24,6 @@ logger = logging.getLogger(__name__)
23
24
 
24
25
  # TODO(@refraction-ray): lack stateful random methods implementation for now
25
26
  # TODO(@refraction-ray): lack scatter impl for now
26
- # TODO(@refraction-ray): lack sparse relevant methods for now
27
27
  # To be added once pytorch backend is ready
28
28
 
29
29
 
@@ -229,6 +229,9 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
229
229
  r = torchlib.ones(shape)
230
230
  return self.cast(r, dtype)
231
231
 
232
+ def exp(self, tensor: Tensor) -> Tensor:
233
+ return torchlib.exp(tensor)
234
+
232
235
  def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor:
233
236
  if dtype is None:
234
237
  dtype = dtypestr
@@ -238,8 +241,18 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
238
241
  def copy(self, a: Tensor) -> Tensor:
239
242
  return a.clone()
240
243
 
244
+ def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor:
245
+ if self.is_tensor(tensor):
246
+ result = tensor
247
+ else:
248
+ result = torchlib.tensor(tensor)
249
+ if dtype is not None:
250
+ result = self.cast(result, dtype)
251
+ return result
252
+
241
253
  def expm(self, a: Tensor) -> Tensor:
242
- raise NotImplementedError("pytorch backend doesn't support expm")
254
+ return torchlib.linalg.matrix_exp(a)
255
+ # raise NotImplementedError("pytorch backend doesn't support expm")
243
256
  # in 2020, torch has no expm, hmmm. but that's ok,
244
257
  # it doesn't support complex numbers which is more severe issue.
245
258
  # see https://github.com/pytorch/pytorch/issues/9983
@@ -293,6 +306,9 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
293
306
  return torchlib.kron(a, b)
294
307
 
295
308
  def numpy(self, a: Tensor) -> Tensor:
309
+ if self.is_sparse(a):
310
+ a = a.coalesce()
311
+ return coo_matrix((a.values().numpy(), a.indices().numpy()), shape=a.shape)
296
312
  a = a.cpu()
297
313
  if a.is_conj():
298
314
  return a.resolve_conj().numpy()
@@ -369,6 +385,20 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
369
385
  def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
370
386
  return torchlib.argmin(a, dim=axis)
371
387
 
388
+ def sort(self, a: Tensor, axis: int = -1) -> Tensor:
389
+ return torchlib.sort(a, dim=axis).values
390
+
391
+ def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
392
+ return torchlib.argsort(a, dim=axis)
393
+
394
+ def all(self, tensor: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
395
+ """
396
+ Corresponds to torch.all.
397
+ """
398
+ if axis is None:
399
+ return torchlib.all(tensor)
400
+ return torchlib.all(tensor, dim=axis)
401
+
372
402
  def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
373
403
  return torchlib.unique(a, return_counts=True) # type: ignore
374
404
 
@@ -382,6 +412,7 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
382
412
  return torchlib.nn.Softmax(a, dim=axis)
383
413
 
384
414
  def onehot(self, a: Tensor, num: int) -> Tensor:
415
+ a = a.long()
385
416
  return torchlib.nn.functional.one_hot(a, num)
386
417
 
387
418
  def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
@@ -409,6 +440,15 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
409
440
  def mod(self, x: Tensor, y: Tensor) -> Tensor:
410
441
  return torchlib.fmod(x, y)
411
442
 
443
+ def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
444
+ return torchlib.floor_divide(x, y)
445
+
446
+ def floor(self, a: Tensor) -> Tensor:
447
+ return torchlib.floor(a)
448
+
449
+ def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
450
+ return torchlib.clamp(a, a_min, a_max)
451
+
412
452
  def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
413
453
  return torchlib.bitwise_right_shift(x, y)
414
454
 
@@ -425,9 +465,52 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
425
465
  v = self.convert_to_tensor(v)
426
466
  return torchlib.searchsorted(a, v, side=side)
427
467
 
468
+ def where(
469
+ self,
470
+ condition: Tensor,
471
+ x: Optional[Tensor] = None,
472
+ y: Optional[Tensor] = None,
473
+ ) -> Tensor:
474
+ if x is None and y is None:
475
+ return torchlib.where(condition)
476
+ return torchlib.where(condition, x, y)
477
+
428
478
  def reverse(self, a: Tensor) -> Tensor:
429
479
  return torchlib.flip(a, dims=(-1,))
430
480
 
481
+ def coo_sparse_matrix(
482
+ self, indices: Tensor, values: Tensor, shape: Tensor
483
+ ) -> Tensor:
484
+ # Convert COO format to PyTorch sparse tensor
485
+ indices = self.convert_to_tensor(indices)
486
+ return torchlib.sparse_coo_tensor(self.transpose(indices), values, shape)
487
+
488
+ def sparse_dense_matmul(
489
+ self,
490
+ sp_a: Tensor,
491
+ b: Tensor,
492
+ ) -> Tensor:
493
+ # Matrix multiplication between sparse and dense tensor
494
+ return torchlib.sparse.mm(sp_a, b)
495
+
496
+ def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor:
497
+ try:
498
+ # Convert COO to CSR format if supported
499
+ return coo.to_sparse_csr()
500
+ except AttributeError as e:
501
+ if not strict:
502
+ return coo
503
+ else:
504
+ raise e
505
+
506
+ def to_dense(self, sp_a: Tensor) -> Tensor:
507
+ # Convert sparse tensor to dense
508
+ return sp_a.to_dense()
509
+
510
+ def is_sparse(self, a: Tensor) -> bool:
511
+ # Check if tensor is sparse
512
+ return a.is_sparse or a.is_sparse_csr # type: ignore
513
+
431
514
  def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any:
432
515
  # torch native tree_map not support multiple pytree args
433
516
  # return torchlib.utils._pytree.tree_map(f, *pytrees)
@@ -661,7 +744,7 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
661
744
  f: Callable[..., Any],
662
745
  static_argnums: Optional[Union[int, Sequence[int]]] = None,
663
746
  jit_compile: Optional[bool] = None,
664
- **kws: Any
747
+ **kws: Any,
665
748
  ) -> Any:
666
749
  if jit_compile is True:
667
750
  # experimental feature reusing the jit_compile flag for tf
@@ -706,6 +789,12 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
706
789
 
707
790
  return wrapper
708
791
 
792
+ def expand_dims(self, a: Tensor, axis: int) -> Tensor:
793
+ return torchlib.unsqueeze(a, dim=axis)
794
+
709
795
  vvag = vectorized_value_and_grad
710
796
 
797
+ def meshgrid(self, *args: Any, **kws: Any) -> Tensor:
798
+ return torchlib.meshgrid(*args, **kws)
799
+
711
800
  optimizer = torch_optimizer
@@ -75,6 +75,12 @@ class keras_optimizer:
75
75
  def _tensordot_tf(
76
76
  self: Any, a: Tensor, b: Tensor, axes: Union[int, Sequence[Sequence[int]]]
77
77
  ) -> Tensor:
78
+ # Use TensorFlow's dtype promotion rules by converting both to a common dtype
79
+ if a.dtype != b.dtype:
80
+ # Find the result dtype using TensorFlow's type promotion rules
81
+ common_dtype = tf.experimental.numpy.result_type(a.dtype, b.dtype)
82
+ a = tf.cast(a, common_dtype)
83
+ b = tf.cast(b, common_dtype)
78
84
  return tf.tensordot(a, b, axes)
79
85
 
80
86
 
@@ -360,6 +366,38 @@ tensornetwork.backends.tensorflow.tensorflow_backend.TensorFlowBackend.rq = _rq_
360
366
  tensornetwork.backends.tensorflow.tensorflow_backend.TensorFlowBackend.svd = _svd_tf
361
367
 
362
368
 
369
+ def sparse_tensor_matmul(self: Tensor, other: Tensor) -> Tensor:
370
+ """
371
+ An implementation of matrix multiplication (@) for tf.SparseTensor.
372
+
373
+ This function is designed to be monkey-patched onto the tf.SparseTensor class.
374
+ It handles multiplication with a dense vector (rank-1 Tensor) by temporarily
375
+ promoting it to a matrix (rank-2 Tensor) for the underlying TensorFlow call.
376
+ """
377
+ # Ensure the 'other' tensor is of a compatible dtype
378
+ if not other.dtype.is_compatible_with(self.dtype):
379
+ other = tf.cast(other, self.dtype)
380
+
381
+ # tf.sparse.sparse_dense_matmul requires the dense tensor to be a 2D matrix.
382
+ # If we get a 1D vector, we need to reshape it.
383
+ is_vector = len(other.shape) == 1
384
+
385
+ if is_vector:
386
+ # Promote the vector to a column matrix [N] -> [N, 1]
387
+ other_matrix = tf.expand_dims(other, axis=1)
388
+ else:
389
+ other_matrix = other
390
+
391
+ # Perform the actual multiplication
392
+ result_matrix = tf.sparse.sparse_dense_matmul(self, other_matrix)
393
+
394
+ if is_vector:
395
+ # Demote the result matrix back to a vector [M, 1] -> [M]
396
+ return tf.squeeze(result_matrix, axis=1)
397
+ else:
398
+ return result_matrix
399
+
400
+
363
401
  class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend): # type: ignore
364
402
  """
365
403
  See the original backend API at `tensorflow backend
@@ -378,6 +416,8 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
378
416
  )
379
417
  tf = tensorflow
380
418
  tf.sparse.SparseTensor.__add__ = tf.sparse.add
419
+ tf.SparseTensor.__matmul__ = sparse_tensor_matmul
420
+
381
421
  self.minor = int(tf.__version__.split(".")[1])
382
422
  self.name = "tensorflow"
383
423
  logger = tf.get_logger() # .setLevel('ERROR')
@@ -407,6 +447,12 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
407
447
  def copy(self, a: Tensor) -> Tensor:
408
448
  return tf.identity(a)
409
449
 
450
+ def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor:
451
+ result = tf.convert_to_tensor(tensor)
452
+ if dtype is not None:
453
+ result = self.cast(result, dtype)
454
+ return result
455
+
410
456
  def expm(self, a: Tensor) -> Tensor:
411
457
  return tf.linalg.expm(a)
412
458
 
@@ -490,12 +536,35 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
490
536
  def max(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
491
537
  return tf.reduce_max(a, axis=axis)
492
538
 
539
+ def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
540
+ return tf.reduce_all(tf.cast(a, tf.bool), axis=axis)
541
+
542
+ def where(
543
+ self,
544
+ condition: Tensor,
545
+ x: Optional[Tensor] = None,
546
+ y: Optional[Tensor] = None,
547
+ ) -> Tensor:
548
+ if x is None and y is None:
549
+ # Return a tuple of tensors to be consistent with other backends
550
+ return tuple(tf.unstack(tf.where(condition), axis=1))
551
+ return tf.where(condition, x, y)
552
+
493
553
  def argmax(self, a: Tensor, axis: int = 0) -> Tensor:
494
554
  return tf.math.argmax(a, axis=axis)
495
555
 
496
556
  def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
497
557
  return tf.math.argmin(a, axis=axis)
498
558
 
559
+ def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
560
+ return tf.argsort(a, axis=axis)
561
+
562
+ def sort(self, a: Tensor, axis: int = -1) -> Tensor:
563
+ return tf.sort(a, axis=axis)
564
+
565
+ def shape_tuple(self, a: Tensor) -> Tuple[int, ...]:
566
+ return tuple(a.shape)
567
+
499
568
  def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
500
569
  r = tf.unique_with_counts(a)
501
570
  order = tf.argsort(r.y)
@@ -504,6 +573,17 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
504
573
  def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
505
574
  return tf.stack(a, axis=axis)
506
575
 
576
+ def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
577
+ return tf.clip_by_value(a, a_min, a_max)
578
+
579
+ def floor(self, a: Tensor) -> Tensor:
580
+ if a.dtype.is_integer:
581
+ return a
582
+ return tf.math.floor(a)
583
+
584
+ def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
585
+ return tf.math.floordiv(x, y)
586
+
507
587
  def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
508
588
  return tf.concat(a, axis=axis)
509
589
 
@@ -678,7 +758,14 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
678
758
  sp_a: Tensor,
679
759
  b: Tensor,
680
760
  ) -> Tensor:
681
- return tf.sparse.sparse_dense_matmul(sp_a, b)
761
+ is_vec = False
762
+ if len(b.shape) == 1:
763
+ b = self.reshape(b, [-1, 1])
764
+ is_vec = True
765
+ r = tf.sparse.sparse_dense_matmul(sp_a, b)
766
+ if is_vec:
767
+ return self.reshape(r, [-1])
768
+ return r
682
769
 
683
770
  def _densify(self) -> Tensor:
684
771
  @partial(self.jit, jit_compile=True)
@@ -712,7 +799,10 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
712
799
  def scan(
713
800
  self, f: Callable[[Tensor, Tensor], Tensor], xs: Tensor, init: Tensor
714
801
  ) -> Tensor:
715
- return tf.scan(f, xs, init)[-1]
802
+ stacked_results = tf.scan(f, xs, init)
803
+ final_state = tf.nest.map_structure(lambda x: x[-1], stacked_results)
804
+ return final_state
805
+ # return tf.scan(f, xs, init)[-1]
716
806
 
717
807
  def device(self, a: Tensor) -> str:
718
808
  dev = a.device
@@ -864,7 +954,7 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
864
954
  f: Callable[..., Any],
865
955
  static_argnums: Optional[Union[int, Sequence[int]]] = None,
866
956
  jit_compile: Optional[bool] = None,
867
- **kws: Any
957
+ **kws: Any,
868
958
  ) -> Any:
869
959
  # static_argnums not supported in tf case, this is only for a consistent interface
870
960
  # for more on static_argnums in tf.function, see issue: https://github.com/tensorflow/tensorflow/issues/52193
@@ -1011,4 +1101,13 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
1011
1101
 
1012
1102
  vvag = vectorized_value_and_grad
1013
1103
 
1104
+ def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
1105
+ """
1106
+ Backend-agnostic meshgrid function.
1107
+ """
1108
+ return tf.meshgrid(*args, **kwargs)
1109
+
1014
1110
  optimizer = keras_optimizer
1111
+
1112
+ def expand_dims(self, a: Tensor, axis: int) -> Tensor:
1113
+ return tf.expand_dims(a, axis)