airtrain 0.1.51__py3-none-any.whl → 0.1.57__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,465 @@
1
+ """
2
+ Agent Registry System for AirTrain
3
+
4
+ This module provides a registry system for agents that can be used to build AI systems.
5
+ It includes:
6
+ - Base agent class
7
+ - Registration decorators
8
+ - Factory methods for agent creation
9
+ - Discovery utilities for finding available agents
10
+ """
11
+
12
+ from abc import ABC, abstractmethod
13
+ import time
14
+ import uuid
15
+ from typing import List, Optional, Type, TypeVar, Union
16
+ import inspect
17
+
18
+ # Import tool registry components
19
+ from airtrain.tools import ToolFactory, BaseTool
20
+ from airtrain.telemetry import (
21
+ telemetry,
22
+ AgentRunTelemetryEvent,
23
+ AgentStepTelemetryEvent,
24
+ AgentEndTelemetryEvent,
25
+ ModelInvocationTelemetryEvent,
26
+ ErrorTelemetryEvent
27
+ )
28
+ from .memory import AgentMemoryManager
29
+
30
+
31
+ # Type variable for agent classes
32
+ A = TypeVar('A', bound='BaseAgent')
33
+
34
+ # Registry structure for agent classes
35
+ AGENT_REGISTRY = {}
36
+
37
+
38
+ class BaseAgent(ABC):
39
+ """Base class for all agents."""
40
+
41
+ # These will be set by the registration decorator
42
+ agent_name: str = None
43
+
44
+ def __init__(
45
+ self,
46
+ name: str,
47
+ models: Optional[List[str]] = None,
48
+ tools: Optional[List[BaseTool]] = None
49
+ ):
50
+ """
51
+ Initialize an agent.
52
+
53
+ Args:
54
+ name: Name of the agent instance
55
+ models: List of model identifiers to use
56
+ tools: List of tools the agent can use
57
+ """
58
+ self.name = name
59
+ self.models = models or []
60
+ self.tools = tools or []
61
+ self.memory = AgentMemoryManager()
62
+
63
+ # Generate unique agent ID for telemetry
64
+ self.agent_id = f"{name}-{uuid.uuid4()}"
65
+ self.start_time = None
66
+ self.step_count = 0
67
+ self.total_tokens = 0
68
+ self.prompt_tokens = 0
69
+ self.completion_tokens = 0
70
+ self.errors = []
71
+
72
+ # Initialize default short-term memory
73
+ self.memory.create_short_term_memory("default")
74
+
75
+ def add_tool(self, tool: BaseTool):
76
+ """
77
+ Add a tool to the agent.
78
+
79
+ Args:
80
+ tool: Tool instance to add
81
+
82
+ Returns:
83
+ Self for method chaining
84
+ """
85
+ self.tools.append(tool)
86
+ return self
87
+
88
+ def register_tools(self, tools: List[BaseTool]):
89
+ """
90
+ Register multiple tools with the agent.
91
+
92
+ Args:
93
+ tools: List of tool instances to add
94
+
95
+ Returns:
96
+ Self for method chaining
97
+ """
98
+ self.tools.extend(tools)
99
+ return self
100
+
101
+ def create_memory(self, name: str, max_messages: int = 10):
102
+ """
103
+ Create a new short-term memory.
104
+
105
+ Args:
106
+ name: Name for the memory
107
+ max_messages: Maximum messages before summarization
108
+
109
+ Returns:
110
+ The created memory instance
111
+ """
112
+ return self.memory.create_short_term_memory(name, max_messages)
113
+
114
+ def reset_memory(self, name: str = "default"):
115
+ """
116
+ Reset a specific short-term memory.
117
+
118
+ Args:
119
+ name: Name of the memory to reset
120
+
121
+ Returns:
122
+ Self for method chaining
123
+ """
124
+ self.memory.reset_short_term_memory(name)
125
+ return self
126
+
127
+ def start_run(self, task: str, model_name: str, model_provider: str):
128
+ """
129
+ Start an agent run and send telemetry.
130
+
131
+ Args:
132
+ task: Description of the task
133
+ model_name: Name of the model being used
134
+ model_provider: Provider of the model
135
+
136
+ Returns:
137
+ Self for method chaining
138
+ """
139
+ self.start_time = time.time()
140
+ self.step_count = 0
141
+ self.total_tokens = 0
142
+ self.prompt_tokens = 0
143
+ self.completion_tokens = 0
144
+ self.errors = []
145
+
146
+ # Send run event
147
+ event = AgentRunTelemetryEvent(
148
+ agent_id=self.agent_id,
149
+ task=task,
150
+ model_name=model_name,
151
+ model_provider=model_provider,
152
+ version=self._get_package_version(),
153
+ source=self.__class__.__name__
154
+ )
155
+ telemetry.capture(event)
156
+ return self
157
+
158
+ def record_step(self, actions: List[dict], step_error: List[str] = None):
159
+ """
160
+ Record an agent step and send telemetry.
161
+
162
+ Args:
163
+ actions: List of actions taken in this step
164
+ step_error: Optional list of errors encountered
165
+
166
+ Returns:
167
+ Self for method chaining
168
+ """
169
+ self.step_count += 1
170
+ consecutive_failures = len(self.errors)
171
+
172
+ # Send step event
173
+ event = AgentStepTelemetryEvent(
174
+ agent_id=self.agent_id,
175
+ step=self.step_count,
176
+ step_error=step_error or [],
177
+ consecutive_failures=consecutive_failures,
178
+ actions=actions
179
+ )
180
+ telemetry.capture(event)
181
+ return self
182
+
183
+ def record_model_usage(
184
+ self,
185
+ model_name: str,
186
+ model_provider: str,
187
+ tokens: int,
188
+ prompt_tokens: int,
189
+ completion_tokens: int,
190
+ duration_seconds: float,
191
+ error: str = None
192
+ ):
193
+ """
194
+ Record model usage and send telemetry.
195
+
196
+ Args:
197
+ model_name: Name of the model used
198
+ model_provider: Provider of the model
199
+ tokens: Total tokens used
200
+ prompt_tokens: Tokens used in the prompt
201
+ completion_tokens: Tokens used in the completion
202
+ duration_seconds: Duration of the model call
203
+ error: Optional error message
204
+
205
+ Returns:
206
+ Self for method chaining
207
+ """
208
+ self.total_tokens += tokens
209
+ self.prompt_tokens += prompt_tokens
210
+ self.completion_tokens += completion_tokens
211
+
212
+ # Send model usage event
213
+ event = ModelInvocationTelemetryEvent(
214
+ agent_id=self.agent_id,
215
+ model_name=model_name,
216
+ model_provider=model_provider,
217
+ tokens=tokens,
218
+ prompt_tokens=prompt_tokens,
219
+ completion_tokens=completion_tokens,
220
+ duration_seconds=duration_seconds,
221
+ error=error
222
+ )
223
+ telemetry.capture(event)
224
+ return self
225
+
226
+ def end_run(self, is_done: bool, success: bool = None):
227
+ """
228
+ End an agent run and send telemetry.
229
+
230
+ Args:
231
+ is_done: Whether the agent completed its task
232
+ success: Whether the agent was successful
233
+
234
+ Returns:
235
+ Self for method chaining
236
+ """
237
+ duration = time.time() - self.start_time if self.start_time else 0
238
+
239
+ # Send end event
240
+ event = AgentEndTelemetryEvent(
241
+ agent_id=self.agent_id,
242
+ steps=self.step_count,
243
+ is_done=is_done,
244
+ success=success,
245
+ total_tokens=self.total_tokens,
246
+ prompt_tokens=self.prompt_tokens,
247
+ completion_tokens=self.completion_tokens,
248
+ total_duration_seconds=duration,
249
+ errors=self.errors
250
+ )
251
+ telemetry.capture(event)
252
+ return self
253
+
254
+ def record_error(self, error_type: str, error_message: str, component: str):
255
+ """
256
+ Record an error and send telemetry.
257
+
258
+ Args:
259
+ error_type: Type of the error
260
+ error_message: Error message
261
+ component: Component where the error occurred
262
+
263
+ Returns:
264
+ Self for method chaining
265
+ """
266
+ self.errors.append(error_message)
267
+
268
+ # Send error event
269
+ event = ErrorTelemetryEvent(
270
+ error_type=error_type,
271
+ error_message=error_message,
272
+ component=component,
273
+ agent_id=self.agent_id
274
+ )
275
+ telemetry.capture(event)
276
+ return self
277
+
278
+ def _get_package_version(self) -> str:
279
+ """Get the package version for telemetry."""
280
+ try:
281
+ from airtrain import __version__
282
+ return __version__
283
+ except ImportError:
284
+ return "unknown"
285
+
286
+ @abstractmethod
287
+ def process(self, user_input: str, memory_name: str = "default") -> str:
288
+ """
289
+ Process user input using a specific memory context.
290
+
291
+ Args:
292
+ user_input: User input to process
293
+ memory_name: Name of the memory to use for context
294
+
295
+ Returns:
296
+ Agent's response
297
+ """
298
+ pass
299
+
300
+
301
+ class AgentRegistry:
302
+ """Registry for agent classes."""
303
+
304
+ @classmethod
305
+ def register(cls, name: Optional[str] = None):
306
+ """
307
+ Decorator to register an agent class.
308
+
309
+ Args:
310
+ name: Optional name for the agent class
311
+
312
+ Returns:
313
+ Decorator function
314
+ """
315
+ def decorator(agent_class: Type[BaseAgent]) -> Type[BaseAgent]:
316
+ """
317
+ Register an agent class with the registry.
318
+
319
+ Args:
320
+ agent_class: Agent class to register
321
+
322
+ Returns:
323
+ The registered agent class
324
+ """
325
+ # Validate agent class
326
+ if not issubclass(agent_class, BaseAgent):
327
+ raise TypeError(
328
+ f"Agent class {agent_class.__name__} must inherit from BaseAgent"
329
+ )
330
+
331
+ # Check for required methods
332
+ has_process = hasattr(agent_class, 'process')
333
+ is_callable = inspect.isfunction(getattr(agent_class, 'process', None))
334
+ if not has_process or not is_callable:
335
+ raise TypeError(
336
+ f"Agent class {agent_class.__name__} must implement process method"
337
+ )
338
+
339
+ # Determine agent name
340
+ agent_name = name or agent_class.__name__
341
+
342
+ # Check for name conflict
343
+ if agent_name in AGENT_REGISTRY:
344
+ raise ValueError(f"Agent '{agent_name}' already registered")
345
+
346
+ # Register the agent class
347
+ AGENT_REGISTRY[agent_name] = agent_class
348
+
349
+ # Add metadata to the class
350
+ agent_class.agent_name = agent_name
351
+
352
+ return agent_class
353
+
354
+ return decorator
355
+
356
+ @classmethod
357
+ def get_agent_class(cls, name: str) -> Type[BaseAgent]:
358
+ """
359
+ Get agent class by name.
360
+
361
+ Args:
362
+ name: Name of the agent class
363
+
364
+ Returns:
365
+ The agent class
366
+
367
+ Raises:
368
+ ValueError: If agent not found
369
+ """
370
+ if name not in AGENT_REGISTRY:
371
+ raise ValueError(f"Agent '{name}' not found in registry")
372
+ return AGENT_REGISTRY[name]
373
+
374
+ @classmethod
375
+ def list_agents(cls) -> List[str]:
376
+ """
377
+ List all registered agents.
378
+
379
+ Returns:
380
+ List of agent names
381
+ """
382
+ return list(AGENT_REGISTRY.keys())
383
+
384
+
385
+ class AgentFactory:
386
+ """Factory for creating agent instances."""
387
+
388
+ @staticmethod
389
+ def create_agent(
390
+ agent_type: str,
391
+ name: Optional[str] = None,
392
+ models: Optional[List[str]] = None,
393
+ tools: Optional[List[Union[str, BaseTool]]] = None,
394
+ **kwargs
395
+ ) -> BaseAgent:
396
+ """
397
+ Create an agent instance.
398
+
399
+ Args:
400
+ agent_type: Type of agent to create
401
+ name: Name for the agent instance
402
+ models: List of model identifiers
403
+ tools: List of tools or tool names
404
+ **kwargs: Additional arguments for the agent constructor
405
+
406
+ Returns:
407
+ Agent instance
408
+ """
409
+ # Get agent class
410
+ agent_class = AgentRegistry.get_agent_class(agent_type)
411
+
412
+ # Prepare name
413
+ instance_name = name or f"{agent_type}_{id(agent_class)}"
414
+
415
+ # Prepare tools
416
+ tool_instances = []
417
+ if tools:
418
+ for tool in tools:
419
+ if isinstance(tool, str):
420
+ # Assume it's a tool name
421
+ # Try stateless first, then stateful
422
+ try:
423
+ tool_instances.append(ToolFactory.get_tool(tool))
424
+ except ValueError:
425
+ try:
426
+ tool_instances.append(
427
+ ToolFactory.get_tool(tool, "stateful")
428
+ )
429
+ except ValueError:
430
+ raise ValueError(f"Tool '{tool}' not found in registry")
431
+ else:
432
+ # Assume it's a tool instance
433
+ tool_instances.append(tool)
434
+
435
+ # Create agent instance
436
+ return agent_class(
437
+ name=instance_name,
438
+ models=models,
439
+ tools=tool_instances,
440
+ **kwargs
441
+ )
442
+
443
+ @staticmethod
444
+ def list_available_agents() -> List[str]:
445
+ """
446
+ List all available agent types.
447
+
448
+ Returns:
449
+ List of agent type names
450
+ """
451
+ return AgentRegistry.list_agents()
452
+
453
+
454
+ # Convenience decorator for registering agents
455
+ def register_agent(name: Optional[str] = None):
456
+ """
457
+ Decorator to register an agent class.
458
+
459
+ Args:
460
+ name: Optional name for the agent class
461
+
462
+ Returns:
463
+ Decorator function
464
+ """
465
+ return AgentRegistry.register(name)
airtrain/core/skills.py CHANGED
@@ -1,8 +1,17 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from typing import Any, Dict, Optional, Type, Generic, TypeVar
3
3
  from uuid import UUID, uuid4
