truss 0.10.9rc601__py3-none-any.whl → 0.10.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/base/constants.py +0 -1
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +30 -22
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +8 -2
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +63 -0
- truss/cli/train/deploy_from_checkpoint_config_whisper.yml +17 -0
- truss/cli/train_commands.py +11 -3
- truss/contexts/image_builder/cache_warmer.py +1 -3
- truss/contexts/image_builder/serving_image_builder.py +24 -32
- truss/remote/baseten/api.py +11 -0
- truss/remote/baseten/core.py +209 -1
- truss/remote/baseten/utils/time.py +15 -0
- truss/templates/server/model_wrapper.py +0 -12
- truss/templates/server/requirements.txt +1 -1
- truss/templates/server/truss_server.py +0 -13
- truss/templates/server.Dockerfile.jinja +1 -1
- truss/tests/cli/train/test_deploy_checkpoints.py +436 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +1 -1
- truss/tests/remote/baseten/conftest.py +18 -0
- truss/tests/remote/baseten/test_api.py +49 -14
- truss/tests/remote/baseten/test_core.py +517 -1
- truss/tests/test_data/test_openai/model/model.py +0 -3
- truss/truss_handle/truss_handle.py +0 -1
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/METADATA +2 -2
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/RECORD +30 -28
- truss_train/definitions.py +6 -0
- truss_train/deployment.py +15 -2
- truss/tests/util/test_basetenpointer.py +0 -227
- truss/util/basetenpointer.py +0 -160
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/WHEEL +0 -0
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/entry_points.txt +0 -0
- {truss-0.10.9rc601.dist-info → truss-0.10.10.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,16 +1,23 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from tempfile import NamedTemporaryFile
|
|
3
|
+
from unittest import mock
|
|
3
4
|
from unittest.mock import MagicMock
|
|
4
5
|
|
|
5
6
|
import pytest
|
|
7
|
+
import requests
|
|
6
8
|
|
|
7
9
|
from truss.base.constants import PRODUCTION_ENVIRONMENT_NAME
|
|
8
10
|
from truss.base.errors import ValidationError
|
|
9
11
|
from truss.remote.baseten import core
|
|
10
12
|
from truss.remote.baseten import custom_types as b10_types
|
|
11
13
|
from truss.remote.baseten.api import BasetenApi
|
|
12
|
-
from truss.remote.baseten.core import
|
|
14
|
+
from truss.remote.baseten.core import (
|
|
15
|
+
MAX_BATCH_SIZE,
|
|
16
|
+
create_truss_service,
|
|
17
|
+
get_training_job_logs_with_pagination,
|
|
18
|
+
)
|
|
13
19
|
from truss.remote.baseten.error import ApiError
|
|
20
|
+
from truss.remote.baseten.utils.time import iso_to_millis
|
|
14
21
|
|
|
15
22
|
|
|
16
23
|
def test_exists_model():
|
|
@@ -245,3 +252,512 @@ def test_validate_truss_config():
|
|
|
245
252
|
match="Validation failed with the following errors:\n error\n and another one",
|
|
246
253
|
):
|
|
247
254
|
core.validate_truss_config_against_backend(api, {"should_error": "hi"})
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def test_get_training_job_logs_with_pagination_single_batch(baseten_api):
|
|
258
|
+
"""Test pagination when all logs fit in a single batch"""
|
|
259
|
+
# Mock logs data
|
|
260
|
+
now_as_iso = "2022-01-01T00:00:00Z"
|
|
261
|
+
now_as_millis = iso_to_millis(now_as_iso)
|
|
262
|
+
mock_logs = [
|
|
263
|
+
{"timestamp": now_as_millis, "message": "Log 1"},
|
|
264
|
+
{"timestamp": now_as_millis + 60000, "message": "Log 2"},
|
|
265
|
+
{"timestamp": now_as_millis + 120000, "message": "Log 3"},
|
|
266
|
+
]
|
|
267
|
+
|
|
268
|
+
# Mock the _fetch_log_batch method to return logs on first call, empty on second
|
|
269
|
+
mock_fetch = mock.Mock(side_effect=[mock_logs, []])
|
|
270
|
+
baseten_api._fetch_log_batch = mock_fetch
|
|
271
|
+
|
|
272
|
+
# Mock get_training_job method
|
|
273
|
+
baseten_api.get_training_job = mock.Mock(
|
|
274
|
+
return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
result = get_training_job_logs_with_pagination(
|
|
278
|
+
baseten_api, "project-123", "job-456", batch_size=5
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Verify the result
|
|
282
|
+
assert result == mock_logs
|
|
283
|
+
assert len(result) == 3
|
|
284
|
+
|
|
285
|
+
# Verify the mock was called twice (once for logs, once for empty batch)
|
|
286
|
+
assert mock_fetch.call_count == 2
|
|
287
|
+
|
|
288
|
+
# Verify the first call parameters
|
|
289
|
+
first_call_args = mock_fetch.call_args_list[0]
|
|
290
|
+
assert first_call_args[0][0] == "project-123" # project_id
|
|
291
|
+
assert first_call_args[0][1] == "job-456" # job_id
|
|
292
|
+
|
|
293
|
+
# Verify the query body contains expected parameters
|
|
294
|
+
query_params = first_call_args[0][2] # query_params
|
|
295
|
+
assert query_params["limit"] == 5 # batch_size
|
|
296
|
+
assert query_params["direction"] == "asc"
|
|
297
|
+
assert "start_epoch_millis" in query_params
|
|
298
|
+
assert "end_epoch_millis" in query_params
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def test_get_training_job_logs_with_pagination_multiple_batches(baseten_api):
|
|
302
|
+
"""Test pagination when logs span multiple batches"""
|
|
303
|
+
# Mock logs data for multiple batches
|
|
304
|
+
batch1_logs = [
|
|
305
|
+
{"timestamp": "1640995200000000000", "message": "Log 1"}, # 2022-01-01 00:00:00
|
|
306
|
+
{"timestamp": "1640995260000000000", "message": "Log 2"}, # 2022-01-01 00:01:00
|
|
307
|
+
]
|
|
308
|
+
batch2_logs = [
|
|
309
|
+
{"timestamp": "1640995320000000000", "message": "Log 3"}, # 2022-01-01 00:02:00
|
|
310
|
+
{"timestamp": "1640995380000000000", "message": "Log 4"}, # 2022-01-01 00:03:00
|
|
311
|
+
]
|
|
312
|
+
batch3_logs = [
|
|
313
|
+
{"timestamp": "1640995440000000000", "message": "Log 5"} # 2022-01-01 00:04:00
|
|
314
|
+
]
|
|
315
|
+
|
|
316
|
+
# Mock the _fetch_log_batch method directly
|
|
317
|
+
mock_fetch = mock.Mock(side_effect=[batch1_logs, batch2_logs, batch3_logs, []])
|
|
318
|
+
baseten_api._fetch_log_batch = mock_fetch
|
|
319
|
+
|
|
320
|
+
# Mock get_training_job method
|
|
321
|
+
baseten_api.get_training_job = mock.Mock(
|
|
322
|
+
return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
result = get_training_job_logs_with_pagination(
|
|
326
|
+
baseten_api, "project-123", "job-456", batch_size=2
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# Verify the result
|
|
330
|
+
expected_logs = batch1_logs + batch2_logs + batch3_logs
|
|
331
|
+
assert result == expected_logs
|
|
332
|
+
assert len(result) == 5
|
|
333
|
+
|
|
334
|
+
# Verify the API calls
|
|
335
|
+
assert mock_fetch.call_count == 4 # 3 batches + 1 empty batch to stop
|
|
336
|
+
|
|
337
|
+
# Verify first call parameters
|
|
338
|
+
first_call_args = mock_fetch.call_args_list[0]
|
|
339
|
+
assert first_call_args[0][0] == "project-123" # project_id
|
|
340
|
+
assert first_call_args[0][1] == "job-456" # job_id
|
|
341
|
+
|
|
342
|
+
# Verify the query body contains expected parameters
|
|
343
|
+
query_params = first_call_args[0][2] # query_params
|
|
344
|
+
assert query_params["limit"] == 2
|
|
345
|
+
assert query_params["direction"] == "asc"
|
|
346
|
+
assert "start_epoch_millis" in query_params
|
|
347
|
+
assert "end_epoch_millis" in query_params
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def test_get_training_job_logs_with_pagination_empty_response(baseten_api):
|
|
351
|
+
"""Test pagination when no logs are returned"""
|
|
352
|
+
# Mock the _fetch_log_batch method directly
|
|
353
|
+
mock_fetch = mock.Mock(return_value=[])
|
|
354
|
+
baseten_api._fetch_log_batch = mock_fetch
|
|
355
|
+
|
|
356
|
+
# Mock get_training_job method
|
|
357
|
+
baseten_api.get_training_job = mock.Mock(
|
|
358
|
+
return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
result = get_training_job_logs_with_pagination(
|
|
362
|
+
baseten_api, "project-123", "job-456"
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Verify the result
|
|
366
|
+
assert result == []
|
|
367
|
+
assert len(result) == 0
|
|
368
|
+
|
|
369
|
+
# Verify the API call
|
|
370
|
+
mock_fetch.assert_called_once()
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def test_get_training_job_logs_with_pagination_partial_batch(baseten_api):
|
|
374
|
+
"""Test pagination when the last batch has fewer logs than batch_size"""
|
|
375
|
+
batch1_logs = [
|
|
376
|
+
{"timestamp": "1640995200000000000", "message": "Log 1"},
|
|
377
|
+
{"timestamp": "1640995260000000000", "message": "Log 2"},
|
|
378
|
+
]
|
|
379
|
+
batch2_logs = [{"timestamp": "1640995320000000000", "message": "Log 3"}]
|
|
380
|
+
|
|
381
|
+
# Mock the _fetch_log_batch method directly
|
|
382
|
+
mock_fetch = mock.Mock(side_effect=[batch1_logs, batch2_logs, []])
|
|
383
|
+
baseten_api._fetch_log_batch = mock_fetch
|
|
384
|
+
|
|
385
|
+
# Mock get_training_job method
|
|
386
|
+
baseten_api.get_training_job = mock.Mock(
|
|
387
|
+
return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
result = get_training_job_logs_with_pagination(
|
|
391
|
+
baseten_api, "project-123", "job-456", batch_size=2
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Verify the result
|
|
395
|
+
expected_logs = batch1_logs + batch2_logs
|
|
396
|
+
assert result == expected_logs
|
|
397
|
+
assert len(result) == 3
|
|
398
|
+
|
|
399
|
+
# Verify only 3 API calls (2 batches + 1 empty batch to stop)
|
|
400
|
+
assert mock_fetch.call_count == 3
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def test_get_training_job_logs_with_pagination_max_iterations(baseten_api):
|
|
404
|
+
"""Test pagination when maximum iterations are reached"""
|
|
405
|
+
# Mock logs that would cause infinite pagination
|
|
406
|
+
batch_logs = [
|
|
407
|
+
{"timestamp": "1640995200000000000", "message": "Log 1"},
|
|
408
|
+
{"timestamp": "1640995260000000000", "message": "Log 2"},
|
|
409
|
+
]
|
|
410
|
+
|
|
411
|
+
# Mock the _fetch_log_batch method directly
|
|
412
|
+
# Configure mock to always return the same batch (simulating infinite pagination)
|
|
413
|
+
mock_fetch = mock.Mock(return_value=batch_logs)
|
|
414
|
+
baseten_api._fetch_log_batch = mock_fetch
|
|
415
|
+
|
|
416
|
+
# Mock get_training_job method
|
|
417
|
+
baseten_api.get_training_job = mock.Mock(
|
|
418
|
+
return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
result = get_training_job_logs_with_pagination(
|
|
422
|
+
baseten_api, "project-123", "job-456", batch_size=2
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# Verify the result (should have MAX_ITERATIONS * batch_size logs)
|
|
426
|
+
from truss.remote.baseten.core import MAX_ITERATIONS
|
|
427
|
+
|
|
428
|
+
expected_log_count = MAX_ITERATIONS * 2
|
|
429
|
+
assert len(result) == expected_log_count
|
|
430
|
+
|
|
431
|
+
# Verify MAX_ITERATIONS API calls were made
|
|
432
|
+
assert mock_fetch.call_count == MAX_ITERATIONS
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def test_get_training_job_logs_with_pagination_api_error(baseten_api):
|
|
436
|
+
"""Test pagination when API returns an error"""
|
|
437
|
+
# Mock the _fetch_log_batch method directly
|
|
438
|
+
# Configure mock to raise an exception
|
|
439
|
+
mock_fetch = mock.Mock(side_effect=Exception("API Error"))
|
|
440
|
+
baseten_api._fetch_log_batch = mock_fetch
|
|
441
|
+
|
|
442
|
+
# Mock get_training_job method
|
|
443
|
+
baseten_api.get_training_job = mock.Mock(
|
|
444
|
+
return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
result = get_training_job_logs_with_pagination(
|
|
448
|
+
baseten_api, "project-123", "job-456"
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
# Verify the result is empty when error occurs
|
|
452
|
+
assert result == []
|
|
453
|
+
assert len(result) == 0
|
|
454
|
+
|
|
455
|
+
# Verify the API call was attempted
|
|
456
|
+
mock_fetch.assert_called_once()
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def test_get_training_job_logs_with_pagination_custom_batch_size(baseten_api):
|
|
460
|
+
"""Test pagination with custom batch size"""
|
|
461
|
+
mock_logs = [
|
|
462
|
+
{"timestamp": "1640995200000000000", "message": "Log 1"},
|
|
463
|
+
{"timestamp": "1640995260000000000", "message": "Log 2"},
|
|
464
|
+
]
|
|
465
|
+
|
|
466
|
+
# Mock the _fetch_log_batch method directly
|
|
467
|
+
mock_fetch = mock.Mock(side_effect=[mock_logs, []])
|
|
468
|
+
baseten_api._fetch_log_batch = mock_fetch
|
|
469
|
+
|
|
470
|
+
# Mock get_training_job method
|
|
471
|
+
baseten_api.get_training_job = mock.Mock(
|
|
472
|
+
return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
result = get_training_job_logs_with_pagination(
|
|
476
|
+
baseten_api, "project-123", "job-456", batch_size=50
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
# Verify the result
|
|
480
|
+
assert result == mock_logs
|
|
481
|
+
|
|
482
|
+
# Verify the API call used custom batch size
|
|
483
|
+
call_args = mock_fetch.call_args
|
|
484
|
+
query_params = call_args[0][2] # query_params
|
|
485
|
+
assert query_params["limit"] == 50
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def test_get_training_job_logs_with_pagination_six_batches(baseten_api):
|
|
489
|
+
"""Test pagination with six batches"""
|
|
490
|
+
iso_time = "2022-01-01T00:00:00Z"
|
|
491
|
+
now_as_millis = iso_to_millis(iso_time)
|
|
492
|
+
mock_logs_batch_1 = [
|
|
493
|
+
{"timestamp": now_as_millis + 1000, "message": "Log 1"},
|
|
494
|
+
{"timestamp": now_as_millis + 2000, "message": "Log 2"},
|
|
495
|
+
]
|
|
496
|
+
mock_logs_batch_2 = [
|
|
497
|
+
{"timestamp": now_as_millis + 3000, "message": "Log 3"},
|
|
498
|
+
{"timestamp": now_as_millis + 4000, "message": "Log 4"},
|
|
499
|
+
]
|
|
500
|
+
mock_logs_batch_3 = [{"timestamp": now_as_millis + 5000, "message": "Log 5"}]
|
|
501
|
+
|
|
502
|
+
# Mock the _fetch_log_batch method directly
|
|
503
|
+
mock_fetch = mock.Mock(
|
|
504
|
+
side_effect=[mock_logs_batch_1, mock_logs_batch_2, mock_logs_batch_3, []]
|
|
505
|
+
)
|
|
506
|
+
baseten_api._fetch_log_batch = mock_fetch
|
|
507
|
+
|
|
508
|
+
# Mock get_training_job method
|
|
509
|
+
baseten_api.get_training_job = mock.Mock(
|
|
510
|
+
return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
result = get_training_job_logs_with_pagination(
|
|
514
|
+
baseten_api, "project-123", "job-456", batch_size=2
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
# Verify the result
|
|
518
|
+
assert result == mock_logs_batch_1 + mock_logs_batch_2 + mock_logs_batch_3
|
|
519
|
+
assert mock_fetch.call_count == 4 # 3 batches + 1 empty batch to stop
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def test_get_training_job_logs_with_pagination_timestamp_conversion(baseten_api):
|
|
523
|
+
"""Test that timestamp conversion from nanoseconds to milliseconds works correctly"""
|
|
524
|
+
batch1_logs = [
|
|
525
|
+
{"timestamp": "1640995200000000000", "message": "Log 1"} # 1640995200000 ms
|
|
526
|
+
]
|
|
527
|
+
batch2_logs = [
|
|
528
|
+
{"timestamp": "1640995260000000000", "message": "Log 2"} # 1640995260000 ms
|
|
529
|
+
]
|
|
530
|
+
|
|
531
|
+
# Mock the _fetch_log_batch method directly
|
|
532
|
+
mock_fetch = mock.Mock(side_effect=[batch1_logs, batch2_logs, []])
|
|
533
|
+
baseten_api._fetch_log_batch = mock_fetch
|
|
534
|
+
|
|
535
|
+
# Mock get_training_job method
|
|
536
|
+
baseten_api.get_training_job = mock.Mock(
|
|
537
|
+
return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
result = get_training_job_logs_with_pagination(
|
|
541
|
+
baseten_api, "project-123", "job-456", batch_size=1
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
# Verify the result
|
|
545
|
+
expected_logs = batch1_logs + batch2_logs
|
|
546
|
+
assert result == expected_logs
|
|
547
|
+
|
|
548
|
+
# Verify the second call uses correct timestamp conversion
|
|
549
|
+
second_call = mock_fetch.call_args_list[1]
|
|
550
|
+
query_params = second_call[0][2] # query_params
|
|
551
|
+
# Should be 1640995200000 + 1 = 1640995200001
|
|
552
|
+
assert query_params["start_epoch_millis"] == 1640995200001
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
def test_get_training_job_logs_with_pagination_query_body_filtering(baseten_api):
|
|
556
|
+
"""Test that None values are properly filtered from query body"""
|
|
557
|
+
mock_logs = [{"timestamp": "1640995200000000000", "message": "Log 1"}]
|
|
558
|
+
|
|
559
|
+
# Mock the _fetch_log_batch method directly
|
|
560
|
+
mock_fetch = mock.Mock(side_effect=[mock_logs, []])
|
|
561
|
+
baseten_api._fetch_log_batch = mock_fetch
|
|
562
|
+
|
|
563
|
+
# Mock get_training_job method
|
|
564
|
+
baseten_api.get_training_job = mock.Mock(
|
|
565
|
+
return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
get_training_job_logs_with_pagination(baseten_api, "project-123", "job-456")
|
|
569
|
+
|
|
570
|
+
# Verify the API call
|
|
571
|
+
call_args = mock_fetch.call_args
|
|
572
|
+
query_params = call_args[0][2] # query_params
|
|
573
|
+
|
|
574
|
+
# Verify that all required values are included in the query body
|
|
575
|
+
assert "start_epoch_millis" in query_params
|
|
576
|
+
assert "end_epoch_millis" in query_params
|
|
577
|
+
assert "limit" in query_params
|
|
578
|
+
assert "direction" in query_params
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
# Tests for new helper methods
|
|
582
|
+
def test_build_log_query_params(baseten_api):
|
|
583
|
+
"""Test _build_log_query_params helper method"""
|
|
584
|
+
from truss.remote.baseten.core import _build_log_query_params
|
|
585
|
+
|
|
586
|
+
# Test with all parameters
|
|
587
|
+
query_params = _build_log_query_params(
|
|
588
|
+
start_time=1640995200000, end_time=1640995260000, batch_size=100
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
expected = {
|
|
592
|
+
"start_epoch_millis": 1640995200000,
|
|
593
|
+
"end_epoch_millis": 1640995260000,
|
|
594
|
+
"limit": 100,
|
|
595
|
+
"direction": "asc",
|
|
596
|
+
}
|
|
597
|
+
assert query_params == expected
|
|
598
|
+
|
|
599
|
+
# Test with None values (should be filtered out)
|
|
600
|
+
query_params = _build_log_query_params(
|
|
601
|
+
start_time=None, end_time=None, batch_size=50
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
expected = {"limit": 50, "direction": "asc"}
|
|
605
|
+
assert query_params == expected
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
def test_handle_server_error_backoff(baseten_api):
|
|
609
|
+
"""Test _handle_server_error_backoff helper method"""
|
|
610
|
+
from truss.remote.baseten.core import _handle_server_error_backoff
|
|
611
|
+
|
|
612
|
+
# Create a mock HTTP error
|
|
613
|
+
mock_response = mock.Mock()
|
|
614
|
+
mock_response.status_code = 500
|
|
615
|
+
|
|
616
|
+
mock_error = requests.HTTPError("Server Error")
|
|
617
|
+
mock_error.response = mock_response
|
|
618
|
+
|
|
619
|
+
# Test backoff behavior
|
|
620
|
+
new_batch_size = _handle_server_error_backoff(mock_error, "job-456", 1, 1000)
|
|
621
|
+
|
|
622
|
+
# Should reduce batch size by half
|
|
623
|
+
assert new_batch_size == 500
|
|
624
|
+
|
|
625
|
+
# Test minimum batch size
|
|
626
|
+
new_batch_size = _handle_server_error_backoff(mock_error, "job-456", 2, 150)
|
|
627
|
+
|
|
628
|
+
# Should not go below 100
|
|
629
|
+
assert new_batch_size == 100
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
def test_process_batch_logs_continue(baseten_api):
|
|
633
|
+
"""Test _process_batch_logs when pagination should continue"""
|
|
634
|
+
from truss.remote.baseten.core import _process_batch_logs
|
|
635
|
+
|
|
636
|
+
batch_logs = [
|
|
637
|
+
{"timestamp": "1640995200000000000", "message": "Log 1"},
|
|
638
|
+
{"timestamp": "1640995260000000000", "message": "Log 2"},
|
|
639
|
+
]
|
|
640
|
+
|
|
641
|
+
should_continue, next_start_time, next_end_time = _process_batch_logs(
|
|
642
|
+
batch_logs, "job-456", 1, 2
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
assert should_continue is True
|
|
646
|
+
# Should be 1640995260000 + 1 = 1640995260001 (last timestamp + 1ms)
|
|
647
|
+
assert next_start_time == 1640995260001
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
def test_process_batch_logs_empty(baseten_api):
|
|
651
|
+
"""Test _process_batch_logs when batch is empty"""
|
|
652
|
+
from truss.remote.baseten.core import _process_batch_logs
|
|
653
|
+
|
|
654
|
+
should_continue, next_start_time, next_end_time = _process_batch_logs(
|
|
655
|
+
[], "job-456", 1, 100
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
assert should_continue is False
|
|
659
|
+
assert next_start_time is None
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
def test_process_batch_logs_partial(baseten_api):
|
|
663
|
+
"""Test _process_batch_logs when batch is smaller than expected"""
|
|
664
|
+
from truss.remote.baseten.core import _process_batch_logs
|
|
665
|
+
|
|
666
|
+
batch_logs = [{"timestamp": "1640995200000000000", "message": "Log 1"}]
|
|
667
|
+
|
|
668
|
+
should_continue, next_start_time, next_end_time = _process_batch_logs(
|
|
669
|
+
batch_logs, "job-456", 1, 100
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
assert should_continue is True
|
|
673
|
+
assert next_start_time is not None
|
|
674
|
+
assert next_end_time is not None
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
def test_get_training_job_logs_with_pagination_server_error_retry(baseten_api):
|
|
678
|
+
"""Test pagination with server error retry logic"""
|
|
679
|
+
batch_logs = [
|
|
680
|
+
{"timestamp": "1640995200000000000", "message": "Log 1"},
|
|
681
|
+
{"timestamp": "1640995260000000000", "message": "Log 2"},
|
|
682
|
+
]
|
|
683
|
+
|
|
684
|
+
# Mock the _fetch_log_batch method directly
|
|
685
|
+
# First call fails with 500, second call succeeds
|
|
686
|
+
mock_response_500 = mock.Mock()
|
|
687
|
+
mock_response_500.status_code = 500
|
|
688
|
+
mock_error_500 = requests.HTTPError("Server Error")
|
|
689
|
+
mock_error_500.response = mock_response_500
|
|
690
|
+
|
|
691
|
+
mock_fetch = mock.Mock(
|
|
692
|
+
side_effect=[
|
|
693
|
+
mock_error_500, # First call fails
|
|
694
|
+
batch_logs, # Second call succeeds
|
|
695
|
+
]
|
|
696
|
+
)
|
|
697
|
+
baseten_api._fetch_log_batch = mock_fetch
|
|
698
|
+
|
|
699
|
+
# Mock get_training_job method
|
|
700
|
+
baseten_api.get_training_job = mock.Mock(
|
|
701
|
+
return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
result = get_training_job_logs_with_pagination(
|
|
705
|
+
baseten_api, "project-123", "job-456", batch_size=1000
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
# Should get the logs after retry
|
|
709
|
+
assert result == batch_logs
|
|
710
|
+
|
|
711
|
+
# Should have made 3 calls (first fails, retry with reduced batch size, then succeeds)
|
|
712
|
+
assert mock_fetch.call_count == 3
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
def test_get_training_job_logs_with_pagination_non_server_error(baseten_api):
|
|
716
|
+
"""Test pagination with non-server error (should not retry)"""
|
|
717
|
+
# Mock the _fetch_log_batch method directly
|
|
718
|
+
|
|
719
|
+
# Mock a 400 error (client error, not server error)
|
|
720
|
+
mock_response_400 = mock.Mock()
|
|
721
|
+
mock_response_400.status_code = 400
|
|
722
|
+
mock_error_400 = requests.HTTPError("Bad Request")
|
|
723
|
+
mock_error_400.response = mock_response_400
|
|
724
|
+
|
|
725
|
+
mock_fetch = mock.Mock(side_effect=mock_error_400)
|
|
726
|
+
baseten_api._fetch_log_batch = mock_fetch
|
|
727
|
+
|
|
728
|
+
# Mock get_training_job method
|
|
729
|
+
baseten_api.get_training_job = mock.Mock(
|
|
730
|
+
return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
result = get_training_job_logs_with_pagination(
|
|
734
|
+
baseten_api, "project-123", "job-456"
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
# Should return empty list on non-server error
|
|
738
|
+
assert result == []
|
|
739
|
+
|
|
740
|
+
# Should have made only 1 call (no retry)
|
|
741
|
+
assert mock_fetch.call_count == 1
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
def test_get_training_job_logs_with_pagination_default_batch_size(baseten_api):
|
|
745
|
+
"""Test that default batch size is MAX_BATCH_SIZE"""
|
|
746
|
+
mock_logs = [{"timestamp": "1640995200000000000", "message": "Log 1"}]
|
|
747
|
+
|
|
748
|
+
# Mock the _fetch_log_batch method directly
|
|
749
|
+
mock_fetch = mock.Mock(side_effect=[mock_logs, []])
|
|
750
|
+
baseten_api._fetch_log_batch = mock_fetch
|
|
751
|
+
|
|
752
|
+
# Mock get_training_job method
|
|
753
|
+
baseten_api.get_training_job = mock.Mock(
|
|
754
|
+
return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
get_training_job_logs_with_pagination(baseten_api, "project-123", "job-456")
|
|
758
|
+
|
|
759
|
+
# Verify the API call used default batch size
|
|
760
|
+
call_args = mock_fetch.call_args
|
|
761
|
+
query_params = call_args[0][2] # query_params
|
|
762
|
+
|
|
763
|
+
assert query_params["limit"] == MAX_BATCH_SIZE
|
|
@@ -106,7 +106,6 @@ class DockerURLs:
|
|
|
106
106
|
self.predict_url = f"{base_url}/v1/models/model:predict"
|
|
107
107
|
self.completions_url = f"{base_url}/v1/completions"
|
|
108
108
|
self.chat_completions_url = f"{base_url}/v1/chat/completions"
|
|
109
|
-
self.messages_url = f"{base_url}/v1/messages"
|
|
110
109
|
|
|
111
110
|
self.schema_url = f"{base_url}/v1/models/model/schema"
|
|
112
111
|
self.metrics_url = f"{base_url}/metrics"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: truss
|
|
3
|
-
Version: 0.10.
|
|
3
|
+
Version: 0.10.10
|
|
4
4
|
Summary: A seamless bridge from model development to model delivery
|
|
5
5
|
Project-URL: Repository, https://github.com/basetenlabs/truss
|
|
6
6
|
Project-URL: Homepage, https://truss.baseten.co
|
|
@@ -37,7 +37,7 @@ Requires-Dist: rich<14,>=13.4.2
|
|
|
37
37
|
Requires-Dist: ruff>=0.4.8
|
|
38
38
|
Requires-Dist: tenacity>=8.0.1
|
|
39
39
|
Requires-Dist: tomlkit>=0.13.2
|
|
40
|
-
Requires-Dist: truss-transfer==0.0.
|
|
40
|
+
Requires-Dist: truss-transfer==0.0.29
|
|
41
41
|
Requires-Dist: watchfiles<0.20,>=0.19.0
|
|
42
42
|
Description-Content-Type: text/markdown
|
|
43
43
|
|