mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
@@ -18,8 +18,8 @@ from __future__ import annotations
18
18
 
19
19
  from typing import Any, ClassVar, Self
20
20
 
21
- from mplang.v2.edsl import serde
22
- from mplang.v2.runtime.value import Value
21
+ from mplang.edsl import serde
22
+ from mplang.runtime.value import Value
23
23
 
24
24
 
25
25
  @serde.register_class
@@ -17,9 +17,9 @@
17
17
  Provides Worker-side state and ops for the simp dialect.
18
18
  """
19
19
 
20
- from mplang.v2.backends.simp_worker.mem import LocalMesh, ThreadCommunicator
21
- from mplang.v2.backends.simp_worker.ops import WORKER_HANDLERS
22
- from mplang.v2.backends.simp_worker.state import SimpWorker
20
+ from mplang.backends.simp_worker.mem import LocalMesh, ThreadCommunicator
21
+ from mplang.backends.simp_worker.ops import WORKER_HANDLERS
22
+ from mplang.backends.simp_worker.state import SimpWorker
23
23
 
24
24
  __all__ = [
25
25
  "WORKER_HANDLERS",
@@ -21,7 +21,7 @@ This module contains:
21
21
 
22
22
  Usage:
23
23
  # Start a worker server
24
- from mplang.v2.backends.simp_http_worker import create_worker_app
24
+ from mplang.backends.simp_http_worker import create_worker_app
25
25
  import uvicorn
26
26
 
27
27
  app = create_worker_app(rank=0, world_size=3, endpoints=[...])
@@ -47,16 +47,16 @@ import httpx
47
47
  from fastapi import FastAPI, HTTPException
48
48
  from pydantic import BaseModel
49
49
 
50
- from mplang.v2.backends import spu_impl as _spu_impl # noqa: F401
51
- from mplang.v2.backends import tensor_impl as _tensor_impl # noqa: F401
50
+ from mplang.backends import spu_impl as _spu_impl # noqa: F401
51
+ from mplang.backends import tensor_impl as _tensor_impl # noqa: F401
52
52
 
53
53
  # Register operation implementations (side-effect imports)
54
- from mplang.v2.backends.simp_worker import SimpWorker
55
- from mplang.v2.backends.simp_worker import ops as _simp_worker_ops # noqa: F401
56
- from mplang.v2.edsl import serde
57
- from mplang.v2.edsl.graph import Graph
58
- from mplang.v2.runtime.interpreter import ExecutionTracer, Interpreter
59
- from mplang.v2.runtime.object_store import ObjectStore
54
+ from mplang.backends.simp_worker import SimpWorker
55
+ from mplang.backends.simp_worker import ops as _simp_worker_ops # noqa: F401
56
+ from mplang.edsl import serde
57
+ from mplang.edsl.graph import Graph
58
+ from mplang.runtime.interpreter import ExecutionTracer, Interpreter
59
+ from mplang.runtime.object_store import ObjectStore
60
60
 
61
61
  logger = logging.getLogger(__name__)
62
62
 
@@ -250,7 +250,7 @@ def create_worker_app(
250
250
  from collections.abc import Callable
251
251
  from typing import cast
252
252
 
253
- from mplang.v2.backends.simp_worker.ops import WORKER_HANDLERS
253
+ from mplang.backends.simp_worker.ops import WORKER_HANDLERS
254
254
 
255
255
  # func_impl is already imported at module level for side-effects
256
256
  handlers: dict[str, Callable[..., Any]] = {**WORKER_HANDLERS} # type: ignore[dict-item]
@@ -53,7 +53,7 @@ class ThreadCommunicator:
53
53
  def send(self, to: int, key: str, data: Any) -> None:
54
54
  assert 0 <= to < self.world_size
55
55
  if self.use_serde:
56
- from mplang.v2.edsl import serde
56
+ from mplang.edsl import serde
57
57
 
58
58
  data = serde.loads(serde.dumps(data))
59
59
  self.peers[to]._on_receive(self.rank, key, data)
@@ -22,9 +22,9 @@ from __future__ import annotations
22
22
 
23
23
  from typing import Any
24
24
 
25
- from mplang.v2.dialects import simp
26
- from mplang.v2.edsl.graph import Operation
27
- from mplang.v2.runtime.interpreter import Interpreter
25
+ from mplang.dialects import simp
26
+ from mplang.edsl.graph import Operation
27
+ from mplang.runtime.interpreter import Interpreter
28
28
 
29
29
 
30
30
  def _ensure_worker_context(interpreter: Any, op_name: str) -> Any:
@@ -111,7 +111,7 @@ def _uniform_cond_worker_impl(
111
111
  interpreter: Interpreter, op: Operation, pred: Any, *args: Any
112
112
  ) -> Any:
113
113
  """Worker implementation of simp.uniform_cond."""
114
- from mplang.v2.backends.tensor_impl import TensorValue
114
+ from mplang.backends.tensor_impl import TensorValue
115
115
 
116
116
  if op.attrs.get("verify_uniform", True):
117
117
  pass # TODO: Implement AllReduce verification
@@ -128,7 +128,7 @@ def _uniform_cond_worker_impl(
128
128
 
129
129
  def _while_loop_worker_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
130
130
  """Worker implementation of simp.while_loop."""
131
- from mplang.v2.backends.tensor_impl import TensorValue
131
+ from mplang.backends.tensor_impl import TensorValue
132
132
 
133
133
  cond_graph = op.regions[0]
134
134
  body_graph = op.regions[1]
@@ -18,10 +18,8 @@ from __future__ import annotations
18
18
 
19
19
  from typing import Any
20
20
 
21
- import mplang.v2.backends.field_impl # noqa: F401
22
- import mplang.v2.backends.tensor_impl # noqa: F401
23
- from mplang.v2.runtime.dialect_state import DialectState
24
- from mplang.v2.runtime.object_store import ObjectStore
21
+ from mplang.runtime.dialect_state import DialectState
22
+ from mplang.runtime.object_store import ObjectStore
25
23
 
26
24
 
27
25
  class SimpWorker(DialectState):
@@ -26,13 +26,13 @@ import numpy as np
26
26
  import spu.api as spu_api
27
27
  import spu.libspu as libspu
28
28
 
29
- from mplang.v2.backends.spu_state import SPUState
30
- from mplang.v2.backends.tensor_impl import TensorValue
31
- from mplang.v2.dialects import spu
32
- from mplang.v2.edsl import serde
33
- from mplang.v2.edsl.graph import Operation
34
- from mplang.v2.runtime.interpreter import Interpreter
35
- from mplang.v2.runtime.value import WrapValue
29
+ from mplang.backends.spu_state import SPUState
30
+ from mplang.backends.tensor_impl import TensorValue
31
+ from mplang.dialects import spu
32
+ from mplang.edsl import serde
33
+ from mplang.edsl.graph import Operation
34
+ from mplang.runtime.interpreter import Interpreter
35
+ from mplang.runtime.value import WrapValue
36
36
 
37
37
  # =============================================================================
38
38
  # SPU Share Wrapper
@@ -160,7 +160,7 @@ def exec_impl(interpreter: Interpreter, op: Operation, *args: Any) -> Any:
160
160
  The SPU config must contain parties info to correctly map global rank
161
161
  to local SPU rank and determine SPU world size.
162
162
  """
