mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,303 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any
16
+
17
+ import sqlglot as sg
18
+ from jax.tree_util import PyTreeDef, tree_flatten
19
+ from sqlglot import exp as sge
20
+ from sqlglot.optimizer import annotate_types as opt_annot
21
+ from sqlglot.optimizer import qualify as opt_qualify
22
+
23
+ from mplang.v1.core import MPObject, PFunction, TableType
24
+ from mplang.v1.core.dtypes import (
25
+ BINARY,
26
+ BOOL,
27
+ DATE,
28
+ DECIMAL,
29
+ FLOAT32,
30
+ FLOAT64,
31
+ INT8,
32
+ INT16,
33
+ INT32,
34
+ INT64,
35
+ INTERVAL,
36
+ JSON,
37
+ STRING,
38
+ TIME,
39
+ TIMESTAMP,
40
+ UINT8,
41
+ UINT16,
42
+ UINT32,
43
+ UINT64,
44
+ UUID,
45
+ DType,
46
+ )
47
+ from mplang.v1.ops.base import stateless_mod
48
+
49
+ _SQL_MOD = stateless_mod("sql")
50
+
51
+
52
+ # Static dtype mappings (MPLang <-> SQL)
53
+ MP_TO_SQL_TYPE: dict[DType, str] = {
54
+ # Floats
55
+ FLOAT64: "DOUBLE",
56
+ FLOAT32: "FLOAT",
57
+ # Signed ints
58
+ INT8: "TINYINT",
59
+ INT16: "SMALLINT",
60
+ INT32: "INT",
61
+ INT64: "BIGINT",
62
+ # Unsigned ints (portable approximations)
63
+ UINT8: "SMALLINT",
64
+ UINT16: "INT",
65
+ UINT32: "BIGINT",
66
+ UINT64: "DECIMAL(38)",
67
+ # Booleans & strings
68
+ BOOL: "BOOLEAN",
69
+ STRING: "VARCHAR",
70
+ # Dates / times
71
+ DATE: "DATE",
72
+ TIME: "TIME",
73
+ TIMESTAMP: "TIMESTAMP",
74
+ # Other table types
75
+ DECIMAL: "DECIMAL",
76
+ JSON: "JSON",
77
+ BINARY: "BLOB",
78
+ UUID: "UUID",
79
+ INTERVAL: "INTERVAL",
80
+ }
81
+
82
+ SQL_TYPE_TO_MP: dict[str, DType] = {
83
+ # Floats
84
+ "double": FLOAT64,
85
+ "double precision": FLOAT64,
86
+ "float": FLOAT32,
87
+ "real": FLOAT32,
88
+ # Signed ints
89
+ "bigint": INT64,
90
+ "long": INT64,
91
+ "int": INT32,
92
+ "integer": INT32,
93
+ "int4": INT32,
94
+ "smallint": INT16,
95
+ "int2": INT16,
96
+ "tinyint": INT8,
97
+ "int1": INT8,
98
+ # Unsigned (rare in SQL)
99
+ "uint8": UINT8,
100
+ "ubyte": UINT8,
101
+ "uint16": UINT16,
102
+ "uint32": UINT32,
103
+ "uint64": UINT64,
104
+ # Booleans / strings
105
+ "bool": BOOL,
106
+ "boolean": BOOL,
107
+ "char": STRING,
108
+ "varchar": STRING,
109
+ "text": STRING,
110
+ "string": STRING,
111
+ # Dates / times
112
+ "date": DATE,
113
+ "time": TIME,
114
+ "timestamp": TIMESTAMP,
115
+ # Decimal / numeric
116
+ "decimal": DECIMAL,
117
+ "numeric": DECIMAL,
118
+ # Others
119
+ "json": JSON,
120
+ "binary": BINARY,
121
+ "varbinary": BINARY,
122
+ "blob": BINARY,
123
+ "uuid": UUID,
124
+ "interval": INTERVAL,
125
+ }
126
+
127
+
128
+ def _deduce_out_schema(
129
+ parsed: sge.Expression,
130
+ dialect: str,
131
+ in_schemas: dict[str, TableType],
132
+ ) -> TableType:
133
+ """Deduce output schema using sqlglot's qualify + annotate_types.
134
+
135
+ This implementation leverages sqlglot's optimizer to resolve table/column
136
+ references (including star expansion) and annotate expression types. It then
137
+ maps sqlglot DataType to mplang DType and returns a TableType.
138
+ """
139
+
140
+ # 1) Build sqlglot schema from MPObject/TableType inputs
141
+ def _dtype_to_sql(dt: DType) -> str:
142
+ return MP_TO_SQL_TYPE.get(dt, "VARCHAR")
143
+
144
+ sqlglot_schema: dict[str, dict[str, str]] = {
145
+ tname: {col: _dtype_to_sql(dt) for col, dt in schema.columns}
146
+ for tname, schema in in_schemas.items()
147
+ }
148
+
149
+ # 2) Parse with read dialect; 3) Qualify (resolve names, expand star); 4) Annotate types
150
+ qualified = opt_qualify.qualify(parsed, schema=sqlglot_schema, dialect=dialect)
151
+ typed = opt_annot.annotate_types(qualified, schema=sqlglot_schema)
152
+
153
+ # 5) Extract projection names and types
154
+ select = typed if isinstance(typed, sge.Select) else typed.find(sge.Select)
155
+ if select is None:
156
+ raise NotImplementedError(
157
+ "Only SELECT queries are supported for schema deduction"
158
+ )
159
+
160
+ def _sqlglot_type_to_dtype(tobj: Any) -> DType:
161
+ ts = str(tobj).lower().replace(" with time zone", "").strip()
162
+ base = ts.split("(", 1)[0].strip()
163
+ return SQL_TYPE_TO_MP.get(base, STRING)
164
+
165
+ pairs: list[tuple[str, DType]] = []
166
+ idx = 0
167
+ used: set[str] = set()
168
+ for proj in select.expressions:
169
+ name = getattr(proj, "alias_or_name", None) or getattr(proj, "name", None)
170
+ if not name:
171
+ name = f"expr_{idx}"
172
+ idx += 1
173
+ t = getattr(proj, "type", None)
174
+ if t is None:
175
+ raise NotImplementedError(
176
+ "Cannot infer type for projection; please provide out_type explicitly"
177
+ )
178
+ dtype = _sqlglot_type_to_dtype(t)
179
+ if name in used:
180
+ raise ValueError(
181
+ f"Duplicate output column name '{name}' after qualification"
182
+ )
183
+ used.add(name)
184
+ pairs.append((name, dtype))
185
+
186
+ return TableType.from_pairs(pairs)
187
+
188
+
189
+ @_SQL_MOD.op_def()
190
+ def run_sql(
191
+ query: str,
192
+ *,
193
+ out_type: TableType | None = None,
194
+ dialect: str = "duckdb",
195
+ **in_tables: Any,
196
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
197
+ """Build a sql.run PFunction from a SQL query with optional schema deduction.
198
+
199
+ API: run_sql(query: str, *, out_type: TableType | None = None, dialect: str = "duckdb", **in_tables) -> (PFunction, [MPObject], PyTreeDef)
200
+
201
+ Semantics:
202
+ - Parses the SQL and binds only the tables that are actually referenced in the query by name.
203
+ - If ``out_type`` is not provided, attempts to deduce the output table schema using sqlglot (qualify + annotate types).
204
+ - Returns a triad consisting of the constructed PFunction (``fn_type='sql.run'``), the ordered list of input MPObjects, and the output PyTreeDef.
205
+
206
+ Difference vs ``run_sql_raw``: this op can infer ``out_type`` and will parse the SQL to filter inputs; ``run_sql_raw`` requires an explicit ``out_type`` and does not parse/filter inputs.
207
+ """
208
+ # Extract required table names from SQL (order by first appearance)
209
+ parsed = sg.parse_one(query, read=dialect)
210
+ required_names: list[str] = []
211
+ for t in parsed.find_all(sge.Table):
212
+ # Prefer .name; fallback to str(this) if needed
213
+ tname = getattr(t, "name", None) or str(t.this)
214
+ if tname not in required_names:
215
+ required_names.append(tname)
216
+
217
+ # Disallow extras not referenced by the query to avoid surprises
218
+ extra = set(in_tables.keys()) - set(required_names)
219
+ if extra:
220
+ raise ValueError(
221
+ f"Unexpected tables provided that are not referenced in SQL: {sorted(extra)}"
222
+ )
223
+
224
+ # Validate required tables and require MPObject for runtime registration
225
+ in_names: list[str] = []
226
+ ins_info: list[TableType] = []
227
+ in_vars: list[MPObject] = []
228
+ for name in required_names:
229
+ if name not in in_tables:
230
+ raise KeyError(f"Missing required table '{name}' for SQL query")
231
+ obj = in_tables[name]
232
+ if not isinstance(obj, MPObject):
233
+ raise TypeError(
234
+ f"Table '{name}' must be an MPObject (for runtime registration), got {type(obj).__name__}"
235
+ )
236
+ assert obj.schema is not None, f"Input table '{name}' missing schema"
237
+ in_vars.append(obj)
238
+ ins_info.append(obj.schema)
239
+ in_names.append(name)
240
+
241
+ if out_type is None:
242
+ in_schemas: dict[str, TableType] = {
243
+ n: in_tables[n].schema for n in required_names
244
+ }
245
+ out_type = _deduce_out_schema(parsed, dialect, in_schemas)
246
+
247
+ pfn = PFunction(
248
+ fn_type="sql.run",
249
+ ins_info=tuple(ins_info),
250
+ outs_info=(out_type,),
251
+ fn_name="",
252
+ fn_text=query,
253
+ in_names=tuple(in_names),
254
+ dialect=dialect,
255
+ )
256
+ _, treedef = tree_flatten(out_type)
257
+ return pfn, in_vars, treedef
258
+
259
+
260
+ @_SQL_MOD.op_def()
261
+ def run_sql_raw(
262
+ query: str,
263
+ out_type: TableType,
264
+ *,
265
+ dialect: str = "duckdb",
266
+ in_tables: dict[str, MPObject] | None = None,
267
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
268
+ """Build a sql.run PFunction from a SQL query with an explicit output schema.
269
+
270
+ API: run_sql_raw(query: str, out_type: TableType, *, dialect: str = "duckdb", in_tables: dict[str, MPObject] | None = None) -> (PFunction, [MPObject], PyTreeDef)
271
+
272
+ Semantics:
273
+ - Does not parse the SQL; carries all tables provided via ``in_tables`` in the mapping's iteration order.
274
+ - Requires an explicit ``out_type``; no schema deduction is attempted.
275
+ - Returns a triad consisting of the constructed PFunction (``fn_type='sql.run'``), the ordered list of input MPObjects, and the output PyTreeDef.
276
+
277
+ Difference vs ``run_sql``: this op requires ``out_type`` and does not parse/filter inputs; ``run_sql`` can infer ``out_type`` and selects only tables referenced by the query.
278
+ """
279
+
280
+ # Collect inputs strictly as provided by caller
281
+ in_names: list[str] = []
282
+ ins_info: list[TableType] = []
283
+ in_vars: list[MPObject] = []
284
+ if in_tables:
285
+ for name, tbl in in_tables.items():
286
+ if not isinstance(tbl, MPObject):
287
+ raise TypeError(f"Input table '{name}' is not an MPObject {type(tbl)}")
288
+ assert tbl.schema is not None, f"Input table '{name}' is missing a schema"
289
+ in_names.append(name)
290
+ ins_info.append(tbl.schema)
291
+ in_vars.append(tbl)
292
+
293
+ pfn = PFunction(
294
+ fn_type="sql.run",
295
+ fn_name="",
296
+ fn_text=query,
297
+ ins_info=tuple(ins_info),
298
+ outs_info=(out_type,),
299
+ in_names=tuple(in_names),
300
+ dialect=dialect,
301
+ )
302
+ _, treedef = tree_flatten(out_type)
303
+ return pfn, in_vars, treedef
@@ -14,8 +14,8 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- from mplang.core import UINT8, TensorType
18
- from mplang.ops.base import stateless_mod
17
+ from mplang.v1.core import UINT8, TensorType
18
+ from mplang.v1.ops.base import stateless_mod
19
19
 
20
20
  _TEE_MOD = stateless_mod("tee")
21
21
 
@@ -20,8 +20,8 @@ This module contains runtime implementations including:
20
20
  - Driver for distributed execution
21
21
  """
22
22
 
23
- from mplang.runtime.driver import Driver, DriverVar
24
- from mplang.runtime.simulation import Simulator
23
+ from mplang.v1.runtime.driver import Driver, DriverVar
24
+ from mplang.v1.runtime.simulation import Simulator
25
25
 
26
26
  __all__ = [
27
27
  "Driver",
@@ -0,0 +1,230 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """SPU IChannel implementation that bridges to MPLang CommunicatorBase.
16
+
17
+ This module provides BaseChannel, which allows SPU to reuse MPLang's
18
+ existing communication layer (ThreadCommunicator/HttpCommunicator) instead
19
+ of creating separate BRPC connections.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import logging
25
+ from typing import TYPE_CHECKING
26
+
27
+ import spu.libspu as libspu
28
+
29
+ if TYPE_CHECKING:
30
+ from mplang.v1.core.comm import CommunicatorBase
31
+
32
+
33
+ class BaseChannel(libspu.link.IChannel):
34
+ """Bridge MPLang CommunicatorBase to SPU IChannel interface.
35
+
36
+ This adapter allows SPU to use MPLang's existing communication layer
37
+ (ThreadCommunicator or HttpCommunicator) instead of creating separate
38
+ BRPC connections.
39
+
40
+ Each BaseChannel represents a channel to ONE peer rank.
41
+
42
+ Communication Protocol:
43
+ - SPU calls send(tag, bytes_data) -> MPLang comm.send(peer, key, bytes_data)
44
+ - SPU calls recv(tag) -> bytes_data <- MPLang comm.recv(peer, key)
45
+
46
+ Tag Namespace:
47
+ All tags are prefixed with "spu:" to avoid collision with other
48
+ MPLang traffic on the same communicator.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ comm: CommunicatorBase,
54
+ local_rank: int,
55
+ peer_rank: int,
56
+ tag_prefix: str = "spu",
57
+ ):
58
+ """Initialize channel to a specific peer.
59
+
60
+ Args:
61
+ comm: MPLang communicator instance (Thread/Http)
62
+ local_rank: Global rank of this party (for logging/debugging)
63
+ peer_rank: Global rank of the peer party
64
+ tag_prefix: Prefix for all tags to avoid collision (default: "spu")
65
+ """
66
+ super().__init__()
67
+ self._comm = comm
68
+ self._local_rank = local_rank
69
+ self._peer_rank = peer_rank
70
+ self._tag_prefix = tag_prefix
71
+
72
+ logging.debug(
73
+ f"BaseChannel initialized: local_rank={local_rank}, "
74
+ f"peer_rank={peer_rank}, tag_prefix={tag_prefix}"
75
+ )
76
+
77
+ def _make_key(self, tag: str) -> str:
78
+ """Create unique key for MPLang comm.
79
+
80
+ Prefixes the tag to avoid collision with non-SPU traffic.
81
+
82
+ Args:
83
+ tag: SPU-provided tag (e.g., "send_0", "recv_0")
84
+
85
+ Returns:
86
+ Prefixed key (e.g., "spu:send_0")
87
+ """
88
+ return f"{self._tag_prefix}:{tag}"
89
+
90
+ def Send(self, tag: str, data: bytes) -> None:
91
+ """Send bytes to peer (synchronous in SPU semantics).
92
+
93
+ Args:
94
+ tag: Message tag for matching send/recv pairs
95
+ data: Raw bytes to send
96
+ """
97
+ key = self._make_key(tag)
98
+ logging.debug(
99
+ f"BaseChannel.Send: {self._local_rank} -> {self._peer_rank}, "
100
+ f"tag={tag}, key={key}, size={len(data)}"
101
+ )
102
+
103
+ # Send raw bytes directly
104
+ # Note: CommunicatorBase.send expects Any type, bytes is acceptable
105
+ self._comm.send(self._peer_rank, key, data)
106
+
107
+ def Recv(self, tag: str) -> bytes:
108
+ """Receive bytes from peer (blocking).
109
+
110
+ Args:
111
+ tag: Message tag for matching send/recv pairs
112
+
113
+ Returns:
114
+ Raw bytes received
115
+ """
116
+ key = self._make_key(tag)
117
+ logging.debug(
118
+ f"BaseChannel.Recv: {self._local_rank} <- {self._peer_rank}, "
119
+ f"tag={tag}, key={key}"
120
+ )
121
+
122
+ # Receive data (should be bytes)
123
+ data = self._comm.recv(self._peer_rank, key)
124
+
125
+ # Validate data type
126
+ if not isinstance(data, bytes):
127
+ raise TypeError(
128
+ f"Expected bytes from communicator, got {type(data).__name__}. "
129
+ f"Communicator must support raw bytes transmission for SPU channels."
130
+ )
131
+
132
+ logging.debug(
133
+ f"BaseChannel.Recv complete: {self._local_rank} <- {self._peer_rank}, "
134
+ f"tag={tag}, size={len(data)}"
135
+ )
136
+ return data
137
+
138
+ def SendAsync(self, tag: str, data: bytes) -> None:
139
+ """Async send (MPLang's send is already async at network layer).
140
+
141
+ For HttpCommunicator, the underlying httpx.put() is non-blocking
142
+ at the HTTP client level. For ThreadCommunicator, send is instant
143
+ (memory transfer).
144
+
145
+ Args:
146
+ tag: Message tag
147
+ data: Raw bytes to send
148
+ """
149
+ # Reuse synchronous send - it's already async underneath
150
+ self.Send(tag, data)
151
+
152
+ def SendAsyncThrottled(self, tag: str, data: bytes) -> None:
153
+ """Throttled async send.
154
+
155
+ Currently maps to regular SendAsync. Future optimization could
156
+ implement rate limiting if needed.
157
+
158
+ Args:
159
+ tag: Message tag
160
+ data: Raw bytes to send
161
+ """
162
+ self.SendAsync(tag, data)
163
+
164
+ def TestSend(self, timeout: int) -> None:
165
+ """Test if this channel can send a dummy msg to peer.
166
+
167
+ Uses fixed 0 seq_id as dummy msg's id to make this function reentrant.
168
+ ConnectToMesh will retry on this multiple times.
169
+
170
+ Args:
171
+ timeout: Timeout in milliseconds
172
+ """
173
+ # Send a handshake message to test connectivity
174
+ # Use fixed tag "__test__" to make this reentrant (idempotent)
175
+ test_data = b"\x00" # Minimal 1-byte message with seq_id=0
176
+ self.Send("__test__", test_data)
177
+
178
+ def TestRecv(self) -> None:
179
+ """Wait for dummy msg from peer.
180
+
181
+ Timeout is controlled by recv_timeout_ms in link descriptor.
182
+ """
183
+ # Receive the handshake message from peer
184
+ # This blocks until message arrives (timeout from desc.recv_timeout_ms)
185
+ test_data = self.Recv("__test__")
186
+ # Validate it's the expected handshake message
187
+ if test_data != b"\x00":
188
+ logging.warning(
189
+ f"TestRecv: unexpected handshake data from {self._peer_rank}, "
190
+ f"expected b'\\x00', got {test_data!r}"
191
+ )
192
+
193
+ def WaitLinkTaskFinish(self) -> None:
194
+ """Wait for all pending async tasks.
195
+
196
+ For MPLang communicators:
197
+ - ThreadCommunicator: No-op (instant memory transfer)
198
+ - HttpCommunicator: No explicit wait needed (httpx handles it)
199
+
200
+ This is a no-op in current implementation.
201
+ """
202
+
203
+ def Abort(self) -> None:
204
+ """Abort communication (cleanup resources).
205
+
206
+ This could be extended to notify the communicator to drop pending
207
+ messages for this channel, but currently is a no-op.
208
+ """
209
+ logging.warning(
210
+ f"BaseChannel.Abort called: {self._local_rank} <-> {self._peer_rank}"
211
+ )
212
+ # Future: Could call comm.abort_session() if implemented
213
+
214
+ def SetThrottleWindowSize(self, size: int) -> None:
215
+ """Set throttle window size.
216
+
217
+ Not applicable to MPLang communicators. No-op.
218
+
219
+ Args:
220
+ size: Window size (ignored)
221
+ """
222
+
223
+ def SetChunkParallelSendSize(self, size: int) -> None:
224
+ """Set chunk parallel send size.
225
+
226
+ Not applicable to MPLang communicators. No-op.
227
+
228
+ Args:
229
+ size: Chunk size (ignored)
230
+ """
@@ -26,9 +26,9 @@ from typing import Any
26
26
  import uvicorn
27
27
  import yaml
28
28
 
29
- from mplang.core import ClusterSpec
30
- from mplang.runtime.client import HttpExecutorClient
31
- from mplang.runtime.server import app
29
+ from mplang.v1.core import ClusterSpec
30
+ from mplang.v1.runtime.client import HttpExecutorClient
31
+ from mplang.v1.runtime.server import app
32
32
 
33
33
 
34
34
  def load_config(config_path: str) -> ClusterSpec:
@@ -27,7 +27,7 @@ from typing import Any
27
27
 
28
28
  import httpx
29
29
 
30
- from mplang.kernels.value import Value, decode_value, encode_value
30
+ from mplang.v1.kernels.value import Value, decode_value, encode_value
31
31
 
32
32
 
33
33
  class ExecutionStatus:
@@ -23,8 +23,8 @@ from typing import Any
23
23
 
24
24
  import httpx
25
25
 
26
- from mplang.core.comm import CommunicatorBase
27
- from mplang.kernels.value import Value, decode_value, encode_value
26
+ from mplang.v1.core.comm import CommunicatorBase
27
+ from mplang.v1.kernels.value import Value, decode_value, encode_value
28
28
 
29
29
 
30
30
  class HttpCommunicator(CommunicatorBase):
@@ -57,7 +57,12 @@ class HttpCommunicator(CommunicatorBase):
57
57
  return str(res)
58
58
 
59
59
  def send(self, to: int, key: str, data: Any) -> None:
60
- """Sends data to a peer party by PUTing to its /comm/{key}/from/{from_rank} endpoint."""
60
+ """Sends data to a peer party by PUTing to its /comm/{key}/from/{from_rank} endpoint.
61
+
62
+ Supports two modes:
63
+ - SPU channel (key starts with "spu:"): sends raw bytes directly
64
+ - Normal channel: wraps data in Value envelope
65
+ """
61
66
  target_endpoint = self.endpoints[to]
62
67
  url = f"{target_endpoint}/sessions/{self.session_name}/comm/{key}/from/{self._rank}"
63
68
  logging.debug(
@@ -65,19 +70,20 @@ class HttpCommunicator(CommunicatorBase):
65
70
  )
66
71
 
67
72
  try:
68
- # Serialize data using Value envelope.
69
- if not isinstance(data, Value):
73
+ # SPU channel mode: send raw bytes directly
74
+ if key.startswith("spu:") and isinstance(data, bytes):
75
+ data_b64 = base64.b64encode(data).decode("utf-8")
76
+ request_data = {"data": data_b64, "is_raw_bytes": True}
77
+ # Normal mode: serialize using Value envelope
78
+ elif isinstance(data, Value):
79
+ data_bytes = encode_value(data)
80
+ data_b64 = base64.b64encode(data_bytes).decode("utf-8")
81
+ request_data = {"data": data_b64}
82
+ else:
70
83
  raise TypeError(
71
84
  f"Communicator requires Value instance, got {type(data).__name__}. "
72
85
  "Wrap data in TensorValue or custom Value subclass."
73
86
  )
74
- data_bytes = encode_value(data)
75
-
76
- data_b64 = base64.b64encode(data_bytes).decode("utf-8")
77
-
78
- request_data = {
79
- "data": data_b64,
80
- }
81
87
 
82
88
  response = httpx.put(url, json=request_data, timeout=60)
83
89
  logging.debug(f"Send response: status={response.status_code}")
@@ -91,14 +97,32 @@ class HttpCommunicator(CommunicatorBase):
91
97
  raise OSError(f"Failed to send data to rank {to}") from e
92
98
 
93
99
  def recv(self, frm: int, key: str) -> Any:
94
- """Wait until the key is set, returns the value. Override to add logging."""
100
+ """Wait until the key is set, returns the value.
101
+
102
+ Supports two modes:
103
+ - SPU channel (key starts with "spu:"): returns raw bytes
104
+ - Normal channel: returns deserialized Value
105
+ """
95
106
  logging.debug(
96
107
  f"Waiting to receive: from_rank={frm}, to_rank={self._rank}, key={key}"
97
108
  )
98
- data_b64 = super().recv(frm, key)
109
+ received_data = super().recv(frm, key)
99
110
 
111
+ # Check if this is raw bytes (SPU channel)
112
+ if isinstance(received_data, dict) and received_data.get("is_raw_bytes"):
113
+ data_bytes = base64.b64decode(received_data["data"])
114
+ logging.debug(
115
+ f"Received raw bytes: from_rank={frm}, to_rank={self._rank}, key={key}, size={len(data_bytes)}"
116
+ )
117
+ return data_bytes
118
+
119
+ # Normal mode: deserialize Value envelope
120
+ data_b64 = (
121
+ received_data
122
+ if isinstance(received_data, str)
123
+ else received_data.get("data")
124
+ )
100
125
  data_bytes = base64.b64decode(data_b64)
101
- # Deserialize using Value envelope
102
126
  result = decode_value(data_bytes)
103
127
 
104
128
  logging.debug(