flexeval 0.3.2__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. {flexeval-0.3.2 → flexeval-0.4.0}/PKG-INFO +2 -1
  2. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/__init__.py +1 -0
  3. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/base.py +11 -2
  4. flexeval-0.4.0/flexeval/core/chat_dataset/sacrebleu_dataset.py +32 -0
  5. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/evaluate_chat_response.py +28 -10
  6. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/evaluate_from_file.py +2 -3
  7. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/evaluate_generation.py +15 -13
  8. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/evaluate_multiple_choice.py +9 -6
  9. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/evaluate_pairwise.py +2 -3
  10. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/evaluate_perplexity.py +9 -5
  11. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/few_shot_generator/balanced.py +4 -1
  12. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/few_shot_generator/base.py +22 -4
  13. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/few_shot_generator/rand.py +1 -1
  14. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/generation_dataset/base.py +2 -1
  15. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/language_model/__init__.py +1 -1
  16. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/language_model/hf_lm.py +8 -6
  17. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/language_model/openai_chatgpt.py +4 -8
  18. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/code_eval.py +7 -1
  19. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/normalizer/regex.py +1 -1
  20. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/perspective_api.py +1 -3
  21. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/multiple_choice_dataset/base.py +2 -1
  22. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/pairwise_comparison/judge/base.py +0 -17
  23. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/pairwise_comparison/judge/llm_judge.py +2 -6
  24. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/pairwise_comparison/scorer/bradley_terry.py +1 -7
  25. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/text_dataset/base.py +7 -3
  26. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/text_dataset/hf.py +6 -5
  27. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/text_dataset/jsonl.py +9 -7
  28. flexeval-0.4.0/flexeval/preset_configs/EvalSetup/code_chat/mbpp_chat.jsonnet +44 -0
  29. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/code_generation/mbpp.jsonnet +8 -11
  30. flexeval-0.4.0/flexeval/preset_configs/EvalSetup/ja_chat/aio_chat.jsonnet +38 -0
  31. flexeval-0.4.0/flexeval/preset_configs/EvalSetup/ja_chat/mgsm_ja_chat.jsonnet +36 -0
  32. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/ja_generation/aio.jsonnet +0 -1
  33. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/ja_generation/jcommonsenseqa.jsonnet +0 -1
  34. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/ja_generation/jnli.jsonnet +0 -1
  35. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/ja_generation/jsquad.jsonnet +0 -1
  36. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/ja_generation/mgsm_ja.jsonnet +0 -1
  37. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/ja_generation/wrime_pos_neg.jsonnet +0 -1
  38. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/ja_generation/xlsum_ja.jsonnet +0 -1
  39. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/ja_multiple_choice/jcommonsenseqa_mc.jsonnet +0 -1
  40. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/translation/wmt20_en_ja.jsonnet +0 -1
  41. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/translation/wmt20_ja_en.jsonnet +0 -1
  42. flexeval-0.4.0/flexeval/preset_configs/EvalSetup/translation_chat/wmt20_en_ja_chat.jsonnet +35 -0
  43. flexeval-0.4.0/flexeval/preset_configs/EvalSetup/translation_chat/wmt20_ja_en_chat.jsonnet +35 -0
  44. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/Metric/assistant_eval_gpt4_en_single_turn.jsonnet +1 -1
  45. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/Metric/assistant_eval_gpt4_ja_single_turn.jsonnet +1 -1
  46. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/PairwiseJudge/assistant_judge_gpt4_en_single_turn.jsonnet +1 -1
  47. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/PairwiseJudge/assistant_judge_gpt4_ja_single_turn.jsonnet +1 -1
  48. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/scripts/common.py +16 -31
  49. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/scripts/flexeval_file.py +24 -26
  50. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/scripts/flexeval_lm.py +68 -73
  51. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/scripts/flexeval_pairwise.py +4 -13
  52. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/scripts/flexeval_presets.py +5 -1
  53. {flexeval-0.3.2 → flexeval-0.4.0}/pyproject.toml +2 -1
  54. {flexeval-0.3.2 → flexeval-0.4.0}/LICENSE +0 -0
  55. {flexeval-0.3.2 → flexeval-0.4.0}/README.md +0 -0
  56. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/__init__.py +0 -0
  57. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/__init__.py +0 -0
  58. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/chatbot_bench.py +0 -0
  59. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/chatbot_bench_datasets/README.md +0 -0
  60. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/chatbot_bench_datasets/mt-en-ref-gpt4.jsonl +0 -0
  61. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/chatbot_bench_datasets/mt-en.jsonl +0 -0
  62. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/chatbot_bench_datasets/mt-ja-ref-gpt4.jsonl +0 -0
  63. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/chatbot_bench_datasets/mt-ja.jsonl +0 -0
  64. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/chatbot_bench_datasets/rakuda-v2-ja.jsonl +0 -0
  65. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/chatbot_bench_datasets/vicuna-en-ref-gpt4.jsonl +0 -0
  66. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/chatbot_bench_datasets/vicuna-en.jsonl +0 -0
  67. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/chatbot_bench_datasets/vicuna-ja-ref-gpt4.jsonl +0 -0
  68. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/chatbot_bench_datasets/vicuna-ja.jsonl +0 -0
  69. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/chat_dataset/hf_dataset.py +0 -0
  70. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/few_shot_generator/__init__.py +0 -0
  71. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/generation_dataset/__init__.py +0 -0
  72. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/generation_dataset/hf_dataset.py +0 -0
  73. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/generation_dataset/jsonl.py +0 -0
  74. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/generation_dataset/sacrebleu_dataset.py +0 -0
  75. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/language_model/base.py +0 -0
  76. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/language_model/vllm_model.py +0 -0
  77. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/__init__.py +0 -0
  78. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/base.py +0 -0
  79. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/bleu.py +0 -0
  80. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/char_f1.py +0 -0
  81. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/common_prefix_length.py +0 -0
  82. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/common_string_length.py +0 -0
  83. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/exact_match.py +0 -0
  84. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/llm_score.py +0 -0
  85. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/normalizer/__init__.py +0 -0
  86. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/normalizer/aio.py +0 -0
  87. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/normalizer/base.py +0 -0
  88. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/output_length_stats.py +0 -0
  89. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/rouge.py +0 -0
  90. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/substring_match.py +0 -0
  91. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/tokenizer/__init__.py +0 -0
  92. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/tokenizer/base.py +0 -0
  93. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/tokenizer/mecab.py +0 -0
  94. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/tokenizer/sacrebleu_tokenizer.py +0 -0
  95. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/tokenizer/whitespace.py +0 -0
  96. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/metric/xer.py +0 -0
  97. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/multiple_choice_dataset/__init__.py +0 -0
  98. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/multiple_choice_dataset/hf_dataset.py +0 -0
  99. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/pairwise_comparison/__init__.py +0 -0
  100. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/pairwise_comparison/judge/__init__.py +0 -0
  101. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/pairwise_comparison/match.py +0 -0
  102. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/pairwise_comparison/match_maker/__init__.py +0 -0
  103. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/pairwise_comparison/match_maker/all_combinations.py +0 -0
  104. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/pairwise_comparison/match_maker/base.py +0 -0
  105. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/pairwise_comparison/match_maker/random_combinations.py +0 -0
  106. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/pairwise_comparison/scorer/__init__.py +0 -0
  107. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/pairwise_comparison/scorer/base.py +0 -0
  108. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/pairwise_comparison/scorer/win_rate.py +0 -0
  109. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/prompt_template/__init__.py +0 -0
  110. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/prompt_template/base.py +0 -0
  111. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/prompt_template/jinja2.py +0 -0
  112. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/text_dataset/__init__.py +0 -0
  113. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/utils/__init__.py +0 -0
  114. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/utils/data_util.py +0 -0
  115. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/core/utils/jinja2_env.py +0 -0
  116. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/code_generation/jhumaneval.jsonnet +0 -0
  117. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/code_generation/jhumaneval_tab_indent.jsonnet +0 -0
  118. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/code_generation/mbpp_tab_indent.jsonnet +0 -0
  119. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/code_generation/openai_humaneval.jsonnet +0 -0
  120. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/code_generation/openai_humaneval_tab_indent.jsonnet +0 -0
  121. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/en_chat/mt-en.jsonnet +0 -0
  122. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/en_chat/vicuna-en.jsonnet +0 -0
  123. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/en_generation/babi.jsonnet +0 -0
  124. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/en_generation/commonsense_qa.jsonnet +0 -0
  125. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/en_generation/gsm8k.jsonnet +0 -0
  126. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/en_generation/squad_v1.jsonnet +0 -0
  127. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/en_generation/trivia_qa.jsonnet +0 -0
  128. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/en_generation/twitter_sentiment.jsonnet +0 -0
  129. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/en_multiple_choice/commonsense_qa_mc.jsonnet +0 -0
  130. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/en_multiple_choice/hellaswag.jsonnet +0 -0
  131. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/en_multiple_choice/openbookqa.jsonnet +0 -0
  132. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/en_multiple_choice/xwinograd_en.jsonnet +0 -0
  133. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/en_perplexity/tiny_shakespeare.jsonnet +0 -0
  134. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/ja_chat/elyze_tasks_100.jsonnet +0 -0
  135. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/ja_chat/mt-ja.jsonnet +0 -0
  136. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/ja_chat/rakuda-v2-ja.jsonnet +0 -0
  137. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/ja_chat/vicuna-ja.jsonnet +0 -0
  138. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/preset_configs/EvalSetup/ja_multiple_choice/xwinograd_ja.jsonnet +0 -0
  139. {flexeval-0.3.2 → flexeval-0.4.0}/flexeval/scripts/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flexeval
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary:
5
5
  Author: ryokan-ri