163
- from mplang.v2.backends.simp_worker.state import SimpWorker
163
+ from mplang.backends.simp_worker.state import SimpWorker
164
164
 
165
165
  # Get SPU config from attrs (passed through from run_jax)
166
166
  config: spu.SPUConfig = op.attrs["config"]
@@ -25,10 +25,10 @@ from typing import TYPE_CHECKING, Any
25
25
  import spu.api as spu_api
26
26
  import spu.libspu as libspu
27
27
 
28
- from mplang.v2.runtime.dialect_state import DialectState
28
+ from mplang.runtime.dialect_state import DialectState
29
29
 
30
30
  if TYPE_CHECKING:
31
- from mplang.v2.dialects import spu
31
+ from mplang.dialects import spu
32
32
 
33
33
 
34
34
  class SPUState(DialectState):
@@ -74,7 +74,7 @@ class SPUState(DialectState):
74
74
  Returns:
75
75
  A tuple of (Runtime, Io) for this party.
76
76
  """
77
- from mplang.v2.backends.spu_impl import to_runtime_config
77
+ from mplang.backends.spu_impl import to_runtime_config
78
78
 
79
79
  # Determine link mode
80
80
  if communicator is not None:
@@ -143,7 +143,7 @@ class SPUState(DialectState):
143
143
  Returns:
144
144
  libspu link context using BaseChannel adapters
145
145
  """
