langchain-camb 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_camb/__init__.py +81 -0
- langchain_camb/toolkits/__init__.py +5 -0
- langchain_camb/toolkits/camb_toolkit.py +148 -0
- langchain_camb/tools/__init__.py +40 -0
- langchain_camb/tools/audio_separation.py +189 -0
- langchain_camb/tools/base.py +161 -0
- langchain_camb/tools/text_to_sound.py +156 -0
- langchain_camb/tools/transcription.py +189 -0
- langchain_camb/tools/translated_tts.py +340 -0
- langchain_camb/tools/translation.py +150 -0
- langchain_camb/tools/tts.py +182 -0
- langchain_camb/tools/voice_clone.py +152 -0
- langchain_camb/tools/voice_list.py +108 -0
- langchain_camb/version.py +3 -0
- langchain_camb-0.1.0.dist-info/METADATA +307 -0
- langchain_camb-0.1.0.dist-info/RECORD +17 -0
- langchain_camb-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""LangChain integration for CAMB AI.
|
|
2
|
+
|
|
3
|
+
CAMB AI provides multilingual audio and localization services including:
|
|
4
|
+
- Text-to-Speech (140+ languages)
|
|
5
|
+
- Translation
|
|
6
|
+
- Transcription with speaker identification
|
|
7
|
+
- Voice cloning
|
|
8
|
+
- Text-to-Sound generation
|
|
9
|
+
- Audio separation
|
|
10
|
+
|
|
11
|
+
Example:
|
|
12
|
+
```python
|
|
13
|
+
from langchain_camb import CambToolkit, CambTTSTool
|
|
14
|
+
from langchain_openai import ChatOpenAI
|
|
15
|
+
from langgraph.prebuilt import create_react_agent
|
|
16
|
+
|
|
17
|
+
# Use individual tools
|
|
18
|
+
tts = CambTTSTool()
|
|
19
|
+
audio_path = tts.invoke({
|
|
20
|
+
"text": "Hello, world!",
|
|
21
|
+
"language": "en-us",
|
|
22
|
+
"voice_id": 147320
|
|
23
|
+
})
|
|
24
|
+
|
|
25
|
+
# Or use the toolkit with an agent
|
|
26
|
+
toolkit = CambToolkit()
|
|
27
|
+
agent = create_react_agent(ChatOpenAI(), toolkit.get_tools())
|
|
28
|
+
agent.invoke({
|
|
29
|
+
"messages": [{"role": "user", "content": "Say hello in Spanish"}]
|
|
30
|
+
})
|
|
31
|
+
```
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
from langchain_camb.tools import (
|
|
35
|
+
AudioSeparationInput,
|
|
36
|
+
CambAudioSeparationTool,
|
|
37
|
+
CambBaseTool,
|
|
38
|
+
CambTextToSoundTool,
|
|
39
|
+
CambTranscriptionTool,
|
|
40
|
+
CambTranslatedTTSTool,
|
|
41
|
+
CambTranslationTool,
|
|
42
|
+
CambTTSTool,
|
|
43
|
+
CambVoiceCloneTool,
|
|
44
|
+
CambVoiceListTool,
|
|
45
|
+
TextToSoundInput,
|
|
46
|
+
TranscriptionInput,
|
|
47
|
+
TranslatedTTSInput,
|
|
48
|
+
TranslationInput,
|
|
49
|
+
TTSInput,
|
|
50
|
+
VoiceCloneInput,
|
|
51
|
+
VoiceListInput,
|
|
52
|
+
)
|
|
53
|
+
from langchain_camb.toolkits import CambToolkit
|
|
54
|
+
from langchain_camb.version import __version__
|
|
55
|
+
|
|
56
|
+
__all__ = [
|
|
57
|
+
# Version
|
|
58
|
+
"__version__",
|
|
59
|
+
# Toolkit
|
|
60
|
+
"CambToolkit",
|
|
61
|
+
# Base
|
|
62
|
+
"CambBaseTool",
|
|
63
|
+
# Tools
|
|
64
|
+
"CambTTSTool",
|
|
65
|
+
"CambTranslatedTTSTool",
|
|
66
|
+
"CambTranslationTool",
|
|
67
|
+
"CambTranscriptionTool",
|
|
68
|
+
"CambVoiceListTool",
|
|
69
|
+
"CambVoiceCloneTool",
|
|
70
|
+
"CambTextToSoundTool",
|
|
71
|
+
"CambAudioSeparationTool",
|
|
72
|
+
# Input schemas
|
|
73
|
+
"TTSInput",
|
|
74
|
+
"TranslatedTTSInput",
|
|
75
|
+
"TranslationInput",
|
|
76
|
+
"TranscriptionInput",
|
|
77
|
+
"VoiceListInput",
|
|
78
|
+
"VoiceCloneInput",
|
|
79
|
+
"TextToSoundInput",
|
|
80
|
+
"AudioSeparationInput",
|
|
81
|
+
]
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""CAMB AI toolkit that bundles all CAMB tools."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import List, Optional
|
|
7
|
+
|
|
8
|
+
from langchain_core.tools import BaseTool
|
|
9
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
10
|
+
|
|
11
|
+
from langchain_camb.tools import (
|
|
12
|
+
CambAudioSeparationTool,
|
|
13
|
+
CambTextToSoundTool,
|
|
14
|
+
CambTranscriptionTool,
|
|
15
|
+
CambTranslatedTTSTool,
|
|
16
|
+
CambTranslationTool,
|
|
17
|
+
CambTTSTool,
|
|
18
|
+
CambVoiceCloneTool,
|
|
19
|
+
CambVoiceListTool,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CambToolkit(BaseModel):
|
|
24
|
+
"""Toolkit that bundles all CAMB AI tools.
|
|
25
|
+
|
|
26
|
+
Provides convenient access to all CAMB AI services:
|
|
27
|
+
- Text-to-Speech (TTS)
|
|
28
|
+
- Translated TTS
|
|
29
|
+
- Translation
|
|
30
|
+
- Transcription
|
|
31
|
+
- Voice Listing
|
|
32
|
+
- Voice Cloning
|
|
33
|
+
- Text-to-Sound
|
|
34
|
+
- Audio Separation
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
```python
|
|
38
|
+
from langchain_camb import CambToolkit
|
|
39
|
+
from langchain_openai import ChatOpenAI
|
|
40
|
+
from langgraph.prebuilt import create_react_agent
|
|
41
|
+
|
|
42
|
+
toolkit = CambToolkit()
|
|
43
|
+
tools = toolkit.get_tools()
|
|
44
|
+
|
|
45
|
+
agent = create_react_agent(ChatOpenAI(), tools)
|
|
46
|
+
agent.invoke({
|
|
47
|
+
"messages": [{"role": "user", "content": "Say hello in Spanish"}]
|
|
48
|
+
})
|
|
49
|
+
```
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
api_key: Optional[str] = Field(
|
|
53
|
+
default=None,
|
|
54
|
+
description="CAMB AI API key. Falls back to CAMB_API_KEY environment variable.",
|
|
55
|
+
)
|
|
56
|
+
base_url: Optional[str] = Field(
|
|
57
|
+
default=None,
|
|
58
|
+
description="Optional custom base URL for CAMB AI API.",
|
|
59
|
+
)
|
|
60
|
+
timeout: float = Field(
|
|
61
|
+
default=60.0,
|
|
62
|
+
description="Request timeout in seconds.",
|
|
63
|
+
)
|
|
64
|
+
include_tts: bool = Field(
|
|
65
|
+
default=True,
|
|
66
|
+
description="Include TTS tool.",
|
|
67
|
+
)
|
|
68
|
+
include_translated_tts: bool = Field(
|
|
69
|
+
default=True,
|
|
70
|
+
description="Include Translated TTS tool.",
|
|
71
|
+
)
|
|
72
|
+
include_translation: bool = Field(
|
|
73
|
+
default=True,
|
|
74
|
+
description="Include Translation tool.",
|
|
75
|
+
)
|
|
76
|
+
include_transcription: bool = Field(
|
|
77
|
+
default=True,
|
|
78
|
+
description="Include Transcription tool.",
|
|
79
|
+
)
|
|
80
|
+
include_voice_list: bool = Field(
|
|
81
|
+
default=True,
|
|
82
|
+
description="Include Voice List tool.",
|
|
83
|
+
)
|
|
84
|
+
include_voice_clone: bool = Field(
|
|
85
|
+
default=True,
|
|
86
|
+
description="Include Voice Clone tool.",
|
|
87
|
+
)
|
|
88
|
+
include_text_to_sound: bool = Field(
|
|
89
|
+
default=True,
|
|
90
|
+
description="Include Text-to-Sound tool.",
|
|
91
|
+
)
|
|
92
|
+
include_audio_separation: bool = Field(
|
|
93
|
+
default=True,
|
|
94
|
+
description="Include Audio Separation tool.",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
98
|
+
|
|
99
|
+
def _get_api_key(self) -> str:
|
|
100
|
+
"""Get API key from field or environment."""
|
|
101
|
+
key = self.api_key or os.environ.get("CAMB_API_KEY")
|
|
102
|
+
if not key:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
"CAMB AI API key is required. "
|
|
105
|
+
"Set it via 'api_key' parameter or CAMB_API_KEY environment variable."
|
|
106
|
+
)
|
|
107
|
+
return key
|
|
108
|
+
|
|
109
|
+
def get_tools(self) -> List[BaseTool]:
|
|
110
|
+
"""Get all enabled CAMB AI tools.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
List of LangChain tools configured with the toolkit's settings.
|
|
114
|
+
"""
|
|
115
|
+
api_key = self._get_api_key()
|
|
116
|
+
common_kwargs = {
|
|
117
|
+
"api_key": api_key,
|
|
118
|
+
"base_url": self.base_url,
|
|
119
|
+
"timeout": self.timeout,
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
tools: List[BaseTool] = []
|
|
123
|
+
|
|
124
|
+
if self.include_tts:
|
|
125
|
+
tools.append(CambTTSTool(**common_kwargs))
|
|
126
|
+
|
|
127
|
+
if self.include_translated_tts:
|
|
128
|
+
tools.append(CambTranslatedTTSTool(**common_kwargs))
|
|
129
|
+
|
|
130
|
+
if self.include_translation:
|
|
131
|
+
tools.append(CambTranslationTool(**common_kwargs))
|
|
132
|
+
|
|
133
|
+
if self.include_transcription:
|
|
134
|
+
tools.append(CambTranscriptionTool(**common_kwargs))
|
|
135
|
+
|
|
136
|
+
if self.include_voice_list:
|
|
137
|
+
tools.append(CambVoiceListTool(**common_kwargs))
|
|
138
|
+
|
|
139
|
+
if self.include_voice_clone:
|
|
140
|
+
tools.append(CambVoiceCloneTool(**common_kwargs))
|
|
141
|
+
|
|
142
|
+
if self.include_text_to_sound:
|
|
143
|
+
tools.append(CambTextToSoundTool(**common_kwargs))
|
|
144
|
+
|
|
145
|
+
if self.include_audio_separation:
|
|
146
|
+
tools.append(CambAudioSeparationTool(**common_kwargs))
|
|
147
|
+
|
|
148
|
+
return tools
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""CAMB AI LangChain tools."""
|
|
2
|
+
|
|
3
|
+
from langchain_camb.tools.audio_separation import (
|
|
4
|
+
AudioSeparationInput,
|
|
5
|
+
CambAudioSeparationTool,
|
|
6
|
+
)
|
|
7
|
+
from langchain_camb.tools.base import CambBaseTool
|
|
8
|
+
from langchain_camb.tools.text_to_sound import CambTextToSoundTool, TextToSoundInput
|
|
9
|
+
from langchain_camb.tools.transcription import CambTranscriptionTool, TranscriptionInput
|
|
10
|
+
from langchain_camb.tools.translated_tts import (
|
|
11
|
+
CambTranslatedTTSTool,
|
|
12
|
+
TranslatedTTSInput,
|
|
13
|
+
)
|
|
14
|
+
from langchain_camb.tools.translation import CambTranslationTool, TranslationInput
|
|
15
|
+
from langchain_camb.tools.tts import CambTTSTool, TTSInput
|
|
16
|
+
from langchain_camb.tools.voice_clone import CambVoiceCloneTool, VoiceCloneInput
|
|
17
|
+
from langchain_camb.tools.voice_list import CambVoiceListTool, VoiceListInput
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
# Base
|
|
21
|
+
"CambBaseTool",
|
|
22
|
+
# Tools
|
|
23
|
+
"CambTTSTool",
|
|
24
|
+
"CambTranslatedTTSTool",
|
|
25
|
+
"CambTranslationTool",
|
|
26
|
+
"CambTranscriptionTool",
|
|
27
|
+
"CambVoiceListTool",
|
|
28
|
+
"CambVoiceCloneTool",
|
|
29
|
+
"CambTextToSoundTool",
|
|
30
|
+
"CambAudioSeparationTool",
|
|
31
|
+
# Input schemas
|
|
32
|
+
"TTSInput",
|
|
33
|
+
"TranslatedTTSInput",
|
|
34
|
+
"TranslationInput",
|
|
35
|
+
"TranscriptionInput",
|
|
36
|
+
"VoiceListInput",
|
|
37
|
+
"VoiceCloneInput",
|
|
38
|
+
"TextToSoundInput",
|
|
39
|
+
"AudioSeparationInput",
|
|
40
|
+
]
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"""Audio separation tool for CAMB AI."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import tempfile
|
|
7
|
+
from typing import Any, Optional, Type
|
|
8
|
+
|
|
9
|
+
from langchain_core.callbacks import (
|
|
10
|
+
AsyncCallbackManagerForToolRun,
|
|
11
|
+
CallbackManagerForToolRun,
|
|
12
|
+
)
|
|
13
|
+
from pydantic import BaseModel, Field, model_validator
|
|
14
|
+
|
|
15
|
+
from langchain_camb.tools.base import CambBaseTool
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AudioSeparationInput(BaseModel):
|
|
19
|
+
"""Input schema for Audio Separation tool."""
|
|
20
|
+
|
|
21
|
+
audio_url: Optional[str] = Field(
|
|
22
|
+
default=None,
|
|
23
|
+
description="URL of the audio file to separate. Provide either audio_url or audio_file_path.",
|
|
24
|
+
)
|
|
25
|
+
audio_file_path: Optional[str] = Field(
|
|
26
|
+
default=None,
|
|
27
|
+
description="Local file path to the audio. Provide either audio_url or audio_file_path.",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
@model_validator(mode="after")
|
|
31
|
+
def validate_audio_source(self) -> "AudioSeparationInput":
|
|
32
|
+
"""Ensure exactly one audio source is provided."""
|
|
33
|
+
if not self.audio_url and not self.audio_file_path:
|
|
34
|
+
raise ValueError("Either audio_url or audio_file_path must be provided.")
|
|
35
|
+
if self.audio_url and self.audio_file_path:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
"Provide only one of audio_url or audio_file_path, not both."
|
|
38
|
+
)
|
|
39
|
+
return self
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class CambAudioSeparationTool(CambBaseTool):
|
|
43
|
+
"""Tool for separating vocals from background audio using CAMB AI.
|
|
44
|
+
|
|
45
|
+
This tool isolates speech/vocals from background music or noise.
|
|
46
|
+
Returns paths to separated audio files.
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
```python
|
|
50
|
+
from langchain_camb import CambAudioSeparationTool
|
|
51
|
+
|
|
52
|
+
separator = CambAudioSeparationTool()
|
|
53
|
+
result = separator.invoke({
|
|
54
|
+
"audio_file_path": "/path/to/mixed_audio.mp3"
|
|
55
|
+
})
|
|
56
|
+
print(result) # JSON with paths to vocals and background
|
|
57
|
+
```
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
name: str = "camb_audio_separation"
|
|
61
|
+
description: str = (
|
|
62
|
+
"Separate vocals/speech from background audio using CAMB AI. "
|
|
63
|
+
"Provide an audio URL or file path. "
|
|
64
|
+
"Returns separate files for vocals and background audio."
|
|
65
|
+
)
|
|
66
|
+
args_schema: Type[BaseModel] = AudioSeparationInput
|
|
67
|
+
|
|
68
|
+
def _run(
|
|
69
|
+
self,
|
|
70
|
+
audio_url: Optional[str] = None,
|
|
71
|
+
audio_file_path: Optional[str] = None,
|
|
72
|
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
73
|
+
) -> str:
|
|
74
|
+
"""Separate audio synchronously.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
JSON string with paths to vocals and background audio.
|
|
78
|
+
"""
|
|
79
|
+
kwargs: dict[str, Any] = {}
|
|
80
|
+
|
|
81
|
+
if audio_file_path:
|
|
82
|
+
with open(audio_file_path, "rb") as f:
|
|
83
|
+
kwargs["media_file"] = f
|
|
84
|
+
result = self.sync_client.audio_separation.create_audio_separation(
|
|
85
|
+
**kwargs
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
# For URL, we may need to handle differently depending on SDK
|
|
89
|
+
result = self.sync_client.audio_separation.create_audio_separation(**kwargs)
|
|
90
|
+
|
|
91
|
+
task_id = result.task_id
|
|
92
|
+
|
|
93
|
+
# Poll for completion and get run_id from status
|
|
94
|
+
status = self._poll_task_status_sync(
|
|
95
|
+
self.sync_client.audio_separation.get_audio_separation_status,
|
|
96
|
+
task_id,
|
|
97
|
+
)
|
|
98
|
+
run_id = status.run_id
|
|
99
|
+
|
|
100
|
+
# Get result
|
|
101
|
+
separation_result = self.sync_client.audio_separation.get_audio_separation_run_info(
|
|
102
|
+
run_id
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return self._format_result(separation_result)
|
|
106
|
+
|
|
107
|
+
async def _arun(
|
|
108
|
+
self,
|
|
109
|
+
audio_url: Optional[str] = None,
|
|
110
|
+
audio_file_path: Optional[str] = None,
|
|
111
|
+
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
112
|
+
) -> str:
|
|
113
|
+
"""Separate audio asynchronously.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
JSON string with paths to vocals and background audio.
|
|
117
|
+
"""
|
|
118
|
+
kwargs: dict[str, Any] = {}
|
|
119
|
+
|
|
120
|
+
if audio_file_path:
|
|
121
|
+
with open(audio_file_path, "rb") as f:
|
|
122
|
+
kwargs["media_file"] = f
|
|
123
|
+
result = await self.async_client.audio_separation.create_audio_separation(
|
|
124
|
+
**kwargs
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
result = await self.async_client.audio_separation.create_audio_separation(
|
|
128
|
+
**kwargs
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
task_id = result.task_id
|
|
132
|
+
|
|
133
|
+
# Poll for completion and get run_id from status
|
|
134
|
+
status = await self._poll_task_status(
|
|
135
|
+
self.async_client.audio_separation.get_audio_separation_status,
|
|
136
|
+
task_id,
|
|
137
|
+
)
|
|
138
|
+
run_id = status.run_id
|
|
139
|
+
|
|
140
|
+
# Get result
|
|
141
|
+
separation_result = (
|
|
142
|
+
await self.async_client.audio_separation.get_audio_separation_run_info(
|
|
143
|
+
run_id
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
return self._format_result(separation_result)
|
|
148
|
+
|
|
149
|
+
def _format_result(self, result: Any) -> str:
|
|
150
|
+
"""Format separation result as JSON."""
|
|
151
|
+
output = {
|
|
152
|
+
"vocals": None,
|
|
153
|
+
"background": None,
|
|
154
|
+
"status": "completed",
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
# Extract URLs or data from result
|
|
158
|
+
if hasattr(result, "vocals_url"):
|
|
159
|
+
output["vocals"] = result.vocals_url
|
|
160
|
+
elif hasattr(result, "vocals"):
|
|
161
|
+
# If it's bytes, save to temp file
|
|
162
|
+
if isinstance(result.vocals, bytes):
|
|
163
|
+
with tempfile.NamedTemporaryFile(
|
|
164
|
+
suffix="_vocals.wav", delete=False
|
|
165
|
+
) as f:
|
|
166
|
+
f.write(result.vocals)
|
|
167
|
+
output["vocals"] = f.name
|
|
168
|
+
else:
|
|
169
|
+
output["vocals"] = result.vocals
|
|
170
|
+
|
|
171
|
+
if hasattr(result, "background_url"):
|
|
172
|
+
output["background"] = result.background_url
|
|
173
|
+
elif hasattr(result, "background"):
|
|
174
|
+
if isinstance(result.background, bytes):
|
|
175
|
+
with tempfile.NamedTemporaryFile(
|
|
176
|
+
suffix="_background.wav", delete=False
|
|
177
|
+
) as f:
|
|
178
|
+
f.write(result.background)
|
|
179
|
+
output["background"] = f.name
|
|
180
|
+
else:
|
|
181
|
+
output["background"] = result.background
|
|
182
|
+
|
|
183
|
+
# Handle alternative attribute names
|
|
184
|
+
if hasattr(result, "instrumental_url"):
|
|
185
|
+
output["background"] = result.instrumental_url
|
|
186
|
+
if hasattr(result, "voice_url"):
|
|
187
|
+
output["vocals"] = result.voice_url
|
|
188
|
+
|
|
189
|
+
return json.dumps(output, indent=2)
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""Base class for CAMB AI LangChain tools."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import os
|
|
7
|
+
from abc import ABC
|
|
8
|
+
from typing import Any, Optional
|
|
9
|
+
|
|
10
|
+
from langchain_core.tools import BaseTool
|
|
11
|
+
from pydantic import ConfigDict, Field, model_validator
|
|
12
|
+
|
|
13
|
+
from camb.client import AsyncCambAI, CambAI
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CambBaseTool(BaseTool, ABC):
|
|
17
|
+
"""Base class for CAMB AI tools.
|
|
18
|
+
|
|
19
|
+
Provides shared client management and configuration for all CAMB AI tools.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
api_key: Optional[str] = Field(
|
|
23
|
+
default=None,
|
|
24
|
+
description="CAMB AI API key. Falls back to CAMB_API_KEY environment variable.",
|
|
25
|
+
)
|
|
26
|
+
base_url: Optional[str] = Field(
|
|
27
|
+
default=None,
|
|
28
|
+
description="Optional custom base URL for CAMB AI API.",
|
|
29
|
+
)
|
|
30
|
+
timeout: float = Field(
|
|
31
|
+
default=60.0,
|
|
32
|
+
description="Request timeout in seconds.",
|
|
33
|
+
)
|
|
34
|
+
max_poll_attempts: int = Field(
|
|
35
|
+
default=60,
|
|
36
|
+
description="Maximum number of polling attempts for async tasks.",
|
|
37
|
+
)
|
|
38
|
+
poll_interval: float = Field(
|
|
39
|
+
default=2.0,
|
|
40
|
+
description="Interval between polling attempts in seconds.",
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Private attributes for lazy client initialization
|
|
44
|
+
_sync_client: Optional[CambAI] = None
|
|
45
|
+
_async_client: Optional[AsyncCambAI] = None
|
|
46
|
+
|
|
47
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
48
|
+
|
|
49
|
+
@model_validator(mode="after")
|
|
50
|
+
def validate_api_key(self) -> "CambBaseTool":
|
|
51
|
+
"""Validate that API key is available."""
|
|
52
|
+
if not self.api_key:
|
|
53
|
+
self.api_key = os.environ.get("CAMB_API_KEY")
|
|
54
|
+
if not self.api_key:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
"CAMB AI API key is required. "
|
|
57
|
+
"Set it via 'api_key' parameter or CAMB_API_KEY environment variable."
|
|
58
|
+
)
|
|
59
|
+
return self
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def sync_client(self) -> CambAI:
|
|
63
|
+
"""Get or create synchronous CAMB AI client."""
|
|
64
|
+
if self._sync_client is None:
|
|
65
|
+
self._sync_client = CambAI(
|
|
66
|
+
api_key=self.api_key,
|
|
67
|
+
base_url=self.base_url,
|
|
68
|
+
timeout=self.timeout,
|
|
69
|
+
)
|
|
70
|
+
return self._sync_client
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def async_client(self) -> AsyncCambAI:
|
|
74
|
+
"""Get or create asynchronous CAMB AI client."""
|
|
75
|
+
if self._async_client is None:
|
|
76
|
+
self._async_client = AsyncCambAI(
|
|
77
|
+
api_key=self.api_key,
|
|
78
|
+
base_url=self.base_url,
|
|
79
|
+
timeout=self.timeout,
|
|
80
|
+
)
|
|
81
|
+
return self._async_client
|
|
82
|
+
|
|
83
|
+
async def _poll_task_status(
|
|
84
|
+
self,
|
|
85
|
+
get_status_fn: Any,
|
|
86
|
+
task_id: str,
|
|
87
|
+
*,
|
|
88
|
+
run_id: Optional[int] = None,
|
|
89
|
+
) -> Any:
|
|
90
|
+
"""Poll for async task completion.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
get_status_fn: Async function to check task status.
|
|
94
|
+
task_id: The task ID to poll.
|
|
95
|
+
run_id: Optional run ID for the request.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
The final status result when task completes.
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
TimeoutError: If polling exceeds max attempts.
|
|
102
|
+
RuntimeError: If task fails.
|
|
103
|
+
"""
|
|
104
|
+
for attempt in range(self.max_poll_attempts):
|
|
105
|
+
status = await get_status_fn(task_id, run_id=run_id)
|
|
106
|
+
|
|
107
|
+
if hasattr(status, "status"):
|
|
108
|
+
status_value = status.status
|
|
109
|
+
if status_value in ("completed", "SUCCESS"):
|
|
110
|
+
return status
|
|
111
|
+
elif status_value in ("failed", "FAILED", "error"):
|
|
112
|
+
error_msg = getattr(status, "error", "Unknown error")
|
|
113
|
+
raise RuntimeError(f"Task failed: {error_msg}")
|
|
114
|
+
|
|
115
|
+
await asyncio.sleep(self.poll_interval)
|
|
116
|
+
|
|
117
|
+
raise TimeoutError(
|
|
118
|
+
f"Task {task_id} did not complete within "
|
|
119
|
+
f"{self.max_poll_attempts * self.poll_interval} seconds"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def _poll_task_status_sync(
|
|
123
|
+
self,
|
|
124
|
+
get_status_fn: Any,
|
|
125
|
+
task_id: str,
|
|
126
|
+
*,
|
|
127
|
+
run_id: Optional[int] = None,
|
|
128
|
+
) -> Any:
|
|
129
|
+
"""Poll for async task completion (synchronous version).
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
get_status_fn: Sync function to check task status.
|
|
133
|
+
task_id: The task ID to poll.
|
|
134
|
+
run_id: Optional run ID for the request.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
The final status result when task completes.
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
TimeoutError: If polling exceeds max attempts.
|
|
141
|
+
RuntimeError: If task fails.
|
|
142
|
+
"""
|
|
143
|
+
import time
|
|
144
|
+
|
|
145
|
+
for attempt in range(self.max_poll_attempts):
|
|
146
|
+
status = get_status_fn(task_id, run_id=run_id)
|
|
147
|
+
|
|
148
|
+
if hasattr(status, "status"):
|
|
149
|
+
status_value = status.status
|
|
150
|
+
if status_value in ("completed", "SUCCESS"):
|
|
151
|
+
return status
|
|
152
|
+
elif status_value in ("failed", "FAILED", "error"):
|
|
153
|
+
error_msg = getattr(status, "error", "Unknown error")
|
|
154
|
+
raise RuntimeError(f"Task failed: {error_msg}")
|
|
155
|
+
|
|
156
|
+
time.sleep(self.poll_interval)
|
|
157
|
+
|
|
158
|
+
raise TimeoutError(
|
|
159
|
+
f"Task {task_id} did not complete within "
|
|
160
|
+
f"{self.max_poll_attempts * self.poll_interval} seconds"
|
|
161
|
+
)
|