mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -21,20 +21,33 @@ import base64
21
21
  import logging
22
22
  from typing import Any
23
23
 
24
- import cloudpickle as pickle
25
24
  import httpx
26
25
 
27
- from mplang.core.comm import CommunicatorBase
26
+ from mplang.v1.core.comm import CommunicatorBase
27
+ from mplang.v1.kernels.value import Value, decode_value, encode_value
28
28
 
29
29
 
30
30
  class HttpCommunicator(CommunicatorBase):
31
31
  def __init__(self, session_name: str, rank: int, endpoints: list[str]):
32
+ # Validate endpoints
33
+ if not endpoints:
34
+ raise ValueError("endpoints cannot be empty")
35
+
36
+ if not all(endpoint for endpoint in endpoints):
37
+ raise ValueError("endpoints cannot contain empty elements")
38
+
32
39
  super().__init__(rank, len(endpoints))
33
40
  self.session_name = session_name
34
- self.endpoints = endpoints
41
+ # Ensure all endpoints have protocol prefix
42
+ self.endpoints = [
43
+ endpoint
44
+ if endpoint.startswith(("http://", "https://"))
45
+ else f"http://{endpoint}"
46
+ for endpoint in endpoints
47
+ ]
35
48
  self._counter = 0
36
49
  logging.info(
37
- f"HttpCommunicator initialized: session={session_name}, rank={rank}, endpoints={endpoints}"
50
+ f"HttpCommunicator initialized: session={session_name}, rank={rank}, endpoints={self.endpoints}"
38
51
  )
39
52
 
40
53
  # override
@@ -44,7 +57,12 @@ class HttpCommunicator(CommunicatorBase):
44
57
  return str(res)
45
58
 
46
59
  def send(self, to: int, key: str, data: Any) -> None:
47
- """Sends data to a peer party by PUTing to its /comm/{key}/from/{from_rank} endpoint."""
60
+ """Sends data to a peer party by PUTing to its /comm/{key}/from/{from_rank} endpoint.
61
+
62
+ Supports two modes:
63
+ - SPU channel (key starts with "spu:"): sends raw bytes directly
64
+ - Normal channel: wraps data in Value envelope
65
+ """
48
66
  target_endpoint = self.endpoints[to]
49
67
  url = f"{target_endpoint}/sessions/{self.session_name}/comm/{key}/from/{self._rank}"
50
68
  logging.debug(
@@ -52,13 +70,20 @@ class HttpCommunicator(CommunicatorBase):
52
70
  )
53
71
 
54
72
  try:
55
- # Use cloudpickle for robust serialization of complex Python objects
56
- data_bytes = pickle.dumps(data)
57
- data_b64 = base64.b64encode(data_bytes).decode("utf-8")
58
-
59
- request_data = {
60
- "data": data_b64,
61
- }
73
+ # SPU channel mode: send raw bytes directly
74
+ if key.startswith("spu:") and isinstance(data, bytes):
75
+ data_b64 = base64.b64encode(data).decode("utf-8")
76
+ request_data = {"data": data_b64, "is_raw_bytes": True}
77
+ # Normal mode: serialize using Value envelope
78
+ elif isinstance(data, Value):
79
+ data_bytes = encode_value(data)
80
+ data_b64 = base64.b64encode(data_bytes).decode("utf-8")
81
+ request_data = {"data": data_b64}
82
+ else:
83
+ raise TypeError(
84
+ f"Communicator requires Value instance, got {type(data).__name__}. "
85
+ "Wrap data in TensorValue or custom Value subclass."
86
+ )
62
87
 
63
88
  response = httpx.put(url, json=request_data, timeout=60)
64
89
  logging.debug(f"Send response: status={response.status_code}")
@@ -72,14 +97,33 @@ class HttpCommunicator(CommunicatorBase):
72
97
  raise OSError(f"Failed to send data to rank {to}") from e
73
98
 
74
99
  def recv(self, frm: int, key: str) -> Any:
75
- """Wait until the key is set, returns the value. Override to add logging."""
100
+ """Wait until the key is set, returns the value.
101
+
102
+ Supports two modes:
103
+ - SPU channel (key starts with "spu:"): returns raw bytes
104
+ - Normal channel: returns deserialized Value
105
+ """
76
106
  logging.debug(
77
107
  f"Waiting to receive: from_rank={frm}, to_rank={self._rank}, key={key}"
78
108
  )
