mplang-nightly 0.1.dev139__py3-none-any.whl → 0.1.dev140__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mplang/backend/builtin.py CHANGED
@@ -14,7 +14,6 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- import os
18
17
  from typing import Any
19
18
 
20
19
  import numpy as np
@@ -24,6 +23,7 @@ from mplang.backend.base import cur_kctx, kernel_def
24
23
  from mplang.core.pfunc import PFunction
25
24
  from mplang.core.table import TableType
26
25
  from mplang.core.tensor import TensorType
26
+ from mplang.runtime.data_providers import get_provider, resolve_uri
27
27
  from mplang.utils import table_utils
28
28
 
29
29
 
@@ -50,16 +50,14 @@ def _read(pfunc: PFunction) -> Any:
50
50
  if path is None:
51
51
  raise ValueError("missing path attr for builtin.read")
52
52
  out_t = pfunc.outs_info[0]
53
+ uri = resolve_uri(str(path))
54
+ prov = get_provider(uri.scheme)
55
+ if prov is None:
56
+ raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
57
+ ctx = cur_kctx()
53
58
  try:
54
- if isinstance(out_t, TableType):
55
- with open(path, "rb") as f:
56
- csv_bytes = f.read()
57
- df = table_utils.csv_to_dataframe(csv_bytes)
58
- return df
59
- else:
60
- data = np.load(path)
61
- return data
62
- except Exception as e: # pragma: no cover - filesystem errors
59
+ return prov.read(uri, out_t, ctx=ctx)
60
+ except Exception as e: # pragma: no cover - provider errors
63
61
  raise RuntimeError(f"builtin.read failed: {e}") from e
64
62
 
65
63
 
@@ -68,16 +66,13 @@ def _write(pfunc: PFunction, obj: Any) -> Any:
68
66
  path = pfunc.attrs.get("path")
69
67
  if path is None:
70
68
  raise ValueError("missing path attr for builtin.write")
69
+ uri = resolve_uri(str(path))
70
+ prov = get_provider(uri.scheme)
71
+ if prov is None:
72
+ raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
73
+ ctx = cur_kctx()
71
74
  try:
72
- dir_name = os.path.dirname(path)
73
- if dir_name:
74
- os.makedirs(dir_name, exist_ok=True)
75
- if hasattr(obj, "__dataframe__") or isinstance(obj, pd.DataFrame):
76
- csv_bytes = table_utils.dataframe_to_csv(obj) # type: ignore
77
- with open(path, "wb") as f:
78
- f.write(csv_bytes)
79
- else:
80
- np.save(path, _to_numpy(obj))
75
+ prov.write(uri, obj, ctx=ctx)
81
76
  return obj
82
77
  except Exception as e: # pragma: no cover
83
78
  raise RuntimeError(f"builtin.write failed: {e}") from e
@@ -28,22 +28,32 @@ _BUILTIN_MOD = stateless_mod("builtin")
28
28
 
29
29
  @_BUILTIN_MOD.simple_op()
30
30
  def identity(x: TensorType) -> TensorType:
31
- """Identity on type: captures the underlying MPObject (if any) but kernel sees only the type.
31
+ """Return the input type unchanged.
32
32
 
33
- Under strict typed_op semantics positional MPObject is converted to its TensorType before entering.
33
+ Args:
34
+ x: The input tensor type. If called with an MPObject, the value is
35
+ captured positionally; the kernel sees only the type.
36
+
37
+ Returns:
38
+ The same type as ``x``.
34
39
  """
35
40
  return x
36
41
 
37
42
 
38
43
  @_BUILTIN_MOD.simple_op()
39
44
  def read(*, path: str, ty: TensorType) -> TensorType:
40
- """Type-only kernel for reading a tensor/table from a path.
45
+ """Declare reading a value of type ``ty`` from ``path`` (type-only).
46
+
47
+ Args:
48
+ path: Non-empty path or URI to read from (stored as an attribute).
49
+ ty: The expected output type/schema.
41
50
 
42
- Attributes:
43
- - path: str destination to read from (carried as PFunction attr)
44
- - ty: TensorType/TableType describing the expected output type
51
+ Returns:
52
+ Exactly ``ty``.
45
53
 
46
- Returns: ty (shape/dtype/schema), no inputs captured.
54
+ Raises:
55
+ ValueError: If ``path`` is empty.
56
+ TypeError: If ``ty`` is not a TensorType or TableType.
47
57
  """
48
58
  if not isinstance(path, str) or path == "":
49
59
  raise ValueError("path must be a non-empty string")
@@ -55,9 +65,14 @@ def read(*, path: str, ty: TensorType) -> TensorType:
55
65
 
56
66
  @_BUILTIN_MOD.simple_op()
57
67
  def write(x: TensorType, *, path: str) -> TensorType:
58
- """Write op: returns same type it consumes; runtime handles side effect.
68
+ """Declare writing the input value to ``path`` and return the same type.
69
+
70
+ Args:
71
+ x: The value's type to be written; values are captured positionally.
72
+ path: Destination path or URI (attribute).
59
73
 
