mplang-nightly 0.1.dev164__py3-none-any.whl → 0.1.dev166__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/core/expr/evaluator.py +55 -15
- mplang/kernels/__init__.py +28 -0
- mplang/kernels/builtin.py +91 -56
- mplang/kernels/crypto.py +39 -30
- mplang/kernels/mock_tee.py +10 -8
- mplang/kernels/phe.py +238 -39
- mplang/kernels/spu.py +134 -45
- mplang/kernels/sql_duckdb.py +8 -13
- mplang/kernels/stablehlo.py +15 -9
- mplang/kernels/value.py +626 -0
- mplang/protos/v1alpha1/mpir_pb2.pyi +71 -21
- mplang/protos/v1alpha1/value_pb2.py +34 -0
- mplang/protos/v1alpha1/value_pb2.pyi +169 -0
- mplang/runtime/client.py +19 -8
- mplang/runtime/communicator.py +11 -4
- mplang/runtime/driver.py +16 -1
- mplang/runtime/link_comm.py +26 -79
- mplang/runtime/server.py +30 -29
- mplang/runtime/session.py +9 -0
- mplang/runtime/simulation.py +4 -5
- mplang/simp/__init__.py +1 -1
- {mplang_nightly-0.1.dev164.dist-info → mplang_nightly-0.1.dev166.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev164.dist-info → mplang_nightly-0.1.dev166.dist-info}/RECORD +26 -23
- {mplang_nightly-0.1.dev164.dist-info → mplang_nightly-0.1.dev166.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev164.dist-info → mplang_nightly-0.1.dev166.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev164.dist-info → mplang_nightly-0.1.dev166.dist-info}/licenses/LICENSE +0 -0
mplang/kernels/value.py
ADDED
@@ -0,0 +1,626 @@
|
|
1
|
+
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
from abc import ABC, abstractmethod
|
18
|
+
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
|
19
|
+
|
20
|
+
from mplang.protos.v1alpha1 import value_pb2 as _value_pb2
|
21
|
+
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
import numpy as np
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"BytesBlob",
|
27
|
+
"TableValue",
|
28
|
+
"TensorValue",
|
29
|
+
"Value",
|
30
|
+
"ValueDecodeError",
|
31
|
+
"ValueError",
|
32
|
+
"ValueProtoBuilder",
|
33
|
+
"ValueProtoReader",
|
34
|
+
"decode_value",
|
35
|
+
"encode_value",
|
36
|
+
"is_value_envelope",
|
37
|
+
"list_value_kinds",
|
38
|
+
"register_value",
|
39
|
+
]
|
40
|
+
|
41
|
+
|
42
|
+
class ValueError(Exception): # shadow built-in intentionally local
|
43
|
+
"""Base exception for backend Value related errors."""
|
44
|
+
|
45
|
+
|
46
|
+
class ValueDecodeError(ValueError):
|
47
|
+
"""Raised when decoding a Value envelope or payload fails."""
|
48
|
+
|
49
|
+
|
50
|
+
class ValueProtoBuilder:
|
51
|
+
"""Builder for creating ValueProto messages with a fluent API.
|
52
|
+
|
53
|
+
Provides a cleaner, more ergonomic interface than directly working with pb2 messages.
|
54
|
+
|
55
|
+
Example:
|
56
|
+
proto = ValueProtoBuilder("my.custom.Type", version=1) \
|
57
|
+
.set_attr("shape", [1, 2, 3]) \
|
58
|
+
.set_attr("dtype", "float32") \
|
59
|
+
.set_payload(b"some bytes") \
|
60
|
+
.build()
|
61
|
+
"""
|
62
|
+
|
63
|
+
def __init__(self, kind: str, version: int):
|
64
|
+
"""Initialize a new proto builder.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
kind: The globally unique KIND identifier for this Value type
|
68
|
+
version: The WIRE_VERSION for this Value type
|
69
|
+
"""
|
70
|
+
self._proto = _value_pb2.ValueProto()
|
71
|
+
self._proto.kind = kind
|
72
|
+
self._proto.value_version = version
|
73
|
+
|
74
|
+
def set_payload(self, payload: bytes) -> ValueProtoBuilder:
|
75
|
+
"""Set the payload bytes.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
payload: Raw bytes to store in the proto
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
Self for method chaining
|
82
|
+
"""
|
83
|
+
self._proto.payload = payload
|
84
|
+
return self
|
85
|
+
|
86
|
+
def set_attr(self, key: str, value: Any) -> ValueProtoBuilder:
|
87
|
+
"""Set a single runtime attribute.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
key: Attribute key
|
91
|
+
value: Attribute value. Supported types: bool, int, float, str, bytes,
|
92
|
+
list[int/float/str]
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
Self for method chaining
|
96
|
+
"""
|
97
|
+
self._proto.runtime_attrs[key].CopyFrom(_python_to_attr_proto(value))
|
98
|
+
return self
|
99
|
+
|
100
|
+
def build(self) -> _value_pb2.ValueProto:
|
101
|
+
"""Return the built proto.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
The fully constructed ValueProto
|
105
|
+
"""
|
106
|
+
return self._proto
|
107
|
+
|
108
|
+
|
109
|
+
class ValueProtoReader:
|
110
|
+
"""Reader for extracting data from ValueProto messages.
|
111
|
+
|
112
|
+
Provides a convenient interface for reading proto fields and attributes.
|
113
|
+
|
114
|
+
Example:
|
115
|
+
reader = ValueProtoReader(proto)
|
116
|
+
shape = reader.get_attr("shape")
|
117
|
+
dtype = reader.get_attr("dtype")
|
118
|
+
payload = reader.payload
|
119
|
+
"""
|
120
|
+
|
121
|
+
def __init__(self, proto: _value_pb2.ValueProto):
|
122
|
+
"""Initialize a reader for an existing proto.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
proto: The ValueProto to read from
|
126
|
+
"""
|
127
|
+
self._proto = proto
|
128
|
+
|
129
|
+
@property
|
130
|
+
def kind(self) -> str:
|
131
|
+
"""Get the KIND identifier."""
|
132
|
+
return self._proto.kind
|
133
|
+
|
134
|
+
@property
|
135
|
+
def version(self) -> int:
|
136
|
+
"""Get the WIRE_VERSION."""
|
137
|
+
return self._proto.value_version
|
138
|
+
|
139
|
+
@property
|
140
|
+
def payload(self) -> bytes:
|
141
|
+
"""Get the payload bytes."""
|
142
|
+
return self._proto.payload
|
143
|
+
|
144
|
+
def get_attr(self, key: str, default: Any = ...) -> Any:
|
145
|
+
"""Get a single attribute with optional default.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
key: Attribute key to retrieve
|
149
|
+
default: Default value if key is missing. If not provided (default),
|
150
|
+
raises ValueDecodeError when key is missing.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
The decoded attribute value or default
|
154
|
+
|
155
|
+
Raises:
|
156
|
+
ValueDecodeError: If key is missing and no default provided
|
157
|
+
"""
|
158
|
+
if key not in self._proto.runtime_attrs:
|
159
|
+
if default is ...:
|
160
|
+
raise ValueDecodeError(f"Missing required runtime_attr: {key}")
|
161
|
+
return default
|
162
|
+
return _attr_proto_to_python(self._proto.runtime_attrs[key])
|
163
|
+
|
164
|
+
|
165
|
+
class Value(ABC):
|
166
|
+
"""Abstract base for backend-level transferable values.
|
167
|
+
|
168
|
+
Subclasses MUST define:
|
169
|
+
KIND (ClassVar[str]): globally unique stable identifier
|
170
|
+
WIRE_VERSION (ClassVar[int]): per-kind payload version integer >=1
|
171
|
+
|
172
|
+
Use ValueProtoBuilder and ValueProtoReader for proto serialization.
|
173
|
+
"""
|
174
|
+
|
175
|
+
KIND: ClassVar[str]
|
176
|
+
WIRE_VERSION: ClassVar[int] = 1
|
177
|
+
|
178
|
+
def estimated_wire_size(self) -> int | None: # optional hint
|
179
|
+
return None
|
180
|
+
|
181
|
+
@abstractmethod
|
182
|
+
def to_proto(self) -> _value_pb2.ValueProto:
|
183
|
+
"""Return fully-populated ValueProto for this value."""
|
184
|
+
|
185
|
+
@classmethod
|
186
|
+
@abstractmethod
|
187
|
+
def from_proto(cls, proto: _value_pb2.ValueProto) -> Value:
|
188
|
+
"""Construct instance from a parsed ValueProto."""
|
189
|
+
|
190
|
+
def to_bool(self) -> bool:
|
191
|
+
"""Convert value to bool (for predicates in control flow).
|
192
|
+
|
193
|
+
Default implementation raises NotImplementedError.
|
194
|
+
Subclasses should override to provide appropriate conversion.
|
195
|
+
|
196
|
+
Raises:
|
197
|
+
NotImplementedError: If the value type cannot be converted to bool
|
198
|
+
"""
|
199
|
+
raise NotImplementedError(
|
200
|
+
f"{type(self).__name__} does not support conversion to bool"
|
201
|
+
)
|
202
|
+
|
203
|
+
|
204
|
+
T = TypeVar("T", bound=Value)
|
205
|
+
|
206
|
+
_VALUE_REGISTRY: dict[str, type[Value]] = {}
|
207
|
+
|
208
|
+
|
209
|
+
def _python_to_attr_proto(value: Any) -> _value_pb2.ValueAttrProto:
|
210
|
+
attr = _value_pb2.ValueAttrProto()
|
211
|
+
if isinstance(value, bool):
|
212
|
+
attr.type = _value_pb2.ValueAttrProto.BOOL
|
213
|
+
attr.b = value
|
214
|
+
elif isinstance(value, int) and not isinstance(value, bool):
|
215
|
+
attr.type = _value_pb2.ValueAttrProto.INT
|
216
|
+
attr.i = value
|
217
|
+
elif isinstance(value, float):
|
218
|
+
attr.type = _value_pb2.ValueAttrProto.FLOAT
|
219
|
+
attr.f = value
|
220
|
+
elif isinstance(value, str):
|
221
|
+
attr.type = _value_pb2.ValueAttrProto.STRING
|
222
|
+
attr.s = value
|
223
|
+
elif isinstance(value, (bytes, bytearray, memoryview)):
|
224
|
+
attr.type = _value_pb2.ValueAttrProto.BYTES
|
225
|
+
attr.raw_bytes = bytes(value)
|
226
|
+
elif isinstance(value, (list, tuple)):
|
227
|
+
if not value:
|
228
|
+
# Represent empty list explicitly
|
229
|
+
attr.type = _value_pb2.ValueAttrProto.EMPTY
|
230
|
+
return attr
|
231
|
+
if all(isinstance(v, float) for v in value):
|
232
|
+
attr.type = _value_pb2.ValueAttrProto.FLOATS
|
233
|
+
attr.floats.extend(float(v) for v in value)
|
234
|
+
elif all(isinstance(v, int) and not isinstance(v, bool) for v in value):
|
235
|
+
attr.type = _value_pb2.ValueAttrProto.INTS
|
236
|
+
attr.ints.extend(int(v) for v in value)
|
237
|
+
elif all(isinstance(v, str) for v in value):
|
238
|
+
attr.type = _value_pb2.ValueAttrProto.STRINGS
|
239
|
+
attr.strs.extend(value)
|
240
|
+
else:
|
241
|
+
raise TypeError(
|
242
|
+
"Unsupported iterable element type for AttrProto: "
|
243
|
+
f"{type(value[0]).__name__}"
|
244
|
+
)
|
245
|
+
elif value is None:
|
246
|
+
attr.type = _value_pb2.ValueAttrProto.UNDEFINED
|
247
|
+
else:
|
248
|
+
raise TypeError(
|
249
|
+
"Unsupported runtime attr type for Value serialization: "
|
250
|
+
f"{type(value).__name__}"
|
251
|
+
)
|
252
|
+
return attr
|
253
|
+
|
254
|
+
|
255
|
+
def _attr_proto_to_python(attr: _value_pb2.ValueAttrProto) -> Any:
|
256
|
+
if attr.type == _value_pb2.ValueAttrProto.FLOAT:
|
257
|
+
return attr.f
|
258
|
+
if attr.type == _value_pb2.ValueAttrProto.INT:
|
259
|
+
return attr.i
|
260
|
+
if attr.type == _value_pb2.ValueAttrProto.STRING:
|
261
|
+
return attr.s
|
262
|
+
if attr.type == _value_pb2.ValueAttrProto.BOOL:
|
263
|
+
return attr.b
|
264
|
+
if attr.type == _value_pb2.ValueAttrProto.BYTES:
|
265
|
+
return attr.raw_bytes
|
266
|
+
if attr.type == _value_pb2.ValueAttrProto.FLOATS:
|
267
|
+
return list(attr.floats)
|
268
|
+
if attr.type == _value_pb2.ValueAttrProto.INTS:
|
269
|
+
return list(attr.ints)
|
270
|
+
if attr.type == _value_pb2.ValueAttrProto.STRINGS:
|
271
|
+
return list(attr.strs)
|
272
|
+
if attr.type == _value_pb2.ValueAttrProto.EMPTY:
|
273
|
+
return []
|
274
|
+
if attr.type == _value_pb2.ValueAttrProto.UNDEFINED:
|
275
|
+
return None
|
276
|
+
raise ValueDecodeError(f"Unsupported AttrProto type {attr.type}")
|
277
|
+
|
278
|
+
|
279
|
+
def _looks_like_pyarrow_table(obj: Any) -> bool:
|
280
|
+
if obj is None:
|
281
|
+
return False
|
282
|
+
module = getattr(obj.__class__, "__module__", "")
|
283
|
+
return (
|
284
|
+
module.startswith("pyarrow.")
|
285
|
+
and hasattr(obj, "schema")
|
286
|
+
and hasattr(obj, "column_names")
|
287
|
+
and hasattr(obj, "num_rows")
|
288
|
+
)
|
289
|
+
|
290
|
+
|
291
|
+
def _looks_like_pandas_df(obj: Any) -> bool:
|
292
|
+
if obj is None:
|
293
|
+
return False
|
294
|
+
module = getattr(obj.__class__, "__module__", "")
|
295
|
+
return (
|
296
|
+
module.startswith("pandas.")
|
297
|
+
and hasattr(obj, "columns")
|
298
|
+
and hasattr(obj, "dtypes")
|
299
|
+
)
|
300
|
+
|
301
|
+
|
302
|
+
def register_value(cls: type[T]) -> type[T]:
|
303
|
+
kind = getattr(cls, "KIND", None)
|
304
|
+
if not kind or not isinstance(kind, str):
|
305
|
+
raise ValueError(f"Value subclass {cls.__name__} missing KIND str")
|
306
|
+
if kind in _VALUE_REGISTRY:
|
307
|
+
raise ValueError(f"Duplicate Value KIND '{kind}'")
|
308
|
+
if getattr(cls, "WIRE_VERSION", None) is None:
|
309
|
+
raise ValueError(f"Value subclass {cls.__name__} missing WIRE_VERSION")
|
310
|
+
_VALUE_REGISTRY[kind] = cls
|
311
|
+
return cls
|
312
|
+
|
313
|
+
|
314
|
+
def list_value_kinds() -> list[str]:
|
315
|
+
return sorted(_VALUE_REGISTRY.keys())
|
316
|
+
|
317
|
+
|
318
|
+
def encode_value(val: Value) -> bytes:
|
319
|
+
"""Encode using protobuf envelope.
|
320
|
+
|
321
|
+
Raises:
|
322
|
+
ValueError if protobuf module not available.
|
323
|
+
"""
|
324
|
+
if _value_pb2 is None: # pragma: no cover
|
325
|
+
raise ValueError("protobuf value_pb2 not generated yet")
|
326
|
+
proto = val.to_proto()
|
327
|
+
if not isinstance(proto, _value_pb2.ValueProto):
|
328
|
+
raise ValueError("Value.to_proto must return ValueProto")
|
329
|
+
if not proto.kind:
|
330
|
+
proto.kind = val.KIND
|
331
|
+
elif proto.kind != val.KIND:
|
332
|
+
raise ValueError(
|
333
|
+
f"ValueProto.kind mismatch: expected '{val.KIND}', got '{proto.kind}'"
|
334
|
+
)
|
335
|
+
if proto.value_version == 0:
|
336
|
+
proto.value_version = val.WIRE_VERSION
|
337
|
+
elif proto.value_version != val.WIRE_VERSION:
|
338
|
+
raise ValueError(
|
339
|
+
f"ValueProto.value_version mismatch: expected {val.WIRE_VERSION}, got {proto.value_version}"
|
340
|
+
)
|
341
|
+
return proto.SerializeToString() # type: ignore[no-any-return]
|
342
|
+
|
343
|
+
|
344
|
+
def is_value_envelope(data: bytes) -> bool:
|
345
|
+
if _value_pb2 is None:
|
346
|
+
return False
|
347
|
+
env = _value_pb2.ValueProto()
|
348
|
+
try:
|
349
|
+
env.ParseFromString(data)
|
350
|
+
return bool(env.kind)
|
351
|
+
except Exception:
|
352
|
+
return False
|
353
|
+
|
354
|
+
|
355
|
+
def decode_value(data: bytes) -> Value:
|
356
|
+
if _value_pb2 is None:
|
357
|
+
raise ValueDecodeError("protobuf value_pb2 not available for decode")
|
358
|
+
env = _value_pb2.ValueProto()
|
359
|
+
try:
|
360
|
+
env.ParseFromString(data)
|
361
|
+
except Exception as e: # pragma: no cover
|
362
|
+
raise ValueDecodeError(f"Failed parsing ValueProto: {e}") from e
|
363
|
+
if not env.kind:
|
364
|
+
raise ValueDecodeError("Envelope missing kind")
|
365
|
+
cls = _VALUE_REGISTRY.get(env.kind)
|
366
|
+
if cls is None:
|
367
|
+
raise ValueDecodeError(f"Unknown Value kind '{env.kind}'")
|
368
|
+
if env.value_version and env.value_version != cls.WIRE_VERSION:
|
369
|
+
raise ValueDecodeError(
|
370
|
+
f"Unsupported {cls.__name__} version {env.value_version}"
|
371
|
+
)
|
372
|
+
if env.value_version == 0:
|
373
|
+
env.value_version = cls.WIRE_VERSION
|
374
|
+
return cls.from_proto(env)
|
375
|
+
|
376
|
+
|
377
|
+
@register_value
|
378
|
+
class BytesBlob(Value): # demo subclass
|
379
|
+
KIND = "mplang.demo.BytesBlob"
|
380
|
+
WIRE_VERSION = 1
|
381
|
+
|
382
|
+
def __init__(self, data: bytes):
|
383
|
+
self._data = data
|
384
|
+
|
385
|
+
def to_proto(self) -> _value_pb2.ValueProto:
|
386
|
+
return (
|
387
|
+
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
388
|
+
.set_payload(self._data)
|
389
|
+
.build()
|
390
|
+
)
|
391
|
+
|
392
|
+
@classmethod
|
393
|
+
def from_proto(cls, proto: _value_pb2.ValueProto) -> BytesBlob:
|
394
|
+
reader = ValueProtoReader(proto)
|
395
|
+
if reader.version != cls.WIRE_VERSION:
|
396
|
+
raise ValueDecodeError(f"Unsupported BytesBlob version {reader.version}")
|
397
|
+
if proto.runtime_attrs:
|
398
|
+
raise ValueDecodeError("BytesBlob does not expect runtime attributes")
|
399
|
+
return cls(reader.payload)
|
400
|
+
|
401
|
+
def __repr__(self) -> str: # pragma: no cover
|
402
|
+
return f"BytesBlob(len={len(self._data)})"
|
403
|
+
|
404
|
+
|
405
|
+
@register_value
|
406
|
+
class TensorValue(Value): # well-known tensor (ndarray) Value
|
407
|
+
"""Numpy ndarray serialization via raw buffer + runtime metadata."""
|
408
|
+
|
409
|
+
KIND = "mplang.ndarray"
|
410
|
+
WIRE_VERSION = 1
|
411
|
+
|
412
|
+
def __init__(self, array): # type: ignore[no-untyped-def]
|
413
|
+
import numpy as np
|
414
|
+
|
415
|
+
if not isinstance(array, np.ndarray):
|
416
|
+
raise TypeError("TensorValue expects a numpy.ndarray")
|
417
|
+
if not array.flags.c_contiguous:
|
418
|
+
array = np.ascontiguousarray(array)
|
419
|
+
self._arr = array
|
420
|
+
|
421
|
+
def to_proto(self) -> _value_pb2.ValueProto:
|
422
|
+
import numpy as np
|
423
|
+
|
424
|
+
arr = self._arr
|
425
|
+
if not arr.flags.c_contiguous:
|
426
|
+
arr = np.ascontiguousarray(arr)
|
427
|
+
return (
|
428
|
+
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
429
|
+
.set_attr("dtype", arr.dtype.str)
|
430
|
+
.set_attr("shape", [int(dim) for dim in arr.shape])
|
431
|
+
.set_payload(arr.tobytes())
|
432
|
+
.build()
|
433
|
+
)
|
434
|
+
|
435
|
+
@classmethod
|
436
|
+
def from_proto(cls, proto: _value_pb2.ValueProto) -> TensorValue:
|
437
|
+
import numpy as np
|
438
|
+
|
439
|
+
reader = ValueProtoReader(proto)
|
440
|
+
if reader.version != cls.WIRE_VERSION:
|
441
|
+
raise ValueDecodeError(f"Unsupported TensorValue version {reader.version}")
|
442
|
+
dtype_val = reader.get_attr("dtype")
|
443
|
+
if not isinstance(dtype_val, str):
|
444
|
+
raise ValueDecodeError("TensorValue runtime attr 'dtype' must be str")
|
445
|
+
shape_val = reader.get_attr("shape")
|
446
|
+
if not isinstance(shape_val, list):
|
447
|
+
raise ValueDecodeError("TensorValue runtime attr 'shape' must be list")
|
448
|
+
shape = tuple(int(dim) for dim in shape_val)
|
449
|
+
try:
|
450
|
+
arr = np.frombuffer(reader.payload, dtype=np.dtype(dtype_val)).reshape(
|
451
|
+
shape
|
452
|
+
)
|
453
|
+
except Exception as e: # pragma: no cover
|
454
|
+
raise ValueDecodeError(f"Failed reconstruct ndarray: {e}") from e
|
455
|
+
return cls(np.array(arr, copy=True))
|
456
|
+
|
457
|
+
@property
|
458
|
+
def shape(self) -> tuple[int, ...]:
|
459
|
+
return tuple(self._arr.shape)
|
460
|
+
|
461
|
+
@property
|
462
|
+
def dtype(self) -> np.dtype[Any]: # pragma: no cover - simple accessor
|
463
|
+
return self._arr.dtype # type: ignore[no-any-return]
|
464
|
+
|
465
|
+
@property
|
466
|
+
def ndim(self) -> int: # pragma: no cover - simple accessor
|
467
|
+
return int(self._arr.ndim)
|
468
|
+
|
469
|
+
def to_numpy(
|
470
|
+
self, *, copy: bool = False
|
471
|
+
) -> np.ndarray[Any, Any]: # pragma: no cover - simple accessor
|
472
|
+
if copy:
|
473
|
+
import numpy as np
|
474
|
+
|
475
|
+
return np.array(self._arr, copy=True)
|
476
|
+
return self._arr # type: ignore[no-any-return]
|
477
|
+
|
478
|
+
def __array__(
|
479
|
+
self, dtype: np.dtype[Any] | None = None
|
480
|
+
) -> np.ndarray[Any, Any]: # pragma: no cover - numpy bridge
|
481
|
+
import numpy as np
|
482
|
+
|
483
|
+
return np.asarray(self._arr, dtype=dtype)
|
484
|
+
|
485
|
+
def to_bool(self) -> bool:
|
486
|
+
"""Convert tensor to bool (for scalar predicates).
|
487
|
+
|
488
|
+
Returns:
|
489
|
+
bool value if tensor is scalar
|
490
|
+
|
491
|
+
Raises:
|
492
|
+
ValueError: If tensor is not a scalar
|
493
|
+
"""
|
494
|
+
if self._arr.size != 1:
|
495
|
+
raise ValueError(
|
496
|
+
f"Cannot convert non-scalar tensor (shape={self._arr.shape}) to bool"
|
497
|
+
)
|
498
|
+
return bool(self._arr.item())
|
499
|
+
|
500
|
+
def __repr__(self) -> str: # pragma: no cover
|
501
|
+
return f"TensorValue(shape={self._arr.shape}, dtype={self._arr.dtype})"
|
502
|
+
|
503
|
+
|
504
|
+
@register_value
|
505
|
+
class TableValue(Value): # well-known table (Arrow IPC) Value
|
506
|
+
"""Table value backed by PyArrow, serialized via Arrow IPC stream.
|
507
|
+
|
508
|
+
KIND: mplang.dataframe.arrow
|
509
|
+
WIRE_VERSION: increments if wire semantics (not DataFrame contents) change.
|
510
|
+
|
511
|
+
Internal representation: Always pyarrow.Table for consistency and performance.
|
512
|
+
Wire format: Arrow IPC stream.
|
513
|
+
|
514
|
+
Accepts pandas DataFrame or pyarrow.Table as input, but internally converts
|
515
|
+
everything to pyarrow.Table for unified handling.
|
516
|
+
"""
|
517
|
+
|
518
|
+
KIND = "mplang.dataframe.arrow"
|
519
|
+
WIRE_VERSION = 1
|
520
|
+
|
521
|
+
def __init__(self, data): # type: ignore[no-untyped-def]
|
522
|
+
"""Initialize TableValue from pandas DataFrame or pyarrow.Table.
|
523
|
+
|
524
|
+
Args:
|
525
|
+
data: pandas.DataFrame, pyarrow.Table, or dict-like object
|
526
|
+
"""
|
527
|
+
try:
|
528
|
+
import pyarrow as pa # type: ignore
|
529
|
+
except ImportError as e:
|
530
|
+
raise ValueError("pyarrow is required for TableValue") from e
|
531
|
+
|
532
|
+
if _looks_like_pyarrow_table(data):
|
533
|
+
self._table = data
|
534
|
+
elif _looks_like_pandas_df(data):
|
535
|
+
try:
|
536
|
+
self._table = pa.Table.from_pandas(data, preserve_index=False)
|
537
|
+
except Exception as e:
|
538
|
+
raise ValueError(
|
539
|
+
f"Cannot convert pandas DataFrame to Arrow: {e}"
|
540
|
+
) from e
|
541
|
+
else:
|
542
|
+
# Try to convert dict-like or other structures
|
543
|
+
try:
|
544
|
+
self._table = pa.table(data)
|
545
|
+
except Exception as e:
|
546
|
+
raise TypeError(
|
547
|
+
f"TableValue requires pandas.DataFrame or pyarrow.Table, got {type(data).__name__}"
|
548
|
+
) from e
|
549
|
+
|
550
|
+
def to_proto(self) -> _value_pb2.ValueProto:
|
551
|
+
"""Serialize to Arrow IPC stream format."""
|
552
|
+
import pyarrow as pa # type: ignore
|
553
|
+
import pyarrow.ipc as pa_ipc # type: ignore
|
554
|
+
|
555
|
+
sink = pa.BufferOutputStream()
|
556
|
+
with pa_ipc.new_stream(sink, self._table.schema) as writer: # type: ignore[arg-type]
|
557
|
+
writer.write_table(self._table) # type: ignore[arg-type]
|
558
|
+
return (
|
559
|
+
ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
|
560
|
+
.set_payload(sink.getvalue().to_pybytes())
|
561
|
+
.build()
|
562
|
+
)
|
563
|
+
|
564
|
+
@classmethod
|
565
|
+
def from_proto(cls, proto: _value_pb2.ValueProto) -> TableValue:
|
566
|
+
"""Deserialize from Arrow IPC stream format."""
|
567
|
+
reader = ValueProtoReader(proto)
|
568
|
+
if reader.version != cls.WIRE_VERSION:
|
569
|
+
raise ValueDecodeError(f"Unsupported TableValue version {reader.version}")
|
570
|
+
if proto.runtime_attrs:
|
571
|
+
raise ValueDecodeError("TableValue does not expect runtime attributes")
|
572
|
+
|
573
|
+
import pyarrow as pa # type: ignore
|
574
|
+
import pyarrow.ipc as pa_ipc # type: ignore
|
575
|
+
|
576
|
+
buf = pa.py_buffer(reader.payload)
|
577
|
+
ipc_reader = pa_ipc.open_stream(buf)
|
578
|
+
table = ipc_reader.read_all()
|
579
|
+
return cls(table)
|
580
|
+
|
581
|
+
def to_arrow(self) -> Any: # pyarrow.Table
|
582
|
+
"""Get the underlying pyarrow.Table (primary interface).
|
583
|
+
|
584
|
+
Returns:
|
585
|
+
pyarrow.Table: The table data
|
586
|
+
"""
|
587
|
+
return self._table
|
588
|
+
|
589
|
+
def to_pandas(self) -> Any: # pandas.DataFrame
|
590
|
+
"""Convert to pandas DataFrame (compatibility interface).
|
591
|
+
|
592
|
+
Note: This creates a copy and converts from Arrow to pandas format.
|
593
|
+
For better performance, consider using to_arrow() and working with
|
594
|
+
Arrow-native APIs (DuckDB, Ibis, etc.) directly.
|
595
|
+
|
596
|
+
Returns:
|
597
|
+
pandas.DataFrame: Converted dataframe
|
598
|
+
"""
|
599
|
+
return self._table.to_pandas() # type: ignore[attr-defined]
|
600
|
+
|
601
|
+
@property
|
602
|
+
def columns(self) -> list[str]:
|
603
|
+
"""Return column names (TableLike protocol compatibility)."""
|
604
|
+
return [str(name) for name in self._table.column_names]
|
605
|
+
|
606
|
+
@property
|
607
|
+
def dtypes(self) -> Any: # pyarrow.Schema
|
608
|
+
"""Return column dtypes as Arrow schema (TableLike protocol compatibility)."""
|
609
|
+
return self._table.schema
|
610
|
+
|
611
|
+
def num_rows(self) -> int:
|
612
|
+
"""Get number of rows in the table.
|
613
|
+
|
614
|
+
Returns:
|
615
|
+
Number of rows
|
616
|
+
"""
|
617
|
+
return self._table.num_rows # type: ignore[attr-defined,return-value,no-any-return]
|
618
|
+
|
619
|
+
def __repr__(self) -> str:
|
620
|
+
"""String representation of TableValue."""
|
621
|
+
try:
|
622
|
+
rows = self.num_rows()
|
623
|
+
cols = self.columns
|
624
|
+
return f"TableValue(rows={rows}, cols={cols})"
|
625
|
+
except Exception:
|
626
|
+
return "TableValue()"
|