mplang-nightly 0.1.dev278__py3-none-any.whl → 0.1.dev279__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +30 -6
- mplang/edsl/__init__.py +3 -0
- mplang/edsl/program.py +134 -0
- mplang/tool/__init__.py +46 -0
- mplang/tool/program.py +335 -0
- {mplang_nightly-0.1.dev278.dist-info → mplang_nightly-0.1.dev279.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev278.dist-info → mplang_nightly-0.1.dev279.dist-info}/RECORD +10 -7
- {mplang_nightly-0.1.dev278.dist-info → mplang_nightly-0.1.dev279.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev278.dist-info → mplang_nightly-0.1.dev279.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev278.dist-info → mplang_nightly-0.1.dev279.dist-info}/licenses/LICENSE +0 -0
mplang/__init__.py
CHANGED
|
@@ -34,6 +34,7 @@ except Exception:
|
|
|
34
34
|
# Fallback for development/editable installs when package is not installed
|
|
35
35
|
__version__ = "0.0.0-dev"
|
|
36
36
|
|
|
37
|
+
import mplang.tool as tool
|
|
37
38
|
from mplang import dialects
|
|
38
39
|
from mplang.backends.simp_driver.ops import DRIVER_HANDLERS
|
|
39
40
|
from mplang.backends.simp_worker import SimpWorker
|
|
@@ -41,6 +42,8 @@ from mplang.backends.simp_worker.mem import LocalMesh
|
|
|
41
42
|
from mplang.backends.simp_worker.ops import WORKER_HANDLERS
|
|
42
43
|
from mplang.dialects.simp import make_driver, make_simulator
|
|
43
44
|
from mplang.edsl import (
|
|
45
|
+
CompiledProgram,
|
|
46
|
+
FlatIOSignature,
|
|
44
47
|
Graph,
|
|
45
48
|
GraphPrinter,
|
|
46
49
|
Object,
|
|
@@ -125,7 +128,7 @@ def _get_context(context: Interpreter | None) -> Interpreter:
|
|
|
125
128
|
|
|
126
129
|
|
|
127
130
|
def evaluate(
|
|
128
|
-
fn: Callable[..., Any] | TracedFunction,
|
|
131
|
+
fn: Callable[..., Any] | TracedFunction | CompiledProgram,
|
|
129
132
|
*args: Any,
|
|
130
133
|
context: Interpreter | None = None,
|
|
131
134
|
**kwargs: Any,
|
|
@@ -158,15 +161,33 @@ def evaluate(
|
|
|
158
161
|
return val.runtime_obj
|
|
159
162
|
return val
|
|
160
163
|
|
|
164
|
+
def eval_graph(graph: Graph, inputs: list[Any]) -> list[InterpObject]:
|
|
165
|
+
runtime_inputs = [unwrap_if_interp(v) for v in inputs]
|
|
166
|
+
raw_result = interp.evaluate_graph(graph, runtime_inputs)
|
|
167
|
+
return [
|
|
168
|
+
InterpObject(v, graph.outputs[i].type, interp)
|
|
169
|
+
for i, v in enumerate(raw_result)
|
|
170
|
+
]
|
|
171
|
+
|
|
161
172
|
with interp:
|
|
173
|
+
if isinstance(fn, CompiledProgram):
|
|
174
|
+
if kwargs:
|
|
175
|
+
raise TypeError(
|
|
176
|
+
"mp.evaluate(CompiledProgram, ...) does not accept keyword arguments; "
|
|
177
|
+
"pass flat positional inputs only."
|
|
178
|
+
)
|
|
179
|
+
if len(args) != fn.signature.input_arity:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
"CompiledProgram requires flat positional inputs matching its signature; "
|
|
182
|
+
f"expected {fn.signature.input_arity}, got {len(args)}."
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return eval_graph(fn.graph, list(args))
|
|
186
|
+
|
|
162
187
|
if isinstance(fn, TracedFunction):
|
|
163
188
|
inputs = fn.prepare_inputs(*args, **kwargs)
|
|
164
189
|
inputs = [unwrap_if_interp(v) for v in inputs]
|
|
165
|
-
|
|
166
|
-
wrapped = [
|
|
167
|
-
InterpObject(v, fn.graph.outputs[i].type, interp)
|
|
168
|
-
for i, v in enumerate(raw_result)
|
|
169
|
-
]
|
|
190
|
+
wrapped = eval_graph(fn.graph, inputs)
|
|
170
191
|
return fn.reconstruct_outputs(wrapped)
|
|
171
192
|
|
|
172
193
|
return fn(*args, **kwargs)
|
|
@@ -417,6 +438,9 @@ __all__ = [ # noqa: RUF022
|
|
|
417
438
|
"WORKER_HANDLERS",
|
|
418
439
|
"make_driver",
|
|
419
440
|
"make_simulator",
|
|
441
|
+
"tool",
|
|
442
|
+
"CompiledProgram",
|
|
443
|
+
"FlatIOSignature",
|
|
420
444
|
# Dialects
|
|
421
445
|
"dialects",
|
|
422
446
|
"register_default_context_factory",
|
mplang/edsl/__init__.py
CHANGED
|
@@ -53,6 +53,7 @@ from .jit import jit
|
|
|
53
53
|
from .object import Object
|
|
54
54
|
from .primitive import Primitive, primitive
|
|
55
55
|
from .printer import GraphPrinter, format_graph
|
|
56
|
+
from .program import CompiledProgram, FlatIOSignature
|
|
56
57
|
from .tracer import TracedFunction, TraceObject, Tracer, trace
|
|
57
58
|
from .typing import MPType, ScalarType, SSType, TableType, TensorType, VectorType
|
|
58
59
|
|
|
@@ -65,7 +66,9 @@ TensorObject = Object[TensorType]
|
|
|
65
66
|
VectorObject = Object[VectorType]
|
|
66
67
|
|
|
67
68
|
__all__ = [
|
|
69
|
+
"CompiledProgram",
|
|
68
70
|
"Context",
|
|
71
|
+
"FlatIOSignature",
|
|
69
72
|
"Graph",
|
|
70
73
|
"GraphPrinter",
|
|
71
74
|
"MPObject",
|
mplang/edsl/program.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import hashlib
|
|
18
|
+
import json
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from typing import Any, ClassVar
|
|
21
|
+
|
|
22
|
+
from mplang.edsl import serde
|
|
23
|
+
from mplang.edsl.graph import Graph
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True)
|
|
27
|
+
class FlatIOSignature:
|
|
28
|
+
"""Portable I/O signature for source-free execution.
|
|
29
|
+
|
|
30
|
+
Only supports flat positional inputs/outputs corresponding to
|
|
31
|
+
`graph.inputs` / `graph.outputs`.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
kind: ClassVar[str] = "flat_list_v0"
|
|
35
|
+
input_arity: int
|
|
36
|
+
output_arity: int
|
|
37
|
+
|
|
38
|
+
def to_json(self) -> dict[str, Any]:
|
|
39
|
+
return {
|
|
40
|
+
"kind": self.kind,
|
|
41
|
+
"input_arity": self.input_arity,
|
|
42
|
+
"output_arity": self.output_arity,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def from_json(cls, data: dict[str, Any]) -> FlatIOSignature:
|
|
47
|
+
if data.get("kind") != cls.kind:
|
|
48
|
+
raise ValueError(f"Unsupported signature kind: {data.get('kind')}")
|
|
49
|
+
return cls(
|
|
50
|
+
input_arity=int(data["input_arity"]),
|
|
51
|
+
output_arity=int(data["output_arity"]),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@serde.register_class
|
|
56
|
+
@dataclass
|
|
57
|
+
class CompiledProgram:
|
|
58
|
+
"""Executable program decoupled from user source.
|
|
59
|
+
|
|
60
|
+
This is a *logical model*; packaging (file/zip/etc.) is handled by tool layer.
|
|
61
|
+
|
|
62
|
+
Current constraints:
|
|
63
|
+
- signature is flat positional list I/O.
|
|
64
|
+
- no closure captures.
|
|
65
|
+
- no constant outputs (out_imms) unless future signature captures them.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
_serde_kind: ClassVar[str] = "mplang.CompiledProgram"
|
|
69
|
+
|
|
70
|
+
graph: Graph
|
|
71
|
+
signature: FlatIOSignature
|
|
72
|
+
required_opcodes: list[str]
|
|
73
|
+
graph_digest: str
|
|
74
|
+
required_world_size: int | None = None
|
|
75
|
+
created_at: str | None = None
|
|
76
|
+
mplang_version: str | None = None
|
|
77
|
+
schema_version: int = 1
|
|
78
|
+
name: str | None = None
|
|
79
|
+
|
|
80
|
+
def to_json(self) -> dict[str, Any]:
|
|
81
|
+
return {
|
|
82
|
+
"schema_version": self.schema_version,
|
|
83
|
+
"name": self.name,
|
|
84
|
+
"graph": serde.to_json(self.graph),
|
|
85
|
+
"signature": self.signature.to_json(),
|
|
86
|
+
"required_opcodes": list(self.required_opcodes),
|
|
87
|
+
"graph_digest": self.graph_digest,
|
|
88
|
+
"required_world_size": self.required_world_size,
|
|
89
|
+
"created_at": self.created_at,
|
|
90
|
+
"mplang_version": self.mplang_version,
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def from_json(cls, data: dict[str, Any]) -> CompiledProgram:
|
|
95
|
+
if "schema_version" not in data:
|
|
96
|
+
raise KeyError("Missing required field: schema_version")
|
|
97
|
+
schema_version = int(data["schema_version"])
|
|
98
|
+
if schema_version != 1:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Unsupported CompiledProgram schema_version: {schema_version}"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
graph = serde.from_json(data["graph"])
|
|
104
|
+
if not isinstance(graph, Graph):
|
|
105
|
+
raise TypeError(
|
|
106
|
+
f"Expected graph to deserialize to Graph, got {type(graph).__name__}"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
signature = FlatIOSignature.from_json(data["signature"])
|
|
110
|
+
|
|
111
|
+
required_world_size = data.get("required_world_size")
|
|
112
|
+
if required_world_size is not None:
|
|
113
|
+
required_world_size = int(required_world_size)
|
|
114
|
+
return cls(
|
|
115
|
+
graph=graph,
|
|
116
|
+
signature=signature,
|
|
117
|
+
required_opcodes=list(data.get("required_opcodes", [])),
|
|
118
|
+
graph_digest=str(data["graph_digest"]),
|
|
119
|
+
required_world_size=required_world_size,
|
|
120
|
+
created_at=data.get("created_at"),
|
|
121
|
+
mplang_version=data.get("mplang_version"),
|
|
122
|
+
schema_version=schema_version,
|
|
123
|
+
name=data.get("name"),
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def compute_graph_digest(graph: Graph) -> str:
|
|
128
|
+
"""Compute a deterministic digest for a Graph.
|
|
129
|
+
|
|
130
|
+
We intentionally avoid `serde.dumps()` because it doesn't sort keys.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
canonical = json.dumps(serde.to_json(graph), sort_keys=True, separators=(",", ":"))
|
|
134
|
+
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
|
mplang/tool/__init__.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Tool-layer APIs for MPLang.
|
|
16
|
+
|
|
17
|
+
This package contains utilities that are intentionally *not* part of the core
|
|
18
|
+
EDSL execution surface. In particular, compile/execute decoupling lives here:
|
|
19
|
+
- build a portable `CompiledProgram`
|
|
20
|
+
- pack/unpack to a container format
|
|
21
|
+
|
|
22
|
+
These helpers must not depend on user source code being available at execution.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
from mplang.edsl.program import CompiledProgram, FlatIOSignature
|
|
28
|
+
from mplang.tool.program import (
|
|
29
|
+
compile_program,
|
|
30
|
+
inspect_artifact,
|
|
31
|
+
pack,
|
|
32
|
+
pack_to_path,
|
|
33
|
+
unpack,
|
|
34
|
+
unpack_path,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
"CompiledProgram",
|
|
39
|
+
"FlatIOSignature",
|
|
40
|
+
"compile_program",
|
|
41
|
+
"inspect_artifact",
|
|
42
|
+
"pack",
|
|
43
|
+
"pack_to_path",
|
|
44
|
+
"unpack",
|
|
45
|
+
"unpack_path",
|
|
46
|
+
]
|
mplang/tool/program.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
# Copyright 2026 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import io
|
|
18
|
+
import json
|
|
19
|
+
import tarfile
|
|
20
|
+
from datetime import UTC, datetime
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Any, Literal
|
|
23
|
+
|
|
24
|
+
import mplang
|
|
25
|
+
from mplang.edsl import serde
|
|
26
|
+
from mplang.edsl.graph import Graph
|
|
27
|
+
from mplang.edsl.program import (
|
|
28
|
+
CompiledProgram,
|
|
29
|
+
FlatIOSignature,
|
|
30
|
+
compute_graph_digest,
|
|
31
|
+
)
|
|
32
|
+
from mplang.edsl.tracer import TracedFunction, trace
|
|
33
|
+
|
|
34
|
+
DEFAULT_MAX_ARTIFACT_JSON_BYTES = 512 * 1024 * 1024 # 512 MiB
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _utc_now_iso() -> str:
|
|
38
|
+
return datetime.now(UTC).isoformat()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _iter_graphs(root: Graph) -> list[Graph]:
|
|
42
|
+
# Use an explicit stack to avoid Python recursion limits.
|
|
43
|
+
# Also guard against potential region graph cycles.
|
|
44
|
+
out: list[Graph] = []
|
|
45
|
+
stack: list[Graph] = [root]
|
|
46
|
+
visited: set[int] = set()
|
|
47
|
+
while stack:
|
|
48
|
+
graph = stack.pop()
|
|
49
|
+
graph_id = id(graph)
|
|
50
|
+
if graph_id in visited:
|
|
51
|
+
continue
|
|
52
|
+
visited.add(graph_id)
|
|
53
|
+
out.append(graph)
|
|
54
|
+
|
|
55
|
+
for op in graph.operations:
|
|
56
|
+
if op.regions:
|
|
57
|
+
stack.extend(op.regions)
|
|
58
|
+
|
|
59
|
+
return out
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _collect_opcodes(graph: Graph) -> set[str]:
|
|
63
|
+
opcodes: set[str] = set()
|
|
64
|
+
for g in _iter_graphs(graph):
|
|
65
|
+
for op in g.operations:
|
|
66
|
+
opcodes.add(op.opcode)
|
|
67
|
+
return opcodes
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _collect_parties(graph: Graph) -> set[int]:
|
|
71
|
+
parties: set[int] = set()
|
|
72
|
+
for g in _iter_graphs(graph):
|
|
73
|
+
for op in g.operations:
|
|
74
|
+
raw = op.attrs.get("parties")
|
|
75
|
+
if raw is None:
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
if not isinstance(raw, (list, tuple, set)):
|
|
79
|
+
raise TypeError(
|
|
80
|
+
"Invalid 'parties' attribute: expected list/tuple/set of ints, "
|
|
81
|
+
f"got {type(raw).__name__}"
|
|
82
|
+
)
|
|
83
|
+
for p in raw:
|
|
84
|
+
p_int = int(p)
|
|
85
|
+
if p_int < 0:
|
|
86
|
+
raise ValueError("Invalid 'parties' attribute: negative party id")
|
|
87
|
+
parties.add(p_int)
|
|
88
|
+
return parties
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _compute_required_world_size(graph: Graph) -> int:
|
|
92
|
+
parties = _collect_parties(graph)
|
|
93
|
+
if not parties:
|
|
94
|
+
return 0
|
|
95
|
+
return max(parties) + 1
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _validate_traced_for_artifact(traced: TracedFunction) -> None:
|
|
99
|
+
# Restriction: no closure captures
|
|
100
|
+
if traced.captured:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"CompiledProgram does not support closure captures; "
|
|
103
|
+
"please refactor to pass all values explicitly."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Restriction: no constant outputs (out_imms)
|
|
107
|
+
if traced.out_imms:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
"CompiledProgram does not support constant outputs (out_imms); "
|
|
110
|
+
"return only traced values (graph outputs)."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Restriction: signature is flat positional list I/O.
|
|
114
|
+
# We do not preserve (args, kwargs) pytree metadata.
|
|
115
|
+
# We therefore require all runtime-provided inputs correspond exactly to graph.inputs.
|
|
116
|
+
if len(traced.graph.inputs) != len(traced.in_var_pos):
|
|
117
|
+
raise ValueError(
|
|
118
|
+
"CompiledProgram requires flat positional inputs that map 1:1 to graph.inputs; "
|
|
119
|
+
f"got graph.inputs={len(traced.graph.inputs)} but in_var_pos={len(traced.in_var_pos)}."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _validate_program(program: CompiledProgram) -> None:
|
|
124
|
+
if program.signature.kind != FlatIOSignature.kind:
|
|
125
|
+
raise ValueError(f"Unsupported signature kind: {program.signature.kind}")
|
|
126
|
+
|
|
127
|
+
if program.signature.input_arity != len(program.graph.inputs):
|
|
128
|
+
raise ValueError(
|
|
129
|
+
"Signature input_arity does not match graph.inputs: "
|
|
130
|
+
f"input_arity={program.signature.input_arity}, inputs={len(program.graph.inputs)}"
|
|
131
|
+
)
|
|
132
|
+
if program.signature.output_arity != len(program.graph.outputs):
|
|
133
|
+
raise ValueError(
|
|
134
|
+
"Signature output_arity does not match graph.outputs: "
|
|
135
|
+
f"output_arity={program.signature.output_arity}, outputs={len(program.graph.outputs)}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
expected_opcodes = sorted(_collect_opcodes(program.graph))
|
|
139
|
+
if sorted(program.required_opcodes) != expected_opcodes:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"required_opcodes mismatch with graph content; "
|
|
142
|
+
"artifact may be corrupted or constructed inconsistently."
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
actual_digest = compute_graph_digest(program.graph)
|
|
146
|
+
if program.graph_digest and program.graph_digest != actual_digest:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
"Graph digest mismatch: "
|
|
149
|
+
f"expected={program.graph_digest}, actual={actual_digest}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
expected_world_size = _compute_required_world_size(program.graph)
|
|
153
|
+
if (
|
|
154
|
+
program.required_world_size is not None
|
|
155
|
+
and program.required_world_size != expected_world_size
|
|
156
|
+
):
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"required_world_size mismatch with graph content; "
|
|
159
|
+
f"expected={expected_world_size}, got={program.required_world_size}."
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Ensure JSON serialization works (fail fast for non-serde attrs).
|
|
163
|
+
serde.to_json(program)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def compile_program(
|
|
167
|
+
fn_or_traced: Any,
|
|
168
|
+
*args: Any,
|
|
169
|
+
context: Any | None = None,
|
|
170
|
+
name: str | None = None,
|
|
171
|
+
**kwargs: Any,
|
|
172
|
+
) -> CompiledProgram:
|
|
173
|
+
"""Compile (trace) into a source-free executable `CompiledProgram`.
|
|
174
|
+
|
|
175
|
+
Restrictions (enforced):
|
|
176
|
+
- no closure captures
|
|
177
|
+
- no constant outputs (`out_imms` must be empty)
|
|
178
|
+
- signature is flat list (positional) I/O
|
|
179
|
+
|
|
180
|
+
Note: `in_imms` (compile-time constants) are allowed: they are baked into the graph.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
traced: TracedFunction
|
|
184
|
+
if isinstance(fn_or_traced, TracedFunction):
|
|
185
|
+
traced = fn_or_traced
|
|
186
|
+
else:
|
|
187
|
+
if context is not None:
|
|
188
|
+
with context:
|
|
189
|
+
traced = trace(fn_or_traced, *args, **kwargs)
|
|
190
|
+
else:
|
|
191
|
+
traced = trace(fn_or_traced, *args, **kwargs)
|
|
192
|
+
|
|
193
|
+
_validate_traced_for_artifact(traced)
|
|
194
|
+
|
|
195
|
+
signature = FlatIOSignature(
|
|
196
|
+
input_arity=len(traced.graph.inputs),
|
|
197
|
+
output_arity=len(traced.graph.outputs),
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
required_opcodes = sorted(_collect_opcodes(traced.graph))
|
|
201
|
+
graph_digest = compute_graph_digest(traced.graph)
|
|
202
|
+
required_world_size = _compute_required_world_size(traced.graph)
|
|
203
|
+
|
|
204
|
+
program = CompiledProgram(
|
|
205
|
+
graph=traced.graph,
|
|
206
|
+
signature=signature,
|
|
207
|
+
required_opcodes=required_opcodes,
|
|
208
|
+
graph_digest=graph_digest,
|
|
209
|
+
required_world_size=required_world_size,
|
|
210
|
+
created_at=_utc_now_iso(),
|
|
211
|
+
mplang_version=getattr(mplang, "__version__", None),
|
|
212
|
+
name=name or traced.name,
|
|
213
|
+
)
|
|
214
|
+
return program
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def pack(program: CompiledProgram, *, compress: bool = True) -> bytes:
|
|
218
|
+
"""Pack a `CompiledProgram` into portable bytes.
|
|
219
|
+
|
|
220
|
+
Container format (recommended): a `tar.gz` archive containing a single
|
|
221
|
+
human-readable JSON file `artifact.json`.
|
|
222
|
+
|
|
223
|
+
This allows users to inspect artifacts via:
|
|
224
|
+
`tar -xzf program.tar.gz && cat artifact.json`
|
|
225
|
+
|
|
226
|
+
If `compress=False`, returns an uncompressed tar archive (still extractable
|
|
227
|
+
via `tar -xf`).
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
artifact_json = json.dumps(
|
|
231
|
+
serde.to_json(program),
|
|
232
|
+
ensure_ascii=False,
|
|
233
|
+
indent=2,
|
|
234
|
+
sort_keys=True,
|
|
235
|
+
).encode("utf-8")
|
|
236
|
+
|
|
237
|
+
buf = io.BytesIO()
|
|
238
|
+
mode: Literal["w:gz", "w"] = "w:gz" if compress else "w"
|
|
239
|
+
with tarfile.open(fileobj=buf, mode=mode) as tf:
|
|
240
|
+
info = tarfile.TarInfo(name="artifact.json")
|
|
241
|
+
info.size = len(artifact_json)
|
|
242
|
+
tf.addfile(info, io.BytesIO(artifact_json))
|
|
243
|
+
|
|
244
|
+
return buf.getvalue()
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def pack_to_path(
|
|
248
|
+
program: CompiledProgram, path: str | Path, *, compress: bool = True
|
|
249
|
+
) -> Path:
|
|
250
|
+
"""Pack and write artifact to disk.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
program: Program to pack.
|
|
254
|
+
path: Output path (typically ends with `.tar.gz`).
|
|
255
|
+
compress: Whether to gzip the tar archive.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
The resolved output path.
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
out_path = Path(path).expanduser().resolve()
|
|
262
|
+
out_path.write_bytes(pack(program, compress=compress))
|
|
263
|
+
return out_path
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def unpack(
|
|
267
|
+
data: bytes, *, max_artifact_json_bytes: int = DEFAULT_MAX_ARTIFACT_JSON_BYTES
|
|
268
|
+
) -> CompiledProgram:
|
|
269
|
+
"""Unpack bytes into a `CompiledProgram`.
|
|
270
|
+
|
|
271
|
+
Supported container format: tar(.gz) containing `artifact.json`.
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
try:
|
|
275
|
+
with tarfile.open(fileobj=io.BytesIO(data), mode="r:*") as tf:
|
|
276
|
+
member = tf.getmember("artifact.json")
|
|
277
|
+
|
|
278
|
+
if not member.isfile():
|
|
279
|
+
raise ValueError("artifact.json is not a regular file")
|
|
280
|
+
|
|
281
|
+
if member.size < 0:
|
|
282
|
+
raise ValueError("Invalid artifact.json size in tar header")
|
|
283
|
+
|
|
284
|
+
if member.size > max_artifact_json_bytes:
|
|
285
|
+
raise ValueError(
|
|
286
|
+
"artifact.json is too large to unpack safely: "
|
|
287
|
+
f"size={member.size} bytes, limit={max_artifact_json_bytes} bytes"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
f = tf.extractfile(member)
|
|
291
|
+
if f is None:
|
|
292
|
+
raise ValueError("artifact.json not found in tar archive")
|
|
293
|
+
payload = json.loads(f.read().decode("utf-8"))
|
|
294
|
+
except (tarfile.ReadError, KeyError, OSError, json.JSONDecodeError) as exc:
|
|
295
|
+
raise ValueError(
|
|
296
|
+
"Invalid artifact container: expected tar(.gz) with artifact.json"
|
|
297
|
+
) from exc
|
|
298
|
+
|
|
299
|
+
program = serde.from_json(payload)
|
|
300
|
+
if not isinstance(program, CompiledProgram):
|
|
301
|
+
raise TypeError(
|
|
302
|
+
f"Expected artifact.json to deserialize to CompiledProgram, got {type(program).__name__}"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
_validate_program(program)
|
|
306
|
+
return program
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def unpack_path(path: str | Path) -> CompiledProgram:
|
|
310
|
+
"""Read an artifact from disk and unpack it."""
|
|
311
|
+
|
|
312
|
+
in_path = Path(path).expanduser().resolve()
|
|
313
|
+
return unpack(in_path.read_bytes())
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def inspect_artifact(data: bytes) -> dict[str, Any]:
|
|
317
|
+
"""Return a JSON-friendly inspection report without executing."""
|
|
318
|
+
|
|
319
|
+
program = unpack(data)
|
|
320
|
+
return {
|
|
321
|
+
"schema_version": program.schema_version,
|
|
322
|
+
"name": program.name,
|
|
323
|
+
"mplang_version": program.mplang_version,
|
|
324
|
+
"created_at": program.created_at,
|
|
325
|
+
"graph_digest": program.graph_digest,
|
|
326
|
+
"required_world_size": program.required_world_size,
|
|
327
|
+
"signature": program.signature.to_json(),
|
|
328
|
+
"required_opcodes": program.required_opcodes,
|
|
329
|
+
"graph": {
|
|
330
|
+
"inputs": len(program.graph.inputs),
|
|
331
|
+
"ops": len(program.graph.operations),
|
|
332
|
+
"outputs": len(program.graph.outputs),
|
|
333
|
+
"region_count": sum(len(op.regions) for op in program.graph.operations),
|
|
334
|
+
},
|
|
335
|
+
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
mplang/__init__.py,sha256=
|
|
1
|
+
mplang/__init__.py,sha256=PsUcGqKGQym3N_CU1Rav304YQwGVj8CLVx9S_6UTD9c,14519
|
|
2
2
|
mplang/cli.py,sha256=NW0GmxZeRC4rrYg8RVBlZiDjkihBXGcHmrll-JFOqWM,20317
|
|
3
3
|
mplang/cli_guide.md,sha256=hKC6AKgJn-lM_wZ0CzZIP2QUBxGPnT0Op_1YyeUhCfI,3581
|
|
4
4
|
mplang/logging_config.py,sha256=6Zm1Z_EBnzxAfeKr94xiLpWYDD8fa4ZEs2g_kMoH8eI,7579
|
|
@@ -42,13 +42,14 @@ mplang/dialects/table.py,sha256=i9ruyh91_tSWu9rsLomrBUfqRdbHiZMMMJzNKfMrAUc,1353
|
|
|
42
42
|
mplang/dialects/tee.py,sha256=BMFSbeK-Ck2jQP4qY9bZeNYTxEa7uEtUWLZLC4BPQxk,10111
|
|
43
43
|
mplang/dialects/tensor.py,sha256=7aAYKaMaFjJ8N25yPFnmVhUuUdKJYy-M-a4NsZGE7kY,39893
|
|
44
44
|
mplang/edsl/README.md,sha256=viflvdRojOa6Xk_UMRPqpuPGXcPGmdlv2-XR6LO7B58,7592
|
|
45
|
-
mplang/edsl/__init__.py,sha256=
|
|
45
|
+
mplang/edsl/__init__.py,sha256=WL4efo6uY1br781_8IaCkSi7yCUldcfJfbtFsn6Fdj4,2698
|
|
46
46
|
mplang/edsl/context.py,sha256=Ln8n3bDe8_ISe42TAGzUuz8fw57-tu1APuihMfAtW1Y,10075
|
|
47
47
|
mplang/edsl/graph.py,sha256=nCeCN7-bxfzyv40fmxcEXOaVUx14cOCaHfFb7A9OBnE,14968
|
|
48
48
|
mplang/edsl/jit.py,sha256=7eLZHoIuL5FZo9G5eF9nI4EeayLK-OvJ0NoH3VG5vLI,2393
|
|
49
49
|
mplang/edsl/object.py,sha256=dBl58q-ondjpjPNBh8zZvIEj6pJw2yEoz6TCaM_oleA,1906
|
|
50
50
|
mplang/edsl/primitive.py,sha256=gDrn4FH682DUOgTqcQ2-9aqDYJau9L8E1ElswyOmmdw,10859
|
|
51
51
|
mplang/edsl/printer.py,sha256=drmfRkdCNqbkRfSDmejxtO-rEAaM13QyHB3AbAmKVFk,4393
|
|
52
|
+
mplang/edsl/program.py,sha256=_JdEU2-nb79VlFLcgMJf4JS30TARBeUIzno0y0SFVsg,4467
|
|
52
53
|
mplang/edsl/registry.py,sha256=hudXZPUrUUueEwgksDKN0cnE3iiXucuTaDdDK8uSPmk,6822
|
|
53
54
|
mplang/edsl/serde.py,sha256=8K94laE8ObeGuBoF6m7g3A-xEe98EvqQ_6ZPPspddAY,11641
|
|
54
55
|
mplang/edsl/tracer.py,sha256=EWN3eMVRG-CZsamTyINOnhhEUKhgd4CYwFMWeRpjycU,23129
|
|
@@ -94,10 +95,12 @@ mplang/runtime/dialect_state.py,sha256=HxO1i4kSOujS2tQzAF9-WmI3nChSaGgupf2_07dHe
|
|
|
94
95
|
mplang/runtime/interpreter.py,sha256=wcCWXpAGylqdw_HecR4suJtwmozHLrK5x6Q8xM-Pn24,43593
|
|
95
96
|
mplang/runtime/object_store.py,sha256=yT6jtKG2GUEJVmpq3gnQ8mCMvUFYzgBciC5A-J5KRdk,5998
|
|
96
97
|
mplang/runtime/value.py,sha256=EqlhSgxLTJi_FF3ppyKjMe4eHS6-ROx-zK1YesG1U4o,4311
|
|
98
|
+
mplang/tool/__init__.py,sha256=9K-T50W_vClUlyERcVx5xGZaeyv0Ts63SaQX6AZtjIs,1341
|
|
99
|
+
mplang/tool/program.py,sha256=W3H8bpPirnoJ4ZrmyPYuMCPadJis20o__n_1MKqCsWU,11058
|
|
97
100
|
mplang/utils/__init__.py,sha256=toubeyISiT6WDdITdfAvdY2iXVZU3PKVNWVeC9sYxuA,947
|
|
98
101
|
mplang/utils/func_utils.py,sha256=aZ-X43w8JKJgiF-IUMS0G7QqrNeoTM5ZPzRNd-tKxpw,5180
|
|
99
|
-
mplang_nightly-0.1.
|
|
100
|
-
mplang_nightly-0.1.
|
|
101
|
-
mplang_nightly-0.1.
|
|
102
|
-
mplang_nightly-0.1.
|
|
103
|
-
mplang_nightly-0.1.
|
|
102
|
+
mplang_nightly-0.1.dev279.dist-info/METADATA,sha256=Q4l1RV5WC5NfRV0heHxzHH8qVf3CkjOi6Ag1kbvsX38,16783
|
|
103
|
+
mplang_nightly-0.1.dev279.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
104
|
+
mplang_nightly-0.1.dev279.dist-info/entry_points.txt,sha256=mG1oJT-GAjQR834a62_QIWb7litzWPPyVnwFqm-rWuY,55
|
|
105
|
+
mplang_nightly-0.1.dev279.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
106
|
+
mplang_nightly-0.1.dev279.dist-info/RECORD,,
|
|
File without changes
|
{mplang_nightly-0.1.dev278.dist-info → mplang_nightly-0.1.dev279.dist-info}/entry_points.txt
RENAMED
|
File without changes
|
{mplang_nightly-0.1.dev278.dist-info → mplang_nightly-0.1.dev279.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|