6
6
  Author-email: ryokan.ri@sbintuitions.co.jp
@@ -18,6 +18,7 @@ Requires-Dist: google-api-python-client (>=2.131.0,<3.0.0)
18
18
  Requires-Dist: jinja2 (>=3.1.2,<4.0.0)
19
19
  Requires-Dist: jiwer (>=3.0.4,<4.0.0)
20
20
  Requires-Dist: jsonargparse[jsonnet] (>=4.26.1,<5.0.0)
21
+ Requires-Dist: loguru (>=0.7.2,<0.8.0)
21
22
  Requires-Dist: openai (>=1.16.1,<2.0.0)
22
23
  Requires-Dist: peft (>=0.10.0,<0.11.0)
23
24
  Requires-Dist: python-levenshtein (>=0.23.0,<0.24.0)
@@ -1,3 +1,4 @@
1
1
  from .base import ChatDataset, ChatInstance
2
2
  from .chatbot_bench import ChatbotBench
3
3
  from .hf_dataset import HfChatDataset
4
+ from .sacrebleu_dataset import SacreBleuChatDataset
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from dataclasses import dataclass
5
- from typing import Any
5
+ from typing import Any, Sequence
6
6
 
7
7
 
8
8
  @dataclass
@@ -43,8 +43,17 @@ class ChatInstance:
43
43
  msg = "extra_info cannot contain a key named 'messages'. It will conflict with the 'messages' attribute."
