mplang-nightly 0.1.dev158__py3-none-any.whl → 0.1.dev268__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. mplang/__init__.py +21 -45
  2. mplang/py.typed +13 -0
  3. mplang/v1/__init__.py +157 -0
  4. mplang/v1/_device.py +602 -0
  5. mplang/{analysis → v1/analysis}/__init__.py +1 -1
  6. mplang/{analysis → v1/analysis}/diagram.py +5 -7
  7. mplang/v1/core/__init__.py +157 -0
  8. mplang/{core → v1/core}/cluster.py +30 -14
  9. mplang/{core → v1/core}/comm.py +5 -1
  10. mplang/{core → v1/core}/context_mgr.py +1 -1
  11. mplang/{core/dtype.py → v1/core/dtypes.py} +44 -2
  12. mplang/{core → v1/core}/expr/__init__.py +7 -7
  13. mplang/{core → v1/core}/expr/ast.py +13 -14
  14. mplang/{core → v1/core}/expr/evaluator.py +65 -24
  15. mplang/{core → v1/core}/expr/printer.py +24 -18
  16. mplang/{core → v1/core}/expr/transformer.py +3 -3
  17. mplang/{core → v1/core}/expr/utils.py +2 -2
  18. mplang/{core → v1/core}/expr/visitor.py +1 -1
  19. mplang/{core → v1/core}/expr/walk.py +1 -1
  20. mplang/{core → v1/core}/interp.py +6 -6
  21. mplang/{core → v1/core}/mpir.py +23 -16
  22. mplang/{core → v1/core}/mpobject.py +6 -6
  23. mplang/{core → v1/core}/mptype.py +13 -10
  24. mplang/{core → v1/core}/pfunc.py +4 -4
  25. mplang/{core → v1/core}/primitive.py +106 -201
  26. mplang/{core → v1/core}/table.py +36 -8
  27. mplang/{core → v1/core}/tensor.py +1 -1
  28. mplang/{core → v1/core}/tracer.py +9 -9
  29. mplang/{api.py → v1/host.py} +38 -6
  30. mplang/v1/kernels/__init__.py +41 -0
  31. mplang/{kernels → v1/kernels}/base.py +1 -1
  32. mplang/v1/kernels/basic.py +240 -0
  33. mplang/{kernels → v1/kernels}/context.py +42 -27
  34. mplang/{kernels → v1/kernels}/crypto.py +44 -37
  35. mplang/v1/kernels/fhe.py +858 -0
  36. mplang/{kernels → v1/kernels}/mock_tee.py +12 -13
  37. mplang/{kernels → v1/kernels}/phe.py +263 -57
  38. mplang/{kernels → v1/kernels}/spu.py +137 -48
  39. mplang/{kernels → v1/kernels}/sql_duckdb.py +12 -15
  40. mplang/{kernels → v1/kernels}/stablehlo.py +30 -23
  41. mplang/v1/kernels/value.py +626 -0
  42. mplang/{ops → v1/ops}/__init__.py +5 -16
  43. mplang/{ops → v1/ops}/base.py +2 -5
  44. mplang/{ops/builtin.py → v1/ops/basic.py} +34 -26
  45. mplang/v1/ops/crypto.py +262 -0
  46. mplang/v1/ops/fhe.py +272 -0
  47. mplang/{ops → v1/ops}/jax_cc.py +33 -68
  48. mplang/v1/ops/nnx_cc.py +168 -0
  49. mplang/{ops → v1/ops}/phe.py +16 -4
  50. mplang/{ops → v1/ops}/spu.py +3 -5
  51. mplang/v1/ops/sql_cc.py +303 -0
  52. mplang/{ops → v1/ops}/tee.py +9 -24
  53. mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.pyi +71 -21
  54. mplang/v1/protos/v1alpha1/value_pb2.py +34 -0
  55. mplang/v1/protos/v1alpha1/value_pb2.pyi +169 -0
  56. mplang/{runtime → v1/runtime}/__init__.py +2 -2
  57. mplang/v1/runtime/channel.py +230 -0
  58. mplang/{runtime → v1/runtime}/cli.py +35 -20
  59. mplang/{runtime → v1/runtime}/client.py +19 -8
  60. mplang/{runtime → v1/runtime}/communicator.py +59 -15
  61. mplang/{runtime → v1/runtime}/data_providers.py +80 -19
  62. mplang/{runtime → v1/runtime}/driver.py +30 -12
  63. mplang/v1/runtime/link_comm.py +196 -0
  64. mplang/{runtime → v1/runtime}/server.py +58 -42
  65. mplang/{runtime → v1/runtime}/session.py +57 -71
  66. mplang/{runtime → v1/runtime}/simulation.py +55 -28
  67. mplang/v1/simp/api.py +353 -0
  68. mplang/{simp → v1/simp}/mpi.py +8 -9
  69. mplang/{simp/__init__.py → v1/simp/party.py} +19 -145
  70. mplang/{simp → v1/simp}/random.py +21 -22
  71. mplang/v1/simp/smpc.py +238 -0
  72. mplang/v1/utils/table_utils.py +185 -0
  73. mplang/v2/__init__.py +424 -0
  74. mplang/v2/backends/__init__.py +57 -0
  75. mplang/v2/backends/bfv_impl.py +705 -0
  76. mplang/v2/backends/channel.py +217 -0
  77. mplang/v2/backends/crypto_impl.py +723 -0
  78. mplang/v2/backends/field_impl.py +454 -0
  79. mplang/v2/backends/func_impl.py +107 -0
  80. mplang/v2/backends/phe_impl.py +148 -0
  81. mplang/v2/backends/simp_design.md +136 -0
  82. mplang/v2/backends/simp_driver/__init__.py +41 -0
  83. mplang/v2/backends/simp_driver/http.py +168 -0
  84. mplang/v2/backends/simp_driver/mem.py +280 -0
  85. mplang/v2/backends/simp_driver/ops.py +135 -0
  86. mplang/v2/backends/simp_driver/state.py +60 -0
  87. mplang/v2/backends/simp_driver/values.py +52 -0
  88. mplang/v2/backends/simp_worker/__init__.py +29 -0
  89. mplang/v2/backends/simp_worker/http.py +354 -0
  90. mplang/v2/backends/simp_worker/mem.py +102 -0
  91. mplang/v2/backends/simp_worker/ops.py +167 -0
  92. mplang/v2/backends/simp_worker/state.py +49 -0
  93. mplang/v2/backends/spu_impl.py +275 -0
  94. mplang/v2/backends/spu_state.py +187 -0
  95. mplang/v2/backends/store_impl.py +62 -0
  96. mplang/v2/backends/table_impl.py +838 -0
  97. mplang/v2/backends/tee_impl.py +215 -0
  98. mplang/v2/backends/tensor_impl.py +519 -0
  99. mplang/v2/cli.py +603 -0
  100. mplang/v2/cli_guide.md +122 -0
  101. mplang/v2/dialects/__init__.py +36 -0
  102. mplang/v2/dialects/bfv.py +665 -0
  103. mplang/v2/dialects/crypto.py +689 -0
  104. mplang/v2/dialects/dtypes.py +378 -0
  105. mplang/v2/dialects/field.py +210 -0
  106. mplang/v2/dialects/func.py +135 -0
  107. mplang/v2/dialects/phe.py +723 -0
  108. mplang/v2/dialects/simp.py +944 -0
  109. mplang/v2/dialects/spu.py +349 -0
  110. mplang/v2/dialects/store.py +63 -0
  111. mplang/v2/dialects/table.py +407 -0
  112. mplang/v2/dialects/tee.py +346 -0
  113. mplang/v2/dialects/tensor.py +1175 -0
  114. mplang/v2/edsl/README.md +279 -0
  115. mplang/v2/edsl/__init__.py +99 -0
  116. mplang/v2/edsl/context.py +311 -0
  117. mplang/v2/edsl/graph.py +463 -0
  118. mplang/v2/edsl/jit.py +62 -0
  119. mplang/v2/edsl/object.py +53 -0
  120. mplang/v2/edsl/primitive.py +284 -0
  121. mplang/v2/edsl/printer.py +119 -0
  122. mplang/v2/edsl/registry.py +207 -0
  123. mplang/v2/edsl/serde.py +375 -0
  124. mplang/v2/edsl/tracer.py +614 -0
  125. mplang/v2/edsl/typing.py +816 -0
  126. mplang/v2/kernels/Makefile +30 -0
  127. mplang/v2/kernels/__init__.py +23 -0
  128. mplang/v2/kernels/gf128.cpp +148 -0
  129. mplang/v2/kernels/ldpc.cpp +82 -0
  130. mplang/v2/kernels/okvs.cpp +283 -0
  131. mplang/v2/kernels/okvs_opt.cpp +291 -0
  132. mplang/v2/kernels/py_kernels.py +398 -0
  133. mplang/v2/libs/collective.py +330 -0
  134. mplang/v2/libs/device/__init__.py +51 -0
  135. mplang/v2/libs/device/api.py +813 -0
  136. mplang/v2/libs/device/cluster.py +352 -0
  137. mplang/v2/libs/ml/__init__.py +23 -0
  138. mplang/v2/libs/ml/sgb.py +1861 -0
  139. mplang/v2/libs/mpc/__init__.py +41 -0
  140. mplang/v2/libs/mpc/_utils.py +99 -0
  141. mplang/v2/libs/mpc/analytics/__init__.py +35 -0
  142. mplang/v2/libs/mpc/analytics/aggregation.py +372 -0
  143. mplang/v2/libs/mpc/analytics/groupby.md +99 -0
  144. mplang/v2/libs/mpc/analytics/groupby.py +331 -0
  145. mplang/v2/libs/mpc/analytics/permutation.py +386 -0
  146. mplang/v2/libs/mpc/common/constants.py +39 -0
  147. mplang/v2/libs/mpc/ot/__init__.py +32 -0
  148. mplang/v2/libs/mpc/ot/base.py +222 -0
  149. mplang/v2/libs/mpc/ot/extension.py +477 -0
  150. mplang/v2/libs/mpc/ot/silent.py +217 -0
  151. mplang/v2/libs/mpc/psi/__init__.py +40 -0
  152. mplang/v2/libs/mpc/psi/cuckoo.py +228 -0
  153. mplang/v2/libs/mpc/psi/okvs.py +49 -0
  154. mplang/v2/libs/mpc/psi/okvs_gct.py +79 -0
  155. mplang/v2/libs/mpc/psi/oprf.py +310 -0
  156. mplang/v2/libs/mpc/psi/rr22.py +344 -0
  157. mplang/v2/libs/mpc/psi/unbalanced.py +200 -0
  158. mplang/v2/libs/mpc/vole/__init__.py +31 -0
  159. mplang/v2/libs/mpc/vole/gilboa.py +327 -0
  160. mplang/v2/libs/mpc/vole/ldpc.py +383 -0
  161. mplang/v2/libs/mpc/vole/silver.py +336 -0
  162. mplang/v2/runtime/__init__.py +15 -0
  163. mplang/v2/runtime/dialect_state.py +41 -0
  164. mplang/v2/runtime/interpreter.py +871 -0
  165. mplang/v2/runtime/object_store.py +194 -0
  166. mplang/v2/runtime/value.py +141 -0
  167. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/METADATA +24 -17
  168. mplang_nightly-0.1.dev268.dist-info/RECORD +180 -0
  169. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/WHEEL +1 -1
  170. mplang/core/__init__.py +0 -92
  171. mplang/device.py +0 -340
  172. mplang/kernels/builtin.py +0 -207
  173. mplang/ops/crypto.py +0 -109
  174. mplang/ops/ibis_cc.py +0 -139
  175. mplang/ops/sql.py +0 -61
  176. mplang/protos/v1alpha1/mpir_pb2_grpc.py +0 -3
  177. mplang/runtime/link_comm.py +0 -131
  178. mplang/simp/smpc.py +0 -201
  179. mplang/utils/table_utils.py +0 -73
  180. mplang_nightly-0.1.dev158.dist-info/RECORD +0 -77
  181. /mplang/{core → v1/core}/mask.py +0 -0
  182. /mplang/{protos → v1/protos}/v1alpha1/mpir_pb2.py +0 -0
  183. /mplang/{runtime → v1/runtime}/exceptions.py +0 -0
  184. /mplang/{runtime → v1/runtime}/http_api.md +0 -0
  185. /mplang/{kernels → v1/simp}/__init__.py +0 -0
  186. /mplang/{utils → v1/utils}/__init__.py +0 -0
  187. /mplang/{utils → v1/utils}/crypto.py +0 -0
  188. /mplang/{utils → v1/utils}/func_utils.py +0 -0
  189. /mplang/{utils → v1/utils}/spu_utils.py +0 -0
  190. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/entry_points.txt +0 -0
  191. {mplang_nightly-0.1.dev158.dist-info → mplang_nightly-0.1.dev268.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,303 @@
1
+ # Copyright 2025 Ant Group Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any
16
+
17
+ import sqlglot as sg
18
+ from jax.tree_util import PyTreeDef, tree_flatten
19
+ from sqlglot import exp as sge
20
+ from sqlglot.optimizer import annotate_types as opt_annot
21
+ from sqlglot.optimizer import qualify as opt_qualify
22
+
23
+ from mplang.v1.core import MPObject, PFunction, TableType
24
+ from mplang.v1.core.dtypes import (
25
+ BINARY,
26
+ BOOL,
27
+ DATE,
28
+ DECIMAL,
29
+ FLOAT32,
30
+ FLOAT64,
31
+ INT8,
32
+ INT16,
33
+ INT32,
34
+ INT64,
35
+ INTERVAL,
36
+ JSON,
37
+ STRING,
38
+ TIME,
39
+ TIMESTAMP,
40
+ UINT8,
41
+ UINT16,
42
+ UINT32,
43
+ UINT64,
44
+ UUID,
45
+ DType,
46
+ )
47
+ from mplang.v1.ops.base import stateless_mod
48
+
49
+ _SQL_MOD = stateless_mod("sql")
50
+
51
+
52
+ # Static dtype mappings (MPLang <-> SQL)
53
+ MP_TO_SQL_TYPE: dict[DType, str] = {
54
+ # Floats
55
+ FLOAT64: "DOUBLE",
56
+ FLOAT32: "FLOAT",
57
+ # Signed ints
58
+ INT8: "TINYINT",
59
+ INT16: "SMALLINT",
60
+ INT32: "INT",
61
+ INT64: "BIGINT",
62
+ # Unsigned ints (portable approximations)
63
+ UINT8: "SMALLINT",
64
+ UINT16: "INT",
65
+ UINT32: "BIGINT",
66
+ UINT64: "DECIMAL(38)",
67
+ # Booleans & strings
68
+ BOOL: "BOOLEAN",
69
+ STRING: "VARCHAR",
70
+ # Dates / times
71
+ DATE: "DATE",
72
+ TIME: "TIME",
73
+ TIMESTAMP: "TIMESTAMP",
74
+ # Other table types
75
+ DECIMAL: "DECIMAL",
76
+ JSON: "JSON",
77
+ BINARY: "BLOB",
78
+ UUID: "UUID",
79
+ INTERVAL: "INTERVAL",
80
+ }
81
+
82
+ SQL_TYPE_TO_MP: dict[str, DType] = {
83
+ # Floats
84
+ "double": FLOAT64,
85
+ "double precision": FLOAT64,
86
+ "float": FLOAT32,
87
+ "real": FLOAT32,
88
+ # Signed ints
89
+ "bigint": INT64,
90
+ "long": INT64,
91
+ "int": INT32,
92
+ "integer": INT32,
93
+ "int4": INT32,
94
+ "smallint": INT16,
95
+ "int2": INT16,
96
+ "tinyint": INT8,
97
+ "int1": INT8,
98
+ # Unsigned (rare in SQL)
99
+ "uint8": UINT8,
100
+ "ubyte": UINT8,
101
+ "uint16": UINT16,
102
+ "uint32": UINT32,
103
+ "uint64": UINT64,
104
+ # Booleans / strings
105
+ "bool": BOOL,
106
+ "boolean": BOOL,
107
+ "char": STRING,
108
+ "varchar": STRING,
109
+ "text": STRING,
110
+ "string": STRING,
111
+ # Dates / times
112
+ "date": DATE,
113
+ "time": TIME,
114
+ "timestamp": TIMESTAMP,
115
+ # Decimal / numeric
116
+ "decimal": DECIMAL,
117
+ "numeric": DECIMAL,
118
+ # Others
119
+ "json": JSON,
120
+ "binary": BINARY,
121
+ "varbinary": BINARY,
122
+ "blob": BINARY,
123
+ "uuid": UUID,
124
+ "interval": INTERVAL,
125
+ }
126
+
127
+
128
+ def _deduce_out_schema(
129
+ parsed: sge.Expression,
130
+ dialect: str,
131
+ in_schemas: dict[str, TableType],
132
+ ) -> TableType:
133
+ """Deduce output schema using sqlglot's qualify + annotate_types.
134
+
135
+ This implementation leverages sqlglot's optimizer to resolve table/column
136
+ references (including star expansion) and annotate expression types. It then
137
+ maps sqlglot DataType to mplang DType and returns a TableType.
138
+ """
139
+
140
+ # 1) Build sqlglot schema from MPObject/TableType inputs
141
+ def _dtype_to_sql(dt: DType) -> str:
142
+ return MP_TO_SQL_TYPE.get(dt, "VARCHAR")
143
+
144
+ sqlglot_schema: dict[str, dict[str, str]] = {
145
+ tname: {col: _dtype_to_sql(dt) for col, dt in schema.columns}
146
+ for tname, schema in in_schemas.items()
147
+ }
148
+
149
+ # 2) Parse with read dialect; 3) Qualify (resolve names, expand star); 4) Annotate types
150
+ qualified = opt_qualify.qualify(parsed, schema=sqlglot_schema, dialect=dialect)
151
+ typed = opt_annot.annotate_types(qualified, schema=sqlglot_schema)
152
+
153
+ # 5) Extract projection names and types
154
+ select = typed if isinstance(typed, sge.Select) else typed.find(sge.Select)
155
+ if select is None:
156
+ raise NotImplementedError(
157
+ "Only SELECT queries are supported for schema deduction"
158
+ )
159
+
160
+ def _sqlglot_type_to_dtype(tobj: Any) -> DType:
161
+ ts = str(tobj).lower().replace(" with time zone", "").strip()
162
+ base = ts.split("(", 1)[0].strip()
163
+ return SQL_TYPE_TO_MP.get(base, STRING)
164
+
165
+ pairs: list[tuple[str, DType]] = []
166
+ idx = 0
167
+ used: set[str] = set()
168
+ for proj in select.expressions:
169
+ name = getattr(proj, "alias_or_name", None) or getattr(proj, "name", None)
170
+ if not name:
171
+ name = f"expr_{idx}"
172
+ idx += 1
173
+ t = getattr(proj, "type", None)
174
+ if t is None:
175
+ raise NotImplementedError(
176
+ "Cannot infer type for projection; please provide out_type explicitly"
177
+ )
178
+ dtype = _sqlglot_type_to_dtype(t)
179
+ if name in used:
180
+ raise ValueError(
181
+ f"Duplicate output column name '{name}' after qualification"
182
+ )
183
+ used.add(name)
184
+ pairs.append((name, dtype))
185
+
186
+ return TableType.from_pairs(pairs)
187
+
188
+
189
+ @_SQL_MOD.op_def()
190
+ def run_sql(
191
+ query: str,
192
+ *,
193
+ out_type: TableType | None = None,
194
+ dialect: str = "duckdb",
195
+ **in_tables: Any,
196
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
197
+ """Build a sql.run PFunction from a SQL query with optional schema deduction.
198
+
199
+ API: run_sql(query: str, *, out_type: TableType | None = None, dialect: str = "duckdb", **in_tables) -> (PFunction, [MPObject], PyTreeDef)
200
+
201
+ Semantics:
202
+ - Parses the SQL and binds only the tables that are actually referenced in the query by name.
203
+ - If ``out_type`` is not provided, attempts to deduce the output table schema using sqlglot (qualify + annotate types).
204
+ - Returns a triad consisting of the constructed PFunction (``fn_type='sql.run'``), the ordered list of input MPObjects, and the output PyTreeDef.
205
+
206
+ Difference vs ``run_sql_raw``: this op can infer ``out_type`` and will parse the SQL to filter inputs; ``run_sql_raw`` requires an explicit ``out_type`` and does not parse/filter inputs.
207
+ """
208
+ # Extract required table names from SQL (order by first appearance)
209
+ parsed = sg.parse_one(query, read=dialect)
210
+ required_names: list[str] = []
211
+ for t in parsed.find_all(sge.Table):
212
+ # Prefer .name; fallback to str(this) if needed
213
+ tname = getattr(t, "name", None) or str(t.this)
214
+ if tname not in required_names:
215
+ required_names.append(tname)
216
+
217
+ # Disallow extras not referenced by the query to avoid surprises
218
+ extra = set(in_tables.keys()) - set(required_names)
219
+ if extra:
220
+ raise ValueError(
221
+ f"Unexpected tables provided that are not referenced in SQL: {sorted(extra)}"
222
+ )
223
+
224
+ # Validate required tables and require MPObject for runtime registration
225
+ in_names: list[str] = []
226
+ ins_info: list[TableType] = []
227
+ in_vars: list[MPObject] = []
228
+ for name in required_names:
229
+ if name not in in_tables:
230
+ raise KeyError(f"Missing required table '{name}' for SQL query")
231
+ obj = in_tables[name]
232
+ if not isinstance(obj, MPObject):
233
+ raise TypeError(
234
+ f"Table '{name}' must be an MPObject (for runtime registration), got {type(obj).__name__}"
235
+ )
236
+ assert obj.schema is not None, f"Input table '{name}' missing schema"
237
+ in_vars.append(obj)
238
+ ins_info.append(obj.schema)
239
+ in_names.append(name)
240
+
241
+ if out_type is None:
242
+ in_schemas: dict[str, TableType] = {
243
+ n: in_tables[n].schema for n in required_names
244
+ }
245
+ out_type = _deduce_out_schema(parsed, dialect, in_schemas)
246
+
247
+ pfn = PFunction(
248
+ fn_type="sql.run",
249
+ ins_info=tuple(ins_info),
250
+ outs_info=(out_type,),
251
+ fn_name="",
252
+ fn_text=query,
253
+ in_names=tuple(in_names),
254
+ dialect=dialect,
255
+ )
256
+ _, treedef = tree_flatten(out_type)
257
+ return pfn, in_vars, treedef
258
+
259
+
260
+ @_SQL_MOD.op_def()
261
+ def run_sql_raw(
262
+ query: str,
263
+ out_type: TableType,
264
+ *,
265
+ dialect: str = "duckdb",
266
+ in_tables: dict[str, MPObject] | None = None,
267
+ ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
268
+ """Build a sql.run PFunction from a SQL query with an explicit output schema.
269
+
270
+ API: run_sql_raw(query: str, out_type: TableType, *, dialect: str = "duckdb", in_tables: dict[str, MPObject] | None = None) -> (PFunction, [MPObject], PyTreeDef)
271
+
272
+ Semantics:
273
+ - Does not parse the SQL; carries all tables provided via ``in_tables`` in the mapping's iteration order.
274
+ - Requires an explicit ``out_type``; no schema deduction is attempted.
275
+ - Returns a triad consisting of the constructed PFunction (``fn_type='sql.run'``), the ordered list of input MPObjects, and the output PyTreeDef.
276
+
277
+ Difference vs ``run_sql``: this op requires ``out_type`` and does not parse/filter inputs; ``run_sql`` can infer ``out_type`` and selects only tables referenced by the query.
278
+ """
279
+
280
+ # Collect inputs strictly as provided by caller
281
+ in_names: list[str] = []
282
+ ins_info: list[TableType] = []
283
+ in_vars: list[MPObject] = []
284
+ if in_tables:
285
+ for name, tbl in in_tables.items():
286
+ if not isinstance(tbl, MPObject):
287
+ raise TypeError(f"Input table '{name}' is not an MPObject {type(tbl)}")
288
+ assert tbl.schema is not None, f"Input table '{name}' is missing a schema"
289
+ in_names.append(name)
290
+ ins_info.append(tbl.schema)
291
+ in_vars.append(tbl)
292
+
293
+ pfn = PFunction(
294
+ fn_type="sql.run",
295
+ fn_name="",
296
+ fn_text=query,
297
+ ins_info=tuple(ins_info),
298
+ outs_info=(out_type,),
299
+ in_names=tuple(in_names),
300
+ dialect=dialect,
301
+ )
302
+ _, treedef = tree_flatten(out_type)
303
+ return pfn, in_vars, treedef
@@ -14,13 +14,8 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- from jax.tree_util import PyTreeDef, tree_flatten
18
-
19
- from mplang.core.dtype import UINT8
20
- from mplang.core.mpobject import MPObject
21
- from mplang.core.pfunc import PFunction
22
- from mplang.core.tensor import TensorType
23
- from mplang.ops.base import stateless_mod
17
+ from mplang.v1.core import UINT8, TensorType
18
+ from mplang.v1.ops.base import stateless_mod
24
19
 
25
20
  _TEE_MOD = stateless_mod("tee")
26
21
 
@@ -32,20 +27,10 @@ def quote_gen(pk: TensorType) -> TensorType:
32
27
  return TensorType(UINT8, (-1,))
33
28
 
34
29
 
35
- @_TEE_MOD.op_def()
36
- def attest(
37
- quote: MPObject, platform: str
38
- ) -> tuple[PFunction, list[MPObject], PyTreeDef]:
39
- """TEE quote verification returning the attested TEE public key."""
40
-
41
- ins_info = [TensorType.from_obj(quote)]
42
- outs_info = [TensorType(UINT8, (32,))] # pk is always 32 bytes for x25519
43
- pfunc = PFunction(
44
- fn_type="tee.attest",
45
- ins_info=ins_info,
46
- outs_info=outs_info,
47
- platform=platform,
48
- )
49
- _, treedef = tree_flatten(outs_info[0])
50
-
51
- return pfunc, [quote], treedef
30
+ @_TEE_MOD.simple_op()
31
+ def attest(quote: TensorType) -> TensorType:
32
+ """TEE quote verification returning the attested TEE public key.
33
+ API (mock): attest(quote: u8[33]) -> tee_pk: u8[32]
34
+ """
35
+ _ = quote # Mark as used for the decorator
36
+ return TensorType(UINT8, (32,))
@@ -38,6 +38,7 @@ class _DataType:
38
38
  class _DataTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_DataType.ValueType], builtins.type):
39
39
  DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
40
40
  UNDEFINED: _DataType.ValueType # 0
41
+ """Undefined data type"""
41
42
  U8: _DataType.ValueType # 1
42
43
  """uint8_t"""
43
44
  I8: _DataType.ValueType # 2
@@ -83,6 +84,7 @@ class DataType(_DataType, metaclass=_DataTypeEnumTypeWrapper):
83
84
  """Data type enumeration"""
84
85
 
85
86
  UNDEFINED: DataType.ValueType # 0
87
+ """Undefined data type"""
86
88
  U8: DataType.ValueType # 1
87
89
  """uint8_t"""
88
90
  I8: DataType.ValueType # 2
@@ -138,14 +140,23 @@ class AttrProto(google.protobuf.message.Message):
138
140
  class _AttrTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[AttrProto._AttrType.ValueType], builtins.type):
