airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,305 @@
|
|
1
|
+
from typing import Optional, Dict, Any, List, Generator, Union
|
2
|
+
from pydantic import Field, validator
|
3
|
+
from airtrain.core.skills import Skill, ProcessingError
|
4
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
|
+
from .credentials import TogetherAICredentials
|
6
|
+
from .models import TogetherAIImageInput, TogetherAIImageOutput, GeneratedImage
|
7
|
+
from .models_config import get_max_completion_tokens
|
8
|
+
from pathlib import Path
|
9
|
+
import base64
|
10
|
+
import time
|
11
|
+
from together import Together
|
12
|
+
|
13
|
+
|
14
|
+
class TogetherAIInput(InputSchema):
|
15
|
+
"""Schema for Together AI input"""
|
16
|
+
|
17
|
+
user_input: str = Field(..., description="User's input text")
|
18
|
+
system_prompt: str = Field(
|
19
|
+
default="You are a helpful assistant.",
|
20
|
+
description="System prompt to guide the model's behavior",
|
21
|
+
)
|
22
|
+
conversation_history: List[Dict[str, str]] = Field(
|
23
|
+
default_factory=list,
|
24
|
+
description=(
|
25
|
+
"List of previous conversation messages in "
|
26
|
+
"[{'role': 'user|assistant', 'content': 'message'}] format"
|
27
|
+
),
|
28
|
+
)
|
29
|
+
model: str = Field(
|
30
|
+
default="meta-llama/Llama-3.1-8B-Instruct",
|
31
|
+
description="Together AI model to use"
|
32
|
+
)
|
33
|
+
max_tokens: int = Field(
|
34
|
+
default=4096,
|
35
|
+
description="Maximum tokens in response"
|
36
|
+
)
|
37
|
+
temperature: float = Field(
|
38
|
+
default=0.7,
|
39
|
+
description="Temperature for response generation",
|
40
|
+
ge=0,
|
41
|
+
le=1
|
42
|
+
)
|
43
|
+
stream: bool = Field(
|
44
|
+
default=False,
|
45
|
+
description="Whether to stream the response"
|
46
|
+
)
|
47
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
48
|
+
default=None,
|
49
|
+
description=(
|
50
|
+
"A list of tools the model may use. "
|
51
|
+
"Currently only functions supported."
|
52
|
+
),
|
53
|
+
)
|
54
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
55
|
+
default=None,
|
56
|
+
description=(
|
57
|
+
"Controls which tool is called by the model. "
|
58
|
+
"'none', 'auto', or specific tool."
|
59
|
+
),
|
60
|
+
)
|
61
|
+
|
62
|
+
@validator('max_tokens')
|
63
|
+
def validate_max_tokens(cls, v, values):
|
64
|
+
"""Validate that max_tokens doesn't exceed the model's limit."""
|
65
|
+
if 'model' in values:
|
66
|
+
model_id = values['model']
|
67
|
+
max_limit = get_max_completion_tokens(model_id)
|
68
|
+
if v > max_limit:
|
69
|
+
return max_limit
|
70
|
+
return v
|
71
|
+
|
72
|
+
|
73
|
+
class TogetherAIOutput(OutputSchema):
|
74
|
+
"""Schema for Together AI output"""
|
75
|
+
|
76
|
+
response: str = Field(..., description="Model's response text")
|
77
|
+
used_model: str = Field(..., description="Model used for generation")
|
78
|
+
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
79
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
80
|
+
default=None, description="Tool calls generated by the model"
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
class TogetherAIChatSkill(Skill[TogetherAIInput, TogetherAIOutput]):
|
85
|
+
"""Skill for Together AI chat"""
|
86
|
+
|
87
|
+
input_schema = TogetherAIInput
|
88
|
+
output_schema = TogetherAIOutput
|
89
|
+
|
90
|
+
def __init__(self, credentials: Optional[TogetherAICredentials] = None):
|
91
|
+
super().__init__()
|
92
|
+
self.credentials = credentials or TogetherAICredentials.from_env()
|
93
|
+
self.client = Together(
|
94
|
+
api_key=self.credentials.together_api_key.get_secret_value()
|
95
|
+
)
|
96
|
+
|
97
|
+
def _build_messages(self, input_data: TogetherAIInput) -> List[Dict[str, str]]:
|
98
|
+
"""
|
99
|
+
Build messages list from input data including conversation history.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
List[Dict[str, str]]: List of messages in the format required by Together AI
|
106
|
+
"""
|
107
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
108
|
+
|
109
|
+
# Add conversation history if present
|
110
|
+
if input_data.conversation_history:
|
111
|
+
messages.extend(input_data.conversation_history)
|
112
|
+
|
113
|
+
# Add current user input
|
114
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
115
|
+
|
116
|
+
return messages
|
117
|
+
|
118
|
+
def process_stream(self, input_data: TogetherAIInput) -> Generator[str, None, None]:
|
119
|
+
"""Process the input and stream the response token by token."""
|
120
|
+
try:
|
121
|
+
messages = self._build_messages(input_data)
|
122
|
+
|
123
|
+
stream = self.client.chat.completions.create(
|
124
|
+
model=input_data.model,
|
125
|
+
messages=messages,
|
126
|
+
temperature=input_data.temperature,
|
127
|
+
max_tokens=input_data.max_tokens,
|
128
|
+
stream=True,
|
129
|
+
)
|
130
|
+
|
131
|
+
for chunk in stream:
|
132
|
+
if chunk.choices[0].delta.content is not None:
|
133
|
+
yield chunk.choices[0].delta.content
|
134
|
+
|
135
|
+
except Exception as e:
|
136
|
+
raise ProcessingError(f"Together AI streaming failed: {str(e)}")
|
137
|
+
|
138
|
+
def process(self, input_data: TogetherAIInput) -> TogetherAIOutput:
|
139
|
+
"""Process the input and return the complete response."""
|
140
|
+
try:
|
141
|
+
if input_data.stream:
|
142
|
+
response_chunks = []
|
143
|
+
for chunk in self.process_stream(input_data):
|
144
|
+
response_chunks.append(chunk)
|
145
|
+
response = "".join(response_chunks)
|
146
|
+
usage = {} # Usage stats not available in streaming
|
147
|
+
tool_calls = None # Tool calls not available in streaming
|
148
|
+
else:
|
149
|
+
messages = self._build_messages(input_data)
|
150
|
+
|
151
|
+
# Prepare API call parameters
|
152
|
+
api_params = {
|
153
|
+
"model": input_data.model,
|
154
|
+
"messages": messages,
|
155
|
+
"temperature": input_data.temperature,
|
156
|
+
"max_tokens": input_data.max_tokens,
|
157
|
+
"stream": False,
|
158
|
+
}
|
159
|
+
|
160
|
+
# Add tools and tool_choice if provided
|
161
|
+
if input_data.tools:
|
162
|
+
api_params["tools"] = input_data.tools
|
163
|
+
|
164
|
+
if input_data.tool_choice:
|
165
|
+
api_params["tool_choice"] = input_data.tool_choice
|
166
|
+
|
167
|
+
try:
|
168
|
+
completion = self.client.chat.completions.create(**api_params)
|
169
|
+
response = completion.choices[0].message.content or ""
|
170
|
+
|
171
|
+
# Extract usage information
|
172
|
+
usage = (
|
173
|
+
completion.usage.model_dump()
|
174
|
+
if hasattr(completion, "usage")
|
175
|
+
else {}
|
176
|
+
)
|
177
|
+
|
178
|
+
# Check for tool calls in the response
|
179
|
+
tool_calls = None
|
180
|
+
if (
|
181
|
+
hasattr(completion.choices[0].message, "tool_calls")
|
182
|
+
and completion.choices[0].message.tool_calls
|
183
|
+
):
|
184
|
+
tool_calls = [
|
185
|
+
{
|
186
|
+
"id": tool_call.id,
|
187
|
+
"type": tool_call.type,
|
188
|
+
"function": {
|
189
|
+
"name": tool_call.function.name,
|
190
|
+
"arguments": tool_call.function.arguments
|
191
|
+
}
|
192
|
+
}
|
193
|
+
for tool_call in completion.choices[0].message.tool_calls
|
194
|
+
]
|
195
|
+
except Exception as api_error:
|
196
|
+
# Provide more specific error messages for common tool call issues
|
197
|
+
error_message = str(api_error)
|
198
|
+
if "tool_call_id" in error_message:
|
199
|
+
raise ProcessingError(
|
200
|
+
"Tool call error: Missing tool_call_id in conversation history. "
|
201
|
+
"When adding tool responses to conversation history, "
|
202
|
+
"make sure each message with role='tool' includes a 'tool_call_id' field."
|
203
|
+
)
|
204
|
+
elif "messages" in error_message and "tool" in error_message:
|
205
|
+
raise ProcessingError(
|
206
|
+
"Tool call error: Invalid message format in conversation history. "
|
207
|
+
f"Original error: {error_message}"
|
208
|
+
)
|
209
|
+
else:
|
210
|
+
# Re-raise with original error message
|
211
|
+
raise ProcessingError(f"Together AI API error: {error_message}")
|
212
|
+
|
213
|
+
return TogetherAIOutput(
|
214
|
+
response=response,
|
215
|
+
used_model=input_data.model,
|
216
|
+
usage=usage,
|
217
|
+
tool_calls=tool_calls
|
218
|
+
)
|
219
|
+
|
220
|
+
except ProcessingError:
|
221
|
+
# Re-raise ProcessingError without modification
|
222
|
+
raise
|
223
|
+
except Exception as e:
|
224
|
+
raise ProcessingError(f"Together AI processing failed: {str(e)}")
|
225
|
+
|
226
|
+
|
227
|
+
class TogetherAIImageSkill(Skill[TogetherAIImageInput, TogetherAIImageOutput]):
|
228
|
+
"""Skill for Together AI image generation"""
|
229
|
+
|
230
|
+
input_schema = TogetherAIImageInput
|
231
|
+
output_schema = TogetherAIImageOutput
|
232
|
+
|
233
|
+
def __init__(self, credentials: Optional[TogetherAICredentials] = None):
|
234
|
+
super().__init__()
|
235
|
+
self.credentials = credentials or TogetherAICredentials.from_env()
|
236
|
+
self.client = Together(
|
237
|
+
api_key=self.credentials.together_api_key.get_secret_value()
|
238
|
+
)
|
239
|
+
|
240
|
+
def process(self, input_data: TogetherAIImageInput) -> TogetherAIImageOutput:
|
241
|
+
try:
|
242
|
+
start_time = time.time()
|
243
|
+
|
244
|
+
# Generate images
|
245
|
+
response = self.client.images.generate(
|
246
|
+
prompt=input_data.prompt,
|
247
|
+
model=input_data.model,
|
248
|
+
steps=input_data.steps,
|
249
|
+
n=input_data.n,
|
250
|
+
size=input_data.size,
|
251
|
+
negative_prompt=input_data.negative_prompt,
|
252
|
+
seed=input_data.seed,
|
253
|
+
)
|
254
|
+
|
255
|
+
# Calculate total time
|
256
|
+
total_time = time.time() - start_time
|
257
|
+
|
258
|
+
# Convert response to our output format
|
259
|
+
generated_images = [
|
260
|
+
GeneratedImage(
|
261
|
+
b64_json=img.b64_json,
|
262
|
+
seed=getattr(img, "seed", None),
|
263
|
+
finish_reason=getattr(img, "finish_reason", None),
|
264
|
+
)
|
265
|
+
for img in response.data
|
266
|
+
]
|
267
|
+
|
268
|
+
return TogetherAIImageOutput(
|
269
|
+
images=generated_images,
|
270
|
+
model=input_data.model,
|
271
|
+
prompt=input_data.prompt,
|
272
|
+
total_time=total_time,
|
273
|
+
usage=getattr(response, "usage", {}),
|
274
|
+
)
|
275
|
+
|
276
|
+
except Exception as e:
|
277
|
+
raise ProcessingError(f"Together AI image generation failed: {str(e)}")
|
278
|
+
|
279
|
+
def save_images(
|
280
|
+
self, output: TogetherAIImageOutput, output_dir: Path
|
281
|
+
) -> List[Path]:
|
282
|
+
"""
|
283
|
+
Save generated images to disk
|
284
|
+
|
285
|
+
Args:
|
286
|
+
output (TogetherAIImageOutput): Generation output containing images
|
287
|
+
output_dir (Path): Directory to save images
|
288
|
+
|
289
|
+
Returns:
|
290
|
+
List[Path]: List of paths to saved images
|
291
|
+
"""
|
292
|
+
output_dir = Path(output_dir)
|
293
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
294
|
+
|
295
|
+
saved_paths = []
|
296
|
+
for i, img in enumerate(output.images):
|
297
|
+
output_path = output_dir / f"image_{i}.png"
|
298
|
+
image_data = base64.b64decode(img.b64_json)
|
299
|
+
|
300
|
+
with open(output_path, "wb") as f:
|
301
|
+
f.write(image_data)
|
302
|
+
|
303
|
+
saved_paths.append(output_path)
|
304
|
+
|
305
|
+
return saved_paths
|
@@ -0,0 +1,49 @@
|
|
1
|
+
from typing import Dict, NamedTuple
|
2
|
+
|
3
|
+
|
4
|
+
class VisionModelConfig(NamedTuple):
|
5
|
+
organization: str
|
6
|
+
display_name: str
|
7
|
+
context_length: int
|
8
|
+
|
9
|
+
|
10
|
+
TOGETHER_VISION_MODELS: Dict[str, VisionModelConfig] = {
|
11
|
+
"meta-llama/Llama-Vision-Free": VisionModelConfig(
|
12
|
+
organization="Meta",
|
13
|
+
display_name="(Free) Llama 3.2 11B Vision Instruct Turbo",
|
14
|
+
context_length=131072,
|
15
|
+
),
|
16
|
+
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo": VisionModelConfig(
|
17
|
+
organization="Meta",
|
18
|
+
display_name="Llama 3.2 11B Vision Instruct Turbo",
|
19
|
+
context_length=131072,
|
20
|
+
),
|
21
|
+
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo": VisionModelConfig(
|
22
|
+
organization="Meta",
|
23
|
+
display_name="Llama 3.2 90B Vision Instruct Turbo",
|
24
|
+
context_length=131072,
|
25
|
+
),
|
26
|
+
}
|
27
|
+
|
28
|
+
|
29
|
+
def get_vision_model_config(model_id: str) -> VisionModelConfig:
|
30
|
+
"""Get vision model configuration by model ID"""
|
31
|
+
if model_id not in TOGETHER_VISION_MODELS:
|
32
|
+
raise ValueError(f"Model {model_id} not found in Together AI vision models")
|
33
|
+
return TOGETHER_VISION_MODELS[model_id]
|
34
|
+
|
35
|
+
|
36
|
+
def list_vision_models_by_organization(
|
37
|
+
organization: str,
|
38
|
+
) -> Dict[str, VisionModelConfig]:
|
39
|
+
"""Get all vision models for a specific organization"""
|
40
|
+
return {
|
41
|
+
model_id: config
|
42
|
+
for model_id, config in TOGETHER_VISION_MODELS.items()
|
43
|
+
if config.organization.lower() == organization.lower()
|
44
|
+
}
|
45
|
+
|
46
|
+
|
47
|
+
def get_default_vision_model() -> str:
|
48
|
+
"""Get the default vision model ID"""
|
49
|
+
return "meta-llama/Llama-Vision-Free"
|
@@ -0,0 +1,38 @@
|
|
1
|
+
"""
|
2
|
+
Airtrain Telemetry
|
3
|
+
|
4
|
+
This package provides telemetry functionality for Airtrain usage.
|
5
|
+
Telemetry is enabled by default to help improve the library and can be disabled by
|
6
|
+
setting AIRTRAIN_TELEMETRY_ENABLED=false in your environment variables or .env file.
|
7
|
+
"""
|
8
|
+
|
9
|
+
from airtrain.telemetry.service import ProductTelemetry
|
10
|
+
from airtrain.telemetry.views import (
|
11
|
+
AgentRunTelemetryEvent,
|
12
|
+
AgentStepTelemetryEvent,
|
13
|
+
AgentEndTelemetryEvent,
|
14
|
+
ModelInvocationTelemetryEvent,
|
15
|
+
ErrorTelemetryEvent,
|
16
|
+
UserFeedbackTelemetryEvent,
|
17
|
+
SkillInitTelemetryEvent,
|
18
|
+
SkillProcessTelemetryEvent,
|
19
|
+
PackageInstallTelemetryEvent,
|
20
|
+
PackageImportTelemetryEvent,
|
21
|
+
)
|
22
|
+
|
23
|
+
__all__ = [
|
24
|
+
"ProductTelemetry",
|
25
|
+
"AgentRunTelemetryEvent",
|
26
|
+
"AgentStepTelemetryEvent",
|
27
|
+
"AgentEndTelemetryEvent",
|
28
|
+
"ModelInvocationTelemetryEvent",
|
29
|
+
"ErrorTelemetryEvent",
|
30
|
+
"UserFeedbackTelemetryEvent",
|
31
|
+
"SkillInitTelemetryEvent",
|
32
|
+
"SkillProcessTelemetryEvent",
|
33
|
+
"PackageInstallTelemetryEvent",
|
34
|
+
"PackageImportTelemetryEvent",
|
35
|
+
]
|
36
|
+
|
37
|
+
# Create a singleton instance for easy import
|
38
|
+
telemetry = ProductTelemetry()
|
@@ -0,0 +1,167 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import platform
|
4
|
+
import sys
|
5
|
+
import uuid
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
from dotenv import load_dotenv
|
9
|
+
from posthog import Posthog
|
10
|
+
|
11
|
+
from airtrain.telemetry.views import BaseTelemetryEvent
|
12
|
+
|
13
|
+
load_dotenv()
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
# Enhanced event settings to collect more data
|
18
|
+
POSTHOG_EVENT_SETTINGS = {
|
19
|
+
'process_person_profile': True,
|
20
|
+
'enable_sent_at': True, # Add timing information
|
21
|
+
'capture_performance': True, # Collect performance data
|
22
|
+
'capture_pageview': True, # More detailed usage tracking
|
23
|
+
}
|
24
|
+
|
25
|
+
|
26
|
+
def singleton(cls):
|
27
|
+
"""Singleton decorator for classes."""
|
28
|
+
instances = {}
|
29
|
+
|
30
|
+
def get_instance(*args, **kwargs):
|
31
|
+
if cls not in instances:
|
32
|
+
instances[cls] = cls(*args, **kwargs)
|
33
|
+
return instances[cls]
|
34
|
+
|
35
|
+
return get_instance
|
36
|
+
|
37
|
+
|
38
|
+
@singleton
|
39
|
+
class ProductTelemetry:
|
40
|
+
"""
|
41
|
+
Service for capturing telemetry data from Airtrain usage.
|
42
|
+
|
43
|
+
Telemetry is enabled by default but can be disabled by setting
|
44
|
+
AIRTRAIN_TELEMETRY_ENABLED=false in your environment.
|
45
|
+
"""
|
46
|
+
|
47
|
+
USER_ID_PATH = str(
|
48
|
+
Path.home() / '.cache' / 'airtrain' / 'telemetry_user_id'
|
49
|
+
)
|
50
|
+
# API key for PostHog
|
51
|
+
PROJECT_API_KEY = 'phc_1pLNkG3QStYEXIz0CAPQaOGpcmxpE3CJXhE1HANWgIz'
|
52
|
+
HOST = 'https://us.i.posthog.com'
|
53
|
+
UNKNOWN_USER_ID = 'UNKNOWN'
|
54
|
+
|
55
|
+
_curr_user_id = None
|
56
|
+
|
57
|
+
def __init__(self) -> None:
|
58
|
+
telemetry_disabled = os.getenv('AIRTRAIN_TELEMETRY_ENABLED', 'true').lower() == 'false'
|
59
|
+
self.debug_logging = os.getenv('AIRTRAIN_LOGGING_LEVEL', 'info').lower() == 'debug'
|
60
|
+
|
61
|
+
# System information to include with telemetry
|
62
|
+
self.system_info = {
|
63
|
+
'os': platform.system(),
|
64
|
+
'os_version': platform.version(),
|
65
|
+
'python_version': f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
|
66
|
+
'platform': platform.platform(),
|
67
|
+
'machine': platform.machine(),
|
68
|
+
'hostname': platform.node(),
|
69
|
+
'username': os.getlogin() if hasattr(os, 'getlogin') else 'unknown'
|
70
|
+
}
|
71
|
+
isBeta = True # TODO: remove this once out of beta
|
72
|
+
if telemetry_disabled and not isBeta:
|
73
|
+
self._posthog_client = None
|
74
|
+
else:
|
75
|
+
if not isBeta:
|
76
|
+
logging.info(
|
77
|
+
'Telemetry enabled. To disable, set '
|
78
|
+
'AIRTRAIN_TELEMETRY_ENABLED=false in your environment.'
|
79
|
+
)
|
80
|
+
if isBeta:
|
81
|
+
logging.info(
|
82
|
+
'You are currently in beta. Telemetry is enabled by default.'
|
83
|
+
)
|
84
|
+
self._posthog_client = Posthog(
|
85
|
+
project_api_key=self.PROJECT_API_KEY,
|
86
|
+
host=self.HOST,
|
87
|
+
disable_geoip=False # Collect geographical data
|
88
|
+
)
|
89
|
+
|
90
|
+
# Set debug mode if enabled
|
91
|
+
if self.debug_logging:
|
92
|
+
self._posthog_client.debug = True
|
93
|
+
|
94
|
+
# Identify user more specifically
|
95
|
+
self._posthog_client.identify(
|
96
|
+
self.user_id,
|
97
|
+
{
|
98
|
+
**self.system_info,
|
99
|
+
'first_seen': True
|
100
|
+
}
|
101
|
+
)
|
102
|
+
|
103
|
+
# Silence posthog's logging only if debug is off
|
104
|
+
if not self.debug_logging:
|
105
|
+
posthog_logger = logging.getLogger('posthog')
|
106
|
+
posthog_logger.disabled = True
|
107
|
+
|
108
|
+
if self._posthog_client is None:
|
109
|
+
logger.debug('Telemetry disabled')
|
110
|
+
|
111
|
+
def capture(self, event: BaseTelemetryEvent) -> None:
|
112
|
+
"""Capture a telemetry event and send it to PostHog if telemetry is enabled."""
|
113
|
+
if self._posthog_client is None:
|
114
|
+
return
|
115
|
+
|
116
|
+
# Add system information to all events
|
117
|
+
enhanced_properties = {
|
118
|
+
**event.properties,
|
119
|
+
**POSTHOG_EVENT_SETTINGS,
|
120
|
+
**self.system_info
|
121
|
+
}
|
122
|
+
|
123
|
+
if self.debug_logging:
|
124
|
+
logger.debug(f'Telemetry event: {event.name} {enhanced_properties}')
|
125
|
+
self._direct_capture(event, enhanced_properties)
|
126
|
+
|
127
|
+
def _direct_capture(self, event: BaseTelemetryEvent, enhanced_properties: dict) -> None:
|
128
|
+
"""
|
129
|
+
Send the event to PostHog. Should not be thread blocking because posthog handles it.
|
130
|
+
"""
|
131
|
+
if self._posthog_client is None:
|
132
|
+
return
|
133
|
+
|
134
|
+
try:
|
135
|
+
self._posthog_client.capture(
|
136
|
+
self.user_id,
|
137
|
+
event.name,
|
138
|
+
enhanced_properties
|
139
|
+
)
|
140
|
+
except Exception as e:
|
141
|
+
logger.error(f'Failed to send telemetry event {event.name}: {e}')
|
142
|
+
|
143
|
+
@property
|
144
|
+
def user_id(self) -> str:
|
145
|
+
"""
|
146
|
+
Get the user ID for telemetry.
|
147
|
+
Creates a new one if it doesn't exist.
|
148
|
+
"""
|
149
|
+
if self._curr_user_id:
|
150
|
+
return self._curr_user_id
|
151
|
+
|
152
|
+
# File access may fail due to permissions or other reasons.
|
153
|
+
# We don't want to crash so we catch all exceptions.
|
154
|
+
try:
|
155
|
+
if not os.path.exists(self.USER_ID_PATH):
|
156
|
+
os.makedirs(os.path.dirname(self.USER_ID_PATH), exist_ok=True)
|
157
|
+
with open(self.USER_ID_PATH, 'w') as f:
|
158
|
+
# Use a more identifiable ID prefix
|
159
|
+
new_user_id = f"airtrain-user-{uuid.uuid4()}"
|
160
|
+
f.write(new_user_id)
|
161
|
+
self._curr_user_id = new_user_id
|
162
|
+
else:
|
163
|
+
with open(self.USER_ID_PATH, 'r') as f:
|
164
|
+
self._curr_user_id = f.read()
|
165
|
+
except Exception:
|
166
|
+
self._curr_user_id = self.UNKNOWN_USER_ID
|
167
|
+
return self._curr_user_id
|