mplang-nightly 0.1.dev163__py3-none-any.whl → 0.1.dev165__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,626 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from abc import ABC, abstractmethod
18
+ from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
19
+
20
+ from mplang.protos.v1alpha1 import value_pb2 as _value_pb2
21
+
22
+ if TYPE_CHECKING:
23
+ import numpy as np
24
+
25
+ __all__ = [
26
+ "BytesBlob",
27
+ "TableValue",
28
+ "TensorValue",
29
+ "Value",
30
+ "ValueDecodeError",
31
+ "ValueError",
32
+ "ValueProtoBuilder",
33
+ "ValueProtoReader",
34
+ "decode_value",
35
+ "encode_value",
36
+ "is_value_envelope",
37
+ "list_value_kinds",
38
+ "register_value",
39
+ ]
40
+
41
+
42
+ class ValueError(Exception): # shadow built-in intentionally local
43
+ """Base exception for backend Value related errors."""
44
+
45
+
46
+ class ValueDecodeError(ValueError):
47
+ """Raised when decoding a Value envelope or payload fails."""
48
+
49
+
50
+ class ValueProtoBuilder:
51
+ """Builder for creating ValueProto messages with a fluent API.
52
+
53
+ Provides a cleaner, more ergonomic interface than directly working with pb2 messages.
54
+
55
+ Example:
56
+ proto = ValueProtoBuilder("my.custom.Type", version=1) \
57
+ .set_attr("shape", [1, 2, 3]) \
58
+ .set_attr("dtype", "float32") \
59
+ .set_payload(b"some bytes") \
60
+ .build()
61
+ """
62
+
63
+ def __init__(self, kind: str, version: int):
64
+ """Initialize a new proto builder.
65
+
66
+ Args:
67
+ kind: The globally unique KIND identifier for this Value type
68
+ version: The WIRE_VERSION for this Value type
69
+ """
70
+ self._proto = _value_pb2.ValueProto()
71
+ self._proto.kind = kind
72
+ self._proto.value_version = version
73
+
74
+ def set_payload(self, payload: bytes) -> ValueProtoBuilder:
75
+ """Set the payload bytes.
76
+
77
+ Args:
78
+ payload: Raw bytes to store in the proto
79
+
80
+ Returns:
81
+ Self for method chaining
82
+ """
83
+ self._proto.payload = payload
84
+ return self
85
+
86
+ def set_attr(self, key: str, value: Any) -> ValueProtoBuilder:
87
+ """Set a single runtime attribute.
88
+
89
+ Args:
90
+ key: Attribute key
91
+ value: Attribute value. Supported types: bool, int, float, str, bytes,
92
+ list[int/float/str]
93
+
94
+ Returns:
95
+ Self for method chaining
96
+ """
97
+ self._proto.runtime_attrs[key].CopyFrom(_python_to_attr_proto(value))
98
+ return self
99
+
100
+ def build(self) -> _value_pb2.ValueProto:
101
+ """Return the built proto.
102
+
103
+ Returns:
104
+ The fully constructed ValueProto
105
+ """
106
+ return self._proto
107
+
108
+
109
+ class ValueProtoReader:
110
+ """Reader for extracting data from ValueProto messages.
111
+
112
+ Provides a convenient interface for reading proto fields and attributes.
113
+
114
+ Example:
115
+ reader = ValueProtoReader(proto)
116
+ shape = reader.get_attr("shape")
117
+ dtype = reader.get_attr("dtype")
118
+ payload = reader.payload
119
+ """
120
+
121
+ def __init__(self, proto: _value_pb2.ValueProto):
122
+ """Initialize a reader for an existing proto.
123
+
124
+ Args:
125
+ proto: The ValueProto to read from
126
+ """
127
+ self._proto = proto
128
+
129
+ @property
130
+ def kind(self) -> str:
131
+ """Get the KIND identifier."""
132
+ return self._proto.kind
133
+
134
+ @property
135
+ def version(self) -> int:
136
+ """Get the WIRE_VERSION."""
137
+ return self._proto.value_version
138
+
139
+ @property
140
+ def payload(self) -> bytes:
141
+ """Get the payload bytes."""
142
+ return self._proto.payload
143
+
144
+ def get_attr(self, key: str, default: Any = ...) -> Any:
145
+ """Get a single attribute with optional default.
146
+
147
+ Args:
148
+ key: Attribute key to retrieve
149
+ default: Default value if key is missing. If not provided (default),
150
+ raises ValueDecodeError when key is missing.
151
+
152
+ Returns:
153
+ The decoded attribute value or default
154
+
155
+ Raises:
156
+ ValueDecodeError: If key is missing and no default provided
157
+ """
158
+ if key not in self._proto.runtime_attrs:
159
+ if default is ...:
160
+ raise ValueDecodeError(f"Missing required runtime_attr: {key}")
161
+ return default
162
+ return _attr_proto_to_python(self._proto.runtime_attrs[key])
163
+
164
+
165
+ class Value(ABC):
166
+ """Abstract base for backend-level transferable values.
167
+
168
+ Subclasses MUST define:
169
+ KIND (ClassVar[str]): globally unique stable identifier
170
+ WIRE_VERSION (ClassVar[int]): per-kind payload version integer >=1
171
+
172
+ Use ValueProtoBuilder and ValueProtoReader for proto serialization.
173
+ """
174
+
175
+ KIND: ClassVar[str]
176
+ WIRE_VERSION: ClassVar[int] = 1
177
+
178
+ def estimated_wire_size(self) -> int | None: # optional hint
179
+ return None
180
+
181
+ @abstractmethod
182
+ def to_proto(self) -> _value_pb2.ValueProto:
183
+ """Return fully-populated ValueProto for this value."""
184
+
185
+ @classmethod
186
+ @abstractmethod
187
+ def from_proto(cls, proto: _value_pb2.ValueProto) -> Value:
188
+ """Construct instance from a parsed ValueProto."""
189
+
190
+ def to_bool(self) -> bool:
191
+ """Convert value to bool (for predicates in control flow).
192
+
193
+ Default implementation raises NotImplementedError.
194
+ Subclasses should override to provide appropriate conversion.
195
+
196
+ Raises:
197
+ NotImplementedError: If the value type cannot be converted to bool
198
+ """
199
+ raise NotImplementedError(
200
+ f"{type(self).__name__} does not support conversion to bool"
201
+ )
202
+
203
+
204
+ T = TypeVar("T", bound=Value)
205
+
206
+ _VALUE_REGISTRY: dict[str, type[Value]] = {}
207
+
208
+
209
+ def _python_to_attr_proto(value: Any) -> _value_pb2.ValueAttrProto:
210
+ attr = _value_pb2.ValueAttrProto()
211
+ if isinstance(value, bool):
212
+ attr.type = _value_pb2.ValueAttrProto.BOOL
213
+ attr.b = value
214
+ elif isinstance(value, int) and not isinstance(value, bool):
215
+ attr.type = _value_pb2.ValueAttrProto.INT
216
+ attr.i = value
217
+ elif isinstance(value, float):
218
+ attr.type = _value_pb2.ValueAttrProto.FLOAT
219
+ attr.f = value
220
+ elif isinstance(value, str):
221
+ attr.type = _value_pb2.ValueAttrProto.STRING
222
+ attr.s = value
223
+ elif isinstance(value, (bytes, bytearray, memoryview)):
224
+ attr.type = _value_pb2.ValueAttrProto.BYTES
225
+ attr.raw_bytes = bytes(value)
226
+ elif isinstance(value, (list, tuple)):
227
+ if not value:
228
+ # Represent empty list explicitly
229
+ attr.type = _value_pb2.ValueAttrProto.EMPTY
230
+ return attr
231
+ if all(isinstance(v, float) for v in value):
232
+ attr.type = _value_pb2.ValueAttrProto.FLOATS
233
+ attr.floats.extend(float(v) for v in value)
234
+ elif all(isinstance(v, int) and not isinstance(v, bool) for v in value):
235
+ attr.type = _value_pb2.ValueAttrProto.INTS
236
+ attr.ints.extend(int(v) for v in value)
237
+ elif all(isinstance(v, str) for v in value):
238
+ attr.type = _value_pb2.ValueAttrProto.STRINGS
239
+ attr.strs.extend(value)
240
+ else:
241
+ raise TypeError(
242
+ "Unsupported iterable element type for AttrProto: "
243
+ f"{type(value[0]).__name__}"
244
+ )
245
+ elif value is None:
246
+ attr.type = _value_pb2.ValueAttrProto.UNDEFINED
247
+ else:
248
+ raise TypeError(
249
+ "Unsupported runtime attr type for Value serialization: "
250
+ f"{type(value).__name__}"
251
+ )
252
+ return attr
253
+
254
+
255
+ def _attr_proto_to_python(attr: _value_pb2.ValueAttrProto) -> Any:
256
+ if attr.type == _value_pb2.ValueAttrProto.FLOAT:
257
+ return attr.f
258
+ if attr.type == _value_pb2.ValueAttrProto.INT:
259
+ return attr.i
260
+ if attr.type == _value_pb2.ValueAttrProto.STRING:
261
+ return attr.s
262
+ if attr.type == _value_pb2.ValueAttrProto.BOOL:
263
+ return attr.b
264
+ if attr.type == _value_pb2.ValueAttrProto.BYTES:
265
+ return attr.raw_bytes
266
+ if attr.type == _value_pb2.ValueAttrProto.FLOATS:
267
+ return list(attr.floats)
268
+ if attr.type == _value_pb2.ValueAttrProto.INTS:
269
+ return list(attr.ints)
270
+ if attr.type == _value_pb2.ValueAttrProto.STRINGS:
271
+ return list(attr.strs)
272
+ if attr.type == _value_pb2.ValueAttrProto.EMPTY:
273
+ return []
274
+ if attr.type == _value_pb2.ValueAttrProto.UNDEFINED:
275
+ return None
276
+ raise ValueDecodeError(f"Unsupported AttrProto type {attr.type}")
277
+
278
+
279
+ def _looks_like_pyarrow_table(obj: Any) -> bool:
280
+ if obj is None:
281
+ return False
282
+ module = getattr(obj.__class__, "__module__", "")
283
+ return (
284
+ module.startswith("pyarrow.")
285
+ and hasattr(obj, "schema")
286
+ and hasattr(obj, "column_names")
287
+ and hasattr(obj, "num_rows")
288
+ )
289
+
290
+
291
+ def _looks_like_pandas_df(obj: Any) -> bool:
292
+ if obj is None:
293
+ return False
294
+ module = getattr(obj.__class__, "__module__", "")
295
+ return (
296
+ module.startswith("pandas.")
297
+ and hasattr(obj, "columns")
298
+ and hasattr(obj, "dtypes")
299
+ )
300
+
301
+
302
+ def register_value(cls: type[T]) -> type[T]:
303
+ kind = getattr(cls, "KIND", None)
304
+ if not kind or not isinstance(kind, str):
305
+ raise ValueError(f"Value subclass {cls.__name__} missing KIND str")
306
+ if kind in _VALUE_REGISTRY:
307
+ raise ValueError(f"Duplicate Value KIND '{kind}'")
308
+ if getattr(cls, "WIRE_VERSION", None) is None:
309
+ raise ValueError(f"Value subclass {cls.__name__} missing WIRE_VERSION")
310
+ _VALUE_REGISTRY[kind] = cls
311
+ return cls
312
+
313
+
314
+ def list_value_kinds() -> list[str]:
315
+ return sorted(_VALUE_REGISTRY.keys())
316
+
317
+
318
+ def encode_value(val: Value) -> bytes:
319
+ """Encode using protobuf envelope.
320
+
321
+ Raises:
322
+ ValueError if protobuf module not available.
323
+ """
324
+ if _value_pb2 is None: # pragma: no cover
325
+ raise ValueError("protobuf value_pb2 not generated yet")
326
+ proto = val.to_proto()
327
+ if not isinstance(proto, _value_pb2.ValueProto):
328
+ raise ValueError("Value.to_proto must return ValueProto")
329
+ if not proto.kind:
330
+ proto.kind = val.KIND
331
+ elif proto.kind != val.KIND:
332
+ raise ValueError(
333
+ f"ValueProto.kind mismatch: expected '{val.KIND}', got '{proto.kind}'"
334
+ )
335
+ if proto.value_version == 0:
336
+ proto.value_version = val.WIRE_VERSION
337
+ elif proto.value_version != val.WIRE_VERSION:
338
+ raise ValueError(
339
+ f"ValueProto.value_version mismatch: expected {val.WIRE_VERSION}, got {proto.value_version}"
340
+ )
341
+ return proto.SerializeToString() # type: ignore[no-any-return]
342
+
343
+
344
+ def is_value_envelope(data: bytes) -> bool:
345
+ if _value_pb2 is None:
346
+ return False
347
+ env = _value_pb2.ValueProto()
348
+ try:
349
+ env.ParseFromString(data)
350
+ return bool(env.kind)
351
+ except Exception:
352
+ return False
353
+
354
+
355
+ def decode_value(data: bytes) -> Value:
356
+ if _value_pb2 is None:
357
+ raise ValueDecodeError("protobuf value_pb2 not available for decode")
358
+ env = _value_pb2.ValueProto()
359
+ try:
360
+ env.ParseFromString(data)
361
+ except Exception as e: # pragma: no cover
362
+ raise ValueDecodeError(f"Failed parsing ValueProto: {e}") from e
363
+ if not env.kind:
364
+ raise ValueDecodeError("Envelope missing kind")
365
+ cls = _VALUE_REGISTRY.get(env.kind)
366
+ if cls is None:
367
+ raise ValueDecodeError(f"Unknown Value kind '{env.kind}'")
368
+ if env.value_version and env.value_version != cls.WIRE_VERSION:
369
+ raise ValueDecodeError(
370
+ f"Unsupported {cls.__name__} version {env.value_version}"
371
+ )
372
+ if env.value_version == 0:
373
+ env.value_version = cls.WIRE_VERSION
374
+ return cls.from_proto(env)
375
+
376
+
377
+ @register_value
378
+ class BytesBlob(Value): # demo subclass
379
+ KIND = "mplang.demo.BytesBlob"
380
+ WIRE_VERSION = 1
381
+
382
+ def __init__(self, data: bytes):
383
+ self._data = data
384
+
385
+ def to_proto(self) -> _value_pb2.ValueProto:
386
+ return (
387
+ ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
388
+ .set_payload(self._data)
389
+ .build()
390
+ )
391
+
392
+ @classmethod
393
+ def from_proto(cls, proto: _value_pb2.ValueProto) -> BytesBlob:
394
+ reader = ValueProtoReader(proto)
395
+ if reader.version != cls.WIRE_VERSION:
396
+ raise ValueDecodeError(f"Unsupported BytesBlob version {reader.version}")
397
+ if proto.runtime_attrs:
398
+ raise ValueDecodeError("BytesBlob does not expect runtime attributes")
399
+ return cls(reader.payload)
400
+
401
+ def __repr__(self) -> str: # pragma: no cover
402
+ return f"BytesBlob(len={len(self._data)})"
403
+
404
+
405
+ @register_value
406
+ class TensorValue(Value): # well-known tensor (ndarray) Value
407
+ """Numpy ndarray serialization via raw buffer + runtime metadata."""
408
+
409
+ KIND = "mplang.ndarray"
410
+ WIRE_VERSION = 1
411
+
412
+ def __init__(self, array): # type: ignore[no-untyped-def]
413
+ import numpy as np
414
+
415
+ if not isinstance(array, np.ndarray):
416
+ raise TypeError("TensorValue expects a numpy.ndarray")
417
+ if not array.flags.c_contiguous:
418
+ array = np.ascontiguousarray(array)
419
+ self._arr = array
420
+
421
+ def to_proto(self) -> _value_pb2.ValueProto:
422
+ import numpy as np
423
+
424
+ arr = self._arr
425
+ if not arr.flags.c_contiguous:
426
+ arr = np.ascontiguousarray(arr)
427
+ return (
428
+ ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
429
+ .set_attr("dtype", arr.dtype.str)
430
+ .set_attr("shape", [int(dim) for dim in arr.shape])
431
+ .set_payload(arr.tobytes())
432
+ .build()
433
+ )
434
+
435
+ @classmethod
436
+ def from_proto(cls, proto: _value_pb2.ValueProto) -> TensorValue:
437
+ import numpy as np
438
+
439
+ reader = ValueProtoReader(proto)
440
+ if reader.version != cls.WIRE_VERSION:
441
+ raise ValueDecodeError(f"Unsupported TensorValue version {reader.version}")
442
+ dtype_val = reader.get_attr("dtype")
443
+ if not isinstance(dtype_val, str):
444
+ raise ValueDecodeError("TensorValue runtime attr 'dtype' must be str")
445
+ shape_val = reader.get_attr("shape")
446
+ if not isinstance(shape_val, list):
447
+ raise ValueDecodeError("TensorValue runtime attr 'shape' must be list")
448
+ shape = tuple(int(dim) for dim in shape_val)
449
+ try:
450
+ arr = np.frombuffer(reader.payload, dtype=np.dtype(dtype_val)).reshape(
451
+ shape
452
+ )
453
+ except Exception as e: # pragma: no cover
454
+ raise ValueDecodeError(f"Failed reconstruct ndarray: {e}") from e
455
+ return cls(np.array(arr, copy=True))
456
+
457
+ @property
458
+ def shape(self) -> tuple[int, ...]:
459
+ return tuple(self._arr.shape)
460
+
461
+ @property
462
+ def dtype(self) -> np.dtype[Any]: # pragma: no cover - simple accessor
463
+ return self._arr.dtype # type: ignore[no-any-return]
464
+
465
+ @property
466
+ def ndim(self) -> int: # pragma: no cover - simple accessor
467
+ return int(self._arr.ndim)
468
+
469
+ def to_numpy(
470
+ self, *, copy: bool = False
471
+ ) -> np.ndarray[Any, Any]: # pragma: no cover - simple accessor
472
+ if copy:
473
+ import numpy as np
474
+
475
+ return np.array(self._arr, copy=True)
476
+ return self._arr # type: ignore[no-any-return]
477
+
478
+ def __array__(
479
+ self, dtype: np.dtype[Any] | None = None
480
+ ) -> np.ndarray[Any, Any]: # pragma: no cover - numpy bridge
481
+ import numpy as np
482
+
483
+ return np.asarray(self._arr, dtype=dtype)
484
+
485
+ def to_bool(self) -> bool:
486
+ """Convert tensor to bool (for scalar predicates).
487
+
488
+ Returns:
489
+ bool value if tensor is scalar
490
+
491
+ Raises:
492
+ ValueError: If tensor is not a scalar
493
+ """
494
+ if self._arr.size != 1:
495
+ raise ValueError(
496
+ f"Cannot convert non-scalar tensor (shape={self._arr.shape}) to bool"
497
+ )
498
+ return bool(self._arr.item())
499
+
500
+ def __repr__(self) -> str: # pragma: no cover
501
+ return f"TensorValue(shape={self._arr.shape}, dtype={self._arr.dtype})"
502
+
503
+
504
+ @register_value
505
+ class TableValue(Value): # well-known table (Arrow IPC) Value
506
+ """Table value backed by PyArrow, serialized via Arrow IPC stream.
507
+
508
+ KIND: mplang.dataframe.arrow
509
+ WIRE_VERSION: increments if wire semantics (not DataFrame contents) change.
510
+
511
+ Internal representation: Always pyarrow.Table for consistency and performance.
512
+ Wire format: Arrow IPC stream.
513
+
514
+ Accepts pandas DataFrame or pyarrow.Table as input, but internally converts
515
+ everything to pyarrow.Table for unified handling.
516
+ """
517
+
518
+ KIND = "mplang.dataframe.arrow"
519
+ WIRE_VERSION = 1
520
+
521
+ def __init__(self, data): # type: ignore[no-untyped-def]
522
+ """Initialize TableValue from pandas DataFrame or pyarrow.Table.
523
+
524
+ Args:
525
+ data: pandas.DataFrame, pyarrow.Table, or dict-like object
526
+ """
527
+ try:
528
+ import pyarrow as pa # type: ignore
529
+ except ImportError as e:
530
+ raise ValueError("pyarrow is required for TableValue") from e
531
+
532
+ if _looks_like_pyarrow_table(data):
533
+ self._table = data
534
+ elif _looks_like_pandas_df(data):
535
+ try:
536
+ self._table = pa.Table.from_pandas(data, preserve_index=False)
537
+ except Exception as e:
538
+ raise ValueError(
539
+ f"Cannot convert pandas DataFrame to Arrow: {e}"
540
+ ) from e
541
+ else:
542
+ # Try to convert dict-like or other structures
543
+ try:
544
+ self._table = pa.table(data)
545
+ except Exception as e:
546
+ raise TypeError(
547
+ f"TableValue requires pandas.DataFrame or pyarrow.Table, got {type(data).__name__}"
548
+ ) from e
549
+
550
+ def to_proto(self) -> _value_pb2.ValueProto:
551
+ """Serialize to Arrow IPC stream format."""
552
+ import pyarrow as pa # type: ignore
553
+ import pyarrow.ipc as pa_ipc # type: ignore
554
+
555
+ sink = pa.BufferOutputStream()
556
+ with pa_ipc.new_stream(sink, self._table.schema) as writer: # type: ignore[arg-type]
557
+ writer.write_table(self._table) # type: ignore[arg-type]
558
+ return (
559
+ ValueProtoBuilder(self.KIND, self.WIRE_VERSION)
560
+ .set_payload(sink.getvalue().to_pybytes())
561
+ .build()
562
+ )
563
+
564
+ @classmethod
565
+ def from_proto(cls, proto: _value_pb2.ValueProto) -> TableValue:
566
+ """Deserialize from Arrow IPC stream format."""
567
+ reader = ValueProtoReader(proto)
568
+ if reader.version != cls.WIRE_VERSION:
569
+ raise ValueDecodeError(f"Unsupported TableValue version {reader.version}")
570
+ if proto.runtime_attrs:
571
+ raise ValueDecodeError("TableValue does not expect runtime attributes")
572
+
573
+ import pyarrow as pa # type: ignore
574
+ import pyarrow.ipc as pa_ipc # type: ignore
575
+
576
+ buf = pa.py_buffer(reader.payload)
577
+ ipc_reader = pa_ipc.open_stream(buf)
578
+ table = ipc_reader.read_all()
579
+ return cls(table)
580
+
581
+ def to_arrow(self) -> Any: # pyarrow.Table
582
+ """Get the underlying pyarrow.Table (primary interface).
583
+
584
+ Returns:
585
+ pyarrow.Table: The table data
586
+ """
587
+ return self._table
588
+
589
+ def to_pandas(self) -> Any: # pandas.DataFrame
590
+ """Convert to pandas DataFrame (compatibility interface).
591
+
592
+ Note: This creates a copy and converts from Arrow to pandas format.
593
+ For better performance, consider using to_arrow() and working with
594
+ Arrow-native APIs (DuckDB, Ibis, etc.) directly.
595
+
596
+ Returns:
597
+ pandas.DataFrame: Converted dataframe
598
+ """
599
+ return self._table.to_pandas() # type: ignore[attr-defined]
600
+
601
+ @property
602
+ def columns(self) -> list[str]:
603
+ """Return column names (TableLike protocol compatibility)."""
604
+ return [str(name) for name in self._table.column_names]
605
+
606
+ @property
607
+ def dtypes(self) -> Any: # pyarrow.Schema
608
+ """Return column dtypes as Arrow schema (TableLike protocol compatibility)."""
609
+ return self._table.schema
610
+
611
+ def num_rows(self) -> int:
612
+ """Get number of rows in the table.
613
+
614
+ Returns:
615
+ Number of rows
616
+ """
617
+ return self._table.num_rows # type: ignore[attr-defined,return-value,no-any-return]
618
+
619
+ def __repr__(self) -> str:
620
+ """String representation of TableValue."""
621
+ try:
622
+ rows = self.num_rows()
623
+ cols = self.columns
624
+ return f"TableValue(rows={rows}, cols={cols})"
625
+ except Exception:
626
+ return "TableValue()"
mplang/ops/tee.py CHANGED
@@ -14,11 +14,7 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- from jax.tree_util import PyTreeDef, tree_flatten
18
-
19
17
  from mplang.core.dtype import UINT8
