truss 0.10.9rc601__py3-none-any.whl → 0.10.10rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (32) hide show
  1. truss/base/constants.py +0 -1
  2. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +30 -22
  3. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +8 -2
  4. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +2 -2
  5. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +63 -0
  6. truss/cli/train/deploy_from_checkpoint_config_whisper.yml +17 -0
  7. truss/cli/train_commands.py +11 -3
  8. truss/contexts/image_builder/cache_warmer.py +1 -3
  9. truss/contexts/image_builder/serving_image_builder.py +24 -32
  10. truss/remote/baseten/api.py +11 -0
  11. truss/remote/baseten/core.py +209 -1
  12. truss/remote/baseten/utils/time.py +15 -0
  13. truss/templates/server/model_wrapper.py +0 -12
  14. truss/templates/server/requirements.txt +1 -1
  15. truss/templates/server/truss_server.py +0 -13
  16. truss/templates/server.Dockerfile.jinja +1 -1
  17. truss/tests/cli/train/test_deploy_checkpoints.py +436 -0
  18. truss/tests/contexts/image_builder/test_serving_image_builder.py +1 -1
  19. truss/tests/remote/baseten/conftest.py +18 -0
  20. truss/tests/remote/baseten/test_api.py +49 -14
  21. truss/tests/remote/baseten/test_core.py +517 -1
  22. truss/tests/test_data/test_openai/model/model.py +0 -3
  23. truss/truss_handle/truss_handle.py +0 -1
  24. {truss-0.10.9rc601.dist-info → truss-0.10.10rc1.dist-info}/METADATA +2 -2
  25. {truss-0.10.9rc601.dist-info → truss-0.10.10rc1.dist-info}/RECORD +30 -28
  26. truss_train/definitions.py +6 -0
  27. truss_train/deployment.py +15 -2
  28. truss/tests/util/test_basetenpointer.py +0 -227
  29. truss/util/basetenpointer.py +0 -160
  30. {truss-0.10.9rc601.dist-info → truss-0.10.10rc1.dist-info}/WHEEL +0 -0
  31. {truss-0.10.9rc601.dist-info → truss-0.10.10rc1.dist-info}/entry_points.txt +0 -0
  32. {truss-0.10.9rc601.dist-info → truss-0.10.10rc1.dist-info}/licenses/LICENSE +0 -0
@@ -1,16 +1,23 @@
1
1
  import json
2
2
  from tempfile import NamedTemporaryFile
3
+ from unittest import mock
3
4
  from unittest.mock import MagicMock
4
5
 
5
6
  import pytest
7
+ import requests
6
8
 
7
9
  from truss.base.constants import PRODUCTION_ENVIRONMENT_NAME
8
10
  from truss.base.errors import ValidationError
9
11
  from truss.remote.baseten import core
10
12
  from truss.remote.baseten import custom_types as b10_types
11
13
  from truss.remote.baseten.api import BasetenApi
12
- from truss.remote.baseten.core import create_truss_service
14
+ from truss.remote.baseten.core import (
15
+ MAX_BATCH_SIZE,
16
+ create_truss_service,
17
+ get_training_job_logs_with_pagination,
18
+ )
13
19
  from truss.remote.baseten.error import ApiError
20
+ from truss.remote.baseten.utils.time import iso_to_millis
14
21
 
15
22
 
16
23
  def test_exists_model():
@@ -245,3 +252,512 @@ def test_validate_truss_config():
245
252
  match="Validation failed with the following errors:\n error\n and another one",
246
253
  ):
247
254
  core.validate_truss_config_against_backend(api, {"should_error": "hi"})