4
+ import time
5
+ import functools
4
6
  from .schemas import InputSchema, OutputSchema
5
7
 
8
+ # Import telemetry
9
+ from airtrain.telemetry import (
10
+ telemetry,
11
+ SkillInitTelemetryEvent,
12
+ SkillProcessTelemetryEvent,
13
+ )
14
+
6
15
  # Generic type variables for input and output schemas
7
16
  InputT = TypeVar("InputT", bound=InputSchema)
8
17
  OutputT = TypeVar("OutputT", bound=OutputSchema)
@@ -17,6 +26,92 @@ class Skill(ABC, Generic[InputT, OutputT]):
17
26
  input_schema: Type[InputT]
18
27
  output_schema: Type[OutputT]
19
28
  _skill_id: Optional[UUID] = None
29
+ _original_process = None
30
+
31
+ def __init__(self):
32
+ """Initialize the skill and capture telemetry."""
33
+ # Initialize skill_id if not already set
34
+ if not self._skill_id:
35
+ self._skill_id = uuid4()
36
+
37
+ # Monkey patch the process method if it hasn't been patched yet
38
+ # This allows us to add telemetry without changing the API
39
+ if not hasattr(self.__class__, '_patched_process'):
40
+ # Store the original process method implementation from this instance
41
+ # This is crucial for proper behavior with inheritance
42
+ self.__class__._original_process = self.__class__.process
43
+
44
+ # Create a wrapper function that will capture telemetry
45
+ def _create_wrapper(original_method):
46
+ @functools.wraps(original_method)
47
+ def wrapped_process(instance, input_data):
48
+ start_time = time.time()
49
+ error = None
50
+
51
+ try:
52
+ # Call the original process method
53
+ result = original_method(instance, input_data)
54
+ return result
55
+ except Exception as e:
56
+ error = str(e)
57
+ raise
58
+ finally:
59
+ duration = time.time() - start_time
60
+
61
+ try:
62
+ # Serialize input data for telemetry
63
+ serialized_input = None
64
+ try:
65
+ # Convert input_data to dict if it's a Pydantic model
66
+ if hasattr(input_data, "dict"):
67
+ serialized_input = input_data.dict()
68
+ # If it's a dataclass
69
+ elif hasattr(input_data, "__dataclass_fields__"):
70
+ from dataclasses import asdict
71
+ serialized_input = asdict(input_data)
72
+ # Fallback
73
+ else:
74
+ serialized_input = {
75
+ "__str__": str(input_data)
76
+ }
77
+ except Exception:
78
+ # If serialization fails, provide simple info
79
+ serialized_input = {"error": "Failed to serialize input data"}
80
+
81
+ telemetry.capture(
82
+ SkillProcessTelemetryEvent(
83
+ skill_id=str(instance.skill_id),
84
+ skill_class=instance.__class__.__name__,
85
+ input_schema=instance.input_schema.__name__,
86
+ output_schema=instance.output_schema.__name__,
87
+ input_data=serialized_input,
88
+ duration_seconds=duration,
89
+ error=error,
90
+ )
91
+ )
92
+ except Exception:
93
+ # Silently continue if telemetry fails
94
+ pass
95
+
96
+ return wrapped_process
97
+
98
+ # Replace the process method with our wrapped version at the class level
99
+ self.__class__.process = _create_wrapper(self.__class__._original_process)
100
+
101
+ # Mark this class as patched to prevent double-patching
102
+ self.__class__._patched_process = True
103
+
104
+ # Capture telemetry for initialization
105
+ try:
106
+ telemetry.capture(
107
+ SkillInitTelemetryEvent(
108
+ skill_id=str(self.skill_id),
109
+ skill_class=self.__class__.__name__,
110
+ )
111
+ )
112
+ except Exception:
113
+ # Silently continue if telemetry fails
114
+ pass
20
115
 
