rubyx-py 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1298 @@
1
+ use crate::api;
2
+ use crate::exception::PythonException;
3
+ use crate::python_api::PythonApi;
4
+ use crate::python_ffi::PyObject;
5
+ use crate::python_guard::PyGuard;
6
+ use crate::rubyx_object::python_to_sendable;
7
+ use crate::stream::StreamItem;
8
+ use crossbeam_channel::{bounded, unbounded, Receiver, Sender};
9
+ use std::panic::{catch_unwind, AssertUnwindSafe};
10
+ use std::thread;
11
+ use std::thread::JoinHandle;
12
+
13
+ pub(crate) const SYNC_ADAPTER_PY: &str = include_str!("python/sync_adapter.py");
14
+
15
+ /// A stream that can consume either sync or async Python generators
16
+ ///
17
+ /// Sync generators, PyIter_Next loop.
18
+ /// Async generators, Rust-side driving depending on configuration
19
+ pub struct AsyncGeneratorStream {
20
+ receiver: Receiver<StreamItem>,
21
+ cancel_sender: Sender<()>,
22
+ handle: Option<JoinHandle<()>>,
23
+ }
24
+
25
+ /// Strategy for consuming async generators
26
+ #[derive(Clone, Copy, Debug)]
27
+ pub enum AsyncStrategy {
28
+ /// AsyncToSync adapter, then use PyIter_Next loop
29
+ PythonAdapter,
30
+ /// Drive __anext__() coroutines from Rust with asyncio
31
+ RustDriving,
32
+ }
33
+
34
+ impl Drop for AsyncGeneratorStream {
35
+ fn drop(&mut self) {
36
+ // Signal the worker thread to stop
37
+ let _ = self.cancel_sender.try_send(());
38
+ // Drain the channel so the worker doesn't block on a full send
39
+ while self.receiver.try_recv().is_ok() {}
40
+ // Join the worker thread to ensure GIL is released
41
+ if let Some(handle) = self.handle.take() {
42
+ let _ = handle.join();
43
+ }
44
+ }
45
+ }
46
+ impl Iterator for AsyncGeneratorStream {
47
+ type Item = Result<magnus::Value, magnus::Error>;
48
+ fn next(&mut self) -> Option<Self::Item> {
49
+ match self.receiver.recv() {
50
+ Ok(StreamItem::Value(v)) => Some(v.try_into()),
51
+ Ok(StreamItem::Error(e)) => Some(Err(magnus::Error::new(
52
+ crate::ruby_helpers::runtime_error(),
53
+ e,
54
+ ))),
55
+ Ok(StreamItem::End) | Err(_) => None,
56
+ }
57
+ }
58
+ }
59
+
60
+ impl AsyncGeneratorStream {
61
+ /// Create a stream from Python object
62
+ pub fn from_python_object(
63
+ py_obj: *mut PyObject,
64
+ async_strategy: AsyncStrategy,
65
+ ) -> Result<Self, String> {
66
+ let api = api();
67
+ let gil = api.ensure_gil();
68
+ let result = if api.is_async_iterable(py_obj) {
69
+ match async_strategy {
70
+ AsyncStrategy::PythonAdapter => Self::from_async_via_adapter(py_obj, api),
71
+ AsyncStrategy::RustDriving => Self::from_async_via_rust(py_obj, api),
72
+ }
73
+ } else {
74
+ let py_iter = api.object_get_iter(py_obj);
75
+ if py_iter.is_null() {
76
+ api.clear_error();
77
+ Err("Object is not iterable".to_string())
78
+ } else {
79
+ Ok(Self::from_sync_iterator(py_iter))
80
+ }
81
+ };
82
+ api.release_gil(gil);
83
+ result
84
+ }
85
+ /// Async path using AsyncToSync adapter.
86
+ fn from_async_via_adapter(async_gen: *mut PyObject, api: &PythonApi) -> Result<Self, String> {
87
+ let sync_iter = api.wrap_async_generator(async_gen);
88
+ if sync_iter.is_null() {
89
+ api.clear_error();
90
+ return Err("Failed to wrap async generator".to_string());
91
+ }
92
+ Ok(Self::from_sync_iterator(sync_iter))
93
+ }
94
+ /// Async path using Rust-side event loop driving.
95
+ fn from_async_via_rust(async_gen: *mut PyObject, api: &PythonApi) -> Result<Self, String> {
96
+ let (value_tx, value_rx) = unbounded();
97
+ let (cancel_tx, cancel_rx) = bounded(1);
98
+ // worker thread own the reference - incref
99
+ api.incref(async_gen);
100
+ let gen_ptr = async_gen as usize;
101
+ let handle = thread::spawn(move || {
102
+ let api = crate::api();
103
+ let gil = api.ensure_gil();
104
+ let async_gen = gen_ptr as *mut PyObject;
105
+ drive_async_generator(api, async_gen, &value_tx, &cancel_rx);
106
+ api.release_gil(gil);
107
+ });
108
+ Ok(Self {
109
+ receiver: value_rx,
110
+ cancel_sender: cancel_tx,
111
+ handle: Some(handle),
112
+ })
113
+ }
114
+ fn from_sync_iterator(py_iter: *mut PyObject) -> Self {
115
+ let (value_tx, value_rx) = unbounded();
116
+ let (cancel_tx, cancel_rx) = bounded(1);
117
+
118
+ // Cast the raw pointer to usize so it can cross the thread boundary.
119
+ // *mut PyObject is not Send, but the usize value is just a number.
120
+ // This is safe because the worker thread will acquire the GIL before
121
+ // using the pointer, and the pointer remains valid (Python iterator
122
+ // is kept alive by its refcount).
123
+ let py_iter_addr = py_iter as usize;
124
+
125
+ let handle = thread::spawn(move || {
126
+ let py_iter = py_iter_addr as *mut PyObject;
127
+ // Worker thread: acquire GIL, iterate, send values
128
+ let api = api();
129
+ let gil = api.ensure_gil();
130
+
131
+ loop {
132
+ // Check if there is a cancellation
133
+ if cancel_rx.try_recv().is_ok() {
134
+ break;
135
+ }
136
+
137
+ // Get next item from Python iterator
138
+ let item = api.iter_next(py_iter);
139
+ if item.is_null() {
140
+ // Check if an exception was raised (vs normal exhaustion)
141
+ if api.has_error() {
142
+ if let Some(exc) = crate::python_api::PythonApi::extract_exception(api) {
143
+ value_tx.send(StreamItem::Error(exc.to_string())).ok();
144
+ } else {
145
+ value_tx.send(StreamItem::End).ok();
146
+ }
147
+ } else {
148
+ value_tx.send(StreamItem::End).ok();
149
+ }
150
+ break;
151
+ }
152
+ // Convert and send to ruby
153
+ let ruby_value = python_to_sendable(item, api)
154
+ .map_err(|e| format!("Error converting Python value to Ruby: {e}"));
155
+ api.decref(item);
156
+ match ruby_value {
157
+ Ok(value) => {
158
+ if value_tx.send(StreamItem::Value(value)).is_err() {
159
+ break; // Consumer dropped — stop producing
160
+ }
161
+ }
162
+ Err(e) => {
163
+ value_tx.send(StreamItem::Error(e)).ok();
164
+ break;
165
+ }
166
+ }
167
+ }
168
+ api.decref(py_iter);
169
+ api.release_gil(gil);
170
+ });
171
+ Self {
172
+ receiver: value_rx,
173
+ cancel_sender: cancel_tx,
174
+ handle: Some(handle),
175
+ }
176
+ }
177
+ }
178
+
179
+ pub(crate) fn drive_async_generator(
180
+ api: &PythonApi,
181
+ async_gen: *mut PyObject,
182
+ sender: &Sender<StreamItem>,
183
+ cancel: &Receiver<()>,
184
+ ) {
185
+ if async_gen.is_null() {
186
+ let _ = sender.send(StreamItem::Error("Async generator is null".to_string()));
187
+ return;
188
+ }
189
+
190
+ let Some(_async_gen_guard) = PyGuard::new(async_gen, api) else {
191
+ let _ = sender.send(StreamItem::Error("Async generator is null".to_string()));
192
+ return;
193
+ };
194
+
195
+ let asyncio = match api.import_module("asyncio") {
196
+ Ok(obj) => {
197
+ let Some(guard) = PyGuard::new(obj, api) else {
198
+ let _ = sender.send(StreamItem::Error("Failed to import asyncio".to_string()));
199
+ return;
200
+ };
201
+ guard
202
+ }
203
+ Err(err) => {
204
+ let _ = sender.send(StreamItem::Error(err.to_string()));
205
+ return;
206
+ }
207
+ };
208
+ let Some(new_loop_fun) = PyGuard::new(
209
+ api.object_get_attr_string(asyncio.ptr(), "new_event_loop"),
210
+ api,
211
+ ) else {
212
+ let _ = sender.send(StreamItem::Error(
213
+ "Failed to get asyncio.new_event_loop".to_string(),
214
+ ));
215
+ if api.has_error() {
216
+ api.clear_error();
217
+ }
218
+ return;
219
+ };
220
+ let Some(event_loop) = PyGuard::new(api.object_call_no_args(new_loop_fun.ptr()), api) else {
221
+ let _ = sender.send(StreamItem::Error("Failed to create event loop".to_string()));
222
+ if api.has_error() {
223
+ api.clear_error();
224
+ }
225
+ return;
226
+ };
227
+ let Some(run_fn) = PyGuard::new(
228
+ api.object_get_attr_string(event_loop.ptr(), "run_until_complete"),
229
+ api,
230
+ ) else {
231
+ let _ = sender.send(StreamItem::Error(
232
+ "Failed to get event_loop.run_until_complete".to_string(),
233
+ ));
234
+ if api.has_error() {
235
+ api.clear_error();
236
+ }
237
+ return;
238
+ };
239
+ let Some(anext_method) = PyGuard::new(api.object_get_attr_string(async_gen, "__anext__"), api)
240
+ else {
241
+ let _ = sender.send(StreamItem::Error(
242
+ "Failed to get __anext__ method".to_string(),
243
+ ));
244
+ if api.has_error() {
245
+ api.clear_error();
246
+ }
247
+ return;
248
+ };
249
+ loop {
250
+ if cancel.try_recv().is_ok() {
251
+ break;
252
+ }
253
+
254
+ let coroutine = api.object_call_no_args(anext_method.ptr());
255
+ if coroutine.is_null() {
256
+ if let Some(exc) = PythonApi::extract_exception(api) {
257
+ if is_stop_async_iteration(&exc) {
258
+ let _ = sender.send(StreamItem::End);
259
+ } else {
260
+ let _ = sender.send(StreamItem::Error(exc.to_string()));
261
+ }
262
+ } else {
263
+ let _ = sender.send(StreamItem::Error("__anext__() failed".into()));
264
+ if api.has_error() {
265
+ api.clear_error();
266
+ }
267
+ }
268
+ break;
269
+ }
270
+
271
+ let args_tuple = api.tuple_new(1);
272
+ if args_tuple.is_null() {
273
+ api.decref(coroutine);
274
+ let _ = sender.send(StreamItem::Error(
275
+ "Failed to allocate argument tuple".to_string(),
276
+ ));
277
+ if api.has_error() {
278
+ api.clear_error();
279
+ }
280
+ break;
281
+ }
282
+
283
+ api.incref(coroutine);
284
+ if api.tuple_set_item(args_tuple, 0, coroutine) != 0 {
285
+ api.decref(args_tuple);
286
+ api.decref(coroutine);
287
+ api.decref(coroutine);
288
+ let _ = sender.send(StreamItem::Error("Failed to set tuple item".to_string()));
289
+ if api.has_error() {
290
+ api.clear_error();
291
+ }
292
+ break;
293
+ }
294
+
295
+ let result = api.object_call(run_fn.ptr(), args_tuple, std::ptr::null_mut());
296
+ api.decref(args_tuple);
297
+ api.decref(coroutine);
298
+
299
+ if result.is_null() {
300
+ if let Some(exc) = PythonApi::extract_exception(api) {
301
+ if is_stop_async_iteration(&exc) {
302
+ let _ = sender.send(StreamItem::End);
303
+ } else {
304
+ let _ = sender.send(StreamItem::Error(exc.to_string()));
305
+ }
306
+ } else if api.has_error() {
307
+ api.clear_error();
308
+ } else {
309
+ let _ = sender.send(StreamItem::Error(
310
+ "run_until_complete failed without Python exception".to_string(),
311
+ ));
312
+ }
313
+ break;
314
+ }
315
+
316
+ match catch_unwind(AssertUnwindSafe(|| python_to_sendable(result, api))) {
317
+ Ok(Ok(val)) => {
318
+ api.decref(result);
319
+ if sender.send(StreamItem::Value(val)).is_err() {
320
+ break;
321
+ }
322
+ }
323
+ Ok(Err(err_msg)) => {
324
+ api.decref(result);
325
+ let _ = sender.send(StreamItem::Error(format!(
326
+ "Cannot convert Python value to Ruby: {err_msg}"
327
+ )));
328
+ break;
329
+ }
330
+ Err(_) => {
331
+ api.decref(result);
332
+ let _ = sender.send(StreamItem::Error(
333
+ "Cannot convert Python value to Ruby".to_string(),
334
+ ));
335
+ break;
336
+ }
337
+ }
338
+ }
339
+
340
+ if let Some(close_fn) = PyGuard::new(api.object_get_attr_string(event_loop.ptr(), "close"), api)
341
+ {
342
+ let close_result = api.object_call_no_args(close_fn.ptr());
343
+ if !close_result.is_null() {
344
+ drop(PyGuard::new(close_result, api));
345
+ } else if api.has_error() {
346
+ api.clear_error();
347
+ }
348
+ } else if api.has_error() {
349
+ api.clear_error();
350
+ }
351
+ }
352
+
353
+ fn is_stop_async_iteration(exc: &PythonException) -> bool {
354
+ matches!(
355
+ exc,
356
+ PythonException::Exception {
357
+ kind,
358
+ message: _,
359
+ traceback: _,
360
+ } if kind == "StopAsyncIteration"
361
+ )
362
+ }
363
+
364
+ #[cfg(test)]
365
+ impl AsyncGeneratorStream {
366
+ /// Test constructor: create an AsyncGeneratorStream from a channel of `Option<SendableValue>`.
367
+ /// `Some(val)` sends a value, `None` signals end-of-stream.
368
+ pub(crate) fn from_channel(
369
+ rx: Receiver<Option<crate::stream::SendableValue>>,
370
+ cancel_tx: Sender<()>,
371
+ ) -> Self {
372
+ let (value_tx, value_rx) = unbounded();
373
+ let handle = thread::spawn(move || {
374
+ while let Ok(item) = rx.recv() {
375
+ match item {
376
+ Some(val) => {
377
+ if value_tx.send(StreamItem::Value(val)).is_err() {
378
+ return;
379
+ }
380
+ }
381
+ None => {
382
+ value_tx.send(StreamItem::End).ok();
383
+ return;
384
+ }
385
+ }
386
+ }
387
+ value_tx.send(StreamItem::End).ok();
388
+ });
389
+ Self {
390
+ receiver: value_rx,
391
+ cancel_sender: cancel_tx,
392
+ handle: Some(handle),
393
+ }
394
+ }
395
+ }
396
+
397
+ #[cfg(test)]
398
+ mod tests {
399
+ use super::*;
400
+ use crate::stream::{SendableValue, StreamItem};
401
+ use crate::test_helpers::skip_if_no_python;
402
+ use crossbeam_channel::bounded;
403
+ use serial_test::serial;
404
+
405
+ const PY_EVAL_INPUT: i64 = 258;
406
+ const PY_FILE_INPUT: i64 = 257;
407
+
408
+ fn make_globals(api: &PythonApi) -> *mut PyObject {
409
+ let globals = api.dict_new();
410
+ let builtins_key = api.string_from_str("__builtins__");
411
+ let builtins = api
412
+ .import_module("builtins")
413
+ .expect("builtins should import");
414
+ api.dict_set_item(globals, builtins_key, builtins);
415
+ api.decref(builtins_key);
416
+ api.decref(builtins);
417
+ globals
418
+ }
419
+
420
+ fn run_file(api: &PythonApi, globals: *mut PyObject, code: &str) {
421
+ let result = api
422
+ .run_string(code, PY_FILE_INPUT, globals, globals)
423
+ .expect("python file input should succeed");
424
+ if !result.is_null() {
425
+ api.decref(result);
426
+ }
427
+ }
428
+
429
+ fn eval_obj(api: &PythonApi, globals: *mut PyObject, code: &str) -> *mut PyObject {
430
+ let result = api
431
+ .run_string(code, PY_EVAL_INPUT, globals, globals)
432
+ .expect("python eval should succeed");
433
+ assert!(!result.is_null());
434
+ result
435
+ }
436
+
437
+ fn restore_new_event_loop(api: &PythonApi, globals: *mut PyObject) {
438
+ run_file(
439
+ api,
440
+ globals,
441
+ r#"
442
+ import asyncio
443
+ if "_saved_new_event_loop" in globals():
444
+ asyncio.new_event_loop = _saved_new_event_loop
445
+ del _saved_new_event_loop
446
+ "#,
447
+ );
448
+ if api.has_error() {
449
+ api.clear_error();
450
+ }
451
+ }
452
+
453
+ fn cleanup_globals(api: &PythonApi, globals: *mut PyObject) {
454
+ if api.has_error() {
455
+ api.clear_error();
456
+ }
457
+ api.decref(globals);
458
+ }
459
+
460
+ fn assert_single_error_contains(items: &[StreamItem], needle: &str) {
461
+ assert_eq!(items.len(), 1, "expected one stream item");
462
+ match &items[0] {
463
+ StreamItem::Error(msg) => {
464
+ assert!(
465
+ msg.contains(needle),
466
+ "expected error message to contain '{needle}', got '{msg}'"
467
+ );
468
+ }
469
+ _ => panic!("expected StreamItem::Error"),
470
+ }
471
+ }
472
+
473
+ fn assert_values_then_end(items: &[StreamItem], expected: &[i64]) {
474
+ assert_eq!(items.len(), expected.len() + 1, "unexpected stream length");
475
+ for (idx, expected_num) in expected.iter().enumerate() {
476
+ match &items[idx] {
477
+ StreamItem::Value(SendableValue::Integer(actual)) => {
478
+ assert_eq!(*actual, *expected_num, "unexpected value at index {idx}");
479
+ }
480
+ _ => panic!("expected integer value item"),
481
+ }
482
+ }
483
+ match items.last() {
484
+ Some(StreamItem::End) => {}
485
+ _ => panic!("expected StreamItem::End as last item"),
486
+ }
487
+ }
488
+
489
+ #[test]
490
+ #[serial]
491
+ fn test_drive_async_generator_null_input() {
492
+ let Some(guard) = skip_if_no_python() else {
493
+ return;
494
+ };
495
+ let api = guard.api();
496
+
497
+ let (value_tx, value_rx) = unbounded();
498
+ let (_cancel_tx, cancel_rx) = bounded(1);
499
+
500
+ drive_async_generator(api, std::ptr::null_mut(), &value_tx, &cancel_rx);
501
+
502
+ let items: Vec<StreamItem> = value_rx.try_iter().collect();
503
+ assert_single_error_contains(&items, "Async generator is null");
504
+ }
505
+
506
+ #[test]
507
+ #[serial]
508
+ fn test_drive_async_generator_yields_values_then_end() {
509
+ let Some(guard) = skip_if_no_python() else {
510
+ return;
511
+ };
512
+ let api = guard.api();
513
+ let globals = make_globals(api);
514
+
515
+ run_file(
516
+ api,
517
+ globals,
518
+ r#"
519
+ import asyncio
520
+ _saved_new_event_loop = asyncio.new_event_loop
521
+
522
+ class _DriveLoop:
523
+ def run_until_complete(self, awaitable):
524
+ return awaitable()
525
+ def close(self):
526
+ pass
527
+
528
+ asyncio.new_event_loop = _DriveLoop
529
+
530
+ class _FakeAgen:
531
+ def __init__(self, values):
532
+ self._values = list(values)
533
+ self._idx = 0
534
+ def __anext__(self):
535
+ if self._idx >= len(self._values):
536
+ def _raise_stop():
537
+ raise StopAsyncIteration
538
+ return _raise_stop
539
+ value = self._values[self._idx]
540
+ self._idx += 1
541
+ return lambda value=value: value
542
+ "#,
543
+ );
544
+
545
+ let async_gen = eval_obj(api, globals, "_FakeAgen([0, 1, 2])");
546
+ let (value_tx, value_rx) = unbounded();
547
+ let (_cancel_tx, cancel_rx) = bounded(1);
548
+
549
+ drive_async_generator(api, async_gen, &value_tx, &cancel_rx);
550
+
551
+ let items: Vec<StreamItem> = value_rx.try_iter().collect();
552
+ restore_new_event_loop(api, globals);
553
+ assert_values_then_end(&items, &[0, 1, 2]);
554
+
555
+ cleanup_globals(api, globals);
556
+ }
557
+
558
+ #[test]
559
+ #[serial]
560
+ fn test_drive_async_generator_propagates_async_error() {
561
+ let Some(guard) = skip_if_no_python() else {
562
+ return;
563
+ };
564
+ let api = guard.api();
565
+ let globals = make_globals(api);
566
+
567
+ run_file(
568
+ api,
569
+ globals,
570
+ r#"
571
+ import asyncio
572
+ _saved_new_event_loop = asyncio.new_event_loop
573
+
574
+ class _DriveLoop:
575
+ def run_until_complete(self, awaitable):
576
+ return awaitable()
577
+ def close(self):
578
+ pass
579
+
580
+ asyncio.new_event_loop = _DriveLoop
581
+
582
+ class _BoomAgen:
583
+ def __init__(self):
584
+ self._idx = 0
585
+ def __anext__(self):
586
+ if self._idx == 0:
587
+ self._idx += 1
588
+ return lambda: 1
589
+ def _raise_boom():
590
+ raise ValueError("async boom")
591
+ return _raise_boom
592
+ "#,
593
+ );
594
+
595
+ let async_gen = eval_obj(api, globals, "_BoomAgen()");
596
+ let (value_tx, value_rx) = unbounded();
597
+ let (_cancel_tx, cancel_rx) = bounded(1);
598
+
599
+ drive_async_generator(api, async_gen, &value_tx, &cancel_rx);
600
+
601
+ let items: Vec<StreamItem> = value_rx.try_iter().collect();
602
+ restore_new_event_loop(api, globals);
603
+ assert_eq!(items.len(), 2, "expected one value then one error");
604
+ match &items[0] {
605
+ StreamItem::Value(SendableValue::Integer(v)) => assert_eq!(*v, 1),
606
+ _ => panic!("expected first item to be integer value"),
607
+ }
608
+ match &items[1] {
609
+ StreamItem::Error(msg) => {
610
+ assert!(
611
+ msg.contains("ValueError") || msg.contains("async boom"),
612
+ "unexpected error message: {msg}"
613
+ );
614
+ }
615
+ _ => panic!("expected second item to be error"),
616
+ }
617
+
618
+ cleanup_globals(api, globals);
619
+ }
620
+
621
+ #[test]
622
+ #[serial]
623
+ fn test_drive_async_generator_rejects_non_async_object() {
624
+ let Some(guard) = skip_if_no_python() else {
625
+ return;
626
+ };
627
+ let api = guard.api();
628
+ let globals = make_globals(api);
629
+
630
+ let sync_iter = eval_obj(api, globals, "iter(range(3))");
631
+ let (value_tx, value_rx) = unbounded();
632
+ let (_cancel_tx, cancel_rx) = bounded(1);
633
+
634
+ drive_async_generator(api, sync_iter, &value_tx, &cancel_rx);
635
+
636
+ let items: Vec<StreamItem> = value_rx.try_iter().collect();
637
+ assert_single_error_contains(&items, "__anext__");
638
+
639
+ cleanup_globals(api, globals);
640
+ }
641
+
642
+ #[test]
643
+ #[serial]
644
+ fn test_drive_async_generator_handles_immediate_anext_failure() {
645
+ let Some(guard) = skip_if_no_python() else {
646
+ return;
647
+ };
648
+ let api = guard.api();
649
+ let globals = make_globals(api);
650
+
651
+ run_file(
652
+ api,
653
+ globals,
654
+ r#"
655
+ class BrokenAsync:
656
+ def __anext__(self):
657
+ raise RuntimeError("anext failed")
658
+ "#,
659
+ );
660
+
661
+ let broken_obj = eval_obj(api, globals, "BrokenAsync()");
662
+ let (value_tx, value_rx) = unbounded();
663
+ let (_cancel_tx, cancel_rx) = bounded(1);
664
+
665
+ drive_async_generator(api, broken_obj, &value_tx, &cancel_rx);
666
+
667
+ let items: Vec<StreamItem> = value_rx.try_iter().collect();
668
+ assert_single_error_contains(&items, "RuntimeError");
669
+
670
+ cleanup_globals(api, globals);
671
+ }
672
+
673
+ #[test]
674
+ #[serial]
675
+ fn test_drive_async_generator_respects_cancel_signal() {
676
+ let Some(guard) = skip_if_no_python() else {
677
+ return;
678
+ };
679
+ let api = guard.api();
680
+ let globals = make_globals(api);
681
+
682
+ run_file(
683
+ api,
684
+ globals,
685
+ r#"
686
+ import asyncio
687
+ _saved_new_event_loop = asyncio.new_event_loop
688
+
689
+ class _DriveLoop:
690
+ def run_until_complete(self, awaitable):
691
+ return awaitable()
692
+ def close(self):
693
+ pass
694
+
695
+ asyncio.new_event_loop = _DriveLoop
696
+
697
+ class _FakeAgen:
698
+ def __anext__(self):
699
+ return lambda: 1
700
+ "#,
701
+ );
702
+
703
+ let async_gen = eval_obj(api, globals, "_FakeAgen()");
704
+ let (value_tx, value_rx) = unbounded();
705
+ let (cancel_tx, cancel_rx) = bounded(1);
706
+ cancel_tx.send(()).expect("cancel signal should send");
707
+
708
+ drive_async_generator(api, async_gen, &value_tx, &cancel_rx);
709
+
710
+ let items: Vec<StreamItem> = value_rx.try_iter().collect();
711
+ restore_new_event_loop(api, globals);
712
+ assert!(items.is_empty(), "expected no output after cancellation");
713
+
714
+ cleanup_globals(api, globals);
715
+ }
716
+
717
+ #[test]
718
+ #[serial]
719
+ fn test_drive_async_generator_missing_new_event_loop_attr() {
720
+ let Some(guard) = skip_if_no_python() else {
721
+ return;
722
+ };
723
+ let api = guard.api();
724
+ let globals = make_globals(api);
725
+
726
+ run_file(
727
+ api,
728
+ globals,
729
+ r#"
730
+ import asyncio
731
+ _saved_new_event_loop = asyncio.new_event_loop
732
+ del asyncio.new_event_loop
733
+ "#,
734
+ );
735
+
736
+ let obj = eval_obj(api, globals, "iter(range(1))");
737
+ let (value_tx, value_rx) = unbounded();
738
+ let (_cancel_tx, cancel_rx) = bounded(1);
739
+ drive_async_generator(api, obj, &value_tx, &cancel_rx);
740
+
741
+ let items: Vec<StreamItem> = value_rx.try_iter().collect();
742
+ restore_new_event_loop(api, globals);
743
+ assert_single_error_contains(&items, "asyncio.new_event_loop");
744
+
745
+ cleanup_globals(api, globals);
746
+ }
747
+
748
+ #[test]
749
+ #[serial]
750
+ fn test_drive_async_generator_event_loop_creation_failure() {
751
+ let Some(guard) = skip_if_no_python() else {
752
+ return;
753
+ };
754
+ let api = guard.api();
755
+ let globals = make_globals(api);
756
+
757
+ run_file(
758
+ api,
759
+ globals,
760
+ r#"
761
+ import asyncio
762
+ _saved_new_event_loop = asyncio.new_event_loop
763
+ def _broken_new_event_loop():
764
+ raise RuntimeError("loop create failed")
765
+ asyncio.new_event_loop = _broken_new_event_loop
766
+ "#,
767
+ );
768
+
769
+ let obj = eval_obj(api, globals, "iter(range(1))");
770
+ let (value_tx, value_rx) = unbounded();
771
+ let (_cancel_tx, cancel_rx) = bounded(1);
772
+ drive_async_generator(api, obj, &value_tx, &cancel_rx);
773
+
774
+ let items: Vec<StreamItem> = value_rx.try_iter().collect();
775
+ restore_new_event_loop(api, globals);
776
+ assert_single_error_contains(&items, "Failed to create event loop");
777
+
778
+ cleanup_globals(api, globals);
779
+ }
780
+
781
+ #[test]
782
+ #[serial]
783
+ fn test_drive_async_generator_missing_run_until_complete() {
784
+ let Some(guard) = skip_if_no_python() else {
785
+ return;
786
+ };
787
+ let api = guard.api();
788
+ let globals = make_globals(api);
789
+
790
+ run_file(
791
+ api,
792
+ globals,
793
+ r#"
794
+ import asyncio
795
+ _saved_new_event_loop = asyncio.new_event_loop
796
+ class _NoRunLoop:
797
+ def close(self):
798
+ pass
799
+ asyncio.new_event_loop = lambda: _NoRunLoop()
800
+ "#,
801
+ );
802
+
803
+ let obj = eval_obj(api, globals, "iter(range(1))");
804
+ let (value_tx, value_rx) = unbounded();
805
+ let (_cancel_tx, cancel_rx) = bounded(1);
806
+ drive_async_generator(api, obj, &value_tx, &cancel_rx);
807
+
808
+ let items: Vec<StreamItem> = value_rx.try_iter().collect();
809
+ restore_new_event_loop(api, globals);
810
+ assert_single_error_contains(&items, "run_until_complete");
811
+
812
+ cleanup_globals(api, globals);
813
+ }
814
+
815
+ #[test]
816
+ #[serial]
817
+ fn test_drive_async_generator_conversion_error() {
818
+ let Some(guard) = skip_if_no_python() else {
819
+ return;
820
+ };
821
+ let api = guard.api();
822
+ let globals = make_globals(api);
823
+
824
+ run_file(
825
+ api,
826
+ globals,
827
+ r#"
828
+ import asyncio
829
+ _saved_new_event_loop = asyncio.new_event_loop
830
+
831
+ class _DriveLoop:
832
+ def run_until_complete(self, awaitable):
833
+ return awaitable()
834
+ def close(self):
835
+ pass
836
+
837
+ asyncio.new_event_loop = _DriveLoop
838
+
839
+ class _ObjAgen:
840
+ def __anext__(self):
841
+ return lambda: object()
842
+ "#,
843
+ );
844
+
845
+ let async_gen = eval_obj(api, globals, "_ObjAgen()");
846
+ let (value_tx, value_rx) = unbounded();
847
+ let (_cancel_tx, cancel_rx) = bounded(1);
848
+ drive_async_generator(api, async_gen, &value_tx, &cancel_rx);
849
+
850
+ let items: Vec<StreamItem> = value_rx.try_iter().collect();
851
+ restore_new_event_loop(api, globals);
852
+ assert_eq!(items.len(), 1, "expected a single conversion error");
853
+ match &items[0] {
854
+ StreamItem::Error(msg) => {
855
+ assert!(
856
+ msg.contains("Cannot convert") || msg.contains("convert Python value"),
857
+ "unexpected conversion error message: {msg}"
858
+ );
859
+ }
860
+ _ => panic!("expected error item for conversion failure"),
861
+ }
862
+
863
+ cleanup_globals(api, globals);
864
+ }
865
+
866
+ #[test]
867
+ #[serial]
868
+ fn test_drive_async_generator_close_failure_does_not_break_output() {
869
+ let Some(guard) = skip_if_no_python() else {
870
+ return;
871
+ };
872
+ let api = guard.api();
873
+ let globals = make_globals(api);
874
+
875
+ run_file(
876
+ api,
877
+ globals,
878
+ r#"
879
+ import asyncio
880
+ _saved_new_event_loop = asyncio.new_event_loop
881
+
882
+ class _CloseFailLoop:
883
+ def run_until_complete(self, coro):
884
+ return coro()
885
+ def close(self):
886
+ raise RuntimeError("close failed")
887
+
888
+ asyncio.new_event_loop = _CloseFailLoop
889
+
890
+ class _FakeAgen:
891
+ def __init__(self, values):
892
+ self._values = list(values)
893
+ self._idx = 0
894
+ def __anext__(self):
895
+ if self._idx >= len(self._values):
896
+ def _raise_stop():
897
+ raise StopAsyncIteration
898
+ return _raise_stop
899
+ value = self._values[self._idx]
900
+ self._idx += 1
901
+ return lambda value=value: value
902
+ "#,
903
+ );
904
+
905
+ let async_gen = eval_obj(api, globals, "_FakeAgen([0, 1])");
906
+ let (value_tx, value_rx) = unbounded();
907
+ let (_cancel_tx, cancel_rx) = bounded(1);
908
+ drive_async_generator(api, async_gen, &value_tx, &cancel_rx);
909
+
910
+ let items: Vec<StreamItem> = value_rx.try_iter().collect();
911
+ restore_new_event_loop(api, globals);
912
+ assert_values_then_end(&items, &[0, 1]);
913
+
914
+ cleanup_globals(api, globals);
915
+ }
916
+
917
+ #[test]
918
+ #[serial]
919
+ fn test_drive_async_generator_missing_close_attr_does_not_break_output() {
920
+ let Some(guard) = skip_if_no_python() else {
921
+ return;
922
+ };
923
+ let api = guard.api();
924
+ let globals = make_globals(api);
925
+
926
+ run_file(
927
+ api,
928
+ globals,
929
+ r#"
930
+ import asyncio
931
+ _saved_new_event_loop = asyncio.new_event_loop
932
+
933
+ class _NoCloseLoop:
934
+ def run_until_complete(self, coro):
935
+ return coro()
936
+
937
+ asyncio.new_event_loop = _NoCloseLoop
938
+
939
+ class _FakeAgen:
940
+ def __init__(self, values):
941
+ self._values = list(values)
942
+ self._idx = 0
943
+ def __anext__(self):
944
+ if self._idx >= len(self._values):
945
+ def _raise_stop():
946
+ raise StopAsyncIteration
947
+ return _raise_stop
948
+ value = self._values[self._idx]
949
+ self._idx += 1
950
+ return lambda value=value: value
951
+ "#,
952
+ );
953
+
954
+ let async_gen = eval_obj(api, globals, "_FakeAgen([0])");
955
+ let (value_tx, value_rx) = unbounded();
956
+ let (_cancel_tx, cancel_rx) = bounded(1);
957
+ drive_async_generator(api, async_gen, &value_tx, &cancel_rx);
958
+
959
+ let items: Vec<StreamItem> = value_rx.try_iter().collect();
960
+ restore_new_event_loop(api, globals);
961
+ assert_values_then_end(&items, &[0]);
962
+
963
+ cleanup_globals(api, globals);
964
+ }
965
+
966
+ // ── AsyncGeneratorStream integration tests ──────────────────────
967
+
968
+ fn collect_stream(stream: &AsyncGeneratorStream) -> Vec<StreamItem> {
969
+ let mut items = Vec::new();
970
+ let timeout = std::time::Duration::from_secs(10);
971
+ loop {
972
+ match stream.receiver.recv_timeout(timeout) {
973
+ Ok(item) => {
974
+ let done = matches!(&item, StreamItem::End | StreamItem::Error(_));
975
+ items.push(item);
976
+ if done {
977
+ break;
978
+ }
979
+ }
980
+ Err(crossbeam_channel::RecvTimeoutError::Timeout) => {
981
+ panic!("timeout waiting for stream item");
982
+ }
983
+ Err(crossbeam_channel::RecvTimeoutError::Disconnected) => {
984
+ break;
985
+ }
986
+ }
987
+ }
988
+ items
989
+ }
990
+
991
+ #[test]
992
+ #[serial]
993
+ fn test_stream_from_sync_iterator_yields_values() {
994
+ let Some(guard) = skip_if_no_python() else {
995
+ return;
996
+ };
997
+ let api = guard.api();
998
+ let globals = make_globals(api);
999
+ let py_iter = eval_obj(api, globals, "iter(range(3))");
1000
+ // Drop GIL so the worker thread spawned by from_sync_iterator can acquire it
1001
+ drop(guard);
1002
+
1003
+ let stream = AsyncGeneratorStream::from_sync_iterator(py_iter);
1004
+ let items = collect_stream(&stream);
1005
+ assert_values_then_end(&items, &[0, 1, 2]);
1006
+
1007
+ let gil = api.ensure_gil();
1008
+ cleanup_globals(api, globals);
1009
+ api.release_gil(gil);
1010
+ }
1011
+
1012
+ #[test]
1013
+ #[serial]
1014
+ fn test_stream_from_sync_iterator_empty() {
1015
+ let Some(guard) = skip_if_no_python() else {
1016
+ return;
1017
+ };
1018
+ let api = guard.api();
1019
+ let globals = make_globals(api);
1020
+ let py_iter = eval_obj(api, globals, "iter(range(0))");
1021
+ drop(guard);
1022
+
1023
+ let stream = AsyncGeneratorStream::from_sync_iterator(py_iter);
1024
+ let items = collect_stream(&stream);
1025
+ assert_eq!(items.len(), 1);
1026
+ assert!(matches!(items[0], StreamItem::End));
1027
+
1028
+ let gil = api.ensure_gil();
1029
+ cleanup_globals(api, globals);
1030
+ api.release_gil(gil);
1031
+ }
1032
+
1033
+ #[test]
1034
+ #[serial]
1035
+ fn test_stream_from_python_object_sync_path() {
1036
+ let Some(guard) = skip_if_no_python() else {
1037
+ return;
1038
+ };
1039
+ let api = guard.api();
1040
+ let globals = make_globals(api);
1041
+ // range(3) is iterable but not an async generator —
1042
+ // from_python_object should detect it as sync and use object_get_iter
1043
+ let py_obj = eval_obj(api, globals, "range(3)");
1044
+ drop(guard);
1045
+
1046
+ let stream = AsyncGeneratorStream::from_python_object(py_obj, AsyncStrategy::PythonAdapter)
1047
+ .expect("should create stream from sync iterable");
1048
+ let items = collect_stream(&stream);
1049
+ assert_values_then_end(&items, &[0, 1, 2]);
1050
+
1051
+ let gil = api.ensure_gil();
1052
+ cleanup_globals(api, globals);
1053
+ api.release_gil(gil);
1054
+ }
1055
+
1056
+ #[test]
1057
+ #[serial]
1058
+ fn test_stream_from_python_object_non_iterable_returns_error() {
1059
+ let Some(guard) = skip_if_no_python() else {
1060
+ return;
1061
+ };
1062
+ let api = guard.api();
1063
+ let globals = make_globals(api);
1064
+ let py_obj = eval_obj(api, globals, "42");
1065
+ drop(guard);
1066
+
1067
+ let result = AsyncGeneratorStream::from_python_object(py_obj, AsyncStrategy::PythonAdapter);
1068
+ match result {
1069
+ Err(msg) => {
1070
+ assert!(
1071
+ msg.contains("not iterable"),
1072
+ "expected 'not iterable' error, got: {msg}"
1073
+ );
1074
+ }
1075
+ Ok(_) => panic!("expected error for non-iterable object"),
1076
+ }
1077
+
1078
+ let gil = api.ensure_gil();
1079
+ cleanup_globals(api, globals);
1080
+ api.release_gil(gil);
1081
+ }
1082
+
1083
+ #[test]
1084
+ #[serial]
1085
+ fn test_stream_from_sync_iterator_propagates_error() {
1086
+ let Some(guard) = skip_if_no_python() else {
1087
+ return;
1088
+ };
1089
+ let api = guard.api();
1090
+ let globals = make_globals(api);
1091
+ run_file(
1092
+ api,
1093
+ globals,
1094
+ r#"
1095
+ class _ErrorIter:
1096
+ def __init__(self):
1097
+ self._idx = 0
1098
+ def __iter__(self):
1099
+ return self
1100
+ def __next__(self):
1101
+ if self._idx == 0:
1102
+ self._idx += 1
1103
+ return 1
1104
+ raise ValueError("sync boom")
1105
+ "#,
1106
+ );
1107
+ let py_iter = eval_obj(api, globals, "_ErrorIter()");
1108
+ drop(guard);
1109
+
1110
+ let stream = AsyncGeneratorStream::from_sync_iterator(py_iter);
1111
+ let items = collect_stream(&stream);
1112
+ assert_eq!(items.len(), 2, "expected one value then one error");
1113
+ match &items[0] {
1114
+ StreamItem::Value(SendableValue::Integer(v)) => assert_eq!(*v, 1),
1115
+ _ => panic!("expected first item to be integer value"),
1116
+ }
1117
+ match &items[1] {
1118
+ StreamItem::Error(msg) => {
1119
+ assert!(
1120
+ msg.contains("ValueError") || msg.contains("sync boom"),
1121
+ "unexpected error message: {msg}"
1122
+ );
1123
+ }
1124
+ _ => panic!("expected second item to be error"),
1125
+ }
1126
+
1127
+ let gil = api.ensure_gil();
1128
+ cleanup_globals(api, globals);
1129
+ api.release_gil(gil);
1130
+ }
1131
+
1132
+ #[test]
1133
+ #[serial]
1134
+ fn test_stream_from_python_object_async_rust_driving() {
1135
+ let Some(guard) = skip_if_no_python() else {
1136
+ return;
1137
+ };
1138
+ let api = guard.api();
1139
+ let globals = make_globals(api);
1140
+ run_file(
1141
+ api,
1142
+ globals,
1143
+ r#"
1144
+ import asyncio
1145
+ _saved_new_event_loop = asyncio.new_event_loop
1146
+
1147
+ class _DriveLoop:
1148
+ def run_until_complete(self, awaitable):
1149
+ return awaitable()
1150
+ def close(self):
1151
+ pass
1152
+
1153
+ asyncio.new_event_loop = _DriveLoop
1154
+
1155
+ class _FakeAsyncGen:
1156
+ def __init__(self, values):
1157
+ self._values = list(values)
1158
+ self._idx = 0
1159
+ def __aiter__(self):
1160
+ return self
1161
+ def __anext__(self):
1162
+ if self._idx >= len(self._values):
1163
+ def _raise_stop():
1164
+ raise StopAsyncIteration
1165
+ return _raise_stop
1166
+ value = self._values[self._idx]
1167
+ self._idx += 1
1168
+ return lambda value=value: value
1169
+ "#,
1170
+ );
1171
+ let async_gen = eval_obj(api, globals, "_FakeAsyncGen([0, 1, 2])");
1172
+ drop(guard);
1173
+
1174
+ let stream =
1175
+ AsyncGeneratorStream::from_python_object(async_gen, AsyncStrategy::RustDriving)
1176
+ .expect("should create stream from async generator");
1177
+ let items = collect_stream(&stream);
1178
+ assert_values_then_end(&items, &[0, 1, 2]);
1179
+
1180
+ let gil = api.ensure_gil();
1181
+ restore_new_event_loop(api, globals);
1182
+ cleanup_globals(api, globals);
1183
+ api.release_gil(gil);
1184
+ }
1185
+
1186
+ #[test]
1187
+ #[serial]
1188
+ fn test_stream_from_python_object_async_rust_driving_with_error() {
1189
+ let Some(guard) = skip_if_no_python() else {
1190
+ return;
1191
+ };
1192
+ let api = guard.api();
1193
+ let globals = make_globals(api);
1194
+ run_file(
1195
+ api,
1196
+ globals,
1197
+ r#"
1198
+ import asyncio
1199
+ _saved_new_event_loop = asyncio.new_event_loop
1200
+
1201
+ class _DriveLoop:
1202
+ def run_until_complete(self, awaitable):
1203
+ return awaitable()
1204
+ def close(self):
1205
+ pass
1206
+
1207
+ asyncio.new_event_loop = _DriveLoop
1208
+
1209
+ class _ErrorAsyncGen:
1210
+ def __init__(self):
1211
+ self._yielded = False
1212
+ def __aiter__(self):
1213
+ return self
1214
+ def __anext__(self):
1215
+ if not self._yielded:
1216
+ self._yielded = True
1217
+ return lambda: 1
1218
+ def _raise():
1219
+ raise ValueError("async gen error")
1220
+ return _raise
1221
+ "#,
1222
+ );
1223
+ let async_gen = eval_obj(api, globals, "_ErrorAsyncGen()");
1224
+ drop(guard);
1225
+
1226
+ let stream =
1227
+ AsyncGeneratorStream::from_python_object(async_gen, AsyncStrategy::RustDriving)
1228
+ .expect("should create stream");
1229
+ let items = collect_stream(&stream);
1230
+ assert_eq!(items.len(), 2, "expected one value then one error");
1231
+ match &items[0] {
1232
+ StreamItem::Value(SendableValue::Integer(v)) => assert_eq!(*v, 1),
1233
+ _ => panic!("expected first item to be integer value"),
1234
+ }
1235
+ match &items[1] {
1236
+ StreamItem::Error(msg) => {
1237
+ assert!(
1238
+ msg.contains("ValueError") || msg.contains("async gen error"),
1239
+ "unexpected error message: {msg}"
1240
+ );
1241
+ }
1242
+ _ => panic!("expected second item to be error"),
1243
+ }
1244
+
1245
+ let gil = api.ensure_gil();
1246
+ restore_new_event_loop(api, globals);
1247
+ cleanup_globals(api, globals);
1248
+ api.release_gil(gil);
1249
+ }
1250
+
1251
+ #[test]
1252
+ #[serial]
1253
+ fn test_stream_from_python_object_async_rust_driving_empty() {
1254
+ let Some(guard) = skip_if_no_python() else {
1255
+ return;
1256
+ };
1257
+ let api = guard.api();
1258
+ let globals = make_globals(api);
1259
+ run_file(
1260
+ api,
1261
+ globals,
1262
+ r#"
1263
+ import asyncio
1264
+ _saved_new_event_loop = asyncio.new_event_loop
1265
+
1266
+ class _DriveLoop:
1267
+ def run_until_complete(self, awaitable):
1268
+ return awaitable()
1269
+ def close(self):
1270
+ pass
1271
+
1272
+ asyncio.new_event_loop = _DriveLoop
1273
+
1274
+ class _EmptyAsyncGen:
1275
+ def __aiter__(self):
1276
+ return self
1277
+ def __anext__(self):
1278
+ def _raise_stop():
1279
+ raise StopAsyncIteration
1280
+ return _raise_stop
1281
+ "#,
1282
+ );
1283
+ let async_gen = eval_obj(api, globals, "_EmptyAsyncGen()");
1284
+ drop(guard);
1285
+
1286
+ let stream =
1287
+ AsyncGeneratorStream::from_python_object(async_gen, AsyncStrategy::RustDriving)
1288
+ .expect("should create stream from empty async generator");
1289
+ let items = collect_stream(&stream);
1290
+ assert_eq!(items.len(), 1);
1291
+ assert!(matches!(items[0], StreamItem::End));
1292
+
1293
+ let gil = api.ensure_gil();
1294
+ restore_new_event_loop(api, globals);
1295
+ cleanup_globals(api, globals);
1296
+ api.release_gil(gil);
1297
+ }
1298
+ }