mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,157 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Core components for multi-party computation.
17
+
18
+ This package provides the fundamental building blocks for multi-party computation,
19
+ including type systems, tracing mechanisms, and interpreter contexts.
20
+ """
21
+
22
+ # Core type system
23
+ # Communication interfaces & core symbols
24
+ from mplang.v1.core.cluster import ClusterSpec, Device, Node, RuntimeInfo
25
+ from mplang.v1.core.comm import (
26
+ CollectiveMixin,
27
+ CommunicatorBase,
28
+ ICollective,
29
+ ICommunicator,
30
+ )
31
+ from mplang.v1.core.context_mgr import cur_ctx, set_ctx, with_ctx
32
+ from mplang.v1.core.dtypes import (
33
+ BINARY,
34
+ BOOL,
35
+ COMPLEX64,
36
+ COMPLEX128,
37
+ DATE,
38
+ DECIMAL,
39
+ FLOAT16,
40
+ FLOAT32,
41
+ FLOAT64,
42
+ INT8,
43
+ INT16,
44
+ INT32,
45
+ INT64,
46
+ INTERVAL,
47
+ JSON,
48
+ STRING,
49
+ TIME,
50
+ TIMESTAMP,
51
+ UINT8,
52
+ UINT16,
53
+ UINT32,
54
+ UINT64,
55
+ UUID,
56
+ DType,
57
+ )
58
+ from mplang.v1.core.interp import InterpContext, InterpVar
59
+ from mplang.v1.core.mask import Mask
60
+ from mplang.v1.core.mpir import IrReader, IrWriter
61
+ from mplang.v1.core.mpobject import MPContext, MPObject
62
+ from mplang.v1.core.mptype import MPType, Rank, Shape
63
+ from mplang.v1.core.pfunc import PFunction, get_fn_name
64
+
65
+ # Import primitive-dependent symbols at the end to avoid cycles when frontend ops
66
+ # import from the core facade during package initialization.
67
+ from mplang.v1.core.primitive import (
68
+ builtin_function,
69
+ function,
70
+ pconv,
71
+ peval,
72
+ pmask,
73
+ pshfl,
74
+ pshfl_s,
75
+ psize,
76
+ uniform_cond,
77
+ while_loop,
78
+ )
79
+ from mplang.v1.core.table import TableLike, TableType
80
+ from mplang.v1.core.tensor import ScalarType, TensorLike, TensorType
81
+ from mplang.v1.core.tracer import (
82
+ TraceContext,
83
+ TracedFunction,
84
+ TraceVar,
85
+ VarNamer,
86
+ trace,
87
+ )
88
+
89
+ __all__ = [
90
+ "BINARY",
91
+ "BOOL",
92
+ "COMPLEX64",
93
+ "COMPLEX128",
94
+ "DATE",
95
+ "DECIMAL",
96
+ "FLOAT16",
97
+ "FLOAT32",
98
+ "FLOAT64",
99
+ "INT8",
100
+ "INT16",
101
+ "INT32",
102
+ "INT64",
103
+ "INTERVAL",
104
+ "JSON",
105
+ "STRING",
106
+ "TIME",
107
+ "TIMESTAMP",
108
+ "UINT8",
109
+ "UINT16",
110
+ "UINT32",
111
+ "UINT64",
112
+ "UUID",
113
+ "ClusterSpec",
114
+ "CollectiveMixin",
115
+ "CommunicatorBase",
116
+ "DType",
117
+ "Device",
118
+ "ICollective",
119
+ "ICommunicator",
120
+ "InterpContext",
121
+ "InterpVar",
122
+ "IrReader",
123
+ "IrWriter",
124
+ "MPContext",
125
+ "MPObject",
126
+ "MPType",
127
+ "Mask",
128
+ "Node",
129
+ "PFunction",
130
+ "Rank",
131
+ "RuntimeInfo",
132
+ "ScalarType",
133
+ "Shape",
134
+ "TableLike",
135
+ "TableType",
136
+ "TensorLike",
137
+ "TensorType",
138
+ "TraceContext",
139
+ "TraceVar",
140
+ "TracedFunction",
141
+ "VarNamer",
142
+ "builtin_function",
143
+ "cur_ctx",
144
+ "function",
145
+ "get_fn_name",
146
+ "pconv",
147
+ "peval",
148
+ "pmask",
149
+ "pshfl",
150
+ "pshfl_s",
151
+ "psize",
152
+ "set_ctx",
153
+ "trace",
154
+ "uniform_cond",
155
+ "while_loop",
156
+ "with_ctx",
157
+ ]
@@ -20,6 +20,7 @@ MPLang cluster configuration.
20
20
  from __future__ import annotations
21
21
 
22
22
  from dataclasses import dataclass, field
23
+ from functools import cached_property
23
24
  from typing import Any
24
25
 
25
26
 
@@ -132,11 +133,11 @@ class ClusterSpec:
132
133
  "which is not defined in nodes"
133
134
  )
134
135
 
135
- # ensure local devices have exactly one member
136
+ # ensure ppu devices have exactly one member
136
137
  for device in self.devices.values():
137
- if device.kind.lower() == "local" and len(device.members) != 1:
138
+ if device.kind.lower() == "ppu" and len(device.members) != 1:
138
139
  raise ValueError(
139
- f"Local device '{device.name}' must have exactly one member"
140
+ f"PPU device '{device.name}' must have exactly one member"
140
141
  )
141
142
 
142
143
  def get_node(self, name: str) -> Node:
@@ -169,6 +170,16 @@ class ClusterSpec:
169
170
  },
170
171
  }
171
172
 
173
+ @cached_property
174
+ def endpoints(self) -> list[str]:
175
+ eps: list[str] = []
176
+ for n in sorted(
177
+ self.nodes.values(),
178
+ key=lambda x: x.rank, # type: ignore[attr-defined]
179
+ ):
180
+ eps.append(n.endpoint)
181
+ return eps
182
+
172
183
  @classmethod
173
184
  def from_dict(cls, config: dict[str, Any]) -> ClusterSpec:
174
185
  """Parses a raw config dictionary and returns a validated ClusterSpec."""
@@ -237,12 +248,13 @@ class ClusterSpec:
237
248
  world_size: int,
238
249
  *,
239
250
  endpoints: list[str] | None = None,
251
+ spu_world_size: int | None = None,
240
252
  spu_protocol: str = "SEMI2K",
241
253
  spu_field: str = "FM128",
242
254
  runtime_version: str = "simulated",
243
255
  runtime_platform: str = "simulated",
244
256
  op_bindings: list[dict[str, str]] | None = None,
245
- enable_local_device: bool = True,
257
+ enable_ppu_device: bool = True,
246
258
  enable_spu_device: bool = True,
247
259
  ) -> ClusterSpec:
248
260
  """Convenience constructor used heavily in tests.
