mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -15,16 +15,38 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  from dataclasses import dataclass
18
- from typing import Any
18
+ from typing import Any, ClassVar
19
19
 
20
20
  import numpy as np
21
21
  import spu.api as spu_api
22
22
  import spu.libspu as libspu
23
23
 
24
- from mplang.core.mptype import TensorLike
25
- from mplang.core.pfunc import PFunction
26
- from mplang.kernels.base import cur_kctx, kernel_def
27
- from mplang.runtime.link_comm import LinkCommunicator
24
+ from mplang.v1.core import (
25
+ BOOL,
26
+ FLOAT32,
27
+ FLOAT64,
28
+ INT8,
29
+ INT16,
30
+ INT32,
31
+ INT64,
32
+ UINT8,
33
+ UINT16,
34
+ UINT32,
35
+ UINT64,
36
+ DType,
37
+ PFunction,
38
+ )
39
+ from mplang.v1.kernels.base import cur_kctx, kernel_def
40
+ from mplang.v1.kernels.value import (
41
+ TensorValue,
42
+ Value,
43
+ ValueDecodeError,
44
+ ValueProtoBuilder,
45
+ ValueProtoReader,
46
+ register_value,
47
+ )
48
+ from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
49
+ from mplang.v1.runtime.link_comm import LinkCommunicator
28
50
 
29
51
 
30
52
  def shape_spu_to_np(spu_shape: Any) -> tuple[int, ...]:
@@ -32,36 +54,106 @@ def shape_spu_to_np(spu_shape: Any) -> tuple[int, ...]:
32
54
  return tuple(spu_shape.dims)
33
55
 
34
56
 
35
- def dtype_spu_to_np(spu_dtype: Any) -> np.dtype:
36
- """Convert SPU dtype to numpy dtype."""
57
+ def dtype_spu_to_mpl(spu_dtype: libspu.DataType) -> DType:
58
+ """Convert libspu.DataType to MPLang DType."""
37
59
  MAP = {
38
- libspu.DataType.DT_F32: np.float32,
39
- libspu.DataType.DT_F64: np.float64,
40
- libspu.DataType.DT_I1: np.bool_,
41
- libspu.DataType.DT_I8: np.int8,
42
- libspu.DataType.DT_U8: np.uint8,
43
- libspu.DataType.DT_I16: np.int16,
44
- libspu.DataType.DT_U16: np.uint16,
45
- libspu.DataType.DT_I32: np.int32,
46
- libspu.DataType.DT_U32: np.uint32,
47
- libspu.DataType.DT_I64: np.int64,
48
- libspu.DataType.DT_U64: np.uint64,
60
+ libspu.DataType.DT_F32: FLOAT32,
61
+ libspu.DataType.DT_F64: FLOAT64,
62
+ libspu.DataType.DT_I1: BOOL,
63
+ libspu.DataType.DT_I8: INT8,
64
+ libspu.DataType.DT_U8: UINT8,
65
+ libspu.DataType.DT_I16: INT16,
66
+ libspu.DataType.DT_U16: UINT16,
67
+ libspu.DataType.DT_I32: INT32,
68
+ libspu.DataType.DT_U32: UINT32,
69
+ libspu.DataType.DT_I64: INT64,
70
+ libspu.DataType.DT_U64: UINT64,
49
71
  }
50
- return MAP[spu_dtype] # type: ignore[return-value]
72
+ return MAP[spu_dtype]
51
73
 
52
74
 
75
+ @register_value
53
76
  @dataclass
54
- class SpuValue:
55
- """SPU value container for secure computation."""
77
+ class SpuValue(Value):
78
+ """SPU value container for secure computation (Value type)."""
79
+
80
+ KIND: ClassVar[str] = "mplang.spu.SpuValue"
81
+ WIRE_VERSION: ClassVar[int] = 1
56
82
 
57
83
  shape: tuple[int, ...]
58
- dtype: Any
84
+ dtype: DType # Now uses MPLang's unified DType
59
85
  vtype: libspu.Visibility
60
86
  share: libspu.Share
61
87
 
62
88
  def __repr__(self) -> str:
63
89
  return f"SpuValue({self.shape},{self.dtype},{self.vtype})"
64
90
 
