stratifyai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +5 -0
- cli/stratifyai_cli.py +1753 -0
- stratifyai/__init__.py +113 -0
- stratifyai/api_key_helper.py +372 -0
- stratifyai/caching.py +279 -0
- stratifyai/chat/__init__.py +54 -0
- stratifyai/chat/builder.py +366 -0
- stratifyai/chat/stratifyai_anthropic.py +194 -0
- stratifyai/chat/stratifyai_bedrock.py +200 -0
- stratifyai/chat/stratifyai_deepseek.py +194 -0
- stratifyai/chat/stratifyai_google.py +194 -0
- stratifyai/chat/stratifyai_grok.py +194 -0
- stratifyai/chat/stratifyai_groq.py +195 -0
- stratifyai/chat/stratifyai_ollama.py +201 -0
- stratifyai/chat/stratifyai_openai.py +209 -0
- stratifyai/chat/stratifyai_openrouter.py +201 -0
- stratifyai/chunking.py +158 -0
- stratifyai/client.py +292 -0
- stratifyai/config.py +1273 -0
- stratifyai/cost_tracker.py +257 -0
- stratifyai/embeddings.py +245 -0
- stratifyai/exceptions.py +91 -0
- stratifyai/models.py +59 -0
- stratifyai/providers/__init__.py +5 -0
- stratifyai/providers/anthropic.py +330 -0
- stratifyai/providers/base.py +183 -0
- stratifyai/providers/bedrock.py +634 -0
- stratifyai/providers/deepseek.py +39 -0
- stratifyai/providers/google.py +39 -0
- stratifyai/providers/grok.py +39 -0
- stratifyai/providers/groq.py +39 -0
- stratifyai/providers/ollama.py +43 -0
- stratifyai/providers/openai.py +344 -0
- stratifyai/providers/openai_compatible.py +372 -0
- stratifyai/providers/openrouter.py +39 -0
- stratifyai/py.typed +2 -0
- stratifyai/rag.py +381 -0
- stratifyai/retry.py +185 -0
- stratifyai/router.py +643 -0
- stratifyai/summarization.py +179 -0
- stratifyai/utils/__init__.py +11 -0
- stratifyai/utils/bedrock_validator.py +136 -0
- stratifyai/utils/code_extractor.py +327 -0
- stratifyai/utils/csv_extractor.py +197 -0
- stratifyai/utils/file_analyzer.py +192 -0
- stratifyai/utils/json_extractor.py +219 -0
- stratifyai/utils/log_extractor.py +267 -0
- stratifyai/utils/model_selector.py +324 -0
- stratifyai/utils/provider_validator.py +442 -0
- stratifyai/utils/token_counter.py +186 -0
- stratifyai/vectordb.py +344 -0
- stratifyai-0.1.0.dist-info/METADATA +263 -0
- stratifyai-0.1.0.dist-info/RECORD +57 -0
- stratifyai-0.1.0.dist-info/WHEEL +5 -0
- stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
- stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
- stratifyai-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Google Gemini chat interface for StratifyAI.
|
|
2
|
+
|
|
3
|
+
Provides convenient functions for Google Gemini chat completions.
|
|
4
|
+
Model must be specified for each request.
|
|
5
|
+
|
|
6
|
+
Environment Variable: GOOGLE_API_KEY
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
# Model is always required
|
|
10
|
+
from stratifyai.chat import google
|
|
11
|
+
response = await google.chat("Hello!", model="gemini-2.5-flash")
|
|
12
|
+
|
|
13
|
+
# Builder pattern (model required)
|
|
14
|
+
client = (
|
|
15
|
+
google
|
|
16
|
+
.with_model("gemini-2.5-pro")
|
|
17
|
+
.with_system("You are a helpful assistant")
|
|
18
|
+
.with_developer("Use markdown")
|
|
19
|
+
)
|
|
20
|
+
response = await client.chat("Hello!")
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import asyncio
|
|
24
|
+
from typing import AsyncIterator, Optional, Union
|
|
25
|
+
|
|
26
|
+
from stratifyai import LLMClient
|
|
27
|
+
from stratifyai.models import ChatResponse, Message
|
|
28
|
+
from stratifyai.chat.builder import ChatBuilder, create_module_builder
|
|
29
|
+
|
|
30
|
+
# Default configuration (no default model - must be specified)
|
|
31
|
+
DEFAULT_TEMPERATURE = 0.7
|
|
32
|
+
DEFAULT_MAX_TOKENS = None
|
|
33
|
+
|
|
34
|
+
# Module-level client (lazy initialization)
|
|
35
|
+
_client: Optional[LLMClient] = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _get_client() -> LLMClient:
|
|
39
|
+
"""Get or create the module-level client."""
|
|
40
|
+
global _client
|
|
41
|
+
if _client is None:
|
|
42
|
+
_client = LLMClient(provider="google")
|
|
43
|
+
return _client
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Module-level builder for chaining
|
|
47
|
+
_builder = create_module_builder(
|
|
48
|
+
provider="google",
|
|
49
|
+
default_temperature=DEFAULT_TEMPERATURE,
|
|
50
|
+
default_max_tokens=DEFAULT_MAX_TOKENS,
|
|
51
|
+
client_factory=_get_client,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# Builder pattern methods (delegate to _builder)
|
|
56
|
+
def with_model(model: str) -> ChatBuilder:
|
|
57
|
+
"""Set the model to use. Returns a new ChatBuilder for chaining."""
|
|
58
|
+
return _builder.with_model(model)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def with_system(prompt: str) -> ChatBuilder:
|
|
62
|
+
"""Set the system prompt. Returns a new ChatBuilder for chaining."""
|
|
63
|
+
return _builder.with_system(prompt)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def with_developer(instructions: str) -> ChatBuilder:
|
|
67
|
+
"""Set developer instructions. Returns a new ChatBuilder for chaining."""
|
|
68
|
+
return _builder.with_developer(instructions)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def with_temperature(temperature: float) -> ChatBuilder:
|
|
72
|
+
"""Set the temperature. Returns a new ChatBuilder for chaining."""
|
|
73
|
+
return _builder.with_temperature(temperature)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def with_max_tokens(max_tokens: int) -> ChatBuilder:
|
|
77
|
+
"""Set max tokens. Returns a new ChatBuilder for chaining."""
|
|
78
|
+
return _builder.with_max_tokens(max_tokens)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def with_options(**kwargs) -> ChatBuilder:
|
|
82
|
+
"""Set additional options. Returns a new ChatBuilder for chaining."""
|
|
83
|
+
return _builder.with_options(**kwargs)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
async def chat(
|
|
87
|
+
prompt: Union[str, list[Message]],
|
|
88
|
+
*,
|
|
89
|
+
model: str,
|
|
90
|
+
system: Optional[str] = None,
|
|
91
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
92
|
+
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS,
|
|
93
|
+
stream: bool = False,
|
|
94
|
+
**kwargs,
|
|
95
|
+
) -> Union[ChatResponse, AsyncIterator[ChatResponse]]:
|
|
96
|
+
"""
|
|
97
|
+
Send a chat completion request to Google Gemini.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
prompt: User message string or list of Message objects.
|
|
101
|
+
model: Model name (required). E.g., "gemini-2.5-flash", "gemini-2.5-pro"
|
|
102
|
+
system: Optional system prompt (ignored if prompt is list of Messages).
|
|
103
|
+
temperature: Sampling temperature (0.0-2.0). Default: 0.7
|
|
104
|
+
max_tokens: Maximum tokens to generate. Default: None (model default)
|
|
105
|
+
stream: Whether to stream the response. Default: False
|
|
106
|
+
**kwargs: Additional parameters passed to the API.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
ChatResponse object, or AsyncIterator[ChatResponse] if streaming.
|
|
110
|
+
|
|
111
|
+
Example:
|
|
112
|
+
>>> from stratifyai.chat import google
|
|
113
|
+
>>> response = await google.chat("What is Python?", model="gemini-2.5-flash")
|
|
114
|
+
>>> print(response.content)
|
|
115
|
+
"""
|
|
116
|
+
client = _get_client()
|
|
117
|
+
|
|
118
|
+
# Build messages list
|
|
119
|
+
if isinstance(prompt, str):
|
|
120
|
+
messages = []
|
|
121
|
+
if system:
|
|
122
|
+
messages.append(Message(role="system", content=system))
|
|
123
|
+
messages.append(Message(role="user", content=prompt))
|
|
124
|
+
else:
|
|
125
|
+
messages = prompt
|
|
126
|
+
|
|
127
|
+
return await client.chat(
|
|
128
|
+
model=model,
|
|
129
|
+
messages=messages,
|
|
130
|
+
temperature=temperature,
|
|
131
|
+
max_tokens=max_tokens,
|
|
132
|
+
stream=stream,
|
|
133
|
+
**kwargs,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
async def chat_stream(
|
|
138
|
+
prompt: Union[str, list[Message]],
|
|
139
|
+
*,
|
|
140
|
+
model: str,
|
|
141
|
+
system: Optional[str] = None,
|
|
142
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
143
|
+
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS,
|
|
144
|
+
**kwargs,
|
|
145
|
+
) -> AsyncIterator[ChatResponse]:
|
|
146
|
+
"""
|
|
147
|
+
Send a streaming chat completion request to Google Gemini.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
prompt: User message string or list of Message objects.
|
|
151
|
+
model: Model name (required). E.g., "gemini-2.5-flash"
|
|
152
|
+
system: Optional system prompt (ignored if prompt is list of Messages).
|
|
153
|
+
temperature: Sampling temperature (0.0-2.0). Default: 0.7
|
|
154
|
+
max_tokens: Maximum tokens to generate. Default: None (model default)
|
|
155
|
+
**kwargs: Additional parameters passed to the API.
|
|
156
|
+
|
|
157
|
+
Yields:
|
|
158
|
+
ChatResponse chunks.
|
|
159
|
+
|
|
160
|
+
Example:
|
|
161
|
+
>>> from stratifyai.chat import google
|
|
162
|
+
>>> async for chunk in google.chat_stream("Tell me a story", model="gemini-2.5-flash"):
|
|
163
|
+
... print(chunk.content, end="", flush=True)
|
|
164
|
+
"""
|
|
165
|
+
return await chat(
|
|
166
|
+
prompt,
|
|
167
|
+
model=model,
|
|
168
|
+
system=system,
|
|
169
|
+
temperature=temperature,
|
|
170
|
+
max_tokens=max_tokens,
|
|
171
|
+
stream=True,
|
|
172
|
+
**kwargs,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def chat_sync(
|
|
177
|
+
prompt,
|
|
178
|
+
*,
|
|
179
|
+
model: str,
|
|
180
|
+
system=None,
|
|
181
|
+
temperature=DEFAULT_TEMPERATURE,
|
|
182
|
+
max_tokens=DEFAULT_MAX_TOKENS,
|
|
183
|
+
**kwargs,
|
|
184
|
+
):
|
|
185
|
+
"""Synchronous wrapper for chat()."""
|
|
186
|
+
return asyncio.run(chat(
|
|
187
|
+
prompt,
|
|
188
|
+
model=model,
|
|
189
|
+
system=system,
|
|
190
|
+
temperature=temperature,
|
|
191
|
+
max_tokens=max_tokens,
|
|
192
|
+
stream=False,
|
|
193
|
+
**kwargs,
|
|
194
|
+
))
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Grok (X.AI) chat interface for StratifyAI.
|
|
2
|
+
|
|
3
|
+
Provides convenient functions for Grok chat completions.
|
|
4
|
+
Model must be specified for each request.
|
|
5
|
+
|
|
6
|
+
Environment Variable: GROK_API_KEY
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
# Model is always required
|
|
10
|
+
from stratifyai.chat import grok
|
|
11
|
+
response = await grok.chat("Hello!", model="grok-beta")
|
|
12
|
+
|
|
13
|
+
# Builder pattern (model required)
|
|
14
|
+
client = (
|
|
15
|
+
grok
|
|
16
|
+
.with_model("grok-2")
|
|
17
|
+
.with_system("You are a helpful assistant")
|
|
18
|
+
.with_developer("Use markdown")
|
|
19
|
+
)
|
|
20
|
+
response = await client.chat("Hello!")
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import asyncio
|
|
24
|
+
from typing import AsyncIterator, Optional, Union
|
|
25
|
+
|
|
26
|
+
from stratifyai import LLMClient
|
|
27
|
+
from stratifyai.models import ChatResponse, Message
|
|
28
|
+
from stratifyai.chat.builder import ChatBuilder, create_module_builder
|
|
29
|
+
|
|
30
|
+
# Default configuration (no default model - must be specified)
|
|
31
|
+
DEFAULT_TEMPERATURE = 0.7
|
|
32
|
+
DEFAULT_MAX_TOKENS = None
|
|
33
|
+
|
|
34
|
+
# Module-level client (lazy initialization)
|
|
35
|
+
_client: Optional[LLMClient] = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _get_client() -> LLMClient:
|
|
39
|
+
"""Get or create the module-level client."""
|
|
40
|
+
global _client
|
|
41
|
+
if _client is None:
|
|
42
|
+
_client = LLMClient(provider="grok")
|
|
43
|
+
return _client
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Module-level builder for chaining
|
|
47
|
+
_builder = create_module_builder(
|
|
48
|
+
provider="grok",
|
|
49
|
+
default_temperature=DEFAULT_TEMPERATURE,
|
|
50
|
+
default_max_tokens=DEFAULT_MAX_TOKENS,
|
|
51
|
+
client_factory=_get_client,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# Builder pattern methods (delegate to _builder)
|
|
56
|
+
def with_model(model: str) -> ChatBuilder:
|
|
57
|
+
"""Set the model to use. Returns a new ChatBuilder for chaining."""
|
|
58
|
+
return _builder.with_model(model)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def with_system(prompt: str) -> ChatBuilder:
|
|
62
|
+
"""Set the system prompt. Returns a new ChatBuilder for chaining."""
|
|
63
|
+
return _builder.with_system(prompt)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def with_developer(instructions: str) -> ChatBuilder:
|
|
67
|
+
"""Set developer instructions. Returns a new ChatBuilder for chaining."""
|
|
68
|
+
return _builder.with_developer(instructions)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def with_temperature(temperature: float) -> ChatBuilder:
|
|
72
|
+
"""Set the temperature. Returns a new ChatBuilder for chaining."""
|
|
73
|
+
return _builder.with_temperature(temperature)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def with_max_tokens(max_tokens: int) -> ChatBuilder:
|
|
77
|
+
"""Set max tokens. Returns a new ChatBuilder for chaining."""
|
|
78
|
+
return _builder.with_max_tokens(max_tokens)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def with_options(**kwargs) -> ChatBuilder:
|
|
82
|
+
"""Set additional options. Returns a new ChatBuilder for chaining."""
|
|
83
|
+
return _builder.with_options(**kwargs)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
async def chat(
|
|
87
|
+
prompt: Union[str, list[Message]],
|
|
88
|
+
*,
|
|
89
|
+
model: str,
|
|
90
|
+
system: Optional[str] = None,
|
|
91
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
92
|
+
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS,
|
|
93
|
+
stream: bool = False,
|
|
94
|
+
**kwargs,
|
|
95
|
+
) -> Union[ChatResponse, AsyncIterator[ChatResponse]]:
|
|
96
|
+
"""
|
|
97
|
+
Send a chat completion request to Grok (X.AI).
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
prompt: User message string or list of Message objects.
|
|
101
|
+
model: Model name (required). E.g., "grok-beta", "grok-2"
|
|
102
|
+
system: Optional system prompt (ignored if prompt is list of Messages).
|
|
103
|
+
temperature: Sampling temperature (0.0-2.0). Default: 0.7
|
|
104
|
+
max_tokens: Maximum tokens to generate. Default: None (model default)
|
|
105
|
+
stream: Whether to stream the response. Default: False
|
|
106
|
+
**kwargs: Additional parameters passed to the API.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
ChatResponse object, or AsyncIterator[ChatResponse] if streaming.
|
|
110
|
+
|
|
111
|
+
Example:
|
|
112
|
+
>>> from stratifyai.chat import grok
|
|
113
|
+
>>> response = await grok.chat("What is Python?", model="grok-beta")
|
|
114
|
+
>>> print(response.content)
|
|
115
|
+
"""
|
|
116
|
+
client = _get_client()
|
|
117
|
+
|
|
118
|
+
# Build messages list
|
|
119
|
+
if isinstance(prompt, str):
|
|
120
|
+
messages = []
|
|
121
|
+
if system:
|
|
122
|
+
messages.append(Message(role="system", content=system))
|
|
123
|
+
messages.append(Message(role="user", content=prompt))
|
|
124
|
+
else:
|
|
125
|
+
messages = prompt
|
|
126
|
+
|
|
127
|
+
return await client.chat(
|
|
128
|
+
model=model,
|
|
129
|
+
messages=messages,
|
|
130
|
+
temperature=temperature,
|
|
131
|
+
max_tokens=max_tokens,
|
|
132
|
+
stream=stream,
|
|
133
|
+
**kwargs,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
async def chat_stream(
|
|
138
|
+
prompt: Union[str, list[Message]],
|
|
139
|
+
*,
|
|
140
|
+
model: str,
|
|
141
|
+
system: Optional[str] = None,
|
|
142
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
143
|
+
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS,
|
|
144
|
+
**kwargs,
|
|
145
|
+
) -> AsyncIterator[ChatResponse]:
|
|
146
|
+
"""
|
|
147
|
+
Send a streaming chat completion request to Grok (X.AI).
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
prompt: User message string or list of Message objects.
|
|
151
|
+
model: Model name (required). E.g., "grok-beta"
|
|
152
|
+
system: Optional system prompt (ignored if prompt is list of Messages).
|
|
153
|
+
temperature: Sampling temperature (0.0-2.0). Default: 0.7
|
|
154
|
+
max_tokens: Maximum tokens to generate. Default: None (model default)
|
|
155
|
+
**kwargs: Additional parameters passed to the API.
|
|
156
|
+
|
|
157
|
+
Yields:
|
|
158
|
+
ChatResponse chunks.
|
|
159
|
+
|
|
160
|
+
Example:
|
|
161
|
+
>>> from stratifyai.chat import grok
|
|
162
|
+
>>> async for chunk in grok.chat_stream("Tell me a story", model="grok-beta"):
|
|
163
|
+
... print(chunk.content, end="", flush=True)
|
|
164
|
+
"""
|
|
165
|
+
return await chat(
|
|
166
|
+
prompt,
|
|
167
|
+
model=model,
|
|
168
|
+
system=system,
|
|
169
|
+
temperature=temperature,
|
|
170
|
+
max_tokens=max_tokens,
|
|
171
|
+
stream=True,
|
|
172
|
+
**kwargs,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def chat_sync(
|
|
177
|
+
prompt,
|
|
178
|
+
*,
|
|
179
|
+
model: str,
|
|
180
|
+
system=None,
|
|
181
|
+
temperature=DEFAULT_TEMPERATURE,
|
|
182
|
+
max_tokens=DEFAULT_MAX_TOKENS,
|
|
183
|
+
**kwargs,
|
|
184
|
+
):
|
|
185
|
+
"""Synchronous wrapper for chat()."""
|
|
186
|
+
return asyncio.run(chat(
|
|
187
|
+
prompt,
|
|
188
|
+
model=model,
|
|
189
|
+
system=system,
|
|
190
|
+
temperature=temperature,
|
|
191
|
+
max_tokens=max_tokens,
|
|
192
|
+
stream=False,
|
|
193
|
+
**kwargs,
|
|
194
|
+
))
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""Groq chat interface for StratifyAI.
|
|
2
|
+
|
|
3
|
+
Provides convenient functions for Groq chat completions.
|
|
4
|
+
Groq provides ultra-fast inference for open-source models.
|
|
5
|
+
Model must be specified for each request.
|
|
6
|
+
|
|
7
|
+
Environment Variable: GROQ_API_KEY
|
|
8
|
+
|
|
9
|
+
Usage:
|
|
10
|
+
# Model is always required
|
|
11
|
+
from stratifyai.chat import groq
|
|
12
|
+
response = await groq.chat("Hello!", model="llama-3.3-70b-versatile")
|
|
13
|
+
|
|
14
|
+
# Builder pattern (model required)
|
|
15
|
+
client = (
|
|
16
|
+
groq
|
|
17
|
+
.with_model("mixtral-8x7b-32768")
|
|
18
|
+
.with_system("You are a helpful assistant")
|
|
19
|
+
.with_developer("Use markdown")
|
|
20
|
+
)
|
|
21
|
+
response = await client.chat("Hello!")
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import asyncio
|
|
25
|
+
from typing import AsyncIterator, Optional, Union
|
|
26
|
+
|
|
27
|
+
from stratifyai import LLMClient
|
|
28
|
+
from stratifyai.models import ChatResponse, Message
|
|
29
|
+
from stratifyai.chat.builder import ChatBuilder, create_module_builder
|
|
30
|
+
|
|
31
|
+
# Default configuration (no default model - must be specified)
|
|
32
|
+
DEFAULT_TEMPERATURE = 0.7
|
|
33
|
+
DEFAULT_MAX_TOKENS = None
|
|
34
|
+
|
|
35
|
+
# Module-level client (lazy initialization)
|
|
36
|
+
_client: Optional[LLMClient] = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_client() -> LLMClient:
|
|
40
|
+
"""Get or create the module-level client."""
|
|
41
|
+
global _client
|
|
42
|
+
if _client is None:
|
|
43
|
+
_client = LLMClient(provider="groq")
|
|
44
|
+
return _client
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# Module-level builder for chaining
|
|
48
|
+
_builder = create_module_builder(
|
|
49
|
+
provider="groq",
|
|
50
|
+
default_temperature=DEFAULT_TEMPERATURE,
|
|
51
|
+
default_max_tokens=DEFAULT_MAX_TOKENS,
|
|
52
|
+
client_factory=_get_client,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# Builder pattern methods (delegate to _builder)
|
|
57
|
+
def with_model(model: str) -> ChatBuilder:
|
|
58
|
+
"""Set the model to use. Returns a new ChatBuilder for chaining."""
|
|
59
|
+
return _builder.with_model(model)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def with_system(prompt: str) -> ChatBuilder:
|
|
63
|
+
"""Set the system prompt. Returns a new ChatBuilder for chaining."""
|
|
64
|
+
return _builder.with_system(prompt)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def with_developer(instructions: str) -> ChatBuilder:
|
|
68
|
+
"""Set developer instructions. Returns a new ChatBuilder for chaining."""
|
|
69
|
+
return _builder.with_developer(instructions)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def with_temperature(temperature: float) -> ChatBuilder:
|
|
73
|
+
"""Set the temperature. Returns a new ChatBuilder for chaining."""
|
|
74
|
+
return _builder.with_temperature(temperature)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def with_max_tokens(max_tokens: int) -> ChatBuilder:
|
|
78
|
+
"""Set max tokens. Returns a new ChatBuilder for chaining."""
|
|
79
|
+
return _builder.with_max_tokens(max_tokens)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def with_options(**kwargs) -> ChatBuilder:
|
|
83
|
+
"""Set additional options. Returns a new ChatBuilder for chaining."""
|
|
84
|
+
return _builder.with_options(**kwargs)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
async def chat(
|
|
88
|
+
prompt: Union[str, list[Message]],
|
|
89
|
+
*,
|
|
90
|
+
model: str,
|
|
91
|
+
system: Optional[str] = None,
|
|
92
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
93
|
+
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS,
|
|
94
|
+
stream: bool = False,
|
|
95
|
+
**kwargs,
|
|
96
|
+
) -> Union[ChatResponse, AsyncIterator[ChatResponse]]:
|
|
97
|
+
"""
|
|
98
|
+
Send a chat completion request to Groq.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
prompt: User message string or list of Message objects.
|
|
102
|
+
model: Model name (required). E.g., "llama-3.3-70b-versatile", "mixtral-8x7b-32768"
|
|
103
|
+
system: Optional system prompt (ignored if prompt is list of Messages).
|
|
104
|
+
temperature: Sampling temperature (0.0-2.0). Default: 0.7
|
|
105
|
+
max_tokens: Maximum tokens to generate. Default: None (model default)
|
|
106
|
+
stream: Whether to stream the response. Default: False
|
|
107
|
+
**kwargs: Additional parameters passed to the API.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
ChatResponse object, or AsyncIterator[ChatResponse] if streaming.
|
|
111
|
+
|
|
112
|
+
Example:
|
|
113
|
+
>>> from stratifyai.chat import groq
|
|
114
|
+
>>> response = await groq.chat("What is Python?", model="llama-3.3-70b-versatile")
|
|
115
|
+
>>> print(response.content)
|
|
116
|
+
"""
|
|
117
|
+
client = _get_client()
|
|
118
|
+
|
|
119
|
+
# Build messages list
|
|
120
|
+
if isinstance(prompt, str):
|
|
121
|
+
messages = []
|
|
122
|
+
if system:
|
|
123
|
+
messages.append(Message(role="system", content=system))
|
|
124
|
+
messages.append(Message(role="user", content=prompt))
|
|
125
|
+
else:
|
|
126
|
+
messages = prompt
|
|
127
|
+
|
|
128
|
+
return await client.chat(
|
|
129
|
+
model=model,
|
|
130
|
+
messages=messages,
|
|
131
|
+
temperature=temperature,
|
|
132
|
+
max_tokens=max_tokens,
|
|
133
|
+
stream=stream,
|
|
134
|
+
**kwargs,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
async def chat_stream(
|
|
139
|
+
prompt: Union[str, list[Message]],
|
|
140
|
+
*,
|
|
141
|
+
model: str,
|
|
142
|
+
system: Optional[str] = None,
|
|
143
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
144
|
+
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS,
|
|
145
|
+
**kwargs,
|
|
146
|
+
) -> AsyncIterator[ChatResponse]:
|
|
147
|
+
"""
|
|
148
|
+
Send a streaming chat completion request to Groq.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
prompt: User message string or list of Message objects.
|
|
152
|
+
model: Model name (required). E.g., "llama-3.3-70b-versatile"
|
|
153
|
+
system: Optional system prompt (ignored if prompt is list of Messages).
|
|
154
|
+
temperature: Sampling temperature (0.0-2.0). Default: 0.7
|
|
155
|
+
max_tokens: Maximum tokens to generate. Default: None (model default)
|
|
156
|
+
**kwargs: Additional parameters passed to the API.
|
|
157
|
+
|
|
158
|
+
Yields:
|
|
159
|
+
ChatResponse chunks.
|
|
160
|
+
|
|
161
|
+
Example:
|
|
162
|
+
>>> from stratifyai.chat import groq
|
|
163
|
+
>>> async for chunk in groq.chat_stream("Tell me a story", model="llama-3.3-70b-versatile"):
|
|
164
|
+
... print(chunk.content, end="", flush=True)
|
|
165
|
+
"""
|
|
166
|
+
return await chat(
|
|
167
|
+
prompt,
|
|
168
|
+
model=model,
|
|
169
|
+
system=system,
|
|
170
|
+
temperature=temperature,
|
|
171
|
+
max_tokens=max_tokens,
|
|
172
|
+
stream=True,
|
|
173
|
+
**kwargs,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def chat_sync(
|
|
178
|
+
prompt,
|
|
179
|
+
*,
|
|
180
|
+
model: str,
|
|
181
|
+
system=None,
|
|
182
|
+
temperature=DEFAULT_TEMPERATURE,
|
|
183
|
+
max_tokens=DEFAULT_MAX_TOKENS,
|
|
184
|
+
**kwargs,
|
|
185
|
+
):
|
|
186
|
+
"""Synchronous wrapper for chat()."""
|
|
187
|
+
return asyncio.run(chat(
|
|
188
|
+
prompt,
|
|
189
|
+
model=model,
|
|
190
|
+
system=system,
|
|
191
|
+
temperature=temperature,
|
|
192
|
+
max_tokens=max_tokens,
|
|
193
|
+
stream=False,
|
|
194
|
+
**kwargs,
|
|
195
|
+
))
|