146
- from mplang.v2.backends.channel import BaseChannel
146
+ from mplang.backends.channel import BaseChannel
147
147
 
148
148
  # Get this worker's global rank
149
149
  global_rank = parties[local_rank]
@@ -18,9 +18,9 @@ from __future__ import annotations
18
18
 
19
19
  from typing import Any
20
20
 
21
- from mplang.v2.dialects import store
22
- from mplang.v2.edsl.graph import Operation
23
- from mplang.v2.runtime.interpreter import Interpreter
21
+ from mplang.dialects import store
22
+ from mplang.edsl.graph import Operation
23
+ from mplang.runtime.interpreter import Interpreter
24
24
 
25
25
 
26
26
  def _get_uri(uri_base: str) -> str:
@@ -28,13 +28,13 @@ import duckdb
28
28
  import pandas as pd
29
29
  import pyarrow as pa
30
30
 
31
- import mplang.v2.edsl.typing as elt
32
- from mplang.v2.backends.tensor_impl import TensorValue
33
- from mplang.v2.dialects import table
34
- from mplang.v2.edsl import serde
35
- from mplang.v2.edsl.graph import Operation
36
- from mplang.v2.runtime.interpreter import Interpreter
37
- from mplang.v2.runtime.value import WrapValue
31
+ import mplang.edsl.typing as elt
32
+ from mplang.backends.tensor_impl import TensorValue
33
+ from mplang.dialects import table
34
+ from mplang.edsl import serde
35
+ from mplang.edsl.graph import Operation
36
+ from mplang.runtime.interpreter import Interpreter
37
+ from mplang.runtime.value import WrapValue
38
38
 
39
39
 
40
40
  class BatchReader(ABC):
@@ -631,7 +631,7 @@ def table2tensor_impl(interpreter: Interpreter, op: Operation, table_val: Any) -
631
631
 
632
632
  Returns TensorValue if tensor_impl is available, otherwise raw np.ndarray.
633
633
  """
634
- from mplang.v2.backends.tensor_impl import TensorValue
634
+ from mplang.backends.tensor_impl import TensorValue
635
635
 
636
636
  tbl = _unwrap(table_val)
637
637
  df = tbl.to_pandas()
@@ -31,14 +31,14 @@ from typing import TYPE_CHECKING, Any, ClassVar
31
31
 
32
32
  import numpy as np
33
33
 
34
- from mplang.v2.backends.crypto_impl import BytesValue, PublicKeyValue
35
- from mplang.v2.dialects import tee
36
- from mplang.v2.edsl import serde
37
- from mplang.v2.runtime.value import Value
34
+ from mplang.backends.crypto_impl import BytesValue, PublicKeyValue
35
+ from mplang.dialects import tee
36
+ from mplang.edsl import serde
37
+ from mplang.runtime.value import Value
38
38
 
39
39
  if TYPE_CHECKING:
40
- from mplang.v2.edsl.graph import Operation
41
- from mplang.v2.runtime.interpreter import Interpreter
40
+ from mplang.edsl.graph import Operation
41
+ from mplang.runtime.interpreter import Interpreter
42
42
 
43
43
 
44
44
  # ==============================================================================
@@ -32,12 +32,12 @@ import numpy as np
32
32
  from jax._src import compiler
33
33
  from numpy.typing import ArrayLike
34
34
 
35
- import mplang.v2.edsl.typing as elt
36
- from mplang.v2.dialects import dtypes, tensor
37
- from mplang.v2.edsl import serde
38
- from mplang.v2.edsl.graph import Operation
39
- from mplang.v2.runtime.interpreter import Interpreter
40
- from mplang.v2.runtime.value import Value, WrapValue
35
+ import mplang.edsl.typing as elt
36
+ from mplang.dialects import dtypes, tensor
37
+ from mplang.edsl import serde
38
+ from mplang.edsl.graph import Operation
39
+ from mplang.runtime.interpreter import Interpreter
40
+ from mplang.runtime.value import Value, WrapValue
41
41
 
42
42
  # =============================================================================
43
43
  # TensorValue Wrapper
@@ -18,19 +18,19 @@ Command-line interface for MPLang2 clusters and jobs.
18
18
 
19
19
  Examples:
20
20
  # Generate a cluster config file
21
- python -m mplang.v2.cli config gen -w 3 -p 8100 -o cluster.yaml
21
+ python -m mplang.cli config gen -w 3 -p 8100 -o cluster.yaml
22
22
 
23
23
  # Start a single worker (production usage)
24
- python -m mplang.v2.cli worker --rank 0 -c cluster.yaml
24
+ python -m mplang.cli worker --rank 0 -c cluster.yaml
25
25
 
26
26
  # Start 3 local workers (development usage)
27
- python -m mplang.v2.cli up -c cluster.yaml
27
+ python -m mplang.cli up -c cluster.yaml
28
28
 
29
29
  # Check cluster status
30
- python -m mplang.v2.cli status -c cluster.yaml
30
+ python -m mplang.cli status -c cluster.yaml
31
31
 
32
32
  # Run a job
33
- python -m mplang.v2.cli run -c cluster.yaml -f my_job.py
33
+ python -m mplang.cli run -c cluster.yaml -f my_job.py
34
34
  """
