mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -16,11 +16,11 @@ from __future__ import annotations
16
16
 
17
17
  import numpy as np
18
18
 
19
- from mplang.core import PFunction, TableType, TensorType
20
- from mplang.kernels.base import cur_kctx, kernel_def
21
- from mplang.kernels.value import TableValue, TensorValue, Value
22
- from mplang.runtime.data_providers import get_provider, resolve_uri
23
- from mplang.utils import table_utils
19
+ from mplang.v1.core import PFunction, TableType, TensorType
20
+ from mplang.v1.kernels.base import cur_kctx, kernel_def
21
+ from mplang.v1.kernels.value import TableValue, TensorValue, Value
22
+ from mplang.v1.runtime.data_providers import get_provider, resolve_uri
23
+ from mplang.v1.utils import table_utils
24
24
 
25
25
 
26
26
  @kernel_def("basic.identity")
@@ -45,17 +45,17 @@ def _read(pfunc: PFunction) -> Value:
45
45
  except Exception as e: # pragma: no cover - provider errors
46
46
  raise RuntimeError(f"basic.read failed: {e}") from e
47
47
 
48
+ if isinstance(data, Value):
49
+ return data
50
+
48
51
  if isinstance(out_t, TableType):
49
- if isinstance(data, TableValue):
50
- return data
51
52
  return TableValue(data)
52
- if isinstance(out_t, TensorType):
53
- if isinstance(data, TensorValue):
54
- return data
53
+ elif isinstance(out_t, TensorType):
55
54
  return TensorValue(np.asarray(data))
56
- raise TypeError(
57
- f"basic.read only supports TableType/TensorType outputs, got {type(out_t).__name__}"
58
- )
55
+ else:
56
+ raise TypeError(
57
+ f"basic.read only supports TableType/TensorType outputs, got {type(out_t).__name__}"
58
+ )
59
59
 
60
60
 
61
61
  @kernel_def("basic.write")
@@ -85,9 +85,9 @@ def _constant(pfunc: PFunction) -> Value:
85
85
  out_t = pfunc.outs_info[0]
86
86
  fmt = pfunc.attrs.get("data_format")
87
87
  if isinstance(out_t, TableType):
88
- if fmt != "bytes[csv]":
88
+ if fmt != "bytes[parquet]":
89
89
  raise ValueError(f"unsupported table constant format {fmt}")
90
- df = table_utils.csv_to_dataframe(data_bytes)
90
+ df = table_utils.decode_table(data_bytes, format="parquet")
91
91
  return TableValue(df)
92
92
  # tensor path
93
93
  shape = out_t.shape # type: ignore[attr-defined,union-attr]
@@ -17,12 +17,12 @@ from __future__ import annotations
17
17
  from collections.abc import Mapping
18
18
  from typing import Any
19
19
 
20
- from mplang.core.dtypes import UINT8, DType
21
- from mplang.core.pfunc import PFunction
22
- from mplang.core.table import TableLike, TableType
23
- from mplang.core.tensor import TensorLike, TensorType
24
- from mplang.kernels import base
25
- from mplang.kernels.base import KernelContext, get_kernel_spec, kernel_exists
20
+ from mplang.v1.core.dtypes import UINT8, DType
21
+ from mplang.v1.core.pfunc import PFunction
22
+ from mplang.v1.core.table import PandasTableLike, TableLike, TableType
23
+ from mplang.v1.core.tensor import TensorLike, TensorType
24
+ from mplang.v1.kernels import base
25
+ from mplang.v1.kernels.base import KernelContext, get_kernel_spec, kernel_exists
26
26
 
27
27
  # Default bindings
28
28
  # Import kernel implementation modules explicitly so their @kernel_def entries
@@ -35,14 +35,14 @@ def _ensure_impl_imported() -> None:
35
35
  global _IMPL_IMPORTED
36
36
  if _IMPL_IMPORTED:
37
37
  return
