mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
@@ -1,456 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- HTTP Executor Client Library.
17
-
18
- This module provides a clean HTTP client interface for interacting with
19
- HTTP-based executor services. It handles all HTTP communication details
20
- and provides domain-specific methods for session, computation, and symbol management.
21
- """
22
-
23
- from __future__ import annotations
24
-
25
- import base64
26
- from typing import Any
27
-
28
- import httpx
29
-
30
- from mplang.v1.kernels.value import Value, decode_value, encode_value
31
-
32
-
33
- class ExecutionStatus:
34
- """Status of a computation execution."""
35
-
36
- PENDING = "pending"
37
- RUNNING = "running"
38
- COMPLETED = "completed"
39
- FAILED = "failed"
40
-
41
-
42
- class HttpExecutorClient:
43
- """HTTP client for interacting with HTTP-based executor services."""
44
-
45
- def __init__(self, endpoint: str, timeout: int = 60):
46
- """Initialize the HTTP executor client.
47
-
48
- Args:
49
- endpoint: The base URL of the HTTP executor service
50
- timeout: Default timeout for HTTP requests in seconds
51
- """
52
- # Ensure endpoint has a protocol prefix
53
- if not endpoint.startswith(("http://", "https://")):
54
- endpoint = f"http://{endpoint}"
55
-
56
- self.endpoint = endpoint.rstrip("/")
57
- self.timeout = timeout
58
- self._client = httpx.AsyncClient(base_url=self.endpoint, timeout=self.timeout)
59
-
60
- async def close(self) -> None:
61
- """Close the underlying HTTP client."""
62
- await self._client.aclose()
63
-
64
- # Internal helpers
65
- def _raise_http_error(self, action: str, e: Exception) -> RuntimeError:
66
- if isinstance(e, httpx.HTTPStatusError):
67
- # Extract detailed error message from response
68
- error_detail = "Unknown error"
69
- try:
70
- error_response = e.response.json()
71
- error_detail = error_response.get("detail", str(e))
72
- except Exception:
73
- error_detail = str(e)
74
- return RuntimeError(f"Failed to {action}: {error_detail}")
75
- elif isinstance(e, httpx.RequestError):
76
- return RuntimeError(f"Failed to {action}: {e}")
77
- else:
78
- return RuntimeError(f"Failed to {action}: {e}")
79
-
80
- # Session Management
81
- async def create_session(
82
- self,
83
- name: str,
84
- rank: int,
85
- cluster_spec: dict,
86
- ) -> str:
87
- """Create a new session.
88
-
89
- Args:
90
- name: Session name/ID.
91
- rank: This party's rank.
92
- cluster_spec: Full cluster specification dict (ClusterSpec.to_dict()).
93
-
94
- Returns:
95
- The session name/ID
96
-
97
- Raises:
98
- RuntimeError: If session creation fails
99
- """
100
- url = f"/sessions/{name}"
101
- payload: dict[str, Any] = {"rank": rank, "cluster_spec": cluster_spec}
102
-
103
- try:
104
- response = await self._client.put(url, json=payload)
105
- response.raise_for_status()
106
- return str(response.json()["name"])
107
- except httpx.HTTPStatusError as e:
108
- # Extract detailed error message from response
109
- error_detail = "Unknown error"
110
- try:
111
- error_response = e.response.json()
112
- error_detail = error_response.get("detail", str(e))
113
- except Exception:
114
- error_detail = str(e)
115
- raise RuntimeError(f"Failed to create session: {error_detail}") from e
116
- except httpx.RequestError as e:
117
- raise RuntimeError(f"Failed to create session: {e}") from e
118
-
119
- async def get_session(self, session_name: str) -> dict[str, Any]:
120
- """Get session information.
121
-
122
- Args:
123
- session_id: The session name/ID
124
-
125
- Returns:
126
- Session information dictionary
127
-
128
- Raises:
129
- RuntimeError: If session retrieval fails
130
- """
131
- url = f"/sessions/{session_name}"
132
-
133
- try:
134
- response = await self._client.get(url)
135
- response.raise_for_status()
136
- return dict(response.json())
137
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
138
- raise self._raise_http_error(f"get session {session_name}", e) from e
139
-
140
- async def delete_session(self, session_name: str) -> None:
141
- """Delete a session and all its associated resources.
142
-
143
- Args:
144
- session_name: The session name/ID
145
-
146
- Raises:
147
- RuntimeError: If session deletion fails
148
- """
149
- url = f"/sessions/{session_name}"
150
-
151
- try:
152
- response = await self._client.delete(url)
153
- response.raise_for_status()
154
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
155
- raise self._raise_http_error(f"delete session {session_name}", e) from e
156
-
157
- # Computation Management
158
- async def create_and_execute_computation(
159
- self,
160
- session_id: str,
161
- computation_id: str,
162
- program: bytes,
163
- input_names: list[str],
164
- output_names: list[str],
165
- ) -> str:
166
- """Create a new computation in a session.
167
-
168
- Args:
169
- session_id: The session name/ID
170
- computation_id: The computation name/ID
171
- program: Serialized computation program (protobuf bytes)
172
- input_names: List of input symbol names
173
- output_names: List of output symbol names
174
-
175
- Returns:
176
- The computation name/ID
177
-
178
- Raises:
179
- RuntimeError: If computation creation fails
180
- """
181
- url = f"/sessions/{session_id}/computations/{computation_id}"
182
- program_data = base64.b64encode(program).decode("utf-8")
183
- payload = {
184
- "mpprogram": program_data,
185
- "input_names": input_names,
186
- "output_names": output_names,
187
- }
188
-
189
- try:
190
- response = await self._client.put(url, json=payload)
191
- response.raise_for_status()
192
- return str(response.json()["name"])
193
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
194
- raise self._raise_http_error("create computation", e) from e
195
-
196
- async def get_computation(
197
- self, session_id: str, computation_id: str
198
- ) -> dict[str, Any]:
199
- """Get computation information.
200
-
201
- Args:
202
- session_id: The session name/ID
203
- computation_id: The computation name/ID
204
-
205
- Returns:
206
- Computation information dictionary
207
-
208
- Raises:
209
- RuntimeError: If computation retrieval fails
210
- """
211
- url = f"/sessions/{session_id}/computations/{computation_id}"
212
-
213
- try:
214
- response = await self._client.get(url)
215
- response.raise_for_status()
216
- return dict(response.json())
217
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
218
- raise self._raise_http_error(f"get computation {computation_id}", e) from e
219
-
220
- async def delete_computation(self, session_id: str, computation_id: str) -> None:
221
- """Delete a computation from a session.
222
-
223
- Args:
224
- session_id: The session name/ID
225
- computation_id: The computation name/ID
226
-
227
- Raises:
228
- RuntimeError: If computation deletion fails
229
- """
230
- url = f"/sessions/{session_id}/computations/{computation_id}"
231
-
232
- try:
233
- response = await self._client.delete(url)
234
- response.raise_for_status()
235
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
236
- raise self._raise_http_error(
237
- f"delete computation {computation_id}", e
238
- ) from e
239
-
240
- # Symbol Management
241
- async def create_symbol(
242
- self, session_name: str, symbol_name: str, data: Any, mptype: dict | None = None
243
- ) -> None:
244
- """Create a symbol with data.
245
-
246
- Args:
247
- session_name: The session name/ID
248
- symbol_name: The symbol name/ID
249
- data: The data to store
250
- mptype: Optional type information
251
-
252
- Raises:
253
- RuntimeError: If symbol creation fails
254
- """
255
- url = f"/sessions/{session_name}/symbols/{symbol_name}"
256
-
257
- # Serialize data using Value envelope
258
- if not isinstance(data, Value):
259
- raise TypeError(f"Data must be a Value instance, got {type(data)}")
260
- data_bytes = encode_value(data)
261
- data_b64 = base64.b64encode(data_bytes).decode("utf-8")
262
-
263
- payload = {"data": data_b64, "mptype": mptype or {}}
264
-
265
- try:
266
- response = await self._client.put(url, json=payload)
267
- response.raise_for_status()
268
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
269
- raise self._raise_http_error(f"create symbol {symbol_name}", e) from e
270
-
271
- async def get_symbol(self, session_name: str, symbol_name: str) -> Any:
272
- """Get symbol data.
273
-
274
- Args:
275
- session_name: The session name/ID
276
- symbol_name: The symbol name/ID
277
-
278
- Returns:
279
- The deserialized symbol data
280
-
281
- Raises:
282
- RuntimeError: If symbol retrieval fails
283
- """
284
- # For simple symbol names (no slashes), we can use them directly in the URL
285
- url = f"/sessions/{session_name}/symbols/{symbol_name}"
286
-
287
- try:
288
- response = await self._client.get(url)
289
- response.raise_for_status()
290
- symbol_data = response.json()
291
-
292
- # Deserialize data using Value envelope
293
- data_bytes = base64.b64decode(symbol_data["data"])
294
- return decode_value(data_bytes)
295
-
296
- except httpx.HTTPStatusError as e:
297
- if e.response is not None and e.response.status_code == 404:
298
- return None
299
- raise self._raise_http_error(f"get symbol {symbol_name}", e) from e
300
- except httpx.RequestError as e:
301
- raise self._raise_http_error(f"get symbol {symbol_name}", e) from e
302
-
303
- async def delete_symbol(self, session_name: str, symbol_name: str) -> None:
304
- """Delete a symbol from a session.
305
-
306
- Args:
307
- session_name: The session name/ID
308
- symbol_name: The symbol name/ID
309
-
310
- Raises:
311
- RuntimeError: If symbol deletion fails
312
- """
313
- url = f"/sessions/{session_name}/symbols/{symbol_name}"
314
-
315
- try:
316
- response = await self._client.delete(url)
317
- response.raise_for_status()
318
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
319
- raise self._raise_http_error(f"delete symbol {symbol_name}", e) from e
320
-
321
- async def health_check(self) -> bool:
322
- """Perform a health check on the HTTP executor service.
323
-
324
- Returns:
325
- True if the service is healthy, False otherwise
326
-
327
- Raises:
328
- RuntimeError: If the health check fails
329
- """
330
- url = "/health"
331
-
332
- try:
333
- response = await self._client.get(url)
334
- response.raise_for_status()
335
- result = response.json().get("status") == "ok"
336
- return bool(result) # Ensure we return a bool type
337
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
338
- raise self._raise_http_error("perform health check", e) from e
339
-
340
- async def list_symbols(self, session_name: str) -> list[str]:
341
- """List all symbols in a session.
342
-
343
- Args:
344
- session_name: The session name/ID
345
-
346
- Returns:
347
- List of symbol names
348
-
349
- Raises:
350
- RuntimeError: If symbol listing fails
351
- """
352
- url = f"/sessions/{session_name}/symbols"
353
-
354
- try:
355
- response = await self._client.get(url)
356
- response.raise_for_status()
357
- return list(response.json()["symbols"])
358
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
359
- raise self._raise_http_error("list symbols", e) from e
360
-
361
- async def list_sessions(self) -> list[str]:
362
- """List all sessions on this node.
363
-
364
- Returns:
365
- List of session names
366
-
367
- Raises:
368
- RuntimeError: If session listing fails
369
- """
370
- url = "/sessions"
371
- try:
372
- response = await self._client.get(url)
373
- response.raise_for_status()
374
- return list(response.json()["sessions"])
375
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
376
- raise self._raise_http_error("list sessions", e) from e
377
-
378
- async def list_computations(self, session_name: str) -> list[str]:
379
- """List all computations in a session.
380
-
381
- Args:
382
- session_name: The session name/ID
383
-
384
- Returns:
385
- List of computation names
386
-
387
- Raises:
388
- RuntimeError: If computation listing fails
389
- """
390
- url = f"/sessions/{session_name}/computations"
391
- try:
392
- response = await self._client.get(url)
393
- response.raise_for_status()
394
- return list(response.json()["computations"])
395
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
396
- raise self._raise_http_error(
397
- f"list computations for session {session_name}", e
398
- ) from e
399
-
400
- # ---------------- Global Symbols (process-level) ----------------
401
- async def create_global_symbol(
402
- self, symbol_name: str, data: Any, mptype: dict | None = None
403
- ) -> None:
404
- """Create or replace a process-global symbol.
405
-
406
- Args:
407
- symbol_name: Identifier
408
- data: Python object to store (pickle based)
409
- mptype: Optional metadata dict
410
- """
411
- url = f"/api/v1/symbols/{symbol_name}"
412
- try:
413
- # Serialize using Value envelope
414
- if not isinstance(data, Value):
415
- raise TypeError(f"Data must be a Value instance, got {type(data)}")
416
- data_bytes = encode_value(data)
417
- payload = {
418
- "data": base64.b64encode(data_bytes).decode("utf-8"),
419
- "mptype": mptype or {},
420
- }
421
- resp = await self._client.put(url, json=payload)
422
- resp.raise_for_status()
423
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
424
- raise self._raise_http_error(
425
- f"create global symbol {symbol_name}", e
426
- ) from e
427
-
428
- async def get_global_symbol(self, symbol_name: str) -> Any:
429
- url = f"/api/v1/symbols/{symbol_name}"
430
- try:
431
- resp = await self._client.get(url)
432
- resp.raise_for_status()
433
- payload = resp.json()
434
- data_bytes = base64.b64decode(payload["data"])
435
- return decode_value(data_bytes)
436
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
437
- raise self._raise_http_error(f"get global symbol {symbol_name}", e) from e
438
-
439
- async def delete_global_symbol(self, symbol_name: str) -> None:
440
- url = f"/api/v1/symbols/{symbol_name}"
441
- try:
442
- resp = await self._client.delete(url)
443
- resp.raise_for_status()
444
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
445
- raise self._raise_http_error(
446
- f"delete global symbol {symbol_name}", e
447
- ) from e
448
-
449
- async def list_global_symbols(self) -> list[str]:
450
- url = "/api/v1/symbols"
451
- try:
452
- resp = await self._client.get(url)
453
- resp.raise_for_status()
454
- return list(resp.json().get("symbols", []))
455
- except (httpx.HTTPStatusError, httpx.RequestError) as e:
456
- raise self._raise_http_error("list global symbols", e) from e
@@ -1,131 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- This module provides a client-side communicator for interacting with the HTTP backend.
17
- Its sole responsibility is to handle inter-party data exchange (send/recv).
18
- """
19
-
20
- import base64
21
- import logging
22
- from typing import Any
23
-
24
- import httpx
25
-
26
- from mplang.v1.core.comm import CommunicatorBase
27
- from mplang.v1.kernels.value import Value, decode_value, encode_value
28
-
29
-
30
- class HttpCommunicator(CommunicatorBase):
31
- def __init__(self, session_name: str, rank: int, endpoints: list[str]):
32
- # Validate endpoints
33
- if not endpoints:
34
- raise ValueError("endpoints cannot be empty")
35
-
36
- if not all(endpoint for endpoint in endpoints):
37
- raise ValueError("endpoints cannot contain empty elements")
38
-
39
- super().__init__(rank, len(endpoints))
40
- self.session_name = session_name
41
- # Ensure all endpoints have protocol prefix
42
- self.endpoints = [
43
- endpoint
44
- if endpoint.startswith(("http://", "https://"))
45
- else f"http://{endpoint}"
46
- for endpoint in endpoints
47
- ]
48
- self._counter = 0
49
- logging.info(
50
- f"HttpCommunicator initialized: session={session_name}, rank={rank}, endpoints={self.endpoints}"
51
- )
52
-
53
- # override
54
- def new_id(self) -> str:
55
- res = self._counter
56
- self._counter += 1
57
- return str(res)
58
-
59
- def send(self, to: int, key: str, data: Any) -> None:
60
- """Sends data to a peer party by PUTing to its /comm/{key}/from/{from_rank} endpoint.
61
-
62
- Supports two modes:
63
- - SPU channel (key starts with "spu:"): sends raw bytes directly
64
- - Normal channel: wraps data in Value envelope
65
- """
66
- target_endpoint = self.endpoints[to]
67
- url = f"{target_endpoint}/sessions/{self.session_name}/comm/{key}/from/{self._rank}"
68
- logging.debug(
69
- f"Sending data: from_rank={self._rank}, to_rank={to}, key={key}, target_url={url}"
70
- )
71
-
72
- try:
73
- # SPU channel mode: send raw bytes directly
74
- if key.startswith("spu:") and isinstance(data, bytes):
75
- data_b64 = base64.b64encode(data).decode("utf-8")
76
- request_data = {"data": data_b64, "is_raw_bytes": True}
77
- # Normal mode: serialize using Value envelope
78
- elif isinstance(data, Value):
79
- data_bytes = encode_value(data)
80
- data_b64 = base64.b64encode(data_bytes).decode("utf-8")
81
- request_data = {"data": data_b64}
82
- else:
83
- raise TypeError(
84
- f"Communicator requires Value instance, got {type(data).__name__}. "
85
- "Wrap data in TensorValue or custom Value subclass."
86
- )
87
-
88
- response = httpx.put(url, json=request_data, timeout=60)
89
- logging.debug(f"Send response: status={response.status_code}")
90
- if response.status_code != 200:
91
- logging.error(f"Send failed: {response.text}")
92
- response.raise_for_status()
93
- except httpx.RequestError as e:
94
- logging.error(
95
- f"Send failed with exception: from_rank={self._rank}, to_rank={to}, key={key}, error={e}"
96
- )
97
- raise OSError(f"Failed to send data to rank {to}") from e
98
-
99
- def recv(self, frm: int, key: str) -> Any:
100
- """Wait until the key is set, returns the value.
101
-
102
- Supports two modes:
103
- - SPU channel (key starts with "spu:"): returns raw bytes
104
- - Normal channel: returns deserialized Value
105
- """
106
- logging.debug(
107
- f"Waiting to receive: from_rank={frm}, to_rank={self._rank}, key={key}"
108
- )
109
- received_data = super().recv(frm, key)
110
-
111
- # Check if this is raw bytes (SPU channel)
112
- if isinstance(received_data, dict) and received_data.get("is_raw_bytes"):
113
- data_bytes = base64.b64decode(received_data["data"])
114
- logging.debug(
115
- f"Received raw bytes: from_rank={frm}, to_rank={self._rank}, key={key}, size={len(data_bytes)}"
116
- )
117
- return data_bytes
118
-
119
- # Normal mode: deserialize Value envelope
120
- data_b64 = (
121
- received_data
122
- if isinstance(received_data, str)
123
- else received_data.get("data")
124
- )
125
- data_bytes = base64.b64decode(data_b64)
126
- result = decode_value(data_bytes)
127
-
128
- logging.debug(
129
- f"Received data: from_rank={frm}, to_rank={self._rank}, key={key}"
130
- )
131
- return result