rustest 0.14.0__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rustest/decorators.py ADDED
@@ -0,0 +1,968 @@
1
+ """User facing decorators mirroring the most common pytest helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable, Mapping, Sequence
6
+ from typing import Any, ParamSpec, TypeVar, overload, cast
7
+
8
+ P = ParamSpec("P")
9
+ R = TypeVar("R")
10
+ Q = ParamSpec("Q")
11
+ S = TypeVar("S")
12
+ TFunc = TypeVar("TFunc", bound=Callable[..., Any])
13
+
14
+ # Valid fixture scopes
15
+ VALID_SCOPES = frozenset(["function", "class", "module", "package", "session"])
16
+
17
+
18
+ class ParameterSet:
19
+ """Represents a single parameter set for pytest.param().
20
+
21
+ This class holds the values for a parametrized test case along with
22
+ optional id and marks metadata.
23
+ """
24
+
25
+ def __init__(self, values: tuple[Any, ...], id: str | None = None, marks: Any = None):
26
+ super().__init__()
27
+ self.values = values
28
+ self.id = id
29
+ self.marks = marks # Currently not used, but stored for future support
30
+
31
+ def __repr__(self) -> str:
32
+ return f"ParameterSet(values={self.values!r}, id={self.id!r})"
33
+
34
+
35
+ @overload
36
+ def fixture(
37
+ func: Callable[P, R],
38
+ *,
39
+ scope: str = "function",
40
+ autouse: bool = False,
41
+ name: str | None = None,
42
+ params: Sequence[Any] | None = None,
43
+ ids: Sequence[str] | Callable[[Any], str | None] | None = None,
44
+ ) -> Callable[P, R]: ...
45
+
46
+
47
+ @overload
48
+ def fixture(
49
+ *,
50
+ scope: str = "function",
51
+ autouse: bool = False,
52
+ name: str | None = None,
53
+ params: Sequence[Any] | None = None,
54
+ ids: Sequence[str] | Callable[[Any], str | None] | None = None,
55
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
56
+
57
+
58
+ def fixture(
59
+ func: Callable[P, R] | None = None,
60
+ *,
61
+ scope: str = "function",
62
+ autouse: bool = False,
63
+ name: str | None = None,
64
+ params: Sequence[Any] | None = None,
65
+ ids: Sequence[str] | Callable[[Any], str | None] | None = None,
66
+ ) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
67
+ """Mark a function as a fixture with a specific scope.
68
+
69
+ Args:
70
+ func: The function to decorate (when used without parentheses)
71
+ scope: The scope of the fixture. One of:
72
+ - "function": New instance for each test function (default)
73
+ - "class": Shared across all test methods in a class
74
+ - "module": Shared across all tests in a module
75
+ - "package": Shared across all tests in a package
76
+ - "session": Shared across all tests in the session
77
+ autouse: If True, the fixture will be automatically used by all tests
78
+ in its scope without needing to be explicitly requested (default: False)
79
+ name: Override the fixture name (default: use the function name)
80
+ params: Optional list of parameter values. The fixture will be called
81
+ once for each parameter, and tests using this fixture will be run
82
+ once for each parameter value. Access the current value via request.param.
83
+ ids: Optional list of string IDs or a callable to generate IDs for each
84
+ parameter value. If not provided, IDs are auto-generated.
85
+
86
+ Usage:
87
+ @fixture
88
+ def my_fixture():
89
+ return 42
90
+
91
+ @fixture(scope="module")
92
+ def shared_fixture():
93
+ return expensive_setup()
94
+
95
+ @fixture(autouse=True)
96
+ def setup_fixture():
97
+ # This fixture will run automatically before each test
98
+ setup_environment()
99
+
100
+ @fixture(name="db")
101
+ def _database_fixture():
102
+ # This fixture is available as "db", not "_database_fixture"
103
+ return Database()
104
+
105
+ @fixture(params=[1, 2, 3])
106
+ def number(request):
107
+ # This fixture will provide values 1, 2, 3 to tests
108
+ return request.param
109
+
110
+ @fixture(params=["mysql", "postgres"], ids=["MySQL", "PostgreSQL"])
111
+ def database(request):
112
+ # Tests will run with both database types
113
+ return create_db(request.param)
114
+ """
115
+ if scope not in VALID_SCOPES:
116
+ valid = ", ".join(sorted(VALID_SCOPES))
117
+ msg = f"Invalid fixture scope '{scope}'. Must be one of: {valid}"
118
+ raise ValueError(msg)
119
+
120
+ def decorator(f: Callable[P, R]) -> Callable[P, R]:
121
+ setattr(f, "__rustest_fixture__", True)
122
+ setattr(f, "__rustest_fixture_scope__", scope)
123
+ setattr(f, "__rustest_fixture_autouse__", autouse)
124
+ if name is not None:
125
+ setattr(f, "__rustest_fixture_name__", name)
126
+
127
+ # Handle fixture parametrization
128
+ if params is not None:
129
+ # Build parameter cases with IDs
130
+ param_cases = _build_fixture_params(params, ids)
131
+ setattr(f, "__rustest_fixture_params__", param_cases)
132
+
133
+ return f
134
+
135
+ # Support both @fixture and @fixture(scope="...")
136
+ if func is not None:
137
+ return decorator(func)
138
+ return decorator
139
+
140
+
141
+ def _build_fixture_params(
142
+ params: Sequence[Any],
143
+ ids: Sequence[str] | Callable[[Any], str | None] | None,
144
+ ) -> list[dict[str, Any]]:
145
+ """Build fixture parameter cases with IDs.
146
+
147
+ Args:
148
+ params: The parameter values
149
+ ids: Optional IDs for each parameter value
150
+
151
+ Returns:
152
+ A list of dicts with 'id' and 'value' keys
153
+ """
154
+ cases: list[dict[str, Any]] = []
155
+ ids_is_callable = callable(ids)
156
+
157
+ if ids is not None and not ids_is_callable:
158
+ if len(ids) != len(params):
159
+ msg = "ids must match the number of params"
160
+ raise ValueError(msg)
161
+
162
+ for index, param_value in enumerate(params):
163
+ # Handle ParameterSet objects (from pytest.param())
164
+ param_set_id: str | None = None
165
+ actual_value: Any = param_value
166
+ if isinstance(param_value, ParameterSet):
167
+ param_set_id = param_value.id
168
+ # For fixture params, we expect a single value
169
+ actual_value = (
170
+ param_value.values[0] if len(param_value.values) == 1 else param_value.values
171
+ )
172
+
173
+ # Generate case ID
174
+ # Priority: ParameterSet id > ids parameter > auto-generated
175
+ if param_set_id is not None:
176
+ case_id = param_set_id
177
+ elif ids is None:
178
+ # Auto-generate ID based on value representation
179
+ case_id = _generate_param_id(actual_value, index)
180
+ elif ids_is_callable:
181
+ generated_id = ids(actual_value)
182
+ case_id = (
183
+ str(generated_id)
184
+ if generated_id is not None
185
+ else _generate_param_id(actual_value, index)
186
+ )
187
+ else:
188
+ case_id = ids[index]
189
+
190
+ cases.append({"id": case_id, "value": actual_value})
191
+
192
+ return cases
193
+
194
+
195
+ def _generate_param_id(value: Any, index: int) -> str:
196
+ """Generate a readable ID for a parameter value.
197
+
198
+ Args:
199
+ value: The parameter value
200
+ index: The index of the parameter
201
+
202
+ Returns:
203
+ A string ID for the parameter
204
+ """
205
+ # Try to generate a readable ID from the value
206
+ if value is None:
207
+ return "None"
208
+ if isinstance(value, bool):
209
+ return str(value)
210
+ if isinstance(value, (int, float)):
211
+ return str(value)
212
+ if isinstance(value, str):
213
+ # Truncate long strings
214
+ if len(value) <= 20:
215
+ return value
216
+ return f"{value[:17]}..."
217
+ if isinstance(value, (list, tuple)):
218
+ seq_value = cast(list[Any] | tuple[Any, ...], value)
219
+ if len(seq_value) == 0:
220
+ return "empty"
221
+ # Try to create a short representation
222
+ items = [_generate_param_id(v, 0) for v in seq_value[:3]]
223
+ result = "-".join(items)
224
+ if len(seq_value) > 3:
225
+ result += f"-...({len(seq_value)})"
226
+ return result
227
+ if isinstance(value, dict):
228
+ dict_value = cast(dict[Any, Any], value)
229
+ if len(dict_value) == 0:
230
+ return "empty_dict"
231
+ return f"dict({len(dict_value)})"
232
+
233
+ # Fallback to index-based ID
234
+ return f"param{index}"
235
+
236
+
237
+ def skip_decorator(reason: str | None = None) -> Callable[[Callable[P, R]], Callable[P, R]]:
238
+ """Skip a test or fixture (decorator form).
239
+
240
+ This is the decorator version used as @skip(reason="...") or via @mark.skip.
241
+ For the function version that raises Skipped, see skip() function.
242
+ """
243
+
244
+ def decorator(func: Callable[P, R]) -> Callable[P, R]:
245
+ setattr(func, "__rustest_skip__", reason or "skipped via rustest.skip")
246
+ return func
247
+
248
+ return decorator
249
+
250
+
251
+ def parametrize(
252
+ arg_names: str | Sequence[str],
253
+ values: Sequence[Sequence[object] | Mapping[str, object] | ParameterSet] | None = None,
254
+ *,
255
+ argvalues: Sequence[Sequence[object] | Mapping[str, object] | ParameterSet] | None = None,
256
+ ids: Sequence[str] | Callable[[Any], str | None] | None = None,
257
+ indirect: bool | Sequence[str] | str = False,
258
+ ) -> Callable[[Callable[Q, S]], Callable[Q, S]]:
259
+ """Parametrise a test function.
260
+
261
+ Args:
262
+ arg_names: Parameter name(s) as a string or sequence
263
+ values: Parameter values for each test case (rustest style)
264
+ argvalues: Parameter values for each test case (pytest style, alias for values)
265
+ ids: Test IDs - either a list of strings or a callable
266
+ indirect: Controls which parameters should be resolved as fixtures:
267
+ - False (default): All parameters are direct values
268
+ - True: All parameters are passed to fixtures with matching names
269
+ - ["param1", "param2"]: Only specified parameters are passed to fixtures
270
+ - "param1": Single parameter passed to fixture
271
+
272
+ When a parameter is indirect, its value is treated as a fixture name,
273
+ and that fixture is resolved and its value used for the test.
274
+
275
+ Example:
276
+ @fixture
277
+ def my_data():
278
+ return {"value": 42}
279
+
280
+ @parametrize("data", ["my_data"], indirect=True)
281
+ def test_example(data):
282
+ assert data["value"] == 42
283
+ """
284
+ # Support both 'values' (rustest style) and 'argvalues' (pytest style)
285
+ actual_values = argvalues if argvalues is not None else values
286
+ if actual_values is None:
287
+ msg = "parametrize() requires either 'values' or 'argvalues' parameter"
288
+ raise TypeError(msg)
289
+
290
+ normalized_names = _normalize_arg_names(arg_names)
291
+ normalized_indirect = _normalize_indirect(indirect, normalized_names)
292
+
293
+ def decorator(func: Callable[Q, S]) -> Callable[Q, S]:
294
+ cases = _build_cases(normalized_names, actual_values, ids)
295
+ setattr(func, "__rustest_parametrization__", cases)
296
+ if normalized_indirect:
297
+ setattr(func, "__rustest_parametrization_indirect__", normalized_indirect)
298
+ return func
299
+
300
+ return decorator
301
+
302
+
303
+ def _normalize_arg_names(arg_names: str | Sequence[str]) -> tuple[str, ...]:
304
+ if isinstance(arg_names, str):
305
+ parts = [part.strip() for part in arg_names.split(",") if part.strip()]
306
+ if not parts:
307
+ msg = "parametrize() expected at least one argument name"
308
+ raise ValueError(msg)
309
+ return tuple(parts)
310
+ return tuple(arg_names)
311
+
312
+
313
+ def _normalize_indirect(
314
+ indirect: bool | Sequence[str] | str, param_names: tuple[str, ...]
315
+ ) -> list[str]:
316
+ """Normalize the indirect parameter to a list of parameter names.
317
+
318
+ Args:
319
+ indirect: The indirect value from parametrize
320
+ param_names: All parameter names from the parametrization
321
+
322
+ Returns:
323
+ A list of parameter names that should be treated as indirect (fixture references)
324
+
325
+ Raises:
326
+ ValueError: If an indirect parameter name is not in param_names
327
+ """
328
+ if indirect is False:
329
+ return []
330
+ if indirect is True:
331
+ return list(param_names)
332
+ if isinstance(indirect, str):
333
+ if indirect not in param_names:
334
+ msg = f"indirect parameter '{indirect}' not found in parametrize argument names {param_names}"
335
+ raise ValueError(msg)
336
+ return [indirect]
337
+ # It's a sequence of strings
338
+ indirect_list = list(indirect)
339
+ for param in indirect_list:
340
+ if param not in param_names:
341
+ msg = f"indirect parameter '{param}' not found in parametrize argument names {param_names}"
342
+ raise ValueError(msg)
343
+ return indirect_list
344
+
345
+
346
+ def _build_cases(
347
+ names: tuple[str, ...],
348
+ values: Sequence[Sequence[object] | Mapping[str, object] | ParameterSet],
349
+ ids: Sequence[str] | Callable[[Any], str | None] | None,
350
+ ) -> tuple[dict[str, object], ...]:
351
+ case_payloads: list[dict[str, object]] = []
352
+
353
+ # Handle callable ids (e.g., ids=str)
354
+ ids_is_callable = callable(ids)
355
+
356
+ if ids is not None and not ids_is_callable:
357
+ if len(ids) != len(values):
358
+ msg = "ids must match the number of value sets"
359
+ raise ValueError(msg)
360
+
361
+ for index, case in enumerate(values):
362
+ # Handle ParameterSet objects (from pytest.param())
363
+ param_set_id: str | None = None
364
+ actual_case: Any = case
365
+ if isinstance(case, ParameterSet):
366
+ param_set_id = case.id
367
+ actual_case = case.values # Extract the actual values
368
+ # If it's a single value tuple, unwrap it for consistency
369
+ if len(actual_case) == 1:
370
+ actual_case = actual_case[0]
371
+
372
+ # Mappings are only treated as parameter mappings when there are multiple parameters
373
+ # For single parameters, dicts/mappings are treated as values
374
+ data: dict[str, Any]
375
+ if isinstance(actual_case, Mapping) and len(names) > 1:
376
+ data = {name: actual_case[name] for name in names}
377
+ elif isinstance(actual_case, (tuple, list)):
378
+ seq_case = cast(tuple[Any, ...] | list[Any], actual_case)
379
+ if len(seq_case) == len(names):
380
+ # Tuples and lists are unpacked to match parameter names (pytest convention)
381
+ # This handles both single and multiple parameters
382
+ data = {name: seq_case[pos] for pos, name in enumerate(names)}
383
+ else:
384
+ # Length mismatch
385
+ if len(names) == 1:
386
+ data = {names[0]: actual_case}
387
+ else:
388
+ raise ValueError("Parametrized value does not match argument names")
389
+ else:
390
+ # Everything else is treated as a single value
391
+ # This includes: primitives, dicts (single param), objects
392
+ if len(names) == 1:
393
+ data = {names[0]: actual_case}
394
+ else:
395
+ raise ValueError("Parametrized value does not match argument names")
396
+
397
+ # Generate case ID
398
+ # Priority: ParameterSet id > ids parameter > auto-generated
399
+ if param_set_id is not None:
400
+ case_id = param_set_id
401
+ elif ids is None:
402
+ case_id = f"case_{index}"
403
+ elif ids_is_callable:
404
+ # Call the function on the case value to get the ID
405
+ generated_id = ids(actual_case)
406
+ case_id = str(generated_id) if generated_id is not None else f"case_{index}"
407
+ else:
408
+ case_id = ids[index]
409
+
410
+ case_payloads.append({"id": case_id, "values": data})
411
+ return tuple(case_payloads)
412
+
413
+
414
+ class MarkDecorator:
415
+ """A decorator for applying a mark to a test function."""
416
+
417
+ def __init__(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
418
+ super().__init__()
419
+ self.name = name
420
+ self.args = args
421
+ self.kwargs = kwargs
422
+
423
+ def __call__(self, func: TFunc) -> TFunc:
424
+ """Apply this mark to the given function."""
425
+ # Get existing marks or create a new list
426
+ existing_marks: list[dict[str, Any]] = getattr(func, "__rustest_marks__", [])
427
+
428
+ # Add this mark to the list
429
+ mark_data = {
430
+ "name": self.name,
431
+ "args": self.args,
432
+ "kwargs": self.kwargs,
433
+ }
434
+ existing_marks.append(mark_data)
435
+
436
+ # Store the marks list on the function
437
+ setattr(func, "__rustest_marks__", existing_marks)
438
+ return func
439
+
440
+ def __repr__(self) -> str:
441
+ return f"Mark({self.name!r}, {self.args!r}, {self.kwargs!r})"
442
+
443
+
444
+ class MarkGenerator:
445
+ """Namespace for dynamically creating marks like pytest.mark.
446
+
447
+ Usage:
448
+ @mark.slow
449
+ @mark.integration
450
+ @mark.timeout(seconds=30)
451
+
452
+ Standard marks:
453
+ @mark.skipif(condition, *, reason="...")
454
+ @mark.xfail(condition=None, *, reason=None, raises=None, run=True, strict=False)
455
+ @mark.usefixtures("fixture1", "fixture2")
456
+ @mark.asyncio(loop_scope="function")
457
+ """
458
+
459
+ def asyncio(
460
+ self,
461
+ func: Callable[..., Any] | None = None,
462
+ *,
463
+ loop_scope: str = "function",
464
+ ) -> Callable[..., Any]:
465
+ """Mark an async test function to be executed with asyncio.
466
+
467
+ This decorator allows you to write async test functions that will be
468
+ automatically executed in an asyncio event loop. The loop_scope parameter
469
+ controls the scope of the event loop used for execution.
470
+
471
+ Args:
472
+ func: The function to decorate (when used without parentheses)
473
+ loop_scope: The scope of the event loop. One of:
474
+ - "function": New loop for each test function (default)
475
+ - "class": Shared loop across all test methods in a class
476
+ - "module": Shared loop across all tests in a module
477
+ - "session": Shared loop across all tests in the session
478
+
479
+ Usage:
480
+ @mark.asyncio
481
+ async def test_async_function():
482
+ result = await some_async_operation()
483
+ assert result == expected
484
+
485
+ @mark.asyncio(loop_scope="module")
486
+ async def test_with_module_loop():
487
+ await another_async_operation()
488
+
489
+ Note:
490
+ This decorator works best with async functions (coroutines), which will
491
+ be automatically wrapped to run in an asyncio event loop. For pytest
492
+ compatibility, it can also be applied to regular functions (the mark
493
+ will be recorded but the function runs normally without asyncio).
494
+ """
495
+ import asyncio
496
+ import inspect
497
+ from functools import wraps
498
+
499
+ valid_scopes = {"function", "class", "module", "session"}
500
+ if loop_scope not in valid_scopes:
501
+ valid = ", ".join(sorted(valid_scopes))
502
+ msg = f"Invalid loop_scope '{loop_scope}'. Must be one of: {valid}"
503
+ raise ValueError(msg)
504
+
505
+ def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
506
+ # Handle class decoration - apply mark to all async methods
507
+ if inspect.isclass(f):
508
+ # Apply the mark to the class itself
509
+ mark_decorator = MarkDecorator("asyncio", (), {"loop_scope": loop_scope})
510
+ marked_class = mark_decorator(f)
511
+
512
+ # Wrap all async methods in the class
513
+ for name, method in inspect.getmembers(
514
+ marked_class, predicate=inspect.iscoroutinefunction
515
+ ):
516
+ wrapped_method = _wrap_async_function(method, loop_scope)
517
+ setattr(marked_class, name, wrapped_method)
518
+ return marked_class
519
+
520
+ # Check if the function is a coroutine
521
+ if not inspect.iscoroutinefunction(f):
522
+ # For pytest compatibility, allow marking non-async functions
523
+ # Just apply the mark without wrapping
524
+ mark_decorator = MarkDecorator("asyncio", (), {"loop_scope": loop_scope})
525
+ return mark_decorator(f)
526
+
527
+ # Store the asyncio mark
528
+ mark_decorator = MarkDecorator("asyncio", (), {"loop_scope": loop_scope})
529
+ marked_f = mark_decorator(f)
530
+
531
+ # Wrap the async function to run it synchronously
532
+ return _wrap_async_function(marked_f, loop_scope)
533
+
534
+ def _wrap_async_function(f: Callable[..., Any], loop_scope: str) -> Callable[..., Any]:
535
+ """Wrap an async function to run it synchronously in an event loop."""
536
+
537
+ @wraps(f)
538
+ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
539
+ # Get or create event loop based on scope
540
+ # For now, we'll always create a new loop - scope handling will be
541
+ # implemented in a future enhancement via fixtures
542
+ loop = asyncio.new_event_loop()
543
+ asyncio.set_event_loop(loop)
544
+ try:
545
+ # Run the coroutine in the event loop
546
+ # Get the original async function
547
+ original_func = getattr(f, "__wrapped__", f)
548
+ coro = original_func(*args, **kwargs)
549
+ return loop.run_until_complete(coro)
550
+ finally:
551
+ # Clean up the loop
552
+ try:
553
+ # Cancel any pending tasks
554
+ pending = asyncio.all_tasks(loop)
555
+ for task in pending:
556
+ task.cancel()
557
+ # Run the loop one more time to let tasks finish cancellation
558
+ if pending:
559
+ loop.run_until_complete(
560
+ asyncio.gather(*pending, return_exceptions=True)
561
+ )
562
+ except Exception:
563
+ pass
564
+ finally:
565
+ loop.close()
566
+
567
+ # Store reference to original async function
568
+ sync_wrapper.__wrapped__ = f
569
+ return sync_wrapper
570
+
571
+ # Support both @mark.asyncio and @mark.asyncio(loop_scope="...")
572
+ if func is not None:
573
+ return decorator(func)
574
+ return decorator
575
+
576
+ def skipif(
577
+ self,
578
+ condition: bool | str,
579
+ reason: str | None = None,
580
+ *,
581
+ _kw_reason: str | None = None,
582
+ ) -> MarkDecorator:
583
+ """Skip test if condition is true.
584
+
585
+ Args:
586
+ condition: Boolean or string condition to evaluate
587
+ reason: Explanation for why the test is skipped (positional or keyword)
588
+
589
+ Usage:
590
+ # Both forms are supported (pytest compatibility):
591
+ @mark.skipif(sys.platform == "win32", reason="Not supported on Windows")
592
+ @mark.skipif(sys.platform == "win32", "Not supported on Windows")
593
+ def test_unix_only():
594
+ pass
595
+ """
596
+ # Support both positional and keyword-only 'reason' for pytest compatibility
597
+ # Some older pytest code uses: skipif(condition, reason) with positional
598
+ # Modern pytest uses: skipif(condition, reason="...") with keyword-only
599
+ actual_reason = _kw_reason if _kw_reason is not None else reason
600
+ return MarkDecorator("skipif", (condition,), {"reason": actual_reason})
601
+
602
+ def xfail(
603
+ self,
604
+ condition: bool | str | None = None,
605
+ *,
606
+ reason: str | None = None,
607
+ raises: type[BaseException] | tuple[type[BaseException], ...] | None = None,
608
+ run: bool = True,
609
+ strict: bool = False,
610
+ ) -> MarkDecorator:
611
+ """Mark test as expected to fail.
612
+
613
+ Args:
614
+ condition: Optional condition - if False, mark is ignored
615
+ reason: Explanation for why the test is expected to fail
616
+ raises: Expected exception type(s)
617
+ run: Whether to run the test (False means skip it)
618
+ strict: If True, passing test will fail the suite
619
+
620
+ Usage:
621
+ @mark.xfail(reason="Known bug in backend")
622
+ def test_known_bug():
623
+ assert False
624
+
625
+ @mark.xfail(sys.platform == "win32", reason="Not implemented on Windows")
626
+ def test_feature():
627
+ pass
628
+ """
629
+ kwargs = {
630
+ "reason": reason,
631
+ "raises": raises,
632
+ "run": run,
633
+ "strict": strict,
634
+ }
635
+ args = () if condition is None else (condition,)
636
+ return MarkDecorator("xfail", args, kwargs)
637
+
638
+ def usefixtures(self, *names: str) -> MarkDecorator:
639
+ """Use fixtures without explicitly requesting them as parameters.
640
+
641
+ Args:
642
+ *names: Names of fixtures to use
643
+
644
+ Usage:
645
+ @mark.usefixtures("setup_db", "cleanup")
646
+ def test_with_fixtures():
647
+ pass
648
+ """
649
+ return MarkDecorator("usefixtures", names, {})
650
+
651
+ def __getattr__(self, name: str) -> Any:
652
+ """Create a mark decorator for the given name."""
653
+ # Return a callable that can be used as @mark.name or @mark.name(args)
654
+ if name == "parametrize":
655
+ return self._create_parametrize_mark()
656
+ return self._create_mark(name)
657
+
658
+ def _create_mark(self, name: str) -> Any:
659
+ """Create a MarkDecorator that can be called with or without arguments."""
660
+
661
+ class _MarkDecoratorFactory:
662
+ """Factory that allows @mark.name or @mark.name(args)."""
663
+
664
+ def __init__(self, mark_name: str) -> None:
665
+ super().__init__()
666
+ self.mark_name = mark_name
667
+
668
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
669
+ # If called with a single argument that's a function, it's @mark.name
670
+ if (
671
+ len(args) == 1
672
+ and not kwargs
673
+ and callable(args[0])
674
+ and hasattr(args[0], "__name__")
675
+ ):
676
+ decorator = MarkDecorator(self.mark_name, (), {})
677
+ return decorator(args[0])
678
+ # Otherwise it's @mark.name(args) - return a decorator
679
+ return MarkDecorator(self.mark_name, args, kwargs)
680
+
681
+ return _MarkDecoratorFactory(name)
682
+
683
+ def _create_parametrize_mark(self) -> Callable[..., Any]:
684
+ """Create a decorator matching top-level parametrize behaviour."""
685
+
686
+ def _parametrize_mark(*args: Any, **kwargs: Any) -> Any:
687
+ if len(args) == 1 and callable(args[0]) and not kwargs:
688
+ msg = "@mark.parametrize must be called with arguments"
689
+ raise TypeError(msg)
690
+ return parametrize(*args, **kwargs)
691
+
692
+ return _parametrize_mark
693
+
694
+
695
+ # Create a singleton instance
696
+ mark = MarkGenerator()
697
+
698
+
699
+ class ExceptionInfo:
700
+ """Information about an exception caught by raises().
701
+
702
+ Attributes:
703
+ type: The exception type
704
+ value: The exception instance
705
+ traceback: The exception traceback
706
+ """
707
+
708
+ def __init__(
709
+ self, exc_type: type[BaseException], exc_value: BaseException, exc_tb: Any
710
+ ) -> None:
711
+ super().__init__()
712
+ self.type = exc_type
713
+ self.value = exc_value
714
+ self.traceback = exc_tb
715
+
716
+ def __repr__(self) -> str:
717
+ return f"<ExceptionInfo {self.type.__name__}({self.value!r})>"
718
+
719
+
720
+ class RaisesContext:
721
+ """Context manager for asserting that code raises a specific exception.
722
+
723
+ This mimics pytest.raises() behavior, supporting:
724
+ - Single or tuple of exception types
725
+ - Optional regex matching of exception messages
726
+ - Access to caught exception information
727
+
728
+ Usage:
729
+ with raises(ValueError):
730
+ int("not a number")
731
+
732
+ with raises(ValueError, match="invalid literal"):
733
+ int("not a number")
734
+
735
+ with raises((ValueError, TypeError)):
736
+ some_function()
737
+
738
+ # Access the caught exception
739
+ with raises(ValueError) as exc_info:
740
+ raise ValueError("oops")
741
+ assert "oops" in str(exc_info.value)
742
+ """
743
+
744
+ def __init__(
745
+ self,
746
+ exc_type: type[BaseException] | tuple[type[BaseException], ...],
747
+ *,
748
+ match: str | None = None,
749
+ ) -> None:
750
+ super().__init__()
751
+ self.exc_type = exc_type
752
+ self.match_pattern = match
753
+ self.excinfo: ExceptionInfo | None = None
754
+
755
+ def __enter__(self) -> RaisesContext:
756
+ return self
757
+
758
+ def __exit__(
759
+ self,
760
+ exc_type: type[BaseException] | None,
761
+ exc_val: BaseException | None,
762
+ exc_tb: Any,
763
+ ) -> bool:
764
+ # No exception was raised
765
+ if exc_type is None:
766
+ exc_name = self._format_exc_name()
767
+ msg = f"DID NOT RAISE {exc_name}"
768
+ raise AssertionError(msg)
769
+
770
+ # At this point, we know an exception was raised, so exc_val cannot be None
771
+ assert exc_val is not None, "exc_val must not be None when exc_type is not None"
772
+
773
+ # Check if the exception type matches
774
+ if not issubclass(exc_type, self.exc_type):
775
+ # Unexpected exception type - let it propagate
776
+ return False
777
+
778
+ # Store the exception information
779
+ self.excinfo = ExceptionInfo(exc_type, exc_val, exc_tb)
780
+
781
+ # Check if the message matches the pattern (if provided)
782
+ if self.match_pattern is not None:
783
+ import re
784
+
785
+ exc_message = str(exc_val)
786
+ if not re.search(self.match_pattern, exc_message):
787
+ msg = (
788
+ f"Pattern {self.match_pattern!r} does not match "
789
+ f"{exc_message!r}. Exception: {exc_type.__name__}: {exc_message}"
790
+ )
791
+ raise AssertionError(msg)
792
+
793
+ # Suppress the exception (it was expected)
794
+ return True
795
+
796
+ def _format_exc_name(self) -> str:
797
+ """Format the expected exception name(s) for error messages."""
798
+ if isinstance(self.exc_type, tuple):
799
+ names = " or ".join(exc.__name__ for exc in self.exc_type)
800
+ return names
801
+ return self.exc_type.__name__
802
+
803
+ @property
804
+ def value(self) -> BaseException:
805
+ """Access the caught exception value."""
806
+ if self.excinfo is None:
807
+ msg = "No exception was caught"
808
+ raise AttributeError(msg)
809
+ return self.excinfo.value
810
+
811
+ @property
812
+ def type(self) -> type[BaseException]:
813
+ """Access the caught exception type."""
814
+ if self.excinfo is None:
815
+ msg = "No exception was caught"
816
+ raise AttributeError(msg)
817
+ return self.excinfo.type
818
+
819
+
820
+ def raises(
821
+ exc_type: type[BaseException] | tuple[type[BaseException], ...],
822
+ *,
823
+ match: str | None = None,
824
+ ) -> RaisesContext:
825
+ """Assert that code raises a specific exception.
826
+
827
+ Args:
828
+ exc_type: The expected exception type(s). Can be a single type or tuple of types.
829
+ match: Optional regex pattern to match against the exception message.
830
+
831
+ Returns:
832
+ A context manager that catches and validates the exception.
833
+
834
+ Raises:
835
+ AssertionError: If no exception is raised, or if the message doesn't match.
836
+
837
+ Usage:
838
+ with raises(ValueError):
839
+ int("not a number")
840
+
841
+ with raises(ValueError, match="invalid literal"):
842
+ int("not a number")
843
+
844
+ with raises((ValueError, TypeError)):
845
+ some_function()
846
+
847
+ # Access the caught exception
848
+ with raises(ValueError) as exc_info:
849
+ raise ValueError("oops")
850
+ assert "oops" in str(exc_info.value)
851
+ """
852
+ return RaisesContext(exc_type, match=match)
853
+
854
+
855
+ class Failed(Exception):
856
+ """Exception raised by fail() to mark a test as failed."""
857
+
858
+ pass
859
+
860
+
861
+ def fail(reason: str = "", pytrace: bool = True) -> None:
862
+ """Explicitly fail the current test with the given message.
863
+
864
+ This function immediately raises an exception to fail the test,
865
+ similar to pytest.fail(). It's useful for conditional test failures
866
+ where a simple assert is not sufficient.
867
+
868
+ Args:
869
+ reason: The failure message to display
870
+ pytrace: If False, hide the Python traceback (not implemented in rustest,
871
+ kept for pytest compatibility)
872
+
873
+ Raises:
874
+ Failed: Always raised to fail the test
875
+
876
+ Usage:
877
+ def test_validation():
878
+ data = load_data()
879
+ if not is_valid(data):
880
+ fail("Data validation failed")
881
+
882
+ def test_conditional():
883
+ if some_condition:
884
+ fail("Condition should not be true")
885
+ assert something_else
886
+
887
+ # With detailed message
888
+ def test_complex():
889
+ result = complex_operation()
890
+ if result.status == "error":
891
+ fail(f"Operation failed: {result.error_message}")
892
+ """
893
+ __tracebackhide__ = True
894
+ raise Failed(reason)
895
+
896
+
897
+ class Skipped(Exception):
898
+ """Exception raised by skip() to dynamically skip a test."""
899
+
900
+ pass
901
+
902
+
903
+ def skip(reason: str = "", allow_module_level: bool = False) -> None:
904
+ """Skip the current test or module dynamically.
905
+
906
+ This function raises an exception to skip the test at runtime,
907
+ similar to pytest.skip(). It's useful for conditional test skipping
908
+ based on runtime conditions.
909
+
910
+ Args:
911
+ reason: The reason why the test is being skipped
912
+ allow_module_level: If True, allow calling skip() at module level
913
+ (not fully implemented in rustest)
914
+
915
+ Raises:
916
+ Skipped: Always raised to skip the test
917
+
918
+ Usage:
919
+ def test_requires_linux():
920
+ import sys
921
+ if sys.platform != "linux":
922
+ skip("Only runs on Linux")
923
+ # Test code here
924
+
925
+ def test_conditional_skip():
926
+ import subprocess
927
+ result = subprocess.run(["which", "docker"], capture_output=True)
928
+ if result.returncode != 0:
929
+ skip("Docker not available")
930
+ # Docker tests here
931
+ """
932
+ __tracebackhide__ = True
933
+ raise Skipped(reason)
934
+
935
+
936
+ class XFailed(Exception):
937
+ """Exception raised by xfail() to mark a test as expected to fail."""
938
+
939
+ pass
940
+
941
+
942
+ def xfail(reason: str = "") -> None:
943
+ """Mark the current test as expected to fail dynamically.
944
+
945
+ This function raises an exception to mark the test as an expected failure
946
+ at runtime, similar to pytest.xfail(). The test will still run but its
947
+ failure won't count against the test suite.
948
+
949
+ Args:
950
+ reason: The reason why the test is expected to fail
951
+
952
+ Raises:
953
+ XFailed: Always raised to mark the test as xfail
954
+
955
+ Usage:
956
+ def test_known_bug():
957
+ import sys
958
+ if sys.version_info < (3, 11):
959
+ xfail("Known bug in Python < 3.11")
960
+ # Test code that fails on older Python
961
+
962
+ def test_experimental_feature():
963
+ if not feature_complete():
964
+ xfail("Feature not yet complete")
965
+ # Test code here
966
+ """
967
+ __tracebackhide__ = True
968
+ raise XFailed(reason)