mplang-nightly 0.1.dev277__py3-none-any.whl → 0.1.dev279__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +30 -6
- mplang/backends/simp_worker/ops.py +6 -2
- mplang/edsl/__init__.py +3 -0
- mplang/edsl/program.py +134 -0
- mplang/runtime/interpreter.py +294 -36
- mplang/tool/__init__.py +46 -0
- mplang/tool/program.py +335 -0
- {mplang_nightly-0.1.dev277.dist-info → mplang_nightly-0.1.dev279.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev277.dist-info → mplang_nightly-0.1.dev279.dist-info}/RECORD +12 -9
- {mplang_nightly-0.1.dev277.dist-info → mplang_nightly-0.1.dev279.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev277.dist-info → mplang_nightly-0.1.dev279.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev277.dist-info → mplang_nightly-0.1.dev279.dist-info}/licenses/LICENSE +0 -0
mplang/__init__.py
CHANGED
|
@@ -34,6 +34,7 @@ except Exception:
|
|
|
34
34
|
# Fallback for development/editable installs when package is not installed
|
|
35
35
|
__version__ = "0.0.0-dev"
|
|
36
36
|
|
|
37
|
+
import mplang.tool as tool
|
|
37
38
|
from mplang import dialects
|
|
38
39
|
from mplang.backends.simp_driver.ops import DRIVER_HANDLERS
|
|
39
40
|
from mplang.backends.simp_worker import SimpWorker
|
|
@@ -41,6 +42,8 @@ from mplang.backends.simp_worker.mem import LocalMesh
|
|
|
41
42
|
from mplang.backends.simp_worker.ops import WORKER_HANDLERS
|
|
42
43
|
from mplang.dialects.simp import make_driver, make_simulator
|
|
43
44
|
from mplang.edsl import (
|
|
45
|
+
CompiledProgram,
|
|
46
|
+
FlatIOSignature,
|
|
44
47
|
Graph,
|
|
45
48
|
GraphPrinter,
|
|
46
49
|
Object,
|
|
@@ -125,7 +128,7 @@ def _get_context(context: Interpreter | None) -> Interpreter:
|
|
|
125
128
|
|
|
126
129
|
|
|
127
130
|
def evaluate(
|
|
128
|
-
fn: Callable[..., Any] | TracedFunction,
|
|
131
|
+
fn: Callable[..., Any] | TracedFunction | CompiledProgram,
|
|
129
132
|
*args: Any,
|
|
130
133
|
context: Interpreter | None = None,
|
|
131
134
|
**kwargs: Any,
|
|
@@ -158,15 +161,33 @@ def evaluate(
|
|
|
158
161
|
return val.runtime_obj
|
|
159
162
|
return val
|
|
160
163
|
|
|
164
|
+
def eval_graph(graph: Graph, inputs: list[Any]) -> list[InterpObject]:
|
|
165
|
+
runtime_inputs = [unwrap_if_interp(v) for v in inputs]
|
|
166
|
+
raw_result = interp.evaluate_graph(graph, runtime_inputs)
|
|
167
|
+
return [
|
|
168
|
+
InterpObject(v, graph.outputs[i].type, interp)
|
|
169
|
+
for i, v in enumerate(raw_result)
|
|
170
|
+
]
|
|
171
|
+
|
|
161
172
|
with interp:
|
|
173
|
+
if isinstance(fn, CompiledProgram):
|
|
174
|
+
if kwargs:
|
|
175
|
+
raise TypeError(
|
|
176
|
+
"mp.evaluate(CompiledProgram, ...) does not accept keyword arguments; "
|
|
177
|
+
"pass flat positional inputs only."
|
|
178
|
+
)
|
|
179
|
+
if len(args) != fn.signature.input_arity:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
"CompiledProgram requires flat positional inputs matching its signature; "
|
|
182
|
+
f"expected {fn.signature.input_arity}, got {len(args)}."
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return eval_graph(fn.graph, list(args))
|
|
186
|
+
|
|
162
187
|
if isinstance(fn, TracedFunction):
|
|
163
188
|
inputs = fn.prepare_inputs(*args, **kwargs)
|
|
164
189
|
inputs = [unwrap_if_interp(v) for v in inputs]
|
|
165
|
-
|
|
166
|
-
wrapped = [
|
|
167
|
-
InterpObject(v, fn.graph.outputs[i].type, interp)
|
|
168
|
-
for i, v in enumerate(raw_result)
|
|
169
|
-
]
|
|
190
|
+
wrapped = eval_graph(fn.graph, inputs)
|
|
170
191
|
return fn.reconstruct_outputs(wrapped)
|
|
171
192
|
|
|
172
193
|
return fn(*args, **kwargs)
|
|
@@ -417,6 +438,9 @@ __all__ = [ # noqa: RUF022
|
|
|
417
438
|
"WORKER_HANDLERS",
|
|
418
439
|
"make_driver",
|
|
419
440
|
"make_simulator",
|
|
441
|
+
"tool",
|
|
442
|
+
"CompiledProgram",
|
|
443
|
+
"FlatIOSignature",
|
|
420
444
|
# Dialects
|
|
421
445
|
"dialects",
|
|
422
446
|
"register_default_context_factory",
|
|
@@ -84,16 +84,20 @@ def _shuffle_static_worker_impl(
|
|
|
84
84
|
my_rank = worker.rank
|
|
85
85
|
data = args[0]
|
|
86
86
|
|
|
87
|
+
exec_id = interpreter.current_op_exec_id()
|
|
88
|
+
graph_key = interpreter.current_graph_exec_key()
|
|
89
|
+
key_prefix = f"shuffle_{graph_key}_{op.name}_{exec_id}"
|
|
90
|
+
|
|
87
91
|
for tgt, src in routing.items():
|
|
88
92
|
if src == my_rank and tgt != my_rank:
|
|
89
|
-
key = f"
|
|
93
|
+
key = f"{key_prefix}_{tgt}"
|
|
90
94
|
comm.send(tgt, key, data)
|
|
91
95
|
|
|
92
96
|
if my_rank in routing:
|
|
93
97
|
src = routing[my_rank]
|
|
94
98
|
if src == my_rank:
|
|
95
99
|
return data
|
|
96
|
-
key = f"
|
|
100
|
+
key = f"{key_prefix}_{my_rank}"
|
|
97
101
|
return comm.recv(src, key)
|
|
98
102
|
else:
|
|
99
103
|
return None
|
mplang/edsl/__init__.py
CHANGED
|
@@ -53,6 +53,7 @@ from .jit import jit
|
|
|
53
53
|
from .object import Object
|
|
54
54
|
from .primitive import Primitive, primitive
|
|
55
55
|
from .printer import GraphPrinter, format_graph
|
|
56
|
+
from .program import CompiledProgram, FlatIOSignature
|
|
56
57
|
from .tracer import TracedFunction, TraceObject, Tracer, trace
|
|
57
58
|
from .typing import MPType, ScalarType, SSType, TableType, TensorType, VectorType
|
|
58
59
|
|
|
@@ -65,7 +66,9 @@ TensorObject = Object[TensorType]
|
|
|
65
66
|
VectorObject = Object[VectorType]
|
|
66
67
|
|
|
67
68
|
__all__ = [
|
|
69
|
+
"CompiledProgram",
|
|
68
70
|
"Context",
|
|
71
|
+
"FlatIOSignature",
|
|
69
72
|
"Graph",
|
|
70
73
|
"GraphPrinter",
|
|
71
74
|
"MPObject",
|
mplang/edsl/program.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import hashlib
|
|
18
|
+
import json
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from typing import Any, ClassVar
|
|
21
|
+
|
|
22
|
+
from mplang.edsl import serde
|
|
23
|
+
from mplang.edsl.graph import Graph
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True)
|
|
27
|
+
class FlatIOSignature:
|
|
28
|
+
"""Portable I/O signature for source-free execution.
|
|
29
|
+
|
|
30
|
+
Only supports flat positional inputs/outputs corresponding to
|
|
31
|
+
`graph.inputs` / `graph.outputs`.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
kind: ClassVar[str] = "flat_list_v0"
|
|
35
|
+
input_arity: int
|
|
36
|
+
output_arity: int
|
|
37
|
+
|
|
38
|
+
def to_json(self) -> dict[str, Any]:
|
|
39
|
+
return {
|
|
40
|
+
"kind": self.kind,
|
|
41
|
+
"input_arity": self.input_arity,
|
|
42
|
+
"output_arity": self.output_arity,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def from_json(cls, data: dict[str, Any]) -> FlatIOSignature:
|
|
47
|
+
if data.get("kind") != cls.kind:
|
|
48
|
+
raise ValueError(f"Unsupported signature kind: {data.get('kind')}")
|
|
49
|
+
return cls(
|
|
50
|
+
input_arity=int(data["input_arity"]),
|
|
51
|
+
output_arity=int(data["output_arity"]),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@serde.register_class
|
|
56
|
+
@dataclass
|
|
57
|
+
class CompiledProgram:
|
|
58
|
+
"""Executable program decoupled from user source.
|
|
59
|
+
|
|
60
|
+
This is a *logical model*; packaging (file/zip/etc.) is handled by tool layer.
|
|
61
|
+
|
|
62
|
+
Current constraints:
|
|
63
|
+
- signature is flat positional list I/O.
|
|
64
|
+
- no closure captures.
|
|
65
|
+
- no constant outputs (out_imms) unless future signature captures them.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
_serde_kind: ClassVar[str] = "mplang.CompiledProgram"
|
|
69
|
+
|
|
70
|
+
graph: Graph
|
|
71
|
+
signature: FlatIOSignature
|
|
72
|
+
required_opcodes: list[str]
|
|
73
|
+
graph_digest: str
|
|
74
|
+
required_world_size: int | None = None
|
|
75
|
+
created_at: str | None = None
|
|
76
|
+
mplang_version: str | None = None
|
|
77
|
+
schema_version: int = 1
|
|
78
|
+
name: str | None = None
|
|
79
|
+
|
|
80
|
+
def to_json(self) -> dict[str, Any]:
|
|
81
|
+
return {
|
|
82
|
+
"schema_version": self.schema_version,
|
|
83
|
+
"name": self.name,
|
|
84
|
+
"graph": serde.to_json(self.graph),
|
|
85
|
+
"signature": self.signature.to_json(),
|
|
86
|
+
"required_opcodes": list(self.required_opcodes),
|
|
87
|
+
"graph_digest": self.graph_digest,
|
|
88
|
+
"required_world_size": self.required_world_size,
|
|
89
|
+
"created_at": self.created_at,
|
|
90
|
+
"mplang_version": self.mplang_version,
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def from_json(cls, data: dict[str, Any]) -> CompiledProgram:
|
|
95
|
+
if "schema_version" not in data:
|
|
96
|
+
raise KeyError("Missing required field: schema_version")
|
|
97
|
+
schema_version = int(data["schema_version"])
|
|
98
|
+
if schema_version != 1:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Unsupported CompiledProgram schema_version: {schema_version}"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
graph = serde.from_json(data["graph"])
|
|
104
|
+
if not isinstance(graph, Graph):
|
|
105
|
+
raise TypeError(
|
|
106
|
+
f"Expected graph to deserialize to Graph, got {type(graph).__name__}"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
signature = FlatIOSignature.from_json(data["signature"])
|
|
110
|
+
|
|
111
|
+
required_world_size = data.get("required_world_size")
|
|
112
|
+
if required_world_size is not None:
|
|
113
|
+
required_world_size = int(required_world_size)
|
|
114
|
+
return cls(
|
|
115
|
+
graph=graph,
|
|
116
|
+
signature=signature,
|
|
117
|
+
required_opcodes=list(data.get("required_opcodes", [])),
|
|
118
|
+
graph_digest=str(data["graph_digest"]),
|
|
119
|
+
required_world_size=required_world_size,
|
|
120
|
+
created_at=data.get("created_at"),
|
|
121
|
+
mplang_version=data.get("mplang_version"),
|
|
122
|
+
schema_version=schema_version,
|
|
123
|
+
name=data.get("name"),
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def compute_graph_digest(graph: Graph) -> str:
|
|
128
|
+
"""Compute a deterministic digest for a Graph.
|
|
129
|
+
|
|
130
|
+
We intentionally avoid `serde.dumps()` because it doesn't sort keys.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
canonical = json.dumps(serde.to_json(graph), sort_keys=True, separators=(",", ":"))
|
|
134
|
+
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
|
mplang/runtime/interpreter.py
CHANGED
|
@@ -24,17 +24,19 @@ from __future__ import annotations
|
|
|
24
24
|
|
|
25
25
|
import collections
|
|
26
26
|
import concurrent.futures
|
|
27
|
+
import contextlib
|
|
28
|
+
import hashlib
|
|
27
29
|
import json
|
|
28
30
|
import os
|
|
29
31
|
import pathlib
|
|
30
32
|
import queue
|
|
31
33
|
import threading
|
|
32
34
|
import time
|
|
33
|
-
from collections.abc import Callable
|
|
35
|
+
from collections.abc import Callable, Iterator
|
|
34
36
|
from typing import TYPE_CHECKING, Any, cast
|
|
35
37
|
|
|
36
38
|
from mplang.edsl.context import AbstractInterpreter
|
|
37
|
-
from mplang.edsl.graph import Graph
|
|
39
|
+
from mplang.edsl.graph import Graph, Value
|
|
38
40
|
from mplang.edsl.object import Object
|
|
39
41
|
from mplang.edsl.registry import get_impl
|
|
40
42
|
from mplang.edsl.typing import BaseType
|
|
@@ -364,12 +366,201 @@ class Interpreter(AbstractInterpreter):
|
|
|
364
366
|
# 2. MIMO Optimization: When one output of a multi-output op is computed,
|
|
365
367
|
# all sibling outputs are cached here to avoid re-execution.
|
|
366
368
|
self._execution_cache: dict[Any, InterpObject] = {}
|
|
369
|
+
|
|
370
|
+
# -----------------------------------------------------------------
|
|
371
|
+
# Graph-local op execution ids (for deterministic communication tags)
|
|
372
|
+
# -----------------------------------------------------------------
|
|
373
|
+
# We assign a monotonically increasing exec_id to each op execution
|
|
374
|
+
# within a graph namespace, and keep it deterministic across parties.
|
|
375
|
+
#
|
|
376
|
+
# IMPORTANT:
|
|
377
|
+
# - We intentionally make exec_id grow across repeated executions of the
|
|
378
|
+
# same region graph (e.g., while_loop iterations) to avoid tag/key reuse.
|
|
379
|
+
#
|
|
380
|
+
# Implementation:
|
|
381
|
+
# - Each evaluate_graph(graph, ...) reserves a contiguous exec_id range
|
|
382
|
+
# [base, base + len(graph.operations)).
|
|
383
|
+
# - Op exec_id = base + op_index_in_graph.
|
|
384
|
+
# - Reservation is persisted per graph_exec_key (structural hash).
|
|
385
|
+
# - We forbid concurrent execution of the same graph_hash to avoid
|
|
386
|
+
# message tag confusion when a backend uses only per-op tags.
|
|
387
|
+
self._exec_id_lock = threading.Lock()
|
|
388
|
+
self._graph_next_exec_base: dict[str, int] = {}
|
|
389
|
+
self._active_graph_exec_keys: set[str] = set()
|
|
390
|
+
self._tls = threading.local()
|
|
367
391
|
self.executor = executor
|
|
368
392
|
self.async_ops: set[str] = set()
|
|
369
393
|
self.name = name
|
|
370
394
|
self.trace_pid = trace_pid
|
|
371
395
|
self.store: ObjectStore | None = store
|
|
372
396
|
|
|
397
|
+
@contextlib.contextmanager
|
|
398
|
+
def _tls_exec_context(
|
|
399
|
+
self,
|
|
400
|
+
*,
|
|
401
|
+
graph_exec_key: str | None = None,
|
|
402
|
+
op_exec_id: int | None = None,
|
|
403
|
+
) -> Iterator[None]:
|
|
404
|
+
"""Temporarily set execution context in thread-local storage."""
|
|
405
|
+
|
|
406
|
+
prev_graph_key = getattr(self._tls, "current_graph_exec_key", None)
|
|
407
|
+
prev_exec_id = getattr(self._tls, "current_op_exec_id", None)
|
|
408
|
+
|
|
409
|
+
if graph_exec_key is not None:
|
|
410
|
+
self._tls.current_graph_exec_key = graph_exec_key
|
|
411
|
+
if op_exec_id is not None:
|
|
412
|
+
self._tls.current_op_exec_id = op_exec_id
|
|
413
|
+
|
|
414
|
+
try:
|
|
415
|
+
yield
|
|
416
|
+
finally:
|
|
417
|
+
if graph_exec_key is not None:
|
|
418
|
+
if prev_graph_key is None:
|
|
419
|
+
delattr(self._tls, "current_graph_exec_key")
|
|
420
|
+
else:
|
|
421
|
+
self._tls.current_graph_exec_key = prev_graph_key
|
|
422
|
+
|
|
423
|
+
if op_exec_id is not None:
|
|
424
|
+
if prev_exec_id is None:
|
|
425
|
+
delattr(self._tls, "current_op_exec_id")
|
|
426
|
+
else:
|
|
427
|
+
self._tls.current_op_exec_id = prev_exec_id
|
|
428
|
+
|
|
429
|
+
def _graph_exec_key(self, graph: Graph) -> str:
|
|
430
|
+
"""Return a deterministic, structural hash for a graph.
|
|
431
|
+
|
|
432
|
+
Used for:
|
|
433
|
+
- Namespacing per-graph exec_id counters
|
|
434
|
+
- Communication tag disambiguation (worker ops may include this key)
|
|
435
|
+
|
|
436
|
+
Note: we cache on the Graph object assuming graphs are immutable during
|
|
437
|
+
execution (finalized graphs / regions).
|
|
438
|
+
"""
|
|
439
|
+
|
|
440
|
+
cached = getattr(graph, "_exec_key", None)
|
|
441
|
+
if cached is not None:
|
|
442
|
+
return cast(str, cached)
|
|
443
|
+
|
|
444
|
+
# NOTE: We intentionally do NOT use graph.to_json() here.
|
|
445
|
+
# graph.to_json() requires all attrs to be JSON-serializable via serde,
|
|
446
|
+
# but graphs may legitimately contain runtime-only objects (e.g. JAX
|
|
447
|
+
# PyTreeDef used by func.func). For communication tag namespaces we use
|
|
448
|
+
# a simple structural fingerprint that is deterministic across parties.
|
|
449
|
+
|
|
450
|
+
def _stable_attr_value(obj: Any) -> Any | None:
|
|
451
|
+
"""Return a JSON-compatible stable value or None if unsupported.
|
|
452
|
+
|
|
453
|
+
We include only values that are likely deterministic across parties.
|
|
454
|
+
Unknown runtime objects are skipped (e.g. PyTreeDef, callables, etc.).
|
|
455
|
+
"""
|
|
456
|
+
|
|
457
|
+
if obj is None or isinstance(obj, (bool, int, float, str)):
|
|
458
|
+
return obj
|
|
459
|
+
|
|
460
|
+
if isinstance(obj, (bytes, bytearray, memoryview)):
|
|
461
|
+
b = bytes(obj)
|
|
462
|
+
return {
|
|
463
|
+
"_kind": "bytes",
|
|
464
|
+
"len": len(b),
|
|
465
|
+
"sha256": hashlib.sha256(b).hexdigest(),
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
try:
|
|
469
|
+
import numpy as np # type: ignore
|
|
470
|
+
|
|
471
|
+
if isinstance(obj, np.ndarray):
|
|
472
|
+
b = obj.tobytes(order="C")
|
|
473
|
+
return {
|
|
474
|
+
"_kind": "ndarray",
|
|
475
|
+
"dtype": str(obj.dtype),
|
|
476
|
+
"shape": list(obj.shape),
|
|
477
|
+
"sha256": hashlib.sha256(b).hexdigest(),
|
|
478
|
+
}
|
|
479
|
+
if isinstance(obj, (np.integer, np.floating)):
|
|
480
|
+
return obj.item()
|
|
481
|
+
except Exception:
|
|
482
|
+
pass
|
|
483
|
+
|
|
484
|
+
if isinstance(obj, (list, tuple)):
|
|
485
|
+
items: list[Any] = []
|
|
486
|
+
for x in obj:
|
|
487
|
+
sx = _stable_attr_value(x)
|
|
488
|
+
if sx is None:
|
|
489
|
+
return None
|
|
490
|
+
items.append(sx)
|
|
491
|
+
return items
|
|
492
|
+
|
|
493
|
+
if isinstance(obj, dict):
|
|
494
|
+
stable_items: list[tuple[Any, Any]] = []
|
|
495
|
+
for k, v in obj.items():
|
|
496
|
+
sk = _stable_attr_value(k)
|
|
497
|
+
sv = _stable_attr_value(v)
|
|
498
|
+
if sk is None or sv is None:
|
|
499
|
+
return None
|
|
500
|
+
stable_items.append((sk, sv))
|
|
501
|
+
stable_items.sort(
|
|
502
|
+
key=lambda kv: json.dumps(
|
|
503
|
+
kv[0], sort_keys=True, separators=(",", ":"), ensure_ascii=False
|
|
504
|
+
)
|
|
505
|
+
)
|
|
506
|
+
return {"_kind": "dict", "items": stable_items}
|
|
507
|
+
|
|
508
|
+
return None
|
|
509
|
+
|
|
510
|
+
def _graph_fingerprint(g: Graph) -> Any:
|
|
511
|
+
# Map SSA Values to stable indices independent of their textual names.
|
|
512
|
+
value_to_index: dict[Value, int] = {}
|
|
513
|
+
|
|
514
|
+
def _index(v: Value) -> int:
|
|
515
|
+
if v in value_to_index:
|
|
516
|
+
return value_to_index[v]
|
|
517
|
+
value_to_index[v] = len(value_to_index)
|
|
518
|
+
return value_to_index[v]
|
|
519
|
+
|
|
520
|
+
for v in g.inputs:
|
|
521
|
+
_index(v)
|
|
522
|
+
for op in g.operations:
|
|
523
|
+
for out in op.outputs:
|
|
524
|
+
_index(out)
|
|
525
|
+
|
|
526
|
+
ops_fp: list[dict[str, Any]] = []
|
|
527
|
+
for op in g.operations:
|
|
528
|
+
attr_keys = sorted(op.attrs.keys())
|
|
529
|
+
stable_attr_items: list[tuple[str, Any]] = []
|
|
530
|
+
for k in attr_keys:
|
|
531
|
+
attr_val = op.attrs.get(k)
|
|
532
|
+
sv = _stable_attr_value(attr_val)
|
|
533
|
+
if sv is not None:
|
|
534
|
+
stable_attr_items.append((k, sv))
|
|
535
|
+
|
|
536
|
+
ops_fp.append({
|
|
537
|
+
"opcode": op.opcode,
|
|
538
|
+
"inputs": [_index(v) for v in op.inputs],
|
|
539
|
+
"outputs": [str(v.type) for v in op.outputs],
|
|
540
|
+
"attrs": {"keys": attr_keys, "stable": stable_attr_items},
|
|
541
|
+
"regions": [_graph_fingerprint(r) for r in op.regions],
|
|
542
|
+
})
|
|
543
|
+
|
|
544
|
+
return {
|
|
545
|
+
"inputs": [str(v.type) for v in g.inputs],
|
|
546
|
+
"ops": ops_fp,
|
|
547
|
+
"outputs": [_index(v) for v in g.outputs],
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
fingerprint = _graph_fingerprint(graph)
|
|
551
|
+
|
|
552
|
+
payload = json.dumps(
|
|
553
|
+
fingerprint,
|
|
554
|
+
sort_keys=True,
|
|
555
|
+
separators=(",", ":"),
|
|
556
|
+
ensure_ascii=False,
|
|
557
|
+
).encode("utf-8")
|
|
558
|
+
key = hashlib.sha256(payload).hexdigest()
|
|
559
|
+
|
|
560
|
+
# Store on graph to avoid id(graph) reuse pitfalls.
|
|
561
|
+
graph._exec_key = key # type: ignore[attr-defined]
|
|
562
|
+
return key
|
|
563
|
+
|
|
373
564
|
def shutdown(self) -> None:
|
|
374
565
|
"""Shutdown the interpreter and release resources.
|
|
375
566
|
|
|
@@ -641,18 +832,70 @@ class Interpreter(AbstractInterpreter):
|
|
|
641
832
|
Returns:
|
|
642
833
|
List of runtime execution results corresponding to graph.outputs.
|
|
643
834
|
"""
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
835
|
+
graph_exec_key = self._graph_exec_key(graph)
|
|
836
|
+
|
|
837
|
+
# Prevent concurrent execution of the same graph hash.
|
|
838
|
+
with self._exec_id_lock:
|
|
839
|
+
if graph_exec_key in self._active_graph_exec_keys:
|
|
840
|
+
raise RuntimeError(
|
|
841
|
+
"Concurrent execution of the same graph is not allowed. "
|
|
842
|
+
f"graph_exec_key={graph_exec_key}"
|
|
843
|
+
)
|
|
844
|
+
self._active_graph_exec_keys.add(graph_exec_key)
|
|
845
|
+
|
|
846
|
+
try:
|
|
847
|
+
with self._tls_exec_context(graph_exec_key=graph_exec_key):
|
|
848
|
+
logger.debug(
|
|
849
|
+
"Evaluating graph: %d inputs, %d ops, %d outputs (job_id=%s, async=%s, graph_key=%s)",
|
|
850
|
+
len(inputs),
|
|
851
|
+
len(graph.operations),
|
|
852
|
+
len(graph.outputs),
|
|
853
|
+
job_id,
|
|
854
|
+
self.executor is not None,
|
|
855
|
+
graph_exec_key,
|
|
856
|
+
)
|
|
857
|
+
if self.executor:
|
|
858
|
+
return self._evaluate_graph_async(graph, inputs, job_id)
|
|
859
|
+
else:
|
|
860
|
+
return self._evaluate_graph_sync(graph, inputs, job_id)
|
|
861
|
+
finally:
|
|
862
|
+
with self._exec_id_lock:
|
|
863
|
+
self._active_graph_exec_keys.discard(graph_exec_key)
|
|
864
|
+
|
|
865
|
+
def _reserve_op_exec_base(self, graph: Graph) -> int:
|
|
866
|
+
"""Reserve a contiguous exec_id range for a single evaluate_graph call.
|
|
867
|
+
|
|
868
|
+
Counter is namespaced by the current graph_exec_key.
|
|
869
|
+
"""
|
|
870
|
+
key = self.current_graph_exec_key()
|
|
871
|
+
with self._exec_id_lock:
|
|
872
|
+
base = self._graph_next_exec_base.get(key, 0)
|
|
873
|
+
self._graph_next_exec_base[key] = base + len(graph.operations)
|
|
874
|
+
return base
|
|
875
|
+
|
|
876
|
+
def current_graph_exec_key(self) -> str:
|
|
877
|
+
"""Return current graph execution key during evaluate_graph execution."""
|
|
878
|
+
|
|
879
|
+
key = getattr(self._tls, "current_graph_exec_key", None)
|
|
880
|
+
if key is None:
|
|
881
|
+
raise RuntimeError(
|
|
882
|
+
"current_graph_exec_key() called outside of evaluate_graph execution"
|
|
883
|
+
)
|
|
884
|
+
return cast(str, key)
|
|
885
|
+
|
|
886
|
+
def current_op_exec_id(self) -> int:
|
|
887
|
+
"""Return current op exec_id during graph execution.
|
|
888
|
+
|
|
889
|
+
Worker-side implementations can use this to build deterministic,
|
|
890
|
+
unique communication tags without coupling to any specific op.
|
|
891
|
+
"""
|
|
892
|
+
|
|
893
|
+
exec_id = getattr(self._tls, "current_op_exec_id", None)
|
|
894
|
+
if exec_id is None:
|
|
895
|
+
raise RuntimeError(
|
|
896
|
+
"current_op_exec_id() called outside of evaluate_graph execution"
|
|
897
|
+
)
|
|
898
|
+
return cast(int, exec_id)
|
|
656
899
|
|
|
657
900
|
def _evaluate_graph_sync(
|
|
658
901
|
self, graph: Graph, inputs: list[Any], job_id: str | None = None
|
|
@@ -661,7 +904,10 @@ class Interpreter(AbstractInterpreter):
|
|
|
661
904
|
# Local environment: Value -> Runtime Object
|
|
662
905
|
env = dict(zip(graph.inputs, inputs, strict=True))
|
|
663
906
|
|
|
664
|
-
|
|
907
|
+
op_exec_base = self._reserve_op_exec_base(graph)
|
|
908
|
+
|
|
909
|
+
for op_index, op in enumerate(graph.operations):
|
|
910
|
+
exec_id = op_exec_base + op_index
|
|
665
911
|
# Resolve inputs
|
|
666
912
|
try:
|
|
667
913
|
args = [env[val] for val in op.inputs]
|
|
@@ -685,15 +931,16 @@ class Interpreter(AbstractInterpreter):
|
|
|
685
931
|
if not handler:
|
|
686
932
|
handler = get_impl(op.opcode)
|
|
687
933
|
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
934
|
+
with self._tls_exec_context(op_exec_id=exec_id):
|
|
935
|
+
if handler:
|
|
936
|
+
# Pass interpreter to support recursive execution (HOFs)
|
|
937
|
+
# Pass op to access attributes and regions
|
|
938
|
+
# Pass args as runtime values
|
|
939
|
+
results = handler(self, op, *args)
|
|
940
|
+
else:
|
|
941
|
+
raise NotImplementedError(
|
|
942
|
+
f"No implementation registered for opcode: {op.opcode}"
|
|
943
|
+
)
|
|
697
944
|
|
|
698
945
|
# Update environment with outputs
|
|
699
946
|
# Handler should return a single value or a tuple/list of values
|
|
@@ -719,6 +966,9 @@ class Interpreter(AbstractInterpreter):
|
|
|
719
966
|
self, graph: Graph, inputs: list[Any], job_id: str | None = None
|
|
720
967
|
) -> list[Any]:
|
|
721
968
|
"""Asynchronous execution with non-blocking DAG scheduling."""
|
|
969
|
+
graph_exec_key = self.current_graph_exec_key()
|
|
970
|
+
op_exec_base = self._reserve_op_exec_base(graph)
|
|
971
|
+
op_to_index = {op: i for i, op in enumerate(graph.operations)}
|
|
722
972
|
# Tracer setup (if not provided, use a disabled stub)
|
|
723
973
|
tracer: ExecutionTracer | _NullTracer
|
|
724
974
|
if self.tracer:
|
|
@@ -817,6 +1067,8 @@ class Interpreter(AbstractInterpreter):
|
|
|
817
1067
|
# Extract args from env (must be ready)
|
|
818
1068
|
args = [env[val] for val in op.inputs]
|
|
819
1069
|
|
|
1070
|
+
exec_id = op_exec_base + op_to_index[op]
|
|
1071
|
+
|
|
820
1072
|
handler = self.handlers.get(op.opcode)
|
|
821
1073
|
if not handler:
|
|
822
1074
|
handler = get_impl(op.opcode)
|
|
@@ -833,12 +1085,15 @@ class Interpreter(AbstractInterpreter):
|
|
|
833
1085
|
|
|
834
1086
|
# Submit to executor
|
|
835
1087
|
def task() -> Any:
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
)
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
1088
|
+
with self._tls_exec_context(
|
|
1089
|
+
graph_exec_key=graph_exec_key, op_exec_id=exec_id
|
|
1090
|
+
):
|
|
1091
|
+
start_ts = tracer.log_start(
|
|
1092
|
+
op, pid=self.trace_pid, namespace=self.trace_pid
|
|
1093
|
+
)
|
|
1094
|
+
res = handler(self, op, *args)
|
|
1095
|
+
tracer.log_end(op, start_ts, pid=self.trace_pid)
|
|
1096
|
+
return res
|
|
842
1097
|
|
|
843
1098
|
def callback(fut: Any) -> None:
|
|
844
1099
|
try:
|
|
@@ -852,12 +1107,15 @@ class Interpreter(AbstractInterpreter):
|
|
|
852
1107
|
else:
|
|
853
1108
|
# Sync execution (run immediately)
|
|
854
1109
|
try:
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
)
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
1110
|
+
with self._tls_exec_context(
|
|
1111
|
+
graph_exec_key=graph_exec_key, op_exec_id=exec_id
|
|
1112
|
+
):
|
|
1113
|
+
start_ts = tracer.log_start(
|
|
1114
|
+
op, pid=self.trace_pid, namespace=self.trace_pid
|
|
1115
|
+
)
|
|
1116
|
+
res = handler(self, op, *args)
|
|
1117
|
+
tracer.log_end(op, start_ts, pid=self.trace_pid)
|
|
1118
|
+
on_op_done(op, res)
|
|
861
1119
|
except Exception as e:
|
|
862
1120
|
on_op_done(op, None, error=e)
|
|
863
1121
|
|
mplang/tool/__init__.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Tool-layer APIs for MPLang.
|
|
16
|
+
|
|
17
|
+
This package contains utilities that are intentionally *not* part of the core
|
|
18
|
+
EDSL execution surface. In particular, compile/execute decoupling lives here:
|
|
19
|
+
- build a portable `CompiledProgram`
|
|
20
|
+
- pack/unpack to a container format
|
|
21
|
+
|
|
22
|
+
These helpers must not depend on user source code being available at execution.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
from mplang.edsl.program import CompiledProgram, FlatIOSignature
|
|
28
|
+
from mplang.tool.program import (
|
|
29
|
+
compile_program,
|
|
30
|
+
inspect_artifact,
|
|
31
|
+
pack,
|
|
32
|
+
pack_to_path,
|
|
33
|
+
unpack,
|
|
34
|
+
unpack_path,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
"CompiledProgram",
|
|
39
|
+
"FlatIOSignature",
|
|
40
|
+
"compile_program",
|
|
41
|
+
"inspect_artifact",
|
|
42
|
+
"pack",
|
|
43
|
+
"pack_to_path",
|
|
44
|
+
"unpack",
|
|
45
|
+
"unpack_path",
|
|
46
|
+
]
|
mplang/tool/program.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import io
|
|
18
|
+
import json
|
|
19
|
+
import tarfile
|
|
20
|
+
from datetime import UTC, datetime
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Any, Literal
|
|
23
|
+
|
|
24
|
+
import mplang
|
|
25
|
+
from mplang.edsl import serde
|
|
26
|
+
from mplang.edsl.graph import Graph
|
|
27
|
+
from mplang.edsl.program import (
|
|
28
|
+
CompiledProgram,
|
|
29
|
+
FlatIOSignature,
|
|
30
|
+
compute_graph_digest,
|
|
31
|
+
)
|
|
32
|
+
from mplang.edsl.tracer import TracedFunction, trace
|
|
33
|
+
|
|
34
|
+
DEFAULT_MAX_ARTIFACT_JSON_BYTES = 512 * 1024 * 1024 # 512 MiB
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _utc_now_iso() -> str:
|
|
38
|
+
return datetime.now(UTC).isoformat()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _iter_graphs(root: Graph) -> list[Graph]:
|
|
42
|
+
# Use an explicit stack to avoid Python recursion limits.
|
|
43
|
+
# Also guard against potential region graph cycles.
|
|
44
|
+
out: list[Graph] = []
|
|
45
|
+
stack: list[Graph] = [root]
|
|
46
|
+
visited: set[int] = set()
|
|
47
|
+
while stack:
|
|
48
|
+
graph = stack.pop()
|
|
49
|
+
graph_id = id(graph)
|
|
50
|
+
if graph_id in visited:
|
|
51
|
+
continue
|
|
52
|
+
visited.add(graph_id)
|
|
53
|
+
out.append(graph)
|
|
54
|
+
|
|
55
|
+
for op in graph.operations:
|
|
56
|
+
if op.regions:
|
|
57
|
+
stack.extend(op.regions)
|
|
58
|
+
|
|
59
|
+
return out
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _collect_opcodes(graph: Graph) -> set[str]:
|
|
63
|
+
opcodes: set[str] = set()
|
|
64
|
+
for g in _iter_graphs(graph):
|
|
65
|
+
for op in g.operations:
|
|
66
|
+
opcodes.add(op.opcode)
|
|
67
|
+
return opcodes
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _collect_parties(graph: Graph) -> set[int]:
|
|
71
|
+
parties: set[int] = set()
|
|
72
|
+
for g in _iter_graphs(graph):
|
|
73
|
+
for op in g.operations:
|
|
74
|
+
raw = op.attrs.get("parties")
|
|
75
|
+
if raw is None:
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
if not isinstance(raw, (list, tuple, set)):
|
|
79
|
+
raise TypeError(
|
|
80
|
+
"Invalid 'parties' attribute: expected list/tuple/set of ints, "
|
|
81
|
+
f"got {type(raw).__name__}"
|
|
82
|
+
)
|
|
83
|
+
for p in raw:
|
|
84
|
+
p_int = int(p)
|
|
85
|
+
if p_int < 0:
|
|
86
|
+
raise ValueError("Invalid 'parties' attribute: negative party id")
|
|
87
|
+
parties.add(p_int)
|
|
88
|
+
return parties
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _compute_required_world_size(graph: Graph) -> int:
|
|
92
|
+
parties = _collect_parties(graph)
|
|
93
|
+
if not parties:
|
|
94
|
+
return 0
|
|
95
|
+
return max(parties) + 1
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _validate_traced_for_artifact(traced: TracedFunction) -> None:
|
|
99
|
+
# Restriction: no closure captures
|
|
100
|
+
if traced.captured:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"CompiledProgram does not support closure captures; "
|
|
103
|
+
"please refactor to pass all values explicitly."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Restriction: no constant outputs (out_imms)
|
|
107
|
+
if traced.out_imms:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
"CompiledProgram does not support constant outputs (out_imms); "
|
|
110
|
+
"return only traced values (graph outputs)."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Restriction: signature is flat positional list I/O.
|
|
114
|
+
# We do not preserve (args, kwargs) pytree metadata.
|
|
115
|
+
# We therefore require all runtime-provided inputs correspond exactly to graph.inputs.
|
|
116
|
+
if len(traced.graph.inputs) != len(traced.in_var_pos):
|
|
117
|
+
raise ValueError(
|
|
118
|
+
"CompiledProgram requires flat positional inputs that map 1:1 to graph.inputs; "
|
|
119
|
+
f"got graph.inputs={len(traced.graph.inputs)} but in_var_pos={len(traced.in_var_pos)}."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _validate_program(program: CompiledProgram) -> None:
|
|
124
|
+
if program.signature.kind != FlatIOSignature.kind:
|
|
125
|
+
raise ValueError(f"Unsupported signature kind: {program.signature.kind}")
|
|
126
|
+
|
|
127
|
+
if program.signature.input_arity != len(program.graph.inputs):
|
|
128
|
+
raise ValueError(
|
|
129
|
+
"Signature input_arity does not match graph.inputs: "
|
|
130
|
+
f"input_arity={program.signature.input_arity}, inputs={len(program.graph.inputs)}"
|
|
131
|
+
)
|
|
132
|
+
if program.signature.output_arity != len(program.graph.outputs):
|
|
133
|
+
raise ValueError(
|
|
134
|
+
"Signature output_arity does not match graph.outputs: "
|
|
135
|
+
f"output_arity={program.signature.output_arity}, outputs={len(program.graph.outputs)}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
expected_opcodes = sorted(_collect_opcodes(program.graph))
|
|
139
|
+
if sorted(program.required_opcodes) != expected_opcodes:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"required_opcodes mismatch with graph content; "
|
|
142
|
+
"artifact may be corrupted or constructed inconsistently."
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
actual_digest = compute_graph_digest(program.graph)
|
|
146
|
+
if program.graph_digest and program.graph_digest != actual_digest:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
"Graph digest mismatch: "
|
|
149
|
+
f"expected={program.graph_digest}, actual={actual_digest}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
expected_world_size = _compute_required_world_size(program.graph)
|
|
153
|
+
if (
|
|
154
|
+
program.required_world_size is not None
|
|
155
|
+
and program.required_world_size != expected_world_size
|
|
156
|
+
):
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"required_world_size mismatch with graph content; "
|
|
159
|
+
f"expected={expected_world_size}, got={program.required_world_size}."
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Ensure JSON serialization works (fail fast for non-serde attrs).
|
|
163
|
+
serde.to_json(program)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def compile_program(
|
|
167
|
+
fn_or_traced: Any,
|
|
168
|
+
*args: Any,
|
|
169
|
+
context: Any | None = None,
|
|
170
|
+
name: str | None = None,
|
|
171
|
+
**kwargs: Any,
|
|
172
|
+
) -> CompiledProgram:
|
|
173
|
+
"""Compile (trace) into a source-free executable `CompiledProgram`.
|
|
174
|
+
|
|
175
|
+
Restrictions (enforced):
|
|
176
|
+
- no closure captures
|
|
177
|
+
- no constant outputs (`out_imms` must be empty)
|
|
178
|
+
- signature is flat list (positional) I/O
|
|
179
|
+
|
|
180
|
+
Note: `in_imms` (compile-time constants) are allowed: they are baked into the graph.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
traced: TracedFunction
|
|
184
|
+
if isinstance(fn_or_traced, TracedFunction):
|
|
185
|
+
traced = fn_or_traced
|
|
186
|
+
else:
|
|
187
|
+
if context is not None:
|
|
188
|
+
with context:
|
|
189
|
+
traced = trace(fn_or_traced, *args, **kwargs)
|
|
190
|
+
else:
|
|
191
|
+
traced = trace(fn_or_traced, *args, **kwargs)
|
|
192
|
+
|
|
193
|
+
_validate_traced_for_artifact(traced)
|
|
194
|
+
|
|
195
|
+
signature = FlatIOSignature(
|
|
196
|
+
input_arity=len(traced.graph.inputs),
|
|
197
|
+
output_arity=len(traced.graph.outputs),
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
required_opcodes = sorted(_collect_opcodes(traced.graph))
|
|
201
|
+
graph_digest = compute_graph_digest(traced.graph)
|
|
202
|
+
required_world_size = _compute_required_world_size(traced.graph)
|
|
203
|
+
|
|
204
|
+
program = CompiledProgram(
|
|
205
|
+
graph=traced.graph,
|
|
206
|
+
signature=signature,
|
|
207
|
+
required_opcodes=required_opcodes,
|
|
208
|
+
graph_digest=graph_digest,
|
|
209
|
+
required_world_size=required_world_size,
|
|
210
|
+
created_at=_utc_now_iso(),
|
|
211
|
+
mplang_version=getattr(mplang, "__version__", None),
|
|
212
|
+
name=name or traced.name,
|
|
213
|
+
)
|
|
214
|
+
return program
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def pack(program: CompiledProgram, *, compress: bool = True) -> bytes:
|
|
218
|
+
"""Pack a `CompiledProgram` into portable bytes.
|
|
219
|
+
|
|
220
|
+
Container format (recommended): a `tar.gz` archive containing a single
|
|
221
|
+
human-readable JSON file `artifact.json`.
|
|
222
|
+
|
|
223
|
+
This allows users to inspect artifacts via:
|
|
224
|
+
`tar -xzf program.tar.gz && cat artifact.json`
|
|
225
|
+
|
|
226
|
+
If `compress=False`, returns an uncompressed tar archive (still extractable
|
|
227
|
+
via `tar -xf`).
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
artifact_json = json.dumps(
|
|
231
|
+
serde.to_json(program),
|
|
232
|
+
ensure_ascii=False,
|
|
233
|
+
indent=2,
|
|
234
|
+
sort_keys=True,
|
|
235
|
+
).encode("utf-8")
|
|
236
|
+
|
|
237
|
+
buf = io.BytesIO()
|
|
238
|
+
mode: Literal["w:gz", "w"] = "w:gz" if compress else "w"
|
|
239
|
+
with tarfile.open(fileobj=buf, mode=mode) as tf:
|
|
240
|
+
info = tarfile.TarInfo(name="artifact.json")
|
|
241
|
+
info.size = len(artifact_json)
|
|
242
|
+
tf.addfile(info, io.BytesIO(artifact_json))
|
|
243
|
+
|
|
244
|
+
return buf.getvalue()
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def pack_to_path(
|
|
248
|
+
program: CompiledProgram, path: str | Path, *, compress: bool = True
|
|
249
|
+
) -> Path:
|
|
250
|
+
"""Pack and write artifact to disk.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
program: Program to pack.
|
|
254
|
+
path: Output path (typically ends with `.tar.gz`).
|
|
255
|
+
compress: Whether to gzip the tar archive.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
The resolved output path.
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
out_path = Path(path).expanduser().resolve()
|
|
262
|
+
out_path.write_bytes(pack(program, compress=compress))
|
|
263
|
+
return out_path
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def unpack(
|
|
267
|
+
data: bytes, *, max_artifact_json_bytes: int = DEFAULT_MAX_ARTIFACT_JSON_BYTES
|
|
268
|
+
) -> CompiledProgram:
|
|
269
|
+
"""Unpack bytes into a `CompiledProgram`.
|
|
270
|
+
|
|
271
|
+
Supported container format: tar(.gz) containing `artifact.json`.
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
try:
|
|
275
|
+
with tarfile.open(fileobj=io.BytesIO(data), mode="r:*") as tf:
|
|
276
|
+
member = tf.getmember("artifact.json")
|
|
277
|
+
|
|
278
|
+
if not member.isfile():
|
|
279
|
+
raise ValueError("artifact.json is not a regular file")
|
|
280
|
+
|
|
281
|
+
if member.size < 0:
|
|
282
|
+
raise ValueError("Invalid artifact.json size in tar header")
|
|
283
|
+
|
|
284
|
+
if member.size > max_artifact_json_bytes:
|
|
285
|
+
raise ValueError(
|
|
286
|
+
"artifact.json is too large to unpack safely: "
|
|
287
|
+
f"size={member.size} bytes, limit={max_artifact_json_bytes} bytes"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
f = tf.extractfile(member)
|
|
291
|
+
if f is None:
|
|
292
|
+
raise ValueError("artifact.json not found in tar archive")
|
|
293
|
+
payload = json.loads(f.read().decode("utf-8"))
|
|
294
|
+
except (tarfile.ReadError, KeyError, OSError, json.JSONDecodeError) as exc:
|
|
295
|
+
raise ValueError(
|
|
296
|
+
"Invalid artifact container: expected tar(.gz) with artifact.json"
|
|
297
|
+
) from exc
|
|
298
|
+
|
|
299
|
+
program = serde.from_json(payload)
|
|
300
|
+
if not isinstance(program, CompiledProgram):
|
|
301
|
+
raise TypeError(
|
|
302
|
+
f"Expected artifact.json to deserialize to CompiledProgram, got {type(program).__name__}"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
_validate_program(program)
|
|
306
|
+
return program
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def unpack_path(path: str | Path) -> CompiledProgram:
|
|
310
|
+
"""Read an artifact from disk and unpack it."""
|
|
311
|
+
|
|
312
|
+
in_path = Path(path).expanduser().resolve()
|
|
313
|
+
return unpack(in_path.read_bytes())
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def inspect_artifact(data: bytes) -> dict[str, Any]:
|
|
317
|
+
"""Return a JSON-friendly inspection report without executing."""
|
|
318
|
+
|
|
319
|
+
program = unpack(data)
|
|
320
|
+
return {
|
|
321
|
+
"schema_version": program.schema_version,
|
|
322
|
+
"name": program.name,
|
|
323
|
+
"mplang_version": program.mplang_version,
|
|
324
|
+
"created_at": program.created_at,
|
|
325
|
+
"graph_digest": program.graph_digest,
|
|
326
|
+
"required_world_size": program.required_world_size,
|
|
327
|
+
"signature": program.signature.to_json(),
|
|
328
|
+
"required_opcodes": program.required_opcodes,
|
|
329
|
+
"graph": {
|
|
330
|
+
"inputs": len(program.graph.inputs),
|
|
331
|
+
"ops": len(program.graph.operations),
|
|
332
|
+
"outputs": len(program.graph.outputs),
|
|
333
|
+
"region_count": sum(len(op.regions) for op in program.graph.operations),
|
|
334
|
+
},
|
|
335
|
+
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
mplang/__init__.py,sha256=
|
|
1
|
+
mplang/__init__.py,sha256=PsUcGqKGQym3N_CU1Rav304YQwGVj8CLVx9S_6UTD9c,14519
|
|
2
2
|
mplang/cli.py,sha256=NW0GmxZeRC4rrYg8RVBlZiDjkihBXGcHmrll-JFOqWM,20317
|
|
3
3
|
mplang/cli_guide.md,sha256=hKC6AKgJn-lM_wZ0CzZIP2QUBxGPnT0Op_1YyeUhCfI,3581
|
|
4
4
|
mplang/logging_config.py,sha256=6Zm1Z_EBnzxAfeKr94xiLpWYDD8fa4ZEs2g_kMoH8eI,7579
|
|
@@ -26,7 +26,7 @@ mplang/backends/simp_driver/values.py,sha256=Lz1utNSIzH-dCzZAEjU6JRcxPsfKGfUJrYl
|
|
|
26
26
|
mplang/backends/simp_worker/__init__.py,sha256=gdrSY1-MDkupCoJ8xwwH7em7fgVWv3J4gBJ45uHdzgg,961
|
|
27
27
|
mplang/backends/simp_worker/http.py,sha256=90nJnNLSM9TUVRxhAFq9pyNk0LwmSmvgnv3Tb8KFWSE,12660
|
|
28
28
|
mplang/backends/simp_worker/mem.py,sha256=tMGiRppeca0TnY8WdqYQMQvsx5UVswCqdeOhiDlLQBs,3574
|
|
29
|
-
mplang/backends/simp_worker/ops.py,sha256=
|
|
29
|
+
mplang/backends/simp_worker/ops.py,sha256=ntxfkD4e6Il4w7FshK1ODcUCUPMlipt33pDY_x5iC0U,5661
|
|
30
30
|
mplang/backends/simp_worker/state.py,sha256=nIu0ybvdYqRqp0TkoSneUF2u31evDHucCRduVBaDals,1445
|
|
31
31
|
mplang/dialects/__init__.py,sha256=CYMmkeQVU0Znr9n3_5clZKb16u7acJ5jl5Zjbx4Tn1U,1478
|
|
32
32
|
mplang/dialects/bfv.py,sha256=m5YfobFCBqn0lg2zBM9RNs2AC7i4PUQH2qXjHLHwSy4,22332
|
|
@@ -42,13 +42,14 @@ mplang/dialects/table.py,sha256=i9ruyh91_tSWu9rsLomrBUfqRdbHiZMMMJzNKfMrAUc,1353
|
|
|
42
42
|
mplang/dialects/tee.py,sha256=BMFSbeK-Ck2jQP4qY9bZeNYTxEa7uEtUWLZLC4BPQxk,10111
|
|
43
43
|
mplang/dialects/tensor.py,sha256=7aAYKaMaFjJ8N25yPFnmVhUuUdKJYy-M-a4NsZGE7kY,39893
|
|
44
44
|
mplang/edsl/README.md,sha256=viflvdRojOa6Xk_UMRPqpuPGXcPGmdlv2-XR6LO7B58,7592
|
|
45
|
-
mplang/edsl/__init__.py,sha256=
|
|
45
|
+
mplang/edsl/__init__.py,sha256=WL4efo6uY1br781_8IaCkSi7yCUldcfJfbtFsn6Fdj4,2698
|
|
46
46
|
mplang/edsl/context.py,sha256=Ln8n3bDe8_ISe42TAGzUuz8fw57-tu1APuihMfAtW1Y,10075
|
|
47
47
|
mplang/edsl/graph.py,sha256=nCeCN7-bxfzyv40fmxcEXOaVUx14cOCaHfFb7A9OBnE,14968
|
|
48
48
|
mplang/edsl/jit.py,sha256=7eLZHoIuL5FZo9G5eF9nI4EeayLK-OvJ0NoH3VG5vLI,2393
|
|
49
49
|
mplang/edsl/object.py,sha256=dBl58q-ondjpjPNBh8zZvIEj6pJw2yEoz6TCaM_oleA,1906
|
|
50
50
|
mplang/edsl/primitive.py,sha256=gDrn4FH682DUOgTqcQ2-9aqDYJau9L8E1ElswyOmmdw,10859
|
|
51
51
|
mplang/edsl/printer.py,sha256=drmfRkdCNqbkRfSDmejxtO-rEAaM13QyHB3AbAmKVFk,4393
|
|
52
|
+
mplang/edsl/program.py,sha256=_JdEU2-nb79VlFLcgMJf4JS30TARBeUIzno0y0SFVsg,4467
|
|
52
53
|
mplang/edsl/registry.py,sha256=hudXZPUrUUueEwgksDKN0cnE3iiXucuTaDdDK8uSPmk,6822
|
|
53
54
|
mplang/edsl/serde.py,sha256=8K94laE8ObeGuBoF6m7g3A-xEe98EvqQ_6ZPPspddAY,11641
|
|
54
55
|
mplang/edsl/tracer.py,sha256=EWN3eMVRG-CZsamTyINOnhhEUKhgd4CYwFMWeRpjycU,23129
|
|
@@ -91,13 +92,15 @@ mplang/libs/mpc/vole/ldpc.py,sha256=gOmIbyOjkGE5lewyatl3p6FizNNH8LZ_1oOhp_-TOck,
|
|
|
91
92
|
mplang/libs/mpc/vole/silver.py,sha256=EIxhpFIVNBemgeIZzCu5Cz_4wysxRm9b1Xfu0xiweVQ,12218
|
|
92
93
|
mplang/runtime/__init__.py,sha256=VdUwJ3kDaI46FvGw7iMGwcsjt0HTGmmRmaBwj99xKIw,620
|
|
93
94
|
mplang/runtime/dialect_state.py,sha256=HxO1i4kSOujS2tQzAF9-WmI3nChSaGgupf2_07dHetY,1277
|
|
94
|
-
mplang/runtime/interpreter.py,sha256=
|
|
95
|
+
mplang/runtime/interpreter.py,sha256=wcCWXpAGylqdw_HecR4suJtwmozHLrK5x6Q8xM-Pn24,43593
|
|
95
96
|
mplang/runtime/object_store.py,sha256=yT6jtKG2GUEJVmpq3gnQ8mCMvUFYzgBciC5A-J5KRdk,5998
|
|
96
97
|
mplang/runtime/value.py,sha256=EqlhSgxLTJi_FF3ppyKjMe4eHS6-ROx-zK1YesG1U4o,4311
|
|
98
|
+
mplang/tool/__init__.py,sha256=9K-T50W_vClUlyERcVx5xGZaeyv0Ts63SaQX6AZtjIs,1341
|
|
99
|
+
mplang/tool/program.py,sha256=W3H8bpPirnoJ4ZrmyPYuMCPadJis20o__n_1MKqCsWU,11058
|
|
97
100
|
mplang/utils/__init__.py,sha256=toubeyISiT6WDdITdfAvdY2iXVZU3PKVNWVeC9sYxuA,947
|
|
98
101
|
mplang/utils/func_utils.py,sha256=aZ-X43w8JKJgiF-IUMS0G7QqrNeoTM5ZPzRNd-tKxpw,5180
|
|
99
|
-
mplang_nightly-0.1.
|
|
100
|
-
mplang_nightly-0.1.
|
|
101
|
-
mplang_nightly-0.1.
|
|
102
|
-
mplang_nightly-0.1.
|
|
103
|
-
mplang_nightly-0.1.
|
|
102
|
+
mplang_nightly-0.1.dev279.dist-info/METADATA,sha256=Q4l1RV5WC5NfRV0heHxzHH8qVf3CkjOi6Ag1kbvsX38,16783
|
|
103
|
+
mplang_nightly-0.1.dev279.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
104
|
+
mplang_nightly-0.1.dev279.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
|
|
105
|
+
mplang_nightly-0.1.dev279.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
106
|
+
mplang_nightly-0.1.dev279.dist-info/RECORD,,
|
|
File without changes
|
{mplang_nightly-0.1.dev277.dist-info → mplang_nightly-0.1.dev279.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
{mplang_nightly-0.1.dev277.dist-info → mplang_nightly-0.1.dev279.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|