mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,378 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Dtype conversion utilities between MPLang ScalarType and external libraries.
16
+
17
+ This module provides bidirectional conversion between MPLang's type system
18
+ (ScalarType hierarchy) and external library types (NumPy, JAX, PyArrow, Pandas).
19
+
20
+ Usage:
21
+ from mplang.v2.dialects import dtypes
22
+
23
+ # MPLang ScalarType → JAX/NumPy
24
+ jax_dtype = dtypes.to_jax(scalar_types.f32) # → jnp.float32
25
+ np_dtype = dtypes.to_numpy(scalar_types.i64) # → np.dtype('int64')
26
+
27
+ # JAX/NumPy → MPLang ScalarType
28
+ scalar_type = dtypes.from_dtype(np.float32) # → scalar_types.f32
29
+ scalar_type = dtypes.from_dtype(jnp.int64) # → scalar_types.i64
30
+
31
+ # PyArrow/Pandas → MPLang ScalarType
32
+ scalar_type = dtypes.from_arrow(pa.int64()) # → scalar_types.i64
33
+ scalar_type = dtypes.from_pandas(df["col"].dtype) # → scalar_types.f64
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ from typing import Any
39
+
40
+ import jax.numpy as jnp
41
+ import numpy as np
42
+
43
+ import mplang.v2.edsl.typing as scalar_types
44
+
45
+ # ==============================================================================
46
+ # MPLang ScalarType → JAX/NumPy conversion
47
+ # ==============================================================================
48
+
49
+ # Mapping from MPLang ScalarType instances to JAX dtypes
50
+ _SCALAR_TO_JAX: dict[scalar_types.ScalarType, Any] = {
51
+ # Signed integers
52
+ scalar_types.i8: jnp.int8,
53
+ scalar_types.i16: jnp.int16,
54
+ scalar_types.i32: jnp.int32,
55
+ scalar_types.i64: jnp.int64,
56
+ # Unsigned integers
57
+ scalar_types.u8: jnp.uint8,
58
+ scalar_types.u16: jnp.uint16,
59
+ scalar_types.u32: jnp.uint32,
60
+ scalar_types.u64: jnp.uint64,
61
+ # Floating point
62
+ scalar_types.f16: jnp.float16,
63
+ scalar_types.f32: jnp.float32,
64
+ scalar_types.f64: jnp.float64,
65
+ # Complex
66
+ scalar_types.c64: jnp.complex64,
67
+ scalar_types.c128: jnp.complex128,
68
+ # Boolean (i1)
69
+ scalar_types.bool_: jnp.bool_,
70
+ }
71
+
72
+
73
+ def to_jax(dtype: scalar_types.ScalarType) -> Any:
74
+ """Convert MPLang scalar type to JAX dtype.
75
+
76
+ Args:
77
+ dtype: MPLang ScalarType (IntegerType, FloatType, or ComplexType)
78
+
79
+ Returns:
80
+ Corresponding JAX/NumPy dtype
81
+
82
+ Raises:
83
+ TypeError: If dtype is not a ScalarType
84
+ ValueError: If dtype has no JAX equivalent
85
+
86
+ Examples:
87
+ >>> dtypes.to_jax(scalar_types.f32)
88
+ <class 'jax.numpy.float32'>
89
+ >>> dtypes.to_jax(scalar_types.i64)
90
+ <class 'jax.numpy.int64'>
91
+ """
92
+ if not isinstance(dtype, scalar_types.ScalarType):
93
+ raise TypeError(f"Expected ScalarType, got {type(dtype).__name__}")
94
+
95
+ # Direct lookup
96
+ if dtype in _SCALAR_TO_JAX:
97
+ return _SCALAR_TO_JAX[dtype]
98
+
99
+ # Handle dynamically created types (same structure but different instance)
100
+ if isinstance(dtype, scalar_types.IntegerType):
101
+ if dtype.bitwidth == 1:
102
+ return jnp.bool_
103
+ prefix = "int" if dtype.signed else "uint"
104
+ try:
105
+ return getattr(jnp, f"{prefix}{dtype.bitwidth}")
106
+ except AttributeError:
107
+ pass
108
+ elif isinstance(dtype, scalar_types.FloatType):
109
+ try:
110
+ return getattr(jnp, f"float{dtype.bitwidth}")
111
+ except AttributeError:
112
+ pass
113
+ elif isinstance(dtype, scalar_types.ComplexType):
114
+ total_bits = dtype.inner_type.bitwidth * 2
115
+ try:
116
+ return getattr(jnp, f"complex{total_bits}")
117
+ except AttributeError:
118
+ pass
119
+
120
+ raise ValueError(f"No JAX dtype equivalent for {dtype}")
121
+
122
+
123
+ def to_numpy(dtype: scalar_types.ScalarType) -> np.dtype:
124
+ """Convert MPLang scalar type to NumPy dtype.
125
+
126
+ Args:
127
+ dtype: MPLang ScalarType
128
+
129
+ Returns:
130
+ Corresponding NumPy dtype
131
+
132
+ Examples:
133
+ >>> dtypes.to_numpy(scalar_types.f32)
134
+ dtype('float32')
135
+ """
136
+ jax_dtype = to_jax(dtype)
137
+ return np.dtype(jax_dtype) # type: ignore[no-any-return]
138
+
139
+
140
+ # ==============================================================================
141
+ # JAX/NumPy → MPLang ScalarType conversion
142
+ # ==============================================================================
143
+
144
+ # Reverse mapping (built dynamically to stay in sync)
145
+ _JAX_TO_SCALAR: dict[Any, scalar_types.ScalarType] = {
146
+ v: k for k, v in _SCALAR_TO_JAX.items()
147
+ }
148
+
149
+ # NumPy dtype to MPLang ScalarType mapping
150
+ _NUMPY_TO_SCALAR: dict[type, scalar_types.ScalarType] = {
151
+ np.int8: scalar_types.i8,
152
+ np.int16: scalar_types.i16,
153
+ np.int32: scalar_types.i32,
154
+ np.int64: scalar_types.i64,
155
+ np.uint8: scalar_types.u8,
156
+ np.uint16: scalar_types.u16,
157
+ np.uint32: scalar_types.u32,
158
+ np.uint64: scalar_types.u64,
159
+ np.float16: scalar_types.f16,
160
+ np.float32: scalar_types.f32,
161
+ np.float64: scalar_types.f64,
162
+ np.complex64: scalar_types.c64,
163
+ np.complex128: scalar_types.c128,
164
+ np.bool_: scalar_types.bool_,
165
+ }
166
+
167
+
168
+ def from_dtype(dtype: Any) -> scalar_types.ScalarType:
169
+ """Convert JAX/NumPy dtype to MPLang scalar type.
170
+
171
+ Args:
172
+ dtype: JAX dtype, NumPy dtype, or dtype-like object
173
+
174
+ Returns:
175
+ Corresponding MPLang ScalarType
176
+
177
+ Raises:
178
+ ValueError: If dtype cannot be converted
179
+
180
+ Examples:
181
+ >>> dtypes.from_dtype(jnp.float32)
182
+ f32
183
+ >>> dtypes.from_dtype(np.dtype("int64"))
184
+ i64
185
+ """
186
+ # Direct lookup for JAX types
187
+ if dtype in _JAX_TO_SCALAR:
188
+ return _JAX_TO_SCALAR[dtype]
189
+
190
+ # Direct lookup for NumPy scalar types
191
+ if dtype in _NUMPY_TO_SCALAR:
192
+ return _NUMPY_TO_SCALAR[dtype]
193
+
194
+ # Handle np.dtype objects
195
+ if isinstance(dtype, np.dtype):
196
+ dtype_type = dtype.type
197
+ if dtype_type in _NUMPY_TO_SCALAR:
198
+ return _NUMPY_TO_SCALAR[dtype_type]
199
+
200
+ # Try to normalize to a dtype object
201
+ try:
202
+ normalized = jnp.dtype(dtype)
203
+ if normalized in _JAX_TO_SCALAR:
204
+ return _JAX_TO_SCALAR[normalized]
205
+ except Exception:
206
+ pass
207
+
208
+ # Fallback: match by name
209
+ name = getattr(dtype, "name", str(dtype)).lower()
210
+
211
+ # Integer types
212
+ if "int8" in name and "uint" not in name:
213
+ return scalar_types.i8
214
+ elif "int16" in name and "uint" not in name:
215
+ return scalar_types.i16
216
+ elif "int32" in name and "uint" not in name:
217
+ return scalar_types.i32
218
+ elif "int64" in name and "uint" not in name:
219
+ return scalar_types.i64
220
+ elif "uint8" in name:
221
+ return scalar_types.u8
222
+ elif "uint16" in name:
223
+ return scalar_types.u16
224
+ elif "uint32" in name:
225
+ return scalar_types.u32
226
+ elif "uint64" in name:
227
+ return scalar_types.u64
228
+ # Float types
229
+ elif "float16" in name:
230
+ return scalar_types.f16
231
+ elif "float32" in name:
232
+ return scalar_types.f32
233
+ elif "float64" in name:
234
+ return scalar_types.f64
235
+ # Complex types
236
+ elif "complex64" in name:
237
+ return scalar_types.c64
238
+ elif "complex128" in name:
239
+ return scalar_types.c128
240
+ # Boolean
241
+ elif "bool" in name:
242
+ return scalar_types.bool_
243
+
244
+ raise ValueError(f"Cannot convert dtype '{dtype}' to MPLang ScalarType")
245
+
246
+
247
+ # ==============================================================================
248
+ # PyArrow → MPLang ScalarType conversion
249
+ # ==============================================================================
250
+
251
+
252
+ def from_arrow(arrow_type: Any) -> scalar_types.BaseType:
253
+ """Convert PyArrow type to MPLang scalar type.
254
+
255
+ Args:
256
+ arrow_type: PyArrow DataType (e.g., pa.int64(), pa.float32())
257
+
258
+ Returns:
259
+ Corresponding MPLang BaseType (ScalarType or CustomType)
260
+
261
+ Raises:
262
+ ValueError: If arrow_type cannot be converted
263
+
264
+ Examples:
265
+ >>> import pyarrow as pa
266
+ >>> dtypes.from_arrow(pa.int64())
267
+ i64
268
+ >>> dtypes.from_arrow(pa.float32())
269
+ f32
270
+ >>> dtypes.from_arrow(pa.string())
271
+ Custom[string]
272
+ """
273
+ import pyarrow as pa
274
+
275
+ if pa.types.is_boolean(arrow_type):
276
+ return scalar_types.bool_
277
+ elif pa.types.is_int8(arrow_type):
278
+ return scalar_types.i8
279
+ elif pa.types.is_int16(arrow_type):
280
+ return scalar_types.i16
281
+ elif pa.types.is_int32(arrow_type):
282
+ return scalar_types.i32
283
+ elif pa.types.is_int64(arrow_type):
284
+ return scalar_types.i64
285
+ elif pa.types.is_uint8(arrow_type):
286
+ return scalar_types.u8
287
+ elif pa.types.is_uint16(arrow_type):
288
+ return scalar_types.u16
289
+ elif pa.types.is_uint32(arrow_type):
290
+ return scalar_types.u32
291
+ elif pa.types.is_uint64(arrow_type):
292
+ return scalar_types.u64
293
+ elif pa.types.is_float16(arrow_type):
294
+ return scalar_types.f16
295
+ elif pa.types.is_float32(arrow_type):
296
+ return scalar_types.f32
297
+ elif pa.types.is_float64(arrow_type) or pa.types.is_floating(arrow_type):
298
+ return scalar_types.f64
299
+ elif pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type):
300
+ return scalar_types.STRING
301
+ elif pa.types.is_date(arrow_type):
302
+ return scalar_types.DATE
303
+ elif pa.types.is_time(arrow_type):
304
+ return scalar_types.TIME
305
+ elif pa.types.is_timestamp(arrow_type):
306
+ return scalar_types.TIMESTAMP
307
+ elif pa.types.is_binary(arrow_type) or pa.types.is_large_binary(arrow_type):
308
+ return scalar_types.BINARY
309
+ else:
310
+ raise ValueError(
311
+ f"Cannot convert PyArrow type '{arrow_type}' to MPLang ScalarType"
312
+ )
313
+
314
+
315
+ # ==============================================================================
316
+ # Pandas dtype → MPLang ScalarType conversion
317
+ # ==============================================================================
318
+
319
+
320
+ def from_pandas(pd_dtype: Any) -> scalar_types.BaseType:
321
+ """Convert Pandas dtype to MPLang scalar type.
322
+
323
+ Args:
324
+ pd_dtype: Pandas dtype (e.g., df["col"].dtype)
325
+
326
+ Returns:
327
+ Corresponding MPLang BaseType (ScalarType or CustomType)
328
+
329
+ Raises:
330
+ ValueError: If pd_dtype cannot be converted
331
+
332
+ Examples:
333
+ >>> import pandas as pd
334
+ >>> df = pd.DataFrame({"x": [1, 2, 3]})
335
+ >>> dtypes.from_pandas(df["x"].dtype)
336
+ i64
337
+ """
338
+ # Get the dtype name as string for matching
339
+ dtype_name = str(pd_dtype)
340
+
341
+ if dtype_name == "bool":
342
+ return scalar_types.bool_
343
+ elif dtype_name in ("int8", "Int8"):
344
+ return scalar_types.i8
345
+ elif dtype_name in ("int16", "Int16"):
346
+ return scalar_types.i16
347
+ elif dtype_name in ("int32", "Int32"):
348
+ return scalar_types.i32
349
+ elif dtype_name in ("int64", "Int64"):
350
+ return scalar_types.i64
351
+ elif dtype_name in ("uint8", "UInt8"):
352
+ return scalar_types.u8
353
+ elif dtype_name in ("uint16", "UInt16"):
354
+ return scalar_types.u16
355
+ elif dtype_name in ("uint32", "UInt32"):
356
+ return scalar_types.u32
357
+ elif dtype_name in ("uint64", "UInt64"):
358
+ return scalar_types.u64
359
+ elif dtype_name in ("float16", "Float16"):
360
+ return scalar_types.f16
361
+ elif dtype_name in ("float32", "Float32"):
362
+ return scalar_types.f32
363
+ elif dtype_name in ("float64", "Float64"):
364
+ return scalar_types.f64
365
+ elif dtype_name in ("complex64",):
366
+ return scalar_types.c64
367
+ elif dtype_name in ("complex128",):
368
+ return scalar_types.c128
369
+ elif dtype_name == "object" or dtype_name.startswith("string"):
370
+ return scalar_types.STRING
371
+ elif dtype_name.startswith("datetime"):
372
+ return scalar_types.TIMESTAMP
373
+ elif dtype_name.startswith("timedelta"):
374
+ return scalar_types.INTERVAL
375
+ else:
376
+ raise ValueError(
377
+ f"Cannot convert Pandas dtype '{pd_dtype}' to MPLang ScalarType"
378
+ )
@@ -0,0 +1,210 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Field dialect: Finite Field Arithmetic.
16
+
17
+ This module defines the Intermediate Representation (IR) for field operations.
18
+ It contains:
19
+ 1. Primitive Definitions (Abstract Operations)
20
+ 2. Abstract Evaluation Rules (Type Inference)
21
+ 3. Public API (Builder Functions)
22
+
23
+ Implementation logic (Backends) is strictly separated into `mplang/v2/backends/field_impl.py`.
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ from typing import Any, cast
29
+
30
+ import jax.numpy as jnp
31
+
32
+ import mplang.v2.edsl as el
33
+ import mplang.v2.edsl.typing as elt
34
+ from mplang.v2.dialects import tensor
35
+
36
+ # =============================================================================
37
+ # Primitives
38
+ # =============================================================================
39
+
40
+ aes_expand_p = el.Primitive[el.Object]("field.aes_expand")
41
+ mul_p = el.Primitive[el.Object]("field.mul")
42
+ solve_okvs_p = el.Primitive[el.Object]("field.solve_okvs")
43
+ decode_okvs_p = el.Primitive[el.Object]("field.decode_okvs")
44
+ ldpc_encode_p = el.Primitive[el.Object]("field.ldpc_encode")
45
+
46
+ # Optimized Mega-Binning Primitives
47
+ solve_okvs_opt_p = el.Primitive[el.Object]("field.solve_okvs_opt")
48
+ decode_okvs_opt_p = el.Primitive[el.Object]("field.decode_okvs_opt")
49
+
50
+ # =============================================================================
51
+ # Abstract Evaluation (Type Inference)
52
+ # =============================================================================
53
+
54
+
55
+ @aes_expand_p.def_abstract_eval
56
+ def _aes_expand_ae(seeds_type: elt.TensorType, *, length: int) -> elt.TensorType:
57
+ # seeds: (N, 2)
58
+ # output: (N, length, 2) -> ALWAYS uint64
59
+ n = seeds_type.shape[0]
60
+ return elt.TensorType(elt.u64, (n, length, 2))
61
+
62
+
63
+ @mul_p.def_abstract_eval
64
+ def _mul_ae(a: elt.TensorType, b: elt.TensorType) -> elt.TensorType:
65
+ return a
66
+
67
+
68
+ @solve_okvs_p.def_abstract_eval
69
+ def _solve_okvs_ae(
70
+ key_type: elt.TensorType,
71
+ val_type: elt.TensorType,
72
+ seed_type: elt.TensorType,
73
+ *,
74
+ m: int,
75
+ ) -> elt.TensorType:
76
+ return elt.TensorType(val_type.element_type, (m, 2))
77
+
78
+
79
+ @decode_okvs_p.def_abstract_eval
80
+ def _decode_okvs_ae(
81
+ key_type: elt.TensorType,
82
+ store_type: elt.TensorType,
83
+ seed_type: elt.TensorType,
84
+ ) -> elt.TensorType:
85
+ n = key_type.shape[0]
86
+ return elt.TensorType(store_type.element_type, (n, 2))
87
+
88
+
89
+ @solve_okvs_opt_p.def_abstract_eval
90
+ def _solve_okvs_opt_ae(
91
+ key_type: elt.TensorType,
92
+ val_type: elt.TensorType,
93
+ seed_type: elt.TensorType,
94
+ *,
95
+ m: int,
96
+ ) -> elt.TensorType:
97
+ return elt.TensorType(val_type.element_type, (m, 2))
98
+
99
+
100
+ @decode_okvs_opt_p.def_abstract_eval
101
+ def _decode_okvs_opt_ae(
102
+ key_type: elt.TensorType,
103
+ store_type: elt.TensorType,
104
+ seed_type: elt.TensorType,
105
+ ) -> elt.TensorType:
106
+ n = key_type.shape[0]
107
+ return elt.TensorType(store_type.element_type, (n, 2))
108
+
109
+
110
+ @ldpc_encode_p.def_abstract_eval
111
+ def _ldpc_encode_ae(
112
+ message: elt.TensorType,
113
+ indices: elt.TensorType,
114
+ indptr: elt.TensorType,
115
+ *,
116
+ m: int,
117
+ n: int,
118
+ ) -> elt.TensorType:
119
+ # message: (K, 2)
120
+ # output: (M, 2) (usually N, 2 in silver context where M=N)
121
+ # Wait, kernel computes (M, 2).
122
+ return elt.TensorType(message.element_type, (m, 2))
123
+
124
+
125
+ # =============================================================================
126
+ # Public API
127
+ # =============================================================================
128
+
129
+
130
+ def aes_expand(seeds: el.Object, length: int) -> el.Object:
131
+ """Expand seeds using AES-CTR PRG.
132
+
133
+ Args:
134
+ seeds: (N, 2) uint64 tensor (keys)
135
+ length: Number of 128-bit blocks to generate per seed
136
+
137
+ Returns:
138
+ (N, length, 2) uint64 tensor
139
+ """
140
+ return aes_expand_p.bind(seeds, length=length)
141
+
142
+
143
+ def mul(a: el.Object, b: el.Object) -> el.Object:
144
+ """GF(2^128) Multiplication."""
145
+ return mul_p.bind(a, b)
146
+
147
+
148
+ def solve_okvs(
149
+ keys: el.Object, values: el.Object, m: int, seed: el.Object
150
+ ) -> el.Object:
151
+ """Solve OKVS P for keys->values using C++ Kernel.
152
+ Returns storage tensor of shape (m, 2).
153
+ """
154
+ return solve_okvs_p.bind(keys, values, seed, m=m)
155
+
156
+
157
+ def decode_okvs(keys: el.Object, storage: el.Object, seed: el.Object) -> el.Object:
158
+ """Decode OKVS values from storage for keys.
159
+ Returns decoded values of shape (N, 2).
160
+ """
161
+ return decode_okvs_p.bind(keys, storage, seed)
162
+
163
+
164
+ def solve_okvs_opt(
165
+ keys: el.Object, values: el.Object, m: int, seed: el.Object
166
+ ) -> el.Object:
167
+ """Solve OKVS using Optimized Mega-Binning Kernel."""
168
+ return solve_okvs_opt_p.bind(keys, values, seed, m=m)
169
+
170
+
171
+ def decode_okvs_opt(keys: el.Object, storage: el.Object, seed: el.Object) -> el.Object:
172
+ """Decode OKVS using Optimized Mega-Binning Kernel."""
173
+ return decode_okvs_opt_p.bind(keys, storage, seed)
174
+
175
+
176
+ def ldpc_encode(
177
+ message: el.Object, h_indices: el.Object, h_indptr: el.Object, m: int, n: int
178
+ ) -> el.Object:
179
+ """Compute S = H * M using Sparse Matrix Multiplication kernel.
180
+
181
+ Args:
182
+ message: (N, 2) or (K, 2) input vector.
183
+ h_indices: CSR indices.
184
+ h_indptr: CSR indptr.
185
+ m: Number of rows in H (Output size).
186
+ n: Number of cols in H (Input size).
187
+
188
+ Returns:
189
+ (M, 2) output vector.
190
+ """
191
+ return ldpc_encode_p.bind(message, h_indices, h_indptr, m=m, n=n)
192
+
193
+
194
+ # =============================================================================
195
+ # Helpers (EDSL Composition)
196
+ # =============================================================================
197
+
198
+
199
+ def add(a: el.Object, b: el.Object) -> el.Object:
200
+ """GF(2^128) Addition (XOR)."""
201
+ return cast(el.Object, tensor.run_jax(jnp.bitwise_xor, a, b))
202
+
203
+
204
+ def sum(x: el.Object, axis: int | None = None) -> el.Object:
205
+ """GF(2^128) Summation (XOR Sum)."""
206
+
207
+ def _sum_impl(val: Any) -> Any:
208
+ return jnp.bitwise_xor.reduce(val, axis=axis)
209
+
210
+ return cast(el.Object, tensor.run_jax(_sum_impl, x))