60
- Positional MPObject (tensor/table) will be captured; kernel sees only its type.
74
+ Returns:
75
+ The same type as ``x``.
61
76
  """
62
77
  return x
63
78
 
@@ -66,6 +81,20 @@ def write(x: TensorType, *, path: str) -> TensorType:
66
81
  def constant(
67
82
  data: TensorLike | ScalarType | TableLike,
68
83
  ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
84
+ """Embed a literal tensor/table and return the full triad.
85
+
86
+ Args:
87
+ data: Constant payload. Supports scalars, array-like tensors, or
88
+ table-like dataframes.
89
+
90
+ Returns:
91
+ Tuple[PFunction, list[MPObject], PyTreeDef]:
92
+ - PFunction: ``fn_type='builtin.constant'`` with one output whose type
93
+ matches ``data``; payload serialized via ``data_bytes`` with
94
+ ``data_format`` ('bytes[numpy]' or 'bytes[csv]').
95
+ - list[MPObject]: Empty (no inputs captured).
96
+ - PyTreeDef: Output tree (single leaf).
97
+ """
69
98
  import numpy as np
70
99
 
71
100
  data_bytes: bytes
@@ -103,19 +132,23 @@ def constant(
103
132
 
104
133
  @_BUILTIN_MOD.simple_op()
105
134
  def rank() -> TensorType:
106
- """Type-only kernel: returns the UINT64 scalar tensor type for current rank.
135
+ """Return the scalar UINT64 tensor type for the current party rank.
107
136
 
108
- Runtime provides the concrete rank value per party during execution; here we
109
- only declare the output type with no inputs captured and no attributes.
137
+ Returns:
138
+ A scalar ``UINT64`` tensor type (shape ``()``).
110
139
  """
111
140
  return TensorType(UINT64, ())
112
141
 
113
142
 
114
143
  @_BUILTIN_MOD.simple_op()
115
144
  def prand(*, shape: Shape = ()) -> TensorType:
116
- """Type-only kernel: private random UINT64 tensor of given shape.
145
+ """Declare a private random UINT64 tensor with the given shape.
146
+
147
+ Args:
148
+ shape: Output tensor shape. Defaults to ``()``.
117
149
 
118
- Shape is attached as a PFunction attribute via typed_op; no inputs.
150
+ Returns:
151
+ A ``UINT64`` tensor type with the specified shape.
119
152
  """
120
153
  return TensorType(UINT64, shape)
121
154
 
@@ -124,21 +157,30 @@ def prand(*, shape: Shape = ()) -> TensorType:
124
157
  def debug_print(
125
158
  x: TensorType | TableType, *, prefix: str = ""
126
159
  ) -> TableType | TensorType:
127
- """Debug-print pass-through: type identity with side-effect attribute.
160
+ """Print a value at runtime and return the same type.
128
161
 
129
- Accepts tensor/table type; MPObject positional (if provided) is captured automatically.
162
+ Args:
163
+ x: The value to print (captured positionally; kernel sees only type).
164
+ prefix: Optional text prefix for the printed output.
165
+
166
+ Returns:
167
+ The same type as ``x``.
130
168
  """
131
169
  return x
132
170
 
133
171
 
134
172
  @_BUILTIN_MOD.simple_op()
135
173
  def pack(x: TensorType | TableType) -> TensorType:
136
- """Type-only pack operator: models serialization into a byte vector.
174
+ """Serialize a tensor/table into a byte vector (type-only).
175
+
176
+ Args:
177
+ x: Input type to pack.
137
178
 
138
- The frontend only declares the type transformation; the runtime decides
139
- whether any actual serialization takes place. The result is always a
140
- one-dimensional UINT8 tensor with unknown length (-1 means runtime
141
- determined).
179
+ Returns:
180
+ A ``UINT8`` tensor type with shape ``(-1,)`` (length decided at runtime).
181
+
182
+ Raises:
183
+ TypeError: If ``x`` is not a TensorType or TableType.
142
184
  """
143
185
 
144
186
  if not isinstance(x, (TensorType, TableType)):
@@ -149,10 +191,19 @@ def pack(x: TensorType | TableType) -> TensorType:
149
191
 
150
192
  @_BUILTIN_MOD.simple_op()
151
193
  def unpack(b: TensorType, *, out_ty: TensorType | TableType) -> TensorType | TableType:
152
- """Type-only unpack operator: inverse of `pack`.
194
+ """Deserialize a byte vector into the explicit output type.
195
+
196
+ Args:
197
+ b: Byte vector type. Must be ``UINT8`` with shape ``(N,)`` (``N`` may be
198
+ ``-1``).
199
+ out_ty: Resulting type/schema after unpacking.
153
200
 
154
- Requires a one-dimensional UINT8 tensor input (length can be -1) and
155
- returns the explicit `out_ty` type description.
201
+ Returns:
202
+ Exactly ``out_ty``.
203
+
204
+ Raises:
205
+ TypeError: If ``out_ty`` is not a TensorType/TableType, or if ``b`` is
206
+ not a 1-D UINT8 tensor.
156
207
  """
157
208
 
158
209
  if not isinstance(out_ty, (TensorType, TableType)):
@@ -166,6 +217,21 @@ def unpack(b: TensorType, *, out_ty: TensorType | TableType) -> TensorType | Tab
166
217
 
167
218
  @_BUILTIN_MOD.simple_op()
168
219
  def table_to_tensor(table: TableType, *, number_rows: int) -> TensorType:
