judgeval 0.16.8__tar.gz → 0.17.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of judgeval might be problematic. Click here for more details.

Files changed (165) hide show
  1. {judgeval-0.16.8 → judgeval-0.17.0}/PKG-INFO +2 -3
  2. {judgeval-0.16.8 → judgeval-0.17.0}/README.md +1 -2
  3. {judgeval-0.16.8 → judgeval-0.17.0}/pyproject.toml +1 -1
  4. judgeval-0.17.0/src/judgeval/trainer/__init__.py +14 -0
  5. judgeval-0.17.0/src/judgeval/trainer/base_trainer.py +117 -0
  6. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/trainer/config.py +1 -1
  7. judgeval-0.16.8/src/judgeval/trainer/trainer.py → judgeval-0.17.0/src/judgeval/trainer/fireworks_trainer.py +14 -38
  8. judgeval-0.17.0/src/judgeval/trainer/trainer.py +70 -0
  9. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/version.py +1 -1
  10. judgeval-0.16.8/src/judgeval/trainer/__init__.py +0 -5
  11. {judgeval-0.16.8 → judgeval-0.17.0}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  12. {judgeval-0.16.8 → judgeval-0.17.0}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  13. {judgeval-0.16.8 → judgeval-0.17.0}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  14. {judgeval-0.16.8 → judgeval-0.17.0}/.github/pull_request_template.md +0 -0
  15. {judgeval-0.16.8 → judgeval-0.17.0}/.github/workflows/blocked-pr.yaml +0 -0
  16. {judgeval-0.16.8 → judgeval-0.17.0}/.github/workflows/ci.yaml +0 -0
  17. {judgeval-0.16.8 → judgeval-0.17.0}/.github/workflows/claude-code-review.yml +0 -0
  18. {judgeval-0.16.8 → judgeval-0.17.0}/.github/workflows/claude.yml +0 -0
  19. {judgeval-0.16.8 → judgeval-0.17.0}/.github/workflows/lint.yaml +0 -0
  20. {judgeval-0.16.8 → judgeval-0.17.0}/.github/workflows/merge-branch-check.yaml +0 -0
  21. {judgeval-0.16.8 → judgeval-0.17.0}/.github/workflows/mypy.yaml +0 -0
  22. {judgeval-0.16.8 → judgeval-0.17.0}/.github/workflows/pre-commit-autoupdate.yaml +0 -0
  23. {judgeval-0.16.8 → judgeval-0.17.0}/.github/workflows/release.yaml +0 -0
  24. {judgeval-0.16.8 → judgeval-0.17.0}/.github/workflows/validate-branch.yaml +0 -0
  25. {judgeval-0.16.8 → judgeval-0.17.0}/.gitignore +0 -0
  26. {judgeval-0.16.8 → judgeval-0.17.0}/.pre-commit-config.yaml +0 -0
  27. {judgeval-0.16.8 → judgeval-0.17.0}/CONTRIBUTING.md +0 -0
  28. {judgeval-0.16.8 → judgeval-0.17.0}/LICENSE.md +0 -0
  29. {judgeval-0.16.8 → judgeval-0.17.0}/assets/Screenshot 2025-05-17 at 8.14.27/342/200/257PM.png" +0 -0
  30. {judgeval-0.16.8 → judgeval-0.17.0}/assets/agent.gif +0 -0
  31. {judgeval-0.16.8 → judgeval-0.17.0}/assets/agent_trace_example.png +0 -0
  32. {judgeval-0.16.8 → judgeval-0.17.0}/assets/brand/company.jpg +0 -0
  33. {judgeval-0.16.8 → judgeval-0.17.0}/assets/brand/company_banner.jpg +0 -0
  34. {judgeval-0.16.8 → judgeval-0.17.0}/assets/brand/darkmode.svg +0 -0
  35. {judgeval-0.16.8 → judgeval-0.17.0}/assets/brand/full_logo.png +0 -0
  36. {judgeval-0.16.8 → judgeval-0.17.0}/assets/brand/icon.png +0 -0
  37. {judgeval-0.16.8 → judgeval-0.17.0}/assets/brand/lightmode.svg +0 -0
  38. {judgeval-0.16.8 → judgeval-0.17.0}/assets/brand/white_background.png +0 -0
  39. {judgeval-0.16.8 → judgeval-0.17.0}/assets/custom_scorer_online_abm.png +0 -0
  40. {judgeval-0.16.8 → judgeval-0.17.0}/assets/data.gif +0 -0
  41. {judgeval-0.16.8 → judgeval-0.17.0}/assets/dataset_clustering_screenshot.png +0 -0
  42. {judgeval-0.16.8 → judgeval-0.17.0}/assets/dataset_clustering_screenshot_dm.png +0 -0
  43. {judgeval-0.16.8 → judgeval-0.17.0}/assets/datasets_preview_screenshot.png +0 -0
  44. {judgeval-0.16.8 → judgeval-0.17.0}/assets/document.gif +0 -0
  45. {judgeval-0.16.8 → judgeval-0.17.0}/assets/error_analysis_dashboard.png +0 -0
  46. {judgeval-0.16.8 → judgeval-0.17.0}/assets/errors.png +0 -0
  47. {judgeval-0.16.8 → judgeval-0.17.0}/assets/experiments_dashboard_screenshot.png +0 -0
  48. {judgeval-0.16.8 → judgeval-0.17.0}/assets/experiments_page.png +0 -0
  49. {judgeval-0.16.8 → judgeval-0.17.0}/assets/experiments_pagev2.png +0 -0
  50. {judgeval-0.16.8 → judgeval-0.17.0}/assets/logo_darkmode.svg +0 -0
  51. {judgeval-0.16.8 → judgeval-0.17.0}/assets/logo_lightmode.svg +0 -0
  52. {judgeval-0.16.8 → judgeval-0.17.0}/assets/monitoring_screenshot.png +0 -0
  53. {judgeval-0.16.8 → judgeval-0.17.0}/assets/online_eval.png +0 -0
  54. {judgeval-0.16.8 → judgeval-0.17.0}/assets/product_shot.png +0 -0
  55. {judgeval-0.16.8 → judgeval-0.17.0}/assets/quickstart_trajectory_ss.png +0 -0
  56. {judgeval-0.16.8 → judgeval-0.17.0}/assets/test.png +0 -0
  57. {judgeval-0.16.8 → judgeval-0.17.0}/assets/tests.png +0 -0
  58. {judgeval-0.16.8 → judgeval-0.17.0}/assets/trace.gif +0 -0
  59. {judgeval-0.16.8 → judgeval-0.17.0}/assets/trace_demo.png +0 -0
  60. {judgeval-0.16.8 → judgeval-0.17.0}/assets/trace_screenshot.png +0 -0
  61. {judgeval-0.16.8 → judgeval-0.17.0}/assets/trace_screenshot_old.png +0 -0
  62. {judgeval-0.16.8 → judgeval-0.17.0}/pytest.ini +0 -0
  63. {judgeval-0.16.8 → judgeval-0.17.0}/scripts/api_generator.py +0 -0
  64. {judgeval-0.16.8 → judgeval-0.17.0}/scripts/openapi_transform.py +0 -0
  65. {judgeval-0.16.8 → judgeval-0.17.0}/scripts/update_types.sh +0 -0
  66. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/__init__.py +0 -0
  67. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/api/__init__.py +0 -0
  68. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/api/api_types.py +0 -0
  69. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/cli.py +0 -0
  70. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/constants.py +0 -0
  71. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/data/__init__.py +0 -0
  72. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/data/evaluation_run.py +0 -0
  73. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/data/example.py +0 -0
  74. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/data/judgment_types.py +0 -0
  75. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/data/result.py +0 -0
  76. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/data/scorer_data.py +0 -0
  77. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/data/scripts/fix_default_factory.py +0 -0
  78. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/data/scripts/openapi_transform.py +0 -0
  79. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/data/trace.py +0 -0
  80. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/dataset/__init__.py +0 -0
  81. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/env.py +0 -0
  82. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/evaluation/__init__.py +0 -0
  83. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/exceptions.py +0 -0
  84. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/integrations/langgraph/__init__.py +0 -0
  85. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/integrations/openlit/__init__.py +0 -0
  86. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/judges/__init__.py +0 -0
  87. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/judges/base_judge.py +0 -0
  88. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/judges/litellm_judge.py +0 -0
  89. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/judges/together_judge.py +0 -0
  90. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/judges/utils.py +0 -0
  91. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/logger.py +0 -0
  92. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/__init__.py +0 -0
  93. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/agent_scorer.py +0 -0
  94. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/api_scorer.py +0 -0
  95. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/base_scorer.py +0 -0
  96. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/example_scorer.py +0 -0
  97. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/exceptions.py +0 -0
  98. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/judgeval_scorers/__init__.py +0 -0
  99. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +0 -0
  100. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +0 -0
  101. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +0 -0
  102. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +0 -0
  103. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +0 -0
  104. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +0 -0
  105. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/score.py +0 -0
  106. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/scorers/utils.py +0 -0
  107. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/__init__.py +0 -0
  108. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/constants.py +0 -0
  109. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/exporters/__init__.py +0 -0
  110. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/exporters/s3.py +0 -0
  111. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/exporters/store.py +0 -0
  112. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/exporters/utils.py +0 -0
  113. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/keys.py +0 -0
  114. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/__init__.py +0 -0
  115. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/config.py +0 -0
  116. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/constants.py +0 -0
  117. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_anthropic/__init__.py +0 -0
  118. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_anthropic/config.py +0 -0
  119. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_anthropic/messages.py +0 -0
  120. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_anthropic/messages_stream.py +0 -0
  121. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_anthropic/wrapper.py +0 -0
  122. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_google/__init__.py +0 -0
  123. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_google/config.py +0 -0
  124. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_google/generate_content.py +0 -0
  125. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_google/wrapper.py +0 -0
  126. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_openai/__init__.py +0 -0
  127. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_openai/beta_chat_completions.py +0 -0
  128. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_openai/chat_completions.py +0 -0
  129. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_openai/config.py +0 -0
  130. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_openai/responses.py +0 -0
  131. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_openai/wrapper.py +0 -0
  132. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_together/__init__.py +0 -0
  133. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_together/chat_completions.py +0 -0
  134. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_together/config.py +0 -0
  135. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/llm_together/wrapper.py +0 -0
  136. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/llm/providers.py +0 -0
  137. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/managers.py +0 -0
  138. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/processors/__init__.py +0 -0
  139. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/tracer/utils.py +0 -0
  140. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/trainer/console.py +0 -0
  141. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/trainer/trainable_model.py +0 -0
  142. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/async_utils.py +0 -0
  143. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/decorators/__init__.py +0 -0
  144. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/decorators/dont_throw.py +0 -0
  145. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/decorators/use_once.py +0 -0
  146. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/file_utils.py +0 -0
  147. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/guards.py +0 -0
  148. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/meta.py +0 -0
  149. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/serialize.py +0 -0
  150. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/testing.py +0 -0
  151. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/url.py +0 -0
  152. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/version_check.py +0 -0
  153. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/wrappers/README.md +0 -0
  154. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/wrappers/__init__.py +0 -0
  155. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/wrappers/immutable_wrap_async.py +0 -0
  156. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/wrappers/immutable_wrap_async_iterator.py +0 -0
  157. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/wrappers/immutable_wrap_sync.py +0 -0
  158. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/wrappers/immutable_wrap_sync_iterator.py +0 -0
  159. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/wrappers/mutable_wrap_async.py +0 -0
  160. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/wrappers/mutable_wrap_sync.py +0 -0
  161. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/wrappers/py.typed +0 -0
  162. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/utils/wrappers/utils.py +0 -0
  163. {judgeval-0.16.8 → judgeval-0.17.0}/src/judgeval/warnings.py +0 -0
  164. {judgeval-0.16.8 → judgeval-0.17.0}/update_version.py +0 -0
  165. {judgeval-0.16.8 → judgeval-0.17.0}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: judgeval
