mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/kernels/spu.py DELETED
@@ -1,341 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from dataclasses import dataclass
18
- from typing import Any, ClassVar
19
-
20
- import numpy as np
21
- import spu.api as spu_api
22
- import spu.libspu as libspu
23
-
24
- from mplang.v1.core import (
25
- BOOL,
26
- FLOAT32,
27
- FLOAT64,
28
- INT8,
29
- INT16,
30
- INT32,
31
- INT64,
32
- UINT8,
33
- UINT16,
34
- UINT32,
35
- UINT64,
36
- DType,
37
- PFunction,
38
- )
39
- from mplang.v1.kernels.base import cur_kctx, kernel_def
40
- from mplang.v1.kernels.value import (
41
- TensorValue,
42
- Value,
43
- ValueDecodeError,
44
- ValueProtoBuilder,
45
- ValueProtoReader,
46
- register_value,
47
- )
48
- from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
49
- from mplang.v1.runtime.link_comm import LinkCommunicator
50
-
51
-
52
- def shape_spu_to_np(spu_shape: Any) -> tuple[int, ...]:
53
- """Convert SPU shape to numpy tuple."""
54
- return tuple(spu_shape.dims)
55
-
56
-
57
- def dtype_spu_to_mpl(spu_dtype: libspu.DataType) -> DType:
58
- """Convert libspu.DataType to MPLang DType."""
59
- MAP = {
60
- libspu.DataType.DT_F32: FLOAT32,
61
- libspu.DataType.DT_F64: FLOAT64,
62
- libspu.DataType.DT_I1: BOOL,
63
- libspu.DataType.DT_I8: INT8,
64
- libspu.DataType.DT_U8: UINT8,
65
- libspu.DataType.DT_I16: INT16,
66
- libspu.DataType.DT_U16: UINT16,
67
- libspu.DataType.DT_I32: INT32,
68
- libspu.DataType.DT_U32: UINT32,
69
- libspu.DataType.DT_I64: INT64,
70
- libspu.DataType.DT_U64: UINT64,
71
- }
72
- return MAP[spu_dtype]
73
-
74
-
75
- @register_value
76
- @dataclass
77
- class SpuValue(Value):
78
- """SPU value container for secure computation (Value type)."""
79
-
80
- KIND: ClassVar[str] = "mplang.spu.SpuValue"
81
- WIRE_VERSION: ClassVar[int] = 1
82
-
83
- shape: tuple[int, ...]
84
- dtype: DType # Now uses MPLang's unified DType
85
- vtype: libspu.Visibility
86
- share: libspu.Share
87
-
88
- def __repr__(self) -> str:
89
- return f"SpuValue({self.shape},{self.dtype},{self.vtype})"
90
-
91
- def to_proto(self) -> _value_pb2.ValueProto:
92
- """Serialize SpuValue to wire format.
93
-
94
- libspu.Share has two attributes:
95
- - meta: bytes (protobuf serialized metadata)
96
- - share_chunks: list[bytes] (the actual secret share data)
97
-
98
- Strategy: Store shape/dtype/vtype in runtime_attrs, concatenate share.meta + all chunks in payload.
99
- """
100
- # Store metadata in runtime_attrs; keep chunk lengths for payload splitting
101
- chunk_lengths = [len(chunk) for chunk in self.share.share_chunks]
102
-
103
- # Payload contains only share chunks (meta stored in attrs)
104
- payload = b""
105
- for chunk in self.share.share_chunks:
106
- payload += chunk
107
-
108
- return (
109
- ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
110
- .set_attr("shape", list(self.shape))
111
- .set_attr("dtype", self.dtype.name) # Serialize DType name
112
- .set_attr("vtype", int(self.vtype))
113
- .set_attr("share_meta", self.share.meta)
114
- .set_attr("chunk_lengths", chunk_lengths)
115
- .set_payload(payload)
116
- .build()
117
- )
118
-
119
- @classmethod
120
- def from_proto(cls, proto: _value_pb2.ValueProto) -> SpuValue:
121
- """Deserialize SpuValue from wire format."""
122
- reader = ValueProtoReader(proto)
123
- if reader.version != cls.WIRE_VERSION:
124
- raise ValueDecodeError(f"Unsupported SpuValue version {reader.version}")
125
-
126
- # Read metadata from runtime_attrs
127
- shape = tuple(reader.get_attr("shape"))
128
- dtype_name = reader.get_attr("dtype")
129
- # Reconstruct DType from serialized name (numpy dtype string)
130
- dtype = DType.from_numpy(dtype_name)
131
- vtype = libspu.Visibility(reader.get_attr("vtype"))
132
- share_meta = reader.get_attr("share_meta")
133
- chunk_lengths = reader.get_attr("chunk_lengths")
134
-
135
- # Parse payload: [chunk_0][chunk_1]...
136
- payload = reader.payload
137
- offset = 0
138
-
139
- share_chunks: list[bytes] = []
140
- for chunk_len in chunk_lengths:
141
- chunk = payload[offset : offset + chunk_len]
142
- offset += chunk_len
143
- share_chunks.append(chunk)
144
-
145
- # Reconstruct libspu.Share
146
- share = libspu.Share()
147
- share.meta = share_meta
148
- share.share_chunks = share_chunks
149
-
150
- return cls(
151
- shape=shape,
152
- dtype=dtype,
153
- vtype=vtype,
154
- share=share,
155
- )
156
-
157
-
158
- def _get_spu_config_and_world() -> tuple[libspu.RuntimeConfig, int]:
159
- kctx = cur_kctx()
160
- cfg = kctx.runtime.get_state("spu.config")
161
- world = kctx.runtime.get_state("spu.world")
162
- if cfg is None or world is None:
163
- raise RuntimeError("SPU kernel state not initialized (config/world)")
164
- return cfg, int(world)
165
-
166
-
167
- def _register_spu_env(
168
- config: libspu.RuntimeConfig, world_size: int, link_ctx: LinkCommunicator | None
169
- ) -> None:
170
- """Register SPU config/world/link inside current kernel context.
171
-
172
- Idempotent: if config/world already set, they must match; link is recorded per rank.
173
- This replaces previous global fallback seeding logic.
174
- """
175
- kctx = cur_kctx()
176
- prev_cfg = kctx.runtime.get_state("spu.config")
177
- prev_world = kctx.runtime.get_state("spu.world")
178
- if prev_cfg is None:
179
- kctx.runtime.set_state("spu.config", config)
180
- kctx.runtime.set_state("spu.world", world_size)
181
- else:
182
- # libspu RuntimeConfig may not implement __eq__; compare serialized repr
183
- same_cfg = (
184
- prev_cfg.SerializeToString() == config.SerializeToString() # type: ignore[attr-defined]
185
- if hasattr(prev_cfg, "SerializeToString")
186
- and hasattr(config, "SerializeToString")
187
- else prev_cfg == config
188
- )
189
- if not (same_cfg and prev_world == world_size):
190
- raise RuntimeError("Conflicting SPU env registration")
191
- # Store single link per runtime (one runtime per rank)
192
- if link_ctx is not None:
193
- kctx.runtime.set_state("spu.link", link_ctx)
194
-
195
-
196
- @kernel_def("spu.seed_env")
197
- def _spu_seed_env(pfunc: PFunction, *args: Any) -> Any:
198
- """Backend kernel to seed SPU environment.
199
-
200
- NOTE: This is a control-plane style operation (side-effect: installs SPU
201
- config/link into the per-runtime state pocket) rather than a pure data
202
- transformation. It remains a kernel temporarily for minimal surface
203
- changes during the backend deglobalization refactor. Callers MUST invoke
204
- it explicitly via `runtime.run_kernel(seed_pfunc, [])`, never through
205
- `Evaluator.evaluate` (fast-path removed) to keep IR evaluation semantics
206
- clean. A future cleanup may promote this to a dedicated runtime helper
207
- (e.g. `seed_spu_env(runtime, config, world, link)`), at which point this
208
- kernel can be deprecated.
209
-
210
- Required attrs: config (RuntimeConfig), world (int)
211
- Optional attr: link (LinkCommunicator or None)
212
- """
213
- cfg = pfunc.attrs.get("config")
214
- world = pfunc.attrs.get("world")
215
- link_ctx = pfunc.attrs.get("link", None)
216
- if cfg is None or world is None:
217
- raise ValueError("spu.seed_env requires 'config' and 'world' attrs")
218
- _register_spu_env(cfg, int(world), link_ctx)
219
- return None
220
-
221
-
222
- @kernel_def("spu.makeshares")
223
- def _spu_makeshares(pfunc: PFunction, tensor: TensorValue) -> tuple[SpuValue, ...]:
224
- """Create SPU shares from input TensorValue data."""
225
- visibility_value = pfunc.attrs.get("visibility", libspu.Visibility.VIS_SECRET.value)
226
- if isinstance(visibility_value, int):
227
- visibility = libspu.Visibility(visibility_value)
228
- else:
229
- visibility = visibility_value
230
-
231
- arg = tensor.to_numpy()
232
- cfg, world = _get_spu_config_and_world()
233
- spu_io = spu_api.Io(world, cfg)
234
- shares = spu_io.make_shares(arg, visibility)
235
- assert len(shares) == world, f"Expected {world} shares, got {len(shares)}"
236
- # Store MPLang DType instead of libspu.DataType
237
- dtype = DType.from_numpy(arg.dtype)
238
- return tuple(
239
- SpuValue(
240
- shape=arg.shape,
241
- dtype=dtype,
242
- vtype=visibility,
243
- share=share,
244
- )
245
- for share in shares
246
- )
247
-
248
-
249
- @kernel_def("spu.reconstruct")
250
- def _spu_reconstruct(pfunc: PFunction, *shares: SpuValue) -> TensorValue:
251
- """Reconstruct plaintext data from SPU shares."""
252
- cfg, world = _get_spu_config_and_world()
253
- assert len(shares) == world, f"Expected {world} shares, got {len(shares)}"
254
- for i, share in enumerate(shares):
255
- if not isinstance(share, SpuValue):
256
- raise ValueError(
257
- f"Input {i} must be SpuValue, got {type(share)}. Reconstruction requires SPU shares as input."
258
- )
259
- spu_args: list[SpuValue] = list(shares) # type: ignore
260
- share_payloads = [spu_arg.share for spu_arg in spu_args]
261
- spu_io = spu_api.Io(world, cfg)
262
- reconstructed = spu_io.reconstruct(share_payloads)
263
- base = np.array(reconstructed, copy=False)
264
- # Respect semantic dtype/shape recorded on shares (all shares share same meta).
265
- semantic_dtype = shares[0].dtype.to_numpy() # DType now has to_numpy() method
266
- semantic_shape = shares[0].shape
267
- restored = np.asarray(base, dtype=semantic_dtype).reshape(semantic_shape)
268
- return TensorValue(np.array(restored, copy=False))
269
-
270
-
271
- @kernel_def("spu.run_pphlo")
272
- def _spu_run_mlir(pfunc: PFunction, *args: SpuValue) -> tuple[SpuValue, ...]:
273
- """Execute compiled SPU function (spu.run_pphlo) and return SpuValue outputs.
274
-
275
- Participation rule: a rank participates iff its entry in the stored
276
- link_ctx list is non-None. This allows us to allocate a world-sized list
277
- (indexed by global rank) and simply assign None for non-SPU parties.
278
- """
279
- if pfunc.fn_type != "spu.run_pphlo":
280
- raise ValueError(
281
- f"Unsupported format: {pfunc.fn_type}. Expected 'spu.run_pphlo'"
282
- )
283
-
284
- cfg, _ = _get_spu_config_and_world()
285
- kctx = cur_kctx()
286
- link_ctx = kctx.runtime.get_state("spu.link")
287
- if link_ctx is None:
288
- raise RuntimeError("Rank not participating in SPU; no link set via seed_env")
289
-
290
- # Lazy runtime cache under key spu.runtime
291
- spu_rt = kctx.runtime.get_state("spu.runtime")
292
- if spu_rt is None:
293
- spu_rt = spu_api.Runtime(link_ctx.get_lctx(), cfg)
294
- kctx.runtime.set_state("spu.runtime", spu_rt)
295
-
296
- # Validate that all inputs are SpuValue objects
297
- for i, arg in enumerate(args):
298
- if not isinstance(arg, SpuValue):
299
- raise ValueError(
300
- f"Input {i} must be SpuValue, got {type(arg)}. In real SPU environments, all inputs must be SpuValue objects."
301
- )
302
-
303
- # Cast for type checking (we've validated above)
304
- spu_args: list[SpuValue] = list(args) # type: ignore
305
-
306
- # Reconstruct SPU executable from MLIR code and metadata
307
- if pfunc.fn_text is None:
308
- raise ValueError("PFunction does not contain executable data")
309
- if not isinstance(pfunc.fn_text, str):
310
- raise ValueError(f"Expected str, got {type(pfunc.fn_text)}")
311
-
312
- # Extract metadata for executable reconstruction
313
- attrs: dict[str, Any] = dict(pfunc.attrs or {})
314
- input_names = attrs.get("input_names", [])
315
- output_names = attrs.get("output_names", [])
316
- executable_name = attrs.get("executable_name", pfunc.fn_name)
317
-
318
- # Create executable from MLIR code and metadata
319
- executable = libspu.Executable(
320
- name=executable_name,
321
- input_names=input_names,
322
- output_names=output_names,
323
- code=pfunc.fn_text,
324
- )
325
-
326
- # Set input variables in SPU runtime
327
- for idx, spu_arg in enumerate(spu_args):
328
- spu_rt.set_var(input_names[idx], spu_arg.share)
329
- spu_rt.run(executable)
330
- shares = [spu_rt.get_var(out_name) for out_name in output_names]
331
- metas = [spu_rt.get_var_meta(out_name) for out_name in output_names]
332
- results: list[SpuValue] = [
333
- SpuValue(
334
- shape=shape_spu_to_np(meta.shape),
335
- dtype=dtype_spu_to_mpl(meta.data_type),
336
- vtype=meta.visibility,
337
- share=shares[idx],
338
- )
339
- for idx, meta in enumerate(metas)
340
- ]
341
- return tuple(results)
@@ -1,44 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from mplang.v1.core import PFunction
18
- from mplang.v1.kernels.base import kernel_def
19
- from mplang.v1.kernels.value import TableValue
20
-
21
-
22
- @kernel_def("duckdb.run_sql")
23
- def _duckdb_sql(pfunc: PFunction, *args: TableValue) -> TableValue:
24
- import duckdb
25
-
26
- # TODO: maybe we could translate the sql to duckdb dialect
27
- # instead of raising an exception
28
- if pfunc.attrs.get("dialect") != "duckdb":
29
- raise ValueError("duckdb.run_sql must have dialect=duckdb attr")
30
-
31
- conn = duckdb.connect(":memory:")
32
- if args:
33
- in_names = pfunc.attrs.get("in_names")
34
- if in_names is None:
35
- raise ValueError("duckdb sql missing in_names attr")
36
- for arg, name in zip(args, in_names, strict=True):
37
- # Use Arrow directly for zero-copy data transfer
38
- arrow_table = arg.to_arrow()
39
- conn.register(name, arrow_table)
40
- # Fetch result as Arrow table for consistency
41
- if pfunc.fn_text is None:
42
- raise ValueError("SQL function text is None")
43
- res_arrow = conn.execute(pfunc.fn_text).fetch_arrow_table()
44
- return TableValue(res_arrow)
@@ -1,90 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from typing import Any
18
-
19
- import jax
20
- import jax.extend as jxt
21
- import jax.numpy as jnp
22
- import numpy as np
23
- from jax._src import compiler
24
-
25
- from mplang.v1.core import PFunction
26
- from mplang.v1.kernels.base import cur_kctx, kernel_def
27
- from mplang.v1.kernels.value import TensorValue
28
-
29
-
30
- @kernel_def("mlir.stablehlo")
31
- def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
32
- if pfunc.fn_type != "mlir.stablehlo":
33
- raise ValueError("stablehlo kernel received wrong fn_type")
34
-
35
- mlir_text = pfunc.fn_text
36
- if mlir_text is None:
37
- raise ValueError("StableHLO kernel missing fn_text")
38
- if isinstance(mlir_text, bytes):
39
- mlir_text = mlir_text.decode("utf-8")
40
-
41
- # Flat-key compile cache: stablehlo.compile_cache.<hash>
42
- ctx = cur_kctx()
43
- rt = ctx.runtime
44
- import hashlib
45
-
46
- h = hashlib.sha256(mlir_text.encode("utf-8")).hexdigest()[:16]
47
- key = f"stablehlo.compile_cache.{h}"
48
- compiled = rt.get_state(key)
49
- if compiled is None:
50
- client = jxt.backend.get_backend()
51
- compile_options = compiler.get_compile_options(num_replicas=1, num_partitions=1)
52
-
53
- try:
54
- compiled = client.compile_and_load(
55
- mlir_text, client.devices(), compile_options
56
- )
57
- except Exception as e: # pragma: no cover
58
- raise RuntimeError(f"StableHLO compile failed: {e}") from e
59
- rt.set_state(key, compiled)
60
-
61
- # Handle JAX's unused parameter elimination via arg_keep_map
62
- runtime_args = args
63
- if "arg_keep_map" in pfunc.attrs:
64
- keep_indices = pfunc.attrs["arg_keep_map"]
65
- # Filter out arguments that were eliminated by JAX during compilation
66
- runtime_args = tuple(args[i] for i in keep_indices)
67
-
68
- tensor_args: list[TensorValue] = []
69
- for idx, arg in enumerate(runtime_args):
70
- if not isinstance(arg, TensorValue):
71
- raise TypeError(
72
- f"StableHLO kernel expects TensorValue inputs, got {type(arg).__name__} at position {idx}"
73
- )
74
- tensor_args.append(arg)
75
-
76
- jax_args = [
77
- jax.device_put(jnp.asarray(tensor.to_numpy())) for tensor in tensor_args
78
- ]
79
-
80
- try:
81
- # Execute with the new LoadedExecutable interface
82
- result = compiled.execute(jax_args)
83
-
84
- # Use jax.tree_util.tree_flatten to robustly handle any PyTree structure
85
- flat_results, _ = jax.tree_util.tree_flatten(result)
86
- flat = [TensorValue(np.asarray(item)) for item in flat_results]
87
-
88
- return tuple(flat)
89
- except Exception as e: # pragma: no cover
90
- raise RuntimeError(f"StableHLO execute failed: {e}") from e