mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev271__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +391 -17
- mplang/{v2/backends → backends}/__init__.py +9 -7
- mplang/{v2/backends → backends}/bfv_impl.py +6 -6
- mplang/{v2/backends → backends}/crypto_impl.py +6 -6
- mplang/{v2/backends → backends}/field_impl.py +5 -5
- mplang/{v2/backends → backends}/func_impl.py +4 -4
- mplang/{v2/backends → backends}/phe_impl.py +3 -3
- mplang/{v2/backends → backends}/simp_design.md +1 -1
- mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
- mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
- mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
- mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
- mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
- mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
- mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
- mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
- mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
- mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
- mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
- mplang/{v2/backends → backends}/spu_impl.py +8 -8
- mplang/{v2/backends → backends}/spu_state.py +4 -4
- mplang/{v2/backends → backends}/store_impl.py +3 -3
- mplang/{v2/backends → backends}/table_impl.py +8 -8
- mplang/{v2/backends → backends}/tee_impl.py +6 -6
- mplang/{v2/backends → backends}/tensor_impl.py +6 -6
- mplang/{v2/cli.py → cli.py} +9 -9
- mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
- mplang/{v2/dialects → dialects}/__init__.py +5 -5
- mplang/{v2/dialects → dialects}/bfv.py +6 -6
- mplang/{v2/dialects → dialects}/crypto.py +5 -5
- mplang/{v2/dialects → dialects}/dtypes.py +2 -2
- mplang/{v2/dialects → dialects}/field.py +3 -3
- mplang/{v2/dialects → dialects}/func.py +2 -2
- mplang/{v2/dialects → dialects}/phe.py +6 -6
- mplang/{v2/dialects → dialects}/simp.py +6 -6
- mplang/{v2/dialects → dialects}/spu.py +7 -7
- mplang/{v2/dialects → dialects}/store.py +2 -2
- mplang/{v2/dialects → dialects}/table.py +3 -3
- mplang/{v2/dialects → dialects}/tee.py +6 -6
- mplang/{v2/dialects → dialects}/tensor.py +5 -5
- mplang/{v2/edsl → edsl}/__init__.py +3 -3
- mplang/{v2/edsl → edsl}/context.py +6 -6
- mplang/{v2/edsl → edsl}/graph.py +5 -5
- mplang/{v2/edsl → edsl}/jit.py +2 -2
- mplang/{v2/edsl → edsl}/object.py +1 -1
- mplang/{v2/edsl → edsl}/primitive.py +5 -5
- mplang/{v2/edsl → edsl}/printer.py +1 -1
- mplang/{v2/edsl → edsl}/serde.py +1 -1
- mplang/{v2/edsl → edsl}/tracer.py +7 -7
- mplang/{v2/edsl → edsl}/typing.py +1 -1
- mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
- mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
- mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
- mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
- mplang/{v2/libs → libs}/collective.py +5 -5
- mplang/{v2/libs → libs}/device/__init__.py +1 -1
- mplang/{v2/libs → libs}/device/api.py +12 -12
- mplang/{v2/libs → libs}/ml/__init__.py +1 -1
- mplang/{v2/libs → libs}/ml/sgb.py +4 -4
- mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
- mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
- mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
- mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
- mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
- mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
- mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
- mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
- mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
- mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
- mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
- mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
- mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
- mplang/{v2/runtime → runtime}/interpreter.py +11 -11
- mplang/{v2/runtime → runtime}/value.py +2 -2
- mplang/{v1/runtime → utils}/__init__.py +18 -15
- mplang/{v1/utils → utils}/func_utils.py +1 -1
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/METADATA +2 -2
- mplang_nightly-0.1.dev271.dist-info/RECORD +102 -0
- mplang/v1/__init__.py +0 -157
- mplang/v1/_device.py +0 -602
- mplang/v1/analysis/__init__.py +0 -37
- mplang/v1/analysis/diagram.py +0 -567
- mplang/v1/core/__init__.py +0 -157
- mplang/v1/core/cluster.py +0 -343
- mplang/v1/core/comm.py +0 -281
- mplang/v1/core/context_mgr.py +0 -50
- mplang/v1/core/dtypes.py +0 -335
- mplang/v1/core/expr/__init__.py +0 -80
- mplang/v1/core/expr/ast.py +0 -542
- mplang/v1/core/expr/evaluator.py +0 -581
- mplang/v1/core/expr/printer.py +0 -285
- mplang/v1/core/expr/transformer.py +0 -141
- mplang/v1/core/expr/utils.py +0 -78
- mplang/v1/core/expr/visitor.py +0 -85
- mplang/v1/core/expr/walk.py +0 -387
- mplang/v1/core/interp.py +0 -160
- mplang/v1/core/mask.py +0 -325
- mplang/v1/core/mpir.py +0 -965
- mplang/v1/core/mpobject.py +0 -117
- mplang/v1/core/mptype.py +0 -407
- mplang/v1/core/pfunc.py +0 -130
- mplang/v1/core/primitive.py +0 -877
- mplang/v1/core/table.py +0 -218
- mplang/v1/core/tensor.py +0 -75
- mplang/v1/core/tracer.py +0 -383
- mplang/v1/host.py +0 -130
- mplang/v1/kernels/__init__.py +0 -41
- mplang/v1/kernels/base.py +0 -125
- mplang/v1/kernels/basic.py +0 -240
- mplang/v1/kernels/context.py +0 -369
- mplang/v1/kernels/crypto.py +0 -122
- mplang/v1/kernels/fhe.py +0 -858
- mplang/v1/kernels/mock_tee.py +0 -72
- mplang/v1/kernels/phe.py +0 -1864
- mplang/v1/kernels/spu.py +0 -341
- mplang/v1/kernels/sql_duckdb.py +0 -44
- mplang/v1/kernels/stablehlo.py +0 -90
- mplang/v1/kernels/value.py +0 -626
- mplang/v1/ops/__init__.py +0 -35
- mplang/v1/ops/base.py +0 -424
- mplang/v1/ops/basic.py +0 -294
- mplang/v1/ops/crypto.py +0 -262
- mplang/v1/ops/fhe.py +0 -272
- mplang/v1/ops/jax_cc.py +0 -147
- mplang/v1/ops/nnx_cc.py +0 -168
- mplang/v1/ops/phe.py +0 -216
- mplang/v1/ops/spu.py +0 -151
- mplang/v1/ops/sql_cc.py +0 -303
- mplang/v1/ops/tee.py +0 -36
- mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
- mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
- mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
- mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
- mplang/v1/runtime/channel.py +0 -230
- mplang/v1/runtime/cli.py +0 -451
- mplang/v1/runtime/client.py +0 -456
- mplang/v1/runtime/communicator.py +0 -131
- mplang/v1/runtime/data_providers.py +0 -303
- mplang/v1/runtime/driver.py +0 -324
- mplang/v1/runtime/exceptions.py +0 -27
- mplang/v1/runtime/http_api.md +0 -56
- mplang/v1/runtime/link_comm.py +0 -196
- mplang/v1/runtime/server.py +0 -501
- mplang/v1/runtime/session.py +0 -270
- mplang/v1/runtime/simulation.py +0 -324
- mplang/v1/simp/__init__.py +0 -13
- mplang/v1/simp/api.py +0 -353
- mplang/v1/simp/mpi.py +0 -131
- mplang/v1/simp/party.py +0 -225
- mplang/v1/simp/random.py +0 -120
- mplang/v1/simp/smpc.py +0 -238
- mplang/v1/utils/__init__.py +0 -13
- mplang/v1/utils/crypto.py +0 -32
- mplang/v1/utils/spu_utils.py +0 -130
- mplang/v1/utils/table_utils.py +0 -185
- mplang/v2/__init__.py +0 -424
- mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
- /mplang/{v2/backends → backends}/channel.py +0 -0
- /mplang/{v2/edsl → edsl}/README.md +0 -0
- /mplang/{v2/edsl → edsl}/registry.py +0 -0
- /mplang/{v2/kernels → kernels}/Makefile +0 -0
- /mplang/{v2/kernels → kernels}/__init__.py +0 -0
- /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
- /mplang/{v2/libs → libs}/device/cluster.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
- /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
- /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
- /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/__init__.py +0 -0
- /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
- /mplang/{v2/runtime → runtime}/object_store.py +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev271.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/comm.py
DELETED
|
@@ -1,281 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import logging
|
|
18
|
-
import threading
|
|
19
|
-
from abc import ABC, abstractmethod
|
|
20
|
-
from typing import Any
|
|
21
|
-
|
|
22
|
-
from mplang.v1.core.mask import Mask
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class ICommunicator(ABC):
|
|
26
|
-
"""Base class for communicators."""
|
|
27
|
-
|
|
28
|
-
@property
|
|
29
|
-
@abstractmethod
|
|
30
|
-
def rank(self) -> int:
|
|
31
|
-
"""Get the rank of this process"""
|
|
32
|
-
|
|
33
|
-
@property
|
|
34
|
-
@abstractmethod
|
|
35
|
-
def world_size(self) -> int:
|
|
36
|
-
"""Get the world size of this process"""
|
|
37
|
-
|
|
38
|
-
@abstractmethod
|
|
39
|
-
def new_id(self) -> str:
|
|
40
|
-
"""Must be implemented by mixing class"""
|
|
41
|
-
raise NotImplementedError
|
|
42
|
-
|
|
43
|
-
@abstractmethod
|
|
44
|
-
def send(self, to: int, key: str, data: Any) -> None:
|
|
45
|
-
"""Send data to peer with the given key"""
|
|
46
|
-
|
|
47
|
-
@abstractmethod
|
|
48
|
-
def recv(self, frm: int, key: str) -> Any:
|
|
49
|
-
"""Receive data from peer with the given key"""
|
|
50
|
-
|
|
51
|
-
@abstractmethod
|
|
52
|
-
def onSent(self, frm: int, key: str, data: Any) -> None:
|
|
53
|
-
"""Called when a key is sent to self"""
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class ICollective(ABC):
|
|
57
|
-
"""Interface for collective communication"""
|
|
58
|
-
|
|
59
|
-
@abstractmethod
|
|
60
|
-
def p2p(self, frm: int, to: int, data: Any) -> Any:
|
|
61
|
-
"""Perform point-to-point communication"""
|
|
62
|
-
|
|
63
|
-
@abstractmethod
|
|
64
|
-
def gather(self, root: int, data: Any) -> list[Any]:
|
|
65
|
-
"""Gather data from all processes to root"""
|
|
66
|
-
|
|
67
|
-
@abstractmethod
|
|
68
|
-
def gather_m(self, pmask: int, root: int, data: Any) -> list[Any]:
|
|
69
|
-
"""Gather data from parties in pmask to root"""
|
|
70
|
-
|
|
71
|
-
@abstractmethod
|
|
72
|
-
def scatter(self, root: int, args: list[Any]) -> Any:
|
|
73
|
-
"""Scatter data from root to all processes"""
|
|
74
|
-
|
|
75
|
-
@abstractmethod
|
|
76
|
-
def scatter_m(self, pmask: int, root: int, args: list[Any]) -> Any:
|
|
77
|
-
"""Scatter data from root to parties in pmask"""
|
|
78
|
-
|
|
79
|
-
@abstractmethod
|
|
80
|
-
def allgather(self, arg: Any) -> list[Any]:
|
|
81
|
-
"""Gather data from all processes to all processes"""
|
|
82
|
-
|
|
83
|
-
@abstractmethod
|
|
84
|
-
def allgather_m(self, pmask: int, arg: Any) -> list[Any]:
|
|
85
|
-
"""Gather data from parties in pmask to all processes"""
|
|
86
|
-
|
|
87
|
-
@abstractmethod
|
|
88
|
-
def bcast(self, root: int, arg: Any) -> Any:
|
|
89
|
-
"""Broadcast data from root to all processes"""
|
|
90
|
-
|
|
91
|
-
@abstractmethod
|
|
92
|
-
def bcast_m(self, pmask: int, root: int, arg: Any) -> Any:
|
|
93
|
-
"""Broadcast data from root to parties in pmask"""
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def is_rank_in(rank: int, mask: int) -> bool:
|
|
97
|
-
"""Check if the given rank is in the mask"""
|
|
98
|
-
return (1 << rank) & mask != 0
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
class CollectiveMixin(ICommunicator, ICollective):
|
|
102
|
-
"""Mixin class providing default implementations of collective communication algorithms
|
|
103
|
-
|
|
104
|
-
This mixin provides implementations based on send/recv primitives.
|
|
105
|
-
Classes using this mixin must implement the ICommunicator interface methods.
|
|
106
|
-
"""
|
|
107
|
-
|
|
108
|
-
# Note: These will be provided by mixing classes as properties
|
|
109
|
-
@property
|
|
110
|
-
def rank(self) -> int:
|
|
111
|
-
"""Must be implemented by mixing class"""
|
|
112
|
-
raise NotImplementedError
|
|
113
|
-
|
|
114
|
-
@property
|
|
115
|
-
def world_size(self) -> int:
|
|
116
|
-
"""Must be implemented by mixing class"""
|
|
117
|
-
raise NotImplementedError
|
|
118
|
-
|
|
119
|
-
def send(self, to: int, key: str, data: Any) -> None:
|
|
120
|
-
"""Must be implemented by mixing class"""
|
|
121
|
-
raise NotImplementedError
|
|
122
|
-
|
|
123
|
-
def recv(self, frm: int, key: str) -> Any:
|
|
124
|
-
"""Must be implemented by mixing class"""
|
|
125
|
-
raise NotImplementedError
|
|
126
|
-
|
|
127
|
-
def new_id(self) -> str:
|
|
128
|
-
"""Must be implemented by mixing class"""
|
|
129
|
-
raise NotImplementedError
|
|
130
|
-
|
|
131
|
-
def p2p(self, frm: int, to: int, data: Any) -> Any:
|
|
132
|
-
"""Perform point-to-point communication"""
|
|
133
|
-
# p2p is a special collective operation, with non-sender and non-receiver nodes get None
|
|
134
|
-
assert 0 <= frm < self.world_size
|
|
135
|
-
assert 0 <= to < self.world_size
|
|
136
|
-
|
|
137
|
-
cid = self.new_id()
|
|
138
|
-
|
|
139
|
-
if self.rank == frm:
|
|
140
|
-
self.send(to, cid, data)
|
|
141
|
-
|
|
142
|
-
if self.rank == to:
|
|
143
|
-
return self.recv(frm, cid)
|
|
144
|
-
else:
|
|
145
|
-
return None
|
|
146
|
-
|
|
147
|
-
def gather_m(self, pmask: int, root: int, data: Any) -> list[Any]:
|
|
148
|
-
"""Gather data from parties in pmask to root"""
|
|
149
|
-
assert 0 <= root < self.world_size
|
|
150
|
-
# wmask = (1 << self.world_size) - 1
|
|
151
|
-
# assert mpt.is_subset(pmask, wmask)
|
|
152
|
-
|
|
153
|
-
cid = self.new_id()
|
|
154
|
-
|
|
155
|
-
if self.rank in Mask(pmask):
|
|
156
|
-
self.send(root, cid, data)
|
|
157
|
-
|
|
158
|
-
if self.rank == root:
|
|
159
|
-
res = [self.recv(idx, cid) for idx in Mask(pmask)]
|
|
160
|
-
else:
|
|
161
|
-
res = [None] * Mask(pmask).num_parties()
|
|
162
|
-
|
|
163
|
-
return res
|
|
164
|
-
|
|
165
|
-
def gather(self, root: int, data: Any) -> list[Any]:
|
|
166
|
-
"""Gather data from all processes to root"""
|
|
167
|
-
pmask = Mask.all(self.world_size)
|
|
168
|
-
return self.gather_m(pmask.value, root, data)
|
|
169
|
-
|
|
170
|
-
def scatter_m(self, pmask: int, root: int, args: list[Any]) -> Any:
|
|
171
|
-
"""Scatter data from root to parties in pmask"""
|
|
172
|
-
logging.debug(
|
|
173
|
-
f"[{self.rank}]: scatter_m: pmask={pmask}, root={root}, args={args}"
|
|
174
|
-
)
|
|
175
|
-
assert 0 <= root < self.world_size
|
|
176
|
-
mask = Mask(pmask)
|
|
177
|
-
assert len(args) == mask.num_parties(), f"{len(args)} != {mask.num_parties()}"
|
|
178
|
-
|
|
179
|
-
cid = self.new_id()
|
|
180
|
-
|
|
181
|
-
if self.rank == root:
|
|
182
|
-
for idx, arg in zip(mask, args, strict=True):
|
|
183
|
-
self.send(idx, cid, arg)
|
|
184
|
-
|
|
185
|
-
if self.rank in mask:
|
|
186
|
-
data = self.recv(root, cid)
|
|
187
|
-
else:
|
|
188
|
-
data = None
|
|
189
|
-
|
|
190
|
-
return data
|
|
191
|
-
|
|
192
|
-
def scatter(self, root: int, args: list[Any]) -> Any:
|
|
193
|
-
"""Scatter data from root to all processes"""
|
|
194
|
-
pmask = Mask.all(self.world_size)
|
|
195
|
-
return self.scatter_m(pmask.value, root, args)
|
|
196
|
-
|
|
197
|
-
def allgather_m(self, pmask: int, arg: Any) -> list[Any]:
|
|
198
|
-
"""Gather data from parties in pmask to all parties"""
|
|
199
|
-
logging.debug(f"allgather_m: pmask={pmask}, arg={arg}")
|
|
200
|
-
cid = self.new_id()
|
|
201
|
-
|
|
202
|
-
if self.rank in Mask(pmask):
|
|
203
|
-
for idx in Mask(pmask):
|
|
204
|
-
self.send(idx, cid, arg)
|
|
205
|
-
|
|
206
|
-
res = [self.recv(idx, cid) for idx in Mask(pmask)]
|
|
207
|
-
else:
|
|
208
|
-
res = [None] * Mask(pmask).num_parties()
|
|
209
|
-
|
|
210
|
-
return res
|
|
211
|
-
|
|
212
|
-
def allgather(self, arg: Any) -> list[Any]:
|
|
213
|
-
"""Gather data from all processes to all processes"""
|
|
214
|
-
pmask = Mask.all(self.world_size)
|
|
215
|
-
return self.allgather_m(pmask.value, arg)
|
|
216
|
-
|
|
217
|
-
def bcast_m(self, pmask: int, root: int, arg: Any) -> Any:
|
|
218
|
-
"""Broadcast data from root to parties in pmask"""
|
|
219
|
-
logging.debug(f"bcast_m: pmask={pmask}, root={root}, arg={arg}")
|
|
220
|
-
assert 0 <= root < self.world_size
|
|
221
|
-
|
|
222
|
-
cid = self.new_id()
|
|
223
|
-
|
|
224
|
-
if self.rank == root:
|
|
225
|
-
for idx in Mask(pmask):
|
|
226
|
-
self.send(idx, cid, arg)
|
|
227
|
-
|
|
228
|
-
if self.rank in Mask(pmask):
|
|
229
|
-
return self.recv(root, cid)
|
|
230
|
-
else:
|
|
231
|
-
return None
|
|
232
|
-
|
|
233
|
-
def bcast(self, root: int, arg: Any) -> Any:
|
|
234
|
-
"""Broadcast data from root to all processes"""
|
|
235
|
-
pmask = Mask.all(self.world_size)
|
|
236
|
-
return self.bcast_m(pmask.value, root, arg)
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
class CommunicatorBase(ICommunicator):
|
|
240
|
-
"""Base implementation providing message box functionality for local communication"""
|
|
241
|
-
|
|
242
|
-
def __init__(self, rank: int, world_size: int):
|
|
243
|
-
self._rank = rank
|
|
244
|
-
self._world_size = world_size
|
|
245
|
-
self._msgboxes: dict = {}
|
|
246
|
-
self._cond = threading.Condition()
|
|
247
|
-
self._counter = 0
|
|
248
|
-
|
|
249
|
-
@property
|
|
250
|
-
def rank(self) -> int:
|
|
251
|
-
return self._rank
|
|
252
|
-
|
|
253
|
-
@property
|
|
254
|
-
def world_size(self) -> int:
|
|
255
|
-
return self._world_size
|
|
256
|
-
|
|
257
|
-
# override
|
|
258
|
-
def new_id(self) -> str:
|
|
259
|
-
# Ensure thread-safe ID generation
|
|
260
|
-
with self._cond:
|
|
261
|
-
res = self._counter
|
|
262
|
-
self._counter += 1
|
|
263
|
-
return str(res)
|
|
264
|
-
|
|
265
|
-
def recv(self, frm: int, key: str) -> Any:
|
|
266
|
-
"""Wait until the key is set, returns the value"""
|
|
267
|
-
# print(f"recv {key}: {sender_rank} -> {self.rank}")
|
|
268
|
-
mkey = (frm, key)
|
|
269
|
-
with self._cond:
|
|
270
|
-
# Wait until message arrives, then consume it
|
|
271
|
-
while mkey not in self._msgboxes:
|
|
272
|
-
self._cond.wait()
|
|
273
|
-
return self._msgboxes.pop(mkey)
|
|
274
|
-
|
|
275
|
-
def onSent(self, frm: int, key: str, data: Any) -> None:
|
|
276
|
-
"""Called when a key is sent to self"""
|
|
277
|
-
with self._cond:
|
|
278
|
-
mkey = (frm, key)
|
|
279
|
-
assert mkey not in self._msgboxes, f"{mkey} exist {self._msgboxes.keys()}"
|
|
280
|
-
self._msgboxes[mkey] = data
|
|
281
|
-
self._cond.notify_all()
|
mplang/v1/core/context_mgr.py
DELETED
|
@@ -1,50 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
import contextlib
|
|
18
|
-
from collections.abc import Iterator
|
|
19
|
-
from typing import TYPE_CHECKING
|
|
20
|
-
|
|
21
|
-
if TYPE_CHECKING:
|
|
22
|
-
# Imported only for typing to avoid import cycles at runtime.
|
|
23
|
-
from mplang.v1.core.mpobject import MPContext
|
|
24
|
-
|
|
25
|
-
# The global working context.
|
|
26
|
-
_g_ctx: MPContext | None = None
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def cur_ctx() -> MPContext:
|
|
30
|
-
if _g_ctx is None:
|
|
31
|
-
# Keep the original error text for backward compatibility with callers/tests.
|
|
32
|
-
raise ValueError("Interpreter not set. Please call set_interp() first.")
|
|
33
|
-
return _g_ctx
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def set_ctx(ctx: MPContext) -> None:
|
|
37
|
-
global _g_ctx
|
|
38
|
-
_g_ctx = ctx
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
@contextlib.contextmanager
|
|
42
|
-
def with_ctx(tmp_ctx: MPContext) -> Iterator[MPContext]:
|
|
43
|
-
global _g_ctx
|
|
44
|
-
saved = _g_ctx # Directly save the global interpreter reference
|
|
45
|
-
try:
|
|
46
|
-
_g_ctx = tmp_ctx
|
|
47
|
-
yield tmp_ctx
|
|
48
|
-
finally:
|
|
49
|
-
# Restore the previous interpreter even if it was None
|
|
50
|
-
_g_ctx = saved
|
mplang/v1/core/dtypes.py
DELETED
|
@@ -1,335 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from dataclasses import dataclass
|
|
18
|
-
from typing import Any, final
|
|
19
|
-
|
|
20
|
-
import numpy as np
|
|
21
|
-
|
|
22
|
-
try:
|
|
23
|
-
# Check if JAX is available
|
|
24
|
-
import jax
|
|
25
|
-
import jax.numpy as jnp
|
|
26
|
-
|
|
27
|
-
_JAX_AVAILABLE = True
|
|
28
|
-
except ImportError:
|
|
29
|
-
_JAX_AVAILABLE = False
|
|
30
|
-
|
|
31
|
-
__all__ = [
|
|
32
|
-
"BINARY",
|
|
33
|
-
"BOOL",
|
|
34
|
-
"COMPLEX64",
|
|
35
|
-
"COMPLEX128",
|
|
36
|
-
"DATE",
|
|
37
|
-
"DECIMAL",
|
|
38
|
-
"FLOAT16",
|
|
39
|
-
"FLOAT32",
|
|
40
|
-
"FLOAT64",
|
|
41
|
-
"INT8",
|
|
42
|
-
"INT16",
|
|
43
|
-
"INT32",
|
|
44
|
-
"INT64",
|
|
45
|
-
"INTERVAL",
|
|
46
|
-
"JSON",
|
|
47
|
-
"STRING",
|
|
48
|
-
"TIME",
|
|
49
|
-
"TIMESTAMP",
|
|
50
|
-
"UINT8",
|
|
51
|
-
"UINT16",
|
|
52
|
-
"UINT32",
|
|
53
|
-
"UINT64",
|
|
54
|
-
"UUID",
|
|
55
|
-
"DType",
|
|
56
|
-
"from_numpy",
|
|
57
|
-
"to_numpy",
|
|
58
|
-
]
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
@final
|
|
62
|
-
@dataclass(frozen=True)
|
|
63
|
-
class DType:
|
|
64
|
-
"""Custom dtype representation that can convert between different libraries."""
|
|
65
|
-
|
|
66
|
-
name: str
|
|
67
|
-
bitwidth: int
|
|
68
|
-
is_signed: bool | None = None # None for non-numeric types
|
|
69
|
-
is_floating: bool = False
|
|
70
|
-
is_complex: bool = False
|
|
71
|
-
is_table_only: bool = False # True for types only supported in tables
|
|
72
|
-
|
|
73
|
-
def __post_init__(self) -> None:
|
|
74
|
-
# Validate the dtype configuration
|
|
75
|
-
if self.is_complex and not self.is_floating:
|
|
76
|
-
raise ValueError("Complex types must be floating point")
|
|
77
|
-
if self.is_floating and self.is_signed is None:
|
|
78
|
-
# Floating point types are always signed
|
|
79
|
-
object.__setattr__(self, "is_signed", True)
|
|
80
|
-
|
|
81
|
-
def __str__(self) -> str:
|
|
82
|
-
return self.name
|
|
83
|
-
|
|
84
|
-
def __repr__(self) -> str:
|
|
85
|
-
return f"DType('{self.name}')"
|
|
86
|
-
|
|
87
|
-
def short_name(self) -> str:
|
|
88
|
-
"""Return a short name for the dtype."""
|
|
89
|
-
# Map common types to short names
|
|
90
|
-
name_map = {
|
|
91
|
-
"bool": "bool",
|
|
92
|
-
"int8": "i8",
|
|
93
|
-
"int16": "i16",
|
|
94
|
-
"int32": "i32",
|
|
95
|
-
"int64": "i64",
|
|
96
|
-
"uint8": "u8",
|
|
97
|
-
"uint16": "u16",
|
|
98
|
-
"uint32": "u32",
|
|
99
|
-
"uint64": "u64",
|
|
100
|
-
"float16": "f16",
|
|
101
|
-
"float32": "f32",
|
|
102
|
-
"float64": "f64",
|
|
103
|
-
"complex64": "c64",
|
|
104
|
-
"complex128": "c128",
|
|
105
|
-
# Table-only types
|
|
106
|
-
"string": "str",
|
|
107
|
-
"date": "date",
|
|
108
|
-
"time": "time",
|
|
109
|
-
"timestamp": "timestamp",
|
|
110
|
-
"decimal": "decimal",
|
|
111
|
-
"binary": "binary",
|
|
112
|
-
"json": "json",
|
|
113
|
-
"uuid": "uuid",
|
|
114
|
-
"interval": "interval",
|
|
115
|
-
}
|
|
116
|
-
return name_map.get(self.name, self.name)
|
|
117
|
-
|
|
118
|
-
@classmethod
|
|
119
|
-
def from_numpy(cls, np_dtype: Any) -> DType:
|
|
120
|
-
"""Convert from NumPy dtype to custom DType."""
|
|
121
|
-
np_dtype = np.dtype(np_dtype)
|
|
122
|
-
name = np_dtype.name
|
|
123
|
-
|
|
124
|
-
if np_dtype.kind == "b": # boolean
|
|
125
|
-
return cls(name, 8, None, False, False) # bool is typically 8 bits
|
|
126
|
-
elif np_dtype.kind in ("i", "u"): # integer
|
|
127
|
-
return cls(name, np_dtype.itemsize * 8, np_dtype.kind == "i", False, False)
|
|
128
|
-
elif np_dtype.kind == "f": # floating
|
|
129
|
-
return cls(name, np_dtype.itemsize * 8, True, True, False)
|
|
130
|
-
elif np_dtype.kind == "c": # complex
|
|
131
|
-
return cls(name, np_dtype.itemsize * 8, True, True, True)
|
|
132
|
-
elif np_dtype.kind in ("U", "S", "O"): # unicode, byte string, or object
|
|
133
|
-
# For string types, bitwidth represents the maximum number of bytes per element (i.e., np_dtype.itemsize)
|
|
134
|
-
# Object is often used for strings.
|
|
135
|
-
return STRING
|
|
136
|
-
else:
|
|
137
|
-
raise ValueError(f"Unsupported NumPy dtype kind: {np_dtype.kind}")
|
|
138
|
-
|
|
139
|
-
@classmethod
|
|
140
|
-
def from_jax(cls, jax_dtype: Any) -> DType:
|
|
141
|
-
"""Convert from JAX dtype to custom DType."""
|
|
142
|
-
if not _JAX_AVAILABLE:
|
|
143
|
-
raise ImportError("JAX is not available")
|
|
144
|
-
# Special handling for PRNG KeyTy: <class jax._src.prng.KeyTy>
|
|
145
|
-
if jnp.issubdtype(jax_dtype, jax.dtypes.prng_key):
|
|
146
|
-
return cls.from_numpy(np.uint32)
|
|
147
|
-
|
|
148
|
-
# JAX dtypes are essentially NumPy dtypes
|
|
149
|
-
return cls.from_numpy(jax_dtype)
|
|
150
|
-
|
|
151
|
-
@classmethod
|
|
152
|
-
def from_python_type(cls, py_type: type) -> DType:
|
|
153
|
-
"""Convert from Python builtin type to custom DType."""
|
|
154
|
-
if py_type is bool:
|
|
155
|
-
return cls("bool", 8, None, False, False)
|
|
156
|
-
elif py_type is int:
|
|
157
|
-
# Use platform-dependent int size (usually 64-bit)
|
|
158
|
-
return cls("int64", 64, True, False, False)
|
|
159
|
-
elif py_type is float:
|
|
160
|
-
return cls("float64", 64, True, True, False)
|
|
161
|
-
elif py_type is complex:
|
|
162
|
-
return cls("complex128", 128, True, True, True)
|
|
163
|
-
else:
|
|
164
|
-
raise ValueError(f"Unsupported Python type: {py_type}")
|
|
165
|
-
|
|
166
|
-
@classmethod
|
|
167
|
-
def from_any(cls, dtype_like: Any) -> DType:
|
|
168
|
-
"""Convert from any supported dtype representation."""
|
|
169
|
-
if isinstance(dtype_like, cls):
|
|
170
|
-
return dtype_like
|
|
171
|
-
|
|
172
|
-
# Try pandas specific dtype conversion first
|
|
173
|
-
try:
|
|
174
|
-
return cls._from_pandas_dtype(dtype_like)
|
|
175
|
-
except (ImportError, TypeError):
|
|
176
|
-
# ImportError if pandas is not installed
|
|
177
|
-
# TypeError if it's not a pandas dtype we can handle
|
|
178
|
-
pass
|
|
179
|
-
|
|
180
|
-
try:
|
|
181
|
-
return cls._from_arrow_dtype(dtype_like)
|
|
182
|
-
except (ImportError, TypeError):
|
|
183
|
-
# ImportError if pyarrow is not installed
|
|
184
|
-
# TypeError if it's not a pyarrow dtype we can handle
|
|
185
|
-
pass
|
|
186
|
-
|
|
187
|
-
if isinstance(dtype_like, type) and dtype_like in (bool, int, float, complex):
|
|
188
|
-
return cls.from_python_type(dtype_like)
|
|
189
|
-
elif hasattr(dtype_like, "dtype") and not isinstance(dtype_like, type):
|
|
190
|
-
# Objects with dtype attribute (arrays, etc.) but not dtype types themselves
|
|
191
|
-
return cls.from_numpy(dtype_like.dtype)
|
|
192
|
-
else:
|
|
193
|
-
# Try NumPy conversion first (handles dtype types, strings, etc.)
|
|
194
|
-
try:
|
|
195
|
-
return cls.from_numpy(dtype_like)
|
|
196
|
-
except (TypeError, ValueError):
|
|
197
|
-
pass
|
|
198
|
-
|
|
199
|
-
# Try JAX conversion if available
|
|
200
|
-
if _JAX_AVAILABLE:
|
|
201
|
-
try:
|
|
202
|
-
return cls.from_jax(dtype_like)
|
|
203
|
-
except (TypeError, ValueError):
|
|
204
|
-
pass
|
|
205
|
-
|
|
206
|
-
raise ValueError(f"Cannot convert {type(dtype_like)} to DType")
|
|
207
|
-
|
|
208
|
-
@classmethod
|
|
209
|
-
def _from_pandas_dtype(cls, dtype_like: Any) -> DType:
|
|
210
|
-
"""Convert pandas-specific dtypes to DType."""
|
|
211
|
-
# Check if pandas is available
|
|
212
|
-
try:
|
|
213
|
-
import pandas as pd
|
|
214
|
-
from pandas.api.types import is_any_real_numeric_dtype, is_bool_dtype
|
|
215
|
-
except ImportError:
|
|
216
|
-
raise ImportError("pandas not available") from None
|
|
217
|
-
|
|
218
|
-
if not hasattr(dtype_like, "__module__") or "pandas" not in str(
|
|
219
|
-
dtype_like.__module__
|
|
220
|
-
):
|
|
221
|
-
# If it's not a pandas dtype, don't handle it here
|
|
222
|
-
raise TypeError("Not a pandas dtype")
|
|
223
|
-
|
|
224
|
-
if isinstance(dtype_like, pd.StringDtype):
|
|
225
|
-
return STRING
|
|
226
|
-
elif is_bool_dtype(dtype_like):
|
|
227
|
-
# Catches pd.BooleanDtype() and 'bool'
|
|
228
|
-
return BOOL
|
|
229
|
-
elif is_any_real_numeric_dtype(dtype_like):
|
|
230
|
-
# Catches Int64Dtype, Float64Dtype, etc.
|
|
231
|
-
return cls.from_numpy(dtype_like.numpy_dtype)
|
|
232
|
-
|
|
233
|
-
raise TypeError(f"Unsupported pandas dtype: {dtype_like}")
|
|
234
|
-
|
|
235
|
-
@classmethod
|
|
236
|
-
def _from_arrow_dtype(cls, dtype_like: Any) -> DType:
|
|
237
|
-
try:
|
|
238
|
-
import pyarrow as pa
|
|
239
|
-
except ImportError:
|
|
240
|
-
raise ImportError("pyarrow not available") from None
|
|
241
|
-
|
|
242
|
-
if not isinstance(dtype_like, pa.DataType):
|
|
243
|
-
raise TypeError("Not a pyarrow dtype")
|
|
244
|
-
|
|
245
|
-
ARROW_DTYPE_MAPPING = {
|
|
246
|
-
pa.bool_(): BOOL,
|
|
247
|
-
pa.int8(): INT8,
|
|
248
|
-
pa.int16(): INT16,
|
|
249
|
-
pa.int32(): INT32,
|
|
250
|
-
pa.int64(): INT64,
|
|
251
|
-
pa.uint8(): UINT8,
|
|
252
|
-
pa.uint16(): UINT16,
|
|
253
|
-
pa.uint32(): UINT32,
|
|
254
|
-
pa.uint64(): UINT64,
|
|
255
|
-
pa.float16(): FLOAT16,
|
|
256
|
-
pa.float32(): FLOAT32,
|
|
257
|
-
pa.float64(): FLOAT64,
|
|
258
|
-
pa.string(): STRING,
|
|
259
|
-
pa.large_string(): STRING,
|
|
260
|
-
}
|
|
261
|
-
result = ARROW_DTYPE_MAPPING.get(dtype_like)
|
|
262
|
-
if result is not None:
|
|
263
|
-
return result
|
|
264
|
-
raise TypeError(f"Unsupported arrow dtype: {dtype_like}")
|
|
265
|
-
|
|
266
|
-
def to_numpy(self) -> np.dtype:
|
|
267
|
-
"""Convert custom DType to NumPy dtype."""
|
|
268
|
-
return np.dtype(self.name)
|
|
269
|
-
|
|
270
|
-
def to_jax(self) -> Any:
|
|
271
|
-
"""Convert custom DType to JAX dtype."""
|
|
272
|
-
if not _JAX_AVAILABLE:
|
|
273
|
-
raise ImportError("JAX is not available")
|
|
274
|
-
|
|
275
|
-
return jnp.dtype(self.name)
|
|
276
|
-
|
|
277
|
-
def to_python_type(self) -> type:
|
|
278
|
-
"""Convert to Python builtin type if possible."""
|
|
279
|
-
if self.name == "bool":
|
|
280
|
-
return bool
|
|
281
|
-
elif self.name.startswith("int") or self.name.startswith("uint"):
|
|
282
|
-
return int
|
|
283
|
-
elif self.name.startswith("float"):
|
|
284
|
-
return float
|
|
285
|
-
elif self.name.startswith("complex"):
|
|
286
|
-
return complex
|
|
287
|
-
else:
|
|
288
|
-
raise ValueError(f"Cannot convert {self.name} to Python builtin type")
|
|
289
|
-
|
|
290
|
-
def numpy_dtype(self) -> np.dtype:
|
|
291
|
-
"""Convert DType to NumPy dtype for compatibility with external libraries."""
|
|
292
|
-
return self.to_numpy()
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
# Common dtype constants for convenience
|
|
296
|
-
BOOL = DType("bool", 8, None, False, False)
|
|
297
|
-
INT8 = DType("int8", 8, True, False, False)
|
|
298
|
-
INT16 = DType("int16", 16, True, False, False)
|
|
299
|
-
INT32 = DType("int32", 32, True, False, False)
|
|
300
|
-
INT64 = DType("int64", 64, True, False, False)
|
|
301
|
-
UINT8 = DType("uint8", 8, False, False, False)
|
|
302
|
-
UINT16 = DType("uint16", 16, False, False, False)
|
|
303
|
-
UINT32 = DType("uint32", 32, False, False, False)
|
|
304
|
-
UINT64 = DType("uint64", 64, False, False, False)
|
|
305
|
-
FLOAT16 = DType("float16", 16, True, True, False)
|
|
306
|
-
FLOAT32 = DType("float32", 32, True, True, False)
|
|
307
|
-
FLOAT64 = DType("float64", 64, True, True, False)
|
|
308
|
-
COMPLEX64 = DType("complex64", 64, True, True, True)
|
|
309
|
-
COMPLEX128 = DType("complex128", 128, True, True, True)
|
|
310
|
-
|
|
311
|
-
# Table-only types (marked with is_table_only=True)
|
|
312
|
-
STRING = DType("string", 0, None, False, False, True) # Variable length string
|
|
313
|
-
DATE = DType("date", 32, None, False, False, True) # Date only
|
|
314
|
-
TIME = DType("time", 32, None, False, False, True) # Time only
|
|
315
|
-
TIMESTAMP = DType("timestamp", 64, None, False, False, True) # Timestamp
|
|
316
|
-
DECIMAL = DType("decimal", 128, True, False, False, True) # Arbitrary precision decimal
|
|
317
|
-
BINARY = DType("binary", 0, None, False, False, True) # Binary data
|
|
318
|
-
JSON = DType("json", 0, None, False, False, True) # JSON data
|
|
319
|
-
UUID = DType("uuid", 128, None, False, False, True) # UUID type
|
|
320
|
-
|
|
321
|
-
# Additional types commonly used in relational databases but keep minimal
|
|
322
|
-
INTERVAL = DType("interval", 64, None, False, False, True) # Time interval
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
# Helper functions for easy conversion
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
def from_numpy(np_dtype: Any) -> DType:
|
|
329
|
-
"""Convert from NumPy dtype to custom DType."""
|
|
330
|
-
return DType.from_numpy(np_dtype)
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
def to_numpy(dtype: DType) -> np.dtype:
|
|
334
|
-
"""Convert custom DType to NumPy dtype."""
|
|
335
|
-
return dtype.to_numpy()
|