mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- """TEE (Trusted Execution Environment) dialect for mplang.v2 EDSL.
15
+ """TEE (Trusted Execution Environment) dialect for mplang EDSL.
16
16
 
17
17
  This dialect provides primitives for TEE remote attestation, enabling secure
18
18
  computation where:
@@ -38,7 +38,7 @@ Supported Platforms:
38
38
 
39
39
  Example:
40
40
  ```python
41
- from mplang.v2.dialects import tee, crypto
41
+ from mplang.dialects import tee, crypto
42
42
 
43
43
  # On TEE side: generate keypair and quote
44
44
  sk, pk = crypto.kem_keygen("x25519")
@@ -62,10 +62,10 @@ from __future__ import annotations
62
62
 
63
63
  from typing import Any, ClassVar, Literal
64
64
 
65
- import mplang.v2.edsl as el
66
- import mplang.v2.edsl.typing as elt
67
- from mplang.v2.dialects.crypto import PublicKeyType
68
- from mplang.v2.edsl import serde
65
+ import mplang.edsl as el
66
+ import mplang.edsl.typing as elt
67
+ from mplang.dialects.crypto import PublicKeyType
68
+ from mplang.edsl import serde
69
69
 
70
70
  # ==============================================================================
71
71
  # --- Type Definitions
@@ -58,10 +58,10 @@ import numpy as np
58
58
  from jax import ShapeDtypeStruct
59
59
  from jax.tree_util import PyTreeDef, tree_flatten
60
60
 
61
- import mplang.v2.edsl as el
62
- import mplang.v2.edsl.typing as elt
63
- from mplang.v1.utils.func_utils import normalize_fn
64
- from mplang.v2.dialects import dtypes
61
+ import mplang.edsl as el
62
+ import mplang.edsl.typing as elt
63
+ from mplang.dialects import dtypes
64
+ from mplang.utils import normalize_fn
65
65
 
66
66
  run_jax_p = el.Primitive[Any]("tensor.run_jax")
67
67
  constant_p = el.Primitive[el.Object]("tensor.constant")
@@ -77,7 +77,7 @@ class RunJaxCompilation:
77
77
 
78
78
  fn: Callable[..., Any]
79
79
  stablehlo: str
80
- out_tree: PyTreeDef
80
+ out_tree: PyTreeDef # type: ignore
81
81
  output_types: list[elt.BaseType]
82
82
  arg_keep_map: list[int] | None = None
83
83
 
@@ -17,8 +17,8 @@
17
17
  This module keeps the surface area intentionally small so downstream code can
18
18
  simply write::
19
19
 
20
- import mplang.v2.edsl as el
21
- import mplang.v2.edsl.typing as elt
20
+ import mplang.edsl as el
21
+ import mplang.edsl.typing as elt
22
22
 
23
23
  The `el` namespace re-exports the commonly used building blocks (context,
24
24
  graph, tracer, primitives, etc.), while the full type system lives under
@@ -27,7 +27,7 @@ graph, tracer, primitives, etc.), while the full type system lives under
27
27
 
28
28
  from __future__ import annotations
29
29
 
30
- # Re-export the typing module so callers can `import mplang.v2.edsl.typing as elt`
30
+ # Re-export the typing module so callers can `import mplang.edsl.typing as elt`
31
31
  from . import typing as typing
32
32
 
33
33
  # Context management
@@ -21,7 +21,7 @@ This module defines the Context hierarchy:
21
21
 
22
22
  Contexts can be used directly with Python's 'with' statement:
23
23
 
24
- from mplang.v2.edsl import Tracer
24
+ from mplang.edsl import Tracer
25
25
 
26
26
  tracer = Tracer()
27
27
  with tracer:
@@ -46,9 +46,9 @@ from collections.abc import Callable
46
46
  from typing import TYPE_CHECKING, Any, Self
47
47
 
48
48
  if TYPE_CHECKING:
49
- from mplang.v2.edsl.graph import Graph
50
- from mplang.v2.edsl.object import Object
51
- from mplang.v2.edsl.primitive import Primitive
49
+ from mplang.edsl.graph import Graph
50
+ from mplang.edsl.object import Object
51
+ from mplang.edsl.primitive import Primitive
52
52
 
53
53
 
54
54
  class Context(ABC):
@@ -257,7 +257,7 @@ def is_tracing() -> bool:
257
257
  Returns:
258
258
  True if the top of the context stack is a Tracer.
259
259
  """
260
- from mplang.v2.edsl.tracer import Tracer
260
+ from mplang.edsl.tracer import Tracer
261
261
 
262
262
  return isinstance(get_current_context(), Tracer)
263
263
 
@@ -280,7 +280,7 @@ def get_default_context() -> Context:
280
280
  if _default_context_factory is None:
281
281
  raise RuntimeError(
282
282
  "No default context factory registered. "
283
- "Ensure mplang.v2.edsl is imported or register a factory manually."
283
+ "Ensure mplang.edsl is imported or register a factory manually."
284
284
  )
285
285
  _default_context = _default_context_factory()
286
286
  return _default_context
@@ -27,8 +27,8 @@ Key Design Principles:
27
27
 
28
28
  Example:
29
29
  --------
30
- from mplang.v2.edsl.graph import Graph, Operation, Value
31
- from mplang.v2.edsl.typing import Tensor, f32
30
+ from mplang.edsl.graph import Graph, Operation, Value
31
+ from mplang.edsl.typing import Tensor, f32
32
32
 
33
33
  graph = Graph()
34
34
 
@@ -61,8 +61,8 @@ from collections.abc import Sequence
61
61
  from dataclasses import dataclass, field
62
62
  from typing import Any, ClassVar
63
63
 
64
- from mplang.v2.edsl import serde
65
- from mplang.v2.edsl.typing import BaseType
64
+ from mplang.edsl import serde
65
+ from mplang.edsl.typing import BaseType
66
66
 
67
67
 
68
68
  @dataclass
@@ -74,7 +74,7 @@ class Value:
74
74
 
75
75
  Attributes:
76
76
  name: Unique SSA name (e.g., "%0", "%1", ...)
77
- type: Type of this value (from mplang.v2.edsl.typing)
77
+ type: Type of this value (from mplang.edsl.typing)
78
78
  defining_op: Operation that produces this value (None for inputs)
79
79
  uses: List of operations that consume this value
80
80
  """
@@ -19,12 +19,12 @@ from typing import Any
19
19
 
20
20
  from jax.tree_util import tree_map
21
21
 
22
- from mplang.v2.edsl.context import (
22
+ from mplang.edsl.context import (
23
23
  AbstractInterpreter,
24
24
  get_current_context,
25
25
  get_default_context,
26
26
  )
27
- from mplang.v2.edsl.tracer import Tracer
27
+ from mplang.edsl.tracer import Tracer
28
28
 
29
29
 
30
30
  def jit(fn: Callable) -> Callable:
@@ -25,7 +25,7 @@ from __future__ import annotations
25
25
  from abc import ABC, abstractmethod
26
26
  from typing import Generic, TypeVar
27
27
 
28
- from mplang.v2.edsl.typing import BaseType
28
+ from mplang.edsl.typing import BaseType
29
29
 
30
30
  T = TypeVar("T", bound=BaseType)
31
31
 
@@ -27,11 +27,11 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
27
27
 
28
28
  from jax.tree_util import tree_map
29
29
 
30
- from mplang.v2.edsl.context import get_current_context, get_default_context
31
- from mplang.v2.edsl.object import Object
30
+ from mplang.edsl.context import get_current_context, get_default_context
31
+ from mplang.edsl.object import Object
32
32
 
33
33
  if TYPE_CHECKING:
34
- from mplang.v2.edsl.typing import BaseType
34
+ from mplang.edsl.typing import BaseType
35
35
 
36
36
  T_Ret = TypeVar("T_Ret")
37
37
 
@@ -54,7 +54,7 @@ class Primitive(Generic[T_Ret]):
54
54
  >>>
55
55
  >>> @encrypt_p.def_abstract_eval
56
56
  >>> def encrypt_abstract(x_type):
57
- >>> from mplang.v2.edsl.typing import Vector
57
+ >>> from mplang.edsl.typing import Vector
58
58
  >>> return Vector[x_type.dtype, x_type.shape]
59
59
  >>>
60
60
  >>> # Execution happens via Graph IR → Backend
@@ -91,7 +91,7 @@ class Primitive(Generic[T_Ret]):
91
91
  """
92
92
  self._impl = fn
93
93
  # Register with the global interpreter registry
94
- from mplang.v2.edsl.registry import register_impl
94
+ from mplang.edsl.registry import register_impl
95
95
 
96
96
  register_impl(self.name, fn)
97
97
  return fn
@@ -18,7 +18,7 @@ from __future__ import annotations
18
18
 
19
19
  from typing import Any
20
20
 
21
- from mplang.v2.edsl.graph import Graph, Operation, Value
21
+ from mplang.edsl.graph import Graph, Operation, Value
22
22
 
23
23
 
24
24
  class GraphPrinter:
@@ -20,7 +20,7 @@ cloudpickle. Each type is responsible for its own serialization via the
20
20
  `@register_class` decorator pattern.
21
21
 
22
22
  Usage:
23
- from mplang.v2.edsl import serde
23
+ from mplang.edsl import serde
24
24
 
25
25
  @serde.register_class
26
26
  class MyType:
@@ -31,14 +31,14 @@ from typing import TYPE_CHECKING, Any, cast
31
31
 
32
32
  from jax.tree_util import PyTreeDef, tree_flatten, tree_map
33
33
 
34
- from mplang.v2.edsl.context import Context
35
- from mplang.v2.edsl.graph import Graph
36
- from mplang.v2.edsl.graph import Value as GraphValue
37
- from mplang.v2.edsl.object import Object
38
- from mplang.v2.edsl.typing import BaseType
34
+ from mplang.edsl.context import Context
35
+ from mplang.edsl.graph import Graph
36
+ from mplang.edsl.graph import Value as GraphValue
37
+ from mplang.edsl.object import Object
38
+ from mplang.edsl.typing import BaseType
39
39
 
40
40
  if TYPE_CHECKING:
41
- from mplang.v2.edsl.primitive import Primitive
41
+ from mplang.edsl.primitive import Primitive
42
42
 
43
43
 
44
44
  class TraceObject(Object):
@@ -48,7 +48,7 @@ class TraceObject(Object):
48
48
  All operations delegate to primitives which record into Graph.
49
49
 
50
50
  Example:
51
- >>> from mplang.v2.edsl import trace
51
+ >>> from mplang.edsl import trace
52
52
  >>> def compute(x, y):
53
53
  ... z = x + y # TraceObject.__add__ → add_p.bind(x, y)
54
54
  ... return z
@@ -124,7 +124,7 @@ from __future__ import annotations
124
124
 
125
125
  from typing import Any, ClassVar, Generic, TypeVar
126
126
 
127
- from mplang.v2.edsl import serde
127
+ from mplang.edsl import serde
128
128
 
129
129
  # ==============================================================================
130
130
  # --- Base Type & Type Aliases
@@ -27,13 +27,13 @@ extern "C" {
27
27
 
28
28
  /**
29
29
  * @brief LDPC Encoding: Compute Syndrome s = H * x
30
- *
30
+ *
31
31
  * H is a sparse M x N binary matrix (CSR format).
32
32
  * x is a dense N-vector of 128-bit blocks (N * 16 bytes).
33
33
  * s is a dense M-vector of 128-bit blocks (M * 16 bytes).
34
- *
34
+ *
35
35
  * Logic: For each row i of H, s[i] = XOR(x[j]) for all j where H[i, j] = 1.
36
- *
36
+ *
37
37
  * @param message_ptr Pointer to message x (N * 2 uint64_t)
38
38
  * @param indices_ptr Pointer to CSR indices (uint64_t)
39
39
  * @param indptr_ptr Pointer to CSR indptr (M+1 uint64_t)
@@ -41,17 +41,17 @@ extern "C" {
41
41
  * @param m Number of rows in H (syndrome length)
42
42
  * @param n Number of cols in H (message length)
43
43
  */
44
- void ldpc_encode(const uint64_t* message_ptr,
45
- const uint64_t* indices_ptr,
46
- const uint64_t* indptr_ptr,
47
- uint64_t* output_ptr,
48
- uint64_t m,
44
+ void ldpc_encode(const uint64_t* message_ptr,
45
+ const uint64_t* indices_ptr,
46
+ const uint64_t* indptr_ptr,
47
+ uint64_t* output_ptr,
48
+ uint64_t m,
49
49
  uint64_t n) {
50
-
50
+
51
51
  // Check alignment
52
52
  // We assume message_ptr and output_ptr are 16-byte aligned for SSE/AVX?
53
53
  // JAX/Numpy arrays are usually aligned.
54
-
54
+
55
55
  // Cast to __m128i for efficiency
56
56
  // But we need to handle potential unaligned access if numpy doesn't align.
57
57
  // _mm_loadu_si128 handles unaligned.
@@ -63,10 +63,10 @@ void ldpc_encode(const uint64_t* message_ptr,
63
63
  for (uint64_t i = 0; i < m; ++i) {
64
64
  // Row i
65
65
  __m128i sum = _mm_setzero_si128();
66
-
66
+
67
67
  uint64_t start = indptr_ptr[i];
68
68
  uint64_t end = indptr_ptr[i+1];
69
-
69
+
70
70
  for (uint64_t k = start; k < end; ++k) {
71
71
  uint64_t col_idx = indices_ptr[k];
72
72
  // XOR accumulation
@@ -74,7 +74,7 @@ void ldpc_encode(const uint64_t* message_ptr,
74
74
  __m128i val = _mm_loadu_si128(&x_vec[col_idx]);
75
75
  sum = _mm_xor_si128(sum, val);
76
76
  }
77
-
77
+
78
78
  _mm_storeu_si128(&s_vec[i], sum);
79
79
  }
80
80
  }
@@ -86,7 +86,7 @@ extern "C" {
86
86
  // 3. Build CSR Structure (Flat Arrays) to replace vector<vector>
87
87
  // col_start[j] points to start of column j's rows in flat_rows
88
88
  std::vector<int> col_start(m + 1, 0);
89
-
89
+
90
90
  // Prefix sum to compute start positions
91
91
  // col_start[0] = 0
92
92
  // col_start[j+1] = col_start[j] + degree[j]
@@ -96,7 +96,7 @@ extern "C" {
96
96
 
97
97
  // Total edges = 3 * N implies flat_rows size
98
98
  std::vector<int> flat_rows(n * 3);
99
-
99
+
100
100
  // Temporary copy of start indices to use as fill pointers
101
101
  std::vector<int> fill_ptr = col_start;
102
102
 
@@ -106,7 +106,7 @@ extern "C" {
106
106
  flat_rows[fill_ptr[r.h2]++] = i;
107
107
  flat_rows[fill_ptr[r.h3]++] = i;
108
108
  }
109
-
109
+
110
110
  // 4. Initialize Peeling
111
111
  std::vector<int> peel_stack;
112
112
  peel_stack.reserve(m);
@@ -135,7 +135,7 @@ extern "C" {
135
135
  int owner_row = -1;
136
136
  int start = col_start[j];
137
137
  int end = col_start[j+1];
138
-
138
+
139
139
  for(int k=start; k<end; ++k) {
140
140
  int r_idx = flat_rows[k];
141
141
  if(!row_removed[r_idx]) {
@@ -28,7 +28,7 @@
28
28
  extern "C" {
29
29
 
30
30
  // Number of Bins for Mega-Binning strategy.
31
- // 1024 bins implies ~1000 items per bin for N=1M, fitting the working set
31
+ // 1024 bins implies ~1000 items per bin for N=1M, fitting the working set
32
32
  // entirely in L1 cache (32KB/48KB) for maximum performance.
33
33
  static const uint64_t NUM_BINS = 1024;
34
34
 
@@ -54,14 +54,14 @@ extern "C" {
54
54
  inline Indices get_bin_local_indices(uint64_t key, uint64_t m_local, __m128i seed) {
55
55
  // Use a distinct seed mix to decorrelate from bin selection
56
56
  __m128i k = _mm_set_epi64x(0, key);
57
- __m128i s2 = _mm_add_epi64(seed, _mm_set_epi64x(1, 1));
57
+ __m128i s2 = _mm_add_epi64(seed, _mm_set_epi64x(1, 1));
58
58
  __m128i h = _mm_aesenc_si128(k, s2);
59
59
  h = _mm_aesenc_si128(h, s2);
60
60
  h = _mm_aesenc_si128(h, s2);
61
61
 
62
62
  uint64_t r = _mm_extract_epi64(h, 0);
63
63
  Indices idx;
64
-
64
+
65
65
  // Fast modulo for local indices
66
66
  idx.h1 = r % m_local;
67
67
  r = r * 6364136223846793005ULL + 1442695040888963407ULL; // LCG step
@@ -80,10 +80,10 @@ extern "C" {
80
80
 
81
81
  // Core Peeling Solver for a single Bin
82
82
  bool solve_bin(
83
- const std::vector<uint64_t>& keys,
84
- const std::vector<__m128i>& vals,
85
- __m128i* P_local,
86
- uint64_t m,
83
+ const std::vector<uint64_t>& keys,
84
+ const std::vector<__m128i>& vals,
85
+ __m128i* P_local,
86
+ uint64_t m,
87
87
  __m128i seed
88
88
  ) {
89
89
  uint64_t n = keys.size();
@@ -95,7 +95,7 @@ extern "C" {
95
95
  };
96
96
  std::vector<Edge> edges(n);
97
97
  std::vector<int> col_degree(m, 0);
98
-
98
+
99
99
  // 1. Build Local Graph
100
100
  for(uint64_t i=0; i<n; ++i) {
101
101
  Indices idx = get_bin_local_indices(keys[i], m, seed);
@@ -127,14 +127,14 @@ extern "C" {
127
127
 
128
128
  std::vector<bool> row_removed(n, false);
129
129
  std::vector<bool> col_removed(m, false);
130
-
130
+
131
131
  struct Assignment {
132
132
  int col;
133
133
  int row_idx;
134
134
  };
135
135
  std::vector<Assignment> assignment_stack;
136
136
  assignment_stack.reserve(n);
137
-
137
+
138
138
  int head = 0;
139
139
  while(head < peel_stack.size()) {
140
140
  int j = peel_stack[head++];
@@ -173,15 +173,15 @@ extern "C" {
173
173
  for(int i=(int)assignment_stack.size()-1; i>=0; --i) {
174
174
  auto a = assignment_stack[i];
175
175
  const auto& e = edges[a.row_idx];
176
-
176
+
177
177
  __m128i val1 = _mm_loadu_si128(&P_local[e.h1]);
178
178
  __m128i val2 = _mm_loadu_si128(&P_local[e.h2]);
179
179
  __m128i val3 = _mm_loadu_si128(&P_local[e.h3]);
180
180
  __m128i target = vals[e.key_idx];
181
-
181
+
182
182
  __m128i current = _mm_xor_si128(_mm_xor_si128(val1, val2), val3);
183
183
  __m128i diff = _mm_xor_si128(target, current);
184
-
184
+
185
185
  _mm_storeu_si128(&P_local[a.col], diff);
186
186
  }
187
187
  return true;
@@ -189,15 +189,15 @@ extern "C" {
189
189
 
190
190
  void solve_okvs_opt(uint64_t* keys, uint64_t* values, uint64_t* output, uint64_t n, uint64_t m, uint64_t* seed_ptr) {
191
191
  __m128i seed = _mm_loadu_si128((__m128i*)seed_ptr);
192
-
192
+
193
193
  // 1. Calculate Bin Boundaries
194
194
  // We divide M evenly among bins. The remainder is distributed to the first few bins.
195
195
  std::vector<uint64_t> bin_offsets(NUM_BINS + 1);
196
196
  std::vector<uint64_t> m_per_bin(NUM_BINS);
197
-
197
+
198
198
  uint64_t base_m = m / NUM_BINS;
199
199
  uint64_t remainder = m % NUM_BINS;
200
-
200
+
201
201
  uint64_t current_offset = 0;
202
202
  for(uint64_t b=0; b<NUM_BINS; ++b) {
203
203
  bin_offsets[b] = current_offset;
@@ -208,18 +208,18 @@ extern "C" {
208
208
 
209
209
  // 2. Partition Data (Stateless)
210
210
  // Note on "Two-Choice Hashing":
211
- // While Two-Choice Hashing (selecting the lighter of 2 potential bins) would significantly
211
+ // While Two-Choice Hashing (selecting the lighter of 2 potential bins) would significantly
212
212
  // reduce max bin load variance, it introduces "Statefulness".
213
213
  // The bin assignment for Key K would depend on the load of bins, which depends on other keys.
214
- // In standard PSI protocols (like RR22), the Decode step must be capable of processing keys
214
+ // In standard PSI protocols (like RR22), the Decode step must be capable of processing keys
215
215
  // independently or without knowledge of the full set distribution (Sender/Receiver separation).
216
216
  // Therefore, we use **Simple Binning** (Stateless Hash) where Bin(K) = H(K) % Bins.
217
- // We mitigate the resulting variance ("Balls-in-Bins" problem) by using a slightly larger
217
+ // We mitigate the resulting variance ("Balls-in-Bins" problem) by using a slightly larger
218
218
  // expansion factor (epsilon ~ 1.35) which is bandwidth-acceptable and ensures stability.
219
-
219
+
220
220
  std::vector<std::vector<uint64_t>> bin_keys(NUM_BINS);
221
221
  std::vector<std::vector<__m128i>> bin_vals(NUM_BINS);
222
-
222
+
223
223
  // Pre-allocate to reduce reallocation overhead (assume ~uniform distribution)
224
224
  // 1.5x margin for pre-allocation safety
225
225
  size_t est_size = (n / NUM_BINS) * 3 / 2;
@@ -227,14 +227,14 @@ extern "C" {
227
227
  bin_keys[b].reserve(est_size);
228
228
  bin_vals[b].reserve(est_size);
229
229
  }
230
-
230
+
231
231
  const __m128i* V_ptr = (const __m128i*)values;
232
232
  for(uint64_t i=0; i<n; ++i) {
233
233
  uint64_t b = get_bin_index(keys[i], seed);
234
234
  bin_keys[b].push_back(keys[i]);
235
235
  bin_vals[b].push_back(_mm_loadu_si128(&V_ptr[i]));
236
236
  }
237
-
237
+
238
238
  // 3. Parallel Solve
239
239
  // Each bin is solved independently. This logic is perfectly parallelizable (embarrassingly parallel).
240
240
  // The working set for each bin (~1000 items) stays hot in L1 Cache.
@@ -244,10 +244,10 @@ extern "C" {
244
244
  #pragma omp parallel for schedule(dynamic)
245
245
  for(uint64_t b=0; b<NUM_BINS; ++b) {
246
246
  if(bin_keys[b].empty()) continue;
247
-
247
+
248
248
  uint64_t offset = bin_offsets[b];
249
249
  uint64_t valid_m = m_per_bin[b];
250
-
250
+
251
251
  if(!solve_bin(bin_keys[b], bin_vals[b], &P_vec[offset], valid_m, seed)) {
252
252
  // On failure, log and fall back to the robust solver for this bin.
253
253
  // The fallback is executed inside a critical section to avoid nested OpenMP
@@ -259,7 +259,7 @@ extern "C" {
259
259
 
260
260
  // Prepare pointers for the safe solver
261
261
  uint64_t* keys_ptr = &bin_keys[b][0];
262
- uint64_t* vals_ptr = &bin_vals[b][0];
262
+ uint64_t* vals_ptr = (uint64_t*)&bin_vals[b][0]; // Cast __m128i* to uint64_t*
263
263
  uint64_t* out_ptr = output + (offset * 2ULL); // each 128-bit slot == 2 uint64_t
264
264
 
265
265
  // Call the safe solver implemented in okvs.cpp
@@ -273,10 +273,10 @@ extern "C" {
273
273
  __m128i seed = _mm_loadu_si128((__m128i*)seed_ptr);
274
274
  __m128i* P_vec = (__m128i*)storage;
275
275
  __m128i* out_vec = (__m128i*)output;
276
-
276
+
277
277
  // Replicate Boundary Logic
278
278
  std::vector<uint64_t> bin_offsets(NUM_BINS + 1);
279
- std::vector<uint64_t> m_per_bin(NUM_BINS);
279
+ std::vector<uint64_t> m_per_bin(NUM_BINS);
280
280
  uint64_t base_m = m / NUM_BINS;
281
281
  uint64_t remainder = m % NUM_BINS;
282
282
  uint64_t current_offset = 0;
@@ -290,12 +290,12 @@ extern "C" {
290
290
  #pragma omp parallel for schedule(static)
291
291
  for(uint64_t i=0; i<n; ++i) {
292
292
  uint64_t b = get_bin_index(keys[i], seed);
293
-
293
+
294
294
  uint64_t m_local = m_per_bin[b];
295
295
  uint64_t offset = bin_offsets[b];
296
-
296
+
297
297
  Indices idx = get_bin_local_indices(keys[i], m_local, seed);
298
-
298
+
299
299
  __m128i val = _mm_xor_si128(
300
300
  _mm_xor_si128(_mm_loadu_si128(&P_vec[offset + idx.h1]), _mm_loadu_si128(&P_vec[offset + idx.h2])),
301
301
  _mm_loadu_si128(&P_vec[offset + idx.h3])
@@ -23,7 +23,7 @@ from __future__ import annotations
23
23
 
24
24
  import numpy as np
25
25
 
26
- from mplang.v2.libs.mpc.common.constants import (
26
+ from mplang.libs.mpc.common.constants import (
27
27
  GOLDEN_RATIO_64,
28
28
  SPLITMIX64_GAMMA_2,
29
29
  SPLITMIX64_GAMMA_3,
@@ -30,8 +30,8 @@ Naming Convention:
30
30
  - collect: gather (N parties → 1 party, stacked)
31
31
 
32
32
  Example:
33
- >>> from mplang.v2.libs.collective import transfer, replicate, distribute, collect
34
- >>> from mplang.v2.dialects.simp import constant, converge
33
+ >>> from mplang.libs.collective import transfer, replicate, distribute, collect
34
+ >>> from mplang.dialects.simp import constant, converge
35
35
  >>>
36
36
  >>> # Create data on party 0
37
37
  >>> x = constant((0,), 42)
@@ -47,9 +47,9 @@ from __future__ import annotations
47
47
 
48
48
  from typing import TYPE_CHECKING
49
49
 
50
- from mplang.v2.dialects.simp import converge, shuffle_static
51
- from mplang.v2.edsl import Object
52
- from mplang.v2.edsl.typing import MPType
50
+ from mplang.dialects.simp import converge, shuffle_static
51
+ from mplang.edsl import Object
52
+ from mplang.edsl.typing import MPType
53
53
 
54
54
  if TYPE_CHECKING:
55
55
  pass
@@ -17,7 +17,7 @@
17
17
  This module provides the high-level device-centric programming interface.
18
18
  """
19
19
 
20
- from mplang.v2.dialects.tensor import jax_fn
20
+ from mplang.dialects.tensor import jax_fn
21
21
 
22
22
  from .api import (
23
23
  DeviceContext,
@@ -28,10 +28,10 @@ from typing import Any, cast
28
28
 
29
29
  from jax.tree_util import tree_flatten, tree_map
30
30
 
31
- from mplang.v2.backends import load_builtins
32
- from mplang.v2.dialects import crypto, simp, spu, tee
33
- from mplang.v2.edsl.object import Object
34
- from mplang.v2.libs.device.cluster import Device
31
+ from mplang.backends import load_builtins
32
+ from mplang.dialects import crypto, simp, spu, tee
33
+ from mplang.edsl.object import Object
34
+ from mplang.libs.device.cluster import Device
35
35
 
36
36
  load_builtins()
37
37
 
@@ -43,7 +43,7 @@ def _resolve_cluster() -> Any:
43
43
  Interpreter with a _cluster_spec attribute. This allows nested contexts
44
44
  to override the cluster if needed.
45
45
  """
46
- from mplang.v2.edsl.context import find_context
46
+ from mplang.edsl.context import find_context
47
47
 
48
48
  ctx = find_context(lambda c: getattr(c, "_cluster_spec", None) is not None)
49
49
  if ctx is not None:
@@ -356,7 +356,7 @@ class DeviceContext:
356
356
  if self._is_spu_device():
357
357
  return self(fn)
358
358
  # PPU/TEE need tensor.jax_fn to compile JAX code
359
- from mplang.v2.dialects.tensor import jax_fn
359
+ from mplang.dialects.tensor import jax_fn
360
360
 
361
361
  return self(jax_fn(fn))
362
362
 
@@ -443,7 +443,7 @@ def _ensure_tee_session(
443
443
  Returns:
444
444
  Tuple of (sess_frm, sess_tee) where each is a symmetric key Object
445
445
  """
446
- import mplang.v2.edsl as el
446
+ import mplang.edsl as el
447
447
 
448
448
  # Get current context ID for cache isolation
449
449
  current_ctx = el.get_current_context()
@@ -749,11 +749,11 @@ def fetch(obj: Object) -> Any:
749
749
  Returns:
750
750
  Python value (numpy array, scalar, etc.)
751
751
  """
752
- from mplang.v2.backends.simp_driver.state import SimpDriver
753
- from mplang.v2.backends.simp_driver.values import DriverVar
754
- from mplang.v2.edsl.context import get_current_context
755
- from mplang.v2.runtime.interpreter import InterpObject, Interpreter
756
- from mplang.v2.runtime.value import WrapValue
752
+ from mplang.backends.simp_driver.state import SimpDriver
753
+ from mplang.backends.simp_driver.values import DriverVar
754
+ from mplang.edsl.context import get_current_context
755
+ from mplang.runtime.interpreter import InterpObject, Interpreter
756
+ from mplang.runtime.value import WrapValue
757
757
 
758
758
  def _unwrap_value(val: Any) -> Any:
759
759
  """Unwrap WrapValue to get the underlying data."""