mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,519 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tensor Runtime Implementation.
16
+
17
+ Implements execution logic for Tensor primitives.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import base64
23
+ import hashlib
24
+ import os
25
+ import time
26
+ from typing import Any, ClassVar, cast
27
+
28
+ import jax
29
+ import jax.extend as jxt
30
+ import jax.numpy as jnp
31
+ import numpy as np
32
+ from jax._src import compiler
33
+ from numpy.typing import ArrayLike
34
+
35
+ import mplang.v2.edsl.typing as elt
36
+ from mplang.v2.dialects import dtypes, tensor
37
+ from mplang.v2.edsl import serde
38
+ from mplang.v2.edsl.graph import Operation
39
+ from mplang.v2.runtime.interpreter import Interpreter
40
+ from mplang.v2.runtime.value import Value, WrapValue
41
+
42
+ # =============================================================================
43
+ # TensorValue Wrapper
44
+ # =============================================================================
45
+
46
+
47
+ @serde.register_class
48
+ class TensorValue(WrapValue[Any]):
49
+ """Runtime value wrapping a numpy array or JAX array.
50
+
51
+ Handles numpy arrays, JAX arrays, and other numpy-like objects via duck typing.
52
+ Serialization uses base64-encoded raw bytes for efficiency.
53
+
54
+ Note: This is for numeric tensors only. Object dtype arrays (containing
55
+ encrypted values, etc.) should NOT be wrapped - they are handled separately
56
+ by elementwise_impl which returns raw np.ndarray(dtype=object).
57
+ """
58
+
59
+ _serde_kind: ClassVar[str] = "tensor_impl.TensorValue"
60
+
61
+ # Expose common array properties for convenience
62
+ @property
63
+ def shape(self) -> tuple[int, ...]:
64
+ return cast(tuple[int, ...], self._data.shape)
65
+
66
+ @property
67
+ def dtype(self) -> np.dtype[Any]:
68
+ return np.dtype(self._data.dtype) # type: ignore[no-any-return]
69
+
70
+ @property
71
+ def ndim(self) -> int:
72
+ return cast(int, self._data.ndim)
73
+
74
+ def __getitem__(self, key: Any) -> Any:
75
+ """Allow indexing into the underlying array."""
76
+ return self._data[key]
77
+
78
+ # =========== Wrap/Unwrap ===========
79
+
80
+ def _convert(self, data: Any) -> Any:
81
+ """Convert input data to numpy array or JAX array."""
82
+ if isinstance(data, TensorValue):
83
+ return data._data
84
+
85
+ # Allow JAX arrays to pass through
86
+ if hasattr(data, "__jax_array__"):
87
+ return data
88
+
89
+ # Handle other numpy-like objects via np.asarray
90
+ if (
91
+ hasattr(data, "__module__")
92
+ and data.__module__ is not None
93
+ and "jax" in data.__module__
94
+ ):
95
+ return data
96
+
97
+ if isinstance(data, np.ndarray):
98
+ return data
99
+ # Try converting other array-like objects
100
+ return np.asarray(data)
101
+
102
+ def unwrap(self) -> np.ndarray:
103
+ """Get the underlying data as a numpy array.
104
+
105
+ If the data is a JAX array, it will be transferred to host.
106
+ """
107
+ return np.asarray(self._data)
108
+
109
+ def as_jax(self) -> Any:
110
+ """Get the underlying data as a JAX array.
111
+
112
+ If the data is a numpy array, it will be transferred to device.
113
+ """
114
+ if hasattr(self._data, "__jax_array__"):
115
+ return self._data
116
+
117
+ # Handle object arrays that might contain numbers (e.g. from elementwise)
118
+ if isinstance(self._data, np.ndarray) and self._data.dtype == object:
119
+ try:
120
+ # Attempt to convert to numeric numpy array first
121
+ # This handles cases where elementwise returned object array of numbers
122
+ val_numeric = np.array(self._data.tolist())
123
+ if val_numeric.dtype != object:
124
+ return jax.device_put(jnp.asarray(val_numeric))
125
+ except Exception:
126
+ # If conversion fails, proceed with original (which will likely fail in jax)
127
+ pass
128
+
129
+ return jax.device_put(jnp.asarray(self._data))
130
+
131
+ # =========== Serialization ===========
132
+
133
+ def to_json(self) -> dict[str, Any]:
134
+ # Ensure we have numpy data for serialization
135
+ # This forces synchronization if data is on device
136
+ data_np = np.asarray(self._data)
137
+
138
+ # Handle object dtype arrays - serialize element by element
139
+ if data_np.dtype == np.object_:
140
+ return {
141
+ "kind": "object",
142
+ "shape": list(data_np.shape),
143
+ "items": [serde.to_json(item) for item in data_np.flat],
144
+ }
145
+ # Standard numeric arrays - use raw bytes
146
+ return {
147
+ "kind": "numeric",
148
+ "dtype": str(data_np.dtype),
149
+ "shape": list(data_np.shape),
150
+ "data": base64.b64encode(data_np.tobytes()).decode("ascii"),
151
+ }
152
+
153
+ @classmethod
154
+ def from_json(cls, data: dict[str, Any]) -> TensorValue:
155
+ kind = data.get("kind", "numeric")
156
+ shape = tuple(data["shape"])
157
+
158
+ if kind == "object":
159
+ items = [serde.from_json(item) for item in data["items"]]
160
+ arr = np.empty(len(items), dtype=object)
161
+ for i, item in enumerate(items):
162
+ arr[i] = item
163
+ return cls(arr.reshape(shape))
164
+ else:
165
+ arr = np.frombuffer(
166
+ base64.b64decode(data["data"]),
167
+ dtype=np.dtype(data["dtype"]),
168
+ )
169
+ return cls(arr.reshape(shape).copy())
170
+
171
+
172
+ # Module-level helpers for convenience (delegate to class methods)
173
+ def _wrap(val: ArrayLike | TensorValue) -> TensorValue:
174
+ """Wrap an array-like value into TensorValue."""
175
+ return TensorValue.wrap(val)
176
+
177
+
178
+ def _unwrap(val: TensorValue | np.ndarray | ArrayLike) -> np.ndarray:
179
+ """Unwrap TensorValue to np.ndarray, also accepts raw arrays."""
180
+ if isinstance(val, TensorValue):
181
+ return val.unwrap()
182
+ if isinstance(val, np.ndarray):
183
+ return val
184
+ # Handle JAX arrays
185
+ if hasattr(val, "__jax_array__"):
186
+ return np.asarray(val)
187
+ return np.asarray(val)
188
+
189
+
190
+ # _ensure_tensor_value removed - callers should unwrap InterpObject before calling impls
191
+
192
+
193
+ # =============================================================================
194
+ # Tensor Primitive Implementations
195
+ # =============================================================================
196
+
197
+
198
+ @tensor.constant_p.def_impl
199
+ def constant_impl(interpreter: Interpreter, op: Operation) -> TensorValue:
200
+ # Recover dtype and shape from IR type
201
+ output_type = op.outputs[0].type
202
+ if not isinstance(output_type, elt.TensorType):
203
+ raise TypeError(f"Expected TensorType, got {output_type}")
204
+
205
+ dtype = dtypes.to_jax(cast(elt.ScalarType, output_type.element_type))
206
+ if dtype is None:
207
+ raise ValueError(f"Unsupported scalar type {output_type.element_type}")
208
+
209
+ shape = output_type.shape
210
+
211
+ # Decode data
212
+ data_b64 = op.attrs["value_b64"]
213
+ data_bytes = base64.b64decode(data_b64)
214
+
215
+ # Create array
216
+ arr = np.frombuffer(data_bytes, dtype=cast(Any, dtype)).reshape(shape).copy()
217
+ return _wrap(arr)
218
+
219
+
220
+ @tensor.concat_p.def_impl
221
+ def concat_impl(
222
+ interpreter: Interpreter, op: Operation, *args: TensorValue
223
+ ) -> TensorValue:
224
+ axis = op.attrs.get("axis", 0)
225
+ unwrapped = [_unwrap(a) for a in args]
226
+ return _wrap(np.concatenate(unwrapped, axis=axis))
227
+
228
+
229
+ @tensor.elementwise_p.def_impl
230
+ def elementwise_impl(interpreter: Interpreter, op: Operation, *args: Value) -> Any:
231
+ """Execute elementwise operation by iterating over tensor elements.
232
+
233
+ Note: args typed as Value (base class) because elementwise handles polymorphic
234
+ inputs - TensorValue for numeric tensors, or np.ndarray with dtype=object
235
+ containing encrypted values (BFVValue, etc.) that are processed element-wise.
236
+ """
237
+ # args are the input tensors (or scalars)
238
+ # op.regions[0] is the scalar computation graph
239
+
240
+ # 1. Determine shape from IR types and runtime args
241
+ shape = ()
242
+ for i, inp_val in enumerate(op.inputs):
243
+ if isinstance(inp_val.type, elt.TensorType):
244
+ if inp_val.type.shape != ():
245
+ # Found a non-scalar tensor input. Use its runtime shape.
246
+ # We assume the tracer ensured all non-scalar tensors have compatible shapes.
247
+ arg = args[i]
248
+ if hasattr(arg, "shape"):
249
+ shape = arg.shape
250
+ break
251
+
252
+ # 2. Construct output container
253
+ # We need to know the output type/dtype.
254
+ # op.outputs[0].type should give us a hint, but here we are in runtime.
255
+ # Let's just use a list or numpy array of objects for flexibility.
256
+ # Since we might be mixing types (e.g. Encrypted objects), object array is safest.
257
+ num_outputs = len(op.outputs)
258
+ results: Any
259
+ if num_outputs > 1:
260
+ results = [np.empty(shape, dtype=object) for _ in range(num_outputs)]
261
+ else:
262
+ results = np.empty(shape, dtype=object)
263
+
264
+ # 3. Iterate and execute
265
+ # Use np.ndindex for multi-dimensional iteration
266
+ subgraph = op.regions[0]
267
+
268
+ if shape == ():
269
+ # Scalar case - return first element from result list
270
+ result = interpreter.evaluate_graph(subgraph, list(args))
271
+ return result[0] if len(result) == 1 else result
272
+
273
+ for index in np.ndindex(shape):
274
+ # Prepare inputs for this element (list ordered by subgraph.inputs)
275
+ scalar_inputs = []
276
+ for i, arg in enumerate(args):
277
+ outer_val = op.inputs[i]
278
+ # Check if this argument should be iterated based on OUTER IR type
279
+ if (
280
+ isinstance(outer_val.type, elt.TensorType)
281
+ and outer_val.type.shape != ()
282
+ ):
283
+ # Tensor argument: pick element (arg is array-like at runtime)
284
+ # Wrap scalar in TensorValue to maintain Value-only contract
285
+ elem = cast(Any, arg)[index]
286
+ if isinstance(elem, Value):
287
+ scalar_inputs.append(elem)
288
+ else:
289
+ scalar_inputs.append(_wrap(np.array(elem))) # type: ignore[index]
290
+ else:
291
+ # Scalar/Broadcast argument: use as is
292
+ # Ensure it is wrapped (it should be, but double check)
293
+ if not isinstance(arg, Value):
294
+ scalar_inputs.append(_wrap(np.array(arg)))
295
+ else:
296
+ scalar_inputs.append(arg)
297
+
298
+ # Recursive execution
299
+ scalar_out_list = interpreter.evaluate_graph(subgraph, scalar_inputs)
300
+ scalar_out = (
301
+ scalar_out_list[0] if len(scalar_out_list) == 1 else scalar_out_list
302
+ )
303
+
304
+ # Unwrap result if it's a TensorValue (to store in numpy array)
305
+ # We store raw values in the object array for now, but will wrap the final array
306
+ if isinstance(scalar_out, TensorValue):
307
+ scalar_out = scalar_out.unwrap()
308
+ if scalar_out.shape == ():
309
+ scalar_out = scalar_out.item()
310
+
311
+ if num_outputs > 1:
312
+ for i, val in enumerate(scalar_out):
313
+ results[i][index] = val
314
+ else:
315
+ results[index] = scalar_out
316
+
317
+ # Wrap results in TensorValue if possible
318
+ if num_outputs > 1:
319
+ return [_wrap(res) for res in results]
320
+ else:
321
+ return _wrap(results)
322
+
323
+
324
+ # Global cache for compiled StableHLO executables
325
+ _STABLEHLO_CACHE: dict[str, Any] = {}
326
+
327
+
328
+ @tensor.run_jax_p.def_impl
329
+ def run_jax_impl(
330
+ interpreter: Interpreter, op: Operation, *args: TensorValue
331
+ ) -> TensorValue | list[TensorValue]:
332
+ """Execute JAX function."""
333
+ t0 = time.time()
334
+
335
+ # Execute via StableHLO
336
+ stablehlo_code = op.attrs.get("stablehlo_code")
337
+ if stablehlo_code is None:
338
+ raise NotImplementedError(
339
+ "run_jax execution requires 'stablehlo_code' attribute"
340
+ )
341
+
342
+ # Compile StableHLO
343
+ client = jxt.backend.get_backend()
344
+
345
+ # Use SHA256 of code as cache key for stability across runs
346
+ # Note: We assume compile_options are constant (num_replicas=1, num_partitions=1)
347
+ code_hash = hashlib.sha256(stablehlo_code.encode("utf-8")).hexdigest()
348
+
349
+ if code_hash in _STABLEHLO_CACHE:
350
+ compiled = _STABLEHLO_CACHE[code_hash]
351
+ else:
352
+ compile_options = compiler.get_compile_options(num_replicas=1, num_partitions=1)
353
+
354
+ # Try disk cache
355
+ cache_dir = interpreter.root_dir / "cache" / "jax"
356
+ cache_dir.mkdir(parents=True, exist_ok=True)
357
+ cache_path = str(cache_dir / f"{code_hash}.pjrt")
358
+ loaded_from_disk = False
359
+
360
+ if os.path.exists(cache_path):
361
+ try:
362
+ with open(cache_path, "rb") as f:
363
+ serialized = f.read()
364
+ compiled = client.deserialize_executable(
365
+ serialized, client.devices(), compile_options
366
+ )
367
+ loaded_from_disk = True
368
+ # print(f"[JAX] Loaded compiled executable from {cache_path}")
369
+ except Exception as e:
370
+ print(f"[JAX] Failed to load from disk cache: {e}")
371
+
372
+ if not loaded_from_disk:
373
+ try:
374
+ compiled = client.compile_and_load(
375
+ stablehlo_code, client.devices(), compile_options
376
+ )
377
+ # Save to disk
378
+ try:
379
+ # Directory creation handled above
380
+ with open(cache_path, "wb") as f:
381
+ f.write(client.serialize_executable(compiled))
382
+ # print(f"[JAX] Saved compiled executable to {cache_path}")
383
+ except Exception as e:
384
+ print(f"[JAX] Failed to save to disk cache: {e}")
385
+ except Exception as e:
386
+ raise RuntimeError(f"StableHLO compile failed: {e}") from e
387
+
388
+ _STABLEHLO_CACHE[code_hash] = compiled
389
+
390
+ # Cast inputs to expected types (Boundary Type Guard)
391
+ # This allows users to pass Python ints/floats to functions expecting f32/i32
392
+ t1 = time.time()
393
+
394
+ jax_input_args = []
395
+ for i, arg in enumerate(args):
396
+ # arg is TensorValue
397
+ if i < len(op.inputs):
398
+ input_type = op.inputs[i].type
399
+ # Check if we need casting
400
+ if isinstance(input_type, elt.TensorType):
401
+ dtype = dtypes.to_jax(cast(elt.ScalarType, input_type.element_type))
402
+ # Get as JAX array
403
+ if isinstance(arg, TensorValue):
404
+ val = arg.as_jax()
405
+ else:
406
+ val = jnp.asarray(arg)
407
+
408
+ if (
409
+ dtype is not None
410
+ and isinstance(val, (jnp.ndarray, np.ndarray))
411
+ and val.dtype != dtype
412
+ ):
413
+ val = val.astype(dtype)
414
+ jax_input_args.append(val)
415
+ else:
416
+ if isinstance(arg, TensorValue):
417
+ jax_input_args.append(arg.as_jax())
418
+ else:
419
+ jax_input_args.append(jnp.asarray(arg))
420
+ else:
421
+ if isinstance(arg, TensorValue):
422
+ jax_input_args.append(arg.as_jax())
423
+ else:
424
+ jax_input_args.append(jnp.asarray(arg))
425
+
426
+ # Handle JAX's unused parameter elimination via arg_keep_map
427
+ arg_keep_map = op.attrs.get("arg_keep_map")
428
+ if arg_keep_map is not None:
429
+ # Filter out arguments that were eliminated by JAX during compilation
430
+ jax_input_args = [jax_input_args[i] for i in arg_keep_map]
431
+
432
+ # Convert args to JAX arrays
433
+ t2 = time.time()
434
+ # jax_input_args are already JAX arrays (or will be handled by execute_sharded if not)
435
+ jax_args = jax_input_args
436
+
437
+ try:
438
+ t3 = time.time()
439
+ result = compiled.execute_sharded(jax_args)
440
+ t4 = time.time()
441
+ arrays = result.disassemble_into_single_device_arrays()
442
+ flat: list[TensorValue] = []
443
+ for lst in arrays:
444
+ if isinstance(lst, list) and len(lst) == 1:
445
+ # Wrap JAX array directly, avoiding np.asarray
446
+ flat.append(_wrap(lst[0]))
447
+ else:
448
+ flat.extend(_wrap(a) for a in lst)
449
+ t5 = time.time()
450
+
451
+ if interpreter.tracer:
452
+ p = interpreter.tracer
453
+ p.log_custom_event("JAX Compile/Cache", t0, t1, cat="jax")
454
+ p.log_custom_event("JAX Prep", t1, t2, cat="jax")
455
+ p.log_custom_event("JAX Transfer In", t2, t3, cat="jax")
456
+ p.log_custom_event("JAX Exec", t3, t4, cat="jax")
457
+ p.log_custom_event("JAX Transfer Out", t4, t5, cat="jax")
458
+
459
+ # If single output, return it directly (but run_jax usually returns list of vars)
460
+ # The primitive expects a list of results matching outputs.
461
+ # If op has 1 output, flat should have 1 element.
462
+ if len(op.outputs) == 1 and len(flat) == 1:
463
+ return flat[0]
464
+ return flat
465
+ except Exception as e:
466
+ raise RuntimeError(f"StableHLO execute failed: {e}") from e
467
+
468
+
469
+ @tensor.gather_p.def_impl
470
+ def gather_impl(
471
+ interpreter: Interpreter, op: Operation, operand: TensorValue, indices: TensorValue
472
+ ) -> TensorValue:
473
+ axis = op.attrs.get("axis", 0)
474
+ operand_arr = _unwrap(operand)
475
+ indices_arr = _unwrap(indices)
476
+ # Ensure indices are integers (they might be JAX arrays or numpy arrays)
477
+ if hasattr(indices_arr, "astype"):
478
+ indices_arr = indices_arr.astype(int)
479
+ return _wrap(np.take(operand_arr, indices_arr, axis=axis))
480
+
481
+
482
+ @tensor.slice_p.def_impl
483
+ def slice_impl(
484
+ interpreter: Interpreter, op: Operation, operand: TensorValue
485
+ ) -> TensorValue:
486
+ starts = op.attrs["starts"]
487
+ ends = op.attrs["ends"]
488
+ strides = op.attrs.get("strides")
489
+
490
+ slices: list[Any] = []
491
+ for i in range(len(starts)):
492
+ start = starts[i]
493
+ end = ends[i]
494
+ stride = strides[i] if strides else 1
495
+ slices.append(slice(start, end, stride))
496
+
497
+ operand_arr = _unwrap(operand)
498
+ # If operand is numpy array, we can slice directly
499
+ # If operand has more dimensions than slices provided, we assume full slice for remaining
500
+ if len(slices) < operand_arr.ndim:
501
+ slices.append(Ellipsis)
502
+
503
+ return _wrap(operand_arr[tuple(slices)])
504
+
505
+
506
+ @tensor.reshape_p.def_impl
507
+ def reshape_impl(
508
+ interpreter: Interpreter, op: Operation, tensor_data: TensorValue
509
+ ) -> TensorValue:
510
+ new_shape = op.attrs["new_shape"]
511
+ return _wrap(_unwrap(tensor_data).reshape(new_shape))
512
+
513
+
514
+ @tensor.transpose_p.def_impl
515
+ def transpose_impl(
516
+ interpreter: Interpreter, op: Operation, tensor_data: TensorValue
517
+ ) -> TensorValue:
518
+ perm = op.attrs.get("perm")
519
+ return _wrap(np.transpose(_unwrap(tensor_data), axes=perm))