mplang-nightly 0.1.dev151__py3-none-any.whl → 0.1.dev153__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplang/core/cluster.py +95 -24
- mplang/kernels/sql_duckdb.py +5 -0
- mplang/ops/base.py +2 -2
- mplang/ops/ibis_cc.py +1 -0
- mplang/ops/sql.py +2 -1
- mplang/runtime/client.py +4 -18
- mplang/runtime/driver.py +1 -6
- mplang/runtime/server.py +95 -49
- mplang/runtime/session.py +285 -0
- mplang/runtime/simulation.py +15 -13
- {mplang_nightly-0.1.dev151.dist-info → mplang_nightly-0.1.dev153.dist-info}/METADATA +1 -1
- {mplang_nightly-0.1.dev151.dist-info → mplang_nightly-0.1.dev153.dist-info}/RECORD +15 -15
- mplang/runtime/resource.py +0 -365
- {mplang_nightly-0.1.dev151.dist-info → mplang_nightly-0.1.dev153.dist-info}/WHEEL +0 -0
- {mplang_nightly-0.1.dev151.dist-info → mplang_nightly-0.1.dev153.dist-info}/entry_points.txt +0 -0
- {mplang_nightly-0.1.dev151.dist-info → mplang_nightly-0.1.dev153.dist-info}/licenses/LICENSE +0 -0
mplang/runtime/resource.py
DELETED
@@ -1,365 +0,0 @@
|
|
1
|
-
# Copyright 2025 Ant Group Co., Ltd.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
"""
|
16
|
-
This module provides the resource management for the HTTP backend.
|
17
|
-
It is a simplified, in-memory version of the original executor's resource manager.
|
18
|
-
"""
|
19
|
-
|
20
|
-
import base64
|
21
|
-
import logging
|
22
|
-
from dataclasses import dataclass, field
|
23
|
-
from typing import Any
|
24
|
-
from urllib.parse import urlparse
|
25
|
-
|
26
|
-
import cloudpickle as pickle
|
27
|
-
import spu.libspu as libspu
|
28
|
-
|
29
|
-
from mplang.core.expr.ast import Expr
|
30
|
-
from mplang.core.expr.evaluator import IEvaluator, create_evaluator
|
31
|
-
from mplang.core.mask import Mask
|
32
|
-
from mplang.kernels.context import RuntimeContext
|
33
|
-
from mplang.kernels.spu import PFunction # type: ignore
|
34
|
-
from mplang.runtime.communicator import HttpCommunicator
|
35
|
-
from mplang.runtime.exceptions import InvalidRequestError, ResourceNotFound
|
36
|
-
from mplang.runtime.link_comm import LinkCommunicator
|
37
|
-
from mplang.utils.spu_utils import parse_field, parse_protocol
|
38
|
-
|
39
|
-
|
40
|
-
class LinkCommFactory:
|
41
|
-
"""Factory for creating and caching link communicators."""
|
42
|
-
|
43
|
-
def __init__(self) -> None:
|
44
|
-
self._cache: dict[tuple[int, tuple[str, ...]], LinkCommunicator] = {}
|
45
|
-
|
46
|
-
def create_link(self, rank: int, addrs: list[str]) -> LinkCommunicator:
|
47
|
-
key = (rank, tuple(addrs))
|
48
|
-
val = self._cache.get(key, None)
|
49
|
-
if val is not None:
|
50
|
-
return val
|
51
|
-
|
52
|
-
logging.info(f"LinkCommunicator created: {rank} {addrs}")
|
53
|
-
new_link = LinkCommunicator(rank, addrs)
|
54
|
-
self._cache[key] = new_link
|
55
|
-
return new_link
|
56
|
-
|
57
|
-
|
58
|
-
# Global link factory instance
|
59
|
-
g_link_factory = LinkCommFactory()
|
60
|
-
|
61
|
-
|
62
|
-
@dataclass
|
63
|
-
class Symbol:
|
64
|
-
name: str
|
65
|
-
mptype: Any # More flexible type to handle dict or MPType
|
66
|
-
data: Any # More flexible data type
|
67
|
-
|
68
|
-
|
69
|
-
@dataclass
|
70
|
-
class Computation:
|
71
|
-
name: str
|
72
|
-
expr: Expr # The computation expression
|
73
|
-
|
74
|
-
|
75
|
-
@dataclass
|
76
|
-
class Session:
|
77
|
-
name: str
|
78
|
-
communicator: HttpCommunicator
|
79
|
-
computations: dict[str, Computation] = field(default_factory=dict)
|
80
|
-
symbols: dict[str, Symbol] = field(default_factory=dict) # Session-level symbols
|
81
|
-
|
82
|
-
# spu related
|
83
|
-
spu_mask: int = -1
|
84
|
-
spu_protocol: str = "SEMI2K"
|
85
|
-
spu_field: str = "FM64"
|
86
|
-
|
87
|
-
|
88
|
-
# Global session storage
|
89
|
-
_sessions: dict[str, Session] = {}
|
90
|
-
|
91
|
-
|
92
|
-
# Session Management
|
93
|
-
def create_session(
|
94
|
-
name: str,
|
95
|
-
rank: int,
|
96
|
-
endpoints: list[str],
|
97
|
-
# SPU related
|
98
|
-
spu_mask: int = 0,
|
99
|
-
spu_protocol: str = "SEMI2K",
|
100
|
-
spu_field: str = "FM64",
|
101
|
-
) -> Session:
|
102
|
-
logging.info(f"Creating session: {name}, rank: {rank}, spu_mask: {spu_mask}")
|
103
|
-
if name in _sessions:
|
104
|
-
# Return existing session (idempotent operation)
|
105
|
-
logging.info(f"Session {name} already exists, returning existing session")
|
106
|
-
return _sessions[name]
|
107
|
-
session = Session(
|
108
|
-
name, HttpCommunicator(session_name=name, rank=rank, endpoints=endpoints)
|
109
|
-
)
|
110
|
-
|
111
|
-
session.spu_mask = spu_mask
|
112
|
-
session.spu_protocol = spu_protocol
|
113
|
-
session.spu_field = spu_field
|
114
|
-
|
115
|
-
_sessions[name] = session
|
116
|
-
logging.info(f"Session {name} created successfully")
|
117
|
-
return session
|
118
|
-
|
119
|
-
|
120
|
-
def get_session(name: str) -> Session | None:
|
121
|
-
return _sessions.get(name)
|
122
|
-
|
123
|
-
|
124
|
-
def delete_session(name: str) -> bool:
|
125
|
-
"""Delete a session and all associated resources.
|
126
|
-
|
127
|
-
Returns:
|
128
|
-
True if session was deleted, False if session was not found.
|
129
|
-
"""
|
130
|
-
if name in _sessions:
|
131
|
-
del _sessions[name]
|
132
|
-
logging.info(f"Session {name} deleted successfully")
|
133
|
-
return True
|
134
|
-
return False
|
135
|
-
|
136
|
-
|
137
|
-
# Global symbol management (process-wide, not per-session)
|
138
|
-
_global_symbols: dict[str, Symbol] = {}
|
139
|
-
|
140
|
-
|
141
|
-
def create_global_symbol(name: str, mptype: dict[str, Any], data_b64: str) -> Symbol:
|
142
|
-
"""Create or replace a global symbol.
|
143
|
-
|
144
|
-
Args:
|
145
|
-
name: Symbol identifier
|
146
|
-
mptype: Metadata dict (shape/dtype, etc.)
|
147
|
-
data_b64: Base64-encoded pickled data
|
148
|
-
"""
|
149
|
-
try:
|
150
|
-
raw = base64.b64decode(data_b64)
|
151
|
-
data = pickle.loads(raw)
|
152
|
-
except Exception as e: # pragma: no cover - defensive
|
153
|
-
raise InvalidRequestError(f"Failed to decode symbol payload: {e}") from e
|
154
|
-
sym = Symbol(name=name, mptype=mptype, data=data)
|
155
|
-
_global_symbols[name] = sym
|
156
|
-
return sym
|
157
|
-
|
158
|
-
|
159
|
-
def get_global_symbol(name: str) -> Symbol:
|
160
|
-
sym = _global_symbols.get(name)
|
161
|
-
if sym is None:
|
162
|
-
raise ResourceNotFound(f"Global symbol '{name}' not found")
|
163
|
-
return sym
|
164
|
-
|
165
|
-
|
166
|
-
def delete_global_symbol(name: str) -> bool:
|
167
|
-
return _global_symbols.pop(name, None) is not None
|
168
|
-
|
169
|
-
|
170
|
-
def list_global_symbols() -> list[str]: # pragma: no cover - trivial
|
171
|
-
return sorted(_global_symbols.keys())
|
172
|
-
|
173
|
-
|
174
|
-
# Computation Management
|
175
|
-
def create_computation(
|
176
|
-
session_name: str, computation_name: str, expr: Expr
|
177
|
-
) -> Computation:
|
178
|
-
"""Creates a computation resource within a session."""
|
179
|
-
session = get_session(session_name)
|
180
|
-
if not session:
|
181
|
-
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
182
|
-
computation = Computation(computation_name, expr)
|
183
|
-
session.computations[computation_name] = computation
|
184
|
-
logging.info(f"Computation {computation_name} created for session {session_name}")
|
185
|
-
return computation
|
186
|
-
|
187
|
-
|
188
|
-
def get_computation(session_name: str, comp_name: str) -> Computation | None:
|
189
|
-
session = get_session(session_name)
|
190
|
-
if session:
|
191
|
-
return session.computations.get(comp_name)
|
192
|
-
return None
|
193
|
-
|
194
|
-
|
195
|
-
def delete_computation(session_name: str, comp_name: str) -> bool:
|
196
|
-
"""Delete a computation from a session.
|
197
|
-
|
198
|
-
Returns:
|
199
|
-
True if computation was deleted, False if not found.
|
200
|
-
"""
|
201
|
-
session = get_session(session_name)
|
202
|
-
if not session:
|
203
|
-
return False
|
204
|
-
|
205
|
-
if comp_name in session.computations:
|
206
|
-
del session.computations[comp_name]
|
207
|
-
logging.info(f"Computation {comp_name} deleted from session {session_name}")
|
208
|
-
return True
|
209
|
-
return False
|
210
|
-
|
211
|
-
|
212
|
-
def execute_computation(
|
213
|
-
session_name: str, comp_name: str, input_names: list[str], output_names: list[str]
|
214
|
-
) -> None:
|
215
|
-
"""Execute a computation using the Evaluator."""
|
216
|
-
session = get_session(session_name)
|
217
|
-
if not session:
|
218
|
-
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
219
|
-
|
220
|
-
computation = get_computation(session_name, comp_name)
|
221
|
-
if not computation:
|
222
|
-
raise ResourceNotFound(
|
223
|
-
f"Computation '{comp_name}' not found in session '{session_name}'."
|
224
|
-
)
|
225
|
-
|
226
|
-
if not session.communicator:
|
227
|
-
raise InvalidRequestError(
|
228
|
-
f"Communicator not initialized for session '{session_name}'."
|
229
|
-
)
|
230
|
-
|
231
|
-
# Get rank from session communicator
|
232
|
-
rank = session.communicator.rank
|
233
|
-
|
234
|
-
# Prepare input bindings from session symbols
|
235
|
-
bindings = {}
|
236
|
-
for input_name in input_names:
|
237
|
-
symbol = get_symbol(session_name, input_name)
|
238
|
-
if not symbol:
|
239
|
-
raise ResourceNotFound(
|
240
|
-
f"Input symbol '{input_name}' not found in session '{session_name}'"
|
241
|
-
)
|
242
|
-
bindings[input_name] = symbol.data
|
243
|
-
|
244
|
-
spu_mask = (
|
245
|
-
Mask(session.spu_mask)
|
246
|
-
if session.spu_mask != -1
|
247
|
-
else Mask.all(session.communicator.world_size)
|
248
|
-
)
|
249
|
-
|
250
|
-
# Build evaluator
|
251
|
-
# Explicit per-rank backend runtime (deglobalized)
|
252
|
-
runtime = RuntimeContext(rank=rank, world_size=session.communicator.world_size)
|
253
|
-
evaluator: IEvaluator = create_evaluator(
|
254
|
-
rank=rank, env=bindings, comm=session.communicator, runtime=runtime
|
255
|
-
)
|
256
|
-
|
257
|
-
# Initialize SPU runtime state for flat kernels (once per evaluator invocation)
|
258
|
-
if rank in spu_mask:
|
259
|
-
# Build SPU address list (only once per rank; consistent ordering of participating ranks)
|
260
|
-
spu_addrs: list[str] = []
|
261
|
-
for r, addr in enumerate(session.communicator.endpoints):
|
262
|
-
if r in spu_mask:
|
263
|
-
if "://" not in addr:
|
264
|
-
addr = f"//{addr}"
|
265
|
-
parsed = urlparse(addr)
|
266
|
-
assert isinstance(parsed.port, int)
|
267
|
-
new_addr = f"{parsed.hostname}:{parsed.port + 100}"
|
268
|
-
spu_addrs.append(new_addr)
|
269
|
-
# Determine this rank's relative index among participating ranks
|
270
|
-
rel_index = sum(1 for r in range(rank) if r in spu_mask)
|
271
|
-
link_ctx = g_link_factory.create_link(rel_index, spu_addrs)
|
272
|
-
else:
|
273
|
-
link_ctx = None
|
274
|
-
# Always seed config/world; provide per-rank link (may be None if not participating)
|
275
|
-
spu_config = libspu.RuntimeConfig(
|
276
|
-
protocol=parse_protocol(session.spu_protocol),
|
277
|
-
field=parse_field(session.spu_field),
|
278
|
-
fxp_fraction_bits=18,
|
279
|
-
)
|
280
|
-
# Seed SPU env via backend kernel (inside evaluator's kernel context)
|
281
|
-
seed_pfunc = PFunction(
|
282
|
-
fn_type="spu.seed_env",
|
283
|
-
ins_info=(),
|
284
|
-
outs_info=(),
|
285
|
-
config=spu_config,
|
286
|
-
world=spu_mask.num_parties(),
|
287
|
-
link=link_ctx,
|
288
|
-
)
|
289
|
-
# Run seeding kernel with evaluator (no inputs, no outputs)
|
290
|
-
evaluator.runtime.run_kernel(seed_pfunc, [])
|
291
|
-
|
292
|
-
results = evaluator.evaluate(computation.expr)
|
293
|
-
|
294
|
-
# Store results in session symbols using output_names
|
295
|
-
if results:
|
296
|
-
if len(results) != len(output_names):
|
297
|
-
raise RuntimeError(
|
298
|
-
f"Expected {len(output_names)} results, got {len(results)}"
|
299
|
-
)
|
300
|
-
for name, val in zip(output_names, results, strict=True):
|
301
|
-
session.symbols[name] = Symbol(name=name, mptype={}, data=val)
|
302
|
-
|
303
|
-
|
304
|
-
# Symbol Management
|
305
|
-
def create_symbol(session_name: str, name: str, mptype: Any, data: Any) -> Symbol:
|
306
|
-
"""Create a symbol in a session's symbol table.
|
307
|
-
|
308
|
-
The `data` is expected to be a base64-encoded pickled Python object.
|
309
|
-
"""
|
310
|
-
session = get_session(session_name)
|
311
|
-
if not session:
|
312
|
-
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
313
|
-
|
314
|
-
# Deserialize base64-encoded data to Python object
|
315
|
-
try:
|
316
|
-
data_bytes = base64.b64decode(data)
|
317
|
-
obj = pickle.loads(data_bytes)
|
318
|
-
except Exception as e:
|
319
|
-
raise InvalidRequestError(f"Invalid symbol data encoding: {e!s}") from e
|
320
|
-
|
321
|
-
symbol = Symbol(name, mptype, obj)
|
322
|
-
session.symbols[name] = symbol
|
323
|
-
return symbol
|
324
|
-
|
325
|
-
|
326
|
-
def get_symbol(session_name: str, name: str) -> Symbol | None:
|
327
|
-
"""Get a symbol from a session's symbol table (session-level only)."""
|
328
|
-
session = get_session(session_name)
|
329
|
-
if not session:
|
330
|
-
return None
|
331
|
-
|
332
|
-
# Only session-level symbols are supported now
|
333
|
-
return session.symbols.get(name)
|
334
|
-
|
335
|
-
|
336
|
-
def list_symbols(session_name: str) -> list[str]:
|
337
|
-
"""List all symbols in a session's symbol table."""
|
338
|
-
session = get_session(session_name)
|
339
|
-
if not session:
|
340
|
-
raise ResourceNotFound(f"Session '{session_name}' not found.")
|
341
|
-
|
342
|
-
# Only session-level symbols are supported now
|
343
|
-
return list(session.symbols.keys())
|
344
|
-
|
345
|
-
|
346
|
-
def delete_symbol(session_name: str, symbol_name: str) -> bool:
|
347
|
-
"""Delete a symbol from a session.
|
348
|
-
|
349
|
-
Returns:
|
350
|
-
True if symbol was deleted, False if not found.
|
351
|
-
"""
|
352
|
-
session = get_session(session_name)
|
353
|
-
if not session:
|
354
|
-
return False
|
355
|
-
|
356
|
-
if symbol_name in session.symbols:
|
357
|
-
del session.symbols[symbol_name]
|
358
|
-
logging.info(f"Symbol {symbol_name} deleted from session {session_name}")
|
359
|
-
return True
|
360
|
-
return False
|
361
|
-
|
362
|
-
|
363
|
-
def list_all_sessions() -> list[str]:
|
364
|
-
"""List all session names."""
|
365
|
-
return list(_sessions.keys())
|
File without changes
|
{mplang_nightly-0.1.dev151.dist-info → mplang_nightly-0.1.dev153.dist-info}/entry_points.txt
RENAMED
File without changes
|
{mplang_nightly-0.1.dev151.dist-info → mplang_nightly-0.1.dev153.dist-info}/licenses/LICENSE
RENAMED
File without changes
|