scalebox-sdk 0.1.25__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. scalebox/__init__.py +2 -2
  2. scalebox/api/__init__.py +3 -1
  3. scalebox/api/client/api/sandboxes/get_sandboxes.py +1 -1
  4. scalebox/api/client/api/sandboxes/post_sandboxes_sandbox_id_connect.py +193 -0
  5. scalebox/api/client/models/connect_sandbox.py +59 -0
  6. scalebox/api/client/models/error.py +2 -2
  7. scalebox/api/client/models/listed_sandbox.py +24 -3
  8. scalebox/api/client/models/new_sandbox.py +10 -0
  9. scalebox/api/client/models/sandbox.py +13 -0
  10. scalebox/api/client/models/sandbox_detail.py +24 -0
  11. scalebox/cli.py +125 -125
  12. scalebox/client/aclient.py +57 -57
  13. scalebox/client/client.py +102 -102
  14. scalebox/code_interpreter/__init__.py +12 -12
  15. scalebox/code_interpreter/charts.py +230 -230
  16. scalebox/code_interpreter/code_interpreter_async.py +3 -1
  17. scalebox/code_interpreter/code_interpreter_sync.py +3 -1
  18. scalebox/code_interpreter/constants.py +3 -3
  19. scalebox/code_interpreter/exceptions.py +13 -13
  20. scalebox/code_interpreter/models.py +485 -485
  21. scalebox/connection_config.py +36 -1
  22. scalebox/csx_connect/__init__.py +1 -1
  23. scalebox/csx_connect/client.py +485 -485
  24. scalebox/csx_desktop/main.py +651 -651
  25. scalebox/exceptions.py +83 -83
  26. scalebox/generated/api.py +61 -61
  27. scalebox/generated/api_pb2.py +203 -203
  28. scalebox/generated/api_pb2.pyi +956 -956
  29. scalebox/generated/api_pb2_connect.py +1407 -1407
  30. scalebox/generated/rpc.py +50 -50
  31. scalebox/sandbox/main.py +146 -139
  32. scalebox/sandbox/sandbox_api.py +105 -91
  33. scalebox/sandbox/signature.py +40 -40
  34. scalebox/sandbox/utils.py +34 -34
  35. scalebox/sandbox_async/main.py +226 -44
  36. scalebox/sandbox_async/sandbox_api.py +124 -3
  37. scalebox/sandbox_sync/main.py +205 -130
  38. scalebox/sandbox_sync/sandbox_api.py +119 -3
  39. scalebox/test/CODE_INTERPRETER_TESTS_READY.md +323 -323
  40. scalebox/test/README.md +329 -329
  41. scalebox/test/bedrock_openai_adapter.py +73 -0
  42. scalebox/test/code_interpreter_test.py +34 -34
  43. scalebox/test/code_interpreter_test_sync.py +34 -34
  44. scalebox/test/run_stress_code_interpreter_sync.py +178 -0
  45. scalebox/test/simple_upload_example.py +131 -0
  46. scalebox/test/stabitiy_test.py +323 -0
  47. scalebox/test/test_browser_use.py +27 -0
  48. scalebox/test/test_browser_use_scalebox.py +62 -0
  49. scalebox/test/test_code_interpreter_execcode.py +289 -211
  50. scalebox/test/test_code_interpreter_sync_comprehensive.py +116 -69
  51. scalebox/test/test_connect_pause_async.py +300 -0
  52. scalebox/test/test_connect_pause_sync.py +300 -0
  53. scalebox/test/test_csx_desktop_examples.py +3 -3
  54. scalebox/test/test_desktop_sandbox_sf.py +112 -0
  55. scalebox/test/test_download_url.py +41 -0
  56. scalebox/test/test_existing_sandbox.py +1037 -0
  57. scalebox/test/test_sandbox_async_comprehensive.py +5 -3
  58. scalebox/test/test_sandbox_object_storage_example.py +151 -0
  59. scalebox/test/test_sandbox_object_storage_example_async.py +159 -0
  60. scalebox/test/test_sandbox_sync_comprehensive.py +1 -1
  61. scalebox/test/test_sf.py +141 -0
  62. scalebox/test/test_watch_dir_async.py +58 -0
  63. scalebox/test/testacreate.py +1 -1
  64. scalebox/test/testagetinfo.py +1 -3
  65. scalebox/test/testcomputeuse.py +243 -243
  66. scalebox/test/testsandbox_api.py +5 -5
  67. scalebox/test/testsandbox_async.py +17 -47
  68. scalebox/test/testsandbox_sync.py +19 -15
  69. scalebox/test/upload_100mb_example.py +377 -0
  70. scalebox/utils/httpcoreclient.py +297 -297
  71. scalebox/utils/httpxclient.py +403 -403
  72. scalebox/version.py +2 -2
  73. {scalebox_sdk-0.1.25.dist-info → scalebox_sdk-1.0.2.dist-info}/METADATA +1 -1
  74. {scalebox_sdk-0.1.25.dist-info → scalebox_sdk-1.0.2.dist-info}/RECORD +78 -60
  75. {scalebox_sdk-0.1.25.dist-info → scalebox_sdk-1.0.2.dist-info}/WHEEL +1 -1
  76. {scalebox_sdk-0.1.25.dist-info → scalebox_sdk-1.0.2.dist-info}/entry_points.txt +0 -0
  77. {scalebox_sdk-0.1.25.dist-info → scalebox_sdk-1.0.2.dist-info}/licenses/LICENSE +0 -0
  78. {scalebox_sdk-0.1.25.dist-info → scalebox_sdk-1.0.2.dist-info}/top_level.txt +0 -0
@@ -1,485 +1,485 @@
1
- import gzip
2
- import json
3
- import struct
4
- from enum import Enum, Flag
5
- from typing import Any, Callable, Dict, Generator, Optional, Tuple
6
-
7
- from google.protobuf import json_format
8
- from httpcore import (
9
- URL,
10
- AsyncConnectionPool,
11
- ConnectionPool,
12
- RemoteProtocolError,
13
- Response,
14
- )
15
-
16
-
17
- class EnvelopeFlags(Flag):
18
- compressed = 0b00000001
19
- end_stream = 0b00000010
20
-
21
-
22
- class Code(Enum):
23
- canceled = "canceled"
24
- unknown = "unknown"
25
- invalid_argument = "invalid_argument"
26
- deadline_exceeded = "deadline_exceeded"
27
- not_found = "not_found"
28
- already_exists = "already_exists"
29
- permission_denied = "permission_denied"
30
- resource_exhausted = "resource_exhausted"
31
- failed_precondition = "failed_precondition"
32
- aborted = "aborted"
33
- out_of_range = "out_of_range"
34
- unimplemented = "unimplemented"
35
- internal = "internal"
36
- unavailable = "unavailable"
37
- data_loss = "data_loss"
38
- unauthenticated = "unauthenticated"
39
-
40
-
41
- def make_error_from_http_code(http_code: int):
42
- error_code_map = {
43
- 400: Code.invalid_argument,
44
- 401: Code.unauthenticated,
45
- 403: Code.permission_denied,
46
- 404: Code.not_found,
47
- 409: Code.already_exists,
48
- 413: Code.resource_exhausted,
49
- 429: Code.resource_exhausted,
50
- 499: Code.canceled,
51
- 500: Code.internal,
52
- 501: Code.unimplemented,
53
- 502: Code.unavailable,
54
- 503: Code.unavailable,
55
- 504: Code.deadline_exceeded,
56
- 505: Code.unimplemented,
57
- }
58
-
59
- return error_code_map.get(http_code, Code.unknown)
60
-
61
-
62
- class ConnectException(Exception):
63
- def __init__(self, status: Code, message: str):
64
- self.status = status
65
- self.message = message
66
-
67
-
68
- envelope_header_length = 5
69
- envelope_header_pack = ">BI"
70
-
71
-
72
- def encode_envelope(*, flags: EnvelopeFlags, data):
73
- return encode_envelope_header(flags=flags.value, data=data) + data
74
-
75
-
76
- def encode_envelope_header(*, flags, data):
77
- return struct.pack(envelope_header_pack, flags, len(data))
78
-
79
-
80
- def decode_envelope_header(header):
81
- flags, data_len = struct.unpack(envelope_header_pack, header)
82
- return EnvelopeFlags(flags), data_len
83
-
84
-
85
- def error_for_response(http_resp: Response):
86
- try:
87
- error = json.loads(http_resp.content)
88
- return make_error(error)
89
- except (json.decoder.JSONDecodeError, KeyError):
90
- error = {"code": http_resp.status, "message": http_resp.content.decode("utf-8")}
91
- return make_error(error)
92
-
93
-
94
- def make_error(error):
95
- status = None
96
- try:
97
- code_value = error.get("code")
98
- # return error code from http status code
99
- if isinstance(code_value, int):
100
- status = make_error_from_http_code(code_value)
101
- else:
102
- status = Code(code_value)
103
- except (KeyError, ValueError):
104
- status = Code.unknown
105
-
106
- return ConnectException(status, error.get("message", ""))
107
-
108
-
109
- class GzipCompressor:
110
- name = "gzip"
111
- decompress = gzip.decompress
112
- compress = gzip.compress
113
-
114
-
115
- class JSONCodec:
116
- content_type = "json"
117
-
118
- @staticmethod
119
- def encode(msg):
120
- return json_format.MessageToJson(msg).encode("utf8")
121
-
122
- @staticmethod
123
- def decode(data, *, msg_type):
124
- msg = msg_type()
125
- json_format.Parse(data.decode("utf8"), msg, ignore_unknown_fields=True)
126
- return msg
127
-
128
-
129
- class ProtobufCodec:
130
- content_type = "proto"
131
-
132
- @staticmethod
133
- def encode(msg):
134
- return msg.SerializeToString()
135
-
136
- @staticmethod
137
- def decode(data, *, msg_type):
138
- msg = msg_type()
139
- msg.ParseFromString(data)
140
- return msg
141
-
142
-
143
- class Client:
144
- def __init__(
145
- self,
146
- *,
147
- pool: Optional[ConnectionPool] = None,
148
- async_pool: Optional[AsyncConnectionPool] = None,
149
- url: str,
150
- response_type,
151
- compressor=None,
152
- json: Optional[bool] = False,
153
- headers: Optional[Dict[str, str]] = None,
154
- ):
155
- if headers is None:
156
- headers = {}
157
-
158
- self.pool = pool
159
- self.async_pool = async_pool
160
- self.url = url
161
- self._codec = JSONCodec if json else ProtobufCodec
162
- self._response_type = response_type
163
- self._compressor = compressor
164
- self._headers = headers
165
- self._connection_retries = 3
166
-
167
- def _prepare_unary_request(
168
- self,
169
- req,
170
- request_timeout=None,
171
- headers={},
172
- **opts,
173
- ):
174
- data = self._codec.encode(req)
175
-
176
- if self._compressor is not None:
177
- data = self._compressor.compress(data)
178
-
179
- extensions = (
180
- None
181
- if request_timeout is None
182
- else {
183
- "timeout": {
184
- "connect": request_timeout,
185
- "pool": request_timeout,
186
- "read": request_timeout,
187
- "write": request_timeout,
188
- }
189
- }
190
- )
191
-
192
- return {
193
- "method": "POST",
194
- "url": self.url,
195
- "content": data,
196
- "extensions": extensions,
197
- "headers": {
198
- **self._headers,
199
- **headers,
200
- **opts.get("headers", {}),
201
- "connect-protocol-version": "1",
202
- "content-encoding": (
203
- "identity" if self._compressor is None else self._compressor.name
204
- ),
205
- "content-type": f"application/{self._codec.content_type}",
206
- },
207
- }
208
-
209
- def _process_unary_response(
210
- self,
211
- http_resp: Response,
212
- ):
213
- if http_resp.status != 200:
214
- raise error_for_response(http_resp)
215
-
216
- content = http_resp.content
217
-
218
- if self._compressor is not None:
219
- content = self._compressor.decompress(content)
220
-
221
- return self._codec.decode(
222
- content,
223
- msg_type=self._response_type,
224
- )
225
-
226
- async def acall_unary(
227
- self,
228
- req,
229
- request_timeout=None,
230
- headers={},
231
- **opts,
232
- ):
233
- if self.async_pool is None:
234
- raise ValueError("async_pool is required")
235
-
236
- req_data = self._prepare_unary_request(
237
- req,
238
- request_timeout,
239
- headers,
240
- **opts,
241
- )
242
-
243
- conn = self.async_pool
244
-
245
- for _ in range(self._connection_retries):
246
- try:
247
- res = await conn.request(**req_data)
248
- return self._process_unary_response(res)
249
- except RemoteProtocolError:
250
- conn = self.async_pool.create_connection(URL(req_data["url"]).origin)
251
-
252
- continue
253
- except:
254
- raise
255
-
256
- def call_unary(self, req, request_timeout=None, headers={}, **opts):
257
- if self.pool is None:
258
- raise ValueError("pool is required")
259
-
260
- req_data = self._prepare_unary_request(
261
- req,
262
- request_timeout,
263
- headers,
264
- **opts,
265
- )
266
-
267
- conn = self.pool
268
-
269
- for _ in range(self._connection_retries):
270
- try:
271
- res = conn.request(**req_data)
272
- return self._process_unary_response(res)
273
- except RemoteProtocolError:
274
- conn = self.pool.create_connection(URL(req_data["url"]).origin)
275
-
276
- continue
277
- except:
278
- raise
279
-
280
- def _create_stream_timeout(self, timeout: Optional[int]):
281
- if timeout:
282
- return {"connect-timeout-ms": str(timeout * 1000)}
283
- return {}
284
-
285
- def _prepare_server_stream_request(
286
- self,
287
- req,
288
- request_timeout=None,
289
- timeout=None,
290
- headers={},
291
- **opts,
292
- ):
293
- data = self._codec.encode(req)
294
- flags = EnvelopeFlags(0)
295
-
296
- extensions = (
297
- None
298
- if request_timeout is None
299
- else {"timeout": {"connect": request_timeout, "pool": request_timeout}}
300
- )
301
-
302
- if self._compressor is not None:
303
- data = self._compressor.compress(data)
304
- flags |= EnvelopeFlags.compressed
305
-
306
- stream_timeout = self._create_stream_timeout(timeout)
307
-
308
- return {
309
- "method": "POST",
310
- "url": self.url,
311
- "content": encode_envelope(
312
- flags=flags,
313
- data=data,
314
- ),
315
- "extensions": extensions,
316
- "headers": {
317
- **self._headers,
318
- **headers,
319
- **opts.get("headers", {}),
320
- **stream_timeout,
321
- "connect-protocol-version": "1",
322
- "connect-content-encoding": (
323
- "identity" if self._compressor is None else self._compressor.name
324
- ),
325
- "content-type": f"application/connect+{self._codec.content_type}",
326
- },
327
- }
328
-
329
- async def acall_server_stream(
330
- self,
331
- req,
332
- request_timeout=None,
333
- timeout=None,
334
- headers={},
335
- **opts,
336
- ):
337
- if self.async_pool is None:
338
- raise ValueError("async_pool is required")
339
-
340
- req_data = self._prepare_server_stream_request(
341
- req,
342
- request_timeout,
343
- timeout,
344
- headers,
345
- **opts,
346
- )
347
-
348
- conn = self.async_pool
349
-
350
- for _ in range(self._connection_retries):
351
- try:
352
- async with conn.stream(**req_data) as http_resp:
353
- if http_resp.status != 200:
354
- await http_resp.aread()
355
- raise error_for_response(http_resp)
356
-
357
- parser = ServerStreamParser(
358
- decode=self._codec.decode,
359
- response_type=self._response_type,
360
- )
361
-
362
- async for chunk in http_resp.aiter_stream():
363
- for chunk in parser.parse(chunk):
364
- yield chunk
365
-
366
- return
367
- except RemoteProtocolError:
368
- conn = self.async_pool.create_connection(URL(req_data["url"]).origin)
369
-
370
- continue
371
- except:
372
- raise
373
-
374
- def call_server_stream(
375
- self,
376
- req,
377
- request_timeout=None,
378
- timeout=None,
379
- headers={},
380
- **opts,
381
- ):
382
- if self.pool is None:
383
- raise ValueError("pool is required")
384
-
385
- req_data = self._prepare_server_stream_request(
386
- req,
387
- request_timeout,
388
- timeout,
389
- headers,
390
- **opts,
391
- )
392
-
393
- conn = self.pool
394
-
395
- for _ in range(self._connection_retries):
396
- try:
397
- with conn.stream(**req_data) as http_resp:
398
- if http_resp.status != 200:
399
- http_resp.read()
400
- raise error_for_response(http_resp)
401
-
402
- parser = ServerStreamParser(
403
- decode=self._codec.decode,
404
- response_type=self._response_type,
405
- )
406
-
407
- for chunk in http_resp.iter_stream():
408
- yield from parser.parse(chunk)
409
-
410
- return
411
- except RemoteProtocolError:
412
- conn = self.pool.create_connection(URL(req_data["url"]).origin)
413
-
414
- continue
415
- except:
416
- raise
417
-
418
- def call_client_stream(self, req, **opts):
419
- raise NotImplementedError("client stream not supported")
420
-
421
- def acall_client_stream(self, req, **opts):
422
- raise NotImplementedError("client stream not supported")
423
-
424
- def call_bidi_stream(self, req, **opts):
425
- raise NotImplementedError("bidi stream not supported")
426
-
427
- def acall_bidi_stream(self, req, **opts):
428
- raise NotImplementedError("bidi stream not supported")
429
-
430
-
431
- DataLen = int
432
-
433
-
434
- class ServerStreamParser:
435
- def __init__(
436
- self,
437
- decode: Callable,
438
- response_type: Any,
439
- ):
440
- self.decode = decode
441
- self.response_type = response_type
442
-
443
- self.buffer: bytes = b""
444
- self._header: Optional[tuple[EnvelopeFlags, DataLen]] = None
445
-
446
- def shift_buffer(self, size: int):
447
- buffer = self.buffer[:size]
448
- self.buffer = self.buffer[size:]
449
- return buffer
450
-
451
- @property
452
- def header(self) -> Tuple[EnvelopeFlags, DataLen]:
453
- if self._header:
454
- return self._header
455
-
456
- header_data = self.shift_buffer(envelope_header_length)
457
- self._header = decode_envelope_header(header_data)
458
-
459
- return self._header
460
-
461
- @header.deleter
462
- def header(self):
463
- self._header = None
464
-
465
- def parse(self, chunk: bytes) -> Generator[Any, None, None]:
466
- self.buffer += chunk
467
-
468
- while len(self.buffer) >= envelope_header_length:
469
- flags, data_len = self.header
470
-
471
- if data_len > len(self.buffer):
472
- break
473
-
474
- data = self.shift_buffer(data_len)
475
-
476
- if EnvelopeFlags.end_stream in flags:
477
- data = json.loads(data)
478
-
479
- if "error" in data:
480
- raise make_error(data["error"])
481
-
482
- return
483
-
484
- yield self.decode(data, msg_type=self.response_type)
485
- del self.header
1
+ import gzip
2
+ import json
3
+ import struct
4
+ from enum import Enum, Flag
5
+ from typing import Any, Callable, Dict, Generator, Optional, Tuple
6
+
7
+ from google.protobuf import json_format
8
+ from httpcore import (
9
+ URL,
10
+ AsyncConnectionPool,
11
+ ConnectionPool,
12
+ RemoteProtocolError,
13
+ Response,
14
+ )
15
+
16
+
17
+ class EnvelopeFlags(Flag):
18
+ compressed = 0b00000001
19
+ end_stream = 0b00000010
20
+
21
+
22
+ class Code(Enum):
23
+ canceled = "canceled"
24
+ unknown = "unknown"
25
+ invalid_argument = "invalid_argument"
26
+ deadline_exceeded = "deadline_exceeded"
27
+ not_found = "not_found"
28
+ already_exists = "already_exists"
29
+ permission_denied = "permission_denied"
30
+ resource_exhausted = "resource_exhausted"
31
+ failed_precondition = "failed_precondition"
32
+ aborted = "aborted"
33
+ out_of_range = "out_of_range"
34
+ unimplemented = "unimplemented"
35
+ internal = "internal"
36
+ unavailable = "unavailable"
37
+ data_loss = "data_loss"
38
+ unauthenticated = "unauthenticated"
39
+
40
+
41
+ def make_error_from_http_code(http_code: int):
42
+ error_code_map = {
43
+ 400: Code.invalid_argument,
44
+ 401: Code.unauthenticated,
45
+ 403: Code.permission_denied,
46
+ 404: Code.not_found,
47
+ 409: Code.already_exists,
48
+ 413: Code.resource_exhausted,
49
+ 429: Code.resource_exhausted,
50
+ 499: Code.canceled,
51
+ 500: Code.internal,
52
+ 501: Code.unimplemented,
53
+ 502: Code.unavailable,
54
+ 503: Code.unavailable,
55
+ 504: Code.deadline_exceeded,
56
+ 505: Code.unimplemented,
57
+ }
58
+
59
+ return error_code_map.get(http_code, Code.unknown)
60
+
61
+
62
+ class ConnectException(Exception):
63
+ def __init__(self, status: Code, message: str):
64
+ self.status = status
65
+ self.message = message
66
+
67
+
68
+ envelope_header_length = 5
69
+ envelope_header_pack = ">BI"
70
+
71
+
72
+ def encode_envelope(*, flags: EnvelopeFlags, data):
73
+ return encode_envelope_header(flags=flags.value, data=data) + data
74
+
75
+
76
+ def encode_envelope_header(*, flags, data):
77
+ return struct.pack(envelope_header_pack, flags, len(data))
78
+
79
+
80
+ def decode_envelope_header(header):
81
+ flags, data_len = struct.unpack(envelope_header_pack, header)
82
+ return EnvelopeFlags(flags), data_len
83
+
84
+
85
+ def error_for_response(http_resp: Response):
86
+ try:
87
+ error = json.loads(http_resp.content)
88
+ return make_error(error)
89
+ except (json.decoder.JSONDecodeError, KeyError):
90
+ error = {"code": http_resp.status, "message": http_resp.content.decode("utf-8")}
91
+ return make_error(error)
92
+
93
+
94
+ def make_error(error):
95
+ status = None
96
+ try:
97
+ code_value = error.get("code")
98
+ # return error code from http status code
99
+ if isinstance(code_value, int):
100
+ status = make_error_from_http_code(code_value)
101
+ else:
102
+ status = Code(code_value)
103
+ except (KeyError, ValueError):
104
+ status = Code.unknown
105
+
106
+ return ConnectException(status, error.get("message", ""))
107
+
108
+
109
+ class GzipCompressor:
110
+ name = "gzip"
111
+ decompress = gzip.decompress
112
+ compress = gzip.compress
113
+
114
+
115
+ class JSONCodec:
116
+ content_type = "json"
117
+
118
+ @staticmethod
119
+ def encode(msg):
120
+ return json_format.MessageToJson(msg).encode("utf8")
121
+
122
+ @staticmethod
123
+ def decode(data, *, msg_type):
124
+ msg = msg_type()
125
+ json_format.Parse(data.decode("utf8"), msg, ignore_unknown_fields=True)
126
+ return msg
127
+
128
+
129
+ class ProtobufCodec:
130
+ content_type = "proto"
131
+
132
+ @staticmethod
133
+ def encode(msg):
134
+ return msg.SerializeToString()
135
+
136
+ @staticmethod
137
+ def decode(data, *, msg_type):
138
+ msg = msg_type()
139
+ msg.ParseFromString(data)
140
+ return msg
141
+
142
+
143
+ class Client:
144
+ def __init__(
145
+ self,
146
+ *,
147
+ pool: Optional[ConnectionPool] = None,
148
+ async_pool: Optional[AsyncConnectionPool] = None,
149
+ url: str,
150
+ response_type,
151
+ compressor=None,
152
+ json: Optional[bool] = False,
153
+ headers: Optional[Dict[str, str]] = None,
154
+ ):
155
+ if headers is None:
156
+ headers = {}
157
+
158
+ self.pool = pool
159
+ self.async_pool = async_pool
160
+ self.url = url
161
+ self._codec = JSONCodec if json else ProtobufCodec
162
+ self._response_type = response_type
163
+ self._compressor = compressor
164
+ self._headers = headers
165
+ self._connection_retries = 3
166
+
167
+ def _prepare_unary_request(
168
+ self,
169
+ req,
170
+ request_timeout=None,
171
+ headers={},
172
+ **opts,
173
+ ):
174
+ data = self._codec.encode(req)
175
+
176
+ if self._compressor is not None:
177
+ data = self._compressor.compress(data)
178
+
179
+ extensions = (
180
+ None
181
+ if request_timeout is None
182
+ else {
183
+ "timeout": {
184
+ "connect": request_timeout,
185
+ "pool": request_timeout,
186
+ "read": request_timeout,
187
+ "write": request_timeout,
188
+ }
189
+ }
190
+ )
191
+
192
+ return {
193
+ "method": "POST",
194
+ "url": self.url,
195
+ "content": data,
196
+ "extensions": extensions,
197
+ "headers": {
198
+ **self._headers,
199
+ **headers,
200
+ **opts.get("headers", {}),
201
+ "connect-protocol-version": "1",
202
+ "content-encoding": (
203
+ "identity" if self._compressor is None else self._compressor.name
204
+ ),
205
+ "content-type": f"application/{self._codec.content_type}",
206
+ },
207
+ }
208
+
209
+ def _process_unary_response(
210
+ self,
211
+ http_resp: Response,
212
+ ):
213
+ if http_resp.status != 200:
214
+ raise error_for_response(http_resp)
215
+
216
+ content = http_resp.content
217
+
218
+ if self._compressor is not None:
219
+ content = self._compressor.decompress(content)
220
+
221
+ return self._codec.decode(
222
+ content,
223
+ msg_type=self._response_type,
224
+ )
225
+
226
+ async def acall_unary(
227
+ self,
228
+ req,
229
+ request_timeout=None,
230
+ headers={},
231
+ **opts,
232
+ ):
233
+ if self.async_pool is None:
234
+ raise ValueError("async_pool is required")
235
+
236
+ req_data = self._prepare_unary_request(
237
+ req,
238
+ request_timeout,
239
+ headers,
240
+ **opts,
241
+ )
242
+
243
+ conn = self.async_pool
244
+
245
+ for _ in range(self._connection_retries):
246
+ try:
247
+ res = await conn.request(**req_data)
248
+ return self._process_unary_response(res)
249
+ except RemoteProtocolError:
250
+ conn = self.async_pool.create_connection(URL(req_data["url"]).origin)
251
+
252
+ continue
253
+ except:
254
+ raise
255
+
256
+ def call_unary(self, req, request_timeout=None, headers={}, **opts):
257
+ if self.pool is None:
258
+ raise ValueError("pool is required")
259
+
260
+ req_data = self._prepare_unary_request(
261
+ req,
262
+ request_timeout,
263
+ headers,
264
+ **opts,
265
+ )
266
+
267
+ conn = self.pool
268
+
269
+ for _ in range(self._connection_retries):
270
+ try:
271
+ res = conn.request(**req_data)
272
+ return self._process_unary_response(res)
273
+ except RemoteProtocolError:
274
+ conn = self.pool.create_connection(URL(req_data["url"]).origin)
275
+
276
+ continue
277
+ except:
278
+ raise
279
+
280
+ def _create_stream_timeout(self, timeout: Optional[int]):
281
+ if timeout:
282
+ return {"connect-timeout-ms": str(timeout * 1000)}
283
+ return {}
284
+
285
+ def _prepare_server_stream_request(
286
+ self,
287
+ req,
288
+ request_timeout=None,
289
+ timeout=None,
290
+ headers={},
291
+ **opts,
292
+ ):
293
+ data = self._codec.encode(req)
294
+ flags = EnvelopeFlags(0)
295
+
296
+ extensions = (
297
+ None
298
+ if request_timeout is None
299
+ else {"timeout": {"connect": request_timeout, "pool": request_timeout}}
300
+ )
301
+
302
+ if self._compressor is not None:
303
+ data = self._compressor.compress(data)
304
+ flags |= EnvelopeFlags.compressed
305
+
306
+ stream_timeout = self._create_stream_timeout(timeout)
307
+
308
+ return {
309
+ "method": "POST",
310
+ "url": self.url,
311
+ "content": encode_envelope(
312
+ flags=flags,
313
+ data=data,
314
+ ),
315
+ "extensions": extensions,
316
+ "headers": {
317
+ **self._headers,
318
+ **headers,
319
+ **opts.get("headers", {}),
320
+ **stream_timeout,
321
+ "connect-protocol-version": "1",
322
+ "connect-content-encoding": (
323
+ "identity" if self._compressor is None else self._compressor.name
324
+ ),
325
+ "content-type": f"application/connect+{self._codec.content_type}",
326
+ },
327
+ }
328
+
329
+ async def acall_server_stream(
330
+ self,
331
+ req,
332
+ request_timeout=None,
333
+ timeout=None,
334
+ headers={},
335
+ **opts,
336
+ ):
337
+ if self.async_pool is None:
338
+ raise ValueError("async_pool is required")
339
+
340
+ req_data = self._prepare_server_stream_request(
341
+ req,
342
+ request_timeout,
343
+ timeout,
344
+ headers,
345
+ **opts,
346
+ )
347
+
348
+ conn = self.async_pool
349
+
350
+ for _ in range(self._connection_retries):
351
+ try:
352
+ async with conn.stream(**req_data) as http_resp:
353
+ if http_resp.status != 200:
354
+ await http_resp.aread()
355
+ raise error_for_response(http_resp)
356
+
357
+ parser = ServerStreamParser(
358
+ decode=self._codec.decode,
359
+ response_type=self._response_type,
360
+ )
361
+
362
+ async for chunk in http_resp.aiter_stream():
363
+ for chunk in parser.parse(chunk):
364
+ yield chunk
365
+
366
+ return
367
+ except RemoteProtocolError:
368
+ conn = self.async_pool.create_connection(URL(req_data["url"]).origin)
369
+
370
+ continue
371
+ except:
372
+ raise
373
+
374
+ def call_server_stream(
375
+ self,
376
+ req,
377
+ request_timeout=None,
378
+ timeout=None,
379
+ headers={},
380
+ **opts,
381
+ ):
382
+ if self.pool is None:
383
+ raise ValueError("pool is required")
384
+
385
+ req_data = self._prepare_server_stream_request(
386
+ req,
387
+ request_timeout,
388
+ timeout,
389
+ headers,
390
+ **opts,
391
+ )
392
+
393
+ conn = self.pool
394
+
395
+ for _ in range(self._connection_retries):
396
+ try:
397
+ with conn.stream(**req_data) as http_resp:
398
+ if http_resp.status != 200:
399
+ http_resp.read()
400
+ raise error_for_response(http_resp)
401
+
402
+ parser = ServerStreamParser(
403
+ decode=self._codec.decode,
404
+ response_type=self._response_type,
405
+ )
406
+
407
+ for chunk in http_resp.iter_stream():
408
+ yield from parser.parse(chunk)
409
+
410
+ return
411
+ except RemoteProtocolError:
412
+ conn = self.pool.create_connection(URL(req_data["url"]).origin)
413
+
414
+ continue
415
+ except:
416
+ raise
417
+
418
+ def call_client_stream(self, req, **opts):
419
+ raise NotImplementedError("client stream not supported")
420
+
421
+ def acall_client_stream(self, req, **opts):
422
+ raise NotImplementedError("client stream not supported")
423
+
424
+ def call_bidi_stream(self, req, **opts):
425
+ raise NotImplementedError("bidi stream not supported")
426
+
427
+ def acall_bidi_stream(self, req, **opts):
428
+ raise NotImplementedError("bidi stream not supported")
429
+
430
+
431
+ DataLen = int
432
+
433
+
434
+ class ServerStreamParser:
435
+ def __init__(
436
+ self,
437
+ decode: Callable,
438
+ response_type: Any,
439
+ ):
440
+ self.decode = decode
441
+ self.response_type = response_type
442
+
443
+ self.buffer: bytes = b""
444
+ self._header: Optional[tuple[EnvelopeFlags, DataLen]] = None
445
+
446
+ def shift_buffer(self, size: int):
447
+ buffer = self.buffer[:size]
448
+ self.buffer = self.buffer[size:]
449
+ return buffer
450
+
451
+ @property
452
+ def header(self) -> Tuple[EnvelopeFlags, DataLen]:
453
+ if self._header:
454
+ return self._header
455
+
456
+ header_data = self.shift_buffer(envelope_header_length)
457
+ self._header = decode_envelope_header(header_data)
458
+
459
+ return self._header
460
+
461
+ @header.deleter
462
+ def header(self):
463
+ self._header = None
464
+
465
+ def parse(self, chunk: bytes) -> Generator[Any, None, None]:
466
+ self.buffer += chunk
467
+
468
+ while len(self.buffer) >= envelope_header_length:
469
+ flags, data_len = self.header
470
+
471
+ if data_len > len(self.buffer):
472
+ break
473
+
474
+ data = self.shift_buffer(data_len)
475
+
476
+ if EnvelopeFlags.end_stream in flags:
477
+ data = json.loads(data)
478
+
479
+ if "error" in data:
480
+ raise make_error(data["error"])
481
+
482
+ return
483
+
484
+ yield self.decode(data, msg_type=self.response_type)
485
+ del self.header