mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,279 @@
1
+ # MPLang EDSL - Experimental Architecture
2
+
3
+ **⚠️ Status**: Experimental / Work in Progress
4
+
5
+ This directory contains the next-generation EDSL (Embedded Domain-Specific Language) architecture for MPLang.
6
+
7
+ ## Why a New Architecture?
8
+
9
+ The current `mplang.core` architecture (Expr Tree + @primitive) has served us well, but we're hitting limitations:
10
+
11
+ 1. **Expr Tree** is hard to optimize (visitor pattern, nested structure)
12
+ 2. **@primitive decorators** hide complexity and limit flexibility
13
+ 3. **Type system** is split between `mptype.MPType` and `typing.BaseType`
14
+ 4. **No clear separation** between IR, frontend, and backend
15
+
16
+ Modern EDSLs (torch.fx, JAX) use **Operation List + SSA** for better analyzability and optimization.
17
+
18
+ ## Goals
19
+
20
+ ### 1. Modern IR (Operation List)
21
+
22
+ **From** (Expr Tree):
23
+ ```python
24
+ CallExpr(
25
+ func=add,
26
+ args=[VariableExpr("x"), VariableExpr("y")]
27
+ )
28
+ ```
29
+
30
+ **To** (Operation List):
31
+ ```python
32
+ %0 = input "x"
33
+ %1 = input "y"
34
+ %2 = add %0, %1
35
+ return %2
36
+ ```
37
+
38
+ ### 2. Unified Type System
39
+
40
+ **Single source of truth**: `mplang.edsl.typing.MPType`
41
+
42
+ ```python
43
+ from mplang2.edsl.typing import Tensor, Vector, MPType, f32
44
+
45
+ # All types use BaseType
46
+ plaintext: MPType = Tensor[f32, (4096,)]
47
+ ciphertext: MPType = Vector[f32, 4096]
48
+ ```
49
+
50
+ ### 3. Explicit Tracing
51
+
52
+ **Clean context management**:
53
+ ```python
54
+ from mplang2.edsl import Tracer
55
+
56
+ tracer = Tracer()
57
+ with tracer: # Context manager protocol
58
+ result = my_function(x, y)
59
+ graph = tracer.finalize(result)
60
+ ```
61
+
62
+ ### 4. Extensibility
63
+
64
+ Easy to add new backends:
65
+ - FHE (Fully Homomorphic Encryption)
66
+ - TEE (Trusted Execution Environment)
67
+ - Custom accelerators
68
+
69
+ ### 5. Layered API Architecture
70
+
71
+ The EDSL provides two distinct API layers:
72
+
73
+ 1. **Low-Level API (Graph Manipulation)**:
74
+ - Direct manipulation of the `Graph` IR.
75
+ - Generic `add_op` method (pure graph API, no op semantics).
76
+ - Analogous to MLIR's generic operation construction.
77
+ - Used by compiler passes and backend implementations.
78
+
79
+ 2. **High-Level API (Tracing)**:
80
+ - Uses `Tracer` + `Primitive` (with `abstract_eval`).
81
+ - Pythonic interface (functions, operators).
82
+ - Automatic type inference and graph construction.
83
+ - The primary interface for users.
84
+
85
+ ## Directory Structure
86
+
87
+ ```
88
+ mplang/edsl/
89
+ ├── __init__.py # Public API
90
+ ├── README.md # This file
91
+
92
+ ├── design/ # Design documents
93
+ │ ├── architecture.md # Complete architecture overview
94
+ │ ├── type_system.md # Type system design
95
+ │ └── migration.md # Migration from mplang.core
96
+
97
+ ├── typing.py # ✅ Unified type system
98
+ ├── graph.py # ✅ IR: Operation List + SSA
99
+ ├── primitive.py # ✅ Primitive abstraction
100
+ ├── object.py # ✅ TraceObject/InterpObject
101
+ ├── context.py # ✅ Context management
102
+ ├── tracer.py # ✅ Explicit tracer
103
+ ├── interpreter.py # ✅ Interpreter + GraphInterpreter
104
+ └── jit.py # ✅ @jit decorator
105
+ ```
106
+
107
+ ## Implementation Status
108
+
109
+ ### ✅ Completed (Phase 1-4)
110
+
111
+ - [x] Type system (`typing.py`) - 649 lines
112
+ - [x] Graph IR (`graph.py`) - 388 lines
113
+ - [x] Primitive abstraction (`primitive.py`) - 338 lines
114
+ - [x] Object hierarchy (`object.py`) - 153 lines
115
+ - [x] Context system (`context.py`) - 117 lines
116
+ - [x] Tracer (`tracer.py`) - 201 lines
117
+ - [x] Interpreter (`interpreter.py`) - 66 lines
118
+ - [x] JIT decorator (`jit.py`) - 42 lines
119
+ - [x] Design documents
120
+ - [x] **153 tests passing** (140 edsl + 13 core2)
121
+
122
+ ### 🚧 In Progress
123
+ - [ ] Integration with existing ops/kernels
124
+ - [ ] Migration utilities
125
+ - [ ] Performance benchmarks
126
+
127
+ ### ❌ Dropped / Deprecated
128
+ - [x] Builder API (`builder.py`) - Integrated into `Tracer`
129
+
130
+ ### 📋 Planned
131
+ - [ ] Advanced optimizations
132
+ - [ ] More backends (TEE, MPC)
133
+
134
+ ## Quick Start
135
+
136
+ ### Using the New Type System
137
+
138
+ ```python
139
+ from mplang2.edsl.typing import Tensor, Vector, CustomType, f32
140
+
141
+ # Define types
142
+ PlaintextVec = Tensor[f32, (4096,)]
143
+ CiphertextVec = Vector[f32, 4096]
144
+ EncryptionKey = CustomType("EncryptionKey")
145
+
146
+ # Type annotations
147
+ def encrypt(data: PlaintextVec, key: EncryptionKey) -> CiphertextVec:
148
+ ...
149
+ ```
150
+
151
+ ### Using the Tracer (Graph Construction)
152
+
153
+ ```python
154
+ from mplang2.edsl import Tracer
155
+ from mplang2.dialects.simp import pcall_static
156
+
157
+ def my_program(x, y):
158
+ # This function is traced into a Graph
159
+ return pcall_static((0, 1), lambda a, b: a + b, x, y)
160
+
161
+ tracer = Tracer()
162
+ with tracer:
163
+ # Inputs are automatically lifted to TraceObjects
164
+ result = my_program(x, y)
165
+
166
+ # Finalize graph
167
+ graph = tracer.finalize(result)
168
+ ```
169
+
170
+ ## Design Documents
171
+
172
+ Detailed design documents are in the `design/` subdirectory:
173
+
174
+ ### 1. [architecture.md](design/architecture.md)
175
+
176
+ Complete EDSL architecture overview covering:
177
+ - Core components (Tracer, Graph)
178
+ - Design principles (Closed-World, TracedFunction vs First-Class Functions)
179
+ - Control flow handling (Dialect-specific, e.g., `simp.uniform_cond`)
180
+ - Comparison with JAX, PyTorch, TensorFlow
181
+
182
+ ### 2. [type_system.md](design/type_system.md)
183
+
184
+ New type system design:
185
+ - Three orthogonal dimensions (Layout, Encryption, Distribution)
186
+ - Type composition examples
187
+ - Ops writing guide
188
+ - Migration strategy
189
+
190
+ ### 3. [migration.md](design/migration.md)
191
+
192
+ Migration path from `mplang.core` to `mplang.edsl`:
193
+ - 6-phase migration plan
194
+ - Backward compatibility strategy
195
+ - Type conversion utilities
196
+
197
+ ## Relationship with mplang.core
198
+
199
+ ```
200
+ mplang/
201
+ ├── core/ # Stable API (current production)
202
+ │ ├── primitive.py
203
+ │ ├── tracer.py
204
+ │ └── expr/
205
+
206
+ ├── edsl/ # Experimental (this directory)
207
+ │ ├── typing.py # Can be used independently
208
+ │ ├── graph.py # Future replacement for core.expr
209
+ │ └── tracer.py # Future replacement for core.tracer
210
+
211
+ ├── ops/ # Shared between core and edsl
212
+ ├── kernels/ # Shared between core and edsl
213
+ └── runtime/ # Shared between core and edsl
214
+ ```
215
+
216
+ **Migration Strategy**:
217
+ 1. Develop `edsl` in parallel (no breaking changes to `core`)
218
+ 2. Gradually move internal code to use `edsl.typing`
219
+ 3. Add adapters between `core` and `edsl`
220
+ 4. Deprecate `core` in future major version
221
+
222
+ ## Contributing
223
+
224
+ We welcome contributions! Since this is experimental:
225
+
226
+ 1. **Read the design docs first**: Understand the architecture
227
+ 2. **Start small**: Pick a specific component (e.g., Graph IR)
228
+ 3. **Discuss early**: Open an issue before implementing
229
+ 4. **Test thoroughly**: Add unit tests for new code
230
+
231
+ ### Development Workflow
232
+
233
+ ```bash
234
+ # Install dev dependencies
235
+ uv sync --group dev
236
+
237
+ # Run tests (future)
238
+ uv run pytest mplang/edsl/
239
+
240
+ # Lint
241
+ uv run ruff check mplang/edsl/
242
+ uv run ruff format mplang/edsl/
243
+
244
+ # Type check
245
+ uv run mypy mplang/edsl/
246
+ ```
247
+
248
+ ## FAQ
249
+
250
+ ### Q: Should I use `mplang.edsl` in production?
251
+
252
+ **A**: No, use `mplang.core`. `mplang.edsl` is experimental.
253
+
254
+ ### Q: Can I use `mplang.edsl.typing` independently?
255
+
256
+ **A**: Yes! The type system is stable and can be used for type annotations.
257
+
258
+ ### Q: When will `edsl` replace `core`?
259
+
260
+ **A**: No timeline yet. We need to:
261
+ 1. Complete the implementation
262
+ 2. Validate performance
263
+ 3. Migrate all tests
264
+ 4. Get community feedback
265
+
266
+ ### Q: How can I help?
267
+
268
+ **A**: Check the implementation status above and pick an unimplemented component. Open an issue to discuss!
269
+
270
+ ## References
271
+
272
+ - **torch.fx**: https://pytorch.org/docs/stable/fx.html
273
+ - **JAX jaxpr**: https://jax.readthedocs.io/en/latest/jaxpr.html
274
+ - **MLIR**: https://mlir.llvm.org/
275
+
276
+ ---
277
+
278
+ **Last Updated**: 2025-01-11
279
+ **Maintainers**: MPLang Team
@@ -0,0 +1,99 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Public entrypoint for the MPLang EDSL.
16
+
17
+ This module keeps the surface area intentionally small so downstream code can
18
+ simply write::
19
+
20
+ import mplang.v2.edsl as el
21
+ import mplang.v2.edsl.typing as elt
22
+
23
+ The `el` namespace re-exports the commonly used building blocks (context,
24
+ graph, tracer, primitives, etc.), while the full type system lives under
25
+ ``mplang.edsl.typing``.
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ # Re-export the typing module so callers can `import mplang.v2.edsl.typing as elt`
31
+ from . import typing as typing
32
+
33
+ # Context management
34
+ from .context import (
35
+ Context,
36
+ find_context,
37
+ find_context_with_state,
38
+ find_interpreter,
39
+ get_current_context,
40
+ get_default_context,
41
+ is_tracing,
42
+ pop_context,
43
+ push_context,
44
+ register_default_context_factory,
45
+ set_root_context,
46
+ )
47
+
48
+ # Graph IR
49
+ from .graph import Graph, Operation, Value
50
+
51
+ # High-level helpers
52
+ from .jit import jit
53
+ from .object import Object
54
+ from .primitive import Primitive, primitive
55
+ from .printer import GraphPrinter, format_graph
56
+ from .tracer import TracedFunction, TraceObject, Tracer, trace
57
+ from .typing import MPType, ScalarType, SSType, TableType, TensorType, VectorType
58
+
59
+ # Type Aliases for strong typing
60
+ MPObject = Object[MPType]
61
+ ScalarObject = Object[ScalarType]
62
+ SSObject = Object[SSType]
63
+ TableObject = Object[TableType]
64
+ TensorObject = Object[TensorType]
65
+ VectorObject = Object[VectorType]
66
+
67
+ __all__ = [
68
+ "Context",
69
+ "Graph",
70
+ "GraphPrinter",
71
+ "MPObject",
72
+ "Object",
73
+ "Operation",
74
+ "Primitive",
75
+ "SSObject",
76
+ "ScalarObject",
77
+ "TableObject",
78
+ "TensorObject",
79
+ "TraceObject",
80
+ "TracedFunction",
81
+ "Tracer",
82
+ "Value",
83
+ "VectorObject",
84
+ "find_context",
85
+ "find_context_with_state",
86
+ "find_interpreter",
87
+ "format_graph",
88
+ "get_current_context",
89
+ "get_default_context",
90
+ "is_tracing",
91
+ "jit",
92
+ "pop_context",
93
+ "primitive",
94
+ "push_context",
95
+ "register_default_context_factory",
96
+ "set_root_context",
97
+ "trace",
98
+ "typing",
99
+ ]
@@ -0,0 +1,311 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Context: EDSL Execution Context Abstraction.
16
+
17
+ This module defines the Context hierarchy:
18
+ - Context: Base class for EDSL execution contexts (with bind_primitive method)
19
+ - Tracer: Tracing context (records operations to Graph IR)
20
+ - Interpreter: Execution context (executes operations immediately)
21
+
22
+ Contexts can be used directly with Python's 'with' statement:
23
+
24
+ from mplang.v2.edsl import Tracer
25
+
26
+ tracer = Tracer()
27
+ with tracer:
28
+ # Operations run under tracer context
29
+ result = primitive.bind(x, y)
30
+
31
+ State Management:
32
+ Contexts can carry arbitrary named state via set_state/get_state.
33
+ This allows different layers (device, ml, analytics) to attach their
34
+ own state without the EDSL layer knowing about specific state types.
35
+
36
+ State key conventions:
37
+ - "dialect.{name}": Dialect runtime state (e.g., "dialect.simp")
38
+ - "device.cluster": Device/cluster configuration
39
+ - "ml.{component}": ML pipeline components
40
+ """
41
+
42
+ from __future__ import annotations
43
+
44
+ from abc import ABC, abstractmethod
45
+ from collections.abc import Callable
46
+ from typing import TYPE_CHECKING, Any, Self
47
+
48
+ if TYPE_CHECKING:
49
+ from mplang.v2.edsl.graph import Graph
50
+ from mplang.v2.edsl.object import Object
51
+ from mplang.v2.edsl.primitive import Primitive
52
+
53
+
54
+ class Context(ABC):
55
+ """Base class for EDSL execution contexts with extensible state slots.
56
+
57
+ A Context represents an environment where primitives are executed.
58
+ There are two types of contexts:
59
+ - Tracer: Records operations to Graph IR (compile-time)
60
+ - Interpreter: Execution context (executes operations immediately)
61
+
62
+ State Management:
63
+ Contexts can carry arbitrary named state. Different layers can attach
64
+ their own state without the EDSL layer knowing specifics:
65
+
66
+ >>> ctx.set_state("device.cluster", cluster_spec)
67
+ >>> ctx.set_state("dialect.simp", simp_driver)
68
+ >>> cluster = ctx.get_state("device.cluster")
69
+ """
70
+
71
+ def __init__(self) -> None:
72
+ self._states: dict[str, Any] = {}
73
+
74
+ # =========================================================================
75
+ # State Management
76
+ # =========================================================================
77
+
78
+ def set_state(self, key: str, value: Any) -> None:
79
+ """Attach state to this context.
80
+
81
+ Args:
82
+ key: State key (e.g., "dialect.simp", "device.cluster")
83
+ value: State value
84
+ """
85
+ self._states[key] = value
86
+
87
+ def get_state(self, key: str, default: Any = None) -> Any:
88
+ """Get attached state by key.
89
+
90
+ Args:
91
+ key: State key
92
+ default: Default value if key not found
93
+
94
+ Returns:
95
+ State value or default
96
+ """
97
+ return self._states.get(key, default)
98
+
99
+ def has_state(self, key: str) -> bool:
100
+ """Check if state exists.
101
+
102
+ Args:
103
+ key: State key
104
+
105
+ Returns:
106
+ True if state exists
107
+ """
108
+ return key in self._states
109
+
110
+ # =========================================================================
111
+ # Abstract Methods
112
+ # =========================================================================
113
+
114
+ @abstractmethod
115
+ def bind_primitive(
116
+ self, primitive: Primitive, args: tuple[Any, ...], kwargs: dict[str, Any]
117
+ ) -> Any:
118
+ """Execute a primitive in this context.
119
+
120
+ Args:
121
+ primitive: The primitive to execute
122
+ args: Positional arguments (Objects)
123
+ kwargs: Keyword arguments (plain values)
124
+
125
+ Returns:
126
+ Result Object (TraceObject in Tracer, InterpObject in Interpreter)
127
+ """
128
+
129
+ @abstractmethod
130
+ def lift(self, obj: Any) -> Object:
131
+ """Lift an object to this context's native Object type.
132
+
133
+ Converts objects to the appropriate type for this context:
134
+ - Tracer: InterpObject → TraceObject (via promote), constants → TraceObject
135
+ - Interpreter: keeps InterpObject as-is, may convert constants
136
+
137
+ Args:
138
+ obj: Object to lift (Object, constant, etc.)
139
+
140
+ Returns:
141
+ Object in the context's native type (TraceObject or InterpObject)
142
+ """
143
+
144
+ # =========================================================================
145
+ # Context Manager
146
+ # =========================================================================
147
+
148
+ def __enter__(self) -> Self:
149
+ """Enter context manager (push context onto stack)."""
150
+ push_context(self)
151
+ return self
152
+
153
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def]
154
+ """Exit context manager (pop context from stack)."""
155
+ pop_context()
156
+
157
+
158
+ # =============================================================================
159
+ # Abstract Interpreter Interface
160
+ # =============================================================================
161
+
162
+
163
+ class AbstractInterpreter(Context):
164
+ """Abstract interface for Interpreters.
165
+
166
+ This allows EDSL components (like JIT) to depend on the Interpreter interface
167
+ without depending on the concrete Runtime implementation (which may depend on
168
+ ObjectStore, Backends, etc.).
169
+ """
170
+
171
+ @abstractmethod
172
+ def evaluate_graph(self, graph: Graph, inputs: list[Any]) -> Any:
173
+ """Execute a Graph IR with given inputs."""
174
+
175
+ @abstractmethod
176
+ def lift(self, obj: Any) -> Any:
177
+ """Lift a python object to an interpreter object."""
178
+
179
+
180
+ # =============================================================================
181
+ # Global Context Stack Management
182
+ # =============================================================================
183
+
184
+ _context_stack: list[Context] = []
185
+ _default_context: Context | None = None
186
+ _default_context_factory: Callable[[], Context] | None = None
187
+
188
+
189
+ def get_current_context() -> Context | None:
190
+ """Get the current active context (top of stack).
191
+
192
+ Returns None if no context is active.
193
+ """
194
+ return _context_stack[-1] if _context_stack else None
195
+
196
+
197
+ def push_context(context: Context) -> None:
198
+ """Push a context onto the stack (enter context)."""
199
+ _context_stack.append(context)
200
+
201
+
202
+ def pop_context() -> Context | None:
203
+ """Pop a context from the stack (exit context).
204
+
205
+ Returns:
206
+ The popped context, or None if stack was empty.
207
+ """
208
+ return _context_stack.pop() if _context_stack else None
209
+
210
+
211
+ def find_context(predicate: Callable[[Context], bool]) -> Context | None:
212
+ """Find a context in the stack that satisfies the predicate.
213
+
214
+ Traverses from top (most recent) to bottom of the context stack,
215
+ returning the first context for which predicate(ctx) returns True.
216
+
217
+ Args:
218
+ predicate: A callable that takes a Context and returns True if it matches.
219
+
220
+ Returns:
221
+ The first matching Context, or None if no match found.
222
+
223
+ Example:
224
+ >>> # Find context with simp dialect state
225
+ >>> ctx = find_context(lambda c: c.has_state("dialect.simp"))
226
+ """
227
+ for ctx in reversed(_context_stack):
228
+ if predicate(ctx):
229
+ return ctx
230
+ return None
231
+
232
+
233
+ def find_context_with_state(key: str) -> Context | None:
234
+ """Find first context that has the specified state.
235
+
236
+ Args:
237
+ key: State key to look for
238
+
239
+ Returns:
240
+ First context with the state, or None
241
+ """
242
+ return find_context(lambda c: c.has_state(key))
243
+
244
+
245
+ def find_interpreter() -> Context | None:
246
+ """Find first Interpreter in the context stack.
247
+
248
+ Returns:
249
+ First Interpreter context, or None if not found.
250
+ """
251
+ return find_context(lambda c: isinstance(c, AbstractInterpreter))
252
+
253
+
254
+ def is_tracing() -> bool:
255
+ """Check if current context is a Tracer.
256
+
257
+ Returns:
258
+ True if the top of the context stack is a Tracer.
259
+ """
260
+ from mplang.v2.edsl.tracer import Tracer
261
+
262
+ return isinstance(get_current_context(), Tracer)
263
+
264
+
265
+ # =============================================================================
266
+ # Default Context Management
267
+ # =============================================================================
268
+
269
+
270
+ def register_default_context_factory(factory: Callable[[], Context]) -> None:
271
+ """Register a factory function to create the default context."""
272
+ global _default_context_factory
273
+ _default_context_factory = factory
274
+
275
+
276
+ def get_default_context() -> Context:
277
+ """Get the default context for eager execution."""
278
+ global _default_context
279
+ if _default_context is None:
280
+ if _default_context_factory is None:
281
+ raise RuntimeError(
282
+ "No default context factory registered. "
283
+ "Ensure mplang.v2.edsl is imported or register a factory manually."
284
+ )
285
+ _default_context = _default_context_factory()
286
+ return _default_context
287
+
288
+
289
+ def set_root_context(context: Context, force: bool = False) -> None:
290
+ """Set the root/default execution context.
291
+
292
+ This sets the provided context as the base of the context stack.
293
+ All subsequent operations will use this context as the default environment.
294
+
295
+ Args:
296
+ context: Context to set as root.
297
+ force: If True, clears the existing context stack before setting.
298
+ If False (default), raises error if stack is not empty.
299
+ """
300
+ if force:
301
+ _context_stack.clear()
302
+ _context_stack.append(context)
303
+ return
304
+
305
+ if get_current_context() is not None:
306
+ raise RuntimeError(
307
+ "Cannot set root context: Context stack is not empty. "
308
+ "Use force=True to overwrite the existing context."
309
+ )
310
+
311
+ push_context(context)