mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,816 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ MPLang Core Typing System: Design and Rationale.
17
+
18
+ This module defines the production type system for MPLang, an EDSL for multi-party
19
+ privacy-preserving computation. This document explains the core principles and design
20
+ decisions that shape this system, intended for future maintainers and developers.
21
+
22
+ ===========================
23
+ Tensor Shape System
24
+ ===========================
25
+ MPLang supports a flexible shape system for tensors to handle various compilation and
26
+ runtime scenarios:
27
+
28
+ **Shape Representations:**
29
+ - `None`: Fully dynamic/unranked tensor (shape unknown at compile time)
30
+ Example: `Tensor[i32, None]`
31
+
32
+ - `()`: Scalar (0-dimensional tensor)
33
+ Example: `Tensor[i32, ()]`
34
+
35
+ - `(dim1, dim2, ...)`: Ranked tensor with static or dynamic dimensions
36
+ - Positive integers: Static dimension sizes
37
+ - `-1`: Dynamic/unknown dimension size
38
+ Examples:
39
+ - `Tensor[i32, (3, 10)]` - Fully static 2D tensor
40
+ - `Tensor[i32, (-1, 10)]` - Dynamic batch size, static feature size
41
+ - `Tensor[i32, (-1, -1)]` - Fully dynamic 2D tensor
42
+
43
+ **Utility Properties:**
44
+ - `.is_scalar`: Check if tensor is 0-dimensional
45
+ - `.is_unranked`: Check if shape is None
46
+ - `.is_fully_static`: Check if all dimensions are statically known
47
+ - `.rank`: Get number of dimensions (None for unranked)
48
+ - `.has_dynamic_dims()`: Check if any dimension is dynamic
49
+
50
+ ===========================
51
+ Principle 1: Orthogonality and Composition
52
+ ===========================
53
+ The type system is built on three orthogonal pillars. Each type represents a single,
54
+ well-defined concept. Complex ideas are expressed by composing these simple types,
55
+ rather than by creating a large, monolithic set of specific types.
56
+
57
+ 1. **Layout Types**: Describe the physical shape and structure of data.
58
+ - `Scalar`: Atomic data types (f32, i64).
59
+ - `Tensor`: A multi-dimensional array of a `ScalarType` element type.
60
+ - `Table`: A dictionary-like structure with named columns of any type.
61
+
62
+ 2. **Encryption Types**: Wrap other types to confer privacy properties by making them opaque.
63
+ - `SS`: A single share of a secret-shared value.
64
+ - Note: Element-wise HE types (like `phe.CiphertextType`) are defined in their respective dialects (e.g., `phe`).
65
+
66
+ 3. **Distribution Types**: Wrap other types to describe their physical location among parties.
67
+ - `MP`: Represents a value logically held by multiple parties.
68
+
69
+ An example of composition: `MP[SS[Tensor[f32, (10,)]], (0, 1)]` represents a
70
+ 10-element float tensor, which is secret-shared (`SS`), and whose shares are distributed
71
+ between parties 0 and 1 (`MP`).
72
+
73
+ ===========================
74
+ Principle 2: The "Three Worlds" of Homomorphic Encryption
75
+ ===========================
76
+ A critical design decision is the strict separation of HE-based computation into three
77
+ distinct, non-interacting "worlds." This avoids ambiguity in operator semantics (e.g., `transpose`),
78
+ clarifies the user's mental model, and aligns the type system with the practical realities of
79
+ underlying HE libraries.
80
+
81
+ - **World 1: The Plaintext World**
82
+ - **Core Type**: `Tensor[Scalar, ...]`
83
+ - **API Standard**: Follows NumPy/JAX conventions. All layout and arithmetic operations are valid.
84
+
85
+ - **Core Type**: `Tensor[EncryptedScalar, ...]` (e.g., `Tensor[phe.CiphertextType, ...]`)
86
+ - **API Standard**: Follows TenSEAL-like (Tensor-level) conventions. Layout operations
87
+ (`transpose`, `reshape`) are valid as they merely shuffle independent ciphertext objects.
88
+ Arithmetic operations are overloaded for element-wise HE computation.
89
+
90
+ ===========================
91
+ Principle 3: Contracts via Protocols
92
+ ===========================
93
+ The system uses `typing.Protocol` to define behavioral contracts (similar to Traits in Rust).
94
+ This allows for writing generic functions that operate on any type satisfying a contract,
95
+ promoting extensibility and loose coupling via structural subtyping ("duck typing").
96
+
97
+ - `EncryptedTrait`: For types representing data in an obscured form.
98
+ - `Distributed`: For types describing data distribution.
99
+
100
+ ===========================
101
+ Rationale for the `EncryptedTrait` Protocol
102
+ ===========================
103
+ The name `EncryptedTrait` was deliberately chosen over the more general `PrivacyBearing` after
104
+ careful consideration.
105
+
106
+ 1. **Scope is Naturally Limited**: Other privacy techniques like Differential Privacy or
107
+ Federated Learning are algorithmic or orchestration patterns that do not require new
108
+ type wrappers for the data itself. A DP-protected tensor is still a `Tensor`.
109
+ Therefore, the protocol only needs to cover technologies that transform data into an
110
+ opaque representation.
111
+
112
+ 2. **Secret Sharing as a form of Encryption**: The key insight is to conceptualize
113
+ Secret Sharing (`SS`) as a form of multi-key encryption. For a holder of a single
114
+ share, the other parties' shares are analogous to the "key" needed to recover the
115
+ secret. Both `HE` and `SS` render the data opaque and require external information
116
+ (a key or other shares) for recovery. This powerful mental model allows both `HE`/`SIMD_HE`
117
+ and `SS` to logically implement the `Encrypted` protocol.
118
+
119
+ This makes `Encrypted` a name that is both intuitive to engineers and conceptually
120
+ consistent within the practical scope of this library.
121
+ """
122
+
123
+ from __future__ import annotations
124
+
125
+ from typing import Any, ClassVar, Generic, TypeVar
126
+
127
+ from mplang.v2.edsl import serde
128
+
129
+ # ==============================================================================
130
+ # --- Base Type & Type Aliases
131
+ # ==============================================================================
132
+
133
+ T = TypeVar("T")
134
+
135
+
136
+ class BaseType:
137
+ """Base class for all MPLang types."""
138
+
139
+ def __repr__(self) -> str:
140
+ return str(self)
141
+
142
+
143
+ # ==============================================================================
144
+ # --- Type Protocols (Contracts)
145
+ # ==============================================================================
146
+
147
+
148
+ class EncryptedTrait:
149
+ """A contract for types that represent data in an encrypted or obscured form."""
150
+
151
+ _pt_type: BaseType
152
+ _enc_schema: str
153
+
154
+ @property
155
+ def pt_type(self) -> BaseType:
156
+ return self._pt_type
157
+
158
+ @property
159
+ def enc_schema(self) -> str:
160
+ return self._enc_schema
161
+
162
+
163
+ # ==============================================================================
164
+ # --- Pillar 1: Layout Types
165
+ # ==============================================================================
166
+
167
+
168
+ class ScalarType(BaseType):
169
+ """Base class for all scalar types (integers, floats, complex).
170
+
171
+ This serves as the common parent for IntegerType, FloatType, and ComplexType,
172
+ allowing code to accept any scalar type without needing union types.
173
+ """
174
+
175
+
176
+ @serde.register_class
177
+ class IntegerType(ScalarType):
178
+ """Represents a variable-length integer type.
179
+
180
+ This is a standard integer type with configurable bit width, used for
181
+ arbitrary-precision arithmetic. It can represent integers that exceed
182
+ the range of fixed-width types like i64.
183
+
184
+ Examples:
185
+ >>> i128 = IntegerType(bitwidth=128, signed=True) # i128
186
+ >>> u256 = IntegerType(bitwidth=256, signed=False) # u256
187
+
188
+ Note:
189
+ Encoding-specific metadata (e.g., fixed-point scale, semantic type)
190
+ should be maintained as attributes on operations/objects that use
191
+ IntegerType, not on the type itself.
192
+ """
193
+
194
+ def __init__(self, *, bitwidth: int = 32, signed: bool = True):
195
+ """Initialize an IntegerType.
196
+
197
+ Args:
198
+ bitwidth: Number of bits for the integer representation.
199
+ Common values: 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096.
200
+ signed: Whether the integer is signed (True) or unsigned (False).
201
+ """
202
+ if bitwidth <= 0 or (bitwidth & (bitwidth - 1)) != 0:
203
+ raise ValueError(f"bitwidth must be a positive power of 2, got {bitwidth}")
204
+ self.bitwidth = bitwidth
205
+ self.signed = signed
206
+
207
+ def __str__(self) -> str:
208
+ sign_prefix = "i" if self.signed else "u"
209
+ return f"{sign_prefix}{self.bitwidth}"
210
+
211
+ def __eq__(self, other: object) -> bool:
212
+ if not isinstance(other, IntegerType):
213
+ return False
214
+ return self.bitwidth == other.bitwidth and self.signed == other.signed
215
+
216
+ def __hash__(self) -> int:
217
+ return hash(("IntegerType", self.bitwidth, self.signed))
218
+
219
+ # --- Serde methods ---
220
+ _serde_kind: ClassVar[str] = "mplang.IntegerType"
221
+
222
+ def to_json(self) -> dict[str, Any]:
223
+ return {"bitwidth": self.bitwidth, "signed": self.signed}
224
+
225
+ @classmethod
226
+ def from_json(cls, data: dict[str, Any]) -> IntegerType:
227
+ return cls(bitwidth=data["bitwidth"], signed=data["signed"])
228
+
229
+
230
+ @serde.register_class
231
+ class FloatType(ScalarType):
232
+ """Represents a floating-point type.
233
+
234
+ This supports standard IEEE 754 floating-point types with configurable
235
+ precision (bitwidth).
236
+
237
+ Examples:
238
+ >>> f16 = FloatType(bitwidth=16) # half precision
239
+ >>> f32 = FloatType(bitwidth=32) # single precision
240
+ >>> f64 = FloatType(bitwidth=64) # double precision
241
+ """
242
+
243
+ def __init__(self, *, bitwidth: int = 32):
244
+ """Initialize a FloatType.
245
+
246
+ Args:
247
+ bitwidth: Number of bits for the float representation.
248
+ Standard values: 16 (half), 32 (single), 64 (double).
249
+ """
250
+ if bitwidth not in (16, 32, 64, 128):
251
+ raise ValueError(f"bitwidth must be 16, 32, 64, or 128, got {bitwidth}")
252
+ self.bitwidth = bitwidth
253
+
254
+ def __str__(self) -> str:
255
+ return f"f{self.bitwidth}"
256
+
257
+ def __eq__(self, other: object) -> bool:
258
+ if not isinstance(other, FloatType):
259
+ return False
260
+ return self.bitwidth == other.bitwidth
261
+
262
+ def __hash__(self) -> int:
263
+ return hash(("FloatType", self.bitwidth))
264
+
265
+ # --- Serde methods ---
266
+ _serde_kind: ClassVar[str] = "mplang.FloatType"
267
+
268
+ def to_json(self) -> dict[str, Any]:
269
+ return {"bitwidth": self.bitwidth}
270
+
271
+ @classmethod
272
+ def from_json(cls, data: dict[str, Any]) -> FloatType:
273
+ return cls(bitwidth=data["bitwidth"])
274
+
275
+
276
+ @serde.register_class
277
+ class ComplexType(ScalarType):
278
+ """Represents a complex number type.
279
+
280
+ Complex numbers are represented as pairs of floating-point values.
281
+ Both real and imaginary parts use the same floating-point type.
282
+
283
+ Examples:
284
+ >>> c64 = ComplexType(inner_type=f32) # complex64 (2x float32)
285
+ >>> c128 = ComplexType(inner_type=f64) # complex128 (2x float64)
286
+ """
287
+
288
+ def __init__(self, *, inner_type: FloatType):
289
+ """Initialize a ComplexType.
290
+
291
+ Args:
292
+ inner_type: The floating-point type for real and imaginary parts.
293
+ Common values: f16, f32, f64, f128.
294
+ """
295
+ if not isinstance(inner_type, FloatType):
296
+ raise TypeError(
297
+ f"inner_type must be a FloatType, got {type(inner_type).__name__}"
298
+ )
299
+ self.inner_type = inner_type
300
+
301
+ def __str__(self) -> str:
302
+ return f"c{self.inner_type.bitwidth * 2}"
303
+
304
+ def __eq__(self, other: object) -> bool:
305
+ if not isinstance(other, ComplexType):
306
+ return False
307
+ return self.inner_type == other.inner_type
308
+
309
+ def __hash__(self) -> int:
310
+ return hash(("ComplexType", self.inner_type))
311
+
312
+ # --- Serde methods ---
313
+ _serde_kind: ClassVar[str] = "mplang.ComplexType"
314
+
315
+ def to_json(self) -> dict[str, Any]:
316
+ return {"inner_type": serde.to_json(self.inner_type)}
317
+
318
+ @classmethod
319
+ def from_json(cls, data: dict[str, Any]) -> ComplexType:
320
+ inner = serde.from_json(data["inner_type"])
321
+ if not isinstance(inner, FloatType):
322
+ raise TypeError(f"ComplexType inner must be FloatType, got {type(inner)}")
323
+ return cls(inner_type=inner)
324
+
325
+
326
+ # ==============================================================================
327
+ # --- Predefined Scalar Type Instances
328
+ # ==============================================================================
329
+
330
+ # Numeric scalar types - comprehensive set aligned with common dtypes
331
+ # Integer types (signed)
332
+ i8 = IntegerType(bitwidth=8, signed=True)
333
+ i16 = IntegerType(bitwidth=16, signed=True)
334
+ i32 = IntegerType(bitwidth=32, signed=True)
335
+ i64 = IntegerType(bitwidth=64, signed=True)
336
+
337
+ # Fixed-width integer types (unsigned)
338
+ u8 = IntegerType(bitwidth=8, signed=False)
339
+ u16 = IntegerType(bitwidth=16, signed=False)
340
+ u32 = IntegerType(bitwidth=32, signed=False)
341
+ u64 = IntegerType(bitwidth=64, signed=False)
342
+
343
+ # Floating point types
344
+ f16 = FloatType(bitwidth=16)
345
+ f32 = FloatType(bitwidth=32)
346
+ f64 = FloatType(bitwidth=64)
347
+
348
+ # Complex types
349
+ c64 = ComplexType(inner_type=f32) # 2x float32 = 64 bits total
350
+ c128 = ComplexType(inner_type=f64) # 2x float64 = 128 bits total
351
+
352
+ # Boolean type (1-bit integer, commonly used)
353
+ bool_ = IntegerType(bitwidth=1, signed=True)
354
+ i1 = bool_ # Alias for MLIR convention
355
+
356
+ # Variable-length integer types (common sizes)
357
+ i128 = IntegerType(bitwidth=128, signed=True)
358
+ i256 = IntegerType(bitwidth=256, signed=True)
359
+ u128 = IntegerType(bitwidth=128, signed=False)
360
+ u256 = IntegerType(bitwidth=256, signed=False)
361
+
362
+
363
+ @serde.register_class
364
+ class TensorType(BaseType, Generic[T]):
365
+ """Represents a ranked tensor of a given element type and shape.
366
+
367
+ Following MLIR's RankedTensorType design - all tensors must have a known rank.
368
+ This simplifies type inference and reduces complexity compared to supporting
369
+ fully unranked tensors.
370
+
371
+ Shape must be a tuple where each dimension can be:
372
+ - Positive integer: Static dimension size
373
+ - -1: Dynamic/unknown dimension size
374
+
375
+ Examples:
376
+ Tensor[i32, ()] # Scalar (0-dim tensor)
377
+ Tensor[i32, (-1, 10)] # Partially dynamic shape (rank=2)
378
+ Tensor[i32, (3, 10)] # Fully static shape (rank=2)
379
+ Tensor[i32, (-1,)] # 1D tensor with dynamic size
380
+ """
381
+
382
+ def __init__(self, element_type: BaseType, shape: tuple[int, ...]):
383
+ # Allow any BaseType to support custom types like PointType, EncryptedScalar
384
+ if not isinstance(element_type, BaseType):
385
+ raise TypeError(
386
+ f"Tensor element type must be a BaseType, but got {type(element_type).__name__}."
387
+ )
388
+ self.element_type = element_type
389
+ self.shape = shape
390
+
391
+ # Validate shape is a tuple
392
+ if not isinstance(shape, tuple):
393
+ raise TypeError(f"Shape must be a tuple, got {type(shape).__name__}")
394
+
395
+ # Validate each dimension
396
+ for dim in shape:
397
+ if not isinstance(dim, int):
398
+ raise TypeError(
399
+ f"Shape dimensions must be integers, got {type(dim).__name__}"
400
+ )
401
+ if dim < -1 or dim == 0:
402
+ raise ValueError(
403
+ f"Invalid dimension {dim}: must be positive or -1 for dynamic"
404
+ )
405
+
406
+ def __class_getitem__(cls, params: tuple | Any) -> Any:
407
+ """Enables the syntax `Tensor[element_type, shape]`.
408
+
409
+ Args:
410
+ params: Either a single element_type or (element_type, shape) tuple
411
+
412
+ Returns:
413
+ TensorType instance or GenericAlias
414
+ """
415
+ # Check if we are doing type specialization (Generic[T]) or instance creation
416
+ # Heuristic: If params contains a Type (class), it's a type spec.
417
+ args = params if isinstance(params, tuple) else (params,)
418
+ if any(isinstance(a, type) for a in args):
419
+ return super().__class_getitem__(params) # type: ignore[misc]
420
+
421
+ if not isinstance(params, tuple):
422
+ raise TypeError(
423
+ "Tensor requires shape parameter. Use Tensor[element_type, shape] "
424
+ "where shape is (), or a tuple of integers."
425
+ )
426
+
427
+ if len(params) != 2:
428
+ raise TypeError(
429
+ f"Tensor expects 2 parameters (element_type, shape), got {len(params)}"
430
+ )
431
+
432
+ element_type, shape = params
433
+ return cls(element_type, shape)
434
+
435
+ def __str__(self) -> str:
436
+ shape_str = ", ".join(str(d) for d in self.shape)
437
+ return f"Tensor[{self.element_type}, ({shape_str})]"
438
+
439
+ def __eq__(self, other: object) -> bool:
440
+ if not isinstance(other, TensorType):
441
+ return False
442
+ return self.element_type == other.element_type and self.shape == other.shape
443
+
444
+ def __hash__(self) -> int:
445
+ return hash((self.element_type, self.shape))
446
+
447
+ @property
448
+ def is_scalar(self) -> bool:
449
+ """Check if this is a scalar (0-dimensional) tensor."""
450
+ return self.shape == ()
451
+
452
+ @property
453
+ def is_fully_static(self) -> bool:
454
+ """Check if all dimensions are statically known."""
455
+ return all(dim > 0 for dim in self.shape)
456
+
457
+ @property
458
+ def rank(self) -> int:
459
+ """Get the rank (number of dimensions) of the tensor.
460
+
461
+ Returns:
462
+ int: Number of dimensions (always available for ranked tensors)
463
+ """
464
+ return len(self.shape)
465
+
466
+ def has_dynamic_dims(self) -> bool:
467
+ """Check if tensor has any dynamic dimensions (-1)."""
468
+ return any(dim == -1 for dim in self.shape)
469
+
470
+ # --- Serde methods ---
471
+ _serde_kind: ClassVar[str] = "mplang.TensorType"
472
+
473
+ def to_json(self) -> dict[str, Any]:
474
+ return {
475
+ "element_type": serde.to_json(self.element_type),
476
+ "shape": list(self.shape),
477
+ }
478
+
479
+ @classmethod
480
+ def from_json(cls, data: dict[str, Any]) -> TensorType[Any]:
481
+ element_type = serde.from_json(data["element_type"])
482
+ shape = tuple(data["shape"])
483
+ return cls(element_type, shape)
484
+
485
+
486
+ Tensor = TensorType
487
+
488
+
489
+ @serde.register_class
490
+ class VectorType(BaseType):
491
+ """Represents a packed SIMD vector of a given element type and size.
492
+
493
+ Unlike Tensor, which represents a logical multi-dimensional array,
494
+ Vector represents a physical packed layout (SIMD).
495
+ This is the underlying payload for SIMD_HE schemes (BFV, CKKS).
496
+
497
+ Args:
498
+ element_type: The type of elements in the vector (must be ScalarType).
499
+ size: The number of elements (slots) in the vector.
500
+ """
501
+
502
+ def __init__(self, element_type: ScalarType, size: int):
503
+ if not isinstance(element_type, ScalarType):
504
+ raise TypeError(
505
+ f"Vector element type must be a ScalarType, got {type(element_type).__name__}"
506
+ )
507
+ if not isinstance(size, int) or size <= 0:
508
+ raise ValueError(f"Vector size must be a positive integer, got {size}")
509
+
510
+ self.element_type = element_type
511
+ self.size = size
512
+
513
+ def __class_getitem__(cls, params: tuple) -> VectorType:
514
+ """Enables the syntax `Vector[element_type, size]`."""
515
+ if not isinstance(params, tuple) or len(params) != 2:
516
+ raise TypeError("Vector expects 2 parameters (element_type, size)")
517
+
518
+ element_type, size = params
519
+ return cls(element_type, size)
520
+
521
+ def __str__(self) -> str:
522
+ return f"Vector[{self.element_type}, {self.size}]"
523
+
524
+ def __eq__(self, other: object) -> bool:
525
+ if not isinstance(other, VectorType):
526
+ return False
527
+ return self.element_type == other.element_type and self.size == other.size
528
+
529
+ def __hash__(self) -> int:
530
+ return hash(("VectorType", self.element_type, self.size))
531
+
532
+ # --- Serde methods ---
533
+ _serde_kind: ClassVar[str] = "mplang.VectorType"
534
+
535
+ def to_json(self) -> dict[str, Any]:
536
+ return {
537
+ "element_type": serde.to_json(self.element_type),
538
+ "size": self.size,
539
+ }
540
+
541
+ @classmethod
542
+ def from_json(cls, data: dict[str, Any]) -> VectorType:
543
+ element_type = serde.from_json(data["element_type"])
544
+ if not isinstance(element_type, ScalarType):
545
+ raise TypeError(
546
+ f"VectorType element must be ScalarType, got {type(element_type)}"
547
+ )
548
+ return cls(element_type, data["size"])
549
+
550
+
551
+ Vector = VectorType
552
+
553
+
554
+ @serde.register_class
555
+ class TableType(BaseType):
556
+ """Represents a table with a named schema of types.
557
+
558
+ Examples:
559
+ >>> TableType({"id": i64, "name": STRING})
560
+ Table[{'id': i64, 'name': Custom[string]}]
561
+
562
+ >>> Table[{"col_a": i32, "col_b": f64}]
563
+ Table[{'col_a': i32, 'col_b': f64}]
564
+ """
565
+
566
+ def __init__(self, schema: dict[str, BaseType]):
567
+ self.schema = schema
568
+
569
+ def __class_getitem__(cls, schema: dict[str, BaseType]) -> TableType:
570
+ """Enables the syntax `Table[{'col_a': i32, ...}]`."""
571
+ return cls(schema)
572
+
573
+ def __str__(self) -> str:
574
+ schema_str = ", ".join(f"'{k}': {v}" for k, v in self.schema.items())
575
+ return f"Table[{{{schema_str}}}]"
576
+
577
+ def __eq__(self, other: object) -> bool:
578
+ if not isinstance(other, TableType):
579
+ return NotImplemented
580
+ return self.schema == other.schema
581
+
582
+ def __hash__(self) -> int:
583
+ return hash(("TableType", tuple(self.schema.items())))
584
+
585
+ # --- Serde methods ---
586
+ _serde_kind: ClassVar[str] = "mplang.TableType"
587
+
588
+ def to_json(self) -> dict[str, Any]:
589
+ return {
590
+ "schema": {name: serde.to_json(t) for name, t in self.schema.items()},
591
+ }
592
+
593
+ @classmethod
594
+ def from_json(cls, data: dict[str, Any]) -> TableType:
595
+ schema = {name: serde.from_json(t) for name, t in data["schema"].items()}
596
+ return cls(schema)
597
+
598
+
599
+ Table = TableType
600
+
601
+
602
+ @serde.register_class
603
+ class CustomType(BaseType):
604
+ """Opaque/custom type identified by a string kind.
605
+
606
+ Used for types that don't have explicit structure (like encryption keys,
607
+ database handles, or other opaque objects) but need to be tracked in the
608
+ type system.
609
+
610
+ Examples::
611
+
612
+ >>> key_type = CustomType("EncryptionKey")
613
+ >>> handle_type = CustomType("DatabaseHandle")
614
+ >>> token_type = CustomType("AuthToken")
615
+
616
+ The kind string serves as the identifier for equality and hashing.
617
+ Two CustomTypes are equal if and only if their kinds are equal.
618
+
619
+ Attributes:
620
+ kind: String identifier for this custom type.
621
+ """
622
+
623
+ def __init__(self, kind: str):
624
+ """Initialize a custom type.
625
+
626
+ Args:
627
+ kind: String identifier for this custom type.
628
+ Should be descriptive (e.g., "EncryptionKey", "Handle").
629
+
630
+ Raises:
631
+ TypeError: If kind is not a string.
632
+ ValueError: If kind is empty or whitespace-only.
633
+ """
634
+ if not isinstance(kind, str):
635
+ raise TypeError(f"kind must be str, got {type(kind).__name__}")
636
+ if not kind or kind.strip() == "":
637
+ raise ValueError("kind must be a non-empty string")
638
+
639
+ self._kind = kind
640
+
641
+ @property
642
+ def kind(self) -> str:
643
+ """Return the string identifier for this custom type."""
644
+ return self._kind
645
+
646
+ def __eq__(self, other: object) -> bool:
647
+ """Two CustomTypes are equal if their kinds match."""
648
+ if not isinstance(other, CustomType):
649
+ return False
650
+ return self._kind == other._kind
651
+
652
+ def __hash__(self) -> int:
653
+ """Hash based on kind for use in sets and dicts."""
654
+ return hash(("CustomType", self._kind))
655
+
656
+ def __repr__(self) -> str:
657
+ """Detailed string representation for debugging."""
658
+ return f"CustomType({self._kind!r})"
659
+
660
+ def __str__(self) -> str:
661
+ """User-friendly string representation."""
662
+ return f"Custom[{self._kind}]"
663
+
664
+ def __class_getitem__(cls, kind: str) -> CustomType:
665
+ """Enable Custom["TypeName"] syntax sugar.
666
+
667
+ Examples::
668
+
669
+ >>> EncryptionKey = Custom["EncryptionKey"]
670
+ >>> # Equivalent to:
671
+ >>> EncryptionKey = CustomType("EncryptionKey")
672
+ """
673
+ return cls(kind)
674
+
675
+ # --- Serde methods ---
676
+ _serde_kind: ClassVar[str] = "mplang.CustomType"
677
+
678
+ def to_json(self) -> dict[str, Any]:
679
+ return {"kind": self.kind}
680
+
681
+ @classmethod
682
+ def from_json(cls, data: dict[str, Any]) -> CustomType:
683
+ return cls(data["kind"])
684
+
685
+
686
+ # Shorthand alias
687
+ Custom = CustomType
688
+
689
+ # ==============================================================================
690
+ # --- Table-only Types (for SQL/DataFrame operations)
691
+ # ==============================================================================
692
+ # These types are used in TableType schemas but don't have direct tensor
693
+ # equivalents. They use CustomType for flexibility.
694
+
695
+ STRING = CustomType("string")
696
+ DATE = CustomType("date")
697
+ TIME = CustomType("time")
698
+ TIMESTAMP = CustomType("timestamp")
699
+ DECIMAL = CustomType("decimal")
700
+ BINARY = CustomType("binary")
701
+ JSON = CustomType("json")
702
+ UUID = CustomType("uuid")
703
+ INTERVAL = CustomType("interval")
704
+
705
+ # ==============================================================================
706
+ # --- Pillar 2: Encryption Types
707
+ # ==============================================================================
708
+
709
+
710
+ @serde.register_class
711
+ class SSType(BaseType, EncryptedTrait, Generic[T]):
712
+ """Represents a single share of a secret value `T`."""
713
+
714
+ def __init__(self, secret_type: BaseType, enc_schema: str = "ss"):
715
+ self._pt_type = secret_type
716
+ self._enc_schema = enc_schema
717
+
718
+ def __class_getitem__(cls, secret_type: BaseType | Any) -> Any:
719
+ """Enables the syntax `SS[Tensor[...]]`."""
720
+ # Check if we are doing type specialization (Generic[T]) or instance creation
721
+ if isinstance(secret_type, type):
722
+ return super().__class_getitem__(secret_type) # type: ignore[misc]
723
+ return cls(secret_type)
724
+
725
+ def __str__(self) -> str:
726
+ return f"SS[{self.pt_type}]"
727
+
728
+ def __eq__(self, other: object) -> bool:
729
+ if not isinstance(other, SSType):
730
+ return False
731
+ return self.pt_type == other.pt_type and self.enc_schema == other.enc_schema
732
+
733
+ def __hash__(self) -> int:
734
+ return hash(("SSType", self.pt_type, self.enc_schema))
735
+
736
+ # --- Serde methods ---
737
+ _serde_kind: ClassVar[str] = "mplang.SSType"
738
+
739
+ def to_json(self) -> dict[str, Any]:
740
+ return {
741
+ "secret_type": serde.to_json(self._pt_type),
742
+ "enc_schema": self._enc_schema,
743
+ }
744
+
745
+ @classmethod
746
+ def from_json(cls, data: dict[str, Any]) -> SSType[Any]:
747
+ secret_type = serde.from_json(data["secret_type"])
748
+ return cls(secret_type, enc_schema=data.get("enc_schema", "ss"))
749
+
750
+
751
+ SS = SSType
752
+
753
+ # ==============================================================================
754
+ # --- Pillar 3: Distribution Types
755
+ # ==============================================================================
756
+
757
+
758
+ @serde.register_class
759
+ class MPType(BaseType, Generic[T]):
760
+ """Represents a logical value distributed among multiple parties.
761
+
762
+ Args:
763
+ value_type: The type of the value held by parties
764
+ parties: Tuple of party IDs (static mask) or None (dynamic mask)
765
+ """
766
+
767
+ def __init__(self, value_type: BaseType, parties: tuple[int, ...] | None):
768
+ self._value_type = value_type
769
+ self._parties = parties
770
+
771
+ @property
772
+ def value_type(self) -> BaseType:
773
+ return self._value_type
774
+
775
+ @property
776
+ def parties(self) -> tuple[int, ...] | None:
777
+ return self._parties
778
+
779
+ def __class_getitem__(
780
+ cls, params: tuple[BaseType, tuple[int, ...] | None] | Any
781
+ ) -> Any:
782
+ """Enables the syntax `MP[Tensor[...], (0, 1)]` or `MP[Tensor[...], None]`."""
783
+ # Check if we are doing type specialization (Generic[T]) or instance creation
784
+ # Heuristic: If params contains a Type (class), it's a type spec.
785
+ args = params if isinstance(params, tuple) else (params,)
786
+ if any(isinstance(a, type) for a in args):
787
+ return super().__class_getitem__(params) # type: ignore[misc]
788
+
789
+ value_type, parties = params
790
+ return cls(value_type, parties)
791
+
792
+ def __str__(self) -> str:
793
+ return f"MP[{self.value_type}, parties={self.parties}]"
794
+
795
+ def __eq__(self, other: object) -> bool:
796
+ if not isinstance(other, MPType):
797
+ return False
798
+ return self.value_type == other.value_type and self.parties == other.parties
799
+
800
+ def __hash__(self) -> int:
801
+ return hash(("MPType", self.value_type, self.parties))
802
+
803
+ # --- Serde methods ---
804
+ _serde_kind: ClassVar[str] = "mplang.MPType"
805
+
806
+ def to_json(self) -> dict[str, Any]:
807
+ return {
808
+ "value_type": serde.to_json(self._value_type),
809
+ "parties": list(self._parties) if self._parties is not None else None,
810
+ }
811
+
812
+ @classmethod
813
+ def from_json(cls, data: dict[str, Any]) -> MPType[Any]:
814
+ value_type = serde.from_json(data["value_type"])
815
+ parties = tuple(data["parties"]) if data["parties"] is not None else None
816
+ return cls(value_type, parties)