mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -20,8 +20,8 @@ from __future__ import annotations
20
20
 
21
21
  from typing import Any
22
22
 
23
- from mplang.core.dtype import DType
24
- from mplang.core.expr.ast import (
23
+ from mplang.v1.core.dtypes import DType
24
+ from mplang.v1.core.expr.ast import (
25
25
  AccessExpr,
26
26
  CallExpr,
27
27
  CondExpr,
@@ -35,10 +35,10 @@ from mplang.core.expr.ast import (
35
35
  VariableExpr,
36
36
  WhileExpr,
37
37
  )
38
- from mplang.core.expr.visitor import ExprVisitor
39
- from mplang.core.mptype import MPType
40
- from mplang.core.pfunc import PFunction
41
- from mplang.core.tensor import Shape, TensorType
38
+ from mplang.v1.core.expr.visitor import ExprVisitor
39
+ from mplang.v1.core.mptype import MPType
40
+ from mplang.v1.core.pfunc import PFunction
41
+ from mplang.v1.core.tensor import Shape, TensorType
42
42
 
43
43
 
44
44
  class Printer(ExprVisitor):
@@ -50,11 +50,13 @@ class Printer(ExprVisitor):
50
50
  compact_format: bool = True,
51
51
  *,
52
52
  verbose_peval: bool = False,
53
+ inline_pcall: bool = True,
53
54
  ):
54
55
  super().__init__() # Initialize MemorizedVisitor
55
56
  self.indent_size = indent_size
56
57
  self.compact_format = compact_format
57
58
  self.verbose_peval = verbose_peval
59
+ self.inline_pcall = inline_pcall
58
60
  self._cur_indent = 0
59
61
  self._output: list[str] = []
60
62
  self._visited: dict[Expr, str] = {}
@@ -92,6 +94,7 @@ class Printer(ExprVisitor):
92
94
  body_printer = Printer(
93
95
  indent_size=self.indent_size,
94
96
  compact_format=self.compact_format,
97
+ inline_pcall=self.inline_pcall,
95
98
  )
96
99
  func_def_expr.accept(body_printer)
97
100
  regions_str += f"{indent}{r_name}: "
@@ -161,13 +164,9 @@ class Printer(ExprVisitor):
161
164
  arg_names = [self._var_name(arg) for arg in expr.args]
162
165
  fn_type = expr.pfunc.fn_type
163
166
 
164
- # for well known builtin functions
165
- if fn_type == "builtin.constant":
167
+ # for well known basic functions
168
+ if fn_type == "basic.constant":
166
169
  return self._print_const(expr.pfunc, expr.mptypes)
167
- elif fn_type == "builtin.rank":
168
- return self._do_print("prank", [], mptypes=expr.mptypes)
169
- elif fn_type == "builtin.prand":
170
- return self._do_print("prand", [], mptypes=expr.mptypes)
171
170
 
172
171
  attrs = {"fn_type": fn_type}
173
172
  if expr.pfunc.fn_name:
@@ -209,12 +208,19 @@ class Printer(ExprVisitor):
209
208
 
210
209
  def visit_call(self, expr: CallExpr) -> str:
211
210
  arg_names = [self._var_name(arg) for arg in expr.args]
212
- return self._do_print(
213
- "pcall",
214
- arg_names,
215
- regions={"fn": expr.fn},
216
- mptypes=expr.mptypes,
217
- )
211
+ if self.inline_pcall:
212
+ return self._do_print(
213
+ expr.name,
214
+ arg_names,
215
+ mptypes=expr.mptypes,
216
+ )
217
+ else:
218
+ return self._do_print(
219
+ "pcall",
220
+ arg_names,
221
+ regions={"fn": expr.fn},
222
+ mptypes=expr.mptypes,
223
+ )
218
224
 
219
225
  def visit_while(self, expr: WhileExpr) -> str:
220
226
  arg_names = [self._var_name(arg) for arg in expr.args]
@@ -18,7 +18,7 @@ Expression transformer based on visitor pattern.
18
18
 
19
19
  from collections.abc import Callable
20
20
 
21
- from mplang.core.expr.ast import (
21
+ from mplang.v1.core.expr.ast import (
22
22
  AccessExpr,
23
23
  CallExpr,
24
24
  CondExpr,
@@ -32,7 +32,7 @@ from mplang.core.expr.ast import (
32
32
  VariableExpr,
33
33
  WhileExpr,
34
34
  )
35
- from mplang.core.expr.visitor import ExprVisitor
35
+ from mplang.v1.core.expr.visitor import ExprVisitor
36
36
 
37
37
 
38
38
  class ExprTransformer(ExprVisitor):
@@ -79,7 +79,7 @@ class ExprTransformer(ExprVisitor):
79
79
  def visit_call(self, expr: CallExpr) -> Expr:
80
80
  # Transform child expressions first
81
81
  transformed_args = [arg.accept(self) for arg in expr.args]
82
- new_expr = CallExpr(expr.fn, transformed_args)
82
+ new_expr = CallExpr(expr.name, expr.fn, transformed_args)
83
83
 
84
84
  if "call" in self.trans_rules:
85
85
  return self.trans_rules["call"](new_expr)
@@ -18,8 +18,8 @@ Utility functions for expression system.
18
18
 
19
19
  from collections.abc import Sequence
20
20
 
21
- from mplang.core.mask import Mask
22
- from mplang.core.mptype import TensorLike
21
+ from mplang.v1.core.mask import Mask
22
+ from mplang.v1.core.mptype import TensorLike
23
23
 
24
24
 
25
25
  def type_equal(*args: TensorLike) -> bool:
@@ -22,7 +22,7 @@ from abc import ABC, abstractmethod
22
22
  from typing import TYPE_CHECKING, Any
23
23
 
24
24
  if TYPE_CHECKING:
25
- from mplang.core.expr.ast import (
25
+ from mplang.v1.core.expr.ast import (
26
26
  AccessExpr,
27
27
  CallExpr,
28
28
  CondExpr,
@@ -32,7 +32,7 @@ from collections import deque
32
32
  from collections.abc import Callable, Iterable, Iterator, Sequence
33
33
  from typing import cast
34
34
 
35
- from mplang.core.expr.ast import (
35
+ from mplang.v1.core.expr.ast import (
36
36
  AccessExpr,
37
37
  CallExpr,
38
38
  CondExpr,
@@ -25,12 +25,12 @@ from abc import abstractmethod
25
25
  from collections.abc import Sequence
26
26
  from typing import Any, cast
27
27
 
28
- from mplang.core.cluster import ClusterSpec
29
- from mplang.core.expr.ast import Expr, VariableExpr
30
- from mplang.core.mpobject import MPContext, MPObject
31
- from mplang.core.mptype import MPType, TensorLike
32
- from mplang.core.tracer import TracedFunction
33
- from mplang.utils.func_utils import var_demorph, var_morph
28
+ from mplang.v1.core.cluster import ClusterSpec
29
+ from mplang.v1.core.expr.ast import Expr, VariableExpr
30
+ from mplang.v1.core.mpobject import MPContext, MPObject
31
+ from mplang.v1.core.mptype import MPType, TensorLike
32
+ from mplang.v1.core.tracer import TracedFunction
33
+ from mplang.v1.utils.func_utils import var_demorph, var_morph
34
34
 
35
35
 
36
36
  # TODO(jint): Should we use inheritance or composition here?
@@ -32,9 +32,9 @@ from typing import Any
32
32
  import numpy as np
33
33
  import spu.libspu as spu_api
34
34
 
35
- from mplang.core.dtype import DATE, JSON, STRING, TIME, TIMESTAMP, DType
36
- from mplang.core.expr import Expr, FuncDefExpr
37
- from mplang.core.expr.ast import (
35
+ from mplang.v1.core.dtypes import DATE, JSON, STRING, TIME, TIMESTAMP, DType
36
+ from mplang.v1.core.expr import Expr, FuncDefExpr
37
+ from mplang.v1.core.expr.ast import (
38
38
  AccessExpr,
39
39
  CallExpr,
40
40
  CondExpr,
@@ -46,13 +46,13 @@ from mplang.core.expr.ast import (
46
46
  VariableExpr,
47
47
  WhileExpr,
48
48
  )
49
- from mplang.core.expr.walk import walk
50
- from mplang.core.mask import Mask
51
- from mplang.core.mptype import MPType
52
- from mplang.core.pfunc import PFunction
53
- from mplang.core.table import TableType
54
- from mplang.core.tensor import TensorType
55
- from mplang.protos.v1alpha1 import mpir_pb2
49
+ from mplang.v1.core.expr.walk import walk
50
+ from mplang.v1.core.mask import Mask
51
+ from mplang.v1.core.mptype import MPType
52
+ from mplang.v1.core.pfunc import PFunction
53
+ from mplang.v1.core.table import TableType
54
+ from mplang.v1.core.tensor import TensorType
55
+ from mplang.v1.protos.v1alpha1 import mpir_pb2
56
56
 
57
57
  # Single mapping table for dtype conversion
58
58
  DTYPE_MAPPING = {
@@ -204,7 +204,7 @@ def attr_to_proto(py_value: Any) -> mpir_pb2.AttrProto:
204
204
  raise TypeError(f"Unsupported tuple/list type: {type(py_value)}")
205
205
  elif isinstance(py_value, FuncDefExpr):
206
206
  # Convert FuncDefExpr to GraphProto
207
- graph = Writer().dumps(py_value)
207
+ graph = IrWriter().dumps(py_value)
208
208
  attr_proto.type = mpir_pb2.AttrProto.GRAPH
209
209
  attr_proto.graph.CopyFrom(graph)
210
210
  elif isinstance(py_value, PFunction):
@@ -217,7 +217,9 @@ def attr_to_proto(py_value: Any) -> mpir_pb2.AttrProto:
217
217
  # Serialize attrs dictionary
218
218
  if py_value.attrs:
219
219
  for attr_name, attr_value in py_value.attrs.items():
220
- attr_proto.func.attrs[attr_name].CopyFrom(attr_to_proto(attr_value))
220
+ # Skip None-valued attributes to align with top-level attr handling
221
+ if attr_value is not None:
222
+ attr_proto.func.attrs[attr_name].CopyFrom(attr_to_proto(attr_value))
221
223
 
222
224
  # Note: We don't serialize ins_info and outs_info since they can be
223
225
  # inferred from the input expressions during deserialization
@@ -234,7 +236,7 @@ def attr_to_proto(py_value: Any) -> mpir_pb2.AttrProto:
234
236
  return attr_proto
235
237
 
236
238
 
237
- class Writer:
239
+ class IrWriter:
238
240
  """Writer for serializing Expr-based expressions to GraphProto.
239
241
 
240
242
  This class traverses an expression tree and converts it into a serialized
@@ -491,6 +493,7 @@ class Writer:
491
493
  op = self._create_node_proto(expr, "call")
492
494
  self._add_single_expr_inputs(op, expr.fn)
493
495
  self._add_expr_inputs(op, *expr.args)
496
+ self._add_attrs(op, name=expr.name)
494
497
  self._finalize_node(op, expr)
495
498
  elif isinstance(expr, WhileExpr):
496
499
  op = self._create_node_proto(expr, "while")
@@ -524,7 +527,7 @@ class Writer:
524
527
  raise TypeError(f"Unsupported expr type for serialization: {type(expr)}")
525
528
 
526
529
 
527
- class Reader:
530
+ class IrReader:
528
531
  """Reader for deserializing GraphProto back to Expr-based expressions.
529
532
 
530
533
  This class is responsible for converting serialized GraphProto representations
@@ -822,8 +825,12 @@ class Reader:
822
825
  arg_exprs.append(self._value_cache[dep_name])
823
826
  else:
824
827
  raise ValueError(f"Input {input_name} not found for call node")
828
+ # Optional call-site name attribute
829
+ call_name = None
830
+ if "name" in node_proto.attrs:
831
+ call_name = self._proto_to_attr(node_proto.attrs["name"]) # type: ignore[assignment]
825
832
 
826
- return CallExpr(fn_expr, arg_exprs)
833
+ return CallExpr(call_name or "", fn_expr, arg_exprs)
827
834
 
828
835
  def _proto_to_mptype(self, type_proto: mpir_pb2.MPTypeProto) -> MPType:
829
836
  """Convert MPTypeProto to MPType."""
@@ -897,7 +904,7 @@ class Reader:
897
904
  )
898
905
  elif attr_proto.type == mpir_pb2.AttrProto.GRAPH:
899
906
  # Handle nested expressions (for control flow)
900
- reader = Reader()
907
+ reader = IrReader()
901
908
  return reader.loads(attr_proto.graph)
902
909
  else:
903
910
  raise TypeError(f"Unsupported attribute type: {attr_proto.type}")
@@ -17,14 +17,14 @@ from __future__ import annotations
17
17
  from abc import ABC, abstractmethod
18
18
  from typing import TYPE_CHECKING, Any
19
19
 
20
- from mplang.core.dtype import DType
21
- from mplang.core.mask import Mask
22
- from mplang.core.mptype import MPType
23
- from mplang.core.table import TableType
24
- from mplang.core.tensor import Shape
20
+ from mplang.v1.core.dtypes import DType
21
+ from mplang.v1.core.mask import Mask
22
+ from mplang.v1.core.mptype import MPType
23
+ from mplang.v1.core.table import TableType
24
+ from mplang.v1.core.tensor import Shape
25
25
 
26
26
  if TYPE_CHECKING:
27
- from mplang.core.cluster import ClusterSpec
27
+ from mplang.v1.core.cluster import ClusterSpec
28
28
 
29
29
 
30
30
  class MPContext:
@@ -20,12 +20,12 @@ from typing import TYPE_CHECKING, Any
20
20
  import numpy as np
21
21
 
22
22
  if TYPE_CHECKING:
23
- from mplang.core.mpobject import MPObject
23
+ from mplang.v1.core.mpobject import MPObject
24
24
 
25
- from mplang.core.dtype import STRING, DType
26
- from mplang.core.mask import Mask
27
- from mplang.core.table import TableLike, TableType
28
- from mplang.core.tensor import ScalarType, Shape, TensorLike, TensorType
25
+ from mplang.v1.core.dtypes import STRING, DType
26
+ from mplang.v1.core.mask import Mask
27
+ from mplang.v1.core.table import TableLike, TableType
28
+ from mplang.v1.core.tensor import ScalarType, Shape, TensorLike, TensorType
29
29
 
30
30
  # basic type aliases
31
31
  Rank = int
@@ -195,6 +195,10 @@ class MPType:
195
195
  information about the object."""
196
196
  return self._attrs
197
197
 
198
+ def raw_type(self) -> TensorType | TableType:
199
+ """Get the raw type information (TensorType or TableType)."""
200
+ return self._type
201
+
198
202
  def set_attr(self, key: str, value: Any) -> None:
199
203
  """Set an attribute for this type."""
200
204
  self._attrs[key] = value
@@ -252,9 +256,8 @@ class MPType:
252
256
  if not isinstance(other, MPType):
253
257
  return False
254
258
  return (
255
- self._type == other._type
256
- and self._pmask == other._pmask
257
- and self._attrs == other._attrs
259
+ self._type == other._type and self._pmask == other._pmask
260
+ # and self._attrs == other._attrs # TODO(jint): attrs should be optional
258
261
  )
259
262
 
260
263
  def __hash__(self) -> int:
@@ -270,7 +273,7 @@ class MPType:
270
273
  def isInstance(self, obj: MPObject) -> bool:
271
274
  """Check if the given object is an instance of this MPType."""
272
275
  # Import here to avoid circular import
273
- from mplang.core.mpobject import MPObject
276
+ from mplang.v1.core.mpobject import MPObject
274
277
 
275
278
  if not isinstance(obj, MPObject):
276
279
  return False
@@ -373,7 +376,7 @@ class MPType:
373
376
  import pandas as pd
374
377
 
375
378
  if isinstance(obj, pd.DataFrame):
376
- from mplang.core.dtype import DType
379
+ from mplang.v1.core.dtypes import DType
377
380
 
378
381
  schema_dict = {}
379
382
  for col_name in obj.columns:
@@ -19,8 +19,8 @@ from collections.abc import Sequence
19
19
  from types import MappingProxyType
20
20
  from typing import Any
21
21
 
22
- from mplang.core.table import TableType
23
- from mplang.core.tensor import TensorType
22
+ from mplang.v1.core.table import TableType
23
+ from mplang.v1.core.tensor import TensorType
24
24
 
25
25
  __all__ = [
26
26
  "PFunction",
@@ -33,7 +33,7 @@ class PFunction:
33
33
 
34
34
  PFunction serves as a unified interface for describing single-party computations
35
35
  in multi-party computing scenarios. It can represent both:
36
- 1. Built-in operations (e.g., "spu.makeshares", "builtin.read")
36
+ 1. Built-in operations (e.g., "spu.makeshares", "basic.read")
37
37
  2. User-defined programmable functions with custom code
38
38
 
39
39
  The PFunction accepts a list of typed inputs (TensorType/TableType). For
@@ -47,7 +47,7 @@ class PFunction:
47
47
 
48
48
  Args:
49
49
  fn_type: The type/category identifier of this PFunction, indicating which
50
- backend or handler should process it (e.g., "spu.makeshares", "builtin.read",
50
+ backend or handler should process it (e.g., "spu.makeshares", "basic.read",
51
51
  "mlir.stablehlo"). This serves as a routing mechanism for execution.
52
52
  ins_info: Type information for input parameters (TensorType or TableType)
53
53
  outs_info: Type information for output values (TensorType or TableType)