21
116
  @abstractmethod
22
117
  def process(self, input_data: InputT) -> OutputT:
@@ -34,6 +129,13 @@ class Skill(ABC, Generic[InputT, OutputT]):
34
129
  """
35
130
  pass
36
131
 
132
+ def __call__(self, input_data: InputT) -> OutputT:
133
+ """Make the skill callable, with input/output validation."""
134
+ self.validate_input(input_data)
135
+ result = self.process(input_data)
136
+ self.validate_output(result)
137
+ return result
138
+
37
139
  def validate_input(self, input_data: Any) -> None:
38
140
  """
39
141
  Validate input data before processing.
@@ -85,9 +85,15 @@ class GroqListModelsSkill(BaseListModelsSkill):
85
85
  """Return list of Groq models."""
86
86
  # Default Groq models from trmx_agent config
87
87
  models = [
88
- {"id": "llama-3-70b-8192", "display_name": "Llama 3 70B (8K)"},
89
- {"id": "mixtral-8x7b-32768", "display_name": "Mixtral 8x7B (32K)"},
90
- {"id": "gemma-7b-it", "display_name": "Gemma 7B Instruct"}
88
+ {"id": "llama-3.3-70b-versatile", "display_name": "Llama 3.3 70B Versatile (Tool Use)"},
89
+ {"id": "llama-3.1-8b-instant", "display_name": "Llama 3.1 8B Instant (Tool Use)"},
90
+ {"id": "mixtral-8x7b-32768", "display_name": "Mixtral 8x7B (32K) (Tool Use)"},
91
+ {"id": "gemma2-9b-it", "display_name": "Gemma 2 9B IT (Tool Use)"},
92
+ {"id": "qwen-qwq-32b", "display_name": "Qwen QWQ 32B (Tool Use)"},
93
+ {"id": "qwen-2.5-coder-32b", "display_name": "Qwen 2.5 Coder 32B (Tool Use)"},
94
+ {"id": "qwen-2.5-32b", "display_name": "Qwen 2.5 32B (Tool Use)"},
95
+ {"id": "deepseek-r1-distill-qwen-32b", "display_name": "DeepSeek R1 Distill Qwen 32B (Tool Use)"},
96
+ {"id": "deepseek-r1-distill-llama-70b", "display_name": "DeepSeek R1 Distill Llama 70B (Tool Use)"},
91
97
  ]
92
98
  return models
93
99
 
@@ -2,5 +2,22 @@
2
2
 
3
3
  from .credentials import GroqCredentials
4
4
  from .skills import GroqChatSkill
5
+ from .models_config import (
6
+ get_model_config,
7
+ get_default_model,
8
+ supports_tool_use,
9
+ supports_parallel_tool_use,
10
+ supports_json_mode,
11
+ GROQ_MODELS_CONFIG,
12
+ )
5
13
 
6
- __all__ = ["GroqCredentials", "GroqChatSkill"]
14
+ __all__ = [
15
+ "GroqCredentials",
16
+ "GroqChatSkill",
17
+ "get_model_config",
18
+ "get_default_model",
19
+ "supports_tool_use",
20
+ "supports_parallel_tool_use",
21
+ "supports_json_mode",
22
+ "GROQ_MODELS_CONFIG",
23
+ ]