dashscope 1.8.0__py3-none-any.whl → 1.25.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. dashscope/__init__.py +61 -14
  2. dashscope/aigc/__init__.py +10 -3
  3. dashscope/aigc/chat_completion.py +282 -0
  4. dashscope/aigc/code_generation.py +145 -0
  5. dashscope/aigc/conversation.py +71 -12
  6. dashscope/aigc/generation.py +288 -16
  7. dashscope/aigc/image_synthesis.py +473 -31
  8. dashscope/aigc/multimodal_conversation.py +299 -14
  9. dashscope/aigc/video_synthesis.py +610 -0
  10. dashscope/api_entities/aiohttp_request.py +8 -5
  11. dashscope/api_entities/api_request_data.py +4 -2
  12. dashscope/api_entities/api_request_factory.py +68 -20
  13. dashscope/api_entities/base_request.py +20 -3
  14. dashscope/api_entities/chat_completion_types.py +344 -0
  15. dashscope/api_entities/dashscope_response.py +243 -15
  16. dashscope/api_entities/encryption.py +179 -0
  17. dashscope/api_entities/http_request.py +216 -62
  18. dashscope/api_entities/websocket_request.py +43 -34
  19. dashscope/app/__init__.py +5 -0
  20. dashscope/app/application.py +203 -0
  21. dashscope/app/application_response.py +246 -0
  22. dashscope/assistants/__init__.py +16 -0
  23. dashscope/assistants/assistant_types.py +175 -0
  24. dashscope/assistants/assistants.py +311 -0
  25. dashscope/assistants/files.py +197 -0
  26. dashscope/audio/__init__.py +4 -2
  27. dashscope/audio/asr/__init__.py +17 -1
  28. dashscope/audio/asr/asr_phrase_manager.py +203 -0
  29. dashscope/audio/asr/recognition.py +167 -27
  30. dashscope/audio/asr/transcription.py +107 -14
  31. dashscope/audio/asr/translation_recognizer.py +1006 -0
  32. dashscope/audio/asr/vocabulary.py +177 -0
  33. dashscope/audio/qwen_asr/__init__.py +7 -0
  34. dashscope/audio/qwen_asr/qwen_transcription.py +189 -0
  35. dashscope/audio/qwen_omni/__init__.py +11 -0
  36. dashscope/audio/qwen_omni/omni_realtime.py +524 -0
  37. dashscope/audio/qwen_tts/__init__.py +5 -0
  38. dashscope/audio/qwen_tts/speech_synthesizer.py +77 -0
  39. dashscope/audio/qwen_tts_realtime/__init__.py +10 -0
  40. dashscope/audio/qwen_tts_realtime/qwen_tts_realtime.py +355 -0
  41. dashscope/audio/tts/__init__.py +2 -0
  42. dashscope/audio/tts/speech_synthesizer.py +5 -0
  43. dashscope/audio/tts_v2/__init__.py +12 -0
  44. dashscope/audio/tts_v2/enrollment.py +179 -0
  45. dashscope/audio/tts_v2/speech_synthesizer.py +886 -0
  46. dashscope/cli.py +157 -37
  47. dashscope/client/base_api.py +652 -87
  48. dashscope/common/api_key.py +2 -0
  49. dashscope/common/base_type.py +135 -0
  50. dashscope/common/constants.py +13 -16
  51. dashscope/common/env.py +2 -0
  52. dashscope/common/error.py +58 -22
  53. dashscope/common/logging.py +2 -0
  54. dashscope/common/message_manager.py +2 -0
  55. dashscope/common/utils.py +276 -46
  56. dashscope/customize/__init__.py +0 -0
  57. dashscope/customize/customize_types.py +192 -0
  58. dashscope/customize/deployments.py +146 -0
  59. dashscope/customize/finetunes.py +234 -0
  60. dashscope/embeddings/__init__.py +5 -1
  61. dashscope/embeddings/batch_text_embedding.py +208 -0
  62. dashscope/embeddings/batch_text_embedding_response.py +65 -0
  63. dashscope/embeddings/multimodal_embedding.py +118 -10
  64. dashscope/embeddings/text_embedding.py +13 -1
  65. dashscope/{file.py → files.py} +19 -4
  66. dashscope/io/input_output.py +2 -0
  67. dashscope/model.py +11 -2
  68. dashscope/models.py +43 -0
  69. dashscope/multimodal/__init__.py +20 -0
  70. dashscope/multimodal/dialog_state.py +56 -0
  71. dashscope/multimodal/multimodal_constants.py +28 -0
  72. dashscope/multimodal/multimodal_dialog.py +648 -0
  73. dashscope/multimodal/multimodal_request_params.py +313 -0
  74. dashscope/multimodal/tingwu/__init__.py +10 -0
  75. dashscope/multimodal/tingwu/tingwu.py +80 -0
  76. dashscope/multimodal/tingwu/tingwu_realtime.py +579 -0
  77. dashscope/nlp/__init__.py +0 -0
  78. dashscope/nlp/understanding.py +64 -0
  79. dashscope/protocol/websocket.py +3 -0
  80. dashscope/rerank/__init__.py +0 -0
  81. dashscope/rerank/text_rerank.py +69 -0
  82. dashscope/resources/qwen.tiktoken +151643 -0
  83. dashscope/threads/__init__.py +26 -0
  84. dashscope/threads/messages/__init__.py +0 -0
  85. dashscope/threads/messages/files.py +113 -0
  86. dashscope/threads/messages/messages.py +220 -0
  87. dashscope/threads/runs/__init__.py +0 -0
  88. dashscope/threads/runs/runs.py +501 -0
  89. dashscope/threads/runs/steps.py +112 -0
  90. dashscope/threads/thread_types.py +665 -0
  91. dashscope/threads/threads.py +212 -0
  92. dashscope/tokenizers/__init__.py +7 -0
  93. dashscope/tokenizers/qwen_tokenizer.py +111 -0
  94. dashscope/tokenizers/tokenization.py +125 -0
  95. dashscope/tokenizers/tokenizer.py +45 -0
  96. dashscope/tokenizers/tokenizer_base.py +32 -0
  97. dashscope/utils/__init__.py +0 -0
  98. dashscope/utils/message_utils.py +838 -0
  99. dashscope/utils/oss_utils.py +243 -0
  100. dashscope/utils/param_utils.py +29 -0
  101. dashscope/version.py +3 -1
  102. {dashscope-1.8.0.dist-info → dashscope-1.25.6.dist-info}/METADATA +53 -50
  103. dashscope-1.25.6.dist-info/RECORD +112 -0
  104. {dashscope-1.8.0.dist-info → dashscope-1.25.6.dist-info}/WHEEL +1 -1
  105. {dashscope-1.8.0.dist-info → dashscope-1.25.6.dist-info}/entry_points.txt +0 -1
  106. {dashscope-1.8.0.dist-info → dashscope-1.25.6.dist-info/licenses}/LICENSE +2 -4
  107. dashscope/deployment.py +0 -129
  108. dashscope/finetune.py +0 -149
  109. dashscope-1.8.0.dist-info/RECORD +0 -49
  110. {dashscope-1.8.0.dist-info → dashscope-1.25.6.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,28 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import datetime
1
3
  import json
4
+ import ssl
2
5
  from http import HTTPStatus
6
+ from typing import Optional
3
7
 
8
+ import aiohttp
9
+ import certifi
4
10
  import requests
5
11
 
6
- from dashscope.api_entities.base_request import BaseRequest
12
+ from dashscope.api_entities.base_request import AioBaseRequest
7
13
  from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
8
14
  from dashscope.common.constants import (DEFAULT_REQUEST_TIMEOUT_SECONDS,
9
- SSE_CONTENT_TYPE, HTTPMethod,
10
- StreamResultMode)
15
+ SSE_CONTENT_TYPE, HTTPMethod)
11
16
  from dashscope.common.error import UnsupportedHTTPMethod
