grok-cli-acp 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.env.example +42 -0
- package/.github/workflows/ci.yml +30 -0
- package/.github/workflows/rust.yml +22 -0
- package/.grok/.env.example +85 -0
- package/.grok/COMPLETE_FIX_SUMMARY.md +466 -0
- package/.grok/ENV_CONFIG_GUIDE.md +173 -0
- package/.grok/QUICK_REFERENCE.md +180 -0
- package/.grok/README.md +104 -0
- package/.grok/TESTING_GUIDE.md +393 -0
- package/CHANGELOG.md +465 -0
- package/CODE_REVIEW_SUMMARY.md +414 -0
- package/COMPLETE_FIX_SUMMARY.md +415 -0
- package/CONFIGURATION.md +489 -0
- package/CONTEXT_FILES_GUIDE.md +419 -0
- package/CONTRIBUTING.md +55 -0
- package/CURSOR_POSITION_FIX.md +206 -0
- package/Cargo.toml +88 -0
- package/ERROR_HANDLING_REPORT.md +361 -0
- package/FINAL_FIX_SUMMARY.md +462 -0
- package/FIXES.md +37 -0
- package/FIXES_SUMMARY.md +87 -0
- package/GROK_API_MIGRATION_SUMMARY.md +111 -0
- package/LICENSE +22 -0
- package/MIGRATION_TO_GROK_API.md +223 -0
- package/README.md +504 -0
- package/REVIEW_COMPLETE.md +416 -0
- package/REVIEW_QUICK_REFERENCE.md +173 -0
- package/SECURITY.md +463 -0
- package/SECURITY_AUDIT.md +661 -0
- package/SETUP.md +287 -0
- package/TESTING_TOOLS.md +88 -0
- package/TESTING_TOOL_EXECUTION.md +239 -0
- package/TOOL_EXECUTION_FIX.md +491 -0
- package/VERIFICATION_CHECKLIST.md +419 -0
- package/docs/API.md +74 -0
- package/docs/CHAT_LOGGING.md +39 -0
- package/docs/CURSOR_FIX_DEMO.md +306 -0
- package/docs/ERROR_HANDLING_GUIDE.md +547 -0
- package/docs/FILE_OPERATIONS.md +449 -0
- package/docs/INTERACTIVE.md +401 -0
- package/docs/PROJECT_CREATION_GUIDE.md +570 -0
- package/docs/QUICKSTART.md +378 -0
- package/docs/QUICK_REFERENCE.md +691 -0
- package/docs/RELEASE_NOTES_0.1.2.md +240 -0
- package/docs/TOOLS.md +459 -0
- package/docs/TOOLS_QUICK_REFERENCE.md +210 -0
- package/docs/ZED_INTEGRATION.md +371 -0
- package/docs/extensions.md +464 -0
- package/docs/settings.md +293 -0
- package/examples/extensions/logging-hook/README.md +91 -0
- package/examples/extensions/logging-hook/extension.json +22 -0
- package/package.json +30 -0
- package/scripts/test_acp.py +252 -0
- package/scripts/test_acp.sh +143 -0
- package/scripts/test_acp_simple.sh +72 -0
- package/src/acp/mod.rs +741 -0
- package/src/acp/protocol.rs +323 -0
- package/src/acp/security.rs +298 -0
- package/src/acp/tools.rs +697 -0
- package/src/bin/banner_demo.rs +216 -0
- package/src/bin/docgen.rs +18 -0
- package/src/bin/installer.rs +217 -0
- package/src/cli/app.rs +310 -0
- package/src/cli/commands/acp.rs +721 -0
- package/src/cli/commands/chat.rs +485 -0
- package/src/cli/commands/code.rs +513 -0
- package/src/cli/commands/config.rs +394 -0
- package/src/cli/commands/health.rs +442 -0
- package/src/cli/commands/history.rs +421 -0
- package/src/cli/commands/mod.rs +14 -0
- package/src/cli/commands/settings.rs +1384 -0
- package/src/cli/mod.rs +166 -0
- package/src/config/mod.rs +2212 -0
- package/src/display/ascii_art.rs +139 -0
- package/src/display/banner.rs +289 -0
- package/src/display/components/input.rs +323 -0
- package/src/display/components/mod.rs +2 -0
- package/src/display/components/settings_list.rs +306 -0
- package/src/display/interactive.rs +1255 -0
- package/src/display/mod.rs +62 -0
- package/src/display/terminal.rs +42 -0
- package/src/display/tips.rs +316 -0
- package/src/grok_client_ext.rs +177 -0
- package/src/hooks/loader.rs +407 -0
- package/src/hooks/mod.rs +158 -0
- package/src/lib.rs +174 -0
- package/src/main.rs +65 -0
- package/src/mcp/client.rs +195 -0
- package/src/mcp/config.rs +20 -0
- package/src/mcp/mod.rs +6 -0
- package/src/mcp/protocol.rs +67 -0
- package/src/utils/auth.rs +41 -0
- package/src/utils/chat_logger.rs +568 -0
- package/src/utils/context.rs +390 -0
- package/src/utils/mod.rs +16 -0
- package/src/utils/network.rs +320 -0
- package/src/utils/rate_limiter.rs +166 -0
- package/src/utils/session.rs +73 -0
- package/src/utils/shell_permissions.rs +389 -0
- package/src/utils/telemetry.rs +41 -0
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
//! Network utilities for detecting and handling network issues
|
|
2
|
+
//!
|
|
3
|
+
//! This module provides utilities specifically designed for handling network
|
|
4
|
+
//! instability common with satellite internet connections like Starlink,
|
|
5
|
+
//! including connection drops, timeouts, and recovery strategies.
|
|
6
|
+
|
|
7
|
+
use anyhow::{anyhow, Error};
|
|
8
|
+
use std::time::{Duration, Instant};
|
|
9
|
+
use tracing::{debug, info, warn};
|
|
10
|
+
|
|
11
|
+
/// Patterns that indicate Starlink or satellite network issues
|
|
12
|
+
const STARLINK_ERROR_PATTERNS: &[&str] = &[
|
|
13
|
+
"connection reset",
|
|
14
|
+
"connection dropped",
|
|
15
|
+
"network unreachable",
|
|
16
|
+
"no route to host",
|
|
17
|
+
"broken pipe",
|
|
18
|
+
"connection refused",
|
|
19
|
+
"timeout",
|
|
20
|
+
"dns resolution failed",
|
|
21
|
+
"temporary failure in name resolution",
|
|
22
|
+
"network is down",
|
|
23
|
+
"host is unreachable",
|
|
24
|
+
];
|
|
25
|
+
|
|
26
|
+
/// HTTP status codes that commonly occur during satellite network issues
|
|
27
|
+
const SATELLITE_HTTP_ERRORS: &[u16] = &[
|
|
28
|
+
502, // Bad Gateway
|
|
29
|
+
503, // Service Unavailable
|
|
30
|
+
504, // Gateway Timeout
|
|
31
|
+
520, // Web Server Unknown Error (Cloudflare)
|
|
32
|
+
521, // Web Server Is Down (Cloudflare)
|
|
33
|
+
522, // Connection Timed Out (Cloudflare)
|
|
34
|
+
523, // Origin Is Unreachable (Cloudflare)
|
|
35
|
+
524, // A Timeout Occurred (Cloudflare)
|
|
36
|
+
];
|
|
37
|
+
|
|
38
|
+
/// Network drop detection result
|
|
39
|
+
#[derive(Debug, Clone)]
|
|
40
|
+
pub struct NetworkDropInfo {
|
|
41
|
+
pub is_drop: bool,
|
|
42
|
+
pub confidence: DropConfidence,
|
|
43
|
+
pub suggested_action: SuggestedAction,
|
|
44
|
+
pub retry_delay: Duration,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
/// Confidence level in network drop detection
|
|
48
|
+
#[derive(Debug, Clone, PartialEq)]
|
|
49
|
+
pub enum DropConfidence {
|
|
50
|
+
Low,
|
|
51
|
+
Medium,
|
|
52
|
+
High,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
/// Suggested action to take when network drop is detected
|
|
56
|
+
#[derive(Debug, Clone)]
|
|
57
|
+
pub enum SuggestedAction {
|
|
58
|
+
Retry,
|
|
59
|
+
RetryWithBackoff,
|
|
60
|
+
WaitAndRetry(Duration),
|
|
61
|
+
CheckConnection,
|
|
62
|
+
Abort,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
/// Detect if an error indicates a network drop, particularly from Starlink
|
|
66
|
+
pub fn detect_network_drop(error: &Error) -> bool {
|
|
67
|
+
let error_string = error.to_string().to_lowercase();
|
|
68
|
+
|
|
69
|
+
// Check for direct error patterns
|
|
70
|
+
for pattern in STARLINK_ERROR_PATTERNS {
|
|
71
|
+
if error_string.contains(pattern) {
|
|
72
|
+
debug!("Network drop detected: pattern '{}' found", pattern);
|
|
73
|
+
return true;
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// Check for HTTP status codes
|
|
78
|
+
for &status in SATELLITE_HTTP_ERRORS {
|
|
79
|
+
if error_string.contains(&status.to_string()) {
|
|
80
|
+
debug!("Network drop detected: HTTP status {} found", status);
|
|
81
|
+
return true;
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
// Check for reqwest-specific timeout errors
|
|
86
|
+
if error_string.contains("reqwest") && error_string.contains("timeout") {
|
|
87
|
+
debug!("Network drop detected: reqwest timeout");
|
|
88
|
+
return true;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
false
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
/// Analyze network error and provide detailed information
|
|
95
|
+
pub fn analyze_network_error(error: &Error) -> NetworkDropInfo {
|
|
96
|
+
let error_string = error.to_string().to_lowercase();
|
|
97
|
+
let mut confidence = DropConfidence::Low;
|
|
98
|
+
let mut suggested_action = SuggestedAction::Retry;
|
|
99
|
+
let mut retry_delay = Duration::from_secs(1);
|
|
100
|
+
|
|
101
|
+
// High confidence indicators
|
|
102
|
+
if error_string.contains("connection reset")
|
|
103
|
+
|| error_string.contains("broken pipe")
|
|
104
|
+
|| error_string.contains("network unreachable")
|
|
105
|
+
{
|
|
106
|
+
confidence = DropConfidence::High;
|
|
107
|
+
suggested_action = SuggestedAction::WaitAndRetry(Duration::from_secs(5));
|
|
108
|
+
retry_delay = Duration::from_secs(5);
|
|
109
|
+
}
|
|
110
|
+
// Medium confidence indicators
|
|
111
|
+
else if error_string.contains("timeout")
|
|
112
|
+
|| error_string.contains("connection refused")
|
|
113
|
+
|| SATELLITE_HTTP_ERRORS
|
|
114
|
+
.iter()
|
|
115
|
+
.any(|&status| error_string.contains(&status.to_string()))
|
|
116
|
+
{
|
|
117
|
+
confidence = DropConfidence::Medium;
|
|
118
|
+
suggested_action = SuggestedAction::RetryWithBackoff;
|
|
119
|
+
retry_delay = Duration::from_secs(2);
|
|
120
|
+
}
|
|
121
|
+
// Low confidence - generic network errors
|
|
122
|
+
else if error_string.contains("network") || error_string.contains("dns") {
|
|
123
|
+
confidence = DropConfidence::Low;
|
|
124
|
+
suggested_action = SuggestedAction::Retry;
|
|
125
|
+
retry_delay = Duration::from_secs(1);
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
let is_drop = confidence != DropConfidence::Low || detect_network_drop(error);
|
|
129
|
+
|
|
130
|
+
NetworkDropInfo {
|
|
131
|
+
is_drop,
|
|
132
|
+
confidence,
|
|
133
|
+
suggested_action,
|
|
134
|
+
retry_delay,
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
/// Check if we're likely on a Starlink connection
|
|
139
|
+
pub async fn detect_starlink_connection() -> bool {
|
|
140
|
+
// Try to resolve Starlink-specific domains or check for satellite-specific patterns
|
|
141
|
+
// This is a heuristic approach
|
|
142
|
+
|
|
143
|
+
// Check if we can resolve starlink.com (indicates possible Starlink connection)
|
|
144
|
+
if let Ok(addrs) = tokio::net::lookup_host("starlink.com:80").await
|
|
145
|
+
&& addrs.count() > 0 {
|
|
146
|
+
info!("Starlink domain resolution successful - possible Starlink connection");
|
|
147
|
+
return true;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
// Additional heuristics could be added here:
|
|
151
|
+
// - Check for specific IP ranges
|
|
152
|
+
// - Analyze latency patterns
|
|
153
|
+
// - Check for satellite-specific network characteristics
|
|
154
|
+
|
|
155
|
+
false
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
/// Perform a network connectivity test
|
|
159
|
+
pub async fn test_connectivity(timeout: Duration) -> Result<Duration, Error> {
|
|
160
|
+
let start = Instant::now();
|
|
161
|
+
|
|
162
|
+
// Test connectivity to multiple reliable endpoints
|
|
163
|
+
let test_hosts = vec!["google.com:80", "cloudflare.com:80", "github.com:80"];
|
|
164
|
+
|
|
165
|
+
for host in test_hosts {
|
|
166
|
+
match tokio::time::timeout(timeout, tokio::net::TcpStream::connect(host)).await {
|
|
167
|
+
Ok(Ok(_stream)) => {
|
|
168
|
+
let elapsed = start.elapsed();
|
|
169
|
+
info!("Connectivity test successful to {} in {:?}", host, elapsed);
|
|
170
|
+
return Ok(elapsed);
|
|
171
|
+
}
|
|
172
|
+
Ok(Err(e)) => {
|
|
173
|
+
warn!("Failed to connect to {}: {}", host, e);
|
|
174
|
+
continue;
|
|
175
|
+
}
|
|
176
|
+
Err(_) => {
|
|
177
|
+
warn!("Timeout connecting to {}", host);
|
|
178
|
+
continue;
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
Err(anyhow!("All connectivity tests failed"))
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
/// Calculate optimal retry delay based on network conditions
|
|
187
|
+
pub fn calculate_retry_delay(attempt: u32, is_starlink: bool) -> Duration {
|
|
188
|
+
let base_delay = if is_starlink {
|
|
189
|
+
// Longer delays for satellite connections
|
|
190
|
+
Duration::from_secs(2_u64.pow(attempt.min(4)))
|
|
191
|
+
} else {
|
|
192
|
+
// Standard exponential backoff
|
|
193
|
+
Duration::from_secs(2_u64.pow(attempt.min(3)))
|
|
194
|
+
};
|
|
195
|
+
|
|
196
|
+
// Add jitter to prevent thundering herd
|
|
197
|
+
let jitter = Duration::from_millis(rand::random::<u64>() % 1000);
|
|
198
|
+
base_delay + jitter
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
/// Network health monitor for continuous connection quality assessment
|
|
202
|
+
pub struct NetworkHealthMonitor {
|
|
203
|
+
consecutive_failures: u32,
|
|
204
|
+
last_success: Option<Instant>,
|
|
205
|
+
total_requests: u64,
|
|
206
|
+
failed_requests: u64,
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
impl NetworkHealthMonitor {
|
|
210
|
+
pub fn new() -> Self {
|
|
211
|
+
Self {
|
|
212
|
+
consecutive_failures: 0,
|
|
213
|
+
last_success: None,
|
|
214
|
+
total_requests: 0,
|
|
215
|
+
failed_requests: 0,
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
pub fn record_success(&mut self) {
|
|
220
|
+
self.consecutive_failures = 0;
|
|
221
|
+
self.last_success = Some(Instant::now());
|
|
222
|
+
self.total_requests += 1;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
pub fn record_failure(&mut self) {
|
|
226
|
+
self.consecutive_failures += 1;
|
|
227
|
+
self.total_requests += 1;
|
|
228
|
+
self.failed_requests += 1;
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
pub fn health_score(&self) -> f64 {
|
|
232
|
+
if self.total_requests == 0 {
|
|
233
|
+
return 1.0;
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
let success_rate =
|
|
237
|
+
(self.total_requests - self.failed_requests) as f64 / self.total_requests as f64;
|
|
238
|
+
|
|
239
|
+
// Penalize consecutive failures
|
|
240
|
+
let consecutive_penalty = (self.consecutive_failures as f64 * 0.1).min(0.5);
|
|
241
|
+
|
|
242
|
+
(success_rate - consecutive_penalty).max(0.0)
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
pub fn should_increase_timeout(&self) -> bool {
|
|
246
|
+
self.consecutive_failures >= 3 || self.health_score() < 0.5
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
pub fn reset(&mut self) {
|
|
250
|
+
self.consecutive_failures = 0;
|
|
251
|
+
self.total_requests = 0;
|
|
252
|
+
self.failed_requests = 0;
|
|
253
|
+
self.last_success = None;
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
impl Default for NetworkHealthMonitor {
|
|
258
|
+
fn default() -> Self {
|
|
259
|
+
Self::new()
|
|
260
|
+
}
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
#[cfg(test)]
|
|
264
|
+
mod tests {
|
|
265
|
+
use super::*;
|
|
266
|
+
|
|
267
|
+
#[test]
|
|
268
|
+
fn test_detect_network_drop() {
|
|
269
|
+
assert!(detect_network_drop(&anyhow!("Connection reset by peer")));
|
|
270
|
+
assert!(detect_network_drop(&anyhow!("Network unreachable")));
|
|
271
|
+
assert!(detect_network_drop(&anyhow!("HTTP 502 Bad Gateway")));
|
|
272
|
+
assert!(!detect_network_drop(&anyhow!("Invalid API key")));
|
|
273
|
+
assert!(!detect_network_drop(&anyhow!("JSON parsing error")));
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
#[test]
|
|
277
|
+
fn test_analyze_network_error() {
|
|
278
|
+
let reset_error = anyhow!("Connection reset by peer");
|
|
279
|
+
let analysis = analyze_network_error(&reset_error);
|
|
280
|
+
assert!(analysis.is_drop);
|
|
281
|
+
assert_eq!(analysis.confidence, DropConfidence::High);
|
|
282
|
+
|
|
283
|
+
let timeout_error = anyhow!("Request timeout");
|
|
284
|
+
let analysis = analyze_network_error(&timeout_error);
|
|
285
|
+
assert!(analysis.is_drop);
|
|
286
|
+
assert_eq!(analysis.confidence, DropConfidence::Medium);
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
#[test]
|
|
290
|
+
fn test_calculate_retry_delay() {
|
|
291
|
+
let delay1 = calculate_retry_delay(1, false);
|
|
292
|
+
let delay2 = calculate_retry_delay(2, false);
|
|
293
|
+
assert!(delay2 >= delay1);
|
|
294
|
+
|
|
295
|
+
let starlink_delay = calculate_retry_delay(1, true);
|
|
296
|
+
let regular_delay = calculate_retry_delay(1, false);
|
|
297
|
+
// Starlink delays should generally be longer (though jitter may affect this)
|
|
298
|
+
// We just test that both are reasonable
|
|
299
|
+
assert!(starlink_delay >= Duration::from_secs(1));
|
|
300
|
+
assert!(regular_delay >= Duration::from_secs(1));
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
#[test]
|
|
304
|
+
fn test_network_health_monitor() {
|
|
305
|
+
let mut monitor = NetworkHealthMonitor::new();
|
|
306
|
+
assert_eq!(monitor.health_score(), 1.0);
|
|
307
|
+
|
|
308
|
+
monitor.record_success();
|
|
309
|
+
assert_eq!(monitor.health_score(), 1.0);
|
|
310
|
+
|
|
311
|
+
monitor.record_failure();
|
|
312
|
+
assert!(monitor.health_score() < 1.0);
|
|
313
|
+
assert!(monitor.health_score() > 0.0);
|
|
314
|
+
|
|
315
|
+
// Multiple consecutive failures
|
|
316
|
+
monitor.record_failure();
|
|
317
|
+
monitor.record_failure();
|
|
318
|
+
assert!(monitor.should_increase_timeout());
|
|
319
|
+
}
|
|
320
|
+
}
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
use crate::config::RateLimitConfig;
|
|
2
|
+
use anyhow::{Result, anyhow};
|
|
3
|
+
use serde::{Deserialize, Serialize};
|
|
4
|
+
use std::fs;
|
|
5
|
+
use std::path::PathBuf;
|
|
6
|
+
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
|
7
|
+
use tracing::warn;
|
|
8
|
+
|
|
9
|
+
#[derive(Debug, Serialize, Deserialize, Default, Clone)]
|
|
10
|
+
pub struct UsageStats {
|
|
11
|
+
pub total_input_tokens: u64,
|
|
12
|
+
pub total_output_tokens: u64,
|
|
13
|
+
pub request_count: u64,
|
|
14
|
+
pub last_request_time: Option<u64>, // Unix timestamp in seconds
|
|
15
|
+
|
|
16
|
+
// We store timestamps as u64 (Unix timestamp) for serialization
|
|
17
|
+
pub request_history: Vec<(u64, u32)>, // (Timestamp, TokenCount)
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
impl UsageStats {
|
|
21
|
+
pub fn new() -> Self {
|
|
22
|
+
Self::default()
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
/// Load usage stats from disk
|
|
26
|
+
pub fn load() -> Result<Self> {
|
|
27
|
+
let path = get_usage_stats_path()?;
|
|
28
|
+
if path.exists() {
|
|
29
|
+
let json = fs::read_to_string(&path)?;
|
|
30
|
+
let stats: UsageStats = serde_json::from_str(&json)?;
|
|
31
|
+
Ok(stats)
|
|
32
|
+
} else {
|
|
33
|
+
Ok(Self::default())
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
/// Save usage stats to disk
|
|
38
|
+
pub fn save(&self) -> Result<()> {
|
|
39
|
+
let path = get_usage_stats_path()?;
|
|
40
|
+
if let Some(parent) = path.parent() {
|
|
41
|
+
fs::create_dir_all(parent)?;
|
|
42
|
+
}
|
|
43
|
+
let json = serde_json::to_string_pretty(self)?;
|
|
44
|
+
fs::write(path, json)?;
|
|
45
|
+
Ok(())
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
/// Checks if the next request of estimated `tokens` size is allowed
|
|
49
|
+
pub fn check_limit(
|
|
50
|
+
&mut self,
|
|
51
|
+
config: &RateLimitConfig,
|
|
52
|
+
estimated_tokens: u32,
|
|
53
|
+
) -> Result<(), String> {
|
|
54
|
+
self.clean_old_history(Duration::from_secs(60));
|
|
55
|
+
|
|
56
|
+
let current_tokens: u32 = self.request_history.iter().map(|(_, tokens)| *tokens).sum();
|
|
57
|
+
let current_requests = self.request_history.len() as u32;
|
|
58
|
+
|
|
59
|
+
if current_requests >= config.max_requests_per_minute {
|
|
60
|
+
return Err("Rate limit exceeded: Requests per minute".to_string());
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
if current_tokens + estimated_tokens > config.max_tokens_per_minute {
|
|
64
|
+
return Err("Rate limit exceeded: Tokens per minute".to_string());
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
Ok(())
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
/// Call this AFTER a successful API call to record actual usage
|
|
71
|
+
pub fn record_usage(&mut self, input_tokens: u32, output_tokens: u32) {
|
|
72
|
+
let now = SystemTime::now()
|
|
73
|
+
.duration_since(UNIX_EPOCH)
|
|
74
|
+
.unwrap()
|
|
75
|
+
.as_secs();
|
|
76
|
+
let total = input_tokens + output_tokens;
|
|
77
|
+
|
|
78
|
+
self.total_input_tokens += input_tokens as u64;
|
|
79
|
+
self.total_output_tokens += output_tokens as u64;
|
|
80
|
+
self.request_count += 1;
|
|
81
|
+
self.last_request_time = Some(now);
|
|
82
|
+
self.request_history.push((now, total));
|
|
83
|
+
|
|
84
|
+
// Auto-save after update
|
|
85
|
+
if let Err(e) = self.save() {
|
|
86
|
+
warn!("Failed to save usage stats: {}. Stats will not persist.", e);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
fn clean_old_history(&mut self, window: Duration) {
|
|
91
|
+
let now = SystemTime::now()
|
|
92
|
+
.duration_since(UNIX_EPOCH)
|
|
93
|
+
.unwrap()
|
|
94
|
+
.as_secs();
|
|
95
|
+
let window_secs = window.as_secs();
|
|
96
|
+
|
|
97
|
+
self.request_history
|
|
98
|
+
.retain(|(time, _)| now.saturating_sub(*time) < window_secs);
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
fn get_usage_stats_path() -> Result<PathBuf> {
|
|
103
|
+
let home_dir = dirs::home_dir().ok_or_else(|| anyhow!("Could not determine home directory"))?;
|
|
104
|
+
Ok(home_dir.join(".grok").join("usage_stats.json"))
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
#[cfg(test)]
|
|
108
|
+
mod tests {
|
|
109
|
+
use super::*;
|
|
110
|
+
use std::thread;
|
|
111
|
+
|
|
112
|
+
#[test]
|
|
113
|
+
fn test_clean_old_history() {
|
|
114
|
+
let mut stats = UsageStats::default();
|
|
115
|
+
let now = SystemTime::now()
|
|
116
|
+
.duration_since(UNIX_EPOCH)
|
|
117
|
+
.unwrap()
|
|
118
|
+
.as_secs();
|
|
119
|
+
|
|
120
|
+
// Add an old record (61 seconds ago)
|
|
121
|
+
stats.request_history.push((now - 61, 100));
|
|
122
|
+
// Add a recent record (10 seconds ago)
|
|
123
|
+
stats.request_history.push((now - 10, 50));
|
|
124
|
+
|
|
125
|
+
stats.clean_old_history(Duration::from_secs(60));
|
|
126
|
+
|
|
127
|
+
assert_eq!(stats.request_history.len(), 1);
|
|
128
|
+
assert_eq!(stats.request_history[0].1, 50);
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
#[test]
|
|
132
|
+
fn test_check_limit_requests() {
|
|
133
|
+
let config = RateLimitConfig {
|
|
134
|
+
max_requests_per_minute: 2,
|
|
135
|
+
max_tokens_per_minute: 1000,
|
|
136
|
+
};
|
|
137
|
+
let mut stats = UsageStats::default();
|
|
138
|
+
|
|
139
|
+
assert!(stats.check_limit(&config, 10).is_ok());
|
|
140
|
+
stats.record_usage(5, 5); // 1st request
|
|
141
|
+
|
|
142
|
+
assert!(stats.check_limit(&config, 10).is_ok());
|
|
143
|
+
stats.record_usage(5, 5); // 2nd request
|
|
144
|
+
|
|
145
|
+
// 3rd request should fail
|
|
146
|
+
assert!(stats.check_limit(&config, 10).is_err());
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
#[test]
|
|
150
|
+
fn test_check_limit_tokens() {
|
|
151
|
+
let config = RateLimitConfig {
|
|
152
|
+
max_requests_per_minute: 10,
|
|
153
|
+
max_tokens_per_minute: 100,
|
|
154
|
+
};
|
|
155
|
+
let mut stats = UsageStats::default();
|
|
156
|
+
|
|
157
|
+
assert!(stats.check_limit(&config, 50).is_ok());
|
|
158
|
+
stats.record_usage(50, 0);
|
|
159
|
+
|
|
160
|
+
assert!(stats.check_limit(&config, 50).is_ok());
|
|
161
|
+
stats.record_usage(50, 0);
|
|
162
|
+
|
|
163
|
+
// 101st token should fail
|
|
164
|
+
assert!(stats.check_limit(&config, 1).is_err());
|
|
165
|
+
}
|
|
166
|
+
}
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
use crate::display::interactive::InteractiveSession;
|
|
2
|
+
use anyhow::{anyhow, Result};
|
|
3
|
+
use std::fs;
|
|
4
|
+
use std::path::{Path, PathBuf};
|
|
5
|
+
|
|
6
|
+
/// Get the sessions directory path
|
|
7
|
+
fn get_sessions_dir() -> Result<PathBuf> {
|
|
8
|
+
let home_dir = dirs::home_dir().ok_or_else(|| anyhow!("Could not determine home directory"))?;
|
|
9
|
+
Ok(home_dir.join(".grok").join("sessions"))
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
/// Save a session to disk
|
|
13
|
+
pub fn save_session(session: &InteractiveSession, name: &str) -> Result<PathBuf> {
|
|
14
|
+
let sessions_dir = get_sessions_dir()?;
|
|
15
|
+
if !sessions_dir.exists() {
|
|
16
|
+
fs::create_dir_all(&sessions_dir)?;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
let file_path = sessions_dir.join(format!("{}.json", name));
|
|
20
|
+
let json = serde_json::to_string_pretty(session)?;
|
|
21
|
+
|
|
22
|
+
fs::write(&file_path, json)?;
|
|
23
|
+
Ok(file_path)
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
/// Load a session from disk
|
|
27
|
+
pub fn load_session(name: &str) -> Result<InteractiveSession> {
|
|
28
|
+
let sessions_dir = get_sessions_dir()?;
|
|
29
|
+
let file_path = sessions_dir.join(format!("{}.json", name));
|
|
30
|
+
|
|
31
|
+
if !file_path.exists() {
|
|
32
|
+
return Err(anyhow!("Session '{}' not found", name));
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
let json = fs::read_to_string(&file_path)?;
|
|
36
|
+
let session: InteractiveSession = serde_json::from_str(&json)?;
|
|
37
|
+
Ok(session)
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
/// List all saved sessions
|
|
41
|
+
pub fn list_sessions() -> Result<Vec<String>> {
|
|
42
|
+
let sessions_dir = get_sessions_dir()?;
|
|
43
|
+
if !sessions_dir.exists() {
|
|
44
|
+
return Ok(Vec::new());
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
let mut sessions = Vec::new();
|
|
48
|
+
for entry in fs::read_dir(sessions_dir)? {
|
|
49
|
+
let entry = entry?;
|
|
50
|
+
let path = entry.path();
|
|
51
|
+
if path.extension().and_then(|s| s.to_str()) == Some("json")
|
|
52
|
+
&& let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
|
|
53
|
+
sessions.push(stem.to_string());
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
sessions.sort();
|
|
58
|
+
Ok(sessions)
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
#[cfg(test)]
|
|
62
|
+
mod tests {
|
|
63
|
+
use super::*;
|
|
64
|
+
use tempfile::tempdir;
|
|
65
|
+
|
|
66
|
+
#[test]
|
|
67
|
+
fn test_save_and_load_session() {
|
|
68
|
+
// We can't easily mock dirs::home_dir without more complex dependency injection or env var tricks
|
|
69
|
+
// So we will verify serialization logic separately if needed, but for now this module
|
|
70
|
+
// relies on file system integration.
|
|
71
|
+
// A proper test would mock the session dir.
|
|
72
|
+
}
|
|
73
|
+
}
|