janus-llm 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. janus/__init__.py +9 -1
  2. janus/__main__.py +4 -0
  3. janus/_tests/test_cli.py +128 -0
  4. janus/_tests/test_translate.py +49 -7
  5. janus/cli.py +530 -46
  6. janus/converter.py +50 -19
  7. janus/embedding/_tests/test_collections.py +2 -8
  8. janus/embedding/_tests/test_database.py +32 -0
  9. janus/embedding/_tests/test_vectorize.py +9 -4
  10. janus/embedding/collections.py +49 -6
  11. janus/embedding/embedding_models_info.py +130 -0
  12. janus/embedding/vectorize.py +53 -62
  13. janus/language/_tests/__init__.py +0 -0
  14. janus/language/_tests/test_combine.py +62 -0
  15. janus/language/_tests/test_splitter.py +16 -0
  16. janus/language/binary/_tests/test_binary.py +16 -1
  17. janus/language/binary/binary.py +10 -3
  18. janus/language/block.py +31 -30
  19. janus/language/combine.py +26 -34
  20. janus/language/mumps/_tests/test_mumps.py +2 -2
  21. janus/language/mumps/mumps.py +93 -9
  22. janus/language/naive/__init__.py +4 -0
  23. janus/language/naive/basic_splitter.py +14 -0
  24. janus/language/naive/chunk_splitter.py +26 -0
  25. janus/language/naive/registry.py +13 -0
  26. janus/language/naive/simple_ast.py +18 -0
  27. janus/language/naive/tag_splitter.py +61 -0
  28. janus/language/splitter.py +168 -74
  29. janus/language/treesitter/_tests/test_treesitter.py +19 -14
  30. janus/language/treesitter/treesitter.py +37 -13
  31. janus/llm/model_callbacks.py +177 -0
  32. janus/llm/models_info.py +165 -72
  33. janus/metrics/__init__.py +8 -0
  34. janus/metrics/_tests/__init__.py +0 -0
  35. janus/metrics/_tests/reference.py +2 -0
  36. janus/metrics/_tests/target.py +2 -0
  37. janus/metrics/_tests/test_bleu.py +56 -0
  38. janus/metrics/_tests/test_chrf.py +67 -0
  39. janus/metrics/_tests/test_file_pairing.py +59 -0
  40. janus/metrics/_tests/test_llm.py +91 -0
  41. janus/metrics/_tests/test_reading.py +28 -0
  42. janus/metrics/_tests/test_rouge_score.py +65 -0
  43. janus/metrics/_tests/test_similarity_score.py +23 -0
  44. janus/metrics/_tests/test_treesitter_metrics.py +110 -0
  45. janus/metrics/bleu.py +66 -0
  46. janus/metrics/chrf.py +55 -0
  47. janus/metrics/cli.py +7 -0
  48. janus/metrics/complexity_metrics.py +208 -0
  49. janus/metrics/file_pairing.py +113 -0
  50. janus/metrics/llm_metrics.py +202 -0
  51. janus/metrics/metric.py +466 -0
  52. janus/metrics/reading.py +70 -0
  53. janus/metrics/rouge_score.py +96 -0
  54. janus/metrics/similarity.py +53 -0
  55. janus/metrics/splitting.py +38 -0
  56. janus/parsers/_tests/__init__.py +0 -0
  57. janus/parsers/_tests/test_code_parser.py +32 -0
  58. janus/parsers/code_parser.py +24 -253
  59. janus/parsers/doc_parser.py +169 -0
  60. janus/parsers/eval_parser.py +80 -0
  61. janus/parsers/reqs_parser.py +72 -0
  62. janus/prompts/prompt.py +103 -30
  63. janus/translate.py +636 -111
  64. janus/utils/_tests/__init__.py +0 -0
  65. janus/utils/_tests/test_logger.py +67 -0
  66. janus/utils/_tests/test_progress.py +20 -0
  67. janus/utils/enums.py +56 -3
  68. janus/utils/progress.py +56 -0
  69. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/METADATA +27 -11
  70. janus_llm-2.0.1.dist-info/RECORD +94 -0
  71. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/WHEEL +1 -1
  72. janus_llm-1.0.0.dist-info/RECORD +0 -48
  73. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/LICENSE +0 -0
  74. {janus_llm-1.0.0.dist-info → janus_llm-2.0.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,466 @@
1
+ import inspect
2
+ import json
3
+ from pathlib import Path
4
+ from typing import Callable, Optional
5
+
6
+ import click
7
+ import typer
8
+ from typing_extensions import Annotated
9
+
10
+ from janus.llm import load_model
11
+ from janus.utils.enums import LANGUAGES
12
+ from janus.utils.logger import create_logger
13
+
14
+ from ..utils.progress import track
15
+ from .cli import evaluate
16
+ from .file_pairing import FILE_PAIRING_METHODS
17
+ from .splitting import SPLITTING_METHODS
18
+
19
+ log = create_logger(__name__)
20
+
21
+
22
+ def metric(
23
+ name: None | str = None,
24
+ help: None | str = None,
25
+ use_reference: bool = True,
26
+ ) -> Callable:
27
+ """Returns a decorator to add a given metric to the cli
28
+
29
+ Metrics must follow the format (src_str, cmp_str, **other_params)
30
+
31
+ Arguments:
32
+ name: The name of the metric. If None, the function name is used.
33
+ help: The help text for the metric.
34
+ use_reference: Whether the metric requires a reference string.
35
+
36
+ Returns:
37
+ The decorator function.
38
+ """
39
+
40
+ def decorator(function):
41
+ if use_reference:
42
+
43
+ def func(
44
+ out_file: Annotated[
45
+ str,
46
+ typer.Option("--out-file", "-o", help="Output JSON file to write."),
47
+ ],
48
+ language: Annotated[
49
+ Optional[str],
50
+ typer.Option(
51
+ "--language",
52
+ "-l",
53
+ help="The language of the source code.",
54
+ click_type=click.Choice(sorted(LANGUAGES)),
55
+ ),
56
+ ] = None,
57
+ target: Annotated[
58
+ Optional[str],
59
+ typer.Option(
60
+ "--target",
61
+ "-t",
62
+ help="Target file or string to evaluate.",
63
+ ),
64
+ ] = None,
65
+ reference: Annotated[
66
+ Optional[str],
67
+ typer.Option(
68
+ "--reference",
69
+ "-r",
70
+ help="Reference file or string to use as reference/baseline.",
71
+ ),
72
+ ] = None,
73
+ json_file_name: Annotated[
74
+ Optional[str],
75
+ typer.Option(
76
+ "--json",
77
+ "-j",
78
+ help="Json file to extract pairs from \
79
+ (if set ignores --target and --reference)",
80
+ ),
81
+ ] = None,
82
+ target_key: Annotated[
83
+ str,
84
+ typer.Option(
85
+ "--target-key",
86
+ "-tk",
87
+ help="json key to extract list of target strings",
88
+ ),
89
+ ] = "target",
90
+ reference_key: Annotated[
91
+ str,
92
+ typer.Option(
93
+ "--reference-key",
94
+ "-rk",
95
+ help="json key to extract list of reference strings",
96
+ ),
97
+ ] = "reference",
98
+ file_pairing_method: Annotated[
99
+ str,
100
+ typer.Option(
101
+ "--method",
102
+ "-m",
103
+ click_type=click.Choice(FILE_PAIRING_METHODS.keys()),
104
+ help="Method to use for pairing\
105
+ segments of target and reference files \
106
+ (ignored for json).",
107
+ ),
108
+ ] = "file",
109
+ llm_name: Annotated[
110
+ str,
111
+ typer.Option(
112
+ "--llm",
113
+ "-L",
114
+ help="The custom name of the model set with 'janus llm add'.",
115
+ ),
116
+ ] = "gpt-3.5-turbo-0125",
117
+ progress: Annotated[
118
+ bool,
119
+ typer.Option(
120
+ "--progress",
121
+ "-p",
122
+ help="Whether to display a progress bar.",
123
+ is_flag=True,
124
+ ),
125
+ ] = False,
126
+ use_strings: Annotated[
127
+ bool,
128
+ typer.Option(
129
+ "--string",
130
+ "-S",
131
+ help="Indicate that the target and reference are strings",
132
+ is_flag=True,
133
+ ),
134
+ ] = False,
135
+ *args,
136
+ **kwargs,
137
+ ):
138
+ out = []
139
+ llm, token_limit, model_cost = load_model(llm_name)
140
+ if json_file_name is not None:
141
+ with open(json_file_name, "r") as f:
142
+ json_obj = json.load(f)
143
+ pairs = {}
144
+ for key in json_obj:
145
+ doc = json_obj[key]
146
+ ref = doc[reference_key]
147
+ experiments = doc["experiments"]
148
+ for model_key in experiments:
149
+ model_dict = experiments[model_key]
150
+ if not isinstance(model_dict, dict):
151
+ continue
152
+ if target_key not in model_dict:
153
+ continue
154
+ if model_key not in pairs:
155
+ pairs[model_key] = {}
156
+ for k in model_dict[target_key]:
157
+ pairs[model_key][k] = (model_dict[target_key][k], ref[k])
158
+ elif target is not None and reference is not None:
159
+ if use_strings:
160
+ target_contents = target
161
+ reference_contents = reference
162
+ else:
163
+ with open(target, "r") as f:
164
+ target_contents = f.read()
165
+ with open(reference, "r") as f:
166
+ reference_contents = f.read()
167
+ pairs = FILE_PAIRING_METHODS[file_pairing_method](
168
+ target_contents,
169
+ reference_contents,
170
+ target_file=None if use_strings else target,
171
+ reference_file=None if use_strings else reference,
172
+ out_file=out_file,
173
+ lang=language,
174
+ llm=llm,
175
+ token_limit=token_limit,
176
+ model_cost=model_cost,
177
+ )
178
+ else:
179
+ raise ValueError(
180
+ "Error, specify json or target and reference files/strings"
181
+ )
182
+ if isinstance(pairs, dict):
183
+ out = {}
184
+ for k in pairs:
185
+ out[k] = apply_function_pairs(
186
+ pairs[k],
187
+ function,
188
+ progress,
189
+ language,
190
+ llm,
191
+ token_limit,
192
+ model_cost,
193
+ *args,
194
+ **kwargs,
195
+ )
196
+ else:
197
+ out = apply_function_pairs(
198
+ pairs,
199
+ function,
200
+ progress,
201
+ language,
202
+ llm,
203
+ token_limit,
204
+ model_cost,
205
+ *args,
206
+ **kwargs,
207
+ )
208
+ out_file = Path(out_file)
209
+ out_file.parent.mkdir(parents=True, exist_ok=True)
210
+ with open(out_file, "w") as f:
211
+ json.dump(out, f)
212
+ log.info(f"Saved results to file: {out_file}")
213
+
214
+ sig1 = inspect.signature(function)
215
+ sig2 = inspect.signature(func)
216
+ func.__signature__ = sig2.replace(
217
+ parameters=tuple(
218
+ list(sig2.parameters.values())[:11]
219
+ + list(sig1.parameters.values())[2:-1]
220
+ )
221
+ )
222
+ else:
223
+
224
+ def func(
225
+ out_file: Annotated[
226
+ str,
227
+ typer.Option("--out-file", "-o", help="Output JSON file to write."),
228
+ ],
229
+ language: Annotated[
230
+ Optional[str],
231
+ typer.Option(
232
+ "--language",
233
+ "-l",
234
+ help="The language of the source code.",
235
+ click_type=click.Choice(sorted(LANGUAGES)),
236
+ ),
237
+ ] = None,
238
+ target: Annotated[
239
+ Optional[str],
240
+ typer.Option(
241
+ "--target", "-t", help="Target file or string to evaluate."
242
+ ),
243
+ ] = None,
244
+ json_file_name: Annotated[
245
+ Optional[str],
246
+ typer.Option(
247
+ "--json",
248
+ "-j",
249
+ help="Json file to extract pairs from \
250
+ (if set ignores --target)",
251
+ ),
252
+ ] = None,
253
+ target_key: Annotated[
254
+ str,
255
+ typer.Option(
256
+ "--target-key",
257
+ "-tk",
258
+ help="json key to extract list of target strings",
259
+ ),
260
+ ] = "target",
261
+ splitting_method: Annotated[
262
+ str,
263
+ typer.Option(
264
+ "--method",
265
+ "-m",
266
+ click_type=click.Choice(SPLITTING_METHODS.keys()),
267
+ help="Method to use for pairing\
268
+ segments of target and reference files.",
269
+ ),
270
+ ] = "file",
271
+ llm_name: Annotated[
272
+ str,
273
+ typer.Option(
274
+ "--llm",
275
+ "-L",
276
+ help="The custom name of the model set with 'janus llm add'.",
277
+ ),
278
+ ] = "gpt-3.5-turbo-0125",
279
+ progress: Annotated[
280
+ bool,
281
+ typer.Option(
282
+ "--progress",
283
+ "-p",
284
+ help="Whether to display a progress bar.",
285
+ is_flag=True,
286
+ ),
287
+ ] = False,
288
+ use_strings: Annotated[
289
+ bool,
290
+ typer.Option(
291
+ "--string",
292
+ "-S",
293
+ help="Indicate that the target and reference are strings",
294
+ is_flag=True,
295
+ ),
296
+ ] = False,
297
+ *args,
298
+ **kwargs,
299
+ ):
300
+ llm, token_limit, model_cost = load_model(llm_name)
301
+ if json_file_name is not None:
302
+ with open(json_file_name, "r") as f:
303
+ json_obj = json.load(f)
304
+ strings = {}
305
+ for key in json_obj:
306
+ doc = json_obj[key]
307
+ experiments = doc["experiments"]
308
+ for model_key in experiments:
309
+ model_dict = experiments[model_key]
310
+ if not isinstance(model_dict, dict):
311
+ continue
312
+ if target_key not in model_dict:
313
+ continue
314
+ if model_key not in strings:
315
+ strings[model_key] = {}
316
+ for k in model_dict[target_key]:
317
+ strings[model_key][k] = model_dict[target_key][k]
318
+ # strings += list(json_obj[key][target_key].values())
319
+ elif target is not None:
320
+ if use_strings:
321
+ target_contents = target
322
+ else:
323
+ with open(target, "r") as f:
324
+ target_contents = f.read()
325
+
326
+ strings = SPLITTING_METHODS[splitting_method](
327
+ target_contents,
328
+ target_file=target if not use_strings else None,
329
+ out_file=out_file,
330
+ lang=language,
331
+ llm=llm,
332
+ token_limit=token_limit,
333
+ model_cost=model_cost,
334
+ )
335
+ else:
336
+ raise ValueError(
337
+ "Error: must specify either json file or target file/string"
338
+ )
339
+ if isinstance(strings, dict):
340
+ out = {}
341
+ for k in strings:
342
+ out[k] = apply_function_strings(
343
+ strings[k],
344
+ function,
345
+ progress,
346
+ language,
347
+ llm,
348
+ token_limit,
349
+ model_cost,
350
+ *args,
351
+ **kwargs,
352
+ )
353
+ else:
354
+ out = apply_function_strings(
355
+ strings,
356
+ function,
357
+ progress,
358
+ language,
359
+ llm,
360
+ token_limit,
361
+ model_cost,
362
+ *args,
363
+ **kwargs,
364
+ )
365
+ out_file = Path(out_file)
366
+ out_file.parent.mkdir(parents=True, exist_ok=True)
367
+ with open(out_file, "w") as f:
368
+ json.dump(out, f)
369
+ log.info(f"Saved results to file: {out_file}")
370
+
371
+ sig1 = inspect.signature(function)
372
+ sig2 = inspect.signature(func)
373
+ func.__signature__ = sig2.replace(
374
+ parameters=tuple(
375
+ list(sig2.parameters.values())[:9]
376
+ + list(sig1.parameters.values())[1:-1]
377
+ )
378
+ )
379
+ if name is None:
380
+ func.__name__ = function.__name__
381
+ else:
382
+ func.__name__ = name
383
+ if help is None:
384
+ func = evaluate.command()(func)
385
+ else:
386
+ func = evaluate.command(help=help)(func)
387
+ return function
388
+
389
+ return decorator
390
+
391
+
392
+ def apply_function_pairs(
393
+ pairs,
394
+ function,
395
+ progress,
396
+ language,
397
+ llm,
398
+ token_limit,
399
+ model_cost,
400
+ *args,
401
+ **kwargs,
402
+ ):
403
+ out = []
404
+ pair_keys = None
405
+ if isinstance(pairs, dict):
406
+ pair_keys = list(pairs.keys())
407
+ pair_values = list(pairs.values())
408
+ else:
409
+ pair_values = pairs
410
+ if progress:
411
+ loop = track(pair_values, description="Evaluating pairs")
412
+ else:
413
+ loop = pair_values
414
+ for src, cmp in loop:
415
+ if not (isinstance(src, str) and isinstance(cmp, str)):
416
+ out.append(False)
417
+ else:
418
+ out.append(
419
+ function(
420
+ src,
421
+ cmp,
422
+ *args,
423
+ **kwargs,
424
+ language=language,
425
+ llm=llm,
426
+ token_limit=token_limit,
427
+ model_cost=model_cost,
428
+ )
429
+ )
430
+ if pair_keys is not None:
431
+ return {k: v for k, v in zip(pair_keys, out)}
432
+ return out
433
+
434
+
435
+ def apply_function_strings(
436
+ strings, function, progress, language, llm, token_limit, model_cost, *args, **kwargs
437
+ ):
438
+ out = []
439
+ string_keys = None
440
+ if isinstance(strings, dict):
441
+ string_keys = list(strings.keys())
442
+ string_values = list(strings.values())
443
+ else:
444
+ string_values = strings
445
+ if progress:
446
+ loop = track(string_values, description="Evaluating strings")
447
+ else:
448
+ loop = string_values
449
+ for string in loop:
450
+ if not isinstance(string, str):
451
+ out.append(False)
452
+ else:
453
+ out.append(
454
+ function(
455
+ string,
456
+ *args,
457
+ **kwargs,
458
+ language=language,
459
+ llm=llm,
460
+ token_limit=token_limit,
461
+ model_cost=model_cost,
462
+ )
463
+ )
464
+ if string_keys is not None:
465
+ return {k: v for k, v in zip(string_keys, out)}
466
+ return out
@@ -0,0 +1,70 @@
1
+ import nltk
2
+ import readability
3
+
4
+ from .metric import metric
5
+
6
+
7
+ def _repeat_text(text):
8
+ """Repeats a string until its length is over 100 words.
9
+
10
+ Arguments:
11
+ text: The input string.
12
+
13
+ Returns:
14
+ A string repeated to have more than 100 words.
15
+ """
16
+ # Strip to remove a newline
17
+ text = text.strip()
18
+
19
+ # Check if the text ends with a period
20
+ if not text.endswith("."):
21
+ text += "." # Add a period if missing
22
+
23
+ # Check if repeated text is long enough, repeat more if needed
24
+ repeated_text = text
25
+ while len(repeated_text.split()) < 100:
26
+ repeated_text += " " + text
27
+
28
+ return repeated_text
29
+
30
+
31
+ def get_readability(target: str) -> readability.Readability:
32
+ """Create a Readability object from an input string
33
+
34
+ Arguments:
35
+ target: The target text.
36
+
37
+ Returns:
38
+ py-readability-metrics Readability object for that text
39
+ """
40
+ nltk.download("punkt", quiet=True)
41
+ target = _repeat_text(target)
42
+ return readability.Readability(target)
43
+
44
+
45
+ @metric(use_reference=False, help="The Flesch Readability score")
46
+ def flesch(target: str, **kwargs) -> float:
47
+ """Calculate the Flesch Score using py-readability-metrics.
48
+
49
+ Arguments:
50
+ target: The target text.
51
+
52
+ Returns:
53
+ The Flesch score.
54
+ """
55
+
56
+ return get_readability(target).flesch().score
57
+
58
+
59
+ @metric(use_reference=False, help="The Gunning-Fog Readability score")
60
+ def gunning_fog(target: str, **kwargs) -> float:
61
+ """Calculate the Gunning-Fog Score using py-readability-metrics.
62
+
63
+ Arguments:
64
+ target: The target text.
65
+
66
+ Returns:
67
+ The Gunning-Fog score.
68
+ """
69
+
70
+ return get_readability(target).gunning_fog().score
@@ -0,0 +1,96 @@
1
+ import click
2
+ import nltk
3
+ import typer
4
+ from rouge import Rouge
5
+ from typing_extensions import Annotated
6
+
7
+ from .metric import metric
8
+
9
+
10
+ @metric(help="ROUGE score")
11
+ def rouge(
12
+ target: str,
13
+ reference: str,
14
+ granularity: Annotated[
15
+ str,
16
+ typer.Option(
17
+ "--granularity",
18
+ "-g",
19
+ help=(
20
+ "The granularity of the ROUGE score. `n` refers to "
21
+ "ROUGE-N, `l` refers to ROUGE-L, and `w` refers to ROUGE-W."
22
+ ),
23
+ click_type=click.Choice(["n", "l", "w"]),
24
+ ),
25
+ ] = "n",
26
+ n_gram: Annotated[
27
+ int,
28
+ typer.Option(
29
+ "--n-gram",
30
+ "-n",
31
+ help=("The n-gram overlap calculated for ROUGE-N. Can be an integer."),
32
+ ),
33
+ ] = 2,
34
+ score_type: Annotated[
35
+ str,
36
+ typer.Option(
37
+ "--score",
38
+ "-s",
39
+ help=(
40
+ "Whether to use the F-score, precision, or recall. For example, `f` "
41
+ "refers to the F-score, `p` refers to precision, and `r` refers to "
42
+ "recall."
43
+ ),
44
+ click_type=click.Choice(["f", "p", "r"]),
45
+ ),
46
+ ] = "f",
47
+ **kwargs,
48
+ ) -> float:
49
+ """Calculate the ROUGE Score.
50
+
51
+ Arguments:
52
+ target: The target text.
53
+ reference: The reference text.
54
+ granularity: The granularity of the ROUGE score. `n` refers to ROUGE-N, `l`
55
+ refers to ROUGE-L, and `w` refers to ROUGE-W.
56
+ n_gram: The n-gram overlap calculated for ROUGE-N. Can be an integer.
57
+ score_type: Whether to use the F-score, precision, or recall. For example, `f`
58
+ refers to the F-score, `p` refers to precision, and `r` refers to recall.
59
+
60
+ Returns:
61
+ The ROUGE score.
62
+ """
63
+ nltk.download("punkt", quiet=True)
64
+
65
+ if granularity.lower() == "n":
66
+ metric_name = "rouge-n"
67
+ metric_name_output = f"rouge-{n_gram}"
68
+ max_n = n_gram
69
+ elif granularity.lower() == "l":
70
+ metric_name = "rouge-l"
71
+ metric_name_output = "rouge-l"
72
+ max_n = 4
73
+ elif granularity.lower() == "w":
74
+ metric_name = "rouge-w"
75
+ metric_name_output = "rouge-w"
76
+ max_n = 4
77
+ else:
78
+ raise ValueError("Invalid granularity. Must be one of `n`, `l`, or `w`.")
79
+
80
+ if score_type.lower() not in ["f", "p", "r"]:
81
+ raise ValueError("Invalid score type. Must be one of `f`, `p`, or `r`.")
82
+
83
+ evaluator = Rouge(
84
+ metrics=[metric_name],
85
+ max_n=max_n,
86
+ limit_length=False,
87
+ length_limit=1_000,
88
+ length_limit_type="words",
89
+ apply_avg=False,
90
+ apply_best=False,
91
+ alpha=0.5, # Default F1_score
92
+ weight_factor=1.2,
93
+ stemming=True,
94
+ )
95
+ scores = evaluator.get_scores(target, reference)
96
+ return scores[metric_name_output][0][score_type.lower()][0]
@@ -0,0 +1,53 @@
1
+ import click
2
+ import typer
3
+ from langchain.evaluation import EmbeddingDistance, load_evaluator
4
+ from typing_extensions import Annotated
5
+
6
+ from ..embedding.embedding_models_info import load_embedding_model
7
+ from .metric import metric
8
+
9
+
10
+ @metric(name="similarity-score", help="Distance between embeddings of strings.")
11
+ def similarity_score(
12
+ target: str,
13
+ reference: str,
14
+ model_name: Annotated[
15
+ str,
16
+ typer.Option("-e", "--embedding-model", help="Name of embedding model to use."),
17
+ ] = "text-embedding-3-small",
18
+ distance_metric: Annotated[
19
+ str,
20
+ typer.Option(
21
+ "-d",
22
+ "--distance-metric",
23
+ click_type=click.Choice([e.value for e in list(EmbeddingDistance)]),
24
+ help="Distance metric to use.",
25
+ ),
26
+ ] = "cosine",
27
+ **kwargs,
28
+ ) -> float:
29
+ """Computes the similarity score of two strings
30
+
31
+ Arguments:
32
+ target: The target string.
33
+ reference: The reference string.
34
+ model_name: The name of the embedding model to use.
35
+ distance_metric: The distance metric to use. Can be one of:
36
+ - cosine
37
+ - euclidean
38
+ - manhattan
39
+ - chebyshev
40
+ - hamming
41
+
42
+ Returns:
43
+ The similarity score of the two strings.
44
+ """
45
+ embedding_model, _, _ = load_embedding_model(model_name)
46
+ evaluator = load_evaluator(
47
+ "pairwise_embedding_distance",
48
+ embeddings=embedding_model,
49
+ distance_metric=distance_metric,
50
+ )
51
+ return evaluator.evaluate_string_pairs(prediction=target, prediction_b=reference)[
52
+ "score"
53
+ ]