12
17
  from dashscope.common.logging import logger
18
+ from dashscope.common.utils import (_handle_aio_stream,
19
+ _handle_aiohttp_failed_response,
20
+ _handle_http_failed_response,
21
+ _handle_stream)
22
+ from dashscope.api_entities.encryption import Encryption
13
23
 
14
24
 
15
- class HttpRequest(BaseRequest):
25
+ class HttpRequest(AioBaseRequest):
16
26
  def __init__(self,
17
27
  url: str,
18
28
  api_key: str,
@@ -20,9 +30,11 @@ class HttpRequest(BaseRequest):
20
30
  stream: bool = True,
21
31
  async_request: bool = False,
22
32
  query: bool = False,
23
- stream_result_mode: str = StreamResultMode.ACCUMULATE,
24
33
  timeout: int = DEFAULT_REQUEST_TIMEOUT_SECONDS,
25
- task_id: str = None) -> None:
34
+ task_id: str = None,
35
+ flattened_output: bool = False,
36
+ encryption: Optional[Encryption] = None,
37
+ user_agent: str = '') -> None:
26
38
  """HttpSSERequest, processing http server sent event stream.
27
39
 
28
40
  Args:
@@ -32,16 +44,31 @@ class HttpRequest(BaseRequest):
32
44
  stream (bool, optional): Is stream request. Defaults to True.
33
45
  timeout (int, optional): Total request timeout.
34
46
  Defaults to DEFAULT_REQUEST_TIMEOUT_SECONDS.
47
+ user_agent (str, optional): Additional user agent string to
48
+ append. Defaults to ''.
35
49
  """