220
+ """Convert a homogeneous-typed table to a dense 2D tensor.
221
+
222
+ Args:
223
+ table: Input table whose columns all share the same dtype.
224
+ number_rows: Number of rows in the resulting tensor. Must be ``>= 0``.
225
+
226
+ Returns:
227
+ A rank-2 tensor with dtype equal to the table column dtype and shape
228
+ ``(number_rows, table.num_columns())``.
229
+
230
+ Raises:
231
+ ValueError: If the table is empty or ``number_rows < 0``.
232
+ TypeError: If the table has heterogeneous column dtypes or ``number_rows``
233
+ is not an int.
234
+ """
169
235
  if table.num_columns() == 0:
170
236
  raise ValueError("Cannot pack empty table")
171
237
  col_dtypes = list(table.column_types())
@@ -184,6 +250,21 @@ def table_to_tensor(table: TableType, *, number_rows: int) -> TensorType:
184
250
 
185
251
  @_BUILTIN_MOD.simple_op()
186
252
  def tensor_to_table(tensor: TensorType, *, column_names: list[str]) -> TableType:
253
+ """Convert a rank-2 tensor into a table with named columns.
254
+
255
+ Args:
256
+ tensor: Rank-2 tensor with shape ``(N, F)``.
257
+ column_names: List of unique, non-whitespace column names of length ``F``.
258
+
259
+ Returns:
260
+ A table with ``F`` columns named as provided, each with dtype
261
+ ``tensor.dtype``.
262
+
263
+ Raises:
264
+ TypeError: If ``tensor`` is not rank-2, or if any column name is not a
265
+ string.
266
+ ValueError: If names are empty/whitespace, duplicated, or length != ``F``.
267
+ """
187
268
  if len(tensor.shape) != 2:
188
269
  raise TypeError("tensor_to_table expects a rank-2 tensor (N,F)")
189
270
  n_cols = tensor.shape[1]
mplang/runtime/client.py CHANGED
@@ -403,3 +403,57 @@ class HttpExecutorClient:
403
403
  raise self._raise_http_error(
404
404
  f"list computations for session {session_name}", e
405
405
  ) from e