@@ -263,8 +275,8 @@ class ClusterSpec:
263
275
  op_bindings:
264
276
  Optional list of length ``world_size`` supplying per-node op_bindings
265
277
  override dicts (defaults to empty dicts).
266
- enable_local_device:
267
- If True (default), create one ``local_{rank}`` device per node.
278
+ enable_ppu_device:
279
+ If True (default), create one ``P{rank}`` PPU device per node.
268
280
  enable_spu_device:
269
281
  If True (default) create a shared SPU device named ``SP0``.
270
282
  """
@@ -282,9 +294,9 @@ class ClusterSpec:
282
294
  f"{len(op_bindings)} != {world_size}"
283
295
  )
284
296
 
285
- if not enable_local_device and not enable_spu_device:
297
+ if not enable_ppu_device and not enable_spu_device:
286
298
  raise ValueError(
287
- "At least one of enable_local_device or enable_spu_device must be True"
299
+ "At least one of enable_ppu_device or enable_spu_device must be True"
288
300
  )
289
301
 
290
302
  nodes: dict[str, Node] = {}
@@ -303,21 +315,25 @@ class ClusterSpec:
303
315
  )
304
316
 
305
317
  devices: dict[str, Device] = {}
306
- # Optional per-node local devices
307
- if enable_local_device:
318
+ # Optional per-node PPU devices
319
+ if enable_ppu_device:
308
320
  for i in range(world_size):
309
- devices[f"local_{i}"] = Device(
310
- name=f"local_{i}",
311
- kind="local",
321
+ devices[f"P{i}"] = Device(
322
+ name=f"P{i}",
323
+ kind="ppu",
312
324
  members=[nodes[f"node{i}"]],
313
325
  )
314
326
 
315
327
  # Shared SPU device
316
328
  if enable_spu_device:
329
+ if spu_world_size is None:
330
+ spu_world_size = world_size
331
+ spu_members = [nodes[f"node{i}"] for i in range(spu_world_size)]
332
+
317
333
  devices["SP0"] = Device(
318
334
  name="SP0",
319
335
  kind="SPU",
320
- members=list(nodes.values()),
336
+ members=spu_members,
321
337
  config={
322
338
  "protocol": spu_protocol,
323
339
  "field": spu_field,
@@ -19,7 +19,7 @@ import threading
19
19
  from abc import ABC, abstractmethod
20
20
  from typing import Any
21
21
 
22
- from mplang.core.mask import Mask
22
+ from mplang.v1.core.mask import Mask
23
23
 
24
24
 
25
25
  class ICommunicator(ABC):
@@ -48,6 +48,10 @@ class ICommunicator(ABC):
48
48
  def recv(self, frm: int, key: str) -> Any:
49
49
  """Receive data from peer with the given key"""
50
50
 
51
+ @abstractmethod
52
+ def onSent(self, frm: int, key: str, data: Any) -> None:
53
+ """Called when a key is sent to self"""
54
+
51
55
 
52
56
  class ICollective(ABC):
53
57
  """Interface for collective communication"""
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING
20
20
 
21
21
  if TYPE_CHECKING:
22
22
  # Imported only for typing to avoid import cycles at runtime.
23
- from mplang.core.mpobject import MPContext
23
+ from mplang.v1.core.mpobject import MPContext
24
24
 
25
25
  # The global working context.
26
26
  _g_ctx: MPContext | None = None
@@ -21,7 +21,8 @@ import numpy as np
21
21
 
22
22
  try:
23
23
  # Check if JAX is available
24
- import jax # noqa: F401
24
+ import jax
25
+ import jax.numpy as jnp
25
26
 
26
27
  _JAX_AVAILABLE = True
27
28
  except ImportError:
@@ -140,6 +141,10 @@ class DType:
140
141
  """Convert from JAX dtype to custom DType."""
141
142
  if not _JAX_AVAILABLE:
142
143
  raise ImportError("JAX is not available")
144
+ # Special handling for PRNG KeyTy: <class jax._src.prng.KeyTy>
145
+ if jnp.issubdtype(jax_dtype, jax.dtypes.prng_key):
146
+ return cls.from_numpy(np.uint32)
147
+
143
148
  # JAX dtypes are essentially NumPy dtypes
144
149
  return cls.from_numpy(jax_dtype)
145
150
 
@@ -172,6 +177,13 @@ class DType:
172
177
  # TypeError if it's not a pandas dtype we can handle
173
178
  pass
174
179
 
180
+ try:
181
+ return cls._from_arrow_dtype(dtype_like)
182
+ except (ImportError, TypeError):
183
+ # ImportError if pyarrow is not installed
184
+ # TypeError if it's not a pyarrow dtype we can handle
185
+ pass
186
+
175
187
  if isinstance(dtype_like, type) and dtype_like in (bool, int, float, complex):
176
188
  return cls.from_python_type(dtype_like)
177
189
  elif hasattr(dtype_like, "dtype") and not isinstance(dtype_like, type):
@@ -220,6 +232,37 @@ class DType:
220
232
 
221
233
  raise TypeError(f"Unsupported pandas dtype: {dtype_like}")
222
234
 
235
+ @classmethod
236
+ def _from_arrow_dtype(cls, dtype_like: Any) -> DType:
237
+ try:
238
+ import pyarrow as pa
239
+ except ImportError:
240
+ raise ImportError("pyarrow not available") from None
241
+
242
+ if not isinstance(dtype_like, pa.DataType):
243
+ raise TypeError("Not a pyarrow dtype")
244
+
245
+ ARROW_DTYPE_MAPPING = {
246
+ pa.bool_(): BOOL,
247
+ pa.int8(): INT8,
248
+ pa.int16(): INT16,
249
+ pa.int32(): INT32,
250
+ pa.int64(): INT64,
251
+ pa.uint8(): UINT8,
252
+ pa.uint16(): UINT16,
253
+ pa.uint32(): UINT32,
254
+ pa.uint64(): UINT64,
255
+ pa.float16(): FLOAT16,
256
+ pa.float32(): FLOAT32,
257
+ pa.float64(): FLOAT64,
258
+ pa.string(): STRING,
259
+ pa.large_string(): STRING,
260
+ }
261
+ result = ARROW_DTYPE_MAPPING.get(dtype_like)
262
+ if result is not None:
263
+ return result
264
+ raise TypeError(f"Unsupported arrow dtype: {dtype_like}")
265
+
223
266
  def to_numpy(self) -> np.dtype:
224
267
  """Convert custom DType to NumPy dtype."""
225
268
  return np.dtype(self.name)
@@ -228,7 +271,6 @@ class DType:
228
271
  """Convert custom DType to JAX dtype."""
229
272
  if not _JAX_AVAILABLE:
230
273
  raise ImportError("JAX is not available")
231
- import jax.numpy as jnp
232
274
 
233
275
  return jnp.dtype(self.name)
234
276
 
@@ -20,7 +20,7 @@ multi-party computation graphs using the visitor pattern.
20
20
  """
