airtrain 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. airtrain/__init__.py +146 -6
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  19. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  21. airtrain/core/credentials.py +62 -44
  22. airtrain/core/skills.py +102 -0
  23. airtrain/integrations/__init__.py +74 -0
  24. airtrain/integrations/anthropic/__init__.py +33 -0
  25. airtrain/integrations/anthropic/credentials.py +32 -0
  26. airtrain/integrations/anthropic/list_models.py +110 -0
  27. airtrain/integrations/anthropic/models_config.py +100 -0
  28. airtrain/integrations/anthropic/skills.py +155 -0
  29. airtrain/integrations/aws/__init__.py +6 -0
  30. airtrain/integrations/aws/credentials.py +36 -0
  31. airtrain/integrations/aws/skills.py +98 -0
  32. airtrain/integrations/cerebras/__init__.py +6 -0
  33. airtrain/integrations/cerebras/credentials.py +19 -0
  34. airtrain/integrations/cerebras/skills.py +127 -0
  35. airtrain/integrations/combined/__init__.py +21 -0
  36. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  37. airtrain/integrations/combined/list_models_factory.py +210 -0
  38. airtrain/integrations/fireworks/__init__.py +21 -0
  39. airtrain/integrations/fireworks/completion_skills.py +147 -0
  40. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  41. airtrain/integrations/fireworks/credentials.py +26 -0
  42. airtrain/integrations/fireworks/list_models.py +128 -0
  43. airtrain/integrations/fireworks/models.py +139 -0
  44. airtrain/integrations/fireworks/requests_skills.py +207 -0
  45. airtrain/integrations/fireworks/skills.py +181 -0
  46. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  47. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  48. airtrain/integrations/fireworks/structured_skills.py +102 -0
  49. airtrain/integrations/google/__init__.py +7 -0
  50. airtrain/integrations/google/credentials.py +58 -0
  51. airtrain/integrations/google/skills.py +122 -0
  52. airtrain/integrations/groq/__init__.py +23 -0
  53. airtrain/integrations/groq/credentials.py +24 -0
  54. airtrain/integrations/groq/models_config.py +162 -0
  55. airtrain/integrations/groq/skills.py +201 -0
  56. airtrain/integrations/ollama/__init__.py +6 -0
  57. airtrain/integrations/ollama/credentials.py +26 -0
  58. airtrain/integrations/ollama/skills.py +41 -0
  59. airtrain/integrations/openai/__init__.py +37 -0
  60. airtrain/integrations/openai/chinese_assistant.py +42 -0
  61. airtrain/integrations/openai/credentials.py +39 -0
  62. airtrain/integrations/openai/list_models.py +112 -0
  63. airtrain/integrations/openai/models_config.py +224 -0
  64. airtrain/integrations/openai/skills.py +342 -0
  65. airtrain/integrations/perplexity/__init__.py +49 -0
  66. airtrain/integrations/perplexity/credentials.py +43 -0
  67. airtrain/integrations/perplexity/list_models.py +112 -0
  68. airtrain/integrations/perplexity/models_config.py +128 -0
  69. airtrain/integrations/perplexity/skills.py +279 -0
  70. airtrain/integrations/sambanova/__init__.py +6 -0
  71. airtrain/integrations/sambanova/credentials.py +20 -0
  72. airtrain/integrations/sambanova/skills.py +129 -0
  73. airtrain/integrations/search/__init__.py +21 -0
  74. airtrain/integrations/search/exa/__init__.py +23 -0
  75. airtrain/integrations/search/exa/credentials.py +30 -0
  76. airtrain/integrations/search/exa/schemas.py +114 -0
  77. airtrain/integrations/search/exa/skills.py +115 -0
  78. airtrain/integrations/together/__init__.py +33 -0
  79. airtrain/integrations/together/audio_models_config.py +34 -0
  80. airtrain/integrations/together/credentials.py +22 -0
  81. airtrain/integrations/together/embedding_models_config.py +92 -0
  82. airtrain/integrations/together/image_models_config.py +69 -0
  83. airtrain/integrations/together/image_skill.py +143 -0
  84. airtrain/integrations/together/list_models.py +76 -0
  85. airtrain/integrations/together/models.py +95 -0
  86. airtrain/integrations/together/models_config.py +399 -0
  87. airtrain/integrations/together/rerank_models_config.py +43 -0
  88. airtrain/integrations/together/rerank_skill.py +49 -0
  89. airtrain/integrations/together/schemas.py +33 -0
  90. airtrain/integrations/together/skills.py +305 -0
  91. airtrain/integrations/together/vision_models_config.py +49 -0
  92. airtrain/telemetry/__init__.py +38 -0
  93. airtrain/telemetry/service.py +167 -0
  94. airtrain/telemetry/views.py +237 -0
  95. airtrain/tools/__init__.py +45 -0
  96. airtrain/tools/command.py +398 -0
  97. airtrain/tools/filesystem.py +166 -0
  98. airtrain/tools/network.py +111 -0
  99. airtrain/tools/registry.py +320 -0
  100. airtrain/tools/search.py +450 -0
  101. airtrain/tools/testing.py +135 -0
  102. airtrain-0.1.4.dist-info/METADATA +222 -0
  103. airtrain-0.1.4.dist-info/RECORD +108 -0
  104. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  105. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  106. airtrain-0.1.3.dist-info/METADATA +0 -106
  107. airtrain-0.1.3.dist-info/RECORD +0 -9
  108. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,147 @@