406
+
407
+ # ---------------- Global Symbols (process-level) ----------------
408
+ async def create_global_symbol(
409
+ self, symbol_name: str, data: Any, mptype: dict | None = None
410
+ ) -> None:
411
+ """Create or replace a process-global symbol.
412
+
413
+ Args:
414
+ symbol_name: Identifier
415
+ data: Python object to store (pickle based)
416
+ mptype: Optional metadata dict
417
+ """
418
+ url = f"/api/v1/symbols/{symbol_name}"
419
+ try:
420
+ payload = {
421
+ "data": base64.b64encode(pickle.dumps(data)).decode("utf-8"),
422
+ "mptype": mptype or {},
423
+ }
424
+ resp = await self._client.put(url, json=payload)
425
+ resp.raise_for_status()
426
+ except (httpx.HTTPStatusError, httpx.RequestError) as e:
427
+ raise self._raise_http_error(
428
+ f"create global symbol {symbol_name}", e
429
+ ) from e
430
+
431
+ async def get_global_symbol(self, symbol_name: str) -> Any:
432
+ url = f"/api/v1/symbols/{symbol_name}"
433
+ try:
434
+ resp = await self._client.get(url)
435
+ resp.raise_for_status()
436
+ payload = resp.json()
437
+ data_bytes = base64.b64decode(payload["data"])
438
+ return pickle.loads(data_bytes)
439
+ except (httpx.HTTPStatusError, httpx.RequestError) as e:
440
+ raise self._raise_http_error(f"get global symbol {symbol_name}", e) from e
441
+
442
+ async def delete_global_symbol(self, symbol_name: str) -> None:
443
+ url = f"/api/v1/symbols/{symbol_name}"
444
+ try:
445
+ resp = await self._client.delete(url)
446
+ resp.raise_for_status()
447
+ except (httpx.HTTPStatusError, httpx.RequestError) as e:
448
+ raise self._raise_http_error(
449
+ f"delete global symbol {symbol_name}", e
450
+ ) from e
451
+
452
+ async def list_global_symbols(self) -> list[str]:
453
+ url = "/api/v1/symbols"
454
+ try:
455
+ resp = await self._client.get(url)
456
+ resp.raise_for_status()
457
+ return list(resp.json().get("symbols", []))
458
+ except (httpx.HTTPStatusError, httpx.RequestError) as e:
459
+ raise self._raise_http_error("list global symbols", e) from e
@@ -0,0 +1,248 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Any
19
+ from urllib.parse import ParseResult, urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+
24
+ from mplang.backend.base import KernelContext
25
+ from mplang.core.table import TableType
26
+ from mplang.core.tensor import TensorType
27
+ from mplang.utils import table_utils
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class ResolvedURI:
32
+ """Result of resolving a resource path into a normalized form.
33
+
34
+ Attributes:
35
+ scheme: The URI scheme (e.g., 'file', 's3', 'mem', 'var', 'secret').
36
+ raw: The original path string as provided by the user.
37
+ parsed: The ParseResult if a scheme was present; otherwise None.
38
+ local_path: For file paths: concrete filesystem path (absolute or as given).
39
+ """
40
+
41
+ scheme: str
42
+ raw: str
43
+ parsed: ParseResult | None
44
+ local_path: str | None
45
+
46
+
47
+ def resolve_uri(path: str) -> ResolvedURI:
48
+ """Resolve a user-provided resource location into a normalized URI form.
49
+
50
+ This helper accepts plain filesystem paths and RFC 3986 style URIs. A path
51
+ is treated as ``file`` when ``urlparse(path).scheme`` is empty. Detection
52
+ no longer depends on the presence of the literal substring ``"://"`` so
53
+ that forms like ``mem:foo`` (no slashes) are still recognized as a URI.
54
+
55
+ Captured fields
56
+ - ``scheme``: Lower-cased scheme (``file`` when absent)
57
+ - ``raw``: Original input
58
+ - ``parsed``: ``ParseResult`` when a scheme was provided, else ``None``
59
+ - ``local_path``: Filesystem path for ``file`` scheme, else ``None``
60
+
61
+ Supported (pluggable) schemes out-of-the-box:
62
+ * ``file`` (default)
63
+ * ``mem``
64
+ * ``s3`` (stub)
65
+ * ``secret`` (stub)
66
+ * ``symbols`` (registered server-side)
67
+
68
+ Examples
69
+ >>> resolve_uri("data/train.npy").scheme
70
+ 'file'
71
+ >>> resolve_uri("mem:dataset1").scheme
72
+ 'mem'
73
+ >>> resolve_uri("mem://dataset1").scheme # both forms acceptable
74
+ 'mem'
75
+ >>> resolve_uri("symbols://shared_model").scheme
76
+ 'symbols'
77
+ >>> resolve_uri("file:///tmp/x.npy").local_path
78
+ '/tmp/x.npy'
79
+ """
80
+
81
+ pr = urlparse(path)
82
+ if not pr.scheme:
83
+ return ResolvedURI("file", path, None, path)
84
+
85
+ scheme = pr.scheme.lower()
86
+ local_path: str | None = None
87
+ if scheme == "file":
88
+ local_path = pr.path
89
+ if pr.netloc and not local_path.startswith("/"):
90
+ local_path = f"//{pr.netloc}/{pr.path}"
91
+ return ResolvedURI(scheme, path, pr, local_path)
92
+
93
+
94
+ class DataProvider:
95
+ """Abstract base for data providers.
96
+
97
+ Minimal contract: read/write by URI and type spec. Providers may ignore the
98
+ type spec but SHOULD validate when feasible.
99
+ """
100
+
101
+ def read(
102
+ self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
103
+ ) -> Any:
104
+ raise NotImplementedError
105
+
106
+ def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
107
+ raise NotImplementedError
108
+
109
+
110
+ _REGISTRY: dict[str, DataProvider] = {}
111
+
112
+
113
+ def register_provider(
114
+ scheme: str, provider: DataProvider, *, replace: bool = False, quiet: bool = False
115
+ ) -> None:
116
+ """Register a provider implementation.
117
+
118
+ Args:
119
+ scheme: URI scheme handled (case-insensitive)
120
+ provider: Implementation
121
+ replace: If False and scheme exists -> ValueError
122
+ quiet: If True, suppress duplicate log messages when replacing
123
+ """
124
+ import logging
125
+
126
+ key = scheme.lower()
127
+ if not replace and key in _REGISTRY:
128
+ raise ValueError(f"provider already registered for scheme: {scheme}")
129
+ if replace and key in _REGISTRY and not quiet:
130
+ logging.info(f"Replacing existing provider for scheme '{scheme}'")
131
+ _REGISTRY[key] = provider
132
+
133
+
134
+ def get_provider(scheme: str) -> DataProvider | None:
135
+ return _REGISTRY.get(scheme.lower())
136
+
137
+
138
+ # ---------------- Default Providers ----------------
139
+
140
+
141
+ class FileProvider(DataProvider):
142
+ """Local filesystem provider.
143
+
144
+ For tables: CSV bytes via table_utils.
145
+ For tensors: NumPy .npy via np.load/np.save.
146
+ """
147
+
148
+ def read(
149
+ self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
150
+ ) -> Any:
151
+ path = uri.local_path or uri.raw
152
+ if isinstance(out_spec, TableType):
153
+ with open(path, "rb") as f:
154
+ csv_bytes = f.read()
155
+ return table_utils.csv_to_dataframe(csv_bytes)
156
+ # tensor path
157
+ return np.load(path)
158
+
159
+ def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
160
+ import os
161
+
162
+ path = uri.local_path or uri.raw
163
+ dir_name = os.path.dirname(path)
164
+ if dir_name:
165
+ os.makedirs(dir_name, exist_ok=True)
166
+ # Table-like to CSV bytes
167
+ if hasattr(value, "__dataframe__") or isinstance(value, pd.DataFrame):
168
+ csv_bytes = table_utils.dataframe_to_csv(value) # type: ignore
169
+ with open(path, "wb") as f:
170
+ f.write(csv_bytes)
171
+ return
172
+ # Tensor-like via numpy
173
+ np.save(path, np.asarray(value))
174
+
175
+
176
+ class _KeyedPocket:
177
+ """Small helper to keep a dict in KernelContext.state under a namespaced key."""
178
+
179
+ def __init__(self, ns: str):
180
+ self.ns = ns
181
+
182
+ def get_map(self, ctx: KernelContext) -> dict[str, Any]:
183
+ pocket = ctx.state.setdefault("resource.providers", {})
184
+ store = pocket.get(self.ns)
185
+ if store is None:
186
+ store = {}
187
+ pocket[self.ns] = store
188
+ return store # type: ignore[return-value]
189
+
190
+
191
+ class MemProvider(DataProvider):
192
+ """In-memory per-runtime KV provider (per rank, per session/runtime)."""
193
+
194
+ def __init__(self) -> None:
195
+ self._pocket = _KeyedPocket("mem")
196
+
197
+ def read(
198
+ self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
199
+ ) -> Any:
200
+ store = self._pocket.get_map(ctx)
201
+ key = uri.raw
202
+ if key not in store:
203
+ raise FileNotFoundError(f"mem resource not found: {key}")
204
+ return store[key]
205
+
206
+ def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
207
+ store = self._pocket.get_map(ctx)
208
+ store[uri.raw] = value
209
+
210
+
211
+ class S3Provider(DataProvider):
212
+ """Placeholder S3 provider. Install external plugin to enable."""
213
+
214
+ def read(
215
+ self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
216
+ ) -> Any:
217
+ raise NotImplementedError(
218
+ "S3 provider not installed. Provide an external plugin via register_provider('s3', ...) ."
219
+ )
220
+
221
+ def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
222
+ raise NotImplementedError(
223
+ "S3 provider not installed. Provide an external plugin via register_provider('s3', ...) ."
224
+ )
225
+
226
+
227
+ class SecretProvider(DataProvider):
228
+ """Placeholder secret provider. Integrate with KMS/secret manager via plugin."""
229
+
230
+ def read(
231
+ self, uri: ResolvedURI, out_spec: TensorType | TableType, *, ctx: KernelContext
232
+ ) -> Any:
233
+ raise NotImplementedError(
234
+ "secret provider not installed. Provide an external plugin via register_provider('secret', ...) ."
235
+ )
236
+
237
+ def write(self, uri: ResolvedURI, value: Any, *, ctx: KernelContext) -> None:
238
+ raise NotImplementedError(
239
+ "secret provider not installed. Provide an external plugin via register_provider('secret', ...) ."
240
+ )
241
+
242
+
243
+ # Register default providers
244
+ register_provider("file", FileProvider())
245
+ register_provider("mem", MemProvider())
246
+ # Stubs to signal missing providers explicitly (can be overridden by plugins)
247
+ register_provider("s3", S3Provider())
248
+ register_provider("secret", SecretProvider())
@@ -36,7 +36,7 @@ from mplang.backend import ( # noqa: F401
36
36
  stablehlo,
37
37
  tee,
38
38
  )
