dashscope 1.8.0__py3-none-any.whl → 1.25.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dashscope/__init__.py +61 -14
- dashscope/aigc/__init__.py +10 -3
- dashscope/aigc/chat_completion.py +282 -0
- dashscope/aigc/code_generation.py +145 -0
- dashscope/aigc/conversation.py +71 -12
- dashscope/aigc/generation.py +288 -16
- dashscope/aigc/image_synthesis.py +473 -31
- dashscope/aigc/multimodal_conversation.py +299 -14
- dashscope/aigc/video_synthesis.py +610 -0
- dashscope/api_entities/aiohttp_request.py +8 -5
- dashscope/api_entities/api_request_data.py +4 -2
- dashscope/api_entities/api_request_factory.py +68 -20
- dashscope/api_entities/base_request.py +20 -3
- dashscope/api_entities/chat_completion_types.py +344 -0
- dashscope/api_entities/dashscope_response.py +243 -15
- dashscope/api_entities/encryption.py +179 -0
- dashscope/api_entities/http_request.py +216 -62
- dashscope/api_entities/websocket_request.py +43 -34
- dashscope/app/__init__.py +5 -0
- dashscope/app/application.py +203 -0
- dashscope/app/application_response.py +246 -0
- dashscope/assistants/__init__.py +16 -0
- dashscope/assistants/assistant_types.py +175 -0
- dashscope/assistants/assistants.py +311 -0
- dashscope/assistants/files.py +197 -0
- dashscope/audio/__init__.py +4 -2
- dashscope/audio/asr/__init__.py +17 -1
- dashscope/audio/asr/asr_phrase_manager.py +203 -0
- dashscope/audio/asr/recognition.py +167 -27
- dashscope/audio/asr/transcription.py +107 -14
- dashscope/audio/asr/translation_recognizer.py +1006 -0
- dashscope/audio/asr/vocabulary.py +177 -0
- dashscope/audio/qwen_asr/__init__.py +7 -0
- dashscope/audio/qwen_asr/qwen_transcription.py +189 -0
- dashscope/audio/qwen_omni/__init__.py +11 -0
- dashscope/audio/qwen_omni/omni_realtime.py +524 -0
- dashscope/audio/qwen_tts/__init__.py +5 -0
- dashscope/audio/qwen_tts/speech_synthesizer.py +77 -0
- dashscope/audio/qwen_tts_realtime/__init__.py +10 -0
- dashscope/audio/qwen_tts_realtime/qwen_tts_realtime.py +355 -0
- dashscope/audio/tts/__init__.py +2 -0
- dashscope/audio/tts/speech_synthesizer.py +5 -0
- dashscope/audio/tts_v2/__init__.py +12 -0
- dashscope/audio/tts_v2/enrollment.py +179 -0
- dashscope/audio/tts_v2/speech_synthesizer.py +886 -0
- dashscope/cli.py +157 -37
- dashscope/client/base_api.py +652 -87
- dashscope/common/api_key.py +2 -0
- dashscope/common/base_type.py +135 -0
- dashscope/common/constants.py +13 -16
- dashscope/common/env.py +2 -0
- dashscope/common/error.py +58 -22
- dashscope/common/logging.py +2 -0
- dashscope/common/message_manager.py +2 -0
- dashscope/common/utils.py +276 -46
- dashscope/customize/__init__.py +0 -0
- dashscope/customize/customize_types.py +192 -0
- dashscope/customize/deployments.py +146 -0
- dashscope/customize/finetunes.py +234 -0
- dashscope/embeddings/__init__.py +5 -1
- dashscope/embeddings/batch_text_embedding.py +208 -0
- dashscope/embeddings/batch_text_embedding_response.py +65 -0
- dashscope/embeddings/multimodal_embedding.py +118 -10
- dashscope/embeddings/text_embedding.py +13 -1
- dashscope/{file.py → files.py} +19 -4
- dashscope/io/input_output.py +2 -0
- dashscope/model.py +11 -2
- dashscope/models.py +43 -0
- dashscope/multimodal/__init__.py +20 -0
- dashscope/multimodal/dialog_state.py +56 -0
- dashscope/multimodal/multimodal_constants.py +28 -0
- dashscope/multimodal/multimodal_dialog.py +648 -0
- dashscope/multimodal/multimodal_request_params.py +313 -0
- dashscope/multimodal/tingwu/__init__.py +10 -0
- dashscope/multimodal/tingwu/tingwu.py +80 -0
- dashscope/multimodal/tingwu/tingwu_realtime.py +579 -0
- dashscope/nlp/__init__.py +0 -0
- dashscope/nlp/understanding.py +64 -0
- dashscope/protocol/websocket.py +3 -0
- dashscope/rerank/__init__.py +0 -0
- dashscope/rerank/text_rerank.py +69 -0
- dashscope/resources/qwen.tiktoken +151643 -0
- dashscope/threads/__init__.py +26 -0
- dashscope/threads/messages/__init__.py +0 -0
- dashscope/threads/messages/files.py +113 -0
- dashscope/threads/messages/messages.py +220 -0
- dashscope/threads/runs/__init__.py +0 -0
- dashscope/threads/runs/runs.py +501 -0
- dashscope/threads/runs/steps.py +112 -0
- dashscope/threads/thread_types.py +665 -0
- dashscope/threads/threads.py +212 -0
- dashscope/tokenizers/__init__.py +7 -0
- dashscope/tokenizers/qwen_tokenizer.py +111 -0
- dashscope/tokenizers/tokenization.py +125 -0
- dashscope/tokenizers/tokenizer.py +45 -0
- dashscope/tokenizers/tokenizer_base.py +32 -0
- dashscope/utils/__init__.py +0 -0
- dashscope/utils/message_utils.py +838 -0
- dashscope/utils/oss_utils.py +243 -0
- dashscope/utils/param_utils.py +29 -0
- dashscope/version.py +3 -1
- {dashscope-1.8.0.dist-info → dashscope-1.25.6.dist-info}/METADATA +53 -50
- dashscope-1.25.6.dist-info/RECORD +112 -0
- {dashscope-1.8.0.dist-info → dashscope-1.25.6.dist-info}/WHEEL +1 -1
- {dashscope-1.8.0.dist-info → dashscope-1.25.6.dist-info}/entry_points.txt +0 -1
- {dashscope-1.8.0.dist-info → dashscope-1.25.6.dist-info/licenses}/LICENSE +2 -4
- dashscope/deployment.py +0 -129
- dashscope/finetune.py +0 -149
- dashscope-1.8.0.dist-info/RECORD +0 -49
- {dashscope-1.8.0.dist-info → dashscope-1.25.6.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,28 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
import datetime
|
|
1
3
|
import json
|
|
4
|
+
import ssl
|
|
2
5
|
from http import HTTPStatus
|
|
6
|
+
from typing import Optional
|
|
3
7
|
|
|
8
|
+
import aiohttp
|
|
9
|
+
import certifi
|
|
4
10
|
import requests
|
|
5
11
|
|
|
6
|
-
from dashscope.api_entities.base_request import
|
|
12
|
+
from dashscope.api_entities.base_request import AioBaseRequest
|
|
7
13
|
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
|
|
8
14
|
from dashscope.common.constants import (DEFAULT_REQUEST_TIMEOUT_SECONDS,
|
|
9
|
-
SSE_CONTENT_TYPE, HTTPMethod
|
|
10
|
-
StreamResultMode)
|
|
15
|
+
SSE_CONTENT_TYPE, HTTPMethod)
|
|
11
16
|
from dashscope.common.error import UnsupportedHTTPMethod
|
|
12
17
|
from dashscope.common.logging import logger
|
|
18
|
+
from dashscope.common.utils import (_handle_aio_stream,
|
|
19
|
+
_handle_aiohttp_failed_response,
|
|
20
|
+
_handle_http_failed_response,
|
|
21
|
+
_handle_stream)
|
|
22
|
+
from dashscope.api_entities.encryption import Encryption
|
|
13
23
|
|
|
14
24
|
|
|
15
|
-
class HttpRequest(
|
|
25
|
+
class HttpRequest(AioBaseRequest):
|
|
16
26
|
def __init__(self,
|
|
17
27
|
url: str,
|
|
18
28
|
api_key: str,
|
|
@@ -20,9 +30,11 @@ class HttpRequest(BaseRequest):
|
|
|
20
30
|
stream: bool = True,
|
|
21
31
|
async_request: bool = False,
|
|
22
32
|
query: bool = False,
|
|
23
|
-
stream_result_mode: str = StreamResultMode.ACCUMULATE,
|
|
24
33
|
timeout: int = DEFAULT_REQUEST_TIMEOUT_SECONDS,
|
|
25
|
-
task_id: str = None
|
|
34
|
+
task_id: str = None,
|
|
35
|
+
flattened_output: bool = False,
|
|
36
|
+
encryption: Optional[Encryption] = None,
|
|
37
|
+
user_agent: str = '') -> None:
|
|
26
38
|
"""HttpSSERequest, processing http server sent event stream.
|
|
27
39
|
|
|
28
40
|
Args:
|
|
@@ -32,16 +44,31 @@ class HttpRequest(BaseRequest):
|
|
|
32
44
|
stream (bool, optional): Is stream request. Defaults to True.
|
|
33
45
|
timeout (int, optional): Total request timeout.
|
|
34
46
|
Defaults to DEFAULT_REQUEST_TIMEOUT_SECONDS.
|
|
47
|
+
user_agent (str, optional): Additional user agent string to
|
|
48
|
+
append. Defaults to ''.
|
|
35
49
|
"""
|
|
36
50
|
|
|
37
|
-
super().__init__()
|
|
51
|
+
super().__init__(user_agent=user_agent)
|
|
38
52
|
self.url = url
|
|
53
|
+
self.flattened_output = flattened_output
|
|
39
54
|
self.async_request = async_request
|
|
55
|
+
self.encryption = encryption
|
|
40
56
|
self.headers = {
|
|
41
57
|
'Accept': 'application/json',
|
|
42
58
|
'Authorization': 'Bearer %s' % api_key,
|
|
43
59
|
**self.headers,
|
|
44
60
|
}
|
|
61
|
+
|
|
62
|
+
if encryption and encryption.is_valid():
|
|
63
|
+
self.headers = {
|
|
64
|
+
"X-DashScope-EncryptionKey": json.dumps({
|
|
65
|
+
"public_key_id": encryption.get_pub_key_id(),
|
|
66
|
+
"encrypt_key": encryption.get_encrypted_aes_key_str(),
|
|
67
|
+
"iv": encryption.get_base64_iv_str()
|
|
68
|
+
}),
|
|
69
|
+
**self.headers,
|
|
70
|
+
}
|
|
71
|
+
|
|
45
72
|
self.query = query
|
|
46
73
|
if self.async_request and self.query is False:
|
|
47
74
|
self.headers = {
|
|
@@ -83,34 +110,177 @@ class HttpRequest(BaseRequest):
|
|
|
83
110
|
pass
|
|
84
111
|
return output
|
|
85
112
|
|
|
86
|
-
def
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
113
|
+
async def aio_call(self):
|
|
114
|
+
response = self._handle_aio_request()
|
|
115
|
+
if self.stream:
|
|
116
|
+
return (item async for item in response)
|
|
117
|
+
else:
|
|
118
|
+
result = await response.__anext__()
|
|
119
|
+
try:
|
|
120
|
+
await response.__anext__()
|
|
121
|
+
except StopAsyncIteration:
|
|
122
|
+
pass
|
|
123
|
+
return result
|
|
124
|
+
|
|
125
|
+
async def _handle_aio_request(self):
|
|
126
|
+
try:
|
|
127
|
+
connector = aiohttp.TCPConnector(
|
|
128
|
+
ssl=ssl.create_default_context(
|
|
129
|
+
cafile=certifi.where()))
|
|
130
|
+
async with aiohttp.ClientSession(
|
|
131
|
+
connector=connector,
|
|
132
|
+
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
|
133
|
+
headers=self.headers) as session:
|
|
134
|
+
logger.debug('Starting request: %s' % self.url)
|
|
135
|
+
if self.method == HTTPMethod.POST:
|
|
136
|
+
is_form, obj = False, {}
|
|
137
|
+
if hasattr(self, 'data') and self.data is not None:
|
|
138
|
+
is_form, obj = self.data.get_aiohttp_payload()
|
|
139
|
+
if is_form:
|
|
140
|
+
headers = {**self.headers, **obj.headers}
|
|
141
|
+
response = await session.post(url=self.url,
|
|
142
|
+
data=obj,
|
|
143
|
+
headers=headers)
|
|
144
|
+
else:
|
|
145
|
+
response = await session.request('POST',
|
|
146
|
+
url=self.url,
|
|
147
|
+
json=obj,
|
|
148
|
+
headers=self.headers)
|
|
149
|
+
elif self.method == HTTPMethod.GET:
|
|
150
|
+
# 添加条件判断
|
|
151
|
+
params = {}
|
|
152
|
+
if hasattr(self, 'data') and self.data is not None:
|
|
153
|
+
params = getattr(self.data, 'parameters', {})
|
|
154
|
+
if params:
|
|
155
|
+
params = self.__handle_parameters(params)
|
|
156
|
+
response = await session.get(url=self.url,
|
|
157
|
+
params=params,
|
|
158
|
+
headers=self.headers)
|
|
104
159
|
else:
|
|
105
|
-
|
|
160
|
+
raise UnsupportedHTTPMethod('Unsupported http method: %s' %
|
|
161
|
+
self.method)
|
|
162
|
+
logger.debug('Response returned: %s' % self.url)
|
|
163
|
+
async with response:
|
|
164
|
+
async for rsp in self._handle_aio_response(response):
|
|
165
|
+
yield rsp
|
|
166
|
+
except aiohttp.ClientConnectorError as e:
|
|
167
|
+
logger.error(e)
|
|
168
|
+
raise e
|
|
169
|
+
except BaseException as e:
|
|
170
|
+
logger.error(e)
|
|
171
|
+
raise e
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def __handle_parameters(params: dict) -> dict:
|
|
175
|
+
def __format(value):
|
|
176
|
+
if isinstance(value, bool):
|
|
177
|
+
return str(value).lower()
|
|
178
|
+
elif isinstance(value, (str, int, float)):
|
|
179
|
+
return value
|
|
180
|
+
elif value is None:
|
|
181
|
+
return ''
|
|
182
|
+
elif isinstance(value, (datetime.datetime, datetime.date)):
|
|
183
|
+
return value.isoformat()
|
|
184
|
+
elif isinstance(value, (list, tuple)):
|
|
185
|
+
return ','.join(str(__format(x)) for x in value)
|
|
186
|
+
elif isinstance(value, dict):
|
|
187
|
+
return json.dumps(value)
|
|
188
|
+
else:
|
|
189
|
+
try:
|
|
190
|
+
return str(value)
|
|
191
|
+
except Exception as e:
|
|
192
|
+
raise ValueError(f"Unsupported type {type(value)} for param formatting: {e}")
|
|
193
|
+
|
|
194
|
+
formatted = {}
|
|
195
|
+
for k, v in params.items():
|
|
196
|
+
formatted[k] = __format(v)
|
|
197
|
+
return formatted
|
|
198
|
+
|
|
199
|
+
async def _handle_aio_response(self, response: aiohttp.ClientResponse):
|
|
200
|
+
request_id = ''
|
|
201
|
+
if (response.status == HTTPStatus.OK and self.stream
|
|
202
|
+
and SSE_CONTENT_TYPE in response.content_type):
|
|
203
|
+
async for is_error, status_code, data in _handle_aio_stream(
|
|
204
|
+
response):
|
|
205
|
+
try:
|
|
206
|
+
output = None
|
|
207
|
+
usage = None
|
|
208
|
+
msg = json.loads(data)
|
|
209
|
+
if not is_error:
|
|
210
|
+
if 'output' in msg:
|
|
211
|
+
output = msg['output']
|
|
212
|
+
if 'usage' in msg:
|
|
213
|
+
usage = msg['usage']
|
|
214
|
+
if 'request_id' in msg:
|
|
215
|
+
request_id = msg['request_id']
|
|
216
|
+
except json.JSONDecodeError:
|
|
217
|
+
yield DashScopeAPIResponse(
|
|
218
|
+
request_id=request_id,
|
|
219
|
+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
220
|
+
code='Unknown',
|
|
221
|
+
message=data)
|
|
222
|
+
continue
|
|
223
|
+
if is_error:
|
|
224
|
+
yield DashScopeAPIResponse(request_id=request_id,
|
|
225
|
+
status_code=status_code,
|
|
226
|
+
code=msg['code'],
|
|
227
|
+
message=msg['message'])
|
|
228
|
+
else:
|
|
229
|
+
if self.encryption and self.encryption.is_valid():
|
|
230
|
+
output = self.encryption.decrypt(output)
|
|
231
|
+
yield DashScopeAPIResponse(request_id=request_id,
|
|
232
|
+
status_code=HTTPStatus.OK,
|
|
233
|
+
output=output,
|
|
234
|
+
usage=usage)
|
|
235
|
+
elif (response.status == HTTPStatus.OK
|
|
236
|
+
and 'multipart' in response.content_type):
|
|
237
|
+
reader = aiohttp.MultipartReader.from_response(response)
|
|
238
|
+
output = {}
|
|
239
|
+
while True:
|
|
240
|
+
part = await reader.next()
|
|
241
|
+
if part is None:
|
|
242
|
+
break
|
|
243
|
+
output[part.name] = await part.read()
|
|
244
|
+
if 'request_id' in output:
|
|
245
|
+
request_id = output['request_id']
|
|
246
|
+
if self.encryption and self.encryption.is_valid():
|
|
247
|
+
output = self.encryption.decrypt(output)
|
|
248
|
+
yield DashScopeAPIResponse(request_id=request_id,
|
|
249
|
+
status_code=HTTPStatus.OK,
|
|
250
|
+
output=output)
|
|
251
|
+
elif response.status == HTTPStatus.OK:
|
|
252
|
+
json_content = await response.json()
|
|
253
|
+
output = None
|
|
254
|
+
usage = None
|
|
255
|
+
if 'output' in json_content and json_content['output'] is not None:
|
|
256
|
+
output = json_content['output']
|
|
257
|
+
# Compatible with wan
|
|
258
|
+
elif 'data' in json_content and json_content['data'] is not None\
|
|
259
|
+
and isinstance(json_content['data'], list)\
|
|
260
|
+
and len(json_content['data']) > 0\
|
|
261
|
+
and 'task_id' in json_content['data'][0]:
|
|
262
|
+
output = json_content
|
|
263
|
+
if 'usage' in json_content:
|
|
264
|
+
usage = json_content['usage']
|
|
265
|
+
if 'request_id' in json_content:
|
|
266
|
+
request_id = json_content['request_id']
|
|
267
|
+
if self.encryption and self.encryption.is_valid():
|
|
268
|
+
output = self.encryption.decrypt(output)
|
|
269
|
+
yield DashScopeAPIResponse(request_id=request_id,
|
|
270
|
+
status_code=HTTPStatus.OK,
|
|
271
|
+
output=output,
|
|
272
|
+
usage=usage)
|
|
273
|
+
else:
|
|
274
|
+
yield await _handle_aiohttp_failed_response(response)
|
|
106
275
|
|
|
107
276
|
def _handle_response(self, response: requests.Response):
|
|
108
277
|
request_id = ''
|
|
109
278
|
if (response.status_code == HTTPStatus.OK and self.stream
|
|
110
279
|
and SSE_CONTENT_TYPE in response.headers.get(
|
|
111
280
|
'content-type', '')):
|
|
112
|
-
for is_error, status_code,
|
|
281
|
+
for is_error, status_code, event in _handle_stream(response):
|
|
113
282
|
try:
|
|
283
|
+
data = event.data
|
|
114
284
|
output = None
|
|
115
285
|
usage = None
|
|
116
286
|
msg = json.loads(data)
|
|
@@ -140,10 +310,15 @@ class HttpRequest(BaseRequest):
|
|
|
140
310
|
message=msg['message']
|
|
141
311
|
if 'message' in msg else None) # noqa E501
|
|
142
312
|
else:
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
313
|
+
if self.flattened_output:
|
|
314
|
+
yield msg
|
|
315
|
+
else:
|
|
316
|
+
if self.encryption and self.encryption.is_valid():
|
|
317
|
+
output = self.encryption.decrypt(output)
|
|
318
|
+
yield DashScopeAPIResponse(request_id=request_id,
|
|
319
|
+
status_code=HTTPStatus.OK,
|
|
320
|
+
output=output,
|
|
321
|
+
usage=usage)
|
|
147
322
|
elif response.status_code == HTTPStatus.OK:
|
|
148
323
|
json_content = response.json()
|
|
149
324
|
logger.debug('Response: %s' % json_content)
|
|
@@ -157,38 +332,17 @@ class HttpRequest(BaseRequest):
|
|
|
157
332
|
usage = json_content['usage']
|
|
158
333
|
if 'request_id' in json_content:
|
|
159
334
|
request_id = json_content['request_id']
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
output=output,
|
|
163
|
-
usage=usage)
|
|
164
|
-
else:
|
|
165
|
-
if 'application/json' in response.headers.get('content-type', ''):
|
|
166
|
-
error = response.json()
|
|
167
|
-
if 'request_id' in error:
|
|
168
|
-
request_id = error['request_id']
|
|
169
|
-
if 'message' not in error:
|
|
170
|
-
message = ''
|
|
171
|
-
logger.error('Request: %s failed, status: %s' %
|
|
172
|
-
(self.url, response.status_code))
|
|
173
|
-
else:
|
|
174
|
-
message = error['message']
|
|
175
|
-
logger.error(
|
|
176
|
-
'Request: %s failed, status: %s, message: %s' %
|
|
177
|
-
(self.url, response.status_code, error['message']))
|
|
178
|
-
yield DashScopeAPIResponse(
|
|
179
|
-
request_id=request_id,
|
|
180
|
-
status_code=response.status_code,
|
|
181
|
-
output=None,
|
|
182
|
-
code=error['code']
|
|
183
|
-
if 'code' in error else None, # noqa E501
|
|
184
|
-
message=message)
|
|
335
|
+
if self.flattened_output:
|
|
336
|
+
yield json_content
|
|
185
337
|
else:
|
|
186
|
-
|
|
338
|
+
if self.encryption and self.encryption.is_valid():
|
|
339
|
+
output = self.encryption.decrypt(output)
|
|
187
340
|
yield DashScopeAPIResponse(request_id=request_id,
|
|
188
|
-
status_code=
|
|
189
|
-
output=
|
|
190
|
-
|
|
191
|
-
|
|
341
|
+
status_code=HTTPStatus.OK,
|
|
342
|
+
output=output,
|
|
343
|
+
usage=usage)
|
|
344
|
+
else:
|
|
345
|
+
yield _handle_http_failed_response(response)
|
|
192
346
|
|
|
193
347
|
def _handle_request(self):
|
|
194
348
|
try:
|
|
@@ -220,6 +374,6 @@ class HttpRequest(BaseRequest):
|
|
|
220
374
|
self.method)
|
|
221
375
|
for rsp in self._handle_response(response):
|
|
222
376
|
yield rsp
|
|
223
|
-
except
|
|
377
|
+
except BaseException as e:
|
|
224
378
|
logger.error(e)
|
|
225
379
|
raise e
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
|
|
1
3
|
import asyncio
|
|
2
4
|
import json
|
|
3
5
|
import uuid
|
|
@@ -10,7 +12,7 @@ from dashscope.api_entities.base_request import AioBaseRequest
|
|
|
10
12
|
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
|
|
11
13
|
from dashscope.common.constants import (DEFAULT_REQUEST_TIMEOUT_SECONDS,
|
|
12
14
|
SERVICE_503_MESSAGE,
|
|
13
|
-
WEBSOCKET_ERROR_CODE
|
|
15
|
+
WEBSOCKET_ERROR_CODE)
|
|
14
16
|
from dashscope.common.error import (RequestFailure, UnexpectedMessageReceived,
|
|
15
17
|
UnknownMessageReceived)
|
|
16
18
|
from dashscope.common.logging import logger
|
|
@@ -29,10 +31,12 @@ class WebSocketRequest(AioBaseRequest):
|
|
|
29
31
|
stream: bool = True,
|
|
30
32
|
ws_stream_mode: str = WebsocketStreamingMode.OUT,
|
|
31
33
|
is_binary_input: bool = False,
|
|
32
|
-
stream_result_mode: str = StreamResultMode.ACCUMULATE,
|
|
33
34
|
timeout: int = DEFAULT_REQUEST_TIMEOUT_SECONDS,
|
|
35
|
+
flattened_output: bool = False,
|
|
36
|
+
pre_task_id=None,
|
|
37
|
+
user_agent: str = '',
|
|
34
38
|
) -> None:
|
|
35
|
-
super().__init__()
|
|
39
|
+
super().__init__(user_agent=user_agent)
|
|
36
40
|
"""HttpRequest.
|
|
37
41
|
|
|
38
42
|
Args:
|
|
@@ -45,12 +49,12 @@ class WebSocketRequest(AioBaseRequest):
|
|
|
45
49
|
"""
|
|
46
50
|
self.url = url
|
|
47
51
|
self.stream = stream
|
|
52
|
+
self.flattened_output = flattened_output
|
|
48
53
|
if timeout is None:
|
|
49
54
|
self.timeout = DEFAULT_REQUEST_TIMEOUT_SECONDS
|
|
50
55
|
else:
|
|
51
56
|
self.timeout = timeout
|
|
52
57
|
self.ws_stream_mode = ws_stream_mode
|
|
53
|
-
self.stream_result_mode = stream_result_mode
|
|
54
58
|
self.is_binary_input = is_binary_input
|
|
55
59
|
|
|
56
60
|
self.headers = {
|
|
@@ -61,6 +65,7 @@ class WebSocketRequest(AioBaseRequest):
|
|
|
61
65
|
self.task_headers = {
|
|
62
66
|
'streaming': self.ws_stream_mode,
|
|
63
67
|
}
|
|
68
|
+
self.pre_task_id = pre_task_id
|
|
64
69
|
|
|
65
70
|
def add_headers(self, headers):
|
|
66
71
|
self.headers = {**self.headers, **headers}
|
|
@@ -77,6 +82,10 @@ class WebSocketRequest(AioBaseRequest):
|
|
|
77
82
|
pass
|
|
78
83
|
return output
|
|
79
84
|
|
|
85
|
+
async def close(self):
|
|
86
|
+
if self.ws is not None and not self.ws.closed:
|
|
87
|
+
await self.ws.close()
|
|
88
|
+
|
|
80
89
|
async def aio_call(self):
|
|
81
90
|
response = self.connection_handler()
|
|
82
91
|
if self.stream:
|
|
@@ -140,7 +149,11 @@ class WebSocketRequest(AioBaseRequest):
|
|
|
140
149
|
code=e.name,
|
|
141
150
|
message=e.message)
|
|
142
151
|
except aiohttp.ClientConnectorError as e:
|
|
143
|
-
|
|
152
|
+
logger.exception(e)
|
|
153
|
+
yield DashScopeAPIResponse(request_id='',
|
|
154
|
+
status_code=-1,
|
|
155
|
+
code='ClientConnectorError',
|
|
156
|
+
message=str(e))
|
|
144
157
|
except aiohttp.WSServerHandshakeError as e:
|
|
145
158
|
code = e.status
|
|
146
159
|
msg = e.message
|
|
@@ -154,6 +167,13 @@ class WebSocketRequest(AioBaseRequest):
|
|
|
154
167
|
status_code=code,
|
|
155
168
|
code=code,
|
|
156
169
|
message=msg)
|
|
170
|
+
except BaseException as e:
|
|
171
|
+
logger.exception(e)
|
|
172
|
+
yield DashScopeAPIResponse(request_id='',
|
|
173
|
+
status_code=-1,
|
|
174
|
+
code='Unknown',
|
|
175
|
+
message='Error type: %s, message: %s' %
|
|
176
|
+
(type(e), e))
|
|
157
177
|
|
|
158
178
|
def _to_DashScopeAPIResponse(self, task_id, is_binary, result):
|
|
159
179
|
if is_binary:
|
|
@@ -177,7 +197,6 @@ class WebSocketRequest(AioBaseRequest):
|
|
|
177
197
|
# check if request stream data, re return an iterator,
|
|
178
198
|
# otherwise we collect data and return user.
|
|
179
199
|
# no matter what, the response is streaming
|
|
180
|
-
final_payload = None
|
|
181
200
|
is_binary_output = False
|
|
182
201
|
while True:
|
|
183
202
|
msg = await ws.receive()
|
|
@@ -186,36 +205,21 @@ class WebSocketRequest(AioBaseRequest):
|
|
|
186
205
|
msg_json = msg.json()
|
|
187
206
|
logger.debug('Receive %s event' % msg_json[HEADER][EVENT_KEY])
|
|
188
207
|
if msg_json[HEADER][EVENT_KEY] == EventType.GENERATED:
|
|
189
|
-
if final_payload is None:
|
|
190
|
-
final_payload = []
|
|
191
208
|
payload = msg_json['payload']
|
|
192
|
-
|
|
193
|
-
yield False, payload
|
|
194
|
-
elif self.stream_result_mode == StreamResultMode.ACCUMULATE: # noqa E501
|
|
195
|
-
final_payload = payload
|
|
196
|
-
else:
|
|
197
|
-
final_payload.append(payload)
|
|
209
|
+
yield False, payload
|
|
198
210
|
elif msg_json[HEADER][EVENT_KEY] == EventType.FINISHED:
|
|
199
|
-
if final_payload is None:
|
|
200
|
-
final_payload = []
|
|
201
211
|
payload = None
|
|
202
212
|
if 'payload' in msg_json:
|
|
203
213
|
payload = msg_json['payload']
|
|
204
214
|
logger.debug(payload)
|
|
205
215
|
if payload:
|
|
206
|
-
|
|
207
|
-
yield False, payload
|
|
208
|
-
elif self.stream_result_mode == StreamResultMode.ACCUMULATE: # noqa E501
|
|
209
|
-
yield False, payload
|
|
210
|
-
else:
|
|
211
|
-
final_payload.extend(payload)
|
|
212
|
-
yield False, final_payload
|
|
216
|
+
yield False, payload
|
|
213
217
|
else:
|
|
214
218
|
if not self.stream:
|
|
215
219
|
if is_binary_output:
|
|
216
|
-
yield True,
|
|
220
|
+
yield True, payload
|
|
217
221
|
else:
|
|
218
|
-
yield False,
|
|
222
|
+
yield False, payload
|
|
219
223
|
break
|
|
220
224
|
elif msg_json[HEADER][EVENT_KEY] == EventType.FAILED:
|
|
221
225
|
self._on_failed(msg_json)
|
|
@@ -225,14 +229,7 @@ class WebSocketRequest(AioBaseRequest):
|
|
|
225
229
|
raise UnknownMessageReceived(error)
|
|
226
230
|
elif msg.type == aiohttp.WSMsgType.BINARY:
|
|
227
231
|
is_binary_output = True
|
|
228
|
-
|
|
229
|
-
final_payload = b''
|
|
230
|
-
if self.stream:
|
|
231
|
-
yield True, msg.data
|
|
232
|
-
elif self.stream_result_mode == StreamResultMode.ACCUMULATE:
|
|
233
|
-
final_payload = msg.data
|
|
234
|
-
else:
|
|
235
|
-
final_payload += msg.data
|
|
232
|
+
yield True, msg.data
|
|
236
233
|
|
|
237
234
|
def _on_failed(self, details):
|
|
238
235
|
error = RequestFailure(request_id=details[HEADER][TASK_ID],
|
|
@@ -243,17 +240,22 @@ class WebSocketRequest(AioBaseRequest):
|
|
|
243
240
|
raise error
|
|
244
241
|
|
|
245
242
|
async def _start_task(self, ws):
|
|
246
|
-
self.
|
|
243
|
+
if self.pre_task_id is not None:
|
|
244
|
+
self.task_headers['task_id'] = self.pre_task_id
|
|
245
|
+
else:
|
|
246
|
+
self.task_headers['task_id'] = uuid.uuid4().hex # create task id.
|
|
247
247
|
task_header = {**self.task_headers, ACTION_KEY: ActionType.START}
|
|
248
248
|
# for binary data, the start action has no input, only parameters.
|
|
249
249
|
start_data = self.data.get_websocket_start_data()
|
|
250
250
|
message = self._build_up_message(task_header, start_data)
|
|
251
|
+
logger.debug('Send start task: {}'.format(message))
|
|
251
252
|
await ws.send_str(message)
|
|
252
253
|
|
|
253
254
|
async def _send_finished_task(self, ws):
|
|
254
255
|
task_header = {**self.task_headers, ACTION_KEY: ActionType.FINISHED}
|
|
255
256
|
payload = {'input': {}}
|
|
256
257
|
message = self._build_up_message(task_header, payload)
|
|
258
|
+
logger.debug('Send finish task: {}'.format(message))
|
|
257
259
|
await ws.send_str(message)
|
|
258
260
|
|
|
259
261
|
async def _send_continue_task_data(self, ws):
|
|
@@ -266,12 +268,19 @@ class WebSocketRequest(AioBaseRequest):
|
|
|
266
268
|
if len(input) > 0:
|
|
267
269
|
if isinstance(input, bytes):
|
|
268
270
|
await ws.send_bytes(input)
|
|
271
|
+
logger.debug(
|
|
272
|
+
'Send continue task with bytes: {}'.format(
|
|
273
|
+
len(input)))
|
|
269
274
|
else:
|
|
270
275
|
await ws.send_bytes(list(input.values())[0])
|
|
276
|
+
logger.debug(
|
|
277
|
+
'Send continue task with list[byte]: {}'.format(
|
|
278
|
+
len(input)))
|
|
271
279
|
else:
|
|
272
280
|
if len(input) > 0:
|
|
273
281
|
message = self._build_up_message(headers=headers,
|
|
274
282
|
payload=input)
|
|
283
|
+
logger.debug('Send continue task: {}'.format(message))
|
|
275
284
|
await ws.send_str(message)
|
|
276
285
|
await asyncio.sleep(0.000001)
|
|
277
286
|
|