mplang-nightly 0.1.dev151__py3-none-any.whl → 0.1.dev153__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,365 +0,0 @@
1
- # Copyright 2025 Ant Group Co., Ltd.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- This module provides the resource management for the HTTP backend.
17
- It is a simplified, in-memory version of the original executor's resource manager.
18
- """
19
-
20
- import base64
21
- import logging
22
- from dataclasses import dataclass, field
23
- from typing import Any
24
- from urllib.parse import urlparse
25
-
26
- import cloudpickle as pickle
27
- import spu.libspu as libspu
28
-
29
- from mplang.core.expr.ast import Expr
30
- from mplang.core.expr.evaluator import IEvaluator, create_evaluator
31
- from mplang.core.mask import Mask
32
- from mplang.kernels.context import RuntimeContext
33
- from mplang.kernels.spu import PFunction # type: ignore
34
- from mplang.runtime.communicator import HttpCommunicator
35
- from mplang.runtime.exceptions import InvalidRequestError, ResourceNotFound
36
- from mplang.runtime.link_comm import LinkCommunicator
37
- from mplang.utils.spu_utils import parse_field, parse_protocol
38
-
39
-
40
- class LinkCommFactory:
41
- """Factory for creating and caching link communicators."""
42
-
43
- def __init__(self) -> None:
44
- self._cache: dict[tuple[int, tuple[str, ...]], LinkCommunicator] = {}
45
-
46
- def create_link(self, rank: int, addrs: list[str]) -> LinkCommunicator:
47
- key = (rank, tuple(addrs))
48
- val = self._cache.get(key, None)
49
- if val is not None:
50
- return val
51
-
52
- logging.info(f"LinkCommunicator created: {rank} {addrs}")
53
- new_link = LinkCommunicator(rank, addrs)
54
- self._cache[key] = new_link
55
- return new_link
56
-
57
-
58
- # Global link factory instance
59
- g_link_factory = LinkCommFactory()
60
-
61
-
62
- @dataclass
63
- class Symbol:
64
- name: str
65
- mptype: Any # More flexible type to handle dict or MPType
66
- data: Any # More flexible data type
67
-
68
-
69
- @dataclass
70
- class Computation:
71
- name: str
72
- expr: Expr # The computation expression
73
-
74
-
75
- @dataclass
76
- class Session:
77
- name: str
78
- communicator: HttpCommunicator
79
- computations: dict[str, Computation] = field(default_factory=dict)
80
- symbols: dict[str, Symbol] = field(default_factory=dict) # Session-level symbols
81
-
82
- # spu related
83
- spu_mask: int = -1
84
- spu_protocol: str = "SEMI2K"
85
- spu_field: str = "FM64"
86
-
87
-
88
- # Global session storage
89
- _sessions: dict[str, Session] = {}
90
-
91
-
92
- # Session Management
93
- def create_session(
94
- name: str,
95
- rank: int,
96
- endpoints: list[str],
97
- # SPU related
98
- spu_mask: int = 0,
99
- spu_protocol: str = "SEMI2K",
100
- spu_field: str = "FM64",
101
- ) -> Session:
102
- logging.info(f"Creating session: {name}, rank: {rank}, spu_mask: {spu_mask}")
103
- if name in _sessions:
104
- # Return existing session (idempotent operation)
105
- logging.info(f"Session {name} already exists, returning existing session")
106
- return _sessions[name]
107
- session = Session(
108
- name, HttpCommunicator(session_name=name, rank=rank, endpoints=endpoints)
109
- )
110
-
111
- session.spu_mask = spu_mask
112
- session.spu_protocol = spu_protocol
113
- session.spu_field = spu_field
114
-
115
- _sessions[name] = session
116
- logging.info(f"Session {name} created successfully")
117
- return session
118
-
119
-
120
- def get_session(name: str) -> Session | None:
121
- return _sessions.get(name)
122
-
123
-
124
- def delete_session(name: str) -> bool:
125
- """Delete a session and all associated resources.
126
-
127
- Returns:
128
- True if session was deleted, False if session was not found.
129
- """
130
- if name in _sessions:
131
- del _sessions[name]
132
- logging.info(f"Session {name} deleted successfully")
133
- return True
134
- return False
135
-
136
-
137
- # Global symbol management (process-wide, not per-session)
138
- _global_symbols: dict[str, Symbol] = {}
139
-
140
-
141
- def create_global_symbol(name: str, mptype: dict[str, Any], data_b64: str) -> Symbol:
142
- """Create or replace a global symbol.
143
-
144
- Args:
145
- name: Symbol identifier
146
- mptype: Metadata dict (shape/dtype, etc.)
147
- data_b64: Base64-encoded pickled data
148
- """
149
- try:
150
- raw = base64.b64decode(data_b64)
151
- data = pickle.loads(raw)
152
- except Exception as e: # pragma: no cover - defensive
153
- raise InvalidRequestError(f"Failed to decode symbol payload: {e}") from e
154
- sym = Symbol(name=name, mptype=mptype, data=data)
155
- _global_symbols[name] = sym
156
- return sym
157
-
158
-
159
- def get_global_symbol(name: str) -> Symbol:
160
- sym = _global_symbols.get(name)
161
- if sym is None:
162
- raise ResourceNotFound(f"Global symbol '{name}' not found")
163
- return sym
164
-
165
-
166
- def delete_global_symbol(name: str) -> bool:
167
- return _global_symbols.pop(name, None) is not None
168
-
169
-
170
- def list_global_symbols() -> list[str]: # pragma: no cover - trivial
171
- return sorted(_global_symbols.keys())
172
-
173
-
174
- # Computation Management
175
- def create_computation(
176
- session_name: str, computation_name: str, expr: Expr
177
- ) -> Computation:
178
- """Creates a computation resource within a session."""
179
- session = get_session(session_name)
180
- if not session:
181
- raise ResourceNotFound(f"Session '{session_name}' not found.")
182
- computation = Computation(computation_name, expr)
183
- session.computations[computation_name] = computation
184
- logging.info(f"Computation {computation_name} created for session {session_name}")
185
- return computation
186
-
187
-
188
- def get_computation(session_name: str, comp_name: str) -> Computation | None:
189
- session = get_session(session_name)
190
- if session:
191
- return session.computations.get(comp_name)
192
- return None
193
-
194
-
195
- def delete_computation(session_name: str, comp_name: str) -> bool:
196
- """Delete a computation from a session.
197
-
198
- Returns:
199
- True if computation was deleted, False if not found.
200
- """
201
- session = get_session(session_name)
202
- if not session:
203
- return False
204
-
205
- if comp_name in session.computations:
206
- del session.computations[comp_name]
207
- logging.info(f"Computation {comp_name} deleted from session {session_name}")
208
- return True
209
- return False
210
-
211
-
212
- def execute_computation(
213
- session_name: str, comp_name: str, input_names: list[str], output_names: list[str]
214
- ) -> None:
215
- """Execute a computation using the Evaluator."""
216
- session = get_session(session_name)
217
- if not session:
218
- raise ResourceNotFound(f"Session '{session_name}' not found.")
219
-
220
- computation = get_computation(session_name, comp_name)
221
- if not computation:
222
- raise ResourceNotFound(
223
- f"Computation '{comp_name}' not found in session '{session_name}'."
224
- )
225
-
226
- if not session.communicator:
227
- raise InvalidRequestError(
228
- f"Communicator not initialized for session '{session_name}'."
229
- )
230
-
231
- # Get rank from session communicator
232
- rank = session.communicator.rank
233
-
234
- # Prepare input bindings from session symbols
235
- bindings = {}
236
- for input_name in input_names:
237
- symbol = get_symbol(session_name, input_name)
238
- if not symbol:
239
- raise ResourceNotFound(
240
- f"Input symbol '{input_name}' not found in session '{session_name}'"
241
- )
242
- bindings[input_name] = symbol.data
243
-
244
- spu_mask = (
245
- Mask(session.spu_mask)
246
- if session.spu_mask != -1
247
- else Mask.all(session.communicator.world_size)
248
- )
249
-
250
- # Build evaluator
251
- # Explicit per-rank backend runtime (deglobalized)
252
- runtime = RuntimeContext(rank=rank, world_size=session.communicator.world_size)
253
- evaluator: IEvaluator = create_evaluator(
254
- rank=rank, env=bindings, comm=session.communicator, runtime=runtime
255
- )
256
-
257
- # Initialize SPU runtime state for flat kernels (once per evaluator invocation)
258
- if rank in spu_mask:
259
- # Build SPU address list (only once per rank; consistent ordering of participating ranks)
260
- spu_addrs: list[str] = []
261
- for r, addr in enumerate(session.communicator.endpoints):
262
- if r in spu_mask:
263
- if "://" not in addr:
264
- addr = f"//{addr}"
265
- parsed = urlparse(addr)
266
- assert isinstance(parsed.port, int)
267
- new_addr = f"{parsed.hostname}:{parsed.port + 100}"
268
- spu_addrs.append(new_addr)
269
- # Determine this rank's relative index among participating ranks
270
- rel_index = sum(1 for r in range(rank) if r in spu_mask)
271
- link_ctx = g_link_factory.create_link(rel_index, spu_addrs)
272
- else:
273
- link_ctx = None
274
- # Always seed config/world; provide per-rank link (may be None if not participating)
275
- spu_config = libspu.RuntimeConfig(
276
- protocol=parse_protocol(session.spu_protocol),
277
- field=parse_field(session.spu_field),
278
- fxp_fraction_bits=18,
279
- )
280
- # Seed SPU env via backend kernel (inside evaluator's kernel context)
281
- seed_pfunc = PFunction(
282
- fn_type="spu.seed_env",
283
- ins_info=(),
284
- outs_info=(),
285
- config=spu_config,
286
- world=spu_mask.num_parties(),
287
- link=link_ctx,
288
- )
289
- # Run seeding kernel with evaluator (no inputs, no outputs)
290
- evaluator.runtime.run_kernel(seed_pfunc, [])
291
-
292
- results = evaluator.evaluate(computation.expr)
293
-
294
- # Store results in session symbols using output_names
295
- if results:
296
- if len(results) != len(output_names):
297
- raise RuntimeError(
298
- f"Expected {len(output_names)} results, got {len(results)}"
299
- )
300
- for name, val in zip(output_names, results, strict=True):
301
- session.symbols[name] = Symbol(name=name, mptype={}, data=val)
302
-
303
-
304
- # Symbol Management
305
- def create_symbol(session_name: str, name: str, mptype: Any, data: Any) -> Symbol:
306
- """Create a symbol in a session's symbol table.
307
-
308
- The `data` is expected to be a base64-encoded pickled Python object.
309
- """
310
- session = get_session(session_name)
311
- if not session:
312
- raise ResourceNotFound(f"Session '{session_name}' not found.")
313
-
314
- # Deserialize base64-encoded data to Python object
315
- try:
316
- data_bytes = base64.b64decode(data)
317
- obj = pickle.loads(data_bytes)
318
- except Exception as e:
319
- raise InvalidRequestError(f"Invalid symbol data encoding: {e!s}") from e
320
-
321
- symbol = Symbol(name, mptype, obj)
322
- session.symbols[name] = symbol
323
- return symbol
324
-
325
-
326
- def get_symbol(session_name: str, name: str) -> Symbol | None:
327
- """Get a symbol from a session's symbol table (session-level only)."""
328
- session = get_session(session_name)
329
- if not session:
330
- return None
331
-
332
- # Only session-level symbols are supported now
333
- return session.symbols.get(name)
334
-
335
-
336
- def list_symbols(session_name: str) -> list[str]:
337
- """List all symbols in a session's symbol table."""
338
- session = get_session(session_name)
339
- if not session:
340
- raise ResourceNotFound(f"Session '{session_name}' not found.")
341
-
342
- # Only session-level symbols are supported now
343
- return list(session.symbols.keys())
344
-
345
-
346
- def delete_symbol(session_name: str, symbol_name: str) -> bool:
347
- """Delete a symbol from a session.
348
-
349
- Returns:
350
- True if symbol was deleted, False if not found.
351
- """
352
- session = get_session(session_name)
353
- if not session:
354
- return False
355
-
356
- if symbol_name in session.symbols:
357
- del session.symbols[symbol_name]
358
- logging.info(f"Symbol {symbol_name} deleted from session {session_name}")
359
- return True
360
- return False
361
-
362
-
363
- def list_all_sessions() -> list[str]:
364
- """List all session names."""
365
- return list(_sessions.keys())