mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -45
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +5 -7
- mplang/v1/core/__init__.py +157 -0
- mplang/{core → v1/core}/cluster.py +30 -14
- mplang/{core → v1/core}/comm.py +5 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +13 -14
- mplang/{core → v1/core}/expr/evaluator.py +65 -24
- mplang/{core → v1/core}/expr/printer.py +24 -18
- mplang/{core → v1/core}/expr/transformer.py +3 -3
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +23 -16
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +4 -4
- mplang/{core → v1/core}/primitive.py +106 -201
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{api.py → v1/host.py} +38 -6
- mplang/v1/kernels/__init__.py +41 -0
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/v1/kernels/basic.py +240 -0
- mplang/{kernels → v1/kernels}/context.py +42 -27
- mplang/{kernels → v1/kernels}/crypto.py +44 -37
- mplang/v1/kernels/fhe.py +858 -0
- mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
- mplang/{kernels → v1/kernels}/phe.py +263 -57
- mplang/{kernels → v1/kernels}/spu.py +137 -48
- mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
- mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
- mplang/v1/kernels/value.py +626 -0
- mplang/{ops → v1/ops}/__init__.py +5 -16
- mplang/{ops → v1/ops}/base.py +2 -5
- mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
- mplang/v1/ops/crypto.py +262 -0
- mplang/v1/ops/fhe.py +272 -0
- mplang/{ops → v1/ops}/jax_cc.py +33 -68
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -4
- mplang/{ops → v1/ops}/spu.py +3 -5
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +9 -24
- mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
- mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +35 -20
- mplang/{runtime → v1/runtime}/client.py +19 -8
- mplang/{runtime → v1/runtime}/communicator.py +59 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +30 -12
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +58 -42
- mplang/{runtime → v1/runtime}/session.py +57 -71
- mplang/{runtime → v1/runtime}/simulation.py +55 -28
- mplang/v1/simp/api.py +353 -0
- mplang/{simp → v1/simp}/mpi.py +8 -9
- mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
- mplang/{simp → v1/simp}/random.py +21 -22
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/core/__init__.py +0 -92
- mplang/device.py +0 -340
- mplang/kernels/builtin.py +0 -207
- mplang/ops/crypto.py +0 -109
- mplang/ops/ibis_cc.py +0 -139
- mplang/ops/sql.py +0 -61
- mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
- mplang/runtime/link_comm.py +0 -131
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -73
- mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{kernels → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Dtype conversion utilities between MPLang ScalarType and external libraries.
|
|
16
|
+
|
|
17
|
+
This module provides bidirectional conversion between MPLang's type system
|
|
18
|
+
(ScalarType hierarchy) and external library types (NumPy, JAX, PyArrow, Pandas).
|
|
19
|
+
|
|
20
|
+
Usage:
|
|
21
|
+
from mplang.v2.dialects import dtypes
|
|
22
|
+
|
|
23
|
+
# MPLang ScalarType → JAX/NumPy
|
|
24
|
+
jax_dtype = dtypes.to_jax(scalar_types.f32) # → jnp.float32
|
|
25
|
+
np_dtype = dtypes.to_numpy(scalar_types.i64) # → np.dtype('int64')
|
|
26
|
+
|
|
27
|
+
# JAX/NumPy → MPLang ScalarType
|
|
28
|
+
scalar_type = dtypes.from_dtype(np.float32) # → scalar_types.f32
|
|
29
|
+
scalar_type = dtypes.from_dtype(jnp.int64) # → scalar_types.i64
|
|
30
|
+
|
|
31
|
+
# PyArrow/Pandas → MPLang ScalarType
|
|
32
|
+
scalar_type = dtypes.from_arrow(pa.int64()) # → scalar_types.i64
|
|
33
|
+
scalar_type = dtypes.from_pandas(df["col"].dtype) # → scalar_types.f64
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
from __future__ import annotations
|
|
37
|
+
|
|
38
|
+
from typing import Any
|
|
39
|
+
|
|
40
|
+
import jax.numpy as jnp
|
|
41
|
+
import numpy as np
|
|
42
|
+
|
|
43
|
+
import mplang.v2.edsl.typing as scalar_types
|
|
44
|
+
|
|
45
|
+
# ==============================================================================
|
|
46
|
+
# MPLang ScalarType → JAX/NumPy conversion
|
|
47
|
+
# ==============================================================================
|
|
48
|
+
|
|
49
|
+
# Mapping from MPLang ScalarType instances to JAX dtypes
|
|
50
|
+
_SCALAR_TO_JAX: dict[scalar_types.ScalarType, Any] = {
|
|
51
|
+
# Signed integers
|
|
52
|
+
scalar_types.i8: jnp.int8,
|
|
53
|
+
scalar_types.i16: jnp.int16,
|
|
54
|
+
scalar_types.i32: jnp.int32,
|
|
55
|
+
scalar_types.i64: jnp.int64,
|
|
56
|
+
# Unsigned integers
|
|
57
|
+
scalar_types.u8: jnp.uint8,
|
|
58
|
+
scalar_types.u16: jnp.uint16,
|
|
59
|
+
scalar_types.u32: jnp.uint32,
|
|
60
|
+
scalar_types.u64: jnp.uint64,
|
|
61
|
+
# Floating point
|
|
62
|
+
scalar_types.f16: jnp.float16,
|
|
63
|
+
scalar_types.f32: jnp.float32,
|
|
64
|
+
scalar_types.f64: jnp.float64,
|
|
65
|
+
# Complex
|
|
66
|
+
scalar_types.c64: jnp.complex64,
|
|
67
|
+
scalar_types.c128: jnp.complex128,
|
|
68
|
+
# Boolean (i1)
|
|
69
|
+
scalar_types.bool_: jnp.bool_,
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def to_jax(dtype: scalar_types.ScalarType) -> Any:
|
|
74
|
+
"""Convert MPLang scalar type to JAX dtype.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
dtype: MPLang ScalarType (IntegerType, FloatType, or ComplexType)
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Corresponding JAX/NumPy dtype
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
TypeError: If dtype is not a ScalarType
|
|
84
|
+
ValueError: If dtype has no JAX equivalent
|
|
85
|
+
|
|
86
|
+
Examples:
|
|
87
|
+
>>> dtypes.to_jax(scalar_types.f32)
|
|
88
|
+
<class 'jax.numpy.float32'>
|
|
89
|
+
>>> dtypes.to_jax(scalar_types.i64)
|
|
90
|
+
<class 'jax.numpy.int64'>
|
|
91
|
+
"""
|
|
92
|
+
if not isinstance(dtype, scalar_types.ScalarType):
|
|
93
|
+
raise TypeError(f"Expected ScalarType, got {type(dtype).__name__}")
|
|
94
|
+
|
|
95
|
+
# Direct lookup
|
|
96
|
+
if dtype in _SCALAR_TO_JAX:
|
|
97
|
+
return _SCALAR_TO_JAX[dtype]
|
|
98
|
+
|
|
99
|
+
# Handle dynamically created types (same structure but different instance)
|
|
100
|
+
if isinstance(dtype, scalar_types.IntegerType):
|
|
101
|
+
if dtype.bitwidth == 1:
|
|
102
|
+
return jnp.bool_
|
|
103
|
+
prefix = "int" if dtype.signed else "uint"
|
|
104
|
+
try:
|
|
105
|
+
return getattr(jnp, f"{prefix}{dtype.bitwidth}")
|
|
106
|
+
except AttributeError:
|
|
107
|
+
pass
|
|
108
|
+
elif isinstance(dtype, scalar_types.FloatType):
|
|
109
|
+
try:
|
|
110
|
+
return getattr(jnp, f"float{dtype.bitwidth}")
|
|
111
|
+
except AttributeError:
|
|
112
|
+
pass
|
|
113
|
+
elif isinstance(dtype, scalar_types.ComplexType):
|
|
114
|
+
total_bits = dtype.inner_type.bitwidth * 2
|
|
115
|
+
try:
|
|
116
|
+
return getattr(jnp, f"complex{total_bits}")
|
|
117
|
+
except AttributeError:
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
raise ValueError(f"No JAX dtype equivalent for {dtype}")
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def to_numpy(dtype: scalar_types.ScalarType) -> np.dtype:
|
|
124
|
+
"""Convert MPLang scalar type to NumPy dtype.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
dtype: MPLang ScalarType
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Corresponding NumPy dtype
|
|
131
|
+
|
|
132
|
+
Examples:
|
|
133
|
+
>>> dtypes.to_numpy(scalar_types.f32)
|
|
134
|
+
dtype('float32')
|
|
135
|
+
"""
|
|
136
|
+
jax_dtype = to_jax(dtype)
|
|
137
|
+
return np.dtype(jax_dtype) # type: ignore[no-any-return]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# ==============================================================================
|
|
141
|
+
# JAX/NumPy → MPLang ScalarType conversion
|
|
142
|
+
# ==============================================================================
|
|
143
|
+
|
|
144
|
+
# Reverse mapping (built dynamically to stay in sync)
|
|
145
|
+
_JAX_TO_SCALAR: dict[Any, scalar_types.ScalarType] = {
|
|
146
|
+
v: k for k, v in _SCALAR_TO_JAX.items()
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
# NumPy dtype to MPLang ScalarType mapping
|
|
150
|
+
_NUMPY_TO_SCALAR: dict[type, scalar_types.ScalarType] = {
|
|
151
|
+
np.int8: scalar_types.i8,
|
|
152
|
+
np.int16: scalar_types.i16,
|
|
153
|
+
np.int32: scalar_types.i32,
|
|
154
|
+
np.int64: scalar_types.i64,
|
|
155
|
+
np.uint8: scalar_types.u8,
|
|
156
|
+
np.uint16: scalar_types.u16,
|
|
157
|
+
np.uint32: scalar_types.u32,
|
|
158
|
+
np.uint64: scalar_types.u64,
|
|
159
|
+
np.float16: scalar_types.f16,
|
|
160
|
+
np.float32: scalar_types.f32,
|
|
161
|
+
np.float64: scalar_types.f64,
|
|
162
|
+
np.complex64: scalar_types.c64,
|
|
163
|
+
np.complex128: scalar_types.c128,
|
|
164
|
+
np.bool_: scalar_types.bool_,
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def from_dtype(dtype: Any) -> scalar_types.ScalarType:
|
|
169
|
+
"""Convert JAX/NumPy dtype to MPLang scalar type.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
dtype: JAX dtype, NumPy dtype, or dtype-like object
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
Corresponding MPLang ScalarType
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
ValueError: If dtype cannot be converted
|
|
179
|
+
|
|
180
|
+
Examples:
|
|
181
|
+
>>> dtypes.from_dtype(jnp.float32)
|
|
182
|
+
f32
|
|
183
|
+
>>> dtypes.from_dtype(np.dtype("int64"))
|
|
184
|
+
i64
|
|
185
|
+
"""
|
|
186
|
+
# Direct lookup for JAX types
|
|
187
|
+
if dtype in _JAX_TO_SCALAR:
|
|
188
|
+
return _JAX_TO_SCALAR[dtype]
|
|
189
|
+
|
|
190
|
+
# Direct lookup for NumPy scalar types
|
|
191
|
+
if dtype in _NUMPY_TO_SCALAR:
|
|
192
|
+
return _NUMPY_TO_SCALAR[dtype]
|
|
193
|
+
|
|
194
|
+
# Handle np.dtype objects
|
|
195
|
+
if isinstance(dtype, np.dtype):
|
|
196
|
+
dtype_type = dtype.type
|
|
197
|
+
if dtype_type in _NUMPY_TO_SCALAR:
|
|
198
|
+
return _NUMPY_TO_SCALAR[dtype_type]
|
|
199
|
+
|
|
200
|
+
# Try to normalize to a dtype object
|
|
201
|
+
try:
|
|
202
|
+
normalized = jnp.dtype(dtype)
|
|
203
|
+
if normalized in _JAX_TO_SCALAR:
|
|
204
|
+
return _JAX_TO_SCALAR[normalized]
|
|
205
|
+
except Exception:
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
# Fallback: match by name
|
|
209
|
+
name = getattr(dtype, "name", str(dtype)).lower()
|
|
210
|
+
|
|
211
|
+
# Integer types
|
|
212
|
+
if "int8" in name and "uint" not in name:
|
|
213
|
+
return scalar_types.i8
|
|
214
|
+
elif "int16" in name and "uint" not in name:
|
|
215
|
+
return scalar_types.i16
|
|
216
|
+
elif "int32" in name and "uint" not in name:
|
|
217
|
+
return scalar_types.i32
|
|
218
|
+
elif "int64" in name and "uint" not in name:
|
|
219
|
+
return scalar_types.i64
|
|
220
|
+
elif "uint8" in name:
|
|
221
|
+
return scalar_types.u8
|
|
222
|
+
elif "uint16" in name:
|
|
223
|
+
return scalar_types.u16
|
|
224
|
+
elif "uint32" in name:
|
|
225
|
+
return scalar_types.u32
|
|
226
|
+
elif "uint64" in name:
|
|
227
|
+
return scalar_types.u64
|
|
228
|
+
# Float types
|
|
229
|
+
elif "float16" in name:
|
|
230
|
+
return scalar_types.f16
|
|
231
|
+
elif "float32" in name:
|
|
232
|
+
return scalar_types.f32
|
|
233
|
+
elif "float64" in name:
|
|
234
|
+
return scalar_types.f64
|
|
235
|
+
# Complex types
|
|
236
|
+
elif "complex64" in name:
|
|
237
|
+
return scalar_types.c64
|
|
238
|
+
elif "complex128" in name:
|
|
239
|
+
return scalar_types.c128
|
|
240
|
+
# Boolean
|
|
241
|
+
elif "bool" in name:
|
|
242
|
+
return scalar_types.bool_
|
|
243
|
+
|
|
244
|
+
raise ValueError(f"Cannot convert dtype '{dtype}' to MPLang ScalarType")
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
# ==============================================================================
|
|
248
|
+
# PyArrow → MPLang ScalarType conversion
|
|
249
|
+
# ==============================================================================
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def from_arrow(arrow_type: Any) -> scalar_types.BaseType:
|
|
253
|
+
"""Convert PyArrow type to MPLang scalar type.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
arrow_type: PyArrow DataType (e.g., pa.int64(), pa.float32())
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
Corresponding MPLang BaseType (ScalarType or CustomType)
|
|
260
|
+
|
|
261
|
+
Raises:
|
|
262
|
+
ValueError: If arrow_type cannot be converted
|
|
263
|
+
|
|
264
|
+
Examples:
|
|
265
|
+
>>> import pyarrow as pa
|
|
266
|
+
>>> dtypes.from_arrow(pa.int64())
|
|
267
|
+
i64
|
|
268
|
+
>>> dtypes.from_arrow(pa.float32())
|
|
269
|
+
f32
|
|
270
|
+
>>> dtypes.from_arrow(pa.string())
|
|
271
|
+
Custom[string]
|
|
272
|
+
"""
|
|
273
|
+
import pyarrow as pa
|
|
274
|
+
|
|
275
|
+
if pa.types.is_boolean(arrow_type):
|
|
276
|
+
return scalar_types.bool_
|
|
277
|
+
elif pa.types.is_int8(arrow_type):
|
|
278
|
+
return scalar_types.i8
|
|
279
|
+
elif pa.types.is_int16(arrow_type):
|
|
280
|
+
return scalar_types.i16
|
|
281
|
+
elif pa.types.is_int32(arrow_type):
|
|
282
|
+
return scalar_types.i32
|
|
283
|
+
elif pa.types.is_int64(arrow_type):
|
|
284
|
+
return scalar_types.i64
|
|
285
|
+
elif pa.types.is_uint8(arrow_type):
|
|
286
|
+
return scalar_types.u8
|
|
287
|
+
elif pa.types.is_uint16(arrow_type):
|
|
288
|
+
return scalar_types.u16
|
|
289
|
+
elif pa.types.is_uint32(arrow_type):
|
|
290
|
+
return scalar_types.u32
|
|
291
|
+
elif pa.types.is_uint64(arrow_type):
|
|
292
|
+
return scalar_types.u64
|
|
293
|
+
elif pa.types.is_float16(arrow_type):
|
|
294
|
+
return scalar_types.f16
|
|
295
|
+
elif pa.types.is_float32(arrow_type):
|
|
296
|
+
return scalar_types.f32
|
|
297
|
+
elif pa.types.is_float64(arrow_type) or pa.types.is_floating(arrow_type):
|
|
298
|
+
return scalar_types.f64
|
|
299
|
+
elif pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type):
|
|
300
|
+
return scalar_types.STRING
|
|
301
|
+
elif pa.types.is_date(arrow_type):
|
|
302
|
+
return scalar_types.DATE
|
|
303
|
+
elif pa.types.is_time(arrow_type):
|
|
304
|
+
return scalar_types.TIME
|
|
305
|
+
elif pa.types.is_timestamp(arrow_type):
|
|
306
|
+
return scalar_types.TIMESTAMP
|
|
307
|
+
elif pa.types.is_binary(arrow_type) or pa.types.is_large_binary(arrow_type):
|
|
308
|
+
return scalar_types.BINARY
|
|
309
|
+
else:
|
|
310
|
+
raise ValueError(
|
|
311
|
+
f"Cannot convert PyArrow type '{arrow_type}' to MPLang ScalarType"
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
# ==============================================================================
|
|
316
|
+
# Pandas dtype → MPLang ScalarType conversion
|
|
317
|
+
# ==============================================================================
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def from_pandas(pd_dtype: Any) -> scalar_types.BaseType:
|
|
321
|
+
"""Convert Pandas dtype to MPLang scalar type.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
pd_dtype: Pandas dtype (e.g., df["col"].dtype)
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
Corresponding MPLang BaseType (ScalarType or CustomType)
|
|
328
|
+
|
|
329
|
+
Raises:
|
|
330
|
+
ValueError: If pd_dtype cannot be converted
|
|
331
|
+
|
|
332
|
+
Examples:
|
|
333
|
+
>>> import pandas as pd
|
|
334
|
+
>>> df = pd.DataFrame({"x": [1, 2, 3]})
|
|
335
|
+
>>> dtypes.from_pandas(df["x"].dtype)
|
|
336
|
+
i64
|
|
337
|
+
"""
|
|
338
|
+
# Get the dtype name as string for matching
|
|
339
|
+
dtype_name = str(pd_dtype)
|
|
340
|
+
|
|
341
|
+
if dtype_name == "bool":
|
|
342
|
+
return scalar_types.bool_
|
|
343
|
+
elif dtype_name in ("int8", "Int8"):
|
|
344
|
+
return scalar_types.i8
|
|
345
|
+
elif dtype_name in ("int16", "Int16"):
|
|
346
|
+
return scalar_types.i16
|
|
347
|
+
elif dtype_name in ("int32", "Int32"):
|
|
348
|
+
return scalar_types.i32
|
|
349
|
+
elif dtype_name in ("int64", "Int64"):
|
|
350
|
+
return scalar_types.i64
|
|
351
|
+
elif dtype_name in ("uint8", "UInt8"):
|
|
352
|
+
return scalar_types.u8
|
|
353
|
+
elif dtype_name in ("uint16", "UInt16"):
|
|
354
|
+
return scalar_types.u16
|
|
355
|
+
elif dtype_name in ("uint32", "UInt32"):
|
|
356
|
+
return scalar_types.u32
|
|
357
|
+
elif dtype_name in ("uint64", "UInt64"):
|
|
358
|
+
return scalar_types.u64
|
|
359
|
+
elif dtype_name in ("float16", "Float16"):
|
|
360
|
+
return scalar_types.f16
|
|
361
|
+
elif dtype_name in ("float32", "Float32"):
|
|
362
|
+
return scalar_types.f32
|
|
363
|
+
elif dtype_name in ("float64", "Float64"):
|
|
364
|
+
return scalar_types.f64
|
|
365
|
+
elif dtype_name in ("complex64",):
|
|
366
|
+
return scalar_types.c64
|
|
367
|
+
elif dtype_name in ("complex128",):
|
|
368
|
+
return scalar_types.c128
|
|
369
|
+
elif dtype_name == "object" or dtype_name.startswith("string"):
|
|
370
|
+
return scalar_types.STRING
|
|
371
|
+
elif dtype_name.startswith("datetime"):
|
|
372
|
+
return scalar_types.TIMESTAMP
|
|
373
|
+
elif dtype_name.startswith("timedelta"):
|
|
374
|
+
return scalar_types.INTERVAL
|
|
375
|
+
else:
|
|
376
|
+
raise ValueError(
|
|
377
|
+
f"Cannot convert Pandas dtype '{pd_dtype}' to MPLang ScalarType"
|
|
378
|
+
)
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Field dialect: Finite Field Arithmetic.
|
|
16
|
+
|
|
17
|
+
This module defines the Intermediate Representation (IR) for field operations.
|
|
18
|
+
It contains:
|
|
19
|
+
1. Primitive Definitions (Abstract Operations)
|
|
20
|
+
2. Abstract Evaluation Rules (Type Inference)
|
|
21
|
+
3. Public API (Builder Functions)
|
|
22
|
+
|
|
23
|
+
Implementation logic (Backends) is strictly separated into `mplang/v2/backends/field_impl.py`.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
from typing import Any, cast
|
|
29
|
+
|
|
30
|
+
import jax.numpy as jnp
|
|
31
|
+
|
|
32
|
+
import mplang.v2.edsl as el
|
|
33
|
+
import mplang.v2.edsl.typing as elt
|
|
34
|
+
from mplang.v2.dialects import tensor
|
|
35
|
+
|
|
36
|
+
# =============================================================================
|
|
37
|
+
# Primitives
|
|
38
|
+
# =============================================================================
|
|
39
|
+
|
|
40
|
+
aes_expand_p = el.Primitive[el.Object]("field.aes_expand")
|
|
41
|
+
mul_p = el.Primitive[el.Object]("field.mul")
|
|
42
|
+
solve_okvs_p = el.Primitive[el.Object]("field.solve_okvs")
|
|
43
|
+
decode_okvs_p = el.Primitive[el.Object]("field.decode_okvs")
|
|
44
|
+
ldpc_encode_p = el.Primitive[el.Object]("field.ldpc_encode")
|
|
45
|
+
|
|
46
|
+
# Optimized Mega-Binning Primitives
|
|
47
|
+
solve_okvs_opt_p = el.Primitive[el.Object]("field.solve_okvs_opt")
|
|
48
|
+
decode_okvs_opt_p = el.Primitive[el.Object]("field.decode_okvs_opt")
|
|
49
|
+
|
|
50
|
+
# =============================================================================
|
|
51
|
+
# Abstract Evaluation (Type Inference)
|
|
52
|
+
# =============================================================================
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@aes_expand_p.def_abstract_eval
|
|
56
|
+
def _aes_expand_ae(seeds_type: elt.TensorType, *, length: int) -> elt.TensorType:
|
|
57
|
+
# seeds: (N, 2)
|
|
58
|
+
# output: (N, length, 2) -> ALWAYS uint64
|
|
59
|
+
n = seeds_type.shape[0]
|
|
60
|
+
return elt.TensorType(elt.u64, (n, length, 2))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@mul_p.def_abstract_eval
|
|
64
|
+
def _mul_ae(a: elt.TensorType, b: elt.TensorType) -> elt.TensorType:
|
|
65
|
+
return a
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@solve_okvs_p.def_abstract_eval
|
|
69
|
+
def _solve_okvs_ae(
|
|
70
|
+
key_type: elt.TensorType,
|
|
71
|
+
val_type: elt.TensorType,
|
|
72
|
+
seed_type: elt.TensorType,
|
|
73
|
+
*,
|
|
74
|
+
m: int,
|
|
75
|
+
) -> elt.TensorType:
|
|
76
|
+
return elt.TensorType(val_type.element_type, (m, 2))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@decode_okvs_p.def_abstract_eval
|
|
80
|
+
def _decode_okvs_ae(
|
|
81
|
+
key_type: elt.TensorType,
|
|
82
|
+
store_type: elt.TensorType,
|
|
83
|
+
seed_type: elt.TensorType,
|
|
84
|
+
) -> elt.TensorType:
|
|
85
|
+
n = key_type.shape[0]
|
|
86
|
+
return elt.TensorType(store_type.element_type, (n, 2))
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@solve_okvs_opt_p.def_abstract_eval
|
|
90
|
+
def _solve_okvs_opt_ae(
|
|
91
|
+
key_type: elt.TensorType,
|
|
92
|
+
val_type: elt.TensorType,
|
|
93
|
+
seed_type: elt.TensorType,
|
|
94
|
+
*,
|
|
95
|
+
m: int,
|
|
96
|
+
) -> elt.TensorType:
|
|
97
|
+
return elt.TensorType(val_type.element_type, (m, 2))
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@decode_okvs_opt_p.def_abstract_eval
|
|
101
|
+
def _decode_okvs_opt_ae(
|
|
102
|
+
key_type: elt.TensorType,
|
|
103
|
+
store_type: elt.TensorType,
|
|
104
|
+
seed_type: elt.TensorType,
|
|
105
|
+
) -> elt.TensorType:
|
|
106
|
+
n = key_type.shape[0]
|
|
107
|
+
return elt.TensorType(store_type.element_type, (n, 2))
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@ldpc_encode_p.def_abstract_eval
|
|
111
|
+
def _ldpc_encode_ae(
|
|
112
|
+
message: elt.TensorType,
|
|
113
|
+
indices: elt.TensorType,
|
|
114
|
+
indptr: elt.TensorType,
|
|
115
|
+
*,
|
|
116
|
+
m: int,
|
|
117
|
+
n: int,
|
|
118
|
+
) -> elt.TensorType:
|
|
119
|
+
# message: (K, 2)
|
|
120
|
+
# output: (M, 2) (usually N, 2 in silver context where M=N)
|
|
121
|
+
# Wait, kernel computes (M, 2).
|
|
122
|
+
return elt.TensorType(message.element_type, (m, 2))
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
# =============================================================================
|
|
126
|
+
# Public API
|
|
127
|
+
# =============================================================================
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def aes_expand(seeds: el.Object, length: int) -> el.Object:
|
|
131
|
+
"""Expand seeds using AES-CTR PRG.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
seeds: (N, 2) uint64 tensor (keys)
|
|
135
|
+
length: Number of 128-bit blocks to generate per seed
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
(N, length, 2) uint64 tensor
|
|
139
|
+
"""
|
|
140
|
+
return aes_expand_p.bind(seeds, length=length)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def mul(a: el.Object, b: el.Object) -> el.Object:
|
|
144
|
+
"""GF(2^128) Multiplication."""
|
|
145
|
+
return mul_p.bind(a, b)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def solve_okvs(
|
|
149
|
+
keys: el.Object, values: el.Object, m: int, seed: el.Object
|
|
150
|
+
) -> el.Object:
|
|
151
|
+
"""Solve OKVS P for keys->values using C++ Kernel.
|
|
152
|
+
Returns storage tensor of shape (m, 2).
|
|
153
|
+
"""
|
|
154
|
+
return solve_okvs_p.bind(keys, values, seed, m=m)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def decode_okvs(keys: el.Object, storage: el.Object, seed: el.Object) -> el.Object:
|
|
158
|
+
"""Decode OKVS values from storage for keys.
|
|
159
|
+
Returns decoded values of shape (N, 2).
|
|
160
|
+
"""
|
|
161
|
+
return decode_okvs_p.bind(keys, storage, seed)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def solve_okvs_opt(
|
|
165
|
+
keys: el.Object, values: el.Object, m: int, seed: el.Object
|
|
166
|
+
) -> el.Object:
|
|
167
|
+
"""Solve OKVS using Optimized Mega-Binning Kernel."""
|
|
168
|
+
return solve_okvs_opt_p.bind(keys, values, seed, m=m)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def decode_okvs_opt(keys: el.Object, storage: el.Object, seed: el.Object) -> el.Object:
|
|
172
|
+
"""Decode OKVS using Optimized Mega-Binning Kernel."""
|
|
173
|
+
return decode_okvs_opt_p.bind(keys, storage, seed)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def ldpc_encode(
|
|
177
|
+
message: el.Object, h_indices: el.Object, h_indptr: el.Object, m: int, n: int
|
|
178
|
+
) -> el.Object:
|
|
179
|
+
"""Compute S = H * M using Sparse Matrix Multiplication kernel.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
message: (N, 2) or (K, 2) input vector.
|
|
183
|
+
h_indices: CSR indices.
|
|
184
|
+
h_indptr: CSR indptr.
|
|
185
|
+
m: Number of rows in H (Output size).
|
|
186
|
+
n: Number of cols in H (Input size).
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
(M, 2) output vector.
|
|
190
|
+
"""
|
|
191
|
+
return ldpc_encode_p.bind(message, h_indices, h_indptr, m=m, n=n)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# =============================================================================
|
|
195
|
+
# Helpers (EDSL Composition)
|
|
196
|
+
# =============================================================================
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def add(a: el.Object, b: el.Object) -> el.Object:
|
|
200
|
+
"""GF(2^128) Addition (XOR)."""
|
|
201
|
+
return cast(el.Object, tensor.run_jax(jnp.bitwise_xor, a, b))
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def sum(x: el.Object, axis: int | None = None) -> el.Object:
|
|
205
|
+
"""GF(2^128) Summation (XOR Sum)."""
|
|
206
|
+
|
|
207
|
+
def _sum_impl(val: Any) -> Any:
|
|
208
|
+
return jnp.bitwise_xor.reduce(val, axis=axis)
|
|
209
|
+
|
|
210
|
+
return cast(el.Object, tensor.run_jax(_sum_impl, x))
|