vibe-client 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vibe_client/models.py ADDED
@@ -0,0 +1,710 @@
1
+ from __future__ import annotations
2
+ # Code generated by smithy-python-codegen DO NOT EDIT.
3
+
4
+ from dataclasses import dataclass
5
+ import logging
6
+ from typing import Any, ClassVar, Literal, Self, Union
7
+
8
+ from smithy_core.deserializers import ShapeDeserializer
9
+ from smithy_core.documents import TypeRegistry
10
+ from smithy_core.exceptions import SmithyException
11
+ from smithy_core.schemas import APIOperation, Schema
12
+ from smithy_core.serializers import ShapeSerializer
13
+ from smithy_core.shapes import ShapeID
14
+
15
+ from ._private.schemas import (
16
+ ACTION_VALUE as _SCHEMA_ACTION_VALUE,
17
+ OBSERVATION_VALUE as _SCHEMA_OBSERVATION_VALUE,
18
+ QUERY_AGENT as _SCHEMA_QUERY_AGENT,
19
+ QUERY_AGENT_INPUT as _SCHEMA_QUERY_AGENT_INPUT,
20
+ QUERY_AGENT_OUTPUT as _SCHEMA_QUERY_AGENT_OUTPUT,
21
+ UNAUTHORIZED_EXCEPTION as _SCHEMA_UNAUTHORIZED_EXCEPTION,
22
+ VALIDATION_EXCEPTION as _SCHEMA_VALIDATION_EXCEPTION,
23
+ VALIDATION_EXCEPTION_FIELD as _SCHEMA_VALIDATION_EXCEPTION_FIELD,
24
+ VIBE_VALIDATION_EXCEPTION as _SCHEMA_VIBE_VALIDATION_EXCEPTION,
25
+ )
26
+
27
+
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ class ServiceError(SmithyException):
32
+ """Base error for all errors in the service."""
33
+ pass
34
+
35
+ @dataclass
36
+ class ApiError(ServiceError):
37
+ """Base error for all API errors in the service."""
38
+ code: ClassVar[str]
39
+ fault: ClassVar[Literal["client", "server"]]
40
+
41
+ message: str
42
+
43
+ def __post_init__(self) -> None:
44
+ super().__init__(self.message)
45
+
46
+ @dataclass
47
+ class UnknownApiError(ApiError):
48
+ """Error representing any unknown api errors."""
49
+ code: ClassVar[str] = 'Unknown'
50
+ fault: ClassVar[Literal["client", "server"]] = "client"
51
+
52
+ @dataclass(kw_only=True)
53
+ class ValidationExceptionField:
54
+ """
55
+ Describes one specific validation failure for an input member.
56
+
57
+ :param path:
58
+ **[Required]** - A JSONPointer expression to the structure member whose value
59
+ failed to satisfy the modeled constraints.
60
+
61
+ :param message:
62
+ **[Required]** - A detailed description of the validation failure.
63
+
64
+ """
65
+
66
+ path: str
67
+
68
+ message: str
69
+
70
+ def serialize(self, serializer: ShapeSerializer):
71
+ serializer.write_struct(_SCHEMA_VALIDATION_EXCEPTION_FIELD, self)
72
+
73
+ def serialize_members(self, serializer: ShapeSerializer):
74
+ serializer.write_string(_SCHEMA_VALIDATION_EXCEPTION_FIELD.members["path"], self.path)
75
+ serializer.write_string(_SCHEMA_VALIDATION_EXCEPTION_FIELD.members["message"], self.message)
76
+
77
+ @classmethod
78
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
79
+ return cls(**cls.deserialize_kwargs(deserializer))
80
+
81
+ @classmethod
82
+ def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]:
83
+ kwargs: dict[str, Any] = {}
84
+
85
+ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
86
+ match schema.expect_member_index():
87
+ case 0:
88
+ kwargs["path"] = de.read_string(_SCHEMA_VALIDATION_EXCEPTION_FIELD.members["path"])
89
+
90
+ case 1:
91
+ kwargs["message"] = de.read_string(_SCHEMA_VALIDATION_EXCEPTION_FIELD.members["message"])
92
+
93
+ case _:
94
+ logger.debug("Unexpected member schema: %s", schema)
95
+
96
+ deserializer.read_struct(_SCHEMA_VALIDATION_EXCEPTION_FIELD, consumer=_consumer)
97
+ return kwargs
98
+
99
+ def _serialize_validation_exception_field_list(serializer: ShapeSerializer, schema: Schema, value: list[ValidationExceptionField]) -> None:
100
+ member_schema = schema.members["member"]
101
+ with serializer.begin_list(schema, len(value)) as ls:
102
+ for e in value:
103
+ ls.write_struct(member_schema, e)
104
+
105
+ def _deserialize_validation_exception_field_list(deserializer: ShapeDeserializer, schema: Schema) -> list[ValidationExceptionField]:
106
+ result: list[ValidationExceptionField] = []
107
+ def _read_value(d: ShapeDeserializer):
108
+ if d.is_null():
109
+ d.read_null()
110
+
111
+ else:
112
+ result.append(ValidationExceptionField.deserialize(d))
113
+ deserializer.read_list(schema, _read_value)
114
+ return result
115
+
116
+ @dataclass(kw_only=True)
117
+ class ValidationException(ApiError):
118
+ """
119
+ A standard error for input validation failures. This should be thrown by
120
+ services when a member of the input structure falls outside of the modeled or
121
+ documented constraints.
122
+
123
+ :param message: A message associated with the specific error.
124
+
125
+ :param field_list:
126
+ A list of specific failures encountered while validating the input. A member can
127
+ appear in this list more than once if it failed to satisfy multiple constraints.
128
+
129
+ """
130
+
131
+ code: ClassVar[str] = "ValidationException"
132
+ fault: ClassVar[Literal["client", "server"]] = "client"
133
+
134
+ message: str
135
+ field_list: list[ValidationExceptionField] | None = None
136
+
137
+ def serialize(self, serializer: ShapeSerializer):
138
+ serializer.write_struct(_SCHEMA_VALIDATION_EXCEPTION, self)
139
+
140
+ def serialize_members(self, serializer: ShapeSerializer):
141
+ serializer.write_string(_SCHEMA_VALIDATION_EXCEPTION.members["message"], self.message)
142
+ if self.field_list is not None:
143
+ _serialize_validation_exception_field_list(serializer, _SCHEMA_VALIDATION_EXCEPTION.members["fieldList"], self.field_list)
144
+
145
+ @classmethod
146
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
147
+ return cls(**cls.deserialize_kwargs(deserializer))
148
+
149
+ @classmethod
150
+ def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]:
151
+ kwargs: dict[str, Any] = {}
152
+
153
+ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
154
+ match schema.expect_member_index():
155
+ case 0:
156
+ kwargs["message"] = de.read_string(_SCHEMA_VALIDATION_EXCEPTION.members["message"])
157
+
158
+ case 1:
159
+ kwargs["field_list"] = _deserialize_validation_exception_field_list(de, _SCHEMA_VALIDATION_EXCEPTION.members["fieldList"])
160
+
161
+ case _:
162
+ logger.debug("Unexpected member schema: %s", schema)
163
+
164
+ deserializer.read_struct(_SCHEMA_VALIDATION_EXCEPTION, consumer=_consumer)
165
+ return kwargs
166
+
167
+ def _serialize_float_list(serializer: ShapeSerializer, schema: Schema, value: list[float]) -> None:
168
+ member_schema = schema.members["member"]
169
+ with serializer.begin_list(schema, len(value)) as ls:
170
+ for e in value:
171
+ ls.write_float(member_schema, e)
172
+
173
+ def _deserialize_float_list(deserializer: ShapeDeserializer, schema: Schema) -> list[float]:
174
+ result: list[float] = []
175
+ member_schema = schema.members["member"]
176
+ def _read_value(d: ShapeDeserializer):
177
+ if d.is_null():
178
+ d.read_null()
179
+
180
+ else:
181
+ result.append(d.read_float(member_schema))
182
+ deserializer.read_list(schema, _read_value)
183
+ return result
184
+
185
+ def _serialize_float_list_list(serializer: ShapeSerializer, schema: Schema, value: list[list[float]]) -> None:
186
+ member_schema = schema.members["member"]
187
+ with serializer.begin_list(schema, len(value)) as ls:
188
+ for e in value:
189
+ _serialize_float_list(ls, member_schema, e)
190
+
191
+ def _deserialize_float_list_list(deserializer: ShapeDeserializer, schema: Schema) -> list[list[float]]:
192
+ result: list[list[float]] = []
193
+ member_schema = schema.members["member"]
194
+ def _read_value(d: ShapeDeserializer):
195
+ if d.is_null():
196
+ d.read_null()
197
+
198
+ else:
199
+ result.append(_deserialize_float_list(d, member_schema))
200
+ deserializer.read_list(schema, _read_value)
201
+ return result
202
+
203
+ def _serialize_boolean_list(serializer: ShapeSerializer, schema: Schema, value: list[bool]) -> None:
204
+ member_schema = schema.members["member"]
205
+ with serializer.begin_list(schema, len(value)) as ls:
206
+ for e in value:
207
+ ls.write_boolean(member_schema, e)
208
+
209
+ def _deserialize_boolean_list(deserializer: ShapeDeserializer, schema: Schema) -> list[bool]:
210
+ result: list[bool] = []
211
+ member_schema = schema.members["member"]
212
+ def _read_value(d: ShapeDeserializer):
213
+ if d.is_null():
214
+ d.read_null()
215
+
216
+ else:
217
+ result.append(d.read_boolean(member_schema))
218
+ deserializer.read_list(schema, _read_value)
219
+ return result
220
+
221
+ def _serialize_integer_list(serializer: ShapeSerializer, schema: Schema, value: list[int]) -> None:
222
+ member_schema = schema.members["member"]
223
+ with serializer.begin_list(schema, len(value)) as ls:
224
+ for e in value:
225
+ ls.write_integer(member_schema, e)
226
+
227
+ def _deserialize_integer_list(deserializer: ShapeDeserializer, schema: Schema) -> list[int]:
228
+ result: list[int] = []
229
+ member_schema = schema.members["member"]
230
+ def _read_value(d: ShapeDeserializer):
231
+ if d.is_null():
232
+ d.read_null()
233
+
234
+ else:
235
+ result.append(d.read_integer(member_schema))
236
+ deserializer.read_list(schema, _read_value)
237
+ return result
238
+
239
+ @dataclass(kw_only=True)
240
+ class UnauthorizedException(ApiError):
241
+
242
+ code: ClassVar[str] = "UnauthorizedException"
243
+ fault: ClassVar[Literal["client", "server"]] = "client"
244
+
245
+ message: str
246
+
247
+ def serialize(self, serializer: ShapeSerializer):
248
+ serializer.write_struct(_SCHEMA_UNAUTHORIZED_EXCEPTION, self)
249
+
250
+ def serialize_members(self, serializer: ShapeSerializer):
251
+ serializer.write_string(_SCHEMA_UNAUTHORIZED_EXCEPTION.members["message"], self.message)
252
+
253
+ @classmethod
254
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
255
+ return cls(**cls.deserialize_kwargs(deserializer))
256
+
257
+ @classmethod
258
+ def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]:
259
+ kwargs: dict[str, Any] = {}
260
+
261
+ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
262
+ match schema.expect_member_index():
263
+ case 0:
264
+ kwargs["message"] = de.read_string(_SCHEMA_UNAUTHORIZED_EXCEPTION.members["message"])
265
+
266
+ case _:
267
+ logger.debug("Unexpected member schema: %s", schema)
268
+
269
+ deserializer.read_struct(_SCHEMA_UNAUTHORIZED_EXCEPTION, consumer=_consumer)
270
+ return kwargs
271
+
272
+ @dataclass(kw_only=True)
273
+ class VibeValidationException(ApiError):
274
+
275
+ code: ClassVar[str] = "VibeValidationException"
276
+ fault: ClassVar[Literal["client", "server"]] = "client"
277
+
278
+ message: str
279
+
280
+ def serialize(self, serializer: ShapeSerializer):
281
+ serializer.write_struct(_SCHEMA_VIBE_VALIDATION_EXCEPTION, self)
282
+
283
+ def serialize_members(self, serializer: ShapeSerializer):
284
+ serializer.write_string(_SCHEMA_VIBE_VALIDATION_EXCEPTION.members["message"], self.message)
285
+
286
+ @classmethod
287
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
288
+ return cls(**cls.deserialize_kwargs(deserializer))
289
+
290
+ @classmethod
291
+ def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]:
292
+ kwargs: dict[str, Any] = {}
293
+
294
+ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
295
+ match schema.expect_member_index():
296
+ case 0:
297
+ kwargs["message"] = de.read_string(_SCHEMA_VIBE_VALIDATION_EXCEPTION.members["message"])
298
+
299
+ case _:
300
+ logger.debug("Unexpected member schema: %s", schema)
301
+
302
+ deserializer.read_struct(_SCHEMA_VIBE_VALIDATION_EXCEPTION, consumer=_consumer)
303
+ return kwargs
304
+
305
+ @dataclass
306
+ class ObservationValueDiscrete:
307
+
308
+ value: int
309
+
310
+ def serialize(self, serializer: ShapeSerializer):
311
+ serializer.write_struct(_SCHEMA_OBSERVATION_VALUE, self)
312
+
313
+ def serialize_members(self, serializer: ShapeSerializer):
314
+ serializer.write_integer(_SCHEMA_OBSERVATION_VALUE.members["discrete"], self.value)
315
+
316
+ @classmethod
317
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
318
+ return cls(value=deserializer.read_integer(_SCHEMA_OBSERVATION_VALUE.members["discrete"]))
319
+
320
+ @dataclass
321
+ class ObservationValueContinuous:
322
+
323
+ value: list[float]
324
+
325
+ def serialize(self, serializer: ShapeSerializer):
326
+ serializer.write_struct(_SCHEMA_OBSERVATION_VALUE, self)
327
+
328
+ def serialize_members(self, serializer: ShapeSerializer):
329
+ _serialize_float_list(serializer, _SCHEMA_OBSERVATION_VALUE.members["continuous"], self.value)
330
+
331
+ @classmethod
332
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
333
+ return cls(value=_deserialize_float_list(deserializer, _SCHEMA_OBSERVATION_VALUE.members["continuous"]))
334
+
335
+ @dataclass
336
+ class ObservationValueMultiDiscrete:
337
+
338
+ value: list[int]
339
+
340
+ def serialize(self, serializer: ShapeSerializer):
341
+ serializer.write_struct(_SCHEMA_OBSERVATION_VALUE, self)
342
+
343
+ def serialize_members(self, serializer: ShapeSerializer):
344
+ _serialize_integer_list(serializer, _SCHEMA_OBSERVATION_VALUE.members["multiDiscrete"], self.value)
345
+
346
+ @classmethod
347
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
348
+ return cls(value=_deserialize_integer_list(deserializer, _SCHEMA_OBSERVATION_VALUE.members["multiDiscrete"]))
349
+
350
+ @dataclass
351
+ class ObservationValueMultiBinary:
352
+
353
+ value: list[bool]
354
+
355
+ def serialize(self, serializer: ShapeSerializer):
356
+ serializer.write_struct(_SCHEMA_OBSERVATION_VALUE, self)
357
+
358
+ def serialize_members(self, serializer: ShapeSerializer):
359
+ _serialize_boolean_list(serializer, _SCHEMA_OBSERVATION_VALUE.members["multiBinary"], self.value)
360
+
361
+ @classmethod
362
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
363
+ return cls(value=_deserialize_boolean_list(deserializer, _SCHEMA_OBSERVATION_VALUE.members["multiBinary"]))
364
+
365
+ @dataclass
366
+ class ObservationValueBox:
367
+
368
+ value: list[list[float]]
369
+
370
+ def serialize(self, serializer: ShapeSerializer):
371
+ serializer.write_struct(_SCHEMA_OBSERVATION_VALUE, self)
372
+
373
+ def serialize_members(self, serializer: ShapeSerializer):
374
+ _serialize_float_list_list(serializer, _SCHEMA_OBSERVATION_VALUE.members["box"], self.value)
375
+
376
+ @classmethod
377
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
378
+ return cls(value=_deserialize_float_list_list(deserializer, _SCHEMA_OBSERVATION_VALUE.members["box"]))
379
+
380
+ @dataclass
381
+ class ObservationValueDict:
382
+
383
+ value: dict[str, ObservationValue]
384
+
385
+ def serialize(self, serializer: ShapeSerializer):
386
+ serializer.write_struct(_SCHEMA_OBSERVATION_VALUE, self)
387
+
388
+ def serialize_members(self, serializer: ShapeSerializer):
389
+ _serialize_value_map(serializer, _SCHEMA_OBSERVATION_VALUE.members["dict"], self.value)
390
+
391
+ @classmethod
392
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
393
+ return cls(value=_deserialize_value_map(deserializer, _SCHEMA_OBSERVATION_VALUE.members["dict"]))
394
+
395
+ @dataclass
396
+ class ObservationValueUnknown:
397
+ """Represents an unknown variant.
398
+
399
+ If you receive this value, you will need to update your library to receive the
400
+ parsed value.
401
+
402
+ This value may not be deliberately sent.
403
+ """
404
+
405
+ tag: str
406
+
407
+ def serialize(self, serializer: ShapeSerializer):
408
+ raise SmithyException("Unknown union variants may not be serialized.")
409
+
410
+ def serialize_members(self, serializer: ShapeSerializer):
411
+ raise SmithyException("Unknown union variants may not be serialized.")
412
+
413
+ @classmethod
414
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
415
+ raise NotImplementedError()
416
+
417
+ ObservationValue = Union[ObservationValueDiscrete | ObservationValueContinuous | ObservationValueMultiDiscrete | ObservationValueMultiBinary | ObservationValueBox | ObservationValueDict | ObservationValueUnknown]
418
+
419
+ class _ObservationValueDeserializer:
420
+ _result: ObservationValue | None = None
421
+
422
+ def deserialize(self, deserializer: ShapeDeserializer) -> ObservationValue:
423
+ self._result = None
424
+ deserializer.read_struct(_SCHEMA_OBSERVATION_VALUE, self._consumer)
425
+
426
+ if self._result is None:
427
+ raise SmithyException("Unions must have exactly one value, but found none.")
428
+
429
+ return self._result
430
+
431
+ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None:
432
+ match schema.expect_member_index():
433
+ case 0:
434
+ self._set_result(ObservationValueDiscrete.deserialize(de))
435
+
436
+ case 1:
437
+ self._set_result(ObservationValueContinuous.deserialize(de))
438
+
439
+ case 2:
440
+ self._set_result(ObservationValueMultiDiscrete.deserialize(de))
441
+
442
+ case 3:
443
+ self._set_result(ObservationValueMultiBinary.deserialize(de))
444
+
445
+ case 4:
446
+ self._set_result(ObservationValueBox.deserialize(de))
447
+
448
+ case 5:
449
+ self._set_result(ObservationValueDict.deserialize(de))
450
+
451
+ case _:
452
+ logger.debug("Unexpected member schema: %s", schema)
453
+
454
+ def _set_result(self, value: ObservationValue) -> None:
455
+ if self._result is not None:
456
+ raise SmithyException("Unions must have exactly one value, but found more than one.")
457
+ self._result = value
458
+
459
+ def _serialize_value_map(serializer: ShapeSerializer, schema: Schema, value: dict[str, ObservationValue]) -> None:
460
+ with serializer.begin_map(schema, len(value)) as m:
461
+ value_schema = schema.members["value"]
462
+ for k, v in value.items():
463
+ m.entry(k, lambda vs: vs.write_struct(value_schema, v))
464
+
465
+ def _deserialize_value_map(deserializer: ShapeDeserializer, schema: Schema) -> dict[str, ObservationValue]:
466
+ result: dict[str, ObservationValue] = {}
467
+ value_schema = schema.members["value"]
468
+ def _read_value(k: str, d: ShapeDeserializer):
469
+ if d.is_null():
470
+ d.read_null()
471
+
472
+ else:
473
+ result[k] = _ObservationValueDeserializer().deserialize(d)
474
+ deserializer.read_map(schema, _read_value)
475
+ return result
476
+
477
+ @dataclass
478
+ class ActionValueDiscrete:
479
+
480
+ value: int
481
+
482
+ def serialize(self, serializer: ShapeSerializer):
483
+ serializer.write_struct(_SCHEMA_ACTION_VALUE, self)
484
+
485
+ def serialize_members(self, serializer: ShapeSerializer):
486
+ serializer.write_integer(_SCHEMA_ACTION_VALUE.members["discrete"], self.value)
487
+
488
+ @classmethod
489
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
490
+ return cls(value=deserializer.read_integer(_SCHEMA_ACTION_VALUE.members["discrete"]))
491
+
492
+ @dataclass
493
+ class ActionValueContinuous:
494
+
495
+ value: list[float]
496
+
497
+ def serialize(self, serializer: ShapeSerializer):
498
+ serializer.write_struct(_SCHEMA_ACTION_VALUE, self)
499
+
500
+ def serialize_members(self, serializer: ShapeSerializer):
501
+ _serialize_float_list(serializer, _SCHEMA_ACTION_VALUE.members["continuous"], self.value)
502
+
503
+ @classmethod
504
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
505
+ return cls(value=_deserialize_float_list(deserializer, _SCHEMA_ACTION_VALUE.members["continuous"]))
506
+
507
+ @dataclass
508
+ class ActionValueMultiDiscrete:
509
+
510
+ value: list[int]
511
+
512
+ def serialize(self, serializer: ShapeSerializer):
513
+ serializer.write_struct(_SCHEMA_ACTION_VALUE, self)
514
+
515
+ def serialize_members(self, serializer: ShapeSerializer):
516
+ _serialize_integer_list(serializer, _SCHEMA_ACTION_VALUE.members["multiDiscrete"], self.value)
517
+
518
+ @classmethod
519
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
520
+ return cls(value=_deserialize_integer_list(deserializer, _SCHEMA_ACTION_VALUE.members["multiDiscrete"]))
521
+
522
+ @dataclass
523
+ class ActionValueMultiBinary:
524
+
525
+ value: list[bool]
526
+
527
+ def serialize(self, serializer: ShapeSerializer):
528
+ serializer.write_struct(_SCHEMA_ACTION_VALUE, self)
529
+
530
+ def serialize_members(self, serializer: ShapeSerializer):
531
+ _serialize_boolean_list(serializer, _SCHEMA_ACTION_VALUE.members["multiBinary"], self.value)
532
+
533
+ @classmethod
534
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
535
+ return cls(value=_deserialize_boolean_list(deserializer, _SCHEMA_ACTION_VALUE.members["multiBinary"]))
536
+
537
+ @dataclass
538
+ class ActionValueBox:
539
+
540
+ value: list[list[float]]
541
+
542
+ def serialize(self, serializer: ShapeSerializer):
543
+ serializer.write_struct(_SCHEMA_ACTION_VALUE, self)
544
+
545
+ def serialize_members(self, serializer: ShapeSerializer):
546
+ _serialize_float_list_list(serializer, _SCHEMA_ACTION_VALUE.members["box"], self.value)
547
+
548
+ @classmethod
549
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
550
+ return cls(value=_deserialize_float_list_list(deserializer, _SCHEMA_ACTION_VALUE.members["box"]))
551
+
552
+ @dataclass
553
+ class ActionValueDict:
554
+
555
+ value: dict[str, ObservationValue]
556
+
557
+ def serialize(self, serializer: ShapeSerializer):
558
+ serializer.write_struct(_SCHEMA_ACTION_VALUE, self)
559
+
560
+ def serialize_members(self, serializer: ShapeSerializer):
561
+ _serialize_value_map(serializer, _SCHEMA_ACTION_VALUE.members["dict"], self.value)
562
+
563
+ @classmethod
564
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
565
+ return cls(value=_deserialize_value_map(deserializer, _SCHEMA_ACTION_VALUE.members["dict"]))
566
+
567
+ @dataclass
568
+ class ActionValueUnknown:
569
+ """Represents an unknown variant.
570
+
571
+ If you receive this value, you will need to update your library to receive the
572
+ parsed value.
573
+
574
+ This value may not be deliberately sent.
575
+ """
576
+
577
+ tag: str
578
+
579
+ def serialize(self, serializer: ShapeSerializer):
580
+ raise SmithyException("Unknown union variants may not be serialized.")
581
+
582
+ def serialize_members(self, serializer: ShapeSerializer):
583
+ raise SmithyException("Unknown union variants may not be serialized.")
584
+
585
+ @classmethod
586
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
587
+ raise NotImplementedError()
588
+
589
+ ActionValue = Union[ActionValueDiscrete | ActionValueContinuous | ActionValueMultiDiscrete | ActionValueMultiBinary | ActionValueBox | ActionValueDict | ActionValueUnknown]
590
+
591
+ class _ActionValueDeserializer:
592
+ _result: ActionValue | None = None
593
+
594
+ def deserialize(self, deserializer: ShapeDeserializer) -> ActionValue:
595
+ self._result = None
596
+ deserializer.read_struct(_SCHEMA_ACTION_VALUE, self._consumer)
597
+
598
+ if self._result is None:
599
+ raise SmithyException("Unions must have exactly one value, but found none.")
600
+
601
+ return self._result
602
+
603
+ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None:
604
+ match schema.expect_member_index():
605
+ case 0:
606
+ self._set_result(ActionValueDiscrete.deserialize(de))
607
+
608
+ case 1:
609
+ self._set_result(ActionValueContinuous.deserialize(de))
610
+
611
+ case 2:
612
+ self._set_result(ActionValueMultiDiscrete.deserialize(de))
613
+
614
+ case 3:
615
+ self._set_result(ActionValueMultiBinary.deserialize(de))
616
+
617
+ case 4:
618
+ self._set_result(ActionValueBox.deserialize(de))
619
+
620
+ case 5:
621
+ self._set_result(ActionValueDict.deserialize(de))
622
+
623
+ case _:
624
+ logger.debug("Unexpected member schema: %s", schema)
625
+
626
+ def _set_result(self, value: ActionValue) -> None:
627
+ if self._result is not None:
628
+ raise SmithyException("Unions must have exactly one value, but found more than one.")
629
+ self._result = value
630
+
631
+ @dataclass(kw_only=True)
632
+ class QueryAgentInput:
633
+
634
+ experiment_id: str | None = None
635
+ observations: 'ObservationValue | None' = None
636
+
637
+ def serialize(self, serializer: ShapeSerializer):
638
+ serializer.write_struct(_SCHEMA_QUERY_AGENT_INPUT, self)
639
+
640
+ def serialize_members(self, serializer: ShapeSerializer):
641
+ if self.observations is not None:
642
+ serializer.write_struct(_SCHEMA_QUERY_AGENT_INPUT.members["observations"], self.observations)
643
+
644
+ @classmethod
645
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
646
+ return cls(**cls.deserialize_kwargs(deserializer))
647
+
648
+ @classmethod
649
+ def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]:
650
+ kwargs: dict[str, Any] = {}
651
+
652
+ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
653
+ match schema.expect_member_index():
654
+ case 0:
655
+ kwargs["experiment_id"] = de.read_string(_SCHEMA_QUERY_AGENT_INPUT.members["experimentId"])
656
+
657
+ case 1:
658
+ kwargs["observations"] = _ObservationValueDeserializer().deserialize(de)
659
+
660
+ case _:
661
+ logger.debug("Unexpected member schema: %s", schema)
662
+
663
+ deserializer.read_struct(_SCHEMA_QUERY_AGENT_INPUT, consumer=_consumer)
664
+ return kwargs
665
+
666
+ @dataclass(kw_only=True)
667
+ class QueryAgentOutput:
668
+
669
+ actions: 'ActionValue'
670
+
671
+ def serialize(self, serializer: ShapeSerializer):
672
+ serializer.write_struct(_SCHEMA_QUERY_AGENT_OUTPUT, self)
673
+
674
+ def serialize_members(self, serializer: ShapeSerializer):
675
+ serializer.write_struct(_SCHEMA_QUERY_AGENT_OUTPUT.members["actions"], self.actions)
676
+
677
+ @classmethod
678
+ def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
679
+ return cls(**cls.deserialize_kwargs(deserializer))
680
+
681
+ @classmethod
682
+ def deserialize_kwargs(cls, deserializer: ShapeDeserializer) -> dict[str, Any]:
683
+ kwargs: dict[str, Any] = {}
684
+
685
+ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
686
+ match schema.expect_member_index():
687
+ case 0:
688
+ kwargs["actions"] = _ActionValueDeserializer().deserialize(de)
689
+
690
+ case _:
691
+ logger.debug("Unexpected member schema: %s", schema)
692
+
693
+ deserializer.read_struct(_SCHEMA_QUERY_AGENT_OUTPUT, consumer=_consumer)
694
+ return kwargs
695
+
696
+ QUERY_AGENT = APIOperation(
697
+ input = QueryAgentInput,
698
+ output = QueryAgentOutput,
699
+ schema = _SCHEMA_QUERY_AGENT,
700
+ input_schema = _SCHEMA_QUERY_AGENT_INPUT,
701
+ output_schema = _SCHEMA_QUERY_AGENT_OUTPUT,
702
+ error_registry = TypeRegistry({
703
+ ShapeID("vibe.astar.public.api.shared#UnauthorizedException"): UnauthorizedException,
704
+ ShapeID("smithy.framework#ValidationException"): ValidationException,
705
+ ShapeID("vibe.astar.public.api.shared#VibeValidationException"): VibeValidationException,
706
+ }),
707
+ effective_auth_schemes = [
708
+ ShapeID("smithy.api#noAuth")
709
+ ]
710
+ )