36
50
 
37
- super().__init__()
51
+ super().__init__(user_agent=user_agent)
38
52
  self.url = url
53
+ self.flattened_output = flattened_output
39
54
  self.async_request = async_request
55
+ self.encryption = encryption
40
56
  self.headers = {
41
57
  'Accept': 'application/json',
42
58
  'Authorization': 'Bearer %s' % api_key,
43
59
  **self.headers,
44
60
  }
61
+
62
+ if encryption and encryption.is_valid():
63
+ self.headers = {
64
+ "X-DashScope-EncryptionKey": json.dumps({
65
+ "public_key_id": encryption.get_pub_key_id(),
66
+ "encrypt_key": encryption.get_encrypted_aes_key_str(),
67
+ "iv": encryption.get_base64_iv_str()
68
+ }),
69
+ **self.headers,
70
+ }
71
+
45
72
  self.query = query
46
73
  if self.async_request and self.query is False:
47
74
  self.headers = {
@@ -83,34 +110,177 @@ class HttpRequest(BaseRequest):
83
110
  pass
84
111
  return output
85
112
 
86
- def _handle_stream(self, response: requests.Response):
87
- # TODO define done message.
88
- is_error = False
89
- status_code = HTTPStatus.INTERNAL_SERVER_ERROR
90
- for line in response.iter_lines():
91
- if line:
92
- line = line.decode('utf8')
93
- line = line.rstrip('\n').rstrip('\r')
94
- if line.startswith('event:error'):
95
- is_error = True
96
- elif line.startswith('status:'):
97
- status_code = line[len('status:'):]
98
- status_code = int(status_code.strip())
99
- elif line.startswith('data:'):
100
- line = line[len('data:'):]
101
- yield (is_error, status_code, line)
102
- if is_error:
103
- break
113
+ async def aio_call(self):
114
+ response = self._handle_aio_request()
115
+ if self.stream:
116
+ return (item async for item in response)
117
+ else:
118
+ result = await response.__anext__()
119
+ try:
120
+ await response.__anext__()
121
+ except StopAsyncIteration:
122
+ pass
123
+ return result
124
+
125
+ async def _handle_aio_request(self):
126
+ try:
127
+ connector = aiohttp.TCPConnector(
128
+ ssl=ssl.create_default_context(
129
+ cafile=certifi.where()))
130
+ async with aiohttp.ClientSession(
131
+ connector=connector,
132
+ timeout=aiohttp.ClientTimeout(total=self.timeout),
133
+ headers=self.headers) as session:
134
+ logger.debug('Starting request: %s' % self.url)
135
+ if self.method == HTTPMethod.POST:
136
+ is_form, obj = False, {}
137
+ if hasattr(self, 'data') and self.data is not None:
138
+ is_form, obj = self.data.get_aiohttp_payload()
139
+ if is_form:
140
+ headers = {**self.headers, **obj.headers}
141
+ response = await session.post(url=self.url,
142
+ data=obj,
143
+ headers=headers)
144
+ else:
145
+ response = await session.request('POST',
146
+ url=self.url,
147
+ json=obj,
148
+ headers=self.headers)
149
+ elif self.method == HTTPMethod.GET:
150
+ # 添加条件判断
151
+ params = {}
152
+ if hasattr(self, 'data') and self.data is not None:
153
+ params = getattr(self.data, 'parameters', {})
154
+ if params:
155
+ params = self.__handle_parameters(params)
156
+ response = await session.get(url=self.url,
157
+ params=params,
158
+ headers=self.headers)
104
159
  else:
105
- continue # ignore heartbeat...
160
+ raise UnsupportedHTTPMethod('Unsupported http method: %s' %
161
+ self.method)
162
+ logger.debug('Response returned: %s' % self.url)
163
+ async with response:
164
+ async for rsp in self._handle_aio_response(response):
165
+ yield rsp
166
+ except aiohttp.ClientConnectorError as e:
167
+ logger.error(e)
168
+ raise e
169
+ except BaseException as e:
170
+ logger.error(e)
171
+ raise e
172
+
173
+ @staticmethod
174
+ def __handle_parameters(params: dict) -> dict:
175
+ def __format(value):
176
+ if isinstance(value, bool):
177
+ return str(value).lower()
178
+ elif isinstance(value, (str, int, float)):
179
+ return value
180
+ elif value is None:
181
+ return ''
182
+ elif isinstance(value, (datetime.datetime, datetime.date)):
183
+ return value.isoformat()
184
+ elif isinstance(value, (list, tuple)):
185
+ return ','.join(str(__format(x)) for x in value)
186
+ elif isinstance(value, dict):
187
+ return json.dumps(value)
188
+ else:
189
+ try:
190
+ return str(value)
191
+ except Exception as e:
192
+ raise ValueError(f"Unsupported type {type(value)} for param formatting: {e}")
193
+
194
+ formatted = {}
195
+ for k, v in params.items():
196
+ formatted[k] = __format(v)
197
+ return formatted
198
+
199
+ async def _handle_aio_response(self, response: aiohttp.ClientResponse):
200
+ request_id = ''
201
+ if (response.status == HTTPStatus.OK and self.stream
202
+ and SSE_CONTENT_TYPE in response.content_type):
203
+ async for is_error, status_code, data in _handle_aio_stream(
204
+ response):
205
+ try:
206
+ output = None
207
+ usage = None
208
+ msg = json.loads(data)
209
+ if not is_error:
210
+ if 'output' in msg:
211
+ output = msg['output']
212
+ if 'usage' in msg:
213
+ usage = msg['usage']
214
+ if 'request_id' in msg:
215
+ request_id = msg['request_id']
216
+ except json.JSONDecodeError:
217
+ yield DashScopeAPIResponse(
218
+ request_id=request_id,
219
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
220
+ code='Unknown',
221
+ message=data)
222
+ continue
223
+ if is_error:
224
+ yield DashScopeAPIResponse(request_id=request_id,
225
+ status_code=status_code,
226
+ code=msg['code'],
227
+ message=msg['message'])
228
+ else:
229
+ if self.encryption and self.encryption.is_valid():
230
+ output = self.encryption.decrypt(output)
231
+ yield DashScopeAPIResponse(request_id=request_id,
232
+ status_code=HTTPStatus.OK,
233
+ output=output,
234
+ usage=usage)
235
+ elif (response.status == HTTPStatus.OK
236
+ and 'multipart' in response.content_type):
237
+ reader = aiohttp.MultipartReader.from_response(response)
238
+ output = {}
239
+ while True:
240
+ part = await reader.next()
241
+ if part is None:
242
+ break
243
+ output[part.name] = await part.read()
244
+ if 'request_id' in output:
245
+ request_id = output['request_id']
246
+ if self.encryption and self.encryption.is_valid():
247
+ output = self.encryption.decrypt(output)
248
+ yield DashScopeAPIResponse(request_id=request_id,
249
+ status_code=HTTPStatus.OK,
250
+ output=output)
251
+ elif response.status == HTTPStatus.OK:
252
+ json_content = await response.json()
253
+ output = None
254
+ usage = None
255
+ if 'output' in json_content and json_content['output'] is not None:
256
+ output = json_content['output']
257
+ # Compatible with wan
258
+ elif 'data' in json_content and json_content['data'] is not None\
259
+ and isinstance(json_content['data'], list)\
260
+ and len(json_content['data']) > 0\
261
+ and 'task_id' in json_content['data'][0]:
262
+ output = json_content
263
+ if 'usage' in json_content:
264
+ usage = json_content['usage']
265
+ if 'request_id' in json_content:
266
+ request_id = json_content['request_id']
267
+ if self.encryption and self.encryption.is_valid():
268
+ output = self.encryption.decrypt(output)
269
+ yield DashScopeAPIResponse(request_id=request_id,
270
+ status_code=HTTPStatus.OK,
271
+ output=output,
272
+ usage=usage)
273
+ else:
274
+ yield await _handle_aiohttp_failed_response(response)
106
275
 
107
276
  def _handle_response(self, response: requests.Response):
108
277
  request_id = ''
109
278
  if (response.status_code == HTTPStatus.OK and self.stream
110
279
  and SSE_CONTENT_TYPE in response.headers.get(
111
280
  'content-type', '')):
112
- for is_error, status_code, data in self._handle_stream(response):
281
+ for is_error, status_code, event in _handle_stream(response):
113
282
  try:
283
+ data = event.data
114
284
  output = None
115
285
  usage = None
116
286
  msg = json.loads(data)
@@ -140,10 +310,15 @@ class HttpRequest(BaseRequest):
140
310
  message=msg['message']
141
311
  if 'message' in msg else None) # noqa E501