255
+
256
+
257
+ def test_get_training_job_logs_with_pagination_single_batch(baseten_api):
258
+ """Test pagination when all logs fit in a single batch"""
259
+ # Mock logs data
260
+ now_as_iso = "2022-01-01T00:00:00Z"
261
+ now_as_millis = iso_to_millis(now_as_iso)
262
+ mock_logs = [
263
+ {"timestamp": now_as_millis, "message": "Log 1"},
264
+ {"timestamp": now_as_millis + 60000, "message": "Log 2"},
265
+ {"timestamp": now_as_millis + 120000, "message": "Log 3"},
266
+ ]
267
+
268
+ # Mock the _fetch_log_batch method to return logs on first call, empty on second
269
+ mock_fetch = mock.Mock(side_effect=[mock_logs, []])
270
+ baseten_api._fetch_log_batch = mock_fetch
271
+
272
+ # Mock get_training_job method
273
+ baseten_api.get_training_job = mock.Mock(
274
+ return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
275
+ )
276
+
277
+ result = get_training_job_logs_with_pagination(
278
+ baseten_api, "project-123", "job-456", batch_size=5
279
+ )
280
+
281
+ # Verify the result
282
+ assert result == mock_logs
283
+ assert len(result) == 3
284
+
285
+ # Verify the mock was called twice (once for logs, once for empty batch)
286
+ assert mock_fetch.call_count == 2
287
+
288
+ # Verify the first call parameters
289
+ first_call_args = mock_fetch.call_args_list[0]
290
+ assert first_call_args[0][0] == "project-123" # project_id
291
+ assert first_call_args[0][1] == "job-456" # job_id
292
+
293
+ # Verify the query body contains expected parameters
294
+ query_params = first_call_args[0][2] # query_params
295
+ assert query_params["limit"] == 5 # batch_size
296
+ assert query_params["direction"] == "asc"
297
+ assert "start_epoch_millis" in query_params
298
+ assert "end_epoch_millis" in query_params
299
+
300
+
301
+ def test_get_training_job_logs_with_pagination_multiple_batches(baseten_api):
302
+ """Test pagination when logs span multiple batches"""
303
+ # Mock logs data for multiple batches
304
+ batch1_logs = [
305
+ {"timestamp": "1640995200000000000", "message": "Log 1"}, # 2022-01-01 00:00:00
306
+ {"timestamp": "1640995260000000000", "message": "Log 2"}, # 2022-01-01 00:01:00
307
+ ]
308
+ batch2_logs = [
309
+ {"timestamp": "1640995320000000000", "message": "Log 3"}, # 2022-01-01 00:02:00
310
+ {"timestamp": "1640995380000000000", "message": "Log 4"}, # 2022-01-01 00:03:00
311
+ ]
312
+ batch3_logs = [
313
+ {"timestamp": "1640995440000000000", "message": "Log 5"} # 2022-01-01 00:04:00
314
+ ]
315
+
316
+ # Mock the _fetch_log_batch method directly
317
+ mock_fetch = mock.Mock(side_effect=[batch1_logs, batch2_logs, batch3_logs, []])
318
+ baseten_api._fetch_log_batch = mock_fetch
319
+
320
+ # Mock get_training_job method
321
+ baseten_api.get_training_job = mock.Mock(
322
+ return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
323
+ )
324
+
325
+ result = get_training_job_logs_with_pagination(
326
+ baseten_api, "project-123", "job-456", batch_size=2
327
+ )
328
+
329
+ # Verify the result
330
+ expected_logs = batch1_logs + batch2_logs + batch3_logs
331
+ assert result == expected_logs
332
+ assert len(result) == 5
333
+
334
+ # Verify the API calls
335
+ assert mock_fetch.call_count == 4 # 3 batches + 1 empty batch to stop
336
+
337
+ # Verify first call parameters
338
+ first_call_args = mock_fetch.call_args_list[0]
339
+ assert first_call_args[0][0] == "project-123" # project_id
340
+ assert first_call_args[0][1] == "job-456" # job_id
341
+
342
+ # Verify the query body contains expected parameters
343
+ query_params = first_call_args[0][2] # query_params
344
+ assert query_params["limit"] == 2
345
+ assert query_params["direction"] == "asc"
346
+ assert "start_epoch_millis" in query_params
347
+ assert "end_epoch_millis" in query_params
348
+
349
+
350
+ def test_get_training_job_logs_with_pagination_empty_response(baseten_api):
351
+ """Test pagination when no logs are returned"""
352
+ # Mock the _fetch_log_batch method directly
353
+ mock_fetch = mock.Mock(return_value=[])
354
+ baseten_api._fetch_log_batch = mock_fetch
355
+
356
+ # Mock get_training_job method
357
+ baseten_api.get_training_job = mock.Mock(
358
+ return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
359
+ )
360
+
361
+ result = get_training_job_logs_with_pagination(
362
+ baseten_api, "project-123", "job-456"
363
+ )
364
+
365
+ # Verify the result
366
+ assert result == []
367
+ assert len(result) == 0
368
+
369
+ # Verify the API call
370
+ mock_fetch.assert_called_once()
371
+
372
+
373
+ def test_get_training_job_logs_with_pagination_partial_batch(baseten_api):
374
+ """Test pagination when the last batch has fewer logs than batch_size"""
375
+ batch1_logs = [
376
+ {"timestamp": "1640995200000000000", "message": "Log 1"},
377
+ {"timestamp": "1640995260000000000", "message": "Log 2"},
378
+ ]
379
+ batch2_logs = [{"timestamp": "1640995320000000000", "message": "Log 3"}]
380
+
381
+ # Mock the _fetch_log_batch method directly
382
+ mock_fetch = mock.Mock(side_effect=[batch1_logs, batch2_logs, []])
383
+ baseten_api._fetch_log_batch = mock_fetch
384
+
385
+ # Mock get_training_job method
386
+ baseten_api.get_training_job = mock.Mock(
387
+ return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
388
+ )
389
+
390
+ result = get_training_job_logs_with_pagination(
391
+ baseten_api, "project-123", "job-456", batch_size=2
392
+ )
393
+
394
+ # Verify the result
395
+ expected_logs = batch1_logs + batch2_logs
396
+ assert result == expected_logs
397
+ assert len(result) == 3
398
+
399
+ # Verify only 3 API calls (2 batches + 1 empty batch to stop)
400
+ assert mock_fetch.call_count == 3
401
+
402
+
403
+ def test_get_training_job_logs_with_pagination_max_iterations(baseten_api):
404
+ """Test pagination when maximum iterations are reached"""
405
+ # Mock logs that would cause infinite pagination
406
+ batch_logs = [
407
+ {"timestamp": "1640995200000000000", "message": "Log 1"},
408
+ {"timestamp": "1640995260000000000", "message": "Log 2"},
409
+ ]
410
+
411
+ # Mock the _fetch_log_batch method directly
412
+ # Configure mock to always return the same batch (simulating infinite pagination)
413
+ mock_fetch = mock.Mock(return_value=batch_logs)
414
+ baseten_api._fetch_log_batch = mock_fetch
415
+
416
+ # Mock get_training_job method
417
+ baseten_api.get_training_job = mock.Mock(
418
+ return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
419
+ )
420
+
421
+ result = get_training_job_logs_with_pagination(
422
+ baseten_api, "project-123", "job-456", batch_size=2
423
+ )
424
+
425
+ # Verify the result (should have MAX_ITERATIONS * batch_size logs)
426
+ from truss.remote.baseten.core import MAX_ITERATIONS
427
+
428
+ expected_log_count = MAX_ITERATIONS * 2
429
+ assert len(result) == expected_log_count
430
+
431
+ # Verify MAX_ITERATIONS API calls were made
432
+ assert mock_fetch.call_count == MAX_ITERATIONS
433
+
434
+
435
+ def test_get_training_job_logs_with_pagination_api_error(baseten_api):
436
+ """Test pagination when API returns an error"""
437
+ # Mock the _fetch_log_batch method directly
438
+ # Configure mock to raise an exception
439
+ mock_fetch = mock.Mock(side_effect=Exception("API Error"))
440
+ baseten_api._fetch_log_batch = mock_fetch
441
+
442
+ # Mock get_training_job method
443
+ baseten_api.get_training_job = mock.Mock(
444
+ return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
445
+ )
446
+
447
+ result = get_training_job_logs_with_pagination(
448
+ baseten_api, "project-123", "job-456"
449
+ )
450
+
451
+ # Verify the result is empty when error occurs
452
+ assert result == []
453
+ assert len(result) == 0
454
+
455
+ # Verify the API call was attempted
456
+ mock_fetch.assert_called_once()
457
+
458
+
459
+ def test_get_training_job_logs_with_pagination_custom_batch_size(baseten_api):
460
+ """Test pagination with custom batch size"""
461
+ mock_logs = [
462
+ {"timestamp": "1640995200000000000", "message": "Log 1"},
463
+ {"timestamp": "1640995260000000000", "message": "Log 2"},
464
+ ]
465
+
466
+ # Mock the _fetch_log_batch method directly
467
+ mock_fetch = mock.Mock(side_effect=[mock_logs, []])
468
+ baseten_api._fetch_log_batch = mock_fetch
469
+
470
+ # Mock get_training_job method
471
+ baseten_api.get_training_job = mock.Mock(
472
+ return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
473
+ )
474
+
475
+ result = get_training_job_logs_with_pagination(
476
+ baseten_api, "project-123", "job-456", batch_size=50
477
+ )
478
+
479
+ # Verify the result
480
+ assert result == mock_logs
481
+
482
+ # Verify the API call used custom batch size
483
+ call_args = mock_fetch.call_args
484
+ query_params = call_args[0][2] # query_params
485
+ assert query_params["limit"] == 50
486
+
487
+
488
+ def test_get_training_job_logs_with_pagination_six_batches(baseten_api):
489
+ """Test pagination with six batches"""
490
+ iso_time = "2022-01-01T00:00:00Z"
491
+ now_as_millis = iso_to_millis(iso_time)
492
+ mock_logs_batch_1 = [
493
+ {"timestamp": now_as_millis + 1000, "message": "Log 1"},
494
+ {"timestamp": now_as_millis + 2000, "message": "Log 2"},
495
+ ]
496
+ mock_logs_batch_2 = [
497
+ {"timestamp": now_as_millis + 3000, "message": "Log 3"},
498
+ {"timestamp": now_as_millis + 4000, "message": "Log 4"},
499
+ ]
500
+ mock_logs_batch_3 = [{"timestamp": now_as_millis + 5000, "message": "Log 5"}]
501
+
502
+ # Mock the _fetch_log_batch method directly
503
+ mock_fetch = mock.Mock(
504
+ side_effect=[mock_logs_batch_1, mock_logs_batch_2, mock_logs_batch_3, []]
505
+ )
506
+ baseten_api._fetch_log_batch = mock_fetch
507
+
508
+ # Mock get_training_job method
509
+ baseten_api.get_training_job = mock.Mock(
510
+ return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
511
+ )
512
+
513
+ result = get_training_job_logs_with_pagination(
514
+ baseten_api, "project-123", "job-456", batch_size=2
515
+ )
516
+
517
+ # Verify the result
518
+ assert result == mock_logs_batch_1 + mock_logs_batch_2 + mock_logs_batch_3
519
+ assert mock_fetch.call_count == 4 # 3 batches + 1 empty batch to stop
520
+
521
+
522
+ def test_get_training_job_logs_with_pagination_timestamp_conversion(baseten_api):
523
+ """Test that timestamp conversion from nanoseconds to milliseconds works correctly"""
524
+ batch1_logs = [
525
+ {"timestamp": "1640995200000000000", "message": "Log 1"} # 1640995200000 ms
526
+ ]
527
+ batch2_logs = [
528
+ {"timestamp": "1640995260000000000", "message": "Log 2"} # 1640995260000 ms
529
+ ]
530
+
531
+ # Mock the _fetch_log_batch method directly
532
+ mock_fetch = mock.Mock(side_effect=[batch1_logs, batch2_logs, []])
533
+ baseten_api._fetch_log_batch = mock_fetch
534
+
535
+ # Mock get_training_job method
536
+ baseten_api.get_training_job = mock.Mock(
537
+ return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
538
+ )
539
+
540
+ result = get_training_job_logs_with_pagination(
541
+ baseten_api, "project-123", "job-456", batch_size=1
542
+ )
543
+
544
+ # Verify the result
545
+ expected_logs = batch1_logs + batch2_logs
546
+ assert result == expected_logs
547
+
548
+ # Verify the second call uses correct timestamp conversion
549
+ second_call = mock_fetch.call_args_list[1]
550
+ query_params = second_call[0][2] # query_params
551
+ # Should be 1640995200000 + 1 = 1640995200001
552
+ assert query_params["start_epoch_millis"] == 1640995200001
553
+
554
+
555
+ def test_get_training_job_logs_with_pagination_query_body_filtering(baseten_api):
556
+ """Test that None values are properly filtered from query body"""
557
+ mock_logs = [{"timestamp": "1640995200000000000", "message": "Log 1"}]
558
+
559
+ # Mock the _fetch_log_batch method directly
560
+ mock_fetch = mock.Mock(side_effect=[mock_logs, []])
561
+ baseten_api._fetch_log_batch = mock_fetch
562
+
563
+ # Mock get_training_job method
564
+ baseten_api.get_training_job = mock.Mock(
565
+ return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
566
+ )
567
+
568
+ get_training_job_logs_with_pagination(baseten_api, "project-123", "job-456")
569
+
570
+ # Verify the API call
571
+ call_args = mock_fetch.call_args
572
+ query_params = call_args[0][2] # query_params
573
+
574
+ # Verify that all required values are included in the query body
575
+ assert "start_epoch_millis" in query_params
576
+ assert "end_epoch_millis" in query_params
577
+ assert "limit" in query_params
578
+ assert "direction" in query_params
579
+
580
+
581
+ # Tests for new helper methods
582
+ def test_build_log_query_params(baseten_api):
583
+ """Test _build_log_query_params helper method"""
584
+ from truss.remote.baseten.core import _build_log_query_params
585
+
586
+ # Test with all parameters
587
+ query_params = _build_log_query_params(
588
+ start_time=1640995200000, end_time=1640995260000, batch_size=100
589
+ )
590
+
591
+ expected = {
592
+ "start_epoch_millis": 1640995200000,
593
+ "end_epoch_millis": 1640995260000,
594
+ "limit": 100,
595
+ "direction": "asc",
596
+ }
597
+ assert query_params == expected
598
+
599
+ # Test with None values (should be filtered out)
600
+ query_params = _build_log_query_params(
601
+ start_time=None, end_time=None, batch_size=50
602
+ )
603
+
604
+ expected = {"limit": 50, "direction": "asc"}
605
+ assert query_params == expected
606
+
607
+
608
+ def test_handle_server_error_backoff(baseten_api):
609
+ """Test _handle_server_error_backoff helper method"""
610
+ from truss.remote.baseten.core import _handle_server_error_backoff
611
+
612
+ # Create a mock HTTP error
613
+ mock_response = mock.Mock()
614
+ mock_response.status_code = 500
615
+
616
+ mock_error = requests.HTTPError("Server Error")
617
+ mock_error.response = mock_response
618
+
619
+ # Test backoff behavior
620
+ new_batch_size = _handle_server_error_backoff(mock_error, "job-456", 1, 1000)
621
+
622
+ # Should reduce batch size by half
623
+ assert new_batch_size == 500
624
+
625
+ # Test minimum batch size
626
+ new_batch_size = _handle_server_error_backoff(mock_error, "job-456", 2, 150)
627
+
628
+ # Should not go below 100
629
+ assert new_batch_size == 100
630
+
631
+
632
+ def test_process_batch_logs_continue(baseten_api):
633
+ """Test _process_batch_logs when pagination should continue"""
634
+ from truss.remote.baseten.core import _process_batch_logs
635
+
636
+ batch_logs = [
637
+ {"timestamp": "1640995200000000000", "message": "Log 1"},
638
+ {"timestamp": "1640995260000000000", "message": "Log 2"},
639
+ ]
640
+
641
+ should_continue, next_start_time, next_end_time = _process_batch_logs(
642
+ batch_logs, "job-456", 1, 2
643
+ )
644
+
645
+ assert should_continue is True
646
+ # Should be 1640995260000 + 1 = 1640995260001 (last timestamp + 1ms)
647
+ assert next_start_time == 1640995260001
648
+
649
+
650
+ def test_process_batch_logs_empty(baseten_api):
651
+ """Test _process_batch_logs when batch is empty"""
652
+ from truss.remote.baseten.core import _process_batch_logs
653
+
654
+ should_continue, next_start_time, next_end_time = _process_batch_logs(
655
+ [], "job-456", 1, 100
656
+ )
657
+
658
+ assert should_continue is False
659
+ assert next_start_time is None
660
+
661
+
662
+ def test_process_batch_logs_partial(baseten_api):
663
+ """Test _process_batch_logs when batch is smaller than expected"""
664
+ from truss.remote.baseten.core import _process_batch_logs
665
+
666
+ batch_logs = [{"timestamp": "1640995200000000000", "message": "Log 1"}]
667
+
668
+ should_continue, next_start_time, next_end_time = _process_batch_logs(
669
+ batch_logs, "job-456", 1, 100
670
+ )
671
+
672
+ assert should_continue is True
673
+ assert next_start_time is not None
674
+ assert next_end_time is not None
675
+
676
+
677
+ def test_get_training_job_logs_with_pagination_server_error_retry(baseten_api):
678
+ """Test pagination with server error retry logic"""
679
+ batch_logs = [
680
+ {"timestamp": "1640995200000000000", "message": "Log 1"},
681
+ {"timestamp": "1640995260000000000", "message": "Log 2"},
682
+ ]
683
+
684
+ # Mock the _fetch_log_batch method directly
685
+ # First call fails with 500, second call succeeds
686
+ mock_response_500 = mock.Mock()
687
+ mock_response_500.status_code = 500
688
+ mock_error_500 = requests.HTTPError("Server Error")
689
+ mock_error_500.response = mock_response_500
690
+
691
+ mock_fetch = mock.Mock(
692
+ side_effect=[
693
+ mock_error_500, # First call fails
694
+ batch_logs, # Second call succeeds
695
+ ]
696
+ )
697
+ baseten_api._fetch_log_batch = mock_fetch
698
+
699
+ # Mock get_training_job method
700
+ baseten_api.get_training_job = mock.Mock(
701
+ return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
702
+ )
703
+
704
+ result = get_training_job_logs_with_pagination(
705
+ baseten_api, "project-123", "job-456", batch_size=1000
706
+ )
707
+
708
+ # Should get the logs after retry
709
+ assert result == batch_logs
710
+
711
+ # Should have made 3 calls (first fails, retry with reduced batch size, then succeeds)
712
+ assert mock_fetch.call_count == 3
713
+
714
+
715
+ def test_get_training_job_logs_with_pagination_non_server_error(baseten_api):
716
+ """Test pagination with non-server error (should not retry)"""
717
+ # Mock the _fetch_log_batch method directly
718
+
719
+ # Mock a 400 error (client error, not server error)
720
+ mock_response_400 = mock.Mock()
721
+ mock_response_400.status_code = 400
722
+ mock_error_400 = requests.HTTPError("Bad Request")
723
+ mock_error_400.response = mock_response_400
724
+
725
+ mock_fetch = mock.Mock(side_effect=mock_error_400)
726
+ baseten_api._fetch_log_batch = mock_fetch
727
+
728
+ # Mock get_training_job method
729
+ baseten_api.get_training_job = mock.Mock(
730
+ return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
731
+ )
732
+
733
+ result = get_training_job_logs_with_pagination(
734
+ baseten_api, "project-123", "job-456"
735
+ )
736
+
737
+ # Should return empty list on non-server error
738
+ assert result == []
739
+
740
+ # Should have made only 1 call (no retry)
741
+ assert mock_fetch.call_count == 1
742
+
743
+
744
+ def test_get_training_job_logs_with_pagination_default_batch_size(baseten_api):
745
+ """Test that default batch size is MAX_BATCH_SIZE"""
746
+ mock_logs = [{"timestamp": "1640995200000000000", "message": "Log 1"}]
747
+
748
+ # Mock the _fetch_log_batch method directly
749
+ mock_fetch = mock.Mock(side_effect=[mock_logs, []])
750
+ baseten_api._fetch_log_batch = mock_fetch
751
+
752
+ # Mock get_training_job method
753
+ baseten_api.get_training_job = mock.Mock(
754
+ return_value={"training_job": {"created_at": "2022-01-01T00:00:00Z"}}
755
+ )
756
+
757
+ get_training_job_logs_with_pagination(baseten_api, "project-123", "job-456")
758
+
759
+ # Verify the API call used default batch size
760
+ call_args = mock_fetch.call_args
761
+ query_params = call_args[0][2] # query_params
762
+
763
+ assert query_params["limit"] == MAX_BATCH_SIZE
@@ -13,6 +13,3 @@ class Model:
13
13
 
