torchmonarch-nightly 2025.7.29__cp313-cp313-manylinux2014_x86_64.whl → 2025.7.31__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monarch/_rust_bindings.so CHANGED
Binary file
@@ -24,7 +24,9 @@ from typing import (
24
24
  Callable,
25
25
  cast,
26
26
  Concatenate,
27
+ Coroutine,
27
28
  Dict,
29
+ Generator,
28
30
  Generic,
29
31
  Iterable,
30
32
  Iterator,
@@ -403,15 +405,17 @@ class Accumulator(Generic[P, R, A]):
403
405
  self._combine: Callable[[A, R], A] = combine
404
406
 
405
407
  def accumulate(self, *args: P.args, **kwargs: P.kwargs) -> "Future[A]":
406
- gen: AsyncGenerator[R, R] = self._endpoint.stream(*args, **kwargs)
408
+ gen: Generator[Coroutine[None, None, R], None, None] = self._endpoint._stream(
409
+ *args, **kwargs
410
+ )
407
411
 
408
412
  async def impl() -> A:
409
413
  value = self._identity
410
- async for x in gen:
411
- value = self._combine(value, x)
414
+ for x in gen:
415
+ value = self._combine(value, await x)
412
416
  return value
413
417
 
414
- return Future(impl=impl)
418
+ return Future(coro=impl())
415
419
 
416
420
 
417
421
  class ValueMesh(MeshTrait, Generic[R]):
@@ -587,7 +591,7 @@ class PortReceiver(Generic[R]):
587
591
  raise ValueError(f"Unexpected message kind: {msg.kind}")
588
592
 
589
593
  def recv(self) -> "Future[R]":
590
- return Future(impl=lambda: self._recv(), requires_loop=False)
594
+ return Future(coro=self._recv())
591
595
 
592
596
 
593
597
  class RankedPortReceiver(PortReceiver[Tuple[int, R]]):
@@ -18,11 +18,10 @@ from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monar
18
18
  RemoteAllocatorBase,
19
19
  SimAllocatorBase,
20
20
  )
21
+ from monarch._src.actor.future import Future
21
22
 
22
23
  if TYPE_CHECKING:
23
- from monarch._rust_bindings.monarch_hyperactor.tokio import PythonTask
24
-
25
- from monarch._src.actor.future import Future
24
+ from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
26
25
 
27
26
  ALLOC_LABEL_PROC_MESH_NAME = "procmesh.monarch.meta.com/name"
28
27
 
@@ -31,9 +30,9 @@ logger: logging.Logger = logging.getLogger(__name__)
31
30
 
32
31
  class AllocateMixin(abc.ABC):
33
32
  @abc.abstractmethod
34
- def allocate_nonblocking(self, spec: AllocSpec) -> "Awaitable[Alloc]": ...
33
+ def allocate_nonblocking(self, spec: AllocSpec) -> "PythonTask[Alloc]": ...
35
34
 
36
- def allocate(self, spec: AllocSpec) -> Future[Alloc]:
35
+ def allocate(self, spec: AllocSpec) -> "Future[Alloc]":
37
36
  """
38
37
  Allocate a process according to the provided spec.
39
38
 
@@ -43,7 +42,7 @@ class AllocateMixin(abc.ABC):
43
42
  Returns:
44
43
  - A future that will be fulfilled when the requested allocation is fulfilled.
45
44
  """
46
- return Future(impl=lambda: self.allocate_nonblocking(spec), requires_loop=False)
45
+ return Future(coro=self.allocate_nonblocking(spec))
47
46
 
48
47
 
49
48
  @final
@@ -12,7 +12,7 @@ import logging
12
12
  import os
13
13
  import sys
14
14
  from dataclasses import dataclass
15
- from typing import cast, Dict, Generator, List, Tuple, Union
15
+ from typing import cast, Dict, Generator, List, Optional, Tuple, Union
16
16
 
17
17
  from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
