mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -14,16 +14,24 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ import pathlib
18
+ import struct
17
19
  from dataclasses import dataclass
18
20
  from typing import Any
19
21
  from urllib.parse import ParseResult, urlparse
20
22
 
21
23
  import numpy as np
22
- import pandas as pd
23
24
 
24
- from mplang.core import TableType, TensorType
25
- from mplang.kernels.base import KernelContext
26
- from mplang.utils import table_utils
25
+ from mplang.v1.core import TableLike, TableType, TensorType
26
+ from mplang.v1.kernels.base import KernelContext
27
+ from mplang.v1.kernels.value import (
28
+ TableValue,
29
+ TensorValue,
30
+ Value,
31
+ decode_value,
32
+ encode_value,
33
+ )
34
+ from mplang.v1.utils import table_utils
27
35
 
28
36
 
29
37
  @dataclass(frozen=True)
@@ -135,6 +143,11 @@ def get_provider(scheme: str) -> DataProvider | None:
135
143
 
136
144
 
137
145
  # ---------------- Default Providers ----------------
146
+ MAGIC_MPLANG = b"MPLG"
147
+ MAGIC_PARQUET = b"PAR1"
148
+ MAGIC_ORC = b"ORC"
149
+ MAGIC_NUMPY = b"\x93NUMPY"
150
+ VERSION = 0x01
138
151
 
139
152
 
140
153
  class FileProvider(DataProvider):
@@ -147,14 +160,52 @@ class FileProvider(DataProvider):
147
160
  def read(
148
161
  self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
149
162
  ) -> Any:
150
- path = uri.local_path or uri.raw
151
- if isinstance(out_spec, TableType):
152
- with open(path, "rb") as f:
153
- csv_bytes = f.read()
154
- # Pass schema to enable column projection
155
- return table_utils.csv_to_dataframe(csv_bytes, schema=out_spec)
156
- # tensor path
157
- return np.load(path)
163
+ path = pathlib.Path(uri.local_path or uri.raw)
164
+ # try load by magic
165
+ with path.open("rb") as f:
166
+ # this is the maximum length needed to detect all supported formats
167
+ # (numpy requires 6 bytes: '\x93NUMPY').
168
+ MAGIC_BYTES_LEN_MAX = 6
169
+ magic = f.read(MAGIC_BYTES_LEN_MAX)
170
+ f.seek(0)
171
+ if magic.startswith(MAGIC_MPLANG):
172
+ MPLANG_HEADER_LEN = len(MAGIC_MPLANG) + 1
173
+ header = f.read(MPLANG_HEADER_LEN)
174
+ _, version = struct.unpack(">4sB", header)
175
+ if version != VERSION:
176
+ raise ValueError(f"unsupported mplang version {version}")
177
+ payload = f.read()
178
+ return decode_value(payload)
179
+ elif magic.startswith(MAGIC_PARQUET):
180
+ if not isinstance(out_spec, TableType):
181
+ raise ValueError(
182
+ f"PARQUET files require TableType output spec, got {type(out_spec).__name__}"
183
+ )
184
+ return table_utils.read_table(
185
+ f, format="parquet", columns=list(out_spec.column_names())
186
+ )
187
+ elif magic.startswith(MAGIC_ORC):
188
+ if not isinstance(out_spec, TableType):
189
+ raise ValueError(
190
+ f"ORC files require TableType output spec, got {type(out_spec).__name__}"
191
+ )
192
+ return table_utils.read_table(
193
+ f, format="orc", columns=list(out_spec.column_names())
194
+ )
195
+ elif magic.startswith(MAGIC_NUMPY):
196
+ if not isinstance(out_spec, TensorType):
197
+ raise ValueError(
198
+ f"NumPy files require TensorType output spec, got {type(out_spec).__name__}"
199
+ )
200
+ return np.load(f)
201
+
202
+ # Fallback: open the file for CSV or NumPy loading.
203
+ if isinstance(out_spec, TableType):
204
+ return table_utils.read_table(
205
+ f, format="csv", columns=list(out_spec.column_names())
206
+ )
207
+ else:
208
+ return np.load(f)
158
209
 
