mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. mplang/__init__.py +21 -130
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +4 -4
  7. mplang/{core → v1/core}/__init__.py +20 -14
  8. mplang/{core → v1/core}/cluster.py +6 -1
  9. mplang/{core → v1/core}/comm.py +1 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core → v1/core}/dtypes.py +38 -0
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +11 -13
  14. mplang/{core → v1/core}/expr/evaluator.py +8 -8
  15. mplang/{core → v1/core}/expr/printer.py +6 -6
  16. mplang/{core → v1/core}/expr/transformer.py +2 -2
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +13 -11
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +2 -2
  25. mplang/{core → v1/core}/primitive.py +12 -12
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{host.py → v1/host.py} +5 -5
  30. mplang/{kernels → v1/kernels}/__init__.py +1 -1
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/{kernels → v1/kernels}/basic.py +15 -15
  33. mplang/{kernels → v1/kernels}/context.py +19 -16
  34. mplang/{kernels → v1/kernels}/crypto.py +8 -10
  35. mplang/{kernels → v1/kernels}/fhe.py +9 -7
  36. mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
  37. mplang/{kernels → v1/kernels}/phe.py +26 -18
  38. mplang/{kernels → v1/kernels}/spu.py +5 -5
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
  40. mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
  41. mplang/{kernels → v1/kernels}/value.py +2 -2
  42. mplang/{ops → v1/ops}/__init__.py +3 -3
  43. mplang/{ops → v1/ops}/base.py +1 -1
  44. mplang/{ops → v1/ops}/basic.py +6 -5
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/{ops → v1/ops}/fhe.py +2 -2
  47. mplang/{ops → v1/ops}/jax_cc.py +26 -59
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -3
  50. mplang/{ops → v1/ops}/spu.py +3 -3
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +2 -2
  53. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  54. mplang/v1/runtime/channel.py +230 -0
  55. mplang/{runtime → v1/runtime}/cli.py +3 -3
  56. mplang/{runtime → v1/runtime}/client.py +1 -1
  57. mplang/{runtime → v1/runtime}/communicator.py +39 -15
  58. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  59. mplang/{runtime → v1/runtime}/driver.py +4 -4
  60. mplang/v1/runtime/link_comm.py +196 -0
  61. mplang/{runtime → v1/runtime}/server.py +22 -9
  62. mplang/{runtime → v1/runtime}/session.py +24 -51
  63. mplang/{runtime → v1/runtime}/simulation.py +36 -14
  64. mplang/{simp → v1/simp}/api.py +72 -14
  65. mplang/{simp → v1/simp}/mpi.py +1 -1
  66. mplang/{simp → v1/simp}/party.py +5 -5
  67. mplang/{simp → v1/simp}/random.py +2 -2
  68. mplang/v1/simp/smpc.py +238 -0
  69. mplang/v1/utils/table_utils.py +185 -0
  70. mplang/v2/__init__.py +424 -0
  71. mplang/v2/backends/__init__.py +57 -0
  72. mplang/v2/backends/bfv_impl.py +705 -0
  73. mplang/v2/backends/channel.py +217 -0
  74. mplang/v2/backends/crypto_impl.py +723 -0
  75. mplang/v2/backends/field_impl.py +454 -0
  76. mplang/v2/backends/func_impl.py +107 -0
  77. mplang/v2/backends/phe_impl.py +148 -0
  78. mplang/v2/backends/simp_design.md +136 -0
  79. mplang/v2/backends/simp_driver/__init__.py +41 -0
  80. mplang/v2/backends/simp_driver/http.py +168 -0
  81. mplang/v2/backends/simp_driver/mem.py +280 -0
  82. mplang/v2/backends/simp_driver/ops.py +135 -0
  83. mplang/v2/backends/simp_driver/state.py +60 -0
  84. mplang/v2/backends/simp_driver/values.py +52 -0
  85. mplang/v2/backends/simp_worker/__init__.py +29 -0
  86. mplang/v2/backends/simp_worker/http.py +354 -0
  87. mplang/v2/backends/simp_worker/mem.py +102 -0
  88. mplang/v2/backends/simp_worker/ops.py +167 -0
  89. mplang/v2/backends/simp_worker/state.py +49 -0
  90. mplang/v2/backends/spu_impl.py +275 -0
  91. mplang/v2/backends/spu_state.py +187 -0
  92. mplang/v2/backends/store_impl.py +62 -0
  93. mplang/v2/backends/table_impl.py +838 -0
  94. mplang/v2/backends/tee_impl.py +215 -0
  95. mplang/v2/backends/tensor_impl.py +519 -0
  96. mplang/v2/cli.py +603 -0
  97. mplang/v2/cli_guide.md +122 -0
  98. mplang/v2/dialects/__init__.py +36 -0
  99. mplang/v2/dialects/bfv.py +665 -0
  100. mplang/v2/dialects/crypto.py +689 -0
  101. mplang/v2/dialects/dtypes.py +378 -0
  102. mplang/v2/dialects/field.py +210 -0
  103. mplang/v2/dialects/func.py +135 -0
  104. mplang/v2/dialects/phe.py +723 -0
  105. mplang/v2/dialects/simp.py +944 -0
  106. mplang/v2/dialects/spu.py +349 -0
  107. mplang/v2/dialects/store.py +63 -0
  108. mplang/v2/dialects/table.py +407 -0
  109. mplang/v2/dialects/tee.py +346 -0
  110. mplang/v2/dialects/tensor.py +1175 -0
  111. mplang/v2/edsl/README.md +279 -0
  112. mplang/v2/edsl/__init__.py +99 -0
  113. mplang/v2/edsl/context.py +311 -0
  114. mplang/v2/edsl/graph.py +463 -0
  115. mplang/v2/edsl/jit.py +62 -0
  116. mplang/v2/edsl/object.py +53 -0
  117. mplang/v2/edsl/primitive.py +284 -0
  118. mplang/v2/edsl/printer.py +119 -0
  119. mplang/v2/edsl/registry.py +207 -0
  120. mplang/v2/edsl/serde.py +375 -0
  121. mplang/v2/edsl/tracer.py +614 -0
  122. mplang/v2/edsl/typing.py +816 -0
  123. mplang/v2/kernels/Makefile +30 -0
  124. mplang/v2/kernels/__init__.py +23 -0
  125. mplang/v2/kernels/gf128.cpp +148 -0
  126. mplang/v2/kernels/ldpc.cpp +82 -0
  127. mplang/v2/kernels/okvs.cpp +283 -0
  128. mplang/v2/kernels/okvs_opt.cpp +291 -0
  129. mplang/v2/kernels/py_kernels.py +398 -0
  130. mplang/v2/libs/collective.py +330 -0
  131. mplang/v2/libs/device/__init__.py +51 -0
  132. mplang/v2/libs/device/api.py +813 -0
  133. mplang/v2/libs/device/cluster.py +352 -0
  134. mplang/v2/libs/ml/__init__.py +23 -0
  135. mplang/v2/libs/ml/sgb.py +1861 -0
  136. mplang/v2/libs/mpc/__init__.py +41 -0
  137. mplang/v2/libs/mpc/_utils.py +99 -0
  138. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  139. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  140. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  141. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  142. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  143. mplang/v2/libs/mpc/common/constants.py +39 -0
  144. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  145. mplang/v2/libs/mpc/ot/base.py +222 -0
  146. mplang/v2/libs/mpc/ot/extension.py +477 -0
  147. mplang/v2/libs/mpc/ot/silent.py +217 -0
  148. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  149. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  150. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  151. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  152. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  153. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  154. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  155. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  156. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  157. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  158. mplang/v2/libs/mpc/vole/silver.py +336 -0
  159. mplang/v2/runtime/__init__.py +15 -0
  160. mplang/v2/runtime/dialect_state.py +41 -0
  161. mplang/v2/runtime/interpreter.py +871 -0
  162. mplang/v2/runtime/object_store.py +194 -0
  163. mplang/v2/runtime/value.py +141 -0
  164. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
  165. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  166. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  167. mplang/device.py +0 -327
  168. mplang/ops/crypto.py +0 -108
  169. mplang/ops/ibis_cc.py +0 -136
  170. mplang/ops/sql_cc.py +0 -62
  171. mplang/runtime/link_comm.py +0 -78
  172. mplang/simp/smpc.py +0 -201
  173. mplang/utils/table_utils.py +0 -85
  174. mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
  175. /mplang/{core → v1/core}/mask.py +0 -0
  176. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  177. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
  178. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
  179. /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
  180. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  181. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  182. /mplang/{simp → v1/simp}/__init__.py +0 -0
  183. /mplang/{utils → v1/utils}/__init__.py +0 -0
  184. /mplang/{utils → v1/utils}/crypto.py +0 -0
  185. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  186. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  187. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  188. {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -248,6 +248,7 @@ class ClusterSpec:
248
248
  world_size: int,
249
249
  *,
250
250
  endpoints: list[str] | None = None,
251
+ spu_world_size: int | None = None,
251
252
  spu_protocol: str = "SEMI2K",
252
253
  spu_field: str = "FM128",
253
254
  runtime_version: str = "simulated",
@@ -325,10 +326,14 @@ class ClusterSpec:
325
326
 
326
327
  # Shared SPU device
327
328
  if enable_spu_device:
329
+ if spu_world_size is None:
330
+ spu_world_size = world_size
331
+ spu_members = [nodes[f"node{i}"] for i in range(spu_world_size)]
332
+
328
333
  devices["SP0"] = Device(
329
334
  name="SP0",
330
335
  kind="SPU",
331
- members=list(nodes.values()),
336
+ members=spu_members,
332
337
  config={
333
338
  "protocol": spu_protocol,
334
339
  "field": spu_field,
@@ -19,7 +19,7 @@ import threading
19
19
  from abc import ABC, abstractmethod
20
20
  from typing import Any
21
21
 
22
- from mplang.core.mask import Mask
22
+ from mplang.v1.core.mask import Mask
23
23
 
24
24
 
25
25
  class ICommunicator(ABC):
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING
20
20
 
21
21
  if TYPE_CHECKING:
22
22
  # Imported only for typing to avoid import cycles at runtime.
23
- from mplang.core.mpobject import MPContext
23
+ from mplang.v1.core.mpobject import MPContext
24
24
 
25
25
  # The global working context.
26
26
  _g_ctx: MPContext | None = None
@@ -177,6 +177,13 @@ class DType:
177
177
  # TypeError if it's not a pandas dtype we can handle
178
178
  pass
179
179
 
180
+ try:
181
+ return cls._from_arrow_dtype(dtype_like)
182
+ except (ImportError, TypeError):
183
+ # ImportError if pyarrow is not installed
184
+ # TypeError if it's not a pyarrow dtype we can handle
185
+ pass
186
+
180
187
  if isinstance(dtype_like, type) and dtype_like in (bool, int, float, complex):
181
188
  return cls.from_python_type(dtype_like)
182
189
  elif hasattr(dtype_like, "dtype") and not isinstance(dtype_like, type):
@@ -225,6 +232,37 @@ class DType:
225
232
 
226
233
  raise TypeError(f"Unsupported pandas dtype: {dtype_like}")
227
234
 
235
+ @classmethod
236
+ def _from_arrow_dtype(cls, dtype_like: Any) -> DType:
237
+ try:
238
+ import pyarrow as pa
239
+ except ImportError:
240
+ raise ImportError("pyarrow not available") from None
241
+
242
+ if not isinstance(dtype_like, pa.DataType):
243
+ raise TypeError("Not a pyarrow dtype")
244
+
245
+ ARROW_DTYPE_MAPPING = {
246
+ pa.bool_(): BOOL,
247
+ pa.int8(): INT8,
248
+ pa.int16(): INT16,
249
+ pa.int32(): INT32,
250
+ pa.int64(): INT64,
251
+ pa.uint8(): UINT8,
252
+ pa.uint16(): UINT16,
253
+ pa.uint32(): UINT32,
254
+ pa.uint64(): UINT64,
255
+ pa.float16(): FLOAT16,
256
+ pa.float32(): FLOAT32,
257
+ pa.float64(): FLOAT64,
258
+ pa.string(): STRING,
259
+ pa.large_string(): STRING,
260
+ }
261
+ result = ARROW_DTYPE_MAPPING.get(dtype_like)
262
+ if result is not None:
263
+ return result
264
+ raise TypeError(f"Unsupported arrow dtype: {dtype_like}")
265
+
228
266
  def to_numpy(self) -> np.dtype:
229
267
  """Convert custom DType to NumPy dtype."""
230
268
  return np.dtype(self.name)
@@ -20,7 +20,7 @@ multi-party computation graphs using the visitor pattern.
20
20
  """
21
21
 
22
22
  # Core expression types
23
- from mplang.core.expr.ast import (
23
+ from mplang.v1.core.expr.ast import (
24
24
  AccessExpr,
25
25
  CallExpr,
26
26
  CondExpr,
@@ -36,12 +36,12 @@ from mplang.core.expr.ast import (
36
36
  )
37
37
 
38
38
  # Built-in evaluator engines
39
- from mplang.core.expr.evaluator import IEvaluator, create_evaluator
40
- from mplang.core.expr.printer import Printer
41
- from mplang.core.expr.transformer import ExprTransformer
39
+ from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator
40
+ from mplang.v1.core.expr.printer import Printer
41
+ from mplang.v1.core.expr.transformer import ExprTransformer
42
42
 
43
43
  # Utility functions
44
- from mplang.core.expr.utils import (
44
+ from mplang.v1.core.expr.utils import (
45
45
  deduce_mask,
46
46
  ensure_scalar,
47
47
  ensure_tensorlist_equal,
@@ -49,8 +49,8 @@ from mplang.core.expr.utils import (
49
49
  )
50
50
 
51
51
  # Visitor pattern interface
52
- from mplang.core.expr.visitor import ExprVisitor
53
- from mplang.core.expr.walk import walk, walk_dataflow, walk_structural
52
+ from mplang.v1.core.expr.visitor import ExprVisitor
53
+ from mplang.v1.core.expr.walk import walk, walk_dataflow, walk_structural
54
54
 
55
55
  __all__ = [
56
56
  "AccessExpr",
@@ -26,15 +26,15 @@ import logging
26
26
  from abc import ABC, abstractmethod
27
27
  from typing import TYPE_CHECKING, Any
28
28
 
29
- from mplang.core.expr.utils import deduce_mask
30
- from mplang.core.mask import Mask
31
- from mplang.core.mptype import MPType, Rank
32
- from mplang.core.pfunc import PFunction
33
- from mplang.core.table import TableType
34
- from mplang.core.tensor import TensorType
29
+ from mplang.v1.core.expr.utils import deduce_mask
30
+ from mplang.v1.core.mask import Mask
31
+ from mplang.v1.core.mptype import MPType, Rank
32
+ from mplang.v1.core.pfunc import PFunction
33
+ from mplang.v1.core.table import TableType
34
+ from mplang.v1.core.tensor import TensorType
35
35
 
36
36
  if TYPE_CHECKING:
37
- from mplang.core.expr.visitor import ExprVisitor
37
+ from mplang.v1.core.expr.visitor import ExprVisitor
38
38
 
39
39
 
40
40
  class Expr(ABC):
@@ -286,8 +286,8 @@ class ConvExpr(Expr):
286
286
  # Validate dtype / shape consistency.
287
287
  first = types[0]
288
288
  for c in types[1:]:
289
- if (c.dtype, c.shape) != (first.dtype, first.shape):
290
- raise TypeError(f"Inconsistent dtype/shape in pconv: {c} vs {first}")
289
+ if c.raw_type() != first.raw_type():
290
+ raise TypeError(f"Inconsistent type in pconv: {c} vs {first}")
291
291
 
292
292
  # Deduce the pmask by intersecting all pmasks.
293
293
  pmasks = [t.pmask for t in types]
@@ -316,7 +316,7 @@ class ConvExpr(Expr):
316
316
  else:
317
317
  out_pmask = None
318
318
 
319
- return [MPType.tensor(first.dtype, first.shape, out_pmask, **first.attrs)]
319
+ return [MPType(first.raw_type(), out_pmask, first.attrs)]
320
320
 
321
321
  def accept(self, visitor: ExprVisitor) -> Any:
322
322
  return visitor.visit_conv(self)
@@ -398,9 +398,7 @@ class ShflSExpr(Expr):
398
398
  def _compute_mptypes(self) -> list[MPType]:
399
399
  # The types are the same as the source value, but with a new pmask.
400
400
  src_type = self.src_val.mptype
401
- return [
402
- MPType.tensor(src_type.dtype, src_type.shape, self.pmask, **src_type.attrs)
403
- ]
401
+ return [MPType(src_type._type, self.pmask, src_type.attrs)]
404
402
 
405
403
  def accept(self, visitor: ExprVisitor) -> Any:
406
404
  return visitor.visit_shfl_s(self)
@@ -27,8 +27,8 @@ from __future__ import annotations
27
27
  from dataclasses import dataclass
28
28
  from typing import Any, Protocol
29
29
 
30
- from mplang.core.comm import ICommunicator
31
- from mplang.core.expr.ast import (
30
+ from mplang.v1.core.comm import ICommunicator
31
+ from mplang.v1.core.expr.ast import (
32
32
  AccessExpr,
33
33
  CallExpr,
34
34
  CondExpr,
@@ -42,12 +42,12 @@ from mplang.core.expr.ast import (
42
42
  VariableExpr,
43
43
  WhileExpr,
44
44
  )
45
- from mplang.core.expr.visitor import ExprVisitor
46
- from mplang.core.expr.walk import walk_dataflow
47
- from mplang.core.mask import Mask
48
- from mplang.core.pfunc import PFunction
49
- from mplang.kernels.context import RuntimeContext
50
- from mplang.kernels.value import Value
45
+ from mplang.v1.core.expr.visitor import ExprVisitor
46
+ from mplang.v1.core.expr.walk import walk_dataflow
47
+ from mplang.v1.core.mask import Mask
48
+ from mplang.v1.core.pfunc import PFunction
49
+ from mplang.v1.kernels.context import RuntimeContext
50
+ from mplang.v1.kernels.value import Value
51
51
 
52
52
 
53
53
  class IEvaluator(Protocol):
@@ -20,8 +20,8 @@ from __future__ import annotations
20
20
 
21
21
  from typing import Any
22
22
 
23
- from mplang.core.dtypes import DType
24
- from mplang.core.expr.ast import (
23
+ from mplang.v1.core.dtypes import DType
24
+ from mplang.v1.core.expr.ast import (
25
25
  AccessExpr,
26
26
  CallExpr,
27
27
  CondExpr,
@@ -35,10 +35,10 @@ from mplang.core.expr.ast import (
35
35
  VariableExpr,
36
36
  WhileExpr,
37
37
  )
38
- from mplang.core.expr.visitor import ExprVisitor
39
- from mplang.core.mptype import MPType
40
- from mplang.core.pfunc import PFunction
41
- from mplang.core.tensor import Shape, TensorType
38
+ from mplang.v1.core.expr.visitor import ExprVisitor
39
+ from mplang.v1.core.mptype import MPType
40
+ from mplang.v1.core.pfunc import PFunction
41
+ from mplang.v1.core.tensor import Shape, TensorType
42
42
 
43
43
 
44
44
  class Printer(ExprVisitor):
@@ -18,7 +18,7 @@ Expression transformer based on visitor pattern.
18
18
 
19
19
  from collections.abc import Callable
20
20
 
21
- from mplang.core.expr.ast import (
21
+ from mplang.v1.core.expr.ast import (
22
22
  AccessExpr,
23
23
  CallExpr,
24
24
  CondExpr,
@@ -32,7 +32,7 @@ from mplang.core.expr.ast import (
32
32
  VariableExpr,
33
33
  WhileExpr,
34
34
  )
35
- from mplang.core.expr.visitor import ExprVisitor
35
+ from mplang.v1.core.expr.visitor import ExprVisitor
36
36
 
37
37
 
38
38
  class ExprTransformer(ExprVisitor):
@@ -18,8 +18,8 @@ Utility functions for expression system.
18
18
 
19
19
  from collections.abc import Sequence
20
20
 
21
- from mplang.core.mask import Mask
22
- from mplang.core.mptype import TensorLike
21
+ from mplang.v1.core.mask import Mask
22
+ from mplang.v1.core.mptype import TensorLike
23
23
 
24
24
 
25
25
  def type_equal(*args: TensorLike) -> bool:
@@ -22,7 +22,7 @@ from abc import ABC, abstractmethod
22
22
  from typing import TYPE_CHECKING, Any
23
23
 
24
24
  if TYPE_CHECKING:
25
- from mplang.core.expr.ast import (
25
+ from mplang.v1.core.expr.ast import (
26
26
  AccessExpr,
27
27
  CallExpr,
28
28
  CondExpr,
@@ -32,7 +32,7 @@ from collections import deque
32
32
  from collections.abc import Callable, Iterable, Iterator, Sequence
33
33
  from typing import cast
34
34
 
35
- from mplang.core.expr.ast import (
35
+ from mplang.v1.core.expr.ast import (
36
36
  AccessExpr,
37
37
  CallExpr,
38
38
  CondExpr,
@@ -25,12 +25,12 @@ from abc import abstractmethod
25
25
  from collections.abc import Sequence
26
26
  from typing import Any, cast
27
27
 
28
- from mplang.core.cluster import ClusterSpec
29
- from mplang.core.expr.ast import Expr, VariableExpr
30
- from mplang.core.mpobject import MPContext, MPObject
31
- from mplang.core.mptype import MPType, TensorLike
32
- from mplang.core.tracer import TracedFunction
33
- from mplang.utils.func_utils import var_demorph, var_morph
28
+ from mplang.v1.core.cluster import ClusterSpec
29
+ from mplang.v1.core.expr.ast import Expr, VariableExpr
30
+ from mplang.v1.core.mpobject import MPContext, MPObject
31
+ from mplang.v1.core.mptype import MPType, TensorLike
32
+ from mplang.v1.core.tracer import TracedFunction
33
+ from mplang.v1.utils.func_utils import var_demorph, var_morph
34
34
 
35
35
 
36
36
  # TODO(jint): Should we use inheritance or composition here?
@@ -32,9 +32,9 @@ from typing import Any
32
32
  import numpy as np
33
33
  import spu.libspu as spu_api
34
34
 
35
- from mplang.core.dtypes import DATE, JSON, STRING, TIME, TIMESTAMP, DType
36
- from mplang.core.expr import Expr, FuncDefExpr
37
- from mplang.core.expr.ast import (
35
+ from mplang.v1.core.dtypes import DATE, JSON, STRING, TIME, TIMESTAMP, DType
36
+ from mplang.v1.core.expr import Expr, FuncDefExpr
37
+ from mplang.v1.core.expr.ast import (
38
38
  AccessExpr,
39
39
  CallExpr,
40
40
  CondExpr,
@@ -46,13 +46,13 @@ from mplang.core.expr.ast import (
46
46
  VariableExpr,
47
47
  WhileExpr,
48
48
  )
49
- from mplang.core.expr.walk import walk
50
- from mplang.core.mask import Mask
51
- from mplang.core.mptype import MPType
52
- from mplang.core.pfunc import PFunction
53
- from mplang.core.table import TableType
54
- from mplang.core.tensor import TensorType
55
- from mplang.protos.v1alpha1 import mpir_pb2
49
+ from mplang.v1.core.expr.walk import walk
50
+ from mplang.v1.core.mask import Mask
51
+ from mplang.v1.core.mptype import MPType
52
+ from mplang.v1.core.pfunc import PFunction
53
+ from mplang.v1.core.table import TableType
54
+ from mplang.v1.core.tensor import TensorType
55
+ from mplang.v1.protos.v1alpha1 import mpir_pb2
56
56
 
57
57
  # Single mapping table for dtype conversion
58
58
  DTYPE_MAPPING = {
@@ -217,7 +217,9 @@ def attr_to_proto(py_value: Any) -> mpir_pb2.AttrProto:
217
217
  # Serialize attrs dictionary
218
218
  if py_value.attrs:
219
219
  for attr_name, attr_value in py_value.attrs.items():
220
- attr_proto.func.attrs[attr_name].CopyFrom(attr_to_proto(attr_value))
220
+ # Skip None-valued attributes to align with top-level attr handling
221
+ if attr_value is not None:
222
+ attr_proto.func.attrs[attr_name].CopyFrom(attr_to_proto(attr_value))
221
223
 
222
224
  # Note: We don't serialize ins_info and outs_info since they can be
223
225
  # inferred from the input expressions during deserialization
@@ -17,14 +17,14 @@ from __future__ import annotations
17
17
  from abc import ABC, abstractmethod
18
18
  from typing import TYPE_CHECKING, Any
19
19
 
20
- from mplang.core.dtypes import DType
21
- from mplang.core.mask import Mask
22
- from mplang.core.mptype import MPType
23
- from mplang.core.table import TableType
24
- from mplang.core.tensor import Shape
20
+ from mplang.v1.core.dtypes import DType
21
+ from mplang.v1.core.mask import Mask
22
+ from mplang.v1.core.mptype import MPType
23
+ from mplang.v1.core.table import TableType
24
+ from mplang.v1.core.tensor import Shape
25
25
 
26
26
  if TYPE_CHECKING:
27
- from mplang.core.cluster import ClusterSpec
27
+ from mplang.v1.core.cluster import ClusterSpec
28
28
 
29
29
 
30
30
  class MPContext:
@@ -20,12 +20,12 @@ from typing import TYPE_CHECKING, Any
20
20
  import numpy as np
21
21
 
22
22
  if TYPE_CHECKING:
23
- from mplang.core.mpobject import MPObject
23
+ from mplang.v1.core.mpobject import MPObject
24
24
 
25
- from mplang.core.dtypes import STRING, DType
26
- from mplang.core.mask import Mask
27
- from mplang.core.table import TableLike, TableType
28
- from mplang.core.tensor import ScalarType, Shape, TensorLike, TensorType
25
+ from mplang.v1.core.dtypes import STRING, DType
26
+ from mplang.v1.core.mask import Mask
27
+ from mplang.v1.core.table import TableLike, TableType
28
+ from mplang.v1.core.tensor import ScalarType, Shape, TensorLike, TensorType
29
29
 
30
30
  # basic type aliases
31
31
  Rank = int
@@ -195,6 +195,10 @@ class MPType:
195
195
  information about the object."""
196
196
  return self._attrs
197
197
 
198
+ def raw_type(self) -> TensorType | TableType:
199
+ """Get the raw type information (TensorType or TableType)."""
200
+ return self._type
201
+
198
202
  def set_attr(self, key: str, value: Any) -> None:
199
203
  """Set an attribute for this type."""
200
204
  self._attrs[key] = value
@@ -252,9 +256,8 @@ class MPType:
252
256
  if not isinstance(other, MPType):
253
257
  return False
254
258
  return (
255
- self._type == other._type
256
- and self._pmask == other._pmask
257
- and self._attrs == other._attrs
259
+ self._type == other._type and self._pmask == other._pmask
260
+ # and self._attrs == other._attrs # TODO(jint): attrs should be optional
258
261
  )
259
262
 
260
263
  def __hash__(self) -> int:
@@ -270,7 +273,7 @@ class MPType:
270
273
  def isInstance(self, obj: MPObject) -> bool:
271
274
  """Check if the given object is an instance of this MPType."""
272
275
  # Import here to avoid circular import
273
- from mplang.core.mpobject import MPObject
276
+ from mplang.v1.core.mpobject import MPObject
274
277
 
275
278
  if not isinstance(obj, MPObject):
276
279
  return False
@@ -373,7 +376,7 @@ class MPType:
373
376
  import pandas as pd
374
377
 
375
378
  if isinstance(obj, pd.DataFrame):
376
- from mplang.core.dtypes import DType
379
+ from mplang.v1.core.dtypes import DType
377
380
 
378
381
  schema_dict = {}
379
382
  for col_name in obj.columns:
@@ -19,8 +19,8 @@ from collections.abc import Sequence
19
19
  from types import MappingProxyType
20
20
  from typing import Any
21
21
 
22
- from mplang.core.table import TableType
23
- from mplang.core.tensor import TensorType
22
+ from mplang.v1.core.table import TableType
23
+ from mplang.v1.core.tensor import TensorType
24
24
 
25
25
  __all__ = [
26
26
  "PFunction",
@@ -28,9 +28,9 @@ from typing import Any, ParamSpec, TypeVar, cast
28
28
 
29
29
  from jax.tree_util import tree_map
30
30
 
31
- from mplang.core.context_mgr import cur_ctx
32
- from mplang.core.dtypes import BOOL
33
- from mplang.core.expr.ast import (
31
+ from mplang.v1.core.context_mgr import cur_ctx
32
+ from mplang.v1.core.dtypes import BOOL
33
+ from mplang.v1.core.expr.ast import (
34
34
  AccessExpr,
35
35
  CallExpr,
36
36
  CondExpr,
@@ -40,13 +40,13 @@ from mplang.core.expr.ast import (
40
40
  ShflSExpr,
41
41
  WhileExpr,
42
42
  )
43
- from mplang.core.interp import InterpContext, InterpVar, apply
44
- from mplang.core.mask import Mask
45
- from mplang.core.mpobject import MPContext, MPObject
46
- from mplang.core.mptype import Rank
47
- from mplang.core.pfunc import PFunction
48
- from mplang.core.tracer import TraceContext, TraceVar, trace
49
- from mplang.utils.func_utils import var_demorph, var_morph
43
+ from mplang.v1.core.interp import InterpContext, InterpVar, apply
44
+ from mplang.v1.core.mask import Mask
45
+ from mplang.v1.core.mpobject import MPContext, MPObject
46
+ from mplang.v1.core.mptype import Rank
47
+ from mplang.v1.core.pfunc import PFunction
48
+ from mplang.v1.core.tracer import TraceContext, TraceVar, trace
49
+ from mplang.v1.utils.func_utils import var_demorph, var_morph
50
50
 
51
51
 
52
52
  def _switch_ctx(ctx: MPContext, obj: MPObject | Any) -> MPObject | Any:
@@ -298,7 +298,7 @@ def uniform_cond(
298
298
 
299
299
  1. ``pred`` is a boolean scalar whose runtime value is identical for every enabled party.
300
300
  2. At least one branch contains multi-party primitives (``seal`` / ``reveal`` /
301
- ``srun`` / ``pshfl`` / mask transformations) whose cost or side-effects you
301
+ ``srun_jax`` / ``pshfl`` / mask transformations) whose cost or side-effects you
302
302
  want to avoid if the branch is not taken.
303
303
  3. You require the semantic guarantee that the *non-selected* branch does **not**
304
304
  perform communication, allocate intermediate buffers, or leak timing/side-effects.
@@ -559,7 +559,7 @@ def while_loop(
559
559
  secret-shared reduction).
560
560
 
561
561
  cond_fn::
562
- sealed_sum = smpc.reveal(smpc.srun(lambda x: jnp.sum(x))(smpc.seal(x)))
562
+ sealed_sum = smpc.reveal(smpc.srun_jax(lambda x: jnp.sum(x), smpc.seal(x)))
563
563
  return sealed_sum < constant(10)
564
564
 
565
565
  body_fn::
@@ -18,16 +18,16 @@ from collections.abc import Iterator
18
18
  from dataclasses import dataclass, field
19
19
  from typing import Any, Protocol, runtime_checkable
20
20
 
21
- from mplang.core.dtypes import DType
21
+ from mplang.v1.core.dtypes import DType
22
22
 
23
23
  __all__ = ["TableLike", "TableType"]
24
24
 
25
25
 
26
26
  @runtime_checkable
27
- class TableLike(Protocol):
27
+ class PandasTableLike(Protocol):
28
28
  """
29
29
  Protocol for objects structurally resembling tables from common libraries
30
- (pandas DataFrame, pyarrow Table, etc.), focusing on dtypes and columns attributes.
30
+ (pandas DataFrame, polars DataFrame, etc.), focusing on dtypes and columns attributes.
31
31
  """
32
32
 
33
33
  @property
@@ -37,6 +37,26 @@ class TableLike(Protocol):
37
37
  def columns(self) -> Any: ...
38
38
 
39
39
 
40
+ @runtime_checkable
41
+ class ArrowSchema(Protocol):
42
+ @property
43
+ def names(self) -> list[str]: ...
44
+ @property
45
+ def types(self) -> list[Any]: ...
46
+
47
+
48
+ @runtime_checkable
49
+ class ArrowTableLike(Protocol):
50
+ @property
51
+ def column_names(self) -> list[str]: ...
52
+
53
+ @property
54
+ def schema(self) -> ArrowSchema: ...
55
+
56
+
57
+ TableLike = PandasTableLike | ArrowTableLike
58
+
59
+
40
60
  @dataclass(frozen=True)
41
61
  class TableType:
42
62
  """Table schema: ordered list of column name-type pairs.
@@ -109,11 +129,19 @@ class TableType:
109
129
  Returns:
110
130
  TableType instance
111
131
  """
112
- columns = [
113
- (name, DType.from_any(dtype))
114
- for name, dtype in zip(table.columns, table.dtypes, strict=True)
115
- ]
116
- return cls(tuple(columns))
132
+ if isinstance(table, PandasTableLike):
133
+ columns = [
134
+ (name, DType.from_any(dtype))
135
+ for name, dtype in zip(table.columns, table.dtypes, strict=True)
136
+ ]
137
+ return cls(tuple(columns))
138
+ elif isinstance(table, ArrowTableLike):
139
+ schema = table.schema
140
+ columns = [
141
+ (name, DType.from_any(dtype))
142
+ for name, dtype in zip(schema.names, schema.types, strict=True)
143
+ ]
144
+ return cls(tuple(columns))
117
145
 
118
146
  def column_names(self) -> tuple[str, ...]:
119
147
  """Get all column names."""
@@ -19,7 +19,7 @@ from typing import Any, Protocol, runtime_checkable
19
19
 
20
20
  import numpy as np
21
21
 
22
- from mplang.core.dtypes import DType
22
+ from mplang.v1.core.dtypes import DType
23
23
 
24
24
  # basic type aliases
25
25
  Shape = tuple[int, ...]
@@ -60,15 +60,15 @@ from collections.abc import Callable
60
60
  from dataclasses import dataclass
61
61
  from typing import Any, cast
62
62
 
63
- from mplang.core.cluster import ClusterSpec
64
- from mplang.core.context_mgr import with_ctx
65
- from mplang.core.expr.ast import Expr, FuncDefExpr, TupleExpr, VariableExpr
66
- from mplang.core.expr.printer import Printer
67
- from mplang.core.mask import Mask
68
- from mplang.core.mpobject import MPContext, MPObject
69
- from mplang.core.mptype import MPType
70
- from mplang.core.pfunc import get_fn_name
71
- from mplang.utils.func_utils import MorphStruct, var_demorph, var_morph
63
+ from mplang.v1.core.cluster import ClusterSpec
64
+ from mplang.v1.core.context_mgr import with_ctx
65
+ from mplang.v1.core.expr.ast import Expr, FuncDefExpr, TupleExpr, VariableExpr
66
+ from mplang.v1.core.expr.printer import Printer
67
+ from mplang.v1.core.mask import Mask
68
+ from mplang.v1.core.mpobject import MPContext, MPObject
69
+ from mplang.v1.core.mptype import MPType
70
+ from mplang.v1.core.pfunc import get_fn_name
71
+ from mplang.v1.utils.func_utils import MorphStruct, var_demorph, var_morph
72
72
 
73
73
 
74
74
  class VarNamer:
@@ -19,7 +19,7 @@ from typing import Any
19
19
 
20
20
  from jax.tree_util import tree_map
21
21
 
22
- from mplang.core import (
22
+ from mplang.v1.core import (
23
23
  ClusterSpec,
24
24
  InterpContext,
25
25
  MPContext,
@@ -28,7 +28,7 @@ from mplang.core import (
28
28
  TracedFunction,
29
29
  trace,
30
30
  )
31
- from mplang.core.context_mgr import cur_ctx, with_ctx
31
+ from mplang.v1.core.context_mgr import cur_ctx, with_ctx
32
32
 
33
33
 
34
34
  def evaluate(
@@ -76,11 +76,11 @@ def fetch(interp: InterpContext | None, objs: Any) -> Any: # type: ignore[misc]
76
76
  evaluated = evaluate(ctx, lambda x: x, objs)
77
77
 
78
78
  def fetch_impl(arg: MPObject | Any) -> Any:
79
- if isinstance(arg, MPObject):
80
- return ctx.fetch(arg)
81
- else:
79
+ if not isinstance(arg, MPObject):
82
80
  return arg
83
81
 
82
+ return ctx.fetch(arg)
83
+
84
84
  return tree_map(fetch_impl, evaluated)
85
85
 
86
86
 
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from mplang.kernels.value import (
15
+ from mplang.v1.kernels.value import (
16
16
  BytesBlob,
17
17
  TableValue,
18
18
  TensorValue,
@@ -37,7 +37,7 @@ from dataclasses import dataclass
37
37
  from typing import TYPE_CHECKING, Any
38
38
 
39
39
  if TYPE_CHECKING:
40
- from mplang.kernels.context import RuntimeContext
40
+ from mplang.v1.kernels.context import RuntimeContext
41
41
 
42
42
  __all__ = [
43
43
  "KernelContext",