mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,187 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """SPU Dialect State.
16
+
17
+ Manages SPU Runtime lifecycle as a dialect state, enabling reuse across
18
+ multiple executions while binding to the Interpreter's lifecycle.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from typing import TYPE_CHECKING, Any
24
+
25
+ import spu.api as spu_api
26
+ import spu.libspu as libspu
27
+
28
+ from mplang.v2.runtime.dialect_state import DialectState
29
+
30
+ if TYPE_CHECKING:
31
+ from mplang.v2.dialects import spu
32
+
33
+
34
+ class SPUState(DialectState):
35
+ """SPU Runtime cache as dialect state.
36
+
37
+ Caches SPU Runtime and Io objects per (local_rank, world_size, config, link_mode)
38
+ to enable reuse across multiple SPU kernel executions.
39
+
40
+ This replaces the previous global `_SPU_RUNTIMES` cache with a properly
41
+ lifecycle-managed dialect state.
42
+ """
43
+
44
+ dialect_name: str = "spu"
45
+
46
+ def __init__(self) -> None:
47
+ # Key: (local_rank, world_size, protocol, field, link_mode)
48
+ # Value: (Runtime, Io)
49
+ self._runtimes: dict[
50
+ tuple[int, int, str, str, str], tuple[spu_api.Runtime, spu_api.Io]
51
+ ] = {}
52
+
53
+ def get_or_create(
54
+ self,
55
+ local_rank: int,
56
+ spu_world_size: int,
57
+ config: spu.SPUConfig,
58
+ spu_endpoints: list[str] | None = None,
59
+ communicator: object | None = None,
60
+ parties: list[int] | None = None,
61
+ ) -> tuple[spu_api.Runtime, spu_api.Io]:
62
+ """Get or create SPU Runtime and Io for the given configuration.
63
+
64
+ Args:
65
+ local_rank: The local rank within the SPU device (0-indexed).
66
+ spu_world_size: The number of parties in the SPU device.
67
+ config: SPU configuration including protocol settings.
68
+ spu_endpoints: Optional list of BRPC endpoints. If None, use mem link.
69
+ communicator: Optional v2 communicator (ThreadCommunicator/HttpCommunicator).
70
+ If provided, use Channels mode to reuse existing communication.
71
+ parties: Optional list of global ranks for SPU parties.
72
+ Required when communicator is provided.
73
+
74
+ Returns:
75
+ A tuple of (Runtime, Io) for this party.
76
+ """
77
+ from mplang.v2.backends.spu_impl import to_runtime_config
78
+
79
+ # Determine link mode
80
+ if communicator is not None:
81
+ link_mode = "channels"
82
+ elif spu_endpoints:
83
+ link_mode = "brpc"
84
+ else:
85
+ link_mode = "mem"
86
+
87
+ cache_key = (
88
+ local_rank,
89
+ spu_world_size,
90
+ config.protocol,
91
+ config.field,
92
+ link_mode,
93
+ )
94
+
95
+ if cache_key in self._runtimes:
96
+ return self._runtimes[cache_key]
97
+
98
+ # Create Link
99
+ if communicator is not None:
100
+ if parties is None:
101
+ raise ValueError("parties required when using communicator")
102
+ link = self._create_channels_link(
103
+ local_rank, spu_world_size, communicator, parties
104
+ )
105
+ elif spu_endpoints:
106
+ link = self._create_brpc_link(local_rank, spu_endpoints)
107
+ else:
108
+ link = self._create_mem_link(local_rank, spu_world_size)
109
+
110
+ # Create Runtime and Io
111
+ runtime_config = to_runtime_config(config)
112
+ runtime = spu_api.Runtime(link, runtime_config)
113
+ io = spu_api.Io(spu_world_size, runtime_config)
114
+
115
+ self._runtimes[cache_key] = (runtime, io)
116
+ return runtime, io
117
+
118
+ def _create_mem_link(
119
+ self, local_rank: int, spu_world_size: int
120
+ ) -> libspu.link.Context:
121
+ """Create in-memory link for simulation."""
122
+ desc = libspu.link.Desc() # type: ignore
123
+ desc.recv_timeout_ms = 30 * 1000
124
+ for i in range(spu_world_size):
125
+ desc.add_party(f"P{i}", f"mem:{i}")
126
+ return libspu.link.create_mem(desc, local_rank)
127
+
128
+ def _create_channels_link(
129
+ self,
130
+ local_rank: int,
131
+ spu_world_size: int,
132
+ communicator: Any,
133
+ parties: list[int],
134
+ ) -> libspu.link.Context:
135
+ """Create link using custom channels (reuse v2 communicator).
136
+
137
+ Args:
138
+ local_rank: SPU local rank (0-indexed, already converted from global)
139
+ spu_world_size: Number of SPU parties
140
+ communicator: v2 communicator (ThreadCommunicator/HttpCommunicator)
141
+ parties: List of global ranks for SPU parties (ordered by local rank)
142
+
143
+ Returns:
144
+ libspu link context using BaseChannel adapters
145
+ """
146
+ from mplang.v2.backends.channel import BaseChannel
147
+
148
+ # Get this worker's global rank
149
+ global_rank = parties[local_rank]
150
+
151
+ # Create channels list (world_size elements, self = None)
152
+ channels = []
153
+ for idx, peer_global_rank in enumerate(parties):
154
+ if idx == local_rank:
155
+ # Self channel must be None
156
+ channel = None
157
+ else:
158
+ # Create channel to peer
159
+ channel = BaseChannel(communicator, global_rank, peer_global_rank)
160
+ channels.append(channel)
161
+
162
+ # Create link descriptor
163
+ desc = libspu.link.Desc() # type: ignore
164
+ desc.recv_timeout_ms = 100 * 1000 # 100 seconds
165
+
166
+ # Add party info (required for world_size inference)
167
+ for idx in range(spu_world_size):
168
+ desc.add_party(f"P{idx}", f"dummy_{parties[idx]}")
169
+
170
+ return libspu.link.create_with_channels(desc, local_rank, channels)
171
+
172
+ def _create_brpc_link(
173
+ self, local_rank: int, spu_endpoints: list[str]
174
+ ) -> libspu.link.Context:
175
+ """Create BRPC link for distributed execution."""
176
+ desc = libspu.link.Desc() # type: ignore
177
+ desc.recv_timeout_ms = 100 * 1000 # 100 seconds
178
+ desc.http_max_payload_size = 32 * 1024 * 1024 # 32MB
179
+
180
+ for i, endpoint in enumerate(spu_endpoints):
181
+ desc.add_party(f"P{i}", endpoint)
182
+
183
+ return libspu.link.create_brpc(desc, local_rank)
184
+
185
+ def shutdown(self) -> None:
186
+ """Clear all cached runtimes."""
187
+ self._runtimes.clear()
@@ -0,0 +1,62 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Store Runtime Implementation."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Any
20
+
21
+ from mplang.v2.dialects import store
22
+ from mplang.v2.edsl.graph import Operation
23
+ from mplang.v2.runtime.interpreter import Interpreter
24
+
25
+
26
+ def _get_uri(uri_base: str) -> str:
27
+ """Generate URI: {uri_base}."""
28
+ # Handle different schemes if necessary, for now assume simple path joining
29
+ # or scheme preservation.
30
+ if "://" in uri_base:
31
+ scheme, _, path = uri_base.partition("://")
32
+ # Ensure we don't double slash if path is absolute
33
+ return f"{scheme}://{path}"
34
+ else:
35
+ # Default to fs:// for absolute paths (sandboxed)
36
+ return f"fs://{uri_base}"
37
+
38
+
39
+ @store.save_p.def_impl
40
+ def save_impl(interpreter: Interpreter, op: Operation, obj_val: Any) -> Any:
41
+ """Save implementation."""
42
+ uri_base: str = op.attrs["uri_base"]
43
+
44
+ # Use ObjectStore to put the value
45
+ # Note: obj_val is the runtime value (e.g. TensorValue, TableValue, or raw)
46
+ # We store it as is (pickle).
47
+ if interpreter.store is None:
48
+ raise RuntimeError("Interpreter has no ObjectStore configured. Cannot save.")
49
+ interpreter.store.put(obj_val, uri=_get_uri(uri_base))
50
+
51
+ return obj_val
52
+
53
+
54
+ @store.load_p.def_impl
55
+ def load_impl(interpreter: Interpreter, op: Operation) -> Any:
56
+ """Load implementation."""
57
+ uri_base: str = op.attrs["uri_base"]
58
+ # expected_type is in attrs but not needed for runtime loading (pickle handles it)
59
+
60
+ if interpreter.store is None:
61
+ raise RuntimeError("Interpreter has no ObjectStore configured. Cannot load.")
62
+ return interpreter.store.get(_get_uri(uri_base))