38
- from mplang.kernels import basic as _impl_basic # noqa: F401
39
- from mplang.kernels import crypto as _impl_crypto # noqa: F401
40
- from mplang.kernels import fhe as _impl_fhe # noqa: F401
41
- from mplang.kernels import mock_tee as _impl_tee # noqa: F401
42
- from mplang.kernels import phe as _impl_phe # noqa: F401
43
- from mplang.kernels import spu as _impl_spu # noqa: F401
44
- from mplang.kernels import sql_duckdb as _impl_sql_duckdb # noqa: F401
45
- from mplang.kernels import stablehlo as _impl_stablehlo # noqa: F401
38
+ from mplang.v1.kernels import basic as _impl_basic # noqa: F401
39
+ from mplang.v1.kernels import crypto as _impl_crypto # noqa: F401
40
+ from mplang.v1.kernels import fhe as _impl_fhe # noqa: F401
41
+ from mplang.v1.kernels import mock_tee as _impl_tee # noqa: F401
42
+ from mplang.v1.kernels import phe as _impl_phe # noqa: F401
43
+ from mplang.v1.kernels import spu as _impl_spu # noqa: F401
44
+ from mplang.v1.kernels import sql_duckdb as _impl_sql_duckdb # noqa: F401
45
+ from mplang.v1.kernels import stablehlo as _impl_stablehlo # noqa: F401
46
46
 
47
47
  _IMPL_IMPORTED = True
48
48
 
@@ -317,9 +317,12 @@ def _validate_table_arg(
317
317
  raise TypeError(
318
318
  f"kernel {fn_type} input[{arg_index}] expects TableLike, got {type(value).__name__}"
319
319
  )
320
- if len(value.columns) != len(spec.columns):
320
+ columns = (
321
+ value.columns if isinstance(value, PandasTableLike) else value.column_names
322
+ )
323
+ if len(columns) != len(spec.columns):
321
324
  raise ValueError(
322
- f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(value.columns)}, expected {len(spec.columns)}"
325
+ f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(columns)}, expected {len(spec.columns)}"
323
326
  )
324
327
 
325
328
 
@@ -18,10 +18,10 @@ import os
18
18
 
19
19
  import numpy as np
20
20
 
21
- from mplang.core import PFunction
22
- from mplang.kernels.base import cur_kctx, kernel_def
23
- from mplang.kernels.value import TensorValue
24
- from mplang.utils.crypto import blake2b
21
+ from mplang.v1.core import PFunction
22
+ from mplang.v1.kernels.base import cur_kctx, kernel_def
23
+ from mplang.v1.kernels.value import TensorValue
24
+ from mplang.v1.utils.crypto import blake2b
25
25
 
26
26
  __all__: list[str] = [] # No public exports currently
27
27
 
@@ -45,11 +45,9 @@ def _get_rng() -> np.random.Generator:
45
45
  def _keystream(key: bytes, nonce: bytes, length: int) -> bytes:
46
46
  # WARNING (INSECURE): hash-based keystream (key||nonce||counter)
47
47
  out = bytearray()
48
- counter = 0
49
48
  while len(out) < length:
50
- chunk = blake2b(key + nonce + counter.to_bytes(4, "little"))
49
+ chunk = blake2b(key + nonce)
51
50
  out.extend(chunk)
52
- counter += 1
53
51
  return bytes(out[:length])
54
52
 
55
53
 
