local-coze 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- local_coze/__init__.py +110 -0
- local_coze/cli/__init__.py +3 -0
- local_coze/cli/chat.py +126 -0
- local_coze/cli/cli.py +34 -0
- local_coze/cli/constants.py +7 -0
- local_coze/cli/db.py +81 -0
- local_coze/cli/embedding.py +193 -0
- local_coze/cli/image.py +162 -0
- local_coze/cli/knowledge.py +195 -0
- local_coze/cli/search.py +198 -0
- local_coze/cli/utils.py +41 -0
- local_coze/cli/video.py +191 -0
- local_coze/cli/video_edit.py +888 -0
- local_coze/cli/voice.py +351 -0
- local_coze/core/__init__.py +25 -0
- local_coze/core/client.py +253 -0
- local_coze/core/config.py +58 -0
- local_coze/core/exceptions.py +67 -0
- local_coze/database/__init__.py +29 -0
- local_coze/database/client.py +170 -0
- local_coze/database/migration.py +342 -0
- local_coze/embedding/__init__.py +31 -0
- local_coze/embedding/client.py +350 -0
- local_coze/embedding/models.py +130 -0
- local_coze/image/__init__.py +19 -0
- local_coze/image/client.py +110 -0
- local_coze/image/models.py +163 -0
- local_coze/knowledge/__init__.py +19 -0
- local_coze/knowledge/client.py +148 -0
- local_coze/knowledge/models.py +45 -0
- local_coze/llm/__init__.py +25 -0
- local_coze/llm/client.py +317 -0
- local_coze/llm/models.py +48 -0
- local_coze/memory/__init__.py +14 -0
- local_coze/memory/client.py +176 -0
- local_coze/s3/__init__.py +12 -0
- local_coze/s3/client.py +580 -0
- local_coze/s3/models.py +18 -0
- local_coze/search/__init__.py +19 -0
- local_coze/search/client.py +183 -0
- local_coze/search/models.py +57 -0
- local_coze/video/__init__.py +17 -0
- local_coze/video/client.py +347 -0
- local_coze/video/models.py +39 -0
- local_coze/video_edit/__init__.py +23 -0
- local_coze/video_edit/examples.py +340 -0
- local_coze/video_edit/frame_extractor.py +176 -0
- local_coze/video_edit/models.py +362 -0
- local_coze/video_edit/video_edit.py +631 -0
- local_coze/voice/__init__.py +17 -0
- local_coze/voice/asr.py +82 -0
- local_coze/voice/models.py +86 -0
- local_coze/voice/tts.py +94 -0
- local_coze-0.0.1.dist-info/METADATA +636 -0
- local_coze-0.0.1.dist-info/RECORD +58 -0
- local_coze-0.0.1.dist-info/WHEEL +4 -0
- local_coze-0.0.1.dist-info/entry_points.txt +3 -0
- local_coze-0.0.1.dist-info/licenses/LICENSE +21 -0
local_coze/cli/voice.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import click
|
|
8
|
+
from coze_coding_utils.runtime_ctx.context import new_context
|
|
9
|
+
from rich.console import Console
|
|
10
|
+
from rich.panel import Panel
|
|
11
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
12
|
+
from rich.table import Table
|
|
13
|
+
|
|
14
|
+
from ..core.config import Config
|
|
15
|
+
from ..voice.asr import ASRClient
|
|
16
|
+
from ..voice.models import TTSConfig
|
|
17
|
+
from ..voice.tts import TTSClient
|
|
18
|
+
from .constants import RUN_MODE_HEADER, RUN_MODE_TEST
|
|
19
|
+
|
|
20
|
+
console = Console()
|
|
21
|
+
|
|
22
|
+
COMMON_SPEAKERS = {
|
|
23
|
+
"zh_female_xueayi_saturn_bigtts": "儿童绘本 (有声阅读)",
|
|
24
|
+
"zh_female_vv_uranus_bigtts": "vivi (通用场景, 中英)",
|
|
25
|
+
"zh_male_dayi_saturn_bigtts": "大壹 (视频配音)",
|
|
26
|
+
"zh_female_mizai_saturn_bigtts": "黑猫侦探社咪仔 (视频配音)",
|
|
27
|
+
"zh_female_jitangnv_saturn_bigtts": "鸡汤女 (视频配音)",
|
|
28
|
+
"zh_female_meilinvyou_saturn_bigtts": "魅力女友 (视频配音)",
|
|
29
|
+
"zh_female_santongyongns_saturn_bigtts": "流畅女声 (视频配音)",
|
|
30
|
+
"zh_male_ruyayichen_saturn_bigtts": "儒雅逸辰 (视频配音)",
|
|
31
|
+
"zh_female_xiaohe_uranus_bigtts": "小何 (通用场景, 默认)",
|
|
32
|
+
"zh_male_m191_uranus_bigtts": "云舟 (通用场景)",
|
|
33
|
+
"zh_male_taocheng_uranus_bigtts": "小天 (通用场景)",
|
|
34
|
+
"saturn_zh_female_keainvsheng_tob": "可爱女生 (角色扮演)",
|
|
35
|
+
"saturn_zh_female_tiaopigongzhu_tob": "调皮公主 (角色扮演)",
|
|
36
|
+
"saturn_zh_male_shuanglangshaonian_tob": "爽朗少年 (角色扮演)",
|
|
37
|
+
"saturn_zh_male_tiancaitongzhuo_tob": "天才同桌 (角色扮演)",
|
|
38
|
+
"saturn_zh_female_cancan_tob": "知性灿灿 (角色扮演)",
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@click.command()
|
|
43
|
+
@click.argument("text")
|
|
44
|
+
@click.option(
|
|
45
|
+
"--output", "-o", required=True, type=click.Path(), help="输出音频文件路径"
|
|
46
|
+
)
|
|
47
|
+
@click.option("--uid", "-u", default="cli_user", help="用户唯一标识")
|
|
48
|
+
@click.option("--speaker", "-s", default=TTSConfig.DEFAULT_SPEAKER, help="音色选择")
|
|
49
|
+
@click.option(
|
|
50
|
+
"--format",
|
|
51
|
+
"-f",
|
|
52
|
+
type=click.Choice(["mp3", "pcm", "ogg_opus"]),
|
|
53
|
+
default="mp3",
|
|
54
|
+
help="音频格式",
|
|
55
|
+
)
|
|
56
|
+
@click.option(
|
|
57
|
+
"--sample-rate",
|
|
58
|
+
type=int,
|
|
59
|
+
default=24000,
|
|
60
|
+
help="采样率 (8000/16000/22050/24000/32000/44100/48000)",
|
|
61
|
+
)
|
|
62
|
+
@click.option("--speech-rate", type=int, default=0, help="语速 (-50 到 100)")
|
|
63
|
+
@click.option("--loudness-rate", type=int, default=0, help="音量 (-50 到 100)")
|
|
64
|
+
@click.option("--ssml", is_flag=True, help="使用 SSML 格式")
|
|
65
|
+
@click.option("--mock", is_flag=True, help="使用 mock 模式(测试运行)")
|
|
66
|
+
@click.option(
|
|
67
|
+
"--header",
|
|
68
|
+
"-H",
|
|
69
|
+
multiple=True,
|
|
70
|
+
help="自定义 HTTP 请求头 (格式: 'Key: Value' 或 'Key=Value',可多次使用)",
|
|
71
|
+
)
|
|
72
|
+
@click.option("--verbose", "-v", is_flag=True, help="显示详细的 HTTP 请求日志")
|
|
73
|
+
def tts(
|
|
74
|
+
text,
|
|
75
|
+
output,
|
|
76
|
+
uid,
|
|
77
|
+
speaker,
|
|
78
|
+
format,
|
|
79
|
+
sample_rate,
|
|
80
|
+
speech_rate,
|
|
81
|
+
loudness_rate,
|
|
82
|
+
ssml,
|
|
83
|
+
mock,
|
|
84
|
+
header,
|
|
85
|
+
verbose,
|
|
86
|
+
):
|
|
87
|
+
"""语音合成 (Text-to-Speech)
|
|
88
|
+
|
|
89
|
+
将文本转换为语音音频文件。
|
|
90
|
+
|
|
91
|
+
音色列表 (按场景分类):
|
|
92
|
+
|
|
93
|
+
有声阅读:
|
|
94
|
+
- zh_female_xueayi_saturn_bigtts (儿童绘本)
|
|
95
|
+
|
|
96
|
+
通用场景:
|
|
97
|
+
- zh_female_xiaohe_uranus_bigtts (小何, 默认)
|
|
98
|
+
- zh_female_vv_uranus_bigtts (vivi, 支持中英)
|
|
99
|
+
- zh_male_m191_uranus_bigtts (云舟)
|
|
100
|
+
- zh_male_taocheng_uranus_bigtts (小天)
|
|
101
|
+
|
|
102
|
+
视频配音:
|
|
103
|
+
- zh_male_dayi_saturn_bigtts (大壹)
|
|
104
|
+
- zh_female_mizai_saturn_bigtts (黑猫侦探社咪仔)
|
|
105
|
+
- zh_female_jitangnv_saturn_bigtts (鸡汤女)
|
|
106
|
+
- zh_female_meilinvyou_saturn_bigtts (魅力女友)
|
|
107
|
+
- zh_female_santongyongns_saturn_bigtts (流畅女声)
|
|
108
|
+
- zh_male_ruyayichen_saturn_bigtts (儒雅逸辰)
|
|
109
|
+
|
|
110
|
+
角色扮演:
|
|
111
|
+
- saturn_zh_female_keainvsheng_tob (可爱女生)
|
|
112
|
+
- saturn_zh_female_tiaopigongzhu_tob (调皮公主)
|
|
113
|
+
- saturn_zh_male_shuanglangshaonian_tob (爽朗少年)
|
|
114
|
+
- saturn_zh_male_tiancaitongzhuo_tob (天才同桌)
|
|
115
|
+
- saturn_zh_female_cancan_tob (知性灿灿)
|
|
116
|
+
|
|
117
|
+
示例:
|
|
118
|
+
coze-coding-ai tts "你好,欢迎使用" -o hello.mp3
|
|
119
|
+
coze-coding-ai tts "测试视频配音" -o test.mp3 -s zh_male_dayi_saturn_bigtts
|
|
120
|
+
coze-coding-ai tts "儿童故事" -o story.mp3 -s zh_female_xueayi_saturn_bigtts --speech-rate 20
|
|
121
|
+
"""
|
|
122
|
+
try:
|
|
123
|
+
from .utils import parse_headers
|
|
124
|
+
|
|
125
|
+
config = Config()
|
|
126
|
+
|
|
127
|
+
ctx = None
|
|
128
|
+
custom_headers = parse_headers(header) or {}
|
|
129
|
+
|
|
130
|
+
if mock:
|
|
131
|
+
ctx = new_context(method="tts.generate", headers=custom_headers)
|
|
132
|
+
custom_headers[RUN_MODE_HEADER] = RUN_MODE_TEST
|
|
133
|
+
console.print("[yellow]🧪 Mock 模式已启用(测试运行)[/yellow]")
|
|
134
|
+
|
|
135
|
+
client = TTSClient(
|
|
136
|
+
config, ctx=ctx, custom_headers=custom_headers, verbose=verbose
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
with Progress(
|
|
140
|
+
SpinnerColumn(),
|
|
141
|
+
TextColumn("[progress.description]{task.description}"),
|
|
142
|
+
console=console,
|
|
143
|
+
) as progress:
|
|
144
|
+
task = progress.add_task("[cyan]正在合成语音...", total=None)
|
|
145
|
+
|
|
146
|
+
if ssml:
|
|
147
|
+
audio_url, audio_size = client.synthesize(
|
|
148
|
+
uid=uid,
|
|
149
|
+
ssml=text,
|
|
150
|
+
speaker=speaker,
|
|
151
|
+
audio_format=format,
|
|
152
|
+
sample_rate=sample_rate,
|
|
153
|
+
speech_rate=speech_rate,
|
|
154
|
+
loudness_rate=loudness_rate,
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
audio_url, audio_size = client.synthesize(
|
|
158
|
+
uid=uid,
|
|
159
|
+
text=text,
|
|
160
|
+
speaker=speaker,
|
|
161
|
+
audio_format=format,
|
|
162
|
+
sample_rate=sample_rate,
|
|
163
|
+
speech_rate=speech_rate,
|
|
164
|
+
loudness_rate=loudness_rate,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
progress.update(task, description="[green]✓ 语音合成完成")
|
|
168
|
+
|
|
169
|
+
os.makedirs(
|
|
170
|
+
(
|
|
171
|
+
os.path.dirname(os.path.abspath(output))
|
|
172
|
+
if os.path.dirname(output)
|
|
173
|
+
else "."
|
|
174
|
+
),
|
|
175
|
+
exist_ok=True,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if audio_url:
|
|
179
|
+
import requests
|
|
180
|
+
|
|
181
|
+
response = requests.get(audio_url)
|
|
182
|
+
response.raise_for_status()
|
|
183
|
+
with open(output, "wb") as f:
|
|
184
|
+
f.write(response.content)
|
|
185
|
+
file_size = len(response.content)
|
|
186
|
+
else:
|
|
187
|
+
file_size = audio_size
|
|
188
|
+
|
|
189
|
+
table = Table(show_header=False, box=None, padding=(0, 2))
|
|
190
|
+
table.add_column("Key", style="cyan")
|
|
191
|
+
table.add_column("Value", style="white", no_wrap=False, overflow="fold")
|
|
192
|
+
|
|
193
|
+
display_text = text[:50] + "..." if len(text) > 50 else text
|
|
194
|
+
table.add_row("文本", display_text)
|
|
195
|
+
table.add_row("音色", COMMON_SPEAKERS.get(speaker, speaker))
|
|
196
|
+
table.add_row("格式", format.upper())
|
|
197
|
+
table.add_row("采样率", f"{sample_rate} Hz")
|
|
198
|
+
if speech_rate != 0:
|
|
199
|
+
table.add_row("语速", f"{speech_rate:+d}")
|
|
200
|
+
if loudness_rate != 0:
|
|
201
|
+
table.add_row("音量", f"{loudness_rate:+d}")
|
|
202
|
+
table.add_row("文件", output)
|
|
203
|
+
table.add_row("大小", f"{file_size / 1024:.1f} KB")
|
|
204
|
+
if audio_url:
|
|
205
|
+
table.add_row("URL", audio_url)
|
|
206
|
+
|
|
207
|
+
console.print()
|
|
208
|
+
console.print(
|
|
209
|
+
Panel(
|
|
210
|
+
table,
|
|
211
|
+
title="[bold green]语音合成完成[/bold green]",
|
|
212
|
+
border_style="green",
|
|
213
|
+
)
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
except Exception as e:
|
|
217
|
+
console.print(f"[red]✗ 错误: {str(e)}[/red]")
|
|
218
|
+
raise click.Abort()
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@click.command()
|
|
222
|
+
@click.argument("audio")
|
|
223
|
+
@click.option("--uid", "-u", default="cli_user", help="用户唯一标识")
|
|
224
|
+
@click.option("--output", "-o", type=click.Path(), help="输出文本文件路径")
|
|
225
|
+
@click.option(
|
|
226
|
+
"--format",
|
|
227
|
+
"-f",
|
|
228
|
+
type=click.Choice(["text", "json"]),
|
|
229
|
+
default="text",
|
|
230
|
+
help="输出格式",
|
|
231
|
+
)
|
|
232
|
+
@click.option("--base64", is_flag=True, help="将本地文件转为 base64 上传")
|
|
233
|
+
@click.option("--mock", is_flag=True, help="使用 mock 模式(测试运行)")
|
|
234
|
+
@click.option(
|
|
235
|
+
"--header",
|
|
236
|
+
"-H",
|
|
237
|
+
multiple=True,
|
|
238
|
+
help="自定义 HTTP 请求头 (格式: 'Key: Value' 或 'Key=Value',可多次使用)",
|
|
239
|
+
)
|
|
240
|
+
@click.option("--verbose", "-v", is_flag=True, help="显示详细的 HTTP 请求日志")
|
|
241
|
+
def asr(audio, uid, output, format, base64, mock, header, verbose):
|
|
242
|
+
"""语音识别 (Automatic Speech Recognition)
|
|
243
|
+
|
|
244
|
+
将语音音频转换为文本。
|
|
245
|
+
|
|
246
|
+
音频要求:
|
|
247
|
+
- 音频时长 ≤ 2小时
|
|
248
|
+
- 音频大小 ≤ 100MB
|
|
249
|
+
- 支持编码: WAV/MP3/OGG OPUS
|
|
250
|
+
|
|
251
|
+
支持:
|
|
252
|
+
- 本地音频文件
|
|
253
|
+
- 音频 URL
|
|
254
|
+
- Base64 编码上传
|
|
255
|
+
|
|
256
|
+
示例:
|
|
257
|
+
coze-coding-ai asr ./audio.mp3
|
|
258
|
+
coze-coding-ai asr https://example.com/audio.mp3
|
|
259
|
+
coze-coding-ai asr ./audio.mp3 -o result.txt
|
|
260
|
+
coze-coding-ai asr ./audio.mp3 -f json
|
|
261
|
+
coze-coding-ai asr audio.mp3 --base64 --output result.txt
|
|
262
|
+
"""
|
|
263
|
+
try:
|
|
264
|
+
from .utils import parse_headers
|
|
265
|
+
|
|
266
|
+
config = Config()
|
|
267
|
+
|
|
268
|
+
ctx = None
|
|
269
|
+
custom_headers = parse_headers(header) or {}
|
|
270
|
+
|
|
271
|
+
if mock:
|
|
272
|
+
ctx = new_context(method="asr.recognize", headers=custom_headers)
|
|
273
|
+
custom_headers[RUN_MODE_HEADER] = RUN_MODE_TEST
|
|
274
|
+
console.print("[yellow]🧪 Mock 模式已启用(测试运行)[/yellow]")
|
|
275
|
+
|
|
276
|
+
client = ASRClient(
|
|
277
|
+
config, ctx=ctx, custom_headers=custom_headers, verbose=verbose
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
audio_url = None
|
|
281
|
+
audio_base64 = None
|
|
282
|
+
|
|
283
|
+
if audio.startswith(("http://", "https://")):
|
|
284
|
+
audio_url = audio
|
|
285
|
+
console.print(f"[cyan]正在识别 URL 音频:[/cyan] {audio}")
|
|
286
|
+
else:
|
|
287
|
+
if not os.path.exists(audio):
|
|
288
|
+
raise FileNotFoundError(f"音频文件不存在: {audio}")
|
|
289
|
+
|
|
290
|
+
if base64:
|
|
291
|
+
import base64 as b64_module
|
|
292
|
+
|
|
293
|
+
console.print(f"[cyan]正在读取并编码音频文件:[/cyan] {audio}")
|
|
294
|
+
with open(audio, "rb") as f:
|
|
295
|
+
audio_data = f.read()
|
|
296
|
+
audio_base64 = b64_module.b64encode(audio_data).decode("utf-8")
|
|
297
|
+
else:
|
|
298
|
+
raise ValueError(
|
|
299
|
+
"本地文件需要先上传到可访问的 URL,或使用 --base64 选项"
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
with Progress(
|
|
303
|
+
SpinnerColumn(),
|
|
304
|
+
TextColumn("[progress.description]{task.description}"),
|
|
305
|
+
console=console,
|
|
306
|
+
) as progress:
|
|
307
|
+
task = progress.add_task("[cyan]正在识别语音...", total=None)
|
|
308
|
+
|
|
309
|
+
text, data = client.recognize(
|
|
310
|
+
uid=uid, url=audio_url, base64_data=audio_base64
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
progress.update(task, description="[green]✓ 识别完成")
|
|
314
|
+
|
|
315
|
+
console.print()
|
|
316
|
+
|
|
317
|
+
if format == "json":
|
|
318
|
+
result = {
|
|
319
|
+
"text": text,
|
|
320
|
+
"duration": data.get("result", {}).get("duration"),
|
|
321
|
+
"utterances": data.get("result", {}).get("utterances", []),
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
if output:
|
|
325
|
+
with open(output, "w", encoding="utf-8") as f:
|
|
326
|
+
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
327
|
+
console.print(f"[green]✓[/green] 结果已保存到: {output}")
|
|
328
|
+
else:
|
|
329
|
+
console.print_json(data=result)
|
|
330
|
+
else:
|
|
331
|
+
console.print(
|
|
332
|
+
Panel(
|
|
333
|
+
text,
|
|
334
|
+
title="[bold green]识别结果[/bold green]",
|
|
335
|
+
border_style="green",
|
|
336
|
+
padding=(1, 2),
|
|
337
|
+
)
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
duration = data.get("result", {}).get("duration")
|
|
341
|
+
if duration:
|
|
342
|
+
console.print(f"\n[dim]音频时长: {duration / 1000:.1f} 秒[/dim]")
|
|
343
|
+
|
|
344
|
+
if output:
|
|
345
|
+
with open(output, "w", encoding="utf-8") as f:
|
|
346
|
+
f.write(text)
|
|
347
|
+
console.print(f"[green]✓[/green] 结果已保存到: {output}")
|
|
348
|
+
|
|
349
|
+
except Exception as e:
|
|
350
|
+
console.print(f"[red]✗ 错误: {str(e)}[/red]")
|
|
351
|
+
raise click.Abort()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from .client import BaseClient
|
|
2
|
+
from .config import Config
|
|
3
|
+
from .exceptions import (
|
|
4
|
+
CozeSDKError,
|
|
5
|
+
ConfigurationError,
|
|
6
|
+
APIError,
|
|
7
|
+
NetworkError,
|
|
8
|
+
ValidationError
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from .. import __version__
|
|
13
|
+
except ImportError:
|
|
14
|
+
__version__ = "0.0.0"
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"BaseClient",
|
|
18
|
+
"Config",
|
|
19
|
+
"CozeSDKError",
|
|
20
|
+
"ConfigurationError",
|
|
21
|
+
"APIError",
|
|
22
|
+
"NetworkError",
|
|
23
|
+
"ValidationError",
|
|
24
|
+
"__version__",
|
|
25
|
+
]
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
from typing import Dict, Optional
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
from coze_coding_utils.runtime_ctx.context import Context, default_headers
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
from rich.panel import Panel
|
|
9
|
+
from rich.syntax import Syntax
|
|
10
|
+
|
|
11
|
+
from .config import Config
|
|
12
|
+
from .exceptions import APIError, NetworkError
|
|
13
|
+
|
|
14
|
+
console = Console()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BaseClient:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
config: Optional[Config] = None,
|
|
21
|
+
ctx: Optional[Context] = None,
|
|
22
|
+
custom_headers: Optional[Dict[str, str]] = None,
|
|
23
|
+
verbose: bool = False,
|
|
24
|
+
):
|
|
25
|
+
if config is None:
|
|
26
|
+
config = Config()
|
|
27
|
+
self.config = config
|
|
28
|
+
self.ctx = ctx
|
|
29
|
+
self.custom_headers = custom_headers or {}
|
|
30
|
+
self.verbose = verbose
|
|
31
|
+
|
|
32
|
+
def _request(
|
|
33
|
+
self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs
|
|
34
|
+
) -> dict:
|
|
35
|
+
request_headers = {}
|
|
36
|
+
|
|
37
|
+
if self.ctx is not None:
|
|
38
|
+
ctx_headers = default_headers(self.ctx)
|
|
39
|
+
request_headers.update(ctx_headers)
|
|
40
|
+
|
|
41
|
+
if self.custom_headers:
|
|
42
|
+
request_headers.update(self.custom_headers)
|
|
43
|
+
|
|
44
|
+
config_headers = self.config.get_headers(headers)
|
|
45
|
+
request_headers.update(config_headers)
|
|
46
|
+
|
|
47
|
+
response = self._make_request(
|
|
48
|
+
method=method, url=url, headers=request_headers, **kwargs
|
|
49
|
+
)
|
|
50
|
+
return self._handle_response(response)
|
|
51
|
+
|
|
52
|
+
def _request_with_response(
|
|
53
|
+
self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs
|
|
54
|
+
) -> requests.Response:
|
|
55
|
+
request_headers = {}
|
|
56
|
+
|
|
57
|
+
if self.ctx is not None:
|
|
58
|
+
ctx_headers = default_headers(self.ctx)
|
|
59
|
+
request_headers.update(ctx_headers)
|
|
60
|
+
|
|
61
|
+
if self.custom_headers:
|
|
62
|
+
request_headers.update(self.custom_headers)
|
|
63
|
+
|
|
64
|
+
config_headers = self.config.get_headers(headers)
|
|
65
|
+
request_headers.update(config_headers)
|
|
66
|
+
|
|
67
|
+
response = self._make_request(
|
|
68
|
+
method=method, url=url, headers=request_headers, **kwargs
|
|
69
|
+
)
|
|
70
|
+
return response
|
|
71
|
+
|
|
72
|
+
def _request_stream(
|
|
73
|
+
self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs
|
|
74
|
+
) -> requests.Response:
|
|
75
|
+
request_headers = {}
|
|
76
|
+
|
|
77
|
+
if self.ctx is not None:
|
|
78
|
+
ctx_headers = default_headers(self.ctx)
|
|
79
|
+
request_headers.update(ctx_headers)
|
|
80
|
+
|
|
81
|
+
if self.custom_headers:
|
|
82
|
+
request_headers.update(self.custom_headers)
|
|
83
|
+
|
|
84
|
+
config_headers = self.config.get_headers(headers)
|
|
85
|
+
request_headers.update(config_headers)
|
|
86
|
+
|
|
87
|
+
kwargs["stream"] = True
|
|
88
|
+
return self._make_request(
|
|
89
|
+
method=method, url=url, headers=request_headers, **kwargs
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def _sanitize_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
|
93
|
+
sanitized = headers.copy()
|
|
94
|
+
if "Authorization" in sanitized:
|
|
95
|
+
token = sanitized["Authorization"]
|
|
96
|
+
if token.startswith("Bearer "):
|
|
97
|
+
token = token[7:]
|
|
98
|
+
if len(token) > 16:
|
|
99
|
+
sanitized["Authorization"] = f"Bearer {token[:8]}...{token[-4:]}"
|
|
100
|
+
else:
|
|
101
|
+
sanitized["Authorization"] = "Bearer ****"
|
|
102
|
+
return sanitized
|
|
103
|
+
|
|
104
|
+
def _log_request(self, method: str, url: str, **kwargs):
|
|
105
|
+
if not self.verbose:
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
parts = []
|
|
109
|
+
parts.append(f"[bold cyan]{method}[/bold cyan] {url}\n")
|
|
110
|
+
|
|
111
|
+
headers = kwargs.get("headers", {})
|
|
112
|
+
sanitized_headers = self._sanitize_headers(headers)
|
|
113
|
+
if sanitized_headers:
|
|
114
|
+
parts.append("[bold]Headers:[/bold]")
|
|
115
|
+
for key, value in sanitized_headers.items():
|
|
116
|
+
parts.append(f" {key}: {value}")
|
|
117
|
+
parts.append("")
|
|
118
|
+
|
|
119
|
+
if "json" in kwargs and kwargs["json"]:
|
|
120
|
+
parts.append("[bold]Body:[/bold]")
|
|
121
|
+
try:
|
|
122
|
+
json_str = json.dumps(kwargs["json"], ensure_ascii=False, indent=2)
|
|
123
|
+
parts.append("")
|
|
124
|
+
except Exception:
|
|
125
|
+
json_str = str(kwargs["json"])
|
|
126
|
+
parts.append("")
|
|
127
|
+
|
|
128
|
+
content = "\n".join(parts)
|
|
129
|
+
console.print(
|
|
130
|
+
Panel(
|
|
131
|
+
content,
|
|
132
|
+
title="[bold green]HTTP Request[/bold green]",
|
|
133
|
+
border_style="green",
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
if "json" in kwargs and kwargs["json"]:
|
|
138
|
+
try:
|
|
139
|
+
json_str = json.dumps(kwargs["json"], ensure_ascii=False, indent=2)
|
|
140
|
+
console.print(
|
|
141
|
+
Syntax(json_str, "json", theme="monokai", line_numbers=False)
|
|
142
|
+
)
|
|
143
|
+
console.print()
|
|
144
|
+
except Exception:
|
|
145
|
+
pass
|
|
146
|
+
|
|
147
|
+
def _log_response(self, response: requests.Response, is_stream: bool = False):
|
|
148
|
+
if not self.verbose:
|
|
149
|
+
return
|
|
150
|
+
|
|
151
|
+
parts = []
|
|
152
|
+
parts.append(f"[bold]Status:[/bold] {response.status_code} {response.reason}\n")
|
|
153
|
+
|
|
154
|
+
response_headers = dict(response.headers)
|
|
155
|
+
sanitized_response_headers = self._sanitize_headers(response_headers)
|
|
156
|
+
if sanitized_response_headers:
|
|
157
|
+
parts.append("[bold]Response Headers:[/bold]")
|
|
158
|
+
for key, value in sanitized_response_headers.items():
|
|
159
|
+
parts.append(f" {key}: {value}")
|
|
160
|
+
parts.append("")
|
|
161
|
+
|
|
162
|
+
if is_stream:
|
|
163
|
+
parts.append("[yellow]⚡ Streaming response - body not shown[/yellow]")
|
|
164
|
+
content = "\n".join(parts)
|
|
165
|
+
console.print(
|
|
166
|
+
Panel(
|
|
167
|
+
content,
|
|
168
|
+
title="[bold blue]HTTP Response[/bold blue]",
|
|
169
|
+
border_style="blue",
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
console.print()
|
|
173
|
+
return
|
|
174
|
+
|
|
175
|
+
parts.append("[bold]Body:[/bold]")
|
|
176
|
+
|
|
177
|
+
content = "\n".join(parts)
|
|
178
|
+
console.print(
|
|
179
|
+
Panel(
|
|
180
|
+
content,
|
|
181
|
+
title="[bold blue]HTTP Response[/bold blue]",
|
|
182
|
+
border_style="blue",
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
try:
|
|
187
|
+
response_data = response.json()
|
|
188
|
+
json_str = json.dumps(response_data, ensure_ascii=False, indent=2)
|
|
189
|
+
if len(json_str) > 2000:
|
|
190
|
+
json_str = json_str[:2000] + "\n... (truncated)"
|
|
191
|
+
console.print(Syntax(json_str, "json", theme="monokai", line_numbers=False))
|
|
192
|
+
except Exception:
|
|
193
|
+
body_text = response.text[:500]
|
|
194
|
+
if len(response.text) > 500:
|
|
195
|
+
body_text += "... (truncated)"
|
|
196
|
+
console.print(f" {body_text}")
|
|
197
|
+
|
|
198
|
+
console.print()
|
|
199
|
+
|
|
200
|
+
def _make_request(self, method: str, url: str, **kwargs) -> requests.Response:
|
|
201
|
+
last_error = None
|
|
202
|
+
is_stream = kwargs.get("stream", False)
|
|
203
|
+
|
|
204
|
+
for attempt in range(self.config.retry_times):
|
|
205
|
+
try:
|
|
206
|
+
if attempt == 0:
|
|
207
|
+
self._log_request(method, url, **kwargs)
|
|
208
|
+
|
|
209
|
+
response = requests.request(
|
|
210
|
+
method=method, url=url, timeout=self.config.timeout, **kwargs
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
if attempt == 0:
|
|
214
|
+
self._log_response(response, is_stream=is_stream)
|
|
215
|
+
|
|
216
|
+
return response
|
|
217
|
+
|
|
218
|
+
except requests.exceptions.RequestException as e:
|
|
219
|
+
last_error = NetworkError(str(e), e)
|
|
220
|
+
if attempt < self.config.retry_times - 1:
|
|
221
|
+
time.sleep(self.config.retry_delay * (attempt + 1))
|
|
222
|
+
continue
|
|
223
|
+
|
|
224
|
+
raise last_error
|
|
225
|
+
|
|
226
|
+
def _handle_response(self, response: requests.Response) -> dict:
|
|
227
|
+
try:
|
|
228
|
+
data = response.json()
|
|
229
|
+
except Exception as e:
|
|
230
|
+
raise APIError(
|
|
231
|
+
f"响应解析失败: {str(e)}, logid: {response.headers.get('X-Tt-Logid')}, 响应内容: {response.text[:200]}",
|
|
232
|
+
status_code=response.status_code,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
try:
|
|
236
|
+
response.raise_for_status()
|
|
237
|
+
except requests.exceptions.HTTPError as e:
|
|
238
|
+
error_msg = (
|
|
239
|
+
f"HTTP 错误: {str(e)}, logid: {response.headers.get('X-Tt-Logid')}"
|
|
240
|
+
)
|
|
241
|
+
if data:
|
|
242
|
+
error_msg += f", 响应数据: {data}"
|
|
243
|
+
raise APIError(
|
|
244
|
+
error_msg, status_code=response.status_code, response_data=data
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
return data
|
|
248
|
+
|
|
249
|
+
def __enter__(self):
|
|
250
|
+
return self
|
|
251
|
+
|
|
252
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
253
|
+
pass
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from .exceptions import ConfigurationError
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Config:
|
|
8
|
+
DEFAULT_RETRY_TIMES = 1
|
|
9
|
+
DEFAULT_RETRY_DELAY = 1.0
|
|
10
|
+
DEFAULT_TIMEOUT = 900
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
api_key: Optional[str] = None,
|
|
15
|
+
base_url: Optional[str] = None,
|
|
16
|
+
base_model_url: Optional[str] = None,
|
|
17
|
+
retry_times: int = DEFAULT_RETRY_TIMES,
|
|
18
|
+
retry_delay: float = DEFAULT_RETRY_DELAY,
|
|
19
|
+
timeout: int = DEFAULT_TIMEOUT,
|
|
20
|
+
):
|
|
21
|
+
self.api_key = api_key or self._get_env_var("COZE_WORKLOAD_IDENTITY_API_KEY")
|
|
22
|
+
self.base_url = base_url or self._get_env_var("COZE_INTEGRATION_BASE_URL")
|
|
23
|
+
self.base_model_url = base_model_url or self._get_env_var(
|
|
24
|
+
"COZE_INTEGRATION_MODEL_BASE_URL"
|
|
25
|
+
)
|
|
26
|
+
self.retry_times = retry_times
|
|
27
|
+
self.retry_delay = retry_delay
|
|
28
|
+
self.timeout = timeout
|
|
29
|
+
|
|
30
|
+
self._validate()
|
|
31
|
+
|
|
32
|
+
def _get_env_var(self, key: str) -> str:
|
|
33
|
+
value = os.getenv(key)
|
|
34
|
+
if not value:
|
|
35
|
+
raise ConfigurationError(
|
|
36
|
+
f"环境变量 {key} 未设置,请确保已正确配置", missing_key=key
|
|
37
|
+
)
|
|
38
|
+
return value
|
|
39
|
+
|
|
40
|
+
def _validate(self):
|
|
41
|
+
if not self.api_key:
|
|
42
|
+
raise ConfigurationError("API Key 未配置")
|
|
43
|
+
if not self.base_url and not self.base_model_url:
|
|
44
|
+
raise ConfigurationError("Base URL 未配置")
|
|
45
|
+
|
|
46
|
+
def get_headers(self, ctx_headers: Optional[dict] = None) -> dict:
|
|
47
|
+
try:
|
|
48
|
+
from .. import __version__
|
|
49
|
+
except ImportError:
|
|
50
|
+
__version__ = "0.0.0"
|
|
51
|
+
|
|
52
|
+
headers = {}
|
|
53
|
+
if ctx_headers:
|
|
54
|
+
headers.update(ctx_headers)
|
|
55
|
+
headers.setdefault("Content-Type", "application/json")
|
|
56
|
+
headers.setdefault("Authorization", f"Bearer {self.api_key}")
|
|
57
|
+
headers.setdefault("X-Client-Sdk", f"coze-coding-dev-sdk-python/{__version__}")
|
|
58
|
+
return headers
|