18
18
  from monarch._src.actor.actor_mesh import (
@@ -25,6 +25,7 @@ from monarch._src.actor.actor_mesh import (
25
25
  from monarch._src.actor.endpoint import endpoint
26
26
  from monarch._src.actor.pdb_wrapper import DebuggerWrite, PdbWrapper
27
27
  from monarch._src.actor.sync_state import fake_sync_state
28
+ from tabulate import tabulate
28
29
 
29
30
 
30
31
  logger = logging.getLogger(__name__)
@@ -43,24 +44,32 @@ def _debugger_output(msg):
43
44
 
44
45
  @dataclass
45
46
  class DebugSessionInfo:
47
+ actor_name: str
46
48
  rank: int
47
49
  coords: Dict[str, int]
48
50
  hostname: str
49
- actor_id: ActorId
50
51
  function: str | None
51
52
  lineno: int | None
52
53
 
54
+ def __lt__(self, other):
55
+ if self.actor_name < other.actor_name:
56
+ return True
57
+ elif self.actor_name == other.actor_name:
58
+ return self.rank < other.rank
59
+ else:
60
+ return False
61
+
53
62
 
54
63
  class DebugSession:
55
64
  """Represents a single session with a remote debugger."""
56
65
 
57
66
  def __init__(
58
- self, rank: int, coords: Dict[str, int], hostname: str, actor_id: ActorId
67
+ self, rank: int, coords: Dict[str, int], hostname: str, actor_name: str
59
68
  ):
60
69
  self.rank = rank
61
70
  self.coords = coords
62
71
  self.hostname = hostname
63
- self.actor_id = actor_id
72
+ self.actor_name = actor_name
64
73
  self._active = False
65
74
  self._message_queue = asyncio.Queue()
66
75
  self._task = None
@@ -127,7 +136,7 @@ class DebugSession:
127
136
  if self._function_lineno is not None:
128
137
  function, lineno = self._function_lineno
129
138
  return DebugSessionInfo(
130
- self.rank, self.coords, self.hostname, self.actor_id, function, lineno
139
+ self.actor_name, self.rank, self.coords, self.hostname, function, lineno
131
140
  )
132
141
 
133
142
  async def attach(self, line=None, suppress_output=False):
@@ -160,6 +169,97 @@ class DebugSession:
160
169
  RanksType = Union[int, List[int], range, Dict[str, Union[range, List[int], int]]]
161
170
 
162
171
 
172
+ class DebugSessions:
173
+ def __init__(self):
174
+ self._sessions: Dict[str, Dict[int, DebugSession]] = {}
175
+
176
+ def insert(self, session: DebugSession) -> None:
177
+ if session.actor_name not in self._sessions:
178
+ self._sessions[session.actor_name] = {session.rank: session}
179
+ elif session.rank not in self._sessions[session.actor_name]:
180
+ self._sessions[session.actor_name][session.rank] = session
181
+ else:
182
+ raise ValueError(
183
+ f"Debug session for rank {session.rank} already exists for actor {session.actor_name}"
184
+ )
185
+
186
+ def remove(self, actor_name: str, rank: int) -> DebugSession:
187
+ if actor_name not in self._sessions:
188
+ raise ValueError(f"No debug sessions for actor {actor_name}")
189
+ elif rank not in self._sessions[actor_name]:
190
+ raise ValueError(f"No debug session for rank {rank} for actor {actor_name}")
191
+ session = self._sessions[actor_name].pop(rank)
192
+ if len(self._sessions[actor_name]) == 0:
193
+ del self._sessions[actor_name]
194
+ return session
195
+
196
+ def get(self, actor_name: str, rank: int) -> DebugSession:
197
+ if actor_name not in self._sessions:
198
+ raise ValueError(f"No debug sessions for actor {actor_name}")
199
+ elif rank not in self._sessions[actor_name]:
200
+ raise ValueError(f"No debug session for rank {rank} for actor {actor_name}")
201
+ return self._sessions[actor_name][rank]
202
+
203
+ def iter(
204
+ self, selection: Optional[Tuple[str, Optional[RanksType]]]
205
+ ) -> Generator[DebugSession, None, None]:
206
+ if selection is None:
207
+ for sessions in self._sessions.values():
208
+ for session in sessions.values():
209
+ yield session
210
+ return
211
+ actor_name, ranks = selection
212
+ if actor_name not in self._sessions:
213
+ return
214
+ sessions = self._sessions[actor_name]
215
+ if ranks is None:
216
+ for session in sessions.values():
217
+ yield session
218
+ elif isinstance(ranks, int):
219
+ if ranks in sessions:
220
+ yield sessions[ranks]
221
+ elif isinstance(ranks, list):
222
+ for rank in ranks:
223
+ if rank in sessions:
224
+ yield sessions[rank]
225
+ elif isinstance(ranks, dict):
226
+ dims = ranks
227
+ for session in sessions.values():
228
+ include_rank = True
229
+ for dim, ranks in dims.items():
230
+ if dim not in session.coords:
231
+ include_rank = False
232
+ break
233
+ elif (
234
+ isinstance(ranks, range) or isinstance(ranks, list)
235
+ ) and session.coords[dim] not in ranks:
236
+ include_rank = False
237
+ break
238
+ elif isinstance(ranks, int) and session.coords[dim] != ranks:
239
+ include_rank = False
240
+ break
241
+ if include_rank:
242
+ yield session
243
+ elif isinstance(ranks, range):
244
+ for rank, session in sessions.items():
245
+ if rank in ranks:
246
+ yield session
247
+
248
+ def info(self) -> List[DebugSessionInfo]:
249
+ session_info = []
250
+ for sessions in self._sessions.values():
251
+ for session in sessions.values():
252
+ session_info.append(session.get_info())
253
+ return session_info
254
+
255
+ def __len__(self) -> int:
256
+ return sum(len(sessions) for sessions in self._sessions.values())
257
+
258
+ def __contains__(self, item: Tuple[str, int]) -> bool:
259
+ actor_name, rank = item
260
+ return actor_name in self._sessions and rank in self._sessions[actor_name]
261
+
262
+
163
263
  _debug_input_parser = None
164
264
 
165
265
 
@@ -181,14 +281,17 @@ def _get_debug_input_parser():
181
281
  dims: dim ("," dim)*
182
282
  ranks: "ranks(" (dims | rank_range | rank_list | INT) ")"
183
283
  pdb_command: /\\w+.*/
184
- cast: "cast" ranks pdb_command
284
+ actor_name: /\\w+/
285
+ cast: "cast" _WS actor_name ranks pdb_command
185
286
  help: "h" | "help"
186
- attach: ("a" | "attach") INT
287
+ attach: ("a" | "attach") _WS actor_name INT
187
288
  cont: "c" | "continue"
188
289
  quit: "q" | "quit"
189
290
  list: "l" | "list"
190
291
  command: attach | list | cast | help | cont | quit
191
292
 
293
+ _WS: WS+
294
+
192
295
  %import common.INT
193
296
  %import common.CNAME
194
297
  %import common.WS
@@ -255,11 +358,14 @@ def _get_debug_input_transformer():
255
358
  def pdb_command(self, items: List[Token]) -> str:
256
359
  return items[0].value
257
360
 
361
+ def actor_name(self, items: List[Token]) -> str:
362
+ return items[0].value
363
+
258
364
  def help(self, _items: List[Token]) -> "Help":
259
365
  return Help()
260
366
 
261
- def attach(self, items: List[Token]) -> "Attach":
262
- return Attach(int(items[0].value))
367
+ def attach(self, items: Tuple[str, Token]) -> "Attach":
368
+ return Attach(items[0], int(items[1].value))
263
369
 
264
370
  def cont(self, _items: List[Token]) -> "Continue":
265
371
  return Continue()
@@ -267,8 +373,8 @@ def _get_debug_input_transformer():
267
373
  def quit(self, _items: List[Token]) -> "Quit":
268
374
  return Quit()
269
375
 
270
- def cast(self, items: Tuple[RanksType, str]) -> "Cast":
271
- return Cast(items[0], items[1])
376
+ def cast(self, items: Tuple[str, RanksType, str]) -> "Cast":
377
+ return Cast(*items)
272
378
 
273
379
  def list(self, items: List[Token]) -> "ListCommand":
274
380
  return ListCommand()
@@ -293,6 +399,7 @@ class DebugCommand:
293
399
 
294
400
  @dataclass
295
401
  class Attach(DebugCommand):
402
+ actor_name: str
296
403
  rank: int
297
404
 
298
405
 
@@ -318,6 +425,7 @@ class Continue(DebugCommand):
318
425
 
319
426
  @dataclass
320
427
  class Cast(DebugCommand):
428
+ actor_name: str
321
429
  ranks: RanksType
322
430
  command: str
323
431
 
@@ -330,7 +438,7 @@ class DebugClient(Actor):
330
438
  """
331
439
 
332
440
  def __init__(self) -> None:
333
- self.sessions = {} # rank -> DebugSession
441
+ self.sessions = DebugSessions()
334
442
 
335
443
  @endpoint
336
444
  async def wait_pending_session(self):
@@ -338,39 +446,33 @@ class DebugClient(Actor):
338
446
  await asyncio.sleep(1)
339
447
 
340
448
  @endpoint
341
- async def list(self) -> List[Tuple[int, Dict[str, int], str, ActorId, str, int]]:
342
- session_info = []
343
- for _, session in self.sessions.items():
344
- info = session.get_info()
345
- session_info.append(
346
- (
347
- info.rank,
348
- info.coords,
349
- info.hostname,
350
- info.actor_id,
351
- info.function,
352
- info.lineno,
353
- )
354
- )
355
- table_info = sorted(session_info, key=lambda r: r[0])
356
-
357
- from tabulate import tabulate
358
-
449
+ async def list(self) -> List[DebugSessionInfo]:
450
+ session_info = sorted(self.sessions.info())
359
451
  print(
360
452
  tabulate(
361
- table_info,
453
+ (
454
+ (
455
+ info.actor_name,
456
+ info.rank,
457
+ info.coords,
458
+ info.hostname,
459
+ info.function,
460
+ info.lineno,
461
+ )
462
+ for info in session_info
463
+ ),
362
464
  headers=[
465
+ "Actor Name",
363
466
  "Rank",
364
467
  "Coords",
365
468
  "Hostname",
366
- "Actor ID",
367
469
  "Function",
368
470
  "Line No.",
369
471
  ],
370
472
  tablefmt="grid",
371
473
  )
372
474
  )
373
- return table_info
475
+ return session_info
374
476
 
375
477
  @endpoint
376
478
  async def enter(self) -> None:
@@ -383,14 +485,16 @@ class DebugClient(Actor):
383
485
  while True:
384
486
  try:
385
487
  user_input = await _debugger_input("monarch_dbg> ")
488
+ if not user_input.strip():
489
+ continue
386
490
  command = DebugCommand.parse(user_input)
387
491
  if isinstance(command, Help):
388
492
  print("monarch_dbg commands:")
389
- print("\tattach <rank> - attach to a debug session")
493
+ print("\tattach <actor_name> <rank> - attach to a debug session")
390
494
  print("\tlist - list all debug sessions")
391
495
  print("\tquit - exit the debugger, leaving all sessions in place")
392
496
  print(
393
- "\tcast ranks(...) <command> - send a command to a set of ranks.\n"
497
+ "\tcast <actor_name> ranks(...) <command> - send a command to a set of ranks on the specified actor mesh.\n"
394
498
  "\t\tThe value inside ranks(...) can be a single rank (ranks(1)),\n"
395
499
  "\t\ta list of ranks (ranks(1,4,6)), a range of ranks (ranks(start?:stop?:step?)),\n"
396
500
  "\t\tor a dict of dimensions (ranks(dim1=1:5:2,dim2=3, dim4=(3,6)))."
@@ -400,10 +504,7 @@ class DebugClient(Actor):
400
504
  )
401
505
  print("\thelp - print this help message")
402
506
  elif isinstance(command, Attach):
403
- if command.rank not in self.sessions:
404
- print(f"No debug session for rank {command.rank}")
405
- else:
406
- await self.sessions[command.rank].attach()
507
+ await self.sessions.get(command.actor_name, command.rank).attach()
407
508
  elif isinstance(command, ListCommand):
408
509
  # pyre-ignore
409
510
  await self.list._method(self)
@@ -419,61 +520,22 @@ class DebugClient(Actor):
419
520
  elif isinstance(command, Quit):
420
521
  return
421
522
  elif isinstance(command, Cast):
422
- await self._cast_input_and_wait(command.command, command.ranks)
523
+ await self._cast_input_and_wait(
524
+ command.command, (command.actor_name, command.ranks)
525
+ )
423
526
  except Exception as e:
424
527
  print(f"Error processing command: {e}")
425
528
 
426
529
  async def _cast_input_and_wait(
427
530
  self,
428
531
  command: str,
429
- ranks: RanksType | None = None,
532
+ selection: Optional[Tuple[str, Optional[RanksType]]] = None,
430
533
  ) -> None:
431
- if ranks is None:
432
- ranks = self.sessions.keys()
433
- elif isinstance(ranks, dict):
434
- ranks = self._iter_ranks_dict(ranks)
435
- elif isinstance(ranks, range):
436
- ranks = self._iter_ranks_range(ranks)
437
- elif isinstance(ranks, int):
438
- ranks = [ranks]
439
534
  tasks = []
440
- for rank in ranks:
441
- if rank in self.sessions:
442
- tasks.append(
443
- self.sessions[rank].attach(
444
- command,
445
- suppress_output=True,
446
- )
447
- )
448
- else:
449
- print(f"No debug session for rank {rank}")
535
+ for session in self.sessions.iter(selection):
536
+ tasks.append(session.attach(command, suppress_output=True))
450
537
  await asyncio.gather(*tasks)
451
538
 
452
- def _iter_ranks_dict(
453
- self, dims: Dict[str, Union[range, List[int], int]]
454
- ) -> Generator[int, None, None]:
455
- for rank, session in self.sessions.items():
456
- include_rank = True
457
- for dim, ranks in dims.items():
458
- if dim not in session.coords:
459
- include_rank = False
460
- break
461
- elif (
462
- isinstance(ranks, range) or isinstance(ranks, list)
463
- ) and session.coords[dim] not in ranks:
464
- include_rank = False
465
- break
466
- elif isinstance(ranks, int) and session.coords[dim] != ranks:
467
- include_rank = False
468
- break
469
- if include_rank:
470
- yield rank
471
-
472
- def _iter_ranks_range(self, rng: range) -> Generator[int, None, None]:
473
- for rank in self.sessions.keys():
474
- if rank in rng:
475
- yield rank
476
-
477
539
  ##########################################################################
478
540
  # Debugger APIs
479
541
  #
@@ -481,30 +543,30 @@ class DebugClient(Actor):
481
543
  # and communicate with them.
482
544
  @endpoint
483
545
  async def debugger_session_start(
484
- self, rank: int, coords: Dict[str, int], hostname: str, actor_id: ActorId
546
+ self, rank: int, coords: Dict[str, int], hostname: str, actor_name: str
485
547
  ) -> None:
486
548
  # Create a session if it doesn't exist
487
- if rank not in self.sessions:
488
- self.sessions[rank] = DebugSession(rank, coords, hostname, actor_id)
549
+ if (actor_name, rank) not in self.sessions:
550
+ self.sessions.insert(DebugSession(rank, coords, hostname, actor_name))
489
551
 
490
552
  @endpoint
491
- async def debugger_session_end(self, rank: int) -> None:
553
+ async def debugger_session_end(self, actor_name: str, rank: int) -> None:
492
554
  """Detach from the current debug session."""
493
- session = self.sessions.pop(rank)
494
- await session.detach()
555
+ await self.sessions.remove(actor_name, rank).detach()
495
556
 
496
557
  @endpoint
497
- async def debugger_read(self, rank: int, size: int) -> DebuggerWrite | str:
558
+ async def debugger_read(
559
+ self, actor_name: str, rank: int, size: int
560
+ ) -> DebuggerWrite | str:
498
561
  """Read from the debug session for the given rank."""
499
- session = self.sessions[rank]
500
-
501
- return await session.debugger_read(size)
562
+ return await self.sessions.get(actor_name, rank).debugger_read(size)
502
563
 
503
564
  @endpoint
504
- async def debugger_write(self, rank: int, write: DebuggerWrite) -> None:
565
+ async def debugger_write(
566
+ self, actor_name: str, rank: int, write: DebuggerWrite
567
+ ) -> None:
505
568
  """Write to the debug session for the given rank."""
506
- session = self.sessions[rank]
507
- await session.debugger_write(write)
569
+ await self.sessions.get(actor_name, rank).debugger_write(write)
508
570
 
509
571
 
510
572
  class DebugManager(Actor):
@@ -529,7 +591,6 @@ class DebugManager(Actor):
529
591
  def __init__(self, debug_client: DebugClient) -> None:
530
592
  self._debug_client = debug_client
531
593
 
532
- # pyre-ignore
533
594
  @endpoint
534
595
  def get_debug_client(self) -> DebugClient:
535
596
  return self._debug_client
@@ -16,7 +16,9 @@ from typing import (
16
16
  Callable,
17
17
  cast,
18
18
  Concatenate,
19
+ Coroutine,
19
20
  Dict,
21
+ Generator,
20
22
  Generic,
21
23
  List,
22
24
  Literal,
@@ -145,7 +147,7 @@ class Endpoint(ABC, Generic[P, R]):
145
147
 
146
148
  results: List[R] = [None] * extent.nelements # pyre-fixme[9]
147
149
  for _ in range(extent.nelements):
148
- rank, value = await r.recv()
150
+ rank, value = await r._recv()
149
151
  results[rank] = value
150
152
  call_shape = Shape(
151
153
  extent.labels,
@@ -153,9 +155,11 @@ class Endpoint(ABC, Generic[P, R]):
153
155
  )
154
156
  return ValueMesh(call_shape, results)
155
157
 
156
- return Future(impl=process, requires_loop=False)
158
+ return Future(coro=process())
157
159
 
158
- async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]:
160
+ def _stream(
161
+ self, *args: P.args, **kwargs: P.kwargs
162
+ ) -> Generator[Coroutine[Any, Any, R], None, None]:
159
163
  """
160
164
  Broadcasts to all actors and yields their responses as a stream / generator.
161
165
 
@@ -168,7 +172,14 @@ class Endpoint(ABC, Generic[P, R]):
168
172
  # pyre-ignore
169
173
  extent = self._send(args, kwargs, port=p)
170
174
  for _ in range(extent.nelements):
171
- yield await r.recv()
175
+ # pyre-ignore
176
+ yield r._recv()
177
+
178
+ def stream(
179
+ self, *args: P.args, **kwargs: P.kwargs
180
+ ) -> Generator[Future[R], None, None]:
181
+ for coro in self._stream(*args, **kwargs):
182
+ yield Future(coro=coro)
172
183
 
173
184
  def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
174
185
  """
@@ -7,7 +7,21 @@
7
7
  import asyncio
8
8
  import traceback
9
9
  from functools import partial
10
- from typing import Generator, Generic, Optional, TypeVar
10
+ from typing import (
11
+ Any,
12
+ cast,
13
+ Coroutine,
14
+ Generator,
15
+ Generic,
16
+ Literal,
17
+ NamedTuple,
18
+ Optional,
19
+ TypeVar,
20
+ )
21
+
22
+ from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
23
+
24
+ from typing_extensions import Self
11
25
 
12
26
  R = TypeVar("R")
13
27
 
@@ -48,43 +62,76 @@ async def _aincomplete(impl, self):
48
62
  # loop machinery, this gives it the same throughput as if we ran it synchronously.
49
63
 
50
64
 
51
- class Future(Generic[R]):
52
- def __init__(self, *, impl, requires_loop=True):
53
- self._aget = partial(_aincomplete, impl)
54
- self._requires_loop = requires_loop
65
+ class _Unawaited(NamedTuple):
66
+ coro: PythonTask
55
67
 
56
- def get(self, timeout: Optional[float] = None) -> R:
57
- if asyncio._get_running_loop() is not None:
58
- raise RuntimeError("get() cannot be called from within an async context")
59
- if timeout is not None:
60
- return asyncio.run(asyncio.wait_for(self._aget(self), timeout))
61
- if not self._requires_loop:
62
- try:
63
- coro = self._aget(self)
64
- next(coro.__await__())
65
- tb_str = "".join(traceback.format_stack(coro.cr_frame))
66
- raise RuntimeError(
67
- f"a coroutine paused with a future with requires_loop=False cannot block on a python asyncio.Future. Use requires_loop=True.\n{tb_str}"
68
- )
69
- except StopIteration as e:
70
- return e.value
71
- return asyncio.run(self._aget(self))
72
68
 
73
- def __await__(self) -> Generator[R, None, R]:
74
- return self._aget(self).__await__()
69
+ class _Complete(NamedTuple):
70
+ value: Any
71
+
72
+
73
+ class _Exception(NamedTuple):
74
+ exe: Exception
75
+
75
76
 
76
- def _set_result(self, result):
77
- async def af(self):
78
- return result
77
+ class _Asyncio(NamedTuple):
78
+ fut: asyncio.Future
79
79
 
80
- self._aget = af
81
- return result
82
80
 
83
- def _set_exception(self, e):
84
- async def af(self):
85
- raise e
81
+ _Status = _Unawaited | _Complete | _Exception | _Asyncio
86
82
 
87
- self._aget = af
83
+
84
+ class Future(Generic[R]):
85
+ def __init__(self, *, coro: "Coroutine[Any, Any, R] | PythonTask[R]"):
86
+ self._status: _Status = _Unawaited(
87
+ coro if isinstance(coro, PythonTask) else PythonTask.from_coroutine(coro)
88
+ )
89
+
90
+ def get(self, timeout: Optional[float] = None) -> R:
91
+ match self._status:
92
+ case _Unawaited(coro=coro):
93
+ try:
94
+ if timeout is not None:
95
+ coro = coro.with_timeout(timeout)
96
+ v = coro.block_on()
97
+ self._status = _Complete(v)
98
+ return cast("R", v)
99
+ except Exception as e:
100
+ self._status = _Exception(e)
101
+ raise e from None
102
+ case _Asyncio(_):
103
+ raise ValueError(
104
+ "already converted into an asyncio.Future, use 'await' to get the value."
105
+ )
106
+ case _Complete(value=value):
107
+ return cast("R", value)
108
+ case _Exception(exe=exe):
109
+ raise exe
110
+ case _:
111
+ raise RuntimeError("unknown status")
112
+
113
+ def __await__(self) -> Generator[Any, Any, R]:
114
+ match self._status:
115
+ case _Unawaited(coro=coro):
116
+ loop = asyncio.get_running_loop()
117
+ fut = loop.create_future()
118
+ self._status = _Asyncio(fut)
119
+
120
+ async def mark_complete():
121
+ try:
122
+ func, value = fut.set_result, await coro
123
+ except Exception as e:
124
+ func, value = fut.set_exception, e
125
+ loop.call_soon_threadsafe(func, value)
126
+
127
+ PythonTask.from_coroutine(mark_complete()).spawn()
128
+ return fut.__await__()
129
+ case _Asyncio(fut=fut):
130
+ return fut.__await__()
131
+ case _:
132
+ raise ValueError(
133
+ "already converted into a synchronous future, use 'get' to get the value."
134
+ )
88
135
 
89
136
  # compatibility with old tensor engine Future objects
90
137
  # hopefully we do not need done(), add_callback because