mplang-nightly 0.1.dev269__py3-none-any.whl → 0.1.dev270__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. mplang/__init__.py +391 -17
  2. mplang/{v2/backends → backends}/__init__.py +9 -7
  3. mplang/{v2/backends → backends}/bfv_impl.py +6 -6
  4. mplang/{v2/backends → backends}/crypto_impl.py +6 -6
  5. mplang/{v2/backends → backends}/field_impl.py +5 -5
  6. mplang/{v2/backends → backends}/func_impl.py +4 -4
  7. mplang/{v2/backends → backends}/phe_impl.py +3 -3
  8. mplang/{v2/backends → backends}/simp_design.md +1 -1
  9. mplang/{v2/backends → backends}/simp_driver/__init__.py +5 -5
  10. mplang/{v2/backends → backends}/simp_driver/http.py +8 -8
  11. mplang/{v2/backends → backends}/simp_driver/mem.py +9 -9
  12. mplang/{v2/backends → backends}/simp_driver/ops.py +4 -4
  13. mplang/{v2/backends → backends}/simp_driver/state.py +2 -2
  14. mplang/{v2/backends → backends}/simp_driver/values.py +2 -2
  15. mplang/{v2/backends → backends}/simp_worker/__init__.py +3 -3
  16. mplang/{v2/backends → backends}/simp_worker/http.py +10 -10
  17. mplang/{v2/backends → backends}/simp_worker/mem.py +1 -1
  18. mplang/{v2/backends → backends}/simp_worker/ops.py +5 -5
  19. mplang/{v2/backends → backends}/simp_worker/state.py +2 -4
  20. mplang/{v2/backends → backends}/spu_impl.py +8 -8
  21. mplang/{v2/backends → backends}/spu_state.py +4 -4
  22. mplang/{v2/backends → backends}/store_impl.py +3 -3
  23. mplang/{v2/backends → backends}/table_impl.py +8 -8
  24. mplang/{v2/backends → backends}/tee_impl.py +6 -6
  25. mplang/{v2/backends → backends}/tensor_impl.py +6 -6
  26. mplang/{v2/cli.py → cli.py} +9 -9
  27. mplang/{v2/cli_guide.md → cli_guide.md} +12 -12
  28. mplang/{v2/dialects → dialects}/__init__.py +5 -5
  29. mplang/{v2/dialects → dialects}/bfv.py +6 -6
  30. mplang/{v2/dialects → dialects}/crypto.py +5 -5
  31. mplang/{v2/dialects → dialects}/dtypes.py +2 -2
  32. mplang/{v2/dialects → dialects}/field.py +3 -3
  33. mplang/{v2/dialects → dialects}/func.py +2 -2
  34. mplang/{v2/dialects → dialects}/phe.py +6 -6
  35. mplang/{v2/dialects → dialects}/simp.py +6 -6
  36. mplang/{v2/dialects → dialects}/spu.py +7 -7
  37. mplang/{v2/dialects → dialects}/store.py +2 -2
  38. mplang/{v2/dialects → dialects}/table.py +3 -3
  39. mplang/{v2/dialects → dialects}/tee.py +6 -6
  40. mplang/{v2/dialects → dialects}/tensor.py +5 -5
  41. mplang/{v2/edsl → edsl}/__init__.py +3 -3
  42. mplang/{v2/edsl → edsl}/context.py +6 -6
  43. mplang/{v2/edsl → edsl}/graph.py +5 -5
  44. mplang/{v2/edsl → edsl}/jit.py +2 -2
  45. mplang/{v2/edsl → edsl}/object.py +1 -1
  46. mplang/{v2/edsl → edsl}/primitive.py +5 -5
  47. mplang/{v2/edsl → edsl}/printer.py +1 -1
  48. mplang/{v2/edsl → edsl}/serde.py +1 -1
  49. mplang/{v2/edsl → edsl}/tracer.py +7 -7
  50. mplang/{v2/edsl → edsl}/typing.py +1 -1
  51. mplang/{v2/kernels → kernels}/ldpc.cpp +13 -13
  52. mplang/{v2/kernels → kernels}/okvs.cpp +4 -4
  53. mplang/{v2/kernels → kernels}/okvs_opt.cpp +31 -31
  54. mplang/{v2/kernels → kernels}/py_kernels.py +1 -1
  55. mplang/{v2/libs → libs}/collective.py +5 -5
  56. mplang/{v2/libs → libs}/device/__init__.py +1 -1
  57. mplang/{v2/libs → libs}/device/api.py +12 -12
  58. mplang/{v2/libs → libs}/ml/__init__.py +1 -1
  59. mplang/{v2/libs → libs}/ml/sgb.py +4 -4
  60. mplang/{v2/libs → libs}/mpc/__init__.py +3 -3
  61. mplang/{v2/libs → libs}/mpc/_utils.py +2 -2
  62. mplang/{v2/libs → libs}/mpc/analytics/aggregation.py +1 -1
  63. mplang/{v2/libs → libs}/mpc/analytics/groupby.py +2 -2
  64. mplang/{v2/libs → libs}/mpc/analytics/permutation.py +3 -3
  65. mplang/{v2/libs → libs}/mpc/ot/base.py +3 -3
  66. mplang/{v2/libs → libs}/mpc/ot/extension.py +2 -2
  67. mplang/{v2/libs → libs}/mpc/ot/silent.py +4 -4
  68. mplang/{v2/libs → libs}/mpc/psi/cuckoo.py +3 -3
  69. mplang/{v2/libs → libs}/mpc/psi/okvs.py +1 -1
  70. mplang/{v2/libs → libs}/mpc/psi/okvs_gct.py +3 -3
  71. mplang/{v2/libs → libs}/mpc/psi/oprf.py +3 -3
  72. mplang/{v2/libs → libs}/mpc/psi/rr22.py +7 -7
  73. mplang/{v2/libs → libs}/mpc/psi/unbalanced.py +4 -4
  74. mplang/{v2/libs → libs}/mpc/vole/gilboa.py +3 -3
  75. mplang/{v2/libs → libs}/mpc/vole/ldpc.py +2 -2
  76. mplang/{v2/libs → libs}/mpc/vole/silver.py +6 -6
  77. mplang/{v2/runtime → runtime}/interpreter.py +11 -11
  78. mplang/{v2/runtime → runtime}/value.py +2 -2
  79. mplang/{v1/runtime → utils}/__init__.py +18 -15
  80. mplang/{v1/utils → utils}/func_utils.py +1 -1
  81. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/METADATA +2 -2
  82. mplang_nightly-0.1.dev270.dist-info/RECORD +102 -0
  83. mplang/v1/__init__.py +0 -157
  84. mplang/v1/_device.py +0 -602
  85. mplang/v1/analysis/__init__.py +0 -37
  86. mplang/v1/analysis/diagram.py +0 -567
  87. mplang/v1/core/__init__.py +0 -157
  88. mplang/v1/core/cluster.py +0 -343
  89. mplang/v1/core/comm.py +0 -281
  90. mplang/v1/core/context_mgr.py +0 -50
  91. mplang/v1/core/dtypes.py +0 -335
  92. mplang/v1/core/expr/__init__.py +0 -80
  93. mplang/v1/core/expr/ast.py +0 -542
  94. mplang/v1/core/expr/evaluator.py +0 -581
  95. mplang/v1/core/expr/printer.py +0 -285
  96. mplang/v1/core/expr/transformer.py +0 -141
  97. mplang/v1/core/expr/utils.py +0 -78
  98. mplang/v1/core/expr/visitor.py +0 -85
  99. mplang/v1/core/expr/walk.py +0 -387
  100. mplang/v1/core/interp.py +0 -160
  101. mplang/v1/core/mask.py +0 -325
  102. mplang/v1/core/mpir.py +0 -965
  103. mplang/v1/core/mpobject.py +0 -117
  104. mplang/v1/core/mptype.py +0 -407
  105. mplang/v1/core/pfunc.py +0 -130
  106. mplang/v1/core/primitive.py +0 -877
  107. mplang/v1/core/table.py +0 -218
  108. mplang/v1/core/tensor.py +0 -75
  109. mplang/v1/core/tracer.py +0 -383
  110. mplang/v1/host.py +0 -130
  111. mplang/v1/kernels/__init__.py +0 -41
  112. mplang/v1/kernels/base.py +0 -125
  113. mplang/v1/kernels/basic.py +0 -240
  114. mplang/v1/kernels/context.py +0 -369
  115. mplang/v1/kernels/crypto.py +0 -122
  116. mplang/v1/kernels/fhe.py +0 -858
  117. mplang/v1/kernels/mock_tee.py +0 -72
  118. mplang/v1/kernels/phe.py +0 -1864
  119. mplang/v1/kernels/spu.py +0 -341
  120. mplang/v1/kernels/sql_duckdb.py +0 -44
  121. mplang/v1/kernels/stablehlo.py +0 -90
  122. mplang/v1/kernels/value.py +0 -626
  123. mplang/v1/ops/__init__.py +0 -35
  124. mplang/v1/ops/base.py +0 -424
  125. mplang/v1/ops/basic.py +0 -294
  126. mplang/v1/ops/crypto.py +0 -262
  127. mplang/v1/ops/fhe.py +0 -272
  128. mplang/v1/ops/jax_cc.py +0 -147
  129. mplang/v1/ops/nnx_cc.py +0 -168
  130. mplang/v1/ops/phe.py +0 -216
  131. mplang/v1/ops/spu.py +0 -151
  132. mplang/v1/ops/sql_cc.py +0 -303
  133. mplang/v1/ops/tee.py +0 -36
  134. mplang/v1/protos/v1alpha1/mpir_pb2.py +0 -63
  135. mplang/v1/protos/v1alpha1/mpir_pb2.pyi +0 -557
  136. mplang/v1/protos/v1alpha1/value_pb2.py +0 -34
  137. mplang/v1/protos/v1alpha1/value_pb2.pyi +0 -169
  138. mplang/v1/runtime/channel.py +0 -230
  139. mplang/v1/runtime/cli.py +0 -451
  140. mplang/v1/runtime/client.py +0 -456
  141. mplang/v1/runtime/communicator.py +0 -131
  142. mplang/v1/runtime/data_providers.py +0 -303
  143. mplang/v1/runtime/driver.py +0 -324
  144. mplang/v1/runtime/exceptions.py +0 -27
  145. mplang/v1/runtime/http_api.md +0 -56
  146. mplang/v1/runtime/link_comm.py +0 -196
  147. mplang/v1/runtime/server.py +0 -501
  148. mplang/v1/runtime/session.py +0 -270
  149. mplang/v1/runtime/simulation.py +0 -324
  150. mplang/v1/simp/__init__.py +0 -13
  151. mplang/v1/simp/api.py +0 -353
  152. mplang/v1/simp/mpi.py +0 -131
  153. mplang/v1/simp/party.py +0 -225
  154. mplang/v1/simp/random.py +0 -120
  155. mplang/v1/simp/smpc.py +0 -238
  156. mplang/v1/utils/__init__.py +0 -13
  157. mplang/v1/utils/crypto.py +0 -32
  158. mplang/v1/utils/spu_utils.py +0 -130
  159. mplang/v1/utils/table_utils.py +0 -185
  160. mplang/v2/__init__.py +0 -424
  161. mplang_nightly-0.1.dev269.dist-info/RECORD +0 -180
  162. /mplang/{v2/backends → backends}/channel.py +0 -0
  163. /mplang/{v2/edsl → edsl}/README.md +0 -0
  164. /mplang/{v2/edsl → edsl}/registry.py +0 -0
  165. /mplang/{v2/kernels → kernels}/Makefile +0 -0
  166. /mplang/{v2/kernels → kernels}/__init__.py +0 -0
  167. /mplang/{v2/kernels → kernels}/gf128.cpp +0 -0
  168. /mplang/{v2/libs → libs}/device/cluster.py +0 -0
  169. /mplang/{v2/libs → libs}/mpc/analytics/__init__.py +0 -0
  170. /mplang/{v2/libs → libs}/mpc/analytics/groupby.md +0 -0
  171. /mplang/{v2/libs → libs}/mpc/common/constants.py +0 -0
  172. /mplang/{v2/libs → libs}/mpc/ot/__init__.py +0 -0
  173. /mplang/{v2/libs → libs}/mpc/psi/__init__.py +0 -0
  174. /mplang/{v2/libs → libs}/mpc/vole/__init__.py +0 -0
  175. /mplang/{v2/runtime → runtime}/__init__.py +0 -0
  176. /mplang/{v2/runtime → runtime}/dialect_state.py +0 -0
  177. /mplang/{v2/runtime → runtime}/object_store.py +0 -0
  178. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/WHEEL +0 -0
  179. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/entry_points.txt +0 -0
  180. {mplang_nightly-0.1.dev269.dist-info → mplang_nightly-0.1.dev270.dist-info}/licenses/LICENSE +0 -0