39
- from mplang.backend.base import create_runtime
39
+ from mplang.backend.base import BackendRuntime, create_runtime
40
40
  from mplang.backend.spu import PFunction # type: ignore
41
41
  from mplang.core.expr.ast import Expr
42
42
  from mplang.core.expr.evaluator import IEvaluator, create_evaluator
@@ -98,6 +98,9 @@ class Session:
98
98
  # Global session storage
99
99
  _sessions: dict[str, Session] = {}
100
100
 
101
+ # Service-level global symbol table (process-local for this server)
102
+ _global_symbols: dict[str, Symbol] = {}
103
+
101
104
 
102
105
  # Session Management
103
106
  def create_session(
@@ -223,6 +226,14 @@ def execute_computation(
223
226
  # Build evaluator
224
227
  # Explicit per-rank backend runtime (deglobalized)
225
228
  runtime = create_runtime(rank, session.communicator.world_size)
229
+ # Inject global symbol storage into backend runtime state so that
230
+ # symbols:// provider can access it during builtin.read/write.
231
+ if isinstance(runtime, BackendRuntime):
232
+ pocket = runtime.state.setdefault("resource.providers", {})
233
+ if "symbols" not in pocket:
234
+ pocket["symbols"] = _global_symbols
235
+ else:
236
+ raise RuntimeError("resource.providers.symbols already exists")
226
237
  evaluator: IEvaluator = create_evaluator(
227
238
  rank=rank, env=bindings, comm=session.communicator, runtime=runtime
228
239
  )
@@ -274,6 +285,45 @@ def execute_computation(
274
285
  session.symbols[name] = Symbol(name=name, mptype={}, data=val)
275
286
 
276
287
 
288
+ # Global symbol CRUD (service-level)
289
+ def create_global_symbol(name: str, mptype: Any, data: str) -> Symbol:
290
+ """Create or update a global symbol in the service-level table.
291
+
292
+ WARNING: Uses Python pickle for arbitrary object deserialization. Deploy
293
+ only in trusted environments. Future work may replace this with a
294
+ restricted / structured serialization.
295
+
296
+ The `data` argument is a base64-encoded pickled Python object. Minimal
297
+ validation of `mptype` is performed for tensor metadata (shape/dtype)
298
+ when present to catch obvious mismatches.
299
+ """
300
+ try:
301
+ data_bytes = base64.b64decode(data)
302
+ obj = pickle.loads(data_bytes)
303
+ except Exception as e:
304
+ raise InvalidRequestError(f"Invalid global symbol data encoding: {e!s}") from e
305
+
306
+ sym = Symbol(name, mptype, obj)
307
+ _global_symbols[name] = sym
308
+ return sym
309
+
310
+
311
+ def get_global_symbol(name: str) -> Symbol | None:
312
+ return _global_symbols.get(name)
313
+
314
+
315
+ def list_global_symbols() -> list[str]:
316
+ return list(_global_symbols.keys())
317
+
318
+
319
+ def delete_global_symbol(name: str) -> bool:
320
+ if name in _global_symbols:
321
+ del _global_symbols[name]
322
+ logging.info(f"Global symbol {name} deleted")
323
+ return True
324
+ return False
325
+
326
+
277
327
  # Symbol Management
278
328
  def create_symbol(session_name: str, name: str, mptype: Any, data: Any) -> Symbol:
279
329
  """Create a symbol in a session's symbol table.
mplang/runtime/server.py CHANGED
@@ -20,15 +20,20 @@ It uses FastAPI to provide a RESTful API for managing computations.
20
20
  import base64
21
21
  import logging
22
22
  import re
23
+ from typing import Any
23
24
 
24
25
  import cloudpickle as pickle
25
26
  from fastapi import FastAPI, HTTPException, Request
26
27
  from fastapi.responses import JSONResponse
27
28
  from pydantic import BaseModel
28
29
 
30
+ from mplang.backend.base import KernelContext
29
31
  from mplang.core.mpir import Reader
32
+ from mplang.core.table import TableType
33
+ from mplang.core.tensor import TensorType
30
34
  from mplang.protos.v1alpha1 import mpir_pb2
31
35
  from mplang.runtime import resource
36
+ from mplang.runtime.data_providers import DataProvider, ResolvedURI, register_provider
32
37
  from mplang.runtime.exceptions import InvalidRequestError, ResourceNotFound
33
38
 
34
39
  logger = logging.getLogger(__name__)
@@ -36,6 +41,75 @@ logger = logging.getLogger(__name__)
36
41
  app = FastAPI()
37
42
 
38
43
 
44
+ class _SymbolsProvider(DataProvider):
45
+ """Server-local symbols provider backed by BackendRuntime.state."""
46
+
47
+ @staticmethod
48
+ def _symbol_name(uri: ResolvedURI) -> str:
49
+ if uri.scheme != "symbols" or uri.parsed is None:
50
+ raise InvalidRequestError(
51
+ "symbols provider expects URI in the form symbols://{name}"
52
+ )
53
+
54
+ parsed = uri.parsed
55
+ if parsed.query or parsed.params or parsed.fragment:
56
+ raise InvalidRequestError(
57
+ "symbols:// URI must not contain query or fragment"
58
+ )
59
+
60
+ if parsed.netloc:
61
+ # e.g. symbols://foo -> name is carried in netloc (path may be empty or "/")
62
+ if parsed.path not in ("", "/"):
63
+ raise InvalidRequestError("symbols:// URIs cannot include subpaths")
64
+ name = parsed.netloc
65
+ else:
66
+ # e.g. symbols:///foo -> netloc empty, single path segment is the symbol name
67
+ path = parsed.path.lstrip("/")
68
+ if not path or "/" in path:
69
+ raise InvalidRequestError(
70
+ "symbols:// URI must specify a single symbol name"
71
+ )
72
+ name = path
73
+
74
+ if not name:
75
+ raise InvalidRequestError("symbols:// URI missing symbol name")
76
+ return name
77
+
78
+ def read(
79
+ self,
80
+ uri: ResolvedURI,
81
+ out_spec: TensorType | TableType,
82
+ *,
83
+ ctx: KernelContext,
84
+ ) -> Any: # type: ignore[override]
85
+ name = self._symbol_name(uri)
86
+ sym = resource.get_global_symbol(name)
87
+ if sym is None:
88
+ raise ResourceNotFound(f"Global symbol '{name}' not found")
89
+ return sym.data
90
+
91
+ def write(
92
+ self,
93
+ uri: ResolvedURI,
94
+ value: Any,
95
+ *,
96
+ ctx: KernelContext,
97
+ ) -> None: # type: ignore[override]
98
+ name = self._symbol_name(uri)
99
+ try:
100
+ data_b64 = base64.b64encode(pickle.dumps(value)).decode("utf-8")
101
+ except Exception as e: # pragma: no cover - defensive
102
+ raise InvalidRequestError(
103
+ f"Failed to encode value for symbols:// write: {e!s}"
104
+ ) from e
105
+
106
+ resource.create_global_symbol(name, {}, data_b64)
107
+
108
+
109
+ # Register symbols provider explicitly for server runtime
110
+ register_provider("symbols", _SymbolsProvider())
111
+
112
+
39
113
  @app.exception_handler(ResourceNotFound)
40
114
  def resource_not_found_handler(request: Request, exc: ResourceNotFound) -> JSONResponse:
41
115
  """Handler for ResourceNotFound exceptions."""
@@ -139,6 +213,12 @@ class ComputationListResponse(BaseModel):
139
213
  computations: list[str]
140
214
 
141
215
 
216
+ class GlobalSymbolResponse(BaseModel):
217
+ name: str
218
+ mptype: dict
219
+ data: str
220
+
221
+
142
222
  @app.get("/health")
143
223
  async def health_check() -> dict[str, str]:
144
224
  """Health check endpoint."""
@@ -303,6 +383,39 @@ def delete_symbol(session_name: str, symbol_name: str) -> dict[str, str]:
303
383
  )
304
384
 
305
385
 
386
+ # Global Symbols endpoints
387
+ @app.put("/api/v1/symbols/{symbol_name}", response_model=GlobalSymbolResponse)
388
+ def create_global_symbol(
389
+ symbol_name: str, request: CreateSymbolRequest
390
+ ) -> GlobalSymbolResponse:
391
+ validate_name(symbol_name, "symbol")
392
+ sym = resource.create_global_symbol(symbol_name, request.mptype, request.data)
393
+ return GlobalSymbolResponse(name=sym.name, mptype=sym.mptype, data=request.data)
394
+
395
+
396
+ @app.get("/api/v1/symbols/{symbol_name}", response_model=GlobalSymbolResponse)
397
+ def get_global_symbol(symbol_name: str) -> GlobalSymbolResponse:
398
+ sym = resource.get_global_symbol(symbol_name)
399
+ if not sym:
400
+ raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
401
+ data_bytes = pickle.dumps(sym.data)
402
+ data_b64 = base64.b64encode(data_bytes).decode("utf-8")
403
+ return GlobalSymbolResponse(name=sym.name, mptype=sym.mptype, data=data_b64)
404
+
405
+
406
+ @app.get("/api/v1/symbols")
407
+ def list_global_symbols() -> dict[str, list[str]]:
408
+ return {"symbols": resource.list_global_symbols()}
409
+
410
+
411
+ @app.delete("/api/v1/symbols/{symbol_name}")
412
+ def delete_global_symbol(symbol_name: str) -> dict[str, str]:
413
+ if resource.delete_global_symbol(symbol_name):
414
+ return {"message": f"Global symbol '{symbol_name}' deleted successfully"}
415
+ else:
416
+ raise ResourceNotFound(f"Global symbol '{symbol_name}' not found")
417
+
418
+
306
419
  # Communication endpoints
307
420
  # TODO(jint) this should be computation level, add multi computation parallel support.
308
421
  @app.put("/sessions/{session_name}/comm/{key}/from/{from_rank}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mplang-nightly
3
- Version: 0.1.dev139
3
+ Version: 0.1.dev140
4
4
  Summary: Multi-Party Programming Language
5
5
  Author-email: SecretFlow Team <secretflow-contact@service.alipay.com>
6
6
  License: Apache License
@@ -5,7 +5,7 @@ mplang/analysis/__init__.py,sha256=CTHFvRsi-nFngojqjn08UaR3RY9i7CJ7T2UdR95kCrk,1
5
5
  mplang/analysis/diagram.py,sha256=ffwgD12gL1_KH1uJ_EYkjmIlDrfxYJJkWj-wHl09_Xk,19520
6
6
  mplang/backend/__init__.py,sha256=Pn1MGW7FZ8ZQWcx_r2Io4Q1JbrMINCpefQt7yCNq_dc,741
7
7
  mplang/backend/base.py,sha256=zaofB1MPC9Af8FS-rx7q6utFxiu9Mppr2gWaKFZ70Os,9863
8
- mplang/backend/builtin.py,sha256=usQjxzYPP8bYqazJveS91SXr0i1i4oe0N4b7aPABKuA,7181
8
+ mplang/backend/builtin.py,sha256=Mk1uUO2Vpw3meqZ0B7B0hG-wndea6cmFv2Uk1vM_uTg,7052
9
9
  mplang/backend/crypto.py,sha256=H_s5HI7lUP7g0xz-a9qMbSn6dhJStUilKbn3-7SIh0I,3812
10
10
  mplang/backend/phe.py,sha256=Y07fkTHTSKHgpMQ1soA3d2eBMlqr25uHCAtZOd5D5_I,10555
11
11
  mplang/backend/spu.py,sha256=i6Qqgeg-Anwpb5xX5Uz8GdmTWNkRy_pjp-xptIvlxl4,9273
@@ -37,7 +37,7 @@ mplang/core/expr/visitor.py,sha256=2Ge-I5N-wH8VVXy8d2WyNaEv8x6seiRx9peyH9S2BYU,2
37
37
  mplang/core/expr/walk.py,sha256=lXkGJEEuvKGDqQihbxXPxfz2RfR1Q1zYUlt11iooQW0,11889
38
38
  mplang/frontend/__init__.py,sha256=3ZBFX_acM96tZ2mtJaxJm150n1cf0LnnCRmkrAc4uBw,1463
39
39
  mplang/frontend/base.py,sha256=I-Hhh5o6GVqBA1YySl9Nk3zkbMoVrQqLMRyX2JypsPI,18268
40
- mplang/frontend/builtin.py,sha256=M3zquGYmZP0SgL6v34LiPibD1PiwFSn-zzo2Q1eeq4s,7259
40
+ mplang/frontend/builtin.py,sha256=8qrlbe_SSy6QTXTnMG6_ADB8jSklVZGFBrkoR-p02FE,9368
41
41
  mplang/frontend/crypto.py,sha256=Nf8zT4Eko7MIs4R2tgZecKVd7d6Hvd_CGGmANhs3Ghs,3651
42
42
  mplang/frontend/ibis_cc.py,sha256=01joUdFS_Ja9--PkezBhEcW_a9mkDrLgOhu5320s_bQ,4167
43
43
  mplang/frontend/jax_cc.py,sha256=ssP6rCvyWQ5VAr80-7z9QZUE2mWXyozJCGpq1dYQYY8,6374
@@ -50,14 +50,15 @@ mplang/protos/v1alpha1/mpir_pb2.pyi,sha256=GwXR4wPB_kB_36iYS9x-cGI9KDKFMq89KhdLh
50
50
  mplang/protos/v1alpha1/mpir_pb2_grpc.py,sha256=xYOs94SXiNYAlFodACnsXW5QovLsHY5tCk3p76RH5Zc,158
51
51
  mplang/runtime/__init__.py,sha256=IRPP3TtpFC4iSt7_uaq-S4dL7CwrXL0XBMeaBoEYLlg,948
52
52
  mplang/runtime/cli.py,sha256=WehDodeVB4AukSWx1LJxxtKUqGmLPY4qjayrPlOg3bE,14438
53
- mplang/runtime/client.py,sha256=9YGyTGox-zAUIkGGiwruCJauEfzf5hQsDU3HBd_PxHE,13555
53
+ mplang/runtime/client.py,sha256=w8sPuQzqaJI5uS_3JHu2mf0tLaFmZH3f6-SeUBfMLMY,15737
54
54
  mplang/runtime/communicator.py,sha256=Lek6_h_Wmr_W-_JpT-vMxL3CHxcVZdtf7jdaLGuxPgQ,3199
55
+ mplang/runtime/data_providers.py,sha256=TPAJSko_2J95oiHCxAKALICVM_LvnxzfgcM48ubhnKU,8226
55
56
  mplang/runtime/driver.py,sha256=Ok1jY301ctN1_KTb4jwSxOdB0lI_xhx9AwhtEGJ-VLQ,11300
56
57
  mplang/runtime/exceptions.py,sha256=c18U0xK20dRmgZo0ogTf5vXlkix9y3VAFuzkHxaXPEk,981
57
58
  mplang/runtime/http_api.md,sha256=-re1DhEqMplAkv_wnqEU-PSs8tTzf4-Ml0Gq0f3Go6s,4883
58
59
  mplang/runtime/link_comm.py,sha256=uNqTCGZVwWeuHAb7yXXQf0DUsMXLa8leHCkrcZdzYMU,4559
59
- mplang/runtime/resource.py,sha256=nGNIYGS1qL34tumoTJKZqT30AnS9C2-JrHZuFMJ1BTA,10776
60
- mplang/runtime/server.py,sha256=onDw3hnqSIK8cOYuRgblTC4myr-V7ym72GW6pWrpFQE,10698
60
+ mplang/runtime/resource.py,sha256=BkzpAjRrkjS-5FPawHHcVzxEE_htpMtS4JJCvcLrnU0,12564
61
+ mplang/runtime/server.py,sha256=gTPqAux1EdefaBFnserYIXamoi7pbEsQrFX6cXbOjik,14716
61
62
  mplang/runtime/simulation.py,sha256=kj-RtvvypISO2xyYQGtm-N8yavnkVpD33KSLpvL-Vms,11107
62
63
  mplang/simp/__init__.py,sha256=DmSMcKvHVXWS2pYsuHazEmwOWWpZeKOJQsNU6VxC10U,11614
63
64
  mplang/simp/mpi.py,sha256=Wv_Q16TQ3rdLam6OzqXiefIGSMmagGkso09ycyOkHEs,4774
@@ -68,8 +69,8 @@ mplang/utils/crypto.py,sha256=rvPomBFtznRHc3RPi6Aip9lsU8zW2oxBqGv1K3vn7Rs,1052
68
69
  mplang/utils/func_utils.py,sha256=vCJcZmu0bEbqhOQKdpttV2_MBllIcPSN0b8U4WjNGGo,5164
69
70
  mplang/utils/spu_utils.py,sha256=S3L9RBkBe2AvSuMSQQ12cBY5Y1NPthubvErSX_7nj1A,4158
70
71
  mplang/utils/table_utils.py,sha256=aC-IZOKkSmFkpr3NZchLM0Wt0GOn-rg_xHBHREWBwAU,2202
71
- mplang_nightly-0.1.dev139.dist-info/METADATA,sha256=KejJXvQkNu9KXYerLe6dUhuyIuH54OrxhHq6IMDfJNA,16547
72
- mplang_nightly-0.1.dev139.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
73
- mplang_nightly-0.1.dev139.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
74
- mplang_nightly-0.1.dev139.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
75
- mplang_nightly-0.1.dev139.dist-info/RECORD,,
72
+ mplang_nightly-0.1.dev140.dist-info/METADATA,sha256=uvtKWbcwO-wpF_BxVBDWqof7tV82vY7vXMLH77prn5g,16547
73
+ mplang_nightly-0.1.dev140.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
74
+ mplang_nightly-0.1.dev140.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
75
+ mplang_nightly-0.1.dev140.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
76
+ mplang_nightly-0.1.dev140.dist-info/RECORD,,