vibe-client 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vibe_client/__init__.py +3 -0
- vibe_client/_private/__init__.py +1 -0
- vibe_client/_private/schemas.py +352 -0
- vibe_client/client.py +498 -0
- vibe_client/config.py +74 -0
- vibe_client/deserialize.py +115 -0
- vibe_client/models.py +710 -0
- vibe_client/serialize.py +53 -0
- vibe_client-0.1.0.dist-info/METADATA +128 -0
- vibe_client-0.1.0.dist-info/RECORD +11 -0
- vibe_client-0.1.0.dist-info/WHEEL +4 -0
vibe_client/models.py
ADDED
@@ -0,0 +1,710 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
# Code generated by smithy-python-codegen DO NOT EDIT.
|
3
|
+
|
4
|
+
from dataclasses import dataclass
|
5
|
+
import logging
|
6
|
+
from typing import Any, ClassVar, Literal, Self, Union
|
7
|
+
|
8
|
+
from smithy_core.deserializers import ShapeDeserializer
|
9
|
+
from smithy_core.documents import TypeRegistry
|
10
|
+
from smithy_core.exceptions import SmithyException
|
11
|
+
from smithy_core.schemas import APIOperation, Schema
|
12
|
+
from smithy_core.serializers import ShapeSerializer
|
13
|
+
from smithy_core.shapes import ShapeID
|
14
|
+
|
15
|
+
from ._private.schemas import (
|
16
|
+
ACTION_VALUE as _SCHEMA_ACTION_VALUE,
|
17
|
+
OBSERVATION_VALUE as _SCHEMA_OBSERVATION_VALUE,
|
18
|
+
QUERY_AGENT as _SCHEMA_QUERY_AGENT,
|
19
|
+
QUERY_AGENT_INPUT as _SCHEMA_QUERY_AGENT_INPUT,
|
20
|
+
QUERY_AGENT_OUTPUT as _SCHEMA_QUERY_AGENT_OUTPUT,
|
21
|
+
UNAUTHORIZED_EXCEPTION as _SCHEMA_UNAUTHORIZED_EXCEPTION,
|
22
|
+
VALIDATION_EXCEPTION as _SCHEMA_VALIDATION_EXCEPTION,
|
23
|
+
VALIDATION_EXCEPTION_FIELD as _SCHEMA_VALIDATION_EXCEPTION_FIELD,
|
24
|
+
VIBE_VALIDATION_EXCEPTION as _SCHEMA_VIBE_VALIDATION_EXCEPTION,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
class ServiceError(SmithyException):
|
32
|
+
"""Base error for all errors in the service."""
|
33
|
+
pass
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class ApiError(ServiceError):
|
37
|
+
"""Base error for all API errors in the service."""
|
38
|
+
code: ClassVar[str]
|
39
|
+
fault: ClassVar[Literal["client", "server"]]
|
40
|
+
|
41
|
+
message: str
|
42
|
+
|
43
|
+
def __post_init__(self) -> None:
|
44
|
+
super().__init__(self.message)
|
45
|
+
|
46
|
+
@dataclass
|
47
|
+
class UnknownApiError(ApiError):
|
48
|
+
"""Error representing any unknown api errors."""
|
49
|
+
code: ClassVar[str] = 'Unknown'
|
50
|
+
fault: ClassVar[Literal["client", "server"]] = "client"
|
51
|
+
|
52
|
+
@dataclass(kw_only=True)
|
53
|
+
class ValidationExceptionField:
|
54
|
+
"""
|
55
|
+
Describes one specific validation failure for an input member.
|
56
|
+
|
57
|
+
:param path:
|
58
|
+
**[Required]** - A JSONPointer expression to the structure member whose value
|
59
|
+
failed to satisfy the modeled constraints.
|
60
|
+
|
61
|
+
:param message:
|
62
|
+
**[Required]** - A detailed description of the validation failure.
|
63
|
+
|
64
|
+
"""
|
65
|
+
|
66
|
+
path: str
|
67
|
+
|
68
|
+
message: str
|
69
|
+
|
70
|
+
def serialize(self, serializer: ShapeSerializer):
|
71
|
+
serializer.write_struct(_SCHEMA_VALIDATION_EXCEPTION_FIELD, self)
|
72
|
+
|
73
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
74
|
+
serializer.write_string(_SCHEMA_VALIDATION_EXCEPTION_FIELD.members["path"], self.path)
|
75
|
+
serializer.write_string(_SCHEMA_VALIDATION_EXCEPTION_FIELD.members["message"], self.message)
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
79
|
+
return cls(**cls.deserialize_kwargs(deserializer))
|
80
|
+
|
81
|
+
@classmethod
|
82
|
+
def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]:
|
83
|
+
kwargs: dict[str, Any] = {}
|
84
|
+
|
85
|
+
def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
|
86
|
+
match schema.expect_member_index():
|
87
|
+
case 0:
|
88
|
+
kwargs["path"] = de.read_string(_SCHEMA_VALIDATION_EXCEPTION_FIELD.members["path"])
|
89
|
+
|
90
|
+
case 1:
|
91
|
+
kwargs["message"] = de.read_string(_SCHEMA_VALIDATION_EXCEPTION_FIELD.members["message"])
|
92
|
+
|
93
|
+
case _:
|
94
|
+
logger.debug("Unexpected member schema: %s", schema)
|
95
|
+
|
96
|
+
deserializer.read_struct(_SCHEMA_VALIDATION_EXCEPTION_FIELD, consumer=_consumer)
|
97
|
+
return kwargs
|
98
|
+
|
99
|
+
def _serialize_validation_exception_field_list(serializer: ShapeSerializer, schema: Schema, value: list[ValidationExceptionField]) -> None:
|
100
|
+
member_schema = schema.members["member"]
|
101
|
+
with serializer.begin_list(schema, len(value)) as ls:
|
102
|
+
for e in value:
|
103
|
+
ls.write_struct(member_schema, e)
|
104
|
+
|
105
|
+
def _deserialize_validation_exception_field_list(deserializer: ShapeDeserializer, schema: Schema) -> list[ValidationExceptionField]:
|
106
|
+
result: list[ValidationExceptionField] = []
|
107
|
+
def _read_value(d: ShapeDeserializer):
|
108
|
+
if d.is_null():
|
109
|
+
d.read_null()
|
110
|
+
|
111
|
+
else:
|
112
|
+
result.append(ValidationExceptionField.deserialize(d))
|
113
|
+
deserializer.read_list(schema, _read_value)
|
114
|
+
return result
|
115
|
+
|
116
|
+
@dataclass(kw_only=True)
|
117
|
+
class ValidationException(ApiError):
|
118
|
+
"""
|
119
|
+
A standard error for input validation failures. This should be thrown by
|
120
|
+
services when a member of the input structure falls outside of the modeled or
|
121
|
+
documented constraints.
|
122
|
+
|
123
|
+
:param message: A message associated with the specific error.
|
124
|
+
|
125
|
+
:param field_list:
|
126
|
+
A list of specific failures encountered while validating the input. A member can
|
127
|
+
appear in this list more than once if it failed to satisfy multiple constraints.
|
128
|
+
|
129
|
+
"""
|
130
|
+
|
131
|
+
code: ClassVar[str] = "ValidationException"
|
132
|
+
fault: ClassVar[Literal["client", "server"]] = "client"
|
133
|
+
|
134
|
+
message: str
|
135
|
+
field_list: list[ValidationExceptionField] | None = None
|
136
|
+
|
137
|
+
def serialize(self, serializer: ShapeSerializer):
|
138
|
+
serializer.write_struct(_SCHEMA_VALIDATION_EXCEPTION, self)
|
139
|
+
|
140
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
141
|
+
serializer.write_string(_SCHEMA_VALIDATION_EXCEPTION.members["message"], self.message)
|
142
|
+
if self.field_list is not None:
|
143
|
+
_serialize_validation_exception_field_list(serializer, _SCHEMA_VALIDATION_EXCEPTION.members["fieldList"], self.field_list)
|
144
|
+
|
145
|
+
@classmethod
|
146
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
147
|
+
return cls(**cls.deserialize_kwargs(deserializer))
|
148
|
+
|
149
|
+
@classmethod
|
150
|
+
def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]:
|
151
|
+
kwargs: dict[str, Any] = {}
|
152
|
+
|
153
|
+
def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
|
154
|
+
match schema.expect_member_index():
|
155
|
+
case 0:
|
156
|
+
kwargs["message"] = de.read_string(_SCHEMA_VALIDATION_EXCEPTION.members["message"])
|
157
|
+
|
158
|
+
case 1:
|
159
|
+
kwargs["field_list"] = _deserialize_validation_exception_field_list(de, _SCHEMA_VALIDATION_EXCEPTION.members["fieldList"])
|
160
|
+
|
161
|
+
case _:
|
162
|
+
logger.debug("Unexpected member schema: %s", schema)
|
163
|
+
|
164
|
+
deserializer.read_struct(_SCHEMA_VALIDATION_EXCEPTION, consumer=_consumer)
|
165
|
+
return kwargs
|
166
|
+
|
167
|
+
def _serialize_float_list(serializer: ShapeSerializer, schema: Schema, value: list[float]) -> None:
|
168
|
+
member_schema = schema.members["member"]
|
169
|
+
with serializer.begin_list(schema, len(value)) as ls:
|
170
|
+
for e in value:
|
171
|
+
ls.write_float(member_schema, e)
|
172
|
+
|
173
|
+
def _deserialize_float_list(deserializer: ShapeDeserializer, schema: Schema) -> list[float]:
|
174
|
+
result: list[float] = []
|
175
|
+
member_schema = schema.members["member"]
|
176
|
+
def _read_value(d: ShapeDeserializer):
|
177
|
+
if d.is_null():
|
178
|
+
d.read_null()
|
179
|
+
|
180
|
+
else:
|
181
|
+
result.append(d.read_float(member_schema))
|
182
|
+
deserializer.read_list(schema, _read_value)
|
183
|
+
return result
|
184
|
+
|
185
|
+
def _serialize_float_list_list(serializer: ShapeSerializer, schema: Schema, value: list[list[float]]) -> None:
|
186
|
+
member_schema = schema.members["member"]
|
187
|
+
with serializer.begin_list(schema, len(value)) as ls:
|
188
|
+
for e in value:
|
189
|
+
_serialize_float_list(ls, member_schema, e)
|
190
|
+
|
191
|
+
def _deserialize_float_list_list(deserializer: ShapeDeserializer, schema: Schema) -> list[list[float]]:
|
192
|
+
result: list[list[float]] = []
|
193
|
+
member_schema = schema.members["member"]
|
194
|
+
def _read_value(d: ShapeDeserializer):
|
195
|
+
if d.is_null():
|
196
|
+
d.read_null()
|
197
|
+
|
198
|
+
else:
|
199
|
+
result.append(_deserialize_float_list(d, member_schema))
|
200
|
+
deserializer.read_list(schema, _read_value)
|
201
|
+
return result
|
202
|
+
|
203
|
+
def _serialize_boolean_list(serializer: ShapeSerializer, schema: Schema, value: list[bool]) -> None:
|
204
|
+
member_schema = schema.members["member"]
|
205
|
+
with serializer.begin_list(schema, len(value)) as ls:
|
206
|
+
for e in value:
|
207
|
+
ls.write_boolean(member_schema, e)
|
208
|
+
|
209
|
+
def _deserialize_boolean_list(deserializer: ShapeDeserializer, schema: Schema) -> list[bool]:
|
210
|
+
result: list[bool] = []
|
211
|
+
member_schema = schema.members["member"]
|
212
|
+
def _read_value(d: ShapeDeserializer):
|
213
|
+
if d.is_null():
|
214
|
+
d.read_null()
|
215
|
+
|
216
|
+
else:
|
217
|
+
result.append(d.read_boolean(member_schema))
|
218
|
+
deserializer.read_list(schema, _read_value)
|
219
|
+
return result
|
220
|
+
|
221
|
+
def _serialize_integer_list(serializer: ShapeSerializer, schema: Schema, value: list[int]) -> None:
|
222
|
+
member_schema = schema.members["member"]
|
223
|
+
with serializer.begin_list(schema, len(value)) as ls:
|
224
|
+
for e in value:
|
225
|
+
ls.write_integer(member_schema, e)
|
226
|
+
|
227
|
+
def _deserialize_integer_list(deserializer: ShapeDeserializer, schema: Schema) -> list[int]:
|
228
|
+
result: list[int] = []
|
229
|
+
member_schema = schema.members["member"]
|
230
|
+
def _read_value(d: ShapeDeserializer):
|
231
|
+
if d.is_null():
|
232
|
+
d.read_null()
|
233
|
+
|
234
|
+
else:
|
235
|
+
result.append(d.read_integer(member_schema))
|
236
|
+
deserializer.read_list(schema, _read_value)
|
237
|
+
return result
|
238
|
+
|
239
|
+
@dataclass(kw_only=True)
|
240
|
+
class UnauthorizedException(ApiError):
|
241
|
+
|
242
|
+
code: ClassVar[str] = "UnauthorizedException"
|
243
|
+
fault: ClassVar[Literal["client", "server"]] = "client"
|
244
|
+
|
245
|
+
message: str
|
246
|
+
|
247
|
+
def serialize(self, serializer: ShapeSerializer):
|
248
|
+
serializer.write_struct(_SCHEMA_UNAUTHORIZED_EXCEPTION, self)
|
249
|
+
|
250
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
251
|
+
serializer.write_string(_SCHEMA_UNAUTHORIZED_EXCEPTION.members["message"], self.message)
|
252
|
+
|
253
|
+
@classmethod
|
254
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
255
|
+
return cls(**cls.deserialize_kwargs(deserializer))
|
256
|
+
|
257
|
+
@classmethod
|
258
|
+
def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]:
|
259
|
+
kwargs: dict[str, Any] = {}
|
260
|
+
|
261
|
+
def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
|
262
|
+
match schema.expect_member_index():
|
263
|
+
case 0:
|
264
|
+
kwargs["message"] = de.read_string(_SCHEMA_UNAUTHORIZED_EXCEPTION.members["message"])
|
265
|
+
|
266
|
+
case _:
|
267
|
+
logger.debug("Unexpected member schema: %s", schema)
|
268
|
+
|
269
|
+
deserializer.read_struct(_SCHEMA_UNAUTHORIZED_EXCEPTION, consumer=_consumer)
|
270
|
+
return kwargs
|
271
|
+
|
272
|
+
@dataclass(kw_only=True)
|
273
|
+
class VibeValidationException(ApiError):
|
274
|
+
|
275
|
+
code: ClassVar[str] = "VibeValidationException"
|
276
|
+
fault: ClassVar[Literal["client", "server"]] = "client"
|
277
|
+
|
278
|
+
message: str
|
279
|
+
|
280
|
+
def serialize(self, serializer: ShapeSerializer):
|
281
|
+
serializer.write_struct(_SCHEMA_VIBE_VALIDATION_EXCEPTION, self)
|
282
|
+
|
283
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
284
|
+
serializer.write_string(_SCHEMA_VIBE_VALIDATION_EXCEPTION.members["message"], self.message)
|
285
|
+
|
286
|
+
@classmethod
|
287
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
288
|
+
return cls(**cls.deserialize_kwargs(deserializer))
|
289
|
+
|
290
|
+
@classmethod
|
291
|
+
def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]:
|
292
|
+
kwargs: dict[str, Any] = {}
|
293
|
+
|
294
|
+
def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
|
295
|
+
match schema.expect_member_index():
|
296
|
+
case 0:
|
297
|
+
kwargs["message"] = de.read_string(_SCHEMA_VIBE_VALIDATION_EXCEPTION.members["message"])
|
298
|
+
|
299
|
+
case _:
|
300
|
+
logger.debug("Unexpected member schema: %s", schema)
|
301
|
+
|
302
|
+
deserializer.read_struct(_SCHEMA_VIBE_VALIDATION_EXCEPTION, consumer=_consumer)
|
303
|
+
return kwargs
|
304
|
+
|
305
|
+
@dataclass
|
306
|
+
class ObservationValueDiscrete:
|
307
|
+
|
308
|
+
value: int
|
309
|
+
|
310
|
+
def serialize(self, serializer: ShapeSerializer):
|
311
|
+
serializer.write_struct(_SCHEMA_OBSERVATION_VALUE, self)
|
312
|
+
|
313
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
314
|
+
serializer.write_integer(_SCHEMA_OBSERVATION_VALUE.members["discrete"], self.value)
|
315
|
+
|
316
|
+
@classmethod
|
317
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
318
|
+
return cls(value=deserializer.read_integer(_SCHEMA_OBSERVATION_VALUE.members["discrete"]))
|
319
|
+
|
320
|
+
@dataclass
|
321
|
+
class ObservationValueContinuous:
|
322
|
+
|
323
|
+
value: list[float]
|
324
|
+
|
325
|
+
def serialize(self, serializer: ShapeSerializer):
|
326
|
+
serializer.write_struct(_SCHEMA_OBSERVATION_VALUE, self)
|
327
|
+
|
328
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
329
|
+
_serialize_float_list(serializer, _SCHEMA_OBSERVATION_VALUE.members["continuous"], self.value)
|
330
|
+
|
331
|
+
@classmethod
|
332
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
333
|
+
return cls(value=_deserialize_float_list(deserializer, _SCHEMA_OBSERVATION_VALUE.members["continuous"]))
|
334
|
+
|
335
|
+
@dataclass
|
336
|
+
class ObservationValueMultiDiscrete:
|
337
|
+
|
338
|
+
value: list[int]
|
339
|
+
|
340
|
+
def serialize(self, serializer: ShapeSerializer):
|
341
|
+
serializer.write_struct(_SCHEMA_OBSERVATION_VALUE, self)
|
342
|
+
|
343
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
344
|
+
_serialize_integer_list(serializer, _SCHEMA_OBSERVATION_VALUE.members["multiDiscrete"], self.value)
|
345
|
+
|
346
|
+
@classmethod
|
347
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
348
|
+
return cls(value=_deserialize_integer_list(deserializer, _SCHEMA_OBSERVATION_VALUE.members["multiDiscrete"]))
|
349
|
+
|
350
|
+
@dataclass
|
351
|
+
class ObservationValueMultiBinary:
|
352
|
+
|
353
|
+
value: list[bool]
|
354
|
+
|
355
|
+
def serialize(self, serializer: ShapeSerializer):
|
356
|
+
serializer.write_struct(_SCHEMA_OBSERVATION_VALUE, self)
|
357
|
+
|
358
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
359
|
+
_serialize_boolean_list(serializer, _SCHEMA_OBSERVATION_VALUE.members["multiBinary"], self.value)
|
360
|
+
|
361
|
+
@classmethod
|
362
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
363
|
+
return cls(value=_deserialize_boolean_list(deserializer, _SCHEMA_OBSERVATION_VALUE.members["multiBinary"]))
|
364
|
+
|
365
|
+
@dataclass
|
366
|
+
class ObservationValueBox:
|
367
|
+
|
368
|
+
value: list[list[float]]
|
369
|
+
|
370
|
+
def serialize(self, serializer: ShapeSerializer):
|
371
|
+
serializer.write_struct(_SCHEMA_OBSERVATION_VALUE, self)
|
372
|
+
|
373
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
374
|
+
_serialize_float_list_list(serializer, _SCHEMA_OBSERVATION_VALUE.members["box"], self.value)
|
375
|
+
|
376
|
+
@classmethod
|
377
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
378
|
+
return cls(value=_deserialize_float_list_list(deserializer, _SCHEMA_OBSERVATION_VALUE.members["box"]))
|
379
|
+
|
380
|
+
@dataclass
|
381
|
+
class ObservationValueDict:
|
382
|
+
|
383
|
+
value: dict[str, ObservationValue]
|
384
|
+
|
385
|
+
def serialize(self, serializer: ShapeSerializer):
|
386
|
+
serializer.write_struct(_SCHEMA_OBSERVATION_VALUE, self)
|
387
|
+
|
388
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
389
|
+
_serialize_value_map(serializer, _SCHEMA_OBSERVATION_VALUE.members["dict"], self.value)
|
390
|
+
|
391
|
+
@classmethod
|
392
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
393
|
+
return cls(value=_deserialize_value_map(deserializer, _SCHEMA_OBSERVATION_VALUE.members["dict"]))
|
394
|
+
|
395
|
+
@dataclass
|
396
|
+
class ObservationValueUnknown:
|
397
|
+
"""Represents an unknown variant.
|
398
|
+
|
399
|
+
If you receive this value, you will need to update your library to receive the
|
400
|
+
parsed value.
|
401
|
+
|
402
|
+
This value may not be deliberately sent.
|
403
|
+
"""
|
404
|
+
|
405
|
+
tag: str
|
406
|
+
|
407
|
+
def serialize(self, serializer: ShapeSerializer):
|
408
|
+
raise SmithyException("Unknown union variants may not be serialized.")
|
409
|
+
|
410
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
411
|
+
raise SmithyException("Unknown union variants may not be serialized.")
|
412
|
+
|
413
|
+
@classmethod
|
414
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
415
|
+
raise NotImplementedError()
|
416
|
+
|
417
|
+
ObservationValue = Union[ObservationValueDiscrete | ObservationValueContinuous | ObservationValueMultiDiscrete | ObservationValueMultiBinary | ObservationValueBox | ObservationValueDict | ObservationValueUnknown]
|
418
|
+
|
419
|
+
class _ObservationValueDeserializer:
|
420
|
+
_result: ObservationValue | None = None
|
421
|
+
|
422
|
+
def deserialize(self, deserializer: ShapeDeserializer) -> ObservationValue:
|
423
|
+
self._result = None
|
424
|
+
deserializer.read_struct(_SCHEMA_OBSERVATION_VALUE, self._consumer)
|
425
|
+
|
426
|
+
if self._result is None:
|
427
|
+
raise SmithyException("Unions must have exactly one value, but found none.")
|
428
|
+
|
429
|
+
return self._result
|
430
|
+
|
431
|
+
def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None:
|
432
|
+
match schema.expect_member_index():
|
433
|
+
case 0:
|
434
|
+
self._set_result(ObservationValueDiscrete.deserialize(de))
|
435
|
+
|
436
|
+
case 1:
|
437
|
+
self._set_result(ObservationValueContinuous.deserialize(de))
|
438
|
+
|
439
|
+
case 2:
|
440
|
+
self._set_result(ObservationValueMultiDiscrete.deserialize(de))
|
441
|
+
|
442
|
+
case 3:
|
443
|
+
self._set_result(ObservationValueMultiBinary.deserialize(de))
|
444
|
+
|
445
|
+
case 4:
|
446
|
+
self._set_result(ObservationValueBox.deserialize(de))
|
447
|
+
|
448
|
+
case 5:
|
449
|
+
self._set_result(ObservationValueDict.deserialize(de))
|
450
|
+
|
451
|
+
case _:
|
452
|
+
logger.debug("Unexpected member schema: %s", schema)
|
453
|
+
|
454
|
+
def _set_result(self, value: ObservationValue) -> None:
|
455
|
+
if self._result is not None:
|
456
|
+
raise SmithyException("Unions must have exactly one value, but found more than one.")
|
457
|
+
self._result = value
|
458
|
+
|
459
|
+
def _serialize_value_map(serializer: ShapeSerializer, schema: Schema, value: dict[str, ObservationValue]) -> None:
|
460
|
+
with serializer.begin_map(schema, len(value)) as m:
|
461
|
+
value_schema = schema.members["value"]
|
462
|
+
for k, v in value.items():
|
463
|
+
m.entry(k, lambda vs: vs.write_struct(value_schema, v))
|
464
|
+
|
465
|
+
def _deserialize_value_map(deserializer: ShapeDeserializer, schema: Schema) -> dict[str, ObservationValue]:
|
466
|
+
result: dict[str, ObservationValue] = {}
|
467
|
+
value_schema = schema.members["value"]
|
468
|
+
def _read_value(k: str, d: ShapeDeserializer):
|
469
|
+
if d.is_null():
|
470
|
+
d.read_null()
|
471
|
+
|
472
|
+
else:
|
473
|
+
result[k] = _ObservationValueDeserializer().deserialize(d)
|
474
|
+
deserializer.read_map(schema, _read_value)
|
475
|
+
return result
|
476
|
+
|
477
|
+
@dataclass
|
478
|
+
class ActionValueDiscrete:
|
479
|
+
|
480
|
+
value: int
|
481
|
+
|
482
|
+
def serialize(self, serializer: ShapeSerializer):
|
483
|
+
serializer.write_struct(_SCHEMA_ACTION_VALUE, self)
|
484
|
+
|
485
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
486
|
+
serializer.write_integer(_SCHEMA_ACTION_VALUE.members["discrete"], self.value)
|
487
|
+
|
488
|
+
@classmethod
|
489
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
490
|
+
return cls(value=deserializer.read_integer(_SCHEMA_ACTION_VALUE.members["discrete"]))
|
491
|
+
|
492
|
+
@dataclass
|
493
|
+
class ActionValueContinuous:
|
494
|
+
|
495
|
+
value: list[float]
|
496
|
+
|
497
|
+
def serialize(self, serializer: ShapeSerializer):
|
498
|
+
serializer.write_struct(_SCHEMA_ACTION_VALUE, self)
|
499
|
+
|
500
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
501
|
+
_serialize_float_list(serializer, _SCHEMA_ACTION_VALUE.members["continuous"], self.value)
|
502
|
+
|
503
|
+
@classmethod
|
504
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
505
|
+
return cls(value=_deserialize_float_list(deserializer, _SCHEMA_ACTION_VALUE.members["continuous"]))
|
506
|
+
|
507
|
+
@dataclass
|
508
|
+
class ActionValueMultiDiscrete:
|
509
|
+
|
510
|
+
value: list[int]
|
511
|
+
|
512
|
+
def serialize(self, serializer: ShapeSerializer):
|
513
|
+
serializer.write_struct(_SCHEMA_ACTION_VALUE, self)
|
514
|
+
|
515
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
516
|
+
_serialize_integer_list(serializer, _SCHEMA_ACTION_VALUE.members["multiDiscrete"], self.value)
|
517
|
+
|
518
|
+
@classmethod
|
519
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
520
|
+
return cls(value=_deserialize_integer_list(deserializer, _SCHEMA_ACTION_VALUE.members["multiDiscrete"]))
|
521
|
+
|
522
|
+
@dataclass
|
523
|
+
class ActionValueMultiBinary:
|
524
|
+
|
525
|
+
value: list[bool]
|
526
|
+
|
527
|
+
def serialize(self, serializer: ShapeSerializer):
|
528
|
+
serializer.write_struct(_SCHEMA_ACTION_VALUE, self)
|
529
|
+
|
530
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
531
|
+
_serialize_boolean_list(serializer, _SCHEMA_ACTION_VALUE.members["multiBinary"], self.value)
|
532
|
+
|
533
|
+
@classmethod
|
534
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
535
|
+
return cls(value=_deserialize_boolean_list(deserializer, _SCHEMA_ACTION_VALUE.members["multiBinary"]))
|
536
|
+
|
537
|
+
@dataclass
|
538
|
+
class ActionValueBox:
|
539
|
+
|
540
|
+
value: list[list[float]]
|
541
|
+
|
542
|
+
def serialize(self, serializer: ShapeSerializer):
|
543
|
+
serializer.write_struct(_SCHEMA_ACTION_VALUE, self)
|
544
|
+
|
545
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
546
|
+
_serialize_float_list_list(serializer, _SCHEMA_ACTION_VALUE.members["box"], self.value)
|
547
|
+
|
548
|
+
@classmethod
|
549
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
550
|
+
return cls(value=_deserialize_float_list_list(deserializer, _SCHEMA_ACTION_VALUE.members["box"]))
|
551
|
+
|
552
|
+
@dataclass
|
553
|
+
class ActionValueDict:
|
554
|
+
|
555
|
+
value: dict[str, ObservationValue]
|
556
|
+
|
557
|
+
def serialize(self, serializer: ShapeSerializer):
|
558
|
+
serializer.write_struct(_SCHEMA_ACTION_VALUE, self)
|
559
|
+
|
560
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
561
|
+
_serialize_value_map(serializer, _SCHEMA_ACTION_VALUE.members["dict"], self.value)
|
562
|
+
|
563
|
+
@classmethod
|
564
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
565
|
+
return cls(value=_deserialize_value_map(deserializer, _SCHEMA_ACTION_VALUE.members["dict"]))
|
566
|
+
|
567
|
+
@dataclass
|
568
|
+
class ActionValueUnknown:
|
569
|
+
"""Represents an unknown variant.
|
570
|
+
|
571
|
+
If you receive this value, you will need to update your library to receive the
|
572
|
+
parsed value.
|
573
|
+
|
574
|
+
This value may not be deliberately sent.
|
575
|
+
"""
|
576
|
+
|
577
|
+
tag: str
|
578
|
+
|
579
|
+
def serialize(self, serializer: ShapeSerializer):
|
580
|
+
raise SmithyException("Unknown union variants may not be serialized.")
|
581
|
+
|
582
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
583
|
+
raise SmithyException("Unknown union variants may not be serialized.")
|
584
|
+
|
585
|
+
@classmethod
|
586
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
587
|
+
raise NotImplementedError()
|
588
|
+
|
589
|
+
ActionValue = Union[ActionValueDiscrete | ActionValueContinuous | ActionValueMultiDiscrete | ActionValueMultiBinary | ActionValueBox | ActionValueDict | ActionValueUnknown]
|
590
|
+
|
591
|
+
class _ActionValueDeserializer:
|
592
|
+
_result: ActionValue | None = None
|
593
|
+
|
594
|
+
def deserialize(self, deserializer: ShapeDeserializer) -> ActionValue:
|
595
|
+
self._result = None
|
596
|
+
deserializer.read_struct(_SCHEMA_ACTION_VALUE, self._consumer)
|
597
|
+
|
598
|
+
if self._result is None:
|
599
|
+
raise SmithyException("Unions must have exactly one value, but found none.")
|
600
|
+
|
601
|
+
return self._result
|
602
|
+
|
603
|
+
def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None:
|
604
|
+
match schema.expect_member_index():
|
605
|
+
case 0:
|
606
|
+
self._set_result(ActionValueDiscrete.deserialize(de))
|
607
|
+
|
608
|
+
case 1:
|
609
|
+
self._set_result(ActionValueContinuous.deserialize(de))
|
610
|
+
|
611
|
+
case 2:
|
612
|
+
self._set_result(ActionValueMultiDiscrete.deserialize(de))
|
613
|
+
|
614
|
+
case 3:
|
615
|
+
self._set_result(ActionValueMultiBinary.deserialize(de))
|
616
|
+
|
617
|
+
case 4:
|
618
|
+
self._set_result(ActionValueBox.deserialize(de))
|
619
|
+
|
620
|
+
case 5:
|
621
|
+
self._set_result(ActionValueDict.deserialize(de))
|
622
|
+
|
623
|
+
case _:
|
624
|
+
logger.debug("Unexpected member schema: %s", schema)
|
625
|
+
|
626
|
+
def _set_result(self, value: ActionValue) -> None:
|
627
|
+
if self._result is not None:
|
628
|
+
raise SmithyException("Unions must have exactly one value, but found more than one.")
|
629
|
+
self._result = value
|
630
|
+
|
631
|
+
@dataclass(kw_only=True)
|
632
|
+
class QueryAgentInput:
|
633
|
+
|
634
|
+
experiment_id: str | None = None
|
635
|
+
observations: 'ObservationValue | None' = None
|
636
|
+
|
637
|
+
def serialize(self, serializer: ShapeSerializer):
|
638
|
+
serializer.write_struct(_SCHEMA_QUERY_AGENT_INPUT, self)
|
639
|
+
|
640
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
641
|
+
if self.observations is not None:
|
642
|
+
serializer.write_struct(_SCHEMA_QUERY_AGENT_INPUT.members["observations"], self.observations)
|
643
|
+
|
644
|
+
@classmethod
|
645
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
646
|
+
return cls(**cls.deserialize_kwargs(deserializer))
|
647
|
+
|
648
|
+
@classmethod
|
649
|
+
def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]:
|
650
|
+
kwargs: dict[str, Any] = {}
|
651
|
+
|
652
|
+
def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
|
653
|
+
match schema.expect_member_index():
|
654
|
+
case 0:
|
655
|
+
kwargs["experiment_id"] = de.read_string(_SCHEMA_QUERY_AGENT_INPUT.members["experimentId"])
|
656
|
+
|
657
|
+
case 1:
|
658
|
+
kwargs["observations"] = _ObservationValueDeserializer().deserialize(de)
|
659
|
+
|
660
|
+
case _:
|
661
|
+
logger.debug("Unexpected member schema: %s", schema)
|
662
|
+
|
663
|
+
deserializer.read_struct(_SCHEMA_QUERY_AGENT_INPUT, consumer=_consumer)
|
664
|
+
return kwargs
|
665
|
+
|
666
|
+
@dataclass(kw_only=True)
|
667
|
+
class QueryAgentOutput:
|
668
|
+
|
669
|
+
actions: 'ActionValue'
|
670
|
+
|
671
|
+
def serialize(self, serializer: ShapeSerializer):
|
672
|
+
serializer.write_struct(_SCHEMA_QUERY_AGENT_OUTPUT, self)
|
673
|
+
|
674
|
+
def serialize_members(self, serializer: ShapeSerializer):
|
675
|
+
serializer.write_struct(_SCHEMA_QUERY_AGENT_OUTPUT.members["actions"], self.actions)
|
676
|
+
|
677
|
+
@classmethod
|
678
|
+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
|
679
|
+
return cls(**cls.deserialize_kwargs(deserializer))
|
680
|
+
|
681
|
+
@classmethod
|
682
|
+
def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]:
|
683
|
+
kwargs: dict[str, Any] = {}
|
684
|
+
|
685
|
+
def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
|
686
|
+
match schema.expect_member_index():
|
687
|
+
case 0:
|
688
|
+
kwargs["actions"] = _ActionValueDeserializer().deserialize(de)
|
689
|
+
|
690
|
+
case _:
|
691
|
+
logger.debug("Unexpected member schema: %s", schema)
|
692
|
+
|
693
|
+
deserializer.read_struct(_SCHEMA_QUERY_AGENT_OUTPUT, consumer=_consumer)
|
694
|
+
return kwargs
|
695
|
+
|
696
|
+
QUERY_AGENT = APIOperation(
|
697
|
+
input = QueryAgentInput,
|
698
|
+
output = QueryAgentOutput,
|
699
|
+
schema = _SCHEMA_QUERY_AGENT,
|
700
|
+
input_schema = _SCHEMA_QUERY_AGENT_INPUT,
|
701
|
+
output_schema = _SCHEMA_QUERY_AGENT_OUTPUT,
|
702
|
+
error_registry = TypeRegistry({
|
703
|
+
ShapeID("vibe.astar.public.api.shared#UnauthorizedException"): UnauthorizedException,
|
704
|
+
ShapeID("smithy.framework#ValidationException"): ValidationException,
|
705
|
+
ShapeID("vibe.astar.public.api.shared#VibeValidationException"): VibeValidationException,
|
706
|
+
}),
|
707
|
+
effective_auth_schemes = [
|
708
|
+
ShapeID("smithy.api#noAuth")
|
709
|
+
]
|
710
|
+
)
|