142
312
  else:
143
- yield DashScopeAPIResponse(request_id=request_id,
144
- status_code=HTTPStatus.OK,
145
- output=output,
146
- usage=usage)
313
+ if self.flattened_output:
314
+ yield msg
315
+ else:
316
+ if self.encryption and self.encryption.is_valid():
317
+ output = self.encryption.decrypt(output)
318
+ yield DashScopeAPIResponse(request_id=request_id,
319
+ status_code=HTTPStatus.OK,
320
+ output=output,
321
+ usage=usage)
147
322
  elif response.status_code == HTTPStatus.OK:
148
323
  json_content = response.json()
149
324
  logger.debug('Response: %s' % json_content)
@@ -157,38 +332,17 @@ class HttpRequest(BaseRequest):
157
332
  usage = json_content['usage']
158
333
  if 'request_id' in json_content:
159
334
  request_id = json_content['request_id']
160
- yield DashScopeAPIResponse(request_id=request_id,
161
- status_code=HTTPStatus.OK,
162
- output=output,
163
- usage=usage)
164
- else:
165
- if 'application/json' in response.headers.get('content-type', ''):
166
- error = response.json()
167
- if 'request_id' in error:
168
- request_id = error['request_id']
169
- if 'message' not in error:
170
- message = ''
171
- logger.error('Request: %s failed, status: %s' %
172
- (self.url, response.status_code))
173
- else:
174
- message = error['message']
175
- logger.error(
176
- 'Request: %s failed, status: %s, message: %s' %
177
- (self.url, response.status_code, error['message']))
178
- yield DashScopeAPIResponse(
179
- request_id=request_id,
180
- status_code=response.status_code,
181
- output=None,
182
- code=error['code']
183
- if 'code' in error else None, # noqa E501
184
- message=message)
335
+ if self.flattened_output:
336
+ yield json_content
185
337
  else:
