airtrain 0.1.51__py3-none-any.whl → 0.1.57__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +42 -2
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/core/skills.py +102 -0
- airtrain/integrations/combined/list_models_factory.py +9 -3
- airtrain/integrations/groq/__init__.py +18 -1
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +93 -17
- airtrain/integrations/together/__init__.py +15 -1
- airtrain/integrations/together/models_config.py +123 -1
- airtrain/integrations/together/skills.py +117 -20
- airtrain/telemetry/__init__.py +34 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +173 -0
- airtrain/tools/__init__.py +41 -0
- airtrain/tools/command.py +211 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- {airtrain-0.1.51.dist-info → airtrain-0.1.57.dist-info}/METADATA +37 -1
- {airtrain-0.1.51.dist-info → airtrain-0.1.57.dist-info}/RECORD +27 -13
- {airtrain-0.1.51.dist-info → airtrain-0.1.57.dist-info}/WHEEL +1 -1
- {airtrain-0.1.51.dist-info → airtrain-0.1.57.dist-info}/entry_points.txt +0 -0
- {airtrain-0.1.51.dist-info → airtrain-0.1.57.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,465 @@
|
|
1
|
+
"""
|
2
|
+
Agent Registry System for AirTrain
|
3
|
+
|
4
|
+
This module provides a registry system for agents that can be used to build AI systems.
|
5
|
+
It includes:
|
6
|
+
- Base agent class
|
7
|
+
- Registration decorators
|
8
|
+
- Factory methods for agent creation
|
9
|
+
- Discovery utilities for finding available agents
|
10
|
+
"""
|
11
|
+
|
12
|
+
from abc import ABC, abstractmethod
|
13
|
+
import time
|
14
|
+
import uuid
|
15
|
+
from typing import List, Optional, Type, TypeVar, Union
|
16
|
+
import inspect
|
17
|
+
|
18
|
+
# Import tool registry components
|
19
|
+
from airtrain.tools import ToolFactory, BaseTool
|
20
|
+
from airtrain.telemetry import (
|
21
|
+
telemetry,
|
22
|
+
AgentRunTelemetryEvent,
|
23
|
+
AgentStepTelemetryEvent,
|
24
|
+
AgentEndTelemetryEvent,
|
25
|
+
ModelInvocationTelemetryEvent,
|
26
|
+
ErrorTelemetryEvent
|
27
|
+
)
|
28
|
+
from .memory import AgentMemoryManager
|
29
|
+
|
30
|
+
|
31
|
+
# Type variable for agent classes
|
32
|
+
A = TypeVar('A', bound='BaseAgent')
|
33
|
+
|
34
|
+
# Registry structure for agent classes
|
35
|
+
AGENT_REGISTRY = {}
|
36
|
+
|
37
|
+
|
38
|
+
class BaseAgent(ABC):
|
39
|
+
"""Base class for all agents."""
|
40
|
+
|
41
|
+
# These will be set by the registration decorator
|
42
|
+
agent_name: str = None
|
43
|
+
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
name: str,
|
47
|
+
models: Optional[List[str]] = None,
|
48
|
+
tools: Optional[List[BaseTool]] = None
|
49
|
+
):
|
50
|
+
"""
|
51
|
+
Initialize an agent.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
name: Name of the agent instance
|
55
|
+
models: List of model identifiers to use
|
56
|
+
tools: List of tools the agent can use
|
57
|
+
"""
|
58
|
+
self.name = name
|
59
|
+
self.models = models or []
|
60
|
+
self.tools = tools or []
|
61
|
+
self.memory = AgentMemoryManager()
|
62
|
+
|
63
|
+
# Generate unique agent ID for telemetry
|
64
|
+
self.agent_id = f"{name}-{uuid.uuid4()}"
|
65
|
+
self.start_time = None
|
66
|
+
self.step_count = 0
|
67
|
+
self.total_tokens = 0
|
68
|
+
self.prompt_tokens = 0
|
69
|
+
self.completion_tokens = 0
|
70
|
+
self.errors = []
|
71
|
+
|
72
|
+
# Initialize default short-term memory
|
73
|
+
self.memory.create_short_term_memory("default")
|
74
|
+
|
75
|
+
def add_tool(self, tool: BaseTool):
|
76
|
+
"""
|
77
|
+
Add a tool to the agent.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
tool: Tool instance to add
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
Self for method chaining
|
84
|
+
"""
|
85
|
+
self.tools.append(tool)
|
86
|
+
return self
|
87
|
+
|
88
|
+
def register_tools(self, tools: List[BaseTool]):
|
89
|
+
"""
|
90
|
+
Register multiple tools with the agent.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
tools: List of tool instances to add
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
Self for method chaining
|
97
|
+
"""
|
98
|
+
self.tools.extend(tools)
|
99
|
+
return self
|
100
|
+
|
101
|
+
def create_memory(self, name: str, max_messages: int = 10):
|
102
|
+
"""
|
103
|
+
Create a new short-term memory.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
name: Name for the memory
|
107
|
+
max_messages: Maximum messages before summarization
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
The created memory instance
|
111
|
+
"""
|
112
|
+
return self.memory.create_short_term_memory(name, max_messages)
|
113
|
+
|
114
|
+
def reset_memory(self, name: str = "default"):
|
115
|
+
"""
|
116
|
+
Reset a specific short-term memory.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
name: Name of the memory to reset
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
Self for method chaining
|
123
|
+
"""
|
124
|
+
self.memory.reset_short_term_memory(name)
|
125
|
+
return self
|
126
|
+
|
127
|
+
def start_run(self, task: str, model_name: str, model_provider: str):
|
128
|
+
"""
|
129
|
+
Start an agent run and send telemetry.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
task: Description of the task
|
133
|
+
model_name: Name of the model being used
|
134
|
+
model_provider: Provider of the model
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
Self for method chaining
|
138
|
+
"""
|
139
|
+
self.start_time = time.time()
|
140
|
+
self.step_count = 0
|
141
|
+
self.total_tokens = 0
|
142
|
+
self.prompt_tokens = 0
|
143
|
+
self.completion_tokens = 0
|
144
|
+
self.errors = []
|
145
|
+
|
146
|
+
# Send run event
|
147
|
+
event = AgentRunTelemetryEvent(
|
148
|
+
agent_id=self.agent_id,
|
149
|
+
task=task,
|
150
|
+
model_name=model_name,
|
151
|
+
model_provider=model_provider,
|
152
|
+
version=self._get_package_version(),
|
153
|
+
source=self.__class__.__name__
|
154
|
+
)
|
155
|
+
telemetry.capture(event)
|
156
|
+
return self
|
157
|
+
|
158
|
+
def record_step(self, actions: List[dict], step_error: List[str] = None):
|
159
|
+
"""
|
160
|
+
Record an agent step and send telemetry.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
actions: List of actions taken in this step
|
164
|
+
step_error: Optional list of errors encountered
|
165
|
+
|
166
|
+
Returns:
|
167
|
+
Self for method chaining
|
168
|
+
"""
|
169
|
+
self.step_count += 1
|
170
|
+
consecutive_failures = len(self.errors)
|
171
|
+
|
172
|
+
# Send step event
|
173
|
+
event = AgentStepTelemetryEvent(
|
174
|
+
agent_id=self.agent_id,
|
175
|
+
step=self.step_count,
|
176
|
+
step_error=step_error or [],
|
177
|
+
consecutive_failures=consecutive_failures,
|
178
|
+
actions=actions
|
179
|
+
)
|
180
|
+
telemetry.capture(event)
|
181
|
+
return self
|
182
|
+
|
183
|
+
def record_model_usage(
|
184
|
+
self,
|
185
|
+
model_name: str,
|
186
|
+
model_provider: str,
|
187
|
+
tokens: int,
|
188
|
+
prompt_tokens: int,
|
189
|
+
completion_tokens: int,
|
190
|
+
duration_seconds: float,
|
191
|
+
error: str = None
|
192
|
+
):
|
193
|
+
"""
|
194
|
+
Record model usage and send telemetry.
|
195
|
+
|
196
|
+
Args:
|
197
|
+
model_name: Name of the model used
|
198
|
+
model_provider: Provider of the model
|
199
|
+
tokens: Total tokens used
|
200
|
+
prompt_tokens: Tokens used in the prompt
|
201
|
+
completion_tokens: Tokens used in the completion
|
202
|
+
duration_seconds: Duration of the model call
|
203
|
+
error: Optional error message
|
204
|
+
|
205
|
+
Returns:
|
206
|
+
Self for method chaining
|
207
|
+
"""
|
208
|
+
self.total_tokens += tokens
|
209
|
+
self.prompt_tokens += prompt_tokens
|
210
|
+
self.completion_tokens += completion_tokens
|
211
|
+
|
212
|
+
# Send model usage event
|
213
|
+
event = ModelInvocationTelemetryEvent(
|
214
|
+
agent_id=self.agent_id,
|
215
|
+
model_name=model_name,
|
216
|
+
model_provider=model_provider,
|
217
|
+
tokens=tokens,
|
218
|
+
prompt_tokens=prompt_tokens,
|
219
|
+
completion_tokens=completion_tokens,
|
220
|
+
duration_seconds=duration_seconds,
|
221
|
+
error=error
|
222
|
+
)
|
223
|
+
telemetry.capture(event)
|
224
|
+
return self
|
225
|
+
|
226
|
+
def end_run(self, is_done: bool, success: bool = None):
|
227
|
+
"""
|
228
|
+
End an agent run and send telemetry.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
is_done: Whether the agent completed its task
|
232
|
+
success: Whether the agent was successful
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
Self for method chaining
|
236
|
+
"""
|
237
|
+
duration = time.time() - self.start_time if self.start_time else 0
|
238
|
+
|
239
|
+
# Send end event
|
240
|
+
event = AgentEndTelemetryEvent(
|
241
|
+
agent_id=self.agent_id,
|
242
|
+
steps=self.step_count,
|
243
|
+
is_done=is_done,
|
244
|
+
success=success,
|
245
|
+
total_tokens=self.total_tokens,
|
246
|
+
prompt_tokens=self.prompt_tokens,
|
247
|
+
completion_tokens=self.completion_tokens,
|
248
|
+
total_duration_seconds=duration,
|
249
|
+
errors=self.errors
|
250
|
+
)
|
251
|
+
telemetry.capture(event)
|
252
|
+
return self
|
253
|
+
|
254
|
+
def record_error(self, error_type: str, error_message: str, component: str):
|
255
|
+
"""
|
256
|
+
Record an error and send telemetry.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
error_type: Type of the error
|
260
|
+
error_message: Error message
|
261
|
+
component: Component where the error occurred
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
Self for method chaining
|
265
|
+
"""
|
266
|
+
self.errors.append(error_message)
|
267
|
+
|
268
|
+
# Send error event
|
269
|
+
event = ErrorTelemetryEvent(
|
270
|
+
error_type=error_type,
|
271
|
+
error_message=error_message,
|
272
|
+
component=component,
|
273
|
+
agent_id=self.agent_id
|
274
|
+
)
|
275
|
+
telemetry.capture(event)
|
276
|
+
return self
|
277
|
+
|
278
|
+
def _get_package_version(self) -> str:
|
279
|
+
"""Get the package version for telemetry."""
|
280
|
+
try:
|
281
|
+
from airtrain import __version__
|
282
|
+
return __version__
|
283
|
+
except ImportError:
|
284
|
+
return "unknown"
|
285
|
+
|
286
|
+
@abstractmethod
|
287
|
+
def process(self, user_input: str, memory_name: str = "default") -> str:
|
288
|
+
"""
|
289
|
+
Process user input using a specific memory context.
|
290
|
+
|
291
|
+
Args:
|
292
|
+
user_input: User input to process
|
293
|
+
memory_name: Name of the memory to use for context
|
294
|
+
|
295
|
+
Returns:
|
296
|
+
Agent's response
|
297
|
+
"""
|
298
|
+
pass
|
299
|
+
|
300
|
+
|
301
|
+
class AgentRegistry:
|
302
|
+
"""Registry for agent classes."""
|
303
|
+
|
304
|
+
@classmethod
|
305
|
+
def register(cls, name: Optional[str] = None):
|
306
|
+
"""
|
307
|
+
Decorator to register an agent class.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
name: Optional name for the agent class
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
Decorator function
|
314
|
+
"""
|
315
|
+
def decorator(agent_class: Type[BaseAgent]) -> Type[BaseAgent]:
|
316
|
+
"""
|
317
|
+
Register an agent class with the registry.
|
318
|
+
|
319
|
+
Args:
|
320
|
+
agent_class: Agent class to register
|
321
|
+
|
322
|
+
Returns:
|
323
|
+
The registered agent class
|
324
|
+
"""
|
325
|
+
# Validate agent class
|
326
|
+
if not issubclass(agent_class, BaseAgent):
|
327
|
+
raise TypeError(
|
328
|
+
f"Agent class {agent_class.__name__} must inherit from BaseAgent"
|
329
|
+
)
|
330
|
+
|
331
|
+
# Check for required methods
|
332
|
+
has_process = hasattr(agent_class, 'process')
|
333
|
+
is_callable = inspect.isfunction(getattr(agent_class, 'process', None))
|
334
|
+
if not has_process or not is_callable:
|
335
|
+
raise TypeError(
|
336
|
+
f"Agent class {agent_class.__name__} must implement process method"
|
337
|
+
)
|
338
|
+
|
339
|
+
# Determine agent name
|
340
|
+
agent_name = name or agent_class.__name__
|
341
|
+
|
342
|
+
# Check for name conflict
|
343
|
+
if agent_name in AGENT_REGISTRY:
|
344
|
+
raise ValueError(f"Agent '{agent_name}' already registered")
|
345
|
+
|
346
|
+
# Register the agent class
|
347
|
+
AGENT_REGISTRY[agent_name] = agent_class
|
348
|
+
|
349
|
+
# Add metadata to the class
|
350
|
+
agent_class.agent_name = agent_name
|
351
|
+
|
352
|
+
return agent_class
|
353
|
+
|
354
|
+
return decorator
|
355
|
+
|
356
|
+
@classmethod
|
357
|
+
def get_agent_class(cls, name: str) -> Type[BaseAgent]:
|
358
|
+
"""
|
359
|
+
Get agent class by name.
|
360
|
+
|
361
|
+
Args:
|
362
|
+
name: Name of the agent class
|
363
|
+
|
364
|
+
Returns:
|
365
|
+
The agent class
|
366
|
+
|
367
|
+
Raises:
|
368
|
+
ValueError: If agent not found
|
369
|
+
"""
|
370
|
+
if name not in AGENT_REGISTRY:
|
371
|
+
raise ValueError(f"Agent '{name}' not found in registry")
|
372
|
+
return AGENT_REGISTRY[name]
|
373
|
+
|
374
|
+
@classmethod
|
375
|
+
def list_agents(cls) -> List[str]:
|
376
|
+
"""
|
377
|
+
List all registered agents.
|
378
|
+
|
379
|
+
Returns:
|
380
|
+
List of agent names
|
381
|
+
"""
|
382
|
+
return list(AGENT_REGISTRY.keys())
|
383
|
+
|
384
|
+
|
385
|
+
class AgentFactory:
|
386
|
+
"""Factory for creating agent instances."""
|
387
|
+
|
388
|
+
@staticmethod
|
389
|
+
def create_agent(
|
390
|
+
agent_type: str,
|
391
|
+
name: Optional[str] = None,
|
392
|
+
models: Optional[List[str]] = None,
|
393
|
+
tools: Optional[List[Union[str, BaseTool]]] = None,
|
394
|
+
**kwargs
|
395
|
+
) -> BaseAgent:
|
396
|
+
"""
|
397
|
+
Create an agent instance.
|
398
|
+
|
399
|
+
Args:
|
400
|
+
agent_type: Type of agent to create
|
401
|
+
name: Name for the agent instance
|
402
|
+
models: List of model identifiers
|
403
|
+
tools: List of tools or tool names
|
404
|
+
**kwargs: Additional arguments for the agent constructor
|
405
|
+
|
406
|
+
Returns:
|
407
|
+
Agent instance
|
408
|
+
"""
|
409
|
+
# Get agent class
|
410
|
+
agent_class = AgentRegistry.get_agent_class(agent_type)
|
411
|
+
|
412
|
+
# Prepare name
|
413
|
+
instance_name = name or f"{agent_type}_{id(agent_class)}"
|
414
|
+
|
415
|
+
# Prepare tools
|
416
|
+
tool_instances = []
|
417
|
+
if tools:
|
418
|
+
for tool in tools:
|
419
|
+
if isinstance(tool, str):
|
420
|
+
# Assume it's a tool name
|
421
|
+
# Try stateless first, then stateful
|
422
|
+
try:
|
423
|
+
tool_instances.append(ToolFactory.get_tool(tool))
|
424
|
+
except ValueError:
|
425
|
+
try:
|
426
|
+
tool_instances.append(
|
427
|
+
ToolFactory.get_tool(tool, "stateful")
|
428
|
+
)
|
429
|
+
except ValueError:
|
430
|
+
raise ValueError(f"Tool '{tool}' not found in registry")
|
431
|
+
else:
|
432
|
+
# Assume it's a tool instance
|
433
|
+
tool_instances.append(tool)
|
434
|
+
|
435
|
+
# Create agent instance
|
436
|
+
return agent_class(
|
437
|
+
name=instance_name,
|
438
|
+
models=models,
|
439
|
+
tools=tool_instances,
|
440
|
+
**kwargs
|
441
|
+
)
|
442
|
+
|
443
|
+
@staticmethod
|
444
|
+
def list_available_agents() -> List[str]:
|
445
|
+
"""
|
446
|
+
List all available agent types.
|
447
|
+
|
448
|
+
Returns:
|
449
|
+
List of agent type names
|
450
|
+
"""
|
451
|
+
return AgentRegistry.list_agents()
|
452
|
+
|
453
|
+
|
454
|
+
# Convenience decorator for registering agents
|
455
|
+
def register_agent(name: Optional[str] = None):
|
456
|
+
"""
|
457
|
+
Decorator to register an agent class.
|
458
|
+
|
459
|
+
Args:
|
460
|
+
name: Optional name for the agent class
|
461
|
+
|
462
|
+
Returns:
|
463
|
+
Decorator function
|
464
|
+
"""
|
465
|
+
return AgentRegistry.register(name)
|
airtrain/core/skills.py
CHANGED
@@ -1,8 +1,17 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
2
|
from typing import Any, Dict, Optional, Type, Generic, TypeVar
|
3
3
|
from uuid import UUID, uuid4
|
4
|
+
import time
|
5
|
+
import functools
|
4
6
|
from .schemas import InputSchema, OutputSchema
|
5
7
|
|
8
|
+
# Import telemetry
|
9
|
+
from airtrain.telemetry import (
|
10
|
+
telemetry,
|
11
|
+
SkillInitTelemetryEvent,
|
12
|
+
SkillProcessTelemetryEvent,
|
13
|
+
)
|
14
|
+
|
6
15
|
# Generic type variables for input and output schemas
|
7
16
|
InputT = TypeVar("InputT", bound=InputSchema)
|
8
17
|
OutputT = TypeVar("OutputT", bound=OutputSchema)
|
@@ -17,6 +26,92 @@ class Skill(ABC, Generic[InputT, OutputT]):
|
|
17
26
|
input_schema: Type[InputT]
|
18
27
|
output_schema: Type[OutputT]
|
19
28
|
_skill_id: Optional[UUID] = None
|
29
|
+
_original_process = None
|
30
|
+
|
31
|
+
def __init__(self):
|
32
|
+
"""Initialize the skill and capture telemetry."""
|
33
|
+
# Initialize skill_id if not already set
|
34
|
+
if not self._skill_id:
|
35
|
+
self._skill_id = uuid4()
|
36
|
+
|
37
|
+
# Monkey patch the process method if it hasn't been patched yet
|
38
|
+
# This allows us to add telemetry without changing the API
|
39
|
+
if not hasattr(self.__class__, '_patched_process'):
|
40
|
+
# Store the original process method implementation from this instance
|
41
|
+
# This is crucial for proper behavior with inheritance
|
42
|
+
self.__class__._original_process = self.__class__.process
|
43
|
+
|
44
|
+
# Create a wrapper function that will capture telemetry
|
45
|
+
def _create_wrapper(original_method):
|
46
|
+
@functools.wraps(original_method)
|
47
|
+
def wrapped_process(instance, input_data):
|
48
|
+
start_time = time.time()
|
49
|
+
error = None
|
50
|
+
|
51
|
+
try:
|
52
|
+
# Call the original process method
|
53
|
+
result = original_method(instance, input_data)
|
54
|
+
return result
|
55
|
+
except Exception as e:
|
56
|
+
error = str(e)
|
57
|
+
raise
|
58
|
+
finally:
|
59
|
+
duration = time.time() - start_time
|
60
|
+
|
61
|
+
try:
|
62
|
+
# Serialize input data for telemetry
|
63
|
+
serialized_input = None
|
64
|
+
try:
|
65
|
+
# Convert input_data to dict if it's a Pydantic model
|
66
|
+
if hasattr(input_data, "dict"):
|
67
|
+
serialized_input = input_data.dict()
|
68
|
+
# If it's a dataclass
|
69
|
+
elif hasattr(input_data, "__dataclass_fields__"):
|
70
|
+
from dataclasses import asdict
|
71
|
+
serialized_input = asdict(input_data)
|
72
|
+
# Fallback
|
73
|
+
else:
|
74
|
+
serialized_input = {
|
75
|
+
"__str__": str(input_data)
|
76
|
+
}
|
77
|
+
except Exception:
|
78
|
+
# If serialization fails, provide simple info
|
79
|
+
serialized_input = {"error": "Failed to serialize input data"}
|
80
|
+
|
81
|
+
telemetry.capture(
|
82
|
+
SkillProcessTelemetryEvent(
|
83
|
+
skill_id=str(instance.skill_id),
|
84
|
+
skill_class=instance.__class__.__name__,
|
85
|
+
input_schema=instance.input_schema.__name__,
|
86
|
+
output_schema=instance.output_schema.__name__,
|
87
|
+
input_data=serialized_input,
|
88
|
+
duration_seconds=duration,
|
89
|
+
error=error,
|
90
|
+
)
|
91
|
+
)
|
92
|
+
except Exception:
|
93
|
+
# Silently continue if telemetry fails
|
94
|
+
pass
|
95
|
+
|
96
|
+
return wrapped_process
|
97
|
+
|
98
|
+
# Replace the process method with our wrapped version at the class level
|
99
|
+
self.__class__.process = _create_wrapper(self.__class__._original_process)
|
100
|
+
|
101
|
+
# Mark this class as patched to prevent double-patching
|
102
|
+
self.__class__._patched_process = True
|
103
|
+
|
104
|
+
# Capture telemetry for initialization
|
105
|
+
try:
|
106
|
+
telemetry.capture(
|
107
|
+
SkillInitTelemetryEvent(
|
108
|
+
skill_id=str(self.skill_id),
|
109
|
+
skill_class=self.__class__.__name__,
|
110
|
+
)
|
111
|
+
)
|
112
|
+
except Exception:
|
113
|
+
# Silently continue if telemetry fails
|
114
|
+
pass
|
20
115
|
|
21
116
|
@abstractmethod
|
22
117
|
def process(self, input_data: InputT) -> OutputT:
|
@@ -34,6 +129,13 @@ class Skill(ABC, Generic[InputT, OutputT]):
|
|
34
129
|
"""
|
35
130
|
pass
|
36
131
|
|
132
|
+
def __call__(self, input_data: InputT) -> OutputT:
|
133
|
+
"""Make the skill callable, with input/output validation."""
|
134
|
+
self.validate_input(input_data)
|
135
|
+
result = self.process(input_data)
|
136
|
+
self.validate_output(result)
|
137
|
+
return result
|
138
|
+
|
37
139
|
def validate_input(self, input_data: Any) -> None:
|
38
140
|
"""
|
39
141
|
Validate input data before processing.
|
@@ -85,9 +85,15 @@ class GroqListModelsSkill(BaseListModelsSkill):
|
|
85
85
|
"""Return list of Groq models."""
|
86
86
|
# Default Groq models from trmx_agent config
|
87
87
|
models = [
|
88
|
-
{"id": "llama-3-70b-
|
89
|
-
{"id": "
|
90
|
-
{"id": "
|
88
|
+
{"id": "llama-3.3-70b-versatile", "display_name": "Llama 3.3 70B Versatile (Tool Use)"},
|
89
|
+
{"id": "llama-3.1-8b-instant", "display_name": "Llama 3.1 8B Instant (Tool Use)"},
|
90
|
+
{"id": "mixtral-8x7b-32768", "display_name": "Mixtral 8x7B (32K) (Tool Use)"},
|
91
|
+
{"id": "gemma2-9b-it", "display_name": "Gemma 2 9B IT (Tool Use)"},
|
92
|
+
{"id": "qwen-qwq-32b", "display_name": "Qwen QWQ 32B (Tool Use)"},
|
93
|
+
{"id": "qwen-2.5-coder-32b", "display_name": "Qwen 2.5 Coder 32B (Tool Use)"},
|
94
|
+
{"id": "qwen-2.5-32b", "display_name": "Qwen 2.5 32B (Tool Use)"},
|
95
|
+
{"id": "deepseek-r1-distill-qwen-32b", "display_name": "DeepSeek R1 Distill Qwen 32B (Tool Use)"},
|
96
|
+
{"id": "deepseek-r1-distill-llama-70b", "display_name": "DeepSeek R1 Distill Llama 70B (Tool Use)"},
|
91
97
|
]
|
92
98
|
return models
|
93
99
|
|
@@ -2,5 +2,22 @@
|
|
2
2
|
|
3
3
|
from .credentials import GroqCredentials
|
4
4
|
from .skills import GroqChatSkill
|
5
|
+
from .models_config import (
|
6
|
+
get_model_config,
|
7
|
+
get_default_model,
|
8
|
+
supports_tool_use,
|
9
|
+
supports_parallel_tool_use,
|
10
|
+
supports_json_mode,
|
11
|
+
GROQ_MODELS_CONFIG,
|
12
|
+
)
|
5
13
|
|
6
|
-
__all__ = [
|
14
|
+
__all__ = [
|
15
|
+
"GroqCredentials",
|
16
|
+
"GroqChatSkill",
|
17
|
+
"get_model_config",
|
18
|
+
"get_default_model",
|
19
|
+
"supports_tool_use",
|
20
|
+
"supports_parallel_tool_use",
|
21
|
+
"supports_json_mode",
|
22
|
+
"GROQ_MODELS_CONFIG",
|
23
|
+
]
|