nltkor 1.2.14__cp311-cp311-macosx_13_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. nltkor/Kor_char.py +193 -0
  2. nltkor/__init__.py +16 -0
  3. nltkor/alignment/__init__.py +1315 -0
  4. nltkor/cider/__init__.py +2 -0
  5. nltkor/cider/cider.py +55 -0
  6. nltkor/cider/cider_scorer.py +207 -0
  7. nltkor/distance/__init__.py +441 -0
  8. nltkor/distance/wasserstein.py +126 -0
  9. nltkor/etc.py +22 -0
  10. nltkor/lazyimport.py +144 -0
  11. nltkor/make_requirement.py +11 -0
  12. nltkor/metrics/__init__.py +63 -0
  13. nltkor/metrics/bartscore.py +301 -0
  14. nltkor/metrics/bertscore.py +331 -0
  15. nltkor/metrics/bleu_tensor.py +20 -0
  16. nltkor/metrics/classical.py +847 -0
  17. nltkor/metrics/entment.py +24 -0
  18. nltkor/metrics/eval.py +517 -0
  19. nltkor/metrics/mauve.py +273 -0
  20. nltkor/metrics/mauve_utils.py +131 -0
  21. nltkor/misc/__init__.py +11 -0
  22. nltkor/misc/string2string_basic_functions.py +59 -0
  23. nltkor/misc/string2string_default_tokenizer.py +83 -0
  24. nltkor/misc/string2string_hash_functions.py +159 -0
  25. nltkor/misc/string2string_word_embeddings.py +503 -0
  26. nltkor/search/__init__.py +10 -0
  27. nltkor/search/classical.py +569 -0
  28. nltkor/search/faiss_search.py +787 -0
  29. nltkor/search/kobert_tokenizer.py +181 -0
  30. nltkor/sejong/__init__.py +3 -0
  31. nltkor/sejong/__pycache__/__init__.cpython-38.pyc +0 -0
  32. nltkor/sejong/__pycache__/__init__.cpython-39.pyc +0 -0
  33. nltkor/sejong/__pycache__/sejong_download.cpython-38.pyc +0 -0
  34. nltkor/sejong/__pycache__/sejong_download.cpython-39.pyc +0 -0
  35. nltkor/sejong/__pycache__/ssem.cpython-38.pyc +0 -0
  36. nltkor/sejong/__pycache__/ssem.cpython-39.pyc +0 -0
  37. nltkor/sejong/ch.py +12 -0
  38. nltkor/sejong/dict_semClassNum.txt +491 -0
  39. nltkor/sejong/layer.txt +630 -0
  40. nltkor/sejong/sejong_download.py +87 -0
  41. nltkor/sejong/ssem.py +684 -0
  42. nltkor/similarity/__init__.py +3 -0
  43. nltkor/similarity/bartscore____.py +337 -0
  44. nltkor/similarity/bertscore____.py +339 -0
  45. nltkor/similarity/classical.py +245 -0
  46. nltkor/similarity/cosine_similarity.py +175 -0
  47. nltkor/tag/__init__.py +71 -0
  48. nltkor/tag/__pycache__/__init__.cpython-38.pyc +0 -0
  49. nltkor/tag/__pycache__/__init__.cpython-39.pyc +0 -0
  50. nltkor/tag/__pycache__/espresso_tag.cpython-38.pyc +0 -0
  51. nltkor/tag/__pycache__/espresso_tag.cpython-39.pyc +0 -0
  52. nltkor/tag/espresso_tag.py +220 -0
  53. nltkor/tag/libs/__init__.py +10 -0
  54. nltkor/tag/libs/__pycache__/__init__.cpython-38.pyc +0 -0
  55. nltkor/tag/libs/__pycache__/__init__.cpython-39.pyc +0 -0
  56. nltkor/tag/libs/__pycache__/attributes.cpython-38.pyc +0 -0
  57. nltkor/tag/libs/__pycache__/attributes.cpython-39.pyc +0 -0
  58. nltkor/tag/libs/__pycache__/config.cpython-38.pyc +0 -0
  59. nltkor/tag/libs/__pycache__/config.cpython-39.pyc +0 -0
  60. nltkor/tag/libs/__pycache__/metadata.cpython-38.pyc +0 -0
  61. nltkor/tag/libs/__pycache__/metadata.cpython-39.pyc +0 -0
  62. nltkor/tag/libs/__pycache__/reader.cpython-38.pyc +0 -0
  63. nltkor/tag/libs/__pycache__/reader.cpython-39.pyc +0 -0
  64. nltkor/tag/libs/__pycache__/taggers.cpython-38.pyc +0 -0
  65. nltkor/tag/libs/__pycache__/taggers.cpython-39.pyc +0 -0
  66. nltkor/tag/libs/__pycache__/utils.cpython-38.pyc +0 -0
  67. nltkor/tag/libs/__pycache__/utils.cpython-39.pyc +0 -0
  68. nltkor/tag/libs/__pycache__/word_dictionary.cpython-38.pyc +0 -0
  69. nltkor/tag/libs/__pycache__/word_dictionary.cpython-39.pyc +0 -0
  70. nltkor/tag/libs/arguments.py +280 -0
  71. nltkor/tag/libs/attributes.py +231 -0
  72. nltkor/tag/libs/config.py +159 -0
  73. nltkor/tag/libs/metadata.py +129 -0
  74. nltkor/tag/libs/ner/__init__.py +2 -0
  75. nltkor/tag/libs/ner/__pycache__/__init__.cpython-38.pyc +0 -0
  76. nltkor/tag/libs/ner/__pycache__/__init__.cpython-39.pyc +0 -0
  77. nltkor/tag/libs/ner/__pycache__/ner_reader.cpython-38.pyc +0 -0
  78. nltkor/tag/libs/ner/__pycache__/ner_reader.cpython-39.pyc +0 -0
  79. nltkor/tag/libs/ner/macmorphoreader.py +7 -0
  80. nltkor/tag/libs/ner/ner_reader.py +92 -0
  81. nltkor/tag/libs/network.c +72325 -0
  82. nltkor/tag/libs/network.cpython-311-darwin.so +0 -0
  83. nltkor/tag/libs/network.pyx +878 -0
  84. nltkor/tag/libs/networkconv.pyx +1028 -0
  85. nltkor/tag/libs/networkdependencyconv.pyx +451 -0
  86. nltkor/tag/libs/parse/__init__.py +1 -0
  87. nltkor/tag/libs/parse/__pycache__/__init__.cpython-38.pyc +0 -0
  88. nltkor/tag/libs/parse/__pycache__/__init__.cpython-39.pyc +0 -0
  89. nltkor/tag/libs/parse/__pycache__/parse_reader.cpython-38.pyc +0 -0
  90. nltkor/tag/libs/parse/__pycache__/parse_reader.cpython-39.pyc +0 -0
  91. nltkor/tag/libs/parse/parse_reader.py +283 -0
  92. nltkor/tag/libs/pos/__init__.py +2 -0
  93. nltkor/tag/libs/pos/__pycache__/__init__.cpython-38.pyc +0 -0
  94. nltkor/tag/libs/pos/__pycache__/__init__.cpython-39.pyc +0 -0
  95. nltkor/tag/libs/pos/__pycache__/pos_reader.cpython-38.pyc +0 -0
  96. nltkor/tag/libs/pos/__pycache__/pos_reader.cpython-39.pyc +0 -0
  97. nltkor/tag/libs/pos/macmorphoreader.py +7 -0
  98. nltkor/tag/libs/pos/pos_reader.py +97 -0
  99. nltkor/tag/libs/reader.py +485 -0
  100. nltkor/tag/libs/srl/__init__.py +3 -0
  101. nltkor/tag/libs/srl/__pycache__/__init__.cpython-38.pyc +0 -0
  102. nltkor/tag/libs/srl/__pycache__/__init__.cpython-39.pyc +0 -0
  103. nltkor/tag/libs/srl/__pycache__/srl_reader.cpython-38.pyc +0 -0
  104. nltkor/tag/libs/srl/__pycache__/srl_reader.cpython-39.pyc +0 -0
  105. nltkor/tag/libs/srl/__pycache__/train_srl.cpython-38.pyc +0 -0
  106. nltkor/tag/libs/srl/__pycache__/train_srl.cpython-39.pyc +0 -0
  107. nltkor/tag/libs/srl/__srl_reader_.py +535 -0
  108. nltkor/tag/libs/srl/srl_reader.py +436 -0
  109. nltkor/tag/libs/srl/train_srl.py +87 -0
  110. nltkor/tag/libs/taggers.py +926 -0
  111. nltkor/tag/libs/utils.py +384 -0
  112. nltkor/tag/libs/word_dictionary.py +239 -0
  113. nltkor/tag/libs/wsd/__init__.py +2 -0
  114. nltkor/tag/libs/wsd/__pycache__/__init__.cpython-38.pyc +0 -0
  115. nltkor/tag/libs/wsd/__pycache__/__init__.cpython-39.pyc +0 -0
  116. nltkor/tag/libs/wsd/__pycache__/wsd_reader.cpython-38.pyc +0 -0
  117. nltkor/tag/libs/wsd/__pycache__/wsd_reader.cpython-39.pyc +0 -0
  118. nltkor/tag/libs/wsd/macmorphoreader.py +7 -0
  119. nltkor/tag/libs/wsd/wsd_reader.py +93 -0
  120. nltkor/tokenize/__init__.py +62 -0
  121. nltkor/tokenize/ko_tokenize.py +115 -0
  122. nltkor/trans.py +121 -0
  123. nltkor-1.2.14.dist-info/LICENSE.txt +1093 -0
  124. nltkor-1.2.14.dist-info/METADATA +41 -0
  125. nltkor-1.2.14.dist-info/RECORD +127 -0
  126. nltkor-1.2.14.dist-info/WHEEL +5 -0
  127. nltkor-1.2.14.dist-info/top_level.txt +1 -0