mplang/v1/core/comm.py DELETED
@@ -1,281 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import logging
18
- import threading
19
- from abc import ABC, abstractmethod
20
- from typing import Any
21
-
22
- from mplang.v1.core.mask import Mask
23
-
24
-
25
- class ICommunicator(ABC):
26
- """Base class for communicators."""
27
-
28
- @property
29
- @abstractmethod
30
- def rank(self) -> int:
31
- """Get the rank of this process"""
32
-
33
- @property
34
- @abstractmethod
35
- def world_size(self) -> int:
36
- """Get the world size of this process"""
37
-
38
- @abstractmethod
39
- def new_id(self) -> str:
40
- """Must be implemented by mixing class"""
41
- raise NotImplementedError
42
-
43
- @abstractmethod
44
- def send(self, to: int, key: str, data: Any) -> None:
45
- """Send data to peer with the given key"""
46
-
47
- @abstractmethod
48
- def recv(self, frm: int, key: str) -> Any:
49
- """Receive data from peer with the given key"""
50
-
51
- @abstractmethod
52
- def onSent(self, frm: int, key: str, data: Any) -> None:
53
- """Called when a key is sent to self"""
54
-
55
-
56
- class ICollective(ABC):
57
- """Interface for collective communication"""
58
-
59
- @abstractmethod
60
- def p2p(self, frm: int, to: int, data: Any) -> Any:
61
- """Perform point-to-point communication"""
62
-
63
- @abstractmethod
64
- def gather(self, root: int, data: Any) -> list[Any]:
65
- """Gather data from all processes to root"""
66
-
67
- @abstractmethod
68
- def gather_m(self, pmask: int, root: int, data: Any) -> list[Any]:
69
- """Gather data from parties in pmask to root"""
70
-
71
- @abstractmethod
72
- def scatter(self, root: int, args: list[Any]) -> Any:
73
- """Scatter data from root to all processes"""
74
-
75
- @abstractmethod
76
- def scatter_m(self, pmask: int, root: int, args: list[Any]) -> Any:
77
- """Scatter data from root to parties in pmask"""
78
-
79
- @abstractmethod
80
- def allgather(self, arg: Any) -> list[Any]:
81
- """Gather data from all processes to all processes"""
82
-
83
- @abstractmethod
84
- def allgather_m(self, pmask: int, arg: Any) -> list[Any]:
85
- """Gather data from parties in pmask to all processes"""
86
-
87
- @abstractmethod
88
- def bcast(self, root: int, arg: Any) -> Any:
89
- """Broadcast data from root to all processes"""
90
-
91
- @abstractmethod
92
- def bcast_m(self, pmask: int, root: int, arg: Any) -> Any:
93
- """Broadcast data from root to parties in pmask"""
94
-
95
-
96
- def is_rank_in(rank: int, mask: int) -> bool:
97
- """Check if the given rank is in the mask"""
98
- return (1 << rank) & mask != 0
99
-
100
-
101
- class CollectiveMixin(ICommunicator, ICollective):
102
- """Mixin class providing default implementations of collective communication algorithms
103
-
104
- This mixin provides implementations based on send/recv primitives.
105
- Classes using this mixin must implement the ICommunicator interface methods.
106
- """
107
-
108
- # Note: These will be provided by mixing classes as properties
109
- @property
110
- def rank(self) -> int:
111
- """Must be implemented by mixing class"""
112
- raise NotImplementedError
113
-
114
- @property
115
- def world_size(self) -> int:
116
- """Must be implemented by mixing class"""
117
- raise NotImplementedError
118
-
119
- def send(self, to: int, key: str, data: Any) -> None:
120
- """Must be implemented by mixing class"""
121
- raise NotImplementedError
122
-
123
- def recv(self, frm: int, key: str) -> Any:
124
- """Must be implemented by mixing class"""
125
- raise NotImplementedError
126
-
127
- def new_id(self) -> str:
128
- """Must be implemented by mixing class"""
129
- raise NotImplementedError
130
-
131
- def p2p(self, frm: int, to: int, data: Any) -> Any:
132
- """Perform point-to-point communication"""
133
- # p2p is a special collective operation, with non-sender and non-receiver nodes get None
134
- assert 0 <= frm < self.world_size
135
- assert 0 <= to < self.world_size
136
-
137
- cid = self.new_id()
138
-
139
- if self.rank == frm:
140
- self.send(to, cid, data)
141
-
142
- if self.rank == to:
143
- return self.recv(frm, cid)
144
- else:
145
- return None
146
-
147
- def gather_m(self, pmask: int, root: int, data: Any) -> list[Any]:
148
- """Gather data from parties in pmask to root"""
149
- assert 0 <= root < self.world_size
150
- # wmask = (1 << self.world_size) - 1
151
- # assert mpt.is_subset(pmask, wmask)
152
-
153
- cid = self.new_id()
154
-
155
- if self.rank in Mask(pmask):
156
- self.send(root, cid, data)
157
-
158
- if self.rank == root:
159
- res = [self.recv(idx, cid) for idx in Mask(pmask)]
160
- else:
161
- res = [None] * Mask(pmask).num_parties()
162
-
163
- return res
164
-
165
- def gather(self, root: int, data: Any) -> list[Any]:
166
- """Gather data from all processes to root"""
167
- pmask = Mask.all(self.world_size)
168
- return self.gather_m(pmask.value, root, data)
169
-
170
- def scatter_m(self, pmask: int, root: int, args: list[Any]) -> Any:
171
- """Scatter data from root to parties in pmask"""
172
- logging.debug(
173
- f"[{self.rank}]: scatter_m: pmask={pmask}, root={root}, args={args}"
174
- )
175
- assert 0 <= root < self.world_size
176
- mask = Mask(pmask)
177
- assert len(args) == mask.num_parties(), f"{len(args)} != {mask.num_parties()}"
178
-
179
- cid = self.new_id()
180
-
181
- if self.rank == root:
182
- for idx, arg in zip(mask, args, strict=True):
183
- self.send(idx, cid, arg)
184
-
185
- if self.rank in mask:
186
- data = self.recv(root, cid)
187
- else:
188
- data = None
189
-
190
- return data
191
-
192
- def scatter(self, root: int, args: list[Any]) -> Any:
193
- """Scatter data from root to all processes"""
194
- pmask = Mask.all(self.world_size)
195
- return self.scatter_m(pmask.value, root, args)
196
-
197
- def allgather_m(self, pmask: int, arg: Any) -> list[Any]:
198
- """Gather data from parties in pmask to all parties"""
199
- logging.debug(f"allgather_m: pmask={pmask}, arg={arg}")
200
- cid = self.new_id()
201
-
202
- if self.rank in Mask(pmask):
203
- for idx in Mask(pmask):
204
- self.send(idx, cid, arg)
205
-
206
- res = [self.recv(idx, cid) for idx in Mask(pmask)]
207
- else:
208
- res = [None] * Mask(pmask).num_parties()
209
-
210
- return res
211
-
212
- def allgather(self, arg: Any) -> list[Any]:
213
- """Gather data from all processes to all processes"""
214
- pmask = Mask.all(self.world_size)
215
- return self.allgather_m(pmask.value, arg)
216
-
217
- def bcast_m(self, pmask: int, root: int, arg: Any) -> Any:
218
- """Broadcast data from root to parties in pmask"""
219
- logging.debug(f"bcast_m: pmask={pmask}, root={root}, arg={arg}")
220
- assert 0 <= root < self.world_size
221
-
222
- cid = self.new_id()
223
-
224
- if self.rank == root:
225
- for idx in Mask(pmask):
226
- self.send(idx, cid, arg)
227
-
228
- if self.rank in Mask(pmask):
229
- return self.recv(root, cid)
230
- else:
231
- return None
232
-
233
- def bcast(self, root: int, arg: Any) -> Any:
234
- """Broadcast data from root to all processes"""
235
- pmask = Mask.all(self.world_size)
236
- return self.bcast_m(pmask.value, root, arg)
237
-
238
-
239
- class CommunicatorBase(ICommunicator):
240
- """Base implementation providing message box functionality for local communication"""
241
-
242
- def __init__(self, rank: int, world_size: int):
243
- self._rank = rank
244
- self._world_size = world_size
245
- self._msgboxes: dict = {}
246
- self._cond = threading.Condition()
247
- self._counter = 0
248
-
249
- @property
250
- def rank(self) -> int:
251
- return self._rank
252
-
253
- @property
254
- def world_size(self) -> int:
255
- return self._world_size
256
-
257
- # override
258
- def new_id(self) -> str:
259
- # Ensure thread-safe ID generation
260
- with self._cond:
261
- res = self._counter
262
- self._counter += 1
263
- return str(res)
264
-
265
- def recv(self, frm: int, key: str) -> Any:
266
- """Wait until the key is set, returns the value"""
267
- # print(f"recv {key}: {sender_rank} -> {self.rank}")
268
- mkey = (frm, key)
269
- with self._cond:
270
- # Wait until message arrives, then consume it
271
- while mkey not in self._msgboxes:
272
- self._cond.wait()
273
- return self._msgboxes.pop(mkey)
274
-
275
- def onSent(self, frm: int, key: str, data: Any) -> None:
276
- """Called when a key is sent to self"""
277
- with self._cond:
278
- mkey = (frm, key)
279
- assert mkey not in self._msgboxes, f"{mkey} exist {self._msgboxes.keys()}"
280
- self._msgboxes[mkey] = data
281
- self._cond.notify_all()
@@ -1,50 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import contextlib
18
- from collections.abc import Iterator
19
- from typing import TYPE_CHECKING
20
-
21
- if TYPE_CHECKING:
22
- # Imported only for typing to avoid import cycles at runtime.
23
- from mplang.v1.core.mpobject import MPContext
24
-
25
- # The global working context.
26
- _g_ctx: MPContext | None = None
27
-
28
-
29
- def cur_ctx() -> MPContext:
30
- if _g_ctx is None:
31
- # Keep the original error text for backward compatibility with callers/tests.
32
- raise ValueError("Interpreter not set. Please call set_interp() first.")
33
- return _g_ctx
34
-
35
-
36
- def set_ctx(ctx: MPContext) -> None:
37
- global _g_ctx
38
- _g_ctx = ctx
39
-
40
-
41
- @contextlib.contextmanager
42
- def with_ctx(tmp_ctx: MPContext) -> Iterator[MPContext]:
43
- global _g_ctx
44
- saved = _g_ctx # Directly save the global interpreter reference
45
- try:
46
- _g_ctx = tmp_ctx
47
- yield tmp_ctx
48
- finally:
49
- # Restore the previous interpreter even if it was None
50
- _g_ctx = saved
mplang/v1/core/dtypes.py DELETED
@@ -1,335 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- from dataclasses import dataclass
18
- from typing import Any, final
19
-
20
- import numpy as np
21
-
22
- try:
23
- # Check if JAX is available
24
- import jax
25
- import jax.numpy as jnp
26
-
27
- _JAX_AVAILABLE = True
28
- except ImportError:
29
- _JAX_AVAILABLE = False
30
-
31
- __all__ = [
32
- "BINARY",
33
- "BOOL",
34
- "COMPLEX64",
35
- "COMPLEX128",
36
- "DATE",
37
- "DECIMAL",
38
- "FLOAT16",
39
- "FLOAT32",
40
- "FLOAT64",
41
- "INT8",
42
- "INT16",
43
- "INT32",
44
- "INT64",
45
- "INTERVAL",
46
- "JSON",
47
- "STRING",
48
- "TIME",
49
- "TIMESTAMP",
50
- "UINT8",
51
- "UINT16",
52
- "UINT32",
53
- "UINT64",
54
- "UUID",
55
- "DType",
56
- "from_numpy",
57
- "to_numpy",
58
- ]
59
-
60
-
61
- @final
62
- @dataclass(frozen=True)
63
- class DType:
64
- """Custom dtype representation that can convert between different libraries."""
65
-
66
- name: str
67
- bitwidth: int
68
- is_signed: bool | None = None # None for non-numeric types
69
- is_floating: bool = False
70
- is_complex: bool = False
71
- is_table_only: bool = False # True for types only supported in tables
72
-
73
- def __post_init__(self) -> None:
74
- # Validate the dtype configuration
75
- if self.is_complex and not self.is_floating:
76
- raise ValueError("Complex types must be floating point")
77
- if self.is_floating and self.is_signed is None:
78
- # Floating point types are always signed
79
- object.__setattr__(self, "is_signed", True)
80
-
81
- def __str__(self) -> str:
82
- return self.name
83
-
84
- def __repr__(self) -> str:
85
- return f"DType('{self.name}')"
86
-
87
- def short_name(self) -> str:
88
- """Return a short name for the dtype."""
89
- # Map common types to short names
90
- name_map = {
91
- "bool": "bool",
92
- "int8": "i8",
93
- "int16": "i16",
94
- "int32": "i32",
95
- "int64": "i64",
96
- "uint8": "u8",
97
- "uint16": "u16",
98
- "uint32": "u32",
99
- "uint64": "u64",
100
- "float16": "f16",
101
- "float32": "f32",
102
- "float64": "f64",
103
- "complex64": "c64",
104
- "complex128": "c128",
105
- # Table-only types
106
- "string": "str",
107
- "date": "date",
108
- "time": "time",
109
- "timestamp": "timestamp",
110
- "decimal": "decimal",
111
- "binary": "binary",
112
- "json": "json",
113
- "uuid": "uuid",
114
- "interval": "interval",
115
- }
116
- return name_map.get(self.name, self.name)
117
-
118
- @classmethod
119
- def from_numpy(cls, np_dtype: Any) -> DType:
120
- """Convert from NumPy dtype to custom DType."""
121
- np_dtype = np.dtype(np_dtype)
122
- name = np_dtype.name
123
-
124
- if np_dtype.kind == "b": # boolean
125
- return cls(name, 8, None, False, False) # bool is typically 8 bits
126
- elif np_dtype.kind in ("i", "u"): # integer
127
- return cls(name, np_dtype.itemsize * 8, np_dtype.kind == "i", False, False)
128
- elif np_dtype.kind == "f": # floating
129
- return cls(name, np_dtype.itemsize * 8, True, True, False)
130
- elif np_dtype.kind == "c": # complex
131
- return cls(name, np_dtype.itemsize * 8, True, True, True)
132
- elif np_dtype.kind in ("U", "S", "O"): # unicode, byte string, or object
133
- # For string types, bitwidth represents the maximum number of bytes per element (i.e., np_dtype.itemsize)
134
- # Object is often used for strings.
135
- return STRING
136
- else:
137
- raise ValueError(f"Unsupported NumPy dtype kind: {np_dtype.kind}")
138
-
139
- @classmethod
140
- def from_jax(cls, jax_dtype: Any) -> DType:
141
- """Convert from JAX dtype to custom DType."""
142
- if not _JAX_AVAILABLE:
143
- raise ImportError("JAX is not available")
144
- # Special handling for PRNG KeyTy: <class jax._src.prng.KeyTy>
145
- if jnp.issubdtype(jax_dtype, jax.dtypes.prng_key):
146
- return cls.from_numpy(np.uint32)
147
-
148
- # JAX dtypes are essentially NumPy dtypes
149
- return cls.from_numpy(jax_dtype)
150
-
151
- @classmethod
152
- def from_python_type(cls, py_type: type) -> DType:
153
- """Convert from Python builtin type to custom DType."""
154
- if py_type is bool:
155
- return cls("bool", 8, None, False, False)
156
- elif py_type is int:
157
- # Use platform-dependent int size (usually 64-bit)
158
- return cls("int64", 64, True, False, False)
159
- elif py_type is float:
160
- return cls("float64", 64, True, True, False)
161
- elif py_type is complex:
162
- return cls("complex128", 128, True, True, True)
163
- else:
164
- raise ValueError(f"Unsupported Python type: {py_type}")
165
-
166
- @classmethod
167
- def from_any(cls, dtype_like: Any) -> DType:
168
- """Convert from any supported dtype representation."""
169
- if isinstance(dtype_like, cls):
170
- return dtype_like
171
-
172
- # Try pandas specific dtype conversion first
173
- try:
174
- return cls._from_pandas_dtype(dtype_like)
175
- except (ImportError, TypeError):
176
- # ImportError if pandas is not installed
177
- # TypeError if it's not a pandas dtype we can handle
178
- pass
179
-
180
- try:
181
- return cls._from_arrow_dtype(dtype_like)
182
- except (ImportError, TypeError):
183
- # ImportError if pyarrow is not installed
184
- # TypeError if it's not a pyarrow dtype we can handle
185
- pass
186
-
187
- if isinstance(dtype_like, type) and dtype_like in (bool, int, float, complex):
188
- return cls.from_python_type(dtype_like)
189
- elif hasattr(dtype_like, "dtype") and not isinstance(dtype_like, type):
190
- # Objects with dtype attribute (arrays, etc.) but not dtype types themselves
191
- return cls.from_numpy(dtype_like.dtype)
192
- else:
193
- # Try NumPy conversion first (handles dtype types, strings, etc.)
194
- try:
195
- return cls.from_numpy(dtype_like)
196
- except (TypeError, ValueError):
197
- pass
198
-
199
- # Try JAX conversion if available
200
- if _JAX_AVAILABLE:
201
- try:
202
- return cls.from_jax(dtype_like)
203
- except (TypeError, ValueError):
204
- pass
205
-
206
- raise ValueError(f"Cannot convert {type(dtype_like)} to DType")
207
-
208
- @classmethod
209
- def _from_pandas_dtype(cls, dtype_like: Any) -> DType:
210
- """Convert pandas-specific dtypes to DType."""
211
- # Check if pandas is available
212
- try:
213
- import pandas as pd
214
- from pandas.api.types import is_any_real_numeric_dtype, is_bool_dtype
215
- except ImportError:
216
- raise ImportError("pandas not available") from None
217
-
218
- if not hasattr(dtype_like, "__module__") or "pandas" not in str(
219
- dtype_like.__module__
220
- ):
221
- # If it's not a pandas dtype, don't handle it here
222
- raise TypeError("Not a pandas dtype")
223
-
224
- if isinstance(dtype_like, pd.StringDtype):
225
- return STRING
226
- elif is_bool_dtype(dtype_like):
227
- # Catches pd.BooleanDtype() and 'bool'
228
- return BOOL
229
- elif is_any_real_numeric_dtype(dtype_like):
230
- # Catches Int64Dtype, Float64Dtype, etc.
231
- return cls.from_numpy(dtype_like.numpy_dtype)
232
-
233
- raise TypeError(f"Unsupported pandas dtype: {dtype_like}")
234
-
235
- @classmethod
236
- def _from_arrow_dtype(cls, dtype_like: Any) -> DType:
237
- try:
238
- import pyarrow as pa
239
- except ImportError:
240
- raise ImportError("pyarrow not available") from None
241
-
242
- if not isinstance(dtype_like, pa.DataType):
243
- raise TypeError("Not a pyarrow dtype")
244
-
245
- ARROW_DTYPE_MAPPING = {
246
- pa.bool_(): BOOL,
247
- pa.int8(): INT8,
248
- pa.int16(): INT16,
249
- pa.int32(): INT32,
250
- pa.int64(): INT64,
251
- pa.uint8(): UINT8,
252
- pa.uint16(): UINT16,
253
- pa.uint32(): UINT32,
254
- pa.uint64(): UINT64,
255
- pa.float16(): FLOAT16,
256
- pa.float32(): FLOAT32,
257
- pa.float64(): FLOAT64,
258
- pa.string(): STRING,
259
- pa.large_string(): STRING,
260
- }
261
- result = ARROW_DTYPE_MAPPING.get(dtype_like)
262
- if result is not None:
263
- return result
264
- raise TypeError(f"Unsupported arrow dtype: {dtype_like}")
265
-
266
- def to_numpy(self) -> np.dtype:
267
- """Convert custom DType to NumPy dtype."""
268
- return np.dtype(self.name)
269
-
270
- def to_jax(self) -> Any:
271
- """Convert custom DType to JAX dtype."""
272
- if not _JAX_AVAILABLE:
273
- raise ImportError("JAX is not available")
274
-
275
- return jnp.dtype(self.name)
276
-
277
- def to_python_type(self) -> type:
278
- """Convert to Python builtin type if possible."""
279
- if self.name == "bool":
280
- return bool
281
- elif self.name.startswith("int") or self.name.startswith("uint"):
282
- return int
283
- elif self.name.startswith("float"):
284
- return float
285
- elif self.name.startswith("complex"):
286
- return complex
287
- else:
288
- raise ValueError(f"Cannot convert {self.name} to Python builtin type")
289
-
290
- def numpy_dtype(self) -> np.dtype:
291
- """Convert DType to NumPy dtype for compatibility with external libraries."""
292
- return self.to_numpy()
293
-
294
-
295
- # Common dtype constants for convenience
296
- BOOL = DType("bool", 8, None, False, False)
297
- INT8 = DType("int8", 8, True, False, False)
298
- INT16 = DType("int16", 16, True, False, False)
299
- INT32 = DType("int32", 32, True, False, False)
300
- INT64 = DType("int64", 64, True, False, False)
301
- UINT8 = DType("uint8", 8, False, False, False)
302
- UINT16 = DType("uint16", 16, False, False, False)
303
- UINT32 = DType("uint32", 32, False, False, False)
304
- UINT64 = DType("uint64", 64, False, False, False)
305
- FLOAT16 = DType("float16", 16, True, True, False)
306
- FLOAT32 = DType("float32", 32, True, True, False)
307
- FLOAT64 = DType("float64", 64, True, True, False)
308
- COMPLEX64 = DType("complex64", 64, True, True, True)
309
- COMPLEX128 = DType("complex128", 128, True, True, True)
310
-
311
- # Table-only types (marked with is_table_only=True)
312
- STRING = DType("string", 0, None, False, False, True) # Variable length string
313
- DATE = DType("date", 32, None, False, False, True) # Date only
314
- TIME = DType("time", 32, None, False, False, True) # Time only
315
- TIMESTAMP = DType("timestamp", 64, None, False, False, True) # Timestamp
316
- DECIMAL = DType("decimal", 128, True, False, False, True) # Arbitrary precision decimal
317
- BINARY = DType("binary", 0, None, False, False, True) # Binary data
318
- JSON = DType("json", 0, None, False, False, True) # JSON data
319
- UUID = DType("uuid", 128, None, False, False, True) # UUID type
320
-
321
- # Additional types commonly used in relational databases but keep minimal
322
- INTERVAL = DType("interval", 64, None, False, False, True) # Time interval
323
-
324
-
325
- # Helper functions for easy conversion
326
-
327
-
328
- def from_numpy(np_dtype: Any) -> DType:
329
- """Convert from NumPy dtype to custom DType."""
330
- return DType.from_numpy(np_dtype)
331
-
332
-
333
- def to_numpy(dtype: DType) -> np.dtype:
334
- """Convert custom DType to NumPy dtype."""
335
- return dtype.to_numpy()