unitlab 2.3.0__tar.gz → 2.3.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {unitlab-2.3.0/src/unitlab.egg-info → unitlab-2.3.3}/PKG-INFO +4 -10
- {unitlab-2.3.0 → unitlab-2.3.3}/README.md +29 -0
- {unitlab-2.3.0 → unitlab-2.3.3}/setup.py +3 -1
- unitlab-2.3.3/src/unitlab/client.py +633 -0
- {unitlab-2.3.0 → unitlab-2.3.3}/src/unitlab/main.py +67 -1
- unitlab-2.3.3/src/unitlab/tunnel_config.py +238 -0
- {unitlab-2.3.0 → unitlab-2.3.3}/src/unitlab/utils.py +3 -0
- {unitlab-2.3.0 → unitlab-2.3.3/src/unitlab.egg-info}/PKG-INFO +4 -10
- {unitlab-2.3.0 → unitlab-2.3.3}/src/unitlab.egg-info/SOURCES.txt +1 -0
- {unitlab-2.3.0 → unitlab-2.3.3}/src/unitlab.egg-info/requires.txt +2 -0
- unitlab-2.3.0/src/unitlab/client.py +0 -236
- {unitlab-2.3.0 → unitlab-2.3.3}/LICENSE.md +0 -0
- {unitlab-2.3.0 → unitlab-2.3.3}/setup.cfg +0 -0
- {unitlab-2.3.0 → unitlab-2.3.3}/src/unitlab/__init__.py +0 -0
- {unitlab-2.3.0 → unitlab-2.3.3}/src/unitlab/__main__.py +0 -0
- {unitlab-2.3.0 → unitlab-2.3.3}/src/unitlab/exceptions.py +0 -0
- {unitlab-2.3.0 → unitlab-2.3.3}/src/unitlab.egg-info/dependency_links.txt +0 -0
- {unitlab-2.3.0 → unitlab-2.3.3}/src/unitlab.egg-info/entry_points.txt +0 -0
- {unitlab-2.3.0 → unitlab-2.3.3}/src/unitlab.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.1
|
2
2
|
Name: unitlab
|
3
|
-
Version: 2.3.
|
3
|
+
Version: 2.3.3
|
4
4
|
Home-page: https://github.com/teamunitlab/unitlab-sdk
|
5
5
|
Author: Unitlab Inc.
|
6
6
|
Author-email: team@unitlab.ai
|
@@ -21,11 +21,5 @@ Requires-Dist: requests
|
|
21
21
|
Requires-Dist: tqdm
|
22
22
|
Requires-Dist: typer
|
23
23
|
Requires-Dist: validators
|
24
|
-
|
25
|
-
|
26
|
-
Dynamic: classifier
|
27
|
-
Dynamic: home-page
|
28
|
-
Dynamic: keywords
|
29
|
-
Dynamic: license
|
30
|
-
Dynamic: license-file
|
31
|
-
Dynamic: requires-dist
|
24
|
+
Requires-Dist: psutil
|
25
|
+
Requires-Dist: pyyaml
|
@@ -26,6 +26,35 @@ Once you have successfully installed the Unitlab package, you can conveniently h
|
|
26
26
|
## Quickstart
|
27
27
|
Follow [the quickstart guide for the Python SDK](https://docs.unitlab.ai/cli-python-sdk/unitlab-python-sdk).
|
28
28
|
|
29
|
+
## CLI Commands
|
30
|
+
|
31
|
+
### Agent Commands
|
32
|
+
|
33
|
+
The agent module provides commands for running device agents with Jupyter, SSH tunnels, and metrics reporting.
|
34
|
+
|
35
|
+
#### Run Agent
|
36
|
+
|
37
|
+
Run a full device agent that sets up Jupyter notebooks, SSH tunnels, and system metrics reporting:
|
38
|
+
|
39
|
+
```bash
|
40
|
+
unitlab agent run --api-key YOUR_API_KEY [OPTIONS]
|
41
|
+
```
|
42
|
+
|
43
|
+
**Options:**
|
44
|
+
- `--api-key` (required): Your Unitlab API key
|
45
|
+
|
46
|
+
**Example:**
|
47
|
+
```bash
|
48
|
+
# Run with auto-generated device ID
|
49
|
+
unitlab agent run your-api-key-here
|
50
|
+
|
51
|
+
|
52
|
+
|
53
|
+
The agent will:
|
54
|
+
- Initialize Jupyter notebook server
|
55
|
+
- Set up SSH tunnels for remote access
|
56
|
+
- Collect and report system metrics
|
57
|
+
- Handle graceful shutdown on interruption
|
29
58
|
|
30
59
|
## Documentation
|
31
60
|
[The documentation](https://docs.unitlab.ai/) provides comprehensive instructions on how to utilize the Unilab SDK effectively.
|
@@ -2,7 +2,7 @@ from setuptools import find_packages, setup
|
|
2
2
|
|
3
3
|
setup(
|
4
4
|
name="unitlab",
|
5
|
-
version="2.3.
|
5
|
+
version="2.3.3",
|
6
6
|
license="MIT",
|
7
7
|
author="Unitlab Inc.",
|
8
8
|
author_email="team@unitlab.ai",
|
@@ -29,6 +29,8 @@ setup(
|
|
29
29
|
"tqdm",
|
30
30
|
"typer",
|
31
31
|
"validators",
|
32
|
+
'psutil',
|
33
|
+
'pyyaml',
|
32
34
|
],
|
33
35
|
entry_points={
|
34
36
|
"console_scripts": ["unitlab=unitlab.main:app"],
|
@@ -0,0 +1,633 @@
|
|
1
|
+
import asyncio
|
2
|
+
import glob
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import urllib.parse
|
6
|
+
import aiofiles
|
7
|
+
import aiohttp
|
8
|
+
import requests
|
9
|
+
import tqdm
|
10
|
+
import socket
|
11
|
+
import subprocess
|
12
|
+
import signal
|
13
|
+
import re
|
14
|
+
import time
|
15
|
+
import threading
|
16
|
+
import psutil
|
17
|
+
from datetime import datetime, timezone
|
18
|
+
from .tunnel_config import CloudflareTunnel
|
19
|
+
from .utils import get_api_url, handle_exceptions
|
20
|
+
|
21
|
+
|
22
|
+
try:
|
23
|
+
import GPUtil
|
24
|
+
HAS_GPU = True
|
25
|
+
except ImportError:
|
26
|
+
HAS_GPU = False
|
27
|
+
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
class UnitlabClient:
|
32
|
+
"""A client with a connection to the Unitlab.ai platform.
|
33
|
+
|
34
|
+
Note:
|
35
|
+
Please refer to the `Python SDK quickstart <https://docs.unitlab.ai/cli-python-sdk/unitlab-python-sdk>`__ for a full example of working with the Python SDK.
|
36
|
+
|
37
|
+
First install the SDK.
|
38
|
+
|
39
|
+
.. code-block:: bash
|
40
|
+
|
41
|
+
pip install --upgrade unitlab
|
42
|
+
|
43
|
+
Import the ``unitlab`` package in your python file and set up a client with an API key. An API key can be created on <https://unitlab.ai/>`__.
|
44
|
+
|
45
|
+
.. code-block:: python
|
46
|
+
|
47
|
+
from unitlab import UnitlabClient
|
48
|
+
api_key = 'YOUR_API_KEY'
|
49
|
+
client = UnitlabClient(api_key)
|
50
|
+
|
51
|
+
Or store your Unitlab API key in your environment (``UNITLAB_API_KEY = 'YOUR_API_KEY'``):
|
52
|
+
|
53
|
+
.. code-block:: python
|
54
|
+
|
55
|
+
from unitlab import UnitlabClient
|
56
|
+
client = UnitlabClient()
|
57
|
+
|
58
|
+
Args:
|
59
|
+
api_key: Your Unitlab.ai API key. If no API key given, reads ``UNITLAB_API_KEY`` from the environment. Defaults to :obj:`None`.
|
60
|
+
Raises:
|
61
|
+
:exc:`~unitlab.exceptions.AuthenticationError`: If an invalid API key is used or (when not passing the API key directly) if ``UNITLAB_API_KEY`` is not found in your environment.
|
62
|
+
"""
|
63
|
+
|
64
|
+
def __init__(self, api_key, api_url=None):
|
65
|
+
self.api_key = api_key
|
66
|
+
self.api_url = api_url or get_api_url()
|
67
|
+
self.api_session = requests.Session()
|
68
|
+
adapter = requests.adapters.HTTPAdapter(max_retries=3)
|
69
|
+
self.api_session.mount("http://", adapter)
|
70
|
+
self.api_session.mount("https://", adapter)
|
71
|
+
|
72
|
+
# Device agent attributes (initialized when needed)
|
73
|
+
self.device_id = None
|
74
|
+
self.base_domain = None
|
75
|
+
self.server_url = None
|
76
|
+
self.hostname = socket.gethostname()
|
77
|
+
self.tunnel_manager = None
|
78
|
+
self.jupyter_url = None
|
79
|
+
self.ssh_url = None
|
80
|
+
self.jupyter_proc = None
|
81
|
+
self.tunnel_proc = None
|
82
|
+
self.jupyter_port = None
|
83
|
+
self.running = True
|
84
|
+
self.metrics_thread = None
|
85
|
+
|
86
|
+
def close(self) -> None:
|
87
|
+
"""Close :class:`UnitlabClient` connections.
|
88
|
+
|
89
|
+
You can manually close the Unitlab client's connections:
|
90
|
+
|
91
|
+
.. code-block:: python
|
92
|
+
|
93
|
+
client = UnitlabClient()
|
94
|
+
client.projects()
|
95
|
+
client.close()
|
96
|
+
|
97
|
+
Or use the client as a context manager:
|
98
|
+
|
99
|
+
.. code-block:: python
|
100
|
+
|
101
|
+
with UnitlabClient() as client:
|
102
|
+
client.projects()
|
103
|
+
"""
|
104
|
+
self.api_session.close()
|
105
|
+
|
106
|
+
def __enter__(self):
|
107
|
+
return self
|
108
|
+
|
109
|
+
def __exit__(
|
110
|
+
self,
|
111
|
+
exc_type,
|
112
|
+
exc_value,
|
113
|
+
traceback,
|
114
|
+
) -> None:
|
115
|
+
self.close()
|
116
|
+
|
117
|
+
def _get_headers(self):
|
118
|
+
return {"Authorization": f"Api-Key {self.api_key}"}
|
119
|
+
|
120
|
+
@handle_exceptions
|
121
|
+
def _get(self, endpoint):
|
122
|
+
return self.api_session.get(
|
123
|
+
urllib.parse.urljoin(self.api_url, endpoint), headers=self._get_headers()
|
124
|
+
)
|
125
|
+
|
126
|
+
@handle_exceptions
|
127
|
+
def _post(self, endpoint, data=None):
|
128
|
+
return self.api_session.post(
|
129
|
+
urllib.parse.urljoin(self.api_url, endpoint),
|
130
|
+
json=data or {},
|
131
|
+
headers=self._get_headers(),
|
132
|
+
)
|
133
|
+
|
134
|
+
def projects(self, pretty=0):
|
135
|
+
return self._get(f"/api/sdk/projects/?pretty={pretty}")
|
136
|
+
|
137
|
+
def project(self, project_id, pretty=0):
|
138
|
+
return self._get(f"/api/sdk/projects/{project_id}/?pretty={pretty}")
|
139
|
+
|
140
|
+
def project_members(self, project_id, pretty=0):
|
141
|
+
return self._get(f"/api/sdk/projects/{project_id}/members/?pretty={pretty}")
|
142
|
+
|
143
|
+
def project_upload_data(self, project_id, directory, batch_size=100):
|
144
|
+
if not os.path.isdir(directory):
|
145
|
+
raise ValueError(f"Directory {directory} does not exist")
|
146
|
+
|
147
|
+
files = [
|
148
|
+
file
|
149
|
+
for files_list in (
|
150
|
+
glob.glob(os.path.join(directory, "") + extension)
|
151
|
+
for extension in ["*jpg", "*png", "*jpeg", "*webp"]
|
152
|
+
)
|
153
|
+
for file in files_list
|
154
|
+
]
|
155
|
+
filtered_files = []
|
156
|
+
for file in files:
|
157
|
+
file_size = os.path.getsize(file) / 1024 / 1024
|
158
|
+
if file_size > 6:
|
159
|
+
logger.warning(
|
160
|
+
f"File {file} is too large ({file_size:.4f} megabytes) skipping, max size is 6 MB"
|
161
|
+
)
|
162
|
+
continue
|
163
|
+
filtered_files.append(file)
|
164
|
+
|
165
|
+
num_files = len(filtered_files)
|
166
|
+
num_batches = (num_files + batch_size - 1) // batch_size
|
167
|
+
|
168
|
+
async def post_file(session: aiohttp.ClientSession, file: str, project_id: str):
|
169
|
+
async with aiofiles.open(file, "rb") as f:
|
170
|
+
form_data = aiohttp.FormData()
|
171
|
+
form_data.add_field("project", project_id)
|
172
|
+
form_data.add_field(
|
173
|
+
"file", await f.read(), filename=os.path.basename(file)
|
174
|
+
)
|
175
|
+
try:
|
176
|
+
await asyncio.sleep(0.1)
|
177
|
+
async with session.post(
|
178
|
+
urllib.parse.urljoin(self.api_url, "/api/sdk/upload-data/"),
|
179
|
+
data=form_data,
|
180
|
+
) as response:
|
181
|
+
response.raise_for_status()
|
182
|
+
return 1
|
183
|
+
except Exception as e:
|
184
|
+
logger.error(f"Error uploading file {file} - {e}")
|
185
|
+
return 0
|
186
|
+
|
187
|
+
async def main():
|
188
|
+
logger.info(f"Uploading {num_files} files to project {project_id}")
|
189
|
+
with tqdm.tqdm(total=num_files, ncols=80) as pbar:
|
190
|
+
async with aiohttp.ClientSession(
|
191
|
+
headers=self._get_headers()
|
192
|
+
) as session:
|
193
|
+
for i in range(num_batches):
|
194
|
+
tasks = []
|
195
|
+
for file in filtered_files[
|
196
|
+
i * batch_size : min((i + 1) * batch_size, num_files)
|
197
|
+
]:
|
198
|
+
tasks.append(
|
199
|
+
post_file(
|
200
|
+
session=session, file=file, project_id=project_id
|
201
|
+
)
|
202
|
+
)
|
203
|
+
for f in asyncio.as_completed(tasks):
|
204
|
+
pbar.update(await f)
|
205
|
+
|
206
|
+
asyncio.run(main())
|
207
|
+
|
208
|
+
def datasets(self, pretty=0):
|
209
|
+
return self._get(f"/api/sdk/datasets/?pretty={pretty}")
|
210
|
+
|
211
|
+
def dataset_download(self, dataset_id, export_type):
|
212
|
+
response = self._post(
|
213
|
+
f"/api/sdk/datasets/{dataset_id}/",
|
214
|
+
data={"download_type": "annotation", "export_type": export_type},
|
215
|
+
)
|
216
|
+
|
217
|
+
with self.api_session.get(url=response["file"], stream=True) as r:
|
218
|
+
r.raise_for_status()
|
219
|
+
filename = f"dataset-{dataset_id}.json"
|
220
|
+
with open(filename, "wb") as f:
|
221
|
+
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
222
|
+
f.write(chunk)
|
223
|
+
logger.info(f"File: {os.path.abspath(filename)}")
|
224
|
+
return os.path.abspath(filename)
|
225
|
+
|
226
|
+
def dataset_download_files(self, dataset_id):
|
227
|
+
response = self._post(
|
228
|
+
f"/api/sdk/datasets/{dataset_id}/", data={"download_type": "files"}
|
229
|
+
)
|
230
|
+
folder = f"dataset-files-{dataset_id}"
|
231
|
+
os.makedirs(folder, exist_ok=True)
|
232
|
+
dataset_files = [
|
233
|
+
dataset_file
|
234
|
+
for dataset_file in response
|
235
|
+
if not os.path.isfile(os.path.join(folder, dataset_file["file_name"]))
|
236
|
+
]
|
237
|
+
|
238
|
+
async def download_file(session: aiohttp.ClientSession, dataset_file: dict):
|
239
|
+
async with session.get(url=dataset_file["source"]) as r:
|
240
|
+
try:
|
241
|
+
r.raise_for_status()
|
242
|
+
except Exception as e:
|
243
|
+
logger.error(
|
244
|
+
f"Error downloading file {dataset_file['file_name']} - {e}"
|
245
|
+
)
|
246
|
+
return 0
|
247
|
+
async with aiofiles.open(
|
248
|
+
os.path.join(folder, dataset_file["file_name"]), "wb"
|
249
|
+
) as f:
|
250
|
+
async for chunk in r.content.iter_any():
|
251
|
+
await f.write(chunk)
|
252
|
+
return 1
|
253
|
+
|
254
|
+
async def main():
|
255
|
+
with tqdm.tqdm(total=len(dataset_files), ncols=80) as pbar:
|
256
|
+
async with aiohttp.ClientSession() as session:
|
257
|
+
tasks = [
|
258
|
+
download_file(session=session, dataset_file=dataset_file)
|
259
|
+
for dataset_file in dataset_files
|
260
|
+
]
|
261
|
+
for f in asyncio.as_completed(tasks):
|
262
|
+
pbar.update(await f)
|
263
|
+
|
264
|
+
asyncio.run(main())
|
265
|
+
|
266
|
+
def initialize_device_agent(self, server_url: str, device_id: str, base_domain: str):
|
267
|
+
"""Initialize device agent configuration"""
|
268
|
+
self.server_url = server_url.rstrip('/')
|
269
|
+
self.device_id = device_id
|
270
|
+
self.base_domain = base_domain
|
271
|
+
|
272
|
+
# Initialize tunnel manager if available
|
273
|
+
if CloudflareTunnel:
|
274
|
+
self.tunnel_manager = CloudflareTunnel(base_domain, device_id)
|
275
|
+
self.jupyter_url = self.tunnel_manager.jupyter_url
|
276
|
+
self.ssh_url = self.tunnel_manager.ssh_url
|
277
|
+
else:
|
278
|
+
self.tunnel_manager = None
|
279
|
+
self.jupyter_url = f"https://jupyter-{device_id}.{base_domain}"
|
280
|
+
self.ssh_url = f"https://ssh-{device_id}.{base_domain}"
|
281
|
+
|
282
|
+
# Setup signal handlers
|
283
|
+
signal.signal(signal.SIGINT, self._handle_shutdown)
|
284
|
+
signal.signal(signal.SIGTERM, self._handle_shutdown)
|
285
|
+
|
286
|
+
def _handle_shutdown(self, signum, frame):
|
287
|
+
"""Handle shutdown signals"""
|
288
|
+
_ = frame # Unused but required by signal handler signature
|
289
|
+
logger.info(f"Received signal {signum}, shutting down...")
|
290
|
+
self.running = False
|
291
|
+
|
292
|
+
def _get_device_headers(self):
|
293
|
+
"""Get headers for device agent API requests"""
|
294
|
+
headers = {
|
295
|
+
'Content-Type': 'application/json',
|
296
|
+
'User-Agent': f'UnitlabDeviceAgent/{self.device_id}'
|
297
|
+
}
|
298
|
+
|
299
|
+
# Add API key if provided
|
300
|
+
if self.api_key:
|
301
|
+
headers['Authorization'] = f'Api-Key {self.api_key}'
|
302
|
+
|
303
|
+
return headers
|
304
|
+
|
305
|
+
def _post_device(self, endpoint, data=None):
|
306
|
+
"""Make authenticated POST request for device agent"""
|
307
|
+
full_url = urllib.parse.urljoin(self.server_url, endpoint)
|
308
|
+
logger.debug(f"Posting to {full_url} with data: {data}")
|
309
|
+
|
310
|
+
try:
|
311
|
+
response = self.api_session.post(
|
312
|
+
full_url,
|
313
|
+
json=data or {},
|
314
|
+
headers=self._get_device_headers(),
|
315
|
+
)
|
316
|
+
logger.debug(f"Response status: {response.status_code}, Response: {response.text}")
|
317
|
+
response.raise_for_status()
|
318
|
+
return response
|
319
|
+
except Exception as e:
|
320
|
+
logger.error(f"POST request failed to {full_url}: {e}")
|
321
|
+
raise
|
322
|
+
|
323
|
+
def start_jupyter(self) -> bool:
|
324
|
+
"""Start Jupyter notebook server"""
|
325
|
+
try:
|
326
|
+
logger.info("Starting Jupyter notebook...")
|
327
|
+
|
328
|
+
cmd = [
|
329
|
+
"jupyter", "notebook",
|
330
|
+
"--no-browser",
|
331
|
+
"--ServerApp.token=''",
|
332
|
+
"--ServerApp.password=''",
|
333
|
+
"--ServerApp.allow_origin='*'",
|
334
|
+
"--ServerApp.ip='0.0.0.0'"
|
335
|
+
]
|
336
|
+
|
337
|
+
self.jupyter_proc = subprocess.Popen(
|
338
|
+
cmd,
|
339
|
+
stdout=subprocess.PIPE,
|
340
|
+
stderr=subprocess.STDOUT,
|
341
|
+
text=True
|
342
|
+
)
|
343
|
+
|
344
|
+
# Wait for Jupyter to start and get the port
|
345
|
+
start_time = time.time()
|
346
|
+
while time.time() - start_time < 30:
|
347
|
+
line = self.jupyter_proc.stdout.readline()
|
348
|
+
if not line:
|
349
|
+
break
|
350
|
+
|
351
|
+
# Look for the port in the output
|
352
|
+
match = re.search(r'http://.*:(\d+)/', line)
|
353
|
+
if match:
|
354
|
+
self.jupyter_port = match.group(1)
|
355
|
+
logger.info(f"✅ Jupyter started on port {self.jupyter_port}")
|
356
|
+
return True
|
357
|
+
|
358
|
+
raise Exception("Timeout waiting for Jupyter to start")
|
359
|
+
|
360
|
+
except Exception as e:
|
361
|
+
logger.error(f"Failed to start Jupyter: {e}")
|
362
|
+
if self.jupyter_proc:
|
363
|
+
self.jupyter_proc.terminate()
|
364
|
+
self.jupyter_proc = None
|
365
|
+
return False
|
366
|
+
|
367
|
+
def setup_tunnels(self) -> bool:
|
368
|
+
"""Setup Cloudflare tunnels"""
|
369
|
+
try:
|
370
|
+
if not self.jupyter_port:
|
371
|
+
logger.error("Jupyter port not available")
|
372
|
+
return False
|
373
|
+
|
374
|
+
if not self.tunnel_manager:
|
375
|
+
logger.warning("CloudflareTunnel not available, skipping tunnel setup")
|
376
|
+
return True
|
377
|
+
|
378
|
+
logger.info("Setting up Cloudflare tunnels...")
|
379
|
+
self.tunnel_proc = self.tunnel_manager.setup(self.jupyter_port)
|
380
|
+
|
381
|
+
if self.tunnel_proc:
|
382
|
+
logger.info("✅ Tunnels established")
|
383
|
+
self.report_services()
|
384
|
+
return True
|
385
|
+
|
386
|
+
return False
|
387
|
+
|
388
|
+
except Exception as e:
|
389
|
+
logger.error(f"Tunnel setup failed: {e}")
|
390
|
+
return False
|
391
|
+
|
392
|
+
def check_ssh(self) -> bool:
|
393
|
+
"""Check if SSH service is available"""
|
394
|
+
try:
|
395
|
+
# Check if SSH is running
|
396
|
+
result = subprocess.run(
|
397
|
+
["systemctl", "is-active", "ssh"],
|
398
|
+
capture_output=True,
|
399
|
+
text=True,
|
400
|
+
timeout=5
|
401
|
+
)
|
402
|
+
|
403
|
+
if result.stdout.strip() == "active":
|
404
|
+
logger.info("✅ SSH service is active")
|
405
|
+
return True
|
406
|
+
else:
|
407
|
+
logger.warning("SSH service is not active")
|
408
|
+
# Try to start SSH
|
409
|
+
subprocess.run(["sudo", "systemctl", "start", "ssh"], timeout=10)
|
410
|
+
time.sleep(2)
|
411
|
+
return False
|
412
|
+
|
413
|
+
except Exception as e:
|
414
|
+
logger.error(f"SSH check failed: {e}")
|
415
|
+
return False
|
416
|
+
|
417
|
+
def report_services(self):
|
418
|
+
"""Report services to the server"""
|
419
|
+
try:
|
420
|
+
# Report Jupyter service
|
421
|
+
jupyter_data = {
|
422
|
+
'service_type': 'jupyter',
|
423
|
+
'service_name': f'jupyter-{self.device_id}',
|
424
|
+
'local_port': int(self.jupyter_port) if self.jupyter_port else 8888,
|
425
|
+
'tunnel_url': self.jupyter_url,
|
426
|
+
'status': 'online'
|
427
|
+
}
|
428
|
+
|
429
|
+
logger.info(f"Reporting Jupyter service with URL: {self.jupyter_url}")
|
430
|
+
jupyter_response = self._post_device(
|
431
|
+
f"/api/tunnel/agent/jupyter/{self.device_id}/",
|
432
|
+
jupyter_data
|
433
|
+
)
|
434
|
+
logger.info(f"Reported Jupyter service: {jupyter_response.status_code if hasattr(jupyter_response, 'status_code') else jupyter_response}")
|
435
|
+
|
436
|
+
# Report SSH service (always report, even if SSH is not running locally)
|
437
|
+
# Remove https:// prefix for SSH hostname
|
438
|
+
ssh_hostname = self.ssh_url.replace('https://', '')
|
439
|
+
|
440
|
+
# Get current system username
|
441
|
+
import getpass
|
442
|
+
current_user = getpass.getuser()
|
443
|
+
|
444
|
+
# Create SSH connection command
|
445
|
+
ssh_connection_cmd = f"ssh -o ProxyCommand='cloudflared access ssh --hostname {ssh_hostname}' {current_user}@{ssh_hostname}"
|
446
|
+
|
447
|
+
# Check if SSH is available
|
448
|
+
ssh_available = self.check_ssh()
|
449
|
+
|
450
|
+
ssh_data = {
|
451
|
+
'service_type': 'ssh',
|
452
|
+
'service_name': f'ssh-{self.device_id}',
|
453
|
+
'local_port': 22,
|
454
|
+
'tunnel_url': ssh_connection_cmd, # Send the SSH command instead of URL
|
455
|
+
'status': 'online' if ssh_available else 'offline'
|
456
|
+
}
|
457
|
+
|
458
|
+
logger.info(f"Reporting SSH service with command: {ssh_connection_cmd}")
|
459
|
+
ssh_response = self._post_device(
|
460
|
+
f"/api/tunnel/agent/ssh/{self.device_id}/",
|
461
|
+
ssh_data
|
462
|
+
)
|
463
|
+
logger.info(f"Reported SSH service: {ssh_response.status_code if hasattr(ssh_response, 'status_code') else ssh_response}")
|
464
|
+
|
465
|
+
except Exception as e:
|
466
|
+
logger.error(f"Failed to report services: {e}", exc_info=True)
|
467
|
+
|
468
|
+
def collect_metrics(self) -> dict:
|
469
|
+
"""Collect system metrics"""
|
470
|
+
metrics = {}
|
471
|
+
|
472
|
+
# CPU metrics
|
473
|
+
metrics['cpu'] = {
|
474
|
+
'percent': psutil.cpu_percent(interval=1),
|
475
|
+
'count': psutil.cpu_count(),
|
476
|
+
'timestamp': datetime.now(timezone.utc).isoformat()
|
477
|
+
}
|
478
|
+
|
479
|
+
# Memory metrics
|
480
|
+
mem = psutil.virtual_memory()
|
481
|
+
metrics['ram'] = {
|
482
|
+
'total': mem.total,
|
483
|
+
'used': mem.used,
|
484
|
+
'available': mem.available,
|
485
|
+
'percent': mem.percent,
|
486
|
+
'timestamp': datetime.now(timezone.utc).isoformat()
|
487
|
+
}
|
488
|
+
|
489
|
+
# GPU metrics (if available)
|
490
|
+
if HAS_GPU:
|
491
|
+
try:
|
492
|
+
gpus = GPUtil.getGPUs()
|
493
|
+
if gpus:
|
494
|
+
gpu = gpus[0]
|
495
|
+
metrics['gpu'] = {
|
496
|
+
'name': gpu.name,
|
497
|
+
'load': gpu.load * 100,
|
498
|
+
'memory_used': gpu.memoryUsed,
|
499
|
+
'memory_total': gpu.memoryTotal,
|
500
|
+
'temperature': gpu.temperature,
|
501
|
+
'timestamp': datetime.now(timezone.utc).isoformat()
|
502
|
+
}
|
503
|
+
except Exception as e:
|
504
|
+
logger.debug(f"GPU metrics unavailable: {e}")
|
505
|
+
|
506
|
+
return metrics
|
507
|
+
|
508
|
+
def send_metrics(self):
|
509
|
+
"""Send metrics to server"""
|
510
|
+
try:
|
511
|
+
metrics = self.collect_metrics()
|
512
|
+
|
513
|
+
# Send CPU metrics
|
514
|
+
if 'cpu' in metrics:
|
515
|
+
self._post_device(f"/api/tunnel/agent/cpu/{self.device_id}/", metrics['cpu'])
|
516
|
+
|
517
|
+
# Send RAM metrics
|
518
|
+
if 'ram' in metrics:
|
519
|
+
self._post_device(f"/api/tunnel/agent/ram/{self.device_id}/", metrics['ram'])
|
520
|
+
|
521
|
+
# Send GPU metrics if available
|
522
|
+
if 'gpu' in metrics and metrics['gpu']:
|
523
|
+
self._post_device(f"/api/tunnel/agent/gpu/{self.device_id}/", metrics['gpu'])
|
524
|
+
|
525
|
+
logger.debug(f"Metrics sent - CPU: {metrics['cpu']['percent']:.1f}%, RAM: {metrics['ram']['percent']:.1f}%")
|
526
|
+
|
527
|
+
except Exception as e:
|
528
|
+
logger.error(f"Failed to send metrics: {e}")
|
529
|
+
|
530
|
+
def metrics_loop(self):
|
531
|
+
"""Background thread for sending metrics"""
|
532
|
+
logger.info("Starting metrics thread")
|
533
|
+
|
534
|
+
while self.running:
|
535
|
+
try:
|
536
|
+
self.send_metrics()
|
537
|
+
|
538
|
+
# Check if processes are still running
|
539
|
+
if self.jupyter_proc and self.jupyter_proc.poll() is not None:
|
540
|
+
logger.warning("Jupyter process died")
|
541
|
+
self.jupyter_proc = None
|
542
|
+
|
543
|
+
if self.tunnel_proc and self.tunnel_proc.poll() is not None:
|
544
|
+
logger.warning("Tunnel process died")
|
545
|
+
self.tunnel_proc = None
|
546
|
+
|
547
|
+
except Exception as e:
|
548
|
+
logger.error(f"Metrics loop error: {e}")
|
549
|
+
|
550
|
+
# Wait for next interval (default 5 seconds)
|
551
|
+
for _ in range(3):
|
552
|
+
if not self.running:
|
553
|
+
break
|
554
|
+
time.sleep(1)
|
555
|
+
|
556
|
+
logger.info("Metrics thread stopped")
|
557
|
+
|
558
|
+
def run_device_agent(self):
|
559
|
+
"""Main run method for device agent"""
|
560
|
+
logger.info("=" * 50)
|
561
|
+
logger.info("Starting Device Agent")
|
562
|
+
logger.info(f"Device ID: {self.device_id}")
|
563
|
+
logger.info(f"Server: {self.server_url}")
|
564
|
+
logger.info(f"Domain: {self.base_domain}")
|
565
|
+
logger.info("=" * 50)
|
566
|
+
|
567
|
+
# Check SSH
|
568
|
+
self.check_ssh()
|
569
|
+
|
570
|
+
# Start Jupyter
|
571
|
+
if not self.start_jupyter():
|
572
|
+
logger.error("Failed to start Jupyter")
|
573
|
+
return
|
574
|
+
|
575
|
+
# Wait a moment for Jupyter to fully initialize
|
576
|
+
time.sleep(1)
|
577
|
+
|
578
|
+
# Setup tunnels
|
579
|
+
if not self.setup_tunnels():
|
580
|
+
logger.error("Failed to setup tunnels")
|
581
|
+
self.cleanup_device_agent()
|
582
|
+
return
|
583
|
+
|
584
|
+
# Print access information
|
585
|
+
logger.info("=" * 50)
|
586
|
+
logger.info("🎉 All services started successfully!")
|
587
|
+
logger.info(f"📔 Jupyter: {self.jupyter_url}")
|
588
|
+
logger.info(f"🔐 SSH: {self.ssh_url}")
|
589
|
+
# Remove https:// prefix for SSH command display
|
590
|
+
ssh_hostname = self.ssh_url.replace('https://', '')
|
591
|
+
import getpass
|
592
|
+
current_user = getpass.getuser()
|
593
|
+
logger.info(f"🔐 SSH Command: ssh -o ProxyCommand='cloudflared access ssh --hostname {ssh_hostname}' {current_user}@{ssh_hostname}")
|
594
|
+
logger.info("=" * 50)
|
595
|
+
|
596
|
+
# Start metrics thread
|
597
|
+
self.metrics_thread = threading.Thread(target=self.metrics_loop, daemon=True)
|
598
|
+
self.metrics_thread.start()
|
599
|
+
|
600
|
+
# Main loop
|
601
|
+
try:
|
602
|
+
while self.running:
|
603
|
+
time.sleep(1)
|
604
|
+
except KeyboardInterrupt:
|
605
|
+
logger.info("Interrupted by user")
|
606
|
+
|
607
|
+
self.cleanup_device_agent()
|
608
|
+
|
609
|
+
def cleanup_device_agent(self):
|
610
|
+
"""Clean up device agent resources"""
|
611
|
+
logger.info("Cleaning up...")
|
612
|
+
|
613
|
+
self.running = False
|
614
|
+
|
615
|
+
# Stop Jupyter
|
616
|
+
if self.jupyter_proc:
|
617
|
+
logger.info("Stopping Jupyter...")
|
618
|
+
self.jupyter_proc.terminate()
|
619
|
+
try:
|
620
|
+
self.jupyter_proc.wait(timeout=5)
|
621
|
+
except subprocess.TimeoutExpired:
|
622
|
+
self.jupyter_proc.kill()
|
623
|
+
|
624
|
+
# Stop tunnel
|
625
|
+
if self.tunnel_proc:
|
626
|
+
logger.info("Stopping tunnel...")
|
627
|
+
self.tunnel_proc.terminate()
|
628
|
+
try:
|
629
|
+
self.tunnel_proc.wait(timeout=5)
|
630
|
+
except subprocess.TimeoutExpired:
|
631
|
+
self.tunnel_proc.kill()
|
632
|
+
|
633
|
+
logger.info("Cleanup complete")
|
@@ -1,6 +1,8 @@
|
|
1
1
|
from enum import Enum
|
2
2
|
from pathlib import Path
|
3
3
|
from uuid import UUID
|
4
|
+
import logging
|
5
|
+
import os
|
4
6
|
|
5
7
|
import typer
|
6
8
|
import validators
|
@@ -9,12 +11,16 @@ from typing_extensions import Annotated
|
|
9
11
|
from . import utils
|
10
12
|
from .client import UnitlabClient
|
11
13
|
|
14
|
+
|
12
15
|
app = typer.Typer()
|
13
16
|
project_app = typer.Typer()
|
14
17
|
dataset_app = typer.Typer()
|
18
|
+
agent_app = typer.Typer()
|
19
|
+
|
15
20
|
|
16
21
|
app.add_typer(project_app, name="project", help="Project commands")
|
17
22
|
app.add_typer(dataset_app, name="dataset", help="Dataset commands")
|
23
|
+
app.add_typer(agent_app, name="agent", help="Agent commands")
|
18
24
|
|
19
25
|
|
20
26
|
API_KEY = Annotated[
|
@@ -42,7 +48,7 @@ class AnnotationType(str, Enum):
|
|
42
48
|
@app.command(help="Configure the credentials")
|
43
49
|
def configure(
|
44
50
|
api_key: Annotated[str, typer.Option(help="The api-key obtained from unitlab.ai")],
|
45
|
-
api_url: Annotated[str, typer.Option()] = "https://
|
51
|
+
api_url: Annotated[str, typer.Option()] = "https://localhost/",
|
46
52
|
):
|
47
53
|
if not validators.url(api_url, simple_host=True):
|
48
54
|
raise typer.BadParameter("Invalid api url")
|
@@ -105,5 +111,65 @@ def dataset_download(
|
|
105
111
|
get_client(api_key).dataset_download_files(pk)
|
106
112
|
|
107
113
|
|
114
|
+
def send_metrics_to_server(server_url: str, device_id: str, metrics: dict):
|
115
|
+
"""Standalone function to send metrics to server using client"""
|
116
|
+
client = UnitlabClient(api_key="dummy") # API key not needed for metrics
|
117
|
+
return client.send_metrics_to_server(server_url, device_id, metrics)
|
118
|
+
|
119
|
+
|
120
|
+
def send_metrics_into_server():
|
121
|
+
"""Standalone function to collect system metrics using client"""
|
122
|
+
client = UnitlabClient(api_key="dummy") # API key not needed for metrics
|
123
|
+
return client.collect_system_metrics()
|
124
|
+
|
125
|
+
|
126
|
+
@agent_app.command(name="run", help="Run the device agent with Jupyter, SSH tunnels and metrics")
|
127
|
+
def run_agent(
|
128
|
+
api_key: str,
|
129
|
+
device_id: Annotated[str, typer.Option(help="Device ID")] = None,
|
130
|
+
base_domain: Annotated[str, typer.Option(help="Base domain for tunnels")] = "1scan.uz",
|
131
|
+
|
132
|
+
):
|
133
|
+
"""Run the full device agent with Jupyter, SSH tunnels and metrics reporting"""
|
134
|
+
|
135
|
+
# Setup logging
|
136
|
+
logging.basicConfig(
|
137
|
+
level=logging.INFO,
|
138
|
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
139
|
+
handlers=[logging.StreamHandler()]
|
140
|
+
)
|
141
|
+
|
142
|
+
# Get server URL from environment or use default
|
143
|
+
server_url = 'https://api-dev.unitlab.ai/'
|
144
|
+
|
145
|
+
# Generate unique device ID if not provided
|
146
|
+
if not device_id:
|
147
|
+
import uuid
|
148
|
+
import platform
|
149
|
+
# Try environment variable first
|
150
|
+
device_id = os.getenv('DEVICE_ID')
|
151
|
+
if not device_id:
|
152
|
+
# Generate a unique ID based on hostname and random UUID
|
153
|
+
hostname = platform.node().replace('.', '-').replace(' ', '-')[:20]
|
154
|
+
random_suffix = str(uuid.uuid4())[:8]
|
155
|
+
device_id = f"{hostname}-{random_suffix}"
|
156
|
+
|
157
|
+
|
158
|
+
# Create client and initialize device agent
|
159
|
+
client = UnitlabClient(api_key=api_key)
|
160
|
+
client.initialize_device_agent(
|
161
|
+
server_url=server_url,
|
162
|
+
device_id=device_id,
|
163
|
+
base_domain=base_domain
|
164
|
+
)
|
165
|
+
|
166
|
+
try:
|
167
|
+
client.run_device_agent()
|
168
|
+
except Exception as e:
|
169
|
+
logging.error(f"Fatal error: {e}")
|
170
|
+
client.cleanup_device_agent()
|
171
|
+
raise typer.Exit(1)
|
172
|
+
|
173
|
+
|
108
174
|
if __name__ == "__main__":
|
109
175
|
app()
|
@@ -0,0 +1,238 @@
|
|
1
|
+
"""
|
2
|
+
Cloudflare Tunnel Configuration for persistent subdomains
|
3
|
+
"""
|
4
|
+
|
5
|
+
import json
|
6
|
+
import subprocess
|
7
|
+
import socket
|
8
|
+
import time
|
9
|
+
import yaml
|
10
|
+
from pathlib import Path
|
11
|
+
|
12
|
+
|
13
|
+
class CloudflareTunnel:
|
14
|
+
def __init__(self, base_domain, device_id):
|
15
|
+
# Hardcode the base domain here
|
16
|
+
self.base_domain = "1scan.uz" # HARDCODED - ignore the passed base_domain
|
17
|
+
self.device_id = device_id
|
18
|
+
self.hostname = socket.gethostname()
|
19
|
+
self.tunnel_name = f"device-{device_id}"
|
20
|
+
self.config_dir = Path.home() / ".cloudflared"
|
21
|
+
self.config_dir.mkdir(exist_ok=True)
|
22
|
+
|
23
|
+
# Subdomain names
|
24
|
+
self.jupyter_subdomain = f"jupyter-{device_id}"
|
25
|
+
self.ssh_subdomain = f"ssh-{device_id}"
|
26
|
+
|
27
|
+
# Full URLs - using hardcoded base_domain
|
28
|
+
self.jupyter_url = f"https://{self.jupyter_subdomain}.{self.base_domain}"
|
29
|
+
self.ssh_url = f"https://{self.ssh_subdomain}.{self.base_domain}"
|
30
|
+
|
31
|
+
self.tunnel_uuid = None
|
32
|
+
self.credentials_file = None
|
33
|
+
|
34
|
+
def login(self):
|
35
|
+
"""Login to Cloudflare (one-time setup)"""
|
36
|
+
try:
|
37
|
+
print("🔐 Checking Cloudflare authentication...")
|
38
|
+
result = subprocess.run(
|
39
|
+
["cloudflared", "tunnel", "login"],
|
40
|
+
capture_output=True,
|
41
|
+
text=True
|
42
|
+
)
|
43
|
+
if result.returncode == 0:
|
44
|
+
print("✅ Cloudflare authentication successful")
|
45
|
+
return True
|
46
|
+
else:
|
47
|
+
print("❌ Cloudflare authentication failed")
|
48
|
+
return False
|
49
|
+
except Exception as e:
|
50
|
+
print(f"❌ Error during Cloudflare login: {e}")
|
51
|
+
return False
|
52
|
+
|
53
|
+
def create_tunnel(self):
|
54
|
+
"""Create a named tunnel"""
|
55
|
+
try:
|
56
|
+
print(f"🚇 Creating tunnel: {self.tunnel_name}")
|
57
|
+
|
58
|
+
# Check if tunnel already exists
|
59
|
+
list_result = subprocess.run(
|
60
|
+
["cloudflared", "tunnel", "list", "--output", "json"],
|
61
|
+
capture_output=True,
|
62
|
+
text=True
|
63
|
+
)
|
64
|
+
|
65
|
+
if list_result.returncode == 0:
|
66
|
+
tunnels = json.loads(list_result.stdout)
|
67
|
+
for tunnel in tunnels:
|
68
|
+
if tunnel.get("name") == self.tunnel_name:
|
69
|
+
self.tunnel_uuid = tunnel.get("id")
|
70
|
+
print(f"✅ Tunnel already exists with ID: {self.tunnel_uuid}")
|
71
|
+
self.credentials_file = self.config_dir / f"{self.tunnel_uuid}.json"
|
72
|
+
return True
|
73
|
+
|
74
|
+
# Create new tunnel
|
75
|
+
result = subprocess.run(
|
76
|
+
["cloudflared", "tunnel", "create", self.tunnel_name],
|
77
|
+
capture_output=True,
|
78
|
+
text=True
|
79
|
+
)
|
80
|
+
|
81
|
+
if result.returncode == 0:
|
82
|
+
for line in result.stdout.split('\n'):
|
83
|
+
if "Created tunnel" in line and "with id" in line:
|
84
|
+
self.tunnel_uuid = line.split("with id")[1].strip()
|
85
|
+
break
|
86
|
+
|
87
|
+
if not self.tunnel_uuid:
|
88
|
+
list_result = subprocess.run(
|
89
|
+
["cloudflared", "tunnel", "list", "--output", "json"],
|
90
|
+
capture_output=True,
|
91
|
+
text=True
|
92
|
+
)
|
93
|
+
if list_result.returncode == 0:
|
94
|
+
tunnels = json.loads(list_result.stdout)
|
95
|
+
for tunnel in tunnels:
|
96
|
+
if tunnel.get("name") == self.tunnel_name:
|
97
|
+
self.tunnel_uuid = tunnel.get("id")
|
98
|
+
break
|
99
|
+
|
100
|
+
if self.tunnel_uuid:
|
101
|
+
self.credentials_file = self.config_dir / f"{self.tunnel_uuid}.json"
|
102
|
+
print(f"✅ Tunnel created with ID: {self.tunnel_uuid}")
|
103
|
+
return True
|
104
|
+
|
105
|
+
print(f"❌ Failed to create tunnel: {result.stderr}")
|
106
|
+
return False
|
107
|
+
|
108
|
+
except Exception as e:
|
109
|
+
print(f"❌ Error creating tunnel: {e}")
|
110
|
+
return False
|
111
|
+
|
112
|
+
def configure_dns(self):
|
113
|
+
"""Configure DNS routes for the tunnel"""
|
114
|
+
try:
|
115
|
+
print("🌐 Configuring DNS routes...")
|
116
|
+
|
117
|
+
# Route for Jupyter
|
118
|
+
jupyter_result = subprocess.run(
|
119
|
+
["cloudflared", "tunnel", "route", "dns",
|
120
|
+
self.tunnel_name, f"{self.jupyter_subdomain}.{self.base_domain}"],
|
121
|
+
capture_output=True,
|
122
|
+
text=True
|
123
|
+
)
|
124
|
+
|
125
|
+
if jupyter_result.returncode == 0:
|
126
|
+
print(f"✅ Jupyter route configured: {self.jupyter_url}")
|
127
|
+
else:
|
128
|
+
print(f"⚠️ Jupyter route may already exist or failed: {jupyter_result.stderr}")
|
129
|
+
|
130
|
+
# Route for SSH
|
131
|
+
ssh_result = subprocess.run(
|
132
|
+
["cloudflared", "tunnel", "route", "dns",
|
133
|
+
self.tunnel_name, f"{self.ssh_subdomain}.{self.base_domain}"],
|
134
|
+
capture_output=True,
|
135
|
+
text=True
|
136
|
+
)
|
137
|
+
|
138
|
+
if ssh_result.returncode == 0:
|
139
|
+
print(f"✅ SSH route configured: {self.ssh_url}")
|
140
|
+
else:
|
141
|
+
print(f"⚠️ SSH route may already exist or failed: {ssh_result.stderr}")
|
142
|
+
|
143
|
+
return True
|
144
|
+
|
145
|
+
except Exception as e:
|
146
|
+
print(f"❌ Error configuring DNS: {e}")
|
147
|
+
return False
|
148
|
+
|
149
|
+
def create_config_file(self, jupyter_port):
|
150
|
+
"""Create tunnel configuration file"""
|
151
|
+
config = {
|
152
|
+
"tunnel": self.tunnel_uuid,
|
153
|
+
"credentials-file": str(self.credentials_file),
|
154
|
+
"ingress": [
|
155
|
+
{
|
156
|
+
"hostname": f"{self.jupyter_subdomain}.{self.base_domain}",
|
157
|
+
"service": f"http://localhost:{jupyter_port}",
|
158
|
+
"originRequest": {
|
159
|
+
"noTLSVerify": True
|
160
|
+
}
|
161
|
+
},
|
162
|
+
{
|
163
|
+
"hostname": f"{self.ssh_subdomain}.{self.base_domain}",
|
164
|
+
"service": "ssh://localhost:22",
|
165
|
+
"originRequest": {
|
166
|
+
"noTLSVerify": True
|
167
|
+
}
|
168
|
+
},
|
169
|
+
{
|
170
|
+
"service": "http_status:404"
|
171
|
+
}
|
172
|
+
]
|
173
|
+
}
|
174
|
+
|
175
|
+
config_file = self.config_dir / f"config-{self.device_id}.yml"
|
176
|
+
with open(config_file, 'w') as f:
|
177
|
+
yaml.dump(config, f, default_flow_style=False)
|
178
|
+
|
179
|
+
print(f"📝 Configuration saved to: {config_file}")
|
180
|
+
return config_file
|
181
|
+
|
182
|
+
def start_tunnel(self, config_file):
|
183
|
+
"""Start the tunnel with the configuration"""
|
184
|
+
try:
|
185
|
+
print("🚀 Starting Cloudflare tunnel...")
|
186
|
+
|
187
|
+
cmd = ["cloudflared", "tunnel", "--config", str(config_file), "run"]
|
188
|
+
|
189
|
+
process = subprocess.Popen(
|
190
|
+
cmd,
|
191
|
+
stdout=subprocess.PIPE,
|
192
|
+
stderr=subprocess.STDOUT,
|
193
|
+
text=True,
|
194
|
+
bufsize=1
|
195
|
+
)
|
196
|
+
|
197
|
+
# Wait for tunnel to establish
|
198
|
+
time.sleep(5)
|
199
|
+
|
200
|
+
if process.poll() is None:
|
201
|
+
print("✅ Tunnel is running")
|
202
|
+
return process
|
203
|
+
else:
|
204
|
+
print("❌ Tunnel failed to start")
|
205
|
+
return None
|
206
|
+
|
207
|
+
except Exception as e:
|
208
|
+
print(f"❌ Error starting tunnel: {e}")
|
209
|
+
return None
|
210
|
+
|
211
|
+
def setup(self, jupyter_port):
|
212
|
+
"""Complete setup process"""
|
213
|
+
# Check if we need to login
|
214
|
+
if not (self.config_dir / "cert.pem").exists():
|
215
|
+
if not self.login():
|
216
|
+
return None
|
217
|
+
|
218
|
+
# Create tunnel
|
219
|
+
if not self.create_tunnel():
|
220
|
+
return None
|
221
|
+
|
222
|
+
# Configure DNS
|
223
|
+
if not self.configure_dns():
|
224
|
+
return None
|
225
|
+
|
226
|
+
# Create config file
|
227
|
+
config_file = self.create_config_file(jupyter_port)
|
228
|
+
|
229
|
+
# Start tunnel
|
230
|
+
tunnel_process = self.start_tunnel(config_file)
|
231
|
+
|
232
|
+
if tunnel_process:
|
233
|
+
print("\n✅ Tunnel setup complete!")
|
234
|
+
print(f"📌 Jupyter URL: {self.jupyter_url}")
|
235
|
+
print(f"📌 SSH URL: {self.ssh_url}")
|
236
|
+
return tunnel_process
|
237
|
+
|
238
|
+
return None
|
@@ -1,6 +1,8 @@
|
|
1
1
|
import logging
|
2
2
|
from configparser import ConfigParser
|
3
3
|
from pathlib import Path
|
4
|
+
import logging
|
5
|
+
import requests
|
4
6
|
|
5
7
|
import requests
|
6
8
|
|
@@ -63,3 +65,4 @@ def get_api_url() -> str:
|
|
63
65
|
return config.get("default", "api_url")
|
64
66
|
except Exception:
|
65
67
|
return "https://api.unitlab.ai"
|
68
|
+
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.1
|
2
2
|
Name: unitlab
|
3
|
-
Version: 2.3.
|
3
|
+
Version: 2.3.3
|
4
4
|
Home-page: https://github.com/teamunitlab/unitlab-sdk
|
5
5
|
Author: Unitlab Inc.
|
6
6
|
Author-email: team@unitlab.ai
|
@@ -21,11 +21,5 @@ Requires-Dist: requests
|
|
21
21
|
Requires-Dist: tqdm
|
22
22
|
Requires-Dist: typer
|
23
23
|
Requires-Dist: validators
|
24
|
-
|
25
|
-
|
26
|
-
Dynamic: classifier
|
27
|
-
Dynamic: home-page
|
28
|
-
Dynamic: keywords
|
29
|
-
Dynamic: license
|
30
|
-
Dynamic: license-file
|
31
|
-
Dynamic: requires-dist
|
24
|
+
Requires-Dist: psutil
|
25
|
+
Requires-Dist: pyyaml
|
@@ -1,236 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import glob
|
3
|
-
import logging
|
4
|
-
import os
|
5
|
-
import urllib.parse
|
6
|
-
|
7
|
-
import aiofiles
|
8
|
-
import aiohttp
|
9
|
-
import requests
|
10
|
-
import tqdm
|
11
|
-
|
12
|
-
from .utils import get_api_url, handle_exceptions
|
13
|
-
|
14
|
-
logger = logging.getLogger(__name__)
|
15
|
-
|
16
|
-
|
17
|
-
class UnitlabClient:
|
18
|
-
"""A client with a connection to the Unitlab.ai platform.
|
19
|
-
|
20
|
-
Note:
|
21
|
-
Please refer to the `Python SDK quickstart <https://docs.unitlab.ai/cli-python-sdk/unitlab-python-sdk>`__ for a full example of working with the Python SDK.
|
22
|
-
|
23
|
-
First install the SDK.
|
24
|
-
|
25
|
-
.. code-block:: bash
|
26
|
-
|
27
|
-
pip install --upgrade unitlab
|
28
|
-
|
29
|
-
Import the ``unitlab`` package in your python file and set up a client with an API key. An API key can be created on <https://unitlab.ai/>`__.
|
30
|
-
|
31
|
-
.. code-block:: python
|
32
|
-
|
33
|
-
from unitlab import UnitlabClient
|
34
|
-
api_key = 'YOUR_API_KEY'
|
35
|
-
client = UnitlabClient(api_key)
|
36
|
-
|
37
|
-
Or store your Unitlab API key in your environment (``UNITLAB_API_KEY = 'YOUR_API_KEY'``):
|
38
|
-
|
39
|
-
.. code-block:: python
|
40
|
-
|
41
|
-
from unitlab import UnitlabClient
|
42
|
-
client = UnitlabClient()
|
43
|
-
|
44
|
-
Args:
|
45
|
-
api_key: Your Unitlab.ai API key. If no API key given, reads ``UNITLAB_API_KEY`` from the environment. Defaults to :obj:`None`.
|
46
|
-
Raises:
|
47
|
-
:exc:`~unitlab.exceptions.AuthenticationError`: If an invalid API key is used or (when not passing the API key directly) if ``UNITLAB_API_KEY`` is not found in your environment.
|
48
|
-
"""
|
49
|
-
|
50
|
-
def __init__(self, api_key, api_url=None):
|
51
|
-
self.api_key = api_key
|
52
|
-
self.api_url = api_url or get_api_url()
|
53
|
-
self.api_session = requests.Session()
|
54
|
-
adapter = requests.adapters.HTTPAdapter(max_retries=3)
|
55
|
-
self.api_session.mount("http://", adapter)
|
56
|
-
self.api_session.mount("https://", adapter)
|
57
|
-
|
58
|
-
def close(self) -> None:
|
59
|
-
"""Close :class:`UnitlabClient` connections.
|
60
|
-
|
61
|
-
You can manually close the Unitlab client's connections:
|
62
|
-
|
63
|
-
.. code-block:: python
|
64
|
-
|
65
|
-
client = UnitlabClient()
|
66
|
-
client.projects()
|
67
|
-
client.close()
|
68
|
-
|
69
|
-
Or use the client as a context manager:
|
70
|
-
|
71
|
-
.. code-block:: python
|
72
|
-
|
73
|
-
with UnitlabClient() as client:
|
74
|
-
client.projects()
|
75
|
-
"""
|
76
|
-
self.api_session.close()
|
77
|
-
|
78
|
-
def __enter__(self):
|
79
|
-
return self
|
80
|
-
|
81
|
-
def __exit__(
|
82
|
-
self,
|
83
|
-
exc_type,
|
84
|
-
exc_value,
|
85
|
-
traceback,
|
86
|
-
) -> None:
|
87
|
-
self.close()
|
88
|
-
|
89
|
-
def _get_headers(self):
|
90
|
-
return {"Authorization": f"Api-Key {self.api_key}"}
|
91
|
-
|
92
|
-
@handle_exceptions
|
93
|
-
def _get(self, endpoint):
|
94
|
-
return self.api_session.get(
|
95
|
-
urllib.parse.urljoin(self.api_url, endpoint), headers=self._get_headers()
|
96
|
-
)
|
97
|
-
|
98
|
-
@handle_exceptions
|
99
|
-
def _post(self, endpoint, data=None):
|
100
|
-
return self.api_session.post(
|
101
|
-
urllib.parse.urljoin(self.api_url, endpoint),
|
102
|
-
json=data or {},
|
103
|
-
headers=self._get_headers(),
|
104
|
-
)
|
105
|
-
|
106
|
-
def projects(self, pretty=0):
|
107
|
-
return self._get(f"/api/sdk/projects/?pretty={pretty}")
|
108
|
-
|
109
|
-
def project(self, project_id, pretty=0):
|
110
|
-
return self._get(f"/api/sdk/projects/{project_id}/?pretty={pretty}")
|
111
|
-
|
112
|
-
def project_members(self, project_id, pretty=0):
|
113
|
-
return self._get(f"/api/sdk/projects/{project_id}/members/?pretty={pretty}")
|
114
|
-
|
115
|
-
def project_upload_data(self, project_id, directory, batch_size=100):
|
116
|
-
if not os.path.isdir(directory):
|
117
|
-
raise ValueError(f"Directory {directory} does not exist")
|
118
|
-
|
119
|
-
files = [
|
120
|
-
file
|
121
|
-
for files_list in (
|
122
|
-
glob.glob(os.path.join(directory, "") + extension)
|
123
|
-
for extension in ["*jpg", "*png", "*jpeg", "*webp"]
|
124
|
-
)
|
125
|
-
for file in files_list
|
126
|
-
]
|
127
|
-
filtered_files = []
|
128
|
-
for file in files:
|
129
|
-
file_size = os.path.getsize(file) / 1024 / 1024
|
130
|
-
if file_size > 6:
|
131
|
-
logger.warning(
|
132
|
-
f"File {file} is too large ({file_size:.4f} megabytes) skipping, max size is 6 MB"
|
133
|
-
)
|
134
|
-
continue
|
135
|
-
filtered_files.append(file)
|
136
|
-
|
137
|
-
num_files = len(filtered_files)
|
138
|
-
num_batches = (num_files + batch_size - 1) // batch_size
|
139
|
-
|
140
|
-
async def post_file(session: aiohttp.ClientSession, file: str, project_id: str):
|
141
|
-
async with aiofiles.open(file, "rb") as f:
|
142
|
-
form_data = aiohttp.FormData()
|
143
|
-
form_data.add_field("project", project_id)
|
144
|
-
form_data.add_field(
|
145
|
-
"file", await f.read(), filename=os.path.basename(file)
|
146
|
-
)
|
147
|
-
try:
|
148
|
-
await asyncio.sleep(0.1)
|
149
|
-
async with session.post(
|
150
|
-
urllib.parse.urljoin(self.api_url, "/api/sdk/upload-data/"),
|
151
|
-
data=form_data,
|
152
|
-
) as response:
|
153
|
-
response.raise_for_status()
|
154
|
-
return 1
|
155
|
-
except Exception as e:
|
156
|
-
logger.error(f"Error uploading file {file} - {e}")
|
157
|
-
return 0
|
158
|
-
|
159
|
-
async def main():
|
160
|
-
logger.info(f"Uploading {num_files} files to project {project_id}")
|
161
|
-
with tqdm.tqdm(total=num_files, ncols=80) as pbar:
|
162
|
-
async with aiohttp.ClientSession(
|
163
|
-
headers=self._get_headers()
|
164
|
-
) as session:
|
165
|
-
for i in range(num_batches):
|
166
|
-
tasks = []
|
167
|
-
for file in filtered_files[
|
168
|
-
i * batch_size : min((i + 1) * batch_size, num_files)
|
169
|
-
]:
|
170
|
-
tasks.append(
|
171
|
-
post_file(
|
172
|
-
session=session, file=file, project_id=project_id
|
173
|
-
)
|
174
|
-
)
|
175
|
-
for f in asyncio.as_completed(tasks):
|
176
|
-
pbar.update(await f)
|
177
|
-
|
178
|
-
asyncio.run(main())
|
179
|
-
|
180
|
-
def datasets(self, pretty=0):
|
181
|
-
return self._get(f"/api/sdk/datasets/?pretty={pretty}")
|
182
|
-
|
183
|
-
def dataset_download(self, dataset_id, export_type):
|
184
|
-
response = self._post(
|
185
|
-
f"/api/sdk/datasets/{dataset_id}/",
|
186
|
-
data={"download_type": "annotation", "export_type": export_type},
|
187
|
-
)
|
188
|
-
|
189
|
-
with self.api_session.get(url=response["file"], stream=True) as r:
|
190
|
-
r.raise_for_status()
|
191
|
-
filename = f"dataset-{dataset_id}.json"
|
192
|
-
with open(filename, "wb") as f:
|
193
|
-
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
194
|
-
f.write(chunk)
|
195
|
-
logger.info(f"File: {os.path.abspath(filename)}")
|
196
|
-
return os.path.abspath(filename)
|
197
|
-
|
198
|
-
def dataset_download_files(self, dataset_id):
|
199
|
-
response = self._post(
|
200
|
-
f"/api/sdk/datasets/{dataset_id}/", data={"download_type": "files"}
|
201
|
-
)
|
202
|
-
folder = f"dataset-files-{dataset_id}"
|
203
|
-
os.makedirs(folder, exist_ok=True)
|
204
|
-
dataset_files = [
|
205
|
-
dataset_file
|
206
|
-
for dataset_file in response
|
207
|
-
if not os.path.isfile(os.path.join(folder, dataset_file["file_name"]))
|
208
|
-
]
|
209
|
-
|
210
|
-
async def download_file(session: aiohttp.ClientSession, dataset_file: dict):
|
211
|
-
async with session.get(url=dataset_file["source"]) as r:
|
212
|
-
try:
|
213
|
-
r.raise_for_status()
|
214
|
-
except Exception as e:
|
215
|
-
logger.error(
|
216
|
-
f"Error downloading file {dataset_file['file_name']} - {e}"
|
217
|
-
)
|
218
|
-
return 0
|
219
|
-
async with aiofiles.open(
|
220
|
-
os.path.join(folder, dataset_file["file_name"]), "wb"
|
221
|
-
) as f:
|
222
|
-
async for chunk in r.content.iter_any():
|
223
|
-
await f.write(chunk)
|
224
|
-
return 1
|
225
|
-
|
226
|
-
async def main():
|
227
|
-
with tqdm.tqdm(total=len(dataset_files), ncols=80) as pbar:
|
228
|
-
async with aiohttp.ClientSession() as session:
|
229
|
-
tasks = [
|
230
|
-
download_file(session=session, dataset_file=dataset_file)
|
231
|
-
for dataset_file in dataset_files
|
232
|
-
]
|
233
|
-
for f in asyncio.as_completed(tasks):
|
234
|
-
pbar.update(await f)
|
235
|
-
|
236
|
-
asyncio.run(main())
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|