nltkor/lazyimport.py ADDED
@@ -0,0 +1,144 @@
1
+ # This module is from mx/DateTime/LazyModule.py and is
2
+ # distributed under the terms of the eGenix.com Public License Agreement
3
+ # http://www.egenix.com/products/eGenix.com-Public-License-1.1.0.pdf
4
+
5
+ """ Helper to enable simple lazy module import.
6
+
7
+ 'Lazy' means the actual import is deferred until an attribute is
8
+ requested from the module's namespace. This has the advantage of
9
+ allowing all imports to be done at the top of a script (in a
10
+ prominent and visible place) without having a great impact
11
+ on startup time.
12
+
13
+ Copyright (c) 1999-2005, Marc-Andre Lemburg; mailto:mal@lemburg.com
14
+ See the documentation for further information on copyrights,
15
+ or contact the author. All Rights Reserved.
16
+ """
17
+
18
+ ### Constants
19
+
20
+ _debug = 0
21
+
22
+ ###
23
+
24
+
25
+ class LazyModule:
26
+
27
+ """ Lazy module class.
28
+
29
+ Lazy modules are imported into the given namespaces whenever a
30
+ non-special attribute (there are some attributes like __doc__
31
+ that class instances handle without calling __getattr__) is
32
+ requested. The module is then registered under the given name
33
+ in locals usually replacing the import wrapper instance. The
34
+ import itself is done using globals as global namespace.
35
+
36
+ Example of creating a lazy load module:
37
+
38
+ ISO = LazyModule('ISO',locals(),globals())
39
+
40
+ Later, requesting an attribute from ISO will load the module
41
+ automatically into the locals() namespace, overriding the
42
+ LazyModule instance:
43
+
44
+ t = ISO.Week(1998,1,1)
45
+
46
+ """
47
+
48
+ # Flag which inidicates whether the LazyModule is initialized or not
49
+ __lazymodule_init = 0
50
+
51
+ # Name of the module to load
52
+ __lazymodule_name = ""
53
+
54
+ # Flag which indicates whether the module was loaded or not
55
+ __lazymodule_loaded = 0
56
+
57
+ # Locals dictionary where to register the module
58
+ __lazymodule_locals = None
59
+
60
+ # Globals dictionary to use for the module import
61
+ __lazymodule_globals = None
62
+
63
+ def __init__(self, name, locals, globals=None):
64
+
65
+ """ Create a LazyModule instance wrapping module name.
66
+
67
+ The module will later on be registered in locals under the
68
+ given module name.
69
+
70
+ globals is optional and defaults to locals.
71
+
72
+ """
73
+ self.__lazymodule_locals = locals
74
+ if globals is None:
75
+ globals = locals
76
+ self.__lazymodule_globals = globals
77
+ mainname = globals.get("__name__", "")
78
+ if mainname:
79
+ self.__name__ = mainname + "." + name
80
+ self.__lazymodule_name = name
81
+ else:
82
+ self.__name__ = self.__lazymodule_name = name
83
+ self.__lazymodule_init = 1
84
+
85
+ def __lazymodule_import(self):
86
+
87
+ """ Import the module now.
88
+ """
89
+ # Load and register module
90
+ name = self.__lazymodule_name
91
+ if self.__lazymodule_loaded:
92
+ return self.__lazymodule_locals[name]
93
+ if _debug:
94
+ print("LazyModule: Loading module %r" % name)
95
+ self.__lazymodule_locals[name] = module = __import__(
96
+ name, self.__lazymodule_locals, self.__lazymodule_globals, "*"
97
+ )
98
+
99
+ # Fill namespace with all symbols from original module to
100
+ # provide faster access.
101
+ self.__dict__.update(module.__dict__)
102
+
103
+ # Set import flag
104
+ self.__dict__["__lazymodule_loaded"] = 1
105
+
106
+ if _debug:
107
+ print("LazyModule: Module %r loaded" % name)
108
+ return module
109
+
110
+ def __getattr__(self, name):
111
+
112
+ """ Import the module on demand and get the attribute.
113
+ """
114
+ if self.__lazymodule_loaded:
115
+ raise AttributeError(name)
116
+ if _debug:
117
+ print(
118
+ "LazyModule: "
119
+ "Module load triggered by attribute %r read access" % name
120
+ )
121
+ module = self.__lazymodule_import()
122
+ return getattr(module, name)
123
+
124
+ def __setattr__(self, name, value):
125
+
126
+ """ Import the module on demand and set the attribute.
127
+ """
128
+ if not self.__lazymodule_init:
129
+ self.__dict__[name] = value
130
+ return
131
+ if self.__lazymodule_loaded:
132
+ self.__lazymodule_locals[self.__lazymodule_name] = value
133
+ self.__dict__[name] = value
134
+ return
135
+ if _debug:
136
+ print(
137
+ "LazyModule: "
138
+ "Module load triggered by attribute %r write access" % name
139
+ )
140
+ module = self.__lazymodule_import()
141
+ setattr(module, name, value)
142
+
143
+ def __repr__(self):
144
+ return "<LazyModule '%s'>" % self.__name__
@@ -0,0 +1,11 @@
1
+ import datetime
2
+ import os
3
+
4
+ def make_requirement(packages):
5
+ file_path = os.path.abspath(__file__)
6
+ file_path = file_path + "__requirement__NLTKor.txt"
7
+ with open(file_path, "a") as f:
8
+ for package in packages:
9
+ f.write(package + '\n')
10
+
11
+ return file_path
@@ -0,0 +1,63 @@
1
+ # Natural Language Toolkit: Metrics
2
+ #
3
+ # Copyright (C) 2001-2020 NLTK Project
4
+ # Author: Steven Bird <stevenbird1@gmail.com>
5
+ # Edward Loper <edloper@gmail.com>
6
+ # URL: <http://nltk.org/>
7
+ # For license information, see LICENSE.TXT
8
+ #
9
+
10
+ """
11
+ NLTKor Metrics
12
+
13
+ Classes and methods for scoring processing modules.
14
+ """
15
+ """
16
+ from nltk.metrics.scores import (
17
+ accuracy,
18
+ precision,
19
+ recall,
20
+ f_measure,
21
+ log_likelihood,
22
+ approxrand,
23
+ )
24
+ from nltk.metrics.confusionmatrix import ConfusionMatrix
25
+ from nltk.metrics.distance import (
26
+ edit_distance,
27
+ edit_distance_align,
28
+ binary_distance,
29
+ jaccard_distance,
30
+ masi_distance,
31
+ interval_distance,
32
+ custom_distance,
33
+ presence,
34
+ fractional_presence,
35
+ )
36
+ from nltk.metrics.paice import Paice
37
+ from nltk.metrics.segmentation import windowdiff, ghd, pk
38
+ from nltk.metrics.agreement import AnnotationTask
39
+ from nltk.metrics.association import (
40
+ NgramAssocMeasures,
41
+ BigramAssocMeasures,
42
+ TrigramAssocMeasures,
43
+ QuadgramAssocMeasures,
44
+ ContingencyMeasures,
45
+ )
46
+ from nltk.metrics.spearman import (
47
+ spearman_correlation,
48
+ ranks_from_sequence,
49
+ ranks_from_scores,
50
+ )
51
+ from nltk.metrics.aline import align
52
+ from nltkor.metrics.eval import StringMetric
53
+ """
54
+ from nltkor.metrics.classical import DefaultMetric
55
+ from nltkor.metrics.entment import EntMent
56
+ from nltkor.metrics.bleu_tensor import *
57
+ #DefaultMetric = lazy_import.lazy_callable("nltkor.metrics.classical.DefaultMetric")
58
+ #Mauve = lazy_import.lazy_callable("nltkor.metrics.mauve.Mauve")
59
+ from nltkor.metrics.mauve import Mauve
60
+ #BERTScore = lazy_import.lazy_callable("nltkor.metrics.bertscore.BERTScore")
61
+ from .bertscore import BERTScore
62
+ #BARTScore = lazy_import.lazy_callable("nltkor.metrics.bartscore.BARTScore")
63
+ from .bartscore import BARTScore
@@ -0,0 +1,301 @@
1
+ """
2
+ string2string similarity
3
+ src = https://github.com/stanfordnlp/string2string
4
+
5
+
6
+ MIT License
7
+
8
+ Copyright (c) 2023 Mirac Suzgun
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+
28
+
29
+ """
30
+
31
+
32
+ """
33
+ This class contains the original implementation of the BARTScore algorithm by Yuan et al. (2021).
34
+
35
+ BARTScore: BART-based Evaluation Metric for Text Generation
36
+
37
+ @inproceedings{bartscore2021,
38
+ author = {Yuan, Weizhe and Neubig, Graham and Liu, Pengfei},
39
+ booktitle = {Advances in Neural Information Processing Systems},
40
+ editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan},
41
+ pages = {27263--27277},
42
+ publisher = {Curran Associates, Inc.},
43
+ title = {BARTScore: Evaluating Generated Text as Text Generation},
44
+ url = {https://proceedings.neurips.cc/paper/2021/file/e4d2b6e6fdeca3e60e0f1a62fee3d9dd-Paper.pdf},
45
+ volume = {34},
46
+ year = {2021}
47
+ }
48
+
49
+ Disclaimer:
50
+ This code is adapted from https://github.com/neulab/BARTScore/blob/main/bart_score.py
51
+ """
52
+
53
+ import numpy as np
54
+ from typing import List, Union, Dict
55
+ import traceback
56
+ from nltkor.make_requirement import make_requirement
57
+
58
+ import torch
59
+ import torch.nn as nn
60
+ from transformers import BartTokenizer, BartForConditionalGeneration
61
+
62
+
63
+ # BARTScore class
64
+ class BARTScore:
65
+ """
66
+ This class implements the BARTScore algorithm.
67
+ """
68
+
69
+ def __init__(self,
70
+ model_name_or_path='facebook/bart-large-cnn',
71
+ tokenizer_name_or_path: str = None,
72
+ device: str = 'cpu',
73
+ max_length=1024,
74
+ ) -> None:
75
+ r"""
76
+ This function initializes the BARTScore class, which computes the BARTScore between two pieces of text.
77
+
78
+ Arguments:
79
+ model_name_or_path (str): The name or path of the model. Defaults to 'facebook/bart-large-cnn'.
80
+ tokenizer_name_or_path (str): The name or path of the tokenizer. Defaults to None.
81
+ device (str): The device to use. Defaults to 'cpu'.
82
+ max_length (int): The maximum length of the input. Defaults to 1024.
83
+
84
+ Returns:
85
+ None
86
+
87
+ Raises:
88
+ ValueError: If the device is not 'cpu' or 'cuda'.
89
+
90
+ .. attention::
91
+
92
+ If you use this class, please make sure to cite the following paper:
93
+
94
+ .. code-block:: latex
95
+
96
+ @inproceedings{bartscore2021,
97
+ author = {Yuan, Weizhe and Neubig, Graham and Liu, Pengfei},
98
+ booktitle = {Advances in Neural Information Processing Systems},
99
+ editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan},
100
+ pages = {27263--27277},
101
+ publisher = {Curran Associates, Inc.},
102
+ title = {BARTScore: Evaluating Generated Text as Text Generation},
103
+ url = {https://proceedings.neurips.cc/paper/2021/file/e4d2b6e6fdeca3e60e0f1a62fee3d9dd-Paper.pdf},
104
+ volume = {34},
105
+ year = {2021}
106
+ }
107
+
108
+ .. note::
109
+ * The default model is the BART-large-cnn model.
110
+ * If the tokenizer name or path is not specified, then the model name or path will be used.
111
+ * If the device is 'cuda', then the model will be loaded onto the GPU.
112
+ * If device is not specified, use the GPU if available, otherwise use the CPU.
113
+
114
+ """
115
+
116
+ if tokenizer_name_or_path is None:
117
+ tokenizer_name_or_path = model_name_or_path
118
+
119
+ # Set the attributes
120
+ self.device = device
121
+ self.max_length = max_length
122
+
123
+ # Load model and tokenizer
124
+ self.tokenizer = BartTokenizer.from_pretrained(tokenizer_name_or_path)
125
+ self.model = BartForConditionalGeneration.from_pretrained(model_name_or_path)
126
+ self.model.eval()
127
+ self.model.to(device)
128
+
129
+ # Set up loss
130
+ self.loss_fct = nn.NLLLoss(reduction='none', ignore_index=self.model.config.pad_token_id)
131
+ self.lsm = nn.LogSoftmax(dim=1)
132
+
133
+
134
+
135
+ # Loads the model weights from a specified path
136
+ def load(self,
137
+ weights_path=None,
138
+ ) -> None:
139
+ """
140
+ This function loads the model weights from a specified path.
141
+
142
+ Arguments:
143
+ weights_path (str): The path to the weights.
144
+
145
+ Returns:
146
+ None
147
+ """
148
+ if weights_path is None:
149
+ weights_path = 'models/bart.pth'
150
+
151
+ self.model.load_state_dict(torch.load(weights_path, map_location=self.device))
152
+
153
+
154
+
155
+ # Compute the BARTScore between source sentences and target sentences
156
+ def compute(self,
157
+ source_sentences: List[str],
158
+ target_sentences: Union[List[str], List[List[str]]],
159
+ batch_size: int = 4,
160
+ agg: str = 'mean',
161
+ ) -> Dict[str, List[float]]:
162
+ """
163
+ This function scores the target sentences against the source sentences using BARTScore.
164
+
165
+ Arguments:
166
+ source_sentences (List[str]): The source sentences.
167
+ target_sentences (Union[List[str], List[List[str]]]): The target sentences.
168
+ batch_size (int): The batch size to use (default: 4)
169
+ agg (str): The aggregation method. Defaults to 'mean'; used only when target_sentences is a list of lists.
170
+
171
+ Returns:
172
+ Dict[str, List[float]]: The BARTScore for each example.
173
+
174
+ Raises:
175
+ ValueError: If the number of source sentences and target sentences do not match.
176
+ """
177
+ # Check the number of source sentences and target sentences
178
+ if len(source_sentences) != len(target_sentences):
179
+ raise ValueError(f'Number of source sentences ({len(source_sentences)}) and number of target sentences ({len(target_sentences)}) do not match.')
180
+
181
+ # If the target sentences are a list of lists, then call the multi_ref_score function
182
+ if isinstance(target_sentences[0], list):
183
+ return self.compute_multi_ref_score(
184
+ source_sentences=source_sentences,
185
+ target_sentences=target_sentences,
186
+ batch_size=batch_size,
187
+ agg=agg
188
+ )
189
+
190
+ # Score for each example
191
+ score_list = []
192
+
193
+ for i in range(0, len(source_sentences), batch_size):
194
+ # Get the current batch
195
+ src_batch = source_sentences[i: i + batch_size]
196
+ tgt_batch = target_sentences[i: i + batch_size]
197
+ try:
198
+ with torch.no_grad():
199
+ # Encode the batch
200
+ encoded_src = self.tokenizer(
201
+ src_batch,
202
+ max_length=self.max_length,
203
+ truncation=True,
204
+ padding=True,
205
+ return_tensors='pt'
206
+ )
207
+ encoded_tgt = self.tokenizer(
208
+ tgt_batch,
209
+ max_length=self.max_length,
210
+ truncation=True,
211
+ padding=True,
212
+ return_tensors='pt'
213
+ )
214
+
215
+ # Get the input ids and attention masks for the source and target sentences
216
+ src_tokens = encoded_src['input_ids'].to(self.device)
217
+ src_mask = encoded_src['attention_mask'].to(self.device)
218
+ tgt_tokens = encoded_tgt['input_ids'].to(self.device)
219
+ tgt_mask = encoded_tgt['attention_mask']
220
+ tgt_len = tgt_mask.sum(dim=1).to(self.device)
221
+
222
+ # Feed the batch to the model and get the loss
223
+ output = self.model(
224
+ input_ids=src_tokens,
225
+ attention_mask=src_mask,
226
+ labels=tgt_tokens
227
+ )
228
+ logits = output.logits.view(-1, self.model.config.vocab_size)
229
+ # Compute the loss
230
+ loss = self.loss_fct(self.lsm(logits), tgt_tokens.view(-1))
231
+ loss = loss.view(tgt_tokens.shape[0], -1)
232
+ loss = loss.sum(dim=1) / tgt_len
233
+ # Get the score
234
+ curr_score_list = [-x.item() for x in loss]
235
+ # Append the score to the list
236
+ score_list += curr_score_list
237
+
238
+ except:
239
+ # If there is an error, print the traceback
240
+ raise Exception(f'Error in scoring batch {i // batch_size}:\n{traceback.format_exc()}')
241
+ return {'score': np.array(score_list)}
242
+
243
+
244
+
245
+ # Score a batch of examples with multiple references
246
+ def compute_multi_ref_score(self,
247
+ source_sentences: List[str],
248
+ target_sentences: List[List[str]],
249
+ batch_size: int = 4,
250
+ agg: str = "mean",
251
+ ) -> Dict[str, List[float]]:
252
+ """
253
+ Score a batch of examples with multiple references.
254
+
255
+ Arguments:
256
+ source_sentences (List[str]): The source sentences.
257
+ target_sentences (List[List[str]]): The target sentences.
258
+ agg (str): The aggregation method. Can be "mean" or "max".
259
+ batch_size (int): The batch size.
260
+
261
+ Returns:
262
+ Dict[str, List[float]]: The BARTScore for each example.
263
+
264
+ Raises:
265
+ ValueError: If the number of source sentences and target sentences do not match.
266
+ """
267
+
268
+ # Assert we have the same number of references
269
+ ref_nums = [len(x) for x in target_sentences]
270
+ if len(set(ref_nums)) > 1:
271
+ raise Exception("You have different number of references per test sample.")
272
+
273
+ ref_num = len(target_sentences[0])
274
+ score_matrix = []
275
+ for i in range(ref_num):
276
+ curr_target_sentences = [x[i] for x in target_sentences]
277
+ scores = self.compute(source_sentences, curr_target_sentences, batch_size)
278
+ score_matrix.append(scores)
279
+ if agg == "mean":
280
+ score_list = np.mean(score_matrix, axis=0)
281
+ elif agg == "max":
282
+ score_list = np.max(score_matrix, axis=0)
283
+ else:
284
+ raise NotImplementedError(f"Aggregation method {agg} not implemented yet.")
285
+ return {"score": score_list}
286
+
287
+ def demo():
288
+ demo_setences = [
289
+ ("I am a student", "He is a teacher"),
290
+ ("나는 학생이다", "그는 선생님이다"),
291
+ ("점심에 온기동에서 삼겹차슈덮밥을 먹었다.", "저녁에 피나치공에서 피자와 치킨을 먹었다."),
292
+ ('제가 나와 있는 곳은 경남 거제시 옥포동 덕포 해수욕장에 나와 있습니다.', '강한 바람에 간판이나 지붕이 떨어지는 등 피해가 잇따르기도 했습니다.'),
293
+ ('Outraged mortuary workers in Kenya have criticised the country’s police chief after he accused them of leasing corpses to opposition politicians.',
294
+ 'Head of police Japheth Koome earlier this week claimed that opposition politicians hired bodies from mortuaries and planted them at the scenes of protests so as to blame the police for brutality.')
295
+
296
+ ]
297
+ for str1, str2 in demo_setences:
298
+ print("demo : ", BARTScore().compute([str1], [str2]))
299
+
300
+ if __name__ == "__main__":
301
+ demo()