mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
@@ -1,501 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- This module implements the HTTP server for the toy backend.
17
- It uses FastAPI to provide a RESTful API for managing computations.
18
- """
19
-
20
- import base64
21
- import logging
22
- import re
23
- from typing import Any
24
-
25
- from fastapi import (
26
- FastAPI,
27
- HTTPException,
28
- Request,
29
- )
30
- from fastapi.responses import JSONResponse
31
- from pydantic import BaseModel
32
-
33
- from mplang.v1.core import IrReader, TableType, TensorType
34
- from mplang.v1.core.cluster import ClusterSpec
35
- from mplang.v1.kernels.base import KernelContext
36
- from mplang.v1.kernels.value import Value, decode_value, encode_value
37
- from mplang.v1.protos.v1alpha1 import mpir_pb2
38
- from mplang.v1.runtime.data_providers import (
39
- DataProvider,
40
- ResolvedURI,
41
- register_provider,
42
- )
43
- from mplang.v1.runtime.exceptions import InvalidRequestError, ResourceNotFound
44
- from mplang.v1.runtime.session import (
45
- Computation,
46
- Session,
47
- Symbol,
48
- create_session_from_spec,
49
- )
50
-
51
- logger = logging.getLogger(__name__)
52
-
53
- app = FastAPI()
54
-
55
- # per-server global state
56
- _sessions: dict[str, Session] = {}
57
- _global_symbols: dict[str, Symbol] = {}
58
-
59
-
60
- def register_session(session: Session) -> Session: # pragma: no cover - test helper
61
- existing = _sessions.get(session.name)
62
- if existing:
63
- return existing
64
- _sessions[session.name] = session
65
- return session
66
-
67
-
68
- class _SymbolsProvider(DataProvider):
69
- """Server-local symbols provider backed by BackendRuntime.state."""
70
-
71
- @staticmethod
72
- def _symbol_name(uri: ResolvedURI) -> str:
73
- if uri.scheme != "symbols" or uri.parsed is None:
74
- raise InvalidRequestError(
75
- "symbols provider expects URI in the form symbols://{name}"
76
- )
77
-
78
- parsed = uri.parsed
79
- if parsed.query or parsed.params or parsed.fragment:
80
- raise InvalidRequestError(
81
- "symbols:// URI must not contain query or fragment"
82
- )
83
-
84
- if parsed.netloc:
85
- # e.g. symbols://foo -> name is carried in netloc (path may be empty or "/")
86
- if parsed.path not in ("", "/"):
87
- raise InvalidRequestError("symbols:// URIs cannot include subpaths")
88
- name = parsed.netloc
89
- else:
90
- # e.g. symbols:///foo -> netloc empty, single path segment is the symbol name
91
- path = parsed.path.lstrip("/")
92
- if not path or "/" in path:
93
- raise InvalidRequestError(
94
- "symbols:// URI must specify a single symbol name"
95
- )
96
- name = path
97
-
98
- if not name:
99
- raise InvalidRequestError("symbols:// URI missing symbol name")
100
- return name
101
-
102
- def read(
103
- self,
104
- uri: ResolvedURI,
105
- out_spec: TensorType | TableType,
106
- *,
107
- ctx: KernelContext,
108
- ) -> Any: # type: ignore[override]
109
- name = self._symbol_name(uri)
110
- sym = _global_symbols.get(name)
111
- if sym is None:
112
- raise ResourceNotFound(f"Global symbol '{name}' not found")
113
- return sym.data
114
-
115
- def write(
116
- self,
117
- uri: ResolvedURI,
118
- value: Any,
119
- *,
120
- ctx: KernelContext,
121
- ) -> None: # type: ignore[override]
122
- name = self._symbol_name(uri)
123
- if not isinstance(value, Value):
124
- raise InvalidRequestError(
125
- f"symbols:// write expects Value instance, got {type(value)}"
126
- )
127
- _global_symbols[name] = Symbol(name=name, mptype={}, data=value)
128
-
129
-
130
- # Register symbols provider explicitly for server runtime
131
- register_provider("symbols", _SymbolsProvider())
132
-
133
-
134
- @app.exception_handler(ResourceNotFound)
135
- def resource_not_found_handler(request: Request, exc: ResourceNotFound) -> JSONResponse:
136
- """Handler for ResourceNotFound exceptions."""
137
- logger.warning(f"Resource not found at {request.url}: {exc}")
138
- return JSONResponse(
139
- status_code=404,
140
- content={"detail": str(exc)},
141
- )
142
-
143
-
144
- @app.exception_handler(InvalidRequestError)
145
- def invalid_request_handler(request: Request, exc: InvalidRequestError) -> JSONResponse:
146
- """Handler for InvalidRequestError exceptions."""
147
- logger.warning(f"Invalid request at {request.url}: {exc}")
148
- return JSONResponse(
149
- status_code=400,
150
- content={"detail": str(exc)},
151
- )
152
-
153
-
154
- @app.exception_handler(Exception)
155
- def general_exception_handler(request: Request, exc: Exception) -> JSONResponse:
156
- """Global exception handler for better error reporting."""
157
- logger.error(f"Unhandled exception at {request.url}: {exc}", exc_info=True)
158
- return JSONResponse(
159
- status_code=500,
160
- content={
161
- "detail": f"Internal server error: {exc!s}",
162
- "error_type": type(exc).__name__,
163
- "path": str(request.url.path),
164
- },
165
- )
166
-
167
-
168
- def validate_name(name: str, name_type: str) -> None:
169
- """Validate that a name is safe for use in URL paths.
170
-
171
- Args:
172
- name: The name to validate
173
- name_type: Type of name (for error messages, e.g., "session", "computation")
174
-
175
- Raises:
176
- HTTPException: If the name contains invalid characters
177
- """
178
- if not name:
179
- raise HTTPException(status_code=400, detail=f"{name_type} name cannot be empty")
180
-
181
- # Only allow alphanumeric, hyphens, underscores, and dots
182
- if not re.match(r"^[a-zA-Z0-9._-]+$", name):
183
- raise HTTPException(
184
- status_code=400,
185
- detail=f"{name_type} name can only contain letters, numbers, dots, hyphens, and underscores",
186
- )
187
-
188
-
189
- # Request/Response Models
190
- class CreateSessionRequest(BaseModel):
191
- rank: int
192
- cluster_spec: dict
193
-
194
-
195
- class SessionResponse(BaseModel):
196
- name: str
197
-
198
-
199
- class CreateComputationRequest(BaseModel):
200
- mpprogram: str # Base64 encoded MPProgram proto
201
- input_names: list[str] # Mandatory input symbol names
202
- output_names: list[str] # Mandatory output symbol names
203
-
204
-
205
- class ComputationResponse(BaseModel):
206
- name: str
207
-
208
-
209
- class CreateSymbolRequest(BaseModel):
210
- mptype: dict
211
- data: str # Base64 encoded Value data
212
-
213
-
214
- class SymbolResponse(BaseModel):
215
- name: str
216
- mptype: dict
217
- data: str # Base64 encoded Value data
218
-
219
-
220
- class CommSendRequest(BaseModel):
221
- data: str # Base64 encoded binary data
222
- is_raw_bytes: bool = False # True for SPU channel raw bytes
223
-
224
-
225
- # Response Models for enhanced status
226
- class SessionListResponse(BaseModel):
227
- sessions: list[str]
228
-
229
-
230
- class ComputationListResponse(BaseModel):
231
- computations: list[str]
232
-
233
-
234
- class GlobalSymbolResponse(BaseModel):
235
- name: str
236
- mptype: dict
237
- data: str # Base64 encoded Value data
238
-
239
-
240
- @app.get("/health")
241
- async def health_check() -> dict[str, str]:
242
- """Health check endpoint."""
243
- return {"status": "ok"}
244
-
245
-
246
- # List all sessions
247
- @app.get("/sessions", response_model=SessionListResponse)
248
- def list_sessions() -> SessionListResponse:
249
- """List all session names."""
250
- return SessionListResponse(sessions=list(_sessions.keys()))
251
-
252
-
253
- # List all computations in a session
254
- @app.get(
255
- "/sessions/{session_name}/computations", response_model=ComputationListResponse
256
- )
257
- def list_session_computations(session_name: str) -> ComputationListResponse:
258
- """List all computation names in a session."""
259
- sess = _sessions.get(session_name)
260
- if not sess:
261
- raise ResourceNotFound(f"Session '{session_name}' not found")
262
- return ComputationListResponse(computations=sess.list_computations())
263
-
264
-
265
- # Session endpoints
266
- @app.put("/sessions/{session_name}", response_model=SessionResponse)
267
- def create_session(session_name: str, request: CreateSessionRequest) -> SessionResponse:
268
- validate_name(session_name, "session")
269
- # Delegate cluster spec parsing & session construction to resource layer
270
-
271
- if session_name in _sessions:
272
- sess = _sessions[session_name]
273
- else:
274
- spec = ClusterSpec.from_dict(request.cluster_spec)
275
- sess = create_session_from_spec(name=session_name, rank=request.rank, spec=spec)
276
- _sessions[session_name] = sess
277
- return SessionResponse(name=sess.name)
278
-
279
-
280
- @app.get("/sessions/{session_name}", response_model=SessionResponse)
281
- def get_session(session_name: str) -> SessionResponse:
282
- sess = _sessions.get(session_name)
283
- if not sess:
284
- raise ResourceNotFound(f"Session '{session_name}' not found")
285
- return SessionResponse(name=sess.name)
286
-
287
-
288
- @app.delete("/sessions/{session_name}")
289
- def delete_session(session_name: str) -> dict[str, str]:
290
- """Delete a session and all its associated resources."""
291
- if session_name in _sessions:
292
- del _sessions[session_name]
293
- logging.info(f"Session {session_name} deleted successfully")
294
- return {"message": f"Session '{session_name}' deleted successfully"}
295
- else:
296
- raise ResourceNotFound(f"Session '{session_name}' not found")
297
-
298
-
299
- # Computation endpoints
300
- @app.put(
301
- "/sessions/{session_name}/computations/{computation_id}",
302
- response_model=ComputationResponse,
303
- )
304
- def create_and_execute_computation(
305
- session_name: str, computation_id: str, request: CreateComputationRequest
306
- ) -> ComputationResponse:
307
- graph_proto = mpir_pb2.GraphProto()
308
- try:
309
- graph_proto.ParseFromString(base64.b64decode(request.mpprogram))
310
- except Exception as e:
311
- raise InvalidRequestError(
312
- f"Invalid base64 or protobuf for mpprogram: {e!s}"
313
- ) from e
314
-
315
- reader = IrReader()
316
- expr = reader.loads(graph_proto)
317
-
318
- if expr is None:
319
- raise InvalidRequestError("Failed to parse expression from protobuf")
320
-
321
- # Create the computation resource
322
- sess = _sessions.get(session_name)
323
- if not sess:
324
- raise ResourceNotFound(f"Session '{session_name}' not found.")
325
- comp = sess.get_computation(computation_id)
326
- if not comp:
327
- comp = Computation(name=computation_id, expr=expr)
328
- sess.add_computation(comp)
329
- sess.execute(comp, request.input_names, request.output_names)
330
- return ComputationResponse(name=computation_id)
331
-
332
-
333
- @app.delete("/sessions/{session_name}/computations/{computation_id}")
334
- def delete_computation(session_name: str, computation_id: str) -> dict[str, str]:
335
- """Delete a specific computation."""
336
- sess = _sessions.get(session_name)
337
- if sess and sess.delete_computation(computation_id):
338
- logging.info(
339
- f"Computation {computation_id} deleted from session {session_name}"
340
- )
341
- return {"message": f"Computation '{computation_id}' deleted successfully"}
342
- else:
343
- raise ResourceNotFound(
344
- f"Computation '{computation_id}' not found in session '{session_name}'"
345
- )
346
-
347
-
348
- # Symbol endpoints
349
- @app.put(
350
- "/sessions/{session_name}/symbols/{symbol_name}", response_model=SymbolResponse
351
- )
352
- def create_session_symbol(
353
- session_name: str, symbol_name: str, request: CreateSymbolRequest
354
- ) -> SymbolResponse:
355
- """Create a symbol in a session."""
356
- sess = _sessions.get(session_name)
357
- if not sess:
358
- raise ResourceNotFound(f"Session '{session_name}' not found.")
359
- try:
360
- obj = decode_value(base64.b64decode(request.data))
361
- except Exception as e:
362
- raise InvalidRequestError(f"Invalid symbol data: {e!s}") from e
363
- symbol = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
364
- sess.add_symbol(symbol)
365
- # Return the base64 data back to client; server stores Python object
366
- return SymbolResponse(
367
- name=symbol.name,
368
- mptype=symbol.mptype,
369
- data=base64.b64encode(encode_value(symbol.data)).decode("utf-8"),
370
- )
371
-
372
-
373
- @app.get(
374
- "/sessions/{session_name}/symbols/{symbol_name}", response_model=SymbolResponse
375
- )
376
- def get_session_symbol(session_name: str, symbol_name: str) -> SymbolResponse:
377
- """Get a symbol from a session."""
378
- try:
379
- logger.debug(
380
- f"Looking for symbol: '{symbol_name}' in session: '{session_name}'"
381
- )
382
- sess = _sessions.get(session_name)
383
- symbol = sess.get_symbol(symbol_name) if sess else None
384
- if not symbol:
385
- raise HTTPException(
386
- status_code=404, detail=f"Symbol {symbol_name} not found"
387
- )
388
-
389
- # symbol data is None means this party does not participate the computation
390
- # that produced the symbol.
391
- if symbol.data is None:
392
- raise ResourceNotFound(f"Symbol '{symbol_name}' has no data on this party")
393
-
394
- # Serialize using Value envelope
395
- return SymbolResponse(
396
- name=symbol.name,
397
- mptype=symbol.mptype,
398
- data=base64.b64encode(encode_value(symbol.data)).decode("utf-8"),
399
- )
400
- except ValueError as e:
401
- raise HTTPException(status_code=404, detail=str(e)) from e
402
-
403
-
404
- @app.get("/sessions/{session_name}/symbols")
405
- def list_session_symbols(session_name: str) -> dict[str, list[str]]:
406
- """List all symbols in a session."""
407
- sess = _sessions.get(session_name)
408
- if not sess:
409
- raise ResourceNotFound(f"Session '{session_name}' not found.")
410
- symbols = sess.list_symbols()
411
- return {"symbols": symbols}
412
-
413
-
414
- @app.delete("/sessions/{session_name}/symbols/{symbol_name}")
415
- def delete_symbol(session_name: str, symbol_name: str) -> dict[str, str]:
416
- """Delete a specific symbol."""
417
- sess = _sessions.get(session_name)
418
- if sess and sess.delete_symbol(symbol_name):
419
- logging.info(f"Symbol {symbol_name} deleted from session {session_name}")
420
- return {"message": f"Symbol '{symbol_name}' deleted successfully"}
421
- else:
422
- raise ResourceNotFound(
423
- f"Symbol '{symbol_name}' not found in session '{session_name}'"
424
- )
425
-
426
-
427
- # Global Symbols endpoints
428
- @app.put("/api/v1/symbols/{symbol_name}", response_model=GlobalSymbolResponse)
429
- def create_global_symbol(
430
- symbol_name: str, request: CreateSymbolRequest
431
- ) -> GlobalSymbolResponse:
432
- validate_name(symbol_name, "symbol")
433
- try:
434
- obj = decode_value(base64.b64decode(request.data))
435
- except Exception as e:
436
- raise InvalidRequestError(f"Invalid global symbol data: {e!s}") from e
437
- sym = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
438
- _global_symbols[symbol_name] = sym
439
- return GlobalSymbolResponse(
440
- name=sym.name,
441
- mptype=sym.mptype,
442
- data=base64.b64encode(encode_value(sym.data)).decode("utf-8"),
443
- )
444
-
445
-
446
- @app.get("/api/v1/symbols/{symbol_name}", response_model=GlobalSymbolResponse)
447
- def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse: # route handler
448
- sym = _global_symbols.get(symbol_name)
449
- if not sym:
450
- raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
451
- # Serialize using Value envelope
452
- return GlobalSymbolResponse(
453
- name=sym.name,
454
- mptype=sym.mptype,
455
- data=base64.b64encode(encode_value(sym.data)).decode("utf-8"),
456
- )
457
-
458
-
459
- @app.get("/api/v1/symbols")
460
- def list_global_symbols() -> dict[str, list[str]]:
461
- return {"symbols": list(_global_symbols.keys())}
462
-
463
-
464
- @app.delete("/api/v1/symbols/{symbol_name}")
465
- def delete_global_symbol(symbol_name: str) -> dict[str, str]: # route handler
466
- if symbol_name in _global_symbols:
467
- del _global_symbols[symbol_name]
468
- return {"message": f"Global symbol '{symbol_name}' deleted successfully"}
469
- else:
470
- raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
471
-
472
-
473
- # Communication endpoints
474
- # TODO(jint) this should be computation level, add multi computation parallel support.
475
- @app.put("/sessions/{session_name}/comm/{key}/from/{from_rank}")
476
- def comm_send(
477
- session_name: str, key: str, from_rank: int, request: CommSendRequest
478
- ) -> dict[str, str]:
479
- """
480
- Receive a message from another party and deliver it to the session's communicator.
481
- This endpoint runs on the receiver's server.
482
- """
483
- sess = _sessions.get(session_name)
484
- if not sess or not sess.communicator:
485
- logger.error(f"Session or communicator not found: session={session_name}")
486
- raise HTTPException(status_code=404, detail="Session or communicator not found")
487
-
488
- # The receiver rank should be the rank of the server hosting this endpoint
489
- # We don't need to validate to_rank since the request is coming to this server
490
-
491
- # For raw bytes (SPU channel), pass through as dict with flag
492
- # For normal data, pass the base64 string directly
493
- data_payload: str | dict[str, object]
494
- if request.is_raw_bytes:
495
- data_payload = {"data": request.data, "is_raw_bytes": True}
496
- else:
497
- data_payload = request.data
498
-
499
- # Use the proper onSent mechanism from CommunicatorBase
500
- sess.communicator.onSent(from_rank, key, data_payload)
501
- return {"status": "ok"}