mplang-nightly 0.1.dev268__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +46 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +19 -13
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/libs/mpc/psi/rr22.py +303 -0
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang/v2/libs/mpc/psi/rr22.py +0 -344
  162. mplang_nightly-0.1.dev268.dist-info/RECORD +0 -180
  163. /mplang/{v2/backends → backends}/channel.py +0 -0
  164. /mplang/{v2/edsl → edsl}/README.md +0 -0
  165. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  166. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  167. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  168. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  169. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  171. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  172. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  175. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  177. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  178. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  179. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  180. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  181. {mplang_nightly-0.1.dev268.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/host.py DELETED
@@ -1,130 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from collections.abc import Callable
18
- from typing import Any
19
-
20
- from jax.tree_util import tree_map
21
-
22
- from mplang.v1.core import (
23
- ClusterSpec,
24
- InterpContext,
25
- MPContext,
26
- MPObject,
27
- TraceContext,
28
- TracedFunction,
29
- trace,
30
- )
31
- from mplang.v1.core.context_mgr import cur_ctx, with_ctx
32
-
33
-
34
- def evaluate(
35
- interp: InterpContext, mpfn: Callable[..., Any], *args: Any, **kwargs: Any
36
- ) -> Any: # type: ignore[misc]
37
- """Evaluate a multi-party function with the given interpreter context.
38
-
39
- This function accepts arbitrary types as it's designed to handle
40
- any multi-party computation function and arguments.
41
-
42
- Args:
43
- interp: The interpreter context for evaluating the multi-party function.
44
- mpfn: The multi-party function to evaluate.
45
- *args: Positional arguments to pass to the function.
46
- **kwargs: Keyword arguments to pass to the function.
47
-
48
- Returns:
49
- Any: The result of evaluating the multi-party function, which can be
50
- any type depending on the function's return type.
51
- """
52
- assert isinstance(interp, InterpContext), f"Expect InterpContext, got {interp}"
53
- with with_ctx(interp):
54
- return mpfn(*args, **kwargs)
55
-
56
-
57
- def fetch(interp: InterpContext | None, objs: Any) -> Any: # type: ignore[misc]
58
- """Fetch computed results from MPObject instances in nested data structures.
59
-
60
- This function uses tree_map to handle arbitrary nested structures,
61
- so it needs to accept and return Any type.
62
-
63
- Args:
64
- interp: The interpreter context for fetching results. If None, uses the
65
- current context from cur_ctx().
66
- objs: The objects containing MPObject instances to fetch. Can be any
67
- nested structure.
68
-
69
- Returns:
70
- Any: The fetched results with the same structure as the input objects,
71
- but with MPObject instances replaced by their computed values.
72
- """
73
- ctx = interp or cur_ctx()
74
- assert isinstance(ctx, InterpContext), f"Expect MPExecutor, got {ctx}"
75
-
76
- evaluated = evaluate(ctx, lambda x: x, objs)
77
-
78
- def fetch_impl(arg: MPObject | Any) -> Any:
79
- if not isinstance(arg, MPObject):
80
- return arg
81
-
82
- return ctx.fetch(arg)
83
-
84
- return tree_map(fetch_impl, evaluated)
85
-
86
-
87
- class CompileOptions(MPContext):
88
- """
89
- Lightweight ``MPContext`` used for ahead-of-time (AOT) compilation.
90
-
91
- Args:
92
- psize: Number of participating parties.
93
- spu_mask: Bitmask indicating which parties own an SPU device. Defaults
94
- to a mask that enables all parties.
95
- """
96
-
97
- def __init__(self, cluster_spec: Any) -> None:
98
- super().__init__(cluster_spec)
99
-
100
- @classmethod
101
- def simple(cls, world_size: int) -> CompileOptions:
102
- """Create a simple CompileOptions with the given party size and SPU mask.
103
-
104
- Args:
105
- world_size: Number of participating parties.
106
-
107
- Returns:
108
- A CompileOptions instance.
109
- """
110
- cluster_spec = ClusterSpec.simple(world_size)
111
- return cls(cluster_spec)
112
-
113
-
114
- def compile(
115
- mctx: MPContext, fn: Callable[..., Any], *args: Any, **kwargs: Any
116
- ) -> TracedFunction:
117
- """Compile a multi-party function into a TracedFunction.
118
-
119
- Args:
120
- mctx: The multi-party context for compilation.
121
- fn: The function to compile.
122
- *args: Positional arguments to pass during compilation.
123
- **kwargs: Keyword arguments to pass during compilation.
124
-
125
- Returns:
126
- TracedFunction: The compiled function representation that can be
127
- evaluated in multi-party contexts.
128
- """
129
- trace_ctx = TraceContext(mctx.cluster_spec)
130
- return trace(trace_ctx, fn, *args, **kwargs)
@@ -1,41 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from mplang.v1.kernels.value import (
16
- BytesBlob,
17
- TableValue,
18
- TensorValue,
19
- Value,
20
- ValueDecodeError,
21
- ValueError,
22
- decode_value,
23
- encode_value,
24
- is_value_envelope,
25
- list_value_kinds,
26
- register_value,
27
- )
28
-
29
- __all__ = [
30
- "BytesBlob",
31
- "TableValue",
32
- "TensorValue",
33
- "Value",
34
- "ValueDecodeError",
35
- "ValueError",
36
- "decode_value",
37
- "encode_value",
38
- "is_value_envelope",
39
- "list_value_kinds",
40
- "register_value",
41
- ]
mplang/v1/kernels/base.py DELETED
@@ -1,125 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Backend kernel registry: mapping kernel_id -> implementation.
16
-
17
- This module provides a lightweight registry for backend kernel implementations.
18
- It does not track or decide which kernel handles a given semantic operation;
19
- that policy (op -> kernel_id) is managed externally by each ``RuntimeContext``.
20
-
21
- Exposed primitives:
22
- * ``@kernel_def(kernel_id)``: decorator to register a kernel implementation.
23
- * ``get_kernel_spec(kernel_id)``: look up a previously registered kernel.
24
- * ``cur_kctx()`` / ``KernelContext``: execution context available only
25
- inside a kernel body (rank, world_size, per-backend state pockets, and a
26
- runtime-wide cache shared by kernels of the same runtime instance).
27
-
28
- No global op binding table exists here; callers resolve an op to a kernel_id
29
- before invoking the kernel function.
30
- """
31
-
32
- from __future__ import annotations
33
-
34
- import contextvars
35
- from collections.abc import Callable
36
- from dataclasses import dataclass
37
- from typing import TYPE_CHECKING, Any
38
-
39
- if TYPE_CHECKING:
40
- from mplang.v1.kernels.context import RuntimeContext
41
-
42
- __all__ = [
43
- "KernelContext",
44
- "KernelSpec",
45
- "cur_kctx",
46
- "get_kernel_spec",
47
- "kernel_exists",
48
- "list_kernels",
49
- ]
50
-
51
-
52
- @dataclass
53
- class KernelContext:
54
- """Ephemeral per-kernel invocation context.
55
-
56
- Cross-kernel persistent state (RNGs, compiled artifacts, environment handles)
57
- should be stored in RuntimeContext.
58
- """
59
-
60
- rank: int
61
- world_size: int
62
- runtime: RuntimeContext
63
-
64
-
65
- _CTX_VAR: contextvars.ContextVar[KernelContext | None] = contextvars.ContextVar(
66
- "_flat_backend_ctx", default=None
67
- )
68
-
69
-
70
- def cur_kctx() -> KernelContext:
71
- """Return current kernel execution context (only valid inside kernel)."""
72
- ctx = _CTX_VAR.get()
73
- if ctx is None:
74
- raise RuntimeError("cur_kctx() called outside backend kernel execution")
75
- return ctx
76
-
77
-
78
- # ---------------- Registry ----------------
79
-
80
- # Kernel callable signature: (pfunc, *args) -> Any | sequence (no **kwargs)
81
- KernelFn = Callable[..., Any]
82
-
83
-
84
- @dataclass
85
- class KernelSpec:
86
- kernel_id: str
87
- fn: KernelFn
88
- meta: dict[str, Any]
89
-
90
-
91
- # All registered kernel implementations: kernel_id -> spec
92
- _KERNELS: dict[str, KernelSpec] = {}
93
-
94
-
95
- def kernel_def(kernel_id: str, /, **meta: Any) -> Callable[[KernelFn], KernelFn]:
96
- """Decorator to register a concrete kernel implementation.
97
-
98
- This ONLY registers the implementation (kernel_id -> fn). It does NOT bind
99
- any op. Higher layer must call ``bind_op(op_type, kernel_id)`` explicitly.
100
- """
101
-
102
- def _decorator(fn: KernelFn) -> KernelFn:
103
- if kernel_id in _KERNELS:
104
- raise ValueError(f"duplicate kernel_id={kernel_id}")
105
- _KERNELS[kernel_id] = KernelSpec(kernel_id=kernel_id, fn=fn, meta=dict(meta))
106
- return fn
107
-
108
- return _decorator
109
-
110
-
111
- def get_kernel_spec(kernel_id: str) -> KernelSpec:
112
- """Return KernelSpec for a registered kernel_id (no op binding lookup)."""
113
- spec = _KERNELS.get(kernel_id)
114
- if spec is None:
115
- raise KeyError(f"kernel_id {kernel_id} not registered")
116
- return spec
117
-
118
-
119
- def list_kernels() -> list[str]:
120
- return sorted(_KERNELS.keys())
121
-
122
-
123
- def kernel_exists(kernel_id: str) -> bool:
124
- """Return True if a kernel_id has been registered."""
125
- return kernel_id in _KERNELS
@@ -1,240 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import numpy as np
18
-
19
- from mplang.v1.core import PFunction, TableType, TensorType
20
- from mplang.v1.kernels.base import cur_kctx, kernel_def
21
- from mplang.v1.kernels.value import TableValue, TensorValue, Value
22
- from mplang.v1.runtime.data_providers import get_provider, resolve_uri
23
- from mplang.v1.utils import table_utils
24
-
25
-
26
- @kernel_def("basic.identity")
27
- def _identity(pfunc: PFunction, value: Value) -> Value:
28
- # Runtime guarantees exactly one argument; no extra arity checks here.
29
- return value
30
-
31
-
32
- @kernel_def("basic.read")
33
- def _read(pfunc: PFunction) -> Value:
34
- path = pfunc.attrs.get("path")
35
- if path is None:
36
- raise ValueError("missing path attr for basic.read")
37
- out_t = pfunc.outs_info[0]
38
- uri = resolve_uri(str(path))
39
- prov = get_provider(uri.scheme)
40
- if prov is None:
41
- raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
42
- ctx = cur_kctx()
43
- try:
44
- data = prov.read(uri, out_t, ctx=ctx)
45
- except Exception as e: # pragma: no cover - provider errors
46
- raise RuntimeError(f"basic.read failed: {e}") from e
47
-
48
- if isinstance(data, Value):
49
- return data
50
-
51
- if isinstance(out_t, TableType):
52
- return TableValue(data)
53
- elif isinstance(out_t, TensorType):
54
- return TensorValue(np.asarray(data))
55
- else:
56
- raise TypeError(
57
- f"basic.read only supports TableType/TensorType outputs, got {type(out_t).__name__}"
58
- )
59
-
60
-
61
- @kernel_def("basic.write")
62
- def _write(pfunc: PFunction, obj: Value) -> Value:
63
- path = pfunc.attrs.get("path")
64
- if path is None:
65
- raise ValueError("missing path attr for basic.write")
66
- uri = resolve_uri(str(path))
67
- prov = get_provider(uri.scheme)
68
- if prov is None:
69
- raise NotImplementedError(f"no resource provider for scheme: {uri.scheme}")
70
- # Pass Value object directly to provider - let provider decide how to handle it
71
- ctx = cur_kctx()
72
- try:
73
- prov.write(uri, obj, ctx=ctx)
74
- except Exception as e: # pragma: no cover
75
- raise RuntimeError(f"basic.write failed: {e}") from e
76
- return obj
77
-
78
-
79
- @kernel_def("basic.constant")
80
- def _constant(pfunc: PFunction) -> Value:
81
- """Return constants as Value types (TensorValue or TableValue)."""
82
- data_bytes = pfunc.attrs.get("data_bytes")
83
- if data_bytes is None:
84
- raise ValueError("missing data_bytes attr for basic.constant")
85
- out_t = pfunc.outs_info[0]
86
- fmt = pfunc.attrs.get("data_format")
87
- if isinstance(out_t, TableType):
88
- if fmt != "bytes[parquet]":
89
- raise ValueError(f"unsupported table constant format {fmt}")
90
- df = table_utils.decode_table(data_bytes, format="parquet")
91
- return TableValue(df)
92
- # tensor path
93
- shape = out_t.shape # type: ignore[attr-defined,union-attr]
94
- dtype = out_t.dtype.numpy_dtype() # type: ignore[attr-defined,union-attr]
95
- arr = np.frombuffer(data_bytes, dtype=dtype).reshape(shape)
96
- return TensorValue(arr)
97
-
98
-
99
- @kernel_def("basic.rank")
100
- def _rank(pfunc: PFunction) -> TensorValue:
101
- """Return rank as TensorValue."""
102
- ctx = cur_kctx()
103
- arr = np.array(ctx.rank, dtype=np.uint64)
104
- return TensorValue(arr)
105
-
106
-
107
- @kernel_def("basic.prand")
108
- def _prand(pfunc: PFunction) -> TensorValue:
109
- """Return random data as TensorValue."""
110
- shape = pfunc.attrs.get("shape", ())
111
- rng = np.random.default_rng()
112
- info = np.iinfo(np.uint64)
113
- data = rng.integers(
114
- low=info.min, high=info.max, size=shape, dtype=np.uint64, endpoint=True
115
- )
116
- return TensorValue(data)
117
-
118
-
119
- @kernel_def("basic.table_to_tensor")
120
- def _table_to_tensor(pfunc: PFunction, table: TableValue) -> TensorValue:
121
- """Convert table to tensor, return as TensorValue."""
122
- arrow_table = table.to_arrow()
123
- if arrow_table.num_columns == 0:
124
- raise ValueError("cannot pack empty table")
125
- # Convert Arrow columns to numpy arrays and stack
126
- mat = np.column_stack([
127
- arrow_table.column(i).to_numpy() for i in range(arrow_table.num_columns)
128
- ])
129
- return TensorValue(mat)
130
-
131
-
132
- @kernel_def("basic.tensor_to_table")
133
- def _tensor_to_table(pfunc: PFunction, tensor: TensorValue) -> TableValue:
134
- """Convert tensor to table, return as TableValue."""
135
- import pyarrow as pa # type: ignore
136
-
137
- arr = tensor.to_numpy()
138
- if arr.ndim != 2:
139
- raise ValueError("tensor_to_table expects rank-2 array")
140
- col_names = pfunc.attrs.get("column_names")
141
- if col_names is None:
142
- raise ValueError("missing column_names attr")
143
- # Create Arrow table directly from numpy array columns
144
- arrays = [pa.array(arr[:, i]) for i in range(arr.shape[1])]
145
- arrow_table = pa.table(dict(zip(col_names, arrays, strict=True)))
146
- return TableValue(arrow_table)
147
-
148
-
149
- def _summ(v: Value) -> str:
150
- try:
151
- if isinstance(v, TableValue):
152
- # Use Arrow's native string representation (more efficient)
153
- arrow_table = v.to_arrow()
154
- # Show first 8 rows
155
- preview = arrow_table.slice(0, min(8, arrow_table.num_rows))
156
- return str(preview)
157
- if isinstance(v, TensorValue):
158
- arr = v.to_numpy()
159
- return str(
160
- np.array2string(
161
- arr, threshold=64, edgeitems=3, precision=6, suppress_small=True
162
- )
163
- )
164
- return repr(v)
165
- except Exception as e: # pragma: no cover
166
- return f"<unprintable {type(v).__name__}: {e}>"
167
-
168
-
169
- @kernel_def("basic.debug_print")
170
- def _debug_print(pfunc: PFunction, val: Value) -> Value:
171
- prefix = pfunc.attrs.get("prefix", "")
172
- ctx = cur_kctx()
173
- print(f"[debug_print][rank={ctx.rank}] {prefix}{_summ(val)}")
174
- return val
175
-
176
-
177
- @kernel_def("basic.pack")
178
- def _pack(pfunc: PFunction, value: Value) -> TensorValue:
179
- outs_info = pfunc.outs_info
180
- if len(outs_info) != 1:
181
- raise ValueError("basic.pack expects single output type")
182
- out_ty = outs_info[0]
183
- if not isinstance(out_ty, TensorType):
184
- raise TypeError("basic.pack must return TensorType")
185
- if out_ty.dtype.numpy_dtype() != np.uint8:
186
- raise TypeError("basic.pack output dtype must be uint8")
187
-
188
- if isinstance(value, TableValue):
189
- # Serialize Arrow table using IPC stream for consistency with Value serde
190
- import pyarrow as pa # type: ignore
191
- import pyarrow.ipc as pa_ipc # type: ignore
192
-
193
- arrow_table = value.to_arrow()
194
- sink = pa.BufferOutputStream()
195
- with pa_ipc.new_stream(sink, arrow_table.schema) as writer: # type: ignore[arg-type]
196
- writer.write_table(arrow_table) # type: ignore[arg-type]
197
- ipc_bytes = sink.getvalue().to_pybytes()
198
- return TensorValue(np.frombuffer(ipc_bytes, dtype=np.uint8))
199
-
200
- if isinstance(value, TensorValue):
201
- arr = value.to_numpy()
202
- return TensorValue(np.frombuffer(arr.tobytes(order="C"), dtype=np.uint8))
203
-
204
- raise TypeError(f"basic.pack does not support Value type {type(value).__name__}")
205
-
206
-
207
- @kernel_def("basic.unpack")
208
- def _unpack(pfunc: PFunction, packed: TensorValue) -> Value:
209
- outs_info = pfunc.outs_info
210
- if len(outs_info) != 1:
211
- raise ValueError("basic.unpack expects single output type")
212
- out_ty = outs_info[0]
213
-
214
- b = packed.to_numpy().astype(np.uint8, copy=False).reshape(-1)
215
-
216
- if isinstance(out_ty, TensorType):
217
- np_dtype = out_ty.dtype.numpy_dtype()
218
- shape = tuple(out_ty.shape)
219
- if any(dim < 0 for dim in shape):
220
- raise ValueError("basic.unpack does not support dynamic tensor shapes")
221
- elem_count = int(np.prod(shape))
222
- expected = elem_count * np.dtype(np_dtype).itemsize
223
- if b.size != expected:
224
- raise ValueError(
225
- f"unpack size mismatch: got {b.size} bytes, expect {expected} for {np_dtype} {shape}"
226
- )
227
- arr = np.frombuffer(b.tobytes(), dtype=np_dtype)
228
- return TensorValue(arr.reshape(shape))
229
-
230
- if isinstance(out_ty, TableType):
231
- # Deserialize Arrow IPC stream back to TableValue
232
- import pyarrow as pa # type: ignore
233
- import pyarrow.ipc as pa_ipc # type: ignore
234
-
235
- buf = pa.py_buffer(b.tobytes())
236
- reader = pa_ipc.open_stream(buf)
237
- table = reader.read_all()
238
- return TableValue(table)
239
-
240
- raise TypeError("basic.unpack output type must be TensorType or TableType")