airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,305 @@
1
+ from typing import Optional, Dict, Any, List, Generator, Union
2
+ from pydantic import Field, validator
3
+ from airtrain.core.skills import Skill, ProcessingError
4
+ from airtrain.core.schemas import InputSchema, OutputSchema
5
+ from .credentials import TogetherAICredentials
6
+ from .models import TogetherAIImageInput, TogetherAIImageOutput, GeneratedImage
7
+ from .models_config import get_max_completion_tokens
8
+ from pathlib import Path
9
+ import base64
10
+ import time
11
+ from together import Together
12
+
13
+
14
+ class TogetherAIInput(InputSchema):
15
+ """Schema for Together AI input"""
16
+
17
+ user_input: str = Field(..., description="User's input text")
18
+ system_prompt: str = Field(
19
+ default="You are a helpful assistant.",
20
+ description="System prompt to guide the model's behavior",
21
+ )
22
+ conversation_history: List[Dict[str, str]] = Field(
23
+ default_factory=list,
24
+ description=(
25
+ "List of previous conversation messages in "
26
+ "[{'role': 'user|assistant', 'content': 'message'}] format"
27
+ ),
28
+ )
29
+ model: str = Field(
30
+ default="meta-llama/Llama-3.1-8B-Instruct",
31
+ description="Together AI model to use"
32
+ )
33
+ max_tokens: int = Field(
34
+ default=4096,
35
+ description="Maximum tokens in response"
36
+ )
37
+ temperature: float = Field(
38
+ default=0.7,
39
+ description="Temperature for response generation",
40
+ ge=0,
41
+ le=1
42
+ )
43
+ stream: bool = Field(
44
+ default=False,
45
+ description="Whether to stream the response"
46
+ )
47
+ tools: Optional[List[Dict[str, Any]]] = Field(
48
+ default=None,
49
+ description=(
50
+ "A list of tools the model may use. "
51
+ "Currently only functions supported."
52
+ ),
53
+ )
54
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
55
+ default=None,
56
+ description=(
57
+ "Controls which tool is called by the model. "
58
+ "'none', 'auto', or specific tool."
59
+ ),
60
+ )
61
+
62
+ @validator('max_tokens')
63
+ def validate_max_tokens(cls, v, values):
64
+ """Validate that max_tokens doesn't exceed the model's limit."""
65
+ if 'model' in values:
66
+ model_id = values['model']
67
+ max_limit = get_max_completion_tokens(model_id)
68
+ if v > max_limit:
69
+ return max_limit
70
+ return v
71
+
72
+
73
+ class TogetherAIOutput(OutputSchema):
74
+ """Schema for Together AI output"""
75
+
76
+ response: str = Field(..., description="Model's response text")
77
+ used_model: str = Field(..., description="Model used for generation")
78
+ usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
79
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
80
+ default=None, description="Tool calls generated by the model"
81
+ )
82
+
83
+
84
+ class TogetherAIChatSkill(Skill[TogetherAIInput, TogetherAIOutput]):
85
+ """Skill for Together AI chat"""
86
+
87
+ input_schema = TogetherAIInput
88
+ output_schema = TogetherAIOutput
89
+
90
+ def __init__(self, credentials: Optional[TogetherAICredentials] = None):
91
+ super().__init__()
92
+ self.credentials = credentials or TogetherAICredentials.from_env()
93
+ self.client = Together(
94
+ api_key=self.credentials.together_api_key.get_secret_value()
95
+ )
96
+
97
+ def _build_messages(self, input_data: TogetherAIInput) -> List[Dict[str, str]]:
98
+ """
99
+ Build messages list from input data including conversation history.
100
+
101
+ Args:
102
+ input_data: The input data containing system prompt, conversation history, and user input
103
+
104
+ Returns:
105
+ List[Dict[str, str]]: List of messages in the format required by Together AI
106
+ """
107
+ messages = [{"role": "system", "content": input_data.system_prompt}]
108
+
109
+ # Add conversation history if present
110
+ if input_data.conversation_history:
111
+ messages.extend(input_data.conversation_history)
112
+
113
+ # Add current user input
114
+ messages.append({"role": "user", "content": input_data.user_input})
115
+
116
+ return messages
117
+
118
+ def process_stream(self, input_data: TogetherAIInput) -> Generator[str, None, None]:
119
+ """Process the input and stream the response token by token."""
120
+ try:
121
+ messages = self._build_messages(input_data)
122
+
123
+ stream = self.client.chat.completions.create(
124
+ model=input_data.model,
125
+ messages=messages,
126
+ temperature=input_data.temperature,
127
+ max_tokens=input_data.max_tokens,
128
+ stream=True,
129
+ )
130
+
131
+ for chunk in stream:
132
+ if chunk.choices[0].delta.content is not None:
133
+ yield chunk.choices[0].delta.content
134
+
135
+ except Exception as e:
136
+ raise ProcessingError(f"Together AI streaming failed: {str(e)}")
137
+
138
+ def process(self, input_data: TogetherAIInput) -> TogetherAIOutput:
139
+ """Process the input and return the complete response."""
140
+ try:
141
+ if input_data.stream:
142
+ response_chunks = []
143
+ for chunk in self.process_stream(input_data):
144
+ response_chunks.append(chunk)
145
+ response = "".join(response_chunks)
146
+ usage = {} # Usage stats not available in streaming
147
+ tool_calls = None # Tool calls not available in streaming
148
+ else:
149
+ messages = self._build_messages(input_data)
150
+
151
+ # Prepare API call parameters
152
+ api_params = {
153
+ "model": input_data.model,
154
+ "messages": messages,
155
+ "temperature": input_data.temperature,
156
+ "max_tokens": input_data.max_tokens,
157
+ "stream": False,
158
+ }
159
+
160
+ # Add tools and tool_choice if provided
161
+ if input_data.tools:
162
+ api_params["tools"] = input_data.tools
163
+
164
+ if input_data.tool_choice:
165
+ api_params["tool_choice"] = input_data.tool_choice
166
+
167
+ try:
168
+ completion = self.client.chat.completions.create(**api_params)
169
+ response = completion.choices[0].message.content or ""
170
+
171
+ # Extract usage information
172
+ usage = (
173
+ completion.usage.model_dump()
174
+ if hasattr(completion, "usage")
175
+ else {}
176
+ )
177
+
178
+ # Check for tool calls in the response
179
+ tool_calls = None
180
+ if (
181
+ hasattr(completion.choices[0].message, "tool_calls")
182
+ and completion.choices[0].message.tool_calls
183
+ ):
184
+ tool_calls = [
185
+ {
186
+ "id": tool_call.id,
187
+ "type": tool_call.type,
188
+ "function": {
189
+ "name": tool_call.function.name,
190
+ "arguments": tool_call.function.arguments
191
+ }
192
+ }
193
+ for tool_call in completion.choices[0].message.tool_calls
194
+ ]
195
+ except Exception as api_error:
196
+ # Provide more specific error messages for common tool call issues
197
+ error_message = str(api_error)
198
+ if "tool_call_id" in error_message:
199
+ raise ProcessingError(
200
+ "Tool call error: Missing tool_call_id in conversation history. "
201
+ "When adding tool responses to conversation history, "
202
+ "make sure each message with role='tool' includes a 'tool_call_id' field."
203
+ )
204
+ elif "messages" in error_message and "tool" in error_message:
205
+ raise ProcessingError(
206
+ "Tool call error: Invalid message format in conversation history. "
207
+ f"Original error: {error_message}"
208
+ )
209
+ else:
210
+ # Re-raise with original error message
211
+ raise ProcessingError(f"Together AI API error: {error_message}")
212
+
213
+ return TogetherAIOutput(
214
+ response=response,
215
+ used_model=input_data.model,
216
+ usage=usage,
217
+ tool_calls=tool_calls
218
+ )
219
+
220
+ except ProcessingError:
221
+ # Re-raise ProcessingError without modification
222
+ raise
223
+ except Exception as e:
224
+ raise ProcessingError(f"Together AI processing failed: {str(e)}")
225
+
226
+
227
+ class TogetherAIImageSkill(Skill[TogetherAIImageInput, TogetherAIImageOutput]):
228
+ """Skill for Together AI image generation"""
229
+
230
+ input_schema = TogetherAIImageInput
231
+ output_schema = TogetherAIImageOutput
232
+
233
+ def __init__(self, credentials: Optional[TogetherAICredentials] = None):
234
+ super().__init__()
235
+ self.credentials = credentials or TogetherAICredentials.from_env()
236
+ self.client = Together(
237
+ api_key=self.credentials.together_api_key.get_secret_value()
238
+ )
239
+
240
+ def process(self, input_data: TogetherAIImageInput) -> TogetherAIImageOutput:
241
+ try:
242
+ start_time = time.time()
243
+
244
+ # Generate images
245
+ response = self.client.images.generate(
246
+ prompt=input_data.prompt,
247
+ model=input_data.model,
248
+ steps=input_data.steps,
249
+ n=input_data.n,
250
+ size=input_data.size,
251
+ negative_prompt=input_data.negative_prompt,
252
+ seed=input_data.seed,
253
+ )
254
+
255
+ # Calculate total time
256
+ total_time = time.time() - start_time
257
+
258
+ # Convert response to our output format
259
+ generated_images = [
260
+ GeneratedImage(
261
+ b64_json=img.b64_json,
262
+ seed=getattr(img, "seed", None),
263
+ finish_reason=getattr(img, "finish_reason", None),
264
+ )
265
+ for img in response.data
266
+ ]
267
+
268
+ return TogetherAIImageOutput(
269
+ images=generated_images,
270
+ model=input_data.model,
271
+ prompt=input_data.prompt,
272
+ total_time=total_time,
273
+ usage=getattr(response, "usage", {}),
274
+ )
275
+
276
+ except Exception as e:
277
+ raise ProcessingError(f"Together AI image generation failed: {str(e)}")
278
+
279
+ def save_images(
280
+ self, output: TogetherAIImageOutput, output_dir: Path
281
+ ) -> List[Path]:
282
+ """
283
+ Save generated images to disk
284
+
285
+ Args:
286
+ output (TogetherAIImageOutput): Generation output containing images
287
+ output_dir (Path): Directory to save images
288
+
289
+ Returns:
290
+ List[Path]: List of paths to saved images
291
+ """
292
+ output_dir = Path(output_dir)
293
+ output_dir.mkdir(parents=True, exist_ok=True)
294
+
295
+ saved_paths = []
296
+ for i, img in enumerate(output.images):
297
+ output_path = output_dir / f"image_{i}.png"
298
+ image_data = base64.b64decode(img.b64_json)
299
+
300
+ with open(output_path, "wb") as f:
301
+ f.write(image_data)
302
+
303
+ saved_paths.append(output_path)
304
+
305
+ return saved_paths
@@ -0,0 +1,49 @@
1
+ from typing import Dict, NamedTuple
2
+
3
+
4
+ class VisionModelConfig(NamedTuple):
5
+ organization: str
6
+ display_name: str
7
+ context_length: int
8
+
9
+
10
+ TOGETHER_VISION_MODELS: Dict[str, VisionModelConfig] = {
11
+ "meta-llama/Llama-Vision-Free": VisionModelConfig(
12
+ organization="Meta",
13
+ display_name="(Free) Llama 3.2 11B Vision Instruct Turbo",
14
+ context_length=131072,
15
+ ),
16
+ "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo": VisionModelConfig(
17
+ organization="Meta",
18
+ display_name="Llama 3.2 11B Vision Instruct Turbo",
19
+ context_length=131072,
20
+ ),
21
+ "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo": VisionModelConfig(
22
+ organization="Meta",
23
+ display_name="Llama 3.2 90B Vision Instruct Turbo",
24
+ context_length=131072,
25
+ ),
26
+ }
27
+
28
+
29
+ def get_vision_model_config(model_id: str) -> VisionModelConfig:
30
+ """Get vision model configuration by model ID"""
31
+ if model_id not in TOGETHER_VISION_MODELS:
32
+ raise ValueError(f"Model {model_id} not found in Together AI vision models")
33
+ return TOGETHER_VISION_MODELS[model_id]
34
+
35
+
36
+ def list_vision_models_by_organization(
37
+ organization: str,
38
+ ) -> Dict[str, VisionModelConfig]:
39
+ """Get all vision models for a specific organization"""
40
+ return {
41
+ model_id: config
42
+ for model_id, config in TOGETHER_VISION_MODELS.items()
43
+ if config.organization.lower() == organization.lower()
44
+ }
45
+
46
+
47
+ def get_default_vision_model() -> str:
48
+ """Get the default vision model ID"""
49
+ return "meta-llama/Llama-Vision-Free"
@@ -0,0 +1,38 @@
1
+ """
2
+ Airtrain Telemetry
3
+
4
+ This package provides telemetry functionality for Airtrain usage.
5
+ Telemetry is enabled by default to help improve the library and can be disabled by
6
+ setting AIRTRAIN_TELEMETRY_ENABLED=false in your environment variables or .env file.
7
+ """
8
+
9
+ from airtrain.telemetry.service import ProductTelemetry
10
+ from airtrain.telemetry.views import (
11
+ AgentRunTelemetryEvent,
12
+ AgentStepTelemetryEvent,
13
+ AgentEndTelemetryEvent,
14
+ ModelInvocationTelemetryEvent,
15
+ ErrorTelemetryEvent,
16
+ UserFeedbackTelemetryEvent,
17
+ SkillInitTelemetryEvent,
18
+ SkillProcessTelemetryEvent,
19
+ PackageInstallTelemetryEvent,
20
+ PackageImportTelemetryEvent,
21
+ )
22
+
23
+ __all__ = [
24
+ "ProductTelemetry",
25
+ "AgentRunTelemetryEvent",
26
+ "AgentStepTelemetryEvent",
27
+ "AgentEndTelemetryEvent",
28
+ "ModelInvocationTelemetryEvent",
29
+ "ErrorTelemetryEvent",
30
+ "UserFeedbackTelemetryEvent",
31
+ "SkillInitTelemetryEvent",
32
+ "SkillProcessTelemetryEvent",
33
+ "PackageInstallTelemetryEvent",
34
+ "PackageImportTelemetryEvent",
35
+ ]
36
+
37
+ # Create a singleton instance for easy import
38
+ telemetry = ProductTelemetry()
@@ -0,0 +1,167 @@
1
+ import logging
2
+ import os
3
+ import platform
4
+ import sys
5
+ import uuid
6
+ from pathlib import Path
7
+
8
+ from dotenv import load_dotenv
9
+ from posthog import Posthog
10
+
11
+ from airtrain.telemetry.views import BaseTelemetryEvent
12
+
13
+ load_dotenv()
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Enhanced event settings to collect more data
18
+ POSTHOG_EVENT_SETTINGS = {
19
+ 'process_person_profile': True,
20
+ 'enable_sent_at': True, # Add timing information
21
+ 'capture_performance': True, # Collect performance data
22
+ 'capture_pageview': True, # More detailed usage tracking
23
+ }
24
+
25
+
26
+ def singleton(cls):
27
+ """Singleton decorator for classes."""
28
+ instances = {}
29
+
30
+ def get_instance(*args, **kwargs):
31
+ if cls not in instances:
32
+ instances[cls] = cls(*args, **kwargs)
33
+ return instances[cls]
34
+
35
+ return get_instance
36
+
37
+
38
+ @singleton
39
+ class ProductTelemetry:
40
+ """
41
+ Service for capturing telemetry data from Airtrain usage.
42
+
43
+ Telemetry is enabled by default but can be disabled by setting
44
+ AIRTRAIN_TELEMETRY_ENABLED=false in your environment.
45
+ """
46
+
47
+ USER_ID_PATH = str(
48
+ Path.home() / '.cache' / 'airtrain' / 'telemetry_user_id'
49
+ )
50
+ # API key for PostHog
51
+ PROJECT_API_KEY = 'phc_1pLNkG3QStYEXIz0CAPQaOGpcmxpE3CJXhE1HANWgIz'
52
+ HOST = 'https://us.i.posthog.com'
53
+ UNKNOWN_USER_ID = 'UNKNOWN'
54
+
55
+ _curr_user_id = None
56
+
57
+ def __init__(self) -> None:
58
+ telemetry_disabled = os.getenv('AIRTRAIN_TELEMETRY_ENABLED', 'true').lower() == 'false'
59
+ self.debug_logging = os.getenv('AIRTRAIN_LOGGING_LEVEL', 'info').lower() == 'debug'
60
+
61
+ # System information to include with telemetry
62
+ self.system_info = {
63
+ 'os': platform.system(),
64
+ 'os_version': platform.version(),
65
+ 'python_version': f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
66
+ 'platform': platform.platform(),
67
+ 'machine': platform.machine(),
68
+ 'hostname': platform.node(),
69
+ 'username': os.getlogin() if hasattr(os, 'getlogin') else 'unknown'
70
+ }
71
+ isBeta = True # TODO: remove this once out of beta
72
+ if telemetry_disabled and not isBeta:
73
+ self._posthog_client = None
74
+ else:
75
+ if not isBeta:
76
+ logging.info(
77
+ 'Telemetry enabled. To disable, set '
78
+ 'AIRTRAIN_TELEMETRY_ENABLED=false in your environment.'
79
+ )
80
+ if isBeta:
81
+ logging.info(
82
+ 'You are currently in beta. Telemetry is enabled by default.'
83
+ )
84
+ self._posthog_client = Posthog(
85
+ project_api_key=self.PROJECT_API_KEY,
86
+ host=self.HOST,
87
+ disable_geoip=False # Collect geographical data
88
+ )
89
+
90
+ # Set debug mode if enabled
91
+ if self.debug_logging:
92
+ self._posthog_client.debug = True
93
+
94
+ # Identify user more specifically
95
+ self._posthog_client.identify(
96
+ self.user_id,
97
+ {
98
+ **self.system_info,
99
+ 'first_seen': True
100
+ }
101
+ )
102
+
103
+ # Silence posthog's logging only if debug is off
104
+ if not self.debug_logging:
105
+ posthog_logger = logging.getLogger('posthog')
106
+ posthog_logger.disabled = True
107
+
108
+ if self._posthog_client is None:
109
+ logger.debug('Telemetry disabled')
110
+
111
+ def capture(self, event: BaseTelemetryEvent) -> None:
112
+ """Capture a telemetry event and send it to PostHog if telemetry is enabled."""
113
+ if self._posthog_client is None:
114
+ return
115
+
116
+ # Add system information to all events
117
+ enhanced_properties = {
118
+ **event.properties,
119
+ **POSTHOG_EVENT_SETTINGS,
120
+ **self.system_info
121
+ }
122
+
123
+ if self.debug_logging:
124
+ logger.debug(f'Telemetry event: {event.name} {enhanced_properties}')
125
+ self._direct_capture(event, enhanced_properties)
126
+
127
+ def _direct_capture(self, event: BaseTelemetryEvent, enhanced_properties: dict) -> None:
128
+ """
129
+ Send the event to PostHog. Should not be thread blocking because posthog handles it.
130
+ """
131
+ if self._posthog_client is None:
132
+ return
133
+
134
+ try:
135
+ self._posthog_client.capture(
136
+ self.user_id,
137
+ event.name,
138
+ enhanced_properties
139
+ )
140
+ except Exception as e:
141
+ logger.error(f'Failed to send telemetry event {event.name}: {e}')
142
+
143
+ @property
144
+ def user_id(self) -> str:
145
+ """
146
+ Get the user ID for telemetry.
147
+ Creates a new one if it doesn't exist.
148
+ """
149
+ if self._curr_user_id:
150
+ return self._curr_user_id
151
+
152
+ # File access may fail due to permissions or other reasons.
153
+ # We don't want to crash so we catch all exceptions.
154
+ try:
155
+ if not os.path.exists(self.USER_ID_PATH):
156
+ os.makedirs(os.path.dirname(self.USER_ID_PATH), exist_ok=True)
157
+ with open(self.USER_ID_PATH, 'w') as f:
158
+ # Use a more identifiable ID prefix
159
+ new_user_id = f"airtrain-user-{uuid.uuid4()}"
160
+ f.write(new_user_id)
161
+ self._curr_user_id = new_user_id
162
+ else:
163
+ with open(self.USER_ID_PATH, 'r') as f:
164
+ self._curr_user_id = f.read()
165
+ except Exception:
166
+ self._curr_user_id = self.UNKNOWN_USER_ID
167
+ return self._curr_user_id