44
44
  raise ValueError(msg)
45
45
 
46
+ @property
47
+ def inputs(self) -> list[dict[str, str]]:
48
+ """
49
+ Alias for `messages`.
50
+ This is used in `FewShotGenerator` so that it can access the inputs with the same attribute name as
51
+ `GenerationInstance` and `MultipleChoiceInstance`.
52
+ """
53
+ return self.messages
54
+
46
55
 
47
- class ChatDataset(ABC):
56
+ class ChatDataset(Sequence[ChatInstance], ABC):
48
57
  """A dataset holding `ChatInstance`."""
49
58
 
50
59
  @abstractmethod
@@ -0,0 +1,32 @@
1
+ import sacrebleu
2
+
3
+ from .base import ChatDataset, ChatInstance
4
+
5
+
6
+ class SacreBleuChatDataset(ChatDataset):
7
+ """Load datasets from the [sacrebleu](https://github.com/mjpost/sacrebleu) library.
8
+ The available datasets are defined in sacrebleu.DATASETS.
9
+ """
10
+
11
+ def __init__(self, dataset_name: str, langpair: str) -> None:
12
+ self._source_list: list[str] = list(sacrebleu.DATASETS[dataset_name].source(langpair))
13
+ self._references_list: list[list[str]] = [
14
+ [r.strip() for r in refs] for refs in sacrebleu.DATASETS[dataset_name].references(langpair)
15
+ ]
16
+
17
+ if len(self._source_list) != len(self._references_list):
18
+ msg = "The number of source and reference pairs should be the same."
19
+ raise ValueError(msg)
20
+
21
+ def require_incremental_response(self) -> bool:
22
+ return False
23
+
24
+ def __len__(self) -> int:
25
+ return len(self._source_list)
26
+
27
+ def __getitem__(self, i: int) -> ChatInstance:
28
+ return ChatInstance(
29
+ messages=[{"role": "user", "content": self._source_list[i]}],
30
+ references=self._references_list[i],
31
+ extra_info={},
32
+ )
@@ -1,35 +1,53 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
- from typing import Any
3
+ from typing import Any, Sequence
5
4
 
5
+ from loguru import logger
6
6
  from tqdm import tqdm
7
7
 
8
8
  from .chat_dataset import ChatDataset, ChatInstance
9
+ from .few_shot_generator import FewShotGenerator
9
10
  from .language_model import LanguageModel
10
11
  from .metric import Metric
11
12
  from .utils.data_util import batch_iter
12
13
 
13
- logger = logging.getLogger(__name__)
14
14
 
