zarz 0.3.1-alpha → 0.3.5-alpha

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,197 +0,0 @@
1
- use anyhow::{anyhow, Context, Result};
2
- use bytes::Bytes;
3
- use futures::stream::StreamExt;
4
- use reqwest::Client;
5
- use serde::Deserialize;
6
- use serde_json::json;
7
-
8
- use super::{CompletionRequest, CompletionResponse, CompletionStream};
9
-
10
- const DEFAULT_ENDPOINT: &str = "https://api.openai.com/v1/chat/completions";
11
-
12
- pub struct OpenAiClient {
13
- http: Client,
14
- endpoint: String,
15
- api_key: String,
16
- }
17
-
18
- impl OpenAiClient {
19
- pub fn from_env(
20
- api_key_override: Option<String>,
21
- endpoint_override: Option<String>,
22
- timeout_override: Option<u64>,
23
- ) -> Result<Self> {
24
- let api_key = api_key_override
25
- .or_else(|| std::env::var("OPENAI_API_KEY").ok())
26
- .ok_or_else(|| anyhow::anyhow!("OPENAI_API_KEY is required. Please set it in ~/.zarz/config.toml or as an environment variable"))?;
27
- let endpoint = endpoint_override
28
- .or_else(|| std::env::var("OPENAI_API_URL").ok())
29
- .unwrap_or_else(|| DEFAULT_ENDPOINT.to_string());
30
-
31
- let timeout_secs = timeout_override
32
- .or_else(|| {
33
- std::env::var("OPENAI_TIMEOUT_SECS")
34
- .ok()
35
- .and_then(|raw| raw.parse::<u64>().ok())
36
- })
37
- .unwrap_or(120);
38
-
39
- let http = Client::builder()
40
- .user_agent("zarz-cli/0.1")
41
- .timeout(std::time::Duration::from_secs(timeout_secs))
42
- .build()
43
- .context("Failed to build HTTP client for OpenAI")?;
44
-
45
- Ok(Self {
46
- http,
47
- endpoint,
48
- api_key,
49
- })
50
- }
51
-
52
- pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
53
- let mut messages = Vec::new();
54
- if let Some(system) = &request.system_prompt {
55
- messages.push(json!({
56
- "role": "system",
57
- "content": system,
58
- }));
59
- }
60
- messages.push(json!({
61
- "role": "user",
62
- "content": request.user_prompt,
63
- }));
64
-
65
- let payload = json!({
66
- "model": request.model,
67
- "max_tokens": request.max_output_tokens,
68
- "temperature": request.temperature,
69
- "messages": messages,
70
- });
71
-
72
- let response = self
73
- .http
74
- .post(&self.endpoint)
75
- .bearer_auth(&self.api_key)
76
- .json(&payload)
77
- .send()
78
- .await
79
- .context("OpenAI request failed")?;
80
-
81
- let response = response.error_for_status().context("OpenAI returned an error status")?;
82
-
83
- let parsed: OpenAiResponse = response
84
- .json()
85
- .await
86
- .context("Failed to decode OpenAI response")?;
87
-
88
- let text = parsed
89
- .choices
90
- .into_iter()
91
- .find_map(|choice| choice.message.content)
92
- .ok_or_else(|| anyhow!("OpenAI response did not include content"))?;
93
-
94
- Ok(CompletionResponse { text })
95
- }
96
-
97
- #[allow(dead_code)]
98
- pub async fn complete_stream(&self, request: &CompletionRequest) -> Result<CompletionStream> {
99
- let mut messages = Vec::new();
100
- if let Some(system) = &request.system_prompt {
101
- messages.push(json!({
102
- "role": "system",
103
- "content": system,
104
- }));
105
- }
106
- messages.push(json!({
107
- "role": "user",
108
- "content": request.user_prompt,
109
- }));
110
-
111
- let payload = json!({
112
- "model": request.model,
113
- "max_tokens": request.max_output_tokens,
114
- "temperature": request.temperature,
115
- "messages": messages,
116
- "stream": true,
117
- });
118
-
119
- let response = self
120
- .http
121
- .post(&self.endpoint)
122
- .bearer_auth(&self.api_key)
123
- .json(&payload)
124
- .send()
125
- .await
126
- .context("OpenAI streaming request failed")?;
127
-
128
- let response = response
129
- .error_for_status()
130
- .context("OpenAI returned an error status")?;
131
-
132
- let stream = response.bytes_stream();
133
- let text_stream = stream.map(|result| {
134
- let bytes = result?;
135
- parse_openai_sse_chunk(&bytes)
136
- });
137
-
138
- Ok(Box::pin(text_stream))
139
- }
140
- }
141
-
142
- #[allow(dead_code)]
143
- fn parse_openai_sse_chunk(bytes: &Bytes) -> Result<String> {
144
- let text = String::from_utf8_lossy(bytes);
145
- let mut result = String::new();
146
-
147
- for line in text.lines() {
148
- if let Some(data) = line.strip_prefix("data: ") {
149
- if data == "[DONE]" {
150
- break;
151
- }
152
-
153
- if let Ok(chunk) = serde_json::from_str::<StreamChunk>(data) {
154
- if let Some(choice) = chunk.choices.first() {
155
- if let Some(content) = &choice.delta.content {
156
- result.push_str(content);
157
- }
158
- }
159
- }
160
- }
161
- }
162
-
163
- Ok(result)
164
- }
165
-
166
- #[allow(dead_code)]
167
- #[derive(Debug, Deserialize)]
168
- struct StreamChunk {
169
- choices: Vec<StreamChoice>,
170
- }
171
-
172
- #[allow(dead_code)]
173
- #[derive(Debug, Deserialize)]
174
- struct StreamChoice {
175
- delta: StreamDelta,
176
- }
177
-
178
- #[allow(dead_code)]
179
- #[derive(Debug, Deserialize)]
180
- struct StreamDelta {
181
- content: Option<String>,
182
- }
183
-
184
- #[derive(Debug, Deserialize)]
185
- struct OpenAiResponse {
186
- choices: Vec<OpenAiChoice>,
187
- }
188
-
189
- #[derive(Debug, Deserialize)]
190
- struct OpenAiChoice {
191
- message: OpenAiMessage,
192
- }
193
-
194
- #[derive(Debug, Deserialize)]
195
- struct OpenAiMessage {
196
- content: Option<String>,
197
- }