mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
@@ -1,369 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from collections.abc import Mapping
18
- from typing import Any
19
-
20
- from mplang.v1.core.dtypes import UINT8, DType
21
- from mplang.v1.core.pfunc import PFunction
22
- from mplang.v1.core.table import PandasTableLike, TableLike, TableType
23
- from mplang.v1.core.tensor import TensorLike, TensorType
24
- from mplang.v1.kernels import base
25
- from mplang.v1.kernels.base import KernelContext, get_kernel_spec, kernel_exists
26
-
27
- # Default bindings
28
- # Import kernel implementation modules explicitly so their @kernel_def entries
29
- # register at import time. Keep imports grouped; alias with leading underscore
30
- # to silence unused variable warnings without F401 pragmas.
31
- _IMPL_IMPORTED = False
32
-
33
-
34
- def _ensure_impl_imported() -> None:
35
- global _IMPL_IMPORTED
36
- if _IMPL_IMPORTED:
37
- return
38
- from mplang.v1.kernels import basic as _impl_basic # noqa: F401
39
- from mplang.v1.kernels import crypto as _impl_crypto # noqa: F401
40
- from mplang.v1.kernels import fhe as _impl_fhe # noqa: F401
41
- from mplang.v1.kernels import mock_tee as _impl_tee # noqa: F401
42
- from mplang.v1.kernels import phe as _impl_phe # noqa: F401
43
- from mplang.v1.kernels import spu as _impl_spu # noqa: F401
44
- from mplang.v1.kernels import sql_duckdb as _impl_sql_duckdb # noqa: F401
45
- from mplang.v1.kernels import stablehlo as _impl_stablehlo # noqa: F401
46
-
47
- _IMPL_IMPORTED = True
48
-
49
-
50
- # imports consolidated above
51
-
52
- _DEFAULT_BINDINGS: dict[str, str] = {
53
- # basic
54
- "basic.identity": "basic.identity",
55
- "basic.read": "basic.read",
56
- "basic.write": "basic.write",
57
- "basic.constant": "basic.constant",
58
- "basic.rank": "basic.rank",
59
- "basic.prand": "basic.prand",
60
- "basic.table_to_tensor": "basic.table_to_tensor",
61
- "basic.tensor_to_table": "basic.tensor_to_table",
62
- "basic.debug_print": "basic.debug_print",
63
- "basic.pack": "basic.pack",
64
- "basic.unpack": "basic.unpack",
65
- # crypto
66
- "crypto.keygen": "crypto.keygen",
67
- "crypto.enc": "crypto.enc",
68
- "crypto.dec": "crypto.dec",
69
- "crypto.kem_keygen": "crypto.kem_keygen",
70
- "crypto.kem_derive": "crypto.kem_derive",
71
- "crypto.hkdf": "crypto.hkdf",
72
- # phe
73
- "phe.keygen": "phe.keygen",
74
- "phe.encrypt": "phe.encrypt",
75
- "phe.mul": "phe.mul",
76
- "phe.add": "phe.add",
77
- "phe.decrypt": "phe.decrypt",
78
- "phe.dot": "phe.dot",
79
- "phe.gather": "phe.gather",
80
- "phe.scatter": "phe.scatter",
81
- "phe.concat": "phe.concat",
82
- "phe.reshape": "phe.reshape",
83
- "phe.transpose": "phe.transpose",
84
- # fhe
85
- "fhe.keygen": "fhe.keygen",
86
- "fhe.encrypt": "fhe.encrypt",
87
- "fhe.decrypt": "fhe.decrypt",
88
- "fhe.add": "fhe.add",
89
- "fhe.mul": "fhe.mul",
90
- "fhe.dot": "fhe.dot",
91
- "fhe.polyval": "fhe.polyval",
92
- "fhe.sub": "fhe.sub",
93
- "fhe.negate": "fhe.negate",
94
- "fhe.square": "fhe.square",
95
- # spu
96
- "spu.seed_env": "spu.seed_env",
97
- "spu.makeshares": "spu.makeshares",
98
- "spu.reconstruct": "spu.reconstruct",
99
- "spu.run_pphlo": "spu.run_pphlo",
100
- # stablehlo
101
- "mlir.stablehlo": "mlir.stablehlo",
102
- # sql
103
- # generic SQL op; backend-specific kernel id for duckdb
104
- "sql.run": "duckdb.run_sql",
105
- # tee
106
- # "tee.quote_gen": "mock_tee.quote_gen",
107
- # "tee.attest": "mock_tee.attest",
108
- }
109
-
110
-
111
- # --- RuntimeContext ---
112
-
113
-
114
- class RuntimeContext:
115
- """Per-runtime execution context with isolated op->kernel bindings.
116
-
117
- This object owns ONLY static dispatch metadata ("op bindings") and mutable
118
- per-rank kernel side state/cache/stats. It does NOT store per-evaluation
119
- variable bindings (those are provided to the evaluator at evaluation time).
120
-
121
- Parameters
122
- ----------
123
- rank : int
124
- Local rank of this participant.
125
- world_size : int
126
- Total number of participants.
127
- initial_bindings : Mapping[str, str] | None, optional
128
- Optional partial overrides applied on top of the default binding table
129
- during construction (override semantics, not replace). These map
130
- op_type -> kernel_id and form a *template* for dispatch. After
131
- initialization, all (re)binding must go through ``bind_op`` /
132
- ``rebind_op`` on this context (scoped to THIS runtime only).
133
- state : dict, optional
134
- Mutable per-runtime key/value storage for kernels. Flat key space;
135
- callers SHOULD use dotted prefixes (e.g. "stablehlo.compile_cache").
136
- Kernels own their *state* (functional correctness data, caches,
137
- handles, compiled objects, RNGs, etc.). Runtime does not interpret
138
- structure—values may themselves be dicts if a kernel wants its own
139
- pocket. Created empty when omitted.
140
- stats : dict, optional
141
- Mutable statistics/telemetry owned by the runtime (usage counters,
142
- timings, profiling aids). Kernels may increment counters but should
143
- avoid storing functional state here. A default "op_calls" mapping is
144
- ensured. Created empty when omitted.
145
- """
146
-
147
- __slots__ = ("_ibindings", "rank", "state", "stats", "world_size")
148
-
149
- def __init__(
150
- self,
151
- rank: int,
152
- world_size: int,
153
- initial_bindings: Mapping[str, str] | None = None,
154
- *,
155
- state: dict[str, Any] | None = None,
156
- stats: dict[str, Any] | None = None,
157
- ) -> None:
158
- _ensure_impl_imported()
159
- self.rank = rank
160
- self.world_size = world_size
161
- # Merge defaults with user overrides (override semantics)
162
- self._ibindings: dict[str, str] = {
163
- **_DEFAULT_BINDINGS,
164
- **(initial_bindings or {}),
165
- }
166
- self.state = state if state is not None else {}
167
- self.stats = stats if stats is not None else {}
168
- self.stats.setdefault("op_calls", {})
169
-
170
- def run_kernel(self, pfunc: PFunction, arg_list: list[Any]) -> list[Any]:
171
- fn_type = pfunc.fn_type
172
- kid = self._ibindings.get(fn_type)
173
- if kid is None:
174
- raise NotImplementedError(f"no backend kernel registered for op {fn_type}")
175
- spec = get_kernel_spec(kid)
176
- fn = spec.fn # kernel implementation
177
- if len(arg_list) != len(pfunc.ins_info):
178
- raise ValueError(
179
- f"kernel {fn_type} arg count mismatch: got {len(arg_list)}, expect {len(pfunc.ins_info)}"
180
- )
181
- for idx, (ins_spec, val) in enumerate(
182
- zip(pfunc.ins_info, arg_list, strict=True)
183
- ):
184
- if isinstance(ins_spec, TableType):
185
- _validate_table_arg(fn_type, idx, ins_spec, val)
186
- continue
187
- if isinstance(ins_spec, TensorType):
188
- _validate_tensor_arg(fn_type, idx, ins_spec, val)
189
- continue
190
-
191
- # install kernel context
192
- kctx = KernelContext(rank=self.rank, world_size=self.world_size, runtime=self)
193
- token = base._CTX_VAR.set(kctx)
194
- try:
195
- raw = fn(pfunc, *arg_list)
196
- finally:
197
- base._CTX_VAR.reset(token)
198
-
199
- try:
200
- op_calls = self.stats.setdefault("op_calls", {})
201
- op_calls[fn_type] = op_calls.get(fn_type, 0) + 1
202
- except Exception: # pragma: no cover - never raise due to stats
203
- pass
204
- expected = len(pfunc.outs_info)
205
- if expected == 0:
206
- if raw in (None, (), []):
207
- return []
208
- raise ValueError(
209
- f"kernel {fn_type} should return no values; got {type(raw).__name__}"
210
- )
211
- if expected == 1:
212
- if isinstance(raw, (tuple, list)):
213
- if len(raw) != 1:
214
- raise ValueError(
215
- f"kernel {fn_type} produced {len(raw)} outputs, expected 1"
216
- )
217
- return [raw[0]]
218
- return [raw]
219
- if not isinstance(raw, (tuple, list)):
220
- raise TypeError(
221
- f"kernel {fn_type} must return sequence (len={expected}), got {type(raw).__name__}"
222
- )
223
- if len(raw) != expected:
224
- raise ValueError(
225
- f"kernel {fn_type} produced {len(raw)} outputs, expected {expected}"
226
- )
227
- return list(raw)
228
-
229
- def reset(self) -> None:
230
- self.state.clear()
231
-
232
- # ---- runtime state API (flat key space) ----
233
- # Keys are treated atomically; convention encourages dotted prefixes
234
- # (e.g. 'stablehlo.compile_cache.hash', 'crypto.rng'). Implementation
235
- # does NOT parse or create hierarchical dicts—any grouping is purely
236
- # by string prefix. Values themselves MAY be dicts if callers want a
237
- # manual pocket. This keeps semantics simple and predictable.
238
-
239
- def ensure_state(self, key: str, factory: type | Any = dict) -> Any:
240
- """Return value for key; if absent create via factory and store.
241
-
242
- Key is not parsed; dotted forms are allowed but treated as a single
243
- map key. Use consistent prefixes for grouping (e.g. 'spu.config').
244
- """
245
- if not key:
246
- raise ValueError("empty state key")
247
- val = self.state.get(key)
248
- if val is None:
249
- val = factory()
250
- self.state[key] = val
251
- return val
252
-
253
- def get_state(self, key: str, default: Any | None = None) -> Any:
254
- if not key:
255
- raise ValueError("empty state key")
256
- return self.state.get(key, default)
257
-
258
- def set_state(self, key: str, value: Any) -> None:
259
- if not key:
260
- raise ValueError("empty state key")
261
- self.state[key] = value
262
-
263
- def del_state(self, key: str) -> None:
264
- if not key:
265
- raise ValueError("empty state key")
266
- self.state.pop(key, None)
267
-
268
- def list_state(self, prefix: str = "") -> dict[str, Any]:
269
- """Return mapping of key -> value; optional prefix filter.
270
-
271
- Prefix match is string-based; if prefix is non-empty include keys
272
- where key == prefix or key starts with prefix + '.'.
273
- """
274
- if not prefix:
275
- return dict(self.state)
276
- pref = prefix if prefix.endswith(".") else prefix + "."
277
- out: dict[str, Any] = {}
278
- for k, v in self.state.items():
279
- if k == prefix or k.startswith(pref):
280
- out[k] = v
281
- return out
282
-
283
- # ---- explicit (re)binding API ----
284
- def bind_op(self, op_type: str, kernel_id: str, *, force: bool = False) -> None:
285
- """Bind an operation to a kernel for THIS context only.
286
-
287
- force=False (default) keeps existing binding (no silent override).
288
- """
289
- if not kernel_exists(kernel_id):
290
- raise KeyError(f"kernel_id {kernel_id} not registered")
291
- if not force and op_type in self._ibindings:
292
- return
293
- self._ibindings[op_type] = kernel_id
294
-
295
- def rebind_op(self, op_type: str, kernel_id: str) -> None:
296
- """Force rebind an operation to a different kernel (shorthand)."""
297
- self.bind_op(op_type, kernel_id, force=True)
298
-
299
- # Introspection helpers
300
- def list_bound_ops(self) -> list[str]: # pragma: no cover - convenience
301
- return sorted(self._ibindings.keys())
302
-
303
- def get_binding(self, op_type: str) -> str | None: # pragma: no cover
304
- return self._ibindings.get(op_type)
305
-
306
- def __repr__(self) -> str: # pragma: no cover - debug aid
307
- return (
308
- f"RuntimeContext(rank={self.rank}, world_size={self.world_size}, "
309
- f"bound_ops={len(self._ibindings)})"
310
- )
311
-
312
-
313
- def _validate_table_arg(
314
- fn_type: str, arg_index: int, spec: TableType, value: Any
315
- ) -> None:
316
- if not isinstance(value, TableLike):
317
- raise TypeError(
318
- f"kernel {fn_type} input[{arg_index}] expects TableLike, got {type(value).__name__}"
319
- )
320
- columns = (
321
- value.columns if isinstance(value, PandasTableLike) else value.column_names
322
- )
323
- if len(columns) != len(spec.columns):
324
- raise ValueError(
325
- f"kernel {fn_type} input[{arg_index}] column count mismatch: got {len(columns)}, expected {len(spec.columns)}"
326
- )
327
-
328
-
329
- def _validate_tensor_arg(
330
- fn_type: str, arg_index: int, spec: TensorType, value: Any
331
- ) -> None:
332
- # Backend-only handle sentinel (e.g., PHE keys) bypasses all structural checks
333
- if tuple(spec.shape) == (-1, 0) and spec.dtype == UINT8:
334
- return
335
-
336
- if isinstance(value, (int, float, bool, complex)):
337
- val_shape: tuple[Any, ...] = ()
338
- duck_dtype: Any = type(value)
339
- else:
340
- if not isinstance(value, TensorLike):
341
- raise TypeError(
342
- f"kernel {fn_type} input[{arg_index}] expects TensorLike, got {type(value).__name__}"
343
- )
344
- val_shape = getattr(value, "shape", ())
345
- duck_dtype = getattr(value, "dtype", None)
346
-
347
- if len(spec.shape) != len(val_shape):
348
- raise ValueError(
349
- f"kernel {fn_type} input[{arg_index}] rank mismatch: got {val_shape}, expected {spec.shape}"
350
- )
351
-
352
- for dim_idx, (spec_dim, val_dim) in enumerate(
353
- zip(spec.shape, val_shape, strict=True)
354
- ):
355
- if spec_dim >= 0 and spec_dim != val_dim:
356
- raise ValueError(
357
- f"kernel {fn_type} input[{arg_index}] shape mismatch at dim {dim_idx}: got {val_dim}, expected {spec_dim}"
358
- )
359
-
360
- try:
361
- val_dtype = DType.from_any(duck_dtype)
362
- except (ValueError, TypeError): # pragma: no cover
363
- raise TypeError(
364
- f"kernel {fn_type} input[{arg_index}] has unsupported dtype object {duck_dtype!r}"
365
- ) from None
366
- if val_dtype != spec.dtype:
367
- raise ValueError(
368
- f"kernel {fn_type} input[{arg_index}] dtype mismatch: got {val_dtype}, expected {spec.dtype}"
369
- )
@@ -1,122 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import os
18
-
19
- import numpy as np
20
-
21
- from mplang.v1.core import PFunction
22
- from mplang.v1.kernels.base import cur_kctx, kernel_def
23
- from mplang.v1.kernels.value import TensorValue
24
- from mplang.v1.utils.crypto import blake2b
25
-
26
- __all__: list[str] = [] # No public exports currently
27
-
28
-
29
- def _get_rng() -> np.random.Generator:
30
- """Get (and lazily create) per-rank RNG for crypto kernels.
31
-
32
- Runtime state is untyped, so we narrow the type explicitly for mypy.
33
- """
34
- kctx = cur_kctx()
35
- rt = kctx.runtime
36
- rng_obj = rt.get_state("crypto.rng")
37
- if rng_obj is None:
38
- seed = int(os.environ.get("MPLANG_CRYPTO_SEED", "0")) + kctx.rank * 7919
39
- rng_obj = np.random.default_rng(seed)
40
- rt.set_state("crypto.rng", rng_obj)
41
- assert isinstance(rng_obj, np.random.Generator) # narrow
42
- return rng_obj
43
-
44
-
45
- def _keystream(key: bytes, nonce: bytes, length: int) -> bytes:
46
- # WARNING (INSECURE): hash-based keystream (key||nonce||counter)
47
- out = bytearray()
48
- while len(out) < length:
49
- chunk = blake2b(key + nonce)
50
- out.extend(chunk)
51
- return bytes(out[:length])
52
-
53
-
54
- @kernel_def("crypto.keygen")
55
- def _crypto_keygen(pfunc: PFunction) -> TensorValue:
56
- length = int(pfunc.attrs.get("length", 32))
57
- rng = _get_rng()
58
- key = rng.integers(0, 256, size=(length,), dtype=np.uint8)
59
- return TensorValue(key)
60
-
61
-
62
- @kernel_def("crypto.enc")
63
- def _crypto_encrypt(
64
- pfunc: PFunction, pt_bytes: TensorValue, key: TensorValue
65
- ) -> TensorValue:
66
- pt_bytes_np = pt_bytes.to_numpy().astype(np.uint8, copy=False)
67
- key_np = key.to_numpy().astype(np.uint8, copy=False)
68
- rng = _get_rng()
69
- nonce = rng.integers(0, 256, size=(16,), dtype=np.uint8)
70
- stream = np.frombuffer(
71
- _keystream(key_np.tobytes(), nonce.tobytes(), pt_bytes_np.size), dtype=np.uint8
72
- )
73
- ct = (pt_bytes_np ^ stream).astype(np.uint8)
74
- out = np.concatenate([nonce, ct]).astype(np.uint8)
75
- return TensorValue(out)
76
-
77
-
78
- @kernel_def("crypto.dec")
79
- def _crypto_decrypt(
80
- pfunc: PFunction, ct_with_nonce: TensorValue, key: TensorValue
81
- ) -> TensorValue:
82
- ct_np = ct_with_nonce.to_numpy().astype(np.uint8, copy=False)
83
- key_np = key.to_numpy().astype(np.uint8, copy=False)
84
- nonce = ct_np[:16]
85
- ct = ct_np[16:]
86
- stream = np.frombuffer(
87
- _keystream(key_np.tobytes(), nonce.tobytes(), len(ct)), dtype=np.uint8
88
- )
89
- pt_bytes = (ct ^ stream).astype(np.uint8)
90
- return TensorValue(pt_bytes)
91
-
92
-
93
- @kernel_def("crypto.kem_keygen")
94
- def _crypto_kem_keygen(pfunc: PFunction) -> tuple[TensorValue, TensorValue]:
95
- rng = _get_rng()
96
- sk = rng.integers(0, 256, size=(32,), dtype=np.uint8)
97
- pk_bytes = blake2b(sk.tobytes())[:32]
98
- pk = np.frombuffer(pk_bytes, dtype=np.uint8)
99
- return (TensorValue(sk), TensorValue(pk))
100
-
101
-
102
- @kernel_def("crypto.kem_derive")
103
- def _crypto_kem_derive(
104
- pfunc: PFunction, sk: TensorValue, peer_pk: TensorValue
105
- ) -> TensorValue:
106
- sk_np = sk.to_numpy().astype(np.uint8, copy=False)
107
- peer_pk_np = peer_pk.to_numpy().astype(np.uint8, copy=False)
108
-
109
- self_pk_bytes = blake2b(sk_np.tobytes())[:32]
110
- self_pk_arr = np.frombuffer(self_pk_bytes, dtype=np.uint8)
111
- xored = (self_pk_arr ^ peer_pk_np).astype(np.uint8)
112
- secret = np.frombuffer(blake2b(xored.tobytes())[:32], dtype=np.uint8)
113
- return TensorValue(secret)
114
-
115
-
116
- @kernel_def("crypto.hkdf")
117
- def _crypto_hkdf(pfunc: PFunction, secret: TensorValue) -> TensorValue:
118
- secret_np = secret.to_numpy().astype(np.uint8, copy=False)
119
- info_str = str(pfunc.attrs.get("info", ""))
120
- info = info_str.encode("utf-8")
121
- out = np.frombuffer(blake2b(secret_np.tobytes() + info)[:32], dtype=np.uint8)
122
- return TensorValue(out)