35
35
 
36
36
  import argparse
@@ -62,7 +62,7 @@ def run_worker(
62
62
  signal.signal(signal.SIGINT, signal.SIG_DFL)
63
63
  signal.signal(signal.SIGTERM, signal.SIG_DFL)
64
64
 
65
- from mplang.v2.backends.simp_worker.http import create_worker_app
65
+ from mplang.backends.simp_worker.http import create_worker_app
66
66
 
67
67
  app = create_worker_app(rank, world_size, endpoints, spu_endpoints)
68
68
 
@@ -323,9 +323,9 @@ def parse_spu_endpoints(
323
323
 
324
324
  def cmd_run(args: argparse.Namespace) -> None:
325
325
  """Run a user job via HTTP cluster or local simulator."""
326
- from mplang.v2 import make_driver, make_simulator
327
- from mplang.v2.edsl.context import pop_context, push_context
328
- from mplang.v2.libs.device import ClusterSpec
326
+ from mplang import make_driver, make_simulator
327
+ from mplang.edsl.context import pop_context, push_context
328
+ from mplang.libs.device import ClusterSpec
329
329
 
330
330
  cluster: ClusterSpec
331
331
 
@@ -26,7 +26,7 @@ First, generate a `cluster.yaml` file. This defines the topology of your MPLang
26
26
 
27
27
  ```bash
28
28
  # Generate a config for 2 workers starting at port 8100
29
- python -m mplang.v2.cli config gen -w 2 -p 8100 -o cluster.yaml
29
+ python -m mplang.cli config gen -w 2 -p 8100 -o cluster.yaml
30
30
  ```
31
31
 
32
32
  ### 2. Start the Cluster (Terminal 1)
@@ -35,7 +35,7 @@ In your **first terminal**, start the cluster using the `up` command. This will
35
35
 
36
36
  ```bash
37
37
  # Terminal 1
38
- python -m mplang.v2.cli up -c cluster.yaml
38
+ python -m mplang.cli up -c cluster.yaml
39
39
  ```
40
40
 
41
41
  You should see logs indicating that workers have started (e.g., `[Worker 0] INFO: Started server process...`). Keep this terminal open.
@@ -46,7 +46,7 @@ Create a Python script (e.g., `my_job.py`) that defines the computation you want
46
46
 
47
47
  ```python
48
48
  # my_job.py
49
- from mplang.v2.dialects import simp
49
+ from mplang.dialects import simp
50
50
  import numpy as np
51
51
 
52
52
  def main():
@@ -71,7 +71,7 @@ In your **second terminal**, use the `run` command to submit the script to the r
71
71
 
72
72
  ```bash
73
73
  # Terminal 2
74
- python -m mplang.v2.cli run -c cluster.yaml -f my_job.py
74
+ python -m mplang.cli run -c cluster.yaml -f my_job.py
75
75
  ```
76
76
 
77
77
  The CLI will connect to the driver, which orchestrates the execution across the workers.
@@ -82,7 +82,7 @@ You can check the health and latency of your workers at any time.
82
82
 
83
83
  ```bash
84
84
  # Terminal 2
85
- python -m mplang.v2.cli status -c cluster.yaml
85
+ python -m mplang.cli status -c cluster.yaml
86
86
  ```
87
87
 
88
88
  **Output Example:**
@@ -99,7 +99,7 @@ To debug or verify intermediate results, you can list the objects currently stor
99
99
 
100
100
  ```bash
101
101
  # Terminal 2
102
- python -m mplang.v2.cli objects -c cluster.yaml
102
+ python -m mplang.cli objects -c cluster.yaml
103
103
  ```
104
104
 
105
105
  **Output Example:**
@@ -114,9 +114,9 @@ Rank | Endpoint | Count | Objects
114
114
 
115
115
  | Command | Description | Usage |
116
116
  | :--- | :--- | :--- |
117
- | `config gen` | Generate cluster config file | `python -m mplang.v2.cli config gen -w <workers> -o <file>` |
118
- | `up` | Start all workers locally | `python -m mplang.v2.cli up -c <config>` |
119
- | `run` | Submit a job script | `python -m mplang.v2.cli run -c <config> -f <script>` |
120
- | `status` | Check worker health | `python -m mplang.v2.cli status -c <config>` |
121
- | `objects` | List objects on workers | `python -m mplang.v2.cli objects -c <config>` |
122
- | `worker` | Start a single worker (prod) | `python -m mplang.v2.cli worker --rank <id> -c <config>` |
117
+ | `config gen` | Generate cluster config file | `python -m mplang.cli config gen -w <workers> -o <file>` |
118
+ | `up` | Start all workers locally | `python -m mplang.cli up -c <config>` |
119
+ | `run` | Submit a job script | `python -m mplang.cli run -c <config> -f <script>` |
120
+ | `status` | Check worker health | `python -m mplang.cli status -c <config>` |
121
+ | `objects` | List objects on workers | `python -m mplang.cli objects -c <config>` |
122
+ | `worker` | Start a single worker (prod) | `python -m mplang.cli worker --rank <id> -c <config>` |
@@ -29,8 +29,8 @@ from __future__ import annotations
29
29
 
30
30
  # Import dialects to trigger their type registrations
31
31
  # Each dialect module registers its types at import time via _register_*_types()
32
- from mplang.v2.dialects import bfv as _bfv # noqa: F401
33
- from mplang.v2.dialects import crypto as _crypto # noqa: F401
34
- from mplang.v2.dialects import spu as _spu # noqa: F401
35
- from mplang.v2.dialects import store as _store # noqa: F401
36
- from mplang.v2.dialects import tee as _tee # noqa: F401
32
+ from mplang.dialects import bfv as _bfv # noqa: F401
33
+ from mplang.dialects import crypto as _crypto # noqa: F401
34
+ from mplang.dialects import spu as _spu # noqa: F401
35
+ from mplang.dialects import store as _store # noqa: F401
36
+ from mplang.dialects import tee as _tee # noqa: F401
@@ -54,8 +54,8 @@ Architecture:
54
54
 
55
55
  Example:
56
56
  ```python
57
- from mplang.v2.dialects import tensor, bfv
58
- import mplang.v2.edsl.typing as elt
57
+ from mplang.dialects import tensor, bfv
58
+ import mplang.edsl.typing as elt
59
59
  import numpy as np
60
60
 
61
61
  # 1. Setup
@@ -91,9 +91,9 @@ from __future__ import annotations
91
91
 
92
92
  from typing import Any, ClassVar, Literal, cast
93
93
 
94
- import mplang.v2.edsl as el
95
- import mplang.v2.edsl.typing as elt
96
- from mplang.v2.edsl import serde
94
+ import mplang.edsl as el
95
+ import mplang.edsl.typing as elt
96
+ from mplang.edsl import serde
97
97
 
98
98
  # ==============================================================================
99
99
  # --- Type Definitions
@@ -369,7 +369,7 @@ def _batch_encode_trace(
369
369
  encoder: el.Object,
370
370
  key: el.Object,
371
371
  ) -> tuple[el.Object, ...]:
372
- from mplang.v2.edsl.tracer import TraceObject, Tracer
372
+ from mplang.edsl.tracer import TraceObject, Tracer
373
373
 
374
374
  ctx = el.get_current_context()
375
375
  if not isinstance(ctx, Tracer):
@@ -21,9 +21,9 @@ from __future__ import annotations
21
21
 
22
22
  from typing import Any, ClassVar
23
23
 
24
- import mplang.v2.edsl as el
25
- import mplang.v2.edsl.typing as elt
26
- from mplang.v2.edsl import serde
24
+ import mplang.edsl as el
25
+ import mplang.edsl.typing as elt
26
+ from mplang.edsl import serde
27
27
 
28
28
  # ==============================================================================
29
29
  # --- Type Definitions
@@ -607,7 +607,7 @@ def random_tensor(shape: tuple[int, ...], dtype: elt.ScalarType) -> el.Object:
607
607
  import math
608
608
  from typing import cast
609
609
 
610
- from mplang.v2.dialects import dtypes, tensor
610
+ from mplang.dialects import dtypes, tensor
611
611
 
612
612
  # Get byte size from numpy dtype
613
613
  np_dtype = dtypes.to_numpy(dtype)
@@ -644,7 +644,7 @@ def random_bits(n: int) -> el.Object:
644
644
 
645
645
  import jax.numpy as jnp
646
646
 
647
- from mplang.v2.dialects import tensor
647
+ from mplang.dialects import tensor
648
648
 
649
649
  # Generate enough bytes to cover n bits
650
650
  num_bytes = (n + 7) // 8
@@ -18,7 +18,7 @@ This module provides bidirectional conversion between MPLang's type system
18
18
  (ScalarType hierarchy) and external library types (NumPy, JAX, PyArrow, Pandas).
19
19
 
20
20
  Usage:
21
- from mplang.v2.dialects import dtypes
21
+ from mplang.dialects import dtypes
22
22
 
23
23
  # MPLang ScalarType → JAX/NumPy
24
24
  jax_dtype = dtypes.to_jax(scalar_types.f32) # → jnp.float32
@@ -40,7 +40,7 @@ from typing import Any
40
40
  import jax.numpy as jnp
41
41
  import numpy as np
42
42
 
43
- import mplang.v2.edsl.typing as scalar_types
43
+ import mplang.edsl.typing as scalar_types
44
44
 
45
45
  # ==============================================================================
46
46
  # MPLang ScalarType → JAX/NumPy conversion
@@ -29,9 +29,9 @@ from typing import Any, cast
29
29
 
30
30
  import jax.numpy as jnp
31
31
 
32
- import mplang.v2.edsl as el
33
- import mplang.v2.edsl.typing as elt
34
- from mplang.v2.dialects import tensor
32
+ import mplang.edsl as el
33
+ import mplang.edsl.typing as elt
34
+ from mplang.dialects import tensor
35
35
 
36
36
  # =============================================================================
37
37
  # Primitives
@@ -24,8 +24,8 @@ from __future__ import annotations
24
24
  from collections.abc import Callable
25
25
  from typing import Any
26
26
 
27
- import mplang.v2.edsl as el
28
- import mplang.v2.edsl.typing as elt
27
+ import mplang.edsl as el
28
+ import mplang.edsl.typing as elt
29
29
 
30
30
  func_def_p = el.Primitive[el.TraceObject]("func.func")
31
31
  call_p = el.Primitive[Any]("func.call")
@@ -35,8 +35,8 @@ Architecture:
35
35
 
36
36
  Example:
37
37
  ```python
38
- from mplang.v2.dialects import tensor, phe
39
- import mplang.v2.edsl.typing as elt
38
+ from mplang.dialects import tensor, phe
39
+ import mplang.edsl.typing as elt
40
40
  import numpy as np
41
41
 
42
42
  # 1. Generate keys (cryptographic only)
@@ -75,9 +75,9 @@ from __future__ import annotations
75
75
  from collections.abc import Callable
76
76
  from typing import Any, NamedTuple
77
77
 
78
- import mplang.v2.edsl as el
79
- import mplang.v2.edsl.typing as elt
80
- from mplang.v2.dialects import tensor
78
+ import mplang.edsl as el
79
+ import mplang.edsl.typing as elt
80
+ from mplang.dialects import tensor
81
81
 
82
82
  # ==============================================================================
83
83
  # --- Type Definitions
@@ -415,7 +415,7 @@ def create_encoder(
415
415
  PHE encoder configured for the specified dtype
416
416
 
417
417
  Example:
418
- >>> import mplang.v2.edsl.typing as elt
418
+ >>> import mplang.edsl.typing as elt
419
419
  >>>
420
420
  >>> # Float encoder with 16-bit fractional precision
421
421
  >>> encoder_f64 = phe.create_encoder(dtype=elt.f64, fxp_bits=16)
@@ -36,8 +36,8 @@ from typing import Any, cast
36
36
 
37
37
  from jax.tree_util import tree_flatten, tree_unflatten
38
38
 
39
- import mplang.v2.edsl as el
40
- import mplang.v2.edsl.typing as elt
39
+ import mplang.edsl as el
40
+ import mplang.edsl.typing as elt
41
41
 
42
42
  # ---------------------------------------------------------------------------
43
43
  # Global configuration
@@ -809,7 +809,7 @@ def constant(parties: tuple[int, ...], data: Any) -> el.Object:
809
809
  import jax.numpy as jnp
810
810
  import numpy as np
811
811
 
812
- from mplang.v2.dialects import table, tensor
812
+ from mplang.dialects import table, tensor
813
813
 
814
814
  # 1. Scalars (int, float, bool, numpy scalars)
815
815
  if isinstance(data, (int, float, bool, np.number, np.bool_)):
@@ -888,11 +888,11 @@ def make_simulator(
888
888
  ... result = my_func()
889
889
  """
890
890
  if enable_profiling:
891
- from mplang.v2.edsl import registry
891
+ from mplang.edsl import registry
892
892
 
893
893
  registry.enable_profiling()
894
894
 
895
- from mplang.v2.backends.simp_driver.mem import make_simulator as _make_sim
895
+ from mplang.backends.simp_driver.mem import make_simulator as _make_sim
896
896
 
897
897
  return _make_sim(
898
898
  world_size, cluster_spec=cluster_spec, enable_tracing=enable_tracing
@@ -917,7 +917,7 @@ def make_driver(endpoints: list[str], *, cluster_spec: Any = None) -> Any:
917
917
  >>> with interp:
918
918
  ... result = my_func()
919
919
  """
920
- from mplang.v2.backends.simp_driver.http import make_driver as _make_drv
920
+ from mplang.backends.simp_driver.http import make_driver as _make_drv
921
921
 
922
922
  return _make_drv(endpoints, cluster_spec=cluster_spec)
923
923
 
@@ -27,8 +27,8 @@ Concepts:
27
27
  Example:
28
28
  ```python
29
29
  import jax.numpy as jnp
30
- from mplang.v2.dialects import spu, tensor, simp
31
- import mplang.v2.edsl.typing as elt
30
+ from mplang.dialects import spu, tensor, simp
31
+ import mplang.edsl.typing as elt
32
32
 
33
33
  # 0. Setup
34
34
  spu_device = spu.SPUDevice(parties=(0, 1, 2))
@@ -83,11 +83,11 @@ import spu.utils.frontend as spu_fe
83
83
  from jax import ShapeDtypeStruct
84
84
  from jax.tree_util import tree_flatten, tree_unflatten
85
85
 
86
- import mplang.v2.edsl as el
87
- import mplang.v2.edsl.typing as elt
88
- from mplang.v1.utils.func_utils import normalize_fn
89
- from mplang.v2.dialects import dtypes
90
- from mplang.v2.edsl import serde
86
+ import mplang.edsl as el
87
+ import mplang.edsl.typing as elt
88
+ from mplang.dialects import dtypes
89
+ from mplang.edsl import serde
90
+ from mplang.utils import normalize_fn
91
91
 
92
92
  # ==============================================================================
93
93
  # --- Configuration
@@ -16,8 +16,8 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- import mplang.v2.edsl as el
20
- import mplang.v2.edsl.typing as elt
19
+ import mplang.edsl as el
20
+ import mplang.edsl.typing as elt
21
21
 
22
22
  save_p: el.Primitive[el.Object] = el.Primitive("store.save")
23
23
  load_p: el.Primitive[el.Object] = el.Primitive("store.load")
@@ -18,8 +18,8 @@ from __future__ import annotations
18
18
 
19
19
  from typing import Any, cast
20
20
 
21
- import mplang.v2.edsl as el
22
- import mplang.v2.edsl.typing as elt
21
+ import mplang.edsl as el
22
+ import mplang.edsl.typing as elt
23
23
 
24
24
  run_sql_p: el.Primitive[Any] = el.Primitive("table.run_sql")
25
25
  table2tensor_p: el.Primitive[el.Object] = el.Primitive("table.table2tensor")
@@ -182,7 +182,7 @@ def _constant_ae(*, data: Any) -> elt.TableType:
182
182
  import pandas as pd
183
183
  import pyarrow as pa
184
184
 
185
- from mplang.v2.dialects import dtypes
185
+ from mplang.dialects import dtypes
186
186
 
187
187
  # Handle PyArrow Table directly
188
188
  if isinstance(data, pa.Table):