mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v1/ops/sql_cc.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import sqlglot as sg
|
|
18
|
+
from jax.tree_util import PyTreeDef, tree_flatten
|
|
19
|
+
from sqlglot import exp as sge
|
|
20
|
+
from sqlglot.optimizer import annotate_types as opt_annot
|
|
21
|
+
from sqlglot.optimizer import qualify as opt_qualify
|
|
22
|
+
|
|
23
|
+
from mplang.v1.core import MPObject, PFunction, TableType
|
|
24
|
+
from mplang.v1.core.dtypes import (
|
|
25
|
+
BINARY,
|
|
26
|
+
BOOL,
|
|
27
|
+
DATE,
|
|
28
|
+
DECIMAL,
|
|
29
|
+
FLOAT32,
|
|
30
|
+
FLOAT64,
|
|
31
|
+
INT8,
|
|
32
|
+
INT16,
|
|
33
|
+
INT32,
|
|
34
|
+
INT64,
|
|
35
|
+
INTERVAL,
|
|
36
|
+
JSON,
|
|
37
|
+
STRING,
|
|
38
|
+
TIME,
|
|
39
|
+
TIMESTAMP,
|
|
40
|
+
UINT8,
|
|
41
|
+
UINT16,
|
|
42
|
+
UINT32,
|
|
43
|
+
UINT64,
|
|
44
|
+
UUID,
|
|
45
|
+
DType,
|
|
46
|
+
)
|
|
47
|
+
from mplang.v1.ops.base import stateless_mod
|
|
48
|
+
|
|
49
|
+
_SQL_MOD = stateless_mod("sql")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# Static dtype mappings (MPLang <-> SQL)
|
|
53
|
+
MP_TO_SQL_TYPE: dict[DType, str] = {
|
|
54
|
+
# Floats
|
|
55
|
+
FLOAT64: "DOUBLE",
|
|
56
|
+
FLOAT32: "FLOAT",
|
|
57
|
+
# Signed ints
|
|
58
|
+
INT8: "TINYINT",
|
|
59
|
+
INT16: "SMALLINT",
|
|
60
|
+
INT32: "INT",
|
|
61
|
+
INT64: "BIGINT",
|
|
62
|
+
# Unsigned ints (portable approximations)
|
|
63
|
+
UINT8: "SMALLINT",
|
|
64
|
+
UINT16: "INT",
|
|
65
|
+
UINT32: "BIGINT",
|
|
66
|
+
UINT64: "DECIMAL(38)",
|
|
67
|
+
# Booleans & strings
|
|
68
|
+
BOOL: "BOOLEAN",
|
|
69
|
+
STRING: "VARCHAR",
|
|
70
|
+
# Dates / times
|
|
71
|
+
DATE: "DATE",
|
|
72
|
+
TIME: "TIME",
|
|
73
|
+
TIMESTAMP: "TIMESTAMP",
|
|
74
|
+
# Other table types
|
|
75
|
+
DECIMAL: "DECIMAL",
|
|
76
|
+
JSON: "JSON",
|
|
77
|
+
BINARY: "BLOB",
|
|
78
|
+
UUID: "UUID",
|
|
79
|
+
INTERVAL: "INTERVAL",
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
SQL_TYPE_TO_MP: dict[str, DType] = {
|
|
83
|
+
# Floats
|
|
84
|
+
"double": FLOAT64,
|
|
85
|
+
"double precision": FLOAT64,
|
|
86
|
+
"float": FLOAT32,
|
|
87
|
+
"real": FLOAT32,
|
|
88
|
+
# Signed ints
|
|
89
|
+
"bigint": INT64,
|
|
90
|
+
"long": INT64,
|
|
91
|
+
"int": INT32,
|
|
92
|
+
"integer": INT32,
|
|
93
|
+
"int4": INT32,
|
|
94
|
+
"smallint": INT16,
|
|
95
|
+
"int2": INT16,
|
|
96
|
+
"tinyint": INT8,
|
|
97
|
+
"int1": INT8,
|
|
98
|
+
# Unsigned (rare in SQL)
|
|
99
|
+
"uint8": UINT8,
|
|
100
|
+
"ubyte": UINT8,
|
|
101
|
+
"uint16": UINT16,
|
|
102
|
+
"uint32": UINT32,
|
|
103
|
+
"uint64": UINT64,
|
|
104
|
+
# Booleans / strings
|
|
105
|
+
"bool": BOOL,
|
|
106
|
+
"boolean": BOOL,
|
|
107
|
+
"char": STRING,
|
|
108
|
+
"varchar": STRING,
|
|
109
|
+
"text": STRING,
|
|
110
|
+
"string": STRING,
|
|
111
|
+
# Dates / times
|
|
112
|
+
"date": DATE,
|
|
113
|
+
"time": TIME,
|
|
114
|
+
"timestamp": TIMESTAMP,
|
|
115
|
+
# Decimal / numeric
|
|
116
|
+
"decimal": DECIMAL,
|
|
117
|
+
"numeric": DECIMAL,
|
|
118
|
+
# Others
|
|
119
|
+
"json": JSON,
|
|
120
|
+
"binary": BINARY,
|
|
121
|
+
"varbinary": BINARY,
|
|
122
|
+
"blob": BINARY,
|
|
123
|
+
"uuid": UUID,
|
|
124
|
+
"interval": INTERVAL,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _deduce_out_schema(
|
|
129
|
+
parsed: sge.Expression,
|
|
130
|
+
dialect: str,
|
|
131
|
+
in_schemas: dict[str, TableType],
|
|
132
|
+
) -> TableType:
|
|
133
|
+
"""Deduce output schema using sqlglot's qualify + annotate_types.
|
|
134
|
+
|
|
135
|
+
This implementation leverages sqlglot's optimizer to resolve table/column
|
|
136
|
+
references (including star expansion) and annotate expression types. It then
|
|
137
|
+
maps sqlglot DataType to mplang DType and returns a TableType.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
# 1) Build sqlglot schema from MPObject/TableType inputs
|
|
141
|
+
def _dtype_to_sql(dt: DType) -> str:
|
|
142
|
+
return MP_TO_SQL_TYPE.get(dt, "VARCHAR")
|
|
143
|
+
|
|
144
|
+
sqlglot_schema: dict[str, dict[str, str]] = {
|
|
145
|
+
tname: {col: _dtype_to_sql(dt) for col, dt in schema.columns}
|
|
146
|
+
for tname, schema in in_schemas.items()
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
# 2) Parse with read dialect; 3) Qualify (resolve names, expand star); 4) Annotate types
|
|
150
|
+
qualified = opt_qualify.qualify(parsed, schema=sqlglot_schema, dialect=dialect)
|
|
151
|
+
typed = opt_annot.annotate_types(qualified, schema=sqlglot_schema)
|
|
152
|
+
|
|
153
|
+
# 5) Extract projection names and types
|
|
154
|
+
select = typed if isinstance(typed, sge.Select) else typed.find(sge.Select)
|
|
155
|
+
if select is None:
|
|
156
|
+
raise NotImplementedError(
|
|
157
|
+
"Only SELECT queries are supported for schema deduction"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def _sqlglot_type_to_dtype(tobj: Any) -> DType:
|
|
161
|
+
ts = str(tobj).lower().replace(" with time zone", "").strip()
|
|
162
|
+
base = ts.split("(", 1)[0].strip()
|
|
163
|
+
return SQL_TYPE_TO_MP.get(base, STRING)
|
|
164
|
+
|
|
165
|
+
pairs: list[tuple[str, DType]] = []
|
|
166
|
+
idx = 0
|
|
167
|
+
used: set[str] = set()
|
|
168
|
+
for proj in select.expressions:
|
|
169
|
+
name = getattr(proj, "alias_or_name", None) or getattr(proj, "name", None)
|
|
170
|
+
if not name:
|
|
171
|
+
name = f"expr_{idx}"
|
|
172
|
+
idx += 1
|
|
173
|
+
t = getattr(proj, "type", None)
|
|
174
|
+
if t is None:
|
|
175
|
+
raise NotImplementedError(
|
|
176
|
+
"Cannot infer type for projection; please provide out_type explicitly"
|
|
177
|
+
)
|
|
178
|
+
dtype = _sqlglot_type_to_dtype(t)
|
|
179
|
+
if name in used:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"Duplicate output column name '{name}' after qualification"
|
|
182
|
+
)
|
|
183
|
+
used.add(name)
|
|
184
|
+
pairs.append((name, dtype))
|
|
185
|
+
|
|
186
|
+
return TableType.from_pairs(pairs)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@_SQL_MOD.op_def()
|
|
190
|
+
def run_sql(
|
|
191
|
+
query: str,
|
|
192
|
+
*,
|
|
193
|
+
out_type: TableType | None = None,
|
|
194
|
+
dialect: str = "duckdb",
|
|
195
|
+
**in_tables: Any,
|
|
196
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
197
|
+
"""Build a sql.run PFunction from a SQL query with optional schema deduction.
|
|
198
|
+
|
|
199
|
+
API: run_sql(query: str, *, out_type: TableType | None = None, dialect: str = "duckdb", **in_tables) -> (PFunction, [MPObject], PyTreeDef)
|
|
200
|
+
|
|
201
|
+
Semantics:
|
|
202
|
+
- Parses the SQL and binds only the tables that are actually referenced in the query by name.
|
|
203
|
+
- If ``out_type`` is not provided, attempts to deduce the output table schema using sqlglot (qualify + annotate types).
|
|
204
|
+
- Returns a triad consisting of the constructed PFunction (``fn_type='sql.run'``), the ordered list of input MPObjects, and the output PyTreeDef.
|
|
205
|
+
|
|
206
|
+
Difference vs ``run_sql_raw``: this op can infer ``out_type`` and will parse the SQL to filter inputs; ``run_sql_raw`` requires an explicit ``out_type`` and does not parse/filter inputs.
|
|
207
|
+
"""
|
|
208
|
+
# Extract required table names from SQL (order by first appearance)
|
|
209
|
+
parsed = sg.parse_one(query, read=dialect)
|
|
210
|
+
required_names: list[str] = []
|
|
211
|
+
for t in parsed.find_all(sge.Table):
|
|
212
|
+
# Prefer .name; fallback to str(this) if needed
|
|
213
|
+
tname = getattr(t, "name", None) or str(t.this)
|
|
214
|
+
if tname not in required_names:
|
|
215
|
+
required_names.append(tname)
|
|
216
|
+
|
|
217
|
+
# Disallow extras not referenced by the query to avoid surprises
|
|
218
|
+
extra = set(in_tables.keys()) - set(required_names)
|
|
219
|
+
if extra:
|
|
220
|
+
raise ValueError(
|
|
221
|
+
f"Unexpected tables provided that are not referenced in SQL: {sorted(extra)}"
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Validate required tables and require MPObject for runtime registration
|
|
225
|
+
in_names: list[str] = []
|
|
226
|
+
ins_info: list[TableType] = []
|
|
227
|
+
in_vars: list[MPObject] = []
|
|
228
|
+
for name in required_names:
|
|
229
|
+
if name not in in_tables:
|
|
230
|
+
raise KeyError(f"Missing required table '{name}' for SQL query")
|
|
231
|
+
obj = in_tables[name]
|
|
232
|
+
if not isinstance(obj, MPObject):
|
|
233
|
+
raise TypeError(
|
|
234
|
+
f"Table '{name}' must be an MPObject (for runtime registration), got {type(obj).__name__}"
|
|
235
|
+
)
|
|
236
|
+
assert obj.schema is not None, f"Input table '{name}' missing schema"
|
|
237
|
+
in_vars.append(obj)
|
|
238
|
+
ins_info.append(obj.schema)
|
|
239
|
+
in_names.append(name)
|
|
240
|
+
|
|
241
|
+
if out_type is None:
|
|
242
|
+
in_schemas: dict[str, TableType] = {
|
|
243
|
+
n: in_tables[n].schema for n in required_names
|
|
244
|
+
}
|
|
245
|
+
out_type = _deduce_out_schema(parsed, dialect, in_schemas)
|
|
246
|
+
|
|
247
|
+
pfn = PFunction(
|
|
248
|
+
fn_type="sql.run",
|
|
249
|
+
ins_info=tuple(ins_info),
|
|
250
|
+
outs_info=(out_type,),
|
|
251
|
+
fn_name="",
|
|
252
|
+
fn_text=query,
|
|
253
|
+
in_names=tuple(in_names),
|
|
254
|
+
dialect=dialect,
|
|
255
|
+
)
|
|
256
|
+
_, treedef = tree_flatten(out_type)
|
|
257
|
+
return pfn, in_vars, treedef
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@_SQL_MOD.op_def()
|
|
261
|
+
def run_sql_raw(
|
|
262
|
+
query: str,
|
|
263
|
+
out_type: TableType,
|
|
264
|
+
*,
|
|
265
|
+
dialect: str = "duckdb",
|
|
266
|
+
in_tables: dict[str, MPObject] | None = None,
|
|
267
|
+
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
268
|
+
"""Build a sql.run PFunction from a SQL query with an explicit output schema.
|
|
269
|
+
|
|
270
|
+
API: run_sql_raw(query: str, out_type: TableType, *, dialect: str = "duckdb", in_tables: dict[str, MPObject] | None = None) -> (PFunction, [MPObject], PyTreeDef)
|
|
271
|
+
|
|
272
|
+
Semantics:
|
|
273
|
+
- Does not parse the SQL; carries all tables provided via ``in_tables`` in the mapping's iteration order.
|
|
274
|
+
- Requires an explicit ``out_type``; no schema deduction is attempted.
|
|
275
|
+
- Returns a triad consisting of the constructed PFunction (``fn_type='sql.run'``), the ordered list of input MPObjects, and the output PyTreeDef.
|
|
276
|
+
|
|
277
|
+
Difference vs ``run_sql``: this op requires ``out_type`` and does not parse/filter inputs; ``run_sql`` can infer ``out_type`` and selects only tables referenced by the query.
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
# Collect inputs strictly as provided by caller
|
|
281
|
+
in_names: list[str] = []
|
|
282
|
+
ins_info: list[TableType] = []
|
|
283
|
+
in_vars: list[MPObject] = []
|
|
284
|
+
if in_tables:
|
|
285
|
+
for name, tbl in in_tables.items():
|
|
286
|
+
if not isinstance(tbl, MPObject):
|
|
287
|
+
raise TypeError(f"Input table '{name}' is not an MPObject {type(tbl)}")
|
|
288
|
+
assert tbl.schema is not None, f"Input table '{name}' is missing a schema"
|
|
289
|
+
in_names.append(name)
|
|
290
|
+
ins_info.append(tbl.schema)
|
|
291
|
+
in_vars.append(tbl)
|
|
292
|
+
|
|
293
|
+
pfn = PFunction(
|
|
294
|
+
fn_type="sql.run",
|
|
295
|
+
fn_name="",
|
|
296
|
+
fn_text=query,
|
|
297
|
+
ins_info=tuple(ins_info),
|
|
298
|
+
outs_info=(out_type,),
|
|
299
|
+
in_names=tuple(in_names),
|
|
300
|
+
dialect=dialect,
|
|
301
|
+
)
|
|
302
|
+
_, treedef = tree_flatten(out_type)
|
|
303
|
+
return pfn, in_vars, treedef
|
mplang/{ops → v1/ops}/tee.py
RENAMED
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
|
-
from mplang.core import UINT8, TensorType
|
|
18
|
-
from mplang.ops.base import stateless_mod
|
|
17
|
+
from mplang.v1.core import UINT8, TensorType
|
|
18
|
+
from mplang.v1.ops.base import stateless_mod
|
|
19
19
|
|
|
20
20
|
_TEE_MOD = stateless_mod("tee")
|
|
21
21
|
|
|
@@ -20,8 +20,8 @@ This module contains runtime implementations including:
|
|
|
20
20
|
- Driver for distributed execution
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
|
-
from mplang.runtime.driver import Driver, DriverVar
|
|
24
|
-
from mplang.runtime.simulation import Simulator
|
|
23
|
+
from mplang.v1.runtime.driver import Driver, DriverVar
|
|
24
|
+
from mplang.v1.runtime.simulation import Simulator
|
|
25
25
|
|
|
26
26
|
__all__ = [
|
|
27
27
|
"Driver",
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""SPU IChannel implementation that bridges to MPLang CommunicatorBase.
|
|
16
|
+
|
|
17
|
+
This module provides BaseChannel, which allows SPU to reuse MPLang's
|
|
18
|
+
existing communication layer (ThreadCommunicator/HttpCommunicator) instead
|
|
19
|
+
of creating separate BRPC connections.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import logging
|
|
25
|
+
from typing import TYPE_CHECKING
|
|
26
|
+
|
|
27
|
+
import spu.libspu as libspu
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from mplang.v1.core.comm import CommunicatorBase
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BaseChannel(libspu.link.IChannel):
|
|
34
|
+
"""Bridge MPLang CommunicatorBase to SPU IChannel interface.
|
|
35
|
+
|
|
36
|
+
This adapter allows SPU to use MPLang's existing communication layer
|
|
37
|
+
(ThreadCommunicator or HttpCommunicator) instead of creating separate
|
|
38
|
+
BRPC connections.
|
|
39
|
+
|
|
40
|
+
Each BaseChannel represents a channel to ONE peer rank.
|
|
41
|
+
|
|
42
|
+
Communication Protocol:
|
|
43
|
+
- SPU calls send(tag, bytes_data) -> MPLang comm.send(peer, key, bytes_data)
|
|
44
|
+
- SPU calls recv(tag) -> bytes_data <- MPLang comm.recv(peer, key)
|
|
45
|
+
|
|
46
|
+
Tag Namespace:
|
|
47
|
+
All tags are prefixed with "spu:" to avoid collision with other
|
|
48
|
+
MPLang traffic on the same communicator.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
comm: CommunicatorBase,
|
|
54
|
+
local_rank: int,
|
|
55
|
+
peer_rank: int,
|
|
56
|
+
tag_prefix: str = "spu",
|
|
57
|
+
):
|
|
58
|
+
"""Initialize channel to a specific peer.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
comm: MPLang communicator instance (Thread/Http)
|
|
62
|
+
local_rank: Global rank of this party (for logging/debugging)
|
|
63
|
+
peer_rank: Global rank of the peer party
|
|
64
|
+
tag_prefix: Prefix for all tags to avoid collision (default: "spu")
|
|
65
|
+
"""
|
|
66
|
+
super().__init__()
|
|
67
|
+
self._comm = comm
|
|
68
|
+
self._local_rank = local_rank
|
|
69
|
+
self._peer_rank = peer_rank
|
|
70
|
+
self._tag_prefix = tag_prefix
|
|
71
|
+
|
|
72
|
+
logging.debug(
|
|
73
|
+
f"BaseChannel initialized: local_rank={local_rank}, "
|
|
74
|
+
f"peer_rank={peer_rank}, tag_prefix={tag_prefix}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def _make_key(self, tag: str) -> str:
|
|
78
|
+
"""Create unique key for MPLang comm.
|
|
79
|
+
|
|
80
|
+
Prefixes the tag to avoid collision with non-SPU traffic.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
tag: SPU-provided tag (e.g., "send_0", "recv_0")
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Prefixed key (e.g., "spu:send_0")
|
|
87
|
+
"""
|
|
88
|
+
return f"{self._tag_prefix}:{tag}"
|
|
89
|
+
|
|
90
|
+
def Send(self, tag: str, data: bytes) -> None:
|
|
91
|
+
"""Send bytes to peer (synchronous in SPU semantics).
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
tag: Message tag for matching send/recv pairs
|
|
95
|
+
data: Raw bytes to send
|
|
96
|
+
"""
|
|
97
|
+
key = self._make_key(tag)
|
|
98
|
+
logging.debug(
|
|
99
|
+
f"BaseChannel.Send: {self._local_rank} -> {self._peer_rank}, "
|
|
100
|
+
f"tag={tag}, key={key}, size={len(data)}"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Send raw bytes directly
|
|
104
|
+
# Note: CommunicatorBase.send expects Any type, bytes is acceptable
|
|
105
|
+
self._comm.send(self._peer_rank, key, data)
|
|
106
|
+
|
|
107
|
+
def Recv(self, tag: str) -> bytes:
|
|
108
|
+
"""Receive bytes from peer (blocking).
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
tag: Message tag for matching send/recv pairs
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Raw bytes received
|
|
115
|
+
"""
|
|
116
|
+
key = self._make_key(tag)
|
|
117
|
+
logging.debug(
|
|
118
|
+
f"BaseChannel.Recv: {self._local_rank} <- {self._peer_rank}, "
|
|
119
|
+
f"tag={tag}, key={key}"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Receive data (should be bytes)
|
|
123
|
+
data = self._comm.recv(self._peer_rank, key)
|
|
124
|
+
|
|
125
|
+
# Validate data type
|
|
126
|
+
if not isinstance(data, bytes):
|
|
127
|
+
raise TypeError(
|
|
128
|
+
f"Expected bytes from communicator, got {type(data).__name__}. "
|
|
129
|
+
f"Communicator must support raw bytes transmission for SPU channels."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
logging.debug(
|
|
133
|
+
f"BaseChannel.Recv complete: {self._local_rank} <- {self._peer_rank}, "
|
|
134
|
+
f"tag={tag}, size={len(data)}"
|
|
135
|
+
)
|
|
136
|
+
return data
|
|
137
|
+
|
|
138
|
+
def SendAsync(self, tag: str, data: bytes) -> None:
|
|
139
|
+
"""Async send (MPLang's send is already async at network layer).
|
|
140
|
+
|
|
141
|
+
For HttpCommunicator, the underlying httpx.put() is non-blocking
|
|
142
|
+
at the HTTP client level. For ThreadCommunicator, send is instant
|
|
143
|
+
(memory transfer).
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
tag: Message tag
|
|
147
|
+
data: Raw bytes to send
|
|
148
|
+
"""
|
|
149
|
+
# Reuse synchronous send - it's already async underneath
|
|
150
|
+
self.Send(tag, data)
|
|
151
|
+
|
|
152
|
+
def SendAsyncThrottled(self, tag: str, data: bytes) -> None:
|
|
153
|
+
"""Throttled async send.
|
|
154
|
+
|
|
155
|
+
Currently maps to regular SendAsync. Future optimization could
|
|
156
|
+
implement rate limiting if needed.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
tag: Message tag
|
|
160
|
+
data: Raw bytes to send
|
|
161
|
+
"""
|
|
162
|
+
self.SendAsync(tag, data)
|
|
163
|
+
|
|
164
|
+
def TestSend(self, timeout: int) -> None:
|
|
165
|
+
"""Test if this channel can send a dummy msg to peer.
|
|
166
|
+
|
|
167
|
+
Uses fixed 0 seq_id as dummy msg's id to make this function reentrant.
|
|
168
|
+
ConnectToMesh will retry on this multiple times.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
timeout: Timeout in milliseconds
|
|
172
|
+
"""
|
|
173
|
+
# Send a handshake message to test connectivity
|
|
174
|
+
# Use fixed tag "__test__" to make this reentrant (idempotent)
|
|
175
|
+
test_data = b"\x00" # Minimal 1-byte message with seq_id=0
|
|
176
|
+
self.Send("__test__", test_data)
|
|
177
|
+
|
|
178
|
+
def TestRecv(self) -> None:
|
|
179
|
+
"""Wait for dummy msg from peer.
|
|
180
|
+
|
|
181
|
+
Timeout is controlled by recv_timeout_ms in link descriptor.
|
|
182
|
+
"""
|
|
183
|
+
# Receive the handshake message from peer
|
|
184
|
+
# This blocks until message arrives (timeout from desc.recv_timeout_ms)
|
|
185
|
+
test_data = self.Recv("__test__")
|
|
186
|
+
# Validate it's the expected handshake message
|
|
187
|
+
if test_data != b"\x00":
|
|
188
|
+
logging.warning(
|
|
189
|
+
f"TestRecv: unexpected handshake data from {self._peer_rank}, "
|
|
190
|
+
f"expected b'\\x00', got {test_data!r}"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def WaitLinkTaskFinish(self) -> None:
|
|
194
|
+
"""Wait for all pending async tasks.
|
|
195
|
+
|
|
196
|
+
For MPLang communicators:
|
|
197
|
+
- ThreadCommunicator: No-op (instant memory transfer)
|
|
198
|
+
- HttpCommunicator: No explicit wait needed (httpx handles it)
|
|
199
|
+
|
|
200
|
+
This is a no-op in current implementation.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
def Abort(self) -> None:
|
|
204
|
+
"""Abort communication (cleanup resources).
|
|
205
|
+
|
|
206
|
+
This could be extended to notify the communicator to drop pending
|
|
207
|
+
messages for this channel, but currently is a no-op.
|
|
208
|
+
"""
|
|
209
|
+
logging.warning(
|
|
210
|
+
f"BaseChannel.Abort called: {self._local_rank} <-> {self._peer_rank}"
|
|
211
|
+
)
|
|
212
|
+
# Future: Could call comm.abort_session() if implemented
|
|
213
|
+
|
|
214
|
+
def SetThrottleWindowSize(self, size: int) -> None:
|
|
215
|
+
"""Set throttle window size.
|
|
216
|
+
|
|
217
|
+
Not applicable to MPLang communicators. No-op.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
size: Window size (ignored)
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
def SetChunkParallelSendSize(self, size: int) -> None:
|
|
224
|
+
"""Set chunk parallel send size.
|
|
225
|
+
|
|
226
|
+
Not applicable to MPLang communicators. No-op.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
size: Chunk size (ignored)
|
|
230
|
+
"""
|
|
@@ -26,9 +26,9 @@ from typing import Any
|
|
|
26
26
|
import uvicorn
|
|
27
27
|
import yaml
|
|
28
28
|
|
|
29
|
-
from mplang.core import ClusterSpec
|
|
30
|
-
from mplang.runtime.client import HttpExecutorClient
|
|
31
|
-
from mplang.runtime.server import app
|
|
29
|
+
from mplang.v1.core import ClusterSpec
|
|
30
|
+
from mplang.v1.runtime.client import HttpExecutorClient
|
|
31
|
+
from mplang.v1.runtime.server import app
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
def load_config(config_path: str) -> ClusterSpec:
|
|
@@ -23,8 +23,8 @@ from typing import Any
|
|
|
23
23
|
|
|
24
24
|
import httpx
|
|
25
25
|
|
|
26
|
-
from mplang.core.comm import CommunicatorBase
|
|
27
|
-
from mplang.kernels.value import Value, decode_value, encode_value
|
|
26
|
+
from mplang.v1.core.comm import CommunicatorBase
|
|
27
|
+
from mplang.v1.kernels.value import Value, decode_value, encode_value
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class HttpCommunicator(CommunicatorBase):
|
|
@@ -57,7 +57,12 @@ class HttpCommunicator(CommunicatorBase):
|
|
|
57
57
|
return str(res)
|
|
58
58
|
|
|
59
59
|
def send(self, to: int, key: str, data: Any) -> None:
|
|
60
|
-
"""Sends data to a peer party by PUTing to its /comm/{key}/from/{from_rank} endpoint.
|
|
60
|
+
"""Sends data to a peer party by PUTing to its /comm/{key}/from/{from_rank} endpoint.
|
|
61
|
+
|
|
62
|
+
Supports two modes:
|
|
63
|
+
- SPU channel (key starts with "spu:"): sends raw bytes directly
|
|
64
|
+
- Normal channel: wraps data in Value envelope
|
|
65
|
+
"""
|
|
61
66
|
target_endpoint = self.endpoints[to]
|
|
62
67
|
url = f"{target_endpoint}/sessions/{self.session_name}/comm/{key}/from/{self._rank}"
|
|
63
68
|
logging.debug(
|
|
@@ -65,19 +70,20 @@ class HttpCommunicator(CommunicatorBase):
|
|
|
65
70
|
)
|
|
66
71
|
|
|
67
72
|
try:
|
|
68
|
-
#
|
|
69
|
-
if
|
|
73
|
+
# SPU channel mode: send raw bytes directly
|
|
74
|
+
if key.startswith("spu:") and isinstance(data, bytes):
|
|
75
|
+
data_b64 = base64.b64encode(data).decode("utf-8")
|
|
76
|
+
request_data = {"data": data_b64, "is_raw_bytes": True}
|
|
77
|
+
# Normal mode: serialize using Value envelope
|
|
78
|
+
elif isinstance(data, Value):
|
|
79
|
+
data_bytes = encode_value(data)
|
|
80
|
+
data_b64 = base64.b64encode(data_bytes).decode("utf-8")
|
|
81
|
+
request_data = {"data": data_b64}
|
|
82
|
+
else:
|
|
70
83
|
raise TypeError(
|
|
71
84
|
f"Communicator requires Value instance, got {type(data).__name__}. "
|
|
72
85
|
"Wrap data in TensorValue or custom Value subclass."
|
|
73
86
|
)
|
|
74
|
-
data_bytes = encode_value(data)
|
|
75
|
-
|
|
76
|
-
data_b64 = base64.b64encode(data_bytes).decode("utf-8")
|
|
77
|
-
|
|
78
|
-
request_data = {
|
|
79
|
-
"data": data_b64,
|
|
80
|
-
}
|
|
81
87
|
|
|
82
88
|
response = httpx.put(url, json=request_data, timeout=60)
|
|
83
89
|
logging.debug(f"Send response: status={response.status_code}")
|
|
@@ -91,14 +97,32 @@ class HttpCommunicator(CommunicatorBase):
|
|
|
91
97
|
raise OSError(f"Failed to send data to rank {to}") from e
|
|
92
98
|
|
|
93
99
|
def recv(self, frm: int, key: str) -> Any:
|
|
94
|
-
"""Wait until the key is set, returns the value.
|
|
100
|
+
"""Wait until the key is set, returns the value.
|
|
101
|
+
|
|
102
|
+
Supports two modes:
|
|
103
|
+
- SPU channel (key starts with "spu:"): returns raw bytes
|
|
104
|
+
- Normal channel: returns deserialized Value
|
|
105
|
+
"""
|
|
95
106
|
logging.debug(
|
|
96
107
|
f"Waiting to receive: from_rank={frm}, to_rank={self._rank}, key={key}"
|
|
97
108
|
)
|
|
98
|
-
|
|
109
|
+
received_data = super().recv(frm, key)
|
|
99
110
|
|
|
111
|
+
# Check if this is raw bytes (SPU channel)
|
|
112
|
+
if isinstance(received_data, dict) and received_data.get("is_raw_bytes"):
|
|
113
|
+
data_bytes = base64.b64decode(received_data["data"])
|
|
114
|
+
logging.debug(
|
|
115
|
+
f"Received raw bytes: from_rank={frm}, to_rank={self._rank}, key={key}, size={len(data_bytes)}"
|
|
116
|
+
)
|
|
117
|
+
return data_bytes
|
|
118
|
+
|
|
119
|
+
# Normal mode: deserialize Value envelope
|
|
120
|
+
data_b64 = (
|
|
121
|
+
received_data
|
|
122
|
+
if isinstance(received_data, str)
|
|
123
|
+
else received_data.get("data")
|
|
124
|
+
)
|
|
100
125
|
data_bytes = base64.b64decode(data_b64)
|
|
101
|
-
# Deserialize using Value envelope
|
|
102
126
|
result = decode_value(data_bytes)
|
|
103
127
|
|
|
104
128
|
logging.debug(
|