3
- Version: 0.16.8
3
+ Version: 0.17.0
4
4
  Summary: Judgeval Package
5
5
  Project-URL: Homepage, https://github.com/JudgmentLabs/judgeval
6
6
  Project-URL: Issues, https://github.com/JudgmentLabs/judgeval/issues
@@ -63,8 +63,7 @@ Judgeval's agent monitoring infra provides a simple harness for integrating GRPO
63
63
  await trainer.train(
64
64
  agent_function=your_agent_function, # entry point to your agent
65
65
  scorers=[RewardScorer()], # Custom scorer you define based on task criteria, acts as reward
66
- prompts=training_prompts, # Tasks
67
- rft_provider="fireworks"
66
+ prompts=training_prompts # Tasks
68
67
  )
69
68
  ```
70
69
 
@@ -36,8 +36,7 @@ Judgeval's agent monitoring infra provides a simple harness for integrating GRPO
36
36
  await trainer.train(
37
37
  agent_function=your_agent_function, # entry point to your agent
38
38
  scorers=[RewardScorer()], # Custom scorer you define based on task criteria, acts as reward
39
- prompts=training_prompts, # Tasks
40
- rft_provider="fireworks"
39
+ prompts=training_prompts # Tasks
41
40
  )
42
41
  ```
43
42
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "judgeval"
3
- version = "0.16.8"
3
+ version = "0.17.0"
4
4
  authors = [
5
5
  { name = "Andrew Li", email = "andrew@judgmentlabs.ai" },
6
6
  { name = "Alex Shan", email = "alex@judgmentlabs.ai" },
@@ -0,0 +1,14 @@
1
+ from judgeval.trainer.trainer import JudgmentTrainer
2
+ from judgeval.trainer.config import TrainerConfig, ModelConfig
3
+ from judgeval.trainer.trainable_model import TrainableModel
4
+ from judgeval.trainer.base_trainer import BaseTrainer
5
+ from judgeval.trainer.fireworks_trainer import FireworksTrainer
6
+
7
+ __all__ = [
8
+ "JudgmentTrainer",
9
+ "TrainerConfig",
10
+ "ModelConfig",
11
+ "TrainableModel",
12
+ "BaseTrainer",
13
+ "FireworksTrainer",
14
+ ]
@@ -0,0 +1,117 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Callable, List, Optional, Union, Dict, TYPE_CHECKING
3
+ from .config import TrainerConfig, ModelConfig
4
+ from judgeval.scorers import ExampleScorer, ExampleAPIScorerConfig
5
+
6
+ if TYPE_CHECKING:
7
+ from judgeval.tracer import Tracer
8
+ from .trainable_model import TrainableModel
9
+
10
+
11
+ class BaseTrainer(ABC):
12
+ """
13
+ Abstract base class for training providers.
14
+
15
+ This class defines the interface that all training provider implementations
16
+ must follow. Each provider (Fireworks, Verifiers, etc.) will have its own
17
+ concrete implementation of this interface.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ config: TrainerConfig,
23
+ trainable_model: "TrainableModel",
24
+ tracer: "Tracer",
25
+ project_name: Optional[str] = None,
26
+ ):
27
+ """
28
+ Initialize the base trainer.
29
+
30
+ Args:
31
+ config: TrainerConfig instance with training parameters
32
+ trainable_model: TrainableModel instance to use for training
33
+ tracer: Tracer for observability
34
+ project_name: Project name for organizing training runs
35
+ """
36
+ self.config = config
37
+ self.trainable_model = trainable_model
38
+ self.tracer = tracer
39
+ self.project_name = project_name or "judgment_training"
40
+
41
+ @abstractmethod
42
+ async def generate_rollouts_and_rewards(
43
+ self,
44
+ agent_function: Callable[[Any], Any],
45
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
46
+ prompts: List[Any],
47
+ num_prompts_per_step: Optional[int] = None,
48
+ num_generations_per_prompt: Optional[int] = None,
49
+ concurrency: Optional[int] = None,
50
+ ) -> Any:
51
+ """
52
+ Generate rollouts and compute rewards using the current model snapshot.
53
+
54
+ Args:
55
+ agent_function: Function/agent to call for generating responses
56
+ scorers: List of scorer objects to evaluate responses
57
+ prompts: List of prompts to use for training
58
+ num_prompts_per_step: Number of prompts to use per step
59
+ num_generations_per_prompt: Generations per prompt
60
+ concurrency: Concurrency limit
61
+
62
+ Returns:
63
+ Provider-specific dataset format for training
64
+ """
65
+ pass
66
+
67
+ @abstractmethod
68
+ async def run_reinforcement_learning(
69
+ self,
70
+ agent_function: Callable[[Any], Any],
71
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
72
+ prompts: List[Any],
73
+ ) -> ModelConfig:
74
+ """
75
+ Run the iterative reinforcement learning fine-tuning loop.
76
+
77
+ Args:
78
+ agent_function: Function/agent to call for generating responses
79
+ scorers: List of scorer objects to evaluate responses
80
+ prompts: List of prompts to use for training
81
+
82
+ Returns:
83
+ ModelConfig: Configuration of the trained model
84
+ """
85
+ pass
86
+
87
+ @abstractmethod
88
+ async def train(
89
+ self,
90
+ agent_function: Callable[[Any], Any],
91
+ scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
92
+ prompts: List[Any],
93
+ ) -> ModelConfig:
94
+ """
95
+ Start the reinforcement learning fine-tuning process.
96
+
97
+ This is the main entry point for running the training.
98
+
99
+ Args:
100
+ agent_function: Function/agent to call for generating responses
101
+ scorers: List of scorer objects to evaluate responses
102
+ prompts: List of prompts to use for training
103
+
104
+ Returns:
105
+ ModelConfig: Configuration of the trained model
106
+ """
107
+ pass
108
+
109
+ @abstractmethod
110
+ def _extract_message_history_from_spans(self) -> List[Dict[str, str]]:
111
+ """
112
+ Extract message history from spans for training purposes.
113
+
114
+ Returns:
115
+ List of message dictionaries with 'role' and 'content' keys
116
+ """
117
+ pass
@@ -16,7 +16,7 @@ class TrainerConfig:
16
16
  user_id: str
17
17
  model_id: str
18
18
  base_model_name: str = "qwen2p5-7b-instruct"
19
- rft_provider: str = "fireworks"
19
+ rft_provider: str = "fireworks" # Supported: "fireworks", "verifiers" (future)
20
20
  num_steps: int = 5
21
21
  num_generations_per_prompt: int = 4
22
22
  num_prompts_per_step: int = 4
@@ -1,9 +1,9 @@
1
1
  import asyncio
2
2
  import json
3
- import time
4
3
  from typing import Optional, Callable, Any, List, Union, Dict
5
4
  from fireworks import Dataset # type: ignore[import-not-found]
6
5
  from .config import TrainerConfig, ModelConfig
6
+ from .base_trainer import BaseTrainer
7
7
  from .trainable_model import TrainableModel
8
8
  from judgeval.tracer import Tracer
9
9
  from judgeval.tracer.exporters.store import SpanStore
@@ -16,12 +16,12 @@ from .console import _spinner_progress, _print_progress, _print_progress_update
16
16
  from judgeval.exceptions import JudgmentRuntimeError
17
17
 
18
18
 
19
- class JudgmentTrainer:
19
+ class FireworksTrainer(BaseTrainer):
20
20
  """
21
- A reinforcement learning trainer for Judgment models using Fine-Tuning.
21
+ Fireworks AI implementation of the training provider.
22
22
 
23
- This class handles the iterative training process where models are improved
24
- through reinforcement learning fine-tuning based on generated rollouts and rewards.
23
+ This trainer uses Fireworks AI's infrastructure for reinforcement learning
24
+ fine-tuning (RFT) of language models.
25
25
  """
26
26
 
27
27
  def __init__(
@@ -32,26 +32,23 @@ class JudgmentTrainer:
32
32
  project_name: Optional[str] = None,
33
33
  ):
34
34
  """
35
- Initialize the JudgmentTrainer.
35
+ Initialize the FireworksTrainer.
36
36
 
37
37
  Args:
38
- config: TrainerConfig instance with training parameters. If None, uses default config.
39
- tracer: Optional tracer for observability
40
- trainable_model: Optional trainable model instance
38
+ config: TrainerConfig instance with training parameters
39
+ trainable_model: TrainableModel instance for Fireworks training
40
+ tracer: Tracer for observability
41
41
  project_name: Project name for organizing training runs and evaluations
42
42
  """
43
43
  try:
44
- self.config = config
45
- self.tracer = tracer
46
- self.project_name = project_name or "judgment_training"
47
- self.trainable_model = trainable_model
44
+ super().__init__(config, trainable_model, tracer, project_name)
48
45
 
49
46
  self.judgment_client = JudgmentClient()
50
47
  self.span_store = SpanStore()
51
48
  self.span_exporter = InMemorySpanExporter(self.span_store)
52
49
  except Exception as e:
53
50
  raise JudgmentRuntimeError(
54
- f"Failed to initialize JudgmentTrainer: {str(e)}"
51
+ f"Failed to initialize FireworksTrainer: {str(e)}"
55
52
  ) from e
56
53
 
57
54
  def _extract_message_history_from_spans(self) -> List[Dict[str, str]]:
@@ -121,22 +118,7 @@ class JudgmentTrainer:
121
118
  pass
122
119
  messages.append({"role": "assistant", "content": content})
123
120
 
124
- elif span_type == "user":
125
- output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
126
- if output is not None:
127
- content = str(output)
128
- try:
129
- parsed = json.loads(content)
130
- if isinstance(parsed, dict) and "messages" in parsed:
131
- for msg in parsed["messages"]:
132
- if isinstance(msg, dict) and msg.get("role") == "user":
133
- content = msg.get("content", content)
134
- break
135
- except (json.JSONDecodeError, KeyError):
136
- pass
137
- messages.append({"role": "user", "content": content})
138
-
139
- elif span_type == "tool":
121
+ elif span_type in ("user", "tool"):
140
122
  output = span_attributes.get(AttributeKeys.JUDGMENT_OUTPUT)
141
123
  if output is not None:
142
124
  content = str(output)
@@ -347,7 +329,7 @@ class JudgmentTrainer:
347
329
  _print_progress_update(f"Training job: {current_state}")
348
330
  last_state = current_state
349
331
 
350
- time.sleep(10)
332
+ await asyncio.sleep(10)
351
333
  job = job.get()
352
334
  if job is None:
353
335
  raise JudgmentRuntimeError(
@@ -374,7 +356,6 @@ class JudgmentTrainer:
374
356
  agent_function: Callable[[Any], Any],
375
357
  scorers: List[Union[ExampleAPIScorerConfig, ExampleScorer]],
376
358
  prompts: List[Any],
377
- rft_provider: Optional[str] = None,
378
359
  ) -> ModelConfig:
379
360
  """
380
361
  Start the reinforcement learning fine-tuning process.
@@ -385,21 +366,16 @@ class JudgmentTrainer:
385
366
  agent_function: Function/agent to call for generating responses.
386
367
  scorers: List of scorer objects to evaluate responses
387
368
  prompts: List of prompts to use for training
388
- rft_provider: RFT provider to use for training. Currently only "fireworks" is supported.
389
- Support for other providers is planned for future releases.
390
369
 
391
370
  Returns:
392
371
  ModelConfig: Configuration of the trained model for future loading
393
372
  """
394
373
  try:
395
- if rft_provider is not None:
396
- self.config.rft_provider = rft_provider
397
-
398
374
  return await self.run_reinforcement_learning(
399
375
  agent_function, scorers, prompts
400
376
  )
401
377
  except JudgmentRuntimeError:
402
- # Re-raise JudgmentAPIError as-is
378
+ # Re-raise JudgmentRuntimeError as-is
403
379
  raise
404
380
  except Exception as e:
405
381
  raise JudgmentRuntimeError(f"Training process failed: {str(e)}") from e
@@ -0,0 +1,70 @@
1
+ from typing import Optional
2
+ from .config import TrainerConfig
3
+ from .base_trainer import BaseTrainer
4
+ from .fireworks_trainer import FireworksTrainer
5
+ from .trainable_model import TrainableModel
6
+ from judgeval.tracer import Tracer
7
+ from judgeval.exceptions import JudgmentRuntimeError
8
+
9
+
10
+ def JudgmentTrainer(
11
+ config: TrainerConfig,
12
+ trainable_model: TrainableModel,
13
+ tracer: Tracer,
14
+ project_name: Optional[str] = None,
15
+ ) -> BaseTrainer:
16
+ """
17
+ Factory function for creating reinforcement learning trainers.
18
+
19
+ This factory creates and returns provider-specific trainer implementations
20
+ (FireworksTrainer, VerifiersTrainer, etc.) based on the configured RFT provider.
21
+
22
+ The factory pattern allows for easy extension to support multiple training
23
+ providers without changing the client-facing API.
24
+
25
+ Example:
26
+ config = TrainerConfig(
27
+ deployment_id="my-deployment",
28
+ user_id="my-user",
29
+ model_id="my-model",
30
+ rft_provider="fireworks" # or "verifiers" in the future
31
+ )
32
+
33
+ # User creates and configures the trainable model
34
+ trainable_model = TrainableModel(config)
35
+ tracer = Tracer()
36
+
37
+ # JudgmentTrainer automatically creates the appropriate provider-specific trainer
38
+ trainer = JudgmentTrainer(config, trainable_model, tracer)
39
+
40
+ # The returned trainer implements the BaseTrainer interface
41
+ model_config = await trainer.train(agent_function, scorers, prompts)
42
+
43
+ Args:
44
+ config: TrainerConfig instance with training parameters including rft_provider
45
+ trainable_model: Provider-specific trainable model instance (e.g., TrainableModel for Fireworks)
46
+ tracer: Tracer for observability
47
+ project_name: Project name for organizing training runs and evaluations
48
+
49
+ Returns:
50
+ Provider-specific trainer instance (FireworksTrainer, etc.) that implements
51
+ the BaseTrainer interface
52
+
53
+ Raises:
54
+ JudgmentRuntimeError: If the specified provider is not supported
55
+ """
56
+ provider = config.rft_provider.lower()
57
+
58
+ if provider == "fireworks":
59
+ return FireworksTrainer(config, trainable_model, tracer, project_name)
60
+ elif provider == "verifiers":
61
+ # Placeholder for future implementation
62
+ raise JudgmentRuntimeError(
63
+ "Verifiers provider is not yet implemented. "
64
+ "Currently supported providers: 'fireworks'"
65
+ )
66
+ else:
67
+ raise JudgmentRuntimeError(
68
+ f"Unsupported RFT provider: '{config.rft_provider}'. "
69
+ f"Currently supported providers: 'fireworks'"
70
+ )
@@ -1,4 +1,4 @@
1
- __version__ = "0.16.8"
1
+ __version__ = "0.17.0"
2
2
 
3
3
 
4
4
  def get_version() -> str:
@@ -1,5 +0,0 @@
1
- from judgeval.trainer.trainer import JudgmentTrainer
2
- from judgeval.trainer.config import TrainerConfig, ModelConfig
3
- from judgeval.trainer.trainable_model import TrainableModel
4
-
5
- __all__ = ["JudgmentTrainer", "TrainerConfig", "ModelConfig", "TrainableModel"]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes