mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/device.py
DELETED
|
@@ -1,327 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
This module provides the device oriented programming interface for MPC.
|
|
17
|
-
|
|
18
|
-
The device oriented programming interface is designed to provide a high-level
|
|
19
|
-
abstraction for the MPC programming. It allows the user to write the program
|
|
20
|
-
in a device-oriented manner, and the runtime will take care of the data
|
|
21
|
-
transformation between devices.
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
from __future__ import annotations
|
|
25
|
-
|
|
26
|
-
from collections.abc import Callable
|
|
27
|
-
from functools import partial, wraps
|
|
28
|
-
from typing import Any
|
|
29
|
-
|
|
30
|
-
from jax.tree_util import tree_map
|
|
31
|
-
|
|
32
|
-
import mplang.host as mphost
|
|
33
|
-
from mplang.core import (
|
|
34
|
-
ClusterSpec,
|
|
35
|
-
Device,
|
|
36
|
-
InterpContext,
|
|
37
|
-
MPObject,
|
|
38
|
-
TensorType,
|
|
39
|
-
cur_ctx,
|
|
40
|
-
primitive,
|
|
41
|
-
)
|
|
42
|
-
from mplang.ops import basic, crypto, ibis_cc, jax_cc, tee
|
|
43
|
-
from mplang.ops.base import FeOperation
|
|
44
|
-
from mplang.ops.ibis_cc import IbisRunner
|
|
45
|
-
from mplang.ops.jax_cc import JaxRunner
|
|
46
|
-
from mplang.simp import mpi, smpc
|
|
47
|
-
from mplang.simp.api import run_at
|
|
48
|
-
|
|
49
|
-
# Automatic transfer between devices when parameter is not on the target device.
|
|
50
|
-
g_auto_trans: bool = True
|
|
51
|
-
|
|
52
|
-
_HKDF_INFO_LITERAL: str = "mplang/device/tee/v1"
|
|
53
|
-
# Default KEM suite for TEE session establishment; make configurable via ClusterSpec in future.
|
|
54
|
-
_TEE_KEM_SUITE: str = "x25519"
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
# `function` decorator could also compile device-level apis.
|
|
58
|
-
function = primitive.function
|
|
59
|
-
|
|
60
|
-
# magic attribute name to mark a MPObject as a device object
|
|
61
|
-
DEVICE_ATTR_NAME = "_devid_"
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def _is_device_obj(obj: Any) -> bool:
|
|
65
|
-
if not isinstance(obj, MPObject):
|
|
66
|
-
return False
|
|
67
|
-
return DEVICE_ATTR_NAME in obj.attrs
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def _set_devid(obj: MPObject, dev_id: str) -> MPObject:
|
|
71
|
-
if not isinstance(obj, MPObject):
|
|
72
|
-
raise TypeError(f"Input must be an instance of Object, {obj}")
|
|
73
|
-
obj.attrs[DEVICE_ATTR_NAME] = dev_id
|
|
74
|
-
return obj
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def _get_devid(obj: MPObject) -> str:
|
|
78
|
-
if not isinstance(obj, MPObject):
|
|
79
|
-
raise TypeError("Input must be an instance of Object")
|
|
80
|
-
|
|
81
|
-
return obj.attrs[DEVICE_ATTR_NAME] # type: ignore[no-any-return]
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
_is_mpobj = lambda x: isinstance(x, MPObject)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def _device_run_spu(
|
|
88
|
-
dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
|
|
89
|
-
) -> Any:
|
|
90
|
-
if not isinstance(op, JaxRunner):
|
|
91
|
-
raise ValueError("SPU device only supports JAX frontend.")
|
|
92
|
-
fn, *aargs = args
|
|
93
|
-
var = smpc.srun(fn)(*aargs, **kwargs)
|
|
94
|
-
return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def _device_run_tee(
|
|
98
|
-
dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
|
|
99
|
-
) -> Any:
|
|
100
|
-
if not isinstance(op, JaxRunner) and not isinstance(op, IbisRunner):
|
|
101
|
-
raise ValueError("TEE device only supports JAX and Ibis frontend.")
|
|
102
|
-
assert len(dev_info.members) == 1
|
|
103
|
-
rank = dev_info.members[0].rank
|
|
104
|
-
var = run_at(rank, op, *args, **kwargs)
|
|
105
|
-
return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def _device_run_ppu(
|
|
109
|
-
dev_info: Device, op: FeOperation, *args: Any, **kwargs: Any
|
|
110
|
-
) -> Any:
|
|
111
|
-
assert len(dev_info.members) == 1
|
|
112
|
-
rank = dev_info.members[0].rank
|
|
113
|
-
var = run_at(rank, op, *args, **kwargs)
|
|
114
|
-
return tree_map(partial(_set_devid, dev_id=dev_info.name), var)
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def _device_run(dev_id: str, op: FeOperation, *args: Any, **kwargs: Any) -> Any:
|
|
118
|
-
assert isinstance(op, FeOperation)
|
|
119
|
-
cluster_spec = mphost.cur_ctx().cluster_spec
|
|
120
|
-
if dev_id not in cluster_spec.devices:
|
|
121
|
-
raise ValueError(f"Device {dev_id} not found in cluster spec.")
|
|
122
|
-
|
|
123
|
-
if g_auto_trans:
|
|
124
|
-
|
|
125
|
-
def trans(obj: Any) -> Any:
|
|
126
|
-
if _is_mpobj(obj):
|
|
127
|
-
assert _is_device_obj(obj)
|
|
128
|
-
return _d2d(dev_id, obj)
|
|
129
|
-
else:
|
|
130
|
-
return obj
|
|
131
|
-
|
|
132
|
-
args, kwargs = tree_map(trans, (args, kwargs))
|
|
133
|
-
|
|
134
|
-
dev_info = cluster_spec.devices[dev_id]
|
|
135
|
-
if dev_info.kind.upper() == "SPU":
|
|
136
|
-
return _device_run_spu(dev_info, op, *args, **kwargs)
|
|
137
|
-
elif dev_info.kind.upper() == "TEE":
|
|
138
|
-
return _device_run_tee(dev_info, op, *args, **kwargs)
|
|
139
|
-
elif dev_info.kind.upper() == "PPU":
|
|
140
|
-
return _device_run_ppu(dev_info, op, *args, **kwargs)
|
|
141
|
-
else:
|
|
142
|
-
raise ValueError(f"Unknown device type: {dev_info.kind}")
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
def device(dev_id: str, *, fe_type: str = "jax") -> Callable:
|
|
146
|
-
"""Decorator to mark a function to be executed on a specific device.
|
|
147
|
-
|
|
148
|
-
Args:
|
|
149
|
-
dev_id: The device id.
|
|
150
|
-
fe_type: The frontend type of the device, could be "jax" or "ibis".
|
|
151
|
-
|
|
152
|
-
Note: 'fe_type' is not needed if the decorated function is already a FeOperation.
|
|
153
|
-
|
|
154
|
-
Example:
|
|
155
|
-
>>> @device("P0")
|
|
156
|
-
... def foo(x, y):
|
|
157
|
-
... return x + y
|
|
158
|
-
"""
|
|
159
|
-
|
|
160
|
-
def deco(fn: Callable) -> Callable:
|
|
161
|
-
@wraps(fn)
|
|
162
|
-
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
|
163
|
-
if isinstance(fn, FeOperation):
|
|
164
|
-
return _device_run(dev_id, fn, *args, **kwargs)
|
|
165
|
-
else:
|
|
166
|
-
if fe_type == "jax":
|
|
167
|
-
return _device_run(dev_id, jax_cc.run_jax, fn, *args, **kwargs)
|
|
168
|
-
elif fe_type == "ibis":
|
|
169
|
-
return _device_run(dev_id, ibis_cc.run_ibis, fn, *args, **kwargs)
|
|
170
|
-
else:
|
|
171
|
-
raise ValueError(f"Unsupported frontend type: {fe_type}")
|
|
172
|
-
|
|
173
|
-
return wrapped
|
|
174
|
-
|
|
175
|
-
return deco
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
def _d2d(to_dev_id: str, obj: MPObject) -> MPObject:
|
|
179
|
-
assert isinstance(obj, MPObject)
|
|
180
|
-
frm_dev_id = _get_devid(obj)
|
|
181
|
-
|
|
182
|
-
if frm_dev_id == to_dev_id:
|
|
183
|
-
return obj
|
|
184
|
-
|
|
185
|
-
cluster_spec: ClusterSpec = mphost.cur_ctx().cluster_spec
|
|
186
|
-
frm_dev = cluster_spec.devices[frm_dev_id]
|
|
187
|
-
to_dev = cluster_spec.devices[to_dev_id]
|
|
188
|
-
frm_to_pair = (frm_dev.kind.upper(), to_dev.kind.upper())
|
|
189
|
-
|
|
190
|
-
if frm_to_pair == ("SPU", "SPU"):
|
|
191
|
-
raise NotImplementedError("Only one SPU is supported for now.")
|
|
192
|
-
elif frm_to_pair == ("SPU", "PPU"):
|
|
193
|
-
assert len(to_dev.members) == 1
|
|
194
|
-
to_rank = to_dev.members[0].rank
|
|
195
|
-
var = smpc.revealTo(obj, to_rank)
|
|
196
|
-
return tree_map(partial(_set_devid, dev_id=to_dev_id), var) # type: ignore[no-any-return]
|
|
197
|
-
elif frm_to_pair == ("PPU", "SPU"):
|
|
198
|
-
assert len(frm_dev.members) == 1
|
|
199
|
-
frm_rank = frm_dev.members[0].rank
|
|
200
|
-
var = smpc.sealFrom(obj, frm_rank)
|
|
201
|
-
return tree_map(partial(_set_devid, dev_id=to_dev_id), var) # type: ignore[no-any-return]
|
|
202
|
-
elif frm_to_pair == ("PPU", "PPU"):
|
|
203
|
-
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
|
204
|
-
frm_rank = frm_dev.members[0].rank
|
|
205
|
-
to_rank = to_dev.members[0].rank
|
|
206
|
-
var = mpi.p2p(frm_rank, to_rank, obj)
|
|
207
|
-
return tree_map(partial(_set_devid, dev_id=to_dev_id), var) # type: ignore[no-any-return]
|
|
208
|
-
elif frm_to_pair == ("PPU", "TEE"):
|
|
209
|
-
# Transparent handshake + encryption for the first transfer; reuse thereafter
|
|
210
|
-
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
|
211
|
-
frm_rank = frm_dev.members[0].rank
|
|
212
|
-
tee_rank = to_dev.members[0].rank
|
|
213
|
-
# Ensure sessions (both directions) exist for this PPU<->TEE pair
|
|
214
|
-
sess_p, sess_t = _ensure_tee_session(frm_dev_id, to_dev_id, frm_rank, tee_rank)
|
|
215
|
-
# Bytes-only path: pack -> enc -> p2p -> dec -> unpack (with static out type)
|
|
216
|
-
obj_ty = TensorType.from_obj(obj)
|
|
217
|
-
b = run_at(frm_rank, basic.pack, obj)
|
|
218
|
-
ct = run_at(frm_rank, crypto.enc, b, sess_p)
|
|
219
|
-
ct_at_tee = mpi.p2p(frm_rank, tee_rank, ct)
|
|
220
|
-
b_at_tee = run_at(tee_rank, crypto.dec, ct_at_tee, sess_t)
|
|
221
|
-
pt_at_tee = run_at(tee_rank, basic.unpack, b_at_tee, out_ty=obj_ty)
|
|
222
|
-
return tree_map(partial(_set_devid, dev_id=to_dev_id), pt_at_tee) # type: ignore[no-any-return]
|
|
223
|
-
elif frm_to_pair == ("TEE", "PPU"):
|
|
224
|
-
# Transparent encryption from TEE to a specific PPU using the reverse-direction session key
|
|
225
|
-
assert len(frm_dev.members) == 1 and len(to_dev.members) == 1
|
|
226
|
-
tee_rank = frm_dev.members[0].rank
|
|
227
|
-
ppu_rank = to_dev.members[0].rank
|
|
228
|
-
# Ensure bidirectional session established for this pair
|
|
229
|
-
sess_p, sess_t = _ensure_tee_session(to_dev_id, frm_dev_id, ppu_rank, tee_rank)
|
|
230
|
-
obj_ty = TensorType.from_obj(obj)
|
|
231
|
-
b = run_at(tee_rank, basic.pack, obj)
|
|
232
|
-
ct = run_at(tee_rank, crypto.enc, b, sess_t)
|
|
233
|
-
ct_at_ppu = mpi.p2p(tee_rank, ppu_rank, ct)
|
|
234
|
-
b_at_ppu = run_at(ppu_rank, crypto.dec, ct_at_ppu, sess_p)
|
|
235
|
-
pt_at_ppu = run_at(ppu_rank, basic.unpack, b_at_ppu, out_ty=obj_ty)
|
|
236
|
-
return tree_map(partial(_set_devid, dev_id=to_dev_id), pt_at_ppu) # type: ignore[no-any-return]
|
|
237
|
-
else:
|
|
238
|
-
supported = [
|
|
239
|
-
("SPU", "PPU"),
|
|
240
|
-
("PPU", "SPU"),
|
|
241
|
-
("PPU", "PPU"),
|
|
242
|
-
("PPU", "TEE"),
|
|
243
|
-
("TEE", "PPU"),
|
|
244
|
-
]
|
|
245
|
-
raise ValueError(
|
|
246
|
-
f"Unsupported device transfer: {frm_to_pair}. Supported pairs: {supported}."
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
def _ensure_tee_session(
|
|
251
|
-
frm_dev_id: str, to_dev_id: str, frm_rank: int, tee_rank: int
|
|
252
|
-
) -> tuple[MPObject, MPObject]:
|
|
253
|
-
"""Ensure a TEE session (sess_p at sender, sess_t at TEE) exists.
|
|
254
|
-
|
|
255
|
-
Returns (sess_p, sess_t).
|
|
256
|
-
"""
|
|
257
|
-
ctx = cur_ctx().root()
|
|
258
|
-
if not hasattr(ctx, "_tee_sessions"):
|
|
259
|
-
ctx._tee_sessions = {} # type: ignore[attr-defined]
|
|
260
|
-
cache: dict[tuple[str, str], tuple[MPObject, MPObject]] = ctx._tee_sessions # type: ignore
|
|
261
|
-
|
|
262
|
-
key = (frm_dev_id, to_dev_id)
|
|
263
|
-
if key in cache:
|
|
264
|
-
return cache[key]
|
|
265
|
-
|
|
266
|
-
# 1) TEE generates (sk, pk) and quote(pk)
|
|
267
|
-
# KEM suite currently constant; future: read from tee device config (e.g. cluster_spec.devices[dev_id].config)
|
|
268
|
-
tee_sk, tee_pk = run_at(tee_rank, crypto.kem_keygen, _TEE_KEM_SUITE)
|
|
269
|
-
quote = run_at(tee_rank, tee.quote_gen, tee_pk)
|
|
270
|
-
|
|
271
|
-
# 2) Send quote to sender and attest to obtain TEE pk
|
|
272
|
-
quote_at_sender = mpi.p2p(tee_rank, frm_rank, quote)
|
|
273
|
-
tee_pk_at_sender = run_at(frm_rank, tee.attest, quote_at_sender)
|
|
274
|
-
|
|
275
|
-
# 3) Sender generates its ephemeral keypair and sends its pk to TEE
|
|
276
|
-
v_sk, v_pk = run_at(frm_rank, crypto.kem_keygen, _TEE_KEM_SUITE)
|
|
277
|
-
v_pk_at_tee = mpi.p2p(frm_rank, tee_rank, v_pk)
|
|
278
|
-
|
|
279
|
-
# 4) Both sides derive the shared secret and session key
|
|
280
|
-
shared_p = run_at(
|
|
281
|
-
frm_rank, crypto.kem_derive, v_sk, tee_pk_at_sender, _TEE_KEM_SUITE
|
|
282
|
-
)
|
|
283
|
-
shared_t = run_at(tee_rank, crypto.kem_derive, tee_sk, v_pk_at_tee, _TEE_KEM_SUITE)
|
|
284
|
-
# Use a fixed ASCII string literal for HKDF info on both sides
|
|
285
|
-
sess_p = run_at(frm_rank, crypto.hkdf, shared_p, _HKDF_INFO_LITERAL)
|
|
286
|
-
sess_t = run_at(tee_rank, crypto.hkdf, shared_t, _HKDF_INFO_LITERAL)
|
|
287
|
-
|
|
288
|
-
cache[key] = (sess_p, sess_t)
|
|
289
|
-
return sess_p, sess_t
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
def put(to_dev_id: str, obj: Any) -> MPObject:
|
|
293
|
-
if not isinstance(obj, MPObject):
|
|
294
|
-
return device(to_dev_id)(lambda x: x)(obj) # type: ignore[no-any-return]
|
|
295
|
-
assert isinstance(obj, MPObject)
|
|
296
|
-
return _d2d(to_dev_id, obj)
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
def _fetch(interp: InterpContext, obj: MPObject) -> Any:
|
|
300
|
-
dev_id = _get_devid(obj)
|
|
301
|
-
cluster_spec = interp.cluster_spec
|
|
302
|
-
dev_kind = cluster_spec.devices[dev_id].kind.upper()
|
|
303
|
-
|
|
304
|
-
dev_info = cluster_spec.devices[dev_id]
|
|
305
|
-
if dev_kind == "SPU":
|
|
306
|
-
revealed = mphost.evaluate(interp, smpc.reveal, obj)
|
|
307
|
-
result = mphost.fetch(interp, revealed)
|
|
308
|
-
# now all members have the same value, return the one at rank 0
|
|
309
|
-
return result[dev_info.members[0].rank]
|
|
310
|
-
elif dev_kind == "PPU":
|
|
311
|
-
assert len(dev_info.members) == 1
|
|
312
|
-
rank = dev_info.members[0].rank
|
|
313
|
-
result = mphost.fetch(interp, obj)
|
|
314
|
-
return result[rank]
|
|
315
|
-
elif dev_kind == "TEE":
|
|
316
|
-
assert len(dev_info.members) == 1
|
|
317
|
-
rank = dev_info.members[0].rank
|
|
318
|
-
result = mphost.fetch(interp, obj)
|
|
319
|
-
return result[rank]
|
|
320
|
-
else:
|
|
321
|
-
raise ValueError(f"Unknown device id: {dev_id}")
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
def fetch(interp: InterpContext, objs: Any) -> Any:
|
|
325
|
-
ctx = interp or mphost.cur_ctx()
|
|
326
|
-
assert isinstance(ctx, InterpContext), f"Expect InterpContext, got {ctx}"
|
|
327
|
-
return tree_map(partial(_fetch, ctx), objs)
|
mplang/ops/crypto.py
DELETED
|
@@ -1,108 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
Crypto frontend operations: operation signatures, types, and high-level semantics.
|
|
17
|
-
|
|
18
|
-
Scope and contracts:
|
|
19
|
-
- This module defines portable API shapes; it does not implement cryptography.
|
|
20
|
-
- Backends execute the operations and must meet the security semantics required
|
|
21
|
-
by the deployment (confidentiality, authenticity, correctness, etc.).
|
|
22
|
-
- The enc/dec API in this frontend uses a conventional 12-byte nonce prefix
|
|
23
|
-
(ciphertext = nonce || payload), and dec expects that format. Other security
|
|
24
|
-
properties (e.g., AEAD) are backend responsibilities.
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
from __future__ import annotations
|
|
28
|
-
|
|
29
|
-
from mplang.core import UINT8, TensorType
|
|
30
|
-
from mplang.ops.base import stateless_mod
|
|
31
|
-
|
|
32
|
-
_CRYPTO_MOD = stateless_mod("crypto")
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
@_CRYPTO_MOD.simple_op()
|
|
36
|
-
def keygen(*, length: int = 32) -> TensorType:
|
|
37
|
-
"""Generate random bytes for symmetric keys or generic randomness.
|
|
38
|
-
|
|
39
|
-
API: keygen(length: int = 32) -> key: u8[length]
|
|
40
|
-
|
|
41
|
-
Notes:
|
|
42
|
-
- Frontend defines the type/shape; backend provides randomness.
|
|
43
|
-
- Raises ValueError when length <= 0.
|
|
44
|
-
"""
|
|
45
|
-
if length <= 0:
|
|
46
|
-
raise ValueError("length must be > 0")
|
|
47
|
-
return TensorType(UINT8, (length,))
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
@_CRYPTO_MOD.simple_op()
|
|
51
|
-
def enc(plaintext: TensorType, key: TensorType) -> TensorType:
|
|
52
|
-
"""Symmetric encryption.
|
|
53
|
-
|
|
54
|
-
API: enc(plaintext: u8[N], key: u8[M]) -> ciphertext: u8[N + 12]
|
|
55
|
-
"""
|
|
56
|
-
pt_ty = plaintext
|
|
57
|
-
if pt_ty.dtype != UINT8:
|
|
58
|
-
raise TypeError("enc expects UINT8 plaintext")
|
|
59
|
-
if len(pt_ty.shape) != 1:
|
|
60
|
-
raise TypeError("enc expects 1-D plaintext")
|
|
61
|
-
length = pt_ty.shape[0]
|
|
62
|
-
if length >= 0:
|
|
63
|
-
return TensorType(UINT8, (length + 12,))
|
|
64
|
-
return TensorType(UINT8, (-1,))
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
@_CRYPTO_MOD.simple_op()
|
|
68
|
-
def dec(ciphertext: TensorType, key: TensorType) -> TensorType:
|
|
69
|
-
"""Symmetric decryption.
|
|
70
|
-
|
|
71
|
-
API: dec(ciphertext: u8[N + 12], key: u8[M]) -> plaintext: u8[N]
|
|
72
|
-
"""
|
|
73
|
-
ct_ty = ciphertext
|
|
74
|
-
if ct_ty.dtype != UINT8:
|
|
75
|
-
raise TypeError("dec expects UINT8 ciphertext")
|
|
76
|
-
if len(ct_ty.shape) != 1:
|
|
77
|
-
raise TypeError("dec expects 1-D ciphertext with nonce")
|
|
78
|
-
length = ct_ty.shape[0]
|
|
79
|
-
if length >= 0 and length < 12:
|
|
80
|
-
raise TypeError("dec expects 1-D ciphertext with nonce")
|
|
81
|
-
if length >= 0:
|
|
82
|
-
return TensorType(UINT8, (length - 12,))
|
|
83
|
-
return TensorType(UINT8, (-1,))
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
@_CRYPTO_MOD.simple_op()
|
|
87
|
-
def kem_keygen(*, suite: str = "x25519") -> tuple[TensorType, TensorType]:
|
|
88
|
-
"""KEM-style keypair generation: returns (sk, pk) bytes."""
|
|
89
|
-
sk_ty = TensorType(UINT8, (32,))
|
|
90
|
-
pk_ty = TensorType(UINT8, (32,))
|
|
91
|
-
return sk_ty, pk_ty
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
@_CRYPTO_MOD.simple_op()
|
|
95
|
-
def kem_derive(
|
|
96
|
-
sk: TensorType, peer_pk: TensorType, *, suite: str = "x25519"
|
|
97
|
-
) -> TensorType:
|
|
98
|
-
"""KEM-style shared secret derivation: returns secret bytes."""
|
|
99
|
-
_ = sk
|
|
100
|
-
_ = peer_pk
|
|
101
|
-
return TensorType(UINT8, (32,))
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
@_CRYPTO_MOD.simple_op()
|
|
105
|
-
def hkdf(secret: TensorType, *, info: str) -> TensorType:
|
|
106
|
-
"""HKDF-style key derivation: returns a 32-byte key."""
|
|
107
|
-
_ = secret
|
|
108
|
-
return TensorType(UINT8, (32,))
|
mplang/ops/ibis_cc.py
DELETED
|
@@ -1,136 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
import inspect
|
|
17
|
-
from collections.abc import Callable
|
|
18
|
-
from typing import Any
|
|
19
|
-
|
|
20
|
-
import ibis
|
|
21
|
-
from jax.tree_util import PyTreeDef, tree_flatten
|
|
22
|
-
|
|
23
|
-
from mplang.core import MPObject, PFunction, TableType, dtypes
|
|
24
|
-
from mplang.ops.base import FeOperation, stateless_mod
|
|
25
|
-
from mplang.utils.func_utils import normalize_fn
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def ibis2sql(
|
|
29
|
-
expr: ibis.Table,
|
|
30
|
-
in_schemas: list[ibis.Schema],
|
|
31
|
-
in_names: list[str],
|
|
32
|
-
fn_name: str = "",
|
|
33
|
-
) -> PFunction:
|
|
34
|
-
"""
|
|
35
|
-
Compile a ibis expr to sql and return the PFunction.
|
|
36
|
-
|
|
37
|
-
Args:
|
|
38
|
-
expr: ibis expr.
|
|
39
|
-
in_schemas: the input table schemas
|
|
40
|
-
in_names: the input table names, If there is only one table, it is usually defaulted to "table"
|
|
41
|
-
Return:
|
|
42
|
-
PFunction: The compiled PFunction
|
|
43
|
-
"""
|
|
44
|
-
assert len(in_schemas) == len(in_names), (
|
|
45
|
-
f"length of input table names and schemas mismatch. {len(in_schemas)}!={len(in_names)}"
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
def _convert(s: ibis.Schema) -> TableType:
|
|
49
|
-
return TableType.from_pairs([
|
|
50
|
-
(name, dtypes.from_numpy(dt.to_numpy())) for name, dt in s.fields.items()
|
|
51
|
-
])
|
|
52
|
-
|
|
53
|
-
ins_info = [_convert(s) for s in in_schemas]
|
|
54
|
-
outs_info = [_convert(expr.schema())]
|
|
55
|
-
|
|
56
|
-
sql = ibis.to_sql(expr, dialect="duckdb")
|
|
57
|
-
# Emit generic sql.run op; runtime maps to backend-specific kernel.
|
|
58
|
-
pfn = PFunction(
|
|
59
|
-
fn_type="sql.run",
|
|
60
|
-
fn_name=fn_name,
|
|
61
|
-
fn_text=sql,
|
|
62
|
-
ins_info=tuple(ins_info),
|
|
63
|
-
outs_info=tuple(outs_info),
|
|
64
|
-
in_names=tuple(in_names),
|
|
65
|
-
dialect="duckdb",
|
|
66
|
-
)
|
|
67
|
-
return pfn
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def is_ibis_function(func: Callable) -> bool:
|
|
71
|
-
"""
|
|
72
|
-
Verify whether a function is an ibis function.
|
|
73
|
-
The func signature should like def foo(t0:ibis.Table, t1:ibis.Table)->ibis.Table
|
|
74
|
-
"""
|
|
75
|
-
try:
|
|
76
|
-
sig = inspect.signature(func)
|
|
77
|
-
except (ValueError, TypeError):
|
|
78
|
-
return False
|
|
79
|
-
|
|
80
|
-
ret_anno = sig.return_annotation
|
|
81
|
-
if ret_anno is ibis.Table:
|
|
82
|
-
return True
|
|
83
|
-
|
|
84
|
-
for param in sig.parameters.values():
|
|
85
|
-
par_anno = param.annotation
|
|
86
|
-
if par_anno is ibis.Table:
|
|
87
|
-
return True
|
|
88
|
-
|
|
89
|
-
return False
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
_IBIS_MOD = stateless_mod("ibis")
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
class IbisRunner(FeOperation):
|
|
96
|
-
"""Ibis runner frontend operation."""
|
|
97
|
-
|
|
98
|
-
def trace(
|
|
99
|
-
self, func: Callable, *args: Any, **kwargs: Any
|
|
100
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
101
|
-
"""Compile an Ibis function to SQL format.
|
|
102
|
-
|
|
103
|
-
Args:
|
|
104
|
-
func: The Ibis function to compile
|
|
105
|
-
*args: Positional arguments to the function
|
|
106
|
-
**kwargs: Keyword arguments to the function
|
|
107
|
-
|
|
108
|
-
Returns:
|
|
109
|
-
tuple[PFunction, list[MPObject], Any]: The compiled PFunction, input variables, and output tree
|
|
110
|
-
"""
|
|
111
|
-
|
|
112
|
-
def is_variable(arg: Any) -> bool:
|
|
113
|
-
return isinstance(arg, MPObject)
|
|
114
|
-
|
|
115
|
-
normalized_fn, in_vars = normalize_fn(func, args, kwargs, is_variable)
|
|
116
|
-
|
|
117
|
-
in_args, in_schemas, in_names = [], [], []
|
|
118
|
-
idx = 0
|
|
119
|
-
for arg in in_vars:
|
|
120
|
-
columns = [(p[0], p[1].to_numpy()) for p in arg.schema.columns]
|
|
121
|
-
schema = ibis.schema(columns)
|
|
122
|
-
name = f"table{idx}"
|
|
123
|
-
table = ibis.table(schema=schema, name=name)
|
|
124
|
-
in_args.append(table)
|
|
125
|
-
in_schemas.append(schema)
|
|
126
|
-
in_names.append(name)
|
|
127
|
-
idx += 1
|
|
128
|
-
|
|
129
|
-
result = normalized_fn(in_args)
|
|
130
|
-
assert isinstance(result, ibis.Table)
|
|
131
|
-
pfunc = ibis2sql(result, in_schemas, in_names, func.__name__)
|
|
132
|
-
_, treedef = tree_flatten(result)
|
|
133
|
-
return pfunc, in_vars, treedef
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
run_ibis = IbisRunner(_IBIS_MOD, "run")
|
mplang/ops/sql_cc.py
DELETED
|
@@ -1,62 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from jax.tree_util import PyTreeDef, tree_flatten
|
|
16
|
-
|
|
17
|
-
from mplang.core.mpobject import MPObject
|
|
18
|
-
from mplang.core.pfunc import PFunction
|
|
19
|
-
from mplang.core.table import TableType
|
|
20
|
-
from mplang.ops.base import FeOperation, stateless_mod
|
|
21
|
-
|
|
22
|
-
_SQL_MOD = stateless_mod("sql")
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class SqlRunner(FeOperation):
|
|
26
|
-
def __init__(self, dialect: str = "duckdb"):
|
|
27
|
-
# Bind to sql module with a stable op name for registry/dispatch
|
|
28
|
-
super().__init__(_SQL_MOD, "run")
|
|
29
|
-
self._dialect = dialect
|
|
30
|
-
|
|
31
|
-
# TODO(jint): we should deduce out_type according to query and in_tables' schema
|
|
32
|
-
def trace(
|
|
33
|
-
self,
|
|
34
|
-
query: str,
|
|
35
|
-
out_type: TableType,
|
|
36
|
-
in_tables: dict[str, MPObject] | None = None,
|
|
37
|
-
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
|
|
38
|
-
in_names: list[str] = []
|
|
39
|
-
ins_info: list[TableType] = []
|
|
40
|
-
in_vars: list[MPObject] = []
|
|
41
|
-
if in_tables:
|
|
42
|
-
for name, tbl in in_tables.items():
|
|
43
|
-
assert isinstance(tbl, MPObject)
|
|
44
|
-
assert tbl.schema is not None
|
|
45
|
-
in_names.append(name)
|
|
46
|
-
ins_info.append(tbl.schema)
|
|
47
|
-
in_vars.append(tbl)
|
|
48
|
-
|
|
49
|
-
pfn = PFunction(
|
|
50
|
-
fn_type="sql.run",
|
|
51
|
-
fn_name="",
|
|
52
|
-
fn_text=query,
|
|
53
|
-
ins_info=tuple(ins_info),
|
|
54
|
-
outs_info=(out_type,),
|
|
55
|
-
in_names=tuple(in_names),
|
|
56
|
-
dialect=self._dialect,
|
|
57
|
-
)
|
|
58
|
-
_, treedef = tree_flatten(out_type)
|
|
59
|
-
return pfn, in_vars, treedef
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
run_sql = SqlRunner("duckdb")
|