mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
@@ -1,303 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import pathlib
18
- import struct
19
- from dataclasses import dataclass
20
- from typing import Any
21
- from urllib.parse import ParseResult, urlparse
22
-
23
- import numpy as np
24
-
25
- from mplang.v1.core import TableLike, TableType, TensorType
26
- from mplang.v1.kernels.base import KernelContext
27
- from mplang.v1.kernels.value import (
28
- TableValue,
29
- TensorValue,
30
- Value,
31
- decode_value,
32
- encode_value,
33
- )
34
- from mplang.v1.utils import table_utils
35
-
36
-
37
- @dataclass(frozen=True)
38
- class ResolvedURI:
39
- """Result of resolving a resource path into a normalized form.
40
-
41
- Attributes:
42
- scheme: The URI scheme (e.g., 'file', 's3', 'mem', 'var', 'secret').
43
- raw: The original path string as provided by the user.
44
- parsed: The ParseResult if a scheme was present; otherwise None.
45
- local_path: For file paths: concrete filesystem path (absolute or as given).
46
- """
47
-
48
- scheme: str
49
- raw: str
50
- parsed: ParseResult | None
51
- local_path: str | None
52
-
53
-
54
- def resolve_uri(path: str) -> ResolvedURI:
55
- """Resolve a user-provided resource location into a normalized URI form.
56
-
57
- This helper accepts plain filesystem paths and RFC 3986 style URIs. A path
58
- is treated as ``file`` when ``urlparse(path).scheme`` is empty. Detection
59
- no longer depends on the presence of the literal substring ``"://"`` so
60
- that forms like ``mem:foo`` (no slashes) are still recognized as a URI.
61
-
62
- Captured fields
63
- - ``scheme``: Lower-cased scheme (``file`` when absent)
64
- - ``raw``: Original input
65
- - ``parsed``: ``ParseResult`` when a scheme was provided, else ``None``
66
- - ``local_path``: Filesystem path for ``file`` scheme, else ``None``
67
-
68
- Supported (pluggable) schemes out-of-the-box:
69
- * ``file`` (default)
70
- * ``mem``
71
- * ``s3`` (stub)
72
- * ``secret`` (stub)
73
- * ``symbols`` (registered server-side)
74
-
75
- Examples
76
- >>> resolve_uri("data/train.npy").scheme
77
- 'file'
78
- >>> resolve_uri("mem:dataset1").scheme
79
- 'mem'
80
- >>> resolve_uri("mem://dataset1").scheme # both forms acceptable
81
- 'mem'
82
- >>> resolve_uri("symbols://shared_model").scheme
83
- 'symbols'
84
- >>> resolve_uri("file:///tmp/x.npy").local_path
85
- '/tmp/x.npy'
86
- """
87
-
88
- pr = urlparse(path)
89
- if not pr.scheme:
90
- return ResolvedURI("file", path, None, path)
91
-
92
- scheme = pr.scheme.lower()
93
- local_path: str | None = None
94
- if scheme == "file":
95
- local_path = pr.path
96
- if pr.netloc and not local_path.startswith("/"):
97
- local_path = f"//{pr.netloc}/{pr.path}"
98
- return ResolvedURI(scheme, path, pr, local_path)
99
-
100
-
101
- class DataProvider:
102
- """Abstract base for data providers.
103
-
104
- Minimal contract: read/write by URI and type spec. Providers may ignore the
105
- type spec but SHOULD validate when feasible.
106
- """
107
-
108
- def read(
109
- self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
110
- ) -> Any:
111
- raise NotImplementedError
112
-
113
- def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
114
- raise NotImplementedError
115
-
116
-
117
- _REGISTRY: dict[str, DataProvider] = {}
118
-
119
-
120
- def register_provider(
121
- scheme: str, provider: DataProvider, *, replace: bool = False, quiet: bool = False
122
- ) -> None:
123
- """Register a provider implementation.
124
-
125
- Args:
126
- scheme: URI scheme handled (case-insensitive)
127
- provider: Implementation
128
- replace: If False and scheme exists -> ValueError
129
- quiet: If True, suppress duplicate log messages when replacing
130
- """
131
- import logging
132
-
133
- key = scheme.lower()
134
- if not replace and key in _REGISTRY:
135
- raise ValueError(f"provider already registered for scheme: {scheme}")
136
- if replace and key in _REGISTRY and not quiet:
137
- logging.info(f"Replacing existing provider for scheme '{scheme}'")
138
- _REGISTRY[key] = provider
139
-
140
-
141
- def get_provider(scheme: str) -> DataProvider | None:
142
- return _REGISTRY.get(scheme.lower())
143
-
144
-
145
- # ---------------- Default Providers ----------------
146
- MAGIC_MPLANG = b"MPLG"
147
- MAGIC_PARQUET = b"PAR1"
148
- MAGIC_ORC = b"ORC"
149
- MAGIC_NUMPY = b"\x93NUMPY"
150
- VERSION = 0x01
151
-
152
-
153
- class FileProvider(DataProvider):
154
- """Local filesystem provider.
155
-
156
- For tables: CSV bytes via table_utils.
157
- For tensors: NumPy .npy via np.load/np.save.
158
- """
159
-
160
- def read(
161
- self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
162
- ) -> Any:
163
- path = pathlib.Path(uri.local_path or uri.raw)
164
- # try load by magic
165
- with path.open("rb") as f:
166
- # this is the maximum length needed to detect all supported formats
167
- # (numpy requires 6 bytes: '\x93NUMPY').
168
- MAGIC_BYTES_LEN_MAX = 6
169
- magic = f.read(MAGIC_BYTES_LEN_MAX)
170
- f.seek(0)
171
- if magic.startswith(MAGIC_MPLANG):
172
- MPLANG_HEADER_LEN = len(MAGIC_MPLANG) + 1
173
- header = f.read(MPLANG_HEADER_LEN)
174
- _, version = struct.unpack(">4sB", header)
175
- if version != VERSION:
176
- raise ValueError(f"unsupported mplang version {version}")
177
- payload = f.read()
178
- return decode_value(payload)
179
- elif magic.startswith(MAGIC_PARQUET):
180
- if not isinstance(out_spec, TableType):
181
- raise ValueError(
182
- f"PARQUET files require TableType output spec, got {type(out_spec).__name__}"
183
- )
184
- return table_utils.read_table(
185
- f, format="parquet", columns=list(out_spec.column_names())
186
- )
187
- elif magic.startswith(MAGIC_ORC):
188
- if not isinstance(out_spec, TableType):
189
- raise ValueError(
190
- f"ORC files require TableType output spec, got {type(out_spec).__name__}"
191
- )
192
- return table_utils.read_table(
193
- f, format="orc", columns=list(out_spec.column_names())
194
- )
195
- elif magic.startswith(MAGIC_NUMPY):
196
- if not isinstance(out_spec, TensorType):
197
- raise ValueError(
198
- f"NumPy files require TensorType output spec, got {type(out_spec).__name__}"
199
- )
200
- return np.load(f)
201
-
202
- # Fallback: open the file for CSV or NumPy loading.
203
- if isinstance(out_spec, TableType):
204
- return table_utils.read_table(
205
- f, format="csv", columns=list(out_spec.column_names())
206
- )
207
- else:
208
- return np.load(f)
209
-
210
- def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
211
- import os
212
-
213
- path = uri.local_path or uri.raw
214
- dir_name = os.path.dirname(path)
215
- if dir_name:
216
- os.makedirs(dir_name, exist_ok=True)
217
-
218
- if not isinstance(value, Value):
219
- value = (
220
- TableValue(value)
221
- if isinstance(value, TableLike)
222
- else TensorValue(value)
223
- )
224
-
225
- if isinstance(value, TableValue):
226
- table_utils.write_table(value.to_arrow(), path, format="parquet")
227
- elif isinstance(value, TensorValue):
228
- with open(path, "wb") as f:
229
- np.save(f, value.to_numpy())
230
- else:
231
- payload = encode_value(value)
232
- with open(path, "wb") as f:
233
- f.write(struct.pack(">4sB", MAGIC_MPLANG, VERSION))
234
- f.write(payload)
235
-
236
-
237
- class MemProvider(DataProvider):
238
- """In-memory per-runtime KV provider (per rank, per session/runtime)."""
239
-
240
- STATE_KEY = "resource.providers.mem"
241
-
242
- @staticmethod
243
- def _store(ctx: KernelContext) -> dict[str, Any]:
244
- # Use ensure_state so creation is atomic & centralized; enforce dict.
245
- store = ctx.runtime.ensure_state(MemProvider.STATE_KEY, dict)
246
- if not isinstance(store, dict): # pragma: no cover - defensive
247
- raise TypeError(
248
- f"runtime state key '{MemProvider.STATE_KEY}' expected dict, got {type(store).__name__}"
249
- )
250
- return store # type: ignore[return-value]
251
-
252
- def read(
253
- self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
254
- ) -> Any:
255
- store = self._store(ctx)
256
- key = uri.raw
257
- if key not in store:
258
- raise FileNotFoundError(f"mem resource not found: {key}")
259
- return store[key]
260
-
261
- def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
262
- store = self._store(ctx)
263
- store[uri.raw] = value
264
-
265
-
266
- class S3Provider(DataProvider):
267
- """Placeholder S3 provider. Install external plugin to enable."""
268
-
269
- def read(
270
- self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
271
- ) -> Any:
272
- raise NotImplementedError(
273
- "S3 provider not installed. Provide an external plugin via register_provider('s3', ...) ."
274
- )
275
-
276
- def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
277
- raise NotImplementedError(
278
- "S3 provider not installed. Provide an external plugin via register_provider('s3', ...) ."
279
- )
280
-
281
-
282
- class SecretProvider(DataProvider):
283
- """Placeholder secret provider. Integrate with KMS/secret manager via plugin."""
284
-
285
- def read(
286
- self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
287
- ) -> Any:
288
- raise NotImplementedError(
289
- "secret provider not installed. Provide an external plugin via register_provider('secret', ...) ."
290
- )
291
-
292
- def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
293
- raise NotImplementedError(
294
- "secret provider not installed. Provide an external plugin via register_provider('secret', ...) ."
295
- )
296
-
297
-
298
- # Register default providers
299
- register_provider("file", FileProvider())
300
- register_provider("mem", MemProvider())
301
- # Stubs to signal missing providers explicitly (can be overridden by plugins)
302
- register_provider("s3", S3Provider())
303
- register_provider("secret", SecretProvider())
@@ -1,324 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- HTTP-based driver implementation for distributed execution.
17
-
18
- This module provides an HTTP-based driver, using REST APIs
19
- for distributed multi-party computation coordination.
20
- """
21
-
22
- from __future__ import annotations
23
-
24
- import asyncio
25
- import base64
26
- import uuid
27
- from collections.abc import Sequence
28
- from typing import Any
29
-
30
- import numpy as np
31
-
32
- from mplang.v1.core import (
33
- ClusterSpec,
34
- InterpContext,
35
- InterpVar,
36
- IrWriter,
37
- Mask,
38
- MPObject,
39
- MPType,
40
- )
41
- from mplang.v1.core.expr.ast import Expr
42
- from mplang.v1.kernels.value import TableValue, TensorValue
43
- from mplang.v1.runtime.client import HttpExecutorClient
44
-
45
-
46
- def new_uuid() -> str:
47
- """Generates a short UUID using URL-safe Base64 encoding."""
48
- u = uuid.uuid4()
49
- # Get the 16 bytes of the UUID
50
- uuid_bytes = u.bytes
51
- # Encode using URL-safe Base64
52
- encoded_bytes = base64.urlsafe_b64encode(uuid_bytes)
53
- # Decode to UTF-8 string, remove padding, and take first 8 characters
54
- encoded_string = encoded_bytes.decode("utf-8").rstrip("=")[:8]
55
- return encoded_string
56
-
57
-
58
- class DriverVar(InterpVar):
59
- """A variable that references a value in distributed HTTP executor nodes.
60
-
61
- This represents a symbol stored on remote HTTP servers that can be
62
- retrieved via REST API calls.
63
- """
64
-
65
- def __init__(
66
- self,
67
- ctx: Driver,
68
- symbol_name: str,
69
- mptype: MPType,
70
- ) -> None:
71
- super().__init__(ctx, mptype)
72
- self.symbol_name = symbol_name
73
-
74
- @property
75
- def mptype(self) -> MPType:
76
- """The type of this variable."""
77
- return self._mptype
78
-
79
- def __repr__(self) -> str:
80
- return f"HttpDriverVar(symbol_name={self.symbol_name}, mptype={self.mptype})"
81
-
82
-
83
- class Driver(InterpContext):
84
- """Driver for distributed execution using HTTP-based services.
85
-
86
- Args:
87
- cluster_spec: The cluster specification defining the distributed environment.
88
- trace_ranks: List of ranks to trace execution for debugging.
89
- timeout: HTTP request timeout in seconds.
90
- """
91
-
92
- def __init__(
93
- self,
94
- cluster_spec: ClusterSpec,
95
- *,
96
- trace_ranks: list[int] | None = None,
97
- timeout: int = 120,
98
- ) -> None:
99
- """Initialize a driver with the given cluster specification.
100
-
101
- Args:
102
- cluster_spec: The cluster specification defining the distributed environment.
103
- trace_ranks: List of ranks to trace execution for debugging.
104
- timeout: HTTP request timeout in seconds.
105
- """
106
- super().__init__(cluster_spec)
107
- self._trace_ranks = trace_ranks or []
108
- self.timeout = timeout
109
-
110
- self._session_id: str | None = None
111
- self._counter = 0
112
-
113
- self.node_addrs = {
114
- node_id: node.endpoint for node_id, node in cluster_spec.nodes.items()
115
- }
116
-
117
- # Get SPU configuration from cluster_spec
118
- spu_devices = cluster_spec.get_devices_by_kind("SPU")
119
- if not spu_devices:
120
- raise ValueError("No SPU device found in the cluster specification")
121
- if len(spu_devices) > 1:
122
- raise ValueError("Multiple SPU devices found in the cluster specification")
123
- spu_device = spu_devices[0]
124
-
125
- # Store SPU configuration as strings for better readability
126
- self.spu_protocol_str = spu_device.config["protocol"]
127
- self.spu_field_str = spu_device.config["field"]
128
-
129
- # Compute spu_mask from spu_device members
130
- spu_mask = Mask.from_ranks([member.rank for member in spu_device.members])
131
- self.spu_mask_int = int(spu_mask)
132
-
133
- def _create_clients(self) -> dict[str, HttpExecutorClient]:
134
- """Create HTTP clients for all endpoints."""
135
- clients = {}
136
- for node_id, endpoint in self.node_addrs.items():
137
- clients[node_id] = HttpExecutorClient(endpoint, self.timeout)
138
- return clients
139
-
140
- async def _close_clients(self, clients: dict[str, HttpExecutorClient]) -> None:
141
- """Close all provided HTTP clients."""
142
- await asyncio.gather(*[client.close() for client in clients.values()])
143
-
144
- def new_name(self, prefix: str = "var") -> str:
145
- """Generate a unique execution name."""
146
- name = f"{prefix}_{self._counter}"
147
- self._counter += 1
148
- return name
149
-
150
- async def _get_or_create_session(self) -> str:
151
- """Get existing session or create a new one across all HTTP servers."""
152
- if self._session_id is None:
153
- new_session_id = new_uuid()
154
- # Create temporary clients for session creation
155
- clients = self._create_clients()
156
- try:
157
- # Create session on all HTTP servers concurrently
158
- tasks = []
159
- for node_id, client in clients.items():
160
- # Convert node_id to rank for the session creation
161
- rank = list(self.node_addrs.keys()).index(node_id)
162
- task = client.create_session(
163
- name=new_session_id,
164
- rank=rank,
165
- cluster_spec=self.cluster_spec.to_dict(),
166
- )
167
- tasks.append(task)
168
-
169
- try:
170
- results = await asyncio.gather(*tasks)
171
- for session_id in results:
172
- assert session_id == new_session_id
173
- self._session_id = new_session_id
174
- except RuntimeError as e:
175
- raise RuntimeError(
176
- f"Failed to create session on one or more parties: {e}"
177
- ) from e
178
- finally:
179
- await self._close_clients(clients)
180
-
181
- assert self._session_id is not None
182
- return self._session_id
183
-
184
- async def _evaluate(
185
- self, expr: Expr, bindings: dict[str, MPObject]
186
- ) -> Sequence[MPObject]:
187
- """Async implementation to evaluate an expression."""
188
- session_id = await self._get_or_create_session()
189
-
190
- # Prepare input names from bindings
191
- var_names = []
192
- party_symbol_names = []
193
- for name, var in bindings.items():
194
- if var.ctx is not self:
195
- raise ValueError(f"Variable {name} not in this context, got {var.ctx}.")
196
- assert isinstance(var, DriverVar), (
197
- f"Expected HttpDriverVar, got {type(var)}"
198
- )
199
- var_names.append(name)
200
- party_symbol_names.append(var.symbol_name)
201
-
202
- var_name_mapping = dict(zip(var_names, party_symbol_names, strict=True))
203
-
204
- writer = IrWriter(var_name_mapping)
205
- program_proto = writer.dumps(expr)
206
-
207
- output_symbols = [self.new_name() for _ in range(expr.num_outputs)]
208
-
209
- # Create temporary clients for computation execution
210
- clients = self._create_clients()
211
- try:
212
- # Concurrently create and execute computation on all parties
213
- tasks = []
214
- computation_id = new_uuid()
215
- for _rank, client in clients.items():
216
- task = client.create_and_execute_computation(
217
- session_id,
218
- computation_id,
219
- program_proto.SerializeToString(),
220
- party_symbol_names,
221
- output_symbols,
222
- )
223
- tasks.append(task)
224
-
225
- try:
226
- await asyncio.gather(*tasks)
227
- except RuntimeError as e:
228
- raise RuntimeError(
229
- f"Failed to create and execute computation on one or more parties: {e}"
230
- ) from e
231
- finally:
232
- await self._close_clients(clients)
233
-
234
- # Create HttpDriverVar objects for each output
235
- driver_vars = []
236
- for symbol_name, mptype in zip(output_symbols, expr.mptypes, strict=True):
237
- driver_var = DriverVar(self, symbol_name, mptype)
238
- driver_vars.append(driver_var)
239
-
240
- return driver_vars
241
-
242
- def evaluate(self, expr: Expr, bindings: dict[str, MPObject]) -> Sequence[MPObject]:
243
- """Evaluate an expression using distributed HTTP execution."""
244
- return asyncio.run(self._evaluate(expr, bindings))
245
-
246
- async def _fetch(self, obj: MPObject) -> list[Any]:
247
- """Async implementation to fetch results."""
248
- if not isinstance(obj, DriverVar):
249
- raise ValueError(f"Expected HttpDriverVar, got {type(obj)}")
250
-
251
- session_id = await self._get_or_create_session()
252
- symbol_full_name = obj.symbol_name
253
-
254
- # Create temporary clients for fetching
255
- clients = self._create_clients()
256
- try:
257
- # Concurrently fetch symbol from all parties
258
- tasks = []
259
- for _rank, client in clients.items():
260
- task = client.get_symbol(session_id, symbol_full_name)
261
- tasks.append(task)
262
-
263
- try:
264
- # The results will be in the same order as the clients (ranks)
265
- results = await asyncio.gather(*tasks)
266
- converted: list[Any] = []
267
- for value in results:
268
- if isinstance(value, TensorValue):
269
- arr = value.to_numpy()
270
- if isinstance(arr, np.ndarray) and arr.size == 1:
271
- converted.append(arr.item())
272
- else:
273
- converted.append(arr)
274
- elif isinstance(value, TableValue):
275
- converted.append(value.to_pandas())
276
- else:
277
- converted.append(value)
278
- return converted
279
- except RuntimeError as e:
280
- raise RuntimeError(
281
- f"Failed to fetch symbol from one or more parties: {e}"
282
- ) from e
283
- finally:
284
- await self._close_clients(clients)
285
-
286
- def fetch(self, obj: MPObject) -> list[Any]:
287
- """Fetch results from the distributed HTTP execution."""
288
- return asyncio.run(self._fetch(obj))
289
-
290
- async def _ping(self, node_id: str) -> bool:
291
- """Async implementation to ping a node.
292
-
293
- Args:
294
- node_id: The ID of the node to ping
295
-
296
- Returns:
297
- True if the node is healthy, False otherwise
298
- """
299
- # Create a temporary client for the node
300
- if node_id not in self.node_addrs:
301
- raise ValueError(f"Node {node_id} not found in party addresses")
302
-
303
- endpoint = self.node_addrs[node_id]
304
- client = HttpExecutorClient(endpoint, self.timeout)
305
-
306
- try:
307
- # Perform health check
308
- return await client.health_check()
309
- except Exception:
310
- # Any exception means the node is not healthy
311
- return False
312
- finally:
313
- await client.close()
314
-
315
- def ping(self, node_id: str) -> bool:
316
- """Ping a node to check if it's healthy.
317
-
318
- Args:
319
- node_id: The ID of the node to ping
320
-
321
- Returns:
322
- True if the node is healthy, False otherwise
323
- """
324
- return asyncio.run(self._ping(node_id))
@@ -1,27 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Custom exception types for the HTTP backend."""
16
-
17
-
18
- class HttpBackendError(Exception):
19
- """Base exception for all HTTP backend errors."""
20
-
21
-
22
- class ResourceNotFound(HttpBackendError):
23
- """Raised when a resource (session, computation, etc.) is not found."""
24
-
25
-
26
- class InvalidRequestError(HttpBackendError, ValueError):
27
- """Raised for invalid requests, e.g., bad parameters or invalid state."""