186
- msg = response.content
338
+ if self.encryption and self.encryption.is_valid():
339
+ output = self.encryption.decrypt(output)
187
340
  yield DashScopeAPIResponse(request_id=request_id,
188
- status_code=response.status_code,
189
- output=None,
190
- code='Unknown',
191
- message=msg.decode('utf-8'))
341
+ status_code=HTTPStatus.OK,
342
+ output=output,
343
+ usage=usage)
344
+ else:
345
+ yield _handle_http_failed_response(response)
192
346
 
193
347
  def _handle_request(self):
194
348
  try:
@@ -220,6 +374,6 @@ class HttpRequest(BaseRequest):
220
374
  self.method)
221
375
  for rsp in self._handle_response(response):
222
376
  yield rsp
223
- except Exception as e:
377
+ except BaseException as e:
224
378
  logger.error(e)
225
379
  raise e
@@ -1,3 +1,5 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
1
3
  import asyncio
2
4
  import json
3
5
  import uuid
@@ -10,7 +12,7 @@ from dashscope.api_entities.base_request import AioBaseRequest
10
12
  from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
11
13
  from dashscope.common.constants import (DEFAULT_REQUEST_TIMEOUT_SECONDS,
12
14
  SERVICE_503_MESSAGE,
13
- WEBSOCKET_ERROR_CODE, StreamResultMode)
15
+ WEBSOCKET_ERROR_CODE)
14
16
  from dashscope.common.error import (RequestFailure, UnexpectedMessageReceived,
15
17
  UnknownMessageReceived)