139
141
  DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
140
142
  UNDEFINED: AttrProto._AttrType.ValueType # 0
143
+ """Undefined attribute type"""
141
144
  FLOAT: AttrProto._AttrType.ValueType # 1
145
+ """Single float value"""
142
146
  INT: AttrProto._AttrType.ValueType # 2
147
+ """Single integer value"""
143
148
  STRING: AttrProto._AttrType.ValueType # 3
149
+ """Single string value"""
144
150
  BOOL: AttrProto._AttrType.ValueType # 4
151
+ """Single boolean value"""
145
152
  BYTES: AttrProto._AttrType.ValueType # 5
153
+ """Binary data"""
146
154
  FLOATS: AttrProto._AttrType.ValueType # 6
155
+ """Array of float values"""
147
156
  INTS: AttrProto._AttrType.ValueType # 7
157
+ """Array of integer values"""
148
158
  STRINGS: AttrProto._AttrType.ValueType # 8
159
+ """Array of string values"""
149
160
  FUNCTION: AttrProto._AttrType.ValueType # 10
150
161
  """Textual function reference"""
151
162
  GRAPH: AttrProto._AttrType.ValueType # 11
@@ -155,14 +166,23 @@ class AttrProto(google.protobuf.message.Message):
155
166
  """Define possible attribute types"""