159
210
  def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
160
211
  import os
@@ -163,14 +214,24 @@ class FileProvider(DataProvider):
163
214
  dir_name = os.path.dirname(path)
164
215
  if dir_name:
165
216
  os.makedirs(dir_name, exist_ok=True)
166
- # Table-like to CSV bytes
167
- if hasattr(value, "__dataframe__") or isinstance(value, pd.DataFrame):
168
- csv_bytes = table_utils.dataframe_to_csv(value) # type: ignore
217
+
218
+ if not isinstance(value, Value):
219
+ value = (
220
+ TableValue(value)
221
+ if isinstance(value, TableLike)
222
+ else TensorValue(value)
223
+ )
224
+
225
+ if isinstance(value, TableValue):
226
+ table_utils.write_table(value.to_arrow(), path, format="parquet")
227
+ elif isinstance(value, TensorValue):
228
+ with open(path, "wb") as f:
229
+ np.save(f, value.to_numpy())
230
+ else:
231
+ payload = encode_value(value)
169
232
  with open(path, "wb") as f:
170
- f.write(csv_bytes)
171
- return
172
- # Tensor-like via numpy
173
- np.save(path, np.asarray(value))
233
+ f.write(struct.pack(">4sB", MAGIC_MPLANG, VERSION))
234
+ f.write(payload)
174
235
 
175
236
 
176
237
  class MemProvider(DataProvider):
@@ -29,7 +29,7 @@ from typing import Any
29
29
 
30
30
  import numpy as np
31
31
 
