mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/libs/mpc/psi/rr22.py +303 -0
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang/v2/libs/mpc/psi/rr22.py +0 -344
- mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/runtime/server.py
DELETED
|
@@ -1,501 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
This module implements the HTTP server for the toy backend.
|
|
17
|
-
It uses FastAPI to provide a RESTful API for managing computations.
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
|
-
import base64
|
|
21
|
-
import logging
|
|
22
|
-
import re
|
|
23
|
-
from typing import Any
|
|
24
|
-
|
|
25
|
-
from fastapi import (
|
|
26
|
-
FastAPI,
|
|
27
|
-
HTTPException,
|
|
28
|
-
Request,
|
|
29
|
-
)
|
|
30
|
-
from fastapi.responses import JSONResponse
|
|
31
|
-
from pydantic import BaseModel
|
|
32
|
-
|
|
33
|
-
from mplang.v1.core import IrReader, TableType, TensorType
|
|
34
|
-
from mplang.v1.core.cluster import ClusterSpec
|
|
35
|
-
from mplang.v1.kernels.base import KernelContext
|
|
36
|
-
from mplang.v1.kernels.value import Value, decode_value, encode_value
|
|
37
|
-
from mplang.v1.protos.v1alpha1 import mpir_pb2
|
|
38
|
-
from mplang.v1.runtime.data_providers import (
|
|
39
|
-
DataProvider,
|
|
40
|
-
ResolvedURI,
|
|
41
|
-
register_provider,
|
|
42
|
-
)
|
|
43
|
-
from mplang.v1.runtime.exceptions import InvalidRequestError, ResourceNotFound
|
|
44
|
-
from mplang.v1.runtime.session import (
|
|
45
|
-
Computation,
|
|
46
|
-
Session,
|
|
47
|
-
Symbol,
|
|
48
|
-
create_session_from_spec,
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
logger = logging.getLogger(__name__)
|
|
52
|
-
|
|
53
|
-
app = FastAPI()
|
|
54
|
-
|
|
55
|
-
# per-server global state
|
|
56
|
-
_sessions: dict[str, Session] = {}
|
|
57
|
-
_global_symbols: dict[str, Symbol] = {}
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def register_session(session: Session) -> Session: # pragma: no cover - test helper
|
|
61
|
-
existing = _sessions.get(session.name)
|
|
62
|
-
if existing:
|
|
63
|
-
return existing
|
|
64
|
-
_sessions[session.name] = session
|
|
65
|
-
return session
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class _SymbolsProvider(DataProvider):
|
|
69
|
-
"""Server-local symbols provider backed by BackendRuntime.state."""
|
|
70
|
-
|
|
71
|
-
@staticmethod
|
|
72
|
-
def _symbol_name(uri: ResolvedURI) -> str:
|
|
73
|
-
if uri.scheme != "symbols" or uri.parsed is None:
|
|
74
|
-
raise InvalidRequestError(
|
|
75
|
-
"symbols provider expects URI in the form symbols://{name}"
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
parsed = uri.parsed
|
|
79
|
-
if parsed.query or parsed.params or parsed.fragment:
|
|
80
|
-
raise InvalidRequestError(
|
|
81
|
-
"symbols:// URI must not contain query or fragment"
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
if parsed.netloc:
|
|
85
|
-
# e.g. symbols://foo -> name is carried in netloc (path may be empty or "/")
|
|
86
|
-
if parsed.path not in ("", "/"):
|
|
87
|
-
raise InvalidRequestError("symbols:// URIs cannot include subpaths")
|
|
88
|
-
name = parsed.netloc
|
|
89
|
-
else:
|
|
90
|
-
# e.g. symbols:///foo -> netloc empty, single path segment is the symbol name
|
|
91
|
-
path = parsed.path.lstrip("/")
|
|
92
|
-
if not path or "/" in path:
|
|
93
|
-
raise InvalidRequestError(
|
|
94
|
-
"symbols:// URI must specify a single symbol name"
|
|
95
|
-
)
|
|
96
|
-
name = path
|
|
97
|
-
|
|
98
|
-
if not name:
|
|
99
|
-
raise InvalidRequestError("symbols:// URI missing symbol name")
|
|
100
|
-
return name
|
|
101
|
-
|
|
102
|
-
def read(
|
|
103
|
-
self,
|
|
104
|
-
uri: ResolvedURI,
|
|
105
|
-
out_spec: TensorType | TableType,
|
|
106
|
-
*,
|
|
107
|
-
ctx: KernelContext,
|
|
108
|
-
) -> Any: # type: ignore[override]
|
|
109
|
-
name = self._symbol_name(uri)
|
|
110
|
-
sym = _global_symbols.get(name)
|
|
111
|
-
if sym is None:
|
|
112
|
-
raise ResourceNotFound(f"Global symbol '{name}' not found")
|
|
113
|
-
return sym.data
|
|
114
|
-
|
|
115
|
-
def write(
|
|
116
|
-
self,
|
|
117
|
-
uri: ResolvedURI,
|
|
118
|
-
value: Any,
|
|
119
|
-
*,
|
|
120
|
-
ctx: KernelContext,
|
|
121
|
-
) -> None: # type: ignore[override]
|
|
122
|
-
name = self._symbol_name(uri)
|
|
123
|
-
if not isinstance(value, Value):
|
|
124
|
-
raise InvalidRequestError(
|
|
125
|
-
f"symbols:// write expects Value instance, got {type(value)}"
|
|
126
|
-
)
|
|
127
|
-
_global_symbols[name] = Symbol(name=name, mptype={}, data=value)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
# Register symbols provider explicitly for server runtime
|
|
131
|
-
register_provider("symbols", _SymbolsProvider())
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
@app.exception_handler(ResourceNotFound)
|
|
135
|
-
def resource_not_found_handler(request: Request, exc: ResourceNotFound) -> JSONResponse:
|
|
136
|
-
"""Handler for ResourceNotFound exceptions."""
|
|
137
|
-
logger.warning(f"Resource not found at {request.url}: {exc}")
|
|
138
|
-
return JSONResponse(
|
|
139
|
-
status_code=404,
|
|
140
|
-
content={"detail": str(exc)},
|
|
141
|
-
)
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
@app.exception_handler(InvalidRequestError)
|
|
145
|
-
def invalid_request_handler(request: Request, exc: InvalidRequestError) -> JSONResponse:
|
|
146
|
-
"""Handler for InvalidRequestError exceptions."""
|
|
147
|
-
logger.warning(f"Invalid request at {request.url}: {exc}")
|
|
148
|
-
return JSONResponse(
|
|
149
|
-
status_code=400,
|
|
150
|
-
content={"detail": str(exc)},
|
|
151
|
-
)
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
@app.exception_handler(Exception)
|
|
155
|
-
def general_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
|
156
|
-
"""Global exception handler for better error reporting."""
|
|
157
|
-
logger.error(f"Unhandled exception at {request.url}: {exc}", exc_info=True)
|
|
158
|
-
return JSONResponse(
|
|
159
|
-
status_code=500,
|
|
160
|
-
content={
|
|
161
|
-
"detail": f"Internal server error: {exc!s}",
|
|
162
|
-
"error_type": type(exc).__name__,
|
|
163
|
-
"path": str(request.url.path),
|
|
164
|
-
},
|
|
165
|
-
)
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
def validate_name(name: str, name_type: str) -> None:
|
|
169
|
-
"""Validate that a name is safe for use in URL paths.
|
|
170
|
-
|
|
171
|
-
Args:
|
|
172
|
-
name: The name to validate
|
|
173
|
-
name_type: Type of name (for error messages, e.g., "session", "computation")
|
|
174
|
-
|
|
175
|
-
Raises:
|
|
176
|
-
HTTPException: If the name contains invalid characters
|
|
177
|
-
"""
|
|
178
|
-
if not name:
|
|
179
|
-
raise HTTPException(status_code=400, detail=f"{name_type} name cannot be empty")
|
|
180
|
-
|
|
181
|
-
# Only allow alphanumeric, hyphens, underscores, and dots
|
|
182
|
-
if not re.match(r"^[a-zA-Z0-9._-]+$", name):
|
|
183
|
-
raise HTTPException(
|
|
184
|
-
status_code=400,
|
|
185
|
-
detail=f"{name_type} name can only contain letters, numbers, dots, hyphens, and underscores",
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
# Request/Response Models
|
|
190
|
-
class CreateSessionRequest(BaseModel):
|
|
191
|
-
rank: int
|
|
192
|
-
cluster_spec: dict
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
class SessionResponse(BaseModel):
|
|
196
|
-
name: str
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
class CreateComputationRequest(BaseModel):
|
|
200
|
-
mpprogram: str # Base64 encoded MPProgram proto
|
|
201
|
-
input_names: list[str] # Mandatory input symbol names
|
|
202
|
-
output_names: list[str] # Mandatory output symbol names
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
class ComputationResponse(BaseModel):
|
|
206
|
-
name: str
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
class CreateSymbolRequest(BaseModel):
|
|
210
|
-
mptype: dict
|
|
211
|
-
data: str # Base64 encoded Value data
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
class SymbolResponse(BaseModel):
|
|
215
|
-
name: str
|
|
216
|
-
mptype: dict
|
|
217
|
-
data: str # Base64 encoded Value data
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
class CommSendRequest(BaseModel):
|
|
221
|
-
data: str # Base64 encoded binary data
|
|
222
|
-
is_raw_bytes: bool = False # True for SPU channel raw bytes
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
# Response Models for enhanced status
|
|
226
|
-
class SessionListResponse(BaseModel):
|
|
227
|
-
sessions: list[str]
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
class ComputationListResponse(BaseModel):
|
|
231
|
-
computations: list[str]
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
class GlobalSymbolResponse(BaseModel):
|
|
235
|
-
name: str
|
|
236
|
-
mptype: dict
|
|
237
|
-
data: str # Base64 encoded Value data
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
@app.get("/health")
|
|
241
|
-
async def health_check() -> dict[str, str]:
|
|
242
|
-
"""Health check endpoint."""
|
|
243
|
-
return {"status": "ok"}
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
# List all sessions
|
|
247
|
-
@app.get("/sessions", response_model=SessionListResponse)
|
|
248
|
-
def list_sessions() -> SessionListResponse:
|
|
249
|
-
"""List all session names."""
|
|
250
|
-
return SessionListResponse(sessions=list(_sessions.keys()))
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
# List all computations in a session
|
|
254
|
-
@app.get(
|
|
255
|
-
"/sessions/{session_name}/computations", response_model=ComputationListResponse
|
|
256
|
-
)
|
|
257
|
-
def list_session_computations(session_name: str) -> ComputationListResponse:
|
|
258
|
-
"""List all computation names in a session."""
|
|
259
|
-
sess = _sessions.get(session_name)
|
|
260
|
-
if not sess:
|
|
261
|
-
raise ResourceNotFound(f"Session '{session_name}' not found")
|
|
262
|
-
return ComputationListResponse(computations=sess.list_computations())
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
# Session endpoints
|
|
266
|
-
@app.put("/sessions/{session_name}", response_model=SessionResponse)
|
|
267
|
-
def create_session(session_name: str, request: CreateSessionRequest) -> SessionResponse:
|
|
268
|
-
validate_name(session_name, "session")
|
|
269
|
-
# Delegate cluster spec parsing & session construction to resource layer
|
|
270
|
-
|
|
271
|
-
if session_name in _sessions:
|
|
272
|
-
sess = _sessions[session_name]
|
|
273
|
-
else:
|
|
274
|
-
spec = ClusterSpec.from_dict(request.cluster_spec)
|
|
275
|
-
sess = create_session_from_spec(name=session_name, rank=request.rank, spec=spec)
|
|
276
|
-
_sessions[session_name] = sess
|
|
277
|
-
return SessionResponse(name=sess.name)
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
@app.get("/sessions/{session_name}", response_model=SessionResponse)
|
|
281
|
-
def get_session(session_name: str) -> SessionResponse:
|
|
282
|
-
sess = _sessions.get(session_name)
|
|
283
|
-
if not sess:
|
|
284
|
-
raise ResourceNotFound(f"Session '{session_name}' not found")
|
|
285
|
-
return SessionResponse(name=sess.name)
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
@app.delete("/sessions/{session_name}")
|
|
289
|
-
def delete_session(session_name: str) -> dict[str, str]:
|
|
290
|
-
"""Delete a session and all its associated resources."""
|
|
291
|
-
if session_name in _sessions:
|
|
292
|
-
del _sessions[session_name]
|
|
293
|
-
logging.info(f"Session {session_name} deleted successfully")
|
|
294
|
-
return {"message": f"Session '{session_name}' deleted successfully"}
|
|
295
|
-
else:
|
|
296
|
-
raise ResourceNotFound(f"Session '{session_name}' not found")
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
# Computation endpoints
|
|
300
|
-
@app.put(
|
|
301
|
-
"/sessions/{session_name}/computations/{computation_id}",
|
|
302
|
-
response_model=ComputationResponse,
|
|
303
|
-
)
|
|
304
|
-
def create_and_execute_computation(
|
|
305
|
-
session_name: str, computation_id: str, request: CreateComputationRequest
|
|
306
|
-
) -> ComputationResponse:
|
|
307
|
-
graph_proto = mpir_pb2.GraphProto()
|
|
308
|
-
try:
|
|
309
|
-
graph_proto.ParseFromString(base64.b64decode(request.mpprogram))
|
|
310
|
-
except Exception as e:
|
|
311
|
-
raise InvalidRequestError(
|
|
312
|
-
f"Invalid base64 or protobuf for mpprogram: {e!s}"
|
|
313
|
-
) from e
|
|
314
|
-
|
|
315
|
-
reader = IrReader()
|
|
316
|
-
expr = reader.loads(graph_proto)
|
|
317
|
-
|
|
318
|
-
if expr is None:
|
|
319
|
-
raise InvalidRequestError("Failed to parse expression from protobuf")
|
|
320
|
-
|
|
321
|
-
# Create the computation resource
|
|
322
|
-
sess = _sessions.get(session_name)
|
|
323
|
-
if not sess:
|
|
324
|
-
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
|
325
|
-
comp = sess.get_computation(computation_id)
|
|
326
|
-
if not comp:
|
|
327
|
-
comp = Computation(name=computation_id, expr=expr)
|
|
328
|
-
sess.add_computation(comp)
|
|
329
|
-
sess.execute(comp, request.input_names, request.output_names)
|
|
330
|
-
return ComputationResponse(name=computation_id)
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
@app.delete("/sessions/{session_name}/computations/{computation_id}")
|
|
334
|
-
def delete_computation(session_name: str, computation_id: str) -> dict[str, str]:
|
|
335
|
-
"""Delete a specific computation."""
|
|
336
|
-
sess = _sessions.get(session_name)
|
|
337
|
-
if sess and sess.delete_computation(computation_id):
|
|
338
|
-
logging.info(
|
|
339
|
-
f"Computation {computation_id} deleted from session {session_name}"
|
|
340
|
-
)
|
|
341
|
-
return {"message": f"Computation '{computation_id}' deleted successfully"}
|
|
342
|
-
else:
|
|
343
|
-
raise ResourceNotFound(
|
|
344
|
-
f"Computation '{computation_id}' not found in session '{session_name}'"
|
|
345
|
-
)
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
# Symbol endpoints
|
|
349
|
-
@app.put(
|
|
350
|
-
"/sessions/{session_name}/symbols/{symbol_name}", response_model=SymbolResponse
|
|
351
|
-
)
|
|
352
|
-
def create_session_symbol(
|
|
353
|
-
session_name: str, symbol_name: str, request: CreateSymbolRequest
|
|
354
|
-
) -> SymbolResponse:
|
|
355
|
-
"""Create a symbol in a session."""
|
|
356
|
-
sess = _sessions.get(session_name)
|
|
357
|
-
if not sess:
|
|
358
|
-
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
|
359
|
-
try:
|
|
360
|
-
obj = decode_value(base64.b64decode(request.data))
|
|
361
|
-
except Exception as e:
|
|
362
|
-
raise InvalidRequestError(f"Invalid symbol data: {e!s}") from e
|
|
363
|
-
symbol = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
|
|
364
|
-
sess.add_symbol(symbol)
|
|
365
|
-
# Return the base64 data back to client; server stores Python object
|
|
366
|
-
return SymbolResponse(
|
|
367
|
-
name=symbol.name,
|
|
368
|
-
mptype=symbol.mptype,
|
|
369
|
-
data=base64.b64encode(encode_value(symbol.data)).decode("utf-8"),
|
|
370
|
-
)
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
@app.get(
|
|
374
|
-
"/sessions/{session_name}/symbols/{symbol_name}", response_model=SymbolResponse
|
|
375
|
-
)
|
|
376
|
-
def get_session_symbol(session_name: str, symbol_name: str) -> SymbolResponse:
|
|
377
|
-
"""Get a symbol from a session."""
|
|
378
|
-
try:
|
|
379
|
-
logger.debug(
|
|
380
|
-
f"Looking for symbol: '{symbol_name}' in session: '{session_name}'"
|
|
381
|
-
)
|
|
382
|
-
sess = _sessions.get(session_name)
|
|
383
|
-
symbol = sess.get_symbol(symbol_name) if sess else None
|
|
384
|
-
if not symbol:
|
|
385
|
-
raise HTTPException(
|
|
386
|
-
status_code=404, detail=f"Symbol {symbol_name} not found"
|
|
387
|
-
)
|
|
388
|
-
|
|
389
|
-
# symbol data is None means this party does not participate the computation
|
|
390
|
-
# that produced the symbol.
|
|
391
|
-
if symbol.data is None:
|
|
392
|
-
raise ResourceNotFound(f"Symbol '{symbol_name}' has no data on this party")
|
|
393
|
-
|
|
394
|
-
# Serialize using Value envelope
|
|
395
|
-
return SymbolResponse(
|
|
396
|
-
name=symbol.name,
|
|
397
|
-
mptype=symbol.mptype,
|
|
398
|
-
data=base64.b64encode(encode_value(symbol.data)).decode("utf-8"),
|
|
399
|
-
)
|
|
400
|
-
except ValueError as e:
|
|
401
|
-
raise HTTPException(status_code=404, detail=str(e)) from e
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
@app.get("/sessions/{session_name}/symbols")
|
|
405
|
-
def list_session_symbols(session_name: str) -> dict[str, list[str]]:
|
|
406
|
-
"""List all symbols in a session."""
|
|
407
|
-
sess = _sessions.get(session_name)
|
|
408
|
-
if not sess:
|
|
409
|
-
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
|
410
|
-
symbols = sess.list_symbols()
|
|
411
|
-
return {"symbols": symbols}
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
@app.delete("/sessions/{session_name}/symbols/{symbol_name}")
|
|
415
|
-
def delete_symbol(session_name: str, symbol_name: str) -> dict[str, str]:
|
|
416
|
-
"""Delete a specific symbol."""
|
|
417
|
-
sess = _sessions.get(session_name)
|
|
418
|
-
if sess and sess.delete_symbol(symbol_name):
|
|
419
|
-
logging.info(f"Symbol {symbol_name} deleted from session {session_name}")
|
|
420
|
-
return {"message": f"Symbol '{symbol_name}' deleted successfully"}
|
|
421
|
-
else:
|
|
422
|
-
raise ResourceNotFound(
|
|
423
|
-
f"Symbol '{symbol_name}' not found in session '{session_name}'"
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
# Global Symbols endpoints
|
|
428
|
-
@app.put("/api/v1/symbols/{symbol_name}", response_model=GlobalSymbolResponse)
|
|
429
|
-
def create_global_symbol(
|
|
430
|
-
symbol_name: str, request: CreateSymbolRequest
|
|
431
|
-
) -> GlobalSymbolResponse:
|
|
432
|
-
validate_name(symbol_name, "symbol")
|
|
433
|
-
try:
|
|
434
|
-
obj = decode_value(base64.b64decode(request.data))
|
|
435
|
-
except Exception as e:
|
|
436
|
-
raise InvalidRequestError(f"Invalid global symbol data: {e!s}") from e
|
|
437
|
-
sym = Symbol(name=symbol_name, mptype=request.mptype, data=obj)
|
|
438
|
-
_global_symbols[symbol_name] = sym
|
|
439
|
-
return GlobalSymbolResponse(
|
|
440
|
-
name=sym.name,
|
|
441
|
-
mptype=sym.mptype,
|
|
442
|
-
data=base64.b64encode(encode_value(sym.data)).decode("utf-8"),
|
|
443
|
-
)
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
@app.get("/api/v1/symbols/{symbol_name}", response_model=GlobalSymbolResponse)
|
|
447
|
-
def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse: # route handler
|
|
448
|
-
sym = _global_symbols.get(symbol_name)
|
|
449
|
-
if not sym:
|
|
450
|
-
raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
|
|
451
|
-
# Serialize using Value envelope
|
|
452
|
-
return GlobalSymbolResponse(
|
|
453
|
-
name=sym.name,
|
|
454
|
-
mptype=sym.mptype,
|
|
455
|
-
data=base64.b64encode(encode_value(sym.data)).decode("utf-8"),
|
|
456
|
-
)
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
@app.get("/api/v1/symbols")
|
|
460
|
-
def list_global_symbols() -> dict[str, list[str]]:
|
|
461
|
-
return {"symbols": list(_global_symbols.keys())}
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
@app.delete("/api/v1/symbols/{symbol_name}")
|
|
465
|
-
def delete_global_symbol(symbol_name: str) -> dict[str, str]: # route handler
|
|
466
|
-
if symbol_name in _global_symbols:
|
|
467
|
-
del _global_symbols[symbol_name]
|
|
468
|
-
return {"message": f"Global symbol '{symbol_name}' deleted successfully"}
|
|
469
|
-
else:
|
|
470
|
-
raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
# Communication endpoints
|
|
474
|
-
# TODO(jint) this should be computation level, add multi computation parallel support.
|
|
475
|
-
@app.put("/sessions/{session_name}/comm/{key}/from/{from_rank}")
|
|
476
|
-
def comm_send(
|
|
477
|
-
session_name: str, key: str, from_rank: int, request: CommSendRequest
|
|
478
|
-
) -> dict[str, str]:
|
|
479
|
-
"""
|
|
480
|
-
Receive a message from another party and deliver it to the session's communicator.
|
|
481
|
-
This endpoint runs on the receiver's server.
|
|
482
|
-
"""
|
|
483
|
-
sess = _sessions.get(session_name)
|
|
484
|
-
if not sess or not sess.communicator:
|
|
485
|
-
logger.error(f"Session or communicator not found: session={session_name}")
|
|
486
|
-
raise HTTPException(status_code=404, detail="Session or communicator not found")
|
|
487
|
-
|
|
488
|
-
# The receiver rank should be the rank of the server hosting this endpoint
|
|
489
|
-
# We don't need to validate to_rank since the request is coming to this server
|
|
490
|
-
|
|
491
|
-
# For raw bytes (SPU channel), pass through as dict with flag
|
|
492
|
-
# For normal data, pass the base64 string directly
|
|
493
|
-
data_payload: str | dict[str, object]
|
|
494
|
-
if request.is_raw_bytes:
|
|
495
|
-
data_payload = {"data": request.data, "is_raw_bytes": True}
|
|
496
|
-
else:
|
|
497
|
-
data_payload = request.data
|
|
498
|
-
|
|
499
|
-
# Use the proper onSent mechanism from CommunicatorBase
|
|
500
|
-
sess.communicator.onSent(from_rank, key, data_payload)
|
|
501
|
-
return {"status": "ok"}
|