156
167
 
157
168
  UNDEFINED: AttrProto.AttrType.ValueType # 0
169
+ """Undefined attribute type"""
158
170
  FLOAT: AttrProto.AttrType.ValueType # 1
171
+ """Single float value"""
159
172
  INT: AttrProto.AttrType.ValueType # 2
173
+ """Single integer value"""
160
174
  STRING: AttrProto.AttrType.ValueType # 3
175
+ """Single string value"""
161
176
  BOOL: AttrProto.AttrType.ValueType # 4
177
+ """Single boolean value"""
162
178
  BYTES: AttrProto.AttrType.ValueType # 5
179
+ """Binary data"""
163
180
  FLOATS: AttrProto.AttrType.ValueType # 6
181
+ """Array of float values"""
164
182
  INTS: AttrProto.AttrType.ValueType # 7
183
+ """Array of integer values"""
165
184
  STRINGS: AttrProto.AttrType.ValueType # 8
185
+ """Array of string values"""
166
186
  FUNCTION: AttrProto.AttrType.ValueType # 10
167
187
  """Textual function reference"""
168
188
  GRAPH: AttrProto.AttrType.ValueType # 11
@@ -182,24 +202,24 @@ class AttrProto(google.protobuf.message.Message):
182
202
  type: global___AttrProto.AttrType.ValueType
