vastai-sdk 0.4.2.dev2__tar.gz → 0.4.2.dev3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/PKG-INFO +1 -1
  2. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/pyproject.toml +1 -1
  3. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/client/client.py +9 -37
  4. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/client/connection.py +1 -9
  5. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/client/endpoint.py +7 -5
  6. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/server/lib/backend.py +117 -68
  7. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/server/lib/data_types.py +1 -0
  8. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/server/lib/server.py +1 -1
  9. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/vast.py +10 -9
  10. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/vastai_sdk.py +8 -11
  11. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/LICENSE +0 -0
  12. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/README.md +0 -0
  13. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/__init__.py +0 -0
  14. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/__init__.py +0 -0
  15. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/client/__init__.py +0 -0
  16. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/client/session.py +0 -0
  17. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/client/worker.py +0 -0
  18. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/server/__init__.py +0 -0
  19. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/server/lib/__init__.py +0 -0
  20. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/server/lib/metrics.py +0 -0
  21. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/serverless/server/worker.py +0 -0
  22. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai/vastai_base.py +0 -0
  23. {vastai_sdk-0.4.2.dev2 → vastai_sdk-0.4.2.dev3}/vastai_sdk/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: vastai-sdk
3
- Version: 0.4.2.dev2
3
+ Version: 0.4.2.dev3
4
4
  Summary: SDK for Vast.ai GPU Cloud Service
5
5
  License-File: LICENSE
6
6
  Author: Chris McKenzie
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "vastai-sdk"
7
- version = "0.4.2.dev2"
7
+ version = "0.4.2.dev3"
8
8
  description = "SDK for Vast.ai GPU Cloud Service"
9
9
  readme = "README.md"
10
10
  authors = [
@@ -136,7 +136,6 @@ class Serverless:
136
136
 
137
137
  return self._ssl_context
138
138
 
139
-
140
139
  async def get_endpoint(self, name="") -> Endpoint:
141
140
  endpoints = await self.get_endpoints()
142
141
  for e in endpoints:
@@ -179,17 +178,6 @@ class Serverless:
179
178
  raise RuntimeError(f"get_endpoint_workers failed: HTTP {resp.status} - {text}")
180
179
 
181
180
  data = await resp.json(content_type=None)
182
-
183
- # If error message from authenticate_endpoint_apikey_by_id occurs, there is a possibility that
184
- # the endpoint's worker instances are not ready to be queried. If an error message occurs,
185
- # return an empty list and print the error message to the user. The endpoint get_endpoint_workers
186
- # should normally return a list of dictionaries containing worker instance information.
187
- if isinstance(data,dict):
188
- if 'error_msg' in data.keys():
189
- self.logger.warning(f"Received the following error from get_endpoint_workers:{data['error_msg']}.\nEndpoint may not be ready for query. Check credentials or wait a few minutes and try again.")
190
- return []
191
-
192
-
193
181
  if not isinstance(data, list):
194
182
  raise RuntimeError(f"Unexpected response type (wanted list): {type(data)}")
195
183
 
@@ -296,7 +284,7 @@ class Serverless:
296
284
  session: Session = None,
297
285
  serverless_request: Optional[ServerlessRequest] = None,
298
286
  cost: int = 100,
299
- timeout: Optional[float] = None,
287
+ max_wait_time: Optional[float] = None,
300
288
  retry: bool = True,
301
289
  max_retries: int = None,
302
290
  stream: bool = False
@@ -308,7 +296,6 @@ class Serverless:
308
296
  async def task(request: ServerlessRequest):
309
297
  request_idx: int = 0
310
298
  total_attempts = 0
311
- start_time = time.time()
312
299
  try:
313
300
  while True:
314
301
  total_attempts += 1
@@ -317,10 +304,6 @@ class Serverless:
317
304
  auth_data = {}
318
305
  session_id = None
319
306
 
320
- # Check total elapsed time
321
- if timeout is not None and (time.time() - start_time) >= timeout:
322
- raise asyncio.TimeoutError(f"Timed out after {time.time() - start_time:.1f}s waiting for worker")
323
-
324
307
  if session is None:
325
308
  if request_idx == 0:
326
309
  self.logger.debug(f"Sending initial route call for request_idx {request_idx}")
@@ -336,21 +319,19 @@ class Serverless:
336
319
  self.logger.error("Did not get request_idx from initial route")
337
320
 
338
321
  poll_interval = 1
339
- poll_elapsed = 0
322
+ elapsed_time = 0
340
323
  attempt = 0
341
324
  while route.status != "READY":
342
325
  request.status = "Polling"
343
-
344
- # Check total elapsed time
345
- if timeout is not None and (time.time() - start_time) >= timeout:
346
- raise asyncio.TimeoutError(f"Timed out after {time.time() - start_time:.1f}s waiting for worker to become ready")
326
+ if max_wait_time is not None and elapsed_time >= max_wait_time:
327
+ raise asyncio.TimeoutError("Timed out waiting for worker to become ready")
347
328
 
348
329
  await asyncio.sleep(poll_interval)
349
- poll_elapsed += poll_interval
330
+ elapsed_time += poll_interval
350
331
 
351
332
  route = await endpoint._route(cost=cost, req_idx=request_idx, timeout=60.0)
352
333
  request_idx = route.request_idx or request_idx
353
-
334
+
354
335
  attempt += 1
355
336
  poll_interval = random.uniform(0.1, min((2 ** attempt) + random.uniform(0, 1), self.max_poll_interval))
356
337
  self.logger.debug(f"Polling route, attempt {attempt}")
@@ -385,7 +366,7 @@ class Serverless:
385
366
  body=worker_request_body,
386
367
  method="POST",
387
368
  retries=1, # avoid stacking retries with the outer loop
388
- timeout=600,
369
+ timeout=30,
389
370
  stream=stream
390
371
  )