15
-
16
- def evaluate_chat_response(
15
+ def evaluate_chat_response( # noqa: C901,PLR0912
17
16
  language_model: LanguageModel,
18
17
  gen_kwargs: dict[str, Any],
19
18
  eval_dataset: ChatDataset,
20
19
  metrics: list[Metric],
21
20
  batch_size: int,
21
+ max_instances: int | None = None,
22
+ few_shot_generator: FewShotGenerator | None = None,
22
23
  ) -> tuple[dict[str, float], list[dict[str, Any]]]:
23
24
  logger.info(f"Evaluate the model with gen_kwargs: {gen_kwargs}")
24
25
 
26
+ eval_instances: Sequence[ChatInstance] = eval_dataset
27
+ if max_instances is not None:
28
+ eval_instances = [eval_dataset[i] for i in range(min(max_instances, len(eval_dataset)))]
29
+
25
30
  all_messages_list: list[list[dict[str, str]]] = []
26
31
  references_list: list[list[str]] = []
27
32
  extra_info_list: list[dict[str, Any]] = []
28
- with tqdm(total=len(eval_dataset)) as pbar:
29
- for i, batch in enumerate(batch_iter(eval_dataset, batch_size)):
30
- batch: list[ChatInstance]
31
-
33
+ with tqdm(total=len(eval_instances)) as pbar:
34
+ for batch_id, batch in enumerate(batch_iter(eval_instances, batch_size)):
32
35
  input_messages_list = [chat_instance.messages for chat_instance in batch]
36
+
37
+ if few_shot_generator is not None:
38
+ for input_id in range(len(input_messages_list)):
39
+ few_shot_instances = few_shot_generator(eval_inputs=input_messages_list[input_id])
40
+ few_shot_messages: list[dict[str, str]] = []
41
+ for few_shot_instance in few_shot_instances:
42
+ if not isinstance(few_shot_instance, ChatInstance):
43
+ msg = f"Invalid instance type: {type(few_shot_instance)}"
44
+ raise TypeError(msg)
45
+ few_shot_messages += few_shot_instance.messages
46
+ if few_shot_instance.references:
47
+ # use the first reference as the assistant message
48
+ few_shot_messages += [{"role": "assistant", "content": few_shot_instance.references[0]}]
49
+ input_messages_list[input_id] = [*few_shot_messages, *input_messages_list[input_id]]
50
+
33
51
  if not eval_dataset.require_incremental_response():
34
52
  lm_outputs = language_model.batch_generate_chat_response(
35
53
  input_messages_list,
@@ -65,7 +83,7 @@ def evaluate_chat_response(
65
83
  references_list += [chat_instance.references for chat_instance in batch]
66
84
  extra_info_list += [chat_instance.extra_info for chat_instance in batch]
67
85
 
68
- if i == 0:
86
+ if batch_id == 0:
69
87
  logger.info("Example of the conversation")
70
88
  logger.info(f"{all_messages_list[0]}")
71
89
 
@@ -1,16 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import json
4
- import logging
5
4
  from os import PathLike
6
5
  from typing import Any
7
6
 
7
+ from loguru import logger
8
+
8
9
  from .chat_dataset import ChatDataset
9
10
  from .generation_dataset import GenerationDataset
10
11
  from .metric import Metric
11
12
 
12
- logger = logging.getLogger(__name__)
13
-
14
13
 
15
14
  def evaluate_from_file(
16
15
  eval_file: str | PathLike[str],
@@ -1,8 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
- from typing import Any
3
+ from typing import Any, Sequence
5
4
 
5
+ from loguru import logger
6
6
  from tqdm import tqdm
7
7
 
8
8
  from .few_shot_generator import FewShotGenerator
@@ -12,25 +12,28 @@ from .metric import Metric
12
12
  from .prompt_template import PromptTemplate
13
13
  from .utils.data_util import batch_iter
14
14
 
15
- logger = logging.getLogger(__name__)
16
15
 
17
-
18
- def evaluate_generation(
16
+ def evaluate_generation( # noqa: C901
19
17
  language_model: LanguageModel,
20
18
  gen_kwargs: dict[str, Any],
21
19
  eval_dataset: GenerationDataset,
22
20
  prompt_template: PromptTemplate,
23
21
  metrics: list[Metric],
24
22
  batch_size: int,
23
+ max_instances: int | None = None,
25
24
  few_shot_generator: FewShotGenerator | None = None,
26
25
  ) -> tuple[dict[str, float], list[dict[str, Any]]]:
27
26
  logger.info(f"Evaluate the model with gen_kwargs: {gen_kwargs}")
28
27
  logger.info(f"Prompt template: {prompt_template}")
29
- eval_instance_list: list[GenerationInstance] = []
28
+
29
+ eval_instances: Sequence[GenerationInstance] = eval_dataset
30
+ if max_instances is not None:
31
+ eval_instances = [eval_dataset[i] for i in range(min(max_instances, len(eval_dataset)))]
32
+
30
33
  lm_prompt_list: list[str] = []
31
34
  lm_output_list: list[str] = []
32
- with tqdm(total=len(eval_dataset)) as pbar:
33
- for i, batch in enumerate(batch_iter(eval_dataset, batch_size)):
35
+ with tqdm(total=len(eval_instances)) as pbar:
36
+ for i, batch in enumerate(batch_iter(eval_instances, batch_size)):
34
37
  lm_prompts: list[str] = []
35
38
  for eval_instance in batch:
36
39
  template_inputs = eval_instance.inputs
@@ -59,17 +62,16 @@ def evaluate_generation(
59
62
  logger.info(f"lm_outputs: {lm_outputs[0]}")
60
63
 
61
64
  lm_prompt_list += lm_prompts
62
- eval_instance_list += batch
63
65
  lm_output_list += lm_outputs
64
66
 
65
67
  pbar.update(len(batch))
66
68
  metrics_summary_dict: dict[str, float] = {}
67
- instance_metrics_list: list[dict[str, Any]] = [{} for _ in range(len(eval_instance_list))]
69
+ instance_metrics_list: list[dict[str, Any]] = [{} for _ in range(len(eval_instances))]
68
70
  for metric in metrics:
69
71
  metric_result = metric.evaluate(
70
72
  lm_outputs=lm_output_list,
71
- references_list=[i.references for i in eval_instance_list],
72
- task_inputs_list=[i.inputs for i in eval_instance_list],
73
+ references_list=[i.references for i in eval_instances],
74
+ task_inputs_list=[i.inputs for i in eval_instances],
73
75
  )
74
76
 
75
77
  metrics_summary_dict.update(metric_result.summary)
@@ -93,7 +95,7 @@ def evaluate_generation(
93
95
  for lm_prompt, lm_output, eval_instance, instance_metrics in zip(
94
96
  lm_prompt_list,
95
97
  lm_output_list,
96
- eval_instance_list,
98
+ eval_instances,
97
99
  instance_metrics_list,
98
100
  )
99
101
  ]
@@ -1,8 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
- from typing import Any
3
+ from typing import Any, Sequence
5
4
 
5
+ from loguru import logger
6
6
  from tqdm import tqdm
7
7
 
8
8
  from .few_shot_generator import FewShotGenerator
@@ -11,19 +11,22 @@ from .multiple_choice_dataset import MultipleChoiceDataset, MultipleChoiceInstan
11
11
  from .prompt_template import PromptTemplate
12
12
  from .utils.data_util import batch_iter
13
13
 
14
- logger = logging.getLogger(__name__)
15
-
16
14
 
17
15
  def evaluate_multiple_choice(
18
16
  language_model: LanguageModel,
19
17
  eval_dataset: MultipleChoiceDataset,
20
18
  prompt_template: PromptTemplate,
21
19
  batch_size: int,
20
+ max_instances: int | None = None,
22
21
  few_shot_generator: FewShotGenerator | None = None,
23
22
  ) -> tuple[dict[str, float], list[dict[str, Any]]]:
23
+ eval_instances: Sequence[MultipleChoiceInstance] = eval_dataset
24
+ if max_instances is not None:
25
+ eval_instances = [eval_dataset[i] for i in range(min(max_instances, len(eval_dataset)))]
26
+
24
27
  results: list[dict[str, Any]] = []
25
- with tqdm(total=len(eval_dataset)) as pbar:
26
- for batch_id, batch in enumerate(batch_iter(eval_dataset, batch_size)):
28
+ with tqdm(total=len(eval_instances)) as pbar:
29
+ for batch_id, batch in enumerate(batch_iter(eval_instances, batch_size)):
27
30
  batch: list[MultipleChoiceInstance]
28
31
 
29
32
  batch_prefixes: list[str] = []
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  from dataclasses import asdict
5
4
  from typing import Any
6
5
 
6
+ from loguru import logger
7
+
7
8
  from .pairwise_comparison import (
8
9
  AllCombinations,
9
10
  BradleyTerryScorer,
@@ -16,8 +17,6 @@ from .pairwise_comparison import (
16
17
  )
17
18
  from .utils.data_util import batch_iter
18
19
 
19
- logger = logging.getLogger(__name__)
20
-
21
20
 
22
21
  def evaluate_pairwise(
23
22
  model_items: dict[str, list[dict[str, Any]]],
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  import math
5
4
  from collections import defaultdict
5
+ from typing import Sequence
6
6
 
7
+ from loguru import logger
7
8
  from tqdm import tqdm
8
9
 
9
10
  from .language_model import LanguageModel
@@ -11,20 +12,23 @@ from .metric.tokenizer import Tokenizer
11
12
  from .text_dataset import TextDataset
12
13
  from .utils.data_util import batch_iter
13
14
 
14
- logger = logging.getLogger(__name__)
15
-
16
15
 
17
16
  def evaluate_perplexity(
18
17
  language_model: LanguageModel,
19
18
  eval_dataset: TextDataset,
20
19
  batch_size: int,
20
+ max_instances: int | None = None,
21
21
  tokenizer: Tokenizer | None = None,
22
22
  ) -> dict[str, float]:
23
23
  total_log_prob = 0.0
24
24
 
25
+ eval_instances: Sequence[str] = eval_dataset
26
+ if max_instances is not None:
27
+ eval_instances = [eval_dataset[i] for i in range(min(max_instances, len(eval_dataset)))]
28
+
25
29
  token_counts: dict[str, int] = defaultdict(int)
26
- with tqdm() as pbar:
27
- for batch in batch_iter(eval_dataset, batch_size):
30
+ with tqdm(total=len(eval_instances)) as pbar:
31
+ for batch in batch_iter(eval_instances, batch_size):
28
32
  log_probs = language_model.batch_compute_log_probs(batch)
29
33
  total_log_prob += sum(log_probs)
30
34
 
@@ -38,7 +38,10 @@ class BalancedFewShotGenerator(FewShotGenerator):
38
38
  label_to_ids[instance.references[0]].append(i)
39
39
  self._label_to_ids = label_to_ids
40
40
 
41
- def _sample_instances(self, eval_inputs: dict[str, Any] | None = None) -> list[GenerationInstance]:
41
+ def _sample_instances(
42
+ self,
43
+ eval_inputs: list[dict[str, Any]] | dict[str, Any] | None = None,
44
+ ) -> list[GenerationInstance]:
42
45
  # Shuffle labels
43
46
  labels = list(self._label_to_ids.keys())
44
47
  self._rnd.shuffle(labels)
@@ -3,11 +3,12 @@ from __future__ import annotations
3
3
  from abc import ABC, abstractmethod
4
4
  from typing import Any, Union
5
5
 
6
+ from flexeval.core.chat_dataset import ChatDataset, ChatInstance
6
7
  from flexeval.core.generation_dataset import GenerationDataset, GenerationInstance
7
8
  from flexeval.core.multiple_choice_dataset import MultipleChoiceDataset, MultipleChoiceInstance
8
9
 
9
- Dataset = Union[GenerationDataset, MultipleChoiceDataset]
10
- Instance = Union[GenerationInstance, MultipleChoiceInstance]
10
+ Dataset = Union[GenerationDataset, MultipleChoiceDataset, ChatDataset]
11
+ Instance = Union[GenerationInstance, MultipleChoiceInstance, ChatInstance]
11
12
 
12
13
 
13
14
  class FewShotGenerator(ABC):
@@ -15,10 +16,27 @@ class FewShotGenerator(ABC):
15
16
  self._num_trials_to_avoid_leak = num_trials_to_avoid_leak
16
17
 
17
18
  @abstractmethod
18
- def _sample_instances(self, eval_inputs: dict[str, Any] | None = None) -> list[Instance]:
19
+ def _sample_instances(self, eval_inputs: list[dict[str, Any]] | dict[str, Any] | None = None) -> list[Instance]:
20
+ """
21
+ Sample instances for few-shot learning.
22
+ This method should be implemented in the derived class.
23
+ """
19
24
  raise NotImplementedError
20
25
 
21
- def __call__(self, eval_inputs: dict[str, Any] | None = None) -> list[Instance]:
26
+ def __call__(self, eval_inputs: list[dict[str, Any]] | dict[str, Any] | None = None) -> list[Instance]:
27
+ """
28
+ Sample instances for few-shot learning.
29
+ This method calls `_sample_instances` and
30
+ checks if the sampled instances have the same inputs as the evaluation instance.
31
+
32
+ Args:
33
+ eval_inputs: The inputs of the evaluation instance.
34
+ This is used to avoid data leakage
35
+ by checking if the sampled instances have the same inputs as the evaluation instance.
36
+
37
+ Returns:
38
+ A list of instances for few-shot learning.
39
+ """
22
40
  sampled_instances = self._sample_instances(eval_inputs=eval_inputs)
23
41
 
24
42
  # check if the sampled instances are the same as the eval_instance
@@ -27,6 +27,6 @@ class RandomFewShotGenerator(FewShotGenerator):
27
27
  self._num_shots = num_shots
28
28
  self._rnd = random.Random(seed)
29
29
 
30
- def _sample_instances(self, eval_inputs: dict[str, Any] | None = None) -> list[Instance]:
30
+ def _sample_instances(self, eval_inputs: list[dict[str, Any]] | dict[str, Any] | None = None) -> list[Instance]:
31
31
  sampled_indices = self._rnd.sample(range(len(self._dataset)), self._num_shots)
32
32
  return [self._dataset[i] for i in sampled_indices]
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from dataclasses import dataclass
5
+ from typing import Sequence
5
6
 
6
7
 
7
8
  @dataclass
@@ -22,7 +23,7 @@ class GenerationInstance:
22
23
  """
23
24
 
24
25
 
25
- class GenerationDataset(ABC):
26
+ class GenerationDataset(Sequence[GenerationInstance], ABC):
26
27
  """A dataset holding `GenerationInstance`."""
27
28
 
28
29
  @abstractmethod
@@ -1,4 +1,4 @@
1
1
  from .base import LanguageModel
2
2
  from .hf_lm import HuggingFaceLM
3
- from .openai_chatgpt import OpenAIChatGPT
3
+ from .openai_chatgpt import OpenAIChatAPI
4
4
  from .vllm_model import VllmModel
@@ -1,18 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import contextlib
4
- import logging
5
4
  from typing import Any, Literal, TypeVar
6
5
 
7
6
  import torch
8
7
  import torch.nn.functional as F # noqa: N812
9
8
  import transformers
9
+ from loguru import logger
10
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding, PreTrainedModel, PreTrainedTokenizer
11
11
 
12
12
  from .base import LanguageModel
13
13
 
14
- logger = logging.getLogger(__name__)
15
-
16
14
  T = TypeVar("T")
17
15
 
18
16
 
@@ -177,13 +175,17 @@ class HuggingFaceLM(LanguageModel):
177
175
  model_kwargs = {**model_kwargs} # copy kwargs to avoid modifying the original dict
178
176
  if "device_map" not in model_kwargs:
179
177
  model_kwargs["device_map"] = "auto"
180
- if "torch_dtype" not in model_kwargs or model_kwargs["torch_dtype"] == "auto":
178
+ if "torch_dtype" not in model_kwargs:
181
179
  # You need to set torch_dtype to use the optimal dtype for the model.
182
180
  # https://huggingface.co/docs/transformers/main/main_classes/model#model-instantiation-dtype
183
181
  model_kwargs["torch_dtype"] = "auto"
184
- else:
182
+ elif model_kwargs["torch_dtype"] != "auto":
185
183
  # Convert string to torch.dtype
186
- model_kwargs["torch_dtype"] = getattr(torch, model_kwargs["torch_dtype"])
184
+ # We allow either "bfloat16" or "torch.bfloat16"
185
+ torch_dtype_str = model_kwargs["torch_dtype"]
186
+ if torch_dtype_str.startswith("torch."):
187
+ torch_dtype_str = torch_dtype_str[len("torch.") :]
188
+ model_kwargs["torch_dtype"] = getattr(torch, torch_dtype_str)
187
189
  if not isinstance(model_kwargs["torch_dtype"], torch.dtype):
188
190
  msg = f"Invalid torch_dtype: {model_kwargs['torch_dtype']}"
189
191
  raise ValueError(msg)
@@ -1,16 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
- import logging
5
4
  from typing import Awaitable, Callable, TypeVar
6
5
 
7
6
  import openai
7
+ from loguru import logger
8
8
  from openai import AsyncOpenAI
9
9
 
10
10
  from .base import LanguageModel
11
11
 
12
- logger = logging.getLogger(__name__)
13
-
14
12
  T = TypeVar("T")
15
13
 
16
14
 
@@ -21,20 +19,18 @@ async def _retry_on_error(
21
19
  ) -> Awaitable[T] | None:
22
20
  for i in range(max_num_trials):
23
21
  try:
24
- # 関数を実行する
25
22
  return await openai_call()
26
23
  except openai.APIError as e: # noqa: PERF203
27
- # 試行回数が上限に達したらエラーを送出
28
24
  if i == max_num_trials - 1:
29
25
  raise
30
- logger.info(f"エラーを受け取りました:{e}")
26
+ logger.info(f"We got an error:{e}")
31
27
  wait_time_seconds = first_wait_time * (2**i)
32
- logger.info(f"{wait_time_seconds}秒待機します")
28
+ logger.info(f"Wait for {wait_time_seconds} seconds...")
33
29
  await asyncio.sleep(wait_time_seconds)
34
30
  return None
35
31
 
36
32
 
37
- class OpenAIChatGPT(LanguageModel):
33
+ class OpenAIChatAPI(LanguageModel):
38
34
  """
39
35
  LanguageModel implementation using OpenAI's ChatGPT API.
40
36
 
@@ -8,6 +8,7 @@ import evaluate
8
8
  from flexeval.core.utils.jinja2_env import JINJA2_ENV
9
9
 
10
10
  from .base import Metric, MetricResult
11
+ from .normalizer import Normalizer
11
12
 
12
13
  # by default, the program is not allowed to execute code and we need to set this environment variable
13
14
  os.environ["HF_ALLOW_CODE_EVAL"] = "1"
@@ -21,15 +22,17 @@ class CodeEval(Metric):
21
22
  code_prompt_template: A Jinja2 template string that will prepend the generated code.
22
23
  The template should contain variables that will be replaced with the values in `task_inputs_list`.
23
24
  If `None`, the code prompt will be the generated code itself.
25
+ normalizer: A normalizer applied to model outputs before evaluation.
24
26
  """
25
27
 
26
- def __init__(self, code_prompt_template: str | None = None) -> None:
28
+ def __init__(self, code_prompt_template: str | None = None, normalizer: Normalizer | None = None) -> None:
27
29
  self._code_prompt_template = None
28
30
  if code_prompt_template is not None:
29
31
  self._code_prompt_template = JINJA2_ENV.from_string(
30
32
  code_prompt_template,
31
33
  )
32
34
  self._code_eval = evaluate.load("code_eval")
35
+ self._normalizer = normalizer
33
36
 
34
37
  def evaluate(
35
38
  self,
@@ -48,6 +51,9 @@ class CodeEval(Metric):
48
51
  task_inputs_list,
49
52
  references_list,
50
53
  ):
54
+ if self._normalizer is not None:
55
+ lm_output = self._normalizer.normalize(lm_output) # noqa: PLW2901
56
+
51
57
  generated_function = lm_output
52
58
  if self._code_prompt_template is not None:
53
59
  generated_function = self._code_prompt_template.render(**task_inputs) + lm_output
@@ -12,7 +12,7 @@ class RegexNormalizer(Normalizer):
12
12
  """
13
13
 
14
14
  def __init__(self, pattern: str) -> None:
15
- self._pattern = re.compile(pattern)
15
+ self._pattern = re.compile(pattern, flags=re.DOTALL)
16
16
 
17
17
  def normalize(self, text: str) -> str:
18
18
  found = self._pattern.findall(text)
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  import os
5
4
  import time
6
5
  from typing import Any, Callable
@@ -8,11 +7,10 @@ from typing import Any, Callable
8
7
  import numpy as np
9
8
  from googleapiclient import discovery
10
9
  from googleapiclient.errors import HttpError
10
+ from loguru import logger
11
11
 
12
12
  from .base import Metric, MetricResult
13
13
 
14
- logger = logging.getLogger(__name__)
15
-
16
14
  PERSPECTIVE_API_KEY = os.getenv("PERSPECTIVE_API_KEY")
17
15
 
18
16
 
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from dataclasses import dataclass
5
+ from typing import Sequence
5
6
 
6
7
 
7
8
  @dataclass
@@ -26,7 +27,7 @@ class MultipleChoiceInstance:
26
27
  """
27
28
 
28
29
 
29
- class MultipleChoiceDataset(ABC):
30
+ class MultipleChoiceDataset(Sequence[MultipleChoiceInstance], ABC):
30
31
  @abstractmethod
31
32
  def __len__(self) -> int:
32
33
  """
@@ -25,23 +25,6 @@ class PairwiseJudge(ABC):
25
25
  The output is a tuple of the winner and the rationale.
26
26
  """
27
27
 
28
- @abstractmethod
29
- def judge(
30
- self,
31
- model1_item: dict[str, Any],
32
- model2_item: dict[str, Any],
33
- ) -> tuple[Winner, str]:
34
- """
35
- Judge which model is better given two items.
36
-
37
- Args:
38
- model1_item: The first model item, containing the model output and other information needed for judging.
39
- model2_item: The second model item, containing the model output and other information needed for judging.
40
-
41
- Returns:
42
- A tuple of the winner and the rationale.
43
- """
44
-
45
28
  @abstractmethod
46
29
  def batch_judge(
47
30
  self,
@@ -1,16 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  import re
5
4
  from typing import Any
6
5
 
6
+ from loguru import logger
7
+
7
8
  from flexeval.core.language_model.base import LanguageModel
8
9
  from flexeval.core.prompt_template.base import PromptTemplate
9
10
 
10
11
  from .base import PairwiseJudge, Winner
11
12
 
12
- logger = logging.getLogger(__name__)
13
-
14
13
 
15
14
  class ChatLLMPairwiseJudge(PairwiseJudge):
16
15
  """
@@ -60,9 +59,6 @@ class ChatLLMPairwiseJudge(PairwiseJudge):
60
59
  else:
61
60
  return winner, rationale
62
61
 
63
- def judge(self, model1_item: dict[str, Any], model2_item: dict[str, Any]) -> tuple[Winner, str]:
64
- return self.batch_judge([(model1_item, model2_item)])[0]
65
-
66
62
  def batch_judge(self, batch_model_items: list[tuple[dict[str, Any], dict[str, Any]]]) -> list[tuple[Winner, str]]:
67
63
  input_chat_messages_list: list[list[dict[str, str]]] = []
68
64
  for model1_item, model2_item in batch_model_items: