scalebox-sdk 0.1.24__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scalebox/__init__.py +2 -2
- scalebox/api/__init__.py +130 -128
- scalebox/api/client/__init__.py +8 -8
- scalebox/api/client/api/sandboxes/get_sandboxes_sandbox_id_metrics.py +2 -2
- scalebox/api/client/api/sandboxes/post_sandboxes.py +2 -2
- scalebox/api/client/api/sandboxes/post_sandboxes_sandbox_id_connect.py +193 -0
- scalebox/api/client/client.py +288 -288
- scalebox/api/client/models/connect_sandbox.py +59 -0
- scalebox/api/client/models/error.py +2 -2
- scalebox/api/client/models/listed_sandbox.py +19 -1
- scalebox/api/client/models/new_sandbox.py +10 -0
- scalebox/api/client/models/sandbox.py +138 -125
- scalebox/api/client/models/sandbox_detail.py +24 -0
- scalebox/api/client/types.py +46 -46
- scalebox/cli.py +125 -125
- scalebox/client/aclient.py +57 -57
- scalebox/client/client.py +102 -102
- scalebox/code_interpreter/__init__.py +12 -12
- scalebox/code_interpreter/charts.py +230 -230
- scalebox/code_interpreter/constants.py +3 -3
- scalebox/code_interpreter/exceptions.py +13 -13
- scalebox/code_interpreter/models.py +485 -485
- scalebox/connection_config.py +34 -1
- scalebox/csx_connect/__init__.py +1 -1
- scalebox/csx_connect/client.py +485 -485
- scalebox/csx_desktop/main.py +651 -651
- scalebox/exceptions.py +83 -83
- scalebox/generated/api.py +61 -61
- scalebox/generated/api_pb2.py +203 -203
- scalebox/generated/api_pb2.pyi +956 -956
- scalebox/generated/api_pb2_connect.py +1407 -1407
- scalebox/generated/rpc.py +50 -50
- scalebox/sandbox/main.py +146 -139
- scalebox/sandbox/sandbox_api.py +105 -91
- scalebox/sandbox/signature.py +40 -40
- scalebox/sandbox/utils.py +34 -34
- scalebox/sandbox_async/commands/command.py +307 -307
- scalebox/sandbox_async/commands/command_handle.py +187 -187
- scalebox/sandbox_async/commands/pty.py +187 -187
- scalebox/sandbox_async/filesystem/filesystem.py +557 -557
- scalebox/sandbox_async/filesystem/watch_handle.py +61 -61
- scalebox/sandbox_async/main.py +228 -46
- scalebox/sandbox_async/sandbox_api.py +124 -3
- scalebox/sandbox_async/utils.py +7 -7
- scalebox/sandbox_sync/__init__.py +2 -2
- scalebox/sandbox_sync/commands/command.py +300 -300
- scalebox/sandbox_sync/commands/command_handle.py +150 -150
- scalebox/sandbox_sync/commands/pty.py +181 -181
- scalebox/sandbox_sync/filesystem/filesystem.py +3 -3
- scalebox/sandbox_sync/filesystem/watch_handle.py +66 -66
- scalebox/sandbox_sync/main.py +208 -133
- scalebox/sandbox_sync/sandbox_api.py +119 -3
- scalebox/test/CODE_INTERPRETER_TESTS_READY.md +323 -323
- scalebox/test/README.md +329 -329
- scalebox/test/bedrock_openai_adapter.py +67 -0
- scalebox/test/code_interpreter_test.py +34 -34
- scalebox/test/code_interpreter_test_sync.py +34 -34
- scalebox/test/run_stress_code_interpreter_sync.py +166 -0
- scalebox/test/simple_upload_example.py +123 -0
- scalebox/test/stabitiy_test.py +310 -0
- scalebox/test/test_browser_use.py +25 -0
- scalebox/test/test_browser_use_scalebox.py +61 -0
- scalebox/test/test_code_interpreter_sync_comprehensive.py +115 -65
- scalebox/test/test_connect_pause_async.py +277 -0
- scalebox/test/test_connect_pause_sync.py +267 -0
- scalebox/test/test_desktop_sandbox_sf.py +117 -0
- scalebox/test/test_download_url.py +49 -0
- scalebox/test/test_sandbox_async_comprehensive.py +1 -1
- scalebox/test/test_sandbox_object_storage_example.py +146 -0
- scalebox/test/test_sandbox_object_storage_example_async.py +156 -0
- scalebox/test/test_sf.py +137 -0
- scalebox/test/test_watch_dir_async.py +56 -0
- scalebox/test/testacreate.py +1 -1
- scalebox/test/testagetinfo.py +1 -1
- scalebox/test/testcomputeuse.py +243 -243
- scalebox/test/testsandbox_api.py +13 -0
- scalebox/test/testsandbox_sync.py +1 -1
- scalebox/test/upload_100mb_example.py +355 -0
- scalebox/utils/httpcoreclient.py +297 -297
- scalebox/utils/httpxclient.py +403 -403
- scalebox/version.py +2 -2
- {scalebox_sdk-0.1.24.dist-info → scalebox_sdk-1.0.1.dist-info}/METADATA +1 -1
- {scalebox_sdk-0.1.24.dist-info → scalebox_sdk-1.0.1.dist-info}/RECORD +87 -69
- {scalebox_sdk-0.1.24.dist-info → scalebox_sdk-1.0.1.dist-info}/WHEEL +1 -1
- {scalebox_sdk-0.1.24.dist-info → scalebox_sdk-1.0.1.dist-info}/entry_points.txt +0 -0
- {scalebox_sdk-0.1.24.dist-info → scalebox_sdk-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {scalebox_sdk-0.1.24.dist-info → scalebox_sdk-1.0.1.dist-info}/top_level.txt +0 -0
scalebox/csx_connect/client.py
CHANGED
|
@@ -1,485 +1,485 @@
|
|
|
1
|
-
import gzip
|
|
2
|
-
import json
|
|
3
|
-
import struct
|
|
4
|
-
from enum import Enum, Flag
|
|
5
|
-
from typing import Any, Callable, Dict, Generator, Optional, Tuple
|
|
6
|
-
|
|
7
|
-
from google.protobuf import json_format
|
|
8
|
-
from httpcore import (
|
|
9
|
-
URL,
|
|
10
|
-
AsyncConnectionPool,
|
|
11
|
-
ConnectionPool,
|
|
12
|
-
RemoteProtocolError,
|
|
13
|
-
Response,
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class EnvelopeFlags(Flag):
|
|
18
|
-
compressed = 0b00000001
|
|
19
|
-
end_stream = 0b00000010
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class Code(Enum):
|
|
23
|
-
canceled = "canceled"
|
|
24
|
-
unknown = "unknown"
|
|
25
|
-
invalid_argument = "invalid_argument"
|
|
26
|
-
deadline_exceeded = "deadline_exceeded"
|
|
27
|
-
not_found = "not_found"
|
|
28
|
-
already_exists = "already_exists"
|
|
29
|
-
permission_denied = "permission_denied"
|
|
30
|
-
resource_exhausted = "resource_exhausted"
|
|
31
|
-
failed_precondition = "failed_precondition"
|
|
32
|
-
aborted = "aborted"
|
|
33
|
-
out_of_range = "out_of_range"
|
|
34
|
-
unimplemented = "unimplemented"
|
|
35
|
-
internal = "internal"
|
|
36
|
-
unavailable = "unavailable"
|
|
37
|
-
data_loss = "data_loss"
|
|
38
|
-
unauthenticated = "unauthenticated"
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def make_error_from_http_code(http_code: int):
|
|
42
|
-
error_code_map = {
|
|
43
|
-
400: Code.invalid_argument,
|
|
44
|
-
401: Code.unauthenticated,
|
|
45
|
-
403: Code.permission_denied,
|
|
46
|
-
404: Code.not_found,
|
|
47
|
-
409: Code.already_exists,
|
|
48
|
-
413: Code.resource_exhausted,
|
|
49
|
-
429: Code.resource_exhausted,
|
|
50
|
-
499: Code.canceled,
|
|
51
|
-
500: Code.internal,
|
|
52
|
-
501: Code.unimplemented,
|
|
53
|
-
502: Code.unavailable,
|
|
54
|
-
503: Code.unavailable,
|
|
55
|
-
504: Code.deadline_exceeded,
|
|
56
|
-
505: Code.unimplemented,
|
|
57
|
-
}
|
|
58
|
-
|
|
59
|
-
return error_code_map.get(http_code, Code.unknown)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
class ConnectException(Exception):
|
|
63
|
-
def __init__(self, status: Code, message: str):
|
|
64
|
-
self.status = status
|
|
65
|
-
self.message = message
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
envelope_header_length = 5
|
|
69
|
-
envelope_header_pack = ">BI"
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def encode_envelope(*, flags: EnvelopeFlags, data):
|
|
73
|
-
return encode_envelope_header(flags=flags.value, data=data) + data
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def encode_envelope_header(*, flags, data):
|
|
77
|
-
return struct.pack(envelope_header_pack, flags, len(data))
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def decode_envelope_header(header):
|
|
81
|
-
flags, data_len = struct.unpack(envelope_header_pack, header)
|
|
82
|
-
return EnvelopeFlags(flags), data_len
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
def error_for_response(http_resp: Response):
|
|
86
|
-
try:
|
|
87
|
-
error = json.loads(http_resp.content)
|
|
88
|
-
return make_error(error)
|
|
89
|
-
except (json.decoder.JSONDecodeError, KeyError):
|
|
90
|
-
error = {"code": http_resp.status, "message": http_resp.content.decode("utf-8")}
|
|
91
|
-
return make_error(error)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def make_error(error):
|
|
95
|
-
status = None
|
|
96
|
-
try:
|
|
97
|
-
code_value = error.get("code")
|
|
98
|
-
# return error code from http status code
|
|
99
|
-
if isinstance(code_value, int):
|
|
100
|
-
status = make_error_from_http_code(code_value)
|
|
101
|
-
else:
|
|
102
|
-
status = Code(code_value)
|
|
103
|
-
except (KeyError, ValueError):
|
|
104
|
-
status = Code.unknown
|
|
105
|
-
|
|
106
|
-
return ConnectException(status, error.get("message", ""))
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
class GzipCompressor:
|
|
110
|
-
name = "gzip"
|
|
111
|
-
decompress = gzip.decompress
|
|
112
|
-
compress = gzip.compress
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
class JSONCodec:
|
|
116
|
-
content_type = "json"
|
|
117
|
-
|
|
118
|
-
@staticmethod
|
|
119
|
-
def encode(msg):
|
|
120
|
-
return json_format.MessageToJson(msg).encode("utf8")
|
|
121
|
-
|
|
122
|
-
@staticmethod
|
|
123
|
-
def decode(data, *, msg_type):
|
|
124
|
-
msg = msg_type()
|
|
125
|
-
json_format.Parse(data.decode("utf8"), msg, ignore_unknown_fields=True)
|
|
126
|
-
return msg
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
class ProtobufCodec:
|
|
130
|
-
content_type = "proto"
|
|
131
|
-
|
|
132
|
-
@staticmethod
|
|
133
|
-
def encode(msg):
|
|
134
|
-
return msg.SerializeToString()
|
|
135
|
-
|
|
136
|
-
@staticmethod
|
|
137
|
-
def decode(data, *, msg_type):
|
|
138
|
-
msg = msg_type()
|
|
139
|
-
msg.ParseFromString(data)
|
|
140
|
-
return msg
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
class Client:
|
|
144
|
-
def __init__(
|
|
145
|
-
self,
|
|
146
|
-
*,
|
|
147
|
-
pool: Optional[ConnectionPool] = None,
|
|
148
|
-
async_pool: Optional[AsyncConnectionPool] = None,
|
|
149
|
-
url: str,
|
|
150
|
-
response_type,
|
|
151
|
-
compressor=None,
|
|
152
|
-
json: Optional[bool] = False,
|
|
153
|
-
headers: Optional[Dict[str, str]] = None,
|
|
154
|
-
):
|
|
155
|
-
if headers is None:
|
|
156
|
-
headers = {}
|
|
157
|
-
|
|
158
|
-
self.pool = pool
|
|
159
|
-
self.async_pool = async_pool
|
|
160
|
-
self.url = url
|
|
161
|
-
self._codec = JSONCodec if json else ProtobufCodec
|
|
162
|
-
self._response_type = response_type
|
|
163
|
-
self._compressor = compressor
|
|
164
|
-
self._headers = headers
|
|
165
|
-
self._connection_retries = 3
|
|
166
|
-
|
|
167
|
-
def _prepare_unary_request(
|
|
168
|
-
self,
|
|
169
|
-
req,
|
|
170
|
-
request_timeout=None,
|
|
171
|
-
headers={},
|
|
172
|
-
**opts,
|
|
173
|
-
):
|
|
174
|
-
data = self._codec.encode(req)
|
|
175
|
-
|
|
176
|
-
if self._compressor is not None:
|
|
177
|
-
data = self._compressor.compress(data)
|
|
178
|
-
|
|
179
|
-
extensions = (
|
|
180
|
-
None
|
|
181
|
-
if request_timeout is None
|
|
182
|
-
else {
|
|
183
|
-
"timeout": {
|
|
184
|
-
"connect": request_timeout,
|
|
185
|
-
"pool": request_timeout,
|
|
186
|
-
"read": request_timeout,
|
|
187
|
-
"write": request_timeout,
|
|
188
|
-
}
|
|
189
|
-
}
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
return {
|
|
193
|
-
"method": "POST",
|
|
194
|
-
"url": self.url,
|
|
195
|
-
"content": data,
|
|
196
|
-
"extensions": extensions,
|
|
197
|
-
"headers": {
|
|
198
|
-
**self._headers,
|
|
199
|
-
**headers,
|
|
200
|
-
**opts.get("headers", {}),
|
|
201
|
-
"connect-protocol-version": "1",
|
|
202
|
-
"content-encoding": (
|
|
203
|
-
"identity" if self._compressor is None else self._compressor.name
|
|
204
|
-
),
|
|
205
|
-
"content-type": f"application/{self._codec.content_type}",
|
|
206
|
-
},
|
|
207
|
-
}
|
|
208
|
-
|
|
209
|
-
def _process_unary_response(
|
|
210
|
-
self,
|
|
211
|
-
http_resp: Response,
|
|
212
|
-
):
|
|
213
|
-
if http_resp.status != 200:
|
|
214
|
-
raise error_for_response(http_resp)
|
|
215
|
-
|
|
216
|
-
content = http_resp.content
|
|
217
|
-
|
|
218
|
-
if self._compressor is not None:
|
|
219
|
-
content = self._compressor.decompress(content)
|
|
220
|
-
|
|
221
|
-
return self._codec.decode(
|
|
222
|
-
content,
|
|
223
|
-
msg_type=self._response_type,
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
async def acall_unary(
|
|
227
|
-
self,
|
|
228
|
-
req,
|
|
229
|
-
request_timeout=None,
|
|
230
|
-
headers={},
|
|
231
|
-
**opts,
|
|
232
|
-
):
|
|
233
|
-
if self.async_pool is None:
|
|
234
|
-
raise ValueError("async_pool is required")
|
|
235
|
-
|
|
236
|
-
req_data = self._prepare_unary_request(
|
|
237
|
-
req,
|
|
238
|
-
request_timeout,
|
|
239
|
-
headers,
|
|
240
|
-
**opts,
|
|
241
|
-
)
|
|
242
|
-
|
|
243
|
-
conn = self.async_pool
|
|
244
|
-
|
|
245
|
-
for _ in range(self._connection_retries):
|
|
246
|
-
try:
|
|
247
|
-
res = await conn.request(**req_data)
|
|
248
|
-
return self._process_unary_response(res)
|
|
249
|
-
except RemoteProtocolError:
|
|
250
|
-
conn = self.async_pool.create_connection(URL(req_data["url"]).origin)
|
|
251
|
-
|
|
252
|
-
continue
|
|
253
|
-
except:
|
|
254
|
-
raise
|
|
255
|
-
|
|
256
|
-
def call_unary(self, req, request_timeout=None, headers={}, **opts):
|
|
257
|
-
if self.pool is None:
|
|
258
|
-
raise ValueError("pool is required")
|
|
259
|
-
|
|
260
|
-
req_data = self._prepare_unary_request(
|
|
261
|
-
req,
|
|
262
|
-
request_timeout,
|
|
263
|
-
headers,
|
|
264
|
-
**opts,
|
|
265
|
-
)
|
|
266
|
-
|
|
267
|
-
conn = self.pool
|
|
268
|
-
|
|
269
|
-
for _ in range(self._connection_retries):
|
|
270
|
-
try:
|
|
271
|
-
res = conn.request(**req_data)
|
|
272
|
-
return self._process_unary_response(res)
|
|
273
|
-
except RemoteProtocolError:
|
|
274
|
-
conn = self.pool.create_connection(URL(req_data["url"]).origin)
|
|
275
|
-
|
|
276
|
-
continue
|
|
277
|
-
except:
|
|
278
|
-
raise
|
|
279
|
-
|
|
280
|
-
def _create_stream_timeout(self, timeout: Optional[int]):
|
|
281
|
-
if timeout:
|
|
282
|
-
return {"connect-timeout-ms": str(timeout * 1000)}
|
|
283
|
-
return {}
|
|
284
|
-
|
|
285
|
-
def _prepare_server_stream_request(
|
|
286
|
-
self,
|
|
287
|
-
req,
|
|
288
|
-
request_timeout=None,
|
|
289
|
-
timeout=None,
|
|
290
|
-
headers={},
|
|
291
|
-
**opts,
|
|
292
|
-
):
|
|
293
|
-
data = self._codec.encode(req)
|
|
294
|
-
flags = EnvelopeFlags(0)
|
|
295
|
-
|
|
296
|
-
extensions = (
|
|
297
|
-
None
|
|
298
|
-
if request_timeout is None
|
|
299
|
-
else {"timeout": {"connect": request_timeout, "pool": request_timeout}}
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
if self._compressor is not None:
|
|
303
|
-
data = self._compressor.compress(data)
|
|
304
|
-
flags |= EnvelopeFlags.compressed
|
|
305
|
-
|
|
306
|
-
stream_timeout = self._create_stream_timeout(timeout)
|
|
307
|
-
|
|
308
|
-
return {
|
|
309
|
-
"method": "POST",
|
|
310
|
-
"url": self.url,
|
|
311
|
-
"content": encode_envelope(
|
|
312
|
-
flags=flags,
|
|
313
|
-
data=data,
|
|
314
|
-
),
|
|
315
|
-
"extensions": extensions,
|
|
316
|
-
"headers": {
|
|
317
|
-
**self._headers,
|
|
318
|
-
**headers,
|
|
319
|
-
**opts.get("headers", {}),
|
|
320
|
-
**stream_timeout,
|
|
321
|
-
"connect-protocol-version": "1",
|
|
322
|
-
"connect-content-encoding": (
|
|
323
|
-
"identity" if self._compressor is None else self._compressor.name
|
|
324
|
-
),
|
|
325
|
-
"content-type": f"application/connect+{self._codec.content_type}",
|
|
326
|
-
},
|
|
327
|
-
}
|
|
328
|
-
|
|
329
|
-
async def acall_server_stream(
|
|
330
|
-
self,
|
|
331
|
-
req,
|
|
332
|
-
request_timeout=None,
|
|
333
|
-
timeout=None,
|
|
334
|
-
headers={},
|
|
335
|
-
**opts,
|
|
336
|
-
):
|
|
337
|
-
if self.async_pool is None:
|
|
338
|
-
raise ValueError("async_pool is required")
|
|
339
|
-
|
|
340
|
-
req_data = self._prepare_server_stream_request(
|
|
341
|
-
req,
|
|
342
|
-
request_timeout,
|
|
343
|
-
timeout,
|
|
344
|
-
headers,
|
|
345
|
-
**opts,
|
|
346
|
-
)
|
|
347
|
-
|
|
348
|
-
conn = self.async_pool
|
|
349
|
-
|
|
350
|
-
for _ in range(self._connection_retries):
|
|
351
|
-
try:
|
|
352
|
-
async with conn.stream(**req_data) as http_resp:
|
|
353
|
-
if http_resp.status != 200:
|
|
354
|
-
await http_resp.aread()
|
|
355
|
-
raise error_for_response(http_resp)
|
|
356
|
-
|
|
357
|
-
parser = ServerStreamParser(
|
|
358
|
-
decode=self._codec.decode,
|
|
359
|
-
response_type=self._response_type,
|
|
360
|
-
)
|
|
361
|
-
|
|
362
|
-
async for chunk in http_resp.aiter_stream():
|
|
363
|
-
for chunk in parser.parse(chunk):
|
|
364
|
-
yield chunk
|
|
365
|
-
|
|
366
|
-
return
|
|
367
|
-
except RemoteProtocolError:
|
|
368
|
-
conn = self.async_pool.create_connection(URL(req_data["url"]).origin)
|
|
369
|
-
|
|
370
|
-
continue
|
|
371
|
-
except:
|
|
372
|
-
raise
|
|
373
|
-
|
|
374
|
-
def call_server_stream(
|
|
375
|
-
self,
|
|
376
|
-
req,
|
|
377
|
-
request_timeout=None,
|
|
378
|
-
timeout=None,
|
|
379
|
-
headers={},
|
|
380
|
-
**opts,
|
|
381
|
-
):
|
|
382
|
-
if self.pool is None:
|
|
383
|
-
raise ValueError("pool is required")
|
|
384
|
-
|
|
385
|
-
req_data = self._prepare_server_stream_request(
|
|
386
|
-
req,
|
|
387
|
-
request_timeout,
|
|
388
|
-
timeout,
|
|
389
|
-
headers,
|
|
390
|
-
**opts,
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
conn = self.pool
|
|
394
|
-
|
|
395
|
-
for _ in range(self._connection_retries):
|
|
396
|
-
try:
|
|
397
|
-
with conn.stream(**req_data) as http_resp:
|
|
398
|
-
if http_resp.status != 200:
|
|
399
|
-
http_resp.read()
|
|
400
|
-
raise error_for_response(http_resp)
|
|
401
|
-
|
|
402
|
-
parser = ServerStreamParser(
|
|
403
|
-
decode=self._codec.decode,
|
|
404
|
-
response_type=self._response_type,
|
|
405
|
-
)
|
|
406
|
-
|
|
407
|
-
for chunk in http_resp.iter_stream():
|
|
408
|
-
yield from parser.parse(chunk)
|
|
409
|
-
|
|
410
|
-
return
|
|
411
|
-
except RemoteProtocolError:
|
|
412
|
-
conn = self.pool.create_connection(URL(req_data["url"]).origin)
|
|
413
|
-
|
|
414
|
-
continue
|
|
415
|
-
except:
|
|
416
|
-
raise
|
|
417
|
-
|
|
418
|
-
def call_client_stream(self, req, **opts):
|
|
419
|
-
raise NotImplementedError("client stream not supported")
|
|
420
|
-
|
|
421
|
-
def acall_client_stream(self, req, **opts):
|
|
422
|
-
raise NotImplementedError("client stream not supported")
|
|
423
|
-
|
|
424
|
-
def call_bidi_stream(self, req, **opts):
|
|
425
|
-
raise NotImplementedError("bidi stream not supported")
|
|
426
|
-
|
|
427
|
-
def acall_bidi_stream(self, req, **opts):
|
|
428
|
-
raise NotImplementedError("bidi stream not supported")
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
DataLen = int
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
class ServerStreamParser:
|
|
435
|
-
def __init__(
|
|
436
|
-
self,
|
|
437
|
-
decode: Callable,
|
|
438
|
-
response_type: Any,
|
|
439
|
-
):
|
|
440
|
-
self.decode = decode
|
|
441
|
-
self.response_type = response_type
|
|
442
|
-
|
|
443
|
-
self.buffer: bytes = b""
|
|
444
|
-
self._header: Optional[tuple[EnvelopeFlags, DataLen]] = None
|
|
445
|
-
|
|
446
|
-
def shift_buffer(self, size: int):
|
|
447
|
-
buffer = self.buffer[:size]
|
|
448
|
-
self.buffer = self.buffer[size:]
|
|
449
|
-
return buffer
|
|
450
|
-
|
|
451
|
-
@property
|
|
452
|
-
def header(self) -> Tuple[EnvelopeFlags, DataLen]:
|
|
453
|
-
if self._header:
|
|
454
|
-
return self._header
|
|
455
|
-
|
|
456
|
-
header_data = self.shift_buffer(envelope_header_length)
|
|
457
|
-
self._header = decode_envelope_header(header_data)
|
|
458
|
-
|
|
459
|
-
return self._header
|
|
460
|
-
|
|
461
|
-
@header.deleter
|
|
462
|
-
def header(self):
|
|
463
|
-
self._header = None
|
|
464
|
-
|
|
465
|
-
def parse(self, chunk: bytes) -> Generator[Any, None, None]:
|
|
466
|
-
self.buffer += chunk
|
|
467
|
-
|
|
468
|
-
while len(self.buffer) >= envelope_header_length:
|
|
469
|
-
flags, data_len = self.header
|
|
470
|
-
|
|
471
|
-
if data_len > len(self.buffer):
|
|
472
|
-
break
|
|
473
|
-
|
|
474
|
-
data = self.shift_buffer(data_len)
|
|
475
|
-
|
|
476
|
-
if EnvelopeFlags.end_stream in flags:
|
|
477
|
-
data = json.loads(data)
|
|
478
|
-
|
|
479
|
-
if "error" in data:
|
|
480
|
-
raise make_error(data["error"])
|
|
481
|
-
|
|
482
|
-
return
|
|
483
|
-
|
|
484
|
-
yield self.decode(data, msg_type=self.response_type)
|
|
485
|
-
del self.header
|
|
1
|
+
import gzip
|
|
2
|
+
import json
|
|
3
|
+
import struct
|
|
4
|
+
from enum import Enum, Flag
|
|
5
|
+
from typing import Any, Callable, Dict, Generator, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
from google.protobuf import json_format
|
|
8
|
+
from httpcore import (
|
|
9
|
+
URL,
|
|
10
|
+
AsyncConnectionPool,
|
|
11
|
+
ConnectionPool,
|
|
12
|
+
RemoteProtocolError,
|
|
13
|
+
Response,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class EnvelopeFlags(Flag):
|
|
18
|
+
compressed = 0b00000001
|
|
19
|
+
end_stream = 0b00000010
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Code(Enum):
|
|
23
|
+
canceled = "canceled"
|
|
24
|
+
unknown = "unknown"
|
|
25
|
+
invalid_argument = "invalid_argument"
|
|
26
|
+
deadline_exceeded = "deadline_exceeded"
|
|
27
|
+
not_found = "not_found"
|
|
28
|
+
already_exists = "already_exists"
|
|
29
|
+
permission_denied = "permission_denied"
|
|
30
|
+
resource_exhausted = "resource_exhausted"
|
|
31
|
+
failed_precondition = "failed_precondition"
|
|
32
|
+
aborted = "aborted"
|
|
33
|
+
out_of_range = "out_of_range"
|
|
34
|
+
unimplemented = "unimplemented"
|
|
35
|
+
internal = "internal"
|
|
36
|
+
unavailable = "unavailable"
|
|
37
|
+
data_loss = "data_loss"
|
|
38
|
+
unauthenticated = "unauthenticated"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def make_error_from_http_code(http_code: int):
|
|
42
|
+
error_code_map = {
|
|
43
|
+
400: Code.invalid_argument,
|
|
44
|
+
401: Code.unauthenticated,
|
|
45
|
+
403: Code.permission_denied,
|
|
46
|
+
404: Code.not_found,
|
|
47
|
+
409: Code.already_exists,
|
|
48
|
+
413: Code.resource_exhausted,
|
|
49
|
+
429: Code.resource_exhausted,
|
|
50
|
+
499: Code.canceled,
|
|
51
|
+
500: Code.internal,
|
|
52
|
+
501: Code.unimplemented,
|
|
53
|
+
502: Code.unavailable,
|
|
54
|
+
503: Code.unavailable,
|
|
55
|
+
504: Code.deadline_exceeded,
|
|
56
|
+
505: Code.unimplemented,
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
return error_code_map.get(http_code, Code.unknown)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ConnectException(Exception):
|
|
63
|
+
def __init__(self, status: Code, message: str):
|
|
64
|
+
self.status = status
|
|
65
|
+
self.message = message
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
envelope_header_length = 5
|
|
69
|
+
envelope_header_pack = ">BI"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def encode_envelope(*, flags: EnvelopeFlags, data):
|
|
73
|
+
return encode_envelope_header(flags=flags.value, data=data) + data
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def encode_envelope_header(*, flags, data):
|
|
77
|
+
return struct.pack(envelope_header_pack, flags, len(data))
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def decode_envelope_header(header):
|
|
81
|
+
flags, data_len = struct.unpack(envelope_header_pack, header)
|
|
82
|
+
return EnvelopeFlags(flags), data_len
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def error_for_response(http_resp: Response):
|
|
86
|
+
try:
|
|
87
|
+
error = json.loads(http_resp.content)
|
|
88
|
+
return make_error(error)
|
|
89
|
+
except (json.decoder.JSONDecodeError, KeyError):
|
|
90
|
+
error = {"code": http_resp.status, "message": http_resp.content.decode("utf-8")}
|
|
91
|
+
return make_error(error)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def make_error(error):
|
|
95
|
+
status = None
|
|
96
|
+
try:
|
|
97
|
+
code_value = error.get("code")
|
|
98
|
+
# return error code from http status code
|
|
99
|
+
if isinstance(code_value, int):
|
|
100
|
+
status = make_error_from_http_code(code_value)
|
|
101
|
+
else:
|
|
102
|
+
status = Code(code_value)
|
|
103
|
+
except (KeyError, ValueError):
|
|
104
|
+
status = Code.unknown
|
|
105
|
+
|
|
106
|
+
return ConnectException(status, error.get("message", ""))
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class GzipCompressor:
|
|
110
|
+
name = "gzip"
|
|
111
|
+
decompress = gzip.decompress
|
|
112
|
+
compress = gzip.compress
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class JSONCodec:
|
|
116
|
+
content_type = "json"
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def encode(msg):
|
|
120
|
+
return json_format.MessageToJson(msg).encode("utf8")
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def decode(data, *, msg_type):
|
|
124
|
+
msg = msg_type()
|
|
125
|
+
json_format.Parse(data.decode("utf8"), msg, ignore_unknown_fields=True)
|
|
126
|
+
return msg
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class ProtobufCodec:
|
|
130
|
+
content_type = "proto"
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def encode(msg):
|
|
134
|
+
return msg.SerializeToString()
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def decode(data, *, msg_type):
|
|
138
|
+
msg = msg_type()
|
|
139
|
+
msg.ParseFromString(data)
|
|
140
|
+
return msg
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class Client:
|
|
144
|
+
def __init__(
|
|
145
|
+
self,
|
|
146
|
+
*,
|
|
147
|
+
pool: Optional[ConnectionPool] = None,
|
|
148
|
+
async_pool: Optional[AsyncConnectionPool] = None,
|
|
149
|
+
url: str,
|
|
150
|
+
response_type,
|
|
151
|
+
compressor=None,
|
|
152
|
+
json: Optional[bool] = False,
|
|
153
|
+
headers: Optional[Dict[str, str]] = None,
|
|
154
|
+
):
|
|
155
|
+
if headers is None:
|
|
156
|
+
headers = {}
|
|
157
|
+
|
|
158
|
+
self.pool = pool
|
|
159
|
+
self.async_pool = async_pool
|
|
160
|
+
self.url = url
|
|
161
|
+
self._codec = JSONCodec if json else ProtobufCodec
|
|
162
|
+
self._response_type = response_type
|
|
163
|
+
self._compressor = compressor
|
|
164
|
+
self._headers = headers
|
|
165
|
+
self._connection_retries = 3
|
|
166
|
+
|
|
167
|
+
def _prepare_unary_request(
|
|
168
|
+
self,
|
|
169
|
+
req,
|
|
170
|
+
request_timeout=None,
|
|
171
|
+
headers={},
|
|
172
|
+
**opts,
|
|
173
|
+
):
|
|
174
|
+
data = self._codec.encode(req)
|
|
175
|
+
|
|
176
|
+
if self._compressor is not None:
|
|
177
|
+
data = self._compressor.compress(data)
|
|
178
|
+
|
|
179
|
+
extensions = (
|
|
180
|
+
None
|
|
181
|
+
if request_timeout is None
|
|
182
|
+
else {
|
|
183
|
+
"timeout": {
|
|
184
|
+
"connect": request_timeout,
|
|
185
|
+
"pool": request_timeout,
|
|
186
|
+
"read": request_timeout,
|
|
187
|
+
"write": request_timeout,
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return {
|
|
193
|
+
"method": "POST",
|
|
194
|
+
"url": self.url,
|
|
195
|
+
"content": data,
|
|
196
|
+
"extensions": extensions,
|
|
197
|
+
"headers": {
|
|
198
|
+
**self._headers,
|
|
199
|
+
**headers,
|
|
200
|
+
**opts.get("headers", {}),
|
|
201
|
+
"connect-protocol-version": "1",
|
|
202
|
+
"content-encoding": (
|
|
203
|
+
"identity" if self._compressor is None else self._compressor.name
|
|
204
|
+
),
|
|
205
|
+
"content-type": f"application/{self._codec.content_type}",
|
|
206
|
+
},
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
def _process_unary_response(
|
|
210
|
+
self,
|
|
211
|
+
http_resp: Response,
|
|
212
|
+
):
|
|
213
|
+
if http_resp.status != 200:
|
|
214
|
+
raise error_for_response(http_resp)
|
|
215
|
+
|
|
216
|
+
content = http_resp.content
|
|
217
|
+
|
|
218
|
+
if self._compressor is not None:
|
|
219
|
+
content = self._compressor.decompress(content)
|
|
220
|
+
|
|
221
|
+
return self._codec.decode(
|
|
222
|
+
content,
|
|
223
|
+
msg_type=self._response_type,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
async def acall_unary(
|
|
227
|
+
self,
|
|
228
|
+
req,
|
|
229
|
+
request_timeout=None,
|
|
230
|
+
headers={},
|
|
231
|
+
**opts,
|
|
232
|
+
):
|
|
233
|
+
if self.async_pool is None:
|
|
234
|
+
raise ValueError("async_pool is required")
|
|
235
|
+
|
|
236
|
+
req_data = self._prepare_unary_request(
|
|
237
|
+
req,
|
|
238
|
+
request_timeout,
|
|
239
|
+
headers,
|
|
240
|
+
**opts,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
conn = self.async_pool
|
|
244
|
+
|
|
245
|
+
for _ in range(self._connection_retries):
|
|
246
|
+
try:
|
|
247
|
+
res = await conn.request(**req_data)
|
|
248
|
+
return self._process_unary_response(res)
|
|
249
|
+
except RemoteProtocolError:
|
|
250
|
+
conn = self.async_pool.create_connection(URL(req_data["url"]).origin)
|
|
251
|
+
|
|
252
|
+
continue
|
|
253
|
+
except:
|
|
254
|
+
raise
|
|
255
|
+
|
|
256
|
+
def call_unary(self, req, request_timeout=None, headers={}, **opts):
|
|
257
|
+
if self.pool is None:
|
|
258
|
+
raise ValueError("pool is required")
|
|
259
|
+
|
|
260
|
+
req_data = self._prepare_unary_request(
|
|
261
|
+
req,
|
|
262
|
+
request_timeout,
|
|
263
|
+
headers,
|
|
264
|
+
**opts,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
conn = self.pool
|
|
268
|
+
|
|
269
|
+
for _ in range(self._connection_retries):
|
|
270
|
+
try:
|
|
271
|
+
res = conn.request(**req_data)
|
|
272
|
+
return self._process_unary_response(res)
|
|
273
|
+
except RemoteProtocolError:
|
|
274
|
+
conn = self.pool.create_connection(URL(req_data["url"]).origin)
|
|
275
|
+
|
|
276
|
+
continue
|
|
277
|
+
except:
|
|
278
|
+
raise
|
|
279
|
+
|
|
280
|
+
def _create_stream_timeout(self, timeout: Optional[int]):
|
|
281
|
+
if timeout:
|
|
282
|
+
return {"connect-timeout-ms": str(timeout * 1000)}
|
|
283
|
+
return {}
|
|
284
|
+
|
|
285
|
+
def _prepare_server_stream_request(
|
|
286
|
+
self,
|
|
287
|
+
req,
|
|
288
|
+
request_timeout=None,
|
|
289
|
+
timeout=None,
|
|
290
|
+
headers={},
|
|
291
|
+
**opts,
|
|
292
|
+
):
|
|
293
|
+
data = self._codec.encode(req)
|
|
294
|
+
flags = EnvelopeFlags(0)
|
|
295
|
+
|
|
296
|
+
extensions = (
|
|
297
|
+
None
|
|
298
|
+
if request_timeout is None
|
|
299
|
+
else {"timeout": {"connect": request_timeout, "pool": request_timeout}}
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
if self._compressor is not None:
|
|
303
|
+
data = self._compressor.compress(data)
|
|
304
|
+
flags |= EnvelopeFlags.compressed
|
|
305
|
+
|
|
306
|
+
stream_timeout = self._create_stream_timeout(timeout)
|
|
307
|
+
|
|
308
|
+
return {
|
|
309
|
+
"method": "POST",
|
|
310
|
+
"url": self.url,
|
|
311
|
+
"content": encode_envelope(
|
|
312
|
+
flags=flags,
|
|
313
|
+
data=data,
|
|
314
|
+
),
|
|
315
|
+
"extensions": extensions,
|
|
316
|
+
"headers": {
|
|
317
|
+
**self._headers,
|
|
318
|
+
**headers,
|
|
319
|
+
**opts.get("headers", {}),
|
|
320
|
+
**stream_timeout,
|
|
321
|
+
"connect-protocol-version": "1",
|
|
322
|
+
"connect-content-encoding": (
|
|
323
|
+
"identity" if self._compressor is None else self._compressor.name
|
|
324
|
+
),
|
|
325
|
+
"content-type": f"application/connect+{self._codec.content_type}",
|
|
326
|
+
},
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
async def acall_server_stream(
|
|
330
|
+
self,
|
|
331
|
+
req,
|
|
332
|
+
request_timeout=None,
|
|
333
|
+
timeout=None,
|
|
334
|
+
headers={},
|
|
335
|
+
**opts,
|
|
336
|
+
):
|
|
337
|
+
if self.async_pool is None:
|
|
338
|
+
raise ValueError("async_pool is required")
|
|
339
|
+
|
|
340
|
+
req_data = self._prepare_server_stream_request(
|
|
341
|
+
req,
|
|
342
|
+
request_timeout,
|
|
343
|
+
timeout,
|
|
344
|
+
headers,
|
|
345
|
+
**opts,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
conn = self.async_pool
|
|
349
|
+
|
|
350
|
+
for _ in range(self._connection_retries):
|
|
351
|
+
try:
|
|
352
|
+
async with conn.stream(**req_data) as http_resp:
|
|
353
|
+
if http_resp.status != 200:
|
|
354
|
+
await http_resp.aread()
|
|
355
|
+
raise error_for_response(http_resp)
|
|
356
|
+
|
|
357
|
+
parser = ServerStreamParser(
|
|
358
|
+
decode=self._codec.decode,
|
|
359
|
+
response_type=self._response_type,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
async for chunk in http_resp.aiter_stream():
|
|
363
|
+
for chunk in parser.parse(chunk):
|
|
364
|
+
yield chunk
|
|
365
|
+
|
|
366
|
+
return
|
|
367
|
+
except RemoteProtocolError:
|
|
368
|
+
conn = self.async_pool.create_connection(URL(req_data["url"]).origin)
|
|
369
|
+
|
|
370
|
+
continue
|
|
371
|
+
except:
|
|
372
|
+
raise
|
|
373
|
+
|
|
374
|
+
def call_server_stream(
|
|
375
|
+
self,
|
|
376
|
+
req,
|
|
377
|
+
request_timeout=None,
|
|
378
|
+
timeout=None,
|
|
379
|
+
headers={},
|
|
380
|
+
**opts,
|
|
381
|
+
):
|
|
382
|
+
if self.pool is None:
|
|
383
|
+
raise ValueError("pool is required")
|
|
384
|
+
|
|
385
|
+
req_data = self._prepare_server_stream_request(
|
|
386
|
+
req,
|
|
387
|
+
request_timeout,
|
|
388
|
+
timeout,
|
|
389
|
+
headers,
|
|
390
|
+
**opts,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
conn = self.pool
|
|
394
|
+
|
|
395
|
+
for _ in range(self._connection_retries):
|
|
396
|
+
try:
|
|
397
|
+
with conn.stream(**req_data) as http_resp:
|
|
398
|
+
if http_resp.status != 200:
|
|
399
|
+
http_resp.read()
|
|
400
|
+
raise error_for_response(http_resp)
|
|
401
|
+
|
|
402
|
+
parser = ServerStreamParser(
|
|
403
|
+
decode=self._codec.decode,
|
|
404
|
+
response_type=self._response_type,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
for chunk in http_resp.iter_stream():
|
|
408
|
+
yield from parser.parse(chunk)
|
|
409
|
+
|
|
410
|
+
return
|
|
411
|
+
except RemoteProtocolError:
|
|
412
|
+
conn = self.pool.create_connection(URL(req_data["url"]).origin)
|
|
413
|
+
|
|
414
|
+
continue
|
|
415
|
+
except:
|
|
416
|
+
raise
|
|
417
|
+
|
|
418
|
+
def call_client_stream(self, req, **opts):
|
|
419
|
+
raise NotImplementedError("client stream not supported")
|
|
420
|
+
|
|
421
|
+
def acall_client_stream(self, req, **opts):
|
|
422
|
+
raise NotImplementedError("client stream not supported")
|
|
423
|
+
|
|
424
|
+
def call_bidi_stream(self, req, **opts):
|
|
425
|
+
raise NotImplementedError("bidi stream not supported")
|
|
426
|
+
|
|
427
|
+
def acall_bidi_stream(self, req, **opts):
|
|
428
|
+
raise NotImplementedError("bidi stream not supported")
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
DataLen = int
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
class ServerStreamParser:
|
|
435
|
+
def __init__(
|
|
436
|
+
self,
|
|
437
|
+
decode: Callable,
|
|
438
|
+
response_type: Any,
|
|
439
|
+
):
|
|
440
|
+
self.decode = decode
|
|
441
|
+
self.response_type = response_type
|
|
442
|
+
|
|
443
|
+
self.buffer: bytes = b""
|
|
444
|
+
self._header: Optional[tuple[EnvelopeFlags, DataLen]] = None
|
|
445
|
+
|
|
446
|
+
def shift_buffer(self, size: int):
|
|
447
|
+
buffer = self.buffer[:size]
|
|
448
|
+
self.buffer = self.buffer[size:]
|
|
449
|
+
return buffer
|
|
450
|
+
|
|
451
|
+
@property
|
|
452
|
+
def header(self) -> Tuple[EnvelopeFlags, DataLen]:
|
|
453
|
+
if self._header:
|
|
454
|
+
return self._header
|
|
455
|
+
|
|
456
|
+
header_data = self.shift_buffer(envelope_header_length)
|
|
457
|
+
self._header = decode_envelope_header(header_data)
|
|
458
|
+
|
|
459
|
+
return self._header
|
|
460
|
+
|
|
461
|
+
@header.deleter
|
|
462
|
+
def header(self):
|
|
463
|
+
self._header = None
|
|
464
|
+
|
|
465
|
+
def parse(self, chunk: bytes) -> Generator[Any, None, None]:
|
|
466
|
+
self.buffer += chunk
|
|
467
|
+
|
|
468
|
+
while len(self.buffer) >= envelope_header_length:
|
|
469
|
+
flags, data_len = self.header
|
|
470
|
+
|
|
471
|
+
if data_len > len(self.buffer):
|
|
472
|
+
break
|
|
473
|
+
|
|
474
|
+
data = self.shift_buffer(data_len)
|
|
475
|
+
|
|
476
|
+
if EnvelopeFlags.end_stream in flags:
|
|
477
|
+
data = json.loads(data)
|
|
478
|
+
|
|
479
|
+
if "error" in data:
|
|
480
|
+
raise make_error(data["error"])
|
|
481
|
+
|
|
482
|
+
return
|
|
483
|
+
|
|
484
|
+
yield self.decode(data, msg_type=self.response_type)
|
|
485
|
+
del self.header
|