391
372
  except Exception as ex:
@@ -396,10 +377,6 @@ class Serverless:
396
377
 
397
378
  if not result.get("ok"):
398
379
  if retry and result.get("retryable") and (max_retries is None or total_attempts < max_retries):
399
- # Check if we have time left before retrying
400
- if timeout is not None and (time.time() - start_time) >= timeout:
401
- raise asyncio.TimeoutError(f"Request timed out after {time.time() - start_time:.1f}s")
402
-
403
380
  request.status = "Retrying"
404
381
  await asyncio.sleep(min((2 ** total_attempts) + random.uniform(0, 1), self.max_poll_interval))
405
382
  continue
@@ -412,9 +389,7 @@ class Serverless:
412
389
 
413
390
  response = {
414
391
  "response": result.get("json") if result.get("json") is not None else {"error": result.get("text", "")},
415
- "ok": result.get("ok"),
416
- "status": result.get("status"),
417
- "text" : result.get("text"),
392
+ "result": result,
418
393
  "latency": (request.complete_time - request.start_time) if request.start_time else None,
419
394
  "url": worker_url,
420
395
  "request_idx": request_idx,
@@ -433,9 +408,7 @@ class Serverless:
433
408
 
434
409
  response = {
435
410
  "response": worker_response,
436
- "ok" : result.get("ok"),
437
- "status" : result.get("status"),
438
- "text" : result.get("text"),
411
+ "result": result,
439
412
  "latency": request.complete_time - request.start_time,
440
413
  "url": worker_url,
441
414
  "request_idx": request_idx,
@@ -450,7 +423,6 @@ class Serverless:
450
423
  except Exception as ex:
451
424
  request.status = "Errored"
452
425
  self.logger.error(f"Request errored: {ex}")
453
- request.set_exception(ex)
454
426
  return
455
427
 
456
428
  bg_task = asyncio.create_task(task(serverless_request))
@@ -3,7 +3,7 @@ import aiohttp
3
3
  import asyncio
4
4
  import random
5
5
  import json
6
- from typing import AsyncIterator, Dict, Optional, Any
6
+ from typing import AsyncIterator, Dict, Optional, Union, Any
7
7
 
8
8
  _JITTER_CAP_SECONDS = 5.0
9
9
 
@@ -198,10 +198,6 @@ async def _make_request(
198
198
  last_result["stream"] = _stream_iter()
199
199
  return last_result
200
200
 
201
- except asyncio.TimeoutError as ex:
202
- if attempt == retries:
203
- raise TimeoutError(f"Request to {full_url} timed out after {timeout}s") from ex
204
- await asyncio.sleep(_backoff_delay(attempt))
205
201
  except Exception as ex:
206
202
  if attempt == retries:
207
203
  raise ex
@@ -261,10 +257,6 @@ async def _make_request(
261
257
 
262
258
  return result
263
259
 
264
- except asyncio.TimeoutError as ex:
265
- if attempt == retries:
266
- raise TimeoutError(f"Request to {full_url} timed out after {timeout}s") from ex
267
- await asyncio.sleep(_backoff_delay(attempt))
268
260
  except Exception as ex:
269
261
  if attempt == retries:
270
262
  raise ex
@@ -24,7 +24,7 @@ class Endpoint:
24
24
  self.id = id
25
25
  self.api_key = api_key
26
26
 
27
- def request(self, route, payload, serverless_request=None, cost: int = 100, retry: bool = True, stream: bool = False, timeout: float = None, session: "Session" = None):
27
+ def request(self, route, payload, serverless_request=None, cost: int = 100, retry: bool = True, stream: bool = False, session: "Session" = None):
28
28
  return self.client.queue_endpoint_request(
29
29
  endpoint=self,
30
30
  worker_route=route,
@@ -33,12 +33,13 @@ class Endpoint:
33
33
  cost=cost,
34
34
  retry=retry,
35
35
  stream=stream,
36
- timeout=timeout,
37
36
  session=session
38
37
  )
39
38
 
40
39
  def close_session(self, session: "Session"):
41
- return self.client.end_endpoint_session(session=session)
40
+ return self.client.end_endpoint_session(
41
+ session=session
42
+ )
42
43
 
43
44
  async def session_healthcheck(self, session: "Session"):
44
45
  result = await self.client.get_endpoint_session(
@@ -85,7 +86,7 @@ class Endpoint:
85
86
  },
86
87
  method="POST",
87
88
  timeout=10.0,
88
- retries=1,
89
+ retries=5,
89
90
  stream=False,
90
91
  )
91
92
  except Exception as ex:
@@ -95,7 +96,8 @@ class Endpoint:
95
96
  raise RuntimeError(f"Failed to route endpoint: HTTP {result.get('status')} - {result.get('text','')[:512]}")
96
97
 
97
98
  return RouteResponse(result.get("json") or {})
98
-
99
+
100
+
99
101
  class RouteResponse:
100
102
  status: str
101
103
  body: dict
@@ -5,7 +5,7 @@ import base64
5
5
  import subprocess
6
6
  import dataclasses
7
7
  import logging
8
- from asyncio import sleep, gather, Semaphore, create_task
8
+ from asyncio import wait, sleep, gather, Semaphore, FIRST_COMPLETED, create_task
9
9
  from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional, Any, Dict
10
10
  from functools import cached_property
11
11
  from distutils.util import strtobool
@@ -35,8 +35,9 @@ from .data_types import (
35
35
  Session
36
36
  )
37
37
 
38
- VERSION = "1.1.0"
38
+ VERSION = "1.0.1"
39
39
 
40
+ MSG_HISTORY_LEN = 100
40
41
  log = logging.getLogger(__file__)
41
42
 
42
43
  # defines the minimum wait time between sending updates to autoscaler
@@ -65,6 +66,7 @@ class Backend:
65
66
  log_actions: List[Tuple[LogAction, str]]
66
67
  reqnum = -1
67
68
  version = VERSION
69
+ msg_history = []
68
70
  sem: Semaphore = dataclasses.field(default_factory=Semaphore)
69
71
  queue: deque = dataclasses.field(default_factory=deque, repr=False)
70
72
  unsecured: bool = dataclasses.field(
@@ -184,7 +186,8 @@ class Backend:
184
186
  if session is None:
185
187
  return False
186
188
 
187
- # Cancel all in-flight request handler tasks
189
+ session.cancel_event.set()
190
+
188
191
  for req in list(session.requests):
189
192
  try:
190
193
  tr = getattr(req, "transport", None)
@@ -193,15 +196,15 @@ class Backend:
193
196
  except Exception:
194
197
  pass
195
198
  session.requests.clear()
199
+
196
200
  request_metrics = self.session_metrics.pop(session_id, None)
197
201
 
198
- # Run the on_close callback
199
202
  try:
200
203
  await self.__run_session_on_close(session)
201
204
  except Exception:
202
205
  pass
203
206
 
204
- # Update metrics outside lock
207
+ # metrics outside lock
205
208
  if request_metrics is not None:
206
209
  self.metrics._request_success(request_metrics)
207
210
  self.metrics._request_end(request_metrics)
@@ -310,8 +313,6 @@ class Backend:
310
313
  self._total_pubkey_fetch_errors = 0
311
314
  self._pubkey = self._fetch_pubkey()
312
315
  self.__start_healthcheck: bool = False
313
- self.__healthcheck_ready: asyncio.Event = asyncio.Event()
314
- self.__healthcheck_succeeded: bool = False
315
316
 
316
317
  @property
317
318
  def pubkey(self) -> Optional[RSA.RsaKey]:
@@ -374,6 +375,7 @@ class Backend:
374
375
  request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
375
376
 
376
377
  event = asyncio.Event()
378
+ finished = asyncio.Event()
377
379
 
378
380
  session = None
379
381
  if session_id is not None:
@@ -398,6 +400,14 @@ class Backend:
398
400
  except ValueError:
399
401
  pass
400
402
 
403
+ async def cancel_api_call_if_disconnected() -> None:
404
+ await request.wait_for_disconnection()
405
+ if not finished.is_set():
406
+ self.metrics._request_canceled(request_metrics)
407
+
408
+ async def cancel_if_session_closed() -> None:
409
+ await session.cancel_event.wait()
410
+
401
411
  async def make_request() -> Union[web.Response, web.StreamResponse]:
402
412
  try:
403
413
  response = await self.__call_backend(handler=handler, payload=payload)
@@ -419,7 +429,13 @@ class Backend:
419
429
  if handler.max_queue_time is not None and self.metrics.model_metrics.wait_time > handler.max_queue_time:
420
430
  self.metrics._request_reject(request_metrics)
421
431
  return web.Response(status=429)
432
+
433
+ disconnect_task = create_task(cancel_api_call_if_disconnected())
434
+ session_cancel_task = None
435
+ if session is not None:
436
+ session_cancel_task = create_task(cancel_if_session_closed())
422
437
 
438
+ next_request_task = None
423
439
  work_task = None
424
440
 
425
441
  self.metrics._request_start(request_metrics, session)
@@ -427,7 +443,24 @@ class Backend:
427
443
  try:
428
444
  if handler.allow_parallel_requests:
429
445
  work_task = create_task(make_request())
430
- # Handler cancellation will raise CancelledError on client disconnect
446
+
447
+ wait_set = [work_task, disconnect_task]
448
+ if session_cancel_task is not None:
449
+ wait_set.append(session_cancel_task)
450
+
451
+ done, pending = await wait(
452
+ wait_set,
453
+ return_when=FIRST_COMPLETED,
454
+ )
455
+
456
+ for t in pending:
457
+ t.cancel()
458
+ await asyncio.gather(*pending, return_exceptions=True)
459
+
460
+ if disconnect_task in done or (session_cancel_task is not None and session_cancel_task in done):
461
+ return web.Response(status=499) # request cancelled
462
+
463
+ # otherwise work_task completed
431
464
  return await work_task
432
465
 
433
466
  # FIFO-queue branch
@@ -438,50 +471,80 @@ class Backend:
438
471
  if self.queue and self.queue[0] is event:
439
472
  event.set()
440
473
 
441
- # Wait for our turn - CancelledError raised if client disconnects
442
- await event.wait()
474
+ # Race between our request being next and request being cancelled
475
+ next_request_task = create_task(event.wait())
476
+
477
+ wait_set = [next_request_task, disconnect_task]
478
+ if session_cancel_task is not None:
479
+ wait_set.append(session_cancel_task)
480
+
481
+ first_done, first_pending = await wait(
482
+ wait_set,
483
+ return_when=FIRST_COMPLETED,
484
+ )
485
+ # If the disconnect task wins the race
486
+ if disconnect_task in first_done or (session_cancel_task is not None and session_cancel_task in first_done):
487
+ # Clean up the next_request_task, then exit
488
+ for t in first_pending:
489
+ t.cancel()
490
+ await asyncio.gather(*first_pending, return_exceptions=True)
491
+ return web.Response(status=499)
443
492
 
444
493
  # We are the next-up request in the queue
445
494
  if session is not None:
446
495
  log.debug(f"Starting work on request {request_metrics.reqnum}")
447
496
 
448
- # Execute the work task
497
+ # Race the backend API call with the disconnect task
449
498
  work_task = create_task(make_request())
499
+
500
+ wait_set = [work_task, disconnect_task]
501
+ if session_cancel_task is not None:
502
+ wait_set.append(session_cancel_task)
503
+
504
+ done, pending = await wait(
505
+ wait_set,
506
+ return_when=FIRST_COMPLETED,
507
+ )
508
+
509
+ for t in pending:
510
+ t.cancel()
511
+ await asyncio.gather(*pending, return_exceptions=True)
512
+
513
+ if disconnect_task in done or (session_cancel_task is not None and session_cancel_task in done):
514
+ return web.Response(status=499)
515
+
516
+ # otherwise work_task completed
450
517
  return await work_task
451
518
 
452
519
  except asyncio.CancelledError:
453
- # With handler_cancellation enabled, this indicates client disconnect
454
- log.debug(f"Request {request_metrics.reqnum} cancelled (client disconnect)")
455
- self.metrics._request_canceled(request_metrics)
456
520
  return web.Response(status=499)
457
-
521
+
458
522
  except Exception as e:
459
523
  log.debug(f"Exception in main handler loop {e}")
460
524
  return web.Response(status=500)
461
525
 
462
526
  finally:
463
- try:
464
- # Remove request from session if present
465
- if session is not None and session_id is not None:
466
- async with self._sessions_lock:
467
- s = self.sessions.get(session_id)
468
- if s is not None:
469
- try:
470
- s.requests.remove(request)
471
- except ValueError:
472
- pass
473
-
474
- if not handler.allow_parallel_requests:
475
- advance_queue_after_completion(event)
476
-
477
- self.metrics._request_end(request_metrics)
478
-
479
- # Cleanup work task if still pending
480
- if work_task and not work_task.done():
481
- work_task.cancel()
482
- await asyncio.gather(work_task, return_exceptions=True)
483
- except Exception as e:
484
- log.error(f"Error during request cleanup: {e}")
527
+ # Set finished flag so we don't cancel after completion
528
+ finished.set()
529
+
530
+ if session is not None and session_id is not None:
531
+ async with self._sessions_lock:
532
+ s = self.sessions.get(session_id)
533
+ if s is not None:
534
+ try:
535
+ s.requests.remove(request)
536
+ except ValueError:
537
+ pass
538
+
539
+ if not handler.allow_parallel_requests:
540
+ advance_queue_after_completion(event)
541
+ self.metrics._request_end(request_metrics, session)
542
+ cleanup_tasks = [t for t in (next_request_task, work_task, disconnect_task, session_cancel_task) if t]
543
+ for t in cleanup_tasks:
544
+ if not t.done():
545
+ t.cancel()
546
+ if cleanup_tasks:
547
+ await asyncio.gather(*cleanup_tasks, return_exceptions=True)
485
548
 
486
549
  async def __healthcheck(self) -> None:
487
550
  """
@@ -513,17 +576,12 @@ class Backend:
513
576
 
514
577
  if status == 200:
515
578
  log.debug("Healthcheck successful")
516
- if not self.__healthcheck_succeeded:
517
- self.__healthcheck_succeeded = True
518
- self.__healthcheck_ready.set()
519
- log.debug("First healthcheck succeeded - model is ready")
520
- else:
579
+ elif status == 503:
521
580
  msg = f"Healthcheck failed with status: {status}"
522
581
  log.debug(msg)
523
- # Only report error if we've already had a successful healthcheck
524
- # (i.e., model was working but now is broken)
525
- if self.__healthcheck_succeeded:
526
- self.backend_errored(msg)
582
+ self.backend_errored(msg)
583
+ else:
584
+ log.debug(f"Healthcheck endpoint not ready: {status}")
527
585
 
528
586
  except CancelledError:
529
587
  log.debug("Healthcheck task cancelled; exiting loop")
@@ -531,10 +589,7 @@ class Backend:
531
589
 
532
590
  except Exception as e:
533
591
  log.debug(f"Healthcheck failed with exception: {e}")
534
- # Only report connection errors AFTER the first successful healthcheck
535
- # During startup, connection failures are expected
536
- if self.__healthcheck_succeeded:
537
- self.backend_errored(str(e))
592
+ self.backend_errored(str(e))
538
593
 
539
594
  async def _start_tracking(self) -> None:
540
595
  await gather(
@@ -600,11 +655,17 @@ class Backend:
600
655
  return False
601
656
 
602
657
  message = {
603
- "url" : auth_data.url
658
+ key: value
659
+ for (key, value) in (dataclasses.asdict(auth_data).items())
660
+ if key != "signature" and key != "__request_id"
604
661
  }
605
-
606
- if verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
662
+ if auth_data.reqnum < (self.reqnum - MSG_HISTORY_LEN):
663
+ log.error(f"Signature error: reqnum failure, got {auth_data.reqnum}, current_reqnum: {self.reqnum}")
664
+ return False
665
+ elif verify_signature(json.dumps(message, indent=4, sort_keys=True), auth_data.signature):
607
666
  self.reqnum = max(auth_data.reqnum, self.reqnum)
667
+ self.msg_history.append(message)
668
+ self.msg_history = self.msg_history[-MSG_HISTORY_LEN:]
608
669
  return True
609
670
  else:
610
671
  log.error(f"Signature error: signature verification failed, sig:{auth_data.signature}, message: {message}")
@@ -688,24 +749,12 @@ class Backend:
688
749
  log.debug(
689
750
  f"Got log line indicating model is loaded: {log_line}"
690
751
  )
752
+ # some backends need a few seconds after logging successful startup before
753
+ # they can begin accepting requests
754
+ # await sleep(5)
691
755
  try:
692
756
  max_throughput = await run_benchmark()
693
757
  self.__start_healthcheck = True
694
-
695
- # Wait for the first successful healthcheck before marking model as loaded
696
- if self.healthcheck_url:
697
- log.debug("Benchmark succeeded, waiting for healthcheck to confirm model is ready...")
698
- try:
699
- await asyncio.wait_for(self.__healthcheck_ready.wait(), timeout=300.0)
700
- log.debug("Healthcheck confirmed - marking model as loaded")
701
- except asyncio.TimeoutError:
702
- raise Exception("Timed out waiting for healthcheck after benchmark (waited 300s)")
703
- else:
704
- # No healthcheck endpoint defined, wait 10 seconds as fallback
705
- log.debug("No healthcheck endpoint defined, waiting 10 seconds before marking model as loaded...")
706
- await asyncio.sleep(10)
707
- log.debug("Wait complete - marking model as loaded")
708
-
709
758
  self.metrics._model_loaded(
710
759
  max_throughput=max_throughput,
711
760
  )
@@ -353,5 +353,6 @@ class Session:
353
353
  on_close_payload: dict
354
354
  requests: list[web.Request] = field(default_factory=list)
355
355
  created_at: float = field(default_factory=time.time)
356
+ cancel_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
356
357
 
357
358
 
@@ -37,7 +37,7 @@ async def start_server_async(backend: Backend, routes: List[web.RouteDef], **kwa
37
37
  app.router.add_post("/session/get", backend.session_get_handler)
38
38
  app.router.add_post("/session/health", backend.session_health_handler)
39
39
 
40
- runner = web.AppRunner(app, handler_cancellation=True)
40
+ runner = web.AppRunner(app)
41
41
  await runner.setup()
42
42
  site = web.TCPSite(
43
43
  runner,
@@ -3960,10 +3960,10 @@ def show__earnings(args):
3960
3960
  :rtype:
3961
3961
  """
3962
3962
 
3963
- Minutes = 60.0
3964
- Hours = 60.0*Minutes
3965
- Days = 24.0*Hours
3966
- Years = 365.0*Days
3963
+ Minutes = 60.0;
3964
+ Hours = 60.0*Minutes;
3965
+ Days = 24.0*Hours;
3966
+ Years = 365.0*Days;
3967
3967
  cday = time.time() / Days
3968
3968
  sday = cday - 1.0
3969
3969
  eday = cday - 1.0
@@ -3979,7 +3979,7 @@ def show__earnings(args):
3979
3979
  try:
3980
3980
  end_date = dateutil.parser.parse(str(args.end_date))
3981
3981
  end_date_txt = end_date.isoformat()
3982
- end_timestamp = end_date.timestamp()
3982
+ end_timestamp = time.mktime(end_date.timetuple())
3983
3983
  eday = end_timestamp / Days
3984
3984
  except ValueError as e:
3985
3985
  print(f"Warning: Invalid end date format! Ignoring end date! \n {str(e)}")
@@ -3988,20 +3988,21 @@ def show__earnings(args):
3988
3988
  try:
3989
3989
  start_date = dateutil.parser.parse(str(args.start_date))
3990
3990
  start_date_txt = start_date.isoformat()
3991
- start_timestamp = start_date.timestamp()
3991
+ start_timestamp = time.mktime(start_date.timetuple())
3992
3992
  sday = start_timestamp / Days
3993
- except ValueError as e:
3993
+ except ValueError:
3994
3994
  print(f"Warning: Invalid start date format! Ignoring start date! \n {str(e)}")
3995
3995
 
3996
+
3997
+
3996
3998
  req_url = apiurl(args, "/users/me/machine-earnings", {"owner": "me", "sday": sday, "eday": eday, "machid" :args.machine_id});
3997
3999
  r = http_get(args, req_url)
3998
4000
  r.raise_for_status()
3999
4001
  rows = r.json()
4000
4002
 
4001
- if args.raw:
4002
- return rows
4003
4003
  print(json.dumps(rows, indent=1, sort_keys=True))
4004
4004
 
4005
+
4005
4006
  def sum(X, k):
4006
4007
  y = 0
4007
4008
  for x in X:
@@ -348,21 +348,18 @@ class VastAI(VastAIBase):
348
348
 
349
349
  sig = getattr(func, "mysignature", None)
350
350
  sig_help = getattr(func, "mysignature_help", None)
351
-
352
351
  if sig:
353
352
  wrapper.__signature__, docappend = self.generate_signature_from_argparse(sig)
353
+ epi = None
354
354
 
355
- # append epilog if exists
356
- if getattr(sig, "epilog", None):
357
- wrapper.__doc__ = f"{wrapper.__doc__.rstrip()}\n\n{sig.epilog.strip()}\n"
358
-
359
- # if no epilog or func docstring, fall back to parser help text
360
- elif sig_help and not hasDoc:
361
- wrapper.__doc__ += f"\n\n{sig_help}"
362
-
363
- # finally append the arg details
364
- wrapper.__doc__ = f"{wrapper.__doc__.rstrip()}\n\n{docappend}"
355
+ if sig.epilog:
356
+ epi = re.sub('Example.?:.*', '', sig.epilog, flags=re.DOTALL|re.M).strip()
357
+ wrapper.__doc__ += epi
365
358
 
359
+ if not (epi or hasDoc) and sig_help:
360
+ wrapper.__doc__ += sig_help
361
+
362
+ wrapper.__doc__ = '\n\n'.join([ wrapper.__doc__.rstrip(), docappend ])
366
363
  return wrapper
367
364
 
368
365
  def credentials_on_disk(self):
File without changes