arekit 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arekit/common/context/terms_mapper.py +2 -2
- arekit/common/data/const.py +5 -4
- arekit/common/{experiment/api/ops_doc.py → data/doc_provider.py} +1 -1
- arekit/common/data/input/providers/columns/sample.py +6 -1
- arekit/common/data/input/providers/instances/base.py +1 -1
- arekit/common/data/input/providers/rows/base.py +36 -13
- arekit/common/data/input/providers/rows/samples.py +57 -55
- arekit/common/data/input/providers/sample/cropped.py +2 -2
- arekit/common/data/input/sample.py +1 -1
- arekit/common/data/rows_fmt.py +82 -0
- arekit/common/data/rows_parser.py +43 -0
- arekit/common/data/storages/base.py +23 -18
- arekit/common/data/views/samples.py +2 -8
- arekit/common/{news → docs}/base.py +2 -2
- arekit/common/{news → docs}/entities_grouping.py +2 -1
- arekit/common/{news → docs}/entity.py +2 -1
- arekit/common/{news → docs}/parsed/base.py +5 -5
- arekit/common/docs/parsed/providers/base.py +68 -0
- arekit/common/{news → docs}/parsed/providers/base_pairs.py +2 -2
- arekit/common/{news → docs}/parsed/providers/entity_service.py +27 -22
- arekit/common/{news → docs}/parsed/providers/opinion_pairs.py +2 -2
- arekit/common/{news → docs}/parsed/providers/text_opinion_pairs.py +6 -6
- arekit/common/docs/parsed/service.py +31 -0
- arekit/common/docs/parser.py +66 -0
- arekit/common/{news → docs}/sentence.py +1 -1
- arekit/common/entities/base.py +11 -2
- arekit/common/experiment/api/base_samples_io.py +1 -1
- arekit/common/frames/variants/collection.py +2 -2
- arekit/common/linkage/base.py +2 -2
- arekit/common/linkage/meta.py +23 -0
- arekit/common/linkage/opinions.py +1 -1
- arekit/common/linkage/text_opinions.py +2 -2
- arekit/common/opinions/annot/algo/base.py +1 -1
- arekit/common/opinions/annot/algo/pair_based.py +15 -13
- arekit/common/opinions/annot/algo/predefined.py +4 -4
- arekit/common/opinions/annot/algo_based.py +5 -5
- arekit/common/opinions/annot/base.py +3 -3
- arekit/common/opinions/base.py +7 -7
- arekit/common/opinions/collection.py +3 -3
- arekit/common/pipeline/base.py +12 -16
- arekit/common/pipeline/batching.py +28 -0
- arekit/common/pipeline/context.py +5 -1
- arekit/common/pipeline/items/base.py +38 -1
- arekit/common/pipeline/items/flatten.py +5 -1
- arekit/common/pipeline/items/handle.py +2 -1
- arekit/common/pipeline/items/iter.py +2 -1
- arekit/common/pipeline/items/map.py +2 -1
- arekit/common/pipeline/items/map_nested.py +4 -0
- arekit/common/pipeline/utils.py +32 -0
- arekit/common/service/sqlite.py +36 -0
- arekit/common/synonyms/base.py +2 -2
- arekit/common/text/{partitioning/str.py → partitioning.py} +16 -11
- arekit/common/text_opinions/base.py +11 -11
- arekit/common/utils.py +33 -46
- arekit/contrib/networks/embedding.py +3 -3
- arekit/contrib/networks/embedding_io.py +5 -5
- arekit/contrib/networks/input/const.py +0 -2
- arekit/contrib/networks/input/providers/sample.py +15 -29
- arekit/contrib/networks/input/rows_parser.py +47 -134
- arekit/contrib/prompt/sample.py +18 -16
- arekit/contrib/utils/data/contents/opinions.py +17 -5
- arekit/contrib/utils/data/doc_provider/dict_based.py +13 -0
- arekit/contrib/utils/data/{doc_ops → doc_provider}/dir_based.py +7 -7
- arekit/contrib/utils/data/readers/base.py +3 -0
- arekit/contrib/utils/data/readers/csv_pd.py +10 -4
- arekit/contrib/utils/data/readers/jsonl.py +3 -0
- arekit/contrib/utils/data/readers/sqlite.py +14 -0
- arekit/contrib/utils/data/service/balance.py +0 -1
- arekit/contrib/utils/data/storages/pandas_based.py +3 -5
- arekit/contrib/utils/data/storages/row_cache.py +18 -6
- arekit/contrib/utils/data/storages/sqlite_based.py +17 -0
- arekit/contrib/utils/data/writers/base.py +5 -0
- arekit/contrib/utils/data/writers/csv_native.py +3 -0
- arekit/contrib/utils/data/writers/csv_pd.py +3 -0
- arekit/contrib/utils/data/writers/json_opennre.py +31 -13
- arekit/contrib/utils/data/writers/sqlite_native.py +114 -0
- arekit/contrib/utils/io_utils/embedding.py +25 -33
- arekit/contrib/utils/io_utils/utils.py +3 -24
- arekit/contrib/utils/pipelines/items/sampling/base.py +31 -26
- arekit/contrib/utils/pipelines/items/sampling/networks.py +7 -10
- arekit/contrib/utils/pipelines/items/text/entities_default.py +2 -2
- arekit/contrib/utils/pipelines/items/text/frames.py +2 -3
- arekit/contrib/utils/pipelines/items/text/frames_lemmatized.py +3 -3
- arekit/contrib/utils/pipelines/items/text/frames_negation.py +2 -1
- arekit/contrib/utils/pipelines/items/text/tokenizer.py +3 -5
- arekit/contrib/utils/pipelines/items/text/translator.py +136 -0
- arekit/contrib/utils/pipelines/opinion_collections.py +5 -5
- arekit/contrib/utils/pipelines/text_opinion/annot/algo_based.py +7 -7
- arekit/contrib/utils/pipelines/text_opinion/extraction.py +34 -22
- arekit/contrib/utils/pipelines/text_opinion/filters/base.py +1 -1
- arekit/contrib/utils/pipelines/text_opinion/filters/distance_based.py +1 -1
- arekit/contrib/utils/pipelines/text_opinion/filters/entity_based.py +3 -3
- arekit/contrib/utils/pipelines/text_opinion/filters/limitation.py +4 -4
- arekit/contrib/utils/serializer.py +4 -23
- arekit-0.25.0.data/data/logo.png +0 -0
- arekit-0.25.0.dist-info/METADATA +82 -0
- arekit-0.25.0.dist-info/RECORD +259 -0
- {arekit-0.23.1.dist-info → arekit-0.25.0.dist-info}/WHEEL +1 -1
- arekit/common/data/row_ids/base.py +0 -79
- arekit/common/data/row_ids/binary.py +0 -38
- arekit/common/data/row_ids/multiple.py +0 -14
- arekit/common/folding/base.py +0 -36
- arekit/common/folding/fixed.py +0 -42
- arekit/common/folding/nofold.py +0 -15
- arekit/common/folding/united.py +0 -46
- arekit/common/news/objects_parser.py +0 -37
- arekit/common/news/parsed/providers/base.py +0 -48
- arekit/common/news/parsed/service.py +0 -31
- arekit/common/news/parser.py +0 -34
- arekit/common/text/parser.py +0 -12
- arekit/common/text/partitioning/__init__.py +0 -0
- arekit/common/text/partitioning/base.py +0 -4
- arekit/common/text/partitioning/terms.py +0 -35
- arekit/contrib/source/__init__.py +0 -0
- arekit/contrib/source/brat/__init__.py +0 -0
- arekit/contrib/source/brat/annot.py +0 -83
- arekit/contrib/source/brat/entities/__init__.py +0 -0
- arekit/contrib/source/brat/entities/compound.py +0 -33
- arekit/contrib/source/brat/entities/entity.py +0 -42
- arekit/contrib/source/brat/entities/parser.py +0 -53
- arekit/contrib/source/brat/news.py +0 -28
- arekit/contrib/source/brat/opinions/__init__.py +0 -0
- arekit/contrib/source/brat/opinions/converter.py +0 -19
- arekit/contrib/source/brat/relation.py +0 -32
- arekit/contrib/source/brat/sentence.py +0 -69
- arekit/contrib/source/brat/sentences_reader.py +0 -128
- arekit/contrib/source/download.py +0 -41
- arekit/contrib/source/nerel/__init__.py +0 -0
- arekit/contrib/source/nerel/entities.py +0 -55
- arekit/contrib/source/nerel/folding/__init__.py +0 -0
- arekit/contrib/source/nerel/folding/fixed.py +0 -75
- arekit/contrib/source/nerel/io_utils.py +0 -62
- arekit/contrib/source/nerel/labels.py +0 -241
- arekit/contrib/source/nerel/reader.py +0 -46
- arekit/contrib/source/nerel/utils.py +0 -24
- arekit/contrib/source/nerel/versions.py +0 -12
- arekit/contrib/source/nerelbio/__init__.py +0 -0
- arekit/contrib/source/nerelbio/io_utils.py +0 -62
- arekit/contrib/source/nerelbio/labels.py +0 -265
- arekit/contrib/source/nerelbio/reader.py +0 -8
- arekit/contrib/source/nerelbio/versions.py +0 -8
- arekit/contrib/source/ruattitudes/__init__.py +0 -0
- arekit/contrib/source/ruattitudes/collection.py +0 -36
- arekit/contrib/source/ruattitudes/entity/__init__.py +0 -0
- arekit/contrib/source/ruattitudes/entity/parser.py +0 -7
- arekit/contrib/source/ruattitudes/io_utils.py +0 -56
- arekit/contrib/source/ruattitudes/labels_fmt.py +0 -12
- arekit/contrib/source/ruattitudes/news.py +0 -51
- arekit/contrib/source/ruattitudes/news_brat.py +0 -44
- arekit/contrib/source/ruattitudes/opinions/__init__.py +0 -0
- arekit/contrib/source/ruattitudes/opinions/base.py +0 -28
- arekit/contrib/source/ruattitudes/opinions/converter.py +0 -37
- arekit/contrib/source/ruattitudes/reader.py +0 -268
- arekit/contrib/source/ruattitudes/sentence.py +0 -73
- arekit/contrib/source/ruattitudes/synonyms.py +0 -17
- arekit/contrib/source/ruattitudes/text_object.py +0 -57
- arekit/contrib/source/rusentiframes/__init__.py +0 -0
- arekit/contrib/source/rusentiframes/collection.py +0 -157
- arekit/contrib/source/rusentiframes/effect.py +0 -24
- arekit/contrib/source/rusentiframes/io_utils.py +0 -19
- arekit/contrib/source/rusentiframes/labels_fmt.py +0 -22
- arekit/contrib/source/rusentiframes/polarity.py +0 -35
- arekit/contrib/source/rusentiframes/role.py +0 -15
- arekit/contrib/source/rusentiframes/state.py +0 -24
- arekit/contrib/source/rusentiframes/types.py +0 -42
- arekit/contrib/source/rusentiframes/value.py +0 -2
- arekit/contrib/source/rusentrel/__init__.py +0 -0
- arekit/contrib/source/rusentrel/const.py +0 -3
- arekit/contrib/source/rusentrel/entities.py +0 -26
- arekit/contrib/source/rusentrel/io_utils.py +0 -125
- arekit/contrib/source/rusentrel/labels_fmt.py +0 -12
- arekit/contrib/source/rusentrel/news_reader.py +0 -51
- arekit/contrib/source/rusentrel/opinions/__init__.py +0 -0
- arekit/contrib/source/rusentrel/opinions/collection.py +0 -30
- arekit/contrib/source/rusentrel/opinions/converter.py +0 -40
- arekit/contrib/source/rusentrel/opinions/provider.py +0 -54
- arekit/contrib/source/rusentrel/opinions/writer.py +0 -42
- arekit/contrib/source/rusentrel/synonyms.py +0 -17
- arekit/contrib/source/sentinerel/__init__.py +0 -0
- arekit/contrib/source/sentinerel/entities.py +0 -52
- arekit/contrib/source/sentinerel/folding/__init__.py +0 -0
- arekit/contrib/source/sentinerel/folding/factory.py +0 -32
- arekit/contrib/source/sentinerel/folding/fixed.py +0 -73
- arekit/contrib/source/sentinerel/io_utils.py +0 -87
- arekit/contrib/source/sentinerel/labels.py +0 -53
- arekit/contrib/source/sentinerel/labels_scaler.py +0 -30
- arekit/contrib/source/sentinerel/reader.py +0 -42
- arekit/contrib/source/synonyms/__init__.py +0 -0
- arekit/contrib/source/synonyms/utils.py +0 -19
- arekit/contrib/source/zip_utils.py +0 -47
- arekit/contrib/utils/bert/rows.py +0 -0
- arekit/contrib/utils/bert/text_b_rus.py +0 -18
- arekit/contrib/utils/connotations/__init__.py +0 -0
- arekit/contrib/utils/connotations/rusentiframes_sentiment.py +0 -23
- arekit/contrib/utils/cv/__init__.py +0 -0
- arekit/contrib/utils/cv/doc_stat/__init__.py +0 -0
- arekit/contrib/utils/cv/doc_stat/base.py +0 -37
- arekit/contrib/utils/cv/doc_stat/sentence.py +0 -12
- arekit/contrib/utils/cv/splitters/__init__.py +0 -0
- arekit/contrib/utils/cv/splitters/base.py +0 -4
- arekit/contrib/utils/cv/splitters/default.py +0 -53
- arekit/contrib/utils/cv/splitters/statistical.py +0 -57
- arekit/contrib/utils/cv/two_class.py +0 -77
- arekit/contrib/utils/data/doc_ops/__init__.py +0 -0
- arekit/contrib/utils/data/doc_ops/dict_based.py +0 -13
- arekit/contrib/utils/data/ext.py +0 -31
- arekit/contrib/utils/data/views/__init__.py +0 -0
- arekit/contrib/utils/data/views/linkages/__init__.py +0 -0
- arekit/contrib/utils/data/views/linkages/base.py +0 -58
- arekit/contrib/utils/data/views/linkages/multilabel.py +0 -48
- arekit/contrib/utils/data/views/linkages/utils.py +0 -24
- arekit/contrib/utils/data/views/opinions.py +0 -14
- arekit/contrib/utils/download.py +0 -78
- arekit/contrib/utils/entities/formatters/str_rus_cased_fmt.py +0 -78
- arekit/contrib/utils/entities/formatters/str_rus_nocased_fmt.py +0 -15
- arekit/contrib/utils/entities/formatters/str_simple_fmt.py +0 -24
- arekit/contrib/utils/entities/formatters/str_simple_uppercase_fmt.py +0 -21
- arekit/contrib/utils/io_utils/opinions.py +0 -39
- arekit/contrib/utils/io_utils/samples.py +0 -78
- arekit/contrib/utils/lexicons/__init__.py +0 -0
- arekit/contrib/utils/lexicons/lexicon.py +0 -43
- arekit/contrib/utils/lexicons/relation.py +0 -45
- arekit/contrib/utils/lexicons/rusentilex.py +0 -34
- arekit/contrib/utils/nn/__init__.py +0 -0
- arekit/contrib/utils/nn/rows.py +0 -83
- arekit/contrib/utils/pipelines/items/sampling/bert.py +0 -5
- arekit/contrib/utils/pipelines/items/text/terms_splitter.py +0 -10
- arekit/contrib/utils/pipelines/items/to_output.py +0 -101
- arekit/contrib/utils/pipelines/sources/__init__.py +0 -0
- arekit/contrib/utils/pipelines/sources/nerel/__init__.py +0 -0
- arekit/contrib/utils/pipelines/sources/nerel/doc_ops.py +0 -27
- arekit/contrib/utils/pipelines/sources/nerel/extract_text_relations.py +0 -59
- arekit/contrib/utils/pipelines/sources/nerel/labels_fmt.py +0 -60
- arekit/contrib/utils/pipelines/sources/nerel_bio/__init__.py +0 -0
- arekit/contrib/utils/pipelines/sources/nerel_bio/doc_ops.py +0 -29
- arekit/contrib/utils/pipelines/sources/nerel_bio/extrat_text_relations.py +0 -59
- arekit/contrib/utils/pipelines/sources/nerel_bio/labels_fmt.py +0 -79
- arekit/contrib/utils/pipelines/sources/ruattitudes/__init__.py +0 -0
- arekit/contrib/utils/pipelines/sources/ruattitudes/doc_ops.py +0 -56
- arekit/contrib/utils/pipelines/sources/ruattitudes/entity_filter.py +0 -19
- arekit/contrib/utils/pipelines/sources/ruattitudes/extract_text_opinions.py +0 -58
- arekit/contrib/utils/pipelines/sources/rusentrel/__init__.py +0 -0
- arekit/contrib/utils/pipelines/sources/rusentrel/doc_ops.py +0 -21
- arekit/contrib/utils/pipelines/sources/rusentrel/extract_text_opinions.py +0 -100
- arekit/contrib/utils/pipelines/sources/sentinerel/__init__.py +0 -0
- arekit/contrib/utils/pipelines/sources/sentinerel/doc_ops.py +0 -29
- arekit/contrib/utils/pipelines/sources/sentinerel/entity_filter.py +0 -62
- arekit/contrib/utils/pipelines/sources/sentinerel/extract_text_opinions.py +0 -175
- arekit/contrib/utils/pipelines/sources/sentinerel/labels_fmt.py +0 -50
- arekit/contrib/utils/pipelines/text_opinion/annot/predefined.py +0 -88
- arekit/contrib/utils/resources.py +0 -26
- arekit/contrib/utils/sources/__init__.py +0 -0
- arekit/contrib/utils/sources/sentinerel/__init__.py +0 -0
- arekit/contrib/utils/sources/sentinerel/text_opinion/__init__.py +0 -0
- arekit/contrib/utils/sources/sentinerel/text_opinion/prof_per_org_filter.py +0 -63
- arekit/contrib/utils/utils_folding.py +0 -19
- arekit/download_data.py +0 -11
- arekit-0.23.1.dist-info/METADATA +0 -23
- arekit-0.23.1.dist-info/RECORD +0 -403
- /arekit/common/{data/row_ids → docs}/__init__.py +0 -0
- /arekit/common/{folding → docs/parsed}/__init__.py +0 -0
- /arekit/common/{news → docs/parsed/providers}/__init__.py +0 -0
- /arekit/common/{news → docs}/parsed/term_position.py +0 -0
- /arekit/common/{news/parsed → service}/__init__.py +0 -0
- /arekit/{common/news/parsed/providers → contrib/utils/data/doc_provider}/__init__.py +0 -0
- {arekit-0.23.1.dist-info → arekit-0.25.0.dist-info}/LICENSE +0 -0
- {arekit-0.23.1.dist-info → arekit-0.25.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sqlite3
|
|
3
|
+
from os.path import dirname
|
|
4
|
+
|
|
5
|
+
from arekit.common.data import const
|
|
6
|
+
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
|
|
7
|
+
from arekit.contrib.utils.data.writers.base import BaseWriter
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SQliteWriter(BaseWriter):
|
|
11
|
+
""" TODO. This implementation is dedicated for the writing concepts of the data
|
|
12
|
+
serialization pipeline. However we add the SQLite3 service, it would be
|
|
13
|
+
right to refactor and utlize some core functionality from the core/service/sqlite.py
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, table_name="contents", index_column_names=None, skip_existed=False, clear_table=True):
|
|
17
|
+
""" index_column_names: list or None
|
|
18
|
+
column names should be considered to build a unique index;
|
|
19
|
+
if None, the default 'const.ID' will be considered for row indexation.
|
|
20
|
+
"""
|
|
21
|
+
assert (isinstance(index_column_names, list) or index_column_names is None)
|
|
22
|
+
self.__index_column_names = index_column_names if index_column_names is not None else [const.ID]
|
|
23
|
+
self.__table_name = table_name
|
|
24
|
+
self.__conn = None
|
|
25
|
+
self.__cur = None
|
|
26
|
+
self.__need_init_table = True
|
|
27
|
+
self.__origin_column_names = None
|
|
28
|
+
self.__skip_existed = skip_existed
|
|
29
|
+
self.__clear_table = clear_table
|
|
30
|
+
|
|
31
|
+
def extension(self):
|
|
32
|
+
return ".sqlite"
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def __iter_storage_column_names(storage):
|
|
36
|
+
""" Iter only those columns that existed in storage.
|
|
37
|
+
"""
|
|
38
|
+
assert (isinstance(storage, RowCacheStorage))
|
|
39
|
+
for col_name, col_type in zip(storage.iter_column_names(), storage.iter_column_types()):
|
|
40
|
+
if col_name in storage.RowCache:
|
|
41
|
+
yield col_name, col_type
|
|
42
|
+
|
|
43
|
+
def __init_table(self, column_data):
|
|
44
|
+
# Compose column name with the related SQLITE type.
|
|
45
|
+
column_types = ",".join([" ".join([col_name, self.type_to_sqlite(col_type)])
|
|
46
|
+
for col_name, col_type in column_data])
|
|
47
|
+
# Create table if not exists.
|
|
48
|
+
self.__cur.execute(f"CREATE TABLE IF NOT EXISTS {self.__table_name}({column_types})")
|
|
49
|
+
# Table exists, however we may optionally remove the content from it.
|
|
50
|
+
if self.__clear_table:
|
|
51
|
+
self.__cur.execute(f"DELETE FROM {self.__table_name};")
|
|
52
|
+
# Create index.
|
|
53
|
+
index_name = f"i_{self.__table_name}_id"
|
|
54
|
+
self.__cur.execute(f"DROP INDEX IF EXISTS {index_name};")
|
|
55
|
+
self.__cur.execute("CREATE INDEX IF NOT EXISTS {index} ON {table}({columns})".format(
|
|
56
|
+
index=index_name,
|
|
57
|
+
table=self.__table_name,
|
|
58
|
+
columns=", ".join(self.__index_column_names)
|
|
59
|
+
))
|
|
60
|
+
self.__origin_column_names = [col_name for col_name, _ in column_data]
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def type_to_sqlite(col_type):
|
|
64
|
+
""" This is a simple function that provides conversion from the
|
|
65
|
+
base numpy types to SQLITE.
|
|
66
|
+
NOTE: this method represent a quick implementation for supporting
|
|
67
|
+
types, however it is far away from the generalized implementation.
|
|
68
|
+
"""
|
|
69
|
+
if isinstance(col_type, str):
|
|
70
|
+
if 'int' in col_type:
|
|
71
|
+
return 'INTEGER'
|
|
72
|
+
|
|
73
|
+
return "TEXT"
|
|
74
|
+
|
|
75
|
+
def open_target(self, target):
|
|
76
|
+
os.makedirs(dirname(target), exist_ok=True)
|
|
77
|
+
self.__conn = sqlite3.connect(target)
|
|
78
|
+
self.__cur = self.__conn.cursor()
|
|
79
|
+
|
|
80
|
+
def commit_line(self, storage):
|
|
81
|
+
assert (isinstance(storage, RowCacheStorage))
|
|
82
|
+
|
|
83
|
+
column_data = list(self.__iter_storage_column_names(storage))
|
|
84
|
+
|
|
85
|
+
if self.__need_init_table:
|
|
86
|
+
self.__init_table(column_data)
|
|
87
|
+
self.__need_init_table = False
|
|
88
|
+
|
|
89
|
+
# Check whether the related row is already exist in SQLITE database.
|
|
90
|
+
row_id = storage.RowCache[const.ID]
|
|
91
|
+
top_row = self.__cur.execute(f"SELECT EXISTS(SELECT 1 FROM {self.__table_name} WHERE id='{row_id}');")
|
|
92
|
+
is_exists = top_row.fetchone()[0]
|
|
93
|
+
if is_exists == 1 and self.__skip_existed:
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
line_data = [storage.RowCache[col_name] for col_name, _ in column_data]
|
|
97
|
+
parameters = ",".join(["?"] * len(line_data))
|
|
98
|
+
|
|
99
|
+
assert (len(self.__origin_column_names) == len(line_data))
|
|
100
|
+
|
|
101
|
+
self.__cur.execute(
|
|
102
|
+
f"INSERT OR REPLACE INTO {self.__table_name} VALUES ({parameters})",
|
|
103
|
+
tuple(line_data))
|
|
104
|
+
|
|
105
|
+
self.__conn.commit()
|
|
106
|
+
|
|
107
|
+
def close_target(self):
|
|
108
|
+
self.__cur = None
|
|
109
|
+
self.__origin_column_names = None
|
|
110
|
+
self.__need_init_table = True
|
|
111
|
+
self.__conn.close()
|
|
112
|
+
|
|
113
|
+
def write_all(self, storage, target):
|
|
114
|
+
pass
|
|
@@ -1,11 +1,9 @@
|
|
|
1
1
|
from os.path import join
|
|
2
2
|
|
|
3
|
-
from arekit.common.folding.base import BaseDataFolding
|
|
4
3
|
from arekit.contrib.networks.embedding_io import BaseEmbeddingIO
|
|
5
4
|
from arekit.contrib.utils.io_utils.utils import check_targets_existence
|
|
6
5
|
from arekit.contrib.utils.np_utils.embedding import NpzEmbeddingHelper
|
|
7
6
|
from arekit.contrib.utils.np_utils.vocab import VocabRepositoryUtils
|
|
8
|
-
from arekit.contrib.utils.utils_folding import experiment_iter_index
|
|
9
7
|
|
|
10
8
|
|
|
11
9
|
class NpEmbeddingIO(BaseEmbeddingIO):
|
|
@@ -17,37 +15,35 @@ class NpEmbeddingIO(BaseEmbeddingIO):
|
|
|
17
15
|
- embedding vocabulary.
|
|
18
16
|
"""
|
|
19
17
|
|
|
20
|
-
|
|
21
|
-
VOCABULARY_FILENAME_TEMPLATE = "vocab-{cv_index}.txt"
|
|
22
|
-
|
|
23
|
-
def __init__(self, target_dir):
|
|
18
|
+
def __init__(self, target_dir, prefix_name="sample"):
|
|
24
19
|
assert(isinstance(target_dir, str))
|
|
20
|
+
|
|
25
21
|
self.__target_dir = target_dir
|
|
22
|
+
self.__term_emb_fn_template = "-".join([prefix_name, "term_embedding"])
|
|
23
|
+
self.__vocab_fn_template = "-".join([prefix_name, "term_embedding"])
|
|
26
24
|
|
|
27
25
|
# region Embedding-related data
|
|
28
26
|
|
|
29
|
-
def save_vocab(self, data
|
|
30
|
-
|
|
31
|
-
target = self.__get_default_vocab_filepath(data_folding)
|
|
27
|
+
def save_vocab(self, data):
|
|
28
|
+
target = self.__get_default_vocab_filepath()
|
|
32
29
|
return VocabRepositoryUtils.save(data=data, target=target)
|
|
33
30
|
|
|
34
|
-
def load_vocab(self
|
|
35
|
-
source = self.___get_vocab_source(
|
|
31
|
+
def load_vocab(self):
|
|
32
|
+
source = self.___get_vocab_source()
|
|
36
33
|
return dict(VocabRepositoryUtils.load(source))
|
|
37
34
|
|
|
38
|
-
def save_embedding(self, data
|
|
39
|
-
|
|
40
|
-
target = self.__get_default_embedding_filepath(data_folding)
|
|
35
|
+
def save_embedding(self, data):
|
|
36
|
+
target = self.__get_default_embedding_filepath()
|
|
41
37
|
NpzEmbeddingHelper.save_embedding(data=data, target=target)
|
|
42
38
|
|
|
43
|
-
def load_embedding(self
|
|
44
|
-
source = self.__get_term_embedding_source(
|
|
39
|
+
def load_embedding(self):
|
|
40
|
+
source = self.__get_term_embedding_source()
|
|
45
41
|
return NpzEmbeddingHelper.load_embedding(source)
|
|
46
42
|
|
|
47
|
-
def check_targets_existed(self
|
|
43
|
+
def check_targets_existed(self):
|
|
48
44
|
targets = [
|
|
49
|
-
self.__get_default_vocab_filepath(
|
|
50
|
-
self.__get_term_embedding_target(
|
|
45
|
+
self.__get_default_vocab_filepath(),
|
|
46
|
+
self.__get_term_embedding_target()
|
|
51
47
|
]
|
|
52
48
|
return check_targets_existence(targets=targets)
|
|
53
49
|
|
|
@@ -55,26 +51,22 @@ class NpEmbeddingIO(BaseEmbeddingIO):
|
|
|
55
51
|
|
|
56
52
|
# region embedding-related data
|
|
57
53
|
|
|
58
|
-
def ___get_vocab_source(self
|
|
54
|
+
def ___get_vocab_source(self):
|
|
59
55
|
""" It is possible to load a predefined embedding from another experiment
|
|
60
56
|
using the related filepath provided by model_io.
|
|
61
57
|
"""
|
|
62
|
-
return self.__get_default_vocab_filepath(
|
|
58
|
+
return self.__get_default_vocab_filepath()
|
|
63
59
|
|
|
64
|
-
def __get_term_embedding_target(self
|
|
65
|
-
return self.__get_default_embedding_filepath(
|
|
60
|
+
def __get_term_embedding_target(self):
|
|
61
|
+
return self.__get_default_embedding_filepath()
|
|
66
62
|
|
|
67
|
-
def __get_term_embedding_source(self
|
|
68
|
-
return self.__get_default_embedding_filepath(
|
|
63
|
+
def __get_term_embedding_source(self):
|
|
64
|
+
return self.__get_default_embedding_filepath()
|
|
69
65
|
|
|
70
|
-
def __get_default_vocab_filepath(self
|
|
71
|
-
return join(self.__target_dir,
|
|
72
|
-
self.VOCABULARY_FILENAME_TEMPLATE.format(
|
|
73
|
-
cv_index=experiment_iter_index(data_folding)))
|
|
66
|
+
def __get_default_vocab_filepath(self):
|
|
67
|
+
return join(self.__target_dir, self.__vocab_fn_template)
|
|
74
68
|
|
|
75
|
-
def __get_default_embedding_filepath(self
|
|
76
|
-
return join(self.__target_dir,
|
|
77
|
-
self.TERM_EMBEDDING_FILENAME_TEMPLATE.format(
|
|
78
|
-
cv_index=experiment_iter_index(data_folding)) + '.npz')
|
|
69
|
+
def __get_default_embedding_filepath(self):
|
|
70
|
+
return join(self.__target_dir, self.__term_emb_fn_template)
|
|
79
71
|
|
|
80
72
|
# endregion
|
|
@@ -1,35 +1,14 @@
|
|
|
1
|
-
import
|
|
1
|
+
from collections.abc import Iterable
|
|
2
2
|
import logging
|
|
3
|
-
from os.path import
|
|
4
|
-
|
|
5
|
-
from arekit.common.experiment.data_type import DataType
|
|
6
|
-
from arekit.common.folding.base import BaseDataFolding
|
|
7
|
-
from arekit.contrib.utils.utils_folding import experiment_iter_index
|
|
3
|
+
from os.path import exists
|
|
8
4
|
|
|
9
5
|
|
|
10
6
|
logger = logging.getLogger(__name__)
|
|
11
7
|
logging.basicConfig(level=logging.INFO)
|
|
12
8
|
|
|
13
9
|
|
|
14
|
-
def join_dir_with_subfolder_name(subfolder_name, dir):
|
|
15
|
-
""" Returns subfolder in in directory
|
|
16
|
-
"""
|
|
17
|
-
assert(isinstance(subfolder_name, str))
|
|
18
|
-
assert(isinstance(dir, str))
|
|
19
|
-
|
|
20
|
-
target_dir = join(dir, "{}/".format(subfolder_name))
|
|
21
|
-
return target_dir
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def filename_template(data_type, data_folding):
|
|
25
|
-
assert(isinstance(data_type, DataType))
|
|
26
|
-
assert(isinstance(data_folding, BaseDataFolding))
|
|
27
|
-
return "{data_type}-{iter_index}".format(data_type=data_type.name.lower(),
|
|
28
|
-
iter_index=experiment_iter_index(data_folding))
|
|
29
|
-
|
|
30
|
-
|
|
31
10
|
def check_targets_existence(targets):
|
|
32
|
-
assert (isinstance(targets,
|
|
11
|
+
assert (isinstance(targets, Iterable))
|
|
33
12
|
|
|
34
13
|
result = True
|
|
35
14
|
for filepath in targets:
|
|
@@ -2,17 +2,13 @@ from arekit.common.data.input.providers.rows.samples import BaseSampleRowProvide
|
|
|
2
2
|
from arekit.common.data.storages.base import BaseRowsStorage
|
|
3
3
|
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
|
|
4
4
|
from arekit.common.experiment.data_type import DataType
|
|
5
|
-
from arekit.common.folding.base import BaseDataFolding
|
|
6
|
-
from arekit.common.pipeline.base import BasePipeline
|
|
7
|
-
from arekit.common.pipeline.context import PipelineContext
|
|
8
5
|
from arekit.common.pipeline.items.base import BasePipelineItem
|
|
9
6
|
from arekit.contrib.utils.serializer import InputDataSerializationHelper
|
|
10
|
-
from arekit.contrib.utils.utils_folding import folding_iter_states
|
|
11
7
|
|
|
12
8
|
|
|
13
9
|
class BaseSerializerPipelineItem(BasePipelineItem):
|
|
14
10
|
|
|
15
|
-
def __init__(self, rows_provider, samples_io, save_labels_func,
|
|
11
|
+
def __init__(self, rows_provider, samples_io, save_labels_func, storage, **kwargs):
|
|
16
12
|
""" sample_rows_formatter:
|
|
17
13
|
how we format input texts for a BERT model, for example:
|
|
18
14
|
- single text
|
|
@@ -24,18 +20,20 @@ class BaseSerializerPipelineItem(BasePipelineItem):
|
|
|
24
20
|
assert(isinstance(rows_provider, BaseSampleRowProvider))
|
|
25
21
|
assert(isinstance(samples_io, BaseSamplesIO))
|
|
26
22
|
assert(callable(save_labels_func))
|
|
27
|
-
assert(callable(balance_func))
|
|
28
23
|
assert(isinstance(storage, BaseRowsStorage))
|
|
24
|
+
super(BaseSerializerPipelineItem, self).__init__(**kwargs)
|
|
29
25
|
|
|
30
26
|
self._rows_provider = rows_provider
|
|
31
|
-
self._balance_func = balance_func
|
|
32
27
|
self._samples_io = samples_io
|
|
33
28
|
self._save_labels_func = save_labels_func
|
|
34
29
|
self._storage = storage
|
|
35
30
|
|
|
36
|
-
def _serialize_iteration(self, data_type, pipeline, data_folding):
|
|
37
|
-
assert
|
|
38
|
-
assert
|
|
31
|
+
def _serialize_iteration(self, data_type, pipeline, data_folding, doc_ids):
|
|
32
|
+
assert(isinstance(data_type, DataType))
|
|
33
|
+
assert(isinstance(pipeline, list))
|
|
34
|
+
assert(isinstance(data_folding, dict) or data_folding is None)
|
|
35
|
+
assert(isinstance(doc_ids, list) or doc_ids is None)
|
|
36
|
+
assert(doc_ids is not None or data_folding is not None)
|
|
39
37
|
|
|
40
38
|
repos = {
|
|
41
39
|
"sample": InputDataSerializationHelper.create_samples_repo(
|
|
@@ -46,27 +44,36 @@ class BaseSerializerPipelineItem(BasePipelineItem):
|
|
|
46
44
|
|
|
47
45
|
writer_and_targets = {
|
|
48
46
|
"sample": (self._samples_io.Writer,
|
|
49
|
-
self._samples_io.create_target(
|
|
50
|
-
data_type=data_type, data_folding=data_folding)),
|
|
47
|
+
self._samples_io.create_target(data_type=data_type)),
|
|
51
48
|
}
|
|
52
49
|
|
|
53
50
|
for description, repo in repos.items():
|
|
51
|
+
|
|
52
|
+
if data_folding is None:
|
|
53
|
+
# Consider only the predefined doc_ids.
|
|
54
|
+
doc_ids_iter = doc_ids
|
|
55
|
+
else:
|
|
56
|
+
# Take particular data_type.
|
|
57
|
+
doc_ids_iter = data_folding[data_type]
|
|
58
|
+
# Consider only predefined doc_ids.
|
|
59
|
+
if doc_ids is not None:
|
|
60
|
+
doc_ids_iter = set(doc_ids_iter).intersection(doc_ids)
|
|
61
|
+
|
|
54
62
|
InputDataSerializationHelper.fill_and_write(
|
|
55
63
|
repo=repo,
|
|
56
64
|
pipeline=pipeline,
|
|
57
|
-
doc_ids_iter=
|
|
58
|
-
do_balance=self._balance_func(data_type),
|
|
65
|
+
doc_ids_iter=doc_ids_iter,
|
|
59
66
|
desc="{desc} [{data_type}]".format(desc=description, data_type=data_type),
|
|
60
67
|
writer=writer_and_targets[description][0],
|
|
61
68
|
target=writer_and_targets[description][1])
|
|
62
69
|
|
|
63
|
-
def _handle_iteration(self, data_type_pipelines, data_folding):
|
|
70
|
+
def _handle_iteration(self, data_type_pipelines, data_folding, doc_ids):
|
|
64
71
|
""" Performing data serialization for a particular iteration
|
|
65
72
|
"""
|
|
66
73
|
assert(isinstance(data_type_pipelines, dict))
|
|
67
|
-
assert(isinstance(data_folding, BaseDataFolding))
|
|
68
74
|
for data_type, pipeline in data_type_pipelines.items():
|
|
69
|
-
self._serialize_iteration(data_type=data_type, pipeline=pipeline, data_folding=data_folding
|
|
75
|
+
self._serialize_iteration(data_type=data_type, pipeline=pipeline, data_folding=data_folding,
|
|
76
|
+
doc_ids=doc_ids)
|
|
70
77
|
|
|
71
78
|
def apply_core(self, input_data, pipeline_ctx):
|
|
72
79
|
"""
|
|
@@ -76,14 +83,12 @@ class BaseSerializerPipelineItem(BasePipelineItem):
|
|
|
76
83
|
DataType.Test: BasePipeline
|
|
77
84
|
}
|
|
78
85
|
|
|
79
|
-
|
|
86
|
+
data_type_pipelines: doc_id -> parsed_doc -> annot -> opinion linkages
|
|
80
87
|
for example, function: sentiment_attitude_extraction_default_pipeline
|
|
88
|
+
doc_ids: optional
|
|
89
|
+
this parameter allows to limit amount of documents considered for sampling
|
|
81
90
|
"""
|
|
82
|
-
assert
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
data_folding = pipeline_ctx.provide("data_folding")
|
|
87
|
-
for _ in folding_iter_states(data_folding):
|
|
88
|
-
self._handle_iteration(data_type_pipelines=pipeline_ctx.provide("data_type_pipelines"),
|
|
89
|
-
data_folding=data_folding)
|
|
91
|
+
assert("data_type_pipelines" in pipeline_ctx)
|
|
92
|
+
self._handle_iteration(data_type_pipelines=pipeline_ctx.provide("data_type_pipelines"),
|
|
93
|
+
doc_ids=pipeline_ctx.provide_or_none("doc_ids"),
|
|
94
|
+
data_folding=pipeline_ctx.provide_or_none("data_folding"))
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from arekit.common.folding.base import BaseDataFolding
|
|
2
1
|
from arekit.contrib.networks.input.embedding.matrix import create_term_embedding_matrix
|
|
3
2
|
from arekit.contrib.networks.input.embedding.offsets import TermsEmbeddingOffsets
|
|
4
3
|
from arekit.contrib.networks.embedding import Embedding
|
|
@@ -9,8 +8,7 @@ from arekit.contrib.utils.pipelines.items.sampling.base import BaseSerializerPip
|
|
|
9
8
|
|
|
10
9
|
class NetworksInputSerializerPipelineItem(BaseSerializerPipelineItem):
|
|
11
10
|
|
|
12
|
-
def __init__(self, save_labels_func, rows_provider, samples_io,
|
|
13
|
-
emb_io, balance_func, storage, save_embedding=True):
|
|
11
|
+
def __init__(self, save_labels_func, rows_provider, samples_io, emb_io, storage, save_embedding=True, **kwargs):
|
|
14
12
|
""" This pipeline item allows to perform a data preparation for neural network models.
|
|
15
13
|
|
|
16
14
|
considering a list of the whole data_types with the related pipelines,
|
|
@@ -25,23 +23,22 @@ class NetworksInputSerializerPipelineItem(BaseSerializerPipelineItem):
|
|
|
25
23
|
rows_provider=rows_provider,
|
|
26
24
|
samples_io=samples_io,
|
|
27
25
|
save_labels_func=save_labels_func,
|
|
28
|
-
|
|
29
|
-
|
|
26
|
+
storage=storage,
|
|
27
|
+
**kwargs)
|
|
30
28
|
|
|
31
29
|
self.__emb_io = emb_io
|
|
32
30
|
self.__save_embedding = save_embedding
|
|
33
31
|
|
|
34
|
-
def _handle_iteration(self, data_type_pipelines, data_folding):
|
|
32
|
+
def _handle_iteration(self, data_type_pipelines, data_folding, doc_ids):
|
|
35
33
|
""" Performing data serialization for a particular iteration
|
|
36
34
|
"""
|
|
37
35
|
assert(isinstance(data_type_pipelines, dict))
|
|
38
|
-
assert(isinstance(data_folding, BaseDataFolding))
|
|
39
36
|
|
|
40
37
|
# Prepare for the present iteration.
|
|
41
38
|
self._rows_provider.clear_embedding_pairs()
|
|
42
39
|
|
|
43
40
|
super(NetworksInputSerializerPipelineItem, self)._handle_iteration(
|
|
44
|
-
data_type_pipelines=data_type_pipelines, data_folding=data_folding)
|
|
41
|
+
data_type_pipelines=data_type_pipelines, data_folding=data_folding, doc_ids=doc_ids)
|
|
45
42
|
|
|
46
43
|
if not (self.__save_embedding and self._rows_provider.HasEmbeddingPairs):
|
|
47
44
|
return
|
|
@@ -52,7 +49,7 @@ class NetworksInputSerializerPipelineItem(BaseSerializerPipelineItem):
|
|
|
52
49
|
vocab = list(TermsEmbeddingOffsets.extract_vocab(words_embedding=term_embedding))
|
|
53
50
|
|
|
54
51
|
# Save embedding matrix
|
|
55
|
-
self.__emb_io.save_embedding(data=embedding_matrix
|
|
56
|
-
self.__emb_io.save_vocab(data=vocab
|
|
52
|
+
self.__emb_io.save_embedding(data=embedding_matrix)
|
|
53
|
+
self.__emb_io.save_vocab(data=vocab)
|
|
57
54
|
|
|
58
55
|
del embedding_matrix
|
|
@@ -4,8 +4,8 @@ from arekit.common.pipeline.items.base import BasePipelineItem
|
|
|
4
4
|
|
|
5
5
|
class TextEntitiesParser(BasePipelineItem):
|
|
6
6
|
|
|
7
|
-
def __init__(self):
|
|
8
|
-
super(TextEntitiesParser, self).__init__()
|
|
7
|
+
def __init__(self, **kwargs):
|
|
8
|
+
super(TextEntitiesParser, self).__init__(**kwargs)
|
|
9
9
|
|
|
10
10
|
@staticmethod
|
|
11
11
|
def __process_word(word):
|
|
@@ -6,11 +6,10 @@ from arekit.common.pipeline.items.base import BasePipelineItem
|
|
|
6
6
|
|
|
7
7
|
class FrameVariantsParser(BasePipelineItem):
|
|
8
8
|
|
|
9
|
-
def __init__(self, frame_variants):
|
|
9
|
+
def __init__(self, frame_variants, **kwargs):
|
|
10
10
|
assert(isinstance(frame_variants, FrameVariantsCollection))
|
|
11
11
|
assert(len(frame_variants) > 0)
|
|
12
|
-
|
|
13
|
-
super(FrameVariantsParser, self).__init__()
|
|
12
|
+
super(FrameVariantsParser, self).__init__(**kwargs)
|
|
14
13
|
|
|
15
14
|
self.__frame_variants = frame_variants
|
|
16
15
|
self.__max_variant_len = max([len(variant) for _, variant in frame_variants.iter_variants()])
|
|
@@ -5,10 +5,10 @@ from arekit.contrib.utils.processing.languages.ru.mods import RussianLanguageMod
|
|
|
5
5
|
|
|
6
6
|
class LemmasBasedFrameVariantsParser(FrameVariantsParser):
|
|
7
7
|
|
|
8
|
-
def __init__(self, frame_variants, stemmer, locale_mods=RussianLanguageMods, save_lemmas=False):
|
|
8
|
+
def __init__(self, frame_variants, stemmer, locale_mods=RussianLanguageMods, save_lemmas=False, **kwargs):
|
|
9
9
|
assert(isinstance(stemmer, Stemmer))
|
|
10
10
|
assert(isinstance(save_lemmas, bool))
|
|
11
|
-
super(LemmasBasedFrameVariantsParser, self).__init__(frame_variants=frame_variants)
|
|
11
|
+
super(LemmasBasedFrameVariantsParser, self).__init__(frame_variants=frame_variants, **kwargs)
|
|
12
12
|
|
|
13
13
|
self.__frame_variants = frame_variants
|
|
14
14
|
self.__stemmer = stemmer
|
|
@@ -24,7 +24,7 @@ class LemmasBasedFrameVariantsParser(FrameVariantsParser):
|
|
|
24
24
|
|
|
25
25
|
def __provide_lemmatized_terms(self, terms):
|
|
26
26
|
"""
|
|
27
|
-
Compose a list of lemmatized versions of
|
|
27
|
+
Compose a list of lemmatized versions of parsed_doc
|
|
28
28
|
PS: Might be significantly slow, depending on stemmer were used.
|
|
29
29
|
"""
|
|
30
30
|
assert(isinstance(terms, list))
|
|
@@ -7,8 +7,9 @@ from arekit.contrib.utils.processing.languages.ru.mods import RussianLanguageMod
|
|
|
7
7
|
|
|
8
8
|
class FrameVariantsSentimentNegation(BasePipelineItem):
|
|
9
9
|
|
|
10
|
-
def __init__(self, locale_mods=RussianLanguageMods):
|
|
10
|
+
def __init__(self, locale_mods=RussianLanguageMods, **kwargs):
|
|
11
11
|
assert(issubclass(locale_mods, BaseLanguageMods))
|
|
12
|
+
super(FrameVariantsSentimentNegation, self).__init__(**kwargs)
|
|
12
13
|
self._locale_mods = locale_mods
|
|
13
14
|
|
|
14
15
|
@staticmethod
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
|
|
3
3
|
from arekit.common.context.token import Token
|
|
4
|
-
from arekit.common.pipeline.context import PipelineContext
|
|
5
4
|
from arekit.common.pipeline.items.base import BasePipelineItem
|
|
6
5
|
from arekit.common.utils import split_by_whitespaces
|
|
7
6
|
from arekit.contrib.utils.processing.text.tokens import Tokens
|
|
@@ -14,14 +13,13 @@ class DefaultTextTokenizer(BasePipelineItem):
|
|
|
14
13
|
""" Default parser implementation.
|
|
15
14
|
"""
|
|
16
15
|
|
|
17
|
-
def __init__(self, keep_tokens=True):
|
|
18
|
-
super(DefaultTextTokenizer, self).__init__()
|
|
16
|
+
def __init__(self, keep_tokens=True, **kwargs):
|
|
17
|
+
super(DefaultTextTokenizer, self).__init__(**kwargs)
|
|
19
18
|
self.__keep_tokens = keep_tokens
|
|
20
19
|
|
|
21
20
|
# region protected methods
|
|
22
21
|
|
|
23
22
|
def apply_core(self, input_data, pipeline_ctx):
|
|
24
|
-
assert(isinstance(pipeline_ctx, PipelineContext))
|
|
25
23
|
output_data = self.__process_parts(input_data)
|
|
26
24
|
if not self.__keep_tokens:
|
|
27
25
|
output_data = [word for word in output_data if not isinstance(word, Token)]
|
|
@@ -60,7 +58,7 @@ class DefaultTextTokenizer(BasePipelineItem):
|
|
|
60
58
|
@staticmethod
|
|
61
59
|
def __split_tokens(term):
|
|
62
60
|
"""
|
|
63
|
-
Splitting off tokens from
|
|
61
|
+
Splitting off tokens from parsed_doc ending, i.e. for example:
|
|
64
62
|
term: "сказать,-" -> "(term: "сказать", ["COMMA_TOKEN", "DASH_TOKEN"])
|
|
65
63
|
return: (unicode or None, list)
|
|
66
64
|
modified term and list of extracted tokens.
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from arekit.common.data.input.providers.const import IDLE_MODE
|
|
2
|
+
from arekit.common.pipeline.conts import PARENT_CTX
|
|
3
|
+
from arekit.common.entities.base import Entity
|
|
4
|
+
from arekit.common.pipeline.context import PipelineContext
|
|
5
|
+
from arekit.common.pipeline.items.base import BasePipelineItem
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MLTextTranslatorPipelineItem(BasePipelineItem):
|
|
9
|
+
""" Machine learning based translator pipeline item.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, batch_translate_model, do_translate_entity=True, **kwargs):
|
|
13
|
+
""" Model, which is based on translation of the text,
|
|
14
|
+
represented as a list of words.
|
|
15
|
+
"""
|
|
16
|
+
super(MLTextTranslatorPipelineItem, self).__init__(**kwargs)
|
|
17
|
+
self.__do_translate_entity = do_translate_entity
|
|
18
|
+
self.__translate = batch_translate_model
|
|
19
|
+
|
|
20
|
+
def fast_most_accurate_approach(self, input_data, entity_placeholder_template="<entityTag={}/>"):
|
|
21
|
+
""" This approach assumes that the translation won't corrupt the original
|
|
22
|
+
meta-annotation for entities and objects mentioned in text.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __optionally_register(prts):
|
|
26
|
+
if len(prts) > 0:
|
|
27
|
+
content.append(" ".join(prts))
|
|
28
|
+
parts_to_join.clear()
|
|
29
|
+
|
|
30
|
+
content = []
|
|
31
|
+
origin_entities = []
|
|
32
|
+
parts_to_join = []
|
|
33
|
+
|
|
34
|
+
for part in input_data:
|
|
35
|
+
if isinstance(part, str) and part.strip():
|
|
36
|
+
parts_to_join.append(part)
|
|
37
|
+
elif isinstance(part, Entity):
|
|
38
|
+
entity_index = len(origin_entities)
|
|
39
|
+
parts_to_join.append(entity_placeholder_template.format(entity_index))
|
|
40
|
+
# Register entities information for further restoration.
|
|
41
|
+
origin_entities.append(part)
|
|
42
|
+
|
|
43
|
+
# Register original text with masked named entities.
|
|
44
|
+
__optionally_register(parts_to_join)
|
|
45
|
+
# Register all named entities in order of their appearance in text.
|
|
46
|
+
content.extend([e.Value for e in origin_entities])
|
|
47
|
+
|
|
48
|
+
# Compose text parts.
|
|
49
|
+
translated_parts = self.__translate(content)
|
|
50
|
+
|
|
51
|
+
if len(translated_parts) == 0:
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
# Take the original text.
|
|
55
|
+
text = translated_parts[0]
|
|
56
|
+
for entity_index in range(len(origin_entities)):
|
|
57
|
+
if entity_placeholder_template.format(entity_index) not in text:
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
# Enumerate entities.
|
|
61
|
+
from_ind = 0
|
|
62
|
+
text_parts = []
|
|
63
|
+
for entity_index, translated_value in enumerate(translated_parts[1:]):
|
|
64
|
+
entity_placeholder_instance = entity_placeholder_template.format(entity_index)
|
|
65
|
+
# Cropping text part.
|
|
66
|
+
to_ind = text.index(entity_placeholder_instance)
|
|
67
|
+
|
|
68
|
+
if self.__do_translate_entity:
|
|
69
|
+
origin_entities[entity_index].set_display_value(translated_value.strip())
|
|
70
|
+
|
|
71
|
+
# Register entities.
|
|
72
|
+
text_parts.append(text[from_ind:to_ind])
|
|
73
|
+
text_parts.append(origin_entities[entity_index])
|
|
74
|
+
# Update from index.
|
|
75
|
+
from_ind = to_ind + len(entity_placeholder_instance)
|
|
76
|
+
|
|
77
|
+
# Consider the remaining part.
|
|
78
|
+
text_parts.append(text[from_ind:])
|
|
79
|
+
return text_parts
|
|
80
|
+
|
|
81
|
+
def default_pre_part_splitting_approach(self, input_data):
|
|
82
|
+
""" This is the original strategy, based on the manually cropped named entities
|
|
83
|
+
before the actual translation call.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __optionally_register(prts):
|
|
87
|
+
if len(prts) > 0:
|
|
88
|
+
content.append(" ".join(prts))
|
|
89
|
+
parts_to_join.clear()
|
|
90
|
+
|
|
91
|
+
content = []
|
|
92
|
+
origin_entities = []
|
|
93
|
+
origin_entity_ind = []
|
|
94
|
+
parts_to_join = []
|
|
95
|
+
|
|
96
|
+
for _, part in enumerate(input_data):
|
|
97
|
+
if isinstance(part, str) and part.strip():
|
|
98
|
+
parts_to_join.append(part)
|
|
99
|
+
elif isinstance(part, Entity):
|
|
100
|
+
# Register first the prior parts were merged.
|
|
101
|
+
__optionally_register(parts_to_join)
|
|
102
|
+
# Register entities information for further restoration.
|
|
103
|
+
origin_entity_ind.append(len(content))
|
|
104
|
+
origin_entities.append(part)
|
|
105
|
+
content.append(part.Value)
|
|
106
|
+
|
|
107
|
+
__optionally_register(parts_to_join)
|
|
108
|
+
|
|
109
|
+
# Compose text parts.
|
|
110
|
+
translated_parts = self.__translate(content)
|
|
111
|
+
|
|
112
|
+
for entity_ind, entity_part_ind in enumerate(origin_entity_ind):
|
|
113
|
+
entity = origin_entities[entity_ind]
|
|
114
|
+
if self.__do_translate_entity:
|
|
115
|
+
entity.set_display_value(translated_parts[entity_part_ind].strip())
|
|
116
|
+
translated_parts[entity_part_ind] = entity
|
|
117
|
+
|
|
118
|
+
return translated_parts
|
|
119
|
+
|
|
120
|
+
def apply_core(self, input_data, pipeline_ctx):
|
|
121
|
+
assert(isinstance(pipeline_ctx, PipelineContext))
|
|
122
|
+
assert(isinstance(input_data, list))
|
|
123
|
+
|
|
124
|
+
# Check the pipeline state whether is an idle mode or not.
|
|
125
|
+
parent_ctx = pipeline_ctx.provide(PARENT_CTX)
|
|
126
|
+
idle_mode = parent_ctx.provide(IDLE_MODE)
|
|
127
|
+
|
|
128
|
+
# When pipeline utilized only for the assessing the expected amount
|
|
129
|
+
# of rows (common case of idle_mode), there is no need to perform
|
|
130
|
+
# translation.
|
|
131
|
+
if idle_mode:
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
fast_accurate = self.fast_most_accurate_approach(input_data)
|
|
135
|
+
return self.default_pre_part_splitting_approach(input_data) \
|
|
136
|
+
if fast_accurate is None else fast_accurate
|