raccoonai 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of raccoonai might be problematic. Click here for more details.

Files changed (44) hide show
  1. raccoonai/__init__.py +96 -0
  2. raccoonai/_base_client.py +2051 -0
  3. raccoonai/_client.py +473 -0
  4. raccoonai/_compat.py +219 -0
  5. raccoonai/_constants.py +14 -0
  6. raccoonai/_exceptions.py +108 -0
  7. raccoonai/_files.py +123 -0
  8. raccoonai/_models.py +795 -0
  9. raccoonai/_qs.py +150 -0
  10. raccoonai/_resource.py +43 -0
  11. raccoonai/_response.py +830 -0
  12. raccoonai/_streaming.py +333 -0
  13. raccoonai/_types.py +217 -0
  14. raccoonai/_utils/__init__.py +57 -0
  15. raccoonai/_utils/_logs.py +25 -0
  16. raccoonai/_utils/_proxy.py +62 -0
  17. raccoonai/_utils/_reflection.py +42 -0
  18. raccoonai/_utils/_streams.py +12 -0
  19. raccoonai/_utils/_sync.py +71 -0
  20. raccoonai/_utils/_transform.py +392 -0
  21. raccoonai/_utils/_typing.py +149 -0
  22. raccoonai/_utils/_utils.py +414 -0
  23. raccoonai/_version.py +4 -0
  24. raccoonai/lib/.keep +4 -0
  25. raccoonai/py.typed +0 -0
  26. raccoonai/resources/__init__.py +33 -0
  27. raccoonai/resources/fleet.py +485 -0
  28. raccoonai/resources/lam.py +1161 -0
  29. raccoonai/types/__init__.py +15 -0
  30. raccoonai/types/fleet_create_params.py +77 -0
  31. raccoonai/types/fleet_create_response.py +20 -0
  32. raccoonai/types/fleet_logs_response.py +14 -0
  33. raccoonai/types/fleet_status_response.py +17 -0
  34. raccoonai/types/fleet_terminate_response.py +17 -0
  35. raccoonai/types/lam_extract_params.py +51 -0
  36. raccoonai/types/lam_extract_response.py +28 -0
  37. raccoonai/types/lam_integration_run_params.py +35 -0
  38. raccoonai/types/lam_integration_run_response.py +47 -0
  39. raccoonai/types/lam_run_params.py +41 -0
  40. raccoonai/types/lam_run_response.py +21 -0
  41. raccoonai-0.1.0a1.dist-info/METADATA +422 -0
  42. raccoonai-0.1.0a1.dist-info/RECORD +44 -0
  43. raccoonai-0.1.0a1.dist-info/WHEEL +4 -0
  44. raccoonai-0.1.0a1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,42 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from typing import Any, Callable
