janus-llm 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- janus/__init__.py +9 -1
- janus/__main__.py +4 -0
- janus/_tests/test_cli.py +128 -0
- janus/_tests/test_translate.py +49 -7
- janus/cli.py +530 -46
- janus/converter.py +50 -19
- janus/embedding/_tests/test_collections.py +2 -8
- janus/embedding/_tests/test_database.py +32 -0
- janus/embedding/_tests/test_vectorize.py +9 -4
- janus/embedding/collections.py +49 -6
- janus/embedding/embedding_models_info.py +130 -0
- janus/embedding/vectorize.py +53 -62
- janus/language/_tests/__init__.py +0 -0
- janus/language/_tests/test_combine.py +62 -0
- janus/language/_tests/test_splitter.py +16 -0
- janus/language/binary/_tests/test_binary.py +16 -1
- janus/language/binary/binary.py +10 -3
- janus/language/block.py +31 -30
- janus/language/combine.py +26 -34
- janus/language/mumps/_tests/test_mumps.py +2 -2
- janus/language/mumps/mumps.py +93 -9
- janus/language/naive/__init__.py +4 -0
- janus/language/naive/basic_splitter.py +14 -0
- janus/language/naive/chunk_splitter.py +26 -0
- janus/language/naive/registry.py +13 -0
- janus/language/naive/simple_ast.py +18 -0
- janus/language/naive/tag_splitter.py +61 -0
- janus/language/splitter.py +168 -74
- janus/language/treesitter/_tests/test_treesitter.py +19 -14
- janus/language/treesitter/treesitter.py +37 -13
- janus/llm/model_callbacks.py +177 -0
- janus/llm/models_info.py +165 -72
- janus/metrics/__init__.py +8 -0
- janus/metrics/_tests/__init__.py +0 -0
- janus/metrics/_tests/reference.py +2 -0
- janus/metrics/_tests/target.py +2 -0
- janus/metrics/_tests/test_bleu.py +56 -0
- janus/metrics/_tests/test_chrf.py +67 -0
- janus/metrics/_tests/test_file_pairing.py +59 -0
- janus/metrics/_tests/test_llm.py +91 -0
- janus/metrics/_tests/test_reading.py +28 -0
- janus/metrics/_tests/test_rouge_score.py +65 -0
- janus/metrics/_tests/test_similarity_score.py +23 -0
- janus/metrics/_tests/test_treesitter_metrics.py +110 -0
- janus/metrics/bleu.py +66 -0
- janus/metrics/chrf.py +55 -0
- janus/metrics/cli.py +7 -0
- janus/metrics/complexity_metrics.py +208 -0
- janus/metrics/file_pairing.py +113 -0
- janus/metrics/llm_metrics.py +202 -0
- janus/metrics/metric.py +466 -0
- janus/metrics/reading.py +70 -0
- janus/metrics/rouge_score.py +96 -0
- janus/metrics/similarity.py +53 -0
- janus/metrics/splitting.py +38 -0
- janus/parsers/_tests/__init__.py +0 -0
- janus/parsers/_tests/test_code_parser.py +32 -0
- janus/parsers/code_parser.py +24 -253
- janus/parsers/doc_parser.py +169 -0
- janus/parsers/eval_parser.py +80 -0
- janus/parsers/reqs_parser.py +72 -0
- janus/prompts/prompt.py +103 -30
- janus/translate.py +636 -111
- janus/utils/_tests/__init__.py +0 -0
- janus/utils/_tests/test_logger.py +67 -0
- janus/utils/_tests/test_progress.py +20 -0
- janus/utils/enums.py +56 -3
- janus/utils/progress.py +56 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/METADATA +27 -11
- janus_llm-2.0.1.dist-info/RECORD +94 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/WHEEL +1 -1
- janus_llm-1.0.0.dist-info/RECORD +0 -48
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/LICENSE +0 -0
- {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/entry_points.txt +0 -0
janus/metrics/metric.py
ADDED
@@ -0,0 +1,466 @@
|
|
1
|
+
import inspect
|
2
|
+
import json
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Callable, Optional
|
5
|
+
|
6
|
+
import click
|
7
|
+
import typer
|
8
|
+
from typing_extensions import Annotated
|
9
|
+
|
10
|
+
from janus.llm import load_model
|
11
|
+
from janus.utils.enums import LANGUAGES
|
12
|
+
from janus.utils.logger import create_logger
|
13
|
+
|
14
|
+
from ..utils.progress import track
|
15
|
+
from .cli import evaluate
|
16
|
+
from .file_pairing import FILE_PAIRING_METHODS
|
17
|
+
from .splitting import SPLITTING_METHODS
|
18
|
+
|
19
|
+
log = create_logger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
def metric(
|
23
|
+
name: None | str = None,
|
24
|
+
help: None | str = None,
|
25
|
+
use_reference: bool = True,
|
26
|
+
) -> Callable:
|
27
|
+
"""Returns a decorator to add a given metric to the cli
|
28
|
+
|
29
|
+
Metrics must follow the format (src_str, cmp_str, **other_params)
|
30
|
+
|
31
|
+
Arguments:
|
32
|
+
name: The name of the metric. If None, the function name is used.
|
33
|
+
help: The help text for the metric.
|
34
|
+
use_reference: Whether the metric requires a reference string.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
The decorator function.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def decorator(function):
|
41
|
+
if use_reference:
|
42
|
+
|
43
|
+
def func(
|
44
|
+
out_file: Annotated[
|
45
|
+
str,
|
46
|
+
typer.Option("--out-file", "-o", help="Output JSON file to write."),
|
47
|
+
],
|
48
|
+
language: Annotated[
|
49
|
+
Optional[str],
|
50
|
+
typer.Option(
|
51
|
+
"--language",
|
52
|
+
"-l",
|
53
|
+
help="The language of the source code.",
|
54
|
+
click_type=click.Choice(sorted(LANGUAGES)),
|
55
|
+
),
|
56
|
+
] = None,
|
57
|
+
target: Annotated[
|
58
|
+
Optional[str],
|
59
|
+
typer.Option(
|
60
|
+
"--target",
|
61
|
+
"-t",
|
62
|
+
help="Target file or string to evaluate.",
|
63
|
+
),
|
64
|
+
] = None,
|
65
|
+
reference: Annotated[
|
66
|
+
Optional[str],
|
67
|
+
typer.Option(
|
68
|
+
"--reference",
|
69
|
+
"-r",
|
70
|
+
help="Reference file or string to use as reference/baseline.",
|
71
|
+
),
|
72
|
+
] = None,
|
73
|
+
json_file_name: Annotated[
|
74
|
+
Optional[str],
|
75
|
+
typer.Option(
|
76
|
+
"--json",
|
77
|
+
"-j",
|
78
|
+
help="Json file to extract pairs from \
|
79
|
+
(if set ignores --target and --reference)",
|
80
|
+
),
|
81
|
+
] = None,
|
82
|
+
target_key: Annotated[
|
83
|
+
str,
|
84
|
+
typer.Option(
|
85
|
+
"--target-key",
|
86
|
+
"-tk",
|
87
|
+
help="json key to extract list of target strings",
|
88
|
+
),
|
89
|
+
] = "target",
|
90
|
+
reference_key: Annotated[
|
91
|
+
str,
|
92
|
+
typer.Option(
|
93
|
+
"--reference-key",
|
94
|
+
"-rk",
|
95
|
+
help="json key to extract list of reference strings",
|
96
|
+
),
|
97
|
+
] = "reference",
|
98
|
+
file_pairing_method: Annotated[
|
99
|
+
str,
|
100
|
+
typer.Option(
|
101
|
+
"--method",
|
102
|
+
"-m",
|
103
|
+
click_type=click.Choice(FILE_PAIRING_METHODS.keys()),
|
104
|
+
help="Method to use for pairing\
|
105
|
+
segments of target and reference files \
|
106
|
+
(ignored for json).",
|
107
|
+
),
|
108
|
+
] = "file",
|
109
|
+
llm_name: Annotated[
|
110
|
+
str,
|
111
|
+
typer.Option(
|
112
|
+
"--llm",
|
113
|
+
"-L",
|
114
|
+
help="The custom name of the model set with 'janus llm add'.",
|
115
|
+
),
|
116
|
+
] = "gpt-3.5-turbo-0125",
|
117
|
+
progress: Annotated[
|
118
|
+
bool,
|
119
|
+
typer.Option(
|
120
|
+
"--progress",
|
121
|
+
"-p",
|
122
|
+
help="Whether to display a progress bar.",
|
123
|
+
is_flag=True,
|
124
|
+
),
|
125
|
+
] = False,
|
126
|
+
use_strings: Annotated[
|
127
|
+
bool,
|
128
|
+
typer.Option(
|
129
|
+
"--string",
|
130
|
+
"-S",
|
131
|
+
help="Indicate that the target and reference are strings",
|
132
|
+
is_flag=True,
|
133
|
+
),
|
134
|
+
] = False,
|
135
|
+
*args,
|
136
|
+
**kwargs,
|
137
|
+
):
|
138
|
+
out = []
|
139
|
+
llm, token_limit, model_cost = load_model(llm_name)
|
140
|
+
if json_file_name is not None:
|
141
|
+
with open(json_file_name, "r") as f:
|
142
|
+
json_obj = json.load(f)
|
143
|
+
pairs = {}
|
144
|
+
for key in json_obj:
|
145
|
+
doc = json_obj[key]
|
146
|
+
ref = doc[reference_key]
|
147
|
+
experiments = doc["experiments"]
|
148
|
+
for model_key in experiments:
|
149
|
+
model_dict = experiments[model_key]
|
150
|
+
if not isinstance(model_dict, dict):
|
151
|
+
continue
|
152
|
+
if target_key not in model_dict:
|
153
|
+
continue
|
154
|
+
if model_key not in pairs:
|
155
|
+
pairs[model_key] = {}
|
156
|
+
for k in model_dict[target_key]:
|
157
|
+
pairs[model_key][k] = (model_dict[target_key][k], ref[k])
|
158
|
+
elif target is not None and reference is not None:
|
159
|
+
if use_strings:
|
160
|
+
target_contents = target
|
161
|
+
reference_contents = reference
|
162
|
+
else:
|
163
|
+
with open(target, "r") as f:
|
164
|
+
target_contents = f.read()
|
165
|
+
with open(reference, "r") as f:
|
166
|
+
reference_contents = f.read()
|
167
|
+
pairs = FILE_PAIRING_METHODS[file_pairing_method](
|
168
|
+
target_contents,
|
169
|
+
reference_contents,
|
170
|
+
target_file=None if use_strings else target,
|
171
|
+
reference_file=None if use_strings else reference,
|
172
|
+
out_file=out_file,
|
173
|
+
lang=language,
|
174
|
+
llm=llm,
|
175
|
+
token_limit=token_limit,
|
176
|
+
model_cost=model_cost,
|
177
|
+
)
|
178
|
+
else:
|
179
|
+
raise ValueError(
|
180
|
+
"Error, specify json or target and reference files/strings"
|
181
|
+
)
|
182
|
+
if isinstance(pairs, dict):
|
183
|
+
out = {}
|
184
|
+
for k in pairs:
|
185
|
+
out[k] = apply_function_pairs(
|
186
|
+
pairs[k],
|
187
|
+
function,
|
188
|
+
progress,
|
189
|
+
language,
|
190
|
+
llm,
|
191
|
+
token_limit,
|
192
|
+
model_cost,
|
193
|
+
*args,
|
194
|
+
**kwargs,
|
195
|
+
)
|
196
|
+
else:
|
197
|
+
out = apply_function_pairs(
|
198
|
+
pairs,
|
199
|
+
function,
|
200
|
+
progress,
|
201
|
+
language,
|
202
|
+
llm,
|
203
|
+
token_limit,
|
204
|
+
model_cost,
|
205
|
+
*args,
|
206
|
+
**kwargs,
|
207
|
+
)
|
208
|
+
out_file = Path(out_file)
|
209
|
+
out_file.parent.mkdir(parents=True, exist_ok=True)
|
210
|
+
with open(out_file, "w") as f:
|
211
|
+
json.dump(out, f)
|
212
|
+
log.info(f"Saved results to file: {out_file}")
|
213
|
+
|
214
|
+
sig1 = inspect.signature(function)
|
215
|
+
sig2 = inspect.signature(func)
|
216
|
+
func.__signature__ = sig2.replace(
|
217
|
+
parameters=tuple(
|
218
|
+
list(sig2.parameters.values())[:11]
|
219
|
+
+ list(sig1.parameters.values())[2:-1]
|
220
|
+
)
|
221
|
+
)
|
222
|
+
else:
|
223
|
+
|
224
|
+
def func(
|
225
|
+
out_file: Annotated[
|
226
|
+
str,
|
227
|
+
typer.Option("--out-file", "-o", help="Output JSON file to write."),
|
228
|
+
],
|
229
|
+
language: Annotated[
|
230
|
+
Optional[str],
|
231
|
+
typer.Option(
|
232
|
+
"--language",
|
233
|
+
"-l",
|
234
|
+
help="The language of the source code.",
|
235
|
+
click_type=click.Choice(sorted(LANGUAGES)),
|
236
|
+
),
|
237
|
+
] = None,
|
238
|
+
target: Annotated[
|
239
|
+
Optional[str],
|
240
|
+
typer.Option(
|
241
|
+
"--target", "-t", help="Target file or string to evaluate."
|
242
|
+
),
|
243
|
+
] = None,
|
244
|
+
json_file_name: Annotated[
|
245
|
+
Optional[str],
|
246
|
+
typer.Option(
|
247
|
+
"--json",
|
248
|
+
"-j",
|
249
|
+
help="Json file to extract pairs from \
|
250
|
+
(if set ignores --target)",
|
251
|
+
),
|
252
|
+
] = None,
|
253
|
+
target_key: Annotated[
|
254
|
+
str,
|
255
|
+
typer.Option(
|
256
|
+
"--target-key",
|
257
|
+
"-tk",
|
258
|
+
help="json key to extract list of target strings",
|
259
|
+
),
|
260
|
+
] = "target",
|
261
|
+
splitting_method: Annotated[
|
262
|
+
str,
|
263
|
+
typer.Option(
|
264
|
+
"--method",
|
265
|
+
"-m",
|
266
|
+
click_type=click.Choice(SPLITTING_METHODS.keys()),
|
267
|
+
help="Method to use for pairing\
|
268
|
+
segments of target and reference files.",
|
269
|
+
),
|
270
|
+
] = "file",
|
271
|
+
llm_name: Annotated[
|
272
|
+
str,
|
273
|
+
typer.Option(
|
274
|
+
"--llm",
|
275
|
+
"-L",
|
276
|
+
help="The custom name of the model set with 'janus llm add'.",
|
277
|
+
),
|
278
|
+
] = "gpt-3.5-turbo-0125",
|
279
|
+
progress: Annotated[
|
280
|
+
bool,
|
281
|
+
typer.Option(
|
282
|
+
"--progress",
|
283
|
+
"-p",
|
284
|
+
help="Whether to display a progress bar.",
|
285
|
+
is_flag=True,
|
286
|
+
),
|
287
|
+
] = False,
|
288
|
+
use_strings: Annotated[
|
289
|
+
bool,
|
290
|
+
typer.Option(
|
291
|
+
"--string",
|
292
|
+
"-S",
|
293
|
+
help="Indicate that the target and reference are strings",
|
294
|
+
is_flag=True,
|
295
|
+
),
|
296
|
+
] = False,
|
297
|
+
*args,
|
298
|
+
**kwargs,
|
299
|
+
):
|
300
|
+
llm, token_limit, model_cost = load_model(llm_name)
|
301
|
+
if json_file_name is not None:
|
302
|
+
with open(json_file_name, "r") as f:
|
303
|
+
json_obj = json.load(f)
|
304
|
+
strings = {}
|
305
|
+
for key in json_obj:
|
306
|
+
doc = json_obj[key]
|
307
|
+
experiments = doc["experiments"]
|
308
|
+
for model_key in experiments:
|
309
|
+
model_dict = experiments[model_key]
|
310
|
+
if not isinstance(model_dict, dict):
|
311
|
+
continue
|
312
|
+
if target_key not in model_dict:
|
313
|
+
continue
|
314
|
+
if model_key not in strings:
|
315
|
+
strings[model_key] = {}
|
316
|
+
for k in model_dict[target_key]:
|
317
|
+
strings[model_key][k] = model_dict[target_key][k]
|
318
|
+
# strings += list(json_obj[key][target_key].values())
|
319
|
+
elif target is not None:
|
320
|
+
if use_strings:
|
321
|
+
target_contents = target
|
322
|
+
else:
|
323
|
+
with open(target, "r") as f:
|
324
|
+
target_contents = f.read()
|
325
|
+
|
326
|
+
strings = SPLITTING_METHODS[splitting_method](
|
327
|
+
target_contents,
|
328
|
+
target_file=target if not use_strings else None,
|
329
|
+
out_file=out_file,
|
330
|
+
lang=language,
|
331
|
+
llm=llm,
|
332
|
+
token_limit=token_limit,
|
333
|
+
model_cost=model_cost,
|
334
|
+
)
|
335
|
+
else:
|
336
|
+
raise ValueError(
|
337
|
+
"Error: must specify either json file or target file/string"
|
338
|
+
)
|
339
|
+
if isinstance(strings, dict):
|
340
|
+
out = {}
|
341
|
+
for k in strings:
|
342
|
+
out[k] = apply_function_strings(
|
343
|
+
strings[k],
|
344
|
+
function,
|
345
|
+
progress,
|
346
|
+
language,
|
347
|
+
llm,
|
348
|
+
token_limit,
|
349
|
+
model_cost,
|
350
|
+
*args,
|
351
|
+
**kwargs,
|
352
|
+
)
|
353
|
+
else:
|
354
|
+
out = apply_function_strings(
|
355
|
+
strings,
|
356
|
+
function,
|
357
|
+
progress,
|
358
|
+
language,
|
359
|
+
llm,
|
360
|
+
token_limit,
|
361
|
+
model_cost,
|
362
|
+
*args,
|
363
|
+
**kwargs,
|
364
|
+
)
|
365
|
+
out_file = Path(out_file)
|
366
|
+
out_file.parent.mkdir(parents=True, exist_ok=True)
|
367
|
+
with open(out_file, "w") as f:
|
368
|
+
json.dump(out, f)
|
369
|
+
log.info(f"Saved results to file: {out_file}")
|
370
|
+
|
371
|
+
sig1 = inspect.signature(function)
|
372
|
+
sig2 = inspect.signature(func)
|
373
|
+
func.__signature__ = sig2.replace(
|
374
|
+
parameters=tuple(
|
375
|
+
list(sig2.parameters.values())[:9]
|
376
|
+
+ list(sig1.parameters.values())[1:-1]
|
377
|
+
)
|
378
|
+
)
|
379
|
+
if name is None:
|
380
|
+
func.__name__ = function.__name__
|
381
|
+
else:
|
382
|
+
func.__name__ = name
|
383
|
+
if help is None:
|
384
|
+
func = evaluate.command()(func)
|
385
|
+
else:
|
386
|
+
func = evaluate.command(help=help)(func)
|
387
|
+
return function
|
388
|
+
|
389
|
+
return decorator
|
390
|
+
|
391
|
+
|
392
|
+
def apply_function_pairs(
|
393
|
+
pairs,
|
394
|
+
function,
|
395
|
+
progress,
|
396
|
+
language,
|
397
|
+
llm,
|
398
|
+
token_limit,
|
399
|
+
model_cost,
|
400
|
+
*args,
|
401
|
+
**kwargs,
|
402
|
+
):
|
403
|
+
out = []
|
404
|
+
pair_keys = None
|
405
|
+
if isinstance(pairs, dict):
|
406
|
+
pair_keys = list(pairs.keys())
|
407
|
+
pair_values = list(pairs.values())
|
408
|
+
else:
|
409
|
+
pair_values = pairs
|
410
|
+
if progress:
|
411
|
+
loop = track(pair_values, description="Evaluating pairs")
|
412
|
+
else:
|
413
|
+
loop = pair_values
|
414
|
+
for src, cmp in loop:
|
415
|
+
if not (isinstance(src, str) and isinstance(cmp, str)):
|
416
|
+
out.append(False)
|
417
|
+
else:
|
418
|
+
out.append(
|
419
|
+
function(
|
420
|
+
src,
|
421
|
+
cmp,
|
422
|
+
*args,
|
423
|
+
**kwargs,
|
424
|
+
language=language,
|
425
|
+
llm=llm,
|
426
|
+
token_limit=token_limit,
|
427
|
+
model_cost=model_cost,
|
428
|
+
)
|
429
|
+
)
|
430
|
+
if pair_keys is not None:
|
431
|
+
return {k: v for k, v in zip(pair_keys, out)}
|
432
|
+
return out
|
433
|
+
|
434
|
+
|
435
|
+
def apply_function_strings(
|
436
|
+
strings, function, progress, language, llm, token_limit, model_cost, *args, **kwargs
|
437
|
+
):
|
438
|
+
out = []
|
439
|
+
string_keys = None
|
440
|
+
if isinstance(strings, dict):
|
441
|
+
string_keys = list(strings.keys())
|
442
|
+
string_values = list(strings.values())
|
443
|
+
else:
|
444
|
+
string_values = strings
|
445
|
+
if progress:
|
446
|
+
loop = track(string_values, description="Evaluating strings")
|
447
|
+
else:
|
448
|
+
loop = string_values
|
449
|
+
for string in loop:
|
450
|
+
if not isinstance(string, str):
|
451
|
+
out.append(False)
|
452
|
+
else:
|
453
|
+
out.append(
|
454
|
+
function(
|
455
|
+
string,
|
456
|
+
*args,
|
457
|
+
**kwargs,
|
458
|
+
language=language,
|
459
|
+
llm=llm,
|
460
|
+
token_limit=token_limit,
|
461
|
+
model_cost=model_cost,
|
462
|
+
)
|
463
|
+
)
|
464
|
+
if string_keys is not None:
|
465
|
+
return {k: v for k, v in zip(string_keys, out)}
|
466
|
+
return out
|
janus/metrics/reading.py
ADDED
@@ -0,0 +1,70 @@
|
|
1
|
+
import nltk
|
2
|
+
import readability
|
3
|
+
|
4
|
+
from .metric import metric
|
5
|
+
|
6
|
+
|
7
|
+
def _repeat_text(text):
|
8
|
+
"""Repeats a string until its length is over 100 words.
|
9
|
+
|
10
|
+
Arguments:
|
11
|
+
text: The input string.
|
12
|
+
|
13
|
+
Returns:
|
14
|
+
A string repeated to have more than 100 words.
|
15
|
+
"""
|
16
|
+
# Strip to remove a newline
|
17
|
+
text = text.strip()
|
18
|
+
|
19
|
+
# Check if the text ends with a period
|
20
|
+
if not text.endswith("."):
|
21
|
+
text += "." # Add a period if missing
|
22
|
+
|
23
|
+
# Check if repeated text is long enough, repeat more if needed
|
24
|
+
repeated_text = text
|
25
|
+
while len(repeated_text.split()) < 100:
|
26
|
+
repeated_text += " " + text
|
27
|
+
|
28
|
+
return repeated_text
|
29
|
+
|
30
|
+
|
31
|
+
def get_readability(target: str) -> readability.Readability:
|
32
|
+
"""Create a Readability object from an input string
|
33
|
+
|
34
|
+
Arguments:
|
35
|
+
target: The target text.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
py-readability-metrics Readability object for that text
|
39
|
+
"""
|
40
|
+
nltk.download("punkt", quiet=True)
|
41
|
+
target = _repeat_text(target)
|
42
|
+
return readability.Readability(target)
|
43
|
+
|
44
|
+
|
45
|
+
@metric(use_reference=False, help="The Flesch Readability score")
|
46
|
+
def flesch(target: str, **kwargs) -> float:
|
47
|
+
"""Calculate the Flesch Score using py-readability-metrics.
|
48
|
+
|
49
|
+
Arguments:
|
50
|
+
target: The target text.
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
The Flesch score.
|
54
|
+
"""
|
55
|
+
|
56
|
+
return get_readability(target).flesch().score
|
57
|
+
|
58
|
+
|
59
|
+
@metric(use_reference=False, help="The Gunning-Fog Readability score")
|
60
|
+
def gunning_fog(target: str, **kwargs) -> float:
|
61
|
+
"""Calculate the Gunning-Fog Score using py-readability-metrics.
|
62
|
+
|
63
|
+
Arguments:
|
64
|
+
target: The target text.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
The Gunning-Fog score.
|
68
|
+
"""
|
69
|
+
|
70
|
+
return get_readability(target).gunning_fog().score
|
@@ -0,0 +1,96 @@
|
|
1
|
+
import click
|
2
|
+
import nltk
|
3
|
+
import typer
|
4
|
+
from rouge import Rouge
|
5
|
+
from typing_extensions import Annotated
|
6
|
+
|
7
|
+
from .metric import metric
|
8
|
+
|
9
|
+
|
10
|
+
@metric(help="ROUGE score")
|
11
|
+
def rouge(
|
12
|
+
target: str,
|
13
|
+
reference: str,
|
14
|
+
granularity: Annotated[
|
15
|
+
str,
|
16
|
+
typer.Option(
|
17
|
+
"--granularity",
|
18
|
+
"-g",
|
19
|
+
help=(
|
20
|
+
"The granularity of the ROUGE score. `n` refers to "
|
21
|
+
"ROUGE-N, `l` refers to ROUGE-L, and `w` refers to ROUGE-W."
|
22
|
+
),
|
23
|
+
click_type=click.Choice(["n", "l", "w"]),
|
24
|
+
),
|
25
|
+
] = "n",
|
26
|
+
n_gram: Annotated[
|
27
|
+
int,
|
28
|
+
typer.Option(
|
29
|
+
"--n-gram",
|
30
|
+
"-n",
|
31
|
+
help=("The n-gram overlap calculated for ROUGE-N. Can be an integer."),
|
32
|
+
),
|
33
|
+
] = 2,
|
34
|
+
score_type: Annotated[
|
35
|
+
str,
|
36
|
+
typer.Option(
|
37
|
+
"--score",
|
38
|
+
"-s",
|
39
|
+
help=(
|
40
|
+
"Whether to use the F-score, precision, or recall. For example, `f` "
|
41
|
+
"refers to the F-score, `p` refers to precision, and `r` refers to "
|
42
|
+
"recall."
|
43
|
+
),
|
44
|
+
click_type=click.Choice(["f", "p", "r"]),
|
45
|
+
),
|
46
|
+
] = "f",
|
47
|
+
**kwargs,
|
48
|
+
) -> float:
|
49
|
+
"""Calculate the ROUGE Score.
|
50
|
+
|
51
|
+
Arguments:
|
52
|
+
target: The target text.
|
53
|
+
reference: The reference text.
|
54
|
+
granularity: The granularity of the ROUGE score. `n` refers to ROUGE-N, `l`
|
55
|
+
refers to ROUGE-L, and `w` refers to ROUGE-W.
|
56
|
+
n_gram: The n-gram overlap calculated for ROUGE-N. Can be an integer.
|
57
|
+
score_type: Whether to use the F-score, precision, or recall. For example, `f`
|
58
|
+
refers to the F-score, `p` refers to precision, and `r` refers to recall.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
The ROUGE score.
|
62
|
+
"""
|
63
|
+
nltk.download("punkt", quiet=True)
|
64
|
+
|
65
|
+
if granularity.lower() == "n":
|
66
|
+
metric_name = "rouge-n"
|
67
|
+
metric_name_output = f"rouge-{n_gram}"
|
68
|
+
max_n = n_gram
|
69
|
+
elif granularity.lower() == "l":
|
70
|
+
metric_name = "rouge-l"
|
71
|
+
metric_name_output = "rouge-l"
|
72
|
+
max_n = 4
|
73
|
+
elif granularity.lower() == "w":
|
74
|
+
metric_name = "rouge-w"
|
75
|
+
metric_name_output = "rouge-w"
|
76
|
+
max_n = 4
|
77
|
+
else:
|
78
|
+
raise ValueError("Invalid granularity. Must be one of `n`, `l`, or `w`.")
|
79
|
+
|
80
|
+
if score_type.lower() not in ["f", "p", "r"]:
|
81
|
+
raise ValueError("Invalid score type. Must be one of `f`, `p`, or `r`.")
|
82
|
+
|
83
|
+
evaluator = Rouge(
|
84
|
+
metrics=[metric_name],
|
85
|
+
max_n=max_n,
|
86
|
+
limit_length=False,
|
87
|
+
length_limit=1_000,
|
88
|
+
length_limit_type="words",
|
89
|
+
apply_avg=False,
|
90
|
+
apply_best=False,
|
91
|
+
alpha=0.5, # Default F1_score
|
92
|
+
weight_factor=1.2,
|
93
|
+
stemming=True,
|
94
|
+
)
|
95
|
+
scores = evaluator.get_scores(target, reference)
|
96
|
+
return scores[metric_name_output][0][score_type.lower()][0]
|
@@ -0,0 +1,53 @@
|
|
1
|
+
import click
|
2
|
+
import typer
|
3
|
+
from langchain.evaluation import EmbeddingDistance, load_evaluator
|
4
|
+
from typing_extensions import Annotated
|
5
|
+
|
6
|
+
from ..embedding.embedding_models_info import load_embedding_model
|
7
|
+
from .metric import metric
|
8
|
+
|
9
|
+
|
10
|
+
@metric(name="similarity-score", help="Distance between embeddings of strings.")
|
11
|
+
def similarity_score(
|
12
|
+
target: str,
|
13
|
+
reference: str,
|
14
|
+
model_name: Annotated[
|
15
|
+
str,
|
16
|
+
typer.Option("-e", "--embedding-model", help="Name of embedding model to use."),
|
17
|
+
] = "text-embedding-3-small",
|
18
|
+
distance_metric: Annotated[
|
19
|
+
str,
|
20
|
+
typer.Option(
|
21
|
+
"-d",
|
22
|
+
"--distance-metric",
|
23
|
+
click_type=click.Choice([e.value for e in list(EmbeddingDistance)]),
|
24
|
+
help="Distance metric to use.",
|
25
|
+
),
|
26
|
+
] = "cosine",
|
27
|
+
**kwargs,
|
28
|
+
) -> float:
|
29
|
+
"""Computes the similarity score of two strings
|
30
|
+
|
31
|
+
Arguments:
|
32
|
+
target: The target string.
|
33
|
+
reference: The reference string.
|
34
|
+
model_name: The name of the embedding model to use.
|
35
|
+
distance_metric: The distance metric to use. Can be one of:
|
36
|
+
- cosine
|
37
|
+
- euclidean
|
38
|
+
- manhattan
|
39
|
+
- chebyshev
|
40
|
+
- hamming
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
The similarity score of the two strings.
|
44
|
+
"""
|
45
|
+
embedding_model, _, _ = load_embedding_model(model_name)
|
46
|
+
evaluator = load_evaluator(
|
47
|
+
"pairwise_embedding_distance",
|
48
|
+
embeddings=embedding_model,
|
49
|
+
distance_metric=distance_metric,
|
50
|
+
)
|
51
|
+
return evaluator.evaluate_string_pairs(prediction=target, prediction_b=reference)[
|
52
|
+
"score"
|
53
|
+
]
|