79
- data_b64 = super().recv(frm, key)
109
+ received_data = super().recv(frm, key)
110
+
111
+ # Check if this is raw bytes (SPU channel)
112
+ if isinstance(received_data, dict) and received_data.get("is_raw_bytes"):
113
+ data_bytes = base64.b64decode(received_data["data"])
114
+ logging.debug(
115
+ f"Received raw bytes: from_rank={frm}, to_rank={self._rank}, key={key}, size={len(data_bytes)}"
116
+ )
117
+ return data_bytes
80
118
 
119
+ # Normal mode: deserialize Value envelope
120
+ data_b64 = (
121
+ received_data
122
+ if isinstance(received_data, str)
123
+ else received_data.get("data")
124
+ )
81
125
  data_bytes = base64.b64decode(data_b64)
82
- result = pickle.loads(data_bytes)
126
+ result = decode_value(data_bytes)
83
127
 
84
128
  logging.debug(
85
129
  f"Received data: from_rank={frm}, to_rank={self._rank}, key={key}"
@@ -14,17 +14,24 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ import pathlib
18
+ import struct
17
19
  from dataclasses import dataclass
18
20
  from typing import Any
19
21
  from urllib.parse import ParseResult, urlparse
20
22
 
21
23
  import numpy as np
22
- import pandas as pd
23
24
 
24
- from mplang.core.table import TableType
25
- from mplang.core.tensor import TensorType
26
- from mplang.kernels.base import KernelContext
27
- from mplang.utils import table_utils
25
+ from mplang.v1.core import TableLike, TableType, TensorType
26
+ from mplang.v1.kernels.base import KernelContext
27
+ from mplang.v1.kernels.value import (
28
+ TableValue,
29
+ TensorValue,
30
+ Value,
31
+ decode_value,
32
+ encode_value,
33
+ )
34
+ from mplang.v1.utils import table_utils
28
35
 
29
36
 
30
37
  @dataclass(frozen=True)
@@ -136,6 +143,11 @@ def get_provider(scheme: str) -> DataProvider | None:
136
143
 
137
144
 
138
145
  # ---------------- Default Providers ----------------
146
+ MAGIC_MPLANG = b"MPLG"
147
+ MAGIC_PARQUET = b"PAR1"
148
+ MAGIC_ORC = b"ORC"
149
+ MAGIC_NUMPY = b"\x93NUMPY"
150
+ VERSION = 0x01
139
151
 
140
152
 
141
153
  class FileProvider(DataProvider):
@@ -148,13 +160,52 @@ class FileProvider(DataProvider):
148
160
  def read(
149
161
  self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
150
162
  ) -> Any:
151
- path = uri.local_path or uri.raw
152
- if isinstance(out_spec, TableType):
153
- with open(path, "rb") as f:
154
- csv_bytes = f.read()
155
- return table_utils.csv_to_dataframe(csv_bytes)
156
- # tensor path
157
- return np.load(path)
163
+ path = pathlib.Path(uri.local_path or uri.raw)
164
+ # try load by magic
165
+ with path.open("rb") as f:
166
+ # this is the maximum length needed to detect all supported formats
167
+ # (numpy requires 6 bytes: '\x93NUMPY').
168
+ MAGIC_BYTES_LEN_MAX = 6
169
+ magic = f.read(MAGIC_BYTES_LEN_MAX)
170
+ f.seek(0)
171
+ if magic.startswith(MAGIC_MPLANG):
172
+ MPLANG_HEADER_LEN = len(MAGIC_MPLANG) + 1
173
+ header = f.read(MPLANG_HEADER_LEN)
174
+ _, version = struct.unpack(">4sB", header)
175
+ if version != VERSION:
176
+ raise ValueError(f"unsupported mplang version {version}")
177
+ payload = f.read()
178
+ return decode_value(payload)
179
+ elif magic.startswith(MAGIC_PARQUET):
180
+ if not isinstance(out_spec, TableType):
181
+ raise ValueError(
182
+ f"PARQUET files require TableType output spec, got {type(out_spec).__name__}"
183
+ )
184
+ return table_utils.read_table(
185
+ f, format="parquet", columns=list(out_spec.column_names())
186
+ )
187
+ elif magic.startswith(MAGIC_ORC):
188
+ if not isinstance(out_spec, TableType):
189
+ raise ValueError(
190
+ f"ORC files require TableType output spec, got {type(out_spec).__name__}"
191
+ )
192
+ return table_utils.read_table(
193
+ f, format="orc", columns=list(out_spec.column_names())
194
+ )
195
+ elif magic.startswith(MAGIC_NUMPY):
196
+ if not isinstance(out_spec, TensorType):
197
+ raise ValueError(
198
+ f"NumPy files require TensorType output spec, got {type(out_spec).__name__}"
199
+ )
200
+ return np.load(f)
201
+
202
+ # Fallback: open the file for CSV or NumPy loading.
203
+ if isinstance(out_spec, TableType):
204
+ return table_utils.read_table(
205
+ f, format="csv", columns=list(out_spec.column_names())
206
+ )
207
+ else:
208
+ return np.load(f)
158
209
 
159
210
  def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
160
211
  import os
@@ -163,14 +214,24 @@ class FileProvider(DataProvider):
163
214
  dir_name = os.path.dirname(path)
164
215
  if dir_name:
165
216
  os.makedirs(dir_name, exist_ok=True)
166
- # Table-like to CSV bytes
167
- if hasattr(value, "__dataframe__") or isinstance(value, pd.DataFrame):
168
- csv_bytes = table_utils.dataframe_to_csv(value) # type: ignore
217
+
218
+ if not isinstance(value, Value):
219
+ value = (
220
+ TableValue(value)
221
+ if isinstance(value, TableLike)
222
+ else TensorValue(value)
223
+ )
224
+
225
+ if isinstance(value, TableValue):
226
+ table_utils.write_table(value.to_arrow(), path, format="parquet")
227
+ elif isinstance(value, TensorValue):
228
+ with open(path, "wb") as f:
229
+ np.save(f, value.to_numpy())
230
+ else:
231
+ payload = encode_value(value)
169
232
  with open(path, "wb") as f:
170
- f.write(csv_bytes)
171
- return
172
- # Tensor-like via numpy
173
- np.save(path, np.asarray(value))
233
+ f.write(struct.pack(">4sB", MAGIC_MPLANG, VERSION))
234
+ f.write(payload)
174
235
 
175
236
 
176
237
  class MemProvider(DataProvider):
@@ -15,8 +15,8 @@
15
15
  """
16
16
  HTTP-based driver implementation for distributed execution.
17
17
 
18
- This module provides an HTTP-based alternative to the gRPC Driver,
19
- using REST APIs for distributed multi-party computation coordination.
18
+ This module provides an HTTP-based driver, using REST APIs
19
+ for distributed multi-party computation coordination.
20
20
  """
21
21
 
22
22
  from __future__ import annotations
@@ -27,14 +27,20 @@ import uuid
27
27
  from collections.abc import Sequence
28
28
  from typing import Any
29
29
 
30
- from mplang.core.cluster import ClusterSpec
31
- from mplang.core.expr.ast import Expr
32
- from mplang.core.interp import InterpContext, InterpVar
33
- from mplang.core.mask import Mask
34
- from mplang.core.mpir import Writer
35
- from mplang.core.mpobject import MPObject
36
- from mplang.core.mptype import MPType
37
- from mplang.runtime.client import HttpExecutorClient
30
+ import numpy as np
31
+
32
+ from mplang.v1.core import (
33
+ ClusterSpec,
34
+ InterpContext,
35
+ InterpVar,
36
+ IrWriter,
37
+ Mask,
38
+ MPObject,
39
+ MPType,
40
+ )
41
+ from mplang.v1.core.expr.ast import Expr
42
+ from mplang.v1.kernels.value import TableValue, TensorValue
43
+ from mplang.v1.runtime.client import HttpExecutorClient
38
44
 
39
45
 
40
46
  def new_uuid() -> str:
@@ -195,7 +201,7 @@ class Driver(InterpContext):
195
201
 
196
202
  var_name_mapping = dict(zip(var_names, party_symbol_names, strict=True))
197
203
 
198
- writer = Writer(var_name_mapping)
204
+ writer = IrWriter(var_name_mapping)
199
205
  program_proto = writer.dumps(expr)
200
206
 
201
207
  output_symbols = [self.new_name() for _ in range(expr.num_outputs)]
@@ -257,7 +263,19 @@ class Driver(InterpContext):
257
263
  try:
258
264
  # The results will be in the same order as the clients (ranks)
259
265
  results = await asyncio.gather(*tasks)
260
- return list(results)
266
+ converted: list[Any] = []
267
+ for value in results:
268
+ if isinstance(value, TensorValue):
269
+ arr = value.to_numpy()
270
+ if isinstance(arr, np.ndarray) and arr.size == 1:
271
+ converted.append(arr.item())
272
+ else:
273
+ converted.append(arr)
274
+ elif isinstance(value, TableValue):
275
+ converted.append(value.to_pandas())
276
+ else:
277
+ converted.append(value)
278
+ return converted
261
279
  except RuntimeError as e:
262
280
  raise RuntimeError(
263
281
  f"Failed to fetch symbol from one or more parties: {e}"
@@ -0,0 +1,196 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ from typing import TYPE_CHECKING
19
+
20
+ import spu.libspu as libspu
21
+
22
+ if TYPE_CHECKING:
23
+ from mplang.v1.core.comm import CommunicatorBase
24
+ from mplang.v1.core.mask import Mask
25
+
26
+
27
+ class LinkCommunicator:
28
+ """Minimal wrapper for libspu link context.
29
+
30
+ Supports three modes:
31
+ 1. BRPC: Production mode with separate BRPC ports (legacy)
32
+ 2. Mem: In-memory links for testing (legacy)
33
+ 3. Channels: Reuse MPLang communicator via IChannel bridge (NEW)
34
+
35
+ The mode is selected based on constructor arguments:
36
+ - If `comm` is provided: Channels mode (NEW)
37
+ - Elif `mem_link` is True: Mem mode
38
+ - Else: BRPC mode
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ rank: int,
44
+ addrs: list[str] | None = None,
45
+ *,
46
+ mem_link: bool = False,
47
+ comm: CommunicatorBase | None = None,
48
+ spu_mask: Mask | None = None,
49
+ ):
50
+ """Initialize link communicator for SPU.
51
+
52
+ Args:
53
+ rank: Global rank of this party
54
+ addrs: List of addresses for all SPU parties (required for BRPC/Mem mode)
55
+ mem_link: If True, use in-memory link (Mem mode)
56
+ comm: MPLang communicator to reuse (Channels mode, NEW)
57
+ spu_mask: SPU parties mask (required for Channels mode)
58
+
59
+ Raises:
60
+ ValueError: If arguments are invalid for the selected mode
61
+ """
62
+ self._rank = rank
63
+
64
+ # Select initialization mode based on arguments
65
+ if comm is not None:
66
+ self._init_channels_mode(rank, comm, spu_mask)
67
+ elif mem_link:
68
+ self._init_mem_mode(rank, addrs)
69
+ else:
70
+ self._init_brpc_mode(rank, addrs)
71
+
72
+ def _init_channels_mode(
73
+ self, rank: int, comm: CommunicatorBase, spu_mask: Mask | None
74
+ ) -> None:
75
+ """Initialize Channels mode (reuse MPLang communicator).
76
+
77
+ Args:
78
+ rank: Global rank of this party
79
+ comm: MPLang communicator to reuse
80
+ spu_mask: SPU parties mask
81
+
82
+ Raises:
83
+ ValueError: If spu_mask is None or rank not in mask
84
+ """
85
+ if spu_mask is None:
86
+ raise ValueError("spu_mask required when using comm")
87
+ if rank not in spu_mask:
88
+ raise ValueError(f"rank {rank} not in spu_mask {spu_mask}")
89
+
90
+ # Lazy import to avoid circular dependency
91
+ from mplang.v1.runtime.channel import BaseChannel
92
+
93
+ # Create channels to ALL SPU parties (including self)
94
+ # libspu expects world_size channels, with self channel being None
95
+ channels = []
96
+ rel_rank = spu_mask.global_to_relative_rank(rank)
97
+
98
+ for _, peer_rank in enumerate(spu_mask):
99
+ if peer_rank == rank:
100
+ # For self, use None (won't be accessed by SPU)
101
+ channel = None
102
+ else:
103
+ channel = BaseChannel(comm, rank, peer_rank)
104
+ channels.append(channel)
105
+
106
+ # Create link context with custom channels
107
+ desc = libspu.link.Desc() # type: ignore
108
+ desc.recv_timeout_ms = 100 * 1000 # 100 seconds
109
+
110
+ # Add party info to desc (required for world_size inference)
111
+ for idx, peer_rank in enumerate(spu_mask):
112
+ desc.add_party(f"P{idx}", f"dummy_{peer_rank}")
113
+
114
+ self.lctx = libspu.link.create_with_channels(desc, rel_rank, channels)
115
+ self._world_size = spu_mask.num_parties()
116
+
117
+ logging.info(
118
+ f"LinkCommunicator initialized with BaseChannel: "
119
+ f"rank={rank}, rel_rank={rel_rank}, spu_mask={spu_mask}, "
120
+ f"world_size={self._world_size}"
121
+ )
122
+
123
+ def _init_mem_mode(self, rank: int, addrs: list[str] | None) -> None:
124
+ """Initialize Mem mode (in-memory links for testing).
125
+
126
+ Args:
127
+ rank: Global rank of this party
128
+ addrs: List of addresses for all SPU parties
129
+
130
+ Raises:
131
+ ValueError: If addrs is None
132
+ """
133
+ if addrs is None:
134
+ raise ValueError("addrs required for Mem mode")
135
+
136
+ self._world_size = len(addrs)
137
+
138
+ desc = libspu.link.Desc() # type: ignore
139
+ desc.recv_timeout_ms = 100 * 1000 # 100 seconds
140
+ desc.http_max_payload_size = 32 * 1024 * 1024 # 32M
141
+ for rank_idx, addr in enumerate(addrs):
142
+ desc.add_party(f"P{rank_idx}", addr)
143
+
144
+ self.lctx = libspu.link.create_mem(desc, self._rank)
145
+ logging.info(
146
+ f"LinkCommunicator initialized with Mem: "
147
+ f"rank={self._rank}, world_size={self._world_size}, addrs={addrs}"
148
+ )
149
+
150
+ def _init_brpc_mode(self, rank: int, addrs: list[str] | None) -> None:
151
+ """Initialize BRPC mode (production mode with separate BRPC ports).
152
+
153
+ Args:
154
+ rank: Global rank of this party
155
+ addrs: List of addresses for all SPU parties
156
+
157
+ Raises:
158
+ ValueError: If addrs is None
159
+ """
160
+ if addrs is None:
161
+ raise ValueError("addrs required for BRPC mode")
162
+
163
+ self._world_size = len(addrs)
164
+
165
+ desc = libspu.link.Desc() # type: ignore
166
+ desc.recv_timeout_ms = 100 * 1000 # 100 seconds
167
+ desc.http_max_payload_size = 32 * 1024 * 1024 # 32M
168
+ for rank_idx, addr in enumerate(addrs):
169
+ desc.add_party(f"P{rank_idx}", addr)
170
+
171
+ self.lctx = libspu.link.create_brpc(desc, self._rank)
172
+ logging.info(
173
+ f"LinkCommunicator initialized with BRPC: "
174
+ f"rank={self._rank}, world_size={self._world_size}, addrs={addrs}"
175
+ )
176
+
177
+ @property
178
+ def rank(self) -> int:
179
+ """Get rank from underlying link context."""
180
+ return self.lctx.rank # type: ignore[no-any-return]
181
+
182
+ @property
183
+ def world_size(self) -> int:
184
+ """Get world size from underlying link context."""
185
+ return self.lctx.world_size # type: ignore[no-any-return]
186
+
187
+ def get_lctx(self) -> libspu.link.Context:
188
+ """Get the underlying libspu link context.
189
+
190
+ This is the primary interface - SPU runtime uses this context directly.
191
+ All communication and serialization is handled by libspu internally.
192
+
193
+ Returns:
194
+ The underlying libspu.link.Context instance.
195
+ """
196
+ return self.lctx