5
+
6
+
7
+ def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool:
8
+ """Returns whether or not the given function has a specific parameter"""
9
+ sig = inspect.signature(func)
10
+ return arg_name in sig.parameters
11
+
12
+
13
+ def assert_signatures_in_sync(
14
+ source_func: Callable[..., Any],
15
+ check_func: Callable[..., Any],
16
+ *,
17
+ exclude_params: set[str] = set(),
18
+ ) -> None:
19
+ """Ensure that the signature of the second function matches the first."""
20
+
21
+ check_sig = inspect.signature(check_func)
22
+ source_sig = inspect.signature(source_func)
23
+
24
+ errors: list[str] = []
25
+
26
+ for name, source_param in source_sig.parameters.items():
27
+ if name in exclude_params:
28
+ continue
29
+
30
+ custom_param = check_sig.parameters.get(name)
31
+ if not custom_param:
32
+ errors.append(f"the `{name}` param is missing")
33
+ continue
34
+
35
+ if custom_param.annotation != source_param.annotation:
36
+ errors.append(
37
+ f"types for the `{name}` param are do not match; source={repr(source_param.annotation)} checking={repr(custom_param.annotation)}"
38
+ )
39
+ continue
40
+
41
+ if errors:
42
+ raise AssertionError(f"{len(errors)} errors encountered when comparing signatures:\n\n" + "\n\n".join(errors))
@@ -0,0 +1,12 @@
1
+ from typing import Any
2
+ from typing_extensions import Iterator, AsyncIterator
3
+
4
+
5
+ def consume_sync_iterator(iterator: Iterator[Any]) -> None:
6
+ for _ in iterator:
7
+ ...
8
+
9
+
10
+ async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None:
11
+ async for _ in iterator:
12
+ ...
@@ -0,0 +1,71 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ import asyncio
5
+ import functools
6
+ import contextvars
7
+ from typing import Any, TypeVar, Callable, Awaitable
8
+ from typing_extensions import ParamSpec
9
+
10
+ T_Retval = TypeVar("T_Retval")
11
+ T_ParamSpec = ParamSpec("T_ParamSpec")
12
+
13
+
14
+ if sys.version_info >= (3, 9):
15
+ to_thread = asyncio.to_thread
16
+ else:
17
+ # backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread
18
+ # for Python 3.8 support
19
+ async def to_thread(
20
+ func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
21
+ ) -> Any:
22
+ """Asynchronously run function *func* in a separate thread.
23
+
24
+ Any *args and **kwargs supplied for this function are directly passed
25
+ to *func*. Also, the current :class:`contextvars.Context` is propagated,
26
+ allowing context variables from the main thread to be accessed in the
27
+ separate thread.
28
+
29
+ Returns a coroutine that can be awaited to get the eventual result of *func*.
30
+ """
31
+ loop = asyncio.events.get_running_loop()
32
+ ctx = contextvars.copy_context()
33
+ func_call = functools.partial(ctx.run, func, *args, **kwargs)
34
+ return await loop.run_in_executor(None, func_call)
35
+
36
+
37
+ # inspired by `asyncer`, https://github.com/tiangolo/asyncer
38
+ def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
39
+ """
40
+ Take a blocking function and create an async one that receives the same
41
+ positional and keyword arguments. For python version 3.9 and above, it uses
42
+ asyncio.to_thread to run the function in a separate thread. For python version
43
+ 3.8, it uses locally defined copy of the asyncio.to_thread function which was
44
+ introduced in python 3.9.
45
+
46
+ Usage:
47
+
48
+ ```python
49
+ def blocking_func(arg1, arg2, kwarg1=None):
50
+ # blocking code
51
+ return result
52
+
53
+
54
+ result = asyncify(blocking_function)(arg1, arg2, kwarg1=value1)
55
+ ```
56
+
57
+ ## Arguments
58
+
59
+ `function`: a blocking regular callable (e.g. a function)
60
+
61
+ ## Return
62
+
63
+ An async function that takes the same positional and keyword arguments as the
64
+ original one, that when called runs the same original function in a thread worker
65
+ and returns the result.
66
+ """
67
+
68
+ async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
69
+ return await to_thread(function, *args, **kwargs)
70
+
71
+ return wrapper
@@ -0,0 +1,392 @@
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ import base64
5
+ import pathlib
6
+ from typing import Any, Mapping, TypeVar, cast
7
+ from datetime import date, datetime
8
+ from typing_extensions import Literal, get_args, override, get_type_hints
9
+
10
+ import anyio
11
+ import pydantic
12
+
13
+ from ._utils import (
14
+ is_list,
15
+ is_mapping,
16
+ is_iterable,
17
+ )
18
+ from .._files import is_base64_file_input
19
+ from ._typing import (
20
+ is_list_type,
21
+ is_union_type,
22
+ extract_type_arg,
23
+ is_iterable_type,
24
+ is_required_type,
25
+ is_annotated_type,
26
+ strip_annotated_type,
27
+ )
28
+ from .._compat import model_dump, is_typeddict
29
+
30
+ _T = TypeVar("_T")
31
+
32
+
33
+ # TODO: support for drilling globals() and locals()
34
+ # TODO: ensure works correctly with forward references in all cases
35
+
36
+
37
+ PropertyFormat = Literal["iso8601", "base64", "custom"]
38
+
39
+
40
+ class PropertyInfo:
41
+ """Metadata class to be used in Annotated types to provide information about a given type.
42
+
43
+ For example:
44
+
45
+ class MyParams(TypedDict):
46
+ account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
47
+
48
+ This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
49
+ """
50
+
51
+ alias: str | None
52
+ format: PropertyFormat | None
53
+ format_template: str | None
54
+ discriminator: str | None
55
+
56
+ def __init__(
57
+ self,
58
+ *,
59
+ alias: str | None = None,
60
+ format: PropertyFormat | None = None,
61
+ format_template: str | None = None,
62
+ discriminator: str | None = None,
63
+ ) -> None:
64
+ self.alias = alias
65
+ self.format = format
66
+ self.format_template = format_template
67
+ self.discriminator = discriminator
68
+
69
+ @override
70
+ def __repr__(self) -> str:
71
+ return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
72
+
73
+
74
+ def maybe_transform(
75
+ data: object,
76
+ expected_type: object,
77
+ ) -> Any | None:
78
+ """Wrapper over `transform()` that allows `None` to be passed.
79
+
80
+ See `transform()` for more details.
81
+ """
82
+ if data is None:
83
+ return None
84
+ return transform(data, expected_type)
85
+
86
+
87
+ # Wrapper over _transform_recursive providing fake types
88
+ def transform(
89
+ data: _T,
90
+ expected_type: object,
91
+ ) -> _T:
92
+ """Transform dictionaries based off of type information from the given type, for example:
93
+
94
+ ```py
95
+ class Params(TypedDict, total=False):
96
+ card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
97
+
98
+
99
+ transformed = transform({"card_id": "<my card ID>"}, Params)
100
+ # {'cardID': '<my card ID>'}
101
+ ```
102
+
103
+ Any keys / data that does not have type information given will be included as is.
104
+
105
+ It should be noted that the transformations that this function does are not represented in the type system.
106
+ """
107
+ transformed = _transform_recursive(data, annotation=cast(type, expected_type))
108
+ return cast(_T, transformed)
109
+
110
+
111
+ def _get_annotated_type(type_: type) -> type | None:
112
+ """If the given type is an `Annotated` type then it is returned, if not `None` is returned.
113
+
114
+ This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
115
+ """
116
+ if is_required_type(type_):
117
+ # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
118
+ type_ = get_args(type_)[0]
119
+
120
+ if is_annotated_type(type_):
121
+ return type_
122
+
123
+ return None
124
+
125
+
126
+ def _maybe_transform_key(key: str, type_: type) -> str:
127
+ """Transform the given `data` based on the annotations provided in `type_`.
128
+
129
+ Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.
130
+ """
131
+ annotated_type = _get_annotated_type(type_)
132
+ if annotated_type is None:
133
+ # no `Annotated` definition for this type, no transformation needed
134
+ return key
135
+
136
+ # ignore the first argument as it is the actual type
137
+ annotations = get_args(annotated_type)[1:]
138
+ for annotation in annotations:
139
+ if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
140
+ return annotation.alias
141
+
142
+ return key
143
+
144
+
145
+ def _transform_recursive(
146
+ data: object,
147
+ *,
148
+ annotation: type,
149
+ inner_type: type | None = None,
150
+ ) -> object:
151
+ """Transform the given data against the expected type.
152
+
153
+ Args:
154
+ annotation: The direct type annotation given to the particular piece of data.
155
+ This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
156
+
157
+ inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
158
+ is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
159
+ the list can be transformed using the metadata from the container type.
160
+
161
+ Defaults to the same value as the `annotation` argument.
162
+ """
163
+ if inner_type is None:
164
+ inner_type = annotation
165
+
166
+ stripped_type = strip_annotated_type(inner_type)
167
+ if is_typeddict(stripped_type) and is_mapping(data):
168
+ return _transform_typeddict(data, stripped_type)
169
+
170
+ if (
171
+ # List[T]
172
+ (is_list_type(stripped_type) and is_list(data))
173
+ # Iterable[T]
174
+ or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
175
+ ):
176
+ # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
177
+ # intended as an iterable, so we don't transform it.
178
+ if isinstance(data, dict):
179
+ return cast(object, data)
180
+
181
+ inner_type = extract_type_arg(stripped_type, 0)
182
+ return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
183
+
184
+ if is_union_type(stripped_type):
185
+ # For union types we run the transformation against all subtypes to ensure that everything is transformed.
186
+ #
187
+ # TODO: there may be edge cases where the same normalized field name will transform to two different names
188
+ # in different subtypes.
189
+ for subtype in get_args(stripped_type):
190
+ data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
191
+ return data
192
+
193
+ if isinstance(data, pydantic.BaseModel):
194
+ return model_dump(data, exclude_unset=True, mode="json")
195
+
196
+ annotated_type = _get_annotated_type(annotation)
197
+ if annotated_type is None:
198
+ return data
199
+
200
+ # ignore the first argument as it is the actual type
201
+ annotations = get_args(annotated_type)[1:]
202
+ for annotation in annotations:
203
+ if isinstance(annotation, PropertyInfo) and annotation.format is not None:
204
+ return _format_data(data, annotation.format, annotation.format_template)
205
+
206
+ return data
207
+
208
+
209
+ def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
210
+ if isinstance(data, (date, datetime)):
211
+ if format_ == "iso8601":
212
+ return data.isoformat()
213
+
214
+ if format_ == "custom" and format_template is not None:
215
+ return data.strftime(format_template)
216
+
217
+ if format_ == "base64" and is_base64_file_input(data):
218
+ binary: str | bytes | None = None
219
+
220
+ if isinstance(data, pathlib.Path):
221
+ binary = data.read_bytes()
222
+ elif isinstance(data, io.IOBase):
223
+ binary = data.read()
224
+
225
+ if isinstance(binary, str): # type: ignore[unreachable]
226
+ binary = binary.encode()
227
+
228
+ if not isinstance(binary, bytes):
229
+ raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
230
+
231
+ return base64.b64encode(binary).decode("ascii")
232
+
233
+ return data
234
+
235
+
236
+ def _transform_typeddict(
237
+ data: Mapping[str, object],
238
+ expected_type: type,
239
+ ) -> Mapping[str, object]:
240
+ result: dict[str, object] = {}
241
+ annotations = get_type_hints(expected_type, include_extras=True)
242
+ for key, value in data.items():
243
+ type_ = annotations.get(key)
244
+ if type_ is None:
245
+ # we do not have a type annotation for this field, leave it as is
246
+ result[key] = value
247
+ else:
248
+ result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
249
+ return result
250
+
251
+
252
+ async def async_maybe_transform(
253
+ data: object,
254
+ expected_type: object,
255
+ ) -> Any | None:
256
+ """Wrapper over `async_transform()` that allows `None` to be passed.
257
+
258
+ See `async_transform()` for more details.
259
+ """
260
+ if data is None:
261
+ return None
262
+ return await async_transform(data, expected_type)
263
+
264
+
265
+ async def async_transform(
266
+ data: _T,
267
+ expected_type: object,
268
+ ) -> _T:
269
+ """Transform dictionaries based off of type information from the given type, for example:
270
+
271
+ ```py
272
+ class Params(TypedDict, total=False):
273
+ card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
274
+
275
+
276
+ transformed = transform({"card_id": "<my card ID>"}, Params)
277
+ # {'cardID': '<my card ID>'}
278
+ ```
279
+
280
+ Any keys / data that does not have type information given will be included as is.
281
+
282
+ It should be noted that the transformations that this function does are not represented in the type system.
283
+ """
284
+ transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
285
+ return cast(_T, transformed)
286
+
287
+
288
+ async def _async_transform_recursive(
289
+ data: object,
290
+ *,
291
+ annotation: type,
292
+ inner_type: type | None = None,
293
+ ) -> object:
294
+ """Transform the given data against the expected type.
295
+
296
+ Args:
297
+ annotation: The direct type annotation given to the particular piece of data.
298
+ This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
299
+
300
+ inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
301
+ is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
302
+ the list can be transformed using the metadata from the container type.
303
+
304
+ Defaults to the same value as the `annotation` argument.
305
+ """
306
+ if inner_type is None:
307
+ inner_type = annotation
308
+
309
+ stripped_type = strip_annotated_type(inner_type)
310
+ if is_typeddict(stripped_type) and is_mapping(data):
311
+ return await _async_transform_typeddict(data, stripped_type)
312
+
313
+ if (
314
+ # List[T]
315
+ (is_list_type(stripped_type) and is_list(data))
316
+ # Iterable[T]
317
+ or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
318
+ ):
319
+ # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
320
+ # intended as an iterable, so we don't transform it.
321
+ if isinstance(data, dict):
322
+ return cast(object, data)
323
+
324
+ inner_type = extract_type_arg(stripped_type, 0)
325
+ return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
326
+
327
+ if is_union_type(stripped_type):
328
+ # For union types we run the transformation against all subtypes to ensure that everything is transformed.
329
+ #
330
+ # TODO: there may be edge cases where the same normalized field name will transform to two different names
331
+ # in different subtypes.
332
+ for subtype in get_args(stripped_type):
333
+ data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
334
+ return data
335
+
336
+ if isinstance(data, pydantic.BaseModel):
337
+ return model_dump(data, exclude_unset=True, mode="json")
338
+
339
+ annotated_type = _get_annotated_type(annotation)
340
+ if annotated_type is None:
341
+ return data
342
+
343
+ # ignore the first argument as it is the actual type
344
+ annotations = get_args(annotated_type)[1:]
345
+ for annotation in annotations:
346
+ if isinstance(annotation, PropertyInfo) and annotation.format is not None:
347
+ return await _async_format_data(data, annotation.format, annotation.format_template)
348
+
349
+ return data
350
+
351
+
352
+ async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
353
+ if isinstance(data, (date, datetime)):
354
+ if format_ == "iso8601":
355
+ return data.isoformat()
356
+
357
+ if format_ == "custom" and format_template is not None:
358
+ return data.strftime(format_template)
359
+
360
+ if format_ == "base64" and is_base64_file_input(data):
361
+ binary: str | bytes | None = None
362
+
363
+ if isinstance(data, pathlib.Path):
364
+ binary = await anyio.Path(data).read_bytes()
365
+ elif isinstance(data, io.IOBase):
366
+ binary = data.read()
367
+
368
+ if isinstance(binary, str): # type: ignore[unreachable]
369
+ binary = binary.encode()
370
+
371
+ if not isinstance(binary, bytes):
372
+ raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
373
+
374
+ return base64.b64encode(binary).decode("ascii")
375
+
376
+ return data
377
+
378
+
379
+ async def _async_transform_typeddict(
380
+ data: Mapping[str, object],
381
+ expected_type: type,
382
+ ) -> Mapping[str, object]:
383
+ result: dict[str, object] = {}
384
+ annotations = get_type_hints(expected_type, include_extras=True)
385
+ for key, value in data.items():
386
+ type_ = annotations.get(key)
387
+ if type_ is None:
388
+ # we do not have a type annotation for this field, leave it as is
389
+ result[key] = value
390
+ else:
391
+ result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
392
+ return result
@@ -0,0 +1,149 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+ import typing
5
+ import typing_extensions
6
+ from typing import Any, TypeVar, Iterable, cast
7
+ from collections import abc as _c_abc
8
+ from typing_extensions import (
9
+ TypeIs,
10
+ Required,
11
+ Annotated,
12
+ get_args,
13
+ get_origin,
14
+ )
15
+
16
+ from .._types import InheritsGeneric
17
+ from .._compat import is_union as _is_union
18
+
19
+
20
+ def is_annotated_type(typ: type) -> bool:
21
+ return get_origin(typ) == Annotated
22
+
23
+
24
+ def is_list_type(typ: type) -> bool:
25
+ return (get_origin(typ) or typ) == list
26
+
27
+
28
+ def is_iterable_type(typ: type) -> bool:
29
+ """If the given type is `typing.Iterable[T]`"""
30
+ origin = get_origin(typ) or typ
31
+ return origin == Iterable or origin == _c_abc.Iterable
32
+
33
+
34
+ def is_union_type(typ: type) -> bool:
35
+ return _is_union(get_origin(typ))
36
+
37
+
38
+ def is_required_type(typ: type) -> bool:
39
+ return get_origin(typ) == Required
40
+
41
+
42
+ def is_typevar(typ: type) -> bool:
43
+ # type ignore is required because type checkers
44
+ # think this expression will always return False
45
+ return type(typ) == TypeVar # type: ignore
46
+
47
+
48
+ _TYPE_ALIAS_TYPES: tuple[type[typing_extensions.TypeAliasType], ...] = (typing_extensions.TypeAliasType,)
49
+ if sys.version_info >= (3, 12):
50
+ _TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType)
51
+
52
+
53
+ def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
54
+ """Return whether the provided argument is an instance of `TypeAliasType`.
55
+
56
+ ```python
57
+ type Int = int
58
+ is_type_alias_type(Int)
59
+ # > True
60
+ Str = TypeAliasType("Str", str)
61
+ is_type_alias_type(Str)
62
+ # > True
63
+ ```
64
+ """
65
+ return isinstance(tp, _TYPE_ALIAS_TYPES)
66
+
67
+
68
+ # Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
69
+ def strip_annotated_type(typ: type) -> type:
70
+ if is_required_type(typ) or is_annotated_type(typ):
71
+ return strip_annotated_type(cast(type, get_args(typ)[0]))
72
+
73
+ return typ
74
+
75
+
76
+ def extract_type_arg(typ: type, index: int) -> type:
77
+ args = get_args(typ)
78
+ try:
79
+ return cast(type, args[index])
80
+ except IndexError as err:
81
+ raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
82
+
83
+
84
+ def extract_type_var_from_base(
85
+ typ: type,
86
+ *,
87
+ generic_bases: tuple[type, ...],
88
+ index: int,
89
+ failure_message: str | None = None,
90
+ ) -> type:
91
+ """Given a type like `Foo[T]`, returns the generic type variable `T`.
92
+
93
+ This also handles the case where a concrete subclass is given, e.g.
94
+ ```py
95
+ class MyResponse(Foo[bytes]):
96
+ ...
97
+
98
+ extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
99
+ ```
100
+
101
+ And where a generic subclass is given:
102
+ ```py
103
+ _T = TypeVar('_T')
104
+ class MyResponse(Foo[_T]):
105
+ ...
106
+
107
+ extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
108
+ ```
109
+ """
110
+ cls = cast(object, get_origin(typ) or typ)
111
+ if cls in generic_bases:
112
+ # we're given the class directly
113
+ return extract_type_arg(typ, index)
114
+
115
+ # if a subclass is given
116
+ # ---
117
+ # this is needed as __orig_bases__ is not present in the typeshed stubs
118
+ # because it is intended to be for internal use only, however there does
119
+ # not seem to be a way to resolve generic TypeVars for inherited subclasses
120
+ # without using it.
121
+ if isinstance(cls, InheritsGeneric):
122
+ target_base_class: Any | None = None
123
+ for base in cls.__orig_bases__:
124
+ if base.__origin__ in generic_bases:
125
+ target_base_class = base
126
+ break
127
+
128
+ if target_base_class is None:
129
+ raise RuntimeError(
130
+ "Could not find the generic base class;\n"
131
+ "This should never happen;\n"
132
+ f"Does {cls} inherit from one of {generic_bases} ?"
133
+ )
134
+
135
+ extracted = extract_type_arg(target_base_class, index)
136
+ if is_typevar(extracted):
137
+ # If the extracted type argument is itself a type variable
138
+ # then that means the subclass itself is generic, so we have
139
+ # to resolve the type argument from the class itself, not
140
+ # the base class.
141
+ #
142
+ # Note: if there is more than 1 type argument, the subclass could
143
+ # change the ordering of the type arguments, this is not currently
144
+ # supported.
145
+ return extract_type_arg(typ, index)
146
+
147
+ return extracted
148
+
149
+ raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")