janus-llm 4.3.5__py3-none-any.whl → 4.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. janus/__init__.py +1 -1
  2. janus/cli/aggregate.py +2 -2
  3. janus/cli/cli.py +6 -0
  4. janus/cli/constants.py +6 -0
  5. janus/cli/diagram.py +36 -7
  6. janus/cli/document.py +10 -1
  7. janus/cli/llm.py +7 -3
  8. janus/cli/partition.py +10 -1
  9. janus/cli/pipeline.py +126 -0
  10. janus/cli/self_eval.py +10 -3
  11. janus/cli/translate.py +10 -1
  12. janus/converter/__init__.py +2 -0
  13. janus/converter/_tests/test_translate.py +6 -5
  14. janus/converter/chain.py +100 -0
  15. janus/converter/converter.py +467 -90
  16. janus/converter/diagram.py +12 -8
  17. janus/converter/document.py +17 -7
  18. janus/converter/evaluate.py +174 -147
  19. janus/converter/partition.py +6 -11
  20. janus/converter/passthrough.py +29 -0
  21. janus/converter/pool.py +74 -0
  22. janus/converter/requirements.py +7 -40
  23. janus/converter/translate.py +2 -58
  24. janus/language/_tests/test_combine.py +1 -0
  25. janus/language/block.py +115 -5
  26. janus/llm/model_callbacks.py +6 -0
  27. janus/llm/models_info.py +19 -0
  28. janus/metrics/_tests/test_reading.py +48 -4
  29. janus/metrics/_tests/test_rouge_score.py +5 -11
  30. janus/metrics/metric.py +47 -124
  31. janus/metrics/reading.py +48 -28
  32. janus/metrics/rouge_score.py +21 -34
  33. janus/parsers/_tests/test_code_parser.py +1 -1
  34. janus/parsers/code_parser.py +2 -2
  35. janus/parsers/eval_parsers/incose_parser.py +3 -3
  36. janus/parsers/reqs_parser.py +3 -3
  37. janus/prompts/templates/cyclic/human.txt +16 -0
  38. janus/prompts/templates/cyclic/system.txt +1 -0
  39. janus/prompts/templates/eval_prompts/incose/human.txt +1 -1
  40. janus/prompts/templates/extract_variables/human.txt +5 -0
  41. janus/prompts/templates/extract_variables/system.txt +1 -0
  42. {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/METADATA +14 -15
  43. {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/RECORD +46 -40
  44. {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/WHEEL +1 -1
  45. janus/metrics/_tests/test_llm.py +0 -90
  46. janus/metrics/llm_metrics.py +0 -202
  47. {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/LICENSE +0 -0
  48. {janus_llm-4.3.5.dist-info → janus_llm-4.5.4.dist-info}/entry_points.txt +0 -0
@@ -1,8 +1,4 @@
1
- import json
2
- from pathlib import Path
3
-
4
1
  from janus.converter.document import Documenter
5
- from janus.language.block import TranslatedCodeBlock
6
2
  from janus.language.combine import ChunkCombiner
7
3
  from janus.parsers.reqs_parser import RequirementsParser
8
4
  from janus.utils.logger import create_logger
@@ -16,41 +12,12 @@ class RequirementsDocumenter(Documenter):
16
12
  A class that translates code from one programming language to its requirements.
17
13
  """
18
14
 
19
- def __init__(self, **kwargs):
20
- super().__init__(**kwargs)
21
- self.set_prompt("requirements")
15
+ def __init__(
16
+ self, combine_output: bool = False, output_type: str = "requirements", **kwargs
17
+ ):
18
+ kwargs.update(output_type=output_type)
19
+ super().__init__(combine_output=combine_output, **kwargs)
20
+ self.set_prompts("requirements")
22
21
  self._combiner = ChunkCombiner()
23
22
  self._parser = RequirementsParser()
24
-
25
- @staticmethod
26
- def get_prompt_replacements(block) -> dict[str, str]:
27
- prompt_replacements: dict[str, str] = {"SOURCE_CODE": block.original.text}
28
- return prompt_replacements
29
-
30
- def _save_to_file(self, block: TranslatedCodeBlock, out_path: Path) -> None:
31
- """Save a file to disk.
32
-
33
- Arguments:
34
- block: The `CodeBlock` to save to a file.
35
- """
36
- output_list = list()
37
- # For each chunk of code, get generation metadata, the text of the code,
38
- # and the LLM generated requirements
39
- blocks = [block for block in block.children] if len(block.children) else [block]
40
- for block in blocks:
41
- code = block.original.text
42
- requirements = self._parser.parse_combined_output(block.complete_text)
43
- metadata = dict(
44
- retries=block.total_retries,
45
- cost=block.total_cost,
46
- processing_time=block.processing_time,
47
- )
48
- # Put them all in a top level 'output' key
49
- output_list.append(
50
- dict(metadata=metadata, code=code, requirements=requirements)
51
- )
52
- obj = dict(
53
- output=output_list,
54
- )
55
- out_path.parent.mkdir(parents=True, exist_ok=True)
56
- out_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
23
+ self._load_parameters()
@@ -1,8 +1,5 @@
1
1
  from janus.converter.converter import Converter, run_if_changed
2
- from janus.llm.models_info import MODEL_PROMPT_ENGINES
3
2
  from janus.parsers.code_parser import CodeParser
4
- from janus.prompts.prompt import SAME_OUTPUT
5
- from janus.utils.enums import LANGUAGES
6
3
  from janus.utils.logger import create_logger
7
4
 
8
5
  log = create_logger(__name__)
@@ -29,13 +26,11 @@ class Translator(Converter):
29
26
  max_prompts: The maximum number of prompts to try before giving up.
30
27
  max_tokens: The maximum number of tokens the model will take in.
31
28
  If unspecificed, model's default max will be used.
32
- prompt_template: name of prompt template directory
33
- (see janus/prompts/templates) or path to a directory.
29
+ prompt_templates: name of prompt template directories
30
+ (see janus/prompts/templates) or paths to directories.
34
31
  """
35
32
  super().__init__(**kwargs)
36
33
 
37
- self._target_version: str | None
38
-
39
34
  self.set_target_language(
40
35
  target_language=target_language,
41
36
  target_version=target_version,
@@ -47,57 +42,6 @@ class Translator(Converter):
47
42
  self._load_parser()
48
43
  super()._load_parameters()
49
44
 
50
- def set_target_language(
51
- self, target_language: str, target_version: str | None
52
- ) -> None:
53
- """Validate and set the target language.
54
-
55
- The affected objects will not be updated until translate() is called.
56
-
57
- Arguments:
58
- target_language: The target programming language.
59
- target_version: The target version of the target programming language.
60
- """
61
- target_language = target_language.lower()
62
- if target_language not in LANGUAGES:
63
- raise ValueError(
64
- f"Invalid target language: {target_language}. "
65
- "Valid target languages are found in `janus.utils.enums.LANGUAGES`."
66
- )
67
- self._target_language = target_language
68
- self._target_version = target_version
69
- # Taking the first suffix as the default for output files
70
- self._target_suffix = f".{LANGUAGES[target_language]['suffixes'][0]}"
71
-
72
- @run_if_changed(
73
- "_prompt_template_name",
74
- "_source_language",
75
- "_target_language",
76
- "_target_version",
77
- "_model_name",
78
- )
79
- def _load_prompt(self) -> None:
80
- """Load the prompt according to this instance's attributes.
81
-
82
- If the relevant fields have not been changed since the last time this
83
- method was called, nothing happens.
84
- """
85
- if self._prompt_template_name in SAME_OUTPUT:
86
- if self._target_language != self._source_language:
87
- raise ValueError(
88
- f"Prompt template ({self._prompt_template_name}) suggests "
89
- f"source and target languages should match, but do not "
90
- f"({self._source_language} != {self._target_language})"
91
- )
92
-
93
- prompt_engine = MODEL_PROMPT_ENGINES[self._llm.short_model_id](
94
- source_language=self._source_language,
95
- target_language=self._target_language,
96
- target_version=self._target_version,
97
- prompt_template=self._prompt_template_name,
98
- )
99
- self._prompt = prompt_engine.prompt
100
-
101
45
  @run_if_changed("_target_language")
102
46
  def _load_parser(self) -> None:
103
47
  """Load the parser according to this instance's attributes.
@@ -36,6 +36,7 @@ class TestCombiner(unittest.TestCase):
36
36
  self.translated_block = TranslatedCodeBlock(
37
37
  self.block,
38
38
  language="python",
39
+ converter=None,
39
40
  )
40
41
 
41
42
  def test_combine(self):
janus/language/block.py CHANGED
@@ -1,9 +1,12 @@
1
1
  from functools import total_ordering
2
- from typing import ForwardRef, Hashable, Optional, Tuple
2
+ from typing import TYPE_CHECKING, ForwardRef, Hashable, Optional, Tuple
3
3
 
4
4
  from janus.language.node import NodeType
5
5
  from janus.utils.logger import create_logger
6
6
 
7
+ if TYPE_CHECKING:
8
+ from janus.converter.converter import Converter
9
+
7
10
  log = create_logger(__name__)
8
11
 
9
12
 
@@ -46,6 +49,9 @@ class CodeBlock:
46
49
  embedding_id: Optional[str] = None,
47
50
  affixes: Tuple[str, str] = ("", ""),
48
51
  context_tags: dict[str, str] = {},
52
+ previous_generations: list["TranslatedCodeBlock"] = [],
53
+ block_type: str | None = None,
54
+ block_label: str | None = None,
49
55
  ) -> None:
50
56
  self.id: Hashable = id
51
57
  self.name: Optional[str] = name
@@ -65,6 +71,9 @@ class CodeBlock:
65
71
  self.complete = True
66
72
  self.omit_prefix = True
67
73
  self.omit_suffix = False
74
+ self.previous_generations = previous_generations
75
+ self.block_type = block_type
76
+ self.block_label = block_label
68
77
 
69
78
  if self.children:
70
79
  self.children[0].omit_prefix = False
@@ -184,12 +193,23 @@ class TranslatedCodeBlock(CodeBlock):
184
193
  translated: Whether this block has been successfully translated
185
194
  """
186
195
 
187
- def __init__(self, original: CodeBlock, language: str) -> None:
196
+ def __init__(
197
+ self,
198
+ original: CodeBlock,
199
+ language: str,
200
+ converter: ForwardRef("Converter"),
201
+ block_type: str | None = None,
202
+ block_label: str | None = None,
203
+ ) -> None:
188
204
  """Create an "empty" `TranslatedCodeBlock` from the given original
189
205
 
190
206
  Arguments:
191
207
  original: The original code block
192
208
  language: The language to translate to
209
+ converter: the converter used to translate
210
+ block_type: type of the block
211
+ block_label: label for block
212
+ (for mapping outputs to inputs through ConverterChain)
193
213
 
194
214
  Returns:
195
215
  A `TranslatedCodeBlock` with the same attributes as the original, except
@@ -207,18 +227,24 @@ class TranslatedCodeBlock(CodeBlock):
207
227
  end_byte=None,
208
228
  tokens=0,
209
229
  children=[
210
- TranslatedCodeBlock(child, language) for child in original.children
230
+ TranslatedCodeBlock(child, language, block_type, block_label)
231
+ for child in original.children
211
232
  ],
212
233
  affixes=original.affixes,
234
+ previous_generations=original.previous_generations,
235
+ block_type=block_type,
236
+ block_label=block_label,
213
237
  )
238
+
214
239
  self.original = original
240
+ self.converter = converter
215
241
 
216
242
  self.complete = original.complete
217
243
  self.translated = False
218
- self.cost = 0.0
244
+ self.cost = 0
219
245
  self.num_requests = 0
220
246
  self.tokens = 0
221
- self.processing_time = 0.0
247
+ self.processing_time = 0
222
248
 
223
249
  self.request_input_tokens = 0
224
250
  self.request_output_tokens = 0
@@ -276,6 +302,11 @@ class TranslatedCodeBlock(CodeBlock):
276
302
  children_sum = sum(c.total_num_requests for c in self.children)
277
303
  return children_sum + self.num_requests
278
304
 
305
+ @property
306
+ def total_processing_time(self) -> float:
307
+ children_sum = sum(c.total_processing_time for c in self.children)
308
+ return children_sum + self.processing_time
309
+
279
310
  @property
280
311
  def translation_completed(self) -> bool:
281
312
  """Whether or not the code block was successfully translated
@@ -297,3 +328,82 @@ class TranslatedCodeBlock(CodeBlock):
297
328
  if self.original.total_tokens
298
329
  else 0
299
330
  )
331
+
332
+ def to_codeblock(self) -> CodeBlock:
333
+ return CodeBlock(
334
+ id=self.id,
335
+ name=self.name,
336
+ node_type=self.node_type,
337
+ language=self.language,
338
+ text=self.text,
339
+ start_point=self.start_point,
340
+ end_point=self.end_point,
341
+ start_byte=self.start_byte,
342
+ end_byte=self.end_byte,
343
+ embedding_id=self.embedding_id,
344
+ tokens=self.tokens,
345
+ children=[child.to_codeblock() for child in self.children],
346
+ affixes=self.affixes,
347
+ previous_generations=self.previous_generations + [self],
348
+ block_type=self.block_type,
349
+ block_label=self.block_label,
350
+ )
351
+
352
+ def __iadd__(self, other):
353
+ self.cost += other.cost
354
+ self.num_requests += other.num_requests
355
+ self.processing_time += other.processing_time
356
+ self.request_input_tokens += other.request_input_tokens
357
+ self.request_output_tokens += other.request_output_tokens
358
+ return self
359
+
360
+
361
+ class BlockCollection:
362
+ def __init__(
363
+ self,
364
+ blocks: list[CodeBlock],
365
+ previous_generations: list[ForwardRef("BlockCollection")] = [],
366
+ ):
367
+ self.blocks = blocks
368
+ self.previous_generations = previous_generations
369
+
370
+ def to_codeblock(self) -> ForwardRef("BlockCollection"):
371
+ return BlockCollection(
372
+ [b.to_codeblock() for b in self.blocks], self.previous_generations + [self]
373
+ )
374
+
375
+ @property
376
+ def total_cost(self):
377
+ return sum(b.total_cost for b in self.blocks)
378
+
379
+ @property
380
+ def total_processing_time(self):
381
+ return sum(b.total_processing_time for b in self.blocks)
382
+
383
+ @property
384
+ def total_request_input_tokens(self):
385
+ return sum(b.total_request_input_tokens for b in self.blocks)
386
+
387
+ @property
388
+ def total_request_output_tokens(self):
389
+ return sum(b.total_request_output_tokens for b in self.blocks)
390
+
391
+ @property
392
+ def total_num_requests(self):
393
+ return sum(b.total_num_requests for b in self.blocks)
394
+
395
+ @property
396
+ def block_type(self):
397
+ return None
398
+
399
+ @property
400
+ def block_label(self):
401
+ return None
402
+
403
+ @property
404
+ def translation_completed(self):
405
+ return all(b.translation_completed for b in self.blocks)
406
+
407
+ @property
408
+ def complete(self):
409
+ return all(b.complete for b in self.blocks)
@@ -44,12 +44,18 @@ COST_PER_1K_TOKENS: dict[str, dict[str, float]] = {
44
44
  "anthropic.claude-instant-v1": {"input": 0.0008, "output": 0.0024},
45
45
  "anthropic.claude-3-haiku-20240307-v1:0": {"input": 0.00025, "output": 0.00125},
46
46
  "anthropic.claude-3-sonnet-20240229-v1:0": {"input": 0.003, "output": 0.015},
47
+ "anthropic.claude-3-5-sonnet-20240620-v1:0": {"input": 0.003, "output": 0.015},
48
+ "anthropic.claude-3-5-sonnet-20241022-v2:0": {"input": 0.003, "output": 0.015},
47
49
  "meta.llama2-13b-chat-v1": {"input": 0.00075, "output": 0.001},
48
50
  "meta.llama2-70b-chat-v1": {"input": 0.00195, "output": 0.00256},
49
51
  "meta.llama2-13b-v1": {"input": 0.0, "output": 0.0},
50
52
  "meta.llama2-70b-v1": {"input": 0.00265, "output": 0.0035},
51
53
  "meta.llama3-8b-instruct-v1:0": {"input": 0.0003, "output": 0.0006},
52
54
  "meta.llama3-70b-instruct-v1:0": {"input": 0.00265, "output": 0.0035},
55
+ "meta.llama3-3-70b-instruct-v1:0": {"input": 0.00072, "output": 0.00072},
56
+ "amazon.nova-lite-v1:0": {"input": 0.00006, "output": 0.00024},
57
+ "amazon.nova-micro-v1:0": {"input": 0.000035, "output": 0.00014},
58
+ "amazon.nova-pro-v1:0": {"input": 0.0008, "output": 0.0032},
53
59
  "amazon.titan-text-lite-v1": {"input": 0.00015, "output": 0.0002},
54
60
  "amazon.titan-text-express-v1": {"input": 0.0002, "output": 0.0006},
55
61
  "ai21.j2-mid-v1": {"input": 0.0125, "output": 0.0125},
janus/llm/models_info.py CHANGED
@@ -96,12 +96,16 @@ claude_models = [
96
96
  "bedrock-claude-haiku",
97
97
  "bedrock-claude-sonnet",
98
98
  "bedrock-claude-sonnet-3.5",
99
+ "bedrock-claude-sonnet-3.5-v2",
99
100
  ]
100
101
  llama2_models = [
101
102
  "bedrock-llama2-70b",
102
103
  "bedrock-llama2-70b-chat",
103
104
  "bedrock-llama2-13b",
104
105
  "bedrock-llama2-13b-chat",
106
+ "bedrock-llama3-8b-instruct",
107
+ "bedrock-llama3-70b-instruct",
108
+ "bedrock-llama3-3-70b-instruct",
105
109
  ]
106
110
  llama3_models = [
107
111
  "bedrock-llama3-8b-instruct",
@@ -113,6 +117,11 @@ titan_models = [
113
117
  "bedrock-jurassic-2-mid",
114
118
  "bedrock-jurassic-2-ultra",
115
119
  ]
120
+ nova_models = [
121
+ "bedrock-nova-lite",
122
+ "bedrock-nova-micro",
123
+ "bedrock-nova-pro",
124
+ ]
116
125
  cohere_models = [
117
126
  "bedrock-command-r-plus",
118
127
  ]
@@ -160,12 +169,17 @@ MODEL_ID_TO_LONG_ID = {
160
169
  "bedrock-claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
161
170
  "bedrock-claude-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
162
171
  "bedrock-claude-sonnet-3.5": "anthropic.claude-3-5-sonnet-20240620-v1:0",
172
+ "bedrock-claude-sonnet-3.5-v2": "anthropic.claude-3-5-sonnet-20241022-v2:0",
163
173
  "bedrock-llama2-70b": "meta.llama2-70b-v1",
164
174
  "bedrock-llama2-70b-chat": "meta.llama2-70b-chat-v1",
165
175
  "bedrock-llama2-13b": "meta.llama2-13b-chat-v1",
166
176
  "bedrock-llama2-13b-chat": "meta.llama2-13b-v1",
167
177
  "bedrock-llama3-8b-instruct": "meta.llama3-8b-instruct-v1:0",
168
178
  "bedrock-llama3-70b-instruct": "meta.llama3-70b-instruct-v1:0",
179
+ "bedrock-llama3-3-70b-instruct": "meta.llama3-3-70b-instruct-v1:0",
180
+ "bedrock-nova-lite": "amazon.nova-lite-v1:0",
181
+ "bedrock-nova-micro": "amazon.nova-micro-v1:0",
182
+ "bedrock-nova-pro": "amazon.nova-pro-v1:0",
169
183
  "bedrock-titan-text-lite": "amazon.titan-text-lite-v1",
170
184
  "bedrock-titan-text-express": "amazon.titan-text-express-v1",
171
185
  "bedrock-jurassic-2-mid": "ai21.j2-mid-v1",
@@ -208,12 +222,17 @@ TOKEN_LIMITS: dict[str, int] = {
208
222
  "anthropic.claude-3-haiku-20240307-v1:0": 248_000,
209
223
  "anthropic.claude-3-sonnet-20240229-v1:0": 248_000,
210
224
  "anthropic.claude-3-5-sonnet-20240620-v1:0": 200_000,
225
+ "anthropic.claude-3-5-sonnet-20241022-v2:0": 200_000,
211
226
  "meta.llama2-70b-v1": 4096,
212
227
  "meta.llama2-70b-chat-v1": 4096,
213
228
  "meta.llama2-13b-chat-v1": 4096,
214
229
  "meta.llama2-13b-v1": 4096,
215
230
  "meta.llama3-8b-instruct-v1:0": 8000,
216
231
  "meta.llama3-70b-instruct-v1:0": 8000,
232
+ "meta.llama3-3-70b-instruct-v1:0": 128_000,
233
+ "amazon.nova-lite-v1:0": 300_000,
234
+ "amazon.nova-micro-v1:0": 128_000,
235
+ "amazon.nova-pro-v1:0": 300_000,
217
236
  "amazon.titan-text-lite-v1": 4096,
218
237
  "amazon.titan-text-express-v1": 8192,
219
238
  "ai21.j2-mid-v1": 8192,
@@ -1,11 +1,25 @@
1
1
  import unittest
2
2
 
3
- from janus.metrics.reading import _repeat_text, flesch, gunning_fog
3
+ from janus.metrics.reading import (
4
+ _repeat_text,
5
+ automated_readability,
6
+ coleman_liau,
7
+ dale_chall,
8
+ flesch,
9
+ flesch_grade,
10
+ gunning_fog,
11
+ word_count,
12
+ )
4
13
 
5
14
 
6
15
  class TestReading(unittest.TestCase):
7
16
  def setUp(self):
8
- self.text = "This is a sample text for testing readability metrics"
17
+ self.text = "This is a sample text for testing readability metrics."
18
+
19
+ def test_word_count(self):
20
+ """Test the word_count function."""
21
+ count = word_count(self.text)
22
+ self.assertEqual(count, 9)
9
23
 
10
24
  def test_repeat_text(self):
11
25
  """Test the _repeat_text function."""
@@ -16,12 +30,42 @@ class TestReading(unittest.TestCase):
16
30
  def test_flesch(self):
17
31
  """Test the Flesch readability score."""
18
32
  score = flesch(self.text)
19
- self.assertAlmostEqual(score, 47.3, places=2)
33
+ self.assertAlmostEqual(score, 45.42, places=2)
34
+
35
+ def test_flesch_grade(self):
36
+ """Test the Flesch Grade Level readability score."""
37
+ score = flesch_grade(self.text)
38
+ self.assertAlmostEqual(score, 9.2, places=2)
20
39
 
21
40
  def test_gunning_fog(self):
22
41
  """Test the Gunning-Fog readability score."""
23
42
  score = gunning_fog(self.text)
24
- self.assertAlmostEqual(score, 8.04, places=2)
43
+ self.assertAlmostEqual(score, 3.97, places=2)
44
+
45
+ def test_dale_chall(self):
46
+ """Test the Dale-Chall readability score."""
47
+ score = dale_chall(self.text)
48
+ self.assertAlmostEqual(score, 4.67, places=2)
49
+
50
+ def test_automated_readability(self):
51
+ """Test the Automated Readability Index score."""
52
+ score = automated_readability(self.text)
53
+ self.assertAlmostEqual(score, 7.1, places=2)
54
+
55
+ def test_coleman_liau(self):
56
+ """Test the Coleman-Liau Index."""
57
+ score = coleman_liau(self.text)
58
+ self.assertAlmostEqual(score, 9.94, places=2)
59
+
60
+ def test_blank_target(self):
61
+ """Test that blank targets return None for all metric functions."""
62
+ blank = " " # blank string with whitespaces
63
+ self.assertIsNone(flesch(blank))
64
+ self.assertIsNone(flesch_grade(blank))
65
+ self.assertIsNone(gunning_fog(blank))
66
+ self.assertIsNone(dale_chall(blank))
67
+ self.assertIsNone(automated_readability(blank))
68
+ self.assertIsNone(coleman_liau(blank))
25
69
 
26
70
 
27
71
  if __name__ == "__main__":
@@ -12,19 +12,13 @@ class TestRouge(unittest.TestCase):
12
12
  score = rouge(
13
13
  self.target, self.reference, granularity="n", n_gram=2, score_type="f"
14
14
  )
15
- self.assertIsInstance(score, float)
15
+ self.assertEqual(score, 0.5)
16
16
 
17
17
  def test_rouge_with_granularity_l(self):
18
18
  score = rouge(
19
19
  self.target, self.reference, granularity="l", n_gram=2, score_type="f"
20
20
  )
21
- self.assertIsInstance(score, float)
22
-
23
- def test_rouge_with_granularity_w(self):
24
- score = rouge(
25
- self.target, self.reference, granularity="w", n_gram=2, score_type="f"
26
- )
27
- self.assertIsInstance(score, float)
21
+ self.assertAlmostEqual(score, 0.8, places=2)
28
22
 
29
23
  def test_rouge_with_invalid_granularity(self):
30
24
  with self.assertRaises(ValueError):
@@ -40,19 +34,19 @@ class TestRouge(unittest.TestCase):
40
34
  score = rouge(
41
35
  self.target, self.reference, granularity="n", n_gram=2, score_type="f"
42
36
  )
43
- self.assertIsInstance(score, float)
37
+ self.assertAlmostEqual(score, 0.5, places=2)
44
38
 
45
39
  def test_rouge_with_score_type_p(self):
46
40
  score = rouge(
47
41
  self.target, self.reference, granularity="n", n_gram=2, score_type="p"
48
42
  )
49
- self.assertIsInstance(score, float)
43
+ self.assertAlmostEqual(score, 0.5, places=2)
50
44
 
51
45
  def test_rouge_with_score_type_r(self):
52
46
  score = rouge(
53
47
  self.target, self.reference, granularity="n", n_gram=2, score_type="r"
54
48
  )
55
- self.assertIsInstance(score, float)
49
+ self.assertAlmostEqual(score, 0.5, places=2)
56
50
 
57
51
  def test_rouge_with_invalid_score_type(self):
58
52
  with self.assertRaises(ValueError):