@@ -68,7 +66,7 @@ def _crypto_encrypt(
68
66
  pt_bytes_np = pt_bytes.to_numpy().astype(np.uint8, copy=False)
69
67
  key_np = key.to_numpy().astype(np.uint8, copy=False)
70
68
  rng = _get_rng()
71
- nonce = rng.integers(0, 256, size=(12,), dtype=np.uint8)
69
+ nonce = rng.integers(0, 256, size=(16,), dtype=np.uint8)
72
70
  stream = np.frombuffer(
73
71
  _keystream(key_np.tobytes(), nonce.tobytes(), pt_bytes_np.size), dtype=np.uint8
74
72
  )
@@ -83,8 +81,8 @@ def _crypto_decrypt(
83
81
  ) -> TensorValue:
84
82
  ct_np = ct_with_nonce.to_numpy().astype(np.uint8, copy=False)
85
83
  key_np = key.to_numpy().astype(np.uint8, copy=False)
86
- nonce = ct_np[:12]
87
- ct = ct_np[12:]
84
+ nonce = ct_np[:16]
85
+ ct = ct_np[16:]
88
86
  stream = np.frombuffer(
89
87
  _keystream(key_np.tobytes(), nonce.tobytes(), len(ct)), dtype=np.uint8
90
88
  )
@@ -23,8 +23,9 @@ from typing import Any
23
23
  import numpy as np
24
24
  import tenseal as ts
25
25
 
26
- from mplang.core import DType, PFunction, TensorLike
27
- from mplang.kernels.base import kernel_def
26
+ from mplang.v1.core import DType, PFunction, TensorLike
27
+ from mplang.v1.kernels.base import kernel_def
28
+ from mplang.v1.kernels.value import TensorValue
28
29
 
29
30
 
30
31
  class FHEContext:
@@ -337,13 +338,14 @@ def _fhe_decrypt(pfunc: PFunction, ciphertext: CipherText, context: FHEContext)
337
338
 
338
339
  # Restore original shape
339
340
  if ciphertext.semantic_shape == ():
340
- # Was a scalar, extract single value
341
- result = decrypted_np[0:1].reshape(())
341
+ # Scalar: shape ()
342
+ result_np = decrypted_np[0:1].reshape(())
342
343
  else:
343
- # Keep as vector
344
- result = decrypted_np
344
+ # Vector: keep 1D array
345
+ result_np = decrypted_np
345
346
 
346
- return (result,)
347
+ # Return TensorValue to adhere to kernel Value I/O convention
348
+ return (TensorValue(np.asarray(result_np)),)
347
349
 
348
350
  except Exception as e:
349
351
  raise RuntimeError(f"FHE vector decryption failed: {e}") from e
@@ -20,9 +20,9 @@ import warnings
20
20
  import numpy as np
21
21
  from numpy.typing import NDArray
22
22
 
23
- from mplang.core import PFunction
24
- from mplang.kernels.base import cur_kctx, kernel_def
25
- from mplang.kernels.value import TensorValue
23
+ from mplang.v1.core import PFunction
24
+ from mplang.v1.kernels.base import cur_kctx, kernel_def
25
+ from mplang.v1.kernels.value import TensorValue
26
26
 
27
27
  __all__: list[str] = []
28
28
 
@@ -23,9 +23,9 @@ import numpy as np
23
23
  from lightphe import LightPHE
24
24
  from lightphe.models.Ciphertext import Ciphertext
25
25
 
26
- from mplang.core import DType, PFunction
27
- from mplang.kernels.base import kernel_def
28
- from mplang.kernels.value import (
26
+ from mplang.v1.core import DType, PFunction
27
+ from mplang.v1.kernels.base import kernel_def
28
+ from mplang.v1.kernels.value import (
29
29
  TensorValue,
30
30
  Value,
31
31
  ValueDecodeError,
@@ -33,7 +33,7 @@ from mplang.kernels.value import (
33
33
  ValueProtoReader,
34
34
  register_value,
35
35
  )
36
- from mplang.protos.v1alpha1 import value_pb2 as _value_pb2
36
+ from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
37
37
 
38
38
  # This controls the decimal precision used in lightPHE for float operations
39
39
  # we force it to 0 to only support integer operations
@@ -473,10 +473,9 @@ def _phe_keygen(pfunc: PFunction) -> Any:
473
473
  # use small key_size to speed up tests
474
474
  # in production use at least 2048 bits or 3072 bits for better security
475
475
  key_size = pfunc.attrs.get("key_size", 2048)
476
- max_value = pfunc.attrs.get(
477
- "max_value", 2**32
478
- ) # Use larger range to avoid overflow
479
- fxp_bits = pfunc.attrs.get("fxp_bits", 12)
476
+ # Accept very large max_value; allow decimal string input, kept simple like other attrs
477
+ max_value = int(pfunc.attrs.get("max_value", 2**32))
478
+ fxp_bits = int(pfunc.attrs.get("fxp_bits", 12))
480
479
 
481
480
  # Validate scheme
482
481
  if scheme.lower() not in ["paillier"]:
@@ -638,7 +637,8 @@ def _phe_mul(pfunc: PFunction, ciphertext: CipherText, plaintext: TensorValue) -
638
637
  # Use numpy to create a properly broadcasted index mapping
639
638
  # Create a dummy array with same shape as ciphertext, fill with indices
640
639
  dummy_ct = (
641
- np.arange(np.prod(ciphertext.semantic_shape))
640
+ np
641
+ .arange(np.prod(ciphertext.semantic_shape))
642
642
  .reshape(ciphertext.semantic_shape)
643
643
  .astype(np.int64)
644
644
  )
@@ -745,7 +745,8 @@ def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
745
745
  # Broadcast ct1 if needed
746
746
  if ct1.semantic_shape != result_shape:
747
747
  dummy_ct1 = (
748
- np.arange(np.prod(ct1.semantic_shape))
748
+ np
749
+ .arange(np.prod(ct1.semantic_shape))
749
750
  .reshape(ct1.semantic_shape)
750
751
  .astype(np.int64)
751
752
  )
@@ -758,7 +759,8 @@ def _phe_add_ct2ct(ct1: CipherText, ct2: CipherText) -> CipherText:
758
759
  # Broadcast ct2 if needed
759
760
  if ct2.semantic_shape != result_shape:
760
761
  dummy_ct2 = (
761
- np.arange(np.prod(ct2.semantic_shape))
762
+ np
763
+ .arange(np.prod(ct2.semantic_shape))
762
764
  .reshape(ct2.semantic_shape)
763
765
  .astype(np.int64)
764
766
  )
@@ -831,7 +833,8 @@ def _phe_add_ct2pt(ciphertext: CipherText, plaintext: TensorValue) -> CipherText
831
833
  # Broadcast ciphertext if needed
832
834
  if ciphertext.semantic_shape != result_shape:
833
835
  dummy_ct = (
834
- np.arange(np.prod(ciphertext.semantic_shape))
836
+ np
837
+ .arange(np.prod(ciphertext.semantic_shape))
835
838
  .reshape(ciphertext.semantic_shape)
836
839
  .astype(np.int64)
837
840
  )
@@ -997,12 +1000,17 @@ def _phe_decrypt(
997
1000
  # Convert to target dtype
998
1001
  if target_dtype.kind in "iu": # integer types
999
1002
  # Convert floats back to integers for integer semantic types
1000
- processed_data = [round(val) for val in decoded_data]
1001
- # Handle overflow for smaller integer types
1002
- info = np.iinfo(target_dtype)
1003
- processed_data = [
1004
- max(info.min, min(info.max, val)) for val in processed_data
1005
- ]
1003
+ # decoded_data are numeric (ints or floats); normalize to Python int
1004
+ ints = [round(v) if isinstance(v, float) else v for v in decoded_data]
1005
+ if np.issubdtype(target_dtype, np.unsignedinteger):
1006
+ # Reduce modulo 2^k for unsigned to preserve ring semantics
1007
+ width = np.iinfo(target_dtype).bits
1008
+ mod = 1 << width
1009
+ processed_data = [v % mod for v in ints]
1010
+ else:
1011
+ # Signed integers: clamp to dtype range
1012
+ info = np.iinfo(target_dtype)
1013
+ processed_data = [max(info.min, min(info.max, v)) for v in ints]
1006
1014
  else: # float types
1007
1015
  processed_data = decoded_data
1008
1016
 
@@ -21,7 +21,7 @@ import numpy as np
21
21
  import spu.api as spu_api
22
22
  import spu.libspu as libspu
23
23
 
24
- from mplang.core import (
24
+ from mplang.v1.core import (
25
25
  BOOL,
26
26
  FLOAT32,
27
27
  FLOAT64,
@@ -36,8 +36,8 @@ from mplang.core import (
36
36
  DType,
37
37
  PFunction,
38
38
  )
39
- from mplang.kernels.base import cur_kctx, kernel_def
40
- from mplang.kernels.value import (
39
+ from mplang.v1.kernels.base import cur_kctx, kernel_def
40
+ from mplang.v1.kernels.value import (
41
41
  TensorValue,
42
42
  Value,
43
43
  ValueDecodeError,
@@ -45,8 +45,8 @@ from mplang.kernels.value import (
45
45
  ValueProtoReader,
46
46
  register_value,
47
47
  )
48
- from mplang.protos.v1alpha1 import value_pb2 as _value_pb2
49
- from mplang.runtime.link_comm import LinkCommunicator
48
+ from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
49
+ from mplang.v1.runtime.link_comm import LinkCommunicator
50
50
 
51
51
 
52
52
  def shape_spu_to_np(spu_shape: Any) -> tuple[int, ...]:
@@ -14,9 +14,9 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- from mplang.core import PFunction
18
- from mplang.kernels.base import kernel_def
19
- from mplang.kernels.value import TableValue
17
+ from mplang.v1.core import PFunction
18
+ from mplang.v1.kernels.base import kernel_def
19
+ from mplang.v1.kernels.value import TableValue
20
20
 
21
21
 
22
22
  @kernel_def("duckdb.run_sql")
@@ -38,5 +38,7 @@ def _duckdb_sql(pfunc: PFunction, *args: TableValue) -> TableValue:
38
38
  arrow_table = arg.to_arrow()
39
39
  conn.register(name, arrow_table)
40
40
  # Fetch result as Arrow table for consistency
41
+ if pfunc.fn_text is None:
42
+ raise ValueError("SQL function text is None")
41
43
  res_arrow = conn.execute(pfunc.fn_text).fetch_arrow_table()
42
44
  return TableValue(res_arrow)
@@ -17,14 +17,14 @@ from __future__ import annotations
17
17
  from typing import Any
18
18
 
19
19
  import jax
20
+ import jax.extend as jxt
20
21
  import jax.numpy as jnp
21
22
  import numpy as np
22
- from jax._src import xla_bridge
23
- from jax.lib import xla_client as xc
23
+ from jax._src import compiler
24
24
 
25
- from mplang.core import PFunction
26
- from mplang.kernels.base import cur_kctx, kernel_def
27
- from mplang.kernels.value import TensorValue
25
+ from mplang.v1.core import PFunction
26
+ from mplang.v1.kernels.base import cur_kctx, kernel_def
27
+ from mplang.v1.kernels.value import TensorValue
28
28
 
29
29
 
30
30
  @kernel_def("mlir.stablehlo")
@@ -47,11 +47,13 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
47
47
  key = f"stablehlo.compile_cache.{h}"
48
48
  compiled = rt.get_state(key)
49
49
  if compiled is None:
50
- backend = jax.default_backend()
51
- client = xla_bridge.get_backend(backend)
52
- compile_options = xc.CompileOptions()
50
+ client = jxt.backend.get_backend()
51
+ compile_options = compiler.get_compile_options(num_replicas=1, num_partitions=1)
52
+
53
53
  try:
54
- compiled = client.compile(mlir_text, compile_options)
54
+ compiled = client.compile_and_load(
55
+ mlir_text, client.devices(), compile_options
56
+ )
55
57
  except Exception as e: # pragma: no cover
56
58
  raise RuntimeError(f"StableHLO compile failed: {e}") from e
57
59
  rt.set_state(key, compiled)
@@ -76,14 +78,13 @@ def _stablehlo_exec(pfunc: PFunction, *args: Any) -> Any:
76
78
  ]
77
79
 
78
80
  try:
79
- result = compiled.execute_sharded(jax_args)
80
- arrays = result.disassemble_into_single_device_arrays()
81
- flat: list[Any] = []
82
- for lst in arrays:
83
- if isinstance(lst, list) and len(lst) == 1:
84
- flat.append(TensorValue(np.asarray(lst[0])))
85
- else:
86
- flat.extend(TensorValue(np.asarray(a)) for a in lst)
81
+ # Execute with the new LoadedExecutable interface
82
+ result = compiled.execute(jax_args)
83
+
84
+ # Use jax.tree_util.tree_flatten to robustly handle any PyTree structure
85
+ flat_results, _ = jax.tree_util.tree_flatten(result)
86
+ flat = [TensorValue(np.asarray(item)) for item in flat_results]
87
+
87
88
  return tuple(flat)
88
89
  except Exception as e: # pragma: no cover
89
90
  raise RuntimeError(f"StableHLO execute failed: {e}") from e
@@ -17,7 +17,7 @@ from __future__ import annotations
17
17
  from abc import ABC, abstractmethod
18
18
  from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
19
19
 
20
- from mplang.protos.v1alpha1 import value_pb2 as _value_pb2
20
+ from mplang.v1.protos.v1alpha1 import value_pb2 as _value_pb2
21
21
 
22
22
  if TYPE_CHECKING:
23
23
  import numpy as np
@@ -591,7 +591,7 @@ class TableValue(Value): # well-known table (Arrow IPC) Value
591
591
 
592
592
  Note: This creates a copy and converts from Arrow to pandas format.
593
593
  For better performance, consider using to_arrow() and working with
594
- Arrow-native APIs (DuckDB, Ibis, etc.) directly.
594
+ Arrow-native APIs (DuckDB, etc.) directly.
595
595
 
596
596
  Returns:
597
597
  pandas.DataFrame: Converted dataframe
@@ -19,15 +19,15 @@ This module contains compilers that transform high-level functions into
19
19
  portable, serializable intermediate representations.
20
20
  """
21
21
 
22
- from mplang.ops import basic, crypto, ibis_cc, jax_cc, phe, spu, sql_cc, tee
23
- from mplang.ops.base import FeOperation as FeOperation
22
+ from mplang.v1.ops import basic, crypto, jax_cc, nnx_cc, phe, spu, sql_cc, tee
23
+ from mplang.v1.ops.base import FeOperation as FeOperation
24
24
 
25
25
  __all__ = [
26
26
  "FeOperation",
27
27
  "basic",
28
28
  "crypto",
29
- "ibis_cc",
30
29
  "jax_cc",
30
+ "nnx_cc",
31
31
  "phe",
32
32
  "spu",
33
33
  "sql_cc",
@@ -20,7 +20,7 @@ from typing import Any
20
20
 
21
21
  from jax.tree_util import PyTreeDef, tree_flatten
22
22
 
23
- from mplang.core import MPContext, MPObject, PFunction, TableType, TensorType
23
+ from mplang.v1.core import MPContext, MPObject, PFunction, TableType, TensorType
24
24
 
25
25
  # -----------------------------------------------------------------------------
26
26
  # Triad ABI
@@ -15,7 +15,7 @@
15
15
 
16
16
  from jax.tree_util import PyTreeDef, tree_flatten
17
17
 
18
- from mplang.core import (
18
+ from mplang.v1.core import (
19
19
  UINT8,
20
20
  UINT64,
21
21
  MPObject,
@@ -27,8 +27,8 @@ from mplang.core import (
27
27
  TensorLike,
28
28
  TensorType,
29
29
  )
30
- from mplang.ops.base import stateless_mod
31
- from mplang.utils import table_utils
30
+ from mplang.v1.ops.base import stateless_mod
31
+ from mplang.v1.utils import table_utils
32
32
 
33
33
  _BASIC_MOD = stateless_mod("basic")
34
34
 
@@ -108,8 +108,9 @@ def constant(
108
108
  out_type: TableType | TensorType
109
109
 
110
110
  if isinstance(data, TableLike):
111
- data_bytes = table_utils.dataframe_to_csv(data)
112
- data_format = "bytes[csv]"
111
+ format = "parquet"
112
+ data_bytes = table_utils.encode_table(data, format=format)
113
+ data_format = f"bytes[{format}]"
113
114
  out_type = TableType.from_tablelike(data)
114
115
  elif isinstance(data, ScalarType):
115
116
  out_type = TensorType.from_obj(data)