91
+ def to_proto(self) -> _value_pb2.ValueProto:
92
+ """Serialize SpuValue to wire format.
93
+
94
+ libspu.Share has two attributes:
95
+ - meta: bytes (protobuf serialized metadata)
96
+ - share_chunks: list[bytes] (the actual secret share data)
97
+
98
+ Strategy: Store shape/dtype/vtype in runtime_attrs, concatenate share.meta + all chunks in payload.
99
+ """
100
+ # Store metadata in runtime_attrs; keep chunk lengths for payload splitting
101
+ chunk_lengths = [len(chunk) for chunk in self.share.share_chunks]
102
+
103
+ # Payload contains only share chunks (meta stored in attrs)
104
+ payload = b""
105
+ for chunk in self.share.share_chunks:
106
+ payload += chunk
107
+
108
+ return (
109
+ ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
110
+ .set_attr("shape", list(self.shape))
111
+ .set_attr("dtype", self.dtype.name) # Serialize DType name
112
+ .set_attr("vtype", int(self.vtype))
113
+ .set_attr("share_meta", self.share.meta)
114
+ .set_attr("chunk_lengths", chunk_lengths)
115
+ .set_payload(payload)
116
+ .build()
117
+ )
118
+
119
+ @classmethod
120
+ def from_proto(cls, proto: _value_pb2.ValueProto) -> SpuValue:
121
+ """Deserialize SpuValue from wire format."""
122
+ reader = ValueProtoReader(proto)
123
+ if reader.version != cls.WIRE_VERSION:
124
+ raise ValueDecodeError(f"Unsupported SpuValue version {reader.version}")
125
+
126
+ # Read metadata from runtime_attrs
127
+ shape = tuple(reader.get_attr("shape"))
128
+ dtype_name = reader.get_attr("dtype")
129
+ # Reconstruct DType from serialized name (numpy dtype string)
130
+ dtype = DType.from_numpy(dtype_name)
131
+ vtype = libspu.Visibility(reader.get_attr("vtype"))
132
+ share_meta = reader.get_attr("share_meta")
133
+ chunk_lengths = reader.get_attr("chunk_lengths")
134
+
135
+ # Parse payload: [chunk_0][chunk_1]...
136
+ payload = reader.payload
137
+ offset = 0
138
+
139
+ share_chunks: list[bytes] = []
140
+ for chunk_len in chunk_lengths:
141
+ chunk = payload[offset : offset + chunk_len]
142
+ offset += chunk_len
143
+ share_chunks.append(chunk)
144
+
145
+ # Reconstruct libspu.Share
146
+ share = libspu.Share()
147
+ share.meta = share_meta
148
+ share.share_chunks = share_chunks
149
+
150
+ return cls(
151
+ shape=shape,
152
+ dtype=dtype,
153
+ vtype=vtype,
154
+ share=share,
155
+ )
156
+
65
157
 
66
158
  def _get_spu_config_and_world() -> tuple[libspu.RuntimeConfig, int]:
67
159
  kctx = cur_kctx()
@@ -128,33 +220,25 @@ def _spu_seed_env(pfunc: PFunction, *args: Any) -> Any:
128
220
 
129
221
 
130
222
  @kernel_def("spu.makeshares")
131
- def _spu_makeshares(pfunc: PFunction, *args: Any) -> Any:
132
- """Create SPU shares from input data.
133
-
134
- Args:
135
- pfunc: PFunction containing makeshares metadata
136
- args: Input data to be shared (single tensor)
137
-
138
- Returns:
139
- Tuple of SPU shares (SpuValue), one for each party.
140
- """
141
- assert len(args) == 1
142
-
223
+ def _spu_makeshares(pfunc: PFunction, tensor: TensorValue) -> tuple[SpuValue, ...]:
224
+ """Create SPU shares from input TensorValue data."""
143
225
  visibility_value = pfunc.attrs.get("visibility", libspu.Visibility.VIS_SECRET.value)
144
226
  if isinstance(visibility_value, int):
145
227
  visibility = libspu.Visibility(visibility_value)
146
228
  else:
147
229
  visibility = visibility_value
148
230
 
149
- arg = np.array(args[0], copy=False)
231
+ arg = tensor.to_numpy()
150
232
  cfg, world = _get_spu_config_and_world()
151
233
  spu_io = spu_api.Io(world, cfg)
152
234
  shares = spu_io.make_shares(arg, visibility)
153
235
  assert len(shares) == world, f"Expected {world} shares, got {len(shares)}"