20
- from mplang.core.mpobject import MPObject
21
- from mplang.core.pfunc import PFunction
22
18
  from mplang.core.tensor import TensorType
23
19
  from mplang.ops.base import stateless_mod
24
20
 
@@ -32,20 +28,10 @@ def quote_gen(pk: TensorType) -> TensorType:
32
28
  return TensorType(UINT8, (-1,))
33
29
 
34
30
 
35
- @_TEE_MOD.op_def()
36
- def attest(
37
- quote: MPObject, platform: str
38
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
39
- """TEE quote verification returning the attested TEE public key."""
40
-
41
- ins_info = [TensorType.from_obj(quote)]
42
- outs_info = [TensorType(UINT8, (32,))] # pk is always 32 bytes for x25519
43
- pfunc = PFunction(
44
- fn_type="tee.attest",
45
- ins_info=ins_info,
46
- outs_info=outs_info,
47
- platform=platform,
48
- )
49
- _, treedef = tree_flatten(outs_info[0])
50
-
51
- return pfunc, [quote], treedef
31
+ @_TEE_MOD.simple_op()
32
+ def attest(quote: TensorType) -> TensorType:
33
+ """TEE quote verification returning the attested TEE public key.
34
+ API (mock): attest(quote: u8[33]) -> tee_pk: u8[32]
35
+ """
36
+ _ = quote # Mark as used for the decorator
37
+ return TensorType(UINT8, (32,))