14
14
  def predict(self, input: Dict) -> str:
15
15
  return "predict"
16
-
17
- def messages(self, input: Dict) -> str:
18
- return "messages"
@@ -106,7 +106,6 @@ class DockerURLs:
106
106
  self.predict_url = f"{base_url}/v1/models/model:predict"
107
107
  self.completions_url = f"{base_url}/v1/completions"
108
108
  self.chat_completions_url = f"{base_url}/v1/chat/completions"
109
- self.messages_url = f"{base_url}/v1/messages"
110
109
 
111
110
  self.schema_url = f"{base_url}/v1/models/model/schema"
112
111
  self.metrics_url = f"{base_url}/metrics"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: truss
3
- Version: 0.10.9rc601
3
+ Version: 0.10.10rc1
4
4
  Summary: A seamless bridge from model development to model delivery
5
5
  Project-URL: Repository, https://github.com/basetenlabs/truss
6
6
  Project-URL: Homepage, https://truss.baseten.co
@@ -37,7 +37,7 @@ Requires-Dist: rich<14,>=13.4.2
37
37
  Requires-Dist: ruff>=0.4.8
38
38
  Requires-Dist: tenacity>=8.0.1
39
39
  Requires-Dist: tomlkit>=0.13.2
40
- Requires-Dist: truss-transfer==0.0.27
40
+ Requires-Dist: truss-transfer==0.0.29
41
41
  Requires-Dist: watchfiles<0.20,>=0.19.0
42
42
  Description-Content-Type: text/markdown
43
43