236
+ # Store MPLang DType instead of libspu.DataType
237
+ dtype = DType.from_numpy(arg.dtype)
154
238
  return tuple(
155
239
  SpuValue(
156
240
  shape=arg.shape,
157
- dtype=arg.dtype,
241
+ dtype=dtype,
158
242
  vtype=visibility,
159
243
  share=share,
160
244
  )
@@ -163,24 +247,29 @@ def _spu_makeshares(pfunc: PFunction, *args: Any) -> Any:
163
247
 
164
248
 
165
249
  @kernel_def("spu.reconstruct")
166
- def _spu_reconstruct(pfunc: PFunction, *args: Any) -> Any:
250
+ def _spu_reconstruct(pfunc: PFunction, *shares: SpuValue) -> TensorValue:
167
251
  """Reconstruct plaintext data from SPU shares."""
168
252
  cfg, world = _get_spu_config_and_world()
169
- assert len(args) == world, f"Expected {world} shares, got {len(args)}"
170
- for i, arg in enumerate(args):
171
- if not isinstance(arg, SpuValue):
253
+ assert len(shares) == world, f"Expected {world} shares, got {len(shares)}"
254
+ for i, share in enumerate(shares):
255
+ if not isinstance(share, SpuValue):
172
256
  raise ValueError(
173
- f"Input {i} must be SpuValue, got {type(arg)}. Reconstruction requires SPU shares as input."
257
+ f"Input {i} must be SpuValue, got {type(share)}. Reconstruction requires SPU shares as input."
174
258
  )
175
- spu_args: list[SpuValue] = list(args) # type: ignore
176
- shares = [spu_arg.share for spu_arg in spu_args]
259
+ spu_args: list[SpuValue] = list(shares) # type: ignore
260
+ share_payloads = [spu_arg.share for spu_arg in spu_args]
177
261
  spu_io = spu_api.Io(world, cfg)
178
- reconstructed = spu_io.reconstruct(shares)
179
- return reconstructed
262
+ reconstructed = spu_io.reconstruct(share_payloads)
263
+ base = np.array(reconstructed, copy=False)
264
+ # Respect semantic dtype/shape recorded on shares (all shares share same meta).
265
+ semantic_dtype = shares[0].dtype.to_numpy() # DType now has to_numpy() method
266
+ semantic_shape = shares[0].shape
267
+ restored = np.asarray(base, dtype=semantic_dtype).reshape(semantic_shape)
268
+ return TensorValue(np.array(restored, copy=False))
180
269
 
181
270
 
182
271
  @kernel_def("spu.run_pphlo")
183
- def _spu_run_mlir(pfunc: PFunction, *args: Any) -> Any:
272
+ def _spu_run_mlir(pfunc: PFunction, *args: SpuValue) -> tuple[SpuValue, ...]:
184
273
  """Execute compiled SPU function (spu.run_pphlo) and return SpuValue outputs.
185
274
 
186
275
  Participation rule: a rank participates iff its entry in the stored
@@ -240,10 +329,10 @@ def _spu_run_mlir(pfunc: PFunction, *args: Any) -> Any:
240
329
  spu_rt.run(executable)
241
330
  shares = [spu_rt.get_var(out_name) for out_name in output_names]
242
331
  metas = [spu_rt.get_var_meta(out_name) for out_name in output_names]
243
- results: list[TensorLike] = [
332
+ results: list[SpuValue] = [
244
333
  SpuValue(
245
334
  shape=shape_spu_to_np(meta.shape),
246
- dtype=dtype_spu_to_np(meta.data_type),
335
+ dtype=dtype_spu_to_mpl(meta.data_type),
247
336
  vtype=meta.visibility,
248
337
  share=shares[idx],
249
338
  )
@@ -14,16 +14,14 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- from typing import Any
18
-
19
- from mplang.core.pfunc import PFunction
20
- from mplang.kernels.base import kernel_def
17
+ from mplang.v1.core import PFunction
18
+ from mplang.v1.kernels.base import kernel_def
19
+ from mplang.v1.kernels.value import TableValue
21
20
 
22
21
 
23
22
  @kernel_def("duckdb.run_sql")
24
- def _duckdb_sql(pfunc: PFunction, *args: Any) -> Any:
23
+ def _duckdb_sql(pfunc: PFunction, *args: TableValue) -> TableValue:
25
24
  import duckdb
26
- import pandas as pd
27
25
 
28
26
  # TODO: maybe we could translate the sql to duckdb dialect
29
27
  # instead of raising an exception
@@ -36,12 +34,11 @@ def _duckdb_sql(pfunc: PFunction, *args: Any) -> Any:
36
34
  if in_names is None:
37
35
  raise ValueError("duckdb sql missing in_names attr")
38
36
  for arg, name in zip(args, in_names, strict=True):
39
- if isinstance(arg, pd.DataFrame):
40
- df = arg
41
- elif isinstance(arg, list): # const list-of-dict for tests
42
- df = pd.DataFrame.from_records(arg)
43
- else:
44
- raise ValueError(f"unsupported duckdb input type {type(arg)}")
45
- conn.register(name, df)
46
- res_df = conn.execute(pfunc.fn_text).fetchdf()
47
- return res_df
37
+ # Use Arrow directly for zero-copy data transfer
38
+ arrow_table = arg.to_arrow()
39
+ conn.register(name, arrow_table)
40
+ # Fetch result as Arrow table for consistency
41
+ if pfunc.fn_text is None:
42
+ raise ValueError("SQL function text is None")
43
+ res_arrow = conn.execute(pfunc.fn_text).fetch_arrow_table()
44
+ return TableValue(res_arrow)
@@ -17,12 +17,14 @@ from __future__ import annotations
17
17
  from typing import Any
18
18
 
19
19
  import jax
20
+ import jax.extend as jxt
20
21
  import jax.numpy as jnp
21
- from jax._src import xla_bridge
22
- from jax.lib import xla_client as xc
22
+ import numpy as np
23
+ from jax._src import compiler
23
24
 
24
- from mplang.core.pfunc import PFunction
25
- from mplang.kernels.base import cur_kctx, kernel_def
25
+ from mplang.v1.core import PFunction
26
+ from mplang.v1.kernels.base import cur_kctx, kernel_def
27
+ from mplang.v1.kernels.value import TensorValue
26
28
 
27
29
 
28
30
  @kernel_def("mlir.stablehlo")
@@ -45,11 +47,13 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
45
47
  key = f"stablehlo.compile_cache.{h}"
46
48
  compiled = rt.get_state(key)
47
49
  if compiled is None:
48
- backend = jax.default_backend()
49
- client = xla_bridge.get_backend(backend)
50
- compile_options = xc.CompileOptions()
50
+ client = jxt.backend.get_backend()
51
+ compile_options = compiler.get_compile_options(num_replicas=1, num_partitions=1)
52
+
51
53
  try:
52
- compiled = client.compile(mlir_text, compile_options)
54
+ compiled = client.compile_and_load(
55
+ mlir_text, client.devices(), compile_options
56
+ )
53
57
  except Exception as e: # pragma: no cover
54
58
  raise RuntimeError(f"StableHLO compile failed: {e}") from e
55
59
  rt.set_state(key, compiled)
@@ -61,23 +65,26 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
61
65
  # Filter out arguments that were eliminated by JAX during compilation
62
66
  runtime_args = tuple(args[i] for i in keep_indices)
63
67
 
64
- jax_args = []
65
- for arg in runtime_args:
66
- if hasattr(arg, "numpy"):
67
- jax_arg = jnp.array(arg.numpy()) # type: ignore
68
- else:
69
- jax_arg = jnp.array(arg)
70
- jax_args.append(jax.device_put(jax_arg))
68
+ tensor_args: list[TensorValue] = []
69
+ for idx, arg in enumerate(runtime_args):
70
+ if not isinstance(arg, TensorValue):
71
+ raise TypeError(
72
+ f"StableHLO kernel expects TensorValue inputs, got {type(arg).__name__} at position {idx}"
73
+ )
74
+ tensor_args.append(arg)
75
+
76
+ jax_args = [
77
+ jax.device_put(jnp.asarray(tensor.to_numpy())) for tensor in tensor_args
78
+ ]
71
79
 
72
80
  try:
73
- result = compiled.execute_sharded(jax_args)
74
- arrays = result.disassemble_into_single_device_arrays()
75
- flat: list[Any] = []
76
- for lst in arrays:
77
- if isinstance(lst, list) and len(lst) == 1:
78
- flat.append(jnp.array(lst[0]))
79
- else:
80
- flat.extend([jnp.array(a) for a in lst])
81
+ # Execute with the new LoadedExecutable interface
82
+ result = compiled.execute(jax_args)
83
+
84
+ # Use jax.tree_util.tree_flatten to robustly handle any PyTree structure
85
+ flat_results, _ = jax.tree_util.tree_flatten(result)
86
+ flat = [TensorValue(np.asarray(item)) for item in flat_results]
87
+
81
88
  return tuple(flat)
82
89
  except Exception as e: # pragma: no cover
83
90
  raise RuntimeError(f"StableHLO execute failed: {e}") from e