183
203
  """Type of the attribute"""
184
204
  f: builtins.float
185
- """FLOAT"""
205
+ """FLOAT value"""
186
206
  i: builtins.int
187
- """INT"""
207
+ """INT value"""
188
208
  s: builtins.str
189
- """STRING"""
209
+ """STRING value"""
190
210
  b: builtins.bool
191
- """BOOL"""
211
+ """BOOL value"""
192
212
  raw_bytes: builtins.bytes
193
213
  """BYTES - for raw binary data"""
194
214
  @property
195
215
  def floats(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.float]:
196
- """FLOATS"""
216
+ """FLOATS - array of float values"""
197
217
  @property
198
218
  def ints(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
199
- """INTS"""
219
+ """INTS - array of integer values"""
200
220
  @property
201
221
  def strs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
202
- """STRINGS"""
222
+ """STRINGS - array of string values"""
203
223
  @property
204
224
  def func(self) -> global___FuncProto:
205
225
  """FUNCTION - textual function reference"""
@@ -228,6 +248,8 @@ global___AttrProto = AttrProto
228
248
 
229
249
  @typing_extensions.final
230
250
  class FuncProto(google.protobuf.message.Message):
251
+ """Function prototype message"""
252
+
231
253
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
232
254
 
233
255
  @typing_extensions.final
@@ -254,7 +276,7 @@ class FuncProto(google.protobuf.message.Message):
254
276
  DOC_STRING_FIELD_NUMBER: builtins.int
255
277
  ATTRS_FIELD_NUMBER: builtins.int
256
278
  type: builtins.str
257
- """Function type."""
279
+ """Function type"""
258
280
  name: builtins.str
259
281
  """Function name"""
260
282
  body: builtins.str
@@ -279,13 +301,17 @@ global___FuncProto = FuncProto
279
301
 
280
302
  @typing_extensions.final
281
303
  class TensorTypeProto(google.protobuf.message.Message):
304
+ """Tensor type definition"""
305
+
282
306
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
283
307
 
284
308
  DTYPE_FIELD_NUMBER: builtins.int
285
309
  SHAPE_DIMS_FIELD_NUMBER: builtins.int
286
310
  dtype: global___DataType.ValueType
311
+ """Data type of the tensor elements"""
287
312
  @property
288
- def shape_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
313
+ def shape_dims(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]:
314
+ """Shape dimensions of the tensor"""
289
315
  def __init__(
290
316
  self,
291
317
  *,
@@ -298,16 +324,22 @@ global___TensorTypeProto = TensorTypeProto
298
324
 
299
325
  @typing_extensions.final
300
326
  class TableTypeProto(google.protobuf.message.Message):
327
+ """Table type definition"""
328
+
301
329
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
302
330
 
303
331
  @typing_extensions.final
304
332
  class Column(google.protobuf.message.Message):
333
+ """Column definition within a table"""
334
+
305
335
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
306
336
 
307
337
  NAME_FIELD_NUMBER: builtins.int
308
338
  DTYPE_FIELD_NUMBER: builtins.int
309
339
  name: builtins.str
340
+ """Name of the column"""
310
341
  dtype: global___DataType.ValueType
342
+ """Data type of the column"""
311
343
  def __init__(
312
344
  self,
313
345
  *,
@@ -318,7 +350,8 @@ class TableTypeProto(google.protobuf.message.Message):
318
350
 
319
351
  COLUMNS_FIELD_NUMBER: builtins.int
320
352
  @property
321
- def columns(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TableTypeProto.Column]: ...
353
+ def columns(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TableTypeProto.Column]:
354
+ """List of columns in the table"""
322
355
  def __init__(
323
356
  self,
324
357
  *,
@@ -330,6 +363,8 @@ global___TableTypeProto = TableTypeProto
330
363
 
331
364
  @typing_extensions.final
332
365
  class MPTypeProto(google.protobuf.message.Message):
366
+ """Multi-party type definition"""
367
+
333
368
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
334
369
 
335
370
  @typing_extensions.final
@@ -355,14 +390,16 @@ class MPTypeProto(google.protobuf.message.Message):
355
390
  PMASK_FIELD_NUMBER: builtins.int
356
391
  ATTRS_FIELD_NUMBER: builtins.int
357
392
  @property
358
- def tensor_type(self) -> global___TensorTypeProto: ...
393
+ def tensor_type(self) -> global___TensorTypeProto:
394
+ """Tensor type specification"""
359
395
  @property
360
- def table_type(self) -> global___TableTypeProto: ...
396
+ def table_type(self) -> global___TableTypeProto:
397
+ """Table type specification"""
361
398
  pmask: builtins.int
362
- """party mask (-1 for dynamic mask, >=0 for static mask)"""
399
+ """Party mask (-1 for dynamic mask, >=0 for static mask)"""
363
400
  @property
364
401
  def attrs(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___AttrProto]:
365
- """attributes"""
402
+ """Additional attributes"""
366
403
  def __init__(
367
404
  self,
368
405
  *,
@@ -379,6 +416,8 @@ global___MPTypeProto = MPTypeProto
379
416
 
380
417
  @typing_extensions.final
381
418
  class NodeProto(google.protobuf.message.Message):
419
+ """Node prototype definition"""
420
+
382
421
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
383
422
 
384
423
  @typing_extensions.final
@@ -406,16 +445,20 @@ class NodeProto(google.protobuf.message.Message):
406
445
  ATTRS_FIELD_NUMBER: builtins.int
407
446
  DOC_STRING_FIELD_NUMBER: builtins.int
408
447
  op_type: builtins.str
448
+ """Operation type of the node"""
409
449
  name: builtins.str
450
+ """Name of the node"""
410
451
  @property
411
452
  def inputs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
412
- """{name:index}"""
453
+ """Input specifications in format {name:index}"""
413
454
  @property
414
455
  def outs_info(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___MPTypeProto]:
415
- """"""
456
+ """Output type information"""
416
457
  @property
417
- def attrs(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___AttrProto]: ...
458
+ def attrs(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___AttrProto]:
459
+ """Node attributes"""
418
460
  doc_string: builtins.str
461
+ """Documentation string"""
419
462
  def __init__(
420
463
  self,
421
464
  *,
@@ -432,6 +475,8 @@ global___NodeProto = NodeProto
432
475
 
433
476
  @typing_extensions.final
434
477
  class VersionInfo(google.protobuf.message.Message):
478
+ """Version information definition"""
479
+
435
480
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
436
481
 
437
482
  MAJOR_FIELD_NUMBER: builtins.int
@@ -460,6 +505,8 @@ global___VersionInfo = VersionInfo
460
505
 
461
506
  @typing_extensions.final
462
507
  class GraphProto(google.protobuf.message.Message):
508
+ """Graph prototype definition"""
509
+
463
510
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
464
511
 
465
512
  @typing_extensions.final
@@ -485,14 +532,17 @@ class GraphProto(google.protobuf.message.Message):
485
532
  OUTPUTS_FIELD_NUMBER: builtins.int
486
533
  ATTRS_FIELD_NUMBER: builtins.int
487
534
  @property
488
- def version(self) -> global___VersionInfo: ...
535
+ def version(self) -> global___VersionInfo:
536
+ """Version information"""
489
537
  @property
490
- def nodes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___NodeProto]: ...
538
+ def nodes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___NodeProto]:
539
+ """List of nodes in the graph"""
491
540
  @property
492
541
  def outputs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
493
- """{name:index}"""
542
+ """Output specifications in format {name:index}"""
494
543
  @property
495
- def attrs(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___AttrProto]: ...
544
+ def attrs(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, global___AttrProto]:
545
+ """Graph attributes"""
496
546
  def __init__(
497
547
  self,
498
548
  *,
@@ -0,0 +1,34 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: mplang/protos/v1alpha1/value.proto
4
+ # Protobuf Python Version: 5.26.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"mplang/protos/v1alpha1/value.proto\x12\x16mplang.protos.v1alpha1\"\xe8\x02\n\x0eValueAttrProto\x12\x43\n\x04type\x18\x01 \x01(\x0e\x32/.mplang.protos.v1alpha1.ValueAttrProto.AttrTypeR\x04type\x12\x0c\n\x01\x66\x18\x02 \x01(\x02R\x01\x66\x12\x0c\n\x01i\x18\x03 \x01(\x03R\x01i\x12\x0c\n\x01s\x18\x04 \x01(\tR\x01s\x12\x0c\n\x01\x62\x18\x05 \x01(\x08R\x01\x62\x12\x1b\n\traw_bytes\x18\x06 \x01(\x0cR\x08rawBytes\x12\x16\n\x06\x66loats\x18\x07 \x03(\x02R\x06\x66loats\x12\x12\n\x04ints\x18\x08 \x03(\x03R\x04ints\x12\x12\n\x04strs\x18\t \x03(\tR\x04strs\"|\n\x08\x41ttrType\x12\r\n\tUNDEFINED\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\x07\n\x03INT\x10\x02\x12\n\n\x06STRING\x10\x03\x12\x08\n\x04\x42OOL\x10\x04\x12\t\n\x05\x42YTES\x10\x05\x12\n\n\x06\x46LOATS\x10\x06\x12\x08\n\x04INTS\x10\x07\x12\x0b\n\x07STRINGS\x10\x08\x12\t\n\x05\x45MPTY\x10\t\"\xa3\x02\n\nValueProto\x12\x12\n\x04kind\x18\x01 \x01(\tR\x04kind\x12#\n\rvalue_version\x18\x02 \x01(\rR\x0cvalueVersion\x12\x18\n\x07payload\x18\x03 \x01(\x0cR\x07payload\x12Y\n\rruntime_attrs\x18\x04 \x03(\x0b\x32\x34.mplang.protos.v1alpha1.ValueProto.RuntimeAttrsEntryR\x0cruntimeAttrs\x1ag\n\x11RuntimeAttrsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12<\n\x05value\x18\x02 \x01(\x0b\x32&.mplang.protos.v1alpha1.ValueAttrProtoR\x05value:\x02\x38\x01\x62\x06proto3')
18
+
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'mplang.protos.v1alpha1.value_pb2', _globals)
22
+ if not _descriptor._USE_C_DESCRIPTORS:
23
+ DESCRIPTOR._loaded_options = None
24
+ _globals['_VALUEPROTO_RUNTIMEATTRSENTRY']._loaded_options = None
25
+ _globals['_VALUEPROTO_RUNTIMEATTRSENTRY']._serialized_options = b'8\001'
26
+ _globals['_VALUEATTRPROTO']._serialized_start=63
27
+ _globals['_VALUEATTRPROTO']._serialized_end=423
28
+ _globals['_VALUEATTRPROTO_ATTRTYPE']._serialized_start=299
29
+ _globals['_VALUEATTRPROTO_ATTRTYPE']._serialized_end=423
30
+ _globals['_VALUEPROTO']._serialized_start=426
31
+ _globals['_VALUEPROTO']._serialized_end=717
32
+ _globals['_VALUEPROTO_RUNTIMEATTRSENTRY']._serialized_start=614
33
+ _globals['_VALUEPROTO_RUNTIMEATTRSENTRY']._serialized_end=717
34
+ # @@protoc_insertion_point(module_scope)