janus-llm 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +130 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +19 -14
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +165 -72
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/METADATA +27 -11
  70. janus_llm-2.0.1.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,466 @@
1
+ import inspect
2
+ import json
3
+ from pathlib import Path
4
+ from typing import Callable, Optional
5
+
6
+ import click
7
+ import typer
8
+ from typing_extensions import Annotated
9
+
10
+ from janus.llm import load_model
11
+ from janus.utils.enums import LANGUAGES
12
+ from janus.utils.logger import create_logger
13
+
14
+ from ..utils.progress import track
15
+ from .cli import evaluate
16
+ from .file_pairing import FILE_PAIRING_METHODS
17
+ from .splitting import SPLITTING_METHODS
18
+
19
+ log = create_logger(__name__)
20
+
21
+
22
+ def metric(
23
+ name: None | str = None,
24
+ help: None | str = None,
25
+ use_reference: bool = True,
26
+ ) -> Callable:
27
+ """Returns a decorator to add a given metric to the cli
28
+
29
+ Metrics must follow the format (src_str, cmp_str, **other_params)
30
+
31
+ Arguments:
32
+ name: The name of the metric. If None, the function name is used.
33
+ help: The help text for the metric.
34
+ use_reference: Whether the metric requires a reference string.
35
+
36
+ Returns:
37
+ The decorator function.
38
+ """
39
+
40
+ def decorator(function):
41
+ if use_reference:
42
+
43
+ def func(
44
+ out_file: Annotated[
45
+ str,
46
+ typer.Option("--out-file", "-o", help="Output JSON file to write."),
47
+ ],
48
+ language: Annotated[
49
+ Optional[str],
50
+ typer.Option(
51
+ "--language",
52
+ "-l",
53
+ help="The language of the source code.",
54
+ click_type=click.Choice(sorted(LANGUAGES)),
55
+ ),
56
+ ] = None,
57
+ target: Annotated[
58
+ Optional[str],
59
+ typer.Option(
60
+ "--target",
61
+ "-t",
62
+ help="Target file or string to evaluate.",
63
+ ),
64
+ ] = None,
65
+ reference: Annotated[
66
+ Optional[str],
67
+ typer.Option(
68
+ "--reference",
69
+ "-r",
70
+ help="Reference file or string to use as reference/baseline.",
71
+ ),
72
+ ] = None,
73
+ json_file_name: Annotated[
74
+ Optional[str],
75
+ typer.Option(
76
+ "--json",
77
+ "-j",
78
+ help="Json file to extract pairs from \
79
+ (if set ignores --target and --reference)",
80
+ ),
81
+ ] = None,
82
+ target_key: Annotated[
83
+ str,
84
+ typer.Option(
85
+ "--target-key",
86
+ "-tk",
87
+ help="json key to extract list of target strings",
88
+ ),
89
+ ] = "target",
90
+ reference_key: Annotated[
91
+ str,
92
+ typer.Option(
93
+ "--reference-key",
94
+ "-rk",
95
+ help="json key to extract list of reference strings",
96
+ ),
97
+ ] = "reference",
98
+ file_pairing_method: Annotated[
99
+ str,
100
+ typer.Option(
101
+ "--method",
102
+ "-m",
103
+ click_type=click.Choice(FILE_PAIRING_METHODS.keys()),
104
+ help="Method to use for pairing\
105
+ segments of target and reference files \
106
+ (ignored for json).",
107
+ ),
108
+ ] = "file",
109
+ llm_name: Annotated[
110
+ str,
111
+ typer.Option(
112
+ "--llm",
113
+ "-L",
114
+ help="The custom name of the model set with 'janus llm add'.",
115
+ ),
116
+ ] = "gpt-3.5-turbo-0125",
117
+ progress: Annotated[
118
+ bool,
119
+ typer.Option(
120
+ "--progress",
121
+ "-p",
122
+ help="Whether to display a progress bar.",
123
+ is_flag=True,
124
+ ),
125
+ ] = False,
126
+ use_strings: Annotated[
127
+ bool,
128
+ typer.Option(
129
+ "--string",
130
+ "-S",
131
+ help="Indicate that the target and reference are strings",
132
+ is_flag=True,
133
+ ),
134
+ ] = False,
135
+ *args,
136
+ **kwargs,
137
+ ):
138
+ out = []
139
+ llm, token_limit, model_cost = load_model(llm_name)
140
+ if json_file_name is not None:
141
+ with open(json_file_name, "r") as f:
142
+ json_obj = json.load(f)
143
+ pairs = {}
144
+ for key in json_obj:
145
+ doc = json_obj[key]
146
+ ref = doc[reference_key]
147
+ experiments = doc["experiments"]
148
+ for model_key in experiments:
149
+ model_dict = experiments[model_key]
150
+ if not isinstance(model_dict, dict):
151
+ continue
152
+ if target_key not in model_dict:
153
+ continue
154
+ if model_key not in pairs:
155
+ pairs[model_key] = {}
156
+ for k in model_dict[target_key]:
157
+ pairs[model_key][k] = (model_dict[target_key][k], ref[k])
158
+ elif target is not None and reference is not None:
159
+ if use_strings:
160
+ target_contents = target
161
+ reference_contents = reference
162
+ else:
163
+ with open(target, "r") as f:
164
+ target_contents = f.read()
165
+ with open(reference, "r") as f:
166
+ reference_contents = f.read()
167
+ pairs = FILE_PAIRING_METHODS[file_pairing_method](
168
+ target_contents,
169
+ reference_contents,
170
+ target_file=None if use_strings else target,
171
+ reference_file=None if use_strings else reference,
172
+ out_file=out_file,
173
+ lang=language,
174
+ llm=llm,
175
+ token_limit=token_limit,
176
+ model_cost=model_cost,
177
+ )
178
+ else:
179
+ raise ValueError(
180
+ "Error, specify json or target and reference files/strings"
181
+ )
182
+ if isinstance(pairs, dict):
183
+ out = {}
184
+ for k in pairs:
185
+ out[k] = apply_function_pairs(
186
+ pairs[k],
187
+ function,
188
+ progress,
189
+ language,
190
+ llm,
191
+ token_limit,
192
+ model_cost,
193
+ *args,
194
+ **kwargs,
195
+ )
196
+ else:
197
+ out = apply_function_pairs(
198
+ pairs,
199
+ function,
200
+ progress,
201
+ language,
202
+ llm,
203
+ token_limit,
204
+ model_cost,
205
+ *args,
206
+ **kwargs,
207
+ )
208
+ out_file = Path(out_file)
209
+ out_file.parent.mkdir(parents=True, exist_ok=True)
210
+ with open(out_file, "w") as f:
211
+ json.dump(out, f)
212
+ log.info(f"Saved results to file: {out_file}")
213
+
214
+ sig1 = inspect.signature(function)
215
+ sig2 = inspect.signature(func)
216
+ func.__signature__ = sig2.replace(
217
+ parameters=tuple(
218
+ list(sig2.parameters.values())[:11]
219
+ + list(sig1.parameters.values())[2:-1]
220
+ )
221
+ )
222
+ else:
223
+
224
+ def func(
225
+ out_file: Annotated[
226
+ str,
227
+ typer.Option("--out-file", "-o", help="Output JSON file to write."),
228
+ ],
229
+ language: Annotated[
230
+ Optional[str],
231
+ typer.Option(
232
+ "--language",
233
+ "-l",
234
+ help="The language of the source code.",
235
+ click_type=click.Choice(sorted(LANGUAGES)),
236
+ ),
237
+ ] = None,
238
+ target: Annotated[
239
+ Optional[str],
240
+ typer.Option(
241
+ "--target", "-t", help="Target file or string to evaluate."
242
+ ),
243
+ ] = None,
244
+ json_file_name: Annotated[
245
+ Optional[str],
246
+ typer.Option(
247
+ "--json",
248
+ "-j",
249
+ help="Json file to extract pairs from \
250
+ (if set ignores --target)",
251
+ ),
252
+ ] = None,
253
+ target_key: Annotated[
254
+ str,
255
+ typer.Option(
256
+ "--target-key",
257
+ "-tk",
258
+ help="json key to extract list of target strings",
259
+ ),
260
+ ] = "target",
261
+ splitting_method: Annotated[
262
+ str,
263
+ typer.Option(
264
+ "--method",
265
+ "-m",
266
+ click_type=click.Choice(SPLITTING_METHODS.keys()),
267
+ help="Method to use for pairing\
268
+ segments of target and reference files.",
269
+ ),
270
+ ] = "file",
271
+ llm_name: Annotated[
272
+ str,
273
+ typer.Option(
274
+ "--llm",
275
+ "-L",
276
+ help="The custom name of the model set with 'janus llm add'.",
277
+ ),
278
+ ] = "gpt-3.5-turbo-0125",
279
+ progress: Annotated[
280
+ bool,
281
+ typer.Option(
282
+ "--progress",
283
+ "-p",
284
+ help="Whether to display a progress bar.",
285
+ is_flag=True,
286
+ ),
287
+ ] = False,
288
+ use_strings: Annotated[
289
+ bool,
290
+ typer.Option(
291
+ "--string",
292
+ "-S",
293
+ help="Indicate that the target and reference are strings",
294
+ is_flag=True,
295
+ ),
296
+ ] = False,
297
+ *args,
298
+ **kwargs,
299
+ ):
300
+ llm, token_limit, model_cost = load_model(llm_name)
301
+ if json_file_name is not None:
302
+ with open(json_file_name, "r") as f:
303
+ json_obj = json.load(f)
304
+ strings = {}
305
+ for key in json_obj:
306
+ doc = json_obj[key]
307
+ experiments = doc["experiments"]
308
+ for model_key in experiments:
309
+ model_dict = experiments[model_key]
310
+ if not isinstance(model_dict, dict):
311
+ continue
312
+ if target_key not in model_dict:
313
+ continue
314
+ if model_key not in strings:
315
+ strings[model_key] = {}
316
+ for k in model_dict[target_key]:
317
+ strings[model_key][k] = model_dict[target_key][k]
318
+ # strings += list(json_obj[key][target_key].values())
319
+ elif target is not None:
320
+ if use_strings:
321
+ target_contents = target
322
+ else:
323
+ with open(target, "r") as f:
324
+ target_contents = f.read()
325
+
326
+ strings = SPLITTING_METHODS[splitting_method](
327
+ target_contents,
328
+ target_file=target if not use_strings else None,
329
+ out_file=out_file,
330
+ lang=language,
331
+ llm=llm,
332
+ token_limit=token_limit,
333
+ model_cost=model_cost,
334
+ )
335
+ else:
336
+ raise ValueError(
337
+ "Error: must specify either json file or target file/string"
338
+ )
339
+ if isinstance(strings, dict):
340
+ out = {}
341
+ for k in strings:
342
+ out[k] = apply_function_strings(
343
+ strings[k],
344
+ function,
345
+ progress,
346
+ language,
347
+ llm,
348
+ token_limit,
349
+ model_cost,
350
+ *args,
351
+ **kwargs,
352
+ )
353
+ else:
354
+ out = apply_function_strings(
355
+ strings,
356
+ function,
357
+ progress,
358
+ language,
359
+ llm,
360
+ token_limit,
361
+ model_cost,
362
+ *args,
363
+ **kwargs,
364
+ )
365
+ out_file = Path(out_file)
366
+ out_file.parent.mkdir(parents=True, exist_ok=True)
367
+ with open(out_file, "w") as f:
368
+ json.dump(out, f)
369
+ log.info(f"Saved results to file: {out_file}")
370
+
371
+ sig1 = inspect.signature(function)
372
+ sig2 = inspect.signature(func)
373
+ func.__signature__ = sig2.replace(
374
+ parameters=tuple(
375
+ list(sig2.parameters.values())[:9]
376
+ + list(sig1.parameters.values())[1:-1]
377
+ )
378
+ )
379
+ if name is None:
380
+ func.__name__ = function.__name__
381
+ else:
382
+ func.__name__ = name
383
+ if help is None:
384
+ func = evaluate.command()(func)
385
+ else:
386
+ func = evaluate.command(help=help)(func)
387
+ return function
388
+
389
+ return decorator
390
+
391
+
392
+ def apply_function_pairs(
393
+ pairs,
394
+ function,
395
+ progress,
396
+ language,
397
+ llm,
398
+ token_limit,
399
+ model_cost,
400
+ *args,
401
+ **kwargs,
402
+ ):
403
+ out = []
404
+ pair_keys = None
405
+ if isinstance(pairs, dict):
406
+ pair_keys = list(pairs.keys())
407
+ pair_values = list(pairs.values())
408
+ else:
409
+ pair_values = pairs
410
+ if progress:
411
+ loop = track(pair_values, description="Evaluating pairs")
412
+ else:
413
+ loop = pair_values
414
+ for src, cmp in loop:
415
+ if not (isinstance(src, str) and isinstance(cmp, str)):
416
+ out.append(False)
417
+ else:
418
+ out.append(
419
+ function(
420
+ src,
421
+ cmp,
422
+ *args,
423
+ **kwargs,
424
+ language=language,
425
+ llm=llm,
426
+ token_limit=token_limit,
427
+ model_cost=model_cost,
428
+ )
429
+ )
430
+ if pair_keys is not None:
431
+ return {k: v for k, v in zip(pair_keys, out)}
432
+ return out
433
+
434
+
435
+ def apply_function_strings(
436
+ strings, function, progress, language, llm, token_limit, model_cost, *args, **kwargs
437
+ ):
438
+ out = []
439
+ string_keys = None
440
+ if isinstance(strings, dict):
441
+ string_keys = list(strings.keys())
442
+ string_values = list(strings.values())
443
+ else:
444
+ string_values = strings
445
+ if progress:
446
+ loop = track(string_values, description="Evaluating strings")
447
+ else:
448
+ loop = string_values
449
+ for string in loop:
450
+ if not isinstance(string, str):
451
+ out.append(False)
452
+ else:
453
+ out.append(
454
+ function(
455
+ string,
456
+ *args,
457
+ **kwargs,
458
+ language=language,
459
+ llm=llm,
460
+ token_limit=token_limit,
461
+ model_cost=model_cost,
462
+ )
463
+ )
464
+ if string_keys is not None:
465
+ return {k: v for k, v in zip(string_keys, out)}
466
+ return out
@@ -0,0 +1,70 @@
1
+ import nltk
2
+ import readability
3
+
4
+ from .metric import metric
5
+
6
+
7
+ def _repeat_text(text):
8
+ """Repeats a string until its length is over 100 words.
9
+
10
+ Arguments:
11
+ text: The input string.
12
+
13
+ Returns:
14
+ A string repeated to have more than 100 words.
15
+ """
16
+ # Strip to remove a newline
17
+ text = text.strip()
18
+
19
+ # Check if the text ends with a period
20
+ if not text.endswith("."):
21
+ text += "." # Add a period if missing
22
+
23
+ # Check if repeated text is long enough, repeat more if needed
24
+ repeated_text = text
25
+ while len(repeated_text.split()) < 100:
26
+ repeated_text += " " + text
27
+
28
+ return repeated_text
29
+
30
+
31
+ def get_readability(target: str) -> readability.Readability:
32
+ """Create a Readability object from an input string
33
+
34
+ Arguments:
35
+ target: The target text.
36
+
37
+ Returns:
38
+ py-readability-metrics Readability object for that text
39
+ """
40
+ nltk.download("punkt", quiet=True)
41
+ target = _repeat_text(target)
42
+ return readability.Readability(target)
43
+
44
+
45
+ @metric(use_reference=False, help="The Flesch Readability score")
46
+ def flesch(target: str, **kwargs) -> float:
47
+ """Calculate the Flesch Score using py-readability-metrics.
48
+
49
+ Arguments:
50
+ target: The target text.
51
+
52
+ Returns:
53
+ The Flesch score.
54
+ """
55
+
56
+ return get_readability(target).flesch().score
57
+
58
+
59
+ @metric(use_reference=False, help="The Gunning-Fog Readability score")
60
+ def gunning_fog(target: str, **kwargs) -> float:
61
+ """Calculate the Gunning-Fog Score using py-readability-metrics.
62
+
63
+ Arguments:
64
+ target: The target text.
65
+
66
+ Returns:
67
+ The Gunning-Fog score.
68
+ """
69
+
70
+ return get_readability(target).gunning_fog().score
@@ -0,0 +1,96 @@
1
+ import click
2
+ import nltk
3
+ import typer
4
+ from rouge import Rouge
5
+ from typing_extensions import Annotated
6
+
7
+ from .metric import metric
8
+
9
+
10
+ @metric(help="ROUGE score")
11
+ def rouge(
12
+ target: str,
13
+ reference: str,
14
+ granularity: Annotated[
15
+ str,
16
+ typer.Option(
17
+ "--granularity",
18
+ "-g",
19
+ help=(
20
+ "The granularity of the ROUGE score. `n` refers to "
21
+ "ROUGE-N, `l` refers to ROUGE-L, and `w` refers to ROUGE-W."
22
+ ),
23
+ click_type=click.Choice(["n", "l", "w"]),
24
+ ),
25
+ ] = "n",
26
+ n_gram: Annotated[
27
+ int,
28
+ typer.Option(
29
+ "--n-gram",
30
+ "-n",
31
+ help=("The n-gram overlap calculated for ROUGE-N. Can be an integer."),
32
+ ),
33
+ ] = 2,
34
+ score_type: Annotated[
35
+ str,
36
+ typer.Option(
37
+ "--score",
38
+ "-s",
39
+ help=(
40
+ "Whether to use the F-score, precision, or recall. For example, `f` "
41
+ "refers to the F-score, `p` refers to precision, and `r` refers to "
42
+ "recall."
43
+ ),
44
+ click_type=click.Choice(["f", "p", "r"]),
45
+ ),
46
+ ] = "f",
47
+ **kwargs,
48
+ ) -> float:
49
+ """Calculate the ROUGE Score.
50
+
51
+ Arguments:
52
+ target: The target text.
53
+ reference: The reference text.
54
+ granularity: The granularity of the ROUGE score. `n` refers to ROUGE-N, `l`
55
+ refers to ROUGE-L, and `w` refers to ROUGE-W.
56
+ n_gram: The n-gram overlap calculated for ROUGE-N. Can be an integer.
57
+ score_type: Whether to use the F-score, precision, or recall. For example, `f`
58
+ refers to the F-score, `p` refers to precision, and `r` refers to recall.
59
+
60
+ Returns:
61
+ The ROUGE score.
62
+ """
63
+ nltk.download("punkt", quiet=True)
64
+
65
+ if granularity.lower() == "n":
66
+ metric_name = "rouge-n"
67
+ metric_name_output = f"rouge-{n_gram}"
68
+ max_n = n_gram
69
+ elif granularity.lower() == "l":
70
+ metric_name = "rouge-l"
71
+ metric_name_output = "rouge-l"
72
+ max_n = 4
73
+ elif granularity.lower() == "w":
74
+ metric_name = "rouge-w"
75
+ metric_name_output = "rouge-w"
76
+ max_n = 4
77
+ else:
78
+ raise ValueError("Invalid granularity. Must be one of `n`, `l`, or `w`.")
79
+
80
+ if score_type.lower() not in ["f", "p", "r"]:
81
+ raise ValueError("Invalid score type. Must be one of `f`, `p`, or `r`.")
82
+
83
+ evaluator = Rouge(
84
+ metrics=[metric_name],
85
+ max_n=max_n,
86
+ limit_length=False,
87
+ length_limit=1_000,
88
+ length_limit_type="words",
89
+ apply_avg=False,
90
+ apply_best=False,
91
+ alpha=0.5, # Default F1_score
92
+ weight_factor=1.2,
93
+ stemming=True,
94
+ )
95
+ scores = evaluator.get_scores(target, reference)
96
+ return scores[metric_name_output][0][score_type.lower()][0]
@@ -0,0 +1,53 @@
1
+ import click
2
+ import typer
3
+ from langchain.evaluation import EmbeddingDistance, load_evaluator
4
+ from typing_extensions import Annotated
5
+
6
+ from ..embedding.embedding_models_info import load_embedding_model
7
+ from .metric import metric
8
+
9
+
10
+ @metric(name="similarity-score", help="Distance between embeddings of strings.")
11
+ def similarity_score(
12
+ target: str,
13
+ reference: str,
14
+ model_name: Annotated[
15
+ str,
16
+ typer.Option("-e", "--embedding-model", help="Name of embedding model to use."),
17
+ ] = "text-embedding-3-small",
18
+ distance_metric: Annotated[
19
+ str,
20
+ typer.Option(
21
+ "-d",
22
+ "--distance-metric",
23
+ click_type=click.Choice([e.value for e in list(EmbeddingDistance)]),
24
+ help="Distance metric to use.",
25
+ ),
26
+ ] = "cosine",
27
+ **kwargs,
28
+ ) -> float:
29
+ """Computes the similarity score of two strings
30
+
31
+ Arguments:
32
+ target: The target string.
33
+ reference: The reference string.
34
+ model_name: The name of the embedding model to use.
35
+ distance_metric: The distance metric to use. Can be one of:
36
+ - cosine
37
+ - euclidean
38
+ - manhattan
39
+ - chebyshev
40
+ - hamming
41
+
42
+ Returns:
43
+ The similarity score of the two strings.
44
+ """
45
+ embedding_model, _, _ = load_embedding_model(model_name)
46
+ evaluator = load_evaluator(
47
+ "pairwise_embedding_distance",
48
+ embeddings=embedding_model,
49
+ distance_metric=distance_metric,
50
+ )
51
+ return evaluator.evaluate_string_pairs(prediction=target, prediction_b=reference)[
52
+ "score"
53
+ ]