1
+ from typing import List, Optional, Dict, Any, Generator, Union
2
+ from pydantic import Field
3
+ import requests
4
+ import json
5
+ from loguru import logger
6
+
7
+ from airtrain.core.skills import Skill, ProcessingError
8
+ from airtrain.core.schemas import InputSchema, OutputSchema
9
+ from .credentials import FireworksCredentials
10
+
11
+
12
+ class FireworksCompletionInput(InputSchema):
13
+ """Schema for Fireworks AI completion input using requests"""
14
+
15
+ prompt: str = Field(..., description="Input prompt for completion")
16
+ model: str = Field(
17
+ default="accounts/fireworks/models/deepseek-r1",
18
+ description="Fireworks AI model to use",
19
+ )
20
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
21
+ temperature: float = Field(
22
+ default=0.7, description="Temperature for response generation", ge=0, le=1
23
+ )
24
+ top_p: float = Field(
25
+ default=1.0, description="Top p sampling parameter", ge=0, le=1
26
+ )
27
+ top_k: int = Field(default=50, description="Top k sampling parameter", ge=0)
28
+ presence_penalty: float = Field(
29
+ default=0.0, description="Presence penalty", ge=-2.0, le=2.0
30
+ )
31
+ frequency_penalty: float = Field(
32
+ default=0.0, description="Frequency penalty", ge=-2.0, le=2.0
33
+ )
34
+ repetition_penalty: float = Field(
35
+ default=1.0, description="Repetition penalty", ge=0.0
36
+ )
37
+ stop: Optional[Union[str, List[str]]] = Field(
38
+ default=None, description="Stop sequences"
39
+ )
40
+ echo: bool = Field(default=False, description="Echo the prompt in the response")
41
+ stream: bool = Field(default=False, description="Whether to stream the response")
42
+
43
+
44
+ class FireworksCompletionOutput(OutputSchema):
45
+ """Schema for Fireworks AI completion output"""
46
+
47
+ response: str
48
+ used_model: str
49
+ usage: Dict[str, int]
50
+
51
+
52
+ class FireworksCompletionSkill(
53
+ Skill[FireworksCompletionInput, FireworksCompletionOutput]
54
+ ):
55
+ """Skill for text completion using Fireworks AI"""
56
+
57
+ input_schema = FireworksCompletionInput
58
+ output_schema = FireworksCompletionOutput
59
+ BASE_URL = "https://api.fireworks.ai/inference/v1/completions"
60
+
61
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
62
+ """Initialize the skill with optional credentials"""
63
+ super().__init__()
64
+ self.credentials = credentials or FireworksCredentials.from_env()
65
+ self.headers = {
66
+ "Accept": "application/json",
67
+ "Content-Type": "application/json",
68
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
69
+ }
70
+
71
+ def _build_payload(self, input_data: FireworksCompletionInput) -> Dict[str, Any]:
72
+ """Build the request payload."""
73
+ payload = {
74
+ "model": input_data.model,
75
+ "prompt": input_data.prompt,
76
+ "max_tokens": input_data.max_tokens,
77
+ "temperature": input_data.temperature,
78
+ "top_p": input_data.top_p,
79
+ "top_k": input_data.top_k,
80
+ "presence_penalty": input_data.presence_penalty,
81
+ "frequency_penalty": input_data.frequency_penalty,
82
+ "repetition_penalty": input_data.repetition_penalty,
83
+ "echo": input_data.echo,
84
+ "stream": input_data.stream,
85
+ }
86
+
87
+ if input_data.stop:
88
+ payload["stop"] = input_data.stop
89
+
90
+ return payload
91
+
92
+ def process_stream(
93
+ self, input_data: FireworksCompletionInput
94
+ ) -> Generator[str, None, None]:
95
+ """Process the input and stream the response."""
96
+ try:
97
+ payload = self._build_payload(input_data)
98
+ response = requests.post(
99
+ self.BASE_URL,
100
+ headers=self.headers,
101
+ data=json.dumps(payload),
102
+ stream=True,
103
+ )
104
+ response.raise_for_status()
105
+
106
+ for line in response.iter_lines():
107
+ if line:
108
+ try:
109
+ data = json.loads(line.decode("utf-8").removeprefix("data: "))
110
+ if data.get("choices") and data["choices"][0].get("text"):
111
+ yield data["choices"][0]["text"]
112
+ except json.JSONDecodeError:
113
+ continue
114
+
115
+ except Exception as e:
116
+ raise ProcessingError(f"Fireworks completion streaming failed: {str(e)}")
117
+
118
+ def process(
119
+ self, input_data: FireworksCompletionInput
120
+ ) -> FireworksCompletionOutput:
121
+ """Process the input and return completion response."""
122
+ try:
123
+ if input_data.stream:
124
+ # For streaming, collect the entire response
125
+ response_chunks = []
126
+ for chunk in self.process_stream(input_data):
127
+ response_chunks.append(chunk)
128
+ response_text = "".join(response_chunks)
129
+ usage = {} # Usage stats not available in streaming mode
130
+ else:
131
+ # For non-streaming, use regular request
132
+ payload = self._build_payload(input_data)
133
+ response = requests.post(
134
+ self.BASE_URL, headers=self.headers, data=json.dumps(payload)
135
+ )
136
+ response.raise_for_status()
137
+ data = response.json()
138
+
139
+ response_text = data["choices"][0]["text"]
140
+ usage = data["usage"]
141
+
142
+ return FireworksCompletionOutput(
143
+ response=response_text, used_model=input_data.model, usage=usage
144
+ )
145
+
146
+ except Exception as e:
147
+ raise ProcessingError(f"Fireworks completion failed: {str(e)}")
@@ -0,0 +1,109 @@
1
+ from typing import List, Dict, Optional
2
+ from pydantic import BaseModel, Field
3
+ from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
4
+
5
+ # TODO: Test this thing.
6
+
7
+
8
+ class ConversationState(BaseModel):
9
+ """Model to track conversation state"""
10
+
11
+ messages: List[Dict[str, str]] = Field(
12
+ default_factory=list, description="List of conversation messages"
13
+ )
14
+ system_prompt: str = Field(
15
+ default="You are a helpful assistant.",
16
+ description="System prompt for the conversation",
17
+ )
18
+ model: str = Field(
19
+ default="accounts/fireworks/models/deepseek-r1",
20
+ description="Model being used for the conversation",
21
+ )
22
+ temperature: float = Field(default=0.7, description="Temperature setting")
23
+ max_tokens: Optional[int] = Field(default=131072, description="Max tokens setting")
24
+
25
+
26
+ class FireworksConversationManager:
27
+ """Manager for handling conversation state with Fireworks AI"""
28
+
29
+ def __init__(
30
+ self,
31
+ skill: Optional[FireworksChatSkill] = None,
32
+ system_prompt: str = "You are a helpful assistant.",
33
+ model: str = "accounts/fireworks/models/deepseek-r1",
34
+ temperature: float = 0.7,
35
+ max_tokens: Optional[int] = None,
36
+ ):
37
+ """
38
+ Initialize conversation manager.
39
+
40
+ Args:
41
+ skill: FireworksChatSkill instance (creates new one if None)
42
+ system_prompt: Initial system prompt
43
+ model: Model to use
44
+ temperature: Temperature setting
45
+ max_tokens: Max tokens setting
46
+ """
47
+ self.skill = skill or FireworksChatSkill()
48
+ self.state = ConversationState(
49
+ system_prompt=system_prompt,
50
+ model=model,
51
+ temperature=temperature,
52
+ max_tokens=max_tokens,
53
+ )
54
+
55
+ def send_message(self, user_input: str) -> FireworksOutput:
56
+ """
57
+ Send a message and get response while maintaining conversation history.
58
+
59
+ Args:
60
+ user_input: User's message
61
+
62
+ Returns:
63
+ FireworksOutput: Model's response
64
+ """
65
+ # Create input with current conversation state
66
+ input_data = FireworksInput(
67
+ user_input=user_input,
68
+ system_prompt=self.state.system_prompt,
69
+ conversation_history=self.state.messages,
70
+ model=self.state.model,
71
+ temperature=self.state.temperature,
72
+ max_tokens=self.state.max_tokens,
73
+ )
74
+
75
+ # Get response
76
+ result = self.skill.process(input_data)
77
+
78
+ # Update conversation history
79
+ self.state.messages.extend(
80
+ [
81
+ {"role": "user", "content": user_input},
82
+ {"role": "assistant", "content": result.response},
83
+ ]
84
+ )
85
+
86
+ return result
87
+
88
+ def reset_conversation(self) -> None:
89
+ """Reset the conversation history while maintaining other settings"""
90
+ self.state.messages = []
91
+
92
+ def get_conversation_history(self) -> List[Dict[str, str]]:
93
+ """Get the current conversation history"""
94
+ return self.state.messages.copy()
95
+
96
+ def update_system_prompt(self, new_prompt: str) -> None:
97
+ """Update the system prompt for future messages"""
98
+ self.state.system_prompt = new_prompt
99
+
100
+ def save_state(self, file_path: str) -> None:
101
+ """Save conversation state to a file"""
102
+ with open(file_path, "w") as f:
103
+ f.write(self.state.model_dump_json(indent=2))
104
+
105
+ def load_state(self, file_path: str) -> None:
106
+ """Load conversation state from a file"""
107
+ with open(file_path, "r") as f:
108
+ data = f.read()
109
+ self.state = ConversationState.model_validate_json(data)
@@ -0,0 +1,26 @@
1
+ from pydantic import SecretStr, BaseModel, Field
2
+ from typing import Optional
3
+ import os
4
+
5
+
6
+ class FireworksCredentials(BaseModel):
7
+ """Credentials for Fireworks AI API"""
8
+
9
+ fireworks_api_key: SecretStr = Field(..., min_length=1)
10
+
11
+ def __repr__(self) -> str:
12
+ """Return a string representation of the credentials."""
13
+ return f"FireworksCredentials(fireworks_api_key=SecretStr('**********'))"
14
+
15
+ def __str__(self) -> str:
16
+ """Return a string representation of the credentials."""
17
+ return self.__repr__()
18
+
19
+ @classmethod
20
+ def from_env(cls) -> "FireworksCredentials":
21
+ """Create credentials from environment variables"""
22
+ api_key = os.getenv("FIREWORKS_API_KEY")
23
+ if not api_key:
24
+ raise ValueError("FIREWORKS_API_KEY environment variable not set")
25
+
26
+ return cls(fireworks_api_key=api_key)
@@ -0,0 +1,128 @@
1
+ from typing import Optional, List
2
+ import requests
3
+ from pydantic import Field
4
+
5
+ from airtrain.core.skills import Skill, ProcessingError
6
+ from airtrain.core.schemas import InputSchema, OutputSchema
7
+ from .credentials import FireworksCredentials
8
+ from .models import FireworksModel
9
+
10
+
11
+ class FireworksListModelsInput(InputSchema):
12
+ """Schema for Fireworks AI list models input"""
13
+
14
+ account_id: str = Field(..., description="The Account Id")
15
+ page_size: Optional[int] = Field(
16
+ default=50,
17
+ description=(
18
+ "The maximum number of models to return. The maximum page_size is 200, "
19
+ "values above 200 will be coerced to 200."
20
+ ),
21
+ le=200
22
+ )
23
+ page_token: Optional[str] = Field(
24
+ default=None,
25
+ description=(
26
+ "A page token, received from a previous ListModels call. Provide this "
27
+ "to retrieve the subsequent page. When paginating, all other parameters "
28
+ "provided to ListModels must match the call that provided the page token."
29
+ )
30
+ )
31
+ filter: Optional[str] = Field(
32
+ default=None,
33
+ description=(
34
+ "Only model satisfying the provided filter (if specified) will be "
35
+ "returned. See https://google.aip.dev/160 for the filter grammar."
36
+ )
37
+ )
38
+ order_by: Optional[str] = Field(
39
+ default=None,
40
+ description=(
41
+ "A comma-separated list of fields to order by. e.g. \"foo,bar\" "
42
+ "The default sort order is ascending. To specify a descending order for a "
43
+ "field, append a \" desc\" suffix. e.g. \"foo desc,bar\" "
44
+ "Subfields are specified with a \".\" character. e.g. \"foo.bar\" "
45
+ "If not specified, the default order is by \"name\"."
46
+ )
47
+ )
48
+
49
+
50
+ class FireworksListModelsOutput(OutputSchema):
51
+ """Schema for Fireworks AI list models output"""
52
+
53
+ models: List[FireworksModel] = Field(
54
+ default_factory=list,
55
+ description="List of Fireworks models"
56
+ )
57
+ next_page_token: Optional[str] = Field(
58
+ default=None,
59
+ description="Token for retrieving the next page of results"
60
+ )
61
+ total_size: Optional[int] = Field(
62
+ default=None,
63
+ description="Total number of models available"
64
+ )
65
+
66
+
67
+ class FireworksListModelsSkill(
68
+ Skill[FireworksListModelsInput, FireworksListModelsOutput]
69
+ ):
70
+ """Skill for listing Fireworks AI models"""
71
+
72
+ input_schema = FireworksListModelsInput
73
+ output_schema = FireworksListModelsOutput
74
+
75
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
76
+ """Initialize the skill with optional credentials"""
77
+ super().__init__()
78
+ self.credentials = credentials or FireworksCredentials.from_env()
79
+ self.base_url = "https://api.fireworks.ai/v1"
80
+
81
+ def process(
82
+ self, input_data: FireworksListModelsInput
83
+ ) -> FireworksListModelsOutput:
84
+ """Process the input and return a list of models."""
85
+ try:
86
+ # Build the URL
87
+ url = f"{self.base_url}/accounts/{input_data.account_id}/models"
88
+
89
+ # Prepare query parameters
90
+ params = {}
91
+ if input_data.page_size:
92
+ params["pageSize"] = input_data.page_size
93
+ if input_data.page_token:
94
+ params["pageToken"] = input_data.page_token
95
+ if input_data.filter:
96
+ params["filter"] = input_data.filter
97
+ if input_data.order_by:
98
+ params["orderBy"] = input_data.order_by
99
+
100
+ # Make the request
101
+ headers = {
102
+ "Authorization": (
103
+ f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}"
104
+ )
105
+ }
106
+
107
+ response = requests.get(url, headers=headers, params=params)
108
+ response.raise_for_status()
109
+
110
+ # Parse the response
111
+ result = response.json()
112
+
113
+ # Convert the models to FireworksModel objects
114
+ models = []
115
+ for model_data in result.get("models", []):
116
+ models.append(FireworksModel(**model_data))
117
+
118
+ # Return the output
119
+ return FireworksListModelsOutput(
120
+ models=models,
121
+ next_page_token=result.get("nextPageToken"),
122
+ total_size=result.get("totalSize")
123
+ )
124
+
125
+ except requests.RequestException as e:
126
+ raise ProcessingError(f"Failed to list Fireworks models: {str(e)}")
127
+ except Exception as e:
128
+ raise ProcessingError(f"Error listing Fireworks models: {str(e)}")
@@ -0,0 +1,139 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import Field, BaseModel
3
+
4
+
5
+ class FireworksMessage(BaseModel):
6
+ """Schema for Fireworks chat message"""
7
+
8
+ content: str
9
+ role: str = Field(..., pattern="^(system|user|assistant)$")
10
+
11
+
12
+ class FireworksUsage(BaseModel):
13
+ """Schema for Fireworks API usage statistics"""
14
+
15
+ prompt_tokens: int
16
+ completion_tokens: int
17
+ total_tokens: int
18
+
19
+
20
+ class FireworksResponse(BaseModel):
21
+ """Schema for Fireworks API response"""
22
+
23
+ id: str
24
+ choices: List[Dict[str, Any]]
25
+ created: int
26
+ model: str
27
+ usage: FireworksUsage
28
+
29
+
30
+ class FireworksModelStatus(BaseModel):
31
+ """Schema for Fireworks model status"""
32
+ # This would be filled with actual fields from the API response
33
+
34
+
35
+ class FireworksModelBaseDetails(BaseModel):
36
+ """Schema for Fireworks base model details"""
37
+ # This would be filled with actual fields from the API response
38
+
39
+
40
+ class FireworksPeftDetails(BaseModel):
41
+ """Schema for Fireworks PEFT details"""
42
+ # This would be filled with actual fields from the API response
43
+
44
+
45
+ class FireworksConversationConfig(BaseModel):
46
+ """Schema for Fireworks conversation configuration"""
47
+ # This would be filled with actual fields from the API response
48
+
49
+
50
+ class FireworksModelDeployedRef(BaseModel):
51
+ """Schema for Fireworks deployed model reference"""
52
+ # This would be filled with actual fields from the API response
53
+
54
+
55
+ class FireworksDeprecationDate(BaseModel):
56
+ """Schema for Fireworks deprecation date"""
57
+ # This would be filled with actual fields from the API response
58
+
59
+
60
+ class FireworksModel(BaseModel):
61
+ """Schema for a Fireworks model"""
62
+
63
+ name: str
64
+ display_name: Optional[str] = None
65
+ description: Optional[str] = None
66
+ create_time: Optional[str] = None
67
+ created_by: Optional[str] = None
68
+ state: Optional[str] = None
69
+ status: Optional[Dict[str, Any]] = None
70
+ kind: Optional[str] = None
71
+ github_url: Optional[str] = None
72
+ hugging_face_url: Optional[str] = None
73
+ base_model_details: Optional[Dict[str, Any]] = None
74
+ peft_details: Optional[Dict[str, Any]] = None
75
+ teft_details: Optional[Dict[str, Any]] = None
76
+ public: Optional[bool] = None
77
+ conversation_config: Optional[Dict[str, Any]] = None
78
+ context_length: Optional[int] = None
79
+ supports_image_input: Optional[bool] = None
80
+ supports_tools: Optional[bool] = None
81
+ imported_from: Optional[str] = None
82
+ fine_tuning_job: Optional[str] = None
83
+ default_draft_model: Optional[str] = None
84
+ default_draft_token_count: Optional[int] = None
85
+ precisions: Optional[List[str]] = None
86
+ deployed_model_refs: Optional[List[Dict[str, Any]]] = None
87
+ cluster: Optional[str] = None
88
+ deprecation_date: Optional[Dict[str, Any]] = None
89
+ calibrated: Optional[bool] = None
90
+ tunable: Optional[bool] = None
91
+ supports_lora: Optional[bool] = None
92
+ use_hf_apply_chat_template: Optional[bool] = None
93
+
94
+
95
+ class ListModelsInput(BaseModel):
96
+ """Schema for listing Fireworks models input"""
97
+
98
+ account_id: str = Field(..., description="The Account Id")
99
+ page_size: Optional[int] = Field(
100
+ default=50,
101
+ description=(
102
+ "The maximum number of models to return. The maximum page_size is 200, "
103
+ "values above 200 will be coerced to 200."
104
+ ),
105
+ le=200
106
+ )
107
+ page_token: Optional[str] = Field(
108
+ default=None,
109
+ description=(
110
+ "A page token, received from a previous ListModels call. Provide this "
111
+ "to retrieve the subsequent page. When paginating, all other parameters "
112
+ "provided to ListModels must match the call that provided the page token."
113
+ )
114
+ )
115
+ filter: Optional[str] = Field(
116
+ default=None,
117
+ description=(
118
+ "Only model satisfying the provided filter (if specified) will be "
119
+ "returned. See https://google.aip.dev/160 for the filter grammar."
120
+ )
121
+ )
122
+ order_by: Optional[str] = Field(
123
+ default=None,
124
+ description=(
125
+ "A comma-separated list of fields to order by. e.g. \"foo,bar\" "
126
+ "The default sort order is ascending. To specify a descending order for a "
127
+ "field, append a \" desc\" suffix. e.g. \"foo desc,bar\" "
128
+ "Subfields are specified with a \".\" character. e.g. \"foo.bar\" "
129
+ "If not specified, the default order is by \"name\"."
130
+ )
131
+ )
132
+
133
+
134
+ class ListModelsOutput(BaseModel):
135
+ """Schema for listing Fireworks models output"""
136
+
137
+ models: List[FireworksModel]
138
+ next_page_token: Optional[str] = None
139
+ total_size: Optional[int] = None