16
18
  from dashscope.common.logging import logger
@@ -29,10 +31,12 @@ class WebSocketRequest(AioBaseRequest):
29
31
  stream: bool = True,
30
32
  ws_stream_mode: str = WebsocketStreamingMode.OUT,
31
33
  is_binary_input: bool = False,
32
- stream_result_mode: str = StreamResultMode.ACCUMULATE,
33
34
  timeout: int = DEFAULT_REQUEST_TIMEOUT_SECONDS,
35
+ flattened_output: bool = False,
36
+ pre_task_id=None,
37
+ user_agent: str = '',
34
38
  ) -> None:
35
- super().__init__()
39
+ super().__init__(user_agent=user_agent)
36
40
  """HttpRequest.
37
41
 
38
42
  Args:
@@ -45,12 +49,12 @@ class WebSocketRequest(AioBaseRequest):
45
49
  """
46
50
  self.url = url
47
51
  self.stream = stream
52
+ self.flattened_output = flattened_output
48
53
  if timeout is None:
49
54
  self.timeout = DEFAULT_REQUEST_TIMEOUT_SECONDS
50
55
  else:
51
56
  self.timeout = timeout
52
57
  self.ws_stream_mode = ws_stream_mode
53
- self.stream_result_mode = stream_result_mode
54
58
  self.is_binary_input = is_binary_input
55
59
 
56
60
  self.headers = {
@@ -61,6 +65,7 @@ class WebSocketRequest(AioBaseRequest):
61
65
  self.task_headers = {
62
66
  'streaming': self.ws_stream_mode,
63
67
  }
68
+ self.pre_task_id = pre_task_id
64
69
 
65
70
  def add_headers(self, headers):
66
71
  self.headers = {**self.headers, **headers}
@@ -77,6 +82,10 @@ class WebSocketRequest(AioBaseRequest):
77
82
  pass
78
83
  return output
79
84
 
85
+ async def close(self):
86
+ if self.ws is not None and not self.ws.closed:
87
+ await self.ws.close()
88
+
80
89
  async def aio_call(self):
81
90
  response = self.connection_handler()
82
91
  if self.stream:
@@ -140,7 +149,11 @@ class WebSocketRequest(AioBaseRequest):
140
149
  code=e.name,
141
150
  message=e.message)
142
151
  except aiohttp.ClientConnectorError as e:
143
- raise e
152
+ logger.exception(e)
153
+ yield DashScopeAPIResponse(request_id='',
154
+ status_code=-1,
155
+ code='ClientConnectorError',
156
+ message=str(e))
144
157
  except aiohttp.WSServerHandshakeError as e:
145
158
  code = e.status
146
159
  msg = e.message
@@ -154,6 +167,13 @@ class WebSocketRequest(AioBaseRequest):
154
167
  status_code=code,
155
168
  code=code,
156
169
  message=msg)
170
+ except BaseException as e:
171
+ logger.exception(e)
172
+ yield DashScopeAPIResponse(request_id='',
173
+ status_code=-1,
174
+ code='Unknown',
175
+ message='Error type: %s, message: %s' %
176
+ (type(e), e))
157
177
 
158
178
  def _to_DashScopeAPIResponse(self, task_id, is_binary, result):
159
179
  if is_binary:
@@ -177,7 +197,6 @@ class WebSocketRequest(AioBaseRequest):
177
197
  # check if request stream data, re return an iterator,
178
198
  # otherwise we collect data and return user.
179
199
  # no matter what, the response is streaming
180
- final_payload = None
181
200
  is_binary_output = False
182
201
  while True:
183
202
  msg = await ws.receive()
@@ -186,36 +205,21 @@ class WebSocketRequest(AioBaseRequest):
186
205
  msg_json = msg.json()
187
206
  logger.debug('Receive %s event' % msg_json[HEADER][EVENT_KEY])
188
207
  if msg_json[HEADER][EVENT_KEY] == EventType.GENERATED:
