mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v2/edsl/README.md
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
# MPLang EDSL - Experimental Architecture
|
|
2
|
+
|
|
3
|
+
**⚠️ Status**: Experimental / Work in Progress
|
|
4
|
+
|
|
5
|
+
This directory contains the next-generation EDSL (Embedded Domain-Specific Language) architecture for MPLang.
|
|
6
|
+
|
|
7
|
+
## Why a New Architecture?
|
|
8
|
+
|
|
9
|
+
The current `mplang.core` architecture (Expr Tree + @primitive) has served us well, but we're hitting limitations:
|
|
10
|
+
|
|
11
|
+
1. **Expr Tree** is hard to optimize (visitor pattern, nested structure)
|
|
12
|
+
2. **@primitive decorators** hide complexity and limit flexibility
|
|
13
|
+
3. **Type system** is split between `mptype.MPType` and `typing.BaseType`
|
|
14
|
+
4. **No clear separation** between IR, frontend, and backend
|
|
15
|
+
|
|
16
|
+
Modern EDSLs (torch.fx, JAX) use **Operation List + SSA** for better analyzability and optimization.
|
|
17
|
+
|
|
18
|
+
## Goals
|
|
19
|
+
|
|
20
|
+
### 1. Modern IR (Operation List)
|
|
21
|
+
|
|
22
|
+
**From** (Expr Tree):
|
|
23
|
+
```python
|
|
24
|
+
CallExpr(
|
|
25
|
+
func=add,
|
|
26
|
+
args=[VariableExpr("x"), VariableExpr("y")]
|
|
27
|
+
)
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
**To** (Operation List):
|
|
31
|
+
```python
|
|
32
|
+
%0 = input "x"
|
|
33
|
+
%1 = input "y"
|
|
34
|
+
%2 = add %0, %1
|
|
35
|
+
return %2
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
### 2. Unified Type System
|
|
39
|
+
|
|
40
|
+
**Single source of truth**: `mplang.edsl.typing.MPType`
|
|
41
|
+
|
|
42
|
+
```python
|
|
43
|
+
from mplang2.edsl.typing import Tensor, Vector, MPType, f32
|
|
44
|
+
|
|
45
|
+
# All types use BaseType
|
|
46
|
+
plaintext: MPType = Tensor[f32, (4096,)]
|
|
47
|
+
ciphertext: MPType = Vector[f32, 4096]
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
### 3. Explicit Tracing
|
|
51
|
+
|
|
52
|
+
**Clean context management**:
|
|
53
|
+
```python
|
|
54
|
+
from mplang2.edsl import Tracer
|
|
55
|
+
|
|
56
|
+
tracer = Tracer()
|
|
57
|
+
with tracer: # Context manager protocol
|
|
58
|
+
result = my_function(x, y)
|
|
59
|
+
graph = tracer.finalize(result)
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
### 4. Extensibility
|
|
63
|
+
|
|
64
|
+
Easy to add new backends:
|
|
65
|
+
- FHE (Fully Homomorphic Encryption)
|
|
66
|
+
- TEE (Trusted Execution Environment)
|
|
67
|
+
- Custom accelerators
|
|
68
|
+
|
|
69
|
+
### 5. Layered API Architecture
|
|
70
|
+
|
|
71
|
+
The EDSL provides two distinct API layers:
|
|
72
|
+
|
|
73
|
+
1. **Low-Level API (Graph Manipulation)**:
|
|
74
|
+
- Direct manipulation of the `Graph` IR.
|
|
75
|
+
- Generic `add_op` method (pure graph API, no op semantics).
|
|
76
|
+
- Analogous to MLIR's generic operation construction.
|
|
77
|
+
- Used by compiler passes and backend implementations.
|
|
78
|
+
|
|
79
|
+
2. **High-Level API (Tracing)**:
|
|
80
|
+
- Uses `Tracer` + `Primitive` (with `abstract_eval`).
|
|
81
|
+
- Pythonic interface (functions, operators).
|
|
82
|
+
- Automatic type inference and graph construction.
|
|
83
|
+
- The primary interface for users.
|
|
84
|
+
|
|
85
|
+
## Directory Structure
|
|
86
|
+
|
|
87
|
+
```
|
|
88
|
+
mplang/edsl/
|
|
89
|
+
├── __init__.py # Public API
|
|
90
|
+
├── README.md # This file
|
|
91
|
+
│
|
|
92
|
+
├── design/ # Design documents
|
|
93
|
+
│ ├── architecture.md # Complete architecture overview
|
|
94
|
+
│ ├── type_system.md # Type system design
|
|
95
|
+
│ └── migration.md # Migration from mplang.core
|
|
96
|
+
│
|
|
97
|
+
├── typing.py # ✅ Unified type system
|
|
98
|
+
├── graph.py # ✅ IR: Operation List + SSA
|
|
99
|
+
├── primitive.py # ✅ Primitive abstraction
|
|
100
|
+
├── object.py # ✅ TraceObject/InterpObject
|
|
101
|
+
├── context.py # ✅ Context management
|
|
102
|
+
├── tracer.py # ✅ Explicit tracer
|
|
103
|
+
├── interpreter.py # ✅ Interpreter + GraphInterpreter
|
|
104
|
+
└── jit.py # ✅ @jit decorator
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
## Implementation Status
|
|
108
|
+
|
|
109
|
+
### ✅ Completed (Phase 1-4)
|
|
110
|
+
|
|
111
|
+
- [x] Type system (`typing.py`) - 649 lines
|
|
112
|
+
- [x] Graph IR (`graph.py`) - 388 lines
|
|
113
|
+
- [x] Primitive abstraction (`primitive.py`) - 338 lines
|
|
114
|
+
- [x] Object hierarchy (`object.py`) - 153 lines
|
|
115
|
+
- [x] Context system (`context.py`) - 117 lines
|
|
116
|
+
- [x] Tracer (`tracer.py`) - 201 lines
|
|
117
|
+
- [x] Interpreter (`interpreter.py`) - 66 lines
|
|
118
|
+
- [x] JIT decorator (`jit.py`) - 42 lines
|
|
119
|
+
- [x] Design documents
|
|
120
|
+
- [x] **153 tests passing** (140 edsl + 13 core2)
|
|
121
|
+
|
|
122
|
+
### 🚧 In Progress
|
|
123
|
+
- [ ] Integration with existing ops/kernels
|
|
124
|
+
- [ ] Migration utilities
|
|
125
|
+
- [ ] Performance benchmarks
|
|
126
|
+
|
|
127
|
+
### ❌ Dropped / Deprecated
|
|
128
|
+
- [x] Builder API (`builder.py`) - Integrated into `Tracer`
|
|
129
|
+
|
|
130
|
+
### 📋 Planned
|
|
131
|
+
- [ ] Advanced optimizations
|
|
132
|
+
- [ ] More backends (TEE, MPC)
|
|
133
|
+
|
|
134
|
+
## Quick Start
|
|
135
|
+
|
|
136
|
+
### Using the New Type System
|
|
137
|
+
|
|
138
|
+
```python
|
|
139
|
+
from mplang2.edsl.typing import Tensor, Vector, CustomType, f32
|
|
140
|
+
|
|
141
|
+
# Define types
|
|
142
|
+
PlaintextVec = Tensor[f32, (4096,)]
|
|
143
|
+
CiphertextVec = Vector[f32, 4096]
|
|
144
|
+
EncryptionKey = CustomType("EncryptionKey")
|
|
145
|
+
|
|
146
|
+
# Type annotations
|
|
147
|
+
def encrypt(data: PlaintextVec, key: EncryptionKey) -> CiphertextVec:
|
|
148
|
+
...
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
### Using the Tracer (Graph Construction)
|
|
152
|
+
|
|
153
|
+
```python
|
|
154
|
+
from mplang2.edsl import Tracer
|
|
155
|
+
from mplang2.dialects.simp import pcall_static
|
|
156
|
+
|
|
157
|
+
def my_program(x, y):
|
|
158
|
+
# This function is traced into a Graph
|
|
159
|
+
return pcall_static((0, 1), lambda a, b: a + b, x, y)
|
|
160
|
+
|
|
161
|
+
tracer = Tracer()
|
|
162
|
+
with tracer:
|
|
163
|
+
# Inputs are automatically lifted to TraceObjects
|
|
164
|
+
result = my_program(x, y)
|
|
165
|
+
|
|
166
|
+
# Finalize graph
|
|
167
|
+
graph = tracer.finalize(result)
|
|
168
|
+
```
|
|
169
|
+
|
|
170
|
+
## Design Documents
|
|
171
|
+
|
|
172
|
+
Detailed design documents are in the `design/` subdirectory:
|
|
173
|
+
|
|
174
|
+
### 1. [architecture.md](design/architecture.md)
|
|
175
|
+
|
|
176
|
+
Complete EDSL architecture overview covering:
|
|
177
|
+
- Core components (Tracer, Graph)
|
|
178
|
+
- Design principles (Closed-World, TracedFunction vs First-Class Functions)
|
|
179
|
+
- Control flow handling (Dialect-specific, e.g., `simp.uniform_cond`)
|
|
180
|
+
- Comparison with JAX, PyTorch, TensorFlow
|
|
181
|
+
|
|
182
|
+
### 2. [type_system.md](design/type_system.md)
|
|
183
|
+
|
|
184
|
+
New type system design:
|
|
185
|
+
- Three orthogonal dimensions (Layout, Encryption, Distribution)
|
|
186
|
+
- Type composition examples
|
|
187
|
+
- Ops writing guide
|
|
188
|
+
- Migration strategy
|
|
189
|
+
|
|
190
|
+
### 3. [migration.md](design/migration.md)
|
|
191
|
+
|
|
192
|
+
Migration path from `mplang.core` to `mplang.edsl`:
|
|
193
|
+
- 6-phase migration plan
|
|
194
|
+
- Backward compatibility strategy
|
|
195
|
+
- Type conversion utilities
|
|
196
|
+
|
|
197
|
+
## Relationship with mplang.core
|
|
198
|
+
|
|
199
|
+
```
|
|
200
|
+
mplang/
|
|
201
|
+
├── core/ # Stable API (current production)
|
|
202
|
+
│ ├── primitive.py
|
|
203
|
+
│ ├── tracer.py
|
|
204
|
+
│ └── expr/
|
|
205
|
+
│
|
|
206
|
+
├── edsl/ # Experimental (this directory)
|
|
207
|
+
│ ├── typing.py # Can be used independently
|
|
208
|
+
│ ├── graph.py # Future replacement for core.expr
|
|
209
|
+
│ └── tracer.py # Future replacement for core.tracer
|
|
210
|
+
│
|
|
211
|
+
├── ops/ # Shared between core and edsl
|
|
212
|
+
├── kernels/ # Shared between core and edsl
|
|
213
|
+
└── runtime/ # Shared between core and edsl
|
|
214
|
+
```
|
|
215
|
+
|
|
216
|
+
**Migration Strategy**:
|
|
217
|
+
1. Develop `edsl` in parallel (no breaking changes to `core`)
|
|
218
|
+
2. Gradually move internal code to use `edsl.typing`
|
|
219
|
+
3. Add adapters between `core` and `edsl`
|
|
220
|
+
4. Deprecate `core` in future major version
|
|
221
|
+
|
|
222
|
+
## Contributing
|
|
223
|
+
|
|
224
|
+
We welcome contributions! Since this is experimental:
|
|
225
|
+
|
|
226
|
+
1. **Read the design docs first**: Understand the architecture
|
|
227
|
+
2. **Start small**: Pick a specific component (e.g., Graph IR)
|
|
228
|
+
3. **Discuss early**: Open an issue before implementing
|
|
229
|
+
4. **Test thoroughly**: Add unit tests for new code
|
|
230
|
+
|
|
231
|
+
### Development Workflow
|
|
232
|
+
|
|
233
|
+
```bash
|
|
234
|
+
# Install dev dependencies
|
|
235
|
+
uv sync --group dev
|
|
236
|
+
|
|
237
|
+
# Run tests (future)
|
|
238
|
+
uv run pytest mplang/edsl/
|
|
239
|
+
|
|
240
|
+
# Lint
|
|
241
|
+
uv run ruff check mplang/edsl/
|
|
242
|
+
uv run ruff format mplang/edsl/
|
|
243
|
+
|
|
244
|
+
# Type check
|
|
245
|
+
uv run mypy mplang/edsl/
|
|
246
|
+
```
|
|
247
|
+
|
|
248
|
+
## FAQ
|
|
249
|
+
|
|
250
|
+
### Q: Should I use `mplang.edsl` in production?
|
|
251
|
+
|
|
252
|
+
**A**: No, use `mplang.core`. `mplang.edsl` is experimental.
|
|
253
|
+
|
|
254
|
+
### Q: Can I use `mplang.edsl.typing` independently?
|
|
255
|
+
|
|
256
|
+
**A**: Yes! The type system is stable and can be used for type annotations.
|
|
257
|
+
|
|
258
|
+
### Q: When will `edsl` replace `core`?
|
|
259
|
+
|
|
260
|
+
**A**: No timeline yet. We need to:
|
|
261
|
+
1. Complete the implementation
|
|
262
|
+
2. Validate performance
|
|
263
|
+
3. Migrate all tests
|
|
264
|
+
4. Get community feedback
|
|
265
|
+
|
|
266
|
+
### Q: How can I help?
|
|
267
|
+
|
|
268
|
+
**A**: Check the implementation status above and pick an unimplemented component. Open an issue to discuss!
|
|
269
|
+
|
|
270
|
+
## References
|
|
271
|
+
|
|
272
|
+
- **torch.fx**: https://pytorch.org/docs/stable/fx.html
|
|
273
|
+
- **JAX jaxpr**: https://jax.readthedocs.io/en/latest/jaxpr.html
|
|
274
|
+
- **MLIR**: https://mlir.llvm.org/
|
|
275
|
+
|
|
276
|
+
---
|
|
277
|
+
|
|
278
|
+
**Last Updated**: 2025-01-11
|
|
279
|
+
**Maintainers**: MPLang Team
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Public entrypoint for the MPLang EDSL.
|
|
16
|
+
|
|
17
|
+
This module keeps the surface area intentionally small so downstream code can
|
|
18
|
+
simply write::
|
|
19
|
+
|
|
20
|
+
import mplang.v2.edsl as el
|
|
21
|
+
import mplang.v2.edsl.typing as elt
|
|
22
|
+
|
|
23
|
+
The `el` namespace re-exports the commonly used building blocks (context,
|
|
24
|
+
graph, tracer, primitives, etc.), while the full type system lives under
|
|
25
|
+
``mplang.edsl.typing``.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
# Re-export the typing module so callers can `import mplang.v2.edsl.typing as elt`
|
|
31
|
+
from . import typing as typing
|
|
32
|
+
|
|
33
|
+
# Context management
|
|
34
|
+
from .context import (
|
|
35
|
+
Context,
|
|
36
|
+
find_context,
|
|
37
|
+
find_context_with_state,
|
|
38
|
+
find_interpreter,
|
|
39
|
+
get_current_context,
|
|
40
|
+
get_default_context,
|
|
41
|
+
is_tracing,
|
|
42
|
+
pop_context,
|
|
43
|
+
push_context,
|
|
44
|
+
register_default_context_factory,
|
|
45
|
+
set_root_context,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Graph IR
|
|
49
|
+
from .graph import Graph, Operation, Value
|
|
50
|
+
|
|
51
|
+
# High-level helpers
|
|
52
|
+
from .jit import jit
|
|
53
|
+
from .object import Object
|
|
54
|
+
from .primitive import Primitive, primitive
|
|
55
|
+
from .printer import GraphPrinter, format_graph
|
|
56
|
+
from .tracer import TracedFunction, TraceObject, Tracer, trace
|
|
57
|
+
from .typing import MPType, ScalarType, SSType, TableType, TensorType, VectorType
|
|
58
|
+
|
|
59
|
+
# Type Aliases for strong typing
|
|
60
|
+
MPObject = Object[MPType]
|
|
61
|
+
ScalarObject = Object[ScalarType]
|
|
62
|
+
SSObject = Object[SSType]
|
|
63
|
+
TableObject = Object[TableType]
|
|
64
|
+
TensorObject = Object[TensorType]
|
|
65
|
+
VectorObject = Object[VectorType]
|
|
66
|
+
|
|
67
|
+
__all__ = [
|
|
68
|
+
"Context",
|
|
69
|
+
"Graph",
|
|
70
|
+
"GraphPrinter",
|
|
71
|
+
"MPObject",
|
|
72
|
+
"Object",
|
|
73
|
+
"Operation",
|
|
74
|
+
"Primitive",
|
|
75
|
+
"SSObject",
|
|
76
|
+
"ScalarObject",
|
|
77
|
+
"TableObject",
|
|
78
|
+
"TensorObject",
|
|
79
|
+
"TraceObject",
|
|
80
|
+
"TracedFunction",
|
|
81
|
+
"Tracer",
|
|
82
|
+
"Value",
|
|
83
|
+
"VectorObject",
|
|
84
|
+
"find_context",
|
|
85
|
+
"find_context_with_state",
|
|
86
|
+
"find_interpreter",
|
|
87
|
+
"format_graph",
|
|
88
|
+
"get_current_context",
|
|
89
|
+
"get_default_context",
|
|
90
|
+
"is_tracing",
|
|
91
|
+
"jit",
|
|
92
|
+
"pop_context",
|
|
93
|
+
"primitive",
|
|
94
|
+
"push_context",
|
|
95
|
+
"register_default_context_factory",
|
|
96
|
+
"set_root_context",
|
|
97
|
+
"trace",
|
|
98
|
+
"typing",
|
|
99
|
+
]
|
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Context: EDSL Execution Context Abstraction.
|
|
16
|
+
|
|
17
|
+
This module defines the Context hierarchy:
|
|
18
|
+
- Context: Base class for EDSL execution contexts (with bind_primitive method)
|
|
19
|
+
- Tracer: Tracing context (records operations to Graph IR)
|
|
20
|
+
- Interpreter: Execution context (executes operations immediately)
|
|
21
|
+
|
|
22
|
+
Contexts can be used directly with Python's 'with' statement:
|
|
23
|
+
|
|
24
|
+
from mplang.v2.edsl import Tracer
|
|
25
|
+
|
|
26
|
+
tracer = Tracer()
|
|
27
|
+
with tracer:
|
|
28
|
+
# Operations run under tracer context
|
|
29
|
+
result = primitive.bind(x, y)
|
|
30
|
+
|
|
31
|
+
State Management:
|
|
32
|
+
Contexts can carry arbitrary named state via set_state/get_state.
|
|
33
|
+
This allows different layers (device, ml, analytics) to attach their
|
|
34
|
+
own state without the EDSL layer knowing about specific state types.
|
|
35
|
+
|
|
36
|
+
State key conventions:
|
|
37
|
+
- "dialect.{name}": Dialect runtime state (e.g., "dialect.simp")
|
|
38
|
+
- "device.cluster": Device/cluster configuration
|
|
39
|
+
- "ml.{component}": ML pipeline components
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
from __future__ import annotations
|
|
43
|
+
|
|
44
|
+
from abc import ABC, abstractmethod
|
|
45
|
+
from collections.abc import Callable
|
|
46
|
+
from typing import TYPE_CHECKING, Any, Self
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from mplang.v2.edsl.graph import Graph
|
|
50
|
+
from mplang.v2.edsl.object import Object
|
|
51
|
+
from mplang.v2.edsl.primitive import Primitive
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class Context(ABC):
|
|
55
|
+
"""Base class for EDSL execution contexts with extensible state slots.
|
|
56
|
+
|
|
57
|
+
A Context represents an environment where primitives are executed.
|
|
58
|
+
There are two types of contexts:
|
|
59
|
+
- Tracer: Records operations to Graph IR (compile-time)
|
|
60
|
+
- Interpreter: Execution context (executes operations immediately)
|
|
61
|
+
|
|
62
|
+
State Management:
|
|
63
|
+
Contexts can carry arbitrary named state. Different layers can attach
|
|
64
|
+
their own state without the EDSL layer knowing specifics:
|
|
65
|
+
|
|
66
|
+
>>> ctx.set_state("device.cluster", cluster_spec)
|
|
67
|
+
>>> ctx.set_state("dialect.simp", simp_driver)
|
|
68
|
+
>>> cluster = ctx.get_state("device.cluster")
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(self) -> None:
|
|
72
|
+
self._states: dict[str, Any] = {}
|
|
73
|
+
|
|
74
|
+
# =========================================================================
|
|
75
|
+
# State Management
|
|
76
|
+
# =========================================================================
|
|
77
|
+
|
|
78
|
+
def set_state(self, key: str, value: Any) -> None:
|
|
79
|
+
"""Attach state to this context.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
key: State key (e.g., "dialect.simp", "device.cluster")
|
|
83
|
+
value: State value
|
|
84
|
+
"""
|
|
85
|
+
self._states[key] = value
|
|
86
|
+
|
|
87
|
+
def get_state(self, key: str, default: Any = None) -> Any:
|
|
88
|
+
"""Get attached state by key.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
key: State key
|
|
92
|
+
default: Default value if key not found
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
State value or default
|
|
96
|
+
"""
|
|
97
|
+
return self._states.get(key, default)
|
|
98
|
+
|
|
99
|
+
def has_state(self, key: str) -> bool:
|
|
100
|
+
"""Check if state exists.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
key: State key
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
True if state exists
|
|
107
|
+
"""
|
|
108
|
+
return key in self._states
|
|
109
|
+
|
|
110
|
+
# =========================================================================
|
|
111
|
+
# Abstract Methods
|
|
112
|
+
# =========================================================================
|
|
113
|
+
|
|
114
|
+
@abstractmethod
|
|
115
|
+
def bind_primitive(
|
|
116
|
+
self, primitive: Primitive, args: tuple[Any, ...], kwargs: dict[str, Any]
|
|
117
|
+
) -> Any:
|
|
118
|
+
"""Execute a primitive in this context.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
primitive: The primitive to execute
|
|
122
|
+
args: Positional arguments (Objects)
|
|
123
|
+
kwargs: Keyword arguments (plain values)
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Result Object (TraceObject in Tracer, InterpObject in Interpreter)
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
@abstractmethod
|
|
130
|
+
def lift(self, obj: Any) -> Object:
|
|
131
|
+
"""Lift an object to this context's native Object type.
|
|
132
|
+
|
|
133
|
+
Converts objects to the appropriate type for this context:
|
|
134
|
+
- Tracer: InterpObject → TraceObject (via promote), constants → TraceObject
|
|
135
|
+
- Interpreter: keeps InterpObject as-is, may convert constants
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
obj: Object to lift (Object, constant, etc.)
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Object in the context's native type (TraceObject or InterpObject)
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
# =========================================================================
|
|
145
|
+
# Context Manager
|
|
146
|
+
# =========================================================================
|
|
147
|
+
|
|
148
|
+
def __enter__(self) -> Self:
|
|
149
|
+
"""Enter context manager (push context onto stack)."""
|
|
150
|
+
push_context(self)
|
|
151
|
+
return self
|
|
152
|
+
|
|
153
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def]
|
|
154
|
+
"""Exit context manager (pop context from stack)."""
|
|
155
|
+
pop_context()
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# =============================================================================
|
|
159
|
+
# Abstract Interpreter Interface
|
|
160
|
+
# =============================================================================
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class AbstractInterpreter(Context):
|
|
164
|
+
"""Abstract interface for Interpreters.
|
|
165
|
+
|
|
166
|
+
This allows EDSL components (like JIT) to depend on the Interpreter interface
|
|
167
|
+
without depending on the concrete Runtime implementation (which may depend on
|
|
168
|
+
ObjectStore, Backends, etc.).
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
@abstractmethod
|
|
172
|
+
def evaluate_graph(self, graph: Graph, inputs: list[Any]) -> Any:
|
|
173
|
+
"""Execute a Graph IR with given inputs."""
|
|
174
|
+
|
|
175
|
+
@abstractmethod
|
|
176
|
+
def lift(self, obj: Any) -> Any:
|
|
177
|
+
"""Lift a python object to an interpreter object."""
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
# =============================================================================
|
|
181
|
+
# Global Context Stack Management
|
|
182
|
+
# =============================================================================
|
|
183
|
+
|
|
184
|
+
_context_stack: list[Context] = []
|
|
185
|
+
_default_context: Context | None = None
|
|
186
|
+
_default_context_factory: Callable[[], Context] | None = None
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def get_current_context() -> Context | None:
|
|
190
|
+
"""Get the current active context (top of stack).
|
|
191
|
+
|
|
192
|
+
Returns None if no context is active.
|
|
193
|
+
"""
|
|
194
|
+
return _context_stack[-1] if _context_stack else None
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def push_context(context: Context) -> None:
|
|
198
|
+
"""Push a context onto the stack (enter context)."""
|
|
199
|
+
_context_stack.append(context)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def pop_context() -> Context | None:
|
|
203
|
+
"""Pop a context from the stack (exit context).
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
The popped context, or None if stack was empty.
|
|
207
|
+
"""
|
|
208
|
+
return _context_stack.pop() if _context_stack else None
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def find_context(predicate: Callable[[Context], bool]) -> Context | None:
|
|
212
|
+
"""Find a context in the stack that satisfies the predicate.
|
|
213
|
+
|
|
214
|
+
Traverses from top (most recent) to bottom of the context stack,
|
|
215
|
+
returning the first context for which predicate(ctx) returns True.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
predicate: A callable that takes a Context and returns True if it matches.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
The first matching Context, or None if no match found.
|
|
222
|
+
|
|
223
|
+
Example:
|
|
224
|
+
>>> # Find context with simp dialect state
|
|
225
|
+
>>> ctx = find_context(lambda c: c.has_state("dialect.simp"))
|
|
226
|
+
"""
|
|
227
|
+
for ctx in reversed(_context_stack):
|
|
228
|
+
if predicate(ctx):
|
|
229
|
+
return ctx
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def find_context_with_state(key: str) -> Context | None:
|
|
234
|
+
"""Find first context that has the specified state.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
key: State key to look for
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
First context with the state, or None
|
|
241
|
+
"""
|
|
242
|
+
return find_context(lambda c: c.has_state(key))
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def find_interpreter() -> Context | None:
|
|
246
|
+
"""Find first Interpreter in the context stack.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
First Interpreter context, or None if not found.
|
|
250
|
+
"""
|
|
251
|
+
return find_context(lambda c: isinstance(c, AbstractInterpreter))
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def is_tracing() -> bool:
|
|
255
|
+
"""Check if current context is a Tracer.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
True if the top of the context stack is a Tracer.
|
|
259
|
+
"""
|
|
260
|
+
from mplang.v2.edsl.tracer import Tracer
|
|
261
|
+
|
|
262
|
+
return isinstance(get_current_context(), Tracer)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
# =============================================================================
|
|
266
|
+
# Default Context Management
|
|
267
|
+
# =============================================================================
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def register_default_context_factory(factory: Callable[[], Context]) -> None:
|
|
271
|
+
"""Register a factory function to create the default context."""
|
|
272
|
+
global _default_context_factory
|
|
273
|
+
_default_context_factory = factory
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def get_default_context() -> Context:
|
|
277
|
+
"""Get the default context for eager execution."""
|
|
278
|
+
global _default_context
|
|
279
|
+
if _default_context is None:
|
|
280
|
+
if _default_context_factory is None:
|
|
281
|
+
raise RuntimeError(
|
|
282
|
+
"No default context factory registered. "
|
|
283
|
+
"Ensure mplang.v2.edsl is imported or register a factory manually."
|
|
284
|
+
)
|
|
285
|
+
_default_context = _default_context_factory()
|
|
286
|
+
return _default_context
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def set_root_context(context: Context, force: bool = False) -> None:
|
|
290
|
+
"""Set the root/default execution context.
|
|
291
|
+
|
|
292
|
+
This sets the provided context as the base of the context stack.
|
|
293
|
+
All subsequent operations will use this context as the default environment.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
context: Context to set as root.
|
|
297
|
+
force: If True, clears the existing context stack before setting.
|
|
298
|
+
If False (default), raises error if stack is not empty.
|
|
299
|
+
"""
|
|
300
|
+
if force:
|
|
301
|
+
_context_stack.clear()
|
|
302
|
+
_context_stack.append(context)
|
|
303
|
+
return
|
|
304
|
+
|
|
305
|
+
if get_current_context() is not None:
|
|
306
|
+
raise RuntimeError(
|
|
307
|
+
"Cannot set root context: Context stack is not empty. "
|
|
308
|
+
"Use force=True to overwrite the existing context."
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
push_context(context)
|