gradexp 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gradexp-0.1.0/.agent/workflows/run_tests.md +11 -0
- gradexp-0.1.0/.gitignore +36 -0
- gradexp-0.1.0/MANIFEST.in +1 -0
- gradexp-0.1.0/PKG-INFO +86 -0
- gradexp-0.1.0/README.md +74 -0
- gradexp-0.1.0/gradexp/__init__.py +2 -0
- gradexp-0.1.0/gradexp/auth.py +103 -0
- gradexp-0.1.0/gradexp/cli.py +21 -0
- gradexp-0.1.0/gradexp/client.py +468 -0
- gradexp-0.1.0/pyproject.toml +22 -0
- gradexp-0.1.0/schema.ts +62 -0
- gradexp-0.1.0/setup.txt +6 -0
- gradexp-0.1.0/test-scripts/init-log.py +7 -0
- gradexp-0.1.0/test-scripts/mock_verify.py +91 -0
- gradexp-0.1.0/test-scripts/with-wandb.py +20 -0
- gradexp-0.1.0/tests/test_client_status.py +62 -0
- gradexp-0.1.0/tests/test_gradexp_mock.py +93 -0
gradexp-0.1.0/.gitignore
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Python
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
*.so
|
|
6
|
+
.Python
|
|
7
|
+
build/
|
|
8
|
+
develop-eggs/
|
|
9
|
+
dist/
|
|
10
|
+
downloads/
|
|
11
|
+
eggs/
|
|
12
|
+
.eggs/
|
|
13
|
+
lib/
|
|
14
|
+
lib64/
|
|
15
|
+
parts/
|
|
16
|
+
sdist/
|
|
17
|
+
var/
|
|
18
|
+
wheels/
|
|
19
|
+
*.egg-info/
|
|
20
|
+
.installed.cfg
|
|
21
|
+
*.egg
|
|
22
|
+
MANIFEST
|
|
23
|
+
|
|
24
|
+
# Virtual Env
|
|
25
|
+
venv/
|
|
26
|
+
env/
|
|
27
|
+
ENV/
|
|
28
|
+
|
|
29
|
+
# Mac
|
|
30
|
+
.DS_Store
|
|
31
|
+
|
|
32
|
+
# IDEs
|
|
33
|
+
.idea/
|
|
34
|
+
.vscode/
|
|
35
|
+
|
|
36
|
+
wandb
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
include README.md
|
gradexp-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: gradexp
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Gradient Explorer Client Library
|
|
5
|
+
Author-email: Misha Obu <misha@parallel-ocean.xyz>
|
|
6
|
+
Requires-Python: >=3.7
|
|
7
|
+
Requires-Dist: appdirs
|
|
8
|
+
Requires-Dist: click
|
|
9
|
+
Requires-Dist: numpy
|
|
10
|
+
Requires-Dist: requests
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
|
|
13
|
+
source ./venv/bin/activate
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
This repo is intended to reimplement some key features of wandb for our own purposes. This is the pythonic frontend for our wandb-for-tensors service.
|
|
18
|
+
|
|
19
|
+
Repo features:
|
|
20
|
+
pip install gradexp
|
|
21
|
+
gradexp login -> opens webpage -> gets token and stores in in permanent context
|
|
22
|
+
In python file:
|
|
23
|
+
```
|
|
24
|
+
import gradexp
|
|
25
|
+
gradexp.init("project name")
|
|
26
|
+
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
# Automatically starts tracking run id, step, etc
|
|
30
|
+
gradexp.log(TODO ... )
|
|
31
|
+
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
App-wide features:
|
|
37
|
+
|
|
38
|
+
1) Local staging + asynchronous streaming
|
|
39
|
+
When you call wandb.log() in your training script, the client SDK:
|
|
40
|
+
|
|
41
|
+
collects the metric/media payload locally (in memory and also writes to local run files).
|
|
42
|
+
|
|
43
|
+
hands it off to a separate streaming thread/process that runs alongside your main process.
|
|
44
|
+
|
|
45
|
+
this uploader process asynchronously batches and sends events over the network. This avoids blocking your training process.
|
|
46
|
+
|
|
47
|
+
events are queued in memory and written to disk if needed (e.g., offline mode) and then eventually synced.
|
|
48
|
+
|
|
49
|
+
if you set WANDB_MODE=offline, nothing is sent to the server until a sync is triggered.
|
|
50
|
+
This behavior effectively decouples ingestion from your training loop. It’s not “direct write to bucket” from your app — it goes through this upload pipeline first.
|
|
51
|
+
Weights & Biases Documentation
|
|
52
|
+
|
|
53
|
+
2) Upload endpoints and storage target
|
|
54
|
+
On the server side (hosted or self-managed):
|
|
55
|
+
|
|
56
|
+
incoming streaming events (metrics, history tuples) arrive via W&B’s application backend.
|
|
57
|
+
|
|
58
|
+
metadata and small event records (like per-step metrics) are stored relationally (MySQL in self-managed reference architecture).
|
|
59
|
+
|
|
60
|
+
larger blobs (logs, media files, artifacts) are written out to object storage buckets (e.g., S3 or your own BYOB bucket).
|
|
61
|
+
You don’t stream directly into the bucket in small per-log increments — the backend receives the event first and the service layer persists the data.
|
|
62
|
+
Weights & Biases Documentation
|
|
63
|
+
+1
|
|
64
|
+
|
|
65
|
+
3) How “appending” works
|
|
66
|
+
Object storage (S3/compatible) isn’t a traditional file system — you can’t literally open and append to an existing file like with a local file. Instead:
|
|
67
|
+
|
|
68
|
+
W&B will write each piece of logged data as a separate object or part of a structured object.
|
|
69
|
+
|
|
70
|
+
for metric history exports (e.g., after a run completes), W&B creates Parquet exports and pushes those into the bucket as artifacts/history files.
|
|
71
|
+
|
|
72
|
+
this is a batch write rather than incremental byte-level append.
|
|
73
|
+
The UI and service layer then stitch these pieces together logically for history views.
|
|
74
|
+
(This pattern is common across object-store backed systems — you don’t append bytes to objects at every metric call in practice.)
|
|
75
|
+
Weights & Biases Documentation
|
|
76
|
+
|
|
77
|
+
4) Back-end buffering & batching
|
|
78
|
+
Even though metrics are “live” in the UI, there’s batching:
|
|
79
|
+
|
|
80
|
+
the SDK batches small updates and flushes over HTTP.
|
|
81
|
+
|
|
82
|
+
on the server side, those are ingested through API layers and persisted into the database or stored as discrete objects.
|
|
83
|
+
|
|
84
|
+
visualization updates pull from the latest ingested state.
|
|
85
|
+
The “real-time” aspect is achieved by frequent flushes and update propagation — not by writing to a monolithic streaming file.
|
|
86
|
+
W&B’s process model means it tolerates network latency/outages by buffering locally and then syncing when available.
|
gradexp-0.1.0/README.md
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
source ./venv/bin/activate
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
This repo is intended to reimplement some key features of wandb for our own purposes. This is the pythonic frontend for our wandb-for-tensors service.
|
|
6
|
+
|
|
7
|
+
Repo features:
|
|
8
|
+
pip install gradexp
|
|
9
|
+
gradexp login -> opens webpage -> gets token and stores in in permanent context
|
|
10
|
+
In python file:
|
|
11
|
+
```
|
|
12
|
+
import gradexp
|
|
13
|
+
gradexp.init("project name")
|
|
14
|
+
|
|
15
|
+
...
|
|
16
|
+
|
|
17
|
+
# Automatically starts tracking run id, step, etc
|
|
18
|
+
gradexp.log(TODO ... )
|
|
19
|
+
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
App-wide features:
|
|
25
|
+
|
|
26
|
+
1) Local staging + asynchronous streaming
|
|
27
|
+
When you call wandb.log() in your training script, the client SDK:
|
|
28
|
+
|
|
29
|
+
collects the metric/media payload locally (in memory and also writes to local run files).
|
|
30
|
+
|
|
31
|
+
hands it off to a separate streaming thread/process that runs alongside your main process.
|
|
32
|
+
|
|
33
|
+
this uploader process asynchronously batches and sends events over the network. This avoids blocking your training process.
|
|
34
|
+
|
|
35
|
+
events are queued in memory and written to disk if needed (e.g., offline mode) and then eventually synced.
|
|
36
|
+
|
|
37
|
+
if you set WANDB_MODE=offline, nothing is sent to the server until a sync is triggered.
|
|
38
|
+
This behavior effectively decouples ingestion from your training loop. It’s not “direct write to bucket” from your app — it goes through this upload pipeline first.
|
|
39
|
+
Weights & Biases Documentation
|
|
40
|
+
|
|
41
|
+
2) Upload endpoints and storage target
|
|
42
|
+
On the server side (hosted or self-managed):
|
|
43
|
+
|
|
44
|
+
incoming streaming events (metrics, history tuples) arrive via W&B’s application backend.
|
|
45
|
+
|
|
46
|
+
metadata and small event records (like per-step metrics) are stored relationally (MySQL in self-managed reference architecture).
|
|
47
|
+
|
|
48
|
+
larger blobs (logs, media files, artifacts) are written out to object storage buckets (e.g., S3 or your own BYOB bucket).
|
|
49
|
+
You don’t stream directly into the bucket in small per-log increments — the backend receives the event first and the service layer persists the data.
|
|
50
|
+
Weights & Biases Documentation
|
|
51
|
+
+1
|
|
52
|
+
|
|
53
|
+
3) How “appending” works
|
|
54
|
+
Object storage (S3/compatible) isn’t a traditional file system — you can’t literally open and append to an existing file like with a local file. Instead:
|
|
55
|
+
|
|
56
|
+
W&B will write each piece of logged data as a separate object or part of a structured object.
|
|
57
|
+
|
|
58
|
+
for metric history exports (e.g., after a run completes), W&B creates Parquet exports and pushes those into the bucket as artifacts/history files.
|
|
59
|
+
|
|
60
|
+
this is a batch write rather than incremental byte-level append.
|
|
61
|
+
The UI and service layer then stitch these pieces together logically for history views.
|
|
62
|
+
(This pattern is common across object-store backed systems — you don’t append bytes to objects at every metric call in practice.)
|
|
63
|
+
Weights & Biases Documentation
|
|
64
|
+
|
|
65
|
+
4) Back-end buffering & batching
|
|
66
|
+
Even though metrics are “live” in the UI, there’s batching:
|
|
67
|
+
|
|
68
|
+
the SDK batches small updates and flushes over HTTP.
|
|
69
|
+
|
|
70
|
+
on the server side, those are ingested through API layers and persisted into the database or stored as discrete objects.
|
|
71
|
+
|
|
72
|
+
visualization updates pull from the latest ingested state.
|
|
73
|
+
The “real-time” aspect is achieved by frequent flushes and update propagation — not by writing to a monolithic streaming file.
|
|
74
|
+
W&B’s process model means it tolerates network latency/outages by buffering locally and then syncing when available.
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import appdirs
|
|
3
|
+
import json
|
|
4
|
+
import webbrowser
|
|
5
|
+
import click
|
|
6
|
+
import requests
|
|
7
|
+
|
|
8
|
+
APP_NAME = "gradexp"
|
|
9
|
+
APP_AUTHOR = "GradientExplorer"
|
|
10
|
+
BASE_URL = "gradient-explorer.xyz"
|
|
11
|
+
|
|
12
|
+
def get_config_dir():
|
|
13
|
+
return appdirs.user_config_dir(APP_NAME, APP_AUTHOR)
|
|
14
|
+
|
|
15
|
+
def get_token_path():
|
|
16
|
+
config_dir = get_config_dir()
|
|
17
|
+
return os.path.join(config_dir, "secrets.json")
|
|
18
|
+
|
|
19
|
+
def save_token(token):
|
|
20
|
+
config_dir = get_config_dir()
|
|
21
|
+
if not os.path.exists(config_dir):
|
|
22
|
+
os.makedirs(config_dir)
|
|
23
|
+
|
|
24
|
+
token_path = get_token_path()
|
|
25
|
+
with open(token_path, 'w') as f:
|
|
26
|
+
json.dump({"api_key": token}, f)
|
|
27
|
+
|
|
28
|
+
# Set permissions to be readable only by user (0600)
|
|
29
|
+
os.chmod(token_path, 0o600)
|
|
30
|
+
|
|
31
|
+
def load_token():
|
|
32
|
+
token_path = get_token_path()
|
|
33
|
+
if not os.path.exists(token_path):
|
|
34
|
+
return None
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
with open(token_path, 'r') as f:
|
|
38
|
+
data = json.load(f)
|
|
39
|
+
return data.get("api_key")
|
|
40
|
+
except (json.JSONDecodeError, IOError):
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
def validate_api_key(api_key, debug=False):
|
|
44
|
+
try:
|
|
45
|
+
url = f"https://api.{BASE_URL}/api/v1/me"
|
|
46
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
47
|
+
|
|
48
|
+
if debug:
|
|
49
|
+
click.echo(f"DEBUG: Request URL: {url}")
|
|
50
|
+
click.echo(f"DEBUG: Request Headers: Authorization: Bearer {api_key[:4]}...{api_key[-4:]}")
|
|
51
|
+
|
|
52
|
+
response = requests.get(
|
|
53
|
+
url,
|
|
54
|
+
headers=headers
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
if debug:
|
|
58
|
+
click.echo(f"DEBUG: Response Status: {response.status_code}")
|
|
59
|
+
click.echo(f"DEBUG: Response Body: {response.text}")
|
|
60
|
+
|
|
61
|
+
response.raise_for_status()
|
|
62
|
+
return response.json()
|
|
63
|
+
except requests.exceptions.RequestException as e:
|
|
64
|
+
if debug:
|
|
65
|
+
click.echo(f"DEBUG: Request failed: {e}")
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
def login_flow(debug=False):
|
|
69
|
+
"""
|
|
70
|
+
Initiates the login flow:
|
|
71
|
+
1. Opens the browser to the authorization URL.
|
|
72
|
+
2. Prompts the user to paste the API key.
|
|
73
|
+
3. Validates the API key.
|
|
74
|
+
4. Saves the API key securely.
|
|
75
|
+
"""
|
|
76
|
+
auth_url = f"https://{BASE_URL}/authorize"
|
|
77
|
+
|
|
78
|
+
click.echo(f"Opening {auth_url} in your default browser...")
|
|
79
|
+
webbrowser.open(auth_url)
|
|
80
|
+
|
|
81
|
+
api_key = click.prompt("Please paste your API key here", hide_input=True)
|
|
82
|
+
|
|
83
|
+
if not api_key:
|
|
84
|
+
click.echo("No API key provided. Login failed.")
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
user_info = validate_api_key(api_key, debug=debug)
|
|
88
|
+
|
|
89
|
+
if user_info:
|
|
90
|
+
save_token(api_key)
|
|
91
|
+
click.secho("Successfully logged in. Your API key is saved.", fg="green")
|
|
92
|
+
else:
|
|
93
|
+
# Construct error message
|
|
94
|
+
click.secho("Error: Invalid API Key", fg="red")
|
|
95
|
+
|
|
96
|
+
def check_login():
|
|
97
|
+
token = load_token()
|
|
98
|
+
if token:
|
|
99
|
+
click.echo(f"Logged in with token: {token[:4]}..." + "*" * 10)
|
|
100
|
+
return True
|
|
101
|
+
else:
|
|
102
|
+
click.echo("Not logged in. Please run `gradexp login`.")
|
|
103
|
+
return False
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import click
|
|
2
|
+
from . import auth
|
|
3
|
+
|
|
4
|
+
@click.group()
|
|
5
|
+
@click.version_option()
|
|
6
|
+
@click.option('--debug', is_flag=True, help="Enable debug logging")
|
|
7
|
+
@click.pass_context
|
|
8
|
+
def main(ctx, debug):
|
|
9
|
+
"""Gradient Explorer CLI"""
|
|
10
|
+
ctx.ensure_object(dict)
|
|
11
|
+
ctx.obj['DEBUG'] = debug
|
|
12
|
+
|
|
13
|
+
@main.command()
|
|
14
|
+
@click.option('--debug', is_flag=True, help="Enable debug logging")
|
|
15
|
+
@click.pass_context
|
|
16
|
+
def login(ctx, debug):
|
|
17
|
+
"""Log in to Gradient Explorer"""
|
|
18
|
+
auth.login_flow(debug=debug or ctx.obj.get('DEBUG', False))
|
|
19
|
+
|
|
20
|
+
if __name__ == '__main__':
|
|
21
|
+
main()
|
|
@@ -0,0 +1,468 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import signal
|
|
4
|
+
import atexit
|
|
5
|
+
import base64
|
|
6
|
+
import json
|
|
7
|
+
import uuid
|
|
8
|
+
import time
|
|
9
|
+
import threading
|
|
10
|
+
import queue
|
|
11
|
+
import shutil
|
|
12
|
+
import appdirs
|
|
13
|
+
import requests
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from . import auth
|
|
17
|
+
|
|
18
|
+
# Calm red: \033[38;5;174m
|
|
19
|
+
# Reset: \033[0m
|
|
20
|
+
GRADEXP_PREFIX = "\033[38;5;174mgradexp\033[0m: "
|
|
21
|
+
|
|
22
|
+
def _term_log(message, end="\n"):
|
|
23
|
+
"""
|
|
24
|
+
Prints a message with a colored 'gradexp:' prefix.
|
|
25
|
+
"""
|
|
26
|
+
# Handle multi-line messages or messages starting with newline
|
|
27
|
+
if isinstance(message, str) and message.startswith("\n"):
|
|
28
|
+
print(f"\n{GRADEXP_PREFIX}{message[1:]}", end=end)
|
|
29
|
+
else:
|
|
30
|
+
print(f"{GRADEXP_PREFIX}{message}", end=end)
|
|
31
|
+
|
|
32
|
+
class Client:
|
|
33
|
+
def __init__(self):
|
|
34
|
+
self.api_key = None
|
|
35
|
+
self.run_id = None
|
|
36
|
+
self.project_id = None
|
|
37
|
+
self.tensors = {} # {name: {"id": uuid, "step": 0}}
|
|
38
|
+
self.active = False
|
|
39
|
+
self._session = requests.Session()
|
|
40
|
+
self._upload_queue = queue.Queue()
|
|
41
|
+
self._stop_event = threading.Event()
|
|
42
|
+
self._worker_thread = None
|
|
43
|
+
self._cache_dir = None
|
|
44
|
+
self._interrupted = False
|
|
45
|
+
self._original_sigint_handler = None
|
|
46
|
+
self._original_sigterm_handler = None
|
|
47
|
+
|
|
48
|
+
# Progress tracking
|
|
49
|
+
self._progress_lock = threading.Lock()
|
|
50
|
+
self._bytes_uploaded = 0
|
|
51
|
+
self._tensors_uploaded = 0
|
|
52
|
+
self._total_tensors_queued = 0
|
|
53
|
+
self._upload_start_time = None
|
|
54
|
+
|
|
55
|
+
def init(self, project_name=None, run_name=None):
|
|
56
|
+
"""
|
|
57
|
+
Initialize a new run session.
|
|
58
|
+
"""
|
|
59
|
+
# Try to pull from wandb if not provided
|
|
60
|
+
if project_name is None or run_name is None:
|
|
61
|
+
if "wandb" in sys.modules:
|
|
62
|
+
import wandb
|
|
63
|
+
if wandb.run:
|
|
64
|
+
if project_name is None:
|
|
65
|
+
project_name = wandb.run.project
|
|
66
|
+
if run_name is None:
|
|
67
|
+
run_name = wandb.run.name
|
|
68
|
+
|
|
69
|
+
# Fallback to defaults if still None
|
|
70
|
+
if project_name is None:
|
|
71
|
+
project_name = "default"
|
|
72
|
+
if run_name is None:
|
|
73
|
+
run_name = "default-run"
|
|
74
|
+
|
|
75
|
+
self.api_key = auth.load_token()
|
|
76
|
+
if not self.api_key:
|
|
77
|
+
raise RuntimeError("Not logged in. Please run 'gradexp login' first.")
|
|
78
|
+
|
|
79
|
+
current_run_name = run_name
|
|
80
|
+
while True:
|
|
81
|
+
payload = {
|
|
82
|
+
"project_name": project_name,
|
|
83
|
+
"run_name": current_run_name
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
url = f"https://api.{auth.BASE_URL}/api/v1/runs"
|
|
88
|
+
headers = {
|
|
89
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
90
|
+
"Content-Type": "application/json"
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
response = self._session.post(url, headers=headers, json=payload)
|
|
94
|
+
response.raise_for_status()
|
|
95
|
+
|
|
96
|
+
data = response.json()
|
|
97
|
+
self.run_id = data.get("run_id")
|
|
98
|
+
self.project_id = data.get("project_id")
|
|
99
|
+
self.tensors = {}
|
|
100
|
+
self.active = True
|
|
101
|
+
|
|
102
|
+
self._cache_dir = appdirs.user_cache_dir(auth.APP_NAME, auth.APP_AUTHOR)
|
|
103
|
+
os.makedirs(self._cache_dir, exist_ok=True)
|
|
104
|
+
|
|
105
|
+
self._stop_event.clear()
|
|
106
|
+
self._worker_thread = threading.Thread(target=self._upload_worker, daemon=True)
|
|
107
|
+
self._worker_thread.start()
|
|
108
|
+
|
|
109
|
+
base_url = f"https://{auth.BASE_URL}"
|
|
110
|
+
run_url = f"{base_url}/?project={self.project_id}&run={self.run_id}"
|
|
111
|
+
_term_log(f"Initialized. See run progress at {run_url}")
|
|
112
|
+
|
|
113
|
+
atexit.register(self.finish)
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
self._original_sigint_handler = signal.getsignal(signal.SIGINT)
|
|
117
|
+
self._original_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
|
118
|
+
signal.signal(signal.SIGINT, self._handle_interrupt)
|
|
119
|
+
signal.signal(signal.SIGTERM, self._handle_interrupt)
|
|
120
|
+
except ValueError:
|
|
121
|
+
pass
|
|
122
|
+
|
|
123
|
+
break # Success
|
|
124
|
+
|
|
125
|
+
except requests.exceptions.HTTPError as e:
|
|
126
|
+
if e.response.status_code == 409:
|
|
127
|
+
try:
|
|
128
|
+
err_data = e.response.json()
|
|
129
|
+
if "suggested_name" in err_data:
|
|
130
|
+
current_run_name = err_data["suggested_name"]
|
|
131
|
+
continue
|
|
132
|
+
except Exception:
|
|
133
|
+
pass
|
|
134
|
+
|
|
135
|
+
if "-" in current_run_name and current_run_name.rsplit("-", 1)[1].isdigit():
|
|
136
|
+
parts = current_run_name.rsplit("-", 1)
|
|
137
|
+
base = parts[0]
|
|
138
|
+
num = int(parts[1]) + 1
|
|
139
|
+
current_run_name = f"{base}-{num}"
|
|
140
|
+
else:
|
|
141
|
+
current_run_name = f"{current_run_name}-2"
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
_term_log(f"Failed to initialize: {e}")
|
|
145
|
+
try:
|
|
146
|
+
_term_log(f"Backend error message: {e.response.text}")
|
|
147
|
+
except:
|
|
148
|
+
pass
|
|
149
|
+
|
|
150
|
+
if e.response.status_code == 401:
|
|
151
|
+
_term_log("Authentication failed: Unauthorized. Please run 'gradexp login' to authenticate.")
|
|
152
|
+
sys.exit(1)
|
|
153
|
+
else:
|
|
154
|
+
raise
|
|
155
|
+
except requests.exceptions.RequestException as e:
|
|
156
|
+
_term_log(f"Failed to initialize: {e}")
|
|
157
|
+
raise
|
|
158
|
+
|
|
159
|
+
def _ensure_tensor(self, name, array):
|
|
160
|
+
"""
|
|
161
|
+
Ensures a tensor exists on the backend. Returns its ID.
|
|
162
|
+
"""
|
|
163
|
+
if name in self.tensors:
|
|
164
|
+
return self.tensors[name]["id"]
|
|
165
|
+
|
|
166
|
+
try:
|
|
167
|
+
url = f"https://api.{auth.BASE_URL}/api/v1/runs/{self.run_id}/tensors"
|
|
168
|
+
headers = {
|
|
169
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
170
|
+
"Content-Type": "application/json"
|
|
171
|
+
}
|
|
172
|
+
payload = {
|
|
173
|
+
"name": name,
|
|
174
|
+
"dtype": str(array.dtype),
|
|
175
|
+
"shape": list(array.shape)
|
|
176
|
+
}
|
|
177
|
+
response = self._session.post(url, headers=headers, json=payload)
|
|
178
|
+
response.raise_for_status()
|
|
179
|
+
|
|
180
|
+
data = response.json()
|
|
181
|
+
tensor_id = data.get("tensor_id")
|
|
182
|
+
self.tensors[name] = {"id": tensor_id, "step": 0}
|
|
183
|
+
return tensor_id
|
|
184
|
+
except Exception as e:
|
|
185
|
+
_term_log(f"Failed to create tensor '{name}': {e}")
|
|
186
|
+
return None
|
|
187
|
+
|
|
188
|
+
def log(self, tensor, name="default"):
|
|
189
|
+
"""
|
|
190
|
+
Upload data for the tensor.
|
|
191
|
+
"""
|
|
192
|
+
if not self.active:
|
|
193
|
+
_term_log("Not initialized. Please call gradexp.init() first.")
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
# Convert to numpy if needed
|
|
197
|
+
if not isinstance(tensor, np.ndarray):
|
|
198
|
+
try:
|
|
199
|
+
tensor = np.array(tensor)
|
|
200
|
+
except Exception:
|
|
201
|
+
_term_log("Could not convert input to numpy array.")
|
|
202
|
+
return
|
|
203
|
+
|
|
204
|
+
# Ensure tensor is registered
|
|
205
|
+
tensor_id = self._ensure_tensor(name, tensor)
|
|
206
|
+
if not tensor_id:
|
|
207
|
+
return
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
# Full consistency: float32 for now as per previous logic,
|
|
211
|
+
# but ideally we respect the tensor's own dtype if we sent that to backend.
|
|
212
|
+
# Assuming backend handles what we told it in _ensure_tensor.
|
|
213
|
+
# But previous code cast to float32. Let's stick to float32 for consistency for now unless specified.
|
|
214
|
+
if tensor.dtype != np.float32:
|
|
215
|
+
tensor = tensor.astype(np.float32)
|
|
216
|
+
|
|
217
|
+
tensor_bytes = tensor.tobytes()
|
|
218
|
+
|
|
219
|
+
# Store locally
|
|
220
|
+
unique_id = str(uuid.uuid4())
|
|
221
|
+
file_path = os.path.join(self._cache_dir, f"step_{unique_id}.bin")
|
|
222
|
+
|
|
223
|
+
with open(file_path, "wb") as f:
|
|
224
|
+
f.write(tensor_bytes)
|
|
225
|
+
|
|
226
|
+
# Get and increment step
|
|
227
|
+
step_index = self.tensors[name]["step"]
|
|
228
|
+
self.tensors[name]["step"] += 1
|
|
229
|
+
|
|
230
|
+
# Enqueue for background upload
|
|
231
|
+
with self._progress_lock:
|
|
232
|
+
self._total_tensors_queued += 1
|
|
233
|
+
self._upload_queue.put({
|
|
234
|
+
"file_path": file_path,
|
|
235
|
+
"tensor_id": tensor_id,
|
|
236
|
+
"step": step_index
|
|
237
|
+
})
|
|
238
|
+
|
|
239
|
+
except Exception as e:
|
|
240
|
+
_term_log(f"Failed to buffer tensor data: {e}")
|
|
241
|
+
|
|
242
|
+
def _update_status(self, status):
|
|
243
|
+
"""
|
|
244
|
+
Updates the status of the run.
|
|
245
|
+
Uses PATCH /api/v1/runs/{run_id}
|
|
246
|
+
"""
|
|
247
|
+
if not self.run_id or not self.api_key:
|
|
248
|
+
return
|
|
249
|
+
|
|
250
|
+
try:
|
|
251
|
+
url = f"https://api.{auth.BASE_URL}/api/v1/runs/{self.run_id}"
|
|
252
|
+
headers = {
|
|
253
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
254
|
+
"Content-Type": "application/json"
|
|
255
|
+
}
|
|
256
|
+
payload = {"status": status}
|
|
257
|
+
response = self._session.patch(url, headers=headers, json=payload)
|
|
258
|
+
response.raise_for_status()
|
|
259
|
+
except Exception as e:
|
|
260
|
+
_term_log(f"Failed to update session status to {status}: {e}")
|
|
261
|
+
|
|
262
|
+
def _print_progress(self):
|
|
263
|
+
"""
|
|
264
|
+
Prints a progress bar showing upload status and throughput.
|
|
265
|
+
"""
|
|
266
|
+
with self._progress_lock:
|
|
267
|
+
uploaded = self._tensors_uploaded
|
|
268
|
+
total = self._total_tensors_queued
|
|
269
|
+
bytes_up = self._bytes_uploaded
|
|
270
|
+
start_time = self._upload_start_time
|
|
271
|
+
|
|
272
|
+
# Calculate throughput
|
|
273
|
+
if start_time and bytes_up > 0:
|
|
274
|
+
elapsed = time.time() - start_time
|
|
275
|
+
if elapsed > 0:
|
|
276
|
+
mbps = (bytes_up / (1024 * 1024)) / elapsed
|
|
277
|
+
else:
|
|
278
|
+
mbps = 0.0
|
|
279
|
+
else:
|
|
280
|
+
mbps = 0.0
|
|
281
|
+
|
|
282
|
+
# Build progress bar
|
|
283
|
+
bar_width = 30
|
|
284
|
+
if total > 0:
|
|
285
|
+
filled = int(bar_width * uploaded / total)
|
|
286
|
+
else:
|
|
287
|
+
filled = 0
|
|
288
|
+
bar = "█" * filled + "░" * (bar_width - filled)
|
|
289
|
+
|
|
290
|
+
sys.stdout.write(f"\r{GRADEXP_PREFIX}[{bar}] {uploaded}/{total} tensors | {mbps:.2f} MB/s ")
|
|
291
|
+
sys.stdout.flush()
|
|
292
|
+
|
|
293
|
+
def _handle_interrupt(self, signum, frame):
|
|
294
|
+
"""
|
|
295
|
+
Handles SIGINT (Ctrl+C) or SIGTERM.
|
|
296
|
+
"""
|
|
297
|
+
if not self._interrupted:
|
|
298
|
+
self._interrupted = True
|
|
299
|
+
q_size = self._upload_queue.qsize()
|
|
300
|
+
|
|
301
|
+
if q_size == 0:
|
|
302
|
+
_term_log("\nInterrupt received. No pending uploads. Shutting down...")
|
|
303
|
+
self.finish("stopped")
|
|
304
|
+
sys.exit(0)
|
|
305
|
+
else:
|
|
306
|
+
_term_log(f"\nInterrupt received. Waiting for {q_size} pending uploads to complete.")
|
|
307
|
+
_term_log("Press Ctrl+C again to force quit.\n")
|
|
308
|
+
|
|
309
|
+
# Show progress while waiting for uploads
|
|
310
|
+
try:
|
|
311
|
+
while not self._upload_queue.empty() or (self._worker_thread and self._worker_thread.is_alive() and self._tensors_uploaded < self._total_tensors_queued):
|
|
312
|
+
self._print_progress()
|
|
313
|
+
time.sleep(0.2)
|
|
314
|
+
self._print_progress() # Final update
|
|
315
|
+
print() # Newline after progress bar
|
|
316
|
+
self.finish("stopped")
|
|
317
|
+
sys.exit(0)
|
|
318
|
+
except KeyboardInterrupt:
|
|
319
|
+
# Second Ctrl+C during progress display - force quit
|
|
320
|
+
_term_log("\nForce quitting...")
|
|
321
|
+
self._update_status("stopped")
|
|
322
|
+
os._exit(1)
|
|
323
|
+
else:
|
|
324
|
+
_term_log("\nForce quitting...")
|
|
325
|
+
# Attempt a quick status update if possible, but don't block
|
|
326
|
+
try:
|
|
327
|
+
# Direct status update attempt to backend, skipping the queue
|
|
328
|
+
# This might fail if the session is already half-closed but worth a try
|
|
329
|
+
self._update_status("stopped")
|
|
330
|
+
except:
|
|
331
|
+
pass
|
|
332
|
+
os._exit(1) # Immediate exit
|
|
333
|
+
|
|
334
|
+
def _upload_worker(self):
|
|
335
|
+
"""
|
|
336
|
+
Background worker that uploads data from the queue.
|
|
337
|
+
"""
|
|
338
|
+
while not self._stop_event.is_set() or not self._upload_queue.empty():
|
|
339
|
+
try:
|
|
340
|
+
# Use a timeout to occasionally check the stop event
|
|
341
|
+
item = self._upload_queue.get(timeout=1.0)
|
|
342
|
+
except queue.Empty:
|
|
343
|
+
continue
|
|
344
|
+
|
|
345
|
+
file_path = item["file_path"]
|
|
346
|
+
tensor_id = item["tensor_id"]
|
|
347
|
+
step = item["step"]
|
|
348
|
+
|
|
349
|
+
try:
|
|
350
|
+
if not os.path.exists(file_path):
|
|
351
|
+
self._upload_queue.task_done()
|
|
352
|
+
continue
|
|
353
|
+
|
|
354
|
+
with open(file_path, "rb") as f:
|
|
355
|
+
content = f.read()
|
|
356
|
+
|
|
357
|
+
content_len = len(content)
|
|
358
|
+
|
|
359
|
+
# Set upload start time on first upload
|
|
360
|
+
with self._progress_lock:
|
|
361
|
+
if self._upload_start_time is None:
|
|
362
|
+
self._upload_start_time = time.time()
|
|
363
|
+
|
|
364
|
+
url = f"https://api.{auth.BASE_URL}/api/v1/tensors/{tensor_id}/step"
|
|
365
|
+
headers = {
|
|
366
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
367
|
+
"Content-Type": "application/octet-stream"
|
|
368
|
+
}
|
|
369
|
+
params = {
|
|
370
|
+
"step": step,
|
|
371
|
+
"index": step # Kept for backward compatibility if needed
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
response = self._session.post(url, headers=headers, data=content, params=params)
|
|
375
|
+
response.raise_for_status()
|
|
376
|
+
|
|
377
|
+
# Success: update progress and delete the local file
|
|
378
|
+
with self._progress_lock:
|
|
379
|
+
self._bytes_uploaded += content_len
|
|
380
|
+
self._tensors_uploaded += 1
|
|
381
|
+
os.remove(file_path)
|
|
382
|
+
|
|
383
|
+
except requests.exceptions.RequestException as e:
|
|
384
|
+
_term_log(f"Failed to upload tensor step in background: {e}")
|
|
385
|
+
if os.path.exists(file_path):
|
|
386
|
+
os.remove(file_path)
|
|
387
|
+
except Exception as e:
|
|
388
|
+
_term_log(f"Unexpected error in upload worker: {e}")
|
|
389
|
+
if os.path.exists(file_path):
|
|
390
|
+
os.remove(file_path)
|
|
391
|
+
finally:
|
|
392
|
+
self._upload_queue.task_done()
|
|
393
|
+
|
|
394
|
+
def finish(self, status="complete"):
|
|
395
|
+
"""
|
|
396
|
+
Marks the session as finished and flushes the upload queue.
|
|
397
|
+
"""
|
|
398
|
+
if not self.active:
|
|
399
|
+
return
|
|
400
|
+
|
|
401
|
+
# Signal the worker to stop after processing remaining items
|
|
402
|
+
self._stop_event.set()
|
|
403
|
+
|
|
404
|
+
if self._worker_thread and self._worker_thread.is_alive():
|
|
405
|
+
# Check if there are pending uploads
|
|
406
|
+
with self._progress_lock:
|
|
407
|
+
pending = self._total_tensors_queued - self._tensors_uploaded
|
|
408
|
+
|
|
409
|
+
if pending > 0:
|
|
410
|
+
_term_log("finishing... Waiting for background uploads to complete.")
|
|
411
|
+
# Show progress while waiting
|
|
412
|
+
while self._worker_thread.is_alive() and self._tensors_uploaded < self._total_tensors_queued:
|
|
413
|
+
self._print_progress()
|
|
414
|
+
time.sleep(0.2)
|
|
415
|
+
self._print_progress() # Final update
|
|
416
|
+
print() # Newline after progress bar
|
|
417
|
+
else:
|
|
418
|
+
# No pending uploads, just wait for thread to finish
|
|
419
|
+
self._worker_thread.join()
|
|
420
|
+
|
|
421
|
+
# Update status on backend
|
|
422
|
+
self._update_status(status)
|
|
423
|
+
|
|
424
|
+
# Restore signal handlers
|
|
425
|
+
try:
|
|
426
|
+
if self._original_sigint_handler:
|
|
427
|
+
signal.signal(signal.SIGINT, self._original_sigint_handler)
|
|
428
|
+
if self._original_sigterm_handler:
|
|
429
|
+
signal.signal(signal.SIGTERM, self._original_sigterm_handler)
|
|
430
|
+
except ValueError:
|
|
431
|
+
pass
|
|
432
|
+
|
|
433
|
+
self._interrupted = False
|
|
434
|
+
self.active = False
|
|
435
|
+
_term_log(f"Session {status}.")
|
|
436
|
+
|
|
437
|
+
# Singleton instance
|
|
438
|
+
_client = Client()
|
|
439
|
+
|
|
440
|
+
def _excepthook(type, value, traceback):
|
|
441
|
+
if _client.active:
|
|
442
|
+
_client.finish("stopped")
|
|
443
|
+
sys.__excepthook__(type, value, traceback)
|
|
444
|
+
|
|
445
|
+
def init(project_name=None, run_name=None):
|
|
446
|
+
# Install excepthook if not already installed
|
|
447
|
+
if sys.excepthook is not _excepthook:
|
|
448
|
+
# We might want to chain if there's already a custom one?
|
|
449
|
+
# For simplicity, we just use ours which calls the default __excepthook__.
|
|
450
|
+
# If user has another custom one, this might override it.
|
|
451
|
+
# A safer way is to store the previous one.
|
|
452
|
+
global _original_excepthook
|
|
453
|
+
_original_excepthook = sys.excepthook
|
|
454
|
+
|
|
455
|
+
def tagged_excepthook(type, value, traceback):
|
|
456
|
+
if _client.active:
|
|
457
|
+
_client.finish("stopped")
|
|
458
|
+
_original_excepthook(type, value, traceback)
|
|
459
|
+
|
|
460
|
+
sys.excepthook = tagged_excepthook
|
|
461
|
+
|
|
462
|
+
_client.init(project_name=project_name, run_name=run_name)
|
|
463
|
+
|
|
464
|
+
def log(tensor, name="default"):
|
|
465
|
+
_client.log(tensor, name=name)
|
|
466
|
+
|
|
467
|
+
def finish(status="complete"):
|
|
468
|
+
_client.finish(status)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "gradexp"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Gradient Explorer Client Library"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.7"
|
|
7
|
+
authors = [
|
|
8
|
+
{ name = "Misha Obu", email = "misha@parallel-ocean.xyz" }
|
|
9
|
+
]
|
|
10
|
+
dependencies = [
|
|
11
|
+
"click",
|
|
12
|
+
"requests",
|
|
13
|
+
"appdirs",
|
|
14
|
+
"numpy",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[project.scripts]
|
|
18
|
+
gradexp = "gradexp.cli:main"
|
|
19
|
+
|
|
20
|
+
[build-system]
|
|
21
|
+
requires = ["hatchling"]
|
|
22
|
+
build-backend = "hatchling.build"
|
gradexp-0.1.0/schema.ts
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* GradExp Schema Definitions -- Documentation
|
|
3
|
+
*
|
|
4
|
+
* Defines the data structures for the new GradExp API (Projects -> Runs -> Tensors -> Steps).
|
|
5
|
+
*
|
|
6
|
+
* Endpoints:
|
|
7
|
+
* - POST /api/v1/runs (Create Run, presumably handles Project creation/lookup)
|
|
8
|
+
* - POST /api/v1/runs/{run_id}/tensors (Create Tensor)
|
|
9
|
+
* - POST /api/v1/tensors/{tensor_id}/step (Log Tensor Value)
|
|
10
|
+
*/
|
|
11
|
+
|
|
12
|
+
/**
|
|
13
|
+
* Payload for creating a run.
|
|
14
|
+
* POST /api/v1/runs
|
|
15
|
+
*/
|
|
16
|
+
export interface CreateRunRequest {
|
|
17
|
+
project_name: string;
|
|
18
|
+
run_name: string;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
export interface CreateRunResponse {
|
|
22
|
+
run_id: string;
|
|
23
|
+
project_id: string;
|
|
24
|
+
name: string;
|
|
25
|
+
status: 'active' | 'complete';
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
/**
|
|
29
|
+
* Payload for creating a tensor.
|
|
30
|
+
* POST /api/v1/runs/{run_id}/tensors
|
|
31
|
+
*/
|
|
32
|
+
export interface CreateTensorRequest {
|
|
33
|
+
name: string;
|
|
34
|
+
dtype: string;
|
|
35
|
+
shape: number[];
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
export interface CreateTensorResponse {
|
|
39
|
+
tensor_id: string;
|
|
40
|
+
run_id: string;
|
|
41
|
+
name: string;
|
|
42
|
+
dtype: string;
|
|
43
|
+
shape: number[];
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
/**
|
|
47
|
+
* Payload for logging a step value.
|
|
48
|
+
* POST /api/v1/tensors/{tensor_id}/step?step={step}
|
|
49
|
+
*
|
|
50
|
+
* The body should be the raw binary data of the tensor.
|
|
51
|
+
*/
|
|
52
|
+
// export interface LogStepRequest { ... } (Removed as it's no longer JSON)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
/**
|
|
56
|
+
* Payload for updating run status.
|
|
57
|
+
* PATCH /api/v1/runs/{run_id}
|
|
58
|
+
*/
|
|
59
|
+
export interface UpdateRunStatusRequest {
|
|
60
|
+
status: 'complete' | 'stopped' | 'stalled';
|
|
61
|
+
}
|
|
62
|
+
|
gradexp-0.1.0/setup.txt
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
from unittest.mock import MagicMock, patch
|
|
3
|
+
import json
|
|
4
|
+
import base64
|
|
5
|
+
import numpy as np
|
|
6
|
+
import time
|
|
7
|
+
import os
|
|
8
|
+
import requests
|
|
9
|
+
|
|
10
|
+
# Set a dummy API key for the test
|
|
11
|
+
os.environ["GRADIENT_EXPLORER_API_KEY"] = "dummy_key"
|
|
12
|
+
|
|
13
|
+
import gradexp
|
|
14
|
+
from gradexp import auth, client
|
|
15
|
+
|
|
16
|
+
class TestGradExpClient(unittest.TestCase):
|
|
17
|
+
@patch('gradexp.auth.load_token', return_value="dummy_key")
|
|
18
|
+
@patch('requests.Session.post')
|
|
19
|
+
@patch('requests.Session.patch')
|
|
20
|
+
def test_init_and_log(self, mock_patch, mock_post, mock_load_token):
|
|
21
|
+
# Setup mock responses
|
|
22
|
+
# 1. init -> runs (409 Conflict)
|
|
23
|
+
mock_response_conflict = MagicMock()
|
|
24
|
+
mock_response_conflict.status_code = 409
|
|
25
|
+
mock_response_conflict.raise_for_status.side_effect = requests.exceptions.HTTPError("Conflict", response=mock_response_conflict)
|
|
26
|
+
|
|
27
|
+
# 2. init -> runs (success with new name)
|
|
28
|
+
mock_response_init = MagicMock()
|
|
29
|
+
mock_response_init.json.return_value = {"run_id": "run_123", "project_id": "proj_123"}
|
|
30
|
+
mock_response_init.status_code = 200
|
|
31
|
+
|
|
32
|
+
# 3. create tensor
|
|
33
|
+
mock_response_tensor = MagicMock()
|
|
34
|
+
mock_response_tensor.json.return_value = {"tensor_id": "tensor_123"}
|
|
35
|
+
mock_response_tensor.status_code = 200
|
|
36
|
+
|
|
37
|
+
# 4. log step
|
|
38
|
+
mock_response_step = MagicMock()
|
|
39
|
+
mock_response_step.status_code = 200
|
|
40
|
+
|
|
41
|
+
# 5. status update (PATCH)
|
|
42
|
+
mock_response_patch = MagicMock()
|
|
43
|
+
mock_response_patch.status_code = 200
|
|
44
|
+
|
|
45
|
+
mock_post.side_effect = [mock_response_conflict, mock_response_init, mock_response_tensor, mock_response_step]
|
|
46
|
+
mock_patch.return_value = mock_response_patch
|
|
47
|
+
|
|
48
|
+
# Reset client singleton
|
|
49
|
+
client._client = client.Client()
|
|
50
|
+
|
|
51
|
+
# RUN INIT
|
|
52
|
+
gradexp.init(project_name="my-project", run_name="my-run")
|
|
53
|
+
|
|
54
|
+
# Verify call 1 (Conflict)
|
|
55
|
+
call_args_list = mock_post.call_args_list
|
|
56
|
+
self.assertTrue(len(call_args_list) >= 2)
|
|
57
|
+
|
|
58
|
+
# First attempt: my-run
|
|
59
|
+
args1, kwargs1 = call_args_list[0]
|
|
60
|
+
self.assertIn("/api/v1/runs", args1[0])
|
|
61
|
+
self.assertEqual(kwargs1['json']['run_name'], "my-run")
|
|
62
|
+
|
|
63
|
+
# Second attempt: my-run-2
|
|
64
|
+
args2, kwargs2 = call_args_list[1]
|
|
65
|
+
self.assertIn("/api/v1/runs", args2[0])
|
|
66
|
+
self.assertEqual(kwargs2['json']['run_name'], "my-run-2")
|
|
67
|
+
|
|
68
|
+
# RUN LOG
|
|
69
|
+
data = np.zeros((2, 2), dtype=np.float32)
|
|
70
|
+
gradexp.log(data, name="my-tensor")
|
|
71
|
+
|
|
72
|
+
# Wait for worker thread
|
|
73
|
+
time.sleep(1)
|
|
74
|
+
|
|
75
|
+
# Verify LOG calls...
|
|
76
|
+
# Check create tensor call
|
|
77
|
+
# note: index 2 because first two were init
|
|
78
|
+
args_tensor, kwargs_tensor = call_args_list[2]
|
|
79
|
+
self.assertIn("/api/v1/runs/run_123/tensors", args_tensor[0])
|
|
80
|
+
|
|
81
|
+
# FINISH (triggers status update)
|
|
82
|
+
gradexp.finish()
|
|
83
|
+
|
|
84
|
+
# Verify PATCH call
|
|
85
|
+
self.assertTrue(mock_patch.called)
|
|
86
|
+
patch_args, patch_kwargs = mock_patch.call_args
|
|
87
|
+
self.assertIn("/api/v1/runs/run_123", patch_args[0])
|
|
88
|
+
self.assertEqual(patch_kwargs['json']['status'], "complete")
|
|
89
|
+
|
|
90
|
+
if __name__ == '__main__':
|
|
91
|
+
unittest.main()
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import gradexp
|
|
2
|
+
import numpy as np
|
|
3
|
+
import wandb
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# 1. Initialize wandb
|
|
9
|
+
wandb.init(project="wandb-gradexp-integration", name="test-run-metadata")
|
|
10
|
+
wandb.log({"test_score": 0.95})
|
|
11
|
+
|
|
12
|
+
# 2. Initialize gradexp without arguments - it should pull from wandb
|
|
13
|
+
gradexp.init()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# 3. Log some data to gradexp
|
|
17
|
+
for i in range(5):
|
|
18
|
+
data = np.random.rand(10, 10).astype(np.float32)
|
|
19
|
+
gradexp.log(data, name=f"tensor_{i}")
|
|
20
|
+
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import unittest
|
|
3
|
+
from unittest.mock import MagicMock, patch
|
|
4
|
+
import gradexp.client as client
|
|
5
|
+
import gradexp
|
|
6
|
+
|
|
7
|
+
class TestStatusUpdates(unittest.TestCase):
|
|
8
|
+
def setUp(self):
|
|
9
|
+
# Reset client state
|
|
10
|
+
client._client = client.Client()
|
|
11
|
+
client._client.api_key = "fake_key" # Mock logged in
|
|
12
|
+
client._client.data_instance_id = "fake_instance_id"
|
|
13
|
+
client._client.active = True
|
|
14
|
+
# Mock requests
|
|
15
|
+
self.mock_session = MagicMock()
|
|
16
|
+
client._client._session = self.mock_session
|
|
17
|
+
|
|
18
|
+
# Mock auth.BASE_URL
|
|
19
|
+
self.auth_patcher = patch('gradexp.auth.BASE_URL', "test.url")
|
|
20
|
+
self.auth_patcher.start()
|
|
21
|
+
|
|
22
|
+
def tearDown(self):
|
|
23
|
+
self.auth_patcher.stop()
|
|
24
|
+
|
|
25
|
+
def test_finish_complete(self):
|
|
26
|
+
client._client.finish() # Default status="complete"
|
|
27
|
+
self.mock_session.post.assert_called_with(
|
|
28
|
+
"https://api.test.url/api/v1/data-instances/fake_instance_id/complete",
|
|
29
|
+
headers={"Authorization": "Bearer fake_key", "Content-Type": "application/json"}
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def test_finish_stopped(self):
|
|
33
|
+
client._client.finish(status="stopped")
|
|
34
|
+
self.mock_session.post.assert_called_with(
|
|
35
|
+
"https://api.test.url/api/v1/data-instances/fake_instance_id/stopped",
|
|
36
|
+
headers={"Authorization": "Bearer fake_key", "Content-Type": "application/json"}
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def test_excepthook_installed(self):
|
|
40
|
+
# Check if excepthook wraps the original
|
|
41
|
+
original = sys.excepthook
|
|
42
|
+
|
|
43
|
+
# Force reinstall for test (since init might have been called or not)
|
|
44
|
+
client.init(project_name="test", run_name="test")
|
|
45
|
+
|
|
46
|
+
self.assertNotEqual(sys.excepthook, original)
|
|
47
|
+
self.assertTrue(hasattr(sys.excepthook, "__code__")) # simple check
|
|
48
|
+
|
|
49
|
+
# Test that calling it triggers finish("stopped")
|
|
50
|
+
client._client.finish = MagicMock()
|
|
51
|
+
try:
|
|
52
|
+
# We don't want to actually crash, just call the hook
|
|
53
|
+
sys.excepthook(RuntimeError, RuntimeError("test"), None)
|
|
54
|
+
except:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
client._client.finish.assert_called_with("stopped")
|
|
58
|
+
|
|
59
|
+
if __name__ == '__main__':
|
|
60
|
+
# Need to prevent init from running real network requests during import/setup if possible
|
|
61
|
+
# But init is only called explicitly.
|
|
62
|
+
unittest.main()
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
from unittest.mock import patch, MagicMock
|
|
3
|
+
import numpy as np
|
|
4
|
+
import sys
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
# Ensure we can import gradexp from current directory
|
|
8
|
+
sys.path.append(os.getcwd())
|
|
9
|
+
|
|
10
|
+
import gradexp
|
|
11
|
+
|
|
12
|
+
class TestGradExp(unittest.TestCase):
|
|
13
|
+
@patch('gradexp.auth.load_token')
|
|
14
|
+
@patch('requests.Session.post')
|
|
15
|
+
def test_full_flow(self, mock_post, mock_load_token):
|
|
16
|
+
# Setup mocks
|
|
17
|
+
mock_load_token.return_value = "test_api_key"
|
|
18
|
+
|
|
19
|
+
# Mock Init Response (201 Created)
|
|
20
|
+
init_response = MagicMock()
|
|
21
|
+
init_response.status_code = 201
|
|
22
|
+
init_response.json.return_value = {
|
|
23
|
+
"run_id": "run_uuid_123",
|
|
24
|
+
"project_id": "project_uuid_456",
|
|
25
|
+
"name": "test-run",
|
|
26
|
+
"status": "active"
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
# Mock Upload Response (200 OK)
|
|
30
|
+
upload_response = MagicMock()
|
|
31
|
+
upload_response.status_code = 200
|
|
32
|
+
upload_response.json.return_value = {"status": "success"}
|
|
33
|
+
|
|
34
|
+
# Configure side_effect to return different responses based on call order or inspection
|
|
35
|
+
# But simpler: just cycle them if we know the order, or make the mock return init first then upload
|
|
36
|
+
# Since we use the same session object, requests.Session.post is called
|
|
37
|
+
mock_post.side_effect = [init_response, upload_response]
|
|
38
|
+
|
|
39
|
+
# 1. run gradexp.init()
|
|
40
|
+
gradexp.init(project_name="test-project", run_name="test-run")
|
|
41
|
+
|
|
42
|
+
# Verify init call
|
|
43
|
+
self.assertEqual(mock_post.call_count, 1)
|
|
44
|
+
init_call_args = mock_post.call_args_list[0]
|
|
45
|
+
self.assertIn('/api/v1/runs', init_call_args[0][0])
|
|
46
|
+
self.assertEqual(init_call_args[1]['headers']['Authorization'], 'Bearer test_api_key')
|
|
47
|
+
|
|
48
|
+
payload = init_call_args[1]['json']
|
|
49
|
+
self.assertEqual(payload['project_name'], 'test-project')
|
|
50
|
+
self.assertEqual(payload['run_name'], 'test-run')
|
|
51
|
+
|
|
52
|
+
# Mock Tensor Creation Response (200 OK)
|
|
53
|
+
tensor_response = MagicMock()
|
|
54
|
+
tensor_response.status_code = 200
|
|
55
|
+
tensor_response.json.return_value = {"tensor_id": "tensor_uuid_123"}
|
|
56
|
+
|
|
57
|
+
# Update side_effect for subsequent calls
|
|
58
|
+
# 2nd call: _ensure_tensor (POST /runs/.../tensors)
|
|
59
|
+
# 3rd call: upload (POST /tensors/.../step)
|
|
60
|
+
mock_post.side_effect = [tensor_response, upload_response]
|
|
61
|
+
|
|
62
|
+
# 2. run gradexp.log()
|
|
63
|
+
data = np.random.rand(3, 224, 224).astype(np.float32)
|
|
64
|
+
gradexp.log(data)
|
|
65
|
+
|
|
66
|
+
# Wait for potential background thread (though in this mock setup it might be immediate if we're lucky or we might need to wait)
|
|
67
|
+
# Actually in the current client, it's a background thread.
|
|
68
|
+
# We might need to join the queue or something.
|
|
69
|
+
# For simplicity in this mock test, let's just wait a bit if needed or mock the queue.
|
|
70
|
+
# But wait, mock_post.side_effect will be called by the background thread.
|
|
71
|
+
|
|
72
|
+
import time
|
|
73
|
+
max_wait = 2.0
|
|
74
|
+
start = time.time()
|
|
75
|
+
while mock_post.call_count < 3 and time.time() - start < max_wait:
|
|
76
|
+
time.sleep(0.1)
|
|
77
|
+
|
|
78
|
+
# Verify calls
|
|
79
|
+
self.assertEqual(mock_post.call_count, 3)
|
|
80
|
+
|
|
81
|
+
# Verify tensor creation call
|
|
82
|
+
tensor_call_args = mock_post.call_args_list[1]
|
|
83
|
+
self.assertIn('/api/v1/runs/run_uuid_123/tensors', tensor_call_args[0][0])
|
|
84
|
+
|
|
85
|
+
# Verify log call
|
|
86
|
+
log_call_args = mock_post.call_args_list[2]
|
|
87
|
+
self.assertIn('/api/v1/tensors/tensor_uuid_123/step', log_call_args[0][0])
|
|
88
|
+
self.assertEqual(log_call_args[1]['params']['step'], 0)
|
|
89
|
+
|
|
90
|
+
print("\nTest passed successfully!")
|
|
91
|
+
|
|
92
|
+
if __name__ == '__main__':
|
|
93
|
+
unittest.main()
|