32
- from mplang.core import (
32
+ from mplang.v1.core import (
33
33
  ClusterSpec,
34
34
  InterpContext,
35
35
  InterpVar,
@@ -38,9 +38,9 @@ from mplang.core import (
38
38
  MPObject,
39
39
  MPType,
40
40
  )
41
- from mplang.core.expr.ast import Expr
42
- from mplang.kernels.value import TableValue, TensorValue
43
- from mplang.runtime.client import HttpExecutorClient
41
+ from mplang.v1.core.expr.ast import Expr
42
+ from mplang.v1.kernels.value import TableValue, TensorValue
43
+ from mplang.v1.runtime.client import HttpExecutorClient
44
44
 
45
45
 
46
46
  def new_uuid() -> str:
@@ -0,0 +1,196 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ from typing import TYPE_CHECKING
19
+
20
+ import spu.libspu as libspu
21
+
22
+ if TYPE_CHECKING:
23
+ from mplang.v1.core.comm import CommunicatorBase
24
+ from mplang.v1.core.mask import Mask
25
+
26
+
27
+ class LinkCommunicator:
28
+ """Minimal wrapper for libspu link context.
29
+
30
+ Supports three modes:
31
+ 1. BRPC: Production mode with separate BRPC ports (legacy)
32
+ 2. Mem: In-memory links for testing (legacy)
33
+ 3. Channels: Reuse MPLang communicator via IChannel bridge (NEW)
34
+
35
+ The mode is selected based on constructor arguments:
36
+ - If `comm` is provided: Channels mode (NEW)
37
+ - Elif `mem_link` is True: Mem mode
38
+ - Else: BRPC mode
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ rank: int,
44
+ addrs: list[str] | None = None,
45
+ *,
46
+ mem_link: bool = False,
47
+ comm: CommunicatorBase | None = None,
48
+ spu_mask: Mask | None = None,
49
+ ):
50
+ """Initialize link communicator for SPU.
51
+
52
+ Args:
53
+ rank: Global rank of this party
54
+ addrs: List of addresses for all SPU parties (required for BRPC/Mem mode)
55
+ mem_link: If True, use in-memory link (Mem mode)
56
+ comm: MPLang communicator to reuse (Channels mode, NEW)
57
+ spu_mask: SPU parties mask (required for Channels mode)
58
+
59
+ Raises:
60
+ ValueError: If arguments are invalid for the selected mode
61
+ """
62
+ self._rank = rank
63
+
64
+ # Select initialization mode based on arguments
65
+ if comm is not None:
66
+ self._init_channels_mode(rank, comm, spu_mask)
67
+ elif mem_link:
68
+ self._init_mem_mode(rank, addrs)
69
+ else:
70
+ self._init_brpc_mode(rank, addrs)
71
+
72
+ def _init_channels_mode(
73
+ self, rank: int, comm: CommunicatorBase, spu_mask: Mask | None
74
+ ) -> None:
75
+ """Initialize Channels mode (reuse MPLang communicator).
76
+
77
+ Args:
78
+ rank: Global rank of this party
79
+ comm: MPLang communicator to reuse
80
+ spu_mask: SPU parties mask
81
+
82
+ Raises:
83
+ ValueError: If spu_mask is None or rank not in mask
84
+ """
85
+ if spu_mask is None:
86
+ raise ValueError("spu_mask required when using comm")
87
+ if rank not in spu_mask:
88
+ raise ValueError(f"rank {rank} not in spu_mask {spu_mask}")
89
+
90
+ # Lazy import to avoid circular dependency
91
+ from mplang.v1.runtime.channel import BaseChannel
92
+
93
+ # Create channels to ALL SPU parties (including self)
94
+ # libspu expects world_size channels, with self channel being None
95
+ channels = []
96
+ rel_rank = spu_mask.global_to_relative_rank(rank)
97
+
98
+ for _, peer_rank in enumerate(spu_mask):
99
+ if peer_rank == rank:
100
+ # For self, use None (won't be accessed by SPU)
101
+ channel = None
102
+ else:
103
+ channel = BaseChannel(comm, rank, peer_rank)
104
+ channels.append(channel)
105
+
106
+ # Create link context with custom channels
107
+ desc = libspu.link.Desc() # type: ignore
108
+ desc.recv_timeout_ms = 100 * 1000 # 100 seconds
109
+
110
+ # Add party info to desc (required for world_size inference)
111
+ for idx, peer_rank in enumerate(spu_mask):
112
+ desc.add_party(f"P{idx}", f"dummy_{peer_rank}")
113
+
114
+ self.lctx = libspu.link.create_with_channels(desc, rel_rank, channels)
115
+ self._world_size = spu_mask.num_parties()
116
+
117
+ logging.info(
118
+ f"LinkCommunicator initialized with BaseChannel: "
119
+ f"rank={rank}, rel_rank={rel_rank}, spu_mask={spu_mask}, "
120
+ f"world_size={self._world_size}"
121
+ )
122
+
123
+ def _init_mem_mode(self, rank: int, addrs: list[str] | None) -> None:
124
+ """Initialize Mem mode (in-memory links for testing).
125
+
126
+ Args:
127
+ rank: Global rank of this party
128
+ addrs: List of addresses for all SPU parties
129
+
130
+ Raises:
131
+ ValueError: If addrs is None
132
+ """
133
+ if addrs is None:
134
+ raise ValueError("addrs required for Mem mode")
135
+
136
+ self._world_size = len(addrs)
137
+
138
+ desc = libspu.link.Desc() # type: ignore
139
+ desc.recv_timeout_ms = 100 * 1000 # 100 seconds
140
+ desc.http_max_payload_size = 32 * 1024 * 1024 # 32M
141
+ for rank_idx, addr in enumerate(addrs):
142
+ desc.add_party(f"P{rank_idx}", addr)
143
+
144
+ self.lctx = libspu.link.create_mem(desc, self._rank)
145
+ logging.info(
146
+ f"LinkCommunicator initialized with Mem: "
147
+ f"rank={self._rank}, world_size={self._world_size}, addrs={addrs}"
148
+ )
149
+
150
+ def _init_brpc_mode(self, rank: int, addrs: list[str] | None) -> None:
151
+ """Initialize BRPC mode (production mode with separate BRPC ports).
152
+
153
+ Args:
154
+ rank: Global rank of this party
155
+ addrs: List of addresses for all SPU parties
156
+
157
+ Raises:
158
+ ValueError: If addrs is None
159
+ """
160
+ if addrs is None:
161
+ raise ValueError("addrs required for BRPC mode")
162
+
163
+ self._world_size = len(addrs)
164
+
165
+ desc = libspu.link.Desc() # type: ignore
166
+ desc.recv_timeout_ms = 100 * 1000 # 100 seconds
167
+ desc.http_max_payload_size = 32 * 1024 * 1024 # 32M
168
+ for rank_idx, addr in enumerate(addrs):
169
+ desc.add_party(f"P{rank_idx}", addr)
170
+
171
+ self.lctx = libspu.link.create_brpc(desc, self._rank)
172
+ logging.info(
173
+ f"LinkCommunicator initialized with BRPC: "
174
+ f"rank={self._rank}, world_size={self._world_size}, addrs={addrs}"
175
+ )
176
+
177
+ @property
178
+ def rank(self) -> int:
179
+ """Get rank from underlying link context."""
180
+ return self.lctx.rank # type: ignore[no-any-return]
181
+
182
+ @property
183
+ def world_size(self) -> int:
184
+ """Get world size from underlying link context."""
185
+ return self.lctx.world_size # type: ignore[no-any-return]
186
+
187
+ def get_lctx(self) -> libspu.link.Context:
188
+ """Get the underlying libspu link context.
189
+
190
+ This is the primary interface - SPU runtime uses this context directly.
191
+ All communication and serialization is handled by libspu internally.
192
+
193
+ Returns:
194
+ The underlying libspu.link.Context instance.
195
+ """
196
+ return self.lctx
@@ -30,14 +30,18 @@ from fastapi import (
30
30
  from fastapi.responses import JSONResponse
31
31
  from pydantic import BaseModel
32
32
 
33
- from mplang.core import IrReader, TableType, TensorType
34
- from mplang.core.cluster import ClusterSpec
35
- from mplang.kernels.base import KernelContext
36
- from mplang.kernels.value import Value, decode_value, encode_value
37
- from mplang.protos.v1alpha1 import mpir_pb2
38
- from mplang.runtime.data_providers import DataProvider, ResolvedURI, register_provider
39
- from mplang.runtime.exceptions import InvalidRequestError, ResourceNotFound
40
- from mplang.runtime.session import (
33
+ from mplang.v1.core import IrReader, TableType, TensorType
34
+ from mplang.v1.core.cluster import ClusterSpec
35
+ from mplang.v1.kernels.base import KernelContext
36
+ from mplang.v1.kernels.value import Value, decode_value, encode_value
37
+ from mplang.v1.protos.v1alpha1 import mpir_pb2
38
+ from mplang.v1.runtime.data_providers import (
39
+ DataProvider,
40
+ ResolvedURI,
41
+ register_provider,
42
+ )
43
+ from mplang.v1.runtime.exceptions import InvalidRequestError, ResourceNotFound
44
+ from mplang.v1.runtime.session import (
41
45
  Computation,
42
46
  Session,
43
47
  Symbol,
@@ -215,6 +219,7 @@ class SymbolResponse(BaseModel):
215
219
 
216
220
  class CommSendRequest(BaseModel):
217
221
  data: str # Base64 encoded binary data
222
+ is_raw_bytes: bool = False # True for SPU channel raw bytes
218
223
 
219
224
 
220
225
  # Response Models for enhanced status
@@ -483,6 +488,14 @@ def comm_send(
483
488
  # The receiver rank should be the rank of the server hosting this endpoint
484
489
  # We don't need to validate to_rank since the request is coming to this server
485
490
 
491
+ # For raw bytes (SPU channel), pass through as dict with flag
492
+ # For normal data, pass the base64 string directly
493
+ data_payload: str | dict[str, object]
494
+ if request.is_raw_bytes:
495
+ data_payload = {"data": request.data, "is_raw_bytes": True}
496
+ else:
497
+ data_payload = request.data
498
+
486
499
  # Use the proper onSent mechanism from CommunicatorBase
487
- sess.communicator.onSent(from_rank, key, request.data)
500
+ sess.communicator.onSent(from_rank, key, data_payload)
488
501
  return {"status": "ok"}
@@ -25,51 +25,28 @@ Process-wide registries (sessions, global symbols) live in the server layer
25
25
 
26
26
  from __future__ import annotations
27
27
 
28
- import logging
29
28
  import time
30
29
  from dataclasses import dataclass, field
31
30
  from functools import cached_property
32
31
  from typing import TYPE_CHECKING, Any, cast
33
- from urllib.parse import urlparse
34
32
 
35
33
  import spu.libspu as libspu
36
34
 
37
- from mplang.core.cluster import ClusterSpec
38
- from mplang.core.comm import ICommunicator
39
- from mplang.core.expr.ast import Expr
40
- from mplang.core.expr.evaluator import IEvaluator, create_evaluator
41
- from mplang.core.mask import Mask
42
- from mplang.kernels.context import RuntimeContext
43
- from mplang.kernels.spu import PFunction # type: ignore
44
- from mplang.kernels.value import Value
45
- from mplang.runtime.communicator import HttpCommunicator
46
- from mplang.runtime.exceptions import ResourceNotFound
47
- from mplang.runtime.link_comm import LinkCommunicator
48
- from mplang.utils.spu_utils import parse_field, parse_protocol
35
+ from mplang.v1.core.cluster import ClusterSpec
36
+ from mplang.v1.core.comm import ICommunicator
37
+ from mplang.v1.core.expr.ast import Expr
38
+ from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
39
+ from mplang.v1.core.mask import Mask
40
+ from mplang.v1.kernels.context import RuntimeContext
41
+ from mplang.v1.kernels.spu import PFunction # type: ignore
42
+ from mplang.v1.kernels.value import Value
43
+ from mplang.v1.runtime.communicator import HttpCommunicator
44
+ from mplang.v1.runtime.exceptions import ResourceNotFound
45
+ from mplang.v1.runtime.link_comm import LinkCommunicator
46
+ from mplang.v1.utils.spu_utils import parse_field, parse_protocol
49
47
 
50
48
  if TYPE_CHECKING: # pragma: no cover - import only for type checking
51
- from mplang.core.cluster import ClusterSpec, Node, RuntimeInfo
52
-
53
-
54
- class LinkCommFactory:
55
- """Factory for creating and caching link communicators."""
56
-
57
- def __init__(self) -> None:
58
- self._cache: dict[tuple[int, tuple[str, ...]], LinkCommunicator] = {}
59
-
60
- def create_link(self, rel_rank: int, addrs: list[str]) -> LinkCommunicator:
61
- key = (rel_rank, tuple(addrs))
62
- link = self._cache.get(key)
63
- if link is not None:
64
- return link
65
- logging.info(f"LinkCommunicator created: rel_rank={rel_rank} addrs={addrs}")
66
- link = LinkCommunicator(rel_rank, addrs)
67
- self._cache[key] = link
68
- return link
69
-
70
-
71
- # Shared link factory (module-local, not global registry of sessions)
72
- g_link_factory = LinkCommFactory()
49
+ from mplang.v1.core.cluster import ClusterSpec, Node, RuntimeInfo
73
50
 
74
51
 
75
52
  @dataclass
@@ -184,23 +161,19 @@ class Session:
184
161
  return
185
162
 
186
163
  link_ctx = None
187
- # TODO(jint): reuse same port for mplang and spu.
188
- SPU_PORT_OFFSET = 100
189
164
 
190
165
  if self.is_spu_party:
191
- # Build SPU address list across all endpoints for ranks in mask
192
- spu_addrs: list[str] = []
193
- for r, addr in enumerate(self.cluster_spec.endpoints):
194
- if r in self.spu_mask:
195
- # TODO(oeqqwq): addr may contain other schema like grpc://
196
- if not addr.startswith(("http://", "https://")):
197
- addr = f"http://{addr}"
198
- parsed = urlparse(addr)
199
- assert isinstance(parsed.port, int)
200
- new_addr = f"{parsed.hostname}:{parsed.port + SPU_PORT_OFFSET}"
201
- spu_addrs.append(new_addr)
202
- rel_index = sum(1 for r in range(self.rank) if r in self.spu_mask)
203
- link_ctx = g_link_factory.create_link(rel_index, spu_addrs)
166
+ # Use Channels mode to reuse existing HttpCommunicator
167
+ # This eliminates the need for separate BRPC ports (SPU_PORT_OFFSET)
168
+ from mplang.v1.core.comm import CommunicatorBase
169
+
170
+ # Type assertion: ICommunicator is actually CommunicatorBase
171
+ comm = cast(CommunicatorBase, self.communicator)
172
+ link_ctx = LinkCommunicator(
173
+ rank=self.rank,
174
+ comm=comm,
175
+ spu_mask=self.spu_mask,
176
+ )
204
177
 
205
178
  spu_config = libspu.RuntimeConfig(
206
179
  protocol=parse_protocol(self.spu_protocol),
@@ -18,13 +18,14 @@ import concurrent.futures
18
18
  import faulthandler
19
19
  import logging
20
20
  import sys
21
+ import threading
21
22
  import traceback
22
23
  from collections.abc import Sequence
23
24
  from typing import Any, cast
24
25
 
25
26
  import spu.libspu as libspu
26
27
 
27
- from mplang.core import (
28
+ from mplang.v1.core import (
28
29
  ClusterSpec,
29
30
  CollectiveMixin,
30
31
  CommunicatorBase,
@@ -38,11 +39,11 @@ from mplang.core import (
38
39
  PFunction, # for spu.seed_env kernel seeding
39
40
  TensorLike,
40
41
  )
41
- from mplang.core.expr.ast import Expr
42
- from mplang.core.expr.evaluator import IEvaluator, create_evaluator
43
- from mplang.kernels.context import RuntimeContext
44
- from mplang.runtime.link_comm import LinkCommunicator
45
- from mplang.utils.spu_utils import parse_field, parse_protocol
42
+ from mplang.v1.core.expr.ast import Expr
43
+ from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
44
+ from mplang.v1.kernels.context import RuntimeContext
45
+ from mplang.v1.runtime.link_comm import LinkCommunicator
46
+ from mplang.v1.utils.spu_utils import parse_field, parse_protocol
46
47
 
47
48
 
48
49
  class ThreadCommunicator(CommunicatorBase, CollectiveMixin):
@@ -129,16 +130,37 @@ class Simulator(InterpContext):
129
130
  comm.set_peers(self._comms)
130
131
 
131
132
  # Prepare link contexts for SPU parties (store for evaluator-time initialization)
132
- spu_addrs = [f"P{spu_rank}" for spu_rank in spu_mask]
133
+ # Use Channels mode to reuse ThreadCommunicator instead of separate mem_link
133
134
  self._spu_link_ctxs: list[LinkCommunicator | None] = [None] * world_size
134
- link_ctx_list = [
135
- LinkCommunicator(idx, spu_addrs, mem_link=True)
136
- for idx in range(spu_mask.num_parties())
135
+
136
+ # Create LinkCommunicators in parallel to avoid deadlock
137
+ # (create_with_channels does handshake via TestSend/TestRecv)
138
+ exceptions: dict[int, Exception] = {}
139
+
140
+ def create_link(g_rank: int) -> None:
141
+ try:
142
+ self._spu_link_ctxs[g_rank] = LinkCommunicator(
143
+ rank=g_rank,
144
+ comm=self._comms[g_rank],
145
+ spu_mask=spu_mask,
146
+ )
147
+ except Exception as e:
148
+ exceptions[g_rank] = e
149
+
150
+ threads = [
151
+ threading.Thread(target=create_link, args=(g_rank,)) for g_rank in spu_mask
137
152
  ]
138
- for g_rank in range(world_size):
139
- if g_rank in spu_mask:
140
- rel = Mask(spu_mask).global_to_relative_rank(g_rank)
141
- self._spu_link_ctxs[g_rank] = link_ctx_list[rel]
153
+ for t in threads:
154
+ t.start()
155
+ for t in threads:
156
+ t.join()
157
+
158
+ # Check for exceptions during link creation
159
+ if exceptions:
160
+ first_exc = next(iter(exceptions.values()))
161
+ raise RuntimeError(
162
+ f"Failed to create SPU link contexts for ranks {list(exceptions.keys())}"
163
+ ) from first_exc
142
164
 
143
165
  self._spu_runtime_cfg = libspu.RuntimeConfig(
144
166
  protocol=spu_protocol, field=spu_field