189
- if final_payload is None:
190
- final_payload = []
191
208
  payload = msg_json['payload']
192
- if self.stream:
193
- yield False, payload
194
- elif self.stream_result_mode == StreamResultMode.ACCUMULATE: # noqa E501
195
- final_payload = payload
196
- else:
197
- final_payload.append(payload)
209
+ yield False, payload
198
210
  elif msg_json[HEADER][EVENT_KEY] == EventType.FINISHED:
199
- if final_payload is None:
200
- final_payload = []
201
211
  payload = None
202
212
  if 'payload' in msg_json:
203
213
  payload = msg_json['payload']
204
214
  logger.debug(payload)
205
215
  if payload:
206
- if self.stream:
207
- yield False, payload
208
- elif self.stream_result_mode == StreamResultMode.ACCUMULATE: # noqa E501
209
- yield False, payload
210
- else:
211
- final_payload.extend(payload)
212
- yield False, final_payload
216
+ yield False, payload
213
217
  else:
214
218
  if not self.stream:
215
219
  if is_binary_output:
216
- yield True, final_payload
220
+ yield True, payload
217
221
  else:
218
- yield False, final_payload
222
+ yield False, payload
219
223
  break
220
224
  elif msg_json[HEADER][EVENT_KEY] == EventType.FAILED:
221
225
  self._on_failed(msg_json)
@@ -225,14 +229,7 @@ class WebSocketRequest(AioBaseRequest):
225
229
  raise UnknownMessageReceived(error)
226
230
  elif msg.type == aiohttp.WSMsgType.BINARY:
227
231
  is_binary_output = True
228
- if final_payload is None:
229
- final_payload = b''
230
- if self.stream:
231
- yield True, msg.data
232
- elif self.stream_result_mode == StreamResultMode.ACCUMULATE:
233
- final_payload = msg.data
234
- else:
235
- final_payload += msg.data
232
+ yield True, msg.data
236
233
 
237
234
  def _on_failed(self, details):
238
235
  error = RequestFailure(request_id=details[HEADER][TASK_ID],
@@ -243,17 +240,22 @@ class WebSocketRequest(AioBaseRequest):
243
240
  raise error
244
241
 
245
242
  async def _start_task(self, ws):
246
- self.task_headers['task_id'] = uuid.uuid4().hex # create task id.
243
+ if self.pre_task_id is not None:
244
+ self.task_headers['task_id'] = self.pre_task_id
245
+ else:
246
+ self.task_headers['task_id'] = uuid.uuid4().hex # create task id.
247
247
  task_header = {**self.task_headers, ACTION_KEY: ActionType.START}
248
248
  # for binary data, the start action has no input, only parameters.
249
249
  start_data = self.data.get_websocket_start_data()
250
250
  message = self._build_up_message(task_header, start_data)
251
+ logger.debug('Send start task: {}'.format(message))
251
252
  await ws.send_str(message)
252
253
 
253
254
  async def _send_finished_task(self, ws):
254
255
  task_header = {**self.task_headers, ACTION_KEY: ActionType.FINISHED}
255
256
  payload = {'input': {}}
256
257
  message = self._build_up_message(task_header, payload)
258
+ logger.debug('Send finish task: {}'.format(message))
257
259
  await ws.send_str(message)
258
260
 
259
261
  async def _send_continue_task_data(self, ws):
@@ -266,12 +268,19 @@ class WebSocketRequest(AioBaseRequest):
266
268
  if len(input) > 0:
267
269
  if isinstance(input, bytes):
268
270
  await ws.send_bytes(input)
271
+ logger.debug(
272
+ 'Send continue task with bytes: {}'.format(
273
+ len(input)))
269
274
  else:
270
275
  await ws.send_bytes(list(input.values())[0])
276
+ logger.debug(
277
+ 'Send continue task with list[byte]: {}'.format(
278
+ len(input)))
271
279
  else:
272
280
  if len(input) > 0:
273
281
  message = self._build_up_message(headers=headers,
274
282
  payload=input)
283
+ logger.debug('Send continue task: {}'.format(message))
275
284
  await ws.send_str(message)
276
285
  await asyncio.sleep(0.000001)
277
286
 
@@ -0,0 +1,5 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from .application import Application
4
+
5
+ __all__ = [Application]