mplang-nightly 0.1.dev192__py3-none-any.whl → 0.1.dev268__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/__init__.py +21 -130
- mplang/py.typed +13 -0
- mplang/v1/__init__.py +157 -0
- mplang/v1/_device.py +602 -0
- mplang/{analysis → v1/analysis}/__init__.py +1 -1
- mplang/{analysis → v1/analysis}/diagram.py +4 -4
- mplang/{core → v1/core}/__init__.py +20 -14
- mplang/{core → v1/core}/cluster.py +6 -1
- mplang/{core → v1/core}/comm.py +1 -1
- mplang/{core → v1/core}/context_mgr.py +1 -1
- mplang/{core → v1/core}/dtypes.py +38 -0
- mplang/{core → v1/core}/expr/__init__.py +7 -7
- mplang/{core → v1/core}/expr/ast.py +11 -13
- mplang/{core → v1/core}/expr/evaluator.py +8 -8
- mplang/{core → v1/core}/expr/printer.py +6 -6
- mplang/{core → v1/core}/expr/transformer.py +2 -2
- mplang/{core → v1/core}/expr/utils.py +2 -2
- mplang/{core → v1/core}/expr/visitor.py +1 -1
- mplang/{core → v1/core}/expr/walk.py +1 -1
- mplang/{core → v1/core}/interp.py +6 -6
- mplang/{core → v1/core}/mpir.py +13 -11
- mplang/{core → v1/core}/mpobject.py +6 -6
- mplang/{core → v1/core}/mptype.py +13 -10
- mplang/{core → v1/core}/pfunc.py +2 -2
- mplang/{core → v1/core}/primitive.py +12 -12
- mplang/{core → v1/core}/table.py +36 -8
- mplang/{core → v1/core}/tensor.py +1 -1
- mplang/{core → v1/core}/tracer.py +9 -9
- mplang/{host.py → v1/host.py} +5 -5
- mplang/{kernels → v1/kernels}/__init__.py +1 -1
- mplang/{kernels → v1/kernels}/base.py +1 -1
- mplang/{kernels → v1/kernels}/basic.py +15 -15
- mplang/{kernels → v1/kernels}/context.py +19 -16
- mplang/{kernels → v1/kernels}/crypto.py +8 -10
- mplang/{kernels → v1/kernels}/fhe.py +9 -7
- mplang/{kernels → v1/kernels}/mock_tee.py +3 -3
- mplang/{kernels → v1/kernels}/phe.py +26 -18
- mplang/{kernels → v1/kernels}/spu.py +5 -5
- mplang/{kernels → v1/kernels}/sql_duckdb.py +5 -3
- mplang/{kernels → v1/kernels}/stablehlo.py +18 -17
- mplang/{kernels → v1/kernels}/value.py +2 -2
- mplang/{ops → v1/ops}/__init__.py +3 -3
- mplang/{ops → v1/ops}/base.py +1 -1
- mplang/{ops → v1/ops}/basic.py +6 -5
- mplang/v1/ops/crypto.py +262 -0
- mplang/{ops → v1/ops}/fhe.py +2 -2
- mplang/{ops → v1/ops}/jax_cc.py +26 -59
- mplang/v1/ops/nnx_cc.py +168 -0
- mplang/{ops → v1/ops}/phe.py +16 -3
- mplang/{ops → v1/ops}/spu.py +3 -3
- mplang/v1/ops/sql_cc.py +303 -0
- mplang/{ops → v1/ops}/tee.py +2 -2
- mplang/{runtime → v1/runtime}/__init__.py +2 -2
- mplang/v1/runtime/channel.py +230 -0
- mplang/{runtime → v1/runtime}/cli.py +3 -3
- mplang/{runtime → v1/runtime}/client.py +1 -1
- mplang/{runtime → v1/runtime}/communicator.py +39 -15
- mplang/{runtime → v1/runtime}/data_providers.py +80 -19
- mplang/{runtime → v1/runtime}/driver.py +4 -4
- mplang/v1/runtime/link_comm.py +196 -0
- mplang/{runtime → v1/runtime}/server.py +22 -9
- mplang/{runtime → v1/runtime}/session.py +24 -51
- mplang/{runtime → v1/runtime}/simulation.py +36 -14
- mplang/{simp → v1/simp}/api.py +72 -14
- mplang/{simp → v1/simp}/mpi.py +1 -1
- mplang/{simp → v1/simp}/party.py +5 -5
- mplang/{simp → v1/simp}/random.py +2 -2
- mplang/v1/simp/smpc.py +238 -0
- mplang/v1/utils/table_utils.py +185 -0
- mplang/v2/__init__.py +424 -0
- mplang/v2/backends/__init__.py +57 -0
- mplang/v2/backends/bfv_impl.py +705 -0
- mplang/v2/backends/channel.py +217 -0
- mplang/v2/backends/crypto_impl.py +723 -0
- mplang/v2/backends/field_impl.py +454 -0
- mplang/v2/backends/func_impl.py +107 -0
- mplang/v2/backends/phe_impl.py +148 -0
- mplang/v2/backends/simp_design.md +136 -0
- mplang/v2/backends/simp_driver/__init__.py +41 -0
- mplang/v2/backends/simp_driver/http.py +168 -0
- mplang/v2/backends/simp_driver/mem.py +280 -0
- mplang/v2/backends/simp_driver/ops.py +135 -0
- mplang/v2/backends/simp_driver/state.py +60 -0
- mplang/v2/backends/simp_driver/values.py +52 -0
- mplang/v2/backends/simp_worker/__init__.py +29 -0
- mplang/v2/backends/simp_worker/http.py +354 -0
- mplang/v2/backends/simp_worker/mem.py +102 -0
- mplang/v2/backends/simp_worker/ops.py +167 -0
- mplang/v2/backends/simp_worker/state.py +49 -0
- mplang/v2/backends/spu_impl.py +275 -0
- mplang/v2/backends/spu_state.py +187 -0
- mplang/v2/backends/store_impl.py +62 -0
- mplang/v2/backends/table_impl.py +838 -0
- mplang/v2/backends/tee_impl.py +215 -0
- mplang/v2/backends/tensor_impl.py +519 -0
- mplang/v2/cli.py +603 -0
- mplang/v2/cli_guide.md +122 -0
- mplang/v2/dialects/__init__.py +36 -0
- mplang/v2/dialects/bfv.py +665 -0
- mplang/v2/dialects/crypto.py +689 -0
- mplang/v2/dialects/dtypes.py +378 -0
- mplang/v2/dialects/field.py +210 -0
- mplang/v2/dialects/func.py +135 -0
- mplang/v2/dialects/phe.py +723 -0
- mplang/v2/dialects/simp.py +944 -0
- mplang/v2/dialects/spu.py +349 -0
- mplang/v2/dialects/store.py +63 -0
- mplang/v2/dialects/table.py +407 -0
- mplang/v2/dialects/tee.py +346 -0
- mplang/v2/dialects/tensor.py +1175 -0
- mplang/v2/edsl/README.md +279 -0
- mplang/v2/edsl/__init__.py +99 -0
- mplang/v2/edsl/context.py +311 -0
- mplang/v2/edsl/graph.py +463 -0
- mplang/v2/edsl/jit.py +62 -0
- mplang/v2/edsl/object.py +53 -0
- mplang/v2/edsl/primitive.py +284 -0
- mplang/v2/edsl/printer.py +119 -0
- mplang/v2/edsl/registry.py +207 -0
- mplang/v2/edsl/serde.py +375 -0
- mplang/v2/edsl/tracer.py +614 -0
- mplang/v2/edsl/typing.py +816 -0
- mplang/v2/kernels/Makefile +30 -0
- mplang/v2/kernels/__init__.py +23 -0
- mplang/v2/kernels/gf128.cpp +148 -0
- mplang/v2/kernels/ldpc.cpp +82 -0
- mplang/v2/kernels/okvs.cpp +283 -0
- mplang/v2/kernels/okvs_opt.cpp +291 -0
- mplang/v2/kernels/py_kernels.py +398 -0
- mplang/v2/libs/collective.py +330 -0
- mplang/v2/libs/device/__init__.py +51 -0
- mplang/v2/libs/device/api.py +813 -0
- mplang/v2/libs/device/cluster.py +352 -0
- mplang/v2/libs/ml/__init__.py +23 -0
- mplang/v2/libs/ml/sgb.py +1861 -0
- mplang/v2/libs/mpc/__init__.py +41 -0
- mplang/v2/libs/mpc/_utils.py +99 -0
- mplang/v2/libs/mpc/analytics/__init__.py +35 -0
- mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
- mplang/v2/libs/mpc/analytics/groupby.md +99 -0
- mplang/v2/libs/mpc/analytics/groupby.py +331 -0
- mplang/v2/libs/mpc/analytics/permutation.py +386 -0
- mplang/v2/libs/mpc/common/constants.py +39 -0
- mplang/v2/libs/mpc/ot/__init__.py +32 -0
- mplang/v2/libs/mpc/ot/base.py +222 -0
- mplang/v2/libs/mpc/ot/extension.py +477 -0
- mplang/v2/libs/mpc/ot/silent.py +217 -0
- mplang/v2/libs/mpc/psi/__init__.py +40 -0
- mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
- mplang/v2/libs/mpc/psi/okvs.py +49 -0
- mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
- mplang/v2/libs/mpc/psi/oprf.py +310 -0
- mplang/v2/libs/mpc/psi/rr22.py +344 -0
- mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
- mplang/v2/libs/mpc/vole/__init__.py +31 -0
- mplang/v2/libs/mpc/vole/gilboa.py +327 -0
- mplang/v2/libs/mpc/vole/ldpc.py +383 -0
- mplang/v2/libs/mpc/vole/silver.py +336 -0
- mplang/v2/runtime/__init__.py +15 -0
- mplang/v2/runtime/dialect_state.py +41 -0
- mplang/v2/runtime/interpreter.py +871 -0
- mplang/v2/runtime/object_store.py +194 -0
- mplang/v2/runtime/value.py +141 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +22 -16
- mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
- mplang/device.py +0 -327
- mplang/ops/crypto.py +0 -108
- mplang/ops/ibis_cc.py +0 -136
- mplang/ops/sql_cc.py +0 -62
- mplang/runtime/link_comm.py +0 -78
- mplang/simp/smpc.py +0 -201
- mplang/utils/table_utils.py +0 -85
- mplang_nightly-0.1.dev192.dist-info/RECORD +0 -83
- /mplang/{core → v1/core}/mask.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.py +0 -0
- /mplang/{protos → v1/protos}/v1alpha1/value_pb2.pyi +0 -0
- /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
- /mplang/{runtime → v1/runtime}/http_api.md +0 -0
- /mplang/{simp → v1/simp}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/__init__.py +0 -0
- /mplang/{utils → v1/utils}/crypto.py +0 -0
- /mplang/{utils → v1/utils}/func_utils.py +0 -0
- /mplang/{utils → v1/utils}/spu_utils.py +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev192.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
mplang/v2/edsl/typing.py
ADDED
|
@@ -0,0 +1,816 @@
|
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
MPLang Core Typing System: Design and Rationale.
|
|
17
|
+
|
|
18
|
+
This module defines the production type system for MPLang, an EDSL for multi-party
|
|
19
|
+
privacy-preserving computation. This document explains the core principles and design
|
|
20
|
+
decisions that shape this system, intended for future maintainers and developers.
|
|
21
|
+
|
|
22
|
+
===========================
|
|
23
|
+
Tensor Shape System
|
|
24
|
+
===========================
|
|
25
|
+
MPLang supports a flexible shape system for tensors to handle various compilation and
|
|
26
|
+
runtime scenarios:
|
|
27
|
+
|
|
28
|
+
**Shape Representations:**
|
|
29
|
+
- `None`: Fully dynamic/unranked tensor (shape unknown at compile time)
|
|
30
|
+
Example: `Tensor[i32, None]`
|
|
31
|
+
|
|
32
|
+
- `()`: Scalar (0-dimensional tensor)
|
|
33
|
+
Example: `Tensor[i32, ()]`
|
|
34
|
+
|
|
35
|
+
- `(dim1, dim2, ...)`: Ranked tensor with static or dynamic dimensions
|
|
36
|
+
- Positive integers: Static dimension sizes
|
|
37
|
+
- `-1`: Dynamic/unknown dimension size
|
|
38
|
+
Examples:
|
|
39
|
+
- `Tensor[i32, (3, 10)]` - Fully static 2D tensor
|
|
40
|
+
- `Tensor[i32, (-1, 10)]` - Dynamic batch size, static feature size
|
|
41
|
+
- `Tensor[i32, (-1, -1)]` - Fully dynamic 2D tensor
|
|
42
|
+
|
|
43
|
+
**Utility Properties:**
|
|
44
|
+
- `.is_scalar`: Check if tensor is 0-dimensional
|
|
45
|
+
- `.is_unranked`: Check if shape is None
|
|
46
|
+
- `.is_fully_static`: Check if all dimensions are statically known
|
|
47
|
+
- `.rank`: Get number of dimensions (None for unranked)
|
|
48
|
+
- `.has_dynamic_dims()`: Check if any dimension is dynamic
|
|
49
|
+
|
|
50
|
+
===========================
|
|
51
|
+
Principle 1: Orthogonality and Composition
|
|
52
|
+
===========================
|
|
53
|
+
The type system is built on three orthogonal pillars. Each type represents a single,
|
|
54
|
+
well-defined concept. Complex ideas are expressed by composing these simple types,
|
|
55
|
+
rather than by creating a large, monolithic set of specific types.
|
|
56
|
+
|
|
57
|
+
1. **Layout Types**: Describe the physical shape and structure of data.
|
|
58
|
+
- `Scalar`: Atomic data types (f32, i64).
|
|
59
|
+
- `Tensor`: A multi-dimensional array of a `ScalarType` element type.
|
|
60
|
+
- `Table`: A dictionary-like structure with named columns of any type.
|
|
61
|
+
|
|
62
|
+
2. **Encryption Types**: Wrap other types to confer privacy properties by making them opaque.
|
|
63
|
+
- `SS`: A single share of a secret-shared value.
|
|
64
|
+
- Note: Element-wise HE types (like `phe.CiphertextType`) are defined in their respective dialects (e.g., `phe`).
|
|
65
|
+
|
|
66
|
+
3. **Distribution Types**: Wrap other types to describe their physical location among parties.
|
|
67
|
+
- `MP`: Represents a value logically held by multiple parties.
|
|
68
|
+
|
|
69
|
+
An example of composition: `MP[SS[Tensor[f32, (10,)]], (0, 1)]` represents a
|
|
70
|
+
10-element float tensor, which is secret-shared (`SS`), and whose shares are distributed
|
|
71
|
+
between parties 0 and 1 (`MP`).
|
|
72
|
+
|
|
73
|
+
===========================
|
|
74
|
+
Principle 2: The "Three Worlds" of Homomorphic Encryption
|
|
75
|
+
===========================
|
|
76
|
+
A critical design decision is the strict separation of HE-based computation into three
|
|
77
|
+
distinct, non-interacting "worlds." This avoids ambiguity in operator semantics (e.g., `transpose`),
|
|
78
|
+
clarifies the user's mental model, and aligns the type system with the practical realities of
|
|
79
|
+
underlying HE libraries.
|
|
80
|
+
|
|
81
|
+
- **World 1: The Plaintext World**
|
|
82
|
+
- **Core Type**: `Tensor[Scalar, ...]`
|
|
83
|
+
- **API Standard**: Follows NumPy/JAX conventions. All layout and arithmetic operations are valid.
|
|
84
|
+
|
|
85
|
+
- **Core Type**: `Tensor[EncryptedScalar, ...]` (e.g., `Tensor[phe.CiphertextType, ...]`)
|
|
86
|
+
- **API Standard**: Follows TenSEAL-like (Tensor-level) conventions. Layout operations
|
|
87
|
+
(`transpose`, `reshape`) are valid as they merely shuffle independent ciphertext objects.
|
|
88
|
+
Arithmetic operations are overloaded for element-wise HE computation.
|
|
89
|
+
|
|
90
|
+
===========================
|
|
91
|
+
Principle 3: Contracts via Protocols
|
|
92
|
+
===========================
|
|
93
|
+
The system uses `typing.Protocol` to define behavioral contracts (similar to Traits in Rust).
|
|
94
|
+
This allows for writing generic functions that operate on any type satisfying a contract,
|
|
95
|
+
promoting extensibility and loose coupling via structural subtyping ("duck typing").
|
|
96
|
+
|
|
97
|
+
- `EncryptedTrait`: For types representing data in an obscured form.
|
|
98
|
+
- `Distributed`: For types describing data distribution.
|
|
99
|
+
|
|
100
|
+
===========================
|
|
101
|
+
Rationale for the `EncryptedTrait` Protocol
|
|
102
|
+
===========================
|
|
103
|
+
The name `EncryptedTrait` was deliberately chosen over the more general `PrivacyBearing` after
|
|
104
|
+
careful consideration.
|
|
105
|
+
|
|
106
|
+
1. **Scope is Naturally Limited**: Other privacy techniques like Differential Privacy or
|
|
107
|
+
Federated Learning are algorithmic or orchestration patterns that do not require new
|
|
108
|
+
type wrappers for the data itself. A DP-protected tensor is still a `Tensor`.
|
|
109
|
+
Therefore, the protocol only needs to cover technologies that transform data into an
|
|
110
|
+
opaque representation.
|
|
111
|
+
|
|
112
|
+
2. **Secret Sharing as a form of Encryption**: The key insight is to conceptualize
|
|
113
|
+
Secret Sharing (`SS`) as a form of multi-key encryption. For a holder of a single
|
|
114
|
+
share, the other parties' shares are analogous to the "key" needed to recover the
|
|
115
|
+
secret. Both `HE` and `SS` render the data opaque and require external information
|
|
116
|
+
(a key or other shares) for recovery. This powerful mental model allows both `HE`/`SIMD_HE`
|
|
117
|
+
and `SS` to logically implement the `Encrypted` protocol.
|
|
118
|
+
|
|
119
|
+
This makes `Encrypted` a name that is both intuitive to engineers and conceptually
|
|
120
|
+
consistent within the practical scope of this library.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
from __future__ import annotations
|
|
124
|
+
|
|
125
|
+
from typing import Any, ClassVar, Generic, TypeVar
|
|
126
|
+
|
|
127
|
+
from mplang.v2.edsl import serde
|
|
128
|
+
|
|
129
|
+
# ==============================================================================
|
|
130
|
+
# --- Base Type & Type Aliases
|
|
131
|
+
# ==============================================================================
|
|
132
|
+
|
|
133
|
+
T = TypeVar("T")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class BaseType:
|
|
137
|
+
"""Base class for all MPLang types."""
|
|
138
|
+
|
|
139
|
+
def __repr__(self) -> str:
|
|
140
|
+
return str(self)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
# ==============================================================================
|
|
144
|
+
# --- Type Protocols (Contracts)
|
|
145
|
+
# ==============================================================================
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class EncryptedTrait:
|
|
149
|
+
"""A contract for types that represent data in an encrypted or obscured form."""
|
|
150
|
+
|
|
151
|
+
_pt_type: BaseType
|
|
152
|
+
_enc_schema: str
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def pt_type(self) -> BaseType:
|
|
156
|
+
return self._pt_type
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def enc_schema(self) -> str:
|
|
160
|
+
return self._enc_schema
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# ==============================================================================
|
|
164
|
+
# --- Pillar 1: Layout Types
|
|
165
|
+
# ==============================================================================
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class ScalarType(BaseType):
|
|
169
|
+
"""Base class for all scalar types (integers, floats, complex).
|
|
170
|
+
|
|
171
|
+
This serves as the common parent for IntegerType, FloatType, and ComplexType,
|
|
172
|
+
allowing code to accept any scalar type without needing union types.
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@serde.register_class
|
|
177
|
+
class IntegerType(ScalarType):
|
|
178
|
+
"""Represents a variable-length integer type.
|
|
179
|
+
|
|
180
|
+
This is a standard integer type with configurable bit width, used for
|
|
181
|
+
arbitrary-precision arithmetic. It can represent integers that exceed
|
|
182
|
+
the range of fixed-width types like i64.
|
|
183
|
+
|
|
184
|
+
Examples:
|
|
185
|
+
>>> i128 = IntegerType(bitwidth=128, signed=True) # i128
|
|
186
|
+
>>> u256 = IntegerType(bitwidth=256, signed=False) # u256
|
|
187
|
+
|
|
188
|
+
Note:
|
|
189
|
+
Encoding-specific metadata (e.g., fixed-point scale, semantic type)
|
|
190
|
+
should be maintained as attributes on operations/objects that use
|
|
191
|
+
IntegerType, not on the type itself.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
def __init__(self, *, bitwidth: int = 32, signed: bool = True):
|
|
195
|
+
"""Initialize an IntegerType.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
bitwidth: Number of bits for the integer representation.
|
|
199
|
+
Common values: 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096.
|
|
200
|
+
signed: Whether the integer is signed (True) or unsigned (False).
|
|
201
|
+
"""
|
|
202
|
+
if bitwidth <= 0 or (bitwidth & (bitwidth - 1)) != 0:
|
|
203
|
+
raise ValueError(f"bitwidth must be a positive power of 2, got {bitwidth}")
|
|
204
|
+
self.bitwidth = bitwidth
|
|
205
|
+
self.signed = signed
|
|
206
|
+
|
|
207
|
+
def __str__(self) -> str:
|
|
208
|
+
sign_prefix = "i" if self.signed else "u"
|
|
209
|
+
return f"{sign_prefix}{self.bitwidth}"
|
|
210
|
+
|
|
211
|
+
def __eq__(self, other: object) -> bool:
|
|
212
|
+
if not isinstance(other, IntegerType):
|
|
213
|
+
return False
|
|
214
|
+
return self.bitwidth == other.bitwidth and self.signed == other.signed
|
|
215
|
+
|
|
216
|
+
def __hash__(self) -> int:
|
|
217
|
+
return hash(("IntegerType", self.bitwidth, self.signed))
|
|
218
|
+
|
|
219
|
+
# --- Serde methods ---
|
|
220
|
+
_serde_kind: ClassVar[str] = "mplang.IntegerType"
|
|
221
|
+
|
|
222
|
+
def to_json(self) -> dict[str, Any]:
|
|
223
|
+
return {"bitwidth": self.bitwidth, "signed": self.signed}
|
|
224
|
+
|
|
225
|
+
@classmethod
|
|
226
|
+
def from_json(cls, data: dict[str, Any]) -> IntegerType:
|
|
227
|
+
return cls(bitwidth=data["bitwidth"], signed=data["signed"])
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@serde.register_class
|
|
231
|
+
class FloatType(ScalarType):
|
|
232
|
+
"""Represents a floating-point type.
|
|
233
|
+
|
|
234
|
+
This supports standard IEEE 754 floating-point types with configurable
|
|
235
|
+
precision (bitwidth).
|
|
236
|
+
|
|
237
|
+
Examples:
|
|
238
|
+
>>> f16 = FloatType(bitwidth=16) # half precision
|
|
239
|
+
>>> f32 = FloatType(bitwidth=32) # single precision
|
|
240
|
+
>>> f64 = FloatType(bitwidth=64) # double precision
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
def __init__(self, *, bitwidth: int = 32):
|
|
244
|
+
"""Initialize a FloatType.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
bitwidth: Number of bits for the float representation.
|
|
248
|
+
Standard values: 16 (half), 32 (single), 64 (double).
|
|
249
|
+
"""
|
|
250
|
+
if bitwidth not in (16, 32, 64, 128):
|
|
251
|
+
raise ValueError(f"bitwidth must be 16, 32, 64, or 128, got {bitwidth}")
|
|
252
|
+
self.bitwidth = bitwidth
|
|
253
|
+
|
|
254
|
+
def __str__(self) -> str:
|
|
255
|
+
return f"f{self.bitwidth}"
|
|
256
|
+
|
|
257
|
+
def __eq__(self, other: object) -> bool:
|
|
258
|
+
if not isinstance(other, FloatType):
|
|
259
|
+
return False
|
|
260
|
+
return self.bitwidth == other.bitwidth
|
|
261
|
+
|
|
262
|
+
def __hash__(self) -> int:
|
|
263
|
+
return hash(("FloatType", self.bitwidth))
|
|
264
|
+
|
|
265
|
+
# --- Serde methods ---
|
|
266
|
+
_serde_kind: ClassVar[str] = "mplang.FloatType"
|
|
267
|
+
|
|
268
|
+
def to_json(self) -> dict[str, Any]:
|
|
269
|
+
return {"bitwidth": self.bitwidth}
|
|
270
|
+
|
|
271
|
+
@classmethod
|
|
272
|
+
def from_json(cls, data: dict[str, Any]) -> FloatType:
|
|
273
|
+
return cls(bitwidth=data["bitwidth"])
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@serde.register_class
|
|
277
|
+
class ComplexType(ScalarType):
|
|
278
|
+
"""Represents a complex number type.
|
|
279
|
+
|
|
280
|
+
Complex numbers are represented as pairs of floating-point values.
|
|
281
|
+
Both real and imaginary parts use the same floating-point type.
|
|
282
|
+
|
|
283
|
+
Examples:
|
|
284
|
+
>>> c64 = ComplexType(inner_type=f32) # complex64 (2x float32)
|
|
285
|
+
>>> c128 = ComplexType(inner_type=f64) # complex128 (2x float64)
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
def __init__(self, *, inner_type: FloatType):
|
|
289
|
+
"""Initialize a ComplexType.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
inner_type: The floating-point type for real and imaginary parts.
|
|
293
|
+
Common values: f16, f32, f64, f128.
|
|
294
|
+
"""
|
|
295
|
+
if not isinstance(inner_type, FloatType):
|
|
296
|
+
raise TypeError(
|
|
297
|
+
f"inner_type must be a FloatType, got {type(inner_type).__name__}"
|
|
298
|
+
)
|
|
299
|
+
self.inner_type = inner_type
|
|
300
|
+
|
|
301
|
+
def __str__(self) -> str:
|
|
302
|
+
return f"c{self.inner_type.bitwidth * 2}"
|
|
303
|
+
|
|
304
|
+
def __eq__(self, other: object) -> bool:
|
|
305
|
+
if not isinstance(other, ComplexType):
|
|
306
|
+
return False
|
|
307
|
+
return self.inner_type == other.inner_type
|
|
308
|
+
|
|
309
|
+
def __hash__(self) -> int:
|
|
310
|
+
return hash(("ComplexType", self.inner_type))
|
|
311
|
+
|
|
312
|
+
# --- Serde methods ---
|
|
313
|
+
_serde_kind: ClassVar[str] = "mplang.ComplexType"
|
|
314
|
+
|
|
315
|
+
def to_json(self) -> dict[str, Any]:
|
|
316
|
+
return {"inner_type": serde.to_json(self.inner_type)}
|
|
317
|
+
|
|
318
|
+
@classmethod
|
|
319
|
+
def from_json(cls, data: dict[str, Any]) -> ComplexType:
|
|
320
|
+
inner = serde.from_json(data["inner_type"])
|
|
321
|
+
if not isinstance(inner, FloatType):
|
|
322
|
+
raise TypeError(f"ComplexType inner must be FloatType, got {type(inner)}")
|
|
323
|
+
return cls(inner_type=inner)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
# ==============================================================================
|
|
327
|
+
# --- Predefined Scalar Type Instances
|
|
328
|
+
# ==============================================================================
|
|
329
|
+
|
|
330
|
+
# Numeric scalar types - comprehensive set aligned with common dtypes
|
|
331
|
+
# Integer types (signed)
|
|
332
|
+
i8 = IntegerType(bitwidth=8, signed=True)
|
|
333
|
+
i16 = IntegerType(bitwidth=16, signed=True)
|
|
334
|
+
i32 = IntegerType(bitwidth=32, signed=True)
|
|
335
|
+
i64 = IntegerType(bitwidth=64, signed=True)
|
|
336
|
+
|
|
337
|
+
# Fixed-width integer types (unsigned)
|
|
338
|
+
u8 = IntegerType(bitwidth=8, signed=False)
|
|
339
|
+
u16 = IntegerType(bitwidth=16, signed=False)
|
|
340
|
+
u32 = IntegerType(bitwidth=32, signed=False)
|
|
341
|
+
u64 = IntegerType(bitwidth=64, signed=False)
|
|
342
|
+
|
|
343
|
+
# Floating point types
|
|
344
|
+
f16 = FloatType(bitwidth=16)
|
|
345
|
+
f32 = FloatType(bitwidth=32)
|
|
346
|
+
f64 = FloatType(bitwidth=64)
|
|
347
|
+
|
|
348
|
+
# Complex types
|
|
349
|
+
c64 = ComplexType(inner_type=f32) # 2x float32 = 64 bits total
|
|
350
|
+
c128 = ComplexType(inner_type=f64) # 2x float64 = 128 bits total
|
|
351
|
+
|
|
352
|
+
# Boolean type (1-bit integer, commonly used)
|
|
353
|
+
bool_ = IntegerType(bitwidth=1, signed=True)
|
|
354
|
+
i1 = bool_ # Alias for MLIR convention
|
|
355
|
+
|
|
356
|
+
# Variable-length integer types (common sizes)
|
|
357
|
+
i128 = IntegerType(bitwidth=128, signed=True)
|
|
358
|
+
i256 = IntegerType(bitwidth=256, signed=True)
|
|
359
|
+
u128 = IntegerType(bitwidth=128, signed=False)
|
|
360
|
+
u256 = IntegerType(bitwidth=256, signed=False)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
@serde.register_class
|
|
364
|
+
class TensorType(BaseType, Generic[T]):
|
|
365
|
+
"""Represents a ranked tensor of a given element type and shape.
|
|
366
|
+
|
|
367
|
+
Following MLIR's RankedTensorType design - all tensors must have a known rank.
|
|
368
|
+
This simplifies type inference and reduces complexity compared to supporting
|
|
369
|
+
fully unranked tensors.
|
|
370
|
+
|
|
371
|
+
Shape must be a tuple where each dimension can be:
|
|
372
|
+
- Positive integer: Static dimension size
|
|
373
|
+
- -1: Dynamic/unknown dimension size
|
|
374
|
+
|
|
375
|
+
Examples:
|
|
376
|
+
Tensor[i32, ()] # Scalar (0-dim tensor)
|
|
377
|
+
Tensor[i32, (-1, 10)] # Partially dynamic shape (rank=2)
|
|
378
|
+
Tensor[i32, (3, 10)] # Fully static shape (rank=2)
|
|
379
|
+
Tensor[i32, (-1,)] # 1D tensor with dynamic size
|
|
380
|
+
"""
|
|
381
|
+
|
|
382
|
+
def __init__(self, element_type: BaseType, shape: tuple[int, ...]):
|
|
383
|
+
# Allow any BaseType to support custom types like PointType, EncryptedScalar
|
|
384
|
+
if not isinstance(element_type, BaseType):
|
|
385
|
+
raise TypeError(
|
|
386
|
+
f"Tensor element type must be a BaseType, but got {type(element_type).__name__}."
|
|
387
|
+
)
|
|
388
|
+
self.element_type = element_type
|
|
389
|
+
self.shape = shape
|
|
390
|
+
|
|
391
|
+
# Validate shape is a tuple
|
|
392
|
+
if not isinstance(shape, tuple):
|
|
393
|
+
raise TypeError(f"Shape must be a tuple, got {type(shape).__name__}")
|
|
394
|
+
|
|
395
|
+
# Validate each dimension
|
|
396
|
+
for dim in shape:
|
|
397
|
+
if not isinstance(dim, int):
|
|
398
|
+
raise TypeError(
|
|
399
|
+
f"Shape dimensions must be integers, got {type(dim).__name__}"
|
|
400
|
+
)
|
|
401
|
+
if dim < -1 or dim == 0:
|
|
402
|
+
raise ValueError(
|
|
403
|
+
f"Invalid dimension {dim}: must be positive or -1 for dynamic"
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
def __class_getitem__(cls, params: tuple | Any) -> Any:
|
|
407
|
+
"""Enables the syntax `Tensor[element_type, shape]`.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
params: Either a single element_type or (element_type, shape) tuple
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
TensorType instance or GenericAlias
|
|
414
|
+
"""
|
|
415
|
+
# Check if we are doing type specialization (Generic[T]) or instance creation
|
|
416
|
+
# Heuristic: If params contains a Type (class), it's a type spec.
|
|
417
|
+
args = params if isinstance(params, tuple) else (params,)
|
|
418
|
+
if any(isinstance(a, type) for a in args):
|
|
419
|
+
return super().__class_getitem__(params) # type: ignore[misc]
|
|
420
|
+
|
|
421
|
+
if not isinstance(params, tuple):
|
|
422
|
+
raise TypeError(
|
|
423
|
+
"Tensor requires shape parameter. Use Tensor[element_type, shape] "
|
|
424
|
+
"where shape is (), or a tuple of integers."
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
if len(params) != 2:
|
|
428
|
+
raise TypeError(
|
|
429
|
+
f"Tensor expects 2 parameters (element_type, shape), got {len(params)}"
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
element_type, shape = params
|
|
433
|
+
return cls(element_type, shape)
|
|
434
|
+
|
|
435
|
+
def __str__(self) -> str:
|
|
436
|
+
shape_str = ", ".join(str(d) for d in self.shape)
|
|
437
|
+
return f"Tensor[{self.element_type}, ({shape_str})]"
|
|
438
|
+
|
|
439
|
+
def __eq__(self, other: object) -> bool:
|
|
440
|
+
if not isinstance(other, TensorType):
|
|
441
|
+
return False
|
|
442
|
+
return self.element_type == other.element_type and self.shape == other.shape
|
|
443
|
+
|
|
444
|
+
def __hash__(self) -> int:
|
|
445
|
+
return hash((self.element_type, self.shape))
|
|
446
|
+
|
|
447
|
+
@property
|
|
448
|
+
def is_scalar(self) -> bool:
|
|
449
|
+
"""Check if this is a scalar (0-dimensional) tensor."""
|
|
450
|
+
return self.shape == ()
|
|
451
|
+
|
|
452
|
+
@property
|
|
453
|
+
def is_fully_static(self) -> bool:
|
|
454
|
+
"""Check if all dimensions are statically known."""
|
|
455
|
+
return all(dim > 0 for dim in self.shape)
|
|
456
|
+
|
|
457
|
+
@property
|
|
458
|
+
def rank(self) -> int:
|
|
459
|
+
"""Get the rank (number of dimensions) of the tensor.
|
|
460
|
+
|
|
461
|
+
Returns:
|
|
462
|
+
int: Number of dimensions (always available for ranked tensors)
|
|
463
|
+
"""
|
|
464
|
+
return len(self.shape)
|
|
465
|
+
|
|
466
|
+
def has_dynamic_dims(self) -> bool:
|
|
467
|
+
"""Check if tensor has any dynamic dimensions (-1)."""
|
|
468
|
+
return any(dim == -1 for dim in self.shape)
|
|
469
|
+
|
|
470
|
+
# --- Serde methods ---
|
|
471
|
+
_serde_kind: ClassVar[str] = "mplang.TensorType"
|
|
472
|
+
|
|
473
|
+
def to_json(self) -> dict[str, Any]:
|
|
474
|
+
return {
|
|
475
|
+
"element_type": serde.to_json(self.element_type),
|
|
476
|
+
"shape": list(self.shape),
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
@classmethod
|
|
480
|
+
def from_json(cls, data: dict[str, Any]) -> TensorType[Any]:
|
|
481
|
+
element_type = serde.from_json(data["element_type"])
|
|
482
|
+
shape = tuple(data["shape"])
|
|
483
|
+
return cls(element_type, shape)
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
Tensor = TensorType
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
@serde.register_class
|
|
490
|
+
class VectorType(BaseType):
|
|
491
|
+
"""Represents a packed SIMD vector of a given element type and size.
|
|
492
|
+
|
|
493
|
+
Unlike Tensor, which represents a logical multi-dimensional array,
|
|
494
|
+
Vector represents a physical packed layout (SIMD).
|
|
495
|
+
This is the underlying payload for SIMD_HE schemes (BFV, CKKS).
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
element_type: The type of elements in the vector (must be ScalarType).
|
|
499
|
+
size: The number of elements (slots) in the vector.
|
|
500
|
+
"""
|
|
501
|
+
|
|
502
|
+
def __init__(self, element_type: ScalarType, size: int):
|
|
503
|
+
if not isinstance(element_type, ScalarType):
|
|
504
|
+
raise TypeError(
|
|
505
|
+
f"Vector element type must be a ScalarType, got {type(element_type).__name__}"
|
|
506
|
+
)
|
|
507
|
+
if not isinstance(size, int) or size <= 0:
|
|
508
|
+
raise ValueError(f"Vector size must be a positive integer, got {size}")
|
|
509
|
+
|
|
510
|
+
self.element_type = element_type
|
|
511
|
+
self.size = size
|
|
512
|
+
|
|
513
|
+
def __class_getitem__(cls, params: tuple) -> VectorType:
|
|
514
|
+
"""Enables the syntax `Vector[element_type, size]`."""
|
|
515
|
+
if not isinstance(params, tuple) or len(params) != 2:
|
|
516
|
+
raise TypeError("Vector expects 2 parameters (element_type, size)")
|
|
517
|
+
|
|
518
|
+
element_type, size = params
|
|
519
|
+
return cls(element_type, size)
|
|
520
|
+
|
|
521
|
+
def __str__(self) -> str:
|
|
522
|
+
return f"Vector[{self.element_type}, {self.size}]"
|
|
523
|
+
|
|
524
|
+
def __eq__(self, other: object) -> bool:
|
|
525
|
+
if not isinstance(other, VectorType):
|
|
526
|
+
return False
|
|
527
|
+
return self.element_type == other.element_type and self.size == other.size
|
|
528
|
+
|
|
529
|
+
def __hash__(self) -> int:
|
|
530
|
+
return hash(("VectorType", self.element_type, self.size))
|
|
531
|
+
|
|
532
|
+
# --- Serde methods ---
|
|
533
|
+
_serde_kind: ClassVar[str] = "mplang.VectorType"
|
|
534
|
+
|
|
535
|
+
def to_json(self) -> dict[str, Any]:
|
|
536
|
+
return {
|
|
537
|
+
"element_type": serde.to_json(self.element_type),
|
|
538
|
+
"size": self.size,
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
@classmethod
|
|
542
|
+
def from_json(cls, data: dict[str, Any]) -> VectorType:
|
|
543
|
+
element_type = serde.from_json(data["element_type"])
|
|
544
|
+
if not isinstance(element_type, ScalarType):
|
|
545
|
+
raise TypeError(
|
|
546
|
+
f"VectorType element must be ScalarType, got {type(element_type)}"
|
|
547
|
+
)
|
|
548
|
+
return cls(element_type, data["size"])
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
Vector = VectorType
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
@serde.register_class
|
|
555
|
+
class TableType(BaseType):
|
|
556
|
+
"""Represents a table with a named schema of types.
|
|
557
|
+
|
|
558
|
+
Examples:
|
|
559
|
+
>>> TableType({"id": i64, "name": STRING})
|
|
560
|
+
Table[{'id': i64, 'name': Custom[string]}]
|
|
561
|
+
|
|
562
|
+
>>> Table[{"col_a": i32, "col_b": f64}]
|
|
563
|
+
Table[{'col_a': i32, 'col_b': f64}]
|
|
564
|
+
"""
|
|
565
|
+
|
|
566
|
+
def __init__(self, schema: dict[str, BaseType]):
|
|
567
|
+
self.schema = schema
|
|
568
|
+
|
|
569
|
+
def __class_getitem__(cls, schema: dict[str, BaseType]) -> TableType:
|
|
570
|
+
"""Enables the syntax `Table[{'col_a': i32, ...}]`."""
|
|
571
|
+
return cls(schema)
|
|
572
|
+
|
|
573
|
+
def __str__(self) -> str:
|
|
574
|
+
schema_str = ", ".join(f"'{k}': {v}" for k, v in self.schema.items())
|
|
575
|
+
return f"Table[{{{schema_str}}}]"
|
|
576
|
+
|
|
577
|
+
def __eq__(self, other: object) -> bool:
|
|
578
|
+
if not isinstance(other, TableType):
|
|
579
|
+
return NotImplemented
|
|
580
|
+
return self.schema == other.schema
|
|
581
|
+
|
|
582
|
+
def __hash__(self) -> int:
|
|
583
|
+
return hash(("TableType", tuple(self.schema.items())))
|
|
584
|
+
|
|
585
|
+
# --- Serde methods ---
|
|
586
|
+
_serde_kind: ClassVar[str] = "mplang.TableType"
|
|
587
|
+
|
|
588
|
+
def to_json(self) -> dict[str, Any]:
|
|
589
|
+
return {
|
|
590
|
+
"schema": {name: serde.to_json(t) for name, t in self.schema.items()},
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
@classmethod
|
|
594
|
+
def from_json(cls, data: dict[str, Any]) -> TableType:
|
|
595
|
+
schema = {name: serde.from_json(t) for name, t in data["schema"].items()}
|
|
596
|
+
return cls(schema)
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
Table = TableType
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
@serde.register_class
|
|
603
|
+
class CustomType(BaseType):
|
|
604
|
+
"""Opaque/custom type identified by a string kind.
|
|
605
|
+
|
|
606
|
+
Used for types that don't have explicit structure (like encryption keys,
|
|
607
|
+
database handles, or other opaque objects) but need to be tracked in the
|
|
608
|
+
type system.
|
|
609
|
+
|
|
610
|
+
Examples::
|
|
611
|
+
|
|
612
|
+
>>> key_type = CustomType("EncryptionKey")
|
|
613
|
+
>>> handle_type = CustomType("DatabaseHandle")
|
|
614
|
+
>>> token_type = CustomType("AuthToken")
|
|
615
|
+
|
|
616
|
+
The kind string serves as the identifier for equality and hashing.
|
|
617
|
+
Two CustomTypes are equal if and only if their kinds are equal.
|
|
618
|
+
|
|
619
|
+
Attributes:
|
|
620
|
+
kind: String identifier for this custom type.
|
|
621
|
+
"""
|
|
622
|
+
|
|
623
|
+
def __init__(self, kind: str):
|
|
624
|
+
"""Initialize a custom type.
|
|
625
|
+
|
|
626
|
+
Args:
|
|
627
|
+
kind: String identifier for this custom type.
|
|
628
|
+
Should be descriptive (e.g., "EncryptionKey", "Handle").
|
|
629
|
+
|
|
630
|
+
Raises:
|
|
631
|
+
TypeError: If kind is not a string.
|
|
632
|
+
ValueError: If kind is empty or whitespace-only.
|
|
633
|
+
"""
|
|
634
|
+
if not isinstance(kind, str):
|
|
635
|
+
raise TypeError(f"kind must be str, got {type(kind).__name__}")
|
|
636
|
+
if not kind or kind.strip() == "":
|
|
637
|
+
raise ValueError("kind must be a non-empty string")
|
|
638
|
+
|
|
639
|
+
self._kind = kind
|
|
640
|
+
|
|
641
|
+
@property
|
|
642
|
+
def kind(self) -> str:
|
|
643
|
+
"""Return the string identifier for this custom type."""
|
|
644
|
+
return self._kind
|
|
645
|
+
|
|
646
|
+
def __eq__(self, other: object) -> bool:
|
|
647
|
+
"""Two CustomTypes are equal if their kinds match."""
|
|
648
|
+
if not isinstance(other, CustomType):
|
|
649
|
+
return False
|
|
650
|
+
return self._kind == other._kind
|
|
651
|
+
|
|
652
|
+
def __hash__(self) -> int:
|
|
653
|
+
"""Hash based on kind for use in sets and dicts."""
|
|
654
|
+
return hash(("CustomType", self._kind))
|
|
655
|
+
|
|
656
|
+
def __repr__(self) -> str:
|
|
657
|
+
"""Detailed string representation for debugging."""
|
|
658
|
+
return f"CustomType({self._kind!r})"
|
|
659
|
+
|
|
660
|
+
def __str__(self) -> str:
|
|
661
|
+
"""User-friendly string representation."""
|
|
662
|
+
return f"Custom[{self._kind}]"
|
|
663
|
+
|
|
664
|
+
def __class_getitem__(cls, kind: str) -> CustomType:
|
|
665
|
+
"""Enable Custom["TypeName"] syntax sugar.
|
|
666
|
+
|
|
667
|
+
Examples::
|
|
668
|
+
|
|
669
|
+
>>> EncryptionKey = Custom["EncryptionKey"]
|
|
670
|
+
>>> # Equivalent to:
|
|
671
|
+
>>> EncryptionKey = CustomType("EncryptionKey")
|
|
672
|
+
"""
|
|
673
|
+
return cls(kind)
|
|
674
|
+
|
|
675
|
+
# --- Serde methods ---
|
|
676
|
+
_serde_kind: ClassVar[str] = "mplang.CustomType"
|
|
677
|
+
|
|
678
|
+
def to_json(self) -> dict[str, Any]:
|
|
679
|
+
return {"kind": self.kind}
|
|
680
|
+
|
|
681
|
+
@classmethod
|
|
682
|
+
def from_json(cls, data: dict[str, Any]) -> CustomType:
|
|
683
|
+
return cls(data["kind"])
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
# Shorthand alias
|
|
687
|
+
Custom = CustomType
|
|
688
|
+
|
|
689
|
+
# ==============================================================================
|
|
690
|
+
# --- Table-only Types (for SQL/DataFrame operations)
|
|
691
|
+
# ==============================================================================
|
|
692
|
+
# These types are used in TableType schemas but don't have direct tensor
|
|
693
|
+
# equivalents. They use CustomType for flexibility.
|
|
694
|
+
|
|
695
|
+
STRING = CustomType("string")
|
|
696
|
+
DATE = CustomType("date")
|
|
697
|
+
TIME = CustomType("time")
|
|
698
|
+
TIMESTAMP = CustomType("timestamp")
|
|
699
|
+
DECIMAL = CustomType("decimal")
|
|
700
|
+
BINARY = CustomType("binary")
|
|
701
|
+
JSON = CustomType("json")
|
|
702
|
+
UUID = CustomType("uuid")
|
|
703
|
+
INTERVAL = CustomType("interval")
|
|
704
|
+
|
|
705
|
+
# ==============================================================================
|
|
706
|
+
# --- Pillar 2: Encryption Types
|
|
707
|
+
# ==============================================================================
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
@serde.register_class
|
|
711
|
+
class SSType(BaseType, EncryptedTrait, Generic[T]):
|
|
712
|
+
"""Represents a single share of a secret value `T`."""
|
|
713
|
+
|
|
714
|
+
def __init__(self, secret_type: BaseType, enc_schema: str = "ss"):
|
|
715
|
+
self._pt_type = secret_type
|
|
716
|
+
self._enc_schema = enc_schema
|
|
717
|
+
|
|
718
|
+
def __class_getitem__(cls, secret_type: BaseType | Any) -> Any:
|
|
719
|
+
"""Enables the syntax `SS[Tensor[...]]`."""
|
|
720
|
+
# Check if we are doing type specialization (Generic[T]) or instance creation
|
|
721
|
+
if isinstance(secret_type, type):
|
|
722
|
+
return super().__class_getitem__(secret_type) # type: ignore[misc]
|
|
723
|
+
return cls(secret_type)
|
|
724
|
+
|
|
725
|
+
def __str__(self) -> str:
|
|
726
|
+
return f"SS[{self.pt_type}]"
|
|
727
|
+
|
|
728
|
+
def __eq__(self, other: object) -> bool:
|
|
729
|
+
if not isinstance(other, SSType):
|
|
730
|
+
return False
|
|
731
|
+
return self.pt_type == other.pt_type and self.enc_schema == other.enc_schema
|
|
732
|
+
|
|
733
|
+
def __hash__(self) -> int:
|
|
734
|
+
return hash(("SSType", self.pt_type, self.enc_schema))
|
|
735
|
+
|
|
736
|
+
# --- Serde methods ---
|
|
737
|
+
_serde_kind: ClassVar[str] = "mplang.SSType"
|
|
738
|
+
|
|
739
|
+
def to_json(self) -> dict[str, Any]:
|
|
740
|
+
return {
|
|
741
|
+
"secret_type": serde.to_json(self._pt_type),
|
|
742
|
+
"enc_schema": self._enc_schema,
|
|
743
|
+
}
|
|
744
|
+
|
|
745
|
+
@classmethod
|
|
746
|
+
def from_json(cls, data: dict[str, Any]) -> SSType[Any]:
|
|
747
|
+
secret_type = serde.from_json(data["secret_type"])
|
|
748
|
+
return cls(secret_type, enc_schema=data.get("enc_schema", "ss"))
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
SS = SSType
|
|
752
|
+
|
|
753
|
+
# ==============================================================================
|
|
754
|
+
# --- Pillar 3: Distribution Types
|
|
755
|
+
# ==============================================================================
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
@serde.register_class
|
|
759
|
+
class MPType(BaseType, Generic[T]):
|
|
760
|
+
"""Represents a logical value distributed among multiple parties.
|
|
761
|
+
|
|
762
|
+
Args:
|
|
763
|
+
value_type: The type of the value held by parties
|
|
764
|
+
parties: Tuple of party IDs (static mask) or None (dynamic mask)
|
|
765
|
+
"""
|
|
766
|
+
|
|
767
|
+
def __init__(self, value_type: BaseType, parties: tuple[int, ...] | None):
|
|
768
|
+
self._value_type = value_type
|
|
769
|
+
self._parties = parties
|
|
770
|
+
|
|
771
|
+
@property
|
|
772
|
+
def value_type(self) -> BaseType:
|
|
773
|
+
return self._value_type
|
|
774
|
+
|
|
775
|
+
@property
|
|
776
|
+
def parties(self) -> tuple[int, ...] | None:
|
|
777
|
+
return self._parties
|
|
778
|
+
|
|
779
|
+
def __class_getitem__(
|
|
780
|
+
cls, params: tuple[BaseType, tuple[int, ...] | None] | Any
|
|
781
|
+
) -> Any:
|
|
782
|
+
"""Enables the syntax `MP[Tensor[...], (0, 1)]` or `MP[Tensor[...], None]`."""
|
|
783
|
+
# Check if we are doing type specialization (Generic[T]) or instance creation
|
|
784
|
+
# Heuristic: If params contains a Type (class), it's a type spec.
|
|
785
|
+
args = params if isinstance(params, tuple) else (params,)
|
|
786
|
+
if any(isinstance(a, type) for a in args):
|
|
787
|
+
return super().__class_getitem__(params) # type: ignore[misc]
|
|
788
|
+
|
|
789
|
+
value_type, parties = params
|
|
790
|
+
return cls(value_type, parties)
|
|
791
|
+
|
|
792
|
+
def __str__(self) -> str:
|
|
793
|
+
return f"MP[{self.value_type}, parties={self.parties}]"
|
|
794
|
+
|
|
795
|
+
def __eq__(self, other: object) -> bool:
|
|
796
|
+
if not isinstance(other, MPType):
|
|
797
|
+
return False
|
|
798
|
+
return self.value_type == other.value_type and self.parties == other.parties
|
|
799
|
+
|
|
800
|
+
def __hash__(self) -> int:
|
|
801
|
+
return hash(("MPType", self.value_type, self.parties))
|
|
802
|
+
|
|
803
|
+
# --- Serde methods ---
|
|
804
|
+
_serde_kind: ClassVar[str] = "mplang.MPType"
|
|
805
|
+
|
|
806
|
+
def to_json(self) -> dict[str, Any]:
|
|
807
|
+
return {
|
|
808
|
+
"value_type": serde.to_json(self._value_type),
|
|
809
|
+
"parties": list(self._parties) if self._parties is not None else None,
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
@classmethod
|
|
813
|
+
def from_json(cls, data: dict[str, Any]) -> MPType[Any]:
|
|
814
|
+
value_type = serde.from_json(data["value_type"])
|
|
815
|
+
parties = tuple(data["parties"]) if data["parties"] is not None else None
|
|
816
|
+
return cls(value_type, parties)
|