21
21
 
22
22
  # Core expression types
23
- from mplang.core.expr.ast import (
23
+ from mplang.v1.core.expr.ast import (
24
24
  AccessExpr,
25
25
  CallExpr,
26
26
  CondExpr,
@@ -36,12 +36,12 @@ from mplang.core.expr.ast import (
36
36
  )
37
37
 
38
38
  # Built-in evaluator engines
39
- from mplang.core.expr.evaluator import IEvaluator, create_evaluator
40
- from mplang.core.expr.printer import Printer
41
- from mplang.core.expr.transformer import ExprTransformer
39
+ from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
40
+ from mplang.v1.core.expr.printer import Printer
41
+ from mplang.v1.core.expr.transformer import ExprTransformer
42
42
 
43
43
  # Utility functions
44
- from mplang.core.expr.utils import (
44
+ from mplang.v1.core.expr.utils import (
45
45
  deduce_mask,
46
46
  ensure_scalar,
47
47
  ensure_tensorlist_equal,
@@ -49,8 +49,8 @@ from mplang.core.expr.utils import (
49
49
  )
50
50
 
51
51
  # Visitor pattern interface
52
- from mplang.core.expr.visitor import ExprVisitor
53
- from mplang.core.expr.walk import walk, walk_dataflow, walk_structural
52
+ from mplang.v1.core.expr.visitor import ExprVisitor
53
+ from mplang.v1.core.expr.walk import walk, walk_dataflow, walk_structural
54
54
 
55
55
  __all__ = [
56
56
  "AccessExpr",
@@ -26,15 +26,15 @@ import logging
26
26
  from abc import ABC, abstractmethod
27
27
  from typing import TYPE_CHECKING, Any
28
28
 
29
- from mplang.core.expr.utils import deduce_mask
30
- from mplang.core.mask import Mask
31
- from mplang.core.mptype import MPType, Rank
32
- from mplang.core.pfunc import PFunction
33
- from mplang.core.table import TableType
34
- from mplang.core.tensor import TensorType
29
+ from mplang.v1.core.expr.utils import deduce_mask
30
+ from mplang.v1.core.mask import Mask
31
+ from mplang.v1.core.mptype import MPType, Rank
32
+ from mplang.v1.core.pfunc import PFunction
33
+ from mplang.v1.core.table import TableType
34
+ from mplang.v1.core.tensor import TensorType
35
35
 
36
36
  if TYPE_CHECKING:
37
- from mplang.core.expr.visitor import ExprVisitor
37
+ from mplang.v1.core.expr.visitor import ExprVisitor
38
38
 
39
39
 
40
40
  class Expr(ABC):
@@ -286,8 +286,8 @@ class ConvExpr(Expr):
286
286
  # Validate dtype / shape consistency.
287
287
  first = types[0]
288
288
  for c in types[1:]:
289
- if (c.dtype, c.shape) != (first.dtype, first.shape):
290
- raise TypeError(f"Inconsistent dtype/shape in pconv: {c} vs {first}")
289
+ if c.raw_type() != first.raw_type():
290
+ raise TypeError(f"Inconsistent type in pconv: {c} vs {first}")
291
291
 
292
292
  # Deduce the pmask by intersecting all pmasks.
293
293
  pmasks = [t.pmask for t in types]
@@ -316,7 +316,7 @@ class ConvExpr(Expr):
316
316
  else:
317
317
  out_pmask = None
318
318
 
319
- return [MPType.tensor(first.dtype, first.shape, out_pmask, **first.attrs)]
319
+ return [MPType(first.raw_type(), out_pmask, first.attrs)]
320
320
 
321
321
  def accept(self, visitor: ExprVisitor) -> Any:
322
322
  return visitor.visit_conv(self)
@@ -398,9 +398,7 @@ class ShflSExpr(Expr):
398
398
  def _compute_mptypes(self) -> list[MPType]:
399
399
  # The types are the same as the source value, but with a new pmask.
400
400
  src_type = self.src_val.mptype
401
- return [
402
- MPType.tensor(src_type.dtype, src_type.shape, self.pmask, **src_type.attrs)
403
- ]
401
+ return [MPType(src_type._type, self.pmask, src_type.attrs)]
404
402
 
405
403
  def accept(self, visitor: ExprVisitor) -> Any:
406
404
  return visitor.visit_shfl_s(self)
@@ -528,8 +526,9 @@ class FuncDefExpr(Expr):
528
526
  class CallExpr(Expr):
529
527
  """Expression for function call."""
530
528
 
531
- def __init__(self, fn: FuncDefExpr, args: list[Expr]):
529
+ def __init__(self, name: str, fn: FuncDefExpr, args: list[Expr]):
532
530
  super().__init__()
531
+ self.name = name
533
532
  self.fn = fn
534
533
  self.args = args
535
534
 
@@ -27,8 +27,8 @@ from __future__ import annotations
27
27
  from dataclasses import dataclass
28
28
  from typing import Any, Protocol
29
29
 
30
- from mplang.core.comm import ICommunicator
31
- from mplang.core.expr.ast import (
30
+ from mplang.v1.core.comm import ICommunicator
31
+ from mplang.v1.core.expr.ast import (
32
32
  AccessExpr,
33
33
  CallExpr,
34
34
  CondExpr,
@@ -42,11 +42,12 @@ from mplang.core.expr.ast import (
42
42
  VariableExpr,
43
43
  WhileExpr,
44
44
  )
45
- from mplang.core.expr.visitor import ExprVisitor
46
- from mplang.core.expr.walk import walk_dataflow
47
- from mplang.core.mask import Mask
48
- from mplang.core.pfunc import PFunction
49
- from mplang.kernels.context import RuntimeContext
45
+ from mplang.v1.core.expr.visitor import ExprVisitor
46
+ from mplang.v1.core.expr.walk import walk_dataflow
47
+ from mplang.v1.core.mask import Mask
48
+ from mplang.v1.core.pfunc import PFunction
49
+ from mplang.v1.kernels.context import RuntimeContext
50
+ from mplang.v1.kernels.value import Value
50
51
 
51
52
 
52
53
  class IEvaluator(Protocol):
@@ -149,12 +150,12 @@ class EvalSemantic:
149
150
  def _as_optional_int(val: Any) -> int | None:
150
151
  """Convert a value to int if possible, preserving None.
151
152
 
152
- Handles Python ints, numpy scalars with .item(), and None.
153
+ Handles Python ints, floats, numpy scalar types (e.g., np.int32, np.float64), and None.
154
+ Uses int(val) for conversion which works with numpy scalars via __int__().
153
155
  """
156
+ val = EvalSemantic._unwrap_value(val)
154
157
  if val is None:
155
158
  return None
156
- if hasattr(val, "item"):
157
- return int(val.item())
158
159
  return int(val)
159
160
 
160
161
  def _simple_allgather(self, value: Any) -> list[Any]:
@@ -167,6 +168,7 @@ class EvalSemantic:
167
168
  Returns a list of length world_size with entries ordered by rank.
168
169
  """
169
170
  ws = self.comm.world_size
171
+ value = self._unwrap_value(value)
170
172
  # Trivial fast-path
171
173
  if ws == 1:
172
174
  return [value]
@@ -185,7 +187,12 @@ class EvalSemantic:
185
187
 
186
188
  def _verify_uniform_predicate(self, pred: Any) -> None:
187
189
  # Runtime uniformity check (O(P^2) send/recv emulation).
188
- vals = self._simple_allgather(bool(pred))
190
+ # Use Value.to_bool() if available, otherwise unwrap and convert
191
+ if isinstance(pred, Value):
192
+ pred_bool = pred.to_bool()
193
+ else:
194
+ pred_bool = bool(self._unwrap_value(pred))
195
+ vals = self._simple_allgather(pred_bool)
189
196
  if not vals:
190
197
  raise ValueError("uniform_cond: empty gather for predicate")
191
198
  first = vals[0]
@@ -209,13 +216,36 @@ class EvalSemantic:
209
216
  assert len(cond_result) == 1, (
210
217
  f"Condition function must return a single value, got {cond_result}"
211
218
  )
212
- cond_value = cond_result[0]
213
- if cond_value is None:
219
+ cond_val = cond_result[0]
220
+ if cond_val is None:
214
221
  raise RuntimeError(
215
222
  "while_loop condition produced None on rank "
216
223
  f"{self.rank}; ensure the predicate yields a boolean for every party."
217
224
  )
218
- return cond_value
225
+ # Use Value.to_bool() if available for cleaner conversion
226
+ if isinstance(cond_val, Value):
227
+ return cond_val.to_bool()
228
+ return bool(self._unwrap_value(cond_val))
229
+
230
+ @staticmethod
231
+ def _unwrap_value(value: Any) -> Any:
232
+ """Convert Value payloads to numpy/python equivalents when possible."""
233
+ if value is None:
234
+ return None
235
+
236
+ if isinstance(value, Value):
237
+ # Try to_numpy first for broader compatibility
238
+ to_numpy = getattr(value, "to_numpy", None)
239
+ if callable(to_numpy):
240
+ arr = to_numpy()
241
+ import numpy as np
242
+
243
+ if isinstance(arr, np.ndarray):
244
+ if arr.size == 1:
245
+ return arr.item()
246
+ return arr
247
+ return arr
248
+ return value
219
249
 
220
250
 
221
251
  class RecursiveEvaluator(EvalSemantic, ExprVisitor):
@@ -296,19 +326,25 @@ class RecursiveEvaluator(EvalSemantic, ExprVisitor):
296
326
  * Add optional static uniform inference (data provenance) to elide the
297
327
  runtime check when predicate uniformity is provable at trace time.
298
328
  """
299
- pred = self._value(expr.pred)
300
- if pred is None:
329
+ pred_val = self._value(expr.pred)
330
+ if pred_val is None:
301
331
  return [None] * len(expr.mptypes)
302
332
 
303
333
  if expr.verify_uniform:
304
- self._verify_uniform_predicate(pred)
334
+ self._verify_uniform_predicate(pred_val)
335
+
336
+ # Convert to bool using Value.to_bool() if available
337
+ if isinstance(pred_val, Value):
338
+ pred = pred_val.to_bool()
339
+ else:
340
+ pred = bool(self._unwrap_value(pred_val))
305
341
 
306
342
  # Only evaluate selected branch locally
307
- if pred:
308
- then_call = CallExpr(expr.then_fn, expr.args)
343
+ if bool(pred):
344
+ then_call = CallExpr("then", expr.then_fn, expr.args)
309
345
  return self._values(then_call)
310
346
  else:
311
- else_call = CallExpr(expr.else_fn, expr.args)
347
+ else_call = CallExpr("else", expr.else_fn, expr.args)
312
348
  return self._values(else_call)
313
349
 
314
350
  def visit_call(self, expr: CallExpr) -> Any:
@@ -435,15 +471,20 @@ class IterativeEvaluator(EvalSemantic):
435
471
  res = self._iter_eval_graph(node.fn.body, {**env, **sub_env})
436
472
  symbols[id(node)] = res
437
473
  elif isinstance(node, CondExpr):
438
- pred_v = self._first(symbols[id(node.pred)])
474
+ pred_val = self._first(symbols[id(node.pred)])
439
475
  arg_vals = [self._first(symbols[id(a)]) for a in node.args]
440
- if pred_v is None:
476
+ if pred_val is None:
441
477
  symbols[id(node)] = [None] * len(node.mptypes)
442
478
  else:
443
479
  # Optional uniform verification identical to recursive evaluator (DRY helper).
444
480
  if node.verify_uniform:
445
- self._verify_uniform_predicate(pred_v)
446
- if bool(pred_v):
481
+ self._verify_uniform_predicate(pred_val)
482
+ # Convert to bool using Value.to_bool() if available
483
+ if isinstance(pred_val, Value):
484
+ pred = pred_val.to_bool()
485
+ else:
486
+ pred = bool(self._unwrap_value(pred_val))
487
+ if pred:
447
488
  sub_env = dict(zip(node.then_fn.params, arg_vals, strict=True))
448
489
  res = self._iter_eval_